diff --git a/.gitattributes b/.gitattributes index 1815c810bcab07646132a6dc7dd1e391dd96727a..267d0e5d84cba76ed182c682bec3156e7d39352b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -86,3 +86,18 @@ MLPY/Lib/site-packages/PIL/_imaging.cp39-win_amd64.pyd filter=lfs diff=lfs merge MLPY/Lib/site-packages/PIL/_imagingft.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text MLPY/Lib/site-packages/pythonwin/mfc140u.dll filter=lfs diff=lfs merge=lfs -text MLPY/Lib/site-packages/pythonwin/win32ui.pyd filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/bin/fbgemm.dll filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/bin/protoc.exe filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/dnnl.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/libprotobuf-lite.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/libprotobuf.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/torch_cpu.dll filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/torch_cpu.lib filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/torch_python.dll filter=lfs diff=lfs merge=lfs -text +MLPY/Lib/site-packages/torch/lib/XNNPACK.lib filter=lfs diff=lfs merge=lfs -text diff --git a/MLPY/Lib/site-packages/torch/_C.cp39-win_amd64.pyd b/MLPY/Lib/site-packages/torch/_C.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..46a4a4ab50e1e50d15f4b62676d0be4e47217ebd Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_C.cp39-win_amd64.pyd differ diff --git a/MLPY/Lib/site-packages/torch/_C/_VariableFunctions.pyi b/MLPY/Lib/site-packages/torch/_C/_VariableFunctions.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e7bc45da38b2228706f8e353adb5af335d22eae3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_VariableFunctions.pyi @@ -0,0 +1,25648 @@ +# @generated from torch/_C/_VariableFunctions.pyi.in +# mypy: disable-error-code="type-arg" + +import builtins +from typing import ( + Any, + Callable, + ContextManager, + Iterator, + List, + Literal, + NamedTuple, + Optional, + overload, + Sequence, + Tuple, + TypeVar, + Union, +) + +import torch +from torch import contiguous_format, Generator, inf, memory_format, strided, SymInt, Tensor +from torch.types import ( + _bool, + _complex, + _device, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Device, + Number, +) + +from torch._prims_common import DeviceLikeType + +@overload +def __and__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __and__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +def _adaptive_avg_pool2d(input: Tensor, output_size: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]]) -> Tensor: ... +def _adaptive_avg_pool3d(input: Tensor, output_size: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]]) -> Tensor: ... +def _add_batch_dim(input: Tensor, batch_dim: _int, level: _int) -> Tensor: ... +@overload +def _add_relu(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def _add_relu(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def _add_relu_(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def _add_relu_(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: ... +def _addmm_activation(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, use_gelu: _bool = False, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def _aminmax(input: Tensor) -> Tuple[Tensor, Tensor]: ... +@overload +def _aminmax(input: Tensor, dim: _int, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: ... +def _amp_foreach_non_finite_check_and_unscale_(self: Union[Tuple[Tensor, ...], List[Tensor]], found_inf: Tensor, inv_scale: Tensor) -> None: ... +def _amp_update_scale_(input: Tensor, growth_tracker: Tensor, found_inf: Tensor, scale_growth_factor: _float, scale_backoff_factor: _float, growth_interval: _int) -> Tensor: ... +@overload +def _assert_async(input: Tensor) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + ... +@overload +def _assert_async(input: Tensor, assert_msg: str) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + ... +def _assert_scalar(self: Union[Number, _complex], assert_msg: str) -> None: ... +def _assert_tensor_metadata(a: Tensor, size: Optional[Sequence[Union[_int, SymInt]]] = None, stride: Optional[Sequence[Union[_int, SymInt]]] = None, dtype: Optional[_dtype] = None) -> None: ... +def _batch_norm_impl_index(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, _int]: ... +def _cast_Byte(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Char(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Double(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Float(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Half(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Int(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Long(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Short(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _choose_qparams_per_tensor(input: Tensor, reduce_range: _bool = False) -> Tuple[_float, _int]: ... +def _chunk_cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int, num_chunks: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _coalesce(input: Tensor) -> Tensor: ... +def _compute_linear_combination(input: Tensor, coefficients: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _conj(input: Tensor) -> Tensor: ... +def _conj_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _conj_physical(input: Tensor) -> Tensor: ... +def _convert_indices_from_coo_to_csr(input: Tensor, size: _int, *, out_int32: _bool = False, out: Optional[Tensor] = None) -> Tensor: ... +def _convert_indices_from_csr_to_coo(crow_indices: Tensor, col_indices: Tensor, *, out_int32: _bool = False, transpose: _bool = False, out: Optional[Tensor] = None) -> Tensor: ... +def _convert_weight_to_int4pack(input: Tensor, innerKTiles: _int) -> Tensor: ... +@overload +def _convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], transposed: _bool, output_padding: _size, groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, cudnn_enabled: _bool) -> Tensor: ... +@overload +def _convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], transposed: _bool, output_padding: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, cudnn_enabled: _bool, allow_tf32: _bool) -> Tensor: ... +def _convolution_mode(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: str, dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def _copy_from(input: Tensor, dst: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _copy_from_and_resize(input: Tensor, dst: Tensor) -> Tensor: ... +def _cslt_compress(input: Tensor) -> Tensor: ... +def _cslt_sparse_mm(compressed_A: Tensor, dense_B: Tensor, bias: Optional[Tensor] = None, alpha: Optional[Tensor] = None, out_dtype: Optional[_dtype] = None, transpose_result: _bool = False, alg_id: _int = 0) -> Tensor: ... +def _cslt_sparse_mm_search(compressed_A: Tensor, dense_B: Tensor, bias: Optional[Tensor] = None, alpha: Optional[Tensor] = None, out_dtype: Optional[_dtype] = None, transpose_result: _bool = False) -> _int: ... +@overload +def _ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int = 0, zero_infinity: _bool = False) -> Tuple[Tensor, Tensor]: ... +@overload +def _ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int = 0, zero_infinity: _bool = False) -> Tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int, deterministic: _bool, zero_infinity: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int, deterministic: _bool, zero_infinity: _bool) -> Tuple[Tensor, Tensor]: ... +def _cudnn_init_dropout_state(dropout: _float, train: _bool, dropout_seed: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _cudnn_rnn(input: Tensor, weight: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, weight_buf: Optional[Tensor], hx: Tensor, cx: Optional[Tensor], mode: _int, hidden_size: Union[_int, SymInt], proj_size: Union[_int, SymInt], num_layers: _int, batch_first: _bool, dropout: _float, train: _bool, bidirectional: _bool, batch_sizes: Sequence[Union[_int, SymInt]], dropout_state: Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _cudnn_rnn_flatten_weight(weight_arr: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, input_size: Union[_int, SymInt], mode: _int, hidden_size: Union[_int, SymInt], proj_size: Union[_int, SymInt], num_layers: _int, batch_first: _bool, bidirectional: _bool) -> Tensor: ... +def _cufft_clear_plan_cache(device_index: _int) -> None: ... +def _cufft_get_plan_cache_max_size(device_index: _int) -> _int: ... +def _cufft_get_plan_cache_size(device_index: _int) -> _int: ... +def _cufft_set_plan_cache_max_size(device_index: _int, max_size: _int) -> None: ... +def _cummax_helper(input: Tensor, values: Tensor, indices: Tensor, dim: _int) -> None: ... +def _cummin_helper(input: Tensor, values: Tensor, indices: Tensor, dim: _int) -> None: ... +def _debug_has_internal_overlap(input: Tensor) -> _int: ... +def _dim_arange(like: Tensor, dim: _int) -> Tensor: ... +def _dirichlet_grad(x: Tensor, alpha: Tensor, total: Tensor) -> Tensor: ... +def _disable_functionalization(): ... +@overload +def _efficientzerotensor(size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _efficientzerotensor(*size: _int, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False, padding_idx: _int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def _embedding_bag_forward_only(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False, padding_idx: _int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def _empty_affine_quantized(size: Sequence[Union[_int, SymInt]], *, scale: _float = 1, zero_point: _int = 0, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _empty_affine_quantized(*size: _int, scale: _float = 1, zero_point: _int = 0, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized(size: Sequence[Union[_int, SymInt]], *, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized(*size: _int, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _enable_functionalization(*, reapply_views: _bool = False): ... +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: ... +def _fake_quantize_learnable_per_channel_affine(input: Tensor, scale: Tensor, zero_point: Tensor, axis: _int, quant_min: _int, quant_max: _int, grad_factor: _float = 1.0) -> Tensor: ... +def _fake_quantize_learnable_per_tensor_affine(input: Tensor, scale: Tensor, zero_point: Tensor, quant_min: _int, quant_max: _int, grad_factor: _float = 1.0) -> Tensor: ... +def _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(input: Tensor, scale: Tensor, zero_point: Tensor, fake_quant_enabled: Tensor, quant_min: _int, quant_max: _int) -> torch.return_types._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: ... +def _fft_c2c(input: Tensor, dim: Sequence[Union[_int, SymInt]], normalization: _int, forward: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _fft_c2r(input: Tensor, dim: _size, normalization: _int, last_dim_size: Union[_int, SymInt], *, out: Optional[Tensor] = None) -> Tensor: ... +def _fft_r2c(input: Tensor, dim: _size, normalization: _int, onesided: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _fill_mem_eff_dropout_mask_(input: Tensor, dropout_p: _float, seed: _int, offset: _int) -> Tensor: ... +def _foobar(input: Tensor, arg1: _bool = True, arg2: _bool = True, *, arg3: _bool = True) -> Tensor: ... +def _foreach_abs(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_abs(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + ... +def _foreach_abs_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_abs_(self: List[Tensor]) -> None + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + ... +def _foreach_acos(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_acos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + ... +def _foreach_acos_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_acos_(self: List[Tensor]) -> None + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor, *, alpha: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_addcdiv(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_addcdiv_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> None: ... +@overload +def _foreach_addcdiv_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_addcmul(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_addcmul_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> None: ... +@overload +def _foreach_addcmul_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> None: ... +def _foreach_asin(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_asin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + ... +def _foreach_asin_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_asin_(self: List[Tensor]) -> None + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + ... +def _foreach_atan(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_atan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + ... +def _foreach_atan_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_atan_(self: List[Tensor]) -> None + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + ... +def _foreach_ceil(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_ceil(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + ... +def _foreach_ceil_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_ceil_(self: List[Tensor]) -> None + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + ... +@overload +def _foreach_clamp_max(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_clamp_max_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_clamp_max_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +@overload +def _foreach_clamp_min(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_clamp_min_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_clamp_min_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_copy_(self: Union[Tuple[Tensor, ...], List[Tensor]], src: Union[Tuple[Tensor, ...], List[Tensor]], non_blocking: _bool = False) -> None: ... +def _foreach_cos(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_cos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + ... +def _foreach_cos_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_cos_(self: List[Tensor]) -> None + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + ... +def _foreach_cosh(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_cosh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + ... +def _foreach_cosh_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_cosh_(self: List[Tensor]) -> None + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> None: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_erf(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_erf(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + ... +def _foreach_erf_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_erf_(self: List[Tensor]) -> None + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + ... +def _foreach_erfc(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_erfc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + ... +def _foreach_erfc_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_erfc_(self: List[Tensor]) -> None + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + ... +def _foreach_exp(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_exp(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + ... +def _foreach_exp_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_exp_(self: List[Tensor]) -> None + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + ... +def _foreach_expm1(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_expm1(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + ... +def _foreach_expm1_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_expm1_(self: List[Tensor]) -> None + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + ... +def _foreach_floor(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_floor(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + ... +def _foreach_floor_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_floor_(self: List[Tensor]) -> None + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + ... +def _foreach_frac(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_frac(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + ... +def _foreach_frac_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_frac_(self: List[Tensor]) -> None + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + ... +@overload +def _foreach_lerp(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weight: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_lerp(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weights: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_lerp_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weight: Union[Number, _complex]) -> None: ... +@overload +def _foreach_lerp_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weights: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_lgamma(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_lgamma(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + ... +def _foreach_lgamma_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_lgamma_(self: List[Tensor]) -> None + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + ... +def _foreach_log(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log` to each Tensor of the input list. + """ + ... +def _foreach_log10(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log10(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + ... +def _foreach_log10_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log10_(self: List[Tensor]) -> None + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + ... +def _foreach_log1p(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log1p(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + ... +def _foreach_log1p_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log1p_(self: List[Tensor]) -> None + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + ... +def _foreach_log2(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log2(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + ... +def _foreach_log2_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log2_(self: List[Tensor]) -> None + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + ... +def _foreach_log_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log_(self: List[Tensor]) -> None + + Apply :func:`torch.log` to each Tensor of the input list. + """ + ... +@overload +def _foreach_maximum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_maximum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_maximum(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_maximum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_maximum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_maximum_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +@overload +def _foreach_minimum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_minimum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_minimum(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_minimum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_minimum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_minimum_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> None: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_neg(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_neg(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + ... +def _foreach_neg_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_neg_(self: List[Tensor]) -> None + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + ... +def _foreach_norm(self: Union[Tuple[Tensor, ...], List[Tensor]], ord: Union[Number, _complex] = 2) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Number, _complex], exponent: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow_(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_pow_(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Number, _complex]) -> None: ... +@overload +def _foreach_pow_(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_reciprocal(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_reciprocal(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + ... +def _foreach_reciprocal_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_reciprocal_(self: List[Tensor]) -> None + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + ... +def _foreach_round(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_round(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.round` to each Tensor of the input list. + """ + ... +def _foreach_round_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_round_(self: List[Tensor]) -> None + + Apply :func:`torch.round` to each Tensor of the input list. + """ + ... +def _foreach_sigmoid(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sigmoid(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + ... +def _foreach_sigmoid_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sigmoid_(self: List[Tensor]) -> None + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + ... +def _foreach_sign(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +def _foreach_sign_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_sin(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + ... +def _foreach_sin_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sin_(self: List[Tensor]) -> None + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + ... +def _foreach_sinh(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sinh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + ... +def _foreach_sinh_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sinh_(self: List[Tensor]) -> None + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + ... +def _foreach_sqrt(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sqrt(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + ... +def _foreach_sqrt_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sqrt_(self: List[Tensor]) -> None + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + ... +@overload +def _foreach_sub(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_sub(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_sub(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_sub_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_sub_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_sub_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +def _foreach_tan(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_tan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + ... +def _foreach_tan_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_tan_(self: List[Tensor]) -> None + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + ... +def _foreach_tanh(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_tanh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + ... +def _foreach_tanh_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_tanh_(self: List[Tensor]) -> None + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + ... +def _foreach_trunc(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_trunc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + ... +def _foreach_trunc_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_trunc_(self: List[Tensor]) -> None + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + ... +def _foreach_zero_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_zero_(self: List[Tensor]) -> None + + Apply :func:`torch.zero` to each Tensor of the input list. + """ + ... +def _from_functional_tensor(t: Tensor) -> Tensor: ... +def _functional_assert_async(input: Tensor, assert_msg: str, dep_token: Tensor) -> Tensor: ... +def _functional_assert_scalar(self: Union[Number, _complex], assert_msg: str, dep_token: Tensor) -> Tensor: ... +def _functional_sym_constrain_range(size: Union[Number, _complex], min: Optional[_int], max: Optional[_int], dep_token: Tensor) -> Tensor: ... +def _functional_sym_constrain_range_for_size(size: Union[Number, _complex], min: Optional[_int], max: Optional[_int], dep_token: Tensor) -> Tensor: ... +def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ... +def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ... +def _functionalize_commit_update(t: Tensor) -> None: ... +def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ... +def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ... +def _functionalize_sync(t: Tensor) -> None: ... +@overload +def _fused_adam_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: Tensor, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_adam_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: _float, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_adamw_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: Tensor, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_adamw_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: _float, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +def _fused_dropout(input: Tensor, p: _float, generator: Optional[Generator] = None) -> Tuple[Tensor, Tensor]: ... +def _fused_moving_avg_obs_fq_helper(input: Tensor, observer_on: Tensor, fake_quant_on: Tensor, running_min: Tensor, running_max: Tensor, scale: Tensor, zero_point: Tensor, averaging_const: _float, quant_min: _int, quant_max: _int, ch_axis: _int, per_row_fake_quant: _bool = False, symmetric_quant: _bool = False) -> torch.return_types._fused_moving_avg_obs_fq_helper: ... +def _fused_sdp_choice(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: _float = 0.0, is_causal: _bool = False, *, scale: Optional[_float] = None) -> _int: ... +@overload +def _fused_sgd_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], momentum_buffer_list: Union[Tuple[Tensor, ...], List[Tensor]], *, weight_decay: _float, momentum: _float, lr: Tensor, dampening: _float, nesterov: _bool, maximize: _bool, is_first_step: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_sgd_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], momentum_buffer_list: Union[Tuple[Tensor, ...], List[Tensor]], *, weight_decay: _float, momentum: _float, lr: _float, dampening: _float, nesterov: _bool, maximize: _bool, is_first_step: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +def _fw_primal_copy(input: Tensor, level: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _grid_sampler_2d_cpu_fallback(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def _has_compatible_shallow_copy_type(input: Tensor, from_: Tensor) -> _bool: ... +def _histogramdd_bin_edges(input: Tensor, bins: _size, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> Tuple[Tensor, ...]: ... +def _histogramdd_from_bin_cts(input: Tensor, bins: _size, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> Tensor: ... +def _histogramdd_from_bin_tensors(input: Tensor, bins: Union[Tuple[Tensor, ...], List[Tensor]], *, weight: Optional[Tensor] = None, density: _bool = False) -> Tensor: ... +def _index_put_impl_(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False, unsafe: _bool = False) -> Tensor: ... +def _indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _int_mm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _is_all_true(input: Tensor) -> Tensor: ... +def _is_any_true(input: Tensor) -> Tensor: ... +def _is_functional_tensor(t: Tensor) -> _bool: ... +def _is_zerotensor(input: Tensor) -> _bool: ... +def _lazy_clone(input: Tensor) -> Tensor: ... +def _linalg_check_errors(info: Tensor, api_name: str, *, is_matrix: _bool) -> None: ... +def _linalg_det(A: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_det: ... +def _linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_eigh: ... +def _linalg_slogdet(A: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_slogdet: ... +def _linalg_solve_ex(A: Tensor, B: Tensor, *, left: _bool = True, check_errors: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_solve_ex: ... +def _linalg_svd(A: Tensor, full_matrices: _bool = False, compute_uv: _bool = True, *, driver: Optional[str] = None, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_svd: ... +def _log_softmax(input: Tensor, dim: _int, half_to_float: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _log_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input_dtype: _dtype, *, out: Optional[Tensor] = None) -> Tensor: ... +def _logcumsumexp(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _lstm_mps(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _lu_with_info(input: Tensor, pivot: _bool = True, check_errors: _bool = True) -> torch.return_types._lu_with_info: ... +def _make_dep_token(*, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _make_dual(primal: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _make_dual_copy(primal: Tensor, tangent: Tensor, level: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _make_per_channel_quantized_tensor(input: Tensor, scale: Tensor, zero_point: Tensor, axis: _int) -> Tensor: ... +def _make_per_tensor_quantized_tensor(input: Tensor, scale: _float, zero_point: _int) -> Tensor: ... +def _masked_scale(input: Tensor, mask: Tensor, scale: _float) -> Tensor: ... +def _masked_softmax(input: Tensor, mask: Tensor, dim: Optional[_int] = None, mask_type: Optional[_int] = None) -> Tensor: ... +def _mixed_dtypes_linear(input: Tensor, weight: Tensor, scale: Tensor, *, bias: Optional[Tensor] = None, activation: Optional[str] = None) -> Tensor: ... +def _mkldnn_reshape(input: Tensor, shape: _size) -> Tensor: ... +def _mkldnn_transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mkldnn_transpose_(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mps_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def _mps_convolution_transpose(input: Tensor, weight: Tensor, padding: Sequence[Union[_int, SymInt]], output_padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +@overload +def _native_batch_norm_legit(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Tensor, running_var: Tensor, training: _bool, momentum: _float, eps: _float, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor, Tensor]: ... +@overload +def _native_batch_norm_legit(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], training: _bool, momentum: _float, eps: _float, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor, Tensor]: ... +def _native_batch_norm_legit_no_training(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Tensor, running_var: Tensor, momentum: _float, eps: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +def _native_multi_head_attention(query: Tensor, key: Tensor, value: Tensor, embed_dim: _int, num_head: _int, qkv_weight: Tensor, qkv_bias: Tensor, proj_weight: Tensor, proj_bias: Tensor, mask: Optional[Tensor] = None, need_weights: _bool = True, average_attn_weights: _bool = True, mask_type: Optional[_int] = None) -> Tuple[Tensor, Tensor]: ... +def _neg_view(input: Tensor) -> Tensor: ... +def _neg_view_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nested_from_padded(padded: Tensor, cpu_nested_shape_example: Tensor, fuse_transform_0213: _bool = False) -> Tensor: ... +def _nested_from_padded_and_nested_example(padded: Tensor, nt_example: Tensor) -> Tensor: ... +def _nested_get_jagged_dummy(any: Tensor) -> Tensor: ... +def _nested_get_lengths(input: Tensor) -> Tensor: ... +def _nested_get_offsets(input: Tensor) -> Tensor: ... +def _nested_get_ragged_idx(input: Tensor) -> _int: ... +def _nested_get_values(input: Tensor) -> Tensor: ... +def _nested_get_values_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nested_tensor_from_mask(t: Tensor, mask: Tensor, mask_check: _bool = True) -> Tensor: ... +def _nested_tensor_from_mask_left_aligned(t: Tensor, mask: Tensor) -> _bool: ... +def _nested_tensor_from_tensor_list(list: Union[Tuple[Tensor, ...], List[Tensor]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = None) -> Tensor: ... +def _nested_tensor_softmax_with_shape(input: Tensor, query: Tensor) -> Tensor: ... +def _nested_view_from_buffer(input: Tensor, nested_size: Tensor, nested_strides: Tensor, offsets: Tensor) -> Tensor: ... +def _nested_view_from_buffer_copy(input: Tensor, nested_size: Tensor, nested_strides: Tensor, offsets: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nested_view_from_jagged(input: Tensor, offsets: Tensor, dummy: Tensor, lengths: Optional[Tensor] = None, ragged_idx: _int = 1) -> Tensor: ... +def _nested_view_from_jagged_copy(input: Tensor, offsets: Tensor, dummy: Tensor, lengths: Optional[Tensor] = None, ragged_idx: _int = 1, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nnpack_available() -> _bool: ... +def _nnpack_spatial_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]], stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def _pack_padded_sequence(input: Tensor, lengths: Tensor, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def _pad_packed_sequence(data: Tensor, batch_sizes: Tensor, batch_first: _bool, padding_value: Union[Number, _complex], total_length: _int) -> Tuple[Tensor, Tensor]: ... +def _pin_memory(input: Tensor, device: Optional[Optional[DeviceLikeType]] = None) -> Tensor: ... +def _prelu_kernel(input: Tensor, weight: Tensor) -> Tensor: ... +def _print(s: str) -> None: ... +def _propagate_xla_data(input: Tensor, output: Tensor) -> None: ... +def _remove_batch_dim(input: Tensor, level: _int, batch_size: _int, out_dim: _int) -> Tensor: ... +def _reshape_alias_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None) -> Tensor: ... +def _reshape_from_tensor(input: Tensor, shape: Tensor) -> Tensor: ... +def _resize_output_(input: Tensor, size: Sequence[Union[_int, SymInt]], device: Optional[DeviceLikeType]) -> Tensor: ... +def _rowwise_prune(weight: Tensor, mask: Tensor, compressed_indices_dtype: _dtype) -> Tuple[Tensor, Tensor]: ... +def _sample_dirichlet(input: Tensor, generator: Optional[Generator] = None) -> Tensor: ... +def _saturate_weight_to_fp16(weight: Tensor) -> Tensor: ... +def _scaled_dot_product_attention_math(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: _float = 0.0, is_causal: _bool = False, dropout_mask: Optional[Tensor] = None, *, scale: Optional[_float] = None) -> Tuple[Tensor, Tensor]: ... +def _scaled_dot_product_cudnn_attention(query: Tensor, key: Tensor, value: Tensor, dropout_p: _float = 0.0, is_causal: _bool = False, return_debug_mask: _bool = False, *, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_cudnn_attention: ... +def _scaled_dot_product_efficient_attention(query: Tensor, key: Tensor, value: Tensor, attn_bias: Optional[Tensor], compute_log_sumexp: _bool, dropout_p: _float = 0.0, is_causal: _bool = False, *, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_efficient_attention: ... +def _scaled_dot_product_flash_attention(query: Tensor, key: Tensor, value: Tensor, dropout_p: _float = 0.0, is_causal: _bool = False, return_debug_mask: _bool = False, *, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_flash_attention: ... +def _scaled_dot_product_flash_attention_for_cpu(query: Tensor, key: Tensor, value: Tensor, dropout_p: _float = 0.0, is_causal: _bool = False, *, attn_mask: Optional[Tensor] = None, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_flash_attention_for_cpu: ... +def _scaled_mm(input: Tensor, mat2: Tensor, *, bias: Optional[Tensor] = None, out_dtype: Optional[_dtype] = None, scale_a: Optional[Tensor] = None, scale_b: Optional[Tensor] = None, scale_result: Optional[Tensor] = None, use_fast_accum: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor]: ... +def _shape_as_tensor(input: Tensor) -> Tensor: ... +def _sobol_engine_draw(quasi: Tensor, n: _int, sobolstate: Tensor, dimension: _int, num_generated: _int, dtype: Optional[_dtype]) -> Tuple[Tensor, Tensor]: ... +def _sobol_engine_ff_(input: Tensor, n: _int, sobolstate: Tensor, dimension: _int, num_generated: _int) -> Tensor: ... +def _sobol_engine_initialize_state_(input: Tensor, dimension: _int) -> Tensor: ... +def _sobol_engine_scramble_(input: Tensor, ltm: Tensor, dimension: _int) -> Tensor: ... +def _softmax(input: Tensor, dim: _int, half_to_float: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input_dtype: _dtype, *, grad_input: Optional[Tensor] = None) -> Tensor: ... +def _sparse_broadcast_to(input: Tensor, size: _size) -> Tensor: ... +def _sparse_broadcast_to_copy(input: Tensor, size: _size, *, out: Optional[Tensor] = None) -> Tensor: ... +def _sparse_csr_prod(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: ... +def _sparse_csr_sum(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: ... +def _sparse_log_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input: Tensor) -> Tensor: ... +def _sparse_semi_structured_linear(input: Tensor, weight: Tensor, meta: Tensor, *, bias: Optional[Tensor] = None, activation: Optional[str] = None, out_dtype: Optional[_dtype] = None) -> Tensor: ... +def _sparse_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input: Tensor) -> Tensor: ... +def _sparse_sparse_matmul(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, *, dtype: _dtype) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, dim: Union[_int, _size]) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, dim: Union[_int, _size], *, dtype: _dtype) -> Tensor: ... +def _stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: ... +def _standard_gamma(input: Tensor, generator: Optional[Generator] = None) -> Tensor: ... +def _standard_gamma_grad(input: Tensor, output: Tensor) -> Tensor: ... +def _sync(t: Tensor) -> None: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor) -> Tensor: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor, b: _bool) -> Tensor: ... +def _test_autograd_multiple_dispatch_view(input: Tensor) -> Tensor: ... +def _test_autograd_multiple_dispatch_view_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _test_check_tensor(input: Tensor) -> Tensor: ... +def _test_functorch_fallback(input: Tensor, other: Tensor) -> Tensor: ... +def _test_parallel_materialize(input: Tensor, num_parallel: _int, skip_first: _bool = False) -> Tensor: ... +def _test_serialization_subcmul(input: Tensor, other: Tensor, alpha: Union[Number, _complex] = 1) -> Tensor: ... +def _to_cpu(tensors: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +def _to_functional_tensor(t: Tensor) -> Tensor: ... +def _to_sparse_semi_structured(dense: Tensor) -> Tuple[Tensor, Tensor]: ... +def _transform_bias_rescale_qkv(qkv: Tensor, qkv_bias: Tensor, num_heads: _int) -> Tuple[Tensor, Tensor, Tensor]: ... +def _transformer_encoder_layer_fwd(src: Tensor, embed_dim: _int, num_heads: _int, qkv_weight: Tensor, qkv_bias: Tensor, proj_weight: Tensor, proj_bias: Tensor, use_gelu: _bool, norm_first: _bool, eps: _float, norm_weight_1: Tensor, norm_bias_1: Tensor, norm_weight_2: Tensor, norm_bias_2: Tensor, ffn_weight_1: Tensor, ffn_bias_1: Tensor, ffn_weight_2: Tensor, ffn_bias_2: Tensor, mask: Optional[Tensor] = None, mask_type: Optional[_int] = None) -> Tensor: ... +def _trilinear(i1: Tensor, i2: Tensor, i3: Tensor, expand1: _size, expand2: _size, expand3: _size, sumdim: _size, unroll_dim: _int = 1) -> Tensor: ... +def _triton_multi_head_attention(query: Tensor, key: Tensor, value: Tensor, embed_dim: _int, num_head: _int, qkv_weight: Tensor, qkv_bias: Tensor, proj_weight: Tensor, proj_bias: Tensor, mask: Optional[Tensor] = None) -> Tensor: ... +def _triton_scaled_dot_attention(q: Tensor, k: Tensor, v: Tensor, dropout_p: _float = 0.0) -> Tensor: ... +def _unique(input: Tensor, sorted: _bool = True, return_inverse: _bool = False) -> Tuple[Tensor, Tensor]: ... +def _unique2(input: Tensor, sorted: _bool = True, return_inverse: _bool = False, return_counts: _bool = False) -> Tuple[Tensor, Tensor, Tensor]: ... +def _unpack_dual(dual: Tensor, level: _int) -> torch.return_types._unpack_dual: ... +def _unsafe_index(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]]) -> Tensor: ... +def _unsafe_index_put(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: ... +@overload +def _use_cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int) -> _bool: ... +@overload +def _use_cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int) -> _bool: ... +def _use_cudnn_rnn_flatten_weight() -> _bool: ... +def _validate_compressed_sparse_indices(is_crow: _bool, compressed_idx: Tensor, plain_idx: Tensor, cdim: _int, dim: _int, nnz: _int) -> None: ... +def _validate_sparse_bsc_tensor_args(ccol_indices: Tensor, row_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _validate_sparse_bsr_tensor_args(crow_indices: Tensor, col_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _validate_sparse_compressed_tensor_args(compressed_indices: Tensor, plain_indices: Tensor, values: Tensor, size: _size, layout: _layout) -> None: ... +def _validate_sparse_coo_tensor_args(indices: Tensor, values: Tensor, size: _size, is_coalesced: Optional[_bool] = None) -> None: ... +def _validate_sparse_csc_tensor_args(ccol_indices: Tensor, row_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _validate_sparse_csr_tensor_args(crow_indices: Tensor, col_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _values_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _weight_int4pack_mm(input: Tensor, mat2: Tensor, qGroupSize: _int, qScaleAndZeros: Tensor) -> Tensor: ... +def _weight_int8pack_mm(input: Tensor, mat2: Tensor, scales: Tensor) -> Tensor: ... +def _weight_norm(v: Tensor, g: Tensor, dim: _int = 0) -> Tensor: ... +def _weight_norm_interface(v: Tensor, g: Tensor, dim: _int = 0) -> Tuple[Tensor, Tensor]: ... +def abs(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + abs(input, *, out=None) -> Tensor + + Computes the absolute value of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = |\text{input}_{i}| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.abs(torch.tensor([-1, -2, 3])) + tensor([ 1, 2, 3]) + """ + ... +def abs_(input: Tensor) -> Tensor: ... +def absolute(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + absolute(input, *, out=None) -> Tensor + + Alias for :func:`torch.abs` + """ + ... +def acos(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + acos(input, *, out=None) -> Tensor + + Computes the inverse cosine of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \cos^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) + >>> torch.acos(a) + tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) + """ + ... +def acos_(input: Tensor) -> Tensor: ... +def acosh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + acosh(input, *, out=None) -> Tensor + + Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) + + Note: + The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range + will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(1, 2) + >>> a + tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) + >>> torch.acosh(a) + tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) + """ + ... +def acosh_(input: Tensor) -> Tensor: ... +def adaptive_avg_pool1d(input: Tensor, output_size: Union[_int, _size]) -> Tensor: ... +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ... +@overload +def add(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], *, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + ... +@overload +def add(self: Tensor, alpha: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + ... +@overload +def add(self: Tensor, alpha: Union[Number, _complex], other: Tensor, *, out: Tensor) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(input: Tensor, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addcdiv(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + ... +@overload +def addcdiv(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor, *, out: Tensor) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + ... +@overload +def addcdiv(input: Tensor, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + ... +@overload +def addcmul(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + ... +@overload +def addcmul(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor, *, out: Tensor) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + ... +@overload +def addcmul(input: Tensor, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat1: Tensor, mat2: Tensor, *, out: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, mat1: Tensor, mat2: Tensor, *, out: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat: Tensor, vec: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat: Tensor, vec: Tensor, *, out: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(input: Tensor, mat: Tensor, vec: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, mat: Tensor, vec: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, mat: Tensor, vec: Tensor, *, out: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv_(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat: Tensor, vec: Tensor) -> Tensor: ... +@overload +def addmv_(input: Tensor, mat: Tensor, vec: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def addmv_(beta: Union[Number, _complex], self: Tensor, mat: Tensor, vec: Tensor) -> Tensor: ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], vec1: Tensor, vec2: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], vec1: Tensor, vec2: Tensor, *, out: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(input: Tensor, vec1: Tensor, vec2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, vec1: Tensor, vec2: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, vec1: Tensor, vec2: Tensor, *, out: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +def adjoint(input: Tensor) -> Tensor: + r""" + adjoint(Tensor) -> Tensor + Returns a view of the tensor conjugated and with the last two dimensions transposed. + + ``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and + to ``x.transpose(-2, -1)`` for real tensors. + + Example:: + >>> x = torch.arange(4, dtype=torch.float) + >>> A = torch.complex(x, x).reshape(2, 2) + >>> A + tensor([[0.+0.j, 1.+1.j], + [2.+2.j, 3.+3.j]]) + >>> A.adjoint() + tensor([[0.-0.j, 2.-2.j], + [1.-1.j, 3.-3.j]]) + >>> (A.adjoint() == A.mH).all() + tensor(True) + """ + ... +def affine_grid_generator(theta: Tensor, size: Sequence[Union[_int, SymInt]], align_corners: _bool) -> Tensor: ... +def alias_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.alias`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def all(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +@overload +def all(input: Tensor, dim: Optional[_size] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +@overload +def all(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +@overload +def all(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +def allclose(input: Tensor, other: Tensor, rtol: _float = 1e-05, atol: _float = 1e-08, equal_nan: _bool = False) -> _bool: + r""" + allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool + + This function checks if :attr:`input` and :attr:`other` satisfy the condition: + + .. math:: + \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert + + elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to + `numpy.allclose `_ + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Example:: + + >>> torch.allclose(torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08])) + False + >>> torch.allclose(torch.tensor([10000., 1e-08]), torch.tensor([10000.1, 1e-09])) + True + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')])) + False + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) + True + """ + ... +def alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def alpha_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def amax(input: Tensor, dim: Union[_int, _size] = (), keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + amax(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the maximum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices, + - ``amax``/``amin`` evenly distributes gradient between equal values, + while ``max(dim)``/``min(dim)`` propagates gradient only to a single + index in the source tensor. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.8177, 1.4878, -0.2491, 0.9130], + [-0.7158, 1.1775, 2.0992, 0.4817], + [-0.0053, 0.0164, -1.3738, -0.0507], + [ 1.9700, 1.1106, -1.0318, -1.0816]]) + >>> torch.amax(a, 1) + tensor([1.4878, 2.0992, 0.0164, 1.9700]) + """ + ... +def amin(input: Tensor, dim: Union[_int, _size] = (), keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + amin(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the minimum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices, + - ``amax``/``amin`` evenly distributes gradient between equal values, + while ``max(dim)``/``min(dim)`` propagates gradient only to a single + index in the source tensor. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.6451, -0.4866, 0.2987, -1.3312], + [-0.5744, 1.2980, 1.8397, -0.2713], + [ 0.9128, 0.9214, -1.7268, -0.2995], + [ 0.9023, 0.4853, 0.9075, -1.6165]]) + >>> torch.amin(a, 1) + tensor([-1.3312, -0.5744, -1.7268, -1.6165]) + """ + ... +def aminmax(input: Tensor, *, dim: Optional[_int] = None, keepdim: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.aminmax: + r""" + aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) + + Computes the minimum and maximum values of the :attr:`input` tensor. + + Args: + input (Tensor): + The input tensor + + Keyword Args: + dim (Optional[int]): + The dimension along which to compute the values. If `None`, + computes the values over the entire :attr:`input` tensor. + Default is `None`. + keepdim (bool): + If `True`, the reduced dimensions will be kept in the output + tensor as dimensions with size 1 for broadcasting, otherwise + they will be removed, as if calling (:func:`torch.squeeze`). + Default is `False`. + out (Optional[Tuple[Tensor, Tensor]]): + Optional tensors on which to write the result. Must have the same + shape and dtype as the expected output. + Default is `None`. + + Returns: + A named tuple `(min, max)` containing the minimum and maximum values. + + Raises: + RuntimeError + If any of the dimensions to compute the values over has size 0. + + .. note:: + NaN values are propagated to the output if at least one value is NaN. + + .. seealso:: + :func:`torch.amin` computes just the minimum value + :func:`torch.amax` computes just the maximum value + + Example:: + + >>> torch.aminmax(torch.tensor([1, -3, 5])) + torch.return_types.aminmax( + min=tensor(-3), + max=tensor(5)) + + >>> # aminmax propagates NaNs + >>> torch.aminmax(torch.tensor([1, -3, 5, torch.nan])) + torch.return_types.aminmax( + min=tensor(nan), + max=tensor(nan)) + + >>> t = torch.arange(10).view(2, 5) + >>> t + tensor([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + >>> t.aminmax(dim=0, keepdim=True) + torch.return_types.aminmax( + min=tensor([[0, 1, 2, 3, 4]]), + max=tensor([[5, 6, 7, 8, 9]])) + """ + ... +def angle(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + angle(input, *, out=None) -> Tensor + + Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + + .. math:: + \text{out}_{i} = angle(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + .. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + + Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, -45]) + """ + ... +@overload +def any(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def any(input: Tensor, dim: Optional[_size] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def any(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def any(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def arange(start: Number, end: Number, step: Number, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(start: Number, end: Number, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(end: Number, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(end: Union[Number, _complex], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(start: Union[Number, _complex], end: Union[Number, _complex], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(start: Union[Number, _complex], end: Union[Number, _complex], step: Union[Number, _complex] = 1, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +def arccos(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arccos(input, *, out=None) -> Tensor + + Alias for :func:`torch.acos`. + """ + ... +def arccos_(input: Tensor) -> Tensor: ... +def arccosh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arccosh(input, *, out=None) -> Tensor + + Alias for :func:`torch.acosh`. + """ + ... +def arccosh_(input: Tensor) -> Tensor: ... +def arcsin(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arcsin(input, *, out=None) -> Tensor + + Alias for :func:`torch.asin`. + """ + ... +def arcsin_(input: Tensor) -> Tensor: ... +def arcsinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arcsinh(input, *, out=None) -> Tensor + + Alias for :func:`torch.asinh`. + """ + ... +def arcsinh_(input: Tensor) -> Tensor: ... +def arctan(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arctan(input, *, out=None) -> Tensor + + Alias for :func:`torch.atan`. + """ + ... +def arctan2(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arctan2(input, other, *, out=None) -> Tensor + Alias for :func:`torch.atan2`. + """ + ... +def arctan_(input: Tensor) -> Tensor: ... +def arctanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arctanh(input, *, out=None) -> Tensor + + Alias for :func:`torch.atanh`. + """ + ... +def arctanh_(input: Tensor) -> Tensor: ... +def argmax(input: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + argmax(input) -> LongTensor + + Returns the indices of the maximum value of all elements in the :attr:`input` tensor. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple maximal values then the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a) + tensor(0) + + .. function:: argmax(input, dim, keepdim=False) -> LongTensor + :noindex: + + Returns the indices of the maximum values of a tensor across a dimension. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. If ``None``, the argmax of the flattened input is returned. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a, dim=1) + tensor([ 0, 2, 0, 1]) + """ + ... +def argmin(input: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + argmin(input, dim=None, keepdim=False) -> LongTensor + + Returns the indices of the minimum value(s) of the flattened tensor or along a dimension + + This is the second value returned by :meth:`torch.min`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. If ``None``, the argmin of the flattened input is returned. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], + [ 1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) + >>> torch.argmin(a, dim=1) + tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) + """ + ... +@overload +def argsort(input: Tensor, *, stable: _bool, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + ... +@overload +def argsort(input: Tensor, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + ... +@overload +def argsort(input: Tensor, dim: Union[str, ellipsis, None], descending: _bool = False) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + ... +def argwhere(input: Tensor) -> Tensor: + r""" + argwhere(input) -> Tensor + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + .. note:: + This function is similar to NumPy's `argwhere`. + + When :attr:`input` is on CUDA, this function causes host-device synchronization. + + Args: + {input} + + Example:: + + >>> t = torch.tensor([1, 0, 1]) + >>> torch.argwhere(t) + tensor([[0], + [2]]) + >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) + >>> torch.argwhere(t) + tensor([[0, 0], + [0, 2], + [1, 1], + [1, 2]]) + """ + ... +def as_strided(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided(input, size, stride, storage_offset=None) -> Tensor + + Create a view of an existing `torch.Tensor` :attr:`input` with specified + :attr:`size`, :attr:`stride` and :attr:`storage_offset`. + + .. warning:: + Prefer using other view functions, like :meth:`torch.Tensor.expand`, + to setting a view's strides manually with `as_strided`, as this + function's behavior depends on the implementation of a tensor's storage. + The constructed view of the storage must only refer to elements within + the storage or a runtime error will be thrown, and if the view is + "overlapped" (with multiple indices referring to the same element in + memory) its behavior is undefined. + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. + + Example:: + + >>> x = torch.randn(3, 3) + >>> x + tensor([[ 0.9039, 0.6291, 1.0795], + [ 0.1586, 2.1939, -0.4900], + [-0.1909, -0.7503, 1.9355]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2)) + >>> t + tensor([[0.9039, 1.0795], + [0.6291, 0.1586]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) + tensor([[0.6291, 0.1586], + [1.0795, 2.1939]]) + """ + ... +def as_strided_(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: ... +def as_strided_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.as_strided`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def as_strided_scatter(input: Tensor, src: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the elements corresponding to the result of calling + input.as_strided(size, stride, storage_offset). + + This function returns a tensor with fresh storage; it does not + return a view. + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + `torch.as_strided(input, size, stride, storage_offset)` + + Example:: + + >>> a = torch.arange(4).reshape(2, 2) + 1 + >>> a + tensor([[1, 2], + [3, 4]]) + >>> b = torch.zeros(3, 3) + >>> b + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> torch.as_strided_scatter(b, a, (2, 2), (1, 2)) + tensor([[1., 3., 2.], + [4., 0., 0.], + [0., 0., 0.]]) + """ + ... +def as_tensor(data: Any, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None) -> Tensor: + r""" + as_tensor(data, dtype=None, device=None) -> Tensor + + Converts :attr:`data` into a tensor, sharing data and preserving autograd + history if possible. + + If :attr:`data` is already a tensor with the requested dtype and device + then :attr:`data` itself is returned, but if :attr:`data` is a + tensor with a different dtype or device then it's copied as if using + `data.to(dtype=dtype, device=device)`. + + If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device then a + tensor is constructed using :func:`torch.from_numpy`. + + .. seealso:: + + :func:`torch.tensor` never shares its data and creates a new "leaf tensor" (see :doc:`/notes/autograd`). + + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) + """ + ... +def asarray(obj: Any, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, copy: Optional[_bool] = None, requires_grad: _bool = False) -> Tensor: + r""" + asarray(obj, *, dtype=None, device=None, copy=None, requires_grad=False) -> Tensor + + Converts :attr:`obj` to a tensor. + + :attr:`obj` can be one of: + + 1. a tensor + 2. a NumPy array or a NumPy scalar + 3. a DLPack capsule + 4. an object that implements Python's buffer protocol + 5. a scalar + 6. a sequence of scalars + + When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, + by default, not require a gradient, have the same datatype as :attr:`obj`, be on the + same device, and share memory with it. These properties can be controlled with the + :attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. + If the returned tensor is of a different datatype, on a different device, or a copy is + requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` + is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is + also a tensor with an autograd history then the returned tensor will have the same history. + + When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's + buffer protocol then the buffer is interpreted as an array of bytes grouped according to + the size of the datatype passed to the :attr:`dtype` keyword argument. (If no datatype is + passed then the default floating point datatype is used, instead.) The returned tensor + will have the specified datatype (or default floating point datatype if none is specified) + and, by default, be on the CPU device and share memory with the buffer. + + When :attr:`obj` is a NumPy scalar, the returned tensor will be a 0-dimensional tensor on + the CPU and that doesn't share its memory (i.e. ``copy=True``). By default datatype will + be the PyTorch datatype corresponding to the NumPy's scalar's datatype. + + When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the + returned tensor will, by default, infer its datatype from the scalar values, be on the + current default device, and not share its memory. + + .. seealso:: + + :func:`torch.tensor` creates a tensor that always copies the data from the input object. + :func:`torch.from_numpy` creates a tensor that always shares memory from NumPy arrays. + :func:`torch.frombuffer` creates a tensor that always shares memory from objects that + implement the buffer protocol. + :func:`torch.from_dlpack` creates a tensor that always shares memory from + DLPack capsules. + + Args: + obj (object): a tensor, NumPy array, DLPack Capsule, object that implements Python's + buffer protocol, scalar, or sequence of scalars. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the datatype of the returned tensor. + Default: ``None``, which causes the datatype of the returned tensor to be + inferred from :attr:`obj`. + copy (bool, optional): controls whether the returned tensor shares memory with :attr:`obj`. + Default: ``None``, which causes the returned tensor to share memory with :attr:`obj` + whenever possible. If ``True`` then the returned tensor does not share its memory. + If ``False`` then the returned tensor shares its memory with :attr:`obj` and an + error is thrown if it cannot. + device (:class:`torch.device`, optional): the device of the returned tensor. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. + requires_grad (bool, optional): whether the returned tensor requires grad. + Default: ``False``, which causes the returned tensor not to require a gradient. + If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` + is also a tensor with an autograd history then the returned tensor will have + the same history. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> # Shares memory with tensor 'a' + >>> b = torch.asarray(a) + >>> a.data_ptr() == b.data_ptr() + True + >>> # Forces memory copy + >>> c = torch.asarray(a, copy=True) + >>> a.data_ptr() == c.data_ptr() + False + + >>> a = torch.tensor([1., 2., 3.], requires_grad=True) + >>> b = a + 2 + >>> b + tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', with no grad + >>> c = torch.asarray(b) + >>> c + tensor([3., 4., 5.]) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> d = torch.asarray(b, requires_grad=True) + >>> d + tensor([3., 4., 5.], grad_fn=) + + >>> array = numpy.array([1, 2, 3]) + >>> # Shares memory with array 'array' + >>> t1 = torch.asarray(array) + >>> array.__array_interface__['data'][0] == t1.data_ptr() + True + >>> # Copies memory due to dtype mismatch + >>> t2 = torch.asarray(array, dtype=torch.float32) + >>> array.__array_interface__['data'][0] == t2.data_ptr() + False + + >>> scalar = numpy.float64(0.5) + >>> torch.asarray(scalar) + tensor(0.5000, dtype=torch.float64) + """ + ... +def asin(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + asin(input, *, out=None) -> Tensor + + Returns a new tensor with the arcsine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5962, 1.4985, -0.4396, 1.4525]) + >>> torch.asin(a) + tensor([-0.6387, nan, -0.4552, nan]) + """ + ... +def asin_(input: Tensor) -> Tensor: ... +def asinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + asinh(input, *, out=None) -> Tensor + + Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) + >>> torch.asinh(a) + tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) + """ + ... +def asinh_(input: Tensor) -> Tensor: ... +def atan(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + atan(input, *, out=None) -> Tensor + + Returns a new tensor with the arctangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) + >>> torch.atan(a) + tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) + """ + ... +def atan2(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + atan2(input, other, *, out=None) -> Tensor + + Element-wise arctangent of :math:`\text{input}_{i} / \text{other}_{i}` + with consideration of the quadrant. Returns a new tensor with the signed angles + in radians between vector :math:`(\text{other}_{i}, \text{input}_{i})` + and vector :math:`(1, 0)`. (Note that :math:`\text{other}_{i}`, the second + parameter, is the x-coordinate, while :math:`\text{input}_{i}`, the first + parameter, is the y-coordinate.) + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) + >>> torch.atan2(a, torch.randn(4)) + tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) + """ + ... +def atan_(input: Tensor) -> Tensor: ... +def atanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + atanh(input, *, out=None) -> Tensor + + Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + + Note: + The domain of the inverse hyperbolic tangent is `(-1, 1)` and values outside this range + will be mapped to ``NaN``, except for the values `1` and `-1` for which the output is + mapped to `+/-INF` respectively. + + .. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(-1, 1) + >>> a + tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) + >>> torch.atanh(a) + tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) + """ + ... +def atanh_(input: Tensor) -> Tensor: ... +def avg_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, ceil_mode: _bool = False, count_include_pad: _bool = True) -> Tensor: ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(input: Tensor, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def bartlett_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +@overload +def bartlett_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +def batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ... +def batch_norm_backward_elemt(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], sum_dy: Tensor, sum_dy_xmu: Tensor, count: Tensor) -> Tensor: ... +def batch_norm_backward_reduce(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], input_g: _bool, weight_g: _bool, bias_g: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def batch_norm_elemt(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], mean: Tensor, invstd: Tensor, eps: _float, *, out: Optional[Tensor] = None) -> Tensor: ... +def batch_norm_gather_stats(input: Tensor, mean: Tensor, invstd: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], momentum: _float, eps: _float, count: _int) -> Tuple[Tensor, Tensor]: ... +def batch_norm_gather_stats_with_counts(input: Tensor, mean: Tensor, invstd: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], momentum: _float, eps: _float, counts: Tensor) -> Tuple[Tensor, Tensor]: ... +def batch_norm_stats(input: Tensor, eps: _float) -> Tuple[Tensor, Tensor]: ... +def batch_norm_update_stats(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], momentum: _float) -> Tuple[Tensor, Tensor]: ... +@overload +def bernoulli(input: Tensor, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + bernoulli(input, *, generator=None, out=None) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + ... +@overload +def bernoulli(input: Tensor, p: _float, *, generator: Optional[Generator] = None) -> Tensor: + r""" + bernoulli(input, *, generator=None, out=None) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + ... +def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ... +def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, weight: Optional[Tensor] = None, pos_weight: Optional[Tensor] = None, reduction: _int = 1) -> Tensor: ... +def bincount(input: Tensor, weights: Optional[Tensor] = None, minlength: _int = 0) -> Tensor: + r""" + bincount(input, weights=None, minlength=0) -> Tensor + + Count the frequency of each value in an array of non-negative ints. + + The number of bins (size 1) is one larger than the largest value in + :attr:`input` unless :attr:`input` is empty, in which case the result is a + tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least + :attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size + :attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``, + ``out[n] += weights[i]`` if :attr:`weights` is specified else + ``out[n] += 1``. + + Note: + This operation may produce nondeterministic gradients when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + Arguments: + input (Tensor): 1-d int tensor + weights (Tensor): optional, weight for each value in the input tensor. + Should be of same size as input tensor. + minlength (int): optional, minimum number of bins. Should be non-negative. + + Returns: + output (Tensor): a tensor of shape ``Size([max(input) + 1])`` if + :attr:`input` is non-empty, else ``Size(0)`` + + Example:: + + >>> input = torch.randint(0, 8, (5,), dtype=torch.int64) + >>> weights = torch.linspace(0, 1, steps=5) + >>> input, weights + (tensor([4, 3, 6, 3, 4]), + tensor([ 0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + >>> torch.bincount(input) + tensor([0, 0, 0, 2, 2, 0, 1]) + + >>> input.bincount(weights) + tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) + """ + ... +def binomial(count: Tensor, prob: Tensor, generator: Optional[Generator] = None) -> Tensor: ... +@overload +def bitwise_and(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + ... +@overload +def bitwise_and(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + ... +@overload +def bitwise_and(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + ... +@overload +def bitwise_left_shift(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + ... +@overload +def bitwise_left_shift(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + ... +@overload +def bitwise_left_shift(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + ... +def bitwise_not(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_not(input, *, out=None) -> Tensor + + Computes the bitwise NOT of the given input tensor. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical NOT. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) + tensor([ 0, 1, -4], dtype=torch.int8) + """ + ... +@overload +def bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_or(input, other, *, out=None) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + ... +@overload +def bitwise_or(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_or(input, other, *, out=None) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + ... +@overload +def bitwise_or(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_or(input, other, *, out=None) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + ... +@overload +def bitwise_right_shift(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + ... +@overload +def bitwise_right_shift(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + ... +@overload +def bitwise_right_shift(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + ... +@overload +def bitwise_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + ... +@overload +def bitwise_xor(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + ... +@overload +def bitwise_xor(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + ... +@overload +def blackman_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +@overload +def blackman_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +def bmm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bmm(input, mat2, *, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored in :attr:`input` + and :attr:`mat2`. + + :attr:`input` and :attr:`mat2` must be 3-D tensors each containing + the same number of matrices. + + If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a + :math:`(b \times m \times p)` tensor, :attr:`out` will be a + :math:`(b \times n \times p)` tensor. + + .. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + + Keyword Args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) + """ + ... +def broadcast_to(input: Tensor, size: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + broadcast_to(input, shape) -> Tensor + + Broadcasts :attr:`input` to the shape :attr:`\shape`. + Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + + Args: + input (Tensor): the input tensor. + shape (list, tuple, or :class:`torch.Size`): the new shape. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) + """ + ... +@overload +def bucketize(input: Tensor, boundaries: Tensor, *, out_int32: _bool = False, right: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of :attr:`boundaries` (one pass the last index). + In other words, if False, gets the lower bound index for each value in :attr:`input` + from :attr:`boundaries`. If True, gets the upper bound index instead. + Default value is False. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + ... +@overload +def bucketize(self: Union[Number, _complex], boundaries: Tensor, *, out_int32: _bool = False, right: _bool = False) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of :attr:`boundaries` (one pass the last index). + In other words, if False, gets the lower bound index for each value in :attr:`input` + from :attr:`boundaries`. If True, gets the upper bound index instead. + Default value is False. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + ... +def can_cast(from_: _dtype, to: _dtype) -> _bool: + r""" + can_cast(from, to) -> bool + + Determines if a type conversion is allowed under PyTorch casting rules + described in the type promotion :ref:`documentation `. + + Args: + from (dtype): The original :class:`torch.dtype`. + to (dtype): The target :class:`torch.dtype`. + + Example:: + + >>> torch.can_cast(torch.double, torch.float) + True + >>> torch.can_cast(torch.float, torch.int) + False + """ + ... +@overload +def cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of :attr:`seq` tensors in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): any python sequence of tensors of the same type. + Non-empty tensors provided must have the same shape, except in the + cat dimension. + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + ... +@overload +def cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of :attr:`seq` tensors in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): any python sequence of tensors of the same type. + Non-empty tensors provided must have the same shape, except in the + cat dimension. + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + ... +def ccol_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def ceil(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ceil(input, *, out=None) -> Tensor + + Returns a new tensor with the ceil of the elements of :attr:`input`, + the smallest integer greater than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + >>> torch.ceil(a) + tensor([-0., -1., -1., 1.]) + """ + ... +def ceil_(input: Tensor) -> Tensor: ... +def celu(input: Tensor, alpha: Union[Number, _complex] = 1.0) -> Tensor: ... +def celu_(input: Tensor, alpha: Union[Number, _complex] = 1.0) -> Tensor: ... +def channel_shuffle(input: Tensor, groups: Union[_int, SymInt]) -> Tensor: ... +def cholesky(input: Tensor, upper: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cholesky(input, upper=False, *, out=None) -> Tensor + + Computes the Cholesky decomposition of a symmetric positive-definite + matrix :math:`A` or for batches of symmetric positive-definite matrices. + + If :attr:`upper` is ``True``, the returned matrix ``U`` is upper-triangular, and + the decomposition has the form: + + .. math:: + + A = U^TU + + If :attr:`upper` is ``False``, the returned matrix ``L`` is lower-triangular, and + the decomposition has the form: + + .. math:: + + A = LL^T + + If :attr:`upper` is ``True``, and :math:`A` is a batch of symmetric positive-definite + matrices, then the returned tensor will be composed of upper-triangular Cholesky factors + of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned + tensor will be composed of lower-triangular Cholesky factors of each of the individual + matrices. + + .. warning:: + + :func:`torch.cholesky` is deprecated in favor of :func:`torch.linalg.cholesky` + and will be removed in a future PyTorch release. + + ``L = torch.cholesky(A)`` should be replaced with + + .. code:: python + + L = torch.linalg.cholesky(A) + + ``U = torch.cholesky(A, upper=True)`` should be replaced with + + .. code:: python + + U = torch.linalg.cholesky(A).mH + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. + + Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + upper (bool, optional): flag that indicates whether to return a + upper or lower triangular matrix. Default: ``False`` + + Keyword args: + out (Tensor, optional): the output matrix + + Example:: + + >>> a = torch.randn(3, 3) + >>> a = a @ a.mT + 1e-3 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> l @ l.mT + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> a = torch.randn(3, 2, 2) # Example for batched input + >>> a = a @ a.mT + 1e-03 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> z = l @ l.mT + >>> torch.dist(z, a) + tensor(2.3842e-07) + """ + ... +def cholesky_inverse(input: Tensor, upper: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cholesky_inverse(L, upper=False, *, out=None) -> Tensor + + Computes the inverse of a complex Hermitian or real symmetric + positive-definite matrix given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Computes the inverse matrix :math:`A^{-1}`. + + Supports input of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` is a batch of matrices + then the output has the same batch dimensions. + + Args: + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False`` + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> torch.cholesky_inverse(L) + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + >>> A.inverse() + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(torch.inverse(A), torch.cholesky_inverse(L)) + tensor(5.6358e-7) + """ + ... +def cholesky_solve(input: Tensor, input2: Tensor, upper: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cholesky_solve(B, L, upper=False, *, out=None) -> Tensor + + Computes the solution of a system of linear equations with complex Hermitian + or real symmetric positive-definite lhs given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Returns the solution :math:`X` of the following linear system: + + .. math:: + + AX = B + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices + then the output has the same batch dimensions. + + Args: + B (Tensor): right-hand side tensor of shape `(*, n, k)` + where :math:`*` is zero or more batch dimensions + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False``. + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> B = torch.randn(3, 2) + >>> torch.cholesky_solve(B, L) + tensor([[ -8.1625, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + >>> A.inverse() @ B + tensor([[ -8.1626, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> B = torch.randn(2, 1, dtype=torch.complex64) + >>> X = torch.cholesky_solve(B, L) + >>> torch.dist(X, A.inverse() @ B) + tensor(1.6881e-5) + """ + ... +def choose_qparams_optimized(input: Tensor, numel: _int, n_bins: _int, ratio: _float, bit_width: _int) -> Tuple[Tensor, Tensor]: ... +def chunk(input: Tensor, chunks: _int, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + chunk(input, chunks, dim=0) -> List of Tensors + + Attempts to split a tensor into the specified number of chunks. Each chunk is a view of + the input tensor. + + + .. note:: + + This function may return fewer than the specified number of chunks! + + .. seealso:: + + :func:`torch.tensor_split` a function that always returns exactly the specified number of chunks + + If the tensor size along the given dimension :attr:`dim` is divisible by :attr:`chunks`, + all returned chunks will be the same size. + If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`, + all returned chunks will be the same size, except the last one. + If such division is not possible, this function may return fewer + than the specified number of chunks. + + Arguments: + input (Tensor): the tensor to split + chunks (int): number of chunks to return + dim (int): dimension along which to split the tensor + + Example: + >>> torch.arange(11).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10])) + >>> torch.arange(12).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10, 11])) + >>> torch.arange(13).chunk(6) + (tensor([0, 1, 2]), + tensor([3, 4, 5]), + tensor([6, 7, 8]), + tensor([ 9, 10, 11]), + tensor([12])) + """ + ... +@overload +def clamp(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + ... +@overload +def clamp(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + ... +@overload +def clamp_(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: ... +@overload +def clamp_max(input: Tensor, max: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_max(input: Tensor, max: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Tensor) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Union[Number, _complex]) -> Tensor: ... +@overload +def clamp_min(input: Tensor, min: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_min(input: Tensor, min: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Tensor) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Union[Number, _complex]) -> Tensor: ... +@overload +def clip(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + ... +@overload +def clip(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + ... +@overload +def clip_(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: ... +@overload +def clip_(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: ... +def clone(input: Tensor, *, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + clone(input, *, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of :attr:`input`. + + .. note:: + + This function is differentiable, so gradients will flow back from the + result of this operation to :attr:`input`. To create a tensor without an + autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned tensor. Default: ``torch.preserve_format``. + """ + ... +def col_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.col_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def column_stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + column_stack(tensors, *, out=None) -> Tensor + + Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + + Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` + in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + """ + ... +def combinations(input: Tensor, r: _int = 2, with_replacement: _bool = False) -> Tensor: + r""" + combinations(input, r=2, with_replacement=False) -> seq + + Compute combinations of length :math:`r` of the given tensor. The behavior is similar to + python's `itertools.combinations` when `with_replacement` is set to `False`, and + `itertools.combinations_with_replacement` when `with_replacement` is set to `True`. + + Arguments: + input (Tensor): 1D vector. + r (int, optional): number of elements to combine + with_replacement (bool, optional): whether to allow duplication in combination + + Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, do + `itertools.combinations` or `itertools.combinations_with_replacement` on these + lists, and finally convert the resulting list into tensor. + + Example:: + + >>> a = [1, 2, 3] + >>> list(itertools.combinations(a, r=2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(itertools.combinations(a, r=3)) + [(1, 2, 3)] + >>> list(itertools.combinations_with_replacement(a, r=2)) + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + >>> tensor_a = torch.tensor(a) + >>> torch.combinations(tensor_a) + tensor([[1, 2], + [1, 3], + [2, 3]]) + >>> torch.combinations(tensor_a, r=3) + tensor([[1, 2, 3]]) + >>> torch.combinations(tensor_a, with_replacement=True) + tensor([[1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3]]) + """ + ... +def complex(real: Tensor, imag: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + complex(real, imag, *, out=None) -> Tensor + + Constructs a complex tensor with its real part equal to :attr:`real` and its + imaginary part equal to :attr:`imag`. + + Args: + real (Tensor): The real part of the complex tensor. Must be half, float or double. + imag (Tensor): The imaginary part of the complex tensor. Must be same dtype + as :attr:`real`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> real = torch.tensor([1, 2], dtype=torch.float32) + >>> imag = torch.tensor([3, 4], dtype=torch.float32) + >>> z = torch.complex(real, imag) + >>> z + tensor([(1.+3.j), (2.+4.j)]) + >>> z.dtype + torch.complex64 + """ + ... +@overload +def concat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +@overload +def concat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +@overload +def concatenate(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +@overload +def concatenate(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +def conj(input: Tensor) -> Tensor: + r""" + conj(input) -> Tensor + + Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, + this function just returns :attr:`input`. + + .. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + + .. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True + """ + ... +def conj_physical(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + conj_physical(input, *, out=None) -> Tensor + + Computes the element-wise conjugate of the given :attr:`input` tensor. + If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. + + .. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + + .. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + .. math:: + \text{out}_{i} = conj(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + """ + ... +def conj_physical_(input: Tensor) -> Tensor: ... +def constant_pad_nd(input: Tensor, pad: Sequence[Union[_int, SymInt]], value: Union[Number, _complex] = 0) -> Tensor: ... +@overload +def conv1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: str = "valid", dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: str = "valid", dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: str = "valid", dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +def conv_tbc(input: Tensor, weight: Tensor, bias: Tensor, pad: _int = 0) -> Tensor: ... +def conv_transpose1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, output_padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, groups: Union[_int, SymInt] = 1, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def conv_transpose2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, output_padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, groups: Union[_int, SymInt] = 1, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def conv_transpose3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, output_padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, groups: Union[_int, SymInt] = 1, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], transposed: _bool, output_padding: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +@overload +def copysign(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + ... +@overload +def copysign(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + ... +def corrcoef(input: Tensor) -> Tensor: + r""" + corrcoef(input) -> Tensor + + Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, + where rows are the variables and columns are the observations. + + .. note:: + + The correlation coefficient matrix R is computed using the covariance matrix C as given by + :math:`R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }` + + .. note:: + + Due to floating point rounding, the resulting array may not be Hermitian and its diagonal elements may not be 1. + The real and imaginary values are clipped to the interval [-1, 1] in an attempt to improve this situation. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Returns: + (Tensor) The correlation coefficient matrix of the variables. + + .. seealso:: + + :func:`torch.cov` covariance matrix. + + Example:: + + >>> x = torch.tensor([[0, 1, 2], [2, 1, 0]]) + >>> torch.corrcoef(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> x = torch.randn(2, 4) + >>> x + tensor([[-0.2678, -0.0908, -0.3766, 0.2780], + [-0.5812, 0.1535, 0.2387, 0.2350]]) + >>> torch.corrcoef(x) + tensor([[1.0000, 0.3582], + [0.3582, 1.0000]]) + >>> torch.corrcoef(x[0]) + tensor(1.) + """ + ... +def cos(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cos(input, *, out=None) -> Tensor + + Returns a new tensor with the cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cos(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) + >>> torch.cos(a) + tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) + """ + ... +def cos_(input: Tensor) -> Tensor: ... +def cosh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cosh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic cosine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) + >>> torch.cosh(a) + tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + ... +def cosh_(input: Tensor) -> Tensor: ... +def cosine_embedding_loss(input1: Tensor, input2: Tensor, target: Tensor, margin: _float = 0.0, reduction: _int = 1) -> Tensor: ... +def cosine_similarity(x1: Tensor, x2: Tensor, dim: _int = 1, eps: _float = 1e-08) -> Tensor: ... +@overload +def count_nonzero(input: Tensor, dim: Optional[_int] = None) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + ... +@overload +def count_nonzero(input: Tensor, dim: _size) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + ... +def cov(input: Tensor, *, correction: _int = 1, fweights: Optional[Tensor] = None, aweights: Optional[Tensor] = None) -> Tensor: + r""" + cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor + + Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are + the variables and columns are the observations. + + A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains + the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents + a single variable (Scalar or 1D) then its variance is returned. + + The sample covariance of the variables :math:`x` and :math:`y` is given by: + + .. math:: + \text{cov}(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{\max(0,~N~-~\delta N)} + + where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively, and + :math:`\delta N` is the :attr:`correction`. + + If :attr:`fweights` and/or :attr:`aweights` are provided, the weighted covariance + is calculated, which is given by: + + .. math:: + \text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)} + {\max(0,~\sum^{N}_{i = 1}w_i~-~\frac{\sum^{N}_{i = 1}w_ia_i}{\sum^{N}_{i = 1}w_i}~\delta N)} + + where :math:`w` denotes :attr:`fweights` or :attr:`aweights` (``f`` and ``a`` for brevity) based on whichever is + provided, or :math:`w = f \times a` if both are provided, and + :math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable. If not + provided, ``f`` and/or ``a`` can be seen as a :math:`\mathbb{1}` vector of appropriate size. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Keyword Args: + correction (int, optional): difference between the sample size and sample degrees of freedom. + Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate, + even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0`` + will return the simple average. Defaults to ``1``. + fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of + times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`. + Must have integral dtype. Ignored if ``None``. Defaults to ``None``. + aweights (tensor, optional): A Scalar or 1D array of observation vector weights. + These relative weights are typically large for observations considered "important" and smaller for + observations considered less "important". Its numel must equal the number of columns of :attr:`input`. + Must have floating point dtype. Ignored if ``None``. Defaults to ``None``. + + Returns: + (Tensor) The covariance matrix of the variables. + + .. seealso:: + + :func:`torch.corrcoef` normalized covariance matrix. + + Example:: + >>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T + >>> x + tensor([[0, 1, 2], + [2, 1, 0]]) + >>> torch.cov(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> torch.cov(x, correction=0) + tensor([[ 0.6667, -0.6667], + [-0.6667, 0.6667]]) + >>> fw = torch.randint(1, 10, (3,)) + >>> fw + tensor([1, 6, 9]) + >>> aw = torch.rand(3) + >>> aw + tensor([0.4282, 0.0255, 0.4144]) + >>> torch.cov(x, fweights=fw, aweights=aw) + tensor([[ 0.4169, -0.4169], + [-0.4169, 0.4169]]) + """ + ... +def cross(input: Tensor, other: Tensor, dim: Optional[_int] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cross(input, other, dim=None, *, out=None) -> Tensor + + + Returns the cross product of vectors in dimension :attr:`dim` of :attr:`input` + and :attr:`other`. + + Supports input of float, double, cfloat and cdouble dtypes. Also supports batches + of vectors, for which it computes the product along the dimension :attr:`dim`. + In this case, the output has the same batch dimensions as the inputs. + + .. warning:: + If :attr:`dim` is not given, it defaults to the first dimension found + with the size 3. Note that this might be unexpected. + + This behavior is deprecated and will be changed to match that of :func:`torch.linalg.cross` + in a future release. + + .. seealso:: + :func:`torch.linalg.cross` which has dim=-1 as default. + + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + dim (int, optional): the dimension to take the cross-product in. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.cross(a, b, dim=1) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> torch.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + """ + ... +def crow_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.crow_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int = 0, reduction: _int = 1, zero_infinity: _bool = False) -> Tensor: ... +@overload +def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int = 0, reduction: _int = 1, zero_infinity: _bool = False) -> Tensor: ... +def cudnn_affine_grid_generator(theta: Tensor, N: _int, C: _int, H: _int, W: _int) -> Tensor: ... +def cudnn_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, exponential_average_factor: _float, epsilon: _float) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def cudnn_convolution(input: Tensor, weight: Tensor, padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, allow_tf32: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def cudnn_convolution_add_relu(input: Tensor, weight: Tensor, z: Tensor, alpha: Optional[Union[Number, _complex]], bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def cudnn_convolution_relu(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def cudnn_convolution_transpose(input: Tensor, weight: Tensor, padding: Sequence[Union[_int, SymInt]], output_padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, allow_tf32: _bool) -> Tensor: ... +def cudnn_grid_sampler(input: Tensor, grid: Tensor) -> Tensor: ... +def cudnn_is_acceptable(input: Tensor) -> _bool: ... +@overload +def cummax(input: Tensor, dim: _int, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + ... +@overload +def cummax(input: Tensor, dim: Union[str, ellipsis, None], *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + ... +@overload +def cummin(input: Tensor, dim: _int, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + ... +@overload +def cummin(input: Tensor, dim: Union[str, ellipsis, None], *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + ... +@overload +def cumprod(input: Tensor, dim: _int, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + ... +@overload +def cumprod(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + ... +@overload +def cumsum(input: Tensor, dim: _int, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + ... +@overload +def cumsum(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + ... +@overload +def cumulative_trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + ... +@overload +def cumulative_trapezoid(y: Tensor, *, dx: Union[Number, _complex] = 1, dim: _int = -1) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + ... +def deg2rad(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + deg2rad(input, *, out=None) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in degrees to radians. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) + >>> torch.deg2rad(a) + tensor([[ 3.1416, -3.1416], + [ 6.2832, -6.2832], + [ 1.5708, -1.5708]]) + """ + ... +def deg2rad_(input: Tensor) -> Tensor: ... +@overload +def dequantize(input: Tensor) -> Tensor: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + ... +@overload +def dequantize(tensors: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + ... +def det(input: Tensor) -> Tensor: + r""" + det(input) -> Tensor + + Alias for :func:`torch.linalg.det` + """ + ... +def detach(input: Tensor) -> Tensor: ... +def detach_(input: Tensor) -> Tensor: ... +def detach_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.detach`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def diag(input: Tensor, diagonal: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + diag(input, diagonal=0, *, out=None) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a matrix (2-D tensor), then returns a 1-D tensor with + the diagonal elements of :attr:`input`. + + The argument :attr:`diagonal` controls which diagonal to consider: + + - If :attr:`diagonal` = 0, it is the main diagonal. + - If :attr:`diagonal` > 0, it is above the main diagonal. + - If :attr:`diagonal` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.diagonal` always returns the diagonal of its input. + + :func:`torch.diagflat` always constructs a tensor with diagonal elements + specified by the input. + + Examples: + + Get the square matrix where the input vector is the diagonal:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.5950,-0.0872, 2.3298]) + >>> torch.diag(a) + tensor([[ 0.5950, 0.0000, 0.0000], + [ 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 2.3298]]) + >>> torch.diag(a, 1) + tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], + [ 0.0000, 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 0.0000, 2.3298], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + Get the k-th diagonal of a given matrix:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-0.4264, 0.0255,-0.1064], + [ 0.8795,-0.2429, 0.1374], + [ 0.1029,-0.6482,-1.6300]]) + >>> torch.diag(a, 0) + tensor([-0.4264,-0.2429,-1.6300]) + >>> torch.diag(a, 1) + tensor([ 0.0255, 0.1374]) + """ + ... +def diag_embed(input: Tensor, offset: _int = 0, dim1: _int = -2, dim2: _int = -1) -> Tensor: + r""" + diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor + + Creates a tensor whose diagonals of certain 2D planes (specified by + :attr:`dim1` and :attr:`dim2`) are filled by :attr:`input`. + To facilitate creating batched diagonal matrices, the 2D planes formed by + the last two dimensions of the returned tensor are chosen by default. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + The size of the new matrix will be calculated to make the specified diagonal + of the size of the last input dimension. + Note that for :attr:`offset` other than :math:`0`, the order of :attr:`dim1` + and :attr:`dim2` matters. Exchanging them is equivalent to changing the + sign of :attr:`offset`. + + Applying :meth:`torch.diagonal` to the output of this function with + the same arguments yields a matrix identical to input. However, + :meth:`torch.diagonal` has different default dimensions, so those + need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 1-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: -2. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: -1. + + Example:: + + >>> a = torch.randn(2, 3) + >>> torch.diag_embed(a) + tensor([[[ 1.5410, 0.0000, 0.0000], + [ 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -2.1788]], + + [[ 0.5684, 0.0000, 0.0000], + [ 0.0000, -1.0845, 0.0000], + [ 0.0000, 0.0000, -1.3986]]]) + + >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) + tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], + [ 0.0000, 0.5684, 0.0000, 0.0000]], + + [[ 0.0000, 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -1.0845, 0.0000]], + + [[ 0.0000, 0.0000, 0.0000, -2.1788], + [ 0.0000, 0.0000, 0.0000, -1.3986]], + + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]) + """ + ... +def diagflat(input: Tensor, offset: _int = 0) -> Tensor: + r""" + diagflat(input, offset=0) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a tensor with more than one dimension, then returns a + 2-D tensor with diagonal elements equal to a flattened :attr:`input`. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + offset (int, optional): the diagonal to consider. Default: 0 (main + diagonal). + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([-0.2956, -0.9068, 0.1695]) + >>> torch.diagflat(a) + tensor([[-0.2956, 0.0000, 0.0000], + [ 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.1695]]) + >>> torch.diagflat(a, 1) + tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.1695], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + >>> a = torch.randn(2, 2) + >>> a + tensor([[ 0.2094, -0.3018], + [-0.1516, 1.9342]]) + >>> torch.diagflat(a) + tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.3018, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1516, 0.0000], + [ 0.0000, 0.0000, 0.0000, 1.9342]]) + """ + ... +@overload +def diagonal(input: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a, 0) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + ... +@overload +def diagonal(input: Tensor, *, outdim: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None], dim2: Union[str, ellipsis, None], offset: _int = 0) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a, 0) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + ... +def diagonal_copy(input: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.diagonal`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def diagonal_scatter(input: Tensor, src: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1) -> Tensor: + r""" + diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the diagonal elements of :attr:`input`, with respect to :attr:`dim1` + and :attr:`dim2`. + + This function returns a tensor with fresh storage; it does not + return a view. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + src (Tensor): the tensor to embed into :attr:`input`. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.diagonal(input, offset, dim1, dim2)`` + + Examples:: + + >>> a = torch.zeros(3, 3) + >>> a + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + + >>> torch.diagonal_scatter(a, torch.ones(3), 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + >>> torch.diagonal_scatter(a, torch.ones(2), 1) + tensor([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) + """ + ... +def diff(input: Tensor, n: _int = 1, dim: _int = -1, prepend: Optional[Tensor] = None, append: Optional[Tensor] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor + + Computes the n-th forward difference along the given dimension. + + The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order + differences are calculated by using :func:`torch.diff` recursively. + + Args: + input (Tensor): the tensor to compute the differences on + n (int, optional): the number of times to recursively compute the difference + dim (int, optional): the dimension to compute the difference along. + Default is the last dimension. + prepend, append (Tensor, optional): values to prepend or append to + :attr:`input` along :attr:`dim` before computing the difference. + Their dimensions must be equivalent to that of input, and their shapes + must match input's shape except on :attr:`dim`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 3, 2]) + >>> torch.diff(a) + tensor([ 2, -1]) + >>> b = torch.tensor([4, 5]) + >>> torch.diff(a, append=b) + tensor([ 2, -1, 2, 1]) + >>> c = torch.tensor([[1, 2, 3], [3, 4, 5]]) + >>> torch.diff(c, dim=0) + tensor([[2, 2, 2]]) + >>> torch.diff(c, dim=1) + tensor([[1, 1], + [1, 1]]) + """ + ... +def digamma(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + digamma(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.digamma`. + """ + ... +def dist(input: Tensor, other: Tensor, p: Union[Number, _complex] = 2) -> Tensor: + r""" + dist(input, other, p=2) -> Tensor + + Returns the p-norm of (:attr:`input` - :attr:`other`) + + The shapes of :attr:`input` and :attr:`other` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + other (Tensor): the Right-hand-side input tensor + p (float, optional): the norm to be computed + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([-1.5393, -0.8675, 0.5916, 1.6321]) + >>> y = torch.randn(4) + >>> y + tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) + >>> torch.dist(x, y, 3.5) + tensor(1.6727) + >>> torch.dist(x, y, 3) + tensor(1.6973) + >>> torch.dist(x, y, 0) + tensor(4.) + >>> torch.dist(x, y, 1) + tensor(2.6537) + """ + ... +def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + div(input, other, *, rounding_mode=None, out=None) -> Tensor + + Divides each element of the input ``input`` by the corresponding element of + :attr:`other`. + + .. math:: + \text{out}_i = \frac{\text{input}_i}{\text{other}_i} + + .. note:: + By default, this performs a "true" division like Python 3. + See the :attr:`rounding_mode` argument for floor division. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + Always promotes integer types to the default scalar type. + + Args: + input (Tensor): the dividend + other (Tensor or Number): the divisor + + Keyword args: + rounding_mode (str, optional): Type of rounding applied to the result: + + * None - default behavior. Performs no rounding and, if both :attr:`input` and + :attr:`other` are integer types, promotes the inputs to the default scalar type. + Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``. + * ``"trunc"`` - rounds the results of the division towards zero. + Equivalent to C-style integer division. + * ``"floor"`` - rounds the results of the division down. + Equivalent to floor division in Python (the ``//`` operator) and NumPy's ``np.floor_divide``. + + out (Tensor, optional): the output tensor. + + Examples:: + + >>> x = torch.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) + >>> torch.div(x, 0.5) + tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274]) + + >>> a = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917], + ... [ 0.1815, -1.0111, 0.9805, -1.5923], + ... [ 0.1062, 1.4581, 0.7759, -1.2344], + ... [-0.1830, -0.0313, 1.1908, -1.4757]]) + >>> b = torch.tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) + >>> torch.div(a, b) + tensor([[-0.4620, -6.6051, 0.5676, 1.2639], + [ 0.2260, -3.4509, -1.2086, 6.8990], + [ 0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + + >>> torch.div(a, b, rounding_mode='trunc') + tensor([[-0., -6., 0., 1.], + [ 0., -3., -1., 6.], + [ 0., 4., -0., 5.], + [-0., -0., -1., 6.]]) + + >>> torch.div(a, b, rounding_mode='floor') + tensor([[-1., -7., 0., 1.], + [ 0., -4., -2., 6.], + [ 0., 4., -1., 5.], + [-1., -1., -2., 6.]]) + """ + ... +@overload +def divide(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +@overload +def divide(input: Tensor, other: Tensor, *, rounding_mode: Optional[str], out: Optional[Tensor] = None) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +@overload +def divide(input: Tensor, other: Union[Number, _complex], *, rounding_mode: Optional[str]) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +@overload +def divide(input: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +def dot(input: Tensor, tensor: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + dot(input, other, *, out=None) -> Tensor + + Computes the dot product of two 1D tensors. + + .. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. + other (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + """ + ... +def dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def dsplit(input: Tensor, sections: _int) -> Tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + ... +@overload +def dsplit(input: Tensor, indices: _size) -> Tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + ... +def dstack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + dstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence depthwise (along third axis). + + This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by :func:`torch.atleast_3d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.dstack((a,b)) + tensor([[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.dstack((a,b)) + tensor([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + """ + ... +def embedding(weight: Tensor, indices: Tensor, padding_idx: Union[_int, SymInt] = -1, scale_grad_by_freq: _bool = False, sparse: _bool = False) -> Tensor: ... +@overload +def embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool, mode: _int, sparse: _bool, per_sample_weights: Optional[Tensor], include_last_offset: _bool, padding_idx: Optional[_int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def embedding_renorm_(input: Tensor, indices: Tensor, max_norm: _float, norm_type: _float) -> Tensor: ... +@overload +def empty(size: Sequence[Union[_int, SymInt]], *, memory_format: Optional[memory_format] = None, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +@overload +def empty(*size: _int, memory_format: Optional[memory_format] = None, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +@overload +def empty(size: _size, *, names: Optional[Sequence[Union[str, ellipsis, None]]], memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +@overload +def empty(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +def empty_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns an uninitialized tensor with the same size as :attr:`input`. + ``torch.empty_like(input)`` is equivalent to + ``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> a=torch.empty((2,3), dtype=torch.int32, device = 'cuda') + >>> torch.empty_like(a) + tensor([[0, 0, 0], + [0, 0, 0]], device='cuda:0', dtype=torch.int32) + """ + ... +def empty_permuted(size: Sequence[Union[_int, SymInt]], physical_layout: _size, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates an uninitialized, non-overlapping and dense tensor with the + specified :attr:`size`, with :attr:`physical_layout` specifying how the + dimensions are physically laid out in memory (each logical dimension is listed + from outermost to innermost). :attr:`physical_layout` is a generalization + of NCHW/NHWC notation: if each dimension is assigned a number according to + what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)`` + while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output + tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]`` + (notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``). + + Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense + tensor with no overlaps. If possible, prefer using this function over + :func:`torch.empty_strided` or manual use of :func:`torch.as_strided`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + physical_layout (tuple of int): the ordering of dimensions physically in memory + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Examples: + + >>> torch.empty((2, 3, 5, 7)).stride() + (105, 35, 7, 1) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride() + (105, 35, 7, 1) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order() + (0, 2, 3, 1) + """ + ... +def empty_quantized(size: _size, qtensor: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def empty_strided(size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. + + .. warning:: + If the constructed tensor is "overlapped" (with multiple indices referring to the same element + in memory) its behavior is undefined. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> a = torch.empty_strided((2, 3), (1, 2)) + >>> a + tensor([[8.9683e-44, 4.4842e-44, 5.1239e+07], + [0.0000e+00, 0.0000e+00, 3.0705e-41]]) + >>> a.stride() + (1, 2) + >>> a.size() + torch.Size([2, 3]) + """ + ... +@overload +def eq(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + ... +@overload +def eq(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + ... +def equal(input: Tensor, other: Tensor) -> _bool: + r""" + equal(input, other) -> bool + + ``True`` if two tensors have the same size and elements, ``False`` otherwise. + + Example:: + + >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) + True + """ + ... +def erf(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + erf(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erf`. + """ + ... +def erf_(input: Tensor) -> Tensor: ... +def erfc(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + erfc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfc`. + """ + ... +def erfc_(input: Tensor) -> Tensor: ... +def erfinv(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + erfinv(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfinv`. + """ + ... +def exp(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + exp(input, *, out=None) -> Tensor + + Returns a new tensor with the exponential of the elements + of the input tensor :attr:`input`. + + .. math:: + y_{i} = e^{x_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.exp(torch.tensor([0, math.log(2.)])) + tensor([ 1., 2.]) + """ + ... +def exp2(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + exp2(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.exp2`. + """ + ... +def exp2_(input: Tensor) -> Tensor: ... +def exp_(input: Tensor) -> Tensor: ... +def expand_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], *, implicit: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.expand`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def expm1(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + expm1(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expm1`. + """ + ... +def expm1_(input: Tensor) -> Tensor: ... +@overload +def eye(n: Union[_int, SymInt], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + ... +@overload +def eye(n: Union[_int, SymInt], m: Union[_int, SymInt], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + ... +def fake_quantize_per_channel_affine(input: Tensor, scale: Tensor, zero_point: Tensor, axis: _int, quant_min: _int, quant_max: _int) -> Tensor: + r""" + fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), in ``torch.float32`` + scale (Tensor): quantization scale, per channel in ``torch.float32`` + zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32`` + axis (int32): channel axis + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized per channel ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(2, 2, 2) + >>> x + tensor([[[-0.2525, -0.0466], + [ 0.3491, -0.2168]], + + [[-0.5906, 1.6258], + [ 0.6444, -0.0542]]]) + >>> scales = (torch.randn(2) + 1) * 0.05 + >>> scales + tensor([0.0475, 0.0486]) + >>> zero_points = torch.zeros(2).to(torch.int32) + >>> zero_points + tensor([0, 0]) + >>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255) + tensor([[[0.0000, 0.0000], + [0.3405, 0.0000]], + + [[0.0000, 1.6134], + [0.6323, 0.0000]]]) + """ + ... +@overload +def fake_quantize_per_tensor_affine(input: Tensor, scale: _float, zero_point: _int, quant_min: _int, quant_max: _int) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + ... +@overload +def fake_quantize_per_tensor_affine(input: Tensor, scale: Tensor, zero_point: Tensor, quant_min: _int, quant_max: _int) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + ... +def fbgemm_linear_fp16_weight(input: Tensor, packed_weight: Tensor, bias: Tensor) -> Tensor: ... +def fbgemm_linear_fp16_weight_fp32_activation(input: Tensor, packed_weight: Tensor, bias: Tensor) -> Tensor: ... +def fbgemm_linear_int8_weight(input: Tensor, weight: Tensor, packed: Tensor, col_offsets: Tensor, weight_scale: Union[Number, _complex], weight_zero_point: Union[Number, _complex], bias: Tensor) -> Tensor: ... +def fbgemm_linear_int8_weight_fp32_activation(input: Tensor, weight: Tensor, packed: Tensor, col_offsets: Tensor, weight_scale: Union[Number, _complex], weight_zero_point: Union[Number, _complex], bias: Tensor) -> Tensor: ... +def fbgemm_linear_quantize_weight(input: Tensor) -> Tuple[Tensor, Tensor, _float, _int]: ... +def fbgemm_pack_gemm_matrix_fp16(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor, K: _int, N: _int) -> Tensor: ... +def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_alpha_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +@overload +def fill(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill(input: Tensor, value: Union[Number, _complex]) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Union[Number, _complex]) -> Tensor: ... +def fix(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fix(input, *, out=None) -> Tensor + + Alias for :func:`torch.trunc` + """ + ... +def fix_(input: Tensor) -> Tensor: ... +@overload +def flatten(input: Tensor, start_dim: _int = 0, end_dim: _int = -1) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +@overload +def flatten(input: Tensor, start_dim: _int, end_dim: _int, out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +@overload +def flatten(input: Tensor, start_dim: Union[str, ellipsis, None], end_dim: Union[str, ellipsis, None], out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +@overload +def flatten(input: Tensor, dims: Sequence[Union[str, ellipsis, None]], out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +def flip(input: Tensor, dims: _size) -> Tensor: + r""" + flip(input, dims) -> Tensor + + Reverse the order of an n-D tensor along given axis in dims. + + .. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + + Args: + input (Tensor): the input tensor. + dims (a list or tuple): axis to flip on + + Example:: + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]]]) + >>> torch.flip(x, [0, 1]) + tensor([[[ 6, 7], + [ 4, 5]], + + [[ 2, 3], + [ 0, 1]]]) + """ + ... +def fliplr(input: Tensor) -> Tensor: + r""" + fliplr(input) -> Tensor + + Flip tensor in the left/right direction, returning a new tensor. + + Flip the entries in each row in the left/right direction. + Columns are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 2-D. + + .. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. + + Args: + input (Tensor): Must be at least 2-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.fliplr(x) + tensor([[1, 0], + [3, 2]]) + """ + ... +def flipud(input: Tensor) -> Tensor: + r""" + flipud(input) -> Tensor + + Flip tensor in the up/down direction, returning a new tensor. + + Flip the entries in each column in the up/down direction. + Rows are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 1-D. + + .. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. + + Args: + input (Tensor): Must be at least 1-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flipud(x) + tensor([[2, 3], + [0, 1]]) + """ + ... +@overload +def float_power(input: Tensor, exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + ... +@overload +def float_power(self: Union[Number, _complex], exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + ... +@overload +def float_power(input: Tensor, exponent: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + ... +def floor(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + floor(input, *, out=None) -> Tensor + + Returns a new tensor with the floor of the elements of :attr:`input`, + the largest integer less than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.8166, 1.5308, -0.2530, -0.2091]) + >>> torch.floor(a) + tensor([-1., 1., -1., -1.]) + """ + ... +def floor_(input: Tensor) -> Tensor: ... +def floor_divide(input: Union[Tensor, Number], other: Union[Tensor, Number], *, out: Optional[Tensor] = None) -> Tensor: + r""" + floor_divide(input, other, *, out=None) -> Tensor + + .. note:: + + Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly performed + truncation division. To restore the previous behavior use + :func:`torch.div` with ``rounding_mode='trunc'``. + + Computes :attr:`input` divided by :attr:`other`, elementwise, and floors + the result. + + .. math:: + \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) + + + + Supports broadcasting to a common shape, type promotion, and integer and float inputs. + + Args: + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([4.0, 3.0]) + >>> b = torch.tensor([2.0, 2.0]) + >>> torch.floor_divide(a, b) + tensor([2.0, 1.0]) + >>> torch.floor_divide(a, 1.4) + tensor([2.0, 2.0]) + """ + ... +def fmax(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmax(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.maximum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) + """ + ... +def fmin(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmin(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.minimum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) + """ + ... +@overload +def fmod(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + ... +@overload +def fmod(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + ... +def frac(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + frac(input, *, out=None) -> Tensor + + Computes the fractional portion of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) + + Example:: + + >>> torch.frac(torch.tensor([1, 2.5, -3.2])) + tensor([ 0.0000, 0.5000, -0.2000]) + """ + ... +def frac_(input: Tensor) -> Tensor: ... +def frexp(input: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.frexp: + r""" + frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) + + Decomposes :attr:`input` into mantissa and exponent tensors + such that :math:`\text{input} = \text{mantissa} \times 2^{\text{exponent}}`. + + The range of mantissa is the open interval (-1, 1). + + Supports float inputs. + + Args: + input (Tensor): the input tensor + + + Keyword args: + out (tuple, optional): the output tensors + + Example:: + + >>> x = torch.arange(9.) + >>> mantissa, exponent = torch.frexp(x) + >>> mantissa + tensor([0.0000, 0.5000, 0.5000, 0.7500, 0.5000, 0.6250, 0.7500, 0.8750, 0.5000]) + >>> exponent + tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) + >>> torch.ldexp(mantissa, exponent) + tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) + """ + ... +def frobenius_norm(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: ... +def from_file(filename: str, shared: Optional[_bool] = None, size: Optional[_int] = 0, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + from_file(filename, shared=None, size=0, *, dtype=None, layout=None, device=None, pin_memory=False) + + Creates a CPU tensor with a storage backed by a memory-mapped file. + + If ``shared`` is True, then memory is shared between processes. All changes are written to the file. + If ``shared`` is False, then changes to the tensor do not affect the file. + + ``size`` is the number of elements in the Tensor. If ``shared`` is ``False``, then the file must contain + at least ``size * sizeof(dtype)`` bytes. If ``shared`` is ``True`` the file will be created if needed. + + .. note:: + Only CPU tensors can be mapped to files. + + .. note:: + For now, tensors with storages backed by a memory-mapped file cannot be created in pinned memory. + + + Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + >>> t = torch.randn(2, 5, dtype=torch.float64) + >>> t.numpy().tofile('storage.pt') + >>> t_mapped = torch.from_file('storage.pt', shared=False, size=10, dtype=torch.float64) + """ + ... +def from_numpy(ndarray) -> Tensor: + r""" + from_numpy(ndarray) -> Tensor + + Creates a :class:`Tensor` from a :class:`numpy.ndarray`. + + The returned tensor and :attr:`ndarray` share the same memory. Modifications to + the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned + tensor is not resizable. + + It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, + ``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``, + ``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, + and ``bool``. + + .. warning:: + Writing to a tensor created from a read-only NumPy array is not supported and will result in undefined behavior. + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.from_numpy(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + """ + ... +def frombuffer(buffer: Any, *, dtype: _dtype, count: int = -1, offset: int = 0, requires_grad: _bool = False) -> Tensor: + r""" + frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor + + Creates a 1-dimensional :class:`Tensor` from an object that implements + the Python buffer protocol. + + Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of + the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count` + elements. + + Note that either of the following must be true: + + 1. :attr:`count` is a positive non-zero number, and the total number of bytes + in the buffer is more than :attr:`offset` plus :attr:`count` times the size + (in bytes) of :attr:`dtype`. + + 2. :attr:`count` is negative, and the length (number of bytes) of the buffer + subtracted by the :attr:`offset` is a multiple of the size (in bytes) of + :attr:`dtype`. + + The returned tensor and buffer share the same memory. Modifications to + the tensor will be reflected in the buffer and vice versa. The returned + tensor is not resizable. + + .. note:: + This function increments the reference count for the object that + owns the shared memory. Therefore, such memory will not be deallocated + before the returned tensor goes out of scope. + + .. warning:: + This function's behavior is undefined when passed an object implementing + the buffer protocol whose data is not on the CPU. Doing so is likely to + cause a segmentation fault. + + .. warning:: + This function does not try to infer the :attr:`dtype` (hence, it is not + optional). Passing a different :attr:`dtype` than its source may result + in unexpected behavior. + + Args: + buffer (object): a Python object that exposes the buffer interface. + + Keyword args: + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + count (int, optional): the number of desired elements to be read. + If negative, all the elements (until the end of the buffer) will be + read. Default: -1. + offset (int, optional): the number of bytes to skip at the start of + the buffer. Default: 0. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> import array + >>> a = array.array('i', [1, 2, 3]) + >>> t = torch.frombuffer(a, dtype=torch.int32) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> # Interprets the signed char bytes as 32-bit integers. + >>> # Each 4 signed char elements will be interpreted as + >>> # 1 signed 32-bit integer. + >>> import array + >>> a = array.array('b', [-1, 0, 0, 0]) + >>> torch.frombuffer(a, dtype=torch.int32) + tensor([255], dtype=torch.int32) + """ + ... +@overload +def full(size: _size, fill_value: Union[Number, _complex], *, out: Optional[Tensor] = None, layout: _layout = strided, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +@overload +def full(size: _size, fill_value: Union[Number, _complex], *, names: List[Union[str, None]], layout: _layout = strided, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +@overload +def full(size: Sequence[Union[_int, SymInt]], fill_value: Union[Number, _complex], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +@overload +def full(size: _size, fill_value: Union[Number, _complex], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +def full_like(input: Tensor, fill_value: Union[Number, _complex], *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. + ``torch.full_like(input, fill_value)`` is equivalent to + ``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + fill_value: the number to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +def fused_moving_avg_obs_fake_quant(input: Tensor, observer_on: Tensor, fake_quant_on: Tensor, running_min: Tensor, running_max: Tensor, scale: Tensor, zero_point: Tensor, averaging_const: _float, quant_min: _int, quant_max: _int, ch_axis: _int, per_row_fake_quant: _bool = False, symmetric_quant: _bool = False) -> Tensor: ... +@overload +def gather(input: Tensor, dim: _int, index: Tensor, *, sparse_grad: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + ... +@overload +def gather(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, *, sparse_grad: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + ... +def gcd(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + gcd(input, other, *, out=None) -> Tensor + + Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`gcd(0, 0) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.gcd(a, b) + tensor([1, 2, 5]) + >>> c = torch.tensor([3]) + >>> torch.gcd(a, c) + tensor([1, 1, 3]) + """ + ... +def gcd_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def ge(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + ... +@overload +def ge(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + ... +def geqrf(input: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.geqrf: + r""" + geqrf(input, *, out=None) -> (Tensor, Tensor) + + This is a low-level function for calling LAPACK's geqrf directly. This function + returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . + + Computes a QR decomposition of :attr:`input`. + Both `Q` and `R` matrices are stored in the same output tensor `a`. + The elements of `R` are stored on and above the diagonal. + Elementary reflectors (or Householder vectors) implicitly defining matrix `Q` + are stored below the diagonal. + The results of this function can be used together with :func:`torch.linalg.householder_product` + to obtain the `Q` matrix or + with :func:`torch.ormqr`, which uses an implicit representation of the `Q` matrix, + for an efficient matrix-matrix multiplication. + + See `LAPACK documentation for geqrf`_ for further details. + + .. note:: + See also :func:`torch.linalg.qr`, which computes Q and R matrices, and :func:`torch.linalg.lstsq` + with the ``driver="gels"`` option for a function that can solve matrix equations using a QR decomposition. + + Args: + input (Tensor): the input matrix + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, Tensor). Ignored if `None`. Default: `None`. + + .. _LAPACK documentation for geqrf: + http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html + """ + ... +def ger(input: Tensor, vec2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ger(input, vec2, *, out=None) -> Tensor + + Alias of :func:`torch.outer`. + + .. warning:: + This function is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.outer` instead. + """ + ... +def get_default_dtype() -> _dtype: + r""" + get_default_dtype() -> torch.dtype + + Get the current default floating point :class:`torch.dtype`. + + Example:: + + >>> torch.get_default_dtype() # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_dtype(torch.float64) + >>> torch.get_default_dtype() # default is now changed to torch.float64 + torch.float64 + """ + ... +def get_num_interop_threads() -> _int: + r""" + get_num_interop_threads() -> int + + Returns the number of threads used for inter-op parallelism on CPU + (e.g. in JIT interpreter) + """ + ... +def get_num_threads() -> _int: + r""" + get_num_threads() -> int + + Returns the number of threads used for parallelizing CPU operations + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Optional[Union[Number, _complex]] = None, dim: Optional[_int] = None, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Sequence[Union[Number, _complex]], dim: Optional[_int] = None, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Sequence[Union[Number, _complex]], dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Union[Tuple[Tensor, ...], List[Tensor]], dim: Optional[_int] = None, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Union[Number, _complex], dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Union[Tuple[Tensor, ...], List[Tensor]], dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def greater(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + ... +@overload +def greater(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + ... +@overload +def greater_equal(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + ... +@overload +def greater_equal(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + ... +def grid_sampler(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def grid_sampler_2d(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def group_norm(input: Tensor, num_groups: _int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: _float = 1e-05, cudnn_enabled: _bool = True) -> Tensor: ... +@overload +def gru(data: Tensor, batch_sizes: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def gru(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def gru_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tensor: ... +@overload +def gt(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + ... +@overload +def gt(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + ... +@overload +def hamming_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hamming_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hamming_window(window_length: _int, periodic: _bool, alpha: _float, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hamming_window(window_length: _int, periodic: _bool, alpha: _float, beta: _float, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hann_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +@overload +def hann_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +def hardshrink(input: Tensor, lambd: Union[Number, _complex] = 0.5, *, out: Optional[Tensor] = None) -> Tensor: ... +def heaviside(input: Tensor, values: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + heaviside(input, values, *, out=None) -> Tensor + + Computes the Heaviside step function for each element in :attr:`input`. + The Heaviside step function is defined as: + + .. math:: + \text{{heaviside}}(input, values) = \begin{cases} + 0, & \text{if input < 0}\\ + values, & \text{if input == 0}\\ + 1, & \text{if input > 0} + \end{cases} + + + Args: + input (Tensor): the input tensor. + values (Tensor): The values to use where :attr:`input` is zero. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + """ + ... +def hinge_embedding_loss(input: Tensor, target: Tensor, margin: _float = 1.0, reduction: _int = 1) -> Tensor: ... +def histc(input: Tensor, bins: _int = 100, min: Union[Number, _complex] = 0, max: Union[Number, _complex] = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor + + Computes the histogram of a tensor. + + The elements are sorted into equal width bins between :attr:`min` and + :attr:`max`. If :attr:`min` and :attr:`max` are both zero, the minimum and + maximum values of the data are used. + + Elements lower than min and higher than max and ``NaN`` elements are ignored. + + Args: + input (Tensor): the input tensor. + bins (int): number of histogram bins + min (Scalar): lower end of the range (inclusive) + max (Scalar): upper end of the range (inclusive) + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: Histogram represented as a tensor + + Example:: + + >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) + tensor([ 0., 2., 1., 0.]) + """ + ... +@overload +def histogram(input: Tensor, bins: Tensor, *, weight: Optional[Tensor] = None, density: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + ... +@overload +def histogram(input: Tensor, bins: _int = 100, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + ... +@overload +def histogramdd(input: Tensor, bins: _int, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + ... +@overload +def histogramdd(input: Tensor, bins: _size, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + ... +@overload +def histogramdd(input: Tensor, bins: Union[Tuple[Tensor, ...], List[Tensor]], range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + ... +def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def hsplit(input: Tensor, sections: _int) -> Tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + ... +@overload +def hsplit(input: Tensor, indices: _size) -> Tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + ... +def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + hspmm(mat1, mat2, *, out=None) -> Tensor + + Performs a matrix multiplication of a :ref:`sparse COO matrix + ` :attr:`mat1` and a strided matrix :attr:`mat2`. The + result is a (1 + 1)-dimensional :ref:`hybrid COO matrix + `. + + Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + """ + ... +def hstack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + hstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence horizontally (column wise). + + This is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.hstack((a,b)) + tensor([1, 2, 3, 4, 5, 6]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.hstack((a,b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + """ + ... +def hypot(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + hypot(input, other, *, out=None) -> Tensor + + Given the legs of a right triangle, return its hypotenuse. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}^{2} + \text{other}_{i}^{2}} + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([5.0000, 5.6569, 6.4031]) + """ + ... +def i0(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + i0(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.i0`. + """ + ... +def i0_(input: Tensor) -> Tensor: ... +def igamma(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + igamma(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammainc`. + """ + ... +def igammac(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + igammac(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammaincc`. + """ + ... +def imag(input: Tensor) -> Tensor: + r""" + imag(input) -> Tensor + + Returns a new tensor containing imaginary values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + .. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + """ + ... +@overload +def index_add(input: Tensor, dim: _int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_add(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_copy(input: Tensor, dim: _int, index: Tensor, source: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + index_copy(input, dim, index, source, *, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_copy(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy(input, dim, index, source, *, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_fill(input: Tensor, dim: _int, index: Tensor, value: Tensor) -> Tensor: ... +@overload +def index_fill(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: ... +@overload +def index_fill(input: Tensor, dim: _int, index: Tensor, value: Union[Number, _complex]) -> Tensor: ... +@overload +def index_fill(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: ... +def index_put(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: ... +def index_put_(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: ... +def index_reduce(input: Tensor, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool = True, out: Optional[Tensor] = None) -> Tensor: + r""" + index_reduce(input, dim, index, source, reduce, *, include_self=True, out=None) -> Tensor + + See :meth:`~Tensor.index_reduce_` for function description. + """ + ... +@overload +def index_select(input: Tensor, dim: _int, index: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + ... +@overload +def index_select(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + ... +def indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def init_num_threads() -> None: ... +def inner(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + inner(input, other, *, out=None) -> Tensor + + Computes the dot product for 1D tensors. For higher dimensions, sums the product + of elements from :attr:`input` and :attr:`other` along their last dimension. + + .. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + + Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + + Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + + Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) + """ + ... +def instance_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], use_input_stats: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ... +def int_repr(input: Tensor) -> Tensor: ... +def inverse(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + inverse(input, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.inv` + """ + ... +def is_complex(input: Tensor) -> _bool: + r""" + is_complex(input) -> (bool) + + Returns True if the data type of :attr:`input` is a complex data type i.e., + one of ``torch.complex64``, and ``torch.complex128``. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_conj(input: Tensor) -> _bool: + r""" + is_conj(input) -> (bool) + + Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_distributed(input: Tensor) -> _bool: ... +def is_floating_point(input: Tensor) -> _bool: + r""" + is_floating_point(input) -> (bool) + + Returns True if the data type of :attr:`input` is a floating point data type i.e., + one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_grad_enabled() -> _bool: + r""" + is_grad_enabled() -> (bool) + + Returns True if grad mode is currently enabled. + """ + ... +def is_inference(input: Tensor) -> _bool: + r""" + is_inference(input) -> (bool) + + Returns True if :attr:`input` is an inference tensor. + + A non-view tensor is an inference tensor if and only if it was + allocated during inference mode. A view tensor is an inference + tensor if and only if the tensor it is a view of is an inference tensor. + + For details on inference mode please see + `Inference Mode `_. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_inference_mode_enabled() -> _bool: + r""" + is_inference_mode_enabled() -> (bool) + + Returns True if inference mode is currently enabled. + """ + ... +def is_neg(input: Tensor) -> _bool: ... +def is_nonzero(input: Tensor) -> _bool: + r""" + is_nonzero(input) -> (bool) + + Returns True if the :attr:`input` is a single element tensor which is not equal to zero + after type conversions. + i.e. not equal to ``torch.tensor([0.])`` or ``torch.tensor([0])`` or + ``torch.tensor([False])``. + Throws a ``RuntimeError`` if ``torch.numel() != 1`` (even in case + of sparse tensors). + + Args: + input (Tensor): the input tensor. + + Examples:: + + >>> torch.is_nonzero(torch.tensor([0.])) + False + >>> torch.is_nonzero(torch.tensor([1.5])) + True + >>> torch.is_nonzero(torch.tensor([False])) + False + >>> torch.is_nonzero(torch.tensor([3])) + True + >>> torch.is_nonzero(torch.tensor([1, 3, 5])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous + """ + ... +def is_same_size(input: Tensor, other: Tensor) -> _bool: ... +def is_signed(input: Tensor) -> _bool: ... +def is_vulkan_available() -> _bool: ... +def isclose(input: Tensor, other: Tensor, rtol: _float = 1e-05, atol: _float = 1e-08, equal_nan: _bool = False) -> Tensor: + r""" + isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + Returns a new tensor with boolean elements representing if each element of + :attr:`input` is "close" to the corresponding element of :attr:`other`. + Closeness is defined as: + + .. math:: + \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert + + + where :attr:`input` and :attr:`other` are finite. Where :attr:`input` + and/or :attr:`other` are nonfinite they are close if and only if + they are equal, with NaNs being considered equal to each other when + :attr:`equal_nan` is True. + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Examples:: + + >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4))) + tensor([ True, False, False]) + >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) + tensor([True, True]) + """ + ... +def isfinite(input: Tensor) -> Tensor: + r""" + isfinite(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element is `finite` or not. + + Real values are finite when they are not NaN, negative infinity, or infinity. + Complex values are finite when both their real and imaginary parts are finite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is finite and False elsewhere + + Example:: + + >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([True, False, True, False, False]) + """ + ... +@overload +def isin(elements: Tensor, test_elements: Tensor, *, assume_unique: _bool = False, invert: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + ... +@overload +def isin(element: Union[Number, _complex], test_elements: Tensor, *, assume_unique: _bool = False, invert: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + ... +@overload +def isin(elements: Tensor, test_element: Union[Number, _complex], *, assume_unique: _bool = False, invert: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + ... +def isinf(input: Tensor) -> Tensor: + r""" + isinf(input) -> Tensor + + Tests if each element of :attr:`input` is infinite + (positive or negative infinity) or not. + + .. note:: + Complex values are infinite when their real or imaginary part is + infinite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is infinite and False elsewhere + + Example:: + + >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) + """ + ... +def isnan(input: Tensor) -> Tensor: + r""" + isnan(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` + is NaN or not. Complex values are considered NaN when either their real + and/or imaginary part is NaN. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is NaN and False elsewhere + + Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([False, True, False]) + """ + ... +def isneginf(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + isneginf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is negative infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isneginf(a) + tensor([ True, False, False]) + """ + ... +def isposinf(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + isposinf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is positive infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isposinf(a) + tensor([False, True, False]) + """ + ... +def isreal(input: Tensor) -> Tensor: + r""" + isreal(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. + All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is real and False elsewhere + + Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) + """ + ... +def istft(input: Tensor, n_fft: _int, hop_length: Optional[_int] = None, win_length: Optional[_int] = None, window: Optional[Tensor] = None, center: _bool = True, normalized: _bool = False, onesided: Optional[_bool] = None, length: Optional[_int] = None, return_complex: _bool = False) -> Tensor: ... +@overload +def kaiser_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + ... +@overload +def kaiser_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + ... +@overload +def kaiser_window(window_length: _int, periodic: _bool, beta: _float, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + ... +def kl_div(input: Tensor, target: Tensor, reduction: _int = 1, *, log_target: _bool = False) -> Tensor: ... +def kron(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + kron(input, other, *, out=None) -> Tensor + + Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + + If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a + :math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a + :math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + + .. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + + where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. + If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + + Supports real-valued and complex-valued inputs. + + .. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + + Arguments: + input (Tensor) + other (Tensor) + + Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + + Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) + """ + ... +@overload +def kthvalue(input: Tensor, k: _int, dim: _int = -1, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + ... +@overload +def kthvalue(input: Tensor, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + ... +def layer_norm(input: Tensor, normalized_shape: Sequence[Union[_int, SymInt]], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: _float = 1e-05, cudnn_enable: _bool = True) -> Tensor: ... +def lcm(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lcm(input, other, *, out=None) -> Tensor + + Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`lcm(0, 0) = 0` and :math:`lcm(0, a) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.lcm(a, b) + tensor([15, 20, 15]) + >>> c = torch.tensor([3]) + >>> torch.lcm(a, c) + tensor([15, 30, 15]) + """ + ... +def lcm_(input: Tensor, other: Tensor) -> Tensor: ... +def ldexp(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ldexp(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by 2 ** :attr:`other`. + + .. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i + + + Typically this function is used to construct floating point numbers by multiplying + mantissas in :attr:`input` with integral powers of two created from the exponents + in :attr:`other`. + + Args: + input (Tensor): the input tensor. + other (Tensor): a tensor of exponents, typically integers. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + """ + ... +def ldexp_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def le(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + ... +@overload +def le(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + ... +@overload +def lerp(input: Tensor, end: Tensor, weight: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + ... +@overload +def lerp(input: Tensor, end: Tensor, weight: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + ... +@overload +def less(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + ... +@overload +def less(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + ... +@overload +def less_equal(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + ... +@overload +def less_equal(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + ... +def lgamma(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lgamma(input, *, out=None) -> Tensor + + Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + + .. math:: + \text{out}_{i} = \ln |\Gamma(\text{input}_{i})| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.lgamma(a) + tensor([ 0.5724, 0.0000, -0.1208]) + """ + ... +@overload +def linspace(start: Number, end: Number, steps: Optional[_int] = None, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Tensor, end: Tensor, steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Union[Number, _complex], end: Tensor, steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Tensor, end: Union[Number, _complex], steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Union[Number, _complex], end: Union[Number, _complex], steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +def log(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{e} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) * 5 + >>> a + tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + >>> torch.log(a) + tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) + """ + ... +def log10(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log10(input, *, out=None) -> Tensor + + Returns a new tensor with the logarithm to the base 10 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{10} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) + + + >>> torch.log10(a) + tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) + """ + ... +def log10_(input: Tensor) -> Tensor: ... +def log1p(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log1p(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of (1 + :attr:`input`). + + .. math:: + y_i = \log_{e} (x_i + 1) + + .. note:: This function is more accurate than :func:`torch.log` for small + values of :attr:`input` + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) + >>> torch.log1p(a) + tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) + """ + ... +def log1p_(input: Tensor) -> Tensor: ... +def log2(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log2(input, *, out=None) -> Tensor + + Returns a new tensor with the logarithm to the base 2 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{2} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) + + + >>> torch.log2(a) + tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) + """ + ... +def log2_(input: Tensor) -> Tensor: ... +def log_(input: Tensor) -> Tensor: ... +@overload +def log_softmax(input: Tensor, dim: _int, dtype: Optional[_dtype] = None, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def log_softmax(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: ... +def logaddexp(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logaddexp(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs. + + Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful + in statistics where the calculated probabilities of events may be so small as to + exceed the range of normal floating point numbers. In such cases the logarithm + of the calculated probability is stored. This function allows adding + probabilities stored in such a fashion. + + This op should be disambiguated with :func:`torch.logsumexp` which performs a + reduction on a single tensor. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1.0, -2, -3])) + tensor([-0.3069, -0.6867, -0.8731]) + >>> torch.logaddexp(torch.tensor([-100.0, -200, -300]), torch.tensor([-1.0, -2, -3])) + tensor([-1., -2., -3.]) + >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) + tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) + """ + ... +def logaddexp2(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logaddexp2(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs in base-2. + + Calculates pointwise :math:`\log_2\left(2^x + 2^y\right)`. See + :func:`torch.logaddexp` for more details. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + """ + ... +@overload +def logcumsumexp(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{j=0}^{i} \exp(x_{ij}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + ... +@overload +def logcumsumexp(input: Tensor, dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{j=0}^{i} \exp(x_{ij}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + ... +def logdet(input: Tensor) -> Tensor: + r""" + logdet(input) -> Tensor + + Calculates log determinant of a square matrix or batches of square matrices. + + It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has + a negative determinant. + + .. note:: + Backward through :meth:`logdet` internally uses SVD results when :attr:`input` + is not invertible. In this case, double backward through :meth:`logdet` will + be unstable in when :attr:`input` doesn't have distinct singular values. See + :func:`torch.linalg.svd` for details. + + .. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + + Arguments: + input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more + batch dimensions. + + Example:: + + >>> A = torch.randn(3, 3) + >>> torch.det(A) + tensor(0.2611) + >>> torch.logdet(A) + tensor(-1.3430) + >>> A + tensor([[[ 0.9254, -0.6213], + [-0.5787, 1.6843]], + + [[ 0.3242, -0.9665], + [ 0.4539, -0.0887]], + + [[ 1.1336, -0.4025], + [-0.7089, 0.9032]]]) + >>> A.det() + tensor([1.1990, 0.4099, 0.7386]) + >>> A.det().log() + tensor([ 0.1815, -0.8917, -0.3031]) + """ + ... +def logical_and(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_and(input, other, *, out=None) -> Tensor + + Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute AND with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_and(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, False]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_and(a, b) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b.double()) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b) + tensor([False, False, True, False]) + >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([False, False, True, False]) + """ + ... +def logical_not(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_not(input, *, out=None) -> Tensor + + Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool + dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_not(torch.tensor([True, False])) + tensor([False, True]) + >>> torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1.5, -10.], dtype=torch.double)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) + tensor([1, 0, 0], dtype=torch.int16) + """ + ... +def logical_or(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_or(input, other, *, out=None) -> Tensor + + Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute OR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_or(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_or(a, b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b.double()) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, True, False]) + """ + ... +def logical_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_xor(input, other, *, out=None) -> Tensor + + Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute XOR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_xor(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([False, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_xor(a, b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b.double()) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, False, False]) + """ + ... +def logit(input: Tensor, eps: Optional[_float] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logit(input, eps=None, *, out=None) -> Tensor + + Alias for :func:`torch.special.logit`. + """ + ... +def logit_(input: Tensor, eps: Optional[_float] = None) -> Tensor: ... +@overload +def logspace(start: Number, end: Number, steps: Optional[_int] = None, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Tensor, end: Tensor, steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Union[Number, _complex], end: Tensor, steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Tensor, end: Union[Number, _complex], steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Union[Number, _complex], end: Union[Number, _complex], steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logsumexp(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + ... +@overload +def logsumexp(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + ... +@overload +def lstm(data: Tensor, batch_sizes: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor, Tensor]: ... +@overload +def lstm(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor, Tensor]: ... +def lstm_cell(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def lt(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + ... +@overload +def lt(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + ... +def lu_solve(input: Tensor, LU_data: Tensor, LU_pivots: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor + + Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted + LU factorization of A from :func:`~linalg.lu_factor`. + + This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + + .. warning:: + + :func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`. + :func:`torch.lu_solve` will be removed in a future PyTorch release. + ``X = torch.lu_solve(B, LU, pivots)`` should be replaced with + + .. code:: python + + X = linalg.lu_solve(LU, pivots, B) + + Arguments: + b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` + is zero or more batch dimensions. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, + where :math:`*` is zero or more batch dimensions. + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, + where :math:`*` is zero or more batch dimensions. + The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of + :attr:`LU_data`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3, 1) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) + tensor(1.00000e-07 * + 2.8312) + """ + ... +def lu_unpack(LU_data: Tensor, LU_pivots: Tensor, unpack_data: _bool = True, unpack_pivots: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.lu_unpack: + r""" + lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. + + .. seealso:: + + :func:`~linalg.lu` returns the matrices from the LU decomposition. Its gradient formula is more efficient + than that of doing :func:`~linalg.lu_factor` followed by :func:`~linalg.lu_unpack`. + + Args: + LU_data (Tensor): the packed LU factorization data + LU_pivots (Tensor): the packed LU factorization pivots + unpack_data (bool): flag indicating if the data should be unpacked. + If ``False``, then the returned ``L`` and ``U`` are empty tensors. + Default: ``True`` + unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``. + If ``False``, then the returned ``P`` is an empty tensor. + Default: ``True`` + + Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + + Returns: + A namedtuple ``(P, L, U)`` + + Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # We can recover A from the factorization + >>> A_ = P @ L @ U + >>> torch.allclose(A, A_) + True + + >>> # LU factorization of a rectangular matrix: + >>> A = torch.randn(2, 3, 2) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # P, L, U are the same as returned by linalg.lu + >>> P_, L_, U_ = torch.linalg.lu(A) + >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) + True + """ + ... +def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor, margin: _float = 0.0, reduction: _int = 1) -> Tensor: ... +@overload +def masked_fill(input: Tensor, mask: Tensor, value: Tensor) -> Tensor: ... +@overload +def masked_fill(input: Tensor, mask: Tensor, value: Union[Number, _complex]) -> Tensor: ... +def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: ... +def masked_select(input: Tensor, mask: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + masked_select(input, mask, *, out=None) -> Tensor + + Returns a new 1-D tensor which indexes the :attr:`input` tensor according to + the boolean mask :attr:`mask` which is a `BoolTensor`. + + The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need + to match, but they must be :ref:`broadcastable `. + + .. note:: The returned tensor does **not** use the same storage + as the original tensor + + Args: + input (Tensor): the input tensor. + mask (BoolTensor): the tensor containing the binary mask to index with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], + [-1.2035, 1.2252, 0.5002, 0.6248], + [ 0.1307, -2.0608, 0.1244, 2.0139]]) + >>> mask = x.ge(0.5) + >>> mask + tensor([[False, False, False, False], + [False, True, True, True], + [False, False, False, True]]) + >>> torch.masked_select(x, mask) + tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) + """ + ... +def matmul(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + matmul(input, other, *, out=None) -> Tensor + + Matrix product of two tensors. + + The behavior depends on the dimensionality of the tensors as follows: + + - If both tensors are 1-dimensional, the dot product (scalar) is returned. + - If both arguments are 2-dimensional, the matrix-matrix product is returned. + - If the first argument is 1-dimensional and the second argument is 2-dimensional, + a 1 is prepended to its dimension for the purpose of the matrix multiply. + After the matrix multiply, the prepended dimension is removed. + - If the first argument is 2-dimensional and the second argument is 1-dimensional, + the matrix-vector product is returned. + - If both arguments are at least 1-dimensional and at least one argument is + N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first + argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the + batched matrix multiply and removed after. If the second argument is 1-dimensional, a + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. + + This operation has support for arguments with :ref:`sparse layouts`. In particular the + matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions + as :func:`torch.mm` + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: + + The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. + + Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> # vector x vector + >>> tensor1 = torch.randn(3) + >>> tensor2 = torch.randn(3) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([]) + >>> # matrix x vector + >>> tensor1 = torch.randn(3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([3]) + >>> # batched matrix x broadcasted vector + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3]) + >>> # batched matrix x batched matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(10, 4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + >>> # batched matrix x broadcasted matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + """ + ... +def matrix_exp(input: Tensor) -> Tensor: + r""" + matrix_exp(A) -> Tensor + + Alias for :func:`torch.linalg.matrix_exp`. + """ + ... +def matrix_power(input: Tensor, n: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + matrix_power(input, n, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.matrix_power` + """ + ... +@overload +def max(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +@overload +def max(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +@overload +def max(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.max: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +@overload +def max(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.max: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def max_pool1d_with_indices(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tuple[Tensor, Tensor]: ... +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def maximum(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + maximum(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`maximum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.maximum(a, b) + tensor([3, 2, 4]) + """ + ... +@overload +def mean(input: Tensor, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + ... +@overload +def mean(input: Tensor, dim: Optional[Union[_int, _size]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + ... +@overload +def mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + ... +@overload +def median(input: Tensor) -> Tensor: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + ... +@overload +def median(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + ... +@overload +def median(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + ... +@overload +def min(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +@overload +def min(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +@overload +def min(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.min: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +@overload +def min(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.min: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +def minimum(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + minimum(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`minimum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.minimum(a, b) + tensor([1, 0, -1]) + """ + ... +def miopen_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, exponential_average_factor: _float, epsilon: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +def miopen_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool) -> Tensor: ... +def miopen_convolution_add_relu(input: Tensor, weight: Tensor, z: Tensor, alpha: Optional[Union[Number, _complex]], bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def miopen_convolution_relu(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def miopen_convolution_transpose(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], output_padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool) -> Tensor: ... +def miopen_depthwise_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool) -> Tensor: ... +def miopen_rnn(input: Tensor, weight: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, hx: Tensor, cx: Optional[Tensor], mode: _int, hidden_size: _int, num_layers: _int, batch_first: _bool, dropout: _float, train: _bool, bidirectional: _bool, batch_sizes: _size, dropout_state: Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def mkldnn_adaptive_avg_pool2d(input: Tensor, output_size: Union[_int, _size], *, out: Optional[Tensor] = None) -> Tensor: ... +def mkldnn_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def mkldnn_linear_backward_weights(grad_output: Tensor, input: Tensor, weight: Tensor, bias_defined: _bool) -> Tuple[Tensor, Tensor]: ... +def mkldnn_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def mkldnn_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def mkldnn_rnn_layer(input: Tensor, weight0: Tensor, weight1: Tensor, weight2: Tensor, weight3: Tensor, hx_: Tensor, cx_: Tensor, reverse: _bool, batch_sizes: _size, mode: _int, hidden_size: _int, num_layers: _int, has_biases: _bool, bidirectional: _bool, batch_first: _bool, train: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def mm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + mm(input, mat2, *, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Supports strided and sparse 2-D tensors as inputs, autograd with + respect to strided inputs. + + This operation has support for arguments with :ref:`sparse layouts`. + If :attr:`out` is provided it's layout will be used. Otherwise, the result + layout will be deduced from that of :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) + """ + ... +@overload +def mode(input: Tensor, dim: _int = -1, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor( + [[0, 0, 0, 2, 0, 0, 2], + [0, 3, 0, 0, 2, 0, 1], + [2, 2, 2, 0, 0, 0, 3], + [2, 2, 3, 0, 1, 1, 0], + [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + ... +@overload +def mode(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor( + [[0, 0, 0, 2, 0, 0, 2], + [0, 3, 0, 0, 2, 0, 1], + [2, 2, 2, 0, 0, 0, 3], + [2, 2, 3, 0, 1, 1, 0], + [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + ... +@overload +def moveaxis(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +@overload +def moveaxis(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +@overload +def movedim(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +@overload +def movedim(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +def msort(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + msort(input, *, out=None) -> Tensor + + Sorts the elements of the :attr:`input` tensor along its first dimension + in ascending order by value. + + .. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) + """ + ... +def mul(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + mul(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by :attr:`other`. + + + .. math:: + \text{out}_i = \text{input}_i \times \text{other}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number) - the tensor or number to multiply input by. + + Keyword args: + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.2015, -0.4255, 2.6087]) + >>> torch.mul(a, 100) + tensor([ 20.1494, -42.5491, 260.8663]) + + >>> b = torch.randn(4, 1) + >>> b + tensor([[ 1.1207], + [-0.3137], + [ 0.0700], + [ 0.8378]]) + >>> c = torch.randn(1, 4) + >>> c + tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) + >>> torch.mul(b, c) + tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], + [-0.1614, -0.0382, 0.1645, -0.7021], + [ 0.0360, 0.0085, -0.0367, 0.1567], + [ 0.4312, 0.1019, -0.4394, 1.8753]]) + """ + ... +def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor + + Returns a tensor where each row contains :attr:`num_samples` indices sampled + from the multinomial (a stricter definition would be multivariate, + refer to torch.distributions.multinomial.Multinomial for more details) + probability distribution located in the corresponding row + of tensor :attr:`input`. + + .. note:: + The rows of :attr:`input` do not need to sum to one (in which case we use + the values as weights), but must be non-negative, finite and have + a non-zero sum. + + Indices are ordered from left to right according to when each was sampled + (first samples are placed in first column). + + If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. + + If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape + :math:`(m \times \text{num\_samples})`. + + If replacement is ``True``, samples are drawn with replacement. + + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + + .. note:: + When drawn without replacement, :attr:`num_samples` must be lower than + number of non-zero elements in :attr:`input` (or the min number of non-zero + elements in each row of :attr:`input` if it is a matrix). + + Args: + input (Tensor): the input tensor containing probabilities + num_samples (int): number of samples to draw + replacement (bool, optional): whether to draw with replacement or not + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights + >>> torch.multinomial(weights, 2) + tensor([1, 2]) + >>> torch.multinomial(weights, 4) # ERROR! + RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, + not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320 + >>> torch.multinomial(weights, 4, replacement=True) + tensor([ 2, 1, 1, 1]) + """ + ... +@overload +def multiply(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + ... +@overload +def multiply(input: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + ... +def mv(input: Tensor, vec: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + mv(input, vec, *, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`input` and the vector + :attr:`vec`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size :math:`m`, :attr:`out` will be 1-D of size :math:`n`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): matrix to be multiplied + vec (Tensor): vector to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.mv(mat, vec) + tensor([ 1.0404, -0.6361]) + """ + ... +def mvlgamma(input: Tensor, p: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + mvlgamma(input, p, *, out=None) -> Tensor + + Alias for :func:`torch.special.multigammaln`. + """ + ... +def nan_to_num(input: Tensor, nan: Optional[_float] = None, posinf: Optional[_float] = None, neginf: Optional[_float] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + + Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` + with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. + By default, :literal:`NaN`\ s are replaced with zero, positive infinity is replaced with the + greatest finite value representable by :attr:`input`'s dtype, and negative infinity + is replaced with the least finite value representable by :attr:`input`'s dtype. + + Args: + input (Tensor): the input tensor. + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + """ + ... +def nan_to_num_(input: Tensor, nan: Optional[_float] = None, posinf: Optional[_float] = None, neginf: Optional[_float] = None) -> Tensor: ... +def nanmean(input: Tensor, dim: Optional[Union[_int, _size]] = None, keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + + Computes the mean of all `non-NaN` elements along the specified dimensions. + + This function is identical to :func:`torch.mean` when there are no `NaN` values + in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will + propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the + `NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`). + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.mean` computes the mean value, propagating `NaN`. + + Example:: + + >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) + >>> x.mean() + tensor(nan) + >>> x.nanmean() + tensor(1.8000) + >>> x.mean(dim=0) + tensor([ nan, 1.5000, 2.5000]) + >>> x.nanmean(dim=0) + tensor([1.0000, 1.5000, 2.5000]) + + # If all elements in the reduced dimensions are NaN then the result is NaN + >>> torch.tensor([torch.nan]).nanmean() + tensor(nan) + """ + ... +@overload +def nanmedian(input: Tensor) -> Tensor: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + ... +@overload +def nanmedian(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + ... +@overload +def nanmedian(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + ... +@overload +def nanquantile(input: Tensor, q: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + ... +@overload +def nanquantile(input: Tensor, q: _float, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + ... +def nansum(input: Tensor, dim: Optional[Union[_int, _size]] = None, keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + nansum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.tensor([1., 2., float('nan'), 4.]) + >>> torch.nansum(a) + tensor(7.) + + .. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. + If :attr:`dim` is a list of dimensions, reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> torch.nansum(torch.tensor([1., float("nan")])) + 1.0 + >>> a = torch.tensor([[1, 2], [3., float("nan")]]) + >>> torch.nansum(a) + tensor(6.) + >>> torch.nansum(a, dim=0) + tensor([4., 2.]) + >>> torch.nansum(a, dim=1) + tensor([3., 3.]) + """ + ... +@overload +def narrow(input: Tensor, dim: _int, start: Tensor, length: Union[_int, SymInt]) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + ... +@overload +def narrow(input: Tensor, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + ... +def narrow_copy(input: Tensor, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt], *, out: Optional[Tensor] = None) -> Tensor: + r""" + narrow_copy(input, dim, start, length, *, out=None) -> Tensor + + Same as :meth:`Tensor.narrow` except this returns a copy rather + than shared storage. This is primarily for sparse tensors, which + do not have a shared-storage narrow method. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow_copy(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow_copy(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) + >>> torch.narrow_copy(s, 0, 0, 1) + tensor(indices=tensor([[0, 0], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + + .. seealso:: + + :func:`torch.narrow` for a non copy variant + """ + ... +def native_batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor, Tensor]: ... +def native_channel_shuffle(input: Tensor, groups: Union[_int, SymInt]) -> Tensor: ... +def native_dropout(input: Tensor, p: _float, train: Optional[_bool]) -> Tuple[Tensor, Tensor]: ... +def native_group_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], N: Union[_int, SymInt], C: Union[_int, SymInt], HxW: Union[_int, SymInt], group: _int, eps: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +def native_layer_norm(input: Tensor, normalized_shape: Sequence[Union[_int, SymInt]], weight: Optional[Tensor], bias: Optional[Tensor], eps: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +@overload +def native_norm(input: Tensor, p: Optional[Union[Number, _complex]], dim: Union[_int, _size], keepdim: _bool, dtype: Optional[_dtype]) -> Tensor: ... +@overload +def native_norm(input: Tensor, p: Union[Number, _complex] = 2) -> Tensor: ... +@overload +def ne(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + ... +@overload +def ne(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + ... +def neg(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + neg(input, *, out=None) -> Tensor + + Returns a new tensor with the negative of the elements of :attr:`input`. + + .. math:: + \text{out} = -1 \times \text{input} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.neg(a) + tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) + """ + ... +def neg_(input: Tensor) -> Tensor: ... +def negative(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + negative(input, *, out=None) -> Tensor + + Alias for :func:`torch.neg` + """ + ... +def negative_(input: Tensor) -> Tensor: ... +def nextafter(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + nextafter(input, other, *, out=None) -> Tensor + + Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> eps = torch.finfo(torch.float32).eps + >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) + tensor([True, True]) + """ + ... +@overload +def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + ... +@overload +def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + ... +def nonzero_static(input: Tensor, *, size: _int, fill_value: _int = -1, out: Optional[Tensor] = None) -> Tensor: ... +def norm_except_dim(v: Tensor, pow: _int = 2, dim: _int = 0) -> Tensor: ... +@overload +def normal(mean: Tensor, std: Tensor, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def normal(mean: Tensor, std: _float = 1, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def normal(mean: _float, std: Tensor, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def normal(mean: _float, std: _float, size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator] = None, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def not_equal(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + ... +@overload +def not_equal(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + ... +@overload +def nuclear_norm(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def nuclear_norm(input: Tensor, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: ... +def numel(self: Tensor) -> _int: + r""" + numel(input) -> int + + Returns the total number of elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 2, 3, 4, 5) + >>> torch.numel(a) + 120 + >>> a = torch.zeros(4,4) + >>> torch.numel(a) + 16 + """ + ... +@overload +def ones(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +@overload +def ones(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +@overload +def ones(size: _size, *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +@overload +def ones(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +def ones_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the same size as + :attr:`input`. ``torch.ones_like(input)`` is equivalent to + ``torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.ones_like(input, out=output)`` is equivalent to + ``torch.ones(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.ones_like(input) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + ... +def orgqr(input: Tensor, input2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + orgqr(input, tau) -> Tensor + + Alias for :func:`torch.linalg.householder_product`. + """ + ... +def ormqr(input: Tensor, input2: Tensor, input3: Tensor, left: _bool = True, transpose: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor + + Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + + Multiplies a :math:`m \times n` matrix `C` (given by :attr:`other`) with a matrix `Q`, + where `Q` is represented using Householder reflectors `(input, tau)`. + See `Representation of Orthogonal or Unitary Matrices`_ for further details. + + If :attr:`left` is `True` then `op(Q)` times `C` is computed, otherwise the result is `C` times `op(Q)`. + When :attr:`left` is `True`, the implicit matrix `Q` has size :math:`m \times m`. + It has size :math:`n \times n` otherwise. + If :attr:`transpose` is `True` then `op` is the conjugate transpose operation, otherwise it's a no-op. + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + + .. seealso:: + :func:`torch.geqrf` can be used to form the Householder representation `(input, tau)` of matrix `Q` + from the QR decomposition. + + .. note:: + This function supports backward but it is only fast when ``(input, tau)`` do not require gradients + and/or ``tau.size(-1)`` is very small. + `` + + Args: + input (Tensor): tensor of shape `(*, mn, k)` where `*` is zero or more batch dimensions + and `mn` equals to `m` or `n` depending on the :attr:`left`. + tau (Tensor): tensor of shape `(*, min(mn, k))` where `*` is zero or more batch dimensions. + other (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + left (bool): controls the order of multiplication. + transpose (bool): controls whether the matrix `Q` is conjugate transposed or not. + + Keyword args: + out (Tensor, optional): the output Tensor. Ignored if `None`. Default: `None`. + + .. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html + """ + ... +def outer(input: Tensor, vec2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + outer(input, vec2, *, out=None) -> Tensor + + Outer product of :attr:`input` and :attr:`vec2`. + If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of + size :math:`m`, then :attr:`out` must be a matrix of size :math:`(n \times m)`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): 1-D input vector + vec2 (Tensor): 1-D input vector + + Keyword args: + out (Tensor, optional): optional output matrix + + Example:: + + >>> v1 = torch.arange(1., 5.) + >>> v2 = torch.arange(1., 4.) + >>> torch.outer(v1, v2) + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.]]) + """ + ... +def pairwise_distance(x1: Tensor, x2: Tensor, p: _float = 2, eps: _float = 1e-06, keepdim: _bool = False) -> Tensor: ... +def pdist(input: Tensor, p: _float = 2) -> Tensor: ... +def permute(input: Tensor, dims: _size) -> Tensor: + r""" + permute(input, dims) -> Tensor + + Returns a view of the original tensor :attr:`input` with its dimensions permuted. + + Args: + input (Tensor): the input tensor. + dims (tuple of int): The desired ordering of dimensions + + Example: + >>> x = torch.randn(2, 3, 5) + >>> x.size() + torch.Size([2, 3, 5]) + >>> torch.permute(x, (2, 0, 1)).size() + torch.Size([5, 2, 3]) + """ + ... +def permute_copy(input: Tensor, dims: _size, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.permute`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def pinverse(input: Tensor, rcond: _float = 1e-15) -> Tensor: + r""" + pinverse(input, rcond=1e-15) -> Tensor + + Alias for :func:`torch.linalg.pinv` + """ + ... +def pixel_shuffle(input: Tensor, upscale_factor: _int) -> Tensor: ... +def pixel_unshuffle(input: Tensor, downscale_factor: _int) -> Tensor: ... +def poisson(input: Tensor, generator: Optional[Generator] = None) -> Tensor: + r""" + poisson(input, generator=None) -> Tensor + + Returns a tensor of the same size as :attr:`input` with each element + sampled from a Poisson distribution with rate parameter given by the corresponding + element in :attr:`input` i.e., + + .. math:: + \text{out}_i \sim \text{Poisson}(\text{input}_i) + + :attr:`input` must be non-negative. + + Args: + input (Tensor): the input tensor containing the rates of the Poisson distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + + Example:: + + >>> rates = torch.rand(4, 4) * 5 # rate parameter between 0 and 5 + >>> torch.poisson(rates) + tensor([[9., 1., 3., 5.], + [8., 6., 6., 0.], + [0., 4., 5., 3.], + [2., 1., 4., 2.]]) + """ + ... +def poisson_nll_loss(input: Tensor, target: Tensor, log_input: _bool, full: _bool, eps: _float, reduction: _int) -> Tensor: ... +def polar(abs: Tensor, angle: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + polar(abs, angle, *, out=None) -> Tensor + + Constructs a complex tensor whose elements are Cartesian coordinates + corresponding to the polar coordinates with absolute value :attr:`abs` and angle + :attr:`angle`. + + .. math:: + \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j + + .. note:: + `torch.polar` is similar to + `std::polar `_ + and does not compute the polar decomposition + of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do. + The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is + infinite. + + + Args: + abs (Tensor): The absolute value the complex tensor. Must be float or double. + angle (Tensor): The angle of the complex tensor. Must be same dtype as + :attr:`abs`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> import numpy as np + >>> abs = torch.tensor([1, 2], dtype=torch.float64) + >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64) + >>> z = torch.polar(abs, angle) + >>> z + tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) + """ + ... +def polygamma(n: _int, input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + polygamma(n, input, *, out=None) -> Tensor + + Alias for :func:`torch.special.polygamma`. + """ + ... +def positive(input: Tensor) -> Tensor: + r""" + positive(input) -> Tensor + + Returns :attr:`input`. + Throws a runtime error if :attr:`input` is a bool tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.randn(5) + >>> t + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.positive(t) + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + """ + ... +@overload +def pow(input: Tensor, exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + ... +@overload +def pow(self: Union[Number, _complex], exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + ... +@overload +def pow(input: Tensor, exponent: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + ... +def prelu(input: Tensor, weight: Tensor) -> Tensor: ... +@overload +def prod(input: Tensor, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + prod(input, *, dtype=None) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + ... +@overload +def prod(input: Tensor, dim: _int, keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + prod(input, *, dtype=None) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + ... +@overload +def prod(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + prod(input, *, dtype=None) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + ... +def promote_types(type1: _dtype, type2: _dtype) -> _dtype: + r""" + promote_types(type1, type2) -> dtype + + Returns the :class:`torch.dtype` with the smallest size and scalar kind that is + not smaller nor of lower kind than either `type1` or `type2`. See type promotion + :ref:`documentation ` for more information on the type + promotion logic. + + Args: + type1 (:class:`torch.dtype`) + type2 (:class:`torch.dtype`) + + Example:: + + >>> torch.promote_types(torch.int32, torch.float32) + torch.float32 + >>> torch.promote_types(torch.uint8, torch.long) + torch.long + """ + ... +def put(input: Tensor, index: Tensor, source: Tensor, accumulate: _bool = False) -> Tensor: ... +def q_per_channel_axis(input: Tensor) -> _int: ... +def q_per_channel_scales(input: Tensor) -> Tensor: ... +def q_per_channel_zero_points(input: Tensor) -> Tensor: ... +def q_scale(input: Tensor) -> _float: ... +def q_zero_point(input: Tensor) -> _int: ... +def qr(input: Tensor, some: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.qr: + r""" + qr(input, some=True, *, out=None) -> (Tensor, Tensor) + + Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, + and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` + with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and + :math:`R` being an upper triangular matrix or batch of upper triangular matrices. + + If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. + Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. + + .. warning:: + + :func:`torch.qr` is deprecated in favor of :func:`torch.linalg.qr` + and will be removed in a future PyTorch release. The boolean parameter :attr:`some` has been + replaced with a string parameter :attr:`mode`. + + ``Q, R = torch.qr(A)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A) + + ``Q, R = torch.qr(A, some=False)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A, mode="complete") + + .. warning:: + If you plan to backpropagate through QR, note that the current backward implementation + is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` + columns of :attr:`input` are linearly independent. + This behavior will probably change once QR supports pivoting. + + .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + + Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + + Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. + + Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.qr(a, some=False) + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) + True + """ + ... +@overload +def quantile(input: Tensor, q: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + ... +@overload +def quantile(input: Tensor, q: _float, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + ... +def quantize_per_channel(input: Tensor, scales: Tensor, zero_points: Tensor, axis: _int, dtype: _dtype) -> Tensor: + r""" + quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor + + Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + + Arguments: + input (Tensor): float tensor to quantize + scales (Tensor): float 1D tensor of scales to use, size should match ``input.size(axis)`` + zero_points (int): integer 1D tensor of offset to use, size should match ``input.size(axis)`` + axis (int): dimension on which apply per-channel quantization + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor + + Example:: + + >>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) + tensor([[-1., 0.], + [ 1., 2.]], size=(2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_channel_affine, + scale=tensor([0.1000, 0.0100], dtype=torch.float64), + zero_point=tensor([10, 0]), axis=0) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() + tensor([[ 0, 10], + [100, 200]], dtype=torch.uint8) + """ + ... +@overload +def quantize_per_tensor(input: Tensor, scale: Tensor, zero_point: Tensor, dtype: _dtype) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + ... +@overload +def quantize_per_tensor(input: Tensor, scale: _float, zero_point: _int, dtype: _dtype) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + ... +@overload +def quantize_per_tensor(tensors: Union[Tuple[Tensor, ...], List[Tensor]], scales: Tensor, zero_points: Tensor, dtype: _dtype) -> Tuple[Tensor, ...]: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + ... +def quantize_per_tensor_dynamic(input: Tensor, dtype: _dtype, reduce_range: _bool) -> Tensor: + r""" + quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor + + Converts a float tensor to a quantized tensor with scale and zero_point calculated + dynamically based on the input. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8`` + reduce_range (bool): a flag to indicate whether to reduce the range of quantized + data by 1 bit, it's required to avoid instruction overflow for some hardwares + + Returns: + Tensor: A newly (dynamically) quantized tensor + + Example:: + + >>> t = torch.quantize_per_tensor_dynamic(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.quint8, False) + >>> print(t) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.011764705882352941, + zero_point=85) + >>> t.int_repr() + tensor([ 0, 85, 170, 255], dtype=torch.uint8) + """ + ... +def quantized_batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], mean: Tensor, var: Tensor, eps: _float, output_scale: _float, output_zero_point: _int) -> Tensor: + r""" + quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor + + Applies batch normalization on a 4D (NCHW) quantized tensor. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Arguments: + input (Tensor): quantized tensor + weight (Tensor): float tensor that corresponds to the gamma, size C + bias (Tensor): float tensor that corresponds to the beta, size C + mean (Tensor): float mean value in batch normalization, size C + var (Tensor): float tensor for variance, size C + eps (float): a value added to the denominator for numerical stability. + output_scale (float): output quantized tensor scale + output_zero_point (int): output quantized tensor zero_point + + Returns: + Tensor: A quantized tensor with batch normalization applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_batch_norm(qx, torch.ones(2), torch.zeros(2), torch.rand(2), torch.rand(2), 0.00001, 0.2, 2) + tensor([[[[-0.2000, -0.2000], + [ 1.6000, -0.2000]], + + [[-0.4000, -0.4000], + [-0.4000, 0.6000]]], + + + [[[-0.2000, -0.2000], + [-0.2000, -0.2000]], + + [[ 0.6000, -0.4000], + [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) + """ + ... +def quantized_gru_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tensor: ... +def quantized_lstm_cell(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tuple[Tensor, Tensor]: ... +def quantized_max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: + r""" + quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 1D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (list of int): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool1d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool1d(qx, [2]) + tensor([[0.0000], + [1.5000]], size=(2, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + ... +def quantized_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: + r""" + quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 2D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (``list of int``): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool2d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool2d(qx, [2,2]) + tensor([[[[1.5000]], + + [[1.5000]]], + + + [[[0.0000]], + + [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + ... +def quantized_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def quantized_rnn_relu_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tensor: ... +def quantized_rnn_tanh_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tensor: ... +def rad2deg(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + rad2deg(input, *, out=None) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in radians to degrees. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) + >>> torch.rad2deg(a) + tensor([[ 180.0233, -180.0233], + [ 359.9894, -359.9894], + [ 89.9544, -89.9544]]) + """ + ... +def rad2deg_(input: Tensor) -> Tensor: ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +def rand_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a uniform distribution on the interval :math:`[0, 1)`. + ``torch.rand_like(input)`` is equivalent to + ``torch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randint(low: _int, high: _int, size: _size, *, generator: Optional[Generator] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(high: _int, size: _size, *, generator: Optional[Generator] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(low: Union[_int, SymInt], high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(low: Union[_int, SymInt], high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint_like(input: Tensor, high: Union[_int, SymInt], *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randint_like(input: Tensor, low: Union[_int, SymInt], high: Union[_int, SymInt], *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +def randn_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the + sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to + ``torch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randperm(n: Union[_int, SymInt], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + ... +@overload +def randperm(n: Union[_int, SymInt], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + ... +def range(start: Number, end: Number, step: Number = 1, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` + with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is + the gap between two values in the tensor. + + .. math:: + \text{out}_{i+1} = \text{out}_i + \text{step}. + + .. warning:: + This function is deprecated and will be removed in a future release because its behavior is inconsistent with + Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). + + Args: + start (float): the starting value for the set of points. Default: ``0``. + end (float): the ending value for the set of points + step (float): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.range(1, 4) + tensor([ 1., 2., 3., 4.]) + >>> torch.range(1, 4, 0.5) + tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) + """ + ... +def ravel(input: Tensor) -> Tensor: + r""" + ravel(input) -> Tensor + + Return a contiguous flattened tensor. A copy is made only if needed. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + """ + ... +def real(input: Tensor) -> Tensor: + r""" + real(input) -> Tensor + + Returns a new tensor containing real values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + """ + ... +def reciprocal(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + reciprocal(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the elements of :attr:`input` + + .. math:: + \text{out}_{i} = \frac{1}{\text{input}_{i}} + + .. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.4595, -2.1219, -1.4314, 0.7298]) + >>> torch.reciprocal(a) + tensor([-2.1763, -0.4713, -0.6986, 1.3702]) + """ + ... +def reciprocal_(input: Tensor) -> Tensor: ... +def relu(input: Tensor) -> Tensor: ... +def relu_(input: Tensor) -> Tensor: ... +@overload +def remainder(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + ... +@overload +def remainder(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + ... +@overload +def remainder(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + ... +def renorm(input: Tensor, p: Union[Number, _complex], dim: _int, maxnorm: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + renorm(input, p, dim, maxnorm, *, out=None) -> Tensor + + Returns a tensor where each sub-tensor of :attr:`input` along dimension + :attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower + than the value :attr:`maxnorm` + + .. note:: If the norm of a row is lower than `maxnorm`, the row is unchanged + + Args: + input (Tensor): the input tensor. + p (float): the power for the norm computation + dim (int): the dimension to slice over to get the sub-tensors + maxnorm (float): the maximum norm to keep each sub-tensor under + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.ones(3, 3) + >>> x[1].fill_(2) + tensor([ 2., 2., 2.]) + >>> x[2].fill_(3) + tensor([ 3., 3., 3.]) + >>> x + tensor([[ 1., 1., 1.], + [ 2., 2., 2.], + [ 3., 3., 3.]]) + >>> torch.renorm(x, 1, 0, 5) + tensor([[ 1.0000, 1.0000, 1.0000], + [ 1.6667, 1.6667, 1.6667], + [ 1.6667, 1.6667, 1.6667]]) + """ + ... +@overload +def repeat_interleave(input: Tensor, repeats: Tensor, dim: Optional[_int] = None, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + ... +@overload +def repeat_interleave(repeats: Tensor, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + ... +@overload +def repeat_interleave(input: Tensor, repeats: Union[_int, SymInt], dim: Optional[_int] = None, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + ... +def reshape(input: Tensor, shape: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + reshape(input, shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`input`, + but with the specified shape. When possible, the returned tensor will be a view + of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs + with compatible strides can be reshaped without copying, but you should not + depend on the copying vs. viewing behavior. + + See :meth:`torch.Tensor.view` on when it is possible to return a view. + + A single dimension may be -1, in which case it's inferred from the remaining + dimensions and the number of elements in :attr:`input`. + + Args: + input (Tensor): the tensor to be reshaped + shape (tuple of int): the new shape + + Example:: + + >>> a = torch.arange(4.) + >>> torch.reshape(a, (2, 2)) + tensor([[ 0., 1.], + [ 2., 3.]]) + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + tensor([ 0, 1, 2, 3]) + """ + ... +def resize_as_(input: Tensor, the_template: Tensor, *, memory_format: Optional[memory_format] = None) -> Tensor: ... +def resize_as_sparse_(input: Tensor, the_template: Tensor) -> Tensor: ... +def resolve_conj(input: Tensor) -> Tensor: + r""" + resolve_conj(input) -> Tensor + + Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False + """ + ... +def resolve_neg(input: Tensor) -> Tensor: + r""" + resolve_neg(input) -> Tensor + + Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its negative bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> z = y.imag + >>> z.is_neg() + True + >>> out = z.resolve_neg() + >>> out + tensor([-1., -2., 3.]) + >>> out.is_neg() + False + """ + ... +@overload +def result_type(tensor: Tensor, other: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def result_type(scalar: Union[Number, _complex], tensor: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def result_type(tensor: Tensor, other: Union[Number, _complex]) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def result_type(scalar1: Union[Number, _complex], scalar2: Union[Number, _complex]) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def rnn_relu(data: Tensor, batch_sizes: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def rnn_relu(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def rnn_relu_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tensor: ... +@overload +def rnn_tanh(data: Tensor, batch_sizes: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def rnn_tanh(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def rnn_tanh_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tensor: ... +def roll(input: Tensor, shifts: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]], dims: Union[_int, _size] = ()) -> Tensor: + r""" + roll(input, shifts, dims=None) -> Tensor + + Roll the tensor :attr:`input` along the given dimension(s). Elements that are + shifted beyond the last position are re-introduced at the first position. If + :attr:`dims` is `None`, the tensor will be flattened before rolling and then + restored to the original shape. + + Args: + input (Tensor): the input tensor. + shifts (int or tuple of ints): The number of places by which the elements + of the tensor are shifted. If shifts is a tuple, dims must be a tuple of + the same size, and each dimension will be rolled by the corresponding + value + dims (int or tuple of ints): Axis along which to roll + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + >>> x + tensor([[1, 2], + [3, 4], + [5, 6], + [7, 8]]) + >>> torch.roll(x, 1) + tensor([[8, 1], + [2, 3], + [4, 5], + [6, 7]]) + >>> torch.roll(x, 1, 0) + tensor([[7, 8], + [1, 2], + [3, 4], + [5, 6]]) + >>> torch.roll(x, -1, 0) + tensor([[3, 4], + [5, 6], + [7, 8], + [1, 2]]) + >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) + tensor([[6, 5], + [8, 7], + [2, 1], + [4, 3]]) + """ + ... +def rot90(input: Tensor, k: _int = 1, dims: _size = (0,1)) -> Tensor: + r""" + rot90(input, k=1, dims=[0,1]) -> Tensor + + Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. + Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + + Args: + input (Tensor): the input tensor. + k (int): number of times to rotate. Default value is 1 + dims (a list or tuple): axis to rotate. Default value is [0, 1] + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.rot90(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.rot90(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) + """ + ... +@overload +def round(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + ... +@overload +def round(input: Tensor, *, decimals: _int, out: Optional[Tensor] = None) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + ... +@overload +def round_(input: Tensor) -> Tensor: ... +@overload +def round_(input: Tensor, *, decimals: _int) -> Tensor: ... +def row_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def row_stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + row_stack(tensors, *, out=None) -> Tensor + + Alias of :func:`torch.vstack`. + """ + ... +def rrelu(input: Tensor, lower: Union[Number, _complex] = 0.125, upper: Union[Number, _complex] = 0.3333333333333333, training: _bool = False, generator: Optional[Generator] = None) -> Tensor: ... +def rrelu_(input: Tensor, lower: Union[Number, _complex] = 0.125, upper: Union[Number, _complex] = 0.3333333333333333, training: _bool = False, generator: Optional[Generator] = None) -> Tensor: ... +def rsqrt(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + rsqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the square-root of each of + the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + >>> torch.rsqrt(a) + tensor([ nan, 1.8351, 0.8053, nan]) + """ + ... +def rsqrt_(input: Tensor) -> Tensor: ... +@overload +def rsub(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def rsub(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: ... +def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1, out: Optional[Tensor] = None) -> Tensor: ... +def scalar_tensor(s: Union[Number, _complex], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, src: Tensor, *, reduce: str, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, src: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, value: Union[Number, _complex], *, reduce: str, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, value: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter_add(input: Tensor, dim: _int, index: Tensor, src: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + ... +@overload +def scatter_add(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + ... +def scatter_reduce(input: Tensor, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool = True, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` + """ + ... +@overload +def searchsorted(sorted_sequence: Tensor, input: Tensor, *, out_int32: _bool = False, right: _bool = False, side: Optional[str] = None, sorter: Optional[Tensor] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + ... +@overload +def searchsorted(sorted_sequence: Tensor, self: Union[Number, _complex], *, out_int32: _bool = False, right: _bool = False, side: Optional[str] = None, sorter: Optional[Tensor] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + ... +def segment_reduce(data: Tensor, reduce: str, *, lengths: Optional[Tensor] = None, indices: Optional[Tensor] = None, offsets: Optional[Tensor] = None, axis: _int = 0, unsafe: _bool = False, initial: Optional[Union[Number, _complex]] = None) -> Tensor: ... +@overload +def select(input: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + ... +@overload +def select(input: Tensor, dim: Union[str, ellipsis, None], index: _int) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + ... +def select_copy(input: Tensor, dim: _int, index: Union[_int, SymInt], *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.select`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def select_scatter(input: Tensor, src: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: + r""" + select_scatter(input, src, dim, index) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into. + index (int): the index to select with + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.select(input, dim, index)`` + + Example:: + + >>> a = torch.zeros(2, 2) + >>> b = torch.ones(2) + >>> a.select_scatter(b, 0, 0) + tensor([[1., 1.], + [0., 0.]]) + """ + ... +def selu(input: Tensor) -> Tensor: ... +def selu_(input: Tensor) -> Tensor: ... +def set_flush_denormal(mode: _bool) -> _bool: + r""" + set_flush_denormal(mode) -> bool + + Disables denormal floating numbers on CPU. + + Returns ``True`` if your system supports flushing denormal numbers and it + successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal` + is supported on x86 architectures supporting SSE3 and AArch64 architecture. + + Args: + mode (bool): Controls whether to enable flush denormal mode or not + + Example:: + + >>> torch.set_flush_denormal(True) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor([ 0.], dtype=torch.float64) + >>> torch.set_flush_denormal(False) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor(9.88131e-324 * + [ 1.0000], dtype=torch.float64) + """ + ... +def set_num_interop_threads(num: _int) -> None: + r""" + set_num_interop_threads(int) + + Sets the number of threads used for interop parallelism + (e.g. in JIT interpreter) on CPU. + + .. warning:: + Can only be called once and before any inter-op parallel work + is started (e.g. JIT execution). + """ + ... +def set_num_threads(num: _int) -> None: + r""" + set_num_threads(int) + + Sets the number of threads used for intraop parallelism on CPU. + + .. warning:: + To ensure that the correct number of threads is used, set_num_threads + must be called before running eager, JIT or autograd code. + """ + ... +def sgn(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sgn(input, *, out=None) -> Tensor + + This function is an extension of torch.sign() to complex tensors. + It computes a new tensor whose elements have + the same angles as the corresponding elements of :attr:`input` and + absolute values (i.e. magnitudes) of one for complex tensors and + is equivalent to torch.sign() for non-complex tensors. + + .. math:: + \text{out}_{i} = \begin{cases} + 0 & |\text{{input}}_i| == 0 \\ + \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} + \end{cases} + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> t.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) + """ + ... +def sigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sigmoid(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expit`. + """ + ... +def sigmoid_(input: Tensor) -> Tensor: ... +def sign(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sign(input, *, out=None) -> Tensor + + Returns a new tensor with the signs of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> a + tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) + >>> torch.sign(a) + tensor([ 1., -1., 0., 1.]) + """ + ... +def signbit(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + signbit(input, *, out=None) -> Tensor + + Tests if each element of :attr:`input` has its sign bit set or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> torch.signbit(a) + tensor([ False, True, False, False]) + >>> a = torch.tensor([-0.0, 0.0]) + >>> torch.signbit(a) + tensor([ True, False]) + + .. note:: + signbit handles signed zeros, so negative zero (-0) returns True. + """ + ... +def sin(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sin(input, *, out=None) -> Tensor + + Returns a new tensor with the sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) + """ + ... +def sin_(input: Tensor) -> Tensor: ... +def sinc(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sinc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.sinc`. + """ + ... +def sinc_(input: Tensor) -> Tensor: ... +def sinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sinh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic sine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) + >>> torch.sinh(a) + tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + ... +def sinh_(input: Tensor) -> Tensor: ... +def slice_copy(input: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.slice`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def slice_inverse(input: Tensor, src: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1) -> Tensor: ... +def slice_scatter(input: Tensor, src: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1, *, out: Optional[Tensor] = None) -> Tensor: + r""" + slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given + dimension. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into + start (Optional[int]): the start index of where to insert the slice + end (Optional[int]): the end index of where to insert the slice + step (int): the how many elements to skip in + + Example:: + + >>> a = torch.zeros(8, 8) + >>> b = torch.ones(2, 8) + >>> a.slice_scatter(b, start=6) + tensor([[0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1.]]) + + >>> b = torch.ones(8, 2) + >>> a.slice_scatter(b, dim=1, start=2, end=6, step=2) + tensor([[0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.]]) + """ + ... +def slogdet(input: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.slogdet: + r""" + slogdet(input) -> (Tensor, Tensor) + + Alias for :func:`torch.linalg.slogdet` + """ + ... +def smm(input: Tensor, mat2: Tensor) -> Tensor: + r""" + smm(input, mat) -> Tensor + + Performs a matrix multiplication of the sparse matrix :attr:`input` + with the dense matrix :attr:`mat`. + + Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied + """ + ... +@overload +def softmax(input: Tensor, dim: _int, dtype: Optional[_dtype] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + ... +@overload +def softmax(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + ... +@overload +def sort(input: Tensor, *, stable: Optional[_bool], dim: _int = -1, descending: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +@overload +def sort(input: Tensor, dim: _int = -1, descending: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +@overload +def sort(input: Tensor, *, stable: Optional[_bool], dim: Union[str, ellipsis, None], descending: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +@overload +def sort(input: Tensor, dim: Union[str, ellipsis, None], descending: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +def sparse_bsc_tensor(ccol_indices: Union[Tensor, List], row_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse + Column)) ` with specified 2-dimensional blocks at the + given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in BSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial blocks for the tensor. Can be a list, + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> ccol_indices = [0, 1, 2] + >>> row_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 1, 2]), + row_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsc) + """ + ... +def sparse_bsr_tensor(crow_indices: Union[Tensor, List], col_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) + ` with specified 2-dimensional blocks at the given + :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix + multiplication operations in BSR format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + blocks in a given row. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> crow_indices = [0, 1, 2] + >>> col_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsr) + """ + ... +def sparse_compressed_tensor(compressed_indices: Union[Tensor, List], plain_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, *, dtype=None, layout=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, + CSC, BSR, or BSC - ` with specified values at + the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse + matrix multiplication operations in Compressed Sparse format are + typically faster than that for sparse tensors in COO format. Make you + have a look at :ref:`the note on the data type of the indices + `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. + plain_indices (array_like): Plain dimension (column or row) + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. + + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + layout (:class:`torch.layout`, required): the desired layout of + returned tensor: :attr:`torch.sparse_csr`, + :attr:`torch.sparse_csc`, :attr:`torch.sparse_bsr`, or + :attr:`torch.sparse_bsc`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> compressed_indices = [0, 2, 4] + >>> plain_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_compressed_tensor(torch.tensor(compressed_indices, dtype=torch.int64), + ... torch.tensor(plain_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double, layout=torch.sparse_csr) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + ... +def sparse_coo_tensor(indices: Tensor, values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None, is_coalesced: Optional[_bool] = None) -> Tensor: + r""" + sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor + + Constructs a :ref:`sparse tensor in COO(rdinate) format + ` with specified values at the given + :attr:`indices`. + + .. note:: + + This function returns an :ref:`uncoalesced tensor + ` when :attr:`is_coalesced` is + unspecified or ``None``. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + indices (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` + internally. The indices are the coordinates of the non-zero values in the matrix, and thus + should be two-dimensional where the first dimension is the number of tensor dimensions and + the second dimension is the number of non-zero values. + values (array_like): Initial values for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not + provided the size will be inferred as the minimum size big enough to hold all non-zero + elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if None, infers data type from :attr:`values`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + is_coalesced (bool, optional): When``True``, the caller is + responsible for providing tensor indices that correspond to a + coalesced tensor. If the :attr:`check_invariants` flag is + False, no error will be raised if the prerequisites are not + met and this will lead to silently incorrect results. To force + coalescion please use :meth:`coalesce` on the resulting + Tensor. + Default: None: except for trivial cases (e.g. nnz < 2) the + resulting Tensor has is_coalesced set to ``False```. + + Example:: + + >>> i = torch.tensor([[0, 1, 1], + ... [2, 0, 2]]) + >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) + >>> torch.sparse_coo_tensor(i, v, [2, 4]) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 4), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v) # Shape inference + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v, [2, 4], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_coo) + + # Create an empty sparse tensor with the following invariants: + # 1. sparse_dim + dense_dim = len(SparseTensor.shape) + # 2. SparseTensor._indices().shape = (sparse_dim, nnz) + # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) + # + # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and + # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0,)), + size=(1,), nnz=0, layout=torch.sparse_coo) + + # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and + # sparse_dim = 1 + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0, 2)), + size=(1, 2), nnz=0, layout=torch.sparse_coo) + + .. _torch.sparse: https://pytorch.org/docs/stable/sparse.html + """ + ... +def sparse_csc_tensor(ccol_indices: Union[Tensor, List], row_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) + ` with specified values at the given + :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in CSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. + row_indices (array_like): Row co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> ccol_indices = [0, 2, 4] + >>> row_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) + """ + ... +def sparse_csr_tensor(crow_indices: Union[Tensor, List], col_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified + values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations + in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look + at :ref:`the note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. + col_indices (array_like): Column co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> crow_indices = [0, 2, 4] + >>> col_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + ... +def split_copy(input: Tensor, split_size: Union[_int, SymInt], dim: _int = 0, *, out: Union[Tuple[Tensor, ...], List[Tensor], None] = None) -> None: + r""" + Performs the same operation as :func:`torch.split`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def split_with_sizes(input: Tensor, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: ... +def split_with_sizes_copy(input: Tensor, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0, *, out: Union[Tuple[Tensor, ...], List[Tensor], None] = None) -> None: + r""" + Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def spmm(input: Tensor, mat2: Tensor) -> Tensor: ... +def sqrt(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the square-root of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.sqrt(a) + tensor([ nan, 1.0112, 0.2883, 0.6933]) + """ + ... +def sqrt_(input: Tensor) -> Tensor: ... +def square(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + square(input, *, out=None) -> Tensor + + Returns a new tensor with the square of the elements of :attr:`input`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.square(a) + tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) + """ + ... +def square_(input: Tensor) -> Tensor: ... +@overload +def squeeze(input: Tensor) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze(input: Tensor, dim: _int) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze(input: Tensor, dim: _size) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze(input: Tensor, dim: Union[str, ellipsis, None]) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def squeeze_copy(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def squeeze_copy(input: Tensor, dim: _size, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def sspaddmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + ... +@overload +def sspaddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + ... +@overload +def sspaddmm(beta: Union[Number, _complex], self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + ... +def stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + stack(tensors, dim=0, *, out=None) -> Tensor + + Concatenates a sequence of tensors along a new dimension. + + All tensors need to be of the same size. + + .. seealso:: + + :func:`torch.cat` concatenates the given sequence along an existing dimension. + + Arguments: + tensors (sequence of Tensors): sequence of tensors to concatenate + dim (int, optional): dimension to insert. Has to be between 0 and the number + of dimensions of concatenated tensors (inclusive). Default: 0 + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]) + >>> x = torch.stack((x, x)) # same as torch.stack((x, x), dim=0) + >>> x + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]], + + [[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]]) + >>> x.size() + torch.Size([2, 2, 3]) + >>> x = torch.stack((x, x), dim=1) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.3367, 0.1288, 0.2345]], + + [[ 0.2303, -1.1229, -0.1863], + [ 0.2303, -1.1229, -0.1863]]]) + >>> x = torch.stack((x, x), dim=2) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + >>> x = torch.stack((x, x), dim=-1) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + """ + ... +@overload +def std(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, unbiased: _bool = True) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def sub(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], *, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + ... +@overload +def sub(self: Tensor, alpha: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + ... +@overload +def sub(self: Tensor, alpha: Union[Number, _complex], other: Tensor, *, out: Tensor) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + ... +@overload +def subtract(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + ... +@overload +def subtract(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + ... +@overload +def sum(input: Tensor, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + ... +@overload +def sum(input: Tensor, dim: Optional[Union[_int, _size]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + ... +@overload +def sum(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + ... +def svd(input: Tensor, some: _bool = True, compute_uv: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.svd: + r""" + svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Computes the singular value decomposition of either a matrix or batch of + matrices :attr:`input`. The singular value decomposition is represented as a + namedtuple `(U, S, V)`, such that :attr:`input` :math:`= U \text{diag}(S) V^{\text{H}}`. + where :math:`V^{\text{H}}` is the transpose of `V` for real inputs, + and the conjugate transpose of `V` for complex inputs. + If :attr:`input` is a batch of matrices, then `U`, `S`, and `V` are also + batched with the same batch dimensions as :attr:`input`. + + If :attr:`some` is `True` (default), the method returns the reduced singular + value decomposition. In this case, if the last two dimensions of :attr:`input` are + `m` and `n`, then the returned `U` and `V` matrices will contain only + `min(n, m)` orthonormal columns. + + If :attr:`compute_uv` is `False`, the returned `U` and `V` will be + zero-filled matrices of shape `(m, m)` and `(n, n)` + respectively, and the same device as :attr:`input`. The argument :attr:`some` + has no effect when :attr:`compute_uv` is `False`. + + Supports :attr:`input` of float, double, cfloat and cdouble data types. + The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will + always be real-valued, even if :attr:`input` is complex. + + .. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.mH + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.linalg.svdvals(A) + + .. note:: Differences with :func:`torch.linalg.svd`: + + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matrices`. Note that + default value for both is `True`, so the default behavior is + effectively the opposite. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns + `Vh`, that is, :math:`V^{\text{H}}`. + * If :attr:`compute_uv` is `False`, :func:`torch.svd` returns zero-filled + tensors for `U` and `Vh`, whereas :func:`torch.linalg.svd` returns + empty tensors. + + .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch are returned in descending order. + + .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is `True`. + + .. note:: When :attr:`some` is `False`, the gradients on `U[..., :, min(m, n):]` + and `V[..., :, min(m, n):]` will be ignored in the backward pass, as those vectors + can be arbitrary bases of the corresponding subspaces. + + .. note:: The implementation of :func:`torch.linalg.svd` on CPU uses LAPACK's routine `?gesdd` + (a divide-and-conquer algorithm) instead of `?gesvd` for speed. Analogously, + on GPU, it uses cuSOLVER's routines `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 + and later, and MAGMA's routine `gesdd` on earlier versions of CUDA. + + .. note:: The returned `U` will not be contiguous. The matrix (or batch of matrices) will + be represented as a column-major matrix (i.e. Fortran-contiguous). + + .. warning:: The gradients with respect to `U` and `V` will only be finite when the input does not + have zero nor repeated singular values. + + .. warning:: If the distance between any two singular values is close to zero, the gradients with respect to + `U` and `V` will be numerically unstable, as they depends on + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix + has small singular values, as these gradients also depend on `S^{-1}`. + + .. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique, + as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column. + The same happens when :attr:`input` has repeated singular values, where one may multiply + the columns of the spanning subspace in `U` and `V` by a rotation matrix + and `the resulting vectors will span the same subspace`_. + Different platforms, like NumPy, or inputs on different device types, + may produce different `U` and `V` tensors. + + Args: + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m, n)` matrices. + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently, the shape of returned `U` and `V`. Default: `True`. + compute_uv (bool, optional): controls whether to compute `U` and `V`. Default: `True`. + + Keyword args: + out (tuple, optional): the output tuple of tensors + + Example:: + + >>> a = torch.randn(5, 3) + >>> a + tensor([[ 0.2364, -0.7752, 0.6372], + [ 1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [ 0.3550, -0.4022, 1.5569], + [ 0.2445, -0.0158, 1.1414]]) + >>> u, s, v = torch.svd(a) + >>> u + tensor([[ 0.4027, 0.0287, 0.5434], + [-0.1946, 0.8833, 0.3679], + [ 0.4296, -0.2890, 0.5261], + [ 0.6604, 0.2717, -0.2618], + [ 0.4234, 0.2481, -0.4733]]) + >>> s + tensor([2.3289, 2.0315, 0.7806]) + >>> v + tensor([[-0.0199, 0.8766, 0.4809], + [-0.5080, 0.4054, -0.7600], + [ 0.8611, 0.2594, -0.4373]]) + >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) + tensor(8.6531e-07) + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, v = torch.svd(a_big) + >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) + tensor(2.6503e-06) + + .. _the resulting vectors will span the same subspace: + (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) + """ + ... +def swapaxes(input: Tensor, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes(input, axis0, axis1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + ... +def swapdims(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims(input, dim0, dim1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + ... +def sym_constrain_range(size: Union[Number, _complex], *, min: Optional[_int] = None, max: Optional[_int] = None) -> None: ... +def sym_constrain_range_for_size(size: Union[Number, _complex], *, min: Optional[_int] = None, max: Optional[_int] = None) -> None: ... +def t(input: Tensor) -> Tensor: + r""" + t(input) -> Tensor + + Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 + and 1. + + 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this + is equivalent to ``transpose(input, 0, 1)``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.randn(()) + >>> x + tensor(0.1995) + >>> torch.t(x) + tensor(0.1995) + >>> x = torch.randn(3) + >>> x + tensor([ 2.4320, -0.4608, 0.7702]) + >>> torch.t(x) + tensor([ 2.4320, -0.4608, 0.7702]) + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.4875, 0.9158, -0.5872], + [ 0.3938, -0.6929, 0.6932]]) + >>> torch.t(x) + tensor([[ 0.4875, 0.3938], + [ 0.9158, -0.6929], + [-0.5872, 0.6932]]) + + See also :func:`torch.transpose`. + """ + ... +def t_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.t`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def take(input: Tensor, index: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + take(input, index) -> Tensor + + Returns a new tensor with the elements of :attr:`input` at the given indices. + The input tensor is treated as if it were viewed as a 1-D tensor. The result + takes the same shape as the indices. + + Args: + input (Tensor): the input tensor. + index (LongTensor): the indices into tensor + + Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> torch.take(src, torch.tensor([0, 2, 5])) + tensor([ 4, 5, 8]) + """ + ... +def take_along_dim(input: Tensor, indices: Tensor, dim: Optional[_int] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + take_along_dim(input, indices, dim=None, *, out=None) -> Tensor + + Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. + + If :attr:`dim` is None, the input array is treated as if it has been flattened to 1d. + + Functions that return indices along a dimension, like :func:`torch.argmax` and :func:`torch.argsort`, + are designed to work with this function. See the examples below. + + .. note:: + This function is similar to NumPy's `take_along_axis`. + See also :func:`torch.gather`. + + Args: + input (Tensor): the input tensor. + indices (tensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) + >>> max_idx = torch.argmax(t) + >>> torch.take_along_dim(t, max_idx) + tensor([60]) + >>> sorted_idx = torch.argsort(t, dim=1) + >>> torch.take_along_dim(t, sorted_idx, dim=1) + tensor([[10, 20, 30], + [40, 50, 60]]) + """ + ... +def tan(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + tan(input, *, out=None) -> Tensor + + Returns a new tensor with the tangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.2027, -1.7687, 0.4412, -1.3856]) + >>> torch.tan(a) + tensor([-2.5930, 4.9859, 0.4722, -5.3366]) + """ + ... +def tan_(input: Tensor) -> Tensor: ... +def tanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + tanh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic tangent of the elements + of :attr:`input`. + + .. math:: + \text{out}_{i} = \tanh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) + >>> torch.tanh(a) + tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) + """ + ... +def tanh_(input: Tensor) -> Tensor: ... +def tensor(data: Any, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. + + .. warning:: + + When working with tensors prefer using :func:`torch.Tensor.clone`, + :func:`torch.Tensor.detach`, and :func:`torch.Tensor.requires_grad_` for + readability. Letting `t` be a tensor, ``torch.tensor(t)`` is equivalent to + ``t.clone().detach()``, and ``torch.tensor(t, requires_grad=True)`` + is equivalent to ``t.clone().detach().requires_grad_(True)``. + + .. seealso:: + + :func:`torch.as_tensor` preserves autograd history and avoids copies where possible. + :func:`torch.from_numpy` creates a tensor that shares storage with a NumPy array. + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + + Example:: + + >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + tensor([[ 0.1000, 1.2000], + [ 2.2000, 3.1000], + [ 4.9000, 5.2000]]) + + >>> torch.tensor([0, 1]) # Type inference on data + tensor([ 0, 1]) + + >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) # creates a double tensor on a CUDA device + tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') + + >>> torch.tensor(3.14159) # Create a zero-dimensional (scalar) tensor + tensor(3.1416) + + >>> torch.tensor([]) # Create an empty tensor (of size (0,)) + tensor([]) + """ + ... +@overload +def tensor_split(input: Tensor, tensor_indices_or_sections: Tensor, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + ... +@overload +def tensor_split(input: Tensor, sections: Union[_int, SymInt], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + ... +@overload +def tensor_split(input: Tensor, indices: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + ... +def threshold(input: Tensor, threshold: Union[Number, _complex], value: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: ... +def threshold_(input: Tensor, threshold: Union[Number, _complex], value: Union[Number, _complex]) -> Tensor: ... +def tile(input: Tensor, dims: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + tile(input, dims) -> Tensor + + Constructs a tensor by repeating the elements of :attr:`input`. + The :attr:`dims` argument specifies the number of repetitions + in each dimension. + + If :attr:`dims` specifies fewer dimensions than :attr:`input` has, then + ones are prepended to :attr:`dims` until all dimensions are specified. + For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`dims` + is (2, 2), then :attr:`dims` is treated as (1, 1, 2, 2). + + Analogously, if :attr:`input` has fewer dimensions than :attr:`dims` + specifies, then :attr:`input` is treated as if it were unsqueezed at + dimension zero until it has as many dimensions as :attr:`dims` specifies. + For example, if :attr:`input` has shape (4, 2) and :attr:`dims` + is (3, 3, 2, 2), then :attr:`input` is treated as if it had the + shape (1, 1, 4, 2). + + .. note:: + + This function is similar to NumPy's tile function. + + Args: + input (Tensor): the tensor whose elements to repeat. + dims (tuple): the number of repetitions per dimension. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + """ + ... +def topk(input: Tensor, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.topk: + r""" + topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) + + Returns the :attr:`k` largest elements of the given :attr:`input` tensor along + a given dimension. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`largest` is ``False`` then the `k` smallest elements are returned. + + A namedtuple of `(values, indices)` is returned with the `values` and + `indices` of the largest `k` elements of each row of the `input` tensor in the + given dimension `dim`. + + The boolean option :attr:`sorted` if ``True``, will make sure that the returned + `k` elements are themselves sorted + + Args: + input (Tensor): the input tensor. + k (int): the k in "top-k" + dim (int, optional): the dimension to sort along + largest (bool, optional): controls whether to return largest or + smallest elements + sorted (bool, optional): controls whether to return the elements + in sorted order + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be + optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.topk(x, 3) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) + """ + ... +def trace(input: Tensor) -> Tensor: + r""" + trace(input) -> Tensor + + Returns the sum of the elements of the diagonal of the input 2-D matrix. + + Example:: + + >>> x = torch.arange(1., 10.).view(3, 3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]]) + >>> torch.trace(x) + tensor(15.) + """ + ... +@overload +def transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + ... +@overload +def transpose(input: Tensor, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None]) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + ... +def transpose_copy(input: Tensor, dim0: _int, dim1: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.transpose`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + ... +@overload +def trapezoid(y: Tensor, *, dx: Union[Number, _complex] = 1, dim: _int = -1) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + ... +@overload +def trapz(y: Tensor, *, dx: _float = 1, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + ... +@overload +def trapz(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + ... +def triangular_solve(input: Tensor, A: Tensor, upper: _bool = True, transpose: _bool = False, unitriangular: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.triangular_solve: + r""" + triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) + + Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` + and multiple right-hand sides :math:`b`. + + In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular + (or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal. + + `torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are + batches of 2D matrices. If the inputs are batches, then returns + batched outputs `X` + + If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and + :attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, + the result may contain `NaN` s. + + Supports input of float, double, cfloat and cdouble data types. + + .. warning:: + + :func:`torch.triangular_solve` is deprecated in favor of :func:`torch.linalg.solve_triangular` + and will be removed in a future PyTorch release. + :func:`torch.linalg.solve_triangular` has its arguments reversed and does not return a + copy of one of the inputs. + + ``X = torch.triangular_solve(B, A).solution`` should be replaced with + + .. code:: python + + X = torch.linalg.solve_triangular(A, B) + + Args: + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``. + transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``, + and `op(A) = A` if it is ``False``. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + + Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + + Returns: + A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` + is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` + (or whatever variant of the system of equations, depending on the keyword arguments.) + + Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + torch.return_types.triangular_solve( + solution=tensor([[ 1.7841, 2.9046, -2.5405], + [ 1.9320, 0.9270, -1.2826]]), + cloned_coefficient=tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) + """ + ... +def tril(input: Tensor, diagonal: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + tril(input, diagonal=0, *, out=None) -> Tensor + + Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0813, -0.8619, 0.7105], + [ 0.0935, 0.1380, 2.2112], + [-0.3409, -0.9828, 0.0289]]) + >>> torch.tril(a) + tensor([[-1.0813, 0.0000, 0.0000], + [ 0.0935, 0.1380, 0.0000], + [-0.3409, -0.9828, 0.0289]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], + [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) + >>> torch.tril(b, diagonal=1) + tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) + >>> torch.tril(b, diagonal=-1) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) + """ + ... +def tril_indices(row: _int, col: _int, offset: _int = 0, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the lower triangular part of a :attr:`row`-by- + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.tril_indices(3, 3) + >>> a + tensor([[0, 1, 1, 2, 2, 2], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, -1) + >>> a + tensor([[1, 2, 2, 3, 3, 3], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) + """ + ... +def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, margin: _float = 1.0, p: _float = 2, eps: _float = 1e-06, swap: _bool = False, reduction: _int = 1) -> Tensor: ... +def triu(input: Tensor, diagonal: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + triu(input, diagonal=0, *, out=None) -> Tensor + + Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.3480, -0.5211, -0.4573]]) + >>> torch.triu(a) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.0000, -1.0680, 0.6602], + [ 0.0000, 0.0000, -0.4573]]) + >>> torch.triu(a, diagonal=1) + tensor([[ 0.0000, 0.5207, 2.0049], + [ 0.0000, 0.0000, 0.6602], + [ 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(a, diagonal=-1) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.0000, -0.5211, -0.4573]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) + """ + ... +def triu_indices(row: _int, col: _int, offset: _int = 0, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the upper triangular part of a :attr:`row` by + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.triu_indices(3, 3) + >>> a + tensor([[0, 0, 0, 1, 1, 2], + [0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, -1) + >>> a + tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], + [0, 1, 2, 0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1], + [1, 2, 2]]) + """ + ... +def true_divide(input: Union[Tensor, Number], other: Union[Tensor, Number], *, out: Optional[Tensor] = None) -> Tensor: + r""" + true_divide(dividend, divisor, *, out) -> Tensor + + Alias for :func:`torch.div` with ``rounding_mode=None``. + """ + ... +def trunc(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + trunc(input, *, out=None) -> Tensor + + Returns a new tensor with the truncated integer values of + the elements of :attr:`input`. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) + >>> torch.trunc(a) + tensor([ 3., 0., -0., -0.]) + """ + ... +def trunc_(input: Tensor) -> Tensor: ... +@overload +def unbind(input: Tensor, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + ... +@overload +def unbind(input: Tensor, dim: Union[str, ellipsis, None]) -> Tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + ... +def unbind_copy(input: Tensor, dim: _int = 0, *, out: Union[Tuple[Tensor, ...], List[Tensor], None] = None) -> None: + r""" + Performs the same operation as :func:`torch.unbind`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def unflatten(input: Tensor, dim: Union[str, ellipsis, None], sizes: Sequence[Union[_int, SymInt]], names: Sequence[Union[str, ellipsis, None]]) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + ... +@overload +def unflatten(input: Tensor, dim: _int, sizes: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + ... +def unfold_copy(input: Tensor, dimension: _int, size: _int, step: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.unfold`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def unique_dim(input: Tensor, dim: _int, sorted: _bool = True, return_inverse: _bool = False, return_counts: _bool = False) -> Tuple[Tensor, Tensor, Tensor]: ... +def unsafe_chunk(input: Tensor, chunks: _int, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unsafe_chunk(input, chunks, dim=0) -> List of Tensors + + Works like :func:`torch.chunk` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + ... +def unsafe_split(input: Tensor, split_size: Union[_int, SymInt], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors + + Works like :func:`torch.split` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + ... +def unsafe_split_with_sizes(input: Tensor, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: ... +def unsqueeze(input: Tensor, dim: _int) -> Tensor: + r""" + unsqueeze(input, dim) -> Tensor + + Returns a new tensor with a dimension of size one inserted at the + specified position. + + The returned tensor shares the same underlying data with this tensor. + + A :attr:`dim` value within the range ``[-input.dim() - 1, input.dim() + 1)`` + can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` + applied at :attr:`dim` = ``dim + input.dim() + 1``. + + Args: + input (Tensor): the input tensor. + dim (int): the index at which to insert the singleton dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4]) + >>> torch.unsqueeze(x, 0) + tensor([[ 1, 2, 3, 4]]) + >>> torch.unsqueeze(x, 1) + tensor([[ 1], + [ 2], + [ 3], + [ 4]]) + """ + ... +def unsqueeze_copy(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.unsqueeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def values_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.values`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def vander(x: Tensor, N: Optional[_int] = None, increasing: _bool = False) -> Tensor: + r""" + vander(x, N=None, increasing=False) -> Tensor + + Generates a Vandermonde matrix. + + The columns of the output matrix are elementwise powers of the input vector :math:`x^{(N-1)}, x^{(N-2)}, ..., x^0`. + If increasing is True, the order of the columns is reversed :math:`x^0, x^1, ..., x^{(N-1)}`. Such a + matrix with a geometric progression in each row is named for Alexandre-Theophile Vandermonde. + + Arguments: + x (Tensor): 1-D input tensor. + N (int, optional): Number of columns in the output. If N is not specified, + a square array is returned :math:`(N = len(x))`. + increasing (bool, optional): Order of the powers of the columns. If True, + the powers increase from left to right, if False (the default) they are reversed. + + Returns: + Tensor: Vandermonde matrix. If increasing is False, the first column is :math:`x^{(N-1)}`, + the second :math:`x^{(N-2)}` and so forth. If increasing is True, the columns + are :math:`x^0, x^1, ..., x^{(N-1)}`. + + Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> torch.vander(x) + tensor([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + >>> torch.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 4, 2, 1], + [ 9, 3, 1], + [25, 5, 1]]) + >>> torch.vander(x, N=3, increasing=True) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) + """ + ... +@overload +def var(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, unbiased: _bool = True) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +def vdot(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + vdot(input, other, *, out=None) -> Tensor + + Computes the dot product of two 1D vectors along a dimension. + + In symbols, this function computes + + .. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + + where :math:`\overline{x_i}` denotes the conjugate for complex + vectors, and it is the identity for real vectors. + + .. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + .. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + + .. note:: out (Tensor, optional): the output tensor. + + + Example:: + + >>> torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + >>> a = torch.tensor((1 +2j, 3 - 1j)) + >>> b = torch.tensor((2 +1j, 4 - 0j)) + >>> torch.vdot(a, b) + tensor([16.+1.j]) + >>> torch.vdot(b, a) + tensor([16.-1.j]) + """ + ... +def view_as_complex(input: Tensor) -> Tensor: + r""" + view_as_complex(input) -> Tensor + + Returns a view of :attr:`input` as a complex tensor. For an input complex + tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a + new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last + dimension of the input tensor is expected to represent the real and imaginary + components of complex numbers. + + .. warning:: + :func:`view_as_complex` is only supported for tensors with + :class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is + expected to have the last dimension of :attr:`size` 2. In addition, the + tensor must have a `stride` of 1 for its last dimension. The strides of all + other dimensions must be even numbers. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, 2) + >>> x + tensor([[ 1.6116, -0.5772], + [-1.4606, -0.9120], + [ 0.0786, -1.7497], + [-0.6561, -1.6623]]) + >>> torch.view_as_complex(x) + tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) + """ + ... +def view_as_complex_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_complex`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def view_as_real(input: Tensor) -> Tensor: + r""" + view_as_real(input) -> Tensor + + Returns a view of :attr:`input` as a real tensor. For an input complex tensor of + :attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new + real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2 + represents the real and imaginary components of complex numbers. + + .. warning:: + :func:`view_as_real` is only supported for tensors with ``complex dtypes``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)]) + >>> torch.view_as_real(x) + tensor([[ 0.4737, -0.3839], + [-0.2098, -0.6699], + [ 0.3470, -0.9451], + [-0.5174, -1.3136]]) + """ + ... +def view_as_real_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_real`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def view_copy(input: Tensor, dtype: _dtype, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def view_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def vsplit(input: Tensor, sections: _int) -> Tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + ... +@overload +def vsplit(input: Tensor, indices: _size) -> Tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + ... +def vstack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + vstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence vertically (row wise). + + This is equivalent to concatenation along the first axis after all 1-D tensors have been reshaped by :func:`torch.atleast_2d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.vstack((a,b)) + tensor([[1, 2, 3], + [4, 5, 6]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.vstack((a,b)) + tensor([[1], + [2], + [3], + [4], + [5], + [6]]) + """ + ... +@overload +def where(condition: Tensor) -> Tuple[Tensor, ...]: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, input: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, self: Union[Number, _complex], other: Union[Number, _complex]) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def xlogy(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + ... +@overload +def xlogy(self: Union[Number, _complex], other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + ... +@overload +def xlogy(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + ... +@overload +def xlogy_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def xlogy_(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +def zero_(input: Tensor) -> Tensor: ... +@overload +def zeros(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +@overload +def zeros(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +@overload +def zeros(size: _size, *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +@overload +def zeros(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +def zeros_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the same size as + :attr:`input`. ``torch.zeros_like(input)`` is equivalent to + ``torch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.zeros_like(input, out=output)`` is equivalent to + ``torch.zeros(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.zeros_like(input) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + ... + +__all__ = ['__and__', '__lshift__', '__or__', '__rshift__', '__xor__', '_adaptive_avg_pool2d', + '_adaptive_avg_pool3d', '_add_batch_dim', '_add_relu', '_add_relu_', '_addmm_activation', + '_aminmax', '_amp_foreach_non_finite_check_and_unscale_', '_amp_update_scale_', '_assert_async', + '_assert_scalar', '_assert_tensor_metadata', '_batch_norm_impl_index', '_cast_Byte', '_cast_Char', + '_cast_Double', '_cast_Float', '_cast_Half', '_cast_Int', '_cast_Long', '_cast_Short', + '_choose_qparams_per_tensor', '_chunk_cat', '_coalesce', '_compute_linear_combination', '_conj', + '_conj_copy', '_conj_physical', '_convert_indices_from_coo_to_csr', + '_convert_indices_from_csr_to_coo', '_convert_weight_to_int4pack', '_convolution', + '_convolution_mode', '_copy_from', '_copy_from_and_resize', '_cslt_compress', '_cslt_sparse_mm', + '_cslt_sparse_mm_search', '_ctc_loss', '_cudnn_ctc_loss', '_cudnn_init_dropout_state', + '_cudnn_rnn', '_cudnn_rnn_flatten_weight', '_cufft_clear_plan_cache', + '_cufft_get_plan_cache_max_size', '_cufft_get_plan_cache_size', '_cufft_set_plan_cache_max_size', + '_cummax_helper', '_cummin_helper', '_debug_has_internal_overlap', '_dim_arange', + '_dirichlet_grad', '_disable_functionalization', '_efficientzerotensor', '_embedding_bag', + '_embedding_bag_forward_only', '_empty_affine_quantized', '_empty_per_channel_affine_quantized', + '_enable_functionalization', '_euclidean_dist', '_fake_quantize_learnable_per_channel_affine', + '_fake_quantize_learnable_per_tensor_affine', + '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', + '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', '_fft_c2c', '_fft_c2r', '_fft_r2c', + '_fill_mem_eff_dropout_mask_', '_foobar', '_foreach_abs', '_foreach_abs_', '_foreach_acos', + '_foreach_acos_', '_foreach_add', '_foreach_add_', '_foreach_addcdiv', '_foreach_addcdiv_', + '_foreach_addcmul', '_foreach_addcmul_', '_foreach_asin', '_foreach_asin_', '_foreach_atan', + '_foreach_atan_', '_foreach_ceil', '_foreach_ceil_', '_foreach_clamp_max', '_foreach_clamp_max_', + '_foreach_clamp_min', '_foreach_clamp_min_', '_foreach_copy_', '_foreach_cos', '_foreach_cos_', + '_foreach_cosh', '_foreach_cosh_', '_foreach_div', '_foreach_div_', '_foreach_erf', + '_foreach_erf_', '_foreach_erfc', '_foreach_erfc_', '_foreach_exp', '_foreach_exp_', + '_foreach_expm1', '_foreach_expm1_', '_foreach_floor', '_foreach_floor_', '_foreach_frac', + '_foreach_frac_', '_foreach_lerp', '_foreach_lerp_', '_foreach_lgamma', '_foreach_lgamma_', + '_foreach_log', '_foreach_log10', '_foreach_log10_', '_foreach_log1p', '_foreach_log1p_', + '_foreach_log2', '_foreach_log2_', '_foreach_log_', '_foreach_maximum', '_foreach_maximum_', + '_foreach_minimum', '_foreach_minimum_', '_foreach_mul', '_foreach_mul_', '_foreach_neg', + '_foreach_neg_', '_foreach_norm', '_foreach_pow', '_foreach_pow_', '_foreach_reciprocal', + '_foreach_reciprocal_', '_foreach_round', '_foreach_round_', '_foreach_sigmoid', + '_foreach_sigmoid_', '_foreach_sign', '_foreach_sign_', '_foreach_sin', '_foreach_sin_', + '_foreach_sinh', '_foreach_sinh_', '_foreach_sqrt', '_foreach_sqrt_', '_foreach_sub', + '_foreach_sub_', '_foreach_tan', '_foreach_tan_', '_foreach_tanh', '_foreach_tanh_', + '_foreach_trunc', '_foreach_trunc_', '_foreach_zero_', '_from_functional_tensor', + '_functional_assert_async', '_functional_assert_scalar', '_functional_sym_constrain_range', + '_functional_sym_constrain_range_for_size', + '_functionalize_are_all_mutations_hidden_from_autograd', + '_functionalize_are_all_mutations_under_no_grad_or_inference_mode', '_functionalize_commit_update', + '_functionalize_mark_mutation_hidden_from_autograd', '_functionalize_replace', + '_functionalize_sync', '_fused_adam_', '_fused_adamw_', '_fused_dropout', + '_fused_moving_avg_obs_fq_helper', '_fused_moving_avg_obs_fq_helper', '_fused_sdp_choice', + '_fused_sgd_', '_fw_primal_copy', '_grid_sampler_2d_cpu_fallback', + '_has_compatible_shallow_copy_type', '_histogramdd_bin_edges', '_histogramdd_from_bin_cts', + '_histogramdd_from_bin_tensors', '_index_put_impl_', '_indices_copy', '_int_mm', '_is_all_true', + '_is_any_true', '_is_functional_tensor', '_is_zerotensor', '_lazy_clone', '_linalg_check_errors', + '_linalg_det', '_linalg_det', '_linalg_eigh', '_linalg_eigh', '_linalg_slogdet', '_linalg_slogdet', + '_linalg_solve_ex', '_linalg_solve_ex', '_linalg_svd', '_linalg_svd', '_log_softmax', + '_log_softmax_backward_data', '_logcumsumexp', '_lstm_mps', '_lu_with_info', '_lu_with_info', + '_make_dep_token', '_make_dual', '_make_dual_copy', '_make_per_channel_quantized_tensor', + '_make_per_tensor_quantized_tensor', '_masked_scale', '_masked_softmax', '_mixed_dtypes_linear', + '_mkldnn_reshape', '_mkldnn_transpose', '_mkldnn_transpose_', '_mps_convolution', + '_mps_convolution_transpose', '_native_batch_norm_legit', '_native_batch_norm_legit_no_training', + '_native_multi_head_attention', '_neg_view', '_neg_view_copy', '_nested_from_padded', + '_nested_from_padded_and_nested_example', '_nested_get_jagged_dummy', '_nested_get_lengths', + '_nested_get_offsets', '_nested_get_ragged_idx', '_nested_get_values', '_nested_get_values_copy', + '_nested_tensor_from_mask', '_nested_tensor_from_mask_left_aligned', + '_nested_tensor_from_tensor_list', '_nested_tensor_softmax_with_shape', '_nested_view_from_buffer', + '_nested_view_from_buffer_copy', '_nested_view_from_jagged', '_nested_view_from_jagged_copy', + '_nnpack_available', '_nnpack_spatial_convolution', '_pack_padded_sequence', + '_pad_packed_sequence', '_pin_memory', '_prelu_kernel', '_print', '_propagate_xla_data', + '_remove_batch_dim', '_reshape_alias_copy', '_reshape_from_tensor', '_resize_output_', + '_rowwise_prune', '_sample_dirichlet', '_saturate_weight_to_fp16', + '_scaled_dot_product_attention_math', '_scaled_dot_product_cudnn_attention', + '_scaled_dot_product_cudnn_attention', '_scaled_dot_product_efficient_attention', + '_scaled_dot_product_efficient_attention', '_scaled_dot_product_flash_attention', + '_scaled_dot_product_flash_attention', '_scaled_dot_product_flash_attention_for_cpu', + '_scaled_dot_product_flash_attention_for_cpu', '_scaled_mm', '_shape_as_tensor', + '_sobol_engine_draw', '_sobol_engine_ff_', '_sobol_engine_initialize_state_', + '_sobol_engine_scramble_', '_softmax', '_softmax_backward_data', '_sparse_broadcast_to', + '_sparse_broadcast_to_copy', '_sparse_csr_prod', '_sparse_csr_sum', + '_sparse_log_softmax_backward_data', '_sparse_semi_structured_linear', + '_sparse_softmax_backward_data', '_sparse_sparse_matmul', '_sparse_sum', '_stack', + '_standard_gamma', '_standard_gamma_grad', '_sync', '_test_autograd_multiple_dispatch', + '_test_autograd_multiple_dispatch_view', '_test_autograd_multiple_dispatch_view_copy', + '_test_check_tensor', '_test_functorch_fallback', '_test_parallel_materialize', + '_test_serialization_subcmul', '_to_cpu', '_to_functional_tensor', '_to_sparse_semi_structured', + '_transform_bias_rescale_qkv', '_transformer_encoder_layer_fwd', '_trilinear', + '_triton_multi_head_attention', '_triton_scaled_dot_attention', '_unique', '_unique2', + '_unpack_dual', '_unpack_dual', '_unsafe_index', '_unsafe_index_put', '_use_cudnn_ctc_loss', + '_use_cudnn_rnn_flatten_weight', '_validate_compressed_sparse_indices', + '_validate_sparse_bsc_tensor_args', '_validate_sparse_bsr_tensor_args', + '_validate_sparse_compressed_tensor_args', '_validate_sparse_coo_tensor_args', + '_validate_sparse_csc_tensor_args', '_validate_sparse_csr_tensor_args', '_values_copy', + '_weight_int4pack_mm', '_weight_int8pack_mm', '_weight_norm', '_weight_norm_interface', 'abs', + 'abs_', 'absolute', 'acos', 'acos_', 'acosh', 'acosh_', 'adaptive_avg_pool1d', + 'adaptive_max_pool1d', 'add', 'addbmm', 'addcdiv', 'addcmul', 'addmm', 'addmv', 'addmv_', 'addr', + 'adjoint', 'affine_grid_generator', 'alias_copy', 'all', 'allclose', 'alpha_dropout', + 'alpha_dropout_', 'amax', 'amin', 'aminmax', 'aminmax', 'angle', 'any', 'arange', 'arccos', + 'arccos_', 'arccosh', 'arccosh_', 'arcsin', 'arcsin_', 'arcsinh', 'arcsinh_', 'arctan', 'arctan2', + 'arctan_', 'arctanh', 'arctanh_', 'argmax', 'argmin', 'argsort', 'argwhere', 'as_strided', + 'as_strided_', 'as_strided_copy', 'as_strided_scatter', 'as_tensor', 'asarray', 'asin', 'asin_', + 'asinh', 'asinh_', 'atan', 'atan2', 'atan_', 'atanh', 'atanh_', 'avg_pool1d', 'baddbmm', + 'bartlett_window', 'batch_norm', 'batch_norm_backward_elemt', 'batch_norm_backward_reduce', + 'batch_norm_elemt', 'batch_norm_gather_stats', 'batch_norm_gather_stats_with_counts', + 'batch_norm_stats', 'batch_norm_update_stats', 'bernoulli', 'bilinear', + 'binary_cross_entropy_with_logits', 'bincount', 'binomial', 'bitwise_and', 'bitwise_left_shift', + 'bitwise_not', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'blackman_window', 'bmm', + 'broadcast_to', 'bucketize', 'can_cast', 'cat', 'ccol_indices_copy', 'ceil', 'ceil_', 'celu', + 'celu_', 'channel_shuffle', 'cholesky', 'cholesky_inverse', 'cholesky_solve', + 'choose_qparams_optimized', 'chunk', 'clamp', 'clamp_', 'clamp_max', 'clamp_max_', 'clamp_min', + 'clamp_min_', 'clip', 'clip_', 'clone', 'col_indices_copy', 'column_stack', 'combinations', + 'complex', 'concat', 'concatenate', 'conj', 'conj_physical', 'conj_physical_', 'constant_pad_nd', + 'conv1d', 'conv2d', 'conv3d', 'conv_tbc', 'conv_transpose1d', 'conv_transpose2d', + 'conv_transpose3d', 'convolution', 'copysign', 'corrcoef', 'cos', 'cos_', 'cosh', 'cosh_', + 'cosine_embedding_loss', 'cosine_similarity', 'count_nonzero', 'cov', 'cross', 'crow_indices_copy', + 'ctc_loss', 'cudnn_affine_grid_generator', 'cudnn_batch_norm', 'cudnn_convolution', + 'cudnn_convolution_add_relu', 'cudnn_convolution_relu', 'cudnn_convolution_transpose', + 'cudnn_grid_sampler', 'cudnn_is_acceptable', 'cummax', 'cummax', 'cummin', 'cummin', 'cumprod', + 'cumsum', 'cumulative_trapezoid', 'deg2rad', 'deg2rad_', 'dequantize', 'det', 'detach', 'detach_', + 'detach_copy', 'diag', 'diag_embed', 'diagflat', 'diagonal', 'diagonal_copy', 'diagonal_scatter', + 'diff', 'digamma', 'dist', 'div', 'divide', 'dot', 'dropout', 'dropout_', 'dsmm', 'dsplit', + 'dstack', 'embedding', 'embedding_bag', 'embedding_renorm_', 'empty', 'empty_like', + 'empty_permuted', 'empty_quantized', 'empty_strided', 'eq', 'equal', 'erf', 'erf_', 'erfc', + 'erfc_', 'erfinv', 'exp', 'exp2', 'exp2_', 'exp_', 'expand_copy', 'expm1', 'expm1_', 'eye', + 'fake_quantize_per_channel_affine', 'fake_quantize_per_tensor_affine', 'fbgemm_linear_fp16_weight', + 'fbgemm_linear_fp16_weight_fp32_activation', 'fbgemm_linear_int8_weight', + 'fbgemm_linear_int8_weight_fp32_activation', 'fbgemm_linear_quantize_weight', + 'fbgemm_pack_gemm_matrix_fp16', 'fbgemm_pack_quantized_matrix', 'feature_alpha_dropout', + 'feature_alpha_dropout_', 'feature_dropout', 'feature_dropout_', 'fill', 'fill_', 'fix', 'fix_', + 'flatten', 'flip', 'fliplr', 'flipud', 'float_power', 'floor', 'floor_', 'floor_divide', 'fmax', + 'fmin', 'fmod', 'frac', 'frac_', 'frexp', 'frexp', 'frobenius_norm', 'from_file', 'from_numpy', + 'frombuffer', 'full', 'full_like', 'fused_moving_avg_obs_fake_quant', 'gather', 'gcd', 'gcd_', + 'ge', 'geqrf', 'geqrf', 'ger', 'get_default_dtype', 'get_num_interop_threads', 'get_num_threads', + 'gradient', 'greater', 'greater_equal', 'grid_sampler', 'grid_sampler_2d', 'grid_sampler_3d', + 'group_norm', 'gru', 'gru_cell', 'gt', 'hamming_window', 'hann_window', 'hardshrink', 'heaviside', + 'hinge_embedding_loss', 'histc', 'histogram', 'histogram', 'histogramdd', 'histogramdd', 'hsmm', + 'hsplit', 'hspmm', 'hstack', 'hypot', 'i0', 'i0_', 'igamma', 'igammac', 'imag', 'index_add', + 'index_copy', 'index_fill', 'index_put', 'index_put_', 'index_reduce', 'index_select', + 'indices_copy', 'init_num_threads', 'inner', 'instance_norm', 'int_repr', 'inverse', 'is_complex', + 'is_conj', 'is_distributed', 'is_floating_point', 'is_grad_enabled', 'is_inference', + 'is_inference_mode_enabled', 'is_neg', 'is_nonzero', 'is_same_size', 'is_signed', + 'is_vulkan_available', 'isclose', 'isfinite', 'isin', 'isinf', 'isnan', 'isneginf', 'isposinf', + 'isreal', 'istft', 'kaiser_window', 'kl_div', 'kron', 'kthvalue', 'kthvalue', 'layer_norm', 'lcm', + 'lcm_', 'ldexp', 'ldexp_', 'le', 'lerp', 'less', 'less_equal', 'lgamma', 'linspace', 'log', + 'log10', 'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'log_softmax', 'logaddexp', + 'logaddexp2', 'logcumsumexp', 'logdet', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', + 'logit', 'logit_', 'logspace', 'logsumexp', 'lstm', 'lstm_cell', 'lt', 'lu_solve', 'lu_unpack', + 'lu_unpack', 'margin_ranking_loss', 'masked_fill', 'masked_scatter', 'masked_select', 'matmul', + 'matrix_exp', 'matrix_power', 'max', 'max', 'max_pool1d', 'max_pool1d_with_indices', 'max_pool2d', + 'max_pool3d', 'maximum', 'mean', 'median', 'median', 'min', 'min', 'minimum', 'miopen_batch_norm', + 'miopen_convolution', 'miopen_convolution_add_relu', 'miopen_convolution_relu', + 'miopen_convolution_transpose', 'miopen_depthwise_convolution', 'miopen_rnn', + 'mkldnn_adaptive_avg_pool2d', 'mkldnn_convolution', 'mkldnn_linear_backward_weights', + 'mkldnn_max_pool2d', 'mkldnn_max_pool3d', 'mkldnn_rnn_layer', 'mm', 'mode', 'mode', 'moveaxis', + 'movedim', 'msort', 'mul', 'multinomial', 'multiply', 'mv', 'mvlgamma', 'nan_to_num', + 'nan_to_num_', 'nanmean', 'nanmedian', 'nanmedian', 'nanquantile', 'nansum', 'narrow', + 'narrow_copy', 'native_batch_norm', 'native_channel_shuffle', 'native_dropout', + 'native_group_norm', 'native_layer_norm', 'native_norm', 'ne', 'neg', 'neg_', 'negative', + 'negative_', 'nextafter', 'nonzero', 'nonzero_static', 'norm_except_dim', 'normal', 'not_equal', + 'nuclear_norm', 'numel', 'ones', 'ones_like', 'orgqr', 'ormqr', 'outer', 'pairwise_distance', + 'pdist', 'permute', 'permute_copy', 'pinverse', 'pixel_shuffle', 'pixel_unshuffle', 'poisson', + 'poisson_nll_loss', 'polar', 'polygamma', 'positive', 'pow', 'prelu', 'prod', 'promote_types', + 'put', 'q_per_channel_axis', 'q_per_channel_scales', 'q_per_channel_zero_points', 'q_scale', + 'q_zero_point', 'qr', 'qr', 'quantile', 'quantize_per_channel', 'quantize_per_tensor', + 'quantize_per_tensor_dynamic', 'quantized_batch_norm', 'quantized_gru_cell', 'quantized_lstm_cell', + 'quantized_max_pool1d', 'quantized_max_pool2d', 'quantized_max_pool3d', 'quantized_rnn_relu_cell', + 'quantized_rnn_tanh_cell', 'rad2deg', 'rad2deg_', 'rand', 'rand_like', 'randint', 'randint_like', + 'randn', 'randn_like', 'randperm', 'range', 'ravel', 'real', 'reciprocal', 'reciprocal_', 'relu', + 'relu_', 'remainder', 'renorm', 'repeat_interleave', 'reshape', 'resize_as_', 'resize_as_sparse_', + 'resolve_conj', 'resolve_neg', 'result_type', 'rnn_relu', 'rnn_relu_cell', 'rnn_tanh', + 'rnn_tanh_cell', 'roll', 'rot90', 'round', 'round_', 'row_indices_copy', 'row_stack', 'rrelu', + 'rrelu_', 'rsqrt', 'rsqrt_', 'rsub', 'saddmm', 'scalar_tensor', 'scatter', 'scatter_add', + 'scatter_reduce', 'searchsorted', 'segment_reduce', 'select', 'select_copy', 'select_scatter', + 'selu', 'selu_', 'set_flush_denormal', 'set_num_interop_threads', 'set_num_threads', 'sgn', + 'sigmoid', 'sigmoid_', 'sign', 'signbit', 'sin', 'sin_', 'sinc', 'sinc_', 'sinh', 'sinh_', + 'slice_copy', 'slice_inverse', 'slice_scatter', 'slogdet', 'slogdet', 'smm', 'softmax', 'sort', + 'sort', 'sparse_bsc_tensor', 'sparse_bsr_tensor', 'sparse_compressed_tensor', 'sparse_coo_tensor', + 'sparse_csc_tensor', 'sparse_csr_tensor', 'split_copy', 'split_with_sizes', + 'split_with_sizes_copy', 'spmm', 'sqrt', 'sqrt_', 'square', 'square_', 'squeeze', 'squeeze_copy', + 'sspaddmm', 'stack', 'std', 'std_mean', 'sub', 'subtract', 'sum', 'svd', 'svd', 'swapaxes', + 'swapdims', 'sym_constrain_range', 'sym_constrain_range_for_size', 't', 't_copy', 'take', + 'take_along_dim', 'tan', 'tan_', 'tanh', 'tanh_', 'tensor', 'tensor_split', 'threshold', + 'threshold_', 'tile', 'topk', 'topk', 'trace', 'transpose', 'transpose_copy', 'trapezoid', 'trapz', + 'triangular_solve', 'triangular_solve', 'tril', 'tril_indices', 'triplet_margin_loss', 'triu', + 'triu_indices', 'true_divide', 'trunc', 'trunc_', 'unbind', 'unbind_copy', 'unflatten', + 'unfold_copy', 'unique_dim', 'unsafe_chunk', 'unsafe_split', 'unsafe_split_with_sizes', + 'unsqueeze', 'unsqueeze_copy', 'values_copy', 'vander', 'var', 'var_mean', 'vdot', + 'view_as_complex', 'view_as_complex_copy', 'view_as_real', 'view_as_real_copy', 'view_copy', + 'vsplit', 'vstack', 'where', 'xlogy', 'xlogy_', 'zero_', 'zeros', 'zeros_like'] diff --git a/MLPY/Lib/site-packages/torch/_C/__init__.pyi b/MLPY/Lib/site-packages/torch/_C/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ae27eb407aa24f500f6c21b09b51d5418bc2a12c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/__init__.pyi @@ -0,0 +1,10976 @@ +# @generated from torch/_C/__init__.pyi.in +# mypy: disable-error-code="type-arg" + +import builtins +from enum import Enum, IntEnum +from pathlib import Path +from typing import ( + Any, + AnyStr, + BinaryIO, + Callable, + ContextManager, + Dict, + Generic, + Iterable, + Iterator, + List, + Literal, + NamedTuple, + Optional, + Protocol, + Sequence, + Set, + SupportsIndex, + Tuple, + Type, + TypeVar, + Union, + overload, + runtime_checkable, +) +from typing_extensions import ParamSpec + +import torch +from torch import inf, SymInt, Tensor +from torch.autograd.graph import Node as _Node +from torch.package import PackageExporter +from torch.storage import UntypedStorage, TypedStorage +from torch.types import ( + _bool, + _complex, + _device, + _dispatchkey, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Device, + Number, + Storage, +) + +from torch._prims_common import DeviceLikeType + +# This module is defined in torch/csrc/Module.cpp + +from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu, _aoti, _verbose + +K = TypeVar("K") +T = TypeVar("T") +S = TypeVar("S", bound="torch.Tensor") +P = ParamSpec("P") +ReturnVal = TypeVar("ReturnVal", covariant=True) # return value (always covariant) +_T_co = TypeVar("_T_co", covariant=True) + + +@runtime_checkable +class _NestedSequence(Protocol[_T_co]): + """A protocol for representing nested sequences. + + References:: + `numpy._typing._NestedSequence` + + """ + + def __len__(self, /) -> builtins.int: ... + def __getitem__(self, index: builtins.int, /) -> _T_co | _NestedSequence[_T_co]: ... + def __contains__(self, x: builtins.object, /) -> builtins.bool: ... + def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... + def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... + def count(self, value: Any, /) -> builtins.int: ... + def index(self, value: Any, /) -> builtins.int: ... + + +# Defined in torch/csrc/Device.cpp +class device: + type: str # THPDevice_type + index: _int # THPDevice_index + + def __get__(self, instance, owner=None) -> device: ... + + # THPDevice_pynew + @overload + def __init__(self, device: DeviceLikeType) -> None: ... + @overload + def __init__(self, type: str, index: _int) -> None: ... + + # Uncomment if we ever make torch.device a decorator + # def __call__(self, func: T) -> T: ... + + def __enter__(self) -> device: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce + +# Defined in torch/csrc/Stream.cpp +class Stream: + stream_id: _int # Stream id + device_index: _int + device_type: _int + + device: device # The device of the stream + +# Defined in torch/csrc/Size.cpp +class Size(Tuple[_int, ...]): + # TODO: __reduce__ + + @overload # type: ignore[override] + def __getitem__(self: Size, key: _int) -> _int: ... + @overload + def __getitem__(self: Size, key: slice) -> Size: ... + def numel(self: Size) -> _int: ... + +# Defined in torch/csrc/Dtype.cpp +class dtype: + # TODO: __reduce__ + is_floating_point: _bool + is_complex: _bool + is_signed: _bool + itemsize: _int + def to_real(self) -> dtype: ... + def to_complex(self) -> dtype: ... + +# Defined in torch/csrc/TypeInfo.cpp +class iinfo: + bits: _int + min: _int + max: _int + dtype: str + + def __init__(self, dtype: _dtype) -> None: ... + +class finfo: + bits: _int + min: _float + max: _float + eps: _float + tiny: _float + smallest_normal: _float + resolution: _float + dtype: str + + @overload + def __init__(self, dtype: _dtype) -> None: ... + @overload + def __init__(self) -> None: ... + +float32: dtype = ... +float: dtype = ... +float64: dtype = ... +double: dtype = ... +float16: dtype = ... +bfloat16: dtype = ... +float8_e4m3fn: dtype = ... +float8_e4m3fnuz: dtype = ... +float8_e5m2: dtype = ... +float8_e5m2fnuz: dtype = ... +half: dtype = ... +uint8: dtype = ... +uint16: dtype = ... +uint32: dtype = ... +uint64: dtype = ... +int8: dtype = ... +int16: dtype = ... +short: dtype = ... +int32: dtype = ... +int: dtype = ... +int64: dtype = ... +long: dtype = ... +complex32: dtype = ... +complex64: dtype = ... +chalf: dtype = ... +cfloat: dtype = ... +complex128: dtype = ... +cdouble: dtype = ... +quint8: dtype = ... +qint8: dtype = ... +qint32: dtype = ... +bool: dtype = ... +quint4x2: dtype = ... +quint2x4: dtype = ... +bits1x8: dtype = ... +bits2x4: dtype = ... +bits4x2: dtype = ... +bits8: dtype = ... +bits16: dtype = ... + +# Defined in torch/csrc/Layout.cpp +class layout: ... + +# Defined in torch/csrc/utils/disable_torch_function.cpp +def DisableTorchFunction(): ... +def DisableTorchFunctionSubclass(): ... + +# Defined in torch/csrc/utils/tensor_layouts.cpp +strided: layout = ... +sparse_coo: layout = ... +sparse_csr: layout = ... +sparse_csc: layout = ... +sparse_bsr: layout = ... +sparse_bsc: layout = ... +_mkldnn: layout = ... +jagged: layout = ... + +# Defined in torch/csrc/MemoryFormat.cpp +class memory_format: ... + +# Defined in torch/csrc/utils/tensor_memoryformats.cpp +contiguous_format: memory_format = ... +channels_last: memory_format = ... +channels_last_3d: memory_format = ... +preserve_format: memory_format = ... + +# Defined in torch/csrc/QScheme.cpp +class qscheme: ... + +# Defined in torch/csrc/utils/tensor_qschemes.h +per_tensor_affine: qscheme = ... +per_channel_affine: qscheme = ... +per_tensor_symmetric: qscheme = ... +per_channel_symmetric: qscheme = ... +per_channel_affine_float_qparams: qscheme = ... + +# Defined in torch/csrc/autograd/python_function.cpp +class _FunctionBase: + saved_tensors: Tuple[Tensor] + _raw_saved_tensors: Tuple[Any] + next_functions: Tuple[Tuple[Any, _int], ...] + needs_input_grad: Tuple[_bool] + metadata: dict + _materialize_non_diff_grads: _bool + # skip adding type hints for the fields that have wrappers defined + # in torch/autograd/function.py + +# Defined in torch/csrc/autograd/python_legacy_variable.cpp +class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy + def __init__( + self, + data: Optional[Tensor] = ..., + requires_grad: Optional[_bool] = ..., + volatile: Optional[_bool] = ..., + _grad_fn: Optional[_FunctionBase] = ..., + ) -> None: ... + +# Defined in torch/csrc/jit/python/init.cpp +class IODescriptor: ... +class JITException: ... + +class Future(Generic[T]): + def __init__(self, devices: List[device]) -> None: ... + def done(self) -> _bool: ... + def value(self) -> T: ... + def wait(self) -> T: ... + def add_done_callback(self, callback: Callable) -> None: ... + def then(self, callback: Callable) -> Future[T]: ... + def set_result(self, result: T) -> None: ... + def _set_unwrap_func(self, callback: Callable) -> None: ... + +class _Await: + def __init__(self) -> None: ... + def fn(self) -> Callable: ... + def args(self) -> Tuple[Any, ...]: ... + def is_nowait(self) -> _bool: ... + +def _jit_set_num_profiled_runs(num: _size) -> _size: ... + +# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h +class _MobileOptimizerType: ... + +CONV_BN_FUSION: _MobileOptimizerType +INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType +REMOVE_DROPOUT: _MobileOptimizerType +FUSE_ADD_RELU: _MobileOptimizerType +HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType +VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType + +def fork(*args: Any, **kwargs: Any) -> Future: ... +def wait(fut: Future) -> Any: ... +def _awaitable(*args: Any, **kwargs: Any) -> _Await: ... +def _awaitable_wait(aw: _Await) -> Any: ... +def _awaitable_nowait(x: Any) -> _Await: ... +def _collect_all(futures: List[Future]) -> Future: ... +def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ... +def unify_type_list(types: List[JitType]) -> JitType: ... +def _freeze_module( + module: ScriptModule, + preserved_attrs: List[str] = [], + freeze_interfaces: _bool = True, + preserveParameters: _bool = True, +) -> ScriptModule: ... +def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ... +def _jit_pass_optimize_for_inference( + module: torch.jit.ScriptModule, + other_methods: List[str] = [], +) -> None: ... +def _jit_pass_fold_frozen_conv_bn(graph: Graph): ... +def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ... +def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ... +def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ... +def _jit_pass_concat_frozen_linear(graph: Graph): ... +def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ... +def _jit_pass_transpose_frozen_linear(graph: Graph): ... +def _jit_pass_remove_dropout(module: torch.jit.ScriptModule): ... +def _is_tracing() -> _bool: ... +def _jit_init() -> _bool: ... +def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... +def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ... +def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ... +def _get_operation_overload( + op_name: str, + op_overload_name: str, +) -> Tuple[Callable, Callable, List[Any]]: ... +def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ... +def _jit_pass_optimize_for_mobile( + module: torch.jit.ScriptModule, + optimization_blocklist: Set[_MobileOptimizerType], + preserved_methods: List[AnyStr], +) -> torch.jit.ScriptModule: ... +def _clone_module_with_class( + module: torch.jit.ScriptModule, + ignored_methods: List[AnyStr], + ignored_attributes: List[AnyStr], +) -> torch.jit.ScriptModule: ... +def _jit_pass_vulkan_optimize_for_mobile( + module: torch.jit.ScriptModule, + optimization_blocklist: Set[_MobileOptimizerType], + preserved_methods: List[AnyStr], +) -> torch.jit.ScriptModule: ... +def _jit_pass_metal_optimize_for_mobile( + module: torch.jit.ScriptModule, + preserved_methods: List[AnyStr], +) -> torch.jit.ScriptModule: ... +def _jit_pass_inline(Graph) -> None: ... +def _jit_pass_constant_propagation(Graph) -> None: ... +def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ... +def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ... +def _jit_erase_non_input_shape_information(Graph) -> None: ... +def _jit_get_schemas_for_operator(name: str) -> List[FunctionSchema]: ... +def _jit_get_all_schemas() -> List[FunctionSchema]: ... +def _jit_check_alias_annotation( + g: Graph, + args: Tuple[Any, ...], + unqualified_op_name: str, +): ... +def _jit_can_fuse_on_cpu() -> _bool: ... +def _jit_can_fuse_on_gpu() -> _bool: ... +def _jit_can_fuse_on_cpu_legacy() -> _bool: ... +def _debug_get_fusion_group_inlining() -> _bool: ... +def _debug_set_fusion_group_inlining(enable: _bool): ... +def _jit_texpr_fuser_enabled() -> _bool: ... +def _jit_nvfuser_enabled() -> _bool: ... +def _jit_llga_enabled() -> _bool: ... +def _jit_set_llga_enabled(enable: _bool): ... +def _llvm_enabled() -> _bool: ... +def _jit_override_can_fuse_on_cpu(override: _bool): ... +def _jit_override_can_fuse_on_gpu(override: _bool): ... +def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ... +def _jit_set_symbolic_shapes_test_mode(override: _bool): ... +def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ... +def _jit_set_texpr_fuser_enabled(enable: _bool): ... +def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ... +def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ... +def _jit_cat_wo_conditionals(optimize_cat: _bool): ... +def _jit_opt_conditionals(opt_conds: _bool): ... +def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ... +def _jit_pass_erase_shape_information(graph: Graph): ... +def _jit_pass_fold_convbn(module: torch.jit.ScriptModule): ... +def _jit_pass_insert_observers( + module: torch.jit.ScriptModule, + method_name: str, + qconfig_dict: Dict[str, Any], + inplace: _bool, + quant_type: _int, +): ... +def _jit_pass_insert_quant_dequant( + module: torch.jit.ScriptModule, + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int, +): ... +def _jit_pass_insert_quant_dequant_for_ondevice_ptq( + module: torch.jit.ScriptModule, + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int, +): ... +def _jit_pass_quant_finalize( + module: torch.jit.ScriptModule, + quant_type: _int, + preserved_attrs: Sequence[str], +): ... +def _jit_pass_quant_finalize_for_ondevice_ptq( + module: torch.jit.ScriptModule, + quant_type: _int, + method_name: str, +): ... +def _jit_pass_insert_observer_method_for_ondevice_ptq( + module: torch.jit.ScriptModule, + method_name: str, + qconfig_dict: Dict[str, Any], + inplace: _bool, + quant_type: _int, +): ... +def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ... +def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ... +def _jit_set_fusion_strategy( + strategy: List[Tuple[str, _int]], +) -> List[Tuple[str, _int]]: ... +def _jit_try_infer_type(obj: Any) -> InferredType: ... +def _jit_get_trigger_value(trigger_name: str) -> _int: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +ResolutionCallback = Callable[[str], Callable[..., Any]] + +# Defined in torch/csrc/jit/python/script_init.cpp +# and torch/csrc/jit/python/init.cpp +def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ... +def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... +def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... +def _jit_assert_is_instance(obj: Any, type: JitType): ... +def _jit_clear_class_registry() -> None: ... +def _jit_set_emit_hooks( + ModuleHook: Optional[Callable], + FunctionHook: Optional[Callable], +) -> None: ... +def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ... +def _load_for_lite_interpreter( + filename: Union[str, Path], + map_location: Optional[DeviceLikeType], +): ... +def _load_for_lite_interpreter_from_buffer( + buffer: BinaryIO, + map_location: Optional[DeviceLikeType], +): ... +def _export_operator_list(module: LiteScriptModule): ... +def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ... +def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ... +def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ... +def _backport_for_mobile( + filename_input: Union[str, Path], + filename_output: Union[str, Path], + to_version: _int, +) -> None: ... +def _backport_for_mobile_from_buffer( + buffer: BinaryIO, + filename_output: Union[str, Path], + to_version: _int, +) -> None: ... +def _backport_for_mobile_to_buffer( + filename_input: Union[str, Path], + to_version: _int, +) -> bytes: ... +def _backport_for_mobile_from_buffer_to_buffer( + buffer: BinaryIO, + to_version: _int, +) -> bytes: ... +def _get_model_ops_and_info(filename: Union[str, Path]): ... +def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ... +def _get_mobile_model_contained_types(filename: Union[str, Path]): ... +def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ... +def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ... +def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ... +def _set_graph_executor_optimize(optimize: _bool): ... +def _export_opnames(module: ScriptModule) -> List[str]: ... +def _create_function_from_trace( + qualname: str, + func: Callable[..., Any], + input_tuple: Tuple[Any, ...], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: List[str], +) -> Tuple[Graph, Stack]: ... +def _create_function_from_trace_with_dict( + qualname: str, + func: Callable[..., Any], + input_dict: Dict[str, Any], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: List[str], +) -> Tuple[Graph, Stack]: ... +def _jit_is_script_object(obj: Any) -> _bool: ... +def _last_executed_optimized_graph() -> Graph: ... +def parse_type_comment(comment: str) -> Decl: ... +def _get_upgraders_map_size() -> _int: ... +def _get_upgraders_entry_map() -> Dict[str, str]: ... +def _dump_upgraders_map() -> Dict[str, str]: ... +def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... +def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ... +def merge_type_from_type_comment( + decl: Decl, + type_annotation_decl: Decl, + is_method: _bool, +) -> Decl: ... +def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ... +def parse_schema(schema: str) -> FunctionSchema: ... +def get_device(input: Tensor) -> _int: ... +def _resolve_type_from_object( + obj: Any, + range: SourceRange, + rcb: ResolutionCallback, +) -> JitType: ... +def _create_module_with_type(ty: JitType) -> ScriptModule: ... +def _create_object_with_type(ty: ClassType) -> ScriptObject: ... +def _run_emit_module_hook(m: ScriptModule): ... +def _replace_overloaded_method_decl( + overload_decl: Decl, + implementation_def: Def, + new_name: str, +) -> Def: ... +def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... +def _jit_pass_onnx_set_dynamic_input_shape( + graph: Graph, + dynamic_axes: Dict[str, Dict[_int, str]], + input_names: List[str], +) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference( + graph: Graph, + params_dict: Dict[str, IValue], + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_assign_output_shape( + graph: Graph, + tensors: List[Tensor], + desc: IODescriptor, + onnx_shape_inference: _bool, + is_script: _bool, + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_remove_inplace_ops_for_onnx( + graph: Graph, + module: Optional[ScriptModule] = None, +) -> None: ... +def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... +def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... +def _jit_pass_peephole( + graph: Graph, + disable_shape_peepholes: _bool = False, +) -> None: ... +def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ... +def _jit_pass_fuse_addmm(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... +def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... +def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... +def _jit_pass_onnx_unpack_quantized_weights( + graph: Graph, + paramsDict: Dict[str, IValue], + caffe2: _bool, +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_quantization_insert_permutes( + graph: Graph, + paramsDict: Dict[str, IValue], +) -> Dict[str, IValue]: ... +def _jit_pass_custom_pattern_based_rewrite_graph( + pattern: str, + fused_node_name: str, + graph: Graph, +) -> None: ... +def _jit_onnx_list_model_parameters( + module: ScriptModule, +) -> Tuple[ScriptModule, List[IValue]]: ... +def _jit_pass_erase_number_types(graph: Graph) -> None: ... +def _jit_pass_onnx_lint(graph: Graph) -> None: ... +def _jit_pass_onnx( + graph: Graph, + _jit_pass_onnx: _onnx.OperatorExportTypes, +) -> Graph: ... +def _jit_pass_onnx_scalar_type_analysis( + graph: Graph, + lowprecision_cast: _bool, + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_peephole( + graph: Graph, + opset_version: _int, + fixed_batch_size: _bool, +) -> None: ... +def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... +def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... +def _jit_pass_onnx_function_extraction( + graph: Graph, + module_names: Set[str], + param_names: List[str], +) -> Dict[Node, Dict[str, str]]: ... +def _jit_pass_onnx_clear_scope_records() -> None: ... +def _jit_pass_onnx_track_scope_attributes( + graph: Graph, + onnx_attrs: Dict[str, Any], +) -> None: ... +def _jit_is_onnx_log_enabled() -> _bool: ... +def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ... +def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ... +def _jit_onnx_log(*args: Any) -> None: ... +def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ... +def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... +def _jit_pass_onnx_deduplicate_initializers( + graph: Graph, + params_dict: Dict[str, IValue], + is_train: _bool, +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eval_peephole( + graph: Graph, + paramsDict: Dict[str, IValue], +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_constant_fold( + graph: Graph, + paramsDict: Dict[str, IValue], + opset_version: _int, +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eliminate_unused_items( + graph: Graph, + paramsDict: Dict[str, IValue], +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... +def _jit_pass_filter_non_tensor_arguments( + params: Dict[str, IValue], +) -> Dict[str, Tensor]: ... +def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... +def _jit_pass_onnx_node_shape_type_inference( + n: Node, + paramsDict: Dict[str, IValue], + opset_version: _int, +) -> None: ... +def _jit_onnx_convert_pattern_from_subblock( + block: Block, + n: Node, + env: Dict[Value, Value], +) -> List[Value]: ... +def _jit_pass_onnx_block( + old_block: Block, + new_block: Block, + operator_export_type: _onnx.OperatorExportTypes, + env: Dict[Value, Value], + is_sub_block: _bool, +) -> Dict[Value, Value]: ... +def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ... +def _jit_pass_fixup_onnx_controlflow_node( + n: Node, + opset_version: _int, +) -> List[Value]: ... +def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ... +def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ... +def _generate_upgraders_graph() -> Dict[str, Graph]: ... +def _calculate_package_version_based_on_upgraders(val: _bool): ... +def _get_version_calculator_flag() -> _bool: ... +def _jit_script_interface_compile( + name: str, + class_def: ClassDef, + rcb: ResolutionCallback, + is_module: _bool, +): ... +def _jit_script_compile_overload( + qualname: str, + overload_decl: Decl, + implementation_def: Def, + rcb: ResolutionCallback, + implementation_defaults: Dict[str, Any], + signature: Any, +): ... +def _jit_script_compile( + qual_name: str, + definition: Def, + rcb: ResolutionCallback, + defaults: Dict[str, Any], +): ... +def _jit_script_class_compile( + qual_name: str, + definition: ClassDef, + defaults: Dict[str, Dict[str, Any]], + rcb: ResolutionCallback, +): ... +def _parse_source_def(src: str) -> Def: ... +def import_ir_module( + cu: CompilationUnit, + filename: Union[str, Path], + map_location: Optional[DeviceLikeType], + extra_files: Dict[str, Any], +) -> ScriptModule: ... +def import_ir_module_from_buffer( + cu: CompilationUnit, + buffer: BinaryIO, + map_location: Optional[DeviceLikeType], + extra_files: Dict[str, Any], +) -> ScriptModule: ... +def _import_ir_module_from_package( + cu: CompilationUnit, + reader: PyTorchFileReader, + storage_context: DeserializationStorageContext, + map_location: Optional[DeviceLikeType], + ts_id: str, +) -> ScriptModule: ... +def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ... +def _check_onnx_proto(proto: str) -> None: ... +def _propagate_and_assign_input_shapes( + graph: Graph, + inputs: Tuple[Tensor, ...], + param_count_list: List[_int], + with_grad: _bool, + propagate: _bool, +) -> Graph: ... + +# Defined in torch/csrc/jit/runtime/graph_executor.h +class GraphExecutorState: ... + +# Defined in torch/torch/csrc/jit/ir/alias_analysis.h +class AliasDb: + def __str__(self) -> str: ... + +class _InsertPoint: + def __enter__(self) -> None: ... + def __exit__(self, *args) -> None: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Use: + @property + def user(self) -> Node: ... + @property + def offset(self) -> _int: ... + def isAfter(self, other: Use) -> _bool: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Value: + def type(self) -> JitType: ... + def setType(self, t: JitType) -> Value: ... + def setTypeAs(self, other: Value) -> Value: ... + def inferTypeFrom(self, t: Tensor) -> None: ... + def debugName(self) -> str: ... + def setDebugName(self, name: str) -> None: ... + def unique(self) -> _int: ... + def offset(self) -> _int: ... + def node(self) -> Node: ... + def uses(self) -> List[Use]: ... + def replaceAllUsesWith(self, val: Value) -> None: ... + def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ... + def requires_grad(self) -> _bool: ... + def requiresGrad(self) -> _bool: ... + def copyMetadata(self, other: Value) -> Value: ... + def isCompleteTensor(self) -> _bool: ... + def toIValue(self) -> IValue: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Block: + def inputs(self) -> Iterator[Value]: ... + def outputs(self) -> Iterator[Value]: ... + def nodes(self) -> Iterator[Node]: ... + def paramNode(self) -> Node: ... + def returnNode(self) -> Node: ... + def owningNode(self) -> Node: ... + def registerOutput(self, n: Value) -> _int: ... + def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Node: + def __getitem__(self, key: str) -> Any: ... + def schema(self) -> str: ... + def input(self) -> Value: ... + def inputs(self) -> Iterator[Value]: ... + def inputsAt(self, idx: _int) -> Value: ... + def inputsSize(self) -> _int: ... + def output(self) -> Value: ... + def outputs(self) -> Iterator[Value]: ... + def outputsAt(self, idx: _int) -> Value: ... + def outputsSize(self) -> _int: ... + def hasMultipleOutputs(self) -> _bool: ... + def blocks(self) -> List[Block]: ... + def addBlock(self) -> Block: ... + def mustBeNone(self) -> _bool: ... + def matches(self, pattern: str) -> _bool: ... + def kind(self) -> str: ... + def kindOf(self, name: str) -> str: ... + def addInput(self, name: str) -> Value: ... + def replaceInput(self, i: _int, newValue: Value) -> Value: ... + def replaceInputWith(self, from_: Value, to: Value) -> None: ... + def replaceAllUsesWith(self, n: Node) -> None: ... + def insertBefore(self, n: Node) -> Node: ... + def insertAfter(self, n: Node) -> Node: ... + def isBefore(self, n: Node) -> _bool: ... + def isAfter(self, n: Node) -> _bool: ... + def moveBefore(self, n: Node) -> None: ... + def moveAfter(self, n: Node) -> None: ... + def removeInput(self, i: _int) -> None: ... + def removeAllInputs(self, i: _int) -> None: ... + def hasUses(self) -> _bool: ... + def eraseOutput(self, i: _int) -> None: ... + def addOutput(self) -> Value: ... + def scopeName(self) -> str: ... + def isNondeterministic(self) -> _bool: ... + def copyAttributes(self, rhs: Node) -> Node: ... + def copyMetadata(self, rhs: Node) -> Node: ... + def hasAttributes(self) -> _bool: ... + def hasAttribute(self, name: str) -> _bool: ... + def removeAttribute(self, attr: str) -> Node: ... + def namedInput(self, name: str) -> Value: ... + def sourceRange(self) -> SourceRange: ... + def owningBlock(self) -> Block: ... + def findNode(self, kind: str, recurse: _bool = True) -> Node: ... + def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ... + def getModuleHierarchy(self) -> str: ... + def prev(self) -> Node: ... + def destroy(self) -> None: ... + def attributeNames(self) -> List[str]: ... + + # Accessors for attributes as types. + def f(self, name: str) -> _float: ... + def f_(self, name: str, val: _float) -> Node: ... + def fs(self, name: str) -> List[_float]: ... + def fs_(self, name: str, val: List[_float]) -> Node: ... + def c(self, name: str) -> complex: ... + def c_(self, name: str, val: complex) -> Node: ... + def s(self, name: str) -> str: ... + def s_(self, name: str, val: str) -> Node: ... + def ss(self, name: str) -> List[str]: ... + def ss_(self, name: str, val: List[str]) -> Node: ... + def i(self, name: str) -> _int: ... + def i_(self, name: str, val: _int) -> Node: ... + # Cannot define "is" like this because it's a reserved keyword in python. + # def is(self, name: str) -> List[_int]: ... + # def is_(self, name: str, val: List[_int]) -> Node: ... + def g(self, name: str) -> Graph: ... + def g_(self, name: str, val: Graph) -> Node: ... + def gs(self, name: str) -> List[Graph]: ... + def gs_(self, name: str, val: List[Graph]) -> Node: ... + def ival(self, name: str) -> IValue: ... + def ival_(self, name: str, val: IValue) -> Node: ... + def t(self, name: str) -> Tensor: ... + def t_(self, name: str, val: Tensor) -> Node: ... + def ts(self, name: str) -> List[Tensor]: ... + def ts_(self, name: str, val: List[Tensor]) -> Node: ... + def ty(self, name: str) -> JitType: ... + def ty_(self, name: str, val: JitType) -> Node: ... + def tys(self, name: str) -> List[JitType]: ... + def tys_(self, name: str, val: List[JitType]) -> Node: ... + +# Defined in torch/torch/csrc/jit/ir/ir.h +class Graph: + def inputs(self) -> Iterator[Value]: ... + def outputs(self) -> Iterator[Value]: ... + def nodes(self) -> Iterator[Node]: ... + def param_node(self) -> Node: ... + def return_node(self) -> Node: ... + def addInput(self, name: str = "") -> Value: ... + def eraseInput(self, i: _int) -> None: ... + def registerOutput(self, n: Value) -> _int: ... + def eraseOutput(self, i: _int) -> None: ... + def create(self, name: str, args, num_outputs: _int) -> Node: ... + def appendNode(self, n: Node) -> Node: ... + def prependNode(self, n: Node) -> Node: ... + def insertNode(self, n: Node) -> Node: ... + def block(self) -> Block: ... + def lint(self) -> None: ... + def alias_db(self) -> AliasDb: ... + def setInsertPoint(self, n: Union[Block, Node]) -> None: ... + def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ... + def insertPoint(self) -> Node: ... + def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ... + def makeMultiOutputIntoTuple(self) -> None: ... + def copy(self) -> Graph: ... + +# Defined in torch/aten/src/ATen/core/alias_info.h +class AliasInfo: + is_write: _bool + before_set: Set[str] + after_set: Set[str] + +# Defined in torch/aten/src/ATen/core/function_schema.h +class Argument: + name: str + type: JitType + default_value: Optional[Any] + def has_default_value(self) -> _bool: ... + kwarg_only: _bool + is_out: _bool + alias_info: Optional[AliasInfo] + +class FunctionSchema: + arguments: List[Argument] + returns: List[Argument] + name: str + overload_name: str + +class _UpgraderEntry: + bumped_at_version: _int + upgrader_name: str + old_schema: str + def __init__( + self, + bumped_at_version: _int, + upgrader_name: str, + old_schema: str, + ) -> None: ... + +class _UpgraderRange: + min_version: _int + max_version: _int + +def _get_max_operator_version() -> _int: ... +def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ... +def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ... +def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ... +def _test_only_remove_entry_to_op_version(op_name: str) -> None: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class ScriptModuleSerializer: + def __init__(self, export_writer: PyTorchFileWriter) -> None: ... + def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ... + def write_files(self) -> None: ... + def storage_context(self) -> SerializationStorageContext: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class SerializationStorageContext: + def __init__(self) -> None: ... + def has_storage(self, storage: Storage) -> _bool: ... + def get_or_add_storage(self, storage: Storage) -> _int: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class DeserializationStorageContext: + def __init__(self) -> None: ... + def get_storage(self, name: str, dtype: _dtype) -> Tensor: ... + def has_storage(self, name: str) -> _bool: ... + def add_storage(self, name: str, tensor: Tensor) -> _int: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class ConcreteModuleTypeBuilder: + def __init__(self, obj: Any) -> None: ... + def set_module_dict(self): ... + def set_module_list(self): ... + def set_parameter_list(self): ... + def set_parameter_dict(self): ... + def add_attribute( + self, + name: str, + ty: JitType, + is_param: _bool, + is_buffer: _bool, + ): ... + def add_module(self, name: str, meta: ConcreteModuleType): ... + def add_constant(self, name: str, value: Any): ... + def add_overload(self, method_name: str, overloaded_method_names: List[str]): ... + def add_builtin_function(self, name: str, symbol_name: str): ... + def add_failed_attribute(self, name: str, failure_reason: str): ... + def add_function_attribute( + self, + name: str, + ty: JitType, + func: Callable[..., Any], + ): ... + def add_ignored_attribute(self, name: str): ... + def add_ignored_attributes(self, names: List[str]): ... + def add_forward_hook(self, hook: Callable[..., Any]): ... + def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ... + +class ConcreteModuleType: + def get_constants(self) -> Dict[str, Any]: ... + def equals(self, other: ConcreteModuleType) -> _bool: ... + @staticmethod + def from_jit_type(ty: JitType) -> ConcreteModuleType: ... + +class CallStack: + def __init__(self, name: str, range: SourceRange): ... + +class ErrorReport: + def __init__(self, range: SourceRange) -> None: ... + def what(self) -> str: ... + @staticmethod + def call_stack() -> str: ... + +class CompilationUnit: + def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ... + def find_function(self, name: str) -> ScriptFunction: ... + def __getattr__(self, name: str) -> ScriptFunction: ... + def define( + self, + script: str, + rcb: ResolutionCallback = ..., + _frames_up: _int = ..., + ): ... + def get_interface(self, name: str) -> InterfaceType: ... + def get_functions(self) -> List[ScriptFunction]: ... + def create_function( + self, + name: str, + graph: Graph, + shouldMangle: _bool = ..., + ) -> ScriptFunction: ... + def get_class(self, name: str) -> ClassType: ... + +class ScriptObject: + def setattr(self, name: str, value: Any): ... + +class ScriptModule(ScriptObject): + def _method_names(self) -> List[str]: ... + def _get_method(self, name: str) -> ScriptMethod: ... + +class LiteScriptModule: + def __call__(self, *input): ... + def find_method(self, method_name: str): ... + def forward(self, *input) -> List[str]: ... + def run_method(self, method_name: str, *input): ... + +# NOTE: switch to collections.abc.Callable in python 3.9 +class ScriptFunction(Generic[P, ReturnVal]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... + def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ... + def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ... + @property + def graph(self) -> Graph: ... + def inlined_graph(self) -> Graph: ... + def schema(self) -> FunctionSchema: ... + def code(self) -> str: ... + def name(self) -> str: ... + @property + def qualified_name(self) -> str: ... + +# NOTE: switch to collections.abc.Callable in python 3.9 +class ScriptMethod(Generic[P, ReturnVal]): + graph: Graph + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... + @property + def owner(self) -> ScriptModule: ... + @property + def name(self) -> str: ... + +class ScriptDict(Generic[K, T]): + def __init__(self, dict: Dict[K, T]) -> None: ... + def __len__(self) -> _int: ... + def __contains__(self, key: K) -> _bool: ... + def __getitem__(self, key: K) -> T: ... + def __setitem__(self, key: K, value: T) -> None: ... + def __delitem__(self, key: K) -> None: ... + def __iter__(self) -> Iterator[K]: ... + def items(self) -> Iterator[tuple[K, T]]: ... + def keys(self) -> Iterator[K]: ... + +class ScriptList(Generic[T]): + def __init__(self, list: List[T]) -> None: ... + def __len__(self) -> _int: ... + def __contains__(self, item: T) -> _bool: ... + @overload + def __getitem__(self, idx: _int) -> T: ... + @overload + def __getitem__(self, idx: slice) -> ScriptList[T]: ... + @overload + def __setitem__(self, idx: _int, value: T) -> None: ... + @overload + def __setitem__(self, idx: slice, value: List[T]) -> None: ... + def __delitem__(self, idx: _int) -> None: ... + def __iter__(self) -> Iterator[T]: ... + def count(self, value: T) -> _int: ... + def remove(self, value: T) -> None: ... + def append(self, value: T) -> None: ... + def clear(self) -> None: ... + @overload + def extend(self, values: List[T]) -> None: ... + @overload + def extend(self, values: Iterable[T]) -> None: ... + @overload + def pop(self) -> T: ... + @overload + def pop(self, idx: _int) -> T: ... + +class ModuleDict: + def __init__(self, mod: ScriptModule) -> None: ... + def items(self) -> List[Tuple[str, Any]]: ... + +class ParameterDict: + def __init__(self, mod: ScriptModule) -> None: ... + +class BufferDict: + def __init__(self, mod: ScriptModule) -> None: ... + +# Defined in torch/csrc/jit/api/module.h +class Module: ... + +# Defined in torch/csrc/Module.cpp +def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension +def _autograd_init() -> _bool: ... # THPAutograd_initExtension +def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr +def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames +def _has_distributed() -> _bool: ... # THPModule_hasDistributed +def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType +def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype +def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize +def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN +def _crash_if_csrc_ubsan() -> _int: ... # THPModule_crashIfCsrcUBSAN +def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN +def _show_config() -> str: ... # THPModule_showConfig +def _cxx_flags() -> str: ... # THPModule_cxxFlags +def _parallel_info() -> str: ... # THPModule_parallelInfo +def _get_cpu_capability() -> str: ... # THPModule_getCpuCapability +def _set_backcompat_broadcast_warn( + arg: _bool, +) -> None: ... # THPModule_setBackcompatBroadcastWarn +def _get_backcompat_broadcast_warn() -> _bool: ... # THPModule_getBackcompatBroadcastWarn +def _set_backcompat_keepdim_warn( + arg: _bool, +) -> None: ... # THPModule_setBackcompatKeepdimWarn +def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn +def get_num_thread() -> _int: ... # THPModule_getNumThreads +def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads +def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads +def set_num_interop_threads( + nthreads: _int, +) -> None: ... # THPModule_setNumInteropThreads +def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN +def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN +def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP +def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash +def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_mem_efficient( + arg: _bool, +) -> None: ... # THPModule_setSDPUseMemEfficient +def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath +def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_cudnn(arg: _bool) -> None: ... # THPModule_setSDPUseMath +def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn +def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn +def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN +def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN +def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN +def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN +def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms +def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly +def _set_deterministic_algorithms( + mode: _bool, + *, + warn_only: _bool = ..., +) -> None: ... # THPModule_setDeterministicAlgorithms +def _get_deterministic_fill_uninitialized_memory() -> _bool: ... # THPModule_deterministicFillUninitializedMemory +def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ... # THPModule_setDeterministicFillUninitializedMemory +def _get_nnpack_enabled() -> _bool: ... # THPModule_userEnabledNNPACK +def _set_nnpack_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledNNPACK +def _get_warnAlways() -> _bool: ... # THPModule_warnAlways +def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways +def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN +def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN +def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS +def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS +def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecision +def _set_float32_matmul_precision( + arg: str, +) -> None: ... # THPModule_setFloat32MatmulPrecision +def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... # THPModule_allowFP16ReductionCuBLAS +def _set_cublas_allow_fp16_reduced_precision_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS +def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... # THPModule_allowBF16ReductionCuBLAS +def _set_cublas_allow_bf16_reduced_precision_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS +def _set_conj(x: Tensor, conj: _bool) -> None: ... +def _set_neg(x: Tensor, neg: _bool) -> None: ... +def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... +def _meta_in_tls_dispatch_include() -> _bool: ... +def _stash_obj_in_tls(key: str, arg: Any) -> None: ... +def _get_obj_in_tls(key: str) -> Any: ... +def _is_key_in_tls(key: str) -> _bool: ... +def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... +def _conv_determine_backend_memory_format( + input: Tensor, + weight: Tensor, + backend: ConvBackend, +) -> memory_format: ... +def _has_storage(x: Tensor) -> _bool: ... +def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... +def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... +def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... + +# NB: There is no Capsule type in typing, see +# https://code.activestate.com/lists/python-dev/139675/ +def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack +def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack +def _get_cpp_backtrace( + frames_to_skip: _int, + maximum_number_of_frames: _int, +) -> str: ... # THPModule_getCppBacktrace +def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal +def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype +def _get_default_device() -> str: ... # THPModule_getDefaultDevice +def _get_qengine() -> _int: ... # THPModule_qEngine +def _set_qengine(qengine: _int) -> None: ... # THPModule_setQEngine +def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines +def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK +def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants +def _set_check_sparse_tensor_invariants( + arg: _bool, +) -> None: ... # THPModule_setCheckSparseTensorInvariants +def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator +def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator +def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction +def _has_torch_function( + args: Iterable[Any], +) -> _bool: ... # THPModule_has_torch_function +def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary +def _has_torch_function_variadic( + *args: Any, +) -> _bool: ... # THPModule_has_torch_function_variadic +def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting +def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting +def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython +def _log_api_usage_metadata(event: str, metadata_map: Dict[str, str]) -> None: ... # LogAPIUsageMetadataFromPython +def _demangle(str) -> str: ... # c10::demangle +def _disabled_torch_function_impl( + func: Callable, + types: Iterable[Type], + args: Tuple, + kwargs: Dict, +) -> Any: ... # THPModule_disable_torch_function +def _disabled_torch_dispatch_impl( + func: Callable, + types: Iterable[Type], + args: Tuple, + kwargs: Dict, +) -> Any: ... # THPModule_disable_dispatch_function +def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ... +def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ... + +class _LinalgBackend: + Default: _LinalgBackend + Cusolver: _LinalgBackend + Magma: _LinalgBackend + +class ConvBackend(Enum): ... + +class Tag(Enum): + core: _int = 0 + data_dependent_output: _int = 1 + dynamic_output_shape: _int = 2 + generated: _int = 3 + inplace_view: _int = 4 + needs_fixed_stride_order: _int = 5 + nondeterministic_bitwise: _int = 6 + nondeterministic_seeded: _int = 7 + pointwise: _int = 8 + pt2_compliant_tag: _int = 9 + view_copy: _int = 10 + +# Defined in `valgrind.h` and `callgrind.h` respectively. +def _valgrind_supported_platform() -> _bool: ... # NVALGRIND +def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT +def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS + +has_openmp: _bool +has_mkl: _bool +_has_mps: _bool +has_lapack: _bool +_has_cuda: _bool +_has_magma: _bool +_has_xpu: _bool +_has_mkldnn: _bool +_has_cudnn: _bool +has_spectral: _bool +_GLIBCXX_USE_CXX11_ABI: _bool +default_generator: Generator + +# Defined in torch/csrc/autograd/init.cpp +def _set_grad_enabled(enabled: _bool) -> None: ... +def is_grad_enabled() -> _bool: ... +def _set_fwd_grad_enabled(enabled: _bool) -> None: ... +def _is_fwd_grad_enabled() -> _bool: ... +def is_inference_mode_enabled() -> _bool: ... +def set_autocast_enabled(enabled: _bool) -> None: ... +def is_autocast_enabled() -> _bool: ... +def clear_autocast_cache() -> None: ... +def set_autocast_cpu_enabled(enabled: _bool) -> None: ... +def is_autocast_cpu_enabled() -> _bool: ... +def _is_any_autocast_enabled() -> _bool: ... +def set_autocast_cpu_dtype(dtype: _dtype) -> None: ... +def set_autocast_gpu_dtype(dtype: _dtype) -> None: ... +def get_autocast_cpu_dtype() -> _dtype: ... +def get_autocast_gpu_dtype() -> _dtype: ... +def autocast_increment_nesting() -> _int: ... +def autocast_decrement_nesting() -> _int: ... +def is_autocast_cache_enabled() -> _bool: ... +def set_autocast_cache_enabled(enabled: _bool) -> None: ... +def _increment_version(tensor: Tensor) -> None: ... +def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ... +def is_anomaly_enabled() -> _bool: ... +def is_anomaly_check_nan_enabled() -> _bool: ... +def _is_multithreading_enabled() -> _bool: ... +def _set_multithreading_enabled(enabled: _bool) -> None: ... +def _set_view_replay_enabled(enabled: _bool) -> None: ... +def _is_view_replay_enabled() -> _bool: ... +def _enter_dual_level() -> _int: ... +def _exit_dual_level(level: _int) -> None: ... +def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ... +def __set_forward_AD_enabled(enabled: _bool) -> None: ... +def __is_forward_AD_enabled() -> _bool: ... +def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... +def _reset_default_hooks() -> None: ... +def _is_torch_function_mode_enabled() -> _bool: ... +def _set_torch_function_mode(cls: Any) -> None: ... +def _push_on_torch_function_stack(cls: Any) -> None: ... +def _pop_torch_function_stack() -> Any: ... +def _get_function_stack_at(idx: _int) -> Any: ... +def _len_torch_function_stack() -> _int: ... +def _set_torch_dispatch_mode(cls: Any) -> None: ... +def _push_on_torch_dispatch_stack(cls: Any) -> None: ... +def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ... +def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ... +def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Any: ... +def _set_dispatch_mode(mode: Any) -> None: ... +def _get_dispatch_stack_at(idx: _int) -> Any: ... +def _len_torch_dispatch_stack() -> _int: ... + +class _DisableTorchDispatch: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _EnableTorchFunction: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _EnablePythonDispatcher: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _DisablePythonDispatcher: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _EnablePreDispatch: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _DisableFuncTorch: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _DisableAutocast: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _InferenceMode: + def __init__(self, enabled: _bool): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +def _set_autograd_fallback_mode(mode: str) -> None: ... +def _get_autograd_fallback_mode() -> str: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class LoggerBase: ... +class NoopLogger(LoggerBase): ... +class LockingLogger(LoggerBase): ... + +class AggregationType(Enum): + SUM = 0 + AVG = 1 + +class FileCheck: + def run(self, test_string: str) -> None: ... + def check(self, test_string: str) -> FileCheck: ... + def check_not(self, test_string: str) -> FileCheck: ... + def check_same(self, test_string: str) -> FileCheck: ... + def check_next(self, test_string: str) -> FileCheck: ... + def check_count( + self, + test_string: str, + count: _int, + exactly: _bool = False, + ) -> FileCheck: ... + def check_dag(self, test_string: str) -> FileCheck: ... + def check_source_highlighted(self, test_string: str) -> FileCheck: ... + def check_regex(self, test_string: str) -> FileCheck: ... + +# Defined in torch/csrc/jit/python/init.cpp +class PyTorchFileReader: + @overload + def __init__(self, name: str) -> None: ... + @overload + def __init__(self, buffer: BinaryIO) -> None: ... + def get_record(self, name: str) -> bytes: ... + def serialization_id(self) -> str: ... + +class PyTorchFileWriter: + @overload + def __init__(self, name: str) -> None: ... + @overload + def __init__(self, buffer: BinaryIO) -> None: ... + def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ... + def write_end_of_file(self) -> None: ... + def set_min_version(self, version: _int) -> None: ... + def get_all_written_records(self) -> List[str]: ... + def archive_name(self) -> str: ... + def serialization_id(self) -> str: ... + +def _jit_get_inline_everything_mode() -> _bool: ... +def _jit_set_inline_everything_mode(enabled: _bool) -> None: ... +def _jit_get_logging_option() -> str: ... +def _jit_set_logging_option(option: str) -> None: ... +def _jit_set_logging_stream(stream_name: str) -> None: ... +def _jit_pass_cse(Graph) -> _bool: ... +def _jit_pass_dce(Graph) -> None: ... +def _jit_pass_lint(Graph) -> None: ... + +# Defined in torch/csrc/jit/python/python_custom_class.cpp +def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ... + +# Defined in torch/csrc/Module.cpp +def _rename_privateuse1_backend(backend: str) -> None: ... +def _get_privateuse1_backend_name() -> str: ... + +# Defined in torch/csrc/Generator.cpp +class Generator: + device: _device + def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ... + def get_state(self) -> Tensor: ... + def set_state(self, _new_state: Tensor) -> Generator: ... + def set_offset(self, offset: _int) -> Generator: ... + def get_offset(self) -> _int: ... + def manual_seed(self, seed: _int) -> Generator: ... + def seed(self) -> _int: ... + def initial_seed(self) -> _int: ... + +# Defined in torch/csrc/utils/python_dispatch.cpp + +class _DispatchOperatorHandle: + def schema(self) -> FunctionSchema: ... + def debug(self) -> str: ... + +class _DispatchModule: + def def_(self, schema: str, alias: str = "") -> _DispatchModule: ... + def def_legacy(self, schema: str) -> _DispatchModule: ... + def def_name_t_t( + self, + name: str, + dispatch: str, + debug: str = "default_def_name_t_t", + ) -> _DispatchModule: ... + def def_schema_t_t( + self, + schema: str, + dispatch: str, + alias: str, + debug: str = "default_def_schema_t_t", + ) -> _DispatchModule: ... + def impl_t_t( + self, + name: str, + dispatch: str, + debug: str = "impl_t_t", + ) -> _DispatchModule: ... + def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ... + def define(self, schema: str, alias: str = "") -> _DispatchModule: ... + def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ... + +def _dispatch_library( + kind: str, + name: str, + dispatch: str, + file: str = "", + linenum: Any = 0, +) -> _DispatchModule: ... +def _dispatch_dump(name: str) -> str: ... +def _dispatch_dump_table(name: str) -> str: ... +def _dispatch_check_invariants(name: str) -> None: ... +def _dispatch_check_all_invariants() -> None: ... +def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ... +def _dispatch_find_schema_or_throw(name: str, overload_name: str) -> _DispatchOperatorHandle: ... +def _dispatch_set_report_error_callback(handle: _DispatchOperatorHandle, callback: Callable) -> None: ... +def _dispatch_has_kernel(name: str) -> _bool: ... +def _dispatch_has_kernel_for_dispatch_key( + name: str, + dispatch: _dispatchkey, +) -> _bool: ... +def _dispatch_has_kernel_for_any_dispatch_key( + name: str, + dispatch_key_set: DispatchKeySet, +) -> _bool: ... +def _dispatch_has_computed_kernel_for_dispatch_key( + name: str, + dispatch: _dispatchkey, +) -> _bool: ... +def _dispatch_find_dangling_impls() -> List[str]: ... +def _dispatch_get_all_op_names() -> List[str]: ... +def _dispatch_tls_set_dispatch_key_excluded( + dispatch: _dispatchkey, + val: _bool, +) -> None: ... +def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_tls_set_dispatch_key_included( + dispatch: _dispatchkey, + val: _bool, +) -> None: ... +def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ... +def _dispatch_key_name(dispatch: _dispatchkey) -> str: ... +def _dispatch_key_for_device(device_type: str) -> str: ... +def _parse_dispatch_key(key: str) -> Optional[DispatchKey]: ... +def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ... +def _dispatch_num_backends() -> _int: ... +def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ... +def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ... +def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ... +def _functionalization_reapply_views_tls() -> _bool: ... + +class DispatchKey(Enum): + Undefined: DispatchKey = ... + FPGA: DispatchKey = ... + ORT: DispatchKey = ... + Vulkan: DispatchKey = ... + Metal: DispatchKey = ... + MKLDNN: DispatchKey = ... + OpenGL: DispatchKey = ... + OpenCL: DispatchKey = ... + IDEEP: DispatchKey = ... + CustomRNGKeyId: DispatchKey = ... + MkldnnCPU: DispatchKey = ... + Sparse: DispatchKey = ... + SparseCsr: DispatchKey = ... + NestedTensor: DispatchKey = ... + Dense: DispatchKey = ... + PreDispatch: DispatchKey = ... + Python: DispatchKey = ... + FuncTorchDynamicLayerBackMode: DispatchKey = ... + ZeroTensor: DispatchKey = ... + Conjugate: DispatchKey = ... + Negative: DispatchKey = ... + BackendSelect: DispatchKey = ... + Named: DispatchKey = ... + AutogradOther: DispatchKey = ... + AutogradFunctionality: DispatchKey = ... + AutogradNestedTensor: DispatchKey = ... + Tracer: DispatchKey = ... + Autocast: DispatchKey = ... + Batched: DispatchKey = ... + VmapMode: DispatchKey = ... + FuncTorchGradWrapper: DispatchKey = ... + FuncTorchBatched: DispatchKey = ... + BatchedNestedTensor: DispatchKey = ... + FuncTorchVmapMode: DispatchKey = ... + FuncTorchDynamicLayerFrontMode: DispatchKey = ... + Functionalize: DispatchKey = ... + TESTING_ONLY_GenericWrapper: DispatchKey = ... + TESTING_ONLY_GenericMode: DispatchKey = ... + ADInplaceOrView: DispatchKey = ... + Autograd: DispatchKey = ... + CompositeImplicitAutograd: DispatchKey = ... + CompositeImplicitAutogradNestedTensor: DispatchKey = ... + CompositeExplicitAutograd: DispatchKey = ... + CompositeExplicitAutogradNonFunctional: DispatchKey = ... + FuncTorchBatchedDecomposition: DispatchKey = ... + CPU: DispatchKey = ... + CUDA: DispatchKey = ... + HIP: DispatchKey = ... + XLA: DispatchKey = ... + MTIA: DispatchKey = ... + MPS: DispatchKey = ... + IPU: DispatchKey = ... + XPU: DispatchKey = ... + HPU: DispatchKey = ... + VE: DispatchKey = ... + Lazy: DispatchKey = ... + Meta: DispatchKey = ... + PrivateUse1: DispatchKey = ... + PrivateUse2: DispatchKey = ... + PrivateUse3: DispatchKey = ... + QuantizedCPU: DispatchKey = ... + QuantizedCUDA: DispatchKey = ... + QuantizedHIP: DispatchKey = ... + QuantizedXLA: DispatchKey = ... + QuantizedMTIA: DispatchKey = ... + QuantizedMPS: DispatchKey = ... + QuantizedIPU: DispatchKey = ... + QuantizedXPU: DispatchKey = ... + QuantizedHPU: DispatchKey = ... + QuantizedVE: DispatchKey = ... + QuantizedLazy: DispatchKey = ... + QuantizedMeta: DispatchKey = ... + QuantizedPrivateUse1: DispatchKey = ... + QuantizedPrivateUse2: DispatchKey = ... + QuantizedPrivateUse3: DispatchKey = ... + SparseCPU: DispatchKey = ... + SparseCUDA: DispatchKey = ... + SparseHIP: DispatchKey = ... + SparseXLA: DispatchKey = ... + SparseMTIA: DispatchKey = ... + SparseMPS: DispatchKey = ... + SparseIPU: DispatchKey = ... + SparseXPU: DispatchKey = ... + SparseHPU: DispatchKey = ... + SparseVE: DispatchKey = ... + SparseLazy: DispatchKey = ... + SparseMeta: DispatchKey = ... + SparsePrivateUse1: DispatchKey = ... + SparsePrivateUse2: DispatchKey = ... + SparsePrivateUse3: DispatchKey = ... + SparseCsrCPU: DispatchKey = ... + SparseCsrCUDA: DispatchKey = ... + SparseCsrHIP: DispatchKey = ... + SparseCsrXLA: DispatchKey = ... + SparseCsrMTIA: DispatchKey = ... + SparseCsrMPS: DispatchKey = ... + SparseCsrIPU: DispatchKey = ... + SparseCsrXPU: DispatchKey = ... + SparseCsrHPU: DispatchKey = ... + SparseCsrVE: DispatchKey = ... + SparseCsrLazy: DispatchKey = ... + SparseCsrMeta: DispatchKey = ... + SparseCsrPrivateUse1: DispatchKey = ... + SparseCsrPrivateUse2: DispatchKey = ... + SparseCsrPrivateUse3: DispatchKey = ... + NestedTensorCPU: DispatchKey = ... + NestedTensorCUDA: DispatchKey = ... + NestedTensorHIP: DispatchKey = ... + NestedTensorXLA: DispatchKey = ... + NestedTensorMTIA: DispatchKey = ... + NestedTensorMPS: DispatchKey = ... + NestedTensorIPU: DispatchKey = ... + NestedTensorXPU: DispatchKey = ... + NestedTensorHPU: DispatchKey = ... + NestedTensorVE: DispatchKey = ... + NestedTensorLazy: DispatchKey = ... + NestedTensorMeta: DispatchKey = ... + NestedTensorPrivateUse1: DispatchKey = ... + NestedTensorPrivateUse2: DispatchKey = ... + NestedTensorPrivateUse3: DispatchKey = ... + AutogradCPU: DispatchKey = ... + AutogradCUDA: DispatchKey = ... + AutogradHIP: DispatchKey = ... + AutogradXLA: DispatchKey = ... + AutogradMTIA: DispatchKey = ... + AutogradMPS: DispatchKey = ... + AutogradIPU: DispatchKey = ... + AutogradXPU: DispatchKey = ... + AutogradHPU: DispatchKey = ... + AutogradVE: DispatchKey = ... + AutogradLazy: DispatchKey = ... + AutogradMeta: DispatchKey = ... + AutogradPrivateUse1: DispatchKey = ... + AutogradPrivateUse2: DispatchKey = ... + AutogradPrivateUse3: DispatchKey = ... + +class DispatchKeySet: + def __init__(self, key: DispatchKey) -> None: ... + def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def highestPriorityTypeId(self) -> DispatchKey: ... + def has(self, k: _dispatchkey) -> _bool: ... + def add(self, k: _dispatchkey) -> DispatchKeySet: ... + def remove(self, k: _dispatchkey) -> DispatchKeySet: ... + def __repr__(self) -> str: ... + +_dispatch_autogradother_backends: DispatchKeySet +_additional_keys_to_prop_for_wrapper_tensors: DispatchKeySet + +def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ... +def _dispatch_keyset_full() -> DispatchKeySet: ... +def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ... +def _dispatch_get_backend_keyset_from_autograd( + dispatch: _dispatchkey, +) -> DispatchKeySet: ... +def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ... +def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ... +def _dispatch_tls_local_include_set() -> DispatchKeySet: ... +def _dispatch_is_included_in_alias( + dispatch_a: _dispatchkey, + dispatch_b: _dispatchkey, +) -> _bool: ... +def _propagate_xla_data(a: Tensor, b: Tensor) -> None: ... +def _replace_(a: Tensor, b: Tensor) -> None: ... +def _commit_update(a: Tensor) -> None: ... + +class _ExcludeDispatchKeyGuard: + def __init__(self, keyset: DispatchKeySet): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _IncludeDispatchKeyGuard: + def __init__(self, k: DispatchKey): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _ForceDispatchKeyGuard: + def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +class _AutoDispatchBelowAutograd: + def __init__(self): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ... +def _dispatch_get_registrations_for_dispatch_key( + dispatch_key: str = "", +) -> List[str]: ... +def _are_functorch_transforms_active() -> _bool: ... + +# Define in torch/csrc/autograd/init.cpp +def _set_python_dispatcher(dispatcher: object) -> None: ... + +def _get_nested_int(id: _int, coeff: _int) -> SymInt: ... + +def _get_constant_bool_symnode(val: _bool) -> Any: ... + +class _TorchDispatchModeKey(Enum): + FAKE: _TorchDispatchModeKey = ... + PROXY: _TorchDispatchModeKey = ... + FUNCTIONAL: _TorchDispatchModeKey = ... + +class _SetExcludeDispatchKeyGuard: + def __init__(self, k: DispatchKey, enabled: _bool): ... + def __enter__(self): ... + def __exit__(self, exc_type, exc_value, traceback): ... + +# Defined in torch/csrc/utils/init.cpp +class BenchmarkConfig: + num_calling_threads: _int + num_worker_threads: _int + num_warmup_iters: _int + num_iters: _int + profiler_output_path: str + +class BenchmarkExecutionStats: + latency_avg_ms: _float + num_iters: _int + +class ThroughputBenchmark: + def __init__(self, module: Any) -> None: ... + def add_input(self, *args: Any, **kwargs: Any) -> None: ... + def run_once(self, *args: Any, **kwargs: Any) -> Any: ... + def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ... + +# Defined in torch/csrc/Storage.cpp +class StorageBase(object): ... + +# TODO: where +class DoubleTensor(Tensor): ... +class FloatTensor(Tensor): ... +class BFloat16Tensor(Tensor): ... +class LongTensor(Tensor): ... +class IntTensor(Tensor): ... +class ShortTensor(Tensor): ... +class HalfTensor(Tensor): ... +class CharTensor(Tensor): ... +class ByteTensor(Tensor): ... +class BoolTensor(Tensor): ... + +# Defined in torch/csrc/autograd/python_engine.cpp +class _ImperativeEngine: + def queue_callback(self, callback: Callable[[], None]) -> None: ... + def run_backward(self, *args: Any, **kwargs: Any) -> Tuple[Tensor, ...]: ... + def is_checkpoint_valid(self) -> _bool: ... + +# Defined in torch/csrc/autograd/python_variable.cpp +class _TensorMeta(type): ... + +# Defined in torch/csrc/autograd/python_variable.cpp +class TensorBase(metaclass=_TensorMeta): + requires_grad: _bool + retains_grad: _bool + shape: Size + data: Tensor + names: List[str] + device: _device + dtype: _dtype + layout: _layout + real: Tensor + imag: Tensor + T: Tensor + H: Tensor + mT: Tensor + mH: Tensor + ndim: _int + output_nr: _int + _version: _int + _base: Optional[Tensor] + _cdata: _int + grad_fn: Optional[_Node] + _grad_fn: Any + _grad: Optional[Tensor] + grad: Optional[Tensor] + _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]] + nbytes: _int + itemsize: _int + _has_symbolic_sizes_strides: _bool + def __abs__(self) -> Tensor: ... + def __add__(self, other: Any) -> Tensor: ... + @overload + def __and__(self, other: Tensor) -> Tensor: ... + @overload + def __and__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __and__(self, other: Any) -> Tensor: ... + def __bool__(self) -> builtins.bool: ... + def __complex__(self) -> builtins.complex: ... + def __div__(self, other: Any) -> Tensor: ... + def __eq__(self, other: Any) -> Tensor: ... # type: ignore[override] + def __float__(self) -> builtins.float: ... + def __floordiv__(self, other: Any) -> Tensor: ... + def __ge__(self, other: Any) -> Tensor: ... + def __getitem__(self, indices: Union[Union[SupportsIndex, Union[None, _bool, _int, slice, ellipsis, Tensor], _NestedSequence[Union[None, _bool, _int, slice, ellipsis, Tensor]]], tuple[Union[SupportsIndex, Union[None, _bool, _int, slice, ellipsis, Tensor], _NestedSequence[Union[None, _bool, _int, slice, ellipsis, Tensor]]], ...]]) -> Tensor: ... + def __gt__(self, other: Any) -> Tensor: ... + def __iadd__(self, other: Any) -> Tensor: ... + @overload + def __iand__(self, other: Tensor) -> Tensor: ... + @overload + def __iand__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __iand__(self, other: Any) -> Tensor: ... + def __idiv__(self, other: Any) -> Tensor: ... + def __ifloordiv__(self, other: Any) -> Tensor: ... + @overload + def __ilshift__(self, other: Tensor) -> Tensor: ... + @overload + def __ilshift__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __ilshift__(self, other: Any) -> Tensor: ... + def __imod__(self, other: Any) -> Tensor: ... + def __imul__(self, other: Any) -> Tensor: ... + def __index__(self) -> builtins.int: ... + @overload + def __init__(self, *args: Any, device: Optional[DeviceLikeType] = None) -> None: ... + @overload + def __init__(self, storage: Storage) -> None: ... + @overload + def __init__(self, other: Tensor) -> None: ... + @overload + def __init__(self, size: _size, *, device: Optional[DeviceLikeType] = None) -> None: ... + def __int__(self) -> builtins.int: ... + def __invert__(self) -> Tensor: ... + @overload + def __ior__(self, other: Tensor) -> Tensor: ... + @overload + def __ior__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __ior__(self, other: Any) -> Tensor: ... + @overload + def __irshift__(self, other: Tensor) -> Tensor: ... + @overload + def __irshift__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __irshift__(self, other: Any) -> Tensor: ... + def __isub__(self, other: Any) -> Tensor: ... + @overload + def __ixor__(self, other: Tensor) -> Tensor: ... + @overload + def __ixor__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __ixor__(self, other: Any) -> Tensor: ... + def __le__(self, other: Any) -> Tensor: ... + def __long__(self) -> builtins.int: ... + @overload + def __lshift__(self, other: Tensor) -> Tensor: ... + @overload + def __lshift__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __lshift__(self, other: Any) -> Tensor: ... + def __lt__(self, other: Any) -> Tensor: ... + def __matmul__(self, other: Any) -> Tensor: ... + def __mod__(self, other: Any) -> Tensor: ... + def __mul__(self, other: Any) -> Tensor: ... + def __ne__(self, other: Any) -> Tensor: ... # type: ignore[override] + def __neg__(self) -> Tensor: ... + def __new__(self, *args, **kwargs) -> Tensor: ... + def __nonzero__(self) -> builtins.bool: ... + @overload + def __or__(self, other: Tensor) -> Tensor: ... + @overload + def __or__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __or__(self, other: Any) -> Tensor: ... + def __pow__(self, other: Any) -> Tensor: ... + def __radd__(self, other: Any) -> Tensor: ... + def __rand__(self, other: Any) -> Tensor: ... + def __rfloordiv__(self, other: Any) -> Tensor: ... + def __rmul__(self, other: Any) -> Tensor: ... + def __ror__(self, other: Any) -> Tensor: ... + def __rpow__(self, other: Any) -> Tensor: ... + @overload + def __rshift__(self, other: Tensor) -> Tensor: ... + @overload + def __rshift__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __rshift__(self, other: Any) -> Tensor: ... + def __rsub__(self, other: Any) -> Tensor: ... + def __rtruediv__(self, other: Any) -> Tensor: ... + def __rxor__(self, other: Any) -> Tensor: ... + def __setitem__(self, indices: Union[Union[SupportsIndex, Union[None, _bool, _int, slice, ellipsis, Tensor], _NestedSequence[Union[None, _bool, _int, slice, ellipsis, Tensor]]], tuple[Union[SupportsIndex, Union[None, _bool, _int, slice, ellipsis, Tensor], _NestedSequence[Union[None, _bool, _int, slice, ellipsis, Tensor]]], ...]], val: Union[Tensor, Number]) -> None: ... + def __sub__(self, other: Any) -> Tensor: ... + def __truediv__(self, other: Any) -> Tensor: ... + @overload + def __xor__(self, other: Tensor) -> Tensor: ... + @overload + def __xor__(self, other: Union[Number, _complex]) -> Tensor: ... + @overload + def __xor__(self, other: Any) -> Tensor: ... + def _addmm_activation(self, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, use_gelu: _bool = False) -> Tensor: ... + def _autocast_to_full_precision(self, cuda_enabled: _bool, cpu_enabled: _bool) -> Tensor: ... + def _autocast_to_reduced_precision(self, cuda_enabled: _bool, cpu_enabled: _bool, cuda_dtype: _dtype, cpu_dtype: _dtype) -> Tensor: ... + def _coalesced_(self, coalesced: _bool) -> Tensor: ... + def _conj(self) -> Tensor: ... + def _conj_physical(self) -> Tensor: ... + def _dimI(self) -> _int: ... + def _dimV(self) -> _int: ... + def _indices(self) -> Tensor: ... + def _is_all_true(self) -> Tensor: ... + def _is_any_true(self) -> Tensor: ... + def _is_view(self) -> _bool: ... + def _is_zerotensor(self) -> _bool: ... + def _lazy_clone(self) -> Tensor: ... + @staticmethod + def _make_subclass(cls: Type[S], data: Tensor, require_grad: _bool = False, dispatch_strides: _bool = False, dispatch_device: _bool = False, device_for_backend_keys: Optional[_device] = None) -> S: ... + def _neg_view(self) -> Tensor: ... + def _nested_tensor_size(self) -> Tensor: ... + def _nested_tensor_storage_offsets(self) -> Tensor: ... + def _nested_tensor_strides(self) -> Tensor: ... + def _nnz(self) -> _int: ... + def _sparse_mask_projection(self, mask: Tensor, accumulate_matches: _bool = False) -> Tensor: ... + def _to_dense(self, dtype: Optional[_dtype] = None, masked_grad: Optional[_bool] = None) -> Tensor: ... + @overload + def _to_sparse(self, *, layout: Optional[_layout] = None, blocksize: Optional[Union[_int, _size]] = None, dense_dim: Optional[_int] = None) -> Tensor: ... + @overload + def _to_sparse(self, sparse_dim: _int) -> Tensor: ... + def _to_sparse_bsc(self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None) -> Tensor: ... + def _to_sparse_bsr(self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None) -> Tensor: ... + def _to_sparse_csc(self, dense_dim: Optional[_int] = None) -> Tensor: ... + def _to_sparse_csr(self, dense_dim: Optional[_int] = None) -> Tensor: ... + def _values(self) -> Tensor: ... + def abs(self) -> Tensor: + r""" + abs() -> Tensor + + See :func:`torch.abs` + """ + ... + def abs_(self) -> Tensor: + r""" + abs_() -> Tensor + + In-place version of :meth:`~Tensor.abs` + """ + ... + def absolute(self) -> Tensor: + r""" + absolute() -> Tensor + + Alias for :func:`abs` + """ + ... + def absolute_(self) -> Tensor: + r""" + absolute_() -> Tensor + + In-place version of :meth:`~Tensor.absolute` + Alias for :func:`abs_` + """ + ... + def acos(self) -> Tensor: + r""" + acos() -> Tensor + + See :func:`torch.acos` + """ + ... + def acos_(self) -> Tensor: + r""" + acos_() -> Tensor + + In-place version of :meth:`~Tensor.acos` + """ + ... + def acosh(self) -> Tensor: + r""" + acosh() -> Tensor + + See :func:`torch.acosh` + """ + ... + def acosh_(self) -> Tensor: + r""" + acosh_() -> Tensor + + In-place version of :meth:`~Tensor.acosh` + """ + ... + def add(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], *, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + add(other, *, alpha=1) -> Tensor + + Add a scalar or tensor to :attr:`self` tensor. If both :attr:`alpha` + and :attr:`other` are specified, each element of :attr:`other` is scaled by + :attr:`alpha` before being used. + + When :attr:`other` is a tensor, the shape of :attr:`other` must be + :ref:`broadcastable ` with the shape of the underlying + tensor + + See :func:`torch.add` + """ + ... + def add_(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], *, alpha: Optional[Union[Number, _complex]] = 1) -> Tensor: + r""" + add_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.add` + """ + ... + def addbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addbmm` + """ + ... + def addbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addbmm` + """ + ... + def addcdiv(self, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1) -> Tensor: + r""" + addcdiv(tensor1, tensor2, *, value=1) -> Tensor + + See :func:`torch.addcdiv` + """ + ... + def addcdiv_(self, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1) -> Tensor: + r""" + addcdiv_(tensor1, tensor2, *, value=1) -> Tensor + + In-place version of :meth:`~Tensor.addcdiv` + """ + ... + def addcmul(self, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1) -> Tensor: + r""" + addcmul(tensor1, tensor2, *, value=1) -> Tensor + + See :func:`torch.addcmul` + """ + ... + def addcmul_(self, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1) -> Tensor: + r""" + addcmul_(tensor1, tensor2, *, value=1) -> Tensor + + In-place version of :meth:`~Tensor.addcmul` + """ + ... + def addmm(self, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addmm` + """ + ... + def addmm_(self, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addmm_(mat1, mat2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addmm` + """ + ... + def addmv(self, mat: Tensor, vec: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addmv(mat, vec, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addmv` + """ + ... + def addmv_(self, mat: Tensor, vec: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addmv_(mat, vec, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addmv` + """ + ... + def addr(self, vec1: Tensor, vec2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addr` + """ + ... + def addr_(self, vec1: Tensor, vec2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + addr_(vec1, vec2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addr` + """ + ... + def adjoint(self) -> Tensor: + r""" + adjoint() -> Tensor + + Alias for :func:`adjoint` + """ + ... + def align_as(self, other: Tensor) -> Tensor: + r""" + align_as(other) -> Tensor + + Permutes the dimensions of the :attr:`self` tensor to match the dimension order + in the :attr:`other` tensor, adding size-one dims for any new names. + + This operation is useful for explicit broadcasting by names (see examples). + + All of the dims of :attr:`self` must be named in order to use this method. + The resulting tensor is a view on the original tensor. + + All dimension names of :attr:`self` must be present in ``other.names``. + :attr:`other` may contain named dimensions that are not in ``self.names``; + the output tensor has a size-one dimension for each of those new names. + + To align a tensor to a specific order, use :meth:`~Tensor.align_to`. + + Examples:: + + # Example 1: Applying a mask + >>> mask = torch.randint(2, [127, 128], dtype=torch.bool).refine_names('W', 'H') + >>> imgs = torch.randn(32, 128, 127, 3, names=('N', 'H', 'W', 'C')) + >>> imgs.masked_fill_(mask.align_as(imgs), 0) + + + # Example 2: Applying a per-channel-scale + >>> def scale_channels(input, scale): + >>> scale = scale.refine_names('C') + >>> return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names=('C',)) + >>> imgs = torch.rand(32, 128, 128, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(32, num_channels, 128, 128, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 128, 128, 128, names=('N', 'C', 'H', 'W', 'D')) + + # scale_channels is agnostic to the dimension order of the input + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) + + .. warning:: + The named tensor API is experimental and subject to change. + """ + ... + @overload + def align_to(self, order: Sequence[Union[str, ellipsis, None]], ellipsis_idx: _int) -> Tensor: ... + @overload + def align_to(self, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ... + @overload + def all(self) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + ... + @overload + def all(self, dim: Optional[_size] = None, keepdim: _bool = False) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + ... + @overload + def all(self, dim: _int, keepdim: _bool = False) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + ... + @overload + def all(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + ... + def allclose(self, other: Tensor, rtol: _float = 1e-05, atol: _float = 1e-08, equal_nan: _bool = False) -> _bool: + r""" + allclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + See :func:`torch.allclose` + """ + ... + def amax(self, dim: Union[_int, _size] = (), keepdim: _bool = False) -> Tensor: + r""" + amax(dim=None, keepdim=False) -> Tensor + + See :func:`torch.amax` + """ + ... + def amin(self, dim: Union[_int, _size] = (), keepdim: _bool = False) -> Tensor: + r""" + amin(dim=None, keepdim=False) -> Tensor + + See :func:`torch.amin` + """ + ... + def aminmax(self, *, dim: Optional[_int] = None, keepdim: _bool = False) -> torch.return_types.aminmax: + r""" + aminmax(*, dim=None, keepdim=False) -> (Tensor min, Tensor max) + + See :func:`torch.aminmax` + """ + ... + def angle(self) -> Tensor: + r""" + angle() -> Tensor + + See :func:`torch.angle` + """ + ... + @overload + def any(self) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + ... + @overload + def any(self, dim: Optional[_size] = None, keepdim: _bool = False) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + ... + @overload + def any(self, dim: _int, keepdim: _bool = False) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + ... + @overload + def any(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + ... + def apply_(self, callable: Callable) -> Tensor: + r""" + apply_(callable) -> Tensor + + Applies the function :attr:`callable` to each element in the tensor, replacing + each element with the value returned by :attr:`callable`. + + .. note:: + + This function only works with CPU tensors and should not be used in code + sections that require high performance. + """ + ... + def arccos(self) -> Tensor: + r""" + arccos() -> Tensor + + See :func:`torch.arccos` + """ + ... + def arccos_(self) -> Tensor: + r""" + arccos_() -> Tensor + + In-place version of :meth:`~Tensor.arccos` + """ + ... + def arccosh(self) -> Tensor: + r""" + acosh() -> Tensor + + See :func:`torch.arccosh` + """ + ... + def arccosh_(self) -> Tensor: + r""" + acosh_() -> Tensor + + In-place version of :meth:`~Tensor.arccosh` + """ + ... + def arcsin(self) -> Tensor: + r""" + arcsin() -> Tensor + + See :func:`torch.arcsin` + """ + ... + def arcsin_(self) -> Tensor: + r""" + arcsin_() -> Tensor + + In-place version of :meth:`~Tensor.arcsin` + """ + ... + def arcsinh(self) -> Tensor: + r""" + arcsinh() -> Tensor + + See :func:`torch.arcsinh` + """ + ... + def arcsinh_(self) -> Tensor: + r""" + arcsinh_() -> Tensor + + In-place version of :meth:`~Tensor.arcsinh` + """ + ... + def arctan(self) -> Tensor: + r""" + arctan() -> Tensor + + See :func:`torch.arctan` + """ + ... + def arctan2(self, other: Tensor) -> Tensor: + r""" + arctan2(other) -> Tensor + + See :func:`torch.arctan2` + """ + ... + def arctan2_(self, other: Tensor) -> Tensor: + r""" + atan2_(other) -> Tensor + + In-place version of :meth:`~Tensor.arctan2` + """ + ... + def arctan_(self) -> Tensor: + r""" + arctan_() -> Tensor + + In-place version of :meth:`~Tensor.arctan` + """ + ... + def arctanh(self) -> Tensor: + r""" + arctanh() -> Tensor + + See :func:`torch.arctanh` + """ + ... + def arctanh_(self) -> Tensor: + r""" + arctanh_(other) -> Tensor + + In-place version of :meth:`~Tensor.arctanh` + """ + ... + def argmax(self, dim: Optional[_int] = None, keepdim: _bool = False) -> Tensor: + r""" + argmax(dim=None, keepdim=False) -> LongTensor + + See :func:`torch.argmax` + """ + ... + def argmin(self, dim: Optional[_int] = None, keepdim: _bool = False) -> Tensor: + r""" + argmin(dim=None, keepdim=False) -> LongTensor + + See :func:`torch.argmin` + """ + ... + @overload + def argsort(self, *, stable: _bool, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(dim=-1, descending=False) -> LongTensor + + See :func:`torch.argsort` + """ + ... + @overload + def argsort(self, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(dim=-1, descending=False) -> LongTensor + + See :func:`torch.argsort` + """ + ... + @overload + def argsort(self, dim: Union[str, ellipsis, None], descending: _bool = False) -> Tensor: + r""" + argsort(dim=-1, descending=False) -> LongTensor + + See :func:`torch.argsort` + """ + ... + def argwhere(self) -> Tensor: + r""" + argwhere() -> Tensor + + See :func:`torch.argwhere` + """ + ... + def as_strided(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided(size, stride, storage_offset=None) -> Tensor + + See :func:`torch.as_strided` + """ + ... + def as_strided_(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided_(size, stride, storage_offset=None) -> Tensor + + In-place version of :meth:`~Tensor.as_strided` + """ + ... + def as_strided_scatter(self, src: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided_scatter(src, size, stride, storage_offset=None) -> Tensor + + See :func:`torch.as_strided_scatter` + """ + ... + def as_subclass(self, cls: Type[S]) -> S: + r""" + as_subclass(cls) -> Tensor + + Makes a ``cls`` instance with the same data pointer as ``self``. Changes + in the output mirror changes in ``self``, and the output stays attached + to the autograd graph. ``cls`` must be a subclass of ``Tensor``. + """ + ... + def asin(self) -> Tensor: + r""" + asin() -> Tensor + + See :func:`torch.asin` + """ + ... + def asin_(self) -> Tensor: + r""" + asin_() -> Tensor + + In-place version of :meth:`~Tensor.asin` + """ + ... + def asinh(self) -> Tensor: + r""" + asinh() -> Tensor + + See :func:`torch.asinh` + """ + ... + def asinh_(self) -> Tensor: + r""" + asinh_() -> Tensor + + In-place version of :meth:`~Tensor.asinh` + """ + ... + def atan(self) -> Tensor: + r""" + atan() -> Tensor + + See :func:`torch.atan` + """ + ... + def atan2(self, other: Tensor) -> Tensor: + r""" + atan2(other) -> Tensor + + See :func:`torch.atan2` + """ + ... + def atan2_(self, other: Tensor) -> Tensor: + r""" + atan2_(other) -> Tensor + + In-place version of :meth:`~Tensor.atan2` + """ + ... + def atan_(self) -> Tensor: + r""" + atan_() -> Tensor + + In-place version of :meth:`~Tensor.atan` + """ + ... + def atanh(self) -> Tensor: + r""" + atanh() -> Tensor + + See :func:`torch.atanh` + """ + ... + def atanh_(self) -> Tensor: + r""" + atanh_(other) -> Tensor + + In-place version of :meth:`~Tensor.atanh` + """ + ... + def baddbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.baddbmm` + """ + ... + def baddbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + baddbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.baddbmm` + """ + ... + @overload + def bernoulli(self, *, generator: Optional[Generator] = None) -> Tensor: + r""" + bernoulli(*, generator=None) -> Tensor + + Returns a result tensor where each :math:`\texttt{result[i]}` is independently + sampled from :math:`\text{Bernoulli}(\texttt{self[i]})`. :attr:`self` must have + floating point ``dtype``, and the result will have the same ``dtype``. + + See :func:`torch.bernoulli` + """ + ... + @overload + def bernoulli(self, p: _float, *, generator: Optional[Generator] = None) -> Tensor: + r""" + bernoulli(*, generator=None) -> Tensor + + Returns a result tensor where each :math:`\texttt{result[i]}` is independently + sampled from :math:`\text{Bernoulli}(\texttt{self[i]})`. :attr:`self` must have + floating point ``dtype``, and the result will have the same ``dtype``. + + See :func:`torch.bernoulli` + """ + ... + @overload + def bernoulli_(self, p: Tensor, *, generator: Optional[Generator] = None) -> Tensor: + r""" + bernoulli_(p=0.5, *, generator=None) -> Tensor + + Fills each location of :attr:`self` with an independent sample from + :math:`\text{Bernoulli}(\texttt{p})`. :attr:`self` can have integral + ``dtype``. + + :attr:`p` should either be a scalar or tensor containing probabilities to be + used for drawing the binary random number. + + If it is a tensor, the :math:`\text{i}^{th}` element of :attr:`self` tensor + will be set to a value sampled from + :math:`\text{Bernoulli}(\texttt{p\_tensor[i]})`. In this case `p` must have + floating point ``dtype``. + + See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` + """ + ... + @overload + def bernoulli_(self, p: _float = 0.5, *, generator: Optional[Generator] = None) -> Tensor: + r""" + bernoulli_(p=0.5, *, generator=None) -> Tensor + + Fills each location of :attr:`self` with an independent sample from + :math:`\text{Bernoulli}(\texttt{p})`. :attr:`self` can have integral + ``dtype``. + + :attr:`p` should either be a scalar or tensor containing probabilities to be + used for drawing the binary random number. + + If it is a tensor, the :math:`\text{i}^{th}` element of :attr:`self` tensor + will be set to a value sampled from + :math:`\text{Bernoulli}(\texttt{p\_tensor[i]})`. In this case `p` must have + floating point ``dtype``. + + See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` + """ + ... + def bfloat16(self) -> Tensor: + r""" + bfloat16(memory_format=torch.preserve_format) -> Tensor + ``self.bfloat16()`` is equivalent to ``self.to(torch.bfloat16)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def bincount(self, weights: Optional[Tensor] = None, minlength: _int = 0) -> Tensor: + r""" + bincount(weights=None, minlength=0) -> Tensor + + See :func:`torch.bincount` + """ + ... + @overload + def bitwise_and(self, other: Tensor) -> Tensor: + r""" + bitwise_and() -> Tensor + + See :func:`torch.bitwise_and` + """ + ... + @overload + def bitwise_and(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_and() -> Tensor + + See :func:`torch.bitwise_and` + """ + ... + @overload + def bitwise_and_(self, other: Tensor) -> Tensor: + r""" + bitwise_and_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_and` + """ + ... + @overload + def bitwise_and_(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_and_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_and` + """ + ... + @overload + def bitwise_left_shift(self, other: Tensor) -> Tensor: + r""" + bitwise_left_shift(other) -> Tensor + + See :func:`torch.bitwise_left_shift` + """ + ... + @overload + def bitwise_left_shift(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_left_shift(other) -> Tensor + + See :func:`torch.bitwise_left_shift` + """ + ... + @overload + def bitwise_left_shift_(self, other: Tensor) -> Tensor: + r""" + bitwise_left_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_left_shift` + """ + ... + @overload + def bitwise_left_shift_(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_left_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_left_shift` + """ + ... + def bitwise_not(self) -> Tensor: + r""" + bitwise_not() -> Tensor + + See :func:`torch.bitwise_not` + """ + ... + def bitwise_not_(self) -> Tensor: + r""" + bitwise_not_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_not` + """ + ... + @overload + def bitwise_or(self, other: Tensor) -> Tensor: + r""" + bitwise_or() -> Tensor + + See :func:`torch.bitwise_or` + """ + ... + @overload + def bitwise_or(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_or() -> Tensor + + See :func:`torch.bitwise_or` + """ + ... + @overload + def bitwise_or_(self, other: Tensor) -> Tensor: + r""" + bitwise_or_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_or` + """ + ... + @overload + def bitwise_or_(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_or_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_or` + """ + ... + @overload + def bitwise_right_shift(self, other: Tensor) -> Tensor: + r""" + bitwise_right_shift(other) -> Tensor + + See :func:`torch.bitwise_right_shift` + """ + ... + @overload + def bitwise_right_shift(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_right_shift(other) -> Tensor + + See :func:`torch.bitwise_right_shift` + """ + ... + @overload + def bitwise_right_shift_(self, other: Tensor) -> Tensor: + r""" + bitwise_right_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_right_shift` + """ + ... + @overload + def bitwise_right_shift_(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_right_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_right_shift` + """ + ... + @overload + def bitwise_xor(self, other: Tensor) -> Tensor: + r""" + bitwise_xor() -> Tensor + + See :func:`torch.bitwise_xor` + """ + ... + @overload + def bitwise_xor(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_xor() -> Tensor + + See :func:`torch.bitwise_xor` + """ + ... + @overload + def bitwise_xor_(self, other: Tensor) -> Tensor: + r""" + bitwise_xor_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_xor` + """ + ... + @overload + def bitwise_xor_(self, other: Union[Number, _complex]) -> Tensor: + r""" + bitwise_xor_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_xor` + """ + ... + def bmm(self, mat2: Tensor) -> Tensor: + r""" + bmm(batch2) -> Tensor + + See :func:`torch.bmm` + """ + ... + def bool(self) -> Tensor: + r""" + bool(memory_format=torch.preserve_format) -> Tensor + + ``self.bool()`` is equivalent to ``self.to(torch.bool)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + @overload + def broadcast_to(self, size: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + broadcast_to(shape) -> Tensor + + See :func:`torch.broadcast_to`. + """ + ... + @overload + def broadcast_to(self, *size: _int) -> Tensor: + r""" + broadcast_to(shape) -> Tensor + + See :func:`torch.broadcast_to`. + """ + ... + def byte(self) -> Tensor: + r""" + byte(memory_format=torch.preserve_format) -> Tensor + + ``self.byte()`` is equivalent to ``self.to(torch.uint8)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def cauchy_(self, median: _float = 0, sigma: _float = 1, *, generator: Optional[Generator] = None) -> Tensor: + r""" + cauchy_(median=0, sigma=1, *, generator=None) -> Tensor + + Fills the tensor with numbers drawn from the Cauchy distribution: + + .. math:: + + f(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - \text{median})^2 + \sigma^2} + + .. note:: + Sigma (:math:`\sigma`) is used to denote the scale parameter in Cauchy distribution. + """ + ... + def ccol_indices(self) -> Tensor: ... + def ceil(self) -> Tensor: + r""" + ceil() -> Tensor + + See :func:`torch.ceil` + """ + ... + def ceil_(self) -> Tensor: + r""" + ceil_() -> Tensor + + In-place version of :meth:`~Tensor.ceil` + """ + ... + def chalf(self, *, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + chalf(memory_format=torch.preserve_format) -> Tensor + + ``self.chalf()`` is equivalent to ``self.to(torch.complex32)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def char(self) -> Tensor: + r""" + char(memory_format=torch.preserve_format) -> Tensor + + ``self.char()`` is equivalent to ``self.to(torch.int8)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def cholesky(self, upper: _bool = False) -> Tensor: + r""" + cholesky(upper=False) -> Tensor + + See :func:`torch.cholesky` + """ + ... + def cholesky_inverse(self, upper: _bool = False) -> Tensor: + r""" + cholesky_inverse(upper=False) -> Tensor + + See :func:`torch.cholesky_inverse` + """ + ... + def cholesky_solve(self, input2: Tensor, upper: _bool = False) -> Tensor: + r""" + cholesky_solve(input2, upper=False) -> Tensor + + See :func:`torch.cholesky_solve` + """ + ... + def chunk(self, chunks: _int, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + chunk(chunks, dim=0) -> List of Tensors + + See :func:`torch.chunk` + """ + ... + @overload + def clamp(self, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: + r""" + clamp(min=None, max=None) -> Tensor + + See :func:`torch.clamp` + """ + ... + @overload + def clamp(self, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: + r""" + clamp(min=None, max=None) -> Tensor + + See :func:`torch.clamp` + """ + ... + @overload + def clamp_(self, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: + r""" + clamp_(min=None, max=None) -> Tensor + + In-place version of :meth:`~Tensor.clamp` + """ + ... + @overload + def clamp_(self, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: + r""" + clamp_(min=None, max=None) -> Tensor + + In-place version of :meth:`~Tensor.clamp` + """ + ... + @overload + def clamp_max(self, max: Tensor) -> Tensor: ... + @overload + def clamp_max(self, max: Union[Number, _complex]) -> Tensor: ... + @overload + def clamp_max_(self, max: Tensor) -> Tensor: ... + @overload + def clamp_max_(self, max: Union[Number, _complex]) -> Tensor: ... + @overload + def clamp_min(self, min: Tensor) -> Tensor: ... + @overload + def clamp_min(self, min: Union[Number, _complex]) -> Tensor: ... + @overload + def clamp_min_(self, min: Tensor) -> Tensor: ... + @overload + def clamp_min_(self, min: Union[Number, _complex]) -> Tensor: ... + @overload + def clip(self, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: + r""" + clip(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp`. + """ + ... + @overload + def clip(self, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: + r""" + clip(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp`. + """ + ... + @overload + def clip_(self, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: + r""" + clip_(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp_`. + """ + ... + @overload + def clip_(self, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: + r""" + clip_(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp_`. + """ + ... + def clone(self, *, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + clone(*, memory_format=torch.preserve_format) -> Tensor + + See :func:`torch.clone` + """ + ... + def coalesce(self) -> Tensor: + r""" + coalesce() -> Tensor + + Returns a coalesced copy of :attr:`self` if :attr:`self` is an + :ref:`uncoalesced tensor `. + + Returns :attr:`self` if :attr:`self` is a coalesced tensor. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + """ + ... + def col_indices(self) -> Tensor: + r""" + col_indices() -> IntTensor + + Returns the tensor containing the column indices of the :attr:`self` + tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. + The ``col_indices`` tensor is strictly of shape (:attr:`self`.nnz()) + and of type ``int32`` or ``int64``. When using MKL routines such as sparse + matrix multiplication, it is necessary to use ``int32`` indexing in order + to avoid downcasting and potentially losing information. + + Example:: + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.col_indices() + tensor([0, 1, 2, 3, 4], dtype=torch.int32) + """ + ... + def conj(self) -> Tensor: + r""" + conj() -> Tensor + + See :func:`torch.conj` + """ + ... + def conj_physical(self) -> Tensor: + r""" + conj_physical() -> Tensor + + See :func:`torch.conj_physical` + """ + ... + def conj_physical_(self) -> Tensor: + r""" + conj_physical_() -> Tensor + + In-place version of :meth:`~Tensor.conj_physical` + """ + ... + def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: + r""" + contiguous(memory_format=torch.contiguous_format) -> Tensor + + Returns a contiguous in memory tensor containing the same data as :attr:`self` tensor. If + :attr:`self` tensor is already in the specified memory format, this function returns the + :attr:`self` tensor. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + """ + ... + def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: + r""" + copy_(src, non_blocking=False) -> Tensor + + Copies the elements from :attr:`src` into :attr:`self` tensor and returns + :attr:`self`. + + The :attr:`src` tensor must be :ref:`broadcastable ` + with the :attr:`self` tensor. It may be of a different data type or reside on a + different device. + + Args: + src (Tensor): the source tensor to copy from + non_blocking (bool): if ``True`` and this copy is between CPU and GPU, + the copy may occur asynchronously with respect to the host. For other + cases, this argument has no effect. + """ + ... + @overload + def copysign(self, other: Tensor) -> Tensor: + r""" + copysign(other) -> Tensor + + See :func:`torch.copysign` + """ + ... + @overload + def copysign(self, other: Union[Number, _complex]) -> Tensor: + r""" + copysign(other) -> Tensor + + See :func:`torch.copysign` + """ + ... + @overload + def copysign_(self, other: Tensor) -> Tensor: + r""" + copysign_(other) -> Tensor + + In-place version of :meth:`~Tensor.copysign` + """ + ... + @overload + def copysign_(self, other: Union[Number, _complex]) -> Tensor: + r""" + copysign_(other) -> Tensor + + In-place version of :meth:`~Tensor.copysign` + """ + ... + def corrcoef(self) -> Tensor: + r""" + corrcoef() -> Tensor + + See :func:`torch.corrcoef` + """ + ... + def cos(self) -> Tensor: + r""" + cos() -> Tensor + + See :func:`torch.cos` + """ + ... + def cos_(self) -> Tensor: + r""" + cos_() -> Tensor + + In-place version of :meth:`~Tensor.cos` + """ + ... + def cosh(self) -> Tensor: + r""" + cosh() -> Tensor + + See :func:`torch.cosh` + """ + ... + def cosh_(self) -> Tensor: + r""" + cosh_() -> Tensor + + In-place version of :meth:`~Tensor.cosh` + """ + ... + @overload + def count_nonzero(self, dim: Optional[_int] = None) -> Tensor: + r""" + count_nonzero(dim=None) -> Tensor + + See :func:`torch.count_nonzero` + """ + ... + @overload + def count_nonzero(self, dim: _size) -> Tensor: + r""" + count_nonzero(dim=None) -> Tensor + + See :func:`torch.count_nonzero` + """ + ... + @overload + def count_nonzero(self, *dim: _int) -> Tensor: + r""" + count_nonzero(dim=None) -> Tensor + + See :func:`torch.count_nonzero` + """ + ... + def cov(self, *, correction: _int = 1, fweights: Optional[Tensor] = None, aweights: Optional[Tensor] = None) -> Tensor: + r""" + cov(*, correction=1, fweights=None, aweights=None) -> Tensor + + See :func:`torch.cov` + """ + ... + def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: + r""" + cpu(memory_format=torch.preserve_format) -> Tensor + + Returns a copy of this object in CPU memory. + + If this object is already in CPU memory and on the correct device, + then no copy is performed and the original object is returned. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def cross(self, other: Tensor, dim: Optional[_int] = None) -> Tensor: + r""" + cross(other, dim=None) -> Tensor + + See :func:`torch.cross` + """ + ... + def crow_indices(self) -> Tensor: + r""" + crow_indices() -> IntTensor + + Returns the tensor containing the compressed row indices of the :attr:`self` + tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. + The ``crow_indices`` tensor is strictly of shape (:attr:`self`.size(0) + 1) + and of type ``int32`` or ``int64``. When using MKL routines such as sparse + matrix multiplication, it is necessary to use ``int32`` indexing in order + to avoid downcasting and potentially losing information. + + Example:: + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.crow_indices() + tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32) + """ + ... + def cuda(self, device: Optional[Union[_device, _int, str]] = None, non_blocking: _bool = False, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: + r""" + cuda(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, + then no copy is performed and the original object is returned. + + Args: + device (:class:`torch.device`): The destination GPU device. + Defaults to the current CUDA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + @overload + def cummax(self, dim: _int) -> torch.return_types.cummax: + r""" + cummax(dim) -> (Tensor, Tensor) + + See :func:`torch.cummax` + """ + ... + @overload + def cummax(self, dim: Union[str, ellipsis, None]) -> torch.return_types.cummax: + r""" + cummax(dim) -> (Tensor, Tensor) + + See :func:`torch.cummax` + """ + ... + @overload + def cummin(self, dim: _int) -> torch.return_types.cummin: + r""" + cummin(dim) -> (Tensor, Tensor) + + See :func:`torch.cummin` + """ + ... + @overload + def cummin(self, dim: Union[str, ellipsis, None]) -> torch.return_types.cummin: + r""" + cummin(dim) -> (Tensor, Tensor) + + See :func:`torch.cummin` + """ + ... + @overload + def cumprod(self, dim: _int, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumprod(dim, dtype=None) -> Tensor + + See :func:`torch.cumprod` + """ + ... + @overload + def cumprod(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumprod(dim, dtype=None) -> Tensor + + See :func:`torch.cumprod` + """ + ... + @overload + def cumprod_(self, dim: _int, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumprod_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumprod` + """ + ... + @overload + def cumprod_(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumprod_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumprod` + """ + ... + @overload + def cumsum(self, dim: _int, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumsum(dim, dtype=None) -> Tensor + + See :func:`torch.cumsum` + """ + ... + @overload + def cumsum(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumsum(dim, dtype=None) -> Tensor + + See :func:`torch.cumsum` + """ + ... + @overload + def cumsum_(self, dim: _int, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumsum_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumsum` + """ + ... + @overload + def cumsum_(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + cumsum_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumsum` + """ + ... + def data_ptr(self) -> _int: + r""" + data_ptr() -> int + + Returns the address of the first element of :attr:`self` tensor. + """ + ... + def deg2rad(self) -> Tensor: + r""" + deg2rad() -> Tensor + + See :func:`torch.deg2rad` + """ + ... + def deg2rad_(self) -> Tensor: + r""" + deg2rad_() -> Tensor + + In-place version of :meth:`~Tensor.deg2rad` + """ + ... + def dense_dim(self) -> _int: + r""" + dense_dim() -> int + + Return the number of dense dimensions in a :ref:`sparse tensor ` :attr:`self`. + + .. note:: + Returns ``len(self.shape)`` if :attr:`self` is not a sparse tensor. + + See also :meth:`Tensor.sparse_dim` and :ref:`hybrid tensors `. + """ + ... + def dequantize(self) -> Tensor: + r""" + dequantize() -> Tensor + + Given a quantized Tensor, dequantize it and return the dequantized float Tensor. + """ + ... + def det(self) -> Tensor: + r""" + det() -> Tensor + + See :func:`torch.det` + """ + ... + def detach(self) -> Tensor: ... + def detach_(self) -> Tensor: ... + def diag(self, diagonal: _int = 0) -> Tensor: + r""" + diag(diagonal=0) -> Tensor + + See :func:`torch.diag` + """ + ... + def diag_embed(self, offset: _int = 0, dim1: _int = -2, dim2: _int = -1) -> Tensor: + r""" + diag_embed(offset=0, dim1=-2, dim2=-1) -> Tensor + + See :func:`torch.diag_embed` + """ + ... + def diagflat(self, offset: _int = 0) -> Tensor: + r""" + diagflat(offset=0) -> Tensor + + See :func:`torch.diagflat` + """ + ... + @overload + def diagonal(self, *, outdim: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None], dim2: Union[str, ellipsis, None], offset: _int = 0) -> Tensor: + r""" + diagonal(offset=0, dim1=0, dim2=1) -> Tensor + + See :func:`torch.diagonal` + """ + ... + @overload + def diagonal(self, offset: _int = 0, dim1: _int = 0, dim2: _int = 1) -> Tensor: + r""" + diagonal(offset=0, dim1=0, dim2=1) -> Tensor + + See :func:`torch.diagonal` + """ + ... + def diagonal_scatter(self, src: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1) -> Tensor: + r""" + diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor + + See :func:`torch.diagonal_scatter` + """ + ... + def diff(self, n: _int = 1, dim: _int = -1, prepend: Optional[Tensor] = None, append: Optional[Tensor] = None) -> Tensor: + r""" + diff(n=1, dim=-1, prepend=None, append=None) -> Tensor + + See :func:`torch.diff` + """ + ... + def digamma(self) -> Tensor: + r""" + digamma() -> Tensor + + See :func:`torch.digamma` + """ + ... + def digamma_(self) -> Tensor: + r""" + digamma_() -> Tensor + + In-place version of :meth:`~Tensor.digamma` + """ + ... + def dim(self) -> _int: + r""" + dim() -> int + + Returns the number of dimensions of :attr:`self` tensor. + """ + ... + def dist(self, other: Tensor, p: Union[Number, _complex] = 2) -> Tensor: + r""" + dist(other, p=2) -> Tensor + + See :func:`torch.dist` + """ + ... + def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: + r""" + div(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.div` + """ + ... + def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: + r""" + div_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.div` + """ + ... + @overload + def divide(self, other: Tensor) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + ... + @overload + def divide(self, other: Tensor, *, rounding_mode: Optional[str]) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + ... + @overload + def divide(self, other: Union[Number, _complex], *, rounding_mode: Optional[str]) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + ... + @overload + def divide(self, other: Union[Number, _complex]) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + ... + @overload + def divide_(self, other: Tensor) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + ... + @overload + def divide_(self, other: Tensor, *, rounding_mode: Optional[str]) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + ... + @overload + def divide_(self, other: Union[Number, _complex], *, rounding_mode: Optional[str]) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + ... + @overload + def divide_(self, other: Union[Number, _complex]) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + ... + def dot(self, tensor: Tensor) -> Tensor: + r""" + dot(other) -> Tensor + + See :func:`torch.dot` + """ + ... + def double(self) -> Tensor: + r""" + double(memory_format=torch.preserve_format) -> Tensor + + ``self.double()`` is equivalent to ``self.to(torch.float64)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + @overload + def dsplit(self, sections: _int) -> Tuple[Tensor, ...]: + r""" + dsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.dsplit` + """ + ... + @overload + def dsplit(self, indices: _size) -> Tuple[Tensor, ...]: + r""" + dsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.dsplit` + """ + ... + @overload + def dsplit(self, *indices: _int) -> Tuple[Tensor, ...]: + r""" + dsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.dsplit` + """ + ... + def element_size(self) -> _int: + r""" + element_size() -> int + + Returns the size in bytes of an individual element. + + Example:: + + >>> torch.tensor([]).element_size() + 4 + >>> torch.tensor([], dtype=torch.uint8).element_size() + 1 + """ + ... + @overload + def eq(self, other: Tensor) -> Tensor: + r""" + eq(other) -> Tensor + + See :func:`torch.eq` + """ + ... + @overload + def eq(self, other: Union[Number, _complex]) -> Tensor: + r""" + eq(other) -> Tensor + + See :func:`torch.eq` + """ + ... + @overload + def eq_(self, other: Tensor) -> Tensor: + r""" + eq_(other) -> Tensor + + In-place version of :meth:`~Tensor.eq` + """ + ... + @overload + def eq_(self, other: Union[Number, _complex]) -> Tensor: + r""" + eq_(other) -> Tensor + + In-place version of :meth:`~Tensor.eq` + """ + ... + def equal(self, other: Tensor) -> _bool: + r""" + equal(other) -> bool + + See :func:`torch.equal` + """ + ... + def erf(self) -> Tensor: + r""" + erf() -> Tensor + + See :func:`torch.erf` + """ + ... + def erf_(self) -> Tensor: + r""" + erf_() -> Tensor + + In-place version of :meth:`~Tensor.erf` + """ + ... + def erfc(self) -> Tensor: + r""" + erfc() -> Tensor + + See :func:`torch.erfc` + """ + ... + def erfc_(self) -> Tensor: + r""" + erfc_() -> Tensor + + In-place version of :meth:`~Tensor.erfc` + """ + ... + def erfinv(self) -> Tensor: + r""" + erfinv() -> Tensor + + See :func:`torch.erfinv` + """ + ... + def erfinv_(self) -> Tensor: + r""" + erfinv_() -> Tensor + + In-place version of :meth:`~Tensor.erfinv` + """ + ... + def exp(self) -> Tensor: + r""" + exp() -> Tensor + + See :func:`torch.exp` + """ + ... + def exp2(self) -> Tensor: + r""" + exp2() -> Tensor + + See :func:`torch.exp2` + """ + ... + def exp2_(self) -> Tensor: + r""" + exp2_() -> Tensor + + In-place version of :meth:`~Tensor.exp2` + """ + ... + def exp_(self) -> Tensor: + r""" + exp_() -> Tensor + + In-place version of :meth:`~Tensor.exp` + """ + ... + @overload + def expand(self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool = False) -> Tensor: + r""" + expand(*sizes) -> Tensor + + Returns a new view of the :attr:`self` tensor with singleton dimensions expanded + to a larger size. + + Passing -1 as the size for a dimension means not changing the size of + that dimension. + + Tensor can be also expanded to a larger number of dimensions, and the + new ones will be appended at the front. For the new dimensions, the + size cannot be set to -1. + + Expanding a tensor does not allocate new memory, but only creates a + new view on the existing tensor where a dimension of size one is + expanded to a larger size by setting the ``stride`` to 0. Any dimension + of size 1 can be expanded to an arbitrary value without allocating new + memory. + + Args: + *sizes (torch.Size or int...): the desired expanded size + + .. warning:: + + More than one element of an expanded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + + Example:: + + >>> x = torch.tensor([[1], [2], [3]]) + >>> x.size() + torch.Size([3, 1]) + >>> x.expand(3, 4) + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + >>> x.expand(-1, 4) # -1 means not changing the size of that dimension + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + """ + ... + @overload + def expand(self, *size: _int, implicit: _bool = False) -> Tensor: + r""" + expand(*sizes) -> Tensor + + Returns a new view of the :attr:`self` tensor with singleton dimensions expanded + to a larger size. + + Passing -1 as the size for a dimension means not changing the size of + that dimension. + + Tensor can be also expanded to a larger number of dimensions, and the + new ones will be appended at the front. For the new dimensions, the + size cannot be set to -1. + + Expanding a tensor does not allocate new memory, but only creates a + new view on the existing tensor where a dimension of size one is + expanded to a larger size by setting the ``stride`` to 0. Any dimension + of size 1 can be expanded to an arbitrary value without allocating new + memory. + + Args: + *sizes (torch.Size or int...): the desired expanded size + + .. warning:: + + More than one element of an expanded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + + Example:: + + >>> x = torch.tensor([[1], [2], [3]]) + >>> x.size() + torch.Size([3, 1]) + >>> x.expand(3, 4) + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + >>> x.expand(-1, 4) # -1 means not changing the size of that dimension + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + """ + ... + def expand_as(self, other: Tensor) -> Tensor: + r""" + expand_as(other) -> Tensor + + Expand this tensor to the same size as :attr:`other`. + ``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``. + + Please see :meth:`~Tensor.expand` for more information about ``expand``. + + Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. + """ + ... + def expm1(self) -> Tensor: + r""" + expm1() -> Tensor + + See :func:`torch.expm1` + """ + ... + def expm1_(self) -> Tensor: + r""" + expm1_() -> Tensor + + In-place version of :meth:`~Tensor.expm1` + """ + ... + def exponential_(self, lambd: _float = 1, *, generator: Optional[Generator] = None) -> Tensor: + r""" + exponential_(lambd=1, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with elements drawn from the PDF (probability density function): + + .. math:: + + f(x) = \lambda e^{-\lambda x}, x > 0 + + .. note:: + In probability theory, exponential distribution is supported on interval [0, :math:`\inf`) (i.e., :math:`x >= 0`) + implying that zero can be sampled from the exponential distribution. + However, :func:`torch.Tensor.exponential_` does not sample zero, + which means that its actual support is the interval (0, :math:`\inf`). + + Note that :func:`torch.distributions.exponential.Exponential` is supported on the interval [0, :math:`\inf`) and can sample zero. + """ + ... + @overload + def fill_(self, value: Tensor) -> Tensor: + r""" + fill_(value) -> Tensor + + Fills :attr:`self` tensor with the specified value. + """ + ... + @overload + def fill_(self, value: Union[Number, _complex]) -> Tensor: + r""" + fill_(value) -> Tensor + + Fills :attr:`self` tensor with the specified value. + """ + ... + def fill_diagonal_(self, fill_value: Union[Number, _complex], wrap: _bool = False) -> Tensor: + r""" + fill_diagonal_(fill_value, wrap=False) -> Tensor + + Fill the main diagonal of a tensor that has at least 2-dimensions. + When dims>2, all dimensions of input must be of equal length. + This function modifies the input tensor in-place, and returns the input tensor. + + Arguments: + fill_value (Scalar): the fill value + wrap (bool): the diagonal 'wrapped' after N columns for tall matrices. + + Example:: + + >>> a = torch.zeros(3, 3) + >>> a.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + >>> b = torch.zeros(7, 3) + >>> b.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> c = torch.zeros(7, 3) + >>> c.fill_diagonal_(5, wrap=True) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + """ + ... + def fix(self) -> Tensor: + r""" + fix() -> Tensor + + See :func:`torch.fix`. + """ + ... + def fix_(self) -> Tensor: + r""" + fix_() -> Tensor + + In-place version of :meth:`~Tensor.fix` + """ + ... + @overload + def flatten(self, start_dim: _int = 0, end_dim: _int = -1) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + ... + @overload + def flatten(self, start_dim: _int, end_dim: _int, out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + ... + @overload + def flatten(self, start_dim: Union[str, ellipsis, None], end_dim: Union[str, ellipsis, None], out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + ... + @overload + def flatten(self, dims: Sequence[Union[str, ellipsis, None]], out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + ... + @overload + def flip(self, dims: _size) -> Tensor: + r""" + flip(dims) -> Tensor + + See :func:`torch.flip` + """ + ... + @overload + def flip(self, *dims: _int) -> Tensor: + r""" + flip(dims) -> Tensor + + See :func:`torch.flip` + """ + ... + def fliplr(self) -> Tensor: + r""" + fliplr() -> Tensor + + See :func:`torch.fliplr` + """ + ... + def flipud(self) -> Tensor: + r""" + flipud() -> Tensor + + See :func:`torch.flipud` + """ + ... + def float(self) -> Tensor: + r""" + float(memory_format=torch.preserve_format) -> Tensor + + ``self.float()`` is equivalent to ``self.to(torch.float32)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + @overload + def float_power(self, exponent: Tensor) -> Tensor: + r""" + float_power(exponent) -> Tensor + + See :func:`torch.float_power` + """ + ... + @overload + def float_power(self, exponent: Union[Number, _complex]) -> Tensor: + r""" + float_power(exponent) -> Tensor + + See :func:`torch.float_power` + """ + ... + @overload + def float_power_(self, exponent: Tensor) -> Tensor: + r""" + float_power_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.float_power` + """ + ... + @overload + def float_power_(self, exponent: Union[Number, _complex]) -> Tensor: + r""" + float_power_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.float_power` + """ + ... + def floor(self) -> Tensor: + r""" + floor() -> Tensor + + See :func:`torch.floor` + """ + ... + def floor_(self) -> Tensor: + r""" + floor_() -> Tensor + + In-place version of :meth:`~Tensor.floor` + """ + ... + def floor_divide(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor] = None) -> Tensor: + r""" + floor_divide(value) -> Tensor + + See :func:`torch.floor_divide` + """ + ... + def floor_divide_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: + r""" + floor_divide_(value) -> Tensor + + In-place version of :meth:`~Tensor.floor_divide` + """ + ... + def fmax(self, other: Tensor) -> Tensor: + r""" + fmax(other) -> Tensor + + See :func:`torch.fmax` + """ + ... + def fmin(self, other: Tensor) -> Tensor: + r""" + fmin(other) -> Tensor + + See :func:`torch.fmin` + """ + ... + @overload + def fmod(self, other: Tensor) -> Tensor: + r""" + fmod(divisor) -> Tensor + + See :func:`torch.fmod` + """ + ... + @overload + def fmod(self, other: Union[Number, _complex]) -> Tensor: + r""" + fmod(divisor) -> Tensor + + See :func:`torch.fmod` + """ + ... + @overload + def fmod_(self, other: Tensor) -> Tensor: + r""" + fmod_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.fmod` + """ + ... + @overload + def fmod_(self, other: Union[Number, _complex]) -> Tensor: + r""" + fmod_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.fmod` + """ + ... + def frac(self) -> Tensor: + r""" + frac() -> Tensor + + See :func:`torch.frac` + """ + ... + def frac_(self) -> Tensor: + r""" + frac_() -> Tensor + + In-place version of :meth:`~Tensor.frac` + """ + ... + def frexp(self) -> torch.return_types.frexp: + r""" + frexp(input) -> (Tensor mantissa, Tensor exponent) + + See :func:`torch.frexp` + """ + ... + @overload + def gather(self, dim: _int, index: Tensor, *, sparse_grad: _bool = False) -> Tensor: + r""" + gather(dim, index) -> Tensor + + See :func:`torch.gather` + """ + ... + @overload + def gather(self, dim: Union[str, ellipsis, None], index: Tensor, *, sparse_grad: _bool = False) -> Tensor: + r""" + gather(dim, index) -> Tensor + + See :func:`torch.gather` + """ + ... + def gcd(self, other: Tensor) -> Tensor: + r""" + gcd(other) -> Tensor + + See :func:`torch.gcd` + """ + ... + def gcd_(self, other: Tensor) -> Tensor: + r""" + gcd_(other) -> Tensor + + In-place version of :meth:`~Tensor.gcd` + """ + ... + @overload + def ge(self, other: Tensor) -> Tensor: + r""" + ge(other) -> Tensor + + See :func:`torch.ge`. + """ + ... + @overload + def ge(self, other: Union[Number, _complex]) -> Tensor: + r""" + ge(other) -> Tensor + + See :func:`torch.ge`. + """ + ... + @overload + def ge_(self, other: Tensor) -> Tensor: + r""" + ge_(other) -> Tensor + + In-place version of :meth:`~Tensor.ge`. + """ + ... + @overload + def ge_(self, other: Union[Number, _complex]) -> Tensor: + r""" + ge_(other) -> Tensor + + In-place version of :meth:`~Tensor.ge`. + """ + ... + def geometric_(self, p: _float, *, generator: Optional[Generator] = None) -> Tensor: + r""" + geometric_(p, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with elements drawn from the geometric distribution: + + .. math:: + + P(X=k) = (1 - p)^{k - 1} p, k = 1, 2, ... + + .. note:: + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`, whereas + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`. + """ + ... + def geqrf(self) -> torch.return_types.geqrf: + r""" + geqrf() -> (Tensor, Tensor) + + See :func:`torch.geqrf` + """ + ... + def ger(self, vec2: Tensor) -> Tensor: + r""" + ger(vec2) -> Tensor + + See :func:`torch.ger` + """ + ... + def get_device(self) -> _int: + r""" + get_device() -> Device ordinal (Integer) + + For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. + For CPU tensors, this function returns `-1`. + + Example:: + + >>> x = torch.randn(3, 4, 5, device='cuda:0') + >>> x.get_device() + 0 + >>> x.cpu().get_device() + -1 + """ + ... + @overload + def greater(self, other: Tensor) -> Tensor: + r""" + greater(other) -> Tensor + + See :func:`torch.greater`. + """ + ... + @overload + def greater(self, other: Union[Number, _complex]) -> Tensor: + r""" + greater(other) -> Tensor + + See :func:`torch.greater`. + """ + ... + @overload + def greater_(self, other: Tensor) -> Tensor: + r""" + greater_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater`. + """ + ... + @overload + def greater_(self, other: Union[Number, _complex]) -> Tensor: + r""" + greater_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater`. + """ + ... + @overload + def greater_equal(self, other: Tensor) -> Tensor: + r""" + greater_equal(other) -> Tensor + + See :func:`torch.greater_equal`. + """ + ... + @overload + def greater_equal(self, other: Union[Number, _complex]) -> Tensor: + r""" + greater_equal(other) -> Tensor + + See :func:`torch.greater_equal`. + """ + ... + @overload + def greater_equal_(self, other: Tensor) -> Tensor: + r""" + greater_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater_equal`. + """ + ... + @overload + def greater_equal_(self, other: Union[Number, _complex]) -> Tensor: + r""" + greater_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater_equal`. + """ + ... + @overload + def gt(self, other: Tensor) -> Tensor: + r""" + gt(other) -> Tensor + + See :func:`torch.gt`. + """ + ... + @overload + def gt(self, other: Union[Number, _complex]) -> Tensor: + r""" + gt(other) -> Tensor + + See :func:`torch.gt`. + """ + ... + @overload + def gt_(self, other: Tensor) -> Tensor: + r""" + gt_(other) -> Tensor + + In-place version of :meth:`~Tensor.gt`. + """ + ... + @overload + def gt_(self, other: Union[Number, _complex]) -> Tensor: + r""" + gt_(other) -> Tensor + + In-place version of :meth:`~Tensor.gt`. + """ + ... + def half(self) -> Tensor: + r""" + half(memory_format=torch.preserve_format) -> Tensor + + ``self.half()`` is equivalent to ``self.to(torch.float16)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def hardshrink(self, lambd: Union[Number, _complex] = 0.5) -> Tensor: + r""" + hardshrink(lambd=0.5) -> Tensor + + See :func:`torch.nn.functional.hardshrink` + """ + ... + def has_names(self) -> _bool: + r""" + Is ``True`` if any of this tensor's dimensions are named. Otherwise, is ``False``. + """ + ... + def heaviside(self, values: Tensor) -> Tensor: + r""" + heaviside(values) -> Tensor + + See :func:`torch.heaviside` + """ + ... + def heaviside_(self, values: Tensor) -> Tensor: + r""" + heaviside_(values) -> Tensor + + In-place version of :meth:`~Tensor.heaviside` + """ + ... + def histc(self, bins: _int = 100, min: Union[Number, _complex] = 0, max: Union[Number, _complex] = 0) -> Tensor: + r""" + histc(bins=100, min=0, max=0) -> Tensor + + See :func:`torch.histc` + """ + ... + @overload + def histogram(self, bins: Tensor, *, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) + + See :func:`torch.histogram` + """ + ... + @overload + def histogram(self, bins: _int = 100, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) + + See :func:`torch.histogram` + """ + ... + @overload + def hsplit(self, sections: _int) -> Tuple[Tensor, ...]: + r""" + hsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.hsplit` + """ + ... + @overload + def hsplit(self, indices: _size) -> Tuple[Tensor, ...]: + r""" + hsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.hsplit` + """ + ... + @overload + def hsplit(self, *indices: _int) -> Tuple[Tensor, ...]: + r""" + hsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.hsplit` + """ + ... + def hypot(self, other: Tensor) -> Tensor: + r""" + hypot(other) -> Tensor + + See :func:`torch.hypot` + """ + ... + def hypot_(self, other: Tensor) -> Tensor: + r""" + hypot_(other) -> Tensor + + In-place version of :meth:`~Tensor.hypot` + """ + ... + def i0(self) -> Tensor: + r""" + i0() -> Tensor + + See :func:`torch.i0` + """ + ... + def i0_(self) -> Tensor: + r""" + i0_() -> Tensor + + In-place version of :meth:`~Tensor.i0` + """ + ... + def igamma(self, other: Tensor) -> Tensor: + r""" + igamma(other) -> Tensor + + See :func:`torch.igamma` + """ + ... + def igamma_(self, other: Tensor) -> Tensor: + r""" + igamma_(other) -> Tensor + + In-place version of :meth:`~Tensor.igamma` + """ + ... + def igammac(self, other: Tensor) -> Tensor: + r""" + igammac(other) -> Tensor + See :func:`torch.igammac` + """ + ... + def igammac_(self, other: Tensor) -> Tensor: + r""" + igammac_(other) -> Tensor + In-place version of :meth:`~Tensor.igammac` + """ + ... + @overload + def index_add(self, dim: _int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + index_add(dim, index, source, *, alpha=1) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_add_`. + """ + ... + @overload + def index_add(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + index_add(dim, index, source, *, alpha=1) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_add_`. + """ + ... + def index_add_(self, dim: _int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + index_add_(dim, index, source, *, alpha=1) -> Tensor + + Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self` + tensor by adding to the indices in the order given in :attr:`index`. For example, + if ``dim == 0``, ``index[i] == j``, and ``alpha=-1``, then the ``i``\ th row of + ``source`` is subtracted from the ``j``\ th row of :attr:`self`. + + The :attr:`dim`\ th dimension of ``source`` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + For a 3-D tensor the output is given as:: + + self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 + self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 + self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (Tensor): the tensor containing values to add + + Keyword args: + alpha (Number): the scalar multiplier for ``source`` + + Example:: + + >>> x = torch.ones(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_add_(0, index, t) + tensor([[ 2., 3., 4.], + [ 1., 1., 1.], + [ 8., 9., 10.], + [ 1., 1., 1.], + [ 5., 6., 7.]]) + >>> x.index_add_(0, index, t, alpha=-1) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + ... + @overload + def index_copy(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy(dim, index, tensor2) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_copy_`. + """ + ... + @overload + def index_copy(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy(dim, index, tensor2) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_copy_`. + """ + ... + @overload + def index_copy_(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy_(dim, index, tensor) -> Tensor + + Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting + the indices in the order given in :attr:`index`. For example, if ``dim == 0`` + and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the + ``j``\ th row of :attr:`self`. + + The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + .. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`tensor` to select from + tensor (Tensor): the tensor containing values to copy + + Example:: + + >>> x = torch.zeros(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_copy_(0, index, t) + tensor([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 7., 8., 9.], + [ 0., 0., 0.], + [ 4., 5., 6.]]) + """ + ... + @overload + def index_copy_(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy_(dim, index, tensor) -> Tensor + + Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting + the indices in the order given in :attr:`index`. For example, if ``dim == 0`` + and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the + ``j``\ th row of :attr:`self`. + + The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + .. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`tensor` to select from + tensor (Tensor): the tensor containing values to copy + + Example:: + + >>> x = torch.zeros(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_copy_(0, index, t) + tensor([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 7., 8., 9.], + [ 0., 0., 0.], + [ 4., 5., 6.]]) + """ + ... + @overload + def index_fill(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + ... + @overload + def index_fill(self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + ... + @overload + def index_fill(self, dim: _int, index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + ... + @overload + def index_fill(self, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + ... + @overload + def index_fill_(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + ... + @overload + def index_fill_(self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + ... + @overload + def index_fill_(self, dim: _int, index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + ... + @overload + def index_fill_(self, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + ... + def index_put(self, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: + r""" + index_put(indices, values, accumulate=False) -> Tensor + + Out-place version of :meth:`~Tensor.index_put_`. + """ + ... + def index_put_(self, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: + r""" + index_put_(indices, values, accumulate=False) -> Tensor + + Puts values from the tensor :attr:`values` into the tensor :attr:`self` using + the indices specified in :attr:`indices` (which is a tuple of Tensors). The + expression ``tensor.index_put_(indices, values)`` is equivalent to + ``tensor[indices] = values``. Returns :attr:`self`. + + If :attr:`accumulate` is ``True``, the elements in :attr:`values` are added to + :attr:`self`. If accumulate is ``False``, the behavior is undefined if indices + contain duplicate elements. + + Args: + indices (tuple of LongTensor): tensors used to index into `self`. + values (Tensor): tensor of same dtype as `self`. + accumulate (bool): whether to accumulate into self + """ + ... + def index_reduce(self, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool = True) -> Tensor: ... + def index_reduce_(self, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool = True) -> Tensor: + r""" + index_reduce_(dim, index, source, reduce, *, include_self=True) -> Tensor + + Accumulate the elements of ``source`` into the :attr:`self` + tensor by accumulating to the indices in the order given in :attr:`index` + using the reduction given by the ``reduce`` argument. For example, if ``dim == 0``, + ``index[i] == j``, ``reduce == prod`` and ``include_self == True`` then the ``i``\ th + row of ``source`` is multiplied by the ``j``\ th row of :attr:`self`. If + :obj:`include_self="True"`, the values in the :attr:`self` tensor are included + in the reduction, otherwise, rows in the :attr:`self` tensor that are accumulated + to are treated as if they were filled with the reduction identites. + + The :attr:`dim`\ th dimension of ``source`` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + For a 3-D tensor with :obj:`reduce="prod"` and :obj:`include_self=True` the + output is given as:: + + self[index[i], :, :] *= src[i, :, :] # if dim == 0 + self[:, index[i], :] *= src[:, i, :] # if dim == 1 + self[:, :, index[i]] *= src[:, :, i] # if dim == 2 + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + .. note:: + + This function only supports floating point tensors. + + .. warning:: + + This function is in beta and may change in the near future. + + Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (FloatTensor): the tensor containing values to accumulate + reduce (str): the reduction operation to apply + (:obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + + Keyword args: + include_self (bool): whether the elements from the ``self`` tensor are + included in the reduction + + Example:: + + >>> x = torch.empty(5, 3).fill_(2) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2, 0]) + >>> x.index_reduce_(0, index, t, 'prod') + tensor([[20., 44., 72.], + [ 2., 2., 2.], + [14., 16., 18.], + [ 2., 2., 2.], + [ 8., 10., 12.]]) + >>> x = torch.empty(5, 3).fill_(2) + >>> x.index_reduce_(0, index, t, 'prod', include_self=False) + tensor([[10., 22., 36.], + [ 2., 2., 2.], + [ 7., 8., 9.], + [ 2., 2., 2.], + [ 4., 5., 6.]]) + """ + ... + @overload + def index_select(self, dim: _int, index: Tensor) -> Tensor: + r""" + index_select(dim, index) -> Tensor + + See :func:`torch.index_select` + """ + ... + @overload + def index_select(self, dim: Union[str, ellipsis, None], index: Tensor) -> Tensor: + r""" + index_select(dim, index) -> Tensor + + See :func:`torch.index_select` + """ + ... + def indices(self) -> Tensor: + r""" + indices() -> Tensor + + Return the indices tensor of a :ref:`sparse COO tensor `. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + + See also :meth:`Tensor.values`. + + .. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. + """ + ... + def inner(self, other: Tensor) -> Tensor: + r""" + inner(other) -> Tensor + + See :func:`torch.inner`. + """ + ... + def int(self) -> Tensor: + r""" + int(memory_format=torch.preserve_format) -> Tensor + + ``self.int()`` is equivalent to ``self.to(torch.int32)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def int_repr(self) -> Tensor: + r""" + int_repr() -> Tensor + + Given a quantized Tensor, + ``self.int_repr()`` returns a CPU Tensor with uint8_t as data type that stores the + underlying uint8_t values of the given Tensor. + """ + ... + def inverse(self) -> Tensor: + r""" + inverse() -> Tensor + + See :func:`torch.inverse` + """ + ... + def is_coalesced(self) -> _bool: + r""" + is_coalesced() -> bool + + Returns ``True`` if :attr:`self` is a :ref:`sparse COO tensor + ` that is coalesced, ``False`` otherwise. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + + See :meth:`coalesce` and :ref:`uncoalesced tensors `. + """ + ... + def is_complex(self) -> _bool: + r""" + is_complex() -> bool + + Returns True if the data type of :attr:`self` is a complex data type. + """ + ... + def is_conj(self) -> _bool: + r""" + is_conj() -> bool + + Returns True if the conjugate bit of :attr:`self` is set to true. + """ + ... + def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: + r""" + is_contiguous(memory_format=torch.contiguous_format) -> bool + + Returns True if :attr:`self` tensor is contiguous in memory in the order specified + by memory format. + + Args: + memory_format (:class:`torch.memory_format`, optional): Specifies memory allocation + order. Default: ``torch.contiguous_format``. + """ + ... + is_cpu: _bool + r"""Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise.""" + is_cuda: _bool + r"""Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise.""" + def is_distributed(self) -> _bool: ... + def is_floating_point(self) -> _bool: + r""" + is_floating_point() -> bool + + Returns True if the data type of :attr:`self` is a floating point data type. + """ + ... + def is_inference(self) -> _bool: + r""" + is_inference() -> bool + + See :func:`torch.is_inference` + """ + ... + is_ipu: _bool + r"""Is ``True`` if the Tensor is stored on the IPU, ``False`` otherwise.""" + is_leaf: _bool + r"""All Tensors that have :attr:`requires_grad` which is ``False`` will be leaf Tensors by convention. + + For Tensors that have :attr:`requires_grad` which is ``True``, they will be leaf Tensors if they were + created by the user. This means that they are not the result of an operation and so + :attr:`grad_fn` is None. + + Only leaf Tensors will have their :attr:`grad` populated during a call to :func:`backward`. + To get :attr:`grad` populated for non-leaf Tensors, you can use :func:`retain_grad`. + + Example:: + + >>> a = torch.rand(10, requires_grad=True) + >>> a.is_leaf + True + >>> b = torch.rand(10, requires_grad=True).cuda() + >>> b.is_leaf + False + # b was created by the operation that cast a cpu Tensor into a cuda Tensor + >>> c = torch.rand(10, requires_grad=True) + 2 + >>> c.is_leaf + False + # c was created by the addition operation + >>> d = torch.rand(10).cuda() + >>> d.is_leaf + True + # d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + >>> e = torch.rand(10).cuda().requires_grad_() + >>> e.is_leaf + True + # e requires gradients and has no operations creating it + >>> f = torch.rand(10, requires_grad=True, device="cuda") + >>> f.is_leaf + True + # f requires grad, has no operation creating it""" + is_meta: _bool + r"""Is ``True`` if the Tensor is a meta tensor, ``False`` otherwise. Meta tensors + are like normal tensors, but they carry no data.""" + is_mkldnn: _bool + is_mps: _bool + r"""Is ``True`` if the Tensor is stored on the MPS device, ``False`` otherwise.""" + is_mtia: _bool + def is_neg(self) -> _bool: + r""" + is_neg() -> bool + + Returns True if the negative bit of :attr:`self` is set to true. + """ + ... + is_nested: _bool + def is_nonzero(self) -> _bool: ... + is_ort: _bool + def is_pinned(self, device: Optional[Optional[DeviceLikeType]] = None) -> _bool: + r""" + Returns true if this tensor resides in pinned memory. + """ + ... + is_quantized: _bool + r"""Is ``True`` if the Tensor is quantized, ``False`` otherwise.""" + def is_same_size(self, other: Tensor) -> _bool: ... + def is_set_to(self, tensor: Tensor) -> _bool: + r""" + is_set_to(tensor) -> bool + + Returns True if both tensors are pointing to the exact same memory (same + storage, offset, size and stride). + """ + ... + def is_signed(self) -> _bool: + r""" + is_signed() -> bool + + Returns True if the data type of :attr:`self` is a signed data type. + """ + ... + is_sparse: _bool + r"""Is ``True`` if the Tensor uses sparse COO storage layout, ``False`` otherwise.""" + is_sparse_csr: _bool + r"""Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise.""" + is_vulkan: _bool + def isclose(self, other: Tensor, rtol: _float = 1e-05, atol: _float = 1e-08, equal_nan: _bool = False) -> Tensor: + r""" + isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + See :func:`torch.isclose` + """ + ... + def isfinite(self) -> Tensor: + r""" + isfinite() -> Tensor + + See :func:`torch.isfinite` + """ + ... + def isinf(self) -> Tensor: + r""" + isinf() -> Tensor + + See :func:`torch.isinf` + """ + ... + def isnan(self) -> Tensor: + r""" + isnan() -> Tensor + + See :func:`torch.isnan` + """ + ... + def isneginf(self) -> Tensor: + r""" + isneginf() -> Tensor + + See :func:`torch.isneginf` + """ + ... + def isposinf(self) -> Tensor: + r""" + isposinf() -> Tensor + + See :func:`torch.isposinf` + """ + ... + def isreal(self) -> Tensor: + r""" + isreal() -> Tensor + + See :func:`torch.isreal` + """ + ... + def istft(self, n_fft: _int, hop_length: Optional[_int] = None, win_length: Optional[_int] = None, window: Optional[Tensor] = None, center: _bool = True, normalized: _bool = False, onesided: Optional[_bool] = None, length: Optional[_int] = None, return_complex: _bool = False) -> Tensor: + r""" + istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor + + See :func:`torch.istft` + """ + ... + def item(self) -> Number: + r""" + item() -> number + + Returns the value of this tensor as a standard Python number. This only works + for tensors with one element. For other cases, see :meth:`~Tensor.tolist`. + + This operation is not differentiable. + + Example:: + + >>> x = torch.tensor([1.0]) + >>> x.item() + 1.0 + """ + ... + def kron(self, other: Tensor) -> Tensor: + r""" + kron(other) -> Tensor + + See :func:`torch.kron` + """ + ... + @overload + def kthvalue(self, k: _int, dim: _int = -1, keepdim: _bool = False) -> torch.return_types.kthvalue: + r""" + kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.kthvalue` + """ + ... + @overload + def kthvalue(self, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> torch.return_types.kthvalue: + r""" + kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.kthvalue` + """ + ... + def lcm(self, other: Tensor) -> Tensor: + r""" + lcm(other) -> Tensor + + See :func:`torch.lcm` + """ + ... + def lcm_(self, other: Tensor) -> Tensor: + r""" + lcm_(other) -> Tensor + + In-place version of :meth:`~Tensor.lcm` + """ + ... + def ldexp(self, other: Tensor) -> Tensor: + r""" + ldexp(other) -> Tensor + + See :func:`torch.ldexp` + """ + ... + def ldexp_(self, other: Tensor) -> Tensor: + r""" + ldexp_(other) -> Tensor + + In-place version of :meth:`~Tensor.ldexp` + """ + ... + @overload + def le(self, other: Tensor) -> Tensor: + r""" + le(other) -> Tensor + + See :func:`torch.le`. + """ + ... + @overload + def le(self, other: Union[Number, _complex]) -> Tensor: + r""" + le(other) -> Tensor + + See :func:`torch.le`. + """ + ... + @overload + def le_(self, other: Tensor) -> Tensor: + r""" + le_(other) -> Tensor + + In-place version of :meth:`~Tensor.le`. + """ + ... + @overload + def le_(self, other: Union[Number, _complex]) -> Tensor: + r""" + le_(other) -> Tensor + + In-place version of :meth:`~Tensor.le`. + """ + ... + @overload + def lerp(self, end: Tensor, weight: Tensor) -> Tensor: + r""" + lerp(end, weight) -> Tensor + + See :func:`torch.lerp` + """ + ... + @overload + def lerp(self, end: Tensor, weight: Union[Number, _complex]) -> Tensor: + r""" + lerp(end, weight) -> Tensor + + See :func:`torch.lerp` + """ + ... + @overload + def lerp_(self, end: Tensor, weight: Tensor) -> Tensor: + r""" + lerp_(end, weight) -> Tensor + + In-place version of :meth:`~Tensor.lerp` + """ + ... + @overload + def lerp_(self, end: Tensor, weight: Union[Number, _complex]) -> Tensor: + r""" + lerp_(end, weight) -> Tensor + + In-place version of :meth:`~Tensor.lerp` + """ + ... + @overload + def less(self, other: Tensor) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.less`. + """ + ... + @overload + def less(self, other: Union[Number, _complex]) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.less`. + """ + ... + @overload + def less_(self, other: Tensor) -> Tensor: + r""" + less_(other) -> Tensor + + In-place version of :meth:`~Tensor.less`. + """ + ... + @overload + def less_(self, other: Union[Number, _complex]) -> Tensor: + r""" + less_(other) -> Tensor + + In-place version of :meth:`~Tensor.less`. + """ + ... + @overload + def less_equal(self, other: Tensor) -> Tensor: + r""" + less_equal(other) -> Tensor + + See :func:`torch.less_equal`. + """ + ... + @overload + def less_equal(self, other: Union[Number, _complex]) -> Tensor: + r""" + less_equal(other) -> Tensor + + See :func:`torch.less_equal`. + """ + ... + @overload + def less_equal_(self, other: Tensor) -> Tensor: + r""" + less_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.less_equal`. + """ + ... + @overload + def less_equal_(self, other: Union[Number, _complex]) -> Tensor: + r""" + less_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.less_equal`. + """ + ... + def lgamma(self) -> Tensor: + r""" + lgamma() -> Tensor + + See :func:`torch.lgamma` + """ + ... + def lgamma_(self) -> Tensor: + r""" + lgamma_() -> Tensor + + In-place version of :meth:`~Tensor.lgamma` + """ + ... + def log(self) -> Tensor: + r""" + log() -> Tensor + + See :func:`torch.log` + """ + ... + def log10(self) -> Tensor: + r""" + log10() -> Tensor + + See :func:`torch.log10` + """ + ... + def log10_(self) -> Tensor: + r""" + log10_() -> Tensor + + In-place version of :meth:`~Tensor.log10` + """ + ... + def log1p(self) -> Tensor: + r""" + log1p() -> Tensor + + See :func:`torch.log1p` + """ + ... + def log1p_(self) -> Tensor: + r""" + log1p_() -> Tensor + + In-place version of :meth:`~Tensor.log1p` + """ + ... + def log2(self) -> Tensor: + r""" + log2() -> Tensor + + See :func:`torch.log2` + """ + ... + def log2_(self) -> Tensor: + r""" + log2_() -> Tensor + + In-place version of :meth:`~Tensor.log2` + """ + ... + def log_(self) -> Tensor: + r""" + log_() -> Tensor + + In-place version of :meth:`~Tensor.log` + """ + ... + def log_normal_(self, mean: _float = 1, std: _float = 2, *, generator: Optional[Generator] = None) -> Tensor: + r""" + log_normal_(mean=1, std=2, *, generator=None) + + Fills :attr:`self` tensor with numbers samples from the log-normal distribution + parameterized by the given mean :math:`\mu` and standard deviation + :math:`\sigma`. Note that :attr:`mean` and :attr:`std` are the mean and + standard deviation of the underlying normal distribution, and not of the + returned distribution: + + .. math:: + + f(x) = \dfrac{1}{x \sigma \sqrt{2\pi}}\ e^{-\frac{(\ln x - \mu)^2}{2\sigma^2}} + """ + ... + @overload + def log_softmax(self, dim: _int, dtype: Optional[_dtype] = None) -> Tensor: ... + @overload + def log_softmax(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: ... + def logaddexp(self, other: Tensor) -> Tensor: + r""" + logaddexp(other) -> Tensor + + See :func:`torch.logaddexp` + """ + ... + def logaddexp2(self, other: Tensor) -> Tensor: + r""" + logaddexp2(other) -> Tensor + + See :func:`torch.logaddexp2` + """ + ... + @overload + def logcumsumexp(self, dim: _int) -> Tensor: + r""" + logcumsumexp(dim) -> Tensor + + See :func:`torch.logcumsumexp` + """ + ... + @overload + def logcumsumexp(self, dim: Union[str, ellipsis, None]) -> Tensor: + r""" + logcumsumexp(dim) -> Tensor + + See :func:`torch.logcumsumexp` + """ + ... + def logdet(self) -> Tensor: + r""" + logdet() -> Tensor + + See :func:`torch.logdet` + """ + ... + def logical_and(self, other: Tensor) -> Tensor: + r""" + logical_and() -> Tensor + + See :func:`torch.logical_and` + """ + ... + def logical_and_(self, other: Tensor) -> Tensor: + r""" + logical_and_() -> Tensor + + In-place version of :meth:`~Tensor.logical_and` + """ + ... + def logical_not(self) -> Tensor: + r""" + logical_not() -> Tensor + + See :func:`torch.logical_not` + """ + ... + def logical_not_(self) -> Tensor: + r""" + logical_not_() -> Tensor + + In-place version of :meth:`~Tensor.logical_not` + """ + ... + def logical_or(self, other: Tensor) -> Tensor: + r""" + logical_or() -> Tensor + + See :func:`torch.logical_or` + """ + ... + def logical_or_(self, other: Tensor) -> Tensor: + r""" + logical_or_() -> Tensor + + In-place version of :meth:`~Tensor.logical_or` + """ + ... + def logical_xor(self, other: Tensor) -> Tensor: + r""" + logical_xor() -> Tensor + + See :func:`torch.logical_xor` + """ + ... + def logical_xor_(self, other: Tensor) -> Tensor: + r""" + logical_xor_() -> Tensor + + In-place version of :meth:`~Tensor.logical_xor` + """ + ... + def logit(self, eps: Optional[_float] = None) -> Tensor: + r""" + logit() -> Tensor + + See :func:`torch.logit` + """ + ... + def logit_(self, eps: Optional[_float] = None) -> Tensor: + r""" + logit_() -> Tensor + + In-place version of :meth:`~Tensor.logit` + """ + ... + @overload + def logsumexp(self, dim: Union[_int, _size], keepdim: _bool = False) -> Tensor: + r""" + logsumexp(dim, keepdim=False) -> Tensor + + See :func:`torch.logsumexp` + """ + ... + @overload + def logsumexp(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False) -> Tensor: + r""" + logsumexp(dim, keepdim=False) -> Tensor + + See :func:`torch.logsumexp` + """ + ... + def long(self) -> Tensor: + r""" + long(memory_format=torch.preserve_format) -> Tensor + + ``self.long()`` is equivalent to ``self.to(torch.int64)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + @overload + def lt(self, other: Tensor) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.lt`. + """ + ... + @overload + def lt(self, other: Union[Number, _complex]) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.lt`. + """ + ... + @overload + def lt_(self, other: Tensor) -> Tensor: + r""" + lt_(other) -> Tensor + + In-place version of :meth:`~Tensor.lt`. + """ + ... + @overload + def lt_(self, other: Union[Number, _complex]) -> Tensor: + r""" + lt_(other) -> Tensor + + In-place version of :meth:`~Tensor.lt`. + """ + ... + def lu_solve(self, LU_data: Tensor, LU_pivots: Tensor) -> Tensor: + r""" + lu_solve(LU_data, LU_pivots) -> Tensor + + See :func:`torch.lu_solve` + """ + ... + def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ... + def map_(self, tensor: Tensor, callable: Callable) -> Tensor: + r""" + map_(tensor, callable) + + Applies :attr:`callable` for each element in :attr:`self` tensor and the given + :attr:`tensor` and stores the results in :attr:`self` tensor. :attr:`self` tensor and + the given :attr:`tensor` must be :ref:`broadcastable `. + + The :attr:`callable` should have the signature:: + + def callable(a, b) -> number + """ + ... + @overload + def masked_fill(self, mask: Tensor, value: Tensor) -> Tensor: + r""" + masked_fill(mask, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.masked_fill_` + """ + ... + @overload + def masked_fill(self, mask: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + masked_fill(mask, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.masked_fill_` + """ + ... + @overload + def masked_fill_(self, mask: Tensor, value: Tensor) -> Tensor: + r""" + masked_fill_(mask, value) + + Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is + True. The shape of :attr:`mask` must be + :ref:`broadcastable ` with the shape of the underlying + tensor. + + Args: + mask (BoolTensor): the boolean mask + value (float): the value to fill in with + """ + ... + @overload + def masked_fill_(self, mask: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + masked_fill_(mask, value) + + Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is + True. The shape of :attr:`mask` must be + :ref:`broadcastable ` with the shape of the underlying + tensor. + + Args: + mask (BoolTensor): the boolean mask + value (float): the value to fill in with + """ + ... + def masked_scatter(self, mask: Tensor, source: Tensor) -> Tensor: + r""" + masked_scatter(mask, tensor) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.masked_scatter_` + + .. note:: + + The inputs :attr:`self` and :attr:`mask` + :ref:`broadcast `. + + Example: + + >>> self = torch.tensor([0, 0, 0, 0, 0]) + >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]]) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + """ + ... + def masked_scatter_(self, mask: Tensor, source: Tensor) -> Tensor: + r""" + masked_scatter_(mask, source) + + Copies elements from :attr:`source` into :attr:`self` tensor at positions where + the :attr:`mask` is True. Elements from :attr:`source` are copied into :attr:`self` + starting at position 0 of :attr:`source` and continuing in order one-by-one for each + occurrence of :attr:`mask` being True. + The shape of :attr:`mask` must be :ref:`broadcastable ` + with the shape of the underlying tensor. The :attr:`source` should have at least + as many elements as the number of ones in :attr:`mask`. + + Args: + mask (BoolTensor): the boolean mask + source (Tensor): the tensor to copy from + + .. note:: + + The :attr:`mask` operates on the :attr:`self` tensor, not on the given + :attr:`source` tensor. + + Example: + + >>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) + >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]]) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter_(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + """ + ... + def masked_select(self, mask: Tensor) -> Tensor: + r""" + masked_select(mask) -> Tensor + + See :func:`torch.masked_select` + """ + ... + def matmul(self, other: Tensor) -> Tensor: + r""" + matmul(tensor2) -> Tensor + + See :func:`torch.matmul` + """ + ... + def matrix_exp(self) -> Tensor: + r""" + matrix_exp() -> Tensor + + See :func:`torch.matrix_exp` + """ + ... + def matrix_power(self, n: _int) -> Tensor: + r""" + matrix_power(n) -> Tensor + + .. note:: :meth:`~Tensor.matrix_power` is deprecated, use :func:`torch.linalg.matrix_power` instead. + + Alias for :func:`torch.linalg.matrix_power` + """ + ... + @overload + def max(self) -> Tensor: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + ... + @overload + def max(self, other: Tensor) -> Tensor: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + ... + @overload + def max(self, dim: _int, keepdim: _bool = False) -> torch.return_types.max: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + ... + @overload + def max(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> torch.return_types.max: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + ... + def maximum(self, other: Tensor) -> Tensor: + r""" + maximum(other) -> Tensor + + See :func:`torch.maximum` + """ + ... + @overload + def mean(self, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.mean` + """ + ... + @overload + def mean(self, dim: Optional[Union[_int, _size]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.mean` + """ + ... + @overload + def mean(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.mean` + """ + ... + @overload + def median(self) -> Tensor: + r""" + median(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.median` + """ + ... + @overload + def median(self, dim: _int, keepdim: _bool = False) -> torch.return_types.median: + r""" + median(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.median` + """ + ... + @overload + def median(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> torch.return_types.median: + r""" + median(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.median` + """ + ... + @overload + def min(self) -> Tensor: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + ... + @overload + def min(self, other: Tensor) -> Tensor: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + ... + @overload + def min(self, dim: _int, keepdim: _bool = False) -> torch.return_types.min: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + ... + @overload + def min(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> torch.return_types.min: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + ... + def minimum(self, other: Tensor) -> Tensor: + r""" + minimum(other) -> Tensor + + See :func:`torch.minimum` + """ + ... + def mm(self, mat2: Tensor) -> Tensor: + r""" + mm(mat2) -> Tensor + + See :func:`torch.mm` + """ + ... + @overload + def mode(self, dim: _int = -1, keepdim: _bool = False) -> torch.return_types.mode: + r""" + mode(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.mode` + """ + ... + @overload + def mode(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> torch.return_types.mode: + r""" + mode(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.mode` + """ + ... + @overload + def moveaxis(self, source: _int, destination: _int) -> Tensor: + r""" + moveaxis(source, destination) -> Tensor + + See :func:`torch.moveaxis` + """ + ... + @overload + def moveaxis(self, source: _size, destination: _size) -> Tensor: + r""" + moveaxis(source, destination) -> Tensor + + See :func:`torch.moveaxis` + """ + ... + @overload + def movedim(self, source: _int, destination: _int) -> Tensor: + r""" + movedim(source, destination) -> Tensor + + See :func:`torch.movedim` + """ + ... + @overload + def movedim(self, source: _size, destination: _size) -> Tensor: + r""" + movedim(source, destination) -> Tensor + + See :func:`torch.movedim` + """ + ... + def msort(self) -> Tensor: + r""" + msort() -> Tensor + + See :func:`torch.msort` + """ + ... + def mul(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor] = None) -> Tensor: + r""" + mul(value) -> Tensor + + See :func:`torch.mul`. + """ + ... + def mul_(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat]) -> Tensor: + r""" + mul_(value) -> Tensor + + In-place version of :meth:`~Tensor.mul`. + """ + ... + def multinomial(self, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None) -> Tensor: + r""" + multinomial(num_samples, replacement=False, *, generator=None) -> Tensor + + See :func:`torch.multinomial` + """ + ... + @overload + def multiply(self, other: Tensor) -> Tensor: + r""" + multiply(value) -> Tensor + + See :func:`torch.multiply`. + """ + ... + @overload + def multiply(self, other: Union[Number, _complex]) -> Tensor: + r""" + multiply(value) -> Tensor + + See :func:`torch.multiply`. + """ + ... + @overload + def multiply_(self, other: Tensor) -> Tensor: + r""" + multiply_(value) -> Tensor + + In-place version of :meth:`~Tensor.multiply`. + """ + ... + @overload + def multiply_(self, other: Union[Number, _complex]) -> Tensor: + r""" + multiply_(value) -> Tensor + + In-place version of :meth:`~Tensor.multiply`. + """ + ... + def mv(self, vec: Tensor) -> Tensor: + r""" + mv(vec) -> Tensor + + See :func:`torch.mv` + """ + ... + def mvlgamma(self, p: _int) -> Tensor: + r""" + mvlgamma(p) -> Tensor + + See :func:`torch.mvlgamma` + """ + ... + def mvlgamma_(self, p: _int) -> Tensor: + r""" + mvlgamma_(p) -> Tensor + + In-place version of :meth:`~Tensor.mvlgamma` + """ + ... + def nan_to_num(self, nan: Optional[_float] = None, posinf: Optional[_float] = None, neginf: Optional[_float] = None) -> Tensor: + r""" + nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor + + See :func:`torch.nan_to_num`. + """ + ... + def nan_to_num_(self, nan: Optional[_float] = None, posinf: Optional[_float] = None, neginf: Optional[_float] = None) -> Tensor: + r""" + nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor + + In-place version of :meth:`~Tensor.nan_to_num`. + """ + ... + def nanmean(self, dim: Optional[Union[_int, _size]] = None, keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + nanmean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.nanmean` + """ + ... + @overload + def nanmedian(self) -> Tensor: + r""" + nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.nanmedian` + """ + ... + @overload + def nanmedian(self, dim: _int, keepdim: _bool = False) -> torch.return_types.nanmedian: + r""" + nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.nanmedian` + """ + ... + @overload + def nanmedian(self, dim: Union[str, ellipsis, None], keepdim: _bool = False) -> torch.return_types.nanmedian: + r""" + nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.nanmedian` + """ + ... + @overload + def nanquantile(self, q: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear") -> Tensor: + r""" + nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.nanquantile` + """ + ... + @overload + def nanquantile(self, q: _float, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear") -> Tensor: + r""" + nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.nanquantile` + """ + ... + def nansum(self, dim: Optional[Union[_int, _size]] = None, keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + nansum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.nansum` + """ + ... + @overload + def narrow(self, dim: _int, start: Tensor, length: Union[_int, SymInt]) -> Tensor: + r""" + narrow(dimension, start, length) -> Tensor + + See :func:`torch.narrow`. + """ + ... + @overload + def narrow(self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: + r""" + narrow(dimension, start, length) -> Tensor + + See :func:`torch.narrow`. + """ + ... + def narrow_copy(self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: + r""" + narrow_copy(dimension, start, length) -> Tensor + + See :func:`torch.narrow_copy`. + """ + ... + def ndimension(self) -> _int: + r""" + ndimension() -> int + + Alias for :meth:`~Tensor.dim()` + """ + ... + @overload + def ne(self, other: Tensor) -> Tensor: + r""" + ne(other) -> Tensor + + See :func:`torch.ne`. + """ + ... + @overload + def ne(self, other: Union[Number, _complex]) -> Tensor: + r""" + ne(other) -> Tensor + + See :func:`torch.ne`. + """ + ... + @overload + def ne_(self, other: Tensor) -> Tensor: + r""" + ne_(other) -> Tensor + + In-place version of :meth:`~Tensor.ne`. + """ + ... + @overload + def ne_(self, other: Union[Number, _complex]) -> Tensor: + r""" + ne_(other) -> Tensor + + In-place version of :meth:`~Tensor.ne`. + """ + ... + def neg(self) -> Tensor: + r""" + neg() -> Tensor + + See :func:`torch.neg` + """ + ... + def neg_(self) -> Tensor: + r""" + neg_() -> Tensor + + In-place version of :meth:`~Tensor.neg` + """ + ... + def negative(self) -> Tensor: + r""" + negative() -> Tensor + + See :func:`torch.negative` + """ + ... + def negative_(self) -> Tensor: + r""" + negative_() -> Tensor + + In-place version of :meth:`~Tensor.negative` + """ + ... + def nelement(self) -> _int: + r""" + nelement() -> int + + Alias for :meth:`~Tensor.numel` + """ + ... + @overload + def new(self, *args: Any, device: Optional[DeviceLikeType] = None) -> Tensor: ... + @overload + def new(self, storage: Storage) -> Tensor: ... + @overload + def new(self, other: Tensor) -> Tensor: ... + @overload + def new(self, size: _size, *, device: Optional[DeviceLikeType] = None) -> Tensor: ... + @overload + def new_empty(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_empty(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with uninitialized data. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty((2, 3)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + """ + ... + @overload + def new_empty(self, *size: _int, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_empty(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with uninitialized data. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty((2, 3)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + """ + ... + def new_empty_strided(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with + uninitialized data. By default, the returned Tensor has the same + :class:`torch.dtype` and :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty_strided((2, 3), (3, 1)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + """ + ... + def new_full(self, size: Sequence[Union[_int, SymInt]], fill_value: Union[Number, _complex], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_full(size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with :attr:`fill_value`. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + fill_value (scalar): the number to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones((2,), dtype=torch.float64) + >>> tensor.new_full((3, 4), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416]], dtype=torch.float64) + """ + ... + @overload + def new_ones(self, size: _size, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``1``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + """ + ... + @overload + def new_ones(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``1``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + """ + ... + @overload + def new_ones(self, *size: _int, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``1``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + """ + ... + def new_tensor(self, data: Any, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + new_tensor(data, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a new Tensor with :attr:`data` as the tensor data. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + .. warning:: + + :func:`new_tensor` always copies :attr:`data`. If you have a Tensor + ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` + or :func:`torch.Tensor.detach`. + If you have a numpy array and want to avoid a copy, use + :func:`torch.from_numpy`. + + .. warning:: + + When data is a tensor `x`, :func:`new_tensor()` reads out 'the data' from whatever it is passed, + and constructs a leaf variable. Therefore ``tensor.new_tensor(x)`` is equivalent to ``x.clone().detach()`` + and ``tensor.new_tensor(x, requires_grad=True)`` is equivalent to ``x.clone().detach().requires_grad_(True)``. + The equivalents using ``clone()`` and ``detach()`` are recommended. + + Args: + data (array_like): The returned Tensor copies :attr:`data`. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones((2,), dtype=torch.int8) + >>> data = [[0, 1], [2, 3]] + >>> tensor.new_tensor(data) + tensor([[ 0, 1], + [ 2, 3]], dtype=torch.int8) + """ + ... + @overload + def new_zeros(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``0``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.float64) + >>> tensor.new_zeros((2, 3)) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]], dtype=torch.float64) + """ + ... + @overload + def new_zeros(self, *size: _int, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``0``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.float64) + >>> tensor.new_zeros((2, 3)) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]], dtype=torch.float64) + """ + ... + def nextafter(self, other: Tensor) -> Tensor: + r""" + nextafter(other) -> Tensor + See :func:`torch.nextafter` + """ + ... + def nextafter_(self, other: Tensor) -> Tensor: + r""" + nextafter_(other) -> Tensor + In-place version of :meth:`~Tensor.nextafter` + """ + ... + @overload + def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: + r""" + nonzero() -> LongTensor + + See :func:`torch.nonzero` + """ + ... + @overload + def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: + r""" + nonzero() -> LongTensor + + See :func:`torch.nonzero` + """ + ... + def nonzero_static(self, *, size: _int, fill_value: _int = -1) -> Tensor: + r""" + nonzero_static(input, *, size, fill_value=-1) -> Tensor + + Returns a 2-D tensor where each row is the index for a non-zero value. + The returned Tensor has the same `torch.dtype` as `torch.nonzero()`. + + Args: + input (Tensor): the input tensor to count non-zero elements. + + Keyword args: + size (int): the size of non-zero elements expected to be included in the out + tensor. Pad the out tensor with `fill_value` if the `size` is larger + than total number of non-zero elements, truncate out tensor if `size` + is smaller. The size must be a non-negative integer. + fill_value (int): the value to fill the output tensor with when `size` is larger + than the total number of non-zero elements. Default is `-1` to represent + invalid index. + + Example: + + # Example 1: Padding + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 4 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([[ 0, 0], + [ 1, 0], + [ 1, 1], + [ -1, -1]], dtype=torch.int64) + + # Example 2: Truncating + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([[ 0, 0], + [ 1, 0]], dtype=torch.int64) + + # Example 3: 0 size + >>> input_tensor = torch.tensor([10]) + >>> static_size = 0 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([], size=(0, 1), dtype=torch.int64) + + # Example 4: 0 rank input + >>> input_tensor = torch.tensor(10) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([], size=(2, 0), dtype=torch.int64) + """ + ... + def normal_(self, mean: _float = 0, std: _float = 1, *, generator: Optional[Generator] = None) -> Tensor: + r""" + normal_(mean=0, std=1, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with elements samples from the normal distribution + parameterized by :attr:`mean` and :attr:`std`. + """ + ... + @overload + def not_equal(self, other: Tensor) -> Tensor: + r""" + not_equal(other) -> Tensor + + See :func:`torch.not_equal`. + """ + ... + @overload + def not_equal(self, other: Union[Number, _complex]) -> Tensor: + r""" + not_equal(other) -> Tensor + + See :func:`torch.not_equal`. + """ + ... + @overload + def not_equal_(self, other: Tensor) -> Tensor: + r""" + not_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.not_equal`. + """ + ... + @overload + def not_equal_(self, other: Union[Number, _complex]) -> Tensor: + r""" + not_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.not_equal`. + """ + ... + def numel(self) -> _int: + r""" + numel() -> int + + See :func:`torch.numel` + """ + ... + def numpy(self, *, force: _bool = False) -> Any: + r""" + numpy(*, force=False) -> numpy.ndarray + + Returns the tensor as a NumPy :class:`ndarray`. + + If :attr:`force` is ``False`` (the default), the conversion + is performed only if the tensor is on the CPU, does not require grad, + does not have its conjugate bit set, and is a dtype and layout that + NumPy supports. The returned ndarray and the tensor will share their + storage, so changes to the tensor will be reflected in the ndarray + and vice versa. + + If :attr:`force` is ``True`` this is equivalent to + calling ``t.detach().cpu().resolve_conj().resolve_neg().numpy()``. + If the tensor isn't on the CPU or the conjugate or negative bit is set, + the tensor won't share its storage with the returned ndarray. + Setting :attr:`force` to ``True`` can be a useful shorthand. + + Args: + force (bool): if ``True``, the ndarray may be a copy of the tensor + instead of always sharing memory, defaults to ``False``. + """ + ... + def orgqr(self, input2: Tensor) -> Tensor: + r""" + orgqr(input2) -> Tensor + + See :func:`torch.orgqr` + """ + ... + def ormqr(self, input2: Tensor, input3: Tensor, left: _bool = True, transpose: _bool = False) -> Tensor: + r""" + ormqr(input2, input3, left=True, transpose=False) -> Tensor + + See :func:`torch.ormqr` + """ + ... + def outer(self, vec2: Tensor) -> Tensor: + r""" + outer(vec2) -> Tensor + + See :func:`torch.outer`. + """ + ... + @overload + def permute(self, dims: _size) -> Tensor: + r""" + permute(*dims) -> Tensor + + See :func:`torch.permute` + """ + ... + @overload + def permute(self, *dims: _int) -> Tensor: + r""" + permute(*dims) -> Tensor + + See :func:`torch.permute` + """ + ... + def pin_memory(self, device: Optional[Optional[DeviceLikeType]] = None) -> Tensor: + r""" + pin_memory() -> Tensor + + Copies the tensor to pinned memory, if it's not already pinned. + """ + ... + def pinverse(self, rcond: _float = 1e-15) -> Tensor: + r""" + pinverse() -> Tensor + + See :func:`torch.pinverse` + """ + ... + def polygamma(self, n: _int) -> Tensor: + r""" + polygamma(n) -> Tensor + + See :func:`torch.polygamma` + """ + ... + def polygamma_(self, n: _int) -> Tensor: + r""" + polygamma_(n) -> Tensor + + In-place version of :meth:`~Tensor.polygamma` + """ + ... + def positive(self) -> Tensor: + r""" + positive() -> Tensor + + See :func:`torch.positive` + """ + ... + @overload + def pow(self, exponent: Tensor) -> Tensor: + r""" + pow(exponent) -> Tensor + + See :func:`torch.pow` + """ + ... + @overload + def pow(self, exponent: Union[Number, _complex]) -> Tensor: + r""" + pow(exponent) -> Tensor + + See :func:`torch.pow` + """ + ... + @overload + def pow_(self, exponent: Tensor) -> Tensor: + r""" + pow_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.pow` + """ + ... + @overload + def pow_(self, exponent: Union[Number, _complex]) -> Tensor: + r""" + pow_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.pow` + """ + ... + def prelu(self, weight: Tensor) -> Tensor: ... + @overload + def prod(self, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + prod(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.prod` + """ + ... + @overload + def prod(self, dim: _int, keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + prod(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.prod` + """ + ... + @overload + def prod(self, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + prod(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.prod` + """ + ... + def put(self, index: Tensor, source: Tensor, accumulate: _bool = False) -> Tensor: + r""" + put(input, index, source, accumulate=False) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.put_`. + `input` corresponds to `self` in :meth:`torch.Tensor.put_`. + """ + ... + def put_(self, index: Tensor, source: Tensor, accumulate: _bool = False) -> Tensor: + r""" + put_(index, source, accumulate=False) -> Tensor + + Copies the elements from :attr:`source` into the positions specified by + :attr:`index`. For the purpose of indexing, the :attr:`self` tensor is treated as if + it were a 1-D tensor. + + :attr:`index` and :attr:`source` need to have the same number of elements, but not necessarily + the same shape. + + If :attr:`accumulate` is ``True``, the elements in :attr:`source` are added to + :attr:`self`. If accumulate is ``False``, the behavior is undefined if :attr:`index` + contain duplicate elements. + + Args: + index (LongTensor): the indices into self + source (Tensor): the tensor containing values to copy from + accumulate (bool): whether to accumulate into self + + Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> src.put_(torch.tensor([1, 3]), torch.tensor([9, 10])) + tensor([[ 4, 9, 5], + [ 10, 7, 8]]) + """ + ... + def q_per_channel_axis(self) -> _int: + r""" + q_per_channel_axis() -> int + + Given a Tensor quantized by linear (affine) per-channel quantization, + returns the index of dimension on which per-channel quantization is applied. + """ + ... + def q_per_channel_scales(self) -> Tensor: + r""" + q_per_channel_scales() -> Tensor + + Given a Tensor quantized by linear (affine) per-channel quantization, + returns a Tensor of scales of the underlying quantizer. It has the number of + elements that matches the corresponding dimensions (from q_per_channel_axis) of + the tensor. + """ + ... + def q_per_channel_zero_points(self) -> Tensor: + r""" + q_per_channel_zero_points() -> Tensor + + Given a Tensor quantized by linear (affine) per-channel quantization, + returns a tensor of zero_points of the underlying quantizer. It has the number of + elements that matches the corresponding dimensions (from q_per_channel_axis) of + the tensor. + """ + ... + def q_scale(self) -> _float: + r""" + q_scale() -> float + + Given a Tensor quantized by linear(affine) quantization, + returns the scale of the underlying quantizer(). + """ + ... + def q_zero_point(self) -> _int: + r""" + q_zero_point() -> int + + Given a Tensor quantized by linear(affine) quantization, + returns the zero_point of the underlying quantizer(). + """ + ... + def qr(self, some: _bool = True) -> torch.return_types.qr: + r""" + qr(some=True) -> (Tensor, Tensor) + + See :func:`torch.qr` + """ + ... + def qscheme(self) -> _qscheme: + r""" + qscheme() -> torch.qscheme + + Returns the quantization scheme of a given QTensor. + """ + ... + @overload + def quantile(self, q: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear") -> Tensor: + r""" + quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.quantile` + """ + ... + @overload + def quantile(self, q: _float, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear") -> Tensor: + r""" + quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.quantile` + """ + ... + def rad2deg(self) -> Tensor: + r""" + rad2deg() -> Tensor + + See :func:`torch.rad2deg` + """ + ... + def rad2deg_(self) -> Tensor: + r""" + rad2deg_() -> Tensor + + In-place version of :meth:`~Tensor.rad2deg` + """ + ... + @overload + def random_(self, *, generator: Optional[Generator] = None) -> Tensor: + r""" + random_(from=0, to=None, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the discrete uniform + distribution over ``[from, to - 1]``. If not specified, the values are usually + only bounded by :attr:`self` tensor's data type. However, for floating point + types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every + value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` + will be uniform in ``[0, 2^53]``. + """ + ... + @overload + def random_(self, from_: _int, to: Optional[_int], *, generator: Optional[Generator] = None) -> Tensor: + r""" + random_(from=0, to=None, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the discrete uniform + distribution over ``[from, to - 1]``. If not specified, the values are usually + only bounded by :attr:`self` tensor's data type. However, for floating point + types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every + value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` + will be uniform in ``[0, 2^53]``. + """ + ... + @overload + def random_(self, to: _int, *, generator: Optional[Generator] = None) -> Tensor: + r""" + random_(from=0, to=None, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the discrete uniform + distribution over ``[from, to - 1]``. If not specified, the values are usually + only bounded by :attr:`self` tensor's data type. However, for floating point + types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every + value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` + will be uniform in ``[0, 2^53]``. + """ + ... + def ravel(self) -> Tensor: + r""" + ravel() -> Tensor + + see :func:`torch.ravel` + """ + ... + def reciprocal(self) -> Tensor: + r""" + reciprocal() -> Tensor + + See :func:`torch.reciprocal` + """ + ... + def reciprocal_(self) -> Tensor: + r""" + reciprocal_() -> Tensor + + In-place version of :meth:`~Tensor.reciprocal` + """ + ... + def record_stream(self, s: Stream) -> None: + r""" + record_stream(stream) + + Marks the tensor as having been used by this stream. When the tensor + is deallocated, ensure the tensor memory is not reused for another tensor + until all work queued on :attr:`stream` at the time of deallocation is + complete. + + .. note:: + + The caching allocator is aware of only the stream where a tensor was + allocated. Due to the awareness, it already correctly manages the life + cycle of tensors on only one stream. But if a tensor is used on a stream + different from the stream of origin, the allocator might reuse the memory + unexpectedly. Calling this method lets the allocator know which streams + have used the tensor. + + .. warning:: + + This method is most suitable for use cases where you are providing a + function that created a tensor on a side stream, and want users to be able + to make use of the tensor without having to think carefully about stream + safety when making use of them. These safety guarantees come at some + performance and predictability cost (analogous to the tradeoff between GC + and manual memory management), so if you are in a situation where + you manage the full lifetime of your tensors, you may consider instead + manually managing CUDA events so that calling this method is not necessary. + In particular, when you call this method, on later allocations the + allocator will poll the recorded stream to see if all operations have + completed yet; you can potentially race with side stream computation and + non-deterministically reuse or fail to reuse memory for an allocation. + + You can safely use tensors allocated on side streams without + :meth:`~Tensor.record_stream`; you must manually ensure that + any non-creation stream uses of a tensor are synced back to the creation + stream before you deallocate the tensor. As the CUDA caching allocator + guarantees that the memory will only be reused with the same creation stream, + this is sufficient to ensure that writes to future reallocations of the + memory will be delayed until non-creation stream uses are done. + (Counterintuitively, you may observe that on the CPU side we have already + reallocated the tensor, even though CUDA kernels on the old tensor are + still in progress. This is fine, because CUDA operations on the new + tensor will appropriately wait for the old operations to complete, as they + are all on the same stream.) + + Concretely, this looks like this:: + + with torch.cuda.stream(s0): + x = torch.zeros(N) + + s1.wait_stream(s0) + with torch.cuda.stream(s1): + y = some_comm_op(x) + + ... some compute on s0 ... + + # synchronize creation stream s0 to side stream s1 + # before deallocating x + s0.wait_stream(s1) + del x + + Note that some discretion is required when deciding when to perform + ``s0.wait_stream(s1)``. In particular, if we were to wait immediately + after ``some_comm_op``, there wouldn't be any point in having the side + stream; it would be equivalent to have run ``some_comm_op`` on ``s0``. + Instead, the synchronization must be placed at some appropriate, later + point in time where you expect the side stream ``s1`` to have finished + work. This location is typically identified via profiling, e.g., using + Chrome traces produced + :meth:`torch.autograd.profiler.profile.export_chrome_trace`. If you + place the wait too early, work on s0 will block until ``s1`` has finished, + preventing further overlapping of communication and computation. If you + place the wait too late, you will use more memory than is strictly + necessary (as you are keeping ``x`` live for longer.) For a concrete + example of how this guidance can be applied in practice, see this post: + `FSDP and CUDACachingAllocator + `_. + """ + ... + def refine_names(self, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ... + def relu(self) -> Tensor: ... + def relu_(self) -> Tensor: ... + @overload + def remainder(self, other: Tensor) -> Tensor: + r""" + remainder(divisor) -> Tensor + + See :func:`torch.remainder` + """ + ... + @overload + def remainder(self, other: Union[Number, _complex]) -> Tensor: + r""" + remainder(divisor) -> Tensor + + See :func:`torch.remainder` + """ + ... + @overload + def remainder_(self, other: Tensor) -> Tensor: + r""" + remainder_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.remainder` + """ + ... + @overload + def remainder_(self, other: Union[Number, _complex]) -> Tensor: + r""" + remainder_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.remainder` + """ + ... + def rename(self, names: Optional[Sequence[Union[str, ellipsis, None]]]) -> Tensor: ... + def rename_(self, names: Optional[Sequence[Union[str, ellipsis, None]]]) -> Tensor: ... + def renorm(self, p: Union[Number, _complex], dim: _int, maxnorm: Union[Number, _complex]) -> Tensor: + r""" + renorm(p, dim, maxnorm) -> Tensor + + See :func:`torch.renorm` + """ + ... + def renorm_(self, p: Union[Number, _complex], dim: _int, maxnorm: Union[Number, _complex]) -> Tensor: + r""" + renorm_(p, dim, maxnorm) -> Tensor + + In-place version of :meth:`~Tensor.renorm` + """ + ... + @overload + def repeat(self, repeats: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + repeat(*sizes) -> Tensor + + Repeats this tensor along the specified dimensions. + + Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. + + .. warning:: + + :meth:`~Tensor.repeat` behaves differently from + `numpy.repeat `_, + but is more similar to + `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. + + Args: + sizes (torch.Size or int...): The number of times to repeat this tensor along each + dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat(4, 2) + tensor([[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]]) + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) + """ + ... + @overload + def repeat(self, *repeats: _int) -> Tensor: + r""" + repeat(*sizes) -> Tensor + + Repeats this tensor along the specified dimensions. + + Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. + + .. warning:: + + :meth:`~Tensor.repeat` behaves differently from + `numpy.repeat `_, + but is more similar to + `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. + + Args: + sizes (torch.Size or int...): The number of times to repeat this tensor along each + dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat(4, 2) + tensor([[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]]) + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) + """ + ... + @overload + def repeat_interleave(self, repeats: Tensor, dim: Optional[_int] = None, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor + + See :func:`torch.repeat_interleave`. + """ + ... + @overload + def repeat_interleave(self, repeats: Union[_int, SymInt], dim: Optional[_int] = None, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor + + See :func:`torch.repeat_interleave`. + """ + ... + def requires_grad_(self, mode: _bool = True) -> Tensor: + r""" + requires_grad_(requires_grad=True) -> Tensor + + Change if autograd should record operations on this tensor: sets this tensor's + :attr:`requires_grad` attribute in-place. Returns this tensor. + + :func:`requires_grad_`'s main use case is to tell autograd to begin recording + operations on a Tensor ``tensor``. If ``tensor`` has ``requires_grad=False`` + (because it was obtained through a DataLoader, or required preprocessing or + initialization), ``tensor.requires_grad_()`` makes it so that autograd will + begin to record operations on ``tensor``. + + Args: + requires_grad (bool): If autograd should record operations on this tensor. + Default: ``True``. + + Example:: + + >>> # Let's say we want to preprocess some saved weights and use + >>> # the result as new weights. + >>> saved_weights = [0.1, 0.2, 0.3, 0.25] + >>> loaded_weights = torch.tensor(saved_weights) + >>> weights = preprocess(loaded_weights) # some function + >>> weights + tensor([-0.5503, 0.4926, -2.1158, -0.8303]) + + >>> # Now, start to record operations done to weights + >>> weights.requires_grad_() + >>> out = weights.pow(2).sum() + >>> out.backward() + >>> weights.grad + tensor([-1.1007, 0.9853, -4.2316, -1.6606]) + """ + ... + @overload + def reshape(self, shape: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + reshape(*shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`self` + but with the specified shape. This method returns a view if :attr:`shape` is + compatible with the current shape. See :meth:`torch.Tensor.view` on when it is + possible to return a view. + + See :func:`torch.reshape` + + Args: + shape (tuple of ints or int...): the desired shape + """ + ... + @overload + def reshape(self, *shape: _int) -> Tensor: + r""" + reshape(*shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`self` + but with the specified shape. This method returns a view if :attr:`shape` is + compatible with the current shape. See :meth:`torch.Tensor.view` on when it is + possible to return a view. + + See :func:`torch.reshape` + + Args: + shape (tuple of ints or int...): the desired shape + """ + ... + def reshape_as(self, other: Tensor) -> Tensor: + r""" + reshape_as(other) -> Tensor + + Returns this tensor as the same shape as :attr:`other`. + ``self.reshape_as(other)`` is equivalent to ``self.reshape(other.sizes())``. + This method returns a view if ``other.sizes()`` is compatible with the current + shape. See :meth:`torch.Tensor.view` on when it is possible to return a view. + + Please see :meth:`reshape` for more information about ``reshape``. + + Args: + other (:class:`torch.Tensor`): The result tensor has the same shape + as :attr:`other`. + """ + ... + @overload + def resize_(self, size: Sequence[Union[_int, SymInt]], *, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor + + Resizes :attr:`self` tensor to the specified size. If the number of elements is + larger than the current storage size, then the underlying storage is resized + to fit the new number of elements. If the number of elements is smaller, the + underlying storage is not changed. Existing elements are preserved but any new + memory is uninitialized. + + .. warning:: + + This is a low-level method. The storage is reinterpreted as C-contiguous, + ignoring the current strides (unless the target size equals the current + size, in which case the tensor is left unchanged). For most purposes, you + will instead want to use :meth:`~Tensor.view()`, which checks for + contiguity, or :meth:`~Tensor.reshape()`, which copies data if needed. To + change the size in-place with custom strides, see :meth:`~Tensor.set_()`. + + .. note:: + + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, new elements are initialized to prevent nondeterministic behavior + from using the result as an input to an operation. Floating point and + complex values are set to NaN, and integer values are set to the maximum + value. + + Args: + sizes (torch.Size or int...): the desired size + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``sizes``. + + Example:: + + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> x.resize_(2, 2) + tensor([[ 1, 2], + [ 3, 4]]) + """ + ... + @overload + def resize_(self, *size: _int, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor + + Resizes :attr:`self` tensor to the specified size. If the number of elements is + larger than the current storage size, then the underlying storage is resized + to fit the new number of elements. If the number of elements is smaller, the + underlying storage is not changed. Existing elements are preserved but any new + memory is uninitialized. + + .. warning:: + + This is a low-level method. The storage is reinterpreted as C-contiguous, + ignoring the current strides (unless the target size equals the current + size, in which case the tensor is left unchanged). For most purposes, you + will instead want to use :meth:`~Tensor.view()`, which checks for + contiguity, or :meth:`~Tensor.reshape()`, which copies data if needed. To + change the size in-place with custom strides, see :meth:`~Tensor.set_()`. + + .. note:: + + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, new elements are initialized to prevent nondeterministic behavior + from using the result as an input to an operation. Floating point and + complex values are set to NaN, and integer values are set to the maximum + value. + + Args: + sizes (torch.Size or int...): the desired size + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``sizes``. + + Example:: + + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> x.resize_(2, 2) + tensor([[ 1, 2], + [ 3, 4]]) + """ + ... + def resize_as_(self, the_template: Tensor, *, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + resize_as_(tensor, memory_format=torch.contiguous_format) -> Tensor + + Resizes the :attr:`self` tensor to be the same size as the specified + :attr:`tensor`. This is equivalent to ``self.resize_(tensor.size())``. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``tensor.size()``. + """ + ... + def resize_as_sparse_(self, the_template: Tensor) -> Tensor: ... + def resolve_conj(self) -> Tensor: + r""" + resolve_conj() -> Tensor + + See :func:`torch.resolve_conj` + """ + ... + def resolve_neg(self) -> Tensor: + r""" + resolve_neg() -> Tensor + + See :func:`torch.resolve_neg` + """ + ... + def retain_grad(self) -> None: + r""" + retain_grad() -> None + + Enables this Tensor to have their :attr:`grad` populated during + :func:`backward`. This is a no-op for leaf tensors. + """ + ... + def roll(self, shifts: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]], dims: Union[_int, _size] = ()) -> Tensor: + r""" + roll(shifts, dims) -> Tensor + + See :func:`torch.roll` + """ + ... + def rot90(self, k: _int = 1, dims: _size = (0,1)) -> Tensor: + r""" + rot90(k, dims) -> Tensor + + See :func:`torch.rot90` + """ + ... + @overload + def round(self) -> Tensor: + r""" + round(decimals=0) -> Tensor + + See :func:`torch.round` + """ + ... + @overload + def round(self, *, decimals: _int) -> Tensor: + r""" + round(decimals=0) -> Tensor + + See :func:`torch.round` + """ + ... + @overload + def round_(self) -> Tensor: + r""" + round_(decimals=0) -> Tensor + + In-place version of :meth:`~Tensor.round` + """ + ... + @overload + def round_(self, *, decimals: _int) -> Tensor: + r""" + round_(decimals=0) -> Tensor + + In-place version of :meth:`~Tensor.round` + """ + ... + def row_indices(self) -> Tensor: ... + def rsqrt(self) -> Tensor: + r""" + rsqrt() -> Tensor + + See :func:`torch.rsqrt` + """ + ... + def rsqrt_(self) -> Tensor: + r""" + rsqrt_() -> Tensor + + In-place version of :meth:`~Tensor.rsqrt` + """ + ... + @overload + def scatter(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... + @overload + def scatter(self, dim: _int, index: Tensor, src: Tensor, *, reduce: str) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... + @overload + def scatter(self, dim: _int, index: Tensor, value: Union[Number, _complex], *, reduce: str) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... + @overload + def scatter(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... + @overload + def scatter(self, dim: _int, index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... + @overload + def scatter(self, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... + @overload + def scatter_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + ... + @overload + def scatter_(self, dim: _int, index: Tensor, src: Tensor, *, reduce: str) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + ... + @overload + def scatter_(self, dim: _int, index: Tensor, value: Union[Number, _complex], *, reduce: str) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + ... + @overload + def scatter_(self, dim: _int, index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + ... + @overload + def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + ... + @overload + def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + ... + def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add_(dim, index, src) -> Tensor + + Adds all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor in a similar fashion as + :meth:`~torch.Tensor.scatter_`. For each value in :attr:`src`, it is added to + an index in :attr:`self` which is specified by its index in :attr:`src` + for ``dimension != dim`` and by the corresponding value in :attr:`index` for + ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + + :attr:`self`, :attr:`index` and :attr:`src` should have same number of + dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all + dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions + ``d != dim``. Note that ``index`` and ``src`` do not broadcast. + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and add, can be + either empty or of the same dimensionality as ``src``. When empty, the + operation returns ``self`` unchanged. + src (Tensor): the source elements to scatter and add + + Example:: + + >>> src = torch.ones((2, 5)) + >>> index = torch.tensor([[0, 1, 2, 0, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[1., 0., 0., 1., 1.], + [0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.]]) + >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[2., 0., 0., 1., 1.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 1., 1.]]) + """ + ... + def scatter_reduce(self, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool = True) -> Tensor: + r""" + scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` + """ + ... + def scatter_reduce_(self, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool = True) -> Tensor: + r""" + scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor + + Reduces all values from the :attr:`src` tensor to the indices specified in + the :attr:`index` tensor in the :attr:`self` tensor using the applied reduction + defined via the :attr:`reduce` argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, + :obj:`"amax"`, :obj:`"amin"`). For each value in :attr:`src`, it is reduced to an + index in :attr:`self` which is specified by its index in :attr:`src` for + ``dimension != dim`` and by the corresponding value in :attr:`index` for + ``dimension = dim``. If :obj:`include_self="True"`, the values in the :attr:`self` + tensor are included in the reduction. + + :attr:`self`, :attr:`index` and :attr:`src` should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + For a 3-D tensor with :obj:`reduce="sum"` and :obj:`include_self=True` the + output is given as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + .. warning:: + + This function is in beta and may change in the near future. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and reduce. + src (Tensor): the source elements to scatter and reduce + reduce (str): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + include_self (bool): whether elements from the :attr:`self` tensor are + included in the reduction + + Example:: + + >>> src = torch.tensor([1., 2., 3., 4., 5., 6.]) + >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) + >>> input = torch.tensor([1., 2., 3., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum") + tensor([5., 14., 8., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False) + tensor([4., 12., 5., 4.]) + >>> input2 = torch.tensor([5., 4., 3., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax") + tensor([5., 6., 5., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) + tensor([3., 6., 5., 2.]) + """ + ... + @overload + def select(self, dim: _int, index: Union[_int, SymInt]) -> Tensor: + r""" + select(dim, index) -> Tensor + + See :func:`torch.select` + """ + ... + @overload + def select(self, dim: Union[str, ellipsis, None], index: _int) -> Tensor: + r""" + select(dim, index) -> Tensor + + See :func:`torch.select` + """ + ... + def select_scatter(self, src: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: + r""" + select_scatter(src, dim, index) -> Tensor + + See :func:`torch.select_scatter` + """ + ... + @overload + def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: + r""" + set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor + + Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, + :attr:`self` tensor will share the same storage and have the same size and + strides as :attr:`source`. Changes to elements in one tensor will be reflected + in the other. + + If :attr:`source` is a :class:`~torch.Storage`, the method sets the underlying + storage, offset, size, and stride. + + Args: + source (Tensor or Storage): the tensor or storage to use + storage_offset (int, optional): the offset in the storage + size (torch.Size, optional): the desired size. Defaults to the size of the source. + stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. + """ + ... + @overload + def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: + r""" + set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor + + Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, + :attr:`self` tensor will share the same storage and have the same size and + strides as :attr:`source`. Changes to elements in one tensor will be reflected + in the other. + + If :attr:`source` is a :class:`~torch.Storage`, the method sets the underlying + storage, offset, size, and stride. + + Args: + source (Tensor or Storage): the tensor or storage to use + storage_offset (int, optional): the offset in the storage + size (torch.Size, optional): the desired size. Defaults to the size of the source. + stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. + """ + ... + def sgn(self) -> Tensor: + r""" + sgn() -> Tensor + + See :func:`torch.sgn` + """ + ... + def sgn_(self) -> Tensor: + r""" + sgn_() -> Tensor + + In-place version of :meth:`~Tensor.sgn` + """ + ... + def short(self) -> Tensor: + r""" + short(memory_format=torch.preserve_format) -> Tensor + + ``self.short()`` is equivalent to ``self.to(torch.int16)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... + def sigmoid(self) -> Tensor: + r""" + sigmoid() -> Tensor + + See :func:`torch.sigmoid` + """ + ... + def sigmoid_(self) -> Tensor: + r""" + sigmoid_() -> Tensor + + In-place version of :meth:`~Tensor.sigmoid` + """ + ... + def sign(self) -> Tensor: + r""" + sign() -> Tensor + + See :func:`torch.sign` + """ + ... + def sign_(self) -> Tensor: + r""" + sign_() -> Tensor + + In-place version of :meth:`~Tensor.sign` + """ + ... + def signbit(self) -> Tensor: + r""" + signbit() -> Tensor + + See :func:`torch.signbit` + """ + ... + def sin(self) -> Tensor: + r""" + sin() -> Tensor + + See :func:`torch.sin` + """ + ... + def sin_(self) -> Tensor: + r""" + sin_() -> Tensor + + In-place version of :meth:`~Tensor.sin` + """ + ... + def sinc(self) -> Tensor: + r""" + sinc() -> Tensor + + See :func:`torch.sinc` + """ + ... + def sinc_(self) -> Tensor: + r""" + sinc_() -> Tensor + + In-place version of :meth:`~Tensor.sinc` + """ + ... + def sinh(self) -> Tensor: + r""" + sinh() -> Tensor + + See :func:`torch.sinh` + """ + ... + def sinh_(self) -> Tensor: + r""" + sinh_() -> Tensor + + In-place version of :meth:`~Tensor.sinh` + """ + ... + @overload + def size(self, dim: None = None) -> Size: + r""" + size(dim=None) -> torch.Size or int + + Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, + the returned value is a :class:`torch.Size`, a subclass of :class:`tuple`. + If ``dim`` is specified, returns an int holding the size of that dimension. + + Args: + dim (int, optional): The dimension for which to retrieve the size. + + Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.size(dim=1) + 4 + """ + ... + @overload + def size(self, dim: _int) -> _int: + r""" + size(dim=None) -> torch.Size or int + + Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, + the returned value is a :class:`torch.Size`, a subclass of :class:`tuple`. + If ``dim`` is specified, returns an int holding the size of that dimension. + + Args: + dim (int, optional): The dimension for which to retrieve the size. + + Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.size(dim=1) + 4 + """ + ... + def slice_inverse(self, src: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1) -> Tensor: ... + def slice_scatter(self, src: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1) -> Tensor: + r""" + slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor + + See :func:`torch.slice_scatter` + """ + ... + def slogdet(self) -> torch.return_types.slogdet: + r""" + slogdet() -> (Tensor, Tensor) + + See :func:`torch.slogdet` + """ + ... + def smm(self, mat2: Tensor) -> Tensor: + r""" + smm(mat) -> Tensor + + See :func:`torch.smm` + """ + ... + @overload + def softmax(self, dim: _int, dtype: Optional[_dtype] = None) -> Tensor: + r""" + softmax(dim) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + ... + @overload + def softmax(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + softmax(dim) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + ... + @overload + def sort(self, *, stable: Optional[_bool], dim: _int = -1, descending: _bool = False) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + ... + @overload + def sort(self, dim: _int = -1, descending: _bool = False) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + ... + @overload + def sort(self, *, stable: Optional[_bool], dim: Union[str, ellipsis, None], descending: _bool = False) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + ... + @overload + def sort(self, dim: Union[str, ellipsis, None], descending: _bool = False) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + ... + def sparse_dim(self) -> _int: + r""" + sparse_dim() -> int + + Return the number of sparse dimensions in a :ref:`sparse tensor ` :attr:`self`. + + .. note:: + Returns ``0`` if :attr:`self` is not a sparse tensor. + + See also :meth:`Tensor.dense_dim` and :ref:`hybrid tensors `. + """ + ... + def sparse_mask(self, mask: Tensor) -> Tensor: + r""" + sparse_mask(mask) -> Tensor + + Returns a new :ref:`sparse tensor ` with values from a + strided tensor :attr:`self` filtered by the indices of the sparse + tensor :attr:`mask`. The values of :attr:`mask` sparse tensor are + ignored. :attr:`self` and :attr:`mask` tensors must have the same + shape. + + .. note:: + + The returned sparse tensor might contain duplicate values if :attr:`mask` + is not coalesced. It is therefore advisable to pass ``mask.coalesce()`` + if such behavior is not desired. + + .. note:: + + The returned sparse tensor has the same indices as the sparse tensor + :attr:`mask`, even when the corresponding values in :attr:`self` are + zeros. + + Args: + mask (Tensor): a sparse tensor whose indices are used as a filter + + Example:: + + >>> nse = 5 + >>> dims = (5, 5, 2, 2) + >>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)), + ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) + >>> V = torch.randn(nse, dims[2], dims[3]) + >>> S = torch.sparse_coo_tensor(I, V, dims).coalesce() + >>> D = torch.randn(dims) + >>> D.sparse_mask(S) + tensor(indices=tensor([[0, 0, 0, 2], + [0, 1, 4, 3]]), + values=tensor([[[ 1.6550, 0.2397], + [-0.1611, -0.0779]], + + [[ 0.2326, -1.0558], + [ 1.4711, 1.9678]], + + [[-0.5138, -0.0411], + [ 1.9417, 0.5158]], + + [[ 0.0793, 0.0036], + [-0.2569, -0.1055]]]), + size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo) + """ + ... + def sparse_resize_(self, size: _size, sparse_dim: _int, dense_dim: _int) -> Tensor: + r""" + sparse_resize_(size, sparse_dim, dense_dim) -> Tensor + + Resizes :attr:`self` :ref:`sparse tensor ` to the desired + size and the number of sparse and dense dimensions. + + .. note:: + If the number of specified elements in :attr:`self` is zero, then + :attr:`size`, :attr:`sparse_dim`, and :attr:`dense_dim` can be any + size and positive integers such that ``len(size) == sparse_dim + + dense_dim``. + + If :attr:`self` specifies one or more elements, however, then each + dimension in :attr:`size` must not be smaller than the corresponding + dimension of :attr:`self`, :attr:`sparse_dim` must equal the number + of sparse dimensions in :attr:`self`, and :attr:`dense_dim` must + equal the number of dense dimensions in :attr:`self`. + + .. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + + Args: + size (torch.Size): the desired size. If :attr:`self` is non-empty + sparse tensor, the desired size cannot be smaller than the + original size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions + """ + ... + def sparse_resize_and_clear_(self, size: _size, sparse_dim: _int, dense_dim: _int) -> Tensor: + r""" + sparse_resize_and_clear_(size, sparse_dim, dense_dim) -> Tensor + + Removes all specified elements from a :ref:`sparse tensor + ` :attr:`self` and resizes :attr:`self` to the desired + size and the number of sparse and dense dimensions. + + .. warning: + Throws an error if :attr:`self` is not a sparse tensor. + + Args: + size (torch.Size): the desired size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions + """ + ... + @overload + def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ... + @overload + def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ... + def split_with_sizes(self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: ... + def sqrt(self) -> Tensor: + r""" + sqrt() -> Tensor + + See :func:`torch.sqrt` + """ + ... + def sqrt_(self) -> Tensor: + r""" + sqrt_() -> Tensor + + In-place version of :meth:`~Tensor.sqrt` + """ + ... + def square(self) -> Tensor: + r""" + square() -> Tensor + + See :func:`torch.square` + """ + ... + def square_(self) -> Tensor: + r""" + square_() -> Tensor + + In-place version of :meth:`~Tensor.square` + """ + ... + @overload + def squeeze(self) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + ... + @overload + def squeeze(self, dim: _int) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + ... + @overload + def squeeze(self, dim: _size) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + ... + @overload + def squeeze(self, *dim: _int) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + ... + @overload + def squeeze(self, dim: Union[str, ellipsis, None]) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + ... + @overload + def squeeze_(self) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + ... + @overload + def squeeze_(self, dim: _int) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + ... + @overload + def squeeze_(self, dim: _size) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + ... + @overload + def squeeze_(self, *dim: _int) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + ... + @overload + def squeeze_(self, dim: Union[str, ellipsis, None]) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + ... + def sspaddmm(self, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + sspaddmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.sspaddmm` + """ + ... + @overload + def std(self, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + ... + @overload + def std(self, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + ... + @overload + def std(self, unbiased: _bool = True) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + ... + @overload + def std(self, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + ... + @overload + def std(self, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + ... + def untyped_storage(self) -> UntypedStorage: ... + def storage_offset(self) -> _int: + r""" + storage_offset() -> int + + Returns :attr:`self` tensor's offset in the underlying storage in terms of + number of storage elements (not bytes). + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5]) + >>> x.storage_offset() + 0 + >>> x[3:].storage_offset() + 3 + """ + ... + def storage_type(self) -> Storage: ... + @overload + def stride(self, dim: None = None) -> Tuple[_int, ...]: + r""" + stride(dim) -> tuple or int + + Returns the stride of :attr:`self` tensor. + + Stride is the jump necessary to go from one element to the next one in the + specified dimension :attr:`dim`. A tuple of all strides is returned when no + argument is passed in. Otherwise, an integer value is returned as the stride in + the particular dimension :attr:`dim`. + + Args: + dim (int, optional): the desired dimension in which stride is required + + Example:: + + >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> x.stride() + (5, 1) + >>> x.stride(0) + 5 + >>> x.stride(-1) + 1 + """ + ... + @overload + def stride(self, dim: _int) -> _int: + r""" + stride(dim) -> tuple or int + + Returns the stride of :attr:`self` tensor. + + Stride is the jump necessary to go from one element to the next one in the + specified dimension :attr:`dim`. A tuple of all strides is returned when no + argument is passed in. Otherwise, an integer value is returned as the stride in + the particular dimension :attr:`dim`. + + Args: + dim (int, optional): the desired dimension in which stride is required + + Example:: + + >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> x.stride() + (5, 1) + >>> x.stride(0) + 5 + >>> x.stride(-1) + 1 + """ + ... + def sub(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], *, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + sub(other, *, alpha=1) -> Tensor + + See :func:`torch.sub`. + """ + ... + def sub_(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], *, alpha: Optional[Union[Number, _complex]] = 1) -> Tensor: + r""" + sub_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.sub` + """ + ... + @overload + def subtract(self, other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + subtract(other, *, alpha=1) -> Tensor + + See :func:`torch.subtract`. + """ + ... + @overload + def subtract(self, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + subtract(other, *, alpha=1) -> Tensor + + See :func:`torch.subtract`. + """ + ... + @overload + def subtract_(self, other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + subtract_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.subtract`. + """ + ... + @overload + def subtract_(self, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + subtract_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.subtract`. + """ + ... + @overload + def sum(self, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + sum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.sum` + """ + ... + @overload + def sum(self, dim: Optional[Union[_int, _size]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + sum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.sum` + """ + ... + @overload + def sum(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + sum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.sum` + """ + ... + @overload + def sum_to_size(self, size: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + sum_to_size(*size) -> Tensor + + Sum ``this`` tensor to :attr:`size`. + :attr:`size` must be broadcastable to ``this`` tensor size. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + """ + ... + @overload + def sum_to_size(self, *size: _int) -> Tensor: + r""" + sum_to_size(*size) -> Tensor + + Sum ``this`` tensor to :attr:`size`. + :attr:`size` must be broadcastable to ``this`` tensor size. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + """ + ... + def svd(self, some: _bool = True, compute_uv: _bool = True) -> torch.return_types.svd: + r""" + svd(some=True, compute_uv=True) -> (Tensor, Tensor, Tensor) + + See :func:`torch.svd` + """ + ... + def swapaxes(self, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes(axis0, axis1) -> Tensor + + See :func:`torch.swapaxes` + """ + ... + def swapaxes_(self, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes_(axis0, axis1) -> Tensor + + In-place version of :meth:`~Tensor.swapaxes` + """ + ... + def swapdims(self, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims(dim0, dim1) -> Tensor + + See :func:`torch.swapdims` + """ + ... + def swapdims_(self, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims_(dim0, dim1) -> Tensor + + In-place version of :meth:`~Tensor.swapdims` + """ + ... + def t(self) -> Tensor: + r""" + t() -> Tensor + + See :func:`torch.t` + """ + ... + def t_(self) -> Tensor: + r""" + t_() -> Tensor + + In-place version of :meth:`~Tensor.t` + """ + ... + def take(self, index: Tensor) -> Tensor: + r""" + take(indices) -> Tensor + + See :func:`torch.take` + """ + ... + def take_along_dim(self, indices: Tensor, dim: Optional[_int] = None) -> Tensor: + r""" + take_along_dim(indices, dim) -> Tensor + + See :func:`torch.take_along_dim` + """ + ... + def tan(self) -> Tensor: + r""" + tan() -> Tensor + + See :func:`torch.tan` + """ + ... + def tan_(self) -> Tensor: + r""" + tan_() -> Tensor + + In-place version of :meth:`~Tensor.tan` + """ + ... + def tanh(self) -> Tensor: + r""" + tanh() -> Tensor + + See :func:`torch.tanh` + """ + ... + def tanh_(self) -> Tensor: + r""" + tanh_() -> Tensor + + In-place version of :meth:`~Tensor.tanh` + """ + ... + @overload + def tensor_split(self, indices: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(indices_or_sections, dim=0) -> List of Tensors + + See :func:`torch.tensor_split` + """ + ... + @overload + def tensor_split(self, tensor_indices_or_sections: Tensor, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(indices_or_sections, dim=0) -> List of Tensors + + See :func:`torch.tensor_split` + """ + ... + @overload + def tensor_split(self, sections: Union[_int, SymInt], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(indices_or_sections, dim=0) -> List of Tensors + + See :func:`torch.tensor_split` + """ + ... + @overload + def tile(self, dims: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + tile(dims) -> Tensor + + See :func:`torch.tile` + """ + ... + @overload + def tile(self, *dims: _int) -> Tensor: + r""" + tile(dims) -> Tensor + + See :func:`torch.tile` + """ + ... + @overload + def to(self, dtype: _dtype, non_blocking: _bool = False, copy: _bool = False, *, memory_format: Optional[torch.memory_format] = None) -> Tensor: + r""" + to(*args, **kwargs) -> Tensor + + Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are + inferred from the arguments of ``self.to(*args, **kwargs)``. + + .. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + + Here are the ways to call ``to``: + + .. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking`, tries to convert asynchronously with respect to + the host if possible, e.g., converting a CPU Tensor with pinned memory to a + CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. When :attr:`non_blocking`, tries to convert + asynchronously with respect to the host if possible, e.g., converting a CPU + Tensor with pinned memory to a CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + """ + ... + @overload + def to(self, device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None, non_blocking: _bool = False, copy: _bool = False, *, memory_format: Optional[torch.memory_format] = None) -> Tensor: + r""" + to(*args, **kwargs) -> Tensor + + Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are + inferred from the arguments of ``self.to(*args, **kwargs)``. + + .. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + + Here are the ways to call ``to``: + + .. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking`, tries to convert asynchronously with respect to + the host if possible, e.g., converting a CPU Tensor with pinned memory to a + CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. When :attr:`non_blocking`, tries to convert + asynchronously with respect to the host if possible, e.g., converting a CPU + Tensor with pinned memory to a CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + """ + ... + @overload + def to(self, other: Tensor, non_blocking: _bool = False, copy: _bool = False, *, memory_format: Optional[torch.memory_format] = None) -> Tensor: + r""" + to(*args, **kwargs) -> Tensor + + Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are + inferred from the arguments of ``self.to(*args, **kwargs)``. + + .. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + + Here are the ways to call ``to``: + + .. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking`, tries to convert asynchronously with respect to + the host if possible, e.g., converting a CPU Tensor with pinned memory to a + CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. When :attr:`non_blocking`, tries to convert + asynchronously with respect to the host if possible, e.g., converting a CPU + Tensor with pinned memory to a CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + """ + ... + def to_dense(self, dtype: Optional[_dtype] = None, *, masked_grad: Optional[_bool] = None) -> Tensor: + r""" + to_dense(dtype=None, *, masked_grad=True) -> Tensor + + Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`. + + Keyword args: + {dtype} + masked_grad (bool, optional): If set to ``True`` (default) and + :attr:`self` has a sparse layout then the backward of + :meth:`to_dense` returns ``grad.sparse_mask(self)``. + + Example:: + + >>> s = torch.sparse_coo_tensor( + ... torch.tensor([[1, 1], + ... [0, 2]]), + ... torch.tensor([9, 10]), + ... size=(3, 3)) + >>> s.to_dense() + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + """ + ... + def to_mkldnn(self, dtype: Optional[_dtype] = None) -> Tensor: + r""" + to_mkldnn() -> Tensor + Returns a copy of the tensor in ``torch.mkldnn`` layout. + """ + ... + def to_padded_tensor(self, padding: _float, output_size: Optional[Sequence[Union[_int, SymInt]]] = None) -> Tensor: + r""" + to_padded_tensor(padding, output_size=None) -> Tensor + See :func:`to_padded_tensor` + """ + ... + @overload + def to_sparse(self, *, layout: Optional[_layout] = None, blocksize: Optional[Union[_int, _size]] = None, dense_dim: Optional[_int] = None) -> Tensor: + r""" + to_sparse(sparseDims) -> Tensor + + Returns a sparse copy of the tensor. PyTorch supports sparse tensors in + :ref:`coordinate format `. + + Args: + sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor + + Example:: + + >>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) + >>> d + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + >>> d.to_sparse() + tensor(indices=tensor([[1, 1], + [0, 2]]), + values=tensor([ 9, 10]), + size=(3, 3), nnz=2, layout=torch.sparse_coo) + >>> d.to_sparse(1) + tensor(indices=tensor([[1]]), + values=tensor([[ 9, 0, 10]]), + size=(3, 3), nnz=1, layout=torch.sparse_coo) + + .. method:: to_sparse(*, layout=None, blocksize=None, dense_dim=None) -> Tensor + :noindex: + + Returns a sparse tensor with the specified layout and blocksize. If + the :attr:`self` is strided, the number of dense dimensions could be + specified, and a hybrid sparse tensor will be created, with + `dense_dim` dense dimensions and `self.dim() - 2 - dense_dim` batch + dimension. + + .. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + + Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR, CSC, BSR or BSC tensor. This argument should be + used only if :attr:`self` is a strided tensor, and must be a + value between 0 and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize + + >>> x = torch.tensor([[[1], [0]], [[0], [0]], [[2], [3]]]) + >>> x.to_sparse(layout=torch.sparse_csr, dense_dim=1) + tensor(crow_indices=tensor([0, 1, 1, 3]), + col_indices=tensor([0, 0, 1]), + values=tensor([[1], + [2], + [3]]), size=(3, 2, 1), nnz=3, layout=torch.sparse_csr) + """ + ... + @overload + def to_sparse(self, sparse_dim: _int) -> Tensor: + r""" + to_sparse(sparseDims) -> Tensor + + Returns a sparse copy of the tensor. PyTorch supports sparse tensors in + :ref:`coordinate format `. + + Args: + sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor + + Example:: + + >>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) + >>> d + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + >>> d.to_sparse() + tensor(indices=tensor([[1, 1], + [0, 2]]), + values=tensor([ 9, 10]), + size=(3, 3), nnz=2, layout=torch.sparse_coo) + >>> d.to_sparse(1) + tensor(indices=tensor([[1]]), + values=tensor([[ 9, 0, 10]]), + size=(3, 3), nnz=1, layout=torch.sparse_coo) + + .. method:: to_sparse(*, layout=None, blocksize=None, dense_dim=None) -> Tensor + :noindex: + + Returns a sparse tensor with the specified layout and blocksize. If + the :attr:`self` is strided, the number of dense dimensions could be + specified, and a hybrid sparse tensor will be created, with + `dense_dim` dense dimensions and `self.dim() - 2 - dense_dim` batch + dimension. + + .. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + + Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR, CSC, BSR or BSC tensor. This argument should be + used only if :attr:`self` is a strided tensor, and must be a + value between 0 and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize + + >>> x = torch.tensor([[[1], [0]], [[0], [0]], [[2], [3]]]) + >>> x.to_sparse(layout=torch.sparse_csr, dense_dim=1) + tensor(crow_indices=tensor([0, 1, 1, 3]), + col_indices=tensor([0, 0, 1]), + values=tensor([[1], + [2], + [3]]), size=(3, 2, 1), nnz=3, layout=torch.sparse_csr) + """ + ... + def to_sparse_bsc(self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None) -> Tensor: + r""" + to_sparse_bsc(blocksize, dense_dim) -> Tensor + + Convert a tensor to a block sparse column (BSC) storage format of + given blocksize. If the :attr:`self` is strided, then the number of + dense dimensions could be specified, and a hybrid BSC tensor will be + created, with `dense_dim` dense dimensions and `self.dim() - 2 - + dense_dim` batch dimension. + + Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSC tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsc = sparse.to_sparse_bsc((5, 5)) + >>> sparse_bsc.row_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsc((2, 1), 1) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 1, 0]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsc) + """ + ... + def to_sparse_bsr(self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None) -> Tensor: + r""" + to_sparse_bsr(blocksize, dense_dim) -> Tensor + + Convert a tensor to a block sparse row (BSR) storage format of given + blocksize. If the :attr:`self` is strided, then the number of dense + dimensions could be specified, and a hybrid BSR tensor will be + created, with `dense_dim` dense dimensions and `self.dim() - 2 - + dense_dim` batch dimension. + + Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsr = sparse.to_sparse_bsr((5, 5)) + >>> sparse_bsr.col_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsr((2, 1), 1) + tensor(crow_indices=tensor([0, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsr) + """ + ... + def to_sparse_csc(self, dense_dim: Optional[_int] = None) -> Tensor: + r""" + to_sparse_csc() -> Tensor + + Convert a tensor to compressed column storage (CSC) format. Except + for strided tensors, only works with 2D tensors. If the :attr:`self` + is strided, then the number of dense dimensions could be specified, + and a hybrid CSC tensor will be created, with `dense_dim` dense + dimensions and `self.dim() - 2 - dense_dim` batch dimension. + + Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csc() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csc(dense_dim=2) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csc) + """ + ... + def to_sparse_csr(self, dense_dim: Optional[_int] = None) -> Tensor: + r""" + to_sparse_csr(dense_dim=None) -> Tensor + + Convert a tensor to compressed row storage format (CSR). Except for + strided tensors, only works with 2D tensors. If the :attr:`self` is + strided, then the number of dense dimensions could be specified, and a + hybrid CSR tensor will be created, with `dense_dim` dense dimensions + and `self.dim() - 2 - dense_dim` batch dimension. + + Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csr() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csr(dense_dim=2) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csr) + """ + ... + def tolist(self) -> List: + r""" + tolist() -> list or number + + Returns the tensor as a (nested) list. For scalars, a standard + Python number is returned, just like with :meth:`~Tensor.item`. + Tensors are automatically moved to the CPU first if necessary. + + This operation is not differentiable. + + Examples:: + + >>> a = torch.randn(2, 2) + >>> a.tolist() + [[0.012766935862600803, 0.5415473580360413], + [-0.08909505605697632, 0.7729271650314331]] + >>> a[0,0].tolist() + 0.012766935862600803 + """ + ... + def topk(self, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True) -> torch.return_types.topk: + r""" + topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor) + + See :func:`torch.topk` + """ + ... + def trace(self) -> Tensor: + r""" + trace() -> Tensor + + See :func:`torch.trace` + """ + ... + @overload + def transpose(self, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose(dim0, dim1) -> Tensor + + See :func:`torch.transpose` + """ + ... + @overload + def transpose(self, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None]) -> Tensor: + r""" + transpose(dim0, dim1) -> Tensor + + See :func:`torch.transpose` + """ + ... + def transpose_(self, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose_(dim0, dim1) -> Tensor + + In-place version of :meth:`~Tensor.transpose` + """ + ... + def triangular_solve(self, A: Tensor, upper: _bool = True, transpose: _bool = False, unitriangular: _bool = False) -> torch.return_types.triangular_solve: + r""" + triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) + + See :func:`torch.triangular_solve` + """ + ... + def tril(self, diagonal: _int = 0) -> Tensor: + r""" + tril(diagonal=0) -> Tensor + + See :func:`torch.tril` + """ + ... + def tril_(self, diagonal: _int = 0) -> Tensor: + r""" + tril_(diagonal=0) -> Tensor + + In-place version of :meth:`~Tensor.tril` + """ + ... + def triu(self, diagonal: _int = 0) -> Tensor: + r""" + triu(diagonal=0) -> Tensor + + See :func:`torch.triu` + """ + ... + def triu_(self, diagonal: _int = 0) -> Tensor: + r""" + triu_(diagonal=0) -> Tensor + + In-place version of :meth:`~Tensor.triu` + """ + ... + def true_divide(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor] = None) -> Tensor: + r""" + true_divide(value) -> Tensor + + See :func:`torch.true_divide` + """ + ... + def true_divide_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: + r""" + true_divide_(value) -> Tensor + + In-place version of :meth:`~Tensor.true_divide_` + """ + ... + def trunc(self) -> Tensor: + r""" + trunc() -> Tensor + + See :func:`torch.trunc` + """ + ... + def trunc_(self) -> Tensor: + r""" + trunc_() -> Tensor + + In-place version of :meth:`~Tensor.trunc` + """ + ... + @overload + def type(self, dtype: None = None, non_blocking: _bool = False) -> str: + r""" + type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor + Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (dtype or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + ... + @overload + def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: + r""" + type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor + Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (dtype or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + ... + def type_as(self, other: Tensor) -> Tensor: + r""" + type_as(tensor) -> Tensor + + Returns this tensor cast to the type of the given tensor. + + This is a no-op if the tensor is already of the correct type. This is + equivalent to ``self.type(tensor.type())`` + + Args: + tensor (Tensor): the tensor which has the desired type + """ + ... + @overload + def unbind(self, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unbind(dim=0) -> seq + + See :func:`torch.unbind` + """ + ... + @overload + def unbind(self, dim: Union[str, ellipsis, None]) -> Tuple[Tensor, ...]: + r""" + unbind(dim=0) -> seq + + See :func:`torch.unbind` + """ + ... + @overload + def unflatten(self, dim: Union[str, ellipsis, None], sizes: Sequence[Union[_int, SymInt]], names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ... + @overload + def unflatten(self, dim: _int, sizes: Sequence[Union[_int, SymInt]]) -> Tensor: ... + def unfold(self, dimension: _int, size: _int, step: _int) -> Tensor: + r""" + unfold(dimension, size, step) -> Tensor + + Returns a view of the original tensor which contains all slices of size :attr:`size` from + :attr:`self` tensor in the dimension :attr:`dimension`. + + Step between two slices is given by :attr:`step`. + + If `sizedim` is the size of dimension :attr:`dimension` for :attr:`self`, the size of + dimension :attr:`dimension` in the returned tensor will be + `(sizedim - size) / step + 1`. + + An additional dimension of size :attr:`size` is appended in the returned tensor. + + Args: + dimension (int): dimension in which unfolding happens + size (int): the size of each slice that is unfolded + step (int): the step between each slice + + Example:: + + >>> x = torch.arange(1., 8) + >>> x + tensor([ 1., 2., 3., 4., 5., 6., 7.]) + >>> x.unfold(0, 2, 1) + tensor([[ 1., 2.], + [ 2., 3.], + [ 3., 4.], + [ 4., 5.], + [ 5., 6.], + [ 6., 7.]]) + >>> x.unfold(0, 2, 2) + tensor([[ 1., 2.], + [ 3., 4.], + [ 5., 6.]]) + """ + ... + def uniform_(self, from_: _float = 0, to: _float = 1, *, generator: Optional[Generator] = None) -> Tensor: + r""" + uniform_(from=0, to=1, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the continuous uniform + distribution: + + .. math:: + f(x) = \dfrac{1}{\text{to} - \text{from}} + """ + ... + def unsafe_chunk(self, chunks: _int, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unsafe_chunk(chunks, dim=0) -> List of Tensors + + See :func:`torch.unsafe_chunk` + """ + ... + def unsafe_split(self, split_size: Union[_int, SymInt], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unsafe_split(split_size, dim=0) -> List of Tensors + + See :func:`torch.unsafe_split` + """ + ... + def unsafe_split_with_sizes(self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: ... + def unsqueeze(self, dim: _int) -> Tensor: + r""" + unsqueeze(dim) -> Tensor + + See :func:`torch.unsqueeze` + """ + ... + def unsqueeze_(self, dim: _int) -> Tensor: + r""" + unsqueeze_(dim) -> Tensor + + In-place version of :meth:`~Tensor.unsqueeze` + """ + ... + def values(self) -> Tensor: + r""" + values() -> Tensor + + Return the values tensor of a :ref:`sparse COO tensor `. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + + See also :meth:`Tensor.indices`. + + .. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. + """ + ... + @overload + def var(self, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + ... + @overload + def var(self, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + ... + @overload + def var(self, unbiased: _bool = True) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + ... + @overload + def var(self, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + ... + @overload + def var(self, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + ... + def vdot(self, other: Tensor) -> Tensor: + r""" + vdot(other) -> Tensor + + See :func:`torch.vdot` + """ + ... + @overload + def view(self, dtype: _dtype) -> Tensor: + r""" + view(*shape) -> Tensor + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (torch.Size or int...): the desired size + + Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + .. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + + Args: + dtype (:class:`torch.dtype`): the desired dtype + + Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) + """ + ... + @overload + def view(self, size: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + view(*shape) -> Tensor + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (torch.Size or int...): the desired size + + Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + .. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + + Args: + dtype (:class:`torch.dtype`): the desired dtype + + Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) + """ + ... + @overload + def view(self, *size: _int) -> Tensor: + r""" + view(*shape) -> Tensor + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (torch.Size or int...): the desired size + + Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + .. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + + Args: + dtype (:class:`torch.dtype`): the desired dtype + + Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) + """ + ... + def view_as(self, other: Tensor) -> Tensor: + r""" + view_as(other) -> Tensor + + View this tensor as the same size as :attr:`other`. + ``self.view_as(other)`` is equivalent to ``self.view(other.size())``. + + Please see :meth:`~Tensor.view` for more information about ``view``. + + Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. + """ + ... + @overload + def vsplit(self, sections: _int) -> Tuple[Tensor, ...]: + r""" + vsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.vsplit` + """ + ... + @overload + def vsplit(self, indices: _size) -> Tuple[Tensor, ...]: + r""" + vsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.vsplit` + """ + ... + @overload + def vsplit(self, *indices: _int) -> Tuple[Tensor, ...]: + r""" + vsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.vsplit` + """ + ... + @overload + def where(self, condition: Tensor, other: Tensor) -> Tensor: + r""" + where(condition, y) -> Tensor + + ``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. + See :func:`torch.where` + """ + ... + @overload + def where(self, condition: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + where(condition, y) -> Tensor + + ``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. + See :func:`torch.where` + """ + ... + @overload + def xlogy(self, other: Tensor) -> Tensor: + r""" + xlogy(other) -> Tensor + + See :func:`torch.xlogy` + """ + ... + @overload + def xlogy(self, other: Union[Number, _complex]) -> Tensor: + r""" + xlogy(other) -> Tensor + + See :func:`torch.xlogy` + """ + ... + @overload + def xlogy_(self, other: Tensor) -> Tensor: + r""" + xlogy_(other) -> Tensor + + In-place version of :meth:`~Tensor.xlogy` + """ + ... + @overload + def xlogy_(self, other: Union[Number, _complex]) -> Tensor: + r""" + xlogy_(other) -> Tensor + + In-place version of :meth:`~Tensor.xlogy` + """ + ... + def zero_(self) -> Tensor: + r""" + zero_() -> Tensor + + Fills :attr:`self` tensor with zeros. + """ + ... + +_TensorBase = TensorBase + +# Defined in torch/csrc/multiprocessing/init.cpp +def _multiprocessing_init() -> None: ... + +# Defined in torch/csrc/mps/Module.cpp +def _mps_deviceSynchronize() -> None: ... +def _mps_get_default_generator() -> Generator: ... +def _mps_emptyCache() -> None: ... +def _mps_setMemoryFraction(fraction: _float) -> None: ... +def _mps_currentAllocatedMemory() -> _int: ... +def _mps_driverAllocatedMemory() -> _int: ... +def _mps_is_available() -> _bool: ... +def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ... +def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ... +def _mps_profilerStopTrace() -> None: ... +def _mps_acquireEvent(enable_timing: _bool) -> _int: ... +def _mps_releaseEvent(event_id: _int) -> None: ... +def _mps_recordEvent(event_id: _int) -> None: ... +def _mps_waitForEvent(event_id: _int) -> None: ... +def _mps_synchronizeEvent(event_id: _int) -> None: ... +def _mps_queryEvent(event_id: _int) -> _bool: ... +def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ... + + +# Defined in torch/csrc/cuda/Module.cpp +def _cuda_getCurrentStream(device: _int) -> Tuple: ... +def _cuda_getCurrentRawStream(device: _int) -> _int: ... +def _cuda_getDefaultStream(device: _int) -> Tuple: ... +def _cuda_getCurrentBlasHandle() -> _int: ... +def _cuda_clearCublasWorkspaces() -> None: ... +def _cuda_setDevice(device: _int) -> None: ... +def _cuda_exchangeDevice(device: _int) -> _int: ... +def _cuda_maybeExchangeDevice(device: _int) -> _int: ... +def _cuda_getDevice() -> _int: ... +def _cuda_getDeviceCount() -> _int: ... +def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ... +def _cuda_get_sync_debug_mode() -> _int: ... +def _cuda_sleep(cycles: _int) -> None: ... +def _cuda_synchronize() -> None: ... +def _cuda_ipc_collect() -> None: ... +def _cuda_getArchFlags() -> Optional[str]: ... +def _cuda_init() -> None: ... +def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... +def _cuda_getCompiledVersion() -> _int: ... +def _cuda_cudaHostAllocator() -> _int: ... +def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... +def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... +def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... +def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... +def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... +def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... +def _cuda_checkPoolLiveAllocations(device: _int, mempool_id: Tuple[_int, _int], expected_live_allocations: Set) -> _bool: ... +def _cuda_setCheckpointPoolState(device: _int, state: _cuda_CUDAAllocator_AllocatorState, stale_storages: List[_int], storages_to_add_deleters_to: List[_int]) -> None: ... +def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ... +def _cuda_emptyCache() -> None: ... +def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ... +def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ... +def _cuda_resetPeakMemoryStats(device: _int) -> None: ... +def _cuda_memorySnapshot() -> Dict[str, Any]: ... +def _cuda_record_memory_history_legacy( + enabled: _bool, + record_context: _bool, + record_context_cpp: _bool, + alloc_trace_max_entries: _int, + alloc_trace_record_context: _bool, +) -> None: ... +def _cuda_record_memory_history( + enabled: Optional[str], + context: Optional[str], + stacks: str, + max_entries +) -> None: ... +def _cuda_isHistoryEnabled() -> _bool: ... + +def _cuda_getAllocatorBackend() -> str: ... +class _cuda_CUDAAllocator_AllocatorState: + pass +def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_CUDAAllocator_AllocatorState: ... +def _set_cached_tensors_enabled(enabled: _bool) -> None: ... +def _add_cached_tensor(t: Tensor) -> None: ... +def _remove_cached_tensor(t: Tensor) -> None: ... +def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ... +def _storage_Use_Count(storage_ptr: _int) -> _int: ... +def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... +def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... +def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ... + +class _cuda_CUDAAllocator: ... + +def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ... +def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ... +def _cuda_getAllocator() -> _cuda_CUDAAllocator: ... +def _cuda_lock_mutex() -> None: ... +def _cuda_unlock_mutex() -> None: ... +def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... +def _cuda_jiterator_compile_and_launch_kernel( + code_string: str, + kernel_name: str, + return_by_ref: _bool, + num_outputs: _int, + tensors: Tuple, + kwargs: Dict[str, Union[_int, _float, _bool]], +) -> Tensor: ... +def _cuda_get_cudnn_benchmark_limit() -> _int: ... +def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ... +def _cuda_get_conv_benchmark_empty_cache() -> _bool: ... +def _cudnn_set_conv_benchmark_empty_cache(enable: _bool) -> None: ... +def _nccl_version() -> _int: ... +def _nccl_version_suffix() -> bytes : ... +def _nccl_unique_id() -> bytes: ... +def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ... +def _nccl_reduce( + input: Sequence[Tensor], + output: Tensor, + root: _int, + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_all_reduce( + input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_broadcast( + input: Sequence[Tensor], + root: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_all_gather( + input: Sequence[Tensor], + output: Sequence[Tensor], + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_reduce_scatter( + input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _rocm_is_backward_pass() -> _bool: ... + +class _CudaDeviceProperties: + name: str + major: _int + minor: _int + multi_processor_count: _int + total_memory: _int + is_integrated: _int + is_multi_gpu_board: _int + max_threads_per_multi_processor: _int + gcnArchName: str + +# Functions related to SDPA +class _SDPAParams: + query: Tensor + key: Tensor + value: Tensor + attn_mask: Optional[Tensor] + dropout: _float + is_causal: _bool + def __init__( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor], + dropout: _float, + is_causal: _bool) -> None: ... + +class _SDPBackend(Enum): + ERROR = -1 + MATH = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + CUDNN_ATTENTION = 3 + +def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... +def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... + +# Defined in torch/csrc/cuda/python_comm.cpp +def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... +def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... +def _broadcast_coalesced( + tensors: List[Tensor], + devices: List[_int], + buffer_size: _int, +) -> List[List[Tensor]]: ... +def _scatter( + tensor: Tensor, + devices: List[_int], + chunk_sizes: Optional[List[_int]], + dim: _int, + streams: Optional[List[Stream]], +) -> List[Tensor]: ... +def _scatter_out( + tensor: Tensor, + out_tensors: List[Tensor], + dim: _int, + streams: Optional[List[Stream]], +) -> List[Tensor]: ... +def _gather( + tensors: List[Tensor], + dim: _int, + destination_index: Optional[_int], +) -> Tensor: ... +def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... + +# Defined in torch/csrc/cuda/Stream.cpp +class _CudaStreamBase(Stream): + stream_id: _int + device_index: _int + device_type: _int + + device: _device + cuda_stream: _int + priority: _int + + def __new__( + self, + priority: _int = 0, + stream_id: _int = 0, + device_index: _int = 0, + stream_ptr: _int = 0, + ) -> _CudaStreamBase: ... + def query(self) -> _bool: ... + def synchronize(self) -> None: ... + def priority_range(self) -> Tuple[_int, _int]: ... + +# Defined in torch/csrc/cuda/Event.cpp +class _CudaEventBase: + device: _device + cuda_event: _int + + def __new__( + cls, + enable_timing: _bool = False, + blocking: _bool = False, + interprocess: _bool = False, + ) -> _CudaEventBase: ... + @classmethod + def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ... + def record(self, stream: _CudaStreamBase) -> None: ... + def wait(self, stream: _CudaStreamBase) -> None: ... + def query(self) -> _bool: ... + def elapsed_time(self, other: _CudaEventBase) -> _float: ... + def synchronize(self) -> None: ... + def ipc_handle(self) -> bytes: ... + +# Defined in torch/csrc/cuda/Graph.cpp +class _CUDAGraph: + def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ... + def capture_end(self) -> None: ... + def replay(self) -> None: ... + def reset(self) -> None: ... + def pool(self) -> Tuple[_int, _int]: ... + def enable_debug_mode(self) -> None: ... + def debug_dump(self, debug_path: str) -> None: ... + +def _cuda_isCurrentStreamCapturing() -> _bool: ... +def _graph_pool_handle() -> Tuple[_int, _int]: ... + +# Defined in torch/csrc/xpu/Module.cpp +def _xpu_setDevice(device: _int) -> None: ... +def _xpu_exchangeDevice(device: _int) -> _int: ... +def _xpu_maybeExchangeDevice(device: _int) -> _int: ... +def _xpu_getDevice() -> _int: ... +def _xpu_getDeviceCount() -> _int: ... +def _xpu_init() -> None: ... +def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... +def _xpu_getCurrentStream(device: _int) -> Tuple: ... +def _xpu_getCurrentRawStream(device: _int) -> _int: ... +def _xpu_synchronize(device: _int) -> None: ... +def _xpu_emptyCache() -> None: ... + +class _XpuDeviceProperties: + name: str + platform_name: str + total_memory: _int + max_compute_units: _int + gpu_eu_count: _int + gpu_subslice_count: _int + max_work_group_size: _int + max_num_sub_groups: _int + sub_group_sizes: List[_int] + type: str + +# Defined in torch/csrc/xpu/Stream.cpp +class _XpuStreamBase(Stream): + stream_id: _int + device_index: _int + device_type: _int + + device: _device + sycl_queue: _int + priority: _int + + def __new__( + cls, + priority: _int = 0, + stream_id: _int = 0, + device_index: _int = 0, + device_type: _int = 0, + ) -> _XpuStreamBase: ... + def query(self) -> _bool: ... + def synchronize(self) -> None: ... + @staticmethod + def priority_range() -> Tuple: ... + +# Defined in torch/csrc/xpu/Event.cpp +class _XpuEventBase: + device: _device + sycl_event: _int + + def __new__(cls, enable_timing: _bool = False) -> _XpuEventBase: ... + def record(self, stream: _XpuEventBase) -> None: ... + def wait(self, stream: _XpuStreamBase) -> None: ... + def query(self) -> _bool: ... + def elapsed_time(self, other: _XpuEventBase) -> _float: ... + def synchronize(self) -> None: ... + +# Defined in torch/csrc/DataLoader.cpp +def _set_worker_signal_handlers( + *arg: Any, +) -> None: ... # THPModule_setWorkerSignalHandlers +def _set_worker_pids( + key: _int, + child_pids: Tuple[_int, ...], +) -> None: ... # THPModule_setWorkerPIDs +def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs +def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails + +# Defined in torch/csrc/jit/python/python_tracer.cpp +class TracingState: + def push_scope(self, scope_name: str) -> None: ... + def pop_scope(self) -> None: ... + def current_scope(self) -> str: ... + def set_graph(self, graph: Graph) -> None: ... + def graph(self) -> Graph: ... + +def _create_graph_by_tracing( + func: Callable[..., Any], + inputs: Any, + var_name_lookup_fn: Callable[[Tensor], str], + strict: Any, + force_outplace: Any, + self: Any = None, + argument_names: List[str] = [], +) -> Tuple[Graph, Stack]: ... +def _tracer_warn_use_python(): ... +def _get_tracing_state() -> TracingState: ... + +# Defined in torch/csrc/jit/python/python_ir.cpp +# Not actually defined in python_ir.cpp, not sure where they are. +class IValue: ... + +Stack = List[IValue] + +class JitType: + annotation_str: str + def isSubtypeOf(self, other: JitType) -> _bool: ... + def with_dtype(self, dtype: _dtype) -> JitType: ... + def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ... + def kind(self) -> str: ... + def scalarType(self) -> Optional[str]: ... + def getElementType(self) -> JitType: ... + def dtype(self) -> Optional[_dtype]: ... + +class InferredType: + def __init__(self, arg: Union[JitType, str]): ... + def type(self) -> JitType: ... + def success(self) -> _bool: ... + def reason(self) -> str: ... + +R = TypeVar("R", bound=JitType) + +class AnyType(JitType): + @staticmethod + def get() -> AnyType: ... + +class NoneType(JitType): + @staticmethod + def get() -> NoneType: ... + +class BoolType(JitType): + @staticmethod + def get() -> BoolType: ... + +class FloatType(JitType): + @staticmethod + def get() -> FloatType: ... + +class ComplexType(JitType): + @staticmethod + def get() -> ComplexType: ... + +class IntType(JitType): + @staticmethod + def get() -> IntType: ... + +class SymIntType(JitType): + @staticmethod + def get() -> SymIntType: ... + +class SymBoolType(JitType): + @staticmethod + def get() -> SymBoolType: ... + +class NumberType(JitType): + @staticmethod + def get() -> NumberType: ... + +class StringType(JitType): + @staticmethod + def get() -> StringType: ... + +class DeviceObjType(JitType): + @staticmethod + def get() -> DeviceObjType: ... + +class _GeneratorType(JitType): + @staticmethod + def get() -> _GeneratorType: ... + +class StreamObjType(JitType): + @staticmethod + def get() -> StreamObjType: ... + +class ListType(JitType): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + @staticmethod + def ofInts() -> ListType: ... + @staticmethod + def ofTensors() -> ListType: ... + @staticmethod + def ofFloats() -> ListType: ... + @staticmethod + def ofComplexDoubles() -> ListType: ... + @staticmethod + def ofBools() -> ListType: ... + @staticmethod + def ofStrings() -> ListType: ... + +class DictType(JitType): + def __init__(self, key: JitType, value: JitType) -> None: ... + def getKeyType(self) -> JitType: ... + def getValueType(self) -> JitType: ... + +class TupleType(JitType): + def __init__(self, a: List[Optional[JitType]]) -> None: ... + def elements(self) -> List[JitType]: ... + +class UnionType(JitType): + def __init__(self, a: List[JitType]) -> None: ... + +class ClassType(JitType): + def __init__(self, qualified_name: str) -> None: ... + +class InterfaceType(JitType): + def __init__(self, qualified_name: str) -> None: ... + def getMethod(self, name: str) -> Optional[FunctionSchema]: ... + def getMethodNames(self) -> List[str]: ... + +class OptionalType(JitType, Generic[R]): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + @staticmethod + def ofTensor() -> OptionalType: ... + +class FutureType(JitType): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + +class AwaitType(JitType): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + +class RRefType(JitType): + def __init__(self, a: JitType) -> None: ... + +class EnumType(JitType): + def __init__( + self, + qualified_name: str, + value_type: JitType, + enum_names_values: List[Any], + ) -> None: ... + +class TensorType(JitType): + @classmethod + def get(cls) -> TensorType: ... + @classmethod + def getInferred(cls) -> TensorType: ... + def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ... + def sizes(self) -> Optional[List[_int]]: ... + def varyingSizes(self) -> Optional[List[Optional[_int]]]: ... + def strides(self) -> Optional[List[_int]]: ... + def device(self) -> Optional[_device]: ... + def dim(self) -> _int: ... + def dtype(self) -> Optional[_dtype]: ... + @staticmethod + def create_from_tensor(t: Tensor) -> TensorType: ... + +# Defined in torch/csrc/jit/python/python_tree_views.cpp +class SourceRange: ... +class TreeView: ... + +class Ident(TreeView): + @property + def name(self) -> str: ... + +class ClassDef(TreeView): ... + +class Def(TreeView): + def name(self) -> Ident: ... + +class Decl(TreeView): ... + +# Defined in torch/csrc/distributed/rpc/init.cpp +def _rpc_init() -> _bool: ... + +# Defined in torch/csrc/distributed/autograd/init.cpp +def _dist_autograd_init() -> _bool: ... + +# Defined in torch/csrc/distributed/c10d/init.cpp +def _c10d_init() -> _bool: ... + +# Defined in torch/csrc/distributed/rpc/testing/init.cpp +def _faulty_agent_init() -> _bool: ... +def _register_py_class_for_device(device: str, cls: Any) -> None: ... +def _activate_cuda_trace() -> None: ... + +# Defined in torch/csrc/Module.cpp +def _current_graph_task_id() -> _int: ... +def _current_autograd_node() -> _Node: ... + +# Defined in torch/csrc/Exceptions.cpp +class _OutOfMemoryError(RuntimeError): ... +class _DistError(RuntimeError): ... +class _DistBackendError(RuntimeError): ... +class _DistStoreError(RuntimeError): ... +class _DistNetworkError(RuntimeError): ... + +# Defined in torch/csrc/profiler/init.cpp +class CapturedTraceback: + pass +def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ... +def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ... + +def _load_mobile_module_from_file(filename: str): ... +def _load_mobile_module_from_bytes(bytes_: bytes): ... +def _load_jit_module_from_file(filename: str): ... +def _load_jit_module_from_bytes(bytes_: bytes): ... +def _save_mobile_module(m: LiteScriptModule, filename: str): ... +def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]): ... +def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ... +def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ... +def _get_module_info_from_flatbuffer(data: bytes): ... +def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ... +def _swap_tensor_impl(t1: Tensor, t2: Tensor): ... +def _save_pickle(obj: Any) -> bytes: ... + +# Defined in torch/csrc/jit/runtime/static/init.cpp +def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ... +def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_aoti.pyi b/MLPY/Lib/site-packages/torch/_C/_aoti.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6a567acae8b4d281361fb2db9855e10184b03cdd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_aoti.pyi @@ -0,0 +1,3 @@ +# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp +class AOTIModelContainerRunnerCpu: ... +class AOTIModelContainerRunnerCuda: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_autograd.pyi b/MLPY/Lib/site-packages/torch/_C/_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..027c69854ec12fcda0397d9f00729f18d41695a4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_autograd.pyi @@ -0,0 +1,123 @@ +from enum import Enum +from typing import Any, Callable, List, Optional, Set + +import torch + +from ._profiler import ( + _ProfilerEvent, + ActiveProfilerType, + ProfilerActivity, + ProfilerConfig, +) + +# Defined in tools/autograd/init.cpp + +class DeviceType(Enum): + CPU = ... + CUDA = ... + MKLDNN = ... + OPENGL = ... + OPENCL = ... + IDEEP = ... + HIP = ... + FPGA = ... + ORT = ... + XLA = ... + MPS = ... + HPU = ... + Meta = ... + Vulkan = ... + Metal = ... + PrivateUse1 = ... + +class ProfilerEvent: + def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ... + def cpu_memory_usage(self) -> int: ... + def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ... + def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ... + def cuda_memory_usage(self) -> int: ... + def device(self) -> int: ... + def handle(self) -> int: ... + def has_cuda(self) -> bool: ... + def is_remote(self) -> bool: ... + def kind(self) -> int: ... + def name(self) -> str: ... + def node_id(self) -> int: ... + def sequence_nr(self) -> int: ... + def shapes(self) -> List[List[int]]: ... + def thread_id(self) -> int: ... + def flops(self) -> float: ... + def is_async(self) -> bool: ... + +class _KinetoEvent: + def name(self) -> str: ... + def device_index(self) -> int: ... + def start_us(self) -> int: ... + def duration_us(self) -> int: ... + def is_async(self) -> bool: ... + def linked_correlation_id(self) -> int: ... + def shapes(self) -> List[List[int]]: ... + def dtypes(self) -> List[str]: ... + def concrete_inputs(self) -> List[Any]: ... + def device_type(self) -> DeviceType: ... + def start_thread_id(self) -> int: ... + def end_thread_id(self) -> int: ... + def correlation_id(self) -> int: ... + def fwd_thread_id(self) -> int: ... + def stack(self) -> List[str]: ... + def scope(self) -> int: ... + def sequence_nr(self) -> int: ... + def flops(self) -> int: ... + def cuda_elapsed_us(self) -> int: ... + def privateuse1_elapsed_us(self) -> int: ... + +class _ProfilerResult: + def events(self) -> List[_KinetoEvent]: ... + def legacy_events(self) -> List[List[ProfilerEvent]]: ... + def save(self, path: str) -> None: ... + def experimental_event_tree(self) -> List[_ProfilerEvent]: ... + def trace_start_us(self) -> int: ... + +class SavedTensor: ... + +def _enable_profiler( + config: ProfilerConfig, + activities: Set[ProfilerActivity], +) -> None: ... +def _prepare_profiler( + config: ProfilerConfig, + activities: Set[ProfilerActivity], +) -> None: ... +def _disable_profiler() -> _ProfilerResult: ... +def _profiler_enabled() -> bool: ... +def _add_metadata_json(key: str, value: str) -> None: ... +def _kineto_step() -> None: ... +def _get_sequence_nr() -> int: ... +def kineto_available() -> bool: ... +def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ... +def _record_function_with_args_exit(handle: torch.Tensor) -> None: ... +def _supported_activities() -> Set[ProfilerActivity]: ... +def _enable_record_function(enable: bool) -> None: ... +def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... +def _push_saved_tensors_default_hooks( + pack_hook: Callable[[torch.Tensor], Any], + unpack_hook: Callable[[Any], torch.Tensor], +) -> None: ... +def _pop_saved_tensors_default_hooks() -> None: ... +def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ... +def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... +def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... +def _profiler_type() -> ActiveProfilerType: ... +def _saved_tensors_hooks_enable() -> None: ... +def _saved_tensors_hooks_disable(message: str) -> None: ... +def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ... + +class CreationMeta(Enum): + DEFAULT = ... + IN_CUSTOM_FUNCTION = ... + MULTI_OUTPUT_NODE = ... + NO_GRAD_MODE = ... + INFERENCE_MODE = ... + +def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ... +def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_cpu.pyi b/MLPY/Lib/site-packages/torch/_C/_cpu.pyi new file mode 100644 index 0000000000000000000000000000000000000000..9dfd41a9f6dee4cceb52e89cb6c20c5c06c941b1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_cpu.pyi @@ -0,0 +1,5 @@ +from torch.types import _bool + +# Defined in torch/csrc/cpu/Module.cpp + +def _is_cpu_support_vnni() -> _bool: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_cudnn.pyi b/MLPY/Lib/site-packages/torch/_C/_cudnn.pyi new file mode 100644 index 0000000000000000000000000000000000000000..15d6289a9180e36bf02fb7e726675302826b3daf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_cudnn.pyi @@ -0,0 +1,17 @@ +from enum import Enum + +from torch.types import _bool, Tuple + +# Defined in torch/csrc/cuda/shared/cudnn.cpp +is_cuda: _bool + +def getRuntimeVersion() -> Tuple[int, int, int]: ... +def getCompileVersion() -> Tuple[int, int, int]: ... +def getVersionInt() -> int: ... + +class RNNMode(int, Enum): + value: int + rnn_relu = ... + rnn_tanh = ... + lstm = ... + gru = ... diff --git a/MLPY/Lib/site-packages/torch/_C/_distributed_autograd.pyi b/MLPY/Lib/site-packages/torch/_C/_distributed_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b1a4062d58a119a4a352a2e565a94100b466d420 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_distributed_autograd.pyi @@ -0,0 +1,26 @@ +from typing import Any, Dict, List, Set + +import torch + +# This module is defined in torch/csrc/distributed/autograd/init.cpp + +class DistAutogradContext: + def _context_id(self) -> int: ... + def _recv_functions(self) -> Dict[int, Any]: ... + def _send_functions(self) -> Dict[int, Any]: ... + def _known_worker_ids(self) -> Set[int]: ... + +def _new_context() -> DistAutogradContext: ... +def _release_context(context_id: int) -> None: ... +def _get_max_id() -> int: ... +def _is_valid_context(worker_id: int) -> bool: ... +def _retrieve_context(context_id: int) -> DistAutogradContext: ... +def _current_context() -> DistAutogradContext: ... +def _init(worker_id: int) -> None: ... +def _get_debug_info() -> Dict[str, str]: ... +def backward( + context_id: int, + roots: List[torch.Tensor], + retain_graph=False, +) -> None: ... +def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_distributed_c10d.pyi b/MLPY/Lib/site-packages/torch/_C/_distributed_c10d.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e0d0cdef4a575a09335b52f85d0f976169103b38 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_distributed_c10d.pyi @@ -0,0 +1,590 @@ +# mypy: disable-error-code="type-arg" +from datetime import timedelta +from enum import Enum +from typing import Any, Dict, List, Optional, overload, Tuple, Union + +import torch +from torch import Tensor +from torch._C import ScriptObject +from torch.futures import Future + +# This module is defined in torch/csrc/distributed/c10d/init.cpp + +_DEFAULT_FIRST_BUCKET_BYTES: int +_DEFAULT_NO_TIMEOUT: timedelta +_DEFAULT_PG_TIMEOUT: timedelta +_DEFAULT_PG_NCCL_TIMEOUT: timedelta + +class BuiltinCommHookType(Enum): + ALLREDUCE = ... + FP16_COMPRESS = ... + +def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ... +def _register_builtin_comm_hook( + reducer: Reducer, + comm_hook_type: BuiltinCommHookType, +): ... +def _set_global_rank(rank: int) -> None: ... +def _hash_tensors(tensors: List[Tensor]) -> int: ... + +class GradBucket: + def index(self) -> int: ... + def buffer(self) -> Tensor: ... + def gradients(self) -> List[Tensor]: ... + def is_last(self) -> bool: ... + def set_buffer(self, tensor: Tensor) -> None: ... + def parameters(self) -> List[Tensor]: ... + +class Reducer: + def __init__( + self, + params: List[Tensor], + bucket_indices: List[List[int]], + per_bucket_size_limits: List[int], + process_group: ProcessGroup, + expect_sparse_gradients: List[bool] = ..., + bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp + find_unused_parameters: bool = ..., + gradient_as_bucket_view: bool = ..., + param_to_name_mapping: Dict[int, str] = ..., + first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp + ): ... + def prepare_for_forward(self) -> None: ... + def prepare_for_backward(self, output: List[Tensor]) -> None: ... + def get_backward_stats(self) -> List[int]: ... + def _install_post_backward_futures(self, futures: List[Future]) -> None: ... + def _rebuild_buckets(self) -> bool: ... + def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ... + def _push_all_rebuilt_params(self) -> None: ... + def _set_forward_pass_work_handle( + self, + work: Work, + use_static_world_size: bool, + ): ... + def _get_local_used_map(self) -> Tensor: ... + def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ... + def _set_static_graph(self) -> None: ... + def _run_comm_hook(self, bucket: GradBucket) -> Future: ... + def set_logger(self, logger: Logger) -> None: ... + def _remove_autograd_hooks(self) -> None: ... + def _check_reducer_finalized(self) -> None: ... + def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ... + def _reset_state(self) -> None: ... + def _update_process_group(self, new_process_group: ProcessGroup) -> None: ... + +class DDPLoggingData: + strs_map: Dict[str, str] + ints_map: Dict[str, int] + +class Logger: + def __init__(self, reducer: Reducer): ... + def set_construction_data_and_log( + self, + module_name: str, + device_ids: List[int], + output_device: int, + broadcast_buffers: bool, + has_sync_bn: bool, + static_graph: bool, + ): ... + def set_runtime_stats_and_log(self) -> None: ... + def set_error_and_log(self, error: str) -> None: ... + def _get_ddp_logging_data(self) -> DDPLoggingData: ... + def _set_comm_hook_name(self, comm_hook: str) -> None: ... + def _set_uneven_input_join(self) -> None: ... + def _set_static_graph(self) -> None: ... + +def get_debug_level(): ... +def set_debug_level(): ... +def set_debug_level_from_env(): ... + +class DebugLevel(Enum): + OFF = ... + INFO = ... + DETAIL = ... + +class ReduceOp: + def __init__(self, op: RedOpType): ... + + SUM: RedOpType = ... + AVG: RedOpType = ... + PRODUCT: RedOpType = ... + MIN: RedOpType = ... + MAX: RedOpType = ... + BAND: RedOpType = ... + BOR: RedOpType = ... + BXOR: RedOpType = ... + PREMUL_SUM: RedOpType = ... + UNUSED: RedOpType = ... + + class RedOpType(Enum): ... + +class BroadcastOptions: + rootRank: int + rootTensor: int + timeout: timedelta + asyncOp: bool + +class AllreduceOptions: + reduceOp: ReduceOp + timeout: timedelta + +class AllreduceCoalescedOptions(AllreduceOptions): ... + +class ReduceOptions: + reduceOp: ReduceOp + rootRank: int + rootTensor: int + timeout: timedelta + +class AllgatherOptions: + timeout: timedelta + asyncOp: bool + +class GatherOptions: + rootRank: int + timeout: timedelta + +class ScatterOptions: + rootRank: int + timeout: timedelta + asyncOp: bool + +class ReduceScatterOptions: + reduceOp: ReduceOp + timeout: timedelta + asyncOp: bool + +class BarrierOptions: + device_ids: List[int] + device: torch.device + timeout: timedelta + +class AllToAllOptions: + timeout: timedelta + +class Store: + def set(self, key: str, value: str): ... + def get(self, key: str) -> bytes: ... + def add(self, key: str, value: int) -> int: ... + def compare_set( + self, + key: str, + expected_value: str, + desired_value: str, + ) -> bytes: ... + def delete_key(self, key: str) -> bool: ... + def num_keys(self) -> int: ... + def set_timeout(self, timeout: timedelta): ... + @overload + def wait(self, keys: List[str]): ... + @overload + def wait(self, keys: List[str], timeout: timedelta): ... + +class FileStore(Store): + def __init__(self, path: str, numWorkers: int = ...): ... + +class HashStore(Store): + def __init__(self): ... + +class TCPStore(Store): + def __init__( + self, + host_name: str, + port: int, + world_size: Optional[int] = ..., + is_master: bool = ..., + timeout: timedelta = ..., + wait_for_workers: bool = ..., + multi_tenant: bool = ..., + master_listen_fd: Optional[int] = ..., + use_libuv: Optional[bool] = ..., + ): ... + @property + def host(self) -> str: ... + @property + def port(self) -> int: ... + +class PrefixStore(Store): + def __init__(self, prefix: str, store: Store): ... + @property + def underlying_store(self) -> Store: ... + +class _DistributedBackendOptions: + def __init__(self): ... + @property + def store(self) -> Store: ... + @store.setter + def store(self, store: Store) -> None: ... + @property + def group_rank(self) -> int: ... + @group_rank.setter + def group_rank(self, rank: int) -> None: ... + @property + def group_size(self) -> int: ... + @group_size.setter + def group_size(self, size: int) -> None: ... + @property + def timeout(self) -> timedelta: ... + @timeout.setter + def timeout(self, timeout: timedelta) -> None: ... + @property + def group_id(self) -> str: ... + @group_id.setter + def group_id(self, group_id: str) -> None: ... + @property + def global_ranks_in_group(self) -> List[int]: ... + @global_ranks_in_group.setter + def global_ranks_in_group(self, ranks: List[int]) -> None: ... + +class Work: + def is_completed(self) -> bool: ... + def is_success(self) -> bool: ... + def exception(self) -> Any: ... + def wait(self, timeout: timedelta = ...) -> bool: ... + def get_future(self) -> Future: ... + def source_rank(self) -> int: ... + def _source_rank(self) -> int: ... + def result(self) -> List[Tensor]: ... + def synchronize(self): ... + def boxed(self) -> ScriptObject: ... + @staticmethod + def unbox(obj: ScriptObject) -> Work: ... + +class Backend: + def __init__( + self, + rank: int, + size: int, + ): ... + @property + def supports_splitting(self) -> bool: ... + def rank(self) -> int: ... + def size(self) -> int: ... + def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ... + def _set_sequence_number_for_group(self) -> None: ... + +class ProcessGroup: + class Options: + def __init__(self, backend: str, timeout: timedelta = ...): ... + @property + def backend(self) -> str: ... + @property + def _timeout(self) -> timedelta: ... + @_timeout.setter + def _timeout(self, val: timedelta) -> None: ... + + class BackendType(Enum): + UNDEFINED = ... + GLOO = ... + NCCL = ... + UCC = ... + MPI = ... + CUSTOM = ... + def __init__(self, store: Store, rank: int, size: int, options: Options): ... + def rank(self) -> int: ... + def size(self) -> int: ... + @overload + def broadcast( + self, + tensors: List[Tensor], + opts=..., + ) -> Work: ... + @overload + def broadcast( + self, + tensor: Tensor, + root: int, + ) -> Work: ... + @overload + def allreduce( + self, + tensors: List[Tensor], + opts: AllreduceOptions = ..., + ) -> Work: ... + @overload + def allreduce( + self, + tensors: List[Tensor], + op=..., + ) -> Work: ... + @overload + def allreduce( + self, + tensor: Tensor, + op=..., + ) -> Work: ... + def allreduce_coalesced( + self, + tensors: List[Tensor], + opts=..., + ) -> Work: ... + def reduce_scatter_tensor_coalesced( + self, + outputTensors: List[Tensor], + inputTensors: List[Tensor], + opts: Optional[ReduceScatterOptions] = None, + ) -> Work: ... + @overload + def reduce( + self, + tensors: List[Tensor], + opts=..., + ) -> Work: ... + @overload + def reduce( + self, + tensor: Tensor, + root: int, + op=..., + ) -> Work: ... + @overload + def allgather( + self, + output_tensors: List[List[Tensor]], + input_tensors: List[Tensor], + opts=..., + ) -> Work: ... + @overload + def allgather( + self, + output_tensors: List[Tensor], + input_tensor: Tensor, + ) -> Work: ... + def _allgather_base( + self, + output: Tensor, + input: Tensor, + opts=..., + ) -> Work: ... + def allgather_coalesced( + self, + output_lists: List[List[Tensor]], + input_list: List[Tensor], + opts=..., + ) -> Work: ... + def allgather_into_tensor_coalesced( + self, + output_lists: List[Tensor], + input_list: List[Tensor], + opts=..., + ) -> Work: ... + @overload + def gather( + self, + output_tensors: List[List[Tensor]], + input_tensors: List[Tensor], + opts=..., + ) -> Work: ... + @overload + def gather( + self, + output_tensors: List[Tensor], + input_tensor: Tensor, + root: int, + ) -> Work: ... + @overload + def scatter( + self, + output_tensors: List[Tensor], + input_tensors: List[List[Tensor]], + opts=..., + ) -> Work: ... + @overload + def scatter( + self, + output_tensor: Tensor, + input_tensors: List[Tensor], + root: int, + ) -> Work: ... + @overload + def reduce_scatter( + self, + output_tensors: List[Tensor], + input_tensors: List[List[Tensor]], + opts=..., + ) -> Work: ... + @overload + def reduce_scatter( + self, + output_tensors: Tensor, + input_tensor: List[Tensor], + ) -> Work: ... + def _reduce_scatter_base( + self, + outputTensor: Tensor, + inputTensor: Tensor, + opts: Optional[ReduceScatterOptions], + ) -> Work: ... + @overload + def alltoall_base( + self, + output_tensor: Tensor, + input_tensor: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + opts=..., + ) -> Work: ... + @overload + def alltoall_base( + self, + output: Tensor, + input: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + ) -> Work: ... + @overload + def alltoall( + self, + output_tensor: List[Tensor], + input_tensor: List[Tensor], + opts=..., + ) -> Work: ... + @overload + def alltoall( + self, + output: List[Tensor], + input: List[Tensor], + ) -> Work: ... + def send( + self, + tensors: List[Tensor], + dstRank: int, + tag: int, + ) -> Work: ... + def recv( + self, + tensors: List[Tensor], + srcRank: int, + tag: int, + ) -> Work: ... + def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ... + def barrier(self, opts=...) -> Work: ... + def boxed(self) -> ScriptObject: ... + @staticmethod + def unbox(obj: ScriptObject) -> ProcessGroup: ... + def _start_coalescing(self, device: torch.device) -> None: ... + def _end_coalescing(self, device: torch.device) -> Work: ... + def _get_backend_name(self) -> str: ... + def _backend_id(self, backend_type: BackendType) -> int: ... + @property + def _device_types(self) -> List[torch.device]: ... + def _get_backend(self, device: torch.device) -> Backend: ... + def _register_backend( + self, + device: torch.device, + backend_type: BackendType, + backend: Optional[Backend], + ) -> None: ... + def _set_group_name(self, name: str) -> None: ... + def name(self) -> str: ... + def _has_hooks(self) -> bool: ... + def _wait_for_pending_works(self) -> None: ... + def _set_sequence_number_for_group(self) -> None: ... + @property + def bound_device_id(self) -> Optional[torch.device]: ... + @bound_device_id.setter + def bound_device_id(self, device: Optional[torch.device]) -> None: ... + @property + def group_name(self) -> str: ... + +class ProcessGroupRoundRobin(ProcessGroup): ... + +def _round_robin_process_groups( + process_groups: List[ProcessGroup], +) -> ProcessGroupRoundRobin: ... + +class ProcessGroupGloo(Backend): + class Device: ... + class Options: ... + + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ): ... + @staticmethod + def create_device(hostname="", interface="") -> Device: ... + @staticmethod + def create_default_device() -> Device: ... + def _set_default_timeout(self, timeout) -> None: ... + +class _ProcessGroupWrapper(Backend): + def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo): ... + wrapped_pg: Backend + +class ProcessGroupNCCL(Backend): + class Options: + def __init__(self, timeout: Optional[timedelta] = None): ... + @property + def backend(self) -> str: ... + @property + def _timeout(self) -> timedelta: ... + @_timeout.setter + def _timeout(self, val: timedelta) -> None: ... + @property + def _is_high_priority_stream(self) -> bool: ... + @_is_high_priority_stream.setter + def _is_high_priority_stream(self, val: bool) -> None: ... + + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ): ... + def _group_start(self) -> None: ... + def _group_end(self) -> None: ... + def _set_default_timeout(self, timeout) -> None: ... + def _shutdown(self) -> None: ... + @property + def uid(self) -> int: ... + +class ProcessGroupUCC(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ): ... + +class ProcessGroupMPI(Backend): + def __init__( + self, + rank: int, + size: int, + pgComm: int, + ): ... + @staticmethod + def create(ranks: List[int]) -> ProcessGroupMPI: ... + +def _compute_bucket_assignment_by_size( + tensors: List[Tensor], + bucket_size_limits: List[int], + expect_sparse_gradient: List[bool] = ..., + tensor_indices: List[int] = ..., +) -> Tuple[List[List[int]], List[int]]: ... +def _broadcast_coalesced( + process_group: ProcessGroup, + tensors: List[Tensor], + buffer_size: int, + src: int, +): ... +def _test_python_store(store: Store): ... +def _verify_params_across_processes( + process_group: ProcessGroup, + params: List[Tensor], + logger: Optional[Logger], +): ... +def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ... +def _register_process_group( + group_name: str, + process_group: ProcessGroup, +) -> None: ... +def _resolve_process_group(group_name: str) -> ProcessGroup: ... +def _unregister_all_process_groups() -> None: ... +def _unregister_process_group(group_name: str) -> None: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_distributed_rpc.pyi b/MLPY/Lib/site-packages/torch/_C/_distributed_rpc.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8ecf79635a4131324743a256f07af50f196c70df --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_distributed_rpc.pyi @@ -0,0 +1,188 @@ +# mypy: disable-error-code="type-arg" +from datetime import timedelta +from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar + +import torch + +from . import Future +from ._autograd import ProfilerEvent +from ._distributed_c10d import Store +from ._profiler import ProfilerConfig + +# This module is defined in torch/csrc/distributed/rpc/init.cpp + +_DEFAULT_INIT_METHOD: str +_DEFAULT_NUM_WORKER_THREADS: int +_UNSET_RPC_TIMEOUT: float +_DEFAULT_RPC_TIMEOUT_SEC: float + +_T = TypeVar("_T") + +class RpcBackendOptions: + rpc_timeout: float + init_method: str + def __init__( + self, + rpc_timeout: float = ..., + init_method: str = ..., + ): ... + +class WorkerInfo: + def __init__(self, name: str, worker_id: int): ... + @property + def name(self) -> str: ... + @property + def id(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + +class RpcAgent: + def join(self, shutdown: bool = False, timeout: float = 0): ... + def sync(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ... + def get_debug_info(self) -> Dict[str, str]: ... + def get_metrics(self) -> Dict[str, str]: ... + +class PyRRef(Generic[_T]): + def __init__(self, value: _T, type_hint: Any = None) -> None: ... + def is_owner(self) -> bool: ... + def confirmed_by_owner(self) -> bool: ... + def owner(self) -> WorkerInfo: ... + def owner_name(self) -> str: ... + def to_here(self, timeout: float = ...) -> _T: ... + def local_value(self) -> Any: ... + def rpc_sync(self, timeout: float = ...) -> Any: ... + def rpc_async(self, timeout: float = ...) -> Any: ... + def remote(self, timeout: float = ...) -> Any: ... + def _serialize(self) -> Tuple: ... + @staticmethod + def _deserialize(tp: Tuple) -> PyRRef: ... + def _get_type(self) -> Type[_T]: ... + def _get_future(self) -> Future[_T]: ... + def _get_profiling_future(self) -> Future[_T]: ... + def _set_profiling_future(self, profilingFuture: Future[_T]): ... + +class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): + num_worker_threads: int + device_maps: Dict[str, Dict[torch.device, torch.device]] + devices: List[torch.device] + def __init__( + self, + num_worker_threads: int, + _transports: Optional[List], + _channels: Optional[List], + rpc_timeout: float = ..., + init_method: str = ..., + device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006 + devices: List[torch.device] = [], # noqa: B006 + ): ... + def _set_device_map( + self, + to: str, + device_map: Dict[torch.device, torch.device], + ): ... + +class TensorPipeAgent(RpcAgent): + def __init__( + self, + store: Store, + name: str, + worker_id: int, + world_size: Optional[int], + opts: _TensorPipeRpcBackendOptionsBase, + reverse_device_maps: Dict[str, Dict[torch.device, torch.device]], + devices: List[torch.device], + ): ... + def join(self, shutdown: bool = False, timeout: float = 0): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ... + def _update_group_membership( + self, + worker_info: WorkerInfo, + my_devices: List[torch.device], + reverse_device_map: Dict[str, Dict[torch.device, torch.device]], + is_join: bool, + ): ... + def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ... + @property + def is_static_group(self) -> bool: ... + @property + def store(self) -> Store: ... + +def _is_current_rpc_agent_set() -> bool: ... +def _get_current_rpc_agent() -> RpcAgent: ... +def _set_and_start_rpc_agent(agent: RpcAgent): ... +def _reset_current_rpc_agent(): ... +def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ... +def _destroy_rref_context(ignoreRRefLeak: bool): ... +def _rref_context_get_debug_info() -> Dict[str, str]: ... +def _cleanup_python_rpc_handler(): ... +def _invoke_rpc_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any, +): ... +def _invoke_rpc_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: List[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool, +): ... +def _invoke_rpc_torchscript( + dstWorkerName: str, + qualifiedNameStr: str, + argsTuple: Tuple, + kwargsDict: Dict, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, +): ... +def _invoke_remote_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any, +): ... +def _invoke_remote_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: List[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool, +): ... +def _invoke_remote_torchscript( + dstWorkerName: WorkerInfo, + qualifiedNameStr: str, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + *args: Any, + **kwargs: Any, +): ... +def get_rpc_timeout() -> float: ... +def enable_gil_profiling(flag: bool): ... +def _set_rpc_timeout(rpcTimeoutSeconds: float): ... + +class RemoteProfilerManager: + @staticmethod + def set_current_profiling_key(key: str): ... + +def _enable_server_process_global_profiler(new_config: ProfilerConfig): ... +def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ... +def _set_profiler_node_id(default_node_id: int): ... +def _enable_jit_rref_pickle(): ... +def _disable_jit_rref_pickle(): ... diff --git a/MLPY/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi b/MLPY/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bf66235d1eaea85bafde64e316a5b3168ecb61db --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi @@ -0,0 +1,35 @@ +from typing import Dict, List + +import torch + +from ._distributed_c10d import Store +from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent + +# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp + +class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): + def __init__( + self, + num_worker_threads: int, + rpc_timeout: float, + init_method: str, + messages_to_fail: List[str], + messages_to_delay: Dict[str, float], + num_fail_sends: int, + ): ... + num_send_recv_threads: int + messages_to_fail: List[str] + messages_to_delay: Dict[str, float] + num_fail_sends: int + +class FaultyTensorPipeAgent(TensorPipeAgent): + def __init__( + self, + store: Store, + name: str, + rank: int, + world_size: int, + options: FaultyTensorPipeRpcBackendOptions, + reverse_device_maps: Dict[str, Dict[torch.device, torch.device]], + devices: List[torch.device], + ): ... diff --git a/MLPY/Lib/site-packages/torch/_C/_functions.pyi b/MLPY/Lib/site-packages/torch/_C/_functions.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c50c31039b91219308ff1ce00c90ac7d77870f19 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_functions.pyi @@ -0,0 +1,11 @@ +from typing import AnyStr, List + +from torch import Tensor + +class UndefinedGrad: + def __init__(self) -> None: ... + def __call__(self, *inputs: Tensor) -> List[Tensor]: ... + +class DelayedError: + def __init__(self, msg: AnyStr, num_inputs: int) -> None: ... + def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_functorch.pyi b/MLPY/Lib/site-packages/torch/_C/_functorch.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8dce498aa642ad78f96fe1ed5895b0da41f72b4c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_functorch.pyi @@ -0,0 +1,77 @@ +from enum import Enum +from typing import Optional, Tuple + +from torch import Tensor + +# Defined in torch/csrc/functorch/init.cpp + +def _set_dynamic_layer_keys_included(included: bool) -> None: ... +def get_unwrapped(tensor: Tensor) -> Tensor: ... +def is_batchedtensor(tensor: Tensor) -> bool: ... +def is_functionaltensor(tensor: Tensor) -> bool: ... +def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ... +def is_gradtrackingtensor(tensor: Tensor) -> bool: ... +def maybe_get_bdim(tensor: Tensor) -> int: ... +def maybe_get_level(tensor: Tensor) -> int: ... +def maybe_current_level() -> Optional[int]: ... +def unwrap_if_dead(tensor: Tensor) -> Tensor: ... +def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ... +def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ... +def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ... +def current_level() -> int: ... +def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ... +def set_single_level_autograd_function_allowed(allowed: bool) -> None: ... +def get_single_level_autograd_function_allowed() -> bool: ... +def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ... +def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ... +def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ... +def _vmap_decrement_nesting() -> int: ... +def _grad_increment_nesting() -> int: ... +def _grad_decrement_nesting() -> int: ... + +# Defined in aten/src/ATen/functorch/Interpreter.h +class TransformType(Enum): + Torch: TransformType = ... + Vmap: TransformType = ... + Grad: TransformType = ... + Jvp: TransformType = ... + Functionalize: TransformType = ... + +class RandomnessType(Enum): + Error: TransformType = ... + Same: TransformType = ... + Different: TransformType = ... + +class CInterpreter: + def key(self) -> TransformType: ... + def level(self) -> int: ... + +class CGradInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def lift(self, Tensor) -> Tensor: ... + def prevGradMode(self) -> bool: ... + +class CJvpInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def lift(self, Tensor) -> Tensor: ... + def prevFwdGradMode(self) -> bool: ... + +class CFunctionalizeInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def functionalizeAddBackViews(self) -> bool: ... + +class CVmapInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def batchSize(self) -> int: ... + def randomness(self) -> RandomnessType: ... + +class DynamicLayer: ... + +def get_interpreter_stack() -> list[CInterpreter]: ... +def peek_interpreter_stack() -> CInterpreter: ... +def pop_dynamic_layer_stack() -> DynamicLayer: ... +def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_itt.pyi b/MLPY/Lib/site-packages/torch/_C/_itt.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a6f2559396fde84b318a768d3e6563ba6be93873 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_itt.pyi @@ -0,0 +1,5 @@ +# Defined in torch/csrc/itt.cpp +def is_available() -> None: ... +def rangePush(message: str) -> None: ... +def rangePop() -> None: ... +def mark(message: str) -> None: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_lazy.pyi b/MLPY/Lib/site-packages/torch/_C/_lazy.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7d7889a1981afa90a28eab6ef08ade70280b1e18 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_lazy.pyi @@ -0,0 +1,28 @@ +from typing import List + +from torch import Tensor + +# defined in torch/csrc/lazy/python/init.cpp +def _mark_step(device: str, devices: List[str], wait: bool): ... +def _wait_device_ops(devices: List[str]): ... +def _reset_metrics(): ... +def _counter_names() -> List[str]: ... +def _counter_value(name: str) -> int: ... +def _metrics_report() -> str: ... +def _get_graph_hash(tensors: List[Tensor]) -> str: ... +def _sync_multi( + tensors: List[Tensor], + devices: List[str], + wait: bool = True, + sync_ltc_data: bool = True, +): ... +def _get_tensor_id(tensor: Tensor) -> int: ... +def _get_tensors_text(tensors: List[Tensor]) -> str: ... +def _get_tensors_dot(tensors: List[Tensor]) -> str: ... +def _get_tensors_backend(tensors: List[Tensor]) -> str: ... +def _get_force_fallback() -> str: ... +def _set_force_fallback(newval: str): ... +def _clear_ir_cache(): ... +def _dump_ir_cache(filename: str): ... +def _set_reuse_ir(val: bool): ... +def _get_default_device_type(): ... diff --git a/MLPY/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi b/MLPY/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi new file mode 100644 index 0000000000000000000000000000000000000000..535af4c4851101fb690292378ffe6da55eb80e32 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi @@ -0,0 +1,11 @@ +# defined in torch/csrc/lazy/python/init.cpp + +from typing import Any, List, Tuple + +from torch import Tensor + +def _init(): ... +def _get_tensors_ts_device_data_node( + tensors: List[Tensor], +) -> Tuple[List[int], List[Any]]: ... +def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_monitor.pyi b/MLPY/Lib/site-packages/torch/_C/_monitor.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6d33ebde320174a6bd7d4eb9505e4d2245c852ea --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_monitor.pyi @@ -0,0 +1,44 @@ +# Defined in torch/csrc/monitor/python_init.cpp + +import datetime +from enum import Enum +from typing import Callable, Dict, List, Union + +class Aggregation(Enum): + VALUE = ... + MEAN = ... + COUNT = ... + SUM = ... + MAX = ... + MIN = ... + +class Stat: + name: str + count: int + def __init__( + self, + name: str, + aggregations: List[Aggregation], + window_size: int, + max_samples: int = -1, + ) -> None: ... + def add(self, v: float) -> None: ... + def get(self) -> Dict[Aggregation, float]: ... + +class Event: + name: str + timestamp: datetime.datetime + data: Dict[str, Union[int, float, bool, str]] + def __init__( + self, + name: str, + timestamp: datetime.datetime, + data: Dict[str, Union[int, float, bool, str]], + ) -> None: ... + +def log_event(e: Event) -> None: ... + +class EventHandlerHandle: ... + +def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ... +def unregister_event_handler(handle: EventHandlerHandle) -> None: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_nn.pyi b/MLPY/Lib/site-packages/torch/_C/_nn.pyi new file mode 100644 index 0000000000000000000000000000000000000000..9beb7e61a5753a6206a426599ae99fe19611d331 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_nn.pyi @@ -0,0 +1,86 @@ +# mypy: disable-error-code="type-arg" +from typing import List, Optional, overload, Sequence, Tuple, Union + +from torch import memory_format, Tensor +from torch.types import _bool, _device, _dtype, _int, _size + +# Defined in tools/autograd/templates/python_nn_functions.cpp + +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ... +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ... +def avg_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ... +def avg_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ... +def elu_(input: Tensor, alpha: float = ...) -> Tensor: ... +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ... +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ... +def gelu(input: Tensor, approximate: str = ...) -> Tensor: ... +def hardsigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def hardtanh(input: Tensor, min_val: float = ..., max_val: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ... +def hardtanh_(input: Tensor, min_val: float = ..., max_val: float = ...) -> Tensor: ... +def leaky_relu(input: Tensor, negative_slope: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ... +def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ... +def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ... +def log_sigmoid(input: Tensor) -> Tensor: ... +def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ... +def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: Optional[float] = None) -> Tensor: ... +def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None) -> Tensor: ... +def softplus(input: Tensor, beta: float = ..., threshold: float = ...) -> Tensor: ... +def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ... + +# Defined in aten/src/ATen/native/mkldnn/Linear.cpp +def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ... + +# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +def mkldnn_reorder_conv2d_weight( + self: Tensor, + padding: List, + stride: List, + dilatation: List, + groups: int, +) -> Tensor: ... +def mkldnn_reorder_conv3d_weight( + self: Tensor, + padding: List, + stride: List, + dilatation: List, + groups: int, +) -> Tensor: ... + +# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp +def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ... + +# Defined at tools/autograd/templates/python_nn_functions.cpp +@overload +def _parse_to( + device: _device, + dtype: _dtype, + non_blocking: _bool, + copy: _bool, + *, + memory_format: memory_format, +) -> Tuple[_device, _dtype, _bool, memory_format]: ... +@overload +def _parse_to( + dtype: _dtype, + non_blocking: _bool, + copy: _bool, + *, + memory_format: memory_format, +) -> Tuple[_device, _dtype, _bool, memory_format]: ... +@overload +def _parse_to( + tensor: Tensor, + non_blocking: _bool, + copy: _bool, + *, + memory_format: memory_format, +) -> Tuple[_device, _dtype, _bool, memory_format]: ... + +# Defined in aten/src/ATen/native/PadSequence.cpp +def pad_sequence( + sequences: List[Tensor], + batch_first: bool = False, + padding_value: float = ..., +) -> Tensor: ... +def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ... +def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_nvtx.pyi b/MLPY/Lib/site-packages/torch/_C/_nvtx.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ff1b574b947940e31ff403cca714c8d1bec0c50d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_nvtx.pyi @@ -0,0 +1,6 @@ +# Defined in torch/csrc/cuda/shared/nvtx.cpp +def rangePushA(message: str) -> int: ... +def rangePop() -> int: ... +def rangeStartA(message: str) -> int: ... +def rangeEnd(int) -> None: ... +def markA(message: str) -> None: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_onnx.pyi b/MLPY/Lib/site-packages/torch/_C/_onnx.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ac1d0f6d51934957c01ea23e7b0059d4ea42316f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_onnx.pyi @@ -0,0 +1,40 @@ +# Defined in torch/csrc/onnx/init.cpp + +from enum import Enum + +_CAFFE2_ATEN_FALLBACK: bool +PRODUCER_VERSION: str + +class TensorProtoDataType(Enum): + UNDEFINED = ... + FLOAT = ... + UINT8 = ... + INT8 = ... + UINT16 = ... + INT16 = ... + INT32 = ... + INT64 = ... + STRING = ... + BOOL = ... + FLOAT16 = ... + DOUBLE = ... + UINT32 = ... + UINT64 = ... + COMPLEX64 = ... + COMPLEX128 = ... + BFLOAT16 = ... + FLOAT8E5M2 = ... + FLOAT8E4M3FN = ... + FLOAT8E5M2FNUZ = ... + FLOAT8E4M3FNUZ = ... + +class OperatorExportTypes(Enum): + ONNX = ... + ONNX_ATEN = ... + ONNX_ATEN_FALLBACK = ... + ONNX_FALLTHROUGH = ... + +class TrainingMode(Enum): + EVAL = ... + PRESERVE = ... + TRAINING = ... diff --git a/MLPY/Lib/site-packages/torch/_C/_profiler.pyi b/MLPY/Lib/site-packages/torch/_C/_profiler.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7cc8dc08e673c4906ac78d506b3c0012a5c9b0bd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_profiler.pyi @@ -0,0 +1,238 @@ +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from torch._C import device, dtype, layout +from typing_extensions import TypeAlias + +# defined in torch/csrc/profiler/python/init.cpp + +class RecordScope(Enum): + FUNCTION = ... + BACKWARD_FUNCTION = ... + TORCHSCRIPT_FUNCTION = ... + KERNEL_FUNCTION_DTYPE = ... + CUSTOM_CLASS = ... + BUILD_FEATURE = ... + LITE_INTERPRETER = ... + USER_SCOPE = ... + STATIC_RUNTIME_OP = ... + STATIC_RUNTIME_MODEL = ... + +class ProfilerState(Enum): + Disable = ... + CPU = ... + CUDA = ... + NVTX = ... + ITT = ... + KINETO = ... + KINETO_GPU_FALLBACK = ... + KINETO_PRIVATEUSE1_FALLBACK = ... + KINETO_PRIVATEUSE1 = ... + +class ActiveProfilerType(Enum): + NONE = ... + LEGACY = ... + KINETO = ... + NVTX = ... + ITT = ... + +class ProfilerActivity(Enum): + CPU = ... + CUDA = ... + MTIA = ... + PrivateUse1 = ... + +class _EventType(Enum): + TorchOp = ... + Backend = ... + Allocation = ... + OutOfMemory = ... + PyCall = ... + PyCCall = ... + Kineto = ... + +class _ExperimentalConfig: + def __init__( + self, + profiler_metrics: List[str] = ..., + profiler_measure_per_kernel: bool = ..., + verbose: bool = ..., + performance_events: List[str] = ..., + enable_cuda_sync_events: bool = ..., + ) -> None: ... + +class ProfilerConfig: + def __init__( + self, + state: ProfilerState, + report_input_shapes: bool, + profile_memory: bool, + with_stack: bool, + with_flops: bool, + with_modules: bool, + experimental_config: _ExperimentalConfig, + ) -> None: ... + +class _ProfilerEvent: + start_tid: int + start_time_ns: int + children: List[_ProfilerEvent] + + # TODO(robieta): remove in favor of `self.typed` + extra_fields: Union[ + _ExtraFields_TorchOp, + _ExtraFields_Backend, + _ExtraFields_Allocation, + _ExtraFields_OutOfMemory, + _ExtraFields_PyCall, + _ExtraFields_PyCCall, + _ExtraFields_Kineto, + ] + + @property + def typed( + self, + ) -> Union[ + Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp], + Tuple[Literal[_EventType.Backend], _ExtraFields_Backend], + Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation], + Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory], + Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall], + Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall], + Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto], + ]: ... + @property + def name(self) -> str: ... + @property + def tag(self) -> _EventType: ... + @property + def id(self) -> int: ... + @property + def parent(self) -> Optional[_ProfilerEvent]: ... + @property + def correlation_id(self) -> int: ... + @property + def end_time_ns(self) -> int: ... + @property + def duration_time_ns(self) -> int: ... + +class _TensorMetadata: + impl_ptr: Optional[int] + storage_data_ptr: Optional[int] + id: Optional[int] + + @property + def allocation_id(self) -> Optional[int]: ... + @property + def layout(self) -> layout: ... + @property + def device(self) -> device: ... + @property + def dtype(self) -> dtype: ... + @property + def sizes(self) -> List[int]: ... + @property + def strides(self) -> List[int]: ... + +Scalar: TypeAlias = Union[int, float, bool, complex] +Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]] + +class _ExtraFields_TorchOp: + name: str + sequence_number: int + allow_tf32_cublas: bool + + @property + def inputs(self) -> List[Input]: ... + @property + def scope(self) -> RecordScope: ... + +class _ExtraFields_Backend: ... + +class _ExtraFields_Allocation: + ptr: int + id: Optional[int] + alloc_size: int + total_allocated: int + total_reserved: int + + @property + def allocation_id(self) -> Optional[int]: ... + @property + def device(self) -> device: ... + +class _ExtraFields_OutOfMemory: ... + +class _PyFrameState: + line_number: int + function_name: str + + @property + def file_name(self) -> str: ... + +class _NNModuleInfo: + @property + def self_ptr(self) -> int: ... + @property + def cls_ptr(self) -> int: ... + @property + def cls_name(self) -> str: ... + @property + def parameters( + self, + ) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ... + +class _OptimizerInfo: + @property + def parameters( + self, + ) -> List[ + Tuple[ + # Parameter + _TensorMetadata, + # + # Gradient (if present during optimizer.step()) + Optional[_TensorMetadata], + # + # Optimizer state for Parameter as (name, tensor) pairs + List[Tuple[str, _TensorMetadata]], + ] + ]: ... + +class _ExtraFields_PyCCall: + @property + def caller(self) -> _PyFrameState: ... + +class _ExtraFields_PyCall: + @property + def callsite(self) -> _PyFrameState: ... + @property + def caller(self) -> _PyFrameState: ... + @property + def module(self) -> Optional[_NNModuleInfo]: ... + @property + def optimizer(self) -> Optional[_OptimizerInfo]: ... + +class _ExtraFields_Kineto: ... + +def _add_execution_trace_observer(output_file_path: str) -> bool: ... +def _remove_execution_trace_observer() -> None: ... +def _enable_execution_trace_observer() -> None: ... +def _disable_execution_trace_observer() -> None: ... +def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ... +def _set_fwd_bwd_enabled_val(val: bool) -> None: ... +def _set_cuda_sync_enabled_val(val: bool) -> None: ... + +class CapturedTraceback: ... + +def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ... + +# The Dict has name, filename, line +def symbolize_tracebacks( + to_symbolize: List[CapturedTraceback], +) -> List[List[Dict[str, str]]]: ... + +class _RecordFunctionFast: + def __init__(self, name: str) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ... diff --git a/MLPY/Lib/site-packages/torch/_C/_verbose.pyi b/MLPY/Lib/site-packages/torch/_C/_verbose.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6d1dbfda288978aa1680412ad24bf488160ba854 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_C/_verbose.pyi @@ -0,0 +1,3 @@ +# Defined in torch/csrc/utils/verbose.cpp +def mkl_set_verbose(enable: int) -> int: ... +def mkldnn_set_verbose(level: int) -> int: ... diff --git a/MLPY/Lib/site-packages/torch/_VF.py b/MLPY/Lib/site-packages/torch/_VF.py new file mode 100644 index 0000000000000000000000000000000000000000..53724c3246e81163c95826a3c69f5912e0dc3304 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_VF.py @@ -0,0 +1,30 @@ +""" +This makes the functions in torch._C._VariableFunctions available as + torch._VF. +without mypy being able to find them. + +A subset of those functions are mapped to ATen functions in +torch/jit/_builtins.py + +See https://github.com/pytorch/pytorch/issues/21478 for the reason for +introducing torch._VF + +""" +import sys +import types + +import torch + + +class VFModule(types.ModuleType): + vf: types.ModuleType + + def __init__(self, name): + super().__init__(name) + self.vf = torch._C._VariableFunctions + + def __getattr__(self, attr): + return getattr(self.vf, attr) + + +sys.modules[__name__] = VFModule(__name__) diff --git a/MLPY/Lib/site-packages/torch/_VF.pyi b/MLPY/Lib/site-packages/torch/_VF.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e7bc45da38b2228706f8e353adb5af335d22eae3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_VF.pyi @@ -0,0 +1,25648 @@ +# @generated from torch/_C/_VariableFunctions.pyi.in +# mypy: disable-error-code="type-arg" + +import builtins +from typing import ( + Any, + Callable, + ContextManager, + Iterator, + List, + Literal, + NamedTuple, + Optional, + overload, + Sequence, + Tuple, + TypeVar, + Union, +) + +import torch +from torch import contiguous_format, Generator, inf, memory_format, strided, SymInt, Tensor +from torch.types import ( + _bool, + _complex, + _device, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Device, + Number, +) + +from torch._prims_common import DeviceLikeType + +@overload +def __and__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __and__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +def _adaptive_avg_pool2d(input: Tensor, output_size: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]]) -> Tensor: ... +def _adaptive_avg_pool3d(input: Tensor, output_size: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]]) -> Tensor: ... +def _add_batch_dim(input: Tensor, batch_dim: _int, level: _int) -> Tensor: ... +@overload +def _add_relu(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def _add_relu(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def _add_relu_(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def _add_relu_(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: ... +def _addmm_activation(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, use_gelu: _bool = False, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def _aminmax(input: Tensor) -> Tuple[Tensor, Tensor]: ... +@overload +def _aminmax(input: Tensor, dim: _int, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: ... +def _amp_foreach_non_finite_check_and_unscale_(self: Union[Tuple[Tensor, ...], List[Tensor]], found_inf: Tensor, inv_scale: Tensor) -> None: ... +def _amp_update_scale_(input: Tensor, growth_tracker: Tensor, found_inf: Tensor, scale_growth_factor: _float, scale_backoff_factor: _float, growth_interval: _int) -> Tensor: ... +@overload +def _assert_async(input: Tensor) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + ... +@overload +def _assert_async(input: Tensor, assert_msg: str) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + ... +def _assert_scalar(self: Union[Number, _complex], assert_msg: str) -> None: ... +def _assert_tensor_metadata(a: Tensor, size: Optional[Sequence[Union[_int, SymInt]]] = None, stride: Optional[Sequence[Union[_int, SymInt]]] = None, dtype: Optional[_dtype] = None) -> None: ... +def _batch_norm_impl_index(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, _int]: ... +def _cast_Byte(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Char(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Double(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Float(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Half(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Int(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Long(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Short(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _choose_qparams_per_tensor(input: Tensor, reduce_range: _bool = False) -> Tuple[_float, _int]: ... +def _chunk_cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int, num_chunks: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _coalesce(input: Tensor) -> Tensor: ... +def _compute_linear_combination(input: Tensor, coefficients: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _conj(input: Tensor) -> Tensor: ... +def _conj_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _conj_physical(input: Tensor) -> Tensor: ... +def _convert_indices_from_coo_to_csr(input: Tensor, size: _int, *, out_int32: _bool = False, out: Optional[Tensor] = None) -> Tensor: ... +def _convert_indices_from_csr_to_coo(crow_indices: Tensor, col_indices: Tensor, *, out_int32: _bool = False, transpose: _bool = False, out: Optional[Tensor] = None) -> Tensor: ... +def _convert_weight_to_int4pack(input: Tensor, innerKTiles: _int) -> Tensor: ... +@overload +def _convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], transposed: _bool, output_padding: _size, groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, cudnn_enabled: _bool) -> Tensor: ... +@overload +def _convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], transposed: _bool, output_padding: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, cudnn_enabled: _bool, allow_tf32: _bool) -> Tensor: ... +def _convolution_mode(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: str, dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def _copy_from(input: Tensor, dst: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _copy_from_and_resize(input: Tensor, dst: Tensor) -> Tensor: ... +def _cslt_compress(input: Tensor) -> Tensor: ... +def _cslt_sparse_mm(compressed_A: Tensor, dense_B: Tensor, bias: Optional[Tensor] = None, alpha: Optional[Tensor] = None, out_dtype: Optional[_dtype] = None, transpose_result: _bool = False, alg_id: _int = 0) -> Tensor: ... +def _cslt_sparse_mm_search(compressed_A: Tensor, dense_B: Tensor, bias: Optional[Tensor] = None, alpha: Optional[Tensor] = None, out_dtype: Optional[_dtype] = None, transpose_result: _bool = False) -> _int: ... +@overload +def _ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int = 0, zero_infinity: _bool = False) -> Tuple[Tensor, Tensor]: ... +@overload +def _ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int = 0, zero_infinity: _bool = False) -> Tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int, deterministic: _bool, zero_infinity: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int, deterministic: _bool, zero_infinity: _bool) -> Tuple[Tensor, Tensor]: ... +def _cudnn_init_dropout_state(dropout: _float, train: _bool, dropout_seed: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _cudnn_rnn(input: Tensor, weight: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, weight_buf: Optional[Tensor], hx: Tensor, cx: Optional[Tensor], mode: _int, hidden_size: Union[_int, SymInt], proj_size: Union[_int, SymInt], num_layers: _int, batch_first: _bool, dropout: _float, train: _bool, bidirectional: _bool, batch_sizes: Sequence[Union[_int, SymInt]], dropout_state: Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _cudnn_rnn_flatten_weight(weight_arr: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, input_size: Union[_int, SymInt], mode: _int, hidden_size: Union[_int, SymInt], proj_size: Union[_int, SymInt], num_layers: _int, batch_first: _bool, bidirectional: _bool) -> Tensor: ... +def _cufft_clear_plan_cache(device_index: _int) -> None: ... +def _cufft_get_plan_cache_max_size(device_index: _int) -> _int: ... +def _cufft_get_plan_cache_size(device_index: _int) -> _int: ... +def _cufft_set_plan_cache_max_size(device_index: _int, max_size: _int) -> None: ... +def _cummax_helper(input: Tensor, values: Tensor, indices: Tensor, dim: _int) -> None: ... +def _cummin_helper(input: Tensor, values: Tensor, indices: Tensor, dim: _int) -> None: ... +def _debug_has_internal_overlap(input: Tensor) -> _int: ... +def _dim_arange(like: Tensor, dim: _int) -> Tensor: ... +def _dirichlet_grad(x: Tensor, alpha: Tensor, total: Tensor) -> Tensor: ... +def _disable_functionalization(): ... +@overload +def _efficientzerotensor(size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _efficientzerotensor(*size: _int, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False, padding_idx: _int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def _embedding_bag_forward_only(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False, padding_idx: _int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def _empty_affine_quantized(size: Sequence[Union[_int, SymInt]], *, scale: _float = 1, zero_point: _int = 0, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _empty_affine_quantized(*size: _int, scale: _float = 1, zero_point: _int = 0, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized(size: Sequence[Union[_int, SymInt]], *, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized(*size: _int, scales: Tensor, zero_points: Tensor, axis: _int, memory_format: Optional[memory_format] = contiguous_format, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _enable_functionalization(*, reapply_views: _bool = False): ... +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: ... +def _fake_quantize_learnable_per_channel_affine(input: Tensor, scale: Tensor, zero_point: Tensor, axis: _int, quant_min: _int, quant_max: _int, grad_factor: _float = 1.0) -> Tensor: ... +def _fake_quantize_learnable_per_tensor_affine(input: Tensor, scale: Tensor, zero_point: Tensor, quant_min: _int, quant_max: _int, grad_factor: _float = 1.0) -> Tensor: ... +def _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(input: Tensor, scale: Tensor, zero_point: Tensor, fake_quant_enabled: Tensor, quant_min: _int, quant_max: _int) -> torch.return_types._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: ... +def _fft_c2c(input: Tensor, dim: Sequence[Union[_int, SymInt]], normalization: _int, forward: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _fft_c2r(input: Tensor, dim: _size, normalization: _int, last_dim_size: Union[_int, SymInt], *, out: Optional[Tensor] = None) -> Tensor: ... +def _fft_r2c(input: Tensor, dim: _size, normalization: _int, onesided: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _fill_mem_eff_dropout_mask_(input: Tensor, dropout_p: _float, seed: _int, offset: _int) -> Tensor: ... +def _foobar(input: Tensor, arg1: _bool = True, arg2: _bool = True, *, arg3: _bool = True) -> Tensor: ... +def _foreach_abs(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_abs(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + ... +def _foreach_abs_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_abs_(self: List[Tensor]) -> None + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + ... +def _foreach_acos(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_acos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + ... +def _foreach_acos_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_acos_(self: List[Tensor]) -> None + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor, *, alpha: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_add_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_addcdiv(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_addcdiv_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> None: ... +@overload +def _foreach_addcdiv_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_addcmul(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_addcmul_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Tensor) -> None: ... +@overload +def _foreach_addcmul_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensor1: Union[Tuple[Tensor, ...], List[Tensor]], tensor2: Union[Tuple[Tensor, ...], List[Tensor]], value: Union[Number, _complex] = 1) -> None: ... +def _foreach_asin(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_asin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + ... +def _foreach_asin_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_asin_(self: List[Tensor]) -> None + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + ... +def _foreach_atan(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_atan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + ... +def _foreach_atan_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_atan_(self: List[Tensor]) -> None + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + ... +def _foreach_ceil(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_ceil(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + ... +def _foreach_ceil_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_ceil_(self: List[Tensor]) -> None + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + ... +@overload +def _foreach_clamp_max(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_clamp_max_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_clamp_max_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +@overload +def _foreach_clamp_min(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_clamp_min_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_clamp_min_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_copy_(self: Union[Tuple[Tensor, ...], List[Tensor]], src: Union[Tuple[Tensor, ...], List[Tensor]], non_blocking: _bool = False) -> None: ... +def _foreach_cos(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_cos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + ... +def _foreach_cos_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_cos_(self: List[Tensor]) -> None + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + ... +def _foreach_cosh(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_cosh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + ... +def _foreach_cosh_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_cosh_(self: List[Tensor]) -> None + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> None: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_div_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_erf(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_erf(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + ... +def _foreach_erf_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_erf_(self: List[Tensor]) -> None + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + ... +def _foreach_erfc(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_erfc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + ... +def _foreach_erfc_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_erfc_(self: List[Tensor]) -> None + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + ... +def _foreach_exp(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_exp(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + ... +def _foreach_exp_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_exp_(self: List[Tensor]) -> None + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + ... +def _foreach_expm1(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_expm1(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + ... +def _foreach_expm1_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_expm1_(self: List[Tensor]) -> None + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + ... +def _foreach_floor(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_floor(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + ... +def _foreach_floor_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_floor_(self: List[Tensor]) -> None + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + ... +def _foreach_frac(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_frac(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + ... +def _foreach_frac_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_frac_(self: List[Tensor]) -> None + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + ... +@overload +def _foreach_lerp(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weight: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_lerp(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weights: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_lerp_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weight: Union[Number, _complex]) -> None: ... +@overload +def _foreach_lerp_(self: Union[Tuple[Tensor, ...], List[Tensor]], tensors1: Union[Tuple[Tensor, ...], List[Tensor]], weights: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_lgamma(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_lgamma(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + ... +def _foreach_lgamma_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_lgamma_(self: List[Tensor]) -> None + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + ... +def _foreach_log(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log` to each Tensor of the input list. + """ + ... +def _foreach_log10(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log10(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + ... +def _foreach_log10_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log10_(self: List[Tensor]) -> None + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + ... +def _foreach_log1p(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log1p(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + ... +def _foreach_log1p_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log1p_(self: List[Tensor]) -> None + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + ... +def _foreach_log2(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_log2(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + ... +def _foreach_log2_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log2_(self: List[Tensor]) -> None + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + ... +def _foreach_log_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_log_(self: List[Tensor]) -> None + + Apply :func:`torch.log` to each Tensor of the input list. + """ + ... +@overload +def _foreach_maximum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_maximum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_maximum(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_maximum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_maximum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_maximum_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +@overload +def _foreach_minimum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_minimum(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_minimum(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_minimum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_minimum_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_minimum_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Tensor) -> None: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +@overload +def _foreach_mul_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_neg(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_neg(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + ... +def _foreach_neg_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_neg_(self: List[Tensor]) -> None + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + ... +def _foreach_norm(self: Union[Tuple[Tensor, ...], List[Tensor]], ord: Union[Number, _complex] = 2) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow(self: Union[Number, _complex], exponent: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_pow_(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_pow_(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Number, _complex]) -> None: ... +@overload +def _foreach_pow_(self: Union[Tuple[Tensor, ...], List[Tensor]], exponent: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_reciprocal(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_reciprocal(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + ... +def _foreach_reciprocal_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_reciprocal_(self: List[Tensor]) -> None + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + ... +def _foreach_round(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_round(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.round` to each Tensor of the input list. + """ + ... +def _foreach_round_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_round_(self: List[Tensor]) -> None + + Apply :func:`torch.round` to each Tensor of the input list. + """ + ... +def _foreach_sigmoid(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sigmoid(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + ... +def _foreach_sigmoid_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sigmoid_(self: List[Tensor]) -> None + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + ... +def _foreach_sign(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +def _foreach_sign_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: ... +def _foreach_sin(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + ... +def _foreach_sin_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sin_(self: List[Tensor]) -> None + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + ... +def _foreach_sinh(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sinh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + ... +def _foreach_sinh_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sinh_(self: List[Tensor]) -> None + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + ... +def _foreach_sqrt(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_sqrt(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + ... +def _foreach_sqrt_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_sqrt_(self: List[Tensor]) -> None + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + ... +@overload +def _foreach_sub(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_sub(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_sub(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> Tuple[Tensor, ...]: ... +@overload +def _foreach_sub_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalars: Sequence[Union[Number, _complex]]) -> None: ... +@overload +def _foreach_sub_(self: Union[Tuple[Tensor, ...], List[Tensor]], other: Union[Tuple[Tensor, ...], List[Tensor]], *, alpha: Union[Number, _complex] = 1) -> None: ... +@overload +def _foreach_sub_(self: Union[Tuple[Tensor, ...], List[Tensor]], scalar: Union[Number, _complex]) -> None: ... +def _foreach_tan(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_tan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + ... +def _foreach_tan_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_tan_(self: List[Tensor]) -> None + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + ... +def _foreach_tanh(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_tanh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + ... +def _foreach_tanh_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_tanh_(self: List[Tensor]) -> None + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + ... +def _foreach_trunc(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + _foreach_trunc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + ... +def _foreach_trunc_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_trunc_(self: List[Tensor]) -> None + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + ... +def _foreach_zero_(self: Union[Tuple[Tensor, ...], List[Tensor]]) -> None: + r""" + _foreach_zero_(self: List[Tensor]) -> None + + Apply :func:`torch.zero` to each Tensor of the input list. + """ + ... +def _from_functional_tensor(t: Tensor) -> Tensor: ... +def _functional_assert_async(input: Tensor, assert_msg: str, dep_token: Tensor) -> Tensor: ... +def _functional_assert_scalar(self: Union[Number, _complex], assert_msg: str, dep_token: Tensor) -> Tensor: ... +def _functional_sym_constrain_range(size: Union[Number, _complex], min: Optional[_int], max: Optional[_int], dep_token: Tensor) -> Tensor: ... +def _functional_sym_constrain_range_for_size(size: Union[Number, _complex], min: Optional[_int], max: Optional[_int], dep_token: Tensor) -> Tensor: ... +def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ... +def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ... +def _functionalize_commit_update(t: Tensor) -> None: ... +def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ... +def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ... +def _functionalize_sync(t: Tensor) -> None: ... +@overload +def _fused_adam_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: Tensor, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_adam_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: _float, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_adamw_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: Tensor, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_adamw_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], exp_avgs: Union[Tuple[Tensor, ...], List[Tensor]], exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], max_exp_avg_sqs: Union[Tuple[Tensor, ...], List[Tensor]], state_steps: Union[Tuple[Tensor, ...], List[Tensor]], *, lr: _float, beta1: _float, beta2: _float, weight_decay: _float, eps: _float, amsgrad: _bool, maximize: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +def _fused_dropout(input: Tensor, p: _float, generator: Optional[Generator] = None) -> Tuple[Tensor, Tensor]: ... +def _fused_moving_avg_obs_fq_helper(input: Tensor, observer_on: Tensor, fake_quant_on: Tensor, running_min: Tensor, running_max: Tensor, scale: Tensor, zero_point: Tensor, averaging_const: _float, quant_min: _int, quant_max: _int, ch_axis: _int, per_row_fake_quant: _bool = False, symmetric_quant: _bool = False) -> torch.return_types._fused_moving_avg_obs_fq_helper: ... +def _fused_sdp_choice(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: _float = 0.0, is_causal: _bool = False, *, scale: Optional[_float] = None) -> _int: ... +@overload +def _fused_sgd_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], momentum_buffer_list: Union[Tuple[Tensor, ...], List[Tensor]], *, weight_decay: _float, momentum: _float, lr: Tensor, dampening: _float, nesterov: _bool, maximize: _bool, is_first_step: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +@overload +def _fused_sgd_(self: Union[Tuple[Tensor, ...], List[Tensor]], grads: Union[Tuple[Tensor, ...], List[Tensor]], momentum_buffer_list: Union[Tuple[Tensor, ...], List[Tensor]], *, weight_decay: _float, momentum: _float, lr: _float, dampening: _float, nesterov: _bool, maximize: _bool, is_first_step: _bool, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None) -> None: ... +def _fw_primal_copy(input: Tensor, level: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _grid_sampler_2d_cpu_fallback(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def _has_compatible_shallow_copy_type(input: Tensor, from_: Tensor) -> _bool: ... +def _histogramdd_bin_edges(input: Tensor, bins: _size, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> Tuple[Tensor, ...]: ... +def _histogramdd_from_bin_cts(input: Tensor, bins: _size, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> Tensor: ... +def _histogramdd_from_bin_tensors(input: Tensor, bins: Union[Tuple[Tensor, ...], List[Tensor]], *, weight: Optional[Tensor] = None, density: _bool = False) -> Tensor: ... +def _index_put_impl_(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False, unsafe: _bool = False) -> Tensor: ... +def _indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _int_mm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _is_all_true(input: Tensor) -> Tensor: ... +def _is_any_true(input: Tensor) -> Tensor: ... +def _is_functional_tensor(t: Tensor) -> _bool: ... +def _is_zerotensor(input: Tensor) -> _bool: ... +def _lazy_clone(input: Tensor) -> Tensor: ... +def _linalg_check_errors(info: Tensor, api_name: str, *, is_matrix: _bool) -> None: ... +def _linalg_det(A: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_det: ... +def _linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_eigh: ... +def _linalg_slogdet(A: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_slogdet: ... +def _linalg_solve_ex(A: Tensor, B: Tensor, *, left: _bool = True, check_errors: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_solve_ex: ... +def _linalg_svd(A: Tensor, full_matrices: _bool = False, compute_uv: _bool = True, *, driver: Optional[str] = None, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types._linalg_svd: ... +def _log_softmax(input: Tensor, dim: _int, half_to_float: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _log_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input_dtype: _dtype, *, out: Optional[Tensor] = None) -> Tensor: ... +def _logcumsumexp(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _lstm_mps(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _lu_with_info(input: Tensor, pivot: _bool = True, check_errors: _bool = True) -> torch.return_types._lu_with_info: ... +def _make_dep_token(*, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _make_dual(primal: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _make_dual_copy(primal: Tensor, tangent: Tensor, level: _int, *, out: Optional[Tensor] = None) -> Tensor: ... +def _make_per_channel_quantized_tensor(input: Tensor, scale: Tensor, zero_point: Tensor, axis: _int) -> Tensor: ... +def _make_per_tensor_quantized_tensor(input: Tensor, scale: _float, zero_point: _int) -> Tensor: ... +def _masked_scale(input: Tensor, mask: Tensor, scale: _float) -> Tensor: ... +def _masked_softmax(input: Tensor, mask: Tensor, dim: Optional[_int] = None, mask_type: Optional[_int] = None) -> Tensor: ... +def _mixed_dtypes_linear(input: Tensor, weight: Tensor, scale: Tensor, *, bias: Optional[Tensor] = None, activation: Optional[str] = None) -> Tensor: ... +def _mkldnn_reshape(input: Tensor, shape: _size) -> Tensor: ... +def _mkldnn_transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mkldnn_transpose_(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mps_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def _mps_convolution_transpose(input: Tensor, weight: Tensor, padding: Sequence[Union[_int, SymInt]], output_padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +@overload +def _native_batch_norm_legit(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Tensor, running_var: Tensor, training: _bool, momentum: _float, eps: _float, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor, Tensor]: ... +@overload +def _native_batch_norm_legit(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], training: _bool, momentum: _float, eps: _float, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor, Tensor]: ... +def _native_batch_norm_legit_no_training(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Tensor, running_var: Tensor, momentum: _float, eps: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +def _native_multi_head_attention(query: Tensor, key: Tensor, value: Tensor, embed_dim: _int, num_head: _int, qkv_weight: Tensor, qkv_bias: Tensor, proj_weight: Tensor, proj_bias: Tensor, mask: Optional[Tensor] = None, need_weights: _bool = True, average_attn_weights: _bool = True, mask_type: Optional[_int] = None) -> Tuple[Tensor, Tensor]: ... +def _neg_view(input: Tensor) -> Tensor: ... +def _neg_view_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nested_from_padded(padded: Tensor, cpu_nested_shape_example: Tensor, fuse_transform_0213: _bool = False) -> Tensor: ... +def _nested_from_padded_and_nested_example(padded: Tensor, nt_example: Tensor) -> Tensor: ... +def _nested_get_jagged_dummy(any: Tensor) -> Tensor: ... +def _nested_get_lengths(input: Tensor) -> Tensor: ... +def _nested_get_offsets(input: Tensor) -> Tensor: ... +def _nested_get_ragged_idx(input: Tensor) -> _int: ... +def _nested_get_values(input: Tensor) -> Tensor: ... +def _nested_get_values_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nested_tensor_from_mask(t: Tensor, mask: Tensor, mask_check: _bool = True) -> Tensor: ... +def _nested_tensor_from_mask_left_aligned(t: Tensor, mask: Tensor) -> _bool: ... +def _nested_tensor_from_tensor_list(list: Union[Tuple[Tensor, ...], List[Tensor]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = None) -> Tensor: ... +def _nested_tensor_softmax_with_shape(input: Tensor, query: Tensor) -> Tensor: ... +def _nested_view_from_buffer(input: Tensor, nested_size: Tensor, nested_strides: Tensor, offsets: Tensor) -> Tensor: ... +def _nested_view_from_buffer_copy(input: Tensor, nested_size: Tensor, nested_strides: Tensor, offsets: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nested_view_from_jagged(input: Tensor, offsets: Tensor, dummy: Tensor, lengths: Optional[Tensor] = None, ragged_idx: _int = 1) -> Tensor: ... +def _nested_view_from_jagged_copy(input: Tensor, offsets: Tensor, dummy: Tensor, lengths: Optional[Tensor] = None, ragged_idx: _int = 1, *, out: Optional[Tensor] = None) -> Tensor: ... +def _nnpack_available() -> _bool: ... +def _nnpack_spatial_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]], stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def _pack_padded_sequence(input: Tensor, lengths: Tensor, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def _pad_packed_sequence(data: Tensor, batch_sizes: Tensor, batch_first: _bool, padding_value: Union[Number, _complex], total_length: _int) -> Tuple[Tensor, Tensor]: ... +def _pin_memory(input: Tensor, device: Optional[Optional[DeviceLikeType]] = None) -> Tensor: ... +def _prelu_kernel(input: Tensor, weight: Tensor) -> Tensor: ... +def _print(s: str) -> None: ... +def _propagate_xla_data(input: Tensor, output: Tensor) -> None: ... +def _remove_batch_dim(input: Tensor, level: _int, batch_size: _int, out_dim: _int) -> Tensor: ... +def _reshape_alias_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None) -> Tensor: ... +def _reshape_from_tensor(input: Tensor, shape: Tensor) -> Tensor: ... +def _resize_output_(input: Tensor, size: Sequence[Union[_int, SymInt]], device: Optional[DeviceLikeType]) -> Tensor: ... +def _rowwise_prune(weight: Tensor, mask: Tensor, compressed_indices_dtype: _dtype) -> Tuple[Tensor, Tensor]: ... +def _sample_dirichlet(input: Tensor, generator: Optional[Generator] = None) -> Tensor: ... +def _saturate_weight_to_fp16(weight: Tensor) -> Tensor: ... +def _scaled_dot_product_attention_math(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: _float = 0.0, is_causal: _bool = False, dropout_mask: Optional[Tensor] = None, *, scale: Optional[_float] = None) -> Tuple[Tensor, Tensor]: ... +def _scaled_dot_product_cudnn_attention(query: Tensor, key: Tensor, value: Tensor, dropout_p: _float = 0.0, is_causal: _bool = False, return_debug_mask: _bool = False, *, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_cudnn_attention: ... +def _scaled_dot_product_efficient_attention(query: Tensor, key: Tensor, value: Tensor, attn_bias: Optional[Tensor], compute_log_sumexp: _bool, dropout_p: _float = 0.0, is_causal: _bool = False, *, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_efficient_attention: ... +def _scaled_dot_product_flash_attention(query: Tensor, key: Tensor, value: Tensor, dropout_p: _float = 0.0, is_causal: _bool = False, return_debug_mask: _bool = False, *, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_flash_attention: ... +def _scaled_dot_product_flash_attention_for_cpu(query: Tensor, key: Tensor, value: Tensor, dropout_p: _float = 0.0, is_causal: _bool = False, *, attn_mask: Optional[Tensor] = None, scale: Optional[_float] = None) -> torch.return_types._scaled_dot_product_flash_attention_for_cpu: ... +def _scaled_mm(input: Tensor, mat2: Tensor, *, bias: Optional[Tensor] = None, out_dtype: Optional[_dtype] = None, scale_a: Optional[Tensor] = None, scale_b: Optional[Tensor] = None, scale_result: Optional[Tensor] = None, use_fast_accum: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor]: ... +def _shape_as_tensor(input: Tensor) -> Tensor: ... +def _sobol_engine_draw(quasi: Tensor, n: _int, sobolstate: Tensor, dimension: _int, num_generated: _int, dtype: Optional[_dtype]) -> Tuple[Tensor, Tensor]: ... +def _sobol_engine_ff_(input: Tensor, n: _int, sobolstate: Tensor, dimension: _int, num_generated: _int) -> Tensor: ... +def _sobol_engine_initialize_state_(input: Tensor, dimension: _int) -> Tensor: ... +def _sobol_engine_scramble_(input: Tensor, ltm: Tensor, dimension: _int) -> Tensor: ... +def _softmax(input: Tensor, dim: _int, half_to_float: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def _softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input_dtype: _dtype, *, grad_input: Optional[Tensor] = None) -> Tensor: ... +def _sparse_broadcast_to(input: Tensor, size: _size) -> Tensor: ... +def _sparse_broadcast_to_copy(input: Tensor, size: _size, *, out: Optional[Tensor] = None) -> Tensor: ... +def _sparse_csr_prod(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: ... +def _sparse_csr_sum(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, dtype: Optional[_dtype] = None) -> Tensor: ... +def _sparse_log_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input: Tensor) -> Tensor: ... +def _sparse_semi_structured_linear(input: Tensor, weight: Tensor, meta: Tensor, *, bias: Optional[Tensor] = None, activation: Optional[str] = None, out_dtype: Optional[_dtype] = None) -> Tensor: ... +def _sparse_softmax_backward_data(grad_output: Tensor, output: Tensor, dim: _int, input: Tensor) -> Tensor: ... +def _sparse_sparse_matmul(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, *, dtype: _dtype) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, dim: Union[_int, _size]) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, dim: Union[_int, _size], *, dtype: _dtype) -> Tensor: ... +def _stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: ... +def _standard_gamma(input: Tensor, generator: Optional[Generator] = None) -> Tensor: ... +def _standard_gamma_grad(input: Tensor, output: Tensor) -> Tensor: ... +def _sync(t: Tensor) -> None: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor) -> Tensor: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor, b: _bool) -> Tensor: ... +def _test_autograd_multiple_dispatch_view(input: Tensor) -> Tensor: ... +def _test_autograd_multiple_dispatch_view_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _test_check_tensor(input: Tensor) -> Tensor: ... +def _test_functorch_fallback(input: Tensor, other: Tensor) -> Tensor: ... +def _test_parallel_materialize(input: Tensor, num_parallel: _int, skip_first: _bool = False) -> Tensor: ... +def _test_serialization_subcmul(input: Tensor, other: Tensor, alpha: Union[Number, _complex] = 1) -> Tensor: ... +def _to_cpu(tensors: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: ... +def _to_functional_tensor(t: Tensor) -> Tensor: ... +def _to_sparse_semi_structured(dense: Tensor) -> Tuple[Tensor, Tensor]: ... +def _transform_bias_rescale_qkv(qkv: Tensor, qkv_bias: Tensor, num_heads: _int) -> Tuple[Tensor, Tensor, Tensor]: ... +def _transformer_encoder_layer_fwd(src: Tensor, embed_dim: _int, num_heads: _int, qkv_weight: Tensor, qkv_bias: Tensor, proj_weight: Tensor, proj_bias: Tensor, use_gelu: _bool, norm_first: _bool, eps: _float, norm_weight_1: Tensor, norm_bias_1: Tensor, norm_weight_2: Tensor, norm_bias_2: Tensor, ffn_weight_1: Tensor, ffn_bias_1: Tensor, ffn_weight_2: Tensor, ffn_bias_2: Tensor, mask: Optional[Tensor] = None, mask_type: Optional[_int] = None) -> Tensor: ... +def _trilinear(i1: Tensor, i2: Tensor, i3: Tensor, expand1: _size, expand2: _size, expand3: _size, sumdim: _size, unroll_dim: _int = 1) -> Tensor: ... +def _triton_multi_head_attention(query: Tensor, key: Tensor, value: Tensor, embed_dim: _int, num_head: _int, qkv_weight: Tensor, qkv_bias: Tensor, proj_weight: Tensor, proj_bias: Tensor, mask: Optional[Tensor] = None) -> Tensor: ... +def _triton_scaled_dot_attention(q: Tensor, k: Tensor, v: Tensor, dropout_p: _float = 0.0) -> Tensor: ... +def _unique(input: Tensor, sorted: _bool = True, return_inverse: _bool = False) -> Tuple[Tensor, Tensor]: ... +def _unique2(input: Tensor, sorted: _bool = True, return_inverse: _bool = False, return_counts: _bool = False) -> Tuple[Tensor, Tensor, Tensor]: ... +def _unpack_dual(dual: Tensor, level: _int) -> torch.return_types._unpack_dual: ... +def _unsafe_index(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]]) -> Tensor: ... +def _unsafe_index_put(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: ... +@overload +def _use_cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int) -> _bool: ... +@overload +def _use_cudnn_ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int) -> _bool: ... +def _use_cudnn_rnn_flatten_weight() -> _bool: ... +def _validate_compressed_sparse_indices(is_crow: _bool, compressed_idx: Tensor, plain_idx: Tensor, cdim: _int, dim: _int, nnz: _int) -> None: ... +def _validate_sparse_bsc_tensor_args(ccol_indices: Tensor, row_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _validate_sparse_bsr_tensor_args(crow_indices: Tensor, col_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _validate_sparse_compressed_tensor_args(compressed_indices: Tensor, plain_indices: Tensor, values: Tensor, size: _size, layout: _layout) -> None: ... +def _validate_sparse_coo_tensor_args(indices: Tensor, values: Tensor, size: _size, is_coalesced: Optional[_bool] = None) -> None: ... +def _validate_sparse_csc_tensor_args(ccol_indices: Tensor, row_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _validate_sparse_csr_tensor_args(crow_indices: Tensor, col_indices: Tensor, values: Tensor, size: _size) -> None: ... +def _values_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def _weight_int4pack_mm(input: Tensor, mat2: Tensor, qGroupSize: _int, qScaleAndZeros: Tensor) -> Tensor: ... +def _weight_int8pack_mm(input: Tensor, mat2: Tensor, scales: Tensor) -> Tensor: ... +def _weight_norm(v: Tensor, g: Tensor, dim: _int = 0) -> Tensor: ... +def _weight_norm_interface(v: Tensor, g: Tensor, dim: _int = 0) -> Tuple[Tensor, Tensor]: ... +def abs(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + abs(input, *, out=None) -> Tensor + + Computes the absolute value of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = |\text{input}_{i}| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.abs(torch.tensor([-1, -2, 3])) + tensor([ 1, 2, 3]) + """ + ... +def abs_(input: Tensor) -> Tensor: ... +def absolute(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + absolute(input, *, out=None) -> Tensor + + Alias for :func:`torch.abs` + """ + ... +def acos(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + acos(input, *, out=None) -> Tensor + + Computes the inverse cosine of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \cos^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) + >>> torch.acos(a) + tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) + """ + ... +def acos_(input: Tensor) -> Tensor: ... +def acosh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + acosh(input, *, out=None) -> Tensor + + Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) + + Note: + The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range + will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(1, 2) + >>> a + tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) + >>> torch.acosh(a) + tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) + """ + ... +def acosh_(input: Tensor) -> Tensor: ... +def adaptive_avg_pool1d(input: Tensor, output_size: Union[_int, _size]) -> Tensor: ... +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ... +@overload +def add(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], *, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + ... +@overload +def add(self: Tensor, alpha: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + ... +@overload +def add(self: Tensor, alpha: Union[Number, _complex], other: Tensor, *, out: Tensor) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(input: Tensor, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + ... +@overload +def addcdiv(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + ... +@overload +def addcdiv(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor, *, out: Tensor) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + ... +@overload +def addcdiv(input: Tensor, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + ... +@overload +def addcmul(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + ... +@overload +def addcmul(self: Tensor, value: Union[Number, _complex], tensor1: Tensor, tensor2: Tensor, *, out: Tensor) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + ... +@overload +def addcmul(input: Tensor, tensor1: Tensor, tensor2: Tensor, *, value: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat1: Tensor, mat2: Tensor, *, out: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmm(beta: Union[Number, _complex], self: Tensor, mat1: Tensor, mat2: Tensor, *, out: Tensor) -> Tensor: + r""" + addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat: Tensor, vec: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat: Tensor, vec: Tensor, *, out: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(input: Tensor, mat: Tensor, vec: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, mat: Tensor, vec: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv(beta: Union[Number, _complex], self: Tensor, mat: Tensor, vec: Tensor, *, out: Tensor) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + ... +@overload +def addmv_(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat: Tensor, vec: Tensor) -> Tensor: ... +@overload +def addmv_(input: Tensor, mat: Tensor, vec: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def addmv_(beta: Union[Number, _complex], self: Tensor, mat: Tensor, vec: Tensor) -> Tensor: ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], vec1: Tensor, vec2: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], vec1: Tensor, vec2: Tensor, *, out: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(input: Tensor, vec1: Tensor, vec2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, vec1: Tensor, vec2: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +@overload +def addr(beta: Union[Number, _complex], self: Tensor, vec1: Tensor, vec2: Tensor, *, out: Tensor) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + ... +def adjoint(input: Tensor) -> Tensor: + r""" + adjoint(Tensor) -> Tensor + Returns a view of the tensor conjugated and with the last two dimensions transposed. + + ``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and + to ``x.transpose(-2, -1)`` for real tensors. + + Example:: + >>> x = torch.arange(4, dtype=torch.float) + >>> A = torch.complex(x, x).reshape(2, 2) + >>> A + tensor([[0.+0.j, 1.+1.j], + [2.+2.j, 3.+3.j]]) + >>> A.adjoint() + tensor([[0.-0.j, 2.-2.j], + [1.-1.j, 3.-3.j]]) + >>> (A.adjoint() == A.mH).all() + tensor(True) + """ + ... +def affine_grid_generator(theta: Tensor, size: Sequence[Union[_int, SymInt]], align_corners: _bool) -> Tensor: ... +def alias_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.alias`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def all(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +@overload +def all(input: Tensor, dim: Optional[_size] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +@overload +def all(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +@overload +def all(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + all(input) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + ... +def allclose(input: Tensor, other: Tensor, rtol: _float = 1e-05, atol: _float = 1e-08, equal_nan: _bool = False) -> _bool: + r""" + allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool + + This function checks if :attr:`input` and :attr:`other` satisfy the condition: + + .. math:: + \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert + + elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to + `numpy.allclose `_ + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Example:: + + >>> torch.allclose(torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08])) + False + >>> torch.allclose(torch.tensor([10000., 1e-08]), torch.tensor([10000.1, 1e-09])) + True + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')])) + False + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) + True + """ + ... +def alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def alpha_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def amax(input: Tensor, dim: Union[_int, _size] = (), keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + amax(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the maximum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices, + - ``amax``/``amin`` evenly distributes gradient between equal values, + while ``max(dim)``/``min(dim)`` propagates gradient only to a single + index in the source tensor. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.8177, 1.4878, -0.2491, 0.9130], + [-0.7158, 1.1775, 2.0992, 0.4817], + [-0.0053, 0.0164, -1.3738, -0.0507], + [ 1.9700, 1.1106, -1.0318, -1.0816]]) + >>> torch.amax(a, 1) + tensor([1.4878, 2.0992, 0.0164, 1.9700]) + """ + ... +def amin(input: Tensor, dim: Union[_int, _size] = (), keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + amin(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the minimum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices, + - ``amax``/``amin`` evenly distributes gradient between equal values, + while ``max(dim)``/``min(dim)`` propagates gradient only to a single + index in the source tensor. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.6451, -0.4866, 0.2987, -1.3312], + [-0.5744, 1.2980, 1.8397, -0.2713], + [ 0.9128, 0.9214, -1.7268, -0.2995], + [ 0.9023, 0.4853, 0.9075, -1.6165]]) + >>> torch.amin(a, 1) + tensor([-1.3312, -0.5744, -1.7268, -1.6165]) + """ + ... +def aminmax(input: Tensor, *, dim: Optional[_int] = None, keepdim: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.aminmax: + r""" + aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) + + Computes the minimum and maximum values of the :attr:`input` tensor. + + Args: + input (Tensor): + The input tensor + + Keyword Args: + dim (Optional[int]): + The dimension along which to compute the values. If `None`, + computes the values over the entire :attr:`input` tensor. + Default is `None`. + keepdim (bool): + If `True`, the reduced dimensions will be kept in the output + tensor as dimensions with size 1 for broadcasting, otherwise + they will be removed, as if calling (:func:`torch.squeeze`). + Default is `False`. + out (Optional[Tuple[Tensor, Tensor]]): + Optional tensors on which to write the result. Must have the same + shape and dtype as the expected output. + Default is `None`. + + Returns: + A named tuple `(min, max)` containing the minimum and maximum values. + + Raises: + RuntimeError + If any of the dimensions to compute the values over has size 0. + + .. note:: + NaN values are propagated to the output if at least one value is NaN. + + .. seealso:: + :func:`torch.amin` computes just the minimum value + :func:`torch.amax` computes just the maximum value + + Example:: + + >>> torch.aminmax(torch.tensor([1, -3, 5])) + torch.return_types.aminmax( + min=tensor(-3), + max=tensor(5)) + + >>> # aminmax propagates NaNs + >>> torch.aminmax(torch.tensor([1, -3, 5, torch.nan])) + torch.return_types.aminmax( + min=tensor(nan), + max=tensor(nan)) + + >>> t = torch.arange(10).view(2, 5) + >>> t + tensor([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + >>> t.aminmax(dim=0, keepdim=True) + torch.return_types.aminmax( + min=tensor([[0, 1, 2, 3, 4]]), + max=tensor([[5, 6, 7, 8, 9]])) + """ + ... +def angle(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + angle(input, *, out=None) -> Tensor + + Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + + .. math:: + \text{out}_{i} = angle(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + .. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + + Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, -45]) + """ + ... +@overload +def any(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def any(input: Tensor, dim: Optional[_size] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def any(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def any(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + any(input) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + ... +@overload +def arange(start: Number, end: Number, step: Number, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(start: Number, end: Number, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(end: Number, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(end: Union[Number, _complex], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(start: Union[Number, _complex], end: Union[Number, _complex], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +@overload +def arange(start: Union[Number, _complex], end: Union[Number, _complex], step: Union[Number, _complex] = 1, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + ... +def arccos(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arccos(input, *, out=None) -> Tensor + + Alias for :func:`torch.acos`. + """ + ... +def arccos_(input: Tensor) -> Tensor: ... +def arccosh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arccosh(input, *, out=None) -> Tensor + + Alias for :func:`torch.acosh`. + """ + ... +def arccosh_(input: Tensor) -> Tensor: ... +def arcsin(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arcsin(input, *, out=None) -> Tensor + + Alias for :func:`torch.asin`. + """ + ... +def arcsin_(input: Tensor) -> Tensor: ... +def arcsinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arcsinh(input, *, out=None) -> Tensor + + Alias for :func:`torch.asinh`. + """ + ... +def arcsinh_(input: Tensor) -> Tensor: ... +def arctan(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arctan(input, *, out=None) -> Tensor + + Alias for :func:`torch.atan`. + """ + ... +def arctan2(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arctan2(input, other, *, out=None) -> Tensor + Alias for :func:`torch.atan2`. + """ + ... +def arctan_(input: Tensor) -> Tensor: ... +def arctanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + arctanh(input, *, out=None) -> Tensor + + Alias for :func:`torch.atanh`. + """ + ... +def arctanh_(input: Tensor) -> Tensor: ... +def argmax(input: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + argmax(input) -> LongTensor + + Returns the indices of the maximum value of all elements in the :attr:`input` tensor. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple maximal values then the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a) + tensor(0) + + .. function:: argmax(input, dim, keepdim=False) -> LongTensor + :noindex: + + Returns the indices of the maximum values of a tensor across a dimension. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. If ``None``, the argmax of the flattened input is returned. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a, dim=1) + tensor([ 0, 2, 0, 1]) + """ + ... +def argmin(input: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + argmin(input, dim=None, keepdim=False) -> LongTensor + + Returns the indices of the minimum value(s) of the flattened tensor or along a dimension + + This is the second value returned by :meth:`torch.min`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. If ``None``, the argmin of the flattened input is returned. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], + [ 1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) + >>> torch.argmin(a, dim=1) + tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) + """ + ... +@overload +def argsort(input: Tensor, *, stable: _bool, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + ... +@overload +def argsort(input: Tensor, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + ... +@overload +def argsort(input: Tensor, dim: Union[str, ellipsis, None], descending: _bool = False) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + ... +def argwhere(input: Tensor) -> Tensor: + r""" + argwhere(input) -> Tensor + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + .. note:: + This function is similar to NumPy's `argwhere`. + + When :attr:`input` is on CUDA, this function causes host-device synchronization. + + Args: + {input} + + Example:: + + >>> t = torch.tensor([1, 0, 1]) + >>> torch.argwhere(t) + tensor([[0], + [2]]) + >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) + >>> torch.argwhere(t) + tensor([[0, 0], + [0, 2], + [1, 1], + [1, 2]]) + """ + ... +def as_strided(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided(input, size, stride, storage_offset=None) -> Tensor + + Create a view of an existing `torch.Tensor` :attr:`input` with specified + :attr:`size`, :attr:`stride` and :attr:`storage_offset`. + + .. warning:: + Prefer using other view functions, like :meth:`torch.Tensor.expand`, + to setting a view's strides manually with `as_strided`, as this + function's behavior depends on the implementation of a tensor's storage. + The constructed view of the storage must only refer to elements within + the storage or a runtime error will be thrown, and if the view is + "overlapped" (with multiple indices referring to the same element in + memory) its behavior is undefined. + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. + + Example:: + + >>> x = torch.randn(3, 3) + >>> x + tensor([[ 0.9039, 0.6291, 1.0795], + [ 0.1586, 2.1939, -0.4900], + [-0.1909, -0.7503, 1.9355]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2)) + >>> t + tensor([[0.9039, 1.0795], + [0.6291, 0.1586]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) + tensor([[0.6291, 0.1586], + [1.0795, 2.1939]]) + """ + ... +def as_strided_(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: ... +def as_strided_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.as_strided`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def as_strided_scatter(input: Tensor, src: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the elements corresponding to the result of calling + input.as_strided(size, stride, storage_offset). + + This function returns a tensor with fresh storage; it does not + return a view. + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + `torch.as_strided(input, size, stride, storage_offset)` + + Example:: + + >>> a = torch.arange(4).reshape(2, 2) + 1 + >>> a + tensor([[1, 2], + [3, 4]]) + >>> b = torch.zeros(3, 3) + >>> b + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> torch.as_strided_scatter(b, a, (2, 2), (1, 2)) + tensor([[1., 3., 2.], + [4., 0., 0.], + [0., 0., 0.]]) + """ + ... +def as_tensor(data: Any, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None) -> Tensor: + r""" + as_tensor(data, dtype=None, device=None) -> Tensor + + Converts :attr:`data` into a tensor, sharing data and preserving autograd + history if possible. + + If :attr:`data` is already a tensor with the requested dtype and device + then :attr:`data` itself is returned, but if :attr:`data` is a + tensor with a different dtype or device then it's copied as if using + `data.to(dtype=dtype, device=device)`. + + If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device then a + tensor is constructed using :func:`torch.from_numpy`. + + .. seealso:: + + :func:`torch.tensor` never shares its data and creates a new "leaf tensor" (see :doc:`/notes/autograd`). + + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) + """ + ... +def asarray(obj: Any, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, copy: Optional[_bool] = None, requires_grad: _bool = False) -> Tensor: + r""" + asarray(obj, *, dtype=None, device=None, copy=None, requires_grad=False) -> Tensor + + Converts :attr:`obj` to a tensor. + + :attr:`obj` can be one of: + + 1. a tensor + 2. a NumPy array or a NumPy scalar + 3. a DLPack capsule + 4. an object that implements Python's buffer protocol + 5. a scalar + 6. a sequence of scalars + + When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, + by default, not require a gradient, have the same datatype as :attr:`obj`, be on the + same device, and share memory with it. These properties can be controlled with the + :attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. + If the returned tensor is of a different datatype, on a different device, or a copy is + requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` + is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is + also a tensor with an autograd history then the returned tensor will have the same history. + + When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's + buffer protocol then the buffer is interpreted as an array of bytes grouped according to + the size of the datatype passed to the :attr:`dtype` keyword argument. (If no datatype is + passed then the default floating point datatype is used, instead.) The returned tensor + will have the specified datatype (or default floating point datatype if none is specified) + and, by default, be on the CPU device and share memory with the buffer. + + When :attr:`obj` is a NumPy scalar, the returned tensor will be a 0-dimensional tensor on + the CPU and that doesn't share its memory (i.e. ``copy=True``). By default datatype will + be the PyTorch datatype corresponding to the NumPy's scalar's datatype. + + When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the + returned tensor will, by default, infer its datatype from the scalar values, be on the + current default device, and not share its memory. + + .. seealso:: + + :func:`torch.tensor` creates a tensor that always copies the data from the input object. + :func:`torch.from_numpy` creates a tensor that always shares memory from NumPy arrays. + :func:`torch.frombuffer` creates a tensor that always shares memory from objects that + implement the buffer protocol. + :func:`torch.from_dlpack` creates a tensor that always shares memory from + DLPack capsules. + + Args: + obj (object): a tensor, NumPy array, DLPack Capsule, object that implements Python's + buffer protocol, scalar, or sequence of scalars. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the datatype of the returned tensor. + Default: ``None``, which causes the datatype of the returned tensor to be + inferred from :attr:`obj`. + copy (bool, optional): controls whether the returned tensor shares memory with :attr:`obj`. + Default: ``None``, which causes the returned tensor to share memory with :attr:`obj` + whenever possible. If ``True`` then the returned tensor does not share its memory. + If ``False`` then the returned tensor shares its memory with :attr:`obj` and an + error is thrown if it cannot. + device (:class:`torch.device`, optional): the device of the returned tensor. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. + requires_grad (bool, optional): whether the returned tensor requires grad. + Default: ``False``, which causes the returned tensor not to require a gradient. + If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` + is also a tensor with an autograd history then the returned tensor will have + the same history. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> # Shares memory with tensor 'a' + >>> b = torch.asarray(a) + >>> a.data_ptr() == b.data_ptr() + True + >>> # Forces memory copy + >>> c = torch.asarray(a, copy=True) + >>> a.data_ptr() == c.data_ptr() + False + + >>> a = torch.tensor([1., 2., 3.], requires_grad=True) + >>> b = a + 2 + >>> b + tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', with no grad + >>> c = torch.asarray(b) + >>> c + tensor([3., 4., 5.]) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> d = torch.asarray(b, requires_grad=True) + >>> d + tensor([3., 4., 5.], grad_fn=) + + >>> array = numpy.array([1, 2, 3]) + >>> # Shares memory with array 'array' + >>> t1 = torch.asarray(array) + >>> array.__array_interface__['data'][0] == t1.data_ptr() + True + >>> # Copies memory due to dtype mismatch + >>> t2 = torch.asarray(array, dtype=torch.float32) + >>> array.__array_interface__['data'][0] == t2.data_ptr() + False + + >>> scalar = numpy.float64(0.5) + >>> torch.asarray(scalar) + tensor(0.5000, dtype=torch.float64) + """ + ... +def asin(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + asin(input, *, out=None) -> Tensor + + Returns a new tensor with the arcsine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5962, 1.4985, -0.4396, 1.4525]) + >>> torch.asin(a) + tensor([-0.6387, nan, -0.4552, nan]) + """ + ... +def asin_(input: Tensor) -> Tensor: ... +def asinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + asinh(input, *, out=None) -> Tensor + + Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) + >>> torch.asinh(a) + tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) + """ + ... +def asinh_(input: Tensor) -> Tensor: ... +def atan(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + atan(input, *, out=None) -> Tensor + + Returns a new tensor with the arctangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) + >>> torch.atan(a) + tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) + """ + ... +def atan2(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + atan2(input, other, *, out=None) -> Tensor + + Element-wise arctangent of :math:`\text{input}_{i} / \text{other}_{i}` + with consideration of the quadrant. Returns a new tensor with the signed angles + in radians between vector :math:`(\text{other}_{i}, \text{input}_{i})` + and vector :math:`(1, 0)`. (Note that :math:`\text{other}_{i}`, the second + parameter, is the x-coordinate, while :math:`\text{input}_{i}`, the first + parameter, is the y-coordinate.) + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) + >>> torch.atan2(a, torch.randn(4)) + tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) + """ + ... +def atan_(input: Tensor) -> Tensor: ... +def atanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + atanh(input, *, out=None) -> Tensor + + Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + + Note: + The domain of the inverse hyperbolic tangent is `(-1, 1)` and values outside this range + will be mapped to ``NaN``, except for the values `1` and `-1` for which the output is + mapped to `+/-INF` respectively. + + .. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(-1, 1) + >>> a + tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) + >>> torch.atanh(a) + tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) + """ + ... +def atanh_(input: Tensor) -> Tensor: ... +def avg_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, ceil_mode: _bool = False, count_include_pad: _bool = True) -> Tensor: ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(input: Tensor, batch1: Tensor, batch2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def baddbmm(beta: Union[Number, _complex], self: Tensor, batch1: Tensor, batch2: Tensor, *, out: Tensor) -> Tensor: + r""" + baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + ... +@overload +def bartlett_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +@overload +def bartlett_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +def batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ... +def batch_norm_backward_elemt(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], sum_dy: Tensor, sum_dy_xmu: Tensor, count: Tensor) -> Tensor: ... +def batch_norm_backward_reduce(grad_out: Tensor, input: Tensor, mean: Tensor, invstd: Tensor, weight: Optional[Tensor], input_g: _bool, weight_g: _bool, bias_g: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def batch_norm_elemt(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], mean: Tensor, invstd: Tensor, eps: _float, *, out: Optional[Tensor] = None) -> Tensor: ... +def batch_norm_gather_stats(input: Tensor, mean: Tensor, invstd: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], momentum: _float, eps: _float, count: _int) -> Tuple[Tensor, Tensor]: ... +def batch_norm_gather_stats_with_counts(input: Tensor, mean: Tensor, invstd: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], momentum: _float, eps: _float, counts: Tensor) -> Tuple[Tensor, Tensor]: ... +def batch_norm_stats(input: Tensor, eps: _float) -> Tuple[Tensor, Tensor]: ... +def batch_norm_update_stats(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], momentum: _float) -> Tuple[Tensor, Tensor]: ... +@overload +def bernoulli(input: Tensor, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + bernoulli(input, *, generator=None, out=None) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + ... +@overload +def bernoulli(input: Tensor, p: _float, *, generator: Optional[Generator] = None) -> Tensor: + r""" + bernoulli(input, *, generator=None, out=None) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + ... +def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ... +def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, weight: Optional[Tensor] = None, pos_weight: Optional[Tensor] = None, reduction: _int = 1) -> Tensor: ... +def bincount(input: Tensor, weights: Optional[Tensor] = None, minlength: _int = 0) -> Tensor: + r""" + bincount(input, weights=None, minlength=0) -> Tensor + + Count the frequency of each value in an array of non-negative ints. + + The number of bins (size 1) is one larger than the largest value in + :attr:`input` unless :attr:`input` is empty, in which case the result is a + tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least + :attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size + :attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``, + ``out[n] += weights[i]`` if :attr:`weights` is specified else + ``out[n] += 1``. + + Note: + This operation may produce nondeterministic gradients when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + Arguments: + input (Tensor): 1-d int tensor + weights (Tensor): optional, weight for each value in the input tensor. + Should be of same size as input tensor. + minlength (int): optional, minimum number of bins. Should be non-negative. + + Returns: + output (Tensor): a tensor of shape ``Size([max(input) + 1])`` if + :attr:`input` is non-empty, else ``Size(0)`` + + Example:: + + >>> input = torch.randint(0, 8, (5,), dtype=torch.int64) + >>> weights = torch.linspace(0, 1, steps=5) + >>> input, weights + (tensor([4, 3, 6, 3, 4]), + tensor([ 0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + >>> torch.bincount(input) + tensor([0, 0, 0, 2, 2, 0, 1]) + + >>> input.bincount(weights) + tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) + """ + ... +def binomial(count: Tensor, prob: Tensor, generator: Optional[Generator] = None) -> Tensor: ... +@overload +def bitwise_and(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + ... +@overload +def bitwise_and(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + ... +@overload +def bitwise_and(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + ... +@overload +def bitwise_left_shift(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + ... +@overload +def bitwise_left_shift(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + ... +@overload +def bitwise_left_shift(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + ... +def bitwise_not(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_not(input, *, out=None) -> Tensor + + Computes the bitwise NOT of the given input tensor. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical NOT. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) + tensor([ 0, 1, -4], dtype=torch.int8) + """ + ... +@overload +def bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_or(input, other, *, out=None) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + ... +@overload +def bitwise_or(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_or(input, other, *, out=None) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + ... +@overload +def bitwise_or(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_or(input, other, *, out=None) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + ... +@overload +def bitwise_right_shift(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + ... +@overload +def bitwise_right_shift(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + ... +@overload +def bitwise_right_shift(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + ... +@overload +def bitwise_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + ... +@overload +def bitwise_xor(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + ... +@overload +def bitwise_xor(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + ... +@overload +def blackman_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +@overload +def blackman_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +def bmm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + bmm(input, mat2, *, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored in :attr:`input` + and :attr:`mat2`. + + :attr:`input` and :attr:`mat2` must be 3-D tensors each containing + the same number of matrices. + + If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a + :math:`(b \times m \times p)` tensor, :attr:`out` will be a + :math:`(b \times n \times p)` tensor. + + .. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + + Keyword Args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) + """ + ... +def broadcast_to(input: Tensor, size: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + broadcast_to(input, shape) -> Tensor + + Broadcasts :attr:`input` to the shape :attr:`\shape`. + Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + + Args: + input (Tensor): the input tensor. + shape (list, tuple, or :class:`torch.Size`): the new shape. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) + """ + ... +@overload +def bucketize(input: Tensor, boundaries: Tensor, *, out_int32: _bool = False, right: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of :attr:`boundaries` (one pass the last index). + In other words, if False, gets the lower bound index for each value in :attr:`input` + from :attr:`boundaries`. If True, gets the upper bound index instead. + Default value is False. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + ... +@overload +def bucketize(self: Union[Number, _complex], boundaries: Tensor, *, out_int32: _bool = False, right: _bool = False) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of :attr:`boundaries` (one pass the last index). + In other words, if False, gets the lower bound index for each value in :attr:`input` + from :attr:`boundaries`. If True, gets the upper bound index instead. + Default value is False. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + ... +def can_cast(from_: _dtype, to: _dtype) -> _bool: + r""" + can_cast(from, to) -> bool + + Determines if a type conversion is allowed under PyTorch casting rules + described in the type promotion :ref:`documentation `. + + Args: + from (dtype): The original :class:`torch.dtype`. + to (dtype): The target :class:`torch.dtype`. + + Example:: + + >>> torch.can_cast(torch.double, torch.float) + True + >>> torch.can_cast(torch.float, torch.int) + False + """ + ... +@overload +def cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of :attr:`seq` tensors in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): any python sequence of tensors of the same type. + Non-empty tensors provided must have the same shape, except in the + cat dimension. + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + ... +@overload +def cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of :attr:`seq` tensors in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): any python sequence of tensors of the same type. + Non-empty tensors provided must have the same shape, except in the + cat dimension. + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + ... +def ccol_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def ceil(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ceil(input, *, out=None) -> Tensor + + Returns a new tensor with the ceil of the elements of :attr:`input`, + the smallest integer greater than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + >>> torch.ceil(a) + tensor([-0., -1., -1., 1.]) + """ + ... +def ceil_(input: Tensor) -> Tensor: ... +def celu(input: Tensor, alpha: Union[Number, _complex] = 1.0) -> Tensor: ... +def celu_(input: Tensor, alpha: Union[Number, _complex] = 1.0) -> Tensor: ... +def channel_shuffle(input: Tensor, groups: Union[_int, SymInt]) -> Tensor: ... +def cholesky(input: Tensor, upper: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cholesky(input, upper=False, *, out=None) -> Tensor + + Computes the Cholesky decomposition of a symmetric positive-definite + matrix :math:`A` or for batches of symmetric positive-definite matrices. + + If :attr:`upper` is ``True``, the returned matrix ``U`` is upper-triangular, and + the decomposition has the form: + + .. math:: + + A = U^TU + + If :attr:`upper` is ``False``, the returned matrix ``L`` is lower-triangular, and + the decomposition has the form: + + .. math:: + + A = LL^T + + If :attr:`upper` is ``True``, and :math:`A` is a batch of symmetric positive-definite + matrices, then the returned tensor will be composed of upper-triangular Cholesky factors + of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned + tensor will be composed of lower-triangular Cholesky factors of each of the individual + matrices. + + .. warning:: + + :func:`torch.cholesky` is deprecated in favor of :func:`torch.linalg.cholesky` + and will be removed in a future PyTorch release. + + ``L = torch.cholesky(A)`` should be replaced with + + .. code:: python + + L = torch.linalg.cholesky(A) + + ``U = torch.cholesky(A, upper=True)`` should be replaced with + + .. code:: python + + U = torch.linalg.cholesky(A).mH + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. + + Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + upper (bool, optional): flag that indicates whether to return a + upper or lower triangular matrix. Default: ``False`` + + Keyword args: + out (Tensor, optional): the output matrix + + Example:: + + >>> a = torch.randn(3, 3) + >>> a = a @ a.mT + 1e-3 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> l @ l.mT + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> a = torch.randn(3, 2, 2) # Example for batched input + >>> a = a @ a.mT + 1e-03 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> z = l @ l.mT + >>> torch.dist(z, a) + tensor(2.3842e-07) + """ + ... +def cholesky_inverse(input: Tensor, upper: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cholesky_inverse(L, upper=False, *, out=None) -> Tensor + + Computes the inverse of a complex Hermitian or real symmetric + positive-definite matrix given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Computes the inverse matrix :math:`A^{-1}`. + + Supports input of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` is a batch of matrices + then the output has the same batch dimensions. + + Args: + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False`` + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> torch.cholesky_inverse(L) + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + >>> A.inverse() + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(torch.inverse(A), torch.cholesky_inverse(L)) + tensor(5.6358e-7) + """ + ... +def cholesky_solve(input: Tensor, input2: Tensor, upper: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cholesky_solve(B, L, upper=False, *, out=None) -> Tensor + + Computes the solution of a system of linear equations with complex Hermitian + or real symmetric positive-definite lhs given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Returns the solution :math:`X` of the following linear system: + + .. math:: + + AX = B + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices + then the output has the same batch dimensions. + + Args: + B (Tensor): right-hand side tensor of shape `(*, n, k)` + where :math:`*` is zero or more batch dimensions + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False``. + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> B = torch.randn(3, 2) + >>> torch.cholesky_solve(B, L) + tensor([[ -8.1625, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + >>> A.inverse() @ B + tensor([[ -8.1626, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> B = torch.randn(2, 1, dtype=torch.complex64) + >>> X = torch.cholesky_solve(B, L) + >>> torch.dist(X, A.inverse() @ B) + tensor(1.6881e-5) + """ + ... +def choose_qparams_optimized(input: Tensor, numel: _int, n_bins: _int, ratio: _float, bit_width: _int) -> Tuple[Tensor, Tensor]: ... +def chunk(input: Tensor, chunks: _int, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + chunk(input, chunks, dim=0) -> List of Tensors + + Attempts to split a tensor into the specified number of chunks. Each chunk is a view of + the input tensor. + + + .. note:: + + This function may return fewer than the specified number of chunks! + + .. seealso:: + + :func:`torch.tensor_split` a function that always returns exactly the specified number of chunks + + If the tensor size along the given dimension :attr:`dim` is divisible by :attr:`chunks`, + all returned chunks will be the same size. + If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`, + all returned chunks will be the same size, except the last one. + If such division is not possible, this function may return fewer + than the specified number of chunks. + + Arguments: + input (Tensor): the tensor to split + chunks (int): number of chunks to return + dim (int): dimension along which to split the tensor + + Example: + >>> torch.arange(11).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10])) + >>> torch.arange(12).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10, 11])) + >>> torch.arange(13).chunk(6) + (tensor([0, 1, 2]), + tensor([3, 4, 5]), + tensor([6, 7, 8]), + tensor([ 9, 10, 11]), + tensor([12])) + """ + ... +@overload +def clamp(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + ... +@overload +def clamp(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + ... +@overload +def clamp_(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: ... +@overload +def clamp_max(input: Tensor, max: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_max(input: Tensor, max: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Tensor) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Union[Number, _complex]) -> Tensor: ... +@overload +def clamp_min(input: Tensor, min: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_min(input: Tensor, min: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Tensor) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Union[Number, _complex]) -> Tensor: ... +@overload +def clip(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + ... +@overload +def clip(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + ... +@overload +def clip_(input: Tensor, min: Optional[Tensor] = None, max: Optional[Tensor] = None) -> Tensor: ... +@overload +def clip_(input: Tensor, min: Optional[Union[Number, _complex]] = None, max: Optional[Union[Number, _complex]] = None) -> Tensor: ... +def clone(input: Tensor, *, memory_format: Optional[memory_format] = None) -> Tensor: + r""" + clone(input, *, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of :attr:`input`. + + .. note:: + + This function is differentiable, so gradients will flow back from the + result of this operation to :attr:`input`. To create a tensor without an + autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned tensor. Default: ``torch.preserve_format``. + """ + ... +def col_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.col_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def column_stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + column_stack(tensors, *, out=None) -> Tensor + + Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + + Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` + in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + """ + ... +def combinations(input: Tensor, r: _int = 2, with_replacement: _bool = False) -> Tensor: + r""" + combinations(input, r=2, with_replacement=False) -> seq + + Compute combinations of length :math:`r` of the given tensor. The behavior is similar to + python's `itertools.combinations` when `with_replacement` is set to `False`, and + `itertools.combinations_with_replacement` when `with_replacement` is set to `True`. + + Arguments: + input (Tensor): 1D vector. + r (int, optional): number of elements to combine + with_replacement (bool, optional): whether to allow duplication in combination + + Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, do + `itertools.combinations` or `itertools.combinations_with_replacement` on these + lists, and finally convert the resulting list into tensor. + + Example:: + + >>> a = [1, 2, 3] + >>> list(itertools.combinations(a, r=2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(itertools.combinations(a, r=3)) + [(1, 2, 3)] + >>> list(itertools.combinations_with_replacement(a, r=2)) + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + >>> tensor_a = torch.tensor(a) + >>> torch.combinations(tensor_a) + tensor([[1, 2], + [1, 3], + [2, 3]]) + >>> torch.combinations(tensor_a, r=3) + tensor([[1, 2, 3]]) + >>> torch.combinations(tensor_a, with_replacement=True) + tensor([[1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3]]) + """ + ... +def complex(real: Tensor, imag: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + complex(real, imag, *, out=None) -> Tensor + + Constructs a complex tensor with its real part equal to :attr:`real` and its + imaginary part equal to :attr:`imag`. + + Args: + real (Tensor): The real part of the complex tensor. Must be half, float or double. + imag (Tensor): The imaginary part of the complex tensor. Must be same dtype + as :attr:`real`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> real = torch.tensor([1, 2], dtype=torch.float32) + >>> imag = torch.tensor([3, 4], dtype=torch.float32) + >>> z = torch.complex(real, imag) + >>> z + tensor([(1.+3.j), (2.+4.j)]) + >>> z.dtype + torch.complex64 + """ + ... +@overload +def concat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +@overload +def concat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +@overload +def concatenate(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +@overload +def concatenate(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + ... +def conj(input: Tensor) -> Tensor: + r""" + conj(input) -> Tensor + + Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, + this function just returns :attr:`input`. + + .. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + + .. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True + """ + ... +def conj_physical(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + conj_physical(input, *, out=None) -> Tensor + + Computes the element-wise conjugate of the given :attr:`input` tensor. + If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. + + .. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + + .. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + .. math:: + \text{out}_{i} = conj(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + """ + ... +def conj_physical_(input: Tensor) -> Tensor: ... +def constant_pad_nd(input: Tensor, pad: Sequence[Union[_int, SymInt]], value: Union[Number, _complex] = 0) -> Tensor: ... +@overload +def conv1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: str = "valid", dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: str = "valid", dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +@overload +def conv3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: str = "valid", dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, groups: Union[_int, SymInt] = 1) -> Tensor: ... +def conv_tbc(input: Tensor, weight: Tensor, bias: Tensor, pad: _int = 0) -> Tensor: ... +def conv_transpose1d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, output_padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, groups: Union[_int, SymInt] = 1, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def conv_transpose2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, output_padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, groups: Union[_int, SymInt] = 1, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def conv_transpose3d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1, padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, output_padding: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 0, groups: Union[_int, SymInt] = 1, dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] = 1) -> Tensor: ... +def convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], transposed: _bool, output_padding: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +@overload +def copysign(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + ... +@overload +def copysign(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + ... +def corrcoef(input: Tensor) -> Tensor: + r""" + corrcoef(input) -> Tensor + + Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, + where rows are the variables and columns are the observations. + + .. note:: + + The correlation coefficient matrix R is computed using the covariance matrix C as given by + :math:`R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }` + + .. note:: + + Due to floating point rounding, the resulting array may not be Hermitian and its diagonal elements may not be 1. + The real and imaginary values are clipped to the interval [-1, 1] in an attempt to improve this situation. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Returns: + (Tensor) The correlation coefficient matrix of the variables. + + .. seealso:: + + :func:`torch.cov` covariance matrix. + + Example:: + + >>> x = torch.tensor([[0, 1, 2], [2, 1, 0]]) + >>> torch.corrcoef(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> x = torch.randn(2, 4) + >>> x + tensor([[-0.2678, -0.0908, -0.3766, 0.2780], + [-0.5812, 0.1535, 0.2387, 0.2350]]) + >>> torch.corrcoef(x) + tensor([[1.0000, 0.3582], + [0.3582, 1.0000]]) + >>> torch.corrcoef(x[0]) + tensor(1.) + """ + ... +def cos(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cos(input, *, out=None) -> Tensor + + Returns a new tensor with the cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cos(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) + >>> torch.cos(a) + tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) + """ + ... +def cos_(input: Tensor) -> Tensor: ... +def cosh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cosh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic cosine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) + >>> torch.cosh(a) + tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + ... +def cosh_(input: Tensor) -> Tensor: ... +def cosine_embedding_loss(input1: Tensor, input2: Tensor, target: Tensor, margin: _float = 0.0, reduction: _int = 1) -> Tensor: ... +def cosine_similarity(x1: Tensor, x2: Tensor, dim: _int = 1, eps: _float = 1e-08) -> Tensor: ... +@overload +def count_nonzero(input: Tensor, dim: Optional[_int] = None) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + ... +@overload +def count_nonzero(input: Tensor, dim: _size) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + ... +def cov(input: Tensor, *, correction: _int = 1, fweights: Optional[Tensor] = None, aweights: Optional[Tensor] = None) -> Tensor: + r""" + cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor + + Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are + the variables and columns are the observations. + + A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains + the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents + a single variable (Scalar or 1D) then its variance is returned. + + The sample covariance of the variables :math:`x` and :math:`y` is given by: + + .. math:: + \text{cov}(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{\max(0,~N~-~\delta N)} + + where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively, and + :math:`\delta N` is the :attr:`correction`. + + If :attr:`fweights` and/or :attr:`aweights` are provided, the weighted covariance + is calculated, which is given by: + + .. math:: + \text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)} + {\max(0,~\sum^{N}_{i = 1}w_i~-~\frac{\sum^{N}_{i = 1}w_ia_i}{\sum^{N}_{i = 1}w_i}~\delta N)} + + where :math:`w` denotes :attr:`fweights` or :attr:`aweights` (``f`` and ``a`` for brevity) based on whichever is + provided, or :math:`w = f \times a` if both are provided, and + :math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable. If not + provided, ``f`` and/or ``a`` can be seen as a :math:`\mathbb{1}` vector of appropriate size. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Keyword Args: + correction (int, optional): difference between the sample size and sample degrees of freedom. + Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate, + even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0`` + will return the simple average. Defaults to ``1``. + fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of + times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`. + Must have integral dtype. Ignored if ``None``. Defaults to ``None``. + aweights (tensor, optional): A Scalar or 1D array of observation vector weights. + These relative weights are typically large for observations considered "important" and smaller for + observations considered less "important". Its numel must equal the number of columns of :attr:`input`. + Must have floating point dtype. Ignored if ``None``. Defaults to ``None``. + + Returns: + (Tensor) The covariance matrix of the variables. + + .. seealso:: + + :func:`torch.corrcoef` normalized covariance matrix. + + Example:: + >>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T + >>> x + tensor([[0, 1, 2], + [2, 1, 0]]) + >>> torch.cov(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> torch.cov(x, correction=0) + tensor([[ 0.6667, -0.6667], + [-0.6667, 0.6667]]) + >>> fw = torch.randint(1, 10, (3,)) + >>> fw + tensor([1, 6, 9]) + >>> aw = torch.rand(3) + >>> aw + tensor([0.4282, 0.0255, 0.4144]) + >>> torch.cov(x, fweights=fw, aweights=aw) + tensor([[ 0.4169, -0.4169], + [-0.4169, 0.4169]]) + """ + ... +def cross(input: Tensor, other: Tensor, dim: Optional[_int] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + cross(input, other, dim=None, *, out=None) -> Tensor + + + Returns the cross product of vectors in dimension :attr:`dim` of :attr:`input` + and :attr:`other`. + + Supports input of float, double, cfloat and cdouble dtypes. Also supports batches + of vectors, for which it computes the product along the dimension :attr:`dim`. + In this case, the output has the same batch dimensions as the inputs. + + .. warning:: + If :attr:`dim` is not given, it defaults to the first dimension found + with the size 3. Note that this might be unexpected. + + This behavior is deprecated and will be changed to match that of :func:`torch.linalg.cross` + in a future release. + + .. seealso:: + :func:`torch.linalg.cross` which has dim=-1 as default. + + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + dim (int, optional): the dimension to take the cross-product in. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.cross(a, b, dim=1) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> torch.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + """ + ... +def crow_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.crow_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: _size, target_lengths: _size, blank: _int = 0, reduction: _int = 1, zero_infinity: _bool = False) -> Tensor: ... +@overload +def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor, blank: _int = 0, reduction: _int = 1, zero_infinity: _bool = False) -> Tensor: ... +def cudnn_affine_grid_generator(theta: Tensor, N: _int, C: _int, H: _int, W: _int) -> Tensor: ... +def cudnn_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, exponential_average_factor: _float, epsilon: _float) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def cudnn_convolution(input: Tensor, weight: Tensor, padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, allow_tf32: _bool, *, out: Optional[Tensor] = None) -> Tensor: ... +def cudnn_convolution_add_relu(input: Tensor, weight: Tensor, z: Tensor, alpha: Optional[Union[Number, _complex]], bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def cudnn_convolution_relu(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def cudnn_convolution_transpose(input: Tensor, weight: Tensor, padding: Sequence[Union[_int, SymInt]], output_padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool, allow_tf32: _bool) -> Tensor: ... +def cudnn_grid_sampler(input: Tensor, grid: Tensor) -> Tensor: ... +def cudnn_is_acceptable(input: Tensor) -> _bool: ... +@overload +def cummax(input: Tensor, dim: _int, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + ... +@overload +def cummax(input: Tensor, dim: Union[str, ellipsis, None], *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + ... +@overload +def cummin(input: Tensor, dim: _int, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + ... +@overload +def cummin(input: Tensor, dim: Union[str, ellipsis, None], *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + ... +@overload +def cumprod(input: Tensor, dim: _int, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + ... +@overload +def cumprod(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + ... +@overload +def cumsum(input: Tensor, dim: _int, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + ... +@overload +def cumsum(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + ... +@overload +def cumulative_trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + ... +@overload +def cumulative_trapezoid(y: Tensor, *, dx: Union[Number, _complex] = 1, dim: _int = -1) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + ... +def deg2rad(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + deg2rad(input, *, out=None) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in degrees to radians. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) + >>> torch.deg2rad(a) + tensor([[ 3.1416, -3.1416], + [ 6.2832, -6.2832], + [ 1.5708, -1.5708]]) + """ + ... +def deg2rad_(input: Tensor) -> Tensor: ... +@overload +def dequantize(input: Tensor) -> Tensor: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + ... +@overload +def dequantize(tensors: Union[Tuple[Tensor, ...], List[Tensor]]) -> Tuple[Tensor, ...]: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + ... +def det(input: Tensor) -> Tensor: + r""" + det(input) -> Tensor + + Alias for :func:`torch.linalg.det` + """ + ... +def detach(input: Tensor) -> Tensor: ... +def detach_(input: Tensor) -> Tensor: ... +def detach_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.detach`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def diag(input: Tensor, diagonal: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + diag(input, diagonal=0, *, out=None) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a matrix (2-D tensor), then returns a 1-D tensor with + the diagonal elements of :attr:`input`. + + The argument :attr:`diagonal` controls which diagonal to consider: + + - If :attr:`diagonal` = 0, it is the main diagonal. + - If :attr:`diagonal` > 0, it is above the main diagonal. + - If :attr:`diagonal` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.diagonal` always returns the diagonal of its input. + + :func:`torch.diagflat` always constructs a tensor with diagonal elements + specified by the input. + + Examples: + + Get the square matrix where the input vector is the diagonal:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.5950,-0.0872, 2.3298]) + >>> torch.diag(a) + tensor([[ 0.5950, 0.0000, 0.0000], + [ 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 2.3298]]) + >>> torch.diag(a, 1) + tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], + [ 0.0000, 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 0.0000, 2.3298], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + Get the k-th diagonal of a given matrix:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-0.4264, 0.0255,-0.1064], + [ 0.8795,-0.2429, 0.1374], + [ 0.1029,-0.6482,-1.6300]]) + >>> torch.diag(a, 0) + tensor([-0.4264,-0.2429,-1.6300]) + >>> torch.diag(a, 1) + tensor([ 0.0255, 0.1374]) + """ + ... +def diag_embed(input: Tensor, offset: _int = 0, dim1: _int = -2, dim2: _int = -1) -> Tensor: + r""" + diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor + + Creates a tensor whose diagonals of certain 2D planes (specified by + :attr:`dim1` and :attr:`dim2`) are filled by :attr:`input`. + To facilitate creating batched diagonal matrices, the 2D planes formed by + the last two dimensions of the returned tensor are chosen by default. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + The size of the new matrix will be calculated to make the specified diagonal + of the size of the last input dimension. + Note that for :attr:`offset` other than :math:`0`, the order of :attr:`dim1` + and :attr:`dim2` matters. Exchanging them is equivalent to changing the + sign of :attr:`offset`. + + Applying :meth:`torch.diagonal` to the output of this function with + the same arguments yields a matrix identical to input. However, + :meth:`torch.diagonal` has different default dimensions, so those + need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 1-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: -2. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: -1. + + Example:: + + >>> a = torch.randn(2, 3) + >>> torch.diag_embed(a) + tensor([[[ 1.5410, 0.0000, 0.0000], + [ 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -2.1788]], + + [[ 0.5684, 0.0000, 0.0000], + [ 0.0000, -1.0845, 0.0000], + [ 0.0000, 0.0000, -1.3986]]]) + + >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) + tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], + [ 0.0000, 0.5684, 0.0000, 0.0000]], + + [[ 0.0000, 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -1.0845, 0.0000]], + + [[ 0.0000, 0.0000, 0.0000, -2.1788], + [ 0.0000, 0.0000, 0.0000, -1.3986]], + + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]) + """ + ... +def diagflat(input: Tensor, offset: _int = 0) -> Tensor: + r""" + diagflat(input, offset=0) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a tensor with more than one dimension, then returns a + 2-D tensor with diagonal elements equal to a flattened :attr:`input`. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + offset (int, optional): the diagonal to consider. Default: 0 (main + diagonal). + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([-0.2956, -0.9068, 0.1695]) + >>> torch.diagflat(a) + tensor([[-0.2956, 0.0000, 0.0000], + [ 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.1695]]) + >>> torch.diagflat(a, 1) + tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.1695], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + >>> a = torch.randn(2, 2) + >>> a + tensor([[ 0.2094, -0.3018], + [-0.1516, 1.9342]]) + >>> torch.diagflat(a) + tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.3018, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1516, 0.0000], + [ 0.0000, 0.0000, 0.0000, 1.9342]]) + """ + ... +@overload +def diagonal(input: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a, 0) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + ... +@overload +def diagonal(input: Tensor, *, outdim: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None], dim2: Union[str, ellipsis, None], offset: _int = 0) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a, 0) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + ... +def diagonal_copy(input: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.diagonal`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def diagonal_scatter(input: Tensor, src: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1) -> Tensor: + r""" + diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the diagonal elements of :attr:`input`, with respect to :attr:`dim1` + and :attr:`dim2`. + + This function returns a tensor with fresh storage; it does not + return a view. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + src (Tensor): the tensor to embed into :attr:`input`. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.diagonal(input, offset, dim1, dim2)`` + + Examples:: + + >>> a = torch.zeros(3, 3) + >>> a + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + + >>> torch.diagonal_scatter(a, torch.ones(3), 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + >>> torch.diagonal_scatter(a, torch.ones(2), 1) + tensor([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) + """ + ... +def diff(input: Tensor, n: _int = 1, dim: _int = -1, prepend: Optional[Tensor] = None, append: Optional[Tensor] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor + + Computes the n-th forward difference along the given dimension. + + The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order + differences are calculated by using :func:`torch.diff` recursively. + + Args: + input (Tensor): the tensor to compute the differences on + n (int, optional): the number of times to recursively compute the difference + dim (int, optional): the dimension to compute the difference along. + Default is the last dimension. + prepend, append (Tensor, optional): values to prepend or append to + :attr:`input` along :attr:`dim` before computing the difference. + Their dimensions must be equivalent to that of input, and their shapes + must match input's shape except on :attr:`dim`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 3, 2]) + >>> torch.diff(a) + tensor([ 2, -1]) + >>> b = torch.tensor([4, 5]) + >>> torch.diff(a, append=b) + tensor([ 2, -1, 2, 1]) + >>> c = torch.tensor([[1, 2, 3], [3, 4, 5]]) + >>> torch.diff(c, dim=0) + tensor([[2, 2, 2]]) + >>> torch.diff(c, dim=1) + tensor([[1, 1], + [1, 1]]) + """ + ... +def digamma(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + digamma(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.digamma`. + """ + ... +def dist(input: Tensor, other: Tensor, p: Union[Number, _complex] = 2) -> Tensor: + r""" + dist(input, other, p=2) -> Tensor + + Returns the p-norm of (:attr:`input` - :attr:`other`) + + The shapes of :attr:`input` and :attr:`other` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + other (Tensor): the Right-hand-side input tensor + p (float, optional): the norm to be computed + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([-1.5393, -0.8675, 0.5916, 1.6321]) + >>> y = torch.randn(4) + >>> y + tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) + >>> torch.dist(x, y, 3.5) + tensor(1.6727) + >>> torch.dist(x, y, 3) + tensor(1.6973) + >>> torch.dist(x, y, 0) + tensor(4.) + >>> torch.dist(x, y, 1) + tensor(2.6537) + """ + ... +def div(input: Union[Tensor, Number], other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + div(input, other, *, rounding_mode=None, out=None) -> Tensor + + Divides each element of the input ``input`` by the corresponding element of + :attr:`other`. + + .. math:: + \text{out}_i = \frac{\text{input}_i}{\text{other}_i} + + .. note:: + By default, this performs a "true" division like Python 3. + See the :attr:`rounding_mode` argument for floor division. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + Always promotes integer types to the default scalar type. + + Args: + input (Tensor): the dividend + other (Tensor or Number): the divisor + + Keyword args: + rounding_mode (str, optional): Type of rounding applied to the result: + + * None - default behavior. Performs no rounding and, if both :attr:`input` and + :attr:`other` are integer types, promotes the inputs to the default scalar type. + Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``. + * ``"trunc"`` - rounds the results of the division towards zero. + Equivalent to C-style integer division. + * ``"floor"`` - rounds the results of the division down. + Equivalent to floor division in Python (the ``//`` operator) and NumPy's ``np.floor_divide``. + + out (Tensor, optional): the output tensor. + + Examples:: + + >>> x = torch.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) + >>> torch.div(x, 0.5) + tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274]) + + >>> a = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917], + ... [ 0.1815, -1.0111, 0.9805, -1.5923], + ... [ 0.1062, 1.4581, 0.7759, -1.2344], + ... [-0.1830, -0.0313, 1.1908, -1.4757]]) + >>> b = torch.tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) + >>> torch.div(a, b) + tensor([[-0.4620, -6.6051, 0.5676, 1.2639], + [ 0.2260, -3.4509, -1.2086, 6.8990], + [ 0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + + >>> torch.div(a, b, rounding_mode='trunc') + tensor([[-0., -6., 0., 1.], + [ 0., -3., -1., 6.], + [ 0., 4., -0., 5.], + [-0., -0., -1., 6.]]) + + >>> torch.div(a, b, rounding_mode='floor') + tensor([[-1., -7., 0., 1.], + [ 0., -4., -2., 6.], + [ 0., 4., -1., 5.], + [-1., -1., -2., 6.]]) + """ + ... +@overload +def divide(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +@overload +def divide(input: Tensor, other: Tensor, *, rounding_mode: Optional[str], out: Optional[Tensor] = None) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +@overload +def divide(input: Tensor, other: Union[Number, _complex], *, rounding_mode: Optional[str]) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +@overload +def divide(input: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + ... +def dot(input: Tensor, tensor: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + dot(input, other, *, out=None) -> Tensor + + Computes the dot product of two 1D tensors. + + .. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. + other (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + """ + ... +def dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def dsplit(input: Tensor, sections: _int) -> Tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + ... +@overload +def dsplit(input: Tensor, indices: _size) -> Tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + ... +def dstack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + dstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence depthwise (along third axis). + + This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by :func:`torch.atleast_3d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.dstack((a,b)) + tensor([[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.dstack((a,b)) + tensor([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + """ + ... +def embedding(weight: Tensor, indices: Tensor, padding_idx: Union[_int, SymInt] = -1, scale_grad_by_freq: _bool = False, sparse: _bool = False) -> Tensor: ... +@overload +def embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool, mode: _int, sparse: _bool, per_sample_weights: Optional[Tensor], include_last_offset: _bool, padding_idx: Optional[_int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def embedding_renorm_(input: Tensor, indices: Tensor, max_norm: _float, norm_type: _float) -> Tensor: ... +@overload +def empty(size: Sequence[Union[_int, SymInt]], *, memory_format: Optional[memory_format] = None, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +@overload +def empty(*size: _int, memory_format: Optional[memory_format] = None, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +@overload +def empty(size: _size, *, names: Optional[Sequence[Union[str, ellipsis, None]]], memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +@overload +def empty(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + ... +def empty_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns an uninitialized tensor with the same size as :attr:`input`. + ``torch.empty_like(input)`` is equivalent to + ``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> a=torch.empty((2,3), dtype=torch.int32, device = 'cuda') + >>> torch.empty_like(a) + tensor([[0, 0, 0], + [0, 0, 0]], device='cuda:0', dtype=torch.int32) + """ + ... +def empty_permuted(size: Sequence[Union[_int, SymInt]], physical_layout: _size, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates an uninitialized, non-overlapping and dense tensor with the + specified :attr:`size`, with :attr:`physical_layout` specifying how the + dimensions are physically laid out in memory (each logical dimension is listed + from outermost to innermost). :attr:`physical_layout` is a generalization + of NCHW/NHWC notation: if each dimension is assigned a number according to + what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)`` + while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output + tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]`` + (notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``). + + Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense + tensor with no overlaps. If possible, prefer using this function over + :func:`torch.empty_strided` or manual use of :func:`torch.as_strided`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + physical_layout (tuple of int): the ordering of dimensions physically in memory + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Examples: + + >>> torch.empty((2, 3, 5, 7)).stride() + (105, 35, 7, 1) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride() + (105, 35, 7, 1) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order() + (0, 2, 3, 1) + """ + ... +def empty_quantized(size: _size, qtensor: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def empty_strided(size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. + + .. warning:: + If the constructed tensor is "overlapped" (with multiple indices referring to the same element + in memory) its behavior is undefined. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> a = torch.empty_strided((2, 3), (1, 2)) + >>> a + tensor([[8.9683e-44, 4.4842e-44, 5.1239e+07], + [0.0000e+00, 0.0000e+00, 3.0705e-41]]) + >>> a.stride() + (1, 2) + >>> a.size() + torch.Size([2, 3]) + """ + ... +@overload +def eq(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + ... +@overload +def eq(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + ... +def equal(input: Tensor, other: Tensor) -> _bool: + r""" + equal(input, other) -> bool + + ``True`` if two tensors have the same size and elements, ``False`` otherwise. + + Example:: + + >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) + True + """ + ... +def erf(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + erf(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erf`. + """ + ... +def erf_(input: Tensor) -> Tensor: ... +def erfc(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + erfc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfc`. + """ + ... +def erfc_(input: Tensor) -> Tensor: ... +def erfinv(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + erfinv(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfinv`. + """ + ... +def exp(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + exp(input, *, out=None) -> Tensor + + Returns a new tensor with the exponential of the elements + of the input tensor :attr:`input`. + + .. math:: + y_{i} = e^{x_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.exp(torch.tensor([0, math.log(2.)])) + tensor([ 1., 2.]) + """ + ... +def exp2(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + exp2(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.exp2`. + """ + ... +def exp2_(input: Tensor) -> Tensor: ... +def exp_(input: Tensor) -> Tensor: ... +def expand_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], *, implicit: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.expand`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def expm1(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + expm1(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expm1`. + """ + ... +def expm1_(input: Tensor) -> Tensor: ... +@overload +def eye(n: Union[_int, SymInt], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + ... +@overload +def eye(n: Union[_int, SymInt], m: Union[_int, SymInt], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + ... +def fake_quantize_per_channel_affine(input: Tensor, scale: Tensor, zero_point: Tensor, axis: _int, quant_min: _int, quant_max: _int) -> Tensor: + r""" + fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), in ``torch.float32`` + scale (Tensor): quantization scale, per channel in ``torch.float32`` + zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32`` + axis (int32): channel axis + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized per channel ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(2, 2, 2) + >>> x + tensor([[[-0.2525, -0.0466], + [ 0.3491, -0.2168]], + + [[-0.5906, 1.6258], + [ 0.6444, -0.0542]]]) + >>> scales = (torch.randn(2) + 1) * 0.05 + >>> scales + tensor([0.0475, 0.0486]) + >>> zero_points = torch.zeros(2).to(torch.int32) + >>> zero_points + tensor([0, 0]) + >>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255) + tensor([[[0.0000, 0.0000], + [0.3405, 0.0000]], + + [[0.0000, 1.6134], + [0.6323, 0.0000]]]) + """ + ... +@overload +def fake_quantize_per_tensor_affine(input: Tensor, scale: _float, zero_point: _int, quant_min: _int, quant_max: _int) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + ... +@overload +def fake_quantize_per_tensor_affine(input: Tensor, scale: Tensor, zero_point: Tensor, quant_min: _int, quant_max: _int) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + ... +def fbgemm_linear_fp16_weight(input: Tensor, packed_weight: Tensor, bias: Tensor) -> Tensor: ... +def fbgemm_linear_fp16_weight_fp32_activation(input: Tensor, packed_weight: Tensor, bias: Tensor) -> Tensor: ... +def fbgemm_linear_int8_weight(input: Tensor, weight: Tensor, packed: Tensor, col_offsets: Tensor, weight_scale: Union[Number, _complex], weight_zero_point: Union[Number, _complex], bias: Tensor) -> Tensor: ... +def fbgemm_linear_int8_weight_fp32_activation(input: Tensor, weight: Tensor, packed: Tensor, col_offsets: Tensor, weight_scale: Union[Number, _complex], weight_zero_point: Union[Number, _complex], bias: Tensor) -> Tensor: ... +def fbgemm_linear_quantize_weight(input: Tensor) -> Tuple[Tensor, Tensor, _float, _int]: ... +def fbgemm_pack_gemm_matrix_fp16(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor, K: _int, N: _int) -> Tensor: ... +def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_alpha_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +@overload +def fill(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill(input: Tensor, value: Union[Number, _complex]) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Union[Number, _complex]) -> Tensor: ... +def fix(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fix(input, *, out=None) -> Tensor + + Alias for :func:`torch.trunc` + """ + ... +def fix_(input: Tensor) -> Tensor: ... +@overload +def flatten(input: Tensor, start_dim: _int = 0, end_dim: _int = -1) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +@overload +def flatten(input: Tensor, start_dim: _int, end_dim: _int, out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +@overload +def flatten(input: Tensor, start_dim: Union[str, ellipsis, None], end_dim: Union[str, ellipsis, None], out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +@overload +def flatten(input: Tensor, dims: Sequence[Union[str, ellipsis, None]], out_dim: Union[str, ellipsis, None]) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + ... +def flip(input: Tensor, dims: _size) -> Tensor: + r""" + flip(input, dims) -> Tensor + + Reverse the order of an n-D tensor along given axis in dims. + + .. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + + Args: + input (Tensor): the input tensor. + dims (a list or tuple): axis to flip on + + Example:: + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]]]) + >>> torch.flip(x, [0, 1]) + tensor([[[ 6, 7], + [ 4, 5]], + + [[ 2, 3], + [ 0, 1]]]) + """ + ... +def fliplr(input: Tensor) -> Tensor: + r""" + fliplr(input) -> Tensor + + Flip tensor in the left/right direction, returning a new tensor. + + Flip the entries in each row in the left/right direction. + Columns are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 2-D. + + .. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. + + Args: + input (Tensor): Must be at least 2-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.fliplr(x) + tensor([[1, 0], + [3, 2]]) + """ + ... +def flipud(input: Tensor) -> Tensor: + r""" + flipud(input) -> Tensor + + Flip tensor in the up/down direction, returning a new tensor. + + Flip the entries in each column in the up/down direction. + Rows are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 1-D. + + .. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. + + Args: + input (Tensor): Must be at least 1-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flipud(x) + tensor([[2, 3], + [0, 1]]) + """ + ... +@overload +def float_power(input: Tensor, exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + ... +@overload +def float_power(self: Union[Number, _complex], exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + ... +@overload +def float_power(input: Tensor, exponent: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + ... +def floor(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + floor(input, *, out=None) -> Tensor + + Returns a new tensor with the floor of the elements of :attr:`input`, + the largest integer less than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.8166, 1.5308, -0.2530, -0.2091]) + >>> torch.floor(a) + tensor([-1., 1., -1., -1.]) + """ + ... +def floor_(input: Tensor) -> Tensor: ... +def floor_divide(input: Union[Tensor, Number], other: Union[Tensor, Number], *, out: Optional[Tensor] = None) -> Tensor: + r""" + floor_divide(input, other, *, out=None) -> Tensor + + .. note:: + + Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly performed + truncation division. To restore the previous behavior use + :func:`torch.div` with ``rounding_mode='trunc'``. + + Computes :attr:`input` divided by :attr:`other`, elementwise, and floors + the result. + + .. math:: + \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) + + + + Supports broadcasting to a common shape, type promotion, and integer and float inputs. + + Args: + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([4.0, 3.0]) + >>> b = torch.tensor([2.0, 2.0]) + >>> torch.floor_divide(a, b) + tensor([2.0, 1.0]) + >>> torch.floor_divide(a, 1.4) + tensor([2.0, 2.0]) + """ + ... +def fmax(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmax(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.maximum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) + """ + ... +def fmin(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmin(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.minimum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) + """ + ... +@overload +def fmod(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + ... +@overload +def fmod(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + ... +def frac(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + frac(input, *, out=None) -> Tensor + + Computes the fractional portion of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) + + Example:: + + >>> torch.frac(torch.tensor([1, 2.5, -3.2])) + tensor([ 0.0000, 0.5000, -0.2000]) + """ + ... +def frac_(input: Tensor) -> Tensor: ... +def frexp(input: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.frexp: + r""" + frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) + + Decomposes :attr:`input` into mantissa and exponent tensors + such that :math:`\text{input} = \text{mantissa} \times 2^{\text{exponent}}`. + + The range of mantissa is the open interval (-1, 1). + + Supports float inputs. + + Args: + input (Tensor): the input tensor + + + Keyword args: + out (tuple, optional): the output tensors + + Example:: + + >>> x = torch.arange(9.) + >>> mantissa, exponent = torch.frexp(x) + >>> mantissa + tensor([0.0000, 0.5000, 0.5000, 0.7500, 0.5000, 0.6250, 0.7500, 0.8750, 0.5000]) + >>> exponent + tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) + >>> torch.ldexp(mantissa, exponent) + tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) + """ + ... +def frobenius_norm(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: ... +def from_file(filename: str, shared: Optional[_bool] = None, size: Optional[_int] = 0, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + from_file(filename, shared=None, size=0, *, dtype=None, layout=None, device=None, pin_memory=False) + + Creates a CPU tensor with a storage backed by a memory-mapped file. + + If ``shared`` is True, then memory is shared between processes. All changes are written to the file. + If ``shared`` is False, then changes to the tensor do not affect the file. + + ``size`` is the number of elements in the Tensor. If ``shared`` is ``False``, then the file must contain + at least ``size * sizeof(dtype)`` bytes. If ``shared`` is ``True`` the file will be created if needed. + + .. note:: + Only CPU tensors can be mapped to files. + + .. note:: + For now, tensors with storages backed by a memory-mapped file cannot be created in pinned memory. + + + Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + >>> t = torch.randn(2, 5, dtype=torch.float64) + >>> t.numpy().tofile('storage.pt') + >>> t_mapped = torch.from_file('storage.pt', shared=False, size=10, dtype=torch.float64) + """ + ... +def from_numpy(ndarray) -> Tensor: + r""" + from_numpy(ndarray) -> Tensor + + Creates a :class:`Tensor` from a :class:`numpy.ndarray`. + + The returned tensor and :attr:`ndarray` share the same memory. Modifications to + the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned + tensor is not resizable. + + It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, + ``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``, + ``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, + and ``bool``. + + .. warning:: + Writing to a tensor created from a read-only NumPy array is not supported and will result in undefined behavior. + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.from_numpy(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + """ + ... +def frombuffer(buffer: Any, *, dtype: _dtype, count: int = -1, offset: int = 0, requires_grad: _bool = False) -> Tensor: + r""" + frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor + + Creates a 1-dimensional :class:`Tensor` from an object that implements + the Python buffer protocol. + + Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of + the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count` + elements. + + Note that either of the following must be true: + + 1. :attr:`count` is a positive non-zero number, and the total number of bytes + in the buffer is more than :attr:`offset` plus :attr:`count` times the size + (in bytes) of :attr:`dtype`. + + 2. :attr:`count` is negative, and the length (number of bytes) of the buffer + subtracted by the :attr:`offset` is a multiple of the size (in bytes) of + :attr:`dtype`. + + The returned tensor and buffer share the same memory. Modifications to + the tensor will be reflected in the buffer and vice versa. The returned + tensor is not resizable. + + .. note:: + This function increments the reference count for the object that + owns the shared memory. Therefore, such memory will not be deallocated + before the returned tensor goes out of scope. + + .. warning:: + This function's behavior is undefined when passed an object implementing + the buffer protocol whose data is not on the CPU. Doing so is likely to + cause a segmentation fault. + + .. warning:: + This function does not try to infer the :attr:`dtype` (hence, it is not + optional). Passing a different :attr:`dtype` than its source may result + in unexpected behavior. + + Args: + buffer (object): a Python object that exposes the buffer interface. + + Keyword args: + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + count (int, optional): the number of desired elements to be read. + If negative, all the elements (until the end of the buffer) will be + read. Default: -1. + offset (int, optional): the number of bytes to skip at the start of + the buffer. Default: 0. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> import array + >>> a = array.array('i', [1, 2, 3]) + >>> t = torch.frombuffer(a, dtype=torch.int32) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> # Interprets the signed char bytes as 32-bit integers. + >>> # Each 4 signed char elements will be interpreted as + >>> # 1 signed 32-bit integer. + >>> import array + >>> a = array.array('b', [-1, 0, 0, 0]) + >>> torch.frombuffer(a, dtype=torch.int32) + tensor([255], dtype=torch.int32) + """ + ... +@overload +def full(size: _size, fill_value: Union[Number, _complex], *, out: Optional[Tensor] = None, layout: _layout = strided, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +@overload +def full(size: _size, fill_value: Union[Number, _complex], *, names: List[Union[str, None]], layout: _layout = strided, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +@overload +def full(size: Sequence[Union[_int, SymInt]], fill_value: Union[Number, _complex], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +@overload +def full(size: _size, fill_value: Union[Number, _complex], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + ... +def full_like(input: Tensor, fill_value: Union[Number, _complex], *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. + ``torch.full_like(input, fill_value)`` is equivalent to + ``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + fill_value: the number to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +def fused_moving_avg_obs_fake_quant(input: Tensor, observer_on: Tensor, fake_quant_on: Tensor, running_min: Tensor, running_max: Tensor, scale: Tensor, zero_point: Tensor, averaging_const: _float, quant_min: _int, quant_max: _int, ch_axis: _int, per_row_fake_quant: _bool = False, symmetric_quant: _bool = False) -> Tensor: ... +@overload +def gather(input: Tensor, dim: _int, index: Tensor, *, sparse_grad: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + ... +@overload +def gather(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, *, sparse_grad: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + ... +def gcd(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + gcd(input, other, *, out=None) -> Tensor + + Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`gcd(0, 0) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.gcd(a, b) + tensor([1, 2, 5]) + >>> c = torch.tensor([3]) + >>> torch.gcd(a, c) + tensor([1, 1, 3]) + """ + ... +def gcd_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def ge(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + ... +@overload +def ge(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + ... +def geqrf(input: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.geqrf: + r""" + geqrf(input, *, out=None) -> (Tensor, Tensor) + + This is a low-level function for calling LAPACK's geqrf directly. This function + returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . + + Computes a QR decomposition of :attr:`input`. + Both `Q` and `R` matrices are stored in the same output tensor `a`. + The elements of `R` are stored on and above the diagonal. + Elementary reflectors (or Householder vectors) implicitly defining matrix `Q` + are stored below the diagonal. + The results of this function can be used together with :func:`torch.linalg.householder_product` + to obtain the `Q` matrix or + with :func:`torch.ormqr`, which uses an implicit representation of the `Q` matrix, + for an efficient matrix-matrix multiplication. + + See `LAPACK documentation for geqrf`_ for further details. + + .. note:: + See also :func:`torch.linalg.qr`, which computes Q and R matrices, and :func:`torch.linalg.lstsq` + with the ``driver="gels"`` option for a function that can solve matrix equations using a QR decomposition. + + Args: + input (Tensor): the input matrix + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, Tensor). Ignored if `None`. Default: `None`. + + .. _LAPACK documentation for geqrf: + http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html + """ + ... +def ger(input: Tensor, vec2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ger(input, vec2, *, out=None) -> Tensor + + Alias of :func:`torch.outer`. + + .. warning:: + This function is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.outer` instead. + """ + ... +def get_default_dtype() -> _dtype: + r""" + get_default_dtype() -> torch.dtype + + Get the current default floating point :class:`torch.dtype`. + + Example:: + + >>> torch.get_default_dtype() # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_dtype(torch.float64) + >>> torch.get_default_dtype() # default is now changed to torch.float64 + torch.float64 + """ + ... +def get_num_interop_threads() -> _int: + r""" + get_num_interop_threads() -> int + + Returns the number of threads used for inter-op parallelism on CPU + (e.g. in JIT interpreter) + """ + ... +def get_num_threads() -> _int: + r""" + get_num_threads() -> int + + Returns the number of threads used for parallelizing CPU operations + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Optional[Union[Number, _complex]] = None, dim: Optional[_int] = None, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Sequence[Union[Number, _complex]], dim: Optional[_int] = None, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Sequence[Union[Number, _complex]], dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Union[Tuple[Tensor, ...], List[Tensor]], dim: Optional[_int] = None, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Union[Number, _complex], dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, spacing: Union[Tuple[Tensor, ...], List[Tensor]], dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def gradient(input: Tensor, *, dim: _size, edge_order: _int = 1) -> Tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + ... +@overload +def greater(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + ... +@overload +def greater(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + ... +@overload +def greater_equal(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + ... +@overload +def greater_equal(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + ... +def grid_sampler(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def grid_sampler_2d(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: _int, padding_mode: _int, align_corners: _bool) -> Tensor: ... +def group_norm(input: Tensor, num_groups: _int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: _float = 1e-05, cudnn_enabled: _bool = True) -> Tensor: ... +@overload +def gru(data: Tensor, batch_sizes: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def gru(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def gru_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tensor: ... +@overload +def gt(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + ... +@overload +def gt(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + ... +@overload +def hamming_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hamming_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hamming_window(window_length: _int, periodic: _bool, alpha: _float, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hamming_window(window_length: _int, periodic: _bool, alpha: _float, beta: _float, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + ... +@overload +def hann_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +@overload +def hann_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + ... +def hardshrink(input: Tensor, lambd: Union[Number, _complex] = 0.5, *, out: Optional[Tensor] = None) -> Tensor: ... +def heaviside(input: Tensor, values: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + heaviside(input, values, *, out=None) -> Tensor + + Computes the Heaviside step function for each element in :attr:`input`. + The Heaviside step function is defined as: + + .. math:: + \text{{heaviside}}(input, values) = \begin{cases} + 0, & \text{if input < 0}\\ + values, & \text{if input == 0}\\ + 1, & \text{if input > 0} + \end{cases} + + + Args: + input (Tensor): the input tensor. + values (Tensor): The values to use where :attr:`input` is zero. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + """ + ... +def hinge_embedding_loss(input: Tensor, target: Tensor, margin: _float = 1.0, reduction: _int = 1) -> Tensor: ... +def histc(input: Tensor, bins: _int = 100, min: Union[Number, _complex] = 0, max: Union[Number, _complex] = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor + + Computes the histogram of a tensor. + + The elements are sorted into equal width bins between :attr:`min` and + :attr:`max`. If :attr:`min` and :attr:`max` are both zero, the minimum and + maximum values of the data are used. + + Elements lower than min and higher than max and ``NaN`` elements are ignored. + + Args: + input (Tensor): the input tensor. + bins (int): number of histogram bins + min (Scalar): lower end of the range (inclusive) + max (Scalar): upper end of the range (inclusive) + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: Histogram represented as a tensor + + Example:: + + >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) + tensor([ 0., 2., 1., 0.]) + """ + ... +@overload +def histogram(input: Tensor, bins: Tensor, *, weight: Optional[Tensor] = None, density: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + ... +@overload +def histogram(input: Tensor, bins: _int = 100, *, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + ... +@overload +def histogramdd(input: Tensor, bins: _int, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + ... +@overload +def histogramdd(input: Tensor, bins: _size, range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + ... +@overload +def histogramdd(input: Tensor, bins: Union[Tuple[Tensor, ...], List[Tensor]], range: Optional[Sequence[_float]] = None, weight: Optional[Tensor] = None, density: _bool = False) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + ... +def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def hsplit(input: Tensor, sections: _int) -> Tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + ... +@overload +def hsplit(input: Tensor, indices: _size) -> Tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + ... +def hspmm(mat1: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + hspmm(mat1, mat2, *, out=None) -> Tensor + + Performs a matrix multiplication of a :ref:`sparse COO matrix + ` :attr:`mat1` and a strided matrix :attr:`mat2`. The + result is a (1 + 1)-dimensional :ref:`hybrid COO matrix + `. + + Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + """ + ... +def hstack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + hstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence horizontally (column wise). + + This is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.hstack((a,b)) + tensor([1, 2, 3, 4, 5, 6]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.hstack((a,b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + """ + ... +def hypot(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + hypot(input, other, *, out=None) -> Tensor + + Given the legs of a right triangle, return its hypotenuse. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}^{2} + \text{other}_{i}^{2}} + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([5.0000, 5.6569, 6.4031]) + """ + ... +def i0(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + i0(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.i0`. + """ + ... +def i0_(input: Tensor) -> Tensor: ... +def igamma(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + igamma(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammainc`. + """ + ... +def igammac(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + igammac(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammaincc`. + """ + ... +def imag(input: Tensor) -> Tensor: + r""" + imag(input) -> Tensor + + Returns a new tensor containing imaginary values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + .. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + """ + ... +@overload +def index_add(input: Tensor, dim: _int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_add(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_copy(input: Tensor, dim: _int, index: Tensor, source: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + index_copy(input, dim, index, source, *, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_copy(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy(input, dim, index, source, *, out=None) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + ... +@overload +def index_fill(input: Tensor, dim: _int, index: Tensor, value: Tensor) -> Tensor: ... +@overload +def index_fill(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: ... +@overload +def index_fill(input: Tensor, dim: _int, index: Tensor, value: Union[Number, _complex]) -> Tensor: ... +@overload +def index_fill(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: ... +def index_put(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: ... +def index_put_(input: Tensor, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool = False) -> Tensor: ... +def index_reduce(input: Tensor, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool = True, out: Optional[Tensor] = None) -> Tensor: + r""" + index_reduce(input, dim, index, source, reduce, *, include_self=True, out=None) -> Tensor + + See :meth:`~Tensor.index_reduce_` for function description. + """ + ... +@overload +def index_select(input: Tensor, dim: _int, index: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + ... +@overload +def index_select(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + ... +def indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def init_num_threads() -> None: ... +def inner(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + inner(input, other, *, out=None) -> Tensor + + Computes the dot product for 1D tensors. For higher dimensions, sums the product + of elements from :attr:`input` and :attr:`other` along their last dimension. + + .. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + + Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + + Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + + Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) + """ + ... +def instance_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], use_input_stats: _bool, momentum: _float, eps: _float, cudnn_enabled: _bool) -> Tensor: ... +def int_repr(input: Tensor) -> Tensor: ... +def inverse(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + inverse(input, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.inv` + """ + ... +def is_complex(input: Tensor) -> _bool: + r""" + is_complex(input) -> (bool) + + Returns True if the data type of :attr:`input` is a complex data type i.e., + one of ``torch.complex64``, and ``torch.complex128``. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_conj(input: Tensor) -> _bool: + r""" + is_conj(input) -> (bool) + + Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_distributed(input: Tensor) -> _bool: ... +def is_floating_point(input: Tensor) -> _bool: + r""" + is_floating_point(input) -> (bool) + + Returns True if the data type of :attr:`input` is a floating point data type i.e., + one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_grad_enabled() -> _bool: + r""" + is_grad_enabled() -> (bool) + + Returns True if grad mode is currently enabled. + """ + ... +def is_inference(input: Tensor) -> _bool: + r""" + is_inference(input) -> (bool) + + Returns True if :attr:`input` is an inference tensor. + + A non-view tensor is an inference tensor if and only if it was + allocated during inference mode. A view tensor is an inference + tensor if and only if the tensor it is a view of is an inference tensor. + + For details on inference mode please see + `Inference Mode `_. + + Args: + input (Tensor): the input tensor. + """ + ... +def is_inference_mode_enabled() -> _bool: + r""" + is_inference_mode_enabled() -> (bool) + + Returns True if inference mode is currently enabled. + """ + ... +def is_neg(input: Tensor) -> _bool: ... +def is_nonzero(input: Tensor) -> _bool: + r""" + is_nonzero(input) -> (bool) + + Returns True if the :attr:`input` is a single element tensor which is not equal to zero + after type conversions. + i.e. not equal to ``torch.tensor([0.])`` or ``torch.tensor([0])`` or + ``torch.tensor([False])``. + Throws a ``RuntimeError`` if ``torch.numel() != 1`` (even in case + of sparse tensors). + + Args: + input (Tensor): the input tensor. + + Examples:: + + >>> torch.is_nonzero(torch.tensor([0.])) + False + >>> torch.is_nonzero(torch.tensor([1.5])) + True + >>> torch.is_nonzero(torch.tensor([False])) + False + >>> torch.is_nonzero(torch.tensor([3])) + True + >>> torch.is_nonzero(torch.tensor([1, 3, 5])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous + """ + ... +def is_same_size(input: Tensor, other: Tensor) -> _bool: ... +def is_signed(input: Tensor) -> _bool: ... +def is_vulkan_available() -> _bool: ... +def isclose(input: Tensor, other: Tensor, rtol: _float = 1e-05, atol: _float = 1e-08, equal_nan: _bool = False) -> Tensor: + r""" + isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + Returns a new tensor with boolean elements representing if each element of + :attr:`input` is "close" to the corresponding element of :attr:`other`. + Closeness is defined as: + + .. math:: + \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert + + + where :attr:`input` and :attr:`other` are finite. Where :attr:`input` + and/or :attr:`other` are nonfinite they are close if and only if + they are equal, with NaNs being considered equal to each other when + :attr:`equal_nan` is True. + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Examples:: + + >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4))) + tensor([ True, False, False]) + >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) + tensor([True, True]) + """ + ... +def isfinite(input: Tensor) -> Tensor: + r""" + isfinite(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element is `finite` or not. + + Real values are finite when they are not NaN, negative infinity, or infinity. + Complex values are finite when both their real and imaginary parts are finite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is finite and False elsewhere + + Example:: + + >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([True, False, True, False, False]) + """ + ... +@overload +def isin(elements: Tensor, test_elements: Tensor, *, assume_unique: _bool = False, invert: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + ... +@overload +def isin(element: Union[Number, _complex], test_elements: Tensor, *, assume_unique: _bool = False, invert: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + ... +@overload +def isin(elements: Tensor, test_element: Union[Number, _complex], *, assume_unique: _bool = False, invert: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + ... +def isinf(input: Tensor) -> Tensor: + r""" + isinf(input) -> Tensor + + Tests if each element of :attr:`input` is infinite + (positive or negative infinity) or not. + + .. note:: + Complex values are infinite when their real or imaginary part is + infinite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is infinite and False elsewhere + + Example:: + + >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) + """ + ... +def isnan(input: Tensor) -> Tensor: + r""" + isnan(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` + is NaN or not. Complex values are considered NaN when either their real + and/or imaginary part is NaN. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is NaN and False elsewhere + + Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([False, True, False]) + """ + ... +def isneginf(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + isneginf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is negative infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isneginf(a) + tensor([ True, False, False]) + """ + ... +def isposinf(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + isposinf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is positive infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isposinf(a) + tensor([False, True, False]) + """ + ... +def isreal(input: Tensor) -> Tensor: + r""" + isreal(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. + All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is real and False elsewhere + + Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) + """ + ... +def istft(input: Tensor, n_fft: _int, hop_length: Optional[_int] = None, win_length: Optional[_int] = None, window: Optional[Tensor] = None, center: _bool = True, normalized: _bool = False, onesided: Optional[_bool] = None, length: Optional[_int] = None, return_complex: _bool = False) -> Tensor: ... +@overload +def kaiser_window(window_length: _int, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + ... +@overload +def kaiser_window(window_length: _int, periodic: _bool, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + ... +@overload +def kaiser_window(window_length: _int, periodic: _bool, beta: _float, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + ... +def kl_div(input: Tensor, target: Tensor, reduction: _int = 1, *, log_target: _bool = False) -> Tensor: ... +def kron(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + kron(input, other, *, out=None) -> Tensor + + Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + + If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a + :math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a + :math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + + .. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + + where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. + If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + + Supports real-valued and complex-valued inputs. + + .. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + + Arguments: + input (Tensor) + other (Tensor) + + Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + + Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) + """ + ... +@overload +def kthvalue(input: Tensor, k: _int, dim: _int = -1, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + ... +@overload +def kthvalue(input: Tensor, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + ... +def layer_norm(input: Tensor, normalized_shape: Sequence[Union[_int, SymInt]], weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: _float = 1e-05, cudnn_enable: _bool = True) -> Tensor: ... +def lcm(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lcm(input, other, *, out=None) -> Tensor + + Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`lcm(0, 0) = 0` and :math:`lcm(0, a) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.lcm(a, b) + tensor([15, 20, 15]) + >>> c = torch.tensor([3]) + >>> torch.lcm(a, c) + tensor([15, 30, 15]) + """ + ... +def lcm_(input: Tensor, other: Tensor) -> Tensor: ... +def ldexp(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ldexp(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by 2 ** :attr:`other`. + + .. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i + + + Typically this function is used to construct floating point numbers by multiplying + mantissas in :attr:`input` with integral powers of two created from the exponents + in :attr:`other`. + + Args: + input (Tensor): the input tensor. + other (Tensor): a tensor of exponents, typically integers. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + """ + ... +def ldexp_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def le(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + ... +@overload +def le(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + ... +@overload +def lerp(input: Tensor, end: Tensor, weight: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + ... +@overload +def lerp(input: Tensor, end: Tensor, weight: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + ... +@overload +def less(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + ... +@overload +def less(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + ... +@overload +def less_equal(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + ... +@overload +def less_equal(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + ... +def lgamma(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lgamma(input, *, out=None) -> Tensor + + Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + + .. math:: + \text{out}_{i} = \ln |\Gamma(\text{input}_{i})| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.lgamma(a) + tensor([ 0.5724, 0.0000, -0.1208]) + """ + ... +@overload +def linspace(start: Number, end: Number, steps: Optional[_int] = None, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Tensor, end: Tensor, steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Union[Number, _complex], end: Tensor, steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Tensor, end: Union[Number, _complex], steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +@overload +def linspace(start: Union[Number, _complex], end: Union[Number, _complex], steps: _int, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + ... +def log(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{e} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) * 5 + >>> a + tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + >>> torch.log(a) + tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) + """ + ... +def log10(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log10(input, *, out=None) -> Tensor + + Returns a new tensor with the logarithm to the base 10 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{10} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) + + + >>> torch.log10(a) + tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) + """ + ... +def log10_(input: Tensor) -> Tensor: ... +def log1p(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log1p(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of (1 + :attr:`input`). + + .. math:: + y_i = \log_{e} (x_i + 1) + + .. note:: This function is more accurate than :func:`torch.log` for small + values of :attr:`input` + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) + >>> torch.log1p(a) + tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) + """ + ... +def log1p_(input: Tensor) -> Tensor: ... +def log2(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + log2(input, *, out=None) -> Tensor + + Returns a new tensor with the logarithm to the base 2 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{2} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) + + + >>> torch.log2(a) + tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) + """ + ... +def log2_(input: Tensor) -> Tensor: ... +def log_(input: Tensor) -> Tensor: ... +@overload +def log_softmax(input: Tensor, dim: _int, dtype: Optional[_dtype] = None, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def log_softmax(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: ... +def logaddexp(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logaddexp(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs. + + Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful + in statistics where the calculated probabilities of events may be so small as to + exceed the range of normal floating point numbers. In such cases the logarithm + of the calculated probability is stored. This function allows adding + probabilities stored in such a fashion. + + This op should be disambiguated with :func:`torch.logsumexp` which performs a + reduction on a single tensor. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1.0, -2, -3])) + tensor([-0.3069, -0.6867, -0.8731]) + >>> torch.logaddexp(torch.tensor([-100.0, -200, -300]), torch.tensor([-1.0, -2, -3])) + tensor([-1., -2., -3.]) + >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) + tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) + """ + ... +def logaddexp2(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logaddexp2(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs in base-2. + + Calculates pointwise :math:`\log_2\left(2^x + 2^y\right)`. See + :func:`torch.logaddexp` for more details. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + """ + ... +@overload +def logcumsumexp(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{j=0}^{i} \exp(x_{ij}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + ... +@overload +def logcumsumexp(input: Tensor, dim: Union[str, ellipsis, None], *, out: Optional[Tensor] = None) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{j=0}^{i} \exp(x_{ij}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + ... +def logdet(input: Tensor) -> Tensor: + r""" + logdet(input) -> Tensor + + Calculates log determinant of a square matrix or batches of square matrices. + + It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has + a negative determinant. + + .. note:: + Backward through :meth:`logdet` internally uses SVD results when :attr:`input` + is not invertible. In this case, double backward through :meth:`logdet` will + be unstable in when :attr:`input` doesn't have distinct singular values. See + :func:`torch.linalg.svd` for details. + + .. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + + Arguments: + input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more + batch dimensions. + + Example:: + + >>> A = torch.randn(3, 3) + >>> torch.det(A) + tensor(0.2611) + >>> torch.logdet(A) + tensor(-1.3430) + >>> A + tensor([[[ 0.9254, -0.6213], + [-0.5787, 1.6843]], + + [[ 0.3242, -0.9665], + [ 0.4539, -0.0887]], + + [[ 1.1336, -0.4025], + [-0.7089, 0.9032]]]) + >>> A.det() + tensor([1.1990, 0.4099, 0.7386]) + >>> A.det().log() + tensor([ 0.1815, -0.8917, -0.3031]) + """ + ... +def logical_and(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_and(input, other, *, out=None) -> Tensor + + Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute AND with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_and(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, False]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_and(a, b) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b.double()) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b) + tensor([False, False, True, False]) + >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([False, False, True, False]) + """ + ... +def logical_not(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_not(input, *, out=None) -> Tensor + + Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool + dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_not(torch.tensor([True, False])) + tensor([False, True]) + >>> torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1.5, -10.], dtype=torch.double)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) + tensor([1, 0, 0], dtype=torch.int16) + """ + ... +def logical_or(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_or(input, other, *, out=None) -> Tensor + + Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute OR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_or(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_or(a, b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b.double()) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, True, False]) + """ + ... +def logical_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logical_xor(input, other, *, out=None) -> Tensor + + Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute XOR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_xor(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([False, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_xor(a, b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b.double()) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, False, False]) + """ + ... +def logit(input: Tensor, eps: Optional[_float] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logit(input, eps=None, *, out=None) -> Tensor + + Alias for :func:`torch.special.logit`. + """ + ... +def logit_(input: Tensor, eps: Optional[_float] = None) -> Tensor: ... +@overload +def logspace(start: Number, end: Number, steps: Optional[_int] = None, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Tensor, end: Tensor, steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Union[Number, _complex], end: Tensor, steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Tensor, end: Union[Number, _complex], steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logspace(start: Union[Number, _complex], end: Union[Number, _complex], steps: _int, base: _float = 10.0, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + ... +@overload +def logsumexp(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + ... +@overload +def logsumexp(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + ... +@overload +def lstm(data: Tensor, batch_sizes: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor, Tensor]: ... +@overload +def lstm(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor, Tensor]: ... +def lstm_cell(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def lt(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + ... +@overload +def lt(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + ... +def lu_solve(input: Tensor, LU_data: Tensor, LU_pivots: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor + + Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted + LU factorization of A from :func:`~linalg.lu_factor`. + + This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + + .. warning:: + + :func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`. + :func:`torch.lu_solve` will be removed in a future PyTorch release. + ``X = torch.lu_solve(B, LU, pivots)`` should be replaced with + + .. code:: python + + X = linalg.lu_solve(LU, pivots, B) + + Arguments: + b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` + is zero or more batch dimensions. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, + where :math:`*` is zero or more batch dimensions. + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, + where :math:`*` is zero or more batch dimensions. + The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of + :attr:`LU_data`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3, 1) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) + tensor(1.00000e-07 * + 2.8312) + """ + ... +def lu_unpack(LU_data: Tensor, LU_pivots: Tensor, unpack_data: _bool = True, unpack_pivots: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.lu_unpack: + r""" + lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. + + .. seealso:: + + :func:`~linalg.lu` returns the matrices from the LU decomposition. Its gradient formula is more efficient + than that of doing :func:`~linalg.lu_factor` followed by :func:`~linalg.lu_unpack`. + + Args: + LU_data (Tensor): the packed LU factorization data + LU_pivots (Tensor): the packed LU factorization pivots + unpack_data (bool): flag indicating if the data should be unpacked. + If ``False``, then the returned ``L`` and ``U`` are empty tensors. + Default: ``True`` + unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``. + If ``False``, then the returned ``P`` is an empty tensor. + Default: ``True`` + + Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + + Returns: + A namedtuple ``(P, L, U)`` + + Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # We can recover A from the factorization + >>> A_ = P @ L @ U + >>> torch.allclose(A, A_) + True + + >>> # LU factorization of a rectangular matrix: + >>> A = torch.randn(2, 3, 2) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # P, L, U are the same as returned by linalg.lu + >>> P_, L_, U_ = torch.linalg.lu(A) + >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) + True + """ + ... +def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor, margin: _float = 0.0, reduction: _int = 1) -> Tensor: ... +@overload +def masked_fill(input: Tensor, mask: Tensor, value: Tensor) -> Tensor: ... +@overload +def masked_fill(input: Tensor, mask: Tensor, value: Union[Number, _complex]) -> Tensor: ... +def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: ... +def masked_select(input: Tensor, mask: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + masked_select(input, mask, *, out=None) -> Tensor + + Returns a new 1-D tensor which indexes the :attr:`input` tensor according to + the boolean mask :attr:`mask` which is a `BoolTensor`. + + The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need + to match, but they must be :ref:`broadcastable `. + + .. note:: The returned tensor does **not** use the same storage + as the original tensor + + Args: + input (Tensor): the input tensor. + mask (BoolTensor): the tensor containing the binary mask to index with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], + [-1.2035, 1.2252, 0.5002, 0.6248], + [ 0.1307, -2.0608, 0.1244, 2.0139]]) + >>> mask = x.ge(0.5) + >>> mask + tensor([[False, False, False, False], + [False, True, True, True], + [False, False, False, True]]) + >>> torch.masked_select(x, mask) + tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) + """ + ... +def matmul(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + matmul(input, other, *, out=None) -> Tensor + + Matrix product of two tensors. + + The behavior depends on the dimensionality of the tensors as follows: + + - If both tensors are 1-dimensional, the dot product (scalar) is returned. + - If both arguments are 2-dimensional, the matrix-matrix product is returned. + - If the first argument is 1-dimensional and the second argument is 2-dimensional, + a 1 is prepended to its dimension for the purpose of the matrix multiply. + After the matrix multiply, the prepended dimension is removed. + - If the first argument is 2-dimensional and the second argument is 1-dimensional, + the matrix-vector product is returned. + - If both arguments are at least 1-dimensional and at least one argument is + N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first + argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the + batched matrix multiply and removed after. If the second argument is 1-dimensional, a + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. + + This operation has support for arguments with :ref:`sparse layouts`. In particular the + matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions + as :func:`torch.mm` + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: + + The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. + + Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> # vector x vector + >>> tensor1 = torch.randn(3) + >>> tensor2 = torch.randn(3) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([]) + >>> # matrix x vector + >>> tensor1 = torch.randn(3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([3]) + >>> # batched matrix x broadcasted vector + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3]) + >>> # batched matrix x batched matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(10, 4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + >>> # batched matrix x broadcasted matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + """ + ... +def matrix_exp(input: Tensor) -> Tensor: + r""" + matrix_exp(A) -> Tensor + + Alias for :func:`torch.linalg.matrix_exp`. + """ + ... +def matrix_power(input: Tensor, n: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + matrix_power(input, n, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.matrix_power` + """ + ... +@overload +def max(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +@overload +def max(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +@overload +def max(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.max: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +@overload +def max(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.max: + r""" + max(input) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + ... +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def max_pool1d_with_indices(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tuple[Tensor, Tensor]: ... +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def maximum(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + maximum(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`maximum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.maximum(a, b) + tensor([3, 2, 4]) + """ + ... +@overload +def mean(input: Tensor, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + ... +@overload +def mean(input: Tensor, dim: Optional[Union[_int, _size]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + ... +@overload +def mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + ... +@overload +def median(input: Tensor) -> Tensor: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + ... +@overload +def median(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + ... +@overload +def median(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + ... +@overload +def min(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +@overload +def min(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +@overload +def min(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.min: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +@overload +def min(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.min: + r""" + min(input) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + ... +def minimum(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + minimum(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`minimum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.minimum(a, b) + tensor([1, 0, -1]) + """ + ... +def miopen_batch_norm(input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, exponential_average_factor: _float, epsilon: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +def miopen_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool) -> Tensor: ... +def miopen_convolution_add_relu(input: Tensor, weight: Tensor, z: Tensor, alpha: Optional[Union[Number, _complex]], bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def miopen_convolution_relu(input: Tensor, weight: Tensor, bias: Optional[Tensor], stride: Sequence[Union[_int, SymInt]], padding: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def miopen_convolution_transpose(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], output_padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool) -> Tensor: ... +def miopen_depthwise_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt], benchmark: _bool, deterministic: _bool) -> Tensor: ... +def miopen_rnn(input: Tensor, weight: Union[Tuple[Tensor, ...], List[Tensor]], weight_stride0: _int, hx: Tensor, cx: Optional[Tensor], mode: _int, hidden_size: _int, num_layers: _int, batch_first: _bool, dropout: _float, train: _bool, bidirectional: _bool, batch_sizes: _size, dropout_state: Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def mkldnn_adaptive_avg_pool2d(input: Tensor, output_size: Union[_int, _size], *, out: Optional[Tensor] = None) -> Tensor: ... +def mkldnn_convolution(input: Tensor, weight: Tensor, bias: Optional[Tensor], padding: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], dilation: Sequence[Union[_int, SymInt]], groups: Union[_int, SymInt]) -> Tensor: ... +def mkldnn_linear_backward_weights(grad_output: Tensor, input: Tensor, weight: Tensor, bias_defined: _bool) -> Tuple[Tensor, Tensor]: ... +def mkldnn_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def mkldnn_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def mkldnn_rnn_layer(input: Tensor, weight0: Tensor, weight1: Tensor, weight2: Tensor, weight3: Tensor, hx_: Tensor, cx_: Tensor, reverse: _bool, batch_sizes: _size, mode: _int, hidden_size: _int, num_layers: _int, has_biases: _bool, bidirectional: _bool, batch_first: _bool, train: _bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... +def mm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + mm(input, mat2, *, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Supports strided and sparse 2-D tensors as inputs, autograd with + respect to strided inputs. + + This operation has support for arguments with :ref:`sparse layouts`. + If :attr:`out` is provided it's layout will be used. Otherwise, the result + layout will be deduced from that of :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) + """ + ... +@overload +def mode(input: Tensor, dim: _int = -1, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor( + [[0, 0, 0, 2, 0, 0, 2], + [0, 3, 0, 0, 2, 0, 1], + [2, 2, 2, 0, 0, 0, 3], + [2, 2, 3, 0, 1, 1, 0], + [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + ... +@overload +def mode(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor( + [[0, 0, 0, 2, 0, 0, 2], + [0, 3, 0, 0, 2, 0, 1], + [2, 2, 2, 0, 0, 0, 3], + [2, 2, 3, 0, 1, 1, 0], + [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + ... +@overload +def moveaxis(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +@overload +def moveaxis(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +@overload +def movedim(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +@overload +def movedim(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + ... +def msort(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + msort(input, *, out=None) -> Tensor + + Sorts the elements of the :attr:`input` tensor along its first dimension + in ascending order by value. + + .. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) + """ + ... +def mul(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + mul(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by :attr:`other`. + + + .. math:: + \text{out}_i = \text{input}_i \times \text{other}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number) - the tensor or number to multiply input by. + + Keyword args: + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.2015, -0.4255, 2.6087]) + >>> torch.mul(a, 100) + tensor([ 20.1494, -42.5491, 260.8663]) + + >>> b = torch.randn(4, 1) + >>> b + tensor([[ 1.1207], + [-0.3137], + [ 0.0700], + [ 0.8378]]) + >>> c = torch.randn(1, 4) + >>> c + tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) + >>> torch.mul(b, c) + tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], + [-0.1614, -0.0382, 0.1645, -0.7021], + [ 0.0360, 0.0085, -0.0367, 0.1567], + [ 0.4312, 0.1019, -0.4394, 1.8753]]) + """ + ... +def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor + + Returns a tensor where each row contains :attr:`num_samples` indices sampled + from the multinomial (a stricter definition would be multivariate, + refer to torch.distributions.multinomial.Multinomial for more details) + probability distribution located in the corresponding row + of tensor :attr:`input`. + + .. note:: + The rows of :attr:`input` do not need to sum to one (in which case we use + the values as weights), but must be non-negative, finite and have + a non-zero sum. + + Indices are ordered from left to right according to when each was sampled + (first samples are placed in first column). + + If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. + + If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape + :math:`(m \times \text{num\_samples})`. + + If replacement is ``True``, samples are drawn with replacement. + + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + + .. note:: + When drawn without replacement, :attr:`num_samples` must be lower than + number of non-zero elements in :attr:`input` (or the min number of non-zero + elements in each row of :attr:`input` if it is a matrix). + + Args: + input (Tensor): the input tensor containing probabilities + num_samples (int): number of samples to draw + replacement (bool, optional): whether to draw with replacement or not + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights + >>> torch.multinomial(weights, 2) + tensor([1, 2]) + >>> torch.multinomial(weights, 4) # ERROR! + RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, + not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320 + >>> torch.multinomial(weights, 4, replacement=True) + tensor([ 2, 1, 1, 1]) + """ + ... +@overload +def multiply(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + ... +@overload +def multiply(input: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + ... +def mv(input: Tensor, vec: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + mv(input, vec, *, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`input` and the vector + :attr:`vec`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size :math:`m`, :attr:`out` will be 1-D of size :math:`n`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): matrix to be multiplied + vec (Tensor): vector to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.mv(mat, vec) + tensor([ 1.0404, -0.6361]) + """ + ... +def mvlgamma(input: Tensor, p: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + mvlgamma(input, p, *, out=None) -> Tensor + + Alias for :func:`torch.special.multigammaln`. + """ + ... +def nan_to_num(input: Tensor, nan: Optional[_float] = None, posinf: Optional[_float] = None, neginf: Optional[_float] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + + Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` + with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. + By default, :literal:`NaN`\ s are replaced with zero, positive infinity is replaced with the + greatest finite value representable by :attr:`input`'s dtype, and negative infinity + is replaced with the least finite value representable by :attr:`input`'s dtype. + + Args: + input (Tensor): the input tensor. + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + """ + ... +def nan_to_num_(input: Tensor, nan: Optional[_float] = None, posinf: Optional[_float] = None, neginf: Optional[_float] = None) -> Tensor: ... +def nanmean(input: Tensor, dim: Optional[Union[_int, _size]] = None, keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + + Computes the mean of all `non-NaN` elements along the specified dimensions. + + This function is identical to :func:`torch.mean` when there are no `NaN` values + in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will + propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the + `NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`). + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.mean` computes the mean value, propagating `NaN`. + + Example:: + + >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) + >>> x.mean() + tensor(nan) + >>> x.nanmean() + tensor(1.8000) + >>> x.mean(dim=0) + tensor([ nan, 1.5000, 2.5000]) + >>> x.nanmean(dim=0) + tensor([1.0000, 1.5000, 2.5000]) + + # If all elements in the reduced dimensions are NaN then the result is NaN + >>> torch.tensor([torch.nan]).nanmean() + tensor(nan) + """ + ... +@overload +def nanmedian(input: Tensor) -> Tensor: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + ... +@overload +def nanmedian(input: Tensor, dim: _int, keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + ... +@overload +def nanmedian(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + ... +@overload +def nanquantile(input: Tensor, q: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + ... +@overload +def nanquantile(input: Tensor, q: _float, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + ... +def nansum(input: Tensor, dim: Optional[Union[_int, _size]] = None, keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + nansum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.tensor([1., 2., float('nan'), 4.]) + >>> torch.nansum(a) + tensor(7.) + + .. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. + If :attr:`dim` is a list of dimensions, reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> torch.nansum(torch.tensor([1., float("nan")])) + 1.0 + >>> a = torch.tensor([[1, 2], [3., float("nan")]]) + >>> torch.nansum(a) + tensor(6.) + >>> torch.nansum(a, dim=0) + tensor([4., 2.]) + >>> torch.nansum(a, dim=1) + tensor([3., 3.]) + """ + ... +@overload +def narrow(input: Tensor, dim: _int, start: Tensor, length: Union[_int, SymInt]) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + ... +@overload +def narrow(input: Tensor, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + ... +def narrow_copy(input: Tensor, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt], *, out: Optional[Tensor] = None) -> Tensor: + r""" + narrow_copy(input, dim, start, length, *, out=None) -> Tensor + + Same as :meth:`Tensor.narrow` except this returns a copy rather + than shared storage. This is primarily for sparse tensors, which + do not have a shared-storage narrow method. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow_copy(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow_copy(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) + >>> torch.narrow_copy(s, 0, 0, 1) + tensor(indices=tensor([[0, 0], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + + .. seealso:: + + :func:`torch.narrow` for a non copy variant + """ + ... +def native_batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: _bool, momentum: _float, eps: _float, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> Tuple[Tensor, Tensor, Tensor]: ... +def native_channel_shuffle(input: Tensor, groups: Union[_int, SymInt]) -> Tensor: ... +def native_dropout(input: Tensor, p: _float, train: Optional[_bool]) -> Tuple[Tensor, Tensor]: ... +def native_group_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], N: Union[_int, SymInt], C: Union[_int, SymInt], HxW: Union[_int, SymInt], group: _int, eps: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +def native_layer_norm(input: Tensor, normalized_shape: Sequence[Union[_int, SymInt]], weight: Optional[Tensor], bias: Optional[Tensor], eps: _float) -> Tuple[Tensor, Tensor, Tensor]: ... +@overload +def native_norm(input: Tensor, p: Optional[Union[Number, _complex]], dim: Union[_int, _size], keepdim: _bool, dtype: Optional[_dtype]) -> Tensor: ... +@overload +def native_norm(input: Tensor, p: Union[Number, _complex] = 2) -> Tensor: ... +@overload +def ne(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + ... +@overload +def ne(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + ... +def neg(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + neg(input, *, out=None) -> Tensor + + Returns a new tensor with the negative of the elements of :attr:`input`. + + .. math:: + \text{out} = -1 \times \text{input} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.neg(a) + tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) + """ + ... +def neg_(input: Tensor) -> Tensor: ... +def negative(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + negative(input, *, out=None) -> Tensor + + Alias for :func:`torch.neg` + """ + ... +def negative_(input: Tensor) -> Tensor: ... +def nextafter(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + nextafter(input, other, *, out=None) -> Tensor + + Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> eps = torch.finfo(torch.float32).eps + >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) + tensor([True, True]) + """ + ... +@overload +def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + ... +@overload +def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + ... +def nonzero_static(input: Tensor, *, size: _int, fill_value: _int = -1, out: Optional[Tensor] = None) -> Tensor: ... +def norm_except_dim(v: Tensor, pow: _int = 2, dim: _int = 0) -> Tensor: ... +@overload +def normal(mean: Tensor, std: Tensor, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def normal(mean: Tensor, std: _float = 1, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def normal(mean: _float, std: Tensor, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def normal(mean: _float, std: _float, size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator] = None, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + ... +@overload +def not_equal(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + ... +@overload +def not_equal(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + ... +@overload +def nuclear_norm(input: Tensor, dim: Union[_int, _size], keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: ... +@overload +def nuclear_norm(input: Tensor, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: ... +def numel(self: Tensor) -> _int: + r""" + numel(input) -> int + + Returns the total number of elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 2, 3, 4, 5) + >>> torch.numel(a) + 120 + >>> a = torch.zeros(4,4) + >>> torch.numel(a) + 16 + """ + ... +@overload +def ones(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +@overload +def ones(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +@overload +def ones(size: _size, *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +@overload +def ones(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + ... +def ones_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the same size as + :attr:`input`. ``torch.ones_like(input)`` is equivalent to + ``torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.ones_like(input, out=output)`` is equivalent to + ``torch.ones(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.ones_like(input) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + ... +def orgqr(input: Tensor, input2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + orgqr(input, tau) -> Tensor + + Alias for :func:`torch.linalg.householder_product`. + """ + ... +def ormqr(input: Tensor, input2: Tensor, input3: Tensor, left: _bool = True, transpose: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor + + Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + + Multiplies a :math:`m \times n` matrix `C` (given by :attr:`other`) with a matrix `Q`, + where `Q` is represented using Householder reflectors `(input, tau)`. + See `Representation of Orthogonal or Unitary Matrices`_ for further details. + + If :attr:`left` is `True` then `op(Q)` times `C` is computed, otherwise the result is `C` times `op(Q)`. + When :attr:`left` is `True`, the implicit matrix `Q` has size :math:`m \times m`. + It has size :math:`n \times n` otherwise. + If :attr:`transpose` is `True` then `op` is the conjugate transpose operation, otherwise it's a no-op. + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + + .. seealso:: + :func:`torch.geqrf` can be used to form the Householder representation `(input, tau)` of matrix `Q` + from the QR decomposition. + + .. note:: + This function supports backward but it is only fast when ``(input, tau)`` do not require gradients + and/or ``tau.size(-1)`` is very small. + `` + + Args: + input (Tensor): tensor of shape `(*, mn, k)` where `*` is zero or more batch dimensions + and `mn` equals to `m` or `n` depending on the :attr:`left`. + tau (Tensor): tensor of shape `(*, min(mn, k))` where `*` is zero or more batch dimensions. + other (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + left (bool): controls the order of multiplication. + transpose (bool): controls whether the matrix `Q` is conjugate transposed or not. + + Keyword args: + out (Tensor, optional): the output Tensor. Ignored if `None`. Default: `None`. + + .. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html + """ + ... +def outer(input: Tensor, vec2: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + outer(input, vec2, *, out=None) -> Tensor + + Outer product of :attr:`input` and :attr:`vec2`. + If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of + size :math:`m`, then :attr:`out` must be a matrix of size :math:`(n \times m)`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): 1-D input vector + vec2 (Tensor): 1-D input vector + + Keyword args: + out (Tensor, optional): optional output matrix + + Example:: + + >>> v1 = torch.arange(1., 5.) + >>> v2 = torch.arange(1., 4.) + >>> torch.outer(v1, v2) + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.]]) + """ + ... +def pairwise_distance(x1: Tensor, x2: Tensor, p: _float = 2, eps: _float = 1e-06, keepdim: _bool = False) -> Tensor: ... +def pdist(input: Tensor, p: _float = 2) -> Tensor: ... +def permute(input: Tensor, dims: _size) -> Tensor: + r""" + permute(input, dims) -> Tensor + + Returns a view of the original tensor :attr:`input` with its dimensions permuted. + + Args: + input (Tensor): the input tensor. + dims (tuple of int): The desired ordering of dimensions + + Example: + >>> x = torch.randn(2, 3, 5) + >>> x.size() + torch.Size([2, 3, 5]) + >>> torch.permute(x, (2, 0, 1)).size() + torch.Size([5, 2, 3]) + """ + ... +def permute_copy(input: Tensor, dims: _size, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.permute`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def pinverse(input: Tensor, rcond: _float = 1e-15) -> Tensor: + r""" + pinverse(input, rcond=1e-15) -> Tensor + + Alias for :func:`torch.linalg.pinv` + """ + ... +def pixel_shuffle(input: Tensor, upscale_factor: _int) -> Tensor: ... +def pixel_unshuffle(input: Tensor, downscale_factor: _int) -> Tensor: ... +def poisson(input: Tensor, generator: Optional[Generator] = None) -> Tensor: + r""" + poisson(input, generator=None) -> Tensor + + Returns a tensor of the same size as :attr:`input` with each element + sampled from a Poisson distribution with rate parameter given by the corresponding + element in :attr:`input` i.e., + + .. math:: + \text{out}_i \sim \text{Poisson}(\text{input}_i) + + :attr:`input` must be non-negative. + + Args: + input (Tensor): the input tensor containing the rates of the Poisson distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + + Example:: + + >>> rates = torch.rand(4, 4) * 5 # rate parameter between 0 and 5 + >>> torch.poisson(rates) + tensor([[9., 1., 3., 5.], + [8., 6., 6., 0.], + [0., 4., 5., 3.], + [2., 1., 4., 2.]]) + """ + ... +def poisson_nll_loss(input: Tensor, target: Tensor, log_input: _bool, full: _bool, eps: _float, reduction: _int) -> Tensor: ... +def polar(abs: Tensor, angle: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + polar(abs, angle, *, out=None) -> Tensor + + Constructs a complex tensor whose elements are Cartesian coordinates + corresponding to the polar coordinates with absolute value :attr:`abs` and angle + :attr:`angle`. + + .. math:: + \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j + + .. note:: + `torch.polar` is similar to + `std::polar `_ + and does not compute the polar decomposition + of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do. + The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is + infinite. + + + Args: + abs (Tensor): The absolute value the complex tensor. Must be float or double. + angle (Tensor): The angle of the complex tensor. Must be same dtype as + :attr:`abs`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> import numpy as np + >>> abs = torch.tensor([1, 2], dtype=torch.float64) + >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64) + >>> z = torch.polar(abs, angle) + >>> z + tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) + """ + ... +def polygamma(n: _int, input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + polygamma(n, input, *, out=None) -> Tensor + + Alias for :func:`torch.special.polygamma`. + """ + ... +def positive(input: Tensor) -> Tensor: + r""" + positive(input) -> Tensor + + Returns :attr:`input`. + Throws a runtime error if :attr:`input` is a bool tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.randn(5) + >>> t + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.positive(t) + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + """ + ... +@overload +def pow(input: Tensor, exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + ... +@overload +def pow(self: Union[Number, _complex], exponent: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + ... +@overload +def pow(input: Tensor, exponent: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + ... +def prelu(input: Tensor, weight: Tensor) -> Tensor: ... +@overload +def prod(input: Tensor, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + prod(input, *, dtype=None) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + ... +@overload +def prod(input: Tensor, dim: _int, keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + prod(input, *, dtype=None) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + ... +@overload +def prod(input: Tensor, dim: Union[str, ellipsis, None], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + prod(input, *, dtype=None) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + ... +def promote_types(type1: _dtype, type2: _dtype) -> _dtype: + r""" + promote_types(type1, type2) -> dtype + + Returns the :class:`torch.dtype` with the smallest size and scalar kind that is + not smaller nor of lower kind than either `type1` or `type2`. See type promotion + :ref:`documentation ` for more information on the type + promotion logic. + + Args: + type1 (:class:`torch.dtype`) + type2 (:class:`torch.dtype`) + + Example:: + + >>> torch.promote_types(torch.int32, torch.float32) + torch.float32 + >>> torch.promote_types(torch.uint8, torch.long) + torch.long + """ + ... +def put(input: Tensor, index: Tensor, source: Tensor, accumulate: _bool = False) -> Tensor: ... +def q_per_channel_axis(input: Tensor) -> _int: ... +def q_per_channel_scales(input: Tensor) -> Tensor: ... +def q_per_channel_zero_points(input: Tensor) -> Tensor: ... +def q_scale(input: Tensor) -> _float: ... +def q_zero_point(input: Tensor) -> _int: ... +def qr(input: Tensor, some: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.qr: + r""" + qr(input, some=True, *, out=None) -> (Tensor, Tensor) + + Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, + and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` + with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and + :math:`R` being an upper triangular matrix or batch of upper triangular matrices. + + If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. + Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. + + .. warning:: + + :func:`torch.qr` is deprecated in favor of :func:`torch.linalg.qr` + and will be removed in a future PyTorch release. The boolean parameter :attr:`some` has been + replaced with a string parameter :attr:`mode`. + + ``Q, R = torch.qr(A)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A) + + ``Q, R = torch.qr(A, some=False)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A, mode="complete") + + .. warning:: + If you plan to backpropagate through QR, note that the current backward implementation + is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` + columns of :attr:`input` are linearly independent. + This behavior will probably change once QR supports pivoting. + + .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + + Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + + Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. + + Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.qr(a, some=False) + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) + True + """ + ... +@overload +def quantile(input: Tensor, q: Tensor, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + ... +@overload +def quantile(input: Tensor, q: _float, dim: Optional[_int] = None, keepdim: _bool = False, *, interpolation: str = "linear", out: Optional[Tensor] = None) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + dim (int): the dimension to reduce. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + ... +def quantize_per_channel(input: Tensor, scales: Tensor, zero_points: Tensor, axis: _int, dtype: _dtype) -> Tensor: + r""" + quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor + + Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + + Arguments: + input (Tensor): float tensor to quantize + scales (Tensor): float 1D tensor of scales to use, size should match ``input.size(axis)`` + zero_points (int): integer 1D tensor of offset to use, size should match ``input.size(axis)`` + axis (int): dimension on which apply per-channel quantization + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor + + Example:: + + >>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) + tensor([[-1., 0.], + [ 1., 2.]], size=(2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_channel_affine, + scale=tensor([0.1000, 0.0100], dtype=torch.float64), + zero_point=tensor([10, 0]), axis=0) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() + tensor([[ 0, 10], + [100, 200]], dtype=torch.uint8) + """ + ... +@overload +def quantize_per_tensor(input: Tensor, scale: Tensor, zero_point: Tensor, dtype: _dtype) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + ... +@overload +def quantize_per_tensor(input: Tensor, scale: _float, zero_point: _int, dtype: _dtype) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + ... +@overload +def quantize_per_tensor(tensors: Union[Tuple[Tensor, ...], List[Tensor]], scales: Tensor, zero_points: Tensor, dtype: _dtype) -> Tuple[Tensor, ...]: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + ... +def quantize_per_tensor_dynamic(input: Tensor, dtype: _dtype, reduce_range: _bool) -> Tensor: + r""" + quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor + + Converts a float tensor to a quantized tensor with scale and zero_point calculated + dynamically based on the input. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8`` + reduce_range (bool): a flag to indicate whether to reduce the range of quantized + data by 1 bit, it's required to avoid instruction overflow for some hardwares + + Returns: + Tensor: A newly (dynamically) quantized tensor + + Example:: + + >>> t = torch.quantize_per_tensor_dynamic(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.quint8, False) + >>> print(t) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.011764705882352941, + zero_point=85) + >>> t.int_repr() + tensor([ 0, 85, 170, 255], dtype=torch.uint8) + """ + ... +def quantized_batch_norm(input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], mean: Tensor, var: Tensor, eps: _float, output_scale: _float, output_zero_point: _int) -> Tensor: + r""" + quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor + + Applies batch normalization on a 4D (NCHW) quantized tensor. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Arguments: + input (Tensor): quantized tensor + weight (Tensor): float tensor that corresponds to the gamma, size C + bias (Tensor): float tensor that corresponds to the beta, size C + mean (Tensor): float mean value in batch normalization, size C + var (Tensor): float tensor for variance, size C + eps (float): a value added to the denominator for numerical stability. + output_scale (float): output quantized tensor scale + output_zero_point (int): output quantized tensor zero_point + + Returns: + Tensor: A quantized tensor with batch normalization applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_batch_norm(qx, torch.ones(2), torch.zeros(2), torch.rand(2), torch.rand(2), 0.00001, 0.2, 2) + tensor([[[[-0.2000, -0.2000], + [ 1.6000, -0.2000]], + + [[-0.4000, -0.4000], + [-0.4000, 0.6000]]], + + + [[[-0.2000, -0.2000], + [-0.2000, -0.2000]], + + [[ 0.6000, -0.4000], + [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) + """ + ... +def quantized_gru_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tensor: ... +def quantized_lstm_cell(input: Tensor, hx: Union[Tuple[Tensor, ...], List[Tensor]], w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tuple[Tensor, Tensor]: ... +def quantized_max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: + r""" + quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 1D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (list of int): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool1d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool1d(qx, [2]) + tensor([[0.0000], + [1.5000]], size=(2, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + ... +def quantized_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: + r""" + quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 2D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (``list of int``): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool2d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool2d(qx, [2,2]) + tensor([[[[1.5000]], + + [[1.5000]]], + + + [[[0.0000]], + + [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + ... +def quantized_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Union[_int, _size] = (), padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: _bool = False) -> Tensor: ... +def quantized_rnn_relu_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tensor: ... +def quantized_rnn_tanh_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Tensor, b_hh: Tensor, packed_ih: Tensor, packed_hh: Tensor, col_offsets_ih: Tensor, col_offsets_hh: Tensor, scale_ih: Union[Number, _complex], scale_hh: Union[Number, _complex], zero_point_ih: Union[Number, _complex], zero_point_hh: Union[Number, _complex]) -> Tensor: ... +def rad2deg(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + rad2deg(input, *, out=None) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in radians to degrees. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) + >>> torch.rad2deg(a) + tensor([[ 180.0233, -180.0233], + [ 359.9894, -359.9894], + [ 89.9544, -89.9544]]) + """ + ... +def rad2deg_(input: Tensor) -> Tensor: ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +@overload +def rand(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + ... +def rand_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a uniform distribution on the interval :math:`[0, 1)`. + ``torch.rand_like(input)`` is equivalent to + ``torch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randint(low: _int, high: _int, size: _size, *, generator: Optional[Generator] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(high: _int, size: _size, *, generator: Optional[Generator] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(low: Union[_int, SymInt], high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint(low: Union[_int, SymInt], high: Union[_int, SymInt], size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + ... +@overload +def randint_like(input: Tensor, high: Union[_int, SymInt], *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randint_like(input: Tensor, low: Union[_int, SymInt], high: Union[_int, SymInt], *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +@overload +def randn(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + ... +def randn_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the + sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to + ``torch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + ... +@overload +def randperm(n: Union[_int, SymInt], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + ... +@overload +def randperm(n: Union[_int, SymInt], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + ... +def range(start: Number, end: Number, step: Number = 1, *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` + with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is + the gap between two values in the tensor. + + .. math:: + \text{out}_{i+1} = \text{out}_i + \text{step}. + + .. warning:: + This function is deprecated and will be removed in a future release because its behavior is inconsistent with + Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). + + Args: + start (float): the starting value for the set of points. Default: ``0``. + end (float): the ending value for the set of points + step (float): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.range(1, 4) + tensor([ 1., 2., 3., 4.]) + >>> torch.range(1, 4, 0.5) + tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) + """ + ... +def ravel(input: Tensor) -> Tensor: + r""" + ravel(input) -> Tensor + + Return a contiguous flattened tensor. A copy is made only if needed. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + """ + ... +def real(input: Tensor) -> Tensor: + r""" + real(input) -> Tensor + + Returns a new tensor containing real values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + """ + ... +def reciprocal(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + reciprocal(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the elements of :attr:`input` + + .. math:: + \text{out}_{i} = \frac{1}{\text{input}_{i}} + + .. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.4595, -2.1219, -1.4314, 0.7298]) + >>> torch.reciprocal(a) + tensor([-2.1763, -0.4713, -0.6986, 1.3702]) + """ + ... +def reciprocal_(input: Tensor) -> Tensor: ... +def relu(input: Tensor) -> Tensor: ... +def relu_(input: Tensor) -> Tensor: ... +@overload +def remainder(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + ... +@overload +def remainder(self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + ... +@overload +def remainder(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + ... +def renorm(input: Tensor, p: Union[Number, _complex], dim: _int, maxnorm: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + renorm(input, p, dim, maxnorm, *, out=None) -> Tensor + + Returns a tensor where each sub-tensor of :attr:`input` along dimension + :attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower + than the value :attr:`maxnorm` + + .. note:: If the norm of a row is lower than `maxnorm`, the row is unchanged + + Args: + input (Tensor): the input tensor. + p (float): the power for the norm computation + dim (int): the dimension to slice over to get the sub-tensors + maxnorm (float): the maximum norm to keep each sub-tensor under + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.ones(3, 3) + >>> x[1].fill_(2) + tensor([ 2., 2., 2.]) + >>> x[2].fill_(3) + tensor([ 3., 3., 3.]) + >>> x + tensor([[ 1., 1., 1.], + [ 2., 2., 2.], + [ 3., 3., 3.]]) + >>> torch.renorm(x, 1, 0, 5) + tensor([[ 1.0000, 1.0000, 1.0000], + [ 1.6667, 1.6667, 1.6667], + [ 1.6667, 1.6667, 1.6667]]) + """ + ... +@overload +def repeat_interleave(input: Tensor, repeats: Tensor, dim: Optional[_int] = None, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + ... +@overload +def repeat_interleave(repeats: Tensor, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + ... +@overload +def repeat_interleave(input: Tensor, repeats: Union[_int, SymInt], dim: Optional[_int] = None, *, output_size: Optional[Union[_int, SymInt]] = None) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + ... +def reshape(input: Tensor, shape: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + reshape(input, shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`input`, + but with the specified shape. When possible, the returned tensor will be a view + of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs + with compatible strides can be reshaped without copying, but you should not + depend on the copying vs. viewing behavior. + + See :meth:`torch.Tensor.view` on when it is possible to return a view. + + A single dimension may be -1, in which case it's inferred from the remaining + dimensions and the number of elements in :attr:`input`. + + Args: + input (Tensor): the tensor to be reshaped + shape (tuple of int): the new shape + + Example:: + + >>> a = torch.arange(4.) + >>> torch.reshape(a, (2, 2)) + tensor([[ 0., 1.], + [ 2., 3.]]) + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + tensor([ 0, 1, 2, 3]) + """ + ... +def resize_as_(input: Tensor, the_template: Tensor, *, memory_format: Optional[memory_format] = None) -> Tensor: ... +def resize_as_sparse_(input: Tensor, the_template: Tensor) -> Tensor: ... +def resolve_conj(input: Tensor) -> Tensor: + r""" + resolve_conj(input) -> Tensor + + Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False + """ + ... +def resolve_neg(input: Tensor) -> Tensor: + r""" + resolve_neg(input) -> Tensor + + Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its negative bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> z = y.imag + >>> z.is_neg() + True + >>> out = z.resolve_neg() + >>> out + tensor([-1., -2., 3.]) + >>> out.is_neg() + False + """ + ... +@overload +def result_type(tensor: Tensor, other: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def result_type(scalar: Union[Number, _complex], tensor: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def result_type(tensor: Tensor, other: Union[Number, _complex]) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def result_type(scalar1: Union[Number, _complex], scalar2: Union[Number, _complex]) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + ... +@overload +def rnn_relu(data: Tensor, batch_sizes: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def rnn_relu(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def rnn_relu_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tensor: ... +@overload +def rnn_tanh(data: Tensor, batch_sizes: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool) -> Tuple[Tensor, Tensor]: ... +@overload +def rnn_tanh(input: Tensor, hx: Tensor, params: Union[Tuple[Tensor, ...], List[Tensor]], has_biases: _bool, num_layers: _int, dropout: _float, train: _bool, bidirectional: _bool, batch_first: _bool) -> Tuple[Tensor, Tensor]: ... +def rnn_tanh_cell(input: Tensor, hx: Tensor, w_ih: Tensor, w_hh: Tensor, b_ih: Optional[Tensor] = None, b_hh: Optional[Tensor] = None) -> Tensor: ... +def roll(input: Tensor, shifts: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]], dims: Union[_int, _size] = ()) -> Tensor: + r""" + roll(input, shifts, dims=None) -> Tensor + + Roll the tensor :attr:`input` along the given dimension(s). Elements that are + shifted beyond the last position are re-introduced at the first position. If + :attr:`dims` is `None`, the tensor will be flattened before rolling and then + restored to the original shape. + + Args: + input (Tensor): the input tensor. + shifts (int or tuple of ints): The number of places by which the elements + of the tensor are shifted. If shifts is a tuple, dims must be a tuple of + the same size, and each dimension will be rolled by the corresponding + value + dims (int or tuple of ints): Axis along which to roll + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + >>> x + tensor([[1, 2], + [3, 4], + [5, 6], + [7, 8]]) + >>> torch.roll(x, 1) + tensor([[8, 1], + [2, 3], + [4, 5], + [6, 7]]) + >>> torch.roll(x, 1, 0) + tensor([[7, 8], + [1, 2], + [3, 4], + [5, 6]]) + >>> torch.roll(x, -1, 0) + tensor([[3, 4], + [5, 6], + [7, 8], + [1, 2]]) + >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) + tensor([[6, 5], + [8, 7], + [2, 1], + [4, 3]]) + """ + ... +def rot90(input: Tensor, k: _int = 1, dims: _size = (0,1)) -> Tensor: + r""" + rot90(input, k=1, dims=[0,1]) -> Tensor + + Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. + Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + + Args: + input (Tensor): the input tensor. + k (int): number of times to rotate. Default value is 1 + dims (a list or tuple): axis to rotate. Default value is [0, 1] + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.rot90(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.rot90(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) + """ + ... +@overload +def round(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + ... +@overload +def round(input: Tensor, *, decimals: _int, out: Optional[Tensor] = None) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + ... +@overload +def round_(input: Tensor) -> Tensor: ... +@overload +def round_(input: Tensor, *, decimals: _int) -> Tensor: ... +def row_indices_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... +def row_stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + row_stack(tensors, *, out=None) -> Tensor + + Alias of :func:`torch.vstack`. + """ + ... +def rrelu(input: Tensor, lower: Union[Number, _complex] = 0.125, upper: Union[Number, _complex] = 0.3333333333333333, training: _bool = False, generator: Optional[Generator] = None) -> Tensor: ... +def rrelu_(input: Tensor, lower: Union[Number, _complex] = 0.125, upper: Union[Number, _complex] = 0.3333333333333333, training: _bool = False, generator: Optional[Generator] = None) -> Tensor: ... +def rsqrt(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + rsqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the square-root of each of + the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + >>> torch.rsqrt(a) + tensor([ nan, 1.8351, 0.8053, nan]) + """ + ... +def rsqrt_(input: Tensor) -> Tensor: ... +@overload +def rsub(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1) -> Tensor: ... +@overload +def rsub(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: ... +def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1, out: Optional[Tensor] = None) -> Tensor: ... +def scalar_tensor(s: Union[Number, _complex], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, src: Tensor, *, reduce: str, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, src: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, value: Union[Number, _complex], *, reduce: str, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: _int, index: Tensor, value: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, value: Union[Number, _complex]) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + ... +@overload +def scatter_add(input: Tensor, dim: _int, index: Tensor, src: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + ... +@overload +def scatter_add(input: Tensor, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + ... +def scatter_reduce(input: Tensor, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool = True, out: Optional[Tensor] = None) -> Tensor: + r""" + scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` + """ + ... +@overload +def searchsorted(sorted_sequence: Tensor, input: Tensor, *, out_int32: _bool = False, right: _bool = False, side: Optional[str] = None, sorter: Optional[Tensor] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + ... +@overload +def searchsorted(sorted_sequence: Tensor, self: Union[Number, _complex], *, out_int32: _bool = False, right: _bool = False, side: Optional[str] = None, sorter: Optional[Tensor] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + ... +def segment_reduce(data: Tensor, reduce: str, *, lengths: Optional[Tensor] = None, indices: Optional[Tensor] = None, offsets: Optional[Tensor] = None, axis: _int = 0, unsafe: _bool = False, initial: Optional[Union[Number, _complex]] = None) -> Tensor: ... +@overload +def select(input: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + ... +@overload +def select(input: Tensor, dim: Union[str, ellipsis, None], index: _int) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + ... +def select_copy(input: Tensor, dim: _int, index: Union[_int, SymInt], *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.select`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def select_scatter(input: Tensor, src: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: + r""" + select_scatter(input, src, dim, index) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into. + index (int): the index to select with + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.select(input, dim, index)`` + + Example:: + + >>> a = torch.zeros(2, 2) + >>> b = torch.ones(2) + >>> a.select_scatter(b, 0, 0) + tensor([[1., 1.], + [0., 0.]]) + """ + ... +def selu(input: Tensor) -> Tensor: ... +def selu_(input: Tensor) -> Tensor: ... +def set_flush_denormal(mode: _bool) -> _bool: + r""" + set_flush_denormal(mode) -> bool + + Disables denormal floating numbers on CPU. + + Returns ``True`` if your system supports flushing denormal numbers and it + successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal` + is supported on x86 architectures supporting SSE3 and AArch64 architecture. + + Args: + mode (bool): Controls whether to enable flush denormal mode or not + + Example:: + + >>> torch.set_flush_denormal(True) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor([ 0.], dtype=torch.float64) + >>> torch.set_flush_denormal(False) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor(9.88131e-324 * + [ 1.0000], dtype=torch.float64) + """ + ... +def set_num_interop_threads(num: _int) -> None: + r""" + set_num_interop_threads(int) + + Sets the number of threads used for interop parallelism + (e.g. in JIT interpreter) on CPU. + + .. warning:: + Can only be called once and before any inter-op parallel work + is started (e.g. JIT execution). + """ + ... +def set_num_threads(num: _int) -> None: + r""" + set_num_threads(int) + + Sets the number of threads used for intraop parallelism on CPU. + + .. warning:: + To ensure that the correct number of threads is used, set_num_threads + must be called before running eager, JIT or autograd code. + """ + ... +def sgn(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sgn(input, *, out=None) -> Tensor + + This function is an extension of torch.sign() to complex tensors. + It computes a new tensor whose elements have + the same angles as the corresponding elements of :attr:`input` and + absolute values (i.e. magnitudes) of one for complex tensors and + is equivalent to torch.sign() for non-complex tensors. + + .. math:: + \text{out}_{i} = \begin{cases} + 0 & |\text{{input}}_i| == 0 \\ + \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} + \end{cases} + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> t.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) + """ + ... +def sigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sigmoid(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expit`. + """ + ... +def sigmoid_(input: Tensor) -> Tensor: ... +def sign(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sign(input, *, out=None) -> Tensor + + Returns a new tensor with the signs of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> a + tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) + >>> torch.sign(a) + tensor([ 1., -1., 0., 1.]) + """ + ... +def signbit(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + signbit(input, *, out=None) -> Tensor + + Tests if each element of :attr:`input` has its sign bit set or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> torch.signbit(a) + tensor([ False, True, False, False]) + >>> a = torch.tensor([-0.0, 0.0]) + >>> torch.signbit(a) + tensor([ True, False]) + + .. note:: + signbit handles signed zeros, so negative zero (-0) returns True. + """ + ... +def sin(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sin(input, *, out=None) -> Tensor + + Returns a new tensor with the sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) + """ + ... +def sin_(input: Tensor) -> Tensor: ... +def sinc(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sinc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.sinc`. + """ + ... +def sinc_(input: Tensor) -> Tensor: ... +def sinh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sinh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic sine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) + >>> torch.sinh(a) + tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + ... +def sinh_(input: Tensor) -> Tensor: ... +def slice_copy(input: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.slice`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def slice_inverse(input: Tensor, src: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1) -> Tensor: ... +def slice_scatter(input: Tensor, src: Tensor, dim: _int = 0, start: Optional[Union[_int, SymInt]] = None, end: Optional[Union[_int, SymInt]] = None, step: Union[_int, SymInt] = 1, *, out: Optional[Tensor] = None) -> Tensor: + r""" + slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given + dimension. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into + start (Optional[int]): the start index of where to insert the slice + end (Optional[int]): the end index of where to insert the slice + step (int): the how many elements to skip in + + Example:: + + >>> a = torch.zeros(8, 8) + >>> b = torch.ones(2, 8) + >>> a.slice_scatter(b, start=6) + tensor([[0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1.]]) + + >>> b = torch.ones(8, 2) + >>> a.slice_scatter(b, dim=1, start=2, end=6, step=2) + tensor([[0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.]]) + """ + ... +def slogdet(input: Tensor, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.slogdet: + r""" + slogdet(input) -> (Tensor, Tensor) + + Alias for :func:`torch.linalg.slogdet` + """ + ... +def smm(input: Tensor, mat2: Tensor) -> Tensor: + r""" + smm(input, mat) -> Tensor + + Performs a matrix multiplication of the sparse matrix :attr:`input` + with the dense matrix :attr:`mat`. + + Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied + """ + ... +@overload +def softmax(input: Tensor, dim: _int, dtype: Optional[_dtype] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + ... +@overload +def softmax(input: Tensor, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + ... +@overload +def sort(input: Tensor, *, stable: Optional[_bool], dim: _int = -1, descending: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +@overload +def sort(input: Tensor, dim: _int = -1, descending: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +@overload +def sort(input: Tensor, *, stable: Optional[_bool], dim: Union[str, ellipsis, None], descending: _bool = False, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +@overload +def sort(input: Tensor, dim: Union[str, ellipsis, None], descending: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + ... +def sparse_bsc_tensor(ccol_indices: Union[Tensor, List], row_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse + Column)) ` with specified 2-dimensional blocks at the + given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in BSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial blocks for the tensor. Can be a list, + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> ccol_indices = [0, 1, 2] + >>> row_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 1, 2]), + row_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsc) + """ + ... +def sparse_bsr_tensor(crow_indices: Union[Tensor, List], col_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) + ` with specified 2-dimensional blocks at the given + :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix + multiplication operations in BSR format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + blocks in a given row. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> crow_indices = [0, 1, 2] + >>> col_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsr) + """ + ... +def sparse_compressed_tensor(compressed_indices: Union[Tensor, List], plain_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, *, dtype=None, layout=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, + CSC, BSR, or BSC - ` with specified values at + the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse + matrix multiplication operations in Compressed Sparse format are + typically faster than that for sparse tensors in COO format. Make you + have a look at :ref:`the note on the data type of the indices + `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. + plain_indices (array_like): Plain dimension (column or row) + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. + + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + layout (:class:`torch.layout`, required): the desired layout of + returned tensor: :attr:`torch.sparse_csr`, + :attr:`torch.sparse_csc`, :attr:`torch.sparse_bsr`, or + :attr:`torch.sparse_bsc`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> compressed_indices = [0, 2, 4] + >>> plain_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_compressed_tensor(torch.tensor(compressed_indices, dtype=torch.int64), + ... torch.tensor(plain_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double, layout=torch.sparse_csr) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + ... +def sparse_coo_tensor(indices: Tensor, values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None, is_coalesced: Optional[_bool] = None) -> Tensor: + r""" + sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor + + Constructs a :ref:`sparse tensor in COO(rdinate) format + ` with specified values at the given + :attr:`indices`. + + .. note:: + + This function returns an :ref:`uncoalesced tensor + ` when :attr:`is_coalesced` is + unspecified or ``None``. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + indices (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` + internally. The indices are the coordinates of the non-zero values in the matrix, and thus + should be two-dimensional where the first dimension is the number of tensor dimensions and + the second dimension is the number of non-zero values. + values (array_like): Initial values for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not + provided the size will be inferred as the minimum size big enough to hold all non-zero + elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if None, infers data type from :attr:`values`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + is_coalesced (bool, optional): When``True``, the caller is + responsible for providing tensor indices that correspond to a + coalesced tensor. If the :attr:`check_invariants` flag is + False, no error will be raised if the prerequisites are not + met and this will lead to silently incorrect results. To force + coalescion please use :meth:`coalesce` on the resulting + Tensor. + Default: None: except for trivial cases (e.g. nnz < 2) the + resulting Tensor has is_coalesced set to ``False```. + + Example:: + + >>> i = torch.tensor([[0, 1, 1], + ... [2, 0, 2]]) + >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) + >>> torch.sparse_coo_tensor(i, v, [2, 4]) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 4), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v) # Shape inference + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v, [2, 4], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_coo) + + # Create an empty sparse tensor with the following invariants: + # 1. sparse_dim + dense_dim = len(SparseTensor.shape) + # 2. SparseTensor._indices().shape = (sparse_dim, nnz) + # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) + # + # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and + # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0,)), + size=(1,), nnz=0, layout=torch.sparse_coo) + + # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and + # sparse_dim = 1 + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0, 2)), + size=(1, 2), nnz=0, layout=torch.sparse_coo) + + .. _torch.sparse: https://pytorch.org/docs/stable/sparse.html + """ + ... +def sparse_csc_tensor(ccol_indices: Union[Tensor, List], row_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) + ` with specified values at the given + :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in CSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. + row_indices (array_like): Row co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> ccol_indices = [0, 2, 4] + >>> row_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) + """ + ... +def sparse_csr_tensor(crow_indices: Union[Tensor, List], col_indices: Union[Tensor, List], values: Union[Tensor, List], size: Optional[_size] = None, *, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, check_invariants: Optional[_bool] = None) -> Tensor: + r""" + sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified + values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations + in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look + at :ref:`the note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. + col_indices (array_like): Column co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + >>> crow_indices = [0, 2, 4] + >>> col_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + ... +def split_copy(input: Tensor, split_size: Union[_int, SymInt], dim: _int = 0, *, out: Union[Tuple[Tensor, ...], List[Tensor], None] = None) -> None: + r""" + Performs the same operation as :func:`torch.split`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def split_with_sizes(input: Tensor, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: ... +def split_with_sizes_copy(input: Tensor, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0, *, out: Union[Tuple[Tensor, ...], List[Tensor], None] = None) -> None: + r""" + Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def spmm(input: Tensor, mat2: Tensor) -> Tensor: ... +def sqrt(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + sqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the square-root of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.sqrt(a) + tensor([ nan, 1.0112, 0.2883, 0.6933]) + """ + ... +def sqrt_(input: Tensor) -> Tensor: ... +def square(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + square(input, *, out=None) -> Tensor + + Returns a new tensor with the square of the elements of :attr:`input`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.square(a) + tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) + """ + ... +def square_(input: Tensor) -> Tensor: ... +@overload +def squeeze(input: Tensor) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze(input: Tensor, dim: _int) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze(input: Tensor, dim: _size) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze(input: Tensor, dim: Union[str, ellipsis, None]) -> Tensor: + r""" + squeeze(input, dim=None) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + ... +@overload +def squeeze_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def squeeze_copy(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def squeeze_copy(input: Tensor, dim: _size, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def sspaddmm(beta: Union[Number, _complex], self: Tensor, alpha: Union[Number, _complex], mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + ... +@overload +def sspaddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Union[Number, _complex] = 1, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + ... +@overload +def sspaddmm(beta: Union[Number, _complex], self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + ... +def stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + stack(tensors, dim=0, *, out=None) -> Tensor + + Concatenates a sequence of tensors along a new dimension. + + All tensors need to be of the same size. + + .. seealso:: + + :func:`torch.cat` concatenates the given sequence along an existing dimension. + + Arguments: + tensors (sequence of Tensors): sequence of tensors to concatenate + dim (int, optional): dimension to insert. Has to be between 0 and the number + of dimensions of concatenated tensors (inclusive). Default: 0 + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]) + >>> x = torch.stack((x, x)) # same as torch.stack((x, x), dim=0) + >>> x + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]], + + [[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]]) + >>> x.size() + torch.Size([2, 2, 3]) + >>> x = torch.stack((x, x), dim=1) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.3367, 0.1288, 0.2345]], + + [[ 0.2303, -1.1229, -0.1863], + [ 0.2303, -1.1229, -0.1863]]]) + >>> x = torch.stack((x, x), dim=2) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + >>> x = torch.stack((x, x), dim=-1) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + """ + ... +@overload +def std(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, unbiased: _bool = True) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def std_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def sub(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], *, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + ... +@overload +def sub(self: Tensor, alpha: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + ... +@overload +def sub(self: Tensor, alpha: Union[Number, _complex], other: Tensor, *, out: Tensor) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + ... +@overload +def subtract(input: Tensor, other: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor] = None) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + ... +@overload +def subtract(input: Tensor, other: Union[Number, _complex], alpha: Union[Number, _complex] = 1) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + ... +@overload +def sum(input: Tensor, *, dtype: Optional[_dtype] = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + ... +@overload +def sum(input: Tensor, dim: Optional[Union[_int, _size]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + ... +@overload +def sum(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False, *, dtype: Optional[_dtype] = None, out: Optional[Tensor] = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + ... +def svd(input: Tensor, some: _bool = True, compute_uv: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.svd: + r""" + svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Computes the singular value decomposition of either a matrix or batch of + matrices :attr:`input`. The singular value decomposition is represented as a + namedtuple `(U, S, V)`, such that :attr:`input` :math:`= U \text{diag}(S) V^{\text{H}}`. + where :math:`V^{\text{H}}` is the transpose of `V` for real inputs, + and the conjugate transpose of `V` for complex inputs. + If :attr:`input` is a batch of matrices, then `U`, `S`, and `V` are also + batched with the same batch dimensions as :attr:`input`. + + If :attr:`some` is `True` (default), the method returns the reduced singular + value decomposition. In this case, if the last two dimensions of :attr:`input` are + `m` and `n`, then the returned `U` and `V` matrices will contain only + `min(n, m)` orthonormal columns. + + If :attr:`compute_uv` is `False`, the returned `U` and `V` will be + zero-filled matrices of shape `(m, m)` and `(n, n)` + respectively, and the same device as :attr:`input`. The argument :attr:`some` + has no effect when :attr:`compute_uv` is `False`. + + Supports :attr:`input` of float, double, cfloat and cdouble data types. + The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will + always be real-valued, even if :attr:`input` is complex. + + .. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.mH + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.linalg.svdvals(A) + + .. note:: Differences with :func:`torch.linalg.svd`: + + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matrices`. Note that + default value for both is `True`, so the default behavior is + effectively the opposite. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns + `Vh`, that is, :math:`V^{\text{H}}`. + * If :attr:`compute_uv` is `False`, :func:`torch.svd` returns zero-filled + tensors for `U` and `Vh`, whereas :func:`torch.linalg.svd` returns + empty tensors. + + .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch are returned in descending order. + + .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is `True`. + + .. note:: When :attr:`some` is `False`, the gradients on `U[..., :, min(m, n):]` + and `V[..., :, min(m, n):]` will be ignored in the backward pass, as those vectors + can be arbitrary bases of the corresponding subspaces. + + .. note:: The implementation of :func:`torch.linalg.svd` on CPU uses LAPACK's routine `?gesdd` + (a divide-and-conquer algorithm) instead of `?gesvd` for speed. Analogously, + on GPU, it uses cuSOLVER's routines `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 + and later, and MAGMA's routine `gesdd` on earlier versions of CUDA. + + .. note:: The returned `U` will not be contiguous. The matrix (or batch of matrices) will + be represented as a column-major matrix (i.e. Fortran-contiguous). + + .. warning:: The gradients with respect to `U` and `V` will only be finite when the input does not + have zero nor repeated singular values. + + .. warning:: If the distance between any two singular values is close to zero, the gradients with respect to + `U` and `V` will be numerically unstable, as they depends on + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix + has small singular values, as these gradients also depend on `S^{-1}`. + + .. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique, + as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column. + The same happens when :attr:`input` has repeated singular values, where one may multiply + the columns of the spanning subspace in `U` and `V` by a rotation matrix + and `the resulting vectors will span the same subspace`_. + Different platforms, like NumPy, or inputs on different device types, + may produce different `U` and `V` tensors. + + Args: + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m, n)` matrices. + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently, the shape of returned `U` and `V`. Default: `True`. + compute_uv (bool, optional): controls whether to compute `U` and `V`. Default: `True`. + + Keyword args: + out (tuple, optional): the output tuple of tensors + + Example:: + + >>> a = torch.randn(5, 3) + >>> a + tensor([[ 0.2364, -0.7752, 0.6372], + [ 1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [ 0.3550, -0.4022, 1.5569], + [ 0.2445, -0.0158, 1.1414]]) + >>> u, s, v = torch.svd(a) + >>> u + tensor([[ 0.4027, 0.0287, 0.5434], + [-0.1946, 0.8833, 0.3679], + [ 0.4296, -0.2890, 0.5261], + [ 0.6604, 0.2717, -0.2618], + [ 0.4234, 0.2481, -0.4733]]) + >>> s + tensor([2.3289, 2.0315, 0.7806]) + >>> v + tensor([[-0.0199, 0.8766, 0.4809], + [-0.5080, 0.4054, -0.7600], + [ 0.8611, 0.2594, -0.4373]]) + >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) + tensor(8.6531e-07) + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, v = torch.svd(a_big) + >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) + tensor(2.6503e-06) + + .. _the resulting vectors will span the same subspace: + (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) + """ + ... +def swapaxes(input: Tensor, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes(input, axis0, axis1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + ... +def swapdims(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims(input, dim0, dim1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + ... +def sym_constrain_range(size: Union[Number, _complex], *, min: Optional[_int] = None, max: Optional[_int] = None) -> None: ... +def sym_constrain_range_for_size(size: Union[Number, _complex], *, min: Optional[_int] = None, max: Optional[_int] = None) -> None: ... +def t(input: Tensor) -> Tensor: + r""" + t(input) -> Tensor + + Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 + and 1. + + 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this + is equivalent to ``transpose(input, 0, 1)``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.randn(()) + >>> x + tensor(0.1995) + >>> torch.t(x) + tensor(0.1995) + >>> x = torch.randn(3) + >>> x + tensor([ 2.4320, -0.4608, 0.7702]) + >>> torch.t(x) + tensor([ 2.4320, -0.4608, 0.7702]) + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.4875, 0.9158, -0.5872], + [ 0.3938, -0.6929, 0.6932]]) + >>> torch.t(x) + tensor([[ 0.4875, 0.3938], + [ 0.9158, -0.6929], + [-0.5872, 0.6932]]) + + See also :func:`torch.transpose`. + """ + ... +def t_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.t`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def take(input: Tensor, index: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + take(input, index) -> Tensor + + Returns a new tensor with the elements of :attr:`input` at the given indices. + The input tensor is treated as if it were viewed as a 1-D tensor. The result + takes the same shape as the indices. + + Args: + input (Tensor): the input tensor. + index (LongTensor): the indices into tensor + + Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> torch.take(src, torch.tensor([0, 2, 5])) + tensor([ 4, 5, 8]) + """ + ... +def take_along_dim(input: Tensor, indices: Tensor, dim: Optional[_int] = None, *, out: Optional[Tensor] = None) -> Tensor: + r""" + take_along_dim(input, indices, dim=None, *, out=None) -> Tensor + + Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. + + If :attr:`dim` is None, the input array is treated as if it has been flattened to 1d. + + Functions that return indices along a dimension, like :func:`torch.argmax` and :func:`torch.argsort`, + are designed to work with this function. See the examples below. + + .. note:: + This function is similar to NumPy's `take_along_axis`. + See also :func:`torch.gather`. + + Args: + input (Tensor): the input tensor. + indices (tensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) + >>> max_idx = torch.argmax(t) + >>> torch.take_along_dim(t, max_idx) + tensor([60]) + >>> sorted_idx = torch.argsort(t, dim=1) + >>> torch.take_along_dim(t, sorted_idx, dim=1) + tensor([[10, 20, 30], + [40, 50, 60]]) + """ + ... +def tan(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + tan(input, *, out=None) -> Tensor + + Returns a new tensor with the tangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.2027, -1.7687, 0.4412, -1.3856]) + >>> torch.tan(a) + tensor([-2.5930, 4.9859, 0.4722, -5.3366]) + """ + ... +def tan_(input: Tensor) -> Tensor: ... +def tanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + tanh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic tangent of the elements + of :attr:`input`. + + .. math:: + \text{out}_{i} = \tanh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) + >>> torch.tanh(a) + tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) + """ + ... +def tanh_(input: Tensor) -> Tensor: ... +def tensor(data: Any, dtype: Optional[_dtype] = None, device: Optional[DeviceLikeType] = None, requires_grad: _bool = False, pin_memory: _bool = False) -> Tensor: + r""" + tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. + + .. warning:: + + When working with tensors prefer using :func:`torch.Tensor.clone`, + :func:`torch.Tensor.detach`, and :func:`torch.Tensor.requires_grad_` for + readability. Letting `t` be a tensor, ``torch.tensor(t)`` is equivalent to + ``t.clone().detach()``, and ``torch.tensor(t, requires_grad=True)`` + is equivalent to ``t.clone().detach().requires_grad_(True)``. + + .. seealso:: + + :func:`torch.as_tensor` preserves autograd history and avoids copies where possible. + :func:`torch.from_numpy` creates a tensor that shares storage with a NumPy array. + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + + Example:: + + >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + tensor([[ 0.1000, 1.2000], + [ 2.2000, 3.1000], + [ 4.9000, 5.2000]]) + + >>> torch.tensor([0, 1]) # Type inference on data + tensor([ 0, 1]) + + >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) # creates a double tensor on a CUDA device + tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') + + >>> torch.tensor(3.14159) # Create a zero-dimensional (scalar) tensor + tensor(3.1416) + + >>> torch.tensor([]) # Create an empty tensor (of size (0,)) + tensor([]) + """ + ... +@overload +def tensor_split(input: Tensor, tensor_indices_or_sections: Tensor, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + ... +@overload +def tensor_split(input: Tensor, sections: Union[_int, SymInt], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + ... +@overload +def tensor_split(input: Tensor, indices: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + ... +def threshold(input: Tensor, threshold: Union[Number, _complex], value: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: ... +def threshold_(input: Tensor, threshold: Union[Number, _complex], value: Union[Number, _complex]) -> Tensor: ... +def tile(input: Tensor, dims: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + tile(input, dims) -> Tensor + + Constructs a tensor by repeating the elements of :attr:`input`. + The :attr:`dims` argument specifies the number of repetitions + in each dimension. + + If :attr:`dims` specifies fewer dimensions than :attr:`input` has, then + ones are prepended to :attr:`dims` until all dimensions are specified. + For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`dims` + is (2, 2), then :attr:`dims` is treated as (1, 1, 2, 2). + + Analogously, if :attr:`input` has fewer dimensions than :attr:`dims` + specifies, then :attr:`input` is treated as if it were unsqueezed at + dimension zero until it has as many dimensions as :attr:`dims` specifies. + For example, if :attr:`input` has shape (4, 2) and :attr:`dims` + is (3, 3, 2, 2), then :attr:`input` is treated as if it had the + shape (1, 1, 4, 2). + + .. note:: + + This function is similar to NumPy's tile function. + + Args: + input (Tensor): the tensor whose elements to repeat. + dims (tuple): the number of repetitions per dimension. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + """ + ... +def topk(input: Tensor, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.topk: + r""" + topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) + + Returns the :attr:`k` largest elements of the given :attr:`input` tensor along + a given dimension. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`largest` is ``False`` then the `k` smallest elements are returned. + + A namedtuple of `(values, indices)` is returned with the `values` and + `indices` of the largest `k` elements of each row of the `input` tensor in the + given dimension `dim`. + + The boolean option :attr:`sorted` if ``True``, will make sure that the returned + `k` elements are themselves sorted + + Args: + input (Tensor): the input tensor. + k (int): the k in "top-k" + dim (int, optional): the dimension to sort along + largest (bool, optional): controls whether to return largest or + smallest elements + sorted (bool, optional): controls whether to return the elements + in sorted order + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be + optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.topk(x, 3) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) + """ + ... +def trace(input: Tensor) -> Tensor: + r""" + trace(input) -> Tensor + + Returns the sum of the elements of the diagonal of the input 2-D matrix. + + Example:: + + >>> x = torch.arange(1., 10.).view(3, 3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]]) + >>> torch.trace(x) + tensor(15.) + """ + ... +@overload +def transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + ... +@overload +def transpose(input: Tensor, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None]) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + ... +def transpose_copy(input: Tensor, dim0: _int, dim1: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.transpose`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + ... +@overload +def trapezoid(y: Tensor, *, dx: Union[Number, _complex] = 1, dim: _int = -1) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + ... +@overload +def trapz(y: Tensor, *, dx: _float = 1, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + ... +@overload +def trapz(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + ... +def triangular_solve(input: Tensor, A: Tensor, upper: _bool = True, transpose: _bool = False, unitriangular: _bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.triangular_solve: + r""" + triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) + + Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` + and multiple right-hand sides :math:`b`. + + In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular + (or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal. + + `torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are + batches of 2D matrices. If the inputs are batches, then returns + batched outputs `X` + + If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and + :attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, + the result may contain `NaN` s. + + Supports input of float, double, cfloat and cdouble data types. + + .. warning:: + + :func:`torch.triangular_solve` is deprecated in favor of :func:`torch.linalg.solve_triangular` + and will be removed in a future PyTorch release. + :func:`torch.linalg.solve_triangular` has its arguments reversed and does not return a + copy of one of the inputs. + + ``X = torch.triangular_solve(B, A).solution`` should be replaced with + + .. code:: python + + X = torch.linalg.solve_triangular(A, B) + + Args: + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``. + transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``, + and `op(A) = A` if it is ``False``. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + + Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + + Returns: + A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` + is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` + (or whatever variant of the system of equations, depending on the keyword arguments.) + + Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + torch.return_types.triangular_solve( + solution=tensor([[ 1.7841, 2.9046, -2.5405], + [ 1.9320, 0.9270, -1.2826]]), + cloned_coefficient=tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) + """ + ... +def tril(input: Tensor, diagonal: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + tril(input, diagonal=0, *, out=None) -> Tensor + + Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0813, -0.8619, 0.7105], + [ 0.0935, 0.1380, 2.2112], + [-0.3409, -0.9828, 0.0289]]) + >>> torch.tril(a) + tensor([[-1.0813, 0.0000, 0.0000], + [ 0.0935, 0.1380, 0.0000], + [-0.3409, -0.9828, 0.0289]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], + [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) + >>> torch.tril(b, diagonal=1) + tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) + >>> torch.tril(b, diagonal=-1) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) + """ + ... +def tril_indices(row: _int, col: _int, offset: _int = 0, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the lower triangular part of a :attr:`row`-by- + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.tril_indices(3, 3) + >>> a + tensor([[0, 1, 1, 2, 2, 2], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, -1) + >>> a + tensor([[1, 2, 2, 3, 3, 3], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) + """ + ... +def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, margin: _float = 1.0, p: _float = 2, eps: _float = 1e-06, swap: _bool = False, reduction: _int = 1) -> Tensor: ... +def triu(input: Tensor, diagonal: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: + r""" + triu(input, diagonal=0, *, out=None) -> Tensor + + Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.3480, -0.5211, -0.4573]]) + >>> torch.triu(a) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.0000, -1.0680, 0.6602], + [ 0.0000, 0.0000, -0.4573]]) + >>> torch.triu(a, diagonal=1) + tensor([[ 0.0000, 0.5207, 2.0049], + [ 0.0000, 0.0000, 0.6602], + [ 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(a, diagonal=-1) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.0000, -0.5211, -0.4573]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) + """ + ... +def triu_indices(row: _int, col: _int, offset: _int = 0, *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the upper triangular part of a :attr:`row` by + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.triu_indices(3, 3) + >>> a + tensor([[0, 0, 0, 1, 1, 2], + [0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, -1) + >>> a + tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], + [0, 1, 2, 0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1], + [1, 2, 2]]) + """ + ... +def true_divide(input: Union[Tensor, Number], other: Union[Tensor, Number], *, out: Optional[Tensor] = None) -> Tensor: + r""" + true_divide(dividend, divisor, *, out) -> Tensor + + Alias for :func:`torch.div` with ``rounding_mode=None``. + """ + ... +def trunc(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + trunc(input, *, out=None) -> Tensor + + Returns a new tensor with the truncated integer values of + the elements of :attr:`input`. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) + >>> torch.trunc(a) + tensor([ 3., 0., -0., -0.]) + """ + ... +def trunc_(input: Tensor) -> Tensor: ... +@overload +def unbind(input: Tensor, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + ... +@overload +def unbind(input: Tensor, dim: Union[str, ellipsis, None]) -> Tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + ... +def unbind_copy(input: Tensor, dim: _int = 0, *, out: Union[Tuple[Tensor, ...], List[Tensor], None] = None) -> None: + r""" + Performs the same operation as :func:`torch.unbind`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def unflatten(input: Tensor, dim: Union[str, ellipsis, None], sizes: Sequence[Union[_int, SymInt]], names: Sequence[Union[str, ellipsis, None]]) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + ... +@overload +def unflatten(input: Tensor, dim: _int, sizes: Sequence[Union[_int, SymInt]]) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + ... +def unfold_copy(input: Tensor, dimension: _int, size: _int, step: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.unfold`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def unique_dim(input: Tensor, dim: _int, sorted: _bool = True, return_inverse: _bool = False, return_counts: _bool = False) -> Tuple[Tensor, Tensor, Tensor]: ... +def unsafe_chunk(input: Tensor, chunks: _int, dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unsafe_chunk(input, chunks, dim=0) -> List of Tensors + + Works like :func:`torch.chunk` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + ... +def unsafe_split(input: Tensor, split_size: Union[_int, SymInt], dim: _int = 0) -> Tuple[Tensor, ...]: + r""" + unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors + + Works like :func:`torch.split` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + ... +def unsafe_split_with_sizes(input: Tensor, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0) -> Tuple[Tensor, ...]: ... +def unsqueeze(input: Tensor, dim: _int) -> Tensor: + r""" + unsqueeze(input, dim) -> Tensor + + Returns a new tensor with a dimension of size one inserted at the + specified position. + + The returned tensor shares the same underlying data with this tensor. + + A :attr:`dim` value within the range ``[-input.dim() - 1, input.dim() + 1)`` + can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` + applied at :attr:`dim` = ``dim + input.dim() + 1``. + + Args: + input (Tensor): the input tensor. + dim (int): the index at which to insert the singleton dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4]) + >>> torch.unsqueeze(x, 0) + tensor([[ 1, 2, 3, 4]]) + >>> torch.unsqueeze(x, 1) + tensor([[ 1], + [ 2], + [ 3], + [ 4]]) + """ + ... +def unsqueeze_copy(input: Tensor, dim: _int, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.unsqueeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def values_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.values`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def vander(x: Tensor, N: Optional[_int] = None, increasing: _bool = False) -> Tensor: + r""" + vander(x, N=None, increasing=False) -> Tensor + + Generates a Vandermonde matrix. + + The columns of the output matrix are elementwise powers of the input vector :math:`x^{(N-1)}, x^{(N-2)}, ..., x^0`. + If increasing is True, the order of the columns is reversed :math:`x^0, x^1, ..., x^{(N-1)}`. Such a + matrix with a geometric progression in each row is named for Alexandre-Theophile Vandermonde. + + Arguments: + x (Tensor): 1-D input tensor. + N (int, optional): Number of columns in the output. If N is not specified, + a square array is returned :math:`(N = len(x))`. + increasing (bool, optional): Order of the powers of the columns. If True, + the powers increase from left to right, if False (the default) they are reversed. + + Returns: + Tensor: Vandermonde matrix. If increasing is False, the first column is :math:`x^{(N-1)}`, + the second :math:`x^{(N-2)}` and so forth. If increasing is True, the columns + are :math:`x^0, x^1, ..., x^{(N-1)}`. + + Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> torch.vander(x) + tensor([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + >>> torch.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 4, 2, 1], + [ 9, 3, 1], + [25, 5, 1]]) + >>> torch.vander(x, N=3, increasing=True) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) + """ + ... +@overload +def var(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False, *, out: Optional[Tensor] = None) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Optional[Union[_int, _size]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Optional[Union[_int, _size]] = None, *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, unbiased: _bool = True) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[Union[Number, _complex]] = None, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +@overload +def var_mean(input: Tensor, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool = True, keepdim: _bool = False) -> Tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + ... +def vdot(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + vdot(input, other, *, out=None) -> Tensor + + Computes the dot product of two 1D vectors along a dimension. + + In symbols, this function computes + + .. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + + where :math:`\overline{x_i}` denotes the conjugate for complex + vectors, and it is the identity for real vectors. + + .. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + .. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + + .. note:: out (Tensor, optional): the output tensor. + + + Example:: + + >>> torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + >>> a = torch.tensor((1 +2j, 3 - 1j)) + >>> b = torch.tensor((2 +1j, 4 - 0j)) + >>> torch.vdot(a, b) + tensor([16.+1.j]) + >>> torch.vdot(b, a) + tensor([16.-1.j]) + """ + ... +def view_as_complex(input: Tensor) -> Tensor: + r""" + view_as_complex(input) -> Tensor + + Returns a view of :attr:`input` as a complex tensor. For an input complex + tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a + new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last + dimension of the input tensor is expected to represent the real and imaginary + components of complex numbers. + + .. warning:: + :func:`view_as_complex` is only supported for tensors with + :class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is + expected to have the last dimension of :attr:`size` 2. In addition, the + tensor must have a `stride` of 1 for its last dimension. The strides of all + other dimensions must be even numbers. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, 2) + >>> x + tensor([[ 1.6116, -0.5772], + [-1.4606, -0.9120], + [ 0.0786, -1.7497], + [-0.6561, -1.6623]]) + >>> torch.view_as_complex(x) + tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) + """ + ... +def view_as_complex_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_complex`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +def view_as_real(input: Tensor) -> Tensor: + r""" + view_as_real(input) -> Tensor + + Returns a view of :attr:`input` as a real tensor. For an input complex tensor of + :attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new + real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2 + represents the real and imaginary components of complex numbers. + + .. warning:: + :func:`view_as_real` is only supported for tensors with ``complex dtypes``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)]) + >>> torch.view_as_real(x) + tensor([[ 0.4737, -0.3839], + [-0.2098, -0.6699], + [ 0.3470, -0.9451], + [-0.5174, -1.3136]]) + """ + ... +def view_as_real_copy(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_real`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def view_copy(input: Tensor, dtype: _dtype, *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def view_copy(input: Tensor, size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + ... +@overload +def vsplit(input: Tensor, sections: _int) -> Tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + ... +@overload +def vsplit(input: Tensor, indices: _size) -> Tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + ... +def vstack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], *, out: Optional[Tensor] = None) -> Tensor: + r""" + vstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence vertically (row wise). + + This is equivalent to concatenation along the first axis after all 1-D tensors have been reshaped by :func:`torch.atleast_2d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.vstack((a,b)) + tensor([[1, 2, 3], + [4, 5, 6]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.vstack((a,b)) + tensor([[1], + [2], + [3], + [4], + [5], + [6]]) + """ + ... +@overload +def where(condition: Tensor) -> Tuple[Tensor, ...]: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, self: Union[Number, _complex], other: Tensor) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, input: Tensor, other: Union[Number, _complex]) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def where(condition: Tensor, self: Union[Number, _complex], other: Union[Number, _complex]) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + ... +@overload +def xlogy(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + ... +@overload +def xlogy(self: Union[Number, _complex], other: Tensor, *, out: Optional[Tensor] = None) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + ... +@overload +def xlogy(input: Tensor, other: Union[Number, _complex], *, out: Optional[Tensor] = None) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + ... +@overload +def xlogy_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def xlogy_(input: Tensor, other: Union[Number, _complex]) -> Tensor: ... +def zero_(input: Tensor) -> Tensor: ... +@overload +def zeros(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +@overload +def zeros(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +@overload +def zeros(size: _size, *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +@overload +def zeros(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + ... +def zeros_like(input: Tensor, *, memory_format: Optional[memory_format] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: + r""" + zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the same size as + :attr:`input`. ``torch.zeros_like(input)`` is equivalent to + ``torch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.zeros_like(input, out=output)`` is equivalent to + ``torch.zeros(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.zeros_like(input) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + ... + +__all__ = ['__and__', '__lshift__', '__or__', '__rshift__', '__xor__', '_adaptive_avg_pool2d', + '_adaptive_avg_pool3d', '_add_batch_dim', '_add_relu', '_add_relu_', '_addmm_activation', + '_aminmax', '_amp_foreach_non_finite_check_and_unscale_', '_amp_update_scale_', '_assert_async', + '_assert_scalar', '_assert_tensor_metadata', '_batch_norm_impl_index', '_cast_Byte', '_cast_Char', + '_cast_Double', '_cast_Float', '_cast_Half', '_cast_Int', '_cast_Long', '_cast_Short', + '_choose_qparams_per_tensor', '_chunk_cat', '_coalesce', '_compute_linear_combination', '_conj', + '_conj_copy', '_conj_physical', '_convert_indices_from_coo_to_csr', + '_convert_indices_from_csr_to_coo', '_convert_weight_to_int4pack', '_convolution', + '_convolution_mode', '_copy_from', '_copy_from_and_resize', '_cslt_compress', '_cslt_sparse_mm', + '_cslt_sparse_mm_search', '_ctc_loss', '_cudnn_ctc_loss', '_cudnn_init_dropout_state', + '_cudnn_rnn', '_cudnn_rnn_flatten_weight', '_cufft_clear_plan_cache', + '_cufft_get_plan_cache_max_size', '_cufft_get_plan_cache_size', '_cufft_set_plan_cache_max_size', + '_cummax_helper', '_cummin_helper', '_debug_has_internal_overlap', '_dim_arange', + '_dirichlet_grad', '_disable_functionalization', '_efficientzerotensor', '_embedding_bag', + '_embedding_bag_forward_only', '_empty_affine_quantized', '_empty_per_channel_affine_quantized', + '_enable_functionalization', '_euclidean_dist', '_fake_quantize_learnable_per_channel_affine', + '_fake_quantize_learnable_per_tensor_affine', + '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', + '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', '_fft_c2c', '_fft_c2r', '_fft_r2c', + '_fill_mem_eff_dropout_mask_', '_foobar', '_foreach_abs', '_foreach_abs_', '_foreach_acos', + '_foreach_acos_', '_foreach_add', '_foreach_add_', '_foreach_addcdiv', '_foreach_addcdiv_', + '_foreach_addcmul', '_foreach_addcmul_', '_foreach_asin', '_foreach_asin_', '_foreach_atan', + '_foreach_atan_', '_foreach_ceil', '_foreach_ceil_', '_foreach_clamp_max', '_foreach_clamp_max_', + '_foreach_clamp_min', '_foreach_clamp_min_', '_foreach_copy_', '_foreach_cos', '_foreach_cos_', + '_foreach_cosh', '_foreach_cosh_', '_foreach_div', '_foreach_div_', '_foreach_erf', + '_foreach_erf_', '_foreach_erfc', '_foreach_erfc_', '_foreach_exp', '_foreach_exp_', + '_foreach_expm1', '_foreach_expm1_', '_foreach_floor', '_foreach_floor_', '_foreach_frac', + '_foreach_frac_', '_foreach_lerp', '_foreach_lerp_', '_foreach_lgamma', '_foreach_lgamma_', + '_foreach_log', '_foreach_log10', '_foreach_log10_', '_foreach_log1p', '_foreach_log1p_', + '_foreach_log2', '_foreach_log2_', '_foreach_log_', '_foreach_maximum', '_foreach_maximum_', + '_foreach_minimum', '_foreach_minimum_', '_foreach_mul', '_foreach_mul_', '_foreach_neg', + '_foreach_neg_', '_foreach_norm', '_foreach_pow', '_foreach_pow_', '_foreach_reciprocal', + '_foreach_reciprocal_', '_foreach_round', '_foreach_round_', '_foreach_sigmoid', + '_foreach_sigmoid_', '_foreach_sign', '_foreach_sign_', '_foreach_sin', '_foreach_sin_', + '_foreach_sinh', '_foreach_sinh_', '_foreach_sqrt', '_foreach_sqrt_', '_foreach_sub', + '_foreach_sub_', '_foreach_tan', '_foreach_tan_', '_foreach_tanh', '_foreach_tanh_', + '_foreach_trunc', '_foreach_trunc_', '_foreach_zero_', '_from_functional_tensor', + '_functional_assert_async', '_functional_assert_scalar', '_functional_sym_constrain_range', + '_functional_sym_constrain_range_for_size', + '_functionalize_are_all_mutations_hidden_from_autograd', + '_functionalize_are_all_mutations_under_no_grad_or_inference_mode', '_functionalize_commit_update', + '_functionalize_mark_mutation_hidden_from_autograd', '_functionalize_replace', + '_functionalize_sync', '_fused_adam_', '_fused_adamw_', '_fused_dropout', + '_fused_moving_avg_obs_fq_helper', '_fused_moving_avg_obs_fq_helper', '_fused_sdp_choice', + '_fused_sgd_', '_fw_primal_copy', '_grid_sampler_2d_cpu_fallback', + '_has_compatible_shallow_copy_type', '_histogramdd_bin_edges', '_histogramdd_from_bin_cts', + '_histogramdd_from_bin_tensors', '_index_put_impl_', '_indices_copy', '_int_mm', '_is_all_true', + '_is_any_true', '_is_functional_tensor', '_is_zerotensor', '_lazy_clone', '_linalg_check_errors', + '_linalg_det', '_linalg_det', '_linalg_eigh', '_linalg_eigh', '_linalg_slogdet', '_linalg_slogdet', + '_linalg_solve_ex', '_linalg_solve_ex', '_linalg_svd', '_linalg_svd', '_log_softmax', + '_log_softmax_backward_data', '_logcumsumexp', '_lstm_mps', '_lu_with_info', '_lu_with_info', + '_make_dep_token', '_make_dual', '_make_dual_copy', '_make_per_channel_quantized_tensor', + '_make_per_tensor_quantized_tensor', '_masked_scale', '_masked_softmax', '_mixed_dtypes_linear', + '_mkldnn_reshape', '_mkldnn_transpose', '_mkldnn_transpose_', '_mps_convolution', + '_mps_convolution_transpose', '_native_batch_norm_legit', '_native_batch_norm_legit_no_training', + '_native_multi_head_attention', '_neg_view', '_neg_view_copy', '_nested_from_padded', + '_nested_from_padded_and_nested_example', '_nested_get_jagged_dummy', '_nested_get_lengths', + '_nested_get_offsets', '_nested_get_ragged_idx', '_nested_get_values', '_nested_get_values_copy', + '_nested_tensor_from_mask', '_nested_tensor_from_mask_left_aligned', + '_nested_tensor_from_tensor_list', '_nested_tensor_softmax_with_shape', '_nested_view_from_buffer', + '_nested_view_from_buffer_copy', '_nested_view_from_jagged', '_nested_view_from_jagged_copy', + '_nnpack_available', '_nnpack_spatial_convolution', '_pack_padded_sequence', + '_pad_packed_sequence', '_pin_memory', '_prelu_kernel', '_print', '_propagate_xla_data', + '_remove_batch_dim', '_reshape_alias_copy', '_reshape_from_tensor', '_resize_output_', + '_rowwise_prune', '_sample_dirichlet', '_saturate_weight_to_fp16', + '_scaled_dot_product_attention_math', '_scaled_dot_product_cudnn_attention', + '_scaled_dot_product_cudnn_attention', '_scaled_dot_product_efficient_attention', + '_scaled_dot_product_efficient_attention', '_scaled_dot_product_flash_attention', + '_scaled_dot_product_flash_attention', '_scaled_dot_product_flash_attention_for_cpu', + '_scaled_dot_product_flash_attention_for_cpu', '_scaled_mm', '_shape_as_tensor', + '_sobol_engine_draw', '_sobol_engine_ff_', '_sobol_engine_initialize_state_', + '_sobol_engine_scramble_', '_softmax', '_softmax_backward_data', '_sparse_broadcast_to', + '_sparse_broadcast_to_copy', '_sparse_csr_prod', '_sparse_csr_sum', + '_sparse_log_softmax_backward_data', '_sparse_semi_structured_linear', + '_sparse_softmax_backward_data', '_sparse_sparse_matmul', '_sparse_sum', '_stack', + '_standard_gamma', '_standard_gamma_grad', '_sync', '_test_autograd_multiple_dispatch', + '_test_autograd_multiple_dispatch_view', '_test_autograd_multiple_dispatch_view_copy', + '_test_check_tensor', '_test_functorch_fallback', '_test_parallel_materialize', + '_test_serialization_subcmul', '_to_cpu', '_to_functional_tensor', '_to_sparse_semi_structured', + '_transform_bias_rescale_qkv', '_transformer_encoder_layer_fwd', '_trilinear', + '_triton_multi_head_attention', '_triton_scaled_dot_attention', '_unique', '_unique2', + '_unpack_dual', '_unpack_dual', '_unsafe_index', '_unsafe_index_put', '_use_cudnn_ctc_loss', + '_use_cudnn_rnn_flatten_weight', '_validate_compressed_sparse_indices', + '_validate_sparse_bsc_tensor_args', '_validate_sparse_bsr_tensor_args', + '_validate_sparse_compressed_tensor_args', '_validate_sparse_coo_tensor_args', + '_validate_sparse_csc_tensor_args', '_validate_sparse_csr_tensor_args', '_values_copy', + '_weight_int4pack_mm', '_weight_int8pack_mm', '_weight_norm', '_weight_norm_interface', 'abs', + 'abs_', 'absolute', 'acos', 'acos_', 'acosh', 'acosh_', 'adaptive_avg_pool1d', + 'adaptive_max_pool1d', 'add', 'addbmm', 'addcdiv', 'addcmul', 'addmm', 'addmv', 'addmv_', 'addr', + 'adjoint', 'affine_grid_generator', 'alias_copy', 'all', 'allclose', 'alpha_dropout', + 'alpha_dropout_', 'amax', 'amin', 'aminmax', 'aminmax', 'angle', 'any', 'arange', 'arccos', + 'arccos_', 'arccosh', 'arccosh_', 'arcsin', 'arcsin_', 'arcsinh', 'arcsinh_', 'arctan', 'arctan2', + 'arctan_', 'arctanh', 'arctanh_', 'argmax', 'argmin', 'argsort', 'argwhere', 'as_strided', + 'as_strided_', 'as_strided_copy', 'as_strided_scatter', 'as_tensor', 'asarray', 'asin', 'asin_', + 'asinh', 'asinh_', 'atan', 'atan2', 'atan_', 'atanh', 'atanh_', 'avg_pool1d', 'baddbmm', + 'bartlett_window', 'batch_norm', 'batch_norm_backward_elemt', 'batch_norm_backward_reduce', + 'batch_norm_elemt', 'batch_norm_gather_stats', 'batch_norm_gather_stats_with_counts', + 'batch_norm_stats', 'batch_norm_update_stats', 'bernoulli', 'bilinear', + 'binary_cross_entropy_with_logits', 'bincount', 'binomial', 'bitwise_and', 'bitwise_left_shift', + 'bitwise_not', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'blackman_window', 'bmm', + 'broadcast_to', 'bucketize', 'can_cast', 'cat', 'ccol_indices_copy', 'ceil', 'ceil_', 'celu', + 'celu_', 'channel_shuffle', 'cholesky', 'cholesky_inverse', 'cholesky_solve', + 'choose_qparams_optimized', 'chunk', 'clamp', 'clamp_', 'clamp_max', 'clamp_max_', 'clamp_min', + 'clamp_min_', 'clip', 'clip_', 'clone', 'col_indices_copy', 'column_stack', 'combinations', + 'complex', 'concat', 'concatenate', 'conj', 'conj_physical', 'conj_physical_', 'constant_pad_nd', + 'conv1d', 'conv2d', 'conv3d', 'conv_tbc', 'conv_transpose1d', 'conv_transpose2d', + 'conv_transpose3d', 'convolution', 'copysign', 'corrcoef', 'cos', 'cos_', 'cosh', 'cosh_', + 'cosine_embedding_loss', 'cosine_similarity', 'count_nonzero', 'cov', 'cross', 'crow_indices_copy', + 'ctc_loss', 'cudnn_affine_grid_generator', 'cudnn_batch_norm', 'cudnn_convolution', + 'cudnn_convolution_add_relu', 'cudnn_convolution_relu', 'cudnn_convolution_transpose', + 'cudnn_grid_sampler', 'cudnn_is_acceptable', 'cummax', 'cummax', 'cummin', 'cummin', 'cumprod', + 'cumsum', 'cumulative_trapezoid', 'deg2rad', 'deg2rad_', 'dequantize', 'det', 'detach', 'detach_', + 'detach_copy', 'diag', 'diag_embed', 'diagflat', 'diagonal', 'diagonal_copy', 'diagonal_scatter', + 'diff', 'digamma', 'dist', 'div', 'divide', 'dot', 'dropout', 'dropout_', 'dsmm', 'dsplit', + 'dstack', 'embedding', 'embedding_bag', 'embedding_renorm_', 'empty', 'empty_like', + 'empty_permuted', 'empty_quantized', 'empty_strided', 'eq', 'equal', 'erf', 'erf_', 'erfc', + 'erfc_', 'erfinv', 'exp', 'exp2', 'exp2_', 'exp_', 'expand_copy', 'expm1', 'expm1_', 'eye', + 'fake_quantize_per_channel_affine', 'fake_quantize_per_tensor_affine', 'fbgemm_linear_fp16_weight', + 'fbgemm_linear_fp16_weight_fp32_activation', 'fbgemm_linear_int8_weight', + 'fbgemm_linear_int8_weight_fp32_activation', 'fbgemm_linear_quantize_weight', + 'fbgemm_pack_gemm_matrix_fp16', 'fbgemm_pack_quantized_matrix', 'feature_alpha_dropout', + 'feature_alpha_dropout_', 'feature_dropout', 'feature_dropout_', 'fill', 'fill_', 'fix', 'fix_', + 'flatten', 'flip', 'fliplr', 'flipud', 'float_power', 'floor', 'floor_', 'floor_divide', 'fmax', + 'fmin', 'fmod', 'frac', 'frac_', 'frexp', 'frexp', 'frobenius_norm', 'from_file', 'from_numpy', + 'frombuffer', 'full', 'full_like', 'fused_moving_avg_obs_fake_quant', 'gather', 'gcd', 'gcd_', + 'ge', 'geqrf', 'geqrf', 'ger', 'get_default_dtype', 'get_num_interop_threads', 'get_num_threads', + 'gradient', 'greater', 'greater_equal', 'grid_sampler', 'grid_sampler_2d', 'grid_sampler_3d', + 'group_norm', 'gru', 'gru_cell', 'gt', 'hamming_window', 'hann_window', 'hardshrink', 'heaviside', + 'hinge_embedding_loss', 'histc', 'histogram', 'histogram', 'histogramdd', 'histogramdd', 'hsmm', + 'hsplit', 'hspmm', 'hstack', 'hypot', 'i0', 'i0_', 'igamma', 'igammac', 'imag', 'index_add', + 'index_copy', 'index_fill', 'index_put', 'index_put_', 'index_reduce', 'index_select', + 'indices_copy', 'init_num_threads', 'inner', 'instance_norm', 'int_repr', 'inverse', 'is_complex', + 'is_conj', 'is_distributed', 'is_floating_point', 'is_grad_enabled', 'is_inference', + 'is_inference_mode_enabled', 'is_neg', 'is_nonzero', 'is_same_size', 'is_signed', + 'is_vulkan_available', 'isclose', 'isfinite', 'isin', 'isinf', 'isnan', 'isneginf', 'isposinf', + 'isreal', 'istft', 'kaiser_window', 'kl_div', 'kron', 'kthvalue', 'kthvalue', 'layer_norm', 'lcm', + 'lcm_', 'ldexp', 'ldexp_', 'le', 'lerp', 'less', 'less_equal', 'lgamma', 'linspace', 'log', + 'log10', 'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'log_softmax', 'logaddexp', + 'logaddexp2', 'logcumsumexp', 'logdet', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', + 'logit', 'logit_', 'logspace', 'logsumexp', 'lstm', 'lstm_cell', 'lt', 'lu_solve', 'lu_unpack', + 'lu_unpack', 'margin_ranking_loss', 'masked_fill', 'masked_scatter', 'masked_select', 'matmul', + 'matrix_exp', 'matrix_power', 'max', 'max', 'max_pool1d', 'max_pool1d_with_indices', 'max_pool2d', + 'max_pool3d', 'maximum', 'mean', 'median', 'median', 'min', 'min', 'minimum', 'miopen_batch_norm', + 'miopen_convolution', 'miopen_convolution_add_relu', 'miopen_convolution_relu', + 'miopen_convolution_transpose', 'miopen_depthwise_convolution', 'miopen_rnn', + 'mkldnn_adaptive_avg_pool2d', 'mkldnn_convolution', 'mkldnn_linear_backward_weights', + 'mkldnn_max_pool2d', 'mkldnn_max_pool3d', 'mkldnn_rnn_layer', 'mm', 'mode', 'mode', 'moveaxis', + 'movedim', 'msort', 'mul', 'multinomial', 'multiply', 'mv', 'mvlgamma', 'nan_to_num', + 'nan_to_num_', 'nanmean', 'nanmedian', 'nanmedian', 'nanquantile', 'nansum', 'narrow', + 'narrow_copy', 'native_batch_norm', 'native_channel_shuffle', 'native_dropout', + 'native_group_norm', 'native_layer_norm', 'native_norm', 'ne', 'neg', 'neg_', 'negative', + 'negative_', 'nextafter', 'nonzero', 'nonzero_static', 'norm_except_dim', 'normal', 'not_equal', + 'nuclear_norm', 'numel', 'ones', 'ones_like', 'orgqr', 'ormqr', 'outer', 'pairwise_distance', + 'pdist', 'permute', 'permute_copy', 'pinverse', 'pixel_shuffle', 'pixel_unshuffle', 'poisson', + 'poisson_nll_loss', 'polar', 'polygamma', 'positive', 'pow', 'prelu', 'prod', 'promote_types', + 'put', 'q_per_channel_axis', 'q_per_channel_scales', 'q_per_channel_zero_points', 'q_scale', + 'q_zero_point', 'qr', 'qr', 'quantile', 'quantize_per_channel', 'quantize_per_tensor', + 'quantize_per_tensor_dynamic', 'quantized_batch_norm', 'quantized_gru_cell', 'quantized_lstm_cell', + 'quantized_max_pool1d', 'quantized_max_pool2d', 'quantized_max_pool3d', 'quantized_rnn_relu_cell', + 'quantized_rnn_tanh_cell', 'rad2deg', 'rad2deg_', 'rand', 'rand_like', 'randint', 'randint_like', + 'randn', 'randn_like', 'randperm', 'range', 'ravel', 'real', 'reciprocal', 'reciprocal_', 'relu', + 'relu_', 'remainder', 'renorm', 'repeat_interleave', 'reshape', 'resize_as_', 'resize_as_sparse_', + 'resolve_conj', 'resolve_neg', 'result_type', 'rnn_relu', 'rnn_relu_cell', 'rnn_tanh', + 'rnn_tanh_cell', 'roll', 'rot90', 'round', 'round_', 'row_indices_copy', 'row_stack', 'rrelu', + 'rrelu_', 'rsqrt', 'rsqrt_', 'rsub', 'saddmm', 'scalar_tensor', 'scatter', 'scatter_add', + 'scatter_reduce', 'searchsorted', 'segment_reduce', 'select', 'select_copy', 'select_scatter', + 'selu', 'selu_', 'set_flush_denormal', 'set_num_interop_threads', 'set_num_threads', 'sgn', + 'sigmoid', 'sigmoid_', 'sign', 'signbit', 'sin', 'sin_', 'sinc', 'sinc_', 'sinh', 'sinh_', + 'slice_copy', 'slice_inverse', 'slice_scatter', 'slogdet', 'slogdet', 'smm', 'softmax', 'sort', + 'sort', 'sparse_bsc_tensor', 'sparse_bsr_tensor', 'sparse_compressed_tensor', 'sparse_coo_tensor', + 'sparse_csc_tensor', 'sparse_csr_tensor', 'split_copy', 'split_with_sizes', + 'split_with_sizes_copy', 'spmm', 'sqrt', 'sqrt_', 'square', 'square_', 'squeeze', 'squeeze_copy', + 'sspaddmm', 'stack', 'std', 'std_mean', 'sub', 'subtract', 'sum', 'svd', 'svd', 'swapaxes', + 'swapdims', 'sym_constrain_range', 'sym_constrain_range_for_size', 't', 't_copy', 'take', + 'take_along_dim', 'tan', 'tan_', 'tanh', 'tanh_', 'tensor', 'tensor_split', 'threshold', + 'threshold_', 'tile', 'topk', 'topk', 'trace', 'transpose', 'transpose_copy', 'trapezoid', 'trapz', + 'triangular_solve', 'triangular_solve', 'tril', 'tril_indices', 'triplet_margin_loss', 'triu', + 'triu_indices', 'true_divide', 'trunc', 'trunc_', 'unbind', 'unbind_copy', 'unflatten', + 'unfold_copy', 'unique_dim', 'unsafe_chunk', 'unsafe_split', 'unsafe_split_with_sizes', + 'unsqueeze', 'unsqueeze_copy', 'values_copy', 'vander', 'var', 'var_mean', 'vdot', + 'view_as_complex', 'view_as_complex_copy', 'view_as_real', 'view_as_real_copy', 'view_copy', + 'vsplit', 'vstack', 'where', 'xlogy', 'xlogy_', 'zero_', 'zeros', 'zeros_like'] diff --git a/MLPY/Lib/site-packages/torch/__config__.py b/MLPY/Lib/site-packages/torch/__config__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8cf5710d77a2c2a6e871006b7803f68c85aa7d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/__config__.py @@ -0,0 +1,22 @@ +import torch + + +def show(): + """ + Return a human-readable string with descriptions of the + configuration of PyTorch. + """ + return torch._C._show_config() + + +# TODO: In principle, we could provide more structured version/config +# information here. For now only CXX_FLAGS is exposed, as Timer +# uses them. +def _cxx_flags(): + """Returns the CXX_FLAGS used when building PyTorch.""" + return torch._C._cxx_flags() + + +def parallel_info(): + r"""Returns detailed string with parallelization settings""" + return torch._C._parallel_info() diff --git a/MLPY/Lib/site-packages/torch/_appdirs.py b/MLPY/Lib/site-packages/torch/_appdirs.py new file mode 100644 index 0000000000000000000000000000000000000000..13db32eea62e50c360c651fbbfd9dfff0124cdc8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_appdirs.py @@ -0,0 +1,666 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2005-2010 ActiveState Software Inc. +# Copyright (c) 2013 Eddy Petrișor + +# flake8: noqa + +""" +This file is directly from +https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py + +The license of https://github.com/ActiveState/appdirs copied below: + + +# This is the MIT license + +Copyright (c) 2010 ActiveState Software Inc. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +"""Utilities for determining application-specific dirs. + +See for details and usage. +""" +# Dev Notes: +# - MSDN on where to store app data files: +# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120 +# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html +# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html + +__version__ = "1.4.4" +__version_info__ = tuple(int(segment) for segment in __version__.split(".")) + + +import os +import sys + +unicode = str + +if sys.platform.startswith("java"): + import platform + + os_name = platform.java_ver()[3][0] + if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc. + system = "win32" + elif os_name.startswith("Mac"): # "Mac OS X", etc. + system = "darwin" + else: # "Linux", "SunOS", "FreeBSD", etc. + # Setting this to "linux2" is not ideal, but only Windows or Mac + # are actually checked for and the rest of the module expects + # *sys.platform* style strings. + system = "linux2" +else: + system = sys.platform + + +def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user data directories are: + Mac OS X: ~/Library/Application Support/ + Unix: ~/.local/share/ # or in $XDG_DATA_HOME, if defined + Win XP (not roaming): C:\Documents and Settings\\Application Data\\ + Win XP (roaming): C:\Documents and Settings\\Local Settings\Application Data\\ + Win 7 (not roaming): C:\Users\\AppData\Local\\ + Win 7 (roaming): C:\Users\\AppData\Roaming\\ + + For Unix, we follow the XDG spec and support $XDG_DATA_HOME. + That means, by default "~/.local/share/". + """ + if system == "win32": + if appauthor is None: + appauthor = appname + const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA" + path = os.path.normpath(_get_win_folder(const)) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == "darwin": + path = os.path.expanduser("~/Library/Application Support/") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_data_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of data dirs should be + returned. By default, the first item from XDG_DATA_DIRS is + returned, or '/usr/local/share/', + if XDG_DATA_DIRS is not set + + Typical site data directories are: + Mac OS X: /Library/Application Support/ + Unix: /usr/local/share/ or /usr/share/ + Win XP: C:\Documents and Settings\All Users\Application Data\\ + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + Win 7: C:\ProgramData\\ # Hidden, but writeable on Win 7. + + For Unix, this is using the $XDG_DATA_DIRS[0] default. + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == "darwin": + path = os.path.expanduser("/Library/Application Support") + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_DATA_DIRS + # only first, if multipath is False + path = os.getenv( + "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"]) + ) + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + if appname and version: + path = os.path.join(path, version) + return path + + +def user_config_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific config dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user config directories are: + Mac OS X: ~/Library/Preferences/ + Unix: ~/.config/ # or in $XDG_CONFIG_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME. + That means, by default "~/.config/". + """ + if system == "win32": + path = user_data_dir(appname, appauthor, None, roaming) + elif system == "darwin": + path = os.path.expanduser("~/Library/Preferences/") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_config_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of config dirs should be + returned. By default, the first item from XDG_CONFIG_DIRS is + returned, or '/etc/xdg/', if XDG_CONFIG_DIRS is not set + + Typical site config directories are: + Mac OS X: same as site_data_dir + Unix: /etc/xdg/ or $XDG_CONFIG_DIRS[i]/ for each value in + $XDG_CONFIG_DIRS + Win *: same as site_data_dir + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + + For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == "win32": + path = site_data_dir(appname, appauthor) + if appname and version: + path = os.path.join(path, version) + elif system == "darwin": + path = os.path.expanduser("/Library/Preferences") + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_CONFIG_DIRS + # only first, if multipath is False + path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg") + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + +def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific cache dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Cache" to the base app data dir for Windows. See + discussion below. + + Typical user cache directories are: + Mac OS X: ~/Library/Caches/ + Unix: ~/.cache/ (XDG default) + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Cache + Vista: C:\Users\\AppData\Local\\\Cache + + On Windows the only suggestion in the MSDN docs is that local settings go in + the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming + app data dir (the default returned by `user_data_dir` above). Apps typically + put cache data somewhere *under* the given dir here. Some examples: + ...\Mozilla\Firefox\Profiles\\Cache + ...\Acme\SuperApp\Cache\1.0 + OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. + This can be disabled with the `opinion=False` option. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + if opinion: + path = os.path.join(path, "Cache") + elif system == "darwin": + path = os.path.expanduser("~/Library/Caches") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_state_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific state dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user state directories are: + Mac OS X: same as user_data_dir + Unix: ~/.local/state/ # or in $XDG_STATE_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow this Debian proposal + to extend the XDG spec and support $XDG_STATE_HOME. + + That means, by default "~/.local/state/". + """ + if system in ["win32", "darwin"]: + path = user_data_dir(appname, appauthor, None, roaming) + else: + path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific log dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Logs" to the base app data dir for Windows, and "log" to the + base cache dir for Unix. See discussion below. + + Typical user log directories are: + Mac OS X: ~/Library/Logs/ + Unix: ~/.cache//log # or under $XDG_CACHE_HOME if defined + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Logs + Vista: C:\Users\\AppData\Local\\\Logs + + On Windows the only suggestion in the MSDN docs is that local settings + go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in + examples of what some windows apps use for a logs dir.) + + OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA` + value for Windows and appends "log" to the user cache dir for Unix. + This can be disabled with the `opinion=False` option. + """ + if system == "darwin": + path = os.path.join(os.path.expanduser("~/Library/Logs"), appname) + elif system == "win32": + path = user_data_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "Logs") + else: + path = user_cache_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "log") + if appname and version: + path = os.path.join(path, version) + return path + + +class AppDirs(object): + """Convenience wrapper for getting application dirs.""" + + def __init__( + self, appname=None, appauthor=None, version=None, roaming=False, multipath=False + ): + self.appname = appname + self.appauthor = appauthor + self.version = version + self.roaming = roaming + self.multipath = multipath + + @property + def user_data_dir(self): + return user_data_dir( + self.appname, self.appauthor, version=self.version, roaming=self.roaming + ) + + @property + def site_data_dir(self): + return site_data_dir( + self.appname, self.appauthor, version=self.version, multipath=self.multipath + ) + + @property + def user_config_dir(self): + return user_config_dir( + self.appname, self.appauthor, version=self.version, roaming=self.roaming + ) + + @property + def site_config_dir(self): + return site_config_dir( + self.appname, self.appauthor, version=self.version, multipath=self.multipath + ) + + @property + def user_cache_dir(self): + return user_cache_dir(self.appname, self.appauthor, version=self.version) + + @property + def user_state_dir(self): + return user_state_dir(self.appname, self.appauthor, version=self.version) + + @property + def user_log_dir(self): + return user_log_dir(self.appname, self.appauthor, version=self.version) + + +# ---- internal support stuff + + +def _get_win_folder_from_registry(csidl_name): + """This is a fallback technique at best. I'm not sure if using the + registry for this guarantees us the correct answer for all CSIDL_* + names. + """ + import winreg as _winreg + + shell_folder_name = { + "CSIDL_APPDATA": "AppData", + "CSIDL_COMMON_APPDATA": "Common AppData", + "CSIDL_LOCAL_APPDATA": "Local AppData", + }[csidl_name] + + key = _winreg.OpenKey( + _winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir, type = _winreg.QueryValueEx(key, shell_folder_name) + return dir + + +def _get_win_folder_with_pywin32(csidl_name): + from win32com.shell import shell, shellcon + + dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0) + # Try to make this a unicode path because SHGetFolderPath does + # not return unicode strings when there is unicode data in the + # path. + try: + dir = unicode(dir) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + try: + import win32api + + dir = win32api.GetShortPathName(dir) + except ImportError: + pass + except UnicodeError: + pass + return dir + + +def _get_win_folder_with_ctypes(csidl_name): + import ctypes + + csidl_const = { + "CSIDL_APPDATA": 26, + "CSIDL_COMMON_APPDATA": 35, + "CSIDL_LOCAL_APPDATA": 28, + }[csidl_name] + + buf = ctypes.create_unicode_buffer(1024) + ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in buf: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf2 = ctypes.create_unicode_buffer(1024) + if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): + buf = buf2 + + return buf.value + + +def _get_win_folder_with_jna(csidl_name): + import array + + from com.sun import jna + from com.sun.jna.platform import win32 + + buf_size = win32.WinDef.MAX_PATH * 2 + buf = array.zeros("c", buf_size) + shell = win32.Shell32.INSTANCE + shell.SHGetFolderPath( + None, + getattr(win32.ShlObj, csidl_name), + None, + win32.ShlObj.SHGFP_TYPE_CURRENT, + buf, + ) + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf = array.zeros("c", buf_size) + kernel = win32.Kernel32.INSTANCE + if kernel.GetShortPathName(dir, buf, buf_size): + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + return dir + + +if system == "win32": + try: + import win32com.shell + + _get_win_folder = _get_win_folder_with_pywin32 + except ImportError: + try: + from ctypes import windll + + _get_win_folder = _get_win_folder_with_ctypes + except ImportError: + try: + import com.sun.jna + + _get_win_folder = _get_win_folder_with_jna + except ImportError: + _get_win_folder = _get_win_folder_from_registry + + +# ---- self test code + +if __name__ == "__main__": + appname = "MyApp" + appauthor = "MyCompany" + + props = ( + "user_data_dir", + "user_config_dir", + "user_cache_dir", + "user_state_dir", + "user_log_dir", + "site_data_dir", + "site_config_dir", + ) + + print(f"-- app dirs {__version__} --") + + print("-- app dirs (with optional 'version')") + dirs = AppDirs(appname, appauthor, version="1.0") + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") + + print("\n-- app dirs (without optional 'version')") + dirs = AppDirs(appname, appauthor) + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") + + print("\n-- app dirs (without optional 'appauthor')") + dirs = AppDirs(appname) + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") + + print("\n-- app dirs (with disabled 'appauthor')") + dirs = AppDirs(appname, appauthor=False) + for prop in props: + print(f"{prop}: {getattr(dirs, prop)}") diff --git a/MLPY/Lib/site-packages/torch/_awaits/__init__.py b/MLPY/Lib/site-packages/torch/_awaits/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9b1fef2960fcc66be6b43ba0e0d92856a799f5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_awaits/__init__.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import cast, Callable, Generic, Type, TypeVar + +import torch + +__all__ = ['Await'] + +W = TypeVar("W") + +class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef] + pass + +class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta): + r""" + Wrapper around a ``torch._C.Await`` which encapsulates delayed execution + of a callable. All manipulations happen with functions ``torch.jit._awaitable``, + ``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``. + + Torch scriptable manipulations: + ``torch.jit._awaitable(func, *args)`` + Creates ``Await[W]`` object, where W is return type of func. + + Returns: + ``torch.jit._awaitable_wait(Await[W])`` + Returns the result of the function, specified at ``_awaitable``, with specified arguments. + + Returns: + The result of type ``W`` of the function call. The result is owned by ``Await[W]`` + and returned on all following ``_awaitable_wait`` calls. + + + ``torch.jit._awaitable_nowait(W)`` + Returns: + Trivial ``Await[W]`` with specified result. + + + Only in eager mode: + ``fn() -> Callable[Tuple[Any], W]`` + Returns: + Specified at ``_awaitable`` python function ``func``. + + ``args() -> Tuple[Any]`` + Returns: + Specified at ``_awaitable`` python args. + + ``is_nowait() -> _bool`` + Returns: + ``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`). + + In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``, + ``_awaitable_wait()`` call will be transparently added. + """ + pass diff --git a/MLPY/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c4372f600412b883e4b9823c9f9f435e9c687a3 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_classes.py b/MLPY/Lib/site-packages/torch/_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..1d160312d883b2081317c3bf013ea5c4604614f9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_classes.py @@ -0,0 +1,55 @@ +import types + +import torch._C + + +class _ClassNamespace(types.ModuleType): + def __init__(self, name): + super().__init__("torch.classes" + name) + self.name = name + + def __getattr__(self, attr): + proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) + if proxy is None: + raise RuntimeError(f"Class {self.name}.{attr} not registered!") + return proxy + + +class _Classes(types.ModuleType): + __file__ = "_classes.py" + + def __init__(self): + super().__init__("torch.classes") + + def __getattr__(self, name): + namespace = _ClassNamespace(name) + setattr(self, name, namespace) + return namespace + + @property + def loaded_libraries(self): + return torch.ops.loaded_libraries + + def load_library(self, path): + """ + Loads a shared library from the given path into the current process. + + The library being loaded may run global initialization code to register + custom classes with the PyTorch JIT runtime. This allows dynamically + loading custom classes. For this, you should compile your class + and the static registration code into a shared library object, and then + call ``torch.classes.load_library('path/to/libcustom.so')`` to load the + shared object. + + After the library is loaded, it is added to the + ``torch.classes.loaded_libraries`` attribute, a set that may be inspected + for the paths of all libraries loaded using this function. + + Args: + path (str): A path to a shared library to load. + """ + torch.ops.load_library(path) + + +# The classes "namespace" +classes = _Classes() diff --git a/MLPY/Lib/site-packages/torch/_compile.py b/MLPY/Lib/site-packages/torch/_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..576d4218c4c49cf55a34efd68198dd021f4ba7dc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_compile.py @@ -0,0 +1,30 @@ +""" +APIs related to torch.compile which lazily import torch._dynamo to avoid +circular dependencies. +""" +import functools + + +def _disable_dynamo(fn=None, recursive=True): + """ + This API should be only used inside torch, external users should still use + torch._dynamo.disable. The main goal of this API is to avoid circular + imports issues that is common while using _dynamo.disable inside torch + itself. + + This API avoids it by lazily importing torch._dynamo from the import time to + the invocation of the decorated function. + """ + if fn is not None: + + @functools.wraps(fn) + def inner(*args, **kwargs): + import torch._dynamo + + return torch._dynamo.disable(fn, recursive)(*args, **kwargs) + + return inner + else: + # decorator usage like @_disable_dynamo(recursive=False). The resulting + # object expects the original decorated function as the arg. + return functools.partial(_disable_dynamo, recursive=recursive) diff --git a/MLPY/Lib/site-packages/torch/_custom_op/__init__.py b/MLPY/Lib/site-packages/torch/_custom_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c603890cf093be83d8ecc39bd24bcb647bdd4d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ac24eb64d0272731f0947172c19cd171def9f1d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/functional.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..628deaa6fee76d50c279406f89358802fe352b2c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/functional.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c8d06dfef35ab53fad39c9d20c230f6f443bc6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_custom_op/autograd.py b/MLPY/Lib/site-packages/torch/_custom_op/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..6be5ab372b96a203e485a4465ab984712f1b6380 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_custom_op/autograd.py @@ -0,0 +1,274 @@ +import torch +import torch.utils._pytree as pytree +from collections import namedtuple +import functools + + +# NOTE [CustomOp autograd kernel indirection] +# We register `inner` as the autograd kernel for this custom_op. +# `inner` either calls the autograd formula registered by the user, +# or goes into an `autograd_not_implemented` kernel. +# +# The reason why this indirection exists is +# so that we can swap out the autograd kernel (the PyTorch dispatcher +# doesn't actually allow us to do this). By default, we want +# the `autograd_not_implemented` behavior, but then the user may come +# and register something that is actually a backward formula +def autograd_kernel_indirection(custom_op): + autograd_fallback = autograd_not_implemented(custom_op) + + def inner(*args, **kwargs): + if custom_op._has_impl('autograd'): + kernel = custom_op._get_impl('autograd').func + return kernel(*args, **kwargs) + # As explained in NOTE ["backward", "save_for_backward", and "autograd"], + # after the user gives us "backward" and "save_for_backward", we generate + # the "autograd" impl. If the user only provided one, then we tell + # the user they've done something wrong. + if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): + missing = ( + 'save_for_backward' if custom_op._has_impl('backward') + else 'backward' + ) + found = 'save_for_backward' if missing == 'backward' else 'backward' + loc = custom_op._get_impl(found).location + raise RuntimeError( + f"We found a '{found}' registration for {custom_op} at " + f"{loc} but were unable to find a '{missing}' registration. " + f"To use the CustomOp API to register a backward formula, " + f"please provide us both a backward function and a " + f"'save for backward' function via `impl_backward` and " + f"`impl_save_for_backward` respectively.") + return autograd_fallback(*args, **kwargs) + return inner + + +# TODO(#101191): Use the actual C++ autograd not implemented fallback, +# or change the default autograd fallback to the autograd not implemented fallback. +def autograd_not_implemented(custom_op): + def kernel(*args, **kwargs): + if torch.is_grad_enabled() and pytree.tree_any( + lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) + ): + raise RuntimeError("Autograd has not been implemented for operator") + with torch._C._AutoDispatchBelowAutograd(): + return custom_op(*args, **kwargs) + return kernel + + +def mark_non_differentiable(ctx, output, output_differentiability): + # Output types are restricted to be: + # - Tensor + # - Tensor[] + # - int, bool, Scalar, float + # See _check_can_register_backward + if output_differentiability is not None: + if not isinstance(output, tuple): + tuple_output = (output,) + else: + tuple_output = output # type: ignore[assignment] + assert len(output_differentiability) == len(tuple_output) + non_differentiable_tensors = [] + for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): + if isinstance(out, torch.Tensor): + if not differentiable: + non_differentiable_tensors.append(out) + continue + if isinstance(out, list): + if not differentiable: + non_differentiable_tensors.extend(out) + continue + if differentiable: + raise RuntimeError( + f"With output_differentiability={output_differentiability}. " + f"At idx {idx}, we received an object of type {type(out)} that " + f"is not a Tensor, so it cannot have be marked as differentiable in " + f"output_differentiability.") + if non_differentiable_tensors: + ctx.mark_non_differentiable(*non_differentiable_tensors) + + +def construct_autograd_kernel( + schema, + output_differentiability, + custom_op, + op_overload, + save_for_backward_fn, + backward_fn): + + def apply(*args): + flat_args, spec = pytree.tree_flatten(args) + out_spec = None + + def forward(ctx, *flat_args): + ctx.set_materialize_grads(True) + args = pytree.tree_unflatten(list(flat_args), spec) + with torch._C._AutoDispatchBelowAutograd(): + output = op_overload(*args) + + # We use the info about args to give better error messages in backward + args_info = namedtuple_args( + schema, pytree.tree_map(type, args)) + + save_for_backward_fn_inputs = namedtuple_args(schema, args) + to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) + + save_pytree_for_backward(ctx, (to_save, args_info)) + mark_non_differentiable(ctx, output, output_differentiability) + + nonlocal out_spec + flat_output, out_spec = pytree.tree_flatten(output) + return tuple(flat_output) + + def backward(ctx, *flat_grad_output): + assert out_spec is not None + grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) + saved, args_info = unpack_saved(ctx) + # There is nothing on the ctx object for now, it is just there so + # that we can add additional things in the future. + inner_ctx = object() + if not isinstance(grads, tuple): + grads = (grads,) + grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) + + # Massage the grad_inputs_dict to a form acceptable by + # autograd.Function. + validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) + return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) + + generated_cls = gen_autograd_function( + custom_op._opname + '_customop', forward, backward) + + flat_output = generated_cls.apply(*flat_args) + assert out_spec is not None + return pytree.tree_unflatten(list(flat_output), out_spec) + return apply + + +def gen_autograd_function(name, forward, backward): + generated_cls = type( + name, + (torch.autograd.Function,), + { + 'forward': staticmethod(forward), + 'backward': staticmethod(backward), + } + ) + return generated_cls + + +@functools.lru_cache +def namedtuple_args_cls(schema): + attribs = [arg.name for arg in schema.arguments.flat_all] + name = str(schema.name) + "_args" + # mypy doesn't support dynamic namedtuple name + tuple_cls = namedtuple(name, attribs) # type: ignore[misc] + return tuple_cls + + +def namedtuple_args(schema, args): + assert isinstance(args, tuple) + tuple_cls = namedtuple_args_cls(schema) + return tuple_cls(*args) + + +def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): + def error(what): + backward = forward_op._get_impl('backward') + raise RuntimeError( + f"In the backward function defined for {forward_op} at " + f"{backward.location} using the CustomOp API, {what}") + + if not isinstance(grad_inputs_dict, dict): + error(f"expected the output of the backward function to be a dict but " + f"got {type(grad_inputs_dict)}") + + expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all + if arg.type.is_tensor_like()} + actual_keys = grad_inputs_dict.keys() + if expected_keys != actual_keys: + error(f"expected the returned grad_input dict to have keys " + f"{expected_keys} but got {actual_keys}. The backward " + f"function must return a gradient (can be None) for each arg " + f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " + f"Args declared to be non-Tensor-like types should not appear " + f"in the grad_input dict") + + for name, grad in grad_inputs_dict.items(): + arg_info = getattr(args_info, name) + + if isinstance(arg_info, list): + if not isinstance(grad, (tuple, list)): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of gradients but got object of type " + f"{type(grad)}.") + if not len(grad) == len(arg_info): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of {len(arg_info)} gradients but got " + f"{len(grad)}") + for idx, (g, info) in enumerate(zip(grad, arg_info)): + if g is None: + continue + if not isinstance(g, torch.Tensor): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of None or Tensor gradients but got " + f"object of {type(g)} at index {idx}") + if not issubclass(info, torch.Tensor): + error(f"for input '{name}', got a Tensor as the gradient " + f"for the {idx}-th value but expected None because " + f"the {idx}-th value was not a Tensor (it was " + f"type {arg_info}") + continue + + if grad is None: + continue + if not isinstance(grad, torch.Tensor): + error(f"got object of type {type(grad)} as the gradient for input " + f"'{name}', " + f"but expected the gradient to be either None or a Tensor") + if not issubclass(arg_info, torch.Tensor): + error(f"got a Tensor as the gradient for input '{name}' but " + f"expected None as the gradient because input '{name}' " + f"was not a Tensor (it was type {arg_info}).") + + +def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): + result = [] + for name, arg_info in args_info._asdict().items(): + if name not in grad_inputs_dict: + result.append(pytree.tree_map(lambda x: None, arg_info)) + continue + result.append(grad_inputs_dict[name]) + return tuple(pytree.tree_leaves(result)) + +# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. +# autograd.Function prefers that users use ctx.save_for_backward to +# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the +# ctx object. +def save_pytree_for_backward(ctx, stuff): + flat_stuff, spec = pytree.tree_flatten(stuff) + num_elts = len(flat_stuff) + tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) + if isinstance(thing, torch.Tensor)] + non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) + if not isinstance(thing, torch.Tensor)] + tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] + non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] + + ctx.spec = spec + ctx.num_elts = num_elts + ctx.save_for_backward(*tensors) + ctx.tensor_idxs = tensor_idxs + ctx.saved_non_tensors = non_tensors + ctx.non_tensor_idxs = non_tensor_idxs + + +# Inverse operation to save_pytree_for_backward +def unpack_saved(ctx): + flat_stuff = [None] * ctx.num_elts + for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): + flat_stuff[idx] = tensor + for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): + flat_stuff[idx] = non_tensor + stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) + return stuff diff --git a/MLPY/Lib/site-packages/torch/_custom_op/functional.py b/MLPY/Lib/site-packages/torch/_custom_op/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..a15e920c3c018e7156da4db86195f3e3a02fd0ef --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_custom_op/functional.py @@ -0,0 +1,187 @@ +import weakref + +import torch +import torch.utils._pytree as pytree +from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet +from torch._ops import OpOverload +from torch.library import Library +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + OperatorName, + OptionalType, + SchemaKind, +) + +from .autograd import autograd_not_implemented + + +def register_functional_op( + lib: Library, + new_op_name: str, + mutable_op: OpOverload, +) -> None: + """Given a mutable operator, registers the functional variant. + + This API also correctly links the functional variant with the mutable + operator for the purposes of functionalization. + + All of the new registrations are performed on the ``lib`` passed in. + + Arguments: + lib (Library): Should be a torch.library.Library object that has + the same namespace as ``mutable_op``'s namespace. + lib will be used to register the new functional op as well + as a functionalization kernel for the ``mutable_op`` + If you don't have a library handy, use + ``torch.library.Library(ns, 'FRAGMENT')`` to construct one. + new_op_name (str): The name of the functional operator (without the + namespace). If no namespace, the new functional variant will be + accessible under ``torch.ops.{lib.ns}.new_op_name``. + mutable_op (OpOverload): The mutable custom operator. Note + that you may need to add a `.default` to it, like + `torch.ops.aten.abs_.default`. + + """ + validate(mutable_op) + schema = functional_schema(new_op_name, mutable_op) + lib.define(schema) + + functional_impl = construct_functional_impl(mutable_op) + lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') + + functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default + + # There's no easy way for us to generate the autograd kernel, so we + # use autograd_not_implemented. Also, this makes it so that the user + # is unable to register an autograd formula themselves. This shouldn't + # be a problem if the user doesn't use the functional op direclty + # in their program, but we may need to revist this in the future. + lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') + + f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) + + lib.impl(mutable_op, f_kernel, 'Functionalize') + + +def construct_functional_impl(mutable_op): + def functional_impl(*args): + # Strategy: + # - clone args that would have been mutated + # - run mutable_op + # - return the cloned args as additional outputs + new_args = [] + extra_rets = [] + for is_write, arg in zip(mutable_args(mutable_op), args): + if is_write: + cloned = arg.clone() if arg is not None else None + new_args.append(cloned) + extra_rets.append(cloned) + else: + new_args.append(arg) + result = mutable_op(*new_args) + if result is None: + return tuple(extra_rets) + if isinstance(result, tuple): + return (*result, *extra_rets) + return (result, *extra_rets) + return functional_impl + + +def construct_functionalization_kernel(mutable_op, functional_op): + def kernel(*args): + # There's nothing to be functionalized! + # We can still end up here because DispatchKey::Functionalize is a mode key + if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + return mutable_op(*args) + + # NB: This differs from the codegen -- codegen handles cases where there + # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper. + # This only really matters for XLA (mixed CPU-XLA tensors) and + # running functionalization without the PT2 stack (which guarantees to us that + # all tensors are FunctionalTensorWrapper). + if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): + raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") + + unwrapped_args = [] + for arg in args: + if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): + torch._sync(arg) + unwrapped = torch._from_functional_tensor(arg) + unwrapped_args.append(unwrapped) + else: + unwrapped_args.append(arg) + + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + output = functional_op(*unwrapped_args) + + num_actual_output = len(mutable_op._schema.returns) + actual_output = pytree.tree_map( + torch._to_functional_tensor, output[:num_actual_output]) + + new_values_to_propagate = output[num_actual_output:] + inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) + if is_write] + assert len(new_values_to_propagate) == len(inputs_to_replace) + for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): + if (arg is None and new_value is None) or (arg is not None and new_value is not None): + continue + torch._C._propagate_xla_data(arg, new_value) + torch._C._replace_(arg, new_value) + torch._C._commit_update(arg) + torch._sync(arg) + + if len(actual_output) == 1: + return actual_output[0] + elif len(actual_output) == 0: + return None + return actual_output + + return kernel + + +def validate(mutable_op: OpOverload): + if not isinstance(mutable_op, OpOverload): + raise TypeError( + f"register_functional_op(mutable_op): expected mutable_op to be instance of " + f"OpOverload but got {type(mutable_op)}") + + # There are generally three types of "in-place" or "mutable" ops. + # Each of them have their own conventions: + # - inplace (first input modified in-place and returned as only output) + # - out= (some args modified in-place and returned as outputs) + # - mutable (some args modified in-place but none of those returned as outputs) + # In theory we can support all three, but we'll just support the last + # option right now for simplicity. + schema = FunctionSchema.parse(str(mutable_op._schema)) + if not schema.kind() == SchemaKind.mutable: + raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") + for ret in schema.returns: + # construct_functionalization_kernel assumes this for simplicity + if ret.annotation is not None: + raise NotImplementedError( + "NYI: register_functional_op(op) where op returns a mutated or aliased value. " + "Please file an issue (and as a workaround, modify your operator to " + "not return the mutated value or aliases)") + for arg in schema.arguments.flat_all: + # construct_functionalization_kernel assumes this for simplicity + if arg.type.is_tensor_like() and ( + arg.type != BaseType(BaseTy.Tensor) + and arg.type != OptionalType(BaseType(BaseTy.Tensor)) + ): + raise NotImplementedError( + "NYI: register_functional_op(op) where op has a List[Tensor] input." + "Please file an issue.") + + +def functional_schema(new_op_name, op: OpOverload): + schema = FunctionSchema.parse(str(op._schema)) + schema = schema.signature().with_name(OperatorName.parse(new_op_name)) + return str(schema) + + +def mutable_args(op: OpOverload): + return tuple(False if arg.alias_info is None else arg.alias_info.is_write + for arg in op._schema.arguments) diff --git a/MLPY/Lib/site-packages/torch/_custom_op/impl.py b/MLPY/Lib/site-packages/torch/_custom_op/impl.py new file mode 100644 index 0000000000000000000000000000000000000000..e5afbbce849f18574f8d89cd6bed9c55fda753ed --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_custom_op/impl.py @@ -0,0 +1,976 @@ +import dataclasses +import functools +import inspect +import sys +import typing +import weakref + +from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy + +import torch +import torch._C as _C +import torch.library as library +from torch._library.abstract_impl import AbstractImplCtx +from torch.library import get_ctx + +from .autograd import autograd_kernel_indirection, construct_autograd_kernel + +""" +For a detailed guide on custom ops, please see +https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + +This file includes pieces of the implementation of our custom operator API. +""" + +__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"] + + +SUPPORTED_DEVICE_TYPE_TO_KEY = { + "cpu": "CPU", + "cuda": "CUDA", +} + +# We will not let users register CustomOps with anything that could look like +# PyTorch internals to avoid confusion. +RESERVED_NS = { + "prim", + "prims", + "aten", + "at", + "torch", + "pytorch", +} + + +def custom_op( + qualname: str, manual_schema: typing.Optional[str] = None +) -> typing.Callable: + r"""Creates a new CustomOp object. + + WARNING: if you're a user, please do not use this directly + (instead use the torch._custom_ops APIs). + Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define (create) the op + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the CustomOp object (the first step); + you must then perform the second step by calling various methods on + the CustomOp object. + + This API is used as a decorator (see examples). + + Arguments: + qualname (str): Should be a string that looks like + "namespace::operator_name". Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. The operator_name must be + the same as the name of the function you pass to custom_op + (see examples). + manual_schema (Optional[str]): Each PyTorch operator needs a schema that + tells PyTorch the types of the inputs/outputs. If None (default), + we will infer the schema from the type annotations on the function + (see examples). Otherwise, if you don't want to use type annotations, + you may provide us the schema string. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the CustomOp. + >>> # We need to provide the decorator a "prototype function" + >>> # (a function with Python ellipses as the body). + >>> @custom_op("my_library::numpy_sin") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> ... + >>> + >>> # numpy_sin is now an instance of class CustomOp + >>> print(type(numpy_sin)) + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @numpy_sin.impl('cpu') + >>> def numpy_sin_impl_cpu(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @numpy_sin.impl('cuda') + >>> def numpy_sin_impl_cuda(x): + >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> numpy_sin(x) # calls numpy_sin_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> numpy_sin(x) # calls numpy_sin_impl_cuda + + """ + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + ns, name = parse_qualname(qualname) + validate_namespace(ns) + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = infer_schema(func) if manual_schema is None else manual_schema + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + if manual_schema is not None: + validate_function_matches_schema(function_schema, func) + + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + + result.__name__ = func.__name__ + result.__module__ = func.__module__ + result.__doc__ = func.__doc__ + + library.impl(lib, result._opname, "Autograd")( + autograd_kernel_indirection(weakref.proxy(result)) + ) + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + + return result + + return inner + + +# Global dictionary holding references to all CustomOp objects +# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime]) +# Used to query the CustomOp associated with a specific C++ dispatcher operator. +# An example usage is FakeTensor: FakeTensor checks if a specific operator +# has an implementation registered via the CustomOp API. +# Indexed by qualname (e.g. aten::foo) +global_registry: typing.Dict[str, "CustomOp"] = {} + + +class CustomOp: + r"""Class for custom operators in PyTorch. + + Use the CustomOp API to create user-defined custom operators that behave + just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it + comes to various PyTorch subsystems (like torch.compile). + + To construct a `CustomOp`, use `custom_op`. + """ + + def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False): + super().__init__() + if not _private_access: + raise RuntimeError( + "The CustomOp constructor is private and we do not guarantee " + "BC for it. Please use custom_op(...) to create a CustomOp object" + ) + name = f"{cpp_ns}::{operator_name}" + self._schema = schema + self._cpp_ns = cpp_ns + self._lib: library.Library = lib + self._ophandle: _C._DispatchOperatorHandle = ophandle + # Has the name of the op, e.g. "foo". We cache here for convenience. + self._opname: str = operator_name + # this is _opname but with namespace. e.g. "custom::foo" + self._qualname: str = name + self.__name__ = None # mypy requires this + # NB: Some of these impls are registered as kernels to DispatchKeys. + # Modifying the _impls dict directly won't do anything in that case. + self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {} + # See NOTE [CustomOp autograd kernel indirection] + self._registered_autograd_kernel_indirection = False + + global_registry[self._qualname] = self + + def _register_autograd_kernel_indirection(self): + assert not self._registered_autograd_kernel_indirection + self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd") + self._registered_autograd_kernel_indirection = True + + # Records the impl and the source location in self._impls + # Note that this doesn't cause torch.library to use the impl, that + # needs to be done in a separate self._lib.impl call. + def _register_impl(self, kind, func, stacklevel=2): + if self._has_impl(kind): + func_and_location = self._impls[kind] + assert func_and_location is not None # Pacify mypy + location = func_and_location.location + raise RuntimeError( + f"Attempting to register a {kind} impl for operator {self._qualname} " + f"that already has a {kind} impl registered from Python at " + f"{location}. This is not supported." + ) + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + location = f"{frame.filename}:{frame.lineno}" + self._impls[kind] = FuncAndLocation(func, location) + + def _get_impl(self, kind): + return self._impls[kind] + + def _has_impl(self, kind): + return kind in self._impls + + def _destroy(self): + # NOTE: [CustomOp lifetime] + # A CustomOp, once created, lives forever. The mechanism is that the + # global registry holds a reference to it. However, to make testing + # easier, we want to be able to destroy CustomOp objects. + # CustomOp._destroy does the job, though it leaves the CustomOp + # in a garbage state. + del self._lib + + opnamespace = getattr(torch.ops, self._cpp_ns) + if hasattr(opnamespace, self._opname): + delattr(opnamespace, self._opname) + + del global_registry[self._qualname] + + def __repr__(self): + return f'' + + def __call__(self, *args, **kwargs): + # Bypass torch.ops.* and directly do OperatorHandle::callBoxed. + # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime + # issues from caching operators that make testing CustomOp difficult). + result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) + return result + + def impl( + self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2, + ) -> typing.Callable: + r"""Register an implementation for a device type for this CustomOp object. + + WARNING: if you're a user, please do not use this directly + (instead use the torch._custom_ops APIs). + Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + If the CustomOp is passed multiple Tensor inputs with different device + types, it will dispatch to the registered implementation for the highest + priority device type among those present. + The supported device types, in order of priority, are {'cuda', 'cpu'}. + + This API is used as a decorator (see examples). + + Arguments: + device_types (str or Iterable[str]): the device type(s) to register the function for. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @custom_op("my_library::numpy_cos") + >>> def numpy_cos(x: Tensor) -> Tensor: + >>> ... + >>> + >>> # Register an implementation for CPU Tensors + >>> @numpy_cos.impl('cpu') + >>> def numpy_cos_impl_cpu(x): + >>> return torch.from_numpy(np.cos(x.numpy())) + >>> + >>> # Register an implementation for CUDA Tensors + >>> @numpy_cos.impl('cuda') + >>> def numpy_cos_impl_cuda(x): + >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> numpy_cos(x) # calls numpy_cos_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> numpy_cos(x) # calls numpy_cos_impl_cuda + + """ + if isinstance(device_types, str): + device_types = [device_types] + for device_type in device_types: + validate_device_type(device_type) + + def inner(f): + for device_type in set(device_types): + self._check_doesnt_have_library_impl(device_type) + self._register_impl(device_type, f, stacklevel=_stacklevel) + dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + library.impl(self._lib, self._opname, dispatch_key)(f) + return f + + return inner + + def _check_doesnt_have_library_impl(self, device_type): + if self._has_impl(device_type): + return + key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl(..., device_types={device_type}): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing torch.library or TORCH_LIBRARY registration.") + + def impl_factory(self) -> typing.Callable: + r"""Register an implementation for a factory function.""" + + def inner(f): + self._register_impl("factory", f) + library.impl(self._lib, self._opname, "BackendSelect")(f) + return f + + return inner + + def impl_abstract(self, _stacklevel=2) -> typing.Callable: + r"""Register an abstract implementation for this operator. + + WARNING: please do not use this directly (and instead use the torch._custom_ops + APIs). Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + An "abstract implementation" specifies the behavior of this operator on + Tensors that carry no data. Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + The abstract implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write an abstract + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The abstract implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API is used as a decorator (see examples). + + Examples:: + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @custom_op('my_library::custom_linear') + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> ... + >>> + >>> @custom_linear.impl_abstract() + >>> def custom_linear_abstract(x, weight): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @custom_op('my_library::custom_nonzero') + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> ... + >>> + >>> @custom_nonzero.impl_abstract() + >>> def custom_nonzero_abstract(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch._custom_op.get_ctx() + >>> nnz = ctx.create_unbacked_symint() + >>> shape = [x.dim(), nnz] + >>> result = x.new_empty(shape, dtype=torch.long) + >>> return result + >>> + >>> @custom_nonzero.impl(['cpu', 'cuda']) + >>> def custom_nonzero_impl(x): + >>> x_np = to_numpy(x) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> # unbacked symbolic ints in PyTorch must be >= 2, so we + >>> # constrain the range to at least 2 + >>> if res.shape[0] <= 1: + >>> raise RuntimeError("not supported") + >>> return torch.tensor(res, device=x.device) + + """ + + def inner(f): + self._check_doesnt_have_library_meta_impl() + self._register_impl("abstract", f, stacklevel=_stacklevel) + location = self._get_impl("abstract").location + + qualname = self._qualname + + # Handle DispatchKey.Meta registration + @functools.wraps(f) + def f_with_ctx(*args, **kwargs): + def error_on_ctx(): + raise RuntimeError( + f"Attempted to call get_ctx() for the meta implementation " + f"for {qualname}." + f"You have presumably called get_ctx() because the operator " + f"has a data-dependent output shape; if so, there is no " + f"such meta implementation and this error is the correct " + f"behavior. Otherwise, please remove the call to get_ctx() " + f"in the implementation registered with impl_abstract " + f"at {location}" + ) + + with torch._library.abstract_impl.set_ctx_getter(error_on_ctx): + return f(*args, **kwargs) + + self._lib.impl(self._opname, f_with_ctx, "Meta") + return f + + return inner + + def _check_can_register_backward(self): + def error(detail): + raise RuntimeError( + f"Cannot use torch._custom_ops APIs to register backward " + f"formula for {detail}. Got operator " + f"{self._qualname} with schema: {schema}" + ) + + schema = self._schema + if schema.kind() != SchemaKind.functional: + error("non-functional operator") + + rets = schema.returns + if not schema.returns: + error("operator with no returns") + + assert len(rets) > 0 + is_non_mutating_view = any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + if is_non_mutating_view: + error("operator that returns views") + + # We make assumptions about the schema's return types. + allowed_return_types = { + BaseType(BaseTy.int): "int", + BaseType(BaseTy.SymInt): "SymInt", + BaseType(BaseTy.bool): "bool", + BaseType(BaseTy.float): "float", + BaseType(BaseTy.Tensor): "Tensor", + ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]", + } + for ret in schema.returns: + if ret.type in allowed_return_types: + continue + error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})") + + def _check_doesnt_have_library_autograd_impl(self): + if self._registered_autograd_kernel_indirection: + return + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an autograd formula; " + f"instead, the operator will decompose into its constituents and those " + f"can have autograd formulas defined on them.") + + # We can improve this by adding "all Autograd keys", but + # realistically people will just be using this API for CPU/CUDA for now. + for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]: + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: " + f"the operator {self._qualname} already has an Autograd kernel " + f"registered to DispatchKey::{key} vi a pre-existing " + f"torch.library or TORCH_LIBRARY registration. Please either " + f"remove those registrations or don't use the torch._custom_ops APIs") + + def _check_doesnt_have_library_meta_impl(self): + if self._has_impl("abstract"): + return + + # If the user's operator is CompositeExplicitAutograd, + # allow them to impl_abstract. This is being pragmatic + # (existing custom ops may have CompositeExplicitAutograd + # registration that don't work with Meta kernels, so this + # gives them an escape hatch). + if ( + _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd") + and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta") + ): + return + + # Otherwise, if the user's already has a Meta kernel or their + # op is CompositeImplicitAutograd or some other alias dispatch key, + # raise. + + # Special case for CompositeImplicitAutograd + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an abstract impl; " + f"instead, the operator will decompose into its constituents and those " + f"can have abstract impls defined on them.") + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call impl_abstract.") + + # NOTE ["backward", "save_for_backward", and "autograd"] + # As a part of the explicit autograd API, a user must provide us + # a "save_for_backward" function and a "backward" function. + # When both of these have been provided, then we automatically + # construct the "autograd" kernel. + def _register_autograd_kernel(self): + assert self._has_impl("backward") + assert self._has_impl("save_for_backward") + kernel = construct_autograd_kernel( + self._schema, + self._output_differentiability, + self, + get_op(self._qualname), + self._get_impl("save_for_backward").func, + self._get_impl("backward").func) + self._register_impl("autograd", kernel) + + def impl_save_for_backward(self, _stacklevel=2): + r"""Register a function that tells us what to save for backward. + + Please see impl_backward for more details. + """ + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("save_for_backward", f, stacklevel=_stacklevel) + if self._has_impl("backward"): + self._register_autograd_kernel() + return inner + + def impl_backward(self, output_differentiability=None, _stacklevel=2): + r"""Registers a backward formula. + + WARNING: if you're a user, please do not use this directly + (instead use the torch._custom_ops APIs). + Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + In order for the CustomOp to work with autograd, you need to register + a backward formula. There are two pieces to this: + 1. You must give us a function to specify what to save for backward. + Call this the "save for backward" function. + 2. You must give us a function that computes gradients. Call this the + "backward" function. + + Use `impl_save_for_backward` to define a "save for backward" function + that specifies what gets saved for backward. The function should accept + two arguments ``(inputs, output)`` and return the quantities to be saved + for backward. + + During runtime, when you call the CustomOp, PyTorch will invoke the + "save for backward" function with the inputs and output of the CustomOp. + + Use `impl_backward` to define the "backward" function. The backward + function must accept ``(ctx, saved, *grads)``: + - ``ctx`` is a context object where we may provide information + - ``saved`` is exactly what gets returned from the "save for backward" + function + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the CustomOp. + + The backward function must return a dict that maps the name of + an input to the CustomOp to its corresponding gradient. All inputs that + were declared to be Tensors in the CustomOp definition must be accounted + for in the dict. The gradient may be a Tensor or None. + + """ + if output_differentiability is not None: + def yell(): + raise RuntimeError( + f"impl_backward(output_differentiability): expected " + f"output_differentiability to be a list of bools with " + f"length equal to the number of outputs of this CustomOp " + f"got: {output_differentiability}") + + if not isinstance(output_differentiability, list): + yell() + for diff in output_differentiability: + if not isinstance(diff, bool): + yell() + if len(self._schema.returns) != len(output_differentiability): + yell() + + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("backward", f, stacklevel=_stacklevel) + self._output_differentiability = output_differentiability + if self._has_impl("save_for_backward"): + self._register_autograd_kernel() + return inner + + +@dataclasses.dataclass +class FuncAndLocation: + func: typing.Callable + location: str + + +def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): + overload_name = ( + "" if operator_name.overload_name is None else operator_name.overload_name + ) + return _C._dispatch_find_schema_or_throw( + f"{cpp_ns}::{str(operator_name.name)}", overload_name + ) + + +def validate_namespace(ns: str) -> None: + if "." in ns: + raise ValueError( + f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' + f"valid variable name)" + ) + if ns in RESERVED_NS: + raise ValueError( + f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " + f"please choose something else. " + ) + +def validate_schema(schema: FunctionSchema) -> None: + if not torch._library.utils.is_functional_schema(schema): + raise ValueError( + f"custom_op only supports functional operators " + f"(ops that do not mutate any inputs, do not return " + f"views of the inputs, and has at least one return). " + f"Got the following non-functional schema: {schema}" + ) + + # For simplicity: don't allow self arguments + if schema.arguments.self_arg is not None: + raise ValueError( + f"custom_op does not support arguments named 'self'. Please " + f"rename your argument. Got: {schema}" + ) + + +def parse_qualname(qualname: str) -> typing.Tuple[str, str]: + names = qualname.split("::", 1) + if len(names) != 2: + raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The " + f"operator name should look something like ns::foo") + if '.' in names[1]: + raise ValueError(f"The torch.custom_ops APIs do not handle overloads, " + f"i.e. operator names with '.' in them. " + f"Please name your operator something like ns::foo. " + f"Got: {qualname}") + return names[0], names[1] + + +def validate_device_type(device_type: str) -> None: + if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: + raise ValueError( + f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " + f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." + ) + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def validate_function_matches_schema( + schema: FunctionSchema, func: typing.Callable +) -> None: + sig = inspect.signature(func) + + if not all(supported_param(p) for _, p in sig.parameters.items()): + raise ValueError( + f"custom_op(..., manual_schema)(func): positional-only args, " + f"varargs, and kwargs are not supported. Please rewrite `func` " + f"to not have them. Got `func` with signature: {sig}" + ) + + if ( + any( + p.annotation is not inspect.Parameter.empty + for _, p in sig.parameters.items() + ) + or sig.return_annotation is not inspect.Signature.empty + ): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func` to have no type annotations to avoid " + f"ambiguity. Got `func` with signature: {sig}" + ) + + positional = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwargonly = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + + def error(): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func`'s signature to match `manual_schema` " + f"(aside from type annotations). " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def error_default_args(): + raise ValueError( + f"custom_op(..., manual_schema)(func): " + f"neither func nor manual_schema should have default " + f"arguments. Got " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def compare(sig_args, schema_args): + if len(sig_args) != len(schema_args): + error() + for (name, param), arg in zip(sig_args, schema_args): + if name != arg.name: + error() + if param.default is not inspect.Parameter.empty or arg.default is not None: + error_default_args() + + compare(positional, schema.arguments.flat_positional) + compare(kwargonly, schema.arguments.flat_kwarg_only) + + +def infer_schema(prototype_function: typing.Callable) -> str: + sig = inspect.signature(prototype_function) + + def error_fn(what): + raise ValueError( + f"custom_op(...)(func): {what} " f"Got func with signature {sig})" + ) + + params = [ + parse_param(name, param, error_fn) for name, param in sig.parameters.items() + ] + ret = parse_return(sig.return_annotation, error_fn) + return f"({', '.join(params)}) -> {ret}" + + +def parse_param(name, param, error_fn): + if not supported_param(param): + error_fn("We do not support positional-only args, varargs, or varkwargs.") + + if param.annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.") + + if param.annotation not in SUPPORTED_PARAM_TYPES.keys(): + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + + if param.default is not inspect.Parameter.empty: + error_fn( + f"Parameter {name} has a default value; this is not supported. " + f"If you want to use default values then create a function with " + f"default values that calls the CustomOp" + ) + + return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}" + + +def derived_types( + base_type, cpp_type, list_base, optional_base_list, optional_list_base +): + result = [ + (base_type, cpp_type), + (typing.Optional[base_type], f"{cpp_type}?"), + ] + if list_base: + result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type] + if optional_base_list: + result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type] + if optional_list_base: + result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type] + return result + + +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (torch.Tensor, "Tensor", True, True, False), + (int, "SymInt", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (torch.types.Number, "Scalar", True, False, False), + (torch.dtype, "ScalarType", False, False, False), + (torch.device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + torch.Tensor: "Tensor", + typing.List[torch.Tensor]: "Tensor[]", + int: "SymInt", + float: "float", + bool: "bool", + torch.types.Number: "Scalar", +} + + +def parse_return(annotation, error_fn): + origin = typing.get_origin(annotation) + if origin is not tuple: + if annotation not in SUPPORTED_RETURN_TYPES.keys(): + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + return SUPPORTED_RETURN_TYPES[annotation] + + args = typing.get_args(annotation) + for arg in args: + if arg not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + + return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" + + +SUPPORTED_PARAM_TYPES = get_supported_param_types() + + +def report_error_callback(custom_op: typing.Any, key: str) -> None: + if key == "Undefined": + raise NotImplementedError( + f"{custom_op}: There were no Tensor inputs to this operator " + f"(e.g. you passed an empty list of Tensors). If your operator is a " + f"factory function (that is, it takes no Tensors and constructs " + f"a new one), then please use CustomOp.impl_factory to register " + f"an implementation for it" + ) + if key == "Meta": + raise NotImplementedError( + f"{custom_op}: when running with device='Meta' tensors: there is no " + f"abstract impl registered for this CustomOp. Please register one via " + f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors" + ) + if key in ("CPU", "CUDA"): + device = key.lower() + raise NotImplementedError( + f"{custom_op}: when running with device='{device}' tensors: there is no " + f"{device} impl registered for this CustomOp. Please register one via " + f"CustomOp.impl(device_type='{device}')" + ) + raise NotImplementedError( + f"{custom_op}: No implementation for dispatch key {key}. It is likely " + f"that we have not added this functionality yet, please either open an " + f"issue or if you're feeling adventurous, use the low-level " + f"torch.library API" + ) + + +def custom_op_from_existing(op): + ns = op.namespace + lib = torch.library.Library(ns, "FRAGMENT") + name = op.name().split("::")[-1] + schema_str = str(op._schema) + # CustomOp expects the schema string without the namespace + schema_str = schema_str.split("::")[-1] + schema = FunctionSchema.parse(schema_str) + return CustomOp(lib, ns, schema, name, op, _private_access=True) + + +def get_op(qualname): + def error_not_found(): + raise ValueError( + f"Could not find the operator {qualname}. Please make sure you have " + f"already registered the operator and (if registered from C++) " + f"loaded it via torch.ops.load_library.") + + ns, name = parse_qualname(qualname) + if not hasattr(torch.ops, ns): + error_not_found() + opnamespace = getattr(torch.ops, ns) + if not hasattr(opnamespace, name): + error_not_found() + packet = getattr(opnamespace, name) + if not hasattr(packet, 'default'): + error_not_found() + return packet.default + + +def _find_custom_op(qualname, also_check_torch_library=False): + if qualname in global_registry: + return global_registry[qualname] + if not also_check_torch_library: + raise RuntimeError( + f"Could not find custom op \"{qualname}\". Did you register it via " + f"the torch._custom_ops API?") + overload = get_op(qualname) + result = custom_op_from_existing(overload) + return result + + +def get_abstract_impl(qualname): + if qualname not in torch._custom_op.impl.global_registry: + return None + custom_op = torch._custom_op.impl.global_registry[qualname] + if custom_op is None: + return None + if not custom_op._has_impl("abstract"): + return None + return custom_op._get_impl("abstract").func + + +def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True): + ns, name = qualname.split("::") + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else [] + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str, tags=tags) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + result._register_autograd_kernel_indirection() + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + return get_op(qualname) diff --git a/MLPY/Lib/site-packages/torch/_custom_ops.py b/MLPY/Lib/site-packages/torch/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..77d01d0d23f87632a6e2499764140b5815193cae --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_custom_ops.py @@ -0,0 +1,322 @@ +import inspect + +from torch._custom_op.impl import ( + _custom_op_with_schema, + _find_custom_op, + infer_schema, + parse_qualname, + validate_namespace, +) +from torch.library import get_ctx + +__all__ = [ + "custom_op", + "impl", + "impl_abstract", + "get_ctx", + "impl_save_for_backward", + "impl_backward", +] + + +def custom_op(qualname, func_or_schema=None): + r"""Register a new custom operator + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define the op (by providing an operator name and schema) + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the custom operator (the first step) + you must then perform the second step by calling various + ``impl_*`` APIs. + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Arguments: + qualname (str): Should be a string that looks like + "namespace::operator_name". Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. + func_or_schema (Union[Callable, str]): Each PyTorch operator needs a + schema that tells PyTorch the types of the inputs/outputs. + If this is a Callable, we will automatically infer the schema from + the type annotations on the function (see examples). Otherwise, + if you don't want to use type annotations, you may provide us the + schema string. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the custom op. + >>> # We need to provide the API a "prototype function" + >>> # (a function that returns NotImplementedError), from which + >>> # we will infer the types of the inputs and outputs. + >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> raise NotImplementedError() + >>> + >>> # The custom op is now accessible via the torch.ops module: + >>> torch.ops.mylibrary.numpy_sin + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu") + >>> def numpy_sin_impl_cpu(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda") + >>> def numpy_sin_impl_cuda(x): + >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda + + """ + ns, name = parse_qualname(qualname) + validate_namespace(ns) + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = infer_schema(func) + _custom_op_with_schema(qualname, schema) + return func + + if func_or_schema is None: + return inner + if isinstance(func_or_schema, str): + _custom_op_with_schema(qualname, func_or_schema) + else: + return inner(func_or_schema) + + +def impl(qualname, *, device_types=("cpu", "cuda"), func=None): + r"""Register an implementation for a device type for this custom op. + + If the op is passed multiple Tensor inputs with different device + types, it will dispatch to the registered implementation for the highest + priority device type among those present. + The supported device types, in order of priority, are {'cuda', 'cpu'}. + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Arguments: + device_types (str or Iterable[str]): the device type(s) to register the function for. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the custom op. + >>> # We need to provide the API a "prototype function" + >>> # (a function that returns NotImplementedError), from which + >>> # we will infer the types of the inputs and outputs. + >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos") + >>> def numpy_cos(x: Tensor) -> Tensor: + >>> raise NotImplementedError() + >>> + >>> # The custom op is now accessible via the torch.ops module: + >>> torch.ops.mylibrary.numpy_cos + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu") + >>> def numpy_cos_impl_cpu(x): + >>> return torch.from_numpy(np.cos(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda") + >>> def numpy_cos_impl_cuda(x): + >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda + + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl(device_types, _stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def impl_abstract(qualname, *, func=None): + r"""Register an abstract implementation for this operator. + + An "abstract implementation" specifies the behavior of this operator on + Tensors that carry no data. Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + The abstract implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write an abstract + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The abstract implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + Examples:: + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch._custom_ops.custom_op("mylibrary::custom_linear") + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> raise NotImplementedError() + >>> + >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear") + >>> def custom_linear_abstract(x, weight): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero') + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> ... + >>> + >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero") + >>> def custom_nonzero_abstract(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch._custom_ops.get_ctx() + >>> nnz = ctx.create_unbacked_symint() + >>> shape = [x.dim(), nnz] + >>> result = x.new_empty(shape, dtype=torch.long) + >>> return result + >>> + >>> @torch._custom_ops.impl("mylibrary::custom_nonzero") + >>> def custom_nonzero_impl(x): + >>> x_np = to_numpy(x) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> # unbacked symbolic ints in PyTorch must be >= 2, so we + >>> # constrain the range to at least 2 + >>> if res.shape[0] <= 1: + >>> raise RuntimeError("not supported") + >>> return torch.tensor(res, device=x.device) + + """ + import torch.library + + return torch.library.impl_abstract(qualname, func, _stacklevel=2) + + +def impl_save_for_backward(qualname, *, func=None): + r"""Register a function that tells us what to save for backward. + + Please see :func:`impl_backward` for more details. + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl_save_for_backward(_stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def impl_backward(qualname, output_differentiability=None, *, func=None): + r"""Registers a backward formula for an operator. + + In order for an operator to work with autograd, you need to register + a backward formula. There are two pieces to this: + 1. You must give us a function to specify what to save for backward. + Call this the "save for backward" function. + 2. You must give us a function that computes gradients. Call this the + "backward" function. + + Use `impl_save_for_backward` to define a "save for backward" function + that specifies what gets saved for backward. The function should accept + two arguments ``(inputs, output)`` and return the quantities to be saved + for backward. + + During runtime, when you call the operator in a forwards pass, PyTorch + will invoke the "save for backward" function with the inputs and output + of the operator. + + Use `impl_backward` to define the "backward" function. The backward + function must accept ``(ctx, saved, *grads)``: + - ``ctx`` is a context object where we may provide information + - ``saved`` is exactly what gets returned from the "save for backward" + function + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + + The backward function must return a dict that maps the name of + an input to the operator to its corresponding gradient. All inputs that + were declared to be Tensors in the operator definition must be accounted + for in the dict. The gradient may be a Tensor or None. + + For a detailed guide on custom ops, please see + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + """ + + def inner(func): + custom_op = _find_custom_op(qualname, also_check_torch_library=True) + custom_op.impl_backward(output_differentiability, _stacklevel=3)(func) + return func + + if func is None: + return inner + return inner(func) + + +def _destroy(qualname): + """De-registers a custom op. For testing purposes only""" + custom_op = _find_custom_op(qualname) + custom_op._destroy() diff --git a/MLPY/Lib/site-packages/torch/_decomp/__init__.py b/MLPY/Lib/site-packages/torch/_decomp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..361ad0bc40e1c1fa6f5a8cb4959ed6083a5bd639 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_decomp/__init__.py @@ -0,0 +1,463 @@ +import inspect +from collections import defaultdict +from functools import wraps +from itertools import chain +from typing import Callable, Dict, List, Sequence, Union + +import torch +import torch.library +from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket +from torch._prims_common import CustomOutParamAnnotation +from torch.utils import _pytree as pytree + +__all__ = [ + "decomposition_table", + "pre_autograd_decomposition_table", + "meta_table", + "register_decomposition", + "get_decompositions", + "core_aten_decompositions", +] + + +# TODO: relax key type here; torch registrations should be possible to; but +# right now this type is accurate +global_decomposition_table: Dict[ + str, Dict[torch._ops.OperatorBase, Callable] +] = defaultdict(dict) + +decomposition_table = global_decomposition_table["post_autograd"] +pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] +meta_table = global_decomposition_table["meta"] + + +def _add_op_to_registry(registry, op, fn): + """ + This is an internal API for adding an op to the decomposition table. + + If op is OpOverload, it will be added to the registry directly. + If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. + """ + overloads: List[Union[torch._ops.OperatorBase]] = [] + if isinstance(op, HigherOrderOperator): + # There's no concept of overloads for HigherOrderOperator + registry[op] = fn + return + elif isinstance(op, OpOverload): + overloads.append(op) + else: + assert isinstance(op, OpOverloadPacket) + for ol in op.overloads(): + overloads.append(getattr(op, ol)) + + for op_overload in overloads: + if op_overload in registry: + raise RuntimeError(f"duplicate registrations for {op_overload}") + # TorchScript dumps a bunch of extra nonsense overloads + # which don't have corresponding dispatcher entries, we need + # to filter those out, e.g aten.add.float_int + if torch._C._dispatch_has_kernel(op_overload.name()): + registry[op_overload] = fn + + +def _convert_out_params(f): + out_annotation = f.__annotations__.get("out") + + # If there are no out params, do not wrap the function. + if not out_annotation: + return f + + # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this + if getattr(out_annotation, "__origin__", None) is tuple: + sig = inspect.signature(f) + out_names = sig.return_annotation._fields + # If out is a tuple, we need to register a function that unpacks all the out + # elements as this is what native_functions.yaml expects + + @wraps(f) + def _fn(*args, **kwargs): + out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) + # Either all of the out kwargs are set or none of them + is_none = out_kwargs[0] is None + assert all((o is None) == is_none for o in out_kwargs) + return f(*args, **kwargs, out=None if is_none else out_kwargs) + + out_params = [ + inspect.Parameter( + o, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=t, + ) + for o, t in zip(out_names, out_annotation.__args__) + ] + # Drop the out parameter and concatenate the new kwargs in the signature + params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] + ) + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + for o in out_params: + _fn.__annotations__[o.name] = o.annotation + + # Propagate that this function is wrapped by `out_wrapper` + _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined] + + return _fn + + # Alternatively, there may be a single tensor out parameter with a name + # other than "out". This will need special treatment and is indicated by an + # annotation, which we will remove here so it is not exposed after wrapping. + custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None) + if custom_out_param_name: + + @wraps(f) + def _fn(*args, **kwargs): + out_kwarg = kwargs.pop(custom_out_param_name, None) + return f(*args, **kwargs, out=out_kwarg) + + out_param = inspect.Parameter( + custom_out_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_annotation, + ) + + # Drop the out parameter and concatenate the new kwarg in the signature + sig = inspect.signature(f) + params = chain( + (v for k, v in sig.parameters.items() if k != "out"), (out_param,) + ) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] + ) + + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + _fn.__annotations__[out_param.name] = out_param.annotation + + return _fn + + return f + + +def register_decomposition( + aten_op, registry=None, *, type="post_autograd", unsafe=False +): + """ + A decorator to register a function as a decomposition to the Python + decomposition table. Use it like this:: + + @register_decomposition(torch.ops.aten.clamp_min) + def clamp_min(x): + return torch.clamp(self, min=min) + + If you are writing a new decomposition, consider contributing it + directly to PyTorch in torch._decomp.decompositions. + + This API is experimental; we are almost certainly going to extend + the API when we make decompositions eligible for use in transforms (e.g., + autograd) and not just backend tracing, where we then need to know if a + decomposition can be used to simulate a transform. + + By default, we also will register it to the Meta key of dispatcher, + and replace the c++ Meta implementation if there is already one. + + unsafe kwarg is for reuse of this function for registering non-function + things + """ + + assert type in {"post_autograd", "pre_autograd", "meta"} + + def decomposition_decorator(fn: Callable) -> Callable: + orig_fn = fn + if not unsafe: + fn = _convert_out_params(fn) + + nonlocal registry + if registry is None: + registry = global_decomposition_table[type] + + def register(op): + _add_op_to_registry(registry, op, fn) + + # To handle allowing multiple aten_ops at once + pytree.tree_map_(register, aten_op) + return orig_fn + + return decomposition_decorator + + +def get_decompositions( + aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]], + type: str = "post_autograd", +) -> Dict[torch._ops.OperatorBase, Callable]: + """ + Retrieve a dictionary of decompositions corresponding to the list of + operator overloads and overload packets passed as input. Overload + packets will include all decomposed overloads in the packet. If there is + no decomposition for a requested operator, it is silently ignored. + + This API is experimental; we are almost certainly going to give an alternate, + more recommended formulation, where a user provides the set of operators + they know how to implement, and we provide decompositions for everything + not in this set. + """ + assert type in {"post_autograd", "pre_autograd", "meta"} + + registry = global_decomposition_table[type] + packets_to_overloads = defaultdict(list) + for opo in registry: + if isinstance(opo, (OpOverload, OpOverloadPacket)): + packets_to_overloads[opo.overloadpacket].append(opo) + decompositions: Dict[torch._ops.OperatorBase, Callable] = {} + for op in aten_ops: + if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: + for op_overload in packets_to_overloads[op]: + decompositions[op_overload] = registry[op_overload] + elif isinstance(op, (torch._ops.OperatorBase)) and op in registry: + decompositions[op] = registry[op] + return decompositions + + +def remove_decompositions( + decompositions: Dict[torch._ops.OperatorBase, Callable], + aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], +) -> None: + """ + Given a dictionary of decompositions obtained from get_decompositions(), removes + operators associated with a list of operator overloads and overload packets passed + as input. If the decomposition dictionary does not contain a decomposition that is + specified to be removed, it is silently ignored. + """ + for op in aten_ops: + if isinstance(op, OpOverloadPacket): + for overload_name in op.overloads(): + opo = getattr(op, overload_name) + decompositions.pop(opo, None) + elif isinstance(op, OpOverload): + decompositions.pop(op, None) + + +# populate the table +import torch._decomp.decompositions +import torch._refs + + +# See NOTE [Core ATen Ops] +# +# list was copied from torch/_inductor/decomposition.py +# excluding decompositions that results in prim ops +# Resulting opset of decomposition is core aten ops +def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: + aten = torch.ops.aten + return get_decompositions( + [ + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.affine_grid_generator, + aten.all, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.baddbmm, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.block_diag, + aten.celu, + aten.celu_, + aten.clamp_max, + aten.clamp_min, + aten.col2im, + aten.count_nonzero, + aten.linalg_cross, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.deg2rad, + aten.deg2rad_, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.dot, + aten.vdot, + aten.elu, + aten.elu_, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten.empty_like, + aten._euclidean_dist.default, + aten.expand_as, + aten.eye, + aten.fill, + aten.fill_, + aten.floor_divide, + aten.frac, + aten.frac_, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu_, + aten.gelu_backward, + aten.glu, + aten.glu_backward, + aten.hardshrink, + aten.hardsigmoid, + aten.hardsigmoid_, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.heaviside_, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add, + aten.index_add_, + aten.index_copy, + aten.index_copy_, + aten.index_fill, + aten.index_fill_, + aten.isin, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten._lazy_clone, + aten._test_parallel_materialize, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.lerp_, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.mish, + aten.mish_, + aten.mse_loss, + aten.mse_loss_backward, + aten.multi_margin_loss, + aten.multilabel_margin_loss_forward, + aten.mv, + aten.mvlgamma, + aten.mvlgamma_, + aten.nansum, + aten.nan_to_num, + aten.nan_to_num_, + aten.narrow, + aten.native_batch_norm_backward, + aten.native_dropout_backward, + aten.native_group_norm_backward, + aten.native_layer_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm, + aten.ones, + aten.ones_like, + aten.pixel_shuffle, + aten.pixel_unshuffle, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.rad2deg_, + aten.reflection_pad1d, + aten.reflection_pad2d, + aten.reflection_pad3d, + aten.replication_pad1d, + aten.replication_pad2d, + aten.replication_pad3d, + aten.renorm, + aten.renorm_, + aten.replication_pad2d, + aten.roll, + aten.rot90, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.rsub, + aten._scaled_dot_product_flash_attention_for_cpu.default, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sgn_, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward, + aten.sinc, + aten.sinc_, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.split.Tensor, + aten.split_with_sizes_copy, + aten.squeeze.default, + aten.squeeze.dim, + aten.std, + aten.std_mean, + aten.stack, + aten.sum.default, + aten.sum.out, + aten.t, + aten.take, + aten.tanh_backward, + aten.threshold, + aten.threshold_, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.tril, + aten.tril_, + aten.triu, + aten.triu_, + aten.unbind, + aten.unfold_backward, + aten.unfold_copy, + aten._unsafe_index, + aten.unsafe_split.Tensor, + aten.unsafe_split_with_sizes, + aten._unsafe_view, + aten.upsample_linear1d, + aten.upsample_bilinear2d, + aten.upsample_trilinear3d, + aten.upsample_nearest2d_backward, + aten.view_as_complex, + aten.xlogy, + aten.xlogy_, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, + aten._chunk_cat, + aten._weight_norm_interface, + ] + ) diff --git a/MLPY/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d5083628a1709afeefd70ef32d4fe39dc50de62 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a892bcb3d2194c802567f1b3d99037cadfd481 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c625285fe7a4dbde286712829cb3f89138f950b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d509e239e7441a2032b870b82e1f206c7d3fb16b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_decomp/decompositions.py b/MLPY/Lib/site-packages/torch/_decomp/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b39a7b82004cd795ede252139422488d0d019a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_decomp/decompositions.py @@ -0,0 +1,4659 @@ +import functools +import numbers +import operator +import sys +from enum import Enum +from functools import partial, reduce +from itertools import chain, product +from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch.nn.functional as F +from torch import sym_float, sym_int, Tensor +from torch._decomp import register_decomposition +from torch._higher_order_ops.out_dtype import out_dtype +from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + out_wrapper, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map + +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + +# None of these functions are publicly accessible; get at them +# from torch._decomps +__all__: List[str] = [] + +aten = torch._ops.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided +# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops +# Will need to validate the non-elementwise uses +def type_casts( + f: Callable, + type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND, + compute_dtype_only: bool = False, +): + @functools.wraps(f) + def inner(*args, **kwargs): + flat_args = [ + x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor) + ] + computation_dtype, result_dtype = utils.elementwise_dtypes( + *flat_args, type_promotion_kind=type_promotion + ) + + # TODO: pretty sure this is not quite right + def increase_prec(x): + if isinstance(x, Tensor): + return x.to(computation_dtype) + else: + return x + + def decrease_prec(x): + if isinstance(x, Tensor): + return x.to(result_dtype) + else: + return x + + r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) + if compute_dtype_only: + return r + else: + return tree_map(decrease_prec, r) + + return inner + + +compute_only_pw_cast_for_opmath = partial( + type_casts, + type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + compute_dtype_only=True, +) +pw_cast_for_opmath = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +) +pw_cast_for_int_to_real = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) + + +# This expands x until x.dim() == dim. Might be useful as an operator +def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor: + for _ in range(dim - x.dim()): + x = x.unsqueeze(-1) + return x + + +@register_decomposition(aten.tanh_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def tanh_backward(out_grad: Tensor, y: Tensor): + return out_grad * (1 - y * y).conj_physical() + + +@register_decomposition(aten.sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def sigmoid_backward(out_grad: Tensor, y: Tensor): + return out_grad * (y * (1 - y)).conj_physical() + + +@register_decomposition(aten.softplus_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): + z = (x * beta).exp() + return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) + + +@register_decomposition(aten.elu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def elu_backward( + grad_output: Tensor, + alpha: float, + scale: float, + input_scale: float, + is_result: bool, + self_or_result: Tensor, +): + negcoef = alpha * scale + poscoef = scale + negiptcoef = input_scale + if is_result: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * (self_or_result + negcoef), + grad_output * poscoef, + ) + else: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef), + grad_output * poscoef, + ) + + +@register_decomposition([aten.fill.Scalar]) +def fill_scalar(self, value): + return torch.full_like(self, value) + + +@register_decomposition([aten.fill.Tensor]) +def fill_tensor(self, value: Tensor): + torch._check( + value.dim() == 0, + lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", + ) + return aten.copy(self, value) + + +@register_decomposition(aten.hardsigmoid) +@out_wrapper() +@pw_cast_for_opmath +def hardsigmoid(self: Tensor) -> Tensor: + return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardsigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def hardsigmoid_backward(grad_output: Tensor, self: Tensor): + return torch.where( + (self > -3.0) & (self < 3.0), + grad_output * (1.0 / 6.0), + 0.0, + ) + + +@register_decomposition(aten.hardtanh_backward) +@out_wrapper("grad_input") +def hardtanh_backward( + grad_output: Tensor, self: Tensor, min_val: float, max_val: float +): + return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output) + + +@register_decomposition(aten.hardswish) +@out_wrapper() +@pw_cast_for_opmath +def hardswish(self: Tensor) -> Tensor: + return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardswish_backward) +@out_wrapper() +@pw_cast_for_opmath +def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: + return torch.where( + self < -3, + 0.0, + torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output), + ) + + +@register_decomposition(aten.threshold_backward) +@out_wrapper("grad_input") +def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): + return torch.where(self <= threshold, 0, grad_output) + + +@register_decomposition(aten.leaky_relu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def leaky_relu_backward( + grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool +): + return torch.where(self > 0, grad_output, grad_output * negative_slope) + + +@register_decomposition(aten.gelu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + x_sq = self * self + x_cube = x_sq * self + inner = kBeta * (self + kKappa * x_cube) + tanh_inner = torch.tanh(inner) + + left = 0.5 * self + right = 1 + tanh_inner + + left_derivative = 0.5 * right + + tanh_derivative = 1 - tanh_inner * tanh_inner + inner_derivative = kBeta * (1 + 3 * kKappa * x_sq) + right_derivative = left * tanh_derivative * inner_derivative + + return grad * (left_derivative + right_derivative) + else: + kAlpha = M_SQRT1_2 + kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 + cdf = 0.5 * (1 + torch.erf(self * kAlpha)) + pdf = kBeta * torch.exp(self * self * -0.5) + return grad * (cdf + self * pdf) + + +@register_decomposition(aten.mish_backward) +@pw_cast_for_opmath +def mish_backward(grad_output: Tensor, input: Tensor): + input_tanh_softplus = torch.tanh(F.softplus(input)) + input_sigmoid = torch.sigmoid(input) + out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus) + return grad_output * (input_tanh_softplus + out) + + +@register_decomposition(aten.silu) +@out_wrapper() +@pw_cast_for_opmath +def silu(self: Tensor) -> Tensor: + return self * torch.sigmoid(self) + + +@register_decomposition(aten.silu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: + sigmoid = 1 / (1 + torch.exp(-self)) + return grad_output * sigmoid * (1 + self * (1 - sigmoid)) + + +@register_decomposition(aten._prelu_kernel) +def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor: + return torch.where(self > 0, self, weight * self) + + +@register_decomposition(aten._prelu_kernel_backward) +def _prelu_kernel_backward( + grad_output: Tensor, + self: Tensor, + weight: Tensor, +) -> Tuple[Tensor, Tensor]: + input_grad = torch.where(self > 0, grad_output, weight * grad_output) + weight_grad = torch.where(self > 0, 0.0, self * grad_output) + return (input_grad, weight_grad) + + +@register_decomposition(aten.rrelu_with_noise) +@aten.rrelu_with_noise.default.py_impl(DispatchKey.AutogradCUDA) +@out_wrapper() +@pw_cast_for_opmath +def rrelu_with_noise( + self: Tensor, + noise: Tensor, + lower: float = 0.125, + upper: float = 0.3333333333333333, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> Tensor: + assert generator is None + if training: + not_positive = self <= 0 + r = aten.uniform(self, lower, upper) + output = torch.where(not_positive, self * r, self) + noise.copy_(torch.where(not_positive, r, 1)) + return output + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu(self, negative_slope) + + +@register_decomposition(aten.rrelu_with_noise_) +@aten.rrelu_with_noise_.default.py_impl(DispatchKey.AutogradCUDA) +@pw_cast_for_opmath +def rrelu_with_noise_( + self: Tensor, + noise: Tensor, + lower: float, + upper: float, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> Tensor: + return self.copy_(rrelu_with_noise(self, noise, lower, upper, training, generator)) + + +@register_decomposition(aten.rrelu_with_noise_backward) +@out_wrapper() +@pw_cast_for_opmath +def rrelu_with_noise_backward( + grad_output: Tensor, + self: Tensor, + noise: Tensor, + lower: float, + upper: float, + training: bool, + self_is_result: bool, +) -> Tensor: + if training and upper - lower > 1e-6: + return grad_output.mul(noise) + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu_backward( + grad_output, self, negative_slope, self_is_result + ) + + +@register_decomposition(aten.log_sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor: + in_negative = self < 0 + max_deriv = torch.where(in_negative, 1, 0) + sign = torch.where(in_negative, 1, -1) + z = torch.exp(-torch.abs(self)) + return grad_output * (max_deriv - sign * (z / (1 + z))) + # CPU has a special formula that uses buffer, but disabled for convenience sake + # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output + + +def apply_loss_reduction(loss: Tensor, reduction: int): + if reduction == Reduction.MEAN.value: + return torch.mean(loss) + elif reduction == Reduction.SUM.value: + return torch.sum(loss) + else: + return loss + + +def to_real_dtype(dtype: torch.dtype): + if dtype == torch.complex32: + return torch.float16 + elif dtype == torch.complex64: + return torch.float32 + elif dtype == torch.complex128: + return torch.float64 + + +# TODO: None of these loss castings are quite correct, see +# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels +# perform the pointwise portion in opmath, but don't maintain it between the +# pointwise portion and the reduction + + +@register_decomposition(aten.mse_loss) +@out_wrapper() +@pw_cast_for_opmath +def mse_loss( + self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value +) -> Tensor: + loss = (self - target) ** 2 + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.mse_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def mse_loss_backward( + grad_output: Tensor, input: Tensor, target: Tensor, reduction: int +): + norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0 + return norm * (input - target) * grad_output + + +@register_decomposition(aten.smooth_l1_loss) +@out_wrapper() +@pw_cast_for_opmath +def smooth_l1_loss( + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, + beta: float = 1.0, +): + loss = (self - target).abs() + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.smooth_l1_loss_backward.default) +@pw_cast_for_opmath +def smooth_l1_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + abs_x = torch.abs(x) + norm_grad = norm * grad_output + return torch.where( + abs_x < beta, + norm_grad * x / beta, + norm_grad * torch.sign(x), + ) + + +@register_decomposition(aten.smooth_l1_loss_backward.grad_input) +@pw_cast_for_opmath +def smooth_l1_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + beta: float, + grad_input: Tensor, +): + result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +@register_decomposition(aten.huber_loss_backward.default) +@pw_cast_for_opmath +def huber_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + return torch.where( + x < -delta, + -norm * grad_output * delta, + torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output), + ) + + +# We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input' +@register_decomposition(aten.huber_loss_backward.out) +@pw_cast_for_opmath +def huber_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + delta: float, + grad_input: Tensor, +): + result = huber_loss_backward(grad_output, self, target, reduction, delta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +def _nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + channel_dim = 0 if self.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(self) + grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(self.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + grad_output = grad_output * weight + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + return grad_input * grad_output + + +@register_decomposition(aten.glu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: + assert self.dim() > 0, "glu does not support 0-dimensional tensors" + wrap_dim = utils.canonicalize_dim(self.dim(), dim) + nIn = self.size(wrap_dim) + assert ( + nIn % 2 == 0 + ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" + inputSize = nIn // 2 + firstHalf = self.narrow(wrap_dim, 0, inputSize) + secondHalf = self.narrow(wrap_dim, inputSize, inputSize) + gradInputFirstHalf = torch.sigmoid(secondHalf) + gradInputSecondHalf = ( + (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output + ) + gradInputFirstHalf = gradInputFirstHalf * grad_output + return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim) + + +@register_decomposition(aten.nll_loss_backward) +@out_wrapper("grad_input") +def nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D" + assert ( + target.dim() <= 1 + ), "0D or 1D target tensor expected, multi-target not supported" + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or ( + self.shape[0] == target.shape[0] + ), f"size mismatch (got input: {self.shape}, target: {target.shape})" + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, got: ", + f"{total_weight.shape} ({total_weight.numel()} elements)", + ) + + assert ( + weight is None or weight.numel() == self.shape[-1] + ), "weight tensor should be defined either for all or no classes" + + if reduction == Reduction.NONE.value and self.dim() == 2: + assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], ( + f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but " + f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}" + ) + else: + assert ( + grad_output.dim() <= 1 and grad_output.numel() == 1 + ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}" + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.nll_loss2d_backward) +@out_wrapper("grad_input") +def nll_loss2d_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert ( + self.dim() == 4 + ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}" + + assert ( + target.dim() == 3 + ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" + + assert ( + self.shape[0] == target.shape[0] + and self.shape[2] == target.shape[1] + and self.shape[3] == target.shape[2] + ), f"size mismatch (got input: {self.shape}, target: {target.shape}" + + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, " + f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" + ) + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.binary_cross_entropy) +@out_wrapper() +@pw_cast_for_opmath +def binary_cross_entropy( + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + # We cannot currently model this without introducing data-dependent control flow + # TORCH_CHECK( + # (input_val >= 0) && (input_val <= 1), + # "all elements of input should be between 0 and 1" + # ) + loss = (target - 1) * torch.maximum( + torch.log1p(-self), self.new_full((), -100) + ) - target * torch.maximum(torch.log(self), self.new_full((), -100)) + if weight is not None: + loss = loss * weight + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.binary_cross_entropy_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def binary_cross_entropy_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + EPSILON = 1e-12 + result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON) + if weight is not None: + result = result * weight + if reduction == Reduction.MEAN.value: + result = result / self.numel() + return result + + +@register_decomposition(aten.soft_margin_loss) +@out_wrapper() +@pw_cast_for_opmath +def soft_margin_loss( + input: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + loss = torch.log1p(torch.exp(-input * target)) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.soft_margin_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def soft_margin_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + grad_input = target * grad_output * (torch.sigmoid(target * self) - 1) + if reduction == Reduction.MEAN.value: + grad_input = grad_input / self.numel() + return grad_input + + +@register_decomposition(aten.dist) +@out_wrapper() +def dist(input: Tensor, other: Tensor, p: float = 2): + return aten.norm(input - other, p=p) + + +@register_decomposition(aten._euclidean_dist) +@out_wrapper() +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: + x1_norm = x1.pow(2).sum(-1, True) + x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format) + x2_norm = x2.pow(2).sum(-1, True) + x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format) + x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1) + x2_ = torch.cat([x2, x2_pad, x2_norm], -1) + result = x1_.matmul(x2_.mT) + return result.clamp_min(0).sqrt() + + +@register_decomposition(aten.slice_backward) +@out_wrapper() +def slice_backward( + grad_output: Tensor, + input_sizes: List[int], + dim: int, + start: int, + end: int, + step: int, +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.slice_scatter(grad_input, grad_output, dim, start, end, step) + + +@register_decomposition(aten.slice.Tensor) +def slice_forward( + # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1 + self: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + ndim = self.dim() + if ndim == 0: + raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") + dim = utils.canonicalize_dim(self.dim(), dim) + sizes = list(self.size()) + strides = list(self.stride()) + + if step <= 0: + raise RuntimeError("slice step must be positive") + + start_val = start if start is not None else 0 + end_val = end if end is not None else sys.maxsize # 2^63 – 1 + + if start_val < 0: + start_val += sizes[dim] + + if end_val < 0: + end_val += sizes[dim] + + if start_val < 0: + start_val = 0 + elif start_val > sizes[dim]: + start_val = sizes[dim] + + if end_val < start_val: + end_val = start_val + elif end_val > sizes[dim]: + end_val = sizes[dim] + + storage_offset = self.storage_offset() + start_val * strides[dim] + len = end_val - start_val + sizes[dim] = (len + step - 1) // step + strides[dim] *= step + + if self.is_quantized: + raise NotImplementedError( + "Slice decomposition for quantized tensors aren't implemented" + ) + else: + return self.as_strided(sizes, strides, storage_offset) + + +@register_decomposition(aten.select_backward) +@out_wrapper() +def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int): + grad_input = grad_output.new_zeros(input_sizes) + return torch.select_scatter(grad_input, grad_output, dim, index) + + +@register_decomposition(aten.diagonal_backward) +@out_wrapper() +def diagonal_backward( + grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _cast_grad_to_input_dtype( + grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype +): + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input + + +@register_decomposition(aten._softmax_backward_data) +@out_wrapper("grad_input") +@compute_only_pw_cast_for_opmath +def _softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + new_grad_output = grad_output * output + grad_input = new_grad_output - output * torch.sum( + new_grad_output, dim=dim, keepdim=True + ) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() + + +@register_decomposition(aten._log_softmax_backward_data) +@out_wrapper() +@compute_only_pw_cast_for_opmath +def _log_softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + grad_input = grad_output - torch.exp(output) * torch.sum( + grad_output, dim=dim, keepdim=True + ) + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + +def _im2col_col2im_indices_along_dim( + input_d, kernel_d, dilation_d, padding_d, stride_d, device +): + """Utility function to implement im2col and col2im""" + blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1) + + arange_kw = partial(torch.arange, dtype=torch.int64, device=device) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1) + + # Broadcast and add kernel starting positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + return blocks_d_indices + kernel_grid + + +@register_decomposition(aten.im2col) +@out_wrapper() +def im2col( + input: Tensor, + kernel_size: List[int], + dilation: List[int], + padding: List[int], + stride: List[int], +) -> Tensor: + torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") + torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: "{param_name} should be greater {'than' zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(dilation, "padding", strict=False) + check_positive(stride, "stride") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4) and all(d != 0 for d in shape[-3:]), + lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + output_size = tuple( + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + shape[-2:], padding, dilation, kernel_size, stride + ) + ) + torch._check( + all(c > 0 for c in output_size), + lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " + f"kernel_size={kernel_size}, dilation={dilation}, " + f"padding={padding}, stride={stride}, " + "the calculated shape of the array of sliding blocks " + f"is {output_size}, but its components must be at least one.", + ) + batched_input = ndim == 4 + if not batched_input: + input = input.unsqueeze(0) + + batch_dim, channel_dim, input_h, input_w = input.shape + + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + blocks_row_indices = _im2col_col2im_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + blocks_col_indices = _im2col_col2im_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom) + # ugh + padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h)) + + blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1) + output = padded_input[:, :, blocks_row_indices, blocks_col_indices] + output = output.permute(0, 1, 2, 4, 3, 5) + num_blocks_row = blocks_row_indices.size(1) + num_blocks_col = blocks_col_indices.size(1) + output = output.reshape( + batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col + ) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.col2im) +@out_wrapper() +@pw_cast_for_opmath +def col2im( + input: Tensor, + output_size: List[int], + kernel_size: List[int], + dilation: List[int], + padding: List[int], + stride: List[int], +) -> Tensor: + torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") + torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "only 2D padding supported") + torch._check(len(stride) == 2, lambda: "only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: "{param_name} should be greater than zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(padding, "padding", strict=False) + check_positive(stride, "stride") + check_positive(output_size, "output_size") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (2, 3) and all(d != 0 for d in shape[-2:]), + lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + prod_kernel_size = kernel_size[0] * kernel_size[1] + torch._check( + shape[-2] % prod_kernel_size == 0, + lambda: "Expected size of input's first non-batch dimension to be divisible by the " + f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " + f"kernel_size={kernel_size}", + ) + col = [ + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + output_size, padding, dilation, kernel_size, stride + ) + ] + L = col[0] * col[1] + torch._check( + shape[-1] == L, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + torch._check( + L > 0, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + batched_input = ndim == 3 + if not batched_input: + input = input.unsqueeze(0) + + shape = input.shape + + out_h, out_w = output_size + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand + input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col) + input = input.permute(0, 1, 2, 4, 3, 5) + + indices_row = _im2col_col2im_indices_along_dim( + out_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + indices_row = _unsqueeze_to_dim(indices_row, 4) + indices_col = _im2col_col2im_indices_along_dim( + out_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)] + output = input.new_zeros( + [shape[0], shape[1] // prod(kernel_size)] + output_padded_size + ) + idx = (None, None, indices_row, indices_col) + output = aten._unsafe_index_put(output, idx, input, accumulate=True) + output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h)) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.native_dropout_backward) +@out_wrapper() +def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r + + +@register_decomposition(aten.unfold_backward) +@out_wrapper() +def unfold_backward( + grad: Tensor, input_size: List[int], dimension: int, size: int, step: int +) -> Tensor: + if len(input_size) == 0: + return torch.squeeze_copy(grad, 0) + dim = utils.canonicalize_dim(len(input_size), dimension) + idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32) + idx = idx.unfold(0, size, step).flatten() + grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1) + # nb. At the moment this generates two kernels in triton + # It could potentially be fused into one call to scatter_reduce, + # in the case step <= size provided scatter_reduce generates 1 kernel + grad_input = grad.new_zeros(input_size) + index = (None,) * dim + (idx,) + return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous() + + +@register_decomposition(aten.logit_backward.default) +@pw_cast_for_opmath +def logit_backward( + grad_output: Tensor, self: Tensor, eps: Optional[float] = None +) -> Tensor: + if eps is not None: + lo = eps + hi = 1.0 - lo + return torch.where( + torch.logical_and(self >= lo, self <= hi), + grad_output / (self * (1.0 - self)), + 0.0, + ) + else: + return torch.where( + torch.logical_and(self >= 0.0, self <= 1.0), + grad_output / (self * (1.0 - self)), + self.new_full((), float("nan")), + ) + + +@register_decomposition(aten.dropout) +@aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.dropout.default.py_impl(DispatchKey.Autograd) +def dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + return aten.native_dropout(input, p, train)[0] + else: + return input.clone() + + +@register_decomposition(aten.native_dropout) +@out_wrapper("out0", "out1") +def native_dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + if p == 1: + return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool)) + if not input.dtype.is_floating_point: + raise RuntimeError( + "result type Float can't be cast to the desired output type Long" + ) + bool_mask = torch.rand_like(input) > p + res = bool_mask * input * float(1.0 / (1.0 - p)) + return (res, bool_mask) + else: + return (input, torch.ones_like(input, dtype=torch.bool)) + + +@register_decomposition(aten._softmax) +@out_wrapper() +def _softmax(x: Tensor, dim: int, half_to_float: bool): + # eager softmax returns a contiguous tensor. Ensure that decomp also returns + # a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + unnormalized = torch.exp(x) + else: + x_max = torch.amax(x, dim, keepdim=True) + unnormalized = torch.exp(x - x_max) + result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten._log_softmax) +@out_wrapper() +def _log_softmax(x: Tensor, dim: int, half_to_float: bool): + # eager log_softmax returns a contiguous tensor. Ensure that decomp also + # returns a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + shifted = x - x_max + shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten.embedding) +@out_wrapper() +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + assert weight.dim() == 2, "'weight' must be 2-D" + # Nb. scale_grad_by_freq is not used in the forward + if indices.ndim <= 1: + # We need this one as weight[indices] calls item() in these cases + out = weight.index_select(0, indices) + if indices.ndim == 0: + out = out.squeeze(0) + return out + else: + return weight[indices] + + +@register_decomposition(aten.embedding_dense_backward) +@out_wrapper() +def embedding_dense_backward( + grad_output: Tensor, + indices: Tensor, + num_weights: int, + padding_idx: int, + scale_grad_by_freq: bool, +): + computation_dtype, result_dtype = utils.elementwise_dtypes( + grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + grad_output = grad_output.to(computation_dtype) + indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] + if scale_grad_by_freq: + counts = indices.new_zeros((num_weights,)) + ones = torch.ones_like(indices) + counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True) + grad_weights_scale = counts[indices] + grad_output = grad_output / grad_weights_scale.unsqueeze(-1) + + mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) + grad = grad_output.masked_fill(mask, 0) + grad_weight = grad_output.new_zeros( + (num_weights,) + grad_output.shape[indices.ndim :] + ) + return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to( + result_dtype + ) + + +def prod(x: List[int]): + r = 1 + for i in x: + r *= i + return r + + +def _pad_chunk( + tensors: List[Tensor], + dim: int, + num_chunks: int, +) -> List[Tensor]: + padded_tensors = [] + for tensor in tensors: + tensor_size = tensor.size() + pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks + if pad_along_dim != tensor_size[dim]: + # Use aten.constant_pad_nd instead of copy_ for functionalization + pad = [0] * 2 * (tensor.ndim - dim - 1) + [ + 0, + pad_along_dim - tensor_size[dim], + ] + tensor = aten.constant_pad_nd(tensor, pad, 0) + view_size = tensor_size[:dim] + torch.Size([num_chunks, -1]) + padded_tensors.append(tensor.view(view_size)) + return padded_tensors + + +def have_same_ndims(tensors: List[Tensor]): + ndim = tensors[0].ndim + for tensor in tensors: + if tensor.ndim != ndim: + return False + return True + + +def leading_dimension_matches(tensors: List[Tensor], dim: int): + leading_dim_sizes = tensors[0].size()[:dim] + for tensor in tensors: + torch._check( + tensor.size()[:dim] == leading_dim_sizes, + lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors", + ) + + +def _preprocess_chunk_cat_inputs( + tensors: List[Tensor], + dim: int, + num_chunks: int, +): + torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks") + torch._check( + len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list" + ) + expected_dtype = tensors[0].dtype + expected_device = tensors[0].device + for tensor in tensors: + torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor") + torch._check( + tensor.dtype == expected_dtype, + lambda: "_chunk_cat expects all input tensors with the same dtype", + ) + torch._check( + tensor.device == expected_device, + lambda: "_chunk_cat expects all inputs tensors on the same device", + ) + if have_same_ndims(tensors): + dim = utils.canonicalize_dim(tensors[0].dim(), dim) + else: + torch._check( + dim >= 0, + lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims", + ) + for tensor in tensors: + torch._check( + dim < tensor.ndim, + lambda: "_chunk_cat expects dim < ndim for all input tensors", + ) + leading_dimension_matches(tensors, dim) + return dim + + +@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out]) +def _chunk_cat( + tensors: List[Tensor], + dim: int, + num_chunks: int, + out: Optional[Tensor] = None, +) -> Tensor: + dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks) + padded_tensors = _pad_chunk(tensors, dim, num_chunks) + if out is None: + return torch.cat(padded_tensors, dim + 1) + else: + torch.cat(padded_tensors, dim + 1, out=out) + return out + + +@register_decomposition(aten.split_with_sizes) +def split_with_sizes( + self: Tensor, split_sizes: List[int], dim: int = 0 +) -> List[Tensor]: + # NB: Perform the check_is_size tests first so that the + # sum test does not try to do a replacement + for i in range(len(split_sizes)): + torch._check_is_size( + split_sizes[i], + lambda: "split_with_sizes expects split_sizes have only non-negative entries", + ) + torch._check_with( + ValueError, + sum(split_sizes) == self.shape[dim], + lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", + ) + num_splits = len(split_sizes) + splits = [] + start_idx = 0 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import expect_true + + for i in range(num_splits): + length = split_sizes[i] + # We know this is true thanks to the sum, but this assertion helps + # out our internal reasoning + expect_true(start_idx + length <= self.shape[dim]) + splits.append(self.narrow(dim, start_idx, length)) + start_idx += length + return splits + + +# out_wrapper currently does not allow optional outputs +@register_decomposition( + [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out] +) +def split_with_sizes_copy( + self: Tensor, + split_sizes: List[int], + dim: int = 0, + out: Optional[List[Tensor]] = None, +) -> Optional[List[Tensor]]: + splits = split_with_sizes(self, split_sizes, dim=dim) + if out is None: + return [s.clone(memory_format=torch.contiguous_format) for s in splits] + else: + for output, split in zip(out, splits): + _maybe_resize_out(output, split.shape) + _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True) + return None + + +@register_decomposition(aten.unsafe_split.Tensor) +def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: + return aten.split.Tensor(input, split_size, dim) + + +@register_decomposition(aten.unsafe_split_with_sizes.default) +def unsafe_split_with_sizes( + input: Tensor, split_sizes: List[int], dim: int = 0 +) -> Tuple[Tensor, ...]: + return aten.split_with_sizes.default(input, split_sizes, dim) + + +@register_decomposition(aten.split.Tensor) +def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: + input_sizes = self.shape + dim_size = input_sizes[dim] + if split_size == 0: + assert dim_size == 0 + return (self,) + chunks = (dim_size + split_size - 1) // split_size + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import guard_int + + chunks = guard_int(chunks) + split_sizes = [split_size for i in range(chunks)] + split_sizes[-1] = split_size - (split_size * chunks - dim_size) + return torch.split(self, split_sizes, dim) + + +@aten.tensor_split.tensor_indices_or_sections.py_impl( + DispatchKey.CompositeImplicitAutograd +) +def tensor_split_tensor_indices_or_sections_py_impl( + self: Tensor, + tensor_indices_or_sections: Tensor, + dim: int = 0, +) -> Tuple[Tensor, ...]: + assert tensor_indices_or_sections.device.type == "cpu" + assert tensor_indices_or_sections.dtype == torch.int64 + split_dim = tensor_indices_or_sections.dim() + torch._check( + split_dim == 1 or split_dim == 0, + lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional " + f"or one-dimensional tensor, but got a tensor with {split_dim} dims", + ) + if split_dim == 0: + sections = tensor_indices_or_sections.item() + assert isinstance(sections, IntLike) + return self.tensor_split(sections, dim) + else: + indices = [i.item() for i in tensor_indices_or_sections] + return self.tensor_split(indices, dim) + + +# TODO: this doesn't appear to have enough precision in bfloat16 +@register_decomposition(aten.addmm) +@out_wrapper() +@pw_cast_for_opmath +def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mm(mat1, mat2) + if beta == 0: + return out + + # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition. + # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided. + # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition. + # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input. + # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases. + # This implementation is not ideal, and we should revisit this when we have a better solution. + return out + beta * self + + +@register_decomposition(aten._addmm_activation) +@out_wrapper() +@pw_cast_for_opmath +def _addmm_activation( + self: Tensor, + mat1: Tensor, + mat2: Tensor, + beta: int = 1, + alpha: int = 1, + use_gelu: bool = False, +): + out = addmm(self, mat1, mat2, beta, alpha) + if use_gelu: + if self.is_cuda: + return aten.gelu(out, approximate="tanh") + else: + return aten.gelu(out) + return aten.relu(out) + + +@register_decomposition(aten.addmv) +@out_wrapper() +@pw_cast_for_opmath +def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mv(mat1, vec) + if beta == 0: + return out + return out + beta * self + + +@register_decomposition(aten.native_group_norm_backward.default) +@pw_cast_for_opmath +def native_group_norm_backward( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + utils.check_same_device( + grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False + ) + utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) + utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) + torch._check( + input.numel() == N * C * HxW, + lambda: f"Expect input to have { N * C * HxW} elements", + ) + torch._check( + mean.shape == (N, group), + lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", + ) + torch._check( + gamma is None or gamma.numel() == C, + lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", + ) + + cpg, _rem = divmod(C, group) + torch._check( + _rem == 0, + lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", + ) + + # Compute Internal gradients + ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2]) + db = grad_output.view(N, C, HxW).sum(dim=[2]) + + d_input: Optional[Tensor] = None + d_gamma: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + s = 1.0 / (HxW * cpg) + if gamma is not None: + ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + gamma.reshape(1, group, cpg), + ) + else: + ds_val = ds.reshape(N, group, cpg).sum(2) + db_val = db.reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + torch.ones((1, group, cpg), device=rstd.device), + ) + c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s + c3 = -c2 * mean - db_val * rstd * s + + c1 = c1.unsqueeze(-1) + c2 = _unsqueeze_to_dim(c2, 4) + c3 = _unsqueeze_to_dim(c3, 4) + d_input = ( + torch.mul(grad_output.reshape(N, group, cpg, HxW), c1) + + torch.mul(input.reshape(N, group, cpg, HxW), c2) + + c3 + ) + d_input = d_input.reshape(input.shape).to(input.dtype) + if output_mask[1]: + d_gamma = ( + ( + (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1)) + * rstd.unsqueeze(-1) + ) + .sum(dim=[0]) + .reshape(C) + ) + if output_mask[2]: + d_bias = db.sum(dim=[0]) + + return (d_input, d_gamma, d_bias) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_group_norm_backward.out) +def native_group_norm_backward_out( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: List[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_group_norm_backward( + grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: + if x is not None: + return x.to(dtype) + return x + + +# TODO: Take a closer look at the type promotion semantics +@register_decomposition(aten.native_layer_norm_backward.default) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + grad_out_cast, input_cast, weight_cast, bias_cast = ( + x.to(computation_dtype).contiguous() if x is not None else x + for x in (grad_out, input, weight, bias) + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: List[int] = [] + outer_dim_indices: List[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + input.new_zeros(input_shape[axis:]) if output_mask[2] else None, + ) + mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + x_hat = (input_cast - mean) * rstd + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + + inner = a - b - c3 + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + d_input = (rstd / N) * inner + + if output_mask[1] and weight_cast is not None: + if len(outer_dim_indices) > 0: + d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) + else: + d_weight = grad_out_cast * x_hat + + if output_mask[2] and bias_cast is not None: + if len(outer_dim_indices) > 0: + d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) + else: + d_bias = grad_out_cast.clone() + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + _maybe_cast(d_bias, input.dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_layer_norm_backward.out) +def native_layer_norm_backward_out( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_layer_norm_backward( + grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def native_batch_norm_helper( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, + functional: bool, +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + reduction_dims = [0] + list(range(2, input.dim())) + computation_dtype = utils.get_computation_dtype(input.dtype) + new_running_mean = running_mean + new_running_var = running_var + if training: + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = input.to(dtype=computation_dtype) + biased_var, mean = torch.var_mean( + input_acc, dim=reduction_dims, correction=0, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + + output = (input - mean) * rstd + + save_mean = torch.squeeze(mean, reduction_dims) + save_rstd = torch.squeeze(rstd, reduction_dims) + if running_mean is not None: + new_running_mean = momentum * save_mean + (1 - momentum) * running_mean + if not functional: + running_mean.copy_(new_running_mean) + if running_var is not None: + n = input.numel() / input.shape[1] + # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction + # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose + # numerics probably don't matter. + squeezed_var = torch.squeeze(biased_var, reduction_dims) + unbiased_var = squeezed_var * (n / (n - 1)) + new_running_var = momentum * unbiased_var + (1 - momentum) * running_var + if not functional: + running_var.copy_(new_running_var) + else: + assert running_mean is not None and running_var is not None + running_mean = running_mean.to(dtype=computation_dtype, copy=True) + new_running_mean = running_mean + running_var = running_var.to(dtype=computation_dtype, copy=True) + new_running_var = running_var + mean = running_mean + invstd = 1 / (torch.sqrt(running_var + eps)) + # Very annoying inconsistency where CPU and CUDA give different shapes + if input.device.type != "cpu": + save_mean = running_mean + save_rstd = invstd + else: + save_mean = input.new_zeros((0,)) + save_rstd = input.new_zeros((0,)) + mean = _unsqueeze_to_dim(mean, input.dim() - 1) + invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) + output = (input - mean) * invstd + + if weight is not None: + weight = weight.flatten() + weight = _unsqueeze_to_dim(weight, input.dim() - 1) + output = output * weight + + if bias is not None: + bias = bias.flatten() + bias = _unsqueeze_to_dim(bias, input.dim() - 1) + output = output + bias + + if input.device.type == "cpu": + save_mean = save_mean.to(dtype=input.dtype) + save_rstd = save_rstd.to(dtype=input.dtype) + return ( + output.to(dtype=input.dtype), + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) + + +@register_decomposition(aten.native_batch_norm) +@out_wrapper("out", "save_mean", "save_invstd") +def native_batch_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm +# with our new correctly schema'd _native_batch_norm_legit and its variants, but +# we cannot do that immediately in the C++ because it would be forwards incompatible +# with some mobile use cases. +# +# Since this change is most impactful for aot autograd/functionalization, we simply +# register this decomposition on the Autograd key for the python dispatcher (which is +# currently only used by aot autograd/functionalization and no one else, really). +# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm +# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations). +@aten.native_batch_norm.default.py_impl(DispatchKey.Autograd) +@aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def native_batch_norm_decomposition( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + if running_mean is None and running_var is None: + return aten._native_batch_norm_legit( + input, weight, bias, training, momentum, eps + ) + if running_mean is None: + raise RuntimeError( + "running_mean is None, but running_var is provided. " + "They should both be None or both be provided." + ) + if running_var is None: + raise RuntimeError( + "running_var is None, but running_mean is provided. " + "They should both be None or both be provided." + ) + if training: + # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg. + return aten._native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + else: + return aten._native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) + + +@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]: + dim_size = tensor.size(dim) + split_size = (dim_size + chunks - 1) // chunks + + if split_size == 0 and dim_size == 0: + split_sizes = [split_size for _ in chunks] + split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim) + return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim) + + +@register_decomposition(aten._native_batch_norm_legit_no_training.default) +def _native_batch_norm_legit_no_training( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + return aten._native_batch_norm_legit.default( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + ) + + +@register_decomposition(aten._native_batch_norm_legit.default) +def _native_batch_norm_legit( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit.no_stats) +def _native_batch_norm_legit_no_stats( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, None, None, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit_functional.default) +def _native_batch_norm_legit_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, True + ) + assert new_running_mean is not None, "new_running_mean should not be None" + assert new_running_var is not None, "new_running_var should not be None" + return output, save_mean, save_rstd, new_running_mean, new_running_var + + +@register_decomposition(aten._fused_dropout) +@out_wrapper("out0", "out1") +@pw_cast_for_opmath +def _fused_dropout_decomposition(input, p, generator=None): + assert generator is None + mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) + res = mask.type_as(input) * input * (1.0 / p) + return (res, mask) + + +def device_hint(tensor): + if isinstance(tensor, torch._subclasses.FakeTensor): + return tensor.fake_device + else: + return None + + +@register_decomposition(aten._to_copy) +@out_wrapper() +def _to_copy( + x: Tensor, + *, + dtype: Optional[torch.dtype] = None, + layout=None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + non_blocking: bool = False, + memory_format: Optional[torch.memory_format] = None, +): + assert not layout or layout == torch.strided, "TODO" + assert not pin_memory, "TODO" + if device is None and dtype is None and memory_format is None: + return x.clone() + dtype_converted = False + common_device = device_hint(x) + + if device is not None and device != x.device: + # avoid conversions on cpu + if dtype is not None and device.type == "cpu": + x = torch._prims.convert_element_type(x, dtype) + dtype_converted = True + x = torch._prims.device_put(x, device) + + if dtype is not None and not dtype_converted: + x = torch._prims.convert_element_type(x, dtype) + dtype_converted = True + + if memory_format is not None: # no ref/prim for memory format + return torch.clone(x, memory_format=memory_format) + return x + + +# Questionable decompositions +# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. +# Note that this decomposition causes issues with in-place ops +@register_decomposition([aten.detach, aten.lift, aten.lift_fresh]) +@out_wrapper() +def nop_decomposition(x): + return aten.alias(x) + + +# Also register to the Autograd dispatch key, so this decomp can run above autograd. +# native_batch_norm needs to decompose into other ops before autograd. +@aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.cudnn_batch_norm) +@out_wrapper("out0", "out1", "out2", "out3") +def cudnn_batch_norm( + input: Tensor, + weight: Tensor, + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +): + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + # Cudnn return running mean and variance when training is True + if training: + return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + input.new_zeros((0,), dtype=torch.uint8), + ) + + +def _broadcast_batch_norm_backward(x, broadcast_mask): + for axis, mask in enumerate(broadcast_mask): + if mask == 1 and not (axis < x.ndim and x.shape[axis] == broadcast_mask[axis]): + x = x.unsqueeze(axis) + return x + + +@register_decomposition(aten.native_batch_norm_backward.default) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_dtype = input.dtype + if weight is not None: + weight_dtype = weight.dtype + else: + weight_dtype = input_dtype + computation_dtype = utils.get_computation_dtype(input.dtype) + ( + grad_out_cast, + input_cast, + weight_cast, + running_mean_cast, + running_var_cast, + save_mean_cast, + save_invstd_cast, + ) = ( + x.to(computation_dtype) if x is not None else x + for x in ( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + ) + ) + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(list(input_shape)) / input_shape[axis] + mean = save_mean_cast + invstd = save_invstd_cast + if train: + assert save_mean_cast is not None and save_invstd_cast is not None + else: + assert running_mean_cast is not None and running_var_cast is not None + mean = running_mean_cast + invstd = torch.rsqrt(running_var_cast + eps) + + broadcast_mask: List[int] = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type] + dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator] + + grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask) + proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator] + + if weight_cast is None: + grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type] + else: + grad_scale = _broadcast_batch_norm_backward( + invstd * weight_cast, broadcast_mask + ) + + if train: + proj = (input_cast - mean) * proj_scale # type: ignore[operator] + grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out_cast * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + else: + grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp + + return ( + grad_input.to(input_dtype), + _maybe_cast(grad_weight, weight_dtype), + _maybe_cast(grad_bias, weight_dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_batch_norm_backward.out) +def native_batch_norm_backward_out( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + result = native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +@register_decomposition(aten.cudnn_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def cudnn_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, + reserveSpace: Tensor, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten._adaptive_avg_pool2d) +@out_wrapper() +@pw_cast_for_opmath +def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): + # Preconditions + device = input.device + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4), + lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", + ) + for d in input.shape[-2:]: + torch._check( + d != 0, + lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}.", + ) + + # Optimisation (we should also do this in the kernel implementation) + if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0: + stride = tuple(i // o for i, o in zip(shape[-2:], output_size)) + kernel = tuple( + i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride) + ) + return torch.nn.functional.avg_pool2d(input, kernel, stride) + + def start_index(a, b, c): + return torch.div(a * c, b, rounding_mode="trunc") + + def end_index(a, b, c): + return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc") + + def compute_idx(in_size, out_size): + orange = torch.arange(out_size, device=device, dtype=torch.int64) + i0 = start_index(orange, out_size, in_size) + # Let length = end_index - start_index, i.e. the length of the pooling kernels + # length.max() can be computed analytically as follows: + maxlength = in_size // out_size + 1 + in_size_mod = in_size % out_size + # adaptive = True iff there are kernels with different lengths + adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) + if adaptive: + maxlength += 1 + elif in_size_mod == 0: + maxlength -= 1 + + range_max = torch.arange(maxlength, device=device, dtype=torch.int64) + idx = i0.unsqueeze(-1) + range_max + if adaptive: + # Need to clamp to avoid accessing out-of-bounds memory + # TODO make minimum accept scalars + maxval = torch.scalar_tensor( + in_size - 1, dtype=idx.dtype, device=idx.device + ) + idx = torch.minimum(idx, maxval) + + # Compute the length + i1 = end_index(orange, out_size, in_size) + length = i1 - i0 + else: + length = maxlength + return idx, length, range_max, adaptive + + # length is not None if it's constant, otherwise we'll need to compute it + idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2]) + idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1]) + + vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw] + # Shortcut for the simpler case + if not adaptive_h and not adaptive_w: + return torch.mean(vals, dim=(-3, -1)) + + def maybe_mask(vals, length, range_max, adaptive, dim): + if isinstance(length, IntLike): + return vals, length + else: + # zero-out the things we didn't really want to select + assert dim < 0 + # hack + mask = range_max >= length.unsqueeze(-1) + if dim == -2: + mask = _unsqueeze_to_dim(mask, 4) + vals = torch.masked_fill(vals, mask, 0.0) + # Compute the length of each window + length = _unsqueeze_to_dim(length, -dim) + return vals, length + + vals, length_h = maybe_mask( + vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2 + ) + vals, length_w = maybe_mask( + vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1 + ) + + # We unroll the sum as we assume that the kernels are going to be small + ret = None + for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])): + if ret is None: + ret = vals[..., i, :, j] + else: + ret = ret + vals[..., i, :, j] + return ret / (length_h * length_w) + + +@register_decomposition(aten.index_add_) +def index_add_( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha) + + +@register_decomposition(aten.index_add) +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + +def _index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + inplace: bool, + alpha: NumberType = 1, +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + index_size = index.size(0) if index.ndim == 1 else 1 + tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1 + torch._check( + tensor_size == index_size, + lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}", + ) + if alpha != 1: + python_type = utils.dtype_to_type(x.dtype) + torch._check( + python_type == bool + or utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + tensor = tensor * alpha + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor, accumulate=True) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +@register_decomposition(aten.pad_sequence.default) +@aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def pad_sequence(sequences, batch_first=False, padding_value=0.0): + torch._check(len(sequences) > 0, lambda: "received an empty list of sequences") + sequences_size = len(sequences) + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max(x.size(0) for x in sequences) + if batch_first: + out_dims = (sequences_size, max_len) + else: + out_dims = (max_len, sequences_size) + out_dims = out_dims + trailing_dims + out = sequences[0].new_full(out_dims, padding_value) + dim_paddings = (0, 0) * len(trailing_dims) + for i in range(sequences_size): + currseq = sequences[i] + row = aten.constant_pad_nd( + currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value + ) + if batch_first: + out = aten.select_scatter(out, row, dim=0, index=i) + else: + out = aten.select_scatter(out, row, dim=1, index=i) + return out + + +@register_decomposition(aten.index_copy_) +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=True) + + +@register_decomposition(aten.index_copy) +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=False) + + +def _index_copy( + x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + index = index.unsqueeze(0) if index.ndim == 0 else index + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +# nb: Should use acc_t, not op_math +@register_decomposition(aten.log_sigmoid_forward) +@out_wrapper("output", "buffer") +@pw_cast_for_opmath +def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +@register_decomposition(aten.uniform) +@out_wrapper() +def uniform( + x: Tensor, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + generator: Optional[torch.Generator] = None, +): + return prims._uniform_helper( + x.shape, + low=sym_float(low), + high=sym_float(high), + dtype=x.dtype, + device=x.device, + generator=generator, + ) + + +@register_decomposition(aten.uniform_) +def uniform_(self, low=0, high=1, generator=None): + return self.copy_(uniform(self, low, high, generator)) + + +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + torch._check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + torch._check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(scale_factors) == spatial_dimensions, lambda: "") + output_size = [] + for i, s in enumerate(scale_factors): + if int(s) == s: + output_size.append(input_size[i + 2] * int(s)) + else: + output_size.append(sym_int(input_size[i + 2] * s)) + return output_size + torch._check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] + + +@register_decomposition(aten.upsample_nearest1d.vec) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd) +def upsample_nearest1d_vec(input, output_size, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale = get_scale_value(scale_factors, 0) + + return aten.upsample_nearest1d.default(input, osize, scale) + + +@register_decomposition(aten._upsample_nearest_exact1d.vec) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_exact1d_vec(input, output_size, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale = get_scale_value(scale_factors, 0) + + return aten._upsample_nearest_exact1d.default(input, osize, scale) + + +@register_decomposition(aten.upsample_nearest2d.vec) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd) +def upsample_nearest2d_vec(input, output_size, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + + return aten.upsample_nearest2d.default(input, osize, scale_h, scale_w) + + +@register_decomposition(aten._upsample_nearest_exact2d.vec) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_exact2d_vec(input, output_size, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + + return aten._upsample_nearest_exact2d.default(input, osize, scale_h, scale_w) + + +@register_decomposition(aten.upsample_nearest3d.vec) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd) +def upsample_nearest3d_vec(input, output_size, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_d = get_scale_value(scale_factors, 0) + scale_h = get_scale_value(scale_factors, 1) + scale_w = get_scale_value(scale_factors, 2) + + return aten.upsample_nearest3d.default(input, osize, scale_d, scale_h, scale_w) + + +@register_decomposition(aten._upsample_nearest_exact3d.vec) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_exact3d_vec(input, output_size, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_d = get_scale_value(scale_factors, 0) + scale_h = get_scale_value(scale_factors, 1) + scale_w = get_scale_value(scale_factors, 2) + + return aten._upsample_nearest_exact3d.default( + input, osize, scale_d, scale_h, scale_w + ) + + +def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): + # For each dim in output_size, compute the set of input indices used + # to produce the upsampled output. + indices = [] + num_spatial_dims = len(output_size) + offset = 0.5 if exact else 0.0 + + for d in range(num_spatial_dims): + # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp + # + # Indices are computed as following: + # scale = isize / osize + # Case: exact=False + # input_index = floor(output_index * scale) + # Same as OpenCV INTER_NEAREST + # + # Case: exact=False + # index_f32 = (output_index + 0.5) * scale - 0.5 + # input_index = round(index_f32) + # Same as Pillow and Scikit-Image/Scipy ndi.zoom + osize = output_size[d] + isize = input.shape[-num_spatial_dims + d] + scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize + + output_indices = torch.arange(osize, dtype=torch.float32, device=input.device) + input_indices = ((output_indices + offset) * scale).to(torch.int64) + for _ in range(num_spatial_dims - 1 - d): + input_indices = input_indices.unsqueeze(-1) + indices.append(input_indices) + return tuple(indices) + + +@register_decomposition(aten.upsample_nearest1d.default) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) +@pw_cast_for_opmath +def upsample_nearest1d( + input: Tensor, + output_size: List[int], + scales: Optional[float] = None, +) -> Tensor: + (l_indices,) = _compute_upsample_nearest_indices(input, output_size, (scales,)) + return aten._unsafe_index(input, (None, None, l_indices)) + + +@register_decomposition(aten._upsample_nearest_exact1d.default) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) +@pw_cast_for_opmath +def _upsample_nearest_exact1d( + input: Tensor, + output_size: List[int], + scales: Optional[float] = None, +) -> Tensor: + (l_indices,) = _compute_upsample_nearest_indices( + input, output_size, (scales,), exact=True + ) + return aten._unsafe_index(input, (None, None, l_indices)) + + +def _upsample_nearest2d_common(input, h_indices, w_indices): + result = aten._unsafe_index(input, (None, None, h_indices, w_indices)) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + _, n_channels, _, _ = input.shape + if input.device.type == "cuda" and n_channels < 4: + memory_format = torch.contiguous_format + + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.upsample_nearest2d.default) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) +@pw_cast_for_opmath +def upsample_nearest2d( + input: Tensor, + output_size: List[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + h_indices, w_indices = _compute_upsample_nearest_indices( + input, output_size, (scales_h, scales_w) + ) + return _upsample_nearest2d_common(input, h_indices, w_indices) + + +@register_decomposition(aten._upsample_nearest_exact2d.default) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) +@pw_cast_for_opmath +def _upsample_nearest_exact2d( + input: Tensor, + output_size: List[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + h_indices, w_indices = _compute_upsample_nearest_indices( + input, output_size, (scales_h, scales_w), exact=True + ) + return _upsample_nearest2d_common(input, h_indices, w_indices) + + +@register_decomposition(aten.upsample_nearest3d.default) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) +@pw_cast_for_opmath +def upsample_nearest3d( + input: Tensor, + output_size: List[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + d_indices, h_indices, w_indices = _compute_upsample_nearest_indices( + input, output_size, (scales_d, scales_h, scales_w) + ) + result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices)) + + return result + + +@register_decomposition(aten._upsample_nearest_exact3d.default) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) +@pw_cast_for_opmath +def _upsample_nearest_exact3d( + input: Tensor, + output_size: List[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + d_indices, h_indices, w_indices = _compute_upsample_nearest_indices( + input, output_size, (scales_d, scales_h, scales_w), exact=True + ) + result = aten._unsafe_index(input, (None, None, d_indices, h_indices, w_indices)) + + return result + + +def gather_params(params, has_biases, has_projections): + if has_biases and has_projections: + group_size = 5 + elif has_biases: + group_size = 4 + elif has_projections: + group_size = 3 + else: + group_size = 2 + + assert len(params) % group_size == 0, len(params) + return [ + tuple(params[i : i + group_size]) for i in range(0, len(params), group_size) + ] + + +def params_hiddens(params, hiddens, i, bidirectional): + if bidirectional: + cur_params, cur_hidden = params[2 * i], hiddens[2 * i] + bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1] + else: + cur_params, cur_hidden = params[i], hiddens[i] + bidir_params, bidir_hidden = None, None + + return cur_params, cur_hidden, bidir_params, bidir_hidden + + +def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens): + assert last_batch_size > batch_size + hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size)) + return cur_hidden.narrow(0, 0, batch_size) + + +def update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, batch_size, inp_hidden +): + if last_batch_size == batch_size: + return cur_hidden + assert last_batch_size < batch_size + return torch.concat( + ( + cur_hidden, + inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size), + ) + ) + + +def one_layer_rnn_data( + inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False +): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + step_output = [] + hiddens: List[torch.Tensor] = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + cur_hidden = hidden.narrow(0, 0, last_batch_size) + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + for inp in split_inp: + i = inp.shape[0] + + if last_batch_size == i: + pass # don't update cur_hidden + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + elif reverse: + cur_hidden = update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, i, hidden + ) + else: + cur_hidden = update_hidden_for_packed( + cur_hidden, last_batch_size, i, hiddens + ) + + cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + last_batch_size = i + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + else: + hiddens.append(cur_hidden) + hiddens.reverse() + + out = torch.cat(step_output, 0) + hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden + return out, hidden_out + + +def rnn_cell(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def rnn_cell_data(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + i = F.linear(i, ih_weight, ih_bias) + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + cur_hidden = hidden.unsqueeze(0) + step_output = [] + for i in precomputed_input: + cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, cur_hidden.squeeze(0) + + +def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + w0 = params[0] + w1 = params[1] + if has_biases: + w2 = params[2] + w3 = params[3] + else: + w2 = torch.zeros(w0.size()) + w3 = torch.zeros(w1.size()) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + batch_sizes: List[int] = [] + mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2 + hidden_size = hx.size(2) + num_layers = 1 + + # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here + bidirectional = False + batch_first = False + + train = False + # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here. + # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous(); + inp = inp.contiguous() + hx = hx.contiguous() + cx = cx.contiguous() + outputs = torch.ops.aten.mkldnn_rnn_layer.default( + inp, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ) + y, hy, cy = outputs[0], outputs[1], outputs[2] + return y, (hy.squeeze(0), cy.squeeze(0)) + + +def _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, +): + input = input.transpose(0, 1) if batch_first else input + final_hiddens = [] + + for i in range(num_layers): + cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens( + params, hidden, i, bidirectional + ) + dropout = dropout if (train and num_layers < i - 1) else 0.0 + fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases) + final_hiddens.append(fwd_hidden) + + if bidirectional: + bwd_inp, bwd_hidden = layer_fn( + input, bidir_hidden, bidir_params, has_biases, reverse=True + ) + final_hiddens.append(bwd_hidden) + + if bidirectional: + input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined] + else: + input = fwd_inp + + if dropout != 0 and train and i < num_layers - 1: + input = torch.dropout(input, dropout, train=True) + + input = input.transpose(0, 1) if batch_first else input + return input, final_hiddens + + +@register_decomposition(aten.rnn_tanh.input) +@aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.input.py_impl(DispatchKey.Autograd) +def rnn_tanh_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.input) +@aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.input.py_impl(DispatchKey.Autograd) +def rnn_relu_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.data) +@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.data.py_impl(DispatchKey.Autograd) +def rnn_relu_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.relu), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_tanh.data) +@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd) +def rnn_tanh_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.tanh), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim): + gates = F.linear(hx, hh_weight, hh_bias) + inp + chunked_gates = gates.chunk(4, chunk_dim) + in_gate = chunked_gates[0].sigmoid() + forget_gate = chunked_gates[1].sigmoid() + cell_gate = chunked_gates[2].tanh() + out_gate = chunked_gates[3].sigmoid() + cy = forget_gate * cx + (in_gate * cell_gate) + hy = out_gate * cy.tanh() + hy = hy if hr_weight is None else F.linear(hy, hr_weight, None) + + return hy, cy + + +def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + step_output = [] + for inp in precomputed_input: + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2) + step_output.append(hx) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, (hx.squeeze(1), cx.squeeze(1)) + + +def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + step_output = [] + hiddens = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + + orig_hx = hidden[0] + orig_cx = hidden[1] + hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow( + 0, 0, last_batch_size + ) + + for inp in split_inp: + i = inp.shape[0] + inp = F.linear(inp, ih_weight, ih_bias) + + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + if i < last_batch_size: + hiddens.append( + ( + hx.narrow(0, i, last_batch_size - i), + cx.narrow(0, i, last_batch_size - i), + ) + ) + hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i) + + # this will only happen when reverse=True + if i > last_batch_size: + hx = torch.concat( + (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + cx = torch.concat( + (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1) + last_batch_size = i + step_output.append(hx) + + if reverse: + step_output.reverse() + hidden_out = (hx, cx) + else: + hiddens.append((hx, cx)) + hiddens.reverse() + hidden0, hidden1 = zip(*hiddens) + hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0) + + out = torch.cat(step_output, 0) + return out, hidden_out + + +def select_one_layer_lstm_function(input, hx, params): + r"""Check whether we could use decompose lstm with mkldnn_rnn_layer. + All the below conditions need to be met: + * ``torch._C._get_mkldnn_enabled()`` returns ``True``. + * All the input args are on CPU. + * The dtypes of args are either torch.float or torch.bfloat16. + * Inference. + * ``has_projections`` returns ``False``. + + Args: + * input: the input sequence to LSTM + * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM + * params: the weight and bias tensors of LSTM + """ + + def use_mkldnn(input, hx, params): + if not torch._C._get_mkldnn_enabled(): + return False + + tensors = [input] + list(hx) + list(chain.from_iterable(params)) + devices = {t.device for t in tensors} + if len(devices) != 1: + return False + + device = devices.pop() + if device != torch.device("cpu"): + return False + # With autocast, possible to have mixed dtype here + dtypes = {t.dtype for t in tensors} + for dtype in dtypes: + if dtype not in [torch.float, torch.bfloat16]: + return False + + if input.requires_grad: + return False + + has_projections = hx[0].size(2) != hx[1].size(2) + if has_projections: + return False + + return True + + # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm + # will expand over the seq_len dim + if use_mkldnn(input, hx, params): + return mkldnn_one_layer_lstm + else: + return one_layer_lstm + + +@register_decomposition(aten.lstm.input) +@aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.input.py_impl(DispatchKey.Autograd) +def lstm_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + layer_fn = select_one_layer_lstm_function(input, hx, params) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +@register_decomposition(aten.lstm.data) +@aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.data.py_impl(DispatchKey.Autograd) +def lstm_data_impl( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_lstm_data, batch_sizes=batch_sizes), + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = inp.chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +@register_decomposition(aten.gru.data) +@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.data.py_impl(DispatchKey.Autograd) +def gru_impl_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.gru.input) +@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.input.py_impl(DispatchKey.Autograd) +def gru_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=gru_cell), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten._upsample_bilinear2d_aa.vec) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bilinear2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten._upsample_bicubic2d_aa.vec) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bicubic2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten.upsample_bilinear2d.vec) +@register_decomposition(aten.upsample_trilinear3d.vec) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_linear_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = scale_factors if scale_factors else [None] * len(osize) + return _upsample_linear(input, osize, align_corners, scales) + + +@register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out]) +@out_wrapper() +def upsample_linear1d( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_w]) + + +@register_decomposition( + [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out] +) +@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def upsample_bilinear2d( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w]) + + +@register_decomposition( + [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out] +) +@out_wrapper() +def upsample_trilinear3d( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear( + input, output_size, align_corners, [scales_d, scales_h, scales_w] + ) + + +def _compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0 + else: + return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size + + +def _compute_source_index(scale, dst_index, align_corners): + if align_corners: + return scale * dst_index + else: + return scale * (dst_index + 0.5) - 0.5 + + +@pw_cast_for_opmath +def _upsample_linear( + input: Tensor, + output_size: List[int], + align_corners: bool, + scales: List[Optional[float]], +) -> Tensor: + # get dimensions of original image + n_batch, n_channels = input.shape[:2] + inp_sizes = input.shape[2:] + n_dims = len(inp_sizes) + + _, dtype = utils.elementwise_dtypes( + input, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + + def get_values(inp_size, out_size, scales, nsqueeze): + # First Calculate scaling factor + scale_factor = _compute_scale(inp_size, out_size, align_corners, scales) + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(out_size, device=input.device).to(dtype=dtype) + + x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0) + x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze)) + x = x_f32.to(torch.int64) + xp1 = (x + 1).clamp(max=inp_size - 1) + return x_f32, x, xp1 + + values = [ + get_values(inp_size, out_size, scales, n_dims - 1 - i) + for i, (inp_size, out_size, scales) in enumerate( + zip(inp_sizes, output_size, scales) + ) + ] + xs_f32, xs, xp1s = list(zip(*values)) + + vs = [] + for a in product(*[[0, 1]] * n_dims): + idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)] + v = aten._unsafe_index(input, idx) + v = _maybe_convert_to_dtype(v, dtype) + vs.append(v) + + for i in reversed(range(n_dims)): + xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype) + vs = [ + # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha + v1 + torch.mul(v2 - v1, xscale) + for v1, v2 in zip(vs[::2], vs[1::2]) + ] + + assert len(vs) == 1 + result = vs[0] + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + if input.device.type == "cuda" and n_channels < 16: + memory_format = torch.contiguous_format + + assert isinstance(result, torch.Tensor) + + result = result.contiguous(memory_format=memory_format) + + if not input.is_floating_point(): + result = result.round() + + return result + + +# We should be applying decompositions after all transformations +@register_decomposition(aten.is_same_size.default) +def is_same_size(a: Tensor, b: Tensor) -> bool: + return a.shape == b.shape + + +@register_decomposition([aten._reshape_alias, aten._unsafe_view]) +@out_wrapper() +def _reshape_alias(x, shape, *args): + return aten.view(x, shape) + + +@register_decomposition([aten._unsafe_index]) +def _index(x, indices): + return aten.index(x, indices) + + +def _nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> Tuple[Tensor, Tensor]: + # self can be [N, C] or [C] + # target can be [N] or [] + + n_dims = self.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + if weight is not None: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + self = self * w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + # target can be [N, 1] or [1] + + result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = self.new_full((), 0.0) + return result, total_weight + + if weight is not None: + w = w.expand(self.shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(self) + + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +@register_decomposition(aten.nll_loss_forward) +@out_wrapper("output", "total_weight") +def nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> Tuple[Tensor, Tensor]: + assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D" + assert ( + target.dim() <= 1 + ), "0D or 1D target tensor expected, multi-target not supported" + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or ( + self.shape[0] == target.shape[0] + ), f"size mismatch (got input: {self.shape}, target: {target.shape})" + + n_classes = self.shape[-1] + + assert weight is None or ( + weight.dim() == 1 and weight.numel() == n_classes + ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}" # noqa: B950 + + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +@register_decomposition(aten.nll_loss2d_forward) +@out_wrapper("output", "total_weight") +def nll_loss2d_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> Tuple[Tensor, Tensor]: + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +# These are adapted from aten/src/ATen/native/UpSample.h, wich is based on +# https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor: + return ((A + 2) * x - (A + 3)) * x * x + 1 + + +def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor: + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A + + +def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType: + A = -0.75 + return ( + _upsample_cubic_convolution2(t + 1.0, A), + _upsample_cubic_convolution1(t, A), + _upsample_cubic_convolution1(1.0 - t, A), + _upsample_cubic_convolution2(2.0 - t, A), + ) + + +def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor: + coeffs2 = _upsample_get_cubic_coefficients(ts) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2)) + + +# Need this instead of just sum() to keep mypy happy +def _sum_tensors(ts: Iterable[Tensor]) -> Tensor: + return reduce(torch.add, ts) + + +def _linspace_from_neg_one( + num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device +): + if num_steps <= 1: + return torch.tensor(0, device=device, dtype=dtype) + + a = ((num_steps - 1) / num_steps) if not align_corners else 1 + return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype) + + +def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated + # corresponding to each individual tensor: grid_x, grid_y, grid_one + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1) + grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0) + return grid_x + grid_y + grid_one + + +def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1) + grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1) + grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0) + grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0) + return grid_x + grid_y + grid_z + grid_one + + +def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool): + n, _, h, w = size + base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners) + # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3) + # We do manually a matrix multiplication which is faster than mm() + # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2) + grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, h, w, 2) + + +def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool): + n, _, d, h, w = size + base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners) + # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4) + # We do manually a matrix multiplication which is faster than mm() + # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3) + grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, d, h, w, 3) + + +@register_decomposition(aten.affine_grid_generator) +@out_wrapper() +@pw_cast_for_opmath +def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool): + torch._check( + len(size) in (4, 5), + lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.", + ) + if len(size) == 4: + return _affine_grid_generator_4d(theta, size, align_corners=align_corners) + else: + return _affine_grid_generator_5d(theta, size, align_corners=align_corners) + + +def _grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, + _expand_grid: bool = True, +) -> Tensor: + # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to + # optionally expand the input grid for performance reasons. + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + + torch._check( + interpolation_mode in (0, 1, 2), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" + ) + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iH, iW = a.shape + _, oH, oW, two = grid.shape + assert two == 2 + + if _expand_grid: + # Let's expand grid to [N, C, oH, oW, 2] + # This allows to generate a single triton cuda kernel instead of two kernels. + # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW + # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW + # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW + grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2) + + def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: + return torch.logical_and( + 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH)) + ) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1) + + def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: + cond = in_bounds_cond(xs, ys) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oH, oW) + for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) + ) + + def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, w_ = clip(ix, iy, w) + return a[N_idx, C_idx, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nw, iy_nw = ix.floor(), iy.floor() + ix_ne, iy_ne = ix_nw + 1, iy_nw + ix_sw, iy_sw = ix_nw, iy_nw + 1 + ix_se, iy_se = ix_ne, iy_sw + + w_nw = (ix_se - ix) * (iy_se - iy) + w_ne = (ix - ix_sw) * (iy_sw - iy) + w_sw = (ix_ne - ix) * (iy - iy_ne) + w_se = (ix - ix_nw) * (iy - iy_nw) + + return _sum_tensors( + get_summand(ix, iy, w) + for (ix, iy, w) in ( + (ix_nw, iy_nw, w_nw), + (ix_ne, iy_ne, w_ne), + (ix_sw, iy_sw, w_sw), + (ix_se, iy_se, w_se), + ) + ) + elif interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nearest = ix.round() + iy_nearest = iy.round() + + return get_summand(ix_nearest, iy_nearest, 1) + else: # interpolation_mode == 2, Bicubic + ix = unnormalize(x, iW) + iy = unnormalize(y, iH) + + ix_nw = ix.floor() + iy_nw = iy.floor() + + tx = ix - ix_nw + ty = iy - iy_nw + + if not _expand_grid: + tx = tx.unsqueeze(1) + ty = ty.unsqueeze(1) + + def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor: + x = compute_coordinates(ix, iW) + y = compute_coordinates(iy, iH) + return get_summand(x, y, 1) + + def get_coeff(ofs: int) -> Tensor: + iy_ofs = iy_nw + (ofs - 1) + cs = ( + get_value_bounded(ix_nw - 1, iy_ofs), + get_value_bounded(ix_nw, iy_ofs), + get_value_bounded(ix_nw + 1, iy_ofs), + get_value_bounded(ix_nw + 2, iy_ofs), + ) + return _upsample_cubic_interp1d(cs, tx) + + coeffs = tuple(get_coeff(ofs) for ofs in range(4)) + return _upsample_cubic_interp1d(coeffs, ty) + + +@register_decomposition(aten.grid_sampler_2d) +@out_wrapper() +@pw_cast_for_opmath +def grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> Tensor: + return _grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + +@register_decomposition(aten.mv) +@out_wrapper() +@pw_cast_for_opmath +def mv(self, vec): + torch._check( + self.dim() == 2 and vec.dim() == 1, + lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", + ) + torch._check( + self.size(1) == vec.size(0), + lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", + ) + return (self * vec).sum(dim=1) + + +@register_decomposition(aten.binary_cross_entropy_with_logits) +@out_wrapper() +def binary_cross_entropy_with_logits( + self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value +): + if pos_weight is not None: + log_weight = (pos_weight - 1) * target + 1 + loss = (1 - target) * self - (log_weight * F.logsigmoid(self)) + else: + loss = (1 - target) * self - F.logsigmoid(self) + + if weight is not None: + loss = loss * weight + + return apply_loss_reduction(loss, reduction) + + +def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool: + # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp + + t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not (t1.ndim >= 3 and t2.ndim <= 2): + return False + if t2.requires_grad and not is_out: + return True + if tensor1.ndim == 2: + return False + if guard_size_oblivious(t1.numel() == 0): + return True + + t1_shape = t1.shape + t1_stride = t1.stride() + return all( + st1 == st2 * s2 + for (st1, st2, s2) in zip(t1_stride[:-2], t1_stride[1:-1], t1_shape[1:-1]) + ) + + +@aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@out_wrapper(pass_is_out=True) +def matmul(tensor1, tensor2, *, is_out=False): + dim_tensor1 = tensor1.dim() + dim_tensor2 = tensor2.dim() + assert dim_tensor1 != 0 and dim_tensor2 != 0 + if dim_tensor1 == 1 and dim_tensor2 == 1: + return torch.dot(tensor1, tensor2) + elif dim_tensor1 == 2 and dim_tensor2 == 1: + return torch.mv(tensor1, tensor2) + elif dim_tensor1 == 1 and dim_tensor2 == 2: + return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0) + elif dim_tensor1 == 2 and dim_tensor2 == 2: + return torch.mm(tensor1, tensor2) + elif should_fold(tensor1, tensor2, is_out): + # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) || + # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2) + # and some condition on the strides is fulfilled + + # optimization: use mm instead of bmm by folding the batch of the larger tensor + # into its leading matrix dimension + transpose = dim_tensor2 > dim_tensor1 + t1 = tensor2.mT if transpose else tensor1 + t2 = ( + tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1) + ) + # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2) + # and t1 and t2 are matmul-compatible + + # Why not t1.view(-1, sizes_1[-1])? + # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous. + # This can happen in e.g. [3, 5, 0] @ [0, 0]. + sizes_1 = t1.shape + output_shape = list(sizes_1[:-1]) + folded_dim1 = reduce(operator.mul, output_shape) + + # Readjust output_shape if we are multiplying by a matrix + t2_is_matrix = t2.dim() == 2 + if t2_is_matrix: + output_shape.append(t2.shape[1]) + + # This will almost always be a view. + # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) + if t2_is_matrix: + # This copies if we perform a 2D @ 3D and the first tensor requires_grad + # See should_fold native/LinearAlgebra.cpp for why. + output = t1_folded.mm(t2).view(output_shape) + return output.mT.contiguous() if transpose else output + else: + return t1_folded.mv(t2).view(output_shape) + + elif dim_tensor1 >= 1 and dim_tensor2 >= 1: + # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + # we track m1 vs m2 separately even though they must match for nicer error messages + n = tensor1.size(-2) if dim_tensor1 > 1 else 1 + m1 = tensor1.size(-1) + batch_tensor1 = tensor1.shape[:-2] + m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) + p = tensor2.size(-1) if dim_tensor2 > 1 else 1 + + batch_tensor2: List[int] = [] + # TODO: handling of slice + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2.size(i)) + + # Same optimization for the gradients as that in should_fold + # If we're going to broadcast, we force it to go through the should_fold branch + if ( + dim_tensor1 == 3 + and dim_tensor2 == 3 + and batch_tensor1[0] != batch_tensor2[0] + ): + if batch_tensor1[0] == 1 and tensor1.requires_grad: + return matmul(tensor1.squeeze(0), tensor2) + if batch_tensor2[0] == 1 and tensor2.requires_grad: + return matmul(tensor1, tensor2.squeeze(0)) + + # expand the batch portion (i.e. cut off matrix dimensions and expand rest) + expand_batch_portion = list( + torch.broadcast_shapes(batch_tensor1, batch_tensor2) + ) + + tensor1_expand_size = expand_batch_portion + [n, m1] + + expand_batch_product = prod(expand_batch_portion) + + # HACK: We need reshape with symint support + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 + ) + + vector_rhs = dim_tensor2 == 1 + if vector_rhs: + tensor2_expand_size = expand_batch_portion + [m2] + tensor2_expanded = ( + tensor2.expand(tensor2_expand_size) + .reshape(expand_batch_product, m2) + .unsqueeze(2) + ) + else: + tensor2_expand_size = expand_batch_portion + [m2, p] + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p + ) + + output_shape = expand_batch_portion + if dim_tensor1 > 1: + output_shape.append(n) + + if dim_tensor2 > 1: + output_shape.append(p) + + if vector_rhs: + return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape) + else: + return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) + else: + torch._check(False, lambda: "both arguments to matmul need to be at least 1D") + + +@register_decomposition(aten.upsample_bicubic2d.default) +@pw_cast_for_opmath +def upsample_bicubic2d_default( + a: Tensor, + output_size: Tuple[int, int], + align_corners: bool, + scale_h: Optional[float] = None, + scale_w: Optional[float] = None, +) -> Tensor: + N, C, iH, iW = a.shape + oH, oW = output_size + + def compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 + else: + return 1 / scale if scale is not None and scale > 0 else in_size / out_size + + def compute_source_index(scale, dst_index, align_corners): + if align_corners: + return scale * dst_index + else: + return scale * (dst_index + 0.5) - 0.5 + + height_scale = compute_scale(iH, oH, align_corners, scale_h) + width_scale = compute_scale(iW, oW, align_corners, scale_w) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1) + out_y = torch.arange(oH, device=a.device).view((1, 1, oH, 1)) + out_x = torch.arange(oW, device=a.device).view((1, 1, 1, oW)) + + real_x = compute_source_index(width_scale, out_x, align_corners) + in_x = real_x.floor() + t_x = real_x - in_x + ix = in_x.to(dtype=torch.int64) + + real_y = compute_source_index(height_scale, out_y, align_corners) + in_y = real_y.floor() + t_y = real_y - in_y + iy = in_y.to(dtype=torch.int64) + + iys_ofs = (iy - 1, iy, iy + 1, iy + 2) + ixs_ofs = (ix - 1, ix, ix + 1, ix + 2) + + def load_bounded(ys, xs): + y_idx = torch.clamp(ys, 0, iH - 1) + x_idx = torch.clamp(xs, 0, iW - 1) + return aten._unsafe_index(a, [N_idx, C_idx, y_idx, x_idx]) + + def get_x_interp(y): + coeffs_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs) + return _upsample_cubic_interp1d(coeffs_x, t_x) + + coeffs_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs) + result = _upsample_cubic_interp1d(coeffs_y, t_y) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(a) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.upsample_bicubic2d.vec) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_vec( + a: Tensor, + output_size: Optional[Tuple[int, int]], + align_corners: bool, + scale_factors: Optional[Tuple[float, float]] = None, +) -> Tensor: + torch._check( + bool(output_size) + bool(scale_factors) == 1, + lambda: "Must specify exactly one of output_size and scale_factors.", + ) + if output_size is None: + assert scale_factors is not None + output_size = cast( + Tuple[int, int], + tuple( + sym_int(sym_float(w) * scale) + for w, scale in zip(a.shape[2:], scale_factors) + ), + ) + scale_h, scale_w = scale_factors if scale_factors else (None, None) + return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) + + +@register_decomposition(aten.reflection_pad1d) +@register_decomposition(aten.reflection_pad2d) +@register_decomposition(aten.reflection_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +@register_decomposition(aten.replication_pad1d) +@register_decomposition(aten.replication_pad2d) +@register_decomposition(aten.replication_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +def _reflection_or_replication_pad( + a: Tensor, + padding: Tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], +) -> Tensor: + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + result = a + for i in range(dim): + idx: List[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) + result = aten._unsafe_index(result, idx) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.aminmax) +@out_wrapper("min", "max") +def aminmax(self, *, dim=None, keepdim=False): + amin = torch.amin(self, dim=dim, keepdim=keepdim) + amax = torch.amax(self, dim=dim, keepdim=keepdim) + return amin, amax + + +@register_decomposition(aten.nansum) +@out_wrapper() +def nansum(self, dim=None, keepdim=False, *, dtype=None): + return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype) + + +@register_decomposition([aten.arange.default, aten.arange.out]) +@out_wrapper() +def arange_default( + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition([aten.arange.start]) +def arange_start( + start: NumberType, + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition(out_dtype) +def out_dtype_decomp(*args, **kwargs): + from torch._higher_order_ops.out_dtype import out_dtype_dense + + return out_dtype_dense(*args, **kwargs) + + +@register_decomposition(aten.multi_margin_loss) +@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: NumberType = 1, + margin: NumberType = 1, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + input = torch.atleast_2d(input) + target = torch.atleast_1d(target) + nframe = input.shape[0] + dim = input.shape[1] + torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported") + torch._check( + input.ndim == 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}", + ) + torch._check( + target.ndim == 1 and target.numel() == nframe, + lambda: f"inconsistent target size, expected {nframe} but got {target.shape}", + ) + if weight is not None: + weight = torch.atleast_1d(weight) + torch._check( + weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr] + lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr] + ) + target = target.unsqueeze(1) + u = torch.gather(input, dim=1, index=target) + z = margin - u + input + z = z.clamp_min(0) + z = z if p == 1 else z * z + if weight is not None: + z = z * weight[target] + idx = torch.arange(dim, device=input.device) + z = torch.where(idx != target, z, 0) + if reduction == Reduction.MEAN.value: + return z.mean() + elif reduction == Reduction.SUM.value: + return z.sum() / z.shape[1] + else: + return z.mean(dim=1) + + +@register_decomposition(aten.multilabel_margin_loss_forward) +@aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd) +@out_wrapper("output", "is_target") +def multilabel_margin_loss_forward( + input: Tensor, + target: Tensor, + reduction: int, +) -> Tuple[Tensor, Tensor]: + orig_input_shape = input.shape + orig_target_shape = target.shape + input = torch.atleast_2d(input) + target = torch.atleast_2d(target) + dim = input.shape[1] + torch._check( + len(orig_input_shape) <= 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}", + ) + torch._check( + len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape, + lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}", + ) + # ignores labels after the first -1, detects when -1 is not present + idx = torch.arange(dim, device=target.device) + is_end = target == -1 + end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True) + # target indices + target_mask = idx < end_idx + # masks target to be able to use gather, which doesn't allow -1 + tidx0 = torch.where(target_mask, target, 0) + u = torch.gather(input, dim=-1, index=tidx0) + # is_target + tidx1 = torch.where(target_mask, target, -1) + is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1) + # loss + z = 1.0 - u.T.unsqueeze(dim=-1) + input + z = z.clamp_min(0) + z = z / dim + # masks loss + z = torch.where(is_target, 0, z) + # reduction + if reduction == Reduction.MEAN.value: + z = z.sum(dim=(0, -1)).mean() + elif reduction == Reduction.SUM.value: + z = z.sum() + else: + z = z.sum(dim=(0, -1)) + # result + is_target = is_target.to(input.dtype).reshape(orig_target_shape) + return z, is_target + + +# scaled_dot_product_attention used to be decomposed in pre-autograd, given that +# it calls _scaled_dot_product_attention_math and +# _scaled_dot_product_attention_math only has a CompositeImplicitAutograd +# kernel. As a result it's decomposed into ops with finer granularity. +# However recent PRs (#103826 #105131 #115913) added new logic in +# scaled_dot_product_attention and now it calls +# _scaled_dot_product_flash_attention_for_cpu in export path. This results +# in _scaled_dot_product_flash_attention_for_cpu showing up in export result. +# This decomposition ensures scaled_dot_product_attention is still decomposed +# the same way as before, i.e., going through +# _scaled_dot_product_attention_math. Notice that this decomp rule should be +# excluded by inductor. +@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default) +def scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +) -> Tuple[Tensor, Tensor]: + dtype = query.dtype + torch._check( + torch.is_floating_point(query), + lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}", + ) + torch._check( + query.dim() == 4 and key.dim() == 4 and value.dim() == 4, + lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}", + ) + torch._check( + dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}" + ) + torch._check( + query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3], + lambda: "q, k, v should have the same head size", + ) + + output, attn = aten._scaled_dot_product_attention_math.default( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + dropout_mask=None, + scale=scale, + ) + # Why this change? + # In pre-dispatch export scaled_dot_product_attention is executed via + # * flash_attention. + # flash_attention allocates output tensor as (N, L, H, E) + # it then transposes that to get (N, H, L, E) which is supposed to be the return + # tensor dim for scaled_dot_product_attention + # assume x: [N, H, L, E] is the output sdpa + # In MHA code, this output is then permuted via (2, 0, 1, 3) to get + # (L, N, H, E) dim tensor + # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via + # x = x.view(L * N, H * E) + # During pre autograd dispatch call to contiguous is not traced because + # flash_attention output after the x.permute is already contiguous + # on which the view is valid + # However, during 2nd stage export, post-dispatch, we run _match variant + # instead of flash* to get the decomposition. _match variant returns + # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns + # x: [L, N, H, E] and without converting this to contiguous tensor + # subsequent view is not valid and the export fails + # solution is to maintain the return tensor view from the decomp to be + # exactly same as *flash* variant. + # flash variants output is contiguous as [N, L, H, E] + # _match variant out is contiguous as [N, H, L, E] + # out = out.transpose(1, 2).contiguous gets output as contiguous + # in [N, L, H, E]. + # Subsrequent transpose(1, 2) then returns a view on which + # aforementioned code snippet, as showm below, is valid + # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via + # x = x.view(L * N, H * E) + + # Really the invariant you want to maintain is: + # pre-dispatch op-output and its decomposed representation must + # return tensor with same view and dims + output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) + return (output.transpose(1, 2), attn) + + +def register_inplace(aten_op, outplace_op): + @register_decomposition(aten_op) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +@register_decomposition([aten.baddbmm]) +@out_wrapper() +@pw_cast_for_opmath +def baddbmm(self, batch1, batch2, beta=1, alpha=1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + result = torch.bmm(batch1, batch2) + if not isinstance(alpha, numbers.Number) or alpha != 1: + result = result * alpha + if beta == 0: + return result + if not isinstance(beta, numbers.Number) or beta != 1: + self = self * beta + return self + result + + +@register_decomposition(aten.floor_divide) +@out_wrapper() +def floor_divide(self, other): + return torch.div(self, other, rounding_mode="floor") + + +@register_decomposition(aten.sym_numel) +def sym_numel(t): + return functools.reduce(operator.mul, t.shape, 1) + + +@register_decomposition([aten.sum.default, aten.sum.out]) +def sum_default( + self: Tensor, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + return aten.sum.dim_IntList(self, [], dtype=dtype) + else: + return aten.sum.IntList_out(self, [], dtype=dtype, out=out) + + +@register_decomposition([aten.squeeze.default, aten.squeeze.dim]) +def squeeze_default(self: Tensor, dim: Optional[int] = None): + if dim is None: + return aten.squeeze.dims(self, list(range(self.dim()))) + else: + return aten.squeeze.dims(self, [dim]) + + +@register_decomposition(torch.ops.aten._weight_norm_interface) +def _weight_norm_interface(x, y, dim=0): + # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 + keep_dim = tuple(i for i in range(len(x.shape)) if i != dim) + norm = x.norm(2, keep_dim, keepdim=True) + return x * (y / norm), norm + + +@register_decomposition(aten.isin) +@out_wrapper() +def isin(elements, test_elements, *, assume_unique=False, invert=False): + # handle when either elements or test_elements are Scalars (they can't both be) + if not isinstance(elements, torch.Tensor): + elements = torch.tensor(elements, device=test_elements.device) + if not isinstance(test_elements, torch.Tensor): + test_elements = torch.tensor(test_elements, device=elements.device) + + if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145): + return isin_default(elements, test_elements, invert=invert) + else: + return isin_sorting( + elements, test_elements, assume_unique=assume_unique, invert=invert + ) + + +def isin_default(elements, test_elements, *, invert=False): + if elements.numel() == 0: + return torch.empty_like(elements, dtype=torch.bool) + + x = elements.view(*elements.shape, *((1,) * test_elements.ndim)) + if not invert: + cmp = x == test_elements + else: + cmp = x != test_elements + dim = tuple(range(-1, -test_elements.ndim - 1, -1)) + return cmp.any(dim=dim) + + +def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False): + elements_flat = elements.flatten() + test_elements_flat = test_elements.flatten() + if assume_unique: + # This is the same as the aten implementation. For + # assume_unique=False, we cannot use unique() here, so we use a + # version with searchsorted instead. + all_elements = torch.cat([elements_flat, test_elements_flat]) + sorted_elements, sorted_order = torch.sort(all_elements, stable=True) + + duplicate_mask = sorted_elements[1:] == sorted_elements[:-1] + duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False) + + if invert: + duplicate_mask = duplicate_mask.logical_not() + + mask = torch.empty_like(duplicate_mask) + mask = mask.index_copy(0, sorted_order, duplicate_mask) + + return mask[0 : elements.numel()] + else: + sorted_test_elements, _ = torch.sort(test_elements_flat) + idx = torch.searchsorted(sorted_test_elements, elements_flat) + test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0) + cmp = sorted_test_elements[test_idx] == elements_flat + cmp = cmp.logical_not() if invert else cmp + return cmp.reshape(elements.shape) + + +@register_decomposition(aten.take) +@out_wrapper() +def take(self, index): + flattened = self.reshape(-1) + return flattened[index] + + +register_inplace(aten.addbmm_, aten.addbmm) +register_inplace(aten.addmm_, aten.addmm) +register_inplace(aten.addmv_, aten.addmv) +register_inplace(aten.baddbmm_, aten.baddbmm) +register_inplace(aten.fill_, aten.fill) +register_inplace(aten.gelu_, aten.gelu) +register_inplace(aten.hardswish_, aten.hardswish) +register_inplace(aten.hardtanh_, aten.hardtanh) +register_inplace(aten.hardsigmoid_, aten.hardsigmoid) +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.index_put_, aten.index_put) +register_inplace(aten.index_reduce_, aten.index_reduce) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) +register_inplace(aten.leaky_relu_, aten.leaky_relu) +register_inplace(aten.logit_, aten.logit) +register_inplace(aten.relu_, aten.relu) +register_inplace(aten.renorm_, aten.renorm) +register_inplace(aten.round_, aten.round) +register_inplace(aten.scatter_, aten.scatter) +register_inplace(aten.scatter_add_, aten.scatter_add) +register_inplace(aten.scatter_reduce_, aten.scatter_reduce) +register_inplace(aten.silu_, aten.silu) diff --git a/MLPY/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py b/MLPY/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py new file mode 100644 index 0000000000000000000000000000000000000000..bf91d9fb83427d13cc61133852ee8bb3fbba6e67 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py @@ -0,0 +1,302 @@ +import inspect +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch._decomp +from torch import Tensor +from torch._prims_common.wrappers import _maybe_remove_out_wrapper + +decomposition_table = torch._decomp.decomposition_table +decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {} +register_decomposition = torch._decomp.register_decomposition +aten = torch.ops.aten + +# NOTE: [forward-mode AD decompositions mechanism] +# +# The mechanism is in VariableType, +# IF any inputs have forward grad +# AND there is no forward AD formula implemented +# AND the functions is actually differentiable +# run the decomposition +# See run_jit_decomposition_with_args_for_jvp +# We currently use python decompositions that we torchscript. +# +# Note that we would be building the backward graph at the decomposed level +# too, but that is OK, because we would've errored out otherwise anyway. +# +# TODO: The mechanism we are using to register decompositions doesn't +# seem to be exclusively used for jvp. So open question here is whether +# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things. +# If that is the case, we may go down the decomposition path unexpectedly +# (and possibly produce an unintelligible error) vs erroring out earlier and +# printing that the forward AD formula is not implemented. +# +# The solution to this may be to have a explicitly white list control when +# to enable the decomposition. + + +def maybe_register_decomposition(op): + def decorator(f): + try: + return register_decomposition(op)(f) + except Exception: + return f + + return decorator + + +# Functions where we need a special decomposition for jvp but there's another version that +# should be used more generally (ex. for jvp we need to recompute the mean and variance for +# the backwards of a normalization function. Without jvp, it should use the saved value) +decomposition_table_for_jvp = {} + + +def register_decomposition_for_jvp(fn): + return register_decomposition(fn, registry=decomposition_table_for_jvp) + + +def _register_jit_decomposition_for_jvp(decomp, use_python=False): + if decomp in decomposition_table_for_jvp: + decomposition_table_used = decomposition_table_for_jvp + elif decomp in decomposition_table: + decomposition_table_used = decomposition_table + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + decomp_fn = decomposition_table_used[decomp] + + # `out_wrapper` extends a decompositions signature with + # an `out` parameter. However jit will use the unwrapped function's + # signature instead so we need to unwrap here to prevent an error + decomp_fn = _maybe_remove_out_wrapper(decomp_fn) + + if use_python: + decomp_fn = torch.jit.ignore(decomp_fn) + sig = inspect.signature(decomp_fn) + + # Create a string wrapping the function from the signature + # example output: + # def wrapped_decomp(x: torch.Tensor, y: int, z: int): + # return decomp_fn(x, y, z) + # Thanks copilot! + def get_function_def(sig): + param_def = [f"{param_str}" for param_str in sig.parameters.values()] + param_use = [f"{param_str}" for param_str in sig.parameters.keys()] + + return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" + + f_str = get_function_def(sig) + graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph + else: + graph = torch.jit.script(decomp_fn).graph + torch.jit._register_decomposition(decomp, graph) + + +# The only decompositions here are temporary or hacks for the purposes of jvp + + +# TODO: do these also belong here? +@maybe_register_decomposition(aten.trace.default) +def trace(self: Tensor) -> Tensor: + return torch.sum(torch.diag(self)) + + +@maybe_register_decomposition(aten.log_sigmoid_forward.default) +def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +def recompute_mean_var( + input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool +): + # for most norm decompositions, it will be the same as the core version except for here. + # We recompute the mean and variance so that they track gradients through input + + mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) + var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) + eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside + eps = eps.detach() + rstd = 1 / torch.sqrt(var + eps) + return mean, rstd + + +@register_decomposition_for_jvp(aten.native_layer_norm_backward) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: List[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: List[bool], +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices = list(range(axis, input_ndim)) + outer_dim_indices = list(range(0, axis)) + + N = 1 + for i in inner_dims: + N *= i + M = 1 + for i in outer_dims: + M *= i + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape), + input.new_zeros(input_shape[axis:]), + input.new_zeros(input_shape[axis:]), + ) + + mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) + + x_hat = (input - mean_) * rstd_ + if weight is not None: + grad_x_hat = grad_out * weight + else: + grad_x_hat = grad_out + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + inner = a - b - c3 + + if output_mask[0]: + d_input: Optional[Tensor] = (rstd_ / N) * inner + else: + d_input = torch.zeros_like(input) # should be None but doesn't work with vjp + + if output_mask[1] and weight is not None: + if len(outer_dim_indices) > 0: + d_weight: Optional[Tensor] = torch.sum( + grad_out * x_hat, outer_dim_indices, False + ) + else: + d_weight = grad_out * x_hat + elif weight is not None: + d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + d_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2] and bias is not None: + if len(outer_dim_indices) > 0: + d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) + else: + d_bias = grad_out.clone() + elif bias is not None: + d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp + else: + d_bias = torch.zeros(()) # should be None but doesn't work with vjp + + return (d_input, d_weight, d_bias) + + +def prod(x: List[int]): + r = 1 + for i in x: + r *= i + return r + + +@register_decomposition_for_jvp(aten.native_batch_norm_backward) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: List[bool], +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type] + mean = save_mean + invstd = save_invstd + if train: + assert ( + save_mean is not None and save_invstd is not None + ), "when train=True, save_mean and save_invstd are required" + + reduciton_dims = [0] + list(range(2, input.dim())) + assert invstd is not None # for typing + mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) + else: + assert running_mean is not None and running_var is not None + mean = running_mean + invstd = torch.rsqrt(running_var + eps) + + assert invstd is not None and mean is not None + + broadcast_mask = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: List[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out, reduction_axes) + dot_p = torch.sum(grad_out * (input - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) + + if weight is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 + else: + grad_scale = torch.reshape(invstd * weight, broadcast_mask) + + if train: + proj = (input - mean) * proj_scale + grad_input = ((grad_out - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + elif weight is not None: + grad_weight = torch.zeros_like( + weight + ) # should be None but doesn't work with vjp + else: + grad_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = torch.zeros_like( + grad_output_sum + ) # should be None but doesn't work with vjp + + return (grad_input, grad_weight, grad_bias) + + +_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) diff --git a/MLPY/Lib/site-packages/torch/_decomp/decompositions_for_rng.py b/MLPY/Lib/site-packages/torch/_decomp/decompositions_for_rng.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9d21831d3430c6bc3a0d2c7712a55d0ff32c42 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_decomp/decompositions_for_rng.py @@ -0,0 +1,263 @@ +import functools +from collections import defaultdict +from typing import Callable, Dict + +import torch +import torch._decomp as decomp +from torch._decomp import get_decompositions +from torch._ops import OpOverload + +aten = torch.ops.aten + +rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict) + + +def register_rng_decomposition(aten_op): + return decomp.register_decomposition(aten_op, rng_decompositions) + + +def throw_on_non_cuda(device): + raise RuntimeError( + f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " + f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " + "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." + ) + + +# TODO - We have to register many more distributions here, and also higher level +# ops like dropout which have fused implementation and can hide the rand inside. +@register_rng_decomposition(aten.rand) +def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False): + if device and device.type != "cuda": + throw_on_non_cuda(device) + seed, offset = PhiloxStateTracker.get_state_as_tuple() + dtype = dtype or torch.float32 + out, offset_jump = torch.ops.rngprims.philox_rand( + shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +@register_rng_decomposition(aten.rand_like) +def rand_like( + x: torch.Tensor, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=torch.preserve_format, +): + device = device or x.device + if device.type != "cuda": + throw_on_non_cuda(device) + dtype = dtype or x.dtype + seed, offset = PhiloxStateTracker.get_state_as_tuple() + out, offset_jump = torch.ops.rngprims.philox_rand( + x.shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +class PhiloxState: + """ + Represents a PhiloxRngState - (seed, offset) where offset = base_offset + + relative_offset. seed and base_offset basically point to the rng state just + before tracing starts. relative offset tracks the totally consumed offset at + trace time. + """ + + def __init__(self): + self.reset() + + def reset(self): + self.seed = torch.tensor(()) + self.base_offset = torch.tensor(()) + self.relative_offset = 0 + self.offset_advanced_alteast_once = False + + def validate_state(self): + assert self.seed.numel() != 0 and self.base_offset.numel() != 0 + + def advance_offset(self, consumed_offset): + self.offset_advanced_alteast_once = True + self.relative_offset = self.relative_offset + consumed_offset + + def set_state(self, seed, base_offset, relative_offset=0): + self.seed = seed + self.base_offset = base_offset + self.relative_offset = relative_offset + + def get_state_as_tuple(self): + self.validate_state() + return (self.seed, self.base_offset + self.relative_offset) + + def get_state_as_tensor(self): + # Only needed because we override get_rng_state. + self.validate_state() + return torch.stack([self.seed, self.base_offset + self.relative_offset]) + + def set_state_from_tensor(self, state): + # Only needed because we override set_rng_state. + self.seed, self.base_offset = torch.unbind(state) + self.relative_offset = 0 + + +class PhiloxStateTracker: + """ + Singleton class to track the philox rng state during AOT Autograd tracing. + For each aot tracing instance, AOT Autograd resets this tracker and keeps + track of both forward and backward offsets. At runtime, we only care about + the total consumed forward and backward offsets. For dynamic shapes, these + offsets are a function of input shapes. Therefore, the AOT generated graphs + have additional outputs that compute total consumed forward and backward + offsets. + """ + + running_state: PhiloxState + fwd_state: PhiloxState + bwd_state: PhiloxState + + def __enter__(self): + PhiloxStateTracker.reset() + return self + + def __exit__(self, exc_type, exc_cal, exc_tb): + PhiloxStateTracker.reset() + + @classmethod + def reset(cls): + cls.running_state = PhiloxState() + cls.fwd_state = PhiloxState() + cls.bwd_state = PhiloxState() + + @classmethod + def mark_beginning_of_forward(cls): + # Tells the tracker to use fwd_state as the running state + cls.running_state = cls.fwd_state + + @classmethod + def mark_beginning_of_backward(cls): + # Tells the tracker to use bwd_state as the running state + cls.running_state = cls.bwd_state + + @classmethod + def record_state(cls, seed, offset, mode): + # Records the seed and offset tensors. These tensors are used to invoke + # the philox_rand functional primitives. + if mode == "forward": + cls.fwd_state.set_state(seed, offset) + cls.mark_beginning_of_forward() + else: + assert mode == "backward" + cls.bwd_state.set_state(seed, offset) + + @classmethod + def get_state_as_tensor(cls): + # The only reason this exists is because we override get_rng_state and + # set_rng_state during tracing. get_rng_state expects a tensor output, + # so return (seed, offset) tuple upset other parts of the program like + # ctx.saved_tensors. + + # A bad consequence is that if user saves and restores rng state, we + # have little bit of ugliness in the generated code, where we first + # concat the (seed, offset) to create a tensor for get_rng_state, and + # then split it back to get (seed, offset) tuple in set_rng_state. + + # TODO: Investigate if there is be a better way to wrap the tuple in a + # false Tensor object, and then desugar it later on. + return cls.running_state.get_state_as_tensor() + + @classmethod + def get_state_as_tuple(cls): + return cls.running_state.get_state_as_tuple() + + @classmethod + def set_state_from_tensor(cls, x): + # This is only needed because we override set_rng_state. Look at the + # comment in get_state_from_tensor method. + cls.running_state.set_state_from_tensor(x) + + @classmethod + def advance_offset(cls, consumed_offset): + cls.running_state.advance_offset(consumed_offset) + + @classmethod + def get_current_relative_offset(cls): + return cls.running_state.relative_offset + + @staticmethod + def multiple_of_4(offset): + # torch cuda rng state offset must be a multiple of 4. For inductor, as + # we sum up all the numel, the result might not be a multiple of 4. This + # method achieves that. + return (offset + 3) // 4 * 4 + + @classmethod + def get_updated_fwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.fwd_state.offset_advanced_alteast_once: + return cls.fwd_state.base_offset + return cls.multiple_of_4( + cls.fwd_state.base_offset + cls.fwd_state.relative_offset + ) + + @classmethod + def get_updated_bwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.bwd_state.offset_advanced_alteast_once: + return cls.bwd_state.base_offset + return cls.multiple_of_4( + cls.bwd_state.base_offset + cls.bwd_state.relative_offset + ) + + +# Adding more decompositions which eventually use rand_like inside decomps. +# Adding these in rng_decompositions ensures the functionalization of rand_like +# ops used in these decomps. The list is copied from inductor codebase, which +# uses it for similar purpose. +# +# Caution - These decomps do not have same accuracy as that of eager. However, +# we can't just disable them with a config flag like fallback_random, because +# for functionalization of rng ops, we have to decompose these ops. +extra_random_decomps = get_decompositions( + [ + aten.cauchy, + aten.cauchy_, + aten.exponential, + aten.exponential_, + aten.geometric, + aten.geometric_, + aten.native_dropout, + aten.normal, + aten.normal_, + aten.normal_functional, + aten.log_normal, + aten.log_normal_, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.uniform_, + ] +) +register_extra_random_decomp = functools.partial( + decomp.register_decomposition, registry=extra_random_decomps +) + + +@register_extra_random_decomp([aten.bernoulli_]) +def bernoulli_(self, p=0.5): + if self.device == torch.device("cpu"): + return NotImplemented + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) + + +@register_extra_random_decomp([aten.bernoulli.p]) +def bernoulli_p(self, p=0.5, *, generator=None): + if self.device == torch.device("cpu"): + return NotImplemented + assert generator is None + return torch.rand_like(self, dtype=torch.float32) < p + + +rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type] diff --git a/MLPY/Lib/site-packages/torch/_deploy.py b/MLPY/Lib/site-packages/torch/_deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee4b4d3b33430f15fd875103ddc1e291d353c70 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_deploy.py @@ -0,0 +1,105 @@ +import io + +import torch +from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer +from torch.package._package_pickler import create_pickler +from torch.package._package_unpickler import PackageUnpickler +from torch.serialization import _maybe_decode_ascii + + +def _save_storages(importer, obj): + serialized_storages = [] + serialized_dtypes = [] + + importer = importer if isinstance(importer, torch.package.PackageImporter) else None + importers: Importer + if importer is not None: + importers = OrderedImporter(importer, sys_importer) + else: + importers = sys_importer + + def persistent_id(obj): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, we can + # remove this case + storage = obj._untyped_storage + dtype = obj.dtype + else: + storage = obj + dtype = torch.uint8 + + serialized_storages.append(obj) + serialized_dtypes.append(dtype) + return ("storage", len(serialized_storages) - 1) + + if hasattr(obj, "__reduce_deploy__"): + if _serialized_reduces.get(id(obj)) is None: + _serialized_reduces[id(obj)] = ( + "reduce_deploy", + id(obj), + *obj.__reduce_deploy__(importers), + ) + return _serialized_reduces[id(obj)] + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = create_pickler(data_buf, importers) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + return ( + data_value, + serialized_storages, + serialized_dtypes, + importer.zip_reader if importer else None, + ) + + +def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == "storage": + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + storage = serialized_storages[data[0]] + dtype = serialized_dtypes[data[0]] + return torch.storage.TypedStorage( + wrap_storage=storage.untyped(), dtype=dtype + ) + + if typename == "reduce_deploy": + reduce_id, func, args = data + if reduce_id not in _loaded_reduces: + _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) + return _loaded_reduces[reduce_id] + + return None + + importer: Importer + if zip_reader is not None: + importer = OrderedImporter(_get_package(zip_reader), sys_importer) + else: + importer = sys_importer + + unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) + unpickler.persistent_load = persistent_load # type: ignore[method-assign] + result = _deploy_objects[id] = unpickler.load() + return result + + +def _get_package(zip_reader): + if zip_reader not in _raw_packages: + _raw_packages[zip_reader] = PackageImporter(zip_reader) + return _raw_packages[zip_reader] + + +_raw_packages: dict = {} +_deploy_objects: dict = {} +_serialized_reduces: dict = {} +_loaded_reduces: dict = {} diff --git a/MLPY/Lib/site-packages/torch/_dispatch/__init__.py b/MLPY/Lib/site-packages/torch/_dispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb7e7109147349cfef3bf5101730226d77db3b8b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b6b06bf734f77ac55e166c18ee2087518378bb Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dispatch/python.py b/MLPY/Lib/site-packages/torch/_dispatch/python.py new file mode 100644 index 0000000000000000000000000000000000000000..fe1d23da4c5abf6fcfb0ad74a70befdd78b4342f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dispatch/python.py @@ -0,0 +1,178 @@ +import itertools +import unittest.mock +from contextlib import contextmanager +from typing import Iterator + +import torch +import torch._C +import torch._ops +import torch.utils._python_dispatch +import torch.utils._pytree as pytree + +__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] + +no_python_dispatcher = torch._C._DisablePythonDispatcher +enable_python_dispatcher = torch._C._EnablePythonDispatcher +enable_pre_dispatch = torch._C._EnablePreDispatch + +CROSSREF_FUNCTIONALIZE = False + + +def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: + """ + Warning: the set of overloads this will report is very subtle. It is precisely + the set of torch.ops functions that have actually been accessed from Python + (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT + from the set of registered operators, which will in general be a larger set, + as this would include all operators which we ran C++ static initializers or + Python operator registration on. This does not eagerly populate the list on + torch.ops.aten; this list is lazy! + + In other words, this is good for traversing over everything that has an + OpOverload object allocated in Python. We use it for cache invalidation, but + don't rely on this list being complete. + + Note that even if we did report all C++ registered overloads, this isn't guaranteed + to be complete either, as a subsequent lazy load of a library which triggers more + registrations could add more things to the set. + """ + for ns in torch.ops: + packets = getattr(torch.ops, ns) + for op_name in packets: + packet = getattr(packets, op_name) + for overload in packet: + yield getattr(packet, overload) + + +@contextmanager +def suspend_functionalization(): + f_tls = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + f_rv = torch._C._functionalization_reapply_views_tls() + if f_tls: + torch._disable_functionalization() + try: + yield + finally: + if f_tls: + torch._enable_functionalization(reapply_views=f_rv) + + +def check_tensor_metadata_matches(nv, rv, desc): + assert callable(desc) + assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" + same_strides, idx = torch._prims_common.check_significant_strides( + nv, rv, only_cuda=False + ) + assert ( + same_strides + ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + + +def check_metadata_matches(n, r, desc): + assert callable(desc) + n_vals, n_spec = pytree.tree_flatten(n) + r_vals, r_spec = pytree.tree_flatten(r) + # TODO: test the specs match; empirically sometimes we have a tuple + # on one side and a list on the other + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") + + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + + +def _fmt(a: object) -> object: + if isinstance(a, torch.Tensor): + return Lit( + f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" + ) + else: + return a + + +def make_crossref_functionalize(op, final_key): + from torch._subclasses.fake_tensor import FakeTensorMode + + # This case is pretty weird, suppress it for now + if op == torch.ops.aten.lift_fresh.default: + return final_key + + def handler(*args, **kwargs): + fake_mode = FakeTensorMode() + + def fakeify_defun(t): + if isinstance(t, torch.Tensor): + if torch._is_functional_tensor(t): + r = torch._from_functional_tensor(t) + # NB: This assumes that the inner tensor sizes/strides match + # the outer tensor sizes/strides. This doesn't necessarily have to + # be the case, see discussion at + # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 + assert t.size() == r.size() + assert t.stride() == r.stride() + else: + r = t + # TODO: suppress guards + return fake_mode.from_tensor(r) + return t + + def maybe_detach(t): + if isinstance(t, torch.Tensor): + return t.detach() + else: + return t + + # TODO: This probably does the wrong thing if you're running other + # substantive modes with the normal op outside here + with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization(): + f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) + orig_f_args, orig_f_kwargs = pytree.tree_map( + maybe_detach, (f_args, f_kwargs) + ) + with fake_mode: + f_r = op(*f_args, **f_kwargs) + r = op._op_dk(final_key, *args, **kwargs) + + def desc(): + fmt_args = ", ".join( + itertools.chain( + (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), + ( + f"{k}={pytree.tree_map(_fmt, v)}" + for k, v in orig_f_kwargs.items() + ), + ) + ) + return f"{op}({fmt_args})" + + check_metadata_matches(f_r, r, desc) + return r + + return handler + + +# NB: enabling this is slow, don't do it in a hot loop. This is purely +# for debugging purposes. +@contextmanager +def enable_crossref_functionalize(): + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) + try: + with enable_python_dispatcher(), unittest.mock.patch( + "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True + ): + yield + finally: + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__init__.py b/MLPY/Lib/site-packages/torch/_dynamo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec689d5485e3429d92a7b2786991028f613f7e5f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/__init__.py @@ -0,0 +1,96 @@ +import torch +from . import convert_frame, eval_frame, resume_execution +from .backends.registry import list_backends, lookup_backend, register_backend +from .callback import callback_handler, on_compile_end, on_compile_start +from .code_context import code_context +from .convert_frame import replay +from .decorators import ( + allow_in_graph, + assume_constant_result, + disable, + disallow_in_graph, + forbid_in_graph, + graph_break, + mark_dynamic, + mark_static, + mark_static_address, + maybe_mark_dynamic, + run, +) +from .eval_frame import ( + _reset_guarded_backend_cache, + explain, + export, + is_dynamo_supported, + is_inductor_supported, + optimize, + optimize_assert, + OptimizedModule, + reset_code, +) +from .external_utils import is_compiling +from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count + +__all__ = [ + "allow_in_graph", + "assume_constant_result", + "disallow_in_graph", + "forbid_in_graph", + "graph_break", + "mark_dynamic", + "maybe_mark_dynamic", + "mark_static", + "mark_static_address", + "optimize", + "optimize_assert", + "export", + "explain", + "run", + "replay", + "disable", + "reset", + "OptimizedModule", + "is_compiling", + "register_backend", + "list_backends", + "lookup_backend", +] + +if torch.manual_seed is torch.random.manual_seed: + import torch.jit._builtins + + # Wrap manual_seed with the disable decorator. + # Can't do it at its implementation due to dependency issues. + torch.manual_seed = disable(torch.manual_seed) + # Add the new manual_seed to the builtin registry. + torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed") + + +def reset() -> None: + """Clear all compile caches and restore initial state""" + with convert_frame.compile_lock: + reset_code_caches() + convert_frame.input_codes.clear() + convert_frame.output_codes.clear() + orig_code_map.clear() + guard_failures.clear() + graph_break_reasons.clear() + resume_execution.ContinueExecutionCache.cache.clear() + _reset_guarded_backend_cache() + reset_frame_count() + torch._C._dynamo.compiled_autograd.clear_cache() + convert_frame.FRAME_COUNTER = 0 + convert_frame.FRAME_COMPILE_COUNTER.clear() + callback_handler.clear() + + +def reset_code_caches() -> None: + """Clear compile caches that are keyed by code objects""" + with convert_frame.compile_lock: + for weak_code in ( + convert_frame.input_codes.seen + convert_frame.output_codes.seen + ): + code = weak_code() + if code: + reset_code(code) + code_context.clear() diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e595649e8f4b26768d9002b430cd67f000b2ba52 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a418daeab7cba1196d6ca2951763b3d4fb682b71 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b784e1e9cab12ea0606a10aff0e3f884f3b4868d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..436a535ffe0a4608e5171ad0107f0ca31c32abe0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..086497812b8329dc31fa270d086c31adb9b76b4b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/callback.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/callback.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f39f945c4cfe3c935e912195e91fbbfa93a9f0bf Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/callback.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/code_context.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/code_context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ec1ca30017eb99fc1c3cbc0be8985df66a2cca9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/code_context.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/codegen.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/codegen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf78f1c9e841324d3d61945a5b1fb7be4d4417da Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/codegen.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8786952d037eb96360f186a4b3ab6a9cd2c0c1ae Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/comptime.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/comptime.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..785e6ce847c872c2493af27ccb443fcff53a43dc Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/comptime.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17977d16796b54377d433fd1ad03ad27e5d8e0ab Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..995dfb6aae2d842ac18dcd25c5d391ca205034f6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b8705a1d37912144b5447d3081883d2c4116a34 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64ee3eabbc3a5ad774dce911c1c74f503338176f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/decorators.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/decorators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203dc81fdf05ab7cba2b282bc9b16c86a5dfd4f3 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/decorators.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faf6361d1622a5ce826ff33338e5390c42db58b2 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..091afa63485547d83e344f844d3eced092d7becb Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/exc.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/exc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6750e39905a221f576b7bf229c20b6c45ea4a89 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/exc.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adb113430796c14cb5d62a90c9fc959152525426 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7ea2c7d53c5c2de4fa13b4c0bea8201d3ac473e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/guards.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/guards.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c809468d95116d60637130e32c3f4c522712cf Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/guards.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a5819875c0851a45db6db33229c576f8915124e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/hooks.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/logging.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15c48bd996de0295bf3194883114576dd39a183 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/logging.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41af299a2ee895157ea2b2d0e501ad4f1df5deba Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/output_graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/output_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e2db23f2eb3896a6978b2c51b53d574fd8fa38 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/output_graph.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/polyfill.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/polyfill.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a30fcf32cddfd1a82d18ae35c210401bc4977021 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/polyfill.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/profiler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a25c112870a15184e52b8ae5e54e13332312368c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/profiler.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c13dbe92d536948ae163e36ab52e3db666bc47c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4e78726dc788e3eef0c96ff94e700666b3df4ce Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec218254c0a181ddaa3423a665b3de12a3ce642b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/source.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/source.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c01c78c246f42ae1436763f4226cd8df1f1477ad Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/source.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/symbolic_convert.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/symbolic_convert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9b1cbf2d7076f565c775c4a1d8ab4453da13666 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/symbolic_convert.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c785bd495a995d217bf36ab4c47da575efc6d39 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/test_case.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/test_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df472ea4e106afee7c28a4c3b39411c55187abec Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/test_case.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff27a050d405988417308aa4346f3de302acb332 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/testing.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/testing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc4011169374b780ec1fc233ba4431dcc8704ce0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/testing.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/trace_rules.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/trace_rules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ad0951d1de826c850c95faa750b88ca60530c7a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/trace_rules.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/types.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d95c8a5c1086c1a7398d8d4a60439ac9b28f8e37 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/types.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a63ed4a6a4b72c98c18af0fdb3a53d4db416b4ba Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py b/MLPY/Lib/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6061934ad2badeaada3b9aa3501aa16044c3af --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -0,0 +1,120 @@ +import torch +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented + +from torch._ops import HigherOrderOperator +from torch._subclasses import FakeTensorMode +from torch.fx.experimental._backward_state import BackwardState + +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils._python_dispatch import _get_current_dispatch_mode +from torch.utils._pytree import tree_map_only + + +__all__ = ["trace_wrapped"] + + +# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: +# if you make_fx trace through this call, we will not actually trace into fn; instead, +# we will directly insert it as a call_function to fn in the graph. +# (Unlike make_fx, Dynamo WILL inline into fn.) +# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing. +# +# Because proxy tensor tracing does not actually run the function, there are +# requirements on the behavior of fn. We are still figuring it out, but here is the current state: +# +# 1) fn SHOULD only take a single argument, which must be a tensor +# 2) fn MUST return a new tensor with the same metadata as the original tensor +# (e.g., zeros_like(input) is a permissible implementation of fn). +# This is verified via an extra assert that is inserted into the traced graph. +# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors +# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state) +# These requirements stem from the requirement that we need to continue performing proxy tensor tracing, +# which assumes accurate fake tensor metadata, without actually running fn. +# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns. +# +# Note that tensors / Python state are allowed to be mutated. +# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake +# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete +# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python). +# +# The intended use case for this function is to allow AOTAutograd to defer complex +# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves +# the function call as is in the graph, and only when we Dynamo through the backward graph in +# compiled autograd do we inline into the function. + + +def trace_wrapped(*args, **kwargs): + with torch.no_grad(): + return _trace_wrapped_op(*args, **kwargs) + + +# TODO(jansel): need to ensure this does not get DCEed +_trace_wrapped_op = HigherOrderOperator("trace_wrapped") + + +def _assert_meta(grad, size, stride, dtype): + assert grad.size() == size, "size mismatch" + assert grad.stride() == stride, "stride mismatch" + assert grad.dtype == dtype, "dtype mismatch" + return grad + + +@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode) +def inner_trace(mode, *args, bw_state=None, **kwargs): + def self_invoke(*args, **dyn_kwargs): + with torch.no_grad(): + return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs) + + def unwrap_proxies(x): + if isinstance(x, torch.Tensor): + return mode.tracer.unwrap_proxy(x) + if isinstance(x, (list, tuple)): + return type(x)(map(unwrap_proxies, x)) + if x is None: + return None + raise AssertionError(f"unhandled type: {type(x)}") + + proxy_kwargs = {} + if bw_state is not None: + assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None + proxy_kwargs["bw_state"] = bw_state.proxy + out_proxy = mode.tracer.create_proxy( + "call_function", + self_invoke, + unwrap_proxies(args), + proxy_kwargs, + name="trace_wrapped", + ) + + if args[0] is None: + grad = args[1] # module backward hooks + else: + grad = args[0] # other backward hooks + grad = tree_map_only(torch.Tensor, torch.empty_like, grad) + track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer) + return grad + + +@_trace_wrapped_op.py_impl(FakeTensorMode) +def inner_fake(*args, **kwargs): + raise RuntimeError("This op should never be invoked here") + + +@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def _trace_wrapped_op_dense(*args, fn, **kwargs): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return fn(*args, **kwargs) + + +_trace_wrapped_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_trace_wrapped_op, deferred_error=True) +) + + +@_trace_wrapped_op.py_functionalize_impl +def _trace_wrapped_functionalized(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs)) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__init__.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a33edf30961fc2e309e25d3d67ec87e0928e7e08 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2180a1b0618630d1116d16db89890750edbb81f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af61ca3f1f8bb79ca1eec2d0b479a90bcfbbddc5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..456c39b3c2aa53c54b32c6555acc2cbb3cb073ce Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a810e5114c8016f57ed69e5bb84e4580c8e3933c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..699c7a32dbc138946bcf4f676b3463861b8fd1bf Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eecd342aa1b5de107dc908f20fcfbc3a555b8975 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aea8ad1d121538fc9f782103852b21ae270a31bf Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88c170b964cc7e9ff8a2ecfa2ddcb376a4160af2 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0341cfc8ce6e91efc47228b398a0620d9493281 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..067b49623cde62f99536ffb1c80e99146c7fb4dd Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/common.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0e945f9920280c6e20906be541b102d2595e3f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/common.py @@ -0,0 +1,112 @@ +# mypy: ignore-errors + +import contextlib +import functools +import logging +from unittest.mock import patch + +import torch +from torch._dynamo import disable +from torch._dynamo.utils import counters, defake +from torch._functorch.aot_autograd import aot_module_simplified +from torch.utils._python_dispatch import _disable_current_modes + +log = logging.getLogger(__name__) + + +def aot_autograd(**kwargs): + def compiler_fn(gm: torch.fx.GraphModule, example_inputs): + # Hack to get around circular import problems with aot_eager_decomp_partition + if callable(kwargs.get("decompositions")): + kwargs["decompositions"] = kwargs["decompositions"]() + + # NB: dont delete counter increment + counters["aot_autograd"]["total"] += 1 + use_fallback = False + + if use_fallback: + log.debug("Unable to use AOT Autograd because graph has mutation") + counters["aot_autograd"]["not_ok"] += 1 + return gm + + # OK attempt to compile + + def _wrapped_bw_compiler(*args, **kwargs): + # stop TorchDynamo from trying to compile our generated backwards pass + return disable(disable(bw_compiler)(*args, **kwargs)) + + bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] + kwargs["bw_compiler"] = _wrapped_bw_compiler + kwargs["inference_compiler"] = ( + kwargs.get("inference_compiler") or kwargs["fw_compiler"] + ) + + from functorch.compile import nop + + from torch._inductor.debug import enable_aot_logging + + # debug asserts slow down compile time noticeably, + # So only default them on when the aot_eager backend is used. + if kwargs.get("fw_compiler", None) == nop: + patch_config = patch("functorch.compile.config.debug_assert", True) + else: + patch_config = contextlib.nullcontext() + + try: + # NB: NOT cloned! + with enable_aot_logging(), patch_config: + cg = aot_module_simplified(gm, example_inputs, **kwargs) + counters["aot_autograd"]["ok"] += 1 + return disable(cg) + except Exception: + counters["aot_autograd"]["not_ok"] += 1 + raise + + return compiler_fn + + +def mem_efficient_fusion_kwargs(use_decomps): + from functorch.compile import ( + default_decompositions, + min_cut_rematerialization_partition, + ts_compile, + ) + + kwargs = { + # these are taken from memory_efficient_fusion() + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + } + + if use_decomps: + kwargs["decompositions"] = default_decompositions + + return kwargs + + +def fake_tensor_unsupported(fn): + """ + Decorator for backends that need real inputs. We swap out fake + tensors for zero tensors. + """ + + @functools.wraps(fn) + def wrapper(model, inputs, **kwargs): + with _disable_current_modes(): + inputs = list(map(defake, inputs)) + return fn(model, inputs, **kwargs) + + return wrapper + + +def device_from_inputs(example_inputs) -> torch.device: + for x in example_inputs: + if hasattr(x, "device"): + return x.device + + +def dtype_from_inputs(example_inputs) -> torch.dtype: + for x in example_inputs: + if hasattr(x, "dtype"): + return x.dtype diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/cudagraphs.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..76c57c505905ce797365088111ab28a77ab2b722 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/cudagraphs.py @@ -0,0 +1,239 @@ +# mypy: ignore-errors + +import functools +import operator +from collections import defaultdict +from typing import Dict, List, Optional + +import torch +from torch._dynamo.backends.debugging import boxed_nop +from torch._inductor.cudagraph_trees import cudagraphify_impl +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + check_multiple_devices_or_any_cpu_nodes, + get_mutation_stack_trace, +) +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + has_incompatible_cudagraph_ops, + num_fw_fixed_arguments, + output_node, +) +from torch.multiprocessing.reductions import StorageWeakRef +from .common import aot_autograd +from .registry import register_backend + +perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + + +def find_input_mutations(g): + def meta_fk(meta): + return meta["val"] if "val" in meta else meta["fake_result"] + + inputs = defaultdict(set) + input_idx = 0 + mutated_inputs = set() + for n in g.nodes: + if n.op == "placeholder": + if isinstance(meta_fk(n.meta), torch.Tensor): + inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) + input_idx += 1 + elif n.op == "call_function": + if n.target is operator.getitem: + continue + schema = n.target._schema + for i, arg in enumerate(schema.arguments): + if i < len(n.args): + argument = n.args[i] + else: + if arg.name not in n.kwargs: + continue + argument = n.kwargs[arg.name] + mut_arg = False + if arg.alias_info: + if arg.alias_info.is_write: + mut_arg = True + if mut_arg: + # TODO: not correct for args that contain tensors in a struct + # like list + mutated_inputs |= inputs[ + StorageWeakRef(meta_fk(argument.meta)._typed_storage()) + ] + + # TODO: error on unrecognized nodes + return mutated_inputs + + +def get_device_node_mapping(gm: torch.fx.GraphModule): + device_node_mapping: Dict[torch.device, torch.fx.Node] = {} + for n in gm.graph.nodes: + t = n.meta.get("val", None) + if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: + device_node_mapping[t.device] = n + return device_node_mapping + + +def check_for_mutation(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: + mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) + if not mutation_indices: + return None + + return get_mutation_stack_trace(aot_model, mutation_indices) + + +def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: + if mut_skip := check_for_mutation(aot_model, num_fixed): + return mut_skip + + if skip := check_multiple_devices_or_any_cpu_nodes( + get_device_node_mapping(aot_model) + ): + return skip + + if has_incompatible_cudagraph_ops(aot_model): + return "skipping cudagraphs due to incompatible op" + + return None + + +def get_device_index(gm) -> int: + device = next(iter(get_device_node_mapping(gm))) + assert device.type == "cuda" + return device.index + + +def get_stack_traces(gm) -> List[Optional[str]]: + output = output_node(gm) + assert len(output.args) == 1 + return [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] + ] + + +def cudagraphs(dynamo_model, dynamo_inputs): + do_cudagraphs = BoxedBool(True) + boxed_device_index = BoxedDeviceIndex(None) + + def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): + interp = boxed_nop(aot_model, aot_inputs) + fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) + if skip_msg := check_for_skip(aot_model, fixed): + BoxedBool.disable(do_cudagraphs) + perf_log.warning("skipping cudagraphs due to %s", skip_msg) + return interp + + boxed_device_index.set(get_device_index(aot_model)) + + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=boxed_device_index.value, + is_backward=False, + is_inference=False, + stack_traces=get_stack_traces(aot_model), + ) + out._boxed_call = True + return out + + def backward_cudagraphs(aot_model, aot_inputs): + interp = boxed_nop(aot_model, aot_inputs) + if not do_cudagraphs: + return aot_model + + fixed = count_tangents(aot_model) + if skip_msg := check_for_skip(aot_model, fixed): + perf_log.warning("skipping cudagraphs due to %s", skip_msg) + + # See [Backward Generation Handling] + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_device_index.value, create_if_none_exists=False + ) + assert manager is not None + + def fn(inputs): + manager.set_to_running_backward() + return aot_model(inputs) + + fn._boxed_call = True + return fn + + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=get_device_index(aot_model), + is_backward=True, + is_inference=False, + stack_traces=get_stack_traces(aot_model), + ) + out._boxed_call = True + return out + + aot_cudagraphs = aot_autograd( + fw_compiler=forward_cudagraphs, + bw_compiler=backward_cudagraphs, + inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), + keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, + ) + return aot_cudagraphs(dynamo_model, dynamo_inputs) + + +class CudagraphsBackend: + compiler_name = "cudagraphs" + + @staticmethod + def reset(): + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + @staticmethod + def __call__(model, inputs): + return cudagraphs(model, inputs) + + +# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful +# for debugging and can serve as a perf baseline. +register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) + + +def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): + """This isn't registered as a backend, but is used in some benchmarks""" + assert isinstance(inputs, (list, tuple)) + if copy_inputs: + static_inputs = [torch.zeros_like(x) for x in inputs] + else: + static_inputs = list(inputs) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + model(*inputs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + static_outputs = model(*static_inputs) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + def run(*new_inputs): + assert len(static_inputs) == len(new_inputs) + if copy_inputs: + for dst, src in zip(static_inputs, new_inputs): + dst.copy_(src) + graph.replay() + if copy_outputs: + return [x.clone() for x in static_outputs] + else: + return static_outputs + + return run diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/debugging.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/debugging.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdc89fb699e2a07099f7d09e9fc4d3b1d8f3a43 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/debugging.py @@ -0,0 +1,289 @@ +# mypy: ignore-errors + +import dataclasses +import functools +from importlib import import_module +from typing import Any, List, Optional + +from functorch.compile import min_cut_rematerialization_partition + +import torch +from torch import _guards +from torch._functorch.compilers import ts_compile +from .common import aot_autograd +from .registry import register_debug_backend as register_backend + +""" +This file contains TorchDynamo backends intended for debugging uses. +""" + + +@register_backend +def eager(gm, fake_tensor_inputs): + return gm + + +@register_backend +def pre_dispatch_eager(gm, fake_tensor_inputs): + from torch.fx.experimental.proxy_tensor import make_fx + + def runnable_gm(*args): + return torch.fx.Interpreter(gm).run(*args) + + pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) + pre_dispatch_gm.print_readable() + + return pre_dispatch_gm + + +@register_backend +def eager_debug(gm, fake_tensor_inputs): + from torch._subclasses.schema_check_mode import SchemaCheckMode + + # We could add more debugging bits here. + # Right now, this backend can be used to check for and error on + # custom dispatcher ops that have incorrect schemas. + def inner(*args): + with SchemaCheckMode(): + return torch.fx.Interpreter(gm).run(*args) + + return inner + + +@register_backend(name="ts") +def torchscript(gm, fake_tensor_inputs): + return torch.jit.script(gm) + + +# used boxed call to discard inputs when they are no longer needed +def boxed_nop(fx_g, example_inputs): + def run(args): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +# Useful for debugging purpose +# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. +aot_eager = aot_autograd( + fw_compiler=boxed_nop, partition_fn=min_cut_rematerialization_partition +) +register_backend(name="aot_eager", compiler_fn=aot_eager) + +aot_eager_default_partitioner = aot_autograd(fw_compiler=boxed_nop) +register_backend( + name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner +) + +# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs +# inductor problems. +# aot_eager_decomp_partition just replaces the inductor compiler with nop to help +# isolate inductor vs aot_eager errors +aot_eager_decomp_partition = aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=boxed_nop, + bw_compiler=boxed_nop, + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), +) +register_backend( + name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition +) + +# AOT Autograd with torchscript backend. Default partitioner. +# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser +# by using the relevant fuser with torch.jit.fuser(...) +aot_ts = aot_autograd(fw_compiler=ts_compile) +register_backend(name="aot_ts", compiler_fn=aot_ts) + +# These buggy backends are used for inducing bugs so that we can test +# our repro extraction / minifier scripts + + +class ReluCompileError(Exception): + pass + + +class TestingOnlyCompileError(Exception): + pass + + +@register_backend +def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise ReluCompileError() + return gm + + +@register_backend +def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch._assert + node.args = (False, "ReluRuntimeError") + gm.recompile() + return gm + + +@register_backend +def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm + + +@register_backend +def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + # Require at least one non-trivial thing in the graph, + # see https://github.com/pytorch/pytorch/issues/102898 + for node in gm.graph.nodes: + if node.op == "call_function": + break + else: + return gm + for t in example_inputs: + if not t.is_leaf: + raise TestingOnlyCompileError() + return gm + + +@dataclasses.dataclass +class ExplainOutput: + """ + This is the output of :func:`torch._dynamo.explain()` + There is no reason to create this class directly. + """ + + graphs: List[torch.fx.GraphModule] + graph_count: int + graph_break_count: int + break_reasons: List[ + Any + ] # Type is GraphCompileReason but doesn't matter for this purpose + op_count: int + ops_per_graph: Optional[List[torch.fx.Node]] = None + out_guards: Optional[List[_guards.Guard]] = None + compile_times: Optional[str] = None + + def __str__(self): + output = f"Graph Count: {self.graph_count}\n" + output += f"Graph Break Count: {self.graph_break_count}\n" + output += f"Op Count: {self.op_count}\n" + + output += "Break Reasons:\n" + for idx, break_reason in enumerate(self.break_reasons): + output += f" Break Reason {idx+1}:\n" + output += f" Reason: {break_reason.reason}\n" + output += " User Stack:\n" + for frame_summary in break_reason.user_stack: + output += f" {frame_summary}\n" + + if self.ops_per_graph is not None: + output += "Ops per Graph:\n" + for idx, ops in enumerate(self.ops_per_graph): + output += f" Ops {idx+1}:\n" + for op in ops: + output += f" {op}\n" + + if self.out_guards is not None: + output += "Out Guards:\n" + for i, guard in enumerate(self.out_guards): + output += f" Guard {i+1}:\n" + output += f" {str(guard)}" + + if self.compile_times is not None: + output += f"Compile Times: {self.compile_times}\n" + return output + + +def _explain_graph_detail( + gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons +): + """ + This function is a utility which processes a torch.fx.GraphModule and + accumulates information about its ops, graph breaks, and other details. It + is intended to be used by the ExplainWithBackend class and + `torch._dynamo.explain()` to provide details from Dynamo's graph capture. + + Parameters: + gm (torch.fx.GraphModule): The GraphModule to be processed. + graphs (list): A list that accumulates all the GraphModules processed. + op_count (int): The total count of operations in all GraphModules processed so far. + ops_per_graph (list): A list that accumulates the operations of each GraphModule. + break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. + + Returns: + tuple: A tuple containing the processed GraphModule, the updated lists of graphs, + operations per graph, and break reasons, and the updated operation count. + """ + graphs.append(gm) + ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] + op_count += len(ops) + ops_per_graph.append(ops) + if gm.compile_subgraph_reason.graph_break: + break_reasons.append(gm.compile_subgraph_reason) + + return gm, graphs, op_count, ops_per_graph, break_reasons + + +class ExplainWithBackend: + """ + This class is intended to be used as a backend for `torch.compile`. It is + composable with other backends. When used in this way, it accumulates + information about graph breaks, ops, and other info and provides a string + representation summarizing this information. + + Attributes: + backend (str): The name of the backend to use for optimization. + graphs (list): A list of the graphs captured by TorchDynamo. + op_count (int): The total number of operations in all optimized graphs. + break_reasons (list): A list of graph break reasons with stack traces. + + Example Usage: + def fn(x): + x = torch.sigmoid(x) + return x + + torch._dynamo.reset() + eb = ExplainWithBackend("inductor") + optimized_fn = torch.compile(fn, backend=eb) + result = optimized_fn(torch.randn(5)) + print(eb.output()) + """ + + def __init__(self, backend): + from .registry import lookup_backend + + self.backend = lookup_backend(backend) + self.graphs = [] + self.op_count = 0 + self.break_reasons = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( + gm, self.graphs, self.op_count, [], self.break_reasons + ) + return self.backend(gm, example_inputs) + + def output(self) -> ExplainOutput: + graph_count = len(self.graphs) + output = ExplainOutput( + self.graphs, + graph_count, + graph_count - 1, + self.break_reasons, + self.op_count, + ) + + return output diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/distributed.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6812035e21b663a1e75df98ceb5ed9fbfbe1cd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/distributed.py @@ -0,0 +1,612 @@ +# mypy: ignore-errors + +import logging +import traceback +from dataclasses import dataclass, field +from typing import Any, List, Optional +from unittest import mock + +import torch +from torch import fx +from torch._dynamo.output_graph import GraphCompileReason +from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.node import Node + +# Regular log messages should go through 'log'. +# ddp_graph_log is a separate artifact logger reserved for dumping graphs. +# See docs/source/logging.rst for more info. +log = logging.getLogger(__name__) +ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") + + +def args_str(args): + # a debug helper + if torch.is_tensor(args): + return f"T[{args.shape}]" + elif isinstance(args, tuple): + return f"tuple({', '.join([args_str(x) for x in args])})" + elif isinstance(args, list): + return f"list({', '.join([args_str(x) for x in args])})" + else: + return str(args) + + +@dataclass +class Bucket: + size: int = 0 + params: List[str] = field(default_factory=list) + nodes: List[fx.Node] = field(default_factory=list) + + # param_ids is just used for unit testing + param_ids: List = field(default_factory=list) + + # keep track of any buckets that were extended for logging purposes + opcount_increased_to_capture_external_output: int = 0 + paramsize_before_opcount_increase: int = 0 + + +def bucket_has_external_output(bucket: Bucket) -> bool: + nodes_in_bucket = set() + # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards + # so we don't reverse it here + for node in bucket.nodes: + # assume node.op != output, since those are filtered in the original iteration + nodes_in_bucket.add(node) + for user in node.users: + if user not in nodes_in_bucket: + return True + return False + + +def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int): + headers = ("Index", "Size (b)", "Param Names") + rows = [] + extended_buckets = [] + for idx, bucket in enumerate(reversed(buckets)): + if len(bucket.params) > 0: + rows.append((idx, bucket.size, bucket.params[0])) + for param in bucket.params[1:]: + rows.append((None, None, param)) + if bucket.opcount_increased_to_capture_external_output > 0: + extended_buckets.append( + ( + idx, + bucket.opcount_increased_to_capture_external_output, + bucket.size - bucket.paramsize_before_opcount_increase, + ) + ) + + if len(rows): + log.info( + "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.", + bucket_bytes_cap, + len(buckets), + ) + + if len(extended_buckets): + log.warning( + "Some buckets were extended beyond their requested parameter capacities" + " in order to ensure each subgraph has an output node, required for fx graph partitioning." + " This can be the case when a subgraph would have only contained nodes performing inplace mutation," + " and returning no logical outputs. This should not be a problem, unless it results in too few graph" + " partitions for optimal DDP performance." + ) + + try: + from tabulate import tabulate + + log.debug( + "\nDDPOptimizer produced the following bucket assignments:\n%s", + tabulate(rows, headers=headers, tablefmt="simple_grid"), + ) + + if len(extended_buckets): + log.warning( + "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s", + tabulate( + extended_buckets, + headers=("Index", "Extra Ops", "Extra Param Size (b)"), + tablefmt="simple_grid", + ), + ) + except ImportError: + log.debug( + "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information." + ) + else: + log.debug("DDPOptimizer captured no parameters and did not split this graph.") + + +def has_higher_order_op(gm): + # Check if there is a higher order op in the graph + for node in gm.graph.nodes: + if node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if isinstance(maybe_param, torch.fx.GraphModule): + return True + return False + + +# 3 (lazy compile): Replace submodules with lazily compiling submodule +class SubmoduleReplacer(torch.fx.interpreter.Interpreter): + def __init__(self, module, compiler): + super().__init__(module) + self.compiler = compiler + + def lazily_compiled_submod(self, input_mod): + """ + Create a wrapper around submodules which: + - lazily compiles each of the partitioned submodules using the user-provided compiler + - unpacks singleton tuples/lists into flat arg + """ + + class LazilyCompiledModule(torch.nn.Module): + def __init__(self, submod, compiler, unwrap_singleton_tuple): + super().__init__() + self.submod = submod + self.compiler = compiler + self.compiled = False + self.unwrap_singleton_tuple = unwrap_singleton_tuple + + def forward(self, *args): + if not self.compiled: + # First compile with args as example_inputs + # These args will be fakeified if using Inductor/AOTAutograd + new_submod = self.compiler(self.submod, args) + del self.submod + self.submod = new_submod + self.compiled = True + self.compiler = None + + x = self.submod(*args) + # we must let 'input_mod' return a tuple, to make AOT happy. + # (aot_autograd compile_fn literally requires that the output of a graph it compiles is a tuple). + # however, we don't acutally want this tuple to be returned, since the fx logic that calls the submod + # will again wrap outputs from the submod in a tuple. So we unwrap it, and count on it being re-wrapped + if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): + return x[0] + return x + + unwrap_singleton_tuple = False + for sn in input_mod.graph.nodes: + if sn.op == "output": + if not isinstance(sn.args[0], tuple): + unwrap_singleton_tuple = True + sn.args = (sn.args,) + + input_mod.recompile() + input_mod.compile_subgraph_reason = GraphCompileReason( + "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." + " Set `torch._dynamo.config.optimize_ddp = False` to disable.", + [ + # it's close to useless to get a real stacktrace here, and quite verbose. + traceback.FrameSummary(__file__, 0, DDPOptimizer), + ], + ) + wrapper = LazilyCompiledModule( + input_mod, + self.compiler, + unwrap_singleton_tuple, + ) + return wrapper + + # We replace the submodules with lazy submodules which compile + # the corresponding submodules when they are run with real values + # Always returns `None` - we do not need to propagate values in order + # to replace submodules. + def run_node(self, n: Node) -> Any: + if n.op == "call_module": + real_mod = self.fetch_attr(n.target) + + ddp_graph_log.debug("\n---%s graph---\n%s", n.target, real_mod.graph) + + assert len(n.kwargs) == 0, "We assume only args for these modules" + lazily_compiled_submod = self.lazily_compiled_submod(real_mod) + + # We update the original (outer) graph with a call into the compiled module + # instead of the uncompiled one. + self.module.delete_submodule(n.target) + n.target = "compiled_" + n.target + self.module.add_submodule(n.target, lazily_compiled_submod) + + +# 3 (no lazy compile): compile each of the partitioned submodules using the user-provided compiler +class SubmodCompiler(torch.fx.interpreter.Interpreter): + def __init__(self, module, compiler, fake_mode): + super().__init__(module) + self.compiler = compiler + self.fake_mode = fake_mode + + def compile_submod(self, input_mod, args, kwargs): + """ + Compile the submodule, + using a wrapper to make sure its output is always a tuple, + which is required by AotAutograd based compilers + """ + assert len(kwargs) == 0, "We assume only args for these modules" + + class WrapperModule(torch.nn.Module): + def __init__(self, submod, unwrap_singleton_tuple): + super().__init__() + self.submod = submod + self.unwrap_singleton_tuple = unwrap_singleton_tuple + + def forward(self, *args): + x = self.submod(*args) + # TODO(whc) + # for some reason the isinstance check is necessary if I split one node per submod + # - even though I supposedly wrapped the output in a tuple in those cases, the real + # compiled module was still returning a tensor + if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): + return x[0] + return x + + unwrap_singleton_tuple = False + for sn in input_mod.graph.nodes: + if sn.op == "output": + if not isinstance(sn.args[0], tuple): + unwrap_singleton_tuple = True + sn.args = (sn.args,) + + input_mod.recompile() + input_mod.compile_subgraph_reason = GraphCompileReason( + "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." + " Set `torch._dynamo.config.optimize_ddp = False` to disable.", + [ + # it's close to useless to get a real stacktrace here, and quite verbose. + traceback.FrameSummary(__file__, 0, DDPOptimizer), + ], + ) + + wrapper = WrapperModule( + self.compiler(input_mod, args), + unwrap_singleton_tuple, + ) + return wrapper + + # Note: + # + # The way distributed works today around fake tensors can be somewhat confusing. + # Some of these codepaths are shared in both runtime, and compile time. The presence + # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. + # + # A few things to keep in mind: + # + # 1) We invoke `compile_submod` with a real module. The output of that gets stored + # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. + # + # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the + # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. + # + # 3) Fake tensors should always be around during compile time. + # + # 4) Fake tensors should never be around at runtime. + # + # 5) We end up with a compilation mode that takes a real submodule and fake tensors, + # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd] + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + new_args = [] + assert self.fake_mode + for arg in args: + if isinstance(arg, torch.Tensor) and not isinstance( + arg, torch._subclasses.FakeTensor + ): + new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode)) + else: + new_args.append(arg) + + log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args)) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + if n.op == "call_module": + real_mod = self.fetch_attr(n.target) + if self.fake_mode: + curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode) + else: + curr_submod = real_mod + + ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph) + + # When calling the compiler on the submod, inputs (new_args) are expected to + # be FakeTensors already since Dynamo would have made them FakeTensors in the + # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors, + # since this wrapping happens during compilation + + # Note: Returning Fake Tensors on First AOT Autograd Call + # + # Inductor will optimize strides of outputs when it deems it profitable. + # For instance, converting to channels last. When we split the graph here + # into multiple inductor compilations, we need to make sure that the + # output strides of one compilation is appropriately passed to the subsequent + # compilations. However, the mapping from inductor output to dynamo output + # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing, + # subclass handling, etc. In order to replay all this logic we set a flag such that + # the first invocation of inductor in aot_autograd will return Fake Tensors with + # appropriate strides. Then, all of aot autograd's runtime logic is replayed. + # This gives us the appropriately strided outputs here which will reflect runtime strides. + + class FakeifyFirstAOTInvocationGuard: + def __init__(self): + self.tc = torch._guards.TracingContext.try_get() + assert self.tc + torch._guards.TracingContext.try_get().fakify_first_call = True + + def __del__(self): + self.tc.fakify_first_call = False + + # For aot_eager and other backends, tracing context is not set + has_tracing_context = torch._guards.TracingContext.try_get() is not None + if has_tracing_context: + g = FakeifyFirstAOTInvocationGuard() + + from torch._dynamo.utils import counters + + init = counters["aot_autograd"]["total"] + compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs) + + # TODO - better way of doing this? + # Only aot autograd handles fakifying first call + invoked_aot_autograd = init != counters["aot_autograd"]["total"] + + # We update the original (outer) graph with a call into the compiled module + # instead of the uncompiled one. + self.module.delete_submodule(n.target) + n.target = "compiled_" + n.target + self.module.add_submodule(n.target, compiled_submod_real) + + # Finally, we have to produce inputs for use compiling the next submodule, + # and these need to be FakeTensors, so we execute the module under fake_mode + # Because parameters are not fake we patch fake tensor mode to allow non fake inputs + with self.fake_mode, mock.patch.object( + self.fake_mode, "allow_non_fake_inputs", True + ): + if has_tracing_context and invoked_aot_autograd: + out = compiled_submod_real(*new_args, **kwargs) + # output should be fake or subclass + assert all( + (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor) + for t in (out if isinstance(out, (list, tuple)) else [out]) + ) + return out + else: + return curr_submod(*new_args, **kwargs) + else: + # placeholder or output nodes don't need to get compiled, just executed + return getattr(self, n.op)(n.target, new_args, kwargs) + + +class DDPOptimizer: + + """Note [DDPOptimizer] + DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP), + breaking the dynamo graph into chunks to compile separately, with the breaks aligning to + the boundaries of gradient-allreduce buckets chosen by DDP. + + Background/Motivation + - DDP uses allreduce collectives to synchronize partial gradients computed on different workers + - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce + - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready + at around the same time during backward and thus can share the same allreduce efficiently + - Allreduces must overlap with backward compute for optimal training performance + - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which + operates when individual grads become 'ready' + - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the + autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole + fused backward function executes, preventing any overlap of compute and communication + + Algorithm + - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse + this graph in reverse order to determine the true order that gradients will become ready during backward. + - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started + and a graph break introduced + - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together + into an outer module that is returned to the user + + Notes + - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP, + and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does + in eager. + - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently + produce splits that do not necessarily align with the buckets used by DDP. This should result in performance + degradation approaching the baseline case where graph-splits are not used, but not worse. + - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the + subgraphs being compiled + - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers + left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are + also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP, + it is not catastrophic but could impact performance by choosing sub-optimal bucket splits. + - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients, + and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by + DDPOptimizer) + + Debugging + - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb. + - In many cases, the log messages are helpful (they show bucket size assignments)- + just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'. + - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model + in a single process (or with torchrun, in multiple processes) + + Args: + bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be + set to match the equivalent parameter on the original DDP module. + + backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph. + + first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP + special-cases the first bucket size since it is sometimes optimal to start a small allreduce early. + + """ + + def __init__( + self, + bucket_bytes_cap: int, + backend_compile_fn, + first_bucket_cap: Optional[int] = None, + ): + if first_bucket_cap is not None: + self.first_bucket_cap = first_bucket_cap + elif torch.distributed.is_available(): + # this constant comes from C10D lib which is not always built + self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES + else: + self.first_bucket_cap = bucket_bytes_cap + + self.bucket_bytes_cap = bucket_bytes_cap + assert ( + self.first_bucket_cap <= self.bucket_bytes_cap + ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP" + + self.backend_compile_fn = backend_compile_fn + + def _ignore_parameter(self, parameter): + return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored + + def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): + """ + Implements graph splitting, first determining a set of of buckets by counting + parameter sizes in reverse graph order, then invoking the user/backend compiler + to compile each subgraph. Finally, stiches compiled graphs into one graphmodule + and returns its callable. + """ + if has_higher_order_op(gm): + # This indicates presence of a higher order op. For now, we + # have no way to break the higher order op into two buckets. + # Allowing higher order ops in the graph also requires + # changes in the split_module, becuase graph splitter + # currently assumes that all the args of all ops are + # tensors, but in the case of higher order ops, it could be + # a graph module. As a workaround, we are shortcircuiting + raise NotImplementedError( + "DDPOptimizer backend: Found a higher order op in the graph. " + "This is not supported. Please turn off DDP optimizer using " + "torch._dynamo.config.optimize_ddp=False. Note that this can " + "cause performance degradation because there will be one bucket " + "for the entire Dynamo graph. Please refer to this issue - " + "https://github.com/pytorch/pytorch/issues/104674." + ) + + # 1: compute the partition map according to DDP bucket logic + buckets = [Bucket()] # (size, param_names) + for node in reversed(gm.graph.nodes): + if node.op in ("output", "placeholder"): + continue + + if ( + buckets[0].size >= self.bucket_bytes_cap + or len(buckets) == 1 + and buckets[0].size >= self.first_bucket_cap + ): + if bucket_has_external_output(buckets[0]): + buckets.insert(0, Bucket()) + else: + # continue building this bucket past the point of filling its parameter capacity, + # to increase chances it contains at least one node that is either a global output or + # passed as input to a subsequent graph + + if buckets[0].opcount_increased_to_capture_external_output == 0: + buckets[0].paramsize_before_opcount_increase = buckets[0].size + buckets[0].opcount_increased_to_capture_external_output += 1 + + if node.op == "call_module": + target = gm.get_submodule(node.target) + for name, param in target.named_parameters(): + if param.requires_grad and not self._ignore_parameter(param): + buckets[0].size += param.untyped_storage().nbytes() + buckets[0].params.append(f"{node.target}_{name}") + buckets[0].param_ids.append(id(param)) + elif node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if maybe_param.requires_grad and not self._ignore_parameter( + maybe_param + ): + buckets[0].size += maybe_param.untyped_storage().nbytes() + buckets[0].params.append(node.target) + buckets[0].param_ids.append(id(maybe_param)) + + # All nodes have to be mapped to a bucket, even if they don't have their own params + # Ignored params still end up in buckets, we just don't count them towards the capacity + buckets[0].nodes.append(node) + + if len(buckets) > 1 and buckets[0].size == 0: + # we collected a small preamble graph with ops that don't include parameters, fuse it back + buckets[1].nodes.extend(buckets[0].nodes) + assert len(buckets[0].params) == 0, "Params should be empty if size is 0" + del buckets[0] + + # stash buckets for testing/debugging purposes + self.buckets = buckets + pretty_print_buckets(buckets, self.bucket_bytes_cap) + + if len(buckets) == 1: + # bypass split/fuse logic if there is only one bucket + return self.backend_compile_fn(gm, example_inputs) + + # 2: partition the graphmodule according to bucket capacity + partition_map = {} + for idx, b in enumerate(buckets): + for node in b.nodes: + partition_map[node] = idx + + split_gm = fx.passes.split_module.split_module( + gm, None, lambda node: partition_map[node] + ) + + debug_str = ( + f"\n---orig graph---\n{gm.graph}\n" + + f"\n---split graph---\n{split_gm.graph}\n" + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # only print the submod graphs, not their children + debug_str += f"\n---{name} graph---\n{module.graph}\n" + debug_str += "\n---------------\n" + ddp_graph_log.debug(debug_str) + + trace_structured( + "optimize_ddp_split_graph", + payload_fn=lambda: split_gm.print_readable(print_output=False), + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + trace_structured( + "optimize_ddp_split_child", + lambda: {"name": name}, + payload_fn=lambda: module.print_readable(print_output=False), + ) + + # NOTE, we want to enable `optimize_ddp_lazy_compile` by default as soon as possible, + # becuase it will fix stride mismatch errors (see motivation: https://github.com/pytorch/pytorch/pull/114154). + # However, lazy compile currently causes shape mismatch in other cases (`test_graph_split_inductor_transpose`) + # and we need to fix them before we can enable it by default. + if not torch._dynamo.config.optimize_ddp_lazy_compile: + # Today, optimize_ddp=True and keep_output_stride=False can lead to silent + # correctness issues. The problem is that ddp_optimizer works by partitioning + # the dynamo graph, sending each subgraph through aot autograd to inductor, + # and creates example inputs by eagerly interpreting each subgraph to get + # an output that with the same metadata that we'd get from eager mode. + # This is a problem though, for torch._inductor.config.keep_output_stride. + # The above config can cause the outputs of the first graph to have + # **different** strides from eager, causing the inputs that we pass + # to the second graph to be wrong. + # To really fix this, we would need to faithfully ask inductor + # what the outputs to each graph it expects are. + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + + if torch._dynamo.config.optimize_ddp_lazy_compile: + submod_compiler = SubmoduleReplacer(split_gm, self.backend_compile_fn) + else: + submod_compiler = SubmodCompiler( + split_gm, self.backend_compile_fn, fake_mode + ) + submod_compiler.run(*example_inputs) + split_gm.recompile() + + ddp_graph_log.debug( + "\n---final graph---\n%s\n---------------\n", split_gm.graph + ) + return split_gm diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/inductor.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/inductor.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a661378b616ce3e7975f0fb330b029ce9c7142 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/inductor.py @@ -0,0 +1,16 @@ +# mypy: ignore-errors + +import sys + +from torch._dynamo import register_backend + + +@register_backend +def inductor(*args, **kwargs): + if sys.platform == "win32": + raise RuntimeError("Windows not yet supported for inductor") + + # do import here to avoid loading inductor into memory when it is not used + from torch._inductor.compile_fx import compile_fx + + return compile_fx(*args, **kwargs) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/onnxrt.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/onnxrt.py new file mode 100644 index 0000000000000000000000000000000000000000..54e7a3c95f2f42f709bf38d78e088962f7b00df6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/onnxrt.py @@ -0,0 +1,37 @@ +# mypy: ignore-errors + +# This backend is maintained by ONNX team. To direct issues +# to the right people, please tag related GitHub issues with `module: onnx`. +# +# Maintainers' Github IDs: wschin, thiagocrepaldi, BowenBao, abock +from torch.onnx._internal.onnxruntime import ( + is_onnxrt_backend_supported, + torch_compile_backend, +) +from .registry import register_backend + + +def has_onnxruntime(): + # FIXME(abock): update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() + return is_onnxrt_backend_supported() + + +if is_onnxrt_backend_supported(): + register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +else: + + def information_displaying_backend(*args, **kwargs): + raise ImportError( + "onnxrt is not registered as a backend. " + "Please make sure all dependencies such as " + "numpy, onnx, onnxscript, and onnxruntime-training are installed. " + "Suggested procedure to fix dependency problem:\n" + " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" + " (2) Open a new python terminal.\n" + " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" + " (4) If it returns `True`, then you can use `onnxrt` backend.\n" + " (5) If it returns `False`, please execute the package importing section in " + "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." + ) + + register_backend(name="onnxrt", compiler_fn=information_displaying_backend) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/registry.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..038d9156838b876a4d44f80f07f6e5adccd89d57 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/registry.py @@ -0,0 +1,115 @@ +# mypy: ignore-errors + +import functools +import sys +from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple + +import torch +from torch import fx + + +class CompiledFn(Protocol): + def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: + ... + + +CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn] + +_BACKENDS: Dict[str, CompilerFn] = dict() + + +def register_backend( + compiler_fn: Optional[CompilerFn] = None, + name: Optional[str] = None, + tags: Sequence[str] = (), +): + """ + Decorator to add a given compiler to the registry to allow calling + `torch.compile` with string shorthand. Note: for projects not + imported by default, it might be easier to pass a function directly + as a backend and not use a string. + + Args: + compiler_fn: Callable taking a FX graph and fake tensor inputs + name: Optional name, defaults to `compiler_fn.__name__` + tags: Optional set of string tags to categorize backend with + """ + if compiler_fn is None: + # @register_backend(name="") syntax + return functools.partial(register_backend, name=name, tags=tags) + assert callable(compiler_fn) + name = name or compiler_fn.__name__ + assert name not in _BACKENDS, f"duplicate name: {name}" + _BACKENDS[name] = compiler_fn + compiler_fn._tags = tuple(tags) + return compiler_fn + + +register_debug_backend = functools.partial(register_backend, tags=("debug",)) +register_experimental_backend = functools.partial( + register_backend, tags=("experimental",) +) + + +def lookup_backend(compiler_fn): + """Expand backend strings to functions""" + if isinstance(compiler_fn, str): + if compiler_fn not in _BACKENDS: + _lazy_import() + if compiler_fn not in _BACKENDS: + _lazy_import_entry_point(compiler_fn) + if compiler_fn not in _BACKENDS: + from ..exc import InvalidBackend + + raise InvalidBackend(name=compiler_fn) + compiler_fn = _BACKENDS[compiler_fn] + return compiler_fn + + +def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: + """ + Return valid strings that can be passed to: + + torch.compile(..., backend="name") + """ + _lazy_import() + exclude_tags = set(exclude_tags or ()) + return sorted( + [ + name + for name, backend in _BACKENDS.items() + if not exclude_tags.intersection(backend._tags) + ] + ) + + +@functools.lru_cache(None) +def _lazy_import(): + from .. import backends + from ..utils import import_submodule + + import_submodule(backends) + + from ..repro.after_dynamo import dynamo_minifier_backend + + assert dynamo_minifier_backend is not None + + +@functools.lru_cache(None) +def _lazy_import_entry_point(backend_name: str): + from importlib.metadata import entry_points + + compiler_fn = None + group_name = "torch_dynamo_backends" + if sys.version_info < (3, 10): + backend_eps = entry_points() + eps = [ep for ep in backend_eps.get(group_name, ()) if ep.name == backend_name] + if len(eps) > 0: + compiler_fn = eps[0].load() + else: + backend_eps = entry_points(group=group_name) + if backend_name in backend_eps.names: + compiler_fn = backend_eps[backend_name].load() + + if compiler_fn is not None and backend_name not in list_backends(tuple()): + register_backend(compiler_fn=compiler_fn, name=backend_name) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/tensorrt.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2ba60cdeb0f6581e049f670088269919fa0fa5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/tensorrt.py @@ -0,0 +1,14 @@ +# mypy: ignore-errors + +# import torch # type: ignore[import] +# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import] +# from .registry import register_backend # type: ignore[import] + +""" +Placeholder for TensorRT backend for dynamo via torch-tensorrt +""" + +# @register_backend +# def tensorrt(gm, example_inputs): +# import torch_tensorrt # type: ignore[import] +# pass diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/torchxla.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/torchxla.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c50e11e8341311a240497ad4c95adbaad7de36 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/torchxla.py @@ -0,0 +1,75 @@ +# mypy: ignore-errors + +import logging +import warnings + +from functorch.compile import make_boxed_func + +from ..backends.common import aot_autograd +from .registry import register_backend, register_experimental_backend + +log = logging.getLogger(__name__) + + +@register_experimental_backend +def torchxla_trivial(gm, fake_tensor_inputs): + return gm + + +@register_experimental_backend +def torchxla_trace_once(model, fake_tensor_inputs): + warnings.warn( + "This backend will be deprecated in 2.2, please use `openxla` backend instead" + ) + + return xla_backend_helper(model, fake_tensor_inputs) + + +@register_backend +def openxla_eval(model, fake_tensor_inputs): + return xla_backend_helper(model, fake_tensor_inputs, boxed=False) + + +def openxla_eval_boxed(model, fake_tensor_inputs): + return xla_backend_helper(model, fake_tensor_inputs, boxed=True) + + +def xla_backend_helper(model, fake_tensor_inputs, boxed=False): + try: + import torch_xla.core.dynamo_bridge as bridge + except ImportError as e: + raise ImportError( + "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla" + ) from e + + compiled_graph = None + + def fwd(*args): + nonlocal model + nonlocal compiled_graph + if compiled_graph is None: + compiled_graph = bridge.extract_compiled_graph(model, args) + del model + return compiled_graph(*args) + + return make_boxed_func(fwd) if boxed else fwd + + +aot_torchxla_trivial = aot_autograd( + fw_compiler=torchxla_trivial, +) +register_experimental_backend( + name="aot_torchxla_trivial", compiler_fn=aot_torchxla_trivial +) + +aot_torchxla_trace_once = aot_autograd( + fw_compiler=torchxla_trace_once, +) +register_experimental_backend( + name="aot_torchxla_trace_once", compiler_fn=aot_torchxla_trace_once +) + +openxla = aot_autograd( + fw_compiler=openxla_eval_boxed, +) +register_backend(name="openxla", compiler_fn=openxla) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/backends/tvm.py b/MLPY/Lib/site-packages/torch/_dynamo/backends/tvm.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f1bd3c7cac8349d0e5a681fe1eaf4ed90386a7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/backends/tvm.py @@ -0,0 +1,172 @@ +# mypy: ignore-errors + +import functools +import importlib +import logging +import os +import tempfile + +import torch +from .common import device_from_inputs, fake_tensor_unsupported + +from .registry import register_backend + +log = logging.getLogger(__name__) + + +@register_backend +@fake_tensor_unsupported +def tvm(gm, example_inputs, *, scheduler=None, trials=20000): + import tvm # type: ignore[import] + from tvm import relay # type: ignore[import] + from tvm.contrib import graph_executor # type: ignore[import] + + jit_mod = torch.jit.trace(gm, example_inputs) + device = device_from_inputs(example_inputs) + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] + example_outputs = gm(*example_inputs) + if len(example_outputs) == 0: + log.warning("Explicitly fall back to eager due to zero output") + return gm.forward + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + if scheduler is None: + scheduler = os.environ.get("TVM_SCHEDULER", None) + + if scheduler == "auto_scheduler": + from tvm import auto_scheduler + + log_file = tempfile.NamedTemporaryFile() + + if not os.path.exists(log_file): + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], params, target + ) + for task in tasks: + print(task.compute_dag) + else: + print("No tasks") + if len(tasks) != 0: + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + if not os.path.exists(log_file): + assert trials > 0 + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + early_stopping=2000, + ) + try: + tuner.tune(tune_option) + except Exception: + if os.path.exists(log_file): + os.unlink(log_file) + raise + + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=3, config={"relay.backend.use_auto_scheduler": True} + ): + lib = relay.build(mod, target=target, params=params) + elif scheduler == "meta_schedule": + from tvm import meta_schedule as ms + + with tempfile.TemporaryDirectory() as work_dir: + if device.type != "cuda": + # meta_schedule needs num-cores to be specified + # here we use the maximum core count + target = tvm.target.Target( + f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}" + ) + # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch + # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + work_dir=work_dir, + max_trials_global=20000, + num_trials_per_iter=64, + params=params, + strategy="evolutionary", + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + ) + elif scheduler == "default" or not scheduler: + # no autotuning + with tvm.transform.PassContext(opt_level=10): + lib = relay.build(mod, target=target, params=params) + else: + raise NotImplementedError( + "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " + "There are three available options: default, auto_scheduler and meta_schedule." + ) + m = graph_executor.GraphModule(lib["default"](dev)) + + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if nd_tensor.dtype == "bool": + # DLPack does not support boolean so it can't be handled by + # torch.utils.dlpack.from_pack. Workaround by going through + # numpy, although this brings additional data copy overhead. + return torch.from_numpy(nd_tensor.numpy()) + return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) + + def to_tvm_tensor(torch_tensor): + """A helper function to transfer a torch.tensor to NDArray.""" + if torch_tensor.dtype == torch.bool: + # same reason as above, fallback to numpy conversion which + # could introduce data copy overhead + return tvm.nd.array(torch_tensor.cpu().numpy()) + return tvm.nd.from_dlpack(torch_tensor) + + def exec_tvm(*i_args): + args = [a.contiguous() for a in i_args] + shape_info, _ = m.get_input_info() + active_inputs = {name for name, _ in shape_info.items()} + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + inp_name = f"inp_{idx}" + if inp_name not in active_inputs: + log.warning( + "input %s skipped as not found in tvm's runtime library", + inp_name, + ) + continue + m.set_input( + inp_name, + to_tvm_tensor(arg), + ) + m.run() + return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())] + + return exec_tvm + + +tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule") +tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler") + + +def has_tvm(): + try: + importlib.import_module("tvm") + return True + except ImportError: + return False + + +@functools.lru_cache(None) +def llvm_target(): + if "avx512" in open("/proc/cpuinfo").read(): + return "llvm -mcpu=skylake-avx512" + return "llvm -mcpu=core-avx2" diff --git a/MLPY/Lib/site-packages/torch/_dynamo/bytecode_analysis.py b/MLPY/Lib/site-packages/torch/_dynamo/bytecode_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..5332ed5b7ec8e77cf449652c4c319ac644454572 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/bytecode_analysis.py @@ -0,0 +1,250 @@ +import bisect +import dataclasses +import dis +import sys +from typing import Any, Set, Union + +TERMINAL_OPCODES = { + dis.opmap["RETURN_VALUE"], + dis.opmap["JUMP_FORWARD"], + dis.opmap["RAISE_VARARGS"], + # TODO(jansel): double check exception handling +} +if sys.version_info >= (3, 9): + TERMINAL_OPCODES.add(dis.opmap["RERAISE"]) +if sys.version_info >= (3, 11): + TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"]) + TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) +else: + TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) +JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs) +JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES} +HASLOCAL = set(dis.haslocal) +HASFREE = set(dis.hasfree) + +stack_effect = dis.stack_effect + + +def get_indexof(insts): + """ + Get a mapping from instruction memory address to index in instruction list. + Additionally checks that each instruction only appears once in the list. + """ + indexof = {} + for i, inst in enumerate(insts): + assert inst not in indexof + indexof[inst] = i + return indexof + + +def remove_dead_code(instructions): + """Dead code elimination""" + indexof = get_indexof(instructions) + live_code = set() + + def find_live_code(start): + for i in range(start, len(instructions)): + if i in live_code: + return + live_code.add(i) + inst = instructions[i] + if inst.exn_tab_entry: + find_live_code(indexof[inst.exn_tab_entry.target]) + if inst.opcode in JUMP_OPCODES: + find_live_code(indexof[inst.target]) + if inst.opcode in TERMINAL_OPCODES: + return + + find_live_code(0) + + # change exception table entries if start/end instructions are dead + # assumes that exception table entries have been propagated, + # e.g. with bytecode_transformation.propagate_inst_exn_table_entries, + # and that instructions with an exn_tab_entry lies within its start/end. + if sys.version_info >= (3, 11): + live_idx = sorted(live_code) + for i, inst in enumerate(instructions): + if i in live_code and inst.exn_tab_entry: + # find leftmost live instruction >= start + start_idx = bisect.bisect_left( + live_idx, indexof[inst.exn_tab_entry.start] + ) + assert start_idx < len(live_idx) + # find rightmost live instruction <= end + end_idx = ( + bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1 + ) + assert end_idx >= 0 + assert live_idx[start_idx] <= i <= live_idx[end_idx] + inst.exn_tab_entry.start = instructions[live_idx[start_idx]] + inst.exn_tab_entry.end = instructions[live_idx[end_idx]] + + return [inst for i, inst in enumerate(instructions) if i in live_code] + + +def remove_pointless_jumps(instructions): + """Eliminate jumps to the next instruction""" + pointless_jumps = { + id(a) + for a, b in zip(instructions, instructions[1:]) + if a.opname == "JUMP_ABSOLUTE" and a.target is b + } + return [inst for inst in instructions if id(inst) not in pointless_jumps] + + +def propagate_line_nums(instructions): + """Ensure every instruction has line number set in case some are removed""" + cur_line_no = None + + def populate_line_num(inst): + nonlocal cur_line_no + if inst.starts_line: + cur_line_no = inst.starts_line + + inst.starts_line = cur_line_no + + for inst in instructions: + populate_line_num(inst) + + +def remove_extra_line_nums(instructions): + """Remove extra starts line properties before packing bytecode""" + + cur_line_no = None + + def remove_line_num(inst): + nonlocal cur_line_no + if inst.starts_line is None: + return + elif inst.starts_line == cur_line_no: + inst.starts_line = None + else: + cur_line_no = inst.starts_line + + for inst in instructions: + remove_line_num(inst) + + +@dataclasses.dataclass +class ReadsWrites: + reads: Set[Any] + writes: Set[Any] + visited: Set[Any] + + +def livevars_analysis(instructions, instruction): + indexof = get_indexof(instructions) + must = ReadsWrites(set(), set(), set()) + may = ReadsWrites(set(), set(), set()) + + def walk(state, start): + if start in state.visited: + return + state.visited.add(start) + + for i in range(start, len(instructions)): + inst = instructions[i] + if inst.opcode in HASLOCAL or inst.opcode in HASFREE: + if "LOAD" in inst.opname or "DELETE" in inst.opname: + if inst.argval not in must.writes: + state.reads.add(inst.argval) + elif "STORE" in inst.opname: + state.writes.add(inst.argval) + elif inst.opname == "MAKE_CELL": + pass + else: + raise NotImplementedError(f"unhandled {inst.opname}") + if inst.exn_tab_entry: + walk(may, indexof[inst.exn_tab_entry.target]) + if inst.opcode in JUMP_OPCODES: + walk(may, indexof[inst.target]) + state = may + if inst.opcode in TERMINAL_OPCODES: + return + + walk(must, indexof[instruction]) + return must.reads | may.reads + + +@dataclasses.dataclass +class FixedPointBox: + value: bool = True + + +@dataclasses.dataclass +class StackSize: + low: Union[int, float] + high: Union[int, float] + fixed_point: FixedPointBox + + def zero(self): + self.low = 0 + self.high = 0 + self.fixed_point.value = False + + def offset_of(self, other, n): + prior = (self.low, self.high) + self.low = min(self.low, other.low + n) + self.high = max(self.high, other.high + n) + if (self.low, self.high) != prior: + self.fixed_point.value = False + + def exn_tab_jump(self, depth): + prior = (self.low, self.high) + self.low = min(self.low, depth) + self.high = max(self.high, depth) + if (self.low, self.high) != prior: + self.fixed_point.value = False + + +def stacksize_analysis(instructions) -> Union[int, float]: + assert instructions + fixed_point = FixedPointBox() + stack_sizes = { + inst: StackSize(float("inf"), float("-inf"), fixed_point) + for inst in instructions + } + stack_sizes[instructions[0]].zero() + + for _ in range(100): + if fixed_point.value: + break + fixed_point.value = True + + for inst, next_inst in zip(instructions, instructions[1:] + [None]): + stack_size = stack_sizes[inst] + # CALL_FINALLY in Python 3.8 is handled differently when determining stack depth. + # See https://github.com/python/cpython/blob/3.8/Python/compile.c#L5450. + # Essentially, the stack effect of CALL_FINALLY is computed with jump=True, + # but the resulting stack depth is propagated to the next instruction, not the + # jump target. + is_call_finally = ( + sys.version_info < (3, 9) and inst.opcode == dis.opmap["CALL_FINALLY"] + ) + if inst.opcode not in TERMINAL_OPCODES: + assert next_inst is not None, f"missing next inst: {inst}" + stack_sizes[next_inst].offset_of( + stack_size, + stack_effect(inst.opcode, inst.arg, jump=is_call_finally), + ) + if inst.opcode in JUMP_OPCODES and not is_call_finally: + stack_sizes[inst.target].offset_of( + stack_size, stack_effect(inst.opcode, inst.arg, jump=True) + ) + if inst.exn_tab_entry: + # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + # on why depth is computed this way. + depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1 + stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth) + + if False: + for inst in instructions: + stack_size = stack_sizes[inst] + print(stack_size.low, stack_size.high, inst) + + low = min([x.low for x in stack_sizes.values()]) + high = max([x.high for x in stack_sizes.values()]) + + assert fixed_point.value, "failed to reach fixed point" + assert low >= 0 + return high diff --git a/MLPY/Lib/site-packages/torch/_dynamo/bytecode_transformation.py b/MLPY/Lib/site-packages/torch/_dynamo/bytecode_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..767c11e2ab632c8885b3bcec5b8a84fdcbf749db --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/bytecode_transformation.py @@ -0,0 +1,1114 @@ +import copy +import dataclasses +import dis +import itertools +import sys +import types +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple + +from .bytecode_analysis import ( + get_indexof, + propagate_line_nums, + remove_extra_line_nums, + stacksize_analysis, +) + + +@dataclasses.dataclass +class InstructionExnTabEntry: + start: "Instruction" + end: "Instruction" + target: "Instruction" + depth: int + lasti: bool + + def __repr__(self) -> str: + return ( + f"InstructionExnTabEntry(start={self.start.short_inst_repr()}, " + f"end={self.end.short_inst_repr()}, " + f"target={self.target.short_inst_repr()}, " + f"depth={self.depth}, lasti={self.lasti})" + ) + + def __eq__(self, o) -> bool: + return ( + self.start is o.start + and self.end is o.end + and self.target is o.target + and self.depth == o.depth + and self.lasti == o.lasti + ) + + +@dataclasses.dataclass +class Instruction: + """A mutable version of dis.Instruction""" + + opcode: int + opname: str + arg: Optional[int] + argval: Any + offset: Optional[int] = None + starts_line: Optional[int] = None + is_jump_target: bool = False + positions: Optional["dis.Positions"] = None + # extra fields to make modification easier: + target: Optional["Instruction"] = None + exn_tab_entry: Optional[InstructionExnTabEntry] = None + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other) -> bool: + return id(self) == id(other) + + def short_inst_repr(self) -> str: + return f"Instruction(opname={self.opname}, offset={self.offset})" + + +def convert_instruction(i: dis.Instruction) -> Instruction: + return Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.offset, + i.starts_line, + i.is_jump_target, + getattr(i, "positions", None), + ) + + +class _NotProvided: + def __repr__(self) -> str: + return "_NotProvided" + + +def create_instruction( + name, *, arg=None, argval=_NotProvided, target=None +) -> Instruction: + """ + At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. + This is to prevent ambiguity, e.g. does + create_instruction("LOAD_CONST", 5) + mean load the constant at co_consts[5], or load the constant 5? + + If `arg` is not provided, it will be computed during assembly from + `argval` or `target`. + + Do not use for LOAD_GLOBAL - use create_load_global instead. + """ + assert name != "LOAD_GLOBAL" + cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None) + if cnt > 1: + raise RuntimeError( + "only one of arg, argval, and target can be not None/_NotProvided" + ) + if arg is not None and not isinstance(arg, int): + raise RuntimeError("instruction arg must be int or None") + return Instruction( + opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target + ) + + +# Python 3.11 remaps +def create_jump_absolute(target) -> Instruction: + inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" + return create_instruction(inst, target=target) + + +def create_load_global(name, push_null) -> Instruction: + """ + `name` is the name of the global to be loaded. + `push_null` specifies whether or not a NULL should be pushed to the stack + before the global (Python 3.11+ only). + + Python 3.11 changed the LOAD_GLOBAL instruction in that the first bit of + the instruction arg specifies whether a NULL should be pushed to the stack + before the global. The remaining bits of the instruction arg contain the + name index. See `create_call_function` for why this NULL is needed. + + The instruction's `arg` is actually computed when assembling the bytecode. + For Python 3.11, push_null information is propagated through the arg. + + NOTE: we don't use create_instruction since LOAD_GLOBAL is the only instruction + where both arg and argval need to be specified. + """ + return Instruction( + opcode=dis.opmap["LOAD_GLOBAL"], + opname="LOAD_GLOBAL", + arg=push_null, + argval=name, + ) + + +def create_dup_top() -> Instruction: + if sys.version_info >= (3, 11): + return create_instruction("COPY", arg=1) + return create_instruction("DUP_TOP") + + +def create_rot_n(n) -> List[Instruction]: + """ + Returns a "simple" sequence of instructions that rotates TOS to the n-th + position in the stack. For Python < 3.11, returns a single ROT_* + instruction. If no such instruction exists, an error is raised and the + caller is expected to generate an equivalent sequence of instructions. + For Python >= 3.11, any rotation can be expressed as a simple sequence of + swaps. + """ + if n <= 1: + # don't rotate + return [] + + if sys.version_info >= (3, 11): + # rotate can be expressed as a sequence of swap operations + # e.g. rotate 3 is equivalent to swap 3, swap 2 + return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] + + # ensure desired rotate function exists + if sys.version_info < (3, 8) and n >= 4: + raise AttributeError(f"rotate {n} not supported for Python < 3.8") + if sys.version_info < (3, 10) and n >= 5: + raise AttributeError(f"rotate {n} not supported for Python < 3.10") + + if n <= 4: + return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] + return [create_instruction("ROT_N", arg=n)] + + +def create_call_function(nargs, push_null) -> List[Instruction]: + """ + Creates a sequence of instructions that makes a function call. + + `push_null` is used in Python 3.11+ only. It is used in codegen when + a function call is intended to be made with the NULL + fn convention, + and we know that the NULL has not been pushed yet. We will push a + NULL and rotate it to the correct position immediately before making + the function call. + push_null should default to True unless you know you are calling a function + that you codegen'd with a null already pushed, for example + (assume `math` is available in the global scope), + + create_load_global("math", True) # pushes a null + create_instruction("LOAD_ATTR", argval="sqrt") + create_instruction("LOAD_CONST", argval=25) + create_call_function(1, False) + """ + if sys.version_info >= (3, 11): + output = [] + if push_null: + output.append(create_instruction("PUSH_NULL")) + output.extend(create_rot_n(nargs + 2)) + output.append(create_instruction("PRECALL", arg=nargs)) + output.append(create_instruction("CALL", arg=nargs)) + return output + return [create_instruction("CALL_FUNCTION", arg=nargs)] + + +def create_call_method(nargs) -> List[Instruction]: + if sys.version_info >= (3, 11): + return [ + create_instruction("PRECALL", arg=nargs), + create_instruction("CALL", arg=nargs), + ] + return [create_instruction("CALL_METHOD", arg=nargs)] + + +def lnotab_writer( + lineno: int, byteno: int = 0 +) -> Tuple[List[int], Callable[[int, int], None]]: + """ + Used to create typing.CodeType.co_lnotab + See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt + This is the internal format of the line number table if Python < 3.10 + """ + assert sys.version_info < (3, 10) + lnotab: List[int] = [] + + def update(lineno_new, byteno_new): + nonlocal byteno, lineno + while byteno_new != byteno or lineno_new != lineno: + byte_offset = max(0, min(byteno_new - byteno, 255)) + line_offset = max(-128, min(lineno_new - lineno, 127)) + assert byte_offset != 0 or line_offset != 0 + byteno += byte_offset + lineno += line_offset + lnotab.extend((byte_offset, line_offset & 0xFF)) + + return lnotab, update + + +def linetable_310_writer(first_lineno): + """ + Used to create typing.CodeType.co_linetable + See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt + This is the internal format of the line number table for Python 3.10 + """ + assert sys.version_info >= (3, 10) and sys.version_info < (3, 11) + linetable: List[int] = [] + lineno = first_lineno + lineno_delta = 0 + byteno = 0 + + def _update(byteno_delta, lineno_delta): + while byteno_delta != 0 or lineno_delta != 0: + byte_offset = max(0, min(byteno_delta, 254)) + line_offset = max(-127, min(lineno_delta, 127)) + assert byte_offset != 0 or line_offset != 0 + byteno_delta -= byte_offset + lineno_delta -= line_offset + linetable.extend((byte_offset, line_offset & 0xFF)) + + def update(lineno_new, byteno_new): + nonlocal lineno, lineno_delta, byteno + byteno_delta = byteno_new - byteno + byteno = byteno_new + _update(byteno_delta, lineno_delta) + lineno_delta = lineno_new - lineno + lineno = lineno_new + + def end(total_bytes): + _update(total_bytes - byteno, lineno_delta) + + return linetable, update, end + + +def encode_varint(n: int) -> List[int]: + """ + 6-bit chunk encoding of an unsigned integer + See https://github.com/python/cpython/blob/3.11/Objects/locations.md + """ + assert n >= 0 + b = [n & 63] + n >>= 6 + while n > 0: + b[-1] |= 64 + b.append(n & 63) + n >>= 6 + return b + + +def linetable_311_writer(first_lineno: int): + """ + Used to create typing.CodeType.co_linetable + See https://github.com/python/cpython/blob/3.11/Objects/locations.md + This is the internal format of the line number table for Python 3.11 + """ + assert sys.version_info >= (3, 11) + linetable = [] + lineno = first_lineno + + def update(positions: "dis.Positions", inst_size): + nonlocal lineno + lineno_new = positions.lineno if positions else None + + def _update(delta, size): + assert 0 < size <= 8 + # first byte - use 13 (no column info) is positions is + # malformed, otherwise use 14 (long form) + other_varints: Tuple[int, ...] = () + if ( + positions + and positions.lineno is not None + and positions.end_lineno is not None + and positions.col_offset is not None + and positions.end_col_offset is not None + ): + linetable.append(0b1_1110_000 + size - 1) + # for whatever reason, column offset needs `+ 1` + # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603 + other_varints = ( + positions.end_lineno - positions.lineno, + positions.col_offset + 1, + positions.end_col_offset + 1, + ) + else: + linetable.append(0b1_1101_000 + size - 1) + # encode signed int + if delta < 0: + delta = ((-delta) << 1) | 1 + else: + delta <<= 1 + # encode unsigned int + linetable.extend(encode_varint(delta)) + for n in other_varints: + linetable.extend(encode_varint(n)) + + if lineno_new is None: + lineno_delta = 0 + else: + lineno_delta = lineno_new - lineno + lineno = lineno_new + while inst_size > 8: + _update(lineno_delta, 8) + inst_size -= 8 + _update(lineno_delta, inst_size) + + return linetable, update + + +@dataclasses.dataclass +class ExceptionTableEntry: + start: int + end: int + target: int + depth: int + lasti: bool + + +def encode_exception_table_varint(n: int) -> List[int]: + """ + Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse. + """ + assert n >= 0 + b = [n & 63] + n >>= 6 + while n > 0: + b.append(n & 63) + n >>= 6 + b.reverse() + for i in range(len(b) - 1): + b[i] |= 64 + return b + + +def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int: + """ + Inverse of `encode_exception_table_varint`. + """ + b = next(bytes_iter) + val = b & 63 + while b & 64: + val <<= 6 + b = next(bytes_iter) + val |= b & 63 + return val + + +def check_exception_table(tab: List[ExceptionTableEntry]) -> None: + """ + Verifies that a list of ExceptionTableEntries will make a well-formed + jump table: entries are non-empty, sorted, and do not overlap. + """ + for i in range(len(tab) - 1): + assert ( + tab[i].start <= tab[i].end + and tab[i].end < tab[i + 1].start + and tab[i + 1].start <= tab[i + 1].end + ) + + +def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]: + """ + Parse the exception table according to + https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + """ + exntab_iter = iter(exntab) + tab = [] + try: + while True: + start = decode_exception_table_varint(exntab_iter) * 2 + length = decode_exception_table_varint(exntab_iter) * 2 + end = start + length - 2 + target = decode_exception_table_varint(exntab_iter) * 2 + dl = decode_exception_table_varint(exntab_iter) + depth = dl >> 1 + lasti = bool(dl & 1) + tab.append(ExceptionTableEntry(start, end, target, depth, lasti)) + except StopIteration: + check_exception_table(tab) + return tab + + +def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes: + """ + Inverse of parse_exception_table - encodes list of exception + table entries into bytes. + """ + b = [] + for entry in tab: + first_entry = encode_exception_table_varint(entry.start // 2) + first_entry[0] |= 1 << 7 + b.extend(first_entry) + length = entry.end - entry.start + 2 + b.extend(encode_exception_table_varint(length // 2)) + b.extend(encode_exception_table_varint(entry.target // 2)) + dl = (entry.depth << 1) + entry.lasti + b.extend(encode_exception_table_varint(dl)) + return bytes(b) + + +def assemble(instructions: List[Instruction], firstlineno: int) -> Tuple[bytes, bytes]: + """Do the opposite of dis.get_instructions()""" + code: List[int] = [] + if sys.version_info >= (3, 11): + lnotab, update_lineno = linetable_311_writer(firstlineno) + num_ext = 0 + for i, inst in enumerate(instructions): + if inst.opname == "EXTENDED_ARG": + inst_size = 1 + num_ext += 1 + # copy positions from the actual instruction + for j in (1, 2, 3): + if instructions[i + j].opname != "EXTENDED_ARG": + inst.positions = instructions[i + j].positions + break + else: + inst_size = instruction_size(inst) // 2 + num_ext + num_ext = 0 + update_lineno(inst.positions, inst_size) + num_ext = 0 + arg = inst.arg or 0 + code.extend((inst.opcode, arg & 0xFF)) + for _ in range(instruction_size(inst) // 2 - 1): + code.extend((0, 0)) + else: + if sys.version_info < (3, 10): + lnotab, update_lineno = lnotab_writer(firstlineno) + else: + lnotab, update_lineno, end = linetable_310_writer(firstlineno) + + for inst in instructions: + if inst.starts_line is not None: + update_lineno(inst.starts_line, len(code)) + arg = inst.arg or 0 + code.extend((inst.opcode, arg & 0xFF)) + + if sys.version_info >= (3, 10): + end(len(code)) + + return bytes(code), bytes(lnotab) + + +def _get_instruction_by_offset(offset_to_inst: Dict[int, Instruction], offset: int): + """ + Get the instruction located at a given offset, accounting for EXTENDED_ARGs + """ + for n in (0, 2, 4, 6): + if offset_to_inst[offset + n].opcode != dis.EXTENDED_ARG: + return offset_to_inst[offset + n] + return None + + +def virtualize_jumps(instructions) -> None: + """Replace jump targets with pointers to make editing easier""" + jump_targets = {inst.offset: inst for inst in instructions} + + for inst in instructions: + if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: + inst.target = _get_instruction_by_offset(jump_targets, inst.argval) + + +_REL_JUMPS = set(dis.hasjrel) + + +def flip_jump_direction(instruction: Instruction) -> None: + if sys.version_info < (3, 11): + raise RuntimeError("Cannot flip jump direction in Python < 3.11") + if "FORWARD" in instruction.opname: + instruction.opname = instruction.opname.replace("FORWARD", "BACKWARD") + elif "BACKWARD" in instruction.opname: + instruction.opname = instruction.opname.replace("BACKWARD", "FORWARD") + else: + raise AttributeError("Instruction is not a forward or backward jump") + instruction.opcode = dis.opmap[instruction.opname] + assert instruction.opcode in _REL_JUMPS + + +def _get_instruction_front(instructions: List[Instruction], idx: int): + """ + i.e. get the first EXTENDED_ARG instruction (if any) when targeting + instructions[idx] with a jump. + """ + target = instructions[idx] + for offset in (1, 2, 3): + if idx >= offset and instructions[idx - offset].opcode == dis.EXTENDED_ARG: + target = instructions[idx - offset] + else: + break + return target + + +def devirtualize_jumps(instructions): + """Fill in args for virtualized jump target after instructions may have moved""" + indexof = get_indexof(instructions) + jumps = set(dis.hasjabs).union(set(dis.hasjrel)) + + for inst in instructions: + if inst.opcode in jumps: + target = _get_instruction_front(instructions, indexof[inst.target]) + if inst.opcode in dis.hasjabs: + if sys.version_info < (3, 10): + inst.arg = target.offset + elif sys.version_info < (3, 11): + # `arg` is expected to be bytecode offset, whereas `offset` is byte offset. + # Divide since bytecode is 2 bytes large. + inst.arg = int(target.offset / 2) + else: + raise RuntimeError("Python 3.11+ should not have absolute jumps") + else: # relative jump + # byte offset between target and next instruction + inst.arg = int(target.offset - inst.offset - instruction_size(inst)) + if inst.arg < 0: + if sys.version_info < (3, 11): + raise RuntimeError("Got negative jump offset for Python < 3.11") + inst.arg = -inst.arg + # forward jumps become backward + if "FORWARD" in inst.opname: + flip_jump_direction(inst) + elif inst.arg > 0: + # backward jumps become forward + if sys.version_info >= (3, 11) and "BACKWARD" in inst.opname: + flip_jump_direction(inst) + if sys.version_info >= (3, 10): + # see bytecode size comment in the absolute jump case above + inst.arg //= 2 + inst.argval = target.offset + inst.argrepr = f"to {target.offset}" + + +def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruction]): + """Replace exception table entries with pointers to make editing easier""" + exn_tab = parse_exception_table(exn_tab_bytes) + offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} + offsets = sorted(offset_to_inst.keys()) + end_offset_idx = 0 + exn_tab_iter = iter(exn_tab) + try: + + def step(): + nonlocal end_offset_idx + entry = next(exn_tab_iter) + # find rightmost offset <= entry.end, since entry.end may not be + # an actual instruction, e.g. if the end instruction is LOAD_GLOBAL, + # which takes more than 2 bytes, then entry.end points to the end + # of the LOAD_GLOBAL instruction, not the beginning. + while ( + end_offset_idx < len(offsets) and offsets[end_offset_idx] <= entry.end + ): + end_offset_idx += 1 + assert end_offset_idx > 0 + end_offset = offsets[end_offset_idx - 1] + inst_entry = InstructionExnTabEntry( + _get_instruction_by_offset(offset_to_inst, entry.start), + _get_instruction_by_offset(offset_to_inst, end_offset), + _get_instruction_by_offset(offset_to_inst, entry.target), + entry.depth, + entry.lasti, + ) + return entry, inst_entry + + entry, inst_entry = step() + for inst in instructions: + while inst.offset > entry.end: + entry, inst_entry = step() + if inst.offset >= entry.start: + inst.exn_tab_entry = copy.copy(inst_entry) + except StopIteration: + pass + + +def compute_exception_table( + instructions: List[Instruction], +) -> List[ExceptionTableEntry]: + """Compute exception table in list format from instructions with exn_tab_entries""" + exn_dict: Dict[Tuple[int, int], Tuple[int, int, bool]] = {} + indexof = get_indexof(instructions) + + for inst in instructions: + if inst.exn_tab_entry: + # account for prefixed EXTENDED_ARGS + start = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.start] + ).offset + # point to the last 2 bytes of the end instruction + end = ( + cast(int, inst.exn_tab_entry.end.offset) + + instruction_size(inst.exn_tab_entry.end) + - 2 + ) + target = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.target] + ).offset + key = (start, end) + val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) + if key in exn_dict: + assert exn_dict[key] == val + exn_dict[key] = val + + # Dynamo may construct nested exception table entries for convenience, + # but Python expects exception table entries to not overlap. + # NOTE: below, "keys" refer to old instruction entries' starts and ends, + # and "entries" refer to the generated exception table entries. + + # Sort keys by increasing start, then decreasing end + keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1])) + # smallest byte that the next exception table entry can start at + nexti = 0 + # stack of current nested keys + key_stack: List[Tuple[int, int]] = [] + exn_tab: List[ExceptionTableEntry] = [] + + def pop(): + """ + Pop the key_stack and append an exception table entry if possible. + """ + nonlocal nexti + if key_stack: + key = key_stack.pop() + if nexti <= key[1]: + exn_tab.append( + ExceptionTableEntry(max(key[0], nexti), key[1], *exn_dict[key]) + ) + nexti = key[1] + 2 + + for key in keys_sorted: + # pop keys that are no longer nested over the current key + while key_stack and key_stack[-1][1] < key[0]: + pop() + if key_stack: + # create an entry covering to the current key, if possible + assert key_stack[-1][0] <= key[0] <= key[1] <= key_stack[-1][1] + left = max(nexti, key_stack[-1][0]) + if left < key[0]: + exn_tab.append( + ExceptionTableEntry(left, key[0] - 2, *exn_dict[key_stack[-1]]) + ) + nexti = key[0] + key_stack.append(key) + while key_stack: + pop() + check_exception_table(exn_tab) + return exn_tab + + +def check_inst_exn_tab_entries_nested( + tab: List[InstructionExnTabEntry], indexof +) -> None: + """ + Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, + i.e. no entries partially overlap. + "Properly sorted" means entries are sorted by increasing starts, then + decreasing ends. + """ + entry_stack: List[Tuple[int, int]] = [] + for entry in tab: + key = (indexof[entry.start], indexof[entry.end]) + while entry_stack and entry_stack[-1][1] < key[0]: + entry_stack.pop() + if entry_stack: + assert entry_stack[-1][0] <= key[0] <= key[1] <= entry_stack[-1][1] + entry_stack.append(key) + + +def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None: + """ + Copies exception table entries to all instructions in an entry's range. + Supports nested exception table entries. + """ + indexof = get_indexof(instructions) + entries: Dict[Tuple[int, int], InstructionExnTabEntry] = {} + for inst in instructions: + if inst.exn_tab_entry: + key = ( + indexof[inst.exn_tab_entry.start], + indexof[inst.exn_tab_entry.end], + ) + if key in entries: + assert inst.exn_tab_entry == entries[key] + entries[key] = inst.exn_tab_entry + sorted_entries = [ + entries[key] for key in sorted(entries.keys(), key=lambda t: (t[0], -t[1])) + ] + check_inst_exn_tab_entries_nested(sorted_entries, indexof) + # Propagation of nested entries works since nested entries come later + # in sorted order. + for entry in sorted_entries: + for i in range(indexof[entry.start], indexof[entry.end] + 1): + instructions[i].exn_tab_entry = copy.copy(entry) + + +def check_inst_exn_tab_entries_valid(instructions: List[Instruction]): + """ + Checks that exn_tab_entries of instructions are valid. + An entry's start, end, and target must be in instructions. + Instructions with an exn_tab_entry are located within + the entry's start and end instructions. + Instructions do not share exn_tab_entries. + + Implicitly checks for no duplicate instructions. + """ + indexof = get_indexof(instructions) + exn_tab_entry_set = set() + for i, inst in enumerate(instructions): + if inst.exn_tab_entry: + assert sys.version_info >= (3, 11) + assert id(inst.exn_tab_entry) not in exn_tab_entry_set + exn_tab_entry_set.add(id(inst.exn_tab_entry)) + entry = inst.exn_tab_entry + assert entry.start in indexof + assert entry.end in indexof + assert entry.target in indexof + assert indexof[entry.start] <= i <= indexof[entry.end] + + +def strip_extended_args(instructions: List[Instruction]) -> None: + instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG] + + +def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]: + """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it""" + rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"} + for inst in instructions: + if inst.opname in rewrites: + inst.opname = rewrites[inst.opname] + inst.opcode = dis.opmap[inst.opname] + return instructions + + +def remove_jump_if_none(instructions: List[Instruction]) -> None: + new_insts = [] + for inst in instructions: + new_insts.append(inst) + if "_NONE" in inst.opname: + is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname)) + is_op.argval = is_op.arg + jump_op = create_instruction( + "POP_JUMP_FORWARD_IF_TRUE" + if "FORWARD" in inst.opname + else "POP_JUMP_BACKWARD_IF_TRUE", + target=inst.target, + ) + # modify inst in-place to preserve jump target + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + inst.arg = None + inst.argval = None + new_insts.extend([is_op, jump_op]) + instructions[:] = new_insts + + +def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> None: + """convert super() with no args into explicit arg form""" + cell_and_free = (code.co_cellvars or tuple()) + (code.co_freevars or tuple()) + if not len(code.co_varnames): + # A function with no argument cannot contain a valid "super()" call + return + output = [] + for idx, inst in enumerate(instructions): + output.append(inst) + if inst.opname == "LOAD_GLOBAL" and inst.argval == "super": + nexti = instructions[idx + 1] + if nexti.opname in ("CALL_FUNCTION", "PRECALL") and nexti.arg == 0: + assert "__class__" in cell_and_free + output.append(create_instruction("LOAD_DEREF", argval="__class__")) + first_var = code.co_varnames[0] + if first_var in cell_and_free: + output.append(create_instruction("LOAD_DEREF", argval=first_var)) + else: + output.append(create_instruction("LOAD_FAST", argval=first_var)) + nexti.arg = 2 + nexti.argval = 2 + if nexti.opname == "PRECALL": + # also update the following CALL instruction + call_inst = instructions[idx + 2] + call_inst.arg = 2 + call_inst.argval = 2 + + instructions[:] = output + + +def fix_extended_args(instructions: List[Instruction]) -> int: + """Fill in correct argvals for EXTENDED_ARG ops""" + output: List[Instruction] = [] + + def maybe_pop_n(n): + for _ in range(n): + if output and output[-1].opcode == dis.EXTENDED_ARG: + output.pop() + + for inst in instructions: + if inst.opcode == dis.EXTENDED_ARG: + # Leave this instruction alone for now so we never shrink code + inst.arg = 0 + elif inst.arg and inst.arg > 0xFFFFFF: + maybe_pop_n(3) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 24)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + elif inst.arg and inst.arg > 0xFFFF: + maybe_pop_n(2) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + elif inst.arg and inst.arg > 0xFF: + maybe_pop_n(1) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + output.append(inst) + + added = len(output) - len(instructions) + assert added >= 0 + instructions[:] = output + return added + + +# from https://github.com/python/cpython/blob/v3.11.1/Include/internal/pycore_opcode.h#L41 +# TODO use the actual object instead, can interface from eval_frame.c +_PYOPCODE_CACHES = { + "BINARY_SUBSCR": 4, + "STORE_SUBSCR": 1, + "UNPACK_SEQUENCE": 1, + "STORE_ATTR": 4, + "LOAD_ATTR": 4, + "COMPARE_OP": 2, + "LOAD_GLOBAL": 5, + "BINARY_OP": 1, + "LOAD_METHOD": 10, + "PRECALL": 1, + "CALL": 4, +} + + +def instruction_size(inst) -> int: + if sys.version_info >= (3, 11): + return 2 * (_PYOPCODE_CACHES.get(dis.opname[inst.opcode], 0) + 1) + return 2 + + +def check_offsets(instructions) -> None: + offset = 0 + for inst in instructions: + assert inst.offset == offset + offset += instruction_size(inst) + + +def update_offsets(instructions) -> None: + offset = 0 + for inst in instructions: + inst.offset = offset + offset += instruction_size(inst) + + +def debug_bytes(*args) -> str: + index = range(max(map(len, args))) + result = [] + for arg in ( + [index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]] + ): + result.append(" ".join(f"{x:03}" for x in arg)) + + return "bytes mismatch\n" + "\n".join(result) + + +def debug_checks(code): + """Make sure our assembler produces same bytes as we start with""" + dode = transform_code_object(code, lambda x, y: None, safe=True) + assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) + assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab) + + +HAS_LOCAL = set(dis.haslocal) +HAS_NAME = set(dis.hasname) +HAS_FREE = set(dis.hasfree) +HAS_CONST = set(dis.hasconst) + + +def get_const_index(code_options, val) -> int: + for i, v in enumerate(code_options["co_consts"]): + # NOTE: stronger comparison is required, since we have + # examples where two values compare equal but have + # different semantic meaning in some cases, e.g. + # 0.0 == -0.0 but have different effects in torch.copysign. + if val is v: + return i + code_options["co_consts"] += (val,) + return len(code_options["co_consts"]) - 1 + + +def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None): + # compute instruction arg from argval if arg is not provided + names = {name: idx for idx, name in enumerate(code_options["co_names"])} + if sys.version_info < (3, 11): + assert varname_from_oparg is None + varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])} + freenames = { + name: idx + for idx, name in enumerate( + code_options["co_cellvars"] + code_options["co_freevars"] + ) + } + else: + assert callable(varname_from_oparg) + allnames = {} + for idx in itertools.count(): + try: + name = varname_from_oparg(idx) + allnames[name] = idx + except IndexError: + break + varnames = {name: allnames[name] for name in code_options["co_varnames"]} + freenames = { + name: allnames[name] + for name in code_options["co_cellvars"] + code_options["co_freevars"] + } + for i in range(len(instructions)): + + def should_compute_arg(): + # argval is prioritized over arg + return instructions[i].argval is not _NotProvided + + if instructions[i].opname == "LOAD_GLOBAL": + # 3.11 LOAD_GLOBAL requires both arg and argval - see create_load_global + assert instructions[i].arg is not None + assert instructions[i].argval is not _NotProvided + if sys.version_info >= (3, 11): + instructions[i].arg = (names[instructions[i].argval] << 1) + ( + cast(int, instructions[i].arg) % 2 + ) + else: + instructions[i].arg = names[instructions[i].argval] + elif instructions[i].opcode in HAS_LOCAL: + if should_compute_arg(): + instructions[i].arg = varnames[instructions[i].argval] + elif instructions[i].opcode in HAS_NAME: + if should_compute_arg(): + instructions[i].arg = names[instructions[i].argval] + elif instructions[i].opcode in HAS_FREE: + if should_compute_arg(): + instructions[i].arg = freenames[instructions[i].argval] + elif instructions[i].opcode in HAS_CONST: + # NOTE: only update argval if arg is not provided. This assumes + # that any additions to co_consts are appended. + if instructions[i].arg is None: + # cannot use a dictionary since consts may not be hashable + idx = get_const_index(code_options, instructions[i].argval) + assert idx >= 0 + instructions[i].arg = idx + + +def get_code_keys() -> List[str]: + # Python 3.11 changes to code keys are not fully documented. + # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 + # for new format. + keys = ["co_argcount"] + keys.append("co_posonlyargcount") + keys.extend( + [ + "co_kwonlyargcount", + "co_nlocals", + "co_stacksize", + "co_flags", + "co_code", + "co_consts", + "co_names", + "co_varnames", + "co_filename", + "co_name", + ] + ) + if sys.version_info >= (3, 11): + keys.append("co_qualname") + keys.append("co_firstlineno") + if sys.version_info >= (3, 10): + keys.append("co_linetable") + else: + keys.append("co_lnotab") + if sys.version_info >= (3, 11): + # not documented, but introduced in https://github.com/python/cpython/issues/84403 + keys.append("co_exceptiontable") + keys.extend( + [ + "co_freevars", + "co_cellvars", + ] + ) + return keys + + +def transform_code_object(code, transformations, safe=False) -> types.CodeType: + keys = get_code_keys() + code_options = {k: getattr(code, k) for k in keys} + assert len(code_options["co_varnames"]) == code_options["co_nlocals"] + + instructions = cleaned_instructions(code, safe) + propagate_line_nums(instructions) + + transformations(instructions, code_options) + return clean_and_assemble_instructions(instructions, keys, code_options)[1] + + +def clean_and_assemble_instructions( + instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any] +) -> Tuple[List[Instruction], types.CodeType]: + # also implicitly checks for no duplicate instructions + check_inst_exn_tab_entries_valid(instructions) + + code_options["co_nlocals"] = len(code_options["co_varnames"]) + varname_from_oparg = None + if sys.version_info >= (3, 11): + # temporary code object with updated names + tmp_code = types.CodeType(*[code_options[k] for k in keys]) + varname_from_oparg = tmp_code._varname_from_oparg # type: ignore[attr-defined] + fix_vars(instructions, code_options, varname_from_oparg=varname_from_oparg) + + dirty = True + while dirty: + update_offsets(instructions) + devirtualize_jumps(instructions) + # this pass might change offsets, if so we need to try again + dirty = bool(fix_extended_args(instructions)) + + remove_extra_line_nums(instructions) + bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"]) + if sys.version_info < (3, 10): + code_options["co_lnotab"] = lnotab + else: + code_options["co_linetable"] = lnotab + + code_options["co_code"] = bytecode + code_options["co_stacksize"] = stacksize_analysis(instructions) + assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - { + "co_posonlyargcount" + } + if sys.version_info >= (3, 11): + code_options["co_exceptiontable"] = assemble_exception_table( + compute_exception_table(instructions) + ) + return instructions, types.CodeType(*[code_options[k] for k in keys]) + + +def populate_kw_names_argval(instructions, consts): + for inst in instructions: + if inst.opname == "KW_NAMES": + inst.argval = consts[inst.arg] + + +def cleaned_instructions(code, safe=False) -> List[Instruction]: + instructions = list(map(convert_instruction, dis.get_instructions(code))) + check_offsets(instructions) + if sys.version_info >= (3, 11): + populate_kw_names_argval(instructions, code.co_consts) + virtualize_exception_table(code.co_exceptiontable, instructions) + virtualize_jumps(instructions) + strip_extended_args(instructions) + if not safe: + if sys.version_info < (3, 11): + remove_load_call_method(instructions) + else: + remove_jump_if_none(instructions) + update_offsets(instructions) + devirtualize_jumps(instructions) + explicit_super(code, instructions) + return instructions + + +_unique_id_counter = itertools.count() + + +def unique_id(name) -> str: + return f"{name}_{next(_unique_id_counter)}" + + +def is_generator(code: types.CodeType) -> bool: + co_generator = 0x20 + return (code.co_flags & co_generator) > 0 diff --git a/MLPY/Lib/site-packages/torch/_dynamo/cache_size.py b/MLPY/Lib/site-packages/torch/_dynamo/cache_size.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbbd5e26541eae62db34a3e498f6191c00f4b99 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/cache_size.py @@ -0,0 +1,172 @@ +import logging +import types +import weakref +from dataclasses import dataclass +from typing import Tuple + +from . import config + +log = logging.getLogger(__name__) +""" +[Note on cache size limit] + +Background - TorchDynamo cache is a linked list. Each cache entry is a +(check_fn, out_code, next pointer). These are stored on the f_code's co_extra +scratch space. When a frame is invoked, we walk this linked list and run +check_fn in each cache_entry to decide if the frame needs recompilation. If none +of the check_fn's returns True, we recompile and add a new entry. To ensure we +don't end up recompiling infinitely, we put limits on the cache size. + +There are two limits +1) cache_size_limit +2) accumulated_cache_size_limit + + +Earlier we used to have only limit - maximum number of entries in 1 cache line +(which is now represented by (2) above). So, why do we need two limits? Lets try +to understand that. + +In general, we want our cache limit value to be a small number (e.g. 8 or even +lower). This ensures that for frames that cause too many recompilation fall to +eager quickly. However, there is another problem that prevents us from lowering +the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put +ID_MATCH guards on nn module if there is a graph break. This means we will have +many recompilations for the same code object because the ID_MATCH guard fails +for different instances of the nn module. This is a common pattern in how models +are authored. Therefore, this requires us to keep the cache_size_limit high. + +We resolve this by introducing these two limits. The first limit (1) limits the +number of cache entries that have an ID_MATCH'd guard for an nn module instance. +And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations +for a code object. One important question is - what is the limit for the code +object that does not have any ID_MATCH guard? For such code objects, we choose +(1) as the cache size limit. + +Lets take an example to understand how these limits help. Suppose, we have 16 +instances of a nn module and we ID_MATCH on the self object. Further, suppose +the inputs to these functions have varying batch size, leading to one +recompilation. In total, there will be 32 recompilations, and therefore 32 cache +entries on the forward code object. In the older case when we had only 1 limit, +our cache size limit must be >= 32 to capture all these recompilations. Now, +suppose there is a separate function in the same program which is very dynamic +and unsuitable for compilation. Such a function will need to undergo 32 +compilations to burst the cache and fallback to eager. These 32 recompilations +are too many and we want to fallback for these compilation-unfriendly functions +sooner. + +In the new scenario, we can have (1) cache_size_limit = 2, (2) +accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can +have maximum of two cache entries, and the maximum number of cache entries +(irrespective of ID_MATCH obj) is 32. This covers the case of forward code +object which has 32 recompilations. For the other function, the one unsuitable +for recompilation, our limit is 2. So, we will burst the cache in just 2 +recompilations. In this manner, these 2 limits help us resolve the tension +mentioned earlier. +""" + + +@dataclass +class CacheSizeRelevantForFrame: + """ + We track the number of cache entries that have same id_match objects as the + given frame. + + TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count - + https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this + could be useful for debugging as well. + """ + + # Total number of CacheEntry objects in the Dynamo linked list + num_cache_entries: int = 0 + + # Number of CacheEntry objects having same ID_MATCH'd objects as given frame. + num_cache_entries_with_same_id_matched_objs: int = 0 + + def will_compilation_exceed(self, limit: int) -> bool: + # Checks if a compilation will exceed the given limit (thats why >=). + return ( + self.will_compilation_exceed_accumulated_limit() + or self.will_compilation_exceed_specific_limit(limit) + ) + + def will_compilation_exceed_accumulated_limit(self) -> bool: + return self.num_cache_entries >= config.accumulated_cache_size_limit + + def will_compilation_exceed_specific_limit(self, limit: int) -> bool: + return self.num_cache_entries_with_same_id_matched_objs >= limit + + +def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str): + obj = frame.f_locals.get(local_name, None) + weak_id = None + try: + weak_id = weakref.ref(obj) + except TypeError: + pass # cannot weakref bool object + return weak_id + + +def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: + """ + Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones + in frame.f_locals. + """ + if not cache_entry: + return False + + for ( + local_name, + weakref_from_cache_entry, + ) in cache_entry.check_fn.id_matched_objs.items(): + if weakref_from_cache_entry() is not None: + weakref_from_frame = _get_weakref_from_f_locals(frame, local_name) + if weakref_from_frame != weakref_from_cache_entry: + return False + + # Also covers the case where no ID_MATCH objects are saved in frame.f_locals + return True + + +def compute_cache_size( + frame: types.FrameType, cache_entry +) -> CacheSizeRelevantForFrame: + # Walk the linked list to calculate the cache size + num_cache_entries = 0 + num_cache_entries_with_same_id_matched_objs = 0 + + while cache_entry: + num_cache_entries += 1 + # Track the number of cache entries having same ID_MATCH'd objects as + # that of frame.f_locals. This will be used later to compare against the + # cache_size_limit. + if _has_same_id_matched_objs(frame, cache_entry): + num_cache_entries_with_same_id_matched_objs += 1 + cache_entry = cache_entry.next + + return CacheSizeRelevantForFrame( + num_cache_entries, num_cache_entries_with_same_id_matched_objs + ) + + +def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool: + """ + If the frame (earlier parsed by compute_cache_size) has more than 1 cache + entry with same ID_MATCH'd objects, then its a recompilation. + """ + # Note that you can have multiple entries in the cache but still not a + # recompile, e.g., you can have 64 nn module instances, each one having an + # ID_MATCH guard, and each one having just 1 cache entry in the cache. In + # this case, we can have 64 entries in the cache, but no recompilation + # because there is only one entry for each id_matched_obj. + return cache_size.will_compilation_exceed(1) + + +def exceeds_cache_size_limit(cache_size: CacheSizeRelevantForFrame) -> Tuple[bool, str]: + """ + Checks if we are exceeding the cache size limit. + """ + if cache_size.will_compilation_exceed_accumulated_limit(): + return True, "accumulated_cache_size_limit" + if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit): + return True, "cache_size_limit" + return False, "" diff --git a/MLPY/Lib/site-packages/torch/_dynamo/callback.py b/MLPY/Lib/site-packages/torch/_dynamo/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6979645af78764a7359ca88bf7880fbefd5a54 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/callback.py @@ -0,0 +1,82 @@ +class CompilationCallbackHandler: + def __init__(self): + self.start_callbacks = [] + self.end_callbacks = [] + + def register_start_callback(self, callback): + """ + Register a callback function to be called when the compilation starts. + + Args: + - callback (callable): The callback function to register. + """ + self.start_callbacks.append(callback) + return callback + + def register_end_callback(self, callback): + """ + Register a callback function to be called when the compilation ends. + + Args: + - callback (callable): The callback function to register. + """ + self.end_callbacks.append(callback) + return callback + + def remove_start_callback(self, callback): + """ + Remove a registered start callback function. + + Args: + - callback (callable): The callback function to remove. + """ + self.start_callbacks.remove(callback) + + def remove_end_callback(self, callback): + """ + Remove a registered end callback function. + + Args: + - callback (callable): The callback function to remove. + """ + self.end_callbacks.remove(callback) + + def run_start_callbacks(self): + """ + Execute all registered start callbacks. + """ + for callback in self.start_callbacks: + callback() + + def run_end_callbacks(self): + """ + Execute all registered end callbacks. + """ + for callback in self.end_callbacks: + callback() + + def clear(self): + """ + Clear all registered callbacks. + """ + self.start_callbacks.clear() + self.end_callbacks.clear() + + +callback_handler = CompilationCallbackHandler() + + +def on_compile_start(callback): + """ + Decorator to register a callback function for the start of the compilation. + """ + callback_handler.register_start_callback(callback) + return callback + + +def on_compile_end(callback): + """ + Decorator to register a callback function for the end of the compilation. + """ + callback_handler.register_end_callback(callback) + return callback diff --git a/MLPY/Lib/site-packages/torch/_dynamo/code_context.py b/MLPY/Lib/site-packages/torch/_dynamo/code_context.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5804336fe270f0a29f7a2e17efc80e0bfe0f7d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/code_context.py @@ -0,0 +1,29 @@ +import types + +from .utils import ExactWeakKeyDictionary + + +class CodeContextDict: + def __init__(self): + self.code_context = ExactWeakKeyDictionary() + + def has_context(self, code: types.CodeType): + return code in self.code_context + + def get_context(self, code: types.CodeType): + ctx = self.code_context.get(code) + if ctx is None: + ctx = {} + self.code_context[code] = ctx + return ctx + + def pop_context(self, code: types.CodeType): + ctx = self.get_context(code) + self.code_context._remove_id(id(code)) + return ctx + + def clear(self): + self.code_context.clear() + + +code_context = CodeContextDict() diff --git a/MLPY/Lib/site-packages/torch/_dynamo/codegen.py b/MLPY/Lib/site-packages/torch/_dynamo/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..767bb2c80a5e9b62bd4fb3ac26d87e4c995b958c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/codegen.py @@ -0,0 +1,398 @@ +import collections +import dataclasses +import re +import sys +import types +from typing import Counter, Dict, List, Optional + +import torch.nn +from . import utils + +from .bytecode_transformation import ( + create_call_function, + create_dup_top, + create_instruction, + create_load_global, + create_rot_n, + Instruction, +) +from .exc import unimplemented +from .source import AttrSource, Source +from .utils import is_safe_constant, rot_n_helper +from .variables.base import VariableTracker +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .variables.torch_function import TensorWithTFOverrideVariable + + +@dataclasses.dataclass +class GraphOutputEntry: + index: int + variable: VariableTracker + + +class PyCodegen: + """ + Helper class uses for constructing Python bytecode + """ + + def __init__( + self, + tx=None, + root: Optional[torch.nn.Module] = None, + graph_output_var: Optional[str] = None, + tempvars=None, + ): + self.root = root + self.top_of_stack: Optional[VariableTracker] = None + self.uses: Counter[VariableTracker] = collections.Counter() + self.graph_outputs: Dict[int, GraphOutputEntry] = {} + self._output: List[Instruction] = [] + self.tempvars = tempvars or {} + self.tx = tx + self.graph_output_var = graph_output_var + self.code_options = self.tx.output.code_options + self.cell_and_freevars = self.tx.cell_and_freevars + self.new_var = self.tx.output.new_var + self.mutable_side_effects_from_source = False + self.value_from_source: bool = True + + def restore_stack(self, stack_values, *, value_from_source=True): + prior = self.mutable_side_effects_from_source + self.mutable_side_effects_from_source = True + prev = self.value_from_source + self.value_from_source &= value_from_source + try: + self.foreach(stack_values) + finally: + self.mutable_side_effects_from_source = prior + self.value_from_source = prev + + def graph_output_vars(self): + return [x.variable for x in self.graph_outputs.values()] + + def call_reconstruct(self, value): + res = value.reconstruct(self) + assert res is None, f"reconstruct!=None {value}" + + def __call__(self, value, allow_cache=True): + """Generate code such that top-of-stack (TOS) is set to value""" + if isinstance(value, Source): + self.call_reconstruct(value) + self.clear_tos() + return + + assert isinstance(value, VariableTracker) + output = self._output + graph_outputs = self.graph_outputs + + if self.top_of_stack is value and allow_cache: + output.append(create_dup_top()) + return + + if self.mutable_side_effects_from_source: + # this is needed to get aliasing relationships right + # value.mutable_local.source will get mutated to hold `value` + # mutable_side_effects_from_source=False is used to codegen the mutation + # mutable_side_effects_from_source=True is used to codegen a reference + from .side_effects import MutableSideEffects + + if isinstance(value.mutable_local, MutableSideEffects): + self(value.mutable_local.source) + return + + if allow_cache: + if value.mutable_local and value.mutable_local in self.tempvars: + output.append(self.create_load(self.tempvars[value.mutable_local])) + self.top_of_stack = value + return + if self.tempvars.get(value) is not None: + output.append(self.create_load(self.tempvars[value])) + self.top_of_stack = value + return + + if value.source is not None and allow_cache and self.value_from_source: + self.call_reconstruct(value.source) + elif value.is_python_constant() and is_safe_constant( + value.as_python_constant() + ): + output.append(self.create_load_const(value.as_python_constant())) + elif isinstance(value, TensorWithTFOverrideVariable): + graph_outputs_key = self.add_graph_output(value) + + self.load_import_from(utils.__name__, "to_subclass") + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append( + self.create_load_global( + value.global_mangled_class_name(self.tx), False, add=True + ) + ) + output.extend(create_call_function(2, True)) + elif isinstance( + value, + ( + TensorVariable, + SymNodeVariable, + UnspecializedPythonVariable, + NumpyNdarrayVariable, + ), + ): + graph_outputs_key = self.add_graph_output(value) + + if isinstance(value, NumpyNdarrayVariable): + self.load_import_from(utils.__name__, "to_numpy_helper") + + self.load_graph_output(graph_outputs[graph_outputs_key].index) + + if isinstance(value, NumpyNdarrayVariable): + output.extend(create_call_function(1, True)) + elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: + output.extend( + [self.create_load_attr("item")] + create_call_function(0, True) + ) + elif isinstance(value, NNModuleVariable): + parts = value.module_key.split(".") + if parts[0] in self.code_options["co_varnames"]: + output.append(self.create_load(parts[0])) + parts = parts[1:] + else: + assert self.root is not None + output.append(self.create_load_output(self.root)) + for part in parts: + output.append(self.create_load_attr(part)) + else: + self.uses[value] += 1 + try: + self.call_reconstruct(value) + except NotImplementedError: + unimplemented(f"reconstruct: {value}") + if allow_cache and value in self.tempvars: + self._output.append(create_dup_top()) + self.add_cache(value) + + self.top_of_stack = value + + def add_graph_output(self, value): + graph_outputs_key = id(value.as_proxy()) + if graph_outputs_key not in self.graph_outputs: + self.graph_outputs[graph_outputs_key] = GraphOutputEntry( + len(self.graph_outputs), value + ) + return graph_outputs_key + + def load_graph_output(self, index): + output = self._output + output.append(self.create_load(self.graph_output_var)) + output.append(self._create_load_const(index)) + output.append(create_instruction("BINARY_SUBSCR")) + + def add_cache(self, value): + var = self.new_var() + self.tempvars[value] = var + if value.mutable_local: + self.tempvars[value.mutable_local] = var + self._output.append(self.create_store(var)) + + def foreach(self, items): + for i in items: + self(i) + + def setup_globally_cached(self, name, value, push_null): + """Store value in a new global""" + name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) + f_globals = self.tx.f_globals + if name in f_globals: + assert id(f_globals[name]) == id(value) + else: + f_globals[name] = value + return [self.create_load_global(name, push_null, add=True)] + + def clear_tos(self): + self.top_of_stack = None + + def append_output(self, inst): + assert isinstance(inst, Instruction) + self._output.append(inst) + self.clear_tos() + + def extend_output(self, insts): + assert all(isinstance(x, Instruction) for x in insts) + self._output.extend(insts) + self.clear_tos() + + def get_instructions(self) -> List[Instruction]: + return self._output + + def create_load(self, name) -> Instruction: + if name in self.cell_and_freevars(): + return create_instruction("LOAD_DEREF", argval=name) + assert name in self.code_options["co_varnames"], f"{name} missing" + return create_instruction("LOAD_FAST", argval=name) + + def create_load_closure(self, name) -> Instruction: + assert name in self.cell_and_freevars() + return create_instruction("LOAD_CLOSURE", argval=name) + + def create_store(self, name) -> Instruction: + if name in self.cell_and_freevars(): + return create_instruction("STORE_DEREF", argval=name) + assert name in self.code_options["co_varnames"] + return create_instruction("STORE_FAST", argval=name) + + def create_load_global(self, name, push_null, add=False) -> Instruction: + if add: + self.tx.output.update_co_names(name) + assert name in self.code_options["co_names"], f"{name} not in co_names" + return create_load_global(name, push_null) + + def create_load_const(self, value) -> Instruction: + assert is_safe_constant(value), f"unsafe constant {value}" + return self._create_load_const(value) + + def _create_load_const(self, value) -> Instruction: + return create_instruction("LOAD_CONST", argval=value) + + create_load_output = _create_load_const + + def create_load_method(self, name): + self.tx.output.update_co_names(name) + return create_instruction("LOAD_METHOD", argval=name) + + def create_load_attr(self, name) -> Instruction: + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + return create_instruction("LOAD_ATTR", argval=name) + + def load_attr(self, name): + self.append_output(self.create_load_attr(name)) + + def create_load_attrs(self, names): + return [self.create_load_attr(name) for name in names.split(".")] + + def create_store_attr(self, name) -> Instruction: + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + return create_instruction("STORE_ATTR", argval=name) + + def store_attr(self, name): + self.append_output(self.create_store_attr(name)) + + def load_function_name(self, fn_name, push_null, num_on_stack=0): + """Load the global fn_name on the stack num_on_stack down""" + output = [] + if push_null and sys.version_info >= (3, 11): + output.extend( + [create_instruction("PUSH_NULL"), *self.rot_n(num_on_stack + 1)] + ) + output.extend( + [ + self.create_load_global(fn_name, False, add=True), + *self.rot_n(num_on_stack + 1), + ] + ) + return output + + def rot_n(self, n): + try: + return create_rot_n(n) + except AttributeError: + # desired rotate bytecode doesn't exist, generate equivalent bytecode + return [ + create_instruction("BUILD_TUPLE", arg=n), + self._create_load_const(rot_n_helper(n)), + *create_rot_n(2), + create_instruction("CALL_FUNCTION_EX", arg=0), + create_instruction("UNPACK_SEQUENCE", arg=n), + ] + + def pop_null(self): + # POP_TOP doesn't work for null, so we pop nulls by pushing in a + # nop function, calling it (which consumes the null), and popping the result. + assert sys.version_info >= (3, 11) + return [ + self._create_load_const(lambda: None), + *create_call_function(0, False), + create_instruction("POP_TOP"), + ] + + def call_function(self, nargs: int, push_null: bool): + self.extend_output(create_call_function(nargs, push_null=push_null)) + + def dup_top(self): + self.append_output(create_dup_top()) + + def store(self, varname): + self.append_output(self.create_store(varname)) + + def make_function_with_closure( + self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 + ): + freevars = code.co_freevars + assert freevars + output = self._output + if sys.version_info >= (3, 11) and push_null: + output.append(create_instruction("PUSH_NULL")) + output.extend(self.rot_n(num_on_stack + 1)) + for var in freevars: + assert var in self.cell_and_freevars() + output.append(create_instruction("LOAD_CLOSURE", argval=var)) + output.append(create_instruction("BUILD_TUPLE", arg=len(freevars))) + output.append(self.create_load_const(code)) + if sys.version_info < (3, 11): + output.append(self.create_load_const(fn_name)) + output.append(create_instruction("MAKE_FUNCTION", arg=0x08)) + output.extend(self.rot_n(num_on_stack + 1)) + self.clear_tos() + + def create_load_python_module(self, mod, push_null) -> Instruction: + """ + Generate a LOAD_GLOBAL instruction to fetch a given python module. + """ + output = self.tx.output + global_scope = output.global_scope + name = re.sub(r"^.*[.]", "", mod.__name__) + if global_scope.get(name, None) is mod: + return self.create_load_global(name, push_null, add=True) + prefix = f"___module_{name}" + global_name = self.tx.output.install_global_by_id(prefix, mod) + return self.create_load_global(global_name, push_null, add=True) + + def make_call_generated_code(self, fn_name: str) -> None: + """Call the generated code function stored in fn_name""" + self.extend_output(self.load_function_name(fn_name, True)) + + graphargs = self.tx.output.graphargs + for arg in graphargs: + if arg.is_unspecialized: + self.extend_output( + [ + self.create_load_python_module(torch, True), + self.create_load_attr("as_tensor"), + ] + ) + self.call_reconstruct(arg) + self.extend_output(create_call_function(1, False)) + else: + self.call_reconstruct(arg) + + self.extend_output(create_call_function(len(graphargs), False)) + + def load_import_from(self, module_name, object_name) -> None: + self(AttrSource(self.tx.import_source(module_name), object_name)) + + def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]: + if sys.version_info >= (3, 11): + output = create_call_function(nargs, push_null) + assert output[-2].opname == "PRECALL" + kw_names_inst = create_instruction("KW_NAMES", argval=kw_names) + output.insert(-2, kw_names_inst) + return output + return [ + self.create_load_const(kw_names), + create_instruction("CALL_FUNCTION_KW", arg=nargs), + ] diff --git a/MLPY/Lib/site-packages/torch/_dynamo/compiled_autograd.py b/MLPY/Lib/site-packages/torch/_dynamo/compiled_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..b130f0c3f5ccd264a434736fe480cae8e96fe571 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/compiled_autograd.py @@ -0,0 +1,280 @@ +import contextlib +import functools +from typing import List, Optional + +import torch +from torch._dynamo.external_utils import call_backward, call_hook +from torch._dynamo.source import GetItemSource, LocalSource +from torch._dynamo.utils import counters, lazy_format_graph_code +from torch._logging import getArtifactLogger, trace_structured +from torch._prims_common import clone_preserve_strides +from torch._subclasses import FakeTensorMode +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import ( + decompose, + disable_autocast_cache, + disable_proxy_modes_tracing, + fetch_object_proxy, + ProxyTorchDispatchMode, + PythonKeyTracer, + track_tensor_tree, +) +from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from torch.fx.proxy import Proxy + +compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") + + +def maybe_clone(x): + if x is not None: + return clone_preserve_strides(x) + return x + + +class AutogradCompilerInstance: + def __init__(self, compiler_fn) -> None: + self.compiler_fn = compiler_fn + self.stack = contextlib.ExitStack() + self.close = self.stack.close + self.shape_env = ShapeEnv() + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.fx_tracer = PythonKeyTracer() + self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") + self.hooks_proxy: Optional[Proxy] = None + + def wrap_fake(self, x, source): + assert isinstance(x, torch.Tensor) + return self.fake_tensor_mode.from_tensor(x, source=source) + + @staticmethod + def source(name, idx) -> GetItemSource: + return GetItemSource(LocalSource(name), idx) + + def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]): + counters["compiled_autograd"]["captures"] += 1 + self.fx_tracer.root = torch.nn.Module() + self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) + self.fx_tracer.tensor_attrs = {} + args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {}) + sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {}) + self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {}) + + # tensor inputs to fake tensors + inputs = [ + self.wrap_fake(x, self.source("inputs", idx)) + for idx, x in enumerate(inputs) + ] + proxies = [args_proxy[i] for i in range(len(inputs))] + self.bind_tensors_to_proxies(inputs, proxies) + + # size inputs to symints + sizes = [ + self.shape_env.create_unspecified_symint_and_symbol( + val, + self.source("sizes", idx), + DimDynamic.DYNAMIC, + ) + for idx, val in enumerate(sizes) + ] + self.bind_tensors_to_proxies(sizes, sizes_proxy) + + # TODO(jansel): are all these modes needed? + self.stack.enter_context(decompose({})) + self.stack.enter_context(self.fake_tensor_mode) + self.stack.enter_context(self.proxy_mode.sym_mode) + self.stack.enter_context(self.proxy_mode) + self.stack.enter_context(disable_autocast_cache()) + return inputs, sizes + + def proxy_call_backward( + self, + inputs, + output_metadatas, + saved_tensors, + backward_idx: int, + ): + assert self.hooks_proxy is not None + backward_fn = self.hooks_proxy[backward_idx] # type: ignore[index] + proxies = self.fx_tracer.create_proxy( + kind="call_function", + target=call_backward, + args=( + backward_fn, + self.to_proxy(saved_tensors), + *self.to_proxy(inputs), + ), + kwargs={}, + ) + + with disable_proxy_modes_tracing(): + # create fake Tensors + grad_ins: List[Optional[torch.Tensor]] = [] + for output_metadata in output_metadatas: + if output_metadata is None: + grad_ins.append(None) + continue + + layout, device, dtype, size = output_metadata + grad_ins.append( + torch.empty(size=size, dtype=dtype, layout=layout, device=device) + ) + self.bind_tensors_to_proxies(grad_ins, proxies) + return tuple(grad_ins) + + def proxy_call_hook(self, hook, *args): + return self.fx_tracer.create_proxy( + "call_function", + call_hook, + ( + hook, + *[self.to_proxy(x) for x in args], + ), + {}, + ) + + def tensor_pre_hook(self, inputs, hook_id, i: int): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + inputs[i], + ) + with disable_proxy_modes_tracing(): + inputs[i] = maybe_clone(inputs[i]) + self.bind_tensors_to_proxies([inputs[i]], [proxy]) + return inputs + + def pre_hook(self, inputs, hook_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + inputs, + ) + with disable_proxy_modes_tracing(): + inputs = [maybe_clone(x) for x in inputs] + self.bind_tensors_to_proxies(inputs, proxies) + return inputs + + def post_hook(self, outputs, inputs, hook_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + outputs, + inputs, + ) + with disable_proxy_modes_tracing(): + outputs = [maybe_clone(x) for x in outputs] + self.bind_tensors_to_proxies(outputs, proxies) + return outputs + + def post_acc_grad_hook(self, input, hook_id): + assert isinstance(input, torch.Tensor) + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + input, + ) + with disable_proxy_modes_tracing(): + input = [maybe_clone(input)] + self.bind_tensors_to_proxies(input, proxies) + return input + + def end_capture(self, outputs): + self.stack.close() + self.fx_tracer.create_node( + "output", + "output", + (self.fx_tracer.create_arg(self.to_proxy(outputs)),), + {}, + ) + graph = GraphModule( + self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" + ) + compiled_autograd_log.info( + "%s", lazy_format_graph_code("Compiled autograd graph", graph) + ) + trace_structured( + "compiled_autograd_graph", + payload_fn=lambda: graph.print_readable(print_output=False), + ) + return self.compiler_fn(graph) + + def to_proxy(self, t): + if t is None: + return None + if isinstance(t, list): + return [self.to_proxy(x) for x in t] + if isinstance(t, tuple): + return tuple(self.to_proxy(x) for x in t) + assert isinstance(t, (torch.Tensor, torch.SymInt)) + return fetch_object_proxy(self.fx_tracer)(t).proxy + + def bind_tensors_to_proxies(self, tensors, proxies): + if isinstance(proxies, torch.fx.Proxy): + proxies = [proxies[i] for i in range(len(tensors))] + assert len(tensors) == len(proxies) + track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer) + + def bind_backward_state(self, index: int): + assert self.hooks_proxy is not None + proxy = self.hooks_proxy[index] # type: ignore[index] + bw_state = BackwardState() + track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) + return bw_state + + +compiled_autograd_enabled = False + +# We may have code like: +# with enable(compiler_fn): +# ... +# with disable(): +# ... +# ... +# The disable() call just want to disable compiled autograd temporarily. +# But overall the feature is enabled. +# +# The code covered by the disable context manager has no way to know if +# compiled autograd is overall eanbled. Use another variable +# compiled_autograd_enabled_count to indicate how many times compiled +# autograd has been enabled in the call stack for this purpose. +compiled_autograd_enabled_count = 0 + + +@contextlib.contextmanager +def enable(compiler_fn): + prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( + functools.partial(AutogradCompilerInstance, compiler_fn) + ) + global compiled_autograd_enabled, compiled_autograd_enabled_count + compiled_autograd_enabled = True + compiled_autograd_enabled_count += 1 + try: + with torch.autograd.set_multithreading_enabled(False): + yield + finally: + compiled_autograd_enabled_count -= 1 + if not prior: + compiled_autograd_enabled = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + + +@contextlib.contextmanager +def disable(): + prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + global compiled_autograd_enabled + compiled_autograd_enabled = False + try: + yield + finally: + if prior: + compiled_autograd_enabled = True + torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/comptime.py b/MLPY/Lib/site-packages/torch/_dynamo/comptime.py new file mode 100644 index 0000000000000000000000000000000000000000..649bcbac947d51710f430efcb22824a02b29b24f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/comptime.py @@ -0,0 +1,373 @@ +# This file establishes the public comptime interface to Dynamo. +# This allows Dynamo users to execute arbitrary Python code while +# Dynamo is symbolically evaluating their original programs. +# +# The goal of the public API is to give users rope, without actually +# leaking private implementation details of Dynamo. + +import builtins +import dis +import traceback +from typing import Optional, Union + +import torch +from torch.fx.experimental.symbolic_shapes import free_symbols + +from .exc import unimplemented +from .variables.constant import ConstantVariable +from .variables.tensor import SymNodeVariable + + +class ComptimeVar: + """ + A ComptimeVar represents a Python value, at some particular point + in time, in the Python code we are symbolically evaluating with + torchdynamo. This must be distinguished from a runtime value, as + at compile-time there are some properties of the variable we + do not know (for example, if the ComptimeVar represents a Tensor, + we only know metadata about the tensor; we do NOT know what the + actual data in the Tensor is.) + """ + + def __init__(self, v): + self.__variable = v + + def as_proxy(self): + """ + Returns an fx.Proxy (or tuple/list of fx.Proxy) representing + this variable in the FX graph we are assembling to pass + to the user compiler. + + This method only works for variables we actually track in + the FX graph, aka Tensors (and ints, if you are compiling + with dynamic shapes). In particular, if you have a list + or tuple of tensors, you will get a list/tuple of proxies + (not a single proxy representing the entire list/tuple). + """ + return self.__variable.as_proxy() + + def is_proxy(self): + """ + Returns True if as_proxy() would succeed. + """ + return self.__variable.is_proxy() + + def as_fake(self): + """ + Returns a "fake" value (either a FakeTensor or a SymInt) + representing the variable in question. This only works + for variables that denote Tensor or int. You can use + this to query metadata; e.g., v.as_fake().size(0) will + tell you the compile-time known size of the tensor. + + WARNING: Do NOT mutate the returned tensor. + """ + return self.__variable.as_proxy().node.meta["example_value"] + + def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: + """ + Returns the size of the tensor (if dim is None) or the size + at the dimension dim. The returned size may be a SymInt. + """ + return self.as_fake().size(dim) + + def python_type(self): + """ + Returns what type(v) would have returned for the variable + at compile time. + """ + return self.__variable.python_type() + + def as_python_constant(self): + """ + Returns the Python value this variable would have, but only if it is + completely known at compile-time (e.g., it is constant). + + WARNING: Do NOT mutate the returned constant. The returned constant + may or may not correspond to the actual value this variable may take + on at runtime; for example, if the variable in question is a constant + list, we may return a copy of that list. + """ + return self.__variable.as_python_constant() + + def is_python_constant(self): + """ + Returns True if as_python_constant would succeed. + """ + return self.__variable.is_python_constant() + + def is_dynamic(self): + if isinstance(self.__variable, SymNodeVariable): + fs = free_symbols(self.__variable.sym_num) + return bool(fs) + return False + + def force_static(self): + """ + Forces that a value is static, inducing a guard on its specific value + """ + if isinstance(self.__variable, SymNodeVariable): + self.__variable.evaluate_expr() + elif isinstance(self.__variable, ConstantVariable): + # TODO: Maybe complain if this isn't a int/bool/float variable + pass + else: + raise AssertionError( + f"cannot force {self.__variable} ({type(self.__variable)}) static" + ) + + def _i_will_not_complain_if_bc_breaks_VariableTracker(self): + """ + Returns the internal data structure VariableTracker that Dynamo uses + to represent variables at compile time. There are no BC guarantees on + this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on + it. + """ + return self.__variable + + def __repr__(self): + # TODO: The default repr is pretty bad, do better + return repr(self.__variable) + + # TODO: API for adding a custom guard + + +class ComptimeContext: + """ + This context class provides access to a public API for Dynamo's internals. + If there is something here you would find useful that is missing, please + file a feature request at https://github.com/pytorch/pytorch/ + """ + + def __init__(self, tx): + self.__tx = tx + + def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: + """ + Retrieve the compile-time known information about a local. + """ + tx = self.__get_tx(stacklevel) + return ComptimeVar(tx.symbolic_locals[name]) + + def graph_break(self, msg="ComptimeContext.graph_break"): + """ + Manually trigger a graph break + """ + unimplemented(msg) + + def graph(self): + """ + Retrieve the partially constructed FX graph that would be + passed to the user compiler after compilation. + """ + return self.__tx.output.graph + + def assert_static(self, val): + """ + Asserts that the int is static (and not dynamic, per dynamic shapes) + """ + assert ( + not val.is_dynamic() + ), "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" + + def print_graph(self, *, verbose=True, file=None): + """ + Print the partially constructed FX graph that would be passed + to the user compiler after compilation. + """ + print( + self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file + ) + + def parent(self): + return ComptimeContext(self.__tx.parent) + + def __get_tx(self, stacklevel): + tx = self.__tx + for _ in range(stacklevel): + tx = tx.parent + return tx + + def print_disas(self, *, file=None, stacklevel=0): + """ + Print the current series of opcodes being executed (not including + parent frames), including where you are in the particular opcode + stream. + """ + tx = self.__get_tx(stacklevel) + print( + dis.Bytecode( + tx.f_code, + current_offset=tx.instructions[tx.instruction_pointer].offset, + ).dis(), + file=file, + ) + + def print_value_stack(self, *, file=None, stacklevel=0): + """ + Print the current Python value stack. Note that this is NOT the same + as the traceback; use print_bt() to print that. Note that at + stacklevel=0, this will typically be empty, as comptime cannot + currently be used in an expression context where there would be + intermediates on the stack. If you would find this useful, please + file a bug at https://github.com/pytorch/pytorch/ + + NB: Stack grows downwards in our print + """ + # TODO: improve printing + tx = self.__get_tx(stacklevel) + for s in tx.stack: + print(f"- {s}", file=file) + + def print_locals(self, *, file=None, stacklevel=0): + """ + Print all of the locals available in the current context. + By default this view is very limited; you can get more information + about any individual local using get_local(). + """ + # TODO: improve by improving the VariableTracker printing + tx = self.__get_tx(stacklevel) + for k, v in tx.symbolic_locals.items(): + print(f"{k} = {v}", file=file) + + def print_bt(self, *, file=None, stacklevel=0): + """ + Print the user code backtrace, starting at the beginning of the + frame Dynamo started evaluating. Note that this MAY NOT go all + the way to the torch.compile invocation, as we may have done + a graph break and are compiling an intermediate frame as the + starting point. If you think the other behavior would be better, + file a bug at https://github.com/pytorch/pytorch/ + """ + stack = [] + tx = self.__get_tx(stacklevel) + while tx is not None: + stack.append(tx.frame_summary()) + tx = getattr(tx, "parent", None) + print( + "".join(traceback.StackSummary.from_list(reversed(stack)).format()), + file=file, + ) + + def print_guards(self, *, file=None): + """ + Print the currently installed guards for the Dynamo context. + This does NOT include guards associated with variables that + may or may not be installed in the future if those variables + are used. + """ + # TODO: improve print format, current guard format is extremely + # verbose + print( + "\n".join(f"{repr(guard)}" for guard in sorted(self.__tx.output.guards)), + file=file, + ) + + def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): + """ + Returns the internal data structure InstructionTranslator that Dynamo + uses to track state of symbolic evaluation. There are no BC + guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if + you rely on it. + """ + return self.__tx + + +class _Comptime: + @staticmethod + def __call__(fn): + """fn gets called at compile time in TorchDynamo, does nothing otherwise""" + return + + # Convenience wrappers that are more compact to use + + @staticmethod + def graph_break(): + comptime(lambda ctx: ctx.graph_break()) + + @staticmethod + def print_graph(): + comptime(lambda ctx: ctx.print_graph()) + + @staticmethod + def print_disas(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_disas( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_value_stack(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_value_stack( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + # This is a more useful variant of print_value_stack that can be used + # in an expression context; e.g., x + print_value_stack_and_return(y + z), + # you will see x on the stack prior to the addition operation + @staticmethod + def print_value_stack_and_return(e, *, stacklevel=0): + comptime( + lambda ctx: ctx.print_value_stack( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + return e + + @staticmethod + def print_locals(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_locals( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_bt(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_bt( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_guards(): + comptime(lambda ctx: ctx.print_guards()) + + @staticmethod + def assert_static(val): + comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) + + @staticmethod + def force_static(val): + comptime(lambda ctx: ctx.get_local("val").force_static()) + + @staticmethod + def breakpoint(): + """ + Like pdb breakpoint(), but drop into pdb whenever this line + of code is compiled by dynamo. Use it by putting + this in your model code:: + + from torch._dynamo.comptime import comptime + comptime.breakpoint() + + And then, inside pdb, you can access 'ctx' to query things + about the compilation context:: + + (Pdb) !ctx.print_bt() + (Pdb) !ctx.print_locals() + (Pdb) p ctx.get_local("attention").as_fake() + """ + + def inner(inner_ctx): + ctx = inner_ctx.parent() + builtins.breakpoint() + + comptime(inner) + + +comptime = _Comptime() diff --git a/MLPY/Lib/site-packages/torch/_dynamo/config.py b/MLPY/Lib/site-packages/torch/_dynamo/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cc707e2cb92a7e0533bfb5db217c1c885ab1c012 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/config.py @@ -0,0 +1,423 @@ +import getpass +import inspect +import os +import re +import sys +import tempfile +from os.path import abspath, dirname +from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union + +import torch + +# to configure logging for dynamo, aot, and inductor +# use the following API in the torch._logging module +# torch._logging.set_logs(dynamo=, aot=, inductor) +# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity) +# see this design doc for more detailed info +# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# +# the name of a file to write the logs to +# [@compile_ignored: debug] +log_file_name: Optional[str] = None + +# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors +verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1" + +# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend +verify_correctness = False + +# need this many ops to create an FX graph +minimum_call_count = 1 + +# turn on/off DCE pass +dead_code_elimination = True + +# disable (for a function) when cache reaches this size + +# controls the maximum number of cache entries with a guard on same ID_MATCH'd +# object. It also controls the maximum size of cache entries if they don't have +# any ID_MATCH'd guards. +# [@compile_ignored: runtime_behaviour] +cache_size_limit = 8 + +# [@compile_ignored: runtime_behaviour] controls the maximum number of entries for a code object. +accumulated_cache_size_limit = 64 + +# whether or not to specialize on int inputs. This only has an effect with +# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int +# inputs. Note that assume_static_by_default will also cause ints to get +# specialized, so this is mostly useful for export, where we want inputs +# to be dynamic, but accesses to ints should NOT get promoted into inputs. +specialize_int = False + +# legacy config, does nothing now! +dynamic_shapes = True + +use_lazy_graph_module = ( + os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1" +) + +# This is a temporarily flag, which changes the behavior of dynamic_shapes=True. +# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic. +# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API +# see [Note - on the state of mark_dynamic] +assume_static_by_default = True + +# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction +# with assume_static_by_default=True. +# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail +# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic. +automatic_dynamic_shapes = True + +# This flag changes how the shapes of parameters are treated. +# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic +# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static, +# while the shapes of torch.Tensor are assumed to be dynamic. +force_parameter_static_shapes = True + +# This flag ensures that the shapes of a nn module are always assumed to be static +# If the flag is set to True, then the shapes of a nn.module are assumed to be static +# If the flag is set to False, then the shapes of a nn.module can be dynamic +force_nn_module_property_static_shapes = True + +# Typically, if you mark_dynamic a dimension, we will error if the dimension +# actually ended up getting specialized. This knob changes the behavior so +# that we don't error at all. This is helpful for our CI where I'm using a +# heuristic to mark batch dimensions as dynamic and the heuristic may get it +# wrong. +allow_ignore_mark_dynamic = False + +# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) +guard_nn_modules = False + +# Uses CPython internal dictionary tags to detect mutation. There is some +# overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag. +# guard_nn_modules unspecializes the nn module instance and adds guard for each +# relevant member of the nn modules. On the other hand, +# guard_nn_modules_using_dict_tags specializes on each nn module instance but +# uses low overhead dict version matching to detect mutations, obviating the +# need to guard on members of the nn modules. With +# guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required +# but kept around for debugging and discussing unspecializing nn module +# variables. +# TODO(janimesh, voz): Remove both of these flags (or atleast guard_nn_modules) +# once we have reached stability for the guard_nn_modules_using_dict_tags. +guard_nn_modules_using_dict_tags = True + +# This feature doesn't really work. We offer this flag for experimental +# purposes / if you want to help us build out support. +# +# torchdynamo has very limited support for tensor subclasses that implement +# __torch_function__. Our current support is limited to tensor subclasses +# that DO NOT store metadata on the tensor (in general, dynamo does not +# support Python code that stores extra attributes on tensors at present). +# If your tensor subclass purely changes function call behavior via +# __torch_function__, you can allow torchdynamo to trace into it by +# adding it to traceable_tensor_subclasses. We don't do any safety checks, +# so it is up to you to ensure that your subclass is well behaved. See also +# https://github.com/pytorch/torchdynamo/issues/1948 +# +# We do NOT currently support __torch_dispatch__. The implementation is +# currently buggy, the main show stopper for nontrivial use is +# https://github.com/pytorch/torchdynamo/issues/1952 +traceable_tensor_subclasses: Set[Type[Any]] = set() + +# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. +# This is a good way to get your model to work one way or another, but you may +# lose optimization opportunities this way. Devs, if your benchmark model is failing +# this way, you should figure out why instead of suppressing it. +suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) + +# Record and write an execution record of the current frame to a file +# if an exception is encountered +# @compile_ignored[debug] +replay_record_enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + +# Rewrite assert statement in python with torch._assert +rewrite_assert_with_torch_assert = True + +# Disable dynamo +disable = os.environ.get("TORCH_COMPILE_DISABLE", False) + +# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo +cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False) + +# legacy config, does nothing now! +skipfiles_inline_module_allowlist: Dict[Any, Any] = {} + +# If a string representing a PyTorch module is in this ignorelist, +# the `allowed_functions.is_allowed` function will not consider it +# when creating a list of PyTorch functions that will appear in +# FX IR. +allowed_functions_module_string_ignorelist = { + "torch.distributions", + "torch.testing", + "torch._refs", + "torch._prims", + "torch._decomp", +} + +# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"} +# None - Minifier is switched off +# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails +# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails +# [@compile_ignored: debug] +repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None) + +# Compiler compilation debug info +# 1: Dumps the original graph out to repro.py if compilation fails +# 2: Dumps a minifier_launcher.py if compilation fails. +# 3: Always dumps a minifier_launcher.py. Good for segfaults. +# 4: Dumps a minifier_launcher.py if the accuracy fails. +# [@compile_ignored: debug] +repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2)) + +# By default, we try to detect accuracy failure by running both forward +# and backward of a torchdynamo produced graph (if you are using repro_after +# 'dynamo'). This setting forces us to only test the forward graph and +# not the backward graph. This can be helpful if you're trying to debug +# an inference only problem, but the minifier seems to be choking on the +# backwards step +# TODO: Detect this situation automatically so the user doesn't need +# to manually configure this +# [@compile_ignored: debug] +repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1" + +# The tolerance we should use when testing if a compiled graph +# has diverged so that we should treat it as an accuracy failure +# [@compile_ignored: debug] +repro_tolerance = 1e-3 + +# If True, when testing if two models are the same, we will test them against +# a third fp64 reference and only report a problem if the RMSE relative to the +# fp64 is greater. However, this will use more memory; you may disable this +# if memory usage is too high. +# [@compile_ignored: runtime_behaviour] +same_two_models_use_fp64 = True + +# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type. +# When this flag is set to False, we introduce a graph break instead of capturing. +# This requires dynamic_shapes to be True. +capture_scalar_outputs = False + +# Not all backends support operators that have dynamic output shape (e.g., +# nonzero, unique). When this flag is set to False, we introduce a graph +# break instead of capturing. This requires dynamic_shapes to be True. +# If you set this to True, you probably also want capture_scalar_outputs +# (these are separated for historical reasons). +capture_dynamic_output_shape_ops = False + +# By default, dynamo will treat all ints as backed SymInts, which means (1) it +# will wait to see the int change over multiple runs before generalizing and +# (2) it will still always 0/1 specialize an int. When true, this knob +# forces dynamo to treat _length_per_key and _offset_per_key on +# KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that +# they (1) generalize immediately and (2) unsoundly never compare equal to +# 0/1. This is not on by default as AOTAutograd/Inductor cannot currently +# compile this code; however, this can be useful for export. +force_unspec_int_unbacked_size_like_on_torchrec_kjt = False + +# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and +# false_fn produces code with identical guards. +enforce_cond_guards_match = True + +# Specify how to optimize a compiiled DDP module. The flag accepts a bollean +# value or a string. There are 4 modes. +# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically +# split model graph into pieces to match DDP bucket sizes to allow DDP +# comm/compute overlap. +# 2. "python_reducer" (experimental): this optimization requires the usage +# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer +# and use the Python reducer to allow compiled_autograd to trace the +# communication and allow comm/compute overlap without graph-breaks. +# 3. "python_reducer_without_compiled_forward" (experimental): this mode is +# similar to "python_reducer". One should only use this optimization mode +# when compiled_autograd is used but the DDP module is not compiled. +# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor +# will Python reducer be used. With this mode, there will be no graph-breaks +# and the original DDP C++ reducer will be used. There will no comm/compute +# overlap. This mode CANNOT be used with compiled_autograd. +# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be +# specified with a boolean value. True is using ddp_optimizer and False is +# no optimization. +optimize_ddp: Union[bool, str] = True + +_ddp_optimization_mode = [ + "ddp_optimizer", + "python_reducer", # experimental mode + "python_reducer_without_compiled_forward", # experimental mode + "no_optimization", +] + + +def _get_optimize_ddp_mode(): + m = sys.modules[__name__] + if isinstance(m.optimize_ddp, bool): + if m.optimize_ddp: + mode = "ddp_optimizer" + else: + mode = "no_optimization" + elif isinstance(m.optimize_ddp, str): + mode = m.optimize_ddp + else: + raise ValueError(f"Invalid type, {type(optimize_ddp)=}") + + assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}" + return mode + + +# If True, delays DDPOptimizer submodule compilation to 1st run of the model, +# so that real tensor strides are used in all submodules +# (instead of using FakeTensor strides which can differ from real tensor strides and causes error in some cases). +# This feature is not hardened yet and it's known to cause issues to some models, so False by default. +optimize_ddp_lazy_compile = False + +# Whether to skip guarding on FSDP-managed modules +skip_fsdp_guards = True + +# Make dynamo skip guarding on hooks on nn modules +# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them, +# dynamo will not notice and will execute whichever version you first compiled. +skip_nnmodule_hook_guards = True + +# If True, raises exception if TorchDynamo is called with a context manager +raise_on_ctx_manager_usage = True + +# If True, raise when aot autograd is unsafe to use +raise_on_unsafe_aot_autograd = False + +# If true, error if you torch.jit.trace over a dynamo-optimized function. +# If false, silently suppress dynamo +error_on_nested_jit_trace = True + +# If true, error with a better message if we symbolically trace over a +# dynamo-optimized function. If false, silently suppress dynamo. +error_on_nested_fx_trace = True + +# Disables graph breaking on rnn. YMMV with backends. +allow_rnn = False + +# If true, error if we try to compile a function that has +# been seen before. +# [@compile_ignored: runtime_behaviour] +error_on_recompile = False + +# [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything) +report_guard_failures = True + +# [@compile_ignored: debug] root folder of the project +base_dir = dirname(dirname(dirname(abspath(__file__)))) + +# Trace through NumPy or graphbreak +trace_numpy = True + +# Trace through torch.distributed code +trace_distributed = False + +# Default NumPy dtypes when tracing with torch.compile +# We default to 64bits. For efficiency, one may want to change these to float32 +numpy_default_float = "float64" +numpy_default_complex = "complex128" +numpy_default_int = "int64" + +# use numpy's PRNG if True, pytorch otherwise +use_numpy_random_stream = False + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +def default_debug_dir_root(): + # [@compile_ignored: debug] + DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" + if DEBUG_DIR_VAR_NAME in os.environ: + return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug") + elif is_fbcode(): + return os.path.join( + tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug" + ) + else: + return os.path.join(os.getcwd(), "torch_compile_debug") + + +# [@compile_ignored: debug] +debug_dir_root = default_debug_dir_root() + +# [@compile_ignored: debug] +_save_config_ignore = { + "repro_after", + "repro_level", + # workaround: "cannot pickle PyCapsule" + "constant_functions", + # workaround: "cannot pickle module" + "skipfiles_inline_module_allowlist", +} + +# for backend="cudagraphs", mutations on input be sent to the cudagraph backend +# or replayed in aot_autograd epilogue. default is False because mutation on inputs +# can prevent cudagraphing. +cudagraph_backend_keep_input_mutation = False + +# When True, only ops that have the torch.Tag.pt2_compliant tag +# will be allowed into the graph; all other ops will be disallowed +# and will fall back to eager-mode PyTorch. Useful to ensure +# correctness of custom ops. +only_allow_pt2_compliant_ops = False + +capture_autograd_function = True + +# enable/disable dynamo tracing for `torch.func` transforms +capture_func_transforms = False + +# enable/disable user-defined triton kernel optimizations +optimize_user_defined_triton_kernels = True + +# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode). +log_compilation_metrics = True + +# A set of logging functions which will be reordered to the end of graph breaks, +# allowing dynamo to construct larget graph. Note that there are some +# limitations to this, such as how it does not correctly print objects that were +# mutated after the print statement. +reorderable_logging_functions: Set[Callable[[Any], None]] = set() + +# simulates what would happen if we didn't have support for BUILD_SET opcode, +# used for testing +inject_BUILD_SET_unimplemented_TESTING_ONLY = False + +_autograd_backward_strict_mode_banned_ops = [ + "stride", + "requires_grad", + "storage_offset", + "layout", + "data", +] + +_autograd_backward_strict_mode_banned_ops.extend( + [name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)] +) + +# Enables caching of dispatches to fake tensors. +fake_tensor_cache_enabled = ( + os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1" +) + +# Enables cross checking between the fake tensor cache and dispatch. +fake_tensor_cache_crosscheck_enabled = ( + os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1" +) + +# support `context_fn` in torch.utils.checkpoint.checkpoint API under torch.compile(). +# WARNING: this is an experimental flag and is subject to change. +_experimental_support_context_fn_in_torch_utils_checkpoint = False + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + def _make_closure_patcher(**changes): + ... + + +from torch.utils._config_module import install_config_module + +install_config_module(sys.modules[__name__]) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/convert_frame.py b/MLPY/Lib/site-packages/torch/_dynamo/convert_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..0b325df31077d93e5ecddce7a73af613c3a64301 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/convert_frame.py @@ -0,0 +1,924 @@ +import collections +import dis +import functools +import itertools +import logging +import os +import random +import sys +import threading +import time +import traceback +import types +import typing +import weakref +from typing import Any, Callable, Dict, List, Optional, Set + +from torch.fx._lazy_graph_module import ( # type: ignore[attr-defined] + _use_lazy_graph_module, +) + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +import torch +import torch._logging +from torch._guards import compile_context, CompileContext, CompileId, tracing +from torch._logging import structured +from torch._utils_internal import signpost_event +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + GuardOnDataDependentSymNode, +) +from torch.fx.graph_module import _forward_from_src as original_forward_from_src +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils._python_dispatch import _disable_current_modes +from torch.utils._traceback import format_traceback_short + +from . import config, exc, trace_rules +from .backends.registry import CompilerFn +from .bytecode_analysis import remove_dead_code, remove_pointless_jumps +from .bytecode_transformation import ( + check_inst_exn_tab_entries_valid, + Instruction, + is_generator, + propagate_inst_exn_table_entries, + transform_code_object, +) +from .cache_size import ( + CacheSizeRelevantForFrame, + compute_cache_size, + exceeds_cache_size_limit, + is_recompilation, +) +from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher +from .exc import ( + augment_exc_message, + BackendCompilerFailed, + format_error_msg, + InternalTorchDynamoError, + TorchRuntimeError, + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, +) +from .guards import ( + CheckFunctionManager, + get_and_maybe_log_recompilation_reason, + GuardedCode, +) +from .hooks import Hooks +from .output_graph import OutputGraph +from .replay_record import ExecutionRecord +from .symbolic_convert import InstructionTranslator, SpeculationLog +from .trace_rules import is_numpy +from .types import BytecodeHook +from .utils import ( + CleanupManager, + CompilationMetrics, + counters, + dynamo_timed, + format_bytecode, + frame_phase_timing, + gen_record_file_name, + increment_frame, + is_namedtuple, + istype, + LazyString, + maybe_cprofile, + orig_code_map, + record_compilation_metrics, + reset_graph_break_dup_checker, + setup_compile_debug, + troubleshooting_url, + write_record_to_file, +) + +log = logging.getLogger(__name__) +bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") +GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard + +compile_lock = threading.RLock() + + +class Tracker: + def __init__(self): + self.seen = [] + self.seen_ids = set() + + def add(self, strong_obj): + idx = id(strong_obj) + if idx not in self.seen_ids: + obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx)) + self.seen.append(obj) + self.seen_ids.add(idx) + + def __contains__(self, item): + return id(item) in self.seen_ids + + def clear(self): + self.seen.clear() + self.seen_ids.clear() + + +input_codes = Tracker() +output_codes = Tracker() + +initial_global_state: Optional[GlobalStateGuard] = None + + +@functools.wraps(original_forward_from_src) +def fx_forward_from_src_skip_result(*args, **kwargs): + # we monkey patch FX to prevent infinite loop of trying to convert + # our generated code + result: types.FunctionType = original_forward_from_src(*args, **kwargs) + skip_code(result.__code__) + return result + + +def preserve_global_state(fn): + """ + Context manager to: + 1) Save/restore torch.is_grad_enabled() state + 2) Save/restore python random state + 3) Save/restore torch random state + 4) Monkey patch torch.fx.graph_module._forward_from_src + """ + + @functools.wraps(fn) + def _fn(*args, **kwargs): + guards = GlobalStateGuard() + prior_grad_mode = torch.is_grad_enabled() + prior_inference_mode = torch.is_inference_mode_enabled() + prior_deterministic = torch.are_deterministic_algorithms_enabled() + prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() + py_rng_state = random.getstate() + torch_rng_state = torch.random.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + prior_fwd_from_src = torch.fx.graph_module._forward_from_src + torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result + cleanup = setup_compile_debug() + try: + return fn(*args, **kwargs) + finally: + cleanup.close() + torch._C._set_grad_enabled(prior_grad_mode) + torch.torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) + torch.use_deterministic_algorithms( + prior_deterministic, warn_only=prior_warn_only + ) + random.setstate(py_rng_state) + torch.random.set_rng_state(torch_rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + torch.fx.graph_module._forward_from_src = prior_fwd_from_src + assert ( + guards.check() + ), "Global state changed while dynamo tracing, please report a bug" + + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + return _fn + + +@TorchPatcher.suppress_torch_distributed_warnings +def has_tensor_in_frame(frame): + """Check if the frame has torch.* related bits""" + # Check if the function was decorated using torch._dynamo.optimize + if frame.f_code in always_optimize_code_objects: + return True + + # Check if there is global import of torch.* + for co_name in frame.f_code.co_names: + if co_name in frame.f_globals: + obj = frame.f_globals[co_name] + if isinstance(obj, types.ModuleType) and ( + obj.__name__.startswith("torch.") or obj is torch + ): + return True + # ... or a global import of numpy.* + if np and config.trace_numpy and (obj is np or is_numpy(obj)): + return True + + seen_ids: Dict[int, bool] = dict() + + def has_tensor(obj): + """Recursively check if the obj has a tensor""" + obj_id = id(obj) + if obj_id in seen_ids: + return seen_ids[obj_id] + seen_ids[obj_id] = False + + if isinstance(obj, (torch.Tensor, torch.nn.Module)) or ( + istype(obj, type) and issubclass(obj, torch.nn.Module) + ): + seen_ids[obj_id] = True + return seen_ids[obj_id] + elif ( + config.trace_numpy + and np + and (istype(obj, np.ndarray) or isinstance(obj, np.generic)) + ): + seen_ids[obj_id] = True + return seen_ids[obj_id] + elif istype(obj, (list, tuple)): + seen_ids[obj_id] = any(has_tensor(v) for v in obj) + return seen_ids[obj_id] + elif istype(obj, dict): + # Some packages like pytest can be updated during runtime. So, make a + # copy of values to avoid issues like "RuntimeError: dictionary + # changed size during iteration" + values = list(obj.values()) + seen_ids[obj_id] = any(has_tensor(v) for v in values) + return seen_ids[obj_id] + elif istype(obj, (str, int, float, type(None), bool)): + seen_ids[obj_id] = False + return seen_ids[obj_id] + elif is_namedtuple(obj) and hasattr(obj, "_fields"): + seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields) + return seen_ids[obj_id] + else: + # if config.debug: + # print( + # f"Assuming that object of type {type(obj)} does not have a tensor" + # ) + return False + + # Check if the passed arguments are of type Tensor + for value in frame.f_locals.values(): + if has_tensor(value): + return True + + log.debug( + "skipping because no torch.* %s \ + %s %s", + frame.f_code.co_name, + frame.f_code.co_filename, + frame.f_code.co_firstlineno, + ) + + return False + + +def exception_handler(e, code, frame=None, export=False): + record_filename = None + if hasattr(e, "exec_record"): + record_filename = gen_record_file_name(e, code) + write_record_to_file(record_filename, e.exec_record) + e.record_filename = record_filename + + augment_exc_message(e, export=export) + + +FRAME_COUNTER = 0 +FRAME_COMPILE_COUNTER: typing.Counter[int] = collections.Counter() + + +def convert_frame_assert( + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, +): + """Fully convert a frame into an FX graph""" + reset_graph_break_dup_checker() + + def _convert_frame_assert( + frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 + ): + increment_frame() + + code = frame.f_code + + cache_size = compute_cache_size(frame, cache_entry) + recompile_reasons = None + if is_recompilation(cache_size): + recompile_reasons = get_and_maybe_log_recompilation_reason( + cache_entry, frame + ) + + input_codes.add(code) + if code in output_codes: + return None + if ( + os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") + and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name + ): + return None + if code.co_name == "" and code.co_filename.endswith( + ( + "transformers/file_utils.py", + "transformers/utils/generic.py", + "diffusers/utils/outputs.py", + ) + ): + # not needed, but cleans up torchbench error stats + return None + if code.co_name == "__setattr__": + # setattr could be tricky to handle generally, + # but also not likely useful to compile- skip the whole frame + return None + if code.co_name == "__init__" and code.co_filename.startswith( + os.path.dirname(torch.optim.__file__) + ): + # optimizer support is still incomplete see + # test_state_dict in test/dynamo/test_optimizers.py + return None + + # Check if the frame is generated by an exec builtin call + # TODO - Running exec generated frame seems propagates f_globals to the + # next frames. + if code.co_name == "" and code.co_filename == "": + return None + + if ( + code.co_name == "" + and code.co_filename == "" + and not bool(frame.f_builtins) + ): + # namedtuple subclass constructor. Empty builtins cause issue with + # len keyword in LIST_LEN guard. + return None + + if is_generator(code): + unimplemented("generator") + exceeded, limit_type = exceeds_cache_size_limit(cache_size) + if exceeded: + + def format_func_info(code): + return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})" + + def format_guard_failures(): + assert recompile_reasons, "TODO(whc) any other recompile reasons?" + return recompile_reasons[-1] + + log.warning( + "torch._dynamo hit config.%s (%s)\n" + " function: %s\n" + " last reason: %s\n" + 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n' + "To diagnose recompilation issues, see %s.", + limit_type, + getattr(config, limit_type), + format_func_info(code), + format_guard_failures(), + troubleshooting_url, + ) + unimplemented(f"{limit_type} reached") + + if not has_tensor_in_frame(frame): + return None + + global initial_global_state + initial_global_state = GlobalStateGuard() + + global FRAME_COUNTER + if "_id" not in frame_state: + frame_state["_id"] = FRAME_COUNTER + FRAME_COUNTER += 1 + frame_id = frame_state["_id"] + + frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] + FRAME_COMPILE_COUNTER[frame_id] += 1 + + compile_id = CompileId(frame_id, frame_compile_id) + + signpost_event( + "dynamo", + "_convert_frame_assert._compile", + { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, + "accumulated_cache_size": cache_size.num_cache_entries, + }, + ) + + return _compile( + frame.f_code, + frame.f_globals, + frame.f_locals, + frame.f_builtins, + compiler_fn, + one_graph, + export, + export_constraints, + hooks, + cache_size, + frame, + frame_state=frame_state, + compile_id=compile_id, + skip=skip + 1, + ) + + _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + + def _clone_with_backend(backend): + return convert_frame_assert(backend, one_graph, export, export_constraints) + + _convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined] + return _convert_frame_assert + + +from collections import OrderedDict + +from torch.utils.hooks import RemovableHandle + +# we have to use `OrderedDict` to make `RemovableHandle` work. +_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict() + + +def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: + """Register hooks for bytecode generated by Dynamo. The hook can do some + logging, as well as return a new code object to be used. Please refer + to `BytecodeHook` for the hook signature. + """ + handle = RemovableHandle(_bytecode_hooks) + _bytecode_hooks[handle.id] = hook + return handle + + +@_use_lazy_graph_module(config.use_lazy_graph_module) +@maybe_cprofile +def _compile( + code: types.CodeType, + globals: Dict[str, object], + locals: Dict[str, object], + builtins: Dict[str, object], + compiler_fn: CompilerFn, + one_graph: bool, + export: bool, + export_constraints, + hooks: Hooks, + cache_size: CacheSizeRelevantForFrame, + frame: Optional[types.FrameType] = None, + frame_state=None, + compile_id=None, + *, + skip: int = 0, +) -> Optional[GuardedCode]: + from torch.fx.experimental.validator import ( + bisect, + BisectValidationException, + translation_validation_enabled, + ValidationException, + ) + + output: Optional[OutputGraph] = None + tracer: Optional[InstructionTranslator] = None + # This is shared across restarts + mutated_closure_cell_contents: Set[str] = set() + speculation_log = SpeculationLog() + torch._dynamo.callback_handler.run_start_callbacks() + + @preserve_global_state + def transform(instructions, code_options): + nonlocal output + nonlocal tracer + speculation_log.restart() + tracer = InstructionTranslator( + instructions, + code, + locals, + globals, + builtins, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + mutated_closure_cell_contents, + frame_state=frame_state, + speculation_log=speculation_log, + ) + + try: + with tracing(tracer.output.tracing_context), tracer.set_current_tx(): + tracer.run() + except exc.UnspecializeRestartAnalysis: + speculation_log.clear() + raise + except (exc.SpeculationRestartAnalysis, exc.SkipFrame): + raise + except Exception: + if translation_validation_enabled(): + bisect(tracer.output.shape_env) + raise + finally: + tracer.output.call_cleanup_hooks() + + output = tracer.output + assert output is not None + assert output.output_instructions + instructions[:] = output.output_instructions + code_options.update(output.code_options) + + if config.dead_code_elimination: + propagate_inst_exn_table_entries(instructions) + check_inst_exn_tab_entries_valid(instructions) + instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) + + @dynamo_timed(phase_name="entire_frame_compile") + def compile_inner( + code: types.CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[List[Instruction], Dict[str, Any]], Any], + ) -> Optional[GuardedCode]: + nonlocal output + for attempt in itertools.count(): + CompileContext.get().attempt = attempt + try: + out_code = transform_code_object(code, transform) + break + except exc.RestartAnalysis as e: + log.info( + "Restarting analysis due to %s", + LazyString(format_traceback_short, e.__traceback__), + ) + if attempt > 100: + unimplemented("100+ RestartAnalysis() calls") + except exc.SkipFrame as e: + log.debug( + "Skipping frame %s %s \ + %s %s", + e, + code.co_name, + code.co_filename, + code.co_firstlineno, + ) + if one_graph: + log.debug("No graph captured with one_graph=True") + return None + + def log_bytecode(prefix, name, filename, line_no, code): + if bytecode_log.isEnabledFor(logging.DEBUG): + bytecode_log.debug( + format_bytecode(prefix, name, filename, line_no, code) + ) + + log_bytecode( + "ORIGINAL BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + code, + ) + log_bytecode( + "MODIFIED BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + out_code, # type: ignore[possibly-undefined] + ) + + for hook in _bytecode_hooks.values(): + hook_output = hook(code, out_code) + if hook_output is not None: + out_code = hook_output + + orig_code_map[out_code] = code + output_codes.add(out_code) + + assert output is not None + + # Tests for new code objects. + # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c + # Only test once the code object is created. + # They are not tested during runtime. + + def count_args(code): + import inspect + + return ( + code.co_argcount + + code.co_kwonlyargcount + + bool(code.co_flags & inspect.CO_VARARGS) + + bool(code.co_flags & inspect.CO_VARKEYWORDS) + ) + + total_argcount_old = count_args(code) + total_argcount_new = count_args(out_code) + msg = "arg mismatch: " + msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, " + msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}" + assert ( + code.co_varnames[:total_argcount_old] + == out_code.co_varnames[:total_argcount_new] + ), msg + + msg = "free var mismatch: " + msg += f"old code object has free var {code.co_freevars}, " + msg += f"new code object has free var {out_code.co_freevars}" + assert code.co_freevars == out_code.co_freevars, msg + + msg = "cell var mismatch: " + msg += f"old code object has cell var {code.co_cellvars}, " + msg += f"new code object has cell var {out_code.co_cellvars}" + assert code.co_cellvars == out_code.co_cellvars, msg + + # Skipping Dynamo on a frame without any extracted graph. + # This does not affect eager functionality. But this is necessary + # for export for cases where Dynamo-reconstructed bytecode can create + # new function frames, confusing export in thinking that there + # are extra graphs now. + + if output.export and output.is_empty_graph(): + return None + + assert output.guards is not None + CleanupManager.instance[out_code] = output.cleanups + check_fn = CheckFunctionManager( + output, + hooks.guard_fail_fn if hooks else None, + ) + + guarded_code = GuardedCode(out_code, check_fn.check_fn) + + if not output.is_empty_graph() and hooks.guard_export_fn is not None: + # We should not run the guard_export_fn when Dynamo does not + # generate any graph. This can happen in export when TorchDynamo + # generated bytecode has some reconstruction logic for mutated + # variables which can trigger TorchDynamo on the children frames but + # they are benign and do not generate any new graphs. + hooks.guard_export_fn(output.guards) + + return guarded_code + + with compile_context(CompileContext(compile_id)): + log.debug( + "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", + code.co_name, + code.co_filename, + code.co_firstlineno, + skip + 2, + # -2: omit current frame, omit contextlib decorator + "".join(traceback.format_list(traceback.extract_stack()[: -2 - skip])), + ) + # -4: -2 as above, plus trace_structured frames + torch._logging.trace_structured( + "dynamo_start", + lambda: { + "stack": structured.from_traceback( + traceback.extract_stack()[: -4 - skip] + ) + }, + ) + start_time = time.time() + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + fail_user_frame_filename: Optional[str] = None + fail_user_frame_lineno: Optional[int] = None + try: + guarded_code = compile_inner(code, one_graph, hooks, transform) + return guarded_code + except ( + Unsupported, + TorchRuntimeError, + BackendCompilerFailed, + AssertionError, + ConstraintViolationError, + GuardOnDataDependentSymNode, + ValidationException, + UncapturedHigherOrderOpError, + BisectValidationException, + ) as e: + fail_type = str(type(e)) + fail_reason = str(e) + exception_handler(e, code, frame, export=export) + if e.innermost_user_frame_summary is not None: # type: ignore[union-attr] + fail_user_frame_filename = e.innermost_user_frame_summary.filename # type: ignore[union-attr] + fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[union-attr] + raise + except Exception as e: + fail_type = str(type(e)) + fail_reason = str(e) + exception_handler(e, code, frame, export=export) + if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined] + fail_user_frame_filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined] + fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined] + raise InternalTorchDynamoError(str(e)).with_traceback( + e.__traceback__ + ) from None + finally: + if tracer: + tracer.output.local_scope = {} + + from .utils import curr_frame + + frame_key = str(curr_frame) + if ( + fail_reason is None + and output is not None + and frame_key in frame_phase_timing + ): + guard_count = len(output.guards) + shape_env_guard_count = len(output.shape_env.guards) + graph_op_count = output.count_calls() + graph_node_count = len(output.graph.nodes) + graph_input_count = len(output.placeholders) + entire_frame_compile_time = frame_phase_timing[frame_key].get( + "entire_frame_compile", None + ) + backend_compile_time = frame_phase_timing[frame_key].get( + "backend_compile", None + ) + inductor_compile_time = frame_phase_timing[frame_key].get( + "inductor_compile", None + ) + code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) + non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} + compliant_custom_ops = { + op.__qualname__ for op in output.compliant_custom_ops + } + else: + guard_count = None + shape_env_guard_count = None + graph_op_count = None + graph_node_count = None + graph_input_count = None + entire_frame_compile_time = None + backend_compile_time = None + inductor_compile_time = None + code_gen_time = None + non_compliant_ops = set({}) + compliant_custom_ops = set({}) + metrics = CompilationMetrics( + frame_key, + code.co_name, + code.co_filename, + code.co_firstlineno, + cache_size.num_cache_entries_with_same_id_matched_objs, + cache_size.num_cache_entries, + guard_count, + shape_env_guard_count, + graph_op_count, + graph_node_count, + graph_input_count, + start_time, + entire_frame_compile_time, + backend_compile_time, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + fail_user_frame_filename, + fail_user_frame_lineno, + non_compliant_ops, + compliant_custom_ops, + ) + record_compilation_metrics(metrics) + torch._dynamo.callback_handler.run_end_callbacks() + + +def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): + """Try to convert a frame into an FX graph, if error leave frame unmodified""" + inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + + def _convert_frame( + frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 + ): + counters["frames"]["total"] += 1 + try: + result = inner_convert( + frame, cache_entry, hooks, frame_state, skip=skip + 1 + ) + counters["frames"]["ok"] += 1 + return result + except Exception as e: + # These two exception types are "soft" failure, in the sense that + # we know this is due to something we didn't implement all the + # way, scare the user less about it. That being said, if you + # are trying to understand why a graph break happened, it's still + # important to have this information, so offer it. + # + # NB: NotImplementedError used to be on this list, but actually + # it is impossible for it to reach here, as it is converted into + # InternalTorchDynamoError. This behavior seemed reasonable + # to me (ezyang, Aug 2023) so I kept it, but maybe at some point + # someone wanted these to also get suppressed. If so, you'll + # need to make these exceptions not get wrapped + + # We intentionally don't want to suppress error here. + if isinstance(e, UncapturedHigherOrderOpError): + raise + + soft_fail = isinstance(e, Unsupported) + if not config.suppress_errors and not soft_fail: + raise + + # Suppress the error. NB: It's very important to do the + # suppression logging HERE, where the actual suppression + # happens. Previously it was somewhere else and so it was + # possible to accidentally not log at all. + record_filename = getattr(e, "record_filename", None) + code = frame.f_code + error_msg = format_error_msg(e, code, record_filename, frame) + + if soft_fail: + log.info(error_msg, exc_info=True) + else: + log.warning(error_msg, exc_info=True) + return None + + _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + _convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined] + return _convert_frame + + +# TODO mlazos: add support for same args, or record them +def replay(filename): + from .backends.debugging import eager + + original_replay_val = config.replay_record_enabled + config.replay_record_enabled = False + with open(filename, "rb") as in_file: + record = ExecutionRecord.load(in_file) + record.globals = dict(itertools.chain(record.globals.items(), globals().items())) + + try: + _compile( + record.code, + record.globals, + record.locals, + record.builtins, + compiler_fn=eager, + one_graph=False, + export=False, + export_constraints=None, + hooks=Hooks(), + cache_size=CacheSizeRelevantForFrame(0, 0), + frame=None, + frame_state={}, + ) + finally: + config.replay_record_enabled = original_replay_val + + +def first_real_inst_idx(code): + if sys.version_info < (3, 11): + return 0 + for inst in dis.get_instructions(code): + if inst.opname == "RESUME": + return inst.offset // 2 + raise RuntimeError("RESUME instruction not found in code") + + +def catch_errors_wrapper(callback, hooks: Hooks): + @functools.wraps(callback) + def catch_errors(frame, cache_entry, frame_state): + assert frame_state is not None + + is_skipfile = trace_rules.check(frame.f_code) + if ( + # TODO: the first condition is not covered by any test + frame.f_lasti >= first_real_inst_idx(frame.f_code) + or is_skipfile + or config.disable + ): + if log.isEnabledFor(logging.DEBUG): + skip_reason = ( + "traced frame already" + if frame.f_lasti >= first_real_inst_idx(frame.f_code) + else "in skipfiles" + if trace_rules.check(frame.f_code) + else "dynamo tracing is disabled" + ) + if not is_skipfile or config.verbose: + log.debug( + "skipping: %s (reason: %s, file: %s)", + frame.f_code.co_name, + skip_reason, + frame.f_code.co_filename, + ) + return None + if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": + # nametuple constructor + return None + if config._get_optimize_ddp_mode() == "ddp_optimizer": + ddp_module = DistributedDataParallel._get_active_ddp_module() + if ddp_module: + with compile_lock: + from torch._dynamo.backends.distributed import DDPOptimizer + + ddp_optimizer = DDPOptimizer( + bucket_bytes_cap=ddp_module.bucket_bytes_cap, + backend_compile_fn=callback._torchdynamo_orig_callable, + ) + assert hasattr( + callback, "_clone_with_backend" + ), "DDPOptimizer only supports callback fns that know how to clone themselves." + hijacked_callback = callback._clone_with_backend( + ddp_optimizer.compile_fn, + ) + return hijacked_callback(frame, cache_entry, hooks, frame_state) + + with compile_lock, _disable_current_modes(): + # skip=1: skip this frame + return callback(frame, cache_entry, hooks, frame_state, skip=1) + + catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] + return catch_errors diff --git a/MLPY/Lib/site-packages/torch/_dynamo/current_scope_id.py b/MLPY/Lib/site-packages/torch/_dynamo/current_scope_id.py new file mode 100644 index 0000000000000000000000000000000000000000..7a619bcbc9214c5b60788c05ccd45a3d2c1443f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/current_scope_id.py @@ -0,0 +1,23 @@ +import contextlib +import threading + +# Global variable to identify which SubgraphTracer we are in. +# It is sometimes difficult to find an InstructionTranslator to use. +_current_scope_id = threading.local() + + +def current_scope_id(): + global _current_scope_id + if not hasattr(_current_scope_id, "value"): + _current_scope_id.value = 1 + return _current_scope_id.value + + +@contextlib.contextmanager +def enter_new_scope(): + global _current_scope_id + try: + _current_scope_id.value = current_scope_id() + 1 + yield + finally: + _current_scope_id.value = current_scope_id() - 1 diff --git a/MLPY/Lib/site-packages/torch/_dynamo/debug_utils.py b/MLPY/Lib/site-packages/torch/_dynamo/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23ef196751642b159cfe15b63a55a57ffcbaa7cd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/debug_utils.py @@ -0,0 +1,802 @@ +# mypy: disable-error-code="method-assign" + +import copy +import functools +import getpass +import inspect +import itertools +import logging +import os +import re +import subprocess +import tempfile +import textwrap +from collections import Counter +from importlib import import_module +from typing import Any, Callable, Dict, List, Optional, TypeVar + +import torch +import torch._prims_common as utils +import torch._subclasses.meta_utils +from torch import Tensor + +from torch._dynamo.testing import rand_strided +from torch._prims_common import is_float_dtype +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._content_store import ContentStoreReader, ContentStoreWriter + +from . import config +from .utils import clone_inputs, get_debug_dir + +log = logging.getLogger(__name__) + +T = TypeVar("T") + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +if use_buck: + import libfb.py.build_info + + +extra_deps = [] +extra_imports = "" +if use_buck: + extra_deps = [ + "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", + "//caffe2/torch/fb/sparsenn:sparsenn_operators", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", + ] + cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined] + extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) + + +BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] + + +class BuckTargetWriter: + def __init__(self, filename): + self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) + self.target = self.py_file.replace(".py", "") + + # Get main_module path from fbcode + self.path = f'{self.subdir.replace("/", ".")}.{self.target}' + self.path = self.path[self.path.find("fbcode.") :] + self.path = self.path[7:] + + # Get cmd line path + tmp = self.subdir + tmp = tmp[tmp.find("fbcode/") :][7:] + self.cmd_line_path = f"//{tmp}:{self.target}" + + def build(self): + extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) + return textwrap.dedent( + f""" +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") + +python_binary( + name="{self.target}", + srcs = ["{self.py_file}"], + compile = False, + deps = [ + "//caffe2:torch", + "//caffe2/functorch:functorch", + "//triton:triton", + "{cur_target}", + ], + cpp_deps = [ +{extra_cpp_deps} + ], + main_module = "{self.path}", + par_style = "xar", +) +""" + ) + + def write(self, print_msg=True): + target_file = os.path.join(self.subdir, "TARGETS") + with open(target_file, "w") as fd: + fd.write(self.build()) + # log.warning("Wrote isolation TARGETS file at %s", target_file) + cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] + if print_msg: + log.warning( + "Found an example that reproduces the error. Run this cmd to repro - %s", + " ".join(cmd_split), + ) + return cmd_split + + +def minifier_dir(): + path = os.path.join(get_debug_dir(), "minifier") + if path is None: + path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + return path + + +MAX_CONSTANT_NUMEL_INLINE = 4 + + +class NNModuleToString: + safe_reprs = [ + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.Softmax, + torch.nn.ReLU, + torch.nn.GELU, + torch.nn.Identity, + torch.nn.MaxPool2d, + torch.nn.Embedding, + torch.nn.Tanh, + torch.nn.ConvTranspose1d, + torch.nn.GLU, + torch.nn.LSTM, + torch.nn.Flatten, + torch.nn.AdaptiveAvgPool2d, + ] + + @staticmethod + def can_convert_to_string(gm): + cant_convert = set() + for _, module in gm.named_children(): + if type(module) not in NNModuleToString.safe_reprs: + cant_convert.add(module) + + if len(cant_convert) > 0: + log.warning("We have not tested reprs of some modules - %s", cant_convert) + # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. + return True + + @staticmethod + def convert(gm): + from torch.nn.modules.module import _addindent + + tab = " " * 4 + + model_str = textwrap.dedent( + """ + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + """ + ) + + for module_name, module in gm.named_children(): + module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in gm._buffers.items(): + if buffer is None: + continue + # Serialize full data for small buffers + if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: + from torch._tensor_str import PRINT_OPTS + + assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE + tensor_str = repr(buffer) + elif torch.is_floating_point(buffer): + tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" + else: + tensor_str = ( + f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" + ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" + model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" + + for param_name, param in gm._parameters.items(): + if param is None: + continue + maybe_device = "" + if param.is_cuda: + maybe_device = ', device="cuda"' + tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" + model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" + + # TODO - Keep this code for now. But, I don't think we will need this. + # attrs = dir(gm) + # for attr in attrs: + # if "_tensor_constant" in attr: + # val = getattr(gm, attr) + # model_str += f" {attr} = {val!r}\n" + + model_str += f"{_addindent(gm.code, 4)}\n" + return model_str + + +@functools.lru_cache(None) # subprocess is expensive +def _cuda_system_info_comment(): + if not torch.cuda.is_available(): + return "# torch.cuda.is_available()==False, no GPU info collected\n" + + model_str = "# CUDA Info: \n" + try: + cuda_version_out = subprocess.check_output(["nvcc", "--version"]) + cuda_version_lines = cuda_version_out.decode().split("\n") + comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) + model_str += f"{comment}\n" + except (FileNotFoundError, subprocess.CalledProcessError): + model_str += "# nvcc not found\n" + + gpu_names = Counter( + torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) + ) + + model_str += "# GPU Hardware Info: \n" + for name, count in gpu_names.items(): + model_str += f"# {name} : {count} \n" + model_str += "\n" + return model_str + + +def generate_config_string(*, stable_output=False): + import torch._functorch.config + import torch._inductor.config + + if stable_output: + return "# config omitted due to stable_output=True" + + experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined] + return f"""\ +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +{torch._functorch.config.codegen_config()} +{experimental_config} +""" + + +def get_minifier_repro_path(): + return os.path.join(minifier_dir(), "minifier_launcher.py") + + +def helper_for_dump_minify(contents): + minified_repro_path = get_minifier_repro_path() + log.warning("Writing minified repro to:\n%s", minified_repro_path) + + if use_buck: + BuckTargetWriter(minified_repro_path).write() + try: + with open(minified_repro_path, "w") as fd: + fd.write(contents) + + except OSError as e: + log.exception(e) + raise NotImplementedError("Could not write to {minified_repro_path}") from e + + +class AccuracyError(Exception): + pass + + +def clone_inputs_retaining_gradness(example_inputs): + """ + This clone inputs is different from utils clone_input. In case of minifier, + all the tensors are leaf tensors while creating a new graph. So, we set the + requires_grad field w/o checking the leafness of the tensor. + """ + cloned_inputs = clone_inputs(example_inputs) + for idx in range(len(example_inputs)): + if isinstance(cloned_inputs[idx], torch.Tensor): + cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) + return cloned_inputs + + +def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): + """ + Runs a forward and possibly backward iteration for a given mod and args. + + When disable_clone is True, we will use args as-is without cloning. + This is higher fidelity but we may destroy the args in the process. + """ + from torch._functorch.aot_autograd import make_boxed_func + + from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass + + gm = copy.deepcopy(gm) + if not disable_clone: + args = clone_inputs_retaining_gradness(args) + + if hasattr(gm, "zero_grad"): + gm.zero_grad(True) + + # TorchInductor returned callable expects lists. So, boxing the call. + orig_named_parameters = getattr(gm, "named_parameters", None) + orig_named_buffers = getattr(gm, "named_buffers", None) + if not hasattr(gm, "_boxed_call") and ( + orig_named_parameters is not None or orig_named_buffers is not None + ): + gm = make_boxed_func(gm) + if orig_named_parameters is not None: + gm.named_parameters = orig_named_parameters + if orig_named_buffers is not None: + gm.named_buffers = orig_named_buffers + + out = gm(args) + if only_fwd: + return out + if requires_bwd_pass(out): + loss = reduce_to_scalar_loss(out) + loss.backward() + return collect_results(gm, out, None, args) + + +def same_two_models( + gm, + opt_gm, + example_inputs, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): + """ + Check two models have same accuracy. + + require_fp64: if True, raise an error if we unable to calculate the fp64 reference + ignore_non_fp: if True, do not compare outputs which are not floating point. This + is mostly useful for the minifier (which wants to avoid quantizing floating point + error into integer/boolean error) + """ + from .eval_frame import OptimizedModule + from .testing import ( + named_buffers_for_optimized_module, + named_parameters_for_optimized_module, + ) + from .utils import same + + if isinstance(gm, OptimizedModule): + gm.named_parameters = named_parameters_for_optimized_module(gm) + gm.named_buffers = named_buffers_for_optimized_module(gm) + + if isinstance(opt_gm, OptimizedModule): + opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) + opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm) + + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) + + fp64_ref = None + if config.same_two_models_use_fp64: + try: + fp64_model, fp64_examples = cast_to_fp64( + copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) + ) + fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) + except Exception: + if require_fp64: + raise RuntimeError("Could not generate fp64 outputs") # noqa: TRY200 + log.warning("Could not generate fp64 outputs") + + try: + res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) + except Exception as e: + # This means that the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return True. + log.exception( + "While minifying the program in accuracy minification mode, " + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph." + ) + return True + + passing = same( + ref, + res, + fp64_ref, + tol=config.repro_tolerance, + equal_nan=True, + ignore_non_fp=ignore_non_fp, + ) + return passing + + +def cast_dtype_args_to_fp64(model): + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.prims.convert_element_type.default + ): + assert len(node.args) == 2 + if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: + node.args = (node.args[0], torch.float64) + if node.op == "call_function": + dtype = node.kwargs.get("dtype") + if dtype is not None and is_float_dtype(dtype): + new_kwargs = dict(node.kwargs) + new_kwargs["dtype"] = torch.float64 + node.kwargs = new_kwargs + + model.graph.lint() + model.recompile() + return model + + +def cast_to(dtype, model, inputs): + from torch.utils._pytree import tree_map + + model = model.to(dtype) + if dtype == torch.float64: + # If casting to fp64 for accuracy comparison, we need to + # replace dtype arguments embedded in the graph with fp64 + model = cast_dtype_args_to_fp64(model) + + inputs = tree_map( + lambda x: x.to(dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else x, + inputs, + ) + return model, inputs + + +def cast_to_fp64(model, inputs): + return cast_to(torch.float64, model, inputs) + + +def backend_accuracy_fails( + gm, + example_inputs, + compiler_fn, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): + try: + compiled_gm = compiler_fn( + copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) + ) + return not same_two_models( + gm, + compiled_gm, + example_inputs, + only_fwd, + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + except Exception as e: + # This means that the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return False. + log.exception( + "While minifying the program in accuracy minification mode, " + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph" + ) + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO SUPPORT CODE +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# Helper functions for computing what the default values of tensor +# values should be. These all coincide with factory functions, e.g., torch.empty + + +def _stride_or_default( + stride: Optional["torch._prims_common.StrideType"], + *, + shape: "torch._prims_common.ShapeType", +) -> "torch._prims_common.StrideType": + return stride if stride is not None else utils.make_contiguous_strides_for(shape) + + +def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: + return lambda x: x if x is not None else d + + +_dtype_or_default = _mk_defaulter(torch.float32) +_device_or_default = _mk_defaulter(torch.device("cpu")) +_storage_offset_or_default = _mk_defaulter(0) +_requires_grad_or_default = _mk_defaulter(False) +_is_leaf_or_default = _mk_defaulter(False) + + +class NopInputReader: + def __init__(self): + self.total = 0 + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + self.total += 1 + + def tensor(self, *args, **kwargs): + pass + + def symint(self, *args, **kwargs): + pass + + +# TODO: Support bundling the entire repro into a zip file for ease of +# transferring around +class InputReader: + def __init__(self, save_dir=None, *, pbar=None): + # If None, we will generate random data instead. It's important + # to natively support this use case as it will allow people to + # share repros without including the real data, if the problem + # reproduces even on random data. + if save_dir is None: + log.warning("no save_dir specified, will generate random data") + self.store = ContentStoreReader(save_dir) if save_dir is not None else None + self.args = [] + self.pbar = pbar + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + if self.pbar is not None: + self.pbar.update(1) + device = _device_or_default(device) + dtype_hint = _dtype_or_default(dtype_hint) + if self.store is not None and storage_hash is not None: + try: + storage = self.store.read_storage(storage_hash) + except FileNotFoundError: + pass + else: + if device != storage.device: + log.warning("device mismatch: %s != %s", device, storage.device) + # TODO: transfer it to the right device? But failing this + # way would be very mysterious! Would have been better + # not to store device in the serialized format... + return storage + log.warning("could not load %s, generating random data instead", storage_hash) + shape = (nbytes // dtype_hint.itemsize,) + stride = _stride_or_default(None, shape=shape) + return rand_strided(shape, stride, dtype_hint, device).untyped_storage() + + def tensor( + self, + storage, + shape, + stride=None, + *, + storage_offset=None, + dtype=None, + requires_grad=None, + is_leaf=None, + **metadata, + ): + stride = _stride_or_default(stride, shape=shape) + storage_offset = _storage_offset_or_default(storage_offset) + dtype = _dtype_or_default(dtype) + is_leaf = _is_leaf_or_default(is_leaf) + requires_grad = _requires_grad_or_default(requires_grad) + t = torch.tensor( + [], dtype=dtype, device=storage.device, requires_grad=requires_grad + ) + with torch.no_grad(): + t.set_(storage, storage_offset, shape, stride) + if not is_leaf: + # Fake up some autograd history in a very naughty way + with torch.enable_grad(): + t = t.clone(memory_format=torch.preserve_format) + with torch.no_grad(): + t.set_(storage, storage_offset, shape, stride) + assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf + torch._utils.set_tensor_metadata(t, metadata) + self.args.append(t) + return t # for BC + + def symint(self, val): + self.args.append(val) + return val # for BC + + +# Here is our writer strategy: +# 1. We will stream all of the inputs to disk +# 2. You can now deterministically randomize the inputs, or reload +# the inputs from disk +# 3. You can YOLO run the script without the inputs, in which case +# we'll fill the inputs with random data and pray. This is the +# legacy behavior, but it's also useful if you want to find out +# if we're so broken even random inputs trigger it +# 4. We could offer an in process "check if the randomized thing +# works too" but this is delicate so we don't do it + + +class InputWriter: + def __init__(self, save_dir, *, stable_hash=False): + self._lines = [] + # TODO: consider ensuring tensor and storage counters line up? + self.storage_counter = itertools.count() + self.save_dir = save_dir + self.store = ( + ContentStoreWriter(save_dir, stable_hash=stable_hash) + if save_dir is not None + else None + ) + self.seen_storages = {} + + def lines(self): + r = [ + "def load_args(reader):", + ] + r.extend(f" {l}" for l in self._lines) + # In case we need to change the internal format of load_args + # in an FC-breaking way + r.append("load_args._version = 0") + return r + + # Storages are untyped, but we need to initialize them with data if + # we don't have the real data, so we give a hint saying what kind + # of initialization may be appropriate + # + # If we had a FakeTensor, device_hint tells us what device should be + def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: + ws = StorageWeakRef(untyped_storage) + v = self.seen_storages.get(ws) + if v is not None: + return v + v = f"buf{next(self.storage_counter)}" + maybe_dtype_hint = "" + if _dtype_or_default(None) != _dtype_or_default(dtype_hint): + maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" + # TODO: being optional on device is kind of pointless as the default + # is CPU but most repros we care about are CUDA + maybe_device = "" + device = untyped_storage.device + if device.type == "meta": + assert device_hint is not None + device = device_hint + if _device_or_default(None) != device: + maybe_device = f", device={device!r}" + nbytes = untyped_storage.nbytes() + storage_hash = None + if self.store is not None and untyped_storage.device.type != "meta": + storage_hash = self.store.write_storage(untyped_storage) + self._lines.append( + f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" + ) + self.seen_storages[ws] = v + return v + + def tensor(self, name, t) -> None: + storage = self.storage( + t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device + ) + args = [] + # NB: this is positional, must come first + if _stride_or_default(None, shape=t.shape) != t.stride(): + args.append(str(tuple(t.stride()))) + if _dtype_or_default(None) != t.dtype: + args.append(f"dtype={t.dtype!r}") + if _storage_offset_or_default(None) != t.storage_offset(): + args.append(f"storage_offset={t.storage_offset()!r}") + tensor_metadata = torch._utils.get_tensor_metadata(t) + if tensor_metadata: + args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) + if _requires_grad_or_default(None) != t.requires_grad: + args.append(f"requires_grad={t.requires_grad!r}") + is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) + if _is_leaf_or_default(None) != is_leaf: + args.append(f"is_leaf={is_leaf!r}") + self._lines.append( + "reader.tensor(" + + ", ".join([storage, str(tuple(t.shape)), *args]) + + f") # {name}" + ) + + # TODO: this doesn't actually symint atm + def symint(self, name, val) -> None: + if isinstance(val, torch.SymInt): + val = val.node.hint + self._lines.append(f"reader.symint({val!r}) # {name}") + + +def aot_graph_input_parser( + func: Callable[[List[Tensor]], List[Tensor]], + device: str = "cuda", + sym_shapes: Optional[Dict[str, int]] = None, + default_sym_shape: Optional[int] = None, +) -> Dict[str, Any]: + """ + Takes in a function which has been printed with print_readable() and constructs kwargs to run it. + + Handles Tensor inputs, Symints, and a graph module which might have tensor constants. + + Consider a function `forward` defined as follows: + + def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",): + _tensor_constant0: "i64[4190]" = self._tensor_constant0 + # Further implementation + + kwargs = aot_graph_input_parser(forward) + forward(**kwargs) + """ + + from torch.fx.graph import dtype_abbrs + + dtype_map = {value: key for key, value in dtype_abbrs.items()} + dtype_pattern = "|".join(dtype_abbrs.values()) + + # Extracting the source code from the function + source = inspect.getsource(func) + + # Regular expressions + tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)" + tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]" + sym_shape_regex = r"Sym\((s\d+)\)" + + class TensorContainer: + "Container for tensors as attributes" + pass + + # Dictionary for tensors from annotations + kwargs: Dict[str, Any] = {} + + sym_shapes = sym_shapes or {} + + def get_sym_int(symint): + torch._check( + symint in sym_shapes or default_sym_shape is not None, + lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", + ) + return sym_shapes.get(symint, default_sym_shape) + + def gen_tensor(shape, dtype) -> Tensor: + # Resolve symbolic shapes to concrete values + resolved_shape = [] + dynamic_dims = [] + for i, dim in enumerate(shape): + dim = dim.strip() + if "s" in dim: + s = get_sym_int(dim) + resolved_shape.append(s) + dynamic_dims.append(i) + else: + resolved_shape.append(int(dim)) + + constructor = torch.randn if dtype.is_floating_point else torch.zeros + out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg] + for d in dynamic_dims: + torch._dynamo.mark_dynamic(out, d) + return out + + # Parse function annotations for tensor generation + annotations = func.__annotations__ + for param, annotation in annotations.items(): + # Skip 'return' annotation + if param == "return": + continue + + match = re.search(tensor_regex, annotation) + if match: + data_type, shape_str = match.groups() + shape = tuple(shape_str.split(",")) + dtype = dtype_map[data_type] + kwargs[param] = gen_tensor(shape, dtype) + + match = re.search(sym_shape_regex, annotation) + if match: + kwargs[param] = get_sym_int(match.group(1)) + + if "self" in inspect.signature(func).parameters: + container = TensorContainer() + kwargs["self"] = container + for match in re.finditer(tensor_assignment_regex, source): + attr_name, data_type, shape_str, _ = match.groups() + shape = tuple(shape_str.split(",")) + dtype = dtype_map[data_type] + setattr(container, attr_name, gen_tensor(shape, dtype)) + + return kwargs diff --git a/MLPY/Lib/site-packages/torch/_dynamo/decorators.py b/MLPY/Lib/site-packages/torch/_dynamo/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..43a51da6151c6f7fce9e4b979fbb098e35e35ad9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/decorators.py @@ -0,0 +1,347 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from . import trace_rules, variables +from .comptime import comptime +from .eval_frame import DisableContext, innermost_fn, RunOnlyContext +from .exc import IncorrectUsage +from .external_utils import is_compiling + +if TYPE_CHECKING: + from torch._C._dynamo.eval_frame import ( # noqa: F401 + reset_code, + set_eval_frame, + set_guard_error_hook, + skip_code, + unsupported, + ) +else: + for name in dir(torch._C._dynamo.eval_frame): + if name.startswith("__"): + continue + globals()[name] = getattr(torch._C._dynamo.eval_frame, name) + + +def run(fn=None): + """Don't do any dynamic compiles, just use prior optimizations""" + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + return RunOnlyContext()(fn) + return RunOnlyContext() + + +def disable(fn=None, recursive=True): + """ + Decorator and context manager to disable TorchDynamo + + If recursive=True, Dynamo is completely skipped on the decorated function + frame as well as the recursively invoked functions. + + If recursive=False, Dynamo skips frames associated with the function code, + but still process recursively invoked frames. + """ + if recursive: + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + return DisableContext()(fn) + return DisableContext() + else: + return skip(fn) + + +def skip(fn=None): + """ + Skip frames associated with the function code, but still process recursively + invoked frames + """ + if fn is None: + return skip + fn = innermost_fn(fn) + assert callable(fn) + skip_code(fn.__code__) + fn._torchdynamo_disable = True + return fn + + +def assume_constant_result(fn): + fn._dynamo_marked_constant = True + return fn + + +def allow_in_graph(fn): + """ + Customize which functions TorchDynamo will include in the generated + graph. Similar to `torch.fx.wrap()`. + :: + + torch._dynamo.allow_in_graph(my_custom_function) + + @torch._dynamo.optimize(...) + def fn(a): + x = torch.add(x, 1) + x = my_custom_function(x) + x = torch.add(x, 1) + return x + + fn(...) + + Will capture a single graph containing `my_custom_function()`. + """ + if isinstance(fn, (list, tuple)): + return [allow_in_graph(x) for x in fn] + assert callable(fn), "allow_in_graph expects a callable" + if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable: + trace_rules._disallowed_callable_ids.remove(id(fn)) + trace_rules._allowed_callable_ids.add(id(fn)) + return fn + + +def _disallow_in_graph_helper(throw_if_not_allowed): + def inner(fn): + if isinstance(fn, (list, tuple)): + return [disallow_in_graph(x) for x in fn] + assert callable(fn), "disallow_in_graph expects a callable" + if ( + throw_if_not_allowed + and trace_rules.lookup_callable(fn) + != variables.TorchInGraphFunctionVariable + and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable + ): + raise IncorrectUsage( + "disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). " + "Allowed callables means callables that TorchDynamo puts as-is in the extracted graph." + ) + trace_rules._allowed_callable_ids.remove(id(fn)) + trace_rules._disallowed_callable_ids.add(id(fn)) + return fn + + return inner + + +def disallow_in_graph(fn): + """ + Customize which functions TorchDynamo will exclude in the generated + graph and force a graph break on. + :: + + torch._dynamo.disallow_in_graph(torch.sub) + + @torch._dynamo.optimize(...) + def fn(a): + x = torch.add(x, 1) + x = torch.sub(x, 1) + x = torch.add(x, 1) + return x + + fn(...) + + Will break the graph on `torch.sub`, and give two graphs each with a + single `torch.add()` op. + """ + return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn) + + +@_disallow_in_graph_helper(throw_if_not_allowed=False) +def graph_break(): + """Force a graph break""" + pass + + +def forbid_in_graph(fn): + """ + Customize which functions TorchDynamo will assert are not present while tracing. + + If you want a graph break on this function instead, use disallow_in_graph. + TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust + documentation would not be amiss. + """ + if isinstance(fn, (list, tuple)): + return [forbid_in_graph(x) for x in fn] + assert callable(fn), "forbid_in_graph applies only to callables" + fn._dynamo_forbidden = True + return fn + + +# Helper function to flatten a tensor subclass and apply a function to +# all inner tensors that match the outer dim. Used to reduce duplication +# across the various marking APIs. +def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): + assert is_traceable_wrapper_subclass(t) + + attrs, ctx = t.__tensor_flatten__() + for attr in attrs: + inner = getattr(t, attr) + if inner.dim() == t.dim(): + func(inner, *args, **kwargs) + + +@dataclass(frozen=True) +class _DimRange: + """ + This represents an dimension of a tensor and the corresponding + min and max values it can take. Don't create this + class directly; instead, use :func:`mark_dynamic`. + """ + + dim: int + min: int + max: int + + +@forbid_in_graph +def mark_dynamic(t, index, *, min=None, max=None): + """ + Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. + + [Note - on the state of mark_dynamic] + + The behavior of having a dynamic dimension on a tensor is governed by a few factors: + + 1) torch._dynamo.config dynamic_shapes True or False. + a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work. + a) dynamic_shapes=False - This config will raise an exception when used in conjunction with + mark_dynamic. We will eventually support this. + + 2) If the dimension is fully constrained - as in, it does not allow more than a single value + in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export), + we will raise an error + + 3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded + range of shapes, in eager we will pass it through, but export will raise an error. + + 4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made + before torch.compile. + + """ + if is_traceable_wrapper_subclass(t): + # default behavior: mirror mark_dynamic() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim( + mark_dynamic, t, index, min=min, max=max + ) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_dynamic_indices"): + t._dynamo_dynamic_indices = set() + t._dynamo_dynamic_range = set() + # TODO(voz): Should we bounds check? + t._dynamo_dynamic_indices.add(index) + t._dynamo_dynamic_range.add(_DimRange(index, min, max)) + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_dynamic(t, i, min=min, max=max) + + +@forbid_in_graph +def maybe_mark_dynamic(t, index): + """ + Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this + dimension ends up getting specialized, don't error). + """ + if is_traceable_wrapper_subclass(t): + # default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_weak_dynamic_indices"): + t._dynamo_weak_dynamic_indices = set() + # TODO(voz): Should we bounds check? + t._dynamo_weak_dynamic_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + maybe_mark_dynamic(t, i) + + +def mark_static(t, index=None): + """ + Mark a tensor as having a static dim. + + This will prevent us from attempting to compile it dynamically + when dynamic=True; this can improve trace-time performance. + + This has lower precedence than mark_dynamic. + + Unlike mark_dynamic, this can be done inside a graph, in which case it + induces specialization on the tensor. + """ + if is_compiling(): + if index is None: + for s in t.size(): + comptime.force_static(s) + else: + comptime.force_static(t.size(index)) + return + + if is_traceable_wrapper_subclass(t): + # default behavior: mirror mark_static() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_static_indices"): + t._dynamo_static_indices = set() + # TODO(voz): Should we bounds check? + t._dynamo_static_indices.add(index) + elif index is None: + for i in range(t.dim()): + mark_static(t, i) + else: + assert isinstance(index, (list, tuple)) + for i in index: + mark_static(t, i) + + +@forbid_in_graph +def mark_static_address(t, guard=True): + """ + Marks an input tensor whose data_ptr will not change across multiple calls + to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation + is not needed for this input. The data_ptr will be guarded if guard=True. Note: + Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called. + """ + if not isinstance(t, torch.Tensor): + raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}") + + if guard: + t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined] + else: + t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined] + + +# Note: this carefully avoids eagerly import einops. +# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2 +def _allow_in_graph_einops(): + import einops + + try: + # requires einops > 0.6.1, torch >= 2.0 + from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401 + _ops_were_registered_in_torchdynamo, + ) + + # einops > 0.6.1 will call the op registration logic as it is imported. + pass + except ImportError: + # einops <= 0.6.1 + allow_in_graph(einops.rearrange) + allow_in_graph(einops.reduce) + if hasattr(einops, "repeat"): + allow_in_graph(einops.repeat) # available since einops 0.2.0 + if hasattr(einops, "einsum"): + allow_in_graph(einops.einsum) # available since einops 0.5.0 + if hasattr(einops, "pack"): + allow_in_graph(einops.pack) # available since einops 0.6.0 + if hasattr(einops, "unpack"): + allow_in_graph(einops.unpack) # available since einops 0.6.0 + + +trace_rules.add_module_init_func("einops", _allow_in_graph_einops) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/device_interface.py b/MLPY/Lib/site-packages/torch/_dynamo/device_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..62241e15711f75309983cb5fb605db324b2fa8f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/device_interface.py @@ -0,0 +1,199 @@ +import inspect +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union + +import torch +from torch._streambase import _EventBase, _StreamBase + +get_cuda_stream: Optional[Callable[[int], int]] +if torch.cuda._is_compiled(): + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +else: + get_cuda_stream = None + +_device_t = Union[torch.device, str, int, None] + +# Recording the device properties in the main process but used in worker process. +caching_worker_device_properties: Dict[str, Any] = {} +caching_worker_current_devices: Dict[str, int] = {} + + +class DeviceInterfaceMeta(type): + def __new__(metacls, *args, **kwargs): + class_member = args[2] + if "Event" in class_member: + assert inspect.isclass(class_member["Event"]) and issubclass( + class_member["Event"], _EventBase + ), "DeviceInterface member Event should be inherit from _EventBase" + if "Stream" in class_member: + assert inspect.isclass(class_member["Stream"]) and issubclass( + class_member["Stream"], _StreamBase + ), "DeviceInterface member Stream should be inherit from _StreamBase" + return super().__new__(metacls, *args, **kwargs) + + +class DeviceInterface(metaclass=DeviceInterfaceMeta): + """ + This is a simple device runtime interface for Inductor. It enables custom + backends to be integrated with Inductor in a device-agnostic semantic. + """ + + class device: + def __new__(cls, device: _device_t): + raise NotImplementedError() + + class Worker: + """ + Worker API to query device properties that will work in multi processing + workers that cannot use the GPU APIs (due to processing fork() and + initialization time issues). Properties are recorded in the main process + before we fork the workers. + """ + + @staticmethod + def set_device(device: int): + raise NotImplementedError() + + @staticmethod + def current_device() -> int: + raise NotImplementedError() + + @staticmethod + def get_device_properties(device: _device_t = None): + raise NotImplementedError() + + @staticmethod + def current_device(): + raise NotImplementedError() + + @staticmethod + def set_device(device: _device_t): + raise NotImplementedError() + + @staticmethod + def device_count(): + raise NotImplementedError() + + @staticmethod + def is_available() -> bool: + raise NotImplementedError() + + @staticmethod + def stream(stream: torch.Stream): + raise NotImplementedError() + + @staticmethod + def current_stream(): + raise NotImplementedError() + + @staticmethod + def set_stream(stream: torch.Stream): + raise NotImplementedError() + + @staticmethod + def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): + raise NotImplementedError() + + @staticmethod + def get_raw_stream(): + raise NotImplementedError() + + @staticmethod + def synchronize(device: _device_t = None): + raise NotImplementedError() + + @staticmethod + def get_device_properties(device: _device_t = None): + raise NotImplementedError() + + @staticmethod + def get_compute_capability(device: _device_t = None): + raise NotImplementedError() + + +class CudaInterface(DeviceInterface): + device = torch.cuda.device + + # register Event and Stream class into the backend interface + # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase + Event = torch.cuda.Event + Stream = torch.cuda.Stream + + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["cuda"] = device + + @staticmethod + def current_device() -> int: + if "cuda" in caching_worker_current_devices: + return caching_worker_current_devices["cuda"] + return torch.cuda.current_device() + + @staticmethod + def get_device_properties(device: _device_t = None): + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "cuda" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = CudaInterface.Worker.current_device() + + if "cuda" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["cuda"] = device_prop + + return caching_worker_device_properties["cuda"][device] + + current_device = staticmethod(torch.cuda.current_device) + set_device = staticmethod(torch.cuda.set_device) + device_count = staticmethod(torch.cuda.device_count) + stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] + current_stream = staticmethod(torch.cuda.current_stream) + set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] + _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] + synchronize = staticmethod(torch.cuda.synchronize) + get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] + get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[arg-type] + + # Can be mock patched by @patch decorator. + @staticmethod + def is_available() -> bool: + return torch.cuda.is_available() + + @staticmethod + def get_compute_capability(device: _device_t = None): + major, min = torch.cuda.get_device_capability(device) + return major * 10 + min + + +device_interfaces: Dict[str, Type[DeviceInterface]] = {} + + +def register_interface_for_device( + device: Union[str, torch.device], device_interface: Type[DeviceInterface] +): + if isinstance(device, torch.device): + device = str(device) + device_interfaces[device] = device_interface + + +def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]: + if isinstance(device, torch.device): + device = str(device) + if device in device_interfaces: + return device_interfaces[device] + raise NotImplementedError(f"No interface for device {device}") + + +def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]: + return device_interfaces.items() + + +register_interface_for_device("cuda", CudaInterface) +for i in range(torch.cuda.device_count()): + register_interface_for_device(f"cuda:{i}", CudaInterface) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/eval_frame.py b/MLPY/Lib/site-packages/torch/_dynamo/eval_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..54772e5547291d68d7c16719e5adfa50c2c3b9dc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/eval_frame.py @@ -0,0 +1,1561 @@ +# mypy: disable-error-code="method-assign" + +""" +Functions in this file are responsible for modifying the eval frame +handler at RUNTIME. Therefore, all functions in this file are hot. +Functions that only execute at compile time should be placed +in torch._dynamo.convert_frame. +""" + +from __future__ import annotations + +import contextlib +import functools +import inspect +import logging +import os +import sys +import textwrap +import threading +import traceback +import types +import warnings +import weakref +from enum import Enum +from os.path import dirname, join +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union +from unittest.mock import patch + +import torch +import torch.fx +import torch.utils._pytree as pytree +import torch.utils.checkpoint +from torch import _guards +from torch._subclasses import fake_tensor +from torch._utils_internal import log_export_usage +from torch.export import Constraint +from torch.export.dynamic_shapes import _process_dynamic_shapes +from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + DimDynamic, + StatelessSymbolicContext, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from ..fx import GraphModule +from .backends.registry import CompilerFn, lookup_backend + +from .hooks import Hooks + +# see discussion at https://github.com/pytorch/pytorch/issues/120699 +reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401 +set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401 +set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401 +skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401 +unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401 + +from . import config, convert_frame, external_utils, trace_rules, utils +from .code_context import code_context +from .exc import CondOpArgsMismatchError, UserError, UserErrorType +from .mutation_guard import install_generation_tagging_init +from .types import CacheEntry, DynamoCallback +from .utils import common_constant_types, compile_times + +log = logging.getLogger(__name__) + +from torch._dispatch.python import enable_python_dispatcher + +always_optimize_code_objects = utils.ExactWeakKeyDictionary() +null_context = contextlib.nullcontext + + +import sympy + + +# See https://github.com/python/typing/pull/240 +class Unset(Enum): + token = 0 + + +unset = Unset.token + +guarded_backend_cache = threading.local() +cached_backends: Dict[int, CompilerFn] = {} + + +def check_current_backend(backend_obj_id: int): + """ + Called from guards to check if we need to recompile due to a backend change + """ + # TODO(jansel): we should move guarded_backend_cache to C++ + try: + if guarded_backend_cache.skip_backend_check_for_run_only_mode: + return True + except AttributeError: + # Go slightly faster next time + guarded_backend_cache.skip_backend_check_for_run_only_mode = False + try: + current_backend = guarded_backend_cache.current_backend + except AttributeError: + current_backend = None + return ( + # Avoid the dict lookup in case of exact same object + id(current_backend) == backend_obj_id + or current_backend == cached_backends.get(backend_obj_id, None) + ) + + +def _reset_guarded_backend_cache(): + global cached_backends + guarded_backend_cache.skip_backend_check_for_run_only_mode = False + guarded_backend_cache.current_backend = None + for backend in cached_backends.values(): + if hasattr(backend, "reset"): + backend.reset() + cached_backends.clear() + + +def backend_cache_manager(callback: DynamoCallback): + # callback is False for RunOnlyContext. RunOnlyContext is used + # as a way to re-use the previous compiled cache. + # We therefore skip the check and re-use whatever code that's already cached. + # Note: the cache that's actually used depends on the caching policy. + if callback is False: + + def change(): + try: + prev_skip = guarded_backend_cache.skip_backend_check_for_run_only_mode + except AttributeError: + prev_skip = False + guarded_backend_cache.skip_backend_check_for_run_only_mode = True + + def revert(): + guarded_backend_cache.skip_backend_check_for_run_only_mode = prev_skip + + return revert + + else: + backend = innermost_fn(callback) + + def change(): + cached_backends.setdefault(id(backend), backend) + try: + prev_backend = guarded_backend_cache.current_backend + except AttributeError: + prev_backend = None + guarded_backend_cache.current_backend = backend + + def revert(): + guarded_backend_cache.current_backend = prev_backend + + return revert + + return change + + +DONT_WRAP_FILES = { + # For tracing into fx modules + inspect.getsourcefile(GraphModule), + join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"), +} + + +def _debug_get_cache_entry_list( + code: Union[types.CodeType, Callable[..., Any]] +) -> List[CacheEntry]: + """ + Given a code object or a callable object, retrieve the cache entries + stored in this code. + """ + if callable(code): + code = code.__code__ + return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code) + + +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + _torchdynamo_orig_callable: Callable[..., Any] + get_compiler_config: Callable[[], Any] + + def __init__(self, mod: torch.nn.Module, dynamo_ctx): + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + self.dynamo_ctx = dynamo_ctx + self._initialize() + + def _initialize(self): + # Do this stuff in constructor to lower overhead slightly + if isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check( + self._orig_mod.forward + ): + # This may be a torch.nn.* instance in trace_rules.py which + # won't trigger a frame evaluation workaround to add an extra + # frame we can capture + self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) + else: + # Invoke hooks outside of dynamo then pickup the inner frame + self.forward = self.dynamo_ctx(self._orig_mod.__call__) + + if hasattr(self._orig_mod, "_initialize_hook"): + self._forward = self.forward + self.forward = self._call_lazy_check + + def __getstate__(self): + state = dict(self.__dict__) + state.pop("forward", None) + state.pop("__call__", None) + return state + + def __setstate__(self, state): + self.__dict__ = state + self._initialize() + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def _call_lazy_check(self, *args, **kwargs): + if hasattr(self._orig_mod, "_initialize_hook"): + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it. + # Afterwards, lazy module deletes its pre-hooks + # to avoid treating it as lazy on subsequent recompile. + self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) + return self._forward(*args, **kwargs) + + def __dir__(self): + orig_mod_attrs = self._orig_mod.__dir__() + return orig_mod_attrs + [ + attr for attr in super().__dir__() if attr not in orig_mod_attrs + ] + + +def remove_from_cache(f): + """ + Make sure f.__code__ is not cached to force a recompile + """ + if isinstance(f, types.CodeType): + reset_code(f) + elif hasattr(f, "__code__"): + reset_code(f.__code__) + elif hasattr(getattr(f, "forward", None), "__code__"): + reset_code(f.forward.__code__) + else: + from . import reset # type: ignore[attr-defined] + + reset() + log.warning("could not determine __code__ for %s", f) + + +def nothing(): + pass + + +def always_false(): + return False + + +def innermost_fn(fn): + """ + In case of nesting of _TorchDynamoContext calls, find the innermost + function. TorchDynamo caches on fn.__code__ object, so its necessary to find + the innermost function to pass on the optimize, run, disable etc. + """ + unaltered_fn = fn + while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): + unaltered_fn = unaltered_fn._torchdynamo_orig_callable + assert callable(unaltered_fn) + return unaltered_fn + + +def make_set_enable_dynamic(enable: bool): + assert isinstance(enable, bool) + if enable: + # Assume everything is dynamic by default + return config._make_closure_patcher(assume_static_by_default=False) + else: + return config._make_closure_patcher( + automatic_dynamic_shapes=False, assume_static_by_default=True + ) + + +class _TorchDynamoContext: + def __init__( + self, + callback: DynamoCallback, + on_enter=nothing, + backend_ctx_ctor=null_context, + patch_fn=nothing, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + ): + super().__init__() + assert callable(callback) or callback is False or callback is None + self.callback: DynamoCallback = callback + self.prior: Union[Unset, DynamoCallback] = unset + self.first_ctx = first_ctx + self.export = export + self.compiler_config = compiler_config + self.cleanup_fns: List[Callable[[], Any]] = [] + self.enter_exit_hooks = [backend_cache_manager(self.callback)] + patch_fn() + + if dynamic is not None: + self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) + + if on_enter is not nothing: + # this case is not common + def call_on_enter(): + on_enter() + return nothing + + self.enter_exit_hooks.append(call_on_enter) + + if backend_ctx_ctor is not contextlib.nullcontext: + # this case is not common + def call_backend_ctx(): + ctx = backend_ctx_ctor() + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_backend_ctx) + + def __enter__(self): + if config.raise_on_ctx_manager_usage: + raise RuntimeError( + "torch._dynamo.optimize(...) is used with a context manager. " + "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html " + "to use torch._dynamo.optimize(...) as an annotation/decorator. " + ) + self.cleanup_fns = [enter() for enter in self.enter_exit_hooks] + self.prior = set_eval_frame(self.callback) + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.prior is not unset + set_eval_frame(self.prior) + self.prior = unset + for cleanup in self.cleanup_fns: + cleanup() + self.cleanup_fns.clear() + + def __call__(self, fn): + # public api for compiler config/options + def get_compiler_config(): + return self.compiler_config + + fn = innermost_fn(fn) + + # add context containing GraphModule to any GraphModule forward functions + from torch.fx._lazy_graph_module import _LazyGraphModule + + if isinstance(fn, _LazyGraphModule) or ( + isinstance(getattr(fn, "__self__", None), _LazyGraphModule) + and fn.__name__ == "_lazy_forward" + ): + # Since dynamo will run the forward method for the GraphModule shortly + # anyways, it does not hurt to do the real recompilation here if + # this is a _LazyGraphModule. This makes it easier for dynamo to + # optimize a _LazyGraphModule. + + lazy_gm = fn if isinstance(fn, _LazyGraphModule) else fn.__self__ + + _LazyGraphModule.force_recompile(lazy_gm) + + # Assume that the underlying node metadata of `fn`, + # a GraphModule instance, accurately represents + # all instances of type(fn). + code_context.get_context(lazy_gm.forward.__code__)[ + "orig_graphmodule" + ] = weakref.ref(lazy_gm) + + if not isinstance(fn, _LazyGraphModule): + # replace fn with the real forward method + fn = lazy_gm.forward + elif isinstance(fn, GraphModule): + code_context.get_context(fn.forward.__code__)[ + "orig_graphmodule" + ] = weakref.ref(fn) + + # Optimize the forward method of torch.nn.Module object + if isinstance(fn, torch.nn.Module): + mod = fn + new_mod = OptimizedModule(mod, self) + # Save the function pointer to find the original callable while nesting + # of decorators. + new_mod._torchdynamo_orig_callable = mod.forward + + # when compiling torch.nn.Module, + # provide public api OptimizedModule.get_compiler_config() + assert not hasattr(new_mod, "get_compiler_config") + new_mod.get_compiler_config = get_compiler_config + + return new_mod + assert callable(fn) + + try: + filename = inspect.getsourcefile(fn) + except TypeError: + filename = None + if ( + (filename is None or trace_rules.check(fn)) + and ( + getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl"] + ) + and filename not in DONT_WRAP_FILES + ): + # call to a builtin without a frame for us to capture + fn = external_utils.wrap_inline(fn) + + callback = self.callback + + if isinstance(self, DisableContext): + is_jit_tracing = always_false + is_fx_tracing = always_false + else: + is_jit_tracing = torch._C._is_tracing + is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing + + @functools.wraps(fn) + def _fn(*args, **kwargs): + if is_fx_tracing(): + if config.error_on_nested_fx_trace: + raise RuntimeError( + "Detected that you are using FX to symbolically trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + else: + return fn(*args, **kwargs) + + if is_jit_tracing(): + if config.error_on_nested_jit_trace: + raise RuntimeError( + "Detected that you are using FX to torch.jit.trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + else: + return fn(*args, **kwargs) + + cleanups = [enter() for enter in self.enter_exit_hooks] + prior = set_eval_frame(callback) + try: + return fn(*args, **kwargs) + finally: + set_eval_frame(prior) + for cleanup in cleanups: + cleanup() + + # hooks to properly handle inlining + if isinstance(self, DisableContext): + _fn._torchdynamo_disable = True # type: ignore[attr-defined] + else: + _fn._torchdynamo_inline = fn # type: ignore[attr-defined] + + # Save the function pointer to find the original callable while nesting + # of decorators. + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + + # when compiling user function instead of nn.Module + # provide public api _fn.get_compiler_config() + assert not hasattr(_fn, "get_compiler_config") + _fn.get_compiler_config = get_compiler_config # type: ignore[attr-defined] + + # If the function is called using torch._dynamo.optimize decorator, we + # should prevent any type of skipping. + if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please wrap the relevant code into a function and optimize the + wrapper function. + + >> class CallableClass: + >> def __init__(self): + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function and other code, wrap that up in a function + + >> def wrapper_fn(x): + >> y = mod(x) + >> return y.sum() + + and then optimize the wrapper_fn + + >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) + """ + ) + ) + always_optimize_code_objects[fn.__code__] = True + + return _fn + + +class OptimizeContext(_TorchDynamoContext): + def __init__( + self, + callback, + backend_ctx_ctor, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + ): + def on_enter(): + install_generation_tagging_init() + + super().__init__( + callback=callback, + on_enter=on_enter, + backend_ctx_ctor=backend_ctx_ctor, + patch_fn=TorchPatcher.patch, + first_ctx=first_ctx, + export=export, + dynamic=dynamic, + compiler_config=compiler_config, + ) + + +class RunOnlyContext(_TorchDynamoContext): + def __init__(self): + # cudagraph trees relies on generation increment + def on_enter(): + torch._dynamo.mutation_guard.GenerationTracker.generation += 1 + + super().__init__(callback=False, on_enter=on_enter) + + +class DisableContext(_TorchDynamoContext): + def __init__(self): + super().__init__(callback=None) + + +def _optimize_catch_errors( + compile_fn, + hooks: Hooks, + backend_ctx_ctor=null_context, + export=False, + dynamic=None, + compiler_config=None, +): + return OptimizeContext( + convert_frame.catch_errors_wrapper(compile_fn, hooks), + backend_ctx_ctor=backend_ctx_ctor, + first_ctx=True, + export=export, + dynamic=dynamic, + compiler_config=compiler_config, + ) + + +def get_compiler_fn(compiler_fn): + from .repro.after_dynamo import wrap_backend_debug + + if hasattr(compiler_fn, "compiler_name"): + compiler_str = compiler_fn.compiler_name + elif isinstance(compiler_fn, str): + compiler_str = compiler_fn + else: + compiler_str = None + compiler_fn = lookup_backend(compiler_fn) + return wrap_backend_debug(compiler_fn, compiler_str) + + +class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] + def __call__(self, fn): + assert callable(fn) + return fn + + +def check_if_dynamo_supported(): + if sys.version_info >= (3, 12): + raise RuntimeError("Python 3.12+ not yet supported for torch.compile") + + +def is_dynamo_supported(): + try: + check_if_dynamo_supported() + return True + except Exception: + return False + + +def check_if_inductor_supported(): + check_if_dynamo_supported() + + if sys.platform == "win32": + raise RuntimeError("Windows not yet supported for inductor") + + +def is_inductor_supported(): + try: + check_if_inductor_supported() + return True + except Exception: + return False + + +def optimize( + backend="inductor", + *, + nopython=False, + guard_export_fn=None, + guard_fail_fn=None, + disable=False, + dynamic=None, +): + """ + The main entrypoint of TorchDynamo. Do graph capture and call + backend() to optimize extracted graphs. + + Args: + backend: One of the two things: + - Either, a function/callable taking a torch.fx.GraphModule and + example_inputs and returning a python callable that runs the + graph faster. + One can also provide additional context for the backend, like + torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. + See AOTAutogradMemoryEfficientFusionWithContext for the usage. + - Or, a string backend name in `torch._dynamo.list_backends()` + nopython: If True, graph breaks will be errors and there will + be a single whole-program graph. + disable: If True, turn this decorator into a no-op + dynamic: If True, upfront compile as dynamic a kernel as possible. If False, + disable all dynamic shapes support (always specialize). If None, automatically + detect when sizes vary and generate dynamic kernels upon recompile. + + Example Usage:: + + @torch._dynamo.optimize() + def toy_example(a, b): + ... + """ + check_if_dynamo_supported() + # Note: The hooks object could be global instead of passed around, *however* that would make + # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. + # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same + # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an + # easier to understand UX at the cost of a little more plumbing on our end. + hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn) + torch._C._log_api_usage_once("torch._dynamo.optimize") + if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": + return _NullDecorator() + + backend = get_compiler_fn(backend) + + # Find if backend has any extra context manager + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) + + if nopython: + return optimize_assert( + backend, + dynamic=dynamic, + hooks=hooks, + ) + return _optimize_catch_errors( + convert_frame.convert_frame(backend, hooks=hooks), + hooks, + backend_ctx_ctor, + dynamic=dynamic, + compiler_config=backend.get_compiler_config() + if hasattr(backend, "get_compiler_config") + else None, + ) + + +# TODO(voz): Consider making "explain" output alongside a run / part of a run +@patch("torch._dynamo.symbolic_convert.explain", True) +def explain(f, *extra_args, **extra_kwargs): + def inner(*args, **kwargs): + # TODO(voz): Do we want a decorator for this? + from . import reset # type: ignore[attr-defined] + + reset() + + graphs: List[torch.fx.GraphModule] = [] + break_reasons: List[Any] = [] + op_count: int = 0 + ops_per_graph: List[torch.fx.Node] = [] + out_guards: List[_guards.Guard] = [] + + def dynamo_graph_accumulating_compiler( + gm: torch.fx.GraphModule, example_inputs + ): + from .backends.debugging import _explain_graph_detail + + nonlocal graphs + nonlocal op_count + nonlocal ops_per_graph + nonlocal break_reasons + + gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail( + gm, graphs, op_count, ops_per_graph, break_reasons + ) + + return gm.forward + + def guard_export_print(guards): + nonlocal out_guards + out_guards.extend(guards) + + opt_f = optimize( + dynamo_graph_accumulating_compiler, + nopython=False, + guard_export_fn=guard_export_print, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. + opt_f(*args, **kwargs) + + graph_count = len(graphs) + + # For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it. + deduped_reasons = {} + for reason in break_reasons: + innermost_frame = reason.user_stack[-1] + # __repr__ uniquely identifies a FrameSummary so we can use it for deduping + deduped_reasons[repr(innermost_frame)] = reason + + formatted_list = "" + for idx, break_reason in enumerate(deduped_reasons.values()): + formatted_stack = "".join(traceback.format_list(break_reason.user_stack)) + msg = f"{idx + 1}. Reason: {break_reason.reason}\n User Stack: {formatted_stack}\n" + formatted_list += msg + + graph_break_count = graph_count - 1 + compile_time = compile_times(repr="str") + + # TODO(voz): Do we want a decorator for this? + reset() + from .backends.debugging import ExplainOutput + + return ExplainOutput( + graphs, + graph_count, + graph_break_count, + break_reasons, + op_count, + ops_per_graph, + out_guards, + compile_time, + ) + + if extra_args or extra_kwargs: + warnings.warn( + "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your explain call in the future if your user defined kwargs " + "conflict with future kwargs added to explain(f)." + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner + + +class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): + def __init__( + self, + m: torch.fx.GraphModule, + flat_args: Tuple[Any], + matched_input_elements_positions: List[int], + flat_results: List[Any], + matched_output_elements_positions: List[int], + example_fake_inputs: List[torch.Tensor], + flat_args_dynamic_dims: List[Set[int]], + fake_mode: Optional[fake_tensor.FakeTensorMode] = None, + ): + super().__init__(m) + + assert len(flat_args_dynamic_dims) == len(flat_args) + matched_input_elements_to_fake = { + val: example_fake_inputs[ix] + for ix, val in enumerate(matched_input_elements_positions) + } + + self.new_args = [] + for i in range(0, len(flat_args)): + arg = super().placeholder(f"arg{i}", (), {}) + if i in matched_input_elements_to_fake: + arg.node.meta["val"] = matched_input_elements_to_fake[i] + else: + # Fill node.mata["val"] with faketensor from the input, + # if it's not found in matched_input_elements_positions + if fake_mode is not None and isinstance(flat_args[i], torch.Tensor): + # TODO(zhxchen17) Also preserve all the user constraints here. + arg.node.meta["val"] = fake_mode.from_tensor( + flat_args[i], + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC + if d in flat_args_dynamic_dims[i] + else DimDynamic.STATIC + for d in range(len(flat_args[i].shape)) + ], + constraint_sizes=[None] * len(flat_args[i].shape), + ), + ) + self.new_args.append(arg) + self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions) + self.matched_output_elements_positions = matched_output_elements_positions + self.flat_results = flat_results + + def placeholder(self, target, args, kwargs): + arg = next(self.old_args_gen) + if "val" in self.current_node.meta: + arg.node.meta["val"] = self.current_node.meta["val"] + if "tensor_dict" in self.current_node.meta: + arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"] + if "example_value" in self.current_node.meta: + arg.node.meta["example_value"] = self.current_node.meta["example_value"] + return arg + + def output(self, target, args, kwargs): + dynamo_result_flat = args[0] + lookup = [*dynamo_result_flat, *self.new_args] + new_results_flat = [] + for i in range(len(self.flat_results)): + if self.matched_output_elements_positions[i] is not None: + new_results_flat.append( + lookup[self.matched_output_elements_positions[i]] + ) + else: + const_val = self.flat_results[i] + assert isinstance(const_val, tuple(common_constant_types)) + new_results_flat.append(const_val) + return super().output(target, (new_results_flat,), {}) + + def run_node(self, n): + self.current_node = n + result_proxy = super().run_node(n) + if "val" in self.current_node.meta: + result_proxy.node.meta["val"] = self.current_node.meta["val"] + if "example_value" in self.current_node.meta: + result_proxy.node.meta["example_value"] = self.current_node.meta[ + "example_value" + ] + if self.current_node.op != "output": + result_proxy.node._rename( + getattr(self.current_node, "name", result_proxy.node.name) + ) + return result_proxy + + def transform(self): + result_gm = super().transform() + if "dynamo_flat_name_to_original_fqn" in self.module.meta: + result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ + "dynamo_flat_name_to_original_fqn" + ] + return result_gm + + +class ExportResult(NamedTuple): + graph_module: torch.fx.GraphModule + guards: _guards.GuardsSet + # NB: Do not add new fields without overriding __iter__; people are + # destructuring so it is BC-breaking + + +def check_signature_rewritable(graph): + input_errors = [] + for node in graph.graph.nodes: + if node.op == "placeholder": + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + user_stacks = graph._source_to_user_stacks.get(source) + if user_stacks is None: + continue + assert len(user_stacks) > 0 + # In some cases we may not have a useful stack. Look for a + # useful stack + stack = None + for s in user_stacks: + if len(s) == 0: + continue + stack = s + break + if stack is None: + msg = f"{source.name()}, a closed over free variable" + else: + tb = "".join(traceback.format_list(stack)) + extra = "" + if len(user_stacks) > 1: + extra = f"(elided {len(user_stacks)-1} more accesses)" + msg = f"{source.name()}, accessed at:\n{tb}{extra}" + # TODO: option to print ALL of the stack traces at once + input_errors.append(msg) + + if input_errors: + raise UserError( + UserErrorType.INVALID_INPUT, + "Cannot export model which references tensors that are neither " + "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd " + "like this tensor to be an explicit input, add it as a dummy argument " + "to the top-level model definition you are exporting; if you would " + "like its value to be embedded as an exported constant, wrap its access " + "in a function marked with @assume_constant_result.\n\n" + + "\n\n".join(input_errors), + ) + + +def rewrite_signature( + f_sig, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_output, + dynamo_traced_result, + flat_args_dynamic_dims, +): + orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) + + def check_user_input_output(flat_values, error_type): + supported_types = [ + torch.Tensor, + torch.SymInt, + torch.SymFloat, + torch.SymBool, + torch._C.ScriptObject, + ] + list(common_constant_types) + + def is_supported_type(val): + return isinstance(val, tuple(supported_types)) + + value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" + # We only check that the outputs are not None. Inputs can be None. + for v in flat_values: + if not is_supported_type(v): + if error_type == UserErrorType.INVALID_INPUT and v is None: + continue + + raise UserError( + error_type, + f"It looks like one of the {value_type}s with type `{type(v)}` " + "is not supported or pytree-flattenable. \n" + f"Exported graphs {value_type}s can only contain the " + f"following supported types: {supported_types}. \n" + "If you are using a custom class object, " + "please register a pytree_flatten/unflatten function " + "using `torch.utils._pytree.register_pytree_node` or " + "`torch.export.register_dataclass`.", + ) + + check_user_input_output(flat_args, UserErrorType.INVALID_INPUT) + flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) + check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) + + def produce_matching(debug_type, sources, candidates): + matched_elements_positions: List[Optional[int]] = [] + dict_of_source_vals = {} + for i, val in enumerate(sources): + dict_of_source_vals[id(val)] = i + + for i, val in enumerate(candidates): + if isinstance(val, tuple(common_constant_types)): + matched_elements_positions.append(None) + elif id(val) not in dict_of_source_vals: + raise AssertionError( + f"Unexpectedly found a {type(val)} in the {debug_type}.\n" + 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"' + ) + else: + matched_elements_positions.append(dict_of_source_vals[id(val)]) + + return matched_elements_positions + + matched_input_elements_positions = produce_matching( + "inputs", flat_args, graph_captured_input + ) + + assert graph_captured_output is not None + matched_output_elements_positions = produce_matching( + "outputs", list(graph_captured_output) + flat_args, flat_results_traced + ) + + new_graph = FlattenInputOutputSignature( + graph, + flat_args, + matched_input_elements_positions, + flat_results_traced, + matched_output_elements_positions, + example_fake_inputs, + flat_args_dynamic_dims, + fake_mode, + ).transform() + + # Make dynamo graph to have same input/output spec as user code + def argument_names(f_sig, args, kwargs) -> List[str]: + def signature_to_fullargspec(sig: inspect.Signature): + # Get a list of Parameter objects from the Signature object + params = list(sig.parameters.values()) + # Separate positional arguments, keyword-only arguments and varargs/varkw + args = [ + p.name + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwonlyargs = [ + p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY + ] + varargs = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), + None, + ) + varkw = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), + None, + ) + # Get default values for positional arguments and keyword-only arguments + defaults = tuple( + p.default + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is not inspect.Parameter.empty + ) + kwonlydefaults = { + p.name: p.default + for p in params + if p.kind == inspect.Parameter.KEYWORD_ONLY + and p.default is not inspect.Parameter.empty + } + # Get annotations for parameters and return value + annotations = {} + if sig.return_annotation: + annotations = {"return": sig.return_annotation} + for parameter in params: + annotations[parameter.name] = parameter.annotation + # Return a FullArgSpec object with the extracted attributes + return inspect.FullArgSpec( + args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations + ) + + fullargspec = signature_to_fullargspec(f_sig) + + # 1. Map `args` 1-to-1 to positional arguments in original signature. + input_strs = fullargspec.args[: len(args)] + + if len(args) > len(fullargspec.args): + # 2. If there are more arguments left in `args`, they map to varargs in original + # signature. Assign names as {varargs}_0, {varargs}_1, ... + assert fullargspec.varargs is not None, "More arguments than expected" + input_strs += [ + f"{fullargspec.varargs}_{i}" + for i in range(0, len(args) - len(input_strs)) + ] + elif len(args) < len(fullargspec.args): + # 3. If there are fewer arguments in `args` than `fullargspec.args`, + # it implies these are arguments either with default values, or provided in + # `kwargs`. The former can be safely ignored. Because Dynamo.export does not + # export them as part of the function signature. The latter will be handled + # in the next step. + for unprovided_arg in fullargspec.args[ + len(args) : -len(fullargspec.defaults or []) + ]: + assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" + + # 4. Keyword arguments provided in `kwargs`. + input_strs += list(kwargs.keys()) + + # 5. Keyword-only arguments with default values if not provided are not exported + # as part of the function signature. + for kwonly_arg in fullargspec.kwonlyargs: + kwonlydefaults = fullargspec.kwonlydefaults or {} + assert ( + kwonly_arg in kwargs or kwonly_arg in kwonlydefaults + ), f"Missing keyword only argument {kwonly_arg}" + + return input_strs + + new_graph.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + argument_names(f_sig, orig_args, orig_kwargs), + in_spec, + out_spec_traced, + ) + ) + new_graph.recompile() + return new_graph + + +def export( + f: Callable[..., Any], + *extra_args, + aten_graph: bool = False, + pre_dispatch: bool = False, + decomposition_table: Optional[ + Dict[torch._ops.OpOverload, Callable[..., Any]] + ] = None, + tracing_mode: str = "symbolic", + constraints: Optional[List[Constraint]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + assume_static_by_default: bool = False, + same_signature: bool = True, + disable_constraint_solver: bool = False, + _log_export_usage: bool = True, + **extra_kwargs, +) -> Callable[..., ExportResult]: + """ + Export an input function f to a format that can be executed outside of PyTorch using the FX graph. + + Args: + f (callable): A PyTorch function to be exported. + + aten_graph (bool): If True, exports a graph with ATen operators. + If False, exports a graph with Python operators. Default is False. + + pre_dispatch (bool): If True, exports a graph with ATen operators, + but before any logic in the PyTorch dispatcher has run. + This can be useful if you want to apply further transformations on a graph before running it + through autograd, autocast, or any other functionalities that are integrated into the dispatcher. + This flag is only valid if aten_graph=True is set. + Default is False. + + decomposition_table (dict): A dictionary that maps operators to their decomposition functions. + Required if aten_graph or tracing_mode is specified. Default is None. + + tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic". + + constraints: [DEPRECATED: use ``dynamic_shapes`` instead, see below] + An optional list of constraints on the dynamic arguments + that specify their possible range of shapes. By default, shapes of + input torch.Tensors are assumed to be static. If an input torch.Tensor + is expected to have dynamic shapes, please use :func:`dynamic_dim` + to define :class:`Constraint` objects that specify the dynamics and the possible + range of shapes. See :func:`dynamic_dim` docstring for examples on + how to use it. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + same_signature (bool): If True, rewrite the returned graph's signature to be the same as f. + + disable_constraint_solver (bool): Whether the dim constraint solver must be disabled. + + Returns: + A function that given args and kwargs, returns a tuple of (graph, guards) + Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. + Guards: The guards we accumulated during tracing f above + + Raises: + AssertionError: If decomposition_table is specified without setting aten_graph=True, + or if graph breaks during tracing in export. + + AssertionError: If Dynamo input and output is not consistent with traced input/output. + + Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. + """ + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_dynamo"}) + + # Deal with "local variable referenced before assignment" + _f = f + _assume_static_by_default = assume_static_by_default + + def inner(*args, **kwargs): + nonlocal constraints + if constraints is not None: + if _log_export_usage: + warnings.warn( + "Using `constraints` to specify dynamic shapes for export is DEPRECATED " + "and will not be supported in the future. " + "Please use `dynamic_shapes` instead (see docs on `torch.export.export`).", + DeprecationWarning, + stacklevel=2, + ) + else: + constraints = _process_dynamic_shapes(_f, args, kwargs, dynamic_shapes) + f = _f + assume_static_by_default = _assume_static_by_default + check_if_dynamo_supported() + torch._C._log_api_usage_once("torch._dynamo.export") + if decomposition_table is not None: + assert ( + aten_graph + ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" + if pre_dispatch: + assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" + f = innermost_fn(f) + call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f + original_signature = inspect.signature(call_to_inspect) + graph = None + out_guards = None + graph_captured_input = None + graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None + fake_mode = None + + def guard_export_print(guards: _guards.GuardsSet): + nonlocal out_guards + assert ( + out_guards is None + ), "whole graph export entails exactly one guard export" + out_guards = guards + + example_inputs = [] + + def dynamo_normalization_capturing_compiler( + gm: torch.fx.GraphModule, inner_example_inputs + ): + nonlocal graph + assert ( + graph is None + ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." + graph = gm + + nonlocal fake_mode, example_inputs + # NB: do NOT pass inner_example_inputs here, we are detecting the + # Dynamo allocated fake mode, which should be DISTINCT from a + # potential outer ambient fake mode which the user provided. + # example_inputs is always the user specified inputs, so they + # would have the wrong fake mode attached to them + fake_mode = _guards.detect_fake_mode() + example_inputs = inner_example_inputs + + def result_capturing_wrapper(*graph_inputs): + nonlocal graph_captured_result + nonlocal graph_captured_input + + graph_captured_input = graph_inputs + assert graph is not None + + named_parameters = dict(graph.named_parameters(remove_duplicate=False)) + named_buffers = dict(graph.named_buffers(remove_duplicate=False)) + + ambient_fake_mode = ( + _guards.detect_fake_mode(graph_inputs) + if _guards.detect_fake_mode(graph_inputs) is not None + else fake_mode + ) + + with ambient_fake_mode, enable_python_dispatcher(): + params_and_buffers = { + **named_parameters, + **named_buffers, + } + fake_params_buffers = dict() + + for name, value in params_and_buffers.items(): + fake_params_buffers[name] = ambient_fake_mode.from_tensor( + value, static_shapes=True + ) + + fake_graph_inputs = pytree.tree_map( + ambient_fake_mode.from_tensor, graph_inputs + ) + graph_captured_result = torch.func.functional_call( + graph, fake_params_buffers, fake_graph_inputs + ) + + return graph_captured_result + + return result_capturing_wrapper + + # Note: This is needed by rewrite_signature. We need to put it before + # optimize_assert since user program may mutate the inputs. + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + remove_from_cache(f) + constraint_violation_error = None + if tracing_mode != "symbolic": + assume_static_by_default = True + with config.patch( + specialize_int=True, + assume_static_by_default=assume_static_by_default, + automatic_dynamic_shapes=False, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ): + opt_f = optimize_assert( + dynamo_normalization_capturing_compiler, + hooks=Hooks( + guard_export_fn=guard_export_print, + guard_fail_fn=None, + ), + export=True, + export_constraints=constraints, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. + try: + result_traced = opt_f(*args, **kwargs) + except ConstraintViolationError as e: + constraint_violation_error = e + remove_from_cache(f) + + if ( + not disable_constraint_solver + and (shape_env := getattr(fake_mode, "shape_env", None)) is not None + and (dim_constraints := shape_env.dim_constraints) is not None + and not isinstance( + call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) + and not trace_rules.check(call_to_inspect) + ): + dim_constraints.solve() + dim_constraints.remove_redundant_dynamic_results() + forced_specializations = dim_constraints.forced_specializations() + msg = dim_constraints.prettify_results( + original_signature, constraint_violation_error, forced_specializations + ) + if constraint_violation_error: + constraint_violation_error.args = ( + constraint_violation_error.args[0] + msg, + ) + else: + if forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + else: + log.info( + "Summary of dimension constraints:%s", + msg, + ) + + # Error if we have any constraints on static values + for k in shape_env.var_to_range.keys(): + if isinstance(k, sympy.Integer): + constraint_violation_error = ConstraintViolationError( + f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" + "It appears that you're trying to set a constraint on a " + f"value which we evaluated to have a static value of {k}. " + 'Set TORCH_LOGS="+export" for more information.' + ) + if constraint_violation_error: + raise constraint_violation_error + + assert ( + graph is not None + ), "Failed to produce a graph during tracing as no tensor operations were found." + assert hasattr(graph, "_source_to_user_stacks") + assert out_guards is not None, "Failed to produce guards during tracing" + assert fake_mode is not None + + log.info( + "Dynamo captured graph:\n\n%s", graph.print_readable(print_output=False) + ) + + # This check need to happened before aten_graph + # because placeholder's _source_node attribute is not preserved by make_fx + if same_signature: + check_signature_rewritable(graph) + + # NB: This is mostly hitting the cache; Dynamo already converted these + example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] + + if aten_graph: + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(graph).run(*args) + + with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), ( + fake_mode + ): + try: + graph = make_fx( + graph_with_interpreter, + decomposition_table=decomposition_table, + tracing_mode="real", + _allow_non_fake_inputs=True, + pre_dispatch=pre_dispatch, + _allow_fake_constant=False, + )(*example_fake_inputs) + except CondOpArgsMismatchError as e: + # Wrap the internal error to the user-facing error + raise UserError( # noqa: TRY200 + UserErrorType.DYNAMIC_CONTROL_FLOW, + str(e), + case_name="cond_operands", + ) + + assert graph is not None + for node in graph.graph.nodes: + if node.op == "get_attr" and isinstance( + getattr(graph, node.target), torch.Tensor + ): + node.meta["val"] = fake_mode.from_tensor( + getattr(graph, node.target), static_shapes=True + ) + + if same_signature: + flat_args_dynamic_dims = [ + {c.dim for c in (constraints or ()) if c.w_tensor() is x} + for x in flat_args + ] + graph = rewrite_signature( + original_signature, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_result, + result_traced, # type: ignore[possibly-undefined] + flat_args_dynamic_dims, + ) + # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check + assert graph is not None + graph.meta["input_shape_constraints"] = ( + [constraint.serializable_spec for constraint in constraints] + if constraints + else [] + ) + + return ExportResult(graph, out_guards) + + if extra_args or extra_kwargs: + warnings.warn( + "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your export call in the future if your user defined kwargs " + "conflict with future kwargs added to export(f)." + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner + + +def optimize_assert( + backend, + *, + hooks=Hooks(None, None), + export=False, + export_constraints=None, + dynamic=None, +): + """ + The same as `torch._dynamo.optimize(backend, nopython=True)` + """ + backend = get_compiler_fn(backend) + + # Find if backend has any extra context manager + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) + + return _optimize_catch_errors( + convert_frame.convert_frame_assert( + backend, export=export, export_constraints=export_constraints + ), + hooks, + backend_ctx_ctor, + export=export, + dynamic=dynamic, + ) + + +class TorchPatcher: + @staticmethod + @functools.lru_cache(None) + def patch(): + # A better way to disable the following would be decorate the source + # functions with @torch._disable_dynamo. However, this causes issues + # with torch.deploy internally. + from .decorators import disable + + torch.jit.trace = disable(torch.jit.trace) + torch.jit.trace_module = disable(torch.jit.trace_module) + torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) + torch.fx._symbolic_trace.Tracer.trace = disable( + torch.fx._symbolic_trace.Tracer.trace + ) + torch.distributions.Distribution.set_default_validate_args(False) + + from ..optim import ( + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + lbfgs, + nadam, + radam, + rmsprop, + rprop, + sgd, + sparse_adam, + ) + + optimizer_modules = { + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + lbfgs, + nadam, + radam, + rmsprop, + rprop, + sgd, + sparse_adam, + } + + for opt_mod in optimizer_modules: + opt_name = opt_mod.__name__.split(".")[-1] + fused_fn_name = f"_fused_{opt_name}" + single_tensor_fn_name = f"_single_tensor_{opt_name}" + + if hasattr(opt_mod, fused_fn_name): + setattr( + opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) + ) + + optimizer_classes = [ + opt + for opt in torch.optim.__dict__.values() + if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) + ] + + # Note: we don't support sparsity or tracing through backwards + excluded_optimizer_classes = { + torch.optim.SparseAdam, + torch.optim.LBFGS, + } + + for opt in optimizer_classes: + if opt in excluded_optimizer_classes: + opt.step = disable(opt.step) + + if hasattr(opt, "_init_group"): + opt._init_group = disable(opt._init_group) + + @staticmethod + def suppress_torch_distributed_warnings(fn): + def inner_fn(*args, **kwargs): + warnings.filterwarnings( + "ignore", category=UserWarning, module="torch.distributed" + ) + return fn(*args, **kwargs) + + return inner_fn diff --git a/MLPY/Lib/site-packages/torch/_dynamo/exc.py b/MLPY/Lib/site-packages/torch/_dynamo/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..b626595d3487782f27eaa962c4e13fdaea39df75 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/exc.py @@ -0,0 +1,335 @@ +import os +import textwrap +from enum import auto, Enum +from traceback import extract_stack, format_exc, format_list, StackSummary +from typing import cast, NoReturn, Optional + +import torch._guards + +from . import config + +from .utils import counters + + +def exportdb_error_message(case_name): + return ( + "For more information about this error, see: " + + "https://pytorch.org/docs/main/generated/exportdb/index.html#" + + case_name.replace("_", "-") + ) + + +import logging + +log = logging.getLogger(__name__) +graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") + + +class TorchDynamoException(RuntimeError): + pass + + +class InternalTorchDynamoError(TorchDynamoException): + pass + + +class RestartAnalysis(TorchDynamoException): + pass + + +class SpeculationRestartAnalysis(RestartAnalysis): + pass + + +class UnspecializeRestartAnalysis(RestartAnalysis): + pass + + +class SkipFrame(TorchDynamoException): + pass + + +class TorchRuntimeError(TorchDynamoException): + pass + + +class InvalidBackend(TorchDynamoException): + def __init__(self, name): + super().__init__( + f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." + ) + + +class ResetRequired(TorchDynamoException): + def __init__(self): + super().__init__( + textwrap.dedent( + """ + Must call `torch._dynamo.reset()` before changing backends. Detected two calls to + `torch.compile()` with a different backend compiler arguments. + """ + ) + ) + + +class BackendCompilerFailed(TorchDynamoException): + def __init__(self, backend_fn, inner_exception): + self.backend_name = getattr(backend_fn, "__name__", "?") + self.inner_exception = inner_exception + msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" + super().__init__(msg) + + +class Unsupported(TorchDynamoException): + def __init__(self, msg): + super().__init__(msg) + self.real_stack = torch._guards.TracingContext.extract_stack() + self.msg = msg + self.category: Optional[str] = None + self.add_to_stats() + + def remove_from_stats(self): + assert self.category is not None + counters[self.category][self.msg] -= 1 + if counters[self.category][self.msg] <= 0: + del counters[self.category][self.msg] + + def add_to_stats(self, category="unimplemented"): + self.category = category + counters[category][self.msg] += 1 + + +class RecompileError(TorchDynamoException): + pass + + +class ArgsMismatchError(Unsupported): + def __init__(self, msg): + super().__init__(msg) + + +class AttributeMutationError(Unsupported): + def __init__(self, msg): + super().__init__(msg) + + +class CondOpArgsMismatchError(ArgsMismatchError): + """ + Internal error from cond() due to arguments mismatch. + """ + + def __init__(self, msg): + super().__init__(msg) + + +class UserErrorType(Enum): + DYNAMIC_CONTROL_FLOW = auto() + ANTI_PATTERN = auto() + STANDARD_LIBRARY = auto() + CONSTRAINT_VIOLATION = auto() + DYNAMIC_DIM = auto() + INVALID_INPUT = auto() + INVALID_OUTPUT = auto() + + +class UserError(Unsupported): + def __init__(self, error_type: UserErrorType, msg, case_name=None): + """ + Type of errors that would be valid in Eager, but not supported in TorchDynamo. + The error message should tell user about next actions. + + error_type: Type of user error + msg: Actionable error message + case_name: (Optional) Unique name (snake case) for the usage example in exportdb. + """ + if case_name is not None: + assert isinstance(case_name, str) + if msg.endswith("."): + msg += " " + else: + msg += "\n" + msg += exportdb_error_message(case_name) + super().__init__(msg) + self.error_type = error_type + self.message = msg + + +class UncapturedHigherOrderOpError(TorchDynamoException): + pass + + +class IncorrectUsage(Exception): + pass + + +# These exceptions are ok to fallback to eager/graph_break. +exceptions_allowed_to_be_fallback = ( + torch._subclasses.fake_tensor.DataDependentOutputException, + torch._subclasses.fake_tensor.DynamicOutputShapeException, + torch._subclasses.fake_tensor.UnsupportedOperatorException, + torch._subclasses.fake_tensor.UnsupportedFakeTensorException, +) + + +def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: + # This function calls unimplemented internally and eventually graph breaks + # or falls to eager. unimplemented itself does not print any user warnings, + # i.e., its very silent. This helper function is intended when an error is + # encountered in the torch.compile stack which is worth showing as warning + # to the user. For example, if AOT Autograd backend fails with a fake tensor + # exception, its ok to fallback to eager but not silently. Here, we can use + # this function to log the message and the stack trace. + graph_break_msg = format_error_msg_verbose(e, code) + graph_breaks_log.debug("%s", graph_break_msg) + log.warning(msg) + raise unimplemented(msg) from e + + +def unimplemented(msg: str) -> NoReturn: + assert msg != os.environ.get("BREAK", False) + raise Unsupported(msg) + + +def warning(msg: str) -> None: + counters["warnings"][msg] += 1 + assert msg != os.environ.get("BREAK", False) + + +# KeyError has special handling for its args +# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details +class KeyErrorMsg: + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self) -> str: + return self.__str__() + + +def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: + import traceback + + exc.innermost_user_frame_summary = None # type: ignore[attr-defined] + + real_stack = get_real_stack(exc) + if real_stack is not None and len(real_stack) > 0: + exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined] + msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" + + if config.replay_record_enabled and hasattr(exc, "record_filename"): + msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ + torch._dynamo.replay('{exc.record_filename}').\n" + + if not config.verbose and hasattr(exc, "real_stack"): + msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n' + + if hasattr(exc, "inner_exception") and hasattr( + exc.inner_exception, "minifier_path" + ): + if hasattr(exc.inner_exception, "buck_command"): + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + f"this buck command to find the smallest traced graph " + f"which reproduces this error: {exc.inner_exception.buck_command}\n" + ) + else: + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + "this script to find the smallest traced graph which reproduces this error.\n" + ) + + if not config.suppress_errors and not export: + msg += ( + "\n\n" + "You can suppress this exception and fall back to eager by setting:\n" + " import torch._dynamo\n" + " torch._dynamo.config.suppress_errors = True\n" + ) + + old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) + + if isinstance(exc, KeyError): + exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] + else: + new_msg = old_msg + msg + exc.args = (new_msg,) + exc.args[1:] + + +def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]: + real_stack = getattr(exc, "real_stack", None) + if real_stack is None: + return None + + # NB: it's possible for real_stack to be []; we still attempt to + # report a stack anyway because the stack_above_dynamo may still + # be useful for debugging + + stack_above_dynamo = [] + if frame is not None: + # NB: frame is PyInterpreterFrame on Python 3.11 and later, + # not a TRUE frame object. You can't actually feed it + # to traceback because it doesn't have enough information. + # To solve this problem, we technically should just materialize + # the frame, the same way _PyFrame_GetFrameObject would do + # (but we cannot actually do this, because this populates + # frame_obj field, which default eval frame doesn't like). + # + # Fortunately, in this case, we can hack it: there's no need + # to actually use the truly top frame, we can just extract + # from where we are right now and rely on filter_stack to + # get rid of all the dynamo frames. For ease of testing + # we apply this behavior to ALL Python versions + stack_above_dynamo = filter_stack(extract_stack()) + + return cast(StackSummary, stack_above_dynamo + real_stack) + + +# filter out all frames after entering dynamo +def filter_stack(stack): + user_stack = [] + for frame in stack: + if "convert_frame" in frame.filename: + break + if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line: + continue + user_stack.append(frame) + + return user_stack + + +def format_error_msg_verbose( + exc: Exception, code, record_filename=None, frame=None +) -> str: + msg = ( + f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" + ) + msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" + msg += format_exc() + real_stack = get_real_stack(exc, frame) + if real_stack is not None: + msg += ( + "\n" + + "=" * 10 + + " The above exception occurred while processing the following code " + + "=" * 10 + + "\n\n" + ) + msg += "".join(format_list(real_stack)) + msg += "\n" + msg += "=" * 10 + + return msg + + +def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str: + msg = os.linesep * 2 + + if config.verbose: + msg = format_error_msg_verbose(exc, code, record_filename, frame) + else: + msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ + line {code.co_firstlineno} \ndue to: \n{format_exc()}" + + return msg diff --git a/MLPY/Lib/site-packages/torch/_dynamo/external_utils.py b/MLPY/Lib/site-packages/torch/_dynamo/external_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0856ec12aa67bb0192259474afd7948f36078221 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/external_utils.py @@ -0,0 +1,103 @@ +# This module contains functions that *will be allowed* by dynamo + +import functools + +import torch +import torch.utils._pytree as pytree + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + +def is_compiling() -> bool: + """ + Indicates whether we are tracing/compiling with torch.compile() or torch.export(). + + If need to check specifically that TorchDynamo is used, then use + torch.compiler.is_dynamo_compiling(). + + TODO(khabinov): we should deprecate this function and use one of these two: + * torch.compiler.is_compiling(), + * torch.compiler.is_dynamo_compiling(). + It will depend on the context where to use what. + """ + return torch.compiler.is_compiling() + + +def wrap_inline(fn): + """ + Create an extra frame around fn that is not in skipfiles + """ + + @functools.wraps(fn) + def inner(*args, **kwargs): + return fn(*args, **kwargs) + + return inner + + +def call_hook(hook, *args): + """ + Used by compiled autograd to handle hook returning None + """ + result = hook(*args) + if result is None: + return args[0] + return result + + +def wrap_numpy(f): + r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function + from ``torch.Tensor``s to ``torch.Tensor``s. + """ + if not np: + return f + + @functools.wraps(f) + def wrap(*args, **kwargs): + args, kwargs = pytree.tree_map_only( + torch.Tensor, lambda x: x.numpy(), (args, kwargs) + ) + out = f(*args, **kwargs) + return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) + + return wrap + + +class FakeContext: + def __init__(self, saved_tensors): + # this will cache the results of saved_tensors + # and will no longer call into c++ binding + self.saved_tensors = saved_tensors + + +def call_backward(backward_fn, saved_tensors, *args): + grads = backward_fn(FakeContext(saved_tensors), *args) + + # in eager, we wrap in a tuple when there's only one grad output + if type(grads) is not tuple: + grads = (grads,) + + return grads + + +def untyped_storage_size(x: torch.Tensor): + return x.untyped_storage().size() + + +def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs): + return getattr(bw_state, hook_name)(*args, **kwargs) + + +def call_module_hooks_from_backward_state( + _, result, *args, bw_state, hooks_name: str, module_name: str +): + module = getattr(bw_state, module_name) + hooks = getattr(bw_state, hooks_name) + for hook in hooks: + new_result = hook(module, result, *args) + if new_result is not None: + result = new_result + return result diff --git a/MLPY/Lib/site-packages/torch/_dynamo/funcname_cache.py b/MLPY/Lib/site-packages/torch/_dynamo/funcname_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0dc1886f35fecd90de86254e22351d11fdf560 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/funcname_cache.py @@ -0,0 +1,57 @@ +import tokenize + +from typing import Dict, List, Optional + +cache: Dict[str, Dict[int, str]] = {} + + +def clearcache() -> None: + cache.clear() + + +def _add_file(filename: str) -> None: + try: + with open(filename) as f: + tokens = list(tokenize.generate_tokens(f.readline)) + except OSError: + cache[filename] = {} + return + + # NOTE: undefined behavior if file is not valid Python source, + # since tokenize will have undefined behavior. + result: Dict[int, str] = {} + # current full funcname, e.g. xxx.yyy.zzz + cur_name = "" + cur_indent = 0 + significant_indents: List[int] = [] + + for i, token in enumerate(tokens): + if token.type == tokenize.INDENT: + cur_indent += 1 + elif token.type == tokenize.DEDENT: + cur_indent -= 1 + # possible end of function or class + if significant_indents and cur_indent == significant_indents[-1]: + significant_indents.pop() + # pop the last name + cur_name = cur_name.rpartition(".")[0] + elif ( + token.type == tokenize.NAME + and i + 1 < len(tokens) + and tokens[i + 1].type == tokenize.NAME + and (token.string == "class" or token.string == "def") + ): + # name of class/function always follows class/def token + significant_indents.append(cur_indent) + if cur_name: + cur_name += "." + cur_name += tokens[i + 1].string + result[token.start[0]] = cur_name + + cache[filename] = result + + +def get_funcname(filename: str, lineno: int) -> Optional[str]: + if filename not in cache: + _add_file(filename) + return cache[filename].get(lineno, None) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/guards.py b/MLPY/Lib/site-packages/torch/_dynamo/guards.py new file mode 100644 index 0000000000000000000000000000000000000000..a6cd7ab94ea3c9fab444757e96eea58dfc561939 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/guards.py @@ -0,0 +1,1505 @@ +from __future__ import annotations + +import ast +import builtins +import collections +import dataclasses +import enum +import functools +import importlib +import inspect +import itertools +import logging +import math +import os +import re +import sys +import textwrap +import types +import weakref +from inspect import currentframe, getframeinfo +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from weakref import ReferenceType + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +import torch +import torch.utils._device +from torch._dynamo.source import ( + is_from_local_source, + TensorProperty, + TensorPropertySource, +) + +from torch._guards import ( + DuplicateInputs, + Guard, + GuardBuilderBase, + GuardEnvExpr, + GuardSource, + Source, +) + +from torch._logging import structured +from torch.fx.experimental.symbolic_shapes import ( + EqualityConstraint, + is_symbolic, + SYMPY_INTERP, +) +from torch.utils._traceback import format_frame, report_compile_source_on_error +from torch.utils.weak import TensorWeakRef + +from . import config, convert_frame, exc, mutation_guard +from .eval_frame import set_guard_error_hook +from .source import AttrSource, DefaultsSource, LocalSource, TypeSource +from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401 +from .utils import ( + common_constant_types, + dict_keys_repr, + guard_failures, + istype, + key_is_id, + key_to_id, + orig_code_map, + tensor_always_has_static_shape, + tuple_iterator_getitem, + tuple_iterator_len, +) + +log = logging.getLogger(__name__) +guards_log = torch._logging.getArtifactLogger(__name__, "guards") +recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") +recompiles_verbose_log = torch._logging.getArtifactLogger( + __name__, "recompiles_verbose" +) +verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") + +TensorGuards = torch._C._dynamo.guards.TensorGuards +check_obj_id = torch._C._dynamo.guards.check_obj_id +check_type_id = torch._C._dynamo.guards.check_type_id +dict_version = torch._C._dynamo.guards.dict_version + + +# For user stack printing +@functools.lru_cache(None) +def uninteresting_files(): + import torch._dynamo.external_utils + + mods = [ + torch._dynamo.external_utils, + ] + return {inspect.getfile(m) for m in mods} + + +CLOSURE_VARS = { + "___check_type_id": check_type_id, + "___check_obj_id": check_obj_id, + "___odict_getitem": collections.OrderedDict.__getitem__, + "___key_to_id": key_to_id, + "___dict_version": dict_version, + "___dict_contains": lambda a, b: a in b, + "___tuple_iterator_len": tuple_iterator_len, + "___tuple_iterator_getitem": tuple_iterator_getitem, + "__math_isnan": math.isnan, + "__numpy_isnan": None if np is None else np.isnan, + "inf": float("inf"), + "__load_module": importlib.import_module, + "utils_device": torch.utils._device, + "device": torch.device, + "___from_numpy": + # If not numpy array, piggy back on e.g. tensor guards to check type + (lambda a: torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a), + "torch": torch, + "inspect": inspect, +} + +if sys.version_info[:2] <= (3, 8): + # [Note: Python Version <= 3.8] + # This branch should be dropped when we drop support for Python 3.8. + # Reason: 'ast.unparse' function was introduced in Python 3.9. + + try: + import astunparse # type: ignore[import] + + def _ast_unparse(node: ast.AST) -> str: + return astunparse.unparse(node).replace("\n", "") + + HAS_UNPARSE_FUNCTIONS = True + except ImportError: + HAS_UNPARSE_FUNCTIONS = False + pass +else: + HAS_UNPARSE_FUNCTIONS = True + + def _ast_unparse(node: ast.AST) -> str: + return ast.unparse(node).replace("\n", "") + + +def strip_function_call(name): + """ + "___odict_getitem(a, 1)" => "a" + "a.layers[slice(2)][0]._xyz" ==> "a" + "getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a" + "getattr(getattr(a.x[3], '0'), '3')" ==> "a" + "a.layers[slice(None, -1, None)][0]._xyz" ==> "a" + """ + # recursively find valid object name in function + valid_name = re.compile("[A-Za-z_].*") + curr = "" + for char in name: + if char in " (": + curr = "" + elif char in "),[]": + if curr and curr != "None" and valid_name.match(curr): + return strip_function_call(curr) + else: + curr += char + + return strip_getattr_getitem(name) + + +def strip_getattr_getitem(name): + """ + "a[1]" => "a" + "a.foo" => "a" + """ + return re.split(r"[.\[]", name)[0] + + +def get_verbose_code_part(code_part, guard): + extra = "" + if guard.user_stack: + for fs in reversed(guard.user_stack): + if fs.filename not in uninteresting_files(): + extra = f" # {format_frame(fs, line=True)}" + break + elif guard.stack: + extra = f" # {format_frame(guard.stack.summary()[-1])}" + + return f"{code_part:<60}{extra}" + + +def convert_to_concrete_values(size_or_stride): + converted: List[Optional[int]] = [] + for dim in size_or_stride: + if not is_symbolic(dim): + converted.append(dim) + else: + assert isinstance(dim, torch.SymInt) + converted.append(dim.node.maybe_as_int()) + return converted + + +def get_tensor_guard_code_part(value, name, sizes, strides): + pytype = type(value) + dispatch_key = ( + torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set() + ) - torch._C._dispatch_tls_local_exclude_set() + dtype = value.dtype + device_index = value.device.index + requires_grad = value.requires_grad + guard_str = ( + f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " + f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" + ) + return guard_str + + +# The ready to eval generated code (possibly multiple parts) for a guard, plus +# the original guard object that created it for provenance +@dataclasses.dataclass +class GuardCodeList: + code_list: List[str] + guard: Guard + + +class GuardBuilder(GuardBuilderBase): + def __init__( + self, + id_ref: Callable[[Any], str], + source_ref: Callable[[Source], str], + lookup_weakrefs: Callable[[object], ReferenceType[object]], + local_scope: Dict[str, object], + global_scope: Dict[str, object], + check_fn_manager: CheckFunctionManager, + ): + self.id_ref = id_ref + self.source_ref = source_ref + self.lookup_weakrefs = lookup_weakrefs + self.scope: Dict[str, Dict[str, object]] = {"L": local_scope, "G": global_scope} + self.scope["__builtins__"] = builtins.__dict__.copy() + for ( + name, + package_module, + ) in torch.package.package_importer._package_imported_modules.items(): + name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") + # Write the package module into the scope so that we can import it + self.scope["__builtins__"][name] = package_module + # Write the demangled name to the scope so that we can use it + self.scope[name] = package_module + + self.argnames: List[str] = [] + # Code is python expression strings generated for each guard + self.code: List[GuardCodeList] = [] + # shape_env_code is only used by builder and is used for + # shape env code. This exists only because we need to make sure + # shape env guards get run after tensor match guards (since the + # tensor match guards make sure we actually have tensors) + self.shape_env_code: List[GuardCodeList] = [] + + # [Note - On Eager Tensor Guards] + # Most of the time, we generate Python code in a guard to directly + # check various properties. However, tensors are a bit special; + # it is too slow to check their properties one-by-one in Python. + # Instead, there is a C++ function TensorGuards.check which takes + # all of the tensor arguments and checks them all against compile-time + # examples entirely in C++. Thus, every time we process a + # TENSOR_MATCH guard, we just add another entry to + # tensor_check_names/tensor_check_examples, saying "for this local, + # check it against this example", and it all ends up getting + # swept up into a single call to ___check_tensors. Invariant: + # len(tensor_check_names) == len(tensor_check_examples). + # TODO: something here + self.tensor_check_names: List[str] = [] + self.tensor_check_examples: List[torch.Tensor] = [] + self.tensor_check_guards: List[Guard] = [] + + self.check_fn_manager: CheckFunctionManager = check_fn_manager + # Keep track of weak references of objects with ID_MATCH guard. This + # info is stored alongside optimized_code and check_fn and is used to + # limit the number of cache entries with same ID_MATCH'd object. + self.id_matched_objs: Dict[str, ReferenceType[object]] = {} + + # Warning: use this with care! This lets you access what the current + # value of the value you are guarding on is. You probably don't want + # to actually durably save this value though (because it's specific + # to this frame!) Instead, you should be reading out some property + # (like its type) which is what you permanently install into the + # guard code. + def get(self, name: str) -> Any: + return eval(name, self.scope, CLOSURE_VARS) + + # Registers the usage of the source name referenced by the + # string (or stored in the Guard) as being guarded upon. It's important + # to call this before generating some code that makes use of 'guard', + # because without this call, we won't actually bind the variable + # you reference in the actual guard closure (oops!) + def arg_ref(self, guard: Union[str, Guard]) -> str: + name: str + if isinstance(guard, str): + name = guard + else: + name = guard.name + base = strip_getattr_getitem(strip_function_call(name)) + if base not in self.argnames: + if re.match(r"[a-zA-Z0-9_]+", base): + if re.match(r"^\d+$", base): + log.warning("invalid var name: %s", guard) + self.argnames.append(base) + + return name + + def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): + attr_source = AttrSource(guard.originating_source, attr_name) + # Copy the stack info + new_guard = Guard( + attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack + ) + new_guard.create(self) + + def TYPE_MATCH(self, guard: Guard) -> None: + # ___check_type_id is same as `id(type(x)) == y` + t = type(self.get(guard.name)) + obj_id = self.id_ref(t) + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" + self._produce_guard_code(guard, [code]) + + def DICT_VERSION(self, guard: Guard): + # ___check_dict_version is same as `dict_version(x) == y` + ref = self.arg_ref(guard) + version = dict_version(self.get(guard.name)) + code = f"___dict_version({ref}) == {version}" + self._produce_guard_code(guard, [code]) + + def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): + dict_ref = self.arg_ref(guard) + + maybe_not = "not " if invert else "" + code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" + return self._produce_guard_code(guard, [code]) + + def BOOL_FALSE(self, guard: Guard): + # Guard on the runtime value being 'False', + # can be faster than seemingly equivalent checks like DICT_KEYS for empty dict + # + # WARNING: this guard is not safe to use generally. It only works if the runtime + # value is of a type that supports bool(), and some types e.g. Tensor do not. + # Only use this guard in cases you can guarantee the runtime type will be friendly. + # (e.g. Specialized NNModule with mutation protection via setattr) + # + # Why not simply check the runtime type inside this guard? It's slow enough to defeat + # the purpose of using this guard, which itself is supposed to be a faster alternative + # to DICT_KEYS. + ref = self.arg_ref(guard) + code = f"not {ref}" + self._produce_guard_code(guard, [code]) + + def ID_MATCH(self, guard: Guard): + # ___check_obj_id is same as `id(x) == y` + if isinstance(guard.originating_source, TypeSource): + # optional optimization to produce cleaner/faster guard code + return self.TYPE_MATCH( + Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) # type: ignore[arg-type] + ) + + ref = self.arg_ref(guard) + val = self.get(guard.name) + code = f"___check_obj_id({ref}, {self.id_ref(val)})" + self._produce_guard_code(guard, [code]) + + # Keep track of ID_MATCH'd objects. This will be used to modify the + # cache size logic + if isinstance(guard.originating_source, LocalSource): + # TODO(janimesh) - This is currently restricted to nn.Module objects + # because many other ID_MATCH'd objects fail - like DeviceMesh. + # Increase the scope of ID_MATCH'd objects. + if isinstance(val, torch.nn.Module): + local_name = guard.originating_source.local_name + weak_id = self.lookup_weakrefs(val) + if weak_id is not None: + self.id_matched_objs[local_name] = weak_id + + def NAME_MATCH(self, guard: Guard): + obj = self.get(guard.name) + self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) + + def DATA_PTR_MATCH(self, guard: Guard): + obj = self.get(guard.name) + code = f"{self.arg_ref(guard)}.data_ptr() == {obj.data_ptr()}" + self._produce_guard_code(guard, [code]) + + def HASATTR(self, guard: Guard): + assert isinstance( + guard.originating_source, AttrSource + ), f"invalid source {guard.name}" + base_source = guard.originating_source.base + base = base_source.name() + attr = guard.originating_source.member + + ref = self.arg_ref(base) + val = hasattr(self.get(base), attr) + code = None + if val: + code = f"hasattr({ref}, {attr!r})" + else: + code = f"not hasattr({ref}, {attr!r})" + + self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base)) + + def FUNCTORCH_STACK_MATCH(self, guard: Guard): + # Invalidate functorch code if current level is different than + # the one when FX graph was generated + # if torch._C._functorch.peek_interpreter_stack() is not None: + cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters() + states = [ci.get_state() for ci in cis] + code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] + self._produce_guard_code(guard, code) + + def EQUALS_MATCH(self, guard: Guard): + ref = self.arg_ref(guard) + val = self.get(guard.name) + t = type(val) + if np: + np_types: Tuple[Type[Any], ...] = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ) + else: + np_types = () + ok_types = tuple( + common_constant_types + | { + type, + list, + tuple, + set, + frozenset, + slice, + range, + torch.Size, + *np_types, + } + ) + if istype(val, dict): + assert all( + istype(x, ok_types) for x in itertools.chain(val.keys(), val.values()) + ) + else: + assert istype( + val, + ok_types, + ), f"Unexpected type {type(val)}, not in {ok_types}" + + # Special case for nan because float("nan") == float("nan") evaluates to False + if istype(val, float) and math.isnan(val): + self.TYPE_MATCH(guard) + code = list() + code.append(f"__math_isnan({ref})") + self._produce_guard_code(guard, code) + return + # Python math library doesn't support complex nan, so we need to use numpy + elif istype(val, complex) and np.isnan(val): + self.TYPE_MATCH(guard) + code = list() + code.append(f"__numpy_isnan({ref})") + self._produce_guard_code(guard, code) + return + + code = list() + + # If matching equality against list/tuple, we must also check that + # the internal types match. (TODO: what about nested lists?) + if istype(val, (list, tuple)): + # NB: SEQUENCE_LENGTH takes care of the outer __check_type_id test + self.SEQUENCE_LENGTH(guard) + + for idx, elem in enumerate(val): + code.append( + f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})" + ) + else: + # Add type check to prevent equality check between tensor and non-tensor. + self.TYPE_MATCH(guard) + + if istype(val, torch.Size): + val = tuple(val) + + # Code object can not be compared against their string representation + # I.e `eval(f"{compile('2+2','','exec')!r}")` raises SyntaxError + assert not istype(val, types.CodeType) + + # TODO: It feels like it would be better to just implement our own + # equality test in C that handles all of the necessary type checking + # and NaN tests + code.append(f"{ref} == {val!r}") + self._produce_guard_code(guard, code) + + def CONSTANT_MATCH(self, guard: Guard): + val = self.get(guard.name) + if istype(val, (bool, type(None), types.CodeType)): + self.ID_MATCH(guard) + else: + self.EQUALS_MATCH(guard) + + def NN_MODULE(self, guard: Guard): + self.ID_MATCH(guard) + ref = self.arg_ref(guard) + val = self.get(guard.name) + + def setup_guard(): + assert istype(val.training, bool) + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + + if hasattr(val, "training"): + # There are cases where a monkeypatched object has a guard made between __new__ and __init__ + setup_guard() + else: + exc.unimplemented(f"Guard setup for uninitialized class {type(val)}") + + def FUNCTION_MATCH(self, guard: Guard): + """things like torch.add and user defined functions""" + if guard.is_local(): + return self.ID_MATCH(guard) + + def CLOSURE_MATCH(self, guard: Guard): + """matches a closure by __code__ id.""" + if guard.is_local(): + val = self.get(guard.name) + # Strictly only want user-defined functions + if type(val) == types.FunctionType and hasattr(val, "__code__"): + self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) + self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) + else: + self.FUNCTION_MATCH(guard) + + def BUILTIN_MATCH(self, guard: Guard): + return self.FUNCTION_MATCH(guard) + + def PYMODULE_MATCH(self, guard: Guard): + return self.FUNCTION_MATCH(guard) + + def SEQUENCE_LENGTH(self, guard): + # This guard is used to check lenght of PySequence objects like list, + # tuple, collections.deque etc + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + self.TYPE_MATCH(guard) + code = list() + if len(value) == 0: + code.append(f"not {ref}") + else: + code.append(f"len({ref}) == {len(value)}") + + self._produce_guard_code(guard, code) + + def DICT_LENGTH(self, guard): + self.SEQUENCE_LENGTH(guard) + + def TUPLE_ITERATOR_LEN(self, guard): + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + self.TYPE_MATCH(guard) + code = list() + code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") + + self._produce_guard_code(guard, code) + + # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards + def DUPLICATE_INPUT(self, guard, source_b): + ref_a = self.arg_ref(guard) + ref_b = self.arg_ref(source_b.name()) + + code = [f"{ref_b} is {ref_a}"] + self._produce_guard_code(guard, code) + + def DICT_KEYS(self, guard): + # Guard on the keys and their order + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + self.TYPE_MATCH(guard) + code = list() + any_key_is_id = any(key_is_id(k) for k in value.keys()) + const_keys_repr = dict_keys_repr( + key_to_id(value), + local=is_from_local_source(guard.originating_source), + ) + if any_key_is_id: + code.append(f"___key_to_id({ref}) == {const_keys_repr}") + else: + code.append(f"list({ref}.keys()) == {const_keys_repr}") + + self._produce_guard_code(guard, code) + + def WEAKREF_ALIVE(self, guard): + self._produce_guard_code(guard, [f"{self.arg_ref(guard)} is not None"]) + + def NN_MODULE_PARAM_NAMES(self, guard): + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + keys = {k for k, v in value.named_parameters()} + + self.TYPE_MATCH(guard) + code = list() + code.append(f"{{k for k, v in {ref}.named_parameters()}} == {keys!r}") + + self._produce_guard_code(guard, code) + + def DICT_CONST_KEYS(self, guard): + """Constant keys match""" + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + self.TYPE_MATCH(guard) + code = list() + code.append(f"list({ref}.keys()) == {list(value.keys())!r}") + + self._produce_guard_code(guard, code) + + def OBJECT_MUTATION(self, guard: Guard): + mutation_guard.watch(self.get(guard.name), self.check_fn_manager) + + def GRAD_MODE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DETERMINISTIC_ALGORITHMS(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def TORCH_FUNCTION_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DEFAULT_DEVICE(self, guard: Guard): + """Guard on CURRENT_DEVICE per torch.utils._device""" + assert guard.source is GuardSource.GLOBAL + import torch.utils._device as m + + self._produce_guard_code( + guard, [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"] + ) + + def BACKEND_MATCH(self, guard: Guard): + """Guard on backend matching based on id of current_backend""" + assert guard.source is GuardSource.GLOBAL + backend_id = ( + f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}" + ) + code = [f"___check_current_backend({backend_id})"] + self._produce_guard_code(guard, code) + + def SHAPE_ENV(self, guard: Guard): + # Let's handle ShapeEnv guards. To do this, we will resolve + # shape variables to sources from tracked_fakes. This must happen after + # tensor checks. + assert guard.name == "" + output_graph = self.check_fn_manager.output_graph + # NB: self.output_graph can be None in the debug_nops tests + fs = output_graph.tracked_fakes + input_contexts = [a.symbolic_context for a in fs] + + def get_sources(t_id, dim): + # Looks up base sources mapped to a tensor id and uses them to create + # sources for the corresponding tensor dimension. + return [ + TensorPropertySource(source, TensorProperty.SIZE, dim) + for source in output_graph.tracked_fakes_id_to_source[t_id] + ] + + if output_graph.export_constraints: + from sympy import Symbol + + source_pairs: List[Tuple[Source, Source]] = [] + derived_equalities: List[ # type: ignore[type-arg] + Tuple[Source, Union[Source, Symbol], Callable] + ] = [] + phantom_symbols: Dict[str, Symbol] = {} + for constraint in output_graph.export_constraints: + if constraint.t_id in output_graph.tracked_fakes_id_to_source: + torch.export.dynamic_shapes._process_equalities( + constraint, + get_sources, + output_graph.shape_env, + source_pairs, + derived_equalities, + phantom_symbols, + ) + else: + log.warning("Untracked tensor used in export constraints") + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + warn_only=False, + ) + else: + equalities_inputs = None + guards = output_graph.shape_env.produce_guards( + [a.fake for a in fs], + [a.source for a in fs], + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + source_ref=self.source_ref, + # Export keeps static. + ignore_static=(not self.check_fn_manager.output_graph.export), + ) + # When exporting, we may work with the shape constraints some more in + # postprocessing, so don't freeze yet + if not self.check_fn_manager.output_graph.export: + output_graph.shape_env.freeze() + for shape_guard in guards: + self._produce_guard_code(guard, [shape_guard], shape_env=True) + + def TENSOR_MATCH(self, guard: Guard, value=None): + if guard.is_nn_module() or guard.originating_source.is_dict_key(): + self.ID_MATCH(guard) + else: + if isinstance(value, TensorWeakRef): + value = value() + + value = value if value is not None else self.get(guard.name) + assert isinstance(value, torch.Tensor) + + tensor_name = self.arg_ref(guard) + # [Note - On Export Tensor Guards] + # + # In eager mode, tensor guards are evaluated through C++, in guards.cpp + # see [Note - On Eager Tensor Guards] for more info. + # + # In export mode, we instead maintain parallel logic between C++ and python + # here, with an exception of checking the dispatch key - with the idea that a dispatch key + # is an entirely runtime notion that would make no sense to keep in an exported graph. + # + # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although + # not entirely true. + # For example, suppose one of the input tensors had the negative dispatch key. + # You should end up with a graph that is specialized for tensors that have a negative dispatch key. + # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated. + # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't + # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key. + # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported + # subset of keys during export. + # + # The list of tensor fields and calls we care about can be found in `terms` below. + # TODO(voz): We are missing storage offset in all our tensor guards? + code: List[str] = list() + if self.check_fn_manager.output_graph.export: + self.TYPE_MATCH(guard) + terms = [ + "dtype", + "device", + "requires_grad", + "ndimension()", + ] + + for term in terms: + real_value = self.get(tensor_name + "." + term) + if istype(real_value, (torch.device, torch.dtype)): + # copy pasted from EQUALS_MATCH + code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") + else: + code.append(f"{tensor_name}.{term} == {real_value}") + else: + self.tensor_check_names.append(tensor_name) + self.tensor_check_examples.append(value) + self.tensor_check_guards.append(guard) + + # A frame is valid for reuse with dynamic dimensions if the new + # (user-requested) dynamic dimensions are a subset of the old + # (already compiled) dynamic dimensions. + # + # It's a little non-obvious why you'd want this: in particular, + # if an already compiled frame matches all of the guards, why + # not just use it, why force a recompile? + # + # We force it for two reasons: + # + # - The user *required* us to compile with a new dynamic dimension, + # we should not ignore that and serve up the old, specialized + # frame. Listen to the user! + # + # - In fact, we are obligated to *raise an error* if we fail to + # make the requested dimension dynamic. If we don't + # recompile, we can't tell if that dimension can actually be + # made dynamic. + # + # If the new dynamic dims are a subset of the old, we already know + # we can make them dynamic (since we made them dynamic in old). + # This is slightly unsound, because maybe your input size is + # [s0, s0, s1] and so you can do it dynamic if you say dynamic + # dims {0, 1, 2} but you can't if you only do {0, 2} (because now + # the second s0 is specialized). But we're not entirely sure if + # this is a good idea anyway lol... (if you want to try removing + # this logic, be my guest! -- ezyang 2024) + # + assert guard.source is not None + static, reason = tensor_always_has_static_shape( + value, is_tensor=True, guard_source=guard.source + ) + if not static: + if hasattr(value, "_dynamo_dynamic_indices"): + code.append( + f"(({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950 + ) + # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of + # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled. + else: + code.append( + f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" + ) + if len(code) > 0: + self._produce_guard_code(guard, code) + + # A util that appends guarded code, or, in the case of export, adds data onto guards + def _produce_guard_code( + self, guard, code_list, provided_guarded_object=None, shape_env=False + ): + # WARNING: It is important that cur_frame/caller do NOT stay in + # the current frame, because they will keep things live longer + # than they should. See TestMisc.test_release_module_memory + cur_frame = currentframe() + assert cur_frame is not None + caller = cur_frame.f_back + del cur_frame + assert caller is not None + func_name = getframeinfo(caller)[2] + del caller + # We use func_name for export, so might as well get a nice defensive check out of it + assert func_name in dir( + self.__class__ + ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" + + if shape_env: + self.shape_env_code.append(GuardCodeList(code_list, guard)) + else: + self.code.append(GuardCodeList(code_list, guard)) + + # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD) + if provided_guarded_object is None: + name_valid = guard.name is not None and guard.name != "" + + guarded_object = self.get(guard.name) if name_valid else None + else: + guarded_object = provided_guarded_object + + guarded_object_type = ( + weakref.ref(type(guarded_object)) if guarded_object is not None else None + ) + obj_ref = None + # Not necessary to have weakref for Enum type, but there is a bug that + # makes hasattr(guarded_object.__class__, "__weakref__") return True. + if hasattr(guarded_object.__class__, "__weakref__") and not isinstance( + guarded_object, enum.Enum + ): + obj_ref = weakref.ref(guarded_object) + + guard.set_export_info( + func_name, + guarded_object_type, + code_list, + obj_ref, + ) + + +# Common Sub-Expression Elimination for Python expressions. +# +# There are 2 steps to this pass: +# 1. Count the frequency of each sub-expression (i.e. inner +# node in the AST tree) +# +# 2. Replace those that occur more than once by a fresh variable 'v'. +# 'v' will be defined in the 'preface' list (output argument to +# 'NodeTransformer') +# +# NB: the use of 'ast.unparse' while visiting the nodes makes this pass +# quadratic on the depth of the tree. +# +# NB: this pass creates a new variable for each AST node that is repeated +# more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c' +# and 'a.b' are also used 10 times. So, there will be a new variable for +# each of them. +class PyExprCSEPass: + # Maximum number of times a given expression can be used without being + # replaced by a fresh variable. + USE_THRESHOLD = 1 + + # Ad-Hoc: AST nodes this pass focuses on. + ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript) + + @dataclasses.dataclass + class Config: + expr_count: Dict[str, int] + expr_to_name: Dict[str, str] + + class ExprCounter(ast.NodeVisitor): + def __init__(self, config: PyExprCSEPass.Config) -> None: + self._config = config + + def visit(self, node: ast.AST) -> Any: + if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): + self._config.expr_count[_ast_unparse(node)] += 1 + super().visit(node) + + class Replacer(ast.NodeTransformer): + def __init__( + self, + config: PyExprCSEPass.Config, + gen_name: Callable[[], str], + ) -> None: + super().__init__() + self._config = config + self._gen_name = gen_name + self.preface: List[str] = [] + + def visit(self, node: ast.AST) -> Any: + if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): + expr = _ast_unparse(node) + + # Replacement only occurs if a given expression is used more + # than once. + if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD: + if expr not in self._config.expr_to_name: + # Parent 'visit' is called so that we CSE the inner expressions first. + # + # The resulting expression is used as right-hand-side of the variable + # assignment. i.e. we are CSE-ing the children before the parents. + # + # Indexing still uses the old 'node', since that's what was counted + # by the 'NodeVisitor'. + node_ = super().visit(node) + expr_ = _ast_unparse(node_) + var_name = self._gen_name() + self.preface.append(f"{var_name} = {expr_}") + self._config.expr_to_name[expr] = var_name + else: + var_name = self._config.expr_to_name[expr] + return ast.Name(var_name, ast.Load()) + + return super().visit(node) + + def __init__(self) -> None: + self._counter = 0 + self._config = self.Config( + expr_count=collections.defaultdict(lambda: 0), expr_to_name={} + ) + + def _new_var(self, prefix: str = "_var") -> str: + name = f"{prefix}{self._counter}" + self._counter += 1 + return name + + def count(self, exprs: List[str]) -> None: + counter = self.ExprCounter(self._config) + for e in exprs: + try: + counter.visit(ast.parse(e)) + except SyntaxError as ex: + log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e) + raise + + def replace(self, expr: str) -> Tuple[List[str], str]: + replacer = self.Replacer(self._config, self._new_var) + new_node = replacer.visit(ast.parse(expr)) + return replacer.preface, _ast_unparse(new_node) + + +def must_add_nn_module_guards(guard): + # For config.guard_nn_modules=False, we can skip all the guards that + # originate from inside of nn module except for a few categories. + return ( + # Guard for defaults + isinstance(guard.originating_source, DefaultsSource) + # Guard using dict tags if the config flag is set + or ( + config.guard_nn_modules_using_dict_tags + and guard.create_fn is GuardBuilder.NN_MODULE + ) + ) + + +class DeletedGuardFn: + pass + + +# NB: Naively, you'd expect this to only be a function that produces +# the callable that constitutes the guard. However, there is some +# delicate handling for invalidating this check function when the +# locals/globals get invalidated, so there's some extra state +# we have to hold in this manager class. +class CheckFunctionManager: + def __init__( + self, + output_graph=None, + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, + ): + guards = output_graph.guards if output_graph else None + self._weakrefs: Dict[int, ReferenceType[object]] = {} + self.output_graph = output_graph + w_builder = None + + def source_ref(source): + guard_source = source.guard_source() + if guard_source is GuardSource.CONSTANT: + # No need to track constants + return source.name() + assert w_builder + r_builder = w_builder() + assert r_builder is not None + return r_builder.arg_ref(source.name()) + + builder = GuardBuilder( + self.id_ref, + source_ref, + self.lookup_weakrefs, + output_graph.local_scope, + output_graph.global_scope, + self, + ) + + # Break retain cycle. See test_release_scope_memory + def cleanup_builder(weak_b): + b = weak_b() + if b: + b.scope = None + + # Break retain cycle. See test_release_input_memory + w_builder = weakref.ref(builder, cleanup_builder) + + for guard in sorted(guards or [], key=Guard.sort_key): + if ( + not config.guard_nn_modules + and guard.is_nn_module() + # Default func args must be guarded on. + # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API + and "__defaults__" not in guard.name + and "__kwdefaults__" not in guard.name + and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) + ): + continue + + guard.create(builder) + self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) + # Keep track of weak references of objects with ID_MATCH guard. This + # info is stored alongside optimized_code and check_fn and is used to + # limit the number of cache entries with same ID_MATCH'd object. + # TODO(janimesh) - Currently this information is stored as an attr on + # the check_fn itself to avoid changing CacehEntry datastructure in + # eval_frame.c. In future, we should probably replace check_fn with a + # queryable data structure such that this information is already present + # in some form. + self.check_fn.id_matched_objs = builder.id_matched_objs + + # NB - We have to very careful of cleaning up here. Because of the + # invalidate function, we can create a weakref finalizer that keeps + # `self` alive for very long. Sometimes by mistake, we can run + # invalidate for a type/object (check id_ref method) that Python can + # leak by design, preventing us from calling the finalizer. In that + # case, the `self` will be alive even though the cache entry will be + # deleted (check invalidate method), which can cause a memory leak, + # e.g., not setting output_graph = None can keep hold of nn_modules. + self._weakrefs.clear() + self.output_graph = None + + def compile_check_fn(self, builder, guards_out, guard_fail_fn): + # see parallel handling of ".0" / "___implicit0" in _eval_frame.c + largs = builder.argnames + largs += ["**___kwargs_ignored"] + + guards_log.debug("GUARDS:") + + # Don't report this guard, it's always the same, useless! + code_parts = ["___check_global_state()"] + verbose_code_parts = code_parts[:] + structured_guard_fns = [] + + def add_code_part(code_part, guard, log_only=False): + verbose_code_part = get_verbose_code_part(code_part, guard) + guards_log.debug("%s", verbose_code_part) + + structured_guard_fns.append( + lambda: { + "code": code_part, + "stack": structured.from_traceback(guard.stack.summary()) + if guard.stack + else None, + "user_stack": structured.from_traceback(guard.user_stack) + if guard.user_stack + else None, + } + ) + + if verbose_guards_log.isEnabledFor(logging.DEBUG): + maybe_stack = "" + maybe_user_stack = "" + if guard is not None: + if guard.stack: + maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}" + if guard.user_stack: + maybe_user_stack = ( + f"\nUser stack:\n{''.join(guard.user_stack.format())}" + ) + verbose_guards_log.debug( + "Guard: %s%s%s", + code_part, + maybe_stack, + maybe_user_stack, + ) + + if not log_only: + code_parts.append(code_part) + verbose_code_parts.append(verbose_code_part) + + seen = set() + for gcl in builder.code: + for code in gcl.code_list: + if code not in seen: + add_code_part(code, gcl.guard) + seen.add(code) + + tensor_check_names = builder.tensor_check_names + check_tensors_fn = None + check_tensors_verbose_fn = None + if tensor_check_names: + assert ( + not self.output_graph.export + ), "Illegal to set tensor_check_names in export." + tensor_check_examples = builder.tensor_check_examples + + dynamic_dims_sizes = [ + convert_to_concrete_values( + self.output_graph.tensor_weakref_to_sizes_strides[t]["size"] + ) + for t in tensor_check_examples + ] + + dynamic_dims_strides = [ + convert_to_concrete_values( + self.output_graph.tensor_weakref_to_sizes_strides[t]["stride"] + ) + for t in tensor_check_examples + ] + + tensor_guards = TensorGuards( + *tensor_check_examples, + dynamic_dims_sizes=dynamic_dims_sizes, + dynamic_dims_strides=dynamic_dims_strides, + ) + check_tensors_fn = tensor_guards.check + check_tensors_verbose_fn = tensor_guards.check_verbose + tensor_check_args = ", ".join( + tensor_check_names + ["tensor_check_names=tensor_check_names"] + ) + # Do this manually, to un-stagger the guards in log message + code_parts.append(f"___check_tensors({tensor_check_args})") + verbose_code_parts.append(f"___check_tensors({tensor_check_args})") + tensor_check_guards = builder.tensor_check_guards + + for i, name in enumerate(tensor_check_names): + # This is a copy of what guards.cpp checks against + # Keep this in sync with TensorCheck constructor + t = tensor_check_examples[i] + sizes = dynamic_dims_sizes[i] + strides = dynamic_dims_strides[i] + code_part = get_tensor_guard_code_part(t, name, sizes, strides) + add_code_part(code_part, tensor_check_guards[i], log_only=True) + + aotautograd_guards: List[GuardEnvExpr] = ( + self.output_graph.tracing_context.guards_context.aotautograd_guards + if self.output_graph + else [] + ) + for guard in aotautograd_guards: + if isinstance(guard, DuplicateInputs): + source_a = guard.input_source_a + source_b = guard.input_source_b + add_code_part(f"{source_a.name()} is {source_b.name()}", None) + else: + raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") + + # TODO: the "guard" here is actually just the top level SHAPE_ENV + # which is useless. Get ShapeEnv to pass in more provenance. + for gcl in builder.shape_env_code: + for code in gcl.code_list: + add_code_part(code, gcl.guard) + + # OK, all done generating guards + torch._logging.trace_structured( + "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] + ) + + global_state = convert_frame.initial_global_state + if global_state is None: + # we should only hit this case in NopTests() + global_state = convert_frame.GlobalStateGuard() + closure_vars = { + "___check_tensors": check_tensors_fn, + "___check_tensors_verbose": check_tensors_verbose_fn, + "___check_global_state": global_state.check, + "___check_current_backend": torch._dynamo.eval_frame.check_current_backend, + "tensor_check_names": tensor_check_names, + **SYMPY_INTERP, + **CLOSURE_VARS, + } + + unique_code_parts = list(unique(code_parts)) + make_guard_fn_args = ", ".join(closure_vars.keys()) + guard_body, pycode = build_guard_function(unique_code_parts, make_guard_fn_args) + + if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1": + print("GUARDS\n", guard_body) + + out: Dict[str, Any] = dict() + + # We don't put builder.scope as the globals in exec call because + # guard_fn.__globals__ becomes equal to builder.scope. This causes + # guard_fn to hold a referece to f_locals sitting in builder.scope["L"] + globals_for_guard_fn = {"G": builder.scope["G"]} + try: + exec(pycode, globals_for_guard_fn, out) + except SyntaxError as ex: + log.exception("Failed to exec guard at line %s.\n%s", ex.lineno, pycode) + raise + guard_fn = out["___make_guard_fn"](*closure_vars.values()) + guard_fn.closure_vars = closure_vars + # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both + guard_fn.args = largs + guard_fn.code_parts = code_parts + guard_fn.verbose_code_parts = verbose_code_parts + # Grab only G, but preserve "G" because guards access it as "G" + guard_fn.global_scope = globals_for_guard_fn + guard_fn.guard_fail_fn = guard_fail_fn + # will be populated by a non-owning reference to CacheEntry/ExtraState + # when the CacheEntry is constructed + guard_fn.cache_entry = None + guard_fn.extra_state = None + return guard_fn + + def invalidate(self): + # Some tests reveal that CheckFunctionManager has no attribute + # check_fn, but this case should not be of any concern. + # This case doesn't seem easy to repro. + if ( + hasattr(self, "check_fn") + and self.check_fn is not DeletedGuardFn + and (cache_entry := self.check_fn.cache_entry) is not None + and (extra_state := self.check_fn.extra_state) is not None + ): + assert isinstance(cache_entry, CacheEntry) + assert isinstance(extra_state, ExtraState) + extra_state.invalidate(cache_entry) + self.check_fn.cache_entry = None + self.check_fn.extra_state = None + self.check_fn = DeletedGuardFn + + def id_ref(self, obj): + """add a weakref, return the id""" + try: + if id(obj) not in self._weakrefs: + # We will clear the _weakrefs dict at the end of __init__ + # function, which will delete the callbacks as well. Therefore, + # we are using a finalizer which is kept alive. + self._weakrefs[id(obj)] = weakref.ref(obj) + weakref.finalize(obj, self.invalidate) + except TypeError: + pass # cannot weakref bool object + return id(obj) + + def lookup_weakrefs(self, obj): + """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" + if id(obj) in self._weakrefs: + return self._weakrefs[id(obj)] + return None + + +def build_guard_function(code_parts, closure_args) -> Tuple[str, str]: + from torch._inductor.utils import IndentedBuffer + + if HAS_UNPARSE_FUNCTIONS: + csepass = PyExprCSEPass() + csepass.count(code_parts) + + def replace(expr: str) -> Tuple[List[str], str]: + return csepass.replace(expr) + + else: + + def replace(expr: str) -> Tuple[List[str], str]: + return [], expr + + # Generate the inner body of the guard function. + # i.e. if-chain of the guard expressions. + guard_body = IndentedBuffer() + for expr in code_parts: + preface, expr = replace(expr) + guard_body.writelines(preface) + guard_body.writeline(f"if not ({expr}):") + with guard_body.indent(): + guard_body.writeline("return False") + + # Wrap the inner body into the actual guard function. + guard = IndentedBuffer() + guard.writeline("def guard(L):") + with guard.indent(): + guard.splice(guard_body) + guard.writeline("return True") + + # Wrap the whole guard function into another function + # with the closure variables. + make_guard_fn = IndentedBuffer() + make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):") + with make_guard_fn.indent(): + make_guard_fn.splice(guard) + make_guard_fn.writeline("return guard") + + return guard_body.getvalue(), make_guard_fn.getvalue() + + +def is_recompiles_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles") + + +def is_recompiles_verbose_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") + + +def get_guard_fail_reason( + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], +) -> str: + """ + Return the reason why `guard_fn` failed. + Updates `guard_failures` with the generated reason. + Only the first failed check of guard_fn is reported. + """ + scope = {"L": f_locals, "G": guard_fn.global_scope["G"]} + scope.update(guard_fn.closure_vars) + scope["___check_tensors"] = scope["___check_tensors_verbose"] + reasons: List[str] = [] + for part in guard_fn.verbose_code_parts: + global_scope = dict(guard_fn.global_scope) + global_scope["__compile_source__"] = part + with report_compile_source_on_error(): + try: + fail_reason = eval(part, global_scope, scope) + except Exception as e: + if is_recompiles_verbose_enabled(): + continue + else: + raise + # Only ___check_tensors knows how to return a fancy fail reason; + # for everything else we just report the code that failed + + if isinstance(fail_reason, bool) and not fail_reason: + fail_reason = part + if isinstance(fail_reason, str): + reasons.append(fail_reason) + if not is_recompiles_verbose_enabled(): + break + + reason_str = "\n".join(reasons) + guard_failures[orig_code_map[code]].append(reason_str) + + try: + if guard_fn.guard_fail_fn is not None: + guard_fn.guard_fail_fn( + GuardFail(reason_str or "unknown reason", orig_code_map[code]) + ) + except Exception as e: + log.exception( + "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", + ) + + return reason_str + + +def get_and_maybe_log_recompilation_reason( + cache_entry, frame: types.FrameType +) -> List[str]: + """ + Return the list of guard failure reasons using cache_entry. + Logs the recompilation reason if `recompiles` logging is enabled. + Raises a RecompileError if `config.error_on_recompile` is enabled. + """ + reasons = [] + while cache_entry is not None: + reason = get_guard_fail_reason( + cache_entry.check_fn, cache_entry.code, frame.f_locals + ) + if reason: + reasons.append(reason) + cache_entry = cache_entry.next + + code = frame.f_code + + # at least one of "recompiles" or "recompiles_verbose" is enabled + do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() + + if do_recompiles_log or config.error_on_recompile: + if is_recompiles_verbose_enabled(): + failures = "\n\n".join( + f"guard {i} failures:\n" + textwrap.indent(reason, "- ") + for i, reason in enumerate(reasons) + ) + else: + failures = textwrap.indent("\n".join(reasons), "- ") + guard_failure_details = ( + f"triggered by the following guard failure(s):\n{failures}" + ) + message = ( + f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n" + f"{textwrap.indent(guard_failure_details, ' ')}" + ) + if do_recompiles_log: + if is_recompiles_verbose_enabled(): + recompiles_verbose_log.debug(message) + else: + recompiles_log.debug(message) + if config.error_on_recompile: + raise exc.RecompileError(message) + + return reasons + + +def guard_error_hook( + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], + index: int, + last: bool, +): + print( + f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" + ) + print("lambda " + ", ".join(guard_fn.args) + ":") + print(" ", " and\n ".join(guard_fn.code_parts)) + local_scope = {"L": f_locals, **guard_fn.closure_vars} + for guard in guard_fn.code_parts: + try: + eval(guard, guard_fn.global_scope, local_scope) + except: # noqa: B001,E722 + print(f"Malformed guard:\n{guard}") + + +set_guard_error_hook(guard_error_hook) + + +def unique(seq): + seen = set() + for x in seq: + if x not in seen: + yield x + seen.add(x) + + +def make_dupe_guard(obj_source, dupe_source): + # Note - we may end up in a situation where we invoke something like + # def fn(x, y) + # with fn(x, x) + # Prior to the addition of tracking to all relevant objects, we would handle this just fine by + # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However, + # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here - + # In the fn(x, x) example call above look like a graph with a single input. + # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard. + + # Note - we may not have a source, that is fine, it just means we had an object that is safe to have + # leave unsourced - like a local list created and discharged entirely within a local scope. + if dupe_source and dupe_source != obj_source: + ser_source_is_local = is_from_local_source(dupe_source) + source_is_local = is_from_local_source(obj_source) + # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently + # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here, + # so maybe we should do this refactor before we land this... + # TODO(voz): Combine local and global guard builders. + if ser_source_is_local == source_is_local: + # Note - this is a little aggressive - these being duplicate input does not always matter. + # However, this should always be a sound guard to add here. + return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) + return None + + +def install_guard(*guards, skip=0): + """ + Add dynamo guards to the current tracing context. + + Args: + guards: guard(s) to add + skip: number of stack frames to ignore for debug stack trace + """ + from torch._guards import TracingContext + + collect_debug_stack = guards_log.isEnabledFor( + logging.DEBUG + ) or verbose_guards_log.isEnabledFor(logging.DEBUG) + add = TracingContext.get().guards_context.dynamo_guards.add + for guard in guards: + assert isinstance(guard, Guard) + add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/hooks.py b/MLPY/Lib/site-packages/torch/_dynamo/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..edffccc7c73e96e6af4dfe079d9d8aa2b504ccce --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/hooks.py @@ -0,0 +1,12 @@ +import dataclasses + +from typing import Callable, Optional + +from torch._guards import GuardsSet +from .types import GuardFail + + +@dataclasses.dataclass +class Hooks: + guard_export_fn: Optional[Callable[[GuardsSet], None]] = None + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None diff --git a/MLPY/Lib/site-packages/torch/_dynamo/logging.py b/MLPY/Lib/site-packages/torch/_dynamo/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..5e0fb984b3307f5986bdddafbb56f3997a5d9e25 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/logging.py @@ -0,0 +1,57 @@ +import itertools +import logging + +from torch.hub import _Faketqdm, tqdm + +# Disable progress bar by default, not in dynamo config because otherwise get a circular import +disable_progress = True + + +# Return all loggers that torchdynamo/torchinductor is responsible for +def get_loggers(): + return [ + logging.getLogger("torch.fx.experimental.symbolic_shapes"), + logging.getLogger("torch._dynamo"), + logging.getLogger("torch._inductor"), + ] + + +# Creates a logging function that logs a message with a step # prepended. +# get_step_logger should be lazily called (i.e. at runtime, not at module-load time) +# so that step numbers are initialized properly. e.g.: + +# @functools.lru_cache(None) +# def _step_logger(): +# return get_step_logger(logging.getLogger(...)) + +# def fn(): +# _step_logger()(logging.INFO, "msg") + +_step_counter = itertools.count(1) + +# Update num_steps if more phases are added: Dynamo, AOT, Backend +# This is very inductor centric +# _inductor.utils.has_triton() gives a circular import error here + +if not disable_progress: + try: + import triton # noqa: F401 + + num_steps = 3 + except ImportError: + num_steps = 2 + pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0) + + +def get_step_logger(logger): + if not disable_progress: + pbar.update(1) + if not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(f"{logger.name}") + + step = next(_step_counter) + + def log(level, msg, **kwargs): + logger.log(level, "Step %s: %s", step, msg, **kwargs) + + return log diff --git a/MLPY/Lib/site-packages/torch/_dynamo/mutation_guard.py b/MLPY/Lib/site-packages/torch/_dynamo/mutation_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..abd48febe14843bfe8f1a57bac1f77ad55651afe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/mutation_guard.py @@ -0,0 +1,126 @@ +# mypy: disable-error-code="method-assign" + +import functools +import weakref + +import torch.nn +from torch.nn import Module + +from .utils import ExactWeakKeyDictionary, is_lazy_module + + +class MutationTracker: + db = ExactWeakKeyDictionary() + + def __init__(self): + self.mutation_count = 0 + self.watchers = [] + + def on_mutation(self, name): + self.mutation_count += 1 + tmp = self.watchers + self.watchers = [] + for ref in tmp: + guarded = ref() + if guarded is not None: + guarded.invalidate(ref) + + def track(self, guarded_code): + self.watchers.append(weakref.ref(guarded_code)) + + +def watch(obj, guarded_code): + """invalidate guarded_code when obj is mutated""" + ensure_patched(type(obj)) + + if obj not in MutationTracker.db: + MutationTracker.db[obj] = MutationTracker() + tracker = MutationTracker.db[obj] + tracker.track(guarded_code) + + +def ensure_patched(cls): + if getattr(cls, "___needs_mutation_patch", True): + cls.___needs_mutation_patch = False + original_setattr = cls.__setattr__ + + @functools.wraps(original_setattr) + def custom_setattr(self, key, value): + try: + MutationTracker.db[self].on_mutation(key) + except KeyError: + pass + return original_setattr(self, key, value) + + cls.__setattr__ = custom_setattr + + +class GenerationTracker: + generation = 0 + dynamic_classes = ExactWeakKeyDictionary() + generation_values = ExactWeakKeyDictionary() + + @classmethod + def tag(cls, obj): + cls.generation_values[obj] = cls.generation + + @staticmethod + def mark_class_dynamic(cls): + assert issubclass(cls, torch.nn.Module) + GenerationTracker.dynamic_classes[cls] = True + + @classmethod + def get_generation_value(cls, obj): + if obj not in cls.generation_values: + return -1 + return cls.generation_values[obj] + + @classmethod + def check(cls, obj): + return ( + obj in cls.generation_values + and cls.generation_values[obj] == cls.generation + ) + + +def is_dynamic_nn_module(obj): + """Check for nn.Modules() created dynamically or mutated""" + if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__: + # A monkey patched `.forward` indicates something wacky is going on + return True + if hasattr(obj, "torchdynamo_force_dynamic"): + return obj.torchdynamo_force_dynamic + if is_lazy_module(obj): + return False + dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( + obj + ) + return dyn + + +def install_generation_tagging_init(): + """ + Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__ + so we can detect nn.Module instances created dynamically inside forward methods. + """ + + if getattr(Module, "___needs_generation_tag_patch", True): + init = Module.__init__ + + def patched_init(self, *args, **kwargs): + init(self, *args, **kwargs) + GenerationTracker.tag(self) + + Module.__init__ = patched_init + + setstate = Module.__setstate__ + + def patched_setstate(self, state): + setstate(self, state) + GenerationTracker.tag(self) + + Module.__setstate__ = patched_setstate + + Module.___needs_generation_tag_patch = False # type: ignore[attr-defined] + + GenerationTracker.generation += 1 diff --git a/MLPY/Lib/site-packages/torch/_dynamo/output_graph.py b/MLPY/Lib/site-packages/torch/_dynamo/output_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..c8164aacad717dc4dd3c84feb2f36bbe6dbdcbbe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/output_graph.py @@ -0,0 +1,2073 @@ +import collections +import contextlib +import copy +import functools +import itertools +import logging +import operator +import re +import sys +import traceback +import weakref +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union + +import sympy + +import torch._guards + +import torch._logging + +import torch.nn +import torch.utils._pytree as pytree +from torch import fx +from torch._guards import ( + Checkpointable, + GlobalContextCheckpointState, + GuardsCheckpointState, + Source, + TracingContext, +) +from torch._utils_internal import signpost_event +from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import SymNode +from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._sympy.interp import sympy_interp +from torch.utils._sympy.reference import PythonReferenceAnalysis +from torch.utils.weak import WeakTensorKeyDictionary + +from . import config, logging as torchdynamo_logging, variables +from .backends.registry import CompiledFn, CompilerFn +from .bytecode_transformation import ( + create_call_function, + create_instruction, + Instruction, + unique_id, +) +from .code_context import code_context +from .codegen import PyCodegen +from .current_scope_id import enter_new_scope +from .exc import ( + BackendCompilerFailed, + exceptions_allowed_to_be_fallback, + SkipFrame, + unimplemented, + unimplemented_with_warning, +) +from .guards import GuardBuilder, install_guard +from .mutation_guard import is_dynamic_nn_module +from .side_effects import SideEffects +from .source import ( + AttrSource, + BackwardStateSource, + ConstantSource, + GlobalStateSource, + is_constant_source, + is_from_local_source, + LocalSource, + ParamBufferSource, + ShapeEnvSource, + TensorProperty, + TensorPropertySource, +) +from .utils import ( + checkpoint_params, + CleanupHook, + clone_inputs, + count_calls, + counters, + dynamo_timed, + get_instruction_source_311, + get_static_address_type, + graph_break_reasons, + increment_op_count, + lazy_format_graph_code, + lazy_format_graph_tabular, + LazyString, + nn_module_proxy, + same, +) +from .variables.base import VariableTracker +from .variables.builder import ( + BackwardStateGraphArg, + GraphArg, + TrackedFake, + VariableBuilder, + wrap_fx_proxy, +) +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) + +from .variables.torch_function import TensorWithTFOverrideVariable + +log = logging.getLogger(__name__) +graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") +graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") +trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") + + +class OutputGraphState(NamedTuple): + input_source_to_var: Dict[Source, VariableTracker] + tracked_fakes: List[TrackedFake] + guard_state: GuardsCheckpointState + nn_modules: Optional[Dict[str, torch.nn.Module]] + register_finalizer_fns: List[Callable[[fx.GraphModule], None]] + global_state: Optional[Dict[str, bool]] + param_name_to_source: Optional[Dict[str, Source]] + side_effects: SideEffects + timestamp: int + non_compliant_ops: Set[torch._ops.OpOverload] + compliant_custom_ops: Set[torch._ops.OpOverload] + + def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]: + for k in self._fields: + if k == "guard_state": + r = self.guard_state.diff(other.guard_state) + if r is not None: + return r + continue + elif k == "side_effects": + r = self.side_effects.diff(other.side_effects) + if r is not None: + return r + continue + + sv = getattr(self, k) + ov = getattr(other, k) + if sv != ov: + return f"{prefix}{k} mismatch: {sv} != {ov}" + return None + + # Back compat .guards api + @property + def guards(self): + return self.guard_state.dynamo_guards + + +@functools.lru_cache(None) +def _step_logger(): + return torchdynamo_logging.get_step_logger(log) + + +@dataclass +class GraphCompileReason: + """Stores why a given output graph was compiled; i.e. what caused the graph break.""" + + reason: str + user_stack: List[traceback.FrameSummary] + + # Indicates if this was a graph compile reason due to graph break. + graph_break: bool = True + + def __post_init__(self): + if self.graph_break: + graph_break_reasons.append(self) + + +def _get_gen_rand_values_fn(random_calls): + def _gen_rand_values(): + return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] + + return _gen_rand_values + + +class FakeRootModule(torch.nn.Module): + """Trick the constructor of fx.GraphModule""" + + def __init__(self, nn_modules: Dict[str, torch.nn.Module]): + super().__init__() + for k, v in nn_modules.items(): + setattr(self, k, v) + + def __repr__(self): + return "FakeRootModule(...)" + + +class WrapperBackend: + def __init__(self, backend: CompilerFn): + self.backend: CompilerFn = backend + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + self.restore = checkpoint_params(gm) + self.gm = gm + copy_gm = copy.deepcopy(self.gm) + self.candidate = self.backend(copy_gm, example_inputs) + + if self.candidate is None or self.candidate is self.gm.forward: + return self.gm.forward + + if not config.verify_correctness: + return self.candidate + + # if verify_correctness=True + try: + correct = self.gm.forward(*clone_inputs(example_inputs)) + result = self.candidate(*clone_inputs(example_inputs)) + + # TODO: replace `same` function with the one in testing + if same(correct, result): + return self.candidate + + raise RuntimeError(f"incorrect results of backend {self}") + return self.gm.forward + + except Exception: + log.exception("error in verify_correctness") + raise + finally: + self.restore() + + +Scope = Dict[str, object] + + +class OutputGraph(Checkpointable[OutputGraphState]): + """ + Wrapper class to hold outputs of InstructionTranslator. Mainly the + generated fx.Graph. + + OutputGraph is 1:1 with a frame being processed. Each frame is associated + with some root InstructionTranslator. When user code calls a function, + we construct a InliningInstructionTranslator that continues to write into + the root InstructionTranslator's OutputGraph. + """ + + def __init__( + self, + code_options: Dict[str, Any], + compiler_fn: Optional[CompilerFn], + root_tx, + export: bool, + export_constraints, + frame_state, + local_scope: Scope, + global_scope: Scope, + f_code, + ): + super().__init__() + self.tracers = [SubgraphTracer(self, export_root=export)] + # Map from graph input's `Source` to its `VariableTracker` to + # de-duplicate graph inputs by source and reuse the tracker + self.input_source_to_var: Dict[Source, VariableTracker] = {} + self.export = export + self.export_constraints = export_constraints + self.frame_state = frame_state + self.tensor_weakref_to_sizes_strides = WeakTensorKeyDictionary() + self.cleanup_hooks: List[Callable[[], Any]] = [] + # compile_id is an id number for the current torch.compile + self.compile_id: int = next(_compile_id_counter) + # Set of globals installed via install_global* APIs + self.installed_globals: Set[str] = set() + + # TODO: maybe should just pass the entire f_code in here? Not + # sure... + self.co_fields = { + "co_name": f_code.co_name, + "co_filename": f_code.co_filename, + "co_firstlineno": f_code.co_firstlineno, + } + + # tracked_fakes says where any tensor that was wrapped to fake came + # from. It is similar to GraphArg, in that all GraphArgs will get + # will get added to TrackedFakes, but TrackedFakes also contains + # GraphArgs that got pruned, and things like Tensor attributes which + # aren't explicit graph inputs. Used by shape guard + self.tracked_fakes: List[TrackedFake] = [] + + # List of symbols for which we have exact bindings in the arguments + # already + self.bound_symbols: Set[sympy.Symbol] = set() + + shape_env = ShapeEnv( + # Reference Cycle! + # Share a reference to the list of TrackedFake. + # + # ShapeEnv needs this in order to be able to reproduce the call + # to produce_guards at an arbitrary time point. That is because + # TrackedFake instances may have its metadata changed throughout + # the program execution. + tracked_fakes=self.tracked_fakes, + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + co_fields=self.co_fields, + ) + + # In export mode, we force the shape_env to strictly disallow any constraining + # of the user marked dynamic dims + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=shape_env, + # TODO (tmanlaibaatar) Remove this once we always lift params and buffers + allow_non_fake_inputs=True if self.export else False, + ) + self.tracing_context: TracingContext = TracingContext(fake_mode) + self.init_ambient_guards() + + # Map each tensor id to a list of sources. This is necessary because + # tensor ids cannot be recovered from tracked fakes (in general). + # We use this map to interpret (i.e., check for violations of) constraints, + # specifically equality constraints, which have shared tensor ids in them. + # This map should also be generally useful, e.g., for (de)serialization. + self.tracked_fakes_id_to_source: Dict[ + int, List[Source] + ] = collections.defaultdict(list) + # Stores the full fqn of a param or buffer to the relevant source. + self.param_name_to_source: Optional[Dict[str, Source]] = dict() + self.side_effects = SideEffects() + self.code_options = dict(code_options) + self.output_instructions: List[Instruction] = [] + # used to track nodes that are added between calls of copy_graphstate + # and restore_graphstate + self.timestamp = 0 + + # A list of register_finalizer_fns to apply to the output graph module + self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = [] + + # Not checkpointed + self.compiler_fn: Optional[CompilerFn] = compiler_fn + self.global_scope = global_scope + self.local_scope = local_scope + self.root_tx = root_tx + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + + # Given a source, what are the user stacks of all locations that + # accessed it? + # + # For efficiency, we only populate this: + # - During export, and + # - If the source could potentially lead to a spurious export input + # + # Feel free to populate this more frequently if other use-cases arise, + # but be aware that we have to generate full stacks for each + # recording! + self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {} + + self._current_tx: List[InstructionTranslatorBase] = [] + self.cleanups: List[CleanupHook] = [] + self.should_exit = False + self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {} + self.torch_function_enabled = torch._C._is_torch_function_enabled() + # Tracks if the output graph has a user defined allowed function in the + # graph. This is used later to determine if we should fallback to eager + # for certain exceptions. THe idea is that if the user has applied + # allow_in_graph, they would like to see the error instead of falling + # back for backend errors. + self.has_user_defined_allowed_in_graph = False + + # Tracks a list of called ops that were not tagged with "pt2_compliant_tag". + # This information is useful for logging. + self.non_compliant_ops: Set[torch._ops.OpOverload] = set({}) + + # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag". + # This information is useful for logging. + self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({}) + + # We save the global torch state here to be restored in case of graph + # breaks. The relevant issue is seen here + # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086 + # where inlining of a function changes the global state (because of the + # presence of torch.no_grad) and there is a graph break. + self.save_global_state() + + # Tracks the original FQNs of the constant tensors from the original graph, + # i.e. buffers and parameters. + self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {} + + # All calls to random() are replaced with a single call to __gen_rand_values + # functions that returns a tuple of random values for each original call. + # random_calls tracks calls to random() and random_values_var stores the name of + # the variable that stores __gen_rand_values results. + self.random_calls: List[ + Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] + ] = [] + self.random_values_var = None + + # Bytecode to insert right before we call the graph + self.pregraph_bytecode: List[Instruction] = [] + + # Use to pass values to backward hooks when using compiled autograd + self.backward_state: Dict[str, VariableTracker] = {} + self.backward_state_proxy: Optional[torch.fx.Proxy] = None + self.backward_state_var: Optional[str] = None + + def add_backward_state_hook(self, hook: VariableTracker): + name = f"hook{len(self.backward_state)}" + assert name not in self.backward_state + self.backward_state[name] = hook + return name, self.get_backward_state_proxy() + + def get_backward_state_proxy(self): + if self.backward_state_proxy is None: + if self.export: + unimplemented("backward_state does not support export") + self.backward_state_proxy = self.root_tracer.create_graph_input( + "dynamo_backward_state", BackwardState, source=BackwardStateSource() + ) + self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg() + self.backward_state_proxy.node.meta["example_value"] = BackwardState() + self.backward_state_var = self.new_var() + return self.backward_state_proxy + + # This gets its own helper function so guards DEBUG logs are more informative + def init_ambient_guards(self): + # Register a SHAPE_ENV guard to make sure we setup shape guards + # that show up in ShapeEnv + self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) + + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS) + ) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE)) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) + + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) + ) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH)) + + def add_cleanup_hook(self, fn: Callable[[], Any]): + self.cleanup_hooks.append(fn) + + def call_cleanup_hooks(self): + for hook in reversed(self.cleanup_hooks): + hook() + self.cleanup_hooks.clear() + + @property + def root_tracer(self): + return self.tracers[0] + + @property + def current_tracer(self): + return self.tracers[-1] + + def is_root_tracer(self): + # Helper to tell if we are inside the higher order operator tracing. + return len(self.tracers) == 1 + + @property + def graph(self): + return self.current_tracer.graph + + # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. + @graph.setter + def graph(self, value): + self.current_tracer.graph = value + + @property + def input_name_to_proxy(self): + return self.current_tracer.input_name_to_proxy + + @property + def real_value_cache(self): + return self.current_tracer.real_value_cache + + # If you are here, and you're looking for create_graph_input, + # to avoid ambiguity, please call one of the following: + # - self.current_tracer.create_graph_input + # - self.root_tracer.create_graph_input + # See NOTE [HigherOrderOperator tracing design] for more context. + + def create_proxy(self, *args, **kwargs): + return self.current_tracer.create_proxy(*args, **kwargs) + + def create_node(self, *args, **kwargs): + return self.current_tracer.create_node(*args, **kwargs) + + def remove_node(self, *args, **kwargs): + return self.current_tracer.remove_node(*args, **kwargs) + + @contextlib.contextmanager + def subtracer(self, source_target, prior_tracer): + new_scope_ctx = enter_new_scope() + try: + if prior_tracer: + # Lineage MUST stay preserved + assert prior_tracer.parent is self.current_tracer + new_scope_ctx.__enter__() + tracer = ( + prior_tracer + if prior_tracer + else SubgraphTracer( + self, parent=self.current_tracer, source_target=source_target + ) + ) + self.tracers.append(tracer) + yield tracer + finally: + new_scope_ctx.__exit__(None, None, None) + self.tracers.pop() + + @property + def output(self): + return self + + @property + def fake_mode(self): + return self.tracing_context.fake_mode + + @property + def shape_env(self): + return self.tracing_context.fake_mode.shape_env + + @property + def guards(self) -> torch._guards.GuardsSet: + return self.tracing_context.guards_context.dynamo_guards + + @property + def nn_modules(self) -> Dict[str, Any]: + return self.tracing_context.module_context.nn_modules + + def save_global_state(self, out=None): + """ + Saves to out if it is provided. Else saves to the tracing context's global_state. + """ + global_state = ( + out if out is not None else self.tracing_context.global_context.global_state + ) + + # TODO - Consider having a torch level API for torch_function_state. As + # of now, we create a ref cycle by passing the + # output.set_torch_function_state to + # output.tracing_context.global_context.global_state. In the interim, + # the problem can be solved by manually set + # output.tracing_context.global_context.global_state to None at cleanup. + global_state["torch_function_enabled"] = ( + self.set_torch_function_state, + self.torch_function_enabled, + ) + global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) + global_state["autocast_enabled"] = ( + torch.set_autocast_enabled, + torch.is_autocast_enabled(), + ) + global_state["autocast_cpu_enabled"] = ( + torch.set_autocast_cpu_enabled, + torch.is_autocast_cpu_enabled(), + ) + global_state["autocast_gpu_dtype"] = ( + torch.set_autocast_gpu_dtype, + torch.get_autocast_gpu_dtype(), + ) + global_state["autocast_cpu_dtype"] = ( + torch.set_autocast_cpu_dtype, + torch.get_autocast_cpu_dtype(), + ) + global_state["autocast_cache_enabled"] = ( + torch.set_autocast_cache_enabled, + torch.is_autocast_cache_enabled(), + ) + + def push_tx(self, tx): + self._current_tx.append(tx) + + def pop_tx(self): + return self._current_tx.pop() + + @property + def current_tx(self): + return self.root_tx if not self._current_tx else self._current_tx[-1] + + def copy_graphstate(self) -> OutputGraphState: + """Create a checkpoint of the current state by copying everything""" + assert self.param_name_to_source is not None + guards_graph_state = self.tracing_context.guards_context.copy_graphstate() + module_state = self.tracing_context.module_context.copy_graphstate() + global_state = self.tracing_context.global_context.copy_graphstate() + state = OutputGraphState( + dict(self.input_source_to_var), + list(self.tracked_fakes), + guards_graph_state, + module_state, + list(self.register_finalizer_fns), + global_state, + dict(self.param_name_to_source), + self.side_effects.clone(), + self.timestamp, + set(self.non_compliant_ops), + set(self.compliant_custom_ops), + ) + self.timestamp += 1 + return state + + def restore_graphstate(self, state: OutputGraphState): + """Restore a checkpoint created by self.copy_graphstate()""" + ( + self.input_source_to_var, + self.tracked_fakes, + guards_state, + module_state, + self.register_finalizer_fns, + global_state, + self.param_name_to_source, + self.side_effects, + self.timestamp, + self.non_compliant_ops, + self.compliant_custom_ops, + ) = state + self.tracing_context.guards_context.restore_graphstate(guards_state) + self.tracing_context.module_context.restore_graphstate(module_state) + self.tracing_context.global_context.restore_graphstate(global_state) + + # FX deepcopy doesn't work for a partially created graph, so just remove new nodes + removed_nodes = 0 + for node in reversed(list(self.graph.nodes)): + if ( + node.meta["creation_timestamp"] > self.timestamp + # placeholders here may have been lazily added by existing objects + and node.op != "placeholder" + ): + # Erasing node alone does not remove the meta information + # So, remove the help tensor explicitly + if "example_value" in node.meta: + del node.meta["example_value"] + self.remove_node(node) + self.real_value_cache.pop(node, None) + removed_nodes += 1 + log.debug("restore_graphstate: removed %s nodes", removed_nodes) + + def add_symbol_bindings(self, arg: GraphArg): + # Insert implicit size vars as necessary. With dynamic shapes, we + # maintain the invariant that every sizevar gets a direct SymInt input + # into the graph. This means downstream graph transforms can assume + # every size variable is explicitly bound and accessible, instead of + # having to pull it out implicitly from tensors. + + if self.export: + return + + assert arg.fake_tensor is not None + + def bind_symint(s, prop): + if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): + return + s0 = s.node.expr + if s0 in self.bound_symbols: + return + self.bound_symbols.add(s0) + log.debug("bind_symint %s %s", s, prop.name()) + # TODO: don't readd symint if we already have it in graph + # (this is harmless because we do remove the unused ones later) + proxy = self.root_tracer.create_graph_input( + str(s0), + torch.SymInt, + before=True, + source=prop, + ) + proxy.node.meta["example_value"] = s + proxy.node.meta["grapharg"] = GraphArg( + prop, + s, + is_unspecialized=False, + fake_tensor=None, + is_tensor=False, + ) + + def handle_tensor(t, src): + for i, s in enumerate(t.size()): + bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i)) + for i, s in enumerate(t.stride()): + bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i)) + bind_symint( + t.storage_offset(), + TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), + ) + if is_traceable_wrapper_subclass(t): + attrs, ctx = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + handle_tensor(inner_t, AttrSource(src, attr)) + + handle_tensor(arg.fake_tensor, arg.source) + + def count_calls(self): + return count_calls(self.graph) + + def is_empty_graph(self): + return len(list(self.graph.nodes)) == 0 + + def get_submodule(self, keys): + assert keys + obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules + for k in keys.split("."): + if isinstance(obj, dict): + obj = obj[k] + else: + obj = getattr(obj, k) + return obj + + def new_var(self, name="tmp"): + existing = set(self.code_options["co_varnames"]) + for i in itertools.count(): + var = f"{name}_{i}" + if var not in existing: + self.code_options["co_varnames"] += (var,) + return var + + def update_co_names(self, name): + """Ensure self.code_options.co_names contains name""" + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + + @staticmethod + def module_key_name(*names): + # create a new unique name + name = "_".join(map(str, names)) + # Strip the guard lookup L/G access + name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name) + # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv + name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) + # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv + name = re.sub(r"[^a-zA-Z0-9]", "_", name) + + if not name or not name[0].isalpha(): + name = "sub" + name + + return name + + def register_attr_or_module( + self, + target: Union[torch.nn.Module, torch.Tensor, Any], + *names, + **options, + ): + if is_dynamic_nn_module(target): + return variables.UnspecializedNNModuleVariable(target, **options) + + options = dict(options) + assert "source" in options + source = options["source"] + assert not isinstance(source, ParamBufferSource) + + if isinstance(target, torch.Tensor): + tracer = self.current_tracer + if not self.is_root_tracer(): + # For higher order ops, we don't want to insert the get_attr in + # innermost graph. Instead, we want to raise the params/buffers + # as inputs to the higher-order graph, and register them as + # get_attrs in the root tracer. + + # Note that Dynamo will still call lift_tracked_freevar_to_input + # when these inputs are encountered for the inner graph. The + # only difference is what happens at the root tracer for + # nn.Parameters vs free inputs. The free inputs are registered + # as placeholders in the root graph, whereas the nn.Parameters + # are registered as get_attr nodes in the root graph. + tracer = self.root_tracer + + if not is_constant_source(source): + install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH)) + + if get_static_address_type(target) == "guarded": + install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH)) + + def wrap_name(module_key): + assert self.param_name_to_source is not None + self.param_name_to_source[module_key] = source + + return wrap_fx_proxy( + self.root_tx, + tracer.create_proxy("get_attr", module_key, tuple(), {}), + example_value=target, + **options, + ) + + elif isinstance(target, torch.nn.Module): + assert isinstance(target, torch.nn.Module) + + install_guard(source.make_guard(GuardBuilder.NN_MODULE)) + + def wrap_name(module_key): + return NNModuleVariable(type(target), module_key, target, **options) + + elif isinstance(target, (torch.SymInt, torch.SymFloat)): + # HACKY CODE REGION BEGIN + # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS + # This ultimately gets written to self.nn_modules, which is unfortunate + # Attrs that are tenors and symints and such need to be migrated to have their + # own storage + # alas, this is like this for now + + def wrap_name(module_key): + return SymNodeVariable.create( + self, + self.create_proxy("get_attr", module_key, tuple(), {}), + sym_num=target, + **options, + ) + + # HACKY CODE REGION END + else: + + def wrap_name(module_key): + self.output.update_co_names(module_key) + self.global_scope[module_key] = target + return VariableBuilder(self, ConstantSource(source_name=module_key))( + target + ) + + for k, v in self.nn_modules.items(): + if v is target: + # it already exists + return wrap_name(k) + + name = OutputGraph.module_key_name(*names) + + base = name + for i in itertools.count(): + if name not in self.nn_modules: + self.nn_modules[name] = target + if isinstance(target, torch.nn.Module): + + def register_leaf_name(leaf_name): + assert self.param_name_to_source is not None + new_source = ParamBufferSource(source, leaf_name) + new_name = f"{name}.{leaf_name}" + self.param_name_to_source[new_name] = new_source + if isinstance(source, LocalSource): + self.dynamo_flat_name_to_original_fqn[ + OutputGraph.module_key_name(new_source.name()) + ] = leaf_name + + # annoying, but there are cases when we do not have parameters + # see test_nn_moduledict_contains + if hasattr(target, "_parameters"): + for leaf_name, _ in target.named_parameters(): + register_leaf_name(leaf_name) + if hasattr(target, "_buffers"): + for leaf_name, _ in target.named_buffers(): + register_leaf_name(leaf_name) + + return wrap_name(name) + name = f"{base}_{i}" + + raise AssertionError("unreachable") + + def compile_subgraph( + self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None + ): + """ + Generate a subgraph to continue execution on user code. + Automatically restore live variables. + """ + assert reason is not None + + from .decorators import disable + + self.partial_convert = partial_convert + self.compile_subgraph_reason = reason + self.should_exit = True + + log.debug("COMPILING GRAPH due to %s", reason) + + if not all(block.can_restore() for block in tx.block_stack): + unimplemented("compile_subgraph with block_depth != 0") + + prefix_insts: List[Instruction] = [] + if sys.version_info >= (3, 11): + # prefix instructions (Python 3.11+) + for inst in tx.prefix_insts: + if inst.opname == "MAKE_CELL": + prefix_insts.append( + create_instruction("MAKE_CELL", argval=inst.argval) + ) + elif inst.opname == "COPY_FREE_VARS": + prefix_insts.append( + create_instruction( + "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) + ) + ) + else: + prefix_insts.append(copy.copy(inst)) + assert not ( + self.pregraph_bytecode and self.export + ), "export does not support pregraph_bytecode" + prefix_insts.extend(self.pregraph_bytecode) + + def append_prefix_insts(): + self.add_output_instructions(prefix_insts) + prefix_insts.clear() + + for block in reversed(tx.block_stack): + block.exit(tx) + + self.cleanup_graph() + tx.prune_dead_locals() + stack_values = list(tx.stack) + # Use nn.Module "proxies" in the constructed GraphModule so that + # the resulting GM does not hold additional strong references to the original modules. + # This prevents a strong ref cycle where Dynamo created code holds on to references + # to modules that also have Dynamo code cache invalidation checks. + # When cache invalidation runs, the generated GM will be invalidated, which also deletes + # the proxies. + nn_modules_proxies = { + name: nn_module_proxy(mod) for name, mod in self.nn_modules.items() + } + root = FakeRootModule(nn_modules_proxies) + # Add all the local vars to the "stack" so restore at the end + restore_vars = [] + val_to_names: Dict[VariableTracker, List[str]] = {} + if stack_values: + val_to_names[stack_values[-1]] = list() + # NB: Typically (i.e., for graph compile from RETURN_VALUE), + # symbolic_locals will be empty at this point, as prune_dead_locals + # will clear out all of symbolic_locals because RETURN_VALUE is the + # last instruction and no more locals are used. The fanciness here + # is only needed for partial graphs. + for k, v in tx.symbolic_locals.items(): + # Note! this explicitly uses .local_name for matching + # Failure to do so will cause spurious registrations in val_to_names. + # This will in turn result in spurious variables showing up in the graph. + # This was very tricky to debug. For an example, dump the graph at call_user_compiler + # while running test_subgraphs.py + if isinstance(v.source, LocalSource) and v.source.local_name == k: + continue # no need to restore initial state + if v not in val_to_names: + val_to_names[v] = list() + val_to_names[v].append(k) + for v in val_to_names.keys(): + restore_vars.extend(val_to_names[v]) + stack_values.extend([v] * len(val_to_names[v])) + + # to handle random calls + if len(self.random_calls) > 0: + append_prefix_insts() + random_calls_instructions = [] + self.random_values_var = self.new_var("random_values") + rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) + rand_fn_name = self.install_global("__gen_rand_values", rand_fn) + codegen = PyCodegen(tx, root) + random_calls_instructions.extend( + codegen.load_function_name(rand_fn_name, True) + ) + random_calls_instructions.extend(create_call_function(0, False)) + random_calls_instructions.append( + codegen.create_store(tx.output.random_values_var), + ) + self.add_output_instructions(random_calls_instructions) + + if ( + stack_values + and all( + not isinstance( + v, + ( + UnspecializedPythonVariable, + NumpyNdarrayVariable, + TensorWithTFOverrideVariable, + ), + ) + for v in stack_values + ) + and all(isinstance(x, TensorVariable) for x in stack_values) + and len(set(stack_values)) == len(stack_values) + and self.side_effects.is_empty() + and not len(tx.debug_locals) != 0 + and not self.backward_state + ): + append_prefix_insts() + # optimization to generate better code in a common case + self.add_output_instructions( + self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) + + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] + ) + else: + graph_output_var = self.new_var("graph_out") + pass1 = PyCodegen(tx, root, graph_output_var) + self.codegen_suffix(tx, stack_values, pass1) + + # one more time now that we have established tempvars + pass2 = PyCodegen( + tx, + root, + graph_output_var, + tempvars={val: None for val, count in pass1.uses.items() if count > 1}, + ) + self.codegen_suffix(tx, stack_values, pass2) + + output = [] + if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: + output.extend( + self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) + ) + + if len(pass2.graph_outputs) != 0: + output.append(pass2.create_store(graph_output_var)) + else: + output.append(create_instruction("POP_TOP")) + append_prefix_insts() + self.add_output_instructions(output + pass2.get_instructions()) + + # restore all the live local vars + self.add_output_instructions( + [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] + ) + + def codegen_suffix(self, tx, stack_values, cg): + if self.backward_state: + assert not self.export + for name, val in self.backward_state.items(): + cg(val) + cg.append_output(cg.create_load(self.backward_state_var)) + cg.store_attr(name) + self.side_effects.codegen_hooks(cg) + self.side_effects.codegen_save_tempvars(cg) + + # Return variables used for logging at the end + for debug_var, args in tx.debug_locals: + cg(debug_var) + for arg in args: + cg(arg) + cg.extend_output(create_call_function(len(args), True)) + + cg.restore_stack(stack_values, value_from_source=not tx.export) + self.side_effects.codegen_update_mutated(cg) + + def cleanup_graph(self): + """ + Remove "creation_timestamp" from node meta + + Remove this pattern from the graph: + torch._C._set_grad_enabled(False) + torch._C._set_grad_enabled(True) + """ + assert self.should_exit + nodes = list(self.graph.nodes) + for node in nodes: + node.meta.pop("creation_timestamp", None) + + grad_enabled = torch.is_grad_enabled() + for node1, node2 in zip(nodes, nodes[1:]): + if ( + node1.target is torch._C._set_grad_enabled + and tuple(node1.args) == (not grad_enabled,) + and not node1._erased + ): + grad_enabled = node1.args[0] + if ( + node2.target is torch._C._set_grad_enabled + and tuple(node2.args) == (not grad_enabled,) + and not node2._erased + ): + grad_enabled = node2.args[0] + self.graph.erase_node(node1) + self.graph.erase_node(node2) + + def get_graph_sizes_structured(self): + ret = {} + for node in self.graph.nodes: + example_value = node.meta.get("example_value", None) + if isinstance(example_value, torch._subclasses.FakeTensor): + size = example_value.size() + ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] + return ret + + def get_graph_sizes(self, name: str): + graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" + graph_sizes_str += f"===== {name} =====\n" + for node in self.graph.nodes: + example_value = node.meta.get("example_value", None) + if isinstance(example_value, torch._subclasses.FakeTensor): + size = example_value.size() + graph_sizes_str += f"{node.name}: {tuple(size)}\n" + concrete_size = [] + has_symint = False + for sz in size: + if isinstance(sz, int): + concrete_size.append(sz) + elif isinstance(sz, torch.SymInt): + has_symint = True + concrete_size.append(sz.node.hint) + else: + break + else: + if has_symint: + graph_sizes_str += ( + f"{node.name} (concrete): {tuple(concrete_size)}\n" + ) + return graph_sizes_str + + @contextlib.contextmanager + def restore_global_state(self): + """ + Momentarily restores the global state to what it was prior to tracing the current output + """ + prior_global_state = self.tracing_context.global_context.copy_graphstate() + current_global_state: Dict[str, Tuple[Any, bool]] = {} + self.save_global_state(out=current_global_state) + try: + # Set to state prior to tracing the graph + self.tracing_context.global_context.restore_graphstate(prior_global_state) + yield + finally: + # Reset to state at the current time (e.g. before calling the user compiler) + self.tracing_context.global_context.restore_graphstate( + GlobalContextCheckpointState(current_global_state) + ) + + @torch._guards.TracingContext.clear_frame() + def compile_and_call_fx_graph(self, tx, rv, root): + """ + Generate code from self.graph and return the Instruction()s to + call that generated code. + """ + from .decorators import disable + + assert self.should_exit + + name = unique_id("__compiled_fn") + + assert isinstance(rv, list) + assert isinstance(root, FakeRootModule) + self.create_node( + "output", + "output", + (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), + {}, + ) + self.insert_deferred_runtime_asserts(root, name) + # NB: deferred runtime asserts can keep graphargs live, so make sure + # those are inserted before pruning + self.remove_unused_graphargs() + ncalls = count_calls(self.graph) + counters["stats"]["calls_captured"] += ncalls + + # free a bit of memory + self.real_value_cache.clear() + + gm = _make_graph_module(root, self.graph) + for register_finalizer in self.register_finalizer_fns: + register_finalizer(gm) + + gm.compile_subgraph_reason = self.compile_subgraph_reason + gm.meta[ + "dynamo_flat_name_to_original_fqn" + ] = self.dynamo_flat_name_to_original_fqn.copy() + + graph_code_log.debug("%s", lazy_format_graph_code(name, gm)) + torch._logging.trace_structured( + "dynamo_output_graph", + lambda: {"sizes": self.get_graph_sizes_structured()}, + payload_fn=lambda: gm.print_readable(print_output=False), + ) + graph_tabular_log.debug("%s", lazy_format_graph_tabular(name, gm)) + graph_sizes_log.debug("%s", LazyString(lambda: self.get_graph_sizes(name))) + self.call_cleanup_hooks() + old_fake_mode = self.tracing_context.fake_mode + if not self.export: + # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting + backend_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=old_fake_mode.shape_env, + ) + # TODO(voz): Ostensibily, this should be scoped and + # restore back to old_fake_mode, but doing so currently violates + # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode + self.tracing_context.fake_mode = backend_fake_mode + + with self.restore_global_state(): + compiled_fn = self.call_user_compiler(gm) + compiled_fn = disable(compiled_fn) + + counters["stats"]["unique_graphs"] += 1 + # This is safe because we pre-process name to be unique + self.install_global_unsafe(name, compiled_fn) + + cg = PyCodegen(tx) + cg.make_call_generated_code(name) + return cg.get_instructions() + + @property + def placeholders(self) -> List[fx.Node]: + r = [] + for node in self.graph.nodes: + if node.op == "placeholder": + r.append(node) + continue + break + return r + + @property + def graphargs(self) -> List[GraphArg]: + return [node.meta["grapharg"] for node in self.placeholders] + + @dynamo_timed(phase_name="backend_compile") + def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + assert self.compiler_fn is not None + tot = 0 + placeholders = [] + for node in gm.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + tot += 1 + if node.op == "placeholder": + placeholders.append(node) + increment_op_count(tot) + for pl in placeholders: + arg = pl.meta["grapharg"] + # TODO: Why isn't this stored in meta :think: + pl._dynamo_source = arg.source + + gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] + gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment] + + try: + name = ( + self.compiler_fn.__name__ + if hasattr(self.compiler_fn, "__name__") + else "" + ) + _step_logger()(logging.INFO, f"calling compiler function {name}") + compiler_fn = self.compiler_fn + if config.verify_correctness: + compiler_fn = WrapperBackend(compiler_fn) + compiled_fn = compiler_fn(gm, self.example_inputs()) + _step_logger()(logging.INFO, f"done compiler function {name}") + assert callable(compiled_fn), "compiler_fn did not return callable" + except exceptions_allowed_to_be_fallback as e: + if self.has_user_defined_allowed_in_graph: + raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( + e.__traceback__ + ) from None + msg = ( + "Backend compiler failed with a fake tensor exception at \n" + f"{self.root_tx.format_frame_summary()}" + "Adding a graph break." + ) + unimplemented_with_warning(e, self.root_tx.f_code, msg) + except SkipFrame as e: + # The backend compiler has requested that we skip the frame, instead of + # aborting execution. + raise e + except Exception as e: + raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( + e.__traceback__ + ) from None + + signpost_event( + "dynamo", + "OutputGraph.call_user_compiler", + { + **self.co_fields, + "op_count": tot, + "node_count": len(gm.graph.nodes), + "input_count": len(placeholders), + }, + ) + + return compiled_fn + + def example_inputs(self) -> List[torch.Tensor]: + result = [] + for arg in self.graphargs: + result.append(arg.example) + return result + + def remove_unused_graphargs(self) -> None: + assert self.should_exit + # Miniature DCE pass, but only for obviously trivial operations + for node in reversed(list(self.graph.nodes)): + if len(list(node.users)) == 0: + if node.op == "get_attr": + self.remove_node(node) + elif node.op == "call_function" and node.target is operator.getitem: + self.remove_node(node) + + def placeholder_binds_symbol(node): + arg = node.meta["grapharg"] + example = arg.example + if isinstance(example, torch.SymInt) and isinstance( + example.node.expr, sympy.Symbol + ): + return example.node.expr + return None + + def remove_unused(node): + log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) + # I'm not really sure why you need to delete these from the + # node since the node is going to get removed + del node.meta["grapharg"] + self.remove_node(node) + self.real_value_cache.pop(node, None) + + used_symbols = set() + recheck_placeholders = [] + for node in self.placeholders: + binds_symbol = placeholder_binds_symbol(node) is not None + # Don't delete symbol bindings yet + if binds_symbol: + if not node.users: + recheck_placeholders.append(node) + else: + if not node.users and not isinstance( + node.meta["grapharg"], BackwardStateGraphArg + ): + remove_unused(node) + else: + # Register the free symbols as uses + arg = node.meta["grapharg"] + if isinstance(arg, BackwardStateGraphArg): + continue + fake = ( + arg.fake_tensor if arg.fake_tensor is not None else arg.example + ) + used_symbols |= free_symbols(fake) + + # After removing unused graphargs, prune unused binds_symbol + for node in recheck_placeholders: + symbol = placeholder_binds_symbol(node) + if symbol is not None: + if symbol not in used_symbols: + remove_unused(node) + else: + # Make sure we delete later occurrences of the same symbol + used_symbols.remove(symbol) + + # TODO: this is a generic pass that should live outside of Dynamo + def insert_deferred_runtime_asserts(self, root, name) -> None: + """ + During tracing, we may have discovered that some data-dependent values + had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime + that x.item() >= 0. This asserts can happen unpredictably during fake + tensor propagation, so we cannot conveniently insert them into the FX graph + when they occur. Instead, we accumulate them in the ShapeEnv, and in this + pass insert them into the graph as proper tests. + """ + # TODO: Request simplification on runtime asserts before emitting them + ras_by_symbol = self.shape_env.deferred_runtime_asserts.copy() + + if not any(ras for ras in ras_by_symbol.values()): + return + + gm = fx.GraphModule(root, self.graph) + graph_code_log.debug( + "%s", + lazy_format_graph_code(f"pre insert_deferred_runtime_asserts {name}", gm), + ) + + # We are going to mutate the dict + symbol_to_proxy = {} + placeholders = set() + last_placeholder = None + for node in self.graph.nodes: + if node.op != "placeholder": + last_placeholder = node + break + placeholders.add(node) + assert last_placeholder is not None + + # Identify what symbols we need to reify. This isn't strictly needed + # but helps reduce churn on the graph + needed_symbols: Set[sympy.Symbol] = set() + for ras in ras_by_symbol.values(): + for ra in ras: + needed_symbols.update(free_symbols(ra.expr)) + + log.debug("needed_symbols = %s", needed_symbols) + + for node in self.graph.nodes: + # Placeholders can match symbols, but when we destructure them + # with size we have to make sure we insert the nodes after all + # the placeholders + with self.graph.inserting_before( + node.next if node not in placeholders else last_placeholder.next + ): + if "example_value" not in node.meta: + continue + + defs = [] + + # For every new unbacked symbol, we need an fx.Node representing + # precisely this value. There are a few places where the unbacked + # symbol could have come from, and we will check them to setup + # these nodes. + # + # For a case like item(), this is trivial (no new node is added.) + # + # For nonzero(), we need to add something like i0 = out.size(0) + # + # We could end up with duplicate nodes this way but it is not a + # big deal. + # + # We also do this to setup backed SymInts, but those are all going + # to be matched from placeholders + def match_symbol(symint, cb): + if ( + isinstance(symint, torch.SymInt) + and isinstance(symint.node, SymNode) + and isinstance(s := symint.node.expr, sympy.Symbol) + and s not in symbol_to_proxy + and s in needed_symbols + ): + symbol_to_proxy[s] = fx.Proxy(cb()) + log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s]) + defs.append(s) + + match_symbol(node.meta["example_value"], lambda: node) + if isinstance(t := node.meta["example_value"], torch.Tensor): + for i, s in enumerate(t.size()): + match_symbol( + s, lambda: self.graph.call_method("size", (node, i)) + ) + for i, s in enumerate(t.stride()): + match_symbol( + s, lambda: self.graph.call_method("stride", (node, i)) + ) + match_symbol( + t.storage_offset(), + lambda: self.graph.call_method("storage_offset", (node,)), + ) + + for i0 in defs: + ras = ras_by_symbol.pop(i0, []) + # Before we perform any asserts, first apply range + # refinement. This is important, because if we are going + # to retrace the graph (and we typically are if we send + # the graph to AOTAutograd), we need to make sure we apply + # range refinement (ala _check_is_size) first, BEFORE we + # run any of the asserts. Otherwise, we may decide to + # perform substitutions based on the asserts which we then + # can't back out, because value ranges can only be applied + # to asserts.) + # + # A perhaps better long term plan is to avoid this order + # dependence by making it possible to refine ranges on + # arbitrary expressions, not just symbols. But it is not + # so easy to make use of this information, see + # https://twitter.com/ezyang/status/1745801370299482492 + # We actually made an attempt at this in + # https://github.com/pytorch/pytorch/pull/119043 + # which didn't work. + # + # Another ideas for how to do this: + # - Have bound_sympy be the source of truth of the ranges of any expression + # - Cache intermediate results for every subexpression of bound_sympy + # - This cache should be possible to edit to refine ranges + # + # One issue with this proposal is that if + # we have a bound on 2x, we are not going to be able to + # apply it for 4x. Similarly, we may have bounds for an + # equivalent expression that we are not applying because + # it's not a perfect match (e.g. x < y vs y > x)". + # + # The first issue we already have it and it's impossible + # to solve in general, so any implementation on a best + # effort basis should do. + # + # The second issue is a preexisting one. It can be mitigated + # with a normalisation algorithm. In general, it may also + # be on a best effort basis, but since our grammar is not + # terribly difficult, chances are we could even fully + # normalise SymPy expressions... who knows. + + if i0 in self.shape_env.size_like: + self.graph.call_function( + torch._check_is_size, (symbol_to_proxy[i0].node,) + ) + + vr = self.shape_env.var_to_range[i0] + if not self.shape_env._default_unspecified_value_range().issubset( + vr + ): + # The runtime range is constrained, so add a runtime + # assert and also explicitly refine the range + # (refinement should not be necessary once runtime + # asserts cause refinement, but that's NYI) + def convert(s): + try: + return int(s) + except TypeError: + return None + + self.graph.call_function( + torch._constrain_as_value, + ( + symbol_to_proxy[i0].node, + convert(vr.lower), + convert(vr.upper), + ), + ) + + for ra in ras: + log.debug("inserting runtime assert %s", ra.expr) + # Need to process ALL free symbols, not just unbacked ones + fvs = free_symbols(ra.expr) + missing = fvs - symbol_to_proxy.keys() + if missing: + i1 = sorted(missing)[0] + # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689 + # assert self.shape_env.is_unbacked_symint(i1), i1 + ras_by_symbol.setdefault(i1, []).append(ra) + else: + # Convert the sympy expression into a sequence of FX + # nodes + res = sympy_interp( + PythonReferenceAnalysis, symbol_to_proxy, ra.expr + ).node + self.graph.call_function( + torch.ops.aten._assert_scalar.default, + # TODO: use ra.msg here, but it's pretty + # useless right now + ( + res, + f"Deferred runtime assertion failed {ra.expr}", + ), + ) + + def add_output_instructions(self, prefix: List[Instruction]) -> None: + """ + We call this on the creation of a new compiled subgraph that is inserted + before user code. + """ + self.output_instructions.extend(prefix) + self.should_exit = True + + def install_global_unsafe(self, name, value) -> None: + """ + WARNING: prefer the safer `install_global_by_id/install_global`. + torch.compile instances should be independent of each other; + one footgun is to have one instance depend on the existence of + a global installed by another instance. This can happen if we mangle + a global the same way across both instances. + """ + assert name not in self.installed_globals + self.installed_globals.add(name) + self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) + + def install_global_by_id(self, prefix, value) -> str: + """ + Installs a global if it hasn't been installed already. + This is determined by (prefix, id(value)) pair. + + Returns the name of the newly installed global. + """ + # NB: need self.compile_id to distinguish this global + # from another global created in a different torch.compile instance + name = f"{prefix}_{id(value)}_c{self.compile_id}" + if name in self.installed_globals: + return name + self.install_global_unsafe(name, value) + return name + + def install_global(self, prefix, value) -> str: + """ + Installs a global, generating a unique name for it. + + Returns the name of the newly installed global. + """ + # NB: unique_id is unique, even across torch.compile instances + name = unique_id(prefix) + self.install_global_unsafe(name, value) + return name + + def cleanup(self) -> None: + # There is a reference cycle between tracer and OutputGraph, causing + # some of the tensor objects to be held alive for longer than necessary. + self.root_tx = None + self.nn_modules.clear() + self.param_name_to_source = None + + for node in self.graph.nodes: + if "grapharg" in node.meta: + del node.meta["grapharg"] + self.real_value_cache.clear() + self.input_name_to_proxy.clear() + self.side_effects.clear() + self.register_finalizer_fns.clear() + self.dynamo_flat_name_to_original_fqn.clear() + self.tracing_context.clear() + + def set_torch_function_state(self, enabled: bool) -> None: + self.torch_function_enabled = enabled + + def add_graph_finalizer( + self, register_finalizer: Callable[[fx.GraphModule], None] + ) -> None: + self.register_finalizer_fns.append(register_finalizer) + + def example_value_from_input_node(self, node: torch.fx.Node): + """Extract the non-fake example tensor""" + if node.op == "placeholder": + return node.meta["grapharg"].example + assert node.op == "get_attr" + return self.nn_modules[node.target] # type: ignore[index] + + +err_epilogue = ( + "With the current config, we will graph break " + "(and fall back to eager-mode PyTorch) on all ops " + "that have do not have the 'pt2_compliant_tag'. " + "Please see the following doc for how to mark this op as PT2 compliant " + "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ" +) + + +def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): + if kind != "call_function": + return + + def encountered_compliant_op(target): + if target.namespace in {"prim", "prims", "aten"}: + return + output_graph.compliant_custom_ops.add(target) + + def encountered_non_compliant_op(target, msg): + output_graph.non_compliant_ops.add(target) + if config.only_allow_pt2_compliant_ops: + unimplemented(msg + " " + err_epilogue) + + if isinstance(target, torch._ops.OpOverload): + if torch.Tag.pt2_compliant_tag in target.tags: + encountered_compliant_op(target) + return + encountered_non_compliant_op( + target, + f"Encountered the torch.ops.OpOverload {target} " + f"that is not PT2 compliant.", + ) + return + + if isinstance(target, torch._ops.OpOverloadPacket): + overloads = tuple(target.overloads()) + # Optimization: Overload resolution is expensive. + # If there's only one overload, we know what it will resolve to. + if len(overloads) == 1: + op = getattr(target, overloads[0]) + if torch.Tag.pt2_compliant_tag in op.tags: + encountered_compliant_op(op) + return + encountered_non_compliant_op( + op, + f"Encountered the non-overloaded " + f"torch.ops.OpOverloadPacket {target} " + f"that is not PT2 compliant. ", + ) + return + + args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes( + output_graph.current_tx, (args, kwargs), False + ) + try: + overload = torch._C._jit_resolve_packet( + target._qualified_op_name, *args, **kwargs + ) + except RuntimeError as e: + unimplemented(str(e)) + + op = getattr(target, overload) + if torch.Tag.pt2_compliant_tag in op.tags: + encountered_compliant_op(op) + else: + encountered_non_compliant_op( + op, + f"Encountered the torch.ops.OpOverloadPacket {target} " + f"which resolves to the overload ({overload}) that is " + f"not PT2 compliant.", + ) + + +_compile_id_counter = itertools.count() + + +class SubgraphTracer(fx.Tracer): + """ + Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer + and the separation of responsibilities is that SubgraphTracer is + responsible for building the graph while OutputGraph is responsible for + compiling and executing the graph. + """ + + def __init__( + self, output_graph, parent=None, export_root=False, source_target=None + ): + super().__init__() + self.output_graph = weakref.proxy(output_graph) + self.graph = torch.fx.Graph() + + # The export is only ever set for the ROOT tracer. It controls + # whether or not certain inputs are allowed to be added or not. + # Look at call sites of create_graph_input to see how it is used. + if export_root: + assert parent is None + self.export_root = export_root + # Map from graph input name to its placeholder proxy object, where the + # map's keys give all current placeholder node names and can be used to + # create unique node names + self.input_name_to_proxy: Dict[str, fx.Proxy] = {} + # Node => computed real value (see utils.get_real_value) + self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} + + # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] + self.parent = parent + # A dict mapping previously free variables (Proxy objects) + # to new Proxy objects that wrap inputs to this subgraph. + # + # This dict serves two purposes: + # - Proxies are associated with VariableTrackers. If we see + # the same VariableTracker twice (and it is a free variable), + # then we want to use the same Proxy in the current subgraph to + # record the tracing. + # - If we are tracing a HigherOrderOperator's body_fn, then we + # need to keep track of what free variables were lifted so we can + # rewrite the HigherOrderOperator call using the traced body_fn. + # Dicts maintain the order of args for the HigherOrderOperator call. + self.lifted_freevars = {} + self.prev_inst = None + + self._cur_code = None + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None + # Each SubgraphTracer is associated with a source target, which indicates + # which operator this subgraph is attached to. We compute a source_fn_stack + # based on the source target. For the root tracer, it's set to []. + # This is useful for debugging and transforming the exported graph. + if self.parent is None: + self.source_fn_stack = [] + else: + self.source_fn_stack = self.parent.source_fn_stack + [ + (self.graph._target_to_str(source_target), source_target) + ] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + # NOTE: [Nested SubgraphTracer and free_variable handling] + # -------------------------------------------------------- + # Read NOTE [HigherOrderOperator tracing design] first. + # + # Let's say we're in the middle of introspecting the body of a possibly + # nested HigherOrderOperator, and we see a free variable. + # + # There are two cases: + # 1. We see a free variable that is already tracked by Dynamo. + # 2. We see a free variable that has not been tracked by Dynamo + # + # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below) + # which will lift the freevar to be an input of this subgraph + # and also recursively lift it to be an input on the parent(s). + # + # In case 2, before the call to `create_proxy`, the InstructionTranslator + # will see the freevar when it gets loaded by Python bytecode. + # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or + # LOAD_GLOBAL. + # There, the InstructionTranslator asks Dynamo to begin tracking the + # freevar by building a new Variable. + # Building a new Variable automatically lifts the freevar to be an + # input of the root SubgraphTracer. + # + # The implications for the code below are: + # - We will always be in Case 1 when we get to this code. + # - Any "free variable" we encounter here is guaranteed to already be + # bound, that is, it is either a graph input of the root graph, or + # some local variable of the root graph or a subgraph. + # - The additional work we need to do here is *only* that we need to + # lift this free variable into inputs (recursively) of each nested + # higher-order-op subgraph until we hit the subgraph where the free + # variable is bound + if self.parent is not None: + flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) + new_flat_args = [] + for arg in flat_args: + maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg) + new_flat_args.append(maybe_new_arg) + + args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) + + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + + # append stack trace to fx node + tx = self.output_graph.current_tx + + # log detailed location of line of code in 3.11 + if sys.version_info >= (3, 11) and kind in ( + "call_function", + "call_method", + "call_module", + ): + cur_inst = tx.current_instruction + if ( + cur_inst is not self.prev_inst + and cur_inst.positions is not None + and cur_inst.positions.lineno is not None + ): + tx_code = tx.f_code + header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) + + def get_trace_call_log_str(): + line = get_instruction_source_311(tx_code, cur_inst).rstrip() + return f"TRACE FX call {rv.node.name} from {header}\n{line}" + + trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) + self.prev_inst = cur_inst + + # update reference to original meta if we're tracing a new code object + is_retracing = False + if tx.f_code is not self._cur_code: + orig_graphmodule_maybe = code_context.get_context(tx.f_code).get( + "orig_graphmodule", lambda: None + )() + if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule): + is_retracing = True + self._orig_gm_meta = [ + nd.meta for nd in orig_graphmodule_maybe.graph.nodes + ] + self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map + self._orig_gm_firstlineno = ( + orig_graphmodule_maybe.forward.__code__.co_firstlineno + ) + else: + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + unimplemented("Invoking an nn.Module inside HigherOrderOperator") + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] + + # preserve original meta if it is available + if ( + self._orig_gm_meta + and self._orig_gm_lineno_map + and self._orig_gm_firstlineno + ): + lineno = tx.current_instruction.starts_line + node_idx = None + if lineno is not None: + node_idx = self._orig_gm_lineno_map.get( + lineno - self._orig_gm_firstlineno, None + ) + if node_idx is not None: + meta = self._orig_gm_meta[node_idx] + for field in fx.proxy._COPY_META_FIELDS: + if field in meta: + rv.node.meta[field] = meta[field] + if "stack_trace" in meta: + rv.node.meta["stack_trace"] = meta["stack_trace"] + + if not is_retracing: + if "nn_module_stack" not in rv.node.meta: + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if "source_fn_stack" not in rv.node.meta: + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + unimplemented( + "Invoking an nn.Module inside HigherOrderOperator" + ) + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] + + if "stack_trace" not in rv.node.meta: + frame_summaries: List[traceback.FrameSummary] = [] + while tx: + frame_summaries.append(tx.frame_summary()) + tx = getattr(tx, "parent", None) + # Reverse the frame_summaries, such that the innermost frame is at the last + frame_summaries.reverse() + + # official from_list stub doesn't have new-style type + msgs = traceback.StackSummary.from_list(frame_summaries).format() + rv.node.stack_trace = "".join(msgs) + + return rv + + def create_node( + self, op, target, args=None, kwargs=None, name=None, type_expr=None + ): + check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) + if self.parent is not None: + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + for arg in flat_args: + if not isinstance(arg, torch.fx.Node): + continue + assert ( + arg.graph == self.graph + ), "create_node using arg not from this SubgraphTracer" + + node = super().create_node(op, target, args, kwargs, name, type_expr) + node.meta["creation_timestamp"] = self.output_graph.timestamp + return node + + # Note: we did not override erase_node since + # we call self.graph.erase_node elsewhere + def remove_node(self, node): + if len(node.users) > 0: + user_graph_nodes: List[torch.fx.Node] = [] + for user in node.users.keys(): + # For the case where user.graph == self.graph, that is a real bug and will raise + # properly. + if user.graph != self.graph: + # This is a nested graph, which needs to be deleted. + # If we do not do this, we will raise on attempting to remove this. + # As we only get here during restoration cleanup, this is sound. + user_graph_nodes.extend(reversed(list(user.graph.nodes))) + for other_graph_node in user_graph_nodes: + other_graph_node.graph.erase_node(other_graph_node) + self.graph.erase_node(node) + self.input_name_to_proxy.pop(node.name, None) + + # when before=True, we will insert this input before the most recent + # inserted proxy. This is a hack to get around an ordering problem, + # where we first insert a tensor argument, and then insert bindings + # for SymInts that may occur in the tensor argument. + # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets + # fixed. + def create_graph_input(self, name, type_expr=None, before=False, source=None): + log.debug( + "create_graph_input %s %s", + name, + source.name() if source is not None else "(none)", + ) + if source is None: + assert ( + self.parent is not None + ), "you are required to provide a source for inputs on the root tracer" + + # In eager, we are generally OK with adding graph inputs whenever we + # want, because we take care of writing the bytecode that knows how + # to source all the inputs. + # + # In export, this is bad, because you want a self-contained export + # object which only depends on the inputs you explicitly passed to it. + # So we are a bit more strict about what sources can become inputs + # in export + if self.export_root: + if not is_from_local_source(source, allow_cell_or_freevar=False): + self.output_graph.source_to_user_stacks.setdefault(source, []).append( + TracingContext.extract_stack() + ) + + # unique + if name in self.input_name_to_proxy: + for i in itertools.count(): + candidate_name = f"{name}_{i}" + if candidate_name not in self.input_name_to_proxy: + name = candidate_name + break + + if self.input_name_to_proxy: + prev_name = next(reversed(self.input_name_to_proxy)) + node = self.input_name_to_proxy[prev_name].node + if before: + ctx = self.graph.inserting_before(node) + else: + ctx = self.graph.inserting_after(node) + else: + ctx = self.graph.inserting_before(None) + with ctx: + proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + if self.input_name_to_proxy and before: + k, v = self.input_name_to_proxy.popitem() + self.input_name_to_proxy[name] = proxy + self.input_name_to_proxy[k] = v + else: + self.input_name_to_proxy[name] = proxy + return proxy + + # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details + def lift_tracked_freevar_to_input(self, proxy): + # You're doing something wrong if we are the root SubgraphTracer because + # Dynamo adds tensors to graph inputs before creating a proxy for them. + assert ( + self.parent is not None + ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer" + # Proxys are associated with VariableTracker. + # It is possible that we've already lifted the Proxy to be an input. + # If that is the case, just return the already lifted Proxy. + if proxy in self.lifted_freevars: + return self.lifted_freevars[proxy] + new_proxy = self.create_graph_input(proxy.node.name) + new_proxy.node.meta["example_value"] = proxy.node.meta["example_value"] + self.lifted_freevars[proxy] = new_proxy + if self.parent is not None and proxy.tracer != self.parent: + self.parent.lift_tracked_freevar_to_input(proxy) + return new_proxy + + def maybe_lift_tracked_freevar_to_input(self, arg): + """ + If arg is a free variable, then lift it to be an input. + Returns the new lifted arg (if arg was a freevar), else the + original arg. + """ + if not isinstance(arg, torch.fx.Proxy): + return arg + elif arg.tracer == self: + return arg + return self.lift_tracked_freevar_to_input(arg) + + +# NOTE: [HigherOrderOperator tracing design] +# Ignoring HigherOrderOperators for a moment, +# OutputGraph represents the graph being built by Dynamo that may be compiled +# and executed. It holds a root SubgraphTracer where the FX graph is built. +# +# HigherOrderOperators are operators that take functions as their arguments. +# When Dynamo encounters a HigherOrderOperator, then it attempts to introspect +# the function passed to it (call this the "body function"), capture it into a +# GraphModule, and rewrite the call to the HigherOrderOperator to use the +# GraphModule. +# +# The way we handle the capture of body functions is through having +# (possibly nested) SubgraphTracers, one per body function. +# +# Mechanically, we do the introspection by: +# - Creating a new SubgraphTracer via OutputGraph.subtracer +# - Executing the body function. +# This constructs the graph of the body function in the new SubgraphTracer +# while modifying the state of the OutputGraph. For example: +# - the OutputGraph can receive new GraphArgs (if we discover any new +# untracked Tensors) +# - side effects from the body function get accumulated into +# OutputGraph.side_effects +# - guards produced by the body function get accumulated into OutputGraph.guards +# +# The traced function has some special properties that make it easier for us +# to transform later down the line: +# - we lift all free variables to being inputs. +# +# If the introspection fails (due to the existence of graph breaks), then +# we roll back the current OutputGraph state and graph break on the +# HigherOrderOperator. diff --git a/MLPY/Lib/site-packages/torch/_dynamo/polyfill.py b/MLPY/Lib/site-packages/torch/_dynamo/polyfill.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6add71c1b2f2787c98bfd3823dfd564f28d687 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/polyfill.py @@ -0,0 +1,47 @@ +# mypy: ignore-errors + +""" +Python polyfills for common builtins. +""" +import math + +import torch + + +def all(iterator): + for elem in iterator: + if not elem: + return False + return True + + +def any(iterator): + for elem in iterator: + if elem: + return True + return False + + +def index(iterator, item, start=0, end=None): + for i, elem in enumerate(list(iterator))[start:end]: + if item == elem: + return i + # This will not run in dynamo + raise ValueError(f"{item} is not in {type(iterator)}") + + +def repeat(item, count): + for i in range(count): + yield item + + +def radians(x): + return math.pi / 180.0 * x + + +def accumulate_grad(x, new_grad): + new_grad = torch.clone(new_grad) + if x.grad is None: + x.grad = new_grad + else: + x.grad.add_(new_grad) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/profiler.py b/MLPY/Lib/site-packages/torch/_dynamo/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..f26c3c7d010d2aa622e8061464c7be191fbd1297 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/profiler.py @@ -0,0 +1,155 @@ +import dataclasses +import os +from typing import Any, List + +import torch + +from .utils import print_once + + +@dataclasses.dataclass +class ProfileMetrics: + microseconds: float = 0.0 + operators: int = 0 + fusions: int = 0 + graphs: int = 0 + + def __iadd__(self, other: "ProfileMetrics"): + self.microseconds += other.microseconds + self.operators += other.operators + self.fusions += other.fusions + return self + + def __add__(self, other: "ProfileMetrics"): + assert isinstance(other, ProfileMetrics) + return ProfileMetrics( + self.microseconds + other.microseconds, + self.operators + other.operators, + self.fusions + other.fusions, + ) + + def __truediv__(self, other): + if isinstance(other, int): + other = ProfileMetrics(other, other, other) + return ProfileMetrics( + self.microseconds / max(1, other.microseconds), + self.operators / max(1, other.operators), + self.fusions / max(1, other.fusions), + ) + + def __str__(self): + return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time" + + def tocsv(self): + return [self.operators, self.microseconds] + + +class ProfileResult: + def __init__(self, captured, total, unique_graphs): + self.captured: ProfileMetrics = captured or ProfileMetrics() + self.total: ProfileMetrics = total or ProfileMetrics() + self.unique_graphs: int = unique_graphs + + def __iadd__(self, other: "ProfileResult"): + self.captured += other.captured + self.total += other.total + self.unique_graphs += other.unique_graphs + return self + + def percent(self): + return self.captured / self.total + + def __str__(self): + return ( + f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls " + f"{self.captured.operators:4}/{self.total.operators:4} = " + + str(self.percent()) + ) + + def tocsv(self): + return [ + self.unique_graphs, + self.captured.graphs, + self.captured.operators, + self.total.operators, + ] + self.percent().tocsv() + + +def should_print_missing(): + return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1" + + +def print_missing(stack): + if any("/torch/autograd/profiler.py" in x for x in stack): + return + stack = [ + x for x in stack if ("> ".join(stack[-3:])) + + +class Profiler: + unique_graphs = 0 + + def __init__(self): + self.prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + with_stack=should_print_missing(), + ) + + def results(self): + captured_regions = 0 + captured_ops = 0 + captured_microseconds = 0 + total_ops = 0 + total_microseconds = 0 + + last_op_end_time = -1 + captured_region_end_time = -1 + events = sorted(self.prof.events(), key=lambda x: x.time_range.start) + for e in events: + if e.name == "TORCHDYNAMO": + captured_region_end_time = e.time_range.end + captured_regions += 1 + # ignore `handle = torch.zeros(1)` in record_function.__init__() + total_ops -= 1 + elif e.time_range.start >= last_op_end_time: + last_op_end_time = e.time_range.end + if e.time_range.end <= captured_region_end_time: + captured_ops += 1 + captured_microseconds += e.time_range.elapsed_us() + elif should_print_missing(): + print_missing(e.stack) + total_ops += 1 + total_microseconds += e.time_range.elapsed_us() + else: + pass # ops recursively called from other ops (ignored) + + unique_graphs = Profiler.unique_graphs + Profiler.unique_graphs = 0 + # we counted one extra op that is part of the profiler setup code + total_ops -= 1 + + return ProfileResult( + captured=ProfileMetrics( + microseconds=captured_microseconds, + operators=captured_ops, + fusions=captured_ops - captured_regions, + graphs=captured_regions, + ), + total=ProfileMetrics( + microseconds=total_microseconds, + operators=total_ops, + fusions=total_ops - 1, + ), + unique_graphs=unique_graphs, + ) + + +def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]): + def _wrapped(*args): + with torch.profiler.record_function("TORCHDYNAMO"): + return gm.forward(*args) + + Profiler.unique_graphs += 1 + return _wrapped diff --git a/MLPY/Lib/site-packages/torch/_dynamo/replay_record.py b/MLPY/Lib/site-packages/torch/_dynamo/replay_record.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4690ed78b14bc765edbfeafa62bf4e35907a16 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/replay_record.py @@ -0,0 +1,110 @@ +import dataclasses +from dataclasses import field +from types import CodeType, ModuleType +from typing import Any, Dict + +from torch.utils._import_utils import import_dill + +dill = import_dill() + + +@dataclasses.dataclass +class ModuleRecord: + module: ModuleType + accessed_attrs: Dict[str, Any] = field(default_factory=dict) + + +@dataclasses.dataclass +class DummyModule: + name: str + is_torch: bool = False + + @property + def __name__(self): + return self.name + + +@dataclasses.dataclass +class ExecutionRecord: + code: CodeType + globals: Dict[str, Any] = field(default_factory=dict) + locals: Dict[str, Any] = field(default_factory=dict) + builtins: Dict[str, Any] = field(default_factory=dict) + code_options: Dict[str, Any] = field(default_factory=dict) + + def dump(self, f): + assert dill is not None, "replay_record requires `pip install dill`" + dill.dump(self, f) + + @classmethod + def load(cls, f): + assert dill is not None, "replay_record requires `pip install dill`" + return dill.load(f) + + +@dataclasses.dataclass +class ExecutionRecorder: + LOCAL_MOD_PREFIX = "___local_mod_" + + code: CodeType + globals: Dict[str, Any] = field(default_factory=dict) + locals: Dict[str, Any] = field(default_factory=dict) + builtins: Dict[str, Any] = field(default_factory=dict) + code_options: Dict[str, Any] = field(default_factory=dict) + name_to_modrec: Dict[str, Any] = field(default_factory=dict) + + def add_local_var(self, name, var): + if isinstance(var, ModuleType): + self.locals[name] = self._add_mod(var) + else: + self.locals[name] = var + + def add_global_var(self, name, var): + if isinstance(var, ModuleType): + self.globals[name] = self._add_mod(var) + else: + self.globals[name] = var + + def add_local_mod(self, name, mod): + assert isinstance(mod, ModuleType) + + self.add_global_var(name, mod) + + def record_module_access(self, mod, name, val): + if isinstance(val, ModuleType): + self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val) + return + + if mod.__name__ in self.name_to_modrec: + self.name_to_modrec[mod.__name__].accessed_attrs[name] = val + + def get_record(self): + return ExecutionRecord( + self.code, + ExecutionRecorder._resolve_modules(self.globals), + ExecutionRecorder._resolve_modules(self.locals), + self.builtins.copy(), + self.code_options.copy(), + ) + + def _add_mod(self, mod): + if mod.__name__ not in self.name_to_modrec: + self.name_to_modrec[mod.__name__] = ModuleRecord(mod) + + return self.name_to_modrec[mod.__name__] + + # Convert ModuleRecords -> DummyModule tree + @classmethod + def _resolve_modules(cls, vars): + def resolve_module(var): + if not isinstance(var, ModuleRecord): + return var + + dummy_mod = DummyModule(var.module.__name__) + for attr_name, attr_value in var.accessed_attrs.items(): + attr_value = resolve_module(attr_value) + dummy_mod.__setattr__(attr_name, attr_value) + + return dummy_mod + + return {k: resolve_module(v) for k, v in vars.items()} diff --git a/MLPY/Lib/site-packages/torch/_dynamo/repro/__init__.py b/MLPY/Lib/site-packages/torch/_dynamo/repro/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f7c37cb0533e34b507fcb72806e318b6e569a41 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81e858dd1f7e713cb81ef3bab109137a40888af9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a8604c8d018209ecbb642bab8d04b238dc4fcea Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/repro/after_aot.py b/MLPY/Lib/site-packages/torch/_dynamo/repro/after_aot.py new file mode 100644 index 0000000000000000000000000000000000000000..3658d7430396b7bda82470dc7c5bde8025a18727 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/repro/after_aot.py @@ -0,0 +1,932 @@ +import argparse +import copy +import functools +import io +import logging +import os +import shutil +import subprocess +import sys +import textwrap +import uuid +from importlib import import_module +from tempfile import TemporaryFile +from typing import Any, Callable, Dict, Union + +import torch +import torch.fx as fx +import torch.nn as nn +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + AccuracyError, + backend_accuracy_fails, + BuckTargetWriter, + cast_to_fp64, + extra_imports, + generate_config_string, + helper_for_dump_minify, + InputReader, + InputWriter, + MAX_CONSTANT_NUMEL_INLINE, + minifier_dir, + NNModuleToString, + NopInputReader, + same_two_models, +) +from torch._dynamo.utils import clone_inputs, counters, same +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + fx_placeholder_targets, + has_free_symbols, +) +from torch.hub import tqdm + +from .. import config + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): + """ + Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both + forward and backward call separately with the backend compiler_fn - like + inductor or nvfuser. Intercepting after Aot Autograd presents neat + abstraction, where all the params are lifted as graph inputs, making it easy + to save the graph as a string. + """ + + @functools.wraps(unconfigured_compiler_fn) + def debug_wrapper(gm, example_inputs, **kwargs): + from torch._subclasses import FakeTensorMode + + compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) + + from torch._functorch.aot_autograd import get_aot_graph_name + + graph_name = get_aot_graph_name() + + # TODO: why do we need to deepcopy the original graph? + orig_graph = copy.deepcopy(gm.graph) + assert config.repro_after in ("dynamo", "aot", None) + + try: + # Call the compiler_fn - which is either aot_autograd or inductor + # with fake inputs + inner_compiled_fn = compiler_fn(gm, example_inputs) + except Exception as e: + # TODO: Failures here are troublesome because no real inputs, + # need a different serialization strategy + if config.repro_after == "aot": + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + log.error("CompilerError") + raise + + # We may run regular PyTorch compute that may trigger Dynamo, do NOT + # recursively attempt to accuracy minify in that case! + def deferred_for_real_inputs(real_inputs): + # This is a bit obscure: if we recursively try to accuracy minify + # the SAME function, this would trigger. But most of the time + # we should never hit this branch + if config.repro_after != "aot": + return inner_compiled_fn(real_inputs) + with config.patch(repro_after=None): + return inner_debug_fn(real_inputs) + + def inner_debug_fn(real_inputs): + """ + Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, + example_inputs can be fake tensors. We can call compiler_fn (which is + inductor or nvfuser) with fake tensors but the actually compiled_fn + should be called with real tensors. Therefore, the actual invocation + is deferred. + """ + # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor + # because inductor clears the tensor list in its codegen. And example_inputs + # are available only for the first invocation. + fake_mode = FakeTensorMode() + copy_tensor_attrs = [ + fake_mode.from_tensor(x) if isinstance(x, torch.Tensor) else x + for x in real_inputs + ] + if config.repro_level == 3: + # Always dump the original module in case we have segfaults + dump_to_minify( + fx.GraphModule(gm, orig_graph), real_inputs, compiler_name + ) + + if config.repro_level == 4: + if compiler_name != "inductor": + raise NotImplementedError( + "Accuracy minification is supported for inductor only" + ) + if backend_aot_accuracy_fails(gm, real_inputs, compiler_fn): + log.warning( + "Accuracy failed for the AOT Autograd graph %s", graph_name + ) + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + dump_to_minify( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + raise AccuracyError("Bad accuracy detected") + else: + # Call the compiled function with real inputs + return inner_compiled_fn(real_inputs) + else: + try: + # Call the compiled function with real inputs + out = inner_compiled_fn(real_inputs) + # sync cuda kernels to ensure IMA detection + for arg in example_inputs: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + torch.cuda.synchronize() + break + return out + except Exception as e: + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + raise + + if config.repro_after == "aot": + compiled_fn = deferred_for_real_inputs + compiled_fn._boxed_call = True # type: ignore[attr-defined] + return compiled_fn + else: + return inner_compiled_fn + + return debug_wrapper + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_compiler_repro_string(gm, args, *, stable_output=False, save_dir=None): + model_str = textwrap.dedent( + f""" +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + + """ + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + + model_str += NNModuleToString.convert(gm) + + # get hint shape/stride when dynamic shape enabled + def hint_if_symint(x): + return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x) + + writer = InputWriter(save_dir) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + else: + raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") + + model_str += "\n".join(writer.lines()) + "\n" + + model_str += "mod = Repro()\n" + return model_str + + +def save_graph_repro( + fd, + gm, + args, + compiler_name, + *, + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + tracing_mode=None, + check_str=None, +): + fd.write( + generate_compiler_repro_string( + gm, + args, + stable_output=stable_output, + save_dir=save_dir, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + if tracing_mode is None: + tracing_mode = "real" + if any(has_free_symbols(a) for a in args): + tracing_mode = "symbolic" + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.after_aot import run_repro\n") + fd.write( + f" with torch.no_grad():\n" + f" run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r}" + ")\n" + ) + + +def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro( + fd, gm, args, compiler_name, save_dir=subdir, accuracy=accuracy + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + pass + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP MINIFIER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify(gm, args, compiler_name: str): + out = io.StringIO() + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + save_graph_repro(out, gm, args, compiler_name, save_dir=subdir, command="minify") + return helper_for_dump_minify(out.getvalue()) + + +def isolate_fails( + fx_g, + args, + compiler_name: str, + env=None, + save_dir=None, + accuracy=None, + tracing_mode=None, + check_str=None, +): + if env is None: + env = {} + subdir = os.path.join(os.getcwd(), "isolate") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") + with open(file_name, "w") as fd: + save_graph_repro( + fd, + fx_g, + args, + compiler_name, + save_dir=save_dir, + command="minifier-query", + accuracy=accuracy, + tracing_mode=tracing_mode, + check_str=check_str, + ) + # with open(file_name, "r") as fd: + # print(fd.read()) + new_env = os.environ.copy() + new_env = {**new_env, **env} + stdout, stderr = TemporaryFile(), TemporaryFile() + + if use_buck: + cmd = BuckTargetWriter(file_name).write(print_msg=False) + else: + cmd = ["python", file_name] + + p = subprocess.Popen( + cmd, + cwd=subdir, + stdout=stdout, + stderr=stderr, + env=new_env, + ) + p.wait() + + stdout.seek(0) + stderr.seek(0) + print( + textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), file=sys.stdout + ) + print( + textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), file=sys.stderr + ) + # print(f"Isolated test failed - {file_name}") + return p.returncode != 0 + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER TOOLS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def inductor_fails(fx_g, args, check_str=None): + has_cuda = False + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + has_cuda = True + break + + def sync(): + if has_cuda: + # Ensures that segfaults are surfaced + torch.cuda.synchronize() + + from torch._inductor.compile_fx import compile_fx_inner + + try: + result = fx_g(*args) + assert isinstance(result, (tuple, list)) + assert not any(isinstance(x, (tuple, list)) for x in result) + except Exception: + return False + + sync() + + try: + compile_mod = compile_fx_inner(fx_g, args) + compile_mod(args) + sync() + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + print(repr(e)) + return True + return False + + +def inductor_accuracy_fails( + fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False +): + from torch._inductor.compile_fx import compile_fx_inner + + return backend_aot_accuracy_fails( + fx_g, + args, + compile_fx_inner, + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + + +backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def repro_common(options, mod, load_args): + # Invariant for graphs we generate with the repro script + assert not any(mod.named_parameters()) + for n, b in mod.named_buffers(): + if b.numel() > MAX_CONSTANT_NUMEL_INLINE: + log.warning( + "Constant %s was not serialized, generated random data instead. " + "If you think this is affecting you, please comment on " + "https://github.com/pytorch/pytorch/issues/100468", + n, + ) + + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + # Turn mod into a GraphModule the slow way + # TODO: speed this up + mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) + + torch._inductor.config.generate_intermediate_hooks = True + + return mod, args + + +ACCURACY_FAILS: Dict[str, Callable[[nn.Module, Any], bool]] = { + "": inductor_fails, + # This might look inverted but it's not. strict_accuracy means "we will + # minify any time we see anything that diverges", whereas accuracy is more + # conservative, and will only minify if there is a meaningful fp64 + # divergence + "accuracy": functools.partial( + inductor_accuracy_fails, require_fp64=True, ignore_non_fp=True + ), + "strict_accuracy": inductor_accuracy_fails, +} + + +def repro_minifier_query(options, mod, load_args): + mod, args = repro_common(options, mod, load_args) + fail_fn = functools.partial( + ACCURACY_FAILS[options.accuracy], check_str=options.check_str + ) + if fail_fn(mod, args): + sys.exit(1) + else: + sys.exit(0) + + +def repro_minify(options, mod, load_args): + from functorch.compile import minifier + + mod, args = repro_common(options, mod, load_args) + compiler_name = "inductor_accuracy" if options.accuracy != "" else "inductor" + + favored_device = 1 if torch.cuda.device_count() >= 2 else 0 + env_variables = {"CUDA_VISIBLE_DEVICES": str(favored_device)} + + module_fails: Any + if options.isolate: + module_fails = functools.partial( + isolate_fails, + env=env_variables, + compiler_name=compiler_name, + save_dir=options.save_dir, + accuracy=options.accuracy, + tracing_mode=options.tracing_mode, + ) + else: + module_fails = ACCURACY_FAILS[options.accuracy] + + minifier( + mod, + args, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, compiler_name=compiler_name + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def repro_analyze(options, mod, load_args): + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.hooks import intermediate_hook + + mod, args = repro_common(options, mod, load_args) + + # TODO: The logic for cloning inputs/models here is intentionally + # modeled off of run_fwd_maybe_bwd, but arguably it is better not to + # clone inputs (as you are doubling your effective GPU memory usage). + # It is certainly faster though! It probably makes sense to let the + # user specify the offload strategy. + + with tqdm(desc="Compiling"): + compiled = compile_fx_inner(mod, args) + total = counters["inductor"]["intermediate_hooks"] + + known_names = set() + + def save_hook(name, val): + known_names.add(name) + if not options.skip_saving_inductor_intermediates: + writer.write_tensor(os.path.join("inductor", name), val) + pbar.update(1) # type: ignore[has-type] + + writer = torch.utils._content_store.ContentStoreWriter( + options.save_dir, stable_hash=options.stable_hash + ) + reader = torch.utils._content_store.ContentStoreReader(options.save_dir) + + new_args = clone_inputs(args) + with intermediate_hook(save_hook), tqdm( + desc="Saving inductor intermediates", total=total + ) as pbar: + compiled(new_args) + assert not new_args + + def compare_tuples(tuple1, tuple2): + diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] + diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] + + if not diff_values: + return None + else: + return " and ".join(f"{a} != {b}" for a, b in diff_values) + + def check_hook(name, val): + meta = writer.compute_tensor_metadata(val) + meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC INDUCTOR at {name} ({reason})") + pbar.update(1) + + if not options.skip_check_deterministic: + new_args = clone_inputs(args) + with intermediate_hook(check_hook), tqdm( + desc="Checking inductor determinism", total=total + ) as pbar: + compiled(new_args) + assert not new_args + + class WriterInterp(fx.Interpreter): + def __init__(self, mod, subdir): + super().__init__(mod) + self.subdir = subdir + + def run_node(self, n): + r = super().run_node(n) + name = n.name + if name in known_names: + pbar.update(1) + writer.write_tensor(os.path.join(self.subdir, name), r) + return r + + # NB: the module cast doesn't actually do anything, since there are no + # parameters/buffers on the module + if not options.skip_saving_float64_intermediates: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + with tqdm(desc="Saving float64 intermediates", total=total) as pbar: + WriterInterp(new_mod, "float64").boxed_run(new_args) + assert not new_args + + class ExactReaderInterp(fx.Interpreter): + def run_node(self, n): + r = super().run_node(n) + name = n.name + if name in known_names: + meta = writer.compute_tensor_metadata(r) + meta2 = reader.read_tensor_metadata(os.path.join("float64", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC FLOAT64 at {name} ({reason})") + pbar.update(1) + return r + + # TODO: check eager determinism + + if not options.skip_check_deterministic: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + with tqdm(desc="Checking float64 determinism", total=total) as pbar: + ExactReaderInterp(new_mod).boxed_run(new_args) + assert not new_args + + # Now that we've saved everything, interp through the eager graph + # and do comparisons + class ReaderInterp(fx.Interpreter): + def run_node(self, n): + r = super().run_node(n) + name = n.name + if name in known_names: + inductor = reader.read_tensor(os.path.join("inductor", name)) + float64 = reader.read_tensor(os.path.join("float64", name)) + logged = False + + def log_error(msg, *args): + nonlocal logged + logged = True + pbar.write(f"DIVERGED at {name}: {msg % args}") + + if not same( + r, + inductor, + float64, + tol=torch._dynamo.config.repro_tolerance, + equal_nan=True, + log_error=log_error, + ): + assert logged + pbar.update(1) + return r + + with tqdm(desc="Checking divergence", total=total) as pbar: + ReaderInterp(mod).boxed_run(args) + assert not args + + +def repro_run(options, mod, load_args): + from torch._inductor.compile_fx import compile_fx_inner + + mod, args = repro_common(options, mod, load_args) + + from torch.cuda import synchronize + + compiled = compile_fx_inner(mod, args) + + if options.accuracy != "": + # We don't really respect --accuracy vs --strict-accuracy here, it + # seems counterintuitive + if not same_two_models(mod, compiled, args, only_fwd=True): + raise AccuracyError("Bad accuracy detected") + else: + need_sync = False + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + ref = compiled(list(args)) + if need_sync: + synchronize() # ensure segfaults are surfaced + return lambda: compiled(list(args)) + + +# TODO: lazily load the inputs or something, rather than cloning them +def run_repro( + mod, + load_args, + *, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + tracing_mode=None, + patch_code=None, + check_str=None, + **kwargs, +): + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + if patch_code is not None: + log.warning( + "patch_code no longer works on this version of PyTorch, silently ignoring" + ) + + parser = argparse.ArgumentParser( + description=f"""\ +An after_aot repro script, typically triggering a bug in PyTorch Inductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser): + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--tracing-mode", + type=str, + metavar="{real,fake,symbolic}", + default=tracing_mode, + help="how to trace the repro module into a GraphModule with metadata", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify,analyze}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_minify_isolate = parser_minify.add_mutually_exclusive_group() + parser_minify_isolate.add_argument( + "--isolate", + action="store_true", + default=True, + help="run in separate processes to avoid interference (default)", + ) + parser_minify_isolate.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + help="speed up by running all compilation in same process", + ) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + # TODO: make this an option for --analyze too + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + parser_analyze = subparsers.add_parser( + "analyze", help="run the accuracy analyzer on the repro" + ) + common_flags(parser_analyze) + parser_analyze.add_argument( + "--skip-saving-inductor-intermediates", + action="store_true", + help="skip saving inductor intermediates on --analyze", + ) + parser_analyze.add_argument( + "--skip-saving-float64-intermediates", + action="store_true", + help="skip saving float64 intermediates", + ) + parser_analyze.add_argument( + "--skip-check-deterministic", + action="store_true", + help="skip checking that the network is deterministic", + ) + parser_analyze.add_argument( + "--stable-hash", + action="store_true", + help="use SHA-1 checksum instead of fast (but possibly unsound) hash", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "analyze": repro_analyze, + "minifier-query": repro_minifier_query, + "run": repro_run, + } + return COMMAND_FNS[options.command](options, mod, load_args) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/repro/after_dynamo.py b/MLPY/Lib/site-packages/torch/_dynamo/repro/after_dynamo.py new file mode 100644 index 0000000000000000000000000000000000000000..854807d5a0654e06d496fb37f7b5e76602e53e4e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/repro/after_dynamo.py @@ -0,0 +1,566 @@ +import argparse +import copy +import functools +import logging +import os +import shutil +import sys +import textwrap +from importlib import import_module +from typing import Union + +import torch +import torch.fx as fx + +from torch._dynamo.debug_utils import ( + AccuracyError, + backend_accuracy_fails, + BUCK_CMD_PREFIX, + BuckTargetWriter, + extra_imports, + generate_config_string, + helper_for_dump_minify, + InputReader, + InputWriter, + minifier_dir, + NNModuleToString, + NopInputReader, + run_fwd_maybe_bwd, + same_two_models, +) +from torch.fx.experimental.symbolic_shapes import fx_placeholder_targets +from torch.hub import tqdm + +from .. import config +from ..backends.registry import lookup_backend, register_debug_backend +from ..debug_utils import clone_inputs_retaining_gradness + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + + @functools.wraps(unconfigured_compiler_fn) + def debug_wrapper(gm, example_inputs, **kwargs): + compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) + assert config.repro_after in ("dynamo", "aot", None) + + if config.repro_after == "dynamo": + + def add_paths(exc): + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") + if use_buck: + exc.buck_command = " ".join( + BUCK_CMD_PREFIX + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] + ) + + if config.repro_level == 3: + dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) + + # Check for either accuracy (level 4) or other type of failures. + if config.repro_level == 4: + # Check Accuracy + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + if backend_accuracy_fails(gm, example_inputs, compiler_fn): + log.warning( + "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." + ) + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + compiler_name, + ) + exc = AccuracyError("Bad accuracy detected.") + add_paths(exc) + raise exc + else: + try: + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) + except Exception as exc: + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + if config.repro_level == 1: + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name + ) + dump_state_fn( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs + ) + elif config.repro_level == 2: + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + compiler_name, + ) + add_paths(exc) + raise + else: + compiled_gm = compiler_fn(gm, example_inputs) + + return compiled_gm + + debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] + if hasattr(unconfigured_compiler_fn, "compiler_name"): + debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + return debug_wrapper + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO DUMPERS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_dynamo_fx_repro_string( + gm, + args, + compiler_name, + check_accuracy=False, + *, + stable_output=False, + save_dir=None, + command="run", +): + """ + Generate a repro string for backend-agnostic minified version. + """ + + model_str = NNModuleToString.convert(gm) + + # TODO: Figure out why torch.compile'd hash isn't work on this codepath + writer = InputWriter(save_dir, stable_hash=True) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + else: + raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") + load_args = "\n".join(writer.lines()) + + return textwrap.dedent( + f""" +from math import inf +import torch +from torch import tensor, device +import torch.fx as fx +import torch._dynamo +from torch._dynamo.testing import rand_strided +from torch._dynamo.debug_utils import run_fwd_maybe_bwd + +{generate_config_string(stable_output=stable_output)} + +{extra_imports} + +{model_str} +mod = Repro() + +{load_args} + +if __name__ == '__main__': + from torch._dynamo.repro.after_dynamo import run_repro + run_repro(mod, load_args, accuracy={check_accuracy!r}, command={command!r}, + save_dir={save_dir!r}, autocast={torch.is_autocast_enabled()!r}, backend={compiler_name!r}) +""" + ) + + +def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): + """ + Saves the repro to a repro.py file + """ + curdir = os.getcwd() + subdir = os.path.join(os.getcwd(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + + with open(file_name, "w") as fd: + fd.write( + generate_dynamo_fx_repro_string( + gm, args, compiler_name, check_accuracy, save_dir=subdir + ) + ) + latest_repro = os.path.join(curdir, "repro.py") + log.warning("Copying %s to %s for convenience", file_name, latest_repro) + + if use_buck: + BuckTargetWriter(latest_repro).write() + + shutil.copyfile(file_name, latest_repro) + + +def dump_backend_state(gm, args, compiler_name, check_accuracy=False): + """ + Dumps the dynamo graph to repro the issue. + 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a + repro.py file. + 2) If we can't convert Fx GraphModule to a string, we use to_folder to save + the module and save a tar file. + """ + assert NNModuleToString.can_convert_to_string(gm) + return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) + # return dump_backend_repro_as_tarfile(gm, args, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER DUMPER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify_after_dynamo(gm, args, compiler_name): + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + helper_for_dump_minify( + generate_dynamo_fx_repro_string( + gm, + args, + compiler_name, + check_accuracy=config.repro_level == 4, + save_dir=subdir, + command="minify", + ) + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER BACKENDS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@register_debug_backend +def dynamo_minifier_backend(gm, example_inputs, compiler_name): + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) + + # TODO: It's inconsistent to pass SymInt inputs but REAL tensors. + # We should pass ints and look at the GraphModule placeholders + # to resolve them to SymInt (if necessary) + example_inputs = [ + i.node.hint if isinstance(i, torch.SymInt) else i for i in example_inputs + ] + + try: + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) + raise ValueError("No issue was detected") + except Exception as exc: + orig_failure = str(exc) + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + fails_fn = functools.partial( + backend_fails, + compiler_fn=compiler_fn, + orig_failure=orig_failure, + ) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + return gm + + +@register_debug_backend +def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) + + # Set the eval mode to remove randomness. + gm.eval() + + # Check Accuracy + if backend_accuracy_fails( + gm, example_inputs, compiler_fn, only_fwd=config.repro_forward_only + ): + log.warning("Accuracy failed for the TorchDynamo produced graph") + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name, check_accuracy=True + ) + fails_fn = functools.partial( + backend_accuracy_fails, + compiler_fn=compiler_fn, + only_fwd=config.repro_forward_only, + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + else: + log.error("Input graph does not fail accuracy testing") + return gm + + +def backend_fails(gm, example_inputs, compiler_fn, orig_failure): + """ + Minifier uses this function to identify if the minified graph module fails + with the same error. + + One caveat is that minifier can potentially go into a wrong direction when + the resulting graph module fails for a different reason. To avoid this, we + save the string for the original exception and check similarity between new + and old exception. They can be somewhat different in some cases, when the + exception string depends on the failing node information. So, we have a + loose similarity metric to guide the minifier path. + """ + from difflib import SequenceMatcher + + try: + # Run the original gm to check eager validity + run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) + return False + except Exception as e: + new_failure = str(e) + if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: + return True + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def run_load_args(options, mod, load_args): + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return args + + +def repro_minify(options, mod, load_args): + args = run_load_args(options, mod, load_args) + + # Setup debug minifier compiler + if not options.accuracy: + compiler_fn = lookup_backend("dynamo_minifier_backend") + else: + compiler_fn = lookup_backend("dynamo_accuracy_minifier_backend") + + if options.backend is None: + raise RuntimeError( + "Compiler name is None - this likely means that a custom compiler " + "was called by torchdynamo. Please remove this error, import your " + "custom compiler function, and replace the backend=None " + "line in run_repro to backend=" + ) + + dynamo_minifier_backend = functools.partial( + compiler_fn, + compiler_name=options.backend, + ) + opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) + + with torch.cuda.amp.autocast(enabled=options.autocast): + opt_mod(*args) + + +def repro_run(options, mod, load_args): + opt_mod = torch._dynamo.optimize(options.backend)(mod) + + if options.accuracy != "": + mod.eval() + opt_mod.eval() + + with torch.cuda.amp.autocast(enabled=options.autocast): + # TODO: disable clone + args = run_load_args(options, mod, load_args) + assert same_two_models(mod, mod, args), "Eager itself failed" + if not same_two_models(mod, opt_mod, args): + raise AccuracyError("Dynamo failed") + else: + with torch.cuda.amp.autocast(enabled=options.autocast): + args = run_load_args(options, mod, load_args) + ref = run_fwd_maybe_bwd( + mod, args, only_fwd=options.only_fwd, disable_clone=True + ) + del args + + args = run_load_args(options, mod, load_args) + res = run_fwd_maybe_bwd( + opt_mod, args, only_fwd=options.only_fwd, disable_clone=True + ) + + +def run_repro( + mod, + load_args, + *, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + autocast=False, + backend="inductor", + **kwargs, +): + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An after_dynamo repro script, typically triggering a bug in Dynamo or +AOTAutograd. When run with no arguments, this script defaults to running +'{command}'. Extra flags may be available; to find out more, try '{command} +--help'. There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {save_dir=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser): + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="test accuracy", + ) + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + default=False, + help="no isolate (doesn't do anything for after_dynamo)", + ) + parser.add_argument( + "--autocast", + default=autocast, + action="store_true", + help="use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--no-autocast", + dest="autocast", + action="store_false", + help="don't use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--backend", + type=str, + default=backend, + metavar="BACKEND", + help="torch.compile backend to use", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + parser_run.add_argument( + "--only-fwd", + action="store_true", + help="don't run backwards compilation for testing", + ) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + } + COMMAND_FNS[options.command](options, mod, load_args) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/resume_execution.py b/MLPY/Lib/site-packages/torch/_dynamo/resume_execution.py new file mode 100644 index 0000000000000000000000000000000000000000..2df133548254a35808b5c54b4f58fc05050fdd8c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/resume_execution.py @@ -0,0 +1,648 @@ +import copy +import dataclasses +import sys +import types +from typing import Any, cast, Dict, List, Optional, Tuple + +from .bytecode_transformation import ( + create_call_function, + create_call_method, + create_dup_top, + create_instruction, + create_jump_absolute, + Instruction, + InstructionExnTabEntry, + transform_code_object, + unique_id, +) +from .utils import ExactWeakKeyDictionary + +# taken from code.h in cpython +CO_OPTIMIZED = 0x0001 +CO_NEWLOCALS = 0x0002 +CO_VARARGS = 0x0004 +CO_VARKEYWORDS = 0x0008 +CO_NESTED = 0x0010 +CO_GENERATOR = 0x0020 +CO_NOFREE = 0x0040 +CO_COROUTINE = 0x0080 +CO_ITERABLE_COROUTINE = 0x0100 +CO_ASYNC_GENERATOR = 0x0200 + + +@dataclasses.dataclass(frozen=True) +class ReenterWith: + stack_index: int + target_values: Optional[Tuple[Any, ...]] = None + + # If we do not want to destroy the stack, we can do the same thing as a + # `SETUP_WITH` block, only that we store the context manager in a local_symbol + def try_except(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + load args + enter context + try: + (rest) + finally: + exit context + """ + load_args = [] + if self.target_values: + load_args = [ + create_instruction("LOAD_CONST", argval=val) + for val in self.target_values + ] + ctx_name = unique_id(f"___context_manager_{self.stack_index}") + if ctx_name not in code_options["co_varnames"]: + code_options["co_varnames"] += (ctx_name,) + for name in ["__enter__", "__exit__"]: + if name not in code_options["co_names"]: + code_options["co_names"] += (name,) + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally = [ + *load_args, + *create_call_function(len(load_args), True), + create_instruction("STORE_FAST", argval=ctx_name), + create_instruction("LOAD_FAST", argval=ctx_name), + create_instruction("LOAD_METHOD", argval="__enter__"), + *create_call_method(0), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + return [ + create_instruction("LOAD_FAST", argval=ctx_name), + create_instruction("LOAD_METHOD", argval="__exit__"), + create_instruction("LOAD_CONST", argval=None), + create_dup_top(), + create_dup_top(), + *create_call_method(3), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("BEGIN_FINALLY"), + except_jump_target, + *create_reset(), + create_instruction("END_FINALLY"), + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + *create_reset(), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("RERAISE"), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RERAISE", arg=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + *create_reset(), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + *create_reset(), + finally_exn_tab_end, # RERAISE 0 + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + + def __call__(self, code_options, cleanup): + """ + Codegen based off of: + with ctx(args): + (rest) + """ + load_args = [] + if self.target_values: + load_args = [ + create_instruction("LOAD_CONST", argval=val) + for val in self.target_values + ] + if sys.version_info < (3, 9): + with_cleanup_start = create_instruction("WITH_CLEANUP_START") + begin_finally = create_instruction("BEGIN_FINALLY") + cleanup[:] = [ + create_instruction("POP_BLOCK"), + begin_finally, + with_cleanup_start, + create_instruction("WITH_CLEANUP_FINISH"), + create_instruction("END_FINALLY"), + ] + cleanup + + return [ + *load_args, + create_instruction("CALL_FUNCTION", arg=len(load_args)), + create_instruction("SETUP_WITH", target=with_cleanup_start), + create_instruction("POP_TOP"), + ], None + elif sys.version_info < (3, 11): + with_except_start = create_instruction("WITH_EXCEPT_START") + pop_top_after_with_except_start = create_instruction("POP_TOP") + + cleanup_complete_jump_target = create_instruction("NOP") + + cleanup[:] = [ + create_instruction("POP_BLOCK"), + create_instruction("LOAD_CONST", argval=None), + create_instruction("DUP_TOP"), + create_instruction("DUP_TOP"), + create_instruction("CALL_FUNCTION", arg=3), + create_instruction("POP_TOP"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + with_except_start, + create_instruction( + "POP_JUMP_IF_TRUE", target=pop_top_after_with_except_start + ), + create_instruction("RERAISE"), + pop_top_after_with_except_start, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_EXCEPT"), + create_instruction("POP_TOP"), + cleanup_complete_jump_target, + ] + cleanup + + return [ + *load_args, + create_instruction("CALL_FUNCTION", arg=len(load_args)), + create_instruction("SETUP_WITH", target=with_except_start), + create_instruction("POP_TOP"), + ], None + else: + pop_top_after_with_except_start = create_instruction("POP_TOP") + cleanup_complete_jump_target = create_instruction("NOP") + + def create_load_none(): + return create_instruction("LOAD_CONST", argval=None) + + exn_tab_1_begin = create_instruction("POP_TOP") + exn_tab_1_end = create_instruction("NOP") + exn_tab_1_target = create_instruction("PUSH_EXC_INFO") + exn_tab_2_end = create_instruction("RERAISE", arg=2) + exn_tab_2_target = create_instruction("COPY", arg=3) + + exn_tab_1_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_1_begin, + exn_tab_1_end, + exn_tab_1_target, + self.stack_index + 1, + True, + ) + exn_tab_1_target.exn_tab_entry = InstructionExnTabEntry( + exn_tab_1_target, + exn_tab_2_end, + exn_tab_2_target, + self.stack_index + 3, + True, + ) + pop_top_after_with_except_start.exn_tab_entry = InstructionExnTabEntry( + pop_top_after_with_except_start, + pop_top_after_with_except_start, + exn_tab_2_target, + self.stack_index + 3, + True, + ) + + cleanup[:] = [ + exn_tab_1_end, + create_load_none(), + create_load_none(), + create_load_none(), + *create_call_function(2, False), + create_instruction("POP_TOP"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + exn_tab_1_target, # PUSH_EXC_INFO + create_instruction("WITH_EXCEPT_START"), + create_instruction( + "POP_JUMP_FORWARD_IF_TRUE", + target=pop_top_after_with_except_start, + ), + exn_tab_2_end, # RERAISE 2 + exn_tab_2_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), + pop_top_after_with_except_start, + create_instruction("POP_EXCEPT"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + cleanup_complete_jump_target, + ] + cleanup + + return [ + *load_args, + *create_call_function(len(load_args), True), + create_instruction("BEFORE_WITH"), + exn_tab_1_begin, # POP_TOP + ], exn_tab_1_target + + +@dataclasses.dataclass +class ResumeFunctionMetadata: + code: types.CodeType + instructions: List[Instruction] = dataclasses.field(default_factory=list) + # Python 3.11+ fields + # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists + # of instructions of all exception table entries that have the same target. + + # map from PUSH_EXC_INFO's in the prefix to original block target offset + prefix_block_target_offset_remap: List[int] = dataclasses.field( + default_factory=list + ) + # map from new block target offsets to original block target offsets + block_target_offset_remap: Optional[Dict[int, int]] = None + + +def _filter_iter(l1, l2, cond): + """ + Two-pointer conditional filter. + e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) + returns the instructions with offsets in sorted_offsets + """ + it = iter(l2) + res = [] + try: + cur = next(it) + for val in l1: + if cond(val, cur): + res.append(val) + cur = next(it) + except StopIteration: + pass + return res + + +class ContinueExecutionCache: + cache = ExactWeakKeyDictionary() + generated_code_metadata = ExactWeakKeyDictionary() + + @classmethod + def lookup(cls, code, lineno, *key): + if code not in cls.cache: + cls.cache[code] = dict() + key = tuple(key) + if key not in cls.cache[code]: + cls.cache[code][key] = cls.generate(code, lineno, *key) + return cls.cache[code][key] + + @classmethod + def generate( + cls, + code, + lineno, + offset: int, + setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ + nstack: int, + argnames: Tuple[str], + setup_fns: Tuple[ReenterWith], + null_idxes: Tuple[int], + ) -> types.CodeType: + assert offset is not None + assert not ( + code.co_flags + & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR) + ) + assert code.co_flags & CO_OPTIMIZED + if code in ContinueExecutionCache.generated_code_metadata: + return cls.generate_based_on_original_code_object( + code, + lineno, + offset, + setup_fn_target_offsets, + nstack, + argnames, + setup_fns, + null_idxes, + ) + + is_py311_plus = sys.version_info >= (3, 11) + meta = ResumeFunctionMetadata(code) + + def update(instructions: List[Instruction], code_options: Dict[str, Any]): + meta.instructions = copy.deepcopy(instructions) + + args = [f"___stack{i}" for i in range(nstack)] + args.extend(v for v in argnames if v not in args) + freevars = tuple(code_options["co_cellvars"] or []) + tuple( + code_options["co_freevars"] or [] + ) + code_options[ + "co_name" + ] = f"torch_dynamo_resume_in_{code_options['co_name']}_at_{lineno}" + if is_py311_plus: + qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1) + if len(qualified_path) == 1: + code_options["co_qualname"] = code_options["co_name"] + else: + assert len(qualified_path) == 2 + module_name, co_name = qualified_path + code_options[ + "co_qualname" + ] = f"{module_name}.torch_dynamo_resume_in_{co_name}_at_{lineno}" + code_options["co_firstlineno"] = lineno + code_options["co_cellvars"] = tuple() + code_options["co_freevars"] = freevars + code_options["co_argcount"] = len(args) + code_options["co_posonlyargcount"] = 0 + code_options["co_kwonlyargcount"] = 0 + code_options["co_varnames"] = tuple( + args + [v for v in code_options["co_varnames"] if v not in args] + ) + code_options["co_flags"] = code_options["co_flags"] & ~( + CO_VARARGS | CO_VARKEYWORDS + ) + target = next(i for i in instructions if i.offset == offset) + + prefix = [] + if is_py311_plus: + if freevars: + prefix.append( + create_instruction("COPY_FREE_VARS", arg=len(freevars)) + ) + prefix.append(create_instruction("RESUME", arg=0)) + + cleanup: List[Instruction] = [] + hooks = {fn.stack_index: fn for fn in setup_fns} + hook_target_offsets = { + fn.stack_index: setup_fn_target_offsets[i] + for i, fn in enumerate(setup_fns) + } + offset_to_inst = {inst.offset: inst for inst in instructions} + # map old hook targets to new targets generated by the hook + old_hook_target_remap = {} + null_idxes_i = 0 + for i in range(nstack): + while ( + null_idxes_i < len(null_idxes) + and null_idxes[null_idxes_i] == i + null_idxes_i + ): + prefix.append(create_instruction("PUSH_NULL")) + null_idxes_i += 1 + prefix.append(create_instruction("LOAD_FAST", argval=f"___stack{i}")) + if i in hooks: + hook = hooks.pop(i) + hook_insts, exn_target = hook(code_options, cleanup) + prefix.extend(hook_insts) + if is_py311_plus: + hook_target_offset = hook_target_offsets.pop(i) + old_hook_target = offset_to_inst[hook_target_offset] + meta.prefix_block_target_offset_remap.append(hook_target_offset) + old_hook_target_remap[old_hook_target] = exn_target + if is_py311_plus: + # reverse the mapping since targets of later/nested contexts are inserted + # into the mapping later, but show up earlier in the prefix. + meta.prefix_block_target_offset_remap = list( + reversed(meta.prefix_block_target_offset_remap) + ) + + assert not hooks + + prefix.append(create_jump_absolute(target)) + + # because the line number table monotonically increases from co_firstlineno + # remove starts_line for any instructions before the graph break instruction + # this will ensure the instructions after the break have the correct line numbers + for inst in instructions: + if inst.offset == target.offset: + break + inst.starts_line = None + if sys.version_info >= (3, 11): + inst.positions = None + + if cleanup: + prefix.extend(cleanup) + prefix.extend(cls.unreachable_codes(code_options)) + + # remap original instructions' exception table entries + if old_hook_target_remap: + assert is_py311_plus + for inst in instructions: + if ( + inst.exn_tab_entry + and inst.exn_tab_entry.target in old_hook_target_remap + ): + inst.exn_tab_entry.target = old_hook_target_remap[ + inst.exn_tab_entry.target + ] + + # TODO(jansel): add dead code elimination here + instructions[:] = prefix + instructions + + new_code = transform_code_object(code, update) + ContinueExecutionCache.generated_code_metadata[new_code] = meta + return new_code + + @staticmethod + def unreachable_codes(code_options) -> List[Instruction]: + """Codegen a `raise None` to make analysis work for unreachable code""" + return [ + create_instruction("LOAD_CONST", argval=None), + create_instruction("RAISE_VARARGS", arg=1), + ] + + @classmethod + def generate_based_on_original_code_object( + cls, code, lineno, offset: int, setup_fn_target_offsets: Tuple[int, ...], *args + ): + """ + This handles the case of generating a resume into code generated + to resume something else. We want to always generate starting + from the original code object so that if control flow paths + converge we only generated 1 resume function (rather than 2^n + resume functions). + """ + + meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[ + code + ] + new_offset = None + + def find_new_offset( + instructions: List[Instruction], code_options: Dict[str, Any] + ): + nonlocal new_offset + (target,) = (i for i in instructions if i.offset == offset) + # match the functions starting at the last instruction as we have added a prefix + (new_target,) = ( + i2 + for i1, i2 in zip(reversed(instructions), reversed(meta.instructions)) + if i1 is target + ) + assert target.opcode == new_target.opcode + new_offset = new_target.offset + + transform_code_object(code, find_new_offset) + + if sys.version_info >= (3, 11): + # setup_fn_target_offsets currently contains the target offset of + # each setup_fn, based on `code`. When we codegen the resume function + # based on the original code object, `meta.code`, the offsets in + # setup_fn_target_offsets must be based on `meta.code` instead. + if not meta.block_target_offset_remap: + block_target_offset_remap = meta.block_target_offset_remap = {} + + def remap_block_offsets( + instructions: List[Instruction], code_options: Dict[str, Any] + ): + # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, + # so we can tell which block a prefix PUSH_EXC_INFO belongs to, + # by counting. Then we can use meta.prefix_block-target_offset_remap + # to determine where in the original code the PUSH_EXC_INFO offset + # replaced. + prefix_blocks: List[Instruction] = [] + for inst in instructions: + if len(prefix_blocks) == len( + meta.prefix_block_target_offset_remap + ): + break + if inst.opname == "PUSH_EXC_INFO": + prefix_blocks.append(inst) + + # offsets into prefix + for inst, o in zip( + prefix_blocks, meta.prefix_block_target_offset_remap + ): + block_target_offset_remap[cast(int, inst.offset)] = o + + # old bytecode targets are after the prefix PUSH_EXC_INFO's + old_start_offset = ( + cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1 + ) + # offsets into old bytecode + old_inst_offsets = sorted( + n for n in setup_fn_target_offsets if n > old_start_offset + ) + targets = _filter_iter( + instructions, old_inst_offsets, lambda inst, o: inst.offset == o + ) + new_targets = _filter_iter( + zip(reversed(instructions), reversed(meta.instructions)), + targets, + lambda v1, v2: v1[0] is v2, + ) + for new, old in zip(new_targets, targets): + block_target_offset_remap[old.offset] = new[1].offset + + transform_code_object(code, remap_block_offsets) + + # if offset is not in setup_fn_target_offsets, it is an error + setup_fn_target_offsets = tuple( + meta.block_target_offset_remap[n] for n in setup_fn_target_offsets + ) + return ContinueExecutionCache.lookup( + meta.code, lineno, new_offset, setup_fn_target_offsets, *args + ) + + +""" +# partially finished support for with statements + +def convert_locals_to_cells( + instructions: List[Instruction], + code_options: Dict[str, Any]): + + code_options["co_cellvars"] = tuple( + var + for var in code_options["co_varnames"] + if var not in code_options["co_freevars"] + and not var.startswith("___stack") + ) + cell_and_free = code_options["co_cellvars"] + code_options["co_freevars"] + for inst in instructions: + if str(inst.argval).startswith("___stack"): + continue + elif inst.opname == "LOAD_FAST": + inst.opname = "LOAD_DEREF" + elif inst.opname == "STORE_FAST": + inst.opname = "STORE_DEREF" + elif inst.opname == "DELETE_FAST": + inst.opname = "DELETE_DEREF" + else: + continue + inst.opcode = dis.opmap[inst.opname] + assert inst.argval in cell_and_free, inst.argval + inst.arg = cell_and_free.index(inst.argval) + +def patch_setup_with( + instructions: List[Instruction], + code_options: Dict[str, Any] +): + nonlocal need_skip + need_skip = True + target_index = next( + idx for idx, i in enumerate(instructions) if i.offset == offset + ) + assert instructions[target_index].opname == "SETUP_WITH" + convert_locals_to_cells(instructions, code_options) + + stack_depth_before = nstack + stack_effect(instructions[target_index].opcode, + instructions[target_index].arg) + + inside_with = [] + inside_with_resume_at = None + stack_depth = stack_depth_before + idx = target_index + 1 + for idx in range(idx, len(instructions)): + inst = instructions[idx] + if inst.opname == "BEGIN_FINALLY": + inside_with_resume_at = inst + break + elif inst.target is not None: + unimplemented("jump from with not supported") + elif inst.opname in ("BEGIN_FINALLY", "WITH_CLEANUP_START", "WITH_CLEANUP_FINISH", "END_FINALLY", + "POP_FINALLY", "POP_EXCEPT", + "POP_BLOCK", "END_ASYNC_FOR"): + unimplemented("block ops not supported") + inside_with.append(inst) + stack_depth += stack_effect(inst.opcode, inst.arg) + assert inside_with_resume_at + + instructions = [ + create_instruction("LOAD_FAST", f"___stack{i}") for i in range(nstack) + ] + [ + create_instruction("SETUP_WITH", target=instructions[target_index].target) + ... call the function ... + unpack_tuple + ] + [ + create_instruction("JUMP_ABSOLUTE", target=inside_with_resume_at) + ] +""" diff --git a/MLPY/Lib/site-packages/torch/_dynamo/side_effects.py b/MLPY/Lib/site-packages/torch/_dynamo/side_effects.py new file mode 100644 index 0000000000000000000000000000000000000000..f1de34e052142d61a91cb0c80dd92d58da99b85d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/side_effects.py @@ -0,0 +1,542 @@ +import inspect +from typing import Any, Dict, List, Optional, Union + +import torch.nn + +from . import utils, variables +from .bytecode_transformation import ( + create_call_function, + create_call_method, + create_instruction, +) +from .codegen import PyCodegen +from .exc import unimplemented +from .source import LocalSource, Source +from .utils import nn_module_new, object_new +from .variables.base import ( + is_side_effect_safe, + MutableLocalBase, + MutableLocalSource, + VariableTracker, +) + + +class MutableSideEffects(MutableLocalBase): + """ + VariableTracker.mutable_local marker to indicate a list passed as + an input that if we mutate we need to re-apply those mutations after + the graph runs. + """ + + def __init__(self, source: Source, is_modified: bool = False): + super().__init__(MutableLocalSource.Existing) + self.source = source + self.is_modified = is_modified + + +class AttributeMutation(MutableLocalBase): + """ + VariableTracker.mutable_local marker to track changes to attributes + """ + + def __init__(self, typ: MutableLocalSource, source: Optional[Source]): + super().__init__(typ) + self.source = source + + +class AttributeMutationExisting(AttributeMutation): + def __init__(self, source: Source): + super().__init__(MutableLocalSource.Existing, source) + self.source = source + + +class AttributeMutationNew(AttributeMutation): + def __init__(self, source: Optional[Source], cls_source: Optional[Source]): + super().__init__(MutableLocalSource.Local, source) + self.cls_source = cls_source + + +class SideEffects: + """ + Track side effects (list mutation, setattr, etc) that need to be + applied after an FX graph is run. + """ + + id_to_variable: Dict[int, VariableTracker] + store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] + keepalive: List[Any] + + def __init__( + self, + id_to_variable=None, + store_attr_mutations=None, + keepalive=None, + save_for_backward=None, + tensor_hooks=None, + ): + super().__init__() + self.id_to_variable = id_to_variable or {} + self.store_attr_mutations = store_attr_mutations or {} + self.keepalive = keepalive or [] + self.save_for_backward = save_for_backward or [] + self.tensor_hooks = tensor_hooks or {} + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SideEffects) + # NB: do NOT test keepalive + return ( + self.id_to_variable == other.id_to_variable + and self.store_attr_mutations == other.store_attr_mutations + and self.save_for_backward == other.save_for_backward + and self.tensor_hooks == other.tensor_hooks + ) + + def diff(self, other: "SideEffects") -> Optional[str]: + if self.id_to_variable != other.id_to_variable: + sk_itv = self.id_to_variable.keys() + ok_itv = other.id_to_variable.keys() + if sk_itv != ok_itv: + return f"id_to_variable keys: {sk_itv} != {ok_itv}" + # Feel free to augment this with more fancy diffing logic + # if needed for debugging + return "id_to_variable: unknown diff" + elif self.store_attr_mutations != other.store_attr_mutations: + sk_sam = self.store_attr_mutations.keys() + ok_sam = other.store_attr_mutations.keys() + if sk_sam != ok_sam: + return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" + return "store_attr_mutations: unknown diff" + elif self.save_for_backward != other.save_for_backward: + return "save_for_backward" + elif self.tensor_hooks != other.tensor_hooks: + return "tensor_hooks" + else: + return None + + def clone(self): + """Create a shallow copy""" + return self.__class__( + id_to_variable=dict(self.id_to_variable), + store_attr_mutations={ + k: dict(v) for k, v in self.store_attr_mutations.items() + }, + keepalive=list(self.keepalive), + save_for_backward=self.save_for_backward, + tensor_hooks=self.tensor_hooks, + ) + + def apply(self, fn, cache=None, skip_fn=lambda _: False): + if cache is None: + cache = dict() + + self.id_to_variable = { + k: VariableTracker.apply(fn, v, cache, skip_fn) + for k, v in self.id_to_variable.items() + } + self.store_attr_mutations = { + k: VariableTracker.apply(fn, v, cache, skip_fn) + for k, v in self.store_attr_mutations.items() + } + self.save_for_backward = VariableTracker.apply( + fn, self.save_for_backward, cache, skip_fn + ) + self.tensor_hooks = VariableTracker.apply(fn, self.tensor_hooks, cache, skip_fn) + + def __contains__(self, item): + return id(item) in self.id_to_variable + + def __getitem__(self, item): + return self.id_to_variable[id(item)] + + def check_allowed_side_effect(self, item): + from torch._dynamo.variables.misc import AutogradFunctionContextVariable + + # People do things like self.dim = dim inside autograd.Function. + # These are benign. + if isinstance(item, AutogradFunctionContextVariable): + return True + if not is_side_effect_safe(item.mutable_local): + unimplemented( + "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" + ) + + def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): + assert self.is_attribute_mutation(item) + self.check_allowed_side_effect(item) + if item.mutable_local not in self.store_attr_mutations: + self.store_attr_mutations[item.mutable_local] = {} + self.store_attr_mutations[item.mutable_local][name] = value + + def load_attr(self, item, name, deleted_ok=False): + assert self.is_attribute_mutation(item) + result = self.store_attr_mutations[item.mutable_local][name] + if not deleted_ok and isinstance(result, variables.DeletedVariable): + unimplemented("read deleted attribute") + return result + + def store_cell(self, cellvar, value): + assert isinstance(cellvar, variables.NewCellVariable) + assert isinstance(value, variables.VariableTracker) + self.store_attr(cellvar, "cell_contents", value) + + def load_cell(self, cellvar): + assert isinstance(cellvar, variables.NewCellVariable) + return self.load_attr(cellvar, "cell_contents") + + def load_global(self, gvar: VariableTracker, name: str): + assert isinstance(gvar, variables.VariableTracker) + return self.load_attr(gvar, name) + + def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): + assert isinstance(gvar, variables.VariableTracker) + assert isinstance(value, variables.VariableTracker) + self.store_attr(gvar, name, value) + + @staticmethod + def cls_supports_mutation_side_effects(cls): + return inspect.getattr_static(cls, "__setattr__", None) in ( + object.__setattr__, + torch.nn.Module.__setattr__, + ) + + def is_attribute_mutation(self, item): + return isinstance(item.mutable_local, AttributeMutation) + + def has_pending_mutation(self, item): + return self.is_attribute_mutation(item) and bool( + self.store_attr_mutations.get(item.mutable_local) + ) + + def is_modified(self, item): + if isinstance(item.mutable_local, AttributeMutationNew): + return True + if self.is_attribute_mutation(item): + return item.mutable_local in self.store_attr_mutations + return item.mutable_local.is_modified + + def _track_obj( + self, + item: Any, + variable: VariableTracker, + mutable_cls=MutableSideEffects, + ): + """Start tracking a new variable for mutation""" + assert variable.source is not None + variable.mutable_local = mutable_cls(variable.source) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + track_mutable = _track_obj + + def track_object_existing( + self, + item: Any, + variable: VariableTracker, + ): + return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting) + + def track_object_new( + self, + cls_source: Source, + user_cls: Any, + variable_cls: Any, + options, + ): + if user_cls is torch.autograd.function.FunctionCtx: + obj = torch.autograd.Function() + elif issubclass(user_cls, torch.nn.Module): + obj = nn_module_new(user_cls) + else: + obj = object_new(user_cls) + variable = variable_cls( + obj, + mutable_local=AttributeMutationNew(None, cls_source), + **options, + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def track_cell_new( + self, + ): + obj = object() + variable = variables.NewCellVariable( + mutable_local=AttributeMutationNew(None, None), + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def track_cell_existing(self, source: Source, item: Any): + variable = variables.NewCellVariable( + mutable_local=AttributeMutationExisting(source), + ) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + def track_global_existing(self, source: Source, item: Any): + variable = variables.NewGlobalVariable( + mutable_local=AttributeMutationExisting(source), + ) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + def track_save_for_backward(self, ctx, args): + assert isinstance(ctx, variables.AutogradFunctionContextVariable) + self.save_for_backward.append((ctx, args)) + + def track_tensor_variables_from_runahead_side_effects(self, other): + # In higher order ops we want to keep track of tensors seen in the + # speculate_subgraph so that we don't lift them again as a new input in + # other speculate_subgraph or in the root tracer. + for other_item in other.keepalive: + other_id = id(other_item) + other_variable = other.id_to_variable[other_id] + if other_id not in self.id_to_variable and isinstance( + other_variable, variables.TensorVariable + ): + self.track_object_existing(other_item, other_variable) + + def prune_dead_object_new(self, tx): + live_new_objects = set() + skip_obj = None + + def visit(var: VariableTracker): + if ( + isinstance(var.mutable_local, AttributeMutationNew) + and var.mutable_local is not skip_obj + ): + live_new_objects.add(var.mutable_local) + return var + + def is_live(var: Union[MutableLocalBase, VariableTracker]): + if isinstance(var, AttributeMutationNew): + return var in live_new_objects + if isinstance(var, VariableTracker): + return is_live(var.mutable_local) + return True + + VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals)) + for var in self.id_to_variable.values(): + if not isinstance(var.mutable_local, AttributeMutationNew): + VariableTracker.apply(visit, var) + + for skip_obj, setattrs in self.store_attr_mutations.items(): + VariableTracker.apply(visit, setattrs) + + self.id_to_variable = { + k: v for k, v in self.id_to_variable.items() if is_live(v) + } + self.store_attr_mutations = { + k: v for k, v in self.store_attr_mutations.items() if is_live(k) + } + + def mutation(self, var): + self.check_allowed_side_effect(var) + if isinstance(var.mutable_local, MutableSideEffects): + var.mutable_local = MutableSideEffects(var.mutable_local.source, True) + + def _get_modified_vars(self): + return [var for var in self.id_to_variable.values() if self.is_modified(var)] + + def codegen_save_tempvars(self, cg: PyCodegen): + for var in self._get_modified_vars(): + if isinstance( + var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) + ) and isinstance(var, variables.NewCellVariable): + cg.load_import_from(utils.__name__, "make_cell") + cg.extend_output(create_call_function(0, True)) + cg.add_cache(var) + if isinstance(var.mutable_local, AttributeMutationNew): + var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + elif isinstance(var.mutable_local, AttributeMutationNew): + if isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented("AutogradFunctionContextVariable escaped") + if "__call_nn_module_init" in self.store_attr_mutations.get( + var.mutable_local, {} + ): + assert isinstance(var, variables.UnspecializedNNModuleVariable) + cg.load_import_from(utils.__name__, "nn_module_new") + else: + cg.load_import_from(utils.__name__, "object_new") + cg(var.mutable_local.cls_source) + cg.extend_output(create_call_function(1, True)) + cg.add_cache(var) + var.mutable_local.source = LocalSource(cg.tempvars[var]) + elif var in cg.tempvars: + assert cg.tempvars.get(var) is None + # subsequent usage should point to the original variable + cg(var.mutable_local.source) + cg.add_cache(var) + + for ctx, args in self.save_for_backward: + cg(ctx.source) + cg.extend_output( + [create_instruction("LOAD_METHOD", argval="save_for_backward")] + ) + for arg in args: + cg(arg) + cg.extend_output( + [ + *create_call_method(len(args)), + create_instruction("POP_TOP"), + ] + ) + + def register_hook(self, tensor, hook, handle, name): + assert isinstance(tensor, variables.TensorVariable) + assert isinstance(hook, variables.VariableTracker) + assert ( + isinstance(handle, variables.RemovableHandleVariable) + and handle.mutable_local + ) + assert hasattr(torch.Tensor, name) + idx = len(self.tensor_hooks.keys()) + # duplicate index possible because of self.remove_hook() + while idx in self.tensor_hooks: + idx += 1 + self.tensor_hooks[idx] = (tensor, hook, handle, name) + assert not handle.idx + handle.idx = idx + + def remove_hook(self, idx): + del self.tensor_hooks[idx] + + def codegen_hooks(self, cg): + for ( + tensor, + hook, + handle, + name, + ) in self.tensor_hooks.values(): + # Note: [On tensor.register_hook] + # + # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented + # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). + # + # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. + # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in + # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able + # tensors. Because a source indicates knowledge of this object outside the torch compile region, and + # because we are running residuals firmly before .backward() can be run, it is sound to invoke + # `register_hook` on a known tensor. + # + # For tensors without a source, we support a limited subset of hooks. Global functions only, and + # compiled_autograd must be enabled or we will graph break. + # + # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the + # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed + # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the + # stack intact. + # + # Dynamo Tensor Hooks Workflow: + # - Functions passed to register_hook are lifted globally. + # - For tensors with sources: + # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: + # - Generate the tensor. + # - Issue a register_hook call on the tensor, linking to the globally stored function. + # - Incorporate a handle if one was established in the eager phase. + # - For tensors without sources: + # - We don't generate any instructions for registering a hook. + # - Handles from intermediary hooks are NYI. + # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. + # - We then manually insert the call function above into the graph. + # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. + assert tensor.source, "Hooks on non input tensors NYI - should not get here" + cg(tensor) + cg.extend_output([cg.create_load_attr(name)]) + cg(hook) + cg.extend_output(create_call_function(1, True)) + + # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will + # be associated with the return value of register_hook(). This consumes the top of stack. + cg.add_cache(handle) + + def codegen_update_mutated(self, cg: PyCodegen): + suffixes = [] + for var in self._get_modified_vars(): + if isinstance(var, variables.ListVariable): + # old[:] = new + cg(var, allow_cache=False) + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output( + [ + cg.create_load_const(None), + cg.create_load_const(None), + create_instruction("BUILD_SLICE", arg=2), + ] + ) + suffixes.append([create_instruction("STORE_SUBSCR")]) + elif isinstance(var, variables.ConstDictVariable): + cg.tx.output.update_co_names("clear") + cg.tx.output.update_co_names("update") + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output([create_instruction("LOAD_METHOD", argval="update")]) + cg(var, allow_cache=False) + + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")]) + + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + *create_call_method(1), # update + create_instruction("POP_TOP"), + ] + ) + elif self.is_attribute_mutation(var): + for name, value in self.store_attr_mutations.get( + var.mutable_local, {} + ).items(): + if isinstance(var, variables.NewGlobalVariable): + cg.tx.output.update_co_names(name) + cg(value) + suffixes.append( + [create_instruction("STORE_GLOBAL", argval=name)] + ) + elif name == "__call_nn_module_init": + pass # handled in codegen_save_tempvars + elif isinstance(value, variables.DeletedVariable): + if isinstance( + var.mutable_local, AttributeMutationExisting + ) and hasattr(getattr(var, "value", None), name): + cg.tx.output.update_co_names(name) + cg(var.mutable_local.source) + suffixes.append( + [create_instruction("DELETE_ATTR", argval=name)] + ) + else: + cg.tx.output.update_co_names(name) + cg(value) + cg(var.mutable_local.source) + suffixes.append([create_instruction("STORE_ATTR", argval=name)]) + elif isinstance(var, variables.TupleIteratorVariable): + for _ in range(var.index): + cg.load_import_from(utils.__name__, "iter_next") + cg(var.mutable_local.source) # type: ignore[attr-defined] + cg.extend_output(create_call_function(1, True)) + cg.append_output(create_instruction("POP_TOP")) + else: + raise AssertionError(type(var)) + + # do all the actual mutations at the very end to handle dependencies + for suffix in reversed(suffixes): + cg.extend_output(suffix) + + def is_empty(self): + return not ( + any(map(self.is_modified, self.id_to_variable.values())) + or self.tensor_hooks + or self.save_for_backward + or self.tensor_hooks + ) + + def clear(self): + self.keepalive.clear() + self.id_to_variable.clear() diff --git a/MLPY/Lib/site-packages/torch/_dynamo/source.py b/MLPY/Lib/site-packages/torch/_dynamo/source.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ca75c4eaa1f910ff8d9384074a873072c48c37 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/source.py @@ -0,0 +1,545 @@ +import collections +import dataclasses +import enum +from typing import Any, Optional, Union + +from torch._guards import ChainedSource, GuardSource, Source + +from . import utils +from .bytecode_transformation import create_call_function, create_instruction +from .utils import enum_repr + +# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, +# so those cases are omitted intentionally +_GUARD_SOURCE_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE, + GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE, + GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE, +} + +_GUARD_SOURCE_FSDP_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, +} + +_GUARD_SOURCE_NOT_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL, + GuardSource.GLOBAL: GuardSource.GLOBAL, + GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL, + GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL, +} + + +def is_constant_source(source): + if isinstance(source, ConstantSource): + return True + try: + if source.guard_source() == GuardSource.CONSTANT: + return True + except NotImplementedError: + pass + + return False + + +def reconstruct_getitem( + source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice +): + source.base.reconstruct(codegen) + if isinstance(source.index, Source): + source.index.reconstruct(codegen) + else: + if index_is_slice: + assert isinstance(source, GetItemSource) + codegen.append_output(codegen.create_load_const(source.unpack_slice())) + else: + codegen.append_output(codegen.create_load_const(source.index)) + + +@dataclasses.dataclass(frozen=True) +class LocalSource(Source): + local_name: str + cell_or_freevar: bool = False + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.LOCAL + + def name(self): + return f"L[{repr(self.local_name)}]" + + +@dataclasses.dataclass(frozen=True) +class SyntheticLocalSource(Source): + local_name: str + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.SYNTHETIC_LOCAL + + def name(self): + return f"SYNTHETIC_LOCAL[{self.local_name!r}]" + + +@dataclasses.dataclass(frozen=True) +class RandomValueSource(Source): + random_call_index: int + + def guard_source(self): + return GuardSource.RANDOM_VALUE + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) + codegen.append_output(codegen.create_load_const(self.random_call_index)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def name(self): + return f"random_value_{self.random_call_index}" + + +@dataclasses.dataclass(frozen=True) +class GlobalSource(Source): + global_name: str + + def reconstruct(self, codegen): + codegen.append_output( + codegen.create_load_global(self.global_name, False, add=True) + ) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): + return f"G[{repr(self.global_name)}]" + + +@dataclasses.dataclass(frozen=True) +class GlobalWeakRefSource(Source): + global_name: str + + def reconstruct(self, codegen): + codegen.append_output( + codegen.create_load_global(self.global_name, True, add=True) + ) + codegen.extend_output(create_call_function(0, False)) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): + return f"G[{repr(self.global_name)}]()" + + +@dataclasses.dataclass(frozen=True) +class AttrSource(ChainedSource): + member: str + get_static: bool = False + + def __post_init__(self): + assert self.base, "Can't construct an AttrSource without a valid base source" + if "." in self.member: + member_parts = self.member.split(".") + object.__setattr__( + self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) + ) + object.__setattr__(self, "member", member_parts[-1]) + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if self.get_static: + return f"inspect.getattr_static({self.base.name()}, {self.member!r})" + elif not self.member.isidentifier(): + return f"getattr({self.base.name()}, {self.member!r})" + return f"{self.base.name()}.{self.member}" + + +@dataclasses.dataclass(frozen=True) +class ParamBufferSource(AttrSource): + def guard_source(self): + return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] + + +# This source is intended to be used in places where a source is needed but it is expected +# that the symbol will be simplified out later on. Symbols with ephemeral sources are +# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral +# source. Guarding on this source is an error. +# +# Example: During subclass view fake-ification, any close-over ViewFunc state should be +# symbolicized / fake-ified to avoid invalid specialization during view replay. This source +# is useful for symbols utilized in the middle of the view chain that are not expected to be +# present within the final view shape metadata. +@dataclasses.dataclass(frozen=True) +class EphemeralSource(Source): + desc: Optional[str] = None + + def guard_source(self): + return GuardSource.EPHEMERAL + + def name(self): + return f"" + + def make_guard(self): + raise NotImplementedError() + + def is_ephemeral(self): + return True + + +class TensorProperty(enum.Enum): + SIZE = 0 + STRIDE = 1 + STORAGE_OFFSET = 2 + + def method_name(self): + if self is TensorProperty.SIZE: + return "size" + elif self is TensorProperty.STRIDE: + return "stride" + elif self is TensorProperty.STORAGE_OFFSET: + return "storage_offset" + + +@dataclasses.dataclass(frozen=True) +class TensorPropertySource(ChainedSource): + prop: TensorProperty + idx: Optional[int] = None # None for STORAGE_OFFSET + + def __post_init__(self): + assert self.base is not None + if self.prop is TensorProperty.STORAGE_OFFSET: + assert self.idx is None + else: + assert self.idx is not None + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.append_output(codegen.create_load_attr(self.prop.method_name())) + if self.idx is not None: + codegen.append_output(codegen.create_load_const(self.idx)) + codegen.extend_output( + create_call_function(1 if self.idx is not None else 0, True) + ) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if self.prop is TensorProperty.SIZE: + return f"{self.base.name()}.size()[{self.idx}]" + elif self.prop is TensorProperty.STRIDE: + return f"{self.base.name()}.stride()[{self.idx}]" + elif self.prop is TensorProperty.STORAGE_OFFSET: + assert self.idx is None + return f"{self.base.name()}.storage_offset()" + else: + raise AssertionError(f"unhandled {self.prop}") + + +@dataclasses.dataclass(frozen=True) +class NegateSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + raise NotImplementedError() + + def guard_source(self): + return self.base.guard_source() + + def name(self): + # NB: use method call so that function stripping regexes work + return f"{self.base.name()}.__neg__()" + + +@dataclasses.dataclass(frozen=True) +class ConvertIntSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"cast_symbool_to_symint_guardless({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class DefaultsSource(ChainedSource): + idx_key: Union[int, str] + is_kw: bool = False + field: str = dataclasses.field(init=False, repr=False, compare=False) + _name: str = dataclasses.field(init=False, repr=False, compare=False) + + def __post_init__(self): + assert ( + self.base + ), "Base must be a valid source in order to properly track and guard this Defaults to its origin." + if self.is_kw: + assert isinstance(self.idx_key, str) + object.__setattr__(self, "field", "__kwdefaults__") + object.__setattr__( + self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" + ) + else: + assert isinstance(self.idx_key, int) + object.__setattr__(self, "field", "__defaults__") + object.__setattr__( + self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" + ) + + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.extend_output(codegen.create_load_attrs(self.field)) + codegen.append_output(codegen.create_load_const(self.idx_key)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return self._name + + +@dataclasses.dataclass(frozen=True) +class GetItemSource(ChainedSource): + index: Any + index_is_slice: bool = False + + def __post_init__(self): + assert self.base is not None + if isinstance(self.index, slice): + # store the hashable version of the slice so the whole GetItemSource is hashable + super().__setattr__("index", self.index.__reduce__()) + super().__setattr__("index_is_slice", True) + + def reconstruct(self, codegen): + reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def guard_source(self): + return self.base.guard_source() + + def unpack_slice(self): + assert self.index_is_slice + slice_class, slice_args = self.index + return slice_class(*slice_args) + + def name(self): + # Index can be of following types + # 1) ConstDictKeySource + # 2) enum.Enum + # 3) index is a slice - example 1:4 + # 4) index is a constant - example string, integer + if isinstance(self.index, Source): + if not isinstance(self.index, ConstDictKeySource): + raise ValueError( + "GetItemSource index must be a constant, enum or ConstDictKeySource" + ) + return f"{self.base.name()}[{self.index.name()}]" + elif self.index_is_slice: + return f"{self.base.name()}[{self.unpack_slice()!r}]" + elif isinstance(self.index, enum.Enum): + return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]" + else: + return f"{self.base.name()}[{self.index!r}]" + + +@dataclasses.dataclass(frozen=True) +class ConstDictKeySource(GetItemSource): + def is_dict_key(self): + return True + + def reconstruct(self, codegen): + codegen.load_import_from(utils.__name__, "dict_keys_getitem") + self.base.reconstruct(codegen) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, True)) + + def name(self): + # The list creation will be CSE'd by PyExprCSEPass + return f"list({self.base.name()}.keys())[{self.index!r}]" + + +@dataclasses.dataclass(frozen=True) +class TupleIteratorGetItemSource(GetItemSource): + def reconstruct(self, codegen): + codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") + self.base.reconstruct(codegen) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, True)) + + def name(self): + return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" + + +@dataclasses.dataclass(frozen=True) +class TypeSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + codegen.load_import_from("builtins", "type") + self.base.reconstruct(codegen) + codegen.extend_output(create_call_function(1, True)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"type({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class ODictGetItemSource(ChainedSource): + index: Any + + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen): + codegen.append_output( + codegen._create_load_const(collections.OrderedDict.__getitem__) + ) + reconstruct_getitem(self, codegen, index_is_slice=False) + codegen.extend_output(create_call_function(2, True)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if isinstance(self.index, type): + rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}' + return f"___odict_getitem({self.base.name()}, {rep})" + elif isinstance(self.index, Source): + return f"___odict_getitem({self.base.name()}, {self.index.name()})" + else: + return f"___odict_getitem({self.base.name()}, {self.index!r})" + + +@dataclasses.dataclass(frozen=True) +class NNModuleSource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] + + def name(self): + return self.base.name() + + +@dataclasses.dataclass(frozen=True) +class NotNNModuleSource(NNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class FSDPNNModuleSource(NNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class GlobalStateSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.GLOBAL + + +@dataclasses.dataclass(frozen=True) +class ConstantSource(Source): + source_name: str + + def reconstruct(self, codegen): + codegen.append_output( + codegen.create_load_global(self.source_name, False, add=False) + ) + + def guard_source(self): + return GuardSource.CONSTANT + + def name(self): + return self.source_name + + def make_guard(self, fn): + raise NotImplementedError() + + +@dataclasses.dataclass(frozen=True) +class NumpyTensorSource(ChainedSource): + def name(self) -> str: + return f"___from_numpy({self.base.name()})" + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen): + codegen.load_import_from("torch", "as_tensor") + self.base.reconstruct(codegen) + codegen.extend_output(create_call_function(1, True)) + + +# This is a synthetic source that is associated with the singleton +# shape env guard we always register for all frames. We get the actual +# guard contents from the ambient ShapeEnv +@dataclasses.dataclass(frozen=True) +class ShapeEnvSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.SHAPE_ENV + + +@dataclasses.dataclass(frozen=True) +class BackwardStateSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.BACKWARD_STATE + + +def is_from_local_source(source: Source, *, allow_cell_or_freevar=True): + if isinstance(source, ChainedSource): + return is_from_local_source( + source.base, allow_cell_or_freevar=allow_cell_or_freevar + ) + if not isinstance(source, LocalSource): + return False + if not allow_cell_or_freevar and source.cell_or_freevar: + return False + return True + + +# TODO: can probably write a generic "test this on everything in the chain" +# helper +def is_from_defaults(source: Source): + if isinstance(source, DefaultsSource): + return True + if isinstance(source, ChainedSource): + return is_from_defaults(source.base) + return False diff --git a/MLPY/Lib/site-packages/torch/_dynamo/symbolic_convert.py b/MLPY/Lib/site-packages/torch/_dynamo/symbolic_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..1a624fd411db334d373c1340afb86db2255b5c0e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/symbolic_convert.py @@ -0,0 +1,2603 @@ +import collections +import contextlib +import copy +import dataclasses +import dis +import functools +import importlib +import inspect +import itertools +import linecache +import logging +import operator +import sys +import textwrap +import threading +import traceback +import types +import typing +import weakref +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Type +from unittest.mock import patch + +import torch +import torch._logging +from torch._guards import Checkpointable, tracing, TracingContext + +from . import config, exc, logging as torchdynamo_logging, trace_rules, variables +from .bytecode_analysis import ( + get_indexof, + JUMP_OPNAMES, + livevars_analysis, + propagate_line_nums, +) +from .bytecode_transformation import ( + cleaned_instructions, + create_call_function, + create_instruction, + create_jump_absolute, + Instruction, + is_generator, + unique_id, +) +from .code_context import code_context +from .codegen import PyCodegen +from .current_scope_id import current_scope_id +from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported +from .funcname_cache import get_funcname +from .guards import GuardBuilder, install_guard +from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState +from .replay_record import DummyModule, ExecutionRecorder +from .resume_execution import ContinueExecutionCache, ReenterWith +from .source import ( + AttrSource, + GetItemSource, + GlobalSource, + GlobalWeakRefSource, + LocalSource, + Source, +) +from .trace_rules import is_builtin_constant, is_forbidden +from .utils import ( + counters, + get_fake_value, + get_instruction_source_311, + graph_break_dup_warning_checker, + istype, + LazyString, + proxy_args_kwargs, +) +from .variables.base import ( + _is_top_level_scope, + is_side_effect_safe, + MutableLocal, + typestr, + VariableTracker, +) +from .variables.builder import VariableBuilder, wrap_fx_proxy +from .variables.builtin import BuiltinVariable +from .variables.constant import ConstantVariable +from .variables.ctx_manager import ( + ContextWrappingVariable, + GenericContextWrappingVariable, + WithExitFunctionVariable, +) +from .variables.dicts import ConstDictVariable, SetVariable +from .variables.functions import ( + BaseUserFunctionVariable, + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) +from .variables.lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SliceVariable, + TupleVariable, +) +from .variables.misc import ( + ClosureVariable, + GetAttrVariable, + InlinedClosureVariable, + NullVariable, + PythonModuleVariable, + UnknownVariable, +) +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + supported_const_comparison_ops, + supported_tensor_comparison_ops, + SymNodeVariable, + TensorVariable, +) +from .variables.user_defined import ( + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedObjectVariable, + UserDefinedVariable, +) + +log = logging.getLogger(__name__) +graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") +trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") +trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source") +tls = threading.local() + + +@dataclasses.dataclass +class SpeculationEntry: + filename: str + lineno: int + instruction_pointer: int + failed: bool = False + reason: Optional[GraphCompileReason] = None + + def fail_and_restart_analysis(self): + """ + Start tracing of the current frame over again, and don't take this branch. + """ + self.failed = True + raise exc.SpeculationRestartAnalysis() + + +@dataclasses.dataclass +class SpeculationLog: + """ + SpeculationLog replaces the prior copy_graphstate/restore_graphstate + checkpointing. Rather than saving/restoring state, we restart the + dynamo conversion process over from the beginning -- but when we + hit the start of the speculation that failed, we instead generate + a graph break. + """ + + entries: List[SpeculationEntry] = dataclasses.field(default_factory=list) + index: int = 0 + + def restart(self): + self.index = 0 + + def clear(self): + self.entries.clear() + self.index = 0 + + def next(self, filename: str, lineno: int, instruction_pointer) -> SpeculationEntry: + """ + Lookup or create a SpeculationEntry() that is shared across + RestartAnalysis calls. Args are used only for debug checks. + """ + if len(self.entries) == self.index: + self.entries.append(SpeculationEntry(filename, lineno, instruction_pointer)) + entry = self.entries[self.index] + self.index += 1 + assert ( + entry.instruction_pointer == instruction_pointer + and entry.filename == filename + and entry.lineno == lineno + ), textwrap.dedent( + f""" + SpecuationLog diverged at {self.index} of {len(self.entries)}: + - Run1: {entry.filename}:{entry.lineno} (ip={entry.instruction_pointer}) + - Run2: {filename}:{lineno} (ip={instruction_pointer}) + Please submit a bug report. + """ + ) + return entry + + +@functools.lru_cache(None) +def _step_logger(): + return torchdynamo_logging.get_step_logger(log) + + +@dataclasses.dataclass +class BlockStackEntry: + target: Instruction + stack_index: Optional[int] = None + with_context: Optional[ContextWrappingVariable] = None + + def can_restore(self): + return self.with_context is not None + + def resume_fn(self): + assert self.stack_index is not None + if self.with_context and self.with_context.target_values: + return ReenterWith(self.stack_index, tuple(self.with_context.target_values)) + else: + return ReenterWith(self.stack_index) + + def exit(self, tx): + assert self.with_context is not None + return self.with_context.exit(tx) + + +class InstructionTranslatorGraphState(NamedTuple): + output: OutputGraphState + symbolic_locals: Dict[str, VariableTracker] + stack: List[VariableTracker] + block_stack: List[BlockStackEntry] + instruction_pointer: Optional[int] + current_instruction: Instruction + next_instruction: Optional[Instruction] + lineno: int + + def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]: + for k in self._fields: + if k == "output": + return self.output.diff(other.output, prefix=f"{k}.") + sv = getattr(self, k) + ov = getattr(other, k) + if sv != ov: + return f"{k} mismatch: {sv} != {ov}" + return None + + +def stack_op(fn: typing.Callable[..., object]): + nargs = len(inspect.signature(fn).parameters) + fn_var = BuiltinVariable(fn) + + @functools.wraps(fn) + def impl(self: "InstructionTranslatorBase", inst: Instruction): + self.push(fn_var.call_function(self, self.popn(nargs), {})) + + return impl + + +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", + truth_fn: typing.Callable[[object], bool], + push: bool, +): + # Detect if this jump instruction is assert and normalize the assert + # by pushing dummy error message when nothing is given. + # + # Python 3.9 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_ASSERTION_ERROR + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS + # + # Python 3.8 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_GLOBAL 0 (Assertion type) + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS 1 + + if (truth_fn is not operator.truth) or push: + return False + + assert isinstance(self.instruction_pointer, int) + current_instruction_pointer = self.instruction_pointer + inst = self.instructions[current_instruction_pointer] + # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 + if sys.version_info < (3, 9): + if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": + return False + else: + if inst.opname != "LOAD_ASSERTION_ERROR": + return False + + current_instruction_pointer += 1 + + # Use dummy error message if its hard to extract + error_msg = "assertion error" + + inst = self.instructions[current_instruction_pointer] + # DETECT RAISE_VARARGS or LOAD CONST + if inst.opname == "LOAD_CONST": + if not isinstance(inst.argval, str): + return False + error_msg = inst.argval + + # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION + # (PRECALL for Python 3.11+) + current_instruction_pointer += 1 + inst = self.instructions[current_instruction_pointer] + if inst.opname not in ("CALL_FUNCTION", "PRECALL"): + return False + + # for Python 3.11+, PRECALL should be followed by CALL, then RAISE_VARARGS + # for Python < 3.11, CALL_FUNCTION should be followed by RAISE_VARARGS + current_instruction_pointer += 1 + if inst.opname == "PRECALL": + current_instruction_pointer += 1 + inst = self.instructions[current_instruction_pointer] + + if inst.opname != "RAISE_VARARGS": + return False + + self.push(ConstantVariable.create(error_msg)) + + return True + + +def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): + def inner(self: "InstructionTranslatorBase", inst: Instruction): + value: VariableTracker = self.pop() + if ( + config.rewrite_assert_with_torch_assert + and _detect_and_normalize_assert_statement(self, truth_fn, push) + ): + error_msg: VariableTracker = self.pop() + # Skip over things like `assert True` + if value.is_python_constant() and bool(value.as_python_constant()): + self.jump(inst) + return + + # TODO maybe should respect DtoH sync intention of users later?? + # Manually insert torch._assert_async instead of python assert and jump over + # assert related instructions as we don't need them anymore. + + # if we see Tensor as assert statement, no need to call scalar_tensor + if isinstance(value, TensorVariable): + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((value, error_msg), {}), + ) + self.jump(inst) + return + + if isinstance(value, SymNodeVariable): + # if the assertion is normal shape expression. + # just install guard and bail out. + sym_expr = value.sym_num + if not isinstance(sym_expr, torch.SymBool): + sym_expr = sym_expr != 0 + + result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) + if not result: + raise unimplemented( + "Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?" + ) + self.jump(inst) + return + + scalar_to_tensor_proxy = self.output.create_proxy( + "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) + ) + + scalar_to_tensor = wrap_fx_proxy( + self, + scalar_to_tensor_proxy, + example_value=get_fake_value(scalar_to_tensor_proxy.node, self), + ) + + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((scalar_to_tensor, error_msg), {}), + ) + self.jump(inst) + return + + if value.is_python_constant(): + if truth_fn(value.as_python_constant()): + push and self.push(value) + self.jump(inst) + elif ( + isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() + ): + # compile a partial subgraph prefix then jump into user code + if self.has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop\n" + f"{self.frame_summary()}" + ) + log.info(msg) + raise exc.SkipFrame(msg) + + self.push(value) + log.debug("generic_jump triggered compile") + self.output.compile_subgraph( + self, + reason=GraphCompileReason( + f"generic_jump {typestr(value)}", [self.frame_summary()] + ), + ) + self.pop() + + if_next = self.create_call_resume_at(self.next_instruction) + push and self.push(value) + if_jump = self.create_call_resume_at(inst.target) + + self.output.add_output_instructions( + [create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump + ) + elif isinstance(value, NNModuleVariable): + # Equivalent of "self.nn_module is not None" + mod = self.output.get_submodule(value.module_key) + if truth_fn(mod): + push and self.push(value) + self.jump(inst) + elif isinstance(value, UserDefinedObjectVariable): + x = value.var_getattr(self, "__bool__") + # if __bool__ is missing, trying __len__ to infer a truth value. + if isinstance(x, GetAttrVariable): + x = value.var_getattr(self, "__len__") + + # __bool__ or __len__ is function + if isinstance(x, UserMethodVariable): + result = x.call_function(self, [], {}) + if isinstance(result, ConstantVariable) and isinstance( + result.value, (bool, int) + ): + if truth_fn(result.value): + push and self.push(value) + self.jump(inst) + else: + unimplemented( + "generic_jump on UserDefined with __bool__ returning non-constant" + ) + # __bool__ or __len__ is non-function or not existed in the user defined object + else: + if truth_fn(True): + push and self.push(value) + self.jump(inst) + elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( + self + ): + if truth_fn(len(value.unpack_var_sequence(self))): + push and self.push(value) + self.jump(inst) + elif isinstance(value, SymNodeVariable): + eval_result = value.evaluate_expr(self.output) + if truth_fn(eval_result): + push and self.push(value) + self.jump(inst) + elif isinstance(value, variables.BackwardHookVariable): + if truth_fn(True): + push and self.push(value) + self.jump(inst) + else: + from .source import is_constant_source + + if value.source is not None and is_constant_source(value.source): + if truth_fn(value.get_real_value()): # type: ignore[attr-defined] + push and self.push(value) + self.jump(inst) + else: + # TODO link the torch.cond doc later + raise exc.UserError( + exc.UserErrorType.DYNAMIC_CONTROL_FLOW, + "Dynamic control flow is not supported at the moment. Please use " + "functorch.experimental.control_flow.cond to explicitly capture the control flow.", + case_name="cond_operands", + ) + + return inner + + +explain = False + + +def break_graph_if_unsupported(*, push): + def decorator(inner_fn): + @functools.wraps(inner_fn) + def wrapper(self: "InstructionTranslatorBase", inst: Instruction): + speculation = self.speculate() + if speculation.failed: + assert speculation.reason is not None + return handle_graph_break(self, inst, speculation.reason) + try: + TracingContext.set_current_loc( + self.f_code.co_filename, self.lineno, self.f_code.co_name + ) + return inner_fn(self, inst) + except Unsupported as excp: + if self.generic_context_manager_depth > 0: + # We don't support graph break under GenericContextWrappingVariable, + # If there is, we roll back to the checkpoint and fall back. + excp.remove_from_stats() + unimplemented("Graph break under GenericContextWrappingVariable") + + if isinstance(excp, exc.UncapturedHigherOrderOpError): + raise + + if not self.should_compile_partial_graph(): + raise + + user_stack = excp.real_stack + # TODO: Also report the traceback from the parent frame + user_stack_formatted = "".join(traceback.format_list(user_stack)) + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + # torch._dynamo.explain() formats this a little nicer, and presents a slightly + # more actionable user code pointer + if ( + graph_break_log.isEnabledFor(logging.DEBUG) + and not explain + and graph_break_dup_warning_checker.add(frame_loc) + ): + # This log line is exercised from + # python test/dynamo/test_exc.py -k test_graph_break_log + graph_break_log.debug( + "Graph break: from user code at:\n%s", + user_stack_formatted, + exc_info=True, + ) + else: + # This log line MUST NOT contain the string "Graph break", + # exercised by + # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log + log.debug( + "Unsupported break in user code at %s:%s (details suppressed)", + *frame_loc, + ) + + if self.has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop\n" + f"{self.frame_summary()}" + ) + log.info(msg) + raise exc.SkipFrame(msg) from excp + + excp.remove_from_stats() + excp.add_to_stats("graph_break") + speculation.reason = GraphCompileReason(excp.msg, user_stack) + speculation.fail_and_restart_analysis() + + def handle_graph_break( + self: "InstructionTranslatorBase", + inst: Instruction, + reason: GraphCompileReason, + ): + self.output.compile_subgraph(self, reason=reason) + cg = PyCodegen(self) + cleanup: List[Instruction] = [] + # Reconstruct the context variables in the block stack + for b in self.block_stack: + assert b.with_context is not None + cg(b.with_context) + cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) + self.output.add_output_instructions(cg.get_instructions()) + del cg + + if sys.version_info >= (3, 11) and inst.opname == "CALL": + kw_names = ( + self.kw_names.as_python_constant() + if self.kw_names is not None + else () + ) + if len(kw_names) > 0: + self.output.add_output_instructions( + [create_instruction("KW_NAMES", argval=kw_names)] + ) + self.output.add_output_instructions( + create_call_function(inst.arg, False) + ) + else: + # copy instruction, but without exception table data + assert inst.target is None + inst_copy = copy.copy(inst) + inst_copy.exn_tab_entry = None + self.output.add_output_instructions([inst_copy]) + + self.output.add_output_instructions(cleanup) + + if sys.version_info >= (3, 11) and inst.opname == "CALL": + # stack effect for PRECALL + CALL is split between the two instructions + stack_effect = dis.stack_effect( + dis.opmap["PRECALL"], inst.arg + ) + dis.stack_effect(dis.opmap["CALL"], inst.arg) + else: + stack_effect = dis.stack_effect(inst.opcode, inst.arg) + self.popn(push - stack_effect) + + for _ in range(push): + self.push(UnknownVariable()) + self.output.add_output_instructions( + self.create_call_resume_at(self.next_instruction) + ) + + return wrapper + + return decorator + + +class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]): + output: OutputGraph + symbolic_locals: Dict[str, VariableTracker] + symbolic_globals: Dict[str, VariableTracker] + stack: List[VariableTracker] + instruction_pointer: Optional[int] + current_instruction: Instruction + next_instruction: Optional[Instruction] + block_stack: List[BlockStackEntry] + lineno: int + kw_names: Optional[ConstantVariable] + accept_prefix_inst: bool + prefix_insts: List[Instruction] + inline_depth: int + inconsistent_side_effects: bool + current_speculation: Optional[SpeculationEntry] + + def mark_inconsistent_side_effects(self): + """ + InstructionTranslator has encountered instructions which may cause + dynamo to see a different version of history from eager + See: https://github.com/pytorch/pytorch/issues/110765 + """ + self.inconsistent_side_effects = True + + def has_backedge(self): + cur_offset = self.current_instruction.offset + assert self.instruction_pointer is not None + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in JUMP_OPNAMES: + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + + def cell_and_freevars(self): + if not hasattr(self, "_cell_and_freevars"): + self._cell_and_freevars = tuple( + self.code_options["co_cellvars"] or [] + ) + tuple(self.code_options["co_freevars"] or []) + return self._cell_and_freevars + + def prune_dead_locals(self): + reads = livevars_analysis(self.instructions, self.current_instruction) + # implicit use by super() + # reads = reads | {"__class__"} + # output variables? + reads = reads | set(self.cell_and_freevars()) + self.symbolic_locals = { + k: v for k, v in self.symbolic_locals.items() if k in reads + } + self.output.side_effects.prune_dead_object_new(self) + + def call_function( + self, + fn: VariableTracker, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ): + assert isinstance(fn, VariableTracker) + assert isinstance(args, list) + assert isinstance(kwargs, dict) + assert all( + isinstance(x, VariableTracker) + for x in itertools.chain(args, kwargs.values()) + ) + inner_fn = None + if hasattr(fn, "value"): + inner_fn = fn.value + if hasattr(fn, "fn"): + inner_fn = fn.fn + if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): + raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") + self.push(fn.call_function(self, args, kwargs)) + + def inline_user_function_return(self, fn, args, kwargs): + """ + A call to some user defined function by inlining it. + """ + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + + def get_line_of_code_header(self, lineno=None): + if lineno is None: + lineno = self.lineno + inline_depth_str = ( + f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else "" + ) + funcname = get_funcname(self.f_code.co_filename, lineno) + funcname_str = "" if funcname is None else f" ({funcname})" + return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" + + def get_log_starts_line_log_str(self): + log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" + line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() + log_str += f" {line}" + return log_str + + def log_starts_line(self): + trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) + + def step(self): + """Process exactly one instruction, return False we should exit""" + assert isinstance(self.instruction_pointer, int) + inst = self.instructions[self.instruction_pointer] + self.current_instruction = inst + self.instruction_pointer += 1 + if self.instruction_pointer < len(self.instructions): + self.next_instruction = self.instructions[self.instruction_pointer] + else: + self.instruction_pointer = None + self.next_instruction = None + if inst.starts_line and self.lineno != inst.starts_line: + self.lineno = inst.starts_line + self.log_starts_line() + + if ( + len(self.stack) == 0 + and self.should_compile_partial_graph() + and self.is_non_empty_graph() + ): + self.current_speculation = self.speculate() + if self.current_speculation.failed: + return self.step_graph_break(inst) + + log.debug("TRACE %s %s %s", inst.opname, inst.argval, self.stack) + + # 3.11 no longer uses a block stack, but we still keep track of one + # so that we know which contexts are currently active. + # For our purposes, all exception table entries with the same target + # are considered to be part of the same "block". + if sys.version_info >= (3, 11): + entry = inst.exn_tab_entry + if not ( + # still in the same block + self.block_stack + and entry + and self.block_stack[-1].target is entry.target + ): + if not entry: + # no longer in any block + # It is possible for NOPs to be between two instructions + # in the same block, but the NOPs are not covered by an + # exception table entry. In this case, assume that we + # are still in the same block. + if self.block_stack and inst.opname != "NOP": + # If we really escape from a block and the current + # instruction is not in another block, then there + # should be no other nested blocks that we are in. + assert len(self.block_stack) == 1 + self.block_stack.pop() + elif ( + # current instruction is in the previous block + len(self.block_stack) > 1 + and self.block_stack[-2].target is entry.target + ): + # exit the current block + self.block_stack.pop() + else: + # current instruction is in a new block + # push block to stack - note, BEFORE_WITH blocks won't + # be pushed here since BEFORE_WITH pushes the block, and + # the current instruction would be counted as being in that block. + self.block_stack.append( + BlockStackEntry(entry.target, len(self.stack)) + ) + + try: + if not hasattr(self, inst.opname): + unimplemented(f"missing: {inst.opname}") + TracingContext.set_current_loc( + self.f_code.co_filename, self.lineno, self.f_code.co_name + ) + getattr(self, inst.opname)(inst) + + return inst.opname != "RETURN_VALUE" + except Unsupported: + if self.current_speculation is None: + log.debug("empty checkpoint") + raise + log.debug("step triggered compile", exc_info=True) + + self.current_speculation.fail_and_restart_analysis() + + def step_graph_break(self, continue_inst): + # generate code from checkpoint + assert not self.output.output_instructions + assert self.current_speculation is not None + self.output.compile_subgraph( + self, + partial_convert=True, + reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), + ) + self.output.add_output_instructions( + [create_jump_absolute(continue_inst)] + self.instructions + ) + + def run_ctx_mgr(self): + # NB: Don't push the top level frame summary; set_current_loc will + # take care of it. However, DO make sure we attach real_stack to + # exceptions + return TracingContext.current_frame(None) + + def run(self): + with self.run_ctx_mgr(): + try: + self.output.push_tx(self) + while ( + self.instruction_pointer is not None + and not self.output.should_exit + and self.step() + ): + pass + except BackendCompilerFailed: + raise + except Exception as e: + if config.replay_record_enabled: + e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] + raise + finally: + self.output.pop_tx() + # Cleanup the outputGraph to delete the held tensors. We perform the + # cleanup only for InstructionTranslator and not + # InliningInstructionTranslator. The InliningInstructionTranslator + # mutates the output object and is restored to original state if + # there was an exception. + if isinstance(self, InstructionTranslator): + self.output.cleanup() + + def push(self, val: Optional[VariableTracker]): + assert val is None or isinstance( + val, VariableTracker + ), f"push expects VariableTracker, got {typestr(val)}" + self.stack.append(val) # type: ignore[arg-type] + + def push_many(self, vals: List[VariableTracker]): + for val in vals: + self.push(val) + + def pop(self) -> VariableTracker: + return self.stack.pop() + + def popn(self, n: int) -> List[VariableTracker]: + assert n >= 0 + return list(reversed([self.pop() for _ in range(n)])) + + def LOAD_FAST(self, inst): + name = inst.argval + if name in self.f_locals and config.replay_record_enabled: + self.exec_recorder.add_local_var(name, self.f_locals[name]) + + if name.startswith(".") and name not in self.symbolic_locals: + # This happens in dict/list comprehensions + name = name.replace(".", "implicit") + assert name not in self.cell_and_freevars() + if name not in self.symbolic_locals: + unimplemented("undefined LOAD_FAST") + self.push(self.symbolic_locals[name]) + if name.startswith("___stack"): + self.symbolic_locals.pop(name) + + def LOAD_DEREF(self, inst): + assert inst.argval in self.cell_and_freevars() + + if inst.argval in self.f_locals and config.replay_record_enabled: + self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) + + if inst.argval not in self.symbolic_locals: + unimplemented(f"undefined LOAD_DEREF {inst.argval}") + self.push(self.symbolic_locals[inst.argval]) + + def STORE_FAST(self, inst): + loaded_vt = self.pop() + name = inst.argval + # Only rename at the top-level scope, this is to avoid the confusion between + # mutating a variable vs renaming it (e.g. a = b) during speculating a higher order op, + # where mutation is prohibited and it's difficult to differentiate it with renaming. + if _is_top_level_scope(current_scope_id()): + loaded_vt = loaded_vt.rename(self, name) + self.symbolic_locals[name] = loaded_vt + + def DELETE_FAST(self, inst): + del self.symbolic_locals[inst.argval] + + STORE_DEREF = STORE_FAST + + def LOAD_CLOSURE(self, inst): + self.push(ClosureVariable(name=inst.argval)) + + def LOAD_CONST(self, inst): + # For empty tuples, create empty TupleVariable + if isinstance(inst.argval, tuple) and not inst.argval: + self.push(TupleVariable([])) + else: + self.push(ConstantVariable.create(value=inst.argval)) + + def get_global_source(self, name): + source: Source + if self.output.global_scope is self.f_globals: + source = GlobalSource(name) + else: + if "__name__" in self.f_globals: + source = AttrSource( + self.import_source(self.f_globals["__name__"]), name + ) + else: + mangled_name = self.output.install_global_by_id( + "___unnamed_scope", self.f_globals + ) + source = GetItemSource(GlobalSource(mangled_name), name) + return source + + def LOAD_GLOBAL(self, inst): + if sys.version_info >= (3, 11): + if inst.arg % 2: + self.PUSH_NULL(inst) + + name = inst.argval + + if config.replay_record_enabled: + if name in self.f_globals: + self.exec_recorder.add_global_var(name, self.f_globals[name]) + else: + assert name in self.f_builtins + self.exec_recorder.builtins[name] = self.f_builtins[name] + + if inst.argval == "AssertionError": + unimplemented("assert with non-string message") + + if name in self.symbolic_globals: + variable = self.output.side_effects[self.symbolic_globals[name]] + self.push(self.output.side_effects.load_global(variable, name)) + return + + try: + value = self.f_globals[name] + except KeyError: + return self.load_builtin(inst) + + source = self.get_global_source(name) + self.push(VariableBuilder(self, source)(value)) + + def STORE_GLOBAL(self, inst): + value = self.pop() + name = inst.argval + source = self.get_global_source(name) + if name not in self.symbolic_globals: + self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object + variable = self.output.side_effects.track_global_existing( + source, self.symbolic_globals[name] + ) + if isinstance(value, RemovableHandleVariable): + unimplemented("Storing handles in globals - NYI") + self.output.side_effects.store_global(variable, name, value) + + def import_source(self, module_name): + """Create an alias to a module for use in guards""" + if "torch_package" in module_name: + value = torch.package.package_importer._package_imported_modules[ + module_name + ] + alias = ( + module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") + ) + else: + value = importlib.import_module(module_name) + alias = f"__import_{module_name.replace('.', '_dot_')}" + f_globals = self.output.global_scope + assert alias not in f_globals or f_globals[alias] is value + f_globals[alias] = value + self.output.update_co_names(alias) + return GlobalSource(alias) + + def resolve_name(self, name, package, level): + """ + Copied from the Cpython implementation of __import__ + Resolve a relative module name to an absolute one. + https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 + """ + bits = package.rsplit(".", level - 1) + if len(bits) < level: + raise ImportError("attempted relative import beyond top-level package") + base = bits[0] + return f"{base}.{name}" if name else base + + def calc_package(self): + """ + Copied from the Cpython implementation of __import__ + https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 + """ + package = self.f_globals.get("__package__") + spec = self.f_globals.get("__spec__") + if package is not None: + if spec is not None and package != spec.parent: + log.warning( + "__package__ != __spec__.parent (%r != %r)", + package, + spec.parent, + stacklevel=3, + ) + return package + elif spec is not None: + return spec.parent + else: + log.warning( + "can't resolve package from __spec__ or __package__, " + "falling back on __name__ and __path__", + stacklevel=3, + ) + package = self.f_globals["__name__"] + if "__path__" not in self.f_globals: + package = package.rpartition(".")[0] + return package + + def IMPORT_NAME(self, inst): + level, fromlist = self.popn(2) + level = level.as_python_constant() + fromlist = fromlist.as_python_constant() + module_name = inst.argval + + # Are we replaying? if so, load recorded module + recorded_name = ( + f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" + ) + if recorded_name in self.f_globals: + value = self.f_globals[recorded_name] + source = GlobalSource(recorded_name) + else: + value = __import__( + module_name, + fromlist=fromlist, + level=level, + globals=self.f_globals, + ) + + if level != 0: + pkg = self.calc_package() + module_name = self.resolve_name(module_name, pkg, level) + + # For __import__, when the name variable is of the form package.module, + # normally, the top-level package (the name up till the first dot) is + # returned, not the module named by module_name. However, when a + # non-empty fromlist argument is given, the module named by name is + # returned. Therefore, we set the source correctly here. + if not fromlist: + top_level_module_name = module_name.partition(".")[0] + source = self.import_source(top_level_module_name) + else: + source = self.import_source(module_name) + + if config.replay_record_enabled: + self.exec_recorder.add_local_mod(recorded_name, value) + + if istype(value, (types.ModuleType, DummyModule)): + self.push(PythonModuleVariable(value, source=source)) + else: + unimplemented(f"IMPORT_NAME {typestr(value)}") + + def IMPORT_FROM(self, inst): + self.DUP_TOP(inst) + self.LOAD_ATTR(inst) + + def load_builtin(self, inst): + if inst.argval not in self.f_builtins: + raise NameError(f"name '{inst.argval}' is not defined") + val = self.f_builtins[inst.argval] + + if callable(val): + self.push(VariableBuilder(self, GlobalSource(inst.argval))(val)) + else: + assert is_builtin_constant(val) + self.push(ConstantVariable.create(value=val)) + + def jump(self, inst): + self.instruction_pointer = self.indexof[inst.target] + + JUMP_FORWARD = jump + JUMP_ABSOLUTE = jump + + POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) + POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) + JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) + JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) + + def SETUP_LOOP(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst.target)) + + def SETUP_EXCEPT(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst.target)) + + def POP_BLOCK(self, inst): + self.block_stack.pop() + + def SETUP_WITH(self, inst): + self.setup_or_before_with(inst) + + def SETUP_FINALLY(self, inst): + self.block_stack.append(BlockStackEntry(inst.target)) + + def BEGIN_FINALLY(self, inst): + self.push(None) + + def WITH_CLEANUP_START(self, inst): + exit, exc = self.popn(2) + assert exc is None + self.push(exc) + self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) + + def WITH_CLEANUP_FINISH(self, inst): + self.popn(2) + self.push(None) + + def CALL_FINALLY(self, inst): + """ + pushes the address of the next instruction onto the stack and increments + bytecode counter by delta + """ + # Python 3.8 only + assert self.next_instruction is not None + addr = self.indexof[self.next_instruction] + self.push(ConstantVariable.create(addr)) + self.instruction_pointer = self.indexof[inst.target] + + def END_FINALLY(self, inst): + # Python 3.8 only + # https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY + tos = self.pop() + if isinstance(tos, ConstantVariable): + self.instruction_pointer = tos.as_python_constant() + else: + pass + + def POP_FINALLY(self, inst): + # Python 3.8 only + preserve_tos = inst.argval + if preserve_tos: + tos = self.pop() + _ = self.pop() + if preserve_tos: + self.push(tos) # type: ignore[possibly-undefined] + + def FOR_ITER(self, inst): + it = self.pop().realize() + if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)): + try: + val, next_iter = it.next_variables(self) + self.push(next_iter) + self.push(val) + except StopIteration: + self.jump(inst) + else: + unimplemented(f"FOR_ITER {typestr(it)}") + + def COMPARE_OP(self, inst): + left, right = self.popn(2) + op = inst.argval + supported_any = dict( + itertools.chain( + supported_tensor_comparison_ops.items(), + supported_const_comparison_ops.items(), + ) + ) + if ( + isinstance( + left, + ( + TensorVariable, + SymNodeVariable, + NNModuleVariable, + BaseListVariable, + UserDefinedVariable, + BaseUserFunctionVariable, + ConstDictVariable, + ), + ) + and isinstance(right, ConstantVariable) + and right.value is None + and op in supported_const_comparison_ops + ): + # is None + self.push( + ConstantVariable.create( + supported_const_comparison_ops[op](object(), right.value) + ) + ) + + elif ( + left.is_python_constant() + and right.is_python_constant() + and op in supported_any + ): + # constant fold + self.push( + ConstantVariable.create( + supported_any[op]( + left.as_python_constant(), right.as_python_constant() + ), + ) + ) + elif op in ("in", "not in"): + self.push(right.call_method(self, "__contains__", [left], {})) + if op == "not in": + self.UNARY_NOT(inst) + else: + self.push( + BuiltinVariable(supported_any[op]).call_function( + self, [left, right], {} + ) + ) + + def GET_ITER(self, inst): + self.call_function(BuiltinVariable(iter), [self.pop()], {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION(self, inst): + args = self.popn(inst.argval) + fn = self.pop() + self.call_function(fn, args, {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION_EX(self, inst): + kwargsvars: VariableTracker + if inst.argval == 0: + kwargsvars = ConstDictVariable({}) + argsvars = self.pop() + elif inst.argval == 1: + kwargsvars = self.pop() + argsvars = self.pop() + else: + unimplemented("CALL_FUNCTION_EX") + fn = self.pop() + if sys.version_info >= (3, 11): + null = self.pop() + assert isinstance(null, NullVariable) + + if ( + isinstance(fn, GetAttrVariable) + and isinstance(fn.obj, TensorVariable) + and fn.name == "view" + and isinstance(argsvars, (ConstantVariable, TensorVariable)) + ): + # Hack to handle special case in some bert models. Converts + # x.view(*shape) into x.view(shape), which is correct for view() + # but not generally. See test_transpose_for_scores(). + argsvars = TupleVariable([argsvars]) + + if not isinstance( + argsvars, BaseListVariable + ) and argsvars.has_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) + + if not isinstance(argsvars, BaseListVariable) or not isinstance( + kwargsvars, ConstDictVariable + ): + unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") + + # Map to a dictionary of str -> VariableTracker + kwargsvars = kwargsvars.keys_as_python_constant() + self.call_function(fn, argsvars.items, kwargsvars) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION_KW(self, inst): + argnames = self.pop() + args = self.popn(inst.argval) + fn = self.pop() + assert isinstance(argnames, TupleVariable) and argnames.is_python_constant() + argnames = argnames.as_python_constant() + args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] + kwargs = dict(zip(argnames, kwargs_list)) + assert len(kwargs) == len(argnames) + self.call_function(fn, args, kwargs) + + def LOAD_METHOD_SUPER(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + arg = inst.argval[0] + argval = self.code_options["co_names"][arg] + if sys.version_info < (3, 11): + self.LOAD_ATTR(dataclasses.replace(inst, argval=argval)) + else: + self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) + + def LOAD_ATTR_SUPER(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + arg = inst.argval[0] + argval = self.code_options["co_names"][arg] + self.LOAD_ATTR(dataclasses.replace(inst, argval=argval)) + + def LOAD_METHOD(self, inst): + self.LOAD_ATTR(inst) + obj = self.pop() + if sys.version_info >= (3, 11): + # always follow the NULL + fn convention, since if obj + # is actually a method, self is already bound to it, so it + # doesn't need to be passed in as an arg. + self.PUSH_NULL(inst) + self.push(obj) + else: + self.push(obj) + self.push(None) + + def CALL_METHOD(self, inst): + args = self.popn(inst.argval) + dummy = self.pop() + assert dummy is None + fn = self.pop() + self.call_function(fn, args, {}) + + def LOAD_ATTR(self, inst): + obj = self.pop() + result = BuiltinVariable(getattr).call_function( + self, [obj, ConstantVariable.create(inst.argval)], {} + ) + self.push(result) + + def STORE_ATTR(self, inst): + speculation = self.speculate() + if speculation.failed: + return self.store_attr_graph_break(inst) + val, obj = self.popn(2) + + if isinstance(obj, NNModuleVariable): + # We don't allow side effects during export + # https://github.com/pytorch/torchdynamo/issues/1475 + assert ( + not self.export + ), f"Mutating module attribute {inst.argval} during export." + + try: + BuiltinVariable(setattr).call_function( + self, [obj, ConstantVariable.create(inst.argval), val], {} + ) + return + except Unsupported as e: + if not self.should_compile_partial_graph(): + raise + log.debug("STORE_ATTR triggered compile", exc_info=True) + e.remove_from_stats() + e.add_to_stats("graph_break") + speculation.fail_and_restart_analysis() + + def store_attr_graph_break(self, inst): + self.output.compile_subgraph( + self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) + ) + self.output.add_output_instructions([copy.copy(inst)]) + self.popn(2) + self.output.add_output_instructions( + self.create_call_resume_at(self.next_instruction) + ) + + def DELETE_ATTR(self, inst): + obj = self.pop() + BuiltinVariable(delattr).call_function( + self, [obj, ConstantVariable.create(inst.argval)], {} + ) + + def create_call_resume_at(self, offset): + raise AssertionError( + f"create_call_resume_at not overridden by subclass {type(self)}" + ) + + def should_compile_partial_graph(self) -> bool: + raise AssertionError( + f"should_compile_partial_graph not overridden by subclass {type(self)}" + ) + + @break_graph_if_unsupported(push=0) + def STORE_SUBSCR(self, inst): + val, obj, key = self.popn(3) + result = obj.call_method(self, "__setitem__", [key, val], {}) + + def BUILD_TUPLE(self, inst): + items = self.popn(inst.argval) + self.push(TupleVariable(items)) + + def BUILD_SLICE(self, inst): + items = self.popn(inst.argval) + self.push(SliceVariable(items)) + + def BUILD_LIST(self, inst): + items = self.popn(inst.argval) + self.push(ListVariable(items, mutable_local=MutableLocal())) + + def BUILD_SET(self, inst): + if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: + unimplemented("missing: BUILD_SET") + items = self.popn(inst.argval) + new_set = SetVariable(items, mutable_local=MutableLocal()) + self.push(new_set) + + def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): + seqs = self.popn(inst.argval) + items = list() + for seq in seqs: + try: + items.extend(seq.unpack_var_sequence(self)) + except NotImplementedError: + unimplemented(f"BUILD_LIST_UNPACK {seq}") + self.push(cls(items, mutable_local=MutableLocal())) + + def BUILD_TUPLE_UNPACK(self, inst): + self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) + + BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK + + def BUILD_MAP(self, inst): + items = self.popn(inst.argval * 2) + d = dict(zip(items[::2], items[1::2])) + self.push(ConstDictVariable(d, mutable_local=MutableLocal())) + + def BUILD_MAP_UNPACK(self, inst): + items = self.popn(inst.argval) + # ensure everything is a dict + items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] + result = dict() + for x in items: + assert isinstance(x, ConstDictVariable) + result.update(x.items) + self.push( + ConstDictVariable( + result, + mutable_local=MutableLocal(), + ) + ) + + BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK + + def BUILD_CONST_KEY_MAP(self, inst): + keys = self.pop() + values = self.popn(inst.argval) + assert isinstance(keys, TupleVariable) + assert keys.is_python_constant() + + keys = keys.unpack_var_sequence(self) + assert len(keys) == len(values) + + self.push( + ConstDictVariable( + dict(zip(keys, values)), + mutable_local=MutableLocal(), + ) + ) + + def MAP_ADD(self, inst): + k, v = self.popn(2) + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ConstDictVariable) + obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type] + + def SET_ADD(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.mutable_local + return obj.call_method(self, "add", [v], {}) + + def LIST_APPEND(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ListVariable) + assert obj.mutable_local + self.output.side_effects.mutation(obj) + obj.items.append(v) + + def MAKE_FUNCTION(self, inst): + flags = inst.arg + old_stack = list(self.stack) + if sys.version_info < (3, 11): + fn_name = self.pop() + code = self.pop() + if sys.version_info >= (3, 11): + # MAKE_FUNCTION behavior actually changed in 3.11, see + # https://github.com/python/cpython/pull/93189/ + assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined] + fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined] + defaults = None + closure = None + annotations = None + kwdefaults = None + + if flags & 0x08: + closure = self.pop() + if flags & 0x04: + annotations = self.pop() + if flags & 0x02: + kwdefaults = self.pop() + if flags & 0x01: + defaults = self.pop() + + self.push( + NestedUserFunctionVariable( + fn_name, + code, + self.f_globals, + defaults, + kwdefaults, + annotations, + closure, + closure_scope=self, + ) + ) + + def UNPACK_SEQUENCE(self, inst): + seq = self.pop() + if isinstance(seq, TensorVariable): + val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) + elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): + # x, y = a.shape + proxy = getattr(seq.obj.as_proxy(), seq.name) + val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] + elif seq.has_unpack_var_sequence(self): + val = seq.unpack_var_sequence(self) + else: + unimplemented(f"UNPACK_SEQUENCE {seq}") + if len(val) != inst.argval: + unimplemented("UNPACK_SEQUENCE length mismatch") + for i in reversed(val): + self.push(i) + + def UNPACK_EX(self, inst): + assert 0 <= inst.argval <= 0xFFFF + prefix = inst.argval & 0xFF # low byte + suffix = inst.argval >> 8 # high byte + seq = self.pop() + if seq.has_unpack_var_sequence(self): + vals = list(seq.unpack_var_sequence(self)) + assert len(vals) >= prefix + suffix + vals_prefix = vals[:prefix] + vals_list = vals[prefix : len(vals) - suffix] + vals_suffix = vals[len(vals) - suffix :] + for item in reversed(vals_suffix): + self.push(item) + self.push(TupleVariable(vals_list)) + for item in reversed(vals_prefix): + self.push(item) + else: + unimplemented(f"UNPACK_EX {seq}") + + def NOP(self, inst): + pass + + def POP_TOP(self, inst): + self.pop() + + def ROT_TWO(self, inst): + a = self.pop() + b = self.pop() + self.push(a) + self.push(b) + + def ROT_THREE(self, inst): + a = self.pop() + b = self.pop() + c = self.pop() + self.push(a) + self.push(c) + self.push(b) + + def ROT_FOUR(self, inst): + a = self.pop() + b = self.pop() + c = self.pop() + d = self.pop() + self.push(a) + self.push(d) + self.push(c) + self.push(b) + + def DUP_TOP(self, inst): + a = self.pop() + self.push(a) + self.push(a) + + def DUP_TOP_TWO(self, inst): + a = self.pop() + b = self.pop() + self.push(b) + self.push(a) + self.push(b) + self.push(a) + + def FORMAT_VALUE(self, inst): + flags = inst.arg + if (flags & 0x04) == 0x04: + fmt_spec = self.pop() + else: + fmt_spec = ConstantVariable.create("") + + value = self.pop() + if isinstance(value, SymNodeVariable): + value = ConstantVariable.create(str(value.sym_num)) + if (flags & 0x03) == 0x01: + value = BuiltinVariable(str).call_function(self, [value], {}) + elif (flags & 0x03) == 0x02: + value = BuiltinVariable(repr).call_function(self, [value], {}) + elif (flags & 0x03) == 0x03: + value = BuiltinVariable(ascii).call_function(self, [value], {}) + + fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") + + self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) + + def BUILD_STRING(self, inst): + format_string_parts: List[str] = [] + args: List[VariableTracker] = [] + kwargs: Dict[str, VariableTracker] = {} + for part in self.popn(inst.arg): + if isinstance(part, ConstantVariable): + format_string_parts.append("{}") + args.append(part) + elif isinstance(part, variables.StringFormatVariable): + format_string_parts.append(part.format_string) + args.extend(part.sym_args) + if set(kwargs.keys()) & set(part.sym_kwargs.keys()): + unimplemented( + f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}" + ) + kwargs.update(part.sym_kwargs) + else: + unimplemented(f"BUILD_STRING {part}") + self.push( + variables.StringFormatVariable.create( + "".join(format_string_parts), args, kwargs + ) + ) + + def IS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + if inst.argval == 0: + new_argval = "is" + else: + new_argval = "is not" + new_inst = create_instruction("COMPARE_OP", argval=new_argval) + self.COMPARE_OP(new_inst) + + def CONTAINS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + left, right = self.popn(2) + op = inst.argval + self.push(right.call_method(self, "__contains__", [left], {})) + if op == 1: + self.UNARY_NOT(inst) + + def LIST_EXTEND(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, ListVariable) + assert obj.mutable_local + obj.call_method(self, "extend", [v], {}) + + def LIST_TO_TUPLE(self, inst): + self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) + + def DICT_MERGE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ConstDictVariable) + assert obj.mutable_local + obj.call_method(self, "update", [v], {}) + + DICT_UPDATE = DICT_MERGE + + def GEN_START(self, inst): + self.pop() + + def GET_LEN(self, inst): + tos = self.stack[-1] + if tos.is_python_constant(): + self.push(ConstantVariable.create(len(tos.as_python_constant()))) + else: + self.push(tos.call_method(self, "__len__", [], {})) + + def MATCH_MAPPING(self, inst): + tos = self.stack[-1] + assert isinstance(tos, ConstDictVariable) + if isinstance(tos.items, collections.abc.Mapping): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(False)) + + def MATCH_SEQUENCE(self, inst): + tos = self.stack[-1] + assert tos.is_python_constant() + tos_value = tos.as_python_constant() + if isinstance(tos_value, collections.abc.Sequence) and not isinstance( + tos_value, (str, bytes, bytearray) + ): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(False)) + + def MATCH_KEYS(self, inst): + tos = self.stack[-1] + tos1 = self.stack[-2] + assert isinstance(tos1, ConstDictVariable) + + if all(k in tos1 for k in tos): # type: ignore[attr-defined] + self.push(TupleVariable([tos1.getitem_const(k) for k in tos])) # type: ignore[attr-defined] + if sys.version_info < (3, 11): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(None)) + if sys.version_info < (3, 11): + self.push(ConstantVariable.create(False)) + + def LOAD_ASSERTION_ERROR(self, inst): + unimplemented("assert with non-string message") + + UNARY_POSITIVE = stack_op(operator.pos) + UNARY_NEGATIVE = stack_op(operator.neg) + UNARY_NOT = stack_op(operator.not_) + UNARY_INVERT = stack_op(operator.invert) + + BINARY_POWER = stack_op(operator.pow) + BINARY_MULTIPLY = stack_op(operator.mul) + BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) + BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) + BINARY_TRUE_DIVIDE = stack_op(operator.truediv) + BINARY_MODULO = stack_op(operator.mod) + BINARY_REMAINDER = stack_op(operator.mod) + BINARY_ADD = stack_op(operator.add) + BINARY_SUBTRACT = stack_op(operator.sub) + BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) + BINARY_LSHIFT = stack_op(operator.lshift) + BINARY_RSHIFT = stack_op(operator.rshift) + BINARY_AND = stack_op(operator.and_) + BINARY_OR = stack_op(operator.or_) + BINARY_XOR = stack_op(operator.xor) + + INPLACE_POWER = stack_op(operator.ipow) + INPLACE_MULTIPLY = stack_op(operator.imul) + INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) + INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) + INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) + INPLACE_MODULO = stack_op(operator.imod) + INPLACE_REMAINDER = stack_op(operator.imod) + INPLACE_ADD = stack_op(operator.iadd) + INPLACE_SUBTRACT = stack_op(operator.isub) + INPLACE_LSHIFT = stack_op(operator.ilshift) + INPLACE_RSHIFT = stack_op(operator.irshift) + INPLACE_AND = stack_op(operator.iand) + INPLACE_XOR = stack_op(operator.ixor) + INPLACE_OR = stack_op(operator.ior) + + # 3.11 opcodes + def RESUME(self, inst): + if inst.arg == 0: + self.append_prefix_inst(inst) + self.accept_prefix_inst = False + else: + assert not self.accept_prefix_inst + + def BINARY_OP(self, inst): + if sys.version_info >= (3, 11): + opname = dis._nb_ops[inst.arg][0][3:] # type: ignore[attr-defined] + if opname.startswith("INPLACE"): + return getattr(self, "INPLACE_" + opname[8:])(inst) + return getattr(self, "BINARY_" + opname)(inst) + else: + unimplemented("BINARY_OP requires Python 3.11+") + + def PRECALL(self, inst): + pass + + def KW_NAMES(self, inst): + kw_names = self.code_options["co_consts"][inst.arg] + assert isinstance(kw_names, tuple) + for name in kw_names: + assert isinstance(name, str) + assert self.kw_names is None + self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment] + + def PUSH_NULL(self, inst): + self.push(NullVariable()) + + @break_graph_if_unsupported(push=1) + def CALL(self, inst): + # see https://docs.python.org/3.11/library/dis.html#opcode-CALL + # for convention + contents = self.popn(inst.arg + 2) + if isinstance(contents[0], NullVariable): + fn = contents[1] + args = [] + else: + fn = contents[0] + args = [contents[1]] + kw_names = self.kw_names.value if self.kw_names else () + if kw_names: + args = args + contents[2 : -len(kw_names)] + kwargs_list = contents[-len(kw_names) :] + kwargs = dict(zip(kw_names, kwargs_list)) + assert len(kwargs) == len(kw_names) + else: + args = args + contents[2:] + kwargs = {} + self.call_function(fn, args, kwargs) + self.kw_names = None + + def COPY(self, inst): + self.push(self.stack[-inst.arg]) + + def SWAP(self, inst): + self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] + + JUMP_BACKWARD = jump + JUMP_BACKWARD_NO_INTERRUPT = jump + + POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) + POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) + POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) + POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) + + def CACHE(self, inst): + pass + + def BEFORE_WITH(self, inst): + self.setup_or_before_with(inst) + + def setup_or_before_with(self, inst): + ctx = self.pop() + if not isinstance(ctx, ContextWrappingVariable): + unimplemented(f"{inst.opname} {ctx}") + + if isinstance(ctx, GenericContextWrappingVariable): + self.generic_context_manager_depth += 1 + + exit = WithExitFunctionVariable( + ctx, + inst.target, + ) + if sys.version_info >= (3, 11): + # see create_call_resume_at for block stack details + assert self.next_instruction + assert self.next_instruction.exn_tab_entry + target = self.next_instruction.exn_tab_entry.target + else: + target = inst.target + if isinstance(self, InstructionTranslator): + self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx)) + else: + self.block_stack.append(BlockStackEntry(target)) + + self.push(exit) + self.push(ctx.enter(self)) + + def append_prefix_inst(self, inst): + assert self.accept_prefix_inst + self.prefix_insts.append(inst) + + def MAKE_CELL(self, inst): + self.append_prefix_inst(inst) + + def COPY_FREE_VARS(self, inst): + self.append_prefix_inst(inst) + + def RETURN_GENERATOR(self, inst): + self.append_prefix_inst(inst) + + def copy_graphstate(self) -> InstructionTranslatorGraphState: + """Create a checkpoint of the current state by copying everything""" + return InstructionTranslatorGraphState( + self.output.copy_graphstate(), + dict(self.symbolic_locals), + list(self.stack), + list(self.block_stack), + self.instruction_pointer, + self.current_instruction, + self.next_instruction, + self.lineno, + ) + + def restore_graphstate(self, state: InstructionTranslatorGraphState): + """Restore a checkpoint created by self.copy_graphstate()""" + ( + output_state, + self.symbolic_locals, + self.stack, + self.block_stack, + self.instruction_pointer, + self.current_instruction, + self.next_instruction, + self.lineno, + ) = state + self.output.restore_graphstate(output_state) + + def is_non_empty_graph(self): + if self.output.count_calls() > 1: + # perf optimization only + self.is_non_empty_graph = lambda: True # type: ignore[method-assign] + return True + return False + + def format_frame_summary(self, additional_stack_frames=None): + if additional_stack_frames is None: + additional_stack_frames = [] + return "".join( + traceback.format_list( + [self.frame_summary()] + list(reversed(additional_stack_frames)) + ) + ) + + def frame_summary(self): + return traceback.FrameSummary( + getattr(self.f_code, "co_filename", ""), + self.lineno, + getattr(self.f_code, "co_name", ""), + lookup_line=False, + ) + + def store_global_weakref_by_id(self, prefix, value): + global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) + install_guard( + GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) + ) + return global_name + + @property + def fake_mode(self): + return self.output.tracing_context.fake_mode + + def find_symbolic_locals_name(self, tensor_variable): + for key, value in self.symbolic_locals.items(): + if value is tensor_variable: + return key + return None + + @contextlib.contextmanager + def strict_translation_mode(self): + self.strict_checks_enabled = True + try: + yield + finally: + self.strict_checks_enabled = False + + def speculate(self) -> SpeculationEntry: + return self.speculation_log.next( + self.f_code.co_filename, self.lineno, self.instruction_pointer + ) + + def __init__( + self, + output: OutputGraph, + instructions: List[Instruction], + f_locals: Dict[str, Any], + f_globals: Dict[str, Any], + f_builtins: Dict[str, Any], + code_options: Dict[str, Any], + symbolic_locals: Dict[str, VariableTracker], + symbolic_globals: Dict[str, VariableTracker], + f_code: types.CodeType, + export: bool, + inline_depth: int, + speculation_log: SpeculationLog, + ): + super().__init__() + self.speculation_log = speculation_log + + # Mutable state checkpointed by copy_graphstate() + self.output = output + self.symbolic_locals = symbolic_locals + self.symbolic_globals = symbolic_globals + self.stack = [] + self.instruction_pointer = 0 + self.current_instruction = create_instruction("NOP") + self.next_instruction = None + self.block_stack = [] + # states before SETUP_WITH for checkpointing and fallback + self.generic_context_manager_depth = 0 + self.lineno = code_options["co_firstlineno"] + self.kw_names = None + self.accept_prefix_inst = True + self.prefix_insts = [] + + # Properties of the input/output code + self.instructions: List[Instruction] = instructions + self.indexof: Dict[Instruction, int] = get_indexof(self.instructions) + self.f_locals: Dict[ + str, Any + ] = f_locals # needed for recording accessed locals for replay + self.f_globals: Dict[str, Any] = f_globals + self.f_builtins: Dict[str, Any] = f_builtins + self.code_options: Dict[str, Any] = code_options + self.f_code: types.CodeType = f_code + + # Execution record for replaying errors + self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options) + # Stack of module being parsed, current nn.module is at the end of ordered dict. + # The first field of tuple is the fully qualified name of current module + # in original hierarchy. The second field is the type of current nn.module + self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} + # Flag to indicate whether tracing is used for export. + self.export = export + + self.current_speculation = None + + self.strict_checks_enabled = False + + if sys.version_info >= (3, 10): + from .resume_execution import ( + CO_ASYNC_GENERATOR, + CO_COROUTINE, + CO_GENERATOR, + CO_ITERABLE_COROUTINE, + ) + + if f_code.co_flags & ( + CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR + ): + self.push(BuiltinVariable(None)) + + self.inline_depth = inline_depth + self.inconsistent_side_effects = False + linecache.lazycache(f_code.co_filename, f_globals) + self.log_starts_line() + + +class InstructionTranslator(InstructionTranslatorBase): + mutated_closure_cell_contents: Set[str] + + @staticmethod + def current_tx() -> "InstructionTranslator": + return tls.current_tx + + @contextlib.contextmanager + def set_current_tx(self): + prior = getattr(tls, "current_tx", None) + tls.current_tx = self + try: + yield + finally: + tls.current_tx = prior + + def __init__( + self, + instructions: List[Instruction], + f_code, + f_locals, + f_globals, + f_builtins, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + mutated_closure_cell_contents: Set[str], + frame_state, + speculation_log: SpeculationLog, + ): + _step_logger()( + logging.INFO, + f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}", + ) + super().__init__( + output=OutputGraph( + code_options, + compiler_fn, + self, + export, + export_constraints, + frame_state, + local_scope=f_locals, + global_scope=f_globals, + f_code=f_code, + ), + instructions=instructions, + f_locals=f_locals, + f_globals=f_globals, + f_builtins=f_builtins, + code_options=code_options, + symbolic_locals={}, # set below + # A global var is inserted only after a STORE_GLOBAL happens to it + symbolic_globals={}, + f_code=f_code, + export=export, + inline_depth=0, + speculation_log=speculation_log, + ) + + self._throw_if_in_functorch() + + # as soon as we create the tracing context we should keep it active, so any calls + # into dynamo apis can rely on finding it + with tracing(self.output.tracing_context), self.set_current_tx(): + self.one_graph: bool = one_graph + self.export = export + self.mutated_closure_cell_contents = mutated_closure_cell_contents + if self.export: + assert ( + self.one_graph + ), "Export without one graph - something has gone wrong." + + vars = list(code_options["co_varnames"]) + cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars] + vars.extend(cells_and_freevars) + cells_and_freevars_set = set(cells_and_freevars) + + self.symbolic_locals = { + k: variables.LazyVariableTracker.create( + f_locals[k], + source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set), + ) + for k in vars + if k in f_locals + } + self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] + if export: + # export gets confused if we never realize unused inputs + # in export mode just eagerly realize everything + self.symbolic_locals = VariableTracker.apply( + lambda x: x.realize(), self.symbolic_locals + ) + + self._freevars_ids = dict() + for name in self.code_options["co_freevars"]: + if name in f_locals: + self._freevars_ids[name] = id(f_locals[name]) + + def _throw_if_in_functorch(self): + # Fallback to eager in case of a graph break inside vmap + eager = torch._dynamo.lookup_backend("eager") + compiler_fn = inspect.getattr_static( + self.output.compiler_fn, "compiler_fn", self.output.compiler_fn + ) + ci = torch._C._functorch.peek_interpreter_stack() + forbidden_keys = ( + torch._C._functorch.TransformType.Vmap, + torch._C._functorch.TransformType.Grad, + ) + if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager: + # if it reaches here, it means Dynamo failed to inline a functorch function + name = ci.key().name.lower() + msg = f"torch.func.{name}(fn) requires the function to be inlined by dynamo" + unimplemented(msg) + + def get_example_value(self, source: Source): + if isinstance(source, LocalSource): + return self.f_locals[source.local_name] + if isinstance(source, GlobalSource): + return self.f_globals[source.global_name] + raise KeyError() + + def run(self): + super().run() + + def match_nested_cell(self, name, cell): + """Match a cell in this method to one in a function we are inlining""" + try: + value = cell.cell_contents + except ValueError: + return None + # TODO(jansel): check the id of the cell rather than the contents + if id(value) != self._freevars_ids.get(name): + return None + return self.symbolic_locals[name] + + def should_compile_partial_graph(self): + return ( + all(b.can_restore() for b in self.block_stack) + and not self.one_graph + and self.generic_context_manager_depth == 0 + ) + + def create_call_resume_at(self, inst): + self.instruction_pointer = None + + if inst.opname == "RETURN_VALUE": + return [create_instruction("RETURN_VALUE")] + + reads = livevars_analysis(self.instructions, inst) + argnames = tuple( + k + for k in self.symbolic_locals.keys() + if k in reads and k not in self.cell_and_freevars() + ) + + cg = PyCodegen(self) + + # Python does not allow null to be an arg to a function, so + # we remove nulls from the stack and restore them in the + # prologue of the resume function + + # sorted list of indices of nulls on the stack + null_idxes: List[int] = [] + if sys.version_info >= (3, 11): + # find indices of NullVariables + for i, var in enumerate(self.stack): + if isinstance(var, NullVariable): + null_idxes.append(i) + # generate bytecode to pop the nulls + null_cnt = 0 + for i, var in enumerate(reversed(self.stack)): + if isinstance(var, NullVariable): + for j in range(2, i + 2 - null_cnt): + cg.append_output(create_instruction("SWAP", arg=j)) + cg.extend_output(cg.pop_null()) + null_cnt += 1 + + # we popped all nulls from the stack at runtime, + # so we should not count NullVariables + stack_len = len(self.stack) - len(null_idxes) + nargs = stack_len + len(argnames) + + name = unique_id(f"__resume_at_{inst.offset}") + + new_code: types.CodeType = ContinueExecutionCache.lookup( + self.f_code, + self.lineno, + inst.offset, + tuple(b.target.offset for b in self.block_stack), + stack_len, + argnames, + tuple(b.resume_fn() for b in self.block_stack), + tuple(null_idxes), + ) + + # Add original GraphModule context to the resume function to handle + # the case of a graph break while tracing a GraphModule + orig_graphmodule_maybe = code_context.get_context(self.f_code).get( + "orig_graphmodule", lambda: None + )() + if orig_graphmodule_maybe is not None: + code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( + orig_graphmodule_maybe + ) + + if new_code.co_freevars: + cg.make_function_with_closure(name, new_code, True, stack_len) + else: + # This is safe: we pre-generate a unique name + self.output.install_global_unsafe( + name, types.FunctionType(new_code, self.f_globals, name) + ) + cg.extend_output(cg.load_function_name(name, True, stack_len)) + + cg.extend_output([cg.create_load(k) for k in argnames]) + cg.extend_output(create_call_function(nargs, False)) + cg.append_output(create_instruction("RETURN_VALUE")) + return cg.get_instructions() + + def symbolic_locals_contain_module_class(self): + for v in self.symbolic_locals.values(): + if isinstance(v, UserDefinedClassVariable) and issubclass( + v.as_python_constant(), torch.nn.Module + ): + return True + return False + + def RETURN_VALUE(self, inst): + if ( + self.output.count_calls() == 0 + and not self.inconsistent_side_effects + and not self.symbolic_locals_contain_module_class() + and not self.export + ): + raise exc.SkipFrame("because no content in function call") + self.instruction_pointer = None + _step_logger()( + logging.INFO, + f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)", + ) + log.debug("RETURN_VALUE triggered compile") + self.output.compile_subgraph( + self, + reason=GraphCompileReason( + "return_value", [self.frame_summary()], graph_break=False + ), + ) + self.output.add_output_instructions([create_instruction("RETURN_VALUE")]) + + +class InliningInstructionTranslator(InstructionTranslatorBase): + """Trace and inline a called method""" + + symbolic_result: Optional[TensorVariable] + + @classmethod + def inline_call(cls, parent, func, args, kwargs): + with patch.dict(counters, {"unimplemented": counters["inline_call"]}): + return cls.inline_call_(parent, func, args, kwargs) + + @staticmethod + def check_inlineable(func): + if func.has_self(): + unimplemented("inline with __self__") + + result = trace_rules.check_verbose(func, is_inlined_call=True) + if result.skipped: + from torch._dynamo.variables.misc import produce_trampoline_autograd_apply + + # _origin marks this as coming from an internal dynamo known function that is safe to + # trace through. + if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ + produce_trampoline_autograd_apply, + ]: + # Known sound + return trace_rules.SkipResult( + False, "allowlist in dynamo known function" + ) + fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" + unimplemented( + f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'" + ) + + if isinstance(func, UserFunctionVariable) and inspect.getattr_static( + func.get_function(), "_torchdynamo_disable", False + ): + unimplemented( + f"call torch._dynamo.disable() wrapped function {func.get_function()}" + ) + else: + return result + + @staticmethod + def inline_call_( + parent, func: VariableTracker, args: List[VariableTracker], kwargs + ): + if isinstance(func, SkipFunctionVariable): + unimplemented("inline with functions in skip files") + assert isinstance( + func, + (UserFunctionVariable, NestedUserFunctionVariable), + ) + result = InliningInstructionTranslator.check_inlineable(func) + assert result.skipped is False + try: + sub_locals, closure_cells = func.bind_args(parent, args, kwargs) + except TypeError as e: + # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info + raise ArgsMismatchError( # noqa: TRY200 + "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( + reason=str(e), + func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", + args=[arg.python_type() for arg in args], + kwargs=kwargs, + ), + ) + + for v in itertools.chain(sub_locals.values(), closure_cells.values()): + if not isinstance(v, VariableTracker): + unimplemented(f"unconverted arg {v}") + + code: types.CodeType = func.get_code() + if code.co_name in ("__setitem__", "__setattr__") and not ( + args is not None + and len(args) > 0 + and isinstance(args[0], variables.CustomizedDictVariable) + ): + unimplemented(f"inline {code.co_name}") + + suffix = "" + # TODO: mlazos, add support for enabling multiple artifact logs + # with a single alias + if torch._logging._internal.log_state.is_artifact_enabled("output_code"): + suffix = f"\n{dis.Bytecode(code).dis()}" + if sys.version_info >= (3, 11): + cur_inst = parent.current_instruction + parent_code = parent.f_code + header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno) + + def get_trace_call_log_str(): + line = get_instruction_source_311(parent_code, cur_inst).rstrip() + return f"TRACE inlined call {code.co_name} from {header}\n{line}" + + trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) + log.debug("INLINING %s%s, %s", code, suffix, result.reason) + + # Detect inline GraphModule calls in order to propagate node metadata, + # by checking if the first argument (self) is a variable tracking a GraphModule. + if args and isinstance(args[0], NNModuleVariable): + module = parent.output.get_submodule(args[0].module_key) + if isinstance(module, torch.fx.GraphModule): + # The inline call might not actually be a call to `forward`, + # but it is enough to add a context for `forward` in case it is called. + code_context.get_context(module.forward.__code__)[ + "orig_graphmodule" + ] = weakref.ref(module) + + tracer: InliningInstructionTranslator + if is_generator(code): + tracer = InliningGeneratorInstructionTranslator( + parent, code, sub_locals, parent.symbolic_globals, closure_cells, func + ) + else: + tracer = InliningInstructionTranslator( + parent, code, sub_locals, parent.symbolic_globals, closure_cells, func + ) + + strict_ctx: Any = contextlib.nullcontext() + if parent.strict_checks_enabled: + strict_ctx = tracer.strict_translation_mode() + try: + with strict_ctx: + tracer.run() + except exc.SkipFrame as e: + msg = f"SKIPPED INLINING {code}: {e}" + log.debug(msg) + raise Unsupported(msg) from e + except Exception as e: + log.debug("FAILED INLINING %s", code) + raise + assert tracer.symbolic_result is not None + func.export_freevars(parent, tracer) + + if tracer.f_globals is parent.f_globals: + # Merge symbolic_globals back if parent and child are in the same namespace + parent.symbolic_globals.update(tracer.symbolic_globals) + + parent.inconsistent_side_effects |= tracer.inconsistent_side_effects + + log.debug("DONE INLINING %s", code) + + if is_generator(code): + assert isinstance(tracer, InliningGeneratorInstructionTranslator) + assert tracer.symbolic_result.as_python_constant() is None + return ListIteratorVariable( + tracer.generated_items, + mutable_local=MutableLocal(), + ) + else: + return tracer.symbolic_result + + def __init__( + self, + parent: InstructionTranslatorBase, + code: types.CodeType, + symbolic_locals: Dict[str, VariableTracker], + symbolic_globals: Dict[str, VariableTracker], + closure_cells: Dict[str, VariableTracker], + funcvar: BaseUserFunctionVariable, + ): + f_globals = funcvar.get_globals() # type: ignore[attr-defined] + f_builtins = f_globals["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + instructions = cleaned_instructions(code) + propagate_line_nums(instructions) + super().__init__( + output=parent.output, + f_locals={}, + f_globals=f_globals, + f_builtins=f_builtins, + symbolic_locals=symbolic_locals, + symbolic_globals=symbolic_globals, + instructions=instructions, + code_options={k: getattr(code, k) for k in dir(code)}, + f_code=code, + export=parent.export, + inline_depth=parent.inline_depth + 1, + speculation_log=parent.speculation_log, + ) + self.parent = parent + self.symbolic_result = None + self.closure_cells = closure_cells + self.nn_module_stack = parent.nn_module_stack.copy() + + @property + def fake_mode(self): + return self.parent.fake_mode + + def run_ctx_mgr(self): + return TracingContext.current_frame(self.parent.frame_summary()) + + def STORE_DEREF(self, inst): + if inst.argval in self.closure_cells: + cell = self.closure_cells[inst.argval] + val = self.pop() + if isinstance(cell, ClosureVariable): + if not self.output.is_root_tracer(): + unimplemented( + "HigherOrderOperator: Mutating a variable not in the current scope (ClosureVariable)" + ) + self.output.root_tx.symbolic_locals[cell.name] = val + else: + self.output.side_effects.store_cell(cell, val) + else: + maybe_cell = self.symbolic_locals.get(inst.argval) + if isinstance( + maybe_cell, + variables.NewCellVariable, + ): + self.output.side_effects.store_cell( + self.symbolic_locals[inst.argval], self.pop() + ) + else: + if ( + maybe_cell is not None + and maybe_cell.source.name() + not in self.output.root_tx.mutated_closure_cell_contents + ): + # Why is the source name here unique? + # mutated_closure_cell_contents is a per-frame + # concept, and sources identify, e.g., particular + # locals from the frame. If you had two locals, + # they'll get different source names, and therefore + # differ here. + self.output.root_tx.mutated_closure_cell_contents.add( + maybe_cell.source.name() + ) + raise exc.UnspecializeRestartAnalysis() + unimplemented("write to __closure__ while inlining") + + def LOAD_DEREF(self, inst): + if inst.argval in self.closure_cells: + cell = self.closure_cells[inst.argval] + if isinstance(cell, ClosureVariable): + self.push(self.output.root_tx.symbolic_locals[cell.name]) + else: + self.push(self.output.side_effects.load_cell(cell)) + else: + maybe_sym_local = self.symbolic_locals.get(inst.argval, None) + if isinstance(maybe_sym_local, variables.NewCellVariable): + self.push(self.output.side_effects.load_cell(maybe_sym_local)) + else: + super().LOAD_DEREF(inst) + + def LOAD_CLOSURE(self, inst): + assert inst.argval in self.cell_and_freevars() + if inst.argval in self.closure_cells: + self.push(self.closure_cells[inst.argval]) + else: + self.push(InlinedClosureVariable(name=inst.argval)) + + def check_replace_is_safe(self, oldvar): + if not is_side_effect_safe(oldvar.mutable_local): + unimplemented( + "HigherOrderOperator: Mutating a variable not in the current scope (replace_all)" + ) + + def should_compile_partial_graph(self): + return False # inlining functions is all-or-nothing + + def create_call_resume_at(self, offset): + unimplemented("cant resume while inlining") + + def RETURN_VALUE(self, inst): + self.symbolic_result = self.pop() # type: ignore[assignment] + self.instruction_pointer = None + + +class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): + generated_items: List[VariableTracker] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.generated_items = [] + + def YIELD_VALUE(self, inst: Instruction): + self.generated_items.append(self.pop()) + # TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE + self.push(ConstantVariable.create(None)) + + def GET_YIELD_FROM_ITER(self, inst): + tos = self.stack[-1] + if not isinstance(tos, ListIteratorVariable): + self.pop() + res = BuiltinVariable(iter).call_function(self, [tos], {}) + self.push(res) + return self.YIELD_FROM(inst) + + def YIELD_FROM(self, inst): + while True: + tos = self.stack[-1].realize() + if isinstance(tos, ConstantVariable) and tos.value is None: + self.pop() + return + if isinstance( + tos, (variables.ListIteratorVariable, variables.IteratorVariable) + ): + try: + val, next_iter = tos.next_variables(self) + self.push(val) + # TODO(voz): Unclear if we need the push None in YIELD_VALUE? + self.YIELD_VALUE(inst) + self.pop() + self.push(next_iter) + except StopIteration: + return + else: + unimplemented(f"YIELD_FROM {typestr(tos)}") + + def SEND(self, inst): + assert len(self.stack) >= 2 + val = self.pop() + tos = self.stack[-1] + if isinstance(tos, ListIteratorVariable): + if isinstance(val, ConstantVariable) and val.value is None: + self.push(val) + self.instruction_pointer = self.indexof[inst.target] + else: + # invoke send + # Unreachable code - if you hit this, you are implementing generator support and have + # lifted the `unimplemented("generator")` in frame conversion. This codepath handles + # subgenerator and lines up with this line in Python 3.11 + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 + unimplemented("Unreachable sub-generator code") + else: + unimplemented(f"SEND {typestr(tos)}") diff --git a/MLPY/Lib/site-packages/torch/_dynamo/tensor_version_op.py b/MLPY/Lib/site-packages/torch/_dynamo/tensor_version_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5c20e1cd5ff504e40a14eb599d11880e61e9f9c1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/tensor_version_op.py @@ -0,0 +1,57 @@ +import torch +from torch._prims import _make_prim, RETURN_TYPE +from torch._subclasses import FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensorMode + +_tensor_version = _make_prim( + schema="_tensor_version(Tensor self) -> SymInt", + return_type=RETURN_TYPE.NEW, + meta=torch.ops.aten._version.default, + impl_aten=torch.ops.aten._version.default, + doc="Tracable unbacked SymInt version of torch.Tensor._version", +) + + +@_tensor_version.py_impl(FakeTensorMode) +def _tensor_version_fake(self): + """ + The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the + `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` + of input tensors to the graph. + """ + return self.fake_mode.shape_env.create_unbacked_symint() + + +_unsafe_set_version_counter = _make_prim( + schema="_unsafe_set_version_counter(Tensor self, SymInt version) -> ()", + return_type=RETURN_TYPE.NEW, + meta=lambda self, version: None, + impl_aten=torch._C._autograd._unsafe_set_version_counter, + doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter", +) +torch.fx.node.has_side_effect(_unsafe_set_version_counter) + + +""" +When we functionalize _tensor_version + _unsafe_set_version_counter, +the ops disappear from the traced graph. We run them eagerly on the +fake tensors used for tracing, in order to get past asserts that would +fail in autograd. + +Why is this ok? +1) Versions on functional tensors don't make any sense since you can't mutate a functional tensor. +2) The whole point of version munging is to trick autograd into doing what we want, and after + AotAtuograd there is no longer any need for these ops. + +Note this is similar to how no_grad is handled. +""" + + +@_tensor_version.py_impl(FunctionalTensorMode) +def _tensor_version_functional(self): + return self._version + + +@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) +def _unsafe_set_version_counter_functional(self, version): + torch._C._autograd._unsafe_set_version_counter(self, version) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/test_case.py b/MLPY/Lib/site-packages/torch/_dynamo/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..939d811ef3b5572b74d254f109d7733ba79bec74 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/test_case.py @@ -0,0 +1,78 @@ +import contextlib +import importlib +import logging +import sys + +import torch +import torch.testing +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + IS_WINDOWS, + TEST_WITH_CROSSREF, + TEST_WITH_TORCHDYNAMO, + TestCase as TorchTestCase, +) + +from . import config, reset, utils + +log = logging.getLogger(__name__) + + +def run_tests(needs=()): + from torch.testing._internal.common_utils import run_tests + + if ( + TEST_WITH_TORCHDYNAMO + or IS_WINDOWS + or TEST_WITH_CROSSREF + or sys.version_info >= (3, 12) + ): + return # skip testing + + if isinstance(needs, str): + needs = (needs,) + for need in needs: + if need == "cuda" and not torch.cuda.is_available(): + return + else: + try: + importlib.import_module(need) + except ImportError: + return + run_tests() + + +class TestCase(TorchTestCase): + _exit_stack: contextlib.ExitStack + + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined] + cls._exit_stack.enter_context( # type: ignore[attr-defined] + config.patch( + raise_on_ctx_manager_usage=True, + suppress_errors=False, + log_compilation_metrics=False, + ), + ) + + def setUp(self): + self._prior_is_grad_enabled = torch.is_grad_enabled() + super().setUp() + reset() + utils.counters.clear() + + def tearDown(self): + for k, v in utils.counters.items(): + print(k, v.most_common()) + reset() + utils.counters.clear() + super().tearDown() + if self._prior_is_grad_enabled is not torch.is_grad_enabled(): + log.warning("Running test changed grad mode") + torch.set_grad_enabled(self._prior_is_grad_enabled) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/test_minifier_common.py b/MLPY/Lib/site-packages/torch/_dynamo/test_minifier_common.py new file mode 100644 index 0000000000000000000000000000000000000000..46e7a272ff888aded2dd3c49354cbefa82e7f5ad --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/test_minifier_common.py @@ -0,0 +1,244 @@ +import dataclasses +import io +import logging +import os +import re +import shutil +import subprocess +import sys +import tempfile +import traceback +from typing import Optional +from unittest.mock import patch + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch.utils._traceback import report_compile_source_on_error + + +@dataclasses.dataclass +class MinifierTestResult: + minifier_code: str + repro_code: str + + def _get_module(self, t): + match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) + assert match is not None, "failed to find module" + r = match.group(0) + r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE) + r = re.sub(r"\n{3,}", "\n\n", r) + return r.strip() + + def minifier_module(self): + return self._get_module(self.minifier_code) + + def repro_module(self): + return self._get_module(self.repro_code) + + +class MinifierTestBase(torch._dynamo.test_case.TestCase): + DEBUG_DIR = tempfile.mkdtemp() + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( # type: ignore[attr-defined] + torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR) + ) + # These configurations make new process startup slower. Disable them + # for the minification tests to speed them up. + cls._exit_stack.enter_context( # type: ignore[attr-defined] + torch._inductor.config.patch( + { + # https://github.com/pytorch/pytorch/issues/100376 + "pattern_matcher": False, + # multiprocess compilation takes a long time to warmup + "compile_threads": 1, + # https://github.com/pytorch/pytorch/issues/100378 + "cpp.vec_isa_ok": False, + } + ) + ) + + @classmethod + def tearDownClass(cls): + if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": + shutil.rmtree(cls.DEBUG_DIR) + else: + print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") + cls._exit_stack.close() # type: ignore[attr-defined] + + def _gen_codegen_fn_patch_code(self, device, bug_type): + assert bug_type in ("compile_error", "runtime_error", "accuracy") + return f"""\ +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} +""" + + def _maybe_subprocess_run(self, args, *, isolate, cwd=None): + if not isolate: + assert len(args) >= 2, args + assert args[0] == "python3", args + if args[1] == "-c": + assert len(args) == 3, args + code = args[2] + args = ["-c"] + else: + assert len(args) >= 2, args + with open(args[1]) as f: + code = f.read() + args = args[1:] + + # WARNING: This is not a perfect simulation of running + # the program out of tree. We only interpose on things we KNOW we + # need to handle for tests. If you need more stuff, you will + # need to augment this appropriately. + + # NB: Can't use save_config because that will omit some fields, + # but we must save and reset ALL fields + dynamo_config = torch._dynamo.config.shallow_copy_dict() + inductor_config = torch._inductor.config.shallow_copy_dict() + try: + stderr = io.StringIO() + log_handler = logging.StreamHandler(stderr) + log = logging.getLogger("torch._dynamo") + log.addHandler(log_handler) + try: + prev_cwd = os.getcwd() + if cwd is not None: + os.chdir(cwd) + with patch("sys.argv", args), report_compile_source_on_error(): + exec(code, {"__name__": "__main__", "__compile_source__": code}) + rc = 0 + except Exception: + rc = 1 + traceback.print_exc(file=stderr) + finally: + log.removeHandler(log_handler) + if cwd is not None: + os.chdir(prev_cwd) # type: ignore[possibly-undefined] + # Make sure we don't leave buggy compiled frames lying + # around + torch._dynamo.reset() + finally: + torch._dynamo.config.load_config(dynamo_config) + torch._inductor.config.load_config(inductor_config) + + # TODO: return a more appropriate data structure here + return subprocess.CompletedProcess( + args, + rc, + b"", + stderr.getvalue().encode("utf-8"), + ) + else: + return subprocess.run(args, capture_output=True, cwd=cwd, check=False) + + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code, *, isolate): + proc = self._maybe_subprocess_run( + ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR + ) + + print("test stdout:", proc.stdout.decode("utf-8")) + print("test stderr:", proc.stderr.decode("utf-8")) + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + return proc, repro_dir_match.group(1) + return proc, None + + # Runs the minifier launcher script in `repro_dir` + def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): + self.assertIsNotNone(repro_dir) + launch_file = os.path.join(repro_dir, "minifier_launcher.py") + with open(launch_file) as f: + launch_code = f.read() + self.assertTrue(os.path.exists(launch_file)) + + args = ["python3", launch_file, "minify", *minifier_args] + if not isolate: + args.append("--no-isolate") + launch_proc = self._maybe_subprocess_run(args, isolate=isolate, cwd=repro_dir) + print("minifier stdout:", launch_proc.stdout.decode("utf-8")) + stderr = launch_proc.stderr.decode("utf-8") + print("minifier stderr:", stderr) + self.assertNotIn("Input graph did not fail the tester", stderr) + + return launch_proc, launch_code + + # Runs the repro script in `repro_dir` + def _run_repro(self, repro_dir, *, isolate=True): + self.assertIsNotNone(repro_dir) + repro_file = os.path.join(repro_dir, "repro.py") + with open(repro_file) as f: + repro_code = f.read() + self.assertTrue(os.path.exists(repro_file)) + + repro_proc = self._maybe_subprocess_run( + ["python3", repro_file], isolate=isolate, cwd=repro_dir + ) + print("repro stdout:", repro_proc.stdout.decode("utf-8")) + print("repro stderr:", repro_proc.stderr.decode("utf-8")) + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file; usually + # just use this to turn on bugs via the config + def _gen_test_code(self, run_code, repro_after, repro_level): + return f"""\ +import torch +import torch._dynamo +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +torch._dynamo.config.repro_after = "{repro_after}" +torch._dynamo.config.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + # + # If possible, you should run the test with isolate=False; use + # isolate=True only if the bug you're testing would otherwise + # crash the process + def _run_full_test( + self, run_code, repro_after, expected_error, *, isolate, minifier_args=() + ) -> Optional[MinifierTestResult]: + if isolate: + repro_level = 3 + elif expected_error is None or expected_error == "AccuracyError": + repro_level = 4 + else: + repro_level = 2 + test_code = self._gen_test_code(run_code, repro_after, repro_level) + print("running test", file=sys.stderr) + test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate) + if expected_error is None: + # Just check that there was no error + self.assertEqual(test_proc.returncode, 0) + self.assertIsNone(repro_dir) + return None + # NB: Intentionally do not test return code; we only care about + # actually generating the repro, we don't have to crash + self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) + self.assertIsNotNone(repro_dir) + print("running minifier", file=sys.stderr) + minifier_proc, minifier_code = self._run_minifier_launcher( + repro_dir, isolate=isolate, minifier_args=minifier_args + ) + print("running repro", file=sys.stderr) + repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate) + self.assertIn(expected_error, repro_proc.stderr.decode("utf-8")) + self.assertNotEqual(repro_proc.returncode, 0) + return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/testing.py b/MLPY/Lib/site-packages/torch/_dynamo/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9fed83380a5fdab8a73ba1a38e13249dc83a23 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/testing.py @@ -0,0 +1,372 @@ +import contextlib +import dis +import functools +import logging +import os.path +import random +import re +import sys +import types +import unittest +from typing import List, Optional, Sequence, Union +from unittest.mock import patch + +np: Optional[types.ModuleType] = None +try: + import numpy as np +except ModuleNotFoundError: + np = None + +import torch +from torch import fx +from torch._dynamo.output_graph import OutputGraph + +from . import config, eval_frame, optimize_assert, reset +from .bytecode_transformation import ( + create_instruction, + debug_checks, + is_generator, + transform_code_object, +) +from .guards import CheckFunctionManager, GuardedCode +from .utils import same + +unsupported = eval_frame.unsupported +three = 3 + +log = logging.getLogger(__name__) + + +def clone_me(x): + if x is None: + return None + return x.detach().clone().requires_grad_(x.requires_grad) + + +def named_parameters_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_parameters + + +def named_buffers_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_buffers + + +def remove_optimized_module_prefix(name) -> str: + return re.sub(r"^_orig_mod[.]", "", name) + + +def collect_results(model, prediction, loss, example_inputs): + results = [] + results.append(prediction) + results.append(loss) + # if isinstance(loss, torch.Tensor) and loss.item() > 1: + # log.warning( + # f"High loss value alert - {loss:.2f}. Can result in unstable gradients." + # ) + + grads = dict() + params = dict() + for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) + param_copy = param + grad = param.grad + # Treat None and zero grad as same + if param.grad is None: + grad = torch.zeros_like(param) + grads[name + ".grad"] = grad + params[name] = param_copy + results.append(grads) + results.append(params) + buffers = dict() + for name, buffer in model.named_buffers(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) + buffers[name] = buffer + results.append(buffers) + for example in example_inputs: + if isinstance(example, (tuple, list)): + for inp in example: + if isinstance(inp, torch.Tensor): + results.append(inp.grad) + else: + if isinstance(example, torch.Tensor): + results.append(example.grad) + return results + + +def requires_bwd_pass(out): + if isinstance(out, torch.Tensor): + return out.requires_grad + elif isinstance(out, (list, tuple)): + return any(requires_bwd_pass(x) for x in out) + elif out is None: + return False + elif isinstance(out, int): + return False + raise NotImplementedError("Don't know how to reduce", type(out)) + + +def reduce_to_scalar_loss(out): + """Reduce the output of a model to get scalar loss""" + if isinstance(out, torch.Tensor): + # Mean does not work on integer tensors + return out.sum() / out.numel() + elif isinstance(out, (list, tuple)): + return sum([reduce_to_scalar_loss(x) for x in out]) / len(out) + elif type(out).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + ): + return reduce_to_scalar_loss(out.logits) + elif type(out).__name__ == "SquashedNormal": + return out.mean.sum() + elif isinstance(out, dict): + return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len( + out.keys() + ) + raise NotImplementedError("Don't know how to reduce", type(out)) + + +def debug_dir() -> str: + path = os.path.join(os.path.dirname(__file__), "../debug") + if not os.path.exists(path): + os.mkdir(path) + return path + + +def debug_dump(name, code: types.CodeType, extra="") -> None: + with open(os.path.join(debug_dir(), name), "w") as fd: + fd.write( + f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n" + ) + + +def debug_insert_nops( + frame, cache_size, hooks, _, *, skip: int = 0 +) -> Optional[GuardedCode]: + """used to debug jump updates""" + + def insert_nops(instructions, code_options): + instructions.insert(0, create_instruction("NOP")) + instructions.insert(0, create_instruction("NOP")) + + if is_generator(frame.f_code): + return None + + debug_checks(frame.f_code) + code = transform_code_object(frame.f_code, insert_nops) + graph = OutputGraph( + code_options={}, + compiler_fn=None, + root_tx=None, + export=False, + export_constraints=None, + frame_state={"_id": 0}, + # TODO: shouldn't this be f_locals/f_globals from frame? + local_scope=locals(), + global_scope=globals(), + f_code=frame.f_code, + ) + + return GuardedCode(code, CheckFunctionManager(graph).check_fn) + + +class CompileCounter: + def __init__(self): + self.frame_count = 0 + self.op_count = 0 + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + return gm.forward + + def clear(self): + self.frame_count = 0 + self.op_count = 0 + + +class CompileCounterWithBackend: + def __init__(self, backend): + self.frame_count = 0 + self.op_count = 0 + self.backend = backend + self.graphs = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + from .backends.registry import lookup_backend + + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + self.graphs.append(gm) + return lookup_backend(self.backend)(gm, example_inputs) + + +# Equivalent to backend="eager", but also records graphs that +# we can assert on +class EagerAndRecordGraphs: + def __init__(self): + self.graphs = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + self.graphs.append(gm) + return gm + + +def strip_comment(code) -> str: + code = str(code) + return re.sub(r"(?m)^ *#.*\n?", "", code) + + +def remove_trailing_space(code) -> str: + return "\n".join([line.rstrip() for line in code.split("\n")]) + + +def normalize_gm(gm_str) -> str: + # strip comments as comments have path to files which may differ from + # system to system. + return remove_trailing_space(strip_comment(gm_str)) + + +def standard_test( + self, + fn, + nargs, + expected_ops=None, + expected_ops_dynamic=None, + expected_frame_count=1, +): + if not config.assume_static_by_default and expected_ops_dynamic is not None: + expected_ops = expected_ops_dynamic + + actual = CompileCounter() + + args1 = [torch.randn(10, 10) for _ in range(nargs)] + args2 = [torch.randn(10, 10) for _ in range(nargs)] + correct1 = fn(*args1) + correct2 = fn(*args2) + reset() + opt_fn = optimize_assert(actual)(fn) + val1a = opt_fn(*args1) + val2a = opt_fn(*args2) + val1b = opt_fn(*args1) + val2b = opt_fn(*args2) + reset() + self.assertTrue(same(val1a, correct1)) + self.assertTrue(same(val1b, correct1)) + self.assertTrue(same(val2a, correct2)) + self.assertTrue(same(val2b, correct2)) + self.assertEqual(actual.frame_count, expected_frame_count) + if expected_ops is not None: + self.assertEqual(actual.op_count, expected_ops) + + +def dummy_fx_compile(gm: fx.GraphModule, example_inputs): + return gm.forward + + +def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1): + if not is_correct: + return "ERROR" + if pvalue > pvalue_threshold: + return f"{speedup:.3f}x SAME" + return f"{speedup:.3f}x p={pvalue:.2f}" + + +def rand_strided( + size: Sequence[int], + stride: Sequence[int], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + extra_size: int = 0, +): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(size, stride)) + + 1 + + extra_size + ) + if dtype.is_floating_point: + buffer = torch.randn(needed_size, dtype=dtype, device=device) + else: + buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device) + return torch.as_strided(buffer, size, stride) + + +def _make_fn_with_patches(fn, *patches): + @functools.wraps(fn) + def _fn(*args, **kwargs): + with contextlib.ExitStack() as stack: + for module, attr, val in patches: + stack.enter_context(patch.object(module, attr, val)) + + return fn(*args, **kwargs) + + return _fn + + +def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches, xfail_prop=None): + DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) + DummyTestClass.__qualname__ = DummyTestClass.__name__ + + for name in dir(cls): + if name.startswith("test_"): + fn = getattr(cls, name) + if not callable(fn): + setattr(DummyTestClass, name, getattr(cls, name)) + continue + new_name = f"{name}{fn_suffix}" + new_fn = _make_fn_with_patches(fn, *patches) + new_fn.__name__ = new_name + if xfail_prop is not None and hasattr(fn, xfail_prop): + new_fn = unittest.expectedFailure(new_fn) + setattr(DummyTestClass, new_name, new_fn) + # NB: Doesn't handle slots correctly, but whatever + elif not hasattr(DummyTestClass, name): + setattr(DummyTestClass, name, getattr(cls, name)) + + return DummyTestClass + + +# test Python 3.11+ specific features +def skipIfNotPy311(fn): + if sys.version_info >= (3, 11): + return fn + return unittest.skip(fn) + + +# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py +# and test/dynamo/test_dynamic_shapes.py +def expectedFailureDynamic(fn): + fn._expected_failure_dynamic = True + return fn + + +# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py +def expectedFailureCodegenDynamic(fn): + fn._expected_failure_codegen_dynamic = True + return fn + + +# Controls test generated in test/inductor/test_cpp_wrapper.py +def expectedFailureDynamicWrapper(fn): + fn._expected_failure_dynamic_wrapper = True + return fn + + +def reset_rng_state(use_xla=False): + torch.manual_seed(1337) + random.seed(1337) + if np: + np.random.seed(1337) + if use_xla: + import torch_xla.core.xla_model as xm + + xm.set_rng_state(1337, str(xm.xla_device())) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/trace_rules.py b/MLPY/Lib/site-packages/torch/_dynamo/trace_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3638568af181f1293c225f03b2a8aec3df99ba --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/trace_rules.py @@ -0,0 +1,3460 @@ +import _collections_abc +import _weakrefset +import abc +import builtins +import collections +import contextlib +import copy +import copyreg +import dataclasses +import enum +import functools +import importlib +import inspect +import itertools +import linecache +import logging +import multiprocessing +import operator +import os +import posixpath +import random +import re +import selectors +import signal +import sys +import tempfile +import threading +import tokenize +import traceback +import types +import typing +import unittest +import weakref +from collections import defaultdict +from typing import Any, Callable, cast, Dict, List, Optional, Set, Union + +np: Optional[types.ModuleType] = None +try: + import numpy as np +except ModuleNotFoundError: + pass + +import torch +import torch._inductor.test_operators +import torch.distributed +import torch.utils._content_store +from ..utils import _config_module +from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper + +from .variables import ( + BuiltinVariable, + FunctorchHigherOrderVariable, + NestedUserFunctionVariable, + SkipFunctionVariable, + TorchInGraphFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) + +from .variables.base import VariableTracker + + +""" +Map of function objects to their tracing rules (Dynamo variables). +* TorchInGraphFunctionVariable: The functions should be put into the FX graph or can be constant folded. E.g., + - torch.add: should be put into the FX graph. + - torch.is_floating_point: constant folded. +* SkipFunctionVariable: The objects should be skipped from tracing. +* UserFunctionVariable: The functions should be inlined. + +For developers: If you add/remove a torch level API, it may trigger failures from +test/dynamo/test_trace_rules.py:test_torch_name_rule_map_updated. To fix the failures: +If you are adding a new torch level API or Dynamo implementation: +* Add the name with the corresponding tracing rule to this map + if you are adding a new in graph function or Dynamo implementation for an existing function. +* Remove the object name from test/dynamo/test_trace_rules.ignored_c_binding_in_graph_function_names if it's there. + +If you are removing an existing torch level API: +* Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_c_binding_in_graph_function_names + depends on where it is. + + +""" +manual_torch_name_rule_map = { + "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, + "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, + "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, + "torch.jit.is_scripting": TorchInGraphFunctionVariable, + "torch.jit.is_tracing": TorchInGraphFunctionVariable, + "torch.jit.annotate": TorchInGraphFunctionVariable, + "torch.distributed.is_available": TorchInGraphFunctionVariable, + "torch.distributed.is_initialized": TorchInGraphFunctionVariable, + "torch.distributed.get_rank": TorchInGraphFunctionVariable, + "torch.distributed.get_world_size": TorchInGraphFunctionVariable, + "torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable, + "torch._utils.is_compiling": TorchInGraphFunctionVariable, + "torch.overrides.get_default_nowrap_functions": TorchInGraphFunctionVariable, + "torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable, + "torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, + "torch.autograd._profiler_enabled": SkipFunctionVariable, + # We graph break on RNG state setters or getters like + # `torch.get_rng_state` or `torch.set_rng_state`. These functions + # are not aten operations and therefore they are completely ignored + # by the AOT dispatcher. As a result, the AOT graph does not have + # these setter or getter functions, producing an incorrect graph + # when it comes to rng states. + "torch.default_generator#get_state": SkipFunctionVariable, + "torch._C.Generator#get_state": SkipFunctionVariable, + "torch.get_rng_state": SkipFunctionVariable, + "torch.cuda.get_rng_state": SkipFunctionVariable, + "torch.default_generator#set_state": SkipFunctionVariable, + "torch._C.Generator#set_state": SkipFunctionVariable, + "torch.set_rng_state": SkipFunctionVariable, + "torch.cuda.set_rng_state": SkipFunctionVariable, + # https://github.com/pytorch/pytorch/issues/107187 + "torch.manual_seed": SkipFunctionVariable, + # https://github.com/pytorch/pytorch/issues/93501 + "torch.nn.utils.rnn.pack_padded_sequence": SkipFunctionVariable, + "torch.nn.Parameter": TorchInGraphFunctionVariable, + "torch._nested_tensor_from_mask": SkipFunctionVariable, + "torch._nested_from_padded": SkipFunctionVariable, + # symbol operators implemented in Python + "torch.sym_not": TorchInGraphFunctionVariable, + "torch.sym_float": TorchInGraphFunctionVariable, + "torch.sym_int": TorchInGraphFunctionVariable, + "torch.sym_max": TorchInGraphFunctionVariable, + "torch.sym_min": TorchInGraphFunctionVariable, + "torch.sym_sqrt": TorchInGraphFunctionVariable, + "torch.sym_ite": TorchInGraphFunctionVariable, + "torch.Tensor#_make_wrapper_subclass": SkipFunctionVariable, + "torch.Tensor#__init__": SkipFunctionVariable, + "torch.cuda.set_device": SkipFunctionVariable, + "torch.cuda.current_device": SkipFunctionVariable, + "torch._C.autocast_decrement_nesting": SkipFunctionVariable, + "torch._C.autocast_increment_nesting": SkipFunctionVariable, + "torch.autograd.grad": SkipFunctionVariable, + "torch._C.clear_autocast_cache": SkipFunctionVariable, + "torch.distributions.constraints.is_dependent": SkipFunctionVariable, + "torch.jit.isinstance": SkipFunctionVariable, + "torch._C.set_anomaly_enabled": SkipFunctionVariable, + "torch._C.set_autocast_cache_enabled": SkipFunctionVariable, + "torch._C.set_autocast_cpu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_cpu_enabled": SkipFunctionVariable, + "torch._C.set_autocast_enabled": SkipFunctionVariable, + "torch._C.set_autocast_gpu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_ipu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_ipu_enabled": SkipFunctionVariable, + "torch._C.set_autocast_xla_dtype": SkipFunctionVariable, + "torch._C.set_autocast_xla_enabled": SkipFunctionVariable, + "torch.resize_as_": SkipFunctionVariable, + "torch.resize_as_sparse_": SkipFunctionVariable, + "torch.get_default_device": TorchInGraphFunctionVariable, + # functorch/vmap + "torch._functorch.vmap._check_int_or_none": UserFunctionVariable, + "torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable, + "torch._functorch.vmap._check_randomness_arg": UserFunctionVariable, + "torch._functorch.vmap._chunked_vmap": UserFunctionVariable, + "torch._functorch.vmap._concat_chunked_outputs": UserFunctionVariable, + "torch._functorch.vmap._create_batched_inputs": UserFunctionVariable, + "torch._functorch.vmap._flat_vmap": UserFunctionVariable, + "torch._functorch.vmap._flatten_chunks_output": UserFunctionVariable, + "torch._functorch.vmap._get_chunked_inputs": UserFunctionVariable, + "torch._functorch.vmap._get_name": UserFunctionVariable, + "torch._functorch.vmap._maybe_remove_batch_dim": UserFunctionVariable, + "torch._functorch.vmap._num_outputs": UserFunctionVariable, + "torch._functorch.vmap._process_batched_inputs": UserFunctionVariable, + "torch._functorch.vmap._unwrap_batched": UserFunctionVariable, + "torch._functorch.vmap._validate_and_get_batch_size": UserFunctionVariable, + "torch._functorch.vmap.doesnt_support_saved_tensors_hooks": UserFunctionVariable, + "torch._functorch.vmap.get_chunk_sizes": UserFunctionVariable, + # lazy_load_decompositions uses a lock that is not supported yet in dynamo + # "torch._functorch.vmap.lazy_load_decompositions": UserFunctionVariable, + "torch._functorch.vmap.restore_vmap": UserFunctionVariable, + "torch._functorch.apis.vmap": UserFunctionVariable, + "torch._functorch.vmap.unwrap_batched": UserFunctionVariable, + "torch._functorch.vmap.vmap_impl": FunctorchHigherOrderVariable, + "torch._functorch.vmap.wrap_batched": UserFunctionVariable, + # functorch/grad + "torch._functorch.eager_transforms.grad_impl": FunctorchHigherOrderVariable, + "torch._functorch.apis.grad_and_value": UserFunctionVariable, + "torch._functorch.eager_transforms._as_tuple": UserFunctionVariable, + "torch._functorch.eager_transforms._check_unique_non_empty": UserFunctionVariable, + "torch._functorch.eager_transforms._create_differentiable": UserFunctionVariable, + "torch._functorch.eager_transforms._slice_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms._undo_create_differentiable": UserFunctionVariable, + "torch._functorch.eager_transforms._validate_and_wrap_argnum": UserFunctionVariable, + "torch._functorch.eager_transforms._validate_and_wrap_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms._wrap_all_tensors": UserFunctionVariable, + "torch._functorch.eager_transforms._wrap_tensor_for_grad": UserFunctionVariable, + # functorch/jacrev + "torch._functorch.eager_transforms.jacrev": UserFunctionVariable, + "torch._functorch.eager_transforms.error_if_complex": UserFunctionVariable, + "torch._functorch.eager_transforms._chunked_standard_basis_for_": UserFunctionVariable, + "torch._functorch.eager_transforms._safe_zero_index": UserFunctionVariable, + # functorch/vjp + "torch._functorch.eager_transforms.vjp": UserFunctionVariable, + "torch._functorch.eager_transforms._vjp_with_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_non_empty_tensor_output": UserFunctionVariable, + "torch._constrain_as_size": UserFunctionVariable, + "torch._constrain_as_value": UserFunctionVariable, + "torch._tensor._convert": UserFunctionVariable, + "torch.jit._unwrap_optional": UserFunctionVariable, + "torch.backends.mha.get_fastpath_enabled": UserFunctionVariable, + "torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable, + "torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable, + "torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable, + "torch._dynamo.mark_static": UserFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, + "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, + "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.sparse_bsc_tensor": SkipFunctionVariable, + "torch.sparse_bsr_tensor": SkipFunctionVariable, + "torch.sparse_csc_tensor": SkipFunctionVariable, + "torch.sparse_csr_tensor": SkipFunctionVariable, + "torch.sparse_compressed_tensor": SkipFunctionVariable, + "torch._C._autograd._unsafe_set_version_counter": TorchInGraphFunctionVariable, +} + + +# In graph functions (including constant folding) that are C bindings +torch_c_binding_in_graph_functions = dict.fromkeys( + [ + "math.acos", + "math.acosh", + "math.asin", + "math.asinh", + "math.atan", + "math.atan2", + "math.atanh", + "math.ceil", + "math.comb", + "math.copysign", + "math.cos", + "math.cosh", + "math.degrees", + "math.dist", + "math.erf", + "math.erfc", + "math.exp", + "math.expm1", + "math.fabs", + "math.factorial", + "math.floor", + "math.fmod", + "math.frexp", + "math.fsum", + "math.gamma", + "math.gcd", + "math.hypot", + "math.isclose", + "math.isfinite", + "math.isinf", + "math.isnan", + "math.isqrt", + "math.ldexp", + "math.lgamma", + "math.log", + "math.log10", + "math.log1p", + "math.log2", + "math.modf", + "math.nextafter", + "math.perm", + "math.pow", + "math.prod", + "math.radians", + "math.remainder", + "math.sin", + "math.sinh", + "math.tan", + "math.tanh", + "math.trunc", + "math.ulp", + "torch._adaptive_avg_pool2d", + "torch._adaptive_avg_pool3d", + "torch._add_batch_dim", + "torch._add_relu_", + "torch._add_relu", + "torch._addmm_activation", + "torch._aminmax", + "torch._amp_foreach_non_finite_check_and_unscale_", + "torch._amp_update_scale_", + "torch._assert_async", + "torch._assert_tensor_metadata", + "torch._batch_norm_impl_index", + "torch._C._activate_cuda_trace", + "torch._C._add_cached_tensor", + "torch._C._add_docstr", + "torch._C._are_functorch_transforms_active", + "torch._C._autograd_init", + "torch._C._awaitable_nowait", + "torch._C._awaitable_wait", + "torch._C._awaitable", + "torch._C._backport_for_mobile_from_buffer_to_buffer", + "torch._C._backport_for_mobile_from_buffer", + "torch._C._backport_for_mobile_to_buffer", + "torch._C._backport_for_mobile", + "torch._C._broadcast_coalesced", + "torch._C._broadcast_out", + "torch._C._broadcast", + "torch._C._c10d_init", + "torch._C._calculate_package_version_based_on_upgraders", + "torch._C._can_use_flash_attention", + "torch._C._can_use_mem_efficient_attention", + "torch._C._check_onnx_proto", + "torch._C._check_sparse_tensor_invariants", + "torch._C._collect_all", + "torch._C._commit_update", + "torch._C._compile_graph_to_code_table", + "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", + "torch._C._construct_storage_from_data_pointer", + "torch._C._conv_determine_backend_memory_format", + "torch._C._cpu._is_cpu_support_vnni", + "torch._C._crash_if_aten_asan", + "torch._C._crash_if_csrc_asan", + "torch._C._crash_if_csrc_ubsan", + "torch._C._crash_if_debug_asserts_fail", + "torch._C._crash_if_vptr_ubsan", + "torch._C._create_function_from_graph", + "torch._C._create_function_from_trace_with_dict", + "torch._C._create_function_from_trace", + "torch._C._create_graph_by_tracing", + "torch._C._create_module_with_type", + "torch._C._create_object_with_type", + "torch._C._cuda_attach_out_of_memory_observer", + "torch._C._cuda_beginAllocateCurrentStreamToPool", + "torch._C._cuda_canDeviceAccessPeer", + "torch._C._cuda_changeCurrentAllocator", + "torch._C._cuda_checkPoolLiveAllocations", + "torch._C._cuda_clearCublasWorkspaces", + "torch._C._cuda_cudaCachingAllocator_raw_alloc", + "torch._C._cuda_cudaCachingAllocator_raw_delete", + "torch._C._cuda_cudaCachingAllocator_set_allocator_settings", + "torch._C._cuda_cudaHostAllocator", + "torch._C._cuda_customAllocator", + "torch._C._cuda_emptyCache", + "torch._C._cuda_endAllocateCurrentStreamToPool", + "torch._C._cuda_exchangeDevice", + "torch._C._cuda_get_conv_benchmark_empty_cache", + "torch._C._cuda_get_cudnn_benchmark_limit", + "torch._C._cuda_get_sync_debug_mode", + "torch._C._cuda_getAllocator", + "torch._C._cuda_getAllocatorBackend", + "torch._C._cuda_getArchFlags", + "torch._C._cuda_getCheckpointState", + "torch._C._cuda_getCompiledVersion", + "torch._C._cuda_getCurrentBlasHandle", + "torch._C._cuda_getCurrentRawStream", + "torch._C._cuda_getCurrentStream", + "torch._C._cuda_getDefaultStream", + "torch._C._cuda_getDevice", + "torch._C._cuda_getDeviceCount", + "torch._C._cuda_hasPrimaryContext", + "torch._C._cuda_init", + "torch._C._cuda_ipc_collect", + "torch._C._cuda_isCurrentStreamCapturing", + "torch._C._cuda_isHistoryEnabled", + "torch._C._cuda_isInBadFork", + "torch._C._cuda_jiterator_compile_and_launch_kernel", + "torch._C._cuda_lock_mutex", + "torch._C._cuda_maybeExchangeDevice", + "torch._C._cuda_memorySnapshot", + "torch._C._cuda_memoryStats", + "torch._C._cuda_record_memory_history_legacy", + "torch._C._cuda_record_memory_history", + "torch._C._cuda_releasePool", + "torch._C._cuda_resetAccumulatedMemoryStats", + "torch._C._cuda_resetPeakMemoryStats", + "torch._C._cuda_set_cudnn_benchmark_limit", + "torch._C._cuda_set_sync_debug_mode", + "torch._C._cuda_setCheckpointPoolState", + "torch._C._cuda_setDevice", + "torch._C._cuda_setMemoryFraction", + "torch._C._cuda_setStream", + "torch._C._cuda_sleep", + "torch._C._cuda_synchronize", + "torch._C._cuda_unlock_mutex", + "torch._C._cudnn_set_conv_benchmark_empty_cache", + "torch._C._cudnn.getCompileVersion", + "torch._C._cudnn.getRuntimeVersion", + "torch._C._cudnn.getVersionInt", + "torch._C._current_autograd_node", + "torch._C._current_graph_task_execution_order", + "torch._C._current_graph_task_id", + "torch._C._cxx_flags", + "torch._C._debug_get_fusion_group_inlining", + "torch._C._debug_only_are_vmap_fallback_warnings_enabled", + "torch._C._debug_only_display_vmap_fallback_warnings", + "torch._C._debug_set_autodiff_subgraph_inlining", + "torch._C._debug_set_fusion_group_inlining", + "torch._C._demangle", + "torch._C._disabled_torch_dispatch_impl", + "torch._C._disabled_torch_function_impl", + "torch._C._dispatch_call_boxed", + "torch._C._dispatch_check_all_invariants", + "torch._C._dispatch_check_invariants", + "torch._C._dispatch_dump_table", + "torch._C._dispatch_dump", + "torch._C._dispatch_find_dangling_impls", + "torch._C._dispatch_find_schema_or_throw", + "torch._C._dispatch_get_all_op_names", + "torch._C._dispatch_get_backend_keyset_from_autograd", + "torch._C._dispatch_get_registrations_for_dispatch_key", + "torch._C._dispatch_has_backend_fallback", + "torch._C._dispatch_has_computed_kernel_for_dispatch_key", + "torch._C._dispatch_has_kernel_for_any_dispatch_key", + "torch._C._dispatch_has_kernel_for_dispatch_key", + "torch._C._dispatch_has_kernel", + "torch._C._dispatch_is_alias_key", + "torch._C._dispatch_is_included_in_alias", + "torch._C._dispatch_is_main_interpreter", + "torch._C._dispatch_isTensorSubclassLike", + "torch._C._dispatch_key_for_device", + "torch._C._dispatch_key_name", + "torch._C._dispatch_key_parse", + "torch._C._dispatch_key_set", + "torch._C._dispatch_keys", + "torch._C._dispatch_keyset_full_after", + "torch._C._dispatch_keyset_full", + "torch._C._dispatch_keyset_to_string", + "torch._C._dispatch_library", + "torch._C._dispatch_num_backends", + "torch._C._dispatch_print_registrations_for_dispatch_key", + "torch._C._dispatch_pystub", + "torch._C._dispatch_set_report_error_callback", + "torch._C._dispatch_tls_is_dispatch_key_excluded", + "torch._C._dispatch_tls_is_dispatch_key_included", + "torch._C._dispatch_tls_local_exclude_set", + "torch._C._dispatch_tls_local_include_set", + "torch._C._dispatch_tls_set_dispatch_key_excluded", + "torch._C._dispatch_tls_set_dispatch_key_included", + "torch._C._dist_autograd_init", + "torch._C._dump_local_tls_set", + "torch._C._dump_upgraders_map", + "torch._C._enable_mobile_interface_call_export", + "torch._C._enter_dual_level", + "torch._C._error_if_any_worker_fails", + "torch._C._exit_dual_level", + "torch._C._export_operator_list", + "torch._C._export_opnames", + "torch._C._faulty_agent_init", + "torch._C._fft.fft_fft", + "torch._C._fft.fft_fft2", + "torch._C._fft.fft_fftfreq", + "torch._C._fft.fft_fftn", + "torch._C._fft.fft_fftshift", + "torch._C._fft.fft_hfft", + "torch._C._fft.fft_hfft2", + "torch._C._fft.fft_hfftn", + "torch._C._fft.fft_ifft", + "torch._C._fft.fft_ifft2", + "torch._C._fft.fft_ifftn", + "torch._C._fft.fft_ifftshift", + "torch._C._fft.fft_ihfft", + "torch._C._fft.fft_ihfft2", + "torch._C._fft.fft_ihfftn", + "torch._C._fft.fft_irfft", + "torch._C._fft.fft_irfft2", + "torch._C._fft.fft_irfftn", + "torch._C._fft.fft_rfft", + "torch._C._fft.fft_rfft2", + "torch._C._fft.fft_rfftfreq", + "torch._C._fft.fft_rfftn", + "torch._C._free_And_Remove_DeleterFn", + "torch._C._freeze_module", + "torch._C._from_dlpack", + "torch._C._functionality_to_backend_keys", + "torch._C._functionalization_reapply_views_tls", + "torch._C._fuse_to_static_module", + "torch._C._gather_out", + "torch._C._gather", + "torch._C._generate_upgraders_graph", + "torch._C._get_autograd_fallback_mode", + "torch._C._get_backcompat_broadcast_warn", + "torch._C._get_backcompat_keepdim_warn", + "torch._C._get_caught_jit_exception_class_name", + "torch._C._get_caught_jit_exception_original_msg", + "torch._C._get_constant_bool_symnode", + "torch._C._get_cpp_backtrace", + "torch._C._get_cpu_capability", + "torch._C._get_cublas_allow_bf16_reduced_precision_reduction", + "torch._C._get_cublas_allow_fp16_reduced_precision_reduction", + "torch._C._get_cublas_allow_tf32", + "torch._C._get_cudnn_allow_tf32", + "torch._C._get_cudnn_benchmark", + "torch._C._get_cudnn_deterministic", + "torch._C._get_cudnn_enabled", + "torch._C._get_custom_class_python_wrapper", + "torch._C._get_default_device", + "torch._C._get_deterministic_algorithms_warn_only", + "torch._C._get_deterministic_algorithms", + "torch._C._get_deterministic_fill_uninitialized_memory", + "torch._C._get_dispatch_mode", + "torch._C._get_dispatch_stack_at", + "torch._C._get_file_format", + "torch._C._get_flash_sdp_enabled", + "torch._C._get_float32_matmul_precision", + "torch._C._get_function_stack_at", + "torch._C._get_graph_executor_optimize", + "torch._C._get_linalg_preferred_backend", + "torch._C._get_math_sdp_enabled", + "torch._C._get_max_operator_version", + "torch._C._get_mem_efficient_sdp_enabled", + "torch._C._get_mkldnn_enabled", + "torch._C._get_cudnn_sdp_enabled", + "torch._C._set_sdp_use_cudnn", + "torch._C._get_mobile_model_contained_types_from_buffer", + "torch._C._get_mobile_model_contained_types", + "torch._C._get_model_bytecode_version_from_buffer", + "torch._C._get_model_bytecode_version", + "torch._C._get_model_extra_files_from_buffer", + "torch._C._get_model_extra_files", + "torch._C._get_model_ops_and_info_from_buffer", + "torch._C._get_model_ops_and_info", + "torch._C._get_module_info_from_flatbuffer", + "torch._C._get_nnpack_enabled", + "torch._C._get_obj_in_tls", + "torch._C._get_operation_overload", + "torch._C._get_operator_version_map", + "torch._C._get_privateuse1_backend_name", + "torch._C._get_qengine", + "torch._C._get_schema", + "torch._C._get_nested_int", + "torch._C._get_tensor_metadata", + "torch._C._get_tracing_state", + "torch._C._get_upgrader_ranges", + "torch._C._get_upgraders_entry_map", + "torch._C._get_upgraders_map_size", + "torch._C._get_value_trace", + "torch._C._get_version_calculator_flag", + "torch._C._get_warnAlways", + "torch._C._graph_pool_handle", + "torch._C._group_tensors_by_device_and_dtype", + "torch._C._hack_do_not_use_clone_module_with_class", + "torch._C._has_distributed", + "torch._C._has_Standard_Deleter", + "torch._C._has_storage", + "torch._C._has_tensorexpr_cpp_tests", + "torch._C._run_tensorexpr_cpp_tests", + "torch._C._has_torch_function_unary", + "torch._C._has_torch_function_variadic", + "torch._C._has_torch_function", + "torch._C._import_ir_module_from_package", + "torch._C._increment_version", + "torch._C._infer_size", + "torch._C._init_names", + "torch._C._initExtension", + "torch._C._is_alias_of", + "torch._C._is_any_autocast_enabled", + "torch._C._is_cached_tensor", + "torch._C._is_fwd_grad_enabled", + "torch._C._is_key_in_tls", + "torch._C._is_multithreading_enabled", + "torch._C._is_torch_function_enabled", + "torch._C._is_torch_function_mode_enabled", + "torch._C._is_tracing", + "torch._C._is_view_replay_enabled", + "torch._C._is_xnnpack_enabled", + "torch._C._itt.is_available", + "torch._C._itt.mark", + "torch._C._itt.rangePop", + "torch._C._itt.rangePush", + "torch._C._ivalue_debug_python_object", + "torch._C._ivalue_tags_match", + "torch._C._jit_assert_is_instance", + "torch._C._jit_can_fuse_on_cpu_legacy", + "torch._C._jit_can_fuse_on_cpu", + "torch._C._jit_can_fuse_on_gpu", + "torch._C._jit_cat_wo_conditionals", + "torch._C._jit_check_alias_annotation", + "torch._C._jit_clear_class_registry", + "torch._C._jit_debug_fuser_num_cached_kernel_specs", + "torch._C._jit_debug_module_iterators", + "torch._C._jit_decay_packed_param_input_types", + "torch._C._jit_decomposition_graph_for_node", + "torch._C._jit_differentiate", + "torch._C._jit_erase_non_input_shape_information", + "torch._C._jit_flatten", + "torch._C._jit_fuser_get_fused_kernel_code", + "torch._C._jit_get_all_schemas", + "torch._C._jit_get_custom_class_schemas", + "torch._C._jit_get_emit_hooks", + "torch._C._jit_get_inline_everything_mode", + "torch._C._jit_get_logging_option", + "torch._C._jit_get_num_profiled_runs", + "torch._C._jit_get_operation", + "torch._C._jit_get_schemas_for_operator", + "torch._C._jit_get_te_cuda_pointwise_block_count", + "torch._C._jit_get_te_cuda_pointwise_block_size", + "torch._C._jit_get_te_cuda_pointwise_loop_levels", + "torch._C._jit_get_te_generate_block_code", + "torch._C._jit_get_te_must_use_llvm_cpu", + "torch._C._jit_get_tracer_state_warn", + "torch._C._jit_has_cpp_tests", + "torch._C._jit_init", + "torch._C._jit_interpret_graph", + "torch._C._jit_is_onnx_log_enabled", + "torch._C._jit_is_script_object", + "torch._C._jit_llga_enabled", + "torch._C._jit_nvfuser_can_be_enabled", + "torch._C._jit_nvfuser_clear_comparison_callback", + "torch._C._jit_nvfuser_enabled", + "torch._C._jit_nvfuser_horizontal_mode", + "torch._C._jit_nvfuser_set_comparison_callback", + "torch._C._jit_nvfuser_single_node_mode", + "torch._C._jit_object_is_non_holding", + "torch._C._jit_onnx_convert_pattern_from_subblock", + "torch._C._jit_onnx_create_full_scope_name", + "torch._C._jit_onnx_list_model_parameters", + "torch._C._jit_onnx_log", + "torch._C._jit_opt_conditionals", + "torch._C._jit_override_can_fuse_on_cpu_legacy", + "torch._C._jit_override_can_fuse_on_cpu", + "torch._C._jit_override_can_fuse_on_gpu", + "torch._C._jit_pass_autocast", + "torch._C._jit_pass_batch_mm", + "torch._C._jit_pass_canonicalize_graph_fuser_ops", + "torch._C._jit_pass_canonicalize", + "torch._C._jit_pass_complete_shape_analysis", + "torch._C._jit_pass_concat_frozen_linear", + "torch._C._jit_pass_constant_loop_unrolling", + "torch._C._jit_pass_constant_pooling", + "torch._C._jit_pass_constant_propagation_immutable_types", + "torch._C._jit_pass_constant_propagation", + "torch._C._jit_pass_convert_frozen_ops_to_mkldnn", + "torch._C._jit_pass_create_autodiff_subgraphs", + "torch._C._jit_pass_create_functional_graphs", + "torch._C._jit_pass_cse", + "torch._C._jit_pass_custom_pattern_based_rewrite_graph", + "torch._C._jit_pass_custom_pattern_based_rewrite", + "torch._C._jit_pass_dbr_quant_remove_redundant_aliases", + "torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects", + "torch._C._jit_pass_dce", + "torch._C._jit_pass_decompose_ops", + "torch._C._jit_pass_dedup_module_uses", + "torch._C._jit_pass_erase_number_types", + "torch._C._jit_pass_erase_shape_information", + "torch._C._jit_pass_filter_non_tensor_arguments", + "torch._C._jit_pass_fixup_onnx_controlflow_node", + "torch._C._jit_pass_fold_convbn", + "torch._C._jit_pass_fold_frozen_conv_add_or_sub", + "torch._C._jit_pass_fold_frozen_conv_bn", + "torch._C._jit_pass_fold_frozen_conv_mul_or_div", + "torch._C._jit_pass_fold_frozen_linear_bn", + "torch._C._jit_pass_fold_prepacking_ops", + "torch._C._jit_pass_functional_to_inplace_activation", + "torch._C._jit_pass_fuse_add_relu", + "torch._C._jit_pass_fuse_addmm", + "torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv", + "torch._C._jit_pass_fuse_frozen_conv_add_relu", + "torch._C._jit_pass_fuse_linear", + "torch._C._jit_pass_fuse_quantized_add_relu", + "torch._C._jit_pass_fuse_tensorexprs", + "torch._C._jit_pass_fuse", + "torch._C._jit_pass_inline_fork_wait", + "torch._C._jit_pass_inline_functional_graphs", + "torch._C._jit_pass_inline", + "torch._C._jit_pass_inplace_to_functional_activation", + "torch._C._jit_pass_insert_observer_method_for_ondevice_ptq", + "torch._C._jit_pass_insert_observers", + "torch._C._jit_pass_insert_prepack_unpack", + "torch._C._jit_pass_insert_prepacked_ops", + "torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq", + "torch._C._jit_pass_insert_quant_dequant", + "torch._C._jit_pass_integer_value_refinement", + "torch._C._jit_pass_lint", + "torch._C._jit_pass_loop_unrolling", + "torch._C._jit_pass_lower_all_tuples", + "torch._C._jit_pass_lower_graph", + "torch._C._jit_pass_metal_fold_prepacking_ops", + "torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv", + "torch._C._jit_pass_metal_insert_prepacked_ops", + "torch._C._jit_pass_metal_optimize_for_mobile", + "torch._C._jit_pass_onnx_assign_output_shape", + "torch._C._jit_pass_onnx_assign_scoped_names_for_node_and_value", + "torch._C._jit_pass_onnx_autograd_function_process", + "torch._C._jit_pass_onnx_block", + "torch._C._jit_pass_onnx_cast_all_constant_to_floating", + "torch._C._jit_pass_onnx_clear_scope_records", + "torch._C._jit_pass_onnx_constant_fold", + "torch._C._jit_pass_onnx_deduplicate_initializers", + "torch._C._jit_pass_onnx_eliminate_unused_items", + "torch._C._jit_pass_onnx_eval_peephole", + "torch._C._jit_pass_onnx_function_extraction", + "torch._C._jit_pass_onnx_function_substitution", + "torch._C._jit_pass_onnx_graph_shape_type_inference", + "torch._C._jit_pass_onnx_lint", + "torch._C._jit_pass_onnx_node_shape_type_inference", + "torch._C._jit_pass_onnx_peephole", + "torch._C._jit_pass_onnx_preprocess_caffe2", + "torch._C._jit_pass_onnx_preprocess", + "torch._C._jit_pass_onnx_quantization_insert_permutes", + "torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx", + "torch._C._jit_pass_onnx_remove_print", + "torch._C._jit_pass_onnx_scalar_type_analysis", + "torch._C._jit_pass_onnx_set_dynamic_input_shape", + "torch._C._jit_pass_onnx_track_scope_attributes", + "torch._C._jit_pass_onnx_unpack_quantized_weights", + "torch._C._jit_pass_onnx", + "torch._C._jit_pass_optimize_for_inference", + "torch._C._jit_pass_optimize_for_mobile", + "torch._C._jit_pass_optimize_frozen_graph", + "torch._C._jit_pass_pattern_based_rewrite", + "torch._C._jit_pass_peephole_list_idioms", + "torch._C._jit_pass_peephole", + "torch._C._jit_pass_prepare_division_for_onnx", + "torch._C._jit_pass_propagate_device", + "torch._C._jit_pass_propagate_dtype", + "torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute", + "torch._C._jit_pass_propagate_shapes_on_graph", + "torch._C._jit_pass_quant_finalize_for_ondevice_ptq", + "torch._C._jit_pass_quant_finalize", + "torch._C._jit_pass_quant_fusion", + "torch._C._jit_pass_refine_integer_values", + "torch._C._jit_pass_refine_tuple_types", + "torch._C._jit_pass_remove_dropout", + "torch._C._jit_pass_remove_expands", + "torch._C._jit_pass_remove_inplace_ops", + "torch._C._jit_pass_remove_mutation", + "torch._C._jit_pass_replace_old_ops_with_upgraders", + "torch._C._jit_pass_replicate_dequantize", + "torch._C._jit_pass_run_decompositions", + "torch._C._jit_pass_specialize_autogradzero", + "torch._C._jit_pass_swap_functional_linear", + "torch._C._jit_pass_transform_conv1d_to_conv2d", + "torch._C._jit_pass_transpose_frozen_linear", + "torch._C._jit_pass_vulkan_fold_prepacking_ops", + "torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv", + "torch._C._jit_pass_vulkan_insert_prepacked_ops", + "torch._C._jit_pass_vulkan_optimize_for_mobile", + "torch._C._jit_register_decomposition_for_schema", + "torch._C._jit_register_shape_compute_graph_for_node", + "torch._C._jit_resolve_packet", + "torch._C._jit_run_cpp_tests", + "torch._C._jit_script_class_compile", + "torch._C._jit_script_compile_overload", + "torch._C._jit_script_compile", + "torch._C._jit_script_interface_compile", + "torch._C._jit_set_autocast_mode", + "torch._C._jit_set_bailout_depth", + "torch._C._jit_set_emit_hooks", + "torch._C._jit_set_fusion_strategy", + "torch._C._jit_set_inline_everything_mode", + "torch._C._jit_set_llga_enabled", + "torch._C._jit_set_logging_option", + "torch._C._jit_set_logging_stream", + "torch._C._jit_set_num_profiled_runs", + "torch._C._jit_set_nvfuser_enabled", + "torch._C._jit_set_nvfuser_guard_mode", + "torch._C._jit_set_nvfuser_horizontal_mode", + "torch._C._jit_set_nvfuser_single_node_mode", + "torch._C._jit_set_nvfuser_skip_node_kind", + "torch._C._jit_set_onnx_log_enabled", + "torch._C._jit_set_onnx_log_output_stream", + "torch._C._jit_set_profiling_executor", + "torch._C._jit_set_profiling_mode", + "torch._C._jit_set_symbolic_shapes_test_mode", + "torch._C._jit_set_te_cuda_pointwise_block_count", + "torch._C._jit_set_te_cuda_pointwise_block_size", + "torch._C._jit_set_te_cuda_pointwise_loop_levels", + "torch._C._jit_set_te_generate_block_code", + "torch._C._jit_set_te_must_use_llvm_cpu", + "torch._C._jit_set_texpr_dynamic_shape_enabled", + "torch._C._jit_set_texpr_fuser_enabled", + "torch._C._jit_set_texpr_reductions_enabled", + "torch._C._jit_set_tracer_state_warn", + "torch._C._jit_set_utf8_decoding_ignore", + "torch._C._jit_shape_compute_graph_for_node", + "torch._C._jit_symbolic_shapes_test_mode_enabled", + "torch._C._jit_texpr_dynamic_shape_enabled", + "torch._C._jit_texpr_fallback_allowed", + "torch._C._jit_texpr_fuser_enabled", + "torch._C._jit_texpr_reductions_enabled", + "torch._C._jit_texpr_set_fallback_allowed", + "torch._C._jit_to_backend_selective", + "torch._C._jit_to_backend", + "torch._C._jit_to_static_module", + "torch._C._jit_trace_graph", + "torch._C._jit_trace_module", + "torch._C._jit_tree_views.FalseLiteral", + "torch._C._jit_tree_views.NoneLiteral", + "torch._C._jit_tree_views.TrueLiteral", + "torch._C._jit_try_infer_type", + "torch._C._jit_unflatten", + "torch._C._last_executed_optimized_graph", + "torch._C._len_torch_dispatch_stack", + "torch._C._len_torch_function_stack", + "torch._C._linalg._linalg_eigvals", + "torch._C._linalg.linalg_cholesky_ex", + "torch._C._linalg.linalg_cholesky", + "torch._C._linalg.linalg_cond", + "torch._C._linalg.linalg_cross", + "torch._C._linalg.linalg_det", + "torch._C._linalg.linalg_diagonal", + "torch._C._linalg.linalg_eig", + "torch._C._linalg.linalg_eigh", + "torch._C._linalg.linalg_eigvals", + "torch._C._linalg.linalg_eigvalsh", + "torch._C._linalg.linalg_householder_product", + "torch._C._linalg.linalg_inv_ex", + "torch._C._linalg.linalg_inv", + "torch._C._linalg.linalg_ldl_factor_ex", + "torch._C._linalg.linalg_ldl_factor", + "torch._C._linalg.linalg_ldl_solve", + "torch._C._linalg.linalg_lstsq", + "torch._C._linalg.linalg_lu_factor_ex", + "torch._C._linalg.linalg_lu_factor", + "torch._C._linalg.linalg_lu_solve", + "torch._C._linalg.linalg_lu", + "torch._C._linalg.linalg_matmul", + "torch._C._linalg.linalg_matrix_exp", + "torch._C._linalg.linalg_matrix_norm", + "torch._C._linalg.linalg_matrix_power", + "torch._C._linalg.linalg_matrix_rank", + "torch._C._linalg.linalg_multi_dot", + "torch._C._linalg.linalg_norm", + "torch._C._linalg.linalg_pinv", + "torch._C._linalg.linalg_qr", + "torch._C._linalg.linalg_slogdet", + "torch._C._linalg.linalg_solve_ex", + "torch._C._linalg.linalg_solve_triangular", + "torch._C._linalg.linalg_solve", + "torch._C._linalg.linalg_svd", + "torch._C._linalg.linalg_svdvals", + "torch._C._linalg.linalg_tensorinv", + "torch._C._linalg.linalg_tensorsolve", + "torch._C._linalg.linalg_vander", + "torch._C._linalg.linalg_vecdot", + "torch._C._linalg.linalg_vector_norm", + "torch._C._llvm_enabled", + "torch._C._load_for_lite_interpreter_from_buffer", + "torch._C._load_for_lite_interpreter", + "torch._C._load_jit_module_from_bytes", + "torch._C._load_jit_module_from_file", + "torch._C._load_mobile_module_from_bytes", + "torch._C._load_mobile_module_from_file", + "torch._C._log_api_usage_metadata", + "torch._C._log_api_usage_once", + "torch._C._logging_set_logger", + "torch._C._meta_in_tls_dispatch_include", + "torch._C._mps_acquireEvent", + "torch._C._mps_currentAllocatedMemory", + "torch._C._mps_deviceSynchronize", + "torch._C._mps_driverAllocatedMemory", + "torch._C._mps_elapsedTimeOfEvents", + "torch._C._mps_emptyCache", + "torch._C._mps_get_default_generator", + "torch._C._mps_is_available", + "torch._C._mps_is_in_bad_fork", + "torch._C._mps_is_on_macos_13_or_newer", + "torch._C._mps_profilerStartTrace", + "torch._C._mps_profilerStopTrace", + "torch._C._mps_queryEvent", + "torch._C._mps_recordEvent", + "torch._C._mps_releaseEvent", + "torch._C._mps_setMemoryFraction", + "torch._C._mps_synchronizeEvent", + "torch._C._mps_waitForEvent", + "torch._C._multiprocessing_init", + "torch._C._nccl_all_gather", + "torch._C._nccl_all_reduce", + "torch._C._nccl_broadcast", + "torch._C._nccl_init_rank", + "torch._C._nccl_reduce_scatter", + "torch._C._nccl_reduce", + "torch._C._nccl_unique_id", + "torch._C._nccl_version_suffix", + "torch._C._nccl_version", + "torch._C._nested.nested_tensor", + "torch._C._nested.nested_to_padded_tensor", + "torch._C._new_symbolic_shape_symbol", + "torch._C._nn_module_to_mobile", + "torch._C._nn._conv_depthwise2d", + "torch._C._nn._pad_circular", + "torch._C._nn._pad_enum", + "torch._C._nn._parse_to", + "torch._C._nn._test_ambiguous_defaults", + "torch._C._nn._test_optional_filled_intlist", + "torch._C._nn._test_optional_floatlist", + "torch._C._nn._test_optional_intlist", + "torch._C._nn._test_string_default", + "torch._C._nn._test_warn_in_autograd", + "torch._C._nn._upsample_bicubic2d_aa", + "torch._C._nn._upsample_bilinear2d_aa", + "torch._C._nn._upsample_nearest_exact1d", + "torch._C._nn._upsample_nearest_exact2d", + "torch._C._nn._upsample_nearest_exact3d", + "torch._C._nn.adaptive_avg_pool2d", + "torch._C._nn.adaptive_avg_pool3d", + "torch._C._nn.adaptive_max_pool2d", + "torch._C._nn.adaptive_max_pool3d", + "torch._C._nn.avg_pool2d", + "torch._C._nn.avg_pool3d", + "torch._C._nn.binary_cross_entropy", + "torch._C._nn.col2im", + "torch._C._nn.conv_depthwise3d", + "torch._C._nn.cross_entropy_loss", + "torch._C._nn.elu_", + "torch._C._nn.elu", + "torch._C._nn.flatten_dense_tensors", + "torch._C._nn.fractional_max_pool2d", + "torch._C._nn.fractional_max_pool3d", + "torch._C._nn.gelu_", + "torch._C._nn.gelu", + "torch._C._nn.glu", + "torch._C._nn.hardsigmoid_", + "torch._C._nn.hardsigmoid", + "torch._C._nn.hardswish_", + "torch._C._nn.hardswish", + "torch._C._nn.hardtanh_", + "torch._C._nn.hardtanh", + "torch._C._nn.huber_loss", + "torch._C._nn.im2col", + "torch._C._nn.l1_loss", + "torch._C._nn.leaky_relu_", + "torch._C._nn.leaky_relu", + "torch._C._nn.linear", + "torch._C._nn.log_sigmoid", + "torch._C._nn.max_pool2d_with_indices", + "torch._C._nn.max_pool3d_with_indices", + "torch._C._nn.max_unpool2d", + "torch._C._nn.max_unpool3d", + "torch._C._nn.mish_", + "torch._C._nn.mish", + "torch._C._nn.mkldnn_linear", + "torch._C._nn.mkldnn_reorder_conv2d_weight", + "torch._C._nn.mkldnn_reorder_conv3d_weight", + "torch._C._nn.mse_loss", + "torch._C._nn.multi_margin_loss", + "torch._C._nn.multilabel_margin_loss", + "torch._C._nn.nll_loss_nd", + "torch._C._nn.nll_loss", + "torch._C._nn.nll_loss2d", + "torch._C._nn.one_hot", + "torch._C._nn.pad_sequence", + "torch._C._nn.pad", + "torch._C._nn.reflection_pad1d", + "torch._C._nn.reflection_pad2d", + "torch._C._nn.reflection_pad3d", + "torch._C._nn.relu6_", + "torch._C._nn.relu6", + "torch._C._nn.replication_pad1d", + "torch._C._nn.replication_pad2d", + "torch._C._nn.replication_pad3d", + "torch._C._nn.rrelu_with_noise_", + "torch._C._nn.rrelu_with_noise", + "torch._C._nn.scaled_dot_product_attention", + "torch._C._nn.silu_", + "torch._C._nn.silu", + "torch._C._nn.slow_conv_dilated2d", + "torch._C._nn.slow_conv_dilated3d", + "torch._C._nn.slow_conv_transpose2d", + "torch._C._nn.slow_conv_transpose3d", + "torch._C._nn.slow_conv3d", + "torch._C._nn.smooth_l1_loss", + "torch._C._nn.soft_margin_loss", + "torch._C._nn.softplus", + "torch._C._nn.softshrink", + "torch._C._nn.thnn_conv2d", + "torch._C._nn.unflatten_dense_tensors", + "torch._C._nn.upsample_bicubic2d", + "torch._C._nn.upsample_bilinear2d", + "torch._C._nn.upsample_linear1d", + "torch._C._nn.upsample_nearest1d", + "torch._C._nn.upsample_nearest2d", + "torch._C._nn.upsample_nearest3d", + "torch._C._nn.upsample_trilinear3d", + "torch._C._non_sym_sizes", + "torch._C._overlaps", + "torch._C._parallel_info", + "torch._C._parse_dispatch_key", + "torch._C._parse_source_def", + "torch._C._pop_torch_dispatch_stack", + "torch._C._pop_torch_function_stack", + "torch._C._propagate_and_assign_input_shapes", + "torch._C._propagate_shapes", + "torch._C._propagate_xla_data", + "torch._C._push_on_torch_dispatch_stack", + "torch._C._push_on_torch_function_stack", + "torch._C._quantize_ondevice_ptq_dynamic", + "torch._C._register_py_class_for_device", + "torch._C._remove_cached_tensor", + "torch._C._remove_worker_pids", + "torch._C._rename_privateuse1_backend", + "torch._C._replace_", + "torch._C._replace_overloaded_method_decl", + "torch._C._resolve_type_from_object", + "torch._C._resolve_type", + "torch._C._rocm_is_backward_pass", + "torch._C._rpc_init", + "torch._C._run_emit_module_hook", + "torch._C._save_jit_module_to_bytes", + "torch._C._save_jit_module", + "torch._C._save_mobile_module_to_bytes", + "torch._C._save_mobile_module", + "torch._C._save_parameters", + "torch._C._scatter_out", + "torch._C._scatter", + "torch._C._select_conv_backend", + "torch._C._set_autograd_fallback_mode", + "torch._C._set_backcompat_broadcast_warn", + "torch._C._set_backcompat_keepdim_warn", + "torch._C._set_cached_tensors_enabled", + "torch._C._set_check_sparse_tensor_invariants", + "torch._C._set_conj", + "torch._C._set_cublas_allow_bf16_reduced_precision_reduction", + "torch._C._set_cublas_allow_fp16_reduced_precision_reduction", + "torch._C._set_cublas_allow_tf32", + "torch._C._set_cudnn_allow_tf32", + "torch._C._set_cudnn_benchmark", + "torch._C._set_cudnn_deterministic", + "torch._C._set_cudnn_enabled", + "torch._C._set_default_dtype", + "torch._C._set_default_mobile_cpu_allocator", + "torch._C._set_default_tensor_type", + "torch._C._set_deterministic_algorithms", + "torch._C._set_deterministic_fill_uninitialized_memory", + "torch._C._set_dispatch_mode", + "torch._C._set_float32_matmul_precision", + "torch._C._set_fwd_grad_enabled", + "torch._C._set_grad_enabled", + "torch._C._set_graph_executor_optimize", + "torch._C._set_linalg_preferred_backend", + "torch._C._set_meta_in_tls_dispatch_include", + "torch._C._set_mkldnn_enabled", + "torch._C._set_multithreading_enabled", + "torch._C._set_neg", + "torch._C._set_nnpack_enabled", + "torch._C._set_print_stack_traces_on_fatal_signal", + "torch._C._set_qengine", + "torch._C._set_sdp_use_flash", + "torch._C._set_sdp_use_math", + "torch._C._set_sdp_use_mem_efficient", + "torch._C._set_should_use_format_with_string_table", + "torch._C._set_storage_access_error_msg", + "torch._C._set_tensor_metadata", + "torch._C._set_tracing_state", + "torch._C._set_value_trace", + "torch._C._set_view_replay_enabled", + "torch._C._set_warnAlways", + "torch._C._set_worker_pids", + "torch._C._set_worker_signal_handlers", + "torch._C._should_allow_numbers_as_tensors", + "torch._C._show_config", + "torch._C._sparse._sparse_addmm", + "torch._C._sparse._sparse_log_softmax", + "torch._C._sparse._sparse_mm_reduce_impl", + "torch._C._sparse._sparse_mm", + "torch._C._sparse._sparse_softmax", + "torch._C._sparse._spdiags", + "torch._C._sparse.sparse_sampled_addmm", + "torch._C._special.special_airy_ai", + "torch._C._special.special_bessel_j0", + "torch._C._special.special_bessel_j1", + "torch._C._special.special_bessel_y0", + "torch._C._special.special_bessel_y1", + "torch._C._special.special_chebyshev_polynomial_t", + "torch._C._special.special_chebyshev_polynomial_u", + "torch._C._special.special_chebyshev_polynomial_v", + "torch._C._special.special_chebyshev_polynomial_w", + "torch._C._special.special_digamma", + "torch._C._special.special_entr", + "torch._C._special.special_erf", + "torch._C._special.special_erfc", + "torch._C._special.special_erfcx", + "torch._C._special.special_erfinv", + "torch._C._special.special_exp2", + "torch._C._special.special_expit", + "torch._C._special.special_expm1", + "torch._C._special.special_gammainc", + "torch._C._special.special_gammaincc", + "torch._C._special.special_gammaln", + "torch._C._special.special_hermite_polynomial_h", + "torch._C._special.special_hermite_polynomial_he", + "torch._C._special.special_i0", + "torch._C._special.special_i0e", + "torch._C._special.special_i1", + "torch._C._special.special_i1e", + "torch._C._special.special_laguerre_polynomial_l", + "torch._C._special.special_legendre_polynomial_p", + "torch._C._special.special_log_ndtr", + "torch._C._special.special_log_softmax", + "torch._C._special.special_log1p", + "torch._C._special.special_logit", + "torch._C._special.special_logsumexp", + "torch._C._special.special_modified_bessel_i0", + "torch._C._special.special_modified_bessel_i1", + "torch._C._special.special_modified_bessel_k0", + "torch._C._special.special_modified_bessel_k1", + "torch._C._special.special_multigammaln", + "torch._C._special.special_ndtr", + "torch._C._special.special_ndtri", + "torch._C._special.special_polygamma", + "torch._C._special.special_psi", + "torch._C._special.special_round", + "torch._C._special.special_scaled_modified_bessel_k0", + "torch._C._special.special_scaled_modified_bessel_k1", + "torch._C._special.special_shifted_chebyshev_polynomial_t", + "torch._C._special.special_shifted_chebyshev_polynomial_u", + "torch._C._special.special_shifted_chebyshev_polynomial_v", + "torch._C._special.special_shifted_chebyshev_polynomial_w", + "torch._C._special.special_sinc", + "torch._C._special.special_softmax", + "torch._C._special.special_spherical_bessel_j0", + "torch._C._special.special_xlog1py", + "torch._C._special.special_xlogy", + "torch._C._special.special_zeta", + "torch._C._stash_obj_in_tls", + "torch._C._storage_id", + "torch._C._storage_Use_Count", + "torch._C._supported_qengines", + "torch._C._te.abs", + "torch._C._te.acos", + "torch._C._te.annotate_input_shapes", + "torch._C._te.asin", + "torch._C._te.atan", + "torch._C._te.atan2", + "torch._C._te.ceil", + "torch._C._te.Compute", + "torch._C._te.Compute2", + "torch._C._te.construct_codegen", + "torch._C._te.cos", + "torch._C._te.cosh", + "torch._C._te.erf", + "torch._C._te.erfc", + "torch._C._te.exp", + "torch._C._te.expm1", + "torch._C._te.fixup_missing_shape_info", + "torch._C._te.floor", + "torch._C._te.fmod", + "torch._C._te.frac", + "torch._C._te.ifThenElse", + "torch._C._te.is_graph_compilable", + "torch._C._te.isnan", + "torch._C._te.lgamma", + "torch._C._te.log", + "torch._C._te.log10", + "torch._C._te.log1p", + "torch._C._te.log2", + "torch._C._te.lower", + "torch._C._te.make_shapes_symbolic", + "torch._C._te.pow", + "torch._C._te.Reduce", + "torch._C._te.remainder", + "torch._C._te.remove_graph_output", + "torch._C._te.remove_unused_self_argument", + "torch._C._te.replace_list_output_with_tuple", + "torch._C._te.round", + "torch._C._te.rsqrt", + "torch._C._te.sigmoid", + "torch._C._te.simplify", + "torch._C._te.sin", + "torch._C._te.sinh", + "torch._C._te.sqrt", + "torch._C._te.tan", + "torch._C._te.tanh", + "torch._C._te.trim_graph", + "torch._C._te.trunc", + "torch._C._tensor_impl_raw_handle", + "torch._C._test_only_add_entry_to_op_version_map", + "torch._C._test_only_populate_upgraders", + "torch._C._test_only_remove_entry_to_op_version_map", + "torch._C._test_only_remove_upgraders", + "torch._C._to_dlpack", + "torch._C._to_functionality_key", + "torch._C._tracer_set_force_outplace", + "torch._C._tracer_set_get_unique_name_fn", + "torch._C._tracer_warn_use_python", + "torch._C._unset_default_mobile_cpu_allocator", + "torch._C._unset_dispatch_mode", + "torch._C._valgrind_supported_platform", + "torch._C._valgrind_toggle_and_dump_stats", + "torch._C._valgrind_toggle", + "torch._C._verbose.mkl_set_verbose", + "torch._C._verbose.mkldnn_set_verbose", + "torch._C._vmapmode_decrement_nesting", + "torch._C._vmapmode_increment_nesting", + "torch._C._warn_deprecation", + "torch._C._warn", + "torch._C._will_engine_execute_node", + "torch._C._wrap_tensor_impl", + "torch._C.fork", + "torch._C.get_autocast_cpu_dtype", + "torch._C.get_autocast_gpu_dtype", + "torch._C.get_autocast_ipu_dtype", + "torch._C.get_autocast_xla_dtype", + "torch._C.get_default_dtype", + "torch._C.get_num_interop_threads", + "torch._C.get_num_threads", + "torch._C.import_ir_module_from_buffer", + "torch._C.import_ir_module", + "torch._C.init_num_threads", + "torch._C.is_anomaly_check_nan_enabled", + "torch._C.is_anomaly_enabled", + "torch._C.is_autocast_cache_enabled", + "torch._C.is_autocast_cpu_enabled", + "torch._C.is_autocast_enabled", + "torch._C.is_autocast_ipu_enabled", + "torch._C.is_autocast_xla_enabled", + "torch._C.is_grad_enabled", + "torch._C.is_inference_mode_enabled", + "torch._C.merge_type_from_type_comment", + "torch._C.parse_ir", + "torch._C.parse_schema", + "torch._C.parse_type_comment", + "torch._C.read_vitals", + "torch._C.set_flush_denormal", + "torch._C.set_num_interop_threads", + "torch._C.set_num_threads", + "torch._C.set_vital", + "torch._C.unify_type_list", + "torch._C.vitals_enabled", + "torch._C.wait", + "torch._cast_Byte", + "torch._cast_Char", + "torch._cast_Double", + "torch._cast_Float", + "torch._cast_Half", + "torch._cast_Int", + "torch._cast_Long", + "torch._cast_Short", + "torch._choose_qparams_per_tensor", + "torch._chunk_cat", + "torch._coalesce", + "torch._compute_linear_combination", + "torch._conj_copy", + "torch._conj_physical", + "torch._conj", + "torch._convert_indices_from_coo_to_csr", + "torch._convert_indices_from_csr_to_coo", + "torch._convert_weight_to_int4pack", + "torch._convolution_mode", + "torch._convolution", + "torch._copy_from_and_resize", + "torch._copy_from", + "torch._cslt_compress", + "torch._cslt_sparse_mm", + "torch._ctc_loss", + "torch._cudnn_ctc_loss", + "torch._cudnn_init_dropout_state", + "torch._cudnn_rnn_flatten_weight", + "torch._cudnn_rnn", + "torch._cufft_clear_plan_cache", + "torch._cufft_get_plan_cache_max_size", + "torch._cufft_get_plan_cache_size", + "torch._cufft_set_plan_cache_max_size", + "torch._cummax_helper", + "torch._cummin_helper", + "torch._debug_has_internal_overlap", + "torch._dim_arange", + "torch._dirichlet_grad", + "torch._disable_functionalization", + "torch._efficientzerotensor", + "torch._embedding_bag_forward_only", + "torch._embedding_bag", + "torch._empty_affine_quantized", + "torch._empty_per_channel_affine_quantized", + "torch._enable_functionalization", + "torch._euclidean_dist", + "torch._fake_quantize_learnable_per_channel_affine", + "torch._fake_quantize_learnable_per_tensor_affine", + "torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "torch._fft_c2c", + "torch._fft_c2r", + "torch._fft_r2c", + "torch._fill_mem_eff_dropout_mask_", + "torch._foobar", + "torch._foreach_abs_", + "torch._foreach_abs", + "torch._foreach_acos_", + "torch._foreach_acos", + "torch._foreach_add_", + "torch._foreach_add", + "torch._foreach_addcdiv_", + "torch._foreach_addcdiv", + "torch._foreach_addcmul_", + "torch._foreach_addcmul", + "torch._foreach_asin_", + "torch._foreach_asin", + "torch._foreach_atan_", + "torch._foreach_atan", + "torch._foreach_ceil_", + "torch._foreach_ceil", + "torch._foreach_clamp_max_", + "torch._foreach_clamp_max", + "torch._foreach_clamp_min_", + "torch._foreach_clamp_min", + "torch._foreach_copy_", + "torch._foreach_cos_", + "torch._foreach_cos", + "torch._foreach_cosh_", + "torch._foreach_cosh", + "torch._foreach_div_", + "torch._foreach_div", + "torch._foreach_erf_", + "torch._foreach_erf", + "torch._foreach_erfc_", + "torch._foreach_erfc", + "torch._foreach_exp_", + "torch._foreach_exp", + "torch._foreach_expm1_", + "torch._foreach_expm1", + "torch._foreach_floor_", + "torch._foreach_floor", + "torch._foreach_frac_", + "torch._foreach_frac", + "torch._foreach_lerp_", + "torch._foreach_lerp", + "torch._foreach_lgamma_", + "torch._foreach_lgamma", + "torch._foreach_log_", + "torch._foreach_log", + "torch._foreach_log10_", + "torch._foreach_log10", + "torch._foreach_log1p_", + "torch._foreach_log1p", + "torch._foreach_log2_", + "torch._foreach_log2", + "torch._foreach_maximum_", + "torch._foreach_maximum", + "torch._foreach_minimum_", + "torch._foreach_minimum", + "torch._foreach_mul_", + "torch._foreach_mul", + "torch._foreach_neg_", + "torch._foreach_neg", + "torch._foreach_norm", + "torch._foreach_pow_", + "torch._foreach_pow", + "torch._foreach_reciprocal_", + "torch._foreach_reciprocal", + "torch._foreach_round_", + "torch._foreach_round", + "torch._foreach_sigmoid_", + "torch._foreach_sigmoid", + "torch._foreach_sign_", + "torch._foreach_sign", + "torch._foreach_sin_", + "torch._foreach_sin", + "torch._foreach_sinh_", + "torch._foreach_sinh", + "torch._foreach_sqrt_", + "torch._foreach_sqrt", + "torch._foreach_sub_", + "torch._foreach_sub", + "torch._foreach_tan_", + "torch._foreach_tan", + "torch._foreach_tanh_", + "torch._foreach_tanh", + "torch._foreach_trunc_", + "torch._foreach_trunc", + "torch._foreach_zero_", + "torch._freeze_functional_tensor", + "torch._from_functional_tensor", + "torch._functional_assert_async", + "torch._functional_sym_constrain_range_for_size", + "torch._functional_sym_constrain_range", + "torch._functionalize_are_all_mutations_hidden_from_autograd", + "torch._functionalize_commit_update", + "torch._functionalize_enable_reapply_views", + "torch._functionalize_has_data_mutation", + "torch._functionalize_has_metadata_mutation", + "torch._functionalize_is_multi_output_view", + "torch._functionalize_mark_mutation_hidden_from_autograd", + "torch._functionalize_replace", + "torch._functionalize_sync", + "torch._functionalize_was_storage_changed", + "torch._fused_adam_", + "torch._fused_adamw_", + "torch._fused_dropout", + "torch._fused_moving_avg_obs_fq_helper", + "torch._fused_sdp_choice", + "torch._fw_primal_copy", + "torch._grid_sampler_2d_cpu_fallback", + "torch._has_compatible_shallow_copy_type", + "torch._histogramdd_bin_edges", + "torch._histogramdd_from_bin_cts", + "torch._histogramdd_from_bin_tensors", + "torch._index_put_impl_", + "torch._indices_copy", + "torch._int_mm", + "torch._is_all_true", + "torch._is_any_true", + "torch._is_functional_tensor", + "torch._is_zerotensor", + "torch._linalg_check_errors", + "torch._linalg_det", + "torch._linalg_eigh", + "torch._linalg_slogdet", + "torch._linalg_solve_ex", + "torch._linalg_svd", + "torch._log_softmax_backward_data", + "torch._log_softmax", + "torch._logcumsumexp", + "torch._lstm_mps", + "torch._lu_with_info", + "torch._make_dep_token", + "torch._make_dual_copy", + "torch._make_dual", + "torch._make_per_channel_quantized_tensor", + "torch._make_per_tensor_quantized_tensor", + "torch._masked_scale", + "torch._masked_softmax", + "torch._mirror_autograd_meta_to", + "torch._mixed_dtypes_linear", + "torch._mkldnn_reshape", + "torch._mkldnn_transpose_", + "torch._mkldnn_transpose", + "torch._mps_convolution_transpose", + "torch._mps_convolution", + "torch._native_batch_norm_legit_no_training", + "torch._native_batch_norm_legit", + "torch._native_multi_head_attention", + "torch._neg_view_copy", + "torch._neg_view", + "torch._nested_from_padded_and_nested_example", + "torch._nested_tensor_from_mask_left_aligned", + "torch._nested_tensor_from_tensor_list", + "torch._nested_tensor_softmax_with_shape", + "torch._nested_view_from_buffer_copy", + "torch._nested_view_from_buffer", + "torch._nnpack_available", + "torch._nnpack_spatial_convolution", + "torch._pack_padded_sequence", + "torch._pad_packed_sequence", + "torch._pin_memory", + "torch._prelu_kernel", + "torch._propagate_xla_data", + "torch._remove_batch_dim", + "torch._reshape_alias_copy", + "torch._reshape_from_tensor", + "torch._resize_output_", + "torch._rowwise_prune", + "torch._sample_dirichlet", + "torch._saturate_weight_to_fp16", + "torch._scaled_dot_product_attention_math", + "torch._scaled_dot_product_efficient_attention", + "torch._scaled_dot_product_flash_attention", + "torch._scaled_dot_product_flash_attention_for_cpu", + "torch._scaled_dot_product_cudnn_attention", + "torch._scaled_mm", + "torch._shape_as_tensor", + "torch._sobol_engine_draw", + "torch._sobol_engine_ff_", + "torch._sobol_engine_initialize_state_", + "torch._sobol_engine_scramble_", + "torch._softmax_backward_data", + "torch._softmax", + "torch._sparse_broadcast_to_copy", + "torch._sparse_broadcast_to", + "torch._sparse_csr_prod", + "torch._sparse_csr_sum", + "torch._sparse_log_softmax_backward_data", + "torch._sparse_semi_structured_linear", + "torch._sparse_softmax_backward_data", + "torch._sparse_sparse_matmul", + "torch._sparse_sum", + "torch._stack", + "torch._standard_gamma_grad", + "torch._standard_gamma", + "torch._test_autograd_multiple_dispatch_view_copy", + "torch._test_autograd_multiple_dispatch_view", + "torch._test_autograd_multiple_dispatch", + "torch._test_check_tensor", + "torch._test_functorch_fallback", + "torch._test_serialization_subcmul", + "torch._to_cpu", + "torch._to_functional_tensor", + "torch._to_sparse_semi_structured", + "torch._transform_bias_rescale_qkv", + "torch._transformer_encoder_layer_fwd", + "torch._trilinear", + "torch._triton_multi_head_attention", + "torch._triton_scaled_dot_attention", + "torch._unique", + "torch._unique2", + "torch._unpack_dual", + "torch._unsafe_index_put", + "torch._unsafe_index", + "torch._use_cudnn_ctc_loss", + "torch._use_cudnn_rnn_flatten_weight", + "torch._values_copy", + "torch._weight_int4pack_mm", + "torch._weight_int8pack_mm", + "torch._weight_norm_interface", + "torch._weight_norm", + "torch.abs_", + "torch.abs", + "torch.absolute", + "torch.acos_", + "torch.acos", + "torch.acosh_", + "torch.acosh", + "torch.adaptive_avg_pool1d", + "torch.adaptive_max_pool1d", + "torch.add", + "torch.addbmm", + "torch.addcdiv", + "torch.addcmul", + "torch.addmm", + "torch.addmv_", + "torch.addmv", + "torch.addr", + "torch.adjoint", + "torch.affine_grid_generator", + "torch.alias_copy", + "torch.all", + "torch.allclose", + "torch.alpha_dropout_", + "torch.alpha_dropout", + "torch.amax", + "torch.amin", + "torch.aminmax", + "torch.angle", + "torch.any", + "torch.arange", + "torch.arccos_", + "torch.arccos", + "torch.arccosh_", + "torch.arccosh", + "torch.arcsin_", + "torch.arcsin", + "torch.arcsinh_", + "torch.arcsinh", + "torch.arctan_", + "torch.arctan", + "torch.arctan2", + "torch.arctanh_", + "torch.arctanh", + "torch.argmax", + "torch.argmin", + "torch.argsort", + "torch.argwhere", + "torch.as_strided_", + "torch.as_strided_copy", + "torch.as_strided_scatter", + "torch.as_strided", + "torch.as_tensor", + "torch.asarray", + "torch.asin_", + "torch.asin", + "torch.asinh_", + "torch.asinh", + "torch.atan_", + "torch.atan", + "torch.atan2", + "torch.atanh_", + "torch.atanh", + "torch.avg_pool1d", + "torch.baddbmm", + "torch.bartlett_window", + "torch.batch_norm_backward_elemt", + "torch.batch_norm_backward_reduce", + "torch.batch_norm_elemt", + "torch.batch_norm_gather_stats_with_counts", + "torch.batch_norm_gather_stats", + "torch.batch_norm_stats", + "torch.batch_norm_update_stats", + "torch.batch_norm", + "torch.bernoulli", + "torch.bilinear", + "torch.binary_cross_entropy_with_logits", + "torch.bincount", + "torch.binomial", + "torch.bitwise_and", + "torch.bitwise_left_shift", + "torch.bitwise_not", + "torch.bitwise_or", + "torch.bitwise_right_shift", + "torch.bitwise_xor", + "torch.blackman_window", + "torch.bmm", + "torch.broadcast_to", + "torch.bucketize", + "torch.can_cast", + "torch.cat", + "torch.ccol_indices_copy", + "torch.ceil_", + "torch.ceil", + "torch.celu_", + "torch.celu", + "torch.channel_shuffle", + "torch.cholesky_inverse", + "torch.cholesky_solve", + "torch.cholesky", + "torch.choose_qparams_optimized", + "torch.chunk", + "torch.clamp_", + "torch.clamp_max_", + "torch.clamp_max", + "torch.clamp_min_", + "torch.clamp_min", + "torch.clamp", + "torch.clip_", + "torch.clip", + "torch.clone", + "torch.col_indices_copy", + "torch.column_stack", + "torch.combinations", + "torch.complex", + "torch.concat", + "torch.concatenate", + "torch.conj_physical_", + "torch.conj_physical", + "torch.conj", + "torch.constant_pad_nd", + "torch.conv_tbc", + "torch.conv_transpose1d", + "torch.conv_transpose2d", + "torch.conv_transpose3d", + "torch.conv1d", + "torch.conv2d", + "torch.conv3d", + "torch.convolution", + "torch.copysign", + "torch.corrcoef", + "torch.cos_", + "torch.cos", + "torch.cosh_", + "torch.cosh", + "torch.cosine_embedding_loss", + "torch.cosine_similarity", + "torch.count_nonzero", + "torch.cov", + "torch.cross", + "torch.crow_indices_copy", + "torch.ctc_loss", + "torch.cudnn_affine_grid_generator", + "torch.cudnn_batch_norm", + "torch.cudnn_convolution_add_relu", + "torch.cudnn_convolution_relu", + "torch.cudnn_convolution_transpose", + "torch.cudnn_convolution", + "torch.cudnn_grid_sampler", + "torch.cudnn_is_acceptable", + "torch.cummax", + "torch.cummin", + "torch.cumprod", + "torch.cumsum", + "torch.cumulative_trapezoid", + "torch.deg2rad_", + "torch.deg2rad", + "torch.dequantize", + "torch.det", + "torch.detach_", + "torch.detach_copy", + "torch.detach", + "torch.diag_embed", + "torch.diag", + "torch.diagflat", + "torch.diagonal_copy", + "torch.diagonal_scatter", + "torch.diagonal", + "torch.diff", + "torch.digamma", + "torch.dist", + "torch.div", + "torch.divide", + "torch.dot", + "torch.dropout_", + "torch.dropout", + "torch.dsmm", + "torch.dsplit", + "torch.dstack", + "torch.embedding_bag", + "torch.embedding_renorm_", + "torch.embedding", + "torch.empty_like", + "torch.empty_permuted", + "torch.empty_quantized", + "torch.empty_strided", + "torch.empty", + "torch.eq", + "torch.equal", + "torch.erf_", + "torch.erf", + "torch.erfc_", + "torch.erfc", + "torch.erfinv", + "torch.exp_", + "torch.exp", + "torch.exp2_", + "torch.exp2", + "torch.expand_copy", + "torch.expm1_", + "torch.expm1", + "torch.eye", + "torch.fake_quantize_per_channel_affine", + "torch.fake_quantize_per_tensor_affine", + "torch.fbgemm_linear_fp16_weight_fp32_activation", + "torch.fbgemm_linear_fp16_weight", + "torch.fbgemm_linear_int8_weight_fp32_activation", + "torch.fbgemm_linear_int8_weight", + "torch.fbgemm_linear_quantize_weight", + "torch.fbgemm_pack_gemm_matrix_fp16", + "torch.fbgemm_pack_quantized_matrix", + "torch.feature_alpha_dropout_", + "torch.feature_alpha_dropout", + "torch.feature_dropout_", + "torch.feature_dropout", + "torch.fill_", + "torch.fill", + "torch.fix_", + "torch.fix", + "torch.flatten", + "torch.flip", + "torch.fliplr", + "torch.flipud", + "torch.float_power", + "torch.floor_", + "torch.floor_divide", + "torch.floor", + "torch.fmax", + "torch.fmin", + "torch.fmod", + "torch.frac_", + "torch.frac", + "torch.frexp", + "torch.frobenius_norm", + "torch.from_file", + "torch.from_numpy", + "torch.frombuffer", + "torch.full_like", + "torch.full", + "torch.fused_moving_avg_obs_fake_quant", + "torch.gather", + "torch.gcd_", + "torch.gcd", + "torch.ge", + "torch.geqrf", + "torch.ger", + "torch.get_device", + "torch.gradient", + "torch.greater_equal", + "torch.greater", + "torch.grid_sampler_2d", + "torch.grid_sampler_3d", + "torch.grid_sampler", + "torch.group_norm", + "torch.gru_cell", + "torch.gru", + "torch.gt", + "torch.hamming_window", + "torch.hann_window", + "torch.hardshrink", + "torch.heaviside", + "torch.hinge_embedding_loss", + "torch.histc", + "torch.histogram", + "torch.histogramdd", + "torch.hsmm", + "torch.hsplit", + "torch.hspmm", + "torch.hstack", + "torch.hypot", + "torch.i0_", + "torch.i0", + "torch.igamma", + "torch.igammac", + "torch.imag", + "torch.index_add", + "torch.index_copy", + "torch.index_fill", + "torch.index_put_", + "torch.index_put", + "torch.index_reduce", + "torch.index_select", + "torch.indices_copy", + "torch.inner", + "torch.instance_norm", + "torch.int_repr", + "torch.inverse", + "torch.is_complex", + "torch.is_conj", + "torch.is_distributed", + "torch.is_floating_point", + "torch.is_inference", + "torch.is_neg", + "torch.is_nonzero", + "torch.is_same_size", + "torch.is_signed", + "torch.is_vulkan_available", + "torch.isclose", + "torch.isfinite", + "torch.isin", + "torch.isinf", + "torch.isnan", + "torch.isneginf", + "torch.isposinf", + "torch.isreal", + "torch.istft", + "torch.kaiser_window", + "torch.kl_div", + "torch.kron", + "torch.kthvalue", + "torch.layer_norm", + "torch.lcm_", + "torch.lcm", + "torch.ldexp_", + "torch.ldexp", + "torch.le", + "torch.lerp", + "torch.less_equal", + "torch.less", + "torch.lgamma", + "torch.linspace", + "torch.log_", + "torch.log_softmax", + "torch.log", + "torch.log10_", + "torch.log10", + "torch.log1p_", + "torch.log1p", + "torch.log2_", + "torch.log2", + "torch.logaddexp", + "torch.logaddexp2", + "torch.logcumsumexp", + "torch.logdet", + "torch.logical_and", + "torch.logical_not", + "torch.logical_or", + "torch.logical_xor", + "torch.logit_", + "torch.logit", + "torch.logspace", + "torch.logsumexp", + "torch.lstm_cell", + "torch.lstm", + "torch.lt", + "torch.lu_solve", + "torch.lu_unpack", + "torch.margin_ranking_loss", + "torch.masked_fill", + "torch.masked_scatter", + "torch.masked_select", + "torch.matmul", + "torch.matrix_exp", + "torch.matrix_power", + "torch.max_pool1d_with_indices", + "torch.max_pool1d", + "torch.max_pool2d", + "torch.max_pool3d", + "torch.max", + "torch.maximum", + "torch.mean", + "torch.median", + "torch.min", + "torch.minimum", + "torch.miopen_batch_norm", + "torch.miopen_convolution_add_relu", + "torch.miopen_convolution_relu", + "torch.miopen_convolution_transpose", + "torch.miopen_convolution", + "torch.miopen_depthwise_convolution", + "torch.miopen_rnn", + "torch.mkldnn_adaptive_avg_pool2d", + "torch.mkldnn_convolution", + "torch.mkldnn_linear_backward_weights", + "torch.mkldnn_max_pool2d", + "torch.mkldnn_max_pool3d", + "torch.mkldnn_rnn_layer", + "torch.mm", + "torch.mode", + "torch.moveaxis", + "torch.movedim", + "torch.msort", + "torch.mul", + "torch.multinomial", + "torch.multiply", + "torch.mv", + "torch.mvlgamma", + "torch.nan_to_num_", + "torch.nan_to_num", + "torch.nanmean", + "torch.nanmedian", + "torch.nanquantile", + "torch.nansum", + "torch.narrow_copy", + "torch.narrow", + "torch.native_batch_norm", + "torch.native_channel_shuffle", + "torch.native_dropout", + "torch.native_group_norm", + "torch.native_layer_norm", + "torch.native_norm", + "torch.ne", + "torch.neg_", + "torch.neg", + "torch.negative_", + "torch.negative", + "torch.nextafter", + "torch.nonzero_static", + "torch.nonzero", + "torch.norm_except_dim", + "torch.normal", + "torch.not_equal", + "torch.nuclear_norm", + "torch.numel", + "torch.obj", + "torch.ones_like", + "torch.ones", + "torch.orgqr", + "torch.ormqr", + "torch.outer", + "torch.pairwise_distance", + "torch.pdist", + "torch.permute_copy", + "torch.permute", + "torch.pinverse", + "torch.pixel_shuffle", + "torch.pixel_unshuffle", + "torch.poisson_nll_loss", + "torch.poisson", + "torch.polar", + "torch.polygamma", + "torch.positive", + "torch.pow", + "torch.prelu", + "torch._print", + "torch.prod", + "torch.promote_types", + "torch.put", + "torch.q_per_channel_axis", + "torch.q_per_channel_scales", + "torch.q_per_channel_zero_points", + "torch.q_scale", + "torch.q_zero_point", + "torch.qr", + "torch.quantile", + "torch.quantize_per_channel", + "torch.quantize_per_tensor_dynamic", + "torch.quantize_per_tensor", + "torch.quantized_batch_norm", + "torch.quantized_gru_cell", + "torch.quantized_lstm_cell", + "torch.quantized_max_pool1d", + "torch.quantized_max_pool2d", + "torch.quantized_max_pool3d", + "torch.quantized_rnn_relu_cell", + "torch.quantized_rnn_tanh_cell", + "torch.rad2deg_", + "torch.rad2deg", + "torch.rand_like", + "torch.rand", + "torch.randint_like", + "torch.randint", + "torch.randn_like", + "torch.randn", + "torch.randperm", + "torch.range", + "torch.ravel", + "torch.real", + "torch.reciprocal_", + "torch.reciprocal", + "torch.relu_", + "torch.relu", + "torch.remainder", + "torch.renorm", + "torch.repeat_interleave", + "torch.reshape", + "torch.resolve_conj", + "torch.resolve_neg", + "torch.result_type", + "torch.rnn_relu_cell", + "torch.rnn_relu", + "torch.rnn_tanh_cell", + "torch.rnn_tanh", + "torch.roll", + "torch.rot90", + "torch.round_", + "torch.round", + "torch.row_indices_copy", + "torch.row_stack", + "torch.rrelu_", + "torch.rrelu", + "torch.rsqrt_", + "torch.rsqrt", + "torch.rsub", + "torch.saddmm", + "torch.scalar_tensor", + "torch.scatter_add", + "torch.scatter_reduce", + "torch.scatter", + "torch.searchsorted", + "torch.segment_reduce", + "torch.select_copy", + "torch.select_scatter", + "torch.select", + "torch.selu_", + "torch.selu", + "torch.sgn", + "torch.sigmoid_", + "torch.sigmoid", + "torch.sign", + "torch.signal.windows.windows.sqrt", + "torch.signbit", + "torch.sin_", + "torch.sin", + "torch.sinc_", + "torch.sinc", + "torch.sinh_", + "torch.sinh", + "torch.slice_copy", + "torch.slice_scatter", + "torch.slogdet", + "torch.smm", + "torch.softmax", + "torch.sort", + "torch.split_copy", + "torch.split_with_sizes_copy", + "torch.split_with_sizes", + "torch.spmm", + "torch.sqrt_", + "torch.sqrt", + "torch.square_", + "torch.square", + "torch.squeeze_copy", + "torch.squeeze", + "torch.sspaddmm", + "torch.stack", + "torch.std_mean", + "torch.std", + "torch.sub", + "torch.subtract", + "torch.sum", + "torch.svd", + "torch.swapaxes", + "torch.swapdims", + "torch.sym_constrain_range_for_size", + "torch.sym_constrain_range", + "torch.t_copy", + "torch.t", + "torch.take_along_dim", + "torch.take", + "torch.tan_", + "torch.tan", + "torch.tanh_", + "torch.tanh", + "torch.tensor_split", + "torch.tensor", + "torch.threshold_", + "torch.threshold", + "torch.tile", + "torch.topk", + "torch.trace", + "torch.transpose_copy", + "torch.transpose", + "torch.trapezoid", + "torch.trapz", + "torch.triangular_solve", + "torch.tril_indices", + "torch.tril", + "torch.triplet_margin_loss", + "torch.triu_indices", + "torch.triu", + "torch.true_divide", + "torch.trunc_", + "torch.trunc", + "torch.unbind_copy", + "torch.unbind", + "torch.unflatten", + "torch.unfold_copy", + "torch.unsafe_chunk", + "torch.unsafe_split_with_sizes", + "torch.unsafe_split", + "torch.unsqueeze_copy", + "torch.unsqueeze", + "torch.values_copy", + "torch.vander", + "torch.var_mean", + "torch.var", + "torch.vdot", + "torch.view_as_complex_copy", + "torch.view_as_complex", + "torch.view_as_real_copy", + "torch.view_as_real", + "torch.view_copy", + "torch.vsplit", + "torch.vstack", + "torch.where", + "torch.xlogy_", + "torch.xlogy", + "torch.zero_", + "torch.zeros", + "torch._fused_sgd_", + "torch.slice_inverse", + "torch._assert_scalar", + "torch._functional_assert_scalar", + ], + TorchInGraphFunctionVariable, +) + + +if sys.version_info >= (3, 9): + torch_c_binding_in_graph_functions["math.lcm"] = TorchInGraphFunctionVariable +if sys.version_info >= (3, 11): + torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable + torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable + + +# In graph functions (including constant folding) that are not C bindings +torch_non_c_binding_in_graph_functions = dict.fromkeys( + [ + "torch.__future__.get_overwrite_module_params_on_conversion", + "torch.__future__.set_overwrite_module_params_on_conversion", + "torch.__getattr__", + "torch._assert", + "torch._check_index", + "torch._check_is_size", + "torch._check_not_implemented", + "torch._check_tensor_all_with", + "torch._check_tensor_all", + "torch._check_type", + "torch._check_value", + "torch._check_with", + "torch._check", + "torch._compile._disable_dynamo", + "torch._functorch.apis.chunk_vmap", + "torch._functorch.autograd_function.custom_function_call_functionalize", + "torch._functorch.autograd_function.custom_function_call_grad", + "torch._functorch.autograd_function.custom_function_call_vmap_generate_rule", + "torch._functorch.autograd_function.custom_function_call_vmap", + "torch._functorch.autograd_function.generate_single_level_function", + "torch._functorch.autograd_function.get_tangents_in_dims", + "torch._functorch.autograd_function.has_overriden_vmap_rule", + "torch._functorch.autograd_function.reductify_leaf", + "torch._functorch.autograd_function.reductify", + "torch._functorch.autograd_function.validate_vmap_returns_tuple_of_two_elements", + "torch._functorch.autograd_function.vmapify_autograd_function", + "torch._functorch.autograd_function.wrap_outputs_maintaining_identity", + "torch._functorch.batch_norm_replacement.batch_norm_without_running_stats", + "torch._functorch.batch_norm_replacement.replace_all_batch_norm_modules_", + "torch._functorch.deprecated.combine_state_for_ensemble", + "torch._functorch.deprecated.functionalize", + "torch._functorch.deprecated.get_warning", + "torch._functorch.deprecated.grad_and_value", + "torch._functorch.deprecated.hessian", + "torch._functorch.deprecated.jacfwd", + "torch._functorch.deprecated.jacrev", + "torch._functorch.deprecated.jvp", + "torch._functorch.deprecated.make_functional_with_buffers", + "torch._functorch.deprecated.make_functional", + "torch._functorch.deprecated.setup_docs", + "torch._functorch.deprecated.vjp", + "torch._functorch.deprecated.warn_deprecated", + "torch._functorch.eager_transforms._any_differentiable", + "torch._functorch.eager_transforms._autograd_grad", + "torch._functorch.eager_transforms._construct_standard_basis_for", + "torch._functorch.eager_transforms._vjp_treespec_compare", + "torch._functorch.eager_transforms._set_tensor_requires_grad", + "torch._functorch.eager_transforms._is_differentiable", + "torch._functorch.eager_transforms._jvp_with_argnums", + "torch._functorch.eager_transforms._maybe_unwrap_functional_tensor", + "torch._functorch.eager_transforms._maybe_wrap_functional_tensor", + "torch._functorch.eager_transforms._replace_args", + "torch._functorch.eager_transforms._unwrap_all_tensors_from_functional", + "torch._functorch.eager_transforms._wrap_all_tensors_to_functional", + "torch._functorch.eager_transforms.assert_flat_tuple_of_tensors", + "torch._functorch.eager_transforms.assert_non_empty_list_of_tensors", + "torch._functorch.eager_transforms.assert_output_is_tensor_or_tensors", + "torch._functorch.eager_transforms.functionalize", + "torch._functorch.eager_transforms.hessian", + "torch._functorch.eager_transforms.jacfwd", + "torch._functorch.eager_transforms.jvp", + "torch._functorch.eager_transforms.lazy_dynamo_disable", + "torch._functorch.eager_transforms.linearize", + "torch._functorch.eager_transforms.noop", + "torch._functorch.eager_transforms.safe_unflatten", + "torch._functorch.eager_transforms.safe_unpack_dual", + "torch._functorch.functional_call.construct_stacked_leaf", + "torch._functorch.functional_call.functional_call", + "torch._functorch.functional_call.stack_module_state", + "torch._functorch.pyfunctorch.coerce_cinterpreter", + "torch._functorch.pyfunctorch.dispatch_functorch", + "torch._functorch.pyfunctorch.nested", + "torch._functorch.pyfunctorch.retrieve_current_functorch_interpreter", + "torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack", + "torch._functorch.utils.enable_single_level_autograd_function", + "torch._functorch.utils.exposed_in", + "torch._functorch.utils.unwrap_dead_wrappers", + "torch._functorch.vmap.lazy_load_decompositions", + "torch._guards.compile_context", + "torch._guards.detect_fake_mode", + "torch._guards.tracing", + "torch._higher_order_ops.map._has_potential_branch_input_alias", + "torch._higher_order_ops.map._has_potential_branch_input_mutation", + "torch._higher_order_ops.map._stack_pytree", + "torch._higher_order_ops.map._unstack_pytree", + "torch._higher_order_ops.map.create_fw_bw_graph", + "torch._higher_order_ops.map.map_autograd", + "torch._higher_order_ops.map.map_dense", + "torch._higher_order_ops.map.map_fake_tensor_mode", + "torch._higher_order_ops.map.map_functionalize", + "torch._higher_order_ops.map.map_proxy_torch_dispatch_mode", + "torch._higher_order_ops.map.map_wrapper", + "torch._higher_order_ops.map.trace_map", + "torch._higher_order_ops.out_dtype.elementwise_dtypes", + "torch._higher_order_ops.out_dtype.is_int_mm", + "torch._higher_order_ops.out_dtype.out_dtype_dense", + "torch._higher_order_ops.out_dtype.out_dtype_fake_tensor_mode", + "torch._higher_order_ops.out_dtype.out_dtype_fallback", + "torch._higher_order_ops.out_dtype.out_dtype_func", + "torch._higher_order_ops.out_dtype.out_dtype_proxy", + "torch._higher_order_ops.out_dtype.trace_out_dtype", + "torch._higher_order_ops.utils.autograd_not_implemented_inner", + "torch._higher_order_ops.utils.autograd_not_implemented", + "torch._linalg_utils._symeig", + "torch._linalg_utils.basis", + "torch._linalg_utils.bform", + "torch._linalg_utils.conjugate", + "torch._linalg_utils.eig", + "torch._linalg_utils.get_floating_dtype", + "torch._linalg_utils.is_sparse", + "torch._linalg_utils.lstsq", + "torch._linalg_utils.matmul", + "torch._linalg_utils.matrix_rank", + "torch._linalg_utils.qform", + "torch._linalg_utils.solve", + "torch._linalg_utils.symeig", + "torch._linalg_utils.transjugate", + "torch._linalg_utils.transpose", + "torch._load_global_deps", + "torch._lowrank._svd_lowrank", + "torch._lowrank.get_approximate_basis", + "torch._lowrank.pca_lowrank", + "torch._lowrank.svd_lowrank", + "torch._ops._compute_keyset", + "torch._ops._get_tensors", + "torch._ops._to_flat_tuple", + "torch._ops.add_cached_op", + "torch._ops.dl_open_guard", + "torch._ops.get_cached_ops", + "torch._ops.key_extractor", + "torch._ops.reset_cached_ops", + "torch._ops.resolve_key", + "torch._preload_cuda_deps", + "torch._register_device_module", + "torch._running_with_deploy", + "torch._utils._dummy_type", + "torch._weights_only_unpickler._get_allowed_globals", + "torch._weights_only_unpickler.load", + "torch.align_tensors", + "torch.amp.autocast_mode._enter_autocast", + "torch.amp.autocast_mode._exit_autocast", + "torch.amp.autocast_mode.autocast_decorator", + "torch.are_deterministic_algorithms_enabled", + "torch.atleast_1d", + "torch.atleast_2d", + "torch.atleast_3d", + "torch.autograd._calculate_shape", + "torch.autograd._is_checkpoint_valid", + "torch.autograd._make_grads", + "torch.autograd._register_py_tensor_class_for_device", + "torch.autograd._tensor_or_tensors_to_tuple", + "torch.autograd.backward", + "torch.autograd.forward_ad.enter_dual_level", + "torch.autograd.forward_ad.exit_dual_level", + "torch.autograd.forward_ad.make_dual", + "torch.autograd.forward_ad.unpack_dual", + "torch.autograd.function._iter_filter", + "torch.autograd.function._iter_jit_values", + "torch.autograd.function._iter_None_tensors", + "torch.autograd.function._iter_tensors_permissive", + "torch.autograd.function._iter_tensors", + "torch.autograd.function._jit_unwrap_structured", + "torch.autograd.function._map_tensor_data", + "torch.autograd.function._nested_map", + "torch.autograd.function._unflatten", + "torch.autograd.function.once_differentiable", + "torch.autograd.function.traceable", + "torch.autograd.functional._as_tuple_nocheck", + "torch.autograd.functional._as_tuple", + "torch.autograd.functional._autograd_grad", + "torch.autograd.functional._check_requires_grad", + "torch.autograd.functional._construct_standard_basis_for", + "torch.autograd.functional._fill_in_zeros", + "torch.autograd.functional._grad_postprocess", + "torch.autograd.functional._grad_preprocess", + "torch.autograd.functional._jacfwd", + "torch.autograd.functional._tuple_postprocess", + "torch.autograd.functional._validate_v", + "torch.autograd.functional.hessian", + "torch.autograd.functional.hvp", + "torch.autograd.functional.jacobian", + "torch.autograd.functional.jvp", + "torch.autograd.functional.vhp", + "torch.autograd.functional.vjp", + "torch.autograd.grad_mode._enter_inference_mode", + "torch.autograd.grad_mode._exit_inference_mode", + "torch.autograd.graph._get_sid", + "torch.autograd.graph._get_tid", + "torch.autograd.graph.allow_mutation_on_saved_tensors", + "torch.autograd.graph.get_gradient_edge", + "torch.autograd.graph.increment_version", + "torch.autograd.graph.register_multi_grad_hook", + "torch.autograd.variable", + "torch.backends.__allow_nonbracketed_mutation", + "torch.backends.cpu.get_cpu_capability", + "torch.backends.cuda.can_use_efficient_attention", + "torch.backends.cuda.can_use_flash_attention", + "torch.backends.cuda.enable_flash_sdp", + "torch.backends.cuda.enable_math_sdp", + "torch.backends.cuda.enable_mem_efficient_sdp", + "torch.backends.cuda.flash_sdp_enabled", + "torch.backends.cuda.is_built", + "torch.backends.cuda.math_sdp_enabled", + "torch.backends.cuda.mem_efficient_sdp_enabled", + "torch.backends.cuda.cudnn_sdp_enabled", + "torch.backends.cuda.enable_cudnn_sdp", + "torch.backends.cuda.preferred_linalg_library", + "torch.backends.cuda.sdp_kernel", + "torch.backends.cudnn._init", + "torch.backends.cudnn.flags", + "torch.backends.cudnn.is_acceptable", + "torch.backends.cudnn.is_available", + "torch.backends.cudnn.set_flags", + "torch.backends.cudnn.version", + "torch.backends.disable_global_flags", + "torch.backends.flags_frozen", + "torch.backends.mkl.is_available", + "torch.backends.mkldnn.flags", + "torch.backends.mkldnn.is_available", + "torch.backends.mkldnn.set_flags", + "torch.backends.mps._init", + "torch.backends.mps.is_available", + "torch.backends.mps.is_built", + "torch.backends.mps.is_macos13_or_newer", + "torch.backends.openmp.is_available", + "torch.backends.quantized._get_qengine_id", + "torch.backends.quantized._get_qengine_str", + "torch.block_diag", + "torch.broadcast_tensors", + "torch.cartesian_prod", + "torch.cdist", + "torch.chain_matmul", + "torch.compile", + "torch.compiled_with_cxx11_abi", + "torch.cpu._is_cpu_support_vnni", + "torch.cpu.current_device", + "torch.cpu.current_stream", + "torch.cpu.device_count", + "torch.cpu.is_available", + "torch.cpu.set_device", + "torch.cpu.stream", + "torch.cpu.synchronize", + "torch.cuda._check_capability", + "torch.cuda._check_cubins", + "torch.cuda._device_count_nvml", + "torch.cuda._get_device", + "torch.cuda._get_generator", + "torch.cuda._get_nvml_device_index", + "torch.cuda._get_pynvml_handler", + "torch.cuda._get_rng_state_offset", + "torch.cuda._is_compiled", + "torch.cuda._lazy_call", + "torch.cuda._lazy_init", + "torch.cuda._memory_viz._block_extra_legacy", + "torch.cuda._memory_viz._block_extra", + "torch.cuda._memory_viz._format_size", + "torch.cuda._memory_viz._format_viz", + "torch.cuda._memory_viz._frame_filter", + "torch.cuda._memory_viz._frame_fmt", + "torch.cuda._memory_viz._frames_fmt", + "torch.cuda._memory_viz._profile_to_snapshot", + "torch.cuda._memory_viz._report_free", + "torch.cuda._memory_viz._write_blocks", + "torch.cuda._memory_viz.calc_active", + "torch.cuda._memory_viz.compare", + "torch.cuda._memory_viz.format_flamegraph", + "torch.cuda._memory_viz.memory", + "torch.cuda._memory_viz.profile_plot", + "torch.cuda._memory_viz.segment_plot", + "torch.cuda._memory_viz.segments", + "torch.cuda._memory_viz.segsum", + "torch.cuda._memory_viz.trace_plot", + "torch.cuda._memory_viz.trace", + "torch.cuda._nvml_based_avail", + "torch.cuda._parse_visible_devices", + "torch.cuda._raw_device_count_nvml", + "torch.cuda._raw_device_uuid_nvml", + "torch.cuda._register_triton_kernels", + "torch.cuda._set_rng_state_offset", + "torch.cuda._set_stream_by_id", + "torch.cuda._sleep", + "torch.cuda._transform_uuid_to_ordinals", + "torch.cuda._utils._get_device_index", + "torch.cuda.amp.autocast_mode._cast", + "torch.cuda.amp.autocast_mode.custom_bwd", + "torch.cuda.amp.autocast_mode.custom_fwd", + "torch.cuda.amp.common.amp_definitely_not_available", + "torch.amp.grad_scaler._refresh_per_optimizer_state", + "torch.cuda.can_device_access_peer", + "torch.cuda.check_error", + "torch.cuda.clock_rate", + "torch.cuda.cudart", + "torch.cuda.current_blas_handle", + "torch.cuda.current_stream", + "torch.cuda.default_stream", + "torch.cuda.device_count", + "torch.cuda.get_arch_list", + "torch.cuda.get_device_capability", + "torch.cuda.get_device_name", + "torch.cuda.get_device_properties", + "torch.cuda.get_gencode_flags", + "torch.cuda.get_sync_debug_mode", + "torch.cuda.graphs.graph_pool_handle", + "torch.cuda.graphs.is_current_stream_capturing", + "torch.cuda.graphs.make_graphed_callables", + "torch.cuda.init", + "torch.cuda.ipc_collect", + "torch.cuda.is_available", + "torch.cuda.is_bf16_supported", + "torch.cuda.is_initialized", + "torch.cuda.jiterator._create_jit_fn", + "torch.cuda.jiterator._create_multi_output_jit_fn", + "torch.cuda.memory_usage", + "torch.cuda.memory._dump_snapshot", + "torch.cuda.memory._free_mutex", + "torch.cuda.memory._get_current_allocator", + "torch.cuda.memory._host_allocator", + "torch.cuda.memory._record_memory_history_impl", + "torch.cuda.memory._record_memory_history_legacy", + "torch.cuda.memory._record_memory_history", + "torch.cuda.memory._save_memory_usage", + "torch.cuda.memory._save_segment_usage", + "torch.cuda.memory._set_allocator_settings", + "torch.cuda.memory._snapshot", + "torch.cuda.memory.caching_allocator_alloc", + "torch.cuda.memory.caching_allocator_delete", + "torch.cuda.memory.change_current_allocator", + "torch.cuda.memory.empty_cache", + "torch.cuda.memory.get_allocator_backend", + "torch.cuda.memory.list_gpu_processes", + "torch.cuda.memory.max_memory_allocated", + "torch.cuda.memory.max_memory_cached", + "torch.cuda.memory.max_memory_reserved", + "torch.cuda.memory.mem_get_info", + "torch.cuda.memory.memory_allocated", + "torch.cuda.memory.memory_cached", + "torch.cuda.memory.memory_reserved", + "torch.cuda.memory.memory_snapshot", + "torch.cuda.memory.memory_stats_as_nested_dict", + "torch.cuda.memory.memory_stats", + "torch.cuda.memory.memory_summary", + "torch.cuda.memory.reset_accumulated_memory_stats", + "torch.cuda.memory.reset_max_memory_allocated", + "torch.cuda.memory.reset_max_memory_cached", + "torch.cuda.memory.reset_peak_memory_stats", + "torch.cuda.memory.set_per_process_memory_fraction", + "torch.cuda.nccl._check_sequence_type", + "torch.cuda.nccl.all_gather", + "torch.cuda.nccl.all_reduce", + "torch.cuda.nccl.broadcast", + "torch.cuda.nccl.init_rank", + "torch.cuda.nccl.is_available", + "torch.cuda.nccl.reduce_scatter", + "torch.cuda.nccl.reduce", + "torch.cuda.nccl.unique_id", + "torch.cuda.nccl.version", + "torch.cuda.nvtx.mark", + "torch.cuda.nvtx.range_end", + "torch.cuda.nvtx.range_pop", + "torch.cuda.nvtx.range_push", + "torch.cuda.nvtx.range_start", + "torch.cuda.nvtx.range", + "torch.cuda.power_draw", + "torch.cuda.profiler.init", + "torch.cuda.profiler.profile", + "torch.cuda.profiler.start", + "torch.cuda.profiler.stop", + "torch.cuda.random.get_rng_state_all", + "torch.cuda.random.initial_seed", + "torch.cuda.random.manual_seed_all", + "torch.cuda.random.manual_seed", + "torch.cuda.random.seed_all", + "torch.cuda.random.seed", + "torch.cuda.random.set_rng_state_all", + "torch.cuda.set_stream", + "torch.cuda.set_sync_debug_mode", + "torch.cuda.stream", + "torch.cuda.synchronize", + "torch.cuda.temperature", + "torch.cuda.utilization", + "torch.einsum", + "torch.functional._check_list_size", + "torch.functional._consecutive_return_counts", + "torch.functional._consecutive_return_inverse_false", + "torch.functional._consecutive_return_inverse_true", + "torch.functional._consecutive_return_inverse", + "torch.functional._consecutive_return_output", + "torch.functional._lu_impl", + "torch.functional._lu_no_infos", + "torch.functional._lu_with_infos", + "torch.functional._meshgrid", + "torch.functional._return_counts", + "torch.functional._return_inverse_false", + "torch.functional._return_inverse_true", + "torch.functional._return_inverse", + "torch.functional._return_output", + "torch.functional._unique_consecutive_impl", + "torch.functional._unique_impl", + "torch.functional._unravel_index", + "torch.functional.broadcast_shapes", + "torch.functional.lu", + "torch.functional.unique", + "torch.functional.unravel_index", + "torch.futures.collect_all", + "torch.futures.wait_all", + "torch.get_deterministic_debug_mode", + "torch.get_float32_matmul_precision", + "torch.is_deterministic_algorithms_warn_only_enabled", + "torch.is_storage", + "torch.is_tensor", + "torch.is_warn_always_enabled", + "torch.masked._ops._any", + "torch.masked._ops._apply_docstring_templates", + "torch.masked._ops._canonical_dim", + "torch.masked._ops._combine_input_and_mask", + "torch.masked._ops._generate_docstring", + "torch.masked._ops._input_mask", + "torch.masked._ops._output_mask", + "torch.masked._ops._reduction_identity", + "torch.masked._ops._sparse_coo_flatten_indices", + "torch.masked._ops._sparse_coo_scatter_reduction_helper", + "torch.masked._ops._sparse_coo_where", + "torch.masked._ops._sparse_csr_segment_reduction_helper", + "torch.masked._ops._sparse_csr_where", + "torch.masked._ops._std_var", + "torch.masked._ops._where", + "torch.masked._ops.amax", + "torch.masked._ops.amin", + "torch.masked._ops.argmax", + "torch.masked._ops.argmin", + "torch.masked._ops.corresponding_real_dtype", + "torch.masked._ops.cumprod", + "torch.masked._ops.cumsum", + "torch.masked._ops.log_softmax", + "torch.masked._ops.logaddexp", + "torch.masked._ops.logsumexp", + "torch.masked._ops.mean", + "torch.masked._ops.median", + "torch.masked._ops.norm", + "torch.masked._ops.normalize", + "torch.masked._ops.prod", + "torch.masked._ops.softmax", + "torch.masked._ops.softmin", + "torch.masked._ops.std", + "torch.masked._ops.sum", + "torch.masked._ops.var", + "torch.meshgrid", + "torch.mps._get_default_mps_generator", + "torch.mps.current_allocated_memory", + "torch.mps.driver_allocated_memory", + "torch.mps.empty_cache", + "torch.mps.get_rng_state", + "torch.mps.manual_seed", + "torch.mps.profiler.profile", + "torch.mps.profiler.start", + "torch.mps.profiler.stop", + "torch.mps.seed", + "torch.mps.set_per_process_memory_fraction", + "torch.mps.set_rng_state", + "torch.mps.synchronize", + "torch.nested._internal.nested_tensor.get_tensor_symint", + "torch.nested._internal.nested_tensor.is_expandable_to", + "torch.nested._internal.nested_tensor.jagged_from_list", + "torch.nested._internal.nested_tensor.jagged_from_tensor_and_lengths", + "torch.nested._internal.nested_tensor.nested_view_from_values_offsets", + "torch.nested._internal.nested_tensor.nested_view_from_values_offsets_lengths", + "torch.nested.as_nested_tensor", + "torch.nested.narrow", + "torch.nested.nested_tensor", + "torch.nn._reduction.get_enum", + "torch.nn._reduction.legacy_get_enum", + "torch.nn._reduction.legacy_get_string", + "torch.nn.factory_kwargs", + "torch.nn.functional._adaptive_max_pool1d", + "torch.nn.functional._adaptive_max_pool2d", + "torch.nn.functional._adaptive_max_pool3d", + "torch.nn.functional._canonical_mask", + "torch.nn.functional._fractional_max_pool2d", + "torch.nn.functional._fractional_max_pool3d", + "torch.nn.functional._get_softmax_dim", + "torch.nn.functional._in_projection_packed", + "torch.nn.functional._in_projection", + "torch.nn.functional._is_integer", + "torch.nn.functional._max_pool1d", + "torch.nn.functional._max_pool2d", + "torch.nn.functional._max_pool3d", + "torch.nn.functional._mha_shape_check", + "torch.nn.functional._no_grad_embedding_renorm_", + "torch.nn.functional._none_or_dtype", + "torch.nn.functional._threshold", + "torch.nn.functional._unpool_output_size", + "torch.nn.functional._verify_batch_size", + "torch.nn.functional._verify_spatial_size", + "torch.nn.functional.adaptive_avg_pool2d", + "torch.nn.functional.adaptive_avg_pool3d", + "torch.nn.functional.adaptive_max_pool1d_with_indices", + "torch.nn.functional.adaptive_max_pool1d", + "torch.nn.functional.adaptive_max_pool2d_with_indices", + "torch.nn.functional.adaptive_max_pool2d", + "torch.nn.functional.adaptive_max_pool3d_with_indices", + "torch.nn.functional.adaptive_max_pool3d", + "torch.nn.functional.affine_grid", + "torch.nn.functional.alpha_dropout", + "torch.nn.functional.assert_int_or_pair", + "torch.nn.functional.batch_norm", + "torch.nn.functional.binary_cross_entropy_with_logits", + "torch.nn.functional.binary_cross_entropy", + "torch.nn.functional.celu", + "torch.nn.functional.cosine_embedding_loss", + "torch.nn.functional.cross_entropy", + "torch.nn.functional.ctc_loss", + "torch.nn.functional.dropout", + "torch.nn.functional.dropout1d", + "torch.nn.functional.dropout2d", + "torch.nn.functional.dropout3d", + "torch.nn.functional.elu", + "torch.nn.functional.embedding_bag", + "torch.nn.functional.embedding", + "torch.nn.functional.feature_alpha_dropout", + "torch.nn.functional.fold", + "torch.nn.functional.fractional_max_pool2d_with_indices", + "torch.nn.functional.fractional_max_pool2d", + "torch.nn.functional.fractional_max_pool3d_with_indices", + "torch.nn.functional.fractional_max_pool3d", + "torch.nn.functional.gaussian_nll_loss", + "torch.nn.functional.glu", + "torch.nn.functional.grid_sample", + "torch.nn.functional.group_norm", + "torch.nn.functional.gumbel_softmax", + "torch.nn.functional.hardsigmoid", + "torch.nn.functional.hardswish", + "torch.nn.functional.hardtanh", + "torch.nn.functional.hinge_embedding_loss", + "torch.nn.functional.huber_loss", + "torch.nn.functional.instance_norm", + "torch.nn.functional.interpolate", + "torch.nn.functional.kl_div", + "torch.nn.functional.l1_loss", + "torch.nn.functional.layer_norm", + "torch.nn.functional.leaky_relu", + "torch.nn.functional.local_response_norm", + "torch.nn.functional.log_softmax", + "torch.nn.functional.lp_pool1d", + "torch.nn.functional.lp_pool2d", + "torch.nn.functional.margin_ranking_loss", + "torch.nn.functional.max_pool1d_with_indices", + "torch.nn.functional.max_pool1d", + "torch.nn.functional.max_pool2d_with_indices", + "torch.nn.functional.max_pool2d", + "torch.nn.functional.max_pool3d_with_indices", + "torch.nn.functional.max_pool3d", + "torch.nn.functional.max_unpool1d", + "torch.nn.functional.max_unpool2d", + "torch.nn.functional.max_unpool3d", + "torch.nn.functional.mish", + "torch.nn.functional.mse_loss", + "torch.nn.functional.multi_head_attention_forward", + "torch.nn.functional.multi_margin_loss", + "torch.nn.functional.multilabel_margin_loss", + "torch.nn.functional.multilabel_soft_margin_loss", + "torch.nn.functional.nll_loss", + "torch.nn.functional.normalize", + "torch.nn.functional.poisson_nll_loss", + "torch.nn.functional.relu", + "torch.nn.functional.relu6", + "torch.nn.functional.rrelu", + "torch.nn.functional.selu", + "torch.nn.functional.sigmoid", + "torch.nn.functional.silu", + "torch.nn.functional.smooth_l1_loss", + "torch.nn.functional.soft_margin_loss", + "torch.nn.functional.softmax", + "torch.nn.functional.softmin", + "torch.nn.functional.softsign", + "torch.nn.functional.tanh", + "torch.nn.functional.tanhshrink", + "torch.nn.functional.triplet_margin_loss", + "torch.nn.functional.unfold", + "torch.nn.functional.upsample_bilinear", + "torch.nn.functional.upsample_nearest", + "torch.nn.functional.upsample", + "torch.nn.grad._pair", + "torch.nn.grad._single", + "torch.nn.grad._triple", + "torch.nn.grad.conv1d_input", + "torch.nn.grad.conv1d_weight", + "torch.nn.grad.conv2d_input", + "torch.nn.grad.conv2d_weight", + "torch.nn.grad.conv3d_input", + "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._arg_requires_grad", + "torch.nn.modules.activation._check_arg_device", + "torch.nn.modules.activation._is_make_fx_tracing", + "torch.nn.modules.container._addindent", + "torch.nn.modules.transformer._detect_is_causal_mask", + "torch.nn.modules.transformer._generate_square_subsequent_mask", + "torch.nn.modules.transformer._get_activation_fn", + "torch.nn.modules.transformer._get_clones", + "torch.nn.modules.transformer._get_seq_len", + "torch.nn.modules.utils._list_with_default", + "torch.nn.modules.utils._ntuple", + "torch.nn.modules.utils._quadruple", + "torch.nn.modules.utils._reverse_repeat_tuple", + "torch.nn.modules.utils.consume_prefix_in_state_dict_if_present", + "torch.nn.parameter.is_lazy", + "torch.norm", + "torch.quantization.default_eval_fn", + "torch.random._seed_custom_device", + "torch.random.fork_rng", + "torch.random.initial_seed", + "torch.random.seed", + "torch.return_types.pytree_register_structseq", + "torch.set_default_device", + "torch.set_default_dtype", + "torch.set_default_tensor_type", + "torch.set_deterministic_debug_mode", + "torch.set_float32_matmul_precision", + "torch.set_warn_always", + "torch.signal.windows.windows._add_docstr", + "torch.signal.windows.windows._window_function_checks", + "torch.signal.windows.windows.bartlett", + "torch.signal.windows.windows.blackman", + "torch.signal.windows.windows.cosine", + "torch.signal.windows.windows.exponential", + "torch.signal.windows.windows.gaussian", + "torch.signal.windows.windows.general_cosine", + "torch.signal.windows.windows.general_hamming", + "torch.signal.windows.windows.hamming", + "torch.signal.windows.windows.hann", + "torch.signal.windows.windows.kaiser", + "torch.signal.windows.windows.merge_dicts", + "torch.signal.windows.windows.nuttall", + "torch.signal.windows.windows.parse_kwargs", + "torch.sparse.semi_structured.to_sparse_semi_structured", + "torch.sparse.sum", + "torch.split", + "torch.stft", + "torch.sym_float", + "torch.sym_int", + "torch.sym_ite", + "torch.sym_max", + "torch.sym_min", + "torch.sym_not", + "torch.tensordot", + "torch.typename", + "torch.unique_consecutive", + "torch.use_deterministic_algorithms", + ], + TorchInGraphFunctionVariable, +) + + +torch_name_rule_map = [ + manual_torch_name_rule_map, + torch_c_binding_in_graph_functions, + torch_non_c_binding_in_graph_functions, +] + + +""" +Generate the torch object - Dynamo tracing rule (the wrapping variable) map. +""" + + +@functools.lru_cache(None) +def get_torch_obj_rule_map(): + d: Dict[Any, VariableTracker] = dict() + for m in torch_name_rule_map: + for k, v in m.items(): # type: ignore[attr-defined] + obj = load_object(k) + if obj is not None: + if obj in d and d[obj] != v: + raise AssertionError( + f"Duplicate torch object {obj} with different rules: {v}, {d[obj]}" + ) + else: + d[obj] = v + return d + + +def _load_obj_from_str(fully_qualified_name): + module, obj_name = fully_qualified_name.rsplit(".", maxsplit=1) + return getattr(importlib.import_module(module), obj_name) + + +""" +Load string represented torch objects. +""" + + +def load_object(name): + try: + x = name.split("#") + if len(x) == 2: + obj = _load_obj_from_str(x[0]) + val = getattr(obj, x[1]) + else: + assert len(x) == 1, f"Invalid obj name {name}" + val = _load_obj_from_str(x[0]) + val = unwrap_if_wrapper(val) + except (AttributeError, ImportError): + val = None + return val + + +""" +Get all torch.Tensor methods which are allowed to be in graph functions. +""" + + +@functools.lru_cache(None) +def get_tensor_method(): + s = set() + for name in dir(torch.Tensor): + method = getattr(torch.Tensor, name) + if isinstance( + method, (types.MethodDescriptorType, types.WrapperDescriptorType) + ): + s.add(method) + return frozenset(s) + + +""" +Return if a torch object is ATen op or torch.Tensor method. +""" + + +def is_aten_op_or_tensor_method(obj): + return obj in get_tensor_method() or isinstance( + obj, + (torch._ops.OpOverloadPacket, torch._ops.OpOverload), + ) + + +class FunctionIdSet: + """ + Track a set of `id()`s of objects which are either allowed or not + allowed to go into the generated FX graph. Use to test for torch.*, + numpy.*, builtins.*, etc. + + Support user modification to permit customization of what can be + added to the graph and what will cause a graph break. + """ + + function_ids: Optional[Set[int]] = None + function_names: Optional[Dict[int, str]] = None + + def __init__(self, lazy_initializer: Callable[[], Union[Dict[int, str], Set[int]]]): + self.lazy_initializer = lazy_initializer + + def __call__(self): + if self.function_ids is None: + value = self.lazy_initializer() + if isinstance(value, dict): + self.function_ids = set(value.keys()) + self.function_names = value + else: + assert isinstance(value, set) + self.function_ids = value + return self.function_ids + + def get_name(self, idx: int, default: str): + self() # lazy init + assert self.function_names is not None + return self.function_names.get(idx, default) + + def add(self, idx: int): + function_ids = self() # lazy init + function_ids.add(idx) + + def remove(self, idx: int): + function_ids = self() + if idx in function_ids: + function_ids.remove(idx) + + def __contains__(self, idx: int): + return idx in self() + + +@FunctionIdSet +def _allowed_callable_ids() -> Dict[int, str]: + rv: Dict[int, str] = {} + return rv + + +@FunctionIdSet +def _disallowed_callable_ids() -> Dict[int, str]: + rv: Dict[int, str] = {} + return rv + + +@FunctionIdSet +def _builtin_function_ids() -> Dict[int, str]: + rv = { + id(v): f"builtins.{k}" + for k, v in builtins.__dict__.items() + if not k.startswith("_") and callable(v) + } + rv.update( + { + id(v): f"operator.{k}" + for k, v in operator.__dict__.items() + if not k.startswith("_") and callable(v) + } + ) + rv.update( + {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)} + ) + rv.update( + { + id(cast): "typing.cast", + id(functools.reduce): "functools.reduce", + id(copy.deepcopy): "copy.deepcopy", + } + ) + return rv + + +@FunctionIdSet +def _numpy_function_ids() -> Dict[int, str]: + rv = dict() + for mod in NP_SUPPORTED_MODULES: + rv.update( + { + id(v): f"{mod.__name__}.{k}" + for k, v in mod.__dict__.items() + if callable(v) + and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__ + } + ) + return rv + + +@FunctionIdSet +def _builtin_constant_ids() -> Dict[int, str]: + """ + Collects constant builtins by eliminating callable items. + """ + rv = { + id(v): f"builtins.{k}" + for k, v in builtins.__dict__.items() + if not k.startswith("_") and not callable(v) + } + return rv + + +_lazy_module_init: Dict[str, List[Callable[[], None]]] = defaultdict(list) + + +def add_module_init_func(name: str, init_func: Callable[[], None]) -> None: + """Register a module without eagerly importing it""" + # If the module is already imported, eagerly run init + assert "." not in name, f"Expected a root module name, but got {name}" + if name in sys.modules: + init_func() + + # Module is not yet imported, delay processing until needed + assert name not in _lazy_module_init + _lazy_module_init[name].append(init_func) + + +def _maybe_init_lazy_module(obj: object) -> None: + module = getattr(obj, "__module__", None) + if module is None: + return + + base_module = module.split(".")[0] + init_funcs = _lazy_module_init.pop(base_module, None) + if init_funcs is not None: + for fn in init_funcs: + fn() + + +def is_callable_allowed(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _allowed_callable_ids + + +def is_callable_disallowed(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _disallowed_callable_ids + + +def is_forbidden(obj) -> bool: + _maybe_init_lazy_module(obj) + return getattr(obj, "_dynamo_forbidden", False) + + +def is_builtin_callable(obj) -> bool: + return id(obj) in _builtin_function_ids + + +def is_builtin_constant(obj) -> bool: + return id(obj) in _builtin_constant_ids + + +def is_numpy(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids + + +""" +A note on skip/inline rules: + +Dynamo consults this file to determine whether function should be inlined or skipped. + +A skip applies at the frame boundary, meaning dynamo either triggers a graph break +at the beginning of the frame or attempts to trace/inline the whole frame. When skipping +a frame, recursively called frames are still traced by dynamo unless also skipped. + +Skipfiles (skipped at the file level instead of function level) still apply on a +frame-by-frame boundary as dynamo traces, but apply to all functions in that file. + +@skip is a helper decorator that can be applied to your function to cause it to be +included here. + +Dynamo skip/inline rules & priorities are defined as follows: +* Inline is the default behavior and will be used unless explicitly skipped. +* Dynamo has two SKIPLIST: BUILTIN_SKIPLIST and THIRDPARTY_SKIPLIST. + * BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc. + * THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc. +* Functions in these two SKIPLISTs are always skipped, except: + * They have explicitly defined rule in `manual_torch_name_rule_map`; + * The corresponding python module has been put into MOD_INLINELIST. +* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases + where we want inline the functions under torch namespace. + We should specify inline for the functions in `manual_torch_name_rule_map` or + put the corresponding python module into MOD_INLINELIST to make dynamo inline them. +* If you call functions under skipped modules/files, Dynamo will wrap these functions + as SkipFunctionVariable. There are a few functions(e.g, collections.OrderedDict) that + we have special handling at SkipFunctionVariable.call_function. + +Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline) + +To figure out what the behavior is, check the following list in order: +* `manual_torch_name_rule_map` (Inline if YES) +* MOD_INLINELIST (Inline if YES) +* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES) +* Inline by default + +In general, if you want to force inline a function or module, please consider adding +the function's python module to MOD_INLINELIST first. +Use the `manual_torch_name_rule_map` only when there are other functions under the same module that +you don't want to inline them. +""" + + +BUILTIN_SKIPLIST = ( + abc, + collections, + contextlib, + copy, + copyreg, + dataclasses, + enum, + functools, + importlib, + inspect, + linecache, + logging, + multiprocessing, + operator, + os, + posixpath, + random, + re, + selectors, + signal, + tempfile, + threading, + tokenize, + torch, # torch/* is skipped by default unless specified in FUNC_INLINELIST or MOD_INLINELIST + traceback, + types, + typing, + unittest, + weakref, + _collections_abc, + _weakrefset, +) + +# third party libraries skiplist is defined by str, because users may not use these libraries. +# we should use lazy import & skip in the future. +THIRDPARTY_SKIPLIST = ( + "fx2trt_oss", + "hypothesis", + "networkx", + "numpy", + "omegaconf", + "onnx", + "onnxruntime", + "onnx_tf", + "pandas", + "sklearn", + "tabulate", + "tensorflow", + "tensorrt", + "torch2trt", + "tqdm", + "tree", + "tvm", + "xarray", +) + + +def _strip_init_py(s): + # TODO: Once we require py3.9 use removesuffix instead. + suffix = "__init__.py" + if s.endswith(suffix): + return s[: -len(suffix)] + else: + return s + + +def _module_dir(m: types.ModuleType): + # Protect against a module not exporting __file__ - this can happen for + # frozen modules, for example. + file = getattr(m, "__file__", None) + return file and _strip_init_py(file) + + +# These are legacy workarounds, don't add new modules to this list. +# Please use the MOD_INLINELIST instead to force inline functions under particular modules. +LEGACY_MOD_INLINELIST = { + "torch._dynamo.external_utils", + "torch._export.db.examples", + "torch._export.wrappers", + "torch._functorch.apis", + "torch._functorch.deprecated", + "torch._higher_order_ops.cond", + "torch.ao.quantization.pt2e.export_utils", + "torch.ao.quantization.pt2e.qat_utils", + "torch.ao.quantization.pt2e.representation.rewrite", + "torch.ao.quantization.pt2e.utils", + "torch.ao.quantization.quantizer.xnnpack_quantizer", + "torch.optim", +} + +if torch.distributed.is_available(): + LEGACY_MOD_INLINELIST |= { + "torch.distributed._tensor.api", + "torch.distributed._tensor.device_mesh", + "torch.distributed.device_mesh", + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper", + "torch.distributed.tensor.parallel._data_parallel_utils", + "torch.distributed.tensor.parallel._utils", + "torch.distributed.tensor.parallel.style", + # we have to add replicate to LEGACY_MOD_INLINELIST to ensure + # the forward_hook won't be ignored. + "torch.distributed._composable.replicate", + } + + +# Force inline functions under these modules, even they are in *_SKIPLIST. +# We are using python module name instead of file or directory object to avoid circular dependency. +# Please keep this sorted alphabetically. +MOD_INLINELIST = { + "torch._refs", + "torch._prims", + "torch._decomp", + "torch._dynamo._trace_wrapped_higher_order_op", + "torch._dynamo.comptime", + "torch._dynamo.polyfill", + "torch._functorch.vmap", + "torch._functorch.eager_transforms", + "torch._inductor.test_operators", + "torch.amp.autocast_mode", + "torch.ao.nn", + "torch.autograd.function", + "torch.backends.cuda", + "torch.cuda.amp.autocast_mode", + "torch.distributions", + "torch.fx._pytree", + "torch.fx.passes.shape_prop", + "torch.nn", + "torch.random", + "torch.sparse", + "torch.testing", + "torch.testing._internal.hypothesis_utils", + "torch.utils._content_store", + "torch.utils._contextlib", + "torch.utils._foreach_utils", + "torch.utils._pytree", + "torch.utils.hooks", + "torch._tensor", + "torch._higher_order_ops.strict_mode", + "torch._higher_order_ops.while_loop", +} + + +if torch.distributed.is_available(): + MOD_INLINELIST.add("torch.distributed") + MOD_INLINELIST.add("torch.distributed._functional_collectives") + MOD_INLINELIST.add("torch.distributed._composable.replicate") + + +@functools.lru_cache(None) +def get_legacy_mod_inlinelist(): + inlinelist = set() + for m in LEGACY_MOD_INLINELIST: + inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + return inlinelist + + +@functools.lru_cache(None) +def get_mod_inlinelist(): + inlinelist = set() + for m in MOD_INLINELIST: + inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + return inlinelist + + +# skip some standard python builtin libs +SKIP_DIRS = [ + "", + _config_module.__file__, +] +SKIP_DIRS.extend(filter(None, (_module_dir(m) for m in BUILTIN_SKIPLIST))) + +SKIP_DIRS_RE = re.compile(r"match nothing^") + +is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode() +# Skip fbcode paths(including torch.package paths) containing +# one of the following strings. +FBCODE_SKIP_DIRS = { + "torchrec/distributed", + "torchrec/fb/distributed", + "caffe2/torch/fb/sparsenn/pooled_embeddings_modules.py", +} +FBCODE_SKIP_DIRS_RE = re.compile(f".*({'|'.join(map(re.escape, FBCODE_SKIP_DIRS))})") + + +def _recompile_re(): + global SKIP_DIRS_RE + SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})") + + +def add(import_name: str): + if isinstance(import_name, types.ModuleType): + return add(import_name.__name__) + assert isinstance(import_name, str) + from importlib.util import find_spec + + module_spec = find_spec(import_name) + if not module_spec: + return + origin = module_spec.origin + if origin is None: + return + global SKIP_DIRS_RE + SKIP_DIRS.append(_strip_init_py(origin)) + _recompile_re() + + +@dataclasses.dataclass +class SkipResult: + skipped: bool + reason: Optional[str] + + +def check_file(filename, is_inlined_call=False): + """Should skip this file?""" + if filename is None: + return SkipResult(True, "filename is None") + if any(filename.startswith(d) for d in get_legacy_mod_inlinelist()): + return SkipResult( + False, + "inlined according trace_rules.LEGACY_MOD_INLINELIST", + ) + if is_inlined_call and is_torch_inline_allowed(filename): + return SkipResult( + False, + "inlined according trace_rules.MOD_INLINELIST", + ) + if is_fbcode and bool(FBCODE_SKIP_DIRS_RE.match(filename)): + return SkipResult( + True, + "skipped according trace_rules.FBCODE_SKIP_DIRS", + ) + if bool(SKIP_DIRS_RE.match(filename)): + return SkipResult(True, "skipped according trace_rules.SKIP_DIRS") + else: + return SkipResult(False, "inlined by default") + + +@dataclasses.dataclass +class FunctionInfo: + py_obj: Optional[object] + name: Optional[str] + filename: str + code: Optional[types.CodeType] + + +""" +This is the main entry point to determine whether an object (function) should be inlined or skipped. +Let's illustrate the logic with an example: + @torch.compile + def f1(x, y): + ...... + f2(x, y) + ...... + + def f2(x, y): + ...... + f3(x, y) + ...... + + def f3(x, y): + ...... + +There are mainly three call sites of check/check_verbose: +* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py. +* When tracing the recursively called functions (like function f2 and f3). + * Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site + is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py. + * If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again + and the call site is in catch_errors_wrapper.catch_errors of convert_frame.py. +* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFunctionVariable in builder.py. + +`is_inlined_call` is used to indicate if the current function call is inlined (f2 is inlined call if it passes check) +or not (f3 is not inlined call if f2 is skipped). Inside of the `check_verbose` function, there are more rules +to be checked if this `is_inlined_call`. +The reason to have this flag is that if the upper level function call (e.g, f2) is skipped, +we don't want to inline the lower level function call (e.g, f3) by default. +""" + + +def check_verbose(obj, is_inlined_call=False): + if isinstance( + obj, (UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable) + ): + try: + py_obj = obj.get_function() + except NotImplementedError: + py_obj = None + fi = FunctionInfo(py_obj, obj.get_name(), obj.get_filename(), obj.get_code()) + elif isinstance(obj, types.CodeType): + fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj) + elif isinstance(obj, (types.FunctionType, types.MethodType)): + fi = FunctionInfo( + obj, obj.__name__, getfile(obj), obj.__code__ # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed + ) + else: + fi = FunctionInfo(obj, None, getfile(obj), None) + + # Consulte the central trace rules defined in torch._dynamo.trace_rules. + rule = torch._dynamo.trace_rules.lookup_inner( + fi.py_obj, fi.name, fi.filename, is_inlined_call + ) + if rule in [UserFunctionVariable, FunctorchHigherOrderVariable]: + return SkipResult( + False, + "inlined according trace_rules.lookup", + ) + else: + assert rule == SkipFunctionVariable, rule + return SkipResult( + True, + "skipped according trace_rules.lookup", + ) + + +def check(obj, is_inlined_call=False): + return check_verbose(obj, is_inlined_call).skipped + + +# skip common third party libs +for _name in THIRDPARTY_SKIPLIST: + add(_name) + +_recompile_re() + + +def is_torch_inline_allowed(filename): + return any(filename.startswith(d) for d in get_mod_inlinelist()) + + +@functools.lru_cache(None) +def dynamo_dir(): + import torch._dynamo + + return _module_dir(torch._dynamo) + + +def is_torch(filename): + if filename.startswith(dynamo_dir()): + return False + return filename.startswith(_module_dir(torch)) + + +""" +Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object. +""" + + +def lookup_callable(obj): + if not hashable(obj): + return None + # Custom allow/disallow in graph takes precedence over the general lookup. + if is_callable_disallowed(obj): + return SkipFunctionVariable + if is_callable_allowed(obj): + return TorchInGraphFunctionVariable + if is_builtin_callable(obj): + return BuiltinVariable + + +""" +Main entry point for looking up the trace rule (the Dynamo variable) for a given function object. +E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`. +""" + + +def lookup(obj): + return lookup_inner(obj) + + +def lookup_inner(obj, name=None, filename=None, is_direct_call=True): + # Step 1: lookup obj's tracing rule in `torch_name_rule_map`. + # The rules defined in `torch_name_rule_map` mainly includes two parts: + # - Manually defined rules for any functions. + # - The list of torch in graph functions. + if not hashable(obj): + return None + if obj is not None: + if is_aten_op_or_tensor_method(obj): + return TorchInGraphFunctionVariable + rule = get_torch_obj_rule_map().get(obj, None) + if rule is not None: + return rule + + # Step 2: lookup obj's tracing rule by function name. + if is_direct_call: + if name == "patched_init": + return SkipFunctionVariable + elif name == "__torch_function__": + return UserFunctionVariable + + # Step 3: lookup obj's tracing rule by filename. + if filename is None: + filename = getfile(obj) + + if check_file(filename, is_direct_call).skipped: + return SkipFunctionVariable + else: + return UserFunctionVariable diff --git a/MLPY/Lib/site-packages/torch/_dynamo/types.py b/MLPY/Lib/site-packages/torch/_dynamo/types.py new file mode 100644 index 0000000000000000000000000000000000000000..b99182b472d457a0faa0a4cc06ef263b1a52aa83 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/types.py @@ -0,0 +1,99 @@ +import dataclasses +import sys +import types +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union + +from typing_extensions import TypeAlias + + +if sys.version_info >= (3, 11): + from torch._C._dynamo import eval_frame + + DynamoFrameType: TypeAlias = eval_frame._PyInterpreterFrame +else: + DynamoFrameType: TypeAlias = types.FrameType + +import torch + +# This class has a `check_fn` field for the guard, +# and a `code` field for the code object. +CacheEntry = torch._C._dynamo.eval_frame._CacheEntry + +ExtraState = torch._C._dynamo.eval_frame._ExtraState + +# We use a dict to store additional data per frame. +FrameState = Dict[Any, Any] + + +class GuardFail(NamedTuple): + # A string repr of the piece of failed guard code we eval-ed + reason: str + # A code object where we failed a guard + orig_code: types.CodeType + + +class GuardFn(Protocol): + closure_vars: Dict[str, object] + args: List[str] + code_parts: List[str] + verbose_code_parts: List[str] + global_scope: Dict[str, object] + guard_fail_fn: Optional[Callable[[GuardFail], None]] + cache_entry: Optional[CacheEntry] + extra_state: Optional[ExtraState] + + # maps locals of user function to bool + def __call__(self, f_locals: Dict[str, object]) -> bool: + ... + + +@dataclasses.dataclass +class GuardedCode: + code: types.CodeType + check_fn: GuardFn + + +class DynamoCallbackFn(Protocol): + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + frame_state: FrameState, + ) -> Optional[GuardedCode]: + ... + + +DynamoCallback = Union[DynamoCallbackFn, None, bool] + + +class DynamoGuardHook(Protocol): + def __call__( + self, + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], + index: int, + last: bool, + ) -> None: + ... + + +class ProfilerStartHook(Protocol): + def __call__( + self, + name: str, + # TODO(whc) how do I annotate a _RecordFunction here? + ) -> Any: + ... + + +class ProfilerEndHook(Protocol): + def __call__(self, record: Any) -> None: + ... + + +class BytecodeHook(Protocol): + def __call__( + self, code: types.CodeType, new_code: types.CodeType + ) -> Optional[types.CodeType]: + ... diff --git a/MLPY/Lib/site-packages/torch/_dynamo/utils.py b/MLPY/Lib/site-packages/torch/_dynamo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07da9c0262e61baa16db423a829aabe6921f4785 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/utils.py @@ -0,0 +1,2563 @@ +import atexit +import collections +import contextlib +import copy +import cProfile +import dataclasses +import datetime +import dis +import enum +import functools +import gc +import inspect +import itertools +import linecache +import logging +import math +import operator +import os +import pstats +import re +import subprocess +import sys +import textwrap +import threading +import time +import types +import typing +import weakref +from contextlib import contextmanager +from functools import lru_cache, wraps +from pathlib import Path +from types import MethodWrapperType +from typing import ( + Any, + Callable, + cast, + ClassVar, + Counter, + DefaultDict, + Deque, + Dict, + Iterator, + KeysView, + List, + Optional, + Set, + Tuple, + Type, + Union, + ValuesView, +) + +from ..utils.hooks import RemovableHandle + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +try: + import torch._logging + import torch._numpy as tnp + from torch._guards import detect_fake_mode # noqa: F401n + from torch._logging import LazyString + from . import config + + # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync. + if np: + NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = ( + np, + np.fft, + np.linalg, + np.random, + ) + + NP_TO_TNP_MODULE = { + np: tnp, + np.fft: tnp.fft, + np.linalg: tnp.linalg, + np.random: tnp.random, + } + else: + NP_SUPPORTED_MODULES = tuple() + + NP_TO_TNP_MODULE = {} + from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +except ImportError: + pass + +import importlib + +import torch +import torch._functorch.config +import torch.fx.experimental.symbolic_shapes +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._utils_internal import log_compilation_event + +from torch.nn.modules.lazy import LazyModuleMixin +from torch.utils._pytree import tree_map_only + + +counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) +optimus_scuba_log: Dict[str, Any] = {} +troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html" +nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html" +nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." +log = logging.getLogger(__name__) + +# profiling compilation time by function +compilation_time_metrics: Dict[str, List[float]] = {} + +# profiling compilation time by frame phase +frame_phase_timing: Dict[str, Dict[str, float]] = {} + +timer_counter = itertools.count() + + +def tabulate(rows, headers): + try: + import tabulate + + return tabulate.tabulate(rows, headers=headers) + except ImportError: + return "\n".join( + ", ".join(map(str, row)) for row in itertools.chain([headers], rows) + ) + + +def maybe_cprofile(func): + if config.cprofile: + return cprofile_wrapper(func) + return func + + +def cprofile_wrapper(func): + @wraps(func) + def profile_wrapper(*args, **kwargs): + global timer_counter + profile_cnt = next(timer_counter) + profile_path = Path(func.__name__ + f"{profile_cnt}.profile") + prof = cProfile.Profile() + prof.enable() + start_ts = time.time() + retval = prof.runcall(func, *args, **kwargs) + profile_latency = time.time() - start_ts + prof.disable() + print( + f"### Cprofile for {func.__name__} iter {profile_cnt} took {profile_latency:.3f} seconds ###" + ) + ps = pstats.Stats(prof) + prof.dump_stats(profile_path) + svg_path = profile_path.with_suffix(".svg") + try: + gprof2dot_process = subprocess.Popen( + [ + "gprof2dot", + "-f", + "pstats", + "--node-label=total-time-percentage", + "--node-label=self-time-percentage", + "--node-label=total-time", + str(profile_path), + ], + stdout=subprocess.PIPE, + ) + subprocess.check_call( + ["dot", "-Tsvg", "-o", str(svg_path)], + stdin=gprof2dot_process.stdout, + ) + print(f"Generated SVG from profile at {str(svg_path)}") + except FileNotFoundError: + print( + "Failed to generate SVG from profile -- dumping stats instead." + "Try installing gprof2dot and dot for a better visualization" + ) + ps.sort_stats(pstats.SortKey.TIME).print_stats(20) + ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) + return retval + + return profile_wrapper + + +curr_frame = 0 + + +# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. +def increment_frame(): + global curr_frame + curr_frame = curr_frame + 1 + + +# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. +def reset_frame_count(): + global curr_frame + frame_phase_timing.clear() + compilation_time_metrics.clear() + curr_frame = 0 + + +op_count = 0 + + +def increment_op_count(cnt): + global op_count + op_count += cnt + + +# Print a report of time spent so far +# Ex: +# TIMING: +# entire_frame_compile:8.574629999999999 +# backend_compile:5.26806 +def print_time_report(): + total = 0.0 + total_by_key = {} + for timings in frame_phase_timing.values(): + for key, timing in timings.items(): + total += timing + if key not in total_by_key: + total_by_key[key] = timing + else: + total_by_key[key] += timing + + out = "TIMING:" + for key, value in total_by_key.items(): + out = f"{out} {key}:{round(value, 5)}" + + print(out) + + +# dynamo_timed API works as a function decorator +# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics +# where the key is the functions name. +# For example: +# +# @dynamo_timed +# def _foo(...): +# +# Would show up as an entry in our timing dict: +# OrderedDict([('bar.._foo', [0.083690, 0.23949, 3.1425e-05])]) +# This is extremely useful for granular debugging. +# +# For a higher-level mode, pass a phase_name into dynamo_timed +# phase_names record an extra record into a separate compilation timing structure, +# one keyed on frame+name rather than function. +# The frame is incremented outside of this function, in def increment_frame() above. + + +def dynamo_timed(original_function=None, phase_name=None): + def dynamo_timed_inner(func): + if config.cprofile: + return func + + @wraps(func) + def time_wrapper(*args, **kwargs): + key = func.__qualname__ + if key not in compilation_time_metrics: + compilation_time_metrics[key] = [] + with torch.profiler.record_function(f"{key} (dynamo_timed)"): + t0 = time.time() + r = func(*args, **kwargs) + time_spent = time.time() - t0 + compilation_time_metrics[key].append(time_spent) + if phase_name: + frame_key = str(curr_frame) + if frame_key not in frame_phase_timing: + frame_phase_timing[frame_key] = {} + if phase_name not in frame_phase_timing[frame_key]: + frame_phase_timing[frame_key][phase_name] = time_spent + else: + frame_phase_timing[frame_key][phase_name] += time_spent + return r + + return time_wrapper + + if original_function: + return dynamo_timed_inner(original_function) + return dynamo_timed_inner + + +def compile_times(repr="str", aggregate=False): + """ + Get metrics about torchdynamo frontend/backend compilation times. + + Accumulates information from functions tagged with `@dynamo_timed`. + + repr='str' returns a printable string for user interaction, and 'csv' + returns headers, rows which can be logged for output + + aggregate causes values from multiple compilations (e.g. split graphs) + to be accumulated into one value. If false, expect more than one value + per metric. + """ + + def fmt_fn(values, item_fn=lambda x: x): + if aggregate: + return item_fn(sum(values)) + return ", ".join(map(item_fn, values)) + + if repr == "str": + rows = [ + (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}")) + for k in compilation_time_metrics + ] + out = "TorchDynamo compilation metrics:\n" + out += tabulate(rows, headers=("Function", "Runtimes (s)")) + return out + elif repr == "csv": + values = [ + fmt_fn(v, item_fn=lambda x: f"{x:.6f}") + for v in compilation_time_metrics.values() + ] + headers = list(compilation_time_metrics.keys()) + return headers, values + + +@atexit.register +def dump_compile_times(): + log.info(compile_times(repr="str", aggregate=True)) + + +tensortype_to_dtype = { + torch.FloatTensor: (torch.float32, torch.float), + torch.DoubleTensor: (torch.float64, torch.double), + torch.HalfTensor: (torch.float16, torch.half), + torch.BFloat16Tensor: (torch.bfloat16,), + torch.ByteTensor: (torch.uint8,), + torch.CharTensor: (torch.int8,), + torch.LongTensor: (torch.int64, torch.long), + torch.IntTensor: (torch.int32, torch.int), + torch.ShortTensor: (torch.int16, torch.short), + torch.BoolTensor: (torch.bool,), +} + + +class DuplicateWarningChecker: + def __init__(self, maxsize=4096): + self.maxsize = maxsize + self.reset() + + def reset(self): + self.set = collections.OrderedDict() + + def add(self, key): + if key in self.set: + self.set.move_to_end(key, last=True) + if not config.verbose: + return False + else: + self.set[key] = None + while len(self.set) > self.maxsize: + self.set.popitem(last=False) + return True + + +graph_break_dup_warning_checker = DuplicateWarningChecker() + + +def setup_compile_debug(): + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + if compile_debug: + torch._logging.set_logs( + dynamo=logging.DEBUG, + aot=logging.DEBUG, + inductor=logging.DEBUG, + output_code=True, # this is off by default + ) + return add_file_handler() + + return contextlib.ExitStack() + + +def reset_graph_break_dup_checker(): + graph_break_dup_warning_checker.reset() + + +def add_file_handler(): + log_path = os.path.join(get_debug_dir(), "torchdynamo") + os.makedirs(log_path, exist_ok=True) + + log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log")) + logger = logging.getLogger("torch._dynamo") + logger.addHandler(log_file_handler) + + exitstack = contextlib.ExitStack() + exitstack.callback(lambda: logger.removeHandler(log_file_handler)) + return exitstack + + +def setup_log_file(): + exitstack = contextlib.ExitStack() + if config.log_file_name is not None: + log_file_handler = logging.FileHandler(config.log_file_name) + for logger in torch._logging._internal.get_loggers(): + logger.addHandler(log_file_handler) + exitstack.callback(lambda: logger.removeHandler(log_file_handler)) + return exitstack + + return exitstack + + +def gen_record_file_name(exc, code): + return f"{get_debug_dir()}/error_recordings/\ +{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" + + +def write_record_to_file(filename, exec_record): + try: + if os.path.exists(filename): + log.warning( + "Unable to write execution record %s; file already exists.", filename + ) + else: + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + exec_record.dump(f) + except Exception: + log.exception("Unable to write execution record %s", filename) + + +def count_calls(g: fx.Graph): + c = 0 + for n in g.nodes: + if "call" in n.op: + c += 1 + return c + + +def identity(x): + return x + + +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + # cannot hash writable memoryview object + except ValueError: + return False + + +def nothing(*args, **kwargs): + pass + + +class ExactWeakKeyDictionary: + """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" + + def __init__(self): + self.values = dict() + self.refs = dict() + + def __getitem__(self, key): + return self.values[id(key)] + + def get(self, key, default=None): + return self.values.get(id(key), default) + + def __contains__(self, key): + return id(key) in self.values + + def __setitem__(self, key, value): + idx = id(key) + if idx not in self.refs: + self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) + self.values[idx] = value + + def _remove_id(self, idx): + if idx in self.values: + del self.values[idx] + if idx in self.refs: + del self.refs[idx] + + def clear(self): + self.refs.clear() + self.values.clear() + + +def istype(obj, allowed_types): + """isinstance() without subclasses""" + if isinstance(allowed_types, (tuple, list, set)): + return type(obj) in allowed_types + return type(obj) is allowed_types + + +def is_typing(value): + # _Final catches most of typing classes: + # - Any + # - Callable + # - Union + # ... + # + # NB: we intentionally ignore classes that inherit from Generic, since they + # can be used as both TypingVariable as well as UserDefinedClassVariable. + return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] + + +def is_numpy_int_type(value): + if not np: + return False + + return istype( + value, + ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ) + + +def is_numpy_float_type(value): + if not np: + return False + + return istype( + value, + ( + np.float16, + np.float32, + np.float64, + ), + ) + + +def is_function_or_wrapper(value): + return ( + is_function(value) + or isinstance(value, functools._lru_cache_wrapper) + and is_function(inspect.getattr_static(value, "__wrapped__")) + or isinstance(value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)) + ) + + +def is_function(value): + return isinstance( + value, + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + torch.jit.ScriptFunction, + ), + ) + + +def unwrap_if_wrapper(fn): + return unwrap_with_attr_name_if_wrapper(fn)[0] + + +def unwrap_with_attr_name_if_wrapper(fn): + # unpack @functools.lru_cache wrapped function + if isinstance(fn, functools._lru_cache_wrapper): + fn = inspect.getattr_static(fn, "__wrapped__") + attr_name = "__wrapped__" + # unpack @torch._dynamo.optimize()(fn) wrapped function + elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + attr_name = "_torchdynamo_inline" + # unpack torch.jit.script_if_tracing + elif is_function(fn) and inspect.getattr_static( + fn, "__script_if_tracing_wrapper", False + ): + fn = inspect.getattr_static(fn, "__original_fn", fn) + attr_name = "__original_fn" + else: + attr_name = None + return fn, attr_name + + +def is_numpy_ndarray(value): + if not np: + return False + + return istype(value, np.ndarray) + + +def istensor(obj): + """Check of obj is a tensor""" + tensor_list = ( + torch.Tensor, + torch.nn.Parameter, + *config.traceable_tensor_subclasses, + ) + tensor_list = tensor_list + (torch._subclasses.FakeTensor,) + return istype(obj, tensor_list) + + +def is_lazy_module(mod): + return isinstance(mod, LazyModuleMixin) + + +@functools.lru_cache(4096) +def print_once(*args): + print(*args) + + +def make_cell(val=None): + """Some black magic to create a cell object that usually only exists in a closure""" + x = val + + def f(): + return x + + assert f.__closure__ is not None and len(f.__closure__) == 1 + return f.__closure__[0] + + +def proxy_args_kwargs(args, kwargs): + try: + proxy_args = tuple(arg.as_proxy() for arg in args) + proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return proxy_args, proxy_kwargs + except NotImplementedError as e: + from .exc import unimplemented + from .variables.base import typestr + + raise unimplemented( + f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}" + ) from e + + +@dataclasses.dataclass +class CompilationMetrics: + frame_key: str + co_name: str + co_filename: str + co_firstlineno: int + cache_size: int + accumulated_cache_size: int + guard_count: Optional[int] + shape_env_guard_count: Optional[int] + graph_op_count: Optional[int] + graph_node_count: Optional[int] + graph_input_count: Optional[int] + start_time: float + entire_frame_compile_time_s: Optional[float] + backend_compile_time_s: Optional[float] + inductor_compile_time_s: Optional[float] + code_gen_time_s: Optional[float] + fail_type: Optional[str] + fail_reason: Optional[str] + fail_user_frame_filename: Optional[str] + fail_user_frame_lineno: Optional[int] + non_compliant_ops: Set[str] + compliant_custom_ops: Set[str] + + +DEFAULT_COMPILATION_METRICS_LIMIT = 64 + + +_compilation_metrics: Deque[CompilationMetrics] = collections.deque( + maxlen=DEFAULT_COMPILATION_METRICS_LIMIT +) + + +def record_compilation_metrics(compilation_metrics: CompilationMetrics): + global _compilation_metrics + _compilation_metrics.append(compilation_metrics) + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) + + +def set_compilation_metrics_limit(new_size: int) -> None: + global _compilation_metrics + while len(_compilation_metrics) > new_size: + _compilation_metrics.popleft() + new_deque = collections.deque(_compilation_metrics, maxlen=new_size) + _compilation_metrics = new_deque + + +def clear_compilation_metrics() -> None: + global _compilation_metrics + _compilation_metrics.clear() + + +def get_compilation_metrics() -> List[CompilationMetrics]: + return list(_compilation_metrics) + + +@dataclasses.dataclass +class CleanupHook: + """Remove a global variable when hook is called""" + + scope: Dict[str, Any] + name: str + + def __call__(self, *args): + CleanupManager.count -= 1 + del self.scope[self.name] + + @staticmethod + def create(scope, name, val): + assert name not in scope + CleanupManager.count += 1 + scope[name] = val + return CleanupHook(scope, name) + + +class CleanupManager(ExactWeakKeyDictionary): + count = 0 + instance: ClassVar["CleanupManager"] + + def _remove_id(self, idx): + for hook in self.values[idx]: + hook() + super()._remove_id(idx) + + +CleanupManager.instance = CleanupManager() + + +def clone_tensor(x): + """Clone the tensor and its gradient""" + y = x.clone().requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = x.grad.clone() + return y + + +def clone_input(x, *, dtype=None): + """copy while preserving strides""" + # TODO: this is questionable + if is_fake(x): + # this func fails on fake tensors in __torch_dispatch__ + return x + + def torch_clone(x): + y = torch.clone(x) + if x.is_leaf: + y.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = clone_input(x.grad, dtype=dtype) + if hasattr(x, "_dynamo_dynamic_indices"): + y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] + return y + + with torch.no_grad(): + if x.device.type == "xla": + # Access data_ptr() for a xla tensor will cause crash + return torch_clone(x) + + needed_size = sum( + (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) + ) + if x.is_quantized: + result = torch.empty_quantized((needed_size + 32,), x) + else: + result = torch.empty( + needed_size + 32, dtype=dtype or x.dtype, device=x.device + ) + cache_line_offset = ( + (x.data_ptr() - result.data_ptr()) % 32 + ) // x.element_size() + result.as_strided_(x.size(), x.stride(), cache_line_offset) + try: + result.copy_(x.clone()) + if x.is_leaf: + result.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + result.grad = clone_input(x.grad, dtype=dtype) + except RuntimeError: + # RuntimeError: unsupported operation: more than one element of the written-to + # tensor refers to a single memory location. Please clone() the tensor before + # performing the operation. + return torch_clone(x) + if hasattr(x, "_dynamo_dynamic_indices"): + result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] + return result + + +def clone_inputs(example_inputs): + res: Union[Dict[Any, Any], List[Any]] + if type(example_inputs) is dict: + res = dict(example_inputs) + for key, value in res.items(): + if isinstance(value, tuple): + res[key] = clone_inputs(value) + else: + assert isinstance(value, torch.Tensor), type(value) + res[key] = clone_input(value) + return res + + res = list(example_inputs) + for i in range(len(res)): + if isinstance(res[i], torch.Tensor): + res[i] = clone_input(res[i]) + return res + + +def skip_frame_if_in_functorch_mode(val: torch.Tensor): + try: + val.data_ptr() # will throw for functorch tensors + except RuntimeError as e: + from .exc import SkipFrame + + # This will be GradTrackingTensor/BatchedTensor/etc + functorch_subclass_name = re.sub(r"\(.*", "", repr(val)) + raise SkipFrame( + f"torch.compile cannot be run in context: {functorch_subclass_name}" + ) from e + + +@contextmanager +def preserve_rng_state(): + disable_functorch = torch._C._DisableFuncTorch + disable_current_modes = torch.utils._python_dispatch._disable_current_modes + with disable_current_modes(), disable_functorch(): + rng_state = torch.clone(torch.random.get_rng_state()) + skip_frame_if_in_functorch_mode(rng_state) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + try: + yield + finally: + with torch.utils._python_dispatch._disable_current_modes(): + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + + +def is_jit_model(model0): + return isinstance( + model0, + ( + torch.jit._trace.TopLevelTracedModule, + torch.jit._script.RecursiveScriptModule, + torch.jit.ScriptFunction, + torch.jit.ScriptModule, + ), + ) + + +def torchscript(model, example_inputs, verbose=False): + if is_jit_model(model): + # already done? + return model + + try: + return torch.jit.trace(model, example_inputs) + except Exception: + try: + return torch.jit.script(model) + except Exception: + if verbose: + log.exception("jit error") + else: + log.error("Both torch.jit.trace and torch.jit.script failed") + return None + + +def getfile(obj): + try: + return inspect.getfile(obj) + except (TypeError, OSError): + return None + + +def is_namedtuple(obj): + """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" + return is_namedtuple_cls(type(obj)) + + +def is_namedtuple_cls(cls): + """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" + try: + if issubclass(cls, tuple): + bases = getattr(cls, "__bases__", []) or [None] + module = getattr(cls, "__module__", None) + return module == "torch.return_types" or ( + bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields") + ) + except TypeError: + pass + return False + + +@functools.lru_cache(1) +def namedtuple_fields(cls): + """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" + if cls is slice: + return ["start", "stop", "step"] + + assert issubclass(cls, tuple) + if hasattr(cls, "_fields"): + # normal namedtuples + return cls._fields + + @dataclasses.dataclass + class Marker: + index: int + + # frustrating ones e.g. torch.return_types.max + assert cls.__module__ == "torch.return_types" + obj = cls(map(Marker, range(cls.n_fields))) + fields: List[Optional[str]] = [None] * cls.n_fields + for name in dir(obj): + if name[0] != "_" and isinstance(getattr(obj, name), Marker): + fields[getattr(obj, name).index] = name + return fields + + +def checkpoint_params(gm): + with torch.no_grad(): + rng_state = torch.clone(torch.random.get_rng_state()) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + saved_state = [] + for param in itertools.chain(gm.parameters(), gm.buffers()): + saved_state.append((param, param._version, torch.clone(param))) + + def restore(): + with torch.no_grad(): + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + for param, version, original_value in saved_state: + if param._version != version: + param.copy_(original_value) + + return restore + + +def timed(model, example_inputs, times=1): + if torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + else: + synchronize = nothing + + synchronize() + gc.collect() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize() + t1 = time.perf_counter() + return result, t1 - t0 # type: ignore[possibly-undefined] + + +def check_is_cuda(gm, example_inputs): + return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) + + +@lru_cache(32) +def rot_n_helper(n): + assert n > 1 + vars = [f"v{i}" for i in range(n)] + rotated = reversed(vars[-1:] + vars[:-1]) + fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") + fn.__name__ = f"rot_{n}_helper" + return fn + + +common_constant_types = { + int, + float, + complex, + bool, + str, + bytes, + type(None), + Ellipsis.__class__, + types.CodeType, + torch.device, + torch.dtype, + torch.memory_format, + torch.layout, +} + + +def is_safe_constant(v): + if istype(v, (tuple, frozenset)): + return all(map(is_safe_constant, v)) + return isinstance(v, (enum.Enum, type)) or istype( + v, + common_constant_types | {slice}, + ) + + +def specialize_symnode(arg): + from .variables import ConstantVariable, SymNodeVariable + + # Guard and specialize + if isinstance(arg, SymNodeVariable): + return ConstantVariable.create(arg.evaluate_expr()) + + return arg + + +def guard_if_dyn(arg): + from .variables import ConstantVariable + + arg = specialize_symnode(arg) + + if isinstance(arg, ConstantVariable): + return arg.as_python_constant() + + return arg + + +def check_constant_args(args, kwargs): + return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) + + +def check_unspec_python_args(args, kwargs): + from .variables.constant import ConstantVariable + from .variables.tensor import UnspecializedPythonVariable + + unspec_count = 0 + for x in itertools.chain(args, kwargs.values()): + if isinstance(x, UnspecializedPythonVariable): + unspec_count += 1 + elif not isinstance(x, (UnspecializedPythonVariable, ConstantVariable)): + return False + else: + pass + + return unspec_count > 0 + + +def check_numpy_ndarray_args(args, kwargs): + from .variables.tensor import NumpyNdarrayVariable + + return any( + isinstance(x, NumpyNdarrayVariable) + for x in itertools.chain(args, kwargs.values()) + ) + + +dict_keys: Type[KeysView[Any]] = type(dict().keys()) +dict_values: Type[ValuesView[Any]] = type(dict().values()) +odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values()) +tuple_iterator: Type[Iterator[Any]] = type(iter(tuple())) +tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] +object_new = object.__new__ + + +def nn_module_new(cls): + obj = object_new(cls) + torch.nn.Module.__init__(obj) + return obj + + +def product(it): + return functools.reduce(operator.mul, it, 1) + + +def tuple_iterator_getitem(it, index): + _, (obj,), start = it.__reduce__() + return obj[start + index] + + +iter_next = next + + +def to_subclass(t, cls): + return t.as_subclass(cls) + + +def dict_keys_getitem(d, n): + return next(itertools.islice(iter(d), n, n + 1)) + + +def enum_repr(value, local): + # enum class can override __str__ method. Use __class__ and name attribute + # to extract the class name and key name. + name = value.__class__.__name__ + val = value.name + scope = "L" if local else "G" + local_name = f'{scope}["{name}"].{val}' + return local_name + + +def _get_fake_tensor(vt): + fake_tensor = vt.as_proxy().node.meta.get("example_value") + if not is_fake(fake_tensor): + from .exc import unimplemented + + unimplemented("Cannot check Tensor object identity without its fake value") + return fake_tensor + + +def iter_contains(items, search, tx, check_tensor_identity=False): + from .variables import ( + BuiltinVariable, + ConstantVariable, + TensorVariable, + VariableTracker, + ) + + if search.is_python_constant(): + found_const = any( + x.is_python_constant() + and x.as_python_constant() == search.as_python_constant() + for x in items + ) + return ConstantVariable.create(found_const) + + must_check_tensor_id = False + if check_tensor_identity and isinstance(search, TensorVariable): + must_check_tensor_id = True + # Match of Tensor means match of FakeTensor + search = _get_fake_tensor(search) + + found: Optional[VariableTracker] = None + for x in items: + if must_check_tensor_id: + if isinstance(x, TensorVariable): + if search is _get_fake_tensor(x): # Object equivalence + return ConstantVariable.create(True) + else: + check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) + if found is None: + found = check + else: + found = BuiltinVariable(operator.or_).call_function( + tx, [check, found], {} + ) + if found is None: + found = ConstantVariable.create(False) + return found + + +def key_is_id(k): + """Returns whether it indexes dictionaries using its id""" + return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) + + +def key_to_id(value): + return [id(k) if key_is_id(k) else k for k in value.keys()] + + +def const_repr(x, *, local) -> str: + from .trace_rules import is_builtin_callable + + if isinstance(x, (list, tuple)): + elems_repr = ",".join(const_repr(s, local=local) for s in x) + if isinstance(x, list): + return f"[{elems_repr}]" + else: + assert isinstance(x, tuple) + if len(x) == 1: + return f"({elems_repr},)" + else: + return f"({elems_repr})" + elif isinstance(x, enum.Enum): + # To workaround repr(Enum) returning invalid global reference before python 3.11 + # by calling enum_repr and removing quotes to render enum in guard code. + return enum_repr(x, local=local).replace("'", "") + elif is_builtin_callable(x): + return x.__name__ + elif isinstance(x, type): + + def fullname(o): + klass = o.__class__ + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ + + return fullname(x) + else: + return f"{x!r}" + + +def dict_keys_repr(const_keys, *, local) -> str: + keys_str = ",".join(const_repr(s, local=local) for s in const_keys) + return "[" + keys_str + "]" + + +GLOBAL_KEY_PREFIX = "__dict_key" + + +from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 + + +def wrap_fake_exception(fn): + try: + return fn() + except UnsupportedFakeTensorException as e: + from .exc import unimplemented + + msg = f"Unsupported: {e.reason} with fake tensor propagation." + log.warning(msg) + raise unimplemented(msg) from e + + +def deepcopy_to_fake_tensor(obj, fake_mode): + with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): + return wrap_fake_exception(lambda: copy.deepcopy(obj)) + + +def rmse(ref, res): + """ + Calculate root mean squared error + """ + return torch.sqrt(torch.mean(torch.square(ref - res))) + + +def same( + ref, + res, + fp64_ref=None, + cos_similarity=False, + tol=1e-4, + equal_nan=False, + exact_dtype=True, + relax_numpy_equality=False, + ignore_non_fp=False, + log_error=log.error, +): + """Check correctness to see if ref and res match""" + if fp64_ref is None: + fp64_ref = ref + if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): + assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" + if len(ref) != len(res): + log_error("Length mismatch") + return False + return len(ref) == len(res) and all( + same( + ai, + bi, + fp64_refi, + cos_similarity, + tol, + equal_nan, + exact_dtype, + relax_numpy_equality, + ignore_non_fp, + log_error=log_error, + ) + for ai, bi, fp64_refi in zip(ref, res, fp64_ref) + ) + elif isinstance(ref, dict): + assert isinstance(res, dict) + assert set(ref.keys()) == set( + res.keys() + ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}" + for k in sorted(ref.keys()): + if not ( + same( + ref[k], + res[k], + fp64_ref[k], + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + ) + ): + log_error("Accuracy failed for key name %s", k) + return False + return True + elif isinstance(ref, (torch.Tensor, float)): + assert not isinstance(ref, torch._subclasses.FakeTensor) + assert not isinstance(res, torch._subclasses.FakeTensor) + + def to_tensor(t): + return t if isinstance(t, torch.Tensor) else torch.tensor(t) + + ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) + + if ref.is_sparse: + assert res.is_sparse + ref = ref.to_dense() + res = res.to_dense() + assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" + if exact_dtype: + if ref.dtype != res.dtype: + log_error("dtype mismatch %s, %s", ref.dtype, res.dtype) + return False + if ref.dtype == torch.bool: + if ignore_non_fp: + return True + # triton stores bool as int8, so add this for more accurate checking + r = torch.allclose( + ref.to(dtype=torch.uint8), + res.to(dtype=torch.uint8), + atol=tol, + rtol=tol, + equal_nan=equal_nan, + ) + if not r: + log_error("Accuracy failed: uint8 tensor did not match") + return r + + if cos_similarity: + ref = ref.flatten().to(torch.float32) + res = res.flatten().to(torch.float32) + if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True): + # early exit that handles zero/nan better + # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 + return True + score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) + if score < 0.99: + log.warning("Similarity score=%s", score.cpu().detach().item()) + return score >= 0.99 + else: + if not exact_dtype: + ref = ref.to(res.dtype) + + # First try usual allclose + if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan): + return True + + # Check error from fp64 version + if fp64_ref.dtype == torch.float64: + ref_error = rmse(fp64_ref, ref).item() + # ref unable to produce this with stable numerics in this precision, ignore + if math.isnan(ref_error): + log.warning( + "Found nan in reference. Consider running in higher precision." + ) + + res_error = rmse(fp64_ref, res).item() + + # In the case of using AMP (Automatic Mixed Precision), certain models have + # failed the benchmark's correctness check. However, the end-to-end model's + # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%. + # Thus, it's possible that the correctness check failures for these models are + # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms. + multiplier = 3.0 if res.dtype == torch.bfloat16 else 2.0 + + if ( + fp64_ref.numel() < 1000 + or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) + # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE + or tol >= 2 * 1e-2 + ): + # In the presence of noise, noise might dominate our error + # metric for smaller tensors. + # Similary, for 1x1 kernels, there seems to be high noise with amp. + multiplier = 3.0 + + passes_test = res_error <= (multiplier * ref_error + tol / 10.0) + if not passes_test: + log_error( + "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s", + res_error, + ref_error, + res.size(), + ) + # import pdb; pdb.set_trace() + return passes_test + + if ignore_non_fp: + return True + + log_error("Accuracy failed: allclose not within tol=%s", tol) + return False + elif isinstance(ref, (str, int, type(None), bool, torch.device)): + if ignore_non_fp: + return True + r = ref == res + if not r: + log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res) + return r + elif is_numpy_int_type(ref) or is_numpy_float_type(ref): + if relax_numpy_equality and not ( + is_numpy_int_type(res) or is_numpy_float_type(res) + ): + ref = ref.item() + r = (type(ref) is type(res)) and (ref == res) + if not r: + log_error("Accuracy failed (numpy): %s != %s", ref, res) + return r + elif is_numpy_ndarray(ref): + return (type(ref) is type(res)) and same( + torch.as_tensor(ref), + torch.as_tensor(res), + fp64_ref, + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + ) + elif type(ref).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + "LongformerMaskedLMOutput", + "Instances", + "SquashedNormal", + "Boxes", + "Normal", + "TanhTransform", + "Foo", + "Variable", + ): + assert type(ref) is type(res) + return all( + same( + getattr(ref, key), + getattr(res, key), + getattr(fp64_ref, key), + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + ) + for key in ref.__dict__.keys() + ) + else: + raise RuntimeError(f"unsupported type: {type(ref).__name__}") + + +def format_func_info(code): + short_filename = code.co_filename.split("/")[-1] + return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" + + +@contextlib.contextmanager +def disable_cache_limit(): + prior = config.cache_size_limit + config.cache_size_limit = sys.maxsize + prior_acc_limit = config.accumulated_cache_size_limit + config.accumulated_cache_size_limit = sys.maxsize + + try: + yield + finally: + config.cache_size_limit = prior + config.accumulated_cache_size_limit = prior_acc_limit + + +# map from transformed code back to original user code +orig_code_map = ExactWeakKeyDictionary() + +# keep a record of code_obj -> list of guard failure reasons for logging +guard_failures: DefaultDict[Any, List[Any]] = collections.defaultdict(list) + +# Keep a record of graph break reasons for logging +graph_break_reasons: List["torch._dynamo.output_graph.GraphCompileReason"] = list() + +# keep record of compiled code, if we are in "error if recompile" +# to track code that dynamo has compiled previously +seen_code_map = ExactWeakKeyDictionary() + + +class CompileProfiler: + """Utility for profiling how and what dynamo would compile. + + Can be used for + * diagnosing recompilation issues + * determining an appropriate compile cache limit + * (TODO)confirming which functions got compiled/skipped + """ + + def __init__(self): + self.frame_count = 0 + self.op_count = 0 + self.backend_ctx_ctor = disable_cache_limit + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + return gm.forward + + # no-op __enter__ and __exit__ to preserve BC + def __enter__(self): + return self + + def __exit__(self, typ, val, traceback): + pass + + def get_metrics(self): + return {"guard_failures": guard_failures} + + def report(self): + metrics = self.get_metrics() + gf = metrics["guard_failures"] + + def num_recompiles(code): + return len(gf[code]) + + def recompile_reasons(code): + return "\n".join([str(x) for x in gf[code]]) + + summarized_gf = [ + [format_func_info(code), num_recompiles(code), recompile_reasons(code)] + for code in gf + ] + + def graph_break_report(): + if "graph_break" in counters: + graph_breaks = counters["graph_break"] + return tabulate( + [[msg, graph_breaks[msg]] for msg in graph_breaks], + headers=["Graph Break Reason", "Count"], + ) + + def recompilation_report(): + if len(gf): + max_recompiles = max([num_recompiles(code) for code in gf]) + recomp_table = tabulate( + summarized_gf, + headers=["Function", "Recompiles", "Recompile Reasons"], + ) + return recomp_table + textwrap.dedent( + f""" + + Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited. + """ + ) + + report = textwrap.dedent( + """ + Torchdynamo Profiler Report + =========================== + + Graph Breaks + ------------ + Graph breaks happen when torchdynamo encounters code it can't safely trace. + If you want to find out why breaks are happening, check below for each break reason + You may gain additional insight by passing `fullgraph=True` to torch.compile, + to stop at the first break. + + """ + ) + report += graph_break_report() or "No graph breaks detected." + report += textwrap.dedent( + """ + + Recompilation + ------------- + These subgraphs were recompiled more than once due to guard failures + Guard failures indicate some condition assumed to be static by the tracer changed, + making it unsafe to reuse the compiled program. + + """ + ) + report += recompilation_report() or "No recompilation detected.\n" + return report + + +# return same dir unless user changes config between calls +@functools.lru_cache(None) +def _get_debug_dir(root_dir): + dir_name = ( + "run_" + + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + # use pid to avoid conflicts among ranks + + "-pid_" + + str(os.getpid()) + ) + return os.path.join(root_dir, dir_name) + + +def get_debug_dir(): + debug_root = config.debug_dir_root + return _get_debug_dir(debug_root) + + +def extract_fake_example_value(node, required=True): + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + elif required: + from torch._dynamo.exc import unimplemented + + unimplemented("`FakeTensor` example value was required but not available") + else: + return None + + +def ensure_graph_fake(e, tx): + assert maybe_get_fake_mode(e) is tx.fake_mode + return e + + +def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): + def visit(n: torch.fx.Node): + if n.op == "call_function" and "example_value" not in n.meta: + # fake tensor validity is checked inside get_fake_value using + # ensure_graph_fake + return get_fake_value(n, tx, allow_non_graph_fake) + + out = n.meta["example_value"] + if not allow_non_graph_fake and isinstance(out, torch.Tensor): + return ensure_graph_fake(out, tx) + return out + + return torch.fx.node.map_arg(nodes, visit) + + +def get_fake_value(node, tx, allow_non_graph_fake=False): + """ + Run the computation represented by `node` using fake tensors and return the result. + + allow_non_graph_fake: whether to allow the return result to be: + 1. non-fake or 2. fake that is not created by this instance of Dynamo. + If `True`, you must be prepared to deal with such return values, ideally + by further wrapping them as this graph's fakes. + """ + from torch.utils._sympy.value_ranges import ValueRangeError + from .exc import ( + TorchRuntimeError, + unimplemented, + Unsupported, + UserError, + UserErrorType, + ) + + op = node.op + + # FX Node should always return the same fake value + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + + args, kwargs = get_fake_values_from_nodes( + tx, (node.args, node.kwargs), allow_non_graph_fake + ) + + nnmodule = None + if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): + # If the first argument is nn.Module, should copy to fake mode. + args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) + + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it. + # Afterwards, lazy module deletes its pre-hooks + # to avoid treating it as lazy on subsequent recompile. + nnmodule._infer_parameters(nnmodule, args) + + # no matter it's lazy module or not, we should copy to fake mode. + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + try: + with tx.fake_mode, enable_python_dispatcher(): + ret_val = wrap_fake_exception( + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ) + except Unsupported: + raise + except RuntimeError as e: + cause: BaseException = e + if e.__cause__ is not None: + cause = e.__cause__ + + if isinstance( + cause, torch._subclasses.fake_tensor.DataDependentOutputException + ): + unimplemented( + f"data dependent operator: {cause.func}; " + "to enable, set torch._dynamo.config.capture_scalar_outputs = True" + ) + elif isinstance( + cause, torch._subclasses.fake_tensor.DynamicOutputShapeException + ): + unimplemented( + f"dynamic shape operator: {cause.func}; " + "to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True" + ) + elif isinstance( + cause, torch._subclasses.fake_tensor.UnsupportedOperatorException + ): + op = cause.func + import_suggestion = "" + if isinstance(op, torch._ops.OpOverload): + maybe_pystub = torch._C._dispatch_pystub( + op._schema.name, op._schema.overload_name + ) + if maybe_pystub is not None: + module, ctx = maybe_pystub + import_suggestion = ( + f"It's possible that the support was implemented in " + f"module `{module}` and you may need to `import {module}`" + f"({ctx}), otherwise " + ) + unimplemented( + f"unsupported operator: {cause.func} ({import_suggestion}see " + "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" + " for how to fix)" + ) + elif isinstance( + cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode + ): + raise UserError( # noqa: TRY200 + UserErrorType.CONSTRAINT_VIOLATION, + "Tried to use data-dependent value in the subsequent computation. " + "This can happen when we encounter unbounded dynamic value that is unknown during tracing time. " + "You will need to explicitly give hint to the compiler. Please take a look at " + f"constrain_as_value OR constrain_as_size APIs. {cause}", + case_name="constrain_as_size_example", + ) + elif isinstance(cause, ValueRangeError): + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + + if not allow_non_graph_fake: + _ = tree_map_only( + torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val + ) + return ret_val + + +_current_node = threading.local() + + +def get_current_node(): + return getattr(_current_node, "value", None) + + +@contextmanager +def set_current_node(node): + old = get_current_node() + _current_node.value = node + try: + yield + finally: + _current_node.value = old + + +def run_node(tracer, node, args, kwargs, nnmodule): + """ + Runs a given node, with the given args and kwargs. + + Behavior is dictated by a node's op. + + run_node is useful for extracting real values out of nodes. + See get_real_value for more info on common usage. + + Note: The tracer arg is only used for 'get_attr' ops + Note: The nnmodule arg is only used for 'call_module' ops + + Nodes that are not call_function, call_method, call_module, or get_attr will + raise an AssertionError. + """ + op = node.op + + with set_current_node(node): + + def make_error_message(e): + return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e) + + try: + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + elif op == "get_attr": + return tracer.get_submodule(node.target) + elif op == "placeholder": + assert "example_value" in node.meta + return node.meta["example_value"] + + except (NotImplementedError, UnsupportedFakeTensorException) as e: + # NB: mimic how wrap_fake_exception does it + from .exc import unimplemented + + raise unimplemented(make_error_message(e)) from e + except Exception as e: + raise RuntimeError(make_error_message(e)).with_traceback( + e.__traceback__ + ) from e + + raise AssertionError(op) + + +def get_real_value(node, tracer): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + from .exc import TorchRuntimeError + + cache = tracer.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( + (node.args, node.kwargs), + lambda n: get_real_value(n, tracer), + ) + + if op == "call_module": + nn_module = tracer.output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = run_node(tracer, node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + return real_value + + +def assert_no_fake_params_or_buffers(gm): + from torch._subclasses.fake_tensor import FakeTensorConfig + + def stack_or_hint(t): + if FakeTensorConfig.debug: + import traceback + + return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}" + else: + return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." + + for name, buffer in gm.named_buffers(): + assert not isinstance( + buffer, torch._subclasses.FakeTensor + ), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}" + for name, param in gm.named_parameters(): + assert not isinstance( + param, torch._subclasses.FakeTensor + ), f"Unexpected fake param {name} {stack_or_hint(param)}" + + +def fqn(obj: Any): + """ + Returns the fully qualified name of the object. + """ + return f"{obj.__module__}.{obj.__qualname__}" + + +def ifdynstaticdefault(count1, count2): + if torch._dynamo.config.assume_static_by_default: + return count1 + else: + return count2 + + +def import_submodule(mod: types.ModuleType): + """ + Ensure all the files in a given submodule are imported + """ + for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))): + if filename.endswith(".py") and filename[0] != "_": + importlib.import_module(f"{mod.__name__}.{filename[:-3]}") + + +def object_has_getattribute(value: Any): + try: + if isinstance( + inspect.getattr_static(type(value), "__getattribute__"), + types.FunctionType, + ): + return True + except AttributeError: + pass + return False + + +def get_custom_getattr(value: Any): + try: + getattr_fn = inspect.getattr_static(type(value), "__getattr__") + except AttributeError: + getattr_fn = None + if getattr_fn is torch.nn.Module.__getattr__: + # ignore this case of getattr + getattr_fn = None + return getattr_fn + + +class TensorStaticReason(enum.Enum): + PARAMETER = 2 + NOT_TENSOR = 4 + NN_MODULE_PROPERTY = 5 + + +def tensor_static_reason_to_message(reason: TensorStaticReason): + if reason == TensorStaticReason.PARAMETER: + return "mark_dynamic on parameter, parameters are always static today." + if reason == TensorStaticReason.NOT_TENSOR: + return "mark_dynamic on a non tensor, how did this happen?" + if reason == TensorStaticReason.NN_MODULE_PROPERTY: + return "tensor is static because it is nn module associated." + raise AssertionError(f"Illegal reason {reason}") + + +def tensor_always_has_static_shape( + tensor: Union[torch.Tensor, Any], + is_tensor: bool, + guard_source: "torch._guards.GuardSource", +) -> Tuple[bool, Optional[TensorStaticReason]]: + """ + Given a tensor, source, and is_tensor flag, determine if a shape should be static. + + Args: + tensor - the real tensor to evaluate, parameters force a static shape. + is_tensor - internal dynamo check, essentially "is_tensor": target_cls is TensorVariable, + tensors not in a TensorVariable for whatever reason are forced static. + + Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape. + The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed. + """ + if guard_source.is_nn_module() and config.force_nn_module_property_static_shapes: + return True, TensorStaticReason.NN_MODULE_PROPERTY + if type(tensor) is torch.nn.Parameter and config.force_parameter_static_shapes: + return True, TensorStaticReason.PARAMETER + if not is_tensor: + return True, TensorStaticReason.NOT_TENSOR + return False, None + + +def lazy_format_graph_code(name, gm, maybe_id=None): + def format_name(): + if maybe_id is not None: + return f"{name} {maybe_id}" + else: + return name + + return LazyString( + lambda: _format_graph_code( + f"===== {format_name()} =====\n", + gm.forward.__code__.co_filename, + gm.print_readable(print_output=False), + ) + ) + + +def _format_graph_code(name, filename, graph_str): + return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" + + +def lazy_format_graph_tabular(fn_name, gm): + def inner(): + try: + from tabulate import tabulate # TODO: Check that this is installed + except ImportError: + return ( + "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n" + + str(lazy_format_graph_code(fn_name, gm)) + ) + + node_specs = [ + [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes + ] + graph_str = tabulate( + node_specs, headers=["opcode", "name", "target", "args", "kwargs"] + ) + return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str) + + return LazyString(inner) + + +def format_bytecode(prefix, name, filename, line_no, code): + return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" + + +forward_hook_names = ["_forward_pre_hooks", "_forward_hooks"] +backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"] +state_dict_hook_names = [ + "_state_dict_pre_hooks", + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", +] +all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names + + +def nn_module_get_all_hooks( + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): + reset_code = torch._C._dynamo.eval_frame.reset_code + """ + Sometimes its useful to differentiate between types of hooks such as forward/backward/pre + hooks executed during module.__call__, and state_dict hooks which are executed separately. + """ + hook_dicts_to_check = [] + check_all_hooks = ( + not check_forward_hooks + and not check_backward_hooks + and not check_state_dict_hooks + ) + if check_forward_hooks or check_all_hooks: + hook_dicts_to_check.extend(forward_hook_names) + if check_backward_hooks or check_all_hooks: + hook_dicts_to_check.extend(backward_hook_names) + if check_state_dict_hooks: + hook_dicts_to_check.extend(state_dict_hook_names) + + all_hooks = [] + for hook_dict_name in hook_dicts_to_check: + hooks = getattr(mod, hook_dict_name, []) + for hook_name in hooks: + hook = hooks[hook_name] + + all_hooks.append(hook) + return all_hooks + + +def nnmodule_has_hooks( + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): + """ + Helper function to check if a module has any hooks attached to it. + """ + hooks = nn_module_get_all_hooks( + mod, + check_forward_hooks=check_forward_hooks, + check_backward_hooks=check_backward_hooks, + check_state_dict_hooks=check_state_dict_hooks, + ) + return bool(hooks) + + +def to_numpy_helper(value): + """Convert tensor and tnp.ndarray to numpy.ndarray.""" + if is_fake(value): + return value + if isinstance(value, tnp.ndarray): + return to_numpy_helper(value.tensor) + elif isinstance(value, torch.Tensor): + return value.numpy(force=True) + elif isinstance(value, (tuple, list)): + return type(value)(to_numpy_helper(obj) for obj in value) + else: + return value + + +def numpy_to_tensor(value): + """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" + assert np is not None + if isinstance(value, np.ndarray): + return torch.as_tensor(value) + if isinstance(value, tnp.ndarray): + return value.tensor + elif isinstance(value, (tuple, list)): + return type(value)(numpy_to_tensor(obj) for obj in value) + else: + return value + + +class numpy_to_tensor_wrapper: + def __init__(self, f): + self.f = f + self.__name__ = "wrapped_" + self.f.__name__ + + def __repr__(self): + return f">" + + def __call__(self, *args, **kwargs): + out = self.f(*args, **kwargs) + return numpy_to_tensor(out) + + +def numpy_attr_wrapper(obj, name): + if isinstance(obj, tnp.ndarray): + out = getattr(obj, name) + return numpy_to_tensor(out) + elif isinstance(obj, torch.Tensor): + out = getattr(tnp.ndarray(obj), name) + return numpy_to_tensor(out) + + +class numpy_method_wrapper: + """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" + + def __init__(self, method: str): + self.method = method + self.__name__ = "wrapped_" + self.method + + def __repr__(self): + return f">" + + def __call__(self, *args, **kwargs): + obj = args[0] + if isinstance(obj, torch.Tensor): + obj = tnp.ndarray(obj) + method_callable = getattr(obj, self.method) + out = method_callable(*args[1:], **kwargs) + return numpy_to_tensor(out) + + +class numpy_operator_wrapper: + """Implements dunder methods for tnp.ndarray via functions from the operator library""" + + def __init__(self, op: Callable[..., Any]): + self.op = op + self.__name__ = f"wrapped_{op.__name__}" + + def __repr__(self): + return f">" + + def __call__(self, *args, **kwargs): + assert not kwargs + + args = ( + tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args + ) + out = self.op(*args) + return numpy_to_tensor(out) + + +def defake(x): + if not isinstance(x, FakeTensor): + return x + size: "torch._prims_common.ShapeType" + stride: "torch._prims_common.StrideType" + if x._has_symbolic_sizes_strides: + size = [] + for s in x.size(): + if isinstance(s, torch.SymInt): + size.append(s.node.shape_env.size_hint(s.node.expr)) + else: + size.append(s) + stride = [] + for s in x.stride(): + if isinstance(s, torch.SymInt): + stride.append(s.node.shape_env.size_hint(s.node.expr)) + else: + stride.append(s) + else: + size = x.size() + stride = x.stride() + y = torch.empty_strided( + size, + stride, + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + y.zero_() + return y + + +def is_utils_checkpoint(obj): + # Lazy import to avoid circular dependencies + import torch.utils.checkpoint + + return obj is torch.utils.checkpoint.checkpoint + + +def build_checkpoint_variable(**options): + import torch._higher_order_ops.wrap as higher_order_ops + from .variables.higher_order_ops import TorchHigherOrderOperatorVariable + + # TODO - This is a temporary situation where we have two versions of + # checkpointing implementation. We will converge on one and remove the other. + activation_checkpoint_op: "torch._ops.HigherOrderOperator" = ( + higher_order_ops.tag_activation_checkpoint + ) + if torch._functorch.config.functionalize_rng_ops: + activation_checkpoint_op = higher_order_ops.wrap_activation_checkpoint + + return TorchHigherOrderOperatorVariable.make( + activation_checkpoint_op, + **options, + ) + + +def is_compile_supported(device_type): + from .eval_frame import is_dynamo_supported + + compile_supported = is_dynamo_supported() + if device_type == "cpu": + pass + elif device_type == "cuda" and compile_supported: + from torch.utils._triton import has_triton + + compile_supported = has_triton() + else: + compile_supported = False + return compile_supported + + +# The following 3.11 source code functions are adapted from +# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py +# in order to output source code corresponding to bytecode in 3.11+. +# We need our own versions since we want to support multiline expressions. +def _fix_offset(str: str, offset: int) -> int: + """ + Convert byte offset `offset` of `str` into character offset. + Byte offset is used for 3.11+ instruction column data. + Takes things like unicode characters into consideration. + + Unchanged from CPython implementation. + """ + as_utf8 = str.encode("utf-8") + return len(as_utf8[:offset].decode("utf-8", errors="replace")) + + +@dataclasses.dataclass +class _Anchors: + # inclusive + left_end_lineno: int + left_end_offset: int + right_start_lineno: int + # exclusive + right_start_offset: int + + +def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: + """ + Given source code `segment` corresponding to a bytecode + instruction, determine: + - for binary ops, the location of the binary op + - for indexing, the location of the brackets. + `segment` is expected to be a valid Python expression + """ + assert sys.version_info >= (3, 11) + + import ast + + try: + # Without brackets, `segment` is parsed as a statement. + # We expect an expression, so wrap `segment` in + # brackets to handle multi-line expressions. + tree = ast.parse("(\n" + segment + "\n)") + except SyntaxError: + return None + + if len(tree.body) != 1: + return None + + lines = segment.split("\n") + + # get character index given byte offset + def normalize(lineno, offset): + return _fix_offset(lines[lineno], offset) + + # Gets the next valid character index in `lines`, if + # the current location is not valid. Handles empty lines. + def next_valid_char(lineno, col): + while lineno < len(lines) and col >= len(lines[lineno]): + col = 0 + lineno += 1 + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + # Get the next valid character index in `lines`. + def increment(lineno, col): + col += 1 + lineno, col = next_valid_char(lineno, col) + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + # Get the next valid character at least on the next line + def nextline(lineno, col): + col = 0 + lineno += 1 + lineno, col = next_valid_char(lineno, col) + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + statement = tree.body[0] + if isinstance(statement, ast.Expr): + expr = statement.value + if isinstance(expr, ast.BinOp): + # ast gives locations for BinOp subexpressions, e.g. + # ( left_expr ) + ( right_expr ) + # left^^^^^ right^^^^^ + # -2 since end_lineno is 1-indexed and because we added an extra + # bracket to `segment` when calling ast.parse + cur_lineno = cast(int, expr.left.end_lineno) - 2 + cur_col = normalize(cur_lineno, expr.left.end_col_offset) + cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) + + # Heuristic to find the operator character. + # The original CPython implementation did not look for ), \, or #, + # leading to incorrect anchor location, e.g. + # (x) + (y) + # ~~^~~~~~~ + while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": + if ch in "\\#": + cur_lineno, cur_col = nextline(cur_lineno, cur_col) + else: + cur_lineno, cur_col = increment(cur_lineno, cur_col) + + # binary op is 1 or 2 characters long, on the same line + right_col = cur_col + 1 + if ( + right_col < len(lines[cur_lineno]) + and not (ch := lines[cur_lineno][right_col]).isspace() + and ch not in "\\#" + ): + right_col += 1 + # right_col can be invalid since it is exclusive + + return _Anchors(cur_lineno, cur_col, cur_lineno, right_col) + elif isinstance(expr, ast.Subscript): + # ast gives locations for value and slice subexpressions, e.g. + # ( value_expr ) [ slice_expr ] + # value^^^^^ slice^^^^^ + # subscript^^^^^^^^^^^^^^^^^^^^ + # find left bracket (first '[' after value) + left_lineno = cast(int, expr.value.end_lineno) - 2 + left_col = normalize(left_lineno, expr.value.end_col_offset) + left_lineno, left_col = next_valid_char(left_lineno, left_col) + while lines[left_lineno][left_col] != "[": + left_lineno, left_col = increment(left_lineno, left_col) + # find right bracket (final character of expression) + right_lineno = cast(int, expr.end_lineno) - 2 + right_col = normalize(right_lineno, expr.end_col_offset) + return _Anchors(left_lineno, left_col, right_lineno, right_col) + elif isinstance(expr, ast.Call): + # ( func_expr ) (args, kwargs) + # func^^^^^ + # call^^^^^^^^^^^^^^^^^^^^^^^^ + # find left bracket (first '(' after func) + left_lineno = cast(int, expr.func.end_lineno) - 2 + left_col = normalize(left_lineno, expr.func.end_col_offset) + left_lineno, left_col = next_valid_char(left_lineno, left_col) + while lines[left_lineno][left_col] != "(": + left_lineno, left_col = increment(left_lineno, left_col) + # find right bracket (final character of expression) + right_lineno = cast(int, expr.end_lineno) - 2 + right_col = normalize(right_lineno, expr.end_col_offset) + return _Anchors(left_lineno, left_col, right_lineno, right_col) + + return None + + +def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str: + """ + Python 3.11+ only. Returns lines of source code (from code object `code`) + corresponding to `inst`'s location data, and underlines relevant code to `inst`. + + Example: CALL on `g`: + f(g( + ^^ + h(x))) + ^^^^^ + + We need our own implementation since `format_frame_summary` in + Python's `traceback` module doesn't handle multi-line expressions + (and their anchor extraction code is not completely correct). + """ + assert inst.positions is not None + if inst.positions.lineno is None: + return "" + # The rstrip + "\n" pattern is used throughout this function to handle + # linecache.getline errors. Error lines are treated as empty strings "", but we want + # to treat them as blank lines "\n". + first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip() + if inst.positions.end_lineno is None: + return first_line + if inst.positions.col_offset is None or inst.positions.end_col_offset is None: + return first_line + + # character index of the start of the instruction + start_offset = _fix_offset(first_line, inst.positions.col_offset) + # character index of the end of the instruction + # compute later since end may be a different line + end_offset = None + # expression corresponding to the instruction so we can get anchors + segment = "" + # underline markers to be printed - start with `~` marker and replace with `^` later + markers = [] + + # Compute segment and initial markers + if inst.positions.end_lineno == inst.positions.lineno: + end_offset = _fix_offset(first_line, inst.positions.end_col_offset) + segment = first_line[start_offset:end_offset] + markers.append(" " * start_offset + "~" * (end_offset - start_offset)) + else: + segment = first_line[start_offset:] + "\n" + markers.append(" " * start_offset + "~" * (len(first_line) - start_offset)) + last_line = linecache.getline( + code.co_filename, inst.positions.end_lineno + ).rstrip() + end_offset = _fix_offset(last_line, inst.positions.end_col_offset) + for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno): + line = linecache.getline(code.co_filename, lineno).rstrip() + segment += line + "\n" + # don't underline leading spaces + num_spaces = len(line) - len(line.lstrip()) + markers.append(" " * num_spaces + "~" * (len(line) - num_spaces)) + segment += last_line[:end_offset] + num_spaces = len(last_line) - len(last_line.lstrip()) + markers.append(" " * num_spaces + "~" * (end_offset - num_spaces)) + + anchors: Optional[_Anchors] = None + try: + anchors = _extract_anchors_from_expr(segment) + except AssertionError: + pass + + # replace `~` markers with `^` where necessary + if anchors is None: + markers = [marker.replace("~", "^") for marker in markers] + else: + # make markers mutable + mutable_markers: List[List[str]] = [list(marker) for marker in markers] + + # anchor positions do not take start_offset into account + if anchors.left_end_lineno == 0: + anchors.left_end_offset += start_offset + if anchors.right_start_lineno == 0: + anchors.right_start_offset += start_offset + + # Turn `~`` markers between anchors to `^` + for lineno in range(len(markers)): + for col in range(len(mutable_markers[lineno])): + if lineno < anchors.left_end_lineno: + continue + if lineno == anchors.left_end_lineno and col < anchors.left_end_offset: + continue + if ( + lineno == anchors.right_start_lineno + and col >= anchors.right_start_offset + ): + continue + if lineno > anchors.right_start_lineno: + continue + if mutable_markers[lineno][col] == "~": + mutable_markers[lineno][col] = "^" + + # make markers into strings again + markers = ["".join(marker) for marker in mutable_markers] + + result = "" + for i in range(len(markers)): + result += ( + linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip() + + "\n" + ) + result += markers[i] + "\n" + return result + + +def get_static_address_type(t): + if isinstance(t, torch.Tensor): + return getattr(t, "_dynamo_static_input_type", None) + + return None + + +def is_rng_state_getter_or_setter(value): + getters = ( + # The following two functions are not identical, so don't remove anyone! + torch._C.Generator.get_state, + torch.default_generator.get_state, + torch.get_rng_state, + torch.cuda.get_rng_state, + ) + setters = ( + torch._C.Generator.set_state, + torch.default_generator.set_state, + torch.set_rng_state, + torch.cuda.set_rng_state, + ) + return value in (*setters, *getters) + + +def is_tensor_base_attr_getter(value): + return ( + isinstance(value, types.MethodWrapperType) + and value.__name__ == "__get__" + and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined] + ) + + +def is_torch_function_object(value): + return hasattr(value, "__torch_function__") + + +def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bool: + from torch._dynamo.variables import UserDefinedObjectVariable + from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable + + return isinstance(vt, TensorWithTFOverrideVariable) or ( + isinstance(vt, UserDefinedObjectVariable) + and hasattr(vt.value, "__torch_function__") + ) + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + ) + + +def get_first_attr(obj, *attrs): + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +@contextlib.contextmanager +def maybe_enable_compiled_autograd(should_enable): + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 + return torch._inductor.compile(gm_, example_inputs_) + + return torch.compile(gm, backend=inner_compiler, fullgraph=True, dynamic=True) + + if should_enable: + with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx: + yield ctx + else: + yield + + +def invalid_removeable_handle(): + # need a subclass so weakref works + class Invalid(dict): # type: ignore[type-arg] + pass + + return RemovableHandle(Invalid()) + + +# Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's. +# Attribute changes to the original object/proxy will be reflected in the other. +# This is useful for cases where we want a keep-alive reference to a module without increasing +# its reference count. +def nn_module_proxy(mod): + if not isinstance(mod, torch.nn.Module): + return mod + if isinstance(mod, torch.fx.GraphModule): + # Dynamo-generated GM's shouldn't contain user-created GM's + return mod + proxy = mod.__class__.__new__(mod.__class__) + proxy.__dict__ = mod.__dict__ + return proxy diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__init__.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89e0853c7ee0366809e82a6bf80fb7c718f748a0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/__init__.py @@ -0,0 +1,151 @@ +# mypy: ignore-errors + +from .base import VariableTracker +from .builtin import BuiltinVariable +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + ContextWrappingVariable, + DeterministicAlgorithmsVariable, + DisabledSavedTensorsHooksVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + StreamContextVariable, + StreamVariable, + VmapIncrementNestingCtxManagerVariable, + WithExitFunctionVariable, +) +from .dicts import ( + ConstDictVariable, + CustomizedDictVariable, + DataClassVariable, + DefaultDictVariable, + SetVariable, +) +from .distributed import BackwardHookVariable +from .functions import ( + FunctoolsPartialVariable, + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) +from .higher_order_ops import ( + FunctorchHigherOrderVariable, + TorchHigherOrderOperatorVariable, +) +from .iter import ( + CountIteratorVariable, + CycleIteratorVariable, + IteratorVariable, + ItertoolsVariable, + RepeatIteratorVariable, +) +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + RestrictedListSubclassVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradFunctionContextVariable, + AutogradFunctionVariable, + ClosureVariable, + DeletedVariable, + GetAttrVariable, + InspectSignatureVariable, + LambdaVariable, + MethodWrapperVariable, + NewCellVariable, + NewGlobalVariable, + NumpyVariable, + PythonModuleVariable, + StringFormatVariable, + SuperVariable, + TypingVariable, + UnknownVariable, +) +from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable +from .sdpa import SDPAParamsVariable +from .tensor import ( + FakeItemVariable, + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, + UntypedStorageVariable, +) +from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable +from .user_defined import ( + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedObjectVariable, +) + +__all__ = [ + "AutogradFunctionContextVariable", + "AutogradFunctionVariable", + "BackwardHookVariable", + "BaseListVariable", + "BuiltinVariable", + "ClosureVariable", + "ConstantVariable", + "ConstDictVariable", + "ContextWrappingVariable", + "CountIteratorVariable", + "CustomizedDictVariable", + "CycleIteratorVariable", + "DataClassVariable", + "DefaultDictVariable", + "DeletedVariable", + "DeterministicAlgorithmsVariable", + "EnumVariable", + "FakeItemVariable", + "GetAttrVariable", + "GradModeVariable", + "InspectSignatureVariable", + "IteratorVariable", + "ItertoolsVariable", + "LambdaVariable", + "LazyVariableTracker", + "ListIteratorVariable", + "ListVariable", + "NamedTupleVariable", + "NestedUserFunctionVariable", + "NewCellVariable", + "NewGlobalVariable", + "NNModuleVariable", + "NumpyNdarrayVariable", + "NumpyVariable", + "PythonModuleVariable", + "RangeVariable", + "RemovableHandleVariable", + "RepeatIteratorVariable", + "RestrictedListSubclassVariable", + "SDPAParamsVariable", + "SkipFunctionVariable", + "SliceVariable", + "StringFormatVariable", + "SuperVariable", + "TensorVariable", + "TorchCtxManagerClassVariable", + "TorchInGraphFunctionVariable", + "TupleVariable", + "UnknownVariable", + "UnspecializedNNModuleVariable", + "UnspecializedPythonVariable", + "UntypedStorageVariable", + "UserDefinedClassVariable", + "UserDefinedObjectVariable", + "UserFunctionVariable", + "UserMethodVariable", + "VariableTracker", + "WithExitFunctionVariable", +] diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7c33bf1053649edb407ba14dbb0668426acf35d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e084b7472a6f0558dac4477445aa4b50522bb23e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..470039655a5c1ebe9413c87e942e669b911be475 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b5163e9d1e1a7a88851274fd482ac6d575e6264 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..525ef4a6399508a7262bc45b3b21410ad65c8a98 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3075f3bbcbac44781c5c19b910e4a9f7eda28bd9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0687cd3e0712fd3a2f7c7a87c5111575afbef68 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e79f54fdc8b89365afa17534868da8ab341ad8f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fbf35a7a02e46cdb98a3413d32c7311147c677d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1871d4873bd89ef3fa0a7799f86d278df189c3b5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3a2c8f5f6f9cfdeff64a1041eeb1ae760f1138b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e766dbd987a9de19fa8d2d0702d5c024c1237b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df4e8f0a6a47caaf474131885557d6aec59db829 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d693d5180318e148ffa6311533c0edb711e850c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81c483f70811a7c8a4f368f0b2d8b84065d65d0a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ff635deec1661a7d0b2d6eb5c8f79850b090b89 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e300696957ba17cc4a6309ff2cc7001aa55f4f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92df6fe544a84ca6650bbab84704de865f81cb2b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6580e1cb9beb5a3de9c6b386886b7656a3cbedb Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..772ae91d5f7afe8967f11e9aa77ffd70fe7ef5ca Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a5694184b0fb9d328aaa3886b6474aac562df8b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/base.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1771d6d7e64d427086be3943751180f5216f6ad9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/base.py @@ -0,0 +1,420 @@ +# mypy: ignore-errors + +import collections +from enum import Enum +from typing import Any, Callable, Dict, List + +from .. import variables +from ..current_scope_id import current_scope_id +from ..exc import unimplemented +from ..source import AttrSource, Source +from ..utils import identity, istype + + +class MutableLocalSource(Enum): + """ + If the VariableTracker.mutable_local represents a Variable that: + - already existed that Dynamo began tracking while introspection (Existing) + - is a new variable that is created during Dynamo introspection (Local) + """ + + Existing = 0 + Local = 1 + + +class ParentsTracker: + """ + This is a perf optimization to limit the number of objects we need to visit in tx.replace_all. + This must be a seperate object so that it is not cloned in apply. + """ + + def __init__(self): + # logically this is a set, but we use a dict to ensure deterministic ordering + self.parents: Dict[ParentsTracker, bool] = dict() + + def add(self, parent): + self.parents[parent] = True + + def recursive_parents(self): + rv = dict(self.parents) + worklist = list(self.parents) + while worklist: + for parent in worklist.pop().parents: + if parent not in rv: + assert isinstance(parent, ParentsTracker) + rv[parent] = True + worklist.append(parent) + return rv.keys() + + +class MutableLocalBase: + """ + Base class for Variable.mutable_local + """ + + def __init__(self, typ: MutableLocalSource): + # In HigherOrderOperator tracing, we need to distinguish + # between MutableLocals inside the HigherOrderOperator and + # ones outside it. For example, it is not safe to mutate + # `a` in the following example because it was constructed + # in a different scope. + # + # def f(x): + # a = 1 + # def g(x): + # nonlocal a + # a = 2 + # return x + # return wrap(g, x) + a + # + # We use self.scope to distinguish this. + # scope == 0: The object was an existing variable + # scope == 1: The object was created while Dynamo + # was introspecting a function + # (and no HigherOrderOps were involved) + # scope >= 2: The object was created through + # Dynamo introspection of a HigherOrderOp. + # The exact number corresponds to the level + # of nested HigherOrderOps. + if typ is MutableLocalSource.Existing: + self.scope = 0 + elif typ is MutableLocalSource.Local: + self.scope = current_scope_id() + else: + unimplemented(f"Unsupported MutableLocalSource: {typ}") + + +class MutableLocal(MutableLocalBase): + """ + Marker used to indicate this (list, iter, etc) was constructed in + local scope and can be mutated safely in analysis without leaking + state. + """ + + def __init__(self): + super().__init__(MutableLocalSource.Local) + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + + +def _is_top_level_scope(scope_id): + return scope_id == 1 + + +def is_side_effect_safe(m: MutableLocalBase): + scope_id = current_scope_id() + + # In the top-level scope (if no HigherOrderOperators are involved), + # we are allowed to modify variables created in this scope as well + # as existing variables. + if _is_top_level_scope(scope_id): + return True + # Otherwise, only allow local mutation of variables created in the current scope + return m.scope == scope_id + + +class VariableTrackerMeta(type): + def __call__(cls, *args, **kwargs): + """Call __post_init__""" + obj = type.__call__(cls, *args, **kwargs) + obj.__post_init__(*args, **kwargs) + return obj + + def __instancecheck__(cls, instance) -> bool: + """Make isinstance work with LazyVariableTracker""" + if type.__instancecheck__( + variables.LazyVariableTracker, instance + ) and cls not in ( + VariableTracker, + variables.LazyVariableTracker, + ): + instance = instance.realize() + return type.__instancecheck__(cls, instance) + + +class VariableTracker(metaclass=VariableTrackerMeta): + """ + Base class for tracked locals and stack values + + VariableTracker instances are immutable and should be copied in + order to change them. + """ + + # fields to leave unmodified in apply() + _nonvar_fields = { + "value", + "guards", + "source", + "mutable_local", + "parents_tracker", + "user_code_variable_name", + } + + def clone(self, **kwargs): + """Shallow copy with some (optional) changes""" + args = dict(self.__dict__) + args.update(kwargs) + return self.__class__(**args) + + @classmethod + def copy(cls, value): + """Deeper (but not full) copy, leaving FX and user objects alone""" + return cls.apply(identity, value) + + @classmethod + def apply( + cls, + fn: Callable[["VariableTracker"], "VariableTracker"], + value, + cache=None, + skip_fn=lambda _: False, # Whether we should skip applying to this var + ): + """ + Walk this object and call fn on all the VariableTracker + instances + """ + if cache is None: + cache = dict() + + idx = id(value) + if idx in cache: + return cache[idx][0] + + if isinstance(value, VariableTracker): + if not skip_fn(value): + + def update_object_dict(v): + changed = False + rv = v.__dict__ + for key in rv.keys(): + if key not in v._nonvar_fields: + prior = rv[key] + rv[key] = cls.apply(fn, prior, cache, skip_fn) + changed = changed or prior is not rv[key] + + return v + + value = value.unwrap() + was_realized = value.is_realized() + result = fn(update_object_dict(value)) + if not was_realized and value.is_realized(): + # running fn() resulted in value getting realized, + # which means we missed updating the contents of result + result = update_object_dict(result.unwrap()) + else: + result = fn(value) + if result is not None: + result = result.unwrap() + elif istype(value, list): + result = [cls.apply(fn, v, cache, skip_fn) for v in value] + elif istype(value, tuple): + result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value) + elif istype(value, (dict, collections.OrderedDict)): + result = { + k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items()) + } + else: + result = value + + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = (result, value) + return result + + def __repr__(self): + return f"{self.__class__.__name__}()" + + def python_type(self): + """ + Abstract method to be implemented by subclasses of VariableTracker. + + This method should return the type represented by the instance of the subclass. + The purpose is to provide a standardized way to retrieve the Python type information + of the variable being tracked. + + Returns: + type: The Python type (such as int, str, list, etc.) of the variable tracked by + the subclass. If the type cannot be determined or is not relevant, + leaving it undefined or invoking super() is always sound. + + Note: + This is an abstract method and may be overridden in subclasses. + + Example: + class SetVariable(VariableTracker): + def python_type(self): + return set + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + raise NotImplementedError(f"{self} has no type") + + def as_python_constant(self): + """For constants""" + raise NotImplementedError(f"{self} is not a constant") + + def guard_as_python_constant(self): + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + try: + return self.as_python_constant() + except NotImplementedError as e: + unimplemented(str(e)) + + def is_python_constant(self): + try: + self.as_python_constant() + return True + except NotImplementedError: + return False + + def make_guard(self, fn): + if self.source: + return self.source.make_guard(fn) + raise NotImplementedError() + + def const_getattr(self, tx, name: str) -> Any: + """getattr(self, name) returning a python constant""" + raise NotImplementedError() + + def var_getattr(self, tx, name: str) -> "VariableTracker": + """getattr(self, name) returning a new variable""" + value = self.const_getattr(tx, name) + if not variables.ConstantVariable.is_literal(value): + raise NotImplementedError() + source = None + if self.source: + source = AttrSource(self.source, name) + return variables.ConstantVariable.create(value, source=source) + + def is_proxy(self): + try: + self.as_proxy() + return True + except NotImplementedError: + return False + + def as_proxy(self): + raise NotImplementedError(str(self)) + + def maybe_fx_node(self): + try: + proxy = self.as_proxy() + import torch.fx + + if isinstance(proxy, torch.fx.Proxy): + return proxy.node + return None + except NotImplementedError: + return None + + def reconstruct(self, codegen): + raise NotImplementedError() + + def can_reconstruct(self, tx): + """If it is possible to reconstruct the Python object this + VariableTracker represents.""" + assert tx is tx.output.root_tx, "Only root tx can reconstruct" + try: + from ..codegen import PyCodegen + + cg = PyCodegen(tx) + self.reconstruct(cg) + return True + except NotImplementedError: + return False + + def unpack_var_sequence(self, tx) -> List["VariableTracker"]: + raise NotImplementedError() + + def has_unpack_var_sequence(self, tx) -> bool: + try: + self.unpack_var_sequence(tx) + return True + except NotImplementedError: + return False + + def inspect_parameter_names(self) -> List[str]: + unimplemented(f"inspect_parameter_names: {self}") + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + unimplemented(f"hasattr {self.__class__.__name__} {name}") + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + unimplemented(f"call_function {self} {args} {kwargs}") + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__len__" and self.has_unpack_var_sequence(tx): + assert not (args or kwargs) + return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) + elif ( + name == "__getattr__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + ): + return self.var_getattr(tx, args[0].as_python_constant()) + raise unimplemented(f"call_method {self} {name} {args} {kwargs}") + + def rename(self, tx, name): + return self + + def realize(self) -> "VariableTracker": + """Used by LazyVariableTracker to build the real VariableTracker""" + return self + + def recursive_realize(self): + """Realize all objects under this""" + return VariableTracker.apply(lambda x: x.realize(), self) + + def unwrap(self) -> "VariableTracker": + """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" + return self + + def is_realized(self): + """Used by LazyVariableTracker to indicate an unrealized node""" + return True + + def __init__( + self, + *, + source: Source = None, + mutable_local: MutableLocal = None, + parents_tracker: ParentsTracker = None, + ): + super().__init__() + self.source = source + self.mutable_local = mutable_local + self.parents_tracker = parents_tracker + + def __post_init__(self, *args, **kwargs): + if self.parents_tracker is None: + self.parents_tracker = ParentsTracker() + # visit children 1 level deep and ensure parent is set properly + VariableTracker.apply( + lambda node: node.parents_tracker.add(self.parents_tracker), + [v for k, v in self.__dict__.items() if k not in self._nonvar_fields], + skip_fn=lambda _: True, + ) + + +def typestr(*objs): + if len(objs) == 1: + (obj,) = objs + if isinstance(obj, VariableTracker): + return str(obj) + else: + return type(obj).__name__ + else: + return " ".join(map(typestr, objs)) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/builder.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4c0f423413637b4916192cab40da6defd0eadc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/builder.py @@ -0,0 +1,1976 @@ +# mypy: ignore-errors + +import abc +import collections +import contextlib +import dataclasses +import enum +import functools +import inspect +import itertools +import logging +import operator +import re +import sys +import types +from typing import List, NamedTuple, Optional, Union + +from torch.utils._sympy.value_ranges import ValueRanges + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +import torch + +from torch import SymInt +from torch._guards import GuardSource, TracingContext +from torch._ops import HigherOrderOperator +from torch._streambase import _EventBase, _StreamBase +from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + DimDynamic, + RelaxedUnspecConstraint, + StatefulSymbolicContext, + SubclassSymbolicContext, + SymbolicContext, +) +from torch.fx.immutable_collections import immutable_list +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils.weak import TensorWeakRef +from .. import config, mutation_guard, replay_record, trace_rules + +from ..device_interface import get_registered_device_interfaces +from ..exc import InternalTorchDynamoError, unimplemented +from ..guards import GuardBuilder, install_guard, make_dupe_guard +from ..side_effects import SideEffects +from ..source import ( + AttrSource, + ConstantSource, + ConstDictKeySource, + ConvertIntSource, + GetItemSource, + is_constant_source, + is_from_defaults, + LocalSource, + NumpyTensorSource, + RandomValueSource, + Source, + TupleIteratorGetItemSource, +) +from ..trace_rules import is_callable_allowed, is_numpy +from ..utils import ( + build_checkpoint_variable, + clone_input, + common_constant_types, + get_fake_value, + get_static_address_type, + is_function_or_wrapper, + is_namedtuple, + is_typing, + is_utils_checkpoint, + istype, + odict_values, + preserve_rng_state, + tensor_always_has_static_shape, + tuple_iterator, + tuple_iterator_getitem, + tuple_iterator_len, + unwrap_with_attr_name_if_wrapper, + wrap_fake_exception, +) + +from .base import MutableLocal, typestr, VariableTracker +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + AutocastModeVariable, + EventVariable, + NullContextVariable, + PreserveVersionContextVariable, + StreamContextVariable, + StreamVariable, +) +from .dicts import ( + ConstDictVariable, + DataClassVariable, + DefaultDictVariable, + HFPretrainedConfigVariable, + PythonSysModulesVariable, + SetVariable, +) +from .distributed import ( + DeviceMeshVariable, + PlacementClassVariable, + PlacementVariable, + ProcessGroupVariable, +) +from .functions import ( + CollectiveFunctionRewriteVariable, + FunctoolsPartialVariable, + TritonKernelVariable, + UserMethodVariable, +) +from .higher_order_ops import TorchHigherOrderOperatorVariable +from .iter import ItertoolsVariable +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + RestrictedListSubclassVariable, + SizeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradFunctionContextVariable, + AutogradFunctionVariable, + ComptimeVariable, + DebuggingVariable, + GetAttrVariable, + GetSetDescriptorVariable, + InspectSignatureVariable, + LambdaVariable, + MethodWrapperVariable, + NumpyVariable, + PythonModuleVariable, + SavedTensorBox, + TypingVariable, +) +from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable +from .optimizer import OptimizerVariable + +from .sdpa import SDPAParamsVariable +from .tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorSubclassVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable +from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable +from .user_defined import ( + KeyedJaggedTensorVariable, + UserDefinedClassVariable, + UserDefinedObjectVariable, +) + + +log = logging.getLogger(__name__) + + +DimList = List + + +class _missing: + pass + + +@dataclasses.dataclass +class GraphArg: + source: Source + # TODO: storing a SymInt here but not a FakeTensor is a pretty strange + # thing to do. Probably should have example (which stores an int) and + # fake_example + _example: Union[TensorWeakRef, torch.SymInt] + is_unspecialized: bool + fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] + # UnspecializedPythonVariable often masquerades as a tensor. + # We MUST NOT generate shape guard code + # that actually tries to access tensor properties on these values. + # is_tensor lets us tell if this graph arg actually is a tensor + # or not. + is_tensor: bool = True + # Sometimes, the Tensor we pass to example is freshly allocated (smh). + # Then we cannot only keep a weak reference to it. This lets you + # stash a strong reference too. + example_strong_ref: Optional[torch.Tensor] = None + + @property + def example(self): + if isinstance(self._example, TensorWeakRef): + r = self._example() + assert r is not None + return r + else: + return self._example + + def __post_init__(self): + if isinstance(self._example, torch.Tensor): + self._example = TensorWeakRef(self._example) + assert is_fake(self.fake_tensor) + + def reconstruct(self, codegen): + self.source.reconstruct(codegen) + + def erase(self): + self._example = None + self.example_strong_ref = None + + def __eq__(self, other): + return self.source.name() == other.source.name() + + +class BackwardStateGraphArg(GraphArg): + def __init__(self): + super().__init__( + source=None, + _example=BackwardState(), + is_unspecialized=False, + fake_tensor=None, + is_tensor=False, + ) + + def reconstruct(self, codegen): + assert codegen.tx.output.backward_state_var + codegen.load_import_from(BackwardState.__module__, "BackwardState") + codegen.call_function(0, True) + codegen.dup_top() + codegen.store(codegen.tx.output.backward_state_var) + + +@dataclasses.dataclass +class FrameStateSizeEntry: + scalar: Optional[int] + size: Optional[List[int]] + + +class VariableBuilder: + """Wrap a python value in a VariableTracker() instance""" + + def __init__( + self, + tx, + source: Source, + ): + assert ( + source is not None + ), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." + assert TracingContext.try_get() is not None, "Expected active TracingContext" + super().__init__() + self.tx = tx + self.source = source + self.name = source.name() + + def __call__(self, value): + if value in self.tx.output.side_effects: + side_effect_result = self.tx.output.side_effects[value] + dup_guard = make_dupe_guard(self.source, side_effect_result.source) + if dup_guard: + self.install_guards(dup_guard) + return side_effect_result + vt = self._wrap(value) + vt.source = self.source + if self._can_lift_attrs_to_inputs(vt): + vt = self.tx.output.side_effects.track_object_existing(value, vt) + return vt + + def _can_lift_attrs_to_inputs(self, vt): + if type(vt) in [ + TensorVariable, + TensorWithTFOverrideVariable, + UserDefinedObjectVariable, + NumpyNdarrayVariable, + ]: + return True + return False + + @staticmethod + @functools.lru_cache(None) + def _common_constants(): + return { + # We zero-one specialize shapes, so specialize these constants + # too + 0, + 1, + # NB: There used to be more constants here, but honestly it was + # pretty confusing. Note we specialize floats by default, and + # DON'T specialize ints by default. This all only matters with + # dynamic_shapes + } + + def get_source(self): + return self.source + + def install_guards(self, *guards): + source = self.get_source() + if ( + isinstance(source, ConstantSource) + or source.guard_source() == GuardSource.CONSTANT + ): + return None + install_guard(*[source.make_guard(guard) for guard in guards], skip=1) + return {} + + def set_source_and_track_mutable(self, value, var): + assert isinstance(var, VariableTracker) + var.source = self.source + return self.tx.output.side_effects.track_mutable(value, var) + + @classmethod + @functools.lru_cache(None) + def _type_dispatch(cls): + # NB: Careful not to close over self to avoid ref cycle from lru_cache + entries = [ + ( + ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + cls.wrap_tensor, + ), + ( + (tuple, list, odict_values, collections.deque, torch.Size), + cls.wrap_listlike, + ), + (tuple_iterator, cls.wrap_tuple_iterator), + ((slice, range), cls.wrap_slice_range), + (tuple(common_constant_types), cls.wrap_literal), + ] + + if config.trace_numpy and np: + entries.append((np.ndarray, cls.wrap_numpy_ndarray)) + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, tuple) else (ts,): + assert t not in result + result[t] = fn + + return result + + @classmethod + @functools.lru_cache(None) + def _id_dispatch(cls): + from ..comptime import comptime + + entries = [ + ( + inspect.signature, + lambda self, value: LambdaVariable( + InspectSignatureVariable.create, + source=self.source, + **self.install_guards(GuardBuilder.CLOSURE_MATCH), + ), + ), + (comptime, lambda self, value: ComptimeVariable()), + ( + dataclasses.fields, + lambda self, value: LambdaVariable( + _dataclasses_fields_lambda, + source=self.source, + **self.install_guards(GuardBuilder.FUNCTION_MATCH), + ), + ), + ] + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, (tuple, list)) else (ts,): + assert t not in result + result[id(t)] = fn + + return result + + def _wrap(self, value): + # import here to avoid circular dependencies + from torch.utils._triton import has_triton + + if has_triton(): + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + else: + + class JITFunction: + pass + + class Autotuner: + pass + + # Handle exact type() match + type_dispatch = self._type_dispatch().get(type(value)) + if type_dispatch is not None: + return type_dispatch(self, value) + + # Handle exact id() match + id_dispatch = self._id_dispatch().get(id(value)) + if id_dispatch is not None: + return id_dispatch(self, value) + + # Note - There are some nested values where types mismatch! + # We want to get those out and wrap those. + value = inspect.getattr_static(value, "_torchdynamo_inline", value) + + # Everything else (NB: order matters!) + if is_traceable_wrapper_subclass(value) or istype( + value, config.traceable_tensor_subclasses + ): + return self.wrap_tensor(value) + elif is_namedtuple(value): + return self.wrap_listlike(value) + + elif value is torch.utils._pytree.SUPPORTED_NODES: + # For SUPPORTED_NODES, we guard on the dictionary version (PEP509) + # under the assumption that the values themselves don't change. + self.install_guards(GuardBuilder.DICT_VERSION) + result = { + ConstantVariable.create(k): UserDefinedObjectVariable( + v, + source=GetItemSource( + self.get_source(), ConstDictKeySource(self.get_source(), i) + ), + ) + for i, (k, v) in enumerate(value.items()) + } + return ConstDictVariable(result, type(value)) + elif value is sys.modules: + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return PythonSysModulesVariable(source=self.source) + elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): + if not value and self.get_source().is_nn_module(): + # It is faster to guard on 'false' property than to guard + # on actual dict keys, but we can't do this fast guard in general because + # it omits a crucial type check that ensures the value is actually still a dict at runtime. + + # Why is this OK for (specialized) nnmodules? We set up a setattr hook + # to check for module property mutations, which does a reasonable, + # but not completely secure job ensuring a property wasn't changed. + self.install_guards(GuardBuilder.BOOL_FALSE) + else: + self.install_guards(GuardBuilder.DICT_LENGTH) + + # Optimisation for the common case strings, ints, etc + all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) + if all_const: + self.install_guards(GuardBuilder.DICT_CONST_KEYS) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + if all_const: + key = ConstantVariable.create(k) + source_key = k + else: + source_key = ConstDictKeySource(self.get_source(), i) + key = LazyVariableTracker.create(k, source_key) + + source_value = GetItemSource(self.get_source(), source_key) + value = LazyVariableTracker.create(v, source_value) + + return key, value + + result = dict( + build_key_value(i, k, v) for i, (k, v) in enumerate(value.items()) + ) + + if istype(value, collections.defaultdict): + factory_source = AttrSource(self.source, "default_factory") + result = DefaultDictVariable( + result, + type(value), + default_factory=VariableBuilder(self.tx, factory_source)( + value.default_factory + ), + source=self.source, + ) + else: + result = ConstDictVariable(result, type(value), source=self.source) + + return self.set_source_and_track_mutable(value, result) + elif isinstance(value, torch.nn.Module): + return self.wrap_module(value) + elif ConstantVariable.is_literal(value): # non-atomic literals + return self.wrap_literal(value) + elif istype(value, frozenset) and ( + ConstantVariable.is_literal(x) for x in value + ): + # For frozenset, we can guard by object ID instead of value + # equality, this allows us to handle non-literal values + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantVariable.create(value=value, source=self.source) + elif isinstance(value, enum.Enum): + self.install_guards(GuardBuilder.ID_MATCH) + return EnumVariable(value=value, source=self.source) + elif DebuggingVariable.is_reorderable_logging_function(value): + # Put this above builtin_callable so that print() can be handled + # along with other builtin debugging functions + self.install_guards(GuardBuilder.BUILTIN_MATCH) + return DebuggingVariable(value, source=self.source) + elif is_utils_checkpoint(value): + return build_checkpoint_variable(source=self.source) + elif isinstance(value, functools.partial): + func_src = AttrSource(self.get_source(), "func") + func_obj = VariableBuilder(self.tx, func_src)(value.func) + + args = [] + args_source = AttrSource(self.get_source(), "args") + for i, arg in enumerate(value.args): + args.append( + VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) + ) + + keywords = {} + keywords_source = AttrSource(self.get_source(), "keywords") + for k, v in value.keywords.items(): + if not ConstantVariable.is_literal(k): + unimplemented("functools.partial with non-literal keyword") + keywords[k] = VariableBuilder( + self.tx, GetItemSource(keywords_source, k) + )(v) + + install_guard( + self.get_source().make_guard(GuardBuilder.TYPE_MATCH), + keywords_source.make_guard(GuardBuilder.DICT_KEYS), + args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), + ) + return FunctoolsPartialVariable(func_obj, args, keywords) + elif is_typing(value): + # typing.List, typing.Mapping, etc. + self.install_guards(GuardBuilder.ID_MATCH) + return TypingVariable( + value, + source=self.source, + ) + elif np is not None and isinstance(value, np.generic): + # numpy array scalars: convert to 0D arrays + return self.wrap_numpy_ndarray(np.asarray(value)) + elif is_numpy(value): + assert np + self.install_guards( + GuardBuilder.FUNCTION_MATCH + if callable(value) + else GuardBuilder.TYPE_MATCH + ) + return NumpyVariable(value, source=self.source) + # NB: These can't be put in type_dispatch, they have to run later + elif CollectiveFunctionRewriteVariable.can_rewrite(value): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return CollectiveFunctionRewriteVariable.create( + self.tx, + value, + source=self.source, + ) + elif istype(value, torch.autograd.function.FunctionMeta): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return AutogradFunctionVariable( + value, + source=self.source, + ) + elif isinstance(value, torch.autograd.function.FunctionCtx): + saved_tensors_source = AttrSource(self.source, "saved_tensors") + install_guard( + self.source.make_guard(GuardBuilder.TYPE_MATCH), + saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), + ) + saved_tensors = [ + VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v) + for n, v in enumerate(value.saved_tensors) + ] + return self.tx.output.side_effects.track_object_existing( + value, + AutogradFunctionContextVariable( + value, + source=self.source, + saved_tensors=SavedTensorBox(saved_tensors), + ), + ) + elif ( + isinstance(value, types.MethodType) + and istype( + getattr(value, "__self__", None), torch.autograd.function.FunctionMeta + ) + and getattr(value, "__name__", "") == "apply" + and value == getattr(value.__self__, "apply", None) + ): + # handle aliased autograd function `apply` calls + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return GetAttrVariable( + AutogradFunctionVariable( + value.__self__, source=AttrSource(self.source, member="__self__") + ), + "apply", + ) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if is_callable_allowed(value): + self.tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value).create_with_source( + value, source=self.source + ) + elif np and isinstance(value, np.number): + return self.wrap_unspecialized_primitive(value) + elif DataClassVariable.is_matching_object(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return DataClassVariable.wrap(self, value) + elif HFPretrainedConfigVariable.is_matching_object(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return HFPretrainedConfigVariable(value) + elif isinstance(value, HigherOrderOperator): + self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) + return TorchHigherOrderOperatorVariable.make(value, source=self.source) + elif isinstance(value, torch.cuda.StreamContext): + self.install_guards(GuardBuilder.ID_MATCH) + stream_source = AttrSource(self.source, "stream") + stream_var = VariableBuilder(self.tx, stream_source)(value.stream) + return StreamContextVariable.create(self.tx, stream_var) + elif isinstance(value, _StreamBase): + self.install_guards(GuardBuilder.ID_MATCH) + return StreamVariable( + None, + value, + value.device, + source=self.source, + ) + elif isinstance(value, (torch._C._SDPAParams)): + self.install_guards(GuardBuilder.TYPE_MATCH) + return SDPAParamsVariable.create(self.tx, value, self.source) + elif isinstance(value, _EventBase): + self.install_guards(GuardBuilder.ID_MATCH) + return EventVariable( + None, + value, + source=self.source, + ) + elif ( + isinstance(value, torch._C._TensorMeta) + and value in config.traceable_tensor_subclasses + ): + return TensorSubclassVariable(value, source=self.source) + elif ( + istype(value, contextlib.nullcontext) + and inspect.getattr_static(value, "enter_result", None) is None + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return NullContextVariable(source=self.source) + elif KeyedJaggedTensorVariable.is_matching_object(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = KeyedJaggedTensorVariable(value, source=self.source) + # TODO: this doing it manually is bad + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, torch.optim.Optimizer): + self.install_guards(GuardBuilder.TYPE_MATCH) + return OptimizerVariable(value, source=self.source) + elif ProcessGroupVariable.is_process_group(value): + self.install_guards(GuardBuilder.ID_MATCH) + return ProcessGroupVariable(value, source=self.source) + elif DeviceMeshVariable.is_device_mesh(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.ID_MATCH) + return DeviceMeshVariable(value, source=self.source) + elif PlacementClassVariable.is_placement_type(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.ID_MATCH) + return PlacementClassVariable(value, source=self.source) + elif PlacementVariable.is_placement(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.ID_MATCH) + return PlacementVariable( + value, + source=self.source, + ) + elif istype(value, type) and value in itertools.__dict__.values(): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return ItertoolsVariable(value, source=self.source) + elif isinstance(value, torch.SymBool): + # Note: the idea here is to re-use the infra we've built for SymInt by simulating the + # user provided SymBool with a SymInt in dynamo. + + # Concretely, + # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). + # so that guards on the SymInts can be effectively applied on the original SymBool in user program. + # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program + # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. + + value_hint = value.node.require_hint() + new_source = ConvertIntSource(self.source) + + new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol( + int(value_hint), + new_source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(new_symint), + source=new_source, + ) + + sym_node_proxy.node.meta["grapharg"] = GraphArg( + new_source, + new_symint, + False, + None, + is_tensor=False, + example_strong_ref=new_symint, + ) + self.tx.output.bound_symbols.add(new_symint.node.expr) + self.tx.output.tracked_fakes.append( + TrackedFake(new_symint, new_source, None) + ) + return SymNodeVariable( + sym_node_proxy, + new_symint == 1, + ) + elif isinstance(value, (JITFunction, Autotuner)): + self.install_guards(GuardBuilder.ID_MATCH) + return TritonKernelVariable( + value, + None, # No kernel idx provided + None, # No grid provided + source=self.source, + ) + elif isinstance(value, torch.amp.autocast_mode.autocast): + self.install_guards(GuardBuilder.ID_MATCH) + return AutocastModeVariable( + target_values=[ + value.device, + value.fast_dtype, + value._enabled, + value._cache_enabled, + ], + source=self.source, + ) + elif TorchCtxManagerClassVariable.is_matching_cls(value): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return TorchCtxManagerClassVariable(value, source=self.source) + elif is_function_or_wrapper(value): + value, attr_name = unwrap_with_attr_name_if_wrapper(value) + # For these wrappers, Dynamo points to the wrapped function, + # so source needs to be updated as well. + if attr_name is not None: + self.source = AttrSource(self.source, attr_name) + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + # Don't use istype, since some python modules are not subclasses of types.ModuleType directly. + # E.g, type(torch.ops) -> , + # type(torch.backends.cudnn) -> + elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return PythonModuleVariable( + value, + source=self.source, + ) + elif isinstance(value, types.MethodType) and isinstance( + value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) + ): + # don't let MethodTypes fall through to UserDefinedObject, + # which doesn't support 'CALL_FUNCTION' + + # TODO(whc): Why do we limit this to methods on NNModules? + # I don't have a good reason for this, but it preserves the existing behavior + # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. + # I suspect we probably want to relax this check and dig deeper there. + + # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, + # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here + # and then `__func__` gets wrapped inside UserMethodVariable. + self_obj = VariableBuilder( + self.tx, source=AttrSource(self.source, "__self__") + )(value.__self__) + assert self_obj and isinstance( + self_obj, VariableTracker + ), "Failed to produce a valid self obj" + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return UserMethodVariable( + value.__func__, + self_obj, + source=self.source, + ) + elif isinstance(value, types.GetSetDescriptorType): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return GetSetDescriptorVariable(value) + elif isinstance(value, types.MethodWrapperType): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return MethodWrapperVariable(value) + elif issubclass(type(value), type): + if value in (torch.utils.hooks.BackwardHook, torch.nn.Parameter): + # TODO(jansel): combine this case with the one above + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + if value is torch.autograd._unsafe_preserve_version_counter: + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return PreserveVersionContextVariable.constructor(self.tx) + # This is a userdefined class, so install an ID_MATCH even if its a + # global variable. + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedClassVariable( + value, + source=self.source, + ) + elif RestrictedListSubclassVariable.is_matching_cls(type(value)): + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + return self.set_source_and_track_mutable( + value, + RestrictedListSubclassVariable( + [ + LazyVariableTracker.create( + value=value[i], source=GetItemSource(self.source, i) + ) + for i in range(len(value)) + ], + user_cls=type(value), + user_cls_source=AttrSource(self.source, "__class__"), + ), + ) + else: + self.install_guards(GuardBuilder.TYPE_MATCH) + result = UserDefinedObjectVariable(value, source=self.source) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + + def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): + if config.specialize_int and type(value) is torch.Size: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + # One can index a tensor with a list/tuple. Therefore, we need to + # have a stricter match. + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + for item in value: + if item is value: + unimplemented("list elements are pointing to the list itself") + + output = [ + LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i)) + for i, item in enumerate(value) + ] + + result = BaseListVariable.cls_for_instance(value)( + output, mutable_local=MutableLocal() + ) + if istype(value, list): + return self.set_source_and_track_mutable(value, result) + return result + + def wrap_tuple_iterator(self, value: tuple_iterator): + self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) + output = [ + VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( + tuple_iterator_getitem(value, i) + ) + for i in range(tuple_iterator_len(value)) + ] + result = TupleIteratorVariable( + output, mutable_local=MutableLocal(), source=self.source + ) + + return self.set_source_and_track_mutable(value, result) + + def wrap_slice_range(self, value: Union[slice, range]): + items = [ + VariableBuilder(self.tx, AttrSource(self.get_source(), k))( + getattr(value, k) + ) + for k in ("start", "stop", "step") + ] + self.install_guards(GuardBuilder.TYPE_MATCH) + if isinstance(value, slice): + return SliceVariable(items, source=self.source) + else: + return RangeVariable(items, source=self.source) + + def wrap_module(self, value: torch.nn.Module): + from ..eval_frame import OptimizedModule + + if istype(value, OptimizedModule): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.source = AttrSource(self.source, "_orig_mod") + return self.wrap_module(value._orig_mod) + + if ( + isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) + and not config.allow_rnn + ): + unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") + if mutation_guard.is_dynamic_nn_module(value): + # created dynamically, don't specialize on it + self.install_guards(GuardBuilder.TYPE_MATCH) + result = UnspecializedNNModuleVariable(value, source=self.source) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass( + value.__class__, torch.nn.parallel.distributed.DistributedDataParallel + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return UnspecializedNNModuleVariable(value) + elif getattr(value, "_is_fsdp_managed_module", False): + # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + # in fully_sharded_data_parallel.py for more information + + # we can't do this assert inside FSDP constructor, + # since we don't know yet whether dynamo will be used + assert getattr( + value, "_fsdp_use_orig_params", False + ), "Dynamo only supports FSDP with use_orig_params=True" + + # Note on FSDP guarding + # 1. We expect FSDP wrapping mutates an nn module irreversably (no way to de-wrap). + # 2. Eager FSDP already assumes (requires, but without enforcement) that users don't mutate their + # model parameters/structure after FSDP wrapping, because FSDP wouldn't notice or update its FlatParams. + # + # Due to (1), once we enter this path we expect not to go back nor have to guard on type + # or _is_fsdp_managed_module. + # + # TODO(whc) We could add a guard on the opposite case, where a user compiled/ran + # pre-FSDP-wrapped model, then wrapped, to ensure that we recompile with the FSDP handling. + # + # Due to (2), we skip guards on inner contents of fsdp_managed modules, by using FSDPNNModuleSource as the + # guard source. This behavior is gated on config.skip_fsdp_guards. + # + # ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps + # them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager) + self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH) + return FSDPManagedNNModuleVariable(value, source=self.get_source()) + else: + return self.tx.output.register_attr_or_module( + value, + self.name, + source=self.get_source(), + # Guards are added inside register_attr_or_module + ) + + def wrap_literal(self, value): + unspec = not config.specialize_int + if unspec and type(value) is int: + # unspecializing int by default, but still + # specialize for the following conditions + if not TracingContext.get().force_unspec_int_unbacked_size_like and ( + value in self._common_constants() + # Assume integers from global variables want to be specialized + or not self.source.guard_source().is_local() + # Assume that integers that came from NN modules want to be + # specialized (as we don't expect users to be changing the + # NN modules on the fly) + or self.source.guard_source().is_nn_module() + or is_from_defaults(self.source) + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + else: + return self.wrap_unspecialized_primitive(value) + else: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): + if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: + raise InternalTorchDynamoError( + "Cannot wrap a Tensor that has already been", + "wrapped by this instance of Dynamo", + ) + + def wrap_tensor(self, value: torch.Tensor): + source = self.get_source() + + # We cannot already be tracking the tensor, which implies + # it would have already been wrapped + assert value not in self.tx.output.side_effects + + if ( + source.guard_source().is_nn_module() + or get_static_address_type(value) is not None + ) and not source.guard_source().is_fsdp_module(): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if is_constant_source(source): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=source, + # Guards are added inside register_attr_or_module + ) + + if type(value) in config.traceable_tensor_subclasses: + # Ordinarily, we would fakeify a tensor so that it can get dynamic + # shapes and be computed on without triggering actual operations. + # However, how can we fakeify a tensor subclass? Ordinary + # inheritance (nor multiple inheritance) won't work work. + # + # Instead, our plan is to *manually simulate* the tensor subclass + # inheriting from a fake tensor with dynamo. This means our + # data representation for a tensor subclass will be a fake tensor + # + tensor subclass type + any extra data the subclass may have + # been storing on the tensor. Because all Python accesses are + # mediated through TensorWithTFOverrideVariable, we can ensure + # that we dispatch differently, e.g., according to + # __torch_function__ + # + # To simplify things for now, the __dict__ tracking bits haven't + # been implemented yet, but they can be added into this design at + # a later point in time. + subclass_type = type(value) + else: + assert type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value), type(value) + subclass_type = None + + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + # By this point, we should have deduplicated all tensors + self.assert_not_wrapped_by_this_graph(value) + + # tx.output has multiple tracers if we're introspecting HigherOrderOperator. + # When we've discovered an untracked tensor, then we actually need + # to get Dynamo to track the tensor (which is what this function does) + # and put it as a graph input on the root tracer. Later on, + # if the input is actually used in the body of the HigherOrderOperator, + # then the relevant SubgraphTracer will lift it to being an input of + # the subgraph. + # See NOTE [HigherOrderOperator tracing design] for more details. + + tensor_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source + ) + options = {} + if type(value) in config.traceable_tensor_subclasses: + options["torch_function_fn"] = build_torch_function_fn( + self.tx, value, self.source + ) + self.install_guards(GuardBuilder.TYPE_MATCH) + + if ( + isinstance(value, torch.Tensor) + and value.is_nested + and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) + ): + unimplemented("torch.compile does not support strided NestedTensor") + + if is_sparse_any(value): + unimplemented( + f"torch.compile does not support sparse Tensor with {value.layout} layout" + ) + + tensor_variable = wrap_fx_proxy( + tx=self.tx, + proxy=tensor_proxy, + example_value=value, + subclass_type=subclass_type, + source=source, + **options, + ) + + self.install_guards( + functools.partial( + GuardBuilder.TENSOR_MATCH, + value=value + if isinstance(source, NumpyTensorSource) + else TensorWeakRef(value), + ) + ) + + # We install TYPE_MATCH guards for traceable wrapper subclass object, + # and recursively install corresponding guard for each inner attribute. + if is_traceable_wrapper_subclass(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + attrs, _ = value.__tensor_flatten__() + for attr in attrs: + inner_value = getattr(value, attr) + inner_source = AttrSource(self.source, attr) + VariableBuilder(self.tx, inner_source)(inner_value).recursive_realize() + + self.tx.output.input_source_to_var[source] = tensor_variable + assert "tensor_dict" not in tensor_proxy.node.meta + tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy() + + # Note: this information is conveyed via subclass_type now + fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] + if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: + raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") + + grapharg = GraphArg(source, value, False, fake_tensor_value) + tensor_proxy.node.meta["grapharg"] = grapharg + self.tx.output.add_symbol_bindings(grapharg) + return tensor_variable + + def wrap_numpy_ndarray(self, value): + assert np is not None + assert isinstance(value, np.ndarray) + + source = NumpyTensorSource(self.get_source()) + + from torch._numpy import _util + + readonly = not value.flags.writeable + if readonly: + try: + value.flags.writeable = True + except ValueError: + # One can not easily make nditer elements writable, + # but warning is not the end of the world + assert isinstance(value.base, np.nditer) + pass + + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) + + # We do this because we want the full behavior of guarding the numpy ndarray as if it were + # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here + # that there's not another great way to do this atm. + # This creates the right graphargs, as well as registration for guards in tensor names and shape env. + VariableBuilder(self.tx, source)(tensor_value).recursive_realize() + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source + ) + options = {"source": source} + numpy_ndarray_variable = wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=self.tx, + proxy=proxy, + example_value=tensor_value, + **options, + ) + + self.tx.output.input_source_to_var[source] = numpy_ndarray_variable + example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] + + # is_unspecialized should be true because we are wrapping a np.ndarray as argument input, and it needs to be + # converted to a tensor. + grapharg = GraphArg( + source, + tensor_value, + is_unspecialized=True, + fake_tensor=example_value, + is_tensor=True, + example_strong_ref=tensor_value, + ) + proxy.node.meta["grapharg"] = grapharg + + return numpy_ndarray_variable + + def wrap_unspecialized_primitive(self, value): + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + else: + shape_env = self.tx.output.shape_env + if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance( + value, int + ): + wrapped_value = shape_env.create_unbacked_symint() + _constrain_range_for_size(wrapped_value) + self.tx.output.bound_symbols.add(wrapped_value.node.expr) + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, None) + ) + + # NB: We do not do float. For motivation, see + # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit + # but the general idea is that we generate kernels that can + # take unspecialized floats and use them in sizevar computation + elif ( + isinstance(value, int) + and not is_constant_source(self.get_source()) + and not isinstance(self.get_source(), RandomValueSource) + ): + if torch._dynamo.config.specialize_int: + # If specialize_int is False, also return + # a constant (but this should have been handled + # in the caller, TBH) + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + name = self.source.name() + if name not in self.tx.output.frame_state: + # Note - this essentially means that if this name gets reused as a tensor, + # it will start fully dynamic. That should always be a safe option, and not awfully inefficient. + # Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not + # sure that is necessary for now. + frame_state_entry = FrameStateSizeEntry(scalar=value, size=None) + else: + frame_state_entry = self.tx.output.frame_state[name] + if frame_state_entry.scalar != value: + log.debug( + "automatic dynamic int %s val %s != %s", + name, + value, + frame_state_entry.scalar, + ) + frame_state_entry.scalar = None + self.tx.output.frame_state[name] = frame_state_entry + + # TODO: This should be dynamic, as we in general do not + # know if bare integers are actually going to be sizevars + # and it is inappropriate to eagerly duck size them with + # real sizevars + if ( + config.automatic_dynamic_shapes and frame_state_entry.scalar is None + ) or not config.assume_static_by_default: + dynamic_dim = DimDynamic.DYNAMIC + else: # assume_static_by_default + # TODO: dynamic_dim = DimDynamic.STATIC should work but + # for some reason it doesn't + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + wrapped_value = shape_env.create_unspecified_symint_and_symbol( + value, + source=self.source, + dynamic_dim=dynamic_dim, + ) + self.tx.output.bound_symbols.add(wrapped_value.node.expr) + + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, None) + ) + else: + wrapped_value = torch.tensor(value) + if not isinstance(self.get_source(), RandomValueSource): + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + options = {"source": self.get_source()} + if isinstance(wrapped_value, torch.Tensor): + options.update({"raw_value": value}) + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + source=self.get_source(), + ) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=wrapped_value, + **options, + ) + self.tx.output.unspec_variable_map[self.name] = unspec_var + if not is_constant_source(self.get_source()): + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + "Dynamo attempts to add additional input during export: value={}, source={}".format( + wrapped_value, self.get_source() + ) + ) + fake_tensor_value = None + if isinstance(unspec_var, ConstantVariable): + example_value = unspec_var.value + else: + example_value = unspec_var.proxy.node.meta["example_value"] + if is_fake(example_value): + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + isinstance(wrapped_value, torch.Tensor), + fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + return unspec_var + + +def _dataclasses_fields_lambda(obj): + if isinstance(obj, UserDefinedObjectVariable): + value = obj.value + elif isinstance(obj, DataClassVariable): + value = obj.user_cls + else: + unimplemented(f"Dataclass fields handling fails for type {obj}") + items = [] + for field in dataclasses.fields(value): + source = None + if obj.source: + source = GetItemSource( + AttrSource(obj.source, "__dataclass_fields__"), field.name + ) + items.append(UserDefinedObjectVariable(field, source=source)) + return TupleVariable(items) + + +def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options): + kwargs = { + "tx": tx, + "proxy": proxy, + "example_value": example_value, + "subclass_type": subclass_type, + **options, + } + if subclass_type is None: + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + else: + result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) + result.install_global(tx) + return result + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +# +# This is a horribly complicated function that does too many things, to +# explain what it does, let's first talk about the classic usage wrap_fx_proxy +# for a TensorVariable. There are two primary modes of use: +# +# 1. Wrapping a pre-existing Tensor. In this case, example_value is set +# to the pre-existing Tensor. (Note that this example_value will NOT +# be the final example_value we put into node.meta['example_value'], +# instead it is converted into a fake tensor using +# wrap_to_fake_tensor_and_record and registered as a graph input.) +# +# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In +# this case, example_value is None (and we are going to figure it out +# ourselves using FakeTensors, via get_fake_value, which will run +# the operation represented by the (singular!) FX node referenced by +# the passed in proxy.) +# +# The expectation is you end up with a Tensor output, and everything is +# straightforwardly traced into the graph. +# +# In all cases, the returned `TensorVariable` subclass will have an `example_value` +# and that `example_value` must be a `FakeTensor` produced by the currently running +# instance of Dynamo. +# +# Upon closer inspection, you may notice that there are a slurry of non-Tensor +# output cases. What gives? Well, we sometimes trace operations into the +# graph that don't involve tensors. +# +# * Some operators return tuples; we need to recursively handle their +# contents +# +# * Some operators have side effects that will affect subsequent AOTAutograd +# tracing but don't otherwise return anything. +# +# * Some operators return symbolic ints/floats/bools which can go in the +# graph and be traced (but only if they're actually symbolic! If they're +# static you don't want to put them in the graph, which means you +# shouldn't call this function.) +# +# The common theme is that you only use this function WHEN YOU ARE TRACING +# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call +# this function without a proxy. +def wrap_fx_proxy_cls( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" + + initial_example_value = example_value + + def _clone_input(value): + if isinstance(value, torch.Tensor): + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not ( + isinstance(value, FakeTensor) + or ( + # Is functional tensor fakeified by this instance of Dynamo + torch._is_functional_tensor(value) + and maybe_get_fake_mode(value) is tx.fake_mode + ) + or value.is_nested + ): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + with preserve_rng_state(): + if example_value is None: + # only allow_non_graph_fake in this instance because we handle the non-fake + # cases properly below. + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + + # Handle recursive calls here + elif maybe_get_fake_mode(example_value) is tx.fake_mode: + pass + + elif isinstance(example_value, torch.Tensor): + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + example_value + ) + # NB: If we're ignoring subclass, then the expectation is you will + # take the returned TensorVariable and wrap it into a more + # accurate TensorVariable that is able to track subclass-ness; + # otherwise this is wrong! + kwargs = { + "is_tensor": target_cls + in (TensorVariable, TensorWithTFOverrideVariable), + } + assert "source" in options and options["source"] is not None + kwargs["source"] = options["source"] + example_value = wrap_to_fake_tensor_and_record( + example_value, tx=tx, **kwargs + ) + if isinstance(example_value, torch.Tensor) and ( + maybe_get_fake_mode(example_value) is not tx.fake_mode + ): + raise InternalTorchDynamoError( + "`example_value` needs to be a `FakeTensor`" + f"wrapped by this instance of Dynamo. Found: {example_value}" + ) + + if isinstance(example_value, torch.Tensor): + is_parameter = isinstance(example_value, torch.nn.Parameter) + + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value) + proxy.node.meta["example_value"] = example_value + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + tensor_type = subclass_type if subclass_type else torch.Tensor + specialized_props["class_type"] = ( + torch.nn.Parameter if is_parameter else tensor_type + ) + + options.update(specialized_props) + return target_cls(proxy, **options) + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target == torch.random.set_rng_state + ): + return TorchInGraphFunctionVariable(proxy.node.target) + elif ( + proxy.node.target == torch._C._DisableFuncTorch + or proxy.node.target == torch.cuda._is_in_bad_fork + ): + return UserDefinedObjectVariable(example_value) + elif istype(example_value, torch.Size) and all( + isinstance(x, int) for x in example_value + ): + sizes = [ConstantVariable.create(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + proxy.node.meta["example_value"] = example_value + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable.create(None, **options), + ) + else: + unpacked.append( + wrap_fx_proxy_cls( + target_cls, + tx, + proxy.tracer.create_proxy( + "call_function", operator.getitem, (proxy, i), {} + ), + example_value=val, + **options, + ) + ) + if isinstance(example_value, torch.Size): + # NB: Keep the old proxy around. See SizeVariable for an + # explanation why + return SizeVariable(unpacked, proxy, **options) + elif istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, mutable_local=MutableLocal(), **options) + else: + assert example_value.__class__.__module__ == "torch.return_types" or hasattr( + example_value, "_fields" + ), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable.create(None, **options) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + proxy.node.meta["example_value"] = example_value + return SymNodeVariable(proxy, example_value, **options) + elif ( + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, _StreamBase) + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + proxy.node.meta["example_value"] = example_value + return StreamVariable(proxy, example_value, example_value.device, **options) + elif ( + inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase) + ) or proxy.node.target in [ + device_interface.Event + for _, device_interface in get_registered_device_interfaces() + ]: + proxy.node.meta["example_value"] = example_value + return EventVariable(proxy, example_value, **options) + elif proxy.node.target == "query" and proxy.node.op == "call_method": + proxy.node.meta["example_value"] = example_value + return ConstantVariable(example_value, **options) + elif ( + example_value is not None + and isinstance(example_value, _EventBase) + and proxy.node.target == "record_event" + and proxy.node.op == "call_method" + ): + proxy.node.meta["example_value"] = example_value + return EventVariable(proxy, example_value, **options) + elif isinstance(example_value, int) and proxy.node.target in [ + torch.sym_int, + getattr, + operator.getitem, + torch._utils._element_size, + torch.seed, + operator.mod, + torch._C._functorch._vmap_increment_nesting, + torch._C._functorch._vmap_decrement_nesting, + torch._functorch.vmap._validate_and_get_batch_size, + torch._C._functorch._grad_increment_nesting, + torch._C._functorch._grad_decrement_nesting, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + # This always wants to be in the graph, even if the constraint + # results in a constant int + torch._constrain_as_value, + torch._constrain_as_size, + ]: + proxy.node.meta["example_value"] = example_value + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, torch.backends.cuda.SDPAParams): + from .sdpa import SDPAParamsVariable + + proxy.node.meta["example_value"] = example_value + return SDPAParamsVariable(proxy, **options) + elif isinstance(example_value, bool) and proxy.node.target in [ + torch.backends.cuda.can_use_flash_attention, + torch.backends.cuda.can_use_efficient_attention, + ]: + proxy.node.meta["example_value"] = example_value + return ConstantVariable.create(example_value, **options) + else: + unimplemented( + "torch.* op returned non-Tensor " + + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" + ) + + +# Tracks the sources of all fake tensors we wrap in Dynamo. +# Used by shape guard computation. +@dataclasses.dataclass +class TrackedFake: + fake: Union[FakeTensor, SymInt] + source: Source + # Is None when fake is SymInt + symbolic_context: Optional[SymbolicContext] + + def __hash__(self) -> int: + return hash((self.fake, self.source.name())) + + def __eq__(self, other: object) -> bool: + if isinstance(other, TrackedFake): + return self.fake is other.fake and self.source.name() == other.source.name() + return False + + +# Performs automatic dynamic dim determination. +# Returns a SymbolicContext +def _automatic_dynamic( + e, tx, source, static_shapes, outer_only=False +) -> SymbolicContext: + # strided NT not supported + if e.is_nested and not isinstance( + e, torch.nested._internal.nested_tensor.NestedTensor + ): + unimplemented("torch.compile does not support strided NestedTensor") + + name = source.name() + prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) + shape_env_to_source_to_symbol_cache = ( + prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None + ) + + # Get base context if the tensor is a view + view_base_context: Optional[SymbolicContext] = None + if e._is_view(): + base_source = AttrSource(source, "_base") + view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) + + if is_traceable_wrapper_subclass(e) and not outer_only: + # Get symbolic context for outer tensor + outer_context = _automatic_dynamic( + e, tx, source, static_shapes, outer_only=True + ) + + # Get symbolic contexts for inner tensors + attrs, _ = type(e).__tensor_flatten__(e) + inner_contexts = {} # mapping from attr -> symbolic context + for attr in attrs: + inner_tensor = getattr(e, attr) + inner_source = AttrSource(source, attr) + inner_context = _automatic_dynamic( + inner_tensor, tx, inner_source, static_shapes + ) + inner_contexts[attr] = inner_context + + return SubclassSymbolicContext( + dynamic_sizes=outer_context.dynamic_sizes, + constraint_sizes=outer_context.constraint_sizes, + view_base_context=view_base_context, + tensor_source=outer_context.tensor_source, + shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, + inner_contexts=inner_contexts, + ) + + if static_shapes: + return StatefulSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC] * e.dim(), + constraint_sizes=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # We preserve the dynamism of inputs. For example, when users call + # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. + from torch.fx.experimental.symbolic_shapes import is_nested_int + + if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): + return StatefulSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC + for s in e.size() + ], + constraint_sizes=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # Prep for automatic dynamic + frame_state_entry = None + if name not in tx.output.frame_state: + # If there is no entry for this source, add the tensor to frame state with its current static size. + # E.g., {} -> {"x": [2, 4]} + frame_state_entry = FrameStateSizeEntry(None, None) + frame_state_entry.size = list(e.size()) + else: + frame_state_entry = tx.output.frame_state[name] + if frame_state_entry.size is not None: + if e.ndim != len(frame_state_entry.size): + # If there is already an entry, and the dim mismatches, replace the frame state entry with None. + # E.g. {"x": [2, 3, 4]} -> {"x": None} + log.debug( + "automatic dynamic %s dim %s != %s", + name, + e.ndim, + frame_state_entry.size, + ) + frame_state_entry.size = None + else: + # If there is already an entry, and the dim matches, for every size in the frame state which + # disagrees with the current static size, replace it with None. E.g., {"x": [2, 3]} -> {"x": [2, None]} + for i, dim in enumerate(frame_state_entry.size): + if dim is not None and e.size()[i] != dim: + log.debug( + "automatic dynamic %s size(%s) %s != %s", + name, + i, + e.size(i), + dim, + ) + frame_state_entry.size[i] = None + + # TODO: index export_constraints ahead of time so we don't have to + # do a linear scan every time here + t_id = id(e) + dim2constraint = {} + + def update_dim2constraint(dim, constraint_range, debug_name): + if dim in dim2constraint: + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + old_constraint_range, old_debug_name = dim2constraint[dim] + new_constraint_range = StrictMinMaxConstraint( + vr=constraint_range.vr & old_constraint_range.vr, + warn_only=False, + ) + # It is possible for (non-None) old_debug_name and debug_name to be different + # but this will only happen the corresponding Dims can be derived equal. + new_debug_name = old_debug_name or debug_name + dim2constraint[dim] = new_constraint_range, new_debug_name + else: + dim2constraint[dim] = constraint_range, debug_name + + if tx.output.export_constraints: + for constraint in tx.output.export_constraints: + if constraint.t_id == t_id: + update_dim2constraint( + constraint.dim, constraint.constraint_range, constraint.debug_name + ) + if constraint.shared is not None and constraint.shared.t_id == t_id: + # We process constraint ranges for each shared dimension separately + # so that we can directly check range constraint violations on them + # without looking up which other shared dimensions have this info. + # In other words, for this t_id, we will have processed all of its + # constraint ranges, no matter where / how they were specified, by + # by the end of this loop. + update_dim2constraint( + constraint.shared.dim, + constraint.constraint_range, + constraint.debug_name, + ) + + dynamic_dims = [] + constraint_dims = [] + for i in range(e.dim()): + # NB: mark dynamic has precedence over static + marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) + marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) + marked_static = i in getattr(e, "_dynamo_static_indices", set()) + + # NB: both static and dynamic have precedence over + automatic_dynamic = config.automatic_dynamic_shapes and ( + frame_state_entry.size is None or frame_state_entry.size[i] is None + ) + + # Reflect the user directive in the frame_state + # For dynamic, apply None always + if frame_state_entry.size and marked_dynamic: + log.debug("automatic dynamic %s marked dynamic", name) + frame_state_entry.size[i] = None + + # We will process constraints first, as they will imply that we + # have a dynamic dimension + # Precedence: export constraints > eager constraints + constraint = dim2constraint.get(i) + if constraint is None: + if marked_dynamic and not config.allow_ignore_mark_dynamic: + if hasattr(e, "_dynamo_dynamic_range"): + dim_range = [ + dr for dr in e._dynamo_dynamic_range if dr.dim == i + ].pop() + if dim_range.min is None and dim_range.max is None: + constraint_dim = RelaxedUnspecConstraint(warn_only=False) + else: + from torch.fx.experimental.symbolic_shapes import ( + StrictMinMaxConstraint, + ) + + constraint_dim = StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), + warn_only=False, + ) + else: + constraint_dim = RelaxedUnspecConstraint(warn_only=False) + + elif not marked_static and automatic_dynamic: + constraint_dim = RelaxedUnspecConstraint(warn_only=True) + else: + constraint_dim = None + else: + constraint_dim, debug_name = constraint + if debug_name is not None: + dim_name = f"{name}.size()[{i}]" + tx.output.shape_env.source_name_to_debug_name[dim_name] = debug_name + constraint_dims.append(constraint_dim) + + # Now, figure out if the dim is dynamic/duck/static + if ( + constraint_dim is not None + or marked_dynamic + or marked_weak_dynamic + or is_nested_int(e.shape[i]) + ): + # NB: We could assert static_shapes is False here, but it + # seems better to allow the user to override symbolic_context in this + # case + dynamic = DimDynamic.DYNAMIC + elif static_shapes or config.assume_static_by_default or marked_static: + dynamic = DimDynamic.STATIC + else: + dynamic = DimDynamic.DUCK + + dynamic_dims.append(dynamic) + + tx.output.frame_state[name] = frame_state_entry + + return StatefulSymbolicContext( + dynamic_sizes=dynamic_dims, + constraint_sizes=constraint_dims, + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + +# See note [Tensor Fakification and Symbol Caching] +def wrap_to_fake_tensor_and_record( + e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None +): + if ( + type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) + or isinstance(e, torch.Tensor) + or is_traceable_wrapper_subclass(e) + ): + assert source is not None + static_shapes, reason = tensor_always_has_static_shape( + e, is_tensor, guard_source=source.guard_source() + ) + + if not parent_context: + symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) + else: + # Parent contexts are passed in when we are recursively creating + # fake tensors for subclasses. A better design would be not to create a + # parent/child relationship, but to recursively call _automatic_dynamic + # as we recursively call wrap_to_fake_tensor_and_record. This runs + # into bugs around how meta_utils knows and works to create fake tensors + # with tensor subclasses. Ideally, dynamo would drive both the recursive + # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation. + assert isinstance(source, AttrSource) + inner_context_name = source.member + symbolic_context = parent_context.inner_contexts[inner_context_name] + + log.debug( + "wrap_to_fake %s %s %s %s", + source.name(), + tuple(e.shape), + symbolic_context, + type(e), + ) + fake_e = wrap_fake_exception( + lambda: tx.fake_mode.from_tensor( + e, + source=source, + symbolic_context=symbolic_context, + ) + ) + + if is_traceable_wrapper_subclass(fake_e): + attrs, _ = fake_e.__tensor_flatten__() + for attr in attrs: + fake_inner = getattr(fake_e, attr) + inner = getattr(e, attr) + inner_source = AttrSource(source, attr) + wrap_to_fake_tensor_and_record( + inner, + tx, + source=inner_source, + is_tensor=isinstance(fake_inner, torch.Tensor), + parent_context=symbolic_context, + ) + + tx.output.tracing_context.tensor_to_context[e] = symbolic_context + tx.output.tensor_weakref_to_sizes_strides[e] = { + "size": fake_e.size(), + "stride": fake_e.stride(), + } + + if ( + is_tensor + and not (static_shapes and source.is_nn_module()) + and not is_constant_source(source) + ): + tx.output.tracked_fakes.append( + TrackedFake(fake_e, source, symbolic_context) + ) + tx.output.tracked_fakes_id_to_source[id(e)].append(source) + + return fake_e + else: + return e + + +class SourcelessBuilder: + """ + Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects + that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over + .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, + there may be reasons to represent it as a ListVariable internally. + + NOTE - Objects produced here are born UNGUARDED due to the nature of sources! + + NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant + if/else type->VariableTracker trees that were cropping up all over dynamo. + """ + + def __call__(self, tx, value) -> VariableTracker: + if isinstance(value, VariableTracker): + # This is always valid to call, and useful for recursive calls. + return value + if isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): + return UserDefinedObjectVariable(value) + if ConstantVariable.is_literal(value): + return SourcelessBuilder.wrap_constant_literal(value) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if is_callable_allowed(value): + self.tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value)(value) + elif is_function_or_wrapper(value): + return trace_rules.lookup(value)(value) + elif isinstance(value, enum.Enum): + return EnumVariable(value) + elif isinstance(value, (type, abc.ABCMeta)): + return UserDefinedClassVariable(value) + elif isinstance(value, dict): + items = {self(tx, k): self(tx, v) for k, v in value.items()} + return ConstDictVariable(items, mutable_local=MutableLocal()) + elif isinstance(value, set): + # Nb. value is a set here so the iteration below is non-deterministic! + return SetVariable( + [self(tx, x) for x in value], mutable_local=MutableLocal() + ) + elif isinstance(value, (tuple, list)): + cls = BaseListVariable.cls_for(type(value)) + return cls([self(tx, x) for x in value], mutable_local=MutableLocal()) + elif isinstance(value, types.MethodWrapperType): + return MethodWrapperVariable(value) + elif PlacementVariable.is_placement(value): + return PlacementVariable(value) + elif DeviceMeshVariable.is_device_mesh(value): + return DeviceMeshVariable(value) + unimplemented(f"Unexpected type in sourceless builder {type(value)}") + + @staticmethod + def wrap_constant_literal(value): + assert ConstantVariable.is_literal(value) + return ConstantVariable.create(value=value) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/builtin.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/builtin.py new file mode 100644 index 0000000000000000000000000000000000000000..94ff064f65e9ad9c20011fa69e4d9e077a76dfbb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/builtin.py @@ -0,0 +1,1748 @@ +# mypy: ignore-errors + +import contextlib +import functools +import inspect +import itertools +import logging +import math +import operator +import types +from collections import defaultdict, OrderedDict +from typing import Dict, List + +import torch +from torch import sym_float, sym_int + +from .. import config, polyfill, variables +from ..exc import ( + AttributeMutationError, + unimplemented, + Unsupported, + UserError, + UserErrorType, +) +from ..guards import GuardBuilder, install_guard +from ..replay_record import DummyModule +from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource +from ..utils import ( + check_constant_args, + check_numpy_ndarray_args, + check_unspec_python_args, + extract_fake_example_value, + get_fake_value, + guard_if_dyn, + istype, + numpy_operator_wrapper, + proxy_args_kwargs, + tensortype_to_dtype, +) +from .base import MutableLocal, typestr, VariableTracker +from .constant import ConstantVariable +from .ctx_manager import EventVariable, StreamVariable +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictView, + is_hashable, + SetVariable, +) +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SizeVariable, + TupleIteratorVariable, + TupleVariable, +) +from .tensor import ( + FakeItemVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .user_defined import UserDefinedVariable + +log = logging.getLogger(__name__) + + +IN_PLACE_DESUGARING_MAP = { + operator.iadd: operator.add, + operator.isub: operator.sub, + operator.imul: operator.mul, + operator.ifloordiv: operator.floordiv, + operator.itruediv: operator.truediv, + operator.imod: operator.mod, + operator.imatmul: operator.imatmul, + operator.ilshift: operator.lshift, + operator.irshift: operator.rshift, + operator.ipow: operator.pow, + operator.iand: operator.and_, + operator.ior: operator.or_, + operator.ixor: operator.xor, +} + + +def _polyfill_call_impl(name): + """Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}""" + + def call_fn(self, tx, *args, **kwargs): + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), args, kwargs + ) + + fn = getattr(polyfill, name) + call_fn.__name__ = f"call_{name}" + return call_fn + + +class BuiltinVariable(VariableTracker): + _SENTINEL = object() + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + return BuiltinVariable(value, source=source) + + @staticmethod + @functools.lru_cache(None) + def _constant_fold_functions(): + fns = { + abs, + all, + any, + bool, + callable, + chr, + divmod, + float, + getattr, + int, + len, + max, + min, + ord, + pow, + repr, + round, + str, + str.format, + sum, + type, + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.truth, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.sub, + operator.getitem, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + operator.index, + } + fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) + return fns + + def can_constant_fold_through(self): + return self.fn in self._constant_fold_functions() + + @staticmethod + @functools.lru_cache(None) + def _fx_graph_functions(): + fns = { + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.getitem, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + } + return fns + + @staticmethod + @functools.lru_cache(None) + def _binops(): + # function -> ([forward name, reverse name, in-place name], in-place op) + fns = { + operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), + operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), + operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), + operator.truediv: ( + ["__truediv__", "__rtruediv__", "__itruediv__"], + operator.itruediv, + ), + operator.floordiv: ( + ["__floordiv__", "__rfloordiv__", "__ifloordiv__"], + operator.ifloordiv, + ), + operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), + pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.lshift: ( + ["__lshift__", "__rlshift__", "__ilshift__"], + operator.ilshift, + ), + operator.rshift: ( + ["__rshift__", "__rrshift__", "__irshift__"], + operator.irshift, + ), + # NB: The follow binary operators are not supported for now, since the + # corresponding magic methods aren't defined on SymInt / SymFloat: + # operator.matmul + # divmod + # operator.and_ + # operator.or_ + # operator.xor + } + return fns + + @staticmethod + @functools.lru_cache(None) + def _binop_handlers(): + # Multiple dispatch mechanism defining custom binop behavior for certain type + # combinations. Handlers are attempted in order, and will be used if the type checks + # match. They are expected to have the signature: + # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker + + # Override table contains: op_fn -> [list of handlers] + op_handlers = {} + for ( + op, + (magic_method_names, in_place_op), + ) in BuiltinVariable._binops().items(): + op_handlers[op] = [] + op_handlers[in_place_op] = [] + + forward_name, reverse_name, inplace_name = magic_method_names + + # User-defined args (highest precedence) + def user_defined_handler( + tx, + a, + b, + options, + forward_name=forward_name, + reverse_name=reverse_name, + ): + # Manually handle reversing logic if needed (e.g. call __radd__) + + # TODO: If we expand this to handle tensor args, we need to manually + # handle cases like this: + # + # class A(int): + # def __radd__(self, other): + # print("woof") + # torch.randn(3) + A(3) + # + # In this example, A.__radd__() is not called -> nothing is printed, because + # Tensor.__add__ only does a subtype test against int, ignoring the subclass. + # To be fully correct, we should not call A.__radd__() here, and there may be + # other cases to reason about and add exceptions for. + if isinstance(a, UserDefinedVariable): + return a.call_method(tx, forward_name, [b], {}) + else: + return b.call_method(tx, reverse_name, [a], {}) + + op_handlers[op].append( + ((UserDefinedVariable, VariableTracker), user_defined_handler) + ) + op_handlers[op].append( + ((VariableTracker, UserDefinedVariable), user_defined_handler) + ) + + def user_defined_inplace_handler( + tx, a, b, options, forward_name=inplace_name + ): + return a.call_method(tx, forward_name, [b], {}) + + op_handlers[in_place_op].append( + ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) + ) + + # Dynamic shape args + def dynamic_handler(tx, a, b, options, fn=op): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", fn, *proxy_args_kwargs([a, b], {}) + ), + **options, + ) + + op_handlers[op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # NB: Prefer out-of-place op when calling in-place op to generate valid graph + op_handlers[in_place_op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # Special cases - lower precedence but still prefer these over constant folding + + # List-like addition (e.g. [1, 2] + [3, 4]) + def tuple_add_handler(tx, a, b, options): + return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options) + + def size_add_handler(tx, a, b, options): + return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options) + + list_like_addition_handlers = [ + # NB: Prefer the tuple-specific logic over base logic because of + # some SizeVariable weirdness. Specifically, the tuple-specific logic + # drops the subclass type (e.g. SizeVariable) and returns TupleVariables. + ( + (SizeVariable, SizeVariable), + size_add_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ( + (ConstantVariable, TupleVariable), + lambda tx, a, b, options: TupleVariable( + list(a.unpack_var_sequence(tx)) + b.items, **options + ), + ), + ( + (BaseListVariable, BaseListVariable), + lambda tx, a, b, options: type(a)(a.items + b.items, **options), + ), + ] + op_handlers[operator.add].extend(list_like_addition_handlers) + + def list_iadd_handler(tx, a, b, _): + if not a.mutable_local or not b.has_unpack_var_sequence(tx): + # Handler doesn't apply + return None + + seq = b.unpack_var_sequence(tx) + tx.output.side_effects.mutation(a) + a.items.extend(seq) + return a + + list_like_iadd_handlers = [ + ( + (ListVariable, VariableTracker), + list_iadd_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ] + op_handlers[operator.iadd].extend(list_like_iadd_handlers) + + # List-like expansion (e.g. [1, 2, 3] * 3) + def expand_list_like(tx, lst, const, options): + return lst.__class__( + items=lst.items * const.as_python_constant(), + mutable_local=MutableLocal(), + **options, + ) + + list_like_expansion_handlers = [ + ((ListVariable, ConstantVariable), expand_list_like), + ((TupleVariable, ConstantVariable), expand_list_like), + ( + (ConstantVariable, ListVariable), + lambda tx, a, b, options: expand_list_like(tx, b, a, options), + ), + ( + (ConstantVariable, TupleVariable), + lambda tx, a, b, options: expand_list_like(tx, b, a, options), + ), + ] + op_handlers[operator.mul].extend(list_like_expansion_handlers) + + return op_handlers + + @staticmethod + def _find_binop_handler(op, a, b): + handlers = BuiltinVariable._binop_handlers() + if op not in handlers: + return None + + # Return first handler that matches the type checks + for (type1, type2), handler in handlers[op]: + if isinstance(a, type1) and isinstance(b, type2): + return handler + + return None + + def can_insert_in_graph(self): + return self.fn in self._fx_graph_functions() + + def __init__(self, fn, **kwargs): + super().__init__(**kwargs) + self.fn = fn + + def __str__(self): + if self.fn is None: + name = "None" + else: + name = self.fn.__name__ + + return f"{self.__class__.__name__}({name})" + + def python_type(self): + return type(self.fn) + + def as_python_constant(self): + return self.fn + + def as_proxy(self): + DTYPE = { + bool: torch.bool, + int: torch.int64, + float: torch.float64, + } + if self.fn in DTYPE: + return DTYPE[self.fn] + return super().as_proxy() + + def reconstruct(self, codegen): + name = self.fn.__name__ + assert self.fn.__module__ == "builtins" + assert name not in codegen.tx.f_globals, "shadowed global" + codegen.append_output(codegen.create_load_global(name, False, add=True)) + + def constant_args(self, *args, **kwargs): + return check_constant_args(args, kwargs) + + def tensor_args(self, *args, **kwargs): + return any( + isinstance(i, variables.TensorVariable) + for i in itertools.chain(args, kwargs.values()) + ) and not any( + isinstance(i, variables.GetAttrVariable) + for i in itertools.chain(args, kwargs.values()) + ) + + def python_and_tensor_constant_only(self, *args, **kwargs): + tensor_args = [] + non_tensor_args = [] + for i in itertools.chain(args, kwargs.values()): + if isinstance(i, variables.TensorVariable): + tensor_args.append(i) + else: + non_tensor_args.append(i) + return all( + is_constant_source(t.source) if t.source is not None else False + for t in tensor_args + ) and self.constant_args(*non_tensor_args) + + def unspec_python_args(self, *args, **kwargs): + return check_unspec_python_args(args, kwargs) + + @staticmethod + def unwrap_unspec_args_kwargs(args, kwargs): + return [x.as_python_constant() for x in args], { + k: v.as_python_constant() for k, v in kwargs.items() + } + + def has_constant_handler(self, args, kwargs): + constant_args = check_constant_args(args, kwargs) + unspec_python_args = self.unspec_python_args(*args, **kwargs) + return self.can_constant_fold_through() and ( + constant_args or unspec_python_args + ) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import UserFunctionVariable + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + args = [v.realize() for v in args] + kwargs = {k: v.realize() for k, v in kwargs.items()} + assert isinstance(args, (list, tuple)) + assert isinstance(kwargs, dict) + tensor_args = self.tensor_args(*args, **kwargs) + + # args[0] is list and args[1] is unspec + if self.fn is operator.getitem and not isinstance( + args[0], variables.TensorVariable + ): + tensor_args = False + + if ( + self.can_insert_in_graph() + and tensor_args + and not ( + self.fn is operator.getitem + and isinstance(args[0], ConstDictVariable) + and isinstance(args[1], variables.TensorVariable) + ) + ): + try: + fn = self.fn + + # Constant fold for constant tensor and python constants + if tensor_args and self.python_and_tensor_constant_only( + *args, **kwargs + ): + from ..bytecode_transformation import unique_id + from .functions import invoke_and_store_as_constant + + return invoke_and_store_as_constant( + tx, fn, unique_id(fn.__name__), args, kwargs + ) + + if self.fn in IN_PLACE_DESUGARING_MAP and isinstance( + args[0], variables.ConstantVariable + ): + # In-place operators like += usually mustate tensor + # values, but in the edge case of immutable values they + # re-bind the variable. + # + # The easiest way to keep the graph consistent in this + # scenario is to de-sugar eagerly. + fn, args = IN_PLACE_DESUGARING_MAP[self.fn], [args[0], args[1]] + + if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn, args = torch.select, [ + args[0], + variables.ConstantVariable.create(0), + args[1], + ] + + # Interaction between ndarray and tensors: + # We prefer the tensor op whenever there are tensors involved + if check_numpy_ndarray_args(args, kwargs) and not any( + type(arg) == variables.TensorVariable for arg in args + ): + proxy = tx.output.create_proxy( + "call_function", + numpy_operator_wrapper(self.fn), + *proxy_args_kwargs(args, kwargs), + ) + + return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy) + + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs(args, kwargs), + ) + if any(isinstance(arg, FakeItemVariable) for arg in args): + return wrap_fx_proxy_cls( + FakeItemVariable, + tx, + proxy, + ) + elif self.unspec_python_args(*args, **kwargs): + _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) + raw_value = self.fn(*_args, **_kwargs) + + need_unwrap = any( + x.need_unwrap + for x in itertools.chain(args, kwargs.values()) + if isinstance(x, variables.UnspecializedPythonVariable) + ) + + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx, + proxy, + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + elif all(isinstance(x, SymNodeVariable) for x in args): + return SymNodeVariable.create(tx, proxy, None) + else: + # Work around for vision_maskrcnn due to precision difference + # specialize the dividend when float divide by tensor + if self.fn is operator.truediv and isinstance( + args[0], variables.UnspecializedPythonVariable + ): + args[0] = args[0].convert_to_constant(tx) + return wrap_fx_proxy(tx, proxy) + + except NotImplementedError: + unimplemented(f"partial tensor op: {self} {args} {kwargs}") + + # Handle cases like int(torch.seed()) + # Also handle sym_float to sym_int cases + if self.fn in (int, float) and isinstance( + args[0], (SymNodeVariable, variables.TensorVariable) + ): + if isinstance(args[0], variables.TensorVariable): + item = args[0].call_method(tx, "item", [], {}) + else: + item = args[0] + fn_ = sym_int if self.fn is int else sym_float + out = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (item.as_proxy(),), + {}, + ), + ) + return out + + # Handle `str` on a user defined function + if self.fn == str and args and isinstance(args[0], (UserFunctionVariable)): + return variables.ConstantVariable.create(value=str(args[0].fn)) + + # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.) + # NB: Tensor args are handled above and not here + if len(kwargs) == 0 and len(args) == 2: + # Try to find a handler for the arg types; otherwise, fall through to constant handler + binop_handler = BuiltinVariable._find_binop_handler( + self.fn, args[0], args[1] + ) + if binop_handler: + res = binop_handler(tx, args[0], args[1], {}) + if res is not None: + return res + + handler = getattr(self, f"call_{self.fn.__name__}", None) + + if handler: + try: + result = handler(tx, *args, **kwargs) + if result is not None: + return result + except TypeError: + # Check if binding is bad. inspect signature bind is expensive. + # So check only when handler call fails. + try: + inspect.signature(handler).bind(tx, *args, **kwargs) + except TypeError as e: + has_constant_handler = self.has_constant_handler(args, kwargs) + if not has_constant_handler: + log.warning( + "incorrect arg count %s %s and no constant handler", + handler, + e, + ) + unimplemented(f"invalid handler args {handler} {args} {kwargs}") + else: + raise + except Unsupported as exc: + has_constant_handler = self.has_constant_handler(args, kwargs) + if not has_constant_handler: + raise + # Actually, we will handle this just fine + exc.remove_from_stats() + + # NB: call to has_constant_handler is deliberately delayed post generic + # handler because has_constant_handler calls as_python_constant + # internally which realizes LazyVariableTracker for ConstantVariables, + # unnecessarily putting guards on objects which might not actually be used. + has_constant_handler = self.has_constant_handler(args, kwargs) + if has_constant_handler: + from .builder import SourcelessBuilder + + # constant fold + return SourcelessBuilder()( + tx, + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + + return super().call_function(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if self.fn == dict and name == "fromkeys": + return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) + if self.fn == itertools.chain and name == "from_iterable": + assert len(args) == 1 + assert len(kwargs) == 0 + obj = args[0] + items = [] + for item in obj.unpack_var_sequence(tx): + items.extend(item.unpack_var_sequence(tx)) + return variables.TupleVariable(items) + + return super().call_method(tx, name, args, kwargs) + + def _call_min_max(self, tx, *args): + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + # expand iterable + items = args[0].unpack_var_sequence(tx) + return self._call_min_max_seq(tx, items) + elif len(args) == 2: + return self._call_min_max_binary(tx, args[0], args[1]) + elif len(args) > 2: + return self._call_min_max_seq(tx, args) + + def _call_min_max_seq(self, tx, items): + assert len(items) > 0 + if len(items) == 1: + return items[0] + + return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) + + def _call_min_max_binary(self, tx, a, b): + if self.tensor_args(a, b): + if not isinstance(a, variables.TensorVariable): + a, b = b, a + assert isinstance(a, variables.TensorVariable) + + # result of an item call is a scalar convert to a tensor + if isinstance(a, FakeItemVariable): + a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( + tx, [a], {} + ) + + # Dynamic input does not get resolved, rather, gets stored as call_function + if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + type(a), + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.fn, + *proxy_args_kwargs([a, b], {}), + ), + ) + + # convert min/max to torch ops + if b.is_python_constant(): + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + fn = variables.NumpyVariable(np.clip) + else: + fn = variables.TorchInGraphFunctionVariable(torch.clamp) + kwargs = {"min": b} if (self.fn is max) else {"max": b} + result = fn.call_function(tx, [a], kwargs) + else: + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + fn = {max: np.maximum, min: np.minimum}[self.fn] + fn = variables.NumpyVariable(fn) + else: + fn = {max: torch.maximum, min: torch.minimum}[self.fn] + fn = variables.TorchInGraphFunctionVariable(fn) + result = fn.call_function(tx, [a, b], {}) + + # return unspec if both a, b are unspec or const + if all( + isinstance( + i, + ( + variables.UnspecializedPythonVariable, + variables.ConstantVariable, + ), + ) + for i in [a, b] + ): + if any(isinstance(val, FakeItemVariable) for val in [a, b]): + return variables.FakeItemVariable.from_tensor_variable(result) + + if b.is_python_constant(): + raw_b = b.as_python_constant() + else: + raw_b = b.raw_value + if self.fn is max: + raw_res = max(a.raw_value, raw_b) + else: + raw_res = min(a.raw_value, raw_b) + + need_unwrap = any( + x.need_unwrap + for x in [a, b] + if isinstance(x, variables.UnspecializedPythonVariable) + ) + return variables.UnspecializedPythonVariable.from_tensor_variable( + result, raw_res, need_unwrap + ) + # otherwise return tensor + else: + return result + elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + fn = torch.sym_max if self.fn is max else torch.sym_min + proxy = tx.output.create_proxy( + "call_function", fn, *proxy_args_kwargs([a, b], {}) + ) + return SymNodeVariable.create(tx, proxy, None) + + call_min = _call_min_max + call_max = _call_min_max + + def call_abs(self, tx, arg: "VariableTracker"): + # Call arg.__abs__() + abs_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__abs__")], {} + ) + return abs_method.call_function(tx, [], {}) + + def call_pos(self, tx, arg: "VariableTracker"): + # Call arg.__pos__() + pos_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__pos__")], {} + ) + return pos_method.call_function(tx, [], {}) + + def call_round(self, tx, arg, *args, **kwargs): + # Call arg.__round__() + round_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__round__")], {} + ) + return round_method.call_function(tx, args, kwargs) + + def call_range(self, tx, *args): + if self.unspec_python_args(*args) or self.constant_args(*args): + return variables.RangeVariable(args) + elif self._dynamic_args(*args): + args = [ + variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args + ] + return variables.RangeVariable(args) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args, **kwargs): + return any(isinstance(x, SymNodeVariable) for x in args) or any( + isinstance(x, SymNodeVariable) for x in kwargs.values() + ) + + def call_slice(self, tx, *args): + return variables.SliceVariable(args) + + def _dyn_proxy(self, tx, *args, **kwargs): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + ) + + def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + if isinstance(obj, variables.IteratorVariable): + # For non-list iterators, we will guard on vars that + # determine the control flow + return obj + + cls = variables.BaseListVariable.cls_for(self.fn) + if obj is None: + return cls( + [], + mutable_local=MutableLocal(), + ) + elif obj.has_unpack_var_sequence(tx): + if obj.source and not is_constant_source(obj.source): + if isinstance(obj, TupleIteratorVariable): + install_guard( + obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) + ) + else: + install_guard(obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + + return cls( + list(obj.unpack_var_sequence(tx)), + mutable_local=MutableLocal(), + ) + + def call_iter(self, tx, obj, *args, **kwargs): + # Handle the case where we are iterating over a tuple, list or iterator + ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) + + if ret is None: + # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. + # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call + # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. + return obj.call_method(tx, "__iter__", args, kwargs) + return ret + + call_tuple = _call_iter_tuple_list + call_list = _call_iter_tuple_list + + def call_callable(self, tx, arg): + from .functions import BaseUserFunctionVariable + + if isinstance( + arg, (variables.UserDefinedClassVariable, BaseUserFunctionVariable) + ): + return variables.ConstantVariable.create(True) + elif isinstance(arg, UserDefinedVariable): + return variables.ConstantVariable.create(callable(arg.value)) + elif isinstance(arg, (ConstantVariable, SymNodeVariable, TensorVariable)): + return variables.ConstantVariable.create(False) + + def call_cast(self, _, *args, **kwargs): + if len(args) == 2: + return args[1] + + unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}") + + def call_dict(self, tx, *args, **kwargs): + return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + + @staticmethod + def call_custom_dict(tx, user_cls, *args, **kwargs): + if not kwargs: + if not args: + args = ({},) + assert len(args) == 1 + arg = args[0] + if isinstance(arg, dict): + return ConstDictVariable(arg, user_cls, mutable_local=MutableLocal()) + elif isinstance(arg, variables.ConstDictVariable): + return arg.clone(user_cls=user_cls, mutable_local=MutableLocal()) + elif isinstance( + arg, + ( + ListVariable, + TupleVariable, + ListIteratorVariable, + ), + ): + items = dict( + x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) + ) + return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) + elif not args and kwargs: + items = {ConstantVariable.create(k): v for k, v in kwargs.items()} + return variables.ConstDictVariable( + items, user_cls=user_cls, mutable_local=MutableLocal() + ) + unimplemented(f"{user_cls.__name__}(): {args} {kwargs}") + + @staticmethod + def call_custom_dict_fromkeys(tx, user_cls, *args, **kwargs): + assert user_cls in {dict, OrderedDict, defaultdict} + if kwargs: + # Only `OrderedDict.fromkeys` accepts `value` passed by keyword + assert user_cls is OrderedDict + assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs + args = (*args, kwargs.pop("value")) + if len(args) == 0: + raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") + if len(args) == 1: + args = (*args, ConstantVariable.create(None)) + assert len(args) == 2 + arg, value = args + DictVariableType = ( + ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable + ) + + if isinstance(arg, dict): + arg = [ConstantVariable.create(k) for k in arg.keys()] + return DictVariableType( + dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() + ) + elif arg.has_unpack_var_sequence(tx) and all( + is_hashable(v) for v in arg.unpack_var_sequence(tx) + ): + keys = arg.unpack_var_sequence(tx) + return DictVariableType( + dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() + ) + unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") + + def call_set(self, tx, *args, **kwargs): + # Can we merge this implementation and call_dict's one? + assert not kwargs + if not args: + return SetVariable([], mutable_local=MutableLocal()) + assert len(args) == 1 + arg = args[0] + if isinstance(arg, variables.SetVariable): + return arg.clone(mutable_local=MutableLocal()) + elif arg.has_unpack_var_sequence(tx): + items = arg.unpack_var_sequence(tx) + return SetVariable(items, mutable_local=MutableLocal()) + else: + unimplemented(f"set(): {args} {kwargs}") + + def call_zip(self, tx, *args, **kwargs): + if kwargs: + assert len(kwargs) == 1 and "strict" in kwargs + if all(x.has_unpack_var_sequence(tx) for x in args): + unpacked = [arg.unpack_var_sequence(tx) for arg in args] + if kwargs.pop("strict", False) and len(unpacked) > 0: + if not all(len(u) == len(unpacked[0]) for u in unpacked): + raise UserError( + ValueError, + "zip() has one argument of len differing from others", + ) + items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] + return variables.TupleVariable(items) + + def call_enumerate(self, tx, *args): + if len(args) == 1: + start = 0 + else: + assert len(args) == 2 + assert isinstance(args[1], variables.ConstantVariable) + start = args[1].as_python_constant() + if args[0].has_unpack_var_sequence(tx): + items = [ + variables.TupleVariable( + [variables.ConstantVariable.create(idx), var], + ) + for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) + ] + return variables.TupleVariable(items) + + def call_len(self, tx, *args, **kwargs): + return args[0].call_method(tx, "__len__", args[1:], kwargs) + + def call_getitem(self, tx, *args, **kwargs): + return args[0].call_method(tx, "__getitem__", args[1:], kwargs) + + def call_isinstance(self, tx, arg, isinstance_type): + try: + arg_type = arg.python_type() + except NotImplementedError: + unimplemented( + f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}" + ) + + isinstance_type = isinstance_type.as_python_constant() + + if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: + + def _tensor_isinstance(tensor_var, tensor_type): + def check_type(ty): + if ty not in tensortype_to_dtype: + return issubclass(arg.python_type(), ty) + + dtypes = tensortype_to_dtype[ty] + return arg.dtype in dtypes + + if type(tensor_type) is tuple: + return any(check_type(ty) for ty in tensor_type) + else: + return check_type(tensor_type) + + return variables.ConstantVariable.create( + _tensor_isinstance(arg, isinstance_type) + ) + # UserDefinedObject with C extensions can have torch.Tensor attributes, + # so break graph. + if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, types.MemberDescriptorType + ): + unimplemented( + f"isinstance called on UserDefinedClass {arg} {isinstance_type}" + ) + # handle __instancecheck__ defined in user class + if ( + isinstance(arg, variables.UserDefinedObjectVariable) + and "__instancecheck__" in isinstance_type.__class__.__dict__ + ): + return variables.ConstantVariable.create( + isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) + ) + + try: + val = issubclass(arg_type, isinstance_type) + except TypeError: + val = arg_type is isinstance_type + return variables.ConstantVariable.create(val) + + def call_issubclass(self, tx, left_ty, right_ty): + """Checks if first arg is subclass of right arg""" + left_ty = left_ty.as_python_constant() + right_ty = right_ty.as_python_constant() + + return variables.ConstantVariable(issubclass(left_ty, right_ty)) + + def call_super(self, tx, a, b): + return variables.SuperVariable(a, b) + + def call_next(self, tx, arg): + if isinstance( + arg, (variables.ListIteratorVariable, variables.IteratorVariable) + ): + val, next_iter = arg.next_variables(tx) + return val + elif isinstance(arg, variables.BaseListVariable): + return arg.items[0] + + def call_hasattr(self, tx, obj, attr): + if attr.is_python_constant(): + name = attr.as_python_constant() + return obj.call_hasattr(tx, name) + + def call_map(self, tx, fn, seq): + if seq.has_unpack_var_sequence(tx): + items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)] + return variables.TupleVariable(items) + + def call_sum(self, tx, seq, start=_SENTINEL): + # Special case for sum on tuple of floats and ints + if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, (int, float)) + for x in seq.items + ): + if start is self._SENTINEL: + return variables.ConstantVariable.create( + sum(x.value for x in seq.items), + ) + if isinstance(start, variables.ConstantVariable) and isinstance( + start.value, (int, float) + ): + return variables.ConstantVariable.create( + sum((x.value for x in seq.items), start=start.value), + ) + if seq.has_unpack_var_sequence(tx): + if start is self._SENTINEL: + start = variables.ConstantVariable.create(0) + items = seq.unpack_var_sequence(tx) + return BuiltinVariable(functools.reduce).call_function( + tx, + [ + BuiltinVariable(operator.add), + variables.TupleVariable(items), + start, + ], + {}, + ) + + def call_reduce(self, tx, function, iterable, initial=_SENTINEL): + if iterable.has_unpack_var_sequence(tx): + items = iterable.unpack_var_sequence(tx) + if initial is self._SENTINEL: + value, items = items[0], items[1:] + else: + value = initial + for element in items: + value = function.call_function(tx, [value, element], {}) + return value + + def call_getattr( + self, tx, obj: VariableTracker, name_var: VariableTracker, default=None + ): + from .. import trace_rules + from . import ( + ConstantVariable, + GetAttrVariable, + PythonModuleVariable, + TorchInGraphFunctionVariable, + UserFunctionVariable, + ) + from .builder import SourcelessBuilder, VariableBuilder + + name = name_var.as_python_constant() + + if not name_var.is_python_constant(): + unimplemented("non-const getattr() name") + + if tx.output.side_effects.is_attribute_mutation(obj): + try: + # re-read a pending side effect? + return tx.output.side_effects.load_attr(obj, name) + except KeyError: + pass + + if default is not None: + hasattr_var = self.call_hasattr(tx, obj, name_var) + assert hasattr_var.as_python_constant() in (True, False) + if not hasattr_var.as_python_constant(): + return default + + options = {} + if obj.source: + source = AttrSource(obj.source, name) + options["source"] = source + else: + source = None + + if name == "__bases__": + try: + value = obj.as_python_constant() + if isinstance(value, type): + bases = value.__bases__ + if source is not None: + tuple_args = [ + VariableBuilder(tx, GetItemSource(source, i))(b) + for i, b in enumerate(bases) + ] + else: + tuple_args = [SourcelessBuilder()(tx, b) for b in bases] + + return variables.TupleVariable(tuple_args, **options) + except NotImplementedError: + pass + + if isinstance(obj, variables.NNModuleVariable): + return obj.var_getattr(tx, name) + elif isinstance( + obj, + ( + variables.TensorVariable, + variables.NamedTupleVariable, + variables.ConstantVariable, + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return GetAttrVariable(obj, name, **options) + elif isinstance(obj, TorchInGraphFunctionVariable): + # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. + member = getattr(obj.value, name) + if isinstance( + member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) and trace_rules.is_aten_op_or_tensor_method(member): + return TorchInGraphFunctionVariable(member, **options) + elif isinstance(obj, (PythonModuleVariable, DummyModule)): + if obj.is_torch: + member = getattr(obj.value, name) + else: + member = obj.value.__dict__[name] + + if config.replay_record_enabled: + tx.exec_recorder.record_module_access(obj.value, name, member) + + if source is not None: + return VariableBuilder(tx, source)(member) + else: + return SourcelessBuilder()(tx, member) + elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): + return ConstantVariable.create(getattr(obj.fn, name)) + else: + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return GetAttrVariable(obj, name, **options) + + def call_setattr( + self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker + ): + from .distributed import PlacementVariable + + if isinstance( + obj, + ( + variables.DataClassVariable, + variables.CustomizedDictVariable, + PlacementVariable, + ), + ): + return obj.call_method(tx, "__setattr__", [name_var, val], {}) + elif ( + tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + name = name_var.as_python_constant() + if isinstance(obj, variables.TensorVariable): + from .builder import wrap_fx_proxy + + if name == "requires_grad": + # TODO(voz): Make it work properly + unimplemented( + "mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " + "the middle of the graph, which aot_autograd does not currently know how to handle. " + ) + if name == "data": + # Remove the old reference in tracked fakes - if we don't do this + # new .data value size and shape differences will cause + # tracked fakes to produce incorrect guards. This is sound because the TensorVariable + # coming out of set_() below will be a new one, and get + # installed in tracked fakes. + to_remove = [] + for tf in tx.output.tracked_fakes: + if tf.source == obj.source: + to_remove.append(tf) + for tf in to_remove: + tx.output.tracked_fakes.remove(tf) + + # Step 1 - disable grads + with dynamo_disable_grad(tx), torch.no_grad(): + # Step 2 - call `set_` + out = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + torch.Tensor.set_, + *proxy_args_kwargs([obj, val], {}), + ), + ) + + # Step 3 - drop the version counter - this is a step required to get + # .data setting to play correctly with the autograd engine. + # Esentially, dynamo is trying to faithful preserve the (absurd) + # behavior of .data= from eager mode + def _lower_version_count_by_1(x): + version = x._version + if version > 0: + version = version - 1 + torch._C._autograd._unsafe_set_version_counter(x, version) + return x + + tx.output.create_proxy( + "call_function", + _lower_version_count_by_1, + (out.as_proxy(),), + {}, + ) + _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"]) + # This handles options prop, guards and ends with a clone + # Step 4 - replace all reference to the current object with the new one + return out + + tx.output.side_effects.store_attr(obj, name, val) + return val + elif isinstance(obj, variables.UserDefinedObjectVariable): + unimplemented( + f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}" + ) + elif isinstance(obj, variables.NNModuleVariable): + if not tx.output.is_root_tracer(): + raise AttributeMutationError( + "Can't inplace modify module params/buffers inside HigherOrderOp" + ) + if name_var.is_python_constant() and isinstance( + val, variables.TensorVariable + ): + assigning_fake_val = get_fake_value(val.as_proxy().node, tx) + + try: + getattr_var = obj.var_getattr(tx, name_var.as_python_constant()) + except AttributeError: + getattr_var = None + + if isinstance(getattr_var, variables.TensorVariable): + # get_fake_val will get the same fake tensor + existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) + + # same tensor identiy, setattr is a no-op + mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") + if ( + existing_fake_attr is assigning_fake_val + and mod_setattr is torch.nn.Module.__setattr__ + ): + return getattr_var + + obj.convert_to_unspecialized(tx) + # FIXME (tmanlaibaatar) this is utter hack to unblock HuggingFace export + # Export generally doesn't want to allow mutations on objects directly, + # but we don't have good way to do this rn. For now, we make it an undefined + # behaviour and just set attributes directly on the PretrainedConfig object + # for now. + elif isinstance(obj, variables.dicts.HFPretrainedConfigVariable) and tx.export: + if name_var.is_python_constant() and isinstance( + val, variables.ConstantVariable + ): + setattr( + obj.obj, name_var.as_python_constant(), val.as_python_constant() + ) + return ConstantVariable(None) + + def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker): + return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) + + def call_type(self, tx, obj: VariableTracker): + from .builder import SourcelessBuilder, VariableBuilder + + try: + py_type = obj.python_type() + except NotImplementedError as error: + raise UserError( + UserErrorType.INVALID_INPUT, + str(error), + case_name="unknown_python_type", + ) from None + + if obj.source is None: + return SourcelessBuilder()(tx, py_type) + else: + return VariableBuilder(tx, TypeSource(obj.source))(py_type) + + def call_reversed(self, tx, obj: VariableTracker): + if obj.has_unpack_var_sequence(tx): + items = list(reversed(obj.unpack_var_sequence(tx))) + return variables.TupleVariable(items) + + def call_sorted(self, tx, obj: VariableTracker, **kwargs): + if ( + obj.has_unpack_var_sequence(tx) + and not isinstance(obj, variables.TensorVariable) + and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) + ): + function = kwargs.pop("key", None) + reverse = kwargs.pop( + "reverse", ConstantVariable.create(False) + ).as_python_constant() + assert len(kwargs) == 0 + if function: + items = sorted( + obj.unpack_var_sequence(tx), + key=lambda x: function.call_function( + tx, [x], {} + ).as_python_constant(), + reverse=reverse, + ) + else: + items = sorted( + obj.unpack_var_sequence(tx), + key=lambda x: x.as_python_constant(), + reverse=reverse, + ) + return variables.ListVariable(items) + + def call_chain(self, tx, *args): + if all(obj.has_unpack_var_sequence(tx) for obj in args): + items = [] + for obj in args: + items.extend(obj.unpack_var_sequence(tx)) + return variables.TupleVariable(items) + + def call_islice(self, tx, iterable, *args): + if iterable.has_unpack_var_sequence(tx) and all( + x.is_python_constant() for x in args + ): + const_args = [x.as_python_constant() for x in args] + items = iterable.unpack_var_sequence(tx) + items = list(itertools.islice(items, *const_args)) + return variables.TupleVariable(items) + + # neg is a constant fold function, so we only get here if constant fold is not valid + def call_neg(self, tx, a): + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + (operator.neg)(a.as_proxy()), + sym_num=None, + ) + # None no-ops this handler and lets the driving function proceed + return None + + def call_format(self, tx, _format_string, *args, **kwargs): + format_string = _format_string.as_python_constant() + return variables.StringFormatVariable.create(format_string, args, kwargs) + + def call_id(self, tx, *args): + if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): + nn_mod_variable = args[0] + mod = tx.output.get_submodule(nn_mod_variable.module_key) + return variables.ConstantVariable.create(id(mod)) + else: + unimplemented(f"call_id with args {args}") + + def call_deepcopy(self, tx, x): + unimplemented(f"copy.deepcopy {repr(x)}") + + def _comparison(self, tx, left, right): + """ + Used to implement comparison operators for different types. + For example, list1 < list2 is implemented differently from tensor1 < tensor2 + """ + from . import ( + BaseListVariable, + ConstantVariable, + NNModuleVariable, + TensorVariable, + UserDefinedObjectVariable, + UserFunctionVariable, + ) + from .lists import SizeVariable + from .tensor import ( + supported_const_comparison_ops, + supported_tensor_comparison_ops, + ) + + op = self.fn + + def _unimplemented(): + unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}") + + if ( + all( + isinstance(x, (NNModuleVariable, ConstantVariable)) + for x in [left, right] + ) + and op in supported_const_comparison_ops.values() + ): + left = ( + tx.output.get_submodule(left.module_key) + if isinstance(left, NNModuleVariable) + else left.as_python_constant() + ) + right = ( + tx.output.get_submodule(right.module_key) + if isinstance(right, NNModuleVariable) + else right.as_python_constant() + ) + return ConstantVariable.create(op(left, right)) + + if isinstance(left, UserFunctionVariable): + if op not in supported_const_comparison_ops.values(): + _unimplemented() + if not isinstance(right, UserFunctionVariable): + _unimplemented() + return ConstantVariable.create(op(left.fn, right.fn)) + + # Note, we have a rare BaseListVariable subtype mismatch with valid comparison + # x = torch.randn([3, 3]) + # x.size() == (3, 3) # True + # (3, 3) == x.size() # True + if isinstance(left, (SizeVariable, TupleVariable)) and isinstance( + right, (TupleVariable, SizeVariable) + ): + return BaseListVariable.list_compare(tx, op, left, right) + + if isinstance(left, BaseListVariable): + if not type(left) == type(right): # Mismatch in BaseListVariable subclasses + _unimplemented() + return BaseListVariable.list_compare(tx, op, left, right) + + # If they implement set semantics (e.g. SetVariable or DictKeys) + if hasattr(left, "set_items") and hasattr(right, "set_items"): + return ConstantVariable.create(op(left.set_items, right.set_items)) + + if isinstance(left, TensorVariable) or isinstance(right, TensorVariable): + from .builder import wrap_fx_proxy_cls + + if op in [operator.is_, operator.is_not]: + is_result = ( + isinstance(left, TensorVariable) + and isinstance(right, TensorVariable) + and id(extract_fake_example_value(left.as_proxy().node)) + == id(extract_fake_example_value(right.as_proxy().node)) + ) + if op is operator.is_: + return ConstantVariable.create(is_result) + else: + return ConstantVariable.create(not is_result) + + if op not in supported_tensor_comparison_ops.values(): + _unimplemented() + if ( + isinstance(left, TensorVariable) + and isinstance(right, TensorVariable) + and (left.size and right.size) is not None + and left.size != right.size + ): + try: + torch.broadcast_shapes(left.size, right.size) + except RuntimeError: + # not broadcastable, can't be compared + _unimplemented() + tensor_cls = left if isinstance(left, TensorVariable) else right + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return wrap_fx_proxy_cls( + type(tensor_cls), # handle Ndarrays and Tensors + tx, + proxy, + ) + + if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable): + if op not in supported_tensor_comparison_ops.values(): + _unimplemented() + + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return SymNodeVariable.create( + tx, + proxy, + sym_num=None, + ) + + if isinstance(left, UserDefinedObjectVariable) and isinstance( + right, UserDefinedObjectVariable + ): + return ConstantVariable.create(op(left.value, right.value)) + + if isinstance(left, (StreamVariable, EventVariable)) or isinstance( + right, (StreamVariable, EventVariable) + ): + if type(left) == type(right) and op is operator.eq: + return ConstantVariable(op(left.value, right.value)) + + if isinstance(right, ConstantVariable) or isinstance( + left, ConstantVariable + ): + return ConstantVariable(op(left.value, right.value)) + + if op.__name__.startswith("is_"): + # If the two objects are of different type, we can safely return False and True for `is` and `is not`, respectively + if type(left) is not type(right): + return ConstantVariable.create(op.__name__ != "is_") + + if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable): + return ConstantVariable.create(op(left.fn, right.fn)) + + _unimplemented() + + def call_and_(self, tx, a, b): + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.and_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if hasattr(a, "set_items") and hasattr(b, "set_items"): + return SetVariable(list(a.set_items & b.set_items)) + # None no-ops this handler and lets the driving function proceed + + def call_or_(self, tx, a, b): + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.or_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if hasattr(a, "set_items") and hasattr(b, "set_items"): + return SetVariable(list(a.set_items | b.set_items)) + # None no-ops this handler and lets the driving function proceed + return None + + def call_not_(self, tx, a): + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.not_, *proxy_args_kwargs([a], {}) + ), + sym_num=None, + ) + + # Unwrap the underlying ConstDictVariable + if isinstance(a, DictView): + a = a.dv_dict + if isinstance(a, (ListVariable, ConstDictVariable)): + return ConstantVariable.create(len(a.items) == 0) + + return None + + call_eq = _comparison + call_gt = _comparison + call_lt = _comparison + call_ge = _comparison + call_le = _comparison + call_ne = _comparison + call_is_ = _comparison + call_is_not = _comparison + + call_all = _polyfill_call_impl("all") + call_any = _polyfill_call_impl("any") + + +@contextlib.contextmanager +def dynamo_disable_grad(tx): + from . import GradModeVariable + + org_value = torch.is_grad_enabled() + gmv = GradModeVariable.create(tx, False) + try: + gmv.enter(tx) + yield + finally: + gmv.exit(tx) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/constant.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..fe122599b3034a695c16dc7411a752b83bc81f68 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/constant.py @@ -0,0 +1,213 @@ +# mypy: ignore-errors + +import operator +from typing import Dict, List + +import torch +from torch._dynamo.source import GetItemSource + +from .. import variables +from ..exc import unimplemented, UserError, UserErrorType +from ..guards import GuardBuilder, install_guard +from ..utils import common_constant_types, istype, np +from .base import typestr, VariableTracker + +_type_to_assert_reason = { + # NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards. + # A locally created set should always become a SetVariable, as the items in the set will already either be sourced + # from somewhere else, or unsourced. An input set would imply sources derived from set contents. For example, an + # input list's contents will have a source like some_list[0], some_list[1][1], etc. For a set, arbitrary access is + # not possible. This is a solvable problem, but one we have not taken on yet. As such, input sets are not allowed to + # become SetVariables. The solution here is to create a ConstantSetVariable that is more like a ConstantVariable. + # As this does not exist, we cannot add sets to this invariant. + list: "List types must use ListVariable.", + dict: "Dict types must use ConstDictVariable.", + torch.Tensor: "Tensor types must use TensorVariable.", + torch.SymInt: "SymInts must use SymNodeVariable. " + "If the underlying value is static, we will create a ConstantVariable and specialize.", + torch.SymFloat: "SymInts must use SymNodeVariable", +} + + +class ConstantVariable(VariableTracker): + @staticmethod + def create(value, **kwargs) -> VariableTracker: + source = kwargs.get("source", None) + is_literal = ConstantVariable.is_literal(value) + if not is_literal: + for disallowed_type, reason in _type_to_assert_reason.items(): + assert not isinstance(value, disallowed_type), reason + + # Routing for list and tuple literals. + if is_literal and isinstance(value, (list, tuple)): + items = [] + for i, x in enumerate(value): + item_source = GetItemSource(source, i) if source else None + if item_source: + install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH)) + items.append( + ConstantVariable.create( + x, + source=item_source, + ) + ) + return variables.BaseListVariable.cls_for(type(value))(items, **kwargs) + + return ConstantVariable(value, **kwargs) + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + if not ConstantVariable.is_literal(value): + for disallowed_type, reason in _type_to_assert_reason.items(): + assert not isinstance(value, disallowed_type), reason + + assert not isinstance( + value, (list, tuple) + ), "ConstantVariable(list) is banned - please create a ListVariable(items)" + if np is not None and isinstance(value, np.number): + self.value = value.item() + else: + self.value = value + + def as_proxy(self): + return self.value + + def __str__(self): + return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + @property + def items(self): + """ + Need this when adding a BaseListVariable and a ConstantVariable together. + Happens in detectron2. + """ + return self.unpack_var_sequence(tx=None) + + def getitem_const(self, arg: VariableTracker): + return ConstantVariable.create( + self.value[arg.as_python_constant()], + ) + + @staticmethod + def is_literal(obj): + if type(obj) in common_constant_types: + return True + # The structure within is_literal get routed to variables.BaseListVariable + if type(obj) in (list, tuple, set, frozenset, torch.Size): + return all(ConstantVariable.is_literal(x) for x in obj) + return False + + def unpack_var_sequence(self, tx): + try: + return [ConstantVariable.create(x) for x in self.as_python_constant()] + except TypeError as e: + raise NotImplementedError from e + + def const_getattr(self, tx, name): + if isinstance(self.value, type): + raise UserError( + UserErrorType.ANTI_PATTERN, + "Can't access members of type(obj) for a generated custom object. " + "Please use __class__ instead", + case_name="type_reflection_method", + ) + member = getattr(self.value, name) + if callable(member): + raise NotImplementedError() + return member + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from .tensor import SymNodeVariable + + if name == "format" and istype(self.value, str): + return variables.BuiltinVariable(str.format).call_function( + tx, [self, *args], kwargs + ) + + if any(isinstance(x, SymNodeVariable) for x in args): + # Promote to SymNodeVariable for operations involving dynamic shapes. + return variables.SymNodeVariable(self.as_proxy(), self.value).call_method( + tx, name, args, kwargs + ) + + try: + const_args = [a.as_python_constant() for a in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + + def has_arith_binop(num_ty): + return ( + isinstance(self.value, num_ty) + and hasattr(operator, name) + and len(args) == 1 + and args[0].is_python_constant() + ) + + if isinstance(self.value, str) and name in str.__dict__.keys(): + method = getattr(self.value, name) + return ConstantVariable.create(method(*const_args, **const_kwargs)) + elif has_arith_binop(int) or has_arith_binop(float): + op = getattr(operator, name) + add_target = const_args[0] + if isinstance(add_target, (torch.SymInt, torch.SymFloat)): + from .tensor import SymNodeVariable + + # Addition between a non sym and sym makes a sym + # sym_num = tx.output.register_attr_or_module( + # add_target, f"sym_shape_{add_target}", source=None + # ) + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return SymNodeVariable.create(tx, proxy, add_target) + return ConstantVariable.create(op(self.value, add_target)) + elif name == "__len__" and not (args or kwargs): + return ConstantVariable.create(len(self.value)) + elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): + assert not kwargs + search = args[0].as_python_constant() + result = search in self.value + return ConstantVariable.create(result) + + unimplemented(f"const method call {typestr(self.value)}.{name}") + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + +class EnumVariable(VariableTracker): + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def as_proxy(self): + return self.value + + def __str__(self): + return f"EnumVariable({type(self.value)})" + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + def const_getattr(self, tx, name): + member = getattr(self.value, name) + if callable(member): + raise NotImplementedError() + return member diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/ctx_manager.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/ctx_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d33a927f9b075702e4ea3331da0667a73f904373 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/ctx_manager.py @@ -0,0 +1,825 @@ +# mypy: ignore-errors + +import dataclasses +import inspect +from typing import Callable, Dict, List, Optional + +import torch._C +from torch._guards import Guard + +from .. import variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..device_interface import get_interface_for_device +from ..exc import unimplemented, Unsupported +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GlobalStateSource +from .base import VariableTracker +from .functions import ( + NestedUserFunctionVariable, + UserFunctionVariable, + UserMethodVariable, + WrappedUserFunctionVariable, + WrappedUserMethodVariable, +) + + +@dataclasses.dataclass +class ContextMangerState: + """ + Mutating `self` in VariableTracker is not allowed because we copy + them. This is a mutable container pointed to by context managers + that won't get copied, so it is safe to mutate. + """ + + cleanup_fn: Optional[Callable] = None + proxy: Optional[torch.fx.Proxy] = None + + def cleanup(self): + if self.cleanup_fn is not None: + self.cleanup_fn() + self.cleanup_fn = None + + def cleanup_assert(self): + assert self.cleanup_fn, "multiple exits?" + self.cleanup() + + +class ContextWrappingVariable(VariableTracker): + _nonvar_fields = { + "cm_obj", + "target_values", + "initial_values", + "state", + *VariableTracker._nonvar_fields, + } + + def __init__(self, target_values, initial_values=None, *, state=None, **kwargs): + super().__init__(**kwargs) + self.target_values = target_values + self.initial_values = initial_values + self.state = ContextMangerState() if state is None else state + + def enter(self, tx): + self._call_func(tx, self.target_values) + self.set_cleanup_hook(tx) + return variables.ConstantVariable.create(None) + + def set_cleanup_hook(self, tx, fn=None): + if fn is None: + + def fn(): + self._call_func(tx, self.initial_values) + + self.state.cleanup_fn = fn + tx.output.add_cleanup_hook(self.state.cleanup) + + def exit(self, tx, *args): + self.state.cleanup_assert() + return variables.ConstantVariable.create(None) + + def reconstruct(self, codegen): + codegen( + AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) + ) + + def module_name(self): + raise NotImplementedError("module_name called on base") + + def fn_name(self): + raise NotImplementedError("fn_name called on base") + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + assert len(args) == 1 + if isinstance(args[0], NestedUserFunctionVariable): + args[0] = UserFunctionVariable(args[0].get_function()) + assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable)) + + if isinstance(args[0], UserMethodVariable): + return WrappedUserMethodVariable(args[0], self) + + if isinstance(args[0], UserFunctionVariable): + return WrappedUserFunctionVariable(args[0], self) + + +class GenericContextWrappingVariable(ContextWrappingVariable): + def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs): + assert cm_obj is not None + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.cm_obj = cm_obj + + def enter(self, tx): + source = None if self.source is None else AttrSource(self.source, "__enter__") + try: + return variables.UserMethodVariable( + self.cm_obj.__enter__.__func__, + variables.UserDefinedObjectVariable(self.cm_obj), + source=source, + ).call_function(tx, [], {}) + except Unsupported as e: + raise unimplemented( + f"Unsupported context manager {self.cm_obj}'s __enter__ function" + ) from e + + def exit(self, tx, *args): + source = None if self.source is None else AttrSource(self.source, "__exit__") + try: + x = variables.UserMethodVariable( + self.cm_obj.__exit__.__func__, + variables.UserDefinedObjectVariable(self.cm_obj), + source=source, + ).call_function( + tx, + [ + variables.ConstantVariable.create(None), + variables.ConstantVariable.create(None), + variables.ConstantVariable.create(None), + ], + {}, + ) + except Unsupported as e: + raise unimplemented( + f"Unsupported context manager {self.cm_obj}'s __exit__ function" + ) from e + + tx.generic_context_manager_depth -= 1 + return x + + +class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): + """represents torch grad requries grad""" + + @staticmethod + def create(tx, target_values, **kwargs): + return GradInplaceRequiresGradCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx): + [enabled] = self.target_values + self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() + torch._C._functorch.set_inplace_requires_grad_allowed(enabled) + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.set_inplace_requires_grad_allowed( + self.prev_state + ), + ) + self.state.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (enabled,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx, *args): + self.state.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.grad increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if grad is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a grad + # call from eager that calls the compiled function, as the grad levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + + @staticmethod + def create(tx, **kwargs): + var = GradIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx): + install_guard(self._guards_singleton) + grad_level = torch._C._functorch._grad_increment_nesting() + self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) + self.state.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._grad_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(grad_level) + + def exit(self, tx, *args): + self.state.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._grad_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch VMap increment/decrement nesting""" + + # A guard is needed as the vmap level is baked into the torch FX graph + # generated. This is fine if vmap is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a vmap + # call from eager that calls the compiled function, as the vmap levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + + @staticmethod + def create(tx, target_values, **kwargs): + var = VmapIncrementNestingCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx): + install_guard(self._guards_singleton) + batch_size, randomness = self.target_values + vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness) + self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) + self.state.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._vmap_increment_nesting, + (batch_size, randomness), + {}, + ) + return variables.ConstantVariable.create(vmap_level) + + def exit(self, tx, *args): + self.state.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._vmap_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class GradModeVariable(ContextWrappingVariable): + """represents torch.{no_grad,enable_grad,set_grad_mode}()""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) + + @staticmethod + def create(tx, target_value, initialized=False, **kwargs): + var = GradModeVariable( + target_values=[target_value], + initial_values=[torch.is_grad_enabled()], + **kwargs, + ) + if initialized: + var._call_func(tx, var.target_values) + return var + + def __init__(self, target_values, initial_values=None, initialized=True, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx): + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit(self, tx, *args): + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ): + self._call_func(tx, self.initial_values) # undo eager initialization + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx, values): + assert len(values) == 1 + value = values[0] + # Coalesce grad mode mutations + if torch.is_grad_enabled() != value: + tx.output.create_node( + "call_function", torch._C._set_grad_enabled, (value,), {} + ) + torch._C._set_grad_enabled(value) + + def module_name(self): + return "torch" + + def fn_name(self): + return "set_grad_enabled" + + +class InferenceModeVariable(ContextWrappingVariable): + @staticmethod + def create(tx, target_value, **kwargs): + var = InferenceModeVariable( + [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs + ) + return var + + def __init__( + self, + target_values, + initial_values=None, + **kwargs, + ): + if initial_values is None: + # This must be called here since function defaults are evaluated at import time + initial_values = torch.is_inference_mode_enabled() + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.target_values = target_values + + def exit(self, tx, *args): + self.state.cleanup_assert() + tx.output.create_node( + "call_function", + torch.autograd.grad_mode._exit_inference_mode, + (self.state.proxy,), + {}, + ) + + def enter(self, tx): + ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) + self.set_cleanup_hook( + tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx) + ) + self.state.proxy = tx.output.create_node( + "call_function", + torch.autograd.grad_mode._enter_inference_mode, + (*self.target_values,), + {}, + ) + + def module_name(self): + return "torch" + + def fn_name(self): + return "inference_mode" + + +class TorchFunctionDisableVariable(ContextWrappingVariable): + """represents whether torch function overrides are enabled or not""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) + + @staticmethod + def create(tx, **kwargs): + var = TorchFunctionDisableVariable( + target_values=[False], + initial_values=[tx.output.torch_function_enabled], + **kwargs, + ) + # mlazos: I think this is here to make sure we don't reinvoke on clone() + var._call_func(tx, [False]) + var.set_cleanup_hook(tx) + return var + + def __init__(self, target_values, initial_values=None, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def _call_func(self, tx, values): + assert len(values) == 1 + tx.output.set_torch_function_state(values[0]) + + +class DeterministicAlgorithmsVariable(ContextWrappingVariable): + """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" + + _guards_singleton = Guard( + GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS + ) + + @staticmethod + def create(tx, target_value, **kwargs): + var = DeterministicAlgorithmsVariable( + target_values=[target_value], + initial_values=[torch.are_deterministic_algorithms_enabled()], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__(self, target_values, initial_values=None, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def _call_func(self, tx, values): + assert len(values) == 1 + value = values[0] + tx.output.create_node( + "call_function", torch._C._set_deterministic_algorithms, (value,), {} + ), + torch._C._set_deterministic_algorithms(value) + + def module_name(self): + return "torch" + + def fn_name(self): + return "use_deterministic_algorithms" + + +class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): + """represents torch.autograd.graph.disable_saved_tensors_hook.""" + + @staticmethod + def create(tx, target_value, **kwargs): + var = DisabledSavedTensorsHooksVariable( + target_values=[target_value], + initial_values=[ + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__(self, target_values, initial_values=None, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def _call_func(self, tx, values): + assert len(values) == 1 + value = values[0] + if value is not None: + # Disable `saved_tensors_hooks` with message (`value`) + # OR + # we are exiting this context and restoring the previous message. + tx.output.create_node( + "call_function", + torch._C._autograd._saved_tensors_hooks_disable, + (value,), + {}, + ) + torch._C._autograd._saved_tensors_hooks_disable(value) + else: + # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`. + tx.output.create_node( + "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {} + ) + torch._C._autograd._saved_tensors_hooks_enable() + + def module_name(self): + return "torch.autograd.graph" + + def fn_name(self): + return "disable_saved_tensors_hooks" + + +class AutocastModeVariable(ContextWrappingVariable): + @staticmethod + def create(func, args, kwargs): + assert func in [ + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ] + # device_type : str, + # dtype : Optional[_dtype] = None, + # enabled : bool = True, + # cache_enabled : Optional[bool] = None):cache_enabled + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + target_values = [] + kwargs.clear() + + for key in ["device_type", "dtype", "enabled", "cache_enabled"]: + if key == "device_type" and func in [ + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ]: + arg = "cuda" if func is torch.cuda.amp.autocast else "cpu" + else: + arg = bound_args.arguments[key] + if isinstance(arg, VariableTracker): + target_values.append(arg.as_python_constant()) + else: + target_values.append(arg) + + var = AutocastModeVariable(target_values, initial_values=None, **kwargs) + return var + + def __init__(self, target_values, initial_values=None, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.target_values = target_values + + def exit(self, tx, *args): + self.state.cleanup_assert() + tx.output.create_node( + "call_function", torch.amp._exit_autocast, (self.state.proxy,), {} + ) + + def enter(self, tx): + ctx = torch.amp._enter_autocast(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) + self.state.proxy = tx.output.create_node( + "call_function", torch.amp._enter_autocast, (*self.target_values,), {} + ) + + def module_name(self): + return "torch.amp.autocast_mode" + + def fn_name(self): + return "autocast" + + +class NullContextVariable(ContextWrappingVariable): + """ + This class represents Python contextlib.nullcontext. + It's used as a placeholder for other context managers that Dynamo doesn't + support yet, e.g, torch.autograd.profiler.record_function. + """ + + def __init__(self, target_values=None, **kwargs): + super().__init__(target_values=target_values, **kwargs) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def exit(self, tx, *args): + return variables.ConstantVariable.create(None) + + def module_name(self): + return "contextlib" + + def fn_name(self): + return "nullcontext" + + +class StreamContextVariable(ContextWrappingVariable): + @staticmethod + def create(tx, target_value, **kwargs): + from .builder import wrap_fx_proxy_cls + + current_stream_method = get_interface_for_device( + target_value.device + ).current_stream + current_stream = wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + current_stream_method, + (None,), + {}, + ), + ) + return StreamContextVariable( + target_values=[target_value], + initial_values=[current_stream], + device=target_value.device, + **kwargs, + ) + + def __init__(self, target_values, device, initial_values=None, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.device = device + self.set_stream = get_interface_for_device(self.device).set_stream + self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id + + def enter(self, tx): + # stream generated inside the traced function + if self.target_values[0].as_proxy() is not None: + tx.output.create_proxy( + "call_function", + self.set_stream, + (self.target_values[0].as_proxy(),), + {}, + ) + # stream passed from outside the traced function + else: + stream = self.target_values[0].value + tx.output.create_proxy( + "call_function", + self.set_stream_id, + (stream.stream_id, stream.device_index, stream.device_type), + {}, + ) + self.set_stream(self.target_values[0].value) + self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value)) + + def exit(self, tx, *args): + tx.output.create_proxy( + "call_function", + self.set_stream, + (self.initial_values[0].as_proxy(),), + {}, + ) + self.state.cleanup_assert() + + +class PreserveVersionContextVariable(ContextWrappingVariable): + """ + Wraps torch.autograd._unsafe_preserve_version_counter + """ + + @staticmethod + def constructor(tx): + return variables.LambdaVariable( + lambda tensor: PreserveVersionContextVariable( + tensor, + tensor.var_getattr(tx, "_version"), + ) + ) + + def __init__(self, tensor, prev_version, **kwargs): + kwargs.setdefault("target_values", None) + super().__init__(**kwargs) + self.tensor = tensor + self.prev_version = prev_version + + def enter(self, tx): + pass + + def exit(self, tx, *args): + from ..tensor_version_op import _unsafe_set_version_counter + + return variables.TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [self.tensor, self.prev_version], {}) + + def reconstruct(self, codegen): + unimplemented( + "torch.autograd._unsafe_preserve_version_counter with graph break" + ) + + +class StreamVariable(VariableTracker): + def __init__(self, proxy, value, device, **kwargs): + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + assert ( + value.device.type == device.type + ), "stream value is not equal to the passed device" + super().__init__(**kwargs) + self.proxy = proxy + self.value = value + self.device = device + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + assert hasattr(self.value, name), f"no stream method found named {name}" + assert name in [ + "wait_stream", + "synchronize", + "query", + "record_event", + "wait_event", + ], f" unsupported stream method {name}" + + from ..utils import proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name in ("wait_stream", "synchronize", "wait_event"): + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return variables.ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=variables.ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name == "record_event": + return wrap_fx_proxy_cls( + target_cls=EventVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + else: + unimplemented(self.device + " stream method " + name + " unsupported") + + def as_proxy(self): + return self.proxy + + def reconstruct(self, codegen): + # If we got here, this stream is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + # Since we just proved that - for other such structures, like lists and dicts, reconstruction + # is fine and sound according to dynamo principles of treating collectives. However, + # streams are special in that we want to preserve the identity of the stream as the same as in the graph + # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not + # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending + # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there. + prefix = f"_stream_{self.device}" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output( + codegen.create_load_global(name, push_null=False, add=True) + ) + + +class EventVariable(VariableTracker): + def __init__(self, proxy, value, **kwargs): + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + super().__init__(**kwargs) + self.proxy = proxy + self.value = value + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..utils import proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name in ("wait", "record", "synchronize"): + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return variables.ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=variables.ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + else: + unimplemented(f"event method {name} unsupported") + + def as_proxy(self): + return self.proxy + + +class WithExitFunctionVariable(VariableTracker): + def __init__(self, ctx: ContextWrappingVariable, target, **kwargs): + super().__init__(**kwargs) + assert isinstance(ctx, ContextWrappingVariable) + self.ctx = ctx + self.target = target + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + assert not kwargs + return self.ctx.exit(tx, *args) + + def reconstruct(self, codegen): + # Note here we reconstruct the context manager rather than the + # exit function. The handler generated by BlockStackEntry + # will re-enter the context in the resume function. + codegen( + AttrSource( + codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name() + ) + ) + + if codegen.tx.output.partial_convert: + codegen.extend_output( + [codegen.create_load_const(val) for val in self.ctx.target_values] + ) + codegen.extend_output( + create_call_function(len(self.ctx.target_values), True) + ) + codegen.append_output(create_instruction("SETUP_WITH", target=self.target)) + codegen.append_output(create_instruction("POP_TOP")) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/dicts.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..10bd5c9ad7a6ee2c64fe3f229f5a9166ba7eec5d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/dicts.py @@ -0,0 +1,897 @@ +# mypy: ignore-errors + +import collections +import dataclasses +import functools +import inspect +import sys +from typing import Dict, List, Optional + +from torch._subclasses.fake_tensor import is_fake + +from .. import variables +from ..bytecode_transformation import ( + create_call_function, + create_call_method, + create_instruction, +) +from ..eval_frame import skip_code + +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GetItemSource +from ..utils import dict_keys, dict_values, istype, specialize_symnode +from .base import MutableLocal, VariableTracker +from .constant import ConstantVariable + +# [Adding a new supported class within the keys of ConstDictVarialble] +# - Add its tracker type to is_hashable +# - (perhaps) Define how it is compared in _HashableTracker._eq_impl + + +def is_hashable(x): + if isinstance(x, variables.TensorVariable): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return x.as_proxy().node.meta.get("example_value") is not None + elif isinstance(x, variables.TupleVariable): + return all(is_hashable(e) for e in x.items) + else: + return isinstance( + x, + ( + variables.BuiltinVariable, + variables.SymNodeVariable, + variables.ConstantVariable, + variables.EnumVariable, + variables.user_defined.UserDefinedClassVariable, + variables.UserFunctionVariable, + variables.SkipFunctionVariable, + variables.misc.NumpyVariable, + variables.NNModuleVariable, + variables.MethodWrapperVariable, + variables.TorchInGraphFunctionVariable, + variables.TypingVariable, + variables.FunctoolsPartialVariable, + ), + ) + + +class ConstDictVariable(VariableTracker): + class _HashableTracker: + """ + Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable + This should not be seen or touched by anything outside of ConstDictVariable and its children + Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing + """ + + def __init__(self, vt): + # We specialize SymNodes + vt = specialize_symnode(vt) + # TODO Temorarily remove to figure out what keys are we breaking on + # and add proper support for them + if not is_hashable(vt): + unimplemented(f"Dict key of type {type(vt)}. Key: {vt}") + self.vt = vt + + @property + def underlying_value(self): + if isinstance(self.vt, variables.TensorVariable): + x = self.vt.as_proxy().node.meta["example_value"] + elif isinstance(self.vt, variables.TupleVariable): + Hashable = ConstDictVariable._HashableTracker + x = tuple(Hashable(e).underlying_value for e in self.vt.items) + elif isinstance(self.vt, variables.NNModuleVariable): + return self.vt.module + elif isinstance(self.vt, variables.UserFunctionVariable): + return self.vt.get_function() + else: + x = self.vt.as_python_constant() + return x + + def __hash__(self): + return hash(self.underlying_value) + + @staticmethod + def _eq_impl(a, b): + # TODO: Put this in utils and share it between variables/builtin.py and here + if type(a) != type(b): + return False + elif isinstance(a, tuple): + Hashable = ConstDictVariable._HashableTracker + return len(a) == len(b) and all( + Hashable._eq_impl(u, v) for u, v in zip(a, b) + ) + elif is_fake(a): + return a is b + else: + return a == b + + def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: + Hashable = ConstDictVariable._HashableTracker + assert isinstance(other, Hashable) or ConstantVariable.is_literal( + other + ), type(other) + if isinstance(other, Hashable): + return Hashable._eq_impl(self.underlying_value, other.underlying_value) + + # constant + return Hashable._eq_impl(self.underlying_value, other) + + def __init__( + self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs + ): + super().__init__(**kwargs) + + Hashable = ConstDictVariable._HashableTracker + + # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers + assert all( + isinstance(x, (VariableTracker, Hashable)) + and isinstance(v, VariableTracker) + for x, v in items.items() + ) + + def make_hashable(key): + return key if isinstance(key, Hashable) else Hashable(key) + + self.items = {make_hashable(x): v for x, v in items.items()} + self.user_cls = user_cls + + def as_proxy(self): + return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} + + def as_python_constant(self): + return { + k.vt.as_python_constant(): v.as_python_constant() + for k, v in self.items.items() + } + + def keys_as_python_constant(self): + return {k.vt.as_python_constant(): v for k, v in self.items.items()} + + def python_type(self): + return self.user_cls + + def __contains__(self, vt): + assert isinstance(vt, VariableTracker) + Hashable = ConstDictVariable._HashableTracker + return is_hashable(vt) and Hashable(vt) in self.items + + def reconstruct(self, codegen): + # instructions to load collections.OrderedDict if necessary + if self.user_cls is collections.OrderedDict: + codegen.extend_output( + [ + codegen.create_load_python_module(collections, True), + codegen.create_load_attr("OrderedDict"), + ] + ) + # instructions to build the dict keys and values + for key, value in self.items.items(): + codegen(key.vt) + codegen(value) + # BUILD_MAP and calling collections.OrderedDict if necessary + if self.user_cls is collections.OrderedDict: + codegen.extend_output( + [ + create_instruction("BUILD_MAP", arg=len(self.items)), + *create_call_function(1, False), + ] + ) + # BUILD_MAP only if user_cls is dict + else: + codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items))) + + def getitem_const(self, arg: VariableTracker): + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + raise KeyError(arg.value) + return self.items[key] + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + BuiltinVariable, + ConstantVariable, + ListIteratorVariable, + ListVariable, + TupleVariable, + ) + + Hashable = ConstDictVariable._HashableTracker + + arg_hashable = args and is_hashable(args[0]) + + if name == "__getitem__": + assert len(args) == 1 + return self.getitem_const(args[0]) + elif name == "items": + assert not (args or kwargs) + return TupleVariable( + [TupleVariable([k.vt, v]) for k, v in self.items.items()] + ) + elif name == "keys": + assert not (args or kwargs) + return DictKeys(self) + elif name == "values": + assert not (args or kwargs) + return DictValues(self) + elif name == "copy": + assert not (args or kwargs) + return self.clone(items=self.items.copy(), mutable_local=MutableLocal()) + elif name == "__len__": + assert not (args or kwargs) + return ConstantVariable.create(len(self.items)) + elif name == "__setitem__" and arg_hashable and self.mutable_local: + assert not kwargs and len(args) == 2 + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = args[1] + return ConstantVariable.create(None) + elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self: + # missing item, return the default value + if len(args) == 1: + return ConstantVariable(None) + else: + return args[1] + elif name == "pop" and arg_hashable and self.mutable_local: + tx.output.side_effects.mutation(self) + return self.items.pop(Hashable(args[0])) + elif name == "clear": + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif ( + name == "update" + and len(args) == 1 + and isinstance( + args[0], + ( + ConstDictVariable, + ListVariable, + TupleVariable, + ListIteratorVariable, + ), + ) + and self.mutable_local + ): + tx.output.side_effects.mutation(self) + if isinstance(args[0], ConstDictVariable): + dict_vt = args[0] + else: + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) + self.items.update(dict_vt.items) + # Wrap strings + kwargs = { + Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() + } + self.items.update(kwargs) + return ConstantVariable.create(None) + elif name in ("get", "__getattr__") and args[0] in self: + return self.getitem_const(args[0]) + elif name == "__contains__" and len(args) == 1: + return ConstantVariable.create(args[0] in self) + else: + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + return [x.vt for x in self.items.keys()] + + +class DefaultDictVariable(ConstDictVariable): + def __init__(self, items, user_cls, default_factory=None, **kwargs): + super().__init__(items, user_cls, **kwargs) + assert user_cls is collections.defaultdict + self.default_factory = default_factory + + def is_python_constant(self): + # Return false for unsupported defaults. This ensures that a bad handler + # path is not taken in BuiltinVariable for getitem. + if self.default_factory not in [list, tuple, dict] and not self.items: + return False + return super().is_python_constant() + + @staticmethod + def is_supported_arg(arg): + if isinstance(arg, variables.BuiltinVariable): + return arg.fn in [list, tuple, dict] + else: + return isinstance(arg, variables.functions.BaseUserFunctionVariable) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__getitem__": + assert len(args) == 1 + + if args[0] in self: + return self.getitem_const(args[0]) + else: + if self.default_factory is None: + raise KeyError(f"{args[0]}") + else: + default_var = self.default_factory.call_function(tx, [], {}) + super().call_method( + tx, "__setitem__", (args[0], default_var), kwargs + ) + return default_var + else: + return super().call_method(tx, name, args, kwargs) + + +class SetVariable(ConstDictVariable): + """We model a sets as dictonary with None values""" + + def __init__( + self, + items: List[VariableTracker], + **kwargs, + ): + items = dict.fromkeys(items, SetVariable._default_value()) + super().__init__(items, **kwargs) + + @property + def set_items(self): + return set(self.items.keys()) + + @staticmethod + def _default_value(): + # Variable to fill in he keys of the dictinary + return ConstantVariable.create(None) + + def as_proxy(self): + return {k.vt.as_proxy() for k in self.set_items} + + def python_type(self): + return set + + def as_python_constant(self): + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen): + codegen.foreach([x.vt for x in self.set_items]) + codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) + + def call_method( + self, + tx, + name, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> "VariableTracker": + # We foward the calls to the dictionary model + if name == "add": + assert not kwargs + assert len(args) == 1 + name = "__setitem__" + args = (args[0], SetVariable._default_value()) + elif name == "pop": + assert not kwargs + assert not args + # Choose an item at random and pop it via the Dict.pop method + result = self.set_items.pop().vt + super().call_method(tx, name, (result,), kwargs) + return result + return super().call_method(tx, name, args, kwargs) + + def getitem_const(self, arg: VariableTracker): + raise RuntimeError("Illegal to getitem on a set") + + +class DictView(VariableTracker): + """ + Models _PyDictViewObject + + This is an "abstract" class. Subclasses will override kv and the items method + """ + + kv: Optional[str] = None + + def __init__(self, dv_dict: ConstDictVariable, **kwargs): + super().__init__(**kwargs) + assert self.kv in ("keys", "values") + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + @property + def view_items(self): + return getattr(self.dv_dict.items, self.kv)() + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + # Implement in the subclasses + raise NotImplementedError() + + def unpack_var_sequence(self, tx): + def unwrap(x): + return x.vt if self.kv == "keys" else x + + return [unwrap(x) for x in self.view_items] + + def reconstruct(self, codegen): + codegen(self.dv_dict) + codegen.extend_output( + [ + create_instruction("LOAD_METHOD", argval=self.kv), + *create_call_method(0), + ] + ) + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__len__": + return self.dv_dict.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + +class DictKeys(DictView): + kv = "keys" + + @property + def set_items(self): + return set(self.view_items) + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + return [x.vt for x in self.view_items] + + def python_type(self): + return dict_keys + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__contains__": + return self.dv_dict.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + +class DictValues(DictView): + # DictValues is an iterable but cannot be compared. + kv = "values" + + @property + def view_items_vt(self): + return list(self.view_items) + + def python_type(self): + return dict_values + + +def _is_matching_transformers_cls(cls) -> bool: + mod = sys.modules.get("transformers.file_utils") + return mod is not None and issubclass(cls, mod.ModelOutput) + + +def _is_matching_diffusers_cls(cls) -> bool: + mod = sys.modules.get("diffusers.utils") + return mod is not None and issubclass(cls, mod.BaseOutput) + + +def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": + """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" + if name in self.items or hasattr(self.user_cls, name): + return ConstantVariable(True) + elif istype(self.mutable_local, MutableLocal) and self.source is None: + # Something created locally can't have any extra fields on it + return ConstantVariable(False) + elif self.mutable_local is None and self.source: + # Maybe add a guard + try: + example = tx.output.root_tx.get_example_value(self.source) + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + return ConstantVariable(hasattr(example, name)) + except KeyError: + pass + unimplemented( + f"hasattr({self.__class__.__name__}, {name}) {self.mutable_local} {self.source}" + ) + + +class DataClassVariable(ConstDictVariable): + """ + This is a bit of a hack to deal with + transformers.file_utils.ModelOutput() from huggingface. + + ModelOutput causes trouble because it a a mix of a dataclass and a + OrderedDict and it calls super() methods implemented in C. + """ + + # ModelOutput() excludes None, though generic datclasses don't + include_none = False + + @staticmethod + @functools.lru_cache(None) + def _patch_once(): + try: + from transformers.file_utils import ModelOutput + + for obj in ModelOutput.__dict__.values(): + if callable(obj): + skip_code(obj.__code__) + except ImportError: + pass + + try: + from diffusers.utils import BaseOutput + + for obj in BaseOutput.__dict__.values(): + if callable(obj): + skip_code(obj.__code__) + except ImportError: + pass + + @staticmethod + def is_matching_cls(cls): + return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) + + @classmethod + def is_matching_object(cls, obj): + return cls.is_matching_cls(type(obj)) + + @classmethod + def create(cls, user_cls, args, kwargs, options): + DataClassVariable._patch_once() + + skip_code(user_cls.__init__.__code__) + keys = [f.name for f in dataclasses.fields(user_cls)] + bound = inspect.signature(user_cls).bind(*args, **kwargs) + bound.apply_defaults() + assert set(bound.arguments.keys()) == set(keys) + items = {} + for key in keys: + val = bound.arguments[key] + key = ConstantVariable.create(key) + if isinstance(val, VariableTracker): + items[key] = val + else: + if cls.include_none: + assert variables.ConstantVariable.is_literal(val) + items[key] = variables.ConstantVariable.create(val) + else: + assert val is None, f"unexpected {val}" + + if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable): + unimplemented("DataClassVariable iterator constructor") + # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__ + + return cls(items, user_cls, **options) + + @classmethod + def wrap(cls, builder, obj): + user_cls = type(obj) + keys = [f.name for f in dataclasses.fields(user_cls)] + + excluded = [] + items = {} + for key in keys: + # __init__ function of a dataclass might not have yet defined the key + if hasattr(obj, key): + val = getattr(obj, key) + var = builder.__class__( + tx=builder.tx, source=AttrSource(builder.source, key) + )(val) + if val is not None or cls.include_none: + key = ConstantVariable.create(key) + items[key] = var + else: + excluded.append(var) + return cls(items, user_cls) + + def __init__(self, items, user_cls, **options): + super().__init__(items, user_cls, **options) + assert self.is_matching_cls(user_cls) + + def as_proxy(self): + raise NotImplementedError() + + def reconstruct(self, codegen): + codegen.extend_output([codegen._create_load_const(self.user_cls)]) + # All the keys are just wrapped strings + d = self.keys_as_python_constant() + codegen.foreach(d.values()) + keys = tuple(d.keys()) + codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__getitem__": + assert not kwargs and len(args) == 1 + val = args[0] + if val.python_type() == str: + return self.getitem_const(val) + else: + return self.call_method(tx, "to_tuple", [], {}).call_method( + tx, "__getitem__", args, kwargs + ) + elif name == "to_tuple": + assert not (args or kwargs) + return variables.TupleVariable(list(self.items.values())) + elif name == "__setattr__": + name = "__setitem__" + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name: str) -> "VariableTracker": + name_vt = ConstantVariable.create(name) + if name_vt in self: + return self.call_method(tx, "__getitem__", [name_vt], {}) + elif not self.include_none: + defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} + if name in defaults: + assert variables.ConstantVariable.is_literal(defaults[name]) + return variables.ConstantVariable.create(defaults[name]) + super().var_getattr(tx, name) + + call_hasattr = _call_hasattr_customobj + + +class CustomizedDictVariable(ConstDictVariable): + @staticmethod + def is_matching_cls(cls): + # True if using default OrderedDict.__init__ and did not implement __post_init__ + if ( + issubclass(cls, collections.OrderedDict) + and cls.__init__ is collections.OrderedDict.__init__ + and not hasattr(cls, "__post_init__") + ): + return True + # hack for HF usecase: + # assume dataclass annotation for ModelOutput subclass + # assume self.create is AA to ModelOutput.__post_init__ + return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) + + @classmethod + def is_matching_object(cls, obj): + return cls.is_matching_cls(type(obj)) + + # called from user_defined.py + # when is_matching_cls(cls) is true + @classmethod + def create(cls, user_cls, args, kwargs, options): + # avoid tracing when returning ModelOutput from forward func + for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"): + if hasattr(user_cls, attr_name): + fn = getattr(user_cls, attr_name) + assert callable(fn), f"expect callable attr {attr_name}" + if hasattr(fn, "__code__"): + skip_code(fn.__code__) + + if dataclasses.is_dataclass(user_cls): + # @dataclass CustomDict(a=1, b=2) + bound = inspect.signature(user_cls).bind(*args, **kwargs) + bound.apply_defaults() + + def make_var(x): + if isinstance(x, VariableTracker): + return x + elif ConstantVariable.is_literal(x): + return ConstantVariable.create(x) + else: + unimplemented( + "expect VariableTracker or ConstantVariable.is_literal" + ) + + items = { + ConstantVariable.create(k): make_var(v) + for k, v in bound.arguments.items() + } + elif not args: + # CustomDict(a=1, b=2) in the general (non-dataclass) case. + items = {ConstantVariable.create(k): v for k, v in kwargs.items()} + elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs: + # CustomDict({'a': 1, 'b': 2}) + items = args[0].items + else: + unimplemented("custom dict init with args/kwargs unimplemented") + + return cls(items, user_cls, **options) + + # called from builder.py + @classmethod + def wrap(cls, builder, obj): + raise NotImplementedError() + + def __init__(self, items, user_cls, **options): + super().__init__(items, user_cls, **options) + assert self.is_matching_cls(user_cls) + + def as_proxy(self): + raise NotImplementedError() + + # 'RETURN_VALUE triggered compile' + # called from torch/_dynamo/codegen.py + def reconstruct(self, codegen): + codegen.extend_output([codegen._create_load_const(self.user_cls)]) + # All the keys are just wrapped strings + d = self.keys_as_python_constant() + codegen.foreach(d.values()) + keys = tuple(d.keys()) + codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + fn = getattr(self.user_cls, name) + source = None if self.source is None else AttrSource(self.source, name) + + if hasattr(fn, "__objclass__") and fn.__objclass__ in ( + dict, + collections.OrderedDict, + ): + # for python dict method without overridden + return super().call_method(tx, name, args, kwargs) + elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"): + # for user overridden method + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=source), + [self] + list(args), + kwargs, + ) + + unimplemented("custom dict: call_method unimplemented name=%s", name) + + def var_getattr(self, tx, name: str) -> "VariableTracker": + name_vt = ConstantVariable.create(name) + if name_vt in self: + return self.call_method(tx, "__getitem__", [name_vt], {}) + super().var_getattr(tx, name) + + call_hasattr = _call_hasattr_customobj + + +@functools.lru_cache(None) +def _install_PretrainedConfig_patch(): + import transformers + + # We need to monkeypatch transformers here, sadly. + # TODO(voz): Upstream to transformers lib + + def _dynamo_overriden_transformers_eq(self, other): + if not hasattr(other, "__dict__"): + return False + return self.__dict__ == other.__dict__ + + transformers.configuration_utils.PretrainedConfig.__eq__ = ( + _dynamo_overriden_transformers_eq + ) + + +class HFPretrainedConfigVariable(VariableTracker): + """ + Hack for HuggingFace PretrainedConfig + """ + + @staticmethod + def is_matching_cls(cls): + mod = sys.modules.get("transformers.configuration_utils") + is_match = mod is not None and issubclass(cls, mod.PretrainedConfig) + + # Lazily install monkeypatch the first time we see it in dynamo + if is_match: + _install_PretrainedConfig_patch() + return is_match + + @classmethod + def is_matching_object(cls, obj): + return cls.is_matching_cls(type(obj)) + + def __init__(self, obj, **kwargs): + super().__init__(**kwargs) + self.obj = obj + assert self.is_matching_cls(type(obj)) + + def var_getattr(self, tx, name: str) -> "VariableTracker": + from . import ConstantVariable + + return ConstantVariable.create(getattr(self.obj, name)) + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + return variables.ConstantVariable.create(hasattr(self.obj, name)) + + +class PythonSysModulesVariable(VariableTracker): + """Special case for sys.modules. + + Without this we will guard on the exact set of modules imported in the + lifetime of the python program. + """ + + def python_type(self): + return dict + + def reconstruct(self, codegen): + codegen.extend_output( + [ + codegen.create_load_python_module(sys, True), + codegen.create_load_attr("modules"), + ] + ) + + def call_method( + self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ): + from .builder import VariableBuilder + + if name == "__getitem__": + return self.call_getitem(tx, *args, **kwargs) + elif name == "get": + return self.call_get(tx, *args, **kwargs) + elif name == "__contains__": + return self.call_contains(tx, *args, **kwargs) + + # Fallback to dict implementation + real_dict = VariableBuilder(tx, self.source)(sys.modules) + return real_dict.call_method(tx, name, args, kwargs) + + def _contains_helper(self, tx, key: VariableTracker): + k = key.as_python_constant() + has_key = k in sys.modules + install_guard( + self.make_guard( + functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key) + ) + ) + return k, has_key + + def call_contains(self, tx, key: VariableTracker): + k, has_key = self._contains_helper(tx, key) + return ConstantVariable.create(value=has_key) + + def call_get( + self, tx, key: VariableTracker, default: Optional[VariableTracker] = None + ): + from .builder import VariableBuilder + + k, has_key = self._contains_helper(tx, key) + + if has_key: + return VariableBuilder( + tx, + GetItemSource(self.source, k), + )(sys.modules[k]) + + if default is not None: + return default + + return ConstantVariable.create(value=None) + + def call_getitem(self, tx, key: VariableTracker): + from .builder import VariableBuilder + + k, has_key = self._contains_helper(tx, key) + return VariableBuilder( + tx, + GetItemSource(self.source, k), + )(sys.modules[k]) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/distributed.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..066be295b93170c4f07a41be528650813649db43 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/distributed.py @@ -0,0 +1,388 @@ +# mypy: ignore-errors +import functools +import inspect +from typing import Dict, List + +import torch +from ...fx.experimental._backward_state import BackwardState +from .. import compiled_autograd, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..exc import unimplemented +from ..external_utils import call_module_hooks_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GlobalSource +from ..utils import istype +from .base import VariableTracker +from .constant import ConstantVariable + + +class DistributedVariable(VariableTracker): + """ + The base distributed variable that encapsulates common methods + for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.). + Concrete distributed objects could inherit this class and add object + specific logic. + + i.e. It provides the check on the distributed package existance + and hold the tracking value for the corresponding distributed object. + """ + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + if not DistributedVariable.is_available(): + unimplemented("torch.distributed package is not available!") + self.value = value + + def python_type(self): + return type(self.value) + + @staticmethod + def is_available(): + # check if the distributed package is available or not + return torch.distributed.is_available() + + +def is_from_local(value): + if not DistributedVariable.is_available(): + return False + from torch.distributed._tensor import DTensor + + return inspect.isfunction(value) and value is DTensor.from_local + + +def is_constant_pg_functions(value): + if not DistributedVariable.is_available(): + return False + + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + + constant_processgroup_functions = [ + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ] + + return inspect.isfunction(value) and value in constant_processgroup_functions + + +class PlacementClassVariable(DistributedVariable): + @staticmethod + def is_placement_type(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed._tensor.placement_types import Placement + + return type(value) is type and issubclass(value, Placement) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if ( + inspect.getattr_static(self.value, "__new__", None) in (object.__new__,) + and self.source + ): + # NOTE: we don't need to track mutations to the placement class as they + # suppose to be immutable. + new_obj = object.__new__(self.value) + var = PlacementVariable(new_obj) + if inspect.getattr_static(self.value, "__init__", None): + var.call_method(tx, "__init__", args, kwargs) + return var + + return super().call_function(tx, args, kwargs) + + +class PlacementVariable(DistributedVariable): + @staticmethod + def is_placement(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed._tensor.placement_types import Placement + + return isinstance(value, Placement) + + def as_python_constant(self): + return self.value + + def var_getattr(self, tx, name: str) -> VariableTracker: + if name == "dim": + return ConstantVariable.create(self.value.dim) + return super().var_getattr(tx, name) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable + + # Placement types dynamo tracking only allows following methods + # and __setattr__ is for case like `Shard(dim)` and methods. + # Methods in the list must satisfy: + # 1. Input arguments are constants and do not need to be guarded on; + # 2. Output is constant with respect to their inputs + constant_fold_functions = [ + "__init__", + "__setattr__", + "is_shard", + "is_partial", + "is_replicate", + ] + + if name in constant_fold_functions: + try: + value_type = type(self.value) + assert ( + inspect.getattr_static(value_type, "__getattr__", None) is None + ), "no custom getattr allowed!" + method = inspect.getattr_static(value_type, name) + except AttributeError: + method = None + if method is object.__init__: + return ConstantVariable.create(None) + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + if name == "__setattr__": + method(self.value, *args, **kwargs) + return self + constant_val = method(self.value, *args, **kwargs) + return ConstantVariable.create(constant_val) + + return super().call_method(tx, name, args, kwargs) + + +class DeviceMeshVariable(DistributedVariable): + @staticmethod + def is_device_mesh(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.device_mesh import DeviceMesh + + return istype(value, DeviceMesh) + + def as_python_constant(self): + return self.value + + def var_getattr(self, tx, name: str) -> VariableTracker: + if name == "ndim": + return ConstantVariable.create(self.value.ndim) + return super().var_getattr(tx, name) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "size": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) + if name == "get_coordinate": + return ConstantVariable.create(self.value.get_coordinate()) + if name == "get_group": + return ConstantVariable.create(self.value.get_group()) + if name == "_get_or_create_default_group": + return ProcessGroupVariable(self.value._get_or_create_default_group()) + return super().call_method(tx, name, args, kwargs) + + +class ProcessGroupVariable(DistributedVariable): + """ + We don't want a ProcessGroup object to end up in our output graph. + + But it's common for dynamo to intercept a PG that is then used to get info like + rank() or world_size(), as well as passed to utility functions in distributed_c10d + which desugar it into plain types like a ranklist and tag. + + For convenience and proper guarding, we construct a variable type. + + TODO: make it possible to use ProcessGroupVariable as input to simple functions + like _expand_group without dynamo complaining about making a proxy for it. + It is not a tensor-like type, and we don't want a proxy- but dynamo assumes + torch library functions are dealing with tensor-like types and would have proxies + for their args. + TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors + or just graph-break whenever one of our special cases is not hit? + """ + + def as_python_constant(self): + return self.value + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "rank": + return variables.ConstantVariable.create(self.value.rank()) + if name == "size": + return variables.ConstantVariable.create(self.value.size()) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "group_name": + return variables.ConstantVariable.create(self.value.group_name) + if name in ["rank", "size"]: + return variables.LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + # TODO should this just raise unimplemented? + return super().var_getattr(tx, name) + + @staticmethod + def is_process_group(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch._C._distributed_c10d import ProcessGroup + from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + + return istype(value, (ProcessGroup, FakeProcessGroup)) + + @staticmethod + def get_global_pg_variable(): + """ + Make a ProcessGroupVariable from torch.distributed.group.WORLD and + intall guards. + """ + import torch.distributed as dist + + source = AttrSource( + AttrSource( + base=AttrSource( + base=GlobalSource(global_name="torch"), + member="distributed", + get_static=False, + ), + member="group", + get_static=False, + ), + member="WORLD", + get_static=False, + ) + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return ProcessGroupVariable( + dist.group.WORLD, + source=source, + ) + + +class BackwardHookVariable(VariableTracker): + """ + Handles torch.utils.hooks.BackwardHook for module-level backward + hooks. + """ + + @staticmethod + def create( + tx, + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + ): + if not compiled_autograd.compiled_autograd_enabled: + unimplemented("module-level backwards hooks require compiled autograd") + + def _in_graph_bw_hooks(bw_state: BackwardState): + """ + Rather than installing the user hooks in the graph (which + don't survive AotAutograd), we install hooks that will call + trace_wrapped in the backward pass that CompiledAutograd + can turn into actual hook calls. + """ + return torch.utils.hooks.BackwardHook( + None, + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_hooks_name, + module_name=module_name, + ), + ), + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_pre_hooks_name, + module_name=module_name, + ), + ), + ) + + module_name, bw_state_proxy = tx.output.add_backward_state_hook(module) + user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks) + user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks) + proxy = tx.output.create_proxy( + "call_function", + _in_graph_bw_hooks, + (bw_state_proxy,), + {}, + ) + proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ()) + return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks) + + def __init__( + self, + proxy: torch.fx.Proxy, + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + **options, + ): + super().__init__(**options) + self.proxy = proxy + self.module = module + self.user_hooks = user_hooks + self.user_pre_hooks = user_pre_hooks + + def as_proxy(self): + return self.proxy + + def call_method( + self, + tx, + name, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> VariableTracker: + if name in ("setup_input_hook", "setup_output_hook"): + return self._setup_hook(tx, name, *args, **kwargs) + return super().call_method(tx, name, args, kwargs) + + def _setup_hook(self, tx, hook_method_name, args): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + hook_method_name, + (self.as_proxy(), args.as_proxy()), + {}, + ), + ) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/functions.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c7c64010e19361d2b596553e7ffd38985486a2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/functions.py @@ -0,0 +1,947 @@ +# mypy: ignore-errors + +import collections +import functools +import inspect +import itertools +import types +from typing import Dict, List, Optional, TYPE_CHECKING, Union + +import torch + +from .. import variables +from ..bytecode_transformation import create_call_function, create_rot_n +from ..exc import unimplemented, Unsupported +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource +from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell +from .base import MutableLocal, typestr, VariableTracker +from .constant import ConstantVariable +from .distributed import ProcessGroupVariable + +if TYPE_CHECKING: + from torch._guards import Source + + +def wrap_bound_arg(tx, val, source=None): + # Source propagation is best effort since not every object we encounter has a source to begin with. + if isinstance(val, VariableTracker): + return val + elif not source: + from torch._dynamo.variables.builder import SourcelessBuilder + + return SourcelessBuilder()(tx, val) + else: + # Create a lazy variable to avoid guarding on __defaults__ unless really + # needed. + return variables.LazyVariableTracker.create(val, source) + + +def wrap_args_kwargs(tx, result): + for k, v in list(result.items()): + if isinstance(v, (tuple, dict)): + # args/kwargs + result[k] = wrap_bound_arg(tx, v) + + +def init_cellvars(parent, result, code): + closure_cells = dict() + side_effects = parent.output.side_effects + + # for name in itertools.chain(code.co_cellvars, code.co_freevars): + for name in code.co_cellvars: + closure_cells[name] = side_effects.track_cell_new() + if name in result: + side_effects.store_cell(closure_cells[name], result.pop(name)) + + return closure_cells + + +def _create_nested_fn( + code, f_globals, name, defaults, closure, kwdefaults, annotations +): + from types import FunctionType + + func = FunctionType(code, f_globals, name, defaults, closure) + func.__kwdefaults__ = kwdefaults + + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert annotations is None or isinstance(annotations, dict) + func.__annotations__ = annotations + + return func + + +class BaseUserFunctionVariable(VariableTracker): + def get_filename(self): + return self.get_code().co_filename + + def get_name(self): + return self.get_code().co_name + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + return tx.inline_user_function_return( + self, list(self.self_args()) + list(args), kwargs + ) + + def call_hasattr(self, tx, name: str) -> VariableTracker: + result = False + + try: + result = hasattr(self.get_function(), name) + except NotImplementedError: + if name == "__name__" and isinstance(self, NestedUserFunctionVariable): + result = True + return variables.ConstantVariable.create(result) + + def inspect_parameter_names(self): + return list(inspect.signature(self.get_function()).parameters) + + def closure_vars(self, tx): + return {} + + +class UserFunctionVariable(BaseUserFunctionVariable): + """Some unsupported user-defined global function""" + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return cls( + value, + source=source, + ) + + def __init__(self, fn, is_constant=False, **kwargs): + super().__init__(**kwargs) + if getattr(fn, "_dynamo_marked_constant", False): + # This method should be treated as a constant for the purposes of compilation + self.is_constant = True + else: + self.is_constant = False + + assert isinstance( + fn, (types.FunctionType, torch.jit.ScriptFunction) + ), f"expected FunctionType found {typestr(fn)} {fn}" + # unpack @torch._dynamo.optimize()(fn) wrapped function + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + # unpack torch.jit.script_if_tracing + if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False): + fn = inspect.getattr_static(fn, "__original_fn", fn) + self.fn: types.FunctionType = fn + + def as_python_constant(self): + if istype(self, UserFunctionVariable): + return self.fn + # subclasses (such as methods) usually aren't a constant + return super().as_python_constant() + + def self_args(self): + return [] + + def get_function(self): + return self.fn + + def get_code(self): + return self.fn.__code__ + + def python_type(self): + return types.FunctionType + + def has_self(self): + return getattr(self.fn, "__self__", None) is not None + + def get_globals(self): + return self.fn.__globals__ + + def bind_args(self, parent, args, kwargs): + assert not self.is_constant + tx = parent.output.root_tx + wrap = functools.partial(wrap_bound_arg, tx=tx) + + fn: types.FunctionType = self.fn + defaults = fn.__defaults__ or [] + defaults_sources = [ + None if self.source is None else DefaultsSource(self.source, idx) + for idx, _ in enumerate(defaults) + ] + fake_func = types.FunctionType( + fn.__code__, + fn.__globals__, + fn.__name__, + tuple( + [ + wrap(val=arg, source=source) + for arg, source in zip(defaults, defaults_sources) + ] + ), + fn.__closure__, + ) + if fn.__kwdefaults__: + kwdefaults_sources = { + k: None + if self.source is None + else DefaultsSource(self.source, k, is_kw=True) + for k in fn.__kwdefaults__ + } + fake_func.__kwdefaults__ = { + k: wrap(val=v, source=kwdefaults_sources[k]) + for k, v in fn.__kwdefaults__.items() + } + + bound = inspect.signature(fake_func).bind(*args, **kwargs) + bound.apply_defaults() + result = dict(bound.arguments.items()) + + wrap_args_kwargs(tx, result) + closure_cells = init_cellvars(parent, result, fn.__code__) + closure = self.fn.__closure__ or () + assert len(closure) == len(self.fn.__code__.co_freevars) + for idx, name, cell in zip( + itertools.count(), self.fn.__code__.co_freevars, closure + ): + if name == "__class__": + source = AttrSource(self.source, "__class__") if self.source else None + result[name] = variables.UserDefinedClassVariable( + cell.cell_contents, + source=source, + ) + else: + var = tx.match_nested_cell(name, cell) + if var is not None: + # optimization for cleaner codegen + result[name] = var + elif self.source: + from .builder import VariableBuilder + + side_effects = parent.output.side_effects + if cell in side_effects: + out = side_effects[cell] + else: + closure_cell = GetItemSource( + AttrSource(self.source, "__closure__"), idx + ) + closure_cell_contents = AttrSource( + closure_cell, "cell_contents" + ) + try: + contents_var = VariableBuilder( + parent, closure_cell_contents + )(cell.cell_contents) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + + if ( + closure_cell_contents.name() + not in tx.mutated_closure_cell_contents + ): + # Optimistically don't allocate the cell, to + # reduce the number of side effects. This is + # important for cond, as without it, any accesses + # to closures create side effects and cond doesn't + # support side effects. If we're wrong and this + # closure cell gets written to, we will restart + # the analysis with this cell's name in the + # mutated list here + result[name] = contents_var + continue + + # cells are written to with "cell_contents", + # so the source should just be the closure_cell, not its contents + out = side_effects.track_cell_existing(closure_cell, cell) + side_effects.store_cell( + out, + contents_var, + ) + + result[name] = out + + else: + from .builder import SourcelessBuilder + + result[name] = SourcelessBuilder()(tx, cell.cell_contents) + + return result, closure_cells + + def export_freevars(self, parent, child): + pass + + def call_hasattr(self, tx, name: str) -> VariableTracker: + result = hasattr(self.fn, name) + return variables.ConstantVariable.create(result) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if self.is_constant: + return invoke_and_store_as_constant( + tx, self.fn, self.get_name(), args, kwargs + ) + + return super().call_function(tx, args, kwargs) + + +class UserMethodVariable(UserFunctionVariable): + """Some unsupported user-defined method""" + + def __init__(self, fn, obj, **kwargs): + super().__init__(fn=fn, **kwargs) + self.obj = obj + + def __str__(self): + return f"{self.__class__.__name__}({self.fn}, {self.obj})" + + def self_args(self): + return [self.obj] + + def python_type(self): + return types.MethodType + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution + # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method + # since we ensure `forward` of allowed modules can be traced by AOT safely. + # Note this is not only for allowed modules, as user customized modules can extend from + # allowed modules but using parent's `forward` method, which is also covered by this branch. + + # If we are tracing the higher order op, we want Dynamo to step inside + # the module call so that Dynamo can see the underlying parameters and + # buffers and raise them as inputs to the graph. The is_root_tracer + # check bypasses the if condition for non-root tracers and directly + # calls the super().call_function at the end, which is basically + # equivalent of inlining the method. + if tx.output.is_root_tracer() and isinstance( + self.obj, variables.NNModuleVariable + ): + module_attr = getattr(self.fn, "__module__", "") + if ( + module_attr is not None + and module_attr.startswith("torch.nn.") + or self.is_constant + ): + return self.obj.call_method( + tx, self.fn.__name__, args, kwargs, constant=self.is_constant + ) + return super().call_function(tx, args, kwargs) + + def inspect_parameter_names(self): + return super().inspect_parameter_names()[1:] + + +class WrappedUserMethodVariable(UserMethodVariable): + def __init__(self, wrapped, context, **kwargs): + kwargs.pop("fn", None) + kwargs.pop("obj", None) + super().__init__(wrapped.fn, wrapped.obj, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + +class WrappedUserFunctionVariable(UserFunctionVariable): + def __init__(self, wrapped, context, **kwargs): + kwargs.pop("fn", None) + kwargs.pop("obj", None) + super().__init__(wrapped.fn, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + +def invoke_and_store_as_constant(tx, fn, name, args, kwargs): + def convert(x): + if isinstance(x, variables.TensorVariable): + return x.get_real_value() + return x.as_python_constant() + + args = [convert(x) for x in args] + kwargs = {k: convert(v) for k, v in kwargs.items()} + res = fn(*args, **kwargs) + return tx.output.register_attr_or_module( + res, + name, + source=ConstantSource(name), + ) + + +class NestedUserFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "closure_scope", + "f_globals", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + fn_name, + code, + f_globals, + defaults, + kwdefaults, + annotations, + closure, + closure_scope, + wrapped_reconstructible=None, + **kwargs, + ): + super().__init__(**kwargs) + assert isinstance(fn_name.as_python_constant(), str) + assert isinstance(code.as_python_constant(), types.CodeType) + assert isinstance(f_globals, dict) + self.fn_name = fn_name + self.code = code + self.f_globals = f_globals + self.defaults = defaults + self.kwdefaults = kwdefaults + self.annotations = annotations + self.closure = closure + if closure is None: + closure_scope = None + self.closure_scope = closure_scope + # Either a source or a VT with .can_reconstruct() == True + self.wrapped_reconstructible: Optional[ + Union[Source, VariableTracker] + ] = wrapped_reconstructible + + def self_args(self): + return [] + + def get_code(self): + return self.code.as_python_constant() + + def get_function(self): + if self.closure: + raise NotImplementedError() + func = types.FunctionType( + self.code.as_python_constant(), + self.f_globals, + self.fn_name.as_python_constant(), + ) + if self.defaults: + func.__defaults__ = self.defaults.as_python_constant() + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.as_python_constant() + if self.annotations: + annotations = self.annotations.as_python_constant() + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert isinstance(annotations, dict) + func.__annotations__ = annotations + return func + + def has_closure(self): + return self.closure is not None + + def has_self(self): + return False + + def get_globals(self): + return self.f_globals + + def bind_args(self, parent, args, kwargs): + from .misc import InlinedClosureVariable + + code = self.get_code() + func = types.FunctionType( + code, + self.f_globals, + self.fn_name.as_python_constant(), + tuple(self.defaults.items) if self.defaults else None, + tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), + ) + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() + bound = inspect.signature(func).bind(*args, **kwargs) + bound.apply_defaults() + result = dict(bound.arguments.items()) + wrap_args_kwargs(parent.output.root_tx, result) + closure_cells = init_cellvars(parent, result, code) + + for idx, name in enumerate(code.co_freevars): + cell = self.closure.items[idx] + assert getattr(cell, name, name) == name + assert name not in result + if isinstance(cell, InlinedClosureVariable): + # InlinedClosureVariable's are created from LOAD_CLOSURE's from + # InliningInstructionTranslators when the variable name is not found in closure_cells. + # They should remain outside of closure_cells, so that our callee (the + # InliningInstructionTranslator that traces `func`) handles + # the cell correctly - that is, the cell's contents are treated as if they + # are local variables, like in UserFunctionVariable's bind_args for freevars. + cand = parent + while cand and name not in cand.symbolic_locals: + cand = cand.parent + if cand is None: + raise RuntimeError( + f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack" + ) + result[name] = cand.symbolic_locals[name] + else: + closure_cells[name] = self.closure.items[idx] + + return result, closure_cells + + def export_freevars(self, parent, child): + code = self.get_code() + for var in code.co_freevars: + if var in child.symbolic_locals: + parent.symbolic_locals[var] = child.symbolic_locals[var] + + def reconstruct(self, codegen): + codegen.load_import_from(__name__, "_create_nested_fn") + codegen(self.code) + codegen.extend_output([codegen._create_load_const(self.f_globals)]) + codegen(ConstantVariable.create(self.code.value.co_name)) + + if self.defaults: + codegen(self.defaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.closure: + codegen(self.closure) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.kwdefaults: + codegen(self.kwdefaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.annotations: + try: + annotations = self.annotations.as_python_constant() + codegen.extend_output([codegen._create_load_const(annotations)]) + except NotImplementedError: + codegen(self.annotations) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + codegen.extend_output(create_call_function(7, push_null=True)) + + if self.wrapped_reconstructible: + codegen.load_import_from("functools", "wraps") + codegen(self.wrapped_reconstructible) + codegen.extend_output(create_call_function(1, True)) + codegen.extend_output(create_rot_n(2)) + codegen.extend_output(create_call_function(1, True)) + + +class SkipFunctionVariable(VariableTracker): + def __init__(self, value, reason=None, **kwargs): + super().__init__(**kwargs) + self.value = value + self.reason = reason + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return cls( + value, + source=source, + ) + + @staticmethod + @functools.lru_cache(None) + def fold_through_function_to_wrapper(): + return { + collections.namedtuple: variables.UserDefinedClassVariable, + } + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if inspect.getattr_static(self.value, "_torchdynamo_disable", False): + unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") + # Fold through the functions(e.g, collections.namedtuple) + # that inputs & outputs are all python constants + elif ( + self.value in self.fold_through_function_to_wrapper().keys() + and check_constant_args(args, kwargs) + ): + value = self.value( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + return self.fold_through_function_to_wrapper().get(self.value)( + value, mutable_local=MutableLocal() + ) + elif ( + self.value is functools.wraps + and not kwargs + and len(args) == 1 + and ( + args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx) + ) + ): + + def wraps(fn): + if isinstance(fn, variables.NestedUserFunctionVariable): + if args[0].source: + reconstructible = args[0].source + else: + reconstructible = args[0] + return fn.clone(wrapped_reconstructible=reconstructible) + unimplemented(f"functools.wraps({fn})") + + return variables.LambdaVariable(wraps) + else: + try: + path = inspect.getfile(self.value) + except TypeError: + path = f"Builtin {self.value.__name__}" + msg = f"'skip function {self.value.__qualname__} in file {path}'" + msg += f"', {self.reason}'" if self.reason else "" + unimplemented(msg) + + +def _traceable_collective_remaps(): + # We can't rely on importing from distributed, since it's not always built + if torch.distributed.is_available(): + from torch.distributed._functional_collectives import ( + traceable_collective_remaps, + ) + + return traceable_collective_remaps + return {} + + +def _traceable_collectives_source(tx, fn): + assert torch.distributed.is_available(), "Illegal invocation." + assert fn in _traceable_collective_remaps().values() + + inner_name = fn.__name__ + path_source = tx.import_source("torch.distributed._functional_collectives") + return AttrSource(path_source, inner_name) + + +class CollectiveFunctionRewriteVariable(UserFunctionVariable): + """ + Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. + + This class provides both a way to check if a function is remappable, and perform the remapping. + + In the case that a function is 'remappable' but only for some combinations of call-time arguments, + we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse + than status-quo as we currently graph-break on all distributed.* collectives. + """ + + def __init__(self, fn, *, replacement_var, **kwargs): + super().__init__(fn, **kwargs) + assert isinstance(replacement_var, UserFunctionVariable) + self.replacement_var = replacement_var + + @staticmethod + def create(tx, old_fn, source, **options): + new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) + return CollectiveFunctionRewriteVariable( + old_fn, + replacement_var=UserFunctionVariable(new_fn, source=new_source, **options), + source=source, + **options, + ) + + @staticmethod + def can_rewrite(variable): + return ( + inspect.isfunction(variable) and variable in _traceable_collective_remaps() + ) + + @staticmethod + def rewrite(tx, fn): + new_fn = _traceable_collective_remaps()[fn] + return new_fn, _traceable_collectives_source(tx, new_fn) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + # call_function must check any unsupported arguments and graph-break. + # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, + # since that's the contract for putting a mapping in `traceable_collective_remaps` + import torch.distributed as dist + from torch.distributed._functional_collectives import REDUCE_OP_TO_STR + + # Merge args into kwargs so positional and keyword args + # can be processed the same way. + signature = inspect.signature(self.fn) + kwargs = dict(signature.bind(*args, **kwargs).arguments) + args = () + + if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): + unimplemented( + f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}" + ) + + if kwargs.get("group") is None or kwargs["group"].value is None: + kwargs["group"] = ProcessGroupVariable.get_global_pg_variable() + + if self.fn == dist.all_reduce: + reduce_op_var = kwargs.get("op") + reduce_op = ( + reduce_op_var.value + if reduce_op_var is not None + else signature.parameters["op"].default + ) + if reduce_op not in REDUCE_OP_TO_STR: + raise ValueError(f"Unsupported all_reduce op: {reduce_op}") + kwargs["op"] = variables.ConstantVariable.create( + REDUCE_OP_TO_STR[reduce_op] + ) + return self.replacement_var.call_function(tx, args, kwargs) + + +class FunctoolsPartialVariable(VariableTracker): + def __init__(self, func: VariableTracker, args, keywords, **kwargs): + super().__init__(**kwargs) + self.func = func + assert isinstance(args, list) + self.args = args + assert isinstance(keywords, dict) + self.keywords = keywords + + def reconstruct(self, codegen): + codegen.load_import_from("functools", "partial") + codegen(self.func) + if self.args: + codegen.foreach(self.args) + if not self.keywords: + codegen.extend_output(create_call_function(len(self.args) + 1, True)) + return + + codegen.foreach(self.keywords.values()) + keys = tuple(self.keywords.keys()) + codegen.extend_output( + codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, True) + ) + + def get_function(self): + return self.as_python_constant() + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + merged_args = self.args + args + merged_kwargs = {**self.keywords, **kwargs} + return self.func.call_function(tx, merged_args, merged_kwargs) + + def call_hasattr(self, tx, name: str) -> VariableTracker: + # functools.partial uses slots, so attributes are constant + return variables.ConstantVariable.create( + hasattr(functools.partial(identity), name) + ) + + def as_python_constant(self): + return functools.partial( + self.func.as_python_constant(), + *[arg.as_python_constant() for arg in self.args], + **{k: v.as_python_constant() for k, v in self.keywords.items()}, + ) + + def guard_as_python_constant(self): + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + return functools.partial( + self.func.guard_as_python_constant(), + *[v.guard_as_python_constant() for v in self.args], + **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, + ) + + +class TritonKernelVariable(VariableTracker): + def __init__(self, kernel, kernel_idx, grid, **kwargs): + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + super().__init__(**kwargs) + + assert kernel is not None + + self.kernel = kernel + self.kernel_idx = kernel_side_table.add_kernel(kernel) + + assert kernel_idx is None or self.kernel_idx == kernel_idx + + self.grid = grid + + if isinstance(kernel, Autotuner): + # We only support configs and keys arguments of triton.autotune + # Make sure other arguments are defaulted + defaults = inspect.signature(Autotuner.__init__).parameters + + # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. + # The call to get_first_attr is to maintain backward-compatibility. + if ( + ( + "warmup" in defaults + and defaults["warmup"].default + != get_first_attr(kernel, "num_warmups", "warmup") + ) + or ( + "rep" in defaults + and defaults["rep"].default + != get_first_attr(kernel, "num_reps", "rep") + ) + or ( + "prune_configs_by" in defaults + and defaults["prune_configs_by"].default + != kernel.early_config_prune + ) + # Set via reset_to_zero argument + or len(kernel.reset_idx) != 0 + or len(kernel.restore_idx) != 0 + ): + raise Unsupported( + "Only configs and keys are supported for triton.autotune" + ) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from triton.runtime.autotuner import Autotuner + + from .constant import ConstantVariable + from .dicts import ConstDictVariable + from .lists import BaseListVariable + + if self.grid is None: + raise Unsupported("Triton kernels should always be called with a grid") + + # Both for grid's meta as well as for the kernel, we need combined + # args and kwargs normalized + names = ( + variables.ConstantVariable.create(name) for name in self.kernel.arg_names + ) + kwargs = {variables.ConstantVariable.create(k): v for k, v in kwargs.items()} + normalized_args = {**dict(zip(names, args)), **kwargs} + + configs = ( + [config.kwargs for config in self.kernel.configs] + if isinstance(self.kernel, Autotuner) + else [{}] + ) + grids = [] + for config_args in configs: + # If the grid is a function, then lets execute it and convert it to + # a list + grid = self.grid + if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)): + # Populate the special "meta" argument to call the grid function + config_args = { + ConstantVariable.create(k): ConstantVariable.create(v) + for k, v in config_args.items() + } + meta = ConstDictVariable({**normalized_args, **config_args}, dict) + grid = grid.call_function(tx, [meta], {}) + + # Now, the grid must be a list either originally or through above + # modification + if isinstance(grid, BaseListVariable): + grids.append(grid.as_proxy()) + else: + unimplemented(f"grid for the triton kernel is {type(grid)}") + + for i in range(len(grids)): + if not isinstance(grids[i], tuple): + raise Unsupported("Only tuple grids are supported") + # inductor expects all grids to be 3-tuple so lets make it + if len(grids[i]) == 1: + grids[i] = (grids[i][0], 1, 1) + elif len(grids[i]) == 2: + grids[i] = (grids[i][0], grids[i][1], 1) + elif len(grids[i]) > 3: + raise Unsupported("Grid can have at most rank 3") + + assert len(grids) != 0 + if len(set(grids)) == 1: + # If there's only one unique grid, lets simplify + grids = [grids[0]] + + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_mutation, + ) + + # Combine args and kwargs and pass as a dict so that if user defined triton + # kernel uses variables as 'grid' or 'kernel', it does not conflict with + # parameters of the wrapper function + meta = ConstDictVariable(normalized_args, dict) + tx.output.create_proxy( + "call_function", + triton_kernel_wrapper_mutation, + (), + { + "kernel_idx": self.kernel_idx, + "grid": grids, + "kwargs": meta.as_proxy(), + }, + ) + + return variables.ConstantVariable( + None, + ) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__getitem__": + # __getitem__ should only be called if we don't already have a grid + # Only grid needs to be passed + if self.grid is not None or len(args) != 1: + raise Unsupported( + "Triton kernels should be called with only a single grid" + ) + + return TritonKernelVariable( + kernel=self.kernel, + kernel_idx=self.kernel_idx, + grid=args[0], + ) + elif name == "run": + if "grid" not in kwargs: + raise Unsupported("Triton kernel requires to be called with a grid") + grid = kwargs.pop("grid") + kwargs.pop("warmup", None) + # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args) + return TritonKernelVariable( + kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid + ).call_function(tx, args, kwargs) + + # Bail out to parent's implementation + return super().call_method(tx, name, args, kwargs) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/higher_order_ops.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/higher_order_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..010188362fff8be4e86b0643d74f4f1946dce2f0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/higher_order_ops.py @@ -0,0 +1,1660 @@ +# mypy: ignore-errors + +import contextlib +import functools +import logging +import types + +from typing import Dict, List, Optional + +import torch._C +import torch.fx +import torch.nn +import torch.onnx.operators +from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value +from torch._dynamo.variables.base import VariableTracker +from torch._dynamo.variables.builtin import BuiltinVariable +from torch._dynamo.variables.functions import UserFunctionVariable +from torch._dynamo.variables.tensor import SymNodeVariable +from torch._guards import Source +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils import _pytree as pytree + +from ..exc import ( + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, + UserError, + UserErrorType, +) +from ..source import AttrSource, FSDPNNModuleSource, GetItemSource, NNModuleSource +from ..utils import proxy_args_kwargs +from .dicts import ConstDictVariable +from .lists import ListVariable, TupleVariable +from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable + + +log = logging.getLogger(__name__) + + +def raise_hard_error_if_graph_break(reason): + def deco(fn): + @functools.wraps(fn) + def graph_break_as_hard_error(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Unsupported as e: + msg = " Scroll up to find out what causes the graph break." + raise UncapturedHigherOrderOpError(reason + msg) from e + + return graph_break_as_hard_error + + return deco + + +@contextlib.contextmanager +def dynamo_enable_grad(tx, enable=True): + from . import GradModeVariable + + org_value = torch.is_grad_enabled() + try: + GradModeVariable.create(tx, enable, initialized=True) + yield + finally: + GradModeVariable.create(tx, org_value, initialized=True) + + +def only_consist_of(var, types, allow_none=False): + if isinstance(var, types): + return True + if allow_none and var.is_python_constant() and var.as_python_constant() is None: + return True + if isinstance(var, (TupleVariable, ListVariable)): + return all(only_consist_of(item, types, allow_none) for item in var.items) + if isinstance(var, ConstDictVariable): + return all( + only_consist_of(item, types, allow_none) for item in var.items.values() + ) + return False + + +# A more read-able syntax sugar for creating a UserFunctionVariable for f +# and run call_function on it. Make it return a function to preserve the calling +# convention of the original f. +def _make_inlined(tx, f): + assert callable(f), "Expect f to be a python callable." + + def inline_call(*args, **kwargs): + return UserFunctionVariable(f).call_function(tx, args, kwargs) + + return inline_call + + +def _call_function_and_unflatten_output(tx, fn, args, kwargs, ret_vt, ret_treespec): + from .builder import wrap_fx_proxy + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + ret_vt.as_proxy(), + ) + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn, + args=args, + kwargs=kwargs, + ), + example_value=flat_example_value, + ) + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {}) + return ( + _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_treespec) + if ret_treespec + else flat_variable + ) + + +def _assert_tensors_nonaliasing(inputs, outputs): + input_tensor_ids = { + id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor) + } + output_tensor_ids = { + id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor) + } + assert input_tensor_ids.isdisjoint( + output_tensor_ids + ), "inputs to function body cannot alias outputs" + + +def validate_args_and_maybe_create_graph_inputs( + sub_args, + tracer, + tx, + set_subgraph_inputs, + description, +): + from . import AutogradFunctionContextVariable, ConstantVariable, EnumVariable + from .builder import wrap_fx_proxy_cls + + assert tracer.parent is not None + + if set_subgraph_inputs == "flatten_manual": + flat_args, tree_spec = _make_inlined(tx, pytree.tree_flatten)( + ListVariable(sub_args) + ).unpack_var_sequence(tx) + + flat_inputs = validate_args_and_maybe_create_graph_inputs( + flat_args.unpack_var_sequence(tx), + tracer, + tx, + set_subgraph_inputs="manual", + description=description, + ) + + return _make_inlined(tx, pytree.tree_unflatten)( + ListVariable(flat_inputs), tree_spec + ).unpack_var_sequence(tx) + else: + args = [] + for a in sub_args: + assert isinstance(a, VariableTracker) + if set_subgraph_inputs == "automatic": + args.append(a) + continue + + if isinstance(a, (ConstantVariable, EnumVariable)): + # This arg is not used in the body of the higher order op. + # Currently, this new input is added to make the calls + # happy, which expect a fixed number of arguments. In + # future, we can clean this up. + tracer.create_graph_input("const") + new_arg = a + # Weird special case, we probably want to delete it or fold it + # into the next case (of `a` being placeable into a graph) + elif isinstance(a, AutogradFunctionContextVariable): + tracer.create_graph_input(a.as_proxy().node.name) + new_arg = a + # If `a` can be put into a graph + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + new_proxy = tracer.create_graph_input(node.name) + example_value = ( + node.meta["example_value"] if "example_value" in node.meta else None + ) + new_arg = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + # If `a` cannot be put into a graph + else: + # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). + raise unimplemented( + f"{description} with body that accepts non-Tensors as input. " + f"Got: {a.python_type()}" + ) + args.append(new_arg) + return args + + +# This helper function is used to make sure two graphs share the same input signature. For example, +# in torch.cond, two branches might lift different set of tensors as inputs. This function helps to +# dedup the inputs and modify the graphs to take the same set of inputs. +def _merge_graph_inputs( + l_graph, l_lifted_freevars, l_name, r_graph, r_lifted_freevars, r_name +): + def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars): + # The nn module attributes are guaranteed to be registered into the top-level graph module during + # higher order op speculation. Therefore, get_attr nodes in two branches with the same + # target refer to the same attribute and we can safely deduplicate them with their target. + # + # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But + # true_branch and false_branch belong to two separate tracing contexts, they may register the same + # attribute to top level seperately. This creates two get_attr proxies for the same attribute + # that have different meta data such as stack_trace (one stack trace for the true_branch, + # and the other for false_branch). It seems better to discard the proxy explicitly in cond + # than make dynamo create a single proxy for the same get_attr target. + def shared_getattrs(l_lifted_proxies, r_lifted_proxies): + true_targets = { + proxy.node.target: proxy + for proxy in l_lifted_proxies + if proxy.node.op == "get_attr" + } + l_shared_getattrs = {} + r_shared_getattrs = {} + + for false_proxy in r_lifted_proxies: + if ( + false_proxy.node.op == "get_attr" + and false_proxy.node.target in true_targets + ): + true_proxy = true_targets[false_proxy.node.target] + l_shared_getattrs[true_proxy] = true_proxy + r_shared_getattrs[false_proxy] = true_proxy + return l_shared_getattrs, r_shared_getattrs + + l_shared_getattrs, r_shared_getattrs = shared_getattrs( + l_lifted_freevars.keys(), r_lifted_freevars.keys() + ) + + l_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + l_shared_getattrs.keys() + ) + r_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + r_shared_getattrs.keys() + ) + unique_l_freevars = l_lifted_freevars.keys() - l_shared_freevars + unique_r_freevars = r_lifted_freevars.keys() - r_shared_freevars + + def _sort_by_name(vars): + return sorted(vars, key=lambda var: var.node.name) + + return ( + list(_sort_by_name(list(l_shared_freevars))), + list(_sort_by_name(list(r_shared_freevars))), + list(_sort_by_name(list(unique_l_freevars))), + list(_sort_by_name(list(unique_r_freevars))), + ) + + (l_shared, r_shared, unique_l, unique_r) = dedup_and_sort_lifted_freevars( + l_lifted_freevars, r_lifted_freevars + ) + + # Let's say we capture cond(pred, true_fn, false_fn, (x,)) + # With set_graph_input set to automatic, + # true_fn has lifted variables x, a, b, c + # false_fn has lifted variables x, a, b, d + # Then fixup_branch_inps make sure both branches have the same signature, i.e.: + # - true_fn(x, a, b, c_true_branch, d_false_branch) + # - false_fn(x, a, b, c_true_branch, d_false_branch) + # + # More formally, the signature has three parts in the following order: + # 1. used in both branches: x, a, b + # 2. only used in true branches: c, suffixed with _true_branch + # 3. only used in false branches: d, suffixed with _false_branch + # Within each part, we re-order the nodes by name to have a derterministic ordering for testing. + def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): + def _insert_or_replace_phs(new_args, name_suffix): + for arg in new_args: + new_ph = graph.placeholder(arg.node.name + name_suffix) + # Override with new_ph if there exists a old placeholder. + if arg in lifted_freevars: + old_ph = lifted_freevars[arg].node + old_ph.replace_all_uses_with(new_ph) + # replace_all_uses_with doesn't clean users. Clean it mannually so that we could erase it. + old_ph.users = {} + graph.erase_node(old_ph) + + first_not_ph_node = next( + node for node in graph.nodes if node.op != "placeholder" + ) + with graph.inserting_before(first_not_ph_node): + _insert_or_replace_phs(shared, "") + _insert_or_replace_phs(unique_l, "_" + l_name) + _insert_or_replace_phs(unique_r, "_" + r_name) + + fixup_branch_inps(l_graph, l_lifted_freevars, l_shared, unique_l, unique_r) + fixup_branch_inps(r_graph, r_lifted_freevars, r_shared, unique_l, unique_r) + return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r + + +# See NOTE [HigherOrderOperator tracing design] for details of the design +def speculate_subgraph( + tx, + f, + sub_args, + sub_kwargs, + description, + *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target=None, + always_restore=False, + enable_grad=None, + # NOTE [argument `set_subgraph_inputs`] + # set_subgraph_inputs controls what how to construct subgraphs' placeholders from sub_args. + # 1. if your HOP supports arbitrary inputs, use set_subtraph_inputs="automatic" (most recommended). + # 2. if your HOP supports only Tensor and symnode inputs, use set_subgraph_inputs="flatten_manual" (recommended). + # If sub_args contain Pytree structure (e.g. dict/list/tuple/set), the sub_args will be flattened first. + # Then the flattend args are manually set as subgraph's placeholders. + # 3. if your HOP must preserve inputs that are not tensor or symnode as placeholders e.g. AutogradFunctionContextVariable + # use set_subgraph_inputs="manual" (not recommended). We do not recommend it in general because it has the + # restriction that user need to manually control how to create placeholders and VariableTrackers for the args. + set_subgraph_inputs="automatic", + restore_side_effects=True, + should_flatten_outputs=False, + # Pass in an originating tracer - this is needed for preserving context + # across fwd-bwd for autograd.Function + tracer=None, +): + if sub_kwargs is None: + sub_kwargs = {} + + assert set_subgraph_inputs in { + "automatic", + "flatten_manual", + "manual", + }, "Please use one of the supported set_subgraph_inputs options." + + # See NOTE [Temporary argument `set_subgraph_inputs`] + if sub_kwargs and set_subgraph_inputs != "automatic": + unimplemented("Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.") + + try: + f, sub_args, sub_kwargs = VariableTracker.apply( + # ensure guards on args get installed in parent subgraph + lambda x: x.realize(), + (f, sub_args, sub_kwargs), + ) + + with tx.output.subtracer(source_target, tracer) as subtracer: + args = validate_args_and_maybe_create_graph_inputs( + sub_args, subtracer, tx, set_subgraph_inputs, description + ) + + validate_args_and_maybe_create_graph_inputs( + sub_kwargs.values(), + subtracer, + tx, + set_subgraph_inputs="automatic", + description=description, + ) + + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + + # For handling side effects, we can make an argument that we don't + # have to do anything here. The side effects infra does a good job + # of graph breaking if we mutate any nonlocal or global variable + # while subtracing. As a result if tracing succeeds, side effects + # data structure will only contain read-only data structures that + # are put there for tracking purposes. + # But on the other hand, there is an argument that if we ever write + # a new side effect in Dynamo which does not go through the side + # effect infra, we can end up in bad state. + # Therefore we restore the side effects after tracing. The catch is + # that we have to special handle tensor variables. If we have seen a + # nonlocal variable tensor during subtracing, we want to keep a + # track of that tensor, so that later subtracing or the root tracer + # itself does not create a new proxy for the already observed tensor + # variable. + if restore_side_effects: + prev_side_effects = tx.output.side_effects.clone() + + with autograd_ctx: + output = f.call_function(tx, args, sub_kwargs) + + if restore_side_effects: + new_side_effects = tx.output.side_effects.clone() + prev_side_effects.track_tensor_variables_from_runahead_side_effects( + new_side_effects + ) + tx.output.side_effects = prev_side_effects + + treespec = None + if should_flatten_outputs: + # Flatten the speculated subgraph output. + output, treespec = _make_inlined(tx, pytree.tree_flatten)( + output + ).unpack_var_sequence(tx) + # Actually, transform the list (returned by flatten) into a tuple + # for dynamo consistency. + output = BuiltinVariable(tuple).call_function(tx, [output], {}) + + # Register output to graph + # Modeled off of compile_and_call_fx_graph + # TODO: support pytree output + # We check always_restore because we dont use the output or side effects of always_restore code, + # like bwd. + if always_restore: + # Nothing left to do here + return (output, treespec), tx.output.graph, subtracer.lifted_freevars + else: + from . import TensorVariable + + if not only_consist_of(output, TensorVariable, allow_none=True): + unimplemented( + "HigherOrderOperator body's output must consist of tensors only" + ) + + # The output proxies might not belong to this SubgraphTracer + # (if they are free variables that were never lifted) + # so lift them here. + output_proxies = output.as_proxy() + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + + tx.output.create_node( + "output", + "output", + (subtracer.create_arg((output_proxies,))), + {}, + ) + graph = tx.output.graph + graph.lint() + lifted_freevars = subtracer.lifted_freevars + + return ( + (output, treespec), + graph, + lifted_freevars, + ) + + except Unsupported as ex: + f_name = f"{type(f).__name__}" + if isinstance(f, UserFunctionVariable): + f_name = f.get_name() + msg = ( + f"speculate_subgraph: while introspecting {description}, we were unable " + f"to trace function `{f_name}` into a single graph. This means " + f"that Dynamo was unable to prove safety for this API and will " + f"fall back to eager-mode PyTorch, which could lead to a slowdown." + ) + log.info(msg) + log.info(ex) + raise ex + + +def make_attr(tx, name): + node = tx.output.create_proxy( + "get_attr", + name, + (), + {}, + ) + return node + + +def add_subgraph(tx, source, name, gm): + next_name = None + i = 0 + while not next_name: + candidate = f"{name}_{i}" + if candidate in tx.output.nn_modules: + i += 1 + else: + next_name = candidate + + gm.__name__ = next_name + if source.guard_source().is_fsdp_module(): + src = FSDPNNModuleSource(GetItemSource(source, next_name)) + else: + src = NNModuleSource(GetItemSource(source, next_name)) + gm.torchdynamo_force_dynamic = False + tx.output.register_attr_or_module(gm, next_name, source=src) + return next_name + + +class TorchHigherOrderOperatorVariable(VariableTracker): + def __init__(self, value, source: Optional[Source] = None, **kwargs): + super().__init__(**kwargs) + self.value = value + self.source = source + + @staticmethod + def make(value, source=None, **kwargs): + if value.__name__ == "cond": + return CondHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "while_loop": + return WhileLoopHigherOrderVariable(value, source, **kwargs) + elif value.__name__ in ("map", "map_impl"): + return MapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "executorch_call_delegate": + return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "out_dtype": + return OutDtypeHigherOrderVariable(value, source, **kwargs) + elif value is torch._functorch.eager_transforms.grad_impl: + return FunctorchGradHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap": + return WrapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ in ( + "wrap_activation_checkpoint", + "tag_activation_checkpoint", + ): + return CheckpointHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "_export_tracepoint": + return ExportTracepointHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "trace_wrapped": + return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs) + elif value.__name__ == "strict_mode": + return StrictModeHigherOrderVariable(value, source, **kwargs) + else: + unimplemented(f"HigherOrderOperator {value.__name__}") + + def call_function( + self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ) -> VariableTracker: + unimplemented(f"HigherOrderOperator {self.value.__name__}") + + +class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="Cond doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import ( + ConstantVariable, + ListVariable, + NestedUserFunctionVariable, + TensorVariable, + UserFunctionVariable, + ) + + args, kwargs = VariableTracker.apply(lambda x: x.realize(), (args, kwargs)) + + for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len( + args + ), "did not provide the right number of non-keyword args" + args.append(v) + + if kwargs: + unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}") + + # TODO(voz): Support fake tensor dispatch for recursive + # ops - see torch/dispatch/_dispatcher.py + if len(args) != 4: + unimplemented( + f"Expected 4 arguments but got {len(args)}.\n" + f"Usage: cond(pred, true_fn, false_fn, operands)", + ) + # predicate + if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable): + unimplemented( + f"Expected pred to be bool or a boolean tensor with single " + f"item but got {str(type(args[0]))} " + f"with original python type {str(args[0].python_type())}.", + ) + + # operands + if not isinstance(args[3], (ListVariable, TupleVariable)): + unimplemented( + f"Expected a tuple but got {args[3].python_type()}", + ) + operands = args[3].unpack_var_sequence(tx) + if not only_consist_of(args[3], (TensorVariable,)): + unimplemented( + "Expect operands to be a tuple of pytrees that only consists of tensor leaves." + ) + + # branches + assert isinstance( + args[1], + ( + UserFunctionVariable, + NestedUserFunctionVariable, + NNModuleVariable, + UnspecializedNNModuleVariable, + ), + ), str( + type(args[1]) + ) # true_fn + + assert isinstance( + args[2], + ( + UserFunctionVariable, + NestedUserFunctionVariable, + NNModuleVariable, + UnspecializedNNModuleVariable, + ), + ), str( + type(args[2]) + ) # false_fn + + # Our strategy for tracing the true/false branches of cond + # are to checkpoint our graphstate, run the true branch, + # roll it back to the checkpoint, and run the false + # branch, and then merge the graphstates. Well, perhaps + # "merge" is too strong a word: we mostly assert that + # the resulting graphstates have to be the same. + # + # We only permit guards to diverge (we union the guards from + # both branches). In particular, this means that side + # effects are NOT permitted inside true/false branches; this + # would be difficult to implement, because of the path + # explosion problem. + + def speculate_branch(branch): + # NB: 0 is predicate + ix = 1 if branch else 2 + # TODO: Support kwargs + ( + (ret_val, ret_treespec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[ix], + operands, + {}, + "cond", + source_target=self.value, + should_flatten_outputs=True, + ) + + if not only_consist_of(ret_val, (TensorVariable,)): + unimplemented( + "Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.", + ) + return ret_val, ret_treespec, ret_graph, ret_lifted_freevars + + (true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch( + True + ) + true_nn_modules = dict(tx.output.nn_modules) + + ( + false_r, + false_treespec, + false_graph, + false_lifted_freevars, + ) = speculate_branch(False) + false_nn_modules = dict(tx.output.nn_modules) + + same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)( + true_treespec, false_treespec + ) + if not same_treespec.as_python_constant(): + unimplemented("Expected branches to return the same pytree structure.") + + def diff_meta(tensor_vars1, tensor_vars2): + assert all( + isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2 + ) + all_diffs = [] + for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)): + # We check the meta data associated with meta["example_value"] + meta1 = _extract_tensor_metadata( + var1.proxy.node.meta["example_value"], include_contiguity=False + ) + meta2 = _extract_tensor_metadata( + var2.proxy.node.meta["example_value"], include_contiguity=False + ) + if meta1 != meta2: + all_diffs.append((f"pair{i}:", meta1, meta2)) + return all_diffs + + if diffs := diff_meta( + true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx) + ): + unimplemented( + f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}" + ) + + ( + true_graph, + false_graph, + true_shared, + false_shared, + unique_true, + unique_false, + ) = _merge_graph_inputs( + true_graph, + true_lifted_freevars, + "true_branch", + false_graph, + false_lifted_freevars, + "false_branch", + ) + + true_name = add_subgraph( + tx, + self.source, + "cond_true", + torch.fx.GraphModule(true_nn_modules, true_graph), + ) + false_name = add_subgraph( + tx, + self.source, + "cond_false", + torch.fx.GraphModule(false_nn_modules, false_graph), + ) + + true_node = make_attr(tx, true_name) + false_node = make_attr(tx, false_name) + + p_args = ( + args[0].as_proxy(), + true_node, + false_node, + # We pick true_shared but it shouldn't matter + true_shared + unique_true + unique_false, + ) + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.cond, p_args, {}, true_r, true_treespec + ) + + +class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="while_loop doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ) -> VariableTracker: + from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable + + args, kwargs = VariableTracker.apply(lambda x: x.realize(), (args, kwargs)) + + for i, k in enumerate(["cond_fn", "body_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len( + args + ), "did not provide the right number of non-keyword args" + args.append(v) + + if kwargs: + unimplemented( + f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}" + ) + + if len(args) != 3: + unimplemented( + f"Expected 3 arguments but got {len(args)}.\n" + f"Usage: while_loop(cond_fn, body_fn, operands)", + ) + + def _check_supported_callable(fn_var): + assert isinstance( + fn_var, + ( + UserFunctionVariable, + NestedUserFunctionVariable, + NNModuleVariable, + UnspecializedNNModuleVariable, + ), + ), str(type(fn_var)) + + _check_supported_callable(args[0]) + _check_supported_callable(args[1]) + + # operands + if not isinstance(args[2], (ListVariable, TupleVariable)): + unimplemented( + f"Expected a tuple but got {args[2].python_type()}", + ) + + operands = args[2].unpack_var_sequence(tx) + if not only_consist_of(args[2], (TensorVariable,)): + unimplemented( + "Expect operands to be a tuple of pytrees that only consists of tensor leaves." + ) + + ( + (cond_r, cond_treespec), + cond_graph, + cond_lifted_freevars, + ) = speculate_subgraph( + tx, args[0], operands, {}, "while_loop", source_target=self.value + ) + cond_nn_modules = dict(tx.output.nn_modules) + if not isinstance(cond_r, TensorVariable): + unimplemented( + f"Expected cond_fn to return a tensor but got {cond_r.python_type()}", + ) + + cond_r_meta = _extract_tensor_metadata( + cond_r.proxy.node.meta["example_value"], include_contiguity=False + ) + if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size( + [] + ): + unimplemented( + f"Expected cond_fn to return a tensor with shape (,) but got {cond_r_meta.shape}" + ) + + ( + (body_r, body_treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[1], + operands, + {}, + "while_loop", + source_target=self.value, + should_flatten_outputs=True, + ) + body_nn_modules = dict(tx.output.nn_modules) + + ( + cond_graph, + body_graph, + cond_shared, + body_shared, + cond_unique, + body_unique, + ) = _merge_graph_inputs( + cond_graph, + cond_lifted_freevars, + "cond_fn", + body_graph, + body_lifted_freevars, + "body_fn", + ) + # We pick cond_shared but it shouldn't matter + merged_input = tuple(cond_shared + cond_unique + body_unique) + + cond_name = add_subgraph( + tx, + self.source, + "cond_fn", + torch.fx.GraphModule(cond_nn_modules, cond_graph), + ) + body_name = add_subgraph( + tx, + self.source, + "body_fn", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + cond_node = make_attr(tx, cond_name) + body_node = make_attr(tx, body_name) + + p_args = ( + cond_node, + body_node, + merged_input, + ) + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.while_loop, p_args, {}, body_r, body_treespec + ) + + +def non_single_tensor_return_unsupported(api, ret): + from . import TensorVariable + + if not isinstance(ret, TensorVariable): + raise Unsupported( + f"{api} over function that returns something " f"other than one Tensor" + ) + + +class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ) -> VariableTracker: + from . import NestedUserFunctionVariable, TensorVariable, UserFunctionVariable + from .builder import wrap_fx_proxy_cls + + if len(kwargs) > 0: + unimplemented( + "torch.ops.higher_order.map: kwargs are not supported in the map operator." + ) + + assert type(args[0].realize()) in ( + UserFunctionVariable, + NestedUserFunctionVariable, + ) + assert type(args[1].realize()) is TensorVariable + + sample_shape = get_fake_value(args[1].as_proxy().node, tx).size() + + if len(sample_shape) < 1 or sample_shape[0] == 0: + unimplemented( + "map() operator doesn't support scalar or zero-sized tensors during tracing." + ) + + # To get the example output from map() we will need to provide at least one sample to + # the loop body. In our case we will always use xs[0], and our map() won't support zero + # sized tensor during tracing. + first_dim = wrap_fx_proxy_cls( + target_cls=TensorVariable, tx=tx, proxy=args[1].as_proxy()[0] + ) + + # TODO: Support kwargs + ( + (body_r, body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + [ + first_dim, + *args[2:], + ], + {}, + "torch.ops.higher_order.map", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + ) + + body_nn_modules = dict(tx.output.nn_modules) + + body_name = add_subgraph( + tx, + self.source, + "map_body", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + p_args = ( + body_node, + [args[1].as_proxy()], + [arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()), + ) + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.map_impl, p_args, {}, body_r, body_spec + ) + + +class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + # This is operator for delegation within Executorch which calls a + # specific function in the given lowered module with the given + # operators. The actual operator is defined in the Executorch codebase. + # This is a bad hierarchical violation since + # executorch_call_delegate sits at a higher level than dynamo, but + # there's no real solution to this issue yet. + if len(kwargs) > 0: + unimplemented( + "executorch_call_delegate: kwargs arguments were not enabled." + ) + lowered_module = tx.output.get_submodule(args[0].module_key) + + lowered_node = make_attr(tx, args[0].module_key) + + p_args = tuple(arg.as_proxy() for arg in args[1:]) + real_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args + ) + + example_res = lowered_module.original_module.module()(*real_sub_args) + + # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: + # executorch modules promise not to alias inputs and outputs. + # Thus, output FakeTensors will correctly not alias input FakeTensors. + _assert_tensors_nonaliasing(real_sub_args, example_res) + + example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) + + p_args = (lowered_node,) + p_args + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import ConstantVariable + from .builder import wrap_fx_proxy + + # TODO: Support `fn` with kwargs. + if not torch._dynamo.config.capture_func_transforms: + unimplemented( + "torch.func.grad capture is disabled, " + "it can be turned on by setting " + "`torch._dynamo.config.capture_func_transforms=True`" + ) + # [NOTE] Here we are (roughly) modelling the following + # + # grad_fn = torch.func.grad(fn, argnums=.., has_aux=..) + # grad_output = grad_fn(x) + grad_args = (args[0], args[1], args[2]) + + # get arguments + func, argnums, has_aux = grad_args + kwargs = args[4].items + if len(kwargs) > 0: + # Since speculate_subgraph doesn't support kwargs, we can't handle this for now. + unimplemented( + "torch.func.grad: kwargs arguments are currently unsupported." + ) + + # Trace through the `func` + # NOTE [HACK: Enable autograd while tracing function] + # `torch.func.grad` should not be affected by `no_grad` outside of `grad`. + # So, we enable_grad right before the function to which `grad` is applied + # (the parts explicitly disabled with `no_grad` inside the function are still disabled). + # Eg. + # def f(x): + # with no_grad(): # This will disable grad tracking under it. + # y = x * 2 + # + # return x ** 2 - y # grad tracking should be enabled irrespective of outside `no_grad`. + # + # with no_grad(): # This will not disable grad tracking inside of grad(f). + # grad_o = torch.func.grad(f)(x) + # TODO: Support kwargs + (body_r, _), body_graph, body_lifted_freevars = speculate_subgraph( + tx, + func, + args[3].items, + {}, + "torch.func.grad", + source_target=self.value, + # See NOTE [HACK: Enable autograd while tracing function] + enable_grad=True, + set_subgraph_inputs="manual", + ) + + body_name = add_subgraph( + tx, + self.source, + "grad_body", + torch.fx.GraphModule(tx.output.nn_modules, body_graph), + ) + body_node = make_attr(tx, body_name) + grad_proxy_args = ( + body_node, + *(arg.as_proxy() for arg in grad_args[1:]), + ) + + # Model `grad_fn = grad(fn, *grad_args, **grad_kwargs)` + grad_fn = tx.output.create_proxy( + "call_function", + torch.func.grad, + args=tuple(grad_proxy_args), + kwargs={}, + name="grad_proxy", + ) + + # Pass lifted freevars to the call to `grad_fn` + args = args[3].items + grad_fn_args = tuple(arg.as_proxy() for arg in args) + tuple( + body_lifted_freevars + ) + + # Call grad_fn with inputs. + # grad_output = grad_fn(*grad_fn_args, **grad_fn_kwargs) + grad_output = grad_fn(*grad_fn_args) + + # `grad_fn(*grad_fn_args, **grad_fn_kwargs)` + # Output of grad_fn is + # For has_aux=False, Tuple[gradients of inputs indicated by argnums]. + # For has_aux=True, Tuple[Tuple[gradients of inputs indicated by argnums], aux values] + # NOTE: example_value should match `grad_output`. + def _from_args(idx): + return args[idx].as_proxy().node.meta["example_value"].contiguous() + + def to_python_ints(argnums): + if not isinstance(argnums, (ConstantVariable, TupleVariable)): + raise UserError( + UserErrorType.INVALID_INPUT, + f"argnums is expected to be int or tuple of ints. Got {argnums}.", + ) + + if isinstance(argnums, ConstantVariable): + if not isinstance(argnums.value, (int, tuple)): + raise UserError( + UserErrorType.INVALID_INPUT, + f"argnums is expected to be int or tuple of ints. Got {argnums}.", + ) + return argnums.value + else: + const_vars = argnums.unpack_var_sequence(tx) + if not all( + isinstance(var, ConstantVariable) and isinstance(var.value, int) + for var in const_vars + ): + raise UserError( + UserErrorType.INVALID_INPUT, + f"argnums is expected to contain int only. Got {const_vars}.", + ) + return tuple(var.value for var in const_vars) + + argnums_v = to_python_ints(argnums) + example_value = pytree.tree_map(_from_args, argnums_v) + + if has_aux.value: + # case : has_aux = True + # NOTE: Currently speculate subgraph allows body_r to be + # Tensor or Tuple/List of Tensor. + # Since `grad` expects output with has_aux + # to be (output, aux), only valid output currently is + # (output, some_tensor) + body_r_proxy = body_r.as_proxy() + aux = body_r_proxy[1].node.meta["example_value"] + example_value = (example_value, aux) + + fx_proxy = wrap_fx_proxy(tx=tx, proxy=grad_output, example_value=example_value) + + # Call contiguous on all the computed grads. + if not has_aux.value: + if isinstance(argnums_v, int): + return fx_proxy.call_method(tx, "contiguous", (), {}) + else: + grads = fx_proxy + items = [] + for idx in range(len(argnums_v)): + proxy = grads.call_method( + tx, "__getitem__", (ConstantVariable.create(idx),), {} + ).call_method(tx, "contiguous", (), {}) + items.append(proxy) + return TupleVariable(items) + else: # case: has_aux.value = True + # fx_proxy -> Tuple(grads, aux) + grads = fx_proxy.call_method( + tx, "__getitem__", (ConstantVariable.create(0),), {} + ) + aux = fx_proxy.call_method( + tx, "__getitem__", (ConstantVariable.create(1),), {} + ) + if isinstance(argnums_v, int): + return TupleVariable([grads.call_method(tx, "contiguous", (), {}), aux]) + else: + items = [] + for idx in range(len(argnums_v)): + proxy = grads.call_method( + tx, "__getitem__", (ConstantVariable.create(idx),), {} + ).call_method(tx, "contiguous", (), {}) + items.append(proxy) + return TupleVariable([TupleVariable(items), aux]) + + +class FunctorchHigherOrderVariable(UserFunctionVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if not torch._dynamo.config.capture_func_transforms: + name = self.get_name() + assert name in ("grad_impl", "vmap_impl") + fn = name.split("_")[0] + unimplemented( + f"torch.func.{fn} capture is disabled, " + "it can be turned on by setting " + "`torch._dynamo.config.capture_func_transforms=True`" + ) + return super().call_function(tx, args, kwargs) + + +class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): + def create_wrapped_node(self, tx, args, kwargs, description): + # See NOTE [HigherOrderOperator tracing design] for more details + + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], # function + [*args[1:]], + kwargs, + description, + source_target=self.value, + should_flatten_outputs=True, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = add_subgraph( + tx, + self.source, + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + + proxy_args = (body_node,) + lifted_args + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return proxy_args, {}, example_value, body_r, treespec, body_gmod + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node( + tx, args, kwargs, "wrap" + ) + + if len(p_kwargs) > 0: + unimplemented("kwargs should have been flattened into lifted args") + + return _call_function_and_unflatten_output( + tx, self.value, tuple(p_args), p_kwargs, body_r, treespec + ) + + +class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + if len(kwargs) > 0: + unimplemented("out_dtype does not handle kwargs") + + p_args = tuple(arg.as_proxy() for arg in args) + op = p_args[0] + output_dtype = p_args[1] + fake_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: a.node.meta["example_value"], p_args[2:] + ) + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + example_value = op(*fake_sub_args).to(dtype=output_dtype) + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + callable = args[0] + + unpacked_sequence = args[1].unpack_var_sequence(tx) + # TODO (tmanlaibaatar) support pytree here + for arg in unpacked_sequence: + if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): + unimplemented("strict_mode HOO only works for flat inputs for now") + + if kwargs: + unimplemented( + f"strict_mode HOO received unexpected kwargs: {list(kwargs.keys())}" + ) + + ( + (ret_val, ret_treespec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + unpacked_sequence, + {}, + "strict_mode", + source_target=self.value, + should_flatten_outputs=True, + ) + + strict_mode_nn_modules = dict(tx.output.nn_modules) + + strict_mode_name = add_subgraph( + tx, + self.source, + "strict_mode_body", + torch.fx.GraphModule(strict_mode_nn_modules, ret_graph), + ) + + strict_mode_node = make_attr(tx, strict_mode_name) + p_args = ( + strict_mode_node, + tuple(arg for arg in ret_lifted_freevars.keys()), + ) + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + ret_val.as_proxy(), + ) + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.ops.higher_order.strict_mode, + args=tuple(p_args), + kwargs={}, + ), + example_value=flat_example_value, + ) + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.strict_mode, p_args, {}, ret_val, ret_treespec + ) + + +class CheckpointHigherOrderVariable(WrapHigherOrderVariable): + def call_function( + self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ) -> VariableTracker: + from torch._higher_order_ops.wrap import TagActivationCheckpoint + from torch.utils.checkpoint import noop_context_fn + from .builder import wrap_fx_proxy + + context_fn = None + if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn: + ctx = kwargs.pop("context_fn") + if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): + context_fn = ctx.fn + elif isinstance( + ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + context_fn = ctx.as_python_constant() + else: + raise NotImplementedError( + f"checkpoint not implemented for {type(ctx)} context_fn" + ) + + checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs) + + # Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are + # already flattened above and managed inside the fx graph. + ( + p_args, + _, + example_value, + body_r, + treespec, + checkpointed_gmod, + ) = self.create_wrapped_node( + tx, args, gmod_kwargs, "torch.utils.checkpoint.checkpoint" + ) + if context_fn is not None: + checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn + + _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs) + + # Store the invocation as a call + variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs=checkpoint_kwargs, + ), + example_value=example_value, + ) + + if treespec is None: + return variable + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + variable = BuiltinVariable(list).call_function(tx, [variable], {}) + + return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec) + + +class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace + by unwrapping the higher order op and inlining through it. This op + is created by dynamo to survive through AotAutograd, then unwrapped + here in the call to dynamo from compiled autograd. + """ + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + kwargs = dict(kwargs) + fn = kwargs.pop("fn") + return fn.call_function(tx, args, kwargs) + + +class AutogradFunctionApplyVariable(VariableTracker): + def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs): + super().__init__(**kwargs) + self.fwd_graph = fwd_graph + self.bwd_graph = bwd_graph + self.parent_source = parent_source + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import ( + AutogradFunctionContextVariable, + UserDefinedClassVariable, + UserFunctionVariable, + UserMethodVariable, + ) + from .builder import wrap_fx_proxy + + """ + Consider the following: + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sin() + @staticmethod + def backward(ctx, grad): + x, = ctx.saved_tensors + return grad * x.cos() + We want the resulting graphs to look like: + def fwd(ctx, x): + # (output, saved tensors / attrs) + return (x.sin(), [x]) + # bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs) + def bwd(ctx, grad, x): + return grad * x.cos() + To accomplish this, we're going to: + 1. Construct a ctx object + 2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True) + 3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting + the ctx and grad inputs. + 4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph) + Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is + just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward + doesn't capture any arguments. + All these steps work if MySin.backward doesn't capture any values. This is a + limitation in general that we should check for. + """ + + prev_side_effects = tx.output.side_effects.clone() + fwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=tx.output.current_tracer, + source_target="autograd.Function", + ) + + fwd_src = AttrSource(self.parent_source, member="forward") + ctx = AutogradFunctionContextVariable.create(tx) + if isinstance(self.fwd_graph, types.FunctionType): + fwd_fn = UserFunctionVariable(self.fwd_graph, source=fwd_src) + fwd_args = [ctx, *args] + elif isinstance(self.fwd_graph, types.MethodType): + fwd_fn = UserMethodVariable( + self.fwd_graph.__func__, + UserDefinedClassVariable(self.fwd_graph.__class__), + source=fwd_src, + ) + fwd_args = [fwd_fn.obj, ctx, *args] + else: + unimplemented("non-function or method") + + # Speculate subgraph on the fwd + (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( + tx, + fwd_fn, + fwd_args, + kwargs, + "autograd.Function", + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=fwd_tracer, + ) + + if fwd_freevars: + unimplemented("NYI") + + if ctx.mutable_local in tx.output.side_effects.store_attr_mutations: + if ( + "_materialize_non_diff_grads" + in tx.output.side_effects.store_attr_mutations[ctx.mutable_local] + ): + unimplemented("NYI") + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + + # Speculate subgraph on the backward. We make the + # bwd tracer a child of the fwd tracer, because backward may rely on + # tensors/attrs created in the fwd tracer. + + from .lists import BaseListVariable + + if isinstance(fwd_out, BaseListVariable): + bwd_args = [ctx, *fwd_out.items] + else: + bwd_args = [ctx, fwd_out] + + bwd_src = AttrSource(self.parent_source, member="backward") + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + self.bwd_graph.__func__, + UserDefinedClassVariable(self.bwd_graph.__class__), + source=bwd_src, + ) + bwd_args = [bwd_fn.obj, *bwd_args] + else: + unimplemented("non-function or method") + + with tx.output.subtracer(fwd_fn, fwd_tracer), tx.strict_translation_mode(): + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + + # TODO: assert that bwd_graph didn't capture values that were + # not created inside fwd_graph. + + # TODO(oulgen): Ideally, we would not do a linear search for output + # node but as things currently are there could be nodes after the + # output node + # This is bug prone as if there's code after the output node, then + # graph.output will append the output at the very end + # This might be a behavior difference + + # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) + for node in fwd_graph.nodes: + if node.op == "output": + fwd_graph.erase_node(node) + break + + new_fwd_graph_outputs = (fwd_out.as_proxy(), list(bwd_freevars.keys())) + new_fwd_graph_outputs = pytree.tree_map(lambda x: x.node, new_fwd_graph_outputs) + fwd_graph.output(new_fwd_graph_outputs) + + # Store fwd_body + fwd_nn_modules = tx.copy_graphstate().output.nn_modules + fwd_name = add_subgraph( + tx, + fwd_src, + "fwd_body", + torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph), + ) + + fwd_node = make_attr(tx, fwd_name) + + # Store bwd_body + bwd_nn_modules = tx.copy_graphstate().output.nn_modules + bwd_name = add_subgraph( + tx, + bwd_src, + "bwd_body", + torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph), + ) + + bwd_node = make_attr(tx, bwd_name) + + tx.output.side_effects = prev_side_effects + + p_args = (fwd_node, bwd_node, *(arg.as_proxy() for arg in args)) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + fwd_out.as_proxy(), + ) + + # Store the invocation as a call + from torch._functorch.autograd_function import autograd_function_apply + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + autograd_function_apply, + args=p_args, + kwargs={}, + ), + example_value=example_value, + ) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/iter.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/iter.py new file mode 100644 index 0000000000000000000000000000000000000000..39968559a9caddb04ad42e221cf09a17f3b22e64 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/iter.py @@ -0,0 +1,260 @@ +# mypy: ignore-errors + +MAX_CYCLE = 3000 + +import itertools +import operator + +from typing import Dict, List, Optional + +from .. import polyfill, variables +from ..exc import unimplemented + +from .base import MutableLocal, VariableTracker +from .constant import ConstantVariable + + +class ItertoolsVariable(VariableTracker): + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def __repr__(self): + return f"ItertoolsVariable({self.value})" + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if ( + self.value is itertools.product + and not kwargs + and all(arg.has_unpack_var_sequence(tx) for arg in args) + ): + seqs = [arg.unpack_var_sequence(tx) for arg in args] + items = [] + for item in itertools.product(*seqs): + items.append(variables.TupleVariable(list(item))) + return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) + elif ( + self.value is itertools.chain + and not kwargs + and all(arg.has_unpack_var_sequence(tx) for arg in args) + ): + seqs = [arg.unpack_var_sequence(tx) for arg in args] + items = list(itertools.chain.from_iterable(seqs)) + return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) + elif self.value is itertools.accumulate: + from .builtin import BuiltinVariable + + if any(key not in ["initial", "func"] for key in kwargs.keys()): + unimplemented( + "Unsupported kwargs for itertools.accumulate: " + f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}" + ) + + acc = kwargs.get("initial") + + if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + + if "func" in kwargs and len(args) == 1: + func = kwargs["func"].call_function + elif len(args) == 2: + func = args[1].call_function + elif len(args) == 1: + # Default to operator.add + func = BuiltinVariable(operator.add).call_function + else: + unimplemented( + "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg" + ) + else: + unimplemented("Unsupported arguments for itertools.accumulate") + + items = [] + if acc is not None: + items.append(acc) + for item in seq: + if acc is None: + acc = item + else: + try: + acc = func(tx, [acc, item], {}) + except Exception: + raise unimplemented( # noqa: TRY200 + f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})" + ) + items.append(acc) + + return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) + elif ( + self.value is itertools.combinations + and not kwargs + and len(args) == 2 + and args[0].has_unpack_var_sequence(tx) + and args[1].is_python_constant() + ): + iterable = args[0].unpack_var_sequence(tx) + r = args[1].as_python_constant() + + items = [] + for item in itertools.combinations(iterable, r): + items.append(variables.TupleVariable(list(item))) + return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) + elif self.value is itertools.groupby: + if any(kw != "key" for kw in kwargs.keys()): + unimplemented( + "Unsupported kwargs for itertools.groupby: " + f"{','.join(set(kwargs.keys()) - {'key'})}" + ) + + def retrieve_const_key(key): + if isinstance(key, variables.SymNodeVariable): + return key.evaluate_expr() + elif isinstance(key, variables.ConstantVariable): + return key.as_python_constant() + else: + raise unimplemented( + "Unsupported key type for itertools.groupby: " + str(type(key)) + ) + + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + keyfunc = ( + ( + lambda x: ( + retrieve_const_key( + kwargs.get("key").call_function(tx, [x], {}) + ) + ) + ) + if "key" in kwargs + else None + ) + else: + unimplemented("Unsupported arguments for itertools.groupby") + + result = [] + try: + for k, v in itertools.groupby(seq, key=keyfunc): + result.append( + variables.TupleVariable( + [ + variables.ConstantVariable.create(k) + if variables.ConstantVariable.is_literal(k) + else k, + variables.ListIteratorVariable( + list(v), mutable_local=MutableLocal() + ), + ], + mutable_local=MutableLocal(), + ) + ) + except Exception: + raise unimplemented( # noqa: TRY200 + "Unexpected failure when calling itertools.groupby" + ) + return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) + elif self.value is itertools.repeat: + if len(args) < 2: + return variables.RepeatIteratorVariable( + *args, mutable_local=MutableLocal() + ) + + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder()(tx, polyfill.repeat), args, kwargs + ) + elif self.value is itertools.count: + return variables.CountIteratorVariable(*args, mutable_local=MutableLocal()) + elif self.value is itertools.cycle: + return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal()) + else: + return super().call_function(tx, args, kwargs) + + +class IteratorVariable(VariableTracker): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def next_variables(self, tx): + unimplemented("abstract method, must implement") + + +class RepeatIteratorVariable(IteratorVariable): + def __init__(self, item: VariableTracker, **kwargs): + super().__init__(**kwargs) + self.item = item + + # Repeat needs no mutation, clone self + def next_variables(self, tx): + return self.item, self + + +class CountIteratorVariable(IteratorVariable): + def __init__(self, item: int = 0, step: int = 1, **kwargs): + super().__init__(**kwargs) + if not isinstance(item, VariableTracker): + item = ConstantVariable.create(item) + if not isinstance(step, VariableTracker): + step = ConstantVariable.create(step) + self.item = item + self.step = step + + def next_variables(self, tx): + assert self.mutable_local + tx.output.side_effects.mutation(self) + next_item = self.item.call_method(tx, "__add__", [self.step], {}) + self.item = next_item + return self.item, self + + +class CycleIteratorVariable(IteratorVariable): + def __init__( + self, + iterator: IteratorVariable, + saved: List[VariableTracker] = None, + saved_index: int = 0, + item: Optional[VariableTracker] = None, + **kwargs, + ): + if saved is None: + saved = [] + super().__init__(**kwargs) + self.iterator = iterator + self.saved = saved + self.saved_index = saved_index + self.item = item + + def next_variables(self, tx): + assert self.mutable_local + + if self.iterator is not None: + try: + new_item, _ = self.iterator.next_variables(tx) + if len(self.saved) > MAX_CYCLE: + unimplemented( + "input iterator to itertools.cycle has too many items" + ) + tx.output.side_effects.mutation(self) + self.saved.append(new_item) + self.item = new_item + if self.item is None: + return self.next_variables(tx) + return self.item, self + except StopIteration: + self.iterator = None + return self.next_variables(tx) + elif len(self.saved) > 0: + tx.output.side_effects.mutation(self) + self.saved_index = (self.saved_index + 1) % len(self.saved) + return self.item, self + else: + raise StopIteration diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/lazy.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..74f0b43475d62d72d46bdf60b6f97b5d1bb9526e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/lazy.py @@ -0,0 +1,106 @@ +# mypy: ignore-errors + +import functools +from typing import Optional + +from .base import VariableTracker + + +class LazyCache: + """Container to cache the real VariableTracker""" + + def __init__(self, value, source): + assert source + self.value = value + self.source = source + self.vt: Optional[VariableTracker] = None + + def realize(self, parents_tracker): + assert self.vt is None + from ..symbolic_convert import InstructionTranslator + from .builder import VariableBuilder + + tx = InstructionTranslator.current_tx() + self.vt = VariableBuilder(tx, self.source)(self.value) + self.vt.parents_tracker.add(parents_tracker) + del self.value + del self.source + + +class LazyVariableTracker(VariableTracker): + """ + A structure that defers the creation of the actual VariableTracker + for a given underlying value until it is accessed. + + The `realize` function invokes VariableBuilder to produce the real object. + Once a LazyVariableTracker has been realized, internal bookkeeping will + prevent double realization. + + This object should be utilized for processing containers, or objects that + reference other objects where we may not want to take on creating all the + VariableTrackers right away. + """ + + _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields} + + @staticmethod + def create(value, source, **options): + return LazyVariableTracker(LazyCache(value, source), source=source, **options) + + def __init__(self, _cache, **kwargs): + assert isinstance(_cache, LazyCache) + super().__init__(**kwargs) + self._cache = _cache + + def realize(self) -> VariableTracker: + """Force construction of the real VariableTracker""" + if self._cache.vt is None: + self._cache.realize(self.parents_tracker) + return self._cache.vt + + def unwrap(self): + """Return the real VariableTracker if it already exists""" + if self.is_realized(): + return self._cache.vt + return self + + def is_realized(self): + return self._cache.vt is not None + + def clone(self, **kwargs): + assert kwargs.get("_cache", self._cache) is self._cache + if kwargs.get("source", self.source) is not self.source: + self.realize() + return VariableTracker.clone(self.unwrap(), **kwargs) + + def __str__(self): + if self.is_realized(): + return self.unwrap().__str__() + return VariableTracker.__str__(self.unwrap()) + + def __getattr__(self, item): + return getattr(self.realize(), item) + + # most methods are auto-generated below, these are the ones we want to exclude + apply = VariableTracker.apply + copy = VariableTracker.copy + __post_init__ = VariableTracker.__post_init__ + __repr__ = VariableTracker.__repr__ + + +def _create_realize_and_forward(name): + @functools.wraps(getattr(VariableTracker, name)) + def realize_and_forward(self, *args, **kwargs): + return getattr(self.realize(), name)(*args, **kwargs) + + return realize_and_forward + + +def _populate(): + for name, value in VariableTracker.__dict__.items(): + if name not in LazyVariableTracker.__dict__: + if callable(value): + setattr(LazyVariableTracker, name, _create_realize_and_forward(name)) + + +_populate() diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/lists.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..ba727aa1e0bfffd773939ef63467b6f156b80538 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/lists.py @@ -0,0 +1,811 @@ +# mypy: ignore-errors + +import collections +import functools +import inspect +import operator +import types +from typing import Dict, List, Optional + +import torch +import torch.fx +from ..._guards import Source + +from .. import polyfill, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import unimplemented +from ..source import AttrSource, GetItemSource +from ..utils import ( + get_fake_value, + guard_if_dyn, + is_namedtuple, + istype, + iter_contains, + namedtuple_fields, + odict_values, +) +from .base import MutableLocal, VariableTracker +from .constant import ConstantVariable +from .functions import UserFunctionVariable, UserMethodVariable + + +class BaseListVariable(VariableTracker): + @staticmethod + def cls_for_instance(obj): + if is_namedtuple(obj): + return functools.partial(NamedTupleVariable, tuple_cls=type(obj)) + return BaseListVariable.cls_for(type(obj)) + + @staticmethod + def cls_for(obj): + return { + iter: ListIteratorVariable, + list: ListVariable, + slice: SliceVariable, + torch.Size: SizeVariable, + tuple: TupleVariable, + odict_values: ListVariable, + torch.nn.ParameterList: ListVariable, + torch.nn.ModuleList: ListVariable, + collections.deque: DequeVariable, + }[obj] + + def __init__( + self, + items: List[VariableTracker], + **kwargs, + ): + super().__init__(**kwargs) + assert isinstance(items, list) + assert all(isinstance(x, VariableTracker) for x in items) + self.items: List[VariableTracker] = items + + def _as_proxy(self): + return [x.as_proxy() for x in self.items] + + def modified(self, items, **kwargs): + return type(self)(items, **kwargs) + + @property + def value(self): + return self.as_python_constant() + + def as_python_constant(self): + return self.python_type()([x.as_python_constant() for x in self.items]) + + def as_proxy(self): + assert self.python_type() is not SizeVariable + return self.python_type()(self._as_proxy()) + + def getitem_const(self, arg: VariableTracker): + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + if self.source is not None: + return self.clone( + items=self.items[index], + source=GetItemSource(self.source, index), + mutable_local=MutableLocal() if self.mutable_local else None, + ) + else: + return self.clone( + items=self.items[index], + mutable_local=MutableLocal() if self.mutable_local else None, + ) + else: + assert isinstance(index, (int, torch.SymInt)) + return self.items[index] + + def unpack_var_sequence(self, tx): + return list(self.items) + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__getitem__": + from .tensor import TensorVariable + + assert not kwargs and len(args) == 1 + if isinstance(args[0], TensorVariable): + value = get_fake_value(args[0].as_proxy().node, tx) + if value.constant is not None and value.constant.numel() == 1: + value = variables.ConstantVariable.create(value.constant.item()) + else: + unimplemented("__getitem__ with non-constant tensor") + else: + value = args[0] + return self.getitem_const(value) + elif name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains(self.items, args[0], tx) + elif name == "index": + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder()(tx, polyfill.index), [self] + list(args), kwargs + ) + + return super().call_method(tx, name, args, kwargs) + + @staticmethod + def list_compare(tx, op, left, right): + from .builtin import BuiltinVariable + + eq_result = BaseListVariable.list_eq(tx, left, right) + if op is operator.eq: + return eq_result + elif op is operator.ne: + return BuiltinVariable(operator.not_).call_function(tx, [eq_result], {}) + else: + unimplemented(f"list_compare {left} {op} {right}") + + @staticmethod + def list_eq(tx, left, right): + from .builtin import BuiltinVariable + + # Most list-like variables implement comparison ops the same way, + # so they can re-use this helper. + # There are quirks though, like how `tuple([2]) == torch.Size([2])`, + # but `tuple([2]) != list([2])` + if len(left.items) != len(right.items): + return ConstantVariable.create(False) + if len(left.items) == 0: + return ConstantVariable.create(True) + + # Generic list comparison works by iterating over left aka self and right the compared-to list. + # If we hit here, their lengths are the same and they cannot be expressed as python constants. + # So, we iterate over the zipped list items. + comps = [] + for l, r in zip(left.items, right.items): + comp = BuiltinVariable(operator.eq).call_function(tx, [l, r], {}) + if comp.is_python_constant() and not comp.as_python_constant(): + # early exit in false case + return comp + comps.append(comp) + + return functools.reduce( + lambda a, b: BuiltinVariable(operator.and_).call_function(tx, [a, b], {}), + comps, + ) + + +class RangeVariable(BaseListVariable): + def __init__(self, items, **kwargs): + items_to_map = items + start = variables.ConstantVariable.create(0) + stop = None + step = variables.ConstantVariable.create(1) + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError() + + assert stop is not None + super().__init__([start, stop, step], **kwargs) + + def python_type(self): + return range + + def as_python_constant(self): + return range(*[x.as_python_constant() for x in self.items]) + + def as_proxy(self): + return self.python_type()(*self._as_proxy()) + + def unpack_var_sequence(self, tx): + return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] + + def reconstruct(self, codegen): + assert "range" not in codegen.tx.f_globals + codegen.append_output(codegen.create_load_python_module(range, True)) + codegen.foreach(self.items) + codegen.extend_output(create_call_function(3, False)) + + def var_getattr(self, tx, name): + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented(f"range.{name}") + return self.items[fields.index(name)] + + +class CommonListMethodsVariable(BaseListVariable): + """ + Implement methods common to List and other List-like things + """ + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "append" and self.mutable_local: + assert not kwargs + (arg,) = args + tx.output.side_effects.mutation(self) + self.items.append(arg) + return ConstantVariable.create(None) + elif ( + name == "extend" + and self.mutable_local + and args + and args[0].has_unpack_var_sequence(tx) + ): + assert not kwargs + (arg,) = args + seq = arg.unpack_var_sequence(tx) + tx.output.side_effects.mutation(self) + self.items.extend(seq) + return ConstantVariable.create(None) + elif name == "insert" and self.mutable_local: + assert not kwargs + idx, value = args + const_idx = idx.as_python_constant() + tx.output.side_effects.mutation(self) + self.items.insert(const_idx, value) + return ConstantVariable.create(None) + elif name == "pop" and self.mutable_local: + assert not kwargs + tx.output.side_effects.mutation(self) + return self.items.pop(*[a.as_python_constant() for a in args]) + elif name == "clear" and self.mutable_local: + assert not kwargs and not args + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif ( + name == "__setitem__" + and self.mutable_local + and args + and args[0].is_python_constant() + ): + assert not kwargs + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + self.items[key.as_python_constant()] = list(value.items) + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + elif name == "copy": + # List copy() doesn't have args and kwargs + assert not kwargs + assert not args + items = list(self.items) + return self.modified(items, mutable_local=MutableLocal()) + else: + return super().call_method(tx, name, args, kwargs) + + +class ListVariable(CommonListMethodsVariable): + def python_type(self): + return list + + def reconstruct(self, codegen): + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if ( + name == "__setitem__" + and self.mutable_local + and args + and args[0].is_python_constant() + ): + assert not kwargs + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + if not value.has_unpack_var_sequence(tx): + unimplemented( + f"Missing dynamo support for expanding {value} into a list for slice assignment." + ) + self.items[key.as_python_constant()] = value.unpack_var_sequence(tx) + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + if self.python_type() is not list: + return super().call_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr([], name)) + + +class DequeVariable(CommonListMethodsVariable): + def python_type(self): + return collections.deque + + def reconstruct(self, codegen): + assert "deque" not in codegen.tx.f_globals + codegen.append_output( + codegen.create_load_python_module(collections.deque, True) + ) + codegen.foreach(self.items) + codegen.extend_output(create_call_function(len(self.items), False)) + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if ( + name == "__setitem__" + and self.mutable_local + and args + and args[0].is_python_constant() + ): + assert not kwargs + key, value = args + assert key.is_python_constant() and isinstance( + key.as_python_constant(), int + ) + tx.output.side_effects.mutation(self) + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + elif name == "extendleft" and self.mutable_local: + assert not kwargs + + (arg,) = args + prefix = arg.unpack_var_sequence(tx) + prefix.reverse() + tx.output.side_effects.mutation(self) + self.items = prefix + list(self.items) + return ConstantVariable.create(None) + elif name == "popleft" and self.mutable_local: + assert not args + assert not kwargs + item = self.items[0] + tx.output.side_effects.mutation(self) + self.items = self.items[1:] + return item + elif name == "appendleft" and self.mutable_local: + assert not kwargs + tx.output.side_effects.mutation(self) + self.items = [args[0]] + list(self.items) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + + +class TupleVariable(BaseListVariable): + def python_type(self): + return tuple + + def reconstruct(self, codegen): + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items))) + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + return super().call_method(tx, name, args, kwargs) + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + if self.python_type() is not tuple: + return super().call_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr((), name)) + + +class SizeVariable(TupleVariable): + """torch.Size(...)""" + + def __init__( + self, + items: List[VariableTracker], + proxy: Optional[torch.fx.Proxy] = None, + **kwargs, + ): + self.proxy = proxy + super().__init__(items, **kwargs) + + def python_type(self): + return torch.Size + + def as_proxy(self): + if self.proxy is not None: + return self.proxy + + # torch.Size needs special handling. Normally, we pun a list-like + # container to directly contain Proxy/Node objects from FX, and FX + # knows to look inside containers (via map_aggregate). But torch.Size + # is weird; although it subclasses from tuple, it doesn't allow + # members which aren't int-like (rejecting Proxy and Node). This + # means we can't use the normal representation trick + # torch.Size([proxy0, proxy1]). I looked into seeing if I could + # relax torch.Size in PyTorch proper, but if torch.Size constructor + # sees a type that it doesn't recognize, it will try to call + # __index__() on it, so there is no BC way to actually change this + # behavior (though it occurs to me that I could have just added a + # YOLO no checking alternate constructor.) + # + # To work around this problem, I represent a torch.Size proxy as + # a straight up proxy, that would have been constructed by taking + # the constituent proxies as arguments. This trick can be generally + # used for any construct that we need a proxy for but we can't + # directly represent as an aggregate; I don't see very many examples + # of this in torchdynamo though! + + # Look for a proxy. If there are none, do the legacy behavior + tracer = None + proxies = self._as_proxy() + for proxy in proxies: + if isinstance(proxy, torch.fx.Proxy): + tracer = proxy.tracer + break + + if tracer is None: + return torch.Size(proxies) + + proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {}) + proxy.node.meta["example_value"] = torch.Size( + [ + p.node.meta["example_value"] if not isinstance(p, int) else p + for p in proxies + ] + ) + return proxy + + def reconstruct(self, codegen): + codegen.load_import_from("torch", "Size") + codegen.foreach(self.items) + build_torch_size = [ + create_instruction("BUILD_TUPLE", arg=len(self.items)), + ] + create_call_function(1, True) + codegen.extend_output(build_torch_size) + + def unpack_var_sequence(self, tx): + return list(self.items) + + def numel(self, tx): + from .builtin import BuiltinVariable + from .tensor import SymNodeVariable + + const_result = 1 + sym_sizes = [] + + for v in self.items: + if isinstance(v, ConstantVariable): + const_result *= v.value + else: + assert isinstance(v, SymNodeVariable), type(v) + # Delay proxy calls until we know it will be necessary + sym_sizes.append(v) + + result = ConstantVariable.create(const_result) + if sym_sizes and const_result == 1: + # Skip multiplying by 1 + result, *sym_sizes = sym_sizes + + if not sym_sizes or const_result == 0: + return result + + mul = BuiltinVariable(operator.mul) + for v in sym_sizes: + result = mul.call_function(tx, [result, v], {}) + return result + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__getitem__": + assert not kwargs and len(args) == 1 + out = self.get_item_dyn(tx, args[0]) + return out + elif name == "numel": + assert not args and not kwargs + return self.numel(tx) + + return super().call_method(tx, name, args, kwargs) + + def get_item_dyn(self, tx, arg: VariableTracker): + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + if isinstance(index, slice): + return SizeVariable(self.items[index]) + else: + assert isinstance(index, (int, torch.SymInt)) + return self.items[index] + + +class NamedTupleVariable(TupleVariable): + def __init__(self, items, tuple_cls, **kwargs): + super().__init__(items, **kwargs) + self.tuple_cls = tuple_cls + + def python_type(self): + return self.tuple_cls + + def as_python_constant(self): + return self.python_type()(*[x.as_python_constant() for x in self.items]) + + def as_proxy(self): + assert self.python_type() is not SizeVariable + return self.python_type()(*self._as_proxy()) + + def reconstruct(self, codegen): + create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls) + codegen.append_output(codegen._create_load_const(create_fn)) + codegen.foreach(self.items) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(self.items)), + ] + + create_call_function(1, True) + ) + + def var_getattr(self, tx, name): + def check_and_create_method(): + method = inspect.getattr_static(self.tuple_cls, name, None) + if isinstance(method, classmethod): + # We need the unbounded cls method to avoid the inline __self__ + return UserMethodVariable( + method.__func__, + variables.UserDefinedClassVariable(self.tuple_cls), + ) + elif isinstance(method, staticmethod): + return UserFunctionVariable(method.__func__) + elif inspect.isfunction(method): + return UserMethodVariable(method, self) + else: + return None + + fields = namedtuple_fields(self.tuple_cls) + if name not in fields: + method = check_and_create_method() + if not method: + super().var_getattr(tx, name) + return method + return self.items[fields.index(name)] + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + fields = namedtuple_fields(self.tuple_cls) + return variables.ConstantVariable.create(name in fields) + + +class SliceVariable(BaseListVariable): + def __init__(self, items, **kwargs): + items_to_map = items + start, stop, step = [variables.ConstantVariable.create(None)] * 3 + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError() + + if isinstance(start, variables.TensorVariable) or isinstance( + stop, variables.TensorVariable + ): + unimplemented("Dynamic slicing on data-dependent value is not supported") + + super().__init__([start, stop, step], **kwargs) + + def as_proxy(self): + return slice(*self._as_proxy()) + + def python_type(self): + return slice + + def as_python_constant(self): + return slice(*[guard_if_dyn(x) for x in self.items]) + + def reconstruct(self, codegen): + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) + + def var_getattr(self, tx, name): + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented(f"slice.{name}") + return self.items[fields.index(name)] + + +class ListIteratorVariable(VariableTracker): + def __init__(self, items, index: int = 0, **kwargs): + super().__init__(**kwargs) + assert isinstance(items, list) + # Removing this check as it slows things down too much + # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 + + # assert all(isinstance(x, VariableTracker) for x in items) + self.items = items + self.index = index + + def __repr__(self): + return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" + + def next_variables(self, tx): + assert self.mutable_local + old_index = self.index + if old_index >= len(self.items): + raise StopIteration() + tx.output.side_effects.mutation(self) + self.index += 1 + return self.items[old_index], self + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ): + if name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains(self.items[self.index :], args[0], tx) + + return super().call_method(tx, name, args, kwargs) + + def as_python_constant(self): + if self.index > 0: + raise NotImplementedError() + return iter([x.as_python_constant() for x in self.items]) + + def unpack_var_sequence(self, tx): + return list(self.items[self.index :]) + + def reconstruct(self, codegen): + remaining_items = self.items[self.index :] + codegen.foreach(remaining_items) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(remaining_items)), + create_instruction("GET_ITER"), + ] + ) + + +class TupleIteratorVariable(ListIteratorVariable): + pass + + +class RestrictedListSubclassVariable(ListVariable): + """ + This is a special case of UserDefinedObjectVariable where: + 1) The user subclasses list + 2) None of the list methods are overriden, merely some new methods are added + + In these cases, we can prevent graph breaks by not using the general + UserDefinedObjectVariable machinery and instead treating it like + a ListVariable. + """ + + _nonvar_fields = {"user_cls", "user_cls_source", *ListVariable._nonvar_fields} + _allowed_names = { + "__call__", + "__module__", + "__dict__", + "__doc__", + "__name__", + "__qualname__", + } + _disallowed_names = { + "__getattribute__", + "__getattr__", + "__setattr__", + } + + @classmethod + def _is_non_conflicting_subclass( + cls, + user_cls: type, + python_cls: type, + ): + """Ensures user_cls inherits from python_cls (e.g. list) and does not override any methods on python_cls""" + if ( + not istype(user_cls, type) + or user_cls.__bases__ != (python_cls,) + or user_cls.__mro__ != (user_cls, python_cls, object) + ): + return False # not subclass + return not any( + hasattr(python_cls, name) or name in cls._disallowed_names + for name in set(user_cls.__dict__.keys()) - cls._allowed_names + ) + + @classmethod + def is_matching_cls(cls, user_cls: type): + return cls._is_non_conflicting_subclass(user_cls, list) + + def __init__(self, items, *, user_cls: type, user_cls_source: Source, **kwargs): + super().__init__(items=items, **kwargs) + self.user_cls = user_cls + self.user_cls_source = user_cls_source + assert istype(user_cls, type) + assert isinstance(user_cls_source, Source) + + def python_type(self): + return self.user_cls + + def as_proxy(self): + return [x.as_proxy() for x in self.items] + + def as_python_constant(self): + raise NotImplementedError() + + def is_python_constant(self): + return False + + @property + def value(self): + raise AttributeError("value") + + def modified(self, items, **kwargs): + return type(self)( + items, + user_cls=self.user_cls, + user_cls_source=self.user_cls_source, + **kwargs, + ) + + def reconstruct(self, codegen): + codegen(self.user_cls_source) + super().reconstruct(codegen) + codegen.extend_output(create_call_function(1, True)) + + def call_method( + self, + tx, + name, + args: List["VariableTracker"], + kwargs: Dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name in self.user_cls.__dict__: + method = self.user_cls.__dict__[name] + if isinstance(method, types.FunctionType): + # inline the method + source = AttrSource(self.user_cls_source, name) + return UserMethodVariable(method, self, source=source).call_function( + tx, args, kwargs + ) + unimplemented( + f"RestrictedListSubclassVariable method {self.user_cls.__name__}.{name}" + ) + return super().call_method(tx, name, args, kwargs) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + return self.call_method(tx, "__call__", args, kwargs) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/misc.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4012a88966c5b5121a51d594c4e21263d59c03 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/misc.py @@ -0,0 +1,886 @@ +# mypy: ignore-errors + +import collections +import dataclasses +import functools +import inspect +import itertools +import sys +import types +from typing import Dict, List + +import torch._C +import torch._numpy as tnp +import torch.utils._pytree as pytree +from .. import config, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource +from ..utils import ( + check_constant_args, + check_unspec_python_args, + identity, + is_tensor_base_attr_getter, + proxy_args_kwargs, +) +from .base import VariableTracker +from .functions import NestedUserFunctionVariable, UserFunctionVariable +from .user_defined import UserDefinedObjectVariable + + +class SuperVariable(VariableTracker): + def __init__(self, typevar, objvar=None, specialized=False, **kwargs): + super().__init__(**kwargs) + # typevar is the fist argument to super(). In the case where no argument + # is provided to super(), it is the __class__ object where + # the super() function is being called + self.typevar = typevar + # objvar here must be an instance or subtype of typevar. + # In the case where super() is called without arguments, it is the first argument + # to the current function where super() is called from (self for regular method, + # cls for a classmethod) + self.objvar = objvar + self.specialized = specialized # directly get attr from self.typevar if true + + def reconstruct(self, codegen): + codegen(variables.BuiltinVariable(super)) + codegen(self.typevar) + if self.objvar is not None: + codegen(self.objvar) + codegen.extend_output(create_call_function(2, True)) + else: + codegen.extend_output(create_call_function(1, True)) + + def _resolved_getattr_and_source(self, tx, name): + assert self.objvar, "1-arg super not implemented" + if self.specialized: + return getattr(self.typevar.as_python_constant(), name) + search_type = self.typevar.as_python_constant() + + # The rest of this function does two things: + # - Walk the mro to find where the attribute comes from to be + # able to provide accurate source + # - Call the getattr to get the object + + # Find the class object, where the function lives. + # When objvar is "self", use type(self), when objvar is "cls", use it as-is + type_to_use = self.objvar.python_type() + type_to_use_source = ( + TypeSource(self.objvar.source) if self.objvar.source else None + ) + if issubclass(type_to_use, type): + type_to_use = self.objvar.value + type_to_use_source = self.objvar.source + + source = None + if self.objvar.source is not None: + # Walk the mro tuple to find out the actual class where the + # attribute resides. + search_mro = type_to_use.__mro__ + start_index = search_mro.index(search_type) + 1 + for index in range(start_index, len(search_mro)): + if hasattr(search_mro[index], name): + # Equivalent of something like type(L['self']).__mro__[1].attr_name + source = AttrSource( + GetItemSource(AttrSource(type_to_use_source, "__mro__"), index), + name, + ) + break + + # TODO(jansel): there is a small chance this could trigger user code, prevent that + return getattr(super(search_type, type_to_use), name), source + + def var_getattr(self, tx, name: str) -> "VariableTracker": + # Check if getattr is a constant. If not, delay the actual work by + # wrapping the result in GetAttrVariable. Mostly super is called with a + # method, so most of the work is delayed to call_function. + # + # We could have just implemented a const_getattr. However, super is + # special when it comes to finding sources. Compared to other VTs, super + # requires the attr name to walk the mro and find the actual source (and + # not just AttrSource). + value, source = self._resolved_getattr_and_source(self, name) + if not variables.ConstantVariable.is_literal(value): + return GetAttrVariable(self, name) + if source: + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + return variables.ConstantVariable.create(value) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + inner_fn, source = self._resolved_getattr_and_source(self, name) + + if inner_fn is object.__init__: + return LambdaVariable(identity) + elif inner_fn is torch.nn.Module.__init__: + objvar = self.objvar + from ..side_effects import AttributeMutationNew + + if ( + isinstance(objvar, variables.UserDefinedObjectVariable) + and isinstance(objvar.mutable_local, AttributeMutationNew) + and not (args or kwargs) + ): + tx.output.side_effects.store_attr( + objvar, + "__call_nn_module_init", + variables.ConstantVariable.create(True), + ) + return variables.ConstantVariable.create(None) + else: + unimplemented("super() nn.Module.__init__") + elif isinstance(inner_fn, types.FunctionType): + return variables.UserFunctionVariable( + inner_fn, source=source + ).call_function(tx, [self.objvar] + args, kwargs) + elif isinstance(inner_fn, types.MethodType): + return variables.UserMethodVariable( + inner_fn.__func__, self.objvar, source=source + ).call_function(tx, args, kwargs) + elif ( + inner_fn is collections.OrderedDict.__getitem__ + and isinstance(self.objvar, variables.UserDefinedObjectVariable) + and self.objvar.source + and len(args) == 1 + and len(kwargs) == 0 + and args[0].is_python_constant() + ): + from .builder import VariableBuilder + + key = args[0].as_python_constant() + return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))( + collections.OrderedDict.__getitem__(self.objvar.value, key) + ) + elif inner_fn in ( + collections.OrderedDict.__setitem__, + object.__setattr__, + ) and isinstance(self.objvar, variables.CustomizedDictVariable): + assert not kwargs and len(args) == 2 + return super(variables.CustomizedDictVariable, self.objvar).call_method( + tx, "__setitem__", args, kwargs + ) + else: + unimplemented(f"non-function or method super: {inner_fn}") + + +class UnknownVariable(VariableTracker): + """ + It could be anything! + """ + + +class DelayGraphBreakVariable(UnknownVariable): + """ + Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. + """ + + +class ComptimeVariable(VariableTracker): + """ + This variable is special, it lets you execute arbitrary code at + Dynamo compile time + """ + + def reconstruct(self, codegen): + raise NotImplementedError("comptime is special form") + + def var_getattr(self, tx, name: str) -> "VariableTracker": + from ..comptime import comptime + + # To support the comptime.print_graph convenience accessors + from .functions import UserFunctionVariable + + return UserFunctionVariable( + getattr(comptime, name), source=AttrSource(self.source, name) + ) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from ..comptime import ComptimeContext + + # TODO: support an expression form as well + + assert not kwargs + assert len(args) == 1 + fn = args[0] + if isinstance(fn, UserFunctionVariable): + fn.get_function()(ComptimeContext(tx)) + elif isinstance(fn, NestedUserFunctionVariable): + # We have to manually bind the freevars ourselves + code = fn.get_code() + assert not fn.closure, ( + "comptime function must not have free variables, " + f"but these variables were free: {code.co_freevars}" + ) + func = types.FunctionType( + code, + fn.f_globals, + fn.fn_name.as_python_constant(), + tuple(fn.defaults.items) if fn.defaults else None, + # We could automatically promote free variables into + # ComptimeVar but this is confusing if you access + # a free variable that we actually DO have the runtime + # value for + # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) + tuple(), + ) + func(ComptimeContext(tx)) + else: + raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") + + return variables.ConstantVariable.create(None) + + +class ClosureVariable(UnknownVariable): + def __init__(self, name, **kwargs): + super().__init__(**kwargs) + self.name = name + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load_closure(self.name)) + + +# closure variable created by an inlined function +class InlinedClosureVariable(UnknownVariable): + def __init__(self, name, **kwargs): + super().__init__(**kwargs) + self.name = name + + def reconstruct(self, codegen): + codegen.append_output(codegen.create_load_closure(self.name)) + + +class NewCellVariable(VariableTracker): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class NewGlobalVariable(VariableTracker): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class InspectSignatureVariable(VariableTracker): + """represents inspect.signature(...)""" + + @staticmethod + def create(callable, **kwargs): + if kwargs: + unimplemented(f"inspect.signature with {kwargs}") + return InspectSignatureVariable(callable) + + def __init__(self, inspected: VariableTracker, **kwargs): + super().__init__(**kwargs) + self.inspected = inspected + + def var_getattr(self, tx, name: str) -> "VariableTracker": + if name == "parameters": + return variables.ConstDictVariable( + { + variables.ConstantVariable.create(name): InspectParameterVariable() + for name in self.inspected.inspect_parameter_names() + }, + user_cls=dict, + ) + return super().var_getattr(tx, name) + + +class InspectParameterVariable(VariableTracker): + """This is not implemented, if used will graph break.""" + + pass + + +def produce_trampoline_autograd_apply(fn_cls): + def trampoline_autograd_apply(*args, **kwargs): + return fn_cls.apply(*args, **kwargs) + + trampoline_autograd_apply._origin = produce_trampoline_autograd_apply + return trampoline_autograd_apply + + +class AutogradFunctionVariable(VariableTracker): + """represents a torch.autograd.Function subclass""" + + def __init__(self, fn_cls, **kwargs): + super().__init__(**kwargs) + self.fn_cls = fn_cls + + def call_apply(self, tx, args, kwargs): + requires_grad = False + + def visit(node): + nonlocal requires_grad + if isinstance(node, variables.TensorVariable): + if node.requires_grad is not False: + requires_grad = True + if isinstance(node, variables.NNModuleVariable): + if node.is_training(tx): + requires_grad = True + return node + + VariableTracker.apply(visit, (args, kwargs)) + + if ( + requires_grad + and torch.is_grad_enabled() + and config.capture_autograd_function + ): + # Note - this is the same check used in autograd/function.py, except inverted. + # If we want to support functorch transforms here, we will need to enable this. + if ( + self.fn_cls.setup_context + != torch.autograd.function._SingleLevelFunction.setup_context + ): + unimplemented( + "NYI - autograd.Function with custom setup_context method" + ) + + vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] + if vjp_fn is not torch.autograd.Function.vjp: + unimplemented("NYI - User defind vjp") + + jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] + if jvp_fn is not torch.autograd.Function.jvp: + unimplemented("NYI - User defind jvp") + + from .higher_order_ops import AutogradFunctionApplyVariable + + source = self.source + if source is None: + source = AttrSource( + tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ + ) + + return AutogradFunctionApplyVariable( + self.fn_cls.forward, + self.fn_cls.backward, + source, + source=AttrSource(source, member="apply"), + ).call_function(tx, args, kwargs) + + if self.source: + source = AttrSource(self.source, "forward") + else: + source = None + + fn = self.fn_cls.forward + ctx = AutogradFunctionContextVariable.create(tx) + args = [ctx, *args] + if isinstance(fn, types.FunctionType): + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, args, kwargs + ) + elif isinstance(fn, types.MethodType): + return variables.UserMethodVariable( + fn.__func__, + variables.UserDefinedClassVariable(self.fn_cls), + source=source, + ).call_function(tx, args, kwargs) + else: + unimplemented( + f"non-function or method in subclass of torch.autograd.Function: {fn}" + ) + + def call_function(self, tx, args, kwargs): + return AutogradFunctionVariable(self.fn_cls) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ): + from ..trace_rules import is_callable_allowed + from .builder import wrap_fx_proxy + + if name == "apply": + if is_callable_allowed(self.fn_cls): + trampoline_autograd_apply = produce_trampoline_autograd_apply( + self.fn_cls + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + trampoline_autograd_apply, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + return self.call_apply(tx, args, kwargs) + + else: + unimplemented(f"Unsupported method: {name}") + + +@dataclasses.dataclass +class SavedTensorBox: + tensors: List[VariableTracker] = dataclasses.field(default_factory=list) + + +class AutogradFunctionContextVariable(UserDefinedObjectVariable): + """ + Tracks an autograd.Function() context using mutation tracking in side_effects.py + """ + + _nonvar_fields = { + "proxy", + "inference", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value, + value_type=None, + inference=False, + proxy=None, + saved_tensors=None, + **kwargs, + ): + super().__init__(value=value, value_type=value_type, **kwargs) + self.inference = inference + self.proxy = proxy + self.saved_tensors = saved_tensors + + @staticmethod + def create(tx): + proxy = tx.output.create_proxy( + "call_function", torch.autograd.function.FunctionCtx, tuple(), {} + ) + out = tx.output.side_effects.track_object_new( + None, + torch.autograd.function.FunctionCtx, + functools.partial( + AutogradFunctionContextVariable, + inference=True, + proxy=proxy, + saved_tensors=SavedTensorBox(), + ), + {}, + ) + proxy.node.meta["example_value"] = out.value + return out + + def as_proxy(self): + if self.proxy is None: + unimplemented("proxy not set") + return self.proxy + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name != "save_for_backward": + unimplemented(f"autograd.Function context method: {name}") + if self.saved_tensors is None: + unimplemented( + "save_for_backward only supported on a newly constructed FunctionCtx" + ) + + if not self.inference: + assert self.source and not kwargs + tx.output.side_effects.track_save_for_backward(self, args) + + # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. + if len(self.saved_tensors.tensors) > 0: + self.saved_tensors.tensors = [] + for arg in args: + self.saved_tensors.tensors.append(arg) + return variables.ConstantVariable.create(None) + + def var_getattr(self, tx, name): + if name == "save_for_backward": + return LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + if name == "saved_tensors" and self.saved_tensors is not None: + return variables.TupleVariable(list(self.saved_tensors.tensors)) + return super().var_getattr(tx, name) + + +class LambdaVariable(VariableTracker): + def __init__(self, fn, **kwargs): + super().__init__(**kwargs) + self.fn = fn + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + return self.fn(*args, **kwargs) + + +class GetAttrVariable(VariableTracker): + def __init__(self, obj, name, **kwargs): + super().__init__(**kwargs) + assert isinstance(obj, VariableTracker) + assert isinstance(name, str) + self.obj = obj + self.name = name + + def __str__(self): + return f"{self.__class__.__name__}({self.obj}, {self.name})" + + @staticmethod + def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): + return getattr(base_proxy, attr) + + def as_proxy(self): + return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) + + def const_getattr(self, tx, name): + if not isinstance(self.obj, variables.NNModuleVariable): + raise NotImplementedError() + step1 = tx.output.get_submodule(self.obj.module_key) + if self.name not in step1.__dict__: + raise NotImplementedError() + step2 = inspect.getattr_static(step1, self.name) + if name not in step2.__dict__: + raise NotImplementedError() + return inspect.getattr_static(step2, name) + + def reconstruct(self, codegen): + codegen(self.obj) + codegen.extend_output(codegen.create_load_attrs(self.name)) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + return self.obj.call_method(tx, self.name, args, kwargs) + + +class MethodWrapperVariable(VariableTracker): + def __init__(self, method_wrapper, **kwargs): + super().__init__(**kwargs) + self.method_wrapper = method_wrapper + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( + args[0], variables.TensorVariable + ): + assert len(args) == 1 and len(kwargs) == 0 + + return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) + + super().call_function(tx, args, kwargs) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.method_wrapper + + +class GetSetDescriptorVariable(VariableTracker): + def __init__(self, desc, **kwargs): + super().__init__(**kwargs) + self.desc = desc + + def var_getattr(self, tx, name): + if name == "__get__" and self.source: + from .builder import VariableBuilder + + return VariableBuilder(tx, AttrSource(self.source, "__get__"))( + self.desc.__get__ + ) + else: + return super().var_getattr(tx, name) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.desc + + +class PythonModuleVariable(VariableTracker): + def __init__(self, value: types.ModuleType, **kwargs): + super().__init__(**kwargs) + self.value = value + self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") + + def python_type(self): + return types.ModuleType + + def as_python_constant(self): + return self.value + + def __repr__(self): + return f"PythonModuleVariable({self.value})" + + def call_hasattr(self, tx, name): + if self.is_torch: + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + return super().call_hasattr(tx, name) + + +class TypingVariable(VariableTracker): + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__getitem__" and len(args) == 1: + return variables.ConstantVariable.create( + self.value[args[0].as_python_constant()], + ) + unimplemented("typing") + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + +@functools.lru_cache(maxsize=1) +def get_np_to_tnp_map(): + from ..utils import NP_TO_TNP_MODULE + + np_fn_to_tnp_fn = {} + + for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): + for fn_name, tnp_fn in tnp_mod.__dict__.items(): + if callable(tnp_fn): + # some internal details do leak from tnp + # which are not part of numpy API. + if np_fn := getattr(np_mod, fn_name, None): + np_fn_to_tnp_fn[np_fn] = tnp_fn + + return np_fn_to_tnp_fn + + +class NumpyVariable(VariableTracker): + """ + Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. + """ + + constant_fold_functions = (tnp.issubdtype,) + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + @classmethod + def can_constant_fold_through(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return fn in cls.constant_fold_functions + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + if not config.trace_numpy: + unimplemented(f"numpy.{self.value}()") + + from ..utils import numpy_to_tensor_wrapper + + from .tensor import NumpyNdarrayVariable + + # lookup method name in tnp. Things like np.dtype(float) are not supported yet. + if self.value.__name__ == "dtype": + unimplemented( + f"numpy dtype function is not supported yet. Got type {type(self.value)}." + ) + else: # We are dealing with a callable. + func = get_np_to_tnp_map().get(self.value) + if func is None: + unimplemented( + f"Can't find numpy function {self.value} in torch._numpy. " + " Please file an issue to request support for this function." + ) + + if ( + func.__module__ == "torch._numpy.random" + and config.use_numpy_random_stream + ): + msg = f"delegate '{func.__qualname__}' to NumPy itself via " + msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" + unimplemented(msg) + + args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) + + constant_args = check_constant_args(args, kwargs) + unspec_python_args = check_unspec_python_args(args, kwargs) + + if self.can_constant_fold_through(func) and ( + constant_args or unspec_python_args + ): + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + + # TODO Add all the functions that go from constants to constants to can_constant_fold_through + proxy = tx.output.create_proxy( + "call_function", + numpy_to_tensor_wrapper(func), + *proxy_args_kwargs(args, kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented("numpy") + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + def as_proxy(self): + if config.trace_numpy and isinstance(self.value, type): + # This handles numpy dtype attributes such as np.float32 + # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph + # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does + return self.value.__name__ + + return super().as_proxy() + + +# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls +class NullVariable(VariableTracker): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __str__(self): + return "NullVariable" + + def reconstruct(self, codegen): + if sys.version_info < (3, 11): + unimplemented("cannot reconstruct NullVariable in < Python 3.11") + codegen.append_output(create_instruction("PUSH_NULL")) + + +class DeletedVariable(VariableTracker): + """Marker used to implement delattr()""" + + +class StringFormatVariable(VariableTracker): + """ + Represents a call to str.format(), we delay calling format until after the graph. + """ + + _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} + + @classmethod + def create(cls, format_string, sym_args, sym_kwargs): + if all( + x.is_python_constant() + for x in itertools.chain(sym_args, sym_kwargs.values()) + ): + return variables.ConstantVariable.create( + format_string.format( + *[v.as_python_constant() for v in sym_args], + **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, + ) + ) + return cls(format_string, list(sym_args), dict(sym_kwargs)) + + def __init__(self, format_string, sym_args, sym_kwargs, **kwargs): + super().__init__(**kwargs) + assert isinstance(format_string, str) + self.format_string = format_string + self.sym_args = sym_args + self.sym_kwargs = sym_kwargs + + def __repr__(self): + return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" + + def reconstruct(self, codegen): + if sys.version_info >= (3, 11): + codegen.append_output(create_instruction("PUSH_NULL")) + codegen.append_output(codegen.create_load_const(self.format_string)) + codegen.append_output(codegen.create_load_attr("format")) + codegen(variables.TupleVariable(self.sym_args)) + kwargs = { + variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() + } + codegen(variables.ConstDictVariable(kwargs)) + codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) + + +class DebuggingVariable(VariableTracker): + """ + Represents a call to a debugging function like print(), or something + registered to config.reorderable_logging_functions. + """ + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + @staticmethod + def is_reorderable_logging_function(obj): + return ( + callable(obj) + and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) + and obj in torch._dynamo.config.reorderable_logging_functions + ) + + def call_function(self, tx, args, kwargs): + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + + if not self.can_reorder_logs(self.value, args, kwargs): + unimplemented( + f"Reordering debugging function {self.value} " + f"with inputs {args} {kwargs} is not yet implemented." + ) + + tx.debug_locals.append((self, list(args))) + + def reconstruct(self, codegen): + return self.source.reconstruct(codegen) + + @staticmethod + def can_reorder_logs(fn, args, kwargs) -> True: + """ + Run some additional checks for what sort of function calls can we + actually reorder. + """ + + allowed_input_types = ( + variables.TensorVariable, + variables.ConstantVariable, + StringFormatVariable, + ) + + flat_args = pytree.tree_leaves([args, kwargs]) + for arg in flat_args: + if not isinstance(arg, allowed_input_types): + return False + + return True diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/nn_module.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/nn_module.py new file mode 100644 index 0000000000000000000000000000000000000000..9da90048ac33766ebbe0a42bed7391c793e89226 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/nn_module.py @@ -0,0 +1,813 @@ +# mypy: ignore-errors + +import functools +import inspect +import itertools +import types +from contextlib import contextmanager, nullcontext +from typing import Any, Dict, List + +import torch.nn + +from .. import trace_rules, variables +from ..exc import unimplemented, UnspecializeRestartAnalysis, Unsupported +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import GenerationTracker +from ..source import ( + AttrSource, + FSDPNNModuleSource, + GetItemSource, + NNModuleSource, + NotNNModuleSource, +) +from ..utils import ( + get_custom_getattr, + get_fake_value, + is_lazy_module, + is_namedtuple, + is_safe_constant, + istensor, + istype, + nnmodule_has_hooks, + object_has_getattribute, + proxy_args_kwargs, +) +from .base import MutableLocal, typestr, VariableTracker +from .functions import invoke_and_store_as_constant +from .lists import SliceVariable +from .user_defined import UserDefinedObjectVariable + + +def initialize_lazy_module(tx, mod, args, kwargs): + """ + Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable. + + Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially + useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook + by the time we trace __call__ and thus no graph-break for lazy allowed modules. + """ + if hasattr(mod, "_initialize_hook"): + + def convert_to_fake(x): + if is_namedtuple(x): + return type(x)(*(convert_to_fake(elem) for elem in x)) + elif isinstance(x, dict): + return {k: convert_to_fake(v) for k, v in x.items()} + elif isinstance(x, (list, tuple, set)): + return type(x)(convert_to_fake(elem) for elem in x) + elif isinstance(x, torch.fx.Proxy): + return get_fake_value(x.node, tx) + else: + return x + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + fake_args = [convert_to_fake(arg) for arg in proxy_args] + fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} + mod._infer_parameters(mod, fake_args, fake_kwargs) + + +@contextmanager +def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): + fully_qualified_name = source.name() + try: + tx.nn_module_stack[module_key] = (fully_qualified_name, type(mod)) + yield + finally: + del tx.nn_module_stack[module_key] + + +class NNModuleVariable(VariableTracker): + _nonvar_fields = {"module_type", "module_key", *VariableTracker._nonvar_fields} + + def __init__( + self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs + ): + super().__init__(**kwargs) + self.module_type = module_type + self.module_key = module_key + self.module = module + assert self.source + + def python_type(self): + return self.module_type + + def _wrap_submodule(self, tx, source, submod, *key_extra, **options): + return + + def unpack_var_sequence(self, tx): + # implement list/iter/tuple/etc calls + base = tx.output.get_submodule(self.module_key) + if isinstance(base, torch.nn.ModuleDict): + result = [] + for name, submod in base.items(): + name_var = variables.ConstantVariable.create(name) + tx.output.register_attr_or_module( + submod, + self.module_key, + name, + source=NNModuleSource(GetItemSource(self.source, name)), + ) + result.append(name_var) + return result + + assert isinstance( + base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential) + ), typestr(base) + assert self.source + result = [] + for idx, submod in enumerate(base): + result.append( + tx.output.register_attr_or_module( + submod, + self.module_key, + idx, + source=NNModuleSource(GetItemSource(self.source, idx)), + ) + ) + return result + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + mod = tx.output.get_submodule(self.module_key) + result = hasattr(mod, name) + install_guard( + NNModuleSource(AttrSource(self.source, name)).make_guard( + GuardBuilder.HASATTR + ) + ) + return variables.ConstantVariable.create(result) + + def is_training(self, tx): + mod = tx.output.get_submodule(self.module_key) + return getattr(mod, "training", False) + + def convert_to_unspecialized(self, tx): + """Restart analysis treating this module as an UnspecializedNNModuleVariable""" + mod = tx.output.get_submodule(self.module_key) + GenerationTracker.tag(mod) + + # Mark the class dynamic unless its module initialization + if tx.f_code.co_name != "__init__": + GenerationTracker.mark_class_dynamic(type(mod)) + raise UnspecializeRestartAnalysis() + + def _custom_getattr_fallback(self, base, tx, name, options): + """Check for a __getattr__ and handle it specially if it is implemented""" + if object_has_getattribute(base): + unimplemented("torch.nn.Module with a custom __getattribute__ defined") + + getattr_fn = get_custom_getattr(base) + if getattr_fn is None: + return None + + if not isinstance(getattr_fn, types.FunctionType): + unimplemented("torch.nn.Module with a non-function custom __getattr__") + + return variables.UserMethodVariable(getattr_fn, self, **options).call_function( + tx, [variables.ConstantVariable.create(name)], {} + ) + + def var_getattr(self, tx, name): + from .builder import VariableBuilder + + if self.source: + source = AttrSource(self.source, name) + else: + source = None + + base = tx.output.get_submodule(self.module_key) + base_dict = object.__getattribute__(base, "__dict__") + object_member = True + all_class_attribute_names = set() + for x in inspect.getmro(base.__class__): + all_class_attribute_names.update(x.__dict__.keys()) + + if not self.source: + unimplemented("GETATTR with no source") + + if name in base_dict: + subobj = base_dict[name] + elif ( + "_modules" in base_dict + and name in base_dict["_modules"] + and name not in all_class_attribute_names + ): + subobj = base_dict["_modules"][name] + elif "_parameters" in base_dict and name in base_dict["_parameters"]: + subobj = base_dict["_parameters"][name] + elif "_buffers" in base_dict and name in base_dict["_buffers"]: + subobj = base_dict["_buffers"][name] + else: + try: + subobj = inspect.getattr_static(base, name) + object_member = False + except AttributeError: + # see if we can fallback to __getattr__, which is not checked by getattr_static + result = self._custom_getattr_fallback( + base=base, tx=tx, name=name, options={"source": source} + ) + if result is not None: + return result + # if we can't find a __getattr__, just raise the AttributeError + raise + + if name == "__class__" and not object_member: + return variables.UserDefinedClassVariable(base.__class__, source=source) + + if object_member: + return VariableBuilder(tx, NNModuleSource(source))(subobj) + else: + if istype(subobj, property): + return variables.UserFunctionVariable( + subobj.fget, + source=source, + ).call_function(tx, [(self)], {}) + elif istype(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, + variables.UserDefinedObjectVariable(type(base)), + source=source, + ) + elif istype(subobj, staticmethod): + return variables.UserFunctionVariable( + subobj.__get__(base), source=source + ) + elif istype(subobj, types.FunctionType): + return variables.UserMethodVariable(subobj, self, source=source) + elif is_safe_constant(subobj) or istensor(subobj): + # Support possibly common cases of class members + return VariableBuilder(tx, NNModuleSource(source))(subobj) + else: + unimplemented(f"class property {typestr(base)} {typestr(subobj)}") + + return variables.GetAttrVariable(self, name, source=source) + + def call_function( + self, + tx, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + mod = tx.output.get_submodule(self.module_key) + + with record_nn_module_stack(self.module_key, self.source, tx, mod): + is_lazy = is_lazy_module(mod) + if ( + isinstance(mod, torch.nn.Sequential) + and mod.__class__.forward is torch.nn.Sequential.forward + ): + if nnmodule_has_hooks(mod): + # We do not want to unroll sequential if it has hooks, since evaporating it + # will cause hooks to not fire! + # This terminates and restart the tracing process + self.convert_to_unspecialized(tx) + + # Unroll sequential + assert ( + not is_lazy + ), "Expected lazy sequential isn't a valid combination?" + assert not kwargs + (arg,) = args + # TODO: Use named_children when it supports remove_duplicate=False. + for child_name, submod in mod._modules.items(): + tx.call_function( + tx.output.register_attr_or_module( + submod, + self.module_key, + child_name, + source=NNModuleSource(AttrSource(self.source, child_name)), + ), + [arg], + {}, + ) + arg = tx.pop() + return arg + + if is_lazy: + # The module type will change after it is called + if mod.cls_to_become is not None: + self.module_type = mod.cls_to_become + + # The pre-hook runs to initialize the module shapes, then deletes itself. After this, + # the module is more or less not lazy and can be treated as a normal module regardless of + # is_allowed or other variations. + initialize_lazy_module(tx, mod, args, kwargs) + + # If we are tracing the higher order op, we want Dynamo to step + # inside the module call so that Dynamo can see the underlying + # parameters and buffers and raise them as inputs to the graph. + if tx.output.is_root_tracer() and mod.__module__.startswith( + ("torch.nn.", "torch.ao.") + ): + if nnmodule_has_hooks( + mod, check_forward_hooks=True, check_backward_hooks=True + ): + # End of fn, this bubbles up and restarts tracing. + self.convert_to_unspecialized(tx) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_module", + self.module_key, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + assert self.source, ( + "Must provide a valid source in order to inline, " + "since inlined function may have default args which must be guarded." + ) + if isinstance(mod, torch.fx.GraphModule): + # TODO: do we want to support __call__ for GM's? + # If so at least some changes are needed, we don't allow inlining + # the call_wrapped currently, and maybe other issues too + fn = mod.forward + else: + fn = mod._call_impl + fn_source = AttrSource(self.source, "__call__") + if istype(fn, types.MethodType): + fn = fn.__func__ + fn_source = AttrSource(fn_source, "__func__") + args = [self] + args + else: + assert istype(fn, types.FunctionType) + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=fn_source), + args, + kwargs, + ) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + constant=False, + ) -> "VariableTracker": + from . import ConstantVariable, ListIteratorVariable, TupleVariable + + key = self.module_key + module = tx.output.get_submodule(key) + + def generic_call_method_helper(name): + # Helper function to put a `call_method` node in FX graph, + # with nn.Module as the first arg. + mod_proxy = tx.output.create_proxy( + "get_attr", + self.module_key, + tuple(), + {}, + ) + mod_proxy.node.meta["example_value"] = module + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_method", + name, + args=(mod_proxy, *proxy_args), + kwargs=proxy_kwargs, + ), + ) + + if name in ["_call_impl", "_wrapped_call_impl"]: + # Example: `self.layer.__call__(x)` + # This is used for explicit calling `__call__` in a forward function. + # Dynamo inlines `__call__`, includes hooks. + return self.call_function(tx, args, kwargs) + elif name == "forward": + # Example: `self.layer.forward(x)` + # This is used for explicit calling `forward` in a forward function. + # Dynamo puts `call_method` node in FX, doesn't trigger hooks. + with record_nn_module_stack(self.module_key, self.source, tx, module): + return generic_call_method_helper(name) + + if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( + inspect.getfile(module.__class__._check_input_dim) + ): + return ConstantVariable.create(True) + + if name == "_get_item_by_idx": + assert args[1].is_python_constant() + assert isinstance(args[0], TupleVariable) + mod_var = args[0].items[args[1].value] + if isinstance(mod_var, UnspecializedNNModuleVariable): + return mod_var + key = mod_var.module_key + submod = tx.output.get_submodule(key) + return tx.output.register_attr_or_module( + submod, + key, + key, + source=NNModuleSource(GetItemSource(self.source, key)), + ) + + if constant: + fn = getattr(module, name) + name = f"{module.__class__.__name__}_{name}_result" + return invoke_and_store_as_constant(tx, fn, name, args, kwargs) + + def assert_all_args_kwargs_const(): + if not all( + x.is_python_constant() for x in itertools.chain(args, kwargs.values()) + ): + raise unimplemented(f"non-const NNModule method {name}") + + def get_kwargs(*names): + assert_all_args_kwargs_const() + fn = getattr(module, name) + bound_args = inspect.signature(fn).bind( + *([x.as_python_constant() for x in args]), + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + bound_args.apply_defaults() + bound_args = bound_args.arguments + return {k: bound_args[k] for k in names} + + def wrap_values(items): + result = [] + for name, submod in items: + result.append( + tx.output.register_attr_or_module( + submod, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ) + ) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + + def named_embed(name, obj): + return TupleVariable( + [ + ConstantVariable.create(name), + tx.output.register_attr_or_module( + obj, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ), + ] + ) + + def gen_source(source, name): + name_split = name.split(".") + if name_split[0] == "": + return source + while len(name_split) > 0: + x = name_split.pop(0) + source = AttrSource(source, x) + return source + + if name == "named_children": + assert not (args or kwargs) + result = [] + for name, submod in module.named_children(): + result.append(named_embed(name, submod)) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + elif name == "named_parameters": + result = [] + for name, param in module.named_parameters( + **get_kwargs("prefix", "recurse") + ): + result.append(named_embed(name, param)) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + elif name == "named_buffers": + result = [] + for name, buffer in module.named_buffers( + **get_kwargs("prefix", "recurse", "remove_duplicate") + ): + result.append(named_embed(name, buffer)) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + elif name == "named_modules": + result = [] + for name, submod in module.named_modules( + **get_kwargs("memo", "prefix", "remove_duplicate") + ): + result.append(named_embed(name, submod)) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + elif name == "children": + assert not (args or kwargs) + return wrap_values(module.named_children()) + elif name == "modules": + return wrap_values(module.named_modules()) + elif name == "parameters": + return wrap_values(module.named_parameters(**get_kwargs("recurse"))) + elif name == "buffers": + return wrap_values(module.named_buffers(**get_kwargs("recurse"))) + elif name == "keys": + assert not (args or kwargs) + result = [] + for name in module.keys(): + result.append(ConstantVariable.create(name)) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + elif name == "values": + assert not (args or kwargs) + return wrap_values(module.items()) + elif name == "items": + assert not (args or kwargs) + result = [] + for name, submod in module.items(): + result.append(named_embed(name, submod)) + return ListIteratorVariable(result, mutable_local=MutableLocal()) + elif name == "__len__": + assert not (args or kwargs) + return ConstantVariable.create(len(module)) + elif ( + name == "__contains__" + and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict)) + and args + and args[0].is_python_constant() + ): + return ConstantVariable.create( + args[0].as_python_constant() in module._modules + ) + elif name == "__getitem__": + assert not kwargs and len(args) == 1 + builtin_supported = ( + torch.nn.ModuleDict.__getitem__, + torch.nn.ModuleList.__getitem__, + torch.nn.ParameterDict.__getitem__, + torch.nn.ParameterList.__getitem__, + torch.nn.Sequential.__getitem__, + ) + + if type(module).__getitem__ not in builtin_supported: + assert isinstance(args[0], variables.ConstantVariable), typestr(args[0]) + key = args[0].as_python_constant() + assert isinstance(key, (str, int)) + fn = getattr(module, name).__func__ + + assert isinstance(fn, types.FunctionType) + + src = AttrSource(AttrSource(self.source, name), "__func__") + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=src), + [self] + list(args), + kwargs, + ) + + assert self.source + + if isinstance(args[0], SliceVariable): + # Build a TupleVariable of NNModules + result = [] + submods = [] + + # Turn the slice into the list of integers + keys = list(range(len(module)))[args[0].as_python_constant()] + for idx, submod in enumerate(module[args[0].as_python_constant()]): + key = keys[idx] + src = NNModuleSource(GetItemSource(self.source, key)) + result.append( + tx.output.register_attr_or_module( + submod, + key, + source=src, + ) + ) + submods.append(submod) + + new_module = torch.nn.Sequential(*submods) + new_module_variable = tx.output.register_attr_or_module( + new_module, + f"{self}.__getitem__(slice)", + source=NNModuleSource( + GetItemSource(self.source, args[0].as_python_constant()) + ), + ) + return new_module_variable + + from .tensor import SymNodeVariable + + if isinstance(args[0], SymNodeVariable): + key = args[0].evaluate_expr(tx.output) + else: + key = args[0].as_python_constant() + + submod = module[key] + return tx.output.register_attr_or_module( + submod, + self.module_key, + key, + source=NNModuleSource(GetItemSource(self.source, key)), + ) + elif ( + name == "_get_abs_string_index" + or ( + isinstance(module, torch.nn.modules.conv._ConvNd) + and name == "_conv_forward" + ) + or ( + isinstance(module, torch.nn.modules.conv._ConvTransposeNd) + and name == "_output_padding" + ) + ): + # Inline the function + fn = getattr(module, name).__func__ + fn_source = AttrSource(AttrSource(self.source, name), "__func__") + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=fn_source), + [self] + args, + kwargs, + ) + # A loose heuristic, but seems to be generally good before we drop into the + # manual handling of inputs + elif ( + name in module.__class__.__dict__ + and callable(module.__class__.__dict__[name]) + and all( + isinstance(x, variables.TensorVariable) + for x in itertools.chain(args, kwargs.values()) + ) + ): + return generic_call_method_helper(name) + else: + return super().call_method(tx, name, args, kwargs) + + +class UnspecializedNNModuleVariable(UserDefinedObjectVariable): + _nonvar_fields = {"value_type", *UserDefinedObjectVariable._nonvar_fields} + + """ + The above class will specialize on the id() of a module and place + parameters on the torch.fx.GraphModule. Giving one graph per + module instance. This version treats nn.Modules() like other user + defined objects and will pass parameters into the FX graph as inputs. + Giving one graph per module class. + """ + + def __init__(self, value, **kwargs): + if type(value) is torch.jit._script.RecursiveScriptModule: + raise Unsupported( + "ScriptModules aren't supported in UnspecializedNNModuleVariable" + " becuase their .forward function isn't a static member of their type" + ) + if "value_type" in kwargs: + lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) + if type(value) is lazy_value_to_become: + # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects) + # and then later we called and mutated the LazyModule into a MaterializedModule. + # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only + # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation. + kwargs["value_type"] = type(value) + + super().__init__(value=value, **kwargs) + + @staticmethod + @functools.lru_cache(None) + def _nn_module_method_ids(): + return { + id(x.__code__) + for x in torch.nn.Module.__dict__.values() + if hasattr(x, "__code__") + } + + def unpack_var_sequence(self, tx): + from .builder import VariableBuilder + + try: + fn = inspect.getattr_static(self.value_type, "__iter__") + except AttributeError as e: + raise NotImplementedError from e + + if fn in ( + torch.nn.ModuleList.__iter__, + torch.nn.ParameterList.__iter__, + torch.nn.Sequential.__iter__, + ): + assert self.source + return [ + VariableBuilder(tx, source=GetItemSource(self.source, idx))(item) + for idx, item in enumerate(self.value) + ] + + return super().unpack_var_sequence(tx) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + mod = self.value + # see comment on lazy module handling in NNModuleVariable.call_function for context + if is_lazy_module(mod): + if mod.cls_to_become is not None: + self.value_type = mod.cls_to_become + initialize_lazy_module(tx, mod, args, kwargs) + name = "_call_impl" + fn = getattr(self.value_type, name) + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + else: + source = None + + ctx = ( + record_nn_module_stack(str(id(mod)), self.source, tx, mod) + if self.source + else nullcontext() + ) + with ctx: + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, [self] + list(args), kwargs + ) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import VariableBuilder + + if name in ["_call_impl", "_wrapped_call_impl"]: + fn = getattr(self.value_type, name) + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + else: + source = None + + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, [self] + list(args), kwargs + ) + + if name not in getattr(self.value, "__dict__", {}): + try: + method = inspect.getattr_static(type(self.value), name) + except AttributeError: + method = None + + if method is torch.nn.Module.parameters: + assert not args or kwargs + if tx.output.side_effects.has_pending_mutation(self): + unimplemented("Module.parameters() with pending mutation") + install_guard( + self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES) + ) + items = [] + for name, value in self.value.named_parameters(): + items.append( + VariableBuilder(tx, AttrSource(self.source, name))(value) + ) + return variables.ListIteratorVariable( + items, mutable_local=MutableLocal() + ) + elif isinstance(method, staticmethod): + source = AttrSource( + AttrSource(AttrSource(self.source, "__class__"), name), "__func__" + ) + return tx.inline_user_function_return( + variables.UserFunctionVariable(method.__func__, source=source), + args, + kwargs, + ) + + if id(method.__code__) in self._nn_module_method_ids(): + unimplemented(f"UnspecializedNNModuleVariable missing {name}") + + return super().call_method(tx, name, args, kwargs) + + +class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): + """ + Tracing behavior: trace into submodules and treat them as Unspecialized, do not + register parameters to the top-level, treat them as function inputs. + + Guards behavior: if 'skip_fsdp_guards', many guards that would be installed + by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis + that a user wrapping their model in FSDP(model) is already opting into a + requirement to not modify internal model state, which would already break FSDP without + compilation. + """ + + def __init__(self, value, **kwargs): + source = kwargs.get("source", None) + assert ( + source is not None + ), "FSDPManagedNNModule depends on having an accurate source to control guarding." + + super().__init__(value=value, **kwargs) + self.source = source + + @staticmethod + def _wrap_source(source): + if not isinstance(source, (FSDPNNModuleSource, NotNNModuleSource)): + if torch._dynamo.config.skip_fsdp_guards: + return FSDPNNModuleSource(source) + else: + # this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes + return NotNNModuleSource(source) + else: + return source + + def __setattr__(self, name: str, value: Any) -> None: + if name == "source": + value = FSDPManagedNNModuleVariable._wrap_source(value) + + return super().__setattr__(name, value) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/optimizer.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdfda8c3c0d532ab0073a83e41070a133d4fe31 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/optimizer.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +import weakref +from typing import Dict, List + +import torch + +from ..decorators import mark_static_address + +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, ConstDictKeySource, GetItemSource, GlobalWeakRefSource +from ..utils import GLOBAL_KEY_PREFIX + +from .base import VariableTracker +from .constant import ConstantVariable +from .dicts import ConstDictVariable +from .lists import ListVariable +from .misc import GetAttrVariable +from .user_defined import UserDefinedObjectVariable + + +class ArgMappingException(Exception): + pass + + +class GuardInstallException(Exception): + pass + + +class OptimizerVariable(UserDefinedObjectVariable): + def __init__( + self, + value, + grad_to_source=None, + static_tensor_names=None, + tensor_to_source=None, + **kwargs, + ): + super().__init__(value, **kwargs) + + for group in self.value.param_groups: + if "capturable" in group: + group["capturable"] = True + + for p in group["params"]: + mark_static_address(p, guard=False) + + self.grad_to_source = grad_to_source or {} + self.tensor_to_source = tensor_to_source or {} + self.static_tensor_names = static_tensor_names or set() + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + """This is an optimization to avoid tracing the very slow initialization of the optimizer""" + if name == "_init_group": + try: + py_args, py_kwargs = self.get_python_args(*args, **kwargs) + ret_val = self.value._init_group(*py_args, **py_kwargs) + self.map_sources_and_install_guards(tx) + self.update_list_args(tx, args, kwargs, py_args, py_kwargs) + # stash a weak_ptr to optimizer to invalidate code + # if the optimizer object dies + mangled_name = f"__optimizer_{id(self.value)}" + tx.store_global_weakref_by_id(mangled_name, self.value) + self.create_finalizer(tx) + + # This is currently safe only because the only actual `ret_val`s returned + # by the `_init_group` of existing optimizers are properties that are invariant + # to the input tensors (e.g. dtype, layout). Changing these would trigger a + # recompilation and hence never result in the wrong specialization of `ret_val`. + return ConstantVariable.create(ret_val) + except (ArgMappingException, GuardInstallException) as _: + # trace normally if we can't map args or install guards correctly + pass + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "_init_group": + return GetAttrVariable(self, name) + + return super().var_getattr(tx, name) + + def get_python_args(self, *args, **kwargs): + """Get python values equivalent to the variable tracker args""" + + def map_arg(arg): + if isinstance(arg, ConstantVariable): + return arg.as_python_constant() + elif isinstance(arg, ListVariable) and not arg.items: + return [] + elif ( + isinstance(arg, ConstDictVariable) + and isinstance(arg.source, GetItemSource) + and isinstance(arg.source.base, AttrSource) + and arg.source.base.member == "param_groups" + ): + return self.value.param_groups[arg.source.index] + + raise ArgMappingException() + + new_args = [map_arg(arg) for arg in args] + new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} + + return new_args, new_kwargs + + def map_sources_and_install_guards(self, tx): + self.grad_to_source = {} + self.tensor_to_source = {} + + from .builder import VariableBuilder + + param_groups_vt = VariableBuilder(tx, AttrSource(self.source, "param_groups"))( + self.value.param_groups + ).recursive_realize() + + for g_ind, (group, group_vt) in enumerate( + zip(self.value.param_groups, param_groups_vt.items) + ): + group_source = group_vt.source + params_vt = group_vt.getitem_const(ConstantVariable.create("params")) + for p_ind, (p, p_vt) in enumerate( + zip(group["params"], params_vt.unpack_var_sequence(tx)) + ): + param_source = p_vt.source + self.tensor_to_source[p] = param_source + grad_source = AttrSource( + param_source, + "grad", + ) + if p.grad is not None: + self.grad_to_source[p.grad] = grad_source + else: + install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) + + # state guards take a long time to generate + # so we manually generate them here + state_source = AttrSource(self.source, "state") + install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS)) + for idx, (p, value) in enumerate(self.value.state.items()): + tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, p) + p_state_source = GetItemSource( + state_source, ConstDictKeySource(state_source, idx) + ) + install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS)) + for k, v in value.items(): + if ( + isinstance(v, torch.Tensor) + and v not in self.grad_to_source + and v not in self.tensor_to_source + ): + self.tensor_to_source[v] = GetItemSource(p_state_source, k) + elif v is None or isinstance(v, (bool, int, float, str)): + install_guard( + GetItemSource(p_state_source, k).make_guard( + GuardBuilder.CONSTANT_MATCH + ) + ) + else: + raise GuardInstallException() + + def wrap_tensor(self, tx, tensor_value): + """Wrap state tensor in a TensorVariable""" + from .builder import VariableBuilder + + # If we have a source for a tensor already use it, + # if we have not seen a tensor before, stash and use a + # global weak ref source, since it must be an optimizer tensor + # that we have missed + + if tensor_value in self.tensor_to_source: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value, guard=False) + builder = VariableBuilder(tx, self.tensor_to_source[tensor_value]) + self.static_tensor_names.add(tx.output.module_key_name(builder.name)) + elif tensor_value in self.grad_to_source: + builder = VariableBuilder(tx, self.grad_to_source[tensor_value]) + else: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value, guard=False) + + global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) + builder = VariableBuilder(tx, GlobalWeakRefSource(global_name)) + self.static_tensor_names.add(tx.output.module_key_name(builder.name)) + + result = builder(tensor_value) + return result + + def update_list_args(self, tx, args, kwargs, py_args, py_kwargs): + """Update the args and kwargs to the traced optimizer call""" + for arg, py_arg in zip(args, py_args): + if isinstance(arg, ListVariable): + assert isinstance( + py_arg, list + ), "py_arg should be a list in optimizer variable" + for i, val in enumerate(py_arg): + tx.output.side_effects.mutation(arg) + if isinstance(val, torch.Tensor): + arg.items.append(self.wrap_tensor(tx, val)) + else: + from .builder import SourcelessBuilder, VariableBuilder + + if arg.source: + arg.items.append( + VariableBuilder(tx, GetItemSource(arg.source, i))(val) + ) + else: + arg.items.append(SourcelessBuilder()(tx, val)) + + def create_finalizer(self, tx): + names_to_delete = self.static_tensor_names + value = self.value + tc = tx.output.tracing_context + + def init_finalizer(gm): + def clear_static_tensor_refs(): + for name in names_to_delete: + gm._buffers.pop(name, None) + gm._parameters.pop(name, None) + if tc.params_flat: + tc.params_flat.clear() + + weakref.finalize(value, clear_static_tensor_refs) + + tx.output.add_graph_finalizer(init_finalizer) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/sdpa.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..9dced6918b945e51d0d44a6ece0308ccb6d0e7b6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/sdpa.py @@ -0,0 +1,84 @@ +# mypy: ignore-errors + +from inspect import getattr_static + +from ..bytecode_transformation import create_call_function +from ..exc import Unsupported +from .base import VariableTracker + + +class SDPAParamsVariable(VariableTracker): + """Represents the c++ params struct for scaled dot product attention. + This is a read-only container.""" + + @staticmethod + def create(tx, value, source): + from torch.backends.cuda import SDPAParams + from ..source import AttrSource + from .builder import VariableBuilder + from .torch import TorchInGraphFunctionVariable + + query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) + key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) + value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) + attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( + value.attn_mask + ) + dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) + is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( + value.is_causal + ) + param_vars = [ + query_var, + key_var, + value_var, + attn_mask_var, + dropout_var, + is_causal_var, + ] + return TorchInGraphFunctionVariable(SDPAParams).call_function( + tx, param_vars, {} + ) + + def __init__(self, proxy, param_vars, **kwargs): + self.proxy = proxy + self.param_vars = param_vars + super().__init__(**kwargs) + + def reconstruct(self, codegen): + assert self.source is None + assert self.param_vars is not None + codegen.load_import_from("torch._C", "_SDPAParams") + codegen.foreach(self.param_vars) + codegen.extend_output(create_call_function(len(self.param_vars), True)) + + def as_proxy(self): + return self.proxy + + def var_getattr(self, tx, name: str) -> VariableTracker: + import torch._C + from ..source import AttrSource + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + try: + getattr_static(torch._C._SDPAParams, name) + except AttributeError: + # Using raise from is too verbose here + raise Unsupported( # noqa: TRY200 + f"Unsupported torch._C._SDPAParams attribute {name}" + ) + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + @staticmethod + def is_sdpa_params(value): + from torch.backends.cuda import SDPAParams + + return value is SDPAParams diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/tensor.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..604cfd792e05aaf6f16b8984974b1af4f83e519a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/tensor.py @@ -0,0 +1,1189 @@ +# mypy: ignore-errors + +import functools + +import inspect +import operator +import types +from typing import Dict, List + +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from ..bytecode_transformation import create_call_method +from ..external_utils import call_hook_from_backward_state + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +import sympy + +import torch._numpy as tnp + +import torch.fx +import torch.random +from torch._dynamo import compiled_autograd +from torch._subclasses.meta_utils import is_sparse_any + +from torch.fx.experimental.symbolic_shapes import ( + guard_scalar, + GuardOnDataDependentSymNode, + has_free_symbols, + is_symbolic, + SymTypes, +) + +from .. import config, variables +from .._trace_wrapped_higher_order_op import trace_wrapped + +from ..exc import unimplemented, UserError, UserErrorType +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import ( + fqn, + get_custom_getattr, + get_fake_value, + get_real_value, + guard_if_dyn, + object_has_getattribute, + product, + proxy_args_kwargs, + tensortype_to_dtype, +) +from .base import VariableTracker +from .constant import ConstantVariable +from .lists import SizeVariable + +supported_tensor_comparison_ops = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, +} +supported_const_comparison_ops = { + "is": operator.is_, + "is not": operator.is_not, + "==": operator.eq, + "!=": operator.ne, +} + + +class TensorVariable(VariableTracker): + """A torch.Tensor input or an intermediate value in the FX graph""" + + _nonvar_fields = { + "proxy", + "dtype", + "device", + "layout", + "ndim", + "size", + "stride", + "requires_grad", + "is_quantized", + "is_contiguous", + "is_sparse", + "class_type", + "specialized_value", + *VariableTracker._nonvar_fields, + } + + def get_real_value(self): + """ + Get the actual value represented by this variable if computation is run + using the user-provided inputs. + NOTE: this runs actual tensor computation and may be + slow and memory-intensive. + """ + return get_real_value(self.proxy.node, self.proxy.tracer) + + def __init__( + self, + proxy: torch.fx.Proxy, + *, + dtype, + device, + layout, + ndim, + requires_grad, + is_quantized, + is_sparse, + class_type, + size=None, + stride=None, + is_contiguous=None, + **kwargs, + ): + super().__init__(**kwargs) + self.proxy = proxy + self.dtype = dtype + self.device = device + self.layout = layout + self.ndim = ndim + self.size = size + self.stride = stride + self.requires_grad = requires_grad + self.is_quantized = is_quantized + self.is_contiguous = is_contiguous + self.is_sparse = is_sparse + self.class_type = class_type + + def as_proxy(self): + return self.proxy + + def python_type(self): + return self.class_type + + @staticmethod + def specialize(value: torch.Tensor): + props = { + "dtype": value.dtype, + "device": value.device, + "layout": value.layout, + "ndim": int(value.ndim), + "requires_grad": value.requires_grad, + "is_quantized": value.is_quantized, + "is_sparse": value.is_sparse, + "class_type": type(value), + } + if is_sparse_any(value) and not has_free_symbols(value): + props["size"] = tuple( + [int(s) if is_symbolic(s) else s for s in value.size()] + ) + elif not has_free_symbols(value): + # this is a fully static shape, and the keys on props here inform specialization. + # We have to cast to int here, because these might get accessed as ConstantVariable, which has + # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant + # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for + # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and + # I'd like to keep it around for now. + props["size"] = tuple( + # the non is_symbolic case applies to the jagged layout + # NestedTensor case as singleton ints are not symbolic + [int(s) if is_symbolic(s) else s for s in value.size()] + ) + props["stride"] = tuple(value.stride()) + if torch._C._functorch.is_batchedtensor(value): + # Batched tensors does not support contiguity patterns, so + # we refrain from computing the `is_contiguous` property + props["is_contiguous"] = None + else: + props["is_contiguous"] = tuple( + [ + x + for x in torch._prims_common._memory_formats + if value.is_contiguous(memory_format=x) + ] + ) + return props + + def dynamic_getattr(self, tx, name): + fake_val = self.proxy.node.meta["example_value"] + # For getattrs on tensors without sources, + # we can do better than the default (creating a GetAttrVariable) + # if: + # (1) the tensor is a traceable tensor subclass + # (2) We are getattr'ing an inner tensor from that subclass + if not self.source and is_traceable_wrapper_subclass(fake_val): + fake_val = self.proxy.node.meta["example_value"] + attrs, ctx = fake_val.__tensor_flatten__() + proxy = getattr(self.as_proxy(), name) + example_value = getattr(fake_val, name) + if name in attrs: + # attrs returned from tensor_flatten are always tensors + assert isinstance(example_value, torch.Tensor) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value) + # any other attributes on the subclass (that are not methods) + # are assumed to be constant metadata. + elif not callable(example_value): + from .builder import SourcelessBuilder + + return SourcelessBuilder()(tx, example_value) + + if not (self.source and self.source.subguards_allowed()): + raise NotImplementedError() + + # For local source, we associate the real value. We use this real value + # for implementing getattr fallthrough on the variable tracker base class. + + # Note - this scope construction is mirrored in guards + # A subsequent PR will introduce a util. + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + # We raise in case we get a typerror bug w/ SuperSource. + # SuperSource has bugs in it atm, and can produce code like + # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, + # L['mod'].model.model.encoder.embed_positions)", scope) + # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. + _input_associated_real_value = eval(self.source.name(), scope) + except Exception as exc: + raise NotImplementedError() from exc + + if _input_associated_real_value is None: + raise NotImplementedError() + + if object_has_getattribute(_input_associated_real_value): + raise NotImplementedError() + + if get_custom_getattr(_input_associated_real_value): + raise NotImplementedError() + + real_value = getattr(_input_associated_real_value, name) + if callable(real_value): + # Callables have more nuanced handling, and we should let the existing system delegate here. + # Raising was past behavior and so should always be sound to fall back. + # Note - at a certain point we may want to handle + raise NotImplementedError() + + from ..guards import GuardBuilder + from .builder import VariableBuilder + + attr_source = AttrSource(self.source, name) + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) + return VariableBuilder(tx, attr_source)(real_value) + + def method_attr_ndim(self, tx): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + else: + return self.call_method(tx, "dim", [], {}) + + def method_attr_dtype(self, tx): + if self.dtype is not None: + return ConstantVariable.create(self.dtype) + + def method_attr_device(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device) + + def method_attr_layout(self, tx): + if self.layout is not None: + return ConstantVariable.create(self.layout) + + def method_attr_is_cuda(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device.type == "cuda") + + def method_attr_shape(self, tx): + if self.size is not None: + sizes = [variables.ConstantVariable.create(x) for x in self.size] + return SizeVariable(sizes) + else: + return self.call_method(tx, "size", [], {}) + + def method_attr_requires_grad(self, tx): + if self.requires_grad is not None: + return ConstantVariable.create(self.requires_grad) + + def method_attr_is_quantized(self, tx): + if self.is_quantized is not None: + return ConstantVariable.create(self.is_quantized) + + def method_attr_is_sparse(self, tx): + if self.is_sparse is not None: + return ConstantVariable.create(self.is_sparse) + + def method_attr_data(self, tx): + return self.call_method(tx, "detach", [], {}) + + def method_attr__version(self, tx): + from ..tensor_version_op import _tensor_version + + return variables.TorchInGraphFunctionVariable(_tensor_version).call_function( + tx, [self], {} + ) + + def var_getattr(self, tx, name): + from . import UserDefinedClassVariable + + if tx.strict_checks_enabled: + if name in self._strict_mode_banned_ops(): + unimplemented(f"Illegal getattr invocation {name} in strict mode") + + if name == "__class__": + return UserDefinedClassVariable(self.python_type()) + + handler = getattr(self, f"method_attr_{name}", None) + result = handler(tx) if handler is not None else None + + # Add a guard for type matching, these guards are checked before tensor guards + # In some cases, a . guard can be evaluated first, and break if + # is later changed to another type + if ( + result is not None + and self.source + and self.source.subguards_allowed() + and not ( + name not in ("grad", "requires_grad") and result.is_python_constant() + ) + ): + install_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) + result.source = AttrSource(self.source, name) + + # It's hard to get inplace view (metadata mutation) on graph input work properly across + # dynamo/aot/inductor, just fall back. + if self.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags + ): + # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. + return variables.misc.DelayGraphBreakVariable( + source=AttrSource(self.source, name) + ) + + # For attributes (not methods) that were not caught in the special handling above, + # (e.g. tensor.real), we handle these generically, assuming that the output type is + # a tensor. + if result is None and name != "grad": + + def try_generic_attr_handling(): + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + try: + static_attr = inspect.getattr_static(torch.Tensor, name) + except AttributeError: + return None + + # Make sure this is an attribute, not a method. + # type(torch.Tensor.H) should be "getset_descriptor" + # This is a because of CPython implementation, see THPVariableType: + # these attributes are implemented under tp_getset, which appear + # as `getset_descriptor`s, (compared to, say, methods which appear + # as `method_descriptor`s) + if type(static_attr) != types.GetSetDescriptorType: + return None + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + result = try_generic_attr_handling() + + if result is None: + result = self.dynamic_getattr(tx, name) + + if result is None: + raise NotImplementedError() + return result + + def has_unpack_var_sequence(self, tx): + return self.ndim > 0 + + def unpack_var_sequence(self, tx, idxes=None): + from .builder import wrap_fx_proxy_cls + + if idxes is None: + if self.size: + length = self.size[0] + else: + dyn_length = self.call_method( + tx, "size", [ConstantVariable.create(0)], {} + ) + # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through + # symbolic_shapes, but that end up as int/sympy.Integer + assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) + if isinstance(dyn_length, SymNodeVariable): + length = dyn_length.evaluate_expr(tx.output) + else: + length = dyn_length.value + idxes = range(length) + return [ + wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i]) + for i in idxes + ] + + def _strict_mode_banned_ops(self): + return torch._dynamo.config._autograd_backward_strict_mode_banned_ops + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if tx.strict_checks_enabled: + if name in self._strict_mode_banned_ops(): + unimplemented(f"Illegal method invocation {name} in strict mode") + + """ + Dispatch to a method-specific handler defined below. If the + handler returns None (or doesn't exist) we put the method call + in the graph. + """ + try: + handler_method = getattr(self, f"method_{name}") + except AttributeError: + pass + else: + try: + result = handler_method(*args, **kwargs) + if result: + return result + except TypeError as e: + unimplemented(f"unhandled args for {name}: {e}") + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + def method_size(self, *args, **kwargs): + return self._method_size_stride("size", *args, **kwargs) + + def method_stride(self, *args, **kwargs): + return self._method_size_stride("stride", *args, **kwargs) + + def _method_size_stride(self, name, dim=None): + dim = guard_if_dyn(dim) + + def make_const_size_variable(x, **options): + return SizeVariable( + [ConstantVariable.create(y, **options) for y in x], **options + ) + + RetVariable = ( + make_const_size_variable if name == "size" else ConstantVariable.create + ) + + # Technically, this should not be necessary, but I'm including it + # for enhanced BC, in case example_value is sometimes not set + # (it really should always be set though!) + if (r := getattr(self, name)) is not None: + if dim is None: + return RetVariable(r) + else: + return ConstantVariable.create(r[dim]) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + if dim is None: + fake_r = getattr(fake, name)() + if not has_free_symbols(fake_r): + # int conversion for safety, in case a SymInt refined + # to constant + return RetVariable(tuple(int(r) for r in fake_r)) + else: + fake_r = getattr(fake, name)(dim) + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + def method_numel(self): + if self.size is not None: + return ConstantVariable.create(product(self.size)) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + fake_r = fake.numel() + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + method_nelement = method_numel + + def method_dim(self): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + + method_ndimension = method_dim + + def method_is_floating_point(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_floating_point) + + def method_is_contiguous(self, memory_format=None): + memory_format = ( + memory_format.as_python_constant() + if memory_format is not None + else torch.contiguous_format + ) + if self.is_contiguous is not None: + return ConstantVariable.create(memory_format in self.is_contiguous) + elif (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create( + fake.is_contiguous(memory_format=memory_format) + ) + + def method_type(self, dtype=None, non_blocking=False, **kwargs): + if ( + dtype is None + and self.dtype is not None + and isinstance(self.device, torch.device) + ): + tensortype = next( + k for k, v in tensortype_to_dtype.items() if self.dtype in v + ) + if self.device.type == "cuda": + return ConstantVariable.create(f"torch.cuda.{tensortype.__name__}") + else: + return ConstantVariable.create(f"torch.{tensortype.__name__}") + elif ( + dtype is not None + and fqn(type(dtype.as_python_constant())) == "torch.tensortype" + ): + # torch.FloatTensor, etc. are all of type "torch.tensortype". + # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. + # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) + tensor_type = dtype.as_python_constant() + tensor_type_const = ConstantVariable.create(fqn(tensor_type)) + + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + if non_blocking: + kwargs = {"non_blocking": non_blocking, **kwargs} + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + "type", + *proxy_args_kwargs([self, tensor_type_const], kwargs), + ), + ) + + def method_as_subclass(self, cls): + if isinstance(cls, TensorSubclassVariable) and cls.source: + from ..symbolic_convert import InstructionTranslator + from .builder import VariableBuilder + from .torch_function import TensorWithTFOverrideVariable + + tx = InstructionTranslator.current_tx() + + # [Note: __torch_function__] coerce this tensor variable into a TensorWithTFOverrideVariable + # in eager, this is just a type change. This isn't sound if a __torch_function__ tensor subclass + # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call. + # It is up to the user whether this is correct behavior or not. + py_cls = cls.as_python_constant() + torch_fn = VariableBuilder( + tx, + AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"), + )(py_cls.__torch_function__.__func__) + + return TensorWithTFOverrideVariable.from_tensor_var( + tx, self, py_cls, torch_fn + ) + + def method_get_device(self): + if isinstance(self.device, torch.device): + index = self.device.index if self.device.type != "cpu" else -1 + return ConstantVariable.create(index) + + def method_element_size(self): + return ConstantVariable.create(self.dtype.itemsize) + + def method_numpy(self, *, force=False): + if not config.trace_numpy: + unimplemented("Tensor.numpy(). config.trace_numpy is False") + if not np: + unimplemented("Tensor.numpy(). NumPy is not available") + if self.layout != torch.strided: + raise TypeError( + f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first" + ) + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # We don't check that the tensor is on CPU when force is False, as this + # allows us to execute NumPy code on CUDA. Same for requires_grad=True + if force and force.as_python_constant(): + # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) + t = self.call_method(tx, "detach", [], {}) + proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {}) + else: + # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable + proxy = tx.output.create_proxy( + "call_method", "view_as", *proxy_args_kwargs([self, self], {}) + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def method_tolist(self): + from ..symbolic_convert import InstructionTranslator + from .builder import SourcelessBuilder + + tx = InstructionTranslator.current_tx() + + def tolist(tensor, sub_proxy): + def wrap(i, sub_proxy): + return SymNodeVariable.create( + tx, + sub_proxy.item(), + sym_num=tx.output.shape_env.create_unbacked_symint(), + ) + + if tensor.dtype not in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + unimplemented("Input tensor for tolist must be an integer tensor") + + if tensor.dim() == 0: + return wrap(tensor, sub_proxy) + + if tensor.dim() == 1: + return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] + + return [ + tolist(sub_tensor, sub_proxy=sub_proxy[i]) + for i, sub_tensor in enumerate(tensor) + ] + + tensor = self.as_proxy().node.meta["example_value"] + out = tolist(tensor, self.as_proxy()) + return SourcelessBuilder()(tx, out) + + def method_backward(self, *args, **kwargs): + unimplemented("Tensor.backward") + + def method_data_ptr(self, *args, **kwargs): + unimplemented("Tensor.data_ptr") + + def method_item(self, *args, **kwargs): + if not config.capture_scalar_outputs: + unimplemented("Tensor.item") + + def method___len__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + + def method___setitem__(self, key, value): + def has_bool_key(v): + if isinstance(v, TensorVariable): + return v.dtype in (torch.bool, torch.int8) + elif isinstance(v, variables.TupleVariable): + return any(has_bool_key(item) for item in v.items) + else: + return False + + if ( + has_bool_key(key) + and isinstance(value, TensorVariable) + and value.requires_grad + and torch.is_grad_enabled() + ): + unimplemented( + "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123" + ) + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + tx.output.create_proxy( + "call_function", + operator.setitem, + *proxy_args_kwargs([self, key, value], {}), + ) + return ConstantVariable.create(None) + + def method_resize_(self, *args, **kwargs): + unimplemented("Tensor.resize_") + + def method_resize_as_(self, *args, **kwargs): + unimplemented("Tensor.resize_as_") + + def method_set_(self, *args, **kwargs): + if len(args) > 1: + # torch.Tensor.set_() has several overloads. + # aten::set_.source_Tensor(Tensor) gets special handling + # in AOTAutograd and functionalization, because it is the most common + # overload and is used by FSDP. + # graph-breaking on aten::set_source_Tensor_storage_offset for now, + # unless we find that we need to make it work. + unimplemented("Tensor.set_.source_Tensor_storage_offset") + + def method_add_(self, other, *, alpha=None): + if alpha is not None: + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [other, alpha], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method_addcdiv_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + result = variables.TorchInGraphFunctionVariable(torch.div).call_function( + tx, [tensor1, tensor2], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, value], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method___contains__(self, arg): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # Rewrite __contains__ here so that downstream passes can trace through + # without dealing with unbacked symbool. Roughly the code we translate is: + # def __contains__(self, x): + # return (x == self).any().item() + result = variables.TorchInGraphFunctionVariable(torch.eq).call_function( + tx, [self, arg], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.any).call_function( + tx, [result], {} + ) + return result.call_method(tx, "item", [], {}) + + def method_redistribute(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def redistribute_fn_with_prim_types(x): + return x.redistribute(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + redistribute_fn_with_prim_types.__name__ = "prim_redistribute" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + redistribute_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_register_hook(self, *args, **kwargs): + return self._method_register_hook("register_hook", *args, **kwargs) + + def method_register_post_accumulate_grad_hook(self, *args, **kwargs): + return self._method_register_hook( + "register_post_accumulate_grad_hook", *args, **kwargs + ) + + def _method_register_hook(self, name: str, hook: VariableTracker): + # Note - do not arbitrarily add hooks here - make sure they match the same contract + # see [On tensor.register_hook] + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + if not self.source: + if not compiled_autograd.compiled_autograd_enabled: + # TODO(voz): + # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary + # python state. + # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run + # them in a compiled bwd without re-entering dynamo as compiled_autograd does. + # + # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? + # No. Because this was going to be a graph break anyway - this check does not + # introduce new graph breaks where there were none. + # + # Discussion point 2 - Should we defer this check to backwards? + # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user + # would have no recourse - their forward traces just fine, but will fail at backwards unless + # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) + # then they have nothing they can do except disable compile. + unimplemented( + "Compilation of intermediate hooks requires compiled autograd" + ) + + hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) + + def _register_hook_trampoline(tensor, bw_state): + register_hook = getattr(tensor, name) + register_hook( + functools.partial( + trace_wrapped, + fn=call_hook_from_backward_state, + bw_state=bw_state, + hook_name=hook_name, + ) + ) + # TODO(jansel): returning None here is wrong, it should be + # RemovableHandle, but we need some extra work to support + # this properly. + return None + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + _register_hook_trampoline, + (self.as_proxy(), bw_state_proxy), + {}, + ), + ) + + handle_variable = variables.RemovableHandleVariable( + mutable_local=variables.base.MutableLocal(), + ) + tx.output.side_effects.register_hook(self, hook, handle_variable, name) + return handle_variable + + def method_requires_grad_(self, requires_grad=True): + if requires_grad is not True: + requires_grad = requires_grad.as_python_constant() + + if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: + unimplemented("Tensor.requires_grad_") + else: + return self + + def method_new(self, *args, **kwargs): + # Convert x.new(torch.Size) into x.new_empty(torch.Size), + # as Tensor.new acts differently with a Size input versus a tuple input. + if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( + len(args) >= 1 + and all( + isinstance(a, ConstantVariable) and a.python_type() == int for a in args + ) + ): + from ..symbolic_convert import InstructionTranslator + + return self.call_method( + InstructionTranslator.current_tx(), "new_empty", args, kwargs + ) + + def method_untyped_storage(self): + return UntypedStorageVariable( + self, self.as_proxy().node.meta["example_value"].untyped_storage() + ) + + def rename(self, tx, name): + self.proxy.node._rename(name) + return super().rename(tx, name) + + +class SymNodeVariable(VariableTracker): + """ + Represents a symbolic size, e.g., as returned by tensor.size(0) + """ + + @classmethod + def create(cls, tx, proxy, sym_num, **options): + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == sym_num + if sym_num is None: + sym_num = get_fake_value(proxy.node, tx) + proxy.node.meta["example_value"] = sym_num + + if isinstance(sym_num, (sympy.Integer, int, bool)): + sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num + return ConstantVariable.create(sym_num) + + return SymNodeVariable(proxy, sym_num, **options) + + def __init__(self, proxy, sym_num, **kwargs): + super().__init__(**kwargs) + self.proxy = proxy + # TODO: Should we allow non SymTypes here? Today it is allowed + self.sym_num = sym_num + + def python_type(self): + if isinstance(self.sym_num, SymTypes): + return self.sym_num.node.pytype + else: + return type(self.sym_num) + + def as_proxy(self): + return self.proxy + + def evaluate_expr(self, output_graph=None): + try: + return guard_scalar(self.sym_num) + except GuardOnDataDependentSymNode as e: + raise UserError( # noqa: TRY200 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._constrain_as_*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + +class NumpyNdarrayVariable(TensorVariable): + """ + Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. + Use this for Tensor.numpy() call. + """ + + @staticmethod + def create(tx, proxy, **options): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=tx, + proxy=proxy, + **options, + ) + + def var_getattr(self, tx, name): + # NB: This INTENTIONALLY does not call super(), because there is + # no intrinsic reason ndarray properties are related to Tensor + # properties. The inheritance here is for implementation sharing. + + from ..utils import numpy_attr_wrapper + from .builder import wrap_fx_proxy + + result = None + + example_value = self.as_proxy().node.meta["example_value"] + example_ndarray = tnp.ndarray(example_value) + + def insert_into_graph(): + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} + ), + ) + + if name in ["T", "real", "imag"]: + proxy = tx.output.create_proxy( + "call_function", + numpy_attr_wrapper, + (self.as_proxy(), name), + {}, + ) + result = NumpyNdarrayVariable.create(tx, proxy) + + # These are awkward to implement. The standard playbook for torch._numpy + # interop is to trace a call into the torch._numpy wrapper which works for + # Tensor operations. However, we don't want to do this for calls + # that don't return Tensors, because in those cases we may not want + # to trace the attribute access into the graph at all (it is sort + # of harmless to do so, because AOTAutograd will eliminate them, + # but it's best not to trace them in to begin with.) But in any + # case, tracing these into the graph is like trying to fit a square + # peg into a round hole; best not to do it. So instead we + # painstakingly implement these by hand + # + # NB: only ALWAYS specialized attributes can go here; notably, + # size/shape not allowed! + elif name in ("ndim", "itemsize"): + return ConstantVariable.create(getattr(example_ndarray, name)) + elif name in ("shape", "stride"): + if not has_free_symbols(r := getattr(example_ndarray, name)): + return ConstantVariable.create(tuple(int(r) for r in r)) + return insert_into_graph() + elif name == "size": + if not has_free_symbols(r := example_ndarray.size): + return ConstantVariable.create(int(r)) + return insert_into_graph() + elif name in ["base", "flags", "dtype"]: + unimplemented(f"TODO: add support for ndarray.{name}") + elif name in ["__version__"]: + unimplemented("delegate np.__version__ to NumPy") + if result is None: + raise NotImplementedError() + return result + + @staticmethod + def patch_args(name, args, kwargs): + if name == "clip": + kwargs_rename = {"a_min": "min", "a_max": "max"} + kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()} + return args, kwargs + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..utils import numpy_method_wrapper + + args, kwargs = self.patch_args(name, args, kwargs) + + if name in ["__len__", "size", "tolist"]: + # delegate back to TensorVariable + return super().call_method(tx, name, args, kwargs) + if name == "tobytes": + unimplemented("tobytes is not modelled in torch._numpy") + proxy = tx.output.create_proxy( + "call_function", + numpy_method_wrapper(name), + *proxy_args_kwargs([self] + list(args), kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def python_type(self): + return np.ndarray + + +class UnspecializedPythonVariable(TensorVariable): + """ + This is a 1-element tensor represents unspecialized python float/int. + """ + + def __init__( + self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs + ): + super().__init__(proxy, **kwargs) + self.raw_value = raw_value + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): + # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. + return UnspecializedPythonVariable( + **dict(tensor_variable.__dict__), + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + + +class FakeItemVariable(TensorVariable): + """An unspecialized python variable which prevents access to the underlying raw value. + This is needed if item is called on a FakeTensor.""" + + def __init__(self, proxy: torch.fx.Proxy, **kwargs): + need_unwrap = kwargs.pop("need_unwrap", False) + super().__init__(proxy, **kwargs) + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable): + return FakeItemVariable(**dict(tensor_variable.__dict__)) + + +class TensorSubclassVariable(VariableTracker): + def __init__(self, value, *args, **kwargs): + self.value = value + super().__init__(*args, **kwargs) + + def call_function( + self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ) -> VariableTracker: + if len(args) == 1 and isinstance(args[0], TensorVariable): + from .builder import VariableBuilder + from .torch_function import TensorWithTFOverrideVariable + + torch_fn = VariableBuilder( + tx, AttrSource(self.source, "__torch_function__") + )(self.value.__torch_function__) + + return TensorWithTFOverrideVariable.from_tensor_var( + tx, args[0], self.value, torch_fn + ) + + return super().call_function(tx, args, kwargs) + + def as_python_constant(self): + return self.value + + def python_type(self): + return type(self.value) + + +class UntypedStorageVariable(VariableTracker): + _nonvar_fields = { + "example_value", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + from_tensor: TensorVariable, + example_value: torch.UntypedStorage, + **kwargs, + ): + super().__init__(**kwargs), + self.from_tensor = from_tensor + # Example_value will always have device="meta" + self.example_value = example_value + + def call_method( + self, + tx, + name, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> VariableTracker: + if name == "size": + assert not args + assert not kwargs + result = self.example_value.size() + if not has_free_symbols(result): + # avoid creating a node in the graph + return ConstantVariable.create(int(result)) + else: + from ..external_utils import untyped_storage_size + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + untyped_storage_size, + (self.from_tensor.as_proxy(),), + {}, + ), + ) + if name == "resize_" and len(args) == 1: + assert not kwargs + tx.output.create_proxy( + "call_function", + torch.ops.inductor.resize_storage_bytes_, + (self.from_tensor.as_proxy(), args[0].as_proxy()), + {}, + ) + return self + + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen): + codegen(self.from_tensor) + codegen.append_output(codegen.create_load_method("untyped_storage")) + codegen.extend_output(create_call_method(0)) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/torch.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac021468c7d37479eeeb8e95dfb58fb628ceb78 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/torch.py @@ -0,0 +1,823 @@ +# mypy: ignore-errors + +import inspect +import logging + +import math +import re +from typing import Dict, List + +import torch._C +import torch._refs +import torch.fx +import torch.nn +import torch.onnx.operators +from torch._logging import warning_once + +from torch._streambase import _StreamBase +from ..._guards import TracingContext +from .. import config, polyfill, variables +from ..codegen import PyCodegen +from ..device_interface import get_registered_device_interfaces +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import SyntheticLocalSource +from ..utils import ( + check_constant_args, + check_unspec_python_args, + guard_if_dyn, + has_torch_function, + hashable, + product, + proxy_args_kwargs, + unwrap_if_wrapper, +) +from .base import VariableTracker +from .ctx_manager import ( + AutocastModeVariable, + NullContextVariable, + TorchFunctionDisableVariable, +) +from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable +from .lists import ListVariable, TupleVariable +from .torch_function import can_dispatch_torch_function, dispatch_torch_function + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +log = logging.getLogger(__name__) + +supported_ctx_manager_classes = { + torch.profiler.profiler.profile, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + torch._C.DisableTorchFunctionSubclass, + torch._functorch.vmap.vmap_increment_nesting, + torch._functorch.eager_transforms.grad_increment_nesting, + torch._functorch.eager_transforms.enable_inplace_requires_grad, + torch.amp.autocast_mode.autocast, + torch.autograd.grad_mode.enable_grad, + torch.autograd.grad_mode.inference_mode, + torch.autograd.grad_mode.no_grad, + torch.autograd.grad_mode.set_grad_enabled, + torch.autograd.graph.disable_saved_tensors_hooks, + torch.cpu.amp.autocast_mode.autocast, + torch.cuda.amp.autocast_mode.autocast, +} + + +REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [ + torch.onnx.operators.shape_as_tensor, + torch._shape_as_tensor, +] + +constant_fold_functions = [ + torch._assert, + torch._utils._get_device_index, + torch._C._get_cublas_allow_tf32, + torch.cuda.get_device_properties, + torch.cuda.is_available, + torch.distributed.is_available, + torch.get_autocast_gpu_dtype, + torch.get_default_dtype, + torch.is_autocast_cache_enabled, + torch.is_autocast_cpu_enabled, + torch.is_autocast_enabled, + torch.is_complex, + torch.is_floating_point, + torch.nn.functional._Reduction.get_enum, + torch.promote_types, + torch._C._get_privateuse1_backend_name, +] + + +if torch.distributed.is_available(): + constant_fold_functions.extend( + [ + torch.distributed.is_initialized, + torch.distributed.get_rank, + torch.distributed.get_world_size, + ] + ) + + +tracing_state_functions = { + torch.jit.is_scripting: False, + torch.jit.is_tracing: False, + torch._C._get_tracing_state: None, + torch.fx._symbolic_trace.is_fx_tracing: False, + torch.onnx.is_in_onnx_export: False, + torch._dynamo.external_utils.is_compiling: True, + torch._utils.is_compiling: True, + torch.compiler.is_compiling: True, + torch.compiler.is_dynamo_compiling: True, +} + + +class BaseTorchVariable(VariableTracker): + """common base for all torch.* functions, classes, modules and other things""" + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return cls( + value, + source=source, + ) + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def reconstruct(self, codegen): + try: + name = f"{self.value.__module__}.{self.value.__name__}" + except Exception: + name = f"torch_obj_{id(self.value)}" + unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) + codegen.extend_output( + codegen.setup_globally_cached(unique_var_name, self.value, False) + ) + + def as_proxy(self): + return self.value + + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + + def call_hasattr(self, tx, name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def can_constant_fold_through(self): + if self.value in constant_fold_functions: + return True + return getattr(self.value, "__module__", None) == "math" + + +class TorchCtxManagerClassVariable(BaseTorchVariable): + """Points to a context manager class in torch.* that dynamo has implementations""" + + def __repr__(self): + return f"TorchCtxManagerClassVariable({self.value})" + + @staticmethod + def is_matching_cls(value): + # Unwrap if it's a functools.lru_cache wrapper + value = unwrap_if_wrapper(value) + # We can't do isinstance(value, type) check because some ctx managers + # are implemented as a function decorated by contextlib.contextmanager, + # E.g., torch._functorch.vmap.vmap_increment_nesting. + return hashable(value) and value in supported_ctx_manager_classes + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import ( + DisabledSavedTensorsHooksVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + StreamVariable, + VmapIncrementNestingCtxManagerVariable, + ) + + if self.value is torch.no_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, False) + return ctx.call_function(tx, args, kwargs) + else: + return GradModeVariable.create(tx, False) + elif self.value is torch.enable_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, True) + return ctx.call_function(tx, args, kwargs) + return GradModeVariable.create(tx, True) + elif self.value is torch.set_grad_enabled and len(args) == 1: + return GradModeVariable.create( + tx, args[0].as_python_constant(), initialized=True + ) + elif self.value is torch.inference_mode: + assert len(args) <= 1 and len(kwargs) == 0 + inf_mode = args[0].as_python_constant() if len(args) == 1 else True + return InferenceModeVariable.create(tx, inf_mode) + elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): + from torch._dynamo.variables.builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + elif self.value in ( + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ): + return AutocastModeVariable.create(self.value, args, kwargs) + elif self.value in ( + torch.profiler.profile, + torch.profiler.record_function, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + ): + warning_once(log, "Profiler function %s will be ignored", self.value) + return NullContextVariable() + elif self.value is torch._C.DisableTorchFunctionSubclass: + assert not (args or kwargs) + return TorchFunctionDisableVariable.create(tx) + elif self.value is torch._functorch.vmap.vmap_increment_nesting: + assert len(args) == 2 + return VmapIncrementNestingCtxManagerVariable.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch._functorch.eager_transforms.grad_increment_nesting: + assert len(args) == 0 + return GradIncrementNestingCtxManagerVariable.create(tx) + elif ( + self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad + ): + assert len(args) == 1 + return GradInplaceRequiresGradCtxManagerVariable.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: + assert len(args) == 1 + return DisabledSavedTensorsHooksVariable.create( + tx, args[0].as_python_constant() + ) + + +class TorchInGraphFunctionVariable(BaseTorchVariable): + """Points to a torch function/method that should be put in FX graph""" + + def __repr__(self): + return f"TorchInGraphFunctionVariable({self.value})" + + def get_function(self): + return self.value + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from . import ( + ConstantVariable, + DeterministicAlgorithmsVariable, + GradModeVariable, + SDPAParamsVariable, + StreamContextVariable, + SymNodeVariable, + TensorVariable, + UserDefinedObjectVariable, + ) + + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + constant_args = check_constant_args(args, kwargs) + unspec_python_args = check_unspec_python_args(args, kwargs) + + if self.can_constant_fold_through() and (constant_args or unspec_python_args): + # constant fold + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif self.value in tracing_state_functions: + assert not args and not kwargs + # See: https://github.com/pytorch/pytorch/issues/110765 + if self.value in ( + torch._utils.is_compiling, + torch._dynamo.external_utils.is_compiling, + torch.compiler.is_compiling, + torch.compiler.is_dynamo_compiling, + ): + tx.mark_inconsistent_side_effects() + return ConstantVariable.create(tracing_state_functions[self.value]) + elif self.value is torch.overrides.get_default_nowrap_functions.__wrapped__: + # [Note: __torch_function__] we return empty here because we restrict + # the set of functions that we trace __torch_function__ on to + # functions outside of the actual set. Implementing this properly will require implementing + # some variable types to track and compare tensor getset descriptors + from .builder import SourcelessBuilder + + return SourcelessBuilder()( + tx, torch.overrides.get_default_nowrap_functions() + ) + elif self.value == torch.ops.inductor.accumulate_grad_.default: + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder()(tx, polyfill.accumulate_grad), args, kwargs + ) + elif self.value == math.radians and not (constant_args or unspec_python_args): + # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder()(tx, polyfill.radians), args, kwargs + ) + elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like): + assert len(args) == 1 + if isinstance(args[0], TensorVariable) or ( + self.value is torch.overrides.is_tensor_like + and isinstance(args[0], UserDefinedObjectVariable) + and hasattr(args[0].value, "__torch_function__") + ): + return ConstantVariable.create(True) + else: + return ConstantVariable.create(False) + elif self.value in ( + torch.is_floating_point, + torch.is_complex, + ): + input_arg = None + if args: + input_arg = args[0] + else: + assert "input" in kwargs + input_arg = kwargs["input"] + if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None: + if self.value is torch.is_floating_point: + return ConstantVariable.create(input_arg.dtype.is_floating_point) + elif self.value is torch.is_complex: + return ConstantVariable.create(input_arg.dtype.is_complex) + else: + raise AssertionError(f"calling {self.value}") + elif ( + self.value is torch.numel + and isinstance(args[0], TensorVariable) + and args[0].size is not None + ): + return ConstantVariable.create(product(args[0].size)) + elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD: + assert len(args) == 1 + assert isinstance(args[0], TensorVariable) + return args[0].call_method(tx, "size", [], {}) + elif self.value in ( + torch.nn.modules.utils._single, + torch.nn.modules.utils._pair, + torch.nn.modules.utils._triple, + torch.nn.modules.utils._quadruple, + torch.nn.modules.utils._ntuple, + ): + return self._call_ntuple(tx, args, kwargs) + elif self.value is torch.is_grad_enabled: + assert not (args or kwargs) + install_guard(GradModeVariable._guards_singleton) + return ConstantVariable.create(torch.is_grad_enabled()) + elif self.value is torch.use_deterministic_algorithms and len(args) == 1: + return DeterministicAlgorithmsVariable.create( + tx, args[0].as_python_constant() + ) + elif self.value is torch.are_deterministic_algorithms_enabled: + assert not (args or kwargs) + install_guard(DeterministicAlgorithmsVariable._guards_singleton) + return ConstantVariable.create(torch.are_deterministic_algorithms_enabled()) + elif self.value is torch._C._is_torch_function_enabled: + assert not (args or kwargs) + install_guard(TorchFunctionDisableVariable._guards_singleton) + return ConstantVariable.create(tx.output.torch_function_enabled) + elif self.value in ( + torch.overrides.has_torch_function, + torch.overrides.has_torch_function_variadic, + torch.overrides.has_torch_function_unary, + ): + assert not kwargs + elems = ( + args[0].unpack_var_sequence(tx) + if len(args) == 1 and isinstance(args[0], TupleVariable) + else args + ) + return ConstantVariable.create( + any(has_torch_function(x) for x in elems), + ) + elif any( + self.value is method + for method in [ + device_interface.stream + for _, device_interface in get_registered_device_interfaces() + ] + ): + assert len(args) == 1 + return StreamContextVariable.create(tx, args[0]) + elif self.value is torch.from_numpy: + if not config.trace_numpy: + unimplemented("torch.from_numpy. config.trace_numpy is False") + if not np: + unimplemented("torch.from_numpy. NumPy is not available") + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.as_tensor, + *proxy_args_kwargs(args, {}), + ), + example_value=None, + ) + elif can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + elif self.value is torch.jit.annotate: + assert len(args) == 2 + return args[1] + elif self.value is torch.backends.cudnn.is_acceptable: + # is_acceptable(tensor) returns true if + # (a) tensor dtype/device are supported by cudnn + # (b) cudnn is available + # (c) some initialization has completed + # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version) + assert ( + len(args) == 1 or "tensor" in kwargs + ), "Expect 1 input to cudnn.is_acceptable" + tensor_variable = args[0] if len(args) > 0 else kwargs["tensor"] + assert isinstance( + tensor_variable, TensorVariable + ), "Expect input to cudnn.is_acceptable to be a tensor" + tensor_inp = torch.tensor( + 0, dtype=tensor_variable.dtype, device=tensor_variable.device + ) + return ConstantVariable.create( + torch.backends.cudnn.is_acceptable(tensor_inp) + ) + elif self.value is torch.utils.hooks.BackwardHook: + return variables.BackwardHookVariable.create(tx, *args, **kwargs) + elif self.value is torch.nn.Parameter: + return self.call_nn_parameter(tx, *args, **kwargs) + elif ( + self.value == torch.numel + and len(args) == 1 + and isinstance(args[0], TensorVariable) + and len(kwargs) == 0 + ): + # TODO(voz): This is rewritten as a call_method because + # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_method", + "numel", + *proxy_args_kwargs(args, kwargs), + ), + ) + # TODO: These special cases shouldn't be necessary; we should + # generically support torch.ops that return int + elif ( + self.value in (torch.ops.aten.sym_size, torch.ops.aten.sym_size.int) + and len(args) == 2 + and len(kwargs) == 0 + and isinstance(args[0], TensorVariable) + ): + # we see this when retracing already traced code + return args[0].call_method(tx, "size", [args[1]], {}) + elif ( + self.value in (torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int) + and len(args) == 2 + and len(kwargs) == 0 + and isinstance(args[0], TensorVariable) + ): + return args[0].call_method(tx, "stride", [args[1]], {}) + elif ( + self.value == torch.addcdiv + and len(args) == 3 + and "value" in kwargs + and len(kwargs) == 1 + ): + # decompose addcdiv into constituent ops, prevents a graph break due to converting + # value to a scalar + result = TorchInGraphFunctionVariable(torch.div).call_function( + tx, args[1:], {} + ) + result = TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, kwargs["value"]], {} + ) + return TorchInGraphFunctionVariable(torch.add).call_function( + tx, [args[0], result], {} + ) + elif ( + self.value is torch._assert + and len(args) >= 1 + and ( + (args[0].is_python_constant() and args[0].as_python_constant()) + or ( + isinstance(args[0], variables.SymNodeVariable) + and args[0].evaluate_expr() + ) + ) + ): + return ConstantVariable(None) + elif SDPAParamsVariable.is_sdpa_params(self.value): + return wrap_fx_proxy( + tx, + proxy=tx.output.create_proxy( + "call_function", + torch._C._SDPAParams, + *proxy_args_kwargs(args, kwargs), + ), + param_vars=args, + ) + elif is_constant_pg_functions(self.value): + # because the input is a "ProcessGroupVariable", we'll be guarding on its + # ID_MATCH based on how it was constructed. + + # We desugar it at trace-time into ranks by directly calling util + # bake the result into the trace + if len(args) == 1: + # group or group name + assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable)) + elif len(args) == 2: + # ranks + tag + assert isinstance(args[0], ListVariable) and isinstance( + args[1], ConstantVariable + ) + else: + raise AssertionError( + f"Invalid group value ({args}) for constant pg " + f"function {self.value}" + ) + args_as_value = [arg.as_python_constant() for arg in args] + invocation_result = self.value(*args_as_value) + + # Note - while we *could* cook up sources around invocations, like a FunctionSource + # the space of invoking functions in the middle of the guard chain is very iffy. As such, + # guard propagation via options is the best we can do. + from .builder import SourcelessBuilder + + return SourcelessBuilder()(tx, invocation_result) + elif is_from_local(self.value): + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args[1:]] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def fn_with_prim_types(x): + return self.value(x, *args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + fn_with_prim_types.__name__ = "prim " + self.value.__name__ + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_with_prim_types, + *proxy_args_kwargs([args[0]], {}), + ), + ) + elif ( + self.value is torch.nested.nested_tensor + and kwargs.get("layout", torch.strided) == torch.strided + ): + raise unimplemented("torch.compile does not support strided NestedTensor") + elif self.value is torch.nn.functional.one_hot and ( + len(args) + len(kwargs) == 1 + or ( + len(args) == 2 + and args[1].is_python_constant() + and args[1].as_python_constant() == -1 + ) + ): + raise unimplemented( + "torch.nn.functional.one_hot with data-dependent output shape" + ) + elif ( + self.value is torch.fx.experimental.symbolic_shapes.guard_size_oblivious + and len(args) == 1 + and isinstance(args[0], SymNodeVariable) + ): + # TODO: this probably should be folded somewhere else but I'm not + # sure where + # TODO: some of the other symbolic_shapes special tools can also + # get this treatment too + (cond,) = args + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_size_oblivious(cond.sym_num) + ) + elif self.value is torch._C._autograd._unsafe_set_version_counter: + from ..tensor_version_op import _unsafe_set_version_counter + + return TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, args, kwargs) + else: + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + bin_ops = {"add", "sub", "mul", "div", "sqrt"} + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ +Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. +To support this behavior, we need to allow const-propping tensors that store symint data. +For now, dynamo will explicitly graph break when it encounters user code with this behavior. +""" + log.warning(msg) + raise unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any(isinstance(x, SymNodeVariable) for x in args): + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + if fn_ is torch.tensor: + + def check_any_unspec(x): + # NB: This includes UnspecializedPythonVariable + if isinstance(x, (TensorVariable, SymNodeVariable)): + return True + elif isinstance(x, (ListVariable, TupleVariable)): + return any(check_any_unspec(y) for y in x.items) + # TODO: there maybe other recursive structures you need to + # check + else: + return False + + data_arg = None + if args: + data_arg = args[0] + elif "data" in kwargs: + data_arg = kwargs["data"] + + # NB: OK to pass torch.tensor(tensor), this will trace fine + if not isinstance(data_arg, TensorVariable) and check_any_unspec( + data_arg + ): + # This is slower and less canonical, so only use it if we + # have to + fn_ = torch._refs.tensor + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. +Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" + ) + + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): + if ( + out_tensor.source + and out_tensor in tx.output.graphargs + and out_tensor.size != result_tensor.size + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if ( + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out.shape != fake_tensor.shape + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + else: + unimplemented(f"out variant of {type(kwargs['out'])}") + + return tensor_variable + + def _call_ntuple(self, tx, args, kwargs): + """inline behavior of torch.nn.modules.utils._ntuple""" + if self.value is torch.nn.modules.utils._ntuple: + count = args[0].as_python_constant() + else: + count = self.value.__closure__[0].cell_contents + assert isinstance(count, int) + assert not kwargs + + def handle_ntuple(value): + if value.has_unpack_var_sequence(tx): + return variables.TupleVariable( + list(value.unpack_var_sequence(tx)), + ) + elif value.is_python_constant(): + # constant prop through it + return variables.ConstantVariable.create( + torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), + ) + else: + unimplemented(f"torch.nn.modules.utils._ntuple({value})") + + if self.value is torch.nn.modules.utils._ntuple: + return variables.LambdaVariable(handle_ntuple) + else: + return handle_ntuple(args[0]) + + @classmethod + def call_nn_parameter(cls, tx, data=None, requires_grad=True): + """A call to torch.nn.Parameter() gets lifted to before the graph""" + if isinstance(requires_grad, variables.VariableTracker): + try: + requires_grad = requires_grad.as_python_constant() + except NotImplementedError: + unimplemented("Parameter(requires_grad=...) not constant") + + if not isinstance(data, variables.TensorVariable): + unimplemented(f"Parameter(data={data}) not implemented") + + # this results in cleaner graphs, but only works for inputs + if data.source: + return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + + unimplemented("Parameter() on non-input") + + @staticmethod + def _nn_param_via_prefix_insert(tx, data, requires_grad): + # Alternate version if we have a .source + from .builder import VariableBuilder + + varname = tx.output.new_var() + + # construct the nn.Parmeter before the graph save it to varname + cg = PyCodegen(tx) + cg.load_import_from("torch.nn", "Parameter") + cg(data.source) + cg(variables.ConstantVariable(requires_grad)) + cg.call_function(2, True) + cg.store(varname) + tx.output.pregraph_bytecode.extend(cg.get_instructions()) + + # add the newly constructed nn.Parameter as a graph input + source = SyntheticLocalSource(varname) + example_value = torch.nn.Parameter( + tx.output.example_value_from_input_node(data.as_proxy().node) + ) + result = VariableBuilder(tx, source)(example_value) + # No need to guard on this since we already guarded on `data`. + # These guards would fail since varname doesn't exist until after the function starts + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/torch_function.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/torch_function.py new file mode 100644 index 0000000000000000000000000000000000000000..857767346ca644638952231fd03dea0b1797f5c0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/torch_function.py @@ -0,0 +1,270 @@ +# mypy: ignore-errors + +import inspect +from typing import Dict, List + +import torch.utils._pytree as pytree + +from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GlobalSource +from ..utils import has_torch_function, is_tensor_base_attr_getter +from .base import VariableTracker +from .constant import ConstantVariable +from .lists import TupleVariable +from .tensor import TensorSubclassVariable, TensorVariable +from .user_defined import UserDefinedObjectVariable + + +# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues): +# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches +# __torch_function__ on attribute accesses, method calls, and torch API calls. +# The following is not supported: +# - triggering __torch_function__ on tensor subclass non-tensor custom attributes +# - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause +# excessive recompiles in certain degenerate cases +# - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls + +# The following is supported: +# - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as +# any argument +# - triggering __torch_function__ on torch API calls with tensor subclass arguments +# - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances +# - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position + +# See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w +# for more information on the design. + +# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py + + +banned_attrs = [ + fn.__self__.__name__ + for fn in get_default_nowrap_functions() + if is_tensor_base_attr_getter(fn) +] + + +def _get_subclass_type(var): + assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) + return var.python_type() + + +def _get_subclass_type_var(tx, var): + assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) + if isinstance(var, TensorWithTFOverrideVariable): + return var.class_type_var(tx) + elif isinstance(var, UserDefinedObjectVariable): + from .builder import SourcelessBuilder, VariableBuilder + + if var.source: + return VariableBuilder(tx, var.source)(var.python_type()) + else: + return SourcelessBuilder()(tx, var.python_type()) + + +def _is_attr_overidden(tx, var, name): + import torch + + overridden = False + try: + attr_val = inspect.getattr_static(var.python_type(), name) + overridden |= attr_val != getattr(torch.Tensor, name) + except AttributeError: + pass + + return overridden + + +def call_torch_function( + tx, torch_function_type, torch_function_var, fn, types, args, kwargs +): + from .builder import SourcelessBuilder + + # signature: + # def __torch_function__(cls, func, types, args=(), kwargs=None): + tf_args = ( + torch_function_type, + fn, + types, + SourcelessBuilder()(tx, tuple(args)), + SourcelessBuilder()(tx, kwargs), + ) + return tx.inline_user_function_return(torch_function_var, tf_args, {}) + + +def build_torch_function_fn(tx, value, source): + from .builder import SourcelessBuilder, VariableBuilder + + if source: + return VariableBuilder( + tx, + AttrSource(AttrSource(source, "__torch_function__"), "__func__"), + )(value.__torch_function__.__func__) + else: + return SourcelessBuilder()(tx, value.__torch_function__.__func__) + + +def can_dispatch_torch_function(tx, args, kwargs): + if tx.output.torch_function_enabled: + all_args = pytree.arg_tree_leaves(*args, **kwargs) + return any(has_torch_function(arg) for arg in all_args) + else: + return False + + +def dispatch_torch_function(tx, fn, args, kwargs): + """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" + + all_args = pytree.arg_tree_leaves(*args, **kwargs) + overloaded_args = _get_overloaded_args( + [arg for arg in all_args if has_torch_function(arg)], + _get_subclass_type, + ) + + for arg in overloaded_args: + res = arg.call_torch_function( + tx, + fn, + TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), + args, + kwargs, + ) + + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + + unimplemented( + f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented" + ) + + +class TensorWithTFOverrideVariable(TensorVariable): + """ + Represents a tensor subclass instance with a __torch_function__ override. + """ + + def __init__(self, *args, **kwargs): + self.torch_function_fn = kwargs.pop("torch_function_fn") + super().__init__(*args, **kwargs) + + @classmethod + def from_tensor_var(cls, tx, tensor_var, class_type, torch_function_fn): + import torch + + kwargs = dict(tensor_var.__dict__) + assert ( + kwargs.pop("class_type") is torch.Tensor + ), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" + var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs) + var.install_global(tx) + return var + + def install_global(self, tx): + # stash the subclass type to rewrap an output tensor if needed + # this is needed because the actual type needs to be available + # each time the compiled artifact is run and outputs a wrapped tensor. + if self.global_mangled_class_name(tx) not in tx.output.global_scope: + # Safe because global_mangled_class_name figures it out + tx.output.install_global_unsafe( + self.global_mangled_class_name(tx), self.class_type + ) + + def python_type(self): + return self.class_type + + def class_type_var(self, tx): + return TensorSubclassVariable( + self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) + ) + + def global_mangled_class_name(self, tx): + # The global_mangled_class_name should be different for different + # invocations of torch.compile. Otherwise, we can run into a situation + # where multiple torch.compile invocations re-use the same global name, + # but the global's lifetime is tied to the first invocation (and + # may be deleted when the first torch.compile invocation is deleted) + # We mangle it based off of the output_graph's id. + compile_id = tx.output.compile_id + return f"__subclass_{self.class_type.__name__}_{id(self.class_type)}_c{id}" + + def var_getattr(self, tx, name): + # [Note: __torch_function__] We currently only support attributes that are defined on + # base tensors, custom attribute accesses will graph break. + import torch + from .builder import SourcelessBuilder + + if name in banned_attrs or not hasattr(torch.Tensor, name): + unimplemented( + f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" + ) + + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Accessing overridden method/attribute {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + + if tx.output.torch_function_enabled: + if self.source: + install_guard( + AttrSource(AttrSource(self.source, "__class__"), name).make_guard( + GuardBuilder.FUNCTION_MATCH + ) + ) + get_fn = SourcelessBuilder()(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + [self], + {}, + ) + else: + return super().var_getattr(tx, name) + + def call_torch_function(self, tx, fn, types, args, kwargs): + return call_torch_function( + tx, + self.class_type_var(tx), + self.torch_function_fn, + fn, + types, + args, + kwargs, + ) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + # This code block implements inlining the __torch_function__ override + # of `call_method`. + if tx.output.torch_function_enabled: + import torch + from .builder import SourcelessBuilder, VariableBuilder + + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Calling overridden method {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + + # [Note: __torch_function__] Currently we only support methods that are defined on tensor + # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality + # We've established with the above check that the method is not overridden, so we guard that the method is the same + # as the impl defined on tensor and retrieve it + if self.source: + func_var = VariableBuilder( + tx, AttrSource(AttrSource(self.source, "__class__"), name) + )(inspect.getattr_static(self.python_type(), name)) + else: + func_var = SourcelessBuilder()(tx, getattr(torch.Tensor, name)) + return dispatch_torch_function(tx, func_var, [self] + args, kwargs) + else: + return super().call_method(tx, name, args, kwargs) diff --git a/MLPY/Lib/site-packages/torch/_dynamo/variables/user_defined.py b/MLPY/Lib/site-packages/torch/_dynamo/variables/user_defined.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f3bb3ff6926b5f17ca1420a378217dd54ef743 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_dynamo/variables/user_defined.py @@ -0,0 +1,946 @@ +# mypy: ignore-errors + +import collections +import contextlib +import functools +import importlib +import inspect +import itertools +import random +import sys +import threading +import types +from typing import Dict, List + +from ..bytecode_transformation import create_call_function + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +try: + from torch.utils._cxx_pytree import PyTreeSpec +except ImportError: + PyTreeSpec = type(None) + +import torch._dynamo.config + +import torch.nn +from torch._guards import TracingContext + +from .. import variables +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource +from ..utils import ( + all_hook_names, + build_checkpoint_variable, + check_constant_args, + get_custom_getattr, + has_torch_function, + is_namedtuple_cls, + is_utils_checkpoint, + istype, + namedtuple_fields, + object_has_getattribute, + proxy_args_kwargs, + tensortype_to_dtype, +) +from .base import MutableLocal, VariableTracker +from .ctx_manager import GenericContextWrappingVariable, NullContextVariable +from .dicts import DefaultDictVariable + + +class UserDefinedVariable(VariableTracker): + pass + + +class UserDefinedClassVariable(UserDefinedVariable): + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self): + return self.value + + def python_type(self): + return type(self.value) + + def as_proxy(self): + return self.value + + def __str__(self): + return f"UserDefinedClassVariable({self.value})" + + @staticmethod + @functools.lru_cache(None) + def _constant_fold_classes(): + return { + torch.device, + torch.finfo, + torch.iinfo, + torch.Size, + } + + @staticmethod + @functools.lru_cache(None) + def _in_graph_classes(): + return set(tensortype_to_dtype.keys()) | { + torch.Tensor, + torch.cuda.Stream, + torch.cuda.Event, + } + + def can_constant_fold_through(self): + return self.value in self._constant_fold_classes() + + def var_getattr(self, tx, name: str) -> "VariableTracker": + from .. import trace_rules + from . import ConstantVariable + from .builder import VariableBuilder + + if name == "__name__": + return ConstantVariable.create(self.value.__name__) + + source = AttrSource(self.source, name) if self.source is not None else None + try: + obj = inspect.getattr_static(self.value, name) + except AttributeError: + obj = None + + if isinstance(obj, staticmethod): + func = obj.__get__(self.value) + if source is not None: + return trace_rules.lookup(func).create_with_source(func, source=source) + else: + return trace_rules.lookup(func)(func) + elif isinstance(obj, classmethod): + return variables.UserMethodVariable(obj.__func__, self, source=source) + elif source and inspect.ismemberdescriptor(obj): + return VariableBuilder(tx, source)(obj.__get__(self.value)) + + # Special handling of collections.OrderedDict.fromkeys() + # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with + # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method(). + # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys), + # and we need duplicate code to handle both cases. + if self.value is collections.OrderedDict and name == "fromkeys": + return super().var_getattr(tx, name) + + if name in getattr(self.value, "__dict__", {}) or ( + self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" + ): + if source: + return VariableBuilder(tx, source)(obj) + elif ConstantVariable.is_literal(obj): + return ConstantVariable.create(obj) + + return super().var_getattr(tx, name) + + def _call_cross_entropy_loss(self, tx, args, kwargs): + """ + functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional loss call: input, target, optional_output + """ + from . import ConstantVariable + + def normalize_args( + weight=ConstantVariable.create(None), + size_average=ConstantVariable.create(None), + ignore_index=ConstantVariable.create(-100), + reduce=ConstantVariable.create(None), + reduction=ConstantVariable.create("mean"), + label_smoothing=ConstantVariable.create(0.0), + ): + return ( + weight, + size_average, + ignore_index, + reduce, + reduction, + label_smoothing, + ) + + ( + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ) = normalize_args(*args, **kwargs) + + def fake_cross_entropy_loss(input, target): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.nn.functional.cross_entropy, + *proxy_args_kwargs( + [ + input, + target, + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ], + {}, + ), + ), + ) + + return variables.LambdaVariable(fake_cross_entropy_loss) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + name == "__subclasses__" + and len(args) == 0 + and not kwargs + and "__subclasses__" not in self.value.__dict__ + ): + options = {"mutable_local": MutableLocal()} + subs_as_vars: List[VariableTracker] = list() + for sub in self.value.__subclasses__(): + source = AttrSource(tx.import_source(sub.__module__), sub.__name__) + subs_as_vars.append( + variables.UserDefinedClassVariable(sub, source=source) + ) + + return variables.ListVariable(subs_as_vars, **options) + elif ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + from .builtin import BuiltinVariable + + return BuiltinVariable.call_custom_dict_fromkeys( + tx, self.value, *args, **kwargs + ) + + return super().call_method(tx, name, args, kwargs) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from ..side_effects import SideEffects + from .builder import SourcelessBuilder, wrap_fx_proxy + from .builtin import BuiltinVariable + + constant_args = check_constant_args(args, kwargs) + + if self.can_constant_fold_through() and constant_args: + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif self.value is torch.nn.CrossEntropyLoss: + return self._call_cross_entropy_loss(tx, args, kwargs) + elif self.value is contextlib.nullcontext: + return NullContextVariable() + elif self.value is collections.OrderedDict: + return BuiltinVariable.call_custom_dict( + tx, collections.OrderedDict, *args, **kwargs + ) + elif ( + self.value is collections.defaultdict + and len(args) <= 1 + and DefaultDictVariable.is_supported_arg(args[0]) + ): + return DefaultDictVariable( + {}, + collections.defaultdict, + args[0], + mutable_local=MutableLocal(), + ) + elif self.value is collections.deque and not kwargs: + if len(args) == 0: + items = [] + elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): + items = args[0].unpack_var_sequence(tx) + else: + unimplemented("deque() with more than 1 arg not supported") + return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) + elif self.value is functools.partial: + if not args: + unimplemented("functools.partial malformed") + # The first arg, a callable (the ctor below will assert on types) + fn = args[0] + rest_args = args[1:] + # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the + # args and keywords + return variables.functions.FunctoolsPartialVariable( + fn, args=rest_args, keywords=kwargs + ) + elif ( + issubclass(type(self.value), type) + and hasattr( + self.value, "__enter__" + ) # TODO(voz): These can invoke user code! + and hasattr( + self.value, "__exit__" + ) # TODO(voz): These can invoke user code! + and check_constant_args(args, kwargs) + and self.value.__init__ == object.__init__ + and len(kwargs) == 0 # TODO(ybliang): support kwargs + ): + unwrapped_args = [x.as_python_constant() for x in args] + return GenericContextWrappingVariable( + unwrapped_args, + cm_obj=self.value(*unwrapped_args), + ) + + elif is_namedtuple_cls(self.value): + fields = namedtuple_fields(self.value) + # check if this a quasi-namedtuple or a real one + if self.value.__module__ == "torch.return_types": + # create pseudo-defaults from values of the quasi-namedtuple + field_defaults = dict(zip(fields, args[0].items)) + else: + field_defaults = self.value._field_defaults + + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + var_tracker_kwargs = {} + for field_name, var_tracker in zip(fields, items): + if var_tracker is None: + if field_name in kwargs: + field_var = kwargs[field_name] + else: + assert field_name in field_defaults + field_var = SourcelessBuilder()(tx, field_defaults[field_name]) + var_tracker_kwargs[field_name] = field_var + + for name, value in var_tracker_kwargs.items(): + assert name in fields + items[fields.index(name)] = value + + assert all(x is not None for x in items) + return variables.NamedTupleVariable(items, self.value) + elif ( + inspect.getattr_static(self.value, "__new__", None) in (object.__new__,) + and SideEffects.cls_supports_mutation_side_effects(self.value) + and self.source + ): + var = tx.output.side_effects.track_object_new( + self.source, + self.value, + variables.UnspecializedNNModuleVariable + if issubclass(self.value, torch.nn.Module) + else UserDefinedObjectVariable, + {}, + ) + if ( + inspect.getattr_static(self.value, "__init__", None) + is torch.nn.Module.__init__ + ): + tx.output.side_effects.store_attr( + var, + "__call_nn_module_init", + variables.ConstantVariable.create(True), + ) + return var + else: + var.call_method(tx, "__init__", args, kwargs) + return var + elif variables.CustomizedDictVariable.is_matching_cls(self.value): + options = {"mutable_local": MutableLocal()} + return variables.CustomizedDictVariable.create( + self.value, args, kwargs, options + ) + elif variables.DataClassVariable.is_matching_cls(self.value): + options = {"mutable_local": MutableLocal()} + return variables.DataClassVariable.create(self.value, args, kwargs, options) + elif ( + variables.RestrictedListSubclassVariable.is_matching_cls(self.value) + and self.source + ): + return variables.RestrictedListSubclassVariable( + variables.BuiltinVariable(list).call_function(tx, args, kwargs).items, + user_cls=self.value, + user_cls_source=self.source, + mutable_local=MutableLocal(), + ) + elif self.value in self._in_graph_classes(): + # torch.LongTensor cannot accept a list of FakeTensors. + # So we stack the list of FakeTensors instead. + if ( + np + and self.value in tensortype_to_dtype + and len(args) == 1 + and isinstance(args[0], variables.ListVariable) + and len(args[0].items) > 1 + and all(isinstance(x, variables.TensorVariable) for x in args[0].items) + ): + # Stack FakeTensor + stacked = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.stack, + *proxy_args_kwargs(args, kwargs), + ), + ) + args = [stacked] + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + *proxy_args_kwargs(args, kwargs), + ), + ) + + return tensor_variable + + return super().call_function(tx, args, kwargs) + + def const_getattr(self, tx, name): + if name == "__name__": + return self.value.__name__ + return super().const_getattr(tx, name) + + +class UserDefinedObjectVariable(UserDefinedVariable): + """ + Mostly objects of defined type. Catch-all for something where we only know the type. + """ + + _nonvar_fields = {"value", "value_type", *UserDefinedVariable._nonvar_fields} + + def __init__(self, value, value_type=None, **kwargs): + super().__init__(**kwargs) + self.value = value + self.value_type = value_type or type(value) + assert type(value) is self.value_type + + def __str__(self): + inner = self.value_type.__name__ + if inner in [ + "builtin_function_or_method", + "getset_descriptor", + "method_descriptor", + "method", + ]: + inner = str(getattr(self.value, "__name__", None)) + return f"{self.__class__.__name__}({inner})" + + def python_type(self): + return self.value_type + + def guard_as_python_constant(self): + if self.source: + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + return self.value + return super().guard_as_python_constant() + + def torch_function_check(self): + assert has_torch_function( + self + ), f"calling torch function on object without __torch_function__ {self}" + + def get_torch_fn(self, tx): + self.torch_function_check() + from .torch_function import build_torch_function_fn + + return build_torch_function_fn(tx, self.value, self.source) + + def call_torch_function(self, tx, fn, types, args, kwargs): + self.torch_function_check() + + from .torch_function import _get_subclass_type_var, call_torch_function + + return call_torch_function( + tx, + _get_subclass_type_var(tx, self), + self.get_torch_fn(tx), + fn, + types, + args, + kwargs, + ) + + @staticmethod + @functools.lru_cache(None) + def _supported_random_functions(): + fns = { + random.random, + random.randint, + random.randrange, + random.uniform, + } + return fns + + def _maybe_get_baseclass_method(self, name): + if name not in getattr(self.value, "__dict__", {}): + try: + return inspect.getattr_static(type(self.value), name) + except AttributeError: + pass + return None + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + BuiltinVariable, + ConstantVariable, + TupleVariable, + UserMethodVariable, + ) + + method = self._maybe_get_baseclass_method(name) + if method is not None: + if method is object.__init__: + return ConstantVariable.create(None) + + # [NOTE] OrderedDict, dict subtypes must always have source + # We cannot instantiate such subtypes in-graph due to builtin __new__ + if method is collections.OrderedDict.keys: + # subclass of OrderedDict + assert not (args or kwargs) + assert self.source # OrderedDict, dict subtypes must always have source + keys = list(self.value.keys()) + assert all(map(ConstantVariable.is_literal, keys)) + install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS)) + return TupleVariable([ConstantVariable.create(k) for k in keys]) + + if ( + method in (collections.OrderedDict.__contains__, dict.__contains__) + and len(args) == 1 + and isinstance(args[0], (ConstantVariable, BuiltinVariable)) + and inspect.getattr_static(type(self.value), "keys") + in (collections.OrderedDict.keys, dict.keys) + ): + assert not kwargs + assert self.source # OrderedDict, dict subtypes must always have source + install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS)) + return ConstantVariable.create( + args[0].as_python_constant() in self.value + ) + + if method is collections.OrderedDict.items and isinstance( + self.value, collections.OrderedDict + ): + assert self.source # OrderedDict, dict subtypes must always have source + assert not (args or kwargs) + items = [] + keys = self.call_method(tx, "keys", [], {}) + for key in keys.unpack_var_sequence(tx): + items.append( + TupleVariable( + [key, self.odict_getitem(tx, key)], + ) + ) + return TupleVariable(items) + + if method is collections.OrderedDict.__getitem__ and len(args) == 1: + assert not kwargs + assert self.source # OrderedDict, dict subtypes must always have source + return self.odict_getitem(tx, args[0]) + + # check for methods implemented in C++ + if isinstance(method, types.FunctionType): + source = ( + None + if self.source is None + else AttrSource(AttrSource(self.source, "__class__"), name) + ) + # TODO(jansel): add a guard to check for monkey patching? + return UserMethodVariable(method, self, source=source).call_function( + tx, args, kwargs + ) + + if method is list.__len__ and self.source and not (args or kwargs): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return ConstantVariable(len(self.value)) + + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + if ( + self.source + and self._maybe_get_baseclass_method("__iter__") is list.__iter__ + and self._maybe_get_baseclass_method("__len__") is list.__len__ + and self._maybe_get_baseclass_method("__getitem__") is list.__getitem__ + ): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return [ + variables.LazyVariableTracker.create( + self.value[k], + source=GetItemSource(self.source, k), + ) + for k in range(len(self.value)) + ] + return super().unpack_var_sequence(tx) + + def is_supported_random(self): + try: + return self.value in self._supported_random_functions() + except TypeError: + # TypeError: unhashable type + return False + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .. import trace_rules + from .builder import VariableBuilder + + if ( + self.is_supported_random() + and all(k.is_python_constant() for k in args) + and all(v.is_python_constant() for v in kwargs.values()) + ): + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + random_call_index = len(tx.output.random_calls) + example_value = self.value(*args, **kwargs) + source = RandomValueSource(random_call_index) + tx.output.random_calls.append((self.value, args, kwargs)) + return VariableBuilder(tx, source).wrap_unspecialized_primitive( + example_value + ) + elif istype(self.value, types.MethodType): + func = self.value.__func__ + obj = self.value.__self__ + if ( + func is torch.utils._contextlib._DecoratorContextManager.clone + and variables.TorchCtxManagerClassVariable.is_matching_cls( + obj.__class__ + ) + and not (args or kwargs) + ): + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, args, kwargs) + + if ( + func is torch.autograd.grad_mode.inference_mode.clone + and obj.__class__ is torch.autograd.grad_mode.inference_mode + ): + # simulate the inference_mode.clone implementation + var = variables.ConstantVariable(obj.mode) + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, [var], kwargs) + elif ( + istype(self.value, functools.partial) + and trace_rules.lookup(self.value.func) + == variables.TorchInGraphFunctionVariable + and all( + variables.ConstantVariable.is_literal(v) + for v in itertools.chain(self.value.args, self.value.keywords.values()) + ) + ): + if self.source: + install_guard( + AttrSource(self.source, "func").make_guard(GuardBuilder.ID_MATCH), + AttrSource(self.source, "args").make_guard( + GuardBuilder.CONSTANT_MATCH + ), + AttrSource(self.source, "keywords").make_guard( + GuardBuilder.CONSTANT_MATCH + ), + ) + + partial_args = [ + variables.ConstantVariable.create(v) for v in self.value.args + ] + partial_args.extend(args) + partial_kwargs = { + k: variables.ConstantVariable.create(v) + for k, v in self.value.keywords.items() + } + partial_kwargs.update(kwargs) + if is_utils_checkpoint(self.value.func): + return build_checkpoint_variable().call_function( + tx, partial_args, partial_kwargs + ) + return variables.TorchInGraphFunctionVariable( + self.value.func + ).call_function(tx, partial_args, partial_kwargs) + elif callable(self.value): + if self.source: + install_guard(self.source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return self.call_method(tx, "__call__", args, kwargs) + + return super().call_function(tx, args, kwargs) + + def _check_for_getattribute(self): + if object_has_getattribute(self.value): + unimplemented("UserDefinedObjectVariable with custom __getattribute__") + + def _check_for_getattr(self): + return get_custom_getattr(self.value) + + def _getattr_static(self, name): + if ( + isinstance(self.value, (torch.nn.Module, PyTreeSpec)) + or "__slots__" in self.value.__class__.__dict__ + or type(self.value) == threading.local + ): + # getattr_static doesn't work on these + subobj = getattr(self.value, name) + else: + subobj = inspect.getattr_static(self.value, name) + return subobj + + def var_getattr(self, tx, name): + from .. import trace_rules + from . import ConstantVariable + from .builder import VariableBuilder + + value = self.value + source = AttrSource(self.source, name) if self.source else None + self._check_for_getattribute() + getattr_fn = self._check_for_getattr() + + class NO_SUCH_SUBOBJ: + pass + + try: + subobj = self._getattr_static(name) + except AttributeError: + subobj = NO_SUCH_SUBOBJ + if isinstance(getattr_fn, types.FunctionType): + return variables.UserMethodVariable( + getattr_fn, self, source=source + ).call_function(tx, [ConstantVariable.create(name)], {}) + elif getattr_fn is not None: + unimplemented("UserDefined with non-function __getattr__") + + if isinstance(subobj, property): + # Rewrite the source being explicit about reading it statically. + if self.source: + source = AttrSource(self.source, name, get_static=True) + source = AttrSource(source, "fget") + return variables.UserMethodVariable( + subobj.fget, self, source=source + ).call_function(tx, [], {}) + elif isinstance(subobj, torch.distributions.utils.lazy_property): + subobj_var = UserDefinedObjectVariable(subobj, source=source) + return variables.UserMethodVariable( + subobj.__get__.__func__, subobj_var, source=source + ).call_function(tx, [self], {}) + elif isinstance(subobj, staticmethod): + func = subobj.__get__(self.value) + if source is not None: + return trace_rules.lookup(func).create_with_source(func, source=source) + else: + return trace_rules.lookup(func)(func) + elif isinstance(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, self.var_getattr(tx, "__class__"), source=source + ) + elif isinstance(subobj, types.FunctionType) or ( + isinstance(subobj, types.MethodType) + and isinstance(self.value, torch.nn.Module) + ): + # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup. + # Static lookup can't tell us it's a method or function correctly, + # so we trigger dynamic lookup here to get the correct type. + dynamic_subobj = getattr(self.value, name) + + while dynamic_subobj is subobj and hasattr(subobj, "_torchdynamo_inline"): + subobj = subobj._torchdynamo_inline + dynamic_subobj = subobj + source = AttrSource(source, "_torchdynamo_inline") if source else None + + if isinstance(subobj, types.MethodType): + if dynamic_subobj.__self__ is not self.value: + unimplemented("__self__ mismatch for bound method") + func = subobj.__func__ + else: + assert isinstance(subobj, types.FunctionType) + func = subobj + + if inspect.ismethod(dynamic_subobj): + return variables.UserMethodVariable(func, self, source=source) + elif inspect.isfunction(dynamic_subobj): + if is_utils_checkpoint(func): + return build_checkpoint_variable(source=source) + elif source is not None: + return trace_rules.lookup(func).create_with_source( + func, source=source + ) + else: + return trace_rules.lookup(func)(func) + + if ( + name in getattr(value, "__dict__", {}) + or ConstantVariable.is_literal(subobj) + or isinstance( + subobj, + ( + torch.Tensor, + torch.nn.Module, + ), + ) + ): + if source: + return VariableBuilder(tx, source)(subobj) + elif ConstantVariable.is_literal(subobj): + return ConstantVariable.create(subobj) + + if ( + name not in getattr(value, "__dict__", {}) + and type(value).__module__.startswith("torch.") + and "torch.optim" not in type(value).__module__ + and not callable(value) + and not isinstance(subobj, types.MethodDescriptorType) + ): + if not source: + assert getattr( + importlib.import_module(type(value).__module__), + type(value).__name__, + ) is type(value) + source = AttrSource( + AttrSource( + tx.import_source(type(value).__module__), type(value).__name__ + ), + name, + ) + + return VariableBuilder(tx, source)(subobj) + options = {"source": source} + if isinstance( + subobj, + ( + torch.distributions.constraints._Interval, + torch.distributions.constraints._Real, + torch.distributions.constraints.Constraint, + ), + ): + return UserDefinedObjectVariable(subobj, **options) + elif isinstance(self.value, torch.nn.Module) and name in all_hook_names: + assert isinstance(subobj, collections.OrderedDict) + if not subobj: + return variables.ConstDictVariable( + subobj, collections.OrderedDict, **options + ) + + if name == "__class__": + return UserDefinedClassVariable(type(self.value), **options) + + return variables.GetAttrVariable(self, name, **options) + + def call_hasattr(self, tx, name: str) -> "VariableTracker": + if tx.output.side_effects.is_attribute_mutation(self): + try: + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + return variables.ConstantVariable.create( + not isinstance(result, variables.DeletedVariable) + ) + except KeyError: + pass + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + if self._check_for_getattribute() or self._check_for_getattr(): + unimplemented("hasattr with custom __getattr__") + + try: + self._getattr_static(name) + return variables.ConstantVariable.create(True) + except AttributeError: + return variables.ConstantVariable.create(False) + + def odict_getitem(self, tx, key): + from .builder import VariableBuilder + from .dicts import is_hashable + + # TODO this should probably be merged with the dict handling + + index = ( + key.source + if is_hashable(key) and key.source is not None + else key.as_python_constant() + ) + + return VariableBuilder( + tx, + ODictGetItemSource(self.source, index), + )(collections.OrderedDict.__getitem__(self.value, key.as_python_constant())) + + +class KeyedJaggedTensorVariable(UserDefinedObjectVariable): + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torchrec.sparse.jagged_tensor") + return mod is not None and type(obj) is mod.KeyedJaggedTensor + + def __init__(self, value, **kwargs): + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + assert type(value) is KeyedJaggedTensor + super().__init__(value, **kwargs) + + def var_getattr(self, tx, name): + if ( + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt + and self.source is not None + and name in ("_length_per_key", "_offset_per_key") + ): + with TracingContext.patch(force_unspec_int_unbacked_size_like=True): + return super().var_getattr(tx, name) + return super().var_getattr(tx, name) + + +class RemovableHandleVariable(VariableTracker): + REMOVED = -1 + + def __init__( + self, + mutable_local=None, + # index of the registration in the side_effects owned register_hook/handle list, used during removal. + idx=None, + **kwargs, + ): + super().__init__(**kwargs) + self.mutable_local = mutable_local + self.idx = idx + + def call_method(self, tx, method_name, args, kwargs): + if method_name == "remove": + if self.idx != self.REMOVED: + tx.output.side_effects.remove_hook(self.idx) + self.idx = self.REMOVED + return variables.ConstantVariable.create(None) + super().call_method(tx, method_name, args, kwargs) + + def reconstruct(self, codegen): + if self.idx == self.REMOVED: + # Hook has already been removed, return a dummy handle + codegen.load_import_from("torch._dynamo.utils", "invalid_removeable_handle") + codegen.extend_output(create_call_function(0, True)) + return + # unreachable due to codegen.add_cache() when the hook is installed + super().reconstruct(codegen) diff --git a/MLPY/Lib/site-packages/torch/_export/__init__.py b/MLPY/Lib/site-packages/torch/_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..421614c6e2707e7e661d599fdd847fdc331aab30 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/__init__.py @@ -0,0 +1,406 @@ +import copy +import dataclasses +import functools +import io +import json +import os +import re +import sys +import types +import warnings +import weakref +import zipfile +from collections import OrderedDict +from contextlib import contextmanager + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +import torch._dynamo +import torch.fx +import torch.utils._pytree as pytree + +from torch._decomp import core_aten_decompositions, get_decompositions +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.exc import UserError, UserErrorType +from torch._dynamo.source import ConstantSource +from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._functorch.aot_autograd import aot_export_module, GraphSignature +from torch._functorch.eager_transforms import functionalize +from torch._guards import detect_fake_mode +from torch._inductor import config +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensor +from torch._utils_internal import log_export_usage +from torch.export._tree_utils import reorder_kwargs +from torch.export._unlift import _create_stateful_graph_module +from torch.export.dynamic_shapes import ( + _process_constraints, + _process_dynamic_shapes, + Constraint, + dims, + dynamic_dim, +) +from torch.export.exported_program import ( + _disable_prexisiting_fake_mode, + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) +from torch.export.graph_signature import ( + _sig_to_specs, + ArgumentSpec, + ConstantArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, +) +from torch.fx import traceback as fx_traceback +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + GuardOnDataDependentSymNode, + ShapeEnv, + StrictMinMaxConstraint, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges + +from .passes.add_runtime_assertions_for_constraints_pass import ( + _AddRuntimeAssertionsForInlineConstraintsPass, +) +from .wrappers import _wrap_submodules + + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + allow_rnn: bool = True + + +@compatibility(is_backward_compatible=False) +def capture_pre_autograd_graph( + f: torch.nn.Module, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, +) -> torch.nn.Module: + """ + A helper function that is intended to trace a module before any pre-autograd + decomposition is run. The produced module will be "non-functional" and + composed of aten operators. Later this API will be deleted in favor of more general + torch.export API. + + Args: + f: nn.Module to be traced + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + Returns: + An nn.Module containing the traced method. + + """ + from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG + from torch.export.dynamic_shapes import _process_dynamic_shapes + + log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) + + assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." + + if kwargs is None: + kwargs = {} + + constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes) + + # Do not decompose dropout for exported models, because in eval mode the dropout + # op disappears from the graph, which makes it difficult to switch to train mode. + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. + decomp_table = { + op: op.decompose + for op in FunctionalTensor.maybe_aliasing_or_mutating_ops + if op != torch.ops.aten.dropout.default + } + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + m = torch._dynamo.export( + f, + constraints=constraints, + assume_static_by_default=True, + tracing_mode="symbolic", + decomposition_table=decomp_table, + pre_dispatch=True, + aten_graph=True, + _log_export_usage=False, + )( + *args, + **kwargs, + )[0] + + _, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs) + + m.meta["inline_constraints"] = { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if re.match(r"^[if]\d+$", str(k)) + } + + if isinstance(f, torch.nn.Module): + from torch.export._trace import _restore_state_dict + _restore_state_dict(f, m) + + flat_args, _ = pytree.tree_flatten((args, kwargs or {})) + range_constraints = _process_constraints(fake_mode, m, 0, flat_args) + + module = _create_stateful_graph_module( + m, + range_constraints=range_constraints, + ) + + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: + + def _my_train(self, mode: bool = True): + ... + + def _my_eval(self): + ... + + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + return module + + +def save( + ep: ExportedProgram, + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + opset_version: Optional[Dict[str, int]] = None, +) -> None: + if not isinstance(ep, ExportedProgram): + raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}") + + from .serde.serialize import serialize, SerializedArtifact + from .serde.schema import SCHEMA_VERSION + artifact: SerializedArtifact = serialize(ep, opset_version) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with zipfile.ZipFile(f, 'w') as zipf: + # Save every field the SerializedArtifact to a file + assert isinstance(artifact.exported_program, bytes) + zipf.writestr("serialized_exported_program.json", artifact.exported_program) + zipf.writestr("serialized_state_dict.pt", artifact.state_dict) + zipf.writestr("serialized_constants.pt", artifact.constants) + + zipf.writestr('version', ".".join(map(str, SCHEMA_VERSION))) + + # Add extra files if provided + if extra_files: + for extra_file_name, content in extra_files.items(): + encoded_content = content.encode('utf-8') + zipf.writestr(f"extra_files/{extra_file_name}", encoded_content) + + +def load( + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ExportedProgram: + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + extra_files = extra_files or {} + + with zipfile.ZipFile(f, 'r') as zipf: + # Check the version + version = zipf.read('version').decode().split('.') + from .serde.schema import SCHEMA_VERSION + + assert len(version) == len(SCHEMA_VERSION) + if version[0] != str(SCHEMA_VERSION[0]): + raise RuntimeError( + f"Serialized version {version} does not match our current " + f"schema version {SCHEMA_VERSION}." + ) + + from .serde.serialize import deserialize, SerializedArtifact + + # Load serialized_ep and serialized_state_dict from the zip file + + serialized_exported_program: Optional[bytes] = None + serialized_state_dict: Optional[bytes] = None + serialized_constants: Optional[bytes] = None + + for file_info in zipf.infolist(): + file_content = zipf.read(file_info.filename) + + if file_info.filename == "serialized_exported_program.json": + serialized_exported_program = file_content + elif file_info.filename == "serialized_state_dict.json": + warnings.warn("This version of file is deprecated") + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.json": + warnings.warn("This version of file is deprecated") + serialized_constants = file_content + elif file_info.filename == "serialized_state_dict.pt": + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.pt": + serialized_constants = file_content + elif file_info.filename.startswith("extra_files"): + filename = file_info.filename.split("/", 1)[1] + extra_files[filename] = file_content.decode('utf-8') + + assert serialized_exported_program is not None + assert serialized_state_dict is not None + assert serialized_constants is not None + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + + return ep + + +def aot_compile( + f: Callable, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Note: this function is not stable yet + + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside, generates executable cpp code from the program, and returns + the path to the generated shared library + + Args: + f: the `nn.Module` or callable to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + options: A dictionary of options to control inductor + + disable_constraint_solver: Whether the dim constraint solver must be disabled. + + Returns: + Path to the generated shared library + """ + from torch.export._trace import _export_to_torch_ir + from torch._inductor.decomposition import select_decomp_table + + constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes) + + if config.is_predispatch: + gm = torch.export._trace._export(f, args, kwargs, constraints, pre_dispatch=True).module() + else: + # We want to export to Torch IR here to utilize the pre_grad passes in + # inductor, which run on Torch IR. + gm = _export_to_torch_ir( + f, + args, + kwargs, + constraints, + disable_constraint_solver=disable_constraint_solver, + # Disabling this flag, because instead we can rely on the mapping + # dynamo_flat_name_to_original_fqn which is coming from Dynamo. + restore_fqn=False, + ) + flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {})) + + with torch.no_grad(): + so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type] + + return so_path + +def aot_load(so_path: str, device: str) -> Callable: + """ + Loads a shared library generated by aot_compile and returns a callable + + Args: + so_path: Path to the shared library + + Returns: + A callable + """ + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a873e020bcef25adda07cd8fe067e024c37627 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/error.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/error.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f3f52b4fece6c2f40f8867bf5c57a18d3bdb63c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/error.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/exported_program.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/exported_program.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff05a230c447af0434ca2940daf32afb067fc987 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/exported_program.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3abfbec9fcf44398b1a49abf1a99c6d4f6f6b16 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/pass_base.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/pass_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edcc8c3d545d734b12cab912584e891c2fbff8ca Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/pass_base.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d51de431c6c9a0b8848d9dbf6608fbda7a73d3 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/verifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/verifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaca8665b1b7e4a39c4f67b1088ab85fe3f832e7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/verifier.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/__pycache__/wrappers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/__pycache__/wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b18aab158e50dea5f34efab89f7b8e162b8a643f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/__pycache__/wrappers.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/__init__.py b/MLPY/Lib/site-packages/torch/_export/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/MLPY/Lib/site-packages/torch/_export/db/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb7e1b1135a74a20ff51df42755a6e22571b1f1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/__pycache__/case.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7cf69d8542739bbc1c3ca88b7bfb6e459d37cc8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/case.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/__pycache__/gen_example.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/gen_example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017f6403ce886c3f68dc4eb7c3c6fc7f55ef6360 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/gen_example.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/__pycache__/logging.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6d44b63727c7f375f9b92f4b15c0b063a28c9e1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/__pycache__/logging.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/case.py b/MLPY/Lib/site-packages/torch/_export/db/case.py new file mode 100644 index 0000000000000000000000000000000000000000..086d16b1a0c9d607b34c4f03c3b636956cdc29fe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/case.py @@ -0,0 +1,188 @@ +import inspect +import re +import string +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple, Union +from types import ModuleType + +import torch + +_TAGS: Dict[str, Dict[str, Any]] = { + "torch": { + "cond": {}, + "dynamic-shape": {}, + "escape-hatch": {}, + "map": {}, + "dynamic-value": {}, + "operator": {}, + "mutation": {}, + }, + "python": { + "assert": {}, + "builtin": {}, + "closure": {}, + "context-manager": {}, + "control-flow": {}, + "data-structure": {}, + "standard-library": {}, + "object-model": {}, + }, +} + + +class SupportLevel(Enum): + """ + Indicates at what stage the feature + used in the example is handled in export. + """ + + SUPPORTED = 1 + NOT_SUPPORTED_YET = 0 + + +class ExportArgs: + __slots__ = ("args", "kwargs") + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +InputsType = Union[Tuple[Any, ...], ExportArgs] + + +def check_inputs_type(x): + if not isinstance(x, (ExportArgs, tuple)): + raise ValueError( + f"Expecting inputs type to be either a tuple, or ExportArgs, got: {type(x)}" + ) + + +def _validate_tag(tag: str): + parts = tag.split(".") + t = _TAGS + for part in parts: + assert set(part) <= set( + string.ascii_lowercase + "-" + ), f"Tag contains invalid characters: {part}" + if part in t: + t = t[part] + else: + raise ValueError(f"Tag {tag} is not found in registered tags.") + + +@dataclass(frozen=True) +class ExportCase: + example_inputs: InputsType + description: str # A description of the use case. + model: torch.nn.Module + name: str + extra_inputs: Optional[InputsType] = None # For testing graph generalization. + # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) + tags: Set[str] = field(default_factory=set) + support_level: SupportLevel = SupportLevel.SUPPORTED + dynamic_shapes: Optional[Dict[str, Any]] = None + + def __post_init__(self): + check_inputs_type(self.example_inputs) + if self.extra_inputs is not None: + check_inputs_type(self.extra_inputs) + + for tag in self.tags: + _validate_tag(tag) + + if not isinstance(self.description, str) or len(self.description) == 0: + raise ValueError(f'Invalid description: "{self.description}"') + + +_EXAMPLE_CASES: Dict[str, ExportCase] = {} +_MODULES: Set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {} + + +def register_db_case(case: ExportCase) -> None: + """ + Registers a user provided ExportCase into example bank. + """ + if case.name in _EXAMPLE_CASES: + if case.name not in _EXAMPLE_CONFLICT_CASES: + _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] + _EXAMPLE_CONFLICT_CASES[case.name].append(case) + return + + _EXAMPLE_CASES[case.name] = case + + +def to_snake_case(name): + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +def _make_export_case(m, name, configs): + if not issubclass(m, torch.nn.Module): + raise TypeError("Export case class should be a torch.nn.Module.") + m = m() + + if "description" not in configs: + # Fallback to docstring if description is missing. + assert ( + m.__doc__ is not None + ), f"Could not find description or docstring for export case: {m}" + configs = {**configs, "description": m.__doc__} + return ExportCase(**{**configs, "model": m, "name": name}) + + +def export_case(**kwargs): + """ + Decorator for registering a user provided case into example bank. + """ + + def wrapper(m): + configs = kwargs + module = inspect.getmodule(m) + if module in _MODULES: + raise RuntimeError("export_case should only be used once per example file.") + + assert module is not None + _MODULES.add(module) + normalized_name = to_snake_case(m.__name__) + module_name = module.__name__.split(".")[-1] + if module_name != normalized_name: + raise RuntimeError( + f'Module name "{module.__name__}" is inconsistent with exported program ' + + f'name "{m.__name__}". Please rename the module to "{normalized_name}".' + ) + + case = _make_export_case(m, module_name, configs) + register_db_case(case) + return case + + return wrapper + + +def export_rewrite_case(**kwargs): + def wrapper(m): + configs = kwargs + + parent = configs.pop("parent") + assert isinstance(parent, ExportCase) + key = parent.name + if key not in _EXAMPLE_REWRITE_CASES: + _EXAMPLE_REWRITE_CASES[key] = [] + + configs["example_inputs"] = parent.example_inputs + case = _make_export_case(m, to_snake_case(m.__name__), configs) + _EXAMPLE_REWRITE_CASES[key].append(case) + return case + + return wrapper + + +def normalize_inputs(x: InputsType) -> ExportArgs: + if isinstance(x, tuple): + return ExportArgs(*x) + + assert isinstance(x, ExportArgs) + return x diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__init__.py b/MLPY/Lib/site-packages/torch/_export/db/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7085667b2a451ddcba198435be515d97eb2f3f0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/__init__.py @@ -0,0 +1,52 @@ +import glob +import importlib +from os.path import basename, dirname, isfile, join + +import torch +from torch._export.db.case import ( + _EXAMPLE_CASES, + _EXAMPLE_CONFLICT_CASES, + _EXAMPLE_REWRITE_CASES, + SupportLevel, +) + + +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py") +] + +# Import all module in the current directory. +from . import * # noqa: F403 + + +def all_examples(): + return _EXAMPLE_CASES + + +if len(_EXAMPLE_CONFLICT_CASES) > 0: + + def get_name(case): + model = case.model + if isinstance(model, torch.nn.Module): + model = type(model) + return model.__name__ + + msg = "Error on conflict export case name.\n" + for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): + msg += f"Case name {case_name} is associated with multiple cases:\n " + msg += f"[{','.join(map(get_name, cases))}]\n" + + raise RuntimeError(msg) + + +def filter_examples_by_support_level(support_level: SupportLevel): + return { + key: val + for key, val in all_examples().items() + if val.support_level == support_level + } + + +def get_rewrite_cases(case): + return _EXAMPLE_REWRITE_CASES.get(case.name, []) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d213f48892a5362258d8f17152b5556ae7fe0868 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e56a8106b277a2ad8ff0612ba8fd151f876df81 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a77993402701d38184ac47e9be89aa4327fd00e8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c6ea78164a18845c0cc28300d1e9cb2e4bd81c9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3698f59ffd8fcabe91ba99906e1a3037fa04ff6c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a59d2b46b7aa2858609ffbc2654578d2f0f3574f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18f4280924ecaf5e943f9231f65a9cc467798835 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d09cb5ab75692b0f1c2caaa31e73feaf942735b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0866761fd265fbb590d0bbd2dd0d777a6566926 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1102dd642ab2d978aedf6f15f6174b6f66efb3e3 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8bfd05e301e55815e0a9ef9bf1b0dd23352be6a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0812ec54b93d1532e48be2780e0a57a757034af8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950d387d7a4393942bb877661fa0d4ec7830ce50 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f583e3cba6fac03e245c4f783c2fc30e03879a7f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f006003516ba2729ada4af41b63607bf01a048b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc0269859d6378d0e2b5ab893414b4dabc1b174 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a816906161146bebd8b8bfd0a533bac91f69015 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aeaaa86ea3a64fcd77ad99557c5b7d50c233a8d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5092b910a24fb8bdbc94c77068b2d7160a094b28 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b4c30b30e79e9a6434262151130c546510de453 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bffab1a442688cdf0893e8a97db107fd7d80b2a7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e45755d4798633e39ffcbb760bc18d296d4f9c75 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..525103ae8bb92f52657fbee87c3180564fdef10f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b6be96b3268c50aefbd19506e6a7c7d89ca2991 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9188363730330871f269050d93cc45f039cbf49c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f29fbb78380b9a20dae18f559b6e23e9c38e5d8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a05eea1696dc488b4e97377e1b9ba6ddd029f207 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c93e96b8263319c4d0e53f36ba2ee60544e328a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..386f484904f03b5813ab67f69001edd7cc2efc94 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ec76cadfb60bd7b7e1b4462cdee34a2a7d5b845 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93677443c666ea2e97d22ca4a4bb2429d5da95c7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df6477633a29f491039b60f43ee1d762b9fb7f7f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..661cfcf37e9ebedf65eb778e9b37d53e5075349e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..252b9158bd0f859293dd23f951c1a37ff3e8c847 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b0e3365d1c0cb26a562063059c0f6d3fc0f542 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/torch_sym_min.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fca22cab8edc29cb9d9657628d83c9f41efb362 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5bfe0b0452c70d96a664419199f657f57938d9e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/assume_constant_result.py b/MLPY/Lib/site-packages/torch/_export/db/examples/assume_constant_result.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d20d2ecf482c783f6278923595488c62ac3559 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/assume_constant_result.py @@ -0,0 +1,24 @@ +import torch +import torch._dynamo as torchdynamo + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.tensor(4)), + tags={"torch.escape-hatch"}, +) +class AssumeConstantResult(torch.nn.Module): + """ + Applying `assume_constant_result` decorator to burn make non-tracable code as constant. + """ + + def __init__(self): + super().__init__() + + @torchdynamo.assume_constant_result + def get_item(self, y): + return y.int().item() + + def forward(self, x, y): + return x[: self.get_item(y)] diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/autograd_function.py b/MLPY/Lib/site-packages/torch/_export/db/examples/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..a9093b370e85a10e808221b5b9d6341530af9c85 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/autograd_function.py @@ -0,0 +1,26 @@ +import torch + +from torch._export.db.case import export_case + + +class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + 1 + + +@export_case( + example_inputs=(torch.randn(3, 2),), +) +class AutogradFunction(torch.nn.Module): + """ + TorchDynamo does not keep track of backward() on autograd functions. We recommend to + use `allow_in_graph` to mitigate this problem. + """ + + def forward(self, x): + return MyAutogradFunction.apply(x) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/class_method.py b/MLPY/Lib/site-packages/torch/_export/db/examples/class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..aafe70b02cd5441fd27d87bb4f7d2fe1bcc73fbd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/class_method.py @@ -0,0 +1,24 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 4),), +) +class ClassMethod(torch.nn.Module): + """ + Class methods are inlined during tracing. + """ + + @classmethod + def method(cls, x): + return x + 1 + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 2) + + def forward(self, x): + x = self.linear(x) + return self.method(x) * self.__class__.method(x) * type(self).method(x) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_class_method.py b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..450f08ff50be2b6a7651bf40526d9a236ecded01 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_class_method.py @@ -0,0 +1,46 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + + +@export_case( + example_inputs=(torch.ones(3),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_nested_function.py b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e84f11edc7ccdbd07811f12e9c3b601bf0cf04 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_nested_function.py @@ -0,0 +1,44 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.ones(3),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondBranchNestedFunction(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using nested function in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + def true_fn(x): + def inner_true_fn(y): + return x + y + + return inner_true_fn(x) + + def false_fn(x): + def inner_false_fn(y): + return x - y + + return inner_false_fn(x) + + return cond(x.shape[0] < 10, true_fn, false_fn, [x]) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..46070590037719de7047f36a56624c70f28f143b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -0,0 +1,63 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.ones(6),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondBranchNonlocalVariables(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. + + The code below will not work because capturing closure variables is not supported. + ``` + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y + my_tensor_var + my_primitive_var + + def false_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y - my_tensor_var - my_primitive_var + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + ``` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(x, y, z): + return x + y + z + + def false_fn(x, y, z): + return x - y - z + + return cond( + x.shape[0] > 5, + true_fn, + false_fn, + [x, my_tensor_var, torch.tensor(my_primitive_var)], + ) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/cond_closed_over_variable.py b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_closed_over_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..e3745271bc911ffa9097bdfd6457fe1e8ba14c07 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_closed_over_variable.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.tensor(True), torch.ones(3, 2)), + tags={"torch.cond", "python.closure"}, +) +class CondClosedOverVariable(torch.nn.Module): + """ + torch.cond() supports branches closed over arbitrary variables. + """ + + def forward(self, pred, x): + def true_fn(val): + return x * 2 + + def false_fn(val): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/cond_operands.py b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_operands.py new file mode 100644 index 0000000000000000000000000000000000000000..d225c2a39e33ad50bf32ed9ddddd4350e3f51ea8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_operands.py @@ -0,0 +1,39 @@ +import torch + +from torch._export.db.case import export_case +from torch.export import Dim +from functorch.experimental.control_flow import cond + +x = torch.randn(3, 2) +y = torch.ones(2) +dim0_x = Dim("dim0_x") + +@export_case( + example_inputs=(x, y), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, + extra_inputs=(torch.randn(2, 2), torch.ones(2)), + dynamic_shapes={"x": {0: dim0_x}, "y": None}, +) +class CondOperands(torch.nn.Module): + """ + The operands passed to cond() must be: + - a list of tensors + - match arguments of `true_fn` and `false_fn` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + + def forward(self, x, y): + def true_fn(x, y): + return x + y + + def false_fn(x, y): + return x - y + + return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/cond_predicate.py b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_predicate.py new file mode 100644 index 0000000000000000000000000000000000000000..df23cc4df7a52ba8c1fd8b1946b46d31e3d9d248 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/cond_predicate.py @@ -0,0 +1,29 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import cond + + +@export_case( + example_inputs=(torch.ones(6, 4, 3),), + tags={ + "torch.cond", + "torch.dynamic-shape", + }, +) +class CondPredicate(torch.nn.Module): + """ + The conditional statement (aka predicate) passed to cond() must be one of the following: + - torch.Tensor with a single element + - boolean expression + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + pred = x.dim() > 2 and x.shape[2] > 10 + + return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/constrain_as_size_example.py b/MLPY/Lib/site-packages/torch/_export/db/examples/constrain_as_size_example.py new file mode 100644 index 0000000000000000000000000000000000000000..f6274acf9dde24214fbf36ec3daad3a6cbf84c58 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/constrain_as_size_example.py @@ -0,0 +1,27 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.tensor(4),), + tags={ + "torch.dynamic-value", + "torch.escape-hatch", + }, +) +class ConstrainAsSizeExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at constrain_as_value and constrain_as_size APIs + constrain_as_size is used for values that NEED to be used for constructing + tensor. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + a = x.item() + torch._constrain_as_size(a, min=0, max=5) + return torch.ones((a, 5)) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/constrain_as_value_example.py b/MLPY/Lib/site-packages/torch/_export/db/examples/constrain_as_value_example.py new file mode 100644 index 0000000000000000000000000000000000000000..fa32144602c18d5b0456aa6b4f37e4f7457f0e0d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/constrain_as_value_example.py @@ -0,0 +1,30 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.tensor(4), torch.randn(5, 5)), + tags={ + "torch.dynamic-value", + "torch.escape-hatch", + }, +) +class ConstrainAsValueExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at constrain_as_value and constrain_as_size APIs. + constrain_as_value is used for values that don't need to be used for constructing + tensor. + """ + + def __init__(self): + super().__init__() + + def forward(self, x, y): + a = x.item() + torch._constrain_as_value(a, min=0, max=5) + + if a < 6: + return y.sin() + return y.cos() diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/decorator.py b/MLPY/Lib/site-packages/torch/_export/db/examples/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e56fb94535fa367864f2e74d74e736619d4be8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/decorator.py @@ -0,0 +1,26 @@ +import functools + +import torch + +from torch._export.db.case import export_case + + +def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + + return wrapper + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.ones(3, 2)), +) +class Decorator(torch.nn.Module): + """ + Decorators calls are inlined into the exported function during tracing. + """ + + @test_decorator + def forward(self, x, y): + return x + y diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dictionary.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..aa52d5a91519486d911e417f2cd95cd924381319 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dictionary.py @@ -0,0 +1,21 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.tensor(4)), + tags={"python.data-structure"}, +) +class Dictionary(torch.nn.Module): + """ + Dictionary structures are inlined and flattened along tracing. + """ + def __init__(self): + super().__init__() + + def forward(self, x, y): + elements = {} + elements["x2"] = x * x + y = y * elements["x2"] + return {"y": y} diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_assert.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c79d37d413da7c22e60beb3cfed4b722810527 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_assert.py @@ -0,0 +1,22 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.assert"}, +) +class DynamicShapeAssert(torch.nn.Module): + """ + A basic usage of python assertion. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + # assertion with error message + assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" + # assertion without error message + assert x.shape[0] > 1 + return x diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..28cc1eea66307807d2c277bad6667a589c4002c8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py @@ -0,0 +1,19 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.dynamic-shape"}, +) +class DynamicShapeConstructor(torch.nn.Module): + """ + Tensor constructors should be captured with dynamic shape inputs rather + than being baked in with static shape. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones(x.shape[0] * 2) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..977e5e95276297fe4dffd5b3f1a2d22603c3ebbd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -0,0 +1,21 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2, 2),), + tags={"torch.dynamic-shape", "python.control-flow"}, +) +class DynamicShapeIfGuard(torch.nn.Module): + """ + `if` statement with backed dynamic shape predicate will be specialized into + one particular branch and generate a guard. However, export will fail if the + the dimension is marked as dynamic shape from higher level API. + """ + + def forward(self, x): + if x.shape[0] == 3: + return x.cos() + + return x.sin() diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_map.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_map.py new file mode 100644 index 0000000000000000000000000000000000000000..3c09a72f528fd8c3003db9b24178f4430d2bf73b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_map.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case +from functorch.experimental.control_flow import map + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.ones(2)), + tags={"torch.dynamic-shape", "torch.map"}, +) +class DynamicShapeMap(torch.nn.Module): + """ + functorch map() maps a function over the first tensor dimension. + """ + + def __init__(self): + super().__init__() + + def forward(self, xs, y): + def body(x, y): + return x + y + + return map(body, xs, y) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_round.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_round.py new file mode 100644 index 0000000000000000000000000000000000000000..df35d45193681199307d7b580fb99b8e7f8ba6b9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_round.py @@ -0,0 +1,24 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel +from torch.export import Dim + +x = torch.ones(3, 2) +dim0_x = Dim("dim0_x") + +@export_case( + example_inputs=(x,), + tags={"torch.dynamic-shape", "python.builtin"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, + dynamic_shapes={"x": {0: dim0_x}}, +) +class DynamicShapeRound(torch.nn.Module): + """ + Calling round on dynamic shapes is not supported. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x[: round(x.shape[0] / 2)] diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9b50b38ce63f68105a750cca0ebd7357c0239c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py @@ -0,0 +1,20 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.dynamic-shape"}, +) +class DynamicShapeSlicing(torch.nn.Module): + """ + Slices with dynamic shape arguments should be captured into the graph + rather than being baked in. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_view.py b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_view.py new file mode 100644 index 0000000000000000000000000000000000000000..c414df8c8dbadd772ffca59ab08c20392005fff2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/dynamic_shape_view.py @@ -0,0 +1,22 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(10, 10),), + tags={"torch.dynamic-shape"}, +) +class DynamicShapeView(torch.nn.Module): + """ + Dynamic shapes should be propagated to view arguments instead of being + baked into the exported graph. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + new_x_shape = x.size()[:-1] + (2, 5) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/fn_with_kwargs.py b/MLPY/Lib/site-packages/torch/_export/db/examples/fn_with_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..0bbee5fc57cb18b6ee91a823d488e92e7a1d8b62 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/fn_with_kwargs.py @@ -0,0 +1,32 @@ +import torch + +from torch._export.db.case import export_case, ExportArgs, SupportLevel + + +@export_case( + example_inputs=ExportArgs( + torch.randn(4), + (torch.randn(4), torch.randn(4)), + *[torch.randn(4), torch.randn(4)], + mykw0=torch.randn(4), + input0=torch.randn(4), input1=torch.randn(4) + ), + tags={"python.data-structure"}, + support_level=SupportLevel.SUPPORTED, +) +class FnWithKwargs(torch.nn.Module): + """ + Keyword arguments are not supported at the moment. + """ + def __init__(self): + super().__init__() + + def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): + out = pos0 + for arg in tuple0: + out = out * arg + for arg in myargs: + out = out * arg + out = out * mykw0 + out = out * mykwargs["input0"] * mykwargs["input1"] + return out diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/list_contains.py b/MLPY/Lib/site-packages/torch/_export/db/examples/list_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..0c9eba71529240e13850d6bf033d0c6b550c0ab1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/list_contains.py @@ -0,0 +1,21 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.dynamic-shape", "python.data-structure", "python.assert"}, +) +class ListContains(torch.nn.Module): + """ + List containment relation can be checked on a dynamic shape or constants. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + assert x.size(-1) in [6, 2] + assert x.size(0) not in [4, 5, 6] + assert "monkey" not in ["cow", "pig"] + return x + x diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/list_unpack.py b/MLPY/Lib/site-packages/torch/_export/db/examples/list_unpack.py new file mode 100644 index 0000000000000000000000000000000000000000..d68c5cf0f2a917a794e54ea91103c168f3c85729 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/list_unpack.py @@ -0,0 +1,27 @@ +from typing import List + +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=([torch.ones(3, 2), torch.tensor(4), torch.tensor(5)],), + tags={"python.control-flow", "python.data-structure"}, +) +class ListUnpack(torch.nn.Module): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + + def __init__(self): + super().__init__() + + def forward(self, args: List[torch.Tensor]): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + x, *y = args + return x + y[0] diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/model_attr_mutation.py b/MLPY/Lib/site-packages/torch/_export/db/examples/model_attr_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..8aca91755613189e46988075a7dfcb1e9547c0d1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/model_attr_mutation.py @@ -0,0 +1,25 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.object-model"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, +) +class ModelAttrMutation(torch.nn.Module): + """ + Attribute mutation is not supported. + """ + + def __init__(self): + super().__init__() + self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)] + + def recreate_list(self): + return [torch.zeros(3, 2), torch.zeros(3, 2)] + + def forward(self, x): + self.attr_list = self.recreate_list() + return x.sum() + self.attr_list[0].sum() diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/nested_function.py b/MLPY/Lib/site-packages/torch/_export/db/examples/nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c6f90c86b1fb6905edc3874dc929011048a8cc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/nested_function.py @@ -0,0 +1,27 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2), torch.ones(2)), + tags={"python.closure"}, +) +class NestedFunction(torch.nn.Module): + """ + Nested functions are traced through. Side effects on global captures + are not supported though. + """ + def __init__(self): + super().__init__() + + def forward(self, a, b): + x = a + b + z = a - b + + def closure(y): + nonlocal x + x += 1 + return x * y + z + + return closure(x) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/null_context_manager.py b/MLPY/Lib/site-packages/torch/_export/db/examples/null_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b9856a9f41d0eae30d070d38fb9eec67f92e0779 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/null_context_manager.py @@ -0,0 +1,26 @@ +import contextlib + +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.context-manager"}, +) +class NullContextManager(torch.nn.Module): + """ + Null context manager in Python will be traced out. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + """ + Null context manager in Python will be traced out. + """ + ctx = contextlib.nullcontext() + with ctx: + return x.sin() + x.cos() diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/optional_input.py b/MLPY/Lib/site-packages/torch/_export/db/examples/optional_input.py new file mode 100644 index 0000000000000000000000000000000000000000..2cbf1604c51e85029f7708312f7a5dcbbc0478f5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/optional_input.py @@ -0,0 +1,19 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.randn(2, 3),), + tags={"python.object-model"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, +) +class OptionalInput(torch.nn.Module): + """ + Tracing through optional input is not supported yet + """ + + def forward(self, x, y=torch.ones(2, 3)): + if y is not None: + return x + y + return x diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/pytree_flatten.py b/MLPY/Lib/site-packages/torch/_export/db/examples/pytree_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..efc565e7e507940f87da4dcb2290e17892f3050f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/pytree_flatten.py @@ -0,0 +1,20 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel +from torch.utils import _pytree as pytree + + +@export_case( + example_inputs=({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), + support_level=SupportLevel.SUPPORTED, +) +class PytreeFlatten(torch.nn.Module): + """ + Pytree from PyTorch can be captured by TorchDynamo. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + y, spec = pytree.tree_flatten(x) + return y[0] + 1 diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/scalar_output.py b/MLPY/Lib/site-packages/torch/_export/db/examples/scalar_output.py new file mode 100644 index 0000000000000000000000000000000000000000..eca92154efe844f824255995c6eab6071343012f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/scalar_output.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case +from torch.export import Dim + +x = torch.ones(3, 2) +dim1_x = Dim("dim1_x") + +@export_case( + example_inputs=(x,), + tags={"torch.dynamic-shape"}, + dynamic_shapes={"x": {1: dim1_x}}, +) +class ScalarOutput(torch.nn.Module): + """ + Returning scalar values from the graph is supported, in addition to Tensor + outputs. Symbolic shapes are captured and rank is specialized. + """ + def __init__(self): + super().__init__() + + def forward(self, x): + return x.shape[1] + 1 diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/specialized_attribute.py b/MLPY/Lib/site-packages/torch/_export/db/examples/specialized_attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdcbaa1a2ee3f840c4f01c1d5d37e27cef52a10 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/specialized_attribute.py @@ -0,0 +1,29 @@ +from enum import Enum + +import torch + +from torch._export.db.case import export_case + + +class Animal(Enum): + COW = "moo" + + +@export_case( + example_inputs=(torch.ones(3, 2),), +) +class SpecializedAttribute(torch.nn.Module): + """ + Model attributes are specialized. + """ + + def __init__(self): + super().__init__() + self.a = "moo" + self.b = 4 + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + self.b + else: + raise ValueError("bad") diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/static_for_loop.py b/MLPY/Lib/site-packages/torch/_export/db/examples/static_for_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..cfdd70566f839315cdbe62022c63e08c6709add4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/static_for_loop.py @@ -0,0 +1,22 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"python.control-flow"}, +) +class StaticForLoop(torch.nn.Module): + """ + A for loop with constant number of iterations should be unrolled in the exported graph. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + ret = [] + for i in range(10): # constant + ret.append(i + x) + return ret diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/static_if.py b/MLPY/Lib/site-packages/torch/_export/db/examples/static_if.py new file mode 100644 index 0000000000000000000000000000000000000000..78b43cfd93d4a596c72a5a2d0003bfe0497416c2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/static_if.py @@ -0,0 +1,23 @@ +import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.ones(3, 2, 2),), + tags={"python.control-flow"}, +) +class StaticIf(torch.nn.Module): + """ + `if` statement with static predicate value should be traced through with the + taken branch. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + if len(x.shape) == 3: + return x + torch.ones(1, 1, 1) + + return x diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/tensor_setattr.py b/MLPY/Lib/site-packages/torch/_export/db/examples/tensor_setattr.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d86efe02decd7d7498d9fc82535a7464c3a0dc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/tensor_setattr.py @@ -0,0 +1,17 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.randn(3, 2), "attr"), + tags={"python.builtin"}, + support_level=SupportLevel.SUPPORTED, +) +class TensorSetattr(torch.nn.Module): + """ + setattr() call onto tensors is not supported. + """ + def forward(self, x, attr): + setattr(x, attr, torch.randn(3, 2)) + return x + 4 diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/torch_sym_min.py b/MLPY/Lib/site-packages/torch/_export/db/examples/torch_sym_min.py new file mode 100644 index 0000000000000000000000000000000000000000..e79a22b66e522a0f88b837c2a9c6bc4d1ffc69f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/torch_sym_min.py @@ -0,0 +1,17 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.operator"}, + support_level=SupportLevel.NOT_SUPPORTED_YET, +) +class TorchSymMin(torch.nn.Module): + """ + torch.sym_min operator is not supported in export. + """ + + def forward(self, x): + return x.sum() + torch.sym_min(x.size(0), 100) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/type_reflection_method.py b/MLPY/Lib/site-packages/torch/_export/db/examples/type_reflection_method.py new file mode 100644 index 0000000000000000000000000000000000000000..031328c7dc3afcdde43f40bf2dade03657172ee7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/type_reflection_method.py @@ -0,0 +1,41 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel, export_rewrite_case + + +class A: + @classmethod + def func(cls, x): + return 1 + x + + +@export_case( + example_inputs=(torch.ones(3, 4),), + tags={"python.builtin"}, + support_level=SupportLevel.SUPPORTED, +) +class TypeReflectionMethod(torch.nn.Module): + """ + type() calls on custom objects followed by attribute accesses are not allowed + due to its overly dynamic nature. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + a = A() + return type(a).func(x) + + +@export_rewrite_case(parent=TypeReflectionMethod) +class TypeReflectionMethodRewrite(torch.nn.Module): + """ + Custom object class methods will be inlined. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return A.func(x) diff --git a/MLPY/Lib/site-packages/torch/_export/db/examples/user_input_mutation.py b/MLPY/Lib/site-packages/torch/_export/db/examples/user_input_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..43906a88b15e172edb7d5b917b8d9a37b515f11e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/examples/user_input_mutation.py @@ -0,0 +1,18 @@ +import torch + +from torch._export.db.case import export_case, SupportLevel + + +@export_case( + example_inputs=(torch.ones(3, 2),), + tags={"torch.mutation"}, + support_level=SupportLevel.SUPPORTED, +) +class UserInputMutation(torch.nn.Module): + """ + Directly mutate user input in forward + """ + + def forward(self, x): + x.mul_(2) + return x.cos() diff --git a/MLPY/Lib/site-packages/torch/_export/db/gen_example.py b/MLPY/Lib/site-packages/torch/_export/db/gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..bcba6c92ef121ac11e77e657c58babeac4e79ad0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/gen_example.py @@ -0,0 +1,28 @@ +import os +import sys + +import torch._export.db.examples as examples + +TEMPLATE = '''import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.randn(3, 2),), + tags={{}}, +) +def {case_name}(x): + """ + """ + + return +''' + +if __name__ == "__main__": + assert len(sys.argv) == 2 + root_dir = examples.__name__.replace(".", "/") + assert os.path.exists(root_dir) + with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f: + print("Writing to", f.name, "...") + f.write(TEMPLATE.format(case_name=sys.argv[1])) diff --git a/MLPY/Lib/site-packages/torch/_export/db/logging.py b/MLPY/Lib/site-packages/torch/_export/db/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..44f68caff77429c97ce11153e0cf29a2550422a2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/db/logging.py @@ -0,0 +1,2 @@ +def exportdb_error_message(case_name: str): + return "" diff --git a/MLPY/Lib/site-packages/torch/_export/error.py b/MLPY/Lib/site-packages/torch/_export/error.py new file mode 100644 index 0000000000000000000000000000000000000000..12d2e594c11d381912fc357ca1250f9ae151bffd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/error.py @@ -0,0 +1,56 @@ +from enum import Enum + + +class ExportErrorType(Enum): + # User providing invalid inputs to either tracer, or other public facing APIs + INVALID_INPUT_TYPE = 1 + + # User returning values from their models that we don’t support. + INVALID_OUTPUT_TYPE = 2 + + # Generated IR does not conform to Export IR Specification. + VIOLATION_OF_SPEC = 3 + + # User’s code contains types and functionalities we don’t support. + NOT_SUPPORTED = 4 + + # User's code didn't provide necessary details for us to successfully trace and export. + # For example, we use a lot of decorators and ask users to annotate their model. + MISSING_PROPERTY = 5 + + # User is using an API without proper initialization step. + UNINITIALIZED = 6 + + +def internal_assert(pred: bool, assert_msg: str) -> None: + """ + This is exir's custom assert method. It internally just throws InternalError. + Note that the sole purpose is to throw our own error while maintaining similar syntax + as python assert. + """ + + if not pred: + raise InternalError(assert_msg) + + +class InternalError(Exception): + """ + Raised when an internal invariance is violated in EXIR stack. + Should hint users to report a bug to dev and expose the original + error message. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExportError(Exception): + """ + This type of exception is raised for errors that are directly caused by the user + code. In general, user errors happen during model authoring, tracing, using our public + facing APIs, and writing graph passes. + """ + + def __init__(self, error_code: ExportErrorType, message: str) -> None: + prefix = f"[{error_code}]: " + super().__init__(prefix + message) diff --git a/MLPY/Lib/site-packages/torch/_export/exported_program.py b/MLPY/Lib/site-packages/torch/_export/exported_program.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd7788a3da1403fb6c302cad7d7510096b9f904 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/exported_program.py @@ -0,0 +1,50 @@ +import warnings + + +import torch +import torch.fx + + +# TODO(ycao): This is added to avoid breaking existing code temporarily. +# Remove when migration is done. +from torch.export.graph_signature import ( + ExportBackwardSignature, + ExportGraphSignature, +) + +from torch.export.exported_program import ( + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) + + + +__all__ = [ + "ExportBackwardSignature", + "ExportGraphSignature", + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", +] + + +def _create_graph_module_for_export(root, graph): + try: + gm = torch.fx.GraphModule(root, graph) + except SyntaxError: + # If custom objects stored in memory are being used in the graph, + # the generated python code will result in a syntax error on the custom + # object, since it is unable to parse the in-memory object. However + # we can still run the graph eagerly through torch.fx.Interpreter, + # so we will bypass this error. + warnings.warn( + "Unable to execute the generated python source code from " + "the graph. The graph module will no longer be directly callable, " + "but you can still run the ExportedProgram, and if needed, you can " + "run the graph module eagerly using torch.fx.Interpreter." + ) + gm = torch.fx.GraphModule(root, torch.fx.Graph()) + gm._graph = graph + + return gm diff --git a/MLPY/Lib/site-packages/torch/_export/non_strict_utils.py b/MLPY/Lib/site-packages/torch/_export/non_strict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84e4cf6a9e2fd60e59b2769572b40ca6f57b7539 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/non_strict_utils.py @@ -0,0 +1,258 @@ +import inspect +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple, Union + +import torch +from torch._dynamo.source import ( + AttrSource, + GetItemSource, + LocalSource, + TensorProperty, + TensorPropertySource, +) +from torch._dynamo.variables.builder import TrackedFake +from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim +from torch._guards import Source +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import Constraint +from torch.export.graph_signature import CustomObjArgument +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + DimDynamic, + EqualityConstraint, + ShapeEnv, + StatelessSymbolicContext, +) +from torch.utils._pytree import ( + GetAttrKey, + KeyPath, + MappingKey, + SequenceKey, + tree_map_with_path, +) + + +def key_path_to_source(kp: KeyPath) -> Source: + """ + Given a key path, return the source for the key path. + """ + source: Source = LocalSource("args") + for k in kp: + if isinstance(k, SequenceKey): + source = GetItemSource(source, k.idx) + elif isinstance(k, MappingKey): + source = GetItemSource(source, k.key) + elif isinstance(k, GetAttrKey): + source = AttrSource(source, k.name) + else: + raise ValueError(f"Unknown KeyEntry {k}") + + return source + + +def _is_constant_argument(t): + return t is None or isinstance(t, (int, float, bool, str)) + + +def fakify( + mode: FakeTensorMode, + kp: KeyPath, + t: Any, + t_constraints: Dict[int, Dict[int, Constraint]], + sources: Dict[Tuple[int, int], List[Source]], +): + source = key_path_to_source(kp) + if _is_constant_argument(t) or isinstance(t, torch.ScriptObject): + return t + if not isinstance(t, torch.Tensor): + raise ValueError(f"Unsupported input type {type(t)}") + n_dims = len(t.shape) + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC] * n_dims, + constraint_sizes=[None] * n_dims, + ) + t_id = id(t) + if t_id in t_constraints: + for i, constraint in t_constraints[t_id].items(): + symbolic_context.constraint_sizes[i] = constraint.constraint_range + symbolic_context.dynamic_sizes[i] = DimDynamic.DYNAMIC + src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) + sources[(t_id, i)].append(src) + mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name + fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) + mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) + return fake + + +def make_fake_params_buffers( + fake_mode: FakeTensorMode, + params_buffers: Dict[str, torch.Tensor], +) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]: + faked_params_buffers = {} + for key, value in params_buffers.items(): + faked_params_buffers[key] = fake_mode.from_tensor(value, static_shapes=True) + return faked_params_buffers + + +def make_fake_inputs(nn_module, args, kwargs, constraints): + """ + Given an nn module, example inputs, and constraints, return a new fake mode, + fake inputs created in that mode whose dynamic shape dimensions are constrained + by the given ranges, and sources for pairs of dynamic shape dimensions that are + constrained to be equal. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following pre-tracing steps: + # - Fakify inputs. + # - Process input shape equalities. + # In strict, these steps are spread across multiple files: + # - output_graph.py fakifies inputs. + # - [post-tracing] guards.py processes input shape equalities. + + t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) + for constraint in constraints: + t_constraints[constraint.t_id][constraint.dim] = constraint + if constraint.shared is not None: + t_constraints[constraint.shared.t_id][constraint.shared.dim] = constraint + + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + + fake_mode = FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields), + allow_non_fake_inputs=True, + ) + if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: + raise ValueError( + "Detected fake_mode does not have a shape_env with tracked fakes. " + "If you constructed the module under a FakeTensorMode, " + "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))" + ) + + with fake_mode: + original_signature = inspect.signature(nn_module.forward) + sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) + fake_args, fake_kwargs = tree_map_with_path( + lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), + (args, kwargs), + ) + + from sympy import Symbol + + source_pairs: List[Tuple[Source, Source]] = [] + derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = [] + phantom_symbols: Dict[str, Symbol] = {} + for constraint in constraints: + torch.export.dynamic_shapes._process_equalities( + constraint, + lambda t_id, dim: sources[(t_id, dim)], + fake_mode.shape_env, + source_pairs, + derived_equalities, + phantom_symbols, + ) + + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + warn_only=False, + ) + return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature + + +def make_constraints( + fake_mode, + equalities_inputs, + original_signature, + gm, +): + """ + Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, + and a graph module, produce guards on the fake mode's shape env (raising constraint + violations if any), solve (to suggest simplifications or fixes), and return the + resulting range constraints and equality constraints. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following post-tracing steps: + # - Produce guards. + # - Solve constraints. + # - Install shape metadata in IR. + # In strict, these steps are spread across multiple files: + # - guards.py produces guards. + # - eval_frame.py solves constraints + # - _trace.py installs shape metadata in IR. + + shape_env = fake_mode.shape_env + placeholders = [tf.fake for tf in shape_env.tracked_fakes] + sources = [tf.source for tf in shape_env.tracked_fakes] + input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] + constraint_violation_error = None + try: + shape_env.produce_guards( + placeholders, + sources, + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + ignore_static=False, + ) + except ConstraintViolationError as e: + constraint_violation_error = e + + shape_env.frozen = True + dim_constraints = shape_env.dim_constraints + if dim_constraints is None: + # Expected when shape_env.produce_guards throws an early constraint violation error. + # There is nothing to solve for in this case. + # TODO(avik): Maybe record the constraint violation error instead and replay later? + assert constraint_violation_error + raise constraint_violation_error + dim_constraints.solve() + dim_constraints.remove_redundant_dynamic_results() + forced_specializations = dim_constraints.forced_specializations() + msg = dim_constraints.prettify_results( + original_signature, constraint_violation_error, forced_specializations + ) + if constraint_violation_error: + constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) + elif forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + if constraint_violation_error: + raise constraint_violation_error + + range_constraints = {} + input_dims = defaultdict(list) + free_symbols = set() + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + if _is_constant_argument(node.meta["val"]) or isinstance( + node.meta["val"], CustomObjArgument + ): + continue + for i, d in enumerate(node.meta["val"].shape): + if isinstance(d, torch.SymInt): + # Look up the range constraint for the symbol corresponding to this shape dimension + # and store it indexed by the symbolic expression corresponding to it. + # NOTE(avik): Use node._expr instead of node.expr for the lookup here because + # we want the symbol, not its replacement, which could be an expression. Maybe + # there's a better way to do this, e.g., by (re)computing value ranges for expressions? + range_constraints[d.node.expr] = shape_env.var_to_range[d.node._expr] + input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) + free_symbols.update(d.node.expr.free_symbols) + + for symbol in free_symbols: + if symbol not in range_constraints: + # Placeholders can have symbolic shapes that are derived expressions. + # The above code will record direct range constraints for them + # so that we can do runtime assertions. In addition, for serde checks + # we want to record range constraints for their root symbols. + range_constraints[symbol] = shape_env.var_to_range[symbol] + + return range_constraints diff --git a/MLPY/Lib/site-packages/torch/_export/pass_base.py b/MLPY/Lib/site-packages/torch/_export/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..23e187ee6b7213d4c722200bee34946ecc0025bf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/pass_base.py @@ -0,0 +1,435 @@ +import operator +import traceback +import typing +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from functorch.experimental.control_flow import _unstack_pytree +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._subclasses import FakeTensor, UnsupportedFakeTensorException +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import traceback as fx_traceback +from torch.fx.experimental.proxy_tensor import PythonKeyTracer +from torch.fx.graph import CodeGen +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils import _pytree as pytree + + +__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] + + +Argument = Any +Value = Any +Fn = Callable[..., Any] +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +_TORCH_SYM_OPS: Set[Callable] = { + torch.sym_int, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, +} + + +class ExportPassBaseError(RuntimeError): + pass + + +class _ExportPassBaseDeprecatedDoNotUse(PassBase): + """ + Interpreter-based pass class to help users maintain the IR spec while writing + transformations. + """ + + @staticmethod + def _create_dummy_node_metadata(): + return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) + + + class ExportTracer(PythonKeyTracer): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: + super().__init__() + self.callback = callback + self.root = torch.nn.Module() + self.graph = torch.fx.Graph() + self.graph.set_codegen(codegen) + self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.submodules: Dict[torch.nn.Module, str] = {} + + def trace(self) -> None: + raise ExportPassBaseError("ExportTracer doesn't support trace().") + + def create_arg(self, a: Argument) -> torch.fx.Node: + if isinstance(a, torch.nn.Module): + if a not in self.submodules: + name_submodule = f"submodule_{len(self.submodules)}" + self.root.add_module(name_submodule, a) + self.submodules[a] = name_submodule + elif isinstance(a, FakeTensor): + if not hasattr(a, "constant") or a.constant is None: + raise ExportPassBaseError(f"Cannot add {a} to graph.") + a = a.constant + node = super().create_arg(a) + if ( + isinstance(a, torch.Tensor) + and isinstance(node, torch.fx.Node) + and node.op == "get_attr" + ): + self.set_metadata(node, a) + self.callback.on_attr(ProxyValue(a, node)) + return node + + def set_metadata( + self, node: torch.fx.Node, value: Argument, + ) -> None: + # propagate the fake tensor or sym nodes + def make_val( + x: Argument, + ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: + if isinstance(x, FakeTensor): + return x + elif isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + # TODO we should allocate static shapes + # for param/buffer values + if isinstance(x, torch.nn.Parameter): + fake_tensor = self.fake_tensor_mode.from_tensor( + x, static_shapes=True + ) + else: + fake_tensor = self.fake_tensor_mode.from_tensor(x) + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + print( + "Fakeifying a Tensor subclass is not supported \ + right now. Instead a TensorMetadata is used." + ) + fake_tensor = None + return fake_tensor + elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): + return x + else: + return None + + node.meta["val"] = pytree.tree_map(make_val, value) + + # Set the tensor_metadata for values that do not have a corresponding FakeTensor + def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: + if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + _ = self.fake_tensor_mode.from_tensor(x) + tensor_meta = None + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + tensor_meta = _extract_tensor_metadata(x) + return tensor_meta + else: + return None + + node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + + class ExportInterpreter(fx.Interpreter): + def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: + super().__init__(gm) + self.callback = callback + self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + + def placeholder( + self, + target: str, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + arg = super().placeholder(target, args, kwargs) + return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) + + def output( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + return self.callback.output(args[0], NodeMetadata(self.node.meta)).data + + def call_function( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> ProxyValue: + meta = NodeMetadata(self.node.meta) + + if target == operator.getitem: + value, key = args + return self.callback.call_getitem(value, key, meta) + elif getattr(target, "__module__", None) in {"_operator", "math"}: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif target in _TORCH_SYM_OPS: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return self.callback.call_operator( + target, + args, + kwargs, + meta, + ) + elif target == torch.ops.higher_order.cond: + pred, true_fn, false_fn, inputs = args + return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) + elif target == torch.ops.higher_order.map_impl: + f, mapped_args, operands = args # type: ignore[assignment] + return self.callback.call_map(f, mapped_args, operands, meta) + # For other unregistered HigherOrderOps, just interpret them blindly + elif isinstance(target, torch._ops.HigherOrderOperator): + return self.callback._fx( + "call_function", + target, + args, + kwargs, + meta, + ) + else: + raise ExportPassBaseError(f"Unsupported target type: {target}") + + def get_attr( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> Argument: + return super().get_attr(target, args, kwargs) + + def call_module( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + ) -> None: + raise ExportPassBaseError("call_module is not supported.") + + def call_method( + self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> None: + raise ExportPassBaseError("call_method is not supported.") + + def run_node(self, n: torch.fx.Node) -> Argument: + self.node = n + self.callback.node_debug_str = n.format_node() + return super().run_node(n) + + def __init__(self) -> None: + self.interpreter = torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + self.tracer = self.ExportTracer(self, CodeGen()) + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self._initialized = True + self.node_debug_str: typing.Optional[str] = None + + def _fx( + self, + kind: str, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + args_data, kwargs_data = pytree.tree_map_only( + ProxyValue, lambda x: x.data, (args, kwargs) + ) + res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) + args_proxy, kwargs_proxy = pytree.tree_map_only( + ProxyValue, lambda x: x.proxy, (args, kwargs) + ) + + name = None + if isinstance(target, torch._ops.OpOverload): + name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) + + res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) + res_proxy.node.meta.update(meta.data) + self.tracer.set_metadata(res_proxy.node, res_data) + return ProxyValue(res_data, res_proxy) + + def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: + # TODO(angelayi): Update this with what we decide to do for metadata in + # the exported graph module + if (args := graph_module.meta.get("args", None)) is not None: + return list(args) + + def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: + if "val" in node.meta: + fake = node.meta["val"] + if hasattr(fake, "constant") and fake.constant is not None: + return fake.constant + return fake + elif tensor_meta := node.meta.get("tensor_meta"): + assert self.fake_tensor_mode is not None + return FakeTensor( + self.fake_tensor_mode, + torch.empty( + tensor_meta.shape, + dtype=tensor_meta.dtype, + device="meta", + requires_grad=tensor_meta.requires_grad, + memory_format=tensor_meta.memory_format, + ), + torch.device("cpu"), + ) + elif len(node.users) == 0: + return None + raise ExportPassBaseError( + f"Cannot construct an input for graph module: {graph_module}.", + ) + + return [ + extract_input(node) + for node in graph_module.graph.nodes + if node.op == "placeholder" + ] + + def on_attr(self, attr: ProxyValue) -> None: + pass + + def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: + arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) + arg_proxy.node.meta = meta.data + self.tracer.set_metadata(arg_proxy.node, arg) + return ProxyValue(arg, arg_proxy) + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", op, args, kwargs, meta) + + def call_sym( + self, + target: Fn, + args: Tuple[Argument, ...], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", target, args, {}, meta) + + def call_cond( + self, + pred: ProxyValue, + true_fn: torch.fx.GraphModule, + false_fn: torch.fx.GraphModule, + inputs: List[Argument], + meta: NodeMetadata, + ) -> ProxyValue: + true_branch = self.call_submodule(true_fn, tuple(inputs)) + false_branch = self.call_submodule(false_fn, tuple(inputs)) + assert true_branch is not None + assert false_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.cond, + (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), + {}, + meta, + ) + + def call_map( + self, + f: torch.fx.GraphModule, + mapped_args: List[ProxyValue], + operands: List[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + xs = _unstack_pytree([arg.data for arg in mapped_args])[0] + f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) + assert f_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.map_impl, + (f_branch.graph_module, mapped_args, operands), + {}, + meta, + ) + + def call_getitem( + self, value: ProxyValue, key: int, meta: NodeMetadata + ) -> ProxyValue: + return self._fx("call_function", operator.getitem, (value, key), {}, meta) + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + return self._fx("output", "output", (results,), {}, meta) + + def call_submodule( + self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] + ) -> PassResult: + prev_tracer, self.tracer = self.tracer, self.ExportTracer( + self, graph_module.graph._codegen + ) + self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode + interpreter = self.ExportInterpreter(self, graph_module) + prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) + with fx_traceback.preserve_node_meta(): + interpreter.run(*inputs_data) + + new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + self.tracer = prev_tracer + self.interpreter = prev_interpreter + return PassResult( + new_graph_module, + True, + ) + + def call(self, graph_module: fx.GraphModule) -> PassResult: + if not getattr(self, "_initialized", False): + raise ExportPassBaseError( + "ExportPass is not initialized with __init__().", + ) + + inputs = self.inputs(graph_module) + + fake_tensor_mode = None + for i in inputs: + if isinstance(i, FakeTensor): + assert ( + fake_tensor_mode is None or fake_tensor_mode is i.fake_mode + ), "Multiple fake tensor mode detected." + fake_tensor_mode = i.fake_mode + if fake_tensor_mode is None: + self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) + fake_tensor_mode = nullcontext() # type: ignore[assignment] + dispatcher_mode = nullcontext() # type: ignore[assignment] + else: + fake_tensor_mode.allow_non_fake_inputs = True + self.tracer.fake_tensor_mode = fake_tensor_mode + dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] + self.fake_tensor_mode = self.tracer.fake_tensor_mode + + with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] + result = self.call_submodule(graph_module, tuple(inputs)) + + return result diff --git a/MLPY/Lib/site-packages/torch/_export/pass_infra/__init__.py b/MLPY/Lib/site-packages/torch/_export/pass_infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dcfaa11cf470664a75df9eb99cac9db6e8a4e4c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73550538234b4300e6936735a9c355bc833d0d26 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da2b6c2f0c5d4ec80ffe0310a4583192eed8b36c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/pass_infra/node_metadata.py b/MLPY/Lib/site-packages/torch/_export/pass_infra/node_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..a83ea3bb9eedadc349f4bdba8cbbf22850bb5afc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/pass_infra/node_metadata.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Set + + +NodeMetadataValue = Any + + +PROTECTED_KEYS: Set[str] = { + "val", + "stack_trace", + "nn_module_stack", + "debug_handle", + "tensor_meta", +} + + +class NodeMetadata: + def __init__(self, data: Dict[str, Any]) -> None: + self.data: Dict[str, Any] = data.copy() + + def __getitem__(self, key: str) -> NodeMetadataValue: + return self.data[key] + + def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: + if key in PROTECTED_KEYS: + raise RuntimeError(f"Could not override node key: {key}") + self.data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.data + + def copy(self) -> "NodeMetadata": + return NodeMetadata(self.data.copy()) diff --git a/MLPY/Lib/site-packages/torch/_export/pass_infra/proxy_value.py b/MLPY/Lib/site-packages/torch/_export/pass_infra/proxy_value.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0b90c8ddc6f0b2bde286b14edaa610394c054a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/pass_infra/proxy_value.py @@ -0,0 +1,41 @@ +# pyre-strict +from typing import Union + +import torch + + +class ProxyValue: + # pyre-ignore + def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]): + # pyre-ignore + self.data = data + self.proxy_or_node = proxy + + @property + def node(self) -> torch.fx.Node: + if isinstance(self.proxy_or_node, torch.fx.Node): + return self.proxy_or_node + assert isinstance(self.proxy_or_node, torch.fx.Proxy) + return self.proxy_or_node.node + + @property + def proxy(self) -> torch.fx.Proxy: + if not isinstance(self.proxy_or_node, torch.fx.Proxy): + raise RuntimeError( + f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" + ) + return self.proxy_or_node + + def to_tensor(self) -> torch.Tensor: + assert isinstance(self.data, torch.Tensor) + return self.data + + def is_tensor(self) -> bool: + return isinstance(self.data, torch.Tensor) + + # pyre-ignore + def __iter__(self): + yield from self.data + + def __bool__(self) -> bool: + return bool(self.data) diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__init__.py b/MLPY/Lib/site-packages/torch/_export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ad040cae5672be1b58bfe523d4fb57e41d2344 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/__init__.py @@ -0,0 +1 @@ +from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a207bfb87ba5b3c9807d0be604f5f8e257f3479 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7284e2e6f899ad0095460dd0234407d140b997 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7206c4f192ecc2c1a008d964282738dad6cc9d56 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a10ecf49e7586a4fb78d706407d2b148d56b2ce Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5a81b99b27430b8ffa3b6963c410b15bf3c093b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65e76302967ab8d0c164923daf337c0b0680b32a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c488b63a14a95f3f06f8ca1fc560c91d636c4d5c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e2d36525e02ca7d32c59a820699a65fa837812e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_sym_size_ops_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b06419e739f00028c83fffe989390d3688a55cdb Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..c8457fb4d736163c9d16a83e4fcd2efd18149e19 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -0,0 +1,231 @@ +import math +import operator +import traceback +from functools import partial +from typing import Callable, Dict, List, NamedTuple, Set + +import sympy + +import torch +import torch.fx +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, ProxyValue, PassResult +from torch.utils._sympy.value_ranges import ValueRanges +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + +__all__ = ["InputDim"] + + +class InputDim(NamedTuple): + input_name: str + dim: int + + +def _convert_to_int(val): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError( + "Export constraints cannot be non-integer expressions" + ) + + +def _convert_range_to_int(range: ValueRanges): + assert isinstance(range, ValueRanges) + min_val = _convert_to_int(range.lower) + max_val = _convert_to_int(range.upper) + return min_val, max_val + + +class _AddRuntimeAssertionsForInlineConstraintsPass(_ExportPassBaseDeprecatedDoNotUse): + def __init__( + self, + range_constraints: Dict[sympy.Symbol, ValueRanges], + ): + super().__init__() + self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set() + self.counter = 0 + + def _assert_range_constraint(self, proxy, lower, upper, assert_msg): + if lower > -math.inf: + self._insert_assert_async(operator.ge, proxy, lower, assert_msg) + + if upper < math.inf: + self._insert_assert_async(operator.le, proxy, upper, assert_msg) + + def _insert_assert_async(self, operator, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + self.counter += 1 + cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata()) + cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata()) + super().call_operator( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + self._create_dummy_node_metadata(), + ) + + def call_operator(self, op, args, kwargs, meta) -> ProxyValue: + ret = super().call_operator(op, args, kwargs, meta) + if "val" not in meta: + return ret + + val = meta["val"] + + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requires the proxy for ret[1], etc. + def add_assertions(val): + call_backs: List[Callable] = [] + messages: List[str] = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node.expr + if symbol in self.existing_inline_assertions: + return call_backs, messages + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol): + if symbol in self._asserts_generated_unbacked_symbols: + return call_backs, messages + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'u' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial(self._assert_range_constraint, lower=min_val, upper=max_val) + ) + messages.append(assert_msg) + self._asserts_generated_unbacked_symbols.add(symbol) + + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + def sym_size_cb(proxy, assert_msg, dim): + dim_proxy = super( + _AddRuntimeAssertionsForInlineConstraintsPass, + self + ).call_operator( + torch.ops.aten.sym_size.int, + (proxy, dim), + {}, + self._create_dummy_node_metadata(), + ) + cb(proxy=dim_proxy, assert_msg=assert_msg) + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(proxy=ret, assert_msg=f"{ret.node}" + msg) + return ret + + def call(self, graph_module): + self.existing_inline_assertions = _get_existing_inline_assertions( + graph_module, self.range_constraints + ) + + # Add runtime asserts for inline constraints + val = super().call(graph_module) + + # Sometimes this pass would return a wrong graph where we have mismatched + # node names in signature. Before we fix it, let's just skip it. + if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass: + return PassResult(graph_module, False) + + # Populate the stack trace with dummy vals to respect IR + for node in val.graph_module.graph.nodes: + if not node.meta.get("stack_trace", None): + node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) + + return PassResult(val.graph_module, val.modified) + + +def _get_existing_inline_assertions( + graph_module: torch.fx.GraphModule, + range_constraints: Dict[sympy.Symbol, ValueRanges], +) -> Dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {} + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + # Find all the existing inline assertions. They will look something like: + # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) + # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) + # %scalar_tensor = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,), kwargs = {}) + # %_assert_async = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor, "..."), kwargs = {}) + for node in module.graph.nodes: + if node.target != torch.ops.aten._assert_async.msg: + continue + + scalar_tensor_arg = node.args[0] + if not ( + scalar_tensor_arg.op == "call_function" and + scalar_tensor_arg.target == torch.ops.aten.scalar_tensor.default + ): + continue + + compare_arg = scalar_tensor_arg.args[0] + if not ( + compare_arg.op == "call_function" and + compare_arg.target in (operator.le, operator.ge) and + len(compare_arg.args) == 2 + ): + continue + + compare_op = compare_arg.target + maybe_symint_arg, compare_int = compare_arg.args + + # x >= 0 will sometimes be canonicalized to -x <= 0, so in some + # cases the operation before the comparison is to multiply by -1. We + # can undo the canonicalization here + if ( + maybe_symint_arg.op == "call_function" and + maybe_symint_arg.target == operator.mul and + maybe_symint_arg.args[0] == -1 + ): + maybe_symint_arg = maybe_symint_arg.args[1] + compare_op = operator.ge + compare_int = -1 * compare_int + + if not ( + "val" in maybe_symint_arg.meta and + isinstance(maybe_symint_arg.meta["val"], torch.SymInt) + ): + continue + + symint = maybe_symint_arg.meta["val"].node.expr + if not isinstance(symint, sympy.Symbol): + continue + + if symint not in range_constraints: + raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}") + + found_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf)) + + if compare_arg.target == operator.le: + existing_inline_assertions[symint] = ValueRanges( + lower=found_range.lower, upper=compare_int + ) + elif compare_arg.target == operator.ge: + existing_inline_assertions[symint] = ValueRanges( + lower=compare_int, upper=found_range.upper + ) + + return existing_inline_assertions diff --git a/MLPY/Lib/site-packages/torch/_export/passes/collect_tracepoints_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/collect_tracepoints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..102bb87d75441317caa7fa4e0d0ef0ee5c89668c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/collect_tracepoints_pass.py @@ -0,0 +1,66 @@ +import operator + +import torch + +from torch.export.exported_program import ConstantArgument, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +__all__ = ["CollectTracepointsPass"] + + +class CollectTracepointsPass(PassBase): + """ + Performs constant folding and constant propagation. + """ + + def __init__(self, specs, sig) -> None: + super().__init__() + self.specs = specs + self.sig = sig + + def call(self, gm): + def get_arg_spec(arg): + if isinstance(arg, torch.fx.Node): + if isinstance(arg.meta.get("val"), torch.Tensor): + return TensorArgument(name=arg.name) + else: + raise AssertionError( + "Symint input is not implemented yet for submodule call signature." + ) + else: + return ConstantArgument(value=arg) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + for i, arg in enumerate(node.args): + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + self.specs[node.kwargs["path"]].inputs.append( + get_arg_spec(arg) + ) + elif kind == "module_call_outputs": + self.specs[node.kwargs["path"]].outputs.append( + get_arg_spec(arg) + ) + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + if isinstance(arg, torch.fx.Node): + for user in node.users: + assert user.op == "call_function" + assert user.target == operator.getitem + assert isinstance(user.args[1], int) + if user.args[1] == i: + user.replace_all_uses_with(arg) + self.sig.replace_all_uses(user.name, arg.name) + break + users = list(node.users) + for user in users: + assert len(user.users) == 0 + gm.graph.erase_node(user) + gm.graph.erase_node(node) + return PassResult(gm, True) diff --git a/MLPY/Lib/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5178bd6e3a7e6812a95409160728753a49a7c8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -0,0 +1,94 @@ +import copy +from typing import Dict, Optional, Tuple, List + +import torch +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._ops import OpOverload + +aten = torch.ops.aten + +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = { + aten.sym_constrain_range.default: aten._functional_sym_constrain_range, + aten._assert_async.msg: aten._functional_assert_async.msg, +} + + +class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Functionalize ops with side effect in graph module by replacing the op with + functional version of it. A new dependency token (`dep_token`) will be + created and propagated through functional ops to output. + For example: + ``` + def f(x): + sym_constrain_range(x.shape[0], min=1, max=3) + return x.add(3) + ``` + Will be transformed to: + ``` + def f(x): + dep_token0 = _make_dep_token() + dep_token1 = _functional_sym_constrain_range( + x.shape[0], min=1, max=3, dep_token=dep_token0 + ) + + return x.add(3), dep_token1 + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._dep_token: Optional[ProxyValue] = None + self._next_dep_token_index: Optional[int] = None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Early return if no non-functional assertions. + if not any( + n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS + for n in graph_module.graph.nodes + ): + return PassResult(graph_module=graph_module, modified=False) + + gm = copy.deepcopy(graph_module) + self._dep_token = None + self._next_dep_token_index = None + return super().call(gm) + + def call_operator( + self, + op: OpOverload, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: + return super().call_operator(op, args, kwargs, meta) + + if self._dep_token is None: + self._dep_token = super().call_operator( + aten._make_dep_token, + args=(), + kwargs={}, + meta=self._create_dummy_node_metadata(), + ) + self._dep_token.node.name = "dep_token0" + self._next_dep_token_index = 1 + + self._dep_token = super().call_operator( + _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], + args=args, + kwargs={**kwargs, "dep_token": self._dep_token}, + meta=meta, + ) + assert self._next_dep_token_index is not None + self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" + self._next_dep_token_index += 1 + + return self._dep_token + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + assert self._dep_token is not None + + return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/MLPY/Lib/site-packages/torch/_export/passes/lift_constants_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/lift_constants_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..253746402d79f9339183ca966fdb75aef1c5c683 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/lift_constants_pass.py @@ -0,0 +1,248 @@ +import collections +from typing import Any, Dict, Union + +import torch +from torch._export.verifier import SpecViolationError +from torch._guards import detect_fake_mode +from torch.export.exported_program import ( + ArgumentSpec, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) + + +class ConstantAttrMap(collections.abc.MutableMapping): + """A mapping class that understands how to use module constants (tensors and + ScriptObjects) as keys. We store tensors normally, but ScriptObjects are + stored by hash, because different torch.ScriptObjects can point to the same + underlying value (but we guarantee that they will `hash()` to the same value + if that's the case). + """ + + def __init__(self): + # Underlying dict that we use to implement this mapping. + self._constant_attrs: Dict[Union[int, torch.Tensor], Any] = {} + # Map from the hash(ScriptObject) to the ScriptObject itself. Used for + # APIs like `__iter__` that should look like they're returning the + # original ScriptObjects. + self._script_object_map: Dict[int, torch.ScriptObject] = {} + + def __getitem__(self, key: Union[torch.Tensor, torch.ScriptObject]) -> Any: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + assert isinstance(real_key, (int, torch.Tensor)) + return self._constant_attrs[real_key] + + def __setitem__( + self, key: Union[torch.Tensor, torch.ScriptObject], value: Any + ) -> None: + if isinstance(key, torch.ScriptObject): + self._constant_attrs[hash(key)] = value + self._script_object_map[hash(key)] = key + elif isinstance(key, torch.Tensor): + self._constant_attrs[key] = value + else: + raise TypeError( + f"Expected key to be a tensor or ScriptObject, got {type(key)}" + ) + + def __delitem__(self, key): + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + + del self._constant_attrs[real_key] + + def __iter__(self): + for key in self._constant_attrs: + if isinstance(key, int): + yield self._script_object_map[key] + else: + yield key + + def __len__(self): + return len(self._constant_attrs) + + def __contains__(self, key: object) -> bool: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + return real_key in self._constant_attrs + + +def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: + # The FQN of the constant tensor in the state dict should + # correspond to the module where the constant tensor was + # originally used. + parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] + if len(parent_fqn) > 0: + return f"{parent_fqn}.{constant_name}" + else: + return constant_name + + +def lift_constants_pass( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]: + """ + Takes a graph module, graph signature, and modifies them implace to lift any + constants (tensors or custom classes) as inputs to the graph. Returns a + dictionary of names to constants. + + Arguments: + gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. + graph_signature (ExportGraphSignature): This graph signature will be + mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. + constant_attrs (ConstantAttr): A mapping from a constant value to its + fully-qualified path in `gm`. This is used to maintain consistent + location of constants between the original module and the exported + version. + + Returns: + A dictionary of fqn => constant value. + """ + all_constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] = {} + + inputs = graph_signature.input_specs + num_custom_obj = sum( + input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs + ) + num_tensor_constants = sum( + input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs + ) + + fake_mode = detect_fake_mode( + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") + ) + + first_user_input_loc, first_user_input = 0, None + for node in gm.graph.nodes: + if node.op == "placeholder" and node.name in graph_signature.user_inputs: + first_user_input = node + break + first_user_input_loc += 1 + + lifted_objs = ConstantAttrMap() + for node in gm.graph.nodes: + if node.op == "get_attr": + constant_val = getattr(gm, node.target) + if constant_val in lifted_objs: + # We already lifted this constant elsewhere. Just rewrite uses + # of this get_attr to point to the already-existing placeholder + # node. + const_placeholder_node = lifted_objs[constant_val] + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + continue + + # For ScriptObject and Tensor constants: + # First check if the constant was an attribute on some module by + # consulting `constant_attrs` map. If it is, use the fqn that keeps + # its location consistent with the eager module. + # + # If it's not in the `constant_attrs` map, that means it's an inline + # constant (e.g. x + torch.tensor(0)), and thus did not have a + # specific location in the eager module. In that case, just generate + # some name and attach it to the module in which it was used. + if isinstance(constant_val, torch.ScriptObject): + constant_kind = InputKind.CUSTOM_OBJ + constant_fqn = constant_attrs.get(constant_val) + if constant_fqn is not None: + _, _, constant_name = constant_fqn.rpartition(".") + else: + constant_name = f"_lifted_custom_obj{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + num_custom_obj += 1 + elif isinstance(constant_val, torch.Tensor): + constant_kind = InputKind.CONSTANT_TENSOR + constant_fqn = constant_attrs.get(constant_val) + if constant_fqn is not None: + _, _, constant_name = constant_fqn.rpartition(".") + else: + constant_name = f"_lifted_tensor_constant{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + num_tensor_constants += 1 + elif isinstance(constant_val, torch.fx.GraphModule): + continue + elif "LoweredBackendModule" in type(constant_val).__name__: + continue + else: + raise SpecViolationError( + f"getattr node {node} referencing unsupported type {type(constant_val)}" + ) + + with gm.graph.inserting_before(first_user_input): + # Insert the constant node before the first user input + const_placeholder_node = gm.graph.placeholder(constant_name) + # match target name with its node name in case there is name collision + # and suffix is added to node name in fx + const_placeholder_node.target = const_placeholder_node.name + + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + + input_spec_arg: ArgumentSpec + if isinstance(constant_val, torch.Tensor): + if fake_mode is not None: + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_val, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = constant_val + else: + const_placeholder_node.meta["val"] = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + elif isinstance(constant_val, torch._C.ScriptObject): + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) + else: + raise SpecViolationError( + f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" + ) + + lifted_objs[constant_val] = const_placeholder_node + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + + # Add the constant as a buffer to the graph signature + graph_signature.input_specs.insert( + first_user_input_loc, + InputSpec( + kind=constant_kind, + arg=input_spec_arg, + target=constant_fqn, + ), + ) + all_constants[constant_fqn] = constant_val + first_user_input_loc += 1 + + return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> Dict[str, Union[torch.Tensor, torch.ScriptObject]]: + """When tracing, we produce a graph with an actual ScriptObject in the + meta["val"]. Eventually we want to change this behavior, when FakeMode infra + for ScriptObjects lands. + + For now, we rewrie meta["val"] to be a placeholder CustomObjArgument + """ + constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] = {} + for node in gm.graph.nodes: + if "val" not in node.meta or not isinstance( + node.meta["val"], torch.ScriptObject + ): + continue + + old_meta = node.meta["val"] + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants diff --git a/MLPY/Lib/site-packages/torch/_export/passes/remove_runtime_assertions.py b/MLPY/Lib/site-packages/torch/_export/passes/remove_runtime_assertions.py new file mode 100644 index 0000000000000000000000000000000000000000..350e9893991f577c28f97a5c6977019f943d61f1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/remove_runtime_assertions.py @@ -0,0 +1,26 @@ +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class _RemoveRuntimeAssertionsPass(PassBase): + """ + Remove runtime assertions inserted by the + _AddRuntimeAssertionsForInlineConstraintsPass. + """ + + def call(self, graph_module) -> PassResult: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target == torch.ops.aten._assert_async.msg: + assert_async_node = node + if len(assert_async_node.users) > 0: + continue + module.graph.erase_node(assert_async_node) + # the upstream scalar_tensor <- {le, ge} <- sym_size + # linear chain of nodes of nodes is removed by the + # downstream dead code elimination + modified = True + return PassResult(graph_module, modified) diff --git a/MLPY/Lib/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..ba62a622ef499b075cd1f339e956ed1a522ecbc9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -0,0 +1,141 @@ +import torch +from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled + +from ..utils import ( + node_inline_, + node_replace_, + nodes_filter, + nodes_first, + nodes_map, + sequential_split, +) + + +def _is_set_grad_enabled_node(node: torch.fx.Node): + return ( + node + and node.op == "call_function" + and node.target == torch._C._set_grad_enabled + ) + + +def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False): + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target == torch._C._set_grad_enabled + ): + return ( + first_non_ph.args[0] != torch.is_grad_enabled() + if omit_if_same_with_ambient + else True + ) + return False + + +def _replace_with_hop(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) + if len(set_grad_nodes) > 0: + assert len(set_grad_nodes) == 1 + set_grad_node = set_grad_nodes[0] + enable_grad_val = set_grad_node.args[0] + with graph.inserting_before(node): + get_attr_node = graph.get_attr(node.target) + output_node = next(iter(reversed(sub_gm.graph.nodes)), None) + if output_node is not None: + assert len(output_node.args) == 1 + output_args = output_node.args[0] + if isinstance(output_args, (tuple, list)): + call_func_node = graph.call_function( + wrap_with_set_grad_enabled, + (enable_grad_val, get_attr_node, *node.args), + {}, + ) + # Create the metadata + call_func_node.meta["val"] = tuple( + arg.meta["val"] for arg in output_args + ) + node_replace_(node, call_func_node, delete_old=True) + + # Rename the name of getitem nodes to the actual name of its contents + # for passing verifier and better readability, also propagate metadata + for get_item_node in call_func_node.users.keys(): + idx: int = get_item_node.args[1] + output_node = output_args[idx] + get_item_node._rename(output_node.name) + get_item_node.meta = output_node.meta + pass + + elif isinstance(output_args, torch.fx.Node): + call_func_node = graph.create_node( + "call_function", + wrap_with_set_grad_enabled, + (enable_grad_val, get_attr_node, *node.args), + {}, + output_args.name, + ) + call_func_node.meta = output_args.meta + node_replace_(node, call_func_node, delete_old=True) + else: + raise NotImplementedError( + f"repalce_set_grad_with_hop_pass doesnt' support output type {type(output_args)}" + ) + else: + raise NotImplementedError( + "Cannot replace a call_module with a hop if it has no output. This module will gets DCEed." + ) + sub_graph.erase_node(set_grad_node) + + +def _remove_set_grad_and_inline(node: torch.fx.Node): + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + nodes_map( + sub_graph.nodes, + lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, + ) + node_inline_(node) + + +def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule): + # If there is no set_grad_enabled node, return the original graph module + need_replacing = False + for node in gm.graph.nodes: + if _is_set_grad_enabled_node(node): + need_replacing = True + + if not need_replacing: + return gm + + new_gm = sequential_split(gm, _is_set_grad_enabled_node) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): + _replace_with_hop(node) + else: + _remove_set_grad_and_inline(node) + + nodes_map( + list(new_gm.graph.nodes), + lambda node: _maybe_inline_or_replace_with_hop(node) + if node.op == "call_module" + else node, + ) + new_gm.graph.lint() + return new_gm diff --git a/MLPY/Lib/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..ef419ccf13ec12dd0f8356703a86c86541ff7649 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py @@ -0,0 +1,18 @@ +from typing import Dict + +import torch + +replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = { + torch.ops.aten.sym_size: torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default, +} + + +def _replace_sym_size_ops_pass(gm: torch.fx.GraphModule): + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target in replacements: + node.target = replacements[node.target] diff --git a/MLPY/Lib/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/MLPY/Lib/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d7ef8f62ffaf0c02e0f5fdc2de8f742b71b80f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -0,0 +1,71 @@ +from typing import Dict, Optional, Set + +import torch +from torch._ops import OpOverload, OpOverloadPacket, HigherOrderOperator +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse + + +__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] + + +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { + torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, +} + +# TODO (tmanlaibaatar) remove this after https://github.com/pytorch/pytorch/pull/100749 +_BLACK_LISTED_OPS: Set[OpOverloadPacket] = { + torch.ops.aten.sym_size, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_numel, +} + +def is_view_op(schema: torch._C.FunctionSchema) -> bool: + if len(schema.arguments) == 0: + return False + alias_info = schema.arguments[0].alias_info + return (alias_info is not None) and (not alias_info.is_write) + + +def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: + if is_view_op(schema) and schema.name.startswith("aten::"): + view_op_name = schema.name.split("::")[1] + view_op_overload = ( + schema.overload_name + if schema.overload_name != "" + else "default" + ) + view_copy_op_name = view_op_name + "_copy" + if not hasattr(torch.ops.aten, view_copy_op_name): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) + + if not hasattr(view_copy_op_overload_packet, view_op_overload): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + return getattr(view_copy_op_overload_packet, view_op_overload) + + return None + + +class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Our backend expects pure functional operators. For efficiency + purposes, we keep view ops around while functionalizing the exported + program. This pass replaces view ops with view copy ops for backends that + need AOT memory planning. + """ + def call_operator(self, op, args, kwargs, meta): + if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: + return super().call_operator( + (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta + ) + + if op in _BLACK_LISTED_OPS or isinstance(op, HigherOrderOperator): + return super().call_operator(op, args, kwargs, meta) + + if view_copy_op := get_view_copy_of_view_op(op._schema): + return super().call_operator(view_copy_op, args, kwargs, meta) + + return super().call_operator(op, args, kwargs, meta) diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__init__.py b/MLPY/Lib/site-packages/torch/_export/serde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04145db13a2851ad1182c4c119d5e94fd451b0f6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/schema.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62cb6611cdf644620503aa2d39bd454aaece2d74 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/schema.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59653df7695d0c15ed021f06f92350d9efa2c9b8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/serialize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/serialize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f067437eb665beed578023e9d73142a27727d4d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/serialize.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/union.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/union.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3891a74a84760fef9b69a38715b5ab32bb35f60e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/union.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63055613e48eeb3cd0b0cf704f96116864aff249 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_export/serde/schema.py b/MLPY/Lib/site-packages/torch/_export/serde/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c67046d551664cc8c04bf2f02eafea2e4076a0d6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/serde/schema.py @@ -0,0 +1,346 @@ +# NOTE: This is a placeholder for iterating on export serialization schema design. +# Anything is subject to change and no guarantee is provided at this point. + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Dict, List, Optional, Tuple + +from torch._export.serde.union import _Union + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (5, 1) +TREESPEC_VERSION = 1 + + +class ScalarType(IntEnum): + UNKNOWN = 0 + BYTE = 1 + CHAR = 2 + SHORT = 3 + INT = 4 + LONG = 5 + HALF = 6 + FLOAT = 7 + DOUBLE = 8 + COMPLEXHALF = 9 + COMPLEXFLOAT = 10 + COMPLEXDOUBLE = 11 + BOOL = 12 + BFLOAT16 = 13 + + +class Layout(IntEnum): + Unknown = 0 + SparseCoo = 1 + SparseCsr = 2 + SparseCsc = 3 + SparseBsr = 4 + SparseBsc = 5 + _mkldnn = 6 + Strided = 7 + + +class MemoryFormat(IntEnum): + Unknown = 0 + ContiguousFormat = 1 + ChannelsLast = 2 + ChannelsLast3d = 3 + PreserveFormat = 4 + + +@dataclass +class Device: + type: str + index: Optional[int] = None + + +@dataclass(repr=False) +class SymExprHint(_Union): + as_int: int + as_float: float + as_bool: bool + + +# This is for storing the symbolic expressions behind symints/symfloats/symbools +# For example, we can get something like +# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) +# if we also have the hint that s0 and s1 are both 2. +@dataclass +class SymExpr: + expr_str: str + hint: Optional[SymExprHint] = None + + +@dataclass(repr=False) +class SymInt(_Union): + as_expr: SymExpr + as_int: int + + +@dataclass(repr=False) +class SymBool(_Union): + as_expr: SymExpr + as_bool: bool + + +@dataclass +class TensorMeta: + dtype: ScalarType + sizes: List[SymInt] + requires_grad: bool + device: Device + strides: List[SymInt] + storage_offset: SymInt + layout: Layout + + +# In most cases we will use the "as_name" field to store arguments which are +# SymInts. +# The "as_int" field is used in the case where we have a list containing a mix +# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to +# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints +# to the "as_int" field. +@dataclass(repr=False) +class SymIntArgument(_Union): + as_name: str + as_int: int + + +# In most cases we will use the "as_name" field to store arguments which are +# SymBools. +# The "as_bool" field is used in the case where we have a list containing a mix +# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to +# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools +# to the "as_bool" field. +@dataclass(repr=False) +class SymBoolArgument(_Union): + as_name: str + as_bool: bool + + +@dataclass +class TensorArgument: + name: str + + +# This is use for storing the contents of a list which contain optional tensors +# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the +# type List[OptionalTensorArgument], with tensor values seiralized to the +# "as_tensor" field, and None values serialized to the "as_none" field. +@dataclass(repr=False) +class OptionalTensorArgument(_Union): + as_tensor: str + as_none: Tuple[()] + + +@dataclass +class GraphArgument: + name: str + graph: 'Graph' + + +@dataclass +class CustomObjArgument: + name: str + class_fqn: str + + +# This is actually a union type +@dataclass(repr=False) +class Argument(_Union): + as_none: Tuple[()] + as_tensor: TensorArgument + as_tensors: List[TensorArgument] + as_int: int + as_ints: List[int] + as_float: float + as_floats: List[float] + as_string: str + as_strings: List[str] + as_sym_int: SymIntArgument + as_sym_ints: List[SymIntArgument] + as_scalar_type: ScalarType + as_memory_format: MemoryFormat + as_layout: Layout + as_device: Device + as_bool: bool + as_bools: List[bool] + as_sym_bool: SymBoolArgument + as_sym_bools: List[SymBoolArgument] + as_graph: GraphArgument + as_optional_tensors: List[OptionalTensorArgument] + as_custom_obj: CustomObjArgument + as_operator: str + + +@dataclass +class NamedArgument: + # Argument name from the operator schema + name: str + arg: Argument + + +@dataclass +class Node: + target: str + inputs: List[NamedArgument] + outputs: List[Argument] + metadata: Dict[str, str] + + +@dataclass +class Graph: + inputs: List[Argument] + outputs: List[Argument] + nodes: List[Node] + tensor_values: Dict[str, TensorMeta] + sym_int_values: Dict[str, SymInt] + sym_bool_values: Dict[str, SymBool] + # This is for deserializing the submodule graphs from higher order ops + # (ex. cond, map) where single tensor returns will just return a single + # tensor, rather than following export schema and returning a singleton + # list. + is_single_tensor_return: bool = False + custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + + +@dataclass +class UserInputSpec: + # Actually, only tensors and SymInts are allowed here + arg: Argument + + +@dataclass +class InputToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class InputToBufferSpec: + arg: TensorArgument + buffer_name: str + persistent: bool + + + +@dataclass +class InputToTensorConstantSpec: + arg: TensorArgument + tensor_constant_name: str + + +@dataclass +class InputToCustomObjSpec: + arg: CustomObjArgument + custom_obj_name: str + + +@dataclass(repr=False) +class InputSpec(_Union): + user_input: UserInputSpec + parameter: InputToParameterSpec + buffer: InputToBufferSpec + tensor_constant: InputToTensorConstantSpec + custom_obj: InputToCustomObjSpec + + +@dataclass +class UserOutputSpec: + arg: Argument + + +@dataclass +class LossOutputSpec: + arg: TensorArgument + + +@dataclass +class BufferMutationSpec: + arg: TensorArgument + buffer_name: str + + +@dataclass +class GradientToParameterSpec: + arg: TensorArgument + parameter_name: str + + +@dataclass +class GradientToUserInputSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass +class UserInputMutationSpec: + arg: TensorArgument + user_input_name: str + + +@dataclass(repr=False) +class OutputSpec(_Union): + user_output: UserOutputSpec + loss_output: LossOutputSpec + buffer_mutation: BufferMutationSpec + gradient_to_parameter: GradientToParameterSpec + gradient_to_user_input: GradientToUserInputSpec + user_input_mutation: UserInputMutationSpec + + +@dataclass +class GraphSignature: + input_specs: List[InputSpec] + output_specs: List[OutputSpec] + + +@dataclass +class RangeConstraint: + min_val: int + max_val: int + + +@dataclass +class ModuleCallSignature: + inputs: List[Argument] + outputs: List[Argument] + + # These are serialized by calling pytree.treespec_loads + # And deserialized by calling pytree.treespec_dumps + in_spec: str + out_spec: str + + +@dataclass +class ModuleCallEntry: + fqn: str + signature: Optional[ModuleCallSignature] = None + + +@dataclass +class GraphModule: + graph: Graph + signature: GraphSignature + # This is used for unflattening, by tracking the calling structure of all of + # the modules in order to unflatten the modules back to the eager calling + # conventions. + module_call_graph: List[ModuleCallEntry] + + +# Invariant: Every time a change is made to the schema, one of the versions +# should be upadted. +@dataclass +class SchemaVersion: + major: int # Major version number is bumped every time a breaking change is made. + minor: int # Minor version number is bumped when a compatible change is made. + + +@dataclass +class ExportedProgram: + graph_module: GraphModule + # Key is the opset namespace (ex. aten), and value is the version number + opset_version: Dict[str, int] + range_constraints: Dict[str, RangeConstraint] + schema_version: SchemaVersion + dialect: str diff --git a/MLPY/Lib/site-packages/torch/_export/serde/schema.yaml b/MLPY/Lib/site-packages/torch/_export/serde/schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23e12619579ce9fe18c298d4d21160b69f2ab33b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/serde/schema.yaml @@ -0,0 +1,389 @@ +# @generated by update_schema.py +# checksum<<4c9986f3aba283b1746995fff8fe7005b370c7e288adec65c03030349a4bab60>> +Argument: + kind: union + fields: + as_none: + type: Tuple[()] + as_tensor: + type: TensorArgument + as_tensors: + type: List[TensorArgument] + as_int: + type: int + as_ints: + type: List[int] + as_float: + type: float + as_floats: + type: List[float] + as_string: + type: str + as_strings: + type: List[str] + as_sym_int: + type: SymIntArgument + as_sym_ints: + type: List[SymIntArgument] + as_scalar_type: + type: ScalarType + as_memory_format: + type: MemoryFormat + as_layout: + type: Layout + as_device: + type: Device + as_bool: + type: bool + as_bools: + type: List[bool] + as_sym_bool: + type: SymBoolArgument + as_sym_bools: + type: List[SymBoolArgument] + as_graph: + type: GraphArgument + as_optional_tensors: + type: List[OptionalTensorArgument] + as_custom_obj: + type: CustomObjArgument + as_operator: + type: str +BufferMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str +CustomObjArgument: + kind: struct + fields: + name: + type: str + class_fqn: + type: str +Device: + kind: struct + fields: + type: + type: str + index: + type: Optional[int] + default: None +ExportedProgram: + kind: struct + fields: + graph_module: + type: GraphModule + opset_version: + type: Dict[str, int] + range_constraints: + type: Dict[str, RangeConstraint] + schema_version: + type: SchemaVersion + dialect: + type: str +GradientToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +GradientToUserInputSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +Graph: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + nodes: + type: List[Node] + tensor_values: + type: Dict[str, TensorMeta] + sym_int_values: + type: Dict[str, SymInt] + sym_bool_values: + type: Dict[str, SymBool] + is_single_tensor_return: + type: bool + default: 'False' + custom_obj_values: + type: Dict[str, CustomObjArgument] + default: '{}' +GraphArgument: + kind: struct + fields: + name: + type: str + graph: + type: Graph +GraphModule: + kind: struct + fields: + graph: + type: Graph + signature: + type: GraphSignature + module_call_graph: + type: List[ModuleCallEntry] +GraphSignature: + kind: struct + fields: + input_specs: + type: List[InputSpec] + output_specs: + type: List[OutputSpec] +InputSpec: + kind: union + fields: + user_input: + type: UserInputSpec + parameter: + type: InputToParameterSpec + buffer: + type: InputToBufferSpec + tensor_constant: + type: InputToTensorConstantSpec + custom_obj: + type: InputToCustomObjSpec +InputToBufferSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str + persistent: + type: bool +InputToCustomObjSpec: + kind: struct + fields: + arg: + type: CustomObjArgument + custom_obj_name: + type: str +InputToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +InputToTensorConstantSpec: + kind: struct + fields: + arg: + type: TensorArgument + tensor_constant_name: + type: str +Layout: + kind: enum + fields: + Unknown: 0 + SparseCoo: 1 + SparseCsr: 2 + SparseCsc: 3 + SparseBsr: 4 + SparseBsc: 5 + _mkldnn: 6 + Strided: 7 +LossOutputSpec: + kind: struct + fields: + arg: + type: TensorArgument +MemoryFormat: + kind: enum + fields: + Unknown: 0 + ContiguousFormat: 1 + ChannelsLast: 2 + ChannelsLast3d: 3 + PreserveFormat: 4 +ModuleCallEntry: + kind: struct + fields: + fqn: + type: str + signature: + type: Optional[ModuleCallSignature] + default: None +ModuleCallSignature: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + in_spec: + type: str + out_spec: + type: str +NamedArgument: + kind: struct + fields: + name: + type: str + arg: + type: Argument +Node: + kind: struct + fields: + target: + type: str + inputs: + type: List[NamedArgument] + outputs: + type: List[Argument] + metadata: + type: Dict[str, str] +OptionalTensorArgument: + kind: union + fields: + as_tensor: + type: str + as_none: + type: Tuple[()] +OutputSpec: + kind: union + fields: + user_output: + type: UserOutputSpec + loss_output: + type: LossOutputSpec + buffer_mutation: + type: BufferMutationSpec + gradient_to_parameter: + type: GradientToParameterSpec + gradient_to_user_input: + type: GradientToUserInputSpec + user_input_mutation: + type: UserInputMutationSpec +RangeConstraint: + kind: struct + fields: + min_val: + type: int + max_val: + type: int +ScalarType: + kind: enum + fields: + UNKNOWN: 0 + BYTE: 1 + CHAR: 2 + SHORT: 3 + INT: 4 + LONG: 5 + HALF: 6 + FLOAT: 7 + DOUBLE: 8 + COMPLEXHALF: 9 + COMPLEXFLOAT: 10 + COMPLEXDOUBLE: 11 + BOOL: 12 + BFLOAT16: 13 +SchemaVersion: + kind: struct + fields: + major: + type: int + minor: + type: int +SymBool: + kind: union + fields: + as_expr: + type: SymExpr + as_bool: + type: bool +SymBoolArgument: + kind: union + fields: + as_name: + type: str + as_bool: + type: bool +SymExpr: + kind: struct + fields: + expr_str: + type: str + hint: + type: Optional[SymExprHint] + default: None +SymExprHint: + kind: union + fields: + as_int: + type: int + as_float: + type: float + as_bool: + type: bool +SymInt: + kind: union + fields: + as_expr: + type: SymExpr + as_int: + type: int +SymIntArgument: + kind: union + fields: + as_name: + type: str + as_int: + type: int +TensorArgument: + kind: struct + fields: + name: + type: str +TensorMeta: + kind: struct + fields: + dtype: + type: ScalarType + sizes: + type: List[SymInt] + requires_grad: + type: bool + device: + type: Device + strides: + type: List[SymInt] + storage_offset: + type: SymInt + layout: + type: Layout +UserInputMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +UserInputSpec: + kind: struct + fields: + arg: + type: Argument +UserOutputSpec: + kind: struct + fields: + arg: + type: Argument +SCHEMA_VERSION: +- 5 +- 1 +TREESPEC_VERSION: 1 diff --git a/MLPY/Lib/site-packages/torch/_export/serde/schema_check.py b/MLPY/Lib/site-packages/torch/_export/serde/schema_check.py new file mode 100644 index 0000000000000000000000000000000000000000..adee0a3f450412f1252ffeead9833c4945232017 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/serde/schema_check.py @@ -0,0 +1,285 @@ +import dataclasses +import hashlib +import re +import typing +from enum import IntEnum +from typing import Any, Dict, Optional, Union + +from torch._export.serde import schema +from torch._export.serde.union import _Union + + +class SchemaUpdateError(Exception): + pass + + +def _check(x, msg): + if not x: + raise SchemaUpdateError(msg) + + +def _staged_schema(): + ret: Dict[str, Any] = {} + defs = {} + + def _handle_aggregate(ty): + def dump_type(t): + if isinstance(t, type): + return t.__name__ + elif isinstance(t, str): + assert t in defs + return t + elif o := typing.get_origin(t): + # Lemme know if there's a better way to do this. + if o == list: + head = "List" + elif o == dict: + head = "Dict" + elif o == tuple: + if typing.get_args(t) == (): + return "Tuple[()]" + head = "Tuple" + elif o == Union: + args = typing.get_args(t) + assert len(args) == 2 and args[1] == type(None) + return f"Optional[{dump_type(args[0])}]" + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + return ( + f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" + ) + elif t == (): + return "()" + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + + def dump_field(f): + t = dump_type(f.type) + ret = {"type": t} + + value = dataclasses.MISSING + if f.default is not dataclasses.MISSING: + value = f.default + elif f.default_factory is not dataclasses.MISSING: + value = f.default_factory() + + if t.startswith("Optional[") and value is not None: + raise AssertionError( + f"Optional field {ty.__name__}.{f.name} must have default value to be None." + ) + + if value is not dataclasses.MISSING: + default = str(value) + ret["default"] = default + return ret + + return {f.name: dump_field(f) for f in dataclasses.fields(ty)} + + def _handle_int_enum(name, ty): + ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + + def _handle_struct(name, ty): + ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} + + def _handle_union(name, ty): + ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} + + for name in dir(schema): + if name.startswith("_"): + continue + + value = getattr(schema, name) + + if hasattr(value, "__module__") and value.__module__ != schema.__name__: + continue + + defs[name] = value + + for name, value in defs.items(): + if isinstance(value, type): + if issubclass(value, IntEnum): + _handle_int_enum(name, value) + elif dataclasses.is_dataclass(value): + if issubclass(value, _Union): + _handle_union(name, value) + else: + _handle_struct(name, value) + else: + raise AssertionError(f"Unknown schema type {name}: {value}") + elif isinstance(value, (int, tuple)): + assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") + else: + raise AssertionError(f"Unknown variable {name}: {value}") + + ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in ret["SCHEMA_VERSION"]) + ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert ret["TREESPEC_VERSION"] > 0 + return ret + + +def _diff_schema(dst, src): + additions = {key: src[key] for key in src.keys() - dst.keys()} + subtractions = {key: dst[key] for key in dst.keys() - src.keys()} + + common_keys = src.keys() & dst.keys() + + versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} + common_keys -= versions + + for key in common_keys: + src_kind = src[key]["kind"] + src_fields = src[key]["fields"] + dst_kind = dst[key]["kind"] + dst_fields = dst[key]["fields"] + _check( + src_kind == dst_kind, + f"Type {key} changed kind from {dst_kind} to {src_kind}", + ) + assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) + added_fields = { + key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() + } + subtracted_fields = { + key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() + } + common_fields = src_fields.keys() & dst_fields.keys() + + for field in common_fields: + src_field = src_fields[field] + dst_field = dst_fields[field] + if src_kind == "struct": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + if "default" in src_field and "default" not in dst_field: + added_fields[field] = {} + added_fields[field]["default"] = src_field["default"] + if "default" not in src_field and "default" in dst_field: + subtracted_fields[field] = {} + subtracted_fields[field]["default"] = dst_field["default"] + elif src_kind == "enum": + _check( + src_field == dst_field, + f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", + ) + elif src_kind == "union": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + else: + raise AssertionError(f"Unknown kind {src_kind}: {key}") + if len(added_fields) > 0: + assert key not in additions + additions[key] = {} + additions[key]["fields"] = added_fields + if len(subtracted_fields) > 0: + assert key not in subtractions + subtractions[key] = {} + subtractions[key]["fields"] = subtracted_fields + + return additions, subtractions + + +def _hash_schema(s): + return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() + + +@dataclasses.dataclass +class _Commit: + result: Dict[str, Any] + checksum_result: str + path: str + additions: Dict[str, Any] + subtractions: Dict[str, Any] + base: Dict[str, Any] + checksum_base: Optional[str] + + +def update_schema(): + import importlib.resources + + if importlib.resources.is_resource(__package__, "schema.yaml"): + content = importlib.resources.read_text(__package__, "schema.yaml") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) + _check(match is not None, "checksum not found in schema.yaml") + assert match is not None + checksum_base = match.group(1) + from yaml import load, Loader + + dst = load(content, Loader=Loader) + assert isinstance(dst, dict) + else: + checksum_base = None + dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} + + src = _staged_schema() + additions, subtractions = _diff_schema(dst, src) + return _Commit( + result=src, + checksum_result=_hash_schema(src), + path=__package__.replace(".", "/") + "/schema.yaml", + additions=additions, + subtractions=subtractions, + base=dst, + checksum_base=checksum_base, + ) + + +def check(commit: _Commit, force_unsafe: bool = False): + next_version = None + reason = "" + # Step 1: Detect major schema updates. + if len(commit.additions) > 0: + for k, v in commit.additions.items(): + if k not in commit.base: + continue + kind = commit.result[k]["kind"] + fields = v["fields"] + for f, d in fields.items(): + if "default" not in d and kind == "struct": + reason += ( + f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " + + "which requires major version bump.\n" + ) + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + if k not in commit.result: + continue + for f in v["fields"]: + reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if force_unsafe: + reason += "--force-unsafe is used." + next_version = commit.result["SCHEMA_VERSION"] + else: + # Step 2: Detect minor schema updates. + if next_version is None and len(commit.additions) > 0: + for k, v in commit.additions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is added to schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + if next_version is None and len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is removed from schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + + return next_version, reason diff --git a/MLPY/Lib/site-packages/torch/_export/serde/serialize.py b/MLPY/Lib/site-packages/torch/_export/serde/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..35d49430baf4f7a79fa4853ed8f2ff71bbf53f06 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/serde/serialize.py @@ -0,0 +1,2434 @@ +import base64 +import copy +import dataclasses +import heapq +import inspect +import io +import json +import logging +import math +import operator +import typing +import copyreg + +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Any, + Callable, + cast, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Union, +) + +import sympy + +import torch +import torch.export.exported_program as ep +from torch._export.serde.schema import SchemaVersion +from torch._export.verifier import load_verifier +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental import symbolic_shapes +from torch.utils import _pytree as pytree +from torch.utils._pytree import treespec_dumps, treespec_loads +from torch.utils._sympy.value_ranges import ValueRanges + +from .schema import ( # type: ignore[attr-defined] + Argument, + BufferMutationSpec, + CustomObjArgument, + Device, + ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, + Graph, + GraphArgument, + GraphModule, + GraphSignature, + InputSpec, + InputToBufferSpec, + InputToCustomObjSpec, + InputToParameterSpec, + InputToTensorConstantSpec, + Layout, + LossOutputSpec, + MemoryFormat, + ModuleCallEntry, + ModuleCallSignature, + NamedArgument, + Node, + OptionalTensorArgument, + OutputSpec, + RangeConstraint, + ScalarType, + SCHEMA_VERSION, + SymBool, + SymBoolArgument, + SymExpr, + SymExprHint, + SymInt, + SymIntArgument, + TensorArgument, + TensorMeta, + TREESPEC_VERSION, + UserInputMutationSpec, + UserInputSpec, + UserOutputSpec, +) +from .union import _Union + + +__all__ = [ + "serialize", + "GraphModuleSerializer", + "ExportedProgramSerializer", + "GraphModuleDeserializer", + "ExportedProgramDeserializer", +] + +from .upgrade import GraphModuleOpUpgrader + +log = logging.getLogger(__name__) + + +class SerializeError(RuntimeError): + pass + + +def _reverse_map(d: Dict[Any, Enum]): + return {v.value: k for k, v in d.items()} + + +MetaType = Union[FakeTensor, int, torch.SymInt, bool, torch.SymBool, ep.CustomObjArgument] + + +ST_DELIMITER = ";" + +_TORCH_TO_SERIALIZE_DTYPE = { + torch.uint8: ScalarType.BYTE, + torch.int8: ScalarType.CHAR, + torch.int16: ScalarType.SHORT, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.float64: ScalarType.DOUBLE, + torch.complex32: ScalarType.COMPLEXHALF, + torch.complex64: ScalarType.COMPLEXFLOAT, + torch.complex128: ScalarType.COMPLEXDOUBLE, + torch.bool: ScalarType.BOOL, + torch.bfloat16: ScalarType.BFLOAT16 +} + + +_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_LAYOUT = { + torch.sparse_coo: Layout.SparseCoo, + torch.sparse_csr: Layout.SparseCsr, + torch.sparse_csc: Layout.SparseCsc, + torch.sparse_bsr: Layout.SparseBsr, + torch.sparse_bsc: Layout.SparseBsc, + torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] + torch.strided: Layout.Strided, +} + + +_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { + torch.contiguous_format: MemoryFormat.ContiguousFormat, + torch.channels_last: MemoryFormat.ChannelsLast, + torch.channels_last_3d: MemoryFormat.ChannelsLast3d, + torch.preserve_format: MemoryFormat.PreserveFormat, +} + + +_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] + + +_SYM_INT_OPS = { + operator.mul, + operator.add, + operator.sub, + operator.floordiv, + operator.mod, + torch.sym_int, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_sqrt, +} + + +_SYM_BOOL_OPS = { + operator.eq, + operator.ne, + operator.le, + operator.ge, + operator.lt, + operator.gt, + torch.sym_not, +} + + +@dataclass +class SerializedArtifact: + exported_program: Union[ExportedProgram, bytes] + state_dict: bytes + constants: bytes + + +def deserialize_device(d: Device) -> torch.device: + if d.index is None: + return torch.device(type=d.type) # type: ignore[call-overload] + return torch.device(type=d.type, index=d.index) + + +def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: + if isinstance(s, (torch.SymInt, int)): + if symbolic_shapes.is_concrete_int(s): + return SymInt.create(as_int=int(s)) + else: + assert isinstance(s, torch.SymInt) + if s.node.hint is None: + return SymInt.create(as_expr=SymExpr(str(s))) + else: + return SymInt.create(as_expr=SymExpr(str(s), hint=SymExprHint.create(as_int=s.node.hint))) + else: + raise SerializeError( + f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: + if isinstance(s, (torch.SymBool, bool)): + if symbolic_shapes.is_concrete_bool(s): + return SymBool.create(as_bool=bool(s)) + else: + return SymBool.create(as_expr=SymExpr(expr_str=str(s))) + else: + raise SerializeError( + f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" + ) + + +def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: + """ + Extract a TensorMeta describing `t`. + """ + return TensorMeta( + dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], + sizes=[serialize_sym_int(s) for s in t.shape], + requires_grad=t.requires_grad, + device=Device(type=t.device.type, index=t.device.index), + strides=[serialize_sym_int(s) for s in t.stride()], + storage_offset=serialize_sym_int(0), # TODO needs to be fixed. + layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], + ) + + +_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None + + +def _reduce_fake_tensor(fake_tensor: FakeTensor): + is_parameter = isinstance(fake_tensor, torch.nn.Parameter) + tensor_meta = serialize_tensor_meta(fake_tensor) + tensor_meta_bytes = json.dumps(_dataclass_to_dict(tensor_meta), cls=EnumEncoder).encode("utf-8") + return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) + + +def _reconstruct_fake_tensor(serialized_tensor_meta: bytes, is_parameter: bool) -> FakeTensor: + # Deserialize the bytes into a TensorMeta + json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) + tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) + # Find the current fake mode + assert _CURRENT_DESERIALIZER is not None, "Need access to current deserializer state" + fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) + if is_parameter: + fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] + return fake_tensor + + +def serialize_torch_artifact(artifact: Dict[str, Any]) -> bytes: + assert FakeTensor not in copyreg.dispatch_table, "Refusing to stomp on existing FakeTensor reducer" + try: + copyreg.pickle(FakeTensor, _reduce_fake_tensor) + buffer = io.BytesIO() + # This is a workaround for backend's tensor deserialization problem: + # unpickleTensor() always create a tensor on the device where it was originally saved + # This behavior is bad for multi-gpu training, as we wish to directly load the tensor + # on the designated device. + # For now, we simply move the tensor to cpu before saving. + # TODO: this should be fixed by deserialization instead. + torch.save(artifact, buffer) + return buffer.getvalue() + finally: + del copyreg.dispatch_table[FakeTensor] + + +def deserialize_torch_artifact(serialized: bytes): + if len(serialized) == 0: + return {} + buffer = io.BytesIO(serialized) + buffer.seek(0) + artifact = torch.load(buffer) + assert isinstance(artifact, dict) + return artifact + + +def _sympy_int_to_int(val: sympy.Expr): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError( + "Export constraints cannot be non-integer expressions" + ) + + +def _int_to_sympy_int(val) -> sympy.Expr: + # Convert concrete int into simple sympy Integers + if val == math.inf: + return sympy.oo + if val == -math.inf: + return -sympy.oo + return sympy.Integer(val) + + +def serialize_range_constraints( + range_constraints: Dict[sympy.Symbol, ValueRanges] +) -> Dict[str, RangeConstraint]: + return { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower), # type: ignore[arg-type] + _sympy_int_to_int(v.upper), # type: ignore[arg-type] + ) + for k, v in range_constraints.items() + } + + +def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool: + returns = target._schema.returns + return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType) + + +def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool: + returns = target._schema.returns + if len(returns) != 1: + return False + return_type = returns[0].real_type + return isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ) + + +@dataclass +class GraphState: + inputs: List[Argument] = field(default_factory=list) + outputs: List[Argument] = field(default_factory=list) + nodes: List[Node] = field(default_factory=list) + tensor_values: Dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: Dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) + is_single_tensor_return: bool = False + custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + + +class GraphModuleSerializer: + def __init__( + self, + graph_signature: ep.ExportGraphSignature, + module_call_graph: List[ep.ModuleCallEntry] + ): + self.graph_state = GraphState() + self.graph_signature = graph_signature + self.module_call_graph = module_call_graph + self.custom_objs: Dict[str, torch._C.ScriptObject] = {} + + @contextmanager + def save_graph_state(self): + saved = self.graph_state + self.graph_state = GraphState() + try: + yield + finally: + self.graph_state = saved + + def handle_placeholder(self, node: torch.fx.Node): + assert node.op == "placeholder" + if isinstance(node.meta['val'], torch.Tensor): + graph_input = Argument.create(as_tensor=TensorArgument(name=node.name)) + self.graph_state.tensor_values[node.name] = serialize_tensor_meta(node.meta["val"]) + elif isinstance(node.meta['val'], torch.SymInt): + raise AssertionError("SymInt graph input is not implemented yet.") + elif isinstance(node.meta['val'], (int, bool, str, float, type(None))): + graph_input = self.serialize_input(node.meta['val']) + elif isinstance(node.meta['val'], ep.CustomObjArgument): + class_fqn = node.meta["val"].class_fqn + graph_input = Argument.create(as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn)) + self.graph_state.custom_obj_values[node.name] = self.serialize_script_obj_meta(node.meta["val"]) + else: + raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") + self.graph_state.inputs.append(graph_input) + + def handle_output(self, node: torch.fx.Node): + assert node.op == "output" + assert len(node.args) == 1, "FX.Node's args should have one arg" + node_args = node.args[0] + if isinstance(node_args, torch.fx.Node): + # For singleton tensor returns + self.graph_state.is_single_tensor_return = True + self.graph_state.outputs = [self.serialize_input(node_args)] + else: + assert isinstance(node_args, (tuple, list)) + self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] + + def serialize_operator(self, target) -> str: + if isinstance(target, str): + return target + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" + + def handle_call_function(self, node: torch.fx.Node): + assert node.op == "call_function" + + # getitem has been handled in the producer node, skip it here + if node.target is operator.getitem: + return + + if node.target in _SYM_INT_OPS: + assert len(node.kwargs) == 0 + meta_val = node.meta["val"] + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))], + metadata=self.serialize_metadata(node), + ) + elif node.target in _SYM_BOOL_OPS: + assert len(node.kwargs) == 0 + meta_val = node.meta["val"] + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))], + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.OpOverload): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + # TODO: create a new tensor_values here, meta might have faketensor info + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.HigherOrderOperator): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(node.args, node.kwargs), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + ) + else: + raise SerializeError(f"Serializing {node.target} is not supported") + + self.graph_state.nodes.append(ex_node) + + def handle_get_attr(self, node): + pass + + def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: + ret = {} + if stack_trace := node.meta.get("stack_trace"): + ret["stack_trace"] = stack_trace + + if nn_module_stack := node.meta.get("nn_module_stack"): + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + + assert isinstance(path, str) + + # node.meta["nn_module_stack"] could have two forms: + # 1. (path: str, module_type: 'type'), e.g. + # ('', ) + # 2. (path: str, module_type: str), e.g. + # ('', 'sigmoid.inference.MySimpleModel') + # ExportedProgram directly produced by torch.export() has form 1 + # ExportedProgram deserialized from disk has form 2 + # TODO: This is not ideal, we should fix this. + if isinstance(ty, str): + normalized_ty = ty + else: + normalized_ty = ty.__module__ + "." + ty.__qualname__ + + return path + "," + normalized_ty + + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" + for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) + + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" for source_fn in source_fn_st] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) + + return ret + + def serialize_script_obj_meta(self, script_obj_meta: ep.CustomObjArgument) -> CustomObjArgument: + return CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: + serialized_args = [] + args_names = inspect.signature(op).parameters.keys() + for args_name, arg in zip(args_names, args): + serialized_args.append( + NamedArgument(name=args_name, arg=self.serialize_input(arg)) + ) + return serialized_args + + def serialize_inputs( + self, target: torch._ops.OpOverload, args, kwargs=None + ) -> List[NamedArgument]: + assert isinstance(target, torch._ops.OpOverload) + kwargs = kwargs or {} + serialized_args = [] + for i, schema_arg in enumerate(target._schema.arguments): + if schema_arg.name in kwargs: + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(kwargs[schema_arg.name]), + ) + ) + elif not schema_arg.kwarg_only and i < len(args): + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(args[i]), + ) + ) + else: + # We intentionally don't serialize the missing arguments + # with default values + pass + + + return serialized_args + + def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: + """ + For serializing HOO inputs since HOOs do not have a schema. + """ + inputs = [ + NamedArgument( + name="", + arg=self.serialize_input(a), + ) for a in args + ] + inputs.extend([ + NamedArgument( + name=name, + arg=self.serialize_input(a) + ) for name, a in kwargs.items() + ]) + return inputs + + def is_sym_int_arg(self, arg) -> bool: + return isinstance(arg, int) or ( + isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_int_values + ) + + def is_sym_bool_arg(self, arg) -> bool: + return isinstance(arg, bool) or ( + isinstance(arg, torch.fx.Node) and arg.name in self.graph_state.sym_bool_values + ) + + def serialize_input(self, arg) -> Argument: + import torch._inductor.ir as inductor_ir + inductor_tensor_buffers = ( + inductor_ir.Buffer, + inductor_ir.ReinterpretView, + ) + + if isinstance(arg, torch.fx.Node): + if arg.op == "get_attr": + assert isinstance(arg.target, str) + attr = getattr(arg.graph.owning_module, arg.target) + + if isinstance(attr, torch.Tensor): + raise SerializeError("getattr nodes containing tensors should not appear in the graph") + elif isinstance(attr, torch.fx.GraphModule): + with self.save_graph_state(): + graph = self.serialize_graph(attr) + return Argument.create(as_graph=GraphArgument(name=arg.target, graph=graph)) + else: + raise SerializeError(f"Unsupported getattr attribute {arg.target} with type: {type(attr)}") + elif self.is_sym_int_arg(arg): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=arg.name)) + elif self.is_sym_bool_arg(arg): + return Argument.create(as_sym_bool=SymBoolArgument.create(as_name=arg.name)) + else: + if isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create(as_custom_obj=CustomObjArgument(name=arg.name, class_fqn=arg.meta["val"].class_fqn)) + return Argument.create(as_tensor=TensorArgument(name=arg.name)) + elif isinstance(arg, inductor_tensor_buffers): + # Other branches are for arguments in fx node. + # This is a special branch for handling buffers (representing tensor arguments) + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + return Argument.create(as_tensor=TensorArgument(name=arg_name)) + elif isinstance(arg, torch.SymInt): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) + elif isinstance(arg, bool): + return Argument.create(as_bool=arg) + elif isinstance(arg, str): + return Argument.create(as_string=arg) + elif isinstance(arg, int): + return Argument.create(as_int=arg) + elif isinstance(arg, float): + return Argument.create(as_float=arg) + elif arg is None: + return Argument.create(as_none=()) + elif isinstance(arg, (list, tuple)): + # Must check bool first, as bool is also treated as int + if all(isinstance(a, bool) for a in arg): + return Argument.create(as_bools=list(arg)) + elif all(isinstance(a, int) for a in arg): + return Argument.create(as_ints=list(arg)) + elif all(isinstance(a, float) for a in arg): + return Argument.create(as_floats=list(arg)) + elif all(isinstance(a, str) for a in arg): + return Argument.create(as_strings=list(arg)) + elif all(isinstance(a, torch.SymInt) for a in arg): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create( + as_sym_ints=[SymIntArgument.create(as_name=str(a)) for a in arg] + ) + elif all(self.is_sym_int_arg(a) for a in arg): + # list of sym_ints + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymIntArgument.create(as_name=a.name)) + elif isinstance(a, int): + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(self.is_sym_bool_arg(a) for a in arg): + # list of sym_bools + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymBoolArgument.create(as_name=a.name)) + elif isinstance(a, bool): + values.append(SymBoolArgument.create(as_bool=a)) + return Argument.create(as_sym_bools=values) + elif all(isinstance(a, torch.fx.Node) for a in arg): + # list of tensors + arguments = [] + for a in arg: + if a.op == "get_attr": + raise SerializeError("getattr nodes containing tensors should not appear in the graph") + arguments.append(TensorArgument(name=a.name)) + return Argument.create(as_tensors=arguments) + elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): + # list of optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=()) + elif isinstance(a, torch.fx.Node): + return OptionalTensorArgument.create(as_tensor=a.name) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all(isinstance(a, inductor_tensor_buffers) for a in arg): + # list of inductor buffers + return Argument.create( + as_tensors=[TensorArgument(name=a.get_name()) for a in arg], + ) + elif all(isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg): + # list of inductor buffers as optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=()) + elif isinstance(a, inductor_tensor_buffers): + return OptionalTensorArgument.create(as_tensor=a.get_name()) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument type: {[type(a) for a in arg]}") + elif isinstance(arg, torch.dtype): + return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) + elif isinstance(arg, torch.device): + return Argument.create(as_device=Device(type=arg.type, index=arg.index)) + elif isinstance(arg, torch.memory_format): + return Argument.create(as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg]) + elif isinstance(arg, torch.layout): + return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) + elif isinstance(arg, torch._C.ScriptObject): + if not ( + arg._has_method("__getstate__") and # type: ignore[attr-defined] + arg._has_method("__setstate__") # type: ignore[attr-defined] + ): + raise SerializeError( + f"Unable to serialize custom class {arg}. Please define " + "serialization methods via def_pickle()." + ) + # Custom objects through torchind are serializable with pickle, + # through implementing the .def_pickle function. This should result + # in the object containing a __getstate__ and __setstate__ + # serialize/deserialize function. + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create(as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn)) + elif isinstance(arg, torch._ops.OpOverload): + return Argument.create(as_operator=self.serialize_operator(arg)) + else: + raise SerializeError(f"Unsupported argument type: {type(arg)}") + + def serialize_tensor_output(self, name, meta_val) -> TensorArgument: + assert name not in self.graph_state.tensor_values + self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) + return TensorArgument(name=name) + + def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_int_values + self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) + return SymIntArgument.create(as_name=name) + + def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_bool_values + self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) + return SymBoolArgument.create(as_name=name) + + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + if spec.kind == ep.InputKind.USER_INPUT: + return InputSpec.create( + user_input=UserInputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + assert spec.persistent is not None + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + persistent=spec.persistent, + ) + ) + elif spec.kind == ep.InputKind.CONSTANT_TENSOR: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + tensor_constant=InputToTensorConstantSpec( + arg=TensorArgument(name=spec.arg.name), + tensor_constant_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.CUSTOM_OBJ: + assert spec.target is not None + assert isinstance(spec.arg, ep.CustomObjArgument) + return InputSpec.create( + custom_obj=InputToCustomObjSpec( + arg=CustomObjArgument(name=spec.arg.name, class_fqn=spec.arg.class_fqn), + custom_obj_name=spec.target, + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec( + arg=self.serialize_argument_spec(spec.arg) + ) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec( + arg=TensorArgument(name=spec.arg.name) + ) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + user_input_mutation=UserInputMutationSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, ep.TensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, ep.SymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, ep.ConstantArgument): + return self.serialize_input(x.value) + elif isinstance(x, ep.CustomObjArgument): + return Argument.create(as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn)) + else: + raise AssertionError("TODO") + + def serialize_module_call_signature(self, module_call_signature: ep.ModuleCallSignature) -> ModuleCallSignature: + return ModuleCallSignature( + inputs=[self.serialize_argument_spec(x) for x in module_call_signature.inputs], + outputs=[self.serialize_argument_spec(x) for x in module_call_signature.outputs], + in_spec=treespec_dumps(module_call_signature.in_spec, TREESPEC_VERSION), + out_spec=treespec_dumps(module_call_signature.out_spec, TREESPEC_VERSION), + ) + + def serialize_module_call_graph(self, module_call_graph: List[ep.ModuleCallEntry]) -> List[ModuleCallEntry]: + return [ + ModuleCallEntry( + fqn=entry.fqn, + signature=self.serialize_module_call_signature(entry.signature) if entry.signature else None, + ) for entry in module_call_graph + ] + + def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: + """For a given node, return the dataclass representing its output values. + + [NOTE: Multiple outputs] We handle aggregates differently than FX. For + FX, it looks like: + + x = call_function("multiple_return", ...) + element0 = call_function(getitem, x, 0) + foo = call_function("use_output", element0) + + We do not want the intermediate `getitem` call, so our serialized thing looks like: + + element0, element1, element2 = call_function("multiple_return", ...) + foo = call_function("use_output", element0) + + We want names to be consistent across these two schemes, so that we can + mostly reuse the names coming from FX. This function computes a mapping from + the FX representation to our representation, preserving the names. + """ + assert node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + + assert isinstance(node.target, torch._ops.OpOverload) + returns = node.target._schema.returns + + if len(returns) == 0: + return [] + + meta_val = node.meta["val"] + + def output_node_at_index(node, index): + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + return user + return None + + # Check single value return + if _is_single_tensor_list_return(node.target): + # e.g "-> Tensor[]" + tensor_args = [] + for idx, meta in enumerate(meta_val): + user_node = output_node_at_index(node, idx) + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{idx}" + ) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + elif len(returns) == 1: + return [self.serialize_output(node.name, meta_val)] + + # There are a two possibilities at this point: + # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" + # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" + # + # Either way, start by gathering a list of TensorArguments with the correct names. + # For consistent naming with FX, consult the downstream `getitem` node and + # make sure our outputs have the same name. + + output_arguments = [] + for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): + if meta is None: + assert isinstance(return_schema.real_type, (torch.OptionalType, torch.TensorType)) + # When the return type is annoated as Tensor type, the op can also return an + # undefined Tensor which will be implicitly converted to None in Python. + output_arguments.append(Argument.create(as_none=())) + elif isinstance(meta, FakeTensor): + assert isinstance(return_schema.real_type, torch.TensorType) + user_node = output_node_at_index(node, idx) + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{idx}" + ) + output_arguments.append(self.serialize_output(name, meta)) + elif isinstance(meta, list): + # for List[Tensor] return type + assert isinstance( + return_schema.real_type, torch.ListType + ) and isinstance( + return_schema.real_type.getElementType(), torch.TensorType + ) + user_node = output_node_at_index(node, idx) + assert user_node is not None + + args = [] + for i, m in enumerate(meta): + if m is None: + continue + sub_user_node = output_node_at_index(user_node, i) + assert sub_user_node is not None, f"No user found at index {i}" + + args.append(self.serialize_tensor_output(sub_user_node.name, m)) + output_arguments.append(Argument.create(as_tensors=args)) + elif isinstance(meta, (int, SymInt)): + user_node = output_node_at_index(node, idx) + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{idx}" + ) + output_arguments.append(self.serialize_output(name, meta)) + else: + raise ValueError(f"Unhandled output type {type(meta)} from node {node.format_node()}") + + return output_arguments + + def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: + """ + For serializing HOO outputs since HOOs do not have a schema. + """ + meta_val = node.meta["val"] + + if isinstance(meta_val, tuple): + # Note: Since we don't have a schema, we just serialize all tuple + # outputs to be a list of values. Even if the output is supposed to + # be a tensor list (Tensor[]), we will serialize it to be a list of + # tensors (Tensor, Tensor, Tensor). An exception is that if there's + # a singleton tensor, we will serialize this to be a singleton + # tensor list so that the deserializer knows to insert getitem nodes. + + idx_to_name = {} + for user in node.users: + if user.target is not operator.getitem: + continue + idx_to_name[user.args[1]] = user.name + + for idx in range(len(meta_val)): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + if len(meta_val) == 1: + tensors = [] + for i, v in enumerate(meta_val): + assert isinstance(v, torch.Tensor) + tensors.append(self.serialize_tensor_output(idx_to_name[i], v)) + return [Argument.create(as_tensors=tensors)] + + else: + return [ + self.serialize_output(idx_to_name[i], element_meta_val) + for i, element_meta_val in enumerate(meta_val) + ] + + else: + return [self.serialize_output(node.name, meta_val)] + + def serialize_output(self, name: str, meta_val: Any) -> Argument: + # Check single value return + if meta_val is None: + return Argument.create(as_none=()) + if isinstance(meta_val, torch.Tensor): + # e.g "-> Tensor" + return Argument.create(as_tensor=self.serialize_tensor_output(name, meta_val)) + elif isinstance(meta_val, (int, torch.SymInt)): + # e.g "-> SymInt" + return Argument.create(as_sym_int=self.serialize_sym_int_output(name, meta_val)) + elif isinstance(meta_val, torch.SymBool): + # e.g "-> SymBool" + return Argument.create(as_sym_bool=self.serialize_sym_bool_output(name, meta_val)) + + # list outputs should've been handled earlier + raise SerializeError(f"Unable to serialize output {meta_val}") + + def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: + meta_val = node.meta["val"] + + idx_to_name = {} + for user in node.users: + assert user.target is operator.getitem, f"User node {user} of {node} is incorrect" + idx_to_name[user.args[1]] = user.name + + for idx, _ in enumerate(meta_val): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + arg_list = [] + for i, element_meta_val in enumerate(meta_val): + arg_list.append( + self.serialize_tensor_output(idx_to_name[i], element_meta_val) + ) + + return arg_list + + def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: + assert isinstance(graph_module, torch.fx.GraphModule) + for node in graph_module.graph.nodes: + try: + getattr(self, f"handle_{node.op}")(node) + except Exception as e: + raise SerializeError(f"Failed serializing node {node} in graph: {node.format_node()}") from e + + return Graph( + inputs=self.graph_state.inputs, + nodes=self.graph_state.nodes, + tensor_values=self.graph_state.tensor_values, + sym_int_values=self.graph_state.sym_int_values, + sym_bool_values=self.graph_state.sym_bool_values, + custom_obj_values=self.graph_state.custom_obj_values, + outputs=self.graph_state.outputs, + is_single_tensor_return=self.graph_state.is_single_tensor_return, + ) + + def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: + graph = self.serialize_graph(graph_module) + + return GraphModule( + graph=graph, + signature=self.serialize_signature(self.graph_signature), + module_call_graph=self.serialize_module_call_graph(self.module_call_graph), + ) + + +class ExportedProgramSerializer: + def __init__(self, opset_version: Optional[Dict[str, int]] = None): + self.opset_version: Dict[str, int] = {} + if opset_version: + self.opset_version.update(opset_version) + if "aten" not in self.opset_version: + self.opset_version["aten"] = torch._C._get_max_operator_version() + + def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact: + """ + Args: + exported_program: Exported Program to serialize + """ + if type(self) == ExportedProgramSerializer: + exported_program._validate() + + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, + exported_program.module_call_graph + ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) + serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints) + + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants = {} + for n, c in gm_serializer.custom_objs.items(): + constants[n] = c + for n, t in exported_program.constants.items(): + assert n not in constants + constants[n] = t + + serialized_ep = ExportedProgram( + graph_module=serialized_graph_module, + opset_version=self.opset_version, + range_constraints=serialized_range_constraints, + schema_version=SchemaVersion( + major=SCHEMA_VERSION[0], + minor=SCHEMA_VERSION[1], + ), + dialect=exported_program.dialect, + ) + + # Test canonical form is well defined. + canonicalize(serialized_ep) + + return SerializedArtifact( + serialized_ep, + serialize_torch_artifact(exported_program.state_dict), + serialize_torch_artifact(constants), + ) + + +class GraphModuleDeserializer: + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: List[ep.ModuleCallEntry] + names_to_symbols: Dict[str, sympy.Symbol] + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]] + + def __init__(self): + self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: Dict[str, MetaType] = {} + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> Iterator[None]: + saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = {} + try: + yield + finally: + self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta = saved + + def deserialize_operator(self, serialized_target: str): + if serialized_target.startswith("_operator"): # TODO(zhxchen17) Follow up on this. + module = operator + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("torch"): + module = torch # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + else: # TODO(zhxchen17) Don't catch all here. + return serialized_target + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: + val = s.value + if s.type == "as_expr": + if val.expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[val.expr_str] + else: + sym = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) + # NOTE(avik): Assumptions on symbols are not explicitly serialized. + # This seems dangerous: it might cause unknown differences in shape env behavior + # on deserialization? Probably deserves a follow-up. + + # Here we force symbols corresponding to SymInts to be at least integers. + # Otherwise some expressions that the shape env would otherwise evaluate to False, + # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + sym = sym.subs({s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols}) + if isinstance(sym, sympy.Symbol): + self.symbol_name_to_symbol[val.expr_str] = sym + + if vr := self.symbol_name_to_range.get(val.expr_str): + symbolic_shapes._constrain_symbol_range( + self.shape_env, + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + else: + # Placeholders, in particular, can have shapes as symbolic expressions. + # We need to populate the shape env with the range constraints of their + # free symbols, otherwise evaluating such expressions will error. + self.symbol_name_to_symbol[val.expr_str] = sym + free_symbols = sym.free_symbols + for s in free_symbols: + if s.name not in self.symbol_name_to_symbol: + self.symbol_name_to_symbol[s.name] = s + if vr := self.symbol_name_to_range.get(s.name): + symbolic_shapes._constrain_symbol_range( + self.shape_env, + s, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + + + if val.hint is None: + hint = None + else: + assert val.hint.type == "as_int" + hint = val.hint.value + + return self.shape_env.create_symintnode(sym, hint=hint) + elif s.type == "as_int": + assert isinstance(val, int) + return val + else: + raise SerializeError( + f"SymInt has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: + val = s.value + if s.type == "as_expr": + expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol) + return self.shape_env.create_symboolnode(expr) + elif s.type == "as_bool": + assert isinstance(val, bool) + return val + else: + raise SerializeError( + f"SymBool has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_tensor_meta( + self, + tensor_meta: TensorMeta, + ) -> FakeTensor: + with self.fake_tensor_mode: + return cast( + FakeTensor, + torch.empty_strided( + tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] + tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] + device=deserialize_device(tensor_meta.device), + dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + ), + ) + + def deserialize_script_obj_meta(self, script_obj_meta: CustomObjArgument) -> ep.CustomObjArgument: + return ep.CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def deserialize_graph_output(self, output) -> torch.fx.Node: + if output.type == "as_tensor": + return self.serialized_name_to_node[output.as_tensor.name] + elif output.type == "as_sym_int": + return self.serialized_name_to_node[output.as_sym_int.as_name] + elif output.type == "as_sym_bool": + return self.serialized_name_to_node[output.as_sym_bool.as_name] + else: + raise SerializeError(f"Unable to deserialize output node {output}") + + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: + # Handle the tensor metas. + for name, tensor_value in serialized_graph.tensor_values.items(): + meta_val = self.deserialize_tensor_meta(tensor_value) + self.serialized_name_to_meta[name] = meta_val + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_sym_int(sym_int_value) + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_sym_bool(sym_bool_value) + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta(script_obj_meta) + + # Inputs: convert to placeholder nodes in FX. + for i, input_ in enumerate(serialized_graph.inputs): + if input_.type in ("as_tensor", "as_sym_int", "as_custom_obj"): + node_name = input_.value.name + placeholder_node = self.graph.placeholder(node_name) + self.sync_fx_node(node_name, placeholder_node) + elif input_.type in ("as_int", "as_float", "as_bool", "as_none", "as_string"): + node_name = f"arg{i}" + placeholder_node = self.graph.placeholder(node_name) + placeholder_node.meta["val"] = self.deserialize_input(input_) + else: + raise SerializeError(f"Invalid input type {input_}") + + # Nodes: convert to call_function nodes. + for serialized_node in serialized_graph.nodes: + try: + target = self.deserialize_operator(serialized_node.target) + self.deserialize_node(serialized_node, target) + + except Exception as e: + raise SerializeError(f"Failed deserializing node {serialized_node}") from e + + # Outputs: convert to a single `output` node. + outputs = [] + for output in serialized_graph.outputs: + outputs.append(self.deserialize_graph_output(output)) + + if serialized_graph.is_single_tensor_return: + assert len(outputs) == 1 + outputs = outputs[0] # type: ignore[assignment] + else: + outputs = tuple(outputs) # type: ignore[assignment] + + output_node = self.graph.output(outputs) + + if serialized_graph.is_single_tensor_return: + output_node.meta["val"] = output_node.args[0].meta["val"] + else: + output_node.meta["val"] = tuple( + arg.meta["val"] for arg in output_node.args[0] + ) + + return self.graph + + def deserialize_node(self, serialized_node: Node, target: Callable) -> None: + if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS: + name = serialized_node.outputs[0].value.as_name + args = self.deserialize_sym_op_inputs(serialized_node.inputs) + + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.deserialize_sym_op_outputs(serialized_node, fx_node) + + elif isinstance(target, torch._ops.HigherOrderOperator): + args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) + # If HOP returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + # + # HOPs don't have schema yet, just check the output lengths and as_tensor attribute + name = ( + serialized_node.outputs[0].as_tensor.name + if len(serialized_node.outputs) == 1 and hasattr(serialized_node.outputs[0], "as_tensor") + else None + ) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + + elif isinstance(target, torch._ops.OpOverload): + # For convenience: if this node returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + name = ( + serialized_node.outputs[0].as_tensor.name + if _is_single_tensor_return(target) + else None # FX will generate a name for us. + ) + args, kwargs = self.deserialize_inputs(target, serialized_node) + fx_node = self.graph.create_node("call_function", target, args, kwargs, name) + self.deserialize_outputs(serialized_node, fx_node) + else: + raise SerializeError(f"Unsupported target type for node {serialized_node}: {target}") + + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + if i.type == "user_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None + ) + elif i.type == "parameter": + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=ep.TensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.type == "buffer": + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=ep.TensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + persistent=i.buffer.persistent, + ) + elif i.type == "tensor_constant": + return ep.InputSpec( + kind=ep.InputKind.CONSTANT_TENSOR, + arg=ep.TensorArgument(name=i.tensor_constant.arg.name), + target=i.tensor_constant.tensor_constant_name, + ) + elif i.type == "custom_obj": + return ep.InputSpec( + kind=ep.InputKind.CUSTOM_OBJ, + arg=ep.CustomObjArgument(name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn), + target=i.custom_obj.custom_obj_name, + ) + else: + raise AssertionError(f"Unknown input spec {i}") + + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + if o.type == "user_output": + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.type == "loss_output": + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=ep.TensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.type == "buffer_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name + ) + elif o.type == "gradient_to_parameter": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name + ) + elif o.type == "gradient_to_user_input": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name + ) + elif o.type == "user_input_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.USER_INPUT_MUTATION, + arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), + target=o.user_input_mutation.user_input_name + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs] + ) + + def deserialize( + self, + serialized_graph_module: GraphModule, + serialized_state_dict: bytes, + constants: bytes, + symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Result: + global _CURRENT_DESERIALIZER + assert _CURRENT_DESERIALIZER is None + _CURRENT_DESERIALIZER = self + try: + self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} + self.symbol_name_to_range = {} if symbol_name_to_range is None else symbol_name_to_range + self.signature = self.deserialize_signature(serialized_graph_module.signature) + self.constants = deserialize_torch_artifact(constants) + self.deserialize_graph(serialized_graph_module.graph) + + module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph) + return GraphModuleDeserializer.Result( + graph_module=ep._create_graph_module_for_export(self.module, self.graph), + signature=self.signature, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, + state_dict=deserialize_torch_artifact(serialized_state_dict), + constants=self.constants, + ) + finally: + _CURRENT_DESERIALIZER = None + + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): + if name in self.serialized_name_to_node: + raise SerializeError(f"Node {name} has already been deserialized before.") + self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta + fx_node.meta["val"] = self.serialized_name_to_meta[name] + + def deserialize_sym_op_inputs(self, inputs): + return tuple(self.deserialize_input(input.arg) for input in inputs) + + def deserialize_inputs(self, target: torch._ops.OpOverload, serialized_node: Node): + schema_args = target._schema.arguments + actual_args = { + input.name: self.deserialize_input(input.arg) for input in serialized_node.inputs + } + args = [] + kwargs = {} + for schema_arg in schema_args: + is_positional = not schema_arg.has_default_value() and not schema_arg.kwarg_only + if is_positional: + args.append(actual_args[schema_arg.name]) + else: + if schema_arg.name in actual_args: + kwargs[schema_arg.name] = actual_args[schema_arg.name] + return tuple(args), kwargs + + def deserialize_hoo_inputs(self, inputs: List[NamedArgument]): + """ + For deserializing HOO inputs since HOOs do not have a schema. + """ + args = [] + kwargs = {} + for input_ in inputs: + if input_.name != "": + kwargs[input_.name] = self.deserialize_input(input_.arg) + else: + args.append(self.deserialize_input(input_.arg)) + return (tuple(args), kwargs) + + def deserialize_input(self, inp: Argument) -> Any: + value = inp.value + typ_ = inp.type + if typ_ == "as_none": + # None should converted as None, but is encoded as bool in serialized + # Convert serialized object to torch equivalent + return None + elif typ_ == "as_tensor": + return self.serialized_name_to_node[inp.as_tensor.name] + elif typ_ == "as_scalar_type": + return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] + elif typ_ == "as_memory_format": + return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] + elif typ_ == "as_layout": + return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = ep._create_graph_module_for_export(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) + elif typ_ == "as_device": + return deserialize_device(inp.as_device) + elif typ_ == "as_int": + return inp.as_int + elif typ_ == "as_float": + return inp.as_float + elif typ_ == "as_bool": + return inp.as_bool + elif typ_ == "as_string": + return inp.as_string + elif typ_ == "as_sym_int": + return self.deserialize_sym_argument(inp.as_sym_int) + elif typ_ == "as_sym_bool": + return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, list): + if len(value) == 0: + return [] + elif typ_ == "as_tensors": + result = [] + for arg in value: + result.append(self.serialized_name_to_node[arg.name]) + return result + elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): + # convert from serialized.python.types.List to python list + return list(value) + elif typ_ in ("as_sym_ints", "as_sym_bools"): + return [self.deserialize_sym_argument(arg) for arg in value] + elif typ_ == "as_optional_tensors": + def deserialize_optional_tensor_args(a): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return self.serialized_name_to_node[a.value] + else: + raise SerializeError(f"Unhandled argument {inp}") + return list(map(deserialize_optional_tensor_args, value)) + else: + raise SerializeError(f"Unhandled argument {inp}") + elif typ_ == "as_custom_obj": + if inp.as_custom_obj.name in self.serialized_name_to_node: + # Custom object has been lifted as an input + return self.serialized_name_to_node[inp.as_custom_obj.name] + return self.constants[inp.as_custom_obj.name] + elif typ_ == "as_operator": + return self.deserialize_operator(inp.as_operator) + else: + raise SerializeError(f"Unhandled argument {inp}") + + def deserialize_sym_argument(self, sym_arg): + if isinstance(sym_arg, SymIntArgument): + if sym_arg.type == "as_int": + return sym_arg.as_int + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymBoolArgument): + if sym_arg.type == "as_bool": + return sym_arg.as_bool + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") + + def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + # Check single value return + if len(serialized_node.outputs) == 0: + return + if ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_tensor" + ): + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return + elif ( + len(serialized_node.outputs) == 1 and + isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument)) + ): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return + + self.deserialize_multiple_outputs(serialized_node, fx_node) + + def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None: + deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) + + def generate_getitem(meta_val, fx_node: torch.fx.Node, arg: Union[TensorArgument, SymIntArgument], idx: int): + if isinstance(arg, TensorArgument): + name = arg.name + elif isinstance(arg, SymIntArgument): + name = arg.as_name + else: + raise AssertionError(f"generate_getitem got unknown argument type {type(arg)}") + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name=name, + ) + self.sync_fx_node(name, individual_output) + meta_val.append(self.serialized_name_to_meta[name]) + # The derived `getitem` nodes should have the same stacktrace as the + # original `fx_node` + individual_output.meta.update(deserialized_metadata) + + def generate_getitems(meta_val, fx_node: torch.fx.Node, args): + for idx, arg in enumerate(args): + if isinstance(arg, Argument): + arg = arg.value + if isinstance(arg, (TensorArgument, SymIntArgument)): + generate_getitem(meta_val, fx_node, arg, idx) + elif isinstance(arg, (list, tuple)): + list_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + ) + meta_val.append([]) + generate_getitems(meta_val[-1], list_output, arg) + list_output.meta.update(deserialized_metadata) + list_output.meta['val'] = meta_val[-1] + else: + raise NotImplementedError(f"Unimplemented node output type: {arg}") + + # Convert multiple return types to FX format. + # In FX, each node only returns one value. So in order to represent + # multiple return values, we have to emit a `getitem` node for each + # return value. + # This performs the inverse mapping of the `serialize_outputs` call in + # serialization, see [NOTE: Multiple outputs] + meta_val: List[Any] = [] + if len(serialized_node.outputs) == 1: + assert isinstance(serialized_node.outputs[0].value, list) + assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors) + else: + generate_getitems(meta_val, fx_node, serialized_node.outputs) + + # also update the metaval for `fx_node` to be a list(meta) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + + def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: + ret: Dict[str, Any] = {} + if stack_trace := metadata.get("stack_trace"): + ret["stack_trace"] = stack_trace + + def deserialize_meta_func(serialized_target: str): + module = None + if serialized_target.startswith("torch.nn"): + module = torch.nn + serialized_target_names = serialized_target.split(".")[2:] + elif serialized_target.startswith("torch"): + module = torch + serialized_target_names = serialized_target.split(".")[1:] + else: + return self.deserialize_operator(serialized_target) + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + if nn_module_stack_str := metadata.get("nn_module_stack"): + # Originally serialized to "key,orig_path,type_str" + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + nn_module_stack = dict( + import_nn_module_stack(*item.split(",")) + for item in nn_module_stack_str.split(ST_DELIMITER) + ) + ret["nn_module_stack"] = nn_module_stack + + if source_fn_st_str := metadata.get("source_fn_stack"): + # Originally serializes to "fx_node_name,op_str" + source_fn_st = [] + for source_fn_str in source_fn_st_str.split(ST_DELIMITER): + name, target_str = source_fn_str.split(",") + source_fn_st.append((name, deserialize_meta_func(target_str))) + ret["source_fn_stack"] = source_fn_st + return ret + + def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: + if x.type == "as_tensor": + return ep.TensorArgument(name=x.as_tensor.name) + elif x.type == "as_sym_int": + return ep.SymIntArgument(name=x.as_sym_int.as_name) + else: + return ep.ConstantArgument(value=self.deserialize_input(x)) + + def deserialize_module_call_signature(self, module_call_signature: ModuleCallSignature) -> ep.ModuleCallSignature: + return ep.ModuleCallSignature( + inputs=[self.deserialize_argument_spec(x) for x in module_call_signature.inputs], + outputs=[self.deserialize_argument_spec(x) for x in module_call_signature.outputs], + in_spec=treespec_loads(module_call_signature.in_spec), + out_spec=treespec_loads(module_call_signature.out_spec), + ) + + def deserialize_module_call_graph(self, module_call_graph: List[ModuleCallEntry]) -> List[ep.ModuleCallEntry]: + return [ + ep.ModuleCallEntry( + fqn=entry.fqn, + signature=self.deserialize_module_call_signature(entry.signature) if entry.signature else None, + ) for entry in module_call_graph + ] + + +class ExportedProgramDeserializer: + def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): + self.expected_opset_version: Dict[str, int] = {} + if expected_opset_version: + self.expected_opset_version.update(expected_opset_version) + if "aten" not in self.expected_opset_version: + self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + + def deserialize_range_constraints( + self, + symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: Dict[str, sympy.Symbol], + ) -> Dict[sympy.Symbol, ValueRanges]: + range_constraints = {} + for k, v in symbol_name_to_range.items(): + if symbol := symbol_name_to_symbol.get(k): + range_constraints[symbol] = v # type: ignore[arg-type] + else: + log.warning(f"Symbol {k} did not appear in the graph that was deserialized") # noqa: G004 + return range_constraints + + def deserialize( + self, serialized_artifact: SerializedArtifact + ) -> ep.ExportedProgram: + assert isinstance(serialized_artifact.exported_program, ExportedProgram) + + if serialized_artifact.exported_program.schema_version.major != SCHEMA_VERSION[0]: + raise SerializeError( + f"Serialized schema version {serialized_artifact.exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) + + symbol_name_to_range = { + k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)) + for k, v in serialized_artifact.exported_program.range_constraints.items() + } + res = ( + GraphModuleDeserializer() + .deserialize( + serialized_artifact.exported_program.graph_module, + serialized_artifact.state_dict, + serialized_artifact.constants, + symbol_name_to_range, + ) + ) + range_constraints = self.deserialize_range_constraints( + symbol_name_to_range, res.names_to_symbols, + ) + model_opset_version: Optional[Dict[str, int]] = serialized_artifact.exported_program.opset_version + self._validate_model_opset_version(model_opset_version) + + upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version) + + exported_program = ep.ExportedProgram( + root=res.graph_module, + graph=res.graph_module.graph, + graph_signature=res.signature, + state_dict=res.state_dict, # type: ignore[arg-type] + range_constraints=range_constraints, + module_call_graph=res.module_call_graph, + example_inputs=None, + verifier=load_verifier(serialized_artifact.exported_program.dialect), + constants=res.constants, + ) + return upgrader.upgrade(exported_program) + + def _validate_model_opset_version(self, model_opset_version: Optional[Dict[str, int]]): + """Compare model_opset_version with expected_opset_version and raise error if we can't resolve the version + difference. + E.g., model_opset_version = {"aten": 3, "custom": 4} + expected_opset_version = {"aten": 4, "custom": 4} + This means we can use an upgrader for ATen to reconcile the deserialized model. + + The logic of this method: + + For common op namespaces: + 1. if model version < expected version, this case can be handled by upgraders. + 2. if model version > expected version, we need downgraders but not implemented yet. + 3. if model version == expected version, we don't need extra handling. + + For op namespace only in model_opset_version, we should give a warning because it is missing from + expected_opset_version. + """ + if not model_opset_version: + raise RuntimeError("Serialized model should have opset version.") + common_namespaces = {key for key in model_opset_version if key in self.expected_opset_version} + for namespace in common_namespaces: + assert ( + isinstance(model_version := model_opset_version[namespace], int) + ), f"model_opset_version value should be int, got {model_opset_version[namespace]}" + + assert ( + isinstance(compiler_version := self.expected_opset_version[namespace], int) + ), f"expected_opset_version value should be int, got {self.expected_opset_version[namespace]}" + + # TODO(larryliu0820): Add support for upgrader & downgrader + if model_version != compiler_version: + raise NotImplementedError( + f"Model opset version {model_opset_version} doesn't match to compiler opset version " + f"{self.expected_opset_version}! Upgrader/downgrader is not implemented yet." + ) + for namespace in model_opset_version: + if namespace in common_namespaces: + continue + log.warning("Compiler doesn't have a version table for op namespace: {ns}. ", extra={"ns": namespace}) + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, bytes): + return base64.b64encode(obj).decode('utf-8') + return super().default(obj) + + +def _dataclass_to_dict(obj): + if isinstance(obj, _Union): + return {obj.type: _dataclass_to_dict(obj.value)} + elif dataclasses.is_dataclass(obj): + return { + f.name: _dataclass_to_dict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + if not (f.default is None and getattr(obj, f.name) is None) + } + elif isinstance(obj, list): + return [_dataclass_to_dict(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_dataclass_to_dict(x) for x in obj) + elif isinstance(obj, dict): + return {k: _dataclass_to_dict(v) for k, v in obj.items()} + else: + return obj + + +def serialize( + exported_program: ep.ExportedProgram, + opset_version: Optional[Dict[str, int]] = None, +) -> SerializedArtifact: + serialized_artifact = ( + ExportedProgramSerializer(opset_version).serialize(exported_program) + ) + assert isinstance(serialized_artifact.exported_program, ExportedProgram) + + + json_program = json.dumps( + _dataclass_to_dict(serialized_artifact.exported_program), cls=EnumEncoder + ) + json_bytes = json_program.encode('utf-8') + artifact = SerializedArtifact( + json_bytes, + serialized_artifact.state_dict, + serialized_artifact.constants + ) + return artifact + + +def _dict_to_dataclass(cls, data): + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." + if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): + if data is None: + return None + ty_args = typing.get_args(cls) + assert len(ty_args) == 2 + return _dict_to_dataclass(ty_args[0], data) + elif isinstance(cls, type) and issubclass(cls, _Union): + assert isinstance(data, dict) + assert len(data) == 1 + _type = next(iter(data.keys())) + _value = next(iter(data.values())) + assert isinstance(_type, str) + field_type = cls.__annotations__[_type] + return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) + elif dataclasses.is_dataclass(cls): + obj = cls(**data) # type: ignore[assignment] + type_hints = typing.get_type_hints(cls) + for f in dataclasses.fields(cls): + name = f.name + new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) + setattr(obj, name, new_field_obj) + return obj + elif isinstance(data, list): + if len(data) == 0: + return data + d_type = typing.get_args(cls)[0] + return [ + _dict_to_dataclass(d_type, d) + for d in data + ] + elif isinstance(data, dict): + v_type = typing.get_args(cls)[1] + return { + k: _dict_to_dataclass(v_type, v) + for k, v in data.items() + } + return data + + +def deserialize( + artifact: SerializedArtifact, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ep.ExportedProgram: + assert isinstance(artifact.exported_program, bytes) + exported_program_str = artifact.exported_program.decode('utf-8') + exported_program_dict = json.loads(exported_program_str) + serialized_exported_program = _dict_to_dataclass(ExportedProgram, exported_program_dict) + return ( + ExportedProgramDeserializer(expected_opset_version) + .deserialize( + SerializedArtifact( + serialized_exported_program, + artifact.state_dict, + artifact.constants + ) + ) + ) + + +def _canonicalize_graph(sorted_inputs, sorted_outputs, graph) -> Tuple[Graph, Dict[str, str]]: + def _get_argument(a: Argument): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return a.as_tensor + elif a.type == "as_tensors": + return a.as_tensors + elif a.type == "as_int": + return None + elif a.type == "as_ints": + return None + elif a.type == "as_float": + return None + elif a.type == "as_floats": + return None + elif a.type == "as_string": + return None + elif a.type == "as_strings": + return None + elif a.type == "as_sym_int": + return a.as_sym_int + elif a.type == "as_sym_ints": + return a.as_sym_ints + elif a.type == "as_scalar_type": + return None + elif a.type == "as_memory_format": + return None + elif a.type == "as_layout": + return None + elif a.type == "as_device": + return None + elif a.type == "as_bool": + return None + elif a.type == "as_bools": + return None + elif a.type == "as_sym_bool": + return a.as_sym_bool + elif a.type == "as_sym_bools": + return a.as_sym_bools + elif a.type == "as_graph": + return None + elif a.type == "as_optional_tensors": + return a.as_optional_tensors + elif a.type == "as_custom_obj": + return None + elif a.type == "as_operator": + return None + else: + raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") + + # Stage 1: Reorder named items. + def for_args(f, a): + assert isinstance(a, Argument) + pytree.tree_map(f, _get_argument(a)) + + def sort_nodes(nodes): + @dataclass + class Edges: + outs: List[int] + ins: int + + graph_inputs: Set[str] = set() + def_table: Dict[str, int] = {} + edges: Dict[int, Edges] = {} + candidates: List[Tuple[str, List[Tuple[str, List[int]]], int]] = [] + rank: Dict[str, int] = {} + ret: List[Node] = [] + + def get_name(a) -> Optional[str]: + if a is None: + return None + if isinstance(a, TensorArgument): + return a.name + elif isinstance(a, (SymIntArgument, SymBoolArgument)): + if a.type == "as_name": + return a.as_name + elif a.type in ("as_int", "as_bool"): + return None + else: + raise AssertionError(f"Unknown argument type: {a}") + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + assert isinstance(a.as_tensor, str) + return a.as_tensor + elif a.type == "as_none": + return None + else: + raise AssertionError(f"Unknown optional tensor type: {a}") + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + def add_input(a): + if s := get_name(a): + graph_inputs.add(s) + + for_args(add_input , i) + + for idx, node in enumerate(nodes): + def add_def(a): + if s := get_name(a): + assert s not in def_table + def_table[s] = idx + + for o in node.outputs: + for_args(add_def, o) + + edges[idx] = Edges([], 0) + + for idx, user in enumerate(nodes): + def add_edge(a): + if s := get_name(a): + if s not in def_table: + assert s in graph_inputs + return + src = def_table[s] + edges[src].outs.append(idx) + edges[idx].ins += 1 + + for i in user.inputs: + for_args(add_edge, i.arg) + + def add_rank(a): + if s := get_name(a): + assert s not in rank + rank[s] = len(rank) + + def get_rank(a): + if s := get_name(a): + return rank[s] + else: + return -1 + + for i in sorted_inputs: + for_args(add_rank, i) + + def add_candidate(idx: int): + def get_ranks(i): + ranks = [] + for_args(lambda x: ranks.append(get_rank(x)), i) + return ranks + node = nodes[idx] + args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] + heapq.heappush(candidates, (node.target, args_rank, idx)) + + for idx, e in edges.items(): + if e.ins == 0: + add_candidate(idx) + + while len(candidates) > 0: + _, _, idx = heapq.heappop(candidates) + node = nodes[idx] + for o in node.outputs: + for_args(add_rank, o) + ret.append(node) + assert idx in edges + for user in edges[idx].outs: + e = edges[user] + assert e.ins > 0 + e.ins -= 1 + if e.ins == 0: + add_candidate(user) + edges[idx].outs.clear() + + return ret + + sorted_nodes = sort_nodes(graph.nodes) + assert len(sorted_nodes) == len(graph.nodes) + + # Stage 2: Rename nodes. + name_table: Dict[str, str] = {} + + def rename_def(a): + def _rename(arg_name, values): + new_name = f"_{len(name_table)}" + assert arg_name not in name_table + name_table[arg_name] = new_name + assert arg_name in values + values[new_name] = values.pop(arg_name) + return new_name + + if a is None: + return + if isinstance(a, TensorArgument): + a.name = _rename(a.name, graph.tensor_values) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_int_values) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_bool_values) + else: + raise AssertionError(f"Unknown argument type: {a}") + + def replace_use(a): + if a is None: + return + if isinstance(a, TensorArgument): + a.name = name_table.get(a.name, a.name) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + assert isinstance(a.as_tensor, str) + a.as_tensor = name_table.get(a.as_tensor, a.as_tensor) + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + for_args(rename_def, i) + + for n in sorted_nodes: + for o in n.outputs: + for_args(rename_def, o) + + for n in sorted_nodes: + for i in n.inputs: + for_args(replace_use, i.arg) + + for o in sorted_outputs: + for_args(replace_use, o) + + # Stage 3: Remove unstable fields. + for n in sorted_nodes: + n.metadata.clear() + + # Stage 4: Aggregate values. + sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=lambda x: x[0])) + sorted_sym_int_values = dict(sorted(graph.sym_int_values.items(), key=lambda x: x[0])) + sorted_sym_bool_values = dict(sorted(graph.sym_bool_values.items(), key=lambda x: x[0])) + + # Stage 5: Recurse in subgraphs. + counter = 0 + for node in sorted_nodes: + for i in node.inputs: + a = i.arg + if a.type == "as_graph": + a.as_graph.graph = _canonicalize_graph( + a.as_graph.graph.inputs, + a.as_graph.graph.outputs, + a.as_graph.graph + ) + a.as_graph.name = f"_g{counter}" + counter += 1 + + graph = Graph( + inputs=sorted_inputs, + outputs=sorted_outputs, + nodes=sorted_nodes, + tensor_values=sorted_tensor_values, + sym_int_values=sorted_sym_int_values, + sym_bool_values=sorted_sym_bool_values, + is_single_tensor_return=graph.is_single_tensor_return, + ) + return graph, name_table + + +def canonicalize(ep: ExportedProgram) -> ExportedProgram: + """ + Normalize a serialized ExportedProgram, so that different eager program which + shares the same semantics can get a single representation on disk. + + This function canonicalizes an ExportedProgram by: + + 1. Sorting nodes in topological order. + 2. Rename nodes to have unique names. + 3. Remove unstable fields. + 4. Aggregate the above program fields. + 5. Recurse in subgraphs. + + Args: + ep (ExportedProgram): The ExportedProgram to canonicalize. + + Returns: + ExportedProgram: The canonicalized exported program. + """ + ep = copy.deepcopy(ep) + + opset_version = dict(sorted(ep.opset_version.items(), key=lambda x: x[0])) + range_constraints = dict(sorted(ep.range_constraints.items(), key=lambda x: x[0])) + module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) + signature = ep.graph_module.signature + graph = ep.graph_module.graph + + assert len(graph.inputs) == len(signature.input_specs) + assert len(graph.outputs) == len(signature.output_specs) + + def rank_input(inp) -> Tuple[int, Optional[str], int]: + idx, (arg, spec) = inp + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + return 5, None, idx + elif spec.type == "parameter": + return 1, spec.parameter.parameter_name, idx + elif spec.type == "buffer": + return 2, spec.buffer.buffer_name, idx + elif spec.type == "tensor_constant": + return 3, spec.tensor_constant.tensor_constant_name, idx + elif spec.type == "custom_obj": + return 4, spec.custom_obj.custom_obj_name, idx + else: + raise AssertionError(f"Unknown input type: {spec}") + + def rank_output(out) -> Tuple[int, Optional[str], int]: + idx, (arg, spec) = out + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + return 3, None, idx + elif spec.type == "loss_output": + return 3, None, idx + elif spec.type == "buffer_mutation": + return 1, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 4, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 5, None, idx + elif spec.type == "user_input_mutation": + return 2, None, idx + else: + raise AssertionError(f"Unknown output type: {spec}") + + sorted_ins = sorted(enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input) + sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] + + sorted_outs = sorted(enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output) + sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] + + sorted_graph, replace_table = _canonicalize_graph(sorted_inputs, sorted_outputs, graph) + + def replace_input(inp): + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + arg = spec.user_input.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type in ("as_none", "as_int", "as_float", "as_string", "as_custom_obj"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "parameter": + t = spec.parameter.arg + t.name = replace_table[t.name] + elif spec.type == "buffer": + t = spec.buffer.arg + t.name = replace_table[t.name] + elif spec.type == "tensor_constant": + t = spec.tensor_constant.arg + t.name = replace_table[t.name] + elif spec.type == "custom_obj": + return + else: + raise AssertionError(f"Unknown input type: {spec}") + + def replace_output(out): + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + arg = spec.user_output.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type in ("as_none", "as_int", "as_float", "as_string"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "loss_output": + t = spec.loss_output.arg + t.name = replace_table[t.name] + elif spec.type == "buffer_mutation": + t = spec.buffer_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_parameter": + t = spec.gradient_to_parameter.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_user_input": + g = spec.gradient_to_user_input + g.arg.name = replace_table[g.arg.name] + g.user_input_name = replace_table[g.user_input_name] + elif spec.type == "user_input_mutation": + u = spec.user_input_mutation + u.arg.name = replace_table[u.arg.name] + u.user_input_name = replace_table[u.user_input_name] + else: + raise AssertionError(f"Unknown output type: {spec}") + + for spec in input_specs: + replace_input(spec) + + for spec in output_specs: + replace_output(spec) + + return ExportedProgram( + graph_module=GraphModule( + graph=sorted_graph, + signature=GraphSignature( + input_specs=list(input_specs), + output_specs=list(output_specs), + ), + module_call_graph=module_call_graph, + ), + opset_version=opset_version, + range_constraints=range_constraints, + schema_version=ep.schema_version, + dialect=ep.dialect, + ) diff --git a/MLPY/Lib/site-packages/torch/_export/serde/union.py b/MLPY/Lib/site-packages/torch/_export/serde/union.py new file mode 100644 index 0000000000000000000000000000000000000000..57a47a712c2f971b9a474e230f8b54547e2acad3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/serde/union.py @@ -0,0 +1,69 @@ +import functools +from dataclasses import fields +from typing import Hashable, Set + + +class _UnionTag(str): + _cls: Hashable + + @staticmethod + def create(t, cls): + tag = _UnionTag(t) + assert not hasattr(tag, "_cls") + tag._cls = cls + return tag + + def __eq__(self, cmp) -> bool: + assert isinstance(cmp, str) + other = str(cmp) + assert other in _get_field_names( + self._cls + ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + return str(self) == other + + def __hash__(self): + return hash(str(self)) + + +@functools.lru_cache(maxsize=None) +def _get_field_names(cls) -> Set[str]: + return {f.name for f in fields(cls)} + + +class _Union: + _type: _UnionTag + + @classmethod + def create(cls, **kwargs): + assert len(kwargs) == 1 + obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] + obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) + return obj + + def __post_init__(self): + assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] + + @property + def type(self) -> str: + try: + return self._type + except AttributeError as e: + raise RuntimeError( + f"Please use {type(self).__name__}.create to instantiate the union type." + ) from e + + @property + def value(self): + return getattr(self, self.type) + + def __getattribute__(self, name): + attr = super().__getattribute__(name) + if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] + raise AttributeError(f"Field {name} is not set.") + return attr + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" diff --git a/MLPY/Lib/site-packages/torch/_export/serde/upgrade.py b/MLPY/Lib/site-packages/torch/_export/serde/upgrade.py new file mode 100644 index 0000000000000000000000000000000000000000..121edbe29b8aa20124a28109e35b6e991aa31b40 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/serde/upgrade.py @@ -0,0 +1,201 @@ +import logging +from collections import defaultdict +from typing import Tuple, Dict, Optional, List + +import torch +from torch.export import export +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._subclasses import FakeTensor +from torch.fx.node import Target, Argument +from torch.library import Library +from torch.utils._pytree import tree_unflatten +import torch._export.exported_program as ep +import re + +lib = Library("aten", "FRAGMENT") +impl_lib = Library("aten", "IMPL") + +log = logging.getLogger(__name__) + + +def get_target_version(versioned_upgrader_name: str) -> int: + """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is + upgrading to version 4.""" + if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name): + raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid") + + return int(versioned_upgrader_name.split('_')[-1]) + 1 + + +def get_upgraders() -> Dict[str, Tuple[str, str]]: + """Getting upgraders entry map and operator version map and merge them into one dict.""" + upgraders = torch._C._get_upgraders_entry_map() + op_version_map = torch._C._get_operator_version_map() + output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] + for opname, entry_list in op_version_map.items(): + if not entry_list: + raise RuntimeError(f"Op version map has an empty entry for opname {opname}") + entry = entry_list[0] + old_schema = entry.old_schema + upgrader_name = entry.upgrader_name + upgrader_str = upgraders.get(upgrader_name, None) + if not upgrader_str: + raise RuntimeError(f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}") + output[upgrader_name] = (old_schema, upgrader_str) + return output + + +class GraphModuleOpUpgrader: + """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available. + To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In + __init__() it does the following: + 1. parse the upgrader list and reorder for upgrading purpose. + 2. register old versions of operators as custom ops. + 3. prepare upgrader passes. + + In `upgrade()` API run these upgrader passes. + + An example of op_upgraders input: + { + "aten::div__Scalar_0_3": ( # versioned op name + "div._Scalar(self: Tensor, other: Scalar)", # old schema + ''' + def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string + if (self.is_floating_point() or isinstance(other, float)): + return self.true_divide_(other) + return self.divide_(other, rounding_mode='trunc') + ''', + ), + }, + + Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the + original TorchScript upgrader). + """ + + class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse): + def __init__(self, old_target: Target, new_target: Target): + super().__init__() + self.old_target = old_target + self.new_target = new_target + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op == self.old_target: + return super().call_operator(self.new_target, args, kwargs, meta) + return super().call_operator(op, args, kwargs, meta) + + def __init__( + self, + compiler_opset_version: Optional[Dict[str, int]] = None, + model_opset_version: Optional[Dict[str, int]] = None, + op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None, + ): + self.op_upgraders: Dict[str, Tuple[str, str]] = get_upgraders() if not op_upgraders else op_upgraders + self.compiler_opset_version = compiler_opset_version if compiler_opset_version else {} + self.model_opset_version = model_opset_version if model_opset_version else {} + self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = GraphModuleOpUpgrader._populate_passes( + self._parse_upgraders(self.op_upgraders)) + + def _parse_upgraders(self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None) -> List[Tuple[str, str]]: + """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well + as the upgrader function string literal.""" + # TODO(larryliu0820): Add support for custom ops + op_namespace = "aten" + if not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version: + return [] + model_ver = self.model_opset_version[op_namespace] + curr_ver = self.compiler_opset_version[op_namespace] + + # key is the target version. div__Scalar_0_3 should have a key of 4. + versioned_upgraders: Dict[int, Tuple[str, str]] = {get_target_version(name): v for name, v in + op_upgraders.items()} + target_upgraders: List[Tuple[str, str]] = [] + # we need all upgraders from model_ver + 1 to curr_ver, inclusively + for ver in range(model_ver + 1, curr_ver + 1): + if ver in versioned_upgraders: + target_upgraders.append(versioned_upgraders[ver]) + else: + # we may be able to get away with missing upgraders, if that operator is missing from given graph + # module. + log.warning("Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver}) + + return target_upgraders + + @staticmethod + def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]: + """Given a list of upgraders, loop through it from lower version to higher version and create passes for all + upgraders. se torch.Library API to register old ops. Op name will be + __. Register upgraders as CompositeImplicitAutograd kernels. For example: + + lib = Library("aten", "FRAGMENT") + lib.define(old_schema) + + impl_lib = Library("aten", "IMPL") + impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd") + + @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the + upgrader function literal text. + @:return upgrader passes, order matters + """ + + upgrader_passes = [] + + def register_old_op(name: str, schema: str, impl_str: str): + """Registers an old version operator using impl_name as old op name.""" + lib.define(schema) + try: + exec(impl_str) + except Exception as e: + raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e + impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd") + + for (schema, upgrader_str) in upgraders: + upgrader_name = upgrader_str.split('(')[0].split(' ')[-1] + op_name = schema.split('(')[0].split("::")[-1] + schema = schema.replace(op_name, upgrader_name) + try: + register_old_op(name=upgrader_name, schema=schema, impl_str=upgrader_str) + except RuntimeError as e: + if "with the same name and overload name multiple times" in str(e): + print(f"Registering {upgrader_name} multiple times") + else: + raise RuntimeError from e + old_op_target = getattr(torch.ops.aten, upgrader_name).default + # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the + # "default" at the end. + op_name, overload_name = (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2]) + new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name) + # Note that the graph will have op names in the graph, but actually they are of old versions. + upgrader_passes.append( + GraphModuleOpUpgrader.UpgraderPass(old_target=new_op_target, new_target=old_op_target)) + + return upgrader_passes + + def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram: + """Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of + operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the + upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the + upgrader. After all passes are applied, the exported program will be upgraded to the target version.""" + if not self.upgrader_passes: + return exported_program + + args = [n.meta.get("val", None) for n in exported_program.graph.nodes if n.op == "placeholder"] + args_real_tensors = [torch.ones(tuple(arg.size()), dtype=arg.dtype) if isinstance(arg, FakeTensor) else arg for + arg in args] + assert exported_program.call_spec.in_spec is not None + args, kwargs = tree_unflatten(args_real_tensors, exported_program.call_spec.in_spec) + assert kwargs == {} + + for _pass in self.upgrader_passes: + upgraded_program = exported_program._transform_do_not_use(_pass) + # NB: we have to retrace the graph_module instead of ep because of some failure. + exported_program = export(upgraded_program.module(), args, kwargs) + + return exported_program diff --git a/MLPY/Lib/site-packages/torch/_export/utils.py b/MLPY/Lib/site-packages/torch/_export/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..65881b4dd625bc3a1c5b5e0de2dfb756ad70d560 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/utils.py @@ -0,0 +1,401 @@ +import dataclasses +import math +import operator +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type + +import torch +from torch._subclasses.fake_tensor import FakeTensor + +from torch.export import ExportedProgram +from torch.utils._pytree import ( + _register_pytree_node, + Context, + FlattenFunc, + FromDumpableContextFn, + KeyPath, + keystr, + MappingKey, + SequenceKey, + ToDumpableContextFn, + UnflattenFunc, +) + + +def _check_input_constraints_for_graph( + input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints +): + def get_keystr(key_path: KeyPath) -> str: + """For a given index into the flat_args, return a human readable string + describing how to access it, e.g. "*args["foo"][0].bar" + """ + # Prefix the keypath with "*args" or "**kwargs" to make it clearer where + # the arguments come from. Ultimately we ought to serialize the + # original arg names for the best error message here. + args_kwargs_key_path = key_path[0] + assert isinstance(args_kwargs_key_path, SequenceKey) + if args_kwargs_key_path.idx == 0: + return f"*args{keystr(key_path[1:])}" + else: + kwarg_key = key_path[1] + assert isinstance(kwarg_key, MappingKey) + name = str(kwarg_key)[1:-1] # get rid of the enclosed [] + return f"{name}{keystr(key_path[2:])}" + + import sympy + + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _convert_range_to_int, + ) + from torch.utils._sympy.solve import try_solve + + if len(flat_args_with_path) != len(input_placeholders): + raise RuntimeError( + "Unexpected number of inputs " + f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" + ) + # NOTE: export already guarantees that the same symbol is used in metadata + # for all InputDims related by equality constraints, so we can just unify + # symbols with given input dimension values to check equality constraints. + unification_map: "Dict[sympy.Symbol, Any]" = {} + for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): + node_val = node.meta.get("val") + if isinstance(node_val, FakeTensor): + if not isinstance(arg, torch.Tensor): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", + ) + + if len(node_val.shape) != len(arg.shape): + raise RuntimeError( + f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " + f"(expected {node_val.shape}, got {arg.shape})" + ) + + for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): + # TODO(avik): Assert the following property in the IR verifier: + # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr + if ( + isinstance(node_dim, torch.SymInt) + and len(node_dim.node.expr.free_symbols) == 1 + ): + symbol = next(iter(node_dim.node.expr.free_symbols)) + if symbol in unification_map: + existing_dim = node_dim.node.expr.subs(unification_map) + if arg_dim != existing_dim: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{existing_dim}, but got {arg_dim}", + ) + else: + if ( + isinstance(arg_dim, torch.SymInt) + and not arg_dim.node.expr.is_number + ): + # This can happen when, say, arg is a fake tensor. + # We do not run checks on symbolic shapes of fake inputs as + # such checks can affect the shape env. + pass + else: + solution = try_solve( + sympy.Eq(node_dim.node.expr, arg_dim), symbol + ) + if solution is None: + raise RuntimeError( # noqa: TRY200 + f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " + f"of the form {node_dim.node.expr}, where {symbol} is an integer" + ) + else: + unification_map[symbol] = int(solution[1]) + + if node_dim.node.expr in range_constraints: + min_val, max_val = _convert_range_to_int( + range_constraints[node_dim.node.expr] + ) + # NOTE: we allow dimensions to be 0/1 at runtime + if min_val > 2: + if arg_dim < min_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " + f"{min_val}, but got {arg_dim}", + ) + if max_val < math.inf: + if arg_dim > max_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " + f"{max_val}, but got {arg_dim}", + ) + else: + if arg_dim != node_dim: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " + f"{node_dim}, but got {arg_dim}", + ) + elif isinstance(node_val, (int, float, str)): + if type(arg) != type(node_val) or arg != node_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", + ) + + +def register_dataclass_as_pytree_node( + cls: Type[Any], + flatten_fn: Optional[FlattenFunc] = None, + unflatten_fn: Optional[UnflattenFunc] = None, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + return_none_fields: bool = False, +) -> None: + assert dataclasses.is_dataclass( + cls + ), f"Only dataclasses can be registered with this function: {cls}" + + def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for f in dataclasses.fields(obj): + name, val = f.name, getattr(obj, f.name) + if val is not None or return_none_fields: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn + unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + _register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: + """ + Checks if the given node is a parameter within the exported program + """ + + return node.name in program.graph_signature.inputs_to_parameters + + +def get_param( + program: ExportedProgram, + node: torch.fx.Node, +) -> Optional[torch.nn.Parameter]: + """ + Returns the parameter associated with the given node in the exported program. + Returns None if the node is not a parameter within the exported program + """ + + if is_param(program, node): + parameter_name = program.graph_signature.inputs_to_parameters[node.name] + return program.state_dict[parameter_name] + + return None + + +def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: + """ + Checks if the given node is a buffer within the exported program + """ + + return node.name in program.graph_signature.inputs_to_buffers + + +def get_buffer( + program: ExportedProgram, + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the buffer associated with the given node in the exported program. + Returns None if the node is not a buffer within the exported program + """ + + if is_buffer(program, node): + buffer_name = program.graph_signature.inputs_to_buffers[node.name] + if buffer_name in program.graph_signature.non_persistent_buffers: + return program.constants[buffer_name] + else: + return program.state_dict[buffer_name] + + return None + + +def is_lifted_tensor_constant( + program: ExportedProgram, + node: torch.fx.Node, +) -> bool: + """ + Checks if the given node is a lifted tensor constant within the exported program + """ + + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def get_lifted_tensor_constant( + program: ExportedProgram, + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the lifted tensor constant associated with the given node in the exported program. + Returns None if the node is not a lifted tensor constant within the exported program + """ + + if is_lifted_tensor_constant(program, node): + lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + return program.constants[lifted_tensor_name] + + return None + + +def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: + """ + Splits the graph module into multiple submodules based on the node_call_back. + The node_call_back should return True if the node is a delimiter. Delimiter will be + the first node in the next submodule. + """ + from torch.fx.passes.split_module import split_module + + split_map = {} + split_id = 0 + for node in gm.graph.nodes: + if node_call_back(node): + split_id += 1 + split_map[node] = split_id + + new_gm = split_module( + gm, + gm, + lambda node: split_map[node], + keep_original_order=True, + keep_original_node_name=True, + ) + # Keep the codegen from original graph module to preserve e.g. pytree info. + new_gm.graph._codegen = gm.graph._codegen + new_gm.recompile() + return new_gm + + +def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """Returns the nodes that match the node_call_back as a list.""" + return [node for node in nodes if node_call_back(node)] + + +def nodes_first( + nodes: List[torch.fx.Node], node_call_back=None +) -> Optional[torch.fx.Node]: + """ + Returns the first node that matches the node_call_back. If no node matches, returns None. + When node_call_back is None, returns the first node in the node list. + """ + ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) + if len(ret) > 0: + return ret[0] + return None + + +def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: + """Returns the number of nodes that match the node_call_back.""" + return len(nodes_filter(nodes, node_call_back)) + + +def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: + """ + Sequentially visit the nodes list and invoke node_call_back on each element. + Returns the nodes list after the node_call_back is invoked on each element. + """ + for node in nodes: + node_call_back(node) + return nodes + + +def node_replace_( + old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False +) -> None: + """ + Replace all uses of old_node with new_node. + """ + old_node.replace_all_uses_with(new_node) + if delete_old: + old_node.users.clear() + old_node.graph.erase_node(old_node) + + +def node_inline_(call_mod_node: torch.fx.Node) -> None: + """ + Inline the submodule of the given node into the parent module. + Note: we only support the case where submodule takes tensors inputs. + """ + assert call_mod_node.op == "call_module" + gm = call_mod_node.graph.owning_module + + assert isinstance(call_mod_node.target, str) + sub_gm = getattr(gm, call_mod_node.target) + + phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") + body = ( + node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") + ) + output = [node for node in sub_gm.graph.nodes if node.op == "output"] + + for ph, arg in zip(phs, call_mod_node.args): + assert isinstance(arg, torch.fx.Node) + node_replace_(ph, arg, delete_old=True) + + with gm.graph.inserting_before(call_mod_node): + for node in body: + new_node = gm.graph.node_copy(node) + node_replace_(node, new_node, delete_old=True) + + if len(output) > 0: + assert len(output) == 1 and len(output[0].args) == 1 + new_output = output[0].args[0] + + if isinstance(new_output, torch.fx.Node): + node_replace_(call_mod_node, new_output, delete_old=True) + elif isinstance(new_output, (list, tuple)): + # Inline the get_item calls for the output node. + get_item_users = nodes_filter( + list(call_mod_node.users.keys()), + lambda node: node.op == "call_function" + and node.target == operator.getitem, + ) + # get_item_node.args[1] is the idx referring to new_output[idx] + nodes_map( + get_item_users, + lambda get_item_node: node_replace_( + get_item_node, + new_output[get_item_node.args[1]], + delete_old=True, + ), + ) + call_mod_node.graph.erase_node(call_mod_node) + else: + raise NotImplementedError( + f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." + ) + else: + call_mod_node.graph.erase_node(call_mod_node) + + gm.delete_all_unused_submodules() + gm.recompile() + return gm diff --git a/MLPY/Lib/site-packages/torch/_export/verifier.py b/MLPY/Lib/site-packages/torch/_export/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2f18f5f8d9b5abca6dae71b57caf1c3c72079c31 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/verifier.py @@ -0,0 +1,416 @@ +import inspect +import math +import operator +from collections.abc import Iterable +from typing import Any, Dict, final, List, Optional, Tuple, Type + +import torch +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.export.exported_program import ExportedProgram +from torch.export.graph_signature import ( + CustomObjArgument, + InputKind, + SymIntArgument, + TensorArgument, +) +from torch.fx import GraphModule +from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt + + +class SpecViolationError(Exception): + pass + + +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + # TODO(angelayi): remove this in favor of _check_val + return _check_val(node) + + +def _check_val(node: torch.fx.Node) -> None: + def _check_correct_val(val): + if val is None: + return True + elif isinstance(val, (int, bool, str, float)): + return True + elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)): + return True + elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor. + return True + elif isinstance(val, (SymInt, SymFloat, SymBool)): + return True + elif isinstance(val, CustomObjArgument): + return True + elif isinstance(val, Iterable): + return all(_check_correct_val(x) for x in val) + return False + + def _no_returns(op): + if not isinstance(op, OpOverload): + return False + return len(op._schema.returns) == 0 + + if "val" not in node.meta: + if node.op == "call_function" and _no_returns(node.target): + return + raise SpecViolationError(f"Node.meta {node.name} is missing val field.") + + val = node.meta["val"] + if not _check_correct_val(val): + raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") + + +class _VerifierMeta(type): + _registry: Dict[str, Type['Verifier']] = {} + + def __new__(metacls, name, bases, attrs): + if bases: + if "check" in attrs or "_check_graph_module" in attrs: + raise SyntaxError("Overriding method check is not allowed.") + assert "dialect" in attrs and attrs["dialect"] != "ATEN" + else: + assert "check" in attrs + assert "_check_graph_module" in attrs + assert attrs["dialect"] == "ATEN" + + assert isinstance(attrs["dialect"], str) + ret = type.__new__(metacls, name, bases, attrs) + metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] + return ret + +def getattr_recursive(obj: Any, target: str) -> Any: + target_atoms = target.split('.') + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class Verifier(metaclass=_VerifierMeta): + dialect = "ATEN" + + def allowed_builtin_ops(self) -> List: + return [ + operator.getitem, + operator.add, + operator.mul, + operator.sub, + operator.truediv, + operator.ge, + operator.le, + operator.gt, + operator.lt, + operator.eq, + operator.ne, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.not_, + operator.pow, + operator.neg, + operator.abs, + math.ceil, + math.floor, + ] + + def allowed_op_types(self) -> Tuple[Type[Any], ...]: + return (OpOverload, HigherOrderOperator) + + def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: + return (torch.fx.GraphModule,) + + def check_valid_op(self, op): + pass + + def check_additional(self, gm: GraphModule) -> None: + """ + Additional checks that are specific to some dialects. + """ + pass + + @final + def check(self, ep: ExportedProgram) -> None: + self._check_graph_module(ep.graph_module) + _verify_exported_program_signature(ep) + + @final + def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: + def _allowed_getattr_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_getattr_types() + assert not any(t is object for t in ret) + return ret + + def _check_valid_op(op) -> None: + def _allowed_builtin_ops() -> List: + ret = self.allowed_builtin_ops() + assert all(inspect.isbuiltin(op) for op in ret) + return ret + + def _allowed_op_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_op_types() + assert not any(t is object for t in ret) + return ret + + # TODO Remove this allowlist. + _allowed_torch_functions = ( + torch.autograd.grad_mode.set_grad_enabled, + torch.sym_int, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled + + ) + + if not isinstance(op, _allowed_op_types()): + if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: + raise SpecViolationError( + f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" + f"Valid builtin ops: {_allowed_builtin_ops()}" + f"Valid torch functions: {_allowed_torch_functions}" + ) + + if isinstance(op, OpOverload): + # All ops functional + if not is_functional(op): + raise SpecViolationError( + f"operator '{op}' is not functional" + ) + self.check_valid_op(op) + + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod.graph.lint() + for node in mod.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op in {"call_module", "call_method"}: + raise SpecViolationError( + f"call_module is not valid: got a class '{node.target}' ", + ) + + elif node.op == "call_function": + _check_val(node) + + _check_valid_op(node.target) + + elif node.op == "get_attr": + if not isinstance(node.target, str): + raise SpecViolationError( + f"Expected get_attr target to be string, but got {type(node.target)}" + ) + + attr = getattr_recursive(mod, node.target) + if isinstance(attr, torch.nn.Module): + def _is_type(name, ty): + return isinstance(getattr(attr, name, None), ty) + if type(attr).__name__ == "LoweredBackendModule": + if _is_type("backend_id", str) \ + and _is_type("processed_bytes", bytes) \ + and _is_type("compile_specs", list) \ + and hasattr(attr, "original_module"): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) + + if not isinstance(attr, _allowed_getattr_types()): + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"Valid get_attr types: {_allowed_getattr_types()}" + ) + + + elif node.op == "placeholder": + _check_val(node) + # TODO(zhxchen17) + # elif node.op == "output": + # _check_flattened_outputs() + + self.check_additional(gm) + + +def _verify_exported_program_signature(exported_program) -> None: + # Check ExportedProgram signature matches + gs = exported_program.graph_signature + + # Check every node in the signature exists in the graph + input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] + + if len(input_node_names) != len(gs.input_specs): + raise SpecViolationError( + f"Number of graph inputs ({len(input_node_names)}) " + f"does not match number of inputs in the graph signature ({len(gs.user_inputs)})" + ) + + for input_spec, node in zip(gs.input_specs, input_node_names): + if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)): + if input_spec.arg.name != node: + raise SpecViolationError( + f"Input spec name {input_spec.arg.name} does not match node name {node}" + ) + + if input_spec.kind == InputKind.USER_INPUT: + continue + + elif input_spec.kind == InputKind.PARAMETER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + param = input_spec.target + if param not in exported_program.state_dict: + raise SpecViolationError( + f"Parameter {param} is not in the state dict." + ) + + if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): + raise SpecViolationError( + f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." + ) + + elif input_spec.kind == InputKind.BUFFER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + buffer = input_spec.target + if input_spec.persistent is None: + raise SpecViolationError( + f"Buffer {buffer} is missing a persistence flag" + ) + + if input_spec.persistent is True and buffer not in exported_program.state_dict: + raise SpecViolationError( + f"Buffer {buffer} is not in the state dict." + ) + + if input_spec.persistent is False and buffer in exported_program.state_dict: + raise SpecViolationError( + f"Non-persistent buffer {buffer} is in the state dict, it should not be." + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + tensor_const = input_spec.target + if tensor_const not in exported_program.constants: + raise SpecViolationError( + f"Constant tensor {tensor_const} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.CUSTOM_OBJ: + if not isinstance(input_spec.arg, CustomObjArgument): + raise SpecViolationError( + f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + custom_obj = input_spec.target + if custom_obj not in exported_program.constants: + raise SpecViolationError( + f"Custom object {custom_obj} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.TOKEN: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + else: + raise SpecViolationError( + f"Unknown InputKind {input_spec.kind}." + ) + + # Check outputs + output_node = list(exported_program.graph.nodes)[-1] + assert output_node.op == "output" + output_nodes = [ + arg.name if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ] + + if len(output_nodes) != len(gs.output_specs): + raise SpecViolationError( + f"Number of output nodes {len(output_nodes)} is different " + "Than the number of outputs specified by the graph signature: \n" + f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" + f"Number of user outputs: {len(gs.user_outputs)}. \n" + ) + + num_tokens = len(gs.output_tokens) + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens + mutate_nodes: List[str] = output_nodes[num_tokens:end] + user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] + + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n") + else: + raise SpecViolationError( + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + ) + + for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): + if user_output_node != user_output_name: + raise SpecViolationError( + f"User output {user_output_node} is not in the correct " + "order or is not found in the " + f"exported program's user_output list: {gs.user_outputs}. " + ) + + +def load_verifier(dialect: str) -> Optional[Type[Verifier]]: + if dialect == "ATEN": + return _VerifierMeta._registry.get(dialect) + return _VerifierMeta._registry[dialect] diff --git a/MLPY/Lib/site-packages/torch/_export/wrappers.py b/MLPY/Lib/site-packages/torch/_export/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..b54fab43f565586425aa20bc0d576a0b5d2304c7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_export/wrappers.py @@ -0,0 +1,114 @@ +from contextlib import contextmanager + +import torch +import torch._custom_ops +from torch._C import DispatchKey +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils import _pytree as pytree + + +_export_tracepoint = HigherOrderOperator("_export_tracepoint") + + +@_export_tracepoint.py_impl(ProxyTorchDispatchMode) +def export_tracepoint_dispatch_mode(mode, *args, **kwargs): + if not mode.enable_tracing: + return _export_tracepoint(*args, **kwargs) + p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) + proxy = mode.tracer.create_proxy( + "call_function", _export_tracepoint, p_args, p_kwargs + ) + return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) + + +@_export_tracepoint.py_impl(FakeTensorMode) +def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): + with mode: + return args + + +@_export_tracepoint.py_functionalize_impl +def export_tracepoint_functional(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) + return ctx.wrap_tensors(out) + + +_export_tracepoint.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_export_tracepoint, deferred_error=True) +) + + +@_export_tracepoint.py_impl(DispatchKey.CPU) +def export_tracepoint_cpu(*args, **kwargs): + return args + + +def _wrap_submodule(mod, path, module_call_specs): + assert isinstance(mod, torch.nn.Module) + assert path != "" + submodule = mod + for name in path.split("."): + if not hasattr(submodule, name): + raise RuntimeError(f"Couldn't find submodule at path {path}") + submodule = getattr(submodule, name) + + def update_module_call_signatures(path, in_spec, out_spec): + if path in module_call_specs: + assert module_call_specs[path]["in_spec"] == in_spec + assert module_call_specs[path]["out_spec"] == out_spec + module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} + + def check_flattened(flat_args): + for a in flat_args: + if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): + raise AssertionError( + f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" + ) + + def pre_hook(module, args, kwargs): + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + check_flattened(flat_args) + flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) + args, kwargs = pytree.tree_unflatten(flat_args, in_spec) + return args, kwargs + + def post_hook(module, args, kwargs, res): + _, in_spec = pytree.tree_flatten((args, kwargs)) + flat_res, out_spec = pytree.tree_flatten(res) + check_flattened(flat_res) + flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) + update_module_call_signatures(path, in_spec, out_spec) + return pytree.tree_unflatten(flat_res, out_spec) + + pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) + post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) + return pre_handle, post_handle + + +@contextmanager +def _wrap_submodules(f, preserve_signature, module_call_signatures): + handles = [] + + try: + for path in preserve_signature: + handles.extend(_wrap_submodule(f, path, module_call_signatures)) + yield + finally: + for handle in handles: + handle.remove() + + +def _mark_strict_experimental(cls): + def call(self, *args): + return strict_mode(self, args) + + cls.__call__ = call + return cls diff --git a/MLPY/Lib/site-packages/torch/_functorch/__init__.py b/MLPY/Lib/site-packages/torch/_functorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b758d3e2d95c16826b4bf374a7b49d2f740049bf Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb7dee310dc9193b98ee711d6ed526f3b746c40c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/apis.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/apis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3b439abd9c4148fd7ddc73a24776dcc0c53b1c1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/apis.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a73f004dfb0a0f823eab9914c80edfd2424526 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da0a385c9315d4cb4a3ff907c08343649359bb69 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..662290fc3a2a7f78d1eca7225e0132eeec81a002 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b945c861ce33fadad4918c4d23aeaf619d505ac Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/compilers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/compilers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec07be0e799d7bc180b83c93b5ee3c71d0497555 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/compilers.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf417d23ae5f99fca379e890c17744327054dbe3 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/config.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/deprecated.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/deprecated.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..030f4e0673faa99a864f0f15f9c664fc7ec82942 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/deprecated.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b851d56544d09cfc20dc3356af8740b55d33fdc5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/functional_call.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/functional_call.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db43ae021e74ea48dd2224405531a7a206712bda Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/functional_call.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b9c8d00a6ebd133686da291198e9a975289432 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/make_functional.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/make_functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1e1cd7fea0883a2b7f44d4f97740dd7dbb1d037 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/make_functional.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/partitioners.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/partitioners.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f83843bccac71106ad6e9974309907fcfe430566 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/partitioners.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37c6e9c2b148fc73e937d10d966275366432f5f5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/python_key.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/python_key.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed260af0b0a3c1a88ae51660eb47131a1cfee183 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/python_key.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c51d69d485a98437390f10ec3a8325411908d6e1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ac175edc5e3155ebd92791f828614687fa4c1d0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a6b21a9cd7cc97191f85aa5d51da2918bbcf08a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/__pycache__/vmap.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/vmap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3c5bdaaf58e96d3e33dfa126b81482f4361ec1d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/__pycache__/vmap.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__init__.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..755ea62025a77913b08d8087a21b0da20bf5b9cc Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc0934c637a5247067b129c1e673b44a2f7896f1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f50a36592b97bad1c4cab808d51032615fea3e4 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94b1756759610e018a77beacccbb7df56fe8da31 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..062d7de61791353c23a4f6d333cbf3e287429787 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5fca0b24baba8423826ea386e4b757b24198899 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8d16b8e3ab858e4ba3bfeb3fbe2c0c713088c4a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9448ca347a82df52867fbbe47d1ce3c9ea8618a0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae39509cf9419b9b8642fda371a887d79455ba9e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c78339d055cccfa13264fad6e0f1805f5a5edaad Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95ccf683116ac78d922c6d27eccb56e0e97ff78b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..748b8621ba23f9c214e7e04328c5ee8b4a8e315e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c35aa2f6739c1d394afe041b9b95491852775f36 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -0,0 +1,626 @@ +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the analysis here constructs view and mutation metadata from running +a functionalized version of the graph under compilation. +""" + +import collections +import logging +from functools import wraps +from typing import Callable, DefaultDict, Dict, List + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch._subclasses.meta_utils import safe_is_leaf +from torch.fx.experimental.symbolic_shapes import is_concrete_int +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) +from .functional_utils import ( + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + has_same_metadata, + to_fun, +) +from .schemas import ( + InputAliasInfo, + MutationType, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .subclass_utils import create_subclass_meta + +from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip + +zip = strict_zip + +log = logging.getLogger(__name__) + + +# This is a version of functionalization that is specifically designed +# for the AOTAutograd use case. +# +# Unlike functorch's variant, this doesn't use the functorch level system, +# instead it directly uses PyTorch's conventional dispatcher to hit the +# functionalization key. In particular, this means that FunctionalTensorWrapper +# can have autograd data stored directly on it. +# +# In typical AOTAutograd usage, the dispatch key order will look like: +# +# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor +# outer tensor inner tensor +# +# Returns: +# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and +# The list of outputs from the forward, but **only** the outputs that we need +# to pass in as tangents into the backward. +# Specifically, aliased outputs from the forward get regenerated, and don't participate +# in the compiled backward function. +def run_functionalized_fw_and_collect_metadata( + f, + *, + keep_input_mutations: bool, + # TODO: refactor to kill this flag + is_train: bool = False, + pre_dispatch: bool = False, +) -> Callable[..., ViewAndMutationMeta]: + memo: Dict[Tensor, Tensor] = {} + + def _to_fun(t): + if isinstance(t, Tensor): + if t in memo: + return memo[t] + r = to_fun(t) + memo[t] = r + return r + else: + return t + + @wraps(f) + def inner(*flat_args): + # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. + assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) + + input_info: List[InputAliasInfo] = [] + output_info: List[OutputAliasInfo] = [] + + prior_grad_enabled = torch.is_grad_enabled() + prior_autocast_states = _get_autocast_states() + + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + # It doesn't matter if we run this under predispatch or not because it is + # only for figuring out metadata + mode = FunctionalTensorMode(_allow_token_discovery=True) + with disable_above, mode: + # precondition: The passed in function already handles unflattening inputs + flattening outputs + flat_f_args = pytree.tree_map(_to_fun, flat_args) + flat_f_outs = f(*flat_f_args) + + if prior_autocast_states != _get_autocast_states(): + raise RuntimeError( + "AOTAutograd does not support tracing graphs that mutate the autocast state. " + "Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, " + "which will unwind all of their mutations to autocast state before the graph exits. " + "If you encounter this error while using torch.compile, please file a bug." + ) + + # Inspect the state of the input tensor functional wrapper to detect input mutation info + # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version + for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)): + # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in + # strides between the functionalized arg inner tensors and non-functionalized arg inner + # tensors. This is a problem as the inner tensor stride change may not be reflected + # correctly in the outer tensor, so disallow this for now. + mutates_data = has_data_mutation(f_arg) + if ( + mutates_data + and not arg.is_contiguous() + and is_traceable_wrapper_subclass(arg) + ): + raise RuntimeError( + "Mutations on non-contiguous inputs are currently not allowed on " + "tensor subclasses" + ) + + if not isinstance(arg, Tensor): + new_arg = arg + else: + new_arg = from_fun(f_arg) + mutates_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=False + ) + if mutates_metadata and is_traceable_wrapper_subclass(arg): + raise RuntimeError( + "Metadata mutations are currently not allowed on tensor subclasses" + ) + mutates_storage_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=True + ) + mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd( + f_arg + ) + mutations_under_no_grad_or_inference_mode = ( + mutates_data + and are_all_mutations_under_no_grad_or_inference_mode(f_arg) + ) + + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + # TODO: discuss this in the PR. Both supporting this, and detecting + erroring out, + # seem painful to get working. + if mutates_storage_metadata: + mutates_data = False + + requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad + + input_info.append( + InputAliasInfo( + is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg), + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=mutations_hidden_from_autograd, + mutates_storage_metadata=mutates_storage_metadata, + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + requires_grad=requires_grad, + keep_input_mutations=keep_input_mutations, + ) + ) + + # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate, + # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view + # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad + # on the base tensor, but we are obligated to properly set requires-gradness on the real output. + + inp_storage_refs = { + StorageWeakRef(inpt.untyped_storage()): idx + for idx, inpt in enumerate(flat_f_args) + if isinstance(inpt, Tensor) + } + + # We need inp tensor id's to be able to tell if an outputs **are** inputs. + inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)} + # We need output tensor id's to tell if any output._base` attributes **are** other outputs. + # (This is also a dict because we need to know that output's index, so we can regenerate + # the alias from it). + out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)} + + # Keep track of which outputs alias other outputs + out_tensor_alias_counts: DefaultDict = collections.defaultdict(int) + # This tells us, for a given group of outputs that alias each other, + # whether they e.g. all came from an unbind call + num_aliased_tensors_that_are_multi_output_views: DefaultDict = ( + collections.defaultdict(int) + ) + out_storage_to_tensors: DefaultDict = collections.defaultdict(set) + curr_storage = None + for o in flat_f_outs: + if isinstance(o, torch.Tensor): + curr_storage = StorageWeakRef(o.untyped_storage()) + out_tensor_alias_counts[curr_storage] += 1 + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # This is an optimization on top of the "alias of intermediates" logic, + # which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!] + # + # Before describing the optimization: this is important for AOTAutograd to have good + # perf around, multi-output views. HOWEVER: + # - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case, + # around using pre-dispatch tracing to partition out a graph so we can faithfully replay all + # views without having to regenerate them at runtime. + # - It's loosely described in this doc (more details will be added soon): + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit + # - Once that change lands, we should just rip out this "optimization", since: + # (1) It will be fully unnecessary + # (2) Although it is only a few lines of code, it is a bit difficult to reason about + # its correctness with the autograd engine in all cases. + # + # + # What is this optimization? Consider the below case: + # def f(x): + # intermediate = x.mul(2) + # # x and intermediate here require grad + # o1, o2, ... o10 = intermediate.unbind(-1) + # return intermediate, o1, o2, ... o10 + # Now, the "intermediate base" handling in AOTAutograd implies that we must do the following: + # (1) return "intermediate as an extra output of the compiled graph + # (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function. + # The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know + # that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function, + # this information will be hidden. + # In particular, mutating one alias might require autograd to update autograd metadata on the other aliases + # (like their grad_fn, for example, when the autograd engine needs to do view-replay). + # + # However, intermediate_base logic can be bad for backward performance (we sometimes generate + # as_strided calls during the intermediate base logic, which can have a slow backward formula). + # Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd? + # + # For a set of outputs of the graph that alias each other, o_1...o_k, consider: + # (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0) + # (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate), + # **at most** 1 can escape from the graph (e.g. there is not some other graph input/output + # o_other, that aliases these outputs) + # (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad. + # This condition is important because it's what causes slowness in the intermediate_base + # codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and + # aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn. + # "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward. + # In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta + # of the other aliases? + # + # Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd): + # (a) What happens if we mutate any of o_1 through o_k directly? + # Autograd raises an error: + # "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is + # the output of a function that returns multiple views. Such functions do not allow the output + # views to be modified inplace. You should replace the inplace operation by an out-of-place one." + # (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)? + # Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views. + # (c) What if we mutate o_k under no_grad? + # Autograd raises the same error + # (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)? + # Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed. + # Autograd raises the same error + # (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view? + # We promised that there is at most **one** such alias, e.g. intermediate in the example above. + # You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k + # to be error fn's. + # Since intermediate was the *only* non-multi-output-alias, there are no other aliases + # of `intermediate` around that were produced by the compiled fn and have a valid grad_fn. + # + # Coming back to this optimization: + # Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias + # without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile + # if all of the above conditions are met. + # This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on + # in eager but fail to during torch.compile, but it has the benefit that this code has much better performance. + # NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here: + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit, + # then this optimization will probably matter less and might be ok to remove. + is_cur_tensor_multi_out_view = isinstance( + o, FunctionalTensor + ) and torch._functionalize_is_multi_output_view( # type: ignore[attr-defined] + o.elem + ) + if is_cur_tensor_multi_out_view: + num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1 + out_storage_to_tensors[curr_storage].add(o) + + # maps the id of an intermediate base to its index in the output of the compiled forward + intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {} + intermediate_bases: List[torch.Tensor] = [] + # Why Do We Care If Storage Changed? + # It's important to understand the implications of storage changes in complex scenarios. Take this example: + # + # def f(x): + # x_storage = x.untyped_storage() + # non_leaf_tensor = torch.ones(4, requires_grad=True).clone() + # + # # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(non_leaf_tensor.untyped_storage()) + # + # out = x.view(-1) + # + # # Restoring x to its original storage, again simulating .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(x_storage) + # + # return out + # + # In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing. + # However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics, + # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'. + # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated, + # which could lead to issues later in the code. + for o in flat_f_outs: + functional_tensor_storage_changed = isinstance( + o, FunctionalTensor + ) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined] + o.elem + ) + curr_storage = ( + None + if not isinstance(o, torch.Tensor) + else StorageWeakRef(o.untyped_storage()) + ) + outs_with_identical_metadata_that_require_grad = ( + [] + if not isinstance(o, Tensor) + else [ + curr + for curr in out_storage_to_tensors[curr_storage] + if has_same_metadata(o, curr) + and curr.requires_grad + and o is not curr + ] + ) + + # See Note [Accessing .grad_fn on FunctionalTensor] + # In-place operations on views will trigger a lazy rebase of the autograd graph; + # this runs during access to the .grad_fn. The rebase logic will invoke view ops + # on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure + # these op calls succeed. + grad_fn = None + if isinstance(o, Tensor): + with FunctionalTensorMode(): + grad_fn = o.grad_fn + + is_result_of_custom_autograd_fn = False + # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) + # autograd fns + if type(grad_fn).__name__ == "CppFunction": + is_result_of_custom_autograd_fn = True + if isinstance(grad_fn, torch.autograd.function.BackwardCFunction): + is_result_of_custom_autograd_fn = True + + if not isinstance(o, Tensor): + output_type = OutputType.non_alias + base_idx = None + elif ( + curr_storage in inp_storage_refs + and grad_fn is not None + and is_result_of_custom_autograd_fn + ): + output_type = OutputType.custom_function_view + base_idx = None + elif ( + curr_storage in inp_storage_refs + and not functional_tensor_storage_changed + ): + base_idx = inp_storage_refs[curr_storage] + is_input_tensor = id(o) in inp_tensor_ids + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + if ( + grad_fn is not None + and num_aliased_outs_that_are_not_multi_output_views == 0 + ): + # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # In particular, given: + # def f(x): + # return list(x.unbind(0)) + # The main reason we ordinarily try to regenerate these output aliases outside of the + # compiled autograd.Function is because if any of the outputs are later mutated, + # autograd needs to perform view-replay to regenerate them. + # However, autograd does not allow users to mutate multi-output views + # in any way that can change the autograd metadata of other aliases. + # So we hide this aliasing from autograd here. + log.debug( + "Encountered AOTAutograd case: differentiable outputs that \ +alias each other from a multi-output view call" + ) + output_type = OutputType.non_alias + elif is_input_tensor: + output_type = OutputType.is_input + else: + output_type = OutputType.alias_of_input + + # We only need to handle the intermediate base case when both + # the intermediate base and the output require gradients. + # See Note [AOT Autograd: outputs aliasing inputs or intermediates!] + elif o._base is not None and o.requires_grad and o._base.requires_grad: + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + if ( + out_tensor_alias_counts[curr_storage] == 1 + or num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + # Note [Intermediate Bases Optimization] + # Normally if we have an output that aliases an intermediate, + # we need to add the extra "intermediate base" logic further down + # to prevent autograd from yelling at us if the user later tries to + # mutate that output. + # However, the common case here is if we have an output that aliases an intermediate, + # but doesn't alias any other outputs. + # In that case, autograd shouldn't have to worry about the aliasing at all + # (if that output is mutated, there are no other live aliases for autograd to worry about). + # The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs. + # So as an optimization, we won't do intermediate base handling in this case. + # Instead, we'll hide the aliasing from autograd using aten._unsafe_view(). + if ( + out_tensor_alias_counts[curr_storage] != 1 + and num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + log.debug( + "Encountered AOTAutograd case: differentiable outputs that alias each other \ +from a multi-output view call" + ) + output_type = OutputType.unsafe_view_alias + base_idx = None + else: + # First, check if o's ._base is an existing output + maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None) + if maybe_existing_out_idx is not None: + # Special case where the output is an alias of a graph intermediate, but that intermediate + # is itself also a user output. + output_type = ( + OutputType.alias_of_intermediate_base_is_user_output + ) + base_idx = maybe_existing_out_idx + else: + # Next, check if o's ._base is an intermediate base that we already returned + maybe_existing_base_output_idx = ( + intermediate_base_tensor_id_to_output_idx.get( + id(o._base), None + ) + ) + if maybe_existing_base_output_idx is not None: + output_type = OutputType.alias_of_intermediate + base_idx = maybe_existing_base_output_idx + else: + # Otherwise, take o._base and explicitly return it as an output in the compiled graph + new_out_idx = len(intermediate_bases) + base_idx = new_out_idx + # Indicate to the logic later on (when we trace the joint) + # that this particular output should get it's ._base appended to the forward graph outputs + output_type = ( + OutputType.alias_of_intermediate_save_as_output + ) + intermediate_base_tensor_id_to_output_idx[ + id(o._base) + ] = new_out_idx + intermediate_bases.append(o._base) + elif ( + # See https://github.com/pytorch/pytorch/issues/100348 for this case. + # This protects against the specific case where a user fn returns (output, output.detach()) + out_tensor_alias_counts[curr_storage] > 1 + and len(outs_with_identical_metadata_that_require_grad) > 0 + and not o.requires_grad + ): + assert len(outs_with_identical_metadata_that_require_grad) > 0 + # In theory we could use any of these tensors to regenerate the aliased outputs from, + # since they all alias each other and have identical metatadata + out_alias = outs_with_identical_metadata_that_require_grad[0] + existing_out_idx = out_tensor_ids[id(out_alias)] + output_type = OutputType.alias_of_intermediate_base_is_user_output + base_idx = existing_out_idx + else: + output_type = OutputType.non_alias + base_idx = None + + if isinstance(o, torch.Tensor): + dynamic_dims = { + i for i, s in enumerate(o.shape) if not is_concrete_int(s) + } + else: + dynamic_dims = None + out_info = OutputAliasInfo( + output_type=output_type, + raw_type=type(o), + base_idx=base_idx, + dynamic_dims=dynamic_dims, + requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, + ) + output_info.append(out_info) + + # See Note [AOT Autograd: Views to avoid tangents aliasing inputs] + def view_avoid_dupes_with_primals(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + return transform_subclass( + t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t) + ) + if isinstance(t, Tensor): + return t.view(t.shape) + return t + + # This analysis function returns *only* the outputs that are meant to be tangents to the backwards. + # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates) + # are *regenerated* later, and not used directly in the autograd graph + f_input_tangents = [ + inp + for inp, info in zip(flat_f_args, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + and info.mutates_data + and info.requires_grad + ] + f_output_tangents = [ + o + for o, info in zip(flat_f_outs, output_info) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases are also included in the backward graph + f_tangents = f_input_tangents + f_output_tangents + intermediate_bases + traced_tangents = pytree.tree_map(from_fun, f_tangents) + traced_tangents = pytree.tree_map( + view_avoid_dupes_with_primals, traced_tangents + ) + user_outs = pytree.tree_map(from_fun, f_output_tangents) + + f_mutated_inputs = [ + inp + for inp, info in zip(flat_f_args, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + f_metadata_mutated_inputs = [ + inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata + ] + # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be. + # When handling subclasses, we need info about **all** outputs of compiled forward graph, + # so we know precisely which graph outputs to wrap back into tensor subclasses + # Ideally we would refactor this so not have an is_train flag, and have the separate + # inference and training paths decide which inputs/output to ask for subclass info on. + # However, we currently stash indexing information on each SubclassMeta about its order + # in the graph outputs list. + f_fw_graph_outs = list(flat_f_outs) + if is_train or not keep_input_mutations: + f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs + else: + # even when "keep_input_mutations" is True, + # we never keep metadata-only mutations in the fw graph + f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs + if is_train: + f_fw_graph_outs = f_fw_graph_outs + intermediate_bases + fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs) + + grad_enabled_mutation = None + if torch.is_grad_enabled() != prior_grad_enabled: + grad_enabled_mutation = torch.is_grad_enabled() + torch.set_grad_enabled( + prior_grad_enabled + ) # Restore the prior state after tracing it + log.debug( + ( + "grad_mode mutation encountered in graph. " + "Will emit mutation epilogue, to set grad_mode=%s" + ), + grad_enabled_mutation, + ) + + metadata = ViewAndMutationMeta( + input_info=input_info, + output_info=output_info, + num_intermediate_bases=len(intermediate_bases), + keep_input_mutations=keep_input_mutations, + traced_tangents=traced_tangents, + subclass_inp_meta=create_subclass_meta(flat_args), + subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), + subclass_tangent_meta=create_subclass_meta(traced_tangents), + is_train=is_train, + grad_enabled_mutation=grad_enabled_mutation, + tokens=mode._tokens, + ) + return metadata + + return inner diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e38f5cab3dd8f0482feb364c6cf64995a197f6ce --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -0,0 +1,192 @@ +""" +This module dispatches the graphs to either the forward-only or joint compilation +pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. +""" + +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import lazy_format_graph_code +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.fx.experimental.proxy_tensor import make_fx + +from .functional_utils import ( + assert_functional_graph, + propagate_input_mutation_stacktraces, +) +from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta +from .traced_function_transforms import ( + aot_dispatch_subclass, + create_functionalized_fn, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, +) + +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + + +def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule: + # FunctionalTensorMode must be enabled here. + # See Note [Accessing .grad_fn on FunctionalTensor] + with enable_python_dispatcher(), FunctionalTensorMode( + pre_dispatch=aot_config.pre_dispatch, export=aot_config.is_export + ): + fx_g = make_fx( + f, + decomposition_table=aot_config.decompositions, + record_module_stack=True, + pre_dispatch=aot_config.pre_dispatch, + )(*args) + + return fx_g + + +def aot_dispatch_base_graph( + flat_fn, + flat_args: List[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> Union[Callable, Tuple[Callable, List[Any], Optional[SubclassMeta]]]: + # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. + # The cases that aot_dispatch_base doesn't need to handle include: + # - outputs that are aliases of graph intermediates + # - outputs that are aliases of graph inputs + # While cases that it does need to handle include: + # - input mutations (including when inputs are aliases of each other) + # - input metadata mutations + fn_to_trace = fn_input_mutations_to_outputs( + flat_fn, + fw_metadata, + keep_data_input_mutations=aot_config.keep_inference_input_mutations, + ) + + fn_to_trace, updated_flat_args = create_functionalized_fn( + fn_to_trace, + flat_args, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=False, + ) + + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + maybe_subclass_meta, + ) = aot_dispatch_subclass( + fn_to_trace, + updated_flat_args, + is_joint_structure=False, + meta=fw_metadata, + fw_only=flat_fn, + ) + + fw_module = _create_graph( + fn_to_trace, + updated_flat_args_subclasses_desugared, + aot_config=aot_config, + ) + + # As long as we opted to remove input mutations, then + # there should be *NO* mutating ops in the graph at this point. + copy_count = assert_functional_graph(fw_module.graph) + + fw_module.graph.eliminate_dead_code() + fw_module.recompile() + + copy_count2 = assert_functional_graph(fw_module.graph) + propagate_input_mutation_stacktraces(fw_module.graph) + + assert copy_count == copy_count2 + + if aot_config.enable_log: + aot_graphs_log.info( + "%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id) + ) + trace_structured( + "aot_forward_graph", + payload_fn=lambda: fw_module.print_readable(print_output=False), + ) + + # TODO: should factor this into a separate function for export that always only returns just the graph. + if aot_config.is_export: + assert ( + maybe_subclass_meta is None + ), "aot_export_module does not support tensor subclass inputs for now." + return fw_module + return fw_module, list(updated_flat_args_subclasses_desugared), maybe_subclass_meta + + +# Has the precondition that there +# are no duplicate arguments in flat_args (e.g., the same Tensor +# object never shows up twice. However, two tensor inputs MAY alias +# the same storage, so long as they have separate TensorImpls.) +def aot_dispatch_autograd_graph( + flat_fn, + flat_args: List[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> Union[Callable, Tuple[Callable, List[Any], Optional[SubclassMeta]]]: + # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward. + # It includes outputs of the original forward, *and* any updated inputs due to input mutations. + # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations. + traced_tangents = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, + fw_metadata.traced_tangents, + ) + + joint_inputs = (flat_args, traced_tangents) + + fn_prepared_for_autograd = fn_prepped_for_autograd( + flat_fn, + fw_metadata, + ) + joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config) + + joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn( + joint_fn_to_trace, + joint_inputs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=True, + ) + + subclass_tracing_info = aot_dispatch_subclass( + joint_fn_to_trace, + updated_joint_inputs, + is_joint_structure=True, + meta=fw_metadata, + fw_only=flat_fn, + ) + + joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn + updated_joint_inputs = subclass_tracing_info.plain_tensor_args + maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta + + fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) + + # There should be *NO* mutating ops in the graph at this point. + assert_functional_graph(fx_g.graph) + + # Redundant with the check above, but worth having in case tracing introduced + # a fake tensor. Unlikely. + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect + # when we need to manually detach() some inputs in the forward. + # Higher order ops might eventually need to do the same. + if aot_config.is_export: + assert ( + maybe_subclass_meta is None + ), "aot_export_module does not support tensor subclass inputs for now." + return fx_g + return fx_g, updated_joint_inputs, maybe_subclass_meta diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/functional_utils.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/functional_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1124e4f3a1dfdb1719dcf0732763ee819c9b4c1d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/functional_utils.py @@ -0,0 +1,370 @@ +""" +This file contains utilities related to functionalization in AOTAutograd: +1. converting to/from functional tensors +2. detecting Tensor mutations - both metadata and Tensor value +3. regenerating/replaying views from their base +4. checking if a graph is functional i.e. whether it contains any mutation ops +""" + +import torch +from torch import Tensor +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + + +def to_fun(t): + if isinstance(t, Tensor): + if is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + else: + return FunctionalTensor.to_functional(t) + else: + return t + + +def sync_functional_tensor(t): + if is_traceable_wrapper_subclass(t): + attrs, ctx = t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + sync_functional_tensor(getattr(t, attr)) + else: + torch._sync(t) + + +# When subclasses are involved, t here will usually look something like: +# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor)))) +def from_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) # type: ignore[attr-defined] + return t + sync_functional_tensor(t) + return torch._from_functional_tensor(t.elem) + + +def is_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + t_attrs, _ = t.__tensor_flatten__() # type: ignore[attr-defined] + t_inners = [getattr(t, attr) for attr in t_attrs] + any_fun = any(is_fun(x) for x in t_inners) + all_fun = all(is_fun(x) for x in t_inners) + assert any_fun == all_fun + return any_fun + + return isinstance(t, FunctionalTensor) + + +# t here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +def has_data_mutation(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + return any(has_data_mutation(getattr(t, attr)) for attr in attrs) + else: + if isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_has_data_mutation(t.elem) # type: ignore[attr-defined] + return False + + +def are_all_mutations_hidden_from_autograd(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd. + return all( + are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs + ) + elif isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem) + else: + return False + + +def are_all_mutations_under_no_grad_or_inference_mode(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + return all( + are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr)) + for attr in attrs + ) + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode( + t.elem + ) + + +# f_arg here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +# Assumption: arg promises to be the "original" tensor wrapped by f_arg +# Note: "storage mutations" coming from set_() are a type of metadata mutation. So: +# - check_only_storage_mutation=True: only return true if there was a storage mutation +# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation) +def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool): + if is_traceable_wrapper_subclass(f_arg): + attrs, _ = f_arg.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + f_inner_ts = [getattr(f_arg, attr) for attr in attrs] + inner_ts = [getattr(arg, attr) for attr in attrs] + return any( + has_metadata_mutation( + f_inner_t, + inner_t, + check_only_storage_mutation=check_only_storage_mutation, + ) + for f_inner_t, inner_t in zip(f_inner_ts, inner_ts) + ) + else: + if not isinstance(f_arg, torch.Tensor): + assert not isinstance(arg, torch.Tensor) + return False + assert isinstance(f_arg, FunctionalTensor) + assert isinstance(arg, FakeTensor) + + arg_after = torch._from_functional_tensor(f_arg.elem) + # This is true if the current tensor experienced at least one set_() call + maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) # type: ignore[attr-defined] + # However, multiple set_() calls can cancel out. So we also check whether the + # storage of the tensor has changed. + # Note: if an input experienced two set_() calls that cancel out, **and** + # it experiences an data mutation, we pessimistically think that the set_() + # call is necessary here. We could in theory fix this, but this will + # hopefully never happen in user code, and is not needed for fsdp. + same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef( + arg_after.untyped_storage() + ) + has_storage_metadata_mutation = maybe_storage_changed and not same_storages + if check_only_storage_mutation: + return has_storage_metadata_mutation + + # storage metadata mutation is a type of metadata mutation, so return true if we saw one + if has_storage_metadata_mutation: + return True + + maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) # type: ignore[attr-defined] + # This is true if the current tensor experienced at least one metadata mutation. + # So if false, we know there was no metadata mutation + if not maybe_metadata_mutated: + return False + + # However, multi metadata mutations can cancel out. + # So we also check if the concrete sizes/strides on the tensor have changed. + same_sizes = arg.shape == arg_after.shape + same_strides = arg.stride() == arg_after.stride() + same_offsets = arg.storage_offset() == arg_after.storage_offset() + has_metadata_mutation_ = maybe_metadata_mutated and not ( + same_sizes and same_strides and same_offsets + ) + # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call. + return has_metadata_mutation_ + + +def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad): + # Try to do view-replay if possible. + # fall back to .as_strided() if we can't. + if target_meta_tensor._base is not None: + # The base that we want to replay our view off of might have a different shape than the view's original base. + b = target_meta_tensor._base + abt = aliased_base_tensor + # Don't unnecessarily call as_strided if nothing changed; as_strided's + # backward is poorly implemented and slow + if abt is not b and ( + abt.size() != b.size() + or abt.stride() != b.stride() + or abt.storage_offset() != b.storage_offset() + ): + reshaped_base_tensor = aliased_base_tensor.as_strided( + b.size(), b.stride(), b.storage_offset() + ) + else: + reshaped_base_tensor = aliased_base_tensor + out = target_meta_tensor._view_func(reshaped_base_tensor) + # This shape mismatch can happen due to a bug in inplace/view handling in autograd. + # Try putting a breakpoint here and running + # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types` + # Also, https://github.com/pytorch/pytorch/issues/49825 + # + # As a stopgap, we'll fall back to as_strided. + if out is not None and out.shape == target_meta_tensor.shape: + if aliased_base_tensor.requires_grad and not target_requires_grad: + out = out.detach() + elif not aliased_base_tensor.requires_grad and target_requires_grad: + out.requires_grad_(True) + return out + size = target_meta_tensor.size() + stride = target_meta_tensor.stride() + storage_offset = target_meta_tensor.storage_offset() + if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex(): + aliased_out = torch.view_as_real(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex(): + aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + else: + aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset) + # For outputs aliasing inputs, we need to check if the requires-gradness has changed. + if aliased_base_tensor.requires_grad and not target_requires_grad: + aliased_out = aliased_out.detach() + elif not aliased_base_tensor.requires_grad and target_requires_grad: + aliased_out.requires_grad_(True) + # For outputs aliasing inputs, we need to check if the dtype has changed. + # as_strided() is the "most generic" view, but it does not cover cross-dtype views + if aliased_out.dtype != target_meta_tensor.dtype: + aliased_out = aliased_out.view(target_meta_tensor.dtype) + return aliased_out + + +def has_same_metadata(t1, t2): + return ( + definitely_true(sym_eq(t1.size(), t2.size())) + and definitely_true(sym_eq(t1.stride(), t2.stride())) + and definitely_true(t1.storage_offset() == t2.storage_offset()) + and t1.is_conj() == t2.is_conj() + and t1.is_neg() == t2.is_neg() + ) + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed +# +# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization +# to confirm that inputs were not mutated when running the user's model with functionalization on. +# But when we have subclass inputs, we can't rely on that: +# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs +# a brand new subclass instance: we are calling __tensor_unflatten__, and going +# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor)) +def was_tensor_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed, +# but shares storage with the old input +def was_tensor_metadata_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg and StorageWeakRef( + arg.untyped_storage() + ) == StorageWeakRef(new_arg.untyped_storage()) + + +# Returns the number of detected copy_ +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + placeholders = set() + copy_count = 0 + # NB: It would also be nice to verify that the mutations all happen at the + # end, but we also do some administrative views after mutations so this + # isn't actually true. (TODO: Could this cause problems for Inductor?) + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target is torch.ops.aten.copy_.default: + suffix = True + # Can only copy_ into an input, and can only do so once + assert n.args[0] in placeholders + placeholders.remove(n.args[0]) + copy_count += 1 + else: + assert ( + not n.target._schema.is_mutable + ), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return copy_count + + +def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: + placeholders = set() + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target is torch.ops.aten.copy_.default: + # Can only copy_ into an input, and can only do so once + assert n.args[0] in placeholders + placeholders.remove(n.args[0]) + copy_from_node = n.args[1] + # Pre-condition: every node has a "stack_trace" field in its meta, + # but copy_() nodes do not (since we manually added them during functionalization). + # Instead, we manually propagate here. + if "stack_trace" in copy_from_node.meta: + assert "stack_trace" not in n.meta, str(n) + n.meta["stack_trace"] = copy_from_node.meta["stack_trace"] + + +def _check_if_mutation_can_be_in_graph( + keep_input_mutations: bool, + mutates_data, + mutates_metadata, + mutations_hidden_from_autograd, + mutations_under_no_grad_or_inference_mode, + requires_grad, +): + if keep_input_mutations: + return mutates_data and ( + (not mutates_metadata and not requires_grad) + or mutations_hidden_from_autograd + or mutations_under_no_grad_or_inference_mode + ) + return False diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..fe926247820f2ccc9d9063e0e3bb671ac0ce1096 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -0,0 +1,432 @@ +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the following analyses are provided: +1. Refine the view and mutation metadata collected previously - removing duplicate + inputs or mapping views to their bases. +2. We also analyze the function signature for export graphs. +""" + +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental.symbolic_shapes import is_concrete_int +from .schemas import ( + BackwardSignature, + GraphSignature, + InputAliasInfo, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .utils import strict_zip + +zip = strict_zip + + +def remove_dupe_metadata( + m: ViewAndMutationMeta, + keep_arg_mask: List[bool], + add_dupe_map: List[int], +) -> ViewAndMutationMeta: + assert len(m.input_info) == len(keep_arg_mask) + # Easy invariant: the first argument should never be a dupe (it will be kept) + assert len(keep_arg_mask) > 0 and keep_arg_mask[0] + + # Filter dupe'd mutated inputs out of traced_tangents + num_data_mutations = len([x for x in m.input_info if x.mutates_data]) + other_traced_tangents = m.traced_tangents[num_data_mutations:] + inp_traced_tangents = m.traced_tangents[:num_data_mutations] + filtered_inp_traced_tangents = [ + x + for i, x in enumerate(inp_traced_tangents) + if keep_arg_mask[m.mutated_inp_runtime_indices[i]] + ] + traced_tangents = filtered_inp_traced_tangents + other_traced_tangents + + return ViewAndMutationMeta( + input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]], + # For outputs that are views of inputs, we store the index of the input that the output + # was generated from. Need to update that index to account for removed dupes. + output_info=[ + OutputAliasInfo( + output_type=o.output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], + requires_grad=o.requires_grad, + ) + for o in m.output_info + ], + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + # We are guaranteed not to get here, since dupes are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=[], + is_train=m.is_train, + ) + + +# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, +# after adding synthetic base arguments to the function. +# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, +# and updating it with our synthetic base calling convention. +# +# When config.debug_assert is set, we automatically regenerate the metadata +# and compare it to this output for sanity. +# +# In addition to the updated metadata, also return the list of input indices +# that will need to be updated in the synthetic base epilogue + + +# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, +# after adding synthetic base arguments to the function. +# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, +# and updating it with our synthetic base calling convention. +# +# When config.debug_assert is set, we automatically regenerate the metadata +# and compare it to this output for sanity. +# +# In addition to the updated metadata, also return the list of input indices +# that will need to be updated in the synthetic base epilogue +def create_synthetic_base_metadata( + m: ViewAndMutationMeta, + # Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a + # synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata) + synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]], + outer_args: List[Any], + inner_args: List[Any], +) -> Tuple[ViewAndMutationMeta, List[int]]: + # maps inner arg indices to outer arg indices + synthetic_base_to_indices: Dict[int, List[int]] = {} + for inner_idx in range(len(inner_args)): + outer_aliased_indices_of_current_base_arg = [ + outer_idx + for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info) + if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx) + or ( + isinstance(inner_idx_or_tuple, tuple) + and inner_idx_or_tuple[0] == inner_idx + ) + ] + synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg + + # given the requires_grad info on mutated inputs, + # generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases. + input_infos = [] + for outer_indices in synthetic_base_to_indices.values(): + # leaf-ness should be all-or-nothing for aliased tensor. + # (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf) + any_leaf = any(m.input_info[x].is_leaf for x in outer_indices) + all_leaf = all(m.input_info[x].is_leaf for x in outer_indices) + assert any_leaf == all_leaf + + mutates_data = ( + True + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_data + ) + mutates_metadata = ( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_metadata + ) + requires_grad = any(m.input_info[x].requires_grad for x in outer_indices) + mutations_hidden_from_autograd = all( + m.input_info[x].mutations_hidden_from_autograd for x in outer_indices + ) + mutations_under_no_grad_or_inference_mode = all( + m.input_info[x].mutations_under_no_grad_or_inference_mode + for x in outer_indices + ) + + inpt_info = InputAliasInfo( + # If len(outer_indices) > 1, then this input is a synthetic base. + # The invariant is that to the rest of aot autograd, synthetic bases only show up if + # one of their aliases gets a data mutation. And if any of their aliases get metadata + # mutations, they will be hidden from the rest of aot autograd. + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=all( + m.input_info[x].mutations_hidden_from_autograd for x in outer_indices + ), + mutates_storage_metadata=False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_storage_metadata, + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + is_leaf=any_leaf, + requires_grad=requires_grad, + keep_input_mutations=m.keep_input_mutations, + ) + input_infos.append(inpt_info) + + # Find any inputs that fulfill the following criteria: + # (1) They are part of a synthetic base (because they alias another input, + # and at least one input experiences a data mutation) + # (2) They experience a metadata mutation + outer_aliased_arg_idx_with_metadata_mutations = [ + outer_idx + for outer_idx, inpt_info in enumerate(m.input_info) + if inpt_info.mutates_metadata + and not isinstance(synthetic_base_info[outer_idx], int) + ] + + # grab the original requires grad info on the outputs, except the ones from the mutated inputs + input_metadata_output_info = [ + OutputAliasInfo( + output_type=OutputType.alias_of_input, + raw_type=FunctionalTensor, + dynamic_dims={ + i + for i, s in enumerate(outer_args[outer_idx].shape) + if not is_concrete_int(s) + }, + base_idx=synthetic_base_info[outer_idx][0], # type: ignore[index] + requires_grad=outer_args[outer_idx].requires_grad, + ) + for outer_idx in outer_aliased_arg_idx_with_metadata_mutations + ] + existing_output_infos = [] + for o in m.output_info: + new_base_idx = ( + None + if o.base_idx is None + else ( + synthetic_base_info[o.base_idx] + if isinstance(synthetic_base_info[o.base_idx], int) + else synthetic_base_info[o.base_idx][0] # type: ignore[index] + ) + ) + # If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change + new_output_type = ( + OutputType.alias_of_input + if o.output_type == OutputType.is_input and o.base_idx != new_base_idx + else o.output_type + ) + existing_output_infos.append( + OutputAliasInfo( + output_type=new_output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases + base_idx=new_base_idx, # type: ignore[arg-type] + requires_grad=o.requires_grad, + ) + ) + + inner_mutated_tangents = [ + x + for inner_idx, x in enumerate(inner_args) + if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad + ] + + output_info = existing_output_infos + input_metadata_output_info + # Regenerate traced tangents to include mutated inputs including synthetic bases + traced_tangents = ( + inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :] + ) + + return ( + ViewAndMutationMeta( + input_info=input_infos, + output_info=output_info, + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=[], + is_train=m.is_train, + ), + outer_aliased_arg_idx_with_metadata_mutations, + ) + + +def _get_last_mem_address(x): + out = x.storage_offset() + for size, stride in zip(x.size(), x.stride()): + out += (size - 1) * stride + return out + + +# Assumption: x and y are known to share a storage, and we are trying to determine +# if their memory is actually completely disjoint, based on sizes/strides/storage_offset +def _tensors_definitely_do_not_overlap(x, y): + if x is y: + return False + if x.numel() == 0 or y.numel() == 0: + return True + + # Make x always on the left + if x.storage_offset() > y.storage_offset(): + x, y = y, x + # Short-circuit in the "obvious" overlapping case: both tensors are contiguous + if x.is_contiguous() and y.is_contiguous(): + if x.storage_offset() + x.numel() > y.storage_offset(): + # definitely overlap + return False + else: + # definitely no overlap + return True + + # Short-circuit: if last memory address of x is < start of y, then not overlapping. + x_last = _get_last_mem_address(x) + if x_last < y.storage_offset(): + return True + + if x.dim() == 2 and y.dim() == 2 and x.stride(1) == 1 and y.stride(1) == 1: + # This cases is needed for the shampoo optimizer. + # All tensors are 2d (non-contiguous), have the same outer stride, and have an inner stride of 1 + # (so rows are contiguous) + if x.stride(0) == y.stride(0): + offset_delta = y.storage_offset() - x.storage_offset() + if offset_delta < x.size(1): + # definitely overlaps (row 0 of y overlaps with row 0 of x) + # Example: + # base = torch.arange(32).reshape(4, 8) + # x = base.narrow(1, 0, 4) + # x: size=(4, 4), stride=(8, 1), offset=0 + # y = base.narrow(1, 3, 4) + # y: size=(4, 4), stride=(8, 1), offset=3 + return False + x_total_elems_covered = x.stride(0) * (x.size(0) - 1) + x.size(1) + if x_total_elems_covered <= offset_delta: + # definitely does not overlap (last byte of x is before start of y) + # Example: + # x: size=(4, 4), stride=(8, 1), offset=0 (last byte is 27) + # y: size=(4, 4), stride=(8, 1), offset=28 (start byte is 28) + return True + # At this point, we want to check if the 0th row of y + # overlaps with **some** row of x. + # We can check this by shifting y backward by the shared stride, repeatedly, + # until the first row of y is before the first row of x. + # Then we can check if these rows overlap. + # We can accomplish this by modding our offset by the stride. + offset_delta_mod = offset_delta % x.stride(0) + # Example: + # 0 1 2 3 + # 9 10 11 12 + # 18 19 20 21 + # 27 28 29 30 + # x: size=(4, 4), stride=(9, 1), offset=0 + # y: size=(4, 4), stride=(9, 1), offset=22 (this would not overlap) + # y: size=(4, 4), stride=(9, 1), offset=23 (this would not overlap) + # y: size=(4, 4), stride=(9, 1), offset=24 (this would overlap) + # y: size=(4, 4), stride=(9, 1), offset=25 (this would overlap) + # If the interval [modded_offset, modded_offset + x_size] falls entirely + # without + if offset_delta_mod + y.size(1) <= x.stride(0): + return True + else: + return False + return False + + +def compute_overlapping_inputs(fwd_inputs, aliased_input_indices): + actual_aliased_indices = set() + for j in range(len(aliased_input_indices)): + for i in range(j): + i_ = aliased_input_indices[i] + j_ = aliased_input_indices[j] + if not _tensors_definitely_do_not_overlap(fwd_inputs[i_], fwd_inputs[j_]): + actual_aliased_indices.add(i_) + actual_aliased_indices.add(j_) + return actual_aliased_indices + + +def _graph_input_names(gm): + return [node.name for node in gm.graph.nodes if node.op == "placeholder"] + + +def _graph_output_names(gm): + output_node = next(iter(reversed(gm.graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + return [getattr(return_arg, "name", None) for return_arg in return_args] + + +def create_graph_signature( + fx_g: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + *, + user_args_flat: List[Tensor], + params_and_buffers_flat: List[Tensor], + param_names: List[str], + buffer_names: List[str], + trace_joint: bool, + num_user_fw_outs: Optional[int], + loss_index: Optional[int], +) -> GraphSignature: + # Retrieve graph input names + graph_input_names = _graph_input_names(fx_g) + # Retrieve graph output names + graph_output_names = _graph_output_names(fx_g) + + num_params_buffers = len(param_names) + len(buffer_names) + num_tokens = len(fw_metadata.tokens) + # We have enough restrictions on the graph (no de-duping, synthetic bases, etc), + # Such that # graph inps = # user inps + # params + # buffers + num_user_args = len(graph_input_names) - num_params_buffers - num_tokens + + if trace_joint: + assert num_user_fw_outs is not None + num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices + backward_output_names = graph_output_names[num_fw_outs:] + + grad_index = itertools.count(0) + gradients_to_parameters = { + backward_output_names[next(grad_index)]: param_names[i] + for i, param in enumerate(params_and_buffers_flat) + if param.requires_grad + } + + gradients_to_user_inputs = { + backward_output_names[next(grad_index)]: graph_input_names[ + i + len(params_and_buffers_flat) + ] + for i, user_input in enumerate(user_args_flat) + if user_input.requires_grad + } + + assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len( + backward_output_names + ) + + # Check that we have fully accounted for all graph outputs + backward_signature = BackwardSignature( + gradients_to_parameters, + gradients_to_user_inputs, + graph_output_names[loss_index], + ) + else: + backward_signature = None + num_user_fw_outs = ( + len(graph_output_names) + - fw_metadata.num_mutated_inp_runtime_indices + - num_tokens + ) + + return GraphSignature.from_tracing_metadata( + in_spec=in_spec, + out_spec=out_spec, + graph_input_names=graph_input_names, + graph_output_names=graph_output_names, + view_mutation_metadata=fw_metadata, + named_parameters=param_names, + named_buffers=buffer_names, + num_user_inputs=num_user_args, + num_user_outputs=num_user_fw_outs, + loss_index=loss_index, + backward_signature=backward_signature, + ) diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..1044e7757ceb1860c7ce71910f5fc7d158551040 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -0,0 +1,936 @@ +""" +These are the runtime wrappers that are associated with JIT-compiling. + +This includes the forward-only and joint JIT runtime wrappers. + +This module depends heavily on the runtime wrapper building blocks defined +in `runtime_wrappers`. +""" + +import logging +from contextlib import nullcontext +from functools import wraps +from typing import Any, List, Optional + +import torch +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo.utils import lazy_format_graph_code +from torch._guards import detect_fake_mode, tracing, TracingContext +from torch._logging import getArtifactLogger, trace_structured +from torch._prims_common import CUDARngStateHelper +from torch._subclasses import FakeTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node +from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals +from .. import config +from .dispatch_and_compile_graph import ( + aot_dispatch_autograd_graph, + aot_dispatch_base_graph, +) +from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling + +from .runtime_wrappers import ( + aot_dispatch_subclass_wrapper, + create_runtime_wrapper, + functionalized_rng_runtime_epilogue, +) +from .schemas import ( + AOTConfig, + MutationType, + OutputType, + SubclassMeta, + TensorAlias, + ViewAndMutationMeta, +) +from .subclass_utils import ( + compute_inner_mutated_inp_indices_from_subclass_meta, + unwrap_tensor_subclasses, + wrap_tensor_subclasses, +) + +from .utils import ( + _get_symint_hints, + call_func_at_runtime_with_args, + make_boxed_func, + normalize_as_list, + strict_zip, +) + +zip = strict_zip + +log = logging.getLogger(__name__) +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + +aten = torch.ops.aten + + +def _compute_output_meta_with_inductor_strides(fw_module, fwd_output_strides): + out = [n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0])] + # will only be set for inductor + if not fwd_output_strides: + return out + with TracingContext.get().fake_mode.shape_env.suppress_guards(): + for i in range(len(out)): + if not isinstance(out[i], Tensor): + continue + if all(s1 == s2 for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])): + continue + out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i]) + return out + + +def aot_dispatch_base( + flat_fn, + flat_args: List[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +): + fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + + disable_amp = torch._C._is_any_autocast_enabled() + context = torch._C._DisableAutocast if disable_amp else nullcontext + fakified_out = None + + with context(), track_graph_compiling(aot_config, "inference"): + compiler = ( + aot_config.inference_compiler + if aot_config.inference_compiler is not None + else aot_config.fw_compiler + ) + if config.functionalize_rng_ops: + # Add the seed and offset as example inputs to pass to the compiler + fake_mode = detect_fake_mode() + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) + updated_flat_args.extend([seed, offset]) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = ( + fw_metadata + if maybe_subclass_meta is None + else maybe_subclass_meta.fw_metadata + ) + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw = compiler(fw_module, updated_flat_args) + + # see note: [Returning Fake Tensors on First AOT Autograd Call] + if tracing_context and tracing_context.fakify_first_call: + fakified_out = _compute_output_meta_with_inductor_strides( + fw_module, fwd_output_strides + ) + + # However, create_runtime_wrapper does not expect the rng offsets in the + # output. So, we have to create another wrapper and take out the offset. As + # a result, we have to account for not boxed_call compilers as well. + if not hasattr(compiled_fw, "_boxed_call"): + compiled_fw = make_boxed_func(compiled_fw) + + # Create a wrapper to set up the rng functionalize bits + @wraps(compiled_fw) + def rng_functionalization_wrapper(args): + # see note: [Returning Fake Tensors on First AOT Autograd Call] + nonlocal fakified_out + if fakified_out is not None: + out = fakified_out + fakified_out = None + return out + + # args is a list because compiled_fw is boxed_call + if fw_metadata.is_rng_op_functionalized: + # Add the seed and offset to args + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() + args.extend([seed, offset]) + out = compiled_fw(args) + out = functionalized_rng_runtime_epilogue(fw_metadata, out) + return out + else: + return compiled_fw(args) + + if maybe_subclass_meta is not None: + compiled_fw_func = aot_dispatch_subclass_wrapper( + rng_functionalization_wrapper, + subclass_metas=fw_metadata.subclass_fw_graph_out_meta, + num_fw_outs_saved_for_bw=None, + ) + else: + compiled_fw_func = rng_functionalization_wrapper + + if not hasattr(compiled_fw_func, "_boxed_call"): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + compiled_fn = create_runtime_wrapper( + compiled_fw_func, + runtime_metadata=fw_metadata, + indices_of_inps_to_detach=[], + trace_joint=False, + keep_input_mutations=aot_config.keep_inference_input_mutations, + disable_amp=disable_amp, + ) + + return compiled_fn + + +def aot_dispatch_autograd( + flat_fn, + flat_args: List[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +): + fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( # type: ignore[misc] + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Copied from aot_dispatch_autograd_graph. + disable_amp = torch._C._is_any_autocast_enabled() + + if aot_config.enable_log: + aot_joint_log.info( + "%s", lazy_format_graph_code("Joint graph", fx_g, aot_config.aot_id) + ) + trace_structured( + "aot_joint_graph", + payload_fn=lambda: fx_g.print_readable(print_output=False), # type: ignore[union-attr] + ) + + fakify_first_call = False + fakified_out = None + + with torch.no_grad(): + inner_meta = ( + fw_metadata + if maybe_subclass_meta is None + else maybe_subclass_meta.fw_metadata + ) + with track_graph_compiling(aot_config, "joint"): + # See Note: [Partitioner handling for Subclasses, Part 1] + # See Note: [Recomputing subclass mutation handling] + mutated_inp_runtime_indices = ( + compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata, inner_meta + ) + ) + num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) + num_inner_fwd_outputs = ( + num_mutated_inp_runtime_indices + + inner_meta.num_outputs + + inner_meta.num_intermediate_bases + + inner_meta.num_outputs_rng_offset + + len( + fw_metadata.tokens + ) # See Note [Side-Effectful Tokens in AOTAutograd] + ) + fw_module, bw_module = aot_config.partition_fn( + fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs + ) + + fw_outs = next(n for n in fw_module.graph.nodes if n.op == "output").args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) + symint_outs_saved_for_bw = [ + n for n in fw_outs_saved_for_bw if is_sym_node(n) + ] + fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + _num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + + # Note [Detaching inputs that never need gradients] + # See https://github.com/pytorch/pytorch/issues/97745 + # Suppose we have a function like this that we want to compile: + # + # def f(x, y): + # return torch.mul(x, y.detach()) + # + # What gradients should we compute for x and y? + # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, + # and so we'll compute: + # x_grad_input = y + # y_grad_input = None + # Does this preserve the semantics of eager mode? + # Unfortunately, no. + # Doing the above will cause autograd to **continue** to backprop the autograd tape + # that was generated from constructing y. + # + # This is **different** from what would have happened in eager mode. + # In eager mode, if we backprop through the output of this function, autograd will only traverse + # the bit of the autograd tape corresponding to "x". + # In particular, if a user had previously backpropped through y's autograd tape, + # And then they try to backprop through the output of the above function, + # then we'll hit the dreaded "Trying to backward through the graph a second time" error. + # + # You might think: If autograd sees that a gradient is None, shouldn't it stop early, + # instead of continuing the backprop through the ancestors of that node in the graph? + # + # Autograd has two passes: + # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed + # (2) a second pass that actually goes ahead and executes each node when it becomes ready, + # propagating gradients + # By the time we're executing a node and we see that it produces a None, the set of nodes to execute + # is already locked-in. + # + # The fix: instead, we can recognize statically that the graph we're compiling will never contribute + # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. + # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. + # + # Note that this solution is not bulletproof. + # It's possible to construct a case where eager may or may not have have tried to autograd through y, + # depending on the actual grad_outputs that were passed in during the backward. + # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, + # allowing autograd to re-use the graph. + # + # An example of this case is: + # def f(x): + # return x.detach() * 2, x * 3 + # If we were to only backprop through outs[0], in eager, we would stop + # If we backward only on the first output, we shouldn't send a grad through x. + # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 + # and we will end up with a zero grad at x. + # If we later backprop through the second output, this will also require backprop'ing through x. + # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. + _indices_of_inps_to_detach = [] + bw_outs = next(n for n in bw_module.graph.nodes if n.op == "output").args[0] + + # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" + # optimization even if we have subclass inputs/outputs (we do not handle this today). + # Computing which our our inputs get None gradients is a bit more complicated, + # if any of our inputs are subclasses. Why? + # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. + # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, + # so we need to figure out which subclass fw inputs they map to. + if maybe_subclass_meta is None: + assert ( + len(bw_outs) + == len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset + ) + for i, (bw_out) in enumerate(bw_outs): + if bw_out is None: + _indices_of_inps_to_detach.append(i) + + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id), + ) + aot_graphs_log.info( + "%s", + lazy_format_graph_code("Backward graph", bw_module, aot_config.aot_id), + ) + trace_structured( + "aot_forward_graph", + payload_fn=lambda: fw_module.print_readable(print_output=False), + ) + trace_structured( + "aot_backward_graph", + payload_fn=lambda: bw_module.print_readable(print_output=False), + ) + + with track_graph_compiling(aot_config, "forward"): + # flat_args at this point might still be subclasses- + # make sure to pass the unwrapped fake tensors into the compiler! + adjusted_flat_args = joint_inputs[0] + if config.functionalize_rng_ops: + # Update example inputs for the fw_compiler + fake_mode = detect_fake_mode() + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) + adjusted_flat_args.extend([seed, offset]) + # We are not clearing flat_args here because + # 1) There is a check in the debug compiler at the end + # 2) It does not matter as these are fake tensors + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = inner_meta + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) + if not hasattr(compiled_fw_func, "_boxed_call"): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + # see note: [Returning Fake Tensors on First AOT Autograd Call] + if tracing_context and tracing_context.fakify_first_call: + fakified_out = _compute_output_meta_with_inductor_strides( + fw_module, fwd_output_strides + ) + fakify_first_call = True + + if maybe_subclass_meta is not None: + # Why do we need to pass in num_fw_outs_saved_for_bw? + # See Note: [Partitioner handling for Subclasses, Part 2] + compiled_fw_func = aot_dispatch_subclass_wrapper( + compiled_fw_func, + subclass_metas=fw_metadata.subclass_fw_graph_out_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ) + if not hasattr(compiled_fw_func, "_boxed_call"): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + # NB: It's important to compile backwards ahead of time, as this may + # add extra guards which we need to apply to the Dynamo cache at + # forwards + with track_graph_compiling(aot_config, "backward"): + placeholder_list = fx_placeholder_vals(bw_module) + + forward_saved_for_backwards_strides = None + if fwd_output_strides is not None: + forward_saved_for_backwards_strides = fwd_output_strides[ + inner_meta.tensors_saved_for_backwards_slice + ] + + # saved activations can have different stride to eager if + # the compiler does layout optimization. We should restride the + # tensor passed in for compiling the backward graph using the + # saved tensor's stride. + for i in range(len(placeholder_list)): + ph_arg = placeholder_list[i] + if not isinstance(ph_arg, torch.Tensor): + continue + + if forward_saved_for_backwards_strides is None: + continue + + real_stride = None + # Per all_args calling convention + j = i - len(symint_outs_saved_for_bw) + if 0 <= j < len(forward_saved_for_backwards_strides): + real_stride = forward_saved_for_backwards_strides[j] + if real_stride is None: + continue + + # Comparing ph_arg.stride() with real_stride directly may + # cause dynamic dimensions in ph_arg being specialized to static + # value. Using the hints to avoid that. + if _get_symint_hints(ph_arg.stride()) != real_stride: + # Note that here we use the stride of the real tensor to + # restride a FakeTensor. This does not cause trouble + # for dynamic shape since this code path only get + # executed if layout optimization is enabled. And we + # disable layout optimization for dynamic shape right + # now. + # + # A solution that decide stride order based on real + # tensor's stride and then apply that stride order to + # the FakeTensor does not work smoothly since some + # tensor's layout is not 'dense'. E.g. mixnet_l has a + # tensor with size [8, 64, 112, 112] and strides + # (2408448, 1, 21504, 192). The solution mentioned will + # decide a stride of (802816, 1, 7168, 64) for this + # tensor which is wrong. + placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride) + + compiled_bw_func = None + if len(symint_outs_saved_for_bw): + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(): + try: + compiled_bw_func = aot_config.bw_compiler( + bw_module, placeholder_list + ) + except Exception: + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) + # Compiled autograd will run the bw_module in the backward pass, + # so recompilation need happen anyway if the backward pass is ever + # called. + # + # The reason we do the GraphModule recompilation here is because + # the lazy recompilation will cause issue in the backward pass + # with compiled autograd. + # + # Do the _LazyGraphModule.force_recompile here rather than when + # bw_module is first generated by the partitioner because the bw_module.recompile + # may be called in some code path later and cause the _LazyGraphModule.forward + # becomes the lazy version again. One example is when dynamic shape is enabled + # upfront, the bw_compiler will be called above which can cause extra + # graph module recompilation on bw_module. + if torch._dynamo.compiled_autograd.compiled_autograd_enabled_count: + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(bw_module) + + saved_context = TracingContext.try_get() + + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + class CompiledFunction(torch.autograd.Function): + compiled_fw = compiled_fw_func + compiled_bw = compiled_bw_func + metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] + maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta + num_symints_saved_for_bw = _num_symints_saved_for_bw + _compiled_autograd_should_lift = False + _fakify_first_call = fakify_first_call + + @staticmethod + def _compiled_autograd_key(ctx): + return (ctx._autograd_function_id, *ctx.symints) + + @staticmethod + def forward(ctx, *deduped_flat_tensor_args): + args = deduped_flat_tensor_args + if backward_state_indices: + bw_state = args[backward_state_indices[0]] + assert isinstance(bw_state, BackwardState) + ctx._compiled_autograd_backward_state = bw_state + + marked_dirty_inps = [] + for i in fw_metadata.mutated_graph_handled_indices_seen_by_autograd: + arg = deduped_flat_tensor_args[i] + if not (arg.requires_grad and arg.is_leaf): # would error + ctx.mark_dirty(arg) + marked_dirty_inps.append(arg) + + if not CompiledFunction._fakify_first_call: + if CompiledFunction.metadata.is_rng_op_functionalized: + # Add the seed and offset to args + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() + args = (*args, seed, offset) + # There is a pretty complicated calling convention around what the compiled fw returns. + # The full list of outputs and their relative order is: + # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) + # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version + # of the original view, and not the synthetic base + + fw_outs = call_func_at_runtime_with_args( + CompiledFunction.compiled_fw, + args, + disable_amp=disable_amp, + ) + else: + nonlocal fakified_out + assert fakified_out is not None + CompiledFunction._fakify_first_call = False + fw_outs = fakified_out + fakified_out = None + + num_outputs = CompiledFunction.metadata.num_outputs + num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased + num_mutated_runtime_inps = ( + CompiledFunction.metadata.num_mutated_inp_runtime_indices + ) + num_tokens = len(CompiledFunction.metadata.tokens) + num_forward_returns = CompiledFunction.metadata.num_forward_returns + num_forward = CompiledFunction.metadata.num_forward + + # Partitioners must put symint arguments at the end separate from tensor arguments + tensors_saved_for_backwards = fw_outs[ + CompiledFunction.metadata.tensors_saved_for_backwards_slice + ] + assert all(isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards) + # See Note [Detaching saved tensors in AOTAutograd] + ctx.save_for_backward( + *( + x.detach() if x._is_view() else x + for x in tensors_saved_for_backwards + ) + ) + symint_outs = fw_outs[ + CompiledFunction.metadata.symints_saved_for_backwards_slice + ] + assert all( + isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) + for x in symint_outs + ), str([type(x) for x in symint_outs]) + ctx.symints = symint_outs + + raw_returns = fw_outs[0 : num_forward_returns + num_tokens] + + # Wrap all autograd.Function.forward() outputs that are aliases + # so that autograd.Function doesn't treat them as tensors + if num_mutated_runtime_inps > 0: + for i, idx in enumerate( + CompiledFunction.metadata.mutated_inp_runtime_indices + ): + # We could make this faster by only looping over inputs with metadata-only mutations + # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. + info = CompiledFunction.metadata.input_info[idx] + if info.mutates_metadata and not info.mutates_data: + raw_returns[i] = TensorAlias(raw_returns[i]) + + if config.debug_assert: + user_mutated_inputs_raw = raw_returns[0:num_mutated_runtime_inps] + mut_inp_infos = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutates_data or x.mutates_metadata + ] + assert len(user_mutated_inputs_raw) == len(mut_inp_infos) + + if CompiledFunction.metadata.num_unsafe_view_outputs > 0: + for idx in CompiledFunction.metadata.unsafe_view_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + o = raw_returns[raw_return_idx] + raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view( + o, o.shape + ) + + if num_outputs_aliased > 0: + for idx in CompiledFunction.metadata.aliased_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + intermediates_raw = raw_returns[ + num_mutated_runtime_inps + num_outputs : + ] + assert not any( + isinstance(x, TensorAlias) for x in intermediates_raw + ) + + # invariant: intermediate bases always require gradients, so we don't have to + # consider marking them as non-differentiable. + raw_returns_not_including_intermediate_bases = raw_returns[ + : num_mutated_runtime_inps + num_outputs + ] + raw_returns_meta = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + CompiledFunction.metadata.output_info + + fw_outs_not_requiring_grad = [ + x + for (i, x) in enumerate(raw_returns_not_including_intermediate_bases) + if isinstance(x, torch.Tensor) and not raw_returns_meta[i].requires_grad + ] + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False + + functionalized_rng_runtime_epilogue( + CompiledFunction.metadata, + fw_outs[num_forward_returns:num_forward], + return_new_outs=False, + ) + return tuple(raw_returns) + tuple(marked_dirty_inps) + + @staticmethod + def backward(ctx, *flat_args): + # Calling convention: we expect a grad_out passed to the backward: + # - for every output of the fw that does *not* alias an input or graph intermediate + # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations) + # - for every graph intermediate that we need to use to generate an output later. + # The other outputs in the autograd.Function.forward that do *not* show up in the backward include: + # - outputs that alias inputs or graph intermediates + # - updated inputs due to metadata-only mutations. + # We need to return them in the forward, but ensure that they all do not get gradients in the backward, + # and we filter them out here before passing the remaining grad_outputs into the compiled backward. + num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases + num_graph_handled_inputs = ( + CompiledFunction.metadata.num_mutated_graph_handled_indices_seen_by_autograd + ) + num_mutated_runtime_inps = ( + CompiledFunction.metadata.num_mutated_inp_runtime_indices + ) + expected_grad_outs = ( + CompiledFunction.metadata.num_outputs + + num_mutated_runtime_inps + + num_intermediate_bases + ) + deterministic = CompiledFunction.metadata.deterministic + global_deterministic = torch.are_deterministic_algorithms_enabled() + if deterministic is not None: + torch._check( + not (not deterministic and global_deterministic), + lambda: ( + "This compiled backward function is being run with " + "torch.use_deterministic_algorithms(True), " + "but it was previously generated during the forward function while " + "torch.use_deterministic_algorithms(False) was set." + ), + ) + + if num_graph_handled_inputs > 0: + flat_args = flat_args[:-num_graph_handled_inputs] + assert len(flat_args) == expected_grad_outs + out_info = CompiledFunction.metadata.output_info + + inp_tangents, out_tangents, intermediate_base_tangents = ( + flat_args[0:num_mutated_runtime_inps], + flat_args[ + num_mutated_runtime_inps : num_mutated_runtime_inps + + CompiledFunction.metadata.num_outputs + ], + flat_args[ + num_mutated_runtime_inps + CompiledFunction.metadata.num_outputs : + ], + ) + # input_info contains info on *every* input, + # But in the backward(), we are only given grad outputs for every mutated input + # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad + input_info = CompiledFunction.metadata.input_info + inp_tangents_filtered = [ + x + for x, info_idx in zip( + inp_tangents, CompiledFunction.metadata.mutated_inp_runtime_indices + ) + if input_info[info_idx].mutates_data + and input_info[info_idx].requires_grad + ] + # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates + out_tangents_filtered = [ + x + for x, info in zip(out_tangents, out_info) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases always require gradients, and always participate in the backward graph. + flat_bw_args_with_grads = [ + *inp_tangents_filtered, + *out_tangents_filtered, + *intermediate_base_tangents, + ] + num_flat_bw_args_with_grads = len(flat_bw_args_with_grads) + + # sanity asserts + # metadata_only_inps = [ + # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) + # if not input_info[info_idx].mutates_data + # ] + # aliased_outputs = [ + # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] + # assert all(x is None for x in metadata_only_inps) + # assert all(x is None for x in aliased_outputs) + + rng_args = [] + if CompiledFunction.metadata.is_rng_op_functionalized: + # Add the seed and offset to args + rng_args = CUDARngStateHelper.get_torch_state_as_tuple() + + all_args = [ + *ctx.symints, + *ctx.saved_tensors, + *flat_bw_args_with_grads, + *rng_args, + ] + del flat_bw_args_with_grads + + tangents_start_idx = ( + len(all_args) - num_flat_bw_args_with_grads - len(rng_args) + ) + tangents_end_idx = len(all_args) - len(rng_args) + + # Note: [AOTAutograd Backward Guards] + # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph. + # Doing so requires us to "guess" about some of the metadata of our grad_outputs. + # + # In particular: if an output to the forward is a plain tensor or a subclass, + # its corresponding grad_output in the backward **may or may not** be + # a plain tensor or a subclass. The main cases are: + # (1) If an output is a plain tensor, its grad_out will also be a plain tensor, + # *unless* the output is used in some subclass compute later in the forward graph, + # which will cause its grad_output to become a subclass + # (2) If an output is a subclass, its grad_out will also be a subclass, + # *unless* the output of the forward did not actually participate in the gradient computation, + # in which case autograd will insert a plain tensor of zeros for the grad_output. + # We could avoid this case with `torch.autograd.Function.set_materialize_grads`, + # although this is not turned on today in AOTAutgrad and would require more work. + # + # Today, we make a guess on subclass-ness based on the above examples, + # and hard-error in the backward if we guessed wrong. + # + # In the future, we should add backward guards that would allow us to + # properly handle this case instead of erroring: we would need to retrace the backward graph, + # since we might produce an entirely different trace if our grad_outputs are subclass or not. + assert ( + len(CompiledFunction.metadata.output_types) + == num_flat_bw_args_with_grads + ) + grad_output_types = [ + type(x) for x in all_args[-num_flat_bw_args_with_grads:] + ] + # In general, we can add more asserts/guards here for when we partitioned + # with incorrect assumptions about the grad_outputs. + # Normalize FakeTensor -> torch.Tensor + # - during tracing our types are FakeTensor + # - at runtime in the backward our types are torch.Tensor... + # - unless we're running compiled backward, in which case they are also FakeTensor + grad_output_types_ = [ + torch.Tensor if x is FakeTensor else x for x in grad_output_types + ] + assert ( + grad_output_types_ == CompiledFunction.metadata.output_types + ), f"""\ +We incorrectly attempted to compile the backward with incorrect subclass metadata. +If you run into this error, please file an issue. +Expected grad_output types: {str(CompiledFunction.metadata.output_types)} +Got grad_output types: {str(grad_output_types)}""" + + # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. + if CompiledFunction.maybe_subclass_metadata is not None: + # Get the number of tangents after unwrapping + len_tangents = len( + unwrap_tensor_subclasses( + all_args[tangents_start_idx:tangents_end_idx], + is_joint_structure=False, + ) + ) + all_args = unwrap_tensor_subclasses(all_args, is_joint_structure=False) + tangents_start_idx = len(all_args) - len_tangents - len(rng_args) + tangents_end_idx = tangents_start_idx + len_tangents + + # Make the tangents contiguous. Note that we must do this after subclass desugaring + # because inputs to inductor have to be contiguous + all_args = [ + t.contiguous() + if ( + (tangents_start_idx <= i < tangents_end_idx) + and (not t.is_contiguous()) + ) + else t + for i, t in enumerate(all_args) + ] + + def call_compiled_backward(): + if ctx._is_compiled_autograd_tracing(): + # For compiled autograd, run raw FX graph so that it can be inlined into the larger graph + symints = ctx._get_compiled_autograd_symints() + assert len(symints) == len(ctx.symints) + all_args[: len(symints)] = symints + if backward_state_indices: + assert ctx._compiled_autograd_backward_state.proxy is not None + all_args.append(ctx._compiled_autograd_backward_state) + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(): + out = normalize_as_list(bw_module(*all_args)) + out = functionalized_rng_runtime_epilogue( + CompiledFunction.metadata, out + ) + return tuple(out) + assert ( + not backward_state_indices + ), "BackwardState requires CompiledAutograd" + ctx.maybe_clear_saved_tensors() + if CompiledFunction.compiled_bw is None: + context = torch._C._DisableAutocast if disable_amp else nullcontext + with tracing(saved_context), context(), track_graph_compiling( + aot_config, "backward" + ): + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, placeholder_list + ) + + out = call_func_at_runtime_with_args( + CompiledFunction.compiled_bw, + all_args, + steal_args=True, + disable_amp=disable_amp, + ) + + out = functionalized_rng_runtime_epilogue( + CompiledFunction.metadata, out + ) + return tuple(out) + + if torch.is_grad_enabled() and any( + t.requires_grad for t in all_args if isinstance(t, torch.Tensor) + ): + # Ensure that the graph is connected, and error if double backward is performed. + # See comment for why once_differentiable is not sufficient: + # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 + class CompiledFunctionBackward(torch.autograd.Function): + # CompiledFunctionBackward is not yet supported in dynamo skipfiles + _compiled_autograd_should_lift = False + + @staticmethod + def forward(ctx, *unused_args): + outs = call_compiled_backward() + # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. + if CompiledFunction.maybe_subclass_metadata is not None: + assert ( + CompiledFunction.maybe_subclass_metadata.grad_input_metas + is not None + ) + outs_wrapped = wrap_tensor_subclasses( + outs, + subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas, + ) + return outs_wrapped + return outs + + @staticmethod + def backward(ctx, *args): + raise RuntimeError( + "torch.compile with aot_autograd does not currently support double backward" + ) + + CompiledFunctionBackward._compiled_autograd_key = ( # type: ignore[method-assign] + CompiledFunction._compiled_autograd_key + ) + + # Pass args even though they're unused, so that the graph is built + out = CompiledFunctionBackward.apply(*all_args) + else: + out = call_compiled_backward() + + # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. + if CompiledFunction.maybe_subclass_metadata is not None: + assert ( + CompiledFunction.maybe_subclass_metadata.grad_input_metas + is not None + ) + outs_wrapped = wrap_tensor_subclasses( + out, + subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas, + ) + return outs_wrapped + return out + + compiled_function = create_runtime_wrapper( + CompiledFunction.apply, + runtime_metadata=fw_metadata, + indices_of_inps_to_detach=_indices_of_inps_to_detach, + trace_joint=True, + keep_input_mutations=aot_config.keep_inference_input_mutations, + disable_amp=disable_amp, + ) + + if not config.debug_assert: + return compiled_function + + flat_requires_grad = [ + a.requires_grad if isinstance(a, Tensor) else None for a in flat_args + ] + + @wraps(compiled_function) + def debug_compiled_function(*args): + # TODO: Check aliasing relationships + # TODO: Check strides for metadata mutation + # (NB: ideally, this logic is factored out of this function and + # you move these debug checks there) + + # Check requires grad. Bad case is when we compiled with + # requires_grad = False, but input requires_grad = True + # (vice versa is OK; we compute a gradient and then throw + # it away when it hits the input.) + for i, a in enumerate(args): + can_require_grad = flat_requires_grad[i] + if can_require_grad is None: + assert not isinstance(a, Tensor) + elif not can_require_grad: + assert not a.requires_grad, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would not require grad", + ) + + return compiled_function(*args) + + return debug_compiled_function diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/logging_utils.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01dfe5f031611bdf1c7a9737d1226a858ba27832 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/logging_utils.py @@ -0,0 +1,135 @@ +""" +Contains utils for logging in AOTAutograd, including managing the names of the graphs under +compilation, capturing user-friendly tracebacks, and debug messages. +""" + +import collections +from contextlib import contextmanager +from typing import List, Tuple + +import torch +import torch.fx.traceback as fx_traceback + +# This is a list since looking forward, we can have this arbitrarily nested. +graph_being_compiled: List[str] = [] +# TODO: It would be nice to reset the numbering every time aot_id goes +# up, but this is annoying to do right now (because we don't know if +# an aot_id will come back from the dead), so right now this also happens +# to be a globally unique number too (at the cost of wobbling if you change +# how the graphs compile) +nth_graph: int = 0 +model_name: str = "model" + + +def set_model_name(name): + global model_name + model_name = name + + +def get_aot_compilation_context() -> Tuple[List[str], str, int]: + return list(graph_being_compiled), model_name, nth_graph + + +def get_aot_graph_name() -> str: + """ + Returns the name of the graph being compiled. + """ + global model_name, graph_being_compiled, nth_graph + return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}" + + +get_graph_being_compiled = get_aot_graph_name + + +@contextmanager +def track_graph_compiling(aot_config, graph_name): + global graph_being_compiled + # TODO: Don't shove the aot_id in here; set it in the context + graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"] + try: + yield + finally: + global nth_graph + nth_graph += 1 + graph_being_compiled = [] + + +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: List): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() # type: ignore[var-annotated] + for node in roots: + if node is not None and node not in seen: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_, seq_nr): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( # type: ignore[attr-defined] + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + fx_traceback.set_grad_fn_seq_nr(seq_nr) + + return prehook + + def get_posthook(special_stack_, seq_nr): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + fx_traceback.reset_grad_fn_seq_nr() + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr())) + + special_stack = forward_node_stack.copy() + special_stack.append( + "Gradient addition node due to multiple use of tensor around:" + ) + node.register_hook(get_posthook(special_stack, node._sequence_nr())) + + +def describe_input(i, aot_config): + if i < aot_config.num_params_buffers: + return f"parameter/buffer {i}" + else: + return f"input {i - aot_config.num_params_buffers}" + + +def format_guard_bug_msg(aot_config, expected): + return ( + f"At compilation time, graph {aot_config.aot_id} was compiled under the " + f"assumption that {expected}, but at runtime this was not the case. " + "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch." + ) diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..84e8828661c87c739c4115090f062d61bf39f697 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -0,0 +1,1021 @@ +""" +This module defines runtime wrappers, which, based on previous analysis attempts to: +1. process the inputs and outputs +2. apply mutations +3. handle functionalized randomness +4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) +""" + +import collections +import pprint +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.dlpack +from torch import Tensor +from torch._guards import DuplicateInputs, TracingContext +from torch._prims_common import CUDARngStateHelper +from torch.multiprocessing.reductions import StorageWeakRef +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata + +from .functional_utils import gen_alias_from_base +from .input_output_analysis import ( + compute_overlapping_inputs, + create_synthetic_base_metadata, + remove_dupe_metadata, +) +from .logging_utils import describe_input, format_guard_bug_msg +from .schemas import ( + AOTConfig, + InputAliasInfo, + OutputType, + SubclassCreationMeta, + TensorAlias, + ViewAndMutationMeta, +) +from .subclass_utils import ( + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses, +) + +from .utils import ( + call_func_at_runtime_with_args, + make_boxed_func, + partial_flatten_asdict, + strict_zip, +) + + +zip = strict_zip + + +# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic +# that needs to run after the compiled function. +# +# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime +# epilogue for a forward-only inference graph, or for an autograd.Function.apply function. +# This is because there are some minor differences in how we treat these cases at runtime: +# - resize_() is currently handled in the inference case, but not fully handled in the autograd case. +# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs +def create_runtime_wrapper( + compiled_fn, + *, + runtime_metadata: ViewAndMutationMeta, + indices_of_inps_to_detach: List[int], + trace_joint: bool, + keep_input_mutations: bool, + disable_amp: bool, +): + num_tokens = len(runtime_metadata.tokens) + + if not hasattr(compiled_fn, "_boxed_call"): + compiled_fn = make_boxed_func(compiled_fn) + + def runtime_wrapper(*args): + # Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + args = (*[torch.tensor([])] * num_tokens, *args) + + if trace_joint: + args_ = list(args) + # See Note [Detaching inputs that never need gradients] + for idx in indices_of_inps_to_detach: + if isinstance(args_[idx], torch.Tensor): + args_[idx] = args_[idx].detach() + with torch.autograd._force_original_view_tracking(True): + all_outs = call_func_at_runtime_with_args( + compiled_fn, + args_, + disable_amp=disable_amp, + ) + else: + # When we have an inference graph, we run with torch.no_grad. + # It's possible to get an inference graph with inputs that require grad, + # in which case we want to make sure autograd is disabled + # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) + if torch.is_grad_enabled(): + with torch.no_grad(): + all_outs = call_func_at_runtime_with_args( + compiled_fn, + args, + disable_amp=disable_amp, + ) + else: + all_outs = call_func_at_runtime_with_args( + compiled_fn, + args, + disable_amp=disable_amp, + ) + + num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices + num_intermediate_bases = runtime_metadata.num_intermediate_bases + + if keep_input_mutations and trace_joint: + num_input_mutations_handled_by_autograd = ( + runtime_metadata.num_mutated_graph_handled_indices_seen_by_autograd + ) + # autograd.Function requires us to return the mutated inputs as extra outputs to the autograd.Function.forward + if num_input_mutations_handled_by_autograd > 0: + all_outs = all_outs[:-num_input_mutations_handled_by_autograd] + + assert ( + len(all_outs) + == num_mutated_runtime_inps + + runtime_metadata.num_outputs + + num_intermediate_bases + + num_tokens + ) + + # Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + all_outs = all_outs[num_tokens:] + + # Step 3: After running the compiled fw, apply updates to mutated inputs + num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices + if num_mutations_to_apply > 0: + updated_inputs = all_outs[:num_mutations_to_apply] + fw_outs = all_outs[num_mutations_to_apply:] + + for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): + meta = runtime_metadata.input_info[inpt_idx] + if not meta.mutates_data and not meta.mutates_metadata: + continue + original_inpt = args[inpt_idx] + updated_inpt = updated_inputs[i] + if meta.mutates_storage_metadata: + # mutates_storage_metadata means our input saw a x.set_(y) call. + # What if x **also** saw a data and/or a metadata mutation? + # (1) If the [meta]data mutation occurred after the set_(), + # then there is no need to copy_() the data. + # When we perform x.set_(x_updated), we are guaranteed that + # x_updated already has the final version of the data/metadata + # (2) If a data mutation occurred before the set_(). + # This case seems very difficult to support. + # TODO: discuss on the PR and decide if we want to tr to + # either support it, or detect and ban it. + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + with torch.no_grad(): + original_inpt.set_(updated_inpt) + continue + if meta.mutates_metadata and not meta.mutates_data: + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to mutate the metadata of the input + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + if meta.mutates_data and meta.mutates_metadata: + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + assert meta.mutates_data + if meta.is_leaf and original_inpt.requires_grad: + # We can hit this situation in this case: + # def f(x): + # x.detach().mul_(2) + # return x + 1 + # AOTAutograd will see a mutation in the above case, and try to + # apply a copy_() here, in the epilogue. + # But if x required gradients, and is a leaf, then autograd + # will yell at us for trying to mutate it. + # However, it's only possible to end up in this scenario (like the above) + # if all of the mutations to the leaf input were non-autograd-tracking mutations + # (aka mutations under no_grad(), or on detached views). + # In that case, we fully want to hide the mutation from autograd, so detaching is ok. + original_inpt.detach().copy_(updated_inpt) + else: + original_inpt.copy_(updated_inpt) + else: + fw_outs = all_outs + + # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of + # compiling them. + if runtime_metadata.num_outputs_aliased > 0: + # The compiled forward also returned intermediate bases. We don't want to return them to the user. + if runtime_metadata.num_intermediate_bases > 0: + fw_outs_no_intermediate_bases = fw_outs[ + : -runtime_metadata.num_intermediate_bases + ] + intermediate_bases = fw_outs[-runtime_metadata.num_intermediate_bases :] + else: + fw_outs_no_intermediate_bases = fw_outs + intermediate_bases = [] + + assert len(fw_outs_no_intermediate_bases) == len( + runtime_metadata.output_info + ) + fw_outs_including_aliases = [] + for i, (o, info) in enumerate( + zip(fw_outs_no_intermediate_bases, runtime_metadata.output_info) + ): + if info.output_type in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ]: + fw_outs_including_aliases.append(o) + continue + if trace_joint: + assert isinstance(o, TensorAlias) + o_ = o.alias + else: + o_ = o + + o_grad = runtime_metadata.output_info[i].requires_grad + if info.output_type == OutputType.alias_of_input: + aliased_base_tensor = args[info.base_idx] # type: ignore[index] + regenerated_out = gen_alias_from_base( + aliased_base_tensor, o_, o_grad + ) + fw_outs_including_aliases.append(regenerated_out) + continue + elif info.output_type == OutputType.is_input: + aliased_base_tensor = args[info.base_idx] # type: ignore[index] + regenerated_out = aliased_base_tensor + fw_outs_including_aliases.append(regenerated_out) + continue + elif info.output_type == OutputType.alias_of_intermediate: + base_tensor_list = intermediate_bases + elif ( + info.output_type == OutputType.alias_of_intermediate_save_as_output + ): + base_tensor_list = intermediate_bases + else: + assert ( + info.output_type + == OutputType.alias_of_intermediate_base_is_user_output + ) + base_tensor_list = fw_outs_no_intermediate_bases + aliased_base_tensor = base_tensor_list[info.base_idx] + # TODO: handle the custom autograd function case here. + # We need a way to check whether a tensor came from a custom autograd fn from python, + # AND a way to replay that custom view fn. + regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad) + fw_outs_including_aliases.append(regenerated_out) + ret_outs = fw_outs_including_aliases + else: + ret_outs = fw_outs + + if runtime_metadata.dynamic_outputs: + for t, o in zip(ret_outs, runtime_metadata.output_info): + if o.dynamic_dims is None: + continue + if hasattr(t, "_dynamo_weak_dynamic_indices"): + t._dynamo_weak_dynamic_indices |= o.dynamic_dims + else: + t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy() + if runtime_metadata.grad_enabled_mutation is not None: + torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation) + return ret_outs + + return runtime_wrapper + + +# Calling convention: If we are running functionalized RNG, then outs consists +# of (user_outs, rng_offset) +def functionalized_rng_runtime_epilogue( + metadata: ViewAndMutationMeta, outs, return_new_outs=True +): + if metadata.is_rng_op_functionalized: + assert metadata.num_outputs_rng_offset == 1 + new_rng_offset = outs[-1] + CUDARngStateHelper.set_new_offset(new_rng_offset) + if return_new_outs: + user_outs = outs[:-1] + return user_outs + else: + return None + return outs + + +# This wrapper handles the AOTDispatch runtime logic for tensor subclasses. +# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor, +# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs). +# This function handles the wrapping and unwrapping of tensor subclasses at runtime. +def aot_dispatch_subclass_wrapper( + runtime_fn: Callable, + *, + subclass_metas: List[Union[int, SubclassCreationMeta]], + num_fw_outs_saved_for_bw: Optional[int], +) -> Callable: + def inner_fn(args): + unwrapped_args = unwrap_tensor_subclasses(args, is_joint_structure=False) + # expectation: runtime_fn is a boxed fn + unwrapped_outs = runtime_fn(unwrapped_args) + wrapped_outs = wrap_tensor_subclasses( + unwrapped_outs, + subclass_metas=subclass_metas, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + is_runtime=True, + ) + return wrapped_outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +# MOTIVATION: +# +# When tracing functions for future execution, one must be careful not to pass +# in the same input tensor multiple times (e.g., f(x, x), as this can result +# in graphs that are ONLY valid if you later pass a new tensor in exactly the +# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct +# tensors that alias each other is a different situation that is covered by +# aot_dispatch_deduplicated_autograd). Here are two examples: +# +# (1) Suppose you have a function: +# +# def f(x, y): +# return x + y +# +# If you make_fx(f)(x, x), you will trace out: +# +# def f(x, y): +# return y + y +# +# Oops! +# +# (2) For most tensors x and y, you can compute f's gradient with respect to +# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However, +# if x is y, you will trace out a program that gets incorrect gradients: +# +# >>> x = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + x, (x, x)) +# (tensor([2.]), tensor([2.])) +# +# In other words, the gradient is double-counted. Deduplicating the arguments +# gives you an appropriate gradient: +# +# >>> y = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + y, (x, y)) +# (tensor([1.]), tensor([1.])) +# +# HOW TO DEDUPLICATE: +# +# There are a few strategies, in order of preference: +# +# 1. For every duplicate argument to the function, detach it into +# a separate leaf tensor, so that it is no longer duplicated. +# +# PRO: The resulting compiled graph works for any configuration +# of duplicated arguments. +# +# CON: It does not (naively) work if you mutate the metadata of inputs: +# +# def f(x, y): +# x.transpose_(0, 1) +# y.transpose_(0, 2) +# +# x = torch.randn(2, 3, 4) +# f(x, x) +# +# The ordering of the transposes inside f dictates whether or not +# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute +# what metadata mutations should get applied to each input; you need to +# assume they aren't duplicates (what we do today) or preserve +# the original metadata mutations exactly in order, so that they work +# for any duplicate configuration. +# +# CON: It does not (naively) work if you mutate the data of inputs. +# In particular, leaf tensors that require grad cannot be mutated, +# this makes it impossible to differentiate with respect to the original +# base. +# +# 2. For every duplicate argument to the function, remove it, so it is +# no longer part of the "true" signature: +# +# PRO: Implemented naively, it still works for metadata/data mutation. +# +# CON: The resulting compiled graph is duplicate-specialized: it only +# works if future calls duplicate arguments in exactly the same way. +# Horribly, Dynamo doesn't guard on this at the moment. But even if +# it did, you could still end up recompiling a bunch of each duplicate. +# +# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if +# Dynamo's guards are not enough. In practice, this seems to cover +# everything. +# +def aot_wrapper_dedupe( + flat_fn, + flat_args: List[Tensor], + aot_config: AOTConfig, + *, + compiler_fn, + fw_metadata, +): + # Use information about whether or not flat_fn mutates its arguments + # or not to handle dupe args + + # Strategy 1: For any input that is not mutated, we can leafify it if we + # need to remove a duplicate. + leaf_flat_args = [] + args_set = set() + ok = True + + for i, a in enumerate(flat_args): + if not isinstance(a, torch.Tensor): + leaf_flat_args.append(a) + elif a not in args_set: + args_set.add(a) + leaf_flat_args.append(a) + elif ( + not fw_metadata.input_info[i].mutates_data + and not fw_metadata.input_info[i].mutates_metadata + ): + leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) + else: + ok = False + break + + if ok: + return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata) + + if requires_subclass_dispatch(leaf_flat_args, fw_metadata): + raise RuntimeError( + """\ +Encountered duplicate inputs that are mutated in the graph, but at least one input/output +to the graph is a tensor subclass. This is not supported today. You can try to +remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + # export path: ban duplicate inputs for now, add later if requested. + if aot_config.is_export: + raise RuntimeError( + f"""\ +Encountered duplicated inputs that are mutated in the graph you are trying to export. +This functionality is currently not supported. If needed, please file a github issue. + +fw_metadata={str(fw_metadata)} + """ + ) + + # Strategy 2: Duplicate specialize. + # + # In Haskell types, suppose you have: + # + # add_dupe_args :: DedupedArgs -> Args + # remove_dupe_args :: Args -> DedupedArgs + # + # compiler_fn + # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R) + # deped_compiler_fn + # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R) + # + # Then the code below can be written in point-free style as: + # + # deduped_compiler_fn f a c = + # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args + # + # Suppose you have: + # + # [a, b, a, c] + # + # We want: + # + # remove_dupe_args([a, b, a, c]) == [a, b, c] + # add_dupe_args([a, b, c]) == [a, b, a, c] + # + # This is done via (respectively): + # + # seen_args = {a: 0, b: 1, c: 2} + # enumerate(add_dupe_map) = [ # how to get args from the deduped list + # (0, 0), + # (1, 1), + # (2, 0), + # (3, 2), + # ] + # keep_arg_mask = [True, True, False, True] + + seen_args: Dict[Tensor, int] = {} + keep_arg_mask = [] + # Implicitly map duped arg position (list index) to de-duped arg position + add_dupe_map: List[int] = [] + duped_arg_len = len(flat_args) + + j = 0 # index into deduped_flat_args + for t in flat_args: + if isinstance(t, torch.Tensor): + if t in seen_args: + keep_arg_mask.append(False) + add_dupe_map.append(seen_args[t]) + continue + seen_args[t] = j + + keep_arg_mask.append(True) + add_dupe_map.append(j) + j += 1 + assert ( + len(add_dupe_map) == duped_arg_len + ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" + + # NB: Hot path, avoid set lookups here + # TODO: Can avoid the zip here too, probably + def remove_dupe_args(args): + return [t for t, keep in zip(args, keep_arg_mask) if keep] + + def add_dupe_args(args): + return [args[add_dupe_map[i]] for i in range(duped_arg_len)] + + deduped_flat_args = remove_dupe_args(flat_args) + + # Update our input metadata to remove duped input metadata. + updated_fw_metadata = remove_dupe_metadata(fw_metadata, keep_arg_mask, add_dupe_map) + + if ( + tracing_context := TracingContext.try_get() + and aot_config.aot_autograd_arg_pos_to_source + ): + # TODO(voz): This structure is 1:1, we could consider an alternate structure like + # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there, + # which feels like needless complexity for a tiny bit of efficiency at this point. + for dupe_arg_pos, (kept_pos, keep_arg) in enumerate( + zip(add_dupe_map, keep_arg_mask) + ): + if not keep_arg: + dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + dupe_arg_pos + ] + kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[kept_pos] + tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined] + DuplicateInputs(kept_arg_source, dupe_arg_source) + ) + + @wraps(flat_fn) + def wrapped_flat_fn(*args): + return flat_fn(*add_dupe_args(args)) + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + wrapped_flat_fn, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*deduped_flat_args) + assert ( + ref_fw_metadata == updated_fw_metadata + ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + + compiled_fn = compiler_fn( + wrapped_flat_fn, deduped_flat_args, aot_config, fw_metadata=updated_fw_metadata + ) + + if not hasattr(compiled_fn, "_boxed_call"): + compiled_fn = make_boxed_func(compiled_fn) + + @wraps(compiled_fn) + def wrapped_compiled_fn(args): + deduped_args = remove_dupe_args(args) + args.clear() + return compiled_fn(deduped_args) + + wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + # This can be uncommented when we properly guard for duplicates, + # but right now we must not do it. + # if not config.debug_assert: + # return wrapped_compiled_fn + + @wraps(wrapped_compiled_fn) + def debugged_compiled_fn(args): + # Test that the computed remove/add arg functions are an inverse + new_args = add_dupe_args(remove_dupe_args(args)) + seen: Dict[Any, None] = {} + for i, (x, y) in enumerate(zip(new_args, args)): + seen[y] = None + assert x is y, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would be a duplicate of " + f"{describe_input(add_dupe_map[i], aot_config)}", + ) + # This is only an error if there is metadata mutation on both of + # the duped arguments; in this case, we need to know what order + # the metadata mutation applies in. You'll get the correct result + # otherwise, because a graph that assumes distinct inputs works if + # you dupe the inputs (the gradient contributions from each input + # will get summed up appropriately.) + # + # TODO: work out how to setup this assert correctly + """ + assert len(seen) == unique_args, format_guard_bug_msg(aot_config, + f"there would be {unique_args} distinct arguments" + ) + """ + return wrapped_compiled_fn(args) + + debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + return debugged_compiled_fn + + +# This layer handles the situation where you have two inputs that alias each other, +# and one of the inputs is mutated. +# We need to take special care to ensure that the mutation is applied to the other aliases in the graph. +# +# pre-condition: aot_wrapper_dedup has already run. +# (This function will in theory work if there are duplicate args. +# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs +# would cause us to hit that path more frequently). +def aot_wrapper_synthetic_base( + flat_fn, + flat_args: List[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + # Currently, the only reason we need to plumb this bool is because + # the synthetic base code prohibits more cases in the autograd case than the inference case. + needs_autograd: bool, + compiler_fn, +): + is_inference = not needs_autograd + flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( + flat_args, + fw_metadata.input_info, + is_inference=is_inference, + ) + # Happy path: we don't need synthetic bases + if synthetic_base_info is None: + return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata) + + # export path: ban synthetic bases for now, add later if requested. + if requires_subclass_dispatch(flat_args, fw_metadata): + raise RuntimeError( + """\ +Encountered aliased inputs that are mutated in the graph, but at least one input/output +to the graph is a tensor subclass. This is not supported today. You can try to +remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + if aot_config.is_export: + raise RuntimeError( + f"""\ +Encountered aliased inputs that are mutated in the graph you are trying to export. +This functionality is currently not supported. If needed, please file a github issue. + +synthetic_base_info={str(synthetic_base_info)} + +fw_metadata={str(fw_metadata)} + """ + ) + + assert len(fw_metadata.input_info) == len(synthetic_base_info) + + # Update our forward metadata to take synthetic bases into account + ( + fw_metadata_updated, + aliased_arg_idx_with_metadata_mutations, + ) = create_synthetic_base_metadata( + fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases + ) + + num_aliased_args_with_metadata_mutations = len( + aliased_arg_idx_with_metadata_mutations + ) + + def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]: + f_args_inner = [] + for inner_idx_or_tuple in synthetic_base_info: + if isinstance(inner_idx_or_tuple, int): + f_args_inner.append(primals[inner_idx_or_tuple]) + else: + inner_base_idx, view_tensor = inner_idx_or_tuple + base = primals[inner_base_idx] + view_arg = gen_alias_from_base( + base, view_tensor, view_tensor.requires_grad + ) + f_args_inner.append(view_arg) + return f_args_inner + + @wraps(flat_fn) + def wrapped_flat_fn(*args): + unpacked_args = _unpack_synthetic_bases(args) + # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases) + # is to relieve the downstream logic from having to reason about mutations on inputs that alias + # each other, by replacing aliased inputs with a synthetic base. + # One area where this breaks down a bit however is if one of those aliased inputs + # experienced a metadata mutation. + # We are now obligated to reapply the metadata mutation directly to the user's input; + # it isn't enough to apply mutations back to the synthetic base in the downstream logic. + # + # The way we handle this is by pretending that those aliased inputs that experience metadata mutations + # are additional outputs in the user's forward function. + # The downstream logic will just treat these as "user outputs that alias inputs". + # However, we will manually grab them at runtime here, use them to reapply the metadata mutation + # to the user inputs, and not return them to the user. + aliased_args_with_metadata_mutations = [ + x + for i, x in enumerate(unpacked_args) + if i in aliased_arg_idx_with_metadata_mutations + ] + if len(aliased_args_with_metadata_mutations) > 0: + return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations + else: + return flat_fn(*unpacked_args) + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + wrapped_flat_fn, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*flat_args_with_synthetic_bases) + assert ref_fw_metadata == fw_metadata_updated, ( + f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, " + f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}" + ) + + compiled_fn = compiler_fn( + wrapped_flat_fn, + flat_args_with_synthetic_bases, + aot_config, + fw_metadata=fw_metadata_updated, + ) + + if not hasattr(compiled_fn, "_boxed_call"): + compiled_fn = make_boxed_func(compiled_fn) + + @wraps(compiled_fn) + def wrapped_compiled_fn(args): + args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( + args, fw_metadata.input_info, is_inference=is_inference + ) + assert synthetic_base_info is not None + aliased_args_w_metadata_mutations = [ + args[i] for i in aliased_arg_idx_with_metadata_mutations + ] + args.clear() + outs = compiled_fn(args_with_synthetic_bases) + if num_aliased_args_with_metadata_mutations > 0: + # This code does not handle **all** input metadata mutations. + # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases + # (which only happens if at least one aliased input experienced a data mutation). + # e.g: + # def f(a, b): + # a.mul_(2) + # b.t_(1, 0) + # f(x.view(2, 2), x.view(2, 2)) + mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:] + user_outs = outs[:-num_aliased_args_with_metadata_mutations] + for inp, mutated_inp in zip( + aliased_args_w_metadata_mutations, mutated_metadata_inps + ): + inp.as_strided_( + mutated_inp.size(), + mutated_inp.stride(), + mutated_inp.storage_offset(), + ) + return user_outs + return outs + + return wrapped_compiled_fn + + +# Note [Handling mutations on an input that aliases other inputs] +# The easiest example to show-case this edge case is here: +# +# def f(a, b): +# a.mul_(2) +# out = a + b +# return out +# b = torch.ones(...) +# a = b.view(-1) +# f(a, b) +# +# In this situation, if a and b happened to be aliased, we need to trace something different! +# Suppose we had b = a.view(-1) +# (In this case, that means that `a._base is b`) +# +# We need to ensure that the aliasing relationship between a and b is preserved. +# We do that detecting the specific situation above (mutate an input that aliases another input), +# and when we do that, we create a synthetic base argument. Then inside of the traced forward, +# we regenerate a and b off of that base. +# The complete example of the transformed function looks like this: +# +# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views +# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph +# def traced_forward(base): +# a = base.as_strided(...) +# b = base.as_strided(...) +# a_updated = a.mul(2) +# base_updated = torch.as_strided_scatter(base, a_updated, ...) +# b_updated = base_updated.as_strided(...) +# out = a_updated + b_updated +# return a_updated, out +# +# def compiled_fn(a, b): +# // we detect that a is the "differentiable base" here +# base = a +# // In other situations, we might do either: +# // (1) a and b are both views off of some larger differentiable base +# // assert a._base is b._base and a._base is not None +# // base = a._base +# // (2) a and b both don't require gradients. Create a base from the storage +# // assert a._base is None and b._base is None +# // base = torch.Tensor(a.storage()) +# a_updated, out = traced_forward(base) +# a.copy_(a_updated) +# return out +# +# This function: +# (1) Merges input views into a synthetic base argument, when any of those input views are mutated +# (2) Returns metadata telling the autograd.Function how to modify their arguments properly, +# to respect the new calling convention. +# +# The calling convention is as follows. +# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base. +# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN], +# Where the ordering of the bases is determined from the ordering of the original view args. +# baseA will come before baseB if the earliest original argument coming from baseA +# showed up earlier in the argument list than the earliest original argument coming from baseB. +# +# Example, given some tensors a, b, c, d +# call site: +# f(a, c.view(-1), b.view(-1), b, c, d) +# Modified argument list: +# c_base comes first because the first c view came earlier in arg list than the first b view +# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases +# b_base = torch.Tensor(b.storage()) +# c_base = torch.Tensor(c.storage()) +# f(c_base, b_base, a, d) +def merge_view_inputs( + fwd_inputs: List[Any], + mutated_input_info: List[InputAliasInfo], + *, + # The autograd case currently has more restrictions than the inference case. + is_inference: bool, +) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]: + def _are_differentiable_views(view1, view2): + if view1 is view2: + return True + if view1._base is None and view2._base is None: + return False + if view1._base is view2._base or view1._base is view2 or view1 is view2._base: + return True + return False + + def _same_dtype_views(view1, view2): + if view1.dtype != view2.dtype: + return False + if view1._base is not None and view1.dtype != view1._base.dtype: + return False + if view2._base is not None and view2.dtype != view2._base.dtype: + return False + return True + + assert len(fwd_inputs) == len(mutated_input_info) + storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list) + base_args = [] + other_args = [] + for i, inpt in enumerate(fwd_inputs): + if isinstance(inpt, Tensor): + storage_ref = StorageWeakRef(inpt.untyped_storage()) + storage_ref_to_idx[storage_ref].append(i) + else: + other_args.append(inpt) + # Note [Synthetic Base Info Metadata] + # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. + # It's either: + # - another int (corresponding to the index in the argument list of the element from the outer calling convention) + # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx]) + # idx corresponds to which synthetic base from the outer calling context to view + inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {} + for aliased_input_indices in storage_ref_to_idx.values(): + if len(aliased_input_indices) <= 1 or not any( + # We only care about mutations that affect all aliases, + # so metadata mutations on an input doesn't require us to do synthetic base handling. + mutated_input_info[inpt_idx].mutates_data + for inpt_idx in aliased_input_indices + ): + for curr_idx in aliased_input_indices: + other_args.append(fwd_inputs[curr_idx]) + continue + + # Here, we attempt to do a more complicated check to detect false aliasing + # (e.g. if all the tensors have the same storage, but don't actually overlap) + # In theory, we could have a large group of tensors that all share storages, where only *some* of them + # have overlapping memory. + # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair + # of tensors in the current group that shares a storage is non-overlapping. + aliased_input_indices_no_false_sharing = compute_overlapping_inputs( + fwd_inputs, aliased_input_indices + ) + if len(aliased_input_indices_no_false_sharing) <= 1: + for curr_idx in aliased_input_indices: + other_args.append(fwd_inputs[curr_idx]) + continue + + # We detected an input that was mutated, AND aliases with another input. + # we need to replace this set of aliased inputs with a single synthetic base. + # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases + # and error out. We can fix them later. + # These checks are transitive, so we don't need to check every pair. + for idx1, idx2 in zip( + aliased_input_indices, aliased_input_indices[1:], strict=False + ): + view1 = fwd_inputs[idx1] + view2 = fwd_inputs[idx2] + # The "inputs that are aliased but have different differentiable bases" case + # is more complicated and hopefully pretty rare. Not currently handled. + if not is_inference: + assert _are_differentiable_views( + view1, view2 + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + # Regenerating views when reinterpreting complex / real tensors seems non-trivial, + # not handling for now + assert _same_dtype_views( + view1, view2 + ), "aot_autograd() does not yet handle input mutations on views with different dtypes." + non_none_bases = [ + fwd_inputs[i]._base + for i in aliased_input_indices + if fwd_inputs[i]._base is not None + ] + aliases_with_none_bases = [ + fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None + ] + if len(non_none_bases) == 0: + # Case where none of the aliases have a ._base + # we generate a synthetic base without gradients, and generate views off of it + # We hit this case when we have input tensors to the graph that share a storage, + # but do not have a ._base field. + # Wondering when we hit this case? + # The _base field simply says that autograd knows about the aliasing relationship, + # but sometimes we create tensors which are aliased out of the same storage but guaranteed + # to be disjoint. In these cases, we will skip setting up the _base relationship + # for performance reasons (because the fact that the tensors share the same storage + # is unobservable unless you (1) do naughty things with resize_/as_strided + # or (2) look at the storage--as we are doing here.) + # One particular example of this is optimizer steps on the LSTM module: + # LSTM parameters are packed into a contiguous storage for efficiency reasons when + # calling cuDNN kernels, so when these parameters get passed to the optimizer we will + # find they share the same storage, but do not have _base set since they are all disjoint. + # + # NOTE: There is one case where this is unsafe: + # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily + # the same shape as the "actual" base that the tensor came from. + # For the most part this is fine, because we always use as_strided() + # to generate the original aliased inputs again. + # If we were to use view-replay though, this could cause the aliased views + # to have incorrect sizes. + example_idx = aliased_input_indices[0] + example_alias = fwd_inputs[example_idx] + # Note that this function is re-used at both trace time and runtime. + # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor. + synthetic_base = torch.empty( + (0,), dtype=example_alias.dtype, device=example_alias.device + ) + # We don't actually have a convenient way of going from storage -> tensor, + # So using set_() here (we suffer some minor overhead, but this case is rare). + synthetic_base.set_(example_alias.untyped_storage()) + else: + # Case where all of the aliases require gradients, and have the same _base. + synthetic_base = non_none_bases[0] + for other_base in non_none_bases[1:]: + assert ( + other_base is synthetic_base + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + for alias in aliases_with_none_bases: + assert ( + alias is synthetic_base + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + base_args.append(synthetic_base) + for curr_view_idx in aliased_input_indices: + curr_view = fwd_inputs[curr_view_idx] + base_idx = len(base_args) - 1 + # We store just enough info here so that we can regenerate the view later. + # Regeneration: curr_view._view_func(args[base_idx]) + inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view) + if len(base_args) == 0: + assert len(other_args) == len(fwd_inputs) + # If no synthetic bases are necessary, just return the original inputs. + return fwd_inputs, None + else: + # Otherwise, return: + # (1) The new args according to the updated calling convention: (synthetic_bases, other_args) + # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. + # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. + args_to_functionalization = base_args + other_args + arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)} + for i, other_arg in enumerate(other_args): + new_idx = len(base_args) + i + old_idx = arg_to_old_idx_map[other_arg] + inner_calling_convention_meta[old_idx] = new_idx + # post process into a list + post_processed_calling_convention_meta: List[ + Union[int, Tuple[int, torch.Tensor]] + ] = [-1 for _ in range(len(inner_calling_convention_meta))] + for k, v in inner_calling_convention_meta.items(): + post_processed_calling_convention_meta[k] = v + # Quick assert: every argument in the inner calling convention should be accounted for. + for x in post_processed_calling_convention_meta: + assert x != -1 + return args_to_functionalization, post_processed_calling_convention_meta diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/schemas.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..25a435f1fa6da6685d2909c9e93ecf2c392a7351 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/schemas.py @@ -0,0 +1,696 @@ +""" +The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes +input/output types, metadata, config, function signatures etc. +""" + +import collections +import functools +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch._guards import Source +from torch._subclasses import FakeTensor +from torch._subclasses.fake_tensor import is_fake + +from .. import config + +from .functional_utils import _check_if_mutation_can_be_in_graph +from .utils import strict_zip + +zip = strict_zip + +OutputType = Enum( + "OutputType", + ( + # output is not an alias + "non_alias", + # output aliases an input + "alias_of_input", + # output **is** an input tensor + "is_input", + # output has a ._base tensor, which is a graph intermediate. + # We need to return its ._base as a graph output, + # so its requires_grad info is populated correctly. + # Instructs the runtime code to regenerate the current output + # from a base tensor, graph_intermediates[base_idx] + "alias_of_intermediate_save_as_output", + # Same as above; but we don't need to explicitly add its ._base + # as a graph output, because it already **is** a graph output. + "alias_of_intermediate", + # Same as above; but the output's ._base is **already** a user output. + # Instructs the runtime code to regenerate the current output from + # a base tensor, user_outputs[base_idx] + "alias_of_intermediate_base_is_user_output", + # See Note [Intermediate Bases Optimization] + "unsafe_view_alias", + # output is an alias, but has a custom autograd.Function backward. + # In this case, we don't want to do view-replay, since we won't be able to replay the custom function. + # Instead, we'll treat this output "normally", and trace its backward into the graph. + "custom_function_view", + ), +) + + +# This class stores info about every user output. +@dataclass(frozen=True) +class OutputAliasInfo: + # Tells us if this output is: + # (1) a regular (non-aliased) output + # (2) an alias of a forward input + # (3) **is** a forward input (special case of "alias_of_input") + # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward) + # (5) an alias of an intermediate, that explicitly requires returning the intermediate + # as a graph output + # (6) an alias of an intermediate, where that intermediate is also a user output + output_type: OutputType + # The raw type of the output (torch.Tensor, SymInt, etc) + raw_type: type + # If (1) above, then + # - base_idx is None + # If (2) or (3) above, then + # - Tells us that the base of this alias is user_fwd_input[base_idx] + # (This is an index into the inputs *before* we make synthetic bases) + # If (4) or (5) above, then + # - Tells us that the base of this alias is output_graph_intermediates[base_idx] + # here, this refers to the index of the *direct* traced + # If (6) above, then: + # - Tells us that the base of this alias is output_user_fwds[base_idx] + # here, this refers to the index of the *direct* traced + base_idx: Optional[int] + # If it is a Tensor, what the dynamic dims are (otherwise is None) + dynamic_dims: Optional[Set[int]] + # requires_grad + requires_grad: bool + + +class MutationType(Enum): + NOT_MUTATED = 1 + MUTATED_IN_GRAPH = 2 + MUTATED_OUT_GRAPH = 3 + + +# This class tells us info about user inputs. +@dataclass(frozen=True) +class InputAliasInfo: + is_leaf: bool + mutates_data: bool + mutates_metadata: bool + mutations_hidden_from_autograd: bool + mutations_under_no_grad_or_inference_mode: bool + mutates_storage_metadata: bool + requires_grad: bool + keep_input_mutations: bool + + def __post_init__(self): + if self.mutates_storage_metadata: + # For convenience, we guarantee that this is always true. + # In practice, If we call .set_(), then at runtime there is no need + # to additionally fix up the tensor metadata, since our runtime + # call to inp.set_(updated_inp) will already have the right metadata + assert self.mutates_metadata + + @functools.cached_property + def mutation_type(self) -> MutationType: + if (not self.mutates_data) and (not self.mutates_metadata): + return MutationType.NOT_MUTATED + + if _check_if_mutation_can_be_in_graph( + self.keep_input_mutations, + self.mutates_data, + self.mutates_metadata, + self.mutations_hidden_from_autograd, + self.mutations_under_no_grad_or_inference_mode, + self.requires_grad, + ): + return MutationType.MUTATED_IN_GRAPH + + return MutationType.MUTATED_OUT_GRAPH + + +@dataclass +class SubclassCreationMeta: + """ + Used for AOTDispatch. + This dataclass gives us the information we need to reconstruct a tensor subclass + from our flat inputs. + Why is this important? The graph that we'd like to trace out contains flat tensor inputs, + But the user's original model may have subclass inputs and outputs. + So we need to wrap/unwrap subclasses as necessary to translate between the user's + view (subclass inps/outs), and the backend compiler's view (graph with no subclass args). + + Complications arise mostly from the fact that a subclass can hold more than one inner tensor; + So for a given subclass input/output, we need to carefully track which indices map + to the subclass tensor in the corresponding "dense-tensor-only" graph. + """ + + # In the inner graph that only takes in dense tensor inputs, + # this maps to the first index of "tensors that should go in this subclass wrapper" + flat_tensor_start_idx: int + # The number of tensors that live in this subclass wrapper + arg_count: int + # Stores the original subclass itself. + # This is needed because we need the autograd metadata on the original subclass + # (this is guaranteed to be a wrapper subclass that holds a fake tensor, + # so holding onto this at runtime shouldn't leak memory) + original_subclass: torch.Tensor + # meta and inner_keys are produced by the subclass's __tensor_flatten__. + # We need to keep them around along with outer_size / outer_stride to plumb them + # into __tensor_unflatten__. + meta: Any + inner_keys: List[Any] + outer_size: Tuple[int, ...] + outer_stride: Tuple[int, ...] + + def creation_fn(self, all_args, *, is_runtime: bool): + curr_args = all_args[ + self.flat_tensor_start_idx : self.flat_tensor_start_idx + self.arg_count + ] + assert len(curr_args) == len( + self.inner_keys + ), f"inner_keys: {str(self.inner_keys)}. len(curr_args): {len(curr_args)}" + # NB: Sometimes we have real inner tensors and symbolic metadata. + # TODO: Resolve this so we always have matching real / symbolic tensors / metadata. + out = type(self.original_subclass).__tensor_unflatten__( # type: ignore[attr-defined] + dict(zip(self.inner_keys, curr_args)), + self.meta, + self.outer_size, + self.outer_stride, + ) + if not is_runtime: + # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper + # has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass. + # We don't trace through the autograd engine at runtime though, so no need + # to compute this extra metadata then! + torch._mirror_autograd_meta_to(self.original_subclass, out) # type: ignore[attr-defined] + + return out + + def __post_init__(self): + # sanity assert to make sure we don't leak memory + assert is_fake(self.original_subclass) + + +# This class encapsulates all aliasing + mutation info we need about the forward graph +# See a more detailed overview of the edge case handling at +# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit +@dataclass(eq=False) +class ViewAndMutationMeta: + # length = # user inputs + # This gives us info about every input, and what sort of mutation happened to it (if any) + input_info: List[InputAliasInfo] + + # length = # user outputs + # This gives us info about every output (mostly around whether it aliases other tensors) + output_info: List[OutputAliasInfo] + + # length = the number of intermediate bases appended as outputs to the end of the forward graph. + # Note: this is not necessarily the same thing as: + # len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate]) + # Because outputs might share a ._base, or an output's ._base might itself be + # another user output (in both cases, we won't redundantly append bases to the end of the graph) + num_intermediate_bases: int + + # For inference only: instructs us to keep data-only input mutations directly in the graph + keep_input_mutations: bool + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # These are the FakeTensor (or potential SymInt) outputs that we traced from our + # metadata pass of the user's forward function. + # Their only use today is to pass them as a best-guess for tangents when tracing the joint. + # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis + # pass once, and re-use the output throughout AOTAutograd + traced_tangents: List[Any] + + # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs + # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors, + # Given a (potentially larger) list of plain torch tensors. + + # Taking subclass_inp_meta as an example: + # subclass_inp_meta[i] = j (an int) tells us: + # "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph." + # subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2) + # "The i'th user input is subclass holding two inner tensors, which are + # inputs[3] and inputs[4] of the plain-tensor graph". + + # length = # user inputs + subclass_inp_meta: List[Union[int, SubclassCreationMeta]] + # So, the full set of outputs to the forward graph looks something like: + # (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors) + # where the first 3 of those 4 can be subclasses + # (but not saved_for_bw tensors, since these are internal to the compiler + # and not user visible, so there's no point in wrapping/unwrapping them at runtime). + # This list contains subclass information on all of the fw graph outputs + # except for saved_for_bw_tensors. + subclass_fw_graph_out_meta: List[Union[int, SubclassCreationMeta]] + # length = # backward graph inputs + subclass_tangent_meta: List[Union[int, SubclassCreationMeta]] + # TODO: we should kill this + # (need to default it to not break internal) + is_train: bool = False + + num_symints_saved_for_bw: Optional[int] = None + + # The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue + # NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode + # that is intended to be in effect prior to running the graph, in keeping with + # equivalence to eager mode. It is the responsibility of upstream graph acquisition + # to reset the grad mode to its pre-graph value prior to calling aot_autograd. + grad_enabled_mutation: Optional[bool] = None + + # Keeps track of whether `torch.use_deterministic_algorithms` was turned on + # when the forward was run. If deterministic mode was turned off during the + # forward, but is turned on during the backward call, then an error is + # raised + deterministic: Optional[bool] = None + + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are + # side-effectful operators, FunctionalTensorMode will populate this + # dictionary telling us how many tokens we will need during tracing. + tokens: Dict[Any, torch.Tensor] = field(default_factory=dict) + + def __post_init__(self): + # pre-compute the indices of the inputs that are mutated. + # When keep_input_mutations is set, we don't need to worry about our epilogue + # handling data-only mutations, because we keep them directly in the graph. + + mutated_inp_runtime_indices = [ + i + for i, m in enumerate(self.input_info) + if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) + ] + + mutated_graph_handled_indices = [ + i + for i, m in enumerate(self.input_info) + if m.mutation_type == MutationType.MUTATED_IN_GRAPH + ] + self.mutated_graph_handled_indices = mutated_graph_handled_indices + self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices) + + mutated_graph_handled_indices_seen_by_autograd = [ + i + for i in mutated_graph_handled_indices + if not self.input_info[i].mutations_hidden_from_autograd + ] + + self.mutated_graph_handled_indices_seen_by_autograd = ( + mutated_graph_handled_indices_seen_by_autograd + ) + self.num_mutated_graph_handled_indices_seen_by_autograd = len( + self.mutated_graph_handled_indices_seen_by_autograd + ) + + aliased_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type + not in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + unsafe_view_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type is OutputType.unsafe_view_alias + ] + + # This is pre-computed in post_init for perf. + # It contains the index of every element + # of input_info that corresponds to a mutation (data or metadata or both) + self.mutated_inp_runtime_indices = mutated_inp_runtime_indices + self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) + + # This is pre-computed for perf. + # It contains the index of every element + # of output_info that corresponds to an alias (either of an input or intermediate) + self.aliased_out_indices = aliased_out_indices + self.unsafe_view_out_indices = unsafe_view_out_indices + self.num_outputs = len(self.output_info) + self.num_outputs_non_aliased = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + ) + self.num_outputs_aliased_to_inputs = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_input, + OutputType.is_input, + ] + ] + ) + self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices) + self.num_outputs_aliased_to_intermediates = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ] + ] + ) + self.num_outputs_aliased = ( + self.num_outputs_aliased_to_inputs + + self.num_outputs_aliased_to_intermediates + ) + + self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info) + # See Note: [AOTAutograd Backward Guards] + # This is pre-computed for fast asserts on the types of our grad_outputs in the backward. + # Eventually, we should kill this and replace with real backward guards. + # (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor) + self.output_types = [ + torch.Tensor if isinstance(x, FakeTensor) else type(x) + for x in self.traced_tangents + ] + + self.is_rng_op_functionalized = config.functionalize_rng_ops + # All of the above metadata is collected by tracing the fw function. + # However, extra outputs for rng offsets behave differently. Both fwd + # and bwd graphs have their own outputs for the total consumed offsets. + # Unlike mutated inputs, we don't have to worry about sending the right + # set of tensors between fwd and bwd. Fwd and bwd offsets are + # independent and simpler to handle. Therefore, we track them + # separately. + self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0 + + # Our forward() returns both (mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints) + self.num_forward_returns = ( + self.num_mutated_inp_runtime_indices + + self.num_outputs + + self.num_intermediate_bases + ) + # In case of functionalization of rng ops, the fw_module returns one + # additional output for rng offset. This rng offset is used right + # away to advance the rng state, and is not passed on to the raw + # outputs. However, we need to know the exact boundary to identify + # which tensors to be saved for the bwd graph. num_forward captures + # this information. + self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset + + @property + def tensors_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(self.num_forward, -self.num_symints_saved_for_bw) + else: + return slice(self.num_forward, None) + + @property + def symints_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(-self.num_symints_saved_for_bw, None) + else: + return slice(0, 0) # empty slice + + def __eq__(self, other): + if not isinstance(other, ViewAndMutationMeta): + return NotImplemented + return ( + self.input_info == other.input_info + and self.output_info == other.output_info + and self.num_intermediate_bases == other.num_intermediate_bases + and self.keep_input_mutations == other.keep_input_mutations + and self.is_rng_op_functionalized == other.is_rng_op_functionalized + and self.num_outputs_rng_offset == other.num_outputs_rng_offset + and len(self.traced_tangents) == len(other.traced_tangents) + and all( + x.shape == y.shape and x.dtype == y.dtype + for x, y, in zip(self.traced_tangents, other.traced_tangents) + ) + ) + + +@dataclass(eq=False) +class SubclassMeta: + # A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses) + # So for example, if the user had a model containing two `TwoTensor` inputs, + # Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here. + fw_metadata: ViewAndMutationMeta + + # Note: [Computing Subclass Metadata about grad_inputs] + # Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses + # + # You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs? + # (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous) + # + # This doesn't really work though. take this example: + # + # def f(DoubleTensor, DenseTensor): + # return DoubleTensor * DenseTensor + # + # In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor. + # When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs. + # This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input) + # and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors. + # + # Note that this info **cannot** easily be figured out from ViewAndMutationMeta. + # We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed. + # + # See Note: [AOTAutograd Backward Guards] + # This will also eventually require us to install backward guards, + # in case we made incorrect assumptions about the subclass-ness of our grad_outputs + # + # Optional field because we don't compute for inference graphs + grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] + + def __init__(self): + # The fields in this class get set after its construction. + pass + + +# This class exists because: +# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs +# - we only care about the metadata on those aliases, so we can regenerate them. +# We do not want them to participate in the autograd.Function. +# We do that by wrapping them in an opaque class, so the autograd.Function +# does not know to treat them as tensors. +@dataclass(frozen=True) +class TensorAlias: + alias: torch.Tensor + + +@dataclass +class BackwardSignature: + """ + Provides information about the backward section of an exported + joint forward-backward graph. + For a particular fx GraphModule, this class contains information on: + (1) A mapping from each gradient (backwards output) to the parameter + it corresponds to (forward input) + (2) A mapping from each gradient (backwards output) to the user input + it corresponds to (forward input) + (3) Which of the forward outputs corresponds to the loss, that we backprop on. + + Each string name is the `node.name` of the corresponding node in the fx graph. + """ + + gradients_to_parameters: Dict[str, str] + gradients_to_user_inputs: Dict[str, str] + loss_output: str + + +GraphOutputName = NewType("GraphOutputName", str) +GraphInputName = NewType("GraphInputName", str) +FQN = NewType("FQN", str) + + +@dataclass +class GraphSignature: + """ + Provides information about an exported module. + For a particular fx GraphModule, this class contains information on: + (1) Which graph inputs are parameters, buffers, or user inputs + (2) (for params/buffers) a mapping from the name of each graph argument + to its parameter/buffer FQN in the original nn.Module. + (3) If there are input mutations, these are represented as extra outputs + in the fx GraphModule. We provide a mapping from these + extra output names to the names of the actual inputs. + (4) The pytree metadata on how to flatten/unflatten inputs and outputs. + The corresponding FX GraphModule only accepts and returns + pytree-flattened inputs/outputs. + (5) (Optionally) if the FX is a joint forward-backward graph, we provide + a signature on the backward section of the joint graph. + """ + + parameters: List[FQN] + buffers: List[FQN] + + user_inputs: List[GraphInputName] + user_outputs: List[GraphOutputName] + inputs_to_parameters: Dict[GraphInputName, FQN] + inputs_to_buffers: Dict[GraphInputName, FQN] + + # If the user's module mutates a buffer, + # it's represented in the graph as an extra graph output. + # This dict is a mapping from + # "graph outputs that correspond to updated buffers" + # to the FQN names of those mutated buffers. + buffers_to_mutate: Dict[GraphOutputName, FQN] + user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName] + + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + backward_signature: Optional[BackwardSignature] + + input_tokens: List[GraphInputName] + output_tokens: List[GraphOutputName] + + @classmethod + def from_tracing_metadata( + cls, + *, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + graph_input_names: List[str], + graph_output_names: List[str], + view_mutation_metadata: ViewAndMutationMeta, + named_parameters: List[str], + named_buffers: List[str], + num_user_inputs: int, + num_user_outputs: int, + loss_index: Optional[int], + backward_signature: Optional[BackwardSignature], + ) -> "GraphSignature": + graph_inputs = graph_input_names + graph_outputs = graph_output_names + parameters = list(named_parameters) + buffers = list(named_buffers) + num_tokens = len(view_mutation_metadata.tokens) + + # Calling convention assumptions: + # (1) graph inputs = (input_tokens, params, buffers, user_inputs) + # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients) + # (If we are capturing an inference graph, this convention is identical + # except that param_gradients is empty) + # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens + + # Address input calling conventions: + start, stop = 0, num_tokens + input_tokens = graph_inputs[start:stop] + + start, stop = stop, stop + len(parameters) + inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters)) + + start, stop = stop, stop + len(buffers) + inputs_to_buffers = dict( + zip( + graph_inputs[start:stop], + buffers, + ) + ) + + start, stop = stop, stop + num_user_inputs + user_inputs = graph_inputs[start:stop] + + # We should've gone through all the inputs now + assert len(graph_inputs) - stop == 0 + + # Address output calling conventions: + start, stop = 0, num_tokens + output_tokens = graph_outputs[start:stop] + + names = [*input_tokens, *parameters, *buffers, *user_inputs] + mutations = [] + for idx, input_info in enumerate(view_mutation_metadata.input_info): + if input_info.mutates_data: + # Only buffers can be mutated, not parameters + assert idx >= len(parameters) + mutations.append(names[idx + num_tokens]) + + assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices + + start, stop = ( + stop, + stop + view_mutation_metadata.num_mutated_inp_runtime_indices, + ) + outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations)) + + user_inputs_to_mutate = {} + buffers_to_mutate = {} + for output_name, mutation_name in outputs_to_mutations.items(): + if mutation_name in user_inputs: + user_inputs_to_mutate[output_name] = mutation_name + else: + assert mutation_name in buffers + buffers_to_mutate[output_name] = mutation_name + + start, stop = stop, stop + num_user_outputs + user_outputs = graph_outputs[start:stop] + + unused_outputs = len(graph_outputs) - stop + if backward_signature is not None: + unused_outputs -= len(backward_signature.gradients_to_parameters) + len( + backward_signature.gradients_to_user_inputs + ) + assert unused_outputs == 0 + + return GraphSignature( + parameters=parameters, # type: ignore[arg-type] + buffers=buffers, # type: ignore[arg-type] + user_inputs=user_inputs, # type: ignore[arg-type] + user_outputs=user_outputs, # type: ignore[arg-type] + inputs_to_buffers=inputs_to_buffers, # type: ignore[arg-type] + inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type] + user_inputs_to_mutate=user_inputs_to_mutate, + buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type] + in_spec=in_spec, + out_spec=out_spec, + backward_signature=backward_signature, + input_tokens=input_tokens, # type: ignore[arg-type] + output_tokens=output_tokens, # type: ignore[arg-type] + ) + + +@dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: Dict[Callable, Callable] + num_params_buffers: int + aot_id: int + keep_inference_input_mutations: bool + is_export: bool = False + no_tangents: bool = False + dynamic_shapes: bool = False + aot_autograd_arg_pos_to_source: Optional[List[Source]] = None + inference_compiler: Optional[Callable] = None + enable_log: bool = True + # this is always false outside of export. + pre_dispatch: bool = False + + def __post_init__(self): + if self.pre_dispatch: + assert self.is_export, "Can only have pre_dispatch IR for export." + + +SubclassTracingInfo = collections.namedtuple( + "SubclassTracingInfo", + ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"], +) diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..563223c3dbbdce4bd1e476cb4bd6db9501a54004 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py @@ -0,0 +1,295 @@ +""" +This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. +AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, +and this includes tensor subclasses that implement __torch_dispatch__. +""" + +from typing import Any, List, Optional, Tuple, Union + +import torch.utils._pytree as pytree + +from torch import Tensor +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .schemas import MutationType, SubclassCreationMeta, ViewAndMutationMeta +from .utils import strict_zip + +zip = strict_zip + + +def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: + args_flattened = pytree.arg_tree_leaves(*args) + any_subclass_args = any( + is_traceable_wrapper_subclass(x) + for x in args_flattened + if isinstance(x, Tensor) + ) + from torch._functorch._aot_autograd.schemas import SubclassCreationMeta + + any_subclass_outputs = any( + type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta + ) + # This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime. + return any_subclass_args or any_subclass_outputs + + +# Given a flat list of arguments, some of which may be tensor subclasses, +# computes metadata about "how to reconstruct the current list of subclasses, +# if we were given their flattened dense tensors instead" +def create_subclass_meta( + curr_args: Union[List[Any], Tuple[Any, ...]], +) -> List[Union[int, SubclassCreationMeta]]: + idx = 0 + infos: List[Union[int, SubclassCreationMeta]] = [] + for a in curr_args: + if isinstance(a, Tensor) and is_traceable_wrapper_subclass(a): + attrs, meta = a.__tensor_flatten__() # type: ignore[attr-defined] + start_idx = idx + cnt = len(attrs) + curr_cnt = cnt + infos.append( + SubclassCreationMeta( + flat_tensor_start_idx=start_idx, + arg_count=curr_cnt, + original_subclass=a, + meta=meta, + inner_keys=attrs, + outer_size=a.shape, + outer_stride=a.stride(), + ) + ) + else: + infos.append(idx) + cnt = 1 + idx += cnt + return infos + + +# Output structure: +# - List[Tensor] if tracing an inference graph +# - Tuple[List[Tensor], List[Tensor]] if tracing a joint graph. +# This function effectively concats each inner list of subclass tensors +# into a (potentially longer) list of inner tensors. +# +# This function takes in a pytree of arguments and unwraps any tensor subclasses. +# Annoyingly, we can't use pytrees to perform the unwrapping, because unwrapping returns +# a list of tensors that we would then need to concat together. +# Instead, we specialize the logic for the inference vs. joint graph case. +# NOTE: this function is hot, since we unwrap tensor subclass inputs at runtime +def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool): + def concat_inner_tensors_from_subclasses(xs): + xs_inner = [] + for x in xs: + if isinstance(x, Tensor) and is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() # type: ignore[attr-defined] + xs_inner += [getattr(x, attr) for attr in attrs] + else: + xs_inner += [x] + return xs_inner + + if is_joint_structure: + assert isinstance(wrapped_args, tuple) and len(wrapped_args) == 2 + assert isinstance(wrapped_args[0], (tuple, list)) and isinstance( + wrapped_args[1], (tuple, list) + ) + unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args[0]) + unwrapped_args_tangents = concat_inner_tensors_from_subclasses(wrapped_args[1]) + unwrapped_args = (unwrapped_args_fw, unwrapped_args_tangents) + else: + assert isinstance(wrapped_args, (list, tuple)) + unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args) + unwrapped_args = unwrapped_args_fw + return unwrapped_args + + +# Turns a flattened list of tensor arguments into (maybe) subclass tensors. +# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in. +def wrap_tensor_subclasses( + unwrapped_args: Union[Tuple[Any, ...], List[Any]], + *, + subclass_metas: List[Union[int, SubclassCreationMeta]], + num_fw_outs_saved_for_bw: Optional[int] = None, + is_runtime: bool = False, +) -> Tuple[Any, ...]: + wrapped_args = [] + num_args_tallied = 0 + for subclass_meta in subclass_metas: + if isinstance(subclass_meta, int): + wrapped_args.append(unwrapped_args[subclass_meta]) + num_args_tallied += 1 + else: + assert isinstance(subclass_meta, SubclassCreationMeta) + wrapped_args.append( + subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + ) + num_args_tallied += subclass_meta.arg_count + + # Note: [Partitioner handling for Subclasses, Part 2] + # At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw, + # to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them. + # + # When this function is called at runtime in the forward, + # we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs. + # + # One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen? + # Answer: we do it **inside of our compiled autograd.Function**. + # This seems like morally the right place: autograd happens above subclass desugaring, + # so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors. + # + # This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph + # into a forward and backward graph, we end up with some activations that show up as extra outputs + # in the compiled forward graph, that are **not** user outputs. + # These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses. + # + # On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`), + # we computed subclass metadata on every forward output, but this did **not** include activations + # created by the partitioner. + # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations), + # but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`. + # We then need to make sure that we return (*wrapped_user_fw_outs, *activations). + if num_fw_outs_saved_for_bw is not None: + assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, ( + f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal " + f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of " + f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})" + ) + activations = unwrapped_args[num_args_tallied:] + if isinstance(wrapped_args, tuple) and isinstance(activations, tuple): + return wrapped_args + activations + return tuple(list(wrapped_args) + list(activations)) + else: + assert len(unwrapped_args) == num_args_tallied + return tuple(wrapped_args) + + +# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses. +# This function carefully handles the inference vs. joint cases: +# - when is_joint_structure is True, args is (primals, tangents) +# - when is_joint_structure is False, args is [*primals] +def wrap_tensor_subclasses_maybe_joint( + unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta +) -> Union[Tuple[Any, ...], List[Any]]: + # Since this function is re-used for both inference and joint graphs, + if is_joint_structure: + assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2 + assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance( + unwrapped_args[1], (tuple, list) + ) + primals, tangents = unwrapped_args[0], unwrapped_args[1] + wrapped_primals = wrap_tensor_subclasses( + primals, subclass_metas=meta.subclass_inp_meta + ) + wrapped_tangents = wrap_tensor_subclasses( + tangents, subclass_metas=meta.subclass_tangent_meta + ) + return (wrapped_primals, wrapped_tangents) + else: + wrapped_args = wrap_tensor_subclasses( + unwrapped_args, subclass_metas=meta.subclass_inp_meta + ) + return wrapped_args + + +# TODO: UNUSED. delete? +def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMeta: + # input infos + input_info = [] + for inp, subclass_meta in zip(meta.input_info, meta.subclass_inp_meta): + num_inps = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count + for _ in range(num_inps): + input_info.append(inp) + + # output infos + output_info = [] + subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[ + meta.num_mutated_inp_runtime_indices : + ] + if meta.num_intermediate_bases > 0: + subclass_out_meta_user_outs_only = subclass_out_meta_user_outs_only[ + : -meta.num_intermediate_bases + ] + # sanity assert + assert len(meta.output_info) == len(subclass_out_meta_user_outs_only) + # Assume that the information on the output is shared by all of its inner tensors. + for out, subclass_meta in zip(meta.output_info, subclass_out_meta_user_outs_only): + num_outs = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count + for _ in range(num_outs): + output_info.append(out) + + # A bit hacky, but we don't actually care about all of the metadata here. + # This metadata is used **underneath** both autograd and subclass de-sugaring, + # So all we really care about is stuff like: + # - num inputs/outputs (needed by the partitioner) + # - input mutations (**not** used today, since we don't handle input mutations inside the subclass, + # although we should handle this eventually) + # TODO: add a test case to assert we error when this happens, instead of getting silent correctness + num_intermediate_bases = None + keep_input_mutations = meta.keep_input_mutations + traced_tangents = None + subclass_inp_meta = None + subclass_fw_graph_out_meta = None + subclass_tangent_meta = None + + metadata = ViewAndMutationMeta( + input_info=input_info, # type: ignore[arg-type] + output_info=output_info, # type: ignore[arg-type] + num_intermediate_bases=num_intermediate_bases, # type: ignore[arg-type] + keep_input_mutations=keep_input_mutations, # type: ignore[arg-type] + traced_tangents=traced_tangents, # type: ignore[arg-type] + subclass_inp_meta=subclass_inp_meta, # type: ignore[arg-type] + subclass_fw_graph_out_meta=subclass_fw_graph_out_meta, # type: ignore[arg-type] + subclass_tangent_meta=subclass_tangent_meta, # type: ignore[arg-type] + ) + return metadata + + +def compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata: ViewAndMutationMeta, + inner_metadata: ViewAndMutationMeta, +) -> List[int]: + # Note: [Recomputing subclass mutation handling] + # + # Generally, if a subclass requires grad, its components will not require grad. + # But for the purposes of tracking returned tensors, we should treat those component + # tensors as if they require grad. + # + # For example, if the subclass tensor requires grad and will be mutated in a way that + # requires us to handle the mutation outside of the graph, we need to return it + # from the forward graph. The inner_meta data won't consider the component tensors + # as if they need to be returned, because they don't require grad; but really, we + # should handle those tensors the same way we handle the subclass tensor itself; i.e. + # if we'd include the subclass tensor as part of the outputs, then we should also + # include the component tensors. + # + # To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs + # from the outer subclass tensors and propagating + + updated_input_info = [] + inner_idx = 0 + if not fw_metadata.subclass_inp_meta: + # Sometimes we don't have subclass info, e.g. synthetic_base codepaths + return inner_metadata.mutated_inp_runtime_indices + assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info) + for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta): + if isinstance(inp_meta, int): + assert outer_idx < len(fw_metadata.input_info) + if inner_metadata is not None: + assert inner_idx < len(inner_metadata.input_info) + assert ( + inner_metadata.input_info[inner_idx] + == fw_metadata.input_info[outer_idx] + ) + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + else: + for _ in range(inp_meta.arg_count): + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + if inner_metadata is not None: + assert len(inner_metadata.input_info) == len(updated_input_info) + + return [ + i + for i, inp in enumerate(updated_input_info) + if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0834c398563ece8bc7ef46aa9dd16ddd1ba638 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -0,0 +1,698 @@ +""" +This module is responsible for transforming functions to be traced into a form +that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) +to handle. + +It does so by: +1. functionalization (including RNG functionalzation) +2. creating a joint graph when required +3. transforming mutations into extra outputs +4. dispatching subclasses +""" + +import warnings +from contextlib import nullcontext +from functools import wraps +from typing import Any, Callable, List, Tuple, Union +from unittest.mock import patch + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker +from torch._guards import detect_fake_mode +from torch._prims_common import CUDARngStateHelper +from torch.fx.experimental.symbolic_shapes import definitely_false, sym_eq +from torch.nn.utils import stateless + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .functional_utils import ( + from_fun, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, +) +from .logging_utils import setup_stacktrace_preservation_hooks +from .schemas import ( + AOTConfig, + MutationType, + OutputType, + SubclassMeta, + SubclassTracingInfo, + ViewAndMutationMeta, +) +from .subclass_utils import ( + create_subclass_meta, + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from .utils import maybe_to_fresh_input + + +# This function returns a new function that returns mutated inputs as outputs. +# if keep_data_input_mutations is set, then we assume that data-only mutations +# will be left in the graph, and we only return metadata-mutated inputs as outputs. +def fn_input_mutations_to_outputs( + fn: Callable, + meta: ViewAndMutationMeta, + keep_data_input_mutations: bool, +) -> Any: + @wraps(fn) + def inner_fn(*args): + outs = fn(*args) + assert len(meta.output_info) == len(outs) + # The compiled fw will return mutated input tensors, *including* metadata-only mutation. + # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. + # (because data-only input mutations are handled directly in the compiled graph) + mutated_inputs_to_return = [ + x for (i, x) in enumerate(args) if i in meta.mutated_inp_runtime_indices + ] + return *mutated_inputs_to_return, *outs + + return inner_fn + + +# This function takes in a fn with external aliasing and mutation, +# and returns a new fn with no external aliasing and mutation, +# as needed for autograd. +# The main transformations are: +# - Return mutated inputs as extra outputs +# - Clone mutated inputs that require gradients, +# because autograd will require us to pass the pre-mutated inputs into autograd.grad +# - Return intermediate bases of outputs as additional outputs, +# needed to appease autograd.Function +# The new function returns: +# (1) The updated outputs +# (2) A boolean mask of len(new_fn_outputs), +# that can be used to tell autograd.grad which outputs should get tangents +# if we trace the backward. +def fn_prepped_for_autograd( + fn: Callable, + meta: ViewAndMutationMeta, +) -> Any: + @wraps(fn) + def inner_fn(*args): + args_maybe_cloned = [ + maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) + ] + + outs = fn(*args_maybe_cloned) + assert isinstance(outs, (tuple, list)) + outs = list(outs) + assert len(meta.output_info) == len(outs) + + mutated_inputs_to_return = [ + x + for (i, x) in enumerate(args_maybe_cloned) + if i in meta.mutated_inp_runtime_indices + ] + + intermediate_bases = [] + for i, (o, info) in enumerate(zip(outs, meta.output_info)): + if info.output_type == OutputType.alias_of_intermediate_save_as_output: + intermediate_bases.append(o._base) + + assert meta.num_intermediate_bases == len(intermediate_bases) + + # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) + fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases + + # Also return a boolean mask specifying which outputs to this function will be used as tangents + mutated_inputs_grad_mask = [ + meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data + and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad + for (i, x) in enumerate(mutated_inputs_to_return) + ] + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + output_grad_mask = [ + meta.output_info[i].output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.output_info[i].raw_type, Tensor) + and meta.output_info[i].requires_grad + for (i, x) in enumerate(outs) + ] + + intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] + + out_grad_mask = ( + mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask + ) + assert len(out_grad_mask) == len(fw_outs_to_return) + + # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + # This is annoying: our joint function needs to be aware of functionalization + # (syncing mutated inputs before calling autograd.grad()) + # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. + for arg in args_maybe_cloned: + if not isinstance(arg, Tensor): + continue + sync_functional_tensor(arg) + + return fw_outs_to_return, out_grad_mask + + return inner_fn + + +# Given a fn, computes the joint. +# NOTE: fn is expects the following behavior: +# (1) fn() needs to return a tuple of (outs, mask), +# where `mask` tells us which outputs are meant to have tangents. +# we don't know this info automatically, because we don't actually want to blindly +# compute tangents for every output that requires grad. +# Specifically, outputs that alias inputs won't participate in the backward and get tangents. +# (2) fn() cannot mutate any inputs that require gradient. +# otherwise, when we compute autograd.grad(), we will not take those input mutations into account +# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) +def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any: + def inner_fn(primals: List[Any], tangents: List[Any]): + outs, tangent_mask = fn(*primals) + assert len(tangent_mask) == len(outs) + outs_to_grad = [ + o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent + ] + assert len(outs_to_grad) == len(tangents) + + # Get the inputs that need gradients + grad_primals = [] + inputs_needs_grads = [] + # Note that we're not using primals here, + # being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + is_grad_tensor = isinstance(p, Tensor) and p.requires_grad + inputs_needs_grads.append(is_grad_tensor) + if is_grad_tensor: + grad_primals.append(p) + + # Get the outputs that need gradients + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + # The not definitely_false is also sketchy; if unbacked + # symints are involved, we're just going to assume that the + # decomps setup the base shape correctly + needed_outs.append( + out + if not definitely_false(sym_eq(out.shape, tangent.shape)) + else out.view(tangent.shape) + ) + needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + if config.functionalize_rng_ops: + PhiloxStateTracker.mark_beginning_of_backward() + backward_out: Tuple[Tensor, ...] = tuple() + # Call the backwards pass + if grad_primals: + with fx_traceback.preserve_node_meta(): + # for full graph export, we always export a joint graph where we assume no tangents are needed. + if aot_config.no_tangents: + assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + allow_unused=True, + ) + else: + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + return outs, [ + next(backward_out_iter) if i else None for i in inputs_needs_grads + ] + + def inner_fn_with_anomaly(*args): + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.") + with torch.autograd.detect_anomaly(check_nan=False): + return inner_fn(*args) + + return inner_fn_with_anomaly + + +def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any: + # Functionalization of rng ops changes the calling convention of the joint graph. + # It goes from (primals, tangents) to (seed, offset, primals, tangents) + # At runtime, we pass on the current seed and offset. This is hidden from + # the user. + fake_mode = detect_fake_mode() + if fake_mode is None: + fake_mode = nullcontext() + + def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): + out = PhiloxStateTracker.get_state_as_tensor() + return out + + def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): + PhiloxStateTracker.set_state_from_tensor(x) + + def append_rng_offsets(args): + if trace_joint: + # args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) + return ( + (*args[0], PhiloxStateTracker.get_updated_fwd_offset()), + (*args[1], PhiloxStateTracker.get_updated_bwd_offset()), + ) + else: + # args signature before: Tuple(fwd_outputs) + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset) + return (*args, PhiloxStateTracker.get_updated_fwd_offset()) + + def traced_joint( + primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset + ): + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( + "torch.cuda.set_rng_state", override_set_rng_state + ): + return append_rng_offsets(func(primals, tangents)) + + def traced_forward(*primals_fwd_seed_fwd_base_offset): + # The signature is (*primals, seed, offset) + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( + "torch.cuda.set_rng_state", override_set_rng_state + ): + return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2])) + + if trace_joint: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") + return traced_joint, ( + *args, + fwd_seed, + fwd_base_offset, + bwd_seed, + bwd_base_offset, + ) + else: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + return traced_forward, (*args, fwd_seed, fwd_base_offset) + + +# This creates the final function that we want to trace using make_fx(), +# in both aot_dispatch_autograd and aot_dispatch_base. +# Preconditions: +# - fn corresponds to the user's fw function +# - fn arguments have been flattened, duplicate arguments have been handled +# - In the returned function, the "primals" arguments *includes* synthetic bases. +# This function does the work of functionalizing the input function, +# and performing copy_() calls at the end of the function if `keep_input_mutations` is set. +# The function returned has signature that is either: +# (1) "traced_fn(primals: List[Any])" if trace_joint is False +# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True +# Returns a new (functionalized) function, and updated arguments to call it with. +def create_functionalized_fn( + fn, + args, + *, + meta: ViewAndMutationMeta, + aot_config: AOTConfig, + trace_joint: bool, +) -> Any: + @wraps(fn) + def _functionalized_f_helper(*args): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + tokens = args[0][: len(meta.tokens)] + actual_args = args[0][len(meta.tokens) :] + args = (actual_args, args[1]) + else: + tokens = args[: len(meta.tokens)] + args = args[len(meta.tokens) :] + assert all(token.numel() == 0 for token in tokens) + + with disable_above: + # Wrap inputs into functional wrappers + f_args = pytree.tree_map(to_fun, args) + f_tokens = pytree.tree_map(to_fun, tokens) + + # Populate the current FunctionalTensorMode with the tokens per + # operator. See Note [FunctionalTensorMode is Stateful] + functional_tensor_mode = ( + torch.utils._python_dispatch._detect_functional_mode() + ) + assert functional_tensor_mode is not None + for i, k in enumerate(meta.tokens.keys()): + functional_tensor_mode._tokens[k] = f_tokens[i] + + # Run the joint + f_outs = fn(*f_args) + + # Return both the tokens and the outputs + # See Note [Side-Effectful Tokens in AOTAutograd] + f_outs = (*functional_tensor_mode._tokens.values(), *f_outs) + + if trace_joint: + # We support a limited amount of mutation of graph inputs during the backward pass. + # (This is used e.g. by Float8, which needs to update buffers during the backward pass) + # Here, we perform extra checks for primals that were mutated in the **backward** + # We're doing the checks here instead of doing them with the rest of the input mutation handling because: + # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened + # during the forward, because the handling is different: some input mutations from the the forward + # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same + # types of mutations in the backward we would need a bw-only runtime epilogue. + # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in + # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would + # require an extra round of tracing though, so it's more efficient to do in-line here. + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) + primals_before = args[0] + primals_after = pytree.tree_map(from_fun, f_args[0]) + for f_inpt, before, after, inpt_info in zip( + f_args[0], primals_before, primals_after, meta.input_info + ): + # Ban metadata mutations on fw inputs during the bw + if not inpt_info.mutates_metadata: + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ), "Found a graph input that had its metadata mutated in the backward. This is not supported" + # Allow data mutations on fw inputs during the bw, but only if they do not require grad + # So we can guarantee that we can keep the mutations in the graph + if has_data_mutation(f_inpt) and not inpt_info.mutates_data: + assert ( + not inpt_info.requires_grad + ), "Found a graph input that requires_grad and was mutated in the backward. This is not supported" + # Otherwise, put the mutation in the graph + before.copy_(after) + # Now that we covered mutations to *forward* inputs during the backward, + # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). + # Today, we will just error in all cases of this happening unless someone needs us to support it. + tangents_before = args[1] + tangents_after = pytree.tree_map(from_fun, f_args[1]) + for f_inpt, before, after in zip( + f_args[1], tangents_before, tangents_after + ): + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) and not has_data_mutation( + f_inpt + ), "Found an input to the backward that was mutated during the backward pass. This is not supported" + + if aot_config.keep_inference_input_mutations: + # Note: This is a bit annoying. There's a layering issue here, where: + # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. + # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. + # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, + # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). + # This makes it pretty difficult for this logic to operate on synthetic bases. + # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual + # (unpacked) input aliases, instead of the synthetic base. + # Example case where (3) could be important: + # + # def f(x, y): + # x.mul_(2) + # y.mul_(3) + # return x, y + # a = torch.ones(1'000'000) + # x, y = out(a[0:9], a[1:10]) + # + # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing + # a giant "updated synthetic base" and copying into a's entire storage. + # + # For now, we are pessimistically not performing the optimization from (3); + # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. + # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry + # about synthetic bases. + for i, (inpt_old, inpt_f) in enumerate( + zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) + ): + if not isinstance(inpt_f, torch.Tensor): + continue + assert is_fun(inpt_f) + inpt_new = from_fun(inpt_f) + if meta.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH: + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + if meta.input_info[i].mutations_hidden_from_autograd: + # Hidden from autograd = run under no_grad, **and** don't bump VC + with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( + inpt_old + ): + inpt_old.copy_(inpt_new) + elif meta.input_info[i].mutations_under_no_grad_or_inference_mode: + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + with torch.no_grad(): + inpt_old.copy_(inpt_new) + else: + inpt_old.copy_(inpt_new) + + # When an output tensor is a functionalized mutated input, and we + # were able to move the mutation in to the graph then we can return + # the mutated input directly. This prevents duplicating the + # tensors contents. + flat_outs, outs_spec = pytree.tree_flatten(f_outs) + flat_outs = [from_fun(o) for o in flat_outs] + num_outs = len(meta.output_info) + + for i, outp in enumerate(flat_outs[:num_outs]): + info = meta.output_info[i] + if info.output_type != OutputType.is_input: + continue + + assert info.base_idx is not None + if ( + meta.input_info[info.base_idx].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + flat_outs[i] = args[info.base_idx] + return pytree.tree_unflatten(flat_outs, outs_spec) + + return pytree.tree_map(from_fun, f_outs) + + # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" + # and "tangents" as its input names (which are special-cased by the partitioner) + # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export + def joint_helper(primals, tangents): + return _functionalized_f_helper(primals, tangents) + + helper = joint_helper if trace_joint else _functionalized_f_helper + if config.functionalize_rng_ops: + # Setup the wrapper for functionalization of rng ops + helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint) + + # Additionally pass in tokens as inputs + # See Note [Side-Effectful Tokens in AOTAutograd] + additional_token_inputs = [torch.tensor([])] * len(meta.tokens) + if trace_joint: + args = ([*additional_token_inputs, *args[0]], *args[1:]) + else: + args = [*additional_token_inputs, *args] + + return helper, args + + +# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor +# Also returns: +# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated) +# - the updated ViewAndMutationMeta for this dense -> dense function. +# The other important arguments are: +# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function. +# when is_joint_structure=False, this is just the forward function. +# - fw_only: this is *always* the forward-only function. +# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions. +# In particular, we need this to tell the partitioner how many dense forward outputs there are. +def aot_dispatch_subclass( + flat_fn_maybe_joint, + args: List[Any], + *, + is_joint_structure: bool, + meta: ViewAndMutationMeta, + fw_only: Callable, +) -> SubclassTracingInfo: + # Skip logic if we don't need to trace through any subclasses + req_subclass_dispatch = requires_subclass_dispatch(args, meta) + if not req_subclass_dispatch: + return SubclassTracingInfo( + plain_tensor_trace_fn=flat_fn_maybe_joint, + plain_tensor_args=args, + maybe_subclass_meta=None, + ) + + # TODO: add subclass guards (later PR). + + # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs). + # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint, + # so we set it later, while we're tracing the joint (see inner_fn() below). + # Another option would be to run our run_functionalized_fw_and_collect_metadata() function + # directly on the joint, but this would hurt compile time (adding yet another pass through the joint). + subclass_meta = SubclassMeta() + + def inner_fn(fn, args, *, use_trace_joint: bool): + # Step 1: wrap tensor inputs into subclasses if necessary + all_args = wrap_tensor_subclasses_maybe_joint( + args, is_joint_structure=use_trace_joint, meta=meta + ) + + # Step 2: call the inner function, with our (maybe subclass) inputs + wrapped_outs = fn(*all_args) + + if use_trace_joint: + # See Note: [Computing Subclass Metadata about grad_inputs] + # We also stash subclass info on our grad_inputs, if we're tracing the joint. + nonlocal subclass_meta + assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2 + # Don't need fw outs since we already have subclass metadata on them + grad_inputs = wrapped_outs[1] + subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs) + + # Step 3: Unwrap any subclass outputs back into dense tensors + unwrapped_outs = unwrap_tensor_subclasses( + wrapped_outs, is_joint_structure=use_trace_joint + ) + return unwrapped_outs + + def joint_fn(primals, tangents): + return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True) + + def fw_fn(*primals): + return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) + + def metadata_fn(*primals): + return inner_fn(fw_only, primals, use_trace_joint=False) + + args_unwrapped = unwrap_tensor_subclasses( + args, is_joint_structure=is_joint_structure + ) + + if is_joint_structure: + primals_unwrapped = args_unwrapped[0] + fn_to_trace = joint_fn + else: + primals_unwrapped = args_unwrapped + fn_to_trace = fw_fn + + # Note: [Partitioner handling for Subclasses, Part 1] + # The way the partitioner works is that: + # (1) we pass is a single graph containing the joint fw/bw, + # where the # of graph outputs corresponds to # fw_outputs + # grad_inputs + # (2) The partitioner accepts an arguments, num_fwd_outputs, + # and assumes that the first "num_fwd_outputs" graph outputs correspond + # to outputs of the forward graph. + # How do tensor subclasses enter the picture? + # the num_fwd_outputs in the final graph is actually non-trivial to compute, + # because it can be influenced by input mutations and intermediate bases. + # So we compute it by inspecting the current ViewAndMutationMeta object. + # However, the original ViewAndMutationMeta that we computed was created + # on the subclass -> subclass graph, + # which can have a different number of outputs than the dense -> dense graph. + # That's why we createa a fresh metadata object on the dense -> dense function here, + # and plumb it back up to the partitioner. + # See Note: [Partitioner handling for Subclasses, Part 2] for more info. + meta_updated = run_functionalized_fw_and_collect_metadata( + metadata_fn, + keep_input_mutations=meta.keep_input_mutations, + is_train=meta.is_train, + )(*primals_unwrapped) + + subclass_meta.fw_metadata = meta_updated + + return SubclassTracingInfo( + plain_tensor_trace_fn=fn_to_trace, + plain_tensor_args=args_unwrapped, + maybe_subclass_meta=subclass_meta, + ) + + +class PropagateUnbackedSymInts(torch.fx.Interpreter): + def run_node(self, n: torch.fx.Node): + import sympy + + result = super().run_node(n) + # TODO: handle Tensor returns + if "example_value" in n.meta: + if isinstance(result, torch.SymInt) and isinstance( + result.node.expr, sympy.Symbol + ): + torch._check(result == n.meta["example_value"]) + + return result + + +def create_functional_call(mod, params_spec, params_len, store_orig_mod=False): + # Redundant with dynamo, but worth having in case this gets invoked elsewhere. + # https://github.com/pytorch/pytorch/issues/103569 + + def functional_call(*args, **kwargs): + with stateless._reparametrize_module( + mod, pytree.tree_unflatten(args[:params_len], params_spec) + ): + if isinstance(mod, torch.fx.GraphModule): + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Anomaly Detection has been enabled." + ) + with torch.autograd.detect_anomaly(check_nan=False): + out = PropagateUnbackedSymInts(mod).run( + *args[params_len:], **kwargs + ) + else: + out = mod(*args[params_len:], **kwargs) + + if not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + # Note [Preserving the nn module stack metadata during export non-strict mode] + # This path is currently only used by the non-strict export flow, + # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph. + # Instead, we stash the original user nn module here, and rely on `make_fx` to grab + # this stashed module and use it to track nn module stack metadata + if store_orig_mod and not hasattr(functional_call, "_orig_mod"): + functional_call._orig_mod = mod # type: ignore[attr-defined] + + return functional_call diff --git a/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/utils.py b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6c7858fd8ea9df1bb6aee96c11e505127815eb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/_aot_autograd/utils.py @@ -0,0 +1,226 @@ +""" +Contains various utils for AOTAutograd, including those for handling collections. +""" + +import dataclasses +import warnings +from contextlib import nullcontext +from functools import wraps +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import py_sym_types + +KNOWN_TYPES = [ + torch.Tensor, + BackwardState, + int, + str, + float, + bool, + type(None), + *py_sym_types, +] + +original_zip = zip + + +def strict_zip(*iterables, strict=True, **kwargs): + if not strict: + return original_zip(*iterables, **kwargs) + + shortest_length = min(len(it) for it in iterables) + for iterable in iterables: + if len(iterable) != shortest_length: + raise ValueError( + "The iterables have different lengths and strict mode is enabled." + ) + + return original_zip(*iterables, **kwargs) + + +def _get_symint_hints(exprs): + """ + Get the hints of a list/tuple of int/SymInt. + """ + if isinstance(exprs, (list, tuple)): + return type(exprs)(_get_symint_hints(e) for e in exprs) + elif isinstance(exprs, torch.SymInt): + return exprs.node.shape_env.size_hint(exprs.node.expr) + else: + return exprs + + +def partial_flatten_asdict(obj: Any) -> Any: + if dataclasses.is_dataclass(obj): + return { + field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) + } + elif isinstance(obj, (list, tuple)): + return obj.__class__([partial_flatten_asdict(item) for item in obj]) + elif isinstance(obj, dict): + return {k: partial_flatten_asdict(v) for k, v in obj.items()} + else: + return obj + + +def normalize_as_list(x): + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +def _get_autocast_states(): + return [ + torch.is_autocast_enabled(), + torch.is_autocast_cpu_enabled(), + torch.get_autocast_gpu_dtype(), + torch.get_autocast_cpu_dtype(), + torch.is_autocast_cache_enabled(), + ] + + +def make_boxed_func(f): + def g(args): + return f(*args) + + g._boxed_call = True # type: ignore[attr-defined] + return g + + +def make_boxed_compiler(compiler): + @wraps(compiler) + def f(fx_g, inps): + out_f = compiler(fx_g, inps) + fx_g = make_boxed_func(out_f) + return fx_g + + return f + + +def call_func_at_runtime_with_args(f, args, steal_args=False, disable_amp=False): + if not steal_args: + args = list(args) + assert isinstance(args, list) + + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(): + if hasattr(f, "_boxed_call"): + out = normalize_as_list(f(args)) + else: + # TODO: Please remove soon + # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 + warnings.warn( + "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " + "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " + "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." + ) + out = normalize_as_list(f(*args)) + return out + + +# Inspired by autodidax (thanks!) +class PytreeThunk: + spec: Optional[pytree.TreeSpec] = None + # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. + is_simple: Optional[ + bool + ] = None # if the output spec is a tuple/list, we won't bother unflattening it. + is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec + + def set(self, spec: pytree.TreeSpec) -> None: + assert self.spec is None or self.spec == spec + assert spec is not None + self.spec: pytree.TreeSpec = spec + if self.spec.type in {tuple, list} and all( + child.is_leaf() for child in spec.children_specs + ): + self.is_simple = True + if self.spec.is_leaf(): + self.is_really_simple = True + + def unflatten(self, x: List[Any]) -> Any: + if self.is_really_simple: + return x[0] + if self.is_simple: + return x + assert self.spec is not None + return pytree.tree_unflatten(x, self.spec) + + +# Creates a function that returns flattened inputs and outputs +# Also returns the output tree spec, which is needed to recover the "unflattened" +# output tree structure later. +def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]: + if kwargs is None: + kwargs = {} + # Save the args_spec for flat_tensor_args to unflatten while tracing + _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + out_spec = PytreeThunk() + + def flat_fn(*flat_args): + # The input are flattened tensor args. Prepare the args in the + # order that original function expects. Add static args as well. + # They will appear as tensor constants in the traced graph. + nonlocal out_spec + args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec) + tree_out = fn(*args, **kwargs) + flat_out, spec = pytree.tree_flatten(tree_out) + for i in flat_out: + is_known_type = False + for j in KNOWN_TYPES: + if isinstance(i, j): + is_known_type = True + break + if not is_known_type: + raise RuntimeError( + f"Found {type(i)} in output, which is not a known type. " + "If this type holds tensors, you need to register a pytree for it. " + "See https://github.com/pytorch/functorch/issues/475 for a brief " + "explanation why. If you don't need to register a pytree, please " + "leave a comment explaining your use case and we'll make this more " + "ergonomic to deal with" + ) + out_spec.set(spec) + return flat_out + + # Can't use functools.wraps here because the wrapper has different + # calling convention + if hasattr(fn, "_orig_mod"): + flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined] + + return flat_fn, out_spec + + +# This function takes in a tensor t, and returns one of t, t.view(), or t.clone(). +# When tracing the joint forward + backward, for any inputs in the graph that are mutated, +# we need to clone them first (and similarly for metadata-only mutations, we need to view them first). +# The idea is that when we trace the backward, we need to pass in the *original* primals +# to autograd.grad(), before they were mutated. +# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them. +# This means that "idx" here represents the index of the (potentially) synthetic base. +# What we need to do is: +# (1) map the current (post-synthetic-base calling convention) input argument index +# to int index pre-synthetic-base-calling-convention. +# (2) There could be multiple, if this index corresponds to a synthetic base +# that has multiple input aliases. +# (3) If any of those corresponding inputs get metadata mutations, then we clone the base. +def maybe_to_fresh_input(idx, t, meta): + if not isinstance(t, torch.Tensor): + return t + if idx in meta.mutated_inp_runtime_indices: + # We only need to bother cloning mutated inputs that participate in autograd. + mutated_inp_idx = meta.mutated_inp_runtime_indices.index(idx) + if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the mutation + return t.clone() + if meta.input_info[idx] and meta.input_info[idx].mutates_metadata: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the metadata mutation + return t.view(t.shape) + return t diff --git a/MLPY/Lib/site-packages/torch/_functorch/aot_autograd.py b/MLPY/Lib/site-packages/torch/_functorch/aot_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..ec817eb7f1d6d1524e5d470b76df5fdc2fc9774a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/aot_autograd.py @@ -0,0 +1,1246 @@ +# mypy: ignore-errors + +import itertools +from contextlib import nullcontext +from functools import partial, wraps +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import patch + +import torch +import torch.nn as nn +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo import compiled_autograd +from torch._dynamo.utils import dynamo_timed, preserve_rng_state +from torch._guards import detect_fake_mode +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + ShapeEnv +) +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions +from . import config +from .partitioners import default_partition + +from ._aot_autograd.utils import ( # noqa: F401 + strict_zip, + _get_symint_hints, + KNOWN_TYPES, + partial_flatten_asdict, + normalize_as_list, + _get_autocast_states, + make_boxed_func, + make_boxed_compiler, + call_func_at_runtime_with_args, + create_tree_flattened_fn, + maybe_to_fresh_input, +) +from ._aot_autograd.logging_utils import ( # noqa: F401 + graph_being_compiled, + nth_graph, + model_name, + set_model_name, + get_aot_compilation_context, + get_aot_graph_name, + get_graph_being_compiled, + track_graph_compiling, + callback_set, + setup_stacktrace_preservation_hooks, + describe_input, + format_guard_bug_msg, +) +from ._aot_autograd.functional_utils import ( # noqa: F401 + is_fun, + to_fun, + from_fun, + sync_functional_tensor, + has_metadata_mutation, + has_data_mutation, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + gen_alias_from_base, + assert_functional_graph, + _check_if_mutation_can_be_in_graph, +) +from ._aot_autograd.schemas import ( # noqa: F401 + OutputType, + OutputAliasInfo, + MutationType, + InputAliasInfo, + SubclassCreationMeta, + ViewAndMutationMeta, + SubclassMeta, + TensorAlias, + BackwardSignature, + GraphOutputName, + GraphInputName, + FQN, + GraphSignature, + AOTConfig, +) +from ._aot_autograd.subclass_utils import ( # noqa: F401 + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, + create_metadata_for_subclass, +) +from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 + run_functionalized_fw_and_collect_metadata, +) +from ._aot_autograd.input_output_analysis import ( # noqa: F401 + remove_dupe_metadata, + create_synthetic_base_metadata, + _tensors_definitely_do_not_overlap, + compute_overlapping_inputs, + create_graph_signature, +) +from ._aot_autograd.traced_function_transforms import ( # noqa: F401 + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, + create_functionalized_fn, + create_functionalized_rng_ops_wrapper, + aot_dispatch_subclass, + create_functional_call, + create_joint, +) +from ._aot_autograd.runtime_wrappers import ( # noqa: F401 + create_runtime_wrapper, + functionalized_rng_runtime_epilogue, + aot_dispatch_subclass_wrapper, + aot_wrapper_dedupe, + aot_wrapper_synthetic_base, + merge_view_inputs, +) +from ._aot_autograd.dispatch_and_compile_graph import ( # noqa: F401 + aot_dispatch_base_graph, + aot_dispatch_autograd_graph, +) +from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 + aot_dispatch_base, + aot_dispatch_autograd, +) + +zip = strict_zip + +# This global counter increments every time we compile a graph with +# AOTAutograd. You can use this to correlate runtime error messages +# with compile time (e.g., if you get an error at runtime saying +# compiled graph 3 failed, you can set a breakpoint at compile time +# for this graph number to investigate further at compile time.) +# +# NB: this is different from get_aot_compilation_context, which tracks +# each underlying graph that is compiled. In contrast, AOT_COUNTER +# corresponds to top-level invocations of aot_module/aot_function; +# one counter is allocated per entire compiled block (but this block +# may involve compiling multiple subgraphs; e.g., for forwards/backwards) +AOT_COUNTER = itertools.count() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation +# that are external to the graph (they show up as side effects in some way when you run the graph). +# +# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions +# and what they're compiled graphs looks like. +# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them. +# +# Note [AOT Autograd: input data mutations] +# +# If we compile a function that mutates inputs, then those input mutations are real side effects +# that a user expects to see after running the compiled graph. +# However, the graph that we want to send to a backend needs to be *entirely* functional. +# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile +# but we update the graph to return (updated_inputs, user_outputs). +# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals. +# +# Example: original user code: +# def f(x): +# x.mul_(2) +# out = x.mul(3) +# return out +# +# After AOT Autograd compiles, we end up with a: +# (a) compiled graph +# (b) autograd.Function.forward() method, that executes the compiled graph +# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue +# +# The output of (a, b, c) are all written below. +# +# def compiled_forward_graph(x): +# x_updated = x.mul(2) +# out = x_updated.mul(3) +# return x_updated, out +# +# # x_updated gets a gradient in the compiled backward +# def compiled_backward_graph(grad_x_updated, grad_out): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# x_updated, out = compiled_forward_graph(x) +# return x_updated, out +# +# def compiled_wrapper(x): +# x_updated, out = autograd.Function.apply(x) +# x.copy_(x_updated) +# return out +# +# Another important thing to note is that updated inputs (due to data mutations) *do* participate +# in the compiled backward graph! Since the compiled forward graph gets N extra outputs +# (due to updated inputs showing up as graph outputs), +# The compiled backward gets an additional N inputs. +# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input +# back to the original input. + + +# Note [AOT Autograd: input metadata mutations] +# +# For the same reason as input mutations, we also don't put input metadata mutations in the graph. +# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph +# +# Example: original user code: +# def f(x): +# x.t_() +# out = x.mul(3) +# return out +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(x): +# x_updated = x.t() +# out = x_updated.mul(3) +# return x_updated, out +# +# # x_updated does *not* get a gradient in the compiled backward +# def compiled_backward_graph(grad_out): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# x_updated, out = compiled_forward_graph(x) +# return x_updated, out +# +# def compiled_wrapper(x): +# x_updated, out = autograd.Function.apply(x) +# x.as_strided_(x_updated) +# return out + + +# Note [AOT Autograd: outputs aliasing inputs or intermediates!] +# +# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates! +# Why? +# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated. +# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph, +# in an epilogue. +# For outputs that alias inputs, we do the following: +# (a) *still* return the aliased output as a graph output +# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output. +# +# For outputs that alias *intermediates*, we do the following: +# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward +# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output). +# You might wonder why we return the aliased output directly in the graph (and making the graph compute it), +# only to not return it and instead generate a fresh alias off of the intermediate, +# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons: +# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call +# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance. +# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides, +# when it has a different set of strides. +# By including the view op directly in the graph, inductor takes that into account when deciding what memory format +# the graph intermediate should be. +# +# Another important thing to note is how our traced backward() graph handles aliases. +# (this applies to outputs aliasing inputs, outputs aliasing intermediates, +# *and* updated inputs returned in the compiled forward due to metadata-only mutations). +# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph +# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly +# at the end of the forward. +# +# Example: original user code: +# def f(x): +# out1 = x.t() +# intermediate = x.mul(2) +# out2 = intermediate.view(-1) +# return out1, out2 +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(x): +# out1 = x.t() +# intermediate = x.mul(2) +# out2 = intermediate.view(-1) +# # the compiled graph also returns the intermediate +# return out1, out2, intermediate +# +# # intermediate gets a gradient in the compiled backward. +# # both output aliases (out1 and out2) do not. +# def compiled_backward_graph(grad_intermediate): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# out1, out2, intermediate = compiled_forward_graph(x) +# return out1, out2, intermediate +# +# def compiled_wrapper(x): +# out1, out2, intermediate = autograd.Function.apply(x) +# # regenerate out1 from the input +# out1_regenerated = out1._view_func(x) +# # regenerate out1 from the intermediate +# out2_regenerated = out2._view_func(intermediate) +# return out1_regenerated, out2_regenerated + + +# Note [AOT Autograd: mutations to inputs that alias other inputs] +# +# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input. +# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other. +# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias +# given the mutation that occurred. +# +# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input +# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base +# inside of the compiled function. +# +# This logic is fully encapsulated in aot_wrapper_synthetic_base() +# +# Example: original user code: +# def f(x, x_view): +# x.mul_(2) +# out = x * x_view +# return out +# f(x, x.view(-1)) +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(base) +# x = generate_x(base) +# x_view = generate_x_view(base) +# x_updated = x.mul(2) +# x_view_updated = x_updated.view(-1) +# out = x_updated * x_view_updated +# return x_updated, out +# +# # The calling convention change from (aliases) -> (base) happens +# # *outside* of the autograd.Function.forward(). +# # That means the forward() only has 1 input (base), +# # and the backward() only has 1 output (grad_base) +# def compiled_backward_graph(grad_out): +# grad_base = ... +# return grad_base +# +# def autograd.Function.forward(base): +# x_updated, out = compiled_forward_graph(base) +# return x_updated, out +# +# # The compiled wrapper is where we create synthetic bases. +# # The info on which inputs are mutated is also tracked *before* synthetic base creation. +# def compiled_wrapper(x, x_view): +# base = merge_view_inputs(x, x_view) +# x_updated, out = autograd.Function.apply(base) +# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view. +# x.copy_(x_updated) +# return out + + +# Note [AOT Autograd: Views to avoid tangents aliasing inputs] +# +# We view every forward output when creating out tangent tensors to handle the problematic +# case in which a subclass does extra aliasing between graph outputs/inputs in a way that +# is not visible above the sublass. +# +# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd, +# we're guaranteed that the tangent tensors that we pass +# into the joint are distinct tensors from the primals. This is because when +# decide which forward outputs to create tangents for, we only create tangents +# for forward outputs that are not aliases of inputs (See Note +# [AOT Autograd: outputs aliasing inputs or intermediates!]). +# +# However, when wrapper tensor subclasses enter the picture, it is possible +# to have an output of the forward that is a subclass that is not an +# input / alias of an input, but one of its inner tensors is an alias! +# NestedTensor is an example: Performing an out-of-place pointwise op on a +# NestedTensor constructs a fresh NestedTensor that holds onto the input's +# offsets tensor directly. +# +# Having tangent tensors that are the same as the (primal) forward inputs, +# can cause problems during tracing as make_fx() will specialize on our +# duplicate inputs: If we passed in the same tensor for primals_1 and +# tangents_1 during tracing, make_fx() will happily sub out all usages of +# tangents_1 with primals_1 in the graph, which is not what we want. +# +# To work around this, we view every forward output when creating out tangent +# tensors so that tangents can never be the same as forward inputs even if +# forward inputs alias forward outputs. + +# Note [Side-Effectful Tokens in AOTAutograd] +# +# We allow some some side-effectful operators in +# the post-AOTAutograd (functional) graph, such as prints and torchbind operations. +# To ensure that these side-effects are compatible to future graph passes that +# assume that the graph is functional, we will thread "effect tokens" to show +# data dependence between these side-effectful operators. Practically speaking, +# effect tokens are just dummy values (torch.tensor([])). The graph would look +# like the following: +# +# def gm(self, token0, reader): +# token1, frame = with_token(ordered_effect_op, (reader,), token0) +# frame = frame * 2 +# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) +# frame2 = frame2 * 2 +# return token2, frame, frame2 +# +# We will pass the token as an input to the graph, thread it through +# side-effectful operators using the `with_effects` high order operator, and then +# return the updated token as an output. +# So the signature of the graph input would look something like +# (*tokens, *params_buffers, *user_inputs), and the signature of the graph +# output would look something like (*tokens, *outputs). + +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +aot_autograd_decompositions = {} + +@dynamo_timed +def create_aot_dispatcher_function( + flat_fn, flat_args: List[Any], aot_config: AOTConfig +): + """ + Traces the forward and backward graphs of the attr:`flat_fn` to generate a + joint graph. The joint graph is an Fx graph with Aten ops. Please refer to + the tracing mechanism to understand the graph capturing details. + + The joint graph is then passed through attr:`partition_fn` to isolate the + forward and backward portions, which are then respectively compiled via the + provided attr:`fw_compiler` and attr:`bw_compiler`. + + The resulting compiled forward and backward graphs are then wrapped up in a + ``torch.autograd.Function`` object. + + The calling convention here is that the first aot_config.num_params_buffers + inputs in flat_args are parameters and buffers, and the rest are inputs. + + We use this to assume that parameters/buffer's shapes don't change. + + Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) + When aot_config.is_export is True, we return an FX graph + metadata + When aot_config.is_export is False, we return an ordinary runtime function + """ + + # This is the main entry point. + # TODO: Chillee argues that dynamo itself should pass in fake tensors to + # the list of arguments when compiling; at the moment we do not do this + + if aot_config.decompositions is None: + aot_config.decompositions = {} + + + aot_config.decompositions = { + **aot_autograd_decompositions, + **aot_config.decompositions, + } + + if config.functionalize_rng_ops: + # Update the decompositions with functionalized random decompositions + aot_config.decompositions = { + **rng_decompositions, + **aot_config.decompositions, + } + + # Check flat_args to see if they're already fake. If so, use that fake + # mode instead. + + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + + python_dispatcher_mode = ( + enable_python_dispatcher() if shape_env is not None else nullcontext() + ) + + with torch.autograd.set_multithreading_enabled( + False + ), preserve_rng_state(), fake_mode, python_dispatcher_mode, PhiloxStateTracker(): + + def process_inputs(flat_args): + def convert(idx, x): + if shape_env is not None: + from torch._dynamo.source import ConstantSource + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), + hint=x, + source=source + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs) + return x + + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + return fake_mode.from_tensor( + x, static_shapes=False, symbolic_context=symbolic_context, source=source + ) + return [convert(idx, x) for idx, x in enumerate(flat_args)] + + fake_flat_args = process_inputs(flat_args) + + needs_autograd = ( + any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) + and torch.is_grad_enabled() + ) + + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *args: None): + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + )(*fake_flat_args) + + req_subclass_dispatch = requires_subclass_dispatch(fake_flat_args, fw_metadata) + + if needs_autograd and not any(x.requires_grad for x in fw_metadata.output_info): + # We realized that none of the outputs require grad, + # so we actually have an inference graph. + needs_autograd = False + # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta + # changes depending on whether we pass in is_train / keep_input_mutations, + # so we're forced to recompute the metadata. + # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata + # so that this is unnecessary. + if req_subclass_dispatch: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + )(*fake_flat_args) + else: + fw_metadata = ViewAndMutationMeta( + input_info=fw_metadata.input_info, + output_info=fw_metadata.output_info, + num_intermediate_bases=fw_metadata.num_intermediate_bases, + keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd, + traced_tangents=fw_metadata.traced_tangents, + subclass_inp_meta=fw_metadata.subclass_inp_meta, + subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, + subclass_tangent_meta=fw_metadata.subclass_tangent_meta, + is_train=needs_autograd, + ) + + + if fw_metadata.num_intermediate_bases > 0: + assert not req_subclass_dispatch, f"""\ +torch.compile is currently being used with tensor subclass inputs: +{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs +that alias one another, which is currently unsupported in the subclass use case. If you run into this, +please file a github issue""" + + if aot_config.is_export: + # aot_export: ban input metadata mutations for now to keep shared code paths simpler. + # Keeping .resize_() in the graph will require some work + # Allowing it but keeping the graph functional will require some calling convention changes. + if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: + raise RuntimeError(f"""\ +Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. +This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. + +fw_metadata={str(fw_metadata)}""") + # In export, banning data mutations on inputs that require grad for now. + # This should be rare, and is tricky to get right. When we trace the backward, + # we currently trace with autograd.grad instead of .backward(), which makes it difficult + # to ensure that we run autograd all the way through the input **before** it saw the mutation. + if len([x for x in fw_metadata.input_info if x.requires_grad and x.mutates_data]) != 0: + raise RuntimeError(f"""\ +Found a graph input that requires gradients, and received a mutation. +This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. + +fw_metadata={str(fw_metadata)}""") + if req_subclass_dispatch: + raise RuntimeError("""\ +aot_export is not currently supported with traceable tensor subclass. +If you need this feature, please comment on """) + + # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, + # and turning it on will require a non-trivial calling convention change for any export runtime. + if config.functionalize_rng_ops: + raise RuntimeError("""\ +Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, +or otherwise set torch._functorch.config.functionalize_rng_ops = False.""") + + # crappy version of dispatcher + # TODO: Do this properly + if needs_autograd: + # For now, aot_dispatch_autograd knows to explicitly return a graph + # when run with export, and an opaque callable otherwise. + # In theory we could factor these out, but I wanted to let the dust + # settle on how functionalized rng fits into export first. + compiler_fn = aot_dispatch_autograd_graph if aot_config.is_export else aot_dispatch_autograd + else: + # aot_dispatch_base_graph contains only the "graph bits", while aot_dispatch_base + # includes some extra work around handling a runtime epilogue. + compiler_fn = aot_dispatch_base_graph if aot_config.is_export else aot_dispatch_base + + compiler_fn = partial(aot_wrapper_synthetic_base, compiler_fn=compiler_fn, needs_autograd=needs_autograd) + compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn) + # You can put more passes here + + compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata) + if aot_config.is_export: + # During export, we don't get back a callable - we get back the raw fx graph + # (either a joint or an inference-only graph) + assert isinstance(compiled_fn, torch.fx.GraphModule) + return compiled_fn, fw_metadata + + if not hasattr(compiled_fn, "_boxed_call"): + compiled_fn = make_boxed_func(compiled_fn) + + return compiled_fn + + +def aot_function( + fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + num_params_buffers: int = 0, + keep_inference_input_mutations: bool = False, + inference_compiler: Optional[Callable] = None, + *, + # Whether or not to trace with dynamic shapes + dynamic=False, + enable_log=True, +) -> Callable: + """ + Traces the forward and backward graph of :attr:`fn` using torch dispatch + mechanism, and then compiles the generated forward and backward graphs + through :attr:`fw_compiler` and :attr:`bw_compiler`. + + :func:`aot_function` traces the forward and backward graph ahead of time, + and generates a joint forward and backward graph. :attr:`partition_fn` is + then used to separate out forward and backward graphs. The partitioner + function can be used to perform optimizations such as recomputation. One can + set `decompositions` dictionary to decompose the operators into a sequence + of core or simpler operators supported by the backend compilers. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Callable): A Python function that takes one ore more arguments. Must + return one or more Tensors. + fw_compiler (Callable): A Python function that accepts an Fx graph with + Aten ops and input args, and returns a Callable that semantically is + equivalent to the input Fx graph. + bw_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + partition_fn (Callable): A Python function that takes a joint forward + and backward graph, and partitions it into separate forward and + backward graphs. + decompositions (Dict): A dictionary to define the decomposition of + larger Aten ops into simpler or core Aten ops. + inference_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. inference_compiler is invoked + if no autograd is needed. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + Returns: + Returns a ``Callable`` that retains the eager behavior of the original + :attr:`fn`, but with forward and backward graph compiled via + :attr:`fw_compile` and :attr:`bw_compile`. + + A simple example usage of :func:`aot_function` is as follows. This example + will print the forward and backward graphs of the function ``fn`` + + >>> fn = lambda x : x.sin().cos() + >>> def print_compile_fn(fx_module, args): + >>> print(fx_module) + >>> return fx_module + >>> aot_fn = aot_function(fn, print_compile_fn) + >>> x = torch.randn(4, 5, requires_grad=True) + >>> aot_fn(x) + """ + + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic, + aot_autograd_arg_pos_to_source=None, + is_export=False, + no_tangents=False, + enable_log=enable_log, + ) + cached_res = None + + @wraps(fn) + def returned_function(*args, **kwargs): + nonlocal cached_res + # Now flatten the tensor args + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + + # Compile the function and save it in the cache + if cached_res is None: + flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs) + + compiled_fn = create_aot_dispatcher_function( + flat_fn, + flat_args, + aot_config, + ) + cached_res = (compiled_fn, out_spec) + + cached_fn, out_spec = cached_res + out = cached_fn(flat_args) + return out_spec.unflatten(out) + + return returned_function + + +def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: + """ + Traces the forward and backward graph of :attr:`mod` using torch dispatch + tracing mechanism. It is wrapper function, that underneath uses + :func:`aot_function` to perform tracing and compilation. + + :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs + to a new callable which is then compiled through :func:`aot_function`. + + .. warning:: + This API is experimental and likely to change. + + Args: + mod (Callable): A ``nn.Module`` module. + args : args to be passed to :func:`aot_function` + kwargs : kwargs to be passed to :func:`aot_function` + + Returns: + Returns a ``nn.Module`` that retains the eager behavior of the original + :attr:`mod`, but with forward and backward graph compiled. + + """ + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) + + def functional_call(named_params, named_buffers, *args, **kwargs): + params_and_buffers = {**named_params, **named_buffers} + return torch.func.functional_call(mod, params_and_buffers, args, kwargs) + + named_params = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + num_params_buffers = len(named_params) + len(named_buffers) + compiled_f = aot_function( + functional_call, *args, num_params_buffers=num_params_buffers, **kwargs + ) + + class AOTModule(nn.Module): + def __init__(self): + super().__init__() + self.orig_module = mod + + def forward(self, *args, **kwargs): + return compiled_f( + named_params, + named_buffers, + *args, + **kwargs, + ) + + return AOTModule() + + +def aot_module_simplified( + mod: nn.Module, + args, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + keep_inference_input_mutations=False, + inference_compiler: Optional[Callable] = None, +) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + params = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = list(params_flat) + params_len = len(params_flat) + + functional_call = create_functional_call(mod, params_spec, params_len) + + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + + seen_sources = set() + + full_args = [] + # First, the params + full_args.extend(params_flat) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.params_flat = params_flat + + aot_autograd_arg_pos_to_source = None + # Then, the params 1:1 mapped sources, if relevant. + if hasattr(mod, "_param_name_to_source"): + aot_autograd_arg_pos_to_source = [] + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + for name in params.keys(): + assert name in mod._param_name_to_source, f"{name} not found." + source = mod._param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + # Next, the input args + full_args.extend(args) + + if hasattr(mod, "graph"): + # Non dynamo entrypoints can get to here... + for i, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if hasattr(node, "_dynamo_source"): + # ... but not here! + if aot_autograd_arg_pos_to_source is None: + aot_autograd_arg_pos_to_source = [] + source = node._dynamo_source + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + if aot_autograd_arg_pos_to_source is not None: + assert len(full_args) == len(aot_autograd_arg_pos_to_source) + + dynamic_shapes = False + for x in full_args: + if isinstance(x, FakeTensor): + dynamic_shapes = x.fake_mode.shape_env is not None + break + + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=params_len, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, + is_export=False, + no_tangents=False, + ) + + with compiled_autograd.disable(): + compiled_fn = create_aot_dispatcher_function( + functional_call, + full_args, + aot_config, + ) + + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + def forward(*runtime_args): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) + + # Just for convenience + forward.zero_grad = mod.zero_grad + forward.named_parameters = mod.named_parameters + forward.named_buffers = mod.named_buffers + + return forward + + +def aot_export_module( + mod: nn.Module, + args, + *, + decompositions: Optional[Dict] = None, + # If true, we'll return a joint forward-backward graph, + # As well as metadata on the loss + gradients in the backward. + trace_joint: bool, + # If trace_joint is True, we expect your module to return a scalar loss. + # Your module can return multiple outputs, so you must specify which output the loss is. + output_loss_index: Optional[int] = None, + pre_dispatch: bool = False, + kwargs=None, +) -> Tuple[torch.fx.GraphModule, GraphSignature]: + """ + This function takes in a module, and returns: + (1) an FX graph that can be exported + (2) some metadata about the graph + + If `trace_joint=True` we will return a joint graph of the forward + backward. + + The traced FX graph will have the following properties compared to the original module: + (1) Inputs and outputs to the module will be pytree-flattened + (2) Parameters and buffers on the module will be lifted into graph inputs, + graph_inputs = (*parameters, *buffers, *user_inputs) + (3) The graph will be fully functionalized + (4) Any input mutations will be converted into additional outputs in the graph, + meaning whoever calls this graph is responsible for applying the mutations + back to the original inputs. + (5) If is_joint is provided the graph will return parameter gradients in addition to user outputs. + The graph output will look like: + graph_outputs = (*updated_inputs, *user_outputs, *param_gradients) + + There are also several restrictions on what modules can use this API. In particular: + (1) If trace_joint is specified, we expect the loss function to be **fused** + into the module forward. One of the outputs to the forward must be a scalar loss, + which is specified with `output_loss_index`. + All other outputs to the forward are presumed to not require gradients. + (2) This API cannot capture optimizers (although in theory we could build an API for this). + (3) Metadata mutations on params/buffers/inputs are banned. + (4) Data mutations on anything that requires gradients are banned (parameters) + (5) If an input is mutated, it is not allowed to alias any other inputs. + (6) Parameters must not be duplicated. + """ + if pre_dispatch and trace_joint: + raise RuntimeError("pre_dispatch is not supported when trace_joint is True.") + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + + params_and_buffers = { + **dict(named_parameters), + **dict(named_buffers), + } + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) + params_and_buffers_flat = tuple(params_and_buffers_flat) + params_len = len(params_and_buffers_flat) + + kwargs = kwargs or {} + + functional_call = create_functional_call(mod, params_spec, params_len, store_orig_mod=True) + + num_fw_outs = None + + if trace_joint: + # This helper effectively just adds some extra asserts about what the backward will look like: + # Outputs must include a scalar loss, that we compute gradients w.r.t. + # We don't compute gradients w.r.t. anything else: so just in case we detach() + # and other output tensors. + def fn_to_trace(*args): + nonlocal num_fw_outs + out = functional_call(*args) + if output_loss_index is None: + raise RuntimeError("""\ +If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss. +You must specify the which (index) output is the loss with output_loss_index.""") + if isinstance(out, (torch.Tensor)): + out = (out,) + if not isinstance(out, (tuple, list)): + raise RuntimeError(f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}") + + for i, o in enumerate(out): + # We only want to create a backward graph w.r.t. the loss that the user passed in. + # This implies that every other output should not require gradients. + # Instead of making this an error (and forcing the user to detach all other outputs + # of their forward), + # we'll automatically detach them here. + if o.requires_grad and i != output_loss_index: + raise RuntimeError(f"""\ +Found an output of the forward that requires gradients, that was not the scalar loss. +We require all outputs to the forward that are not the scalar loss to not require gradient, +because we will only compute a backward graph against the scalar loss. +You can fix this by calling .detach() on each of your forward outputs that is not the loss. +You specified that output index {output_loss_index} is the loss, but we found that +the output at index {i} requires gradients.""") + out_loss = out[output_loss_index] + num_fw_outs = len(out) + if not out_loss.requires_grad: + raise RuntimeError(f"""\ +The output at index {output_loss_index} was marked as the loss, but it does not require gradients""") + if out_loss.numel() != 1: + raise RuntimeError(f"""\ +We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""") + return out + ctx = nullcontext + else: + # Run under no_grad, so our tracing machinery only traces an inference graph. + ctx = torch.no_grad + fn_to_trace = functional_call + + full_args = [] + # First, the params + # NB: It is REQUIRED that parameters come first, Inductor infers "fixed" + # parameters by looking at the difference in parameter count outside + # and inside AOTAutograd, and assumes the prefix of arguments are fixed + # arguments + full_args.extend(params_and_buffers_flat) + # Next, the input args + full_args.extend(args) + + with ctx(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + fn_to_trace, + full_args, + decompositions=decompositions, + num_params_buffers=params_len, + no_tangents=True, + pre_dispatch=pre_dispatch, + kwargs=kwargs, + ) + if trace_joint: + def flattened_joint(*args): + # The idea here is that the joint graph that AOTAutograd creates has some strict properties: + # (1) It accepts two arguments (primals, tangents), and pytree_flattens them + # (2) It returns a tuple of (fw_outs, gradients) + # This is a very useful convention for anyone who wants to partition the joint graph + # into a separate forward and backward graph. + # However, + # (1) for people exporting a single joint graph, it would be preferable not to have + # any pytrees in the graph. + # (2) We are guaranteed in the aot_export_module case that the forward outputs a loss, + # and there are therefore no tangents that are needed to run the joint graph. + # (3) AOTAutograd creates a grad_input for every input in the forward, + # including None's for inputs that are not grad-requiring tensors. + # we don't want these in our export graph. + # and there are therefore no tangents that are needed to run the joint graph. + # This function "fixes" both of the above by removing any tangent inputs, + # and removing pytrees from the original FX graph. + fake_tangents = [None for _ in range(metadata.num_outputs + metadata.num_mutated_inp_runtime_indices)] + fw_outs, gradients = fx_g(args, fake_tangents) + assert len(gradients) == len(args) + output_gradients = [] + for i, (a, grad) in enumerate(zip(args, gradients)): + if isinstance(a, torch.Tensor) and a.requires_grad: + assert grad is not None, """\ +Found a parameter that did not receive a gradient. +"This is most likely a bug, but if this needs to be supported please comment on this Github issue: +https://github.com/pytorch/pytorch/issues/101192 +""" + output_gradients.append(grad) + else: + assert grad is None + return *fw_outs, *output_gradients + fx_g = make_fx(flattened_joint)(*full_args) + + user_args_flat = pytree.arg_tree_leaves(*args, **kwargs) + return fx_g, create_graph_signature( + fx_g, + metadata, + in_spec, + out_spec, + user_args_flat=user_args_flat, + params_and_buffers_flat=params_and_buffers_flat, + param_names=list(named_parameters.keys()), + buffer_names=list(named_buffers.keys()), + trace_joint=trace_joint, + num_user_fw_outs=num_fw_outs, + loss_index=output_loss_index, + ) + +def aot_export_joint_simple( + func: Callable, + args, + *, + trace_joint: bool, + # It looks like the main consequence of this API is that for dynamic shapes, + # it will assume that parms/buffers are static. + # With the new inferred dynamic shapes API, maybe this doesn't matter? + num_params_buffers: int = 0, + decompositions: Optional[Dict] = None, +) -> torch.fx.GraphModule: + """ + A simplified version of export. Used by higher order operators. + + This function makes a high-level "no calling convention changes" guarantee: + - If no inputs require grad (so we export an inference graph), + there are *no* calling convention change between the exported graph, and "func". + - If at least one input requires grad (so we trace out and export a joint fw-bw graph), + Then if you were partition the graph into a separate forward and backward graph, + The forward graph will have no calling convention changes compared to "func". + + The above also relies on some strong restrictions around which functions this API accepts: + (1) `args` cannot contain any pytrees (they must have been pytree_flattened already) + (2) `func` cannot mutate any inputs + (3) The outputs of `func` cannot alias any inputs. + + Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops. + """ + if trace_joint: + ctx = nullcontext + else: + # Run under no_grad, so our tracing machinery only traces an inference graph. + ctx = torch.no_grad + + with ctx(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + func, + args, + decompositions=decompositions, + ) + in_spec, _kw_in_spec = in_spec.children_specs + # At this point, we can just directly return the (joint or inference graph) that we traced. + # First though: a bunch of assertions to make sure that our graph doesn't require + # any calling convention changes compared to the original function. + # These restrictions are *in addition to* the general restrictions on export. + + # No input mutations + if len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) != 0: + raise RuntimeError(f"aot_export_joint_simple does not support input mutations. {str(metadata)}") + # No output aliasing + if len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) != 0: + raise RuntimeError(f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}") + # No pytrees + if in_spec.is_leaf(): + raise RuntimeError(f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}") + if not all(child.is_leaf() for child in in_spec.children_specs): + raise RuntimeError(f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}") + if out_spec.is_leaf(): + raise RuntimeError(f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}") + if not all(child.is_leaf() for child in out_spec.children_specs): + raise RuntimeError(f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}") + # TODO: we might have to temporarily patch config.functionalize_rng + # so that it doesn't run when we're exporting a higher order op. + + if config.debug_assert: + # Smoke test that after partitioning, we can run the forward without any calling convention changes. + fw_module, bw_module = aot_config.default_partition( # noqa: F821 + fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821 + ) + # Attempt to run the fw_module with the original user inputs + fake_mode = detect_fake_mode(args) + if fake_mode is None: + fake_mode = FakeTensorMode() + with fake_mode: + fw_module(*args) + return fx_g + +# Private for now because we aren't providing a contract on what to return +# for joint graphs (we could when there's a clearer use case) +# In the future, we may need to add more export API's that provide their own strong guarantees. +# This is meant as a general helper function for handling various export-y use cases. +def _aot_export_function( + func: Callable, + args, + *, + num_params_buffers: int = 0, + decompositions: Optional[Dict] = None, + # If we're exporting a joint graph and we don't want any tangent inputs in the graph + # (because we are backpropping through a scalar 1 loss), + # we need to explicitly specify not to include tangents in the graph. + # It's not enough just to check that our tangent is a scalar, since we also + # need to know if it is a 1 (no need to make it a graph input), or something else + # (requiring it to be a graph input). + # We don't know this info at trace time though, so we need to make it an explicit config. + no_tangents: bool = False, + pre_dispatch: bool = False, + kwargs=None, +) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: + kwargs = kwargs or {} + + flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs) + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + dynamic_shapes = False + for x in flat_args: + if isinstance(x, FakeTensor): + dynamic_shapes = x.fake_mode.shape_env is not None + break + + # The export use case doesn't care about several bits of AOTConfig + # (1) compilers (we just export the graph) + # (2) partitioners (export is only full graph, user can partition themselves) + aot_config = AOTConfig( + fw_compiler=None, + bw_compiler=None, + inference_compiler=None, + partition_fn=None, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + # For now there's no use case involving keeping input mutations in the graph + # (which we can only do in the inference case anyway). + # We can add this later if we need to. + keep_inference_input_mutations=False, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=None, + is_export=True, + no_tangents=no_tangents, + pre_dispatch=pre_dispatch, + ) + + fx_g, meta = create_aot_dispatcher_function( + flat_fn, + flat_args, + aot_config, + ) + return fx_g, meta, in_spec, out_spec.spec + + +compiled_function = aot_function +compiled_module = aot_module diff --git a/MLPY/Lib/site-packages/torch/_functorch/apis.py b/MLPY/Lib/site-packages/torch/_functorch/apis.py new file mode 100644 index 0000000000000000000000000000000000000000..17358cd7bf45b02eef863dfc286036b420f57eba --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/apis.py @@ -0,0 +1,401 @@ +# NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can +# trace through functorch transforms. +# Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing +# and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file +# to Dynamo. +from torch._functorch.vmap import (vmap_impl, _check_randomness_arg, + Callable, in_dims_t, out_dims_t, _check_out_dims_is_int_or_int_pytree, + _process_batched_inputs, _chunked_vmap) +from torch._functorch.utils import exposed_in, argnums_t +import functools + +# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, +# sends those into func, and then unwraps the output BatchedTensors. Operations +# on BatchedTensors perform the batched operations that the user is asking for. +# +# vmap's randomness behavior differs from JAX's, which would require a PRNG key +# to be passed everywhere. + + +@exposed_in('torch.func') +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = 'error', + *, + chunk_size=None) -> Callable: + """ + vmap is the vectorizing map; ``vmap(func)`` returns a new function that + maps ``func`` over some dimension of the inputs. Semantically, vmap + pushes the map into PyTorch operations called by ``func``, effectively + vectorizing those operations. + + vmap is useful for handling batch dimensions: one can write a function + ``func`` that runs on examples and then lift it to a function that can + take batches of examples with ``vmap(func)``. vmap can also be used to + compute batched gradients when composed with autograd. + + .. note:: + :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for + convenience. Use whichever one you'd like. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. ``in_dims`` should have a + structure like the inputs. If the ``in_dim`` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If ``out_dims`` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunk_size (None or int): If None (default), apply a single vmap over inputs. + If not None, then compute the vmap :attr:`chunk_size` samples at a time. + Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop. + If you run into memory issues computing the vmap, please try a non-None chunk_size. + + Returns: + Returns a new "batched" function. It takes the same inputs as + ``func``, except each input has an extra dimension at the index + specified by ``in_dims``. It takes returns the same outputs as + ``func``, except each output has an extra dimension at the index + specified by ``out_dims``. + + .. warning: + :func:`vmap` works best with functional-style code. Please do not + perform any side-effects in ``func``, with the exception of + in-place PyTorch operations. Examples of side-effects include mutating + Python data structures and assigning values to variables not captured + in ``func``. + + One example of using :func:`vmap` is to compute batched dot products. PyTorch + doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully + rummaging through docs, use :func:`vmap` to construct a new function. + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) + + :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler + model authoring experience. + + >>> batch_size, feature_size = 3, 5 + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> + >>> def model(feature_vec): + >>> # Very simple linear model with activation + >>> return feature_vec.dot(weights).relu() + >>> + >>> examples = torch.randn(batch_size, feature_size) + >>> result = torch.vmap(model)(examples) + + :func:`vmap` can also help vectorize computations that were previously difficult + or impossible to batch. One example is higher-order gradient computation. + The PyTorch autograd engine computes vjps (vector-Jacobian products). + Computing a full Jacobian matrix for some function f: R^N -> R^N usually + requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`, + we can vectorize the whole computation, computing the Jacobian in a single + call to ``autograd.grad``. + + >>> # Setup + >>> N = 5 + >>> f = lambda x: x ** 2 + >>> x = torch.randn(N, requires_grad=True) + >>> y = f(x) + >>> I_N = torch.eye(N) + >>> + >>> # Sequential approach + >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] + >>> for v in I_N.unbind()] + >>> jacobian = torch.stack(jacobian_rows) + >>> + >>> # vectorized gradient computation + >>> def get_vjp(v): + >>> return torch.autograd.grad(y, x, v) + >>> jacobian = torch.vmap(get_vjp)(I_N) + + :func:`vmap` can also be nested, producing an output with multiple batched dimensions + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) + >>> batched_dot(x, y) # tensor of size [2, 3] + + If the inputs are not batched along the first dimension, ``in_dims`` specifies + the dimension that each inputs are batched along as + + >>> torch.dot # [N], [N] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension + + If there are multiple inputs each of which is batched along different dimensions, + ``in_dims`` must be a tuple with the batch dimension for each input as + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None + + If the input is a Python struct, ``in_dims`` must be a tuple containing a struct + matching the shape of the input: + + >>> f = lambda dict: torch.dot(dict['x'], dict['y']) + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> input = {'x': x, 'y': y} + >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) + >>> batched_dot(input) + + By default, the output is batched along the first dimension. However, it can be batched + along any dimension by using ``out_dims`` + + >>> f = lambda x: x ** 2 + >>> x = torch.randn(2, 5) + >>> batched_pow = torch.vmap(f, out_dims=1) + >>> batched_pow(x) # [5, 2] + + For any function that uses kwargs, the returned function will not batch the kwargs but will + accept kwargs + + >>> x = torch.randn([2, 5]) + >>> def fn(x, scale=4.): + >>> return x * scale + >>> + >>> batched_pow = torch.vmap(fn) + >>> assert torch.allclose(batched_pow(x), x * 4) + >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] + + .. note:: + vmap does not provide general autobatching or handle variable-length + sequences out of the box. + """ + _check_randomness_arg(randomness) + if not (chunk_size is None or chunk_size > 0): + raise ValueError(f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})") + + # @functools.wraps(func) + def wrapped(*args, **kwargs): + return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs) + + return wrapped + + +def chunk_vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = 'error', + chunks=2) -> Callable: + """ + chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes + everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of + chunks at a time. For more details about vectorizing map, see :func:`vmap`. + + .. note:: + Please use :func:`vmap` with ``chunk_size`` argument instead of this API. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. ``in_dims`` should have a + structure like the inputs. If the ``in_dim`` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If ``out_dims`` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunks (int): Number of chunks to use to split the input data. Default is 2. + If equals to 1 then :func:`vmap` is called. + + Returns: + Returns a new "batched" function. It takes the same inputs as + ``func``, except each input has an extra dimension at the index + specified by ``in_dims``. It takes returns the same outputs as + ``func``, except each output has an extra dimension at the index + specified by ``out_dims``. + """ + _check_randomness_arg(randomness) + + if chunks == 1: + return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness) + + def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_): + flat_args_chunks = tuple( + t.chunk(chunks_, dim=in_dim) if in_dim is not None else [t, ] * chunks_ + for t, in_dim in zip(flat_args_, flat_in_dims_) + ) + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + @functools.wraps(func) + def wrapped_with_chunks(*args, **kwargs): + _check_out_dims_is_int_or_int_pytree(out_dims, func) + _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func) + # Chunk flat arguments + chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) + + # Apply vmap on chunks + return _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs) + + return wrapped_with_chunks + + +@exposed_in("torch.func") +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + """``grad`` operator helps computing gradients of ``func`` with respect to the + input(s) specified by ``argnums``. This operator can be nested to + compute higher-order gradients. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified ``has_aux`` equals ``True``, + function can return a tuple of single-element Tensor and other auxiliary objects: + ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to. + ``argnums`` can be single integer or tuple of integers. Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a tensor and other + auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute gradients with respect to its inputs. By default, the output of + the function is the gradient tensor(s) with respect to the first argument. + If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects + is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with + respect to each ``argnums`` value is returned. + + Example of using ``grad``: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad + >>> x = torch.randn([]) + >>> cos_x = grad(lambda x: torch.sin(x))(x) + >>> assert torch.allclose(cos_x, x.cos()) + >>> + >>> # Second-order gradients + >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) + >>> assert torch.allclose(neg_sin_x, -x.sin()) + + When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad, vmap + >>> batch_size, feature_size = 3, 5 + >>> + >>> def model(weights, feature_vec): + >>> # Very simple linear model with activation + >>> assert feature_vec.dim() == 1 + >>> return feature_vec.dot(weights).relu() + >>> + >>> def compute_loss(weights, example, target): + >>> y = model(weights, example) + >>> return ((y - target) ** 2).mean() # MSELoss + >>> + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> examples = torch.randn(batch_size, feature_size) + >>> targets = torch.randn(batch_size) + >>> inputs = (weights, examples, targets) + >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) + + Example of using ``grad`` with ``has_aux`` and ``argnums``: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad + >>> def my_loss_func(y, y_pred): + >>> loss_per_sample = (0.5 * y_pred - y) ** 2 + >>> loss = loss_per_sample.mean() + >>> return loss, (y_pred, loss_per_sample) + >>> + >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) + >>> y_true = torch.rand(4) + >>> y_preds = torch.rand(4, requires_grad=True) + >>> out = fn(y_true, y_preds) + >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``grad``. + + Case 1: Using ``torch.no_grad`` inside a function: + + >>> # xdoctest: +SKIP + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``grad`` inside ``torch.no_grad`` context manager: + + >>> # xdoctest: +SKIP + >>> with torch.no_grad(): + >>> grad(f)(x) + + In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``grad`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + + """ + # To avoid cyclical dependency. + import torch._functorch.eager_transforms as eager_transforms + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) + return wrapper + + +@exposed_in("torch.func") +def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + """ + Returns a function to compute a tuple of the gradient and primal, or + forward, computation. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified ``has_aux`` + equals ``True``, function can return a tuple of single-element + Tensor and other auxiliary objects: ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients + with respect to. ``argnums`` can be single integer or tuple of + integers. Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a tensor and + other auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute a tuple of gradients with respect to its inputs + and the forward computation. By default, the output of the function is + a tuple of the gradient tensor(s) with respect to the first argument + and the primal computation. If specified ``has_aux`` equals + ``True``, tuple of gradients and tuple of the forward computation with + output auxiliary objects is returned. If ``argnums`` is a tuple of + integers, a tuple of a tuple of the output gradients with respect to + each ``argnums`` value and the forward computation is returned. + + See :func:`grad` for examples + """ + from torch._functorch import eager_transforms + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return eager_transforms.grad_and_value_impl(func, argnums, has_aux, args, kwargs) + return wrapper diff --git a/MLPY/Lib/site-packages/torch/_functorch/autograd_function.py b/MLPY/Lib/site-packages/torch/_functorch/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa7b3bc6dcff645c1d167765b3c4b6446b8a4c7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/autograd_function.py @@ -0,0 +1,659 @@ +import torch +from torch._ops import HigherOrderOperator +from torch._C._functorch import TransformType +from torch._functorch.utils import enable_single_level_autograd_function +import torch.utils._pytree as pytree +from torch._C._functorch import ( + _wrap_for_grad, + _unwrap_for_grad, + current_level, +) +from torch._functorch.vmap import ( + wrap_batched, + unwrap_batched, + restore_vmap, + _add_batch_dim, +) +from torch._functorch.apis import vmap +from torch._functorch.vmap import _broadcast_to_and_flatten +from torch.autograd.forward_ad import _set_fwd_grad_enabled +from typing import Any, NamedTuple, Tuple + +# autograd.Function technically runs before the regular PyTorch dispatcher. +# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot) +# work with it. One day we might decide to change this, but until then, +# we need to give the illusion that autograd.Function runs before those things. +# +# We do this by using creating a custom HigherOrderOperator that only functorch +# dispatches specially. +class CustomFunctionHigherOrderOperator(HigherOrderOperator): + def __init__(self): + super().__init__('custom_function_call') + + def __call__(self, autograd_function, *args, **kwargs): + # When custom_function_call is done dispatching through functorch, + # it should just invoke the autograd.Function. This is consistent + # with the autograd.Function behavior of being invoked before the + # PyTorch dispatcher. + # + # This will lead us into trouble later down the line, but this is + # pre-existing. There is an invariant that a function traced by + # make_fx should have the same behavior when provided the same + # Tensor. However, make_fx sees autograd.Function as a composite + # (because autograd.Function happens before the Python dispatch key) + # and only traces the forward pass. + if torch._C._are_functorch_transforms_active(): + return super().__call__(autograd_function, *args, **kwargs) + return autograd_function.apply(*args, **kwargs) + + +# "custom_function_call" +# This is the mechanism for an autograd.Function that works with functorch transforms. +# It wraps an autograd.Function; interactions with functorch transforms are defined +# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch +# dispatcher. +custom_function_call = CustomFunctionHigherOrderOperator() + + +# The grad rule for custom_function_call is to construct a new _SingleLevelFunction +# (autograd.Function that only works with a single layer (level) of functorch) that: +# - unwraps the inputs +# - redispatches to custom_function_call +# - wraps the outputs +# and whose backward pass calls the original autograd.Function's backward. +# +# Why do we need to redispatch to custom_function_call? +# ----------------------------------------------------- +# This is consistent with how ATen operators work with functorch's grad transform: +# they always redispatch to the original operator. +# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x) +# +# grad1 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin (*) +# - rewrap the outputs on the return +# +# On the redispatch in (*), grad0 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin +# - rewrap the outputs on the return +# +# To "set up the autograd graph", we generate a _SingleLevelFunction +# and apply it. +@custom_function_call.py_impl(TransformType.Grad) +@custom_function_call.py_impl(TransformType.Jvp) +def custom_function_call_grad(interpreter, autograd_function, *operands): + Generated = generate_single_level_function(interpreter, autograd_function) + with enable_single_level_autograd_function(): + flat_out = Generated.apply(*operands) + return flat_out + + +def generate_single_level_function(interpreter, autograd_function): + level = interpreter.level() + + def forward(*operands): + unwrapped_operands = pytree.tree_map_only( + torch.Tensor, + lambda x: _unwrap_for_grad(x, level), + operands) + # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter + # the transform. _SingleLevelFunction will turn off both fwd and bwd + # gradient computation and we need to turn it back on here. + with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower(): + unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands) + + # See NOTE [mark_dirty object identity check] + def wrap_fn(output): + return _wrap_for_grad(output, level) + + return wrap_outputs_maintaining_identity( + unwrapped_output, + unwrapped_operands, + operands, + wrap_fn) + + def setup_context(ctx, inputs, output): + return autograd_function.setup_context(ctx, inputs, output) + + # backward is only used if the transform is TransformType.Grad + def backward(ctx, *grads): + result = autograd_function.backward(ctx, *grads) + return result + + # jvp is only used if the transform is TransformType.Jvp + def jvp(ctx, *tangents): + result = autograd_function.jvp(ctx, *tangents) + return result + + # This is the sequence of magic words to dynamically generate a Subclass with + # a given name. A Tensor's .grad_fn field has a class name that is the original + # autograd.Function's name + Backward, so we do this to generate some + # meaningful name. + name = f'{autograd_function.__name__}Generated' + Generated = type( + name, + (torch.autograd.function._SingleLevelFunction,), + { + 'forward': staticmethod(forward), + 'backward': staticmethod(backward), + 'jvp': staticmethod(jvp), + 'setup_context': staticmethod(setup_context), + }, + ) + return Generated + +# wrap_outputs_maintaining_identity handles outputs from the vmap, +# backward (vjp), and jvp staticmethod. The way it distinguishes +# between the vmap case and the {backward, jvp} case is if the out_dims +# are specified or not. +# +# NB: we cannot use out_dims=None as the deciding factor. This because +# out_dims=None can still happen in the vmap staticmethod! What the +# user is saying in that case is that their output does not have a +# dimension that is being vmapped over, which is valid. +NO_OUT_DIMS = "not specified" + +# NOTE [mark_dirty object identity check] +# autograd.Function's ctx.mark_dirty expect a returned input +# to have the same object identity as the input. +# Mode-only functorch will greatly simplify this logic. +def wrap_outputs_maintaining_identity( + outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS): + flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) + flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) + + unwrapped_input_to_orig_input = { + id(unwrapped): orig + for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs) + } + + flat_outputs, spec = pytree.tree_flatten(outputs) + result = [] + + out_dims_specified = out_dims != NO_OUT_DIMS + + if out_dims_specified: + flat_out_dims = _broadcast_to_and_flatten(out_dims, spec) + # _broadcast_to_and_flatten returns None if it is unable to broadcast. + # TODO: update following link from master to stable once that's out + if flat_out_dims is None: + raise RuntimeError( + f"The autograd.Function's vmap staticmethod returned an " + f"incompatible (output, out_dims) tuple. " + f"Expected out_dims={out_dims} " + f"to be compatible with the structure of `output`. " + f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " + f"but output has structure {spec}. " + f"For more details, please see " + f"https://pytorch.org/docs/master/notes/extending.func.html" + ) + + for i, output in enumerate(flat_outputs): + if not isinstance(output, torch.Tensor): + result.append(output) + continue + if id(output) in unwrapped_input_to_orig_input: + result.append(unwrapped_input_to_orig_input[id(output)]) + continue + if out_dims_specified: + result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index] + else: + result.append(wrap_fn(output)) + + return pytree.tree_unflatten(result, spec) + + +# NOTE: [functorch vjp and autograd interaction] +# There's an edge case with the functorch vjp and autograd interaction +# that will eventually be fixed by mode-only functorch. +# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper, +# so we (the framework) need to do it manually. Regular PyTorch operators +# automatically do so this is consistent. +# +# class MyExp(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.exp() +# +# @staticmethod +# def setup_context(ctx, inputs, output): +# y = output +# ctx.save_for_backward(y) +# +# @staticmethod +# def backward(gy): +# y, = ctx.saved_tensors() +# return MyMul.apply(gy, y) +# +# x = torch.randn([], requires_grad=True) +# gy = torch.randn([], requires_grad=True) +# _, vjp_fn = vjp(MySin.apply, x) +# result = vjp_fn(gy) +# +# MyMul is an autograd.Function that is not shown here. +# It saves a `y` for backward (since gy requires grad). +# +# in vjp_fn(gy), we get: +# > MyMul.apply(gy, GradTensorWrapper(y, level=dead)) +# Because the y that is saved for backward by MyExp is a GradTensorWrapper +# but is now dead since we are outside the vjp context. +# +# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper, +# will automatically unwrap the GradTensorWrapper when applied. +# But since autograd.Function technically sits above the regular PyTorch +# dispatcher, it doesn't get this treatment. So we manually do +# the unwrapping to be consistent with regular PyTorch dispatcher operations. + + +class VmapInfo(NamedTuple): + batch_size: int + randomness: str + + +def has_overriden_vmap_rule(autograd_function): + return autograd_function.vmap is not torch.autograd.Function.vmap + + +def validate_vmap_returns_tuple_of_two_elements(result): + base_error_msg = ( + "Expected the vmap staticmethod to have two returns, an output " + "and out_dims with pytree structure compatible with the output. " + ) + if not isinstance(result, tuple): + raise RuntimeError(base_error_msg + f"Got a {type(result)} instead") + if not len(result) == 2: + raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead") + +@custom_function_call.py_impl(TransformType.Vmap) +def custom_function_call_vmap(interpreter, autograd_function, *operands): + if autograd_function.generate_vmap_rule: + if has_overriden_vmap_rule(autograd_function): + # TODO: Update link to stable once that's out + # https://github.com/pytorch/pytorch/issues/92029 + raise RuntimeError( + f"You tried to vmap over {autograd_function.__name__}, but " + f"it has both generate_vmap_rule=True and an overriden vmap " + f"staticmethod. Please set generate_vmap_rule=False or delete " + f"the overriden vmap staticmethod to avoid ambiguity. " + f"For more details, please see " + f"https://pytorch.org/docs/master/notes/extending.func.html") + return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands) + + if not has_overriden_vmap_rule(autograd_function): + # TODO: Update link to stable once that's out + # https://github.com/pytorch/pytorch/issues/92029 + raise RuntimeError( + f"You tried to vmap over {autograd_function.__name__}, but " + f"it does not have vmap support. Please override and implement the " + f"vmap staticmethod or set generate_vmap_rule=True. " + f"For more details, please see " + f"https://pytorch.org/docs/master/notes/extending.func.html") + + current_level = interpreter.level() + info = VmapInfo( + batch_size=interpreter.batch_size(), + randomness=interpreter.randomness(), + ) + unwrapped_operands, in_dims = unwrap_batched(operands, current_level) + + # If none of the tensors are batched at the current level, then we skip the + # current level. This saves the user from needing to handle this case in + # their vmap staticmethod (and is consistent with our C++ batching rule API) + if pytree.tree_all(lambda dim: dim is None, in_dims): + with interpreter.lower(): + return custom_function_call(autograd_function, *operands) + + with interpreter.lower(): + result = autograd_function.vmap(info, in_dims, *unwrapped_operands) + validate_vmap_returns_tuple_of_two_elements(result) + unwrapped_output, out_dims = result + + # See NOTE [mark_dirty object identity check] + def wrap_fn(output, out_dim): + return output if out_dim is None else _add_batch_dim(output, out_dim, current_level) + + return wrap_outputs_maintaining_identity( + unwrapped_output, + unwrapped_operands, + operands, + wrap_fn, + out_dims=out_dims) + + +def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands): + unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level()) + vmapped_function, get_out_dims = vmapify_autograd_function( + autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()) + + with interpreter.lower(): + output = custom_function_call(vmapped_function, *unwrapped_operands) + + out_dims = get_out_dims() + return wrap_batched(output, out_dims, interpreter.level()) + + +@custom_function_call.py_impl(TransformType.Functionalize) +def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands): + raise RuntimeError("NYI: Functionalize rule for custom_function_call") + + +def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness): + # The following values are saved from the forward() and setup_context() + # and used in backward(). + # Why do we save the values out here instead of on the ctx object? + # - out_dims: There's no way to retrieve this from forward() + # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting + # vmap(vmap( but not completely sure if it is a problem. If we + # assigned those fields to the ctx object, the worry is that they + # get overwritten. + init_val = "not populated" + out_dims = init_val + input_shapes: Any = init_val + saved_tensors_bdims: Any = init_val + + def forward(*operands): + nonlocal out_dims + outputs, out_dims = restore_vmap( + autograd_function.forward, in_dims, batch_size, randomness)(*operands) + return outputs + + def setup_context(ctx, inputs, outputs): + input_shapes_ = None + saved_tensors_bdims_ = None + + def inner(inputs, outputs): + # wrapped_ctx.save_for_backward will: + # - unwrap batchedtensors into (tensor, bdim) + # - save_for_backward(*unwrapped_tensors) + # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims + wrapped_ctx = CtxCustomSave(ctx, current_level()) + autograd_function.setup_context(wrapped_ctx, inputs, outputs) + + # input_shapes are used for reductify later to reduce expanded gradients + # to the correct shape. + # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?] + # for more details + nonlocal input_shapes_ + input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None + for inp in inputs) + nonlocal saved_tensors_bdims_ + saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims + + # See NOTE: [Why do we need to run setup_context under a vmap?] + restore_vmap( + inner, + (in_dims, out_dims), + batch_size, + randomness, + )(inputs, outputs) + + nonlocal input_shapes + input_shapes = input_shapes_ + nonlocal saved_tensors_bdims + saved_tensors_bdims = saved_tensors_bdims_ + + def jvp(ctx, *tangents): + assert out_dims != init_val + assert saved_tensors_bdims != init_val + + def jvp_no_context(saved_tensors, tangents): + wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) + return autograd_function.jvp(wrapped_ctx, *tangents) + + tangent_in_dims = get_tangents_in_dims(in_dims, tangents) + out_tangents, out_tangents_dims = restore_vmap( + jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)( + ctx.saved_tensors, tangents) + + result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size) + return result + + def backward(ctx, *grad_outputs): + assert out_dims != init_val + assert input_shapes != init_val + assert saved_tensors_bdims != init_val + + def backward_no_context(inputs): + saved_tensors, grad_outputs = inputs + wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) + return autograd_function.backward(wrapped_ctx, *grad_outputs) + + grad_ins, grad_ins_dims = restore_vmap( + backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)( + (ctx.saved_tensors, grad_outputs)) + result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes) + return result + + name = f'Vmapped{autograd_function.__name__}' + Generated = type( + name, + (torch.autograd.Function,), + { + 'forward': staticmethod(forward), + 'backward': staticmethod(backward), + 'jvp': staticmethod(jvp), + 'setup_context': staticmethod(setup_context), + 'generate_vmap_rule': True + } + ) + + def get_out_dims(): + assert out_dims != init_val + return out_dims + + return Generated, get_out_dims + + +# tangents might be None, so we need to replace +# the corresponding in_dims with None. +def get_tangents_in_dims(input_dims, tangents): + flat_in_dims, spec = pytree.tree_flatten(input_dims) + flat_tangents = pytree.arg_tree_leaves(*tangents) + result = [None if tangent is None else in_dim + for in_dim, tangent in zip(flat_in_dims, flat_tangents)] + return pytree.tree_unflatten(result, spec) + + +# NOTE: [Why do we need to run setup_context under a vmap?] +# Consider the following autograd.Function +# +# class Sum(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.sum() +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# ctx.x_shape = inputs[0] +# @staticmethod +# def backward(ctx, gy): +# return gy.expand(ctx.x_shape) +# +# x = torch.randn(B, 4) +# in_dims = 0 +# vmap(Sum.apply, in_dims)(x) +# +# Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum: +# +# class VmappedSum(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return vmap(Sum.forward, in_dims)(x) +# +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# Sum.setup_context(ctx, inputs, outputs) +# +# @staticmethod +# def backward(ctx, gy): +# def backward_no_context(gy): +# return gy.expand(ctx.x_shape) +# +# dims = (0,) +# gx = vmap(backward_no_context, dims)(gy) +# return gx +# +# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B], +# and we’re doing: +# +# def backward_no_context(gy): +# return gy.expand([B, 4]) +# +# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]") +# +# This gives us the wrong result (gx has shape [B, B, 4], but it should +# have shape [4]). Performing vmap over setup_context means the shape +# saved has shape [4] and leads to a correct result shape for gx. + +# Wraps a ctx object. Forwards all attr accesses to the underlying object +# except for the attrs in _pt_attrs +class WrappedCtx: + _pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx') + + def __init__(self, ctx): + if not isinstance(ctx, WrappedCtx): + reserved_attrs = type(self)._pt_reserved_attrs + for name in reserved_attrs: + if not hasattr(ctx, name): + continue + raise RuntimeError( + f'PyTorch reserves the {reserved_attrs} field on ctx. ' + 'Please name your fields on ctx something else to avoid name ' + 'collision.') + self._pt_inner_ctx = ctx + + def __getattr__(self, name): + return getattr(self._pt_inner_ctx, name) + + def __setattr__(self, name, value): + if name in type(self)._pt_reserved_attrs: + self.__dict__[name] = value + return + return setattr(self._pt_inner_ctx, name, value) + +# Wraps ctx to create a new ctx object that overrides saved_tensors. +class CtxWithSavedTensors(WrappedCtx): + _pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs) + + def __init__(self, ctx, new_saved_tensors): + super().__init__(ctx) + self._pt_new_saved_tensors = new_saved_tensors + + @property + def saved_tensors(self): + return self._pt_new_saved_tensors + +class CtxCustomSave(WrappedCtx): + _pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level', + *WrappedCtx._pt_reserved_attrs) + + def __init__(self, ctx, current_level): + super().__init__(ctx) + self._pt_saved_tensors_bdims = () + self._pt_current_level = current_level + + def save_for_backward(self, *tensors): + unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) + self._pt_inner_ctx.save_for_backward(*unwrapped_tensors) + self._pt_saved_tensors_bdims = bdims + + def save_for_forward(self, *tensors): + unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) + self._pt_inner_ctx.save_for_forward(*unwrapped_tensors) + self._pt_saved_tensors_bdims = bdims + + +def reductify(grad_input, grad_input_bdim, input_bdim, batch_size, + target_shape_without_bdim_to_reduce_to=None): + if not isinstance(grad_input, tuple): + grad_input = (grad_input,) + if not isinstance(grad_input_bdim, tuple): + grad_input_bdim = (grad_input_bdim,) + if not isinstance(input_bdim, tuple): + input_bdim = (input_bdim,) + + if target_shape_without_bdim_to_reduce_to is None: + target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,) + result = tuple( + reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape) + for gi, gi_bdim, i_bdim, maybe_ishape in + zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to) + ) + return result + + +def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size, + target_shape_without_bdim_to_reduce_to=None): + if grad_input is None: + return None + + if grad_input_bdim is None and input_bdim is None: + return grad_input + + if grad_input_bdim is not None and input_bdim is None: + return grad_input.sum(grad_input_bdim) + + # NOTE: [Why can't we rely on autograd to reduce expanded gradients?] + # For reverse-mode AD, + # given a grad_input and input, it is valid for the user to return a + # grad_input that has a broadcasted shape when compared to the input. + # In this situation, autograd automatically reduces the grad_input to + # the shape of the input. + # + # However, when input_bdim is not None, we have problems. + # + # [example 1] + # grad_input: Tensor[3, 4], input: Tensor[B, 4] + # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable + # from [B, 4]. + # + # [example 2] + # grad_input: Tensor[3, B, 4], input: Tensor[B, 4] + # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable + # from [B, 4]. + # + # This means that we need to also reduce the grad_input to the shape of the + # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag; + # if not-None then we do the reducing manually, otherwise, we do not do a reduction. + assert input_bdim is not None + + if grad_input_bdim is None: + grad_input = grad_input.unsqueeze(input_bdim) + new_shape = list(grad_input.shape) + new_shape[input_bdim] = batch_size + grad_input = grad_input.expand(new_shape) + grad_input_bdim = input_bdim + + if target_shape_without_bdim_to_reduce_to is not None: + return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)( + grad_input, target_shape_without_bdim_to_reduce_to) + + if input_bdim != grad_input_bdim: + grad_input = grad_input.movedim(grad_input_bdim, input_bdim) + return grad_input + + +class AutogradFunctionApply(HigherOrderOperator): + def __init__(self): + super().__init__("autograd_function_apply") + + def __call__(self, fwd, bwd, *fwd_args): + saved_values = None + + class ApplyTemplate(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + nonlocal saved_values + output, saved_values = fwd(None, *args) + return output + + @staticmethod + def backward(ctx, *grad): + return bwd(None, *grad, *saved_values) + + return ApplyTemplate.apply(*fwd_args) + + +autograd_function_apply = AutogradFunctionApply() diff --git a/MLPY/Lib/site-packages/torch/_functorch/batch_norm_replacement.py b/MLPY/Lib/site-packages/torch/_functorch/batch_norm_replacement.py new file mode 100644 index 0000000000000000000000000000000000000000..d741bc215dbb3bdffbf6dbe486b38c5494476754 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/batch_norm_replacement.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from torch._functorch.utils import exposed_in + + +def batch_norm_without_running_stats(module: nn.Module): + if isinstance(module, nn.modules.batchnorm._BatchNorm) and module.track_running_stats: + module.running_mean = None + module.running_var = None + module.num_batches_tracked = None + module.track_running_stats = False + + +@exposed_in("torch.func") +def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: + """ + In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and + setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` + """ + # base case + batch_norm_without_running_stats(root) + + for obj in root.modules(): + batch_norm_without_running_stats(obj) + return root diff --git a/MLPY/Lib/site-packages/torch/_functorch/benchmark_utils.py b/MLPY/Lib/site-packages/torch/_functorch/benchmark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cac2701b70591b818743bcf1db72848274637c85 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/benchmark_utils.py @@ -0,0 +1,195 @@ +# mypy: ignore-errors + +import contextlib +import time +import os +import json + +import torch +from torch.profiler import profile, ProfilerActivity + + +def synchronize(): + pass + + +def dump_chrome_trace(f, input, trace_filename, optimize_ctx, activities, num_runs=1, + devices=None, kwargs_for_f=None, kwargs_for_profiler=None): + """ + Output the chrome trace of running f(input, **kwargs_for_f) with [optimize_ctx] + [num_runs] times to [trace_filename]. + + [activities] are the activities that the profiler will record, e.g. ProfilerActivity.CUDA. + Return total runtime without the profiler + + Outputs to trace_filename + """ + + if devices is None: + devices = ["cuda"] + + global synchronize + if devices != ["cpu"] and torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + + if kwargs_for_f is None: + kwargs_for_f = {} + if kwargs_for_profiler is None: + kwargs_for_profiler = {} + + with optimize_ctx: + torch.manual_seed(1337) + for _ in range(5): # warmup runs + f(input, **kwargs_for_f) + synchronize() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + t1 = time.perf_counter() + timing = t1 - t0 + + with profile(activities=activities, **kwargs_for_profiler) as prof: + with optimize_ctx: + synchronize() + torch.manual_seed(1337) + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + prof.export_chrome_trace(trace_filename) + + return timing + + +def get_chrome_trace_events(filename): + f = open(filename) + data = json.load(f) + events = data["traceEvents"] + return events + + +def is_gpu_compute_event(event): + global gpu_pids + return "pid" in event and event["pid"] in gpu_pids and "ph" in event and event["ph"] == "X" + + +def get_sorted_gpu_events(events): + sorted_gpu_events = [] + for event in events: + if not is_gpu_compute_event(event): + continue + sorted_gpu_events.append(event) + return sorted(sorted_gpu_events, key=lambda x: x["ts"]) + + +def get_duration(sorted_gpu_events): + if len(sorted_gpu_events) == 0: + return 0 + event = sorted_gpu_events[0] + current_end_time = event["ts"] + event["dur"] + total_duration = event["dur"] + for event in sorted_gpu_events[1:]: + start_time = max(event["ts"], current_end_time) + end_time = event["ts"] + event["dur"] + total_duration = total_duration + max(end_time - start_time, 0) + current_end_time = max(current_end_time, end_time) + return total_duration + + +def get_sorted_gpu_mm_conv_events(events): + def is_mm_conv_event(event): + return "name" in event and ("gemm" in event["name"] or "conv" in event["name"] + or "cutlass" in event["name"] or "wgrad" in event["name"]) + gpu_events = get_sorted_gpu_events(events) + sorted_events = [] + for event in gpu_events: + if not is_mm_conv_event(event): + continue + sorted_events.append(event) + return sorted_events + + +gpu_pids = [] + + +def compute_utilization(filename: str, total_length: float): + """ + Process the chrome traces outputs by the pytorch profiler to compute GPU Utilization + and percent of times spent on matmul and convolution + + Args: + filename(str): Name of chrome traces file produced by pytorch profiler + + total_length(float): total length of the process without profiler in second + + Return: + tuple: (GPU Utilization, percent of time spent on matmul and convolution) + """ + events = get_chrome_trace_events(filename) + + # get pids of GPU events + global gpu_pids + gpu_pids = [] + for event in events: + if "name" not in event: + continue + if event["name"] == 'process_labels' and "GPU" in event["args"]["labels"]: + gpu_pids.append(event["pid"]) + + total_length = total_length * 1e6 + sorted_gpu_events = get_sorted_gpu_events(events) + utilization = get_duration(sorted_gpu_events) / total_length + + sorted_gpu_mm_conv_events = get_sorted_gpu_mm_conv_events(events) + mm_conv_utilization = get_duration(sorted_gpu_mm_conv_events) / total_length + + return utilization, mm_conv_utilization + + +def benchmark_utilization(f, input, trace_folder, optimize_ctx=None, trace_file_name="tmp_chrome_trace", num_runs=1): + """ + Benchmark the GPU Utilization and percent of time spent on matmul and convolution operations of + running f(input, **kwargs_for_f) with [optimize_ctx] [num_runs] times. + It will produce a chrome trace file in trace_folder/trace_file_name.json + + Example: + + ``` + def f(a): + return a.sum() + a = torch.rand(2**20, device="cuda") + utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace") + ``` + + Args: + f: function to benchmark + + input: input to :attr:`f` + + trace_folder: name of the folder to store the chrome trace + + optimize_ctx: the context in which f will run + + trace_file_name: name of the dumped chrome trace file, default to "tmp_chrome_trace" + + num_runs: number of times to run f, excluding the warm-up runs, default to 1. + + Return: + tuple: (GPU Utilization, percent of time spent on matmul and convolution) + + """ + isExist = os.path.exists(trace_folder) + if not isExist: + os.makedirs(trace_folder) + print("create folder " + trace_folder) + + if optimize_ctx is None: + optimize_ctx = contextlib.nullcontext() + + chrome_trace_file_name = os.path.join(trace_folder, trace_file_name + ".json") + total_length = dump_chrome_trace(f, input, chrome_trace_file_name, optimize_ctx, + [ProfilerActivity.CUDA], num_runs=num_runs, devices="cuda") + utilization, mm_conv_utilization = compute_utilization(chrome_trace_file_name, total_length) + + return utilization, mm_conv_utilization diff --git a/MLPY/Lib/site-packages/torch/_functorch/compile_utils.py b/MLPY/Lib/site-packages/torch/_functorch/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e62358d67b8c01115c3163887216005a92e7f4d8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/compile_utils.py @@ -0,0 +1,97 @@ +# mypy: ignore-errors + + +import torch +import torch.fx as fx +from torch.utils._pytree import tree_flatten +from torch.utils import _pytree as pytree + +aten = torch.ops.aten + + +def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + +rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma, + aten.bernoulli, aten.multinomial, aten.native_dropout, + aten.normal, aten.poisson, aten.binomial, aten.rrelu, + aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm] + + +# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph +def fx_graph_cse(fx_g: torch.fx.graph.Graph): + new_graph = fx.Graph() + env = {} # map from node in the old graph to node in the new graph + hash_env = {} # map from hash to a node in the new graph + token_map = {} # map from hash to token + for n in fx_g.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops: + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, torch.fx.node.Node) and v in env: + arg_list[i] = env[v] + if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)): + arg_list[i] = v.node + return tuple(arg_list), spec + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = {"target": n.target, "args": args, "args_spec": args_spec, + "kwargs": kwargs, "kwargs_spec": kwargs_spec} + + # hash substituted args to a number, do not hash specs because specs are not hashable + # We need to add type into hash to avoid situations like: + # hash((primals_2, 1.0)) == hash((primals_2, 1)) + hash_arg = hash((tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + return new_graph + + +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + +def get_placeholders(graph): + return list(filter(lambda x: x.op == 'placeholder', graph.nodes)) + +def get_outputs(graph): + for node in graph.nodes: + if node.op == 'output': + return pytree.tree_leaves(node.args[0]) + raise AssertionError("No output node found") diff --git a/MLPY/Lib/site-packages/torch/_functorch/compilers.py b/MLPY/Lib/site-packages/torch/_functorch/compilers.py new file mode 100644 index 0000000000000000000000000000000000000000..409b1bbb8f0979c6a7e3d1756d5bb31a1cbdaa79 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/compilers.py @@ -0,0 +1,441 @@ +# mypy: ignore-errors + +import copy +import logging +import os +import pickle +import random +from contextlib import contextmanager +from functools import partial +from typing import Callable, Union +import sympy + +import torch +from torch import SymInt +import torch.fx as fx +import torch.nn as nn +from torch._decomp import get_decompositions +from torch.fx.experimental.symbolic_shapes import bind_symbols + +from .aot_autograd import aot_function, aot_module, make_boxed_compiler +from .compile_utils import strip_overloads +from .partitioners import ( + default_partition, + draw_graph, + min_cut_rematerialization_partition, +) +import torch.utils._pytree as pytree + + +log = logging.getLogger(__name__) + + +# These canonicalizations are needed here (and not decompositions), as the ops +# we're trying to canonicalize to CompositeImplicitAutograd. +def _canonicalize(fx_g): + for node in fx_g.graph.nodes: + if node.target == torch.ops.aten._to_copy: + node.target = torch.ops.aten.to + fx_g.recompile() + return fx_g + + +@contextmanager +def _disable_jit_autocast(): + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) + try: + yield + finally: + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) + + +@make_boxed_compiler +def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: + """ + Compiles the :attr:`fx_g` with Torchscript compiler. + + .. warning:: + This API is experimental and likely to change. + + Args: + fx_g(fx.GraphModule): The input Fx graph module to be compiled. + + Returns: + Torch scripted model. + """ + + with _disable_jit_autocast(): + strip_overloads(fx_g) + + for node in fx_g.graph.nodes: + if ( + node.target == torch.ops.aten._to_copy + and len(node.args) == 1 + and len(node.kwargs) == 1 + and "dtype" in node.kwargs + ): + node.target = torch.ops.aten.to + + for node in fx_g.graph.nodes: + new_kwargs = {} + for k, v in node.kwargs.items(): + if isinstance(v, torch.device): + v = v.type + new_kwargs[k] = v + node.kwargs = new_kwargs + + fx_g.graph.lint() + + fx_g.recompile() + + f = torch.jit.script(fx_g) + + torch._C._jit_pass_remove_mutation(f.graph) + + f = torch.jit.freeze(f.eval()) + f = torch.jit.optimize_for_inference(f) + if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): + f(*inps) + return f + + +def _draw_graph_compile(fx_g, _, name, clear_meta=True): + print(fx_g.code) + draw_graph(fx_g, name, clear_meta=clear_meta) + return fx_g + + +def draw_graph_compile(name): + return make_boxed_compiler( + partial(_draw_graph_compile, name=name) + ) + + +@make_boxed_compiler +def nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler + and can be used to check accuracy. + + .. warning:: + This API is experimental and likely to change. + + """ + return fx_g + +class DebugInterpreter(fx.Interpreter): + def run(self, *args): + self.symbol_mapping = bind_symbols(self.module, *args) + super().run(*args) + + def run_node(self, n): + + def subst_symint(ni): + if not isinstance(ni, SymInt): + return ni + r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping)) + assert r.is_number, r + return int(r) + + def subst_symint_tuple(nis): + return tuple(subst_symint(ni) for ni in nis) + + def check_significant_strides(a, b): + if subst_symint(a.numel()) > 0: + for idx in range(a.ndim): + if subst_symint(a.stride(idx)) != b.stride(idx) and subst_symint(a.size(idx)) > 1: + return False + return True + + def check(nv, rv, desc): + assert callable(desc) + assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" + assert subst_symint_tuple(nv.size()) == rv.size(), \ + f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" + same_strides = check_significant_strides(nv, rv) + assert same_strides, f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" + + r = super().run_node(n) + if 'val' in n.meta: + n_vals, n_spec = pytree.tree_flatten(n.meta['val']) + r_vals, r_spec = pytree.tree_flatten(r) + # TODO: There is some sort of problem where we record that an + # operator returned a tuple/list, and then later it turns out the + # real version of the operator returned a list/tuple. Need to + # figure out what's actually going on here, the error itself is + # harmless enough as we only getitem out the outputs. + # assert n_spec == r_spec, f"{n_spec} != {r_spec}" + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}") + return r + + +@make_boxed_compiler +def debug_nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns a (slow) interpreter over the FX graph module that also checks + various debugging properties (e.g., that tracing strides matched real + strides.) + """ + return DebugInterpreter(fx_g).run + +@make_boxed_compiler +def simple_ts_compile(fx_g, _): + strip_overloads(fx_g) + f = torch.jit.script(fx_g) + f = torch.jit.freeze(f.eval()) + return f + + +def nnc_jit(f): + return aot_function(f, simple_ts_compile) + + +aten = torch.ops.aten +default_decompositions = { + aten.detach, + aten.gelu_backward, + aten.leaky_relu_backward, + aten.sigmoid_backward, + aten.threshold_backward, + aten.hardtanh_backward, + aten.hardsigmoid_backward, + aten.hardswish_backward, + aten.tanh_backward, + aten.silu_backward, + aten.elu_backward, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.masked_fill.Scalar, + aten.masked_fill.Tensor, + aten.elu, + aten.leaky_relu, + aten.hardtanh, + aten.hardswish, + aten.hardsigmoid, + aten.conj_physical, + aten.is_same_size, +} + +default_decompositions = get_decompositions(default_decompositions) + + +@make_boxed_compiler +def print_compile(fx_g, _): + print(fx_g.code) + return fx_g + + +def memory_efficient_fusion( + fn: Union[Callable, nn.Module], + **kwargs, +): + """ + Wrapper function over :func:`aot_function` and :func:`aot_module` to perform + memory efficient fusion. It uses the + :func:`min_cut_rematerialization_partition` partitioner to perform efficient + recomputation. It uses NVFuser to compile the generated forward and backward + graphs. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` + that takes one ore more arguments. Must return one or more Tensors. + **kwargs: Any other overrides you want to make to the settings + + Returns: + Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior + of the original :attr:`fn`, but whose forward and backward graphs have + gone through recomputation optimizations, and the graphs have been + compiled with nvfuser. + + """ + config = { + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + "decompositions": default_decompositions, + } + config.update(kwargs) + if isinstance(fn, torch.nn.Module): + return aot_module(fn, **config) + else: + return aot_function(fn, **config) + + +def debug_compile(fx_g, inps): + fx_g.to_folder("foo") + print( + f""" +############################################################## +# To minimize FX graph, copy and paste the below and run it # +############################################################## + +import torch +import torch.fx as fx +from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess + +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] +from foo import FxModule +mod = FxModule().cuda() + +with torch.jit.fuser("fuser2"): + # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess + minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) +""" + ) + from foo import FxModule + + FxModule().cuda()(*inps) + + return ts_compile(fx_g, inps) + + +graph_index = 0 + + +def get_inputs(input_data_path): + """ + Return a random input for the given inputs meta generated from _save_fx_default. + """ + inputs = [] + with (open(input_data_path, "rb")) as f: + inputs_meta = pickle.load(f) + inputs = [] + for meta in inputs_meta: + if len(meta) == 1: + type = meta + input = type(random.rand()) + else: + type, shape, stride, dtype, device = meta + if dtype in { + torch.int, + torch.int32, + torch.int64, + torch.bool, + torch.int, + torch.uint8, + int, + float, + }: + input = torch.randint(0, 1, shape, dtype=dtype, device=device) + else: + input = torch.rand(shape, dtype=dtype, device=device) + inputs.append(input) + return inputs + + +def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): + """ + The forward, backward, and joint computation graph will be stored in + {folder_name}/{current_name}/{current_name}_forward_{graph_index}, + {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and + {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. + The input shape of the graphs will be stored in the .input files. + These files can be loaded with pickle, + and is a list of format (type, shape, stride, dtype, device). + In the case of type = int or float, it is just (type,). + For joint graph input, it is a nested list [[],[]] + where the two inner lists have the same format. + If dump_example_input is True, example_inputs will be stored in .pt file. + Since each function might produce multiple graphs, + the graph_index is used to distinguish difference graphs + """ + from functorch.compile import aot_module_simplified + + def get_input_meta(args): + input_meta = [] + if len(args) > 0 and isinstance(args[0], tuple): # joint input + input_meta += get_input_meta(args[0]) + input_meta += get_input_meta(args[1]) + return input_meta + for arg in args: + if type(arg) == int or type(arg) == float: + input_meta.append((type(arg),)) + else: + input_meta.append( + (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) + ) + return input_meta + + def graph_saver_helper(gm_to_save, args, type_name): + global graph_index + if len(gm_to_save.graph.nodes) == 0: + log.log( + logging.WARNING, + "No nodes in graph {%s}_{%s}_{%s}.", + current_name, + type_name, + graph_index, + ) + return + + gm = copy.deepcopy(gm_to_save) + gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen + gm.recompile() + + input_meta = get_input_meta(args) + + os.makedirs(f"{folder_name}/{current_name}", exist_ok=True) + gm.to_folder( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" + ) + pickle.dump( + input_meta, + open( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 + "wb", + ), + ) # noqa: E501 + if dump_example_input: + torch.save( + args, + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 + ) # noqa: E501 + + def graph_saver_forward(gm, fw_args): + graph_saver_helper(gm, fw_args, "forward") + return gm + + def graph_saver_backward(gm, bw_args): + graph_saver_helper(gm, bw_args, "backward") + global graph_index + graph_index += 1 + return gm + + def graph_saver_joint(gm, joint_args): + graph_saver_helper(gm, joint_args, "joint") + return default_partition(gm, joint_args) + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=graph_saver_forward, + bw_compiler=graph_saver_backward, + partition_fn=graph_saver_joint, + decompositions=default_decompositions, + ) + + +# WARNING: This isn't tested anywhere!! +def graph_dumper_aot(current_name, folder_name, dump_example_input=False): + """ + Dump the forward, backward, and joint computation graph. + Example Usage: + save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) + optimize_ctx = torchdynamo.optimize( + save_fx_func + ) + with torch.enable_grad(): + with optimize_ctx: + result = forward_and_backward_pass(model, example_inputs) + """ + global graph_index + graph_index = 0 + return partial(_save_fx_default, current_name, folder_name, dump_example_input) diff --git a/MLPY/Lib/site-packages/torch/_functorch/config.py b/MLPY/Lib/site-packages/torch/_functorch/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0e24bdbc7f253a5bc8cb02296098d046441501f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/config.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Global flags for aot autograd +""" +import os +import sys +from typing import TYPE_CHECKING + +# Converts torch rng ops to their functional philox rng equivalents. Note that +# we functionalize only CUDA rng ops today. +functionalize_rng_ops = False + +# can be useful for debugging if we are incorrectly creating meta fake tensors +fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True) + +# Enables optional asserts in hotpath code to check for errors. If +# you are seeing weird accuracy problems, try turning this on. +# This is currently off by default as it will harm tracing time, +# but it is on by default for aot_eager. +debug_assert = False + +debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False) + +static_weight_shapes = True + +# Applies CSE to the graph before partitioning +cse = True + +# Restricts the amount of computation AOTAutograd can do. +max_dist_from_bw = 3 + +# Enable aggressive_recomputation in the min-cut algorithm in partitioners to reduce +# memory usage with some penalty of performance. It allows more ops to be considered +# as recomputable except random ops and compute-intensive ops. +aggressive_recomputation = False + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + +from torch.utils._config_module import install_config_module + +# adds patch, save_config, invalid config checks, etc +install_config_module(sys.modules[__name__]) diff --git a/MLPY/Lib/site-packages/torch/_functorch/deprecated.py b/MLPY/Lib/site-packages/torch/_functorch/deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..bf90a602b74d700be2723ca9d731c6ac49d2eebb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/deprecated.py @@ -0,0 +1,125 @@ +import torch._functorch.apis as apis +import torch._functorch.eager_transforms as _impl +import torch._functorch.make_functional as _nn_impl +from torch._functorch.vmap import in_dims_t, out_dims_t +from torch._functorch.eager_transforms import argnums_t +import torch.nn as nn +import textwrap +from typing import Any, Callable, Optional, Tuple, Union +import warnings + +""" +The APIs in this file are exposed as `functorch.*`. They are thin wrappers +around the torch.func.* APIs that have deprecation warnings -- we're trying +to move people to the torch.func.* equivalents. + +NB: We don't use *args, **kwargs in the signatures because that changes the +documentation. +""" + +def get_warning(api, new_api=None, replace_newlines=False): + if new_api is None: + new_api = f'torch.func.{api}' + warning = ( + f"We've integrated functorch into PyTorch. As the final step of the \n" + f"integration, functorch.{api} is deprecated as of PyTorch \n" + f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" + f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n" + f"and/or the torch.func migration guide for more details \n" + f"https://pytorch.org/docs/master/func.migrating.html" + ) + if replace_newlines: + warning = warning.replace("\n", "") + return warning + + +def warn_deprecated(api, new_api=None): + warning = get_warning(api, new_api, replace_newlines=True) + warnings.warn(warning, stacklevel=2) + + +def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): + api_name = functorch_api.__name__ + if torch_func_api is None: + torch_func_api = getattr(_impl, api_name) + # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO + if torch_func_api.__doc__ is None: + return + + warning = get_warning(api_name, new_api_name) + warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ") + warning_note = textwrap.indent(warning_note, " ") + functorch_api.__doc__ = torch_func_api.__doc__ + warning_note + +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = 'error', + *, + chunk_size=None) -> Callable: + warn_deprecated('vmap', 'torch.vmap') + return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size) + +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + warn_deprecated('grad') + return apis.grad(func, argnums, has_aux) + +def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + warn_deprecated('grad_and_value') + return apis.grad_and_value(func, argnums, has_aux) + +def vjp(func: Callable, *primals, has_aux: bool = False): + warn_deprecated('vjp') + return _impl.vjp(func, *primals, has_aux=has_aux) + +def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False): + warn_deprecated('jvp') + return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux) + +def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False): + warn_deprecated('jacrev') + return _impl.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size, + _preallocate_and_copy=_preallocate_and_copy) + +def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"): + warn_deprecated('jacfwd') + return _impl.jacfwd(func, argnums, has_aux, randomness=randomness) + +def hessian(func, argnums=0): + warn_deprecated('hessian') + return _impl.hessian(func, argnums=argnums) + +def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable: + warn_deprecated('functionalize') + return _impl.functionalize(func, remove=remove) + +def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): + warn_deprecated('make_functional', 'torch.func.functional_call') + return _nn_impl.make_functional(model, disable_autograd_tracking) + +def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False): + warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call') + return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking) + +def combine_state_for_ensemble(models): + warn_deprecated('combine_state_for_ensemble', 'torch.func.stack_module_state') + return _nn_impl.combine_state_for_ensemble(models) + +setup_docs(vmap, apis.vmap, 'torch.vmap') +setup_docs(grad, apis.grad) +setup_docs(grad_and_value, apis.grad_and_value) +setup_docs(vjp) +setup_docs(jvp) +setup_docs(jacrev) +setup_docs(jacfwd) +setup_docs(hessian) +setup_docs(functionalize) +setup_docs(make_functional, _nn_impl.make_functional, + 'torch.func.functional_call') +setup_docs(make_functional_with_buffers, _nn_impl.make_functional, + 'torch.func.functional_call') +setup_docs(combine_state_for_ensemble, _nn_impl.combine_state_for_ensemble, + 'torch.func.stack_module_state') diff --git a/MLPY/Lib/site-packages/torch/_functorch/eager_transforms.py b/MLPY/Lib/site-packages/torch/_functorch/eager_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a3767e0751d9360f7e5cefacf790cbc715f0e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/eager_transforms.py @@ -0,0 +1,1640 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Union, Tuple, List, Any, Optional +import torch +from functools import partial, wraps +import contextlib +from torch.utils._pytree import ( + tree_flatten, + tree_unflatten, + tree_map, + tree_map_only, + tree_map_, + treespec_pprint, +) +from torch.utils import _pytree as pytree +from torch.fx.experimental import const_fold +from torch.fx.experimental.proxy_tensor import make_fx +import torch.autograd.forward_ad as fwAD +from torch._subclasses.functional_tensor import FunctionalTensor + +from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes +from .apis import vmap + +from torch._C._functorch import ( + _wrap_for_grad, + _unwrap_for_grad, + _grad_increment_nesting, + _grad_decrement_nesting, + _jvp_increment_nesting, + _jvp_decrement_nesting, + _wrap_functional_tensor, + _unwrap_functional_tensor, + _func_decrement_nesting, + _func_increment_nesting, + _assert_wrapped_functional, + _propagate_functional_input_mutation, + set_inplace_requires_grad_allowed, + get_inplace_requires_grad_allowed, +) +from torch._functorch.utils import exposed_in, argnums_t + + +def lazy_dynamo_disable(func): + import torch._dynamo + return torch._dynamo.disable(func) + +@contextlib.contextmanager +def enable_inplace_requires_grad(enabled): + prev_state = get_inplace_requires_grad_allowed() + set_inplace_requires_grad_allowed(enabled) + try: + yield + finally: + set_inplace_requires_grad_allowed(prev_state) + + +def _vjp_treespec_compare(primals_out, cotangents): + # Revert this once #116264 gets fixed + _, primals_out_spec = tree_flatten(primals_out) + _, cotangents_spec = tree_flatten(cotangents) + # Dynamo fails to trace operator.ne below. To bypass this limitation, this + # function is not inlined. + if primals_out_spec != cotangents_spec: + raise RuntimeError( + f'Expected pytree structure of cotangents to be the same ' + f'as pytree structure of outputs to the function. ' + f'cotangents: {treespec_pprint(cotangents_spec)}, ' + f'primal output: {treespec_pprint(primals_out_spec)}') + + +def _set_tensor_requires_grad(x): + # avoid graph-break on x.requires_grad_() + # https://github.com/pytorch/pytorch/pull/110053 + return x.requires_grad_() + +def _create_differentiable(inps, level=None): + def create_differentiable(x): + if isinstance(x, torch.Tensor): + with enable_inplace_requires_grad(True): + return _set_tensor_requires_grad(x) + raise ValueError(f'Thing passed to transform API must be Tensor, ' + f'got {type(x)}') + return tree_map(create_differentiable, inps) + + +def _undo_create_differentiable(inps, level=None): + def unwrap_tensors(x): + if isinstance(x, torch.Tensor): + return _unwrap_for_grad(x, level) + # TODO: Remove the following hack for namedtuples + if isinstance(x, tuple): + return tree_map(unwrap_tensors, tuple(x)) + + raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}") + + return tree_map(unwrap_tensors, inps) + + +def _is_differentiable(maybe_tensor): + if not isinstance(maybe_tensor, torch.Tensor): + return False + return maybe_tensor.requires_grad + + +def _any_differentiable(tensor_or_tuple_of_tensors): + flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) + return any(tuple(map(_is_differentiable, flat_args))) + + +def _wrap_tensor_for_grad(maybe_tensor, level): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + return _wrap_for_grad(maybe_tensor, level) + + +def _wrap_all_tensors(tensor_pytree, level): + return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree) + + +def _as_tuple(val): + if isinstance(val, tuple): + return val + return (val,) + +# Version of autograd.grad that handles outputs that don't depend on inputs + + +def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True): + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad) + if len(result) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*result) + if len(diff_outputs) == 0: + return tuple(torch.zeros_like(inp) for inp in inputs) + grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True) + grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi + for gi, inp in zip(grad_inputs, inputs)) + return grad_inputs + +# NOTE [grad and vjp interaction with no_grad] +# +# def f(x): +# with torch.no_grad(): +# c = x ** 2 +# return x - c +# +# The thing to consider is if enable_grad is on/off before grad gets called. +# +# Case 1: enable_grad is on. +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad. +# +# Case 2: enable_grad is off +# with torch.no_grad(): +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad, but not the +# outer one. This is because `grad` is a "function transform": its result +# should not depend on the result of a context manager outside of `f`. +# +# This gives us the following desired behavior: +# - (nested) grad transforms must obey torch.no_grad inside them +# - (nested) grad transforms should not obey torch.no_grad outside them +# +# To achieve this behavior, upon entering grad/vjp: +# - we save the current ("previous") is_grad_enabled (*) +# - we unconditionally enable grad. +# +# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer +# off the stack: +# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad +# active, all subsequent grad transforms must obey it). +# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False, +# then we temporarily restore the previous `is_grad_enabled`. This is +# because we're crossing the boundary from a `grad` outside the +# no_grad to a `grad` inside the no_grad. +# +# NB: vjp has some interesting behavior because the vjp's callable can be called +# under a different grad_mode than the forward computation... +# +# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but +# it respects c10::AutoFwGradMode. We've implemented the same logic for +# our jvp transform (it will have special handling if FwGradMode is disabled). + + +# How do we increment and decrement the nesting? I don't think we can. +@exposed_in("torch.func") +def vjp(func: Callable, *primals, has_aux: bool = False): + """ + Standing for the vector-Jacobian product, returns a tuple containing the + results of ``func`` applied to ``primals`` and a function that, when + given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with + respect to ``primals`` times ``cotangents``. + + Args: + func (Callable): A Python function that takes one or more arguments. Must + return one or more Tensors. + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, vjp_fn)`` tuple containing the output of ``func`` + applied to ``primals`` and a function that computes the vjp of + ``func`` with respect to all ``primals`` using the cotangents passed + to the returned function. If ``has_aux is True``, then instead returns a + ``(output, vjp_fn, aux)`` tuple. + The returned ``vjp_fn`` function will return a tuple of each VJP. + + When used in simple cases, :func:`vjp` behaves the same as :func:`grad` + + >>> x = torch.randn([5]) + >>> f = lambda x: x.sin().sum() + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> grad = vjpfunc(torch.tensor(1.))[0] + >>> assert torch.allclose(grad, torch.func.grad(f)(x)) + + However, :func:`vjp` can support functions with multiple outputs by + passing in the cotangents for each of the outputs + + >>> x = torch.randn([5]) + >>> f = lambda x: (x.sin(), x.cos()) + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + :func:`vjp` can even support outputs being Python structs + + >>> x = torch.randn([5]) + >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} + >>> vjps = vjpfunc(cotangents) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + The function returned by :func:`vjp` will compute the partials with + respect to each of the ``primals`` + + >>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) + >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) + >>> cotangents = torch.randn([5, 5]) + >>> vjps = vjpfunc(cotangents) + >>> assert len(vjps) == 2 + >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) + >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents)) + + ``primals`` are the positional arguments for ``f``. All kwargs use their + default value + + >>> x = torch.randn([5]) + >>> def f(x, scale=4.): + >>> return x * scale + >>> + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> vjps = vjpfunc(torch.ones_like(x)) + >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``vjp``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager: + + >>> # xdoctest: +SKIP(failing) + >>> with torch.no_grad(): + >>> vjp(f)(x) + + In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``vjp`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + return _vjp_with_argnums(func, *primals, has_aux=has_aux) + + +@contextlib.contextmanager +def grad_increment_nesting(): + try: + grad_level = _grad_increment_nesting() + yield grad_level + finally: + _grad_decrement_nesting() + + +@doesnt_support_saved_tensors_hooks +def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False): + # This is the same function as vjp but also accepts an argnums argument + # All args are the same as vjp except for the added argument + # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. + # If None, computes the gradients with respect to all inputs (used for vjp). Default: None + # + # WARN: Users should NOT call this function directly and should just be calling vjp. + # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers. + # + # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev + # + # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs + # for only the primal elements given by argnums. + with grad_increment_nesting() as level: + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + primals = _wrap_all_tensors(primals, level) + # Note for the reviewer: This is extremely odd but it passes the + # assertion "len(self.block_stack) == 1" on symbolic_convert.py + # The equivalent "if argnums is None" fails for some reason + if not isinstance(argnums, int) and not argnums: + diff_primals = _create_differentiable(primals, level) + else: + diff_primals = _slice_argnums(primals, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_primals) + primals_out = func(*primals) + + if has_aux: + if not (isinstance(primals_out, tuple) and len(primals_out) == 2): + raise RuntimeError( + "vjp(f, *primals): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + primals_out, aux = primals_out + aux = _undo_create_differentiable(aux, level) + + flat_primals_out, primals_out_spec = tree_flatten(primals_out) + assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)') + flat_diff_primals, primals_spec = tree_flatten(diff_primals) + results = _undo_create_differentiable(primals_out, level) + + for primal_out in flat_primals_out: + assert isinstance(primal_out, torch.Tensor) + if primal_out.is_floating_point() or primal_out.is_complex(): + continue + raise RuntimeError("vjp(f, ...): All outputs of f must be " + "floating-point or complex Tensors, got Tensor " + f"with dtype {primal_out.dtype}") + + def wrapper(cotangents, retain_graph=True, create_graph=None): + if create_graph is None: + create_graph = torch.is_grad_enabled() + flat_cotangents, cotangents_spec = tree_flatten(cotangents) + _vjp_treespec_compare(primals_out, cotangents) + result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents, + retain_graph=retain_graph, create_graph=create_graph) + return tree_unflatten(result, primals_spec) + + if has_aux: + return results, wrapper, aux + else: + return results, wrapper + + +def _safe_zero_index(x): + assert len(x) == 1 + return x[0] + +# jacrev and jacfwd don't support complex functions +# Helper function to throw appropriate error. +def error_if_complex(func_name, args, is_input): + flat_args = pytree.tree_leaves(args) + for idx, arg in enumerate(flat_args): + if isinstance(arg, torch.Tensor) and arg.dtype.is_complex: + input_or_output = ("inputs" if is_input else "outputs") + err_msg = (f"{func_name}: Expected all {input_or_output} " + f"to be real but received complex tensor at flattened input idx: {idx}") + raise RuntimeError(err_msg) + +@exposed_in("torch.func") +def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False): + """ + Computes the Jacobian of ``func`` with respect to the arg(s) at index + ``argnum`` using reverse mode autodiff + + .. note:: + Using :attr:`chunk_size=1` is equivalent to computing the jacobian + row-by-row with a for-loop i.e. the constraints of :func:`vmap` are + not applicable. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + chunk_size (None or int): If None (default), use the maximum chunk size + (equivalent to doing a single vmap over vjp to compute the jacobian). + If 1, then compute the jacobian row-by-row with a for-loop. + If not None, then compute the jacobian :attr:`chunk_size` rows at a time + (equivalent to doing multiple vmap over vjp). If you run into memory issues computing + the jacobian, please try to specify a non-None chunk_size. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from torch.func import jacrev + >>> x = torch.randn(5) + >>> jacobian = jacrev(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from torch.func import jacrev + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + :func:`jacrev` can be composed with vmap to produce batched + Jacobians: + + >>> from torch.func import jacrev, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacrev(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + Additionally, :func:`jacrev` can be composed with itself to produce + Hessians + + >>> from torch.func import jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacrev(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacrev` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using ``argnums``: + + >>> from torch.func import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to ``argnums`` will compute the Jacobian + with respect to multiple arguments + + >>> from torch.func import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``jacrev``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager: + + >>> with torch.no_grad(): + >>> jacrev(f)(x) + + In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``jacrev`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + if not (chunk_size is None or chunk_size > 0): + raise ValueError("jacrev: `chunk_size` should be greater than 0.") + + @wraps(func) + def wrapper_fn(*args): + error_if_complex("jacrev", args, is_input=True) + vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux) + if has_aux: + output, vjp_fn, aux = vjp_out + else: + output, vjp_fn = vjp_out + + # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs] + flat_output, output_spec = tree_flatten(output) + + error_if_complex("jacrev", flat_output, is_input=False) + + # NB: vjp already checks that all outputs are tensors + # Step 1: Construct grad_outputs by splitting the standard basis + flat_output_numels = tuple(out.numel() for out in flat_output) + + primals = _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + + def compute_jacobian_stacked(): + # Helper function to compute chunked Jacobian + # The intermediate chunked calculation are only + # scoped at this function level. + chunked_results = [] + for flat_basis_chunk in _chunked_standard_basis_for_(flat_output, + flat_output_numels, + chunk_size=chunk_size): + if chunk_size == 1: + # sanity check. + for t in flat_basis_chunk: + assert t.size(0) == 1 + + flat_basis_chunk = tree_map(lambda t: torch.squeeze(t, 0), flat_basis_chunk) + + basis = tree_unflatten(flat_basis_chunk, output_spec) + + if chunk_size == 1: + # Behaviour with `chunk_size=1` is same as `for-loop` + # i.e. user shouldn't deal with the limitations of vmap. + chunked_result = vjp_fn(basis) + else: # chunk_size is None or chunk_size != 1 + chunked_result = vmap(vjp_fn)(basis) + + flat_results = pytree.tree_leaves(chunked_result) + + if chunk_size == 1: + flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results) + + chunked_results.append(flat_results) + + if len(chunked_results) == 1: + # Short-circuit if we used a single chunk + return chunked_results[0] + + # Concatenate chunks. + flat_results = [] + # Iterate and concat the jacobians of different + # inputs. + for idx in range(len(flat_primals)): + r = tuple(r_[idx] for r_ in chunked_results) + flat_results.append(torch.cat(r, 0)) + + return flat_results + + def compute_jacobian_preallocate_and_copy(): + # Helper function to compute chunked Jacobian + # The intermediate chunked calculation are only + # scoped at this function level. + out_vec_size = sum(flat_output_numels) + + # Don't pre-allocate if we have a single chunk. + if not (chunk_size is None or chunk_size >= out_vec_size): + stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals] + + for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output, + flat_output_numels, + chunk_size=chunk_size)): + if chunk_size == 1: + # sanity check. + for t in flat_basis_chunk: + assert t.size(0) == 1 + + flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk] + + basis = tree_unflatten(flat_basis_chunk, output_spec) + + if chunk_size == 1: + # Behaviour with `chunk_size=1` is same as `for-loop` + # i.e. user shouldn't deal with the limitations of vmap. + chunked_result = vjp_fn(basis) + else: # chunk_size is None or chunk_size != 1 + chunked_result = vmap(vjp_fn)(basis) + + flat_results = pytree.tree_leaves(chunked_result) + + # Short-circuit if we have a single chunk. + if chunk_size is None or chunk_size >= out_vec_size: + if chunk_size == 1: # and out_vec_size == 1 + # Since we squeezed the output dim + flat_results = tree_map(lambda t: torch.unsqueeze(t, 0), flat_results) + return flat_results + + for r, sr in zip(flat_results, stacked_results): + sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r) + + return stacked_results + + if _preallocate_and_copy: + flat_jacobians_per_input = compute_jacobian_preallocate_and_copy() + else: + flat_jacobians_per_input = compute_jacobian_stacked() + + # Step 2: The returned jacobian is one big tensor per input. In this step, + # we split each Tensor by output. + flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input] + flat_input_flat_output = [ + tuple(split.view(out.shape + primal.shape) + for split, out in zip(splits, flat_output)) + for splits, primal in zip(flat_jacobians_per_input, flat_primals) + ] + + # Step 3: Right now, `jacobian` is a List[List[Tensor]]. + # The outer List corresponds to the number of primals, + # the inner List corresponds to the number of outputs. + # We need to: + # a. Exchange the order of the outer List and inner List + # b. tree_unflatten the inner Lists (which correspond to the primals) + # c. handle the argnums=int case + # d. tree_unflatten the outer List (which corresponds to the outputs) + flat_output_flat_input = tuple(zip(*flat_input_flat_output)) + + flat_output_input = tuple(tree_unflatten(flat_input, primals_spec) + for flat_input in flat_output_flat_input) + + if isinstance(argnums, int): + flat_output_input = tuple(_safe_zero_index(flat_input) + for flat_input in flat_output_input) + output_input = tree_unflatten(flat_output_input, output_spec) + if has_aux: + return output_input, aux + return output_input + return wrapper_fn + +# NOTE: [Computing jacobian with vmap and vjp for multiple outputs] +# +# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). +# It turns out we can compute the jacobian of this function with a single +# call to autograd.grad by using vmap over the correct grad_outputs. +# +# Firstly, one way to compute the jacobian is to stack x**2 and x.sum() +# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) +# +# To get the first row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) +# To get the 2nd row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) +# and so on. +# +# Using vmap, we can vectorize all 4 of these computations into one by +# passing the standard basis for R^4 as the grad_output. +# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). +# +# Now, how do we compute the jacobian *without stacking the output*? +# We can just split the standard basis across the outputs. So to +# compute the jacobian of f(x), we'd use +# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) +# The grad_outputs looks like the following: +# ( torch.tensor([[1, 0, 0], +# [0, 1, 0], +# [0, 0, 1], +# [0, 0, 0]]), +# torch.tensor([[0], +# [0], +# [0], +# [1]]) ) +# +# But we're not done yet! +# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) +# returns a Tensor of shape [4, 3]. We have to remember to split the +# jacobian of shape [4, 3] into two: +# - one of shape [3, 3] for the first output +# - one of shape [ 3] for the second output + + +def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): + # This function: + # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. + # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. + # - Each chunk corresponds to one tensor. The chunk has the same dtype and + # device as the tensor + # + # For example, with tensor_numels = [1, 2, 1], this function returns: + # ( tensor([[1], tensor([[0, 0], tensor([[0], + # [0], [1, 0], [0], + # [0], [0, 1], [0], + # [0]]) , [0, 0]]) , [1]]) ) + # + # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) + # Precondition: tensors always has at least one element. + # + # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] + # for context behind this function. + # NOTE: Argument `chunk_size` is used to generate chunked basis instead of + # one huge basis matrix. `chunk_size` dictates the maximum size of the + # basis matrix along dim=0. + assert len(tensors) == len(tensor_numels) + assert len(tensors) > 0 + assert chunk_size is None or chunk_size > 0 + total_numel = sum(tensor_numels) + if chunk_size and chunk_size < total_numel: + chunk_numels = get_chunk_sizes(total_numel, chunk_size) + else: # chunk_size is None or chunk_size >= total_numel + chunk_size = total_numel + chunk_numels = [total_numel] + + diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind()) + + for chunk_idx, total_numel in enumerate(chunk_numels): + chunks = tuple(tensor.new_zeros(total_numel, tensor_numel) + for tensor, tensor_numel in zip(tensors, tensor_numels)) + + for chunk, diag_start_idx in zip(chunks, diag_start_indices): + chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1) + chunks = tuple(chunk.view(total_numel, *tensor.shape) + for chunk, tensor in zip(chunks, tensors)) + yield chunks + +def _construct_standard_basis_for(tensors, tensor_numels): + for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): + return basis + + +def _validate_and_wrap_argnum(argnum, num_args): + if not isinstance(argnum, int): + raise RuntimeError(f'argnum must be int, got: {type(argnum)}') + if argnum >= 0 and argnum < num_args: + return argnum + if argnum < 0 and argnum >= -num_args: + return argnum + num_args + raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs') + + +def _check_unique_non_empty(argnums): + if isinstance(argnums, tuple): + if len(argnums) == 0: + raise RuntimeError("argnums must be non-empty") + if len(set(argnums)) != len(argnums): + raise RuntimeError(f"argnums elements must be unique, got {argnums}") + + +def _replace_args(old_args, new_args, argnums): + if isinstance(argnums, int): + if len(new_args) != 1: + raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}') + return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args))) + if isinstance(argnums, tuple): + if len(new_args) != len(argnums): + raise RuntimeError( + "new_args should have the same size as argnums. " + f"Argnums size {len(argnums)}, new_args size {len(new_args)}") + + def get_right_elem(i): + return new_args[argnums.index(i)] if i in argnums else old_args[i] + + return tuple(get_right_elem(i) for i in range(len(old_args))) + raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}') + + +def _validate_and_wrap_argnums(argnums, num_args): + if isinstance(argnums, int): + return _validate_and_wrap_argnum(argnums, num_args) + if isinstance(argnums, tuple): + return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums) + raise AssertionError("Should never get here") + + +def _slice_argnums(args, argnums, as_tuple=True): + if not isinstance(argnums, int) and not isinstance(argnums, tuple): + raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}') + argnums = _validate_and_wrap_argnums(argnums, len(args)) + _check_unique_non_empty(argnums) + if isinstance(argnums, int): + if as_tuple: + return (args[argnums],) + else: + return args[argnums] + return tuple(args[i] for i in argnums) + + +JVP_NESTING = 0 + + +@contextlib.contextmanager +def noop(): + yield + + +def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None: + if not isinstance(elts, tuple): + raise RuntimeError( + f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}') + for elt in elts: + if isinstance(elt, torch.Tensor): + continue + raise RuntimeError( + f'{api}: Expected {argname} to be a tuple of Tensors, got ' + f'a tuple with an element of type {type(elt)}') + if len(elts) == 0: + raise RuntimeError( + f'{api}: Expected {argname} to be a non-empty tuple of Tensors.') + + +def assert_non_empty_tensor_output(output: List[Any], api: str) -> None: + if (len(output) == 1 and output[0] is None) or len(output) < 1: + raise RuntimeError( + f'{api}: Expected f to be a function that has non-empty output (got output = {output})' + ) + for o in output: + if not isinstance(o, torch.Tensor): + raise RuntimeError( + f'{api}: expected f(*primals) to return only tensors' + f', got unsupported type {type(o)}' + ) + + +def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: + if isinstance(output, torch.Tensor): + return + if not isinstance(output, tuple): + raise RuntimeError( + f'{api}: Expected output of f to be a Tensor or Tensors, got ' + f'{type(output)}') + if len(output) == 0: + raise RuntimeError( + f'{api}: Expected output of f to be a non-empty tuple of Tensors.') + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f'{api}: Expected output of f to be a Tensor or Tensors, got ' + f'{type(out)} as an output') + + +def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None: + if len(output) == 0: + raise RuntimeError( + f'{api}: Expected {argname} to contain at least one Tensor.') + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f'{api}: Expected {argname} to only contain Tensors, got ' + f'{type(out)}') + + +jvp_str = 'jvp(f, primals, tangents)' + + +def safe_unpack_dual(dual, strict): + if not isinstance(dual, torch.Tensor): + raise RuntimeError( + f'{jvp_str}: expected f(*args) to return only tensors' + f', got unsupported type {type(dual)}' + ) + + primal, tangent = fwAD.unpack_dual(dual) + if tangent is None: + if strict: + raise RuntimeError( + 'jvp(f, primals, tangents, strict=True): ' + 'The output of f is independent of ' + 'the inputs. This is not allowed with strict=True.') + tangent = torch.zeros_like(primal) + return primal, tangent + + +@exposed_in("torch.func") +def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False): + """ + Standing for the Jacobian-vector product, returns a tuple containing + the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at + ``primals``" times ``tangents``. This is also known as forward-mode autodiff. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + tangents (Tensors): The "vector" for which Jacobian-vector-product is + computed. Must be the same structure and sizes as the inputs to + ``func``. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, jvp_out)`` tuple containing the output of ``func`` + evaluated at ``primals`` and the Jacobian-vector product. + If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + + jvp is useful when you wish to compute gradients of a function R^1 -> R^N + + >>> from torch.func import jvp + >>> x = torch.randn([]) + >>> f = lambda x: x * torch.tensor([1., 2., 3]) + >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) + >>> assert torch.allclose(value, f(x)) + >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) + + :func:`jvp` can support functions with multiple inputs by passing in the + tangents for each of the inputs + + >>> from torch.func import jvp + >>> x = torch.randn(5) + >>> y = torch.randn(5) + >>> f = lambda x, y: (x * y) + >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) + >>> assert torch.allclose(output, x + y) + + """ + + return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux) + + +@doesnt_support_saved_tensors_hooks +def _jvp_with_argnums(func: Callable, primals: Any, tangents: Any, argnums: Optional[argnums_t], *, + strict: bool = False, has_aux: bool): + # This is the same function as jvp but also accepts an argnums argument + # Most args are the same as jvp except for the added argument + # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. + # If None, computes the gradients with respect to all inputs (used for jvp). Default: None + # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is + # given by argnums + # + # WARN: Users should NOT call this function directly and should just be calling jvp. + # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers. + # + # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd + # + # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to + # the primals given by argnums + if not isinstance(primals, tuple): + raise RuntimeError( + f'{jvp_str}: Expected primals to be a tuple. ' + f'E.g. it should be valid to call f(*primals).') + diff_args = primals if argnums is None else _slice_argnums(primals, argnums) + flat_primals, primals_spec = tree_flatten(diff_args) + flat_tangents, tangents_spec = tree_flatten(tangents) + if primals_spec != tangents_spec: + raise RuntimeError( + f'{jvp_str}: Expected primals and tangents to have the same python ' + f'structure. For example, if primals is a tuple of 3 tensors, ' + f'tangents also must be. Got primals with structure {primals_spec} ' + f'and tangents with structure {tangents_spec}') + assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals') + assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents') + + level = _jvp_increment_nesting() + try: + global JVP_NESTING + JVP_NESTING += 1 + with fwAD._set_fwd_grad_enabled(True): + ctx = fwAD.dual_level if JVP_NESTING == 1 else noop + with ctx(): + flat_duals = tuple(fwAD.make_dual(p, t) + for p, t in zip(flat_primals, flat_tangents)) + duals = tree_unflatten(flat_duals, primals_spec) + if argnums is not None: + primals = _wrap_all_tensors(primals, level) + duals = _replace_args(primals, duals, argnums) + result_duals = func(*duals) + if has_aux: + if not (isinstance(result_duals, tuple) and len(result_duals) == 2): + raise RuntimeError( + f"{jvp_str}: output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + result_duals, aux = result_duals + aux = _undo_create_differentiable(aux, level) + + result_duals, spec = tree_flatten(result_duals) + assert_non_empty_tensor_output(result_duals, jvp_str) + + primals_out, tangents_out = \ + zip(*[safe_unpack_dual(dual, strict) for dual in result_duals]) + primals_out = tree_map( + partial(_undo_create_differentiable, level=level), primals_out) + tangents_out = tree_map( + partial(_undo_create_differentiable, level=level), tangents_out) + + primals_out_unflatten = tree_unflatten(primals_out, spec) + tangents_out_unflatten = tree_unflatten(tangents_out, spec) + if has_aux: + return primals_out_unflatten, tangents_out_unflatten, aux + + return primals_out_unflatten, tangents_out_unflatten + finally: + _jvp_decrement_nesting() + JVP_NESTING -= 1 + + +def safe_unflatten(tensor, dim, shape): + if len(shape) == 0: + assert tensor.shape[dim] == 1 + return tensor.squeeze(dim) + return tensor.unflatten(dim, shape) + + +@exposed_in("torch.func") +def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"): + """ + Computes the Jacobian of ``func`` with respect to the arg(s) at index + ``argnum`` using forward-mode autodiff + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + randomness(str): Flag indicating what type of randomness to use. + See :func:`vmap` for more detail. Allowed: "different", "same", "error". + Default: "error" + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + An alternative is to use :func:`jacrev`, which has better operator coverage. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from torch.func import jacfwd + >>> x = torch.randn(5) + >>> jacobian = jacfwd(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + :func:`jacfwd` can be composed with vmap to produce batched + Jacobians: + + >>> from torch.func import jacfwd, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacfwd(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from torch.func import jacfwd + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev` + to produce Hessians + + >>> from torch.func import jacfwd, jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacfwd` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using ``argnums``: + + >>> from torch.func import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to ``argnums`` will compute the Jacobian + with respect to multiple arguments + + >>> from torch.func import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + """ + @wraps(func) + def wrapper_fn(*args): + error_if_complex("jacfwd", args, is_input=True) + primals = args if argnums is None else _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + flat_primals_numels = tuple(p.numel() for p in flat_primals) + flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) + basis = tree_unflatten(flat_basis, primals_spec) + + def push_jvp(basis): + output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux) + # output[0] is the output of `func(*args)` + error_if_complex("jacfwd", output[0], is_input=False) + if has_aux: + _, jvp_out, aux = output + return jvp_out, aux + _, jvp_out = output + return jvp_out + + results = vmap(push_jvp, randomness=randomness)(basis) + if has_aux: + results, aux = results + # aux is in the standard basis format, e.g. NxN matrix + # We need to fetch the first element as original `func` output + flat_aux, aux_spec = tree_flatten(aux) + flat_aux = [value[0] for value in flat_aux] + aux = tree_unflatten(flat_aux, aux_spec) + + jac_outs, spec = tree_flatten(results) + # Most probably below output check can never raise an error + # as jvp should test the output before + # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') + + jac_outs_ins = tuple( + tuple( + safe_unflatten(jac_out_in, -1, primal.shape) + for primal, jac_out_in in + zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1)) + ) + for jac_out in jac_outs + ) + jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins) + + if isinstance(argnums, int): + jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) + if has_aux: + return tree_unflatten(jac_outs_ins, spec), aux + return tree_unflatten(jac_outs_ins, spec) + return wrapper_fn + + +@exposed_in("torch.func") +def hessian(func, argnums=0): + """ + Computes the Hessian of ``func`` with respect to the arg(s) at index + ``argnum`` via a forward-over-reverse strategy. + + The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is + a good default for good performance. It is possible to compute Hessians + through other compositions of :func:`jacfwd` and :func:`jacrev` like + ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Hessian with respect to. + Default: 0. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Hessian of ``func`` with respect to the arg(s) at + ``argnums``. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + An alternative is to use ``jacrev(jacrev(func))``, which has better + operator coverage. + + A basic usage with a R^N -> R^1 function gives a N x N Hessian: + + >>> from torch.func import hessian + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hess, torch.diag(-x.sin())) + + """ + return jacfwd(jacrev(func, argnums), argnums) + + +@doesnt_support_saved_tensors_hooks +def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable: + with grad_increment_nesting() as level: + output, aux, grad_input = None, None, None + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + args = _wrap_all_tensors(args, level) + kwargs = _wrap_all_tensors(kwargs, level) + diff_args = _slice_argnums(args, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_args) + + output = func(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + output, aux = output + + if not isinstance(output, torch.Tensor): + raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) ' + f'to return a Tensor, got {type(output)}') + if output.dim() != 0: + raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) ' + 'to return a scalar Tensor, got tensor with ' + f'{output.dim()} dims. Maybe you wanted to ' + 'use the vjp or jacrev APIs instead?') + + flat_diff_args, spec = tree_flatten(diff_args) + + # NB: need create_graph so that backward pass isn't run in no_grad mode + flat_outputs = _as_tuple(output) + flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True) + grad_input = tree_unflatten(flat_grad_input, spec) + + grad_input = _undo_create_differentiable(grad_input, level) + output = _undo_create_differentiable(output, level) + if has_aux: + aux = _undo_create_differentiable(aux, level) + + if has_aux: + return grad_input, (output, aux) + return grad_input, output + + +def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs): + results = grad_and_value_impl(func, argnums, has_aux, args, kwargs) + if has_aux: + grad, (_, aux) = results + return grad, aux + grad, _ = results + return grad + +def _maybe_wrap_functional_tensor(maybe_tensor, level, *, _python_functionalize: bool = False): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + wrapped = _wrap_functional_tensor(maybe_tensor, level) + _assert_wrapped_functional(maybe_tensor, wrapped) + if _python_functionalize: + out = FunctionalTensor(wrapped) + torch._mirror_autograd_meta_to(maybe_tensor, out) + return out + return wrapped + + +def _wrap_all_tensors_to_functional(tensor_pytree, level, *, _python_functionalize: bool = False): + return tree_map(partial(lambda x: _maybe_wrap_functional_tensor( + x, level, _python_functionalize=_python_functionalize)), tensor_pytree) + + +def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + if isinstance(maybe_tensor, FunctionalTensor): + maybe_tensor = maybe_tensor.elem + + if not torch._is_functional_tensor(maybe_tensor): + # If it's not a functional tensor, just return it. + # This can happen if we functionalize a fn that returns a global, + # which was never wrapped properly. + return maybe_tensor + # Sync any pending updates on the output tensor + torch._sync(maybe_tensor) + return _unwrap_functional_tensor(maybe_tensor, reapply_views) + + +def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool): + return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree) + + +@exposed_in("torch.func") +def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable: + """ + functionalize is a transform that can be used to remove (intermediate) + mutations and aliasing from a function, while preserving the function's + semantics. + + ``functionalize(func)`` returns a new function with the same semantics + as ``func``, but with all intermediate mutations removed. + Every inplace operation performed on an intermediate tensor: + ``intermediate.foo_()`` + gets replaced by its out-of-place equivalent: + ``intermediate_updated = intermediate.foo()``. + + functionalize is useful for shipping a pytorch program off to + backends or compilers that aren't able to easily represent + mutations or aliasing operators. + + Args: + func (Callable): A Python function that takes one or more arguments. + remove (str): An optional string argument, that takes on either + the value 'mutations' or 'mutations_and_views'. + If 'mutations' is passed in then all mutating operators + will be replaced with their non-mutating equivalents. + If 'mutations_and_views' is passed in, then additionally, all aliasing + operators will be replaced with their non-aliasing equivalents. + Default: 'mutations'. + + Returns: + Returns a new "functionalized" function. It takes the same inputs as + ``func``, and has the same behavior, but any mutations + (and optionally aliasing) performed on intermediate tensors + in the function will be removed. + + functionalize will also remove mutations (and views) that were performed on function inputs. + However to preserve semantics, functionalize will "fix up" the mutations after + the transform has finished running, by detecting if any tensor inputs "should have" + been mutated, and copying the new data back to the inputs if necessary. + + + Example:: + + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> from torch.func import functionalize + >>> + >>> # A function that uses mutations and views, but only on intermediate tensors. + >>> def f(a): + ... b = a + 1 + ... c = b.view(-1) + ... c.add_(1) + ... return b + ... + >>> inpt = torch.randn(2) + >>> + >>> out1 = f(inpt) + >>> out2 = functionalize(f)(inpt) + >>> + >>> # semantics are the same (outputs are equivalent) + >>> print(torch.allclose(out1, out2)) + True + >>> + >>> f_traced = make_fx(f)(inpt) + >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> print(f_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]) + add_ = torch.ops.aten.add_(view, 1); view = None + return add + + >>> print(f_no_mutations_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]); add = None + add_1 = torch.ops.aten.add(view, 1); view = None + view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None + return view_1 + + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view_copy = torch.ops.aten.view_copy(add, [-1]); add = None + add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None + return view_copy_1 + + + >>> # A function that mutates its input tensor + >>> def f(a): + ... b = a.view(-1) + ... b.add_(1) + ... return a + ... + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> # + >>> # All mutations and views have been removed, + >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input + >>> # after the function has completed. + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + view_copy = torch.ops.aten.view_copy(a_1, [-1]) + add = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None + copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None + return view_copy_1 + + + There are a few "failure modes" for functionalize that are worth calling out: + (1) Like other torch.func transforms, `functionalize()` doesn't work with functions + that directly use `.backward()`. The same is true for torch.autograd.grad. + If you want to use autograd, you can compute gradients directly + with `functionalize(grad(f))`. + (2) Like other torch.func transforms, `functionalize()` doesn't work with global state. + If you call `functionalize(f)` on a function that takes views / mutations of + non-local state, functionalization will simply no-op and pass the view/mutation + calls directly to the backend. + One way to work around this is is to ensure that any non-local state creation + is wrapped into a larger function, which you then call functionalize on. + (3) `resize_()` has some limitations: functionalize will only work on programs + that use resize_()` as long as the tensor being resized is not a view. + (4) `as_strided()` has some limitations: functionalize will not work on + `as_strided()` calls that result in tensors with overlapping memory. + + + Finally, a helpful mental model for understanding functionalization is that + most user pytorch programs are writing with the public torch API. + When executed, torch operators are generally decomposed into + our internal C++ "ATen" API. + The logic for functionalization happens entirely at the level of ATen. + Functionalization knows how to take every aliasing operator in ATen, + and map it to its non-aliasing equivalent + (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``), + and how to take every mutating operator in ATen, + and map it to its non-mutating equivalent + (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``), + while tracking aliases and mutations out-of-line to know when to fix things up. + Information about which ATen operators are aliasing or mutating all comes from + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml. + """ + if remove == 'mutations': + reapply_views = True + elif remove == 'mutations_and_views': + reapply_views = False + else: + raise RuntimeError( + f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}." + " Valid options are:\n" + " remove='mutations': all inplace and out= operators will be removed from the program, and replaced" + " with their out-of-place equivalents.\n" + " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be" + " replaced with their non-aliasing counterparts, {view}_copy.\n" + ) + + @doesnt_support_saved_tensors_hooks + @wraps(func) + def wrapped(*args, **kwargs): + try: + func_level = _func_increment_nesting(reapply_views) + func_args = _wrap_all_tensors_to_functional(args, func_level) + func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) + + flattened_unwrapped_args = pytree.arg_tree_leaves(*args) + flattened_wrapped_args = pytree.arg_tree_leaves(*func_args) + flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs) + flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs) + + func_outputs = func(*func_args, **func_kwargs) + outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views) + flat_outputs, func_out_spec = tree_flatten(outputs) + + for a in flattened_wrapped_args + flattened_wrapped_kwargs: + if isinstance(a, torch.Tensor): + # Call sync_() on the inputs, to ensure that any pending mutations have been applied. + torch._sync(a) + + # And if any mutations were applied to the inputs, we need to propagate them back to the user. + for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args): + if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor): + _propagate_functional_input_mutation(unwrapped, wrapped) + for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs): + if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor): + _propagate_functional_input_mutation(unwrapped, wrapped) + + return outputs + finally: + _func_decrement_nesting() + return wrapped + +@exposed_in("torch.func") +def linearize(func: Callable, *primals) -> Tuple[Any, Callable]: + ''' + Returns the value of ``func`` at ``primals`` and linear approximation + at ``primals``. + + Args: + func (Callable): A Python function that takes one or more arguments. + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. These are the values at which the function is linearly approximated. + + Returns: + Returns a ``(output, jvp_fn)`` tuple containing the output of ``func`` + applied to ``primals`` and a function that computes the jvp of + ``func`` evaluated at ``primals``. + + linearize is useful if jvp is to be computed multiple times at ``primals``. However, + to achieve this, linearize saves intermediate computation and has higher memory requirements + than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient + to compute vmap(jvp) instead of using linearize. + + .. note:: + linearize evaluates ``func`` twice. Please file an issue for an implementation + with a single evaluation. + + Example:: + >>> import torch + >>> from torch.func import linearize + >>> def fn(x): + ... return x.sin() + ... + >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) + >>> jvp_fn(torch.ones(3, 3)) + tensor([[1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.]]) + >>> + + ''' + # Note: We evaluate `fn` twice. + # Once for returning the output and other while + # tracing the graph. + # If this becomes a bottle-neck, we should update + # make_fx such that it also returns the output. + + output = func(*primals) + _, output_spec = tree_flatten(output) + + flat_primals, primals_argspec = tree_flatten(primals) + + # tangents for tracing + flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals) + + # function to trace + def trace_fn(flat_tangents): + with fwAD.dual_level(): + flat_duals = tuple(fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents)) + duals = tree_unflatten(flat_duals, primals_argspec) + output = func(*duals) + tangents = tree_map_only(torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output) + + return tangents + + jvp_graph = make_fx(trace_fn)(flat_tangents) + const_folded_jvp_graph = const_fold.split_const_subgraphs(jvp_graph) + + # Hold only the meta-data regarding the primals. + flat_primals_shape = tuple(p.shape for p in flat_primals) + flat_primals_device = tuple(p.device for p in flat_primals) + flat_primals_dtype = tuple(p.dtype for p in flat_primals) + + def forward_ad_checks(flat_tangents): + for idx, t in enumerate(flat_tangents): + if t.shape != flat_primals_shape[idx]: + msg = (f"tangent:{idx} with shape {t.shape} in flattened " + f"pytree doesn't match the shape {flat_primals_shape[idx]} " + "of the corresponding primal.") + raise RuntimeError(msg) + + if t.device != flat_primals_device[idx]: + msg = (f"tangent:{idx} with device {t.device} in flattened " + f"pytree doesn't match the device {flat_primals_device[idx]} " + "of the corresponding primal.") + raise RuntimeError(msg) + + if t.dtype != flat_primals_dtype[idx]: + msg = (f"tangent:{idx} with dtype {t.dtype} in flattened " + f"pytree doesn't match the dtype {flat_primals_dtype[idx]} " + "of the corresponding primal.") + raise RuntimeError(msg) + + # jvp_fn : callable to return + # It takes care of checking the argspec of tangents, + # calling the folded fx graph and unflattening fx graph output + def jvp_fn(*tangents): + flat_tangents, tangent_argspec = tree_flatten(tangents) + if tangent_argspec != primals_argspec: + raise RuntimeError(f"Expected the tangents {tangent_argspec} to have " + f"the same argspec as the primals {primals_argspec}") + + forward_ad_checks(flat_tangents) + + flat_output = const_folded_jvp_graph(*flat_tangents) + # const folded graph can return flat output, + # so transform output. + return tree_unflatten(flat_output, output_spec) + + return output, jvp_fn diff --git a/MLPY/Lib/site-packages/torch/_functorch/functional_call.py b/MLPY/Lib/site-packages/torch/_functorch/functional_call.py new file mode 100644 index 0000000000000000000000000000000000000000..16240d61ffd9b4c138f198adf6be6bddc890e651 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/functional_call.py @@ -0,0 +1,248 @@ +from collections import Counter +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch._functorch.utils import exposed_in + + +@exposed_in("torch.func") +def functional_call( + module: "torch.nn.Module", + parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]], + args: Union[Any, Tuple], + kwargs: Optional[Dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Performs a functional call on the module by replacing the module parameters + and buffers with the provided ones. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the ``parameter_and_buffer_dicts`` input. + + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + An example of passing multiple dictionaries + + .. code-block:: python + + a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries + mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer + print(mod.weight) # tensor(...) + print(mod.buffer) # tensor(...) + x = torch.randn((1, 1)) + print(x) + functional_call(mod, a, x) # same as x + print(mod.weight) # same as before functional_call + + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from torch.func import functional_call, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + + def compute_loss(params, x, t): + y = functional_call(model, params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t) + + .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the + parameters for better performance and memory usage + + Example:: + + >>> detached_params = {k: v.detach() for k, v in model.named_parameters()} + >>> grad_weights = grad(compute_loss)(detached_params, x, t) + >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad + + This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking + outside of the transforms, this will result in less memory usage and faster speeds. + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in + the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can + be used together + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparameterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + if isinstance(parameter_and_buffer_dicts, dict): + parameters_and_buffers = parameter_and_buffer_dicts + elif isinstance(parameter_and_buffer_dicts, Sequence): + if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts): + raise ValueError( + "Expected all elements of parameter_and_buffer_dicts to be dictionaries" + ) + all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()] + repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1] + if len(repeated_keys) > 0: + raise ValueError( + f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous" + ) + parameters_and_buffers = { + k: v for d in parameter_and_buffer_dicts for k, v in d.items() + } + else: + raise ValueError( + f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, " + f"but got {type(parameter_and_buffer_dicts)}" + ) + + return nn.utils.stateless._functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +@exposed_in("torch.func") +def stack_module_state( + models: List[nn.Module], +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """stack_module_state(models) -> params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries + that stack all of their parameters and buffers together, indexed by name. + The stacked parameters are optimizable (i.e. they are new leaf nodes in the + autograd history that are unrelated to the original parameters and can be + passed directly to an optimizer). + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + def wrapper(params, buffers, data): + return torch.func.functional_call(model[0], (params, buffers), data) + + params, buffers = stack_module_state(models) + output = vmap(wrapper, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + When there's submodules, this follows state dict naming conventions + + .. code-block:: python + + import torch.nn as nn + class Foo(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + hidden = 4 + self.l1 = nn.Linear(in_features, hidden) + self.l2 = nn.Linear(hidden, out_features) + + def forward(self, x): + return self.l2(self.l1(x)) + + num_models = 5 + in_features, out_features = 3, 3 + models = [Foo(in_features, out_features) for i in range(num_models)] + params, buffers = stack_module_state(models) + print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias" + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + """ + if len(models) == 0: + raise RuntimeError("stack_module_state: Expected at least one model, got 0.") + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError( + "stack_module_state: Expected all models to have the same training/eval mode." + ) + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError( + "stack_module_state: Expected all models to be of the same class." + ) + all_params = [dict(model.named_parameters()) for model in models] + params = { + k: construct_stacked_leaf(tuple(params[k] for params in all_params), k) + for k in all_params[0] + } + all_buffers = [dict(model.named_buffers()) for model in models] + buffers = { + k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k) + for k in all_buffers[0] + } + + return params, buffers + + +def construct_stacked_leaf( + tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str +) -> Tensor: + all_requires_grad = all(t.requires_grad for t in tensors) + none_requires_grad = all(not t.requires_grad for t in tensors) + if not all_requires_grad and not none_requires_grad: + raise RuntimeError( + f"Expected {name} from each model to have the same .requires_grad" + ) + result = torch.stack(tensors) + if all_requires_grad: + result = result.detach().requires_grad_() + return result diff --git a/MLPY/Lib/site-packages/torch/_functorch/fx_minifier.py b/MLPY/Lib/site-packages/torch/_functorch/fx_minifier.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0c3009de5b3af81b687259dfc9c2a9c53f0e84 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/fx_minifier.py @@ -0,0 +1,445 @@ +# mypy: ignore-errors + +import torch.fx as fx +import copy +import torch +import math +import sys +from typing import Callable, List +from functools import wraps, partial +from dataclasses import dataclass +from .compile_utils import get_placeholders, get_outputs +from torch.utils._content_store import ContentStoreWriter +from torch.hub import tqdm +from torch.multiprocessing.reductions import StorageWeakRef +import os + +is_tuple = object() + +@dataclass +class LoadTensorMeta: + size: List[int] + stride: List[int] + dtype: torch.dtype + device: torch.device + +class ConcreteProp(torch.fx.Interpreter): + def __init__(self, mod, *, writer=None, skip_offload=False): + super().__init__(mod) + self.writer = writer + self.skip_offload = skip_offload + self.seen_storages = set() + + def run_node(self, n): + self.pbar.update(1) + r = super().run_node(n) + name = n.name + + if isinstance(r, torch.Tensor): + if self.writer is None: + n.meta['concrete_value'] = r + else: + if StorageWeakRef(r.untyped_storage()) in self.seen_storages: + # Refuse to offload tensors which alias other live + # tensors, because this will violate operator contracts + n.meta['concrete_value'] = None + else: + if not self.skip_offload: + self.writer.write_tensor(os.path.join("eager", name), r) + n.meta['concrete_value'] = LoadTensorMeta( + r.size(), + r.stride(), + r.dtype, + r.device + ) + self.seen_storages.add(StorageWeakRef(r.untyped_storage())) + else: + n.meta['concrete_value'] = is_tuple + + return r + + def propagate(self, *args): + with tqdm( + desc="Saving intermediates for delta debugging", + total=len(self.module.graph.nodes), + disable=self.writer is None + ) as pbar: + self.pbar = pbar + r = super().run(*args) + if not self.skip_offload: + pbar.set_description("Saved! To skip next time, run with --skip-saving-eager-intermediates") + return r + +def is_load_tensor_node(node): + return node.op == 'call_function' and node.target is torch.ops.debugprims.load_tensor.default + + +# inplace modifies node/inps +def _convert_node_to_placeholder(graph, node, inps): + if node.op == 'output' or node.op == "placeholder": + return False + + if is_load_tensor_node(node): + return False + + concrete_val = node.meta.get('concrete_value', None) + + if isinstance(concrete_val, torch.Tensor): + node.op = 'placeholder' + node.target = node.name + node.args = () + node.kwargs = {} + + inps.append(concrete_val) + return True + + elif concrete_val is None: + return False + + elif concrete_val is is_tuple: + r = False + for tuple_user in list(node.users): + r = _convert_node_to_placeholder(graph, tuple_user, inps) or r + # NB: We must not erase the node at this point, because + # we are iterating over the nodes and this would change + # the iteration order + # graph.erase_node(node) + return r + + elif isinstance(concrete_val, LoadTensorMeta): + node.op = 'call_function' + node.target = torch.ops.debugprims.load_tensor.default + node.args = (os.path.join("eager", node.name), concrete_val.size, concrete_val.stride) + node.kwargs = { + 'device': concrete_val.device, + 'dtype': concrete_val.dtype, + } + return True + + return False + +def create_minified_hlo_graph(minified_fx_graph, inputs): + """ + Takes minified FX graph as primary input, and ports it to HLO via StableHLO + Provides minified HLO graph as output, and archive them to local directory + """ + hlo_dir = f"{os.getcwd()}/hlo_files" + os.makedirs(hlo_dir, exists_ok=True) + + from torch_xla.stablehlo import save_torch_model_as_stablehlo + save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir) + +def dump_state(fx_g, inps): + print(f""" +# Working Repro with {len(fx_g.graph.nodes)} nodes +inps = {[(i.shape, i.dtype, i.device.type) for i in inps]} +inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps] +{fx_g.code} +""") + +def is_power_of_two(n): + if n == 0: + return False + return (n & (n - 1)) == 0 + +@dataclass +class ReproState: + graph: fx.Graph + inps: List[torch.Tensor] + + def __post_init__(self): + ph_nodes = get_placeholders(self.graph) + assert len(ph_nodes) == len(self.inps) + +def minifier( + fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable = dump_state, *, + save_dir=None, offload_to_disk=False, skip_offload=False, skip_sanity=False, + max_granularity=None +): + """ + Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. + + Does 2 main strategies: + 1. Truncates suffix: Removes some suffix from the graph and sets a new output. + 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, + tries replacing quarter of the graph, etc. + + >>> # xdoctest: +SKIP(failing) + >>> failing_function = fx.symbolic_trace(f) + >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) + + note: module_fails returns True if it fails. + """ + assert isinstance(inps, (tuple, list)) + + failing_graph = fail_f.graph + cur_size = len(failing_graph.nodes) + + if max_granularity is not None and not is_power_of_two(max_granularity): + raise RuntimeError(f"max_granularity {max_granularity} not power of two") + + num_queries = 0 + + def deepcopy_fx_graph(fx_graph): + return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph + + + def graph_fails(graph, inps): + nonlocal num_queries + graph = copy.deepcopy(graph) + num_queries += 1 + mod = fx.GraphModule(fail_f, graph) + mod.graph.lint() + return module_fails(mod, inps) + + writer = None + if offload_to_disk: + writer = ContentStoreWriter(save_dir) + + ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps) + if not skip_sanity and not graph_fails(failing_graph, inps): + raise RuntimeError("Input graph did not fail the tester") + print(f"Started off with {cur_size} nodes", file=sys.stderr) + + def _register_strategy(strategy: Callable, name: str): + @wraps(strategy) + def new_func(old_state: ReproState, granularity=1): + print(file=sys.stderr) + print( + f"Strategy: {name} (G: {granularity}) " + f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)", + file=sys.stderr + ) + new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity) + if new_state is not None: + new_nodes = len(new_state.graph.nodes) + old_nodes = len(old_state.graph.nodes) + new_inps = len(new_state.inps) + old_inps = len(old_state.inps) + new_outs = len(get_outputs(new_state.graph)) + old_outs = len(get_outputs(old_state.graph)) + progress_made = False + if new_nodes < old_nodes: + progress_made = True + print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", file=sys.stderr) + if new_inps > old_inps: + progress_made = True + print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs", file=sys.stderr) + if new_outs < old_outs: + progress_made = True + print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs", file=sys.stderr) + + if not progress_made: + raise RuntimeError("Success raised but no progress made?") + + if not graph_fails(new_state.graph, new_state.inps): + print("WARNING: Something went wrong, not applying this minification", file=sys.stderr) + return None + return new_state + else: + print(f"FAIL: {name}", file=sys.stderr) + return None + + return new_func + + def register_strategy(name: str): + return partial(_register_strategy, name=name) + + @register_strategy("Truncate suffix") + def remove_suffix(cur_graph, cur_inps, granularity): + tested = set() + new_graph = fx.Graph() + env = {} + for idx, node in enumerate(cur_graph.nodes): + new_node = new_graph.node_copy(node, lambda x: env[x]) + if node.op not in ['placeholder', 'output']: + # If idx is divisible by (granularity * 2), it would have been checked already. + if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested: + output_node = new_graph.output((new_node,)) + if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(new_graph, cur_inps): + return ReproState(new_graph, cur_inps) + else: + tested.add(idx) + new_graph.erase_node(output_node) + env[node] = new_node + return None + + @register_strategy("Remove outputs") + def remove_outputs(cur_graph, cur_inps, granularity): + granularity = max(1, granularity // 2) + for idx, node in enumerate(cur_graph.nodes): + node.idx = idx + if node.op == 'output': + output = node + break + + if isinstance(output.args[0], fx.Node): + return None + + output_args = sorted(output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9)) + if len(output_args) == 1: + return None + + for idx in range(0, len(output_args), granularity): + output.args = (output_args[:idx] + output_args[idx + granularity:],) + if graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + + def remove_unused_inputs_unchecked(cur_state: ReproState): + cur_graph = cur_state.graph + cur_inps = cur_state.inps + ph_nodes = get_placeholders(cur_graph) + assert len(ph_nodes) == len(cur_inps) + + new_inps = [] + for idx in range(len(ph_nodes)): + if len(ph_nodes[idx].users) == 0: + cur_graph.erase_node(ph_nodes[idx]) + else: + new_inps.append(cur_inps[idx]) + if len(new_inps) < len(cur_inps): + return ReproState(cur_graph, new_inps) + return None + + def remove_unused_inputs_checked(cur_state: ReproState): + new_state = remove_unused_inputs_unchecked(cur_state) + if new_state is not None and graph_fails(new_state.graph, new_state.inps): + return new_state + return None + + def _remove_unused_wrapper(cur_graph, cur_inps, granularity): + return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps)) + + remove_unused_inputs = register_strategy("Remove unused inputs")(_remove_unused_wrapper) + + @register_strategy("Eliminate dead code") + def eliminate_dead_code(cur_graph, cur_inps, granularity): + if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + + def _consolidate_placeholders(cur_graph, inps): + new_graph = fx.Graph() + env = {} + seen_non_placeholder = False + + # Move all placeholders to the front; also, if any load_tensor + # is at the front, convert it into an input (because it can be live + # all the time) + for node in cur_graph.nodes: + if node.op == 'placeholder': + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + elif not seen_non_placeholder and is_load_tensor_node(node): + new_node = new_graph.placeholder(node.name) + env[node] = new_node + inps.append(torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs)) + else: + seen_non_placeholder = True + + # Move everyone else + for node in cur_graph.nodes: + if node not in env: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + return new_graph + + @register_strategy("Delta Debugging") + def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity): + num_nodes = len(cur_graph.nodes) + for start_range in range(0, num_nodes, granularity): + is_removing = False + new_graph = deepcopy_fx_graph(cur_graph) + new_inps = cur_inps[:] + end_range = min(num_nodes, start_range + granularity) + for idx in range(start_range, end_range): + new_node = list(new_graph.nodes)[idx] + if _convert_node_to_placeholder(new_graph, new_node, new_inps): + is_removing = True + if not is_removing: + continue + new_graph.eliminate_dead_code() + new_graph = _consolidate_placeholders(new_graph, new_inps) + new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps)) + if new_state is None: + new_state = ReproState(new_graph, new_inps) + if graph_fails(new_state.graph, new_state.inps): + return ReproState(new_state.graph, new_state.inps) + + return None + + @register_strategy("Consolidate Inputs") + def consolidate_inputs(cur_graph, cur_inps, granularity): + old_len = len(cur_inps) + cur_graph = _consolidate_placeholders(cur_graph, cur_inps) + if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + failing_state = ReproState(failing_graph, inps) + + def try_granularity(failing_state, granularity, use_non_granular): + print(f"Trying granularity {granularity}", file=sys.stderr) + + strategies = [] + num_nodes = len(failing_state.graph.nodes) + num_outputs = len(get_outputs(failing_state.graph)) + if num_outputs > num_nodes // 2: + strategies += [remove_outputs] + + if use_non_granular: + strategies += [eliminate_dead_code, remove_unused_inputs, consolidate_inputs] + + strategies += [remove_suffix, delta_debugging] + + for strategy in strategies: + new_state = strategy(failing_state, granularity) + if new_state is not None: + return new_state + return None + + while True: + dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps) + granularity = int(2**(math.floor(math.log2(len(failing_state.graph.nodes))))) + if max_granularity is not None: + granularity = min(max_granularity, granularity) + new_state = try_granularity(failing_state, granularity, use_non_granular=True) + if new_state is not None: + failing_state = new_state + continue + + granularity //= 2 + has_progress = False + while granularity >= 1: + new_state = try_granularity(failing_state, granularity, use_non_granular=False) + if new_state is not None: + failing_state = new_state + has_progress = True + break + granularity //= 2 + if has_progress: + continue + + new_state = remove_outputs(failing_state, 1) + if new_state is not None: + failing_state = new_state + continue + + break + + if not graph_fails(failing_state.graph, failing_state.inps): + raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing") + + print(f"Made {num_queries} queries", file=sys.stderr) + failing_fx = fx.GraphModule(fail_f, failing_state.graph) + + # If XLA debugging environment is enabled, create minified HLO graph as well + if "XLA_HLO_DEBUG" in os.environ: + create_minified_hlo_graph(failing_fx, failing_state.inps) + + dump_state(failing_fx, failing_state.inps) + print("Wrote minimal repro out to repro.py", file=sys.stderr) + return failing_fx, failing_state.inps diff --git a/MLPY/Lib/site-packages/torch/_functorch/make_functional.py b/MLPY/Lib/site-packages/torch/_functorch/make_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..2a393055e24fe28faeadb5cb72f72809dc6ae95d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/make_functional.py @@ -0,0 +1,615 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NoReturn, + Sequence, + Tuple, + Type, + Union, +) + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + +# Utilities to make nn.Module "functional" +# In particular the goal is to be able to provide a function that takes as input +# the parameters and evaluate the nn.Module using fixed inputs. + + +def raise_parameter_tying_error() -> NoReturn: + raise RuntimeError( + "make_functional(module): we don't yet support models that " + "do parameter tying (also sometimes known as weight sharing). " + "Please try to rewrite your model by replacing all instances of the " + "tied parameter with another and/or comment your support in " + "https://github.com/pytorch/functorch/issues/446" + ) + + +def create_names_map( + named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], + tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], +) -> Dict[str, List[str]]: + """ + named_params is a dictionary of tensors: {'A': A, 'B': B} + tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} + with potentially tied (or 'duplicated') tensors + + This function creates a mapping from the names in named_params to the + names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. + """ + named_params = dict(named_params) + tied_named_params = dict(tied_named_params) + + tensors_dict_keys = set(named_params.keys()) + tied_tensors_dict_keys = set(tied_named_params.keys()) + assert tensors_dict_keys.issubset(tied_tensors_dict_keys) + + tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {} + for key, tensor in named_params.items(): + tensor_to_mapping[tensor] = (key, []) + for key, tensor in tied_named_params.items(): + assert tensor in tensor_to_mapping + tensor_to_mapping[tensor][1].append(key) + return dict(tensor_to_mapping.values()) + + +def _extract_members( + mod: nn.Module, + named_members: Callable[..., Iterable[Tuple[str, Tensor]]], + subclass: Callable[[Tensor], Tensor], +) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: + all_named_members = tuple(named_members(remove_duplicate=False)) + unique_named_members = tuple(named_members(remove_duplicate=True)) + names_map = create_names_map(unique_named_members, all_named_members) + + # Remove all the members in the model + memo = {} + accessor = NamedMemberAccessor(mod) + for name, p in all_named_members: + if p not in memo: + memo[p] = subclass(torch.empty_like(p, device="meta")) + replacement = memo[p] + accessor.set_tensor(name, replacement) + + if len(unique_named_members) == 0: + names, params = (), () + else: + names, params = zip(*unique_named_members) # type: ignore[assignment] + return params, names, names_map + + +def extract_weights( + mod: nn.Module, +) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: + """ + This function removes all the Parameters from the model and + return them as a tuple as well as their original attribute names. + The weights must be re-loaded with `load_weights` before the model + can be used again. + Note that this function modifies the model in place and after this + call, mod.parameters() will be empty. + """ + return _extract_members(mod, mod.named_parameters, nn.Parameter) + + +def extract_buffers( + mod: nn.Module, +) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: + return _extract_members(mod, mod.named_buffers, lambda x: x) + + +def load_weights( + mod: nn.Module, + names: Sequence[str], + params: Sequence[Tensor], + as_params: bool = False, +) -> None: + """ + Reload a set of weights so that `mod` can be used again to perform a forward pass. + Note that the `params` are regular Tensors (that can have history) and so are left + as Tensors. This means that mod.parameters() will still be empty after this call. + """ + accessor = NamedMemberAccessor(mod) + if as_params: + params = [nn.Parameter(p) for p in params] + accessor.set_tensors(names, params) + + +def _swap_state( + mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor] +) -> List[Tensor]: + result: List[Tensor] = [] + accessor = NamedMemberAccessor(mod) + for (_, attr_names), elem in zip(names_map.items(), elems): + for i, attr_name in enumerate(attr_names): + if i == 0: + result.append(accessor.swap_tensor(attr_name, elem)) + else: + accessor.set_tensor(attr_name, elem) + return result + + +def load_buffers( + mod: nn.Module, + names: Sequence[str], + buffers: Sequence[Tensor], + as_params: bool = False, +) -> None: + accessor = NamedMemberAccessor(mod) + accessor.set_tensors(names, buffers) + + +def load_state( + model: nn.Module, + weights: Sequence[Tensor], + weight_names: Sequence[str], + buffers: Sequence[Tensor] = (), + buffer_names: Sequence[str] = (), +) -> nn.Module: + """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model + + load_state takes `weights` and `buffers` and assigns them to the model. + This is the inverse operation of `make_functional_deprecated_v1`. + """ + assert len(weight_names) == len(weights) + load_weights(model, weight_names, weights) + if len(buffers) > 0: + assert len(buffer_names) == len(buffers) + load_buffers(model, buffer_names, buffers) + return model + + +def make_functional_deprecated_v1(model: nn.Module): + """make_functional_deprecated_v1(model) -> weights, func, weight_names + + Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights) + and returns a functional version of the model, `func`. This makes + it so that it is possible use transforms over the parameters of + `model`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, func, _ = make_functional_deprecated_v1(model) + func(weights, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, _, func = make_functional_deprecated_v1(model) + grad_weights = grad(func)(weights, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError( + "make_functional_deprecated_v1(model): `model` has buffers. Please use " + "make_functional_with_buffers_deprecated_v1(model) instead." + ) + weights, descriptors, _ = extract_weights(model) + + def fun(weights, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, descriptors, weights) + return mutable_model(*data) + + return weights, fun, descriptors + + +def make_functional_with_buffers_deprecated_v1(model: nn.Module): + """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names + + Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers) + and returns a functional version of the model, `func`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + grad_weights = grad(func)(weights, buffers, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + weights, weight_descriptors, _ = extract_weights(model) + buffers, buf_descriptors, _ = extract_buffers(model) + + def fun(weights, buffers, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, weight_descriptors, weights) + load_buffers(mutable_model, buf_descriptors, buffers) + return mutable_model(*data) + + return weights, buffers, fun, weight_descriptors, buf_descriptors + + +class FunctionalModuleWithBuffers(nn.Module): + """ + This is the callable object returned by :func:`make_functional_with_buffers`. + """ + + def __init__( + self, + stateless_model: nn.Module, + param_names: Tuple[str, ...], + buffer_names: Tuple[str, ...], + param_names_map: Dict[str, List[str]], + buffer_names_map: Dict[str, List[str]], + ) -> None: + super().__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.buffer_names = buffer_names + + self.all_names_map = dict(param_names_map) + self.all_names_map.update(buffer_names_map) + + @staticmethod + def _create_from( + model: nn.Module, disable_autograd_tracking: bool = False + ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]: + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, param_names_map = extract_weights(model_copy) + buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return ( + FunctionalModuleWithBuffers( + model_copy, param_names, buffer_names, param_names_map, buffer_names_map + ), + params, + buffers, + ) + + def forward( + self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs + ) -> Any: + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state( + self.stateless_model, + self.all_names_map, + tuple(params) + tuple(buffers), + ) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.all_names_map, old_state) + + +class FunctionalModule(nn.Module): + """ + This is the callable object returned by :func:`make_functional`. + """ + + def __init__( + self, + stateless_model: nn.Module, + param_names: Tuple[str, ...], + names_map: Dict[str, List[str]], + ) -> None: + super().__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.names_map = names_map + + @staticmethod + def _create_from( + model: nn.Module, disable_autograd_tracking: bool = False + ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]: + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, names_map = extract_weights(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return FunctionalModule(model_copy, param_names, names_map), params + + def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any: + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state(self.stateless_model, self.names_map, params) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.names_map, old_state) + + +def make_functional( + model: nn.Module, disable_autograd_tracking: bool = False +) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]: + """make_functional(model, disable_autograd_tracking=False) -> func, params + + Given a ``torch.nn.Module``, :func:`make_functional` extracts the state + (params) and returns a functional version of the model, ``func``. This + makes it so that it is possible use transforms over the parameters of + ``model``. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + func(params, x) + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + + def compute_loss(params, x, t): + y = func(params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, x, t) + + If the model has any buffers, please use :func:`make_functional_with_buffers` instead. + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError( + "make_functional(model): `model` has buffers. Please use " + "make_functional_with_buffers(model) instead." + ) + return FunctionalModule._create_from( + model, disable_autograd_tracking=disable_autograd_tracking + ) + + +def make_functional_with_buffers( + model: nn.Module, disable_autograd_tracking: bool = False +) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: + """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers + + Given a ``torch.nn.Module``, make_functional_with_buffers extracts the + state (params and buffers) and returns a functional version of the model + ``func`` that can be invoked like a function. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + func(params, buffers, x) + + And here is an example of applying the grad transform over the parameters + of a model: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + + def compute_loss(params, buffers, x, t): + y = func(params, buffers, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, buffers, x, t) + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + return FunctionalModuleWithBuffers._create_from( + model, disable_autograd_tracking=disable_autograd_tracking + ) + + +def transpose_stack( + tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...] +) -> Tuple[Tensor, ...]: + tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) + results = tuple( + torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors + ) + return results + + +def combine_state_for_ensemble( + models: Sequence[nn.Module], +) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: + """combine_state_for_ensemble(models) -> func, params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their + parameters and buffers together to make ``params`` and ``buffers``. + Each parameter and buffer in the result will have an additional dimension + of size ``M``. + + :func:`combine_state_for_ensemble` also returns ``func``, a functional + version of one of the models in :attr:`models`. One cannot directly run + ``func(params, buffers, *args, **kwargs)`` directly, you probably want to + use ``vmap(func, ...)(params, buffers, *args, **kwargs)`` + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + fmodel, params, buffers = combine_state_for_ensemble(models) + output = vmap(fmodel, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + + This API is subject to change -- we're investigating better ways to + create ensembles and would love your feedback how to improve this. + """ + if len(models) == 0: + raise RuntimeError( + "combine_state_for_ensemble: Expected at least one model, got 0." + ) + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError( + "combine_state_for_ensemble: Expected all models to " + "have the same training/eval mode." + ) + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError( + "combine_state_for_ensemble: Expected all models to be of the same class." + ) + funcs, params, buffers = zip( + *[make_functional_with_buffers(model) for model in models] + ) + params = transpose_stack(params) + buffers = transpose_stack(buffers) + return funcs[0], params, buffers + + +def functional_init( + model_class: Type[nn.Module], + ensemble_shape: Union[Tuple[()], Tuple[int]] = (), + device: torch.types.Device = "cpu", +): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError("NYI: ensemble_shape with more than 1 element") + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] # type: ignore[misc] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple( + model_class(*args, **kwargs).to(device) for _ in range(num_models) + ) + _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs)) + weights = tuple(make_functional_deprecated_v1(model)[0] for model in models) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + return weights, fn, names + + return wrapped + + +def functional_init_with_buffers( + model_class: Type[nn.Module], + ensemble_shape: Union[Tuple[()], Tuple[int]] = (), + device: torch.types.Device = "cpu", +): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError("NYI: ensemble_shape with more than 1 element") + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] # type: ignore[misc] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple( + model_class(*args, **kwargs).to(device) for _ in range(num_models) + ) + ( + _, + _, + fn, + weight_names, + buffer_names, + ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs)) + weights, buffers = zip( + *tuple( + make_functional_with_buffers_deprecated_v1(model)[:2] + for model in models + ) + ) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + buffers = tuple(zip(*buffers)) + buffers = tuple(torch.stack(shards).detach() for shards in buffers) + return weights, buffers, fn, weight_names, buffer_names + + return wrapped diff --git a/MLPY/Lib/site-packages/torch/_functorch/partitioners.py b/MLPY/Lib/site-packages/torch/_functorch/partitioners.py new file mode 100644 index 0000000000000000000000000000000000000000..e2bffcaca654eaafcfa15d934008f8b37e629f88 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/partitioners.py @@ -0,0 +1,981 @@ +# mypy: ignore-errors + +from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + hint_int, free_symbols, is_symbol_binding_fx_node, find_symbol_binding_fx_nodes +) +from torch.fx.experimental._backward_state import BackwardState +import torch +import torch.fx as fx +import operator +import math +import torch.utils._pytree as pytree +import copy +import os +import itertools +import sympy +from collections import defaultdict +from torch.fx.passes import graph_drawer +from typing import List, Optional, Set, Tuple, Union +from .compile_utils import fx_graph_cse, get_aten_target +from . import config +import functools + + +AOT_PARTITIONER_DEBUG = config.debug_partitioner + + +def must_recompute(node): + return node.meta.get("recompute", False) + +def has_recomputable_ops(fx_g): + found = False + for node in fx_g.graph.nodes: + if must_recompute(node): + return True + return False + +def has_recomputable_rng_ops(fx_g): + for node in fx_g.graph.nodes: + if must_recompute(node) and hasattr(node.target, "tags") and torch.Tag.nondeterministic_seeded in node.target.tags: + return True + return False + +def sym_node_size(node): + if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): + return 1 + assert isinstance(node.meta["val"], torch.SymFloat) + return 4 + +class InvalidNodeBase: + def __repr__(self): + return "Invalid Node" + + +InvalidNode = InvalidNodeBase() + + +def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): + """ + Given a graph, extracts out a subgraph that takes the specified nodes as + inputs and returns the specified outputs. + + This includes specifying non-placeholder nodes as inputs. + + The general strategy is to initialize all inputs with proxies as we + encounter them, and trace through the graph, only keeping values which take + in valid proxies. Then, all dead code is eliminated. + """ + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in inputs: + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + for node in joint_graph.nodes: + if node in inputs: + continue + elif node.op == 'placeholder': + env[node] = InvalidNode + elif node.op == 'call_function': + all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) + all_args = [isinstance(env[x], InvalidNodeBase) for x in all_args if isinstance(x, fx.Node)] + if any(all_args): + env[node] = InvalidNode + continue + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == 'get_attr': + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == 'output': + pass + output_values = [] + for x in outputs: + if isinstance(x, fx.Node): + if x not in env: + raise RuntimeError(f"Node {x} couldn't be found in env") + assert not isinstance(env[x], InvalidNodeBase), f"Node {x} was invalid, but is output" + output_values.append(env[x]) + else: + output_values.append(x) + new_graph.output(output_values) + + new_graph.eliminate_dead_code() + new_graph.lint() + return new_graph + + +def _is_primal(node): + return ( + node.op == "placeholder" + and "tangents" not in node.target + and not _is_bwd_seed_offset(node) + and not _is_fwd_seed_offset(node) + ) + +def _is_tangent(node): + return node.op == "placeholder" and "tangents" in node.target + +def _is_bwd_seed_offset(node): + return node.op == "placeholder" and ("bwd_seed" in node.target or "bwd_base_offset" in node.target) + +def _is_fwd_seed_offset(node): + return node.op == "placeholder" and ("fwd_seed" in node.target or "fwd_base_offset" in node.target) + +def _is_backward_state(node): + return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) + + +def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): + outputs = pytree.arg_tree_leaves(*(node.args for node in joint_module.graph.nodes if node.op == 'output')) + fwd_outputs = outputs[:num_fwd_outputs] + bwd_outputs = outputs[num_fwd_outputs:] + return fwd_outputs, bwd_outputs + + +def _remove_by_name(saved_values, name): + for saved_value in saved_values: + if saved_value.name == name: + saved_values.remove(saved_value) + break + +def _placeholders(nodes): + # Avoid making an entire pass over the graph if we only care about the input placeholders + result = [] + for node in nodes: + if node.op == 'placeholder': + result.append(node) + else: + break # placeholders are all at the start of graph + return result + + +def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs): + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + placeholders = _placeholders(joint_module.graph.nodes) + primal_inputs = [*filter(_is_primal, placeholders)] + tangent_inputs = [*filter(_is_tangent, placeholders)] + fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] + bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] + backward_state_inputs = [*filter(_is_backward_state, placeholders)] + + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, + bwd_outputs + ) + + for node in _placeholders(bwd_graph.nodes): + assert node.op == 'placeholder' + # This is to filter out saved values that don't actually end up being used by the backwards pass + if not node.users: + _remove_by_name(saved_values, node.name) + _remove_by_name(saved_sym_nodes, node.name) + elif _is_backward_state(node): + # BackwardState is saved directly + _remove_by_name(saved_values, node.name) + assert backward_state_inputs + + + # Now that we have the finalized list of saved values, we need to ensure + # we propagate all symbols which are referenced by backwards inputs. + # These are not directly used in the graph but are required for downstream + # sizevar assignment + saved_symbols: Set[sympy.Symbol] = set() + saved_sym_nodes_binding = [] + saved_sym_nodes_derived = [] + + # Some symbols may already be bound in the directly saved_sym_nodes, + # keep track of them so we don't re-bind them + for node in saved_sym_nodes: + symbol = is_symbol_binding_fx_node(node) + if symbol: + saved_symbols.add(symbol) + saved_sym_nodes_binding.append(node) + else: + saved_sym_nodes_derived.append(node) + + # Now go through all of the prospective backward inputs and track any + # other symbols we need to bind + symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) + for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): + if "val" not in node.meta: + continue + new_symbols = free_symbols(node.meta["val"]) - saved_symbols + # NB: Deterministic order please! + for s in sorted(new_symbols, key=lambda s: s.name): + # NB: For well formed graphs, the symbol should always be present, + # but we also have ways to produce ill-formed graphs, e.g., direct + # make_fx usages, so don't choke in this case + if s not in symbol_bindings: + continue + saved_sym_nodes_binding.append(symbol_bindings[s]) + saved_symbols |= new_symbols + + + # Update saved_sym_nodes that are now reordered to have all bindings at + # front. This can also be used later on to figure out the position of saved + # sym nodes in the output of fwd graph. + saved_sym_nodes.clear() + saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) + + # Now, we re-generate the fwd/bwd graphs. + # NB: This might increase compilation time, but I doubt it matters + fwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + primal_inputs + fwd_seed_offset_inputs, + fwd_outputs + saved_values + saved_sym_nodes + ) + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs + backward_state_inputs, + bwd_outputs + ) + + fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) + bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) + return fwd_module, bwd_module + + +def default_partition( + joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the :attr:`joint_module` in a manner that closely resembles the + behavior observed in the original ``.forward()`` and ``.backward()`` of the + callable, i.e., the resulting forward graph contains those operators that + are executed in the original ``.forward()`` callable passed to + :func:`aot_function`. + + The default partitioner collects the operators that are between the forward + inputs and the forward outputs. This helps in finding the tensors which have + to be stashed for the backward pass. These stashed tensors become the output + of the generated forward graph. The remaining operators are then placed in + the backward graph. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + if has_recomputable_ops(joint_module): + return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs) + forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != 'output'} + saved_values = [] + saved_sym_nodes = [] + + for node in joint_module.graph.nodes: + if node.name not in forward_node_names: + continue + if is_sym_node(node): + # Symints must be kept separate from tensors so that PythonFunction only calls + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes.append(node) + elif ( + 'tensor_meta' not in node.meta + and node.op == 'call_function' + ): + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + users = node.users + assert all(user.target == operator.getitem for user in users) + saved_values.extend(users) + else: + backward_usages = [n for n in node.users if n.name not in forward_node_names] + if 'tensor_meta' in node.meta and all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + else: + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) + + return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) + + +def _prod(x): + s = 1 + for i in x: + s *= i + return s + +def _tensor_nbytes(numel, dtype): + return numel * dtype.itemsize + +def _size_of(node: fx.Node) -> int: + if 'val' in node.meta: + val = node.meta['val'] + if isinstance(val, py_sym_types): + if isinstance(val, torch.SymInt): + return 1 + else: + return 999999 + # NB: The fallback values here are meaningless, maybe we should respect + # torch._inductor.config.unbacked_symint_fallback (but this is a + # layering violation) + elif isinstance(val, (list, tuple)): + return sum(_tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) for n in val if isinstance(n, torch.Tensor)) + elif isinstance(val, torch.Tensor): + return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) + + raise RuntimeError(f"Unknown metadata type {type(val)}") + + # Only needed since we don't always trace with fake tensors. + if 'tensor_meta' in node.meta: + metadata = node.meta['tensor_meta'] + # TODO: What is to_size_hint suppose to be? + numel = _prod(map(to_size_hint, metadata.shape)) # noqa: F821 + dtype = metadata.dtype + else: + return 0 + + return _tensor_nbytes(numel, dtype) + + +# Used for some investigative purposes +def _count_ops(graph): + from collections import defaultdict + cnt = defaultdict(int) + for node in graph.nodes: + if node.op == 'call_function': + cnt[node.target.__name__] += 1 + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) + + +@functools.lru_cache(None) +def pointwise_ops(): + ops = [] + for attr_name in dir(torch.ops.aten): + opoverloadpacket = getattr(torch.ops.aten, attr_name) + if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): + continue + + for overload in opoverloadpacket.overloads(): + op_overload = getattr(opoverloadpacket, overload) + if torch.Tag.pointwise in op_overload.tags: + # currently aot autograd uses packet not overload + ops.append(opoverloadpacket) + break + + return ops + +def get_depth(node, depth_map): + if node in depth_map: + return depth_map[node] + + # Base case + if node.op == "placeholder": + depth_map[node] = 0 + return depth_map[node] + + # Handle output node + if node.op == "output": + args = node.args[0] + for arg in args: + if isinstance(arg, torch.fx.node.Node): + get_depth(arg, depth_map) + return + + # Get the depth of args and set the depth of this node + arg_depths = [get_depth(arg, depth_map) for arg in node.all_input_nodes if isinstance(arg, torch.fx.node.Node)] + # factory ops like full, rand might not have any input args + if len(arg_depths) == 0: + arg_depths = [0] + depth_map[node] = max(arg_depths) + 1 + return depth_map[node] + + +def sort_depths(args, depth_map): + arg_depths = {arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)} + return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) + + +def reordering_to_mimic_autograd_engine(gm): + """ + This pass finds the first bwd node in the graph (by looking at users of + tangents) and then reorders the graph by walking from this node to all the + way to the end of the graph. At each op in this traveral, we insert this op + in a new graph and try to bring only the relevant subgraph from the other + non-bwd edges relevant for this op. This closely mimics the behavior of + autograd engine. + + Why is this pass required in the first place? + + This is an artifact of how partitioners work today. The starting point of + partitioner is a joint graph, which is fwd and then bwd graph. In the case + of checkpointing, we keep portions of fwd graph in their original place in + the joint graph, while obtaining a bwd graph. As a result, the resulting bwd + graph has copies of recomputed fwd subgraphs followed by the original bwd + graph. If we run this naively, this leads to bad memory footprint, because + the fwd subgraphs are live for way longer duration than necessary. This pass + reorders the operations such that we prioritize the ops for the original bwd + graph while only realizing those ops from the fwd graph that are necessary + at any given point in the graph. + """ + + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in gm.graph.nodes: + if node.op == "placeholder": + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + + order = {} + for idx, node in enumerate(gm.graph.nodes): + order[node] = idx + + # Populate depth for the nodes. Depth is the distance from the inputs. + depths = {} + output_node = next(node for node in gm.graph.nodes if node.op == "output") + get_depth(output_node, depths) + + def insert_node_in_graph(node): + if node in env: + return env[node] + + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + for arg, _ in sort_depths(node.all_input_nodes, depths): + env[arg] = insert_node_in_graph(arg) + env[node] = new_graph.node_copy(node, lambda x: env[x]) + return env[node] + + # Find first bwd node in the graph + tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) + first_node_in_bwd = None + minimum_order = math.inf + for tangent in tangent_inputs: + for user in tangent.users: + if order[user] < minimum_order: + minimum_order = order[user] + first_node_in_bwd = user + assert first_node_in_bwd is not None + + # Build the graph op-by-op by starting from the node all the way to the end + for node in list(gm.graph.nodes)[order[first_node_in_bwd]:]: + insert_node_in_graph(node) + + # The output node is already built by the traversal. + new_gm = torch.fx.GraphModule(gm, new_graph) + return new_gm + + +def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes): + # During user-driven activation checkpointing, we have to ensure that a rng + # op in fwd yields the same output as the recomputed rng op in the bwd. To + # do this, we use functionalize wrappers to wrap the random ops and share + # rng state between the fwd and bwd graphs. + + # There are 3 main steps to do this + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + # Step 2 - Modify the fwd pass such that + # 1) Replace rand with run_and_save_rng_state wrapper + # 2) Replace the users of the original op with the output[1] of this op. + # 3) Collect all the rng_state - output[0] of each op, and make them + # output nodes. Special care needs to be taken here because fwd outputs + # has symints at the very end. + # Step 3 - Modify the bwd pass such that + # 1) Add the input nodes just before the tangents for the stashed rng states + # 2) Replace rand with run_with_save_rng_state wrappers + # 3) Use the stashed states as inputs to these ops + + # Unique id to generate name + uid = itertools.count() + + def get_rng_ops(gmod): + random_nodes = {} + for node in gmod.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + random_nodes[node.name] = node + return random_nodes + + def get_device(node): + """ + Check the example value of the node outputs to find the device type. + """ + if "val" not in node.meta: + return None + + candidates = node.meta["val"] + if not isinstance(candidates, tuple): + candidates = (candidates,) + + for candidate in candidates: + if isinstance(candidate, torch.Tensor): + if candidate.device.type == "cuda": + return "cuda" + + return "cpu" + + def get_sample_rng_state(device): + if device == "cuda": + return torch.cuda.get_rng_state() + return torch.get_rng_state() + + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + joint_graph_rng_ops = get_rng_ops(joint_module) + fw_graph_rng_ops = get_rng_ops(fw_module) + bw_graph_rng_ops = get_rng_ops(bw_module) + recomputable_rng_ops_map = dict() + for node in joint_module.graph.nodes: + if ( + must_recompute(node) + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + base_node = joint_graph_rng_ops[node.name] + fw_node = fw_graph_rng_ops[node.name] + bw_node = bw_graph_rng_ops[node.name] + recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} + + run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state + run_with_rng_state = torch._prims.rng_prims.run_with_rng_state + + for node in bw_module.graph.nodes: + if node.op == "placeholder" and "tangent" in node.name: + bw_tangent_start_node = node + break + + + fw_rng_state_outputs = [] + for base_node, node_pair in recomputable_rng_ops_map.items(): + # Step 2 - Modify the fwd pass such that + fw_node = node_pair["fwd"] + bw_node = node_pair["bwd"] + fw_graph = fw_module.graph + with fw_graph.inserting_before(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + run_and_save_rng, + args=(fw_node.target, *fw_node.args), + kwargs=fw_node.kwargs + ) + state = fw_graph.create_node("call_function", operator.getitem, args=(functional_fw_node, 0), kwargs={}) + rng_output = fw_graph.create_node("call_function", operator.getitem, args=(functional_fw_node, 1,), kwargs={}) + fw_node.replace_all_uses_with(rng_output) + fw_graph.erase_node(fw_node) + fw_rng_state_outputs.append(state) + + + # Step 3 - Modify the bwd pass such that + bw_graph = bw_module.graph + with bw_graph.inserting_before(bw_tangent_start_node): + state_name = f"rng_state_output_{next(uid)}" + bw_rng_state_node = bw_graph.placeholder(state_name) + bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) + + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + run_with_rng_state, + args=(bw_rng_state_node, bw_node.target, *bw_node.args), + kwargs=bw_node.kwargs + ) + + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + + # Add the rng states in the output of the fwd graph. AOT Autograd assumes + # that symints are at the end of forward graph outputs. So, insert the new + # rng states accordingly. + fw_output_node = next(node for node in fw_module.graph.nodes if node.op == "output") + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = fw_outputs[:sym_node_start_idx] + fw_rng_state_outputs + fw_outputs[sym_node_start_idx:] + fw_module.graph.output(outputs) + fw_module.graph.erase_node(fw_output_node) + fw_module.recompile() + bw_module.recompile() + return fw_module, bw_module + + +def cleanup_recompute_tags(joint_module): + """ + If there are two consecutive checkpointed blocks with no operator in + between, we would still want to stash the tensor at the boundary of + checkpointed blocks. The following pass makes the last output node + non-recomputable to allow for that. + """ + for node in joint_module.graph.nodes: + if must_recompute(node): + for user in node.users: + if must_recompute(user) and user.meta["recompute"] > node.meta["recompute"]: + node.meta["recompute"] = 0 + return joint_module + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, _joint_inputs, compiler="inductor", recomputable_ops=None, + *, num_fwd_outputs +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimination. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + try: + import networkx as nx + except ImportError as e: + raise RuntimeError("Need networkx installed to perform smart recomputation " + "heuristics") from e + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + full_bw_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + + name_to_node = {} + for node in joint_module.graph.nodes: + name_to_node[node.name] = node + + def classify_nodes(joint_module): + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == 'placeholder' and "tangents" in node.target: + required_bw_nodes.add(node) + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + required_bw_nodes.update(o for o in bwd_outputs if o is not None) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs) + required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes + if node.op != 'output'} + unclaimed_nodes = {node for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes} + return fwd_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, inputs + + orig_fw_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, inputs = classify_nodes(joint_module) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(required_bw_nodes) == 0: + return default_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) + + for node in reversed(joint_module.graph.nodes): + if node not in required_fw_nodes: + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + aten = torch.ops.aten + prims = torch.ops.prims + + # compiler == "nvfuser" is the default set of recomputable ops + default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501,B950 + view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + if compiler == "inductor": + default_recomputable_ops += [prims.div, prims.convert_element_type, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.arange, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum] # noqa: E501,B950 + view_ops += [aten.view, aten.slice, aten.permute, aten.t, prims.broadcast_in_dim, aten.expand, aten.as_strided] + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [ + method_to_operator(m) + for m in magic_methods + ] + + recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops) + + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950 + + fusible_ops = recomputable_ops | set(random_ops) + if AOT_PARTITIONER_DEBUG: + joint_module_ops = { + str(node.target._overloadpacket) + for node in joint_module.graph.nodes + if node.op == "call_function" and hasattr(node.target, "_overloadpacket") + } + ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} + print("Ops banned from rematerialization: ", ops_ignored) + print() + + def is_materialized_backwards(node): + cur_nodes = {node} + while len(cur_nodes) > 0: + cur = cur_nodes.pop() + for user in cur.users: + if user not in required_fw_nodes and not is_fusible(cur, user): + return True + if user not in required_fw_nodes and get_aten_target(user) in view_ops: + cur_nodes.add(user) + + return False + + def ban_recomputation(node): + if "recompute" in node.meta: + return node.meta["recompute"] == 0 + elif config.aggressive_recomputation: + ignored_ops = random_ops + compute_intensive_ops + return (node.op == 'call_function' and get_aten_target(node) in ignored_ops) + else: + if node.op != 'call_function': + return False + if get_aten_target(node) not in recomputable_ops: + return True + if node.target == operator.getitem: + return False + if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: + return False + + # If a node *must* be materialized in the backwards pass, then we + # should never recompute it. This is a pretty subtle point. In + # general, the assumption we make is that recomputing a node in the + # backwards pass is "free". However, if a node must be materialized + # in the backwards pass, then recomputing it is never free. + if is_materialized_backwards(node): + return True + + # Arbitrary hack that sometimes seems to help things. The above + # modification appears to have made this heuristic a lot less critical + # for performance. + # TODO: Investigate why this hack helps. + # TODO: Investigate the interaction with compiler assisted + # activation checkpointing. Removing the heuristic improves both + # memory footprint and speedup. + if not graph_has_recomputable_ops: + if compiler == "inductor" and node.dist_from_bw > config.max_dist_from_bw: + return True + # If the output of an op is 4x smaller (arbitrary choice), + # then we don't allow recomputation. + input_tensors_size = sum(_size_of(i) for i in node.args if isinstance(i, fx.Node)) + output_size = _size_of(node) + return (output_size * 4 < input_tensors_size) + + def is_fusible(a, b): + # We can perform "memory fusion" into a cat, but cat cannot be a + # producer to a fusion + if get_aten_target(b) == aten.cat: + return True + return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops + + def is_materialized(node): + if node.op == 'placeholder': + return True + + return not all(is_fusible(node, user) for user in node.users) + + def get_node_weight(node) -> int: + mem_sz = _size_of(node) + + # Heuristic to bias towards nodes closer to the backwards pass + # Complete guess about current value + mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) + # mem_sz = int(mem_sz + node.dist_from_bw) + + if is_materialized(node): + return mem_sz + else: + return mem_sz * 2 + + nx_graph = nx.DiGraph() + for node in full_bw_graph.nodes: + if node.op == 'output': + continue + + if node in required_bw_nodes: + if node not in inputs: + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + # If someone saves a input for backward as-is and backward + # returns that tensor as-is as a grad input, then the node x would + # be both a required_bw_node and an input. In this case we + # (1) connect x_in to to the source, (2) x_out to the sink, and + # (3) assign the proper weight to the x_in-x_out edge, so that + # x would be part of cut nodes. A case where this happens is if + # NestedTensor saves a offset tensor as part of the singleton int + # in sizes. + nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf) + + if _is_primal(node) or _is_fwd_seed_offset(node): + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + + # If a node can't be recomputed (too expensive or involves randomness), + # we prevent it from being recomputed by adding an inf edge to the source + # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. + if ban_recomputation(node) and node in required_fw_nodes: + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + + # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. + is_non_tensor_node = (('val' not in node.meta and 'tensor_meta' not in node.meta) or + ('val' in node.meta and not isinstance(node.meta['val'], torch.Tensor))) + + if is_sym_node(node): + weight = sym_node_size(node) + elif is_non_tensor_node: + weight = 0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + else: + weight = get_node_weight(node) + + # Creates the weights on the "node" edge + nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) + for user in node.users: + nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) + + try: + cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") + except Exception: + print('Failed to compute min-cut on following graph:') + print('\n'.join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + raise + + reachable, non_reachable = partition + cutset = set() + for u, nbrs in ((n, nx_graph[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + + cut_nodes = set() + for node_in, node_out in cutset: + assert node_in[:-3] == node_out[:-4] + node_name = node_in[:-3] + cut_nodes.add(node_name) + + # To make this stuff deterministic + node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]) + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes = list(filter(is_sym_node, saved_values)) + saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols + fw_module, bw_module = _extract_fwd_bwd_modules( + joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) + + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + if AOT_PARTITIONER_DEBUG: + print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9) + fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'} + bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'} + remat_nodes = fw_module_nodes & bw_module_nodes + + counts = defaultdict(int) + for node in fw_module.graph.nodes: + if node.name in remat_nodes and hasattr(node.target, '_overloadpacket'): + counts[str(node.target._overloadpacket)] += 1 + print(f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}") + print("Count of Ops Rematerialized: ", sorted(counts.items(), key=lambda x: x[1], reverse=True)) + return fw_module, bw_module + + +def draw_graph( + traced: torch.fx.GraphModule, + fname: str, + figname: str = "fx_graph", + clear_meta: bool = True, + prog: Union[str, List[str]] = None, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, +) -> None: + if clear_meta: + new_graph = copy.deepcopy(traced.graph) + traced = fx.GraphModule(traced, new_graph) + for node in traced.graph.nodes: + node.meta = {} + base, ext = os.path.splitext(fname) + if not ext: + ext = ".svg" + print(f"Writing FX graph to file: {base}{ext}") + g = graph_drawer.FxGraphDrawer( + traced, + figname, + parse_stack_trace=parse_stack_trace, + dot_graph_shape=dot_graph_shape, + ) + x = g.get_main_dot_graph() + write_method = getattr(x, "write_" + ext.lstrip(".")) + fname = f"{base}{ext}" + if prog is None: + write_method(fname) + else: + write_method(fname, prog=prog) + + +def draw_joint_graph( + graph: torch.fx.GraphModule, + joint_inputs, + file_name: str = "full_graph.png", + dot_graph_shape: Optional[str] = None, +): + draw_graph(graph, file_name, dot_graph_shape=dot_graph_shape) + return default_partition(graph, joint_inputs) diff --git a/MLPY/Lib/site-packages/torch/_functorch/pyfunctorch.py b/MLPY/Lib/site-packages/torch/_functorch/pyfunctorch.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae0c7c353673d8156d77e70b0f81b1ce2eea83b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/pyfunctorch.py @@ -0,0 +1,252 @@ +from abc import ABC, abstractmethod +import contextlib +from typing import Any, List, Tuple +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + TransformType, + RandomnessType, + CInterpreter, + CGradInterpreterPtr, + CFunctionalizeInterpreterPtr, + CVmapInterpreterPtr, + CJvpInterpreterPtr, + pop_dynamic_layer_stack, + push_dynamic_layer_stack, +) +from torch.autograd.forward_ad import _set_fwd_grad_enabled + +""" +This file contains the functorch integration with PyDispatcher. + +PyDispatcher does not understand functorch's DynamicLayerStack dispatching +logic because it is entirely implemented in C++ in the fallbacks for two +dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable +to directly reuse C++ boxed fallbacks). + +Instead of trying to hammer PyDispatcher into understanding those fallbacks, +we re-implement the logic of peeking the top of the stack for an interpreter, +selecting the interpreter to dispatch on, etc, in Python. This leads to a +simpler design. + +The main difference between C++ functorch and PyDispatcher's functorch logic +is that: +- C++ functorch needs to manually tweak dispatch keys to ping-pong between + DynamicLayerFrontMode and DynamicLayerBackMode. +- PyDispatcher's functorch logic pops an Interpreter from the top of the stack + and asks it to execute the rule associated with the Interpreter. + +In C++ we do the ping-pong because e.g. vmap rules are associated with the +batched DispatchKey, but in PyDispatcher we are able to avoid this by asking +the user to register a batching rule directly to a transform that an +interpreter then invokes. +""" + + +# FuncTorchInterpreter is the Python version of Interpreter (recall that +# the DynamicLayerStack is a stack of interpreters). +# It is a wrapper around the actual C++ Interpreter object. +# +# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h +class FuncTorchInterpreter(ABC): + def __init__(self, cptr: Any): + self._cptr = cptr + + # Process an operation. eg for vmap, this is invoking a batching rule. + # Conceptually this is analogous to Interpreter::process in C++ + @abstractmethod + def process(self, op, args, kwargs): + pass + + # lower an operation from this Interpreter to the next Interpreter on the stack. + # Concretely, this involves temporarily popping the current Interpreter. + # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++ + def lower(self): + return temporarily_pop_interpreter_stack() + + def level(self): + return self._cptr.level() + + def key(self): + return self._cptr.key() + + def get_state(self): + raise NotImplementedError() + + def check_state(self, state): + return state == self.get_state() + + +@contextlib.contextmanager +def temporarily_pop_interpreter_stack(): + try: + saved = pop_dynamic_layer_stack() + yield + finally: + push_dynamic_layer_stack(saved) + + +class VmapInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Vmap + # NOTE: [Interpreter cdata vs cptr] + # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr + # so that we can access methods specific to the vmap interpreter + self._cdata = cdata + self._cptr = CVmapInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Vmap] + return kernel(self, *args, **kwargs) + + def batch_size(self): + return self._cptr.batchSize() + + def randomness(self): + typ = self._cptr.randomness() + if typ == RandomnessType.Error: + return "error" + elif typ == RandomnessType.Same: + return "same" + elif typ == RandomnessType.Different: + return "different" + raise RuntimeError(f"Unknown RandomnessType: {typ}") + + def get_state(self): + return (self.key().name, self.level(), self.randomness()) + + +@contextlib.contextmanager +def nested(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx) + yield contexts + + +class GradInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Grad + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + self._cptr = CGradInterpreterPtr(cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs]) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Grad] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # GradInterpreter has custom lower because of the no_grad interaction + # See NOTE [grad and vjp interaction with no_grad] + # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_grad_mode = self.prev_grad_mode() + if not prev_grad_mode: + return nested(torch.no_grad(), super().lower()) + return super().lower() + + def prev_grad_mode(self): + return self._cptr.prevGradMode() + + def get_state(self): + return (self.key().name, self.level(), self.prev_grad_mode()) + + +class JvpInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Jvp + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + self._cptr = CJvpInterpreterPtr(cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs]) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Jvp] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # Jvp has custom lower because of the no_fwd_grad interaction + # See NOTE [grad and vjp interaction with no_grad] for related info. + # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_fwd_grad_mode = self.prev_fwd_grad_mode() + if not prev_fwd_grad_mode: + return nested(_set_fwd_grad_enabled(False), super().lower()) + return super().lower() + + def prev_fwd_grad_mode(self): + return self._cptr.prevFwdGradMode() + + +class FunctionalizeInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Functionalize + self._cdata = cdata + self._cptr = CFunctionalizeInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Functionalize] + return kernel(self, *args, **kwargs) + + def functionalize_add_back_views(self): + return self._cptr.functionalizeAddBackViews() + + +def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: + key = cinterpreter.key() + if key == TransformType.Grad: + return GradInterpreter(cinterpreter) + if key == TransformType.Vmap: + return VmapInterpreter(cinterpreter) + if key == TransformType.Jvp: + return JvpInterpreter(cinterpreter) + if key == TransformType.Functionalize: + return FunctionalizeInterpreter(cinterpreter) + raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") + + +def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter: + interpreter = torch._C._functorch.peek_interpreter_stack() + assert interpreter is not None + return coerce_cinterpreter(interpreter) + + +def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]: + cis = torch._C._functorch.get_interpreter_stack() + if cis is None: + return [] + return [coerce_cinterpreter(ci) for ci in cis] + + +def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool: + # There are four possible cases covered here: + # 1. Current stack empty AND stack when generated not empty -> Invalidate + # 2. Current stack not empty AND stack when generated empty -> Invalidate + # 3. Current stack and generated stack empty -> Valid FX graph + # 4. Current stack and generated stack not empty -> Valid if both states match + peek = torch._C._functorch.peek_interpreter_stack() + if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0): + return False + + cis = retrieve_all_functorch_interpreters() + return len(cis) == len(states) and \ + all(ci.check_state(state) for ci, state in zip(cis, states)) + + +def dispatch_functorch(op, args, kwargs): + interpreter = retrieve_current_functorch_interpreter() + # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's + # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers. + # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch + # transforms, so we manually unwrap the dead tensors here. + # This logic won't need to exist when we have mode-only functorch. + args, kwargs = pytree.tree_map_only( + torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)) + return interpreter.process(op, args, kwargs) diff --git a/MLPY/Lib/site-packages/torch/_functorch/python_key.py b/MLPY/Lib/site-packages/torch/_functorch/python_key.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0f6c14124fa1d8a2bd59b198d38e9d4368f5ec --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/python_key.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +__all__ = ["make_fx", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"] +from torch.fx.experimental.proxy_tensor import make_fx, dispatch_trace, PythonKeyTracer, decompose + +pythonkey_decompose = decompose diff --git a/MLPY/Lib/site-packages/torch/_functorch/pytree_hacks.py b/MLPY/Lib/site-packages/torch/_functorch/pytree_hacks.py new file mode 100644 index 0000000000000000000000000000000000000000..c016206a0267b555a39e2a441b308a3156a656c0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/pytree_hacks.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +# TODO: remove this file when the migration of the pytree utility is done +from torch.utils._pytree import tree_map_, treespec_pprint + + +__all__ = ["tree_map_", "treespec_pprint"] + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "torch._functorch.pytree_hacks is deprecated and will be removed in a future release. " + "Please use torch.utils._pytree instead.", + DeprecationWarning, + ) diff --git a/MLPY/Lib/site-packages/torch/_functorch/top_operators_github_usage.py b/MLPY/Lib/site-packages/torch/_functorch/top_operators_github_usage.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5422f8ffe833bf2f7346eb67474cfeac99af0a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/top_operators_github_usage.py @@ -0,0 +1,625 @@ +# mypy: ignore-errors + +""" +From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 +Try to keep this list in sync with that. +""" +top_torch = [ + ("t", 6837449), + ("tensor", 585786), + ("mode", 462182), + ("cat", 394818), + ("max", 368038), + ("zeros", 329495), + ("load", 327756), + ("no_grad", 294694), + ("save", 265130), + ("from_numpy", 243063), + ("manual_seed", 165044), + ("ones", 153696), + ("randn", 150796), + ("stack", 133358), + ("sum", 130772), + ("arange", 98087), + ("rand", 94715), + ("mean", 88546), + ("exp", 73883), + ("zeros_like", 72831), + ("min", 72248), + ("sigmoid", 66798), + ("log", 62135), + ("matmul", 47811), + ("clamp", 45304), + ("sqrt", 44911), + ("abs", 43535), + ("tanh", 42793), + ("empty", 40311), + ("argmax", 38435), + ("bmm", 33984), + ("pow", 33571), + ("norm", 31125), + ("mm", 30995), + ("is_tensor", 29546), + ("ones_like", 29512), + ("nonzero", 28681), + ("full", 28373), + ("unsqueeze", 27911), + ("where", 26585), + ("randperm", 26450), + ("eye", 24342), + ("mul", 23236), + ("topk", 22537), + ("as_tensor", 21967), + ("sort", 21412), + ("squeeze", 20863), + ("randint", 20771), + ("linspace", 20041), + ("add", 19201), + ("transpose", 18663), + ("split", 18325), + ("gather", 17904), + ("set_grad_enabled", 16013), + ("sin", 15669), + ("cos", 15562), + ("div", 15513), + ("index_select", 14866), + ("multinomial", 14331), + ("flatten", 14267), + ("isnan", 14170), + ("randn_like", 13096), + ("eq", 12680), + ("einsum", 12480), + ("round", 12367), + ("floor", 11628), + ("allclose", 11000), + ("reshape", 10605), + ("diag", 10167), + ("chunk", 9581), + ("std", 9379), + ("set_default_tensor_type", 9281), + ("triu", 8559), + ("meshgrid", 8292), + ("set_num_threads", 8126), + ("unique", 7964), + ("full_like", 7780), + ("tril", 7538), + ("dot", 7275), + ("sign", 6943), + ("equal", 6916), + ("normal", 6750), + ("cumsum", 6556), + ("dist", 6058), + ("isfinite", 6030), + ("gt", 5935), + ("set_printoptions", 5888), + ("range", 5491), + ("empty_like", 5351), + ("flip", 5342), + ("masked_select", 5341), + ("bernoulli", 5262), + ("atan", 5253), + ("var", 5247), + ("prod", 5200), + ("erf", 5088), + ("inverse", 5072), + ("addmm", 4854), + ("logsumexp", 4582), + ("fft", 4436), + ("lt", 4421), + ("log2", 4316), + ("enable_grad", 4238), + ("rand_like", 4187), + ("argsort", 3972), + ("seed", 3932), + ("mv", 3547), + ("ger", 3309), + ("ge", 3248), + ("atan2", 3210), + ("ceil", 3202), + ("ne", 3075), + ("bincount", 3063), + ("acos", 3055), + ("rsqrt", 3031), + ("svd", 3029), + ("numel", 3003), + ("log1p", 2840), + ("unbind", 2808), + ("le", 2714), + ("isinf", 2707), + ("cross", 2646), + ("set_default_dtype", 2536), + ("argmin", 2535), + ("sparse_coo_tensor", 2489), + ("log10", 2304), + ("kthvalue", 2192), + ("set_rng_state", 2158), + ("get_rng_state", 1996), + ("get_default_dtype", 1879), + ("det", 1868), + ("qr", 1864), + ("histc", 1852), + ("symeig", 1832), + ("trace", 1801), + ("median", 1795), + ("addcmul", 1751), + ("remainder", 1717), + ("baddbmm", 1693), + ("lgamma", 1665), + ("repeat_interleave", 1598), + ("fmod", 1576), + ("reciprocal", 1575), + ("tan", 1560), + ("initial_seed", 1532), + ("take", 1529), + ("stft", 1487), + ("get_num_threads", 1477), + ("real", 1459), + ("cholesky", 1406), + ("quantize_per_tensor", 1392), + ("diag_embed", 1364), + ("lerp", 1363), + ("asin", 1345), + ("eig", 1333), + ("trunc", 1290), + ("diagonal", 1287), + ("cosh", 1279), + ("rfft", 1269), + ("cumprod", 1260), + ("addr", 1211), + ("roll", 1198), + ("narrow", 1188), + ("digamma", 1172), + ("square", 1163), + ("sinh", 1131), + ("logspace", 1084), + ("broadcast_tensors", 1070), + ("irfft", 1013), + ("frac", 997), + ("hann_window", 994), + ("solve", 989), + ("logdet", 977), + ("expm1", 968), + ("cdist", 946), + ("addmv", 903), + ("randint_like", 888), + ("tensordot", 888), + ("ifft", 877), + ("true_divide", 854), + ("erfinv", 830), + ("addcdiv", 819), + ("addbmm", 813), + ("renorm", 781), + ("pinverse", 753), + ("isclose", 740), + ("erfc", 729), + ("is_storage", 725), + ("triangular_solve", 723), + ("rot90", 709), + ("logical_not", 686), + ("geqrf", 681), + ("slogdet", 677), + ("lu", 665), + ("hamming_window", 659), + ("orgqr", 651), + ("ormqr", 622), + ("is_floating_point", 602), + ("diagflat", 562), + ("cholesky_solve", 559), + ("tril_indices", 552), + ("chain_matmul", 551), + ("triu_indices", 548), + ("angle", 522), + ("poisson", 505), + ("matrix_power", 485), + ("unique_consecutive", 471), + ("quantize_per_channel", 465), + ("std_mean", 458), + ("bartlett_window", 447), + ("var_mean", 428), + ("lstsq", 421), + ("logical_and", 419), + ("mvlgamma", 411), + ("blackman_window", 400), + ("bitwise_not", 395), + ("cholesky_inverse", 388), + ("as_strided", 384), + ("floor_divide", 353), + ("cartesian_prod", 321), + ("lu_solve", 317), + ("set_flush_denormal", 310), + ("empty_strided", 283), + ("logical_xor", 282), + ("polygamma", 282), + ("logical_or", 280), + ("set_num_interop_threads", 278), + ("combinations", 274), + ("trapz", 270), + ("matrix_rank", 260), + ("lu_unpack", 255), + ("result_type", 244), + ("conj", 231), + ("cummax", 230), + ("lobpcg", 229), + ("bitwise_xor", 217), + ("promote_types", 213), + ("get_num_interop_threads", 211), + ("cummin", 205), + ("bitwise_and", 198), + ("dequantize", 192), + ("bitwise_or", 191), + ("imag", 191), + ("can_cast", 184), + ("istft", 180), + ("compiled_with_cxx11_abi", 159), + ("is_complex", 151), + ("block_diag", 136), + ("pca_lowrank", 124), + ("absolute", 122), + ("svd_lowrank", 108), + ("neg", 2), +] + +top_nn_functional = [ + ("nn.functional.softmax", 10522), + ("nn.functional.relu", 8572), + ("nn.functional.interpolate", 7277), + ("nn.functional.pad", 5207), + ("nn.functional.log_softmax", 4699), + ("nn.functional.normalize", 2338), + ("nn.functional.cross_entropy", 2083), + ("nn.functional.grid_sample", 1970), + ("nn.functional.one_hot", 1967), + ("nn.functional.mse_loss", 1920), + ("nn.functional.conv2d", 1593), + ("nn.functional.dropout", 1516), + ("nn.functional.softplus", 1385), + ("nn.functional.sigmoid", 1128), + ("nn.functional.linear", 1036), + ("nn.functional.gelu", 930), + ("nn.functional.avg_pool2d", 899), + ("nn.functional.max_pool2d", 876), + ("nn.functional.nll_loss", 863), + ("nn.functional.embedding", 737), + ("nn.functional.tanh", 664), + ("nn.functional.leaky_relu", 640), + ("nn.functional.adaptive_avg_pool2d", 633), + ("nn.functional.cosine_similarity", 627), + ("nn.functional.unfold", 609), + ("nn.functional.conv1d", 596), + ("nn.functional.binary_cross_entropy_with_logits", 591), + ("nn.functional.l1_loss", 571), + ("nn.functional.binary_cross_entropy", 492), + ("nn.functional.elu", 416), + ("nn.functional.batch_norm", 413), + ("nn.functional.upsample", 413), + ("nn.functional.fold", 305), + ("nn.functional.affine_grid", 298), + ("nn.functional.max_pool1d", 297), + ("nn.functional.torch", 294), + ("nn.functional.threshold", 263), + ("nn.functional.smooth_l1_loss", 262), + ("nn.functional.pairwise_distance", 253), + ("nn.functional.logsigmoid", 243), + ("nn.functional.adaptive_max_pool2d", 235), + ("nn.functional.relu6", 213), + ("nn.functional.pixel_shuffle", 209), + ("nn.functional.avg_pool3d", 203), + ("nn.functional.bilinear", 203), + ("nn.functional.conv_transpose2d", 201), + ("nn.functional.gumbel_softmax", 197), + ("nn.functional.max_unpool2d", 196), + ("nn.functional.kl_div", 191), + ("nn.functional.hardtanh", 189), + ("nn.functional.ctc_loss", 185), + ("nn.functional.layer_norm", 178), + ("nn.functional.conv3d", 172), + ("nn.functional.max_unpool3d", 167), + ("nn.functional.hardshrink", 165), + ("nn.functional.hardswish", 156), + ("nn.functional.selu", 156), + ("nn.functional.glu", 155), + ("nn.functional.assert_int_or_pair", 150), + ("nn.functional.hardsigmoid", 146), + ("nn.functional.upsample_bilinear", 146), + ("nn.functional.max_pool3d", 140), + ("nn.functional.adaptive_avg_pool3d", 139), + ("nn.functional.instance_norm", 124), + ("nn.functional.embedding_bag", 122), + ("nn.functional.upsample_nearest", 110), + ("nn.functional.avg_pool1d", 105), + ("nn.functional.prelu", 102), + ("nn.functional.celu", 92), + ("nn.functional.dropout2d", 86), + ("nn.functional.hinge_embedding_loss", 82), + ("nn.functional.softsign", 81), + ("nn.functional.max_unpool1d", 74), + ("nn.functional.silu", 74), + ("nn.functional.softshrink", 70), + ("nn.functional.leaky_relu_", 68), + ("nn.functional.softmin", 67), + ("nn.functional.channel_shuffle", 66), + ("nn.functional.multilabel_margin_loss", 66), + ("nn.functional.dropout3d", 65), + ("nn.functional.multi_margin_loss", 65), + ("nn.functional.lp_pool2d", 64), + ("nn.functional.conv_transpose1d", 62), + ("nn.functional.triplet_margin_loss", 62), + ("nn.functional.tanhshrink", 61), + ("nn.functional.adaptive_max_pool1d", 59), + ("nn.functional.cosine_embedding_loss", 58), + ("nn.functional.multi_head_attention_forward", 58), + ("nn.functional.max_pool1d_with_indices", 53), + ("nn.functional.poisson_nll_loss", 53), + ("nn.functional.margin_ranking_loss", 52), + ("nn.functional.soft_margin_loss", 52), + ("nn.functional.adaptive_max_pool3d", 51), + ("nn.functional.group_norm", 51), + ("nn.functional.local_response_norm", 51), + ("nn.functional.multilabel_soft_margin_loss", 51), + ("nn.functional.relu_", 50), + ("nn.functional.alpha_dropout", 49), + ("nn.functional.feature_alpha_dropout", 49), + ("nn.functional.lp_pool1d", 49), + ("nn.functional.adaptive_max_pool1d_with_indices", 48), + ("nn.functional.adaptive_max_pool2d_with_indices", 48), + ("nn.functional.adaptive_max_pool3d_with_indices", 48), + ("nn.functional.fractional_max_pool2d", 48), + ("nn.functional.fractional_max_pool2d_with_indices", 48), + ("nn.functional.fractional_max_pool3d", 48), + ("nn.functional.fractional_max_pool3d_with_indices", 48), + ("nn.functional.max_pool2d_with_indices", 48), + ("nn.functional.max_pool3d_with_indices", 48), + ("nn.functional.handle_torch_function", 47), + ("nn.functional.has_torch_function", 47), + ("nn.functional.adaptive_avg_pool1d", 43), + ("nn.functional.pdist", 43), + ("nn.functional.rrelu_", 37), + ("nn.functional.elu_", 34), + ("nn.functional.boolean_dispatch", 33), + ("nn.functional.hardtanh_", 26), + ("nn.functional.triplet_margin_with_distance_loss", 23), + ("nn.functional.selu_", 20), + ("nn.functional.pixel_unshuffle", 19), + ("nn.functional.conv_transpose3d", 18), + ("nn.functional.gaussian_nll_loss", 15), + ("nn.functional.has_torch_function_unary", 15), + ("nn.functional.has_torch_function_variadic", 15), + ("nn.functional.celu_", 13), + ("nn.functional.huber_loss", 7), + ("nn.functional.mish", 4), + ("nn.functional.threshold_", 3), + ("nn.functional.grad", 2), + ("nn.functional.conv_tbc", 1), + ("nn.functional.math", 1), +] + +top_nn_module = [ + ("nn.Module", 927129, None), + ("nn.Linear", 530688, "nn.functional.linear"), + ("nn.Sequential", 384968, None), + ("nn.Conv2d", 383320, "nn.functional.conv2d"), + ("nn.ReLU", 318877, "nn.functional.relu"), + ("nn.BatchNorm2d", 233265, "nn.functional.batch_norm"), + ("nn.Dropout", 179268, "nn.functional.dropout"), + ("nn.ModuleList", 171225, None), + ("nn.Parameter", 153291, None), + ("nn.CrossEntropyLoss", 152696, "nn.functional.cross_entropy"), + ("nn.MaxPool2d", 138619, "nn.functional.max_pool2d"), + ("nn.Embedding", 111844, "nn.functional.embedding"), + ("nn.DataParallel", 104238, None), + ("nn.MSELoss", 82954, "nn.functional.mse_loss"), + ("nn.Sigmoid", 75810, "nn.functional.sigmoid"), + ("nn.LeakyReLU", 65632, "nn.functional.leaky_relu"), + ("nn.BatchNorm1d", 65374, "nn.functional.batch_norm"), + ("nn.Softmax", 65114, "nn.functional.softmax"), + ("nn.Tanh", 59445, "nn.functional.tanh"), + ("nn.AdaptiveAvgPool2d", 59071, "nn.functional.adaptive_avg_pool2d"), + ("nn.AvgPool2d", 58377, "nn.functional.avg_pool2d"), + ("nn.ConvTranspose2d", 57524, "nn.functional.conv_transpose2d"), + ("nn.LSTM", 57411, None), + ("nn.Conv1d", 41108, "nn.functional.conv1d"), + ("nn.LayerNorm", 36089, "nn.functional.layer_norm"), + ("nn.BCELoss", 34005, "nn.functional.binary_cross_entropy"), + ("nn.Upsample", 32527, "nn.functional.interpolate"), + ("nn.BCEWithLogitsLoss", 29944, "nn.functional.binary_cross_entropy_with_logits"), + ("nn.GRU", 25421, None), + ("nn.Dropout2d", 23512, "nn.functional.dropout2d"), + ("nn.LogSoftmax", 22897, "nn.functional.log_softmax"), + ("nn.L1Loss", 22778, "nn.functional.l1_loss"), + ("nn.GroupNorm", 22183, "nn.functional.group_norm"), + ("nn.NLLLoss", 21751, "nn.functional.nll_loss"), + ("nn.Conv3d", 20874, "nn.functional.conv3d"), + ("nn.Identity", 17911, None), + ("nn.InstanceNorm2d", 16426, "nn.functional.instance_norm"), + ("nn.BatchNorm3d", 16378, "nn.functional.batch_norm"), + ("nn.PReLU", 13472, "nn.functional.prelu"), + ("nn.ReLU6", 12622, "nn.functional.relu6"), + ("nn.ELU", 12508, "nn.functional.elu"), + ("nn.LSTMCell", 10885, None), + ("nn.Flatten", 10384, "torch.flatten"), + ("nn.ModuleDict", 10255, None), + ("nn.ReflectionPad2d", 9954, "nn.functional.pad"), + ("nn.MaxPool3d", 9526, "nn.functional.max_pool3d"), + ("nn.MaxPool1d", 9154, "nn.functional.max_pool1d"), + ("nn.RNN", 9154, None), + ("nn.ZeroPad2d", 8847, "nn.functional.pad"), + ("nn.ParameterList", 7702, None), + ("nn.SyncBatchNorm", 6814, None), + ("nn.PixelShuffle", 6571, "nn.functional.pixel_shuffle"), + ("nn.SmoothL1Loss", 6517, "nn.functional.smooth_l1_loss"), + ("nn.Hardswish", 6458, "nn.functional.hardswish"), + ("nn.AdaptiveMaxPool2d", 6071, "nn.functional.adaptive_max_pool2d"), + ("nn.SELU", 6043, "nn.functional.selu"), + ("nn.ConvTranspose3d", 6039, "nn.functional.conv_transpose3d"), + ("nn.GRUCell", 5840, None), + ("nn.ReplicationPad2d", 5600, "nn.functional.pad"), + ("nn.KLDivLoss", 5541, "nn.functional.kl_div"), + ("nn.ConvTranspose1d", 5183, "nn.functional.conv_transpose1d"), + ("nn.Softplus", 5120, "nn.functional.softplus"), + ("nn.SiLU", 4895, "nn.functional.silu"), + ("nn.AvgPool3d", 4523, "nn.functional.avg_pool3d"), + ("nn.CosineSimilarity", 4058, "nn.functional.cosine_similarity"), + ("nn.GELU", 3932, "nn.functional.gelu"), + ("nn.UpsamplingBilinear2d", 3673, "nn.functional.interpolate"), + ("nn.InstanceNorm1d", 3658, "nn.functional.instance_norm"), + ("nn.Transformer", 3604, None), + ("nn.MultiheadAttention", 3435, "nn.functional.multi_head_attention_forward"), + ("nn.AvgPool1d", 3195, "nn.functional.avg_pool1d"), + ("nn.Dropout3d", 2964, "nn.functional.dropout3d"), + ("nn.AdaptiveAvgPool3d", 2915, "nn.functional.adaptive_avg_pool3d"), + ("nn.InstanceNorm3d", 2893, "nn.functional.instance_norm"), + ("nn.Hardtanh", 2613, "nn.functional.hardtanh"), + ("nn.MarginRankingLoss", 2568, "nn.functional.margin_ranking_loss"), + ("nn.GLU", 2526, "nn.functional.glu"), + ("nn.AdaptiveAvgPool1d", 2481, "nn.functional.adaptive_avg_pool1d"), + ("nn.EmbeddingBag", 2344, "nn.functional.embedding_bag"), + ("nn.TransformerEncoderLayer", 2292, None), + ("nn.TransformerEncoder", 2091, None), + ("nn.MaxUnpool2d", 2031, "nn.functional.max_unpool2d"), + ("nn.UpsamplingNearest2d", 2004, "nn.functional.interpolate"), + ("nn.ConstantPad1d", 1904, "nn.functional.pad"), + ("nn.ConstantPad2d", 1791, "nn.functional.pad"), + ("nn.CTCLoss", 1789, "nn.functional.ctc_loss"), + ("nn.AdaptiveMaxPool1d", 1713, "nn.functional.adaptive_max_pool1d"), + ("nn.AdaptiveLogSoftmaxWithLoss", 1665, None), + ("nn.Bilinear", 1664, "nn.functional.bilinear"), + ("nn.RNNCell", 1653, None), + ("nn.MultiLabelSoftMarginLoss", 1624, "nn.functional.multilabel_soft_margin_loss"), + ("nn.Unfold", 1452, "nn.functional.unfold"), + ("nn.RReLU", 1431, "nn.functional.rrelu"), + ("nn.CosineEmbeddingLoss", 1357, "nn.functional.cosine_embedding_loss"), + ("nn.LocalResponseNorm", 1331, "nn.functional.local_response_norm"), + ("nn.Softmax2d", 1300, "nn.functional.softmax"), + ("nn.PairwiseDistance", 1241, "nn.functional.pairwise_distance"), + ("nn.LogSigmoid", 1235, "nn.functional.logsigmoid"), + ("nn.TripletMarginLoss", 1230, "nn.functional.triplet_margin_loss"), + ("nn.RNNBase", 1133, None), + ("nn.Threshold", 1043, "nn.functional.threshold"), + ("nn.AdaptiveMaxPool3d", 1025, "nn.functional.adaptive_max_pool3d"), + ("nn.CELU", 1018, "nn.functional.celu"), + ("nn.NLLLoss2d", 966, "nn.functional.nll_loss"), + ("nn.Softsign", 877, "nn.functional.softsign"), + ("nn.ReplicationPad1d", 862, "nn.functional.pad"), + ("nn.SoftMarginLoss", 856, "nn.functional.soft_margin_loss"), + ("nn.ParameterDict", 742, None), + ("nn.ReflectionPad1d", 731, "nn.functional.pad"), + ("nn.Softshrink", 713, "nn.functional.softshrink"), + ("nn.AlphaDropout", 710, "nn.functional.alpha_dropout"), + ("nn.Tanhshrink", 681, "nn.functional.tanhshrink"), + ("nn.PoissonNLLLoss", 676, "nn.functional.poisson_nll_loss"), + ("nn.MaxUnpool3d", 660, "nn.functional.max_unpool3d"), + ("nn.Fold", 630, "nn.functional.fold"), + ("nn.MultiMarginLoss", 622, "nn.functional.multi_margin_loss"), + ("nn.TransformerDecoderLayer", 614, None), + ("nn.TransformerDecoder", 607, None), + ("nn.Hardshrink", 592, "nn.functional.hardshrink"), + ("nn.ConstantPad3d", 582, "nn.functional.pad"), + ("nn.MultiLabelMarginLoss", 580, "nn.functional.multilabel_margin_loss"), + ("nn.LPPool2d", 550, "nn.functional.lp_pool2d"), + ("nn.Softmin", 537, "nn.functional.softmin"), + ("nn.MaxUnpool1d", 518, "nn.functional.max_unpool1d"), + ("nn.FractionalMaxPool2d", 484, "nn.functional.fractional_max_pool2d"), + ("nn.Hardsigmoid", 477, "nn.functional.hardsigmoid"), + ("nn.ReplicationPad3d", 470, "nn.functional.pad"), + ("nn.HingeEmbeddingLoss", 442, "nn.functional.hinge_embedding_loss"), + ("nn.LPPool1d", 386, "nn.functional.lp_pool1d"), + ("nn.FractionalMaxPool3d", 252, "nn.functional.fractional_max_pool3d"), + ("nn.Container", 217, None), + ("nn.Unflatten", 206, "nn.functional.unflatten"), + ("nn.FeatureAlphaDropout", 136, "nn.functional.feature_alpha_dropout"), + ("nn.TripletMarginWithDistanceLoss", 107, "nn.functional.triplet_margin_with_distance_loss"), + ("nn.ChannelShuffle", 90, "nn.functional.channel_shuffle"), + ("nn.RNNCellBase", 88, None), + ("nn.LazyLinear", 81, "nn.functional.linear"), + ("nn.UninitializedParameter", 60, None), + ("nn.CrossMapLRN2d", 59, None), + ("nn.GaussianNLLLoss", 55, "nn.functional.gaussian_nll_loss"), + ("nn.PixelUnshuffle", 45, "nn.functional.pixel_unshuffle"), + ("nn.Mish", 31, "nn.functional.mish"), + ("nn.ReflectionPad3d", 22, "nn.functional.pad"), + ("nn.HuberLoss", 18, "nn.functional.huber_loss"), + ("nn.LazyConv2d", 15, None), + ("nn.LazyConv1d", 9, None), + ("nn.LazyConv3d", 8, None), + ("nn.LazyConvTranspose1d", 8, None), + ("nn.LazyConvTranspose2d", 8, None), + ("nn.LazyConvTranspose3d", 8, None), + ("nn.LazyBatchNorm1d", 3, None), + ("nn.LazyBatchNorm2d", 3, None), + ("nn.LazyBatchNorm3d", 3, None), + ("nn.UninitializedBuffer", 3, None), +] + +# No rankings because these are a little hard to get rankings for +method_only_ops = [ + 'bfloat16', + 'bool', + 'byte', + 'char', + 'contiguous', + 'cpu', + 'cuda', + 'detach', + 'double', + 'expand', + 'expand_as', + 'float', + 'get_device', + 'half', + 'hardshrink', + 'index_add', + 'index_copy', + 'index_fill', + 'index_put', + 'int', + 'is_contiguous', + 'is_pinned', + 'is_set_to', + 'is_shared', + 'is_signed', + 'item', + 'long', + 'masked_scatter', + 'masked_fill', + 'narrow_copy', + 'numpy', + 'pin_memory', + 'repeat', + 'reshape_as', + 'select', + 'short', + 'storage_offset', + 'sum_to_size', + 'to', + 'to_mkldnn', + 'tolist', + 'type', + 'type_as', + 'unfold', + 'view', + 'view_as', +] + + +def get_nn_functional_top_list(): + top_nn_functional_ = dict(top_nn_functional) + for _, count, functional_name in top_nn_module: + if functional_name is None: + continue + if functional_name == 'torch.flatten': + continue + if functional_name not in top_nn_functional_: + top_nn_functional_[functional_name] = count + else: + top_nn_functional_[functional_name] += count + + top_nn_functional_ = list(top_nn_functional_.items()) + top_nn_functional_.sort(key=lambda x: x[1], reverse=True) + return top_nn_functional_ + + +usage_count = {} +for k, v in get_nn_functional_top_list(): + usage_count[k] = v +for k, v in top_torch: + usage_count[k] = v diff --git a/MLPY/Lib/site-packages/torch/_functorch/utils.py b/MLPY/Lib/site-packages/torch/_functorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86e0f7d65aab4c49dc5c91f3d3cee0639a0a3b55 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/utils.py @@ -0,0 +1,41 @@ +import contextlib +import torch +from torch._C._functorch import ( + set_single_level_autograd_function_allowed, + get_single_level_autograd_function_allowed, + unwrap_if_dead, +) +from typing import Union, Tuple + +@contextlib.contextmanager +def enable_single_level_autograd_function(): + try: + prev_state = get_single_level_autograd_function_allowed() + set_single_level_autograd_function_allowed(True) + yield + finally: + set_single_level_autograd_function_allowed(prev_state) + +def unwrap_dead_wrappers(args): + # NB: doesn't use tree_map_only for performance reasons + result = tuple( + unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg + for arg in args + ) + return result + +# Allows one to expose an API in a private submodule publicly as per the definition +# in PyTorch's public api policy. +# +# It is a temporary solution while we figure out if it should be the long-term solution +# or if we should amend PyTorch's public api policy. The concern is that this approach +# may not be very robust because it's not clear what __module__ is used for. +# However, both numpy and jax overwrite the __module__ attribute of their APIs +# without problem, so it seems fine. +def exposed_in(module): + def wrapper(fn): + fn.__module__ = module + return fn + return wrapper + +argnums_t = Union[int, Tuple[int, ...]] diff --git a/MLPY/Lib/site-packages/torch/_functorch/vmap.py b/MLPY/Lib/site-packages/torch/_functorch/vmap.py new file mode 100644 index 0000000000000000000000000000000000000000..09339f809caf7a64348db8bde1ea174de13b2240 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_functorch/vmap.py @@ -0,0 +1,452 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import contextlib +import functools +import threading +from torch import Tensor +from typing import Any, Callable, Optional, Tuple, Union, List +from torch.utils._pytree import ( + tree_flatten, + tree_unflatten, + tree_map_, + _broadcast_to_and_flatten, + TreeSpec, +) +from functools import partial +import os +import itertools + +from torch._C._functorch import ( + _add_batch_dim, + _remove_batch_dim, + _vmap_decrement_nesting, + _vmap_increment_nesting, + is_batchedtensor, +) + +in_dims_t = Union[int, Tuple] +out_dims_t = Union[int, Tuple[int, ...]] + + +def doesnt_support_saved_tensors_hooks(f): + message = ( + "torch.func transforms don't yet support saved tensor hooks. " + "Please open an issue with your use case." + ) + + @functools.wraps(f) + def fn(*args, **kwargs): + with torch.autograd.graph.disable_saved_tensors_hooks(message): + return f(*args, **kwargs) + return fn + + +# Checks that all args-to-be-batched have the same batch dim size +def _validate_and_get_batch_size( + flat_in_dims: List[Optional[int]], + flat_args: List) -> int: + batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args) + if in_dim is not None] + if len(batch_sizes) == 0: + raise ValueError('vmap: Expected at least one Tensor to vmap over') + if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): + raise ValueError( + f'vmap: Expected all tensors to have the same size in the mapped ' + f'dimension, got sizes {batch_sizes} for the mapped dimension') + return batch_sizes[0] + + +def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: + if isinstance(batched_outputs, tuple): + return len(batched_outputs) + return 1 + +# If value is a tuple, check it has length `num_elements`. +# If value is not a tuple, make a tuple with `value` repeated `num_elements` times + + +def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple: + if not isinstance(value, tuple): + return (value,) * num_elements + if len(value) != num_elements: + raise ValueError(error_message_lambda()) + return value + + +def _process_batched_inputs( + in_dims: in_dims_t, args: Tuple, func: Callable +) -> Tuple[int, List[Any], List[Any], TreeSpec]: + if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'expected `in_dims` to be int or a (potentially nested) tuple ' + f'matching the structure of inputs, got: {type(in_dims)}.') + if len(args) == 0: + raise ValueError( + f'vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add ' + f'inputs, or you are trying to vmap over a function with no inputs. ' + f'The latter is unsupported.') + + flat_args, args_spec = tree_flatten(args) + flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) + if flat_in_dims is None: + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'in_dims is not compatible with the structure of `inputs`. ' + f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs ' + f'has structure {args_spec}.') + + for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): + if not isinstance(in_dim, int) and in_dim is not None: + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for an input but in_dim must be either ' + f'an integer dimension or None.') + if isinstance(in_dim, int) and not isinstance(arg, Tensor): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for an input but the input is of type ' + f'{type(arg)}. We cannot vmap over non-Tensor arguments, ' + f'please use None as the respective in_dim') + if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): + raise ValueError( + f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(): ' + f'Got in_dim={in_dim} for some input, but that input is a Tensor ' + f'of dimensionality {arg.dim()} so expected in_dim to satisfy ' + f'-{arg.dim()} <= in_dim < {arg.dim()}.') + if in_dim is not None and in_dim < 0: + flat_in_dims[i] = in_dim % arg.dim() + + return _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec + +# Creates BatchedTensors for every Tensor in arg that should be batched. +# Returns the (potentially) batched arguments and the batch_size. + + +def _create_batched_inputs( + flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple: + # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] + batched_inputs = [arg if in_dim is None else + _add_batch_dim(arg, in_dim, vmap_level) + for in_dim, arg in zip(flat_in_dims, flat_args)] + return tree_unflatten(batched_inputs, args_spec) + + +def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim): + + if out_dim is None: + if isinstance(batched_output, torch.Tensor) and is_batchedtensor(batched_output): + raise ValueError( + f'vmap({name}, ...): `{name}` can not return a ' + f'BatchedTensor when out_dim is None' + ) + return batched_output + + # out_dim is non None + if not isinstance(batched_output, torch.Tensor): + raise ValueError(f'vmap({name}, ...): `{name}` must only return ' + f'Tensors, got type {type(batched_output)}. ' + 'Did you mean to set out_dim= to None for output?') + + return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) + + +# Undos the batching (and any batch dimensions) associated with the `vmap_level`. +def _unwrap_batched( + batched_outputs: Union[Tensor, Tuple[Tensor, ...]], + out_dims: out_dims_t, + vmap_level: int, batch_size: int, func: Callable) -> Tuple: + flat_batched_outputs, output_spec = tree_flatten(batched_outputs) + + def incompatible_error(): + raise ValueError( + f'vmap({_get_name(func)}, ..., out_dims={out_dims})(): ' + f'out_dims is not compatible with the structure of `outputs`. ' + f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs ' + f'has structure {output_spec}.') + + if isinstance(batched_outputs, torch.Tensor): + # Some weird edge case requires us to spell out the following + # see test_out_dims_edge_case + if isinstance(out_dims, int): + flat_out_dims = [out_dims] + elif isinstance(out_dims, tuple) and len(out_dims) == 1: + flat_out_dims = out_dims + elif out_dims is None: + flat_out_dims = [out_dims] + else: + incompatible_error() + else: + flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) + if flat_out_dims is None: + incompatible_error() + + flat_outputs = [ + _maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim) + for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) + ] + return tree_unflatten(flat_outputs, output_spec) + + +def _check_int_or_none(x, func, out_dims): + if isinstance(x, int): + return + if x is None: + return + raise ValueError( + f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be ' + f'an int, None or a python collection of ints representing where in the outputs the ' + f'vmapped dimension should appear.') + + +def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None: + if isinstance(out_dims, int): + return + tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims) + + +def _get_name(func: Callable): + if hasattr(func, '__name__'): + return func.__name__ + + # Not all callables have __name__, in fact, only static functions/methods do. + # A callable created via functools.partial or an nn.Module, to name some + # examples, don't have a __name__. + return repr(func) + + +DECOMPOSITIONS_LOADED = False +DECOMPOSITIONS_LOCK = threading.Lock() +VMAP_DECOMPOSITIONS_LIB = None + +# torch.package, Python 3.11, and torch.jit-less environments are unhappy with +# decompositions. Only load them when needed if possible. +def lazy_load_decompositions(): + global DECOMPOSITIONS_LOADED + if DECOMPOSITIONS_LOADED: + return + + with DECOMPOSITIONS_LOCK: + if DECOMPOSITIONS_LOADED: + return + + if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__): + DECOMPOSITIONS_LOADED = True + return + + # use an alternate way to register an operator into the decomposition table + # _register_jit_decomposition doesn't work for some operators, e.g. addr, + # because the Tensor types generated cannot be unioned by torchscript + # decomp should be type OpOverload + global VMAP_DECOMPOSITIONS_LIB + VMAP_DECOMPOSITIONS_LIB = torch.library.Library("aten", "IMPL", "FuncTorchBatched") + + from torch._decomp import decomposition_table + + def _register_python_decomposition_vmap(decomp): + if decomp in decomposition_table: + VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp]) + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + + _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.smooth_l1_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.addr.default) + + DECOMPOSITIONS_LOADED = True + +def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs): + lazy_load_decompositions() + _check_out_dims_is_int_or_int_pytree(out_dims, func) + batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func) + + if chunk_size is not None: + chunks_flat_args = _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size) + return _chunked_vmap(func, flat_in_dims, chunks_flat_args, + args_spec, out_dims, randomness, **kwargs) + + # If chunk_size is not specified. + return _flat_vmap( + func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs + ) + +def get_chunk_sizes(total_elems, chunk_size): + n_chunks = n_chunks = total_elems // chunk_size + chunk_sizes = [chunk_size] * n_chunks + # remainder chunk + remainder = total_elems % chunk_size + if remainder != 0: + chunk_sizes.append(remainder) + return chunk_sizes + +def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size): + split_idxs = (batch_size,) + if chunk_size is not None: + chunk_sizes = get_chunk_sizes(batch_size, chunk_size) + split_idxs = tuple(itertools.accumulate(chunk_sizes)) + + flat_args_chunks = tuple( + t.tensor_split(split_idxs, dim=in_dim) if in_dim is not None else [t, ] * len(split_idxs) + for t, in_dim in zip(flat_args, flat_in_dims) + ) + + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + +def _flatten_chunks_output(chunks_output_): + # chunks_output is a list of chunked outputs + # flatten chunked outputs: + flat_chunks_output = [] + arg_spec = None + for output in chunks_output_: + flat_output, arg_specs = tree_flatten(output) + flat_chunks_output.append(flat_output) + if arg_spec is None: + arg_spec = arg_specs + + # transpose chunk dim and flatten structure + # flat_output_chunks is flat list of chunks + flat_output_chunks = list(zip(*flat_chunks_output)) + return flat_output_chunks, arg_spec + + +def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks): + # concat chunks on out_dim + flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) + assert len(flat_out_dims) == len(flat_output_chunks) + flat_output = [] + for idx, out_dim in enumerate(flat_out_dims): + flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim)) + # release tensors + flat_output_chunks[idx] = None + + return flat_output + + +# Applies vmap on chunked_input and returns concatenated output over the chunks. +def _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs): + + chunks_output = [] + rs = torch.get_rng_state() if randomness == "same" else None + for flat_args in chunks_flat_args: + batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) + + # The way we compute split the input in `_get_chunked_inputs`, + # we may get a tensor with `0` batch-size. We skip any computation + # in that case. + # Eg. + # >>> chunk_size = 1 + # >>> batch_size = 6 + # >>> t = torch.zeros(batch_size, 1) + # >>> t.tensor_split([1, 2, 3, 4, 5, 6]) + # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), + # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1))) + if batch_size == 0: + continue + + if rs is not None: + torch.set_rng_state(rs) + chunks_output.append( + _flat_vmap( + func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs + ) + ) + + flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) + + # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`. + # eagerly remove the reference from `chunks_output`. + del chunks_output + + # concat chunks on out_dim + flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks) + + # finally unflatten the output + return tree_unflatten(flat_output, arg_spec) + + +# Vmap refactored helper functions: +def _check_randomness_arg(randomness): + if randomness not in ['error', 'different', 'same']: + raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}") + + +@contextlib.contextmanager +def vmap_increment_nesting(batch_size, randomness): + try: + vmap_level = _vmap_increment_nesting(batch_size, randomness) + yield vmap_level + finally: + _vmap_decrement_nesting() + + +@doesnt_support_saved_tensors_hooks +def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs): + + with vmap_increment_nesting(batch_size, randomness) as vmap_level: + batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec) + batched_outputs = func(*batched_inputs, **kwargs) + return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) + + +# `restore_vmap` is a private helper function. It is vmap but has the following +# differences: +# - instead of returning outputs, it returns an (outputs, out_dims) tuple. +# out_dims is a pytree of same shape as outputs and contains Optional[int] +# specifying where the vmapped dimension, if it exists, is in the corresponding output. +# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped). +# restore_vmap allows for no inputs to have the vmap dimension +# - does no validation on outputs (vmap expects only Tensor outputs) +# restore_vmap allows for return of arbitrary outputs (not just Tensors) +# +# The TL;DR is that restore_vmap is more general than vmap and has a slightly +# different API. The relaxations are so that we can "pause" vmap in the middle +# of its execution and then "restore" it later (this is what we do in +# the generate_vmap_rule=True implementation of autograd.Function). +# +# restore_vmap can be technically used in the implementation of vmap, but doing +# that refactor is a bit technically challenging because: +# - vmap couples the tensor-wrapping code with error checking +# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it +# in python because it overlaps with unwrap_batched +@doesnt_support_saved_tensors_hooks +def restore_vmap(func, in_dims, batch_size, randomness): + def inner(*args, **kwargs): + with vmap_increment_nesting(batch_size, randomness) as vmap_level: + batched_inputs = wrap_batched(args, in_dims, vmap_level) + batched_outputs = func(*batched_inputs, **kwargs) + return unwrap_batched(batched_outputs, vmap_level) + return inner + + +def wrap_batched(args, bdims, level): + flat_args, spec = tree_flatten(args) + flat_bdims = _broadcast_to_and_flatten(bdims, spec) + assert flat_bdims is not None + result = _create_batched_inputs(flat_bdims, flat_args, level, spec) + return result + + +def unwrap_batched(args, level): + flat_args, spec = tree_flatten(args) + if len(flat_args) == 0: + return args, () + result = [torch._C._functorch._unwrap_batched(arg, level) if isinstance(arg, torch.Tensor) + else (arg, None) for arg in flat_args] + output, bdims = zip(*result) + return tree_unflatten(output, spec), tree_unflatten(bdims, spec) diff --git a/MLPY/Lib/site-packages/torch/_guards.py b/MLPY/Lib/site-packages/torch/_guards.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d8cc55c0eec746678ae3fa922791676f1ae77b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_guards.py @@ -0,0 +1,879 @@ +from __future__ import annotations + +import contextlib + +import dataclasses +import enum +import functools +import logging +import threading +import traceback +import unittest.mock +import weakref +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + NamedTuple, + Optional, + Set, + Tuple, + TYPE_CHECKING, + TypeVar, +) + +import torch +from torch.utils import _pytree as pytree +from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakTensorKeyDictionary + +log = logging.getLogger(__name__) + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + + +""" +torch._guards is the definitional source of truth for general purpose guard structures. + +An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions, +and no guard installation notions here. +""" + + +class CompileId(NamedTuple): + frame_id: int + # This id is per-frame, and counts how many times we've compiled this + # frame. This could have been a global id but having this be per-frame + # gives you a better intuitive sense for how many recompiles have occurred + # so far. + frame_compile_id: int + # TODO: consider also tracking the recompilation count + + def __str__(self): + return f"{self.frame_id}/{self.frame_compile_id}" + + +class TraceId(NamedTuple): + compile_id: CompileId + # This starts off as 0, and every time we restart analysis it goes + # up by one + attempt: int + + def __str__(self): + if self.attempt == 0: + return str(self.compile_id) + else: + return f"{self.compile_id}_{self.attempt}" + + +class GuardSource(enum.Enum): + LOCAL = 0 + GLOBAL = 1 + LOCAL_NN_MODULE = 2 + GLOBAL_NN_MODULE = 3 + CONSTANT = 4 + RANDOM_VALUE = 5 + SHAPE_ENV = 6 + LOCAL_FSDP_MODULE = 7 + GLOBAL_FSDP_MODULE = 8 + BACKWARD_STATE = 9 + EPHEMERAL = 10 + SYNTHETIC_LOCAL = 11 + + def is_fsdp_module(self) -> bool: + return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) + + def is_nn_module(self) -> bool: + return ( + self + in ( + GuardSource.GLOBAL_NN_MODULE, + GuardSource.LOCAL_NN_MODULE, + ) + or self.is_fsdp_module() + ) + + def is_local(self): + return self in ( + GuardSource.LOCAL, + GuardSource.LOCAL_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE, + ) + + +""" +Base class for a "GuardBuilder" role. + +The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little +confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference +to torchdynamo's GuardBuilder. + +Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based +on GuardSource's select function. + +There is value in keeping this GuardBuilderBase empty to keep layering clean. +""" + + +class GuardBuilderBase: + pass + + +class ShapeGuard(NamedTuple): + expr: sympy.Expr + stack: CapturedTraceback + + +@dataclasses.dataclass +class Guard: + # originating_source is the source that called the make_guard method to + # construct this guard object. The property name specifies what exactly it + # is the guard is guarding on. The meaning of the name is dependent on the + # create_fn; you must look at the use-site inside create_fn to know what + # name means. + # + # That being said, although you might think this is just a "name", name is + # usually an arbitrary Python expression that will be evaluated with all + # globals (and locals, if you create a LOCAL guard) to extract the Python + # object that we want to perform guard tests on. This evaluation + # typically happens in GuardBuilder.eval. In these cases, name is + # typically produced by originating_source.name() (not to be confused with + # GuardSource - the property source). + # + # Occasionally, name is not a valid Python expression; sometimes + # it is meaningless. Example create_fns that are like this include + # GRAD_MODE and SHAPE_ENV. + originating_source: Source + create_fn: Callable[[GuardBuilderBase, Guard], None] + + # Export only. These values are written to at time of guard check_fn creation. + guard_types: Optional[List[str]] = None + code_list: Optional[List[str]] = None + obj_weakref: Optional[object] = None + guarded_class_weakref: Optional[type] = None + + stack: Optional[CapturedTraceback] = None + user_stack: Optional[traceback.StackSummary] = None + _hash: Optional[int] = None + + def __hash__(self): + if self._hash is None: + self._hash = hash((self.name, self.source, id(self.create_fn))) + return self._hash + + def sort_key(self): + return ( + self.source.value if self.source else -1, + len(self.name), + self.name, + self.inner_create_fn().__code__.co_firstlineno, + ) + + def __lt__(self, other): + return self.sort_key() < other.sort_key() + + def inner_create_fn(self): + if isinstance(self.create_fn, functools.partial): + return self.create_fn.func + else: + return self.create_fn + + @property + def name(self) -> str: + return self.originating_source.name() + + @property + def source(self) -> GuardSource: + return self.originating_source.guard_source() + + @staticmethod + def weakref_to_str(obj_weakref): + """ + This is a workaround of a Python weakref bug. + + `obj_weakref` is instance returned by `weakref.ref`, + `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: + + class MyConfig(dict): + def __getattr__(self, x): + return self[x] + + obj = MyConfig(offset=5) + obj_weakref = weakref.ref(obj) + str(obj_weakref) # raise error: KeyError: '__name__' + """ + if isinstance(obj_weakref, weakref.ReferenceType): + obj = obj_weakref() + if obj is not None: + return f"" + else: + return f"" + else: + return str(obj_weakref) + + def __repr__(self): + s = f""" + {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} + {{ + 'guard_types': {self.guard_types}, + 'code': {self.code_list}, + 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} + 'guarded_class': {self.guarded_class_weakref} + }} + """ + return s + + def __str__(self): + output = f"Name: {repr(self.name)}\n" + source = self.source.name.lower() if self.source else "" + output += f" Source: {source}\n" + output += f" Create Function: {self.inner_create_fn().__name__}\n" + output += f" Guard Types: {self.guard_types}\n" + output += f" Code List: {self.code_list}\n" + output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n" + output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" + return output + + def create(self, builder: GuardBuilderBase): + try: + return self.create_fn(builder, self) + except Exception: + log.error("Error while creating guard:\n%s", str(self).rstrip()) + if self.stack: + log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) + raise + + def is_nn_module(self): + return self.source.is_nn_module() + + def is_fsdp_module(self): + return self.source.is_fsdp_module() + + def is_local(self): + return self.source.is_local() + + def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): + if not self.guard_types: + self.guard_types = list() + + self.guard_types.append(guard_type) + + assert self.guarded_class_weakref in ( + guarded_class, + None, + ), "Guarded class id must be identical, or None" + self.guarded_class_weakref = guarded_class + + if not self.code_list: + self.code_list = code_list + else: + self.code_list.extend(code_list) + + assert self.obj_weakref in ( + obj_weakref, + None, + ), "Guarded object must be identical, or None" + self.obj_weakref = obj_weakref + + +T = TypeVar("T") + +""" +Parent structure for guard env expressions. +A GuardEnvExpr can have any subtype. +Note: All subtypes must be handled exhaustively in +torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError. +""" + + +@dataclasses.dataclass +class GuardEnvExpr: + pass + + +""" +A class representing a pair of duplicate inputs. +input_pos_a and input_pos_b are input positions we have deduped. +""" + + +@dataclasses.dataclass +class DuplicateInputs(GuardEnvExpr): + input_source_a: Source + input_source_b: Source + + def __post_init__(self): + assert self.input_source_a != self.input_source_b + + +""" +Checkpointable is an interface for driving state snapshotting, left purposely vague for now. + +copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that +can also be taken in at restore_graphstate(T) calls. + +When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable +does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet. + +In the future, it will have a closer coupling to a generic Checkpoint management system. +""" + + +class Checkpointable(ABC, Generic[T]): + @abstractmethod + def copy_graphstate(self) -> T: + ... + + @abstractmethod + def restore_graphstate(self, state: T): + ... + + +class GuardsCheckpointState: + """ + The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext + """ + + dynamo_guards: Set[Guard] = set() + + def __init__(self, dynamo_guards): + self.dynamo_guards = dynamo_guards + + def diff(self, other): + """ + Produces a delta against another GuardsCheckpointState. + + Returns None if no delta is found, otherwise, return a set() of mismatched + Guard type objects. + """ + r = self.dynamo_guards.difference(other.dynamo_guards) + if len(r) == 0: + return None + return r + + def __eq__(self, other): + return self.diff(other) is None + + +class ModuleContextCheckpointState: + nn_modules: Dict[str, torch.nn.Module] = {} + + def __init__(self, nn_modules): + self.nn_modules = nn_modules + + def diff(self, other): + """ + Produces a delta against another ModuleContextCheckpointState. + + Returns None if no delta is found, otherwise, return a set() of mismatched + module key names. + """ + r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys())) + if len(r) == 0: + return None + return r + + def __eq__(self, other): + return self.diff(other) is None + + +class ModuleContext(Checkpointable[ModuleContextCheckpointState]): + def __init__(self): + self.nn_modules: Dict[str, Any] = {} + + def copy_graphstate(self): + return ModuleContextCheckpointState(dict(self.nn_modules)) + + def restore_graphstate(self, state): + assert isinstance(state, ModuleContextCheckpointState) + self.nn_modules = state.nn_modules + + +class GlobalContextCheckpointState: + global_state: Dict[str, Tuple[Callable, ...]] = {} + + def __init__(self, global_states): + self.global_state = global_states + + def diff(self, other): + """ + Produces a delta against another GlobalContextCheckpointState. + + Returns None if no delta is found, otherwise, return a set() of mismatched + global key names. + """ + r = set(self.global_state.keys()).difference(set(other.global_state.keys())) + if len(r) == 0: + return None + return r + + def __eq__(self, other): + return self.diff(other) is None + + +class GlobalContext(Checkpointable[GlobalContextCheckpointState]): + """ + This keeps track of the global torch state during tracing of a function. + For example, torch.is_grad_enabled. + """ + + _supported_global_states = { + "grad_enabled", + "torch_function_enabled", + "autocast_enabled", + "autocast_cpu_enabled", + "autocast_gpu_dtype", + "autocast_cpu_dtype", + "autocast_cache_enabled", + } + + def __init__(self): + self.global_state: Dict[str, Tuple[Callable, ...]] = {} + + def copy_graphstate(self): + return GlobalContextCheckpointState(dict(self.global_state)) + + def restore_graphstate(self, state): + assert isinstance(state, GlobalContextCheckpointState) + self.global_state = state.global_state + assert ( + len(self.global_state) == len(self._supported_global_states) + and set(self.global_state.keys()) == self._supported_global_states + ), "Global state mismatch" + for func, args in self.global_state.values(): + func(args) + + +""" +A GuardsContext is a checkpointable representation of all the guards in the current tracing +context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated +directly outside of it. For passing around internal state representations of this object, +prefer to extract them with copy_graphstate to produce a GuardsCheckpointState. +""" + + +# Like a Set[Guard] but will record the user stack on all guards at the +# time they were installed at their destination +class GuardsSet: + def __init__(self, inner=None): + if inner is None: + inner = set() + self.inner = inner + + def __iter__(self): + return iter(self.inner) + + def __len__(self): + return len(self.inner) + + # Subtraction along with bool is typically used to determine the delta of + # added guards between checkpoints for higher order ops + def __sub__(self, other): + return GuardsSet(self.inner - other.inner) + + def __bool__(self): + return bool(self.inner) + + def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): + if guard in self.inner: + return + if collect_debug_stack: + if guard.stack is None: + guard.stack = CapturedTraceback.extract(skip=1 + skip) + if guard.user_stack is None: + guard.user_stack = TracingContext.extract_stack() + self.inner.add(guard) + + def update(self, *others: Set[Guard]): + for o in others: + for g in o: + self.add(g, skip=1) + + def remove_guards_with_source(self, source): + """Delete all guards with a given source""" + self.inner = {g for g in self.inner if g.originating_source != source} + + +class GuardsContext(Checkpointable[GuardsCheckpointState]): + def __init__(self): + self.dynamo_guards: GuardsSet = GuardsSet() + self.aotautograd_guards: List[GuardEnvExpr] = [] + + def copy_graphstate(self): + return GuardsCheckpointState(set(self.dynamo_guards.inner)) + + def restore_graphstate(self, state): + # NB: "steals" the passed in state + assert isinstance(state, GuardsCheckpointState) + self.dynamo_guards = GuardsSet(state.dynamo_guards) + + +_TLS = threading.local() + +""" +TracingContext is the source of truth for all currently accumulated information +needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems +are open to managing their own TracingContext with that in mind. + +The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid +having to plumb complex subsystems across multiple verticals. + +Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor. +Accessing the current tracing context via +TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how +to plumb objects back up to where frame interpretation happened. + +Note that you can end up with multiple TracingContext for a single compilation +of a frame, as we reset the TracingContext whenever we restart analysis. +CompileContext is a more overarching context that encompasses multiple restarts. +""" + + +class CompileContext: + @staticmethod + def get() -> CompileContext: + assert _TLS.compile_context is not None + return _TLS.compile_context + + @staticmethod + def try_get() -> Optional[CompileContext]: + return getattr(_TLS, "compile_context", None) + + def __init__(self, compile_id): + assert compile_id is None or isinstance(compile_id, CompileId) + self.compile_id: Optional[CompileId] = compile_id + self.attempt = 0 + + @staticmethod + def current_compile_id(): + self = CompileContext.try_get() + if self is None: + return None + return self.compile_id + + @staticmethod + def current_trace_id(): + self = CompileContext.try_get() + if self is None: + return None + if self.compile_id is None: + return None + return TraceId(self.compile_id, self.attempt) + + +class TracingContext: + """ + Provides the currently installed TracingContext, or None. + + Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but + will return None. + """ + + @staticmethod + def try_get() -> Optional[TracingContext]: + return getattr(_TLS, "tracing_context", None) + + @staticmethod + def get() -> TracingContext: + if ctx := TracingContext.try_get(): + return ctx + raise RuntimeError( + "TracingContext.get() must be called within an ongoing trace." + ) + + def __init__(self, fake_mode): + self.guards_context = GuardsContext() + self.module_context = ModuleContext() + self.global_context = GlobalContext() + self.fake_mode = fake_mode + self.frame_summary_stack = [] + # This is morally part of frame_summary_stack, but it is kept separate + # for clarity. As we process a frame, this variable gets updated + # to keep track of what line we are in the function. We make a + # function call, this gets cleared and the frame location is pushed + # to frame_summary_stack (prepping this variable for the inner frame's + # progress) + self.loc_in_frame = None + # this is only set after aot_autograd + self.fw_metadata = None + self.params_flat = None + # this is for extended return calling convention from backend + # compiler to aot_autograd + # Per output, what the compiler specified stride of the output is, + # or None if no stride is known. This is always the HINT, it + # is never a SymInt (it would be better if it was a SymInt, but + # I can't conveniently get this from Inductor atm. Also, be + # careful not to accidentally induce guards on the SymInt if + # you ever do change this in aot_autograd.py; you should check + # on permutations preferentially.) + self.output_strides: Optional[List[Optional[List[int]]]] = None + # When this is True, whenever we encounter an int in Dynamo tracing, + # we will (1) force unspec it and (2) force it as a size-like unbacked + # integer. This is currently used when processing certain lists of + # ints that are known to be size-like and may have 0/1 entries that we + # must not specialize on. + self.force_unspec_int_unbacked_size_like = False + # See note [Tensor Fakification and Symbol Caching] + self.tensor_to_context = WeakTensorKeyDictionary() + + # If this true, Aot Autograd will return output Fake Tensors with appropiate + # meta on the first invocation + # see note: [Returning Fake Tensors on First AOT Autograd Call] + self.fakify_first_call = False + + def clear(self): + # Look at the note in output_graph.py in function `save_global_state` + # for the context on clearing global context. + self.global_context.global_state = {} + + @staticmethod + @contextmanager + def patch(**kwargs): + prior = {} + ctx = TracingContext.get() + + for key in kwargs.keys(): + # KeyError on invalid entry + prior[key] = getattr(ctx, key) + for key, val in kwargs.items(): + setattr(ctx, key, val) + try: + yield + finally: + for key, val in prior.items(): + setattr(ctx, key, val) + + @staticmethod + def extract_stack(): + self = TracingContext.try_get() + if self is None: + return traceback.StackSummary() + stack = self.frame_summary_stack + if self.loc_in_frame is not None: + stack = stack + [self.loc_in_frame] + return traceback.StackSummary.from_list(stack) + + # Call this when you want to call into some code that isn't necessarily + # associated with the current frame state + @staticmethod + @contextlib.contextmanager + def clear_frame(): + tc = TracingContext.get() + with unittest.mock.patch.object( + tc, "frame_summary_stack", [] + ), unittest.mock.patch.object(tc, "loc_in_frame", None): + try: + yield + except Exception as e: + # Prevent real_stack from getting attached + # + # The invariant is that if an Exception as real_stack, we've + # appropriately attached a user stack and we no longer need to + # attach anything. Because we cannot conveniently interpose + # when an exception is thrown, we instead interpose everywhere + # we set what the user stack is set (using the context + # manager). However, our compiler stack does "tail calls" + # (when it calls into user compiler), at which point the + # parent exception frames would incorrectly attach an + # incorrect frame. + # + # However, if, somehow, someone raised an exception with this + # scope that had a stack (for example, because they are + # restoring the user stack state appropriately as they process + # node by node), we should respect it. Thus, we cannot + # unconditionally set None. + if not hasattr(e, "real_stack"): + e.real_stack = None # type: ignore[attr-defined] + raise + + @staticmethod + @contextlib.contextmanager + def current_frame(frame_summary): + # frame_summary can be None to solely take advantage of real_stack + # attachment to thrown exceptions + tc = TracingContext.get() + if frame_summary is not None: + tc.frame_summary_stack.append(frame_summary) + old = tc.loc_in_frame + tc.loc_in_frame = None + try: + yield + except Exception as e: + if not hasattr(e, "real_stack"): + e.real_stack = tc.extract_stack() # type: ignore[attr-defined] + raise + finally: + if frame_summary is not None: + tc.frame_summary_stack.pop() + tc.loc_in_frame = old + + @staticmethod + @contextlib.contextmanager + def report_output_strides(): + tc = TracingContext.try_get() + if tc is None: + yield None + return + old_output_strides = tc.output_strides + tc.output_strides = [] + try: + yield tc.output_strides + finally: + tc.output_strides = old_output_strides + + @staticmethod + def set_current_loc(filename, lineno, frame_name): + TracingContext.get().loc_in_frame = traceback.FrameSummary( + filename, lineno, frame_name + ) + + +@contextmanager +def compile_context(context: CompileContext): + old_context = getattr(_TLS, "compile_context", None) + _TLS.compile_context = context + try: + yield context + finally: + _TLS.compile_context = old_context + + +@contextmanager +def tracing(context: Optional[TracingContext]): + """ + This function installs the passed in tracing context as a dynamic scoped + global variable. + + Calls to TracingContext.get() while not under a `with tracing()` context + will return None. + """ + old_context = getattr(_TLS, "tracing_context", None) + _TLS.tracing_context = context + try: + yield context + except Exception as e: + if not hasattr(e, "real_stack") and context is not None: + e.real_stack = context.extract_stack() # type: ignore[attr-defined] + raise + finally: + if ( + context is not None + and context.fake_mode is not None + and context.fake_mode.shape_env is not None + ): + context.fake_mode.shape_env.cleanup() + _TLS.tracing_context = old_context + + +# Subclasses can be found in torch/_dynamo/source.py +# TODO(voz): Consider a toplevel torch/_source.py +@dataclasses.dataclass(frozen=True) +class Source: + def is_dict_key(self): + return False + + def is_ephemeral(self): + return False + + def reconstruct(self, codegen): + raise NotImplementedError() + + def guard_source(self) -> GuardSource: + raise NotImplementedError() + + def name(self) -> str: + raise NotImplementedError() + + def make_guard(self, fn) -> Guard: + if self.guard_source() is GuardSource.CONSTANT: + raise NotImplementedError() + return Guard(self, fn) + + def is_nn_module(self) -> bool: + return self.guard_source().is_nn_module() + + def subguards_allowed(self): + """True if you can guard on attributes of this""" + return self.guard_source() != GuardSource.SYNTHETIC_LOCAL + + +# Subclasses can be found in torch/_dynamo/source.py +@dataclasses.dataclass(frozen=True) +class ChainedSource(Source): + base: Source + + def is_dict_key(self): + # Recurse until you either hit a ConstDictKey or a Source + return self.base.is_dict_key() + + def is_ephemeral(self): + return self.base.is_ephemeral() + + +def detect_fake_mode(inputs: Any = None): + """ + Attempts to "detect" what the current fake mode is. If there is one ambiently + available from TracingContext, we preferentially use that. Otherwise, we + heuristically detect the fake mode via the following sources, in order of + priority: + + - Currently active fake mode on stack + - Fake mode associated with passed in tensors (inputs does not + have to be flattened) + """ + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + + fake_modes = [] + + if context := TracingContext.try_get(): + fake_mode = context.fake_mode + if fake_mode is not None: + fake_modes.append((fake_mode, "tracing context", 0)) + + from torch.utils._python_dispatch import _get_current_dispatch_mode_stack + + for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): + if isinstance(m, FakeTensorMode): + fake_modes.append((m, "active fake mode", i)) + + flat_inputs = pytree.tree_leaves(inputs) + for i, flat_input in enumerate(flat_inputs): + if isinstance(flat_input, FakeTensor): + fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) + + if fake_modes: + fake_mode, desc1, i1 = fake_modes[0] + for m, desc2, i2 in fake_modes[1:]: + assert fake_mode is m, ( + f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" + f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" + f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" + ) + return fake_mode + else: + return None + + +def active_fake_mode(): + """ + Inspects the dispatch mode stack for an active fake mode and returns it. + Returns None if no fake mode is active. + """ + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils._python_dispatch import _get_current_dispatch_mode_stack + + for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())): + if isinstance(m, FakeTensorMode): + return m + + return None diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__init__.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..495dfee3daef493bfc660f86ef35684ebe1ffb96 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/__init__.py @@ -0,0 +1 @@ +from .cond import cond diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95b44c3453ef0b5cf4ad49bed85131dd15ab6203 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b3f0bc470864d4b958acdc3d6716da65e8b5dfd Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50f32b9b59ecfea401079f63bc4e3aa78fb8396d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd256f5f70859f52b6aa0b573e6136a4e900e0ea Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..098c82dee9c4ffac5d979b07ff5c62ae03a0bead Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72334da357e1f1478e870c87088beb05de6b5603 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..033f7d0f2433ed9feec15912f0428d517cc001fa Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78fef3654e3168dbf745b265ee1fa52985afd724 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb51c4f6d573c9b037c463625617090f482e5e92 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2fb118c719a4ebe3e4ce939280978e1a37fed4c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c81a8f0af7aad1e005863aff178a78c8841dc39e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42a1d81ca48cd19b489e87fb36a777683f11d005 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/auto_functionalize.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/auto_functionalize.py new file mode 100644 index 0000000000000000000000000000000000000000..da7f441c309b379be7e309ad792c0256cec4df14 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/auto_functionalize.py @@ -0,0 +1,261 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +# NOTE: [auto-functionalizing custom ops] +# Users may wish to torch.compile custom ops that mutate their inputs. +# torch.compile will automatically support this op without anyone needing +# to provide a functionalization kernel for it. Here's how. +# +# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () +# op. First, when FakeTensor sees this op: +# - If the schema says it returns nothing, we can generate a trivial +# FakeTensor rule for it (that returns nothing). +# - Otherwise, the user needs to provide a FakeTensor rule (abstract impl) +# +# Next, when Python FunctionalTensor sees the op, it will functionalize +# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) +# HOP and replacing the mutated inputs with corresponding outputs of this HOP. +# This HOP effectively runs the functional version of the op when +# called: it clones inputs that will be mutated, runs the op, and +# then returns (output, Tensors with the new values) + + +class AutoFunctionalized(HigherOrderOperator): + """auto_functionalized(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + + Concretely, it looks at all the arguments that are mutable through + _mutable_op's operator schema, clones those kwargs, runs + `out = _mutable_op(**kwargs)` with the cloned values, and then returns the + operator output concatenated with the cloned values that were mutated. + + We have some restrictions on `_mutable_op`. + See `can_auto_functionalize` for the restrictions. We can likely lift + many of these if users request it. + + The reason why _mutable_op is prefixed with an + underscore is to prevent collisions with kwarg names in **kwargs. + """ + + def __init__(self): + super().__init__("auto_functionalized") + + def __call__( + self, + _mutable_op: torch._ops.OpOverload, + **kwargs: Dict[str, Any], + ) -> Tuple[Any, Tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized = AutoFunctionalized() + + +def can_auto_functionalize(op: torch._ops.OperatorBase) -> bool: + if not isinstance(op, torch._ops.OpOverload): + return False + + if torch._library.utils.is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + if not schema.is_mutable: + return False + schema = op._schema + + for arg in schema.arguments: + if arg.alias_info is None: + continue + if not arg.alias_info.is_write: + continue + if type(arg.type) is torch.TensorType: + continue + if ( + type(arg.type) is torch.OptionalType + and type(arg.type.getElementType()) is torch.TensorType + ): + continue + # Not yet supported: other Tensor types. This includes things like + # Tensor[], Tensor?[], Tensor[]?. + return False + + # The returns must not alias anything + for ret in schema.returns: + if ret.alias_info is None and type(ret.type) is torch.TensorType: + continue + # Not yet supported: List[Tensor] return. + return False + return True + + +@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_dense( + _mutable_op: torch._ops.OpOverload, + _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + new_kwargs = dict(**kwargs) + result = [] + + _mutable_args_names = get_mutable_arg_names(_mutable_op) + for name in _mutable_args_names: + if ( + _only_clone_these_tensors is not None + and name not in _only_clone_these_tensors + ): + new_kwargs[name] = kwargs[name] + else: + new_kwargs[name] = ( + clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) + result.append(new_kwargs[name]) + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *result) # type: ignore[return-value] + else: + return (out, *result) # type: ignore[return-value] + + +@auto_functionalized.py_impl(FakeTensorMode) +def auto_functionalized_fake( + mode, + _mutable_op: torch._ops.OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_proxy( + mode, + _mutable_op: torch._ops.OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + if not mode.enable_tracing: + return auto_functionalized(_mutable_op, **kwargs) + + with disable_proxy_modes_tracing(): + out = auto_functionalized(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +auto_functionalized.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) + + +def get_mutable_arg_names(op: torch._ops.OpOverload) -> List[str]: + """ + Returns the list of argument names that get mutated according to the + schema. + """ + mutable_args_names = [ + arg.name + for arg in op._schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + return mutable_args_names + + +def do_auto_functionalize( + op: torch._ops.OpOverload, args: Tuple[Any, ...], kwargs: Dict[str, Any] +) -> Any: + """Functionalizes a call to op(*args, **kwargs) by emitting a call to + `outs = auto_functionalized(op, normalized_kwargs)` + and replacing the mutated (args, kwargs) with the corresponding outputs. + + The normalized_kwargs are just the (args, kwargs), but all in kwarg form. + This makes handling easier for the auto_functionalized HOP. + """ + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized( + op, **unwrapped_kwargs # type: ignore[arg-type] + ) + + # List of the name of args that get mutated (according to the schema) + mutable_args_names = get_mutable_arg_names(op) + + unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[ + : -len(mutable_args_names) + ] + unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :] + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + assert isinstance(unwrapped_out, torch.Tensor) + orig_arg = normalized_kwargs[name] + ctx.replace(orig_arg, unwrapped_out) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +@auto_functionalized.py_functionalize_impl +def auto_functionalized_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/cond.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/cond.py new file mode 100644 index 0000000000000000000000000000000000000000..000db491f82cecc0705d61b34e3e146563ec11c9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/cond.py @@ -0,0 +1,349 @@ +import torch +import torch._subclasses.functional_tensor + +import torch.utils._pytree as pytree + +from torch._C import DispatchKey +from torch._C._functorch import ( + _add_batch_dim, + get_unwrapped, + is_batchedtensor, + maybe_get_bdim, +) +from torch._functorch.utils import exposed_in + +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + UnsupportedAliasMutationException, +) + +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +@exposed_in("torch") +def cond(pred, true_fn, false_fn, operands): + r""" + Conditionally applies `true_fn` or `false_fn`. + + .. warning:: + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `cond` is structured control flow operator. That is, it is like a Python if-statement, + but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be + capturable using torch.compile and torch.export. + + Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following:: + + def cond(pred, true_branch, false_branch, operands): + if pred: + return true_branch(*operands) + else: + return false_branch(*operands) + + Args: + pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element, + indicating which branch function to apply. + + true_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. + + false_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. The true branch and false branch must + have consistent input and outputs, meaning the inputs have to be + the same, and the outputs have to be the same type and shape. + + operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions. + + Example:: + + def true_fn(x: torch.Tensor): + return x.cos() + def false_fn(x: torch.Tensor): + return x.sin() + return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) + + Restrictions: + - The conditional statement (aka `pred`) must meet one of the following constraints: + + - It's a `torch.Tensor` with only one element, and torch.bool dtype + + - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10` + + - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints: + + - The function signature must match with operands. + + - The function must return a tensor with the same metadata, e.g. shape, + dtype, etc. + + - The function cannot have in-place mutations on inputs or global variables. + (Note: in-place tensor operations such as `add_` for intermediate results + are allowed in a branch) + + .. warning:: + Temporal Limitations: + + - `cond` only supports **inference** right now. Autograd will be supported in the future. + + - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future. + + """ + + if torch.compiler.is_dynamo_compiling(): + return cond_op(pred, true_fn, false_fn, operands) + + def _validate_input(pred, true_fn, false_fn, operands): + if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)): + raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.") + + if isinstance(pred, torch.Tensor) and pred.numel() != 1: + raise RuntimeError( + f"Expected pred to be bool or single-element tensor, but got {pred}." + ) + + if not callable(true_fn) or not callable(false_fn): + raise RuntimeError("Expect both branches to be callbale.") + + if not isinstance(operands, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), operands + ): + raise RuntimeError( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor leaves, but got {operands}." + ) + + _validate_input(pred, true_fn, false_fn, operands) + + if not torch._dynamo.is_dynamo_supported(): + raise RuntimeError("torch.cond requires dynamo support.") + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile(cond_op, backend="eager", fullgraph=True)( + pred, true_fn, false_fn, operands + ) + + +""" +We're going to define a `cond_op` operation. +In order to do this, we need implementations for each of the dispatch keys. +""" +cond_op = HigherOrderOperator("cond") + + +def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): + assert isinstance( + operands, (list, tuple) + ), "Cond operands must be a list or tuple of tensors" + assert all( + isinstance(o, torch.Tensor) for o in operands + ), "Cond operands must be a list of tensors" + + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + + with disable_proxy_modes_tracing(): + true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands) + false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands) + + true_outs = [] + false_outs = [] + for node in true_graph.graph.nodes: + if node.op == "output": + true_outs.extend(node.args) + + for node in false_graph.graph.nodes: + if node.op == "output": + false_outs.extend(node.args) + + flat_true_outs = pytree.arg_tree_leaves(*true_outs) + flat_false_outs = pytree.arg_tree_leaves(*false_outs) + if len(flat_true_outs) != len(flat_false_outs): + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected to return same number of outputs but got:" + f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)" + f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)" + ) + + for i in range(0, len(flat_true_outs)): + true_out = flat_true_outs[i] + false_out = flat_false_outs[i] + if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]: + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected each tensor to have same metadata but got:" + f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" + f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" + ) + + # There are probably better ways - I know that create_arg has some self incrementing name + # magic to it, but since we explicitly have to get the name for register_module, + # I was not sure how to do that. This kinda simulates it. + next_name = None + i = 0 + while not next_name: + candidate = f"true_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + true_name = next_name + false_name = f"false_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, false_name) + + proxy_mode.tracer.root.register_module(true_name, true_graph) + proxy_mode.tracer.root.register_module(false_name, false_graph) + + args = (pred, true_graph, false_graph, operands) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="conditional" + ) + + # At this point, we're *guaranteed* that whether an output came from the + # true or false branch is indistinguishable. So, as this is just for tracing + # purposes, choose the true branch. + + # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in + # a FakeTensorMode error : + # `Current active mode not registered` + # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in + # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys. + out = false_fn(*operands) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def cond_op_dense(pred, true_fn, false_fn, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + + +cond_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(cond_op, deferred_error=True) +) + + +@cond_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, pred, true_fn, false_fn, operands): + if mode.enable_tracing: + return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands) + else: + return cond_op(pred, true_fn, false_fn, operands) + + +@cond_op.py_impl(FakeTensorMode) +def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): + with mode: + true_outs = true_fn(*operands) + flat_true_outs = pytree.tree_leaves(true_outs) + flat_false_outs = pytree.tree_leaves(false_fn(*operands)) + if len(flat_true_outs) != len(flat_false_outs): + raise RuntimeError("Unmatched number of outputs from cond() branches.") + + for true_out, false_out in zip(flat_true_outs, flat_false_outs): + true_meta = _extract_tensor_metadata(true_out) + false_meta = _extract_tensor_metadata(false_out) + if true_meta != false_meta: + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected each tensor to have same metadata but got:" + f"\n {true_fn.__name__} returns {true_meta}" + f"\n {false_fn.__name__} returns {false_meta}" + ) + return true_outs + + +@cond_op.py_functionalize_impl +def cond_func(ctx, pred, true_fn, false_fn, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + unwrapped_pred = ctx.unwrap_tensors(pred) + with ctx.redispatch_to_next() as m: + functional_true = ctx.functionalize(true_fn) + functional_false = ctx.functionalize(false_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for branch in [functional_true, functional_false]: + if _has_potential_branch_input_mutation( + branch, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "One of torch.cond branch might be modifying the input!" + ) + for branch in [true_fn, false_fn]: + if _has_potential_branch_input_alias( + branch, unwrapped_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "One of torch.cond branch might be aliasing the input!" + ) + + cond_return = cond_op( + unwrapped_pred, functional_true, functional_false, unwrapped_inputs + ) + return ctx.wrap_tensors(cond_return) + + +@cond_op.py_impl(torch._C._functorch.TransformType.Vmap) +def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): + assert isinstance( + inputs, (list, tuple) + ), "Cond inputs must be a list or tuple of tensors" + assert all( + isinstance(i, torch.Tensor) for i in inputs + ), "Cond inputs must be a list of tensors" + + pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred + + # unbatched tensors are not vmapped + tensors, in_dims = zip( + *[ + (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None) + for t in inputs + ] + ) + + if is_batchedtensor(pred): + # prepend "pred" and vmap everything + tensors = (pred_,) + tensors + in_dims = (0,) + in_dims + + def fn(p, *args): + t = true_fn(*args) + f = false_fn(*args) + return torch.where(p, t[0], f[0]) + + with interpreter.lower(): + result = torch.vmap(fn, in_dims=in_dims)(*tensors) + + else: + # predicate is known at this stage and it is a boolean expression or a + # tensor with one element. + true_fn = torch.vmap(true_fn, in_dims=in_dims) + false_fn = torch.vmap(false_fn, in_dims=in_dims) + + with interpreter.lower(): + result = cond_op(pred, true_fn, false_fn, tensors) + + if not isinstance(result, tuple): + result = (result,) + lvl = interpreter.level() + return tuple([_add_batch_dim(r, 0, lvl) for r in result]) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/effects.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..c47ba873970465be0acd7df86e5afc09773ec88b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/effects.py @@ -0,0 +1,204 @@ +from enum import Enum +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class _EffectType(Enum): + ORDERED = "Ordered" + + +SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = { + torch.ops.aten._print.default: _EffectType.ORDERED, +} + + +class WithEffects(HigherOrderOperator): + """ + with_effects(token, op, args, kwargs) -> (new_token, op_results) + + This HOP helps ensure ordering between side effectful ops like prints or ops + using torchbind objects. This is needed to ensure a traced graph from + AOTAutograd is functional so that future optimization passes do not reorder + these operators. This is done through threading "effect tokens" through the + graph to enforce data dependence between side effectful ops. + + The tokens are basically dummy values (torch.tensor([])). We create a token + per "effect type", which are enumerated in the _EffectType enum. + """ + + def __init__(self): + super().__init__("with_effects") + + def __call__( + self, + token, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) -> Tuple[Any, ...]: + assert isinstance(op, torch._ops.OpOverload) + assert not has_aliasing(op), "Ops with aliasing is not supported" + assert has_effects(op, args, kwargs) + assert isinstance(kwargs, dict) + return super().__call__(token, op, *args, **kwargs) + + +with_effects = WithEffects() + + +def has_aliasing(op: torch._ops.OpOverload): + for arg in op._schema.arguments: + if arg.alias_info is not None: + return True + for arg in op._schema.returns: + if arg.alias_info is not None: + return True + return False + + +def has_effects(op, args, kwargs) -> bool: + return ( + isinstance(op, torch._ops.OpOverload) + and not has_aliasing(op) + and get_effect_key(op, args, kwargs) is not None + ) + + +def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: + if op in SIDE_EFFECTS: + return SIDE_EFFECTS[op] + + for arg in args: + if isinstance(arg, torch.ScriptObject): + return _EffectType.ORDERED + + return None + + +@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) +def with_effects_dense( + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + out = op(*args, **kwargs) + new_token = torch.tensor([]) + if isinstance(out, tuple): + return (new_token, *out) + return (new_token, out) + + +@with_effects.py_impl(FakeTensorMode) +def with_effects_fake( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + with mode: + result = with_effects_dense(token, op, *args, **kwargs) + return result + + +@with_effects.py_impl(ProxyTorchDispatchMode) +def with_effects_proxy( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Tuple[torch.Tensor, ...]: + if not mode.enable_tracing: + return with_effects(token, op, *args, **kwargs) + + with disable_proxy_modes_tracing(): + out = with_effects(token, op, *args, **kwargs) + + proxy_token = mode.tracer.unwrap_proxy(token) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + with_effects, + (proxy_token, op, *proxy_args), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +with_effects.fallthrough(DispatchKey.AutogradCPU) +with_effects.fallthrough(DispatchKey.AutogradCUDA) + + +def handle_effects( + allow_token_discovery: bool, + tokens: Dict[_EffectType, torch.Tensor], + op: torch._ops.OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + """ + Args: + allow_token_discovery: Whether or not we are discovering tokens. If this + is true, we will create a token for every side effect type seen that + does not have a token assigned yet. If this is false, the tokens + should've all been created ahead of time, so we will error if there is + no token mapping to every effect type. + + tokens: Map of effect type to tokens. This is to chain operators of the + same effects together so that they do not get reordered in later + optimization passes. + """ + + # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because + # this will create an empty tensor during proxy mode tracing if the token + # doesn't exist. But the tokens should always exist during proxy mode tracing. + key = get_effect_key(op, args, kwargs) + assert key is not None + if key not in tokens: + assert allow_token_discovery, f"Could not find a token for effect {key}" + tokens[key] = torch.tensor([]) + token = tokens[key] + + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type] + unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type] + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + (new_token, *unwrapped_outs) = with_effects( + unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type] + ) + + if len(op._schema.returns) == 0: + assert unwrapped_outs[0] is None + unwrapped_outs = None # type: ignore[assignment] + elif len(op._schema.returns) == 1: + assert len(unwrapped_outs) == 1 + unwrapped_outs = unwrapped_outs[0] + else: + assert len(unwrapped_outs) == len(op._schema.returns) + + # Add the newly created token into the tokens map for a following call to + # use this token. + wrapped_token = ctx.wrap_tensors(new_token) + assert isinstance(wrapped_token, torch.Tensor) + tokens[key] = wrapped_token + + return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type] diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/map.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/map.py new file mode 100644 index 0000000000000000000000000000000000000000..07f89ea23c90f41c02051eada86fc771f0bab221 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/map.py @@ -0,0 +1,358 @@ +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun + +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import ( + disable_functional_mode, + FunctionalTensor, +) +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.multiprocessing.reductions import StorageWeakRef + + +# TODO: We add this to prevent dymamo from tracing into map_wrapper, +# remove the wrapper call when it's ready. +class MapWrapper(HigherOrderOperator): + def __call__(self, xs, *args): + return map_wrapper(xs, *args) + + +map = MapWrapper("map") +map_impl = HigherOrderOperator("map_impl") + +dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, +) + + +def create_fw_bw_graph(f, num_mapped_args, *args): + mapped_xs = args[:num_mapped_args] + pos_args = args[num_mapped_args:] + + # Note: We create "clean" environments for make_fx by suspending all dispatch keys + # between Autograd and Python key. Currently, we only suspend functionalization but more can be + # added when required. Will encounter two problems if we don't suspend functionalization: + # + # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, + # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. + # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to + # fetch the proxy for the inputs and fail to capture any operations on them. + # + # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further + # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer + # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, + # when creating the output node, it fails to associate the wrapped tensor with its proxy. + # Instead, it will create _tensor_constant as output. + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + + def _from_fun(t): + if isinstance(t, torch.Tensor): + if t.dtype != torch.bool: + return torch.empty_strided( + t.size(), + t.stride(), + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + else: + # clone of a functional tensor produces a functional tensor + # but we want to avoid it so we clone a non-functional version + maybe_unfunc_t = t + if isinstance(t, FunctionalTensor): + torch._sync(t) + maybe_unfunc_t = from_fun(t) + elif torch._is_functional_tensor(t): + # need to handle both types of functionalization here: + # these are the tensors that came from the user, + # which could be either FunctionalTensorWrapper or FunctionalTensor + torch._sync(t) + maybe_unfunc_t = torch._from_functional_tensor(t) + return maybe_unfunc_t.clone() + return t + + unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + example_flat_out = pytree.tree_map( + _from_fun, f(*example_xs, *example_pos_args) + ) + if any( + not isinstance(out, torch.Tensor) + for out in example_flat_out + if out is not None + ): + raise RuntimeError( + "Expect outputs of map only contains tensors or None. " + f"Got types {[type(out) for out in example_flat_out]}." + ) + example_grad = [_from_fun(out) for out in example_flat_out] + + fw_graph = make_fx(f)(*example_xs, *example_pos_args) + + def joint_f(*example_args): + joint_mapped_args = example_args[:joint_num_mapped] + args = example_args[joint_num_mapped:] + + mapped_input = joint_mapped_args[:num_mapped_args] + mapped_grads = joint_mapped_args[num_mapped_args:] + + def fw_with_masks(*args): + fw_out = f(*args) + return fw_out, [ + True + if isinstance(ret, torch.Tensor) and ret.requires_grad + else False + for ret in fw_out + ] + + joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) + _, grads = joint( + list(mapped_input) + list(args), + [ + grad + for grad in mapped_grads + if grad is not None and grad.requires_grad + ], + ) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + input_storage = { + StorageWeakRef(arg._typed_storage()) + for arg in example_args + if isinstance(arg, torch.Tensor) + } + + def maybe_clone(t): + if ( + isinstance(t, torch.Tensor) + and StorageWeakRef(t._typed_storage()) in input_storage + ): + return t.clone() + return t + + return pytree.tree_map(maybe_clone, grads) + + joint_num_mapped = len(example_grad) + len(example_xs) + joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) + return fw_graph, joint_graph + + +def map_wrapper(f, xs, *args): + flat_xs, xs_spec = pytree.tree_flatten(xs) + if not all(isinstance(t, torch.Tensor) for t in flat_xs): + raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") + + num_mapped_args = len(flat_xs) + shapes = [xs.shape for xs in flat_xs] + leading_dim_size = shapes[0][0] + if leading_dim_size == 0: + raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") + + if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): + raise RuntimeError( + f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." + ) + + out_spec = None + + def flat_fn(*flat_args): + xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec) + unflattened_out = f(xs, *flat_args[num_mapped_args:]) + flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) + + nonlocal out_spec + out_spec = tmp_out_spec + return flat_out + + return pytree.tree_unflatten( + map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type] + ) + + +class MapAutogradOp(torch.autograd.Function): + @staticmethod + def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): + ctx.save_for_backward(*flat_args) + ctx._joint_graph = joint_graph + ctx._num_mapped_args = num_mapped_args + with torch._C._AutoDispatchBelowAutograd(): + return ( + *map_impl( + fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] + ), + ) + + @staticmethod + def backward(ctx, *flat_grads): + fw_args = ctx.saved_tensors + fw_mapped_args = fw_args[: ctx._num_mapped_args] + pos_args = fw_args[ctx._num_mapped_args :] + + grads = map_impl( + ctx._joint_graph, + fw_mapped_args + flat_grads, + pos_args, + ) + return None, None, None, *grads + + +def trace_map(proxy_mode, func_overload, f, xs, pos_args): + leading_dim_size = xs[0].shape[0] + + example_input = _unstack_pytree(xs)[0] + body_graph = f + + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args) + + next_name = None + i = 0 + while not next_name: + candidate = f"body_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + proxy_mode.tracer.root.register_module(next_name, body_graph) + + with disable_proxy_modes_tracing(): + example_outs = body_graph(*example_input, *pos_args) + + def expand_tensor(t): + if isinstance(t, torch.Tensor): + return t.expand(leading_dim_size, *t.shape) + return t + + expanded_outs = pytree.tree_map(expand_tensor, example_outs) + + node_args = (body_graph, list(xs), list(pos_args)) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="map_impl" + ) + return track_tensor_tree( + expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +def _unstack_pytree(xs): + flat_xs, inspec = pytree.tree_flatten(xs) + if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): + raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") + + if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): + raise RuntimeError( + f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" + ) + + a = zip(*flat_xs) + + pytrees = [] + for tuple in a: + pytrees.append(pytree.tree_unflatten(tuple, inspec)) + return pytrees + + +def _stack_pytree(pytrees): + flat_out = [] + out_spec = None + for pt in pytrees: + flat_pt, out_spec = pytree.tree_flatten(pt) + flat_out.append(flat_pt) + assert out_spec is not None + b = zip(*flat_out) + stacked_out = [] + for leaves in b: + if all(isinstance(leaf, torch.Tensor) for leaf in leaves): + stacked_out.append(torch.stack(leaves)) + elif all(leaf is None for leaf in leaves): + # Backward graph can return None output when forward inputs doesn't require grad. + # When we eagerly execute backward graph, we need to call _stack_pytree on its output, + # therefore we need to deal with None output. + stacked_out.append(None) # type: ignore[arg-type] + else: + raise RuntimeError(f"Cannot stack {leaves}.") + return pytree.tree_unflatten(stacked_out, out_spec) + + +@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) +def map_dense(f, xs, pos_args): + pytrees = [] + for inp in _unstack_pytree(xs): + pytrees.append(f(*inp, *pos_args)) + return _stack_pytree(pytrees) + + +@map_impl.py_impl(DispatchKey.Autograd) +def map_autograd(f, xs, pos_args): + num_mapped_args = len(xs) + fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) + return flat_out + + +@map_impl.py_impl(ProxyTorchDispatchMode) +def map_proxy_torch_dispatch_mode(mode, f, xs, args): + if mode.enable_tracing: + return trace_map(mode, map_impl, f, xs, args) + else: + return map_impl(f, xs, args) + + +@map_impl.py_impl(FakeTensorMode) +def map_fake_tensor_mode(mode, f, xs, args): + with mode: + return map_dense(f, xs, args) + + +@map_impl.py_functionalize_impl +def map_functionalize(ctx, f, xs, pos_args): + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_args = ctx.unwrap_tensors(pos_args) + wrapped_fn = ctx.functionalize(f) + + with ctx.redispatch_to_next(): + with disable_proxy_modes_tracing(): + example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + if _has_potential_branch_input_mutation( + f, example_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException("torch.map is mutating the input!") + + if _has_potential_branch_input_alias( + f, example_inputs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException("torch.map is aliasing the input!") + + map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args) + return ctx.wrap_tensors(map_return) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/out_dtype.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/out_dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..5f30903e02cfaea25620a189f6d681900aa32dc3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/out_dtype.py @@ -0,0 +1,170 @@ + +import torch +import torch.utils._pytree as pytree +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, + maybe_handle_decomp, +) +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._higher_order_ops.utils import autograd_not_implemented + +# TODO to figure out a more generic approach +ALLOWABLE_OPS = [ + torch.ops.aten.linear.default, + torch.ops.aten.mm.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.convolution.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul.Scalar, + torch.ops.aten.div.Tensor, + torch.ops.aten.div.Scalar, +] + + +class OutDtypeOperator(HigherOrderOperator): + """ + The out_dtype operator takes an existing ATen functional operator, an + `out_dtype` argument, and arguments to the original operator, and executes + the original operator and returns a Tensor with the `out_dtype` precision. + This operator does not mandate a compute precision so it allows the + representation to not be opinionated about the exact implementation. + + The general implementation for all operators will be the following: + 1. Promote inputs dtypes based on default PyTorch dtype promotion rules, + using the dtypes of all input Tensors/Scalars and the `out_dtype` + arugument. + 2. Execute the operator + 3. Cast the output to `out_dtype` + """ + + + def __init__(self): + super().__init__("out_dtype") + # TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to + # become different (torch._higher_order_ops.out_dtype) which will result + # in torch.fx to record the op incorrectly in the graph. + self.__module__ = "torch.ops.higher_order" + + def __call__(self, op, output_dtype, *args): + if not isinstance(op, torch._ops.OpOverload): + raise ValueError("out_dtype's first argument must be an OpOverload") + if op._schema.is_mutable: + raise ValueError("out_dtype's first argument needs to be a functional operator") + if not ( + len(op._schema.returns) == 1 and + isinstance(op._schema.returns[0].type, torch.TensorType) + ): + raise ValueError( + "out_dtype's can only apply to ops that return a single tensor" + f"Instead got {[r.type for r in op._schema.returns]}" + ) + + if op not in ALLOWABLE_OPS: + raise ValueError( + f"out_dtype only allows the following operators: {ALLOWABLE_OPS}." + ) + + res = super().__call__(op, output_dtype, *args) + + return res + + +out_dtype = OutDtypeOperator() + +def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args): + # NB: Long-term we should put the decomposition logic into + # ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp + # in all HigherOrderOp proxy implementations. + r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {}) + if r is not NotImplemented: + return r + + with disable_proxy_modes_tracing(): + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + out = op(*args).to(dtype=output_dtype) + + node_args = (op, output_dtype, *args) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="out_dtype" + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd) +def out_dtype_dense( + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args +): + if is_int_mm(op, output_dtype, args): + return torch._int_mm(*args) + return out_dtype_fallback(op, output_dtype, *args) + + +def is_int_mm(op, output_dtype, args): + return ( + op == torch.ops.aten.mm.default and + output_dtype == torch.int32 and + len(args) == 2 and + args[0].dtype == torch.int8 and + args[1].dtype == torch.int8 and + args[0].is_cuda and + args[1].is_cuda + ) + + +def out_dtype_fallback(op, output_dtype, *args): + flat_inputs = pytree.arg_tree_leaves(*args) + [torch.ones(1, dtype=output_dtype)] + promote_dtype: torch.dtype = elementwise_dtypes( + *flat_inputs, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + )[0] + + casted_args = pytree.tree_map_only( + torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args + ) + res = op(*casted_args).to(dtype=output_dtype) + return res + + +out_dtype.py_impl(DispatchKey.Autograd)(autograd_not_implemented(out_dtype, deferred_error=True)) + + +@out_dtype.py_impl(ProxyTorchDispatchMode) +def out_dtype_proxy( + mode: ProxyTorchDispatchMode, + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args +): + if mode.enable_tracing: + return trace_out_dtype(mode, out_dtype, op, output_dtype, *args) + else: + return out_dtype(op, output_dtype, *args) + + +@out_dtype.py_impl(FakeTensorMode) +def out_dtype_fake_tensor_mode( + mode: FakeTensorMode, + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args +): + with mode: + return out_dtype_dense(op, output_dtype, *args) + + +@out_dtype.py_functionalize_impl +def out_dtype_func(ctx, op, output_dtype, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + + with ctx.redispatch_to_next(): + res = out_dtype(op, output_dtype, *unwrapped_args) + return ctx.wrap_tensors(res) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/strict_mode.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/strict_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..73e20dc817cefcb4ef0f91d4b72f127a738741ba --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/strict_mode.py @@ -0,0 +1,100 @@ +import torch +import torch._subclasses.functional_tensor + +import torch.utils._pytree as pytree + +from torch._C import DispatchKey +from torch._functorch.utils import exposed_in + +from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +@exposed_in("torch") +def strict_mode(callable, operands): + if torch.compiler.is_dynamo_compiling(): + return strict_mode_op(callable, operands) + + with _set_compilation_env(): + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile(strict_mode_op, backend="eager", fullgraph=True)( + callable, operands + ) + + +strict_mode_op = HigherOrderOperator("strict_mode") + + +@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def strict_mode_op_dense(callable, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return callable(*operands) + + +strict_mode_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(strict_mode_op, deferred_error=True) +) + + +@strict_mode_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, callable, operands): + if mode.enable_tracing: + return trace_strict_mode(mode, strict_mode_op, callable, operands) + else: + return strict_mode_op(callable, operands) + + +def trace_strict_mode(mode, strict_mode_op, callable, operands): + pre_dispatch = getattr(mode, "pre_dispatch", False) + + with disable_proxy_modes_tracing(): + graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands) + + next_name = None + i = 0 + while not next_name: + candidate = f"strict_graph_{i}" + if hasattr(mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + graph_name = next_name + mode.tracer.root.register_module(graph_name, graph) + + args = (graph, operands) + + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + + out_proxy = mode.tracer.create_proxy( + "call_function", strict_mode_op, proxy_args, {}, name="strict_mode" + ) + + out = graph(*operands) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + +@strict_mode_op.py_impl(FakeTensorMode) +def strict_mode_fake_tensor_mode(mode, callable, operands): + with mode: + true_outs = callable(*operands) + return true_outs + + +@strict_mode_op.py_functionalize_impl +def strict_mode_func(ctx, callable, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + with ctx.redispatch_to_next(): + functional_callable = ctx.functionalize(callable) + + cond_return = strict_mode_op(functional_callable, unwrapped_inputs) + return ctx.wrap_tensors(cond_return) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/torchbind.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/torchbind.py new file mode 100644 index 0000000000000000000000000000000000000000..385054682e6c4188dccff26694bb29f512884a0b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/torchbind.py @@ -0,0 +1,94 @@ +from contextlib import contextmanager + +import torch +from torch._C import DispatchKey # @manual +from torch._functorch._aot_autograd.utils import KNOWN_TYPES +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.node import has_side_effect +from torch.utils import _pytree as pytree + +# The call_torchbind operator represents a method invocation on a torchbind +# object. The calling convention is: +# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) +# We do not expect users to write this operator directly. Instead it will be +# emitted by Dynamo when tracing encounters a torchbind object. +call_torchbind = HigherOrderOperator("call_torchbind") + +# Register this operator as side-effectful with FX. +# TODO: this is not really sufficient. While passes (hopefully) check +# Node.is_impure() and make good decisions, we also assume we can execute the +# graph as many times as we want without changing behavior, which is NOT true of +# ops that mutate torchbind object state. +has_side_effect(call_torchbind) + +_orig_scriptmethod_call = torch.ScriptMethod.__call__ + + +def torchbind_method_redispatch(self, *args, **kwargs): + if isinstance(self.raw_owner, torch.ScriptObject): + return call_torchbind(self.raw_owner, self.name, *args, **kwargs) + return _orig_scriptmethod_call(self, *args, **kwargs) + + +@contextmanager +def enable_torchbind_tracing(): + """Context manager that acts as a feature flag to enable torchbind tracing + behavior. Once torchbind tracing has been stabilized, we can remove this and + turn it always on. + """ + try: + KNOWN_TYPES.append(torch.ScriptObject) + torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] + yield + finally: + assert ( + KNOWN_TYPES.pop() is torch.ScriptObject + ), "Someone else messed with KNOWN_TYPES during tracing, exploding." + torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] + + +@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) +def call_torchbind_impl(obj, method, *args, **kwargs): + return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + + +@call_torchbind.py_impl(ProxyTorchDispatchMode) +def inner(mode, *args, **kwargs): + if mode.enable_tracing: + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + call_torchbind, + proxy_args, + proxy_kwargs, + ) + out = call_torchbind_impl(*args, **kwargs) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + else: + return call_torchbind(*args, **kwargs) + + +# TODO: currently we just run the C++ implementation with fake tensors. +# But we should make it possible to register a fake torchbind implementation. +@call_torchbind.py_impl(FakeTensorMode) +def call_torchbind_fake(mode, *args, **kwargs): + with mode: + return call_torchbind_impl(*args, **kwargs) + + +call_torchbind.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(call_torchbind, deferred_error=True) +) + + +@call_torchbind.py_functionalize_impl +def call_torchbind_func(ctx, *args, **kwargs): + args = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + return ctx.wrap_tensors(call_torchbind(*args, **kwargs)) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..0a41e585079dcba91d0060eb6835e02fd4e801af --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py @@ -0,0 +1,842 @@ +import dataclasses +import logging +import threading +import warnings +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + +log = logging.getLogger("torch._dynamo") + + +############################################################################### +# Kernel Side Table + + +# We cannot put Triton Kernels into the FX graph as the graph nodes +# do not support arbitrary functions. +# Use a side table. +# We use two dicts so that fetching both the kernel and id are O(1) +class KernelSideTable: + id_to_kernel: Dict[int, Any] = dict() + kernel_to_id: Dict[Any, int] = dict() + lock = threading.Lock() + + # Returns index on the table + def add_kernel(self, kernel) -> int: + with self.lock: + if kernel in self.kernel_to_id: + return self.kernel_to_id[kernel] + + idx = len(self.id_to_kernel) + self.id_to_kernel[idx] = kernel + self.kernel_to_id[kernel] = idx + return idx + + # Returns the triton kernel at the given index + def get_kernel(self, idx: int): + # No need to lock here as fetching from dict is atomic + assert idx in self.id_to_kernel + return self.id_to_kernel[idx] + + # Resets the table (only meant to be used in unit tests) + # This is only safe assuming single threaded execution + def reset_table(self) -> None: + self.id_to_kernel = dict() + self.kernel_to_id = dict() + + +kernel_side_table = KernelSideTable() + + +############################################################################### +# Mutation Tracker + + +@dataclasses.dataclass(frozen=True) +class Param: + idx: int + + +@dataclasses.dataclass(frozen=True) +class Intermediate: + idx: int + + def fake(self): + return self.idx < 0 + + +@dataclasses.dataclass(frozen=True) +class Op: + name: str + fn_call_name: Optional[str] + args: List[Union[Param, Intermediate]] + ret: Intermediate = dataclasses.field(repr=False) + + def __post_init__(self): + if self.name == "tt.call": + assert self.fn_call_name is not None + else: + assert self.fn_call_name is None + + +def generate_ttir(kernel, kwargs): + """ + Uses Triton's internal code generation to create TTIR + """ + from triton.compiler.compiler import ASTSource + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + import torch + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(kernel, Autotuner): + if len(kernel.configs) > 0: + # If we are autotuning, then it doesn't matter which version gets + # picked for tracing purposes, so lets pick the first one + kwargs = {**kwargs, **kernel.configs[0].kwargs} + kernel = kernel.fn + + assert isinstance(kernel, JITFunction) + + if len(kwargs) != len(kernel.arg_names): + raise Exception("Incorrect number of arguments passed to kernel") + + # Replace all SymExprs with a regular value for TTIR generation + # Replace all FakeTensor with real tensors + # These replacements are needed for triton's type, key and config functions + ordered_args: Dict[str, Any] = {} + for name in kernel.arg_names: + a = kwargs[name] + if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)): + ordered_args[name] = 2 + elif isinstance(a, FakeTensor): + ordered_args[name] = torch.empty(2, dtype=a.dtype) + else: + ordered_args[name] = a + + ordered_tensor_names = [ + name for name, arg in ordered_args.items() if isinstance(arg, Tensor) + ] + specialization = kernel._get_config(*ordered_args.values()) + constants = { + i: arg + for i, arg in enumerate(ordered_args.values()) + if not isinstance(arg, Tensor) + } + + # Build kernel signature -- doesn't include constexpr arguments. + signature = { + i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(ordered_args.values()) + if i not in kernel.constexprs + } + + def get_backend(): + from triton.compiler.backends.cuda import CUDABackend + from triton.runtime.driver import driver + + target = driver.get_current_target() + return CUDABackend(target) + + backend = get_backend() + + options = backend.parse_options(dict()) + # triton._C.libtriton.triton.ir.load_dialects(context) + # backend.load_dialects(context) + + src = ASTSource(kernel, signature, constants, specialization) + ttir_module = src.make_ir(options) + if not ttir_module.verify(): + raise Exception("Verification for TTIR module has failed") + + return ttir_module, ordered_tensor_names + + +def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: + """ + Walk the `ttir_module` bottom up to mine the `functions` from + the structured MLIR entities representing the Triton kernel + (mlir::Operation, mlir::Block, mlir::Region). + """ + functions: Dict[str, Dict[Intermediate, List[Op]]] = {} + + # block id --> op result (Intermediate) --> one or more ops + op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict( + lambda: defaultdict(list) + ) + region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list) + block_id_to_block_arg_ids: Dict[int, List[int]] = {} + replacements: Dict[int, Union[Intermediate, Param]] = {} + reindex_map: Dict[int, int] = {} + next_fake_intermediate = 0 + + def reindex(idx): + if idx not in reindex_map: + reindex_map[idx] = len(reindex_map) + return reindex_map[idx] + + def mlir_to_functions(op) -> None: + name: str = op.get_name() + if name == "builtin.module": + # this wraps all tt.func ops + return + + operand_ids: List[int] = [ + reindex(op.get_operand(i).id()) for i in range(op.get_num_operands()) + ] + result_ids: List[int] = [ + reindex(op.get_result(i).id()) for i in range(op.get_num_results()) + ] + + child_block_ids: List[int] = [] + for i in [op.get_region(i).id() for i in range(op.get_num_regions())]: + # as the walk is bottom-up, the region_id_to_block_ids[i] + # must be populated by the time we process the enclosing op + child_block_ids.extend(region_id_to_block_ids[i]) + + parent_block_id = -1 + parent_block = op.get_block() + if parent_block is not None: + parent_block_id = parent_block.id() + if parent_block_id not in block_id_to_block_arg_ids: + block_id_to_block_arg_ids[parent_block_id] = [] + for i in range(parent_block.get_num_arguments()): + block_id_to_block_arg_ids[parent_block_id].append( + reindex(parent_block.get_argument(i).id()), + ) + # the region info is collected via ops' parent blocks to be + # used later when the region's encloding op is traversed + parent_region = parent_block.get_parent() + if parent_region is not None: + region_id_to_block_ids[parent_region.id()].append(parent_block_id) + + nonlocal next_fake_intermediate + + if name == "tt.func": + # for function ops: gather and inline + # the ops from all child blocks + fn_ops = defaultdict(list) + for child_block_id in child_block_ids: + for result, block_fn_ops in op_stack.pop(child_block_id).items(): + for block_fn_op in block_fn_ops: + fn_ops[result].append(block_fn_op) + + # replace the corresponding Intermediates in the + # child op args with the function args (Params) + for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]): + replacements[idx] = Param(i) + + for fn_op_list in fn_ops.values(): + for fn_op in fn_op_list: + for i in range(len(fn_op.args)): + arg = fn_op.args[i] + if isinstance(arg, Intermediate) and arg.idx in replacements: + fn_op.args[i] = replacements[arg.idx] + + # next function capture starts + # with empty replacements + replacements.clear() + + fn_name = op.get_str_attr("sym_name") + functions[fn_name] = fn_ops + elif child_block_ids: + if name in ("scf.if", "scf.for", "scf.while"): + # for blocked control flow ops: inline the enclosed + # ops into the parent block + rewire the last op in + # each child block (yield) to return the scf result + yield_ops = [] + for block_id in child_block_ids: + # the block args used as operands of the ops in the block + # (and nested blocks inlined in the current block by now) + # are replaced by new fake Intermediates to avoid "this + # operand is not returned by anything other op in the fn" + # error in the downstream analysis + for idx in block_id_to_block_arg_ids[block_id]: + next_fake_intermediate -= 1 + replacements[idx] = Intermediate(next_fake_intermediate) + + if block_id in op_stack: + block_ops = op_stack.pop(block_id) + if not block_ops: + continue + last_ret, last_ops = block_ops.popitem() + if all(op.name == "scf.yield" for op in last_ops): + # if last_ops are scf.yield, treat them separately + yield_ops.extend(last_ops) + else: + # otherwise, return last_ops to the block + block_ops[last_ret] = last_ops + for op_result, child_ops in block_ops.items(): + op_stack[parent_block_id][op_result].extend(child_ops) + + scf_results = [Intermediate(idx) for idx in result_ids] + for scf_result in scf_results: + for yield_op in yield_ops: + op_stack[parent_block_id][scf_result].append(yield_op) + else: + # TODO(oulgen): add support for tt.reduce + raise Exception( + f"Unknown blocked function: {name}. Can't capture the TTIR." + ) + else: + callee = None + if name == "tt.call": + callee = op.get_flat_symbol_ref_attr("callee") + args: List[Union[Param, Intermediate]] = [ + Intermediate(operand) for operand in operand_ids + ] + block_ops = op_stack[parent_block_id] + if result_ids: + for result_id in result_ids: + res = Intermediate(result_id) + block_ops[res].append(Op(name, callee, args, res)) + else: + next_fake_intermediate -= 1 + fake_res = Intermediate(next_fake_intermediate) + block_ops[fake_res].append(Op(name, callee, args, fake_res)) + + ttir_module.walk(mlir_to_functions) + + return functions + + +def parse_ttir(ttir, kwargs): + """ + Given a Triton emitted TTIR text, this function lexes and parses the + code using a minimal grammar defined inside. During the lexing/parsing, + we drop any constant value and type information as they are not + necessary to us. + Being able to choose what we need makes this not a general purpose TTIR + parser which further makes parsing much simpler. + """ + # TODO(oulgen): + # - Support closures (e.g. "tt.reduce") + + try: + import lark # type: ignore[import-not-found] + from lark import Lark, Transformer, v_args + except ModuleNotFoundError: + warnings.warn( + "Using slow path for user-defined Triton kernels. `pip install lark` to fix this." + ) + raise + + # Ops looks like one of the following forms: + # + # %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + # tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32> + # %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950 + grammar = """ + start: (module_block | loc_line)+ + + loc_line: "#loc" /.+/ NEWLINE + + module_block: "module" "{" func_block+ "}" LOC + + func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func + + ?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt + + if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if + for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for + while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while + + condition_stmt: "scf.condition" "(" arg ")" args rest + label_stmt: LABEL ":" "// pred:" LABEL + | LABEL "(" /.+/ NEWLINE + cf_stmt: "cf" "." NAME /.+/ NEWLINE + + op: OP_NAME LOC + | [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op + + ?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE + divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}" + + args: | "(" ")" | "("? arg ("," arg)* ")"? + + ?arg: INTERMEDIATE + | INTERMEDIATE_CONSTANT + | CONSTANT + | PARAM + | "[" args "]" + | arg_with_index + + ?arg_with_index: arg "#" DIGIT+ + + ?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+] + + PARAM.5: "%arg" DIGIT+ + INTERMEDIATE.4: "%" DIGIT+ + INTERMEDIATE_CONSTANT.3: "%" NAME + CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")? + LABEL: "^bb" DIGIT+ + + NAME: (LETTER | DIGIT | "_")+ + NON_CF_NAME: /(?!(cf))/ NAME + FN_NAME: "@" (NAME | ESCAPED_STRING) + OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""? + + LOC.5: "loc(#loc" DIGIT* ")" + + %import common.LETTER + %import common.DIGIT + %import common.WS + %import common.NEWLINE + %import common.ESCAPED_STRING + %import common.FLOAT + %ignore WS + """ + + next_fake_intermediate = 0 + + def convert(token): + if isinstance(token, lark.tree.Tree): + if token.data == "args": + res = [] + for a in token.children: + c = convert(a) + if isinstance(c, list): + res.extend(c) + else: + res.append(c) + return res + elif token.data in {"assign_lhs", "arg_with_index"}: + # Drop length/index qualifier + return convert(token.children[0]) + else: + raise AssertionError(f"Tree node with {token.data}") + + if token is None or ( + isinstance(token, lark.lexer.Token) + and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT") + ): + nonlocal next_fake_intermediate + next_fake_intermediate -= 1 + return Intermediate(next_fake_intermediate) + + assert isinstance(token, lark.lexer.Token) + + if token.type == "INTERMEDIATE": + return Intermediate(int(token.value[len("%") :])) + if token.type == "PARAM": + return Param(int(token.value[len("%arg") :])) + + raise AssertionError(f"{type(token.type)} => {token.value} invalid") + + # In alternative representation, function names are quoted. + # It should be possible to move this into the grammar alltogether. + def convert_name(token): + if token is None: + return None + s = token.value + if len(s) > 2 and s[0] == '"' and s[-1] == '"': + return s[1:-1] + return s + + functions: Dict[str, Dict[Intermediate, List[Op]]] = {} + + def extend_dict_list(d1, d2): + for key, values in d2.items(): + d1[key].extend(values) + + @v_args(inline=True) + class TransformOps(Transformer): + def process_op(self, ret, op_name, fn_name, args, *rest): + return Op( + convert_name(op_name), + convert_name(fn_name), + convert(args), + convert(ret), + ) + + def process_func(self, name, _args, *stmts): + ops: Dict[Intermediate, List[Op]] = defaultdict(list) + for e in stmts: + if isinstance(e, Op): + ops[e.ret].append(e) + elif isinstance(e, dict): + extend_dict_list(ops, e) + functions[name.value] = ops + + def _process_scf(self, ret, stmts): + ret = convert(ret) + ops: Dict[Intermediate, List[Op]] = defaultdict(list) + for e in stmts: + if isinstance(e, Op): + if e.name == "scf.yield": + ops[ret].append(Op(e.name, None, e.args, ret)) + else: + ops[e.ret].append(e) + elif isinstance(e, dict): + extend_dict_list(ops, e) + return ops + + def process_if(self, ret, _args, _rest, *stmts): + return self._process_scf(ret, stmts) + + def process_for(self, ret, _args, _rest, *stmts): + return self._process_scf(ret, stmts) + + def process_while(self, ret, _args, _rest, *stmts): + return self._process_scf(ret, stmts) + + parser = Lark( + grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps() + ) + parser.parse(ttir) + return functions + + +class MemoizeWithCycleCheck: + def __init__(self, fn): + self.fn = fn + self.reset() + + def __call__(self, functions, fn_name, num_args): + key = (fn_name, num_args) + if key not in self.cache: + self.cache[key] = None + self.cache[key] = self.fn(functions, fn_name, num_args) + if self.cache[key] is None: + raise Exception("Recursion is not supported") + return self.cache[key] + + def reset(self): + self.cache = {} + + +@MemoizeWithCycleCheck +def analyze_kernel_mutations(functions, fn_name, num_args): + """ + Analyzes the graph to detect all sinks from a predefined list of sinks + by using triton's MemWrite trait list. NOTE: What if triton exposed this? + From each sink, it traverses the CFG backwards to identify all the input + pointers that are mutated. + """ + # Name of mutation op to mutated parameter indices + # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td + # All the OPs that have MemWrite trait. + # What if Triton exposed this? + MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]} + # Ops that we want to bail out on + UNKNOWN_OPS = {"tt.elementwise_inline_asm"} + + stack: List[Union[Param, Intermediate]] = [] + visited = set() + ops = functions[fn_name] + for op_list in ops.values(): + for op in op_list: + if op.name in UNKNOWN_OPS: + raise Exception( + f"ttir analysis hit an op we do not know how to analyze: {op.name}" + ) + + if op.name == "tt.call": + assert op.fn_call_name in functions + mutations = analyze_kernel_mutations( + functions, op.fn_call_name, len(op.args) + ) + stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) + else: + for idx in MUTATION_OPS.get(op.name, []): + stack.append(op.args[idx]) + + # The following is an iterative DFS algorithm + mutated = [False] * num_args + while stack: + arg = stack.pop() + if arg in visited: + continue + + visited.add(arg) + + if isinstance(arg, Param): + if arg.idx >= num_args: + # This is an argument defined in the kernel, not passed in + continue + mutated[arg.idx] = True + elif isinstance(arg, Intermediate) and not arg.fake(): + for op in ops[arg]: + # Skip arguments to load + if op.name != "tt.load": + stack.extend(op.args) + return mutated + + +def identify_mutated_tensors(kernel, kwargs): + """ + Given a triton kernel and the arguments for this kernel, this function + 1) Retrieves the TTIR converted version of the kernel from Triton's API. + 2) Parses the TTIR and creates a control flow graph + 3) Analyzes the graph to detect all input tensor mutations + """ + + ttir_module = None + functions = None + try: + from torch._dynamo import config + + if not config.optimize_user_defined_triton_kernels: + raise Exception("optimize_user_defined_triton_kernels is False") + + ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs) + + # extract functions from TTIR + if hasattr(ttir_module, "walk"): + # use MLIR bindings exposed by Triton code + functions = ttir_to_functions(ttir_module) + else: + # parse string representation of Triton IR + functions = parse_ttir(str(ttir_module), kwargs) + + assert functions is not None + kernel_name = next(iter(functions.keys())) + # Triton codegen modifies the name + assert kernel.fn.__name__ in kernel_name + # Reset the cache between top level invocations + # The cache for analyze kernel mutations is mainly used for cycle + # detection, so each top level invocation needs a clean cache + analyze_kernel_mutations.reset() + mutations = analyze_kernel_mutations( + functions, kernel_name, len(ordered_tensor_names) + ) + + return [ + ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated + ] + except Exception as e: + import traceback + + warnings.warn( + "Encountered an exception in identify_mutated_tensors, " + "assuming every input is mutated:\n" + "".join( + traceback.TracebackException.from_exception(e).format() # noqa: G001 + ) + ) + if ttir_module is not None: + log.debug("TTIR:\n%s", str(ttir_module)) + if functions is not None: + log.debug("functions:") + for name, fn in functions.items(): + log.debug("===\t%s\t===", name) + for ret, ops in fn.items(): + log.debug("%s\t=>\t%s", ret, ops) + return [key for key, value in kwargs.items() if isinstance(value, Tensor)] + + +############################################################################### +# Triton Kernel Wrappers + + +# Used for wrapping a Triton Kernel +class TritonKernelWrapperMutation(HigherOrderOperator): + def __init__(self): + super().__init__("triton_kernel_wrapper_mutation") + + +triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() + + +# Used for wrapping a Triton Kernel in a functional manner +class TritonKernelWrapperFunctional(HigherOrderOperator): + def __init__(self): + super().__init__("triton_kernel_wrapper_functional") + + +triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() + + +@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs): + from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code + + kernel = kernel_side_table.get_kernel(kernel_idx) + + if len(grid) == 1: + grid_fn = grid[0] + else: + fn_name, code = user_defined_kernel_grid_fn_code( + kernel.fn.__name__, kernel.configs, grid + ) + namespace: Dict[str, Any] = {} + exec(code, namespace) + grid_fn = namespace[fn_name] + + kernel[grid_fn](**kwargs) + + +@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) +def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel_idx, grid, kwargs): + with mode: + return None + + +def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args): + with disable_proxy_modes_tracing(): + out = func_overload(**node_args) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", + func_overload, + (), + proxy_args, + name=func_overload.__name__ + "_proxy", + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( + mode, *, kernel_idx, grid, kwargs +): + if mode.enable_tracing: + trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_mutation, + {"kernel_idx": kernel_idx, "grid": grid, "kwargs": kwargs}, + ) + else: + triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs) + + return None + + +@triton_kernel_wrapper_mutation.py_functionalize_impl +def triton_kernel_wrapper_mutation_functionalize(ctx, kernel_idx, grid, kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + kernel = kernel_side_table.get_kernel(kernel_idx) + # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each + # other, and one gets mutated in kernel, and later another gets mutated, + # they are no longer equal. Fix this by graph breaking on this condition + # earlier in dynamo. + tensors_to_clone = identify_mutated_tensors(kernel, unwrapped_kwargs) + with ctx.redispatch_to_next(): + unwrapped_outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + grid=grid, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + + assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys())) + for key, output_arg in unwrapped_outputs.items(): + if not isinstance(output_arg, Tensor): + continue + input_arg = kwargs[key] + assert isinstance(input_arg, Tensor) + + ctx.replace(input_arg, output_arg) + # indicate that above replace is hidden from autograd + ctx.mark_mutation_hidden_from_autograd(input_arg) + ctx.commit_update(input_arg) + ctx.sync(input_arg) + # sync calls replace_ under the hood, so again indicate that + # this indirect replace is hidden from autograd + ctx.mark_mutation_hidden_from_autograd(input_arg) + return None + + +@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_functional_dense( + *, kernel_idx, grid, kwargs, tensors_to_clone +): + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + kwargs = { + key: (clone_preserve_strides(val) if key in tensors_to_clone else val) + for key, val in kwargs.items() + } + triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs) + return {key: val for key, val in kwargs.items() if key in tensors_to_clone} + + +@triton_kernel_wrapper_functional.py_impl(FakeTensorMode) +def triton_kernel_wrapper_functional_fake_tensor_mode( + mode, *, kernel_idx, grid, kwargs, tensors_to_clone +): + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + with mode: + return { + key: clone_preserve_strides(val) + for key, val in kwargs.items() + if key in tensors_to_clone + } + + +@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( + mode, *, kernel_idx, grid, kwargs, tensors_to_clone +): + if mode.enable_tracing: + return trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_functional, + { + "kernel_idx": kernel_idx, + "grid": grid, + "kwargs": kwargs, + "tensors_to_clone": tensors_to_clone, + }, + ) + else: + return triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + grid=grid, + kwargs=kwargs, + tensors_to_clone=tensors_to_clone, + ) + + +@triton_kernel_wrapper_functional.py_functionalize_impl +def triton_kernel_wrapper_functional_functionalize( + ctx, kernel_idx, grid, kwargs, tensors_to_clone +): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + grid=grid, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + return ctx.wrap_tensors(outputs) + + +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU) + +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/utils.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8673fa6e06981c5f1a2a1c303c0a789af64859 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/utils.py @@ -0,0 +1,183 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator +from torch.fx.experimental.proxy_tensor import make_fx +from torch.multiprocessing.reductions import StorageWeakRef + + +@dataclass +class UnsupportedAliasMutationException(RuntimeError): + reason: str + + +def autograd_not_implemented_inner( + operator: HigherOrderOperator, delayed_error: bool, *args: Any, **kwargs: Any +) -> Any: + """If autograd is enabled and any of the arguments require grad this will either + raise an error or return a DelayedError depending on the value of delayed. + + Args: + operator: The HigherOrderOperator to call with the *args and **kwargs with + op_name: The name of the HigherOrderOperator + delayed_error: If True, return a DelayedError instead of raising an error + args: The flattened operands to the HigherOrderOperator + kwargs: The keyword arguments to the HigherOrderOperator + + Raises: + RuntimeError: If autograd is enabled and any of the arguments to the HigherOrderOperator + """ + with torch._C._AutoDispatchBelowAutograd(): + result = operator(*args, **kwargs) + flat_operands = pytree.arg_tree_leaves(*args) + if torch.is_grad_enabled() and any( + f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) + ): + if delayed_error: + err_fn = torch._C._functions.DelayedError( + f"Autograd not implemented for {str(operator)}", + 1, + ) + + def fake_requires_grad(tensor): + if torch.is_floating_point(tensor) or torch.is_complex(tensor): + tensor = tensor.detach() + tensor.requires_grad = True + return tensor + + return pytree.tree_map_only( + torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result + ) + else: + raise RuntimeError(f"Autograd not implemented for {str(operator)}") + return result + + +def autograd_not_implemented(op: HigherOrderOperator, deferred_error: bool) -> Callable: + def inner(*args, **kwargs): + return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) + + return inner + + +def _maybe_run_with_interpreter(fn): + maybe_interpreted_fn = fn + if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta(): + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with fx_traceback.preserve_node_meta(): + return torch.fx.Interpreter(fn).run(*args) + + maybe_interpreted_fn = graph_with_interpreter + return maybe_interpreted_fn + + +# We'll use the current decomposition table to make sure operators in subgraphs are +# decomposed properly. +# We also need to maybe run with interpreter for propagating stack_trace +def reenter_make_fx(fn, pre_dispatch=False): + decomp_table = torch.fx.experimental.proxy_tensor.CURRENT_DECOMPOSITION_TABLE + return make_fx( + _maybe_run_with_interpreter(fn), + decomposition_table=decomp_table, + pre_dispatch=pre_dispatch, + ) + + +@contextmanager +def _set_compilation_env(): + _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag + try: + # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo + # once we are confident fx tracing works with dynamo. + torch.fx._symbolic_trace._is_fx_tracing_flag = False + yield + finally: + torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing + + +def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): + """ + Dispatch-trace the branch with inputs and check if + producing graph has mutable op on the input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + def _detect_input_mutation(gm): + input_nodes = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_nodes.add(node) + if node.op == "call_function": + target = node.target + if ( + isinstance(target, torch._ops.OpOverload) + and target._schema.is_mutable + ): + for arg in node.args: + if arg in input_nodes: + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule): + if _detect_input_mutation(module): + return True + + return False + + return _detect_input_mutation(gm) + + +def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False): + """ + Dispatch-trace the branch with inputs and check if + producing graph has output aliasing the branch input. This is + bit restrictive as the branch must be traceable. + """ + try: + gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + def _detect_input_alias(gm): + input_storages = set() + for node in gm.graph.nodes: + # We need to check existence of "val" because we reuse the logic here + # for map operator, where num_mapped_args is a scalar + # and doesn't have a "val" meta. + if node.op == "placeholder" and "val" in node.meta: + input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) + if node.op == "output": + + def check_alias(out): + if out is not None and "val" in out.meta: + out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) + return out_storage in input_storages + return False + + if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): + return True + + return False + + return _detect_input_alias(gm) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/while_loop.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/while_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..441560ee191f2cbf83781f32950df697cab555f3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/while_loop.py @@ -0,0 +1,232 @@ +import torch +import torch.utils._pytree as pytree + +from torch._C import DispatchKey + +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class WhileLoopOp(HigherOrderOperator): + def __call__(self, cond_fn, body_fn, operands): + if not isinstance(cond_fn, torch.fx.GraphModule) or not isinstance( + body_fn, torch.fx.GraphModule + ): + raise RuntimeError( + "cond_fn and body_fn must be torch.fx.GraphModule, got " + f"{type(cond_fn)} and {type(body_fn)}" + ) + if not isinstance(operands, tuple): + raise RuntimeError("operands must be a tuple, got " f"{type(operands)}") + if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in operands): + raise RuntimeError( + "operands must be a tuple of tensors, ints, floats, or bools, got " + f"{operands}" + ) + return super().__call__(cond_fn, body_fn, operands) + + +while_loop_op = HigherOrderOperator("while_loop") + + +def while_loop(cond_fn, body_fn, operands): + r""" + Run body_fn(*operands) while cond_fn(*operands) returns a True scalar tensor. Returns the output of body_fn or + initial operands. + + .. warning:: + `torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export. + + `while_loop` is equivalent to the following: + + def while_loop(cond_fn, body_fn, operands): + val = operands + while cond_fn(*val): + val = body_fn(*val) + return val + + Args: + cond_fn (Callable): A callable function that returns a boolean Scalar tensor. + + body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors + + operands (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also + the initial value of states that are carried across iterations. + + Example: + + def cond_fn(iter, x): + return iter.sum() < 10 + + def body_fn(iter, x): + return iter + 1, x.sin() + + while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4))) + + Restrictions: + + - body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs. + + - body_fn and cond_fn must not in-place mutate the operands. A clone before the mutation is required. + + - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn. + + - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required. + + .. warning:: + Temporal Limitations: + + - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. + + """ + if torch.compiler.is_dynamo_compiling(): + return while_loop_op(cond_fn, body_fn, operands) + + def _validate_input(cond_fn, body_fn, operands): + if not callable(cond_fn) or not callable(body_fn): + raise RuntimeError("Expect cond_fn and body_fn to be callbale.") + + if not isinstance(operands, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), operands + ): + raise RuntimeError( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor leaves, but got {operands}." + ) + + _validate_input(cond_fn, body_fn, operands) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + return torch.compile(while_loop_op, backend="eager", fullgraph=True)( + cond_fn, body_fn, operands + ) + + +@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def while_loop_dense(cond_fn, body_fn, operands): + init_val = operands + + def _is_boolean_scalar_tensor(pred): + return ( + isinstance(pred, torch.Tensor) + and pred.size() == torch.Size([]) + and pred.dtype == torch.bool + ) + + if not isinstance(operands, tuple): + raise RuntimeError(f"operands must be a tuple but got {type(operands)}") + + while pred := cond_fn(*init_val): + if not _is_boolean_scalar_tensor(pred): + raise RuntimeError( + f"cond_fn must return a boolean scalar tensor but got {pred}" + ) + out = body_fn(*init_val) + assert isinstance( + out, tuple + ), f"body_fn should return a tuple but got {type(out)}" + assert len(out) == len( + init_val + ), "body_fn should return the same number of elements as operands" + init_val = out + return init_val + + +while_loop_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(while_loop_op, deferred_error=True) +) + + +@while_loop_op.py_impl(ProxyTorchDispatchMode) +def while_loop_tracing(mode, cond_fn, body_fn, operands): + def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands): + pre_dispatch = getattr(proxy_mode, "pre_dispatch", False) + with disable_proxy_modes_tracing(): + cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands) + body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands) + + next_name = None + i = 0 + while not next_name: + candidate = f"while_loop_cond_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + cond_graph_name = next_name + body_graph_name = f"while_loop_body_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, body_graph_name) + + proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph) + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + args = (cond_graph, body_graph, operands) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", while_loop_op, proxy_args, {}, name="while_loop" + ) + + # body_fn return output with the same pytree and tensor meta data as operands + # so we could just return the output after one iteration. + out = body_fn(*operands) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + if mode.enable_tracing: + return _trace_while_loop(mode, while_loop_op, cond_fn, body_fn, operands) + else: + return while_loop_op(cond_fn, body_fn, operands) + + +@while_loop_op.py_impl(FakeTensorMode) +def while_loop_fake_tensor_mode(mode, cond_fn, body_fn, operands): + return body_fn(*operands) + + +@while_loop_op.py_functionalize_impl +def while_loop_func(ctx, cond_fn, body_fn, operands): + unwrapped_operands = ctx.unwrap_tensors(operands) + with ctx.redispatch_to_next() as m: + functional_cond_fn = ctx.functionalize(cond_fn) + functional_body_fn = ctx.functionalize(body_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for fn, fn_name in [ + (functional_cond_fn, "cond_fn"), + (functional_body_fn, "body_fn"), + ]: + if _has_potential_branch_input_mutation( + fn, unwrapped_operands, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + f"torch.while_loop's {fn_name} might be modifying the input!" + ) + + for fn in [functional_cond_fn, functional_body_fn]: + if _has_potential_branch_input_alias( + fn, unwrapped_operands, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + f"torch.while_loop's {fn_name} might be aliasing the input!" + ) + ret = while_loop_op(functional_cond_fn, functional_body_fn, unwrapped_operands) + return ctx.wrap_tensors(ret) diff --git a/MLPY/Lib/site-packages/torch/_higher_order_ops/wrap.py b/MLPY/Lib/site-packages/torch/_higher_order_ops/wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..949c24f2abe4700a1d975f4def130f9fd04b2996 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_higher_order_ops/wrap.py @@ -0,0 +1,183 @@ +import inspect +import logging + +import torch +from torch._ops import HigherOrderOperator +from torch.utils.checkpoint import checkpoint, uid +import torch._dynamo.config + +log = logging.getLogger(__name__) + + + +# Used for testing the HigherOrderOperator mechanism +class Wrap(HigherOrderOperator): + def __init__(self): + super().__init__("wrap") + + def __call__(self, func, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + result = func(*args, **kwargs) + return result + + return wrapper() + +wrap = Wrap() + +class WrapWithSetGradEnabled(HigherOrderOperator): + def __init__(self): + super().__init__("wrap_with_set_grad_enabled") + + def __call__(self, enable_grad, wrapped_func, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + with torch.set_grad_enabled(enable_grad): + return wrapped_func(*args, **kwargs) + return wrapper() + +wrap_with_set_grad_enabled = WrapWithSetGradEnabled() + +class WrapActivationCheckpoint(HigherOrderOperator): + """ + This operator is used to wrap torch.utils.checkpoint. This avoids + TorchDynamo to look into saved tensor hooks and directly passes the control + to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of + AOT tracing torch.utils.checkpoint code, we have a backward graph with + recomputed forward nodes. + + However, we might deprecate this operator soon. The difficulty arises in the + functionalization of rng ops. Today, there are two different + functionalization of rng ops - one at AOT autograd and other at Inductor. + And they are difficult to map to each other. The rng states also complicate + pattern matching in Inductor. Due to the ease of implementation, we are + currently inclined towards functionalization at Inductor level, which means + that duplication/recomputation is done as a compiler pass in the + partitioners. See TagActivationCheckpoint for more information. + """ + def __init__(self): + super().__init__("wrap_activation_checkpoint") + + def __call__(self, function, *args, **kwargs): + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + kwargs["use_reentrant"] = False + kwargs["preserve_rng_state"] = False + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + return checkpoint(Interpreter(function).run, *args, **kwargs) + +wrap_activation_checkpoint = WrapActivationCheckpoint() + +class TagActivationCheckpoint(HigherOrderOperator): + """ + This operator is supposed to be used only with torch.compile stack. This + accepts a Fx graph module which needs to be checkpointed. This operator adds + "recomputable" tag to the nodes of the Fx graph that should be recomputed. + + The goal is to: + 1. Avoid using Dynamo to trace through saved tensor hooks. + 2. For selective checkpointing case, let AOTAutograd trace through + saved tensor hooks but has special logic with TorchDispatchMode to override + the usual saved_tensor_hooks fn logic in order to tag the nodes. + 3. Rely on the partitioners to actually duplicate the nodes. + This sits well in the torch.compile stack, because by the time graph + reaches partitioner, inductor has already run its functionalization of rng + ops (by setting fixed seed for each random op, see `replace_random_passes`). + Therefore, the duplication of nodes, by design, respects the rng states in + the forward and recomputed forward in backward. + """ + + def __init__(self): + super().__init__("tag_activation_checkpoint") + + @staticmethod + def divide_kwargs(kwargs): + """ + checkpoint fn can have mixed kwargs between checkpointed fn and + checkpoint fn itself. For example + >> def gn(x, y, z=None): + >> a = torch.matmul(x, y) + >> if z is not None: + >> return torch.matmul(a, z) + >> return a + >> def fn(x, y, z): + >> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z)) + In the above case, z belongs to checkpointed function gn, but + use_reentrant belongs to the checkpoint function. This function splits + the kwargs into checkpoint_kwargs and gmod_kwargs (or + checkpointed_fn_kwargs). + We do sorting to ensure same graph from run to run for better + debuggability. It is not required for correctness. + """ + ckpt_signature = inspect.signature(checkpoint) + checkpoint_keys = set() + for name in ckpt_signature.parameters: + if name in ("function", "args", "kwargs"): + continue + checkpoint_keys.add(name) + + # `preserve_rng_state` is not a regular kwarg + checkpoint_keys.add("preserve_rng_state") + + checkpoint_kwargs = {name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys} + gmod_kwargs = {name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys} + return checkpoint_kwargs, gmod_kwargs + + def tag_nodes(self, gmod): + unique_graph_id = next(uid) + for node in gmod.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + node.meta["recompute"] = unique_graph_id + return gmod + + def __call__(self, gmod, *args, **kwargs): + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + if "_checkpoint_context_fn" in gmod.meta: + assert torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint, \ + "Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile" + log.warning(""" +Detected that context_fn is passed to torch.utils.checkpoint under torch.compile. +Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). +""") + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + kwargs["use_reentrant"] = False + # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through + # `torch.random.fork_rng` op (which is not supported yet under CUDA). + # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state + # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor + # instead of in AOTAutograd). + kwargs["preserve_rng_state"] = False + kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"] + # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag + # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py. + gmod = self.tag_nodes(gmod) + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + return checkpoint(Interpreter(gmod).run, *args, **kwargs) + else: + gmod = self.tag_nodes(gmod) + # Using interpreter allows preservation of metadata through torch.compile stack. + # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here + # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile. + # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test) + with fx_traceback.preserve_node_meta(): + return Interpreter(gmod).run(*args) + +tag_activation_checkpoint = TagActivationCheckpoint() diff --git a/MLPY/Lib/site-packages/torch/_inductor/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29d0966ca407eee3107feb137924ca246c8bf8d0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/__init__.py @@ -0,0 +1,150 @@ +from typing import Any, Dict, List, Optional + +import torch.fx +import torch.utils._pytree as pytree + +__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] + + +def compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +): + """ + Compile a given FX graph with TorchInductor. This allows compiling + FX graphs captured without using TorchDynamo. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Callable with same behavior as gm but faster. + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options) + + +def aot_compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + options: Optional[Dict[str, Any]] = None, +) -> str: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library + """ + from .compile_fx import compile_fx_aot + + # We will serialize the pytree info into the .so as constant strings + in_spec = None + out_spec = None + if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen): + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + if codegen.pytree_info.in_spec is not None: + in_spec = codegen.pytree_info.in_spec + if codegen.pytree_info.out_spec is not None: + out_spec = codegen.pytree_info.out_spec + + else: + if hasattr(gm, "_in_spec"): + in_spec = gm._in_spec + if hasattr(gm, "_out_spec"): + out_spec = gm._out_spec + + serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else "" + serialized_out_spec = ( + pytree.treespec_dumps(out_spec) if out_spec is not None else "" + ) + + options = ( + { + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + if options is None + else { + **options, + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + ) + + return compile_fx_aot( + gm, + example_inputs, + config_patches=options, + ) + + +def list_mode_options( + mode: Optional[str] = None, dynamic: Optional[bool] = None +) -> Dict[str, Any]: + r"""Returns a dictionary describing the optimizations that each of the available + modes passed to `torch.compile()` performs. + + Args: + mode (str, optional): The mode to return the optimizations for. + If None, returns optimizations for all modes + dynamic (bool, optional): Whether dynamic shape is enabled. + + Example:: + >>> torch._inductor.list_mode_options() + """ + + mode_options: Dict[str, Dict[str, bool]] = { + "default": {}, + # enable cudagraphs + "reduce-overhead": { + "triton.cudagraphs": True, + }, + # enable max-autotune + "max-autotune-no-cudagraphs": { + "max_autotune": True, + }, + # enable max-autotune + # enable cudagraphs + "max-autotune": { + "max_autotune": True, + "triton.cudagraphs": True, + }, + } + return mode_options[mode] if mode else mode_options # type: ignore[return-value] + + +def list_options() -> List[str]: + r"""Returns a dictionary describing the optimizations and debug configurations + that are available to `torch.compile()`. + + The options are documented in `torch._inductor.config`. + + Example:: + + >>> torch._inductor.list_options() + """ + + from torch._inductor import config + + current_config: Dict[str, Any] = config.shallow_copy_dict() + + return list(current_config.keys()) + + +def cudagraph_mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + from .cudagraph_trees import mark_step_begin + + mark_step_begin() diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2b7ec8f37182fb250dbf3e4587b6704e4bf17d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dfb6b968369008e03efba6b0ca1a1fe2960bb98 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/bounds.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/bounds.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..179dd21e5c284562023d272eaa40fddf635818e8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/bounds.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c583798d146f856bfecebc2ae7616888c7ab189 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02221c291e1257e245ba2eed3b59e5e9d741db48 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/comms.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/comms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb6a2898d345f539d56b6b31ec947c5f571cef0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/comms.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3ed23f9890aec154f3baee41a09470f3055d13d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e997f070d47106ff1da57fba32988c8df33167a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a040f1f5a7abe6a598ac127fb9bdebdd9ff7e616 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5ef0cafbc661712b55743fe0594e417054e6b16 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/coordinate_descent_tuner.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dec02e720073645cbd8923d621cd894542471ba Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5de6b439f891a6a01b0062a89fc49c3ea4d819d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/debug.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/debug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2da29903ee89721451d88323f07d06a3dad925b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/debug.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/decomposition.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/decomposition.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7e4a82e50a000f075a5fdc2d9ba735eed10994f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/decomposition.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/dependencies.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/dependencies.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd3a676b9d12e5933ef7b3c3d9a6a71578de690b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/dependencies.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5558306c4a66b4524af71882f03ec8a3882ef38 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/freezing.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/freezing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dcc3d749c8f13d332c1e87e215ede1f67ff9677 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/freezing.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..378403e528e52d88d8f2b7bebd311d4f1df7031f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e421cb07dd72e689707bbd300fc4d4126908d87 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/graph.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68f461be153419be45bda309e40d6d21b227be6e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/hooks.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..339304400bd2b5eb368c4352657f8ad5e008c607 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e719af4d20b3bdc98c5c7af6c091b4727ed1d5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/ir.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/ir.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa0f1fd2f7d32cc7fd71c056ece6f85883f3f0b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/ir.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/lowering.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/lowering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d48384d16d681589d091555c9e2a4ae5146e6ffc Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/lowering.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/metrics.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b0adcb1e9c505739472e75ac7d1d0c5ed57832 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/metrics.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b4b8056f00d5d6aede3597bb20b5c20c497fc53 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f76c9b94117b7ba712465957ceaea6b75a0026c1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35ef5d63bca08a05f5a9ae675ffd88f1b7863f9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2935c40933a47db275600f656d3434e622de6372 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/scheduler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4346f248be45913fefccf40a96d5ade9cd6df848 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/scheduler.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3b62fa63180ae3e5e76aa29515f76049a6a217 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/sizevars.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/sizevars.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c00066a2b1d1ef038a1981ae830429c258cb0e69 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/sizevars.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/test_case.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/test_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a5893a8cfe488a90f9ca95bfb09c360e0bfde9e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/test_case.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/test_operators.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/test_operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ed2266fbc463da832d2a87cdb1ca910e0604d3b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/test_operators.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5a4778bc17b3187d522f31511b6cd83295abba6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/triton_helpers.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088c6234fd9d34579889ab66defb15660a4a98d1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/triton_heuristics.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd1425fc743095ca9c80886571f7425ff4d47de Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/virtualized.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/virtualized.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6658f177060f9d2ed34bda0db17be0611b62b63 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/virtualized.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c93490f3f0057705110ec80787430a0eb9d4310 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/autotune_process.py b/MLPY/Lib/site-packages/torch/_inductor/autotune_process.py new file mode 100644 index 0000000000000000000000000000000000000000..6a3bf9fcf23f95982bc9904304dede907d0d6513 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/autotune_process.py @@ -0,0 +1,656 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import logging +import os +import queue +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from ctypes import byref, c_size_t, c_void_p +from multiprocessing.process import BaseProcess +from multiprocessing.queues import Queue +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TYPE_CHECKING, + Union, +) + +import torch +from torch import multiprocessing +from torch._dynamo.testing import rand_strided + +from torch._inductor import ir +from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache + +if TYPE_CHECKING: + from torch._inductor.select_algorithm import TritonTemplateCaller + +from . import config +from .utils import do_bench +from .virtualized import V + +CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" +EXIT_HANDLER_REGISTERED = False + +log = logging.getLogger(__name__) + + +# Used to synchronize between parent and child processes +class Ping: + pass + + +class Pong: + pass + + +@contextlib.contextmanager +def set_cuda_visible_device(device: Optional[int]): + """ + Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the + specified single device. If device is None, don't manipulate the environment. + """ + if device is None: + yield + return + + current = os.environ.get(CUDA_VISIBLE_DEVICES) + os.environ[CUDA_VISIBLE_DEVICES] = str(device) + try: + yield + finally: + if current is None: + del os.environ[CUDA_VISIBLE_DEVICES] + else: + os.environ[CUDA_VISIBLE_DEVICES] = current + + +@dataclasses.dataclass +class TuningProcess: + """ + Abstraction for launching a helper process to benchmark kernels. Spawns + the parent process and uses multiprocessing queues to send benchmark + requests and return results. + """ + + device: Optional[int] = None + process: Optional[BaseProcess] = None + request_queue: Optional[Queue[Any]] = None + response_queue: Optional[Queue[Any]] = None + + @staticmethod + def process_main( + request_queue: Queue[Any], + response_queue: Queue[Any], + ) -> None: + """ + Entry point for the child process. + """ + log.debug( + "Entering TuningProcess child. Visible devices = %s", + os.environ.get(CUDA_VISIBLE_DEVICES), + ) + try: + TuningProcess.workloop(request_queue, response_queue) + except Exception as ex: + log.exception("Exception in TuningProcess: %s", ex) + + @staticmethod + def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None: + """ + Work loop for the benchmarking subprocess. + """ + while True: + obj = request_queue.get() + + if obj is None: + break # None is a sentinel for the child to terminate + elif isinstance(obj, Ping): + response_queue.put(Pong()) + elif isinstance(obj, BenchmarkRequest): + response_queue.put(obj.benchmark()) + else: + raise RuntimeError(f"Invalid request type {type(obj)}") + + def valid(self) -> bool: + """ + True if the sub-process has been initialized. + """ + return ( + self.process is not None + and self.request_queue is not None + and self.response_queue is not None + ) + + def clear(self) -> None: + """ + Reset to an uninitialized state. + """ + self.process = self.request_queue = self.response_queue = None + + def initialize(self) -> None: + """ + Create child process, request/response queues, and do the warm up. + Set the environment to make only the provided GPU device visible + to the process. + """ + if self.valid(): + return + + # cuda runtime does not work with "fork", use "spawn" to start processes. + ctx = multiprocessing.get_context("spawn") + self.request_queue = ctx.Queue() + self.response_queue = ctx.Queue() + + self.process = ctx.Process( + target=self.process_main, + args=( + self.request_queue, + self.response_queue, + ), + ) + assert self.process is not None + with set_cuda_visible_device(self.device): + self.process.start() + + def put(self, obj: Any) -> None: + """ + Push a work item to the child process. + """ + # In case of a prior crash, ensure the subprocess is running + self.initialize() + assert self.request_queue is not None + self.request_queue.put(obj) + + def get(self) -> Any: + """ + Get a response from the child process. + """ + assert self.process is not None + assert self.response_queue is not None + while True: + try: + return self.response_queue.get(timeout=1.0) + except queue.Empty: + status = self.process.exitcode + if status is None: + # child process is still running + continue + # child process crashed + self.clear() + raise + + def terminate(self) -> None: + """ + Signal the child process to terminate. + """ + if self.valid(): + assert self.process is not None + assert self.request_queue is not None + self.request_queue.put(None) + + def wait(self) -> None: + """ + Wait for the child process to exit. + """ + if self.process is not None: + self.process.join() + self.clear() + + +@dataclasses.dataclass +class TuningProcessPool: + """ + Maintains a pool of TuningProcesses to benchmark kernels in parallel + across devices. By default, we create one TuningProcess per device and + set the sub-process environment to make only that device visible. + """ + + processes: Optional[queue.Queue[TuningProcess]] = None + executor: Optional[ThreadPoolExecutor] = None + + def initialize(self) -> None: + """ + Start the child processes. + """ + assert (self.processes is None) == (self.executor is None) + if self.processes is not None: + return + + devices = self.get_device_list() + log.debug("Sub-process autotune device list: %s", devices) + + # Launch the child processes and push a msg to "warm up" + self.processes = queue.Queue() + for device in devices: + p = TuningProcess(device=device) + p.initialize() + p.put(Ping()) + self.processes.put(p) + + # Wait for the initialization to finish + for p in self.processes.queue: + assert isinstance(p.get(), Pong) + + # Use a thread pool to manage distributing work to the subprocesses. + # Threads block on an available process, so it makes sense to match + # the number of threads with the number of devices. + self.executor = ThreadPoolExecutor(max_workers=len(devices)) + + # Register the exit handler for the parent process so it will terminate + # the child processes. + global EXIT_HANDLER_REGISTERED + if not EXIT_HANDLER_REGISTERED: + EXIT_HANDLER_REGISTERED = True + import atexit + + atexit.register(self.terminate) + + def get_device_list(self) -> Sequence[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + count = torch.cuda.device_count() + + # If the user specified the visible devices in the env, use those. + if CUDA_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] + assert len(devices) <= count + return devices + + return list(range(count)) + + def terminate(self) -> None: + """ + Signal all child processes to terminate. + """ + if self.executor is not None: + self.executor.shutdown() + self.executor = None + + if self.processes is not None: + for p in self.processes.queue: + p.terminate() + for p in self.processes.queue: + p.wait() + self.processes = None + + def target(self, choice: TritonTemplateCaller) -> float: + """ + Entry point for the thread-pool helper threads: Wait for an open TuningProcess, + remove it from the queue, execute the benchmark in that subprocess, and return + the TuningProcess to the queue. + """ + assert choice.bmreq is not None + assert self.processes is not None + + process = self.processes.get() + process.put(choice.bmreq) + try: + return process.get() + except queue.Empty: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # set to INF so this choice will be ignored + return float("inf") + finally: + self.processes.put(process) + + def benchmark( + self, + choices: List[TritonTemplateCaller], + ) -> Dict[TritonTemplateCaller, float]: + """ + Benchmark each choice in a separate process. + """ + assert self.processes is not None, "Tuning process pool is not initialized" + assert self.executor is not None + + results = {} + + # Use a ThreadExecutorPool to spread the work across the subprocesses and + # to grab subprocesses as soon as they're free. + for choice, result in zip(choices, self.executor.map(self.target, choices)): + results[choice] = result + + return results + + +tuning_pool = TuningProcessPool() + + +LayoutOrBuffer = Union[ir.Layout, ir.Buffer] + + +@dataclasses.dataclass +class TensorMeta: + device: torch.device + dtype: torch.dtype + sizes: torch._prims_common.ShapeType + strides: torch._prims_common.StrideType + offset: int + + @classmethod + def from_irnodes( + cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] + ) -> Union[TensorMeta, List[TensorMeta]]: + if isinstance(irnodes, Sequence): + result: List[Any] = [cls.from_irnodes(x) for x in irnodes] + assert all(isinstance(x, TensorMeta) for x in result) + return result + + node = irnodes + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + + dtype = node.get_dtype() + assert dtype is not None + + return TensorMeta( + device=node.get_device(), + dtype=dtype, + sizes=V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + strides=V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + offset=V.graph.sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + def to_tensor(self) -> torch.Tensor: + return rand_strided( + self.sizes, + self.strides, + device=self.device, + dtype=self.dtype, + extra_size=self.offset, + ) + + +@dataclasses.dataclass +class BenchmarkRequest: + """ + Only handle triton template benchmark for now. The extern kernel benchmark + can be done inside the same process since they usually don't cause crash. + + Important: Instances of this class and subclasses have to be serializable + across process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + ): + # the kernel name defined in the module + self.kernel_name = kernel_name + + if isinstance(input_tensor_meta, TensorMeta): + input_tensor_meta = [input_tensor_meta] + self.input_tensor_meta = input_tensor_meta + + if isinstance(output_tensor_meta, (tuple, list)): + assert len(output_tensor_meta) == 1 + output_tensor_meta = output_tensor_meta[0] + self.output_tensor_meta = output_tensor_meta + + self.extra_args = extra_args + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + raise NotImplementedError() + + def cleanup_run_fn(self) -> None: + pass + + def benchmark( + self, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + debug = log.isEnabledFor(logging.DEBUG) + if debug: + start_ts = time.time() + + # create args and out tensor + if output_tensor is None: + assert len(input_tensors) == 0 + input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) + output_tensor = self.output_tensor_meta.to_tensor() + + if debug: + create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + + fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor) + + if debug: + load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + + out = do_bench(fn) + torch.cuda.synchronize() # shake out any CUDA errors + + if debug: + bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + log.debug( + "InChildProcess %s: load %f, create tensor %f, bench %f", + str(self), + load_elapse, # type: ignore[possibly-undefined] + create_tensor_elapse, # type: ignore[possibly-undefined] + bench_elapse, + ) + self.cleanup_run_fn() + return out + + +class TestBenchmarkRequest(BenchmarkRequest): + """ + Supports unit testing. Defined in this file so that the TuningProcess + sub-process knows how to unpickle these objects. + """ + + def __init__(self, value: Optional[float] = None) -> None: + self.value = value + + def benchmark( + self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None + ) -> float: + if self.value is None: + raise Exception("Failed to run") + return self.value + + +class TritonBenchmarkRequest(BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + module_path: str, # the path of the module defining the triton kernel + module_cache_key: str, + grid: List[int], + num_stages: int, + num_warps: int, + matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + ): + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.module_path = module_path + self.module_cache_key = module_cache_key + self.grid = grid + self.num_stages = num_stages + self.num_warps = num_warps + self.matrix_instr_nonkdim = matrix_instr_nonkdim + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + log.debug( + "benchmark module key: %s, path: %s", + self.module_cache_key, + self.module_path, + ) + + run_method = getattr(mod, self.kernel_name).run + extra_args = list(self.extra_args) + + # Newer version of triton add warmup argument to JITFunction.run. + # This code handles backward-compatibility. + warmup_arg = {} + import inspect + + if "warmup" in inspect.signature(run_method).parameters: + warmup_arg["warmup"] = False + + if torch.version.hip and self.matrix_instr_nonkdim != 0: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + num_stages=self.num_stages, + num_warps=self.num_warps, + matrix_instr_nonkdim=self.matrix_instr_nonkdim, + ) + else: + return functools.partial( + run_method, + *input_tensors, + output_tensor, + *self.extra_args, + grid=self.grid, + **warmup_arg, + num_stages=self.num_stages, + num_warps=self.num_warps, + ) + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" + + +class CUDABenchmarkRequest(BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ): + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate CUDACodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CUDACodeCache.load(self.source_code, "so") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( + self.source_code, "so" + ) + args = [ + c_void_p(tensor.data_ptr()) + for tensor in list(input_tensors) + [output_tensor] + ] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + run_method( + *args, # input ptrs and output ptrs + *self.extra_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + self.workspace_size = c_workspace_size.value + # TODO: Support non-zero workspace_size. + assert self.workspace_size == 0, ( + "Things need to be fixed to support non-zero workspace_size: " + "1) max autotune cache needs to store workspace size; " + "2) memory allocation needs to allocate / deallocate workspace correctly; " + ) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + None, # set workspace ptr, TODO: update it to a real ptr if workspace_size > 0 + stream_ptr, + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + + +def benchmark_in_sub_process( + choices: List[TritonTemplateCaller], +) -> Dict[TritonTemplateCaller, float]: + """ + Do benchmarking in a subprocess and return the perf number (latency). + """ + return tuning_pool.benchmark(choices) diff --git a/MLPY/Lib/site-packages/torch/_inductor/bounds.py b/MLPY/Lib/site-packages/torch/_inductor/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52e523e99ffbe5cd89eee073825c6d0c3c65a1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/bounds.py @@ -0,0 +1,124 @@ +import operator +from functools import partial +from typing import Any, Callable, Dict + +from sympy import Expr + +import torch +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from .ir import InterpreterShim, LoopBody, LoopBodyBlock +from .utils import cache_on_self, dominated_nodes +from .virtualized import V + + +class BoundVars: + """ + Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() + It exposes the ranges of the nodes in the `bounds` variable + + Note. A current limitation of this analysis is that it just works on a per-loop basis. + We should be able to propagate the bounds between across the whole graph. This may benefit + the case a bounded variable is returned by a kernel and fed into another. + """ + + def __init__(self, loop_body: LoopBody) -> None: + self.loop_body = loop_body + self.replacement_vals = { + k: ValueRanges[Expr](0, v - 1) + if (isinstance(v, int) or v.is_number) + else bound_sympy(v) + for k, v in loop_body.var_ranges.items() + } + # avoid computing these values, pessimistically assume that they are unbounded + self.unbounded_vars = dominated_nodes( + node + for node in self.loop_body.get_nodes() + if node.target in ["load", "reduction", operator.getitem] + or "masked_subblock" in node.target + ) + # To access this variable call `get_bounds()` + self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} + + @cache_on_self + def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: + submodules = self.swap_submodules(self.loop_body.submodules) + + # Initialize the environment with the unbounded variables + for node in self.unbounded_vars: + # we need to evaluate masked_subblock to recurse, and we need to set indirect values + if not isinstance(node.target, str) or ( + "masked_subblock" not in node.target + and "set_indirect" not in node.target + ): + self._bounds[node] = ValueRanges[Expr].unknown() + + with V.set_ops_handler(ValueRangeAnalysis()): + interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + interpreter.run(V.get_ops_handler(), initial_env=self._bounds) + return self._bounds + + def swap_submodules( + self, submodules: Dict[str, Callable[..., Any]] + ) -> Dict[str, Callable[..., ValueRanges[Expr]]]: + result: Dict[str, Callable[..., ValueRanges[Expr]]] = {} + for key in submodules.keys(): + if key == "get_index": + result[key] = self.get_index + elif "masked_subblock" in key: + subblock = self.loop_body.subblocks[key] + # The result within the lambda will reference to the final + # set of modules at the end of the for-loop as it stores a reference to it + + # bind subblock in a function because python lambdas close over by reference + # moving the lambda out of make_fn would close over the reference to subblock, + # so all lambdas would have the same subblock reference that is the final + # subblock in the loop + def make_fn(subblock): + return lambda mask, value: self.masked_subblock( + subblock, self._bounds, mask, value, result + ) + + result[key] = make_fn(subblock) + + elif "set_indirect" in key: + idx = int(key[len("set_indirect") :]) + var = self.loop_body.indirect_vars[idx] + indirect = partial(self.set_indirect, var) + result[key] = indirect + else: + assert "scan" in key + result[key] = submodules[key] + + return result + + def masked_subblock( + self, + subblock: LoopBodyBlock, + env: Dict[torch.fx.Node, ValueRanges[Expr]], + mask: Any, + value: Any, + submodules: Dict[str, Callable[..., Any]], + ) -> ValueRanges[Expr]: + interp = InterpreterShim(subblock.graph, submodules) + interp.run(V.get_ops_handler(), initial_env=env) + output = [node for node in subblock.graph.nodes if node.target == "output"] + assert len(output) == 1 + # dont bother unioning with value since the load from buffer will be + # pessimistically assumed to be inf anyway + return interp.env[output[0]] + + def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: + assert isinstance(new, ValueRanges) + self.replacement_vals[old] = new + return new + + def get_index(self, name: Expr) -> ValueRanges[Expr]: + expr = self.loop_body.indexing_exprs[name] + bound = self.replacement_vals.get(expr) + if bound is None: + bound = bound_sympy(expr, self.replacement_vals) + # The following assertion is true at the time of this writing + # We don't assert is as to not execute bound_sympy when bound is not None + # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) + self.replacement_vals[name] = bound + return bound diff --git a/MLPY/Lib/site-packages/torch/_inductor/codecache.py b/MLPY/Lib/site-packages/torch/_inductor/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..7b477691fe5971afe19c9f7945def2a5276c4d37 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codecache.py @@ -0,0 +1,2727 @@ +from __future__ import annotations + +import base64 +import copyreg +import dataclasses +import functools +import hashlib +import importlib +import io +import json +import logging +import multiprocessing +import os +import pathlib +import pickle +import pkgutil +import platform +import re +import shlex +import shutil +import signal +import subprocess +import sys +import sysconfig +import tempfile +import textwrap +import threading +import warnings +import weakref +from bisect import bisect_right +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from copy import copy +from ctypes import c_void_p, cdll, CDLL +from functools import partial +from pathlib import Path +from threading import Thread +from time import sleep, time +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import torch + +from torch._dynamo.device_interface import ( + get_interface_for_device, + get_registered_device_interfaces, +) +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor import config, exc, metrics +from torch._inductor.codegen.cuda import cuda_env +from torch._inductor.utils import cache_dir, developer_warning, is_linux +from torch._subclasses.fake_tensor import ( + extract_tensor_metadata, + FakeTensor, + TensorMetadata, +) +from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv + +if TYPE_CHECKING: + from torch._inductor.graph import GraphLowering + from torch._inductor.select_algorithm import ChoiceCaller + +from torch.hub import _Faketqdm, tqdm + +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") + +if config.is_fbcode(): + from triton.fb import build_paths + from triton.fb.build import _run_build_command + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args, **kwargs): + pass + + def log_global_cache_stats(*args, **kwargs): + pass + + def log_global_cache_vals(*args, **kwargs): + pass + + def use_global_cache() -> bool: + return False + + +LOCK_TIMEOUT = 600 + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + + +def _compile_start() -> None: + global _t0 + if _t0 is None: + _t0 = time() + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0 + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + + +log = logging.getLogger(__name__) + + +def cpp_wrapper_cache_dir(name: str) -> str: + cu_str = ( + "cpu" + if torch.version.cuda is None + else f'cu{torch.version.cuda.replace(".", "")}' + ) + python_version = f"py{sys.version_info.major}{sys.version_info.minor}" + build_folder = f"{python_version}_{cu_str}" + + cpp_wrapper_dir = os.path.join(cache_dir(), build_folder) + cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name) + os.makedirs(cpp_wrapper_build_directory, exist_ok=True) + return cpp_wrapper_build_directory + + +def get_cpp_wrapper_cubin_path_name(): + return "cubin_path" if torch.version.hip is None else "hsaco_path" + + +class CacheBase: + @staticmethod + @functools.lru_cache(None) + def get_system() -> Dict[str, Any]: + try: + import triton + + triton_version = triton.__version__ + except ModuleNotFoundError: + triton_version = None + + try: + system: Dict[str, Any] = { + "device": { + "name": torch.cuda.get_device_properties( + torch.cuda.current_device() + ).name, + }, + "version": { + "cuda": torch.version.cuda, + "triton": triton_version, + }, + } + except (AssertionError, RuntimeError): + # If cuda is not installed, none of the above config is relevant. + system = {} + + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + + @staticmethod + @functools.lru_cache(None) + def get_local_cache_path() -> Path: + return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) + + @staticmethod + @functools.lru_cache(None) + def get_global_cache_path() -> Optional[Path]: + return ( + Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"])) + if config.global_cache_dir is not None + else None + ) + + def __init__(self) -> None: + if not torch.cuda.is_available(): + return + + self.system = CacheBase.get_system() + + self.local_cache_path = CacheBase.get_local_cache_path() + self.global_cache_path = CacheBase.get_global_cache_path() + + def get_local_cache(self) -> Dict[str, Any]: + if not self.local_cache_path.is_file(): + return {} + with open(self.local_cache_path) as local_cache_fp: + local_cache = json.load(local_cache_fp) + return local_cache["cache"] + + def update_local_cache(self, local_cache: Dict[str, Any]) -> None: + if not os.path.exists(self.local_cache_path.parent): + os.makedirs(self.local_cache_path.parent, exist_ok=True) + + write_atomic( + str(self.local_cache_path), + json.dumps({"system": self.system, "cache": local_cache}, indent=4), + ) + + +class LocalCache(CacheBase): + def lookup(self, *keys: str) -> Optional[Dict[str, Any]]: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys: + if key in cache: + sub_cache = cache[key] + else: + return None + + return sub_cache + + def set_value(self, *keys: str, value: Any) -> None: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys[0:-1]: + sub_cache.setdefault(key, {}) + sub_cache = sub_cache[key] + sub_cache[keys[-1]] = value + + self.update_local_cache(cache) + + +class PersistentCache(CacheBase): + @functools.lru_cache(None) + def get_global_cache(self): + if self.global_cache_path is None or not self.global_cache_path.is_file(): + return {} + with open(self.global_cache_path) as global_cache_fp: + global_cache = json.load(global_cache_fp) + return global_cache["cache"] + + def lookup( + self, + choices: List[ChoiceCaller], + op: str, + inputs: str, + benchmark: Callable[[Any], Dict[ChoiceCaller, float]], + ) -> Dict[ChoiceCaller, float]: + """ + Check to see if we have benchmarked the given choice callers. For each + choice caller: + + 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. + 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 3. + a. `max_autotune_gemm=True`: benchmark the choice, update + local_cache[op][inputs][choice], and return the benchmark. + b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. + """ + precision = torch.get_float32_matmul_precision() + + log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) + log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) + log_errors = partial( + log_global_cache_errors, self.system, op, inputs, precision + ) + timings = {} + + def check_cache(cache, callback=None) -> bool: + """Check if `cache` contains data for all the choices""" + hit = True + for choice in choices: + choice_hash = choice.hash_key() + if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}): + # cache hit + timings[choice] = cache[op][inputs][precision][choice_hash] + else: + # cache miss + hit = False + break + if callback: + callback(cached=hit) + return hit + + if config.max_autotune or config.max_autotune_gemm: + local_cache = self.get_local_cache() + # check local cache first since it is data specific to the current machine + if not check_cache(local_cache) and not ( + use_global_cache() + and check_cache(self.get_global_cache(), callback=log_stats) + ): + try: + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][inputs][precision][choice.hash_key()] = timing + except RuntimeError as e: + # catch and log autotuning failures + log_errors(e) + raise e + + self.update_local_cache(local_cache) + + timings_to_log = { + choice.hash_key(): timings[choice] for choice in choices + } + log_vals(timings_to_log) + elif use_global_cache(): + # only check global cache, not local one + check_cache(self.get_global_cache(), callback=log_stats) + # may have a partial cache hit, where not everything is benchmarked + + return timings + + +def get_lock_dir() -> str: + lock_dir = os.path.join(cache_dir(), "locks") + if not os.path.exists(lock_dir): + os.makedirs(lock_dir, exist_ok=True) + return lock_dir + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + + +def code_hash(code: Union[str, bytes], extra: str = ""): + hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") + if extra != "": + hashing_str = hashing_str + b"||" + extra.encode("utf-8") + return "c" + sha256_hash(hashing_str) + + +def get_path( + basename: str, extension: str, specified_dir: str = "" +) -> Tuple[str, str, str]: + if specified_dir: + if os.path.isabs(specified_dir): + subdir = specified_dir + else: + subdir = os.path.join(cache_dir(), specified_dir) + else: + subdir = os.path.join(cache_dir(), basename[1:3]) + path = os.path.join(subdir, f"{basename}.{extension}") + return basename, subdir, path + + +def get_hash(content: Union[str, bytes], extra: str = "", hash_type: str = "code"): + if hash_type == "code": + return code_hash(content, extra) + if hash_type in ["cubin", "hsaco"]: + return code_hash(repr(content)) + raise AssertionError(f"Unknown hash type {hash_type}") + + +def write( + content: Union[str, bytes], + extension: str, + extra: str = "", + hash_type: str = "code", + specified_dir: str = "", +) -> Tuple[str, str]: + # use striped content to compute hash so we don't end up with different + # hashes just because the content begins/ends with differnet number of + # spaces. + key: str = get_hash(content.strip(), extra, hash_type) + basename, subdir, path = get_path(key, extension, specified_dir) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + if not os.path.exists(path): + write_atomic(path, content) + return basename, path + + +def write_atomic(path: str, content: Union[str, bytes]) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + assert isinstance( + content, (str, bytes) + ), "Only strings and byte arrays can be saved in the cache" + path = pathlib.Path(path) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode) as f: + f.write(content) + tmp_path.rename(path) + + +@dataclasses.dataclass +class TensorMetadataAndValues: + """ + TensorMetadata plus the elements as a list of raw values. + Used for hashing inlined constants. + """ + + tensor_metadata: TensorMetadata + values: List[Any] + + +def _ident(x: Any) -> Any: + return x + + +def _reduce_fake_tensor(t): + """ + See FxGraphCachePickler. Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata(t) + return (_ident, (metadata,)) + + +def _reduce_tensor(t): + """ + See FxGraphCachePickler. Custom reducer to pickle Tensors. + """ + if t.is_mkldnn: + # TODO: These tensors don't currently pickle, so we can't cache a + # compiled graph containing them. Just fail now. If mkldnn tensors + # get pickling support, we can remove this. + raise BypassFxGraphCache() + + # If we see tensors, we know they're constants stored as attributes on + # the GraphModule. See tensor lowering; small constants are inlined. If + # we see a small tensor, therefore, no reference will ultimately remain + # in the generated code. So we need to include its value in the cache key. + # Large constants are effectively treated as inputs and we consider only + # their metadata. + metadata = extract_tensor_metadata(t) + if len(t.shape) == 0 or torch._inductor.graph.GraphLowering.can_inline_constant(t): + return (_ident, (TensorMetadataAndValues(metadata, t.tolist()),)) + else: + return (_ident, (metadata,)) + + +def _reduce_symint(s): + """ + See FxGraphCachePickler. Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and + # not the backed value. We evaluate guards stored with a cached graph + # to ensure a cached entity with SymInt args is safe to reuse. + return (_ident, (str(s),)) + + +class FxGraphCachePickler(pickle.Pickler): + """ + Custom pickler to customize the pickling of some objects (Tensors), only for the + purpose of computing a hash for keying into the FxGraphCache. Tensors contain + objects that don't pickle and/or vary between runs, and we want to capture the + data that allow us to compute a stable, but safe hash. + """ + + dispatch_table = copyreg.dispatch_table.copy() + dispatch_table[FakeTensor] = _reduce_fake_tensor + dispatch_table[torch.Tensor] = _reduce_tensor + dispatch_table[torch.SymInt] = _reduce_symint + + @staticmethod + def dumps(obj) -> bytes: + """ + Pickle an object using the FxGraphCachePickler. + """ + with io.BytesIO() as stream: + pickler = FxGraphCachePickler(stream) + pickler.dump(obj) + return stream.getvalue() + + @staticmethod + def get_hash(obj: Any) -> str: + """ + Serialize an object using the FxGraphCachePickler and return a hash + of the pickled object. + """ + serialized_data = FxGraphCachePickler.dumps(obj) + return sha256_hash(serialized_data) + + +@functools.lru_cache(None) +def get_inductor_code_hash() -> bytes: + """ + Compute a hash of all inductor code modules. Used by the FxGraph cache + so any inductor code changes would result in new cache keys. + """ + inductor_root = os.path.dirname(__file__) + + contents: Dict[str, bytes] = {} + for lib in pkgutil.iter_modules([inductor_root]): + spec = lib.module_finder.find_spec(lib.name, None) + assert spec is not None + module = spec.origin + assert module is not None + with open(module, "rb") as f: + contents[module] = f.read() + + return hashlib.sha256(pickle.dumps(contents)).digest() + + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: List[Any] + + +class BypassFxGraphCache(Exception): + """ + Exception to indicate that the FxGraphCache should be bypassed. + """ + + pass + + +class FxGraphHashDetails: + """ + Object to capture all the details for a compiled FX graph relevant to computing + a safe and stable cache key. + """ + + # Excluded kwargs param that are not stable between runs + EXCLUDED_KWARGS = ["graph_id"] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + ): + self.gm = gm + self.example_inputs = example_inputs + + # Order kwargs so hashing is stable to changes in kwarg order. + self.fx_kwargs = {} + for k in sorted(fx_kwargs): + if k not in self.EXCLUDED_KWARGS: + if type(fx_kwargs[k]) is set: + # Special case to handle set params. Python sets can't be + # ordered, so sort the elements and store them in a proxy. + self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) + else: + self.fx_kwargs[k] = fx_kwargs[k] + + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. + self.deterministic_algorithms_settings = ( + torch.are_deterministic_algorithms_enabled(), + torch.is_deterministic_algorithms_warn_only_enabled(), + torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + ) + + # Global settings affecting matmul codegen. + self.cuda_matmul_settings = ( + torch.backends.cuda.matmul.allow_tf32, + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, + ) + + # Also hash on various system info (including the triton compiler version). + self.torch_version = torch.__version__ + self.system_info = CacheBase.get_system() + + # And the inductor configuration and code. + self.inductor_code_hash = get_inductor_code_hash() + try: + self.inductor_config = config.save_config() + except TypeError as e: + # Some configs options are callables, e.g., post_grad_custom_pre_pass, + # and may not pickle. + log.debug("Can't pickle inductor config: %s", e) + raise BypassFxGraphCache() from e + + def debug_str(self) -> str: + """ + Get a printable string describing in more detail all the attributes + comprising this object. Useful for debugging when one graph hashes + to a different value than another. + """ + + def get_str(obj) -> str: + if isinstance(obj, torch.Tensor): + return str(extract_tensor_metadata(obj)) + elif isinstance(obj, bytes): + return "" + else: + return str(obj) + + lines = [] + for attr, obj in vars(self).items(): + if isinstance(obj, list): + for ii in range(len(obj)): + h = FxGraphCachePickler.get_hash(obj[ii]) + lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") + elif isinstance(obj, dict): + for k, v in obj.items(): + h = FxGraphCachePickler.get_hash(v) + lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + else: + h = FxGraphCachePickler.get_hash(obj) + lines.append(f"[{h}] {attr}: {get_str(obj)}") + return "\n".join(lines) + + +def compiled_fx_graph_hash( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], +) -> str: + """ + Generate a unique hash of the FX graph for caching. + """ + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs) + # The prefix distinguishes among the other kinds of objects we + # cache in this module. + key = "f" + FxGraphCachePickler.get_hash(details) + log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str()) + return key + + +class FxGraphCache: + """ + Supports caching and reusing compiled Fx graphs. + + The overall strategy is as follows: + - This cache stores entries on disk. When saving an entry, we can't + serialize callables (that could be C++, Triton, etc.), so we serialize + their own disk cache location. We then recreate the compiled artifact + after fetching from disk. + - For indexing the cache, we gather the fields relevant to identifying an + FxGraph (the graph module, graph inputs, system settings etc.) into an + FxGraphCacheDetails object, pickle it, and compute a hash for the key. + See FxGraphCachePickler. + - Among the metadata we store, we also include a guards expression that's + appropriate for validating any symbols for Tensor arguments that have + symbolic bounds. On cache lookup then, we evaluate those guards in the + current context to validate that a cached entry can be served. + - A given graph could have multiple compiled versions, corresponding to + different sets of guards. Therefore, we store cache entries in the form: + // + - On lookup, we compute the key from the graph details, iterate over all + leaf files in the corresponding subdirectory, deserialize the entry, and + evaluate its guards expression. If the evaluation succeeds, we have a + cache hit. If it fails, we compile the graph and store a new entry. + - Finally, on a cache hit, we need to make sure any guards that would + have been created during compilation are added to the current context. + """ + + # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs + # in an in-memory cache after loading from disk. + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "fxgraph") + + @staticmethod + def _get_tmp_dir_for_key(key: str) -> str: + """ + Return the disk location for a given cache key. + """ + return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) + + @staticmethod + def _filter_symints(inputs: List[Any]) -> List[torch.SymInt]: + """ + Get the SymInt objects from the input list. + """ + return [s for s in inputs if isinstance(s, torch.SymInt)] + + @staticmethod + def _get_shape_env() -> Optional[ShapeEnv]: + """ + Helper to get the shape env from the tracing context. + """ + ctx = torch._guards.TracingContext.try_get() + if not ctx: + return None + return ctx.fake_mode.shape_env + + @staticmethod + def _lookup_graph( + key: str, + example_inputs: List[torch.Tensor], + ) -> Optional[CompiledFxGraph]: + """ + Lookup a compiled graph in the cache by key. On a hit, return the + deserialized CompiledFxGraph object. On a miss, return None. + """ + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + return None + + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + + # Iterate over any entries in the subdir for this key and evaluate + # their guards to determine whether there's a hit. + graph = None + + for path in sorted(os.listdir(subdir)): + with open(os.path.join(subdir, path), "rb") as f: + candidate: CompiledFxGraph = pickle.load(f) + + guards_expr = candidate.guards_expr + if not guards_expr: + # No guards to evaluate, so this is a hit. + graph = candidate + break + + # Evaluate the guard expression in the current context. + symints = FxGraphCache._filter_symints(example_inputs) + + # If there's not a cache hit, we don't want the evaluation to + # affect the current env, e.g., cause the creation of new guards, + # so we evaluate with the hints instead of the symbols. + assert all(has_hint(s) for s in symints) + hints = [hint_int(s) for s in symints] + hit = bool(shape_env.evaluate_guards_expression(guards_expr, hints)) + log.debug( + "fx graph cache key %s evaluating guards for %s with values %s => %s", + key, + guards_expr, + hints, + hit, + ) + if hit: + # Now re-evaluate with the symints to add any guards to the current env. + check = bool(shape_env.evaluate_guards_expression(guards_expr, symints)) + assert check is True + log.debug( + "fx graph cache key %s post-load guards: %s", key, shape_env.guards + ) + graph = candidate + break + + # Increment the cached metrics by the amounts recorded when the FX + # graph was compiled for this cache entry. Pretending these counters + # were incremented normally is useful for testing with the cache enabled. + if graph is not None: + metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) + + return graph + + @staticmethod + def _save_graph( + key: str, compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor] + ): + """ + Store a serialized CompiledFxGraph on disk. + """ + disk_compiled_graph = copy(compiled_graph) + # Important as compiled models are not pickleable: + disk_compiled_graph.compiled_artifact = None + + # Before serializing, compute the guard expression that will be used to + # ensure that a CompiledFxGraph is valid when loaded from the cache. It's + # sufficient to consider only the SymInt args to the fx graph since the + # Tensor shapes are already captured in the hash for the cache key. Any + # Tensor arg with a symbolic shape will have a SymInt arg for the graph. + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + symints = FxGraphCache._filter_symints(example_inputs) + disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints) + + try: + content = pickle.dumps(disk_compiled_graph) + except Exception as e: + log.debug("fx graph cache unable to serialize compiled graph: %s", e) + counters["inductor"]["fxgraph_cache_pickle_error"] += 1 + return + + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content) + + @staticmethod + def _check_can_cache(): + """ + Check some conditions that would preclude caching and raise BypassFxGraphCache + to bypass in case caching is not possible. + """ + if config.freezing or config.aot_inductor.use_runtime_constant_folding: + # Freezing can embed constants that wouldn't be static across runs. + raise BypassFxGraphCache() + + if FxGraphCache._get_shape_env() is None: + # The treatment of guards in the caching implementation requires that + # we have a shape env. + log.debug("fx graph cache no shape env") + raise BypassFxGraphCache() + + @staticmethod + def load( + compile_fx_fn: Callable[..., Any], + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + ): + """ + Load a compiled graph from the cache. If a cached entry does not exist, + compile the graph and save it to the cache. + """ + from filelock import FileLock + + compiled_graph = None + try: + FxGraphCache._check_can_cache() + key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs) + + lock_path = os.path.join(get_lock_dir(), key + ".lock") + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + compiled_graph = FxGraphCache._lookup_graph(key, example_inputs) + if compiled_graph is None: + log.debug("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs) + FxGraphCache._save_graph(key, compiled_graph, example_inputs) + else: + log.debug("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + except BypassFxGraphCache: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + + if not compiled_graph: + compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs) + + return compiled_graph + + @staticmethod + def clear(): + """ + Clear out the on-disk cache. + """ + try: + shutil.rmtree(FxGraphCache._get_tmp_dir()) + except FileNotFoundError: + pass + + +@dataclasses.dataclass +class CompiledFxGraph: + """ + Class holding a compiled FX graph. This is the object serialized on disk + to support FxGraph caching. + """ + + compiled_artifact: Optional[Callable[..., Any]] + current_callable: Optional[Callable[..., Any]] + cache_key: Optional[str] + artifact_path: Optional[str] + cache_linemap: Optional[List[Tuple[int, str]]] + device_types: Set[str] + device_idxs: Set[int] + mutated_inputs: Set[str] + mutated_input_idxs: Set[int] + constants: Dict[str, torch.Tensor] + output_strides: Optional[List[Optional[Tuple[int, ...]]]] + disabled_cudagraphs_reason: Optional[str] + metrics_deltas: metrics.CachedMetricsDeltas + # This is a string representation of an expression we serialize + # with the object so the guards can be evaluated in a different + # context in order to verify the validity of serving a cached + # fx graph. The expression must be generated by: + # ShapeEnv.produce_guards_expression() + guards_expr: Optional[str] + + _boxed_call: Optional[bool] = None + + def __init__( + self, + compiled_artifact: Optional[Callable[..., Any]], + graph: GraphLowering, + output_strides: List[Optional[Tuple[int, ...]]], + disabled_cudagraphs_reason: Optional[str], + metrics_deltas: metrics.CachedMetricsDeltas, + ): + self.compiled_artifact = compiled_artifact + self.current_callable = None + self.cache_key = graph.cache_key + self.artifact_path = graph.cache_path + self.cache_linemap = graph.cache_linemap + self.device_types = graph.device_types + self.device_idxs = graph.device_idxs + self.mutated_inputs = graph.mutated_inputs + self.mutated_input_idxs = set(graph.mutated_input_idxs) + self.constants = graph.constants + self.output_strides = output_strides + self.disabled_cudagraphs_reason = disabled_cudagraphs_reason + self.metrics_deltas = metrics_deltas + self.guards_expr = None + + def __call__(self, inputs: List[Any]) -> Any: + return self.get_current_callable()(inputs) + + def get_current_callable(self) -> Callable[..., Any]: + if self.current_callable is None: + # This prevents a circular reference that makes CompiledFxGraph + # get stuck without getting garbage collected + return functools.partial(_run_from_cache, weakref.proxy(self)) + else: + return self.current_callable + + +def _run_from_cache(compiled_graph: CompiledFxGraph, inputs: List[Any]) -> Any: + # We can't really serialize callables that may be C++/Triton/etc., + # so we serialize their disk cache location instead + # TODO: When making an API that can save compiled models e2e to disk + # this will need to be better + if compiled_graph.compiled_artifact is None: + from .codecache import PyCodeCache + + assert compiled_graph.cache_key + assert compiled_graph.artifact_path + compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path( + compiled_graph.cache_key, + compiled_graph.artifact_path, + compiled_graph.cache_linemap, + compiled_graph.constants, + ).call + + return compiled_graph.compiled_artifact(inputs) + + +def cpp_compiler() -> str: + if config.is_fbcode(): + return build_paths.cc() + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + return cpp_compiler_search(search) + + +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler() + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +def is_gcc() -> bool: + return bool(re.search(r"(gcc|g\+\+)", cpp_compiler())) + + +def is_clang() -> bool: + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler())) + + +@functools.lru_cache(None) +def is_apple_clang() -> bool: + cxx = cpp_compiler() + version_string = subprocess.check_output([cxx, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +class VecISA: + _bit_width: int + _macro: str + _arch_flags: str + _dtype_nelements: Dict[torch.dtype, int] + + # Note [Checking for Vectorized Support in Inductor] + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, + # making the runtime check unnecessary. + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) +#include +#include +#endif + +__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" # noqa: B950 + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self) -> int: + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float) -> int: + return self._dtype_nelements[dtype] + + def build_macro(self) -> str: + return self._macro + + def build_arch_flags(self) -> str: + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + @functools.lru_cache(None) + def __bool__(self) -> bool: + if config.cpp.vec_isa_ok is not None: + return config.cpp.vec_isa_ok + + if config.is_fbcode(): + return True + + key, input_path = write(VecISA._avx_code, "cpp") + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[:-3] + "so" + build_cmd = shlex.split( + cpp_compile_command( + input_path, output_path, warning_all=False, vec_isa=self + ) + ) + try: + # Check build result + compile_file(input_path, output_path, build_cmd) + subprocess.check_call( + [ + sys.executable, + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + stderr=subprocess.DEVNULL, + env={**os.environ, "PYTHONPATH": ":".join(sys.path)}, + ) + except Exception as e: + return False + + return True + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = "-DCPU_CAPABILITY_AVX512" + _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = "-DCPU_CAPABILITY_AVX2" + _arch_flags = "-mavx2 -mfma" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecZVECTOR(VecISA): + _bit_width = 256 + _macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION" + _arch_flags = "-mvx -mzvector" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "zvector" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = "" + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self) -> bool: # type: ignore[override] + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [VecAVX512(), VecAVX2()] + + +# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content +# might have too much redundant content that is useless for ISA check. Hence, +# we only cache some key isa information. +@functools.lru_cache(None) +def valid_vec_isa_list() -> List[VecISA]: + if sys.platform != "linux": + return [] + + if platform.machine() == "s390x": + return [VecZVECTOR()] + + isa_list = [] + with open("/proc/cpuinfo") as _cpu_info: + _cpu_info_content = _cpu_info.read() + for isa in supported_vec_isa_list: + if str(isa) in _cpu_info_content and isa: + isa_list.append(isa) + return isa_list + + +def pick_vec_isa() -> VecISA: + if config.is_fbcode(): + return VecAVX2() + + _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa + + # If the simdlen is None, it indicates determin the vectorization length automatically + if config.cpp.simdlen is None: + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] + + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): + return isa + + return invalid_vec_isa + + +def get_compile_only(compile_only: bool = True) -> str: + return "-c" if compile_only else "" + + +def get_shared(shared: bool = True, compile_only: bool = False) -> str: + if not shared: + return "" + if compile_only: + return "-fPIC" + if platform.system() == "Darwin" and "clang" in cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return "-shared -fPIC -undefined dynamic_lookup" + else: + return "-shared -fPIC" + + +def get_warning_all_flag(warning_all: bool = True) -> str: + return "-Wall" if warning_all else "" + + +def get_glibcxx_abi_build_flags() -> str: + return "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) + + +def cpp_flags() -> str: + flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"] + if is_clang(): + flags.append("-Werror=ignored-optimization-argument") + return " ".join(flags) + + +def cpp_wrapper_flags() -> str: + return "-DTORCH_INDUCTOR_CPP_WRAPPER" + + +def optimization_flags() -> str: + base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG" + base_flags += " -ffast-math -fno-finite-math-only" + if not config.cpp.enable_unsafe_math_opt_flag: + base_flags += " -fno-unsafe-math-optimizations" + if not config.cpp.enable_floating_point_contract_flag: + base_flags += " -ffp-contract=off" + + if config.is_fbcode(): + # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies. + # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths. + # We will fix it later by exposing the lib path. + return base_flags + + if sys.platform == "darwin": + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + # Also, `-march=native` is unrecognized option on M1 + base_flags += " -Xclang" + else: + if platform.machine() == "ppc64le": + base_flags += " -mcpu=native" + else: + base_flags += " -march=native" + + # Internal cannot find libgomp.so + if not config.is_fbcode(): + base_flags += " -fopenmp" + return base_flags + + +def use_custom_generated_macros() -> str: + return "-D C10_USING_CUSTOM_GENERATED_MACROS" + + +def use_fb_internal_macros() -> str: + if config.is_fbcode(): + openmp_lib = build_paths.openmp_lib() + preprocessor_flags = " ".join( + ( + "-D C10_USE_GLOG", + "-D C10_USE_MINIMAL_GLOG", + "-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ) + ) + return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}" + else: + return "" + + +def use_standard_sys_dir_headers() -> str: + if config.is_fbcode(): + return "-nostdinc" + else: + return "" + + +@functools.lru_cache(None) +def is_conda_llvm_openmp_installed() -> bool: + try: + command = "conda list llvm-openmp --json" + output = subprocess.check_output(command.split()).decode("utf8") + return len(json.loads(output)) > 0 + except subprocess.SubprocessError: + return False + + +@functools.lru_cache(None) +def homebrew_libomp() -> Tuple[bool, str]: + try: + # check if `brew` is installed + subprocess.check_output(["which", "brew"]) + # get the location of `libomp` if it is installed + # this is the location that `libomp` **would** be installed + # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details + libomp_path = ( + subprocess.check_output(["brew", "--prefix", "libomp"]) + .decode("utf8") + .strip() + ) + # check if `libomp` is installed + omp_available = os.path.exists(libomp_path) + return omp_available, libomp_path + except subprocess.SubprocessError: + return False, "" + + +def get_include_and_linking_paths( + include_pytorch: bool = False, + vec_isa: VecISA = invalid_vec_isa, + cuda: bool = False, + aot_mode: bool = False, +) -> Tuple[List[str], str, str, str, str]: + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = os.path.dirname(build_paths.cuda()) + from torch.utils import cpp_extension + + macros = "" + build_arch_flags = "" + if sys.platform == "linux" and ( + include_pytorch + or vec_isa != invalid_vec_isa + or cuda + or config.cpp.enable_kernel_profile + ): + # Note - We include pytorch only on linux right now. There is more work + # to do to enable OMP build on darwin where PyTorch is built with IOMP + # and we need a way to link to what PyTorch links. + ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")] + lpaths = cpp_extension.library_paths(cuda) + [ + sysconfig.get_config_var("LIBDIR") + ] + + libs = [] + + # No need to manually specify libraries in fbcode. + if not config.is_fbcode(): + libs += ["torch", "torch_cpu"] + libs += ["gomp"] + if not aot_mode: + libs += ["torch_python"] + else: + # internal remote execution is able to find omp, but not gomp + libs += ["omp"] + if aot_mode: + ipaths += [os.path.dirname(cpp_prefix_path())] + if cuda: + # This is a special treatment for Meta internal cuda-12 where all libs + # are in lib/cuda-12 and lib/cuda-12/stubs + for i, path in enumerate(lpaths): + if path.startswith( + os.environ["CUDA_HOME"] + ) and not os.path.exists(f"{path}/libcudart_static.a"): + for root, dirs, files in os.walk(path): + if "libcudart_static.a" in files: + lpaths[i] = os.path.join(path, root) + lpaths.append(os.path.join(lpaths[i], "stubs")) + break + macros = vec_isa.build_macro() + if macros: + if config.is_fbcode() and vec_isa != invalid_vec_isa: + cap = str(vec_isa).upper() + macros = " ".join( + [ + vec_isa.build_arch_flags(), + f"-D CPU_CAPABILITY={cap}", + f"-D CPU_CAPABILITY_{cap}", + f"-D HAVE_{cap}_CPU_DEFINITION", + ] + ) + + if cuda: + if macros is None: + macros = "" + macros += " -D USE_ROCM" if torch.version.hip else " -D USE_CUDA" + + if cuda: + if torch.version.hip is not None: + libs += ["c10_hip", "torch_hip"] + macros += " -D __HIP_PLATFORM_AMD__" + else: + if config.is_fbcode(): + libs += ["cuda"] + else: + libs += ["c10_cuda", "cuda", "torch_cuda"] + build_arch_flags = vec_isa.build_arch_flags() + else: + # Note - this is effectively a header only inclusion. Usage of some header files may result in + # symbol not found, if those header files require a library. + # For those cases, include the lpath and libs command as we do for pytorch above. + # This approach allows us to only pay for what we use. + ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")] + if aot_mode: + ipaths += [os.path.dirname(cpp_prefix_path())] + lpaths = [] + if sys.platform == "darwin": + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not is_apple_clang() + + # check the `OMP_PREFIX` environment first + if os.getenv("OMP_PREFIX") is not None: + header_path = os.path.join(os.getenv("OMP_PREFIX"), "include", "omp.h") # type: ignore[arg-type] + valid_env = os.path.exists(header_path) + if valid_env: + ipaths.append(os.path.join(os.getenv("OMP_PREFIX"), "include")) # type: ignore[arg-type] + lpaths.append(os.path.join(os.getenv("OMP_PREFIX"), "lib")) # type: ignore[arg-type] + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + libs = [] if omp_available else ["omp"] + + # prefer to use openmp from `conda install llvm-openmp` + if not omp_available and os.getenv("CONDA_PREFIX") is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(os.getenv("CONDA_PREFIX"), "lib") # type: ignore[arg-type] + ipaths.append(os.path.join(os.getenv("CONDA_PREFIX"), "include")) # type: ignore[arg-type] + lpaths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs = ["iomp5"] + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + ipaths.append(os.path.join(libomp_path, "include")) + lpaths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + else: + libs = ["omp"] if config.is_fbcode() else ["gomp"] + + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 + if not config.abi_compatible: + libs += ["c10"] + lpaths += [cpp_extension.TORCH_LIB_PATH] + + # third party libs + if config.is_fbcode(): + ipaths.append(build_paths.sleef()) + ipaths.append(build_paths.openmp()) + ipaths.append(build_paths.cc_include()) + ipaths.append(build_paths.libgcc()) + ipaths.append(build_paths.libgcc_arch()) + ipaths.append(build_paths.libgcc_backward()) + ipaths.append(build_paths.glibc()) + ipaths.append(build_paths.linux_kernel()) + ipaths.append(build_paths.cuda()) + # We also need to bundle includes with absolute paths into a remote directory + # (later on, we copy the include paths from cpp_extensions into our remote dir) + ipaths.append("include") + + static_link_libs = [] + if aot_mode and cuda and config.is_fbcode(): + # For Meta internal cuda-12, it is recommended to static link cudart + static_link_libs = ["-Wl,-Bstatic", "-lcudart_static", "-Wl,-Bdynamic"] + + lpaths_str = " ".join(["-L" + p for p in lpaths]) + libs_str = " ".join(static_link_libs + ["-l" + p for p in libs]) + return ipaths, lpaths_str, libs_str, macros, build_arch_flags + + +def cpp_compile_command( + input: Union[str, List[str]], + output: str, + warning_all: bool = True, + shared: bool = True, + include_pytorch: bool = False, + vec_isa: VecISA = invalid_vec_isa, + cuda: bool = False, + aot_mode: bool = False, + compile_only: bool = False, + use_absolute_path: bool = False, +) -> str: + ipaths, lpaths, libs, macros, build_arch_flags = get_include_and_linking_paths( + include_pytorch, vec_isa, cuda, aot_mode + ) + if isinstance(input, str): + input = [input] + ipaths_str = " ".join(["-I" + p for p in ipaths]) + clang_flags = "" + if config.is_fbcode(): + if aot_mode and not use_absolute_path: + inp_name = input + out_name = output + linker_script = _LINKER_SCRIPT + else: + # We need to copy any absolute-path torch includes + inp_name = [os.path.basename(i) for i in input] + out_name = os.path.basename(output) + linker_script = os.path.basename(_LINKER_SCRIPT) + assert is_clang() + # Use clang runtime instead of libgcc + clang_flags += " --rtlib=compiler-rt" + clang_flags += " -fuse-ld=lld" + clang_flags += f" -Wl,--script={linker_script}" + linker_paths = "-B" + build_paths.glibc_lib() + linker_paths += " -L" + build_paths.glibc_lib() + else: + inp_name = input + out_name = output + linker_paths = "" # let the compiler pick + if compile_only: + libs, lpaths = "", "" + inp_name_str = " ".join(inp_name) + return re.sub( + r"[ \n]+", + " ", + f""" + {cpp_compiler()} {inp_name_str} {get_shared(shared, compile_only)} + {get_warning_all_flag(warning_all)} {cpp_flags()} + {get_glibcxx_abi_build_flags()} + {ipaths_str} {lpaths} {libs} {build_arch_flags} + {macros} {linker_paths} {clang_flags} + {optimization_flags()} + {use_custom_generated_macros()} + {use_fb_internal_macros()} + {use_standard_sys_dir_headers()} + {get_compile_only(compile_only)} + -o {out_name} + """, + ).strip() + + +def run_command_and_check(cmd: str): + cmd = shlex.split(cmd) + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + +@functools.lru_cache(None) +def split_aot_inductor_output_path(path: str) -> Tuple[str, str]: + """Returns the path where the AOT Inductor compiled kernels are stored.""" + if path.endswith(".so"): + return os.path.split(path) + else: + return path, "" + + +class CudaKernelParamCache: + cache: Dict[str, Dict[str, str]] = dict() + clear = staticmethod(cache.clear) + + @classmethod + def set(cls, key: str, params: Dict[str, str], cubin: str) -> None: + bin_type = "cubin" if torch.version.hip is None else "hsaco" + _, path = write( + cubin, + bin_type, + hash_type=bin_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + ) + + params[get_cpp_wrapper_cubin_path_name()] = path + + cls.cache[key] = params + + @classmethod + def get(cls, key: str) -> Optional[Dict[str, str]]: + return cls.cache.get(key, None) + + @classmethod + def get_keys(cls): + return cls.cache.keys() + + +class AotCodeCompiler: + @classmethod + def compile( + cls, + graph: GraphLowering, + source_code: str, + serialized_extern_kernel_nodes: Optional[str], + cuda: bool, + ) -> str: + picked_vec_isa = pick_vec_isa() + cpp_command = repr( + cpp_compile_command( + "i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode + ) + ) + fbcode_aot_cpu_re = False + use_absolute_path = False + if config.is_fbcode(): + ld_command = build_paths.ld() + if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU + objcopy_command = build_paths.objcopy_fallback() + fbcode_aot_cpu_re = True + use_absolute_path = True + else: + objcopy_command = build_paths.objcopy() + else: + ld_command = "ld" + objcopy_command = "objcopy" + + ( + specified_output_path, + specified_so_name, + ) = split_aot_inductor_output_path(config.aot_inductor.output_path) + key, input_path = write( + source_code, + "cpp", + extra=cpp_command, + specified_dir=specified_output_path, + ) + + def _compile_consts_linux(consts: bytes) -> str: + _, consts_path = write( + consts, + "bin", + specified_dir=specified_output_path, + ) + + consts_o = os.path.splitext(consts_path)[0] + ".o" + if fbcode_aot_cpu_re: + cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}" + compile_file(consts_path, consts_o, cmd.split()) + os.chmod(consts_o, 0o644) + else: + cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}" + run_command_and_check(cmd) + log.debug("aot constant binary command: %s", cmd) + + cmd = ( + f"{objcopy_command} --rename-section" + " .data=.lrodata,alloc,load,readonly,data,contents" + f" {consts_o} {consts_o}" + ) + log.debug("aot constant obj command: %s", cmd) + run_command_and_check(cmd) + + cmd = f"rm {consts_path}" + log.debug("aot constant bin removal command: %s", cmd) + run_command_and_check(cmd) + + if fbcode_aot_cpu_re: + body = re.sub(r"[\W]", "_", os.path.basename(consts_path)) + else: + body = re.sub(r"[\W]", "_", consts_path) + + symbol_list = [] + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}" + ) + symbol_list.append( + f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}" + ) + log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list)) + for cmd in symbol_list: + run_command_and_check(cmd) + return consts_o + + def _compile_consts_darwin(consts: bytes) -> str: + is_large_consts = len(consts) > 1024 + consts_asm = "\t.section\t__TEXT,__const\n" + consts_asm += "\t.globl\t__binary_constants_bin_start\n" + consts_asm += "__binary_constants_bin_start:\n" + if not is_large_consts: + for c in consts: + consts_asm += f"\t.byte {c}\n" + # Add one element even if constants are empty + # Otherwise assembler will not put them in data section + if not consts: + consts_asm += "\t.space 1\n" + else: + consts_asm += "\t.quad 0x1234567899abcdef\n" + consts_asm += f"\t.space {len(consts) - 8}\n" + consts_asm += ".globl\t__binary_constants_bin_end\n" + consts_asm += "__binary_constants_bin_end:\n" + _, consts_path = write( + consts_asm, + "S", + specified_dir=specified_output_path, + ) + consts_o = os.path.splitext(consts_path)[0] + ".o" + cmd = f"{cpp_compiler()} -c -o {consts_o} {consts_path}" + run_command_and_check(cmd) + if is_large_consts: + with open(consts_o, "r+b") as f: + f.seek(0) + hdr = f.read(1024) + # Search for magic number and write the actual data over it + start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + assert start_idx != -1 + f.seek(start_idx) + pos = 0 + while pos < len(consts): + rc = f.write(consts[pos:]) + pos += rc + return consts_o + + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + # Currently, this only support serializing extern nodes in fbcode + # Eventually, we should also have a serializer for OSS. + if config.is_fbcode() and serialized_extern_kernel_nodes: + output_json = os.path.splitext(input_path)[0] + ".json" + with open(output_json, "w") as f: + f.write(serialized_extern_kernel_nodes) + + output_so = ( + config.aot_inductor.output_path + if specified_so_name + else os.path.splitext(input_path)[0] + ".so" + ) + + output_o = os.path.splitext(input_path)[0] + ".o" + cmd = cpp_compile_command( + input=input_path, + output=output_o, + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + ) + log.debug("aot compilation command: %s", cmd) + if fbcode_aot_cpu_re: + compile_file(input_path, output_o, cmd.split()) + os.chmod(output_o, 0o644) + else: + run_command_and_check(cmd) + + def _to_bytes(t: torch.Tensor) -> bytes: + # This serializes the tensor's untyped_storage to bytes by accessing + # the raw data of the underlying structure. + import ctypes + + if t.numel() == 0: + return b"" + + t_cpu = t.untyped_storage().cpu() + raw_array = ctypes.cast( + t_cpu.data_ptr(), + ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()), + ) + + return bytes(raw_array.contents) + + aot_constants = b"".join( + _to_bytes(tensor) + for name, tensor in graph.constants.items() + if name not in graph.folded_constants + ) + consts_o = { + "linux": _compile_consts_linux, + "darwin": _compile_consts_darwin, + }[sys.platform](aot_constants) + + cmd = cpp_compile_command( + input=[output_o, consts_o], + output=output_so, + vec_isa=picked_vec_isa, + cuda=cuda, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + log.debug("aot linkage command: %s", cmd) + if fbcode_aot_cpu_re: + compile_file([output_o, consts_o], output_so, cmd.split()) + os.chmod(output_so, 0o755) + else: + run_command_and_check(cmd) + + return output_so + + +# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py. +# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock. +# Cycle goes: +# - CppCodeCache.load() +# - pick_vec_isa() +# - valid_vec_isa_list() +# - VecISA.__bool__() <-- takes out a lock +# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock. +@functools.lru_cache +def cpp_prefix_path() -> str: + path = Path(__file__).parent / "codegen/cpp_prefix.h" + with path.open() as f: + content = f.read() + _, filename = write( + content, + "h", + ) + return filename + + +def cpp_prefix() -> str: + filename = cpp_prefix_path() + if config.is_fbcode(): + # We need relative paths, since we bundle up + # everything that we compile into a folder for remote compilation. + return f'#include "{os.path.basename(filename)}"' + else: + return f'#include "{filename}"' + + +# Given a path to an input cpp file and an output path, +# Attempts to compile the file, storing the output in "output_path" +@dynamo_timed +def compile_file( + input_path: Union[str, List[str]], output_path: str, cmd: List[str] +) -> None: + input_paths = [input_path] if isinstance(input_path, str) else input_path + input_files = [ + os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths + ] + try: + if config.is_fbcode(): + # Need to copy our header into the same folder as the sourcecode. + header_path = cpp_prefix_path() + header_name = os.path.basename(header_path) + output_name = os.path.basename(output_path) + # When we build remotely, we need to make sure to carefully copy any files + # that are required during the compilation process into our build directly. + # This is where all of the ATen/c10/Torch includes come from. + torch_includes_path = os.path.join(_TORCH_PATH, "include") + with tempfile.TemporaryDirectory() as tmp_dir: + # Copy everything to tmp compilation folder + shutil.copy(header_path, os.path.join(tmp_dir, header_name)) + shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld")) + for p, f in zip(input_paths, input_files): + shutil.copy(p, os.path.join(tmp_dir, f)) + dest_include_path = os.path.join(tmp_dir, "include") + shutil.copytree(torch_includes_path, dest_include_path) + # Run the build + output_file_path = _run_build_command(cmd, tmp_dir, output_name) + # Copy output from the build + if os.path.exists(output_path): + os.remove(output_path) + shutil.copy(output_file_path, output_path) + else: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + + +_libgomp: Optional[CDLL] = None + + +class CppCodeCache: + cache: Dict[str, Union[CDLL, ModuleType]] = {} + clear = staticmethod(cache.clear) + cpp_compile_command_flags: Dict[str, Any] = {} + + @staticmethod + def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: + return cdll.LoadLibrary(path) + + @classmethod + def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: + try: + return cls._load_library_inner(path, key) + except (ImportError, OSError) as e: + if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): + # hacky workaround for fbcode/buck + global _libgomp + _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") + return cls._load_library_inner(path, key) + if "failed to map segment from shared object" in str(e): + raise OSError( + f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " + "is mounted with noexec (e.g., by default Docker mounts tmp file systems " + f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " + "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." + ) from e + raise + + @classmethod + def load(cls, source_code: str, cuda: bool = False) -> Union[CDLL, ModuleType]: + cls.cpp_compile_command_flags.update({"cuda": cuda}) + picked_vec_isa = pick_vec_isa() + cpp_command = repr( + cpp_compile_command( + "i", "o", vec_isa=picked_vec_isa, **cls.cpp_compile_command_flags + ) + ) + key, input_path = write(source_code, "cpp", extra=cpp_command) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[:-3] + "so" + if not os.path.exists(output_path): + cmd = shlex.split( + cpp_compile_command( + input=input_path, + output=output_path, + vec_isa=picked_vec_isa, + **cls.cpp_compile_command_flags, + ) + ) + compile_file(input_path, output_path, cmd) + cls.cache[key] = cls._load_library(output_path, key) + cls.cache[key].key = key # type: ignore[union-attr] + + return cls.cache[key] + + +# Customized Python binding for cpp kernels +class CppPythonBindingsCodeCache(CppCodeCache): + cache: Dict[str, Union[CDLL, ModuleType]] = {} + clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + # kernels have no dependency on libtorch + "include_pytorch": False, + "shared": True, + } + entry_function = "kernel" + call_entry_function = "kernel(%s);Py_RETURN_NONE;" + extra_parse_arg = "" + suffix_template = textwrap.dedent( + """ + // Python bindings to call %s(): + #define PY_SSIZE_T_CLEAN + #include + #include + #include + + // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. + // We manually link it below to workaround issues with fbcode build. + static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); + + template static inline T parse_arg(PyObject* args, size_t n) { + static_assert(std::is_pointer::value, "arg type must be pointer or long"); + return static_cast(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); + } + template <> inline long parse_arg(PyObject* args, size_t n) { + auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); + if(result == -1 && PyErr_Occurred()) + [[unlikely]] throw std::runtime_error("expected int arg"); + return result; + } + + %s + + static PyObject* %s_py(PyObject* self, PyObject* args) { + try { + if(!PyTuple_CheckExact(args)) + [[unlikely]] throw std::runtime_error("tuple args required"); + if(PyTuple_GET_SIZE(args) != %s) + [[unlikely]] throw std::runtime_error("requires %s args"); + %s + } catch(std::exception const& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + } + } + + static PyMethodDef py_methods[] = { + {"%s", %s_py, METH_VARARGS, ""}, + {NULL, NULL, 0, NULL}}; + + static struct PyModuleDef py_module = + {PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods}; + + PyMODINIT_FUNC PyInit_%s(void) { + const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); + if(!str_addr) { + PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); + return nullptr; + } + std::istringstream iss(str_addr); + uintptr_t addr = 0; + iss >> addr; + _torchinductor_pyobject_tensor_data_ptr = + reinterpret_cast(addr); + return PyModule_Create(&py_module); + } + """ + ) + + @classmethod + def _load_library_inner(cls, path: str, key: str) -> ModuleType: + os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( + torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] + ) + return importlib.machinery.ExtensionFileLoader( + f"{key}.{cls.entry_function}", path + ).load_module() # type: ignore[call-arg] + + @classmethod + def load_pybinding( + cls, + argtypes: List[str], + source_code: str, + cuda: bool = False, + num_outputs: int = -1, + ) -> Any: + """ + Wrap a C++ function in fast Python bindings. + + Args: + argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] + source_code: C++ source code containing a ENTRY_FUNCTION() function + + Returns: + A python version of ENTRY_FUNCTION() + """ + parseargs = ", ".join( + f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" + for n, argtype in enumerate(argtypes) + ) + suffix = cls.suffix_template % ( + cls.entry_function, + cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "", + cls.entry_function, + len(argtypes), + len(argtypes), + cls.call_entry_function % parseargs, + cls.entry_function, + cls.entry_function, + cls.entry_function, + cls.entry_function, + ) + result = cls.load(source_code + suffix, cuda) + assert isinstance(result, ModuleType) + return getattr(result, cls.entry_function) + + +class CppWrapperCodeCache(CppPythonBindingsCodeCache): + cache: Dict[str, Union[CDLL, ModuleType]] = {} + clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + "include_pytorch": True, + "shared": True, + } + entry_function = "inductor_entry_cpp" + call_entry_function = "return THPVariable_WrapList(inductor_entry_cpp(%s));" + extra_parse_arg = textwrap.dedent( + """ + #include + #include + + template <> inline std::vector parse_arg>(PyObject* args, size_t n) { + return THPVariable_UnpackList(PyTuple_GET_ITEM(args, n)); + } + + std::vector inductor_entry_cpp(std::vector&& inputs) { + auto input_handles = unsafe_alloc_new_handles_from_tensors(inputs); + // For outputs, we only allocate a vector to hold returned tensor handles, + // not allocating the actual output tensor storage here + std::vector output_handles(%s); + + try { + inductor_entry_impl(input_handles.data(), output_handles.data()); + } catch(std::exception const& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return {}; + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return {}; + } + + return alloc_tensors_by_stealing_from_handles(output_handles.data(), output_handles.size()); + } + """ + ) + + +class PyCodeCache: + cache: Dict[str, ModuleType] = dict() + linemaps: Dict[str, List[Tuple[Any, ...]]] = dict() + clear = staticmethod(cache.clear) + + @classmethod + def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]: + return write(source_code, "py", extra=extra) + + @classmethod + def load( + cls, + source_code: str, + extra: str = "", + linemap: Optional[List[Tuple[int, str]]] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> ModuleType: + key, path = write(source_code, "py", extra=extra) + return cls.load_by_key_path(key, path, linemap, attrs) + + @classmethod + def load_by_key_path( + cls, + key: str, + path: str, + linemap: Optional[List[Tuple[int, str]]] = None, + attrs: Optional[Dict[str, Any]] = None, + ) -> ModuleType: + if linemap is None: + linemap = [] + if key not in cls.cache: + with open(path) as f: + try: + code = compile(f.read(), path, "exec") + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + mod = ModuleType(f"{__name__}.{key}") + mod.__file__ = path + mod.key = key # type: ignore[attr-defined] + exec(code, mod.__dict__, mod.__dict__) + sys.modules[mod.__name__] = mod + # another thread might set this first + cls.cache.setdefault(key, mod) + # unzip into separate lines/nodes lists + cls.linemaps[path] = list(zip(*linemap)) + + if attrs is not None: + for k, v in attrs.items(): + setattr(mod, k, v) + + return cls.cache[key] + + @classmethod + @functools.lru_cache(None) + def stack_frames_for_code( + cls, path: str, lineno: int + ) -> Optional[List[Dict[str, Any]]]: + if path not in cls.linemaps: + return None + # [(starting_line, ), ...] + lines, nodes = cls.linemaps[path] + p = bisect_right(lines, lineno) + if p == 0: + return None + entry = nodes[p - 1] + if not entry: + return None + + def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]: + # ideally fx stores stack traces as data rather than a string + # but this is not along a performance critical path + regex = r'File "(.+)", line (\d+), in (.+)\n' + matches = re.findall(regex, stack_trace) + return [ + {"filename": f, "line": int(l), "name": n} + for f, l, n in reversed(matches) + ] + + return parse_stack_trace(entry) + + +class TritonCodeCache: + @classmethod + def load(cls, kernel_name: str, source_code: str) -> ModuleType: + mod = PyCodeCache.load(source_code) + return getattr(mod, kernel_name) + + +def _cuda_compiler() -> Optional[str]: + if cuda_env.nvcc_exist(config.cuda.cuda_cxx): + return config.cuda.cuda_cxx + if cuda_env.nvcc_exist(os.getenv("CUDACXX")): + return os.getenv("CUDACXX", "") + if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): + return os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc") + return "nvcc" + + +def _cutlass_include_paths() -> List[str]: + cutlass_path = config.cuda.cutlass_dir + return [ + os.path.join(cutlass_path, "include"), + os.path.join(cutlass_path, "tools/library/include"), + os.path.join(cutlass_path, "tools/library/src"), + os.path.join(cutlass_path, "tools/util/include"), + ] + + +def _cuda_lib_options() -> List[str]: + from torch.utils import cpp_extension + + extra_ldflags: List[str] = [] + if is_linux(): + extra_lib_dir = "lib64" + if not os.path.exists( + cpp_extension._join_cuda_home(extra_lib_dir) + ) and os.path.exists(cpp_extension._join_cuda_home("lib")): + # 64-bit CUDA may be installed in "lib" + # Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64" + extra_lib_dir = "lib" + extra_ldflags.append(f"-L{cpp_extension._join_cuda_home(extra_lib_dir)}") + extra_ldflags.append( + f'-L{cpp_extension._join_cuda_home(extra_lib_dir, "stubs")}' + ) + extra_ldflags.append("-lcuda") + extra_ldflags.append("-lcudart") + else: + raise NotImplementedError( + "Unsupported env, failed to find cuda libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _nvcc_host_compiler_options() -> List[str]: + return [ + "-fPIC", + "-fno-strict-aliasing", + "-fvisibility=hidden", + "-Wconversion", + ] + + +def _nvcc_compiler_options() -> List[str]: + arch = cuda_env.get_cuda_arch() + if arch == "90": + # Required by cutlass compilation. + arch = "90a" + code = [f"sm_{arch}", f"compute_{arch}"] + if config.cuda.enable_cuda_lto: + code += [f"lto_{arch}"] + options = [ + "-t=0", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-w", + f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", + config.cuda.compile_opt_level, + "-std=c++17", + "--expt-relaxed-constexpr", + "-DNDEBUG", + ] + if config.cuda.enable_debug_info: + options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) + if config.cuda.enable_ptxas_info: + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", + ] + ) # Annotate the ptx file with source information + if config.cuda.use_fast_math: + options.extend( + [ + "--use_fast_math", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) + return options + + +def cuda_compile_command( + src_files: List[str], + dst_file: str, + dst_file_ext: str, +) -> str: + include_paths = _cutlass_include_paths() + cuda_lib_options = _cuda_lib_options() + nvcc_host_compiler_options = _nvcc_host_compiler_options() + nvcc_compiler_options = _nvcc_compiler_options() + options = ( + nvcc_compiler_options + + [ + f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" + for opt in nvcc_host_compiler_options + ] + + ["-I" + path for path in include_paths] + + cuda_lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + log.debug("CUDA command: %s", res) + return res + + +class DLLWrapper: + """A wrapper for a dynamic library.""" + + def __init__( + self, + lib_path: str, + ): + self.lib_path = lib_path + self.DLL = cdll.LoadLibrary(lib_path) + self.is_open = True + + def close(self): + if self.is_open: + self._dlclose() + self.is_open = False + + def _dlclose(self): + f_dlclose = None + + if is_linux(): + syms = CDLL(None) + if not hasattr(syms, "dlclose"): + # Apline Linux + syms = CDLL("libc.so") + + if hasattr(syms, "dlclose"): + f_dlclose = syms.dlclose + else: + raise NotImplementedError("Unsupported env, failed to do dlclose!") + + if f_dlclose is not None: + f_dlclose.argtypes = [c_void_p] + f_dlclose(self.DLL._handle) + else: + log.warning( + "dll unloading function was not found, library may not be unloaded properly!" + ) + + def __getattr__(self, name): + if not self.is_open: + raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") + + method = getattr(self.DLL, name) + + def _wrapped_func(*args): + err = method(*args) + if err: + raise RuntimeError(f"Error in function: {method.__name__}") + + return _wrapped_func + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def __del__(self): + self.close() + + +class CUDACodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: Dict[str, CacheEntry] = dict() + clear = staticmethod(cache.clear) + _SOURCE_CODE_SUFFIX = "cu" + + @classmethod + def write(cls, source_code, dst_file_ext) -> Tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile(cls, source_code, dst_file_ext) -> Tuple[str, str, str]: + """ + Compiles CUDA source_code into a file with dst_file_ext extension. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = cuda_compile_command( + [input_path], output_path, dst_file_ext + ).split(" ") + try: + subprocess.check_output( + cmd, stderr=subprocess.STDOUT, env=os.environ + ) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd, error.output) from error + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _set_triton_ptxas_path() -> None: + if os.environ.get("TRITON_PTXAS_PATH") is not None: + return + ptxas_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas") + ) + if not os.path.exists(ptxas_path): + return + if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK): + os.environ["TRITON_PTXAS_PATH"] = ptxas_path + else: + warnings.warn(f"{ptxas_path} exists but is not an executable") + + +def _worker_compile( + kernel_name: str, source_code: str, cc: int, device: torch.device +) -> None: + device_interface = get_interface_for_device(device.type) + device_interface.Worker.set_device(device.index) + kernel = TritonCodeCache.load(kernel_name, source_code) + kernel.precompile(warm_cache_only_with_cc=cc) + + +def _load_kernel(kernel_name: str, source_code: str) -> ModuleType: + _set_triton_ptxas_path() + kernel = TritonCodeCache.load(kernel_name, source_code) + kernel.precompile() + return kernel + + +class TritonFuture: + kernel: ModuleType + + def __init__( + self, + kernel_name: str, + source_code: str, + future: Future[Any], + ) -> None: + self.kernel_name = kernel_name + self.source_code = source_code + self.future = future + + # @dynamo_utils.dynamo_timed + def result(self) -> ModuleType: + t0 = time() + if hasattr(self, "kernel"): + return self.kernel + # If the worker failed this will throw an exception. + self.future.result() + kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code) + latency = time() - t0 + if latency > 50: + developer_warning( + f"Detected long compilation time of {latency} seconds for kernel name {self.kernel_name}" + ) + developer_warning(self.source_code) + del self.kernel_name, self.source_code, self.future + return kernel + + +# If this process dies abnormally (e.g. segfault) +# it will not shut down the workers. Instead +# the workers will have their parent reassigned to the +# init process. This launches a separate thread to +# watch for the worker getting reassigned, +# and cleans it up in this case. +# +# This function cannot be an inner function since otherwise mp_context="spawn" would +# not work for ProcessPoolExecutor since inner functions cannot be pickled. +def _async_compile_initializer(orig_ppid) -> None: + def run() -> None: + while True: + sleep(1) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + +_watchdog_thread: Optional[Thread] = None + +# Used to keep track of all process pools invoked so far. +_pool_set: Set[ProcessPoolExecutor] = set() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + global _pool_set + for pool in _pool_set: + pool.shutdown() + _pool_set.clear() + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert config.compile_threads > 1 + return ThreadPoolExecutor(config.compile_threads) + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> ProcessPoolExecutor: + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + assert config.compile_threads > 1 + orig_ppid = os.getpid() + + ctx = multiprocessing.get_context(config.worker_start_method) + pool = ProcessPoolExecutor( + config.compile_threads, + mp_context=ctx, + initializer=partial(_async_compile_initializer, orig_ppid), + ) + + global _pool_set + _pool_set.add(pool) + + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + return pool + + @classmethod + def warm_pool(cls) -> None: + if config.compile_threads <= 1: + return + _compile_start() + pool = cls.process_pool() + + # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the + # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. + + # Examples: + # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup + # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup + + # So we want to start the workers early when it is still cheap, and also to allow the workers to get + # ready before we have work for them. + + # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. + # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. + + # We force them to start here with some YOLOing of the internal methods. + if hasattr(pool, "_start_queue_management_thread"): + pool._start_queue_management_thread() + else: + for _ in range(config.compile_threads): + pool._adjust_process_count() + if hasattr(pool, "_start_executor_manager_thread"): + pool._start_executor_manager_thread() + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if config.compile_threads <= 1: + return task() + return cls.pool().submit(task) + + @classmethod + def map(cls, fn: Callable[..., Any], seq: List[Any]) -> List[Any]: + if config.compile_threads <= 1 or len(seq) <= 1: + return list(map(fn, seq)) + return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]] + + def triton( + self, kernel_name: str, source_code: str, device_str: str = "cuda" + ) -> Union[TritonFuture, ModuleType]: + _compile_start() + + if config.compile_threads > 1: + device_interface = get_interface_for_device(device_str) + device = torch.device(device_str, device_interface.current_device()) + cc = device_interface.get_compute_capability(device) + future = self.process_pool().submit( + _worker_compile, kernel_name, source_code, cc, device + ) + return TritonFuture(kernel_name, source_code, future) + else: + return _load_kernel(kernel_name, source_code) + + def multi_kernel(self, *args, **kwargs) -> ModuleType: + """ + Async compile the python shim for multi-kernel. + """ + + def task(): + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + return MultiKernelCall(*args, **kwargs) + + return self.submit(task) + + def cpp(self, source_code: str) -> ModuleType: + def task(): + return CppCodeCache.load(source_code).kernel + + return self.submit(task) + + def cpp_pybinding(self, argtypes: List[str], source_code: str) -> ModuleType: + return self.submit( + functools.partial( + CppPythonBindingsCodeCache.load_pybinding, argtypes, source_code + ) + ) + + def cuda(self, source_code, dst_file_ext): + def task(): + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def wait(self, scope: Dict[str, Any]) -> None: + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, TritonFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + if config.compile_threads > 1: + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + if isinstance(result, (Future, TritonFuture)): + scope[key] = result.result() + pbar.update(1) + + _compile_end() + + +if os.environ.get("TORCH_TNT_IN_USE", "0") == "1": + # When TorchTNT is used, calling warm_pool() here will cause the + # compile workers created not being able to be shut down inside + # shutdown_compile_workers(). This may cause significant QPS drop. + log.info("Do not call AsyncCompile.warm_pool() because TorchTNT is in use.") +else: + AsyncCompile.warm_pool() diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f28ee719d08e4899440b3c58de77e45fad9741ae Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..157be12cca2e1adfe810231c509aebd50d0ccae3 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1d507d9746e255ff60d7cc3befa56709975b4e6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81dfe7c70d342ca6064ef9224c83f1c488e85bed Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..776ebba7288dc168e98adbb78fdc49e0af1fe1b4 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cuda.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be6c870c52e1a0237b3123c49fa937e808eb2c83 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6bee7f4a75a6c50ffa74c1850ae7de9e8f3657d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa308694c34e7785120e8f2a80f33d8af2721130 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27237f95eaa88518ac8f9f1e0b77e7ebfab0be8f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45caa2e6f8b232c041c9ed35d47011aed4befd6e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_foreach.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25da24eed95159e51ba2bb1ae14df8326763d743 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f8cf841a22d6afa04bac20220b3344c0f691c92 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a646ae6fc021badf00a4dd1cb43cbf29f3daef8a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/__pycache__/wrapper.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp b/MLPY/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09ba9a2733120fd3fc5680cbfac412e3d314ff4e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/implementation.cpp @@ -0,0 +1,87 @@ +// NOTE: Like interface.cpp, this file will be copied into AOTInductor +// generated output. This file is intended to keep implementation +// details separate from the implementation of the AOTI public +// interface. Note also that #includes should go into interface.cpp +// for simplicity of maintenance. + +namespace torch { +namespace aot_inductor { +template +void convert_output_to_handle( + const ArrayRefTensor& output, + AtenTensorHandle& handle) { + handle = output.expensiveCopyToTensor(); +} + +template +void convert_outputs_to_handles_helper( + const std::tuple...>& outputs, + AtenTensorHandle* output_handles, + std::index_sequence) { + (convert_output_to_handle(std::get(outputs), output_handles[Is]), ...); +} +template +void convert_outputs_to_handles( + const std::tuple...>& outputs, + AtenTensorHandle* output_handles) { + convert_outputs_to_handles_helper( + outputs, output_handles, std::make_index_sequence()); +} + +template +void convert_handle_to_arrayref_tensor( + AtenTensorHandle handle, + ArrayRefTensor& input) { + void* data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr)); + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim)); + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel)); + int64_t* sizes; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes)); + int64_t* strides; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides)); + int32_t dtype; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype)); + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type)); + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(handle, &device_index)); + + input = ArrayRefTensor( + MiniArrayRef(reinterpret_cast(data_ptr), numel), + MiniArrayRef(sizes, dim), + MiniArrayRef(strides, dim), + device_type, + device_index); +} + +template +void convert_handles_to_inputs_helper( + AtenTensorHandle* input_handles, + std::tuple...>& inputs, + std::index_sequence) { + (convert_handle_to_arrayref_tensor(input_handles[Is], std::get(inputs)), + ...); +} + +template +void convert_handles_to_inputs( + AtenTensorHandle* input_handles, + std::tuple...>& inputs) { + convert_handles_to_inputs_helper( + input_handles, inputs, std::make_index_sequence()); +} + +template +void assert_numel(const ArrayRefTensor& tensor, int64_t numel) { + if (tensor.numel() != numel) { + std::stringstream err; + err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel(); + throw std::runtime_error(err.str()); + } +} +} // namespace aot_inductor +} // namespace torch diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp b/MLPY/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d88881f3afcaf758442b202224f88bae6d47afe8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,354 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + std::cerr << "Error: " << e.what() << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred." << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK( \ + actual_size == expected_size, \ + "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) { + aoti_torch_grad_mode_set_enabled(false); + } + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + bool prev_mode; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir) { + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, + is_cpu ? "cpu" : "cuda", + cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir) { + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0" << std::endl; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto* container = new torch::aot_inductor::AOTInductorModelContainer( + num_models, std::string(device_str), cubin_dir_opt); + *container_handle = + reinterpret_cast(container); + }) +} + +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto* container = + reinterpret_cast( + container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, + constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->swap_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast*>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, + constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + "" + ); + + if (input_map) { + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType) nullptr, + nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast( + model_handle); + delete model; + })} + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); + }) +} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = + reinterpret_cast*>( + constant_map_handle); + + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +} // extern "C" diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/common.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..843bbd64fc74a7efeea3c8cb0f0adb80049b6e27 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,1755 @@ +import contextlib +import dataclasses +import functools +import itertools +import logging +import operator +import re +from itertools import chain +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy.printing.printer import Printer + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._sympy.value_ranges import ValueRanges + +from .. import config, metrics +from ..utils import ( + DeferredLineBase, + do_bench, + free_symbol_startswith, + IndentedBuffer, + sympy_dot, + sympy_index_symbol, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V + +if TYPE_CHECKING: + from ..ir import TensorBox + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + + +def data_type_logger(msg): + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class WorkspaceArg: + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + """ + + nbytes: sympy.Expr + zero_fill: bool + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.Integer(0) + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: type + wrapper_codegen: type + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] + +device_codegens: Dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name): + raise NotImplementedError() + + def set_device(self, device_idx): + raise NotImplementedError() + + def synchronize(self): + raise NotImplementedError() + + def device_guard(self, device_idx): + raise NotImplementedError() + + +device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, device_scheduling: type, device_wrapper_codegen: type +): + device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen) + + +def get_scheduling_for_device(device: str): + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device(device: str): + return ( + device_codegens[device].wrapper_codegen if device in device_codegens else None + ) + + +def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str): + assert isinstance(device, str) + + if not device_op_overrides_dict.keys(): + from .cuda import device_op_overrides # noqa: F401 + + if device in device_op_overrides_dict.keys(): + return device_op_overrides_dict[device] + + return DeviceOpOverrides() + + +@functools.lru_cache(None) +def boolean_ops(): + return ( + "is_inf", + "is_nan", + "bitwise_xor", + "logical_not", + "signbit", + "le", + "lt", + "ge", + "gt", + "eq", + "ne", + ) + + +DTYPE_TO_COMPUTATION_DTYPE = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +class DataTypePropagation: + def __init__(self, body) -> None: + self.body = body + self.graphs: Dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propogated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propogated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node): + if node.target in boolean_ops(): + return torch.bool + + if node.op == "placeholder": + return None + + if node.target == "output": + # we can infer output node if it only have 1 arg + if len(node.args) != 1: + return None + + if node.target in ( + "to_dtype", + "index_expr", + ): + return node.args[-1] + + if node.target in ( + "rand", + "randn", + ): + return torch.float + + if node.target in ( + "get_index", + "index_expr", + ): + return torch.int64 + + if node.target in ( + "load", + "store", + "store_reduction", + ): + buf_name = node.args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + + if node.target == operator.getitem: + return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] + + assert isinstance(node.target, str) + + if node.target == "reduction": + return node.args[1] + + if node.target == "constant": + return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index] + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph): + assert graph.nodes + graph_dtype = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self): + self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body): + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node): + from ..ir import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode) + assert isinstance(node._body, LoopBody) + DataTypePropagation.propagate_loopbody(node._body) + + +class ExprPrinter(Printer): + @staticmethod + def paren(string): + def all_in_parens(string): + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + if ( + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.I) + or re.match(r"^\([^)]*\)$", string, re.I) + or string == "" + ): + return string + # don't put extra parens for strings that are already wrapped in parens + if all_in_parens(string): + return string + return f"({string})" + + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + def _print_Relational(self, expr): + return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) + + def _print_Mul(self, expr): + return "*".join(map(self.paren, map(self._print, expr.args))) + + def _print_Add(self, expr): + return " + ".join(map(self.paren, map(self._print, expr.args))) + + def _print_Mod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_CleanDiv(self, expr): + return self._print_FloorDiv(expr) + + def _print_GreaterThan(self, expr): + # GreaterThan: >= + # StrictlyGreaterThan: > + # Go figure... + return " >= ".join(map(self.paren, map(self._print, expr.args))) + + def _print_align(self, expr): + assert len(expr.args) == 1 + return f"align({self._print(expr.args[0])})" + + +class PythonPrinter(ExprPrinter): + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + mod = self.paren(self.doprint(mod)) + if div != "1": + x = f"({x} // {div})" + return f"{x} % {mod}" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + def _helper_sqrt(self, expr): + return f"math.sqrt({self._print(expr)})" + + def _print_Pow(self, expr): + # Pow() confuses triton + base, exp = expr.args + # NB: Remember this is sizevar computation! You don't typically + # expect to have to do floating point computation including exponents + # in sizevar compute. Instead of adding support for floating + # point pow, you should make upstream retranslate the Sympy expression + # into Tensor expressions earlier and do that instead. + if exp == 0.5: + return self._helper_sqrt(base) + elif exp == -0.5: + return "1/" + self._helper_sqrt(base) + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + if exp > 0: + return "*".join([self.paren(base)] * exp) + elif exp < 0: + return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + return "1" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"abs({self._print(expr.args[0])})" + + def _print_Max(self, expr): + assert len(expr.args) >= 2 + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr): + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + assert isinstance(ndigits, sympy.Integer) + return f"round({self._print(number)}, {ndigits})" + + +class OpOverrides: + def __init__(self, parent): + super().__init__() + self._parent = parent + + def __getattr__(self, item): + return getattr(self._parent, item) + + @staticmethod + def identity(value): + # used to trigger cse + return value + + @staticmethod + def constant(value, dtype): + return repr(value) + + @staticmethod + def reciprocal(x): + return ops.truediv("1", x) + + @staticmethod + def square(x): + return ops.mul(x, x) + + @staticmethod + def bitwise_not(x): + return f"~{ExprPrinter.paren(x)}" + + @staticmethod + def logical_not(a): + return f"{ExprPrinter.paren(a)} == 0" + + @staticmethod + def bitwise_and(x, y): + return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_or(x, y): + return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_xor(x, y): + return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_left_shift(x, y): + return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_right_shift(x, y): + return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" + + @staticmethod + def remainder(a, b): + r = ops.mod(a, b) + return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) + + @staticmethod + def load_seed(name, offset): + return ops.load(name, sympy.Integer(offset)) + + @classmethod + def _initialize_pointwise_overrides(cls, target): + assert target in {"triton", "cpp", "cppvec"}, target + + def pointwise_factory_1(impl): + def func(x): + return impl.format(x=x) + + return func + + def pointwise_factory_2(impl): + def func(x, y): + return impl.format(x=x, y=y) + + return func + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if isinstance(impl, str): + nof_args = 2 if "{y}" in impl else 1 + # extend the following dictionary with factory + # functions for a specific number of arguments as + # needed: + factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args] + setattr(cls, funcname, staticmethod(factory(impl))) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: str + triton: Optional[str] = None # None when not impl in libdevice/triton + cppvec: Optional[str] = None # None when not impl in aten/.../vec + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +pointwise_overrides_data: Dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_j0_forward({x})", + triton="libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_j1_forward({x})", + triton="libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_y0_forward({x})", + triton="libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="bessel_y1_forward({x})", + triton="libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_digamma({x})", + cppvec="{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_erfcx({x})", + triton="libdevice.erfcx({x})", + name="special_erfcx", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i0({x})", + triton="libdevice.cyl_bessel_i0({x})", + cppvec="{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i0e({x})", + cppvec="{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i1({x})", + triton="libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_i0_forward({x})", + triton="libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_i1_forward({x})", + triton="libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="calc_polygamma({y}, {x})", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp="laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +# Use mypy to check protocol implemented correctly +def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: + return h + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name, line): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self): + if all( + self.name not in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ): + return self.line + return None + + def _new_line(self, line): + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: List[str] + + +class KernelArgs: + @staticmethod + def _lookup(prefix, odict, name): + assert isinstance(name, (str, sympy.Symbol)) + if name not in odict: + odict[name] = f"{prefix}{len(odict)}" + return odict[name] + + def __init__(self, sizevars=None): + self.input_buffers = dict() + self.output_buffers = dict() + self.inplace_buffers = dict() + self.sizevars = sizevars or dict() + self.workspace_arg = None + + def __repr__(self): + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + def _buffer_is_marked_removed(self, name): + return isinstance(name, str) and name.startswith("REMOVED") + + def input(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return self.output_buffers[name] + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name, output_name): + assert output_name not in self.inplace_buffers + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + buf = InplacedBuffer( + f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace(self, nbytes: sympy.Expr, zero_fill: bool): + if self.workspace_arg is None: + self.workspace_arg = WorkspaceArg(nbytes, zero_fill) + return "ws_ptr", 0 + + offset = self.workspace_arg.nbytes + zero_fill = zero_fill or self.workspace_arg.zero_fill + self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) + return "ws_ptr", offset + + def seed_offset(self, name, value): + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name): + if str(name) == "seed": + self.sizevars["seed"] = "seed" + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self): + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def wrap_ptr_arg(self, buf, dtype): + return buf + + def wrap_size_arg(self, size): + return str(size) + + def cpp_argdefs(self): + from .cpp import DTYPE_TO_CPP, INDEX_TYPE + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + arg_defs.append(f"const {INDEX_TYPE} {inner}") + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert self.workspace_arg is None, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs(self): + arg_defs = [] + call_args = [] + precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + arg_defs.append(inplaced.inner_name) + call_args.append(inplaced.other_names[-1]) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), self.output_buffers.items() + ): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + arg_defs.append(inner) + call_args.append(outer) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(inner) + call_args.append(outer) + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + if self.workspace_arg is not None: + arg_defs.append("ws_ptr") + call_args.append("workspace") + precompile_args.append(self.workspace_arg) + + return arg_defs, call_args, precompile_args + + def aliases(self): + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield self.output_buffers[other], inplaced.inner_name + + def is_removed(self, name): + def _is_removed(name, buffers): + return name not in buffers or self._buffer_is_marked_removed(buffers[name]) + + return _is_removed(name, self.output_buffers) and _is_removed( + name, self.inplace_buffers + ) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self): + live_outs = set() + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__(self, name, bounds: ValueRanges[Any]): + assert isinstance(bounds, ValueRanges) + self.name = name + self.bounds = bounds + + def __str__(self): + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + + def update_on_args(self, name, args, kwargs): + pass + + +class CppWrapperKernelArgs(KernelArgs): + def wrap_ptr_arg(self, buf, dtype): + from .cpp import DTYPE_TO_CPP + + if config.abi_compatible: + # In the abi_compatible model, we just return the buf here. + # We will form correct call args later in wrapper.generate_kernel_all. + return buf + else: + return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" + + def wrap_size_arg(self, size): + return f"{size}" + + +class CSE: + """Common subexpression elimination""" + + def __init__( + self, + prefix="", + suffix="", + name_prefix="tmp", + iter_buffers=None, + store_cache=None, + reduction_cache=None, + varname_map=None, + ): + self.prefix = prefix + self.suffix = suffix + self.cache = {} + self.name_prefix = name_prefix + self.store_cache = store_cache or {} + self.reduction_cache = reduction_cache or {} + self.iter_buffer_ids = iter_buffers or itertools.count() + self.invalidated_stores = set() + self.varname_map = varname_map or {} + + def invalidate(self, keep_vars: Set[str]): + for name, tmp in list(self.store_cache.items()): + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} + + def clone(self): + # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional + return CSE( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + ) + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write=True, + assignment=True, + ) -> CSEVariable: + if isinstance(expr, OpsValue): + expr = expr.value + + assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + return expr + cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr + var = self.cache.get(cache_key, None) + if not var: + var = self.newvar(bounds) if assignment else None + self.cache[cache_key] = var + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + else: + var.bounds = var.bounds.tighten(bounds) + + return var + + def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds) + self.varname_map[var_name] = var + return var + + +class IndirectAssertLine(DeferredLineBase): + def __init__(self, line, assert_fn, var, mask, size_map): + self.var = var + self.mask = mask + self.line = line + self.assert_fn = assert_fn + self.size_map = size_map + + def __call__(self): + size, size_str = self.size_map[(self.var, self.mask)] + + # We assert if we've not been able to prove the bound + assert_min = (self.var.bounds.lower >= 0) != sympy.true + assert_max = (self.var.bounds.upper < size) != sympy.true + + # FooBar interview question + if not (assert_min or assert_max): + return None + elif assert_min and assert_max: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is suported by triton + cond = f"(0 <= {self.var}) & ({self.var} < {size_str})" + cond_print = f"0 <= {self.var} < {size_str}" + elif assert_min: + cond = f"0 <= {self.var}" + cond_print = cond + else: + assert assert_max + cond = f"{self.var} < {size_str}" + cond_print = cond + + if self.mask: + cond = f"({cond}) | ~{self.mask}" + return self.line.format( + assert_fn=self.assert_fn, cond=cond, cond_print=cond_print + ) + + def _new_line(self, line): + return IndirectAssertLine( + line, self.assert_fn, self.var, self.mask, self.size_map + ) + + +class CodeGen: + def __init__(self): + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self): + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class Kernel(CodeGen): + newvar_prefix = "" + suffix = "" + overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None + # TODO: these look dead, but with all the getattr it's hard to tell... + load_format: None = None + store_format: None = None + + def __init__(self, args=None, increase_kernel_count=True): + super().__init__() + if increase_kernel_count: + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse: CSE = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers = set() + self.store_buffer_names = set() + self._load_mask = None + # set in set_current_node + self.current_node = None + self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None + # Upper bounds for indirect_indexing and their str representation + # NB: None, None is never stored in map, but it is the assumed + # "not set" value for the dict + self.indirect_max_sizes: Dict[ + Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]] + ] = {} + + self.removed_buffers = set() + self.inplaced_to_remove = set() + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers = dict() + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name = None + + @contextlib.contextmanager + def set_current_node(self, node): + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers(self, lb, cb=None, sb=None): + if cb is None: + cb = lb + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = cse.clone() + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError() + + def indirect_load(self, name: str, index: sympy.Expr): + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + raise NotImplementedError() + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError() + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + raise NotImplementedError() + + def scan( + self, + dtype: torch.dtype, + combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], + value: CSEVariable, + init: int, + ) -> CSEVariable: + raise NotImplementedError() + + def bucketize( + self, + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError() + + @property + def assert_function(self) -> str: + raise NotImplementedError() + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError() + + def __enter__(self): + # TODO: hoist this to top level + class CSEProxy: + self.name = "CSEProxy" + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + # TritonTemplateKernel has no current_node + buf_bounds = ValueRanges.unknown() + if hasattr(V.interpreter, "current_node"): + fx_node = V.interpreter.current_node + assert isinstance(self.node_to_bounds, dict) + buf_bounds = self.node_to_bounds.get( + fx_node, ValueRanges.unknown() + ) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + + def do_cse(v): + csevar = self.cse.generate(self.compute, v, bounds=buf_bounds) + csevar.update_on_args(name, args, kwargs) + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def indirect_indexing( + var: CSEVariable, size: sympy.Expr, check: bool = True + ): + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg = var.bounds & ValueRanges(-sympy.oo, -1) + new_bounds = ValueRanges(neg.lower + size, neg.upper + size) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, sympy.oo) + new_bounds = new_bounds | pos + + stm = ops.add(var, self.rename_indexing(size)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, "0") + stm = ops.where(lt, stm, var) + new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + new_var.update_on_args("index_wrap", (var,), {}) + var = new_var + + if self.generate_assert(check): + mask = self.load_mask(var) + + # An assertion line may have been written already, if so just + # update the max size. + map_key = (var, mask) + existing_size, _ = self.indirect_max_sizes.get( + map_key, (None, None) + ) + if existing_size is not None: + size = sympy.Min(size, existing_size) + else: + line = ( + '{assert_fn}({cond}, "index out of bounds: {cond_print}")' + ) + self.compute.writeline( + IndirectAssertLine( + line, + self.assert_function, + var, + mask, + self.indirect_max_sizes, + ) + ) + + self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) + return sympy_index_symbol(str(var)) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_startswith(index, "tmp"): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return store_cache[name] + return self.load(name, index) + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + self.cse.store_cache[name] = value + if self.current_node: + for other_name in self.current_node.get_mutations(): + self.cse.store_cache[other_name] = value + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + else: + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + self.cse.store_cache[name] = value + if self.current_node: + for other_name in self.current_node.get_mutations(): + self.cse.store_cache[other_name] = value + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtype: torch.dtype, + combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], + value: CSEVariable, + init: int, + ) -> CSEVariable: + return self.scan(dtype, combine_fn, value, init) + + @staticmethod + def bucketize( + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Given values (tensor) and offsets_name (reference to the name of a 1D + tensor), calculate the bucket that each value belongs to. + + e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True + return = [ 0, 1, 1, 1, 1, 3, 3, 4]. + + When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. + When right == True, bucket i refers to range [offsets[i], offsets[i+1]). + + Offsets must be non-decreasing or the result is undefined. + """ + return self.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + assert self.overrides + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + if V.graph.scheduler: + V.graph.scheduler.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def generate_assert(self, check): + return (check or config.debug_index_asserts) and config.assert_indirect_indexing + + def load_mask(self, var) -> str: + # only the triton kernel requires mask + return "" + + def rename_indexing(self, index) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] # type: ignore[return-value] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if x.name.startswith(("s", "u", "ps")) + or (x.name.startswith("i") and not x.name.startswith("idx")) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args, **kwargs): + return CSEVariable(*args, **kwargs) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + # Load value as mask + is_load_as_mask: bool = False + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + # Load uint8/int8 value as float32 + is_load_int8_as_float: bool = False + + +@functools.lru_cache(None) +def jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]] + + +class ChoiceCaller: + """ + Represents a possible choice used in autotune_process.py. + During autotuning, self.benchmark() is first called to get benchmark result, + and if this choice is selected, self.output_node() is called to get the output_node. + + Children classes: TritonTemplateCaller, CUDATemplateCaller. + """ + + def __init__(self, name, input_nodes, layout): + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + + def benchmark(self, *args, out) -> float: + algo = self.to_callable() + return do_bench(lambda: algo(*args, out=out)) + + def call_name(self) -> str: + raise NotImplementedError() + + def to_callable(self): + raise NotImplementedError() + + def hash_key(self) -> str: + raise NotImplementedError() + + def output_node(self) -> "TensorBox": + raise NotImplementedError() + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return {} + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def _template_from_string(source): + env = jinja2_env() + if env is not None: + return env.from_string(source) + return None + + @staticmethod + def _fake_get_dtype(fake_out): + _get_dtype_real = V.graph.get_dtype + + def get_dtype(name): + if name == fake_out.get_name(): + return fake_out.get_dtype() + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str): + self.name = name + + def maybe_append_choice(self, choices, **kwargs): + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + except NotImplementedError: + pass + + def generate(self, **kwargs) -> ChoiceCaller: + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError() diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..1bb2282bc9cf036c8734294c9c24a533f756fb23 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,4038 @@ +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import re +import sys +from copy import copy, deepcopy +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._inductor.ir import StorageBox, TensorBox +from torch._prims_common import is_float_dtype +from torch.utils import _pytree as pytree +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .. import codecache, config, ir, metrics +from ..codegen.wrapper import WrapperCodeGen +from ..optimize_indexing import range_expressable_in_32_bits +from ..scheduler import ( + BaseScheduling, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_fused_kernel_name, + is_welford_reduction, + parallel_num_threads, + sympy_index_symbol, + sympy_product, + sympy_subs, +) + +from ..virtualized import ops, OpsValue, V +from .common import ( + BracesBuffer, + CppWrapperKernelArgs, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + ExprPrinter, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "half", + torch.int64: "long", + torch.int32: "int", + torch.int16: "short", + torch.int8: "signed char", + torch.uint64: "unsigned long", + torch.uint32: "unsigned int", + torch.uint16: "unsigned short", + torch.uint8: "unsigned char", + torch.uint32: "unsigned int", + torch.uint64: "unsigned long", + torch.bool: "bool", + torch.bfloat16: "bfloat16", + torch.complex64: "complex64", + torch.float8_e4m3fn: "float8_e4m3fn", + torch.float8_e5m2: "float8_e5m2", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "cpu": "at::kCPU", + "cuda": "at::kCUDA", +} + +INDEX_TYPE = "long" + +NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = { + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", +} + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "c10::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in {"max", "argmax"}: + return ( + f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::min()" + ) + if reduction_type in {"min", "argmin"}: + return ( + f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + assert reduction_type not in {"argmin", "argmax"} + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + + return scalar_type + + +def reduction_combine(reduction_type, var, next_value): + if reduction_type == "sum": + return f"{var} + {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in {"argmin", "argmax"}: + return f"{acc}.index" + return acc + + +def is_to_lowp_dtype(expr): + to_exprs = ["cvt_fp32_to_lowp_fp", "c10::convert"] + if any(to_expr in expr for to_expr in to_exprs): + if "half" in expr: + return torch.half + if "bfloat16" in expr: + return torch.bfloat16 + return None + + +def get_lowp_to_fp32_expr(lowp_var, src_dtype, kernel): + if isinstance(kernel, CppVecKernel): + return f"cvt_lowp_fp_to_fp32<{DTYPE_TO_CPP[src_dtype]}>({lowp_var})" + else: + assert isinstance(kernel, CppKernel) + return f"c10::convert({lowp_var})" + + +index_value_name_counter = 1 + + +def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar): + global index_value_name_counter + struct_name = f"IndexValue_{index_value_name_counter}" + index_value_name_counter += 1 + + # A small annoyance, due to it being a little cumbersome to just throw {} into strings + prefix = [ + f"struct {struct_name} {{size_t index; {DTYPE_TO_CPP[src_dtype]} value;}};", + f"{struct_name} {tmpvar}{{0, {reduction_init(reduction_type, src_dtype)}}};", + ] + + if reduction_type in ["argmax", "argmin"]: + compare_op = "greater_or_nan" if reduction_type == "argmax" else "less_or_nan" + prefix.extend( + [ + "#if !defined(__clang_major__) || __clang_major__ > 9", + f"#pragma omp declare reduction({reduction_type} : {struct_name} :\\", + f" omp_out = {compare_op}(omp_in.value, omp_out.value, omp_in.index, omp_out.index) ? omp_in : omp_out)\\", + f"\tinitializer(omp_priv = {{0, {reduction_init(reduction_type, src_dtype)}}})", + "#endif", + ] + ) + + return prefix + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor") + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus") + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int): + index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index_vec_simplified, var) + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr): + return f"{int(expr)}L" + + def _print_Where(self, expr): + c = self.paren(self.doprint(expr.args[0])) + p = self.paren(self.doprint(expr.args[1])) + q = self.paren(self.doprint(expr.args[2])) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + if div != 1: + div = self.paren(self.doprint(div)) + if expr.is_integer: + x = f"c10::div_floor_integer({x}, {div})" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.paren(self.doprint(mod)) + return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + if expr.is_integer: + return f"c10::div_floor_integer({x}, {div})" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Pow(self, expr): + # Uses float constants to perform FP div + base, exp = expr.args + base = self._print(base) + + if exp == 0.5 or exp == -0.5: + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + assert exp.is_integer + exp = int(exp) + if exp > 0: + r = "*".join([self.paren(base)] * exp) + elif exp < 0: + r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Rational(self, expr): + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min({il})" + + def _print_Max(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max({il})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext: + return node.meta.get(OptimizationContext.key, None) + + +def get_current_node_opt_ctx() -> OptimizationContext: + assert V.interpreter.current_node + return get_opt_ctx(V.interpreter.current_node) + + +class CppVecUnsupportedError(Exception): + pass + + +class CppCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any]): + super().__init__(name, bounds) + self.is_vec = False + self.dtype: Optional[torch.dtype] = None + self.dependent_itervars: Set[sympy.Symbol] = set() + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[1] is index + self._set_dependent_itervars(args[1]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + # NOTE [dtype of CppCSEVariable] + # Deciding dtype according to the current optimization context is not + # always accurate since the dtypes are initialized during dtype propagation + # at the beginning of the codegen. It is possible that some ops are invoked + # during the codegen of the current op and take different dtypes from the + # current op. + # TODO(jgong5): A more accurate way of deciding the dtype of the variables is to + # propagate the dtypes here inside `update_on_args`. + if ( + hasattr(V.interpreter, "current_node") + and get_current_node_opt_ctx() is not None + ): + self.dtype = get_current_node_opt_ctx().dtype + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"decltype({a})({a} + {b})" + + @staticmethod + def sub(a, b): + return f"decltype({a})({a} - {b})" + + @staticmethod + def mul(a, b): + return f"decltype({a})({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + if src_dtype in (torch.float16, torch.bfloat16): + # c10::bit_cast requires the source and target have the bitwidth. + # Because the input tensor's dtype could be promoted, e.g. from float16 to + # float, we have to cast the tensor to its original source dtype before + # invoking bit_cast. We also need to convert the bit-casted tensor + # back to float to make sure we keep using higher precision values + # for the rest of the computation. + cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})" + cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})" + return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})" + else: + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + return f"std::signbit({x})" + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, constants + # must be promoted as well + dtype = torch.float32 + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + return f"decltype({a})({a} << {b})" + + @staticmethod + def bitwise_right_shift(a, b): + return f"decltype({a})({a} >> {b})" + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and not arg.is_vec + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + # broadcast scalar args to vector if needed + new_args = [] + vec_dtype = vectors[0].dtype + for arg in args: + if isinstance(arg, CppCSEVariable) and not arg.is_vec: + assert isinstance(V.kernel, CppVecKernel) + # align scalar data type to the vector for binary ops + if len(args) == 2 and arg.dtype != vec_dtype: + arg = ops.to_dtype(arg, vec_dtype) + arg = arg.value if isinstance(arg, OpsValue) else arg + # See NOTE [dtype of CppCSEVariable]: we have to fix arg.dtype since + # the dtype from optimization context could be wrong. + assert isinstance(arg, CppCSEVariable) + arg.dtype = vec_dtype + new_arg = V.kernel.broadcast(arg) + new_args.append(new_arg) + else: + new_args.append(arg) + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr( + scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] + ) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + return f"to_float_mask({x} == {y})" + + @staticmethod + def ne(x, y): + return f"to_float_mask({x} != {y})" + + @staticmethod + def lt(x, y): + return f"to_float_mask({x} < {y})" + + @staticmethod + def gt(x, y): + return f"to_float_mask({x} > {y})" + + @staticmethod + def le(x, y): + return f"to_float_mask({x} <= {y})" + + @staticmethod + def ge(x, y): + return f"to_float_mask({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + return f"({a} != 0) & ({b} != 0)" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"({a} != 0) | ({b} != 0)" + + @staticmethod + def logical_xor(a, b): + return f"({a} != 0) ^ ({b} != 0)" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def nextafter(x): + return f"{x}.nextafter()" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + # For real x, asinh(x) = log(x + sqrt(1 + x**2)) + vec_one = f"decltype({x})(1)" + return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + _t = f"decltype({a})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(b, CppCSEVariable) + if b.dtype != torch.float: + raise CppVecUnsupportedError( + "where with non-float tensor is not supported in vectorized codegen" + ) + return f"decltype({b})::blendv({c}, {b}, {a})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None): + assert dtype in [ + torch.bool, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ], f"{__name__} does not support {dtype}" + node: torch.fx.Node = V.interpreter.current_node + assert node and isinstance(node, torch.fx.Node) + opt_ctx_x = get_opt_ctx(node.args[1]) + assert opt_ctx_x + if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype == torch.bool: + return f"vec_convert_to_mask({x})" + if opt_ctx_x.dtype == torch.bool and dtype in (torch.float, torch.float32): + return f"mask_convert_to_float({x})" + if opt_ctx_x.dtype == torch.bool and dtype in DTYPE_LOWP_FP: + return f"mask_convert_to_lowp<{DTYPE_TO_CPP[dtype]}>({x})" + if opt_ctx_x.dtype == torch.bool and dtype == torch.int64: + return f"mask_convert_to_int64({x})" + if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype in DTYPE_LOWP_FP: + return f"cvt_fp32_to_lowp_fp<{DTYPE_TO_CPP[dtype]}>({x})" + if opt_ctx_x.dtype in DTYPE_LOWP_FP and dtype in (torch.float, torch.float32): + return f"cvt_lowp_fp_to_fp32<{DTYPE_TO_CPP[opt_ctx_x.dtype]}>({x})" + if opt_ctx_x.dtype in (torch.uint8, torch.int8) and dtype in ( + torch.float, + torch.float32, + ): + # Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size() + return f"at::vec::convert_int8_to_float({x})" + if opt_ctx_x.dtype in (torch.float, torch.float32) and dtype in ( + torch.uint8, + torch.int8, + ): + # if we already handle the saturation previously. + # * Pattern match of quantization op in the loop body. + # * Skip the explicit saturation and clamp inside at::vec::convert_float_to_int8. + return f"at::vec::convert_float_to_int8<{DTYPE_TO_CPP[dtype]}>({x})" + if opt_ctx_x.dtype == torch.int32 and dtype == torch.float: + return f"at::vec::convert_to_fp_of_same_size({x})" + if opt_ctx_x.dtype == torch.float and dtype == torch.int32: + return f"at::vec::convert_to_int_of_same_size({x})" + if opt_ctx_x.dtype == torch.int64 and dtype == torch.float: + return f"cvt_int64_to_fp32({x})" + if opt_ctx_x.dtype == torch.float and dtype == torch.int64: + return f"cvt_fp32_to_int64({x})" + if opt_ctx_x.dtype == torch.int32 and dtype == torch.int64: + return f"cvt_int32_to_int64({x})" + if opt_ctx_x.dtype == torch.int64 and dtype == torch.int32: + return f"cvt_int64_to_int32({x})" + # TODO(jgong5): support conversion for other types + # currently we only allow load/store torch.uint8 and handle conversion there + return f"({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + body_code = f"{var}()" + body_code_vec = ( + body_code + if result.is_vec + else f"{V.kernel._get_vec_type(torch.float)}({body_code})" + ) + other_code = value_to_cpp(other, "float") + other_code_vec = f"{V.kernel._get_vec_type(torch.float)}({other_code})" + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec or result.is_vec: + if result.dtype != torch.float: + raise CppVecUnsupportedError( + "masked with non-float tensor is not supported in vectorized codegen" + ) + type = f"decltype({body_code_vec})" + float_mask = f"to_float_mask({new_mask})" + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if (all_zero({float_mask}))") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + code.writeline( + f"return {type}::blendv({other_code_vec}, {body_code_vec}, {float_mask});" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx and opt_ctx.dtype is not None + dtype = opt_ctx.dtype + assert dtype == torch.int32 + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = stride_at_vec_range(index, tiling_var, V.kernel.tiling_factor) + if stride.is_number and not V.kernel.index_indirect_depends_on( + index, tiling_var + ): + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + value = ops.to_dtype(cexpr(index), dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel.load_non_contiguous(None, index, dtype, V.kernel.compute) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None + self.ranges: List[sympy.Expr] = [] + self.itervars: List[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + self.reduction_suffix = IndentedBuffer() + self.reduction_var_map = {} + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: Dict[Tuple[str, str], str] = {} + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def cache_fp32_cse_var_before_lowp_store(self, var_to_store): + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + + if var_to_store.dtype not in DTYPE_LOWP_FP: + # only need to cache fp32 cse var while var_to_store is lowp data + return + + def find_fp32_var(var, cache): + fp32_cse_var = None + fp32_cse_var_name = None + lowp_dtype = None + for expr, cse_var in cache.items(): + if cse_var == var: + lowp_dtype = is_to_lowp_dtype(expr) + if lowp_dtype: + m = re.search(r"tmp\d+", expr) + assert m + fp32_cse_var_name = m.group() + if fp32_cse_var_name: + for cse_var in cache.values(): + if cse_var.name == fp32_cse_var_name: + fp32_cse_var = cse_var + break + assert fp32_cse_var is not None + return fp32_cse_var, lowp_dtype + + fp32_var, lowp_dtype = find_fp32_var(var_to_store, self.cse.cache) + if fp32_var: + self.cse.cache[ + get_lowp_to_fp32_expr(var_to_store, lowp_dtype, self) + ] = fp32_var + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + if V.graph.get_dtype(name) in [torch.float16]: + line = f"static_cast({line})" + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + self.cache_fp32_cse_var_before_lowp_store(value) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.reduction_var_map[acc] = reduction_type + if argmax_or_argmin: + self.reduction_prefix.writelines( + argmax_argmin_prefix(reduction_type, src_dtype, acc) + ) + compare_op = ( + "greater_or_nan" if reduction_type == "argmax" else "less_or_nan" + ) + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writelines( + [ + f"if(!({compare_op}({acc}.value, {value}, {acc}.index, {cexpr_index(index)}))) {{", + f" {acc}.index = {cexpr_index(index)}; {acc}.value = {value};", + "}", + ], + ) + else: + acc_type = reduction_acc_type(reduction_type, dtype) + + if (reduction_type, acc_type) not in self.reduction_omp_dec: + if RTYPE_TO_CPP[reduction_type] not in NATIVE_OMP_RTYPES: + # Scalar reduction for other reductions are declared by default + self.reduction_prefix.splice( + f"""\ + #pragma omp declare reduction(\ + {RTYPE_TO_CPP[reduction_type]}:{acc_type}:\ + omp_out = {reduction_combine(reduction_type, "omp_out", "omp_in")}) \ + initializer(omp_priv={{{reduction_init(reduction_type, dtype)}}}) + """ + ) + self.reduction_omp_dec[reduction_type, acc_type] = RTYPE_TO_CPP[ + reduction_type + ] + + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};" + ) + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value)};" + ) + + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple( + reduction_lengths + ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol(f"x{n}") for n in range(len(self.ranges)) + ] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + threads = parallel_num_threads() + assert self.call_ranges is not None + par_depth = self.decide_parallel_depth( + self.call_ranges[: loop_nest.max_parallel_depth()], threads + ) + with contextlib.ExitStack() as stack: + if par_depth: + if loop_nest.is_reduction_only(): + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_kernel(kernel): + with contextlib.ExitStack() as stack: + assert kernel + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.preloads) + kernel.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(kernel.loads) + code.splice(kernel.compute) + code.splice(kernel.stores) + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.poststores) + + def get_reduction_code_buffer(loops, is_suffix=True): + for loop in loops: + for kernel in loop.get_kernels(): + if is_suffix: + return kernel.reduction_suffix + else: + return kernel.reduction_prefix + return None + + def gen_loops(loops: List[LoopLevel], in_reduction=False): + with contextlib.ExitStack() as stack_outer: + if loops: + loop = loops[0] + if loop.is_reduction() and not in_reduction: + reduction_prefix = get_reduction_code_buffer( + loops, is_suffix=False + ) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if loop_nest.is_reduction_only() and loop.parallel: + worksharing.parallel(threads) + + for loop in loops: + gen_loop(loop, in_reduction) + + if loops: + loop = loops[0] + if loop_nest.is_reduction_only() and loop.parallel: + worksharing.close() + if loop.is_reduction() and not in_reduction: + code.splice( + get_reduction_code_buffer(loops, is_suffix=True) + ) + + def gen_loop(loop: LoopLevel, in_reduction=False): + with contextlib.ExitStack() as stack: + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + # generate inner loops or loop body + if loop.inner: + gen_loops(loop.inner, loop.is_reduction()) + else: + kernels = loop.get_kernels() + assert len(kernels) == 1 + gen_kernel(kernels[0]) + + stack.enter_context(code.indent()) + if loop_nest.root: + gen_loops(loop_nest.root) + else: + gen_kernel(loop_nest.kernel) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNestWithSplit.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, ranges, threads): + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return depth + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor=0, + tiling_idx=-1, + tiling_dtype=torch.float, + ): + super().__init__(args, num_threads) + self.vec_isa = codecache.pick_vec_isa() + assert self.vec_isa + if tiling_factor == 0: + tiling_factor = self.vec_isa.nelements(dtype=tiling_dtype) + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx is not None + load_mask_str = f"to_float_mask({load_mask})" if load_mask else None + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype in (torch.uint8, torch.int8) and opt_ctx.is_load_int8_as_float: + assert self._get_num_vectors(torch.uint8) == 1 + line = ( + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>::loadu_one_fourth({loadbuf})" + ) + elif opt_ctx.is_load_as_mask: + line = f"flag_to_float_vec({loadbuf})" + elif dtype in DTYPE_LOWP_FP: + line = ( + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.tiling_factor})" + ) + else: + line = ( + f"masked_load({loadbuf}, {load_mask_str})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf})" + ) + return line + + def load_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + ) -> CppCSEVariable: + """ + Load a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :return: a CppCSEVariable that represents the loaded vector. + """ + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with self.swap_buffers(code), code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {result_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data());" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + assert opt_ctx is not None + is_mask = opt_ctx.is_load_as_mask + code = BracesBuffer() + code.writeline("[&]") + with self.swap_buffers(code), code.indent(): + result_type = "float" if is_mask else f"{DTYPE_TO_CPP[dtype]}" + result_size = get_result_size(dtype) + result_declare = ( + f"__at_align__ std::array<{result_type}, {result_size}> tmpbuf;" + ) + code.writeline(result_declare) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if s.name.startswith("tmp") # type: ignore[attr-defined] + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + load_mask = None + if self._load_mask is not None: + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = ( + f"vector_lane_mask_check({self._load_mask}, {itervar_inner})" + ) + else: + load_mask = f"{self._load_mask} != 0" + index = sympy_subs(index, replacements) # type: ignore[arg-type] + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + if codecache.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; {itervar_inner} < {self.tiling_factor}; {itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + rhs = ( + f"{var}[{cexpr_index(index)}]" + if var is not None + else f"{cexpr_index(index)}" + ) + if is_mask: + rhs = f"flag_to_float_scalar({rhs})" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = stride_at_vec_range(index, tiling_var, self.tiling_factor) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + non_contiguous = stride != 1 or self.index_indirect_depends_on( + index, tiling_var + ) + if non_contiguous: + csevar = self.load_non_contiguous(var, index, dtype) + else: + line = self._get_vec_load_line(var, index, dtype, self._load_mask) + csevar = self.cse.generate(self.loads, line) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (name, index), {}) + csevar.is_vec = True + return csevar + + def _get_vec_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + ): + """ + Get a store line str that stores `value` into `var` at `index` of `dtype`. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + assert index.has(tiling_var), f"index: {index}, tiling_var: {tiling_var}" + var_expr = f"{var} + {cexpr_index(index)}" + stride = stride_at_vec_range(index, tiling_var, self.tiling_factor) + non_contiguous = stride != 1 or self.index_indirect_depends_on( + index, tiling_var + ) + if non_contiguous: + var_expr = "tmpbuf" + if dtype == torch.float: + line = f"{value}.store({var_expr});" + else: + line = f"{value}.store({var_expr}, {self.tiling_factor});" + if non_contiguous: + inner = sympy_index_symbol(f"{tiling_var}_inner") + new_index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=inner + ) + tmp_bufsize = ( + f"{self.tiling_factor}*sizeof(float)/sizeof({DTYPE_TO_CPP[dtype]})" + ) + line = ( + f"{{ __at_align__ {DTYPE_TO_CPP[dtype]} tmpbuf[{tmp_bufsize}]; {line} " + f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++) " + f"{var}[{cexpr_index(new_index)}] = tmpbuf[{inner}]; }}" + ) + return line + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert mode is None + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.output(name) + self.cache_fp32_cse_var_before_lowp_store(value) + index = self.rename_indexing(index) + self.stores.writeline( + DeferredLine( + name, + self._get_vec_store_line(value, var, index, V.graph.get_dtype(name)), + ) + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + assert reduction_type in { + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + } + assert dtype == src_dtype + assert dtype in [torch.float, torch.int64] + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + acc_type = reduction_acc_type(reduction_type, dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, dtype) + + if (reduction_type, acc_type) not in self.reduction_omp_dec: + if RTYPE_TO_CPP[reduction_type] not in NATIVE_OMP_RTYPES: + # Scalar reduction for other reductions are declared by default + self.reduction_prefix.splice( + f"""\ +#pragma omp declare reduction(\ +{RTYPE_TO_CPP[reduction_type]}:{acc_type}:\ +omp_out = {reduction_combine(reduction_type, "omp_out", "omp_in")}) \ +initializer(omp_priv={{{reduction_init(reduction_type, dtype)}}}) + """ + ) + self.reduction_omp_dec[reduction_type, acc_type] = RTYPE_TO_CPP[ + reduction_type + ] + + if (reduction_type, acc_type_vec) not in self.reduction_omp_dec: + self.reduction_prefix.splice( + f"""\ +#pragma omp declare reduction(\ +{RTYPE_TO_CPP[reduction_type]}:{acc_type_vec}:\ +omp_out = {self.reduction_combine_vec(reduction_type, "omp_out", "omp_in")}) \ +initializer(omp_priv={{{self.reduction_init_vec(reduction_type, dtype)}}}) + """ + ) + self.reduction_omp_dec[reduction_type, acc_type_vec] = RTYPE_TO_CPP[ + reduction_type + ] + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + acc_vec = f"{acc}_vec" + + self.reduction_var_map[acc_vec] = reduction_type + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};" + ) + self.reduction_prefix.writeline( + f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" + ) + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};" + ) + + tmpvar: Union[str, CSEVariable] + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert ( + self._get_num_vectors(dtype) == 1 + ), "Welford reduction does not support VectorizedN (N>1)" + next_value = f"welford_vec_reduce_all({acc_vec})" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + # Only float reductions are vectorized currently + dtype = torch.float + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + self.reduction_suffix.writeline( + DeferredLine( + name, + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});", + ) + ) + else: + # Vertical reduction + store_lines = [] + if out_dtype != dtype: + if out_dtype in DTYPE_LOWP_FP and dtype == torch.float: + _lowp_fp_tmpvar_vec = f"{DTYPE_TO_CPP[out_dtype]}_{value}" + store_lines = [ + DeferredLine( + name, + f"auto {_lowp_fp_tmpvar_vec} = cvt_fp32_to_lowp_fp<{DTYPE_TO_CPP[out_dtype]}>({value});", + ) + ] + value = _lowp_fp_tmpvar_vec + else: + raise AssertionError( + f"Unsupported reduction type from {dtype} to {out_dtype}" + ) + store_lines += [ + DeferredLine( + name, + self._get_vec_store_line(value, var, index, out_dtype), + ) + ] + self.reduction_suffix.writelines(store_lines) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"to_float_mask({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange( + self, index: Union[sympy.Expr, CppCSEVariable], stride: sympy.Symbol + ) -> CppCSEVariable: + if isinstance(index, sympy.Expr): + index = cexpr(index) + else: + assert isinstance(index, CppCSEVariable) + assert not index.is_vec + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(torch.int32)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = torch.int32 + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + scalar_init = reduction_init(reduction_type, dtype) + return f"{vec_type}({scalar_init})" + + def reduction_acc_type_vec(self, reduction_type, dtype): + assert reduction_type not in {"argmin", "argmax"} + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + + return vec_type + + def reduction_combine_vec(self, reduction_type, var, next_value): + if reduction_type == "max": + return f"at::vec::maximum({var}, {next_value})" + elif reduction_type == "min": + return f"at::vec::minimum({var}, {next_value})" + elif reduction_type == "sum": + return f"{var} + {next_value}" + elif reduction_type == "prod": + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + else: + raise NotImplementedError() + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__(self, args, num_threads, tiling_factor, tiling_indices, tiling_dtype): + super().__init__( + args, num_threads, tiling_factor, tiling_indices[1], tiling_dtype + ) + self.tiling_indices = tiling_indices + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store(self, name, var, index, is_store): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{factor}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + load_or_store = f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{factor},{factor}>({src}, {ld_src}, {dst}, {ld_dst});" + if is_store: + tile_var = self.cse.newvar() + elif load_or_store not in self.cse.cache: + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.cache[load_or_store] + + if need_define: + define_line = f"{DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}] __attribute__ ((aligned ({factor})));" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + assert mode is None + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}" + if V.graph.get_dtype(name) in DTYPE_LOWP_FP: + line = f"{value}.store({storebuf}, {self.tiling_factor});" + elif V.graph.get_dtype(name) in (torch.uint8, torch.int8): + line = f"{value}.store({storebuf}, {self.tiling_factor});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + code.writeline( + f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +class CppVecKernelChecker(CppVecKernel): + def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1): + super().__init__(args, num_threads, tiling_factor, tiling_idx) + + # Since this kernel is only for checker but does not generate any + # code, so we need to decrease the kernel count. + metrics.generated_kernel_count -= 1 + + # Used to record the graph wrapper code as the wrapper_code status could be + # changed during graph run. + self._orig_wrapper_code = None + + self.simd_vec = True + + self.fast_vec_list = [] + for k, v in CppVecOverrides.__dict__.items(): + if isinstance(v, staticmethod): + self.fast_vec_list.append(k) + self.exit_stack = contextlib.ExitStack() + + # Cache all the load result + self.load_supported_dtypes: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ] + self.store_supported_dtypes: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ] + # Cache the dtypes of the store operation. If the store is mixing dtypes, the + # vectorization would not support it as it is hard to determine the vec dtype + self.store_dtypes: List[torch.dtype] = [] + # The dtype is used for vectorization + self.vec_dtype: torch.dtype = torch.float32 + + def disable_vec(self, msg=None): + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Disabled vectorization: %s", msg) + self.simd_vec = False + + def is_mask(self, name: str, users: Dict[torch.fx.Node, None]): + load_type = V.graph.get_dtype(name) + if load_type == torch.bool: + return all(user.target in ("where", "masked") for user in users.keys()) + elif load_type in (torch.uint8, torch.int8): + """ + If the load value is torch.uint8/int8, then we only support the loaded + value is as the mask. + """ + if not all( + user.target == "to_dtype" and user.args[-1] == torch.bool + for user in users.keys() + ): + return False + + for to_dtype_node in users.keys(): + assert to_dtype_node.target == "to_dtype" + if not all( + user.target in ("where", "masked") + for user in to_dtype_node.users.keys() + ): + return False + return True + else: + return False + + def is_load_int8_as_float(self, name: str, users: Dict[torch.fx.Node, None]): + """ + Check: + 1. load_type is torch.uint8 or torch.int8 + 2. has 1 user node of target to_dtype + 3. dtype of to_dtype is torch.float + """ + load_type = V.graph.get_dtype(name) + if load_type not in (torch.uint8, torch.int8): + return False + if len(users) == 1: + user = next(iter(users)) + if (user.target == "to_dtype") and (user.args[-1] == torch.float): + return True + return False + return False + + def can_store_fp32_as_int8(self, store_var: str, value_node: torch.fx.Node): + """ + Check: + 1. store_type is torch.uint8/torch.int8 + 2. value_node is of target to_dtype + 3. dtype of to_dtype node is torch.uint8/torch.int8 + """ + store_type = V.graph.get_dtype(store_var) + if store_type not in (torch.uint8, torch.int8): + return False + if value_node.target == "to_dtype" and value_node.args[-1] in ( + torch.uint8, + torch.int8, + ): + return True + + return False + + def is_load_integer_scalar_tensor(self, name: str, index: sympy.Expr): + load_dtype = V.graph.get_dtype(name) + buffer = V.graph.get_buffer(name) + return ( + load_dtype in [torch.int32, torch.int64] + and isinstance(buffer, TensorBox) + and isinstance(buffer.data, StorageBox) + and (len(buffer.data.layout.size) == 0) + and (index == 0) + ) + + def load(self, name: str, index: sympy.Expr): + with RecordOptimizationContext(__name__) as node_ctx: + load_dtype = V.graph.get_dtype(name) + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + opt_ctx.dtype = load_dtype + opt_ctx.is_load_as_mask = self.is_mask(name, node_ctx.get_fx_node().users) + opt_ctx.is_load_int8_as_float = self.is_load_int8_as_float( + name, node_ctx.get_fx_node().users + ) + + var = self.cse.newvar() + + if len(self.itervars) == 0: + self.disable_vec("not a loop") + return var + + if load_dtype in (torch.bool, torch.uint8, torch.int8) and not ( + opt_ctx.is_load_as_mask or opt_ctx.is_load_int8_as_float + ): + if not opt_ctx.is_load_as_mask: + self.disable_vec(f"{load_dtype} not loaded as mask") + elif not opt_ctx.is_load_int8_as_float: + self.disable_vec(f"{load_dtype} not loaded as float") + return var + + if ( + (load_dtype not in self.load_supported_dtypes) + and not self.is_load_integer_scalar_tensor(name, index) + and index.has(self.itervars[self.tiling_idx]) + ): + self.disable_vec(f"{load_dtype} not supported by load") + return var + + return var + + def store(self, name, index, value, mode=None): + with RecordOptimizationContext(__name__) as node_ctx: + if len(self.itervars) == 0: + self.disable_vec("not a loop") + return self.simd_vec + + store_dtype = V.graph.get_dtype(name) + + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + opt_ctx.dtype = store_dtype + + store_dtype = torch.float if store_dtype == torch.float32 else store_dtype + self.store_dtypes.append(store_dtype) + if store_dtype not in self.store_supported_dtypes: + self.disable_vec(f"{store_dtype} not supported by store") + return self.simd_vec + + if store_dtype in (torch.uint8, torch.int8): + value_node = node_ctx.get_fx_node().all_input_nodes[-1] + if not self.can_store_fp32_as_int8(name, value_node): + self.disable_vec("not support store float32 as uint8/int8") + return self.simd_vec + + assert "buf" in name + index = self.rename_indexing(index) + + if mode: + self.disable_vec(f"store mode: {mode}") + return self.simd_vec + + if index.is_number: + self.disable_vec(f"constant store index: {index}") + return self.simd_vec + + def reduction(self, dtype, src_dtype, reduction_type, value): + if ( + (dtype == torch.float and src_dtype == torch.float) + or (dtype == torch.int64 and src_dtype == torch.int64) + and reduction_type in VECTORIZABLE_RTYPES + ): + pass + else: + self.disable_vec( + f"reduction: dtype {dtype}, src_dtype {src_dtype}, reduction_type {reduction_type}" + ) + if is_welford_reduction(reduction_type): + return tuple([self.simd_vec] * 3) + return self.simd_vec + + def store_reduction(self, name, index, value): + return self.simd_vec + + def is_supported_cmp(self, node: torch.fx.Node): + def get_node_dtype(node): + if type(node) == torch.fx.Node: + opt_ctx: OptimizationContext = get_current_node_opt_ctx() + return opt_ctx.dtype if opt_ctx else None + else: + return None + + def get_cmp_dtypes(node: torch.fx.Node): + return get_node_dtype(node.args[-2]), get_node_dtype(node.args[-1]) + + assert len(node.args) >= 2 + # cmp(x, y): y is a magic value like x >= 1 + if type(node.args[-1]) in [int, float]: + return True + # cmp(x, y): x is a magic value like 1 >= y + if type(node.args[-2]) in [int, float]: + return False + + left_dtype, right_dtype = get_cmp_dtypes(node) + if left_dtype is None or right_dtype is None: + # TODO(Eikan): To record, deduce and propagate the data type of every expression. + return True + else: + return left_dtype == right_dtype + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self._orig_wrapper_code is not None + # Restore the wrapper_code + V.graph.wrapper_code = self._orig_wrapper_code + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def __enter__(self): + # Record the graph wrapper code. The wrapper_code status could be + # changed during graph run. Regarding this checker, we also need to + # run the graph but we don't expect to change any status that would + # impact the code generation. Hence, we record the graph wrapper code + # and replace it with a dummy wrapper_code and then restore to the + # original one as long as the checker is finished. + self._orig_wrapper_code = V.graph.wrapper_code + V.graph.wrapper_code = WrapperCodeGen() + + parent_handler = V.MockHandler() + + class VecCheckerProxy: + bin_cmp_ops = ["eq", "ne", "le", "ge", "lt", "gt"] + + @staticmethod + def _bin_cmp_op(x, y): + current_node: torch.fx.Node = V.interpreter.current_node + if not self.is_supported_cmp(current_node): + self.disable_vec(f"binary comparison op: {current_node}") + return self.simd_vec + + @staticmethod + def __getattr__(name): # type: ignore[misc] + def inner(*args, **kwargs): + if name in VecCheckerProxy.bin_cmp_ops: + return VecCheckerProxy._bin_cmp_op(args, kwargs) + + if name not in self.fast_vec_list: + self.disable_vec(f"op: {name}") + + parent_val = getattr(parent_handler, name)(*args, **kwargs) + return pytree.tree_map(lambda _: self.simd_vec, parent_val) + + return inner + + @staticmethod + def load(name: str, index: sympy.Expr): + return self.load(name, index) + + @staticmethod + def store(name, index, value, mode=None): + return self.store(name, index, value, mode=mode) + + @staticmethod + def reduction(dtype, src_dtype, reduction_type, value): + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def store_reduction(name, index, value): + return self.store_reduction(name, index, value) + + @staticmethod + def constant(val, dtype): + with RecordOptimizationContext(__name__) as node_ctx: + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + # VecKernel override dtype for constant + # Vectorization only support int32/fp32 now + # So if dtype = int64/fp64, we will cast it to int32/fp32 if possible + i32_iinfo = torch.iinfo(torch.int32) + if ( + dtype == torch.int64 + and val <= i32_iinfo.max + and val >= i32_iinfo.min + ): + opt_ctx.dtype = torch.int32 + + f32_iinfo = torch.finfo(torch.float32) + if dtype == torch.double: + if ( + (val <= f32_iinfo.max and val >= f32_iinfo.min) + or (val == torch.inf) + or (val == -torch.inf) + ): + opt_ctx.dtype = torch.float32 + + supported_dtypes = [ + torch.float32, + torch.int32, + torch.int64, + torch.bfloat16, + torch.float16, + torch.bool, + ] + + if opt_ctx.dtype not in supported_dtypes or ( + opt_ctx.dtype == torch.int32 + and not all( + user.target in VecCheckerProxy.bin_cmp_ops + for user in node_ctx.current_node.users + ) + ): + self.disable_vec(f"constant dtype: {opt_ctx.dtype}") + return val + + @staticmethod + def index_expr(expr, dtype): + assert len(self.ranges) == len(self.itervars) + if not len(self.ranges) or not all( + not isinstance(range, sympy.Expr) or sympy.simplify(range).is_number + for range in self.ranges + ): + # if the range value is sympy.Expr, we might could not deduce the accurate loop interval. + self.disable_vec(f"index_expr: {expr}, dtype {dtype}") + return self.cse.newvar() + + def can_use_int32(): + free_symbols = list(expr.free_symbols) + sizes = { + k: v + for k, v in zip(self.itervars, self.ranges) + if k in free_symbols + } + # Trivial case: Range empty + if any(v == 0 for v in sizes.values()): + return True + + vars_ranges = {k: ValueRanges(0, v - 1) for k, v in sizes.items()} + if not vars_ranges or len(vars_ranges) != len(free_symbols): + i32_iinfo = torch.iinfo(torch.int32) + return ( + expr.is_number + and expr <= i32_iinfo.max + and expr >= i32_iinfo.min + ) + expr_ranges = bound_sympy(expr, vars_ranges) + if math.isinf(expr_ranges.lower) or math.isinf(expr_ranges.upper): # type: ignore[arg-type] + return False + # If something takes the values 0..7, we will compare in the loop + # x < 8. As such, for the loop not to overflow in the last iteration, we want + # to check that expr_ranges.upper + 1 is representable as well + return range_expressable_in_32_bits( + ValueRanges( + int(expr_ranges.lower), int(expr_ranges.upper) + 1 # type: ignore[arg-type] + ) + ) + + with RecordOptimizationContext(__name__) as node_ctx: + assert len(self.ranges) == len(self.itervars) + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + if ( + dtype == torch.int64 + and can_use_int32() + and all( + user.target in VecCheckerProxy.bin_cmp_ops + for user in node_ctx.current_node.users + ) + ): + opt_ctx.dtype = torch.int32 + else: + opt_ctx.dtype = dtype + self.disable_vec(f"index_expr: {expr}, dtype {dtype}") + + tmp_var = self.cse.newvar() + return tmp_var + + @staticmethod + def indirect_indexing(index_var, size, check=True): + return sympy_index_symbol(str(index_var)) + + @staticmethod + def masked(mask, body, other): + body() + return self.cse.newvar() + + @staticmethod + def to_dtype(x, dtype, src_dtype=None): + with RecordOptimizationContext(__name__) as node_ctx: + opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() + assert opt_ctx + opt_ctx.dtype = dtype + + cur_node = node_ctx.get_fx_node() + input_value: torch.fx.Node = cur_node.all_input_nodes[1] + if dtype == torch.float: + if input_value.target in [ + "load", + ]: + # Support masked_load for BF16/FP16. Because the legalization will + # insert to_dtype to convert the BF16/FP16 input to FP32. + dtype = ( + V.graph.get_dtype(input_value.args[1]) # type: ignore[arg-type] + if input_value.target == "load" + else input_value.args[-1] + ) + if dtype in [ + torch.float16, + torch.bfloat16, + torch.float, + torch.float64, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ]: + # Convert from dtype to torch.float + pass + else: + self.disable_vec(f"to_dtype: dtype {dtype}") + elif dtype in DTYPE_LOWP_FP: + if not all(usr.target == "store" for usr in cur_node.users): + self.disable_vec( + "to_dtype: bfloat16/float16 expecting users are all stores" + ) + return x + + store_names = [usr.args[1] for usr in cur_node.users] + if not all( + V.graph.get_dtype(name) in [dtype] for name in store_names + ): + self.disable_vec( + "to_dtype: expecting all stores into bfloat16 or float16" + ) + return x + elif dtype == torch.bool: + pass + elif dtype in (torch.uint8, torch.int8): + # Only allow below 2 cases: + # Case 1: to_int8 and store which corresponding to the single quant node + # at last of fusion pattern. + is_to_int8_and_store = all( + usr.target in ["store"] for usr in cur_node.users + ) + # Case 2: to_int8 and to_float which corresponding to pair of quant/dequant node + # at middle of fusion pattern. + is_to_int8_and_to_float = all( + ( + usr.target in ["to_dtype"] + and usr.args[2] == torch.float32 + ) + for usr in cur_node.users + ) + if not (is_to_int8_and_store or is_to_int8_and_to_float): + self.disable_vec(f"to_dtype: dtype {dtype}") + elif dtype in [torch.int64, torch.int32]: + pass + else: + self.disable_vec(f"to_dtype: dtype {dtype}") + return x + + self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + +class CppKernelProxy(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, ir.LoopBody): + return True + + _lowp_fp_type: Optional[torch.dtype] = None + + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + + sub_blocks = [scheduler_node._body.root_block] + list( + scheduler_node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + # TODO(Eikan): Regarding get_index and index_expr, we should conclude the + # the data type as well. + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + return False + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + return False + if _lowp_fp_type: + assert ( + _lowp_fp_type == opt_ctx.dtype + ), "scheduler node do not support bf16/fp16 mix" + else: + _lowp_fp_type = opt_ctx.dtype + else: + return False + + scheduler_node._lowp_fp_type = _lowp_fp_type # type: ignore[attr-defined] + return True + + def legalize_lowp_fp_dtype(self, nodes): + def add_to_dtype(sub_graph: torch.fx.Graph): + def is_lowp_fp_load(node: torch.fx.Node): + if node.target not in ["load"]: + return False + assert len(node.args) == 3 + load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + return load_dtype in DTYPE_LOWP_FP + + def is_lowp_fp_store(node: torch.fx.Node): + if node.target != "store": + return False + _, store_var, _, _, _ = node.args + store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type] + return store_dtype in DTYPE_LOWP_FP + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if is_lowp_fp_load(_node): + # No need to promote to float if all users are direct stores + if all(user.target == "store" for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + to_type_node_args = to_type_node.args + _node.replace_all_uses_with(to_type_node) + to_type_node.args = to_type_node_args + metrics.cpp_to_dtype_count += 1 + elif is_lowp_fp_store(_node): + ops, name, _, value_var, _ = _node.args + # No need to promote to float if it is a user of a load which are all directly stored + if value_var.target == "load" and all( + user.target == "store" for user in value_var.users + ): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + (ops, x, _) = _node.args + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + else: + pass + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + def _legalize_lowp_fp(loop_body: ir.LoopBody): + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, ir.LoopBody) + node: SchedulerNode = _node + + def is_memory_copy_scheduler_node(node: SchedulerNode): + op_counts = node.read_writes.op_counts + return ( + len(op_counts) == 2 and "load" in op_counts and "store" in op_counts + ) + + should_legalize = not is_memory_copy_scheduler_node(node) + if should_legalize: + body: ir.LoopBody = node._body + _legalize_lowp_fp(body) + + def codegen_nodes(self, nodes: List[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + + assert len(nodes) >= 1 + first_node = nodes[0] + vec_dtype = ( + first_node._lowp_fp_type # type: ignore[attr-defined] + if all( + hasattr(_node, "_lowp_fp_type") + and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined] + for _node in nodes + ) + else torch.float + ) + + kernel_group = self.kernel_group + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for node in nodes: + if node.group[1] in [ + (group, reduction_group), + (group + reduction_group, ()), + ]: + assert not in_suffix + node.run(vars, reduction_vars) + else: + in_suffix = True + assert node.group[1] == ( + group, + (), + ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + node.run(vars, ()) + + scalar_kernel = codegen_kernel(CppKernel) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNestWithSplit.build(scalar_kernel) + + if not self.picked_vec_isa: + return + + def select_tiling_indices(tiling_factor): + all_index = [] + for node in nodes: + rw = dependencies.extract_read_writes(node._body, *node._sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = set() + contig_vars_list = [] + non_contig_stride_const = set() + non_contig_stride_other = set() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(s.name.startswith("s") for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = ( + contig_vars - non_contig_stride_const - non_contig_stride_other + ) + if len(contig_vars) == 0: + # no contiguous vars + return [len(self.itervars) - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == len(self.itervars) - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + def select_tiling(dtype: torch.dtype = torch.float): + # TODO(jgong5): support alternative tiling factors and data types + tiling_factor = self.picked_vec_isa.nelements(dtype=dtype) + tiling_indices = select_tiling_indices(tiling_factor) + if tiling_indices: + could_vec = True + for tiling_indice in tiling_indices: + with CppVecKernelChecker( + deepcopy(self.kernel_group.args), + parallel_num_threads(), + tiling_factor, + tiling_indice, + ) as vec_checker: + run(vec_checker) + could_vec = could_vec and vec_checker.simd_vec + if not could_vec: + break + if could_vec: + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_factors, tiling_indices = select_tiling(vec_dtype) + assert len(tiling_factors) == len(tiling_indices) + try: + if len(tiling_indices) == 1: + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype + ) + metrics.generated_cpp_vec_kernel_count += 1 + main_loop, tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + main_loop.set_kernel(vec_kernel) + tail_loop.set_kernel(scalar_kernel) + main_loop.simd_vec = True + tail_loop.simd_omp = True + # We chop the loop into two cubes by the nelements - main loop and tail loop. + # Regarding the main loop, it is straightforward that it could be vectorized with + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized + # as 4(128bits). + tail_loop.simd_nelements = tiling_factors[0] // 2 + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + tile2d_kernel = codegen_kernel( + CppTile2DKernel, tiling_factors[0], tiling_indices, vec_dtype + ) + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype + ) + metrics.generated_cpp_vec_kernel_count += 2 + outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + outer_tail_loop.set_kernel(scalar_kernel) + ( + inner_main_loop, + inner_tail_loop, + ) = outer_main_loop.split_with_tiling( + tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] + ) + inner_main_loop.set_kernel(tile2d_kernel) + inner_tail_loop.set_kernel(vec_kernel) + except CppVecUnsupportedError as e: + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Disabled vectorization: %s", e) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + + def __init__(self, scheduler): + self.scheduler = scheduler + self.get_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def get_kernel_group(self): + from .cpp_wrapper_cpu import CppWrapperCpu + + self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup] + if isinstance(V.graph.wrapper_code, CppWrapperCpu): + self.kernel_group = CppWrapperKernelGroup() + else: + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0 + return get_indexing_ranges_exprs(node.snodes[0]) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + extra_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + def get_buffer(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0 + # use the last scheduler node from the list as it has the most + # relevant indexing expressions + return get_buffer(node.snodes[-1]) + else: + assert isinstance(node, SchedulerNode) + return node.node + + ref_node_buffer = get_buffer(ref_node) + if isinstance(ref_node_buffer, ir.TemplateBuffer): + return False + + assert isinstance(ref_node_buffer, ir.ComputedBuffer) + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + var_ranges1 = ref_node_buffer.get_read_writes().var_ranges + var_ranges2 = node_to_recomp.node.get_read_writes().var_ranges + if var_ranges1 != var_ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def can_fuse_vertical(self, node1, node2): + return self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + + def codegen_nodes(self, nodes: List[SchedulerNode]): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def flush(self): + self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) + self.get_kernel_group() + self._set_flush_status(False) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, call_args, arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_define_and_call(self, wrapper): + self.stack.close() + if not self.scheduled_nodes: + return + + fused_name = ( + get_fused_kernel_name(self.scheduled_nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + arg_defs, call_args, arg_types = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + code = BracesBuffer() + # TODO: support kernel profile on other platforms + enable_kernel_profile = ( + config.cpp.enable_kernel_profile and sys.platform == "linux" + ) + if enable_kernel_profile: + code.writelines(["#include "]) + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + code.writeline(codecache.cpp_prefix()) + + code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})') + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + + codecache_def = IndentedBuffer() + if not V.graph.cpp_wrapper: + codecache_def.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") + codecache_def.splice(code) + if not V.graph.cpp_wrapper: + codecache_def.writeline("''')") + + codecache_str = codecache_def.getvalue() + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + codecache_str = codecache_str.replace("#pragma CMT", "//") + wrapper.define_kernel(kernel_name, codecache_str, cuda=False) + # generate the code to call this + wrapper.generate_kernel_call( + kernel_name, call_args, cuda=False, arg_types=arg_types + ) + + +class CppWrapperKernelGroup(KernelGroup): + def __init__(self): + super().__init__() + self.args = CppWrapperKernelArgs() + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.Integer(0) + steps: sympy.Expr = sympy.Integer(1) + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + reduction_var_map: Optional[Dict[str, str]] = None + parent: Optional["LoopLevel"] = None + # the next inner level of the loop, empty if it is inner-most + # contains >1 LoopLevel if the inner level of loop is split + inner: List["LoopLevel"] = dataclasses.field(default_factory=list) + # kernel assigned to this loop level, only valid when it is a leaf + kernel: Optional[CppKernel] = None + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `codecache.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `codecache.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def get_kernels(self) -> List[CppKernel]: + """Get all kernel objects under this loop level""" + if self.kernel: + return [self.kernel] + kernels = [] + for loop in self.inner: + kernels += loop.get_kernels() + return kernels + + def set_kernel(self, kernel: CppKernel): + """ + Set the kernel under this loop level. No split is allowed under + this loop level. + """ + if not self.inner: + self.kernel = kernel + loop: Optional[LoopLevel] = self + assert loop is not None + if loop.is_reduction(): + loop.reduction_var_map = kernel.reduction_var_map.copy() + loop = loop.parent + while loop is not None and loop.is_reduction(): + assert loop.reduction_var_map is not None + loop.reduction_var_map.update(kernel.reduction_var_map) + loop = loop.parent + return + assert len(self.inner) == 1 + self.inner[0].set_kernel(kernel) + + def get_loops_at(self, depth) -> List["LoopLevel"]: + if depth == 0: + return [self] + else: + loops = [] + for loop in self.inner: + loops += loop.get_loops_at(depth - 1) + return loops + + def is_reduction(self): + return bool(self.reduction_var_map) + + def split_with_tiling(self, depth, factor): + def clone_inner(): + inner = [] + if self.inner: + for loop in self.inner: + inner.append(loop.clone()) + return inner + + def do_split_with_tiling(): + sympy_factor = sympy.Integer(factor) + + offset = FloorDiv(self.size, sympy_factor) * sympy_factor + main_loop = LoopLevel(self.var, offset) + main_loop.steps = sympy_factor + main_loop.parallel = self.parallel + main_loop.collapsed = False + main_loop.reduction_var_map = self.reduction_var_map + main_loop.inner = clone_inner() + if main_loop.inner: + for loop in main_loop.inner: + loop.parent = main_loop + + tail_loop = LoopLevel(self.var, self.size) + tail_loop.offset = offset + tail_loop.parallel = self.parallel + tail_loop.collapsed = False + tail_loop.reduction_var_map = self.reduction_var_map + tail_loop.inner = clone_inner() + if tail_loop.inner: + for loop in tail_loop.inner: + loop.parent = tail_loop + + return main_loop, tail_loop + + if depth == 0: + main_loop, tail_loop = do_split_with_tiling() + parent = self.parent + if parent: + parent.inner = [main_loop, tail_loop] + main_loop.parent = parent + tail_loop.parent = parent + return main_loop, tail_loop + else: + assert len(self.inner) == 1 + return self.inner[0].split_with_tiling(depth - 1, factor) + + def clone(self): + loop = copy(self) + loop.inner = [] + if self.inner: + for inner_loop in self.inner: + inner_loop_clone = inner_loop.clone() + inner_loop_clone.parent = loop + loop.inner.append(inner_loop_clone) + loop.kernel = deepcopy(self.kernel) + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + if self.reduction_var_map: + reduction = " " + " ".join( + f"reduction({RTYPE_TO_CPP[rtype]}:{var})" + for var, rtype in self.reduction_var_map.items() + ) + else: + reduction = "" + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = f"#pragma omp for{reduction} " + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}{reduction}" + elif not self.reduction_var_map and codecache.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNestWithSplit: + """ + A loop-nest like structure but with some loop level split along + the loop range into the main tiling loop and the tail. It is built + with the `build` method as a loop nest and then split with + `split_with_tiling` at some depth. + + A typical case is for vectorization where we typically split at the inner-most + loop level. A more complicated case is 2D tiling where we split at + both inner-most and outer levels. + """ + + root: Optional[List[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + root: List[LoopLevel] = [] + levels: List[LoopLevel] = root + loop: Optional[LoopLevel] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size, parent=loop) + if loop_idx >= reduction_depth: + loop.reduction_var_map = kernel.reduction_var_map.copy() + levels.append(loop) + levels = loop.inner + loop_nest = LoopNestWithSplit(root) + if loop: + loop.kernel = kernel + else: + loop_nest.kernel = kernel + return loop_nest + + def __bool__(self): + return bool(self.root) + + def get_loops_at(self, depth) -> List[LoopLevel]: + """Get all the loop levels at the given `depth` (most outer loop has depth 0)""" + loops: List[LoopLevel] = [] + assert self.root is not None + for loop in self.root: + loops += loop.get_loops_at(depth) + return loops + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: + 1) Levels without splitting and + 2) All reduction or non-reduction levels + When the loop is split at the top level, the max depth is 1. + """ + max_depth = 0 + assert self.root is not None + loops = self.root + if len(loops) > 1: + return 1 + is_reduction = loops[0].is_reduction() if loops else False + while len(loops) == 1 and loops[0].is_reduction() == is_reduction: + max_depth += 1 + loops = loops[0].inner + return max_depth + + def is_reduction_only(self): + """ + Whether all the loops are for reduction. Reduction loops + are always the inner most ones. + """ + return ( + self.root is not None and len(self.root) > 0 and self.root[0].is_reduction() + ) + + def mark_parallel(self, par_depth): + assert ( + par_depth <= self.max_parallel_depth() + ), "Parallel depth cannot exceed the maximal allowed parallel depth" + assert self.root is not None + loops = self.root + for loop in loops: + loop.parallel = par_depth + for i in range(1, par_depth): + loops = loops[0].inner + loops[0].collapsed = True + + def split_with_tiling(self, depth, factor): + """ + Split the loop into main and tail loops at given `depth` so that the range + of the main loop has range `floor_div(range, factor) * factor` and + the tail loop handles the remainder. The main loop is tiled + according to the `factor`. + """ + loops = self.get_loops_at(depth) + assert len(loops) == 1 + split_loops = loops[0].split_with_tiling(0, factor) + if depth == 0: + self.root = split_loops + return split_loops diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_prefix.h b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_prefix.h new file mode 100644 index 0000000000000000000000000000000000000000..bfd9a7add180e3d8759ebe170027ee4106e5890e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_prefix.h @@ -0,0 +1,595 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) +#define INDUCTOR_USE_VECTOR_TYPES() 1 +#else +#define INDUCTOR_USE_VECTOR_TYPES() 0 +#endif + +#if INDUCTOR_USE_VECTOR_TYPES() +#include +#include +#include +#endif + +typedef at::Half half; +typedef at::BFloat16 bfloat16; + +typedef at::Float8_e4m3fn float8_e4m3fn; +typedef at::Float8_e5m2 float8_e5m2; + +template +struct Welford { + T mean = T(0); + T m2 = T(0); + T weight = T(0); +}; + + +template +struct IsVecType: std::false_type {}; + +#if INDUCTOR_USE_VECTOR_TYPES() +template +struct IsVecType>: std::true_type {}; +#endif + +template +Welford welford_combine(const Welford &a, const Welford &b) { + if constexpr (!IsVecType::value) { + if (a.weight == 0) { + return b; + } + if (b.weight == 0) { + return a; + } + } + auto delta = b.mean - a.mean; + auto new_weight = a.weight + b.weight; + auto wb_over_w = b.weight / new_weight; + if constexpr (IsVecType::value) { + // Guard against division by zero + wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0)); + } + auto result = Welford{ + a.mean + delta * wb_over_w, + a.m2 + b.m2 + delta * delta * a.weight * wb_over_w, + new_weight + }; + return result; +} + +template +Welford welford_combine(const Welford &acc, T data) { + // Add a single data point + auto delta = data - acc.mean; + auto new_weight = acc.weight + T(1); + auto new_mean = acc.mean + delta / new_weight; + auto new_delta = data - new_mean; + auto result = Welford{ + new_mean, + acc.m2 + delta * new_delta, + new_weight + }; + return result; +} + +// Refer to https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/ +// aten/src/ATen/native/SharedReduceOps.h#L419-L445 +template +inline bool greater_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); +} + +template +inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); +} + +#if INDUCTOR_USE_VECTOR_TYPES() +template +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using Vec = at::vec::Vectorized; + alignas(alignof(Vec)) scalar_t array[Vec::size()]; + x.store(array); + for (size_t i = 0; i + n < Vec::size(); i += 2 * n) { + array[i] = array[i + n]; + } + return Vec::loadu(array); +} + +#ifdef CPU_CAPABILITY_AVX2 +inline at::vec::Vectorized vec_shuffle_down(at::vec::Vectorized x, size_t n) { + using vec_t = at::vec::Vectorized; +#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w) + switch (n) { + case 1: + return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3))); + case 2: + return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2))); + case 4: + return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1))); + } + TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n); +} +#endif + +template +Welford welford_vec_reduce_all(Welford> acc) { + using Vec = at::vec::Vectorized; + for (size_t n = 1; n < Vec::size(); n *= 2) { + auto shuffled = Welford{ + vec_shuffle_down(acc.mean, n), + vec_shuffle_down(acc.m2, n), + vec_shuffle_down(acc.weight, n) + }; + acc = welford_combine(acc, shuffled); + } + + Welford result; + alignas(alignof(Vec)) scalar_t array[Vec::size()]; + acc.mean.store(array); + result.mean = array[0]; + + acc.m2.store(array); + result.m2 = array[0]; + + acc.weight.store(array); + result.weight = array[0]; + + return result; +} +#endif + + +template inline typename std::common_type::type mod(T a, U b) { return a % b; } +template <> inline float mod(float a, float b) { return std::fmod(a, b); } +template <> inline double mod(double a, double b) { return std::fmod(a, b); } + +template +inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a > b ? a : b; +} + +template +inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) { + if (at::_isnan(a)) { + return a; + } + return a < b ? a : b; +} + +constexpr float uint32_to_uniform_float(uint32_t value) { + // maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + constexpr float scale = 4.6566127342e-10; + return static_cast(value & 0x7FFFFFFF) * scale; +} + +float normalized_rand_cpu(uint32_t seed, uint32_t offset) { + return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)()); +} + +float randn_cpu(uint32_t seed, uint32_t offset) { + at::Philox4_32 engine(seed, 0, offset); + return engine.randn(10); +} + +int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_t high) { + auto gen = at::Philox4_32(seed, 0, offset); + uint64_t r0 = gen(); + uint64_t r1 = gen(); + uint64_t result = r0 | (r1 << 32); + return static_cast(result % (high - low)) + low; +} + +template struct AsIntegerType { typedef T type; }; +template <> struct AsIntegerType { typedef uint32_t type; }; +template <> struct AsIntegerType { typedef uint64_t type; }; +template <> struct AsIntegerType { typedef uint16_t type; }; + +template +typename std::enable_if::value, T>::type +inline fetch_value(volatile T *addr) { + return *addr; +} + +template +typename std::enable_if::value, T>::type +inline fetch_value(volatile T *addr) { + return T(addr->x, T::from_bits()); +} + +template +typename std::enable_if::value>::type +atomic_add(volatile T *addr, T offset) { + typedef typename AsIntegerType::type alt_type; + + static_assert(sizeof(std::atomic) == sizeof(T), + "std::atomic issue"); + + alt_type expected; + + alt_type desired; + + std::atomic *atomic_addr = (std::atomic *)addr; + do { + T val = fetch_value(addr); + reinterpret_cast(&expected)[0] = val; + reinterpret_cast(&desired)[0] = val + offset; + } while (!atomic_addr->compare_exchange_weak(expected, desired, + std::memory_order_relaxed)); +} + +// Since C++20 float is supported by fetch_add, but the performance may not +// better than compare_exchange_weak, which can be checked by microbenchmark +// inductor_cpu_atomic.py +template +typename std::enable_if::value>::type +atomic_add(volatile T *addr, T offset) { + static_assert(sizeof(std::atomic) == sizeof(T), + "std::atomic issue"); + std::atomic *atomic_addr = (std::atomic *)addr; + atomic_addr->fetch_add(offset, std::memory_order_relaxed); +} + +// This function is used to convert bool or uint8 to float mask for +// vectorization. The caller needs to make sure the src represents TRUE/FALSE +// correctly. +template +inline float flag_to_float_scalar(T src) { + float ret; + *(uint32_t*)(&ret) = src ? 0xFFFFFFFF : 0; + return ret; +} + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) + +inline at::vec::Vectorized masked_load(const float* src, at::vec::Vectorized mask) { +# if defined(CPU_CAPABILITY_AVX512) + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_loadu_ps(zero_vec, mmask, src); +# elif defined(CPU_CAPABILITY_AVX2) + auto all_ones = _mm256_set1_epi32(0xFFFFFFFF); + auto mmask = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones); + return _mm256_maskload_ps(src, mmask); +# elif defined(CPU_CAPABILITY_ZVECTOR) + auto result = at::vec::Vectorized::loadu(src); + return (result & mask); +# else +# error Unsupported vectorization CPU capability +# endif +} + +template +typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type +inline masked_load(const T* src, at::vec::Vectorized mask) { +# if defined(CPU_CAPABILITY_AVX512) + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp = _mm256_mask_loadu_epi16(zero, mmask, src); + return _mm512_inserti32x8(_mm512_castsi256_si512(temp), zero, 1); +# elif defined(CPU_CAPABILITY_AVX2) + auto all_ones = _mm256_set1_epi32(0xFFFFFFFF); + auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones); + __at_align__ uint32_t mmask[8]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec); + __at_align__ uint16_t result[16]; + for (auto i = 0; i < 8; i++) { + result[i] = mmask[i] == 0xFFFFFFFF ? src[i].x: uint16_t(0); + } + return at::vec::Vectorized::loadu(result); +# elif defined(CPU_CAPABILITY_ZVECTOR) + auto result = at::vec::Vectorized::loadu(src, 8); + uint32_t maskdata[8] = { 0 }; + uint16_t maskdata_dest[16] = { 0 }; + mask.store(maskdata); + for (auto i = 0; i < 8; i++) { + maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFFFF: 0; + } + auto maskvector = at::vec::Vectorized::loadu(maskdata_dest); + return (result & maskvector); +# else +# error Unsupported vectorization CPU capability +# endif +} + +template +typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type +inline masked_load(const T* src, at::vec::Vectorized mask) { +# if defined(CPU_CAPABILITY_AVX512) + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ); + auto zero = _mm_set1_epi8(0); + auto temp = _mm_mask_loadu_epi8(zero, mmask, src); + return _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0); +# elif defined(CPU_CAPABILITY_AVX2) + auto all_ones = _mm256_set1_epi32(0xFFFFFFFF); + auto mmask_vec = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones); + __at_align__ uint32_t mmask[8]; + _mm256_storeu_si256(reinterpret_cast<__m256i*>(mmask), mmask_vec); + __at_align__ T result[32]; + for (auto i = 0; i < 8; i++) { + result[i] = mmask[i] == 0xFFFFFFFF ? src[i]: T(0); + } + return at::vec::Vectorized::loadu(result); +# elif defined(CPU_CAPABILITY_ZVECTOR) + auto result = at::vec::Vectorized::loadu(src, 8); + uint32_t maskdata[8]; + T maskdata_dest[32] = { 0 }; + mask.store(maskdata); + for (auto i = 0; i < 8; i++) { + maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFF: 0; + } + auto maskvector = at::vec::Vectorized::loadu(maskdata_dest); + return (result & maskvector); +# else +# error Unsupported vectorization CPU capability +# endif +} + +template +inline at::vec::Vectorized flag_to_float_vec(const T* src) { + __at_align__ float dst_tmp[at::vec::Vectorized::size()]; + #pragma unroll + for (int64_t i = 0; i < at::vec::Vectorized::size(); i++) { + dst_tmp[i] = flag_to_float_scalar(src[i]); + } + return at::vec::Vectorized::loadu(dst_tmp); +} + +template +inline at::vec::Vectorized cvt_lowp_fp_to_fp32( + at::vec::Vectorized src) { + at::vec::Vectorized res_vec1(0); + at::vec::Vectorized res_vec2(0); + std::tie(res_vec1, res_vec2) = at::vec::convert_to_float(src); + return res_vec1; +} + +template +inline at::vec::Vectorized cvt_fp32_to_lowp_fp( + at::vec::Vectorized src) { + return at::vec::convert_from_float(src, src); +} + +inline at::vec::Vectorized mask_convert_to_float(at::vec::Vectorized src) { + auto zeros = at::vec::Vectorized(0); + auto ones = at::vec::Vectorized(1); + return at::vec::Vectorized::blendv(zeros, ones, src); +} + +template +inline +typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type +mask_convert_to_lowp(at::vec::Vectorized src) { + auto fp_vec = mask_convert_to_float(src); + return cvt_fp32_to_lowp_fp(fp_vec); +} + +template +inline at::vec::Vectorized vec_convert_to_mask(at::vec::Vectorized src) { + assert( + at::vec::Vectorized::size() == at::vec::Vectorized::size()); + at::vec::Vectorized res_vec(0); + __at_align__ float dst_tmp[at::vec::Vectorized::size()]; + __at_align__ SRC src_tmp[at::vec::Vectorized::size()]; + src.store(src_tmp); + +#pragma unroll + for (int i = 0; i < at::vec::Vectorized::size(); i++) { + *(uint32_t*)(dst_tmp + i) = src_tmp[i] ? 0xFFFFFFFF : 0; + } + + return res_vec.loadu(dst_tmp); +} + +template +inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { + return vec_convert_to_mask(src); +} + +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) +template <> +inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { +#if defined(CPU_CAPABILITY_AVX2) + return at::vec::Vectorized(_mm256_castsi256_ps(src)); +#else + return at::vec::Vectorized(_mm512_castsi512_ps(src)); +#endif +} +#endif + +template <> +inline at::vec::Vectorized to_float_mask(at::vec::Vectorized src) { + return src; +} + +inline at::vec::Vectorized to_float_mask(int src) { + union { + float fmask; + uint32_t imask; + } mask; + mask.imask = src ? 0xFFFFFFFF : 0; + return at::vec::Vectorized(mask.fmask); +} + +inline bool all_zero(at::vec::Vectorized src) { +# if defined(CPU_CAPABILITY_AVX512) + auto src_int = _mm512_castps_si512(src); + __mmask16 mask = _mm512_test_epi32_mask(src_int, src_int); + return mask == 0; +# elif defined(CPU_CAPABILITY_AVX2) + return _mm256_testz_ps(src, src); +# else + __at_align__ int mask[at::vec::Vectorized::size()]; + src.store(mask); + for (int i = 0; i < at::vec::Vectorized::size(); i++) { + if (mask[i] != 0) { + return false; + } + } + return true; +# endif +} + +inline bool vector_lane_mask_check(at::vec::Vectorized src, int lane) { +# if defined(CPU_CAPABILITY_AVX512) + return _mm512_movepi32_mask(_mm512_castps_si512(src)) & (1 << lane); +# elif defined(CPU_CAPABILITY_AVX2) + return _mm256_movemask_ps(src) & (1 << lane); +# else + __at_align__ int mask[at::vec::Vectorized::size()]; + src.store(mask); + return mask[lane] != 0; +# endif +} + +inline at::vec::Vectorized cvt_int64_to_fp32(at::vec::VectorizedN src) { +# if defined(CPU_CAPABILITY_AVX512) + auto low = _mm512_cvtepi64_ps(src[0]); + auto high = _mm512_cvtepi64_ps(src[1]); + return _mm512_insertf32x8(_mm512_castps256_ps512(low), high, 1); +# elif defined(CPU_CAPABILITY_AVX2) + auto low_double = at::vec::convert_to_fp_of_same_size(src[0]); + auto low = _mm256_cvtpd_ps(low_double); + auto high_double = at::vec::convert_to_fp_of_same_size(src[1]); + auto high = _mm256_cvtpd_ps(high_double); + return _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1); +# else + constexpr int float_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ float result[float_vec_size]; + __at_align__ int64_t src_buf[int64_vec_size]; + for (int i = 0; i < 2; i++) { + src[i].store(src_buf + i * int64_vec_size); + for (int j = 0; j < int64_vec_size; j++) { + result[i * int64_vec_size + j] = static_cast(src_buf[i * int64_vec_size + j]); + } + } + return at::vec::Vectorized::loadu(result); +# endif +} + +inline at::vec::VectorizedN cvt_fp32_to_int64(at::vec::Vectorized src) { + at::vec::VectorizedN result; +# if defined(CPU_CAPABILITY_AVX512) + result[0] = _mm512_cvt_roundps_epi64(_mm512_castps512_ps256(src), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); + result[1] = _mm512_cvt_roundps_epi64(_mm512_extractf32x8_ps(src, 1), _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); +# elif defined(CPU_CAPABILITY_AVX2) + auto int32_vec = at::vec::convert_to_int_of_same_size(src); + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(int32_vec)); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(int32_vec, 1)); +# else + constexpr int float_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ float src_buf[float_vec_size]; + __at_align__ int64_t result_buf[int64_vec_size]; + src.store(src_buf); + for (int i = 0; i < 2; i++) { + for (int j = 0; j < int64_vec_size; j++) { + result_buf[j] = static_cast(src_buf[i * int64_vec_size + j]); + } + result[i] = at::vec::Vectorized::loadu(result_buf); + } +# endif + return result; +} + +inline at::vec::Vectorized cvt_int64_to_int32(at::vec::VectorizedN src) { +# if defined(CPU_CAPABILITY_AVX512) + auto low = _mm512_cvtepi64_epi32(src[0]); + auto high = _mm512_cvtepi64_epi32(src[1]); + return _mm512_inserti32x8(_mm512_castsi256_si512(low), high, 1); +# elif defined(CPU_CAPABILITY_AVX2) + auto low = _mm256_shuffle_epi32(src[0], _MM_SHUFFLE(2, 0, 2, 0)); + auto high = _mm256_shuffle_epi32(src[1], _MM_SHUFFLE(2, 0, 2, 0)); + auto low_perm = _mm256_permute4x64_epi64(low, _MM_SHUFFLE(3, 1, 2, 0)); + auto high_perm = _mm256_permute4x64_epi64(high, _MM_SHUFFLE(3, 1, 2, 0)); + return _mm256_blend_epi32(low_perm, high_perm, 0xF0); +# else + constexpr int int32_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ int32_t result[int32_vec_size]; + __at_align__ int64_t src_buf[int64_vec_size]; + for (int i = 0; i < 2; i++) { + src[i].store(src_buf + i * int64_vec_size); + for (int j = 0; j < int64_vec_size; j++) { + result[i * int64_vec_size + j] = static_cast(src_buf[i * int64_vec_size + j]); + } + } + return at::vec::Vectorized::loadu(result); +# endif +} + +inline at::vec::VectorizedN cvt_int32_to_int64(at::vec::Vectorized src) { + at::vec::VectorizedN result; +# if defined(CPU_CAPABILITY_AVX512) + result[0] = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(src)); + result[1] = _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(src, 1)); +# elif defined(CPU_CAPABILITY_AVX2) + result[0] = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(src)); + result[1] = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(src, 1)); +#else + constexpr int int32_vec_size = at::vec::Vectorized::size(); + constexpr int int64_vec_size = at::vec::Vectorized::size(); + __at_align__ int32_t src_buf[int32_vec_size]; + __at_align__ int64_t result_buf[int64_vec_size]; + src.store(src_buf); + for (int i = 0; i < 2; i++) { + for (int j = 0; j < int64_vec_size; j++) { + result_buf[j] = static_cast(src_buf[i * int64_vec_size + j]); + } + result[i] = at::vec::Vectorized::loadu(result_buf); + } +# endif + return result; +} + +inline at::vec::VectorizedN mask_convert_to_int64(at::vec::Vectorized src) { + return cvt_fp32_to_int64(mask_convert_to_float(src)); +} + +inline at::vec::Vectorized to_float_mask(at::vec::VectorizedN src) { + return to_float_mask(cvt_int64_to_int32(src)); +} + +#endif diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..e97b538b3693f910d6630dbf2171807d37cd98b4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -0,0 +1,1851 @@ +import functools +import os +import sys +from itertools import count +from typing import List, Optional, Tuple + +import sympy +from sympy import Expr + +import torch +import torch._ops +from .. import config, ir + +from ..codecache import CudaKernelParamCache +from ..utils import cache_on_self, sympy_product +from ..virtualized import V +from .common import IndentedBuffer +from .wrapper import EnterSubgraphLine, ExitSubgraphLine, pexpr, WrapperCodeGen + + +class CppWrapperCpu(WrapperCodeGen): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + """ + + def __init__(self): + if not hasattr(self, "device"): + self.device = "cpu" + super().__init__() + self.declare = "auto " + self.declare_maybe_reference = "decltype(auto) " + self.ending = ";" + self.open_bracket = "{" + self.closed_bracket = "}" + self.comment = "//" + self.namespace = "at::" + self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()" + self.extern_call_ops = set() + self.size = "sizes()" + self.stride = "strides()" + self.cuda = False + self.supports_intermediate_hooks = False + self.outputs_need_copy = set() + self.kernel_callsite_id = count() + self.int_array_id = count() # for int array local variable declarations + self.declared_int_array_vars = set() + self.tmp_tensor_id = count() # for tmp tensor local variable declarations + self.arg_var_id = count() + self.used_cached_devices = set() + self.used_cached_dtypes = set() + self.cached_output_id = count() + self.scalar_to_tensor_id = count() + + from .cpp import cexpr, CppPrinter + + self.expr_printer = cexpr + + # CppPrinter sometimes calls at::native functions which causes problems in + # the ABI-compatible mode. Currently we are hitting this problem when codegen + # Grid computation expressions, but we my need to fix other size computation + # as well. + class GridExprCppPrinter(CppPrinter): + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + assert expr.is_integer, "Expect integers in GridExprPrinter" + return f"({x}/{div})" + + self.grid_expr_printer = GridExprCppPrinter().doprint + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + triton_meta=None, + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if cuda: + return super().generate_kernel_call( + name, + call_args, + grid, + device_index, + cuda, + triton, + arg_types, + grid_fn, + ) + else: + if config.abi_compatible: + assert arg_types is not None and len(call_args) == len( + arg_types + ), "Mismatch call_args and arg_types in generate_kernel_call" + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"auto* {var_name} = get_data_ptr_wrapper({arg});" + ) + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + self.writeline(self.wrap_kernel_call(name, new_args)) + else: + self.writeline(self.wrap_kernel_call(name, call_args)) + + def write_constant(self, name, hashed): + # include a hash so our code cache gives different constants different files + self.header.writeline(f"// {name} {hashed}") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + for header_cpp_file in ("interface.cpp", "implementation.cpp"): + with open( + os.path.join( + os.path.dirname(__file__), "aoti_runtime", header_cpp_file + ) + ) as f: + self.header.splice(f.read()) + else: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + ''' + """ + ) + + if config.abi_compatible: + if config.c_shim_version == "1": + self.header.splice("#include ") + else: + self.header.splice( + f"#include " + ) + self.header.splice( + """ + #include + #include + #include + """ + ) + if V.graph.aot_mode: + self.header.splice( + """ + #include + """ + ) + else: + self.header.splice( + """ + #include + #include + #include + #include + #include + #include + #include + #include + + #define reinterpret_tensor torch::inductor::_reinterpret_tensor + #define alloc_from_pool torch::inductor::_alloc_from_pool + """ + ) + + self.header.splice("#include ") + + if not V.graph.aot_mode: + self.header.splice( + """ + #include + + using namespace torch::aot_inductor; + """ + ) + + from .memory_planning import ALIGN_BYTES + + # Round up to the nearest multiple of ALIGN_BYTES + # ALIGN_BYTES must be a power of 2 + self.header.splice( + f""" + [[maybe_unused]] static int64_t align(int64_t nbytes) {{ + return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES}; + }} + """ + ) + + def mark_output_type(self): + # mark output type to unwrap tensor back to python scalar + from ..ir import ShapeAsConstantBuffer + + output_is_tensor = dict() + for idx, x in enumerate(V.graph.graph_outputs): + if isinstance(x, ShapeAsConstantBuffer): + output_is_tensor[idx] = False + else: + output_is_tensor[idx] = True + + self.output_is_tensor = output_is_tensor + + def write_prefix(self): + if V.graph.is_const_graph: + # We do not write prefix for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + self.prefix.writeline("namespace torch {") + self.prefix.writeline("namespace aot_inductor {") + + def write_input_output_info( + self, + info_kind: str, + idx: int, + name: str, + ): + self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") + + @staticmethod + def get_input_cpp_type(input): + assert config.use_minimal_arrayref_interface + from .cpp import DTYPE_TO_CPP + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: + from .cpp import DTYPE_TO_CPP + + input_cpp_types = ", ".join( + f"{CppWrapperCpu.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + self.prefix.splice(V.graph.const_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + if config.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + self.prefix.splice( + """ + pybind11::gil_scoped_release release; + """ + ) + + if config.abi_compatible: + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + else: + # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime. + self.prefix.splice( + f""" + auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + from .cpp import DTYPE_TO_CPP + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] + ) + assert ( + dtype is not None + ), "Fails to get the dtype of the sympy.Expr" + cpp_dtype = DTYPE_TO_CPP[dtype] + if config.abi_compatible: + self.prefix.writeline(f"{cpp_dtype} {input_key};") + dtype_str = str(dtype).split(".")[-1] + self.prefix.writeline( + f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});" + ) + else: + self.prefix.writeline( + f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();" + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + if config.abi_compatible: + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + self.prefix.writeline( + f"auto {constants_key} = *tensor_handle_to_tensor_pointer(" + + f"""constants_->at({idx}));""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"auto {constants_key} = inputs[{constants_idx}];" + ) + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "auto& kernels = static_cast(*this->kernels_.get());" + ) + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int64_t* {name}_size;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));" + ) + else: + super().codegen_input_size_var_decl(code, name) + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + if config.abi_compatible: + code.writeline(f"int64_t* {name}_stride;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));" + ) + else: + super().codegen_input_stride_var_decl(code, name) + + def codegen_model_kernels(self): + self.prefix.writeline("namespace {") + self.prefix.writeline( + "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" + ) + self.prefix.writeline(" public:") + declare_kernel = set(self.src_to_kernel.values()) + declare_kernel.update( + entry[0] for entry in self.user_defined_kernel_cache.values() + ) + if V.graph.const_module: + declare_kernel.update( + V.graph.const_module.wrapper_code.src_to_kernel.values() + ) + for kernel in declare_kernel: + self.prefix.writeline(f" CUfunction {kernel}{{nullptr}};") + self.prefix.writeline("};") + self.prefix.writeline("} // namespace") + + def codegen_model_constructor(self): + """ + // Generated code example + AOTInductorModel::AOTInductorModel() + : AOTInductorModelBase(4, 1) { + inputs_info_[0].name = "input0"; + inputs_info_[0].dtype = "torch.float16"; + ... + constants_info_[0].name = "L__self___weight"; + constants_info_[0].dtype = at::kFloat; + constants_info_[0].offset = 0; + constants_info_[0].data_size = 8192; + constants_info_[0].shape = {64, 32}; + constants_info_[0].stride = {32, 1}; + ... + outputs_info_[0].name = "output0"; + outputs_info_[0].dtype = "torch.float16"; + } + """ + + num_inputs = len(V.graph.graph_inputs) + num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) + self.prefix.splice( + f""" + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) + : AOTInductorModelBase({num_inputs}, {num_outputs}, {num_constants}, device_str, cubin_dir) {{ + """ + ) + + with self.prefix.indent(): + for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): + assert not isinstance( + inp, sympy.Expr + ), f"input {name=} cannot be symbolic" + self.write_input_output_info("inputs_info_", idx, name) + + for idx, (name, tensor) in enumerate(V.graph.constants.items()): + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};" + ) + from_folded = "true" if name in V.graph.folded_constants else "false" + self.prefix.writeline( + f"constants_info_[{idx}].from_folded = {from_folded};" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + if name in V.graph.dynamo_flat_name_to_original_fqn: + original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( + name, name + ) + elif name in V.graph.allocated_constant_name: + original_fqn = V.graph.allocated_constant_name[name] + else: + raise AssertionError("original_fqn must be set for constant") + self.prefix.writeline( + f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" + ) + self.prefix.writeline("update_constants_map(std::move(constants_map));") + self.prefix.writeline("update_constants_array(std::move(constants_array));") + + def escape_string(x): + return ( + x.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\t", "\\t") + ) + + self.prefix.writeline( + f'in_spec_ = "{escape_string(config.aot_inductor.serialized_in_spec)}";' + ) + self.prefix.writeline( + f'out_spec_ = "{escape_string(config.aot_inductor.serialized_out_spec)}";' + ) + + for idx, output in enumerate(V.graph.graph_outputs): + assert not isinstance( + output, sympy.Expr + ), f"output {name=} cannot be symbolic" + name = f"output{idx}" + self.write_input_output_info("outputs_info_", idx, name) + + self.prefix.writeline( + "this->kernels_ = std::make_unique();" + ) + + self.prefix.writeline("}") + + def codegen_const_run_driver(self): + """ + // Generated code example + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + std::unordered_map folded_constants_map; + std::vector output_handles; + // build up output_handles over here. + _const_run_impl(output_handles, stream, proxy_executor); + // build up folded_constants_map + return folded_constants_map; + } + """ + + self.prefix.splice( + """ + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + """ + ) + if not config.aot_inductor.use_runtime_constant_folding: + self.prefix.splice( + """ + if (!initialization) { + std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " + << "aot_inductor.use_runtime_constant_folding=False\\n"; + } + return {}; + } + """ + ) + return + + with self.prefix.indent(): + # This is a mapping to the index of constant folding graph's output + const_index_mapping: List[Optional[Tuple[int, str]]] = [None] * len( + V.graph.const_output_index + ) + for idx, (name, _) in enumerate(V.graph.constants.items()): + if name in V.graph.const_output_index: + const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] + assert ( + None not in const_index_mapping + ), "Not all constant gets mapped for constant folding graph." + + self.prefix.writeline( + f""" + std::unordered_map folded_constants_map; + folded_constants_map.reserve({len(const_index_mapping)}); + std::vector output_handles({len(const_index_mapping)}); + """ + ) + + self.prefix.splice( + """ + // The below assignment of output_handles to constants is not used directly. + // It's only used to memo the correspondence of handle and constants. + """ + ) + + for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f"output_handles[{output_idx}] = constants_->at({const_idx});" + ) + + self.prefix.writeline( + "_const_run_impl(output_handles, stream, proxy_executor);" + ) + + for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' + ) + self.prefix.writeline("return folded_constants_map;") + + self.prefix.writeline("}") + + def generate(self, is_inference): + if V.graph.aot_mode and not V.graph.is_const_graph: + self.codegen_model_kernels() + self.codegen_model_constructor() + self.codegen_const_run_driver() + self.write_wrapper_decl() + return super().generate(is_inference) + + def finalize_prefix(self): + cached_dtypes_buffer = IndentedBuffer() + if config.abi_compatible: + for dtype in self.used_cached_dtypes: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") + cached_dtypes_buffer.splice(self.prefix) + self.prefix = cached_dtypes_buffer + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False + ): + self.header.splice(f"\n{kernel}\n") + + def codegen_scalar_to_tensor(self, output: str): + name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + self.wrapper_call.writeline( + f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" + ) + return name + + @cache_on_self + def get_output_refs(self): + return [ + f"torch::tensor({x.codegen_reference(self.wrapper_call)})" + if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible + else x.codegen_reference(self.wrapper_call) + for x in V.graph.graph_outputs + ] + + def generate_return(self, output_refs): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph and config.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + for idx, output in enumerate(output_refs): + if config.abi_compatible: + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = ( + f"cached_output_{next(self.cached_output_id)}" + ) + output_value_type = f"std::decay_t(output_arrayref_tensors).data()[0])>" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if output in cst_names: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if output in cst_names: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + else: + assert ( + not arr_iface + ), "minimal ArrayRef interface is only supported in ABI-compatible mode" + if output in cst_names: + output_expr = f"{output}.clone()" + # See NOTE(return_constant) above. + else: + output_expr = output + self.wrapper_call.writeline( + f"output_handles[{idx}] = reinterpret_cast(" + + f"new at::Tensor({output_expr}));" + ) + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") + + def generate_before_suffix(self, result): + if not V.graph.is_const_graph: + if V.graph.aot_mode: + result.writeline("} // AOTInductorModel::run_impl") + else: + result.writeline("} // inductor_entry_impl") + + def generate_end(self, result): + if V.graph.aot_mode: + if V.graph.is_const_graph: + result.writeline("} // AOTInductorModel::_const_run_impl") + else: + result.writeline("} // namespace aot_inductor") + result.writeline("} // namespace torch") + return + + result.writeline("'''\n)") + result.splice( + f""" + inductor_entry = CppWrapperCodeCache.load_pybinding( + ["std::vector"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)}) + """ + ) + + # unwrap output tensor back to python scalar + if all(x for x in self.output_is_tensor.values()): + # If no ShapeAsConstantBuffer in the output, directly return the output as tensors + return_str = "return f(args_tensor)" + else: + outputs = [ + f"outputs[{i}]" if self.output_is_tensor[i] else f"outputs[{i}].item()" + for i in range(len(V.graph.graph_outputs)) + ] + outputs_str = f"[{', '.join(outputs)}]" + return_str = f""" + outputs = f(args_tensor) + return {outputs_str} + """ + + args_str = "args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]" + if V.graph.constants: + # Append constants to the input args for cpp wrapper. + # Python wrapper directly gets the value inside the wrapper call + # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). + # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + constants_str = f"[{', '.join(V.graph.constants.keys())}]" + args_str += f""" + constants_tensor = {constants_str} + args_tensor.extend(constants_tensor) + """ + + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + {args_str} + {return_str} + return g + call = _wrap_func(inductor_entry) + """ + ) + + def generate_c_shim_extern_kernel_call(self, kernel, args): + # In the abi_compatible mode, we call fallback aten ops through a C shim layer + self.allow_stack_allocation = False + kernel_tokens = kernel.split("::") + kernel_suffix = kernel_tokens[-1] + if kernel_suffix == "call": + kernel_suffix = kernel_tokens[-2] + if config.c_shim_version == "1": + shim_fn = f"aoti_torch_{kernel_suffix}" + else: + shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" + + # HACK: val_to_arg_str jams multiple arguments together using a comma. If that + # ever breaks, it needs to be reworked to be able to return multiple arguments, + # and the split-on-comma code here needs to be removed. + wrapped_args = [] + for x in args: + pieces = x.split(", ") + for piece in pieces: + # We only really *need* convert_arrayref_tensor_to_tensor for + # ArrayRefTensors. The code flowing into here uses `0` for nullptr, + # which convert_arrayref_tensor_to_tensor would blindly coerce to int, + # so just avoid wrapping integers. + if not piece.isdigit(): + piece = f"convert_arrayref_tensor_to_tensor({piece})" + wrapped_args.append(piece) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(wrapped_args)}));" + ) + + def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): + # registered output buffer name + name = extern_kernel.name + output_handle_name = f"{name}_handle" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_arg = f"&{output_handle_name}" + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args + [output_arg] + ) + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + def generate_extern_kernel_alloc(self, extern_kernel, args): + if config.abi_compatible: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + else: + super().generate_extern_kernel_alloc(extern_kernel, args) + + def generate_c_shim_fallback_kernel(self, fallback_kernel, args): + output_args = [] + output_raii_handles = [] + output_name_base = fallback_kernel.get_name() + for idx, output in enumerate(fallback_kernel.outputs): + if isinstance(output, ir.MultiOutput): + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + if output.indices: + assert ( + output.indices[0][1] == idx + ), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError("unsupported type of {output=}") + args = args + output_args + assert ( + fallback_kernel.abi_compatible_kernel is not None + ), f"abi_compatible_kernel is None for {fallback_kernel.python_kernel_name=}" + self.generate_c_shim_extern_kernel_call( + fallback_kernel.abi_compatible_kernel, args + ) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + + def generate_fallback_kernel(self, fallback_kernel, args): + if config.abi_compatible: + self.generate_c_shim_fallback_kernel(fallback_kernel, args) + else: + super().generate_fallback_kernel(fallback_kernel, args) + + def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel): + if output_view: + output_as_strided = f"{output_view.codegen_reference()}" + output_name = f"{output_view.get_name()}_as_strided" + self.writeline(f"auto {output_name} = {output_as_strided};") + + args.insert(0, output_name) + else: + args.insert(0, f"{codegen_reference}") + + if config.abi_compatible: + self.generate_c_shim_extern_kernel_call(kernel, args) + else: + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_user_defined_triton_kernel( + self, kernel_name, grid, configs, args, triton_meta + ): + assert len(grid) != 0 + if len(grid) == 1: + grid_decision = grid[0] + else: + meta = CudaKernelParamCache.get(kernel_name) + assert meta is not None + grid_decision = None + for i, c in enumerate(configs): + if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()): + grid_decision = grid[i] + break + assert grid_decision is not None + + self.generate_kernel_call( + kernel_name, + args, + grid=grid_decision, + device_index=V.graph.scheduler.current_device.index, + cuda=True, + triton=True, + triton_meta=triton_meta, + ) + + def generate_scatter_fallback( + self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs + ): + # TODO: support other overload for cpp wrapper and remove the below assertions + if config.abi_compatible: + # call the ABI shim function instead of the ATen one + kernel = kernel.replace("at::", "aoti_torch_") + line = f"{kernel}({output}, {','.join(map(str, inputs))}" + if python_kernel_name == "aten.scatter_": + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert ( + reduce is None + ), "Expect reduce to be None for aten.scatter_ with scalar src" + else: + line += f", {','.join(kwargs)}" + line += f"){self.ending}" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + if V.graph.aot_mode and V.graph.cpp_wrapper and config.abi_compatible: + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # tensor prematurely deallocated, thus this std::vector().data() trick here. + indices_str = ( + f"std::vector{{{', '.join(indices)}}}.data()" + ) + args = [x, indices_str, str(len(indices)), values, accumulate] + else: + indices_str = ( + f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + ) + args = [x, indices_str, values, accumulate] + + args.insert(0, x) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def add_benchmark_harness(self, output): + if V.graph.aot_mode: + return + super().add_benchmark_harness(output) + + def codegen_sizevar(self, x: Expr) -> str: + return self.expr_printer(V.graph.sizevars.simplify(x)) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + if config.abi_compatible: + # in the abi_compatible mode, outputs are returned via arguments + return name + else: + return f"std::get<{index}>({basename})" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_sizevar, shape)) + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def codegen_dynamic_scalar(self, node): + from .cpp import DTYPE_TO_ATEN, DTYPE_TO_CPP + + (data,) = (t.codegen_reference() for t in node.inputs) + if config.abi_compatible: + dtype = node.inputs[0].get_dtype() + dtype_str = str(dtype).split(".")[-1] + self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym};") + self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym});") + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.sym)) + else: + if node.is_bool: + self.writeline(f"bool {node.sym} = {data}.item() ? 1 : 0;") + else: + convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( + "at::k", "to" + ) + self.writeline(f"auto {node.sym} = {data}.item().{convert_type}();") + + def can_stack_allocate_buffer(self, buffer): + return ( + self.allow_stack_allocation + and buffer.get_device().type == "cpu" + and self.can_prove_buffer_has_static_shape(buffer) + and ir.is_contiguous_strides_for_shape( + buffer.get_stride(), buffer.get_size() + ) + ) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_layout(), ir.MultiOutputLayout) + or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) + or ( + config.use_minimal_arrayref_interface + and V.graph.aot_mode + and buffer.get_name() in V.graph.graph_inputs + ) + else f"{buffer.get_name()}.reset();" + ) + + def make_free_by_names(self, names_to_del: List[str]): + return " ".join(f"{name}.reset();" for name in names_to_del) + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + if config.abi_compatible: + return f"auto {new_name} = std::move({old_name}); // reuse" + else: + return super().codegen_exact_buffer_reuse(old_name, new_name, del_line) + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline( + 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' + ) + + def write_triton_header_once(self): + pass + + def generate_start_graph(self): + pass + + def generate_end_graph(self): + pass + + def generate_inf_and_nan_checker(self, nodes): + for buf in nodes.get_names(): + # TODO: Add buf name directly into check_inf_and_nan. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_check_inf_and_nan({buf}));" + ) + + def codegen_device(self, device): + if config.abi_compatible: + self.used_cached_devices.add(device.type) + return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}" + else: + from .cpp import DEVICE_TO_ATEN + + return ( + f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" + if device.index is not None + else f"{DEVICE_TO_ATEN[device.type]}" + ) + + def codegen_dtype(self, dtype): + if config.abi_compatible: + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" + else: + from .cpp import DTYPE_TO_ATEN + + return DTYPE_TO_ATEN[dtype] + + @functools.lru_cache(None) + def codegen_int_array_var( + self, + int_array: str, + writer=None, + known_statically=False, + graph=None, # for per-graph caching + ): + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass + if writer is None: + # The first pass codegen uses `self` as the writer + writer = self + + var = f"int_array_{next(self.int_array_id)}" + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writer.writeline(f"static constexpr int64_t {var}[] = {int_array};") + else: + writer.writeline(f"int64_t {var}[] = {int_array};") + return var + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + buffer if self.can_stack_allocate_buffer(buffer) else None, + ) + + def make_allocation( + self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + ): + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(orig_stride) + if config.abi_compatible: + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + from .cpp import DTYPE_TO_CPP + + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") + args = [ + f"{name}_storage", + size_array_var, + stride_array_var, + device_type, + device_idx, + ] + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + if V.graph.aot_mode and device_str.startswith("c10::Device("): + tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)" + else: + tensor_device = device_str + + if device.type == "cpu": + return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});" + if device.type == "cuda": + return ( + f"at::Tensor {name} = at::detail::empty_strided_cuda(" + f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" + ) + return ( + f"{self.declare}{name} = {self.namespace}empty_strided(" + f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" + ) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + if config.abi_compatible: + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + pexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" + ) + return f"RAIIAtenTensorHandle({tmp_name})" + + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + self.codegen_shape_tuple(shape), + self.codegen_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view( + self, data, size_list, stride_list, offset, writer + ) -> str: + dim = str(len(size_list)) + size = self.codegen_shape_tuple(size_list) + stride = self.codegen_shape_tuple(stride_list) + offset = self.codegen_sizevar(offset) + + if config.abi_compatible: + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + if writer is None: + writer = self + + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + size, + writer, + known_statically=self.is_statically_known_list_of_ints(size_list), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + stride, + writer, + known_statically=self.is_statically_known_list_of_ints(stride_list), + graph=self.get_codegened_graph(), + ), + offset, + ] + + def gen_reinterpret_call(writer, args): + writer.writeline( + f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" + ) + + if ( + self.can_stack_allocate_buffer(data) + and self.is_statically_known_list_of_ints(size_list) + and self.is_statically_known_list_of_ints(stride_list) + and ir.is_contiguous_strides_for_shape(stride_list, size_list) + ): + gen_reinterpret_call(writer, args) + return tmp_name + + gen_reinterpret_call(writer, args) + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::vector{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.data() + # ); + # ``` + return f"wrap_with_raii_handle_if_needed({tmp_name})" + else: + args = [data.get_name(), size, stride, offset] + return f"reinterpret_tensor({', '.join(args)})" + + def codegen_device_copy(self, src, dst): + if config.abi_compatible: + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_tensor_copy_(expensive_copy_to_tensor_if_needed({src}), {dst}));" + ) + else: + self.writeline(f"{dst}.copy_({src});") + + def codegen_multi_output(self, name, value): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + if not config.abi_compatible: + super().codegen_multi_output(name, value) + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): + if config.abi_compatible: + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline( + f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);" + ) + else: + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + src = inner_output.codegen_reference() + if config.abi_compatible: + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + self.writeline(f"{outer_output} = {src}{self.ending}") + + def codegen_conditional(self, conditional): + name = conditional.get_name() + outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] + if config.abi_compatible: + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + predicate = f"{conditional.predicate.get_name()}_scalar" + self.writeline(f"bool {predicate};") + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the unrelying scalar bool Tensor + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));" + ) + else: + # in non-ABI-compatible mode, we can codegen the conditional outputs + # as array of at::Tensor instances, as the ir.MultiOutput is codegened + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];") + predicate = f"{conditional.predicate.codegen_reference()}.item()" + + self.writeline(f"if ({predicate}) {{") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("} else {") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def generate_extern_kernel_args_decl_if_needed( + self, op_overload, raw_args, output_args + ): + arg_types = [x.real_type for x in op_overload._schema.arguments] + return_types = [x.type for x in op_overload._schema.returns] + + new_tensor_args = [] + new_int_args = [] + + def fill_args(arg, arg_type): + static_arg_types = ( + torch.FloatType, + torch.BoolType, + torch.StringType, + torch.Type, + torch.DeviceObjType, + ) + inductor_tensor_buffers = ( + ir.Buffer, + ir.ReinterpretView, + ) + + if isinstance(arg_type, torch.TensorType): + assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" + new_tensor_args.append(f"{arg.codegen_reference()}") + elif isinstance(arg_type, torch.IntType): + # int + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg + new_int_args.append(self.expr_printer(expr)) + elif isinstance(arg_type, torch.NumberType): + # Scalar of type int + assert isinstance(arg, (int, float, bool)) + # Only treat int Scalar as dynamic + if isinstance(arg, int): + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.ListType): + assert isinstance(arg, (list, tuple)) + + # List[Tensor] + if isinstance(arg_type.getElementType(), torch.TensorType): + new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) + # List[Optional[Tensor]] + elif isinstance( + arg_type.getElementType(), torch.OptionalType + ) and isinstance( + arg_type.getElementType().getElementType(), torch.TensorType + ): + new_tensor_args.extend( + [f"{a.codegen_reference()}" for a in arg if a is not None] + ) + # List[int] + elif isinstance(arg_type.getElementType(), torch.IntType): + new_int_args.extend([str(a) for a in arg]) + # List[SymInt] + elif isinstance(arg_type.getElementType(), torch.SymIntType): + expressions = [ + a.node.expr if isinstance(a, torch.SymInt) else a for a in arg + ] + new_int_args.extend( + [self.expr_printer(expr) for expr in expressions] + ) + # List[Scalar] + elif isinstance(arg_type.getElementType(), torch.NumberType): + # Only treat int Scalar as dynamic + is_int_type = [isinstance(a, int) for a in arg] + if any(is_int_type): + assert all( + is_int_type + ), "AOTInductor only supports int scalars of the same type" + new_int_args.extend([str(a) for a in arg]) + else: + assert isinstance( + arg_type.getElementType(), static_arg_types # type: ignore[arg-type] + ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + else: + assert isinstance( + arg_type, static_arg_types # type: ignore[arg-type] + ), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + + for arg, arg_type in zip(raw_args, arg_types): + if arg is not None: + if isinstance(arg_type, torch.OptionalType): + fill_args(arg, arg_type.getElementType()) + else: + fill_args(arg, arg_type) + + def fill_output_arg(arg, return_type): + if isinstance(return_type, torch.TensorType): + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + new_tensor_args.append(f"{arg}") + elif isinstance(return_type, torch.SymIntType): + raise NotImplementedError("NYI support for return type: SymInt") + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.SymIntType + ): + raise NotImplementedError("NYI support for return type: List[SymInt]") + else: + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance(return_type, (torch.TensorType)): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) + + for output_arg in output_args: + assert output_arg is not None, "Optional return types are not yet supported" + if isinstance(output_arg, (list, tuple)): + for out in output_arg: + fill_output_arg(out, torch.TensorType.get()) + else: + fill_output_arg(output_arg, torch.TensorType.get()) + + return new_tensor_args, new_int_args + + def generate_extern_kernel_alloc_and_find_schema_if_needed( + self, + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name="", + op_overload=None, + raw_args=None, + outputs=None, + ): + if config.is_fbcode(): + assert op_overload is not None + assert raw_args is not None + assert outputs is not None + + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode( + name, + cpp_kernel_key, + op_overload, + raw_args, + outputs, + ) + else: + return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss( + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name, + ) + + def generate_extern_kernel_alloc_and_find_schema_if_needed_oss( + self, + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name="", + ): + if cpp_kernel_key not in self.extern_call_ops: + self.writeline( + f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()" + ) + self.writeline( + f'\t.findSchemaOrThrow("{kernel}", "{cpp_kernel_overload_name}")' + ) + self.writeline(f"\t.typed<{cpp_op_schema}>();") + self.extern_call_ops.add(cpp_kernel_key) + + self.writeline( + f"auto {name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});" + ) + + def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode( + self, + name, + cpp_kernel_key, + op_overload, + raw_args, # contains both args and flatten kwargs + outputs, + ): + def extract_output_name(out): + assert out is not None, "None, i.e. optional output is not supported" + if isinstance(out, ir.MultiOutput): + return out.get_name() + elif isinstance(out, (list, tuple)): + return type(out)(extract_output_name(o) for o in out) + else: + raise AssertionError(f"Unexpected output: {type(out)}") + + # output_args has the same pytree structure as outputs + output_args = extract_output_name(outputs) + if isinstance(output_args, str): + output_args = [output_args] + + ( + tensor_call_args, + int_call_args, + ) = self.generate_extern_kernel_args_decl_if_needed( + op_overload, raw_args, output_args + ) + + tensor_call_args_str = ", ".join(tensor_call_args) + int_call_args_str = ", ".join(int_call_args) + + extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 + + self.writeline( + f"aoti_torch_proxy_executor_call_function(proxy_executor, " + f"{extern_kernel_node_index}, " + f"{len(int_call_args)}, " + f"std::vector{{{int_call_args_str}}}.data(), " + f"{len(tensor_call_args)}, " + f"std::vector{{{tensor_call_args_str}}}.data());" + ) + + self.extern_call_ops.add(cpp_kernel_key) + + def generate_reset_kernel_saved_flags(self): + pass + + def generate_save_uncompiled_kernels(self): + pass + + def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: + if ( + config.abi_compatible + and not is_legacy_abi + and isinstance(type_, torch.OptionalType) + ): + if val is None: + return "0" # nullptr is not available in C + if not isinstance(type_.getElementType(), torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};") + return f"&{var_name}" + elif config.c_shim_version == "2": + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val) + if "wrap_with_raii_handle_if_needed" in base_handle: + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + self.writeline( + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" + ) + base_handle = tmp_var_name + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") + return f"&{var_name}" + + return self.val_to_arg_str(val) + + def val_to_arg_str(self, val) -> str: + if val is None: + # When None is passed as an argument, it represents an optional that does not contain a value. + if config.abi_compatible: + return "0" # nullptr is not available in C + return "c10::nullopt" + elif isinstance(val, bool): + if config.abi_compatible: + return "1" if val else "0" + else: + return "true" if val else "false" + elif isinstance(val, int): + # uint64_t is long on Linux, but long long on MacOS + return f"{val}LL" if sys.platform == "darwin" else f"{val}L" + elif isinstance(val, str): + return f'"{val}"' + elif isinstance( + val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) + ): + return val.codegen_reference() + elif isinstance(val, torch.device): + return self.codegen_device(val) + elif isinstance(val, torch.dtype): + return self.codegen_dtype(val) + elif isinstance(val, float) and val in [float("inf"), float("-inf")]: + if val == float("inf"): + return "std::numeric_limits::infinity()" + else: + return "-std::numeric_limits::infinity()" + elif isinstance(val, (list, tuple)): + # FIXME handle embedded optional types? + result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}" + if config.abi_compatible: + static = self.is_statically_known_list_of_ints(val) + # Need to pass the array length because we can't use std::vector + int_var_array = self.codegen_int_array_var( + result, + known_statically=static, + graph=self.get_codegened_graph(), + ) + return f"{int_var_array}, {len(val)}" + else: + return result + else: + return repr(val) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..36bff25c66c371bc8c64a1ca785bcee7c573b4ee --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cuda.py @@ -0,0 +1,328 @@ +import functools +import os +from itertools import chain, count +from typing import Any, List, Optional, TYPE_CHECKING + +import sympy + +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name + +from .. import config +from ..codecache import CudaKernelParamCache +from ..triton_heuristics import grid as default_grid +from ..virtualized import V +from .cpp_wrapper_cpu import CppWrapperCpu +from .wrapper import SymbolicCallArg + +if TYPE_CHECKING: + from ..graph import GraphLowering + + +def is_int(s: str) -> bool: + # Cpp code gen adds L at the end of ints + # Lets remove it for checking whether we have an int or not + if s and s[-1] == "L": + s = s[:-1] + try: + int(s) + except ValueError: + return False + except TypeError: + return False + return True + + +def is_float(s: str) -> bool: + try: + float(s) + except ValueError: + return False + return True + + +class CppWrapperCuda(CppWrapperCpu): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self): + self.device = "cuda" + super().__init__() + self.grid_id = count() + self.cuda = True + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + + self.header.splice("#include ") + if config.abi_compatible: + self.header.splice( + "#include " + ) + else: + self.header.splice( + """ + #include + #include + #include + """ + ) + + self.header.splice( + """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + ) + + def write_get_raw_stream(self, index, graph=None): + name = f"stream{index}" + self.writeline(f"cudaStream_t {name};") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));" + ) + return name + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + if not cuda: + return super().define_kernel(name, kernel, metadata, cuda) + + def generate(self, is_inference): + self.prefix.writeline("\n") + if not V.graph.aot_mode: + for kernel in chain( + self.src_to_kernel.values(), + [entry[0] for entry in self.user_defined_kernel_cache.values()], + ): + self.prefix.writeline(f"static CUfunction {kernel} = nullptr;") + self.prefix.writeline("\n") + return super().generate(is_inference) + + @functools.lru_cache(None) + def generate_load_kernel_once( + self, + name: str, + mangled_name: str, + cubin_path: str, + shared_mem: int, + graph: "GraphLowering", # for per-graph caching + ): + if V.graph.aot_mode: + self.writeline(f"if (kernels.{name} == nullptr) {{") + self.writeline( + f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);""" + ) + self.writeline("}") + else: + self.writeline(f"if ({name} == nullptr) {{") + self.writeline( + f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" + ) + self.writeline("}") + + def generate_args_decl(self, call_args): + dynamic_symbols = V.graph.sizevars.free_symbols() + # TODO: only works for constant now, need type info + new_args = [] + for arg in call_args: + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(arg, (sympy.Integer, sympy.Symbol, SymbolicCallArg)): + self.writeline(f"auto {var_name} = {arg};") + elif isinstance(arg, sympy.Expr): + self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") + elif is_int(arg): + self.writeline(f"int {var_name} = {arg};") + elif is_float(arg): + self.writeline(f"float {var_name} = {arg};") + elif any(str(arg) == s.name for s in dynamic_symbols): + self.writeline(f"auto {var_name} = {arg};") + elif arg == "nullptr": + self.writeline(f"auto {var_name} = nullptr;") + elif arg == "c10::nullopt": + self.writeline(f"auto {var_name} = c10::nullopt;") + else: + if config.abi_compatible: + self.writeline(f"CUdeviceptr {var_name};") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" + ) + else: + self.writeline( + f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" + ) + new_args.append(f"&{var_name}") + + return ", ".join(new_args) + + def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True): + """ + Generate grid configs for launching a CUDA kernel using the grid + function from triton_heuristics. + """ + if not cuda: + return grid + assert isinstance(grid, list), f"expected {grid=} to be a list" + grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] + grid_fn = default_grid(*grid) + params = CudaKernelParamCache.get(name) + assert ( + params is not None + ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}" + block_cfg = { + "XBLOCK": params["x_block"], + "YBLOCK": params["y_block"], + "ZBLOCK": params["z_block"], + } + return grid_fn(block_cfg) + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + triton_meta=None, + ): + if not cuda: + # Even in CppWrapperCuda, we may see cpp kernels + return super().generate_kernel_call( + name, call_args, grid, device_index, cuda, triton, arg_types + ) + + params = CudaKernelParamCache.get(name) + assert ( + params is not None + ), f"cuda kernel parameters for {name} should already exist at this moment" + mangled_name = params.get("mangled_name", None) + assert mangled_name is not None, "missing mangled_name" + cubin_path = params.get(get_cpp_wrapper_cubin_path_name(), None) + assert cubin_path is not None and os.path.exists( + cubin_path + ), f"cubin file should already exist at this moment: {cubin_path}" + shared_mem = params.get("shared_mem", 0) + + self.generate_load_kernel_once( + name, mangled_name, cubin_path, shared_mem, V.graph + ) + + # args with value 1 are added into equal_to_1 and constants + # in triton_meta (in the Python codegen) which makes them + # inlined in the PTX and compiled CUBIN + if ( + triton_meta is not None + and "configs" in triton_meta + and triton_meta["configs"] + ): + equal_to_1 = triton_meta["configs"][0].equal_to_1 + call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1] + + call_args = self.generate_args_decl(call_args) + kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" + self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};") + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device_index, V.graph) + ) + grid_name = f"{name}_grid_{next(self.grid_id)}" + assert isinstance( + grid, (list, tuple) + ), f"expected grid to be a list or tuple but got: {grid=}" + + grid = [V.graph.sizevars.simplify(item) for item in grid] + grid_uses_symbolic_shapes = any(item.free_symbols for item in grid) + grid_args = [self.grid_expr_printer(item) for item in grid] + grid_args_str = ", ".join(grid_args) + self.writeline(f"Grid {grid_name} = Grid({grid_args_str});") + + if grid_uses_symbolic_shapes: + self.writeline(f"if ({grid_name}.is_non_zero()) {{") + kernel_var_name = f"kernels.{name}" if V.graph.aot_mode else name + self.writeline( + "launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format( + kernel_var_name, + f"{grid_name}.grid_x", + f"{grid_name}.grid_y", + f"{grid_name}.grid_z", + params["num_warps"], + params["shared_mem"], + kernel_args_var, + stream, + ) + ) + if grid_uses_symbolic_shapes: + self.writeline("}") diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb5020127ee427fac425d3ff076ebb2103061f1e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c349ec0e541d7a736f3605d410525d864bcba4c9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df49d81bcc89546ef4edb54efd7be42cd9b16221 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..430ad5ffcaf314bd398beaf363d83bbe4a62dc17 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ed9580d1a13b2dfd235582418bb21a84a6baf51 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f29bdf2dcc3cd64e186a377dca599a9eb6298ae8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_epilogue_gen.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d848ca1f89b3ce1350b08b2fcc86bff3d93d0f60 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41251202dae9c99699ec7411515fd369fd560239 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..896c08d4bf424aeceba88c2e6be8f615e784ea2c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..3417527a59d4a32ed1b1eeb0b8de3e2111319e5b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,212 @@ +import logging +from typing import cast, List + +from ...._dynamo.utils import counters + +from ... import config, ir +from ...codecache import code_hash, get_path +from ...ir import ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer + +from .cutlass_epilogue_gen import CUTLASSEVTOpNotImplementedError + +log = logging.getLogger(__name__) + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + def __init__(self, scheduler: Scheduler): + super().__init__() + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def is_cuda_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template( + node.get_template_node() + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + epilogue_nodes: List[ir.IRNode], + additional_node: ir.IRNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes. + additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + if not isinstance(cuda_template_buffer, CUDATemplateBuffer): + return False + if not cuda_template_buffer.template.can_fuse_epilogue: + # The used GEMM op does not support fusing epilogues + return False + if not isinstance(additional_node, ComputedBuffer): + return False + if not isinstance(additional_node.data, Pointwise): + return False + # We can fuse a Pointwise op that depends on the last fused epilogue node + # if any. If there is no epilogue node yet, it needs to depend on the template + # node + node_name = additional_node.get_computed_buffer_name() + if node_name is None: + return False + + if len(epilogue_nodes) == 0: + if cuda_template_buffer.name not in additional_node.get_read_names(): + return False + else: + last_epilogue_node = epilogue_nodes[-1] + assert isinstance(last_epilogue_node, ir.ComputedBuffer) # for mypy + last_epilogue_name = ( + last_epilogue_node.name + if last_epilogue_node.name is not None + else last_epilogue_node.data.name # type: ignore[attr-defined] + ) + if last_epilogue_name not in additional_node.get_read_names(): + return False + if additional_node.layout != cuda_template_buffer.layout: + return False + try: + from torch._inductor.codegen.cuda.cutlass_epilogue_gen import ( + CutlassEVTEpilogueArgumentFormatter, + CutlassEVTEpilogueTypeFormatter, + ) + + CutlassEVTEpilogueTypeFormatter.ir_to_evt_string( + cast(str, cuda_template_buffer.name), "anything", [additional_node] + ) + CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string( + cast(str, cuda_template_buffer.name), [additional_node] + ) + except CUTLASSEVTOpNotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + log.warning( + f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: + # Likely due to unsupported dtype. + log.warning( + f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + return True + + @staticmethod + def _unwrap_epilogue_nodes(fused_node: FusedSchedulerNode) -> List[ir.IRNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + nodes.remove(template_node) + return [n.node for n in nodes] + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode): + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), [], node2.node + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, SchedulerNode + ): + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node().node, + self._unwrap_epilogue_nodes(fnode1), + node2.node, + ) + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''', 'so')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template( + template_node + ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + kernel.call_kernel(kernel_name, ctb, epilogue_ir_nodes) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..581291f3e8e34105ed80b7b7865ffbaa9e962a6a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,45 @@ +import functools +import logging +from typing import Optional + +import torch + +from ... import config + +log = logging.getLogger(__name__) + + +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception as e: + log.error("Error getting cuda arch: %s", e) + return None + + +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception as e: + log.error("Error getting cuda version: %s", e) + return None + + +@functools.lru_cache(None) +def nvcc_exist(nvcc_path: str = "nvcc") -> bool: + if nvcc_path is None: + return False + import subprocess + + res = subprocess.call( + ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + return res == 0 diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..330b5c279a1c390eb8e6f96b6de61c3c64449b06 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,374 @@ +import logging +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union + +from ... import ir +from ...autotune_process import CUDABenchmarkRequest +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox +from ...select_algorithm import ChoiceCaller +from ...utils import sympy_product +from ...virtualized import V + +from ..common import IndentedBuffer, Kernel, OpOverrides, PrimitiveInfoType +from ..cpp import CppPrinter, DTYPE_TO_CPP + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__(self, kernel_name): + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: Dict[str, IRNode] = {} + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return {**self.args.input_buffers, **self.args.output_buffers}.get( + node.get_name(), None + ) + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def def_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" + + def call_kernel( + self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.WrapperCodeGen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + _, call_args, _ = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + else: + call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + call_args.append("None") + + if node.get_workspace_size() > 0: + call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())") + else: + call_args.append("None") + + wrapper.generate_kernel_call( + name, + call_args, + device_index=V.graph.scheduler.current_device.index, + cuda=True, + triton=False, + ) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + + sizes = node.get_size()[start_index : end_index + 1] + if len(sizes) == 0: + return str(default_value) + + val = sympy_product(sizes) + return cexpr(self.rename_indexing(val)) + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + return cexpr(self.rename_indexing(stride)) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str], + bmreq: CUDABenchmarkRequest, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark( + *args, output_tensor=out + ) # @TODO: Hack for ensuring that Cutlass Kernel is preferred + + def __str__(self): + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + epilogue_node_names: List[str] = [ + getattr(en, "name", "no_name") + for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr] + ] + epilogue_node_strs: List[str] = [ + str(en) for en in self.info_kwargs.get("epilogue_nodes", []) # type: ignore[union-attr] + ] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "epilogue_node_names": epilogue_node_names, # type: ignore[dict-item] + "epilogue_node_strs": epilogue_node_strs, # type: ignore[dict-item] + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd653556e996229a4700652c67b23aa1338df69 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,242 @@ +import functools +import itertools +import logging +from typing import List, Optional +from unittest.mock import patch + +import sympy + +import torch +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout + +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel + +log = logging.getLogger(__name__) + + +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: List[Buffer], + layout: Layout, + input_reorder: Optional[List[int]] = None, + ): + """ + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer("buf_out", layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + kernel_name = f"cuda_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CUDATemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}" + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = CUDATemplateKernel( + kernel_name="KERNEL_NAME", + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + using bfloat16 = nv_bfloat16; + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in {"1", "1L"}: + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..dd42711d6235bf4f0c294371abe588c20b78aa48 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py @@ -0,0 +1,360 @@ +from typing import Dict, List +from unittest.mock import patch + +import sympy + +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise +from torch._inductor.utils import IndentedBuffer, sympy_str + + +# Used as a magic string to indicate an unsupported sympy expression +# became part of generated C++ code. +_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]" + + +def _arg_str(a): + if isinstance(a, sympy.Expr): + # If this return value containting the _MAGIC_SYMPY_ERROR_STRING + # is used as part of the final generated C++ code, + # a CUTLASSEVTOpNotImplementedError is raised to indicate that + # the op could not be converted to a valid EVT expression. + return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')" + return str(a) + + +class CUTLASSEVTOpNotImplementedError(NotImplementedError): + pass + + +class CutlassEVTEpilogueTypeFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) functor declarations. + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name, evt_type_name): + """ + + Initialize an instance of CutlassEVTEpilogueTypeFormatter. + + Parameters: + - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused) + IR graph. + - evt_type_name (str): The output name of the EVT type we are generating. + + """ + self.accumulator_node_name = accumulator_node_name + self.output = IndentedBuffer(0) + self.var_counter = 0 + self.evt_type_name = evt_type_name + self.aliases = dict() + + @staticmethod + def ir_to_evt_string( + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ): + """ + Formats IR nodes into a string representation compatible with Cutlass EVT format. + + Args: + template_output_node_name (str): The name of the template output node. + evt_type_name (str): The name of the EVT type. + epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be + ComputedBuffer nodes wrapping Pointwise nodes. + + Returns: + A string representation of the IR nodes formatted according to the Cutlass EVT format. + """ + formatter = CutlassEVTEpilogueTypeFormatter( + template_output_node_name, evt_type_name + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + if isinstance(node, ComputedBuffer): + pnode = node.data + else: + raise RuntimeError( + "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer" + ) + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + formatter.aliases[node.name] = result + res = formatter.getvalue(result) # type: ignore[possibly-undefined] + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + """ + Resolve V.ops. calls, after this instance has been installed as V.ops handler. + """ + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + self.var_counter += 1 + varname = f"EVT_expr_{self.var_counter}" + # replace line with a new variable name + self.output.writeline(f"using {varname} = {line};") + return varname + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + # Load an input to an operation. Might be the output of the matmul, the result + # of a previous epilogue node, a constant or (TODO) an auxiliary input. + if name == self.accumulator_node_name: + return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */" + elif name in self.aliases: + return self.aliases[name] + else: + # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */" + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + # Load a constant + if str(dtype) in ("torch.float16", "torch.float32"): + return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast /* value={value}, dtype={dtype} */" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + # Perform a named operation on two inputs + # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops + return f"cutlass::epilogue::fusion::Sm90EVT,{a},{b}>" # noqa: B950 + + def _convert_to_output_dtype(self, a): + # Convert the final output to the dtype of the output buffer + return f"cutlass::epilogue::fusion::Sm90EVT,{a}>" # noqa: B950 + + def _op_to_dtype(self, a, *args, **kwargs): + # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator + # dtype. + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + return a # noqa: B950 + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return f"cutlass::epilogue::fusion::Sm90EVT,{a}, {const_zero}>" # noqa: B950 + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError() + + # Add more ops here... + def getvalue(self, result) -> str: + # Return final result + dtype_converted_expr = self._convert_to_output_dtype( + f"EVT_expr_{self.var_counter}" + ) + self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};") + return self.output.getvalue() + + +class CutlassEVTEpilogueArgumentFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name: str): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method. + + Args: + accumulator_node_name (str): The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + """ + self.accumulator_node_name: str = accumulator_node_name # + self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen + self.var_counter: int = ( + 0 # used to generate variable names, incremented for each new variable + ) + self.aliases: Dict[str, str] = dict() # Aliases for subexpression functors + + @staticmethod + def ir_to_evt_argument_string( + template_output_node_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + formatter = CutlassEVTEpilogueArgumentFormatter( + template_output_node_name, + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + assert isinstance(node, ComputedBuffer) + pnode = node.data + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + if node.name is not None: + formatter.aliases[node.name] = result + + res: str = formatter.getvalue(result) # type: ignore[possibly-undefined] + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + return line + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + if name == self.accumulator_node_name: + return "{}" + elif name in self.aliases: + return self.aliases[name] + else: + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + if str(dtype) in ("torch.float16", "torch.float32"): + return "{ static_cast(" + str(value) + ") }" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + return f"{{ /*{op}: */ {a}, {b} }}" + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return "{" + str(a) + ", " + const_zero + "}" + + def _op_to_dtype(self, a, dtype, src_dtype=None): + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + assert dtype in ( + "torch.float32", + "torch.float16", + ), f"Unsupported dtype: {dtype}" + assert src_dtype in ( + None, + "torch.float32", + "torch.float16", + ), f"Unsupported source dtype: {src_dtype}" + return a + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError() + + def getvalue(self, result) -> str: + return "{" + str(result) + "}" diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20eb4b3f94a273ba03ec722adde8da8c71143887 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d9cbbc92f6252e3ee83340027a7575908875afe Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..4a34bd7e9d3a68ee694f5f0170056dc9ef67d7e3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,186 @@ +from ..cutlass_utils import try_import_cutlass + +if try_import_cutlass(): + import enum + + from cutlass_library.library import * # noqa: F401, F403 + from cutlass_library.gemm_operation import * # noqa: F401, F403 + + # copied / modified from original at + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + # to support EVT similar to + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950 + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > + """ + self.gemm_template = """ + using EpilogueScheduleType = ${epilogue_schedule}; + static_assert(cute::is_same_v || + cute::is_same_v, + "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementAcc = ${element_accumulator}; + using ElementD = ${element_d}; + ${epilogue_functor}; + using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + EpilogueScheduleType, + ${operation_name}_epilogue_functor + >::CollectiveOp; + + using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + + // Gemm operator ${operation_name} + using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + + // Define named type + struct ${operation_name} : + public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ + ${compile_guard_start} + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + ${compile_guard_end} + """ + + # + def emit(self, operation): + tile_shape = operation.tile_description.tile_shape + warp_count = operation.tile_description.warp_count + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout" # noqa: B950 + warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], # type: ignore[name-defined] + } + epilogue_functor = SubstituteTemplate( # type: ignore[name-defined] + self.builtin_epilogue_functor_template, values + ) + + elif callable(operation.epilogue_functor): + epilogue_functor = operation.epilogue_functor( + operation.procedural_name() + "_epilogue_functor" + ) + else: + epilogue_functor = str(operation.epilogue_functor) + # + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], # type: ignore[name-defined] + "layout_a": LayoutTag[instance_layout_A], # type: ignore[name-defined] + "element_b": DataTypeTag[operation.B.element], # type: ignore[name-defined] + "layout_b": LayoutTag[instance_layout_B], # type: ignore[name-defined] + "element_c": DataTypeTag[operation.C.element], # type: ignore[name-defined] + "layout_c": LayoutTag[instance_layout_C], # type: ignore[name-defined] + "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined] + "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined] + "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined] + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950 + "arch": "cutlass::arch::Sm%d" % operation.arch, + "tile_shape_m": str(operation.tile_description.tile_shape[0]), + "tile_shape_n": str(operation.tile_description.tile_shape[1]), + "tile_shape_k": str(operation.tile_description.tile_shape[2]), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined] + "epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined] + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], # type: ignore[name-defined] + "transform_b": ComplexTransformTag[operation.B.complex_transform], # type: ignore[name-defined] + "math_operation": MathOperationTag[ # type: ignore[name-defined] + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), # type: ignore[name-defined] + } + + return SubstituteTemplate(self.gemm_template, values) # type: ignore[name-defined] diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa1bdc941a45d5a7c7e009886182a0f68d11e2a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -0,0 +1,258 @@ +import functools +import logging +import os +import sys +from dataclasses import dataclass +from typing import Any, List, Optional + +import sympy + +import torch + +from ...codecache import cache_dir +from ...config import cuda as inductor_cuda_config +from ...ir import Layout +from .cuda_env import get_cuda_arch, get_cuda_version + +log = logging.getLogger(__name__) + + +def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +def _gen_cutlass_file( + file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str +) -> None: + orig_full_path = os.path.abspath(os.path.join(src_dir, file_name)) + text = "" + with open(orig_full_path) as f: + text = f.read() + text = _rename_cutlass_import(text, cutlass_modules) + dst_full_path = os.path.abspath( + os.path.join( + dst_dir, + file_name, + ) + ) + with open(dst_full_path, "w") as f: + f.write(text) + + +@functools.lru_cache(None) +def try_import_cutlass() -> bool: + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + # This is a temporary hack to avoid CUTLASS module naming conflicts. + # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. + + cutlass_py_full_path = os.path.abspath( + os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library") + ) + tmp_cutlass_py_full_path = os.path.abspath( + os.path.join(cache_dir(), "torch_cutlass_library") + ) + dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library") + + if os.path.isdir(cutlass_py_full_path): + if tmp_cutlass_py_full_path not in sys.path: + if os.path.exists(dst_link): + assert os.path.islink( + dst_link + ), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + cutlass_py_full_path + ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" + else: + os.makedirs(tmp_cutlass_py_full_path, exist_ok=True) + os.symlink(cutlass_py_full_path, dst_link) + sys.path.append(tmp_cutlass_py_full_path) + try: + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + + return True + + except ImportError as e: + log.debug( + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_py_full_path, + ) + return False + + +def _normalize_cuda_arch(arch: str) -> str: + if int(arch) >= 90: + return "90" + elif int(arch) >= 80: + return "80" + elif int(arch) >= 75: + return "75" + elif int(arch) >= 70: + return "70" + else: + raise NotImplementedError(f"Unsupported cuda arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures: Optional[str] = None + cuda_version: Optional[str] = None + + operations = "all" + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + # TODO: these three look dead? + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = True + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None or self.cuda_version is None: + raise RuntimeError( + f"{self.architectures=} or {self.cuda_version=} is None!" + ) + self.architectures = _normalize_cuda_arch(self.architectures) + + +@functools.lru_cache(None) +def _gen_ops_cached(arch, version) -> List[Any]: + # Note: Cache needs to be specific for cuda architecture and version + + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None or version is None: + log.error( + "Cannot detect cuda arch %s or cuda version %s. " + "Will discard all cutlass ops. " + "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", + arch, + version, + ) + return list() + arch = _normalize_cuda_arch(arch) + args = CUTLASSArgs(architectures=arch, cuda_version=version) + manifest = cutlass_manifest.Manifest(args) + + if arch == "90": + cutlass_generator.GenerateSM90(manifest, args.cuda_version) + cutlass_generator.GenerateSM80(manifest, args.cuda_version) + else: + try: + func = getattr(cutlass_generator, "GenerateSM" + arch) + func(manifest, args.cuda_version) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) from e + return manifest.operations + + +def gen_ops() -> List[Any]: + """ + Generates all supported CUTLASS operations. + """ + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) + + +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.float: + return ( + cutlass_dtype == cutlass_library.library.DataType.f32 + or cutlass_dtype == cutlass_library.library.DataType.tf32 + ) + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: List[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a list of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + if len(input_torch_dtypes) == 0: + return None + torch_dtype = input_torch_dtypes[0] + for dtype in input_torch_dtypes[1:]: + if torch_dtype != dtype: + raise RuntimeError(f"Unmatched input dtypes: {torch_dtype=}, {dtype=}") + if torch_dtype == torch.half: + if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction: + return torch_dtype + else: + return torch.float + if torch_dtype in {torch.bfloat16, torch.float}: + return torch.float + raise NotImplementedError(f"Unsupported data type: {input_torch_dtypes=}") + + +def get_alignments(torch_dtype: torch.dtype) -> List[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. + """ + + if torch_dtype in (torch.half, torch.bfloat16): + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + if is_static_int(size[-1]) and is_static_int(offset): + alignments = get_alignments(dtype) + for alignment in alignments: + if int(size[-1]) % alignment == 0 and int(offset) % alignment == 0: + return alignment + + return 1 diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..579f340c9af17a598401be2b3906f6741ecd94e4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,18 @@ +from ..common import DeviceOpOverrides, register_device_op_overrides + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx): + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self): + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx): + return f"torch.cuda._DeviceGuard({device_idx})" + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..ea022d4d7019ceb4250841a4ccba385abc660bd6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,706 @@ +import copy +import logging +import re +from typing import cast, Dict, List, Optional, Tuple + +from ...config import cuda as inductor_cuda_config +from ...ir import Buffer, CUDATemplateBuffer, FixedLayout, IRNode, Layout +from ..common import IndentedBuffer + +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate +from .cutlass_epilogue_gen import ( + CutlassEVTEpilogueArgumentFormatter, + CutlassEVTEpilogueTypeFormatter, +) + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + try { + {{kernel.check_not_null(X)}} + {{kernel.check_not_null(W)}} + {{kernel.check_not_null(Bias)}} + {{kernel.check_not_null(Y)}} + int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + int64_t M = {{kernel.size(X, -2)}}; + int64_t K = {{kernel.size(X, -1)}}; + int64_t N = {{kernel.size(W, -1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + + +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + + +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.stride(X, -3), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.stride(W, -3), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}} + }; +""" + +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.stride(Bias, -3), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.stride(Y, -3), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + + +class CUTLASSGemmTemplate(CUTLASSTemplate): + """ + CUTLASS GEMM template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + def __init__( + self, + input_nodes: List[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + can_fuse_epilogue: Optional[bool] = None, + ): + """ + Args: + input_nodes: input nodes of the kernel + layout: layout of the output node + alpha: alpha value of the GEMM operation + beta: beta value of the GEMM operation + input_reorder: reorder of the input nodes + can_fuse_epilogue: If set to True, will only list and use operators capable of flexible epilogue fusions. + If False, it will not use those. If None, both may be listed, but it will not allow fusions. + Defaults to None + """ + super().__init__("cutlass_gemm", input_nodes, layout, input_reorder) + self.alpha = alpha + self.beta = beta + self.can_fuse_epilogue = can_fuse_epilogue + + @staticmethod + def add_cutlass_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + fuseable=True, + non_fuseable=True, + ): + if non_fuseable: + if fuseable: + # list both fuseable and non-fuseable ops, and treat them all as non-fuseable + can_fuse_epilogue = False + else: + can_fuse_epilogue = None + + cutlass_template = CUTLASSGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + can_fuse_epilogue=can_fuse_epilogue, + ) + ops = cutlass_template.gen_ops() + for op in ops: + cutlass_template.maybe_append_choice( + choices, + op=op, + ) + else: + ops = [] + if fuseable: + cutlass_template_evt = CUTLASSGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + can_fuse_epilogue=True, + ) + # This will list only ops capable of EVT fusion + ops_evt = cutlass_template_evt.gen_ops() + for op in ops_evt: + cutlass_template_evt.maybe_append_choice( + choices, + op=op, + ) + else: + ops_evt = [] + log.debug( + "Added %d cutlass gemm configs and %d fuseable gemm configs.", + len(ops), + len(ops_evt), + ) + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + return res + + @staticmethod + def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if torch_layout.stride[-1] == 1: + return cutlass_lib.LayoutType.RowMajor + elif torch_layout.stride[-2] == 1: + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + def layout_match(torch_layout, cutlass_layout) -> bool: + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + alignment = cutlass_utils.get_max_alignment(torch_layout) + if alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def has_tma_epilogue(op) -> bool: + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + @staticmethod + def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] # noqa: F821 + """ + returns True if the op is capable of flexible epilogue fusions + using epilogue visitor trees. + + See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950 + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Universal3x: + return False + if op.epilogue_schedule not in ( + cutlass_lib.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_lib.EpilogueScheduleType.TmaWarpSpecializedCooperative, + ): + return False + + return True + + def render_evt_epilogue_declaration( + self, + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + """Generates the epilogue for the EVT epilogue fusion""" + return CutlassEVTEpilogueTypeFormatter.ir_to_evt_string( + template_output_node_name, evt_type_name, epilogue_nodes + ) + + def define_gemm_instance( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + output_buffer_name: str, + epilogue_nodes: Optional[List[IRNode]] = None, + ) -> Tuple[str, str]: + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import ( + EmitGemmUniversal3xInstanceWithEVT, + ) + + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + emitter = EmitGemmUniversal3xInstanceWithEVT() + op.epilogue_functor = lambda epilogue_functor_type_name: self.render_evt_epilogue_declaration( + output_buffer_name, epilogue_functor_type_name, epilogue_nodes + ) + else: + emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance() + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + else: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + raise RuntimeError( + "EVT epilogue fusion is not supported for Cutlass 2.x ops." + ) + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + return op_def, op_type + + @staticmethod + def should_swap_XW( + bias: IRNode, + beta: float, + ) -> bool: + return True + + # TODO(ipiszy): Check whether it's necessary to swap X/W. + # strides = bias.get_stride() + # if strides[-1] != 1: + # return True + # for stride in strides[:-1]: + # if stride != 0: + # return True + # return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # Swap X and W in GemmOperation. + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + # Skip simt kernels + if ( + op.tile_description.math_instruction.opcode_class + == cutlass_lib.OpcodeClass.Simt + ): + return None + + # Only keep GemmUniversal kernels + if op.gemm_kind not in { + cutlass_lib.GemmKind.Universal, + cutlass_lib.GemmKind.Universal3x, + }: + return None + # Filter ops by dtypes. + X = self.input_nodes[0] + W = self.input_nodes[1] + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.C.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return None + + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Update op. + op = copy.deepcopy(op) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + if not ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + + # Set bias layout and alignment. + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if op.gemm_kind != cutlass_lib.GemmKind.Universal3x: + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return None + else: + op.C.layout = bias_layout + if not self.set_alignment(Bias.get_layout(), op.C): + return None + else: + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op.C.element = cutlass_lib.DataType.void + else: + op.C.layout = op.D.layout + supports_evt: bool = self.supports_evt(op) + if (self.can_fuse_epilogue is not None) and ( + self.can_fuse_epilogue != supports_evt + ): + return None + if inductor_cuda_config.cutlass_only_evt_capable_ops and not supports_evt: + return None + return op + + def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm] + res: Dict[str, cutlass_gemm_op.GemmOperation] = dict() + num_3x_ops = 0 + num_2x_ops = 0 + for op_dict in ops.values(): + for op_list in op_dict.values(): + for op in op_list: + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + for op in res.values(): + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + num_3x_ops += 1 + else: + num_2x_ops += 1 + log.debug( + "Got cutlass configs: total number of ops: %d, " + "total number of 3x ops: %d, total number of 2x ops: %d", + len(res), + num_3x_ops, + num_2x_ops, + ) + return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs] + + def gemm_mode(self) -> str: + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + + if epilogue_template is not None: + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), + new_stride, + old_layout.offset, + ) + return Buffer(node.get_name(), new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + else: + arguments = self._template_from_string(GEMM_ARGS_CUTLASS_2X).render( + split_k=1, **options + ) + return arguments + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ) -> str: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + assert self.can_fuse_epilogue and CUTLASSGemmTemplate.supports_evt( + op + ), "op does not support EVT epilogue fusion" + assert ( + template_buffer_node is not None + ), "Template node is required for epilogue fusion" + assert isinstance( + template_buffer_node, CUDATemplateBuffer + ), f"Template node has to be a CUDATemplateBuffer, is type {type(template_buffer_node)}" + assert ( + template_buffer_node.name is not None + ), "Output node has to be a Buffer with a name" + # This is the name of the output of the Matmul, before epilogues are applied. + # it is not necessarily materialized in global memory if we have an epilogue + + template_output_node_name = ( + template_buffer_node.name if template_buffer_node is not None else None + ) + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance( + op, cutlass_gemm_op.GemmOperation + ), "op argument is required and has to be an instance of GemmOperation" + if template_buffer_node is not None: + self.output_node = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + self.output_node = cast(Buffer, epilogue_nodes[-1]) + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + epilogue_template: Optional[str] = None + should_swap_xw: bool = False + epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + if Bias is not None and self.has_tma_epilogue(op): + if self.should_swap_XW(Bias, self.beta): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + epilogue_args = ( + CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string( + cast(str, template_output_node_name), epilogue_nodes + ) + ) + epilogue_template = GEMM_ARGS_CUTLASS_3X_EPILOGUE + argument_template = GEMM_ARGS_CUTLASS_3X + else: + # TODO: Support split_k. + argument_template = GEMM_ARGS_CUTLASS_2X + + instance_definition, instance_type = self.define_gemm_instance( + op, cast(str, template_output_node_name), epilogue_nodes + ) + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + should_swap_xw=should_swap_xw, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=epilogue_args, + ) + res = self._template_from_string(GEMM_TEMPLATE).render(**options) + return res diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..ed88fc3a9d824779eb783ccddcdab6ffc557e5b9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,75 @@ +from typing import List + +from ..scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling + +from .triton import TritonScheduling + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Scheduler): + super().__init__() + self._scheduler = scheduler + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template( + node + ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node): + return self._cuda_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template( + node + ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn(self, sizes): + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes + ) + + def codegen_nodes(self, nodes: List[SchedulerNode]): + return self._triton_scheduling.codegen_nodes(nodes) + + def codegen_sync(self): + return self._triton_scheduling.codegen_sync() + + def flush(self): + return self._triton_scheduling.flush() + + def codegen_foreach(self, *args, **kwargs): + return self._triton_scheduling.codegen_foreach(*args, **kwargs) + + def benchmark_fused_nodes(self, nodes): + return self._triton_scheduling.benchmark_fused_nodes(nodes) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/memory_planning.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/memory_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..6f921c527e99409679e714aafe6ec14758b31f7f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/memory_planning.py @@ -0,0 +1,799 @@ +from __future__ import annotations + +import collections +import dataclasses +import itertools +import pprint +from typing import Any, Dict, Iterable, List, Optional, Protocol + +import sympy + +import torch +from .. import config, ir +from ..utils import cache_on_self, CachedMethod, IndentedBuffer +from ..virtualized import V + +from .wrapper import ( + AllocateLine, + FreeIfNotReusedLine, + MemoryPlanningLine, + NullLine, + ReuseLine, +) + + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes): + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr): + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value): + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value + + +@dataclasses.dataclass +class LiveRange: + """ + A range where a given tensor is live. Begin and end are both counters + representing points in the program of grouped memory operations. + Begin is inclusive, end is exclusive. + + Invariant: begin <= end + """ + + begin: float # int | ±inf + end: float # int | ±inf + + def contains(self, other: LiveRange): + """Is other entirely within self""" + return self.begin <= other.begin and other.end <= self.end + + def join(self, other: LiveRange): + """Combine two ranges using a union operation""" + return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) + + def __len__(self): + return self.end - self.begin + + +class LiveRanges: + """ + A collection of LiveRange regions, allowing for non-contiguous + live regions. + + Invariant: LiveRanges.ranges is in sorted order and non-overlapping + """ + + def __init__(self, ranges: Iterable[LiveRange]): + ranges = [*sorted(ranges, key=lambda x: x.begin)] + self.ranges = ranges[:1] + for r in ranges[1:]: + assert self.ranges[-1].begin <= r.begin + if self.ranges[-1].end >= r.begin: + self.ranges[-1] = LiveRange.join(self.ranges[-1], r) + else: + self.ranges.append(r) + + def overlaps(self, other: LiveRanges): + """Check if any pair of ranges in self and other overlap""" + left = collections.deque(self.ranges) + right = collections.deque(other.ranges) + while left and right: + if left[0].begin > right[0].begin: + left, right = right, left + assert left[0].begin <= right[0].begin + if left[0].end > right[0].begin: + return True + left.popleft() + return False + + @property + def begin(self): + return self.ranges[0].begin + + @property + def end(self): + return self.ranges[-1].end + + def __repr__(self): + return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" + + +class AllocationTreeNode: + """ + Abstract base class for nodes in allocation pool. + """ + + def allocate(self, block: Allocation, is_last: bool) -> bool: + """ + Try to assign block to a memory location in this bool. Return True if + an assignment was made. + """ + return False + + def get_live_ranges(self) -> LiveRanges: + """Aggregate LiveRanges for all objects below this in tree""" + raise NotImplementedError() + + def get_size_hint(self) -> int: + """Number of bytes used for example inputs""" + raise NotImplementedError() + + def get_symbolic_size(self) -> sympy.Expr: + """Number of bytes needed at runtime""" + raise NotImplementedError() + + def finalize(self, pool, offset) -> AllocationTreeNode: + """Called after all allocations have been made""" + return self + + def is_empty(self): + return False + + +@dataclasses.dataclass +class Allocation(AllocationTreeNode): + """ + Represents memory allocated to a given node in the allocation pool. + """ + + node: ir.Buffer + live_range: LiveRange + size_hint: int + symbolic_size: sympy.Expr + allocated: bool = False + pool: Optional[AllocationPool] = None + offset: Optional[sympy.Expr] = None + + @property + def device(self): + return self.node.get_device() + + def get_live_ranges(self): + return LiveRanges([self.live_range]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return self.symbolic_size + + def mark_allocated(self): + assert not self.allocated + self.allocated = True + + def finalize(self, pool, offset): + assert self.pool is None and self.offset is None + self.pool = pool + self.offset = offset + return self + + def codegen_alloc_from_pool(self, wrapper): + assert self.pool + node = self.node + shape = tuple(node.get_size()) + stride = tuple(node.get_stride()) + return wrapper.codegen_alloc_from_pool( + self.pool.name, self.offset, node.get_dtype(), shape, stride + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"node={self.node.get_name()}, " + f"live_range={self.live_range}, " + f"size_hint={self.size_hint}, " + f"symbolic_size={self.symbolic_size}, " + f"pool={self.pool.name if self.pool else None}, " + f"offset={self.offset})" + ) + + +@dataclasses.dataclass +class Empty(AllocationTreeNode): + """ + Placeholder to represent empty space in the allocation pool. + Only exists to get the size_hint correct in parent nodes. + """ + + size_hint: int + + def get_live_ranges(self): + return LiveRanges([]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return 0 + + def is_empty(self): + return True + + +class MemorySplitProtocol(Protocol): + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] + + def _allocate(self, block: Allocation, is_last: bool) -> bool: + ... + + +class ClearCacheOnAllocateMixin(MemorySplitProtocol): + """ + Helper to assist in caching get_live_ranges, get_size_hint, and + get_symbolic_size. + """ + + def allocate(self, block: Allocation, is_last: bool): + is_allocated = self._allocate(block, is_last) + if is_allocated: + self.clear_cache() + return is_allocated + + def clear_cache(self): + self.get_live_ranges.clear_cache(self) + self.get_size_hint.clear_cache(self) + self.get_symbolic_size.clear_cache(self) + + +@dataclasses.dataclass +class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains a list of allocations not overlapping in LiveRanges. + + Invariant: no pair (a,b) in self.allocations will have: + a.get_live_ranges().overlaps(b.get_live_ranges()) + """ + + allocations: List[AllocationTreeNode] + + def _allocate(self, block: Allocation, is_last: bool): + slot_size = self.get_size_hint() + block_size = block.get_size_hint() + if not is_last and block_size > slot_size: + return False # doesn't fit + + block_live = block.get_live_ranges() + overlapping = [ + s for s in self.allocations if s.get_live_ranges().overlaps(block_live) + ] + if len(overlapping) > 1: + # TODO(jansel): we could try harder here by merging overlapping in space + return False + elif len(overlapping) == 1: + return overlapping[0].allocate(block, is_last) + else: + block.mark_allocated() + + if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): + self.allocations.pop() + + if slot_size == block_size: + # perfect fit + self.allocations.append(block) + elif slot_size > block_size: + self.allocations.append( + SpatialSplit.create(block, slot_size - block_size) + ) + else: # grow this allocation + assert is_last + self.allocations = [ + *( + SpatialSplit.create(a, block_size - slot_size) + for a in self.allocations + ), + block, + ] + return True + + @cache_on_self + def get_live_ranges(self) -> LiveRanges: + return LiveRanges( + itertools.chain.from_iterable( + x.get_live_ranges().ranges for x in self.allocations + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + if not self.allocations: + return 0 + return max(x.get_size_hint() for x in self.allocations) + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + if not self.allocations: + return 0 # type: ignore[return-value] + return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) + + def is_empty(self): + return len(self.allocations) == 1 and self.allocations[0].is_empty() + + def finalize(self, pool, offset): + self.allocations = [block.finalize(pool, offset) for block in self.allocations] + self.clear_cache() + if len(self.allocations) == 1: + return self.allocations[0] + return self + + +@dataclasses.dataclass +class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains two allocations, left and right, that do not overlap in space. + Right will be allocated immediately after left in memory. + """ + + left: TemporalSplit + right: TemporalSplit + + @staticmethod + def create(left, extra_space): + assert isinstance(left, AllocationTreeNode) + assert isinstance(extra_space, int) and extra_space >= 1 + return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) + + def _allocate(self, block: Allocation, is_last: bool): + return self.left.allocate(block, False) or self.right.allocate(block, is_last) + + @cache_on_self + def get_live_ranges(self): + return LiveRanges( + itertools.chain( + self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + return _align(self.left.get_size_hint()) + self.right.get_size_hint() + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() + + def finalize(self, pool, offset): + self.left = self.left.finalize(pool, offset) + self.right = self.right.finalize( + pool, offset + align(self.left.get_symbolic_size()) + ) + self.clear_cache() + if self.right.is_empty(): + return self.left + return self + + +@dataclasses.dataclass +class AllocationPool: + """ + Represents a pool of allocations that will be generated by a single + call to torch.empty. + """ + + device: torch.device + root: TemporalSplit + can_expand: bool = True + restrict_live_range: Optional[LiveRange] = None + name: Optional[str] = None + names_to_del: List[str] = dataclasses.field(default_factory=list) + creation_cache: Dict[str, str] = dataclasses.field(default_factory=dict) + + def allocate(self, block: Allocation, is_last: bool): + if self.restrict_live_range and not self.restrict_live_range.contains( + block.live_range + ): + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): + return True + + if is_last: + return self.allocate_at_end(block) + + return False + + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + return True + + def finalize(self, name): + assert not self.name + self.name = name + self.names_to_del.append(name) + self.root.finalize(self, 0) + + def codegen_create(self, wrapper, code: IndentedBuffer): + assert self.name + nbytes = self.root.get_symbolic_size() + for block in self.root.allocations: + if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): + # optimization: fuse first allocation and pool creation + node = block.node + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=node.get_dtype(), + shape=tuple(node.get_size()), + stride=tuple(node.get_stride()), + ) + ) + self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name + return + else: + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=torch.uint8, + shape=(nbytes,), + stride=(1,), + ) + ) + + def codegen_destroy(self, wrapper, code: IndentedBuffer): + code.writeline(wrapper.make_free_by_names(self.names_to_del)) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class AllocationPools: + """ + Collection of many AllocationPool objects grouped by device. + """ + + device_to_pools: Dict[torch.device, List[AllocationPool]] = dataclasses.field( + default_factory=dict + ) + + def get_pools(self, block): + if block.device not in self.device_to_pools: + self.device_to_pools[block.device] = [] + return self.device_to_pools[block.device] + + def allocate(self, block: Allocation): + pools = self.get_pools(block) + + for pool in pools: + if pool.allocate(block, is_last=pool is pools[-1]): + return + + # everything is full, make a new pool + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool != "none", + ) + ) + block.mark_allocated() + + def allocate_output(self, block: Allocation): + """Outputs get different pools so memory gets freed properly""" + pools = self.get_pools(block) + if pools and config.memory_pool in ("outputs", "combined"): + pools[-1].allocate_at_end(block) + else: + # create a new pool + block.mark_allocated() + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool == "combined", + ) + ) + + def finalize(self): + """Called at the end of allocation process""" + for i, pool in enumerate( + itertools.chain.from_iterable(self.device_to_pools.values()) + ): + pool.finalize(f"pool{i}") + + def pprint(self): + for pool in itertools.chain.from_iterable(self.device_to_pools.values()): + print() + print(pool.name) + print(pool.root.get_live_ranges()) + pprint.pprint(pool.root) + + +class BufferGroup: + """ + Due to inplace reuse an allocated buffer can have many names. + This tracks these collections of buffers sharing underlying memory. + """ + + def __init__(self, node: ir.Buffer): + self.node = node + self.names = [node.get_name()] + self.is_output = False + self.allocation: Optional[Allocation] = None + self.live_range = LiveRange(float("inf"), -float("inf")) + + def update_usage(self, timestep: int): + """Expand self.live_range to include timestep""" + self.live_range = LiveRange( + min(timestep, self.live_range.begin), + max(timestep, self.live_range.end), + ) + + def sym_nbytes(self): + return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize + + def make_allocation(self): + assert not self.allocation, "multiple allocations" + assert isinstance(self.live_range.begin, int), "live ranges not computed" + nbytes = self.sym_nbytes() + # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have + # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. + size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) + self.allocation = Allocation( + self.node, + self.live_range, + size_hint=size_hint, + symbolic_size=nbytes, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " + f"live_range={self.live_range}" + ) + + +@dataclasses.dataclass +class PoolMemoryPlanningLine(MemoryPlanningLine): + """Abstract base class for {Alloc,Dealloc}FromPoolLine""" + + group: BufferGroup + timestep: Optional[int] = None + + @property + def node(self): + return self.group.node + + +@dataclasses.dataclass +class AllocFromPoolLine(PoolMemoryPlanningLine): + """Similar to AllocationLine, but takes memory from a pool""" + + is_first_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + allocation = self.group.allocation + assert allocation and allocation.pool + pool = allocation.pool + name = self.node.get_name() + + if self.is_first_pool_usage: + pool.codegen_create(self.wrapper, code) + + pool.names_to_del.extend(self.group.names) + alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + if alloc_from_pool in pool.creation_cache: + code.writeline( + self.wrapper.make_tensor_alias( + name, pool.creation_cache[alloc_from_pool], "alloc" + ) + ) + else: + pool.creation_cache[alloc_from_pool] = name + code.writeline( + f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" + ) + + +@dataclasses.dataclass +class DeallocFromPoolLine(PoolMemoryPlanningLine): + """Similar to FreeIfNotReusedLine, but takes memory from a pool""" + + is_last_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + if self.is_last_pool_usage: + assert self.group.allocation and self.group.allocation.pool + self.group.allocation.pool.codegen_destroy(self.wrapper, code) + + +@dataclasses.dataclass +class MemoryPlanner: + """ + Coordination object to run memory planning passes during wrapper + codegen. + """ + + wrapper: Any + pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) + buffer_groups: Optional[List[BufferGroup]] = None + + def plan(self, lines: List[Any]) -> List[Any]: + """Call all the memory planning passes in sequence""" + lines = [*lines] + self.drop_removed_buffers(lines) + self.convert_to_pool_lines(lines) + self.compute_live_ranges(lines) + self.allocate_groups() + self.mark_first_last_usage(lines) + return lines + + def drop_removed_buffers(self, lines): + """ + Replace any memory planning lines in V.graph.removed_buffers with NullLine + """ + # drop any removed buffers + for i, line in enumerate(lines): + if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): + if line.node.get_name() in V.graph.removed_buffers: + lines[i] = NullLine(self.wrapper) + + def compute_buffer_groups(self, lines): + """ + Populates self.buffer_groups with BufferGroup objects that join + allocations with common storage (due to inplace reuse) into a + single object. + """ + name_to_group = {} + for line in lines: + if isinstance(line, AllocateLine): + name = line.node.get_name() + assert name not in name_to_group + name_to_group[name] = BufferGroup(line.node) + elif isinstance(line, ReuseLine): + old_name = line.node.get_name() + new_name = line.reused_as.get_name() + assert new_name not in name_to_group + # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc + if old_name in name_to_group: + name_to_group[old_name].names.append(new_name) + name_to_group[new_name] = name_to_group[old_name] + + outputs = set(V.graph.get_output_names()) + unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] + for group in unique_groups: + group.is_output = any(x in outputs for x in group.names) + + assert self.buffer_groups is None + self.buffer_groups = unique_groups + return name_to_group + + def convert_to_pool_lines(self, lines): + """ + Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their + pool-based counterparts. + """ + name_to_group = self.compute_buffer_groups(lines) + for i, line in enumerate(lines): + if isinstance(line, AllocateLine): + if line.node.get_name() in name_to_group: + lines[i] = AllocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, FreeIfNotReusedLine): + assert not line.is_reused + if line.node.get_name() in name_to_group: + lines[i] = DeallocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, ReuseLine): + if line.node.get_name() in name_to_group: + line.delete_old = False + + def compute_live_ranges(self, lines): + """Populate every BufferGroup.live_ranges field based on first/last usage""" + timestep = 0 + worklist = collections.deque(lines) + while worklist: + if isinstance(worklist[0], MemoryPlanningLine): + timestep += 1 + while worklist and isinstance(worklist[0], MemoryPlanningLine): + line = worklist.popleft() + if isinstance(line, PoolMemoryPlanningLine): + line.group.update_usage(timestep) + line.timestep = timestep + else: + worklist.popleft() + + timestep += 1 + assert self.buffer_groups is not None + for group in self.buffer_groups: + if group.is_output: + group.update_usage(timestep) + + def allocate_groups(self): + """ + Assign every allocation to a specific location in a specific AllocationPool. + """ + assert config.memory_pool in ("none", "intermediates", "outputs", "combined") + assert self.buffer_groups is not None + + for group in self.buffer_groups: + group.make_allocation() + + outputs: List[Allocation] = [] + intermediates: List[Allocation] = [] + for group in self.buffer_groups: + assert group.allocation + if group.is_output and config.memory_pool != "combined": + outputs.append(group.allocation) + else: + intermediates.append(group.allocation) + + for block in sorted( + outputs, + key=lambda x: ( + x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate_output(block) + + for block in sorted( + intermediates, + key=lambda x: ( + -x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate(block) + + self.pools.finalize() + + def mark_first_last_usage(self, lines): + """ + Populate the AllocFromPoolLine.is_first_pool_usage and + DeallocFromPoolLine.is_last_pool_usage fields so that pools + are created/destroyed. + """ + seen = set() + for line in lines: + if isinstance(line, AllocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_first_pool_usage = True + seen.add(pool) + + seen = set() + for line in reversed(lines): + if isinstance(line, DeallocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_last_pool_usage = ( + pool.root.get_live_ranges().end <= line.timestep + ) + seen.add(pool) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/multi_kernel.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/multi_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..d788076470203ff068b5b5557dc921f44a4781b8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/multi_kernel.py @@ -0,0 +1,413 @@ +import logging +import os +from typing import Any, List + +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled + +from .. import config +from ..codecache import PyCodeCache, TritonFuture +from ..utils import cache_on_self, do_bench +from ..virtualized import V +from .common import TensorArg + +log = logging.getLogger(__name__) + + +def get_kernel_argdefs(kernel): + arg_defs, _, _ = kernel.args.python_argdefs() + return arg_defs + + +def _get_all_args(args_list): + all_args = max(args_list, key=len)[:] + for args in args_list: + assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}" + + return all_args + + +def get_all_kernel_argdefs(kernels): + """ + The logic here must match with `get_all_call_args`. + """ + argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels] + + return _get_all_args(argdefs_list) + + +def get_all_call_args(call_args_list): + """ + Passed in the call_args for each subkernel and return the call_args for the + combined multi-kernel. + + Note an algorithm as follows does not always work: + ``` + all_call_args: Dict[ + Any, None + ] = {} # use a dict rather than set to maintain insertion order + for call_args in call_args_list: + all_call_args.update({arg: None for arg in call_args}) + + all_call_args = list(all_call_args.keys()) + ``` + It will fail if any kernel has the same argument passed in multiple times. + Check test_pass_same_arg_multi_times in test_multi_kernel.py + + Instead, we pick the longest call args and assert that otehr call args are + a subset of it. + """ + return _get_all_args(call_args_list) + + +def get_numel_argdefs(kernel): + numel_argdefs = [] + for tree in kernel.range_trees: + if tree.prefix != "r" or kernel.inside_reduction: + numel_argdefs.append(f"{tree.prefix}numel") + + return numel_argdefs + + +class MultiKernelState: + """ + Maintain state of multi-kernel compilation so we don't define duplicated + multi-kernel for the same set of sub-kernels. + + V.graph.wrapper_code has a reference to MultiKernelState instance. + """ + + def __init__(self): + self.subkernel_to_kernel_name = {} + + def define_kernel(self, kernels): + """ + Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". + This has some minor issue. + + E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , + there are 2 flavors of non-persistent reduction: + https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 + and + https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd + + The only different is cache eviction policy. + + We should name the multi-kernel differently in these 2 cases. + """ + kernel_names = tuple(k.kernel_name for k in kernels) + if kernel_names in self.subkernel_to_kernel_name: + return self.subkernel_to_kernel_name[kernel_names] + + # name the multi kernel based on the first kernel + multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" + self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name + + if V.graph.cpp_wrapper: + # we should not generate any python code for multi-kernel during + # the second pass of cpp-wrapper. + return multi_kernel_name + + wrapper = V.graph.wrapper_code + + kernel_call_def_code = "\n".join( + [ + f""" + def call{idx}(need_clone_args=False): + args = [{', '.join(get_kernel_argdefs(kernels[idx]))}] + if need_clone_args: + args, _ = multi_kernel_call.kernels[{idx}].clone_args(*args) + multi_kernel_call.kernels[{idx}].run(*args, {', '.join(get_numel_argdefs(kernels[idx]))}, grid=grid, stream=stream) + """.format( + idx + ).strip( + "\n" + ) + for idx in range(len(kernels)) + ] + ) + + # add subkernel src code hashes to the multi-kernel source code so changing a + # subkernel implementation will result in a differnt py file for + # multi-kernel. This makes cache implementation straightforward since + # we can decide cache file name based on multi-kernel py file name + # directly. + # + # Without the hash added for subkernels, the cache file may be shared by + # different subkernels which is incorrect. + subkernel_hashes = "\n".join( + f"# subkernel{i} code hash: {kernel.code_hash}" + for i, kernel in enumerate(kernels) + ) + + src_code = f""" +{subkernel_hashes} +def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.join(get_numel_argdefs(kernels[0]))}, grid, stream): +{kernel_call_def_code} + multi_kernel_call.run_with_argless_kernels([call0, call1]) + """ # noqa: B950 line too long + wrapper.header.splice( + f""" + {multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [ + {", ".join(kernel_names)}, + ], + ''' + """ + ) + wrapper.header.splice(src_code) + wrapper.header.splice( + """ + ''' + ) + """ + ) + + return multi_kernel_name + + +class MultiKernel: + """ + This class maintains the compile time state for multi kernels. + + Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. + The generated definition for the multi-kernel will looks like: + ``` + multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code) + ``` + + Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + """ + + def __init__(self, kernels): + assert len(kernels) >= 2 + + self.kernels = kernels + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + kernels + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + def call_kernel(self, kernel_name): + """ + Collect the union of arguments from all subkernels as the arguments + for the multi-kernel. + """ + assert kernel_name == self.kernel_name + call_args_list = [kernel.get_call_args() for kernel in self.kernels] + + all_call_args = get_all_call_args(call_args_list) + grid: List[Any] = [] + + if V.graph.cpp_wrapper: + # for the second pass of cpp-wrapper codegen, we should call + # the fast kernel directly + picked_kernel = MultiKernelCall.lookup_choice(kernel_name) + kernel_name = self.kernels[picked_kernel].kernel_name + final_call_args = call_args_list[picked_kernel] + else: + final_call_args = all_call_args + + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args_and_grid( + kernel_name, final_call_args, grid + ) + + grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) + + V.graph.wrapper_code.generate_kernel_call( + kernel_name, + final_call_args, + grid, + V.graph.scheduler.current_device.index, + ) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + seen = set() + for k in self.kernels: + _, call_args, arg_types = k.args.python_argdefs() + for arg, arg_type in zip(call_args, arg_types): + if arg in seen: + continue + seen.add(arg) + if isinstance(arg_type, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + @property + def removed_buffers(self): + return set.intersection(*[k.removed_buffers for k in self.kernels]) + + @property + def inplaced_to_remove(self): + return set.intersection(*[k.inplaced_to_remove for k in self.kernels]) + + @property + @cache_on_self + def inplace_update_buffers(self): + """ + Make sure all kernels have the same inplace update mappings. + """ + for k in self.kernels[1:]: + assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers + return self.kernels[0].inplace_update_buffers + + def warn_mix_layout(self, kernel_name: str): + pass + + +class MultiKernelCall: + """ + This class is called at run time to actually run the kernel + """ + + def __init__(self, multi_kernel_name, kernels, src_code): + assert len(kernels) >= 2 + self._kernels = kernels + self.multi_kernel_name = multi_kernel_name + + self._run = PyCodeCache.load(src_code).run + self.disable_cache = os.environ.get( + "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" + ) == "1" or is_metric_table_enabled("persistent_red_perf") + + self.picked_kernel = None + if config.triton.multi_kernel > 1: + # manually force a subkernel to ease perf testing + picked_by_config = config.triton.multi_kernel - 2 + assert picked_by_config < len(self._kernels) + self.picked_kernel = picked_by_config + elif not self.disable_cache: + self.load_cache() + + self._recorded = False + + def cache_file_path(self): + py_file_path = self._run.__globals__["__file__"] + return os.path.splitext(py_file_path)[0] + ".picked_kernel" + + def load_cache(self): + assert self.picked_kernel is None + path = self.cache_file_path() + if os.path.exists(path): + with open(path) as fd: + self.picked_kernel = int(fd.read()) + assert self.picked_kernel >= 0 and self.picked_kernel < len( + self._kernels + ) + log.debug( + "Load picked kernel %d from cache file %s", self.picked_kernel, path + ) + + def store_cache(self): + assert self.picked_kernel is not None + path = self.cache_file_path() + with open(path, "w") as fd: + fd.write(str(self.picked_kernel)) + log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) + + @property + def kernels(self): + """ + Read results from future. + + This should be called after parallel compilation is done. + In case you call this before compilation is done, + it may slow down the parallel compilation. + """ + for i, kernel in enumerate(self._kernels): + if isinstance(kernel, TritonFuture): + self._kernels[i] = kernel.result() + + return self._kernels + + def run(self, *args, **kwargs): + self._run(self, *args, **kwargs) + + @staticmethod + def benchmark_sub_kernels(kernel_calls): + """ + Benchmark all the sub kernels and return the execution time + (in milliseconds) for each of time. + + Unit test may mock this method to force a specific kernel to + be picked. + """ + return [ + do_bench(lambda: kernel_call(True), rep=40, fast_flush=True) + for kernel_call in kernel_calls + ] + + # record_choice and lookup_choice are helper functions for cpp-wrapper + # codegen. The first pass use record_choice to keep the choice and + # the second pass do lookup by calling lookup_choice. + # + # An alternative that reused the multi-kernel cache does not work well + # since during codegen of the second pass, it's very hard to know the + # path for the cache file. Also reading the cache file need do some IO + # which can be slower. + @staticmethod + def record_choice(multi_kernel_name, choice): + """ + Record the multi-kernel choice for cpp-wrapper first pass codegen + for the second pass. + + We should do nothing if this function is not called during codegen. + """ + from torch._inductor.graph import GraphLowering + + if not isinstance(V.graph, GraphLowering): + return + + if not V.graph.record_multi_kernel_choice: + return + + V.graph.multi_kernel_to_choice[multi_kernel_name] = choice + + @staticmethod + def lookup_choice(multi_kernel_name): + # this should always been done during cpp-wrapper codegen + assert V.graph.record_multi_kernel_choice + # there should be no miss + return V.graph.multi_kernel_to_choice[multi_kernel_name] + + def run_with_argless_kernels(self, kernel_calls): + if self.picked_kernel is None: + timings = self.benchmark_sub_kernels(kernel_calls) + self.picked_kernel = timings.index(min(timings)) + k0 = self.kernels[0] + log.debug( + "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + k0.size_hints, + k0.inductor_meta.get("reduction_hint"), + timings, + ) + + def get_kernel_path(k): + return k.fn.fn.__code__.co_filename + + get_metric_table("persistent_red_perf").add_row( + lambda: { + "kernel1_name": get_kernel_path(self.kernels[0]), + "kernel2_name": get_kernel_path(self.kernels[1]), + "kernel1_latency": timings[0], + "kernel2_latency": timings[1], + "size_hints": k0.size_hints, + "reduction_hint": k0.inductor_meta.get("reduction_hint"), + "speedup": timings[1] / timings[0], + } + ) + + if not self.disable_cache: + self.store_cache() + + if not self._recorded: + self._recorded = True + self.record_choice(self.multi_kernel_name, self.picked_kernel) + kernel_calls[self.picked_kernel]() diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/triton.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..87f61b0710f81fdb64bd80cc0da8b9fbfb77c530 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton.py @@ -0,0 +1,3931 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import textwrap +from functools import lru_cache +from typing import ( + Any, + Callable, + cast, + Counter, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) + +import sympy + +import torch +import torch._logging + +from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata +from torch._prims_common import is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._triton import has_triton_package + +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..codecache import code_hash, get_path, PyCodeCache +from ..dependencies import Dep, MemoryDep, StarDep, WeakDep +from ..ir import IRNode, ReductionHint, TritonTemplateBuffer +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..triton_heuristics import AutotuneHint +from ..utils import ( + cache_on_self, + do_bench, + get_dtype_size, + get_fused_kernel_name, + get_kernel_metadata, + get_max_y_grid, + green_text, + is_welford_reduction, + next_power_of_2, + Placeholder, + sympy_dot, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, + yellow_text, +) +from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V +from ..wrapper_benchmark import get_kernel_category_by_source_code +from .common import ( + CSE, + CSEVariable, + DeferredLine, + free_symbol_startswith, + IndentedBuffer, + index_prevent_reordering, + Kernel, + OpOverrides, + PythonPrinter, + SizeArg, + TensorArg, +) +from .multi_kernel import MultiKernel +from .triton_utils import config_of, signature_of, signature_to_meta + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +@lru_cache(None) +def gen_attr_descriptor_import(): + """ + import AttrsDescriptor if the triton version is new enough to have this + class defined. + """ + if not has_triton_package(): + return "" + + import triton.compiler.compiler + + if hasattr(triton.compiler.compiler, "AttrsDescriptor"): + return "from triton.compiler.compiler import AttrsDescriptor" + else: + return "" + + +@lru_cache(None) +def gen_common_triton_imports(): + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor import triton_helpers, triton_heuristics + from torch._inductor.ir import ReductionHint, TileHint + from torch._inductor.triton_helpers import libdevice, math as tl_math + from torch._inductor.triton_heuristics import AutotuneHint + from torch._inductor.utils import instance_descriptor + """ + ) + return imports.getvalue() + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: Set[sympy.Symbol] + mask_str: str + expand_str: Optional[str] + _has_rindex: bool + + def has_mask(self): + return bool(self.mask_vars) + + def has_rindex(self): + return self._has_rindex + + def has_tmpmask(self): + return "tmp" in self.mask_str + + def has_rmask(self): + return "rmask" in self.mask_str + + +@dataclasses.dataclass +class BlockPtrOptions: + constant_offset: sympy.Expr + shape: List[sympy.Expr] + strides: List[sympy.Expr] + block_shape: List[str] + order: List[int] + offsets: List[str] + mask_vars: Set[sympy.Symbol] + reshape_suffix: List[str] + + @staticmethod + def create( + strides: List[sympy.Expr], + constant_offset: sympy.Expr, + range_trees: List[IterationRangesEntry], + mask_vars: Set[sympy.Symbol], + ) -> BlockPtrOptions: + """Helper to create a BlockPtrOptions instance""" + block_shape = [f"{t.prefix.upper()}BLOCK" for t in range_trees] + reshape_suffix = [*block_shape] + + broadcasting_dim = [s == 0 for s in strides] + for i, is_broadcasting in enumerate(broadcasting_dim): + if is_broadcasting: + # drop any stride==0 dimensions for performance + reshape_suffix[i] = "1" + + if V.kernel.no_x_dim: + assert range_trees[0].prefix == "x" + reshape_suffix.pop(0) + + if ( + not V.kernel.inside_reduction + and len(strides) == len(V.kernel.numels) - 1 + and V.kernel.numels[-1] != 1 + ): + # Need to expand rank by 1 to match rank when self.inside_reduction=True + reshape_suffix.append("1") + + def filter(it): + """Removes any broadcasting dims from a given sequence""" + assert len(it) == len(broadcasting_dim) + return [ + item + for item, is_broadcasting in zip(it, broadcasting_dim) + if not is_broadcasting + ] + + return BlockPtrOptions( + constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), + shape=[ + V.graph.sizevars.lookup_precomputed_size(t.numel) + for t in filter(range_trees) + ], + strides=[*map(V.graph.sizevars.lookup_precomputed_size, filter(strides))], + block_shape=filter(block_shape), + order=V.graph.sizevars.guarded_order(filter(strides)), + offsets=filter([f"{t.prefix}offset" for t in range_trees]), + mask_vars=mask_vars, + reshape_suffix=reshape_suffix, + ) + + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_block_ptr() + + Args: + name: variable name for pointer + roffset: should roffset be included in offsets=..., for use with tl.advance() + + Returns: + "tl.make_block_ptr(...)" + """ + f = V.kernel.index_to_str + offsets = [*self.offsets] + if not roffset: + offsets[offsets.index("roffset")] = "0" + args = [ + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name, + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + f"order={f(self.order)}", + f"offsets={f(offsets)}", + ] + return f"tl.make_block_ptr({', '.join(args)})" + + @cache_on_self + def boundary_check(self) -> List[int]: + """List of indices to pass to tl.load(boundary_check=...)""" + check = [] + for i in range(len(self.shape)): + if ( + self.block_shape[i] != "1" + and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type] + and not V.graph.sizevars.statically_known_multiple_of( + self.shape[i], + config.triton.max_block[self.block_shape[i][0]], # type: ignore[arg-type] + ) + and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK") + ): + check.append(i) + return check + + def advance_roffset(self): + """Codegen string to pass to tl.advance(name, ...)""" + advance = ["0"] * len(self.shape) + advance[self.offsets.index("roffset")] = "RBLOCK" + return V.kernel.index_to_str(advance) + + def has_rindex(self): + return "RBLOCK" in self.block_shape + + def has_rmask(self): + return self.has_rindex() + + def has_tmpmask(self): + return False # block_ptr can't do indirect indexing + + def has_mask(self): + return bool(self.boundary_check()) + + +def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): + """Workaround https://github.com/openai/triton/issues/2836""" + assert isinstance(old_shape, list) and isinstance(new_shape, list) + if old_shape == new_shape: + return value + if [s for s in new_shape if s != "1"] != old_shape: + return f"tl.reshape({value}, [{', '.join(new_shape)}])" + # rewrite to [:, None] syntax, which is less buggy + idx = 0 + expand = [] + for size in new_shape: + if idx < len(old_shape) and size == old_shape[idx]: + expand.append(":") + idx += 1 + else: + assert size == "1" + expand.append("None") + assert idx == len(old_shape) + return f"{value}[{', '.join(expand)}]" + + +class TritonPrinter(PythonPrinter): + def _print_floor(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _helper_sqrt(self, expr): + return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" + + def _print_Where(self, expr): + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"tl.where({c}, {p}, {q})" + + def _print_Min(self, expr): + nargs = len(expr.args) + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Min(*expr.args[:mid])) + b = self._print(sympy.Min(*expr.args[mid:])) + return f"tl.minimum({a}, {b})" + + def _print_Max(self, expr): + nargs = len(expr.args) + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Max(*expr.args[:mid])) + b = self._print(sympy.Max(*expr.args[mid:])) + + return f"tl.maximum({a}, {b})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"tl_math.abs({self._print(expr.args[0])})" + + def _print_cos(self, expr): + assert len(expr.args) == 1 + return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_cosh(self, expr): + assert len(expr.args) == 1 + return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_acos(self, expr): + assert len(expr.args) == 1 + return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_sin(self, expr): + assert len(expr.args) == 1 + return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_sinh(self, expr): + assert len(expr.args) == 1 + return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_asin(self, expr): + assert len(expr.args) == 1 + return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_tan(self, expr): + assert len(expr.args) == 1 + return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_tanh(self, expr): + assert len(expr.args) == 1 + return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_atan(self, expr): + assert len(expr.args) == 1 + return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}" + + +texpr = TritonPrinter().doprint +pexpr = PythonPrinter().doprint + + +def triton_compute_type(dtype): + triton_type_name = str(dtype).split(".")[-1] + if triton_type_name == "bool": + triton_type_name = "int1" + elif triton_type_name in ("float16", "bfloat16"): + # float16 math is done in float32 inside the kernel + triton_type_name = "float32" + elif triton_type_name == "float8_e4m3fn": + triton_type_name = "float8e4nv" + elif triton_type_name == "float8_e5m2": + triton_type_name = "float8e5" + elif triton_type_name == "float8_e4m3fnuz": + triton_type_name = "float8e4b8" + elif triton_type_name == "float8_e5m2": + triton_type_name = "float8e5b16" + return f"tl.{triton_type_name}" + + +def triton_store_type(dtype): + triton_type_name = str(dtype).split(".")[-1] + if triton_type_name == "bool": + triton_type_name = "int8" + elif triton_type_name == "float8_e4m3fn": + triton_type_name = "float8e4nv" + elif triton_type_name == "float8_e5m2": + triton_type_name = "float8e5" + return f"tl.{triton_type_name}" + + +def triton_acc_type(dtype): + if is_integer_dtype(dtype) and dtype.is_signed: + nbits = 64 if dtype == torch.int64 else 32 + return f"tl.int{nbits}" + return triton_compute_type(dtype) + + +def triton_constant(value): + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +class TritonCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any]): + super().__init__(name, bounds) + # We'll use this to track which masks the variable needs when used for indirect indexing + self.mask_vars: Set[str] = set() + + def update_on_args(self, name, args, kwargs): + # When making a variable that is going to be used in indirect indexing + # if a where clause is used it should mean that the result is always a + # valid index, so you shouldn't include any of the dependent variables + # in the resulting load mask + if name == "where": + return + for arg in args: + if isinstance(arg, TritonCSEVariable): + self.mask_vars.update(arg.mask_vars) + elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr": + # most of the time index vars don't need masks associated with them + # however, when index vars are used to compute indices for indirect reads + # those reads should subsequently be masked, + self.mask_vars.update({f"{arg.name[0]}mask"}) + + def __repr__(self): + return f"TritonCSEVariable(name={self.name})" + + +class TritonOverrides(OpOverrides): + """Map element-wise ops to Triton""" + + @staticmethod + def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): + def _get_min_elements_per_thread( + src_dtype: torch.dtype, dst_dtype: torch.dtype + ) -> int: + if src_dtype == dst_dtype: + # No data type conversion is needed. No requirements on min_elem_per_thread. + return 0 + + # fp8 data type conversions has min_elem_per_thread requirements. + # Refer to Triton implementations here: + # https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. + fp8_dtypes = { + torch.float8_e4m3fn, + torch.float8_e5m2, + } + # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2. + assert not ( + src_dtype in fp8_dtypes + and dst_dtype in fp8_dtypes + and src_dtype != dst_dtype + ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!" + if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2: + return 4 + if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn: + return 2 + # No requirements on min_elem_per_thread. + return 0 + + if src_dtype is not None: + # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype). + # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions + # in the same kernel. + V.kernel.min_elem_per_thread = max( + _get_min_elements_per_thread(src_dtype, dtype), + V.kernel.min_elem_per_thread, + ) + + if dtype == torch.bool: + return f"({x} != 0)" + elif dtype == torch.uint8: + # to work around llvm uint conversion semantics + # that produces 0's for negative values + return f"{x}.to(tl.int8).to(tl.uint8)" + return f"{x}.to({triton_compute_type(dtype)})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + triton_dtype = triton_compute_type(dtype) + # We may promote float16 or bfloat16 to float32 and cause the + # bitwidth of dtype to be different from the input tensor (i.e. float32). + # In such as case, we will have to convert the input tensor to + # its src_type, perform bitcast, and then convert the bit-casted + # tensor back to float to ensure we use values with the right precision. + if src_dtype in (torch.float16, torch.bfloat16): + triton_src_dtype = str(src_dtype).split(".")[-1] + cast_x = f"{x}.to(tl.{triton_src_dtype})" + cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)" + return f"{cast_x}.to(tl.float32)" + else: + return f"{x}.to({triton_dtype}, bitcast=True)" + + @staticmethod + def _shaped_constant(value, dtype, shape): + type_ = torch._prims_common.dtype_to_type(dtype) + triton_val = triton_constant(type_(value)) + triton_type = triton_compute_type(dtype) + + if triton_type == "tl.float32": + # Float constants are always f32 in triton + return triton_val + + # NOTE: We use a tensor here in order to get the expected type. + # Otherwise, e.g. float64 constants would be trunctated to float32. + return f"tl.full({shape}, {triton_val}, {triton_type})" + + @classmethod + def constant(cls, value, dtype): + return cls._shaped_constant(value, dtype, shape=[]) + + @staticmethod + def abs(x): + return f"tl_math.abs({x})" + + @staticmethod + def libdevice_abs(x): + return f"libdevice.abs({x})" + + @staticmethod + def exp(x): + return f"tl_math.exp({x})" + + @staticmethod + def libdevice_exp(x): + return f"libdevice.exp({x})" + + @staticmethod + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + def sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def libdevice_sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def relu(x): + bug = config.triton.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + # NB: this only triggers runtime error as long as input + # is not all zero + return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})' + elif bug == "accuracy": + return f"{x} + 1" + elif bug is None: + return ops.maximum("0", x) + else: + raise AssertionError( + f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"triton_helpers.minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"triton_helpers.maximum({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"tl.where({a}, {b}, {c})" + + @staticmethod + def cos(x): + return f"tl_math.cos({x})" + + @staticmethod + def libdevice_cos(x): + return f"libdevice.cos({x})" + + @staticmethod + def sin(x): + return f"tl_math.sin({x})" + + @staticmethod + def libdevice_sin(x): + return f"libdevice.sin({x})" + + @classmethod + def index_expr(cls, expr, dtype): + raise NotImplementedError("ops.index_expr not implemented outside a kernel") + + @staticmethod + def masked(mask, body, other): + raise NotImplementedError("ops.masked not implemented outside a kernel") + + @staticmethod + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + def erfinv(x): + return f"libdevice.erfinv({x})" + + @staticmethod + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + offset = f"({offset}).to(tl.uint32)" + return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + raise NotImplementedError("ops.load_seed not implemented outside a kernel") + + @staticmethod + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + def tanh(x): + return f"libdevice.tanh({x})" + + @staticmethod + def sigmoid(x): + return f"tl.sigmoid({x})" + + @staticmethod + def libdevice_sigmoid(x): + return f"1/(1 + libdevice.exp(-({x})))" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" + + @staticmethod + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + def log(x): + return f"tl_math.log({x})" + + @staticmethod + def libdevice_log(x): + return f"libdevice.log({x})" + + @staticmethod + def isinf(x): + return f"libdevice.isinf({x}).to(tl.int1)" + + @staticmethod + def isnan(x): + return f"libdevice.isnan({x}).to(tl.int1)" + + @staticmethod + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def floordiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Similar to div_floor_kernel_cuda in pytorch core. + # Notice that // in triton behaves as truncdiv instead of floordiv + quot = f"{a} // {b}" + rem = f"{a} % {b}" + return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" + + @staticmethod + def sign(x): + def to_int(s): + return f"{s}.to(tl.int8)" + + left = to_int(ops.lt("0", x)) + right = to_int(ops.lt(x, "0")) + sub = ops.sub(left, right) + return f"{sub}.to({x}.dtype)" + + @staticmethod + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Notice that // in triton behaves as truncdiv instead of floordiv + return f"{a} // {b}" + + @staticmethod + def ceil(x): + return f"libdevice.ceil({x})" + + +TritonOverrides._initialize_pointwise_overrides("triton") + + +# Use mypy to check protocol implemented correctly +def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]: + return h + + +class TritonKernelOverrides(TritonOverrides): + """Map element-wise ops to Triton within a TritonKernel + + Unlike TritonOverrides, these assume the code is going to be inserted into + the body of the main triton kernel and so it may use indexing and mask + variables which are assumed to already be defined in the current scope. + """ + + @classmethod + def constant(cls, value, dtype): + # NOTE: Cannot use shape=[] as it's not supported by triton-rocm + # We could use shape=[1] instead but starting with the correct + # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR. + ndim = V.kernel.triton_tensor_ndim() + shape = [1] * ndim + return cls._shaped_constant(value, dtype, shape=shape) + + @classmethod + def index_expr(cls, expr, dtype): + indexing = V.kernel.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + # This is called from CSEProxy.__getattr__, so we'll set the bounds there + var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str) + + if dtype not in {torch.int32, torch.int64}: + var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype)) + var.mask_vars = indexing.mask_vars + return var + + @staticmethod + def masked(mask, body, other): + with V.kernel.mask_loads(mask) as new_mask: + result = body() + + # Take dtype from result to prevent accidental promotion + other = V.kernel.cse.generate( + V.kernel.compute, + f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)", + ) + return ops.where(new_mask, result, other) + + @staticmethod + def load_seed(name, offset): + var = V.kernel.args.input(name) + return ( + f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})" + ) + + @staticmethod + def frexp(x): + cache_key = f"frexp({x})" + if cache_key in V.kernel.cse.cache: + return V.kernel.cse.cache[cache_key] + + mantissa = V.kernel.cse.newvar() + exponent = V.kernel.cse.newvar() + V.kernel.compute.writeline( + f"{mantissa}, {exponent} = triton_helpers.frexp({x})" + ) + V.kernel.cse.cache[cache_key] = (mantissa, exponent) + return (mantissa, exponent) + + +# Use mypy to check protocol implemented correctly +def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]: + return h + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: List[sympy.Symbol], + var_ranges: Dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: TritonKernel, + divisor=sympy.Integer(1), + length=sympy.Integer(1), + root: IterationRangesRoot, + ): + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + def symbol(self): + return sympy_index_symbol(self.name) + + +class IterationRangesRoot(IterationRanges): + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: TritonKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + ): + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: Dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + assert not is_loop or (prefix == "r" and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + + def __repr__(self): + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self): + for node in self.nodes.values(): + node.cache_clear() + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries(self, lengths: List[sympy.Expr]): + divisor = sympy.Integer(1) + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return list(reversed(itervars)) + + def construct(self, lengths: List[sympy.Expr]): + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes(self, index: sympy.Expr): + """Figure out vars from this tree used in index""" + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor)) + divisor = sympy.Integer(1) + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return list(reversed(index_vars)), list(reversed(sizes)) + + def ranges_code(self): + assert self.tensor_dim is not None + size = self.kernel.indexing_size_str(self.tensor_dim) + index_dtype = self.kernel.index_dtype + convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}" + + def scalar_code(self, value): + index_dtype = self.kernel.index_dtype + ndim = self.kernel.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def get_pid(self): + assert self.grid_dim is not None + key = f"tl.program_id({self.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if ( + self.grid_dim == 1 + and config.triton.max_tiles <= 2 + and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid()) + ): + key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)" + pid = self.pid_cache.get(key, key) + if self.kernel.index_dtype != "tl.int32": + return f"{pid}.to({self.kernel.index_dtype})" + return pid + + def codegen_header(self, code): + x = self.prefix + if self.is_loop: + code.writeline(f"{self.name} = {x}offset + {x}base") + elif self.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{self.name} = {self.ranges_code()}") + code.writeline(f"{x}offset = 0") + else: + if self.tensor_dim is not None: + line = f"{x}offset + {self.ranges_code()}" + else: + line = self.scalar_code(f"{x}offset") + code.writelines( + [ + f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", + f"{self.name} = {line}", + ] + ) + code.writeline(f"{x}mask = {self.name} < {x}numel") + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ): + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self): + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name): + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self): + self.codegen.cache_clear() + + def writeline(self, line): + if self.root.is_loop: + V.kernel.indexing_code.writeline(line) + else: + # lift non-reduction stores outside loop + V.kernel.body.writeline(line) + + def _codegen(self): + self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr))) + return self.name + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all(s.name.startswith("s") for s in symbols): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return self.name == other.name + + +class HelperFunctions: + """An ordered set of helper functions.""" + + _templates_seen: Dict[str, str] # Template code to function name + finalized_helpers: List[str] + + def __init__(self): + self._templates_seen = {} + self.finalized_helpers = [] + + def add(self, template_code: str) -> str: + """This accepts a function definition with the function name + left as a format specifier e.g. + + @triton.jit + def {name}(arg0, arg1): + return arg0 + arg1 + + We add the templated code to the function set and return the name + assigned to that function. + + """ + existing_name = self._templates_seen.get(template_code) + if existing_name is not None: + # Don't duplicate existing helpers + return existing_name + + name = f"_triton_helper_fn{len(self.finalized_helpers)}" + self._templates_seen[template_code] = name + self.finalized_helpers.append(template_code.format(name=name)) + return name + + def __iter__(self): + return iter(self.finalized_helpers) + + def __getitem__(self, idx): + return self.finalized_helpers[idx] + + +class TritonKernel(Kernel): + overrides = TritonKernelOverrides # type: ignore[assignment] + sexpr = pexpr + + helper_functions: HelperFunctions + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[Set[str]] = None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + min_elem_per_thread=0, + disable_persistent_reduction=False, + ): + if pid_cache is None: + pid_cache = {} + super().__init__() + self.numels = [V.graph.sizevars.simplify(s) for s in groups] + self.mutations: Set[str] = mutations if mutations is not None else set() + self.range_trees: List[IterationRangesRoot] = [] + self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = self.numels[-1] != 1 + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment] + self.outside_loop_vars: Set[Any] = set() + self.reduction_hint = reduction_hint + self.index_dtype: str = index_dtype + self.min_elem_per_thread = min_elem_per_thread + self.last_usage: Set[str] = set() + self.block_ptr_id = itertools.count() + # buffer accesses in the kernel + self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) + + self.persistent_reduction: bool = ( + not disable_persistent_reduction + ) and self.should_use_persistent_reduction() + self.no_x_dim = ( + self.reduction_hint == ReductionHint.INNER + and self.persistent_reduction + and len(self.numels) == 2 + and self.numels[-1] >= 256 + ) + self.initialize_range_tree(pid_cache) + + self.helper_functions = HelperFunctions() + + # A set of autotuning hints to pass as part of triton_meta + self.autotune_hints: Set[AutotuneHint] = set() + + # define this in a closure to make cache local to object + @functools.lru_cache(None) + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + return index + + self.simplify_indexing = simplify_indexing + self.code_hash = None + self.triton_meta: Optional[Dict[str, object]] = None + + def need_numel_args(self): + r""" + Indicate whether we need provide numel as arguments for the generated + kernel calls in the benchmark. + + Should be true for pointwise/reduction kernels but false for triton + matmul kernels. + """ + return True + + def should_use_persistent_reduction(self) -> bool: + """ + Heuristic to set self.persistent_reduction and add guards + if needed. + """ + if not (self.inside_reduction and config.triton.persistent_reductions): + return False + threshold = { + ReductionHint.INNER: 1024, + }.get(self.reduction_hint, 64) + + # If multi_kernel is enabled, we do more aggressive persistent reduction. + # This may result in some persisent reductions slower than the + # corresponding non-persistent reductions. MultiKernel will do benchmarking + # to pick the faster one. + if config.triton.multi_kernel: + threshold *= 16 + last_numel = self.numels[-1] + if not isinstance(last_numel, (int, sympy.Integer)): + # Not static + return False + hint = V.graph.sizevars.size_hint(last_numel) + if hint > threshold: + return False + # will need to recompile if we cross a larger power of 2 boundary + V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) # type: ignore[arg-type] + return True + + def set_last_usage(self, nodes): + if not self.inside_reduction or self.persistent_reduction: + return + self.last_usage = set( + itertools.chain.from_iterable( + n.last_usage for n in nodes if n is not EnableReduction + ) + ) + + def initialize_range_tree(self, pid_cache): + no_r_dim = not self.inside_reduction or self.numels[-1] == 1 + + prefixes = "zyxr" + active_prefixes = prefixes[-len(self.numels) :] + + grid_dims = "xyz" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyz" + else: + tensor_dims = "xyzr" + + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix == "r" + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + self.numels[i], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + ) + ) + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + tree.codegen_header(self.body) + if self.inside_reduction and self.range_trees[-1].is_loop: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}") + + def disable_reduction(self): + should_flush = self.range_trees[-1].is_loop + + @contextlib.contextmanager + def ctx(): + if self.numels[-1] == 1: + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths): + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit() + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(size, idx1, idx2): + def getter(flat_vars): + return size * flat_vars[idx1] + flat_vars[idx2] + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while ( + current_group < len(remaining) + and sv.size_hint(remaining[current_group]) == 1 + ): + # scroll to next group with remaining elements + current_group += 1 + + if sv.size_hint(size) > sv.size_hint(remaining[current_group]): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit() + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + size2, + add_range(current_group, size1), + add_range(current_group + 1, size2), + ) + ) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all( + V.graph.sizevars.size_hint(s) == 1 for s in remaining + ), f"failed to set ranges {remaining} {lengths}" + + return new_ranges, return_getters_groups + + @classmethod + def is_compatible( + cls, groups: Iterable[sympy.Expr], lengths: List[List[sympy.Expr]] + ): + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.Integer(1) + + if len(lengths) == len(self.range_trees) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return self.set_ranges(*lengths) + + new_ranges, return_getters_groups = self._split_iteration_ranges( + groups, lengths + ) + itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr): + # tmpX means indirect indexing + return free_symbol_startswith(index, "tmp") + + def is_broadcasted(self, index: sympy.Expr): + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels) + ) + + def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in triton code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the triton kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return texpr(self.rename_indexing(self.codegen_indexing(index))) + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ) -> Union[IndexingOptions, BlockPtrOptions]: + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + s.name.startswith("s") or s.name.startswith("ps") for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + index = self.simplify_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or var.name.startswith("r") + if override_mask: + pass + elif var.name.startswith("tmp"): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif var.name.startswith(("s", "ps", "i", "u")): + pass + else: + # var is one of xN, yN or rN + assert var.name[0] in "xyr", var.name + mask_vars.add(f"{var.name[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars = set() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + block_ptr + and config.triton.use_block_ptr + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/openai/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees(reorder=True) + symbols = [t.symbol() for t in range_trees] + strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols] + offset = sympy.Wild("_offset", exclude=symbols) + m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset) + # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with + # a tl.reshape the correct block. We will miss these cases today. + if m: + self.filter_masks(mask_vars) + return BlockPtrOptions.create( + [m[s] for s in strides], + m[offset], + range_trees, + mask_vars, # type: ignore[arg-type] + ) + + expand_str = None + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + return IndexingOptions(index_str, set(), "None", expand_str, has_rindex) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" + mask_vars = dense_mask_vars + + if override_mask: + mask_vars = {override_mask} + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex) # type: ignore[arg-type] + + def active_range_trees(self, reorder=False): + trees = [ + t for t in self.range_trees if t.prefix != "r" or self.inside_reduction + ] + if reorder and len(trees) > 1: + count = sum(t.prefix in "xyz" for t in trees) + assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ + t.prefix for t in trees[:count] + ] + trees[:count] = reversed(trees[:count]) + return trees + + def filter_masks(self, mask_vars): + for tree in self.range_trees: + # Masks are superfluous if we only have one element + if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] + mask_vars.discard(f"{tree.prefix}mask") + continue + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.prefix.upper() not in config.triton.max_block: + continue + max_block = config.triton.max_block[tree.prefix.upper()] + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # It's faster to avoid masking at all. But it is sound to always + # mask. + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] + mask_vars.discard(f"{tree.prefix}mask") + + def var_ranges(self): + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + @contextlib.contextmanager + def mask_loads(self, mask): + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + if prior: + mask = self.cse.generate(self.compute, f"{mask} & {prior}") + + self._load_mask = mask + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + + def generate_assert(self, check): + return torch.version.hip is None and super().generate_assert(check) + + def load_mask(self, var): + mask = "" + mask_vars = set(var.mask_vars) + if self._load_mask: + mask_vars.add(self._load_mask) + + if mask_vars: + mask = ( + f"{next(iter(mask_vars))}" + if len(mask_vars) == 1 + else f"({' & '.join(str(v) for v in mask_vars)})" + ) + return mask + + @property + def assert_function(self) -> str: + return "tl.device_assert" + + def get_strides_of_load(self, index: sympy.Expr): + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + def codegen_block_ptr( + self, name: str, var: str, indexing: BlockPtrOptions, other="" + ) -> Tuple[str, Optional[DeferredLine], str]: + advance_block_ptr = None + check = indexing.boundary_check() + if not check: + # workaround https://github.com/openai/triton/issues/2813 + other = "" + elif other: + assert other == ", other=0.0" + other = f", boundary_check={check!r}, padding_option='zero'" + else: + other = f", boundary_check={check!r}" + if ( + self.inside_reduction + and self.range_trees[-1].is_loop + and indexing.has_rindex() + ): + block_ptr = f"block_ptr{next(self.block_ptr_id)}" + self.body.writeline( + DeferredLine( + name, f"{block_ptr} = {indexing.format(var, roffset=False)}" + ) + ) + advance_block_ptr = DeferredLine( + name, + f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})", + ) + else: + block_ptr = indexing.format(var) + return block_ptr, advance_block_ptr, other + + def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): + # broadcasting is not implicit for block_ptrs + value = ( + f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})" + ) + # drop any extra size=1 dimensions + value = triton_reshape(value, indexing.reshape_suffix, indexing.block_shape) + # workaround https://github.com/openai/triton/issues/2814 + value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" + return f"tl.store({block_ptr}, {value}{other})" + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + + # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold + # 1) We are doing broadcasting + # 2) It is a non-coalesced load. The intuition is that if it's + # non-coalesced, we will likely load each element multiple times in + # practice. + # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold + # 3.1) We are in a reduction loop + # 3.2) Its not its last use + # 3.3) This load will not be lifted to the body + # + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + if self.is_broadcasted(original_index): + ep = ", eviction_policy='evict_last'" + elif not is_coalesced: + ep = ", eviction_policy='evict_last'" + elif self.inside_reduction and self.range_trees[-1].is_loop: + if name in self.args.inplace_buffers: + names = set(self.args.inplace_buffers[name].other_names) + else: + names = {name} + last_use = len(names & self.last_usage) > 0 + evict_last = not last_use and (has_rindex or indirect_indexing) + if evict_last: + ep = ", eviction_policy='evict_last'" + else: + ep = ", eviction_policy='evict_first'" + else: + ep = "" + # "other" below is a workaround for https://github.com/openai/triton/issues/737 + # for bool, even though it's likely subject to the same bug, setting `other` leads + # to LLVM errors so we are skipping it for now + if ( + (has_tmpmask or has_rindex) + and V.graph.get_dtype(name) != torch.bool + and indexing.has_mask() + ): + other = ", other=0.0" + else: + other = "" + + advance_block_ptr = None + append_broadcast = None + if V.graph.is_unspec_arg(name): + line = var + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing, other + ) + line = f"tl.load({block_ptr}{other}{ep})" + # add needed size=1 dimensions + line = triton_reshape( + line, indexing.block_shape, indexing.reshape_suffix + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = indexing.expand_str + else: + line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})" + + dtype = V.graph.get_dtype(name) + if dtype in (torch.float16, torch.bfloat16): + line += ".to(tl.float32)" + if dtype == torch.bool and torch.version.hip is None: + # Workaround for https://github.com/openai/triton/issues/2151 + # tl.load returns int8 when loading from pointer to int1 + # NOTE: Currently causes hangs on bool UTs for ROCm + line += ".to(tl.int1)" + + if has_tmpmask: + # Masked loads must come after the mask is computed + load_buffer = self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indirect_indexing + and not has_rindex + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + load_buffer = self.body + else: + load_buffer = self.loads + + result_var = self.cse.generate(load_buffer, line) + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast: + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line) + + if advance_block_ptr: + load_buffer.writeline(advance_block_ptr) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + original_index = index + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) + + # Guard against write-after-read corruption in triton. + # See # https://github.com/openai/triton/issues/1615 + # This triton bug means that a load which is broadcasted over multiple + # warps may see the result of a store that happens later in the triton + # program. The workaround is to add a barrier before storing, which + # enforces that all warps have already read the data. + is_inplace = name in self.args.inplace_buffers + is_broadcasted = self.is_broadcasted(original_index) + if is_inplace and is_broadcasted: + self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) + + advance_block_ptr = None + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing + ) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + if advance_block_ptr: + self.stores.writeline(advance_block_ptr) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + def bucketize( + self, + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + + # Triton performance for bucketize_binary_search is much better when the number + # of threads equals the number of elements. + # If we're trying to use a bucketize kernel, we should make sure that an + # autotuning config with num_elements_per_warp=32 exists. + self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32) + + offsets_ptr = self.args.input(offsets_name) + block_size = self.dense_size_str() + offsets_size_str = self.index_to_str(offsets_size) + + if indexing_dtype == torch.int32: + triton_dtype = "tl.int32" + elif indexing_dtype == torch.int64: + triton_dtype = "tl.int64" + else: + raise NotImplementedError( + "Bucketize only supports indexing with int32 and int64" + ) + + result = self.cse.generate( + self.compute, + f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})", # noqa: B950 line too long + ) + + return result + + def reduction_resize(self, value): + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + + sizes = [":"] * ndims + sizes[-1] = "None" + return f"{value}[{', '.join(sizes)}]" + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + assert self.inside_reduction + masks = {f"{tree.prefix}mask" for tree in self.range_trees} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + + # Say we have + # tmp0 = ops.constant(1, torch.int64) + # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) + # tmp0 in the triton code is either a scalar, or single-element tensor + # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 + # To avoid this, we broadcast to the expected shape first. + dense_size_str = self.dense_size_str() + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, f"tl.broadcast_to({v}, {dense_size_str})" + ), + value, + ) + + dim: int + root_op: str + + def final_reduction(value): + use_helper = reduction_type in {"any", "max", "min", "prod"} + module = "triton_helpers" if use_helper else "tl" + if reduction_type in {"max", "min"}: + return self.reduction_resize( + f"{module}.{reduction_type}2({value}, {dim})" + ) + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + + def final_argreduce(buffer, result_var, value, index): + buffer.splice( + f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp')} + """ + ) + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = self.triton_tensor_ndim() - 1 + acc_type = triton_acc_type(src_dtype) + result_var: Any = self.cse.newvar() + result_var.mask_vars = {var for var in masks if var[0] != "r"} + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(triton_constant, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, where_cond(value, default)) + + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = str( + self.cse.generate( + self.compute, + f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + ) + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + elif reduction_type == "welford_reduce": + # For persistent reductions, don't bother with + # welford's algorithm since it uses more registers, and + # taking two reductions doesn't increase memory usage. + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.numels[-1], dtype) + mean = ops.truediv(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + result_var = (mean, m2, rnumel) + elif reduction_type == "welford_combine": + mean, m2, weight = masked_value + welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" + mean, m2, weight = (self.cse.newvar() for _ in range(3)) + self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}") + + result_var = tuple( + self.cse.generate(self.compute, self.reduction_resize(var_name)) + for var_name in (mean, m2, weight) + ) + else: + result_var = self.cse.generate( + self.compute, final_reduction(masked_value) + ) + else: + accumulator = f"_{result_var}" + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(triton_constant, default) + if not isinstance(default, tuple): + self.body.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.body.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + """ + ) + final_argreduce(self.suffix, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + accumulator = f"{result_var}_mean" + accumulator_m2 = f"{result_var}_m2" + accumulator_weight = f"{result_var}_weight" + self.body.writeline( + f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + + if reduction_type == "welford_combine": + mean, m2, weight = value + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( + {accumulator}, {accumulator_m2}, {accumulator_weight}, + {mean}, {m2}, {weight} + ) + """ + ) + else: + assert reduction_type == "welford_reduce" + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( + {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0 + ) + """ + ) + + self.compute.splice( + f"""\ + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)} + {accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)} + """ + ) + + result_mean = result_var + result_m2 = self.cse.newvar() + result_weight = self.cse.newvar() + self.suffix.splice( + f"""\ + {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford( + {accumulator}, {accumulator_m2}, {accumulator_weight}, {dim} + ) + {result_mean} = {self.reduction_resize(f'{result_mean}_tmp')} + {result_m2} = {self.reduction_resize(f'{result_m2}_tmp')} + {result_weight} = {self.reduction_resize(f'{result_weight}_tmp')} + """ + ) + result_var = result_mean, result_m2, result_weight + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + # This is only really used for aten.any. It changes the + # final reduction of a non-persistent reduction from + # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] + # to + # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) + # which is needed because tl.reduce doesn't support tl.int1 + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.suffix.writeline( + f"{result_var} = {final_reduction(accumulator)}.to({result_type})" + ) + else: + self.suffix.writeline( + f"{result_var} = {final_reduction(accumulator)}" + ) + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + self.outside_loop_vars |= set(result_var) + else: + self.outside_loop_vars.add(result_var) + + return result_var + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + assert self.inside_reduction + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + + if isinstance(indexing, BlockPtrOptions): + self.suffix.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + assert isinstance(indexing, IndexingOptions) + self.suffix.writeline( + DeferredLine( + name, + f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})", + ) + ) + + def _lift_helper(self, fn, num_args) -> str: + # Lift IR function into a triton function in the global namespace + helper = IndentedBuffer() + helper.writeline("@triton.jit") + args = [f"arg{n}" for n in range(num_args)] + signature = ", ".join(args) + helper.writeline(f"def {{name}}({signature}):") + + cse = CSE(prefix="", suffix="") + overrides = TritonOverrides(V.MockHandler()) + + class CSEProxy: + def __getattr__(self, name: str) -> Callable[..., CSEVariable]: + def inner(*args, **kwargs): + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + ) + + return inner + + with helper.indent(), V.set_ops_handler(CSEProxy()): + outputs = fn(*args) + helper.writeline(f"return {outputs}") + + return self.helper_functions.add(helper.getvalue()) + + def scan( + self, + dtype: torch.dtype, + combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], + value: CSEVariable, + init: int, + ) -> CSEVariable: + assert self.inside_reduction + masks = {f"{tree.prefix}mask" for tree in self.range_trees} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + + value = self.cse.generate( + self.compute, f"tl.broadcast_to({value}, {self.dense_size_str()})" + ) + + default = triton_constant(init) + dim = self.triton_tensor_ndim() - 1 + acc_type = triton_acc_type(dtype) + cond = " & ".join(masks) + + combine_helper_fn = self._lift_helper(combine_fn, 2) + + def where_cond(value): + if not cond: + return value + default_tensor = self.cse.generate( + self.body, + f"tl.full({[1] * self.triton_tensor_ndim()}, {default}, {triton_compute_type(dtype)})", + ) + return self.cse.generate( + self.compute, f"tl.where({cond}, {value}, {default_tensor})" + ) + + if self.persistent_reduction: + masked_value = where_cond(value) + result_var = self.cse.generate( + self.compute, + f"tl.associative_scan({masked_value}, {dim}, {combine_helper_fn})", + ) + else: + accumulator = self.cse.newvar() + reduced_size = self.dense_size_list() + reduced_size[-1] = "1" + reduced_size = f"[{', '.join(reduced_size)}]" + + self.body.writeline( + f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" + ) + + masked_value = where_cond(value) + partial_reduce = self.cse.generate( + self.compute, + self.reduction_resize( + f"tl.reduce({value}, {dim}, {combine_helper_fn})" + ), + ) + acc_next = combine_fn(accumulator, partial_reduce) + partial_scan = self.cse.generate( + self.compute, + f"tl.associative_scan({masked_value}, {dim}, {combine_helper_fn})", + ) + result_var = self.cse.generate( + self.compute, combine_fn(accumulator, partial_scan) + ) + self.compute.writeline(f"{accumulator} = {acc_next}") + + result_var.mask_vars = masks # type: ignore[attr-defined] + return result_var + + def codegen_body(self): + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if not ( + self.indexing_code + or self.loads + or self.stores + or self.compute + or self.suffix + ): + return + + if self.inside_reduction and self.range_trees[-1].is_loop: + self.body.writeline("for roffset in range(0, rnumel, RBLOCK):") + with self.body.indent(): + # last range tree is always reduction + self.range_trees[-1].codegen_header(self.body) + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + # invalidate any caches that came from inside the reduction loop + self.cse.invalidate(self.outside_loop_vars) + self.range_trees[-1].cache_clear() + else: + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.body.splice(self.suffix) + self.indexing_code.clear() + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.suffix.clear() + + def codegen_kernel_benchmark(self, num_gb, grid=None): + result = IndentedBuffer() + argdefs, call_args, signature = self.args.python_argdefs() + + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + if grid is None: + grid = [] + extra_args = [] + extra_args_str = None + for tree in self.active_range_trees(): + expr = pexpr(V.graph.sizevars.size_hint(tree.numel)) + extra_args.append(expr) + if tree.prefix != "r": + grid.append(expr) + if self.need_numel_args(): + extra_args_str = ", ".join(map(str, extra_args)) + ", " + else: + extra_args_str = "" + grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})" + else: + grid_arg = f"grid={grid}" + index = V.graph.scheduler.current_device.index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline("from triton.testing import do_bench") + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = do_bench(lambda: call(args), rep=40, fast_flush=True)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self): + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + from torch._inductor.triton_heuristics import grid, split_scan_grid + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _ = self.args.python_argdefs() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in self.buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices: Set[Any] = set() + no_index_dep_count = 0 + for dep in self.buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def _get_heuristic(self): + if self.persistent_reduction: + assert self.inside_reduction + return "persistent_reduction" + elif self.inside_reduction: + return "reduction" + return "pointwise" + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + + size_hints = [] + for numel in self.numels: + numel_hint = V.graph.sizevars.symbolic_hint(numel) + if not isinstance(numel_hint, (int, sympy.Integer)): + # This default heuristic hint was picked carefully: it is + # large, to ensure that we don't shrink the block size (since + # if you don't have many elements, it'd be wasteful to pick a + # large block size). Since we don't know how many elements we + # might have, we should be OK with some inefficiency to make + # sure we handle the large case well. 8192 is the largest + # block size we support, so we pick that. + # + # If we have a better hint for unbacked SymInts (e.g., because + # a user told us, or we are tracking upper bounds) we could + # use that here. + size_hint = 8192 + else: + size_hint = next_power_of_2(int(numel_hint)) + size_hints.append(size_hint) + + if not self.inside_reduction: + size_hints.pop() + + heuristics = self._get_heuristic() + + if name is None: + code.splice(gen_common_triton_imports()) + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature = self.args.python_argdefs() + # maps actual expression to SizeArg if it is in sizevars replacements + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + # mypy is unhappy about the sympy.Expr + # type for the key of the dict below + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype + ) + triton_meta = { + "signature": triton_meta_signature, + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + } + + inductor_meta = { + "autotune_hints": set(self.autotune_hints), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "no_x_dim": self.no_x_dim, + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + for tree in self.active_range_trees(): + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + triton_meta_signature[len(argdefs)] = signature_of( + sizearg, size_dtype=self.index_dtype + ) + argdefs.append(f"{tree.prefix}numel") + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + triton_meta["configs"] = [config_of(signature)] + + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] + + self.triton_meta = triton_meta + + for tree in self.range_trees: + if tree.prefix == "r" and self.persistent_reduction: + # RBLOCK for persistent_reduction is defined in codegen_static_numels + continue + if tree.tensor_dim is None: + continue + argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") + + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + if self.inside_reduction: + reduction_hint = self.reduction_hint + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + def codegen_static_numels(self, code): + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): + + We would add + xnumel = 4096 + rnumel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + for tree in self.range_trees: + if tree.prefix != "r" or self.inside_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + + if tree.prefix == "r" and self.persistent_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + val = int(simplified_tree_numel) + else: + continue + val = next_power_of_2(val) + code.writeline(f"RBLOCK: tl.constexpr = {val}") + + if tree.prefix == "x" and self.no_x_dim: + code.writeline("XBLOCK: tl.constexpr = 1") + + def triton_tensor_ndim(self): + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i): + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> List[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + if tree.prefix != "r" or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def dense_size_str(self): + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def _get_grid_fn(self): + return "grid" + + def add_numel_to_call_args_and_grid(self, name, call_args, grid): + # TODO(jansel): if there are constants, we shouldn't bother passing them as args + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree) + + if tree.prefix != "r" or self.inside_reduction: + call_args.append(expr) + if tree.grid_dim is not None: + grid.append(expr) + + def get_call_args(self): + _, call_args, _ = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + + return call_args + + def call_kernel(self, name: str, node: Optional[IRNode] = None): + wrapper = V.graph.wrapper_code + call_args = self.get_call_args() + grid: List[Any] = [] + self.add_numel_to_call_args_and_grid(name, call_args, grid) + current_device = V.graph.scheduler.current_device + + if self.args.workspace_arg is not None: + ws = self.args.workspace_arg + wrapper.generate_workspace_allocation( + ws.nbytes, current_device, ws.zero_fill + ) + + grid = wrapper.generate_default_grid(name, grid) + wrapper.generate_kernel_call( + name, + call_args, + grid, + current_device.index, + cuda=True, + triton=True, + grid_fn=self._get_grid_fn(), + triton_meta=self.triton_meta, + ) + + if self.args.workspace_arg is not None: + wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.python_argdefs() + for arg, arg_type in zip(call_args, arg_types): + if isinstance(arg_type, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, signature = self.args.python_argdefs() + uniform_stride_order = None + for arg_name in call_args: + buf = V.graph.get_buffer(arg_name) + if buf and len(buf.layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in buf.layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(buf.layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order(V.graph.get_buffer(name).layout.stride) + if V.graph.get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).layout.size + if V.graph.get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + msg = yellow_text( + f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def create_cse_var(self, *args, **kwargs): + return TritonCSEVariable(*args, **kwargs) + + +class TritonScheduling(BaseScheduling): + def __init__(self, scheduler): + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + + if node1.is_template(): + # Only allow fusion for TritonTemplates for now. + # Fusion for CUDATemplates are not supported. + is_triton_template = isinstance(node1.node, TritonTemplateBuffer) + if not is_triton_template: + why("node1 is not TritonTemplateBuffer") + return is_triton_template + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = self.select_tiling( + node1.get_nodes(), numel1 + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: List[Any] = [] + current_loop_writes: Set[str] = set() + + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + current_loop_reduced_writes = set() + current_loop_has_writes = False + done = set() + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def schedule_node_in_loop(n): + nonlocal current_loop_has_writes + done.add(n) + node_schedule.append(n) + current_loop_has_writes = True + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + current_loop_reduced_writes.add(n.get_name()) + + @contextlib.contextmanager + def end_current_reduction_loop(): + nonlocal current_loop_has_writes + if current_loop_has_writes: + # flush out any other runnable nodes to reduce number of loops + for other_node in nodes[index + 1 :]: + if ( + node not in done + and fits_in_main_body(other_node) + and not (current_loop_reduced_writes & other_node.ancestors) + ): + schedule_node_in_loop(node) + + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + yield + node_schedule.append(EnableReduction) + current_loop_reduced_writes.clear() + current_loop_has_writes = False + + for index, node in enumerate(nodes): + if node in done: + continue + done.add(node) + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not current_loop_reduced_writes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(current_loop_reduced_writes) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_nodes(self, nodes: List[scheduler.SchedulerNode]): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + buf_accesses = collections.defaultdict(list) + for node in nodes: + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) + + @staticmethod + def reduction_hint(node): + assert node.is_reduction() + if all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] + ) -> bool: + int_max = torch.iinfo(torch.int32).max + size_hint = V.graph.sizevars.size_hint + has_hint = V.graph.sizevars.shape_env.has_hint + + def within_32bit(e): + # Allow for unhinted e as long as we can still statically prove + # (e.g., via ValueRanges) that it is still in bounds + if V.graph.sizevars.is_expr_static_and_true(e <= int_max): + return True + # Otherwise, the hint MUST exist and be in range + return has_hint(e) and size_hint(e) <= int_max + + if not within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if not isinstance(buf.get_layout(), ir.MultiOutputLayout) + ] + + if not all(within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + return True + + @staticmethod + def select_index_dtype(node_schedule, numel, reduction_numel): + # Gather all used buffer names + buffer_names = set() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + + buffer_names.update(node.get_names()) + buffer_names.update(node.used_buffer_names()) + + # Get buffers objects + def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: + if name in V.graph.name_to_buffer: + return V.graph.name_to_buffer[name] + elif name in V.graph.graph_inputs: + return V.graph.graph_inputs[name] + elif name in V.graph.constants: + data = V.graph.constants[name] + return ir.ConstantBuffer( + name, + ir.FixedLayout( + data.device, data.dtype, *V.graph.static_sizes_strides(data) + ), + ) + raise RuntimeError(f"Failed to find buffer matching name {name}") + + buffers = [_get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = numel * reduction_numel + + if TritonScheduling.can_use_32bit_indexing(total_numel, buffers): + return "tl.int32" + return "tl.int64" + + def get_kernel_args(self, node_schedule, numel, reduction_numel): + reductions = list( + filter( + lambda n: n not in (EnableReduction, DisableReduction) + and n.is_reduction(), + node_schedule, + ) + ) + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + + mutations = set() + for node in node_schedule: + if hasattr(node, "get_mutations"): + mutations.update(node.get_mutations()) + + index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) + + return reduction_hint_val, mutations, index_dtype + + def codegen_comment(self, node_schedule): + wrapper = V.graph.wrapper_code + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + if origins: + wrapper.writeline(origins) + + if config.debug_fusion: + from torch._inductor.scheduler import ( + BaseSchedulerNode, + ForeachKernelSchedulerNode, + ) + + if not any( + isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule + ): + # We probably should look what are the nodes inside a foreach + # schedule node + node_names = [ + n.get_name() + for n in node_schedule + if isinstance(n, BaseSchedulerNode) + ] + wrapper.writeline( + f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" + ) + + def codegen_node_schedule( + self, node_schedule, buf_accesses, numel, reduction_numel + ): + from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel + + tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, reduction_numel + ) + + is_split_scan = any( + isinstance(node, BaseSchedulerNode) and node.is_split_scan() + for node in node_schedule + ) + kernel_type = TritonSplitScanKernel if is_split_scan else TritonKernel + kernel_args = tiled_groups + kernel_kwargs = { + "reduction_hint": reduction_hint_val, + "mutations": mutations, + "index_dtype": index_dtype, + } + kernel = kernel_type( + *kernel_args, + **kernel_kwargs, + ) + kernel.buf_accesses = buf_accesses + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + kernel_name = self.define_kernel(src_code, node_schedule) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + + if kernel.persistent_reduction and config.triton.multi_kernel: + kernel2 = TritonKernel( + *kernel_args, + **kernel_kwargs, + disable_persistent_reduction=True, + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel2) + with V.set_kernel_handler(kernel2): + src_code2 = kernel2.codegen_kernel() + kernel_name2 = self.define_kernel(src_code2, node_schedule) + kernel2.kernel_name = kernel_name2 + kernel2.code_hash = code_hash(src_code2) + + final_kernel = MultiKernel([kernel, kernel2]) + else: + final_kernel = kernel # type: ignore[assignment] + + with V.set_kernel_handler(final_kernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernel.args.live_output_buffers() + for node in node_schedule: + if not isinstance(node, scheduler.BaseSchedulerNode): + continue + name = node.get_name() + if name not in live_outs: + continue + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + stack = contextlib.ExitStack() + kernel.set_last_usage(current_reduction_nodes(node_schedule)) + + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.decide_inplace_update() + for i, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def define_kernel(self, src_code, node_schedule): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "#") + + basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', device_str='{V.graph.scheduler.current_device.type}')" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if is_metric_table_enabled("kernel_metadata"): + log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name + + def codegen_template( + self, template_node, epilogue_nodes, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + with kernel: + if not only_gen_src_code: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + partial_code = render() + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + + # finalize must be called after adding epilogue above + with V.set_kernel_handler(kernel): + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + node_schedule = [template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) + assert kernel.meta is not None, "meta is None" + grid = kernel.grid_fn(*grid_args, kernel.meta) + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel_name = self.define_kernel(src_code, node_schedule) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.scheduler.free_buffers() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def codegen_foreach(self, foreach_node): + from .triton_foreach import ForeachKernel + + for partitions_with_metadata in ForeachKernel.horizontal_partition( + foreach_node.get_subkernel_nodes(), self + ): + kernel = ForeachKernel() + for nodes, tiled_groups, numel, rnumel in partitions_with_metadata: + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + ( + reduction_hint_val, + mutations, + index_dtype, + ) = self.get_kernel_args(node_schedule, numel, rnumel) + + subkernel = kernel.create_sub_kernel( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel( + node_schedule, + subkernel, + ) + + with V.set_kernel_handler(subkernel): + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, [foreach_node]) + self.codegen_comment([foreach_node]) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.scheduler.free_buffers() + + @staticmethod + @functools.lru_cache(32) + def candidate_tilings(node): + ranges, reduction_ranges = node.get_ranges() + if len(ranges) <= 1: + return () + + rw = node.pointwise_read_writes() + assert len(rw.range_vars) == len(ranges) + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) + ] + write_names = {dep.name for dep in rw.writes} + + tilings: List[CandidateTiling] = [] + + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + tiled_groups = ( + V.graph.sizevars.simplify(sympy_product(ranges[:split])), + V.graph.sizevars.simplify(sympy_product(ranges[split:])), + ) + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append(CandidateTiling(tiled_groups, score, dep.name)) + return tilings + + @classmethod + def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + if reduction_numel != 1 or config.triton.max_tiles <= 1: + # TODO(jansel): should we tile reductions? + # do perf hint here if stride-1 dim is not being reduced + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if len(cls.candidate_tilings(node)) > 0: + perf_hint_log.info("reduction over non-contiguous dims") + break + return (numel, reduction_numel) + + seen_names = set() + candidate_tiles: Counter[Any] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for tiling in cls.candidate_tilings(node): + if tiling.name in seen_names: + continue + seen_names.add(tiling.name) + candidate_tiles[tiling.tiling] += tiling.score + + ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] + + if config.triton.max_tiles >= 3: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + # Add one 3D tiling choice + for i in range(1, len(ranked_tilings)): + a0, a1 = ranked_tilings[0] + b0, b1 = ranked_tilings[i] + if V.graph.sizevars.size_hint(a1 - b1) == 0: + continue + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + a0, a1 = ranked_tilings[i] + b0, b1 = ranked_tilings[0] + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if V.graph.sizevars.statically_known_multiple_of(a1, b1): + tiling = (a0, FloorDiv(a1, b1), b1) + ranked_tilings = [tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + for tiled_groups in ranked_tilings: + new_groups = (*tiled_groups, reduction_numel) + if all( + TritonKernel.is_compatible(new_groups, node.get_ranges()) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ): + return new_groups + + return (numel, reduction_numel) + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def benchmark_fused_nodes(self, nodes): + # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. + for n in nodes: + n.last_usage = set() + + if not nodes[0].is_template(): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + + tiled_groups = self.select_tiling(node_schedule, numel, rnumel) + reduction_hint_val, mutations, index_dtype = self.get_kernel_args( + node_schedule, numel, rnumel + ) + + kernel = TritonKernel( + *tiled_groups, + reduction_hint=reduction_hint_val, + mutations=mutations, + index_dtype=index_dtype, + ) + + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + else: + template_node = nodes[0] + epilogue_nodes = nodes[1:] + + with config.patch("benchmark_kernel", True): + src_code = self.codegen_template( + template_node, epilogue_nodes, only_gen_src_code=True + ) + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + mod = PyCodeCache.load(src_code) + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return float(fd.read()) + return None + + def store_cache(): + path = cache_file_path() + with open(path, "w") as fd: + fd.write(str(ms)) + + log.debug( + "kernel src code for %s written to: %s", + {n.get_name() for n in nodes}, + mod.__file__, + ) + ms = load_cache() + if ms is not None: + return ms, mod.__file__ + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + call(wrapped_jit_function.clone_args(*args)[0]) + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args)[0])) + + log.debug( + "The fused kernel for %s took %.3f ms to run", + {n.get_name() for n in nodes}, + ms, + ) + store_cache() + return ms, mod.__file__ + + +@dataclasses.dataclass +class CandidateTiling: + tiling: Tuple[sympy.Expr, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class DisableReduction: + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction: + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule): + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node + + +class CantSplit(Exception): + pass diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_foreach.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_foreach.py new file mode 100644 index 0000000000000000000000000000000000000000..8698731a6ce13b28932a6ccda50c38e71ba0f0b4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_foreach.py @@ -0,0 +1,250 @@ +import itertools +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Tuple + +from sympy import Integer + +import torch + +from .. import metrics +from ..scheduler import SchedulerNode +from ..utils import ceildiv, Placeholder +from ..virtualized import V +from .common import IndentedBuffer, Kernel +from .triton import gen_common_triton_imports, TritonKernel +from .triton_utils import config_of, signature_to_meta + + +@dataclass +class PartitionState: + partitions: List[ + List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]] + ] + cur_partition: List[ + Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer] + ] + cur_count: int + + def finalize(self): + if self.cur_partition: + self.partitions.append(self.cur_partition) + + +class ForeachKernel(Kernel): + MAX_NUM_ARGS = 250 # number where I would no longer get triton errors + + @staticmethod + def _update_partition(partition_state, node_rw_count, node_info): + if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS: + partition_state.partitions.append(partition_state.cur_partition) + partition_state.cur_partition = [node_info] + partition_state.cur_count = node_rw_count + else: + partition_state.cur_count += node_rw_count + partition_state.cur_partition.append(node_info) + + @staticmethod + def horizontal_partition(subkernel_nodes, triton_scheduling): + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + assert len(subkernel_nodes) >= 1 + + partition_state_1d = PartitionState([], [], 0) + yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + + for node in subkernel_nodes: + fused_nodes = node.get_nodes() + _, (numel, rnumel) = max( + fused_nodes, key=lambda x: int(x.is_reduction()) + ).group + tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel) + node_info = fused_nodes, tiled_groups, numel, rnumel + + read_writes = node.read_writes + read_write_count = len(read_writes.reads) + len(read_writes.writes) + + if tiled_groups[1] == 1: + ForeachKernel._update_partition( + partition_state_1d, read_write_count, node_info + ) + else: + y_elem = tiled_groups[0] + partition_state_2d = yelem_to_partition_state_2d[y_elem] + ForeachKernel._update_partition( + partition_state_2d, read_write_count, node_info + ) + + partition_state_1d.finalize() + all_partitions = partition_state_1d.partitions + for partition_state_2d in yelem_to_partition_state_2d.values(): + partition_state_2d.finalize() + all_partitions.extend(partition_state_2d.partitions) + + return all_partitions + + def __init__(self): + super().__init__() + self.blocking_2d = False + self.block_size_1d = 1024 # Try tuning this value + self.block_size_2d = 32 + self.num_warps = 8 + self.sub_kernels = [] + self.iter_vars_count = itertools.count() + self.x_block_count = 0 + self.y_block_count = 0 + + def get_block_size(self): + if self.blocking_2d: + return self.block_size_2d + else: + return self.block_size_1d + + @staticmethod + def codegen_pid_offsets(code, block_count, lower_bound, prefix): + if block_count == 0: + code.splice(f"{prefix}pid_offset = {prefix}pid") + else: + code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}") + + def codegen_pid_range(self, code, x_elems): + num_x_blocks = ceildiv(x_elems, self.get_block_size()) + upper_bound_x_pid = self.x_block_count + num_x_blocks + lower_bound_x_pid = self.x_block_count + + if self.x_block_count == 0: + cond = "if" + else: + cond = "elif" + + x_pid_bounds_check = ( + f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}" + ) + code.splice(f"{cond} {x_pid_bounds_check}:") + + with code.indent(): + ForeachKernel.codegen_pid_offsets( + code, num_x_blocks, lower_bound_x_pid, "x" + ) + self.x_block_count += num_x_blocks + + def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint): + sub_kernel = TritonKernel( + *groups, + index_dtype=index_dtype, + mutations=mutations, + pid_cache={ + "tl.program_id(0)": "xpid_offset", + "tl.program_id(1)": "ypid", + }, + reduction_hint=reduction_hint, + ) + if self.blocking_2d: + assert len(groups) == 3 + + self.blocking_2d |= groups[1] != 1 and len(groups) == 3 + metrics.generated_kernel_count -= 1 + sub_kernel.args = self.args + sub_kernel.iter_vars_count = self.iter_vars_count + sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids + self.sub_kernels.append(sub_kernel) + return sub_kernel + + def jit_lines(self): + can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) + size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + _, _, signature = self.args.python_argdefs() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=size_dtype), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + return f""" + @triton_heuristics.foreach( + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + + def grid(self): + return ( + self.x_block_count, + ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d) + if self.blocking_2d + else 1, + 1, + ) + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + + code.splice(gen_common_triton_imports()) + argdefs, _, _ = self.args.python_argdefs() + code.splice(self.jit_lines()) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + + with code.indent(): + code.splice("xpid = tl.program_id(0)") + if self.blocking_2d: + code.splice("ypid = tl.program_id(1)") + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") + code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") + else: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") + + for sub_kernel in self.sub_kernels: + assert len(sub_kernel.numels) <= 3 + # TODO mlazos: support dynamic shapes + numel_ind = 0 if not self.blocking_2d else 1 + self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind])) + with code.indent(): + if self.blocking_2d: + code.splice(f"ynumel = {sub_kernel.numels[0]}") + code.splice(f"xnumel = {sub_kernel.numels[1]}") + else: + code.splice(f"xnumel = {sub_kernel.numels[0]}") + + sub_kernel.codegen_body() + code.splice(sub_kernel.body) + + code.splice("else:") + with code.indent(): + code.splice("pass") + + return code.getvalue() + + def call_kernel(self, code, name: str): + _, call_args, _ = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + if V.graph.cpp_wrapper: + V.graph.wrapper_code.generate_kernel_call( + name, + call_args, + device_index=V.graph.scheduler.current_device.index, + grid=self.grid(), + ) + else: + # TODO: refactor generate_kernel_call + call_args_str = ", ".join(call_args) + stream_name = code.write_get_raw_stream( + V.graph.scheduler.current_device.index + ) + code.writeline( + f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})" + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_split_scan.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_split_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d935dc196d0c23e358df469b241dfbc089d700 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_split_scan.py @@ -0,0 +1,180 @@ +import functools + +from typing import Optional, Set + +from torch._inductor import config, ir + +from torch._inductor.codegen.triton import ( + IterationRangesRoot, + triton_compute_type, + TritonKernel, + TritonKernelOverrides, +) + +from torch._prims_common import prod + +from torch.utils._sympy.functions import CeilDiv + + +class TritonSplitScanKernel(TritonKernel): + """Generates a triton kernel that supports ops.scan calls while also splitting + the reduction dimension over multiple triton programs. + + For this kernel, loop numels will always take the form ``(xdim, rdim)`` + and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication + between blocks occurs within a global memory workspace buffer, which + must be zero-filled before launching the kernel. + + Note that generation for ``ops.reduction`` is not supported. + + For details of the communication strategy, see + https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + """ + + def __init__( + self, + *groups, + index_dtype: str, + mutations: Optional[Set[str]] = None, + reduction_hint=ir.ReductionHint.DEFAULT, + min_elem_per_thread=0, + ): + super().__init__( + *groups, + index_dtype=index_dtype, + mutations=mutations, + pid_cache=None, + reduction_hint=reduction_hint, + min_elem_per_thread=min_elem_per_thread, + ) + self.no_x_dim = True + + def initialize_range_tree(self, pid_cache): + prefixes = "yxr" + assert len(self.numels) <= len( + prefixes + ), "z dimension not supported for split scan" + active_prefixes = prefixes[len(prefixes) - len(self.numels) :] + + grid_dims = "rxy" + for numel, prefix in zip(self.numels, active_prefixes): + is_reduction = prefix == "r" + tensor_dim = 0 if is_reduction else None + grid_dim = grid_dims.find(prefix) + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numel, + prefix, + grid_dim, + self, + pid_cache=pid_cache, + is_loop=False, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + ) + ) + for tree in self.range_trees: + tree.codegen_header(self.body) + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise NotImplementedError("NYI TritonSplitDimKernel reductions") + + def scan(self, dtype, combine_fn, value, init): + import triton.language as tl + + compute_type = triton_compute_type(dtype) + compute_type_triton = getattr(tl, compute_type[3:]) + + element_nbits = compute_type_triton.primitive_bitwidth + + scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" + scratch_type_triton = getattr(tl, scratch_type[3:]) + scratch_elems_per_block = 3 if element_nbits == 64 else 1 + scratch_nbytes_per_block = scratch_elems_per_block * ( + scratch_type_triton.primitive_bitwidth // 8 + ) + + cse_load = functools.partial(self.cse.generate, self.loads) + cse_compute = functools.partial(self.cse.generate, self.compute) + + assert len(self.numels) == 2, "Unexpected tiling" + min_rblock = config.triton.min_split_scan_rblock + max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock) + nbytes = scratch_nbytes_per_block * max_blocks + scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) + if offset != 0: + scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") + runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}" + ) + + masks = {f"{tree.prefix}mask" for tree in self.range_trees} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + + value = cse_compute(f"{value}.to({compute_type})") + value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") + init = cse_compute(f"tl.full([], {init}, {compute_type})") + if masks: + cond = " & ".join(masks) + masked_value = cse_compute(TritonKernelOverrides.where(cond, value, init)) + else: + masked_value = value + + combine_helper_fn = self._lift_helper(combine_fn, 2) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" + + block_sum = cse_compute( + f"tl.reduce({masked_value}, {dim}, {combine_helper_fn})" + ) + exclusive_prefix = self.cse.newvar() + if element_nbits == 64: + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( + {scratch_base}, + {block_sum}, + {self.range_trees[-1].get_pid()}, + {combine_helper_fn}, + {init}, + ) + """, + strip=True, + ) + + else: + assert element_nbits <= 32 + value_as_uint_dtype = f"tl.uint{element_nbits}" + + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( + {scratch_base}, + {block_sum}, + {self.range_trees[-1].get_pid()}, + {combine_helper_fn}, + {init}, + DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, + DTYPE_PACK={scratch_type}, + ) + """, + strip=True, + ) + # Compute final cumsum + block_scan = cse_compute( + f"tl.associative_scan({masked_value}, {dim}, {combine_helper_fn})" + ) + return cse_compute(f"{combine_helper_fn}({exclusive_prefix}, {block_scan})") + + def _get_heuristic(self): + return "split_scan" + + def _get_grid_fn(self): + return "split_scan_grid" diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_utils.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c12fded8a2f0b0d602fc9365bc3abb527540640 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/triton_utils.py @@ -0,0 +1,130 @@ +from typing import Any, Dict, List, Optional + +import torch + +from .. import config +from ..utils import _type_of, instance_descriptor +from ..virtualized import V +from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg + + +def signature_of(arg: KernelArgType, *, size_dtype: str) -> str: + if isinstance(arg, TensorArg): + # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. + # Related PR: https://github.com/openai/triton/pull/2279/ + if arg.dtype == torch.float8_e4m3fn: + tye = "*fp8e4nv" + elif arg.dtype == torch.float8_e5m2: + tye = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + tye = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + tye = "*fp8e5b16" + else: + tye = _type_of(arg.dtype) + if V.graph.is_unspec_arg(arg.buffer): + # had unwrapped 0d tensor as scalar + new_tye = tye.lstrip("*") + if new_tye in ["fp16", "bf16"]: + return "fp32" + else: + return new_tye + else: + return tye + if isinstance(arg, SizeArg): + if arg.expr is None: + # From triton/runtime/jit.py + # `None` is nullptr. Implicitly convert to *i8. + return "*i8" + elif isinstance(arg.expr, float): + return "fp32" + if size_dtype == "tl.int32": + return "i32" + elif size_dtype == "tl.int64": + return "i64" + else: + raise NotImplementedError(f"unhandled size_dtype {size_dtype}") + if isinstance(arg, WorkspaceArg): + return "*i8" + raise NotImplementedError(f"unhandled {type(arg)}: {arg}") + + +def signature_to_meta( + signature: List[KernelArgType], + *, + size_dtype: str, + indices: Optional[List[int]] = None, +) -> Dict[int, str]: + if indices is None: + indices = list(range(len(signature))) + return { + i: signature_of(arg, size_dtype=size_dtype) + for i, arg in zip(indices, signature) + } + + +def config_of( + args: List[KernelArgType], + *, + indices: Optional[List[int]] = None, +) -> Any: + if indices is None: + indices = list(range(len(args))) + + def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: + """ + Roughly follow triton code here: + https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 + """ + if isinstance(x, TensorArg): + if include_tensor: + offset_aligned = V.graph.sizevars.statically_known_multiple_of( + x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] + ) + return offset_aligned and not V.graph.scheduler.is_unaligned_buffer( + x.buffer + ) + else: + return False + if isinstance(x, SizeArg): + # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with + # _maybe_evaluate_static... + if x.name.startswith("load_seed_offset"): + return False + if x.expr is None: + return False + if isinstance(x.expr, float): + return False + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] + if isinstance(x, WorkspaceArg): + return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment) # type: ignore[arg-type] + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: + divisible_by_16 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) + ) + else: + divisible_by_16 = () + divisible_by_8 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=8, include_tensor=False) + ) + + equal_to_1 = tuple( + i + for i, arg in zip(indices, args) + if isinstance(arg, SizeArg) + and arg.expr is not None + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) + # ids_of_folded_args is set from equal_to_1 + # and None args by the Triton compiler + ids_of_folded_args = tuple(equal_to_1) + + return instance_descriptor( + divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8 + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/codegen/wrapper.py b/MLPY/Lib/site-packages/torch/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..afb4976ade34e58600dcabe6e5ba5ff24e3e3fff --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/codegen/wrapper.py @@ -0,0 +1,1543 @@ +import collections +import contextlib +import dataclasses +import functools +import inspect +import operator +import re +from itertools import count +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import sympy +from sympy import Expr + +import torch +import torch._ops +from torch._dynamo.utils import counters, dynamo_timed + +from torch._inductor.codegen.multi_kernel import MultiKernelState +from torch.fx.experimental.symbolic_shapes import SymTypes +from torch.fx.node import _get_qualified_name +from torch.utils._sympy.singleton_int import SingletonInt + +from .. import codecache, config, ir +from ..ir import ReinterpretView +from ..utils import ( + cache_on_self, + get_benchmark_name, + LineContext, + sympy_product, + sympy_str, +) +from ..virtualized import V +from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter +from .triton_utils import config_of, signature_to_meta + +if TYPE_CHECKING: + import triton + + from ..graph import GraphLowering + + +pexpr = PythonPrinter().doprint + + +ReuseKey = Tuple[torch.device, torch.dtype, str] + + +def buffer_reuse_key(node: ir.Buffer) -> ReuseKey: + return ( + node.get_device(), + node.get_dtype(), + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())), + ) + + +def convert_arg_type(arg: torch.Argument) -> str: + from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP + + # use x.real_type instead of x.type so that we get ScalarType instead of int + python_type = repr(arg.real_type) # type: ignore[attr-defined] + + if python_type == "Tensor": + # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func + if arg.alias_info is not None and arg.alias_info.is_write: + return f"at::{python_type}&" + else: + return f"at::{python_type} const&" + + if python_type in PYTHON_TO_CPP: + cpp_type = PYTHON_TO_CPP[python_type] + return cpp_type + + # Convert args of container types e.g. Optional[*] + for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items(): + container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type) + if len(container_match) == 1: + contained_type = container_match[0] + assert ( + contained_type in PYTHON_TO_CPP + ), f"unsupported {py_container} type in convert_arg_type: {contained_type}" + cpp_contained_type = PYTHON_TO_CPP[contained_type] + return f"{cpp_container}<{cpp_contained_type}>" + + raise AssertionError(f"unsupport python_type: {python_type}") + + +def convert_return_type(ret: torch.Argument) -> str: + # use x.real_type instead of x.type so that we get ScalarType instead of int + python_type = repr(ret.real_type) # type: ignore[attr-defined] + python_to_cpp = { + "Tensor": "at::Tensor", + "List[Tensor]": "std::vector", + } + + cpp_type = python_to_cpp.get(python_type, None) + assert cpp_type is not None, f"NYI return type: {python_type}" + # An output aliasing an input is returned by reference only when it's a + # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output + # aliases the input tensor, but the op returns a vector by value. + if python_type == "Tensor" and ret.alias_info is not None: + cpp_type += "&" + return cpp_type + + +def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str: + args = kernel._schema.arguments + returns = kernel._schema.returns + + num_returns = len(returns) + assert num_returns > 0, "must have at least one return value" + + if num_returns == 1: + cpp_return_value = convert_return_type(returns[0]) + elif num_returns > 1: + tuple_returns = ", ".join([convert_return_type(r) for r in returns]) + cpp_return_value = f"std::tuple<{tuple_returns}>" + + cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args] + return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined] + + +# TODO: Move to a well known place +TritonMetaParams = Dict[str, int] +TritonGrid = Union[ + Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]] +] + + +def user_defined_kernel_grid_fn_code( + name: str, + configs: List["triton.Config"], + grids: List[TritonGrid], + wrapper: Optional["WrapperCodeGen"] = None, +) -> Tuple[str, str]: + output = IndentedBuffer() + + def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: + return item if isinstance(item, sympy.Expr) else sympy.Integer(item) + + def determine_grid(grid: TritonGrid): + if wrapper is None or callable(grid): + # return as-is when used in eager mode or when grid is callable + return grid + # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen + sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) + return wrapper.codegen_shape_tuple(sympy_grid) + + fn_name = f"grid_wrapper_for_{name}" + output.writeline(f"def {fn_name}(meta):") + with output.indent(): + if len(grids) == 1: + grid = determine_grid(grids[0]) + output.writeline(f"return {grid}") + else: + assert len(grids) > 1 + assert len(grids) == len(configs) + seen = set() + for grid, c in zip(grids, configs): + guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()] + guards = " and ".join(guards) + grid = determine_grid(grid) + statement = f"if {guards}: return {grid}" + if statement in seen: + continue + seen.add(statement) + output.writeline(statement) + + return fn_name, output.getvalue() + + +@dataclasses.dataclass +class SymbolicCallArg: + inner: str + # the original symbolic expression represented by inner + inner_expr: sympy.Expr + + def __str__(self): + return str(self.inner) + + +# Default thread stack sizes vary by platform: +# - Linux: 8 MB +# - macOS: 512 KB +# - Windows: 1 MB +# Just pick something comfortably smaller than the smallest for now. +MAX_STACK_ALLOCATION_SIZE = 1024 * 100 + + +class MemoryPlanningState: + def __init__(self): + super().__init__() + self.reuse_pool: Dict[ + ReuseKey, List[FreeIfNotReusedLine] + ] = collections.defaultdict(list) + self.total_allocated_buffer_size: int = 0 + + def __contains__(self, key: ReuseKey) -> bool: + return bool(self.reuse_pool.get(key, None)) + + def pop(self, key: ReuseKey) -> "FreeIfNotReusedLine": + item = self.reuse_pool[key].pop() + assert not item.is_reused + return item + + def push(self, key: ReuseKey, item: "FreeIfNotReusedLine") -> None: + assert not item.is_reused + self.reuse_pool[key].append(item) + + +class WrapperLine: + pass + + +@dataclasses.dataclass +class EnterSubgraphLine(WrapperLine): + wrapper: "WrapperCodeGen" + graph: "GraphLowering" + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.push_codegened_graph(self.graph) + code.do_indent() + + +@dataclasses.dataclass +class ExitSubgraphLine(WrapperLine): + wrapper: "WrapperCodeGen" + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.pop_codegened_graph() + code.do_unindent() + + +@dataclasses.dataclass +class EnterDeviceContextManagerLine(WrapperLine): + device_idx: int + last_seen_device_guard_index: Optional[int] + + def codegen(self, code: IndentedBuffer) -> None: + if V.graph.cpp_wrapper: + code.writeline("\n") + if V.graph.aot_mode: + # In AOT mode, we have a stream provided as a param. A stream is + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: + if config.abi_compatible: + code.writeline( + "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);" + ) + else: + code.writeline( + "at::cuda::CUDAStreamGuard stream_guard(" + + "at::cuda::getStreamFromExternal(stream, this->device_idx_));" + ) + else: + assert ( + self.last_seen_device_guard_index == self.device_idx + ), "AOTInductor only supports running on one CUDA device" + else: + if self.last_seen_device_guard_index is None: + code.writeline( + f"AOTICudaGuard device_guard({self.device_idx});" + if config.abi_compatible + else f"at::cuda::CUDAGuard device_guard({self.device_idx});" + ) + else: + code.writeline(f"device_guard.set_index({self.device_idx});") + else: + # Note _DeviceGuard has less overhead than device, but only accepts + # integers + code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") + code.do_indent() + code.writeline(V.graph.device_ops.set_device(self.device_idx)) + + +class ExitDeviceContextManagerLine(WrapperLine): + def codegen(self, code: IndentedBuffer) -> None: + if not V.graph.cpp_wrapper: + code.do_unindent() + + +@dataclasses.dataclass +class MemoryPlanningLine(WrapperLine): + wrapper: "WrapperCodeGen" + + def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine": + """First pass to find reuse""" + return self + + def codegen(self, code: IndentedBuffer) -> None: + """Second pass to output code""" + pass + + def __str__(self) -> str: + """ + Emits a string representation that fits on one line. + """ + args: List[str] = [] + for field in dataclasses.fields(self): + if field.name == "wrapper": + continue + val = getattr(self, field.name) + args.append( + f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" + ) + return f"{type(self).__name__}({', '.join(args)})" + + +@dataclasses.dataclass +class AllocateLine(MemoryPlanningLine): + node: ir.Buffer + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + if config.allow_buffer_reuse and key in state: + free_line = state.pop(key) + free_line.is_reused = True + return ReuseLine(self.wrapper, free_line.node, self.node) + + if self.node.get_device().type == "cpu": + static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) + if static_shape is not None: + state.total_allocated_buffer_size += int( + functools.reduce(operator.mul, static_shape, 1) + ) + + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + line = self.wrapper.make_buffer_allocation(self.node) + code.writeline(line) + + +@dataclasses.dataclass +class FreeIfNotReusedLine(MemoryPlanningLine): + node: ir.Buffer + is_reused: bool = False + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if isinstance(self.node.layout, (ir.AliasedLayout, ir.MultiOutputLayout)): + return self + assert not self.is_reused + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + if config.allow_buffer_reuse: + state.push(buffer_reuse_key(self.node), self) + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(self.wrapper.make_buffer_free(self.node)) + + +@dataclasses.dataclass +class ReuseLine(MemoryPlanningLine): + node: ir.Buffer + reused_as: ir.Buffer + delete_old: bool = True + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + assert self.reused_as.get_name() in V.graph.removed_buffers + return NullLine(self.wrapper) + assert self.reused_as.get_name() not in V.graph.removed_buffers + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) + ) + + +class NullLine(MemoryPlanningLine): + pass + + +BufferName = str + + +class WrapperCodeGen(CodeGen): + """ + Generate outer wrapper in Python that calls the kernels. + """ + + def __init__(self): + super().__init__() + self._names_iter: Iterator[int] = count() + self.header = IndentedBuffer() + self.prefix = IndentedBuffer() + self.suffix = IndentedBuffer() + self.wrapper_call = IndentedBuffer() + # If the generated source code is exactly the same, reuse the + # pre-existing kernel for it + self.src_to_kernel: Dict[str, str] = {} + self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set() + self.lines: List[Union[MemoryPlanningLine, LineContext]] = [] + self.declare = "" + self.declare_maybe_reference = "" + self.ending = "" + self.open_bracket = "[" + self.closed_bracket = "]" + self.comment = "#" + self.namespace = "" + self.none_str = "None" + self.size = "size()" + self.stride = "stride()" + self.last_seen_device_guard_index: Optional[int] = None + self.supports_intermediate_hooks = True + self.expr_printer = pexpr + self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} + self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol + self.allow_stack_allocation: Optional[bool] = None + self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {} + self.computed_sizes: Set[sympy.Symbol] = set() + + # this is used for tracking which GraphLowering instance---parent graph + # or (nested) subgraph---is currently codegened; the primary use case is + # including the graph instance into a cache key to avoid cross-graph + # caching during lowering of nested subgraphs + self.codegened_graph_stack = [V.graph] + + self.write_header() + self.write_prefix() + + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache puts different constants into different files + self.write_constant(name, hashed) + + self.allocated: Set[BufferName] = set() + self.freed: Set[BufferName] = set() + + # maps from reusing buffer to reused buffer + self.reuses: Dict[BufferName, BufferName] = dict() + + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @functools.lru_cache(None) + def add_import_once(line: str) -> None: + self.header.writeline(line) + + self.add_import_once = add_import_once + self._metas: Dict[str, str] = {} + self.multi_kernel_state = MultiKernelState() + + def write_constant(self, name: str, hashed: str) -> None: + self.header.writeline(f"{name} = None # {hashed}") + + def write_header(self) -> None: + self.header.splice( + f""" + from ctypes import c_void_p, c_long + import torch + import math + import random + import os + import tempfile + from math import inf, nan + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + + from torch import device, empty_strided + from {codecache.__name__} import AsyncCompile + from torch._inductor.select_algorithm import extern_kernels + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + alloc_from_pool = torch.ops.inductor._alloc_from_pool + reinterpret_tensor = torch.ops.inductor._reinterpret_tensor + async_compile = AsyncCompile() + + """ + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + self.header.splice( + """ + import triton + import triton.language as tl + from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph + {} + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def add_meta_once(self, meta: TritonMetaParams) -> str: + meta = repr(meta) + if meta not in self._metas: + var = f"meta{len(self._metas)}" + self._metas[meta] = var + self.header.writeline(f"{var} = {meta}") + return self._metas[meta] + + @cache_on_self + def get_output_refs(self) -> List[str]: + return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs] + + def mark_output_type(self) -> None: + return + + def codegen_input_size_asserts(self) -> None: + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + size = self.codegen_shape_tuple(buf.get_size()) + stride = self.codegen_shape_tuple(buf.get_stride()) + self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + + def codegen_input_nan_asserts(self) -> None: + self.prefix.writeline("# make sure graph inputs are not nan/inf") + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + line = f"assert not {name}.isnan().any().item()" + self.prefix.writeline(line) + line = f"assert not {name}.isinf().any().item()" + self.prefix.writeline(line) + + def write_prefix(self) -> None: + self.prefix.splice( + """ + + async_compile.wait(globals()) + del async_compile + + def call(args): + """ + ) + with self.prefix.indent(): + if config.triton.debug_sync_graph: + self.prefix.writeline(V.graph.device_ops.synchronize()) + if V.graph.graph_inputs: + lhs = ", ".join(V.graph.graph_input_names) + if len(V.graph.graph_input_names) == 1: + lhs += "," + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + if config.size_asserts: + self.codegen_input_size_asserts() + if config.nan_asserts: + self.codegen_input_nan_asserts() + + # this function (and below) takes a graph as input so + # that stream caching happens per graph instance. this + # is important for nested subgraph codegening. + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: + self.write_triton_header_once() + name = f"stream{device_idx}" + self.writeline(f"{name} = get_raw_stream({device_idx})") + return name + + def get_codegened_graph(self): + return self.codegened_graph_stack[-1] + + def push_codegened_graph(self, graph): + self.codegened_graph_stack.append(graph) + + def pop_codegened_graph(self): + return self.codegened_graph_stack.pop() + + def next_kernel_suffix(self) -> str: + return f"{next(self._names_iter)}" + + def codegen_device_guard_enter(self, device_idx: int) -> None: + self.writeline( + EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) + ) + self.last_seen_device_guard_index = device_idx + + def codegen_device_guard_exit(self) -> None: + self.writeline(ExitDeviceContextManagerLine()) + + def generate_return(self, output_refs: List[str]) -> None: + if output_refs: + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_before_suffix(self, result: IndentedBuffer) -> None: + return + + def generate_end(self, result: IndentedBuffer) -> None: + return + + def generate_fallback_kernel(self, fallback_kernel, args): + self.generate_extern_kernel_alloc(fallback_kernel, args) + + def generate_extern_kernel_alloc(self, extern_kernel, args): + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel): + if output_view: + args.append(f"out={output_view.codegen_reference()}") + else: + args.append(f"out={codegen_reference}") + self.writeline(f"{kernel}({', '.join(args)})") + + def generate_user_defined_triton_kernel( + self, kernel_name, grid, configs, args, triton_meta + ): + grid, code = user_defined_kernel_grid_fn_code( + kernel_name, configs, grid, wrapper=self + ) + # Must happen after free symbols are already codegened + # Emit the grid wrapper function right before the call + for line in code.split("\n"): + self.writeline(line) + + stream_name = self.write_get_raw_stream( + V.graph.scheduler.current_device.index, V.graph + ) + self.writeline( + f"{kernel_name}.run({', '.join(args)}, grid={grid}, stream={stream_name})" + ) + + def generate_scatter_fallback( + self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs + ): + line = f"{kernel}({','.join(map(str, inputs))}" + if kernel == "aten.scatter_": + if reduce: + line += f", reduce={repr(reduce)}" + else: + line += ", ".join([""] + kwargs) + line += f"){self.ending}" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" + args = [x, indices_str, values, accumulate] + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_extern_kernel_alloc_and_find_schema_if_needed( + self, + name, + kernel, + codegen_args, + cpp_op_schema, + cpp_kernel_key, + cpp_kernel_overload_name="", + op_overload=None, + raw_args=None, + outputs=None, + ): + self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})") + + def generate_inf_and_nan_checker(self, node): + # TODO: Add check for python too. + pass + + @dynamo_timed + def generate(self, is_inference): + if config.profile_bandwidth: + self.write_triton_header_once() + result = IndentedBuffer() + result.splice(self.header) + + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.generate_profiler_mark_wrapper_call(stack) + if config.profile_bandwidth: + self.generate_start_graph() + + # We disable planning during training because it presently increases peak memory consumption. + if is_inference and config.memory_planning: + self.memory_plan() + # TODO: integrate memory planning & stack allocation? + self.allow_stack_allocation = False + else: + self.memory_plan_reuse() + + if config.triton.store_cubin: + self.generate_reset_kernel_saved_flags() + + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + + output_refs = self.get_output_refs() + self.mark_output_type() + if config.triton.debug_sync_graph: + self.wrapper_call.writeline(V.graph.device_ops.synchronize()) + + if config.profile_bandwidth: + self.generate_end_graph() + + if config.triton.store_cubin: + self.generate_save_uncompiled_kernels() + + self.generate_return(output_refs) + + self.finalize_prefix() + result.splice(self.prefix) + + with result.indent(): + result.splice(self.wrapper_call) + + self.generate_before_suffix(result) + result.splice(self.suffix) + + self.generate_end(result) + + self.add_benchmark_harness(result) + + return result.getvaluewithlinemap() + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + self.allow_stack_allocation = ( + self.allow_stack_allocation is not False + and config.allow_stack_allocation + and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE + ) + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}") + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + code.writeline( + f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}" + ) + + def codegen_inputs( + self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox] + ): + """Assign all symbolic shapes to locals""" + + @functools.lru_cache(None) + def sizeof(name): + self.codegen_input_size_var_decl(code, name) + return f"{name}_size" + + @functools.lru_cache(None) + def strideof(name): + self.codegen_input_stride_var_decl(code, name) + return f"{name}_stride" + + # Assign all symbolic shapes needed to local variables + needed = V.graph.sizevars.free_symbols() + + def is_expr(x): + return isinstance(x[1], sympy.Expr) + + graph_inputs_expr = list(filter(is_expr, graph_inputs.items())) + graph_inputs_tensors = list( + filter(lambda x: not is_expr(x), graph_inputs.items()) + ) + + for name, shape in graph_inputs_expr: + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] + if shape in needed: + needed.remove(shape) # type: ignore[arg-type] + code.writeline(f"{self.declare}{shape} = {name}{self.ending}") + + for name, value in graph_inputs_tensors: + shapes = value.get_size() + for dim, shape in enumerate(shapes): + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] + if shape in needed: + needed.remove(shape) # type: ignore[arg-type] + code.writeline( + f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" + ) + + for name, value in graph_inputs_tensors: + shapes = value.get_stride() + for dim, shape in enumerate(shapes): + shape = V.graph.sizevars.simplify(shape) # type: ignore[arg-type] + if shape in needed: + needed.remove(shape) # type: ignore[arg-type] + code.writeline( + f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" + ) + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and sym.name.startswith("ps"): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline( + f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}" + ) + + def finalize_prefix(self): + pass + + def codegen_python_sizevar(self, x: Expr) -> str: + return pexpr(V.graph.sizevars.simplify(x)) + + def codegen_sizevar(self, x: Expr) -> str: + return self.codegen_python_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + return f"{basename}[{index}]" + + def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_python_sizevar, shape)) + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + return self.codegen_python_shape_tuple(shape) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + str(dtype), + self.codegen_shape_tuple(shape), + self.codegen_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view(self, data, size, stride, offset, writer) -> str: + size = self.codegen_shape_tuple(size) + stride = self.codegen_shape_tuple(stride) + offset = self.codegen_sizevar(offset) + return f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" + + def codegen_device_copy(self, src, dst): + self.writeline(f"{dst}.copy_({src})") + + def codegen_multi_output(self, name, value): + self.writeline(f"{self.declare}{name} = {value}{self.ending}") + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if node.is_bool: + self.writeline(f"{node.sym} = 1 if {data}.item() else 0") + else: + self.writeline(f"{node.sym} = {data}.item()") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + self.writeline(f"{node.get_name()} = None") + + def benchmark_compiled_module(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_fake_input( + name, value.size(), value.stride(), value.device, value.dtype + ) + + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + # Inductor should only work with dense -> dense graph, and + # SingletonInts belong to metadata that should only live on + # the subclass. + continue + if isinstance(value, sympy.Expr): # Don't need to add symbolic + add_expr_input(name, V.graph.sizevars.size_hint(value)) + else: + shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()] + stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()] + add_fake_input( + name, shape, stride, value.get_device(), value.get_dtype() + ) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return print_performance(fn, times=times, repeat=repeat)") + + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "from torch._inductor.wrapper_benchmark import compiled_module_main", + f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", + ] + ) + + def define_kernel( + self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + ): + metadata_comment = f"{metadata}\n" if metadata else "" + self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}") + + def define_user_defined_triton_kernel(self, kernel, configs, kwargs): + original_name = kernel.__name__ + + from .common import KernelArgType, SizeArg, TensorArg + + signature: List[KernelArgType] = [] + constants: Dict[int, Any] = {} + non_constant_indices = [] + equal_to_1_arg_idx: List[int] = [] + for idx, key in enumerate(kernel.arg_names): + if key not in kwargs: + continue + arg = kwargs[key] + if idx in kernel.constexprs: + constants[idx] = arg + else: + non_constant_indices.append(idx) + if isinstance(arg, ir.Buffer): + signature.append( + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ) + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + signature.append( + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ) + ) + else: + signature.append(SizeArg(key, arg)) + if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type] + equal_to_1_arg_idx.append(idx) + index_dtype = "tl.int32" + triton_meta = { + "signature": signature_to_meta( + signature, + size_dtype=index_dtype, + indices=non_constant_indices, + ), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # TODO(aakhundov): add None args to constants, too. currently, this + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + "constants": { + **constants, + **{idx: 1 for idx in equal_to_1_arg_idx}, + }, + "configs": [ + config_of( + signature, + indices=non_constant_indices, + ) + ], + } + + # Distinguish between different functions using function id + cache_key: List[Any] = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key = tuple(cache_key) + + if cache_key in self.user_defined_kernel_cache: + return self.user_defined_kernel_cache[cache_key] + + name = f"{original_name}_{len(self.user_defined_kernel_cache)}" + # Add to the cache for the next use + self.user_defined_kernel_cache[cache_key] = (name, triton_meta) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + from .triton import gen_common_triton_imports + + compile_wrapper.splice(gen_common_triton_imports()) + + inductor_meta = { + "kernel_name": name, + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + + configs = [ + { + "kwargs": config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + for config in configs + ] + + compile_wrapper.splice( + f""" + @triton_heuristics.user_autotune( + configs={configs!r}, + inductor_meta={inductor_meta!r}, + triton_meta={triton_meta!r}, + filename=__file__, + custom_kernel=True, + ) + @triton.jit + """ + ) + compile_wrapper.splice(kernel.src, strip=True) + + # Also include any possible kernel being called indirectly + from triton import JITFunction + + symbols_included = {original_name} + + def traverse(cur_kernel): + for symbol_name in cur_kernel.fn.__code__.co_names: + if symbol_name in symbols_included: + continue + if symbol_name in cur_kernel.fn.__globals__: + symbol = cur_kernel.fn.__globals__[symbol_name] + if isinstance(symbol, JITFunction): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.jit") + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif isinstance(symbol, (int, str, bool)): + compile_wrapper.newline() + compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + symbols_included.add(symbol_name) + + traverse(kernel) + + compile_wrapper.writeline( + f"''', device_str='{V.graph.scheduler.current_device.type}')" + ) + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + self.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + return name, triton_meta + + def generate_numel_expr(self, kernel_name: str, tree): + expr = f"{kernel_name}_{tree.prefix}numel" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, tree.numel) + + def generate_workspace_allocation(self, nbytes, device, zero_fill): + line = self.make_allocation( + "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) + ) + self.writeline(line) + if zero_fill: + self.writeline(f"workspace.zero_(){self.ending}") + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)}){self.ending}" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline("from torch.profiler import record_function") + self.wrapper_call.writeline( + f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) + + def generate_start_graph(self): + self.wrapper_call.writeline("start_graph()") + + def generate_end_graph(self): + self.wrapper_call.writeline("end_graph()") + + def generate_reset_kernel_saved_flags(self): + self.wrapper_call.splice( + """ + for kernel in globals().values(): + if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner): + kernel.cuda_kernel_saved = False + """ + ) + + def generate_save_uncompiled_kernels(self): + """ + Precompile and save the CUBINs of the Triton kernels that haven't + been precompiled and saved as a side effect of running the generated + JIT model (Python wrapper). This can happen when the model contains + control flow: only one pass through the control flow operators covers + the kernels that are saved, the remaining kernels are not launched, + hence not saved. The main purpose of this codegen is to compile and + save the Triton kernels outside the active control flow path for + subsequent AOTInductor code generation and compilation. + """ + self.wrapper_call.splice( + """ + for kernel in globals().values(): + if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner): + if not kernel.cuda_kernel_saved: + if len(kernel.launchers) == 0: + kernel.precompile() + kernel.save_cuda_kernel( + grid=(0, 0, 0), # use dummy grid + stream="stream", # use dummy stream + launcher=kernel.launchers[0], + ) + """ + ) + + def generate_default_grid(self, name: str, grid_args: List[Any]): + return grid_args + + def generate_kernel_call( + self, + name, + call_args, + grid=None, + device_index=None, + cuda=True, + triton=True, + arg_types=None, + grid_fn: str = "grid", + triton_meta=None, + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if cuda: + call_args_str = ", ".join(pexpr(item) for item in call_args) + stream_name = self.write_get_raw_stream( + V.graph.scheduler.current_device.index, V.graph + ) + if triton: + grid_str = ", ".join(pexpr(item) for item in grid) + grid_str = f"{grid_fn}({grid_str})" + self.writeline( + f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) + else: + stream_ptr = f"c_void_p({stream_name})" + self.writeline(f"{name}.{name}({call_args_str}, {stream_ptr})") + else: + self.writeline(self.wrap_kernel_call(name, call_args)) + + def writeline(self, line): + self.lines.append(line) + + def enter_context(self, ctx): + self.lines.append(LineContext(ctx)) + + def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: + raise NotImplementedError() + + def val_to_arg_str(self, s): + if isinstance(s, SymTypes): + return pexpr(sympy.expand(repr(s))) + elif isinstance(s, sympy.Expr): + return pexpr(s) + elif isinstance(s, (tuple, list)): + + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s)) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) + elif isinstance(s, (ir.Buffer, ReinterpretView)): + return s.codegen_reference() + else: + return repr(s) + + # The following methods are for memory management + def make_buffer_allocation(self, buffer): + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + stride = tuple(buffer.get_stride()) + return self.make_allocation(buffer.get_name(), device, dtype, shape, stride) + + def make_allocation(self, name, device, dtype, shape, stride): + if device.type in ("cpu", "cuda"): + # optimized path for faster allocations, saving ~2us versus the stuff below + return ( + f"{name} = empty_strided_{device.type}(" + f"{self.codegen_shape_tuple(shape)}, " + f"{self.codegen_shape_tuple(stride)}, " + f"{dtype})" + ) + # all other devices: + return ( + f"{name} = empty_strided(" + f"{self.codegen_shape_tuple(shape)}, " + f"{self.codegen_shape_tuple(stride)}, " + f"device='{device.type}', dtype={dtype})" + ) + + def make_tensor_alias(self, new_name, old_name, comment=""): + return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" + + def make_buffer_free(self, buffer): + return f"del {buffer.get_name()}" + + def make_free_by_names(self, names_to_del: List[str]): + return f"del {', '.join(name for name in names_to_del)}" + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" + + def make_buffer_reuse(self, old, new, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + if old_name in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call + ) + if reinterpret_view in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" + + def codegen_deferred_allocation(self, name, layout): + self.writeline( + DeferredLine( + name, + f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} " + f"{self.comment} alias", + ) + ) + + def codegen_allocation(self, buffer): + assert ( + buffer.get_workspace_size() == 0 + ), "Only support zero workspace size for now!" + + name = buffer.get_name() + + if name in V.graph.removed_buffers or name in self.allocated: + return + self.allocated.add(name) + if isinstance( + buffer, + (ir.ExternKernelAlloc, ir.MultiOutput), + ): + return + + layout = buffer.get_layout() + if isinstance(layout, ir.MutationLayout): + return + if isinstance(layout, ir.AliasedLayout): + assert isinstance( + layout.view, ir.ReinterpretView + ), f"unexpected {type(layout.view)}: {layout.view}" + self.codegen_allocation(layout.view.data) + self.codegen_deferred_allocation(name, layout) + return + + self.writeline(AllocateLine(self, buffer)) + + def codegen_free(self, buffer): + assert ( + buffer.get_workspace_size() == 0 + ), "Only support zero workspace size for now!" + + name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, ir.InputBuffer): + self.writeline(self.make_buffer_free(buffer)) + return + + if not self.can_reuse(buffer): + return + self.freed.add(name) + + self.writeline(FreeIfNotReusedLine(self, buffer)) + + def can_reuse(self, input_buffer, output_buffer=None): + name = input_buffer.get_name() + if ( + name in V.graph.removed_buffers + or name in V.graph.graph_inputs + or name in V.graph.constants + or name in V.graph.never_reuse_buffers + or name in self.freed + ): + return False + + return True + + def did_reuse(self, buffer, reused_buffer): + # Check whether a given buffer was reused by a possible reuser in the wrapper codegen + # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed + return ( + buffer.get_name() in self.reuses + and self.reuses[buffer.get_name()] == reused_buffer.get_name() + ) + + def codegen_inplace_reuse(self, input_buffer, output_buffer): + assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer) + self.codegen_allocation(input_buffer) + self.freed.add(input_buffer.get_name()) + self.allocated.add(output_buffer.get_name()) + self.reuses[output_buffer.get_name()] = input_buffer.get_name() + self.writeline(ReuseLine(self, input_buffer, output_buffer)) + + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCpu, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): + self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + + def codegen_conditional(self, conditional): + name = conditional.get_name() + outer_inputs = [buf.codegen_reference() for buf in conditional.operands] + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + self.writeline(f"if {conditional.predicate.codegen_reference()}.item():") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("else:") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + + @staticmethod + def statically_known_int_or_none(x): + try: + val = V.graph._shape_env._maybe_evaluate_static(x) + return int(x) + except Exception: + return None + + @staticmethod + def statically_known_list_of_ints_or_none(lst): + result = [] + for x in lst: + num = WrapperCodeGen.statically_known_int_or_none(x) + if num is None: + return None + result.append(num) + return result + + @staticmethod + def is_statically_known_list_of_ints(lst): + return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None + + @staticmethod + def static_shape_for_buffer_or_none(buffer): + return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size()) + + @staticmethod + def can_prove_buffer_has_static_shape(buffer): + return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None diff --git a/MLPY/Lib/site-packages/torch/_inductor/comm_analysis.py b/MLPY/Lib/site-packages/torch/_inductor/comm_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..0f7f1d8f336bbc980afb949ee06492cfdcc04635 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/comm_analysis.py @@ -0,0 +1,273 @@ +import math +from enum import IntEnum + +import sympy + +import torch +from . import ir + +from .utils import get_dtype_size, sympy_product +from .virtualized import V + + +class NCCL_COLL(IntEnum): + ALL_REDUCE = 0 + ALL_GATHER = 1 + REDUCE_SCATTER = 2 + + +class NVIDIA_GPU_TYPE(IntEnum): + VOLTA = 0 + AMPERE = 1 + HOPPER = 2 + + +def get_gpu_type() -> NVIDIA_GPU_TYPE: + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" + if "V100" in gpu_info: + return NVIDIA_GPU_TYPE.VOLTA + elif "A100" in gpu_info: + return NVIDIA_GPU_TYPE.AMPERE + elif "H100" in gpu_info: + return NVIDIA_GPU_TYPE.HOPPER + else: + # for other gpu types, assume Ampere + return NVIDIA_GPU_TYPE.AMPERE + + +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if isinstance(node, ir._CollectiveKernel): + kernel_name = node.python_kernel_name + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + else: + raise Exception(f"Unsupported collective kernel: {kernel_name}") + + if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)): + return NCCL_COLL.ALL_REDUCE + elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)): + return NCCL_COLL.ALL_GATHER + elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)): + return NCCL_COLL.REDUCE_SCATTER + else: + raise Exception(f"Unsupported collective type: {node}") + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + shape = inp.layout.size + numel = sympy_product(inp.layout.size) + if isinstance(numel, sympy.Integer): + # For ease of testing + numel = int(numel) + else: + numel = V.graph.sizevars.size_hint(numel) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if type(node) == ir._CollectiveKernel: + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + elif isinstance(node, ir.CollectiveKernel): + return node.constant_args[2] # type: ignore[attr-defined] + else: + raise TypeError(f"Unsupported collective type: {node}") + + +#################################################################################################################### +# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +#################################################################################################################### + + +class NCCL_HW(IntEnum): + NVLINK = 0 + PCI = 1 + NET = 2 + + +class NCCL_ALGO(IntEnum): + TREE = 0 + RING = 1 + + +class NCCL_PROTO(IntEnum): + # The ordering and enum values here matches original in + # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 + # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 + LL = 0 # Low-latency + # LL128 = 1 # Low-latency 128-byte + # SIMPLE = 2 + + +# Latencies in us +# len(NCCL_ALGO) x len(NCCL_PROTO) +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree + [ + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] + +# Latencies in us +# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) +hwLat = [ + # NVLINK + [ + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] + + +# LL128 max BW per channel +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 + [ + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] + + +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ns). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + tensor_storage_size_bytes = get_collective_input_size_bytes(node) + # Convert bytes to GB + tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 + + # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. + # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + num_gpus_per_node = 8 + group_size = get_collective_group_size(node) + nNodes = math.ceil(group_size / num_gpus_per_node) + nRanks = group_size # this is total # of gpus globally that participate in this collective op + + if nRanks <= 1: + return 0 + + # Assumes ring algorithm + nccl_algo = NCCL_ALGO.RING + nccl_proto = NCCL_PROTO.LL + coll = get_collective_type(node) + + # =============== bandwidth computation =============== + # First compute bandwidth in GB/s; then at the end, convert it to GB/ns + + bwIntra = torch._inductor.config.intra_node_bw + bwInter = torch._inductor.config.inter_node_bw + + compCapIndex = get_gpu_type() + index2 = nNodes - 1 if nNodes <= 2 else 2 + # LL: for single node, we look at GPU type; for multi-node, we look at CPU type + index1 = compCapIndex if nNodes == 1 else 0 + llMaxBw = llMaxBws[index1][index2] + + # NOTE: each step of ring algorithm is synchronized, + # and is bottlenecked by the slowest link which is the inter-node interconnect. + # hence when nNodes >= 2, bw is inter-node bandwidth. + # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc + # have this as `if nNodes <= 2` which seems wrong. Corrected it here. + bw = bwIntra if nNodes == 1 else bwInter + nChannels = 2 # Assume # channels is 2 + busBw = nChannels * bw + + # Various model refinements + busBw = min( + llMaxBw, + busBw + * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), + ) + + if coll == NCCL_COLL.ALL_REDUCE: + nsteps = 2 * (nRanks - 1) + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nsteps = nRanks - 1 + + # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) + ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] + bandwidth = busBw * ratio + # Convert GB/s to GB/ns + bandwidth_GB_per_ns = bandwidth / 1e9 + + # =============== latency computation =============== + intraHw = NCCL_HW.NVLINK + hw = intraHw if nNodes == 1 else NCCL_HW.NET + + if coll == NCCL_COLL.ALL_REDUCE: + if nNodes > 1: + nInterSteps = 2 * nNodes + else: + nInterSteps = 0 + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nInterSteps = nNodes - 1 + + # First compute latency in us; then at the end, convert it to ns + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] + + # Inter-node rings still have to launch nsteps * net overhead. + netOverhead = 0.0 + if nNodes > 1: + netOverhead = 1.0 # getNetOverhead(comm); + intraLat = max(intraLat, netOverhead) + latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] + # Convert us to ns + latency_ns = latency * 1e3 + + # =============== final result =============== + transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns + return transport_ns + latency_ns + + +################################################################################################################ +# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +################################################################################################################ diff --git a/MLPY/Lib/site-packages/torch/_inductor/comms.py b/MLPY/Lib/site-packages/torch/_inductor/comms.py new file mode 100644 index 0000000000000000000000000000000000000000..9974fa428976b5b3283f828fec1b2c75534672ea --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/comms.py @@ -0,0 +1,363 @@ +# pyre-strict + +from typing import List + +import torch + +from . import config, ir, scheduler +from .dependencies import WeakDep +from .utils import tuple_sorted + +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + + +def sink_waits( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + """ + Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of + communication overlap. + """ + new_order = [] + cur_waits = set() + for snode in snodes: + if isinstance(snode.node, ir.Wait): + cur_waits.add(snode) + else: + for wait in tuple_sorted(cur_waits): + if snode in wait.node_users: + new_order.append(wait) + cur_waits.remove(wait) + new_order.append(snode) + new_order.extend(tuple_sorted(cur_waits)) + return new_order + + +def raise_comms( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + """ + Greedily moves comms as early as possible (i.e. until we reach an input). + Optimal in terms of communication overlap. + + TODO: We might want to adjust this in the future to account for memory limitations. + e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible, + which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP, + or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way. + """ + new_order_reversed: List["scheduler.BaseSchedulerNode"] = [] + cur_comms: List["scheduler.BaseSchedulerNode"] = [] + for snode in reversed(snodes): + if isinstance(snode.node, ir.CollectiveKernel): + cur_comms.append(snode) + else: + for comm in cur_comms: + assert len(comm.inverse_users) > 0 + while len(cur_comms) > 0 and any( + snode in comm.inverse_users for comm in cur_comms + ): + comm = cur_comms.pop(0) + new_order_reversed.append(comm) + new_order_reversed.append(snode) + assert len(cur_comms) <= 1 + new_order_reversed.extend(tuple_sorted(cur_comms)) + return new_order_reversed[::-1] + + +def get_ancestors(node): + ancestors = set() + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.inverse_users: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return ancestors + + +def get_descendants(node): + descendants = set() + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.node_users: + if inp not in descendants: + descendants.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return descendants + + +def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]): + """ + Decide global ordering of comms, by just enforcing the ordering that's in the input graph + (might not be the same ordering as the eager mode program). + TODO: Come up with a better approach + """ + comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)] + for i in range(1, len(comm_nodes)): + # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm + comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name())) + + +def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None: + assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes) + + +def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + if config.estimate_op_runtime == "default": + runtime = snode.get_estimated_runtime() + else: + assert callable(config.estimate_op_runtime) + runtime = config.estimate_op_runtime(snode) + return runtime + + +def reorder_compute_for_overlap( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + """ + Decides a global ordering of all compute and communication nodes, + assuming that we already have a global ordering of communication nodes. + + Overall scheduling procedure is: + Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes + that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. + Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. + Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. + We prioritize compute nodes that are needed sooner. + Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. + Step 4: We schedule comm N + 1. + Repeat this for subsequent comm nodes. + """ + final_order = [] + + comm_nodes = [] + for snode in snodes: + if isinstance(snode.node, ir.CollectiveKernel): + comm_nodes.append(snode) + if len(comm_nodes) == 0: + # if there is no comm nodes, return the current order + return snodes + + comm_ancestors = {node: get_ancestors(node) for node in comm_nodes} + comm_descendants = {node: get_descendants(node) for node in comm_nodes} + + indeg = dict.fromkeys(snodes, 0) + for snode in snodes: + for user in snode.node_users: + if user in indeg: + indeg[user] += 1 + ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0} + + unscheduled_nodes = set() + unscheduled_nodes = set(snodes) + + def schedule_node(snode): + """ + Schedule a single node. + """ + assert snode in unscheduled_nodes + assert snode in ready_to_schedule_nodes + ready_to_schedule_nodes.remove(snode) + unscheduled_nodes.remove(snode) + final_order.append(snode) + for user in tuple_sorted(snode.node_users): + if user in indeg: + indeg[user] -= 1 + if indeg[user] == 0: + ready_to_schedule_nodes.add(user) + + def schedule_nodes(snodes): + """ + Schedules all nodes in `snodes` in an arbitrary topologically valid order. + """ + all_nodes = set(snodes) + assert all(node in unscheduled_nodes for node in all_nodes) + while len(all_nodes) > 0: + # NOTE: since model graph is always a DAG and does not have circular dependency inside, + # there should be at least one node that is a "free node" (i.e. indeg == 0), + # hence infinite loop is not possible. But we check here just to be safe. + progress = False + for node in tuple_sorted(all_nodes): + if node in ready_to_schedule_nodes: + schedule_node(node) + all_nodes.remove(node) + progress = True + if not progress: + raise Exception( + "Unable to find a free node (indeg == 0). This is an impossible state to reach. " + "Please report a bug to PyTorch." + ) + + # First, schedule all compute nodes that are required by first comm node, + # as well as the first comm node itself. + assert len(comm_nodes) > 0 + schedule_nodes( + list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]], + ) + + rolled_over_compute_cost = 0 + for idx in range(1, len(comm_ancestors)): + # Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule + # all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`, + # to run at the same time with comm `idx-1`. + needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & ( + comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]] + ) + assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes) + + total_compute_runtime_cost = rolled_over_compute_cost + sum( + [ + estimate_op_runtime(node) + for node in needed_by_next_comm_and_ready_compute_nodes + ] + ) + prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1]) + schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes)) + + # Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done. + # Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`. + # We prioritize compute nodes that are needed sooner. + step1_runtime_cost = total_compute_runtime_cost + if step1_runtime_cost >= prev_comm_runtime_cost: + pass + else: + # Find all ready to schedule compute nodes that do not depend on comm `idx-1`. + ready_to_schedule_compute_nodes = tuple_sorted( + ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]] + ) + assert_no_comm_nodes(ready_to_schedule_compute_nodes) + + def earliest_comm_descendant(node): + for idx in range(len(comm_nodes)): + if node in comm_ancestors[comm_nodes[idx]]: + return idx + return len(comm_nodes) + + # Prioritize compute nodes that are needed sooner. + ready_to_schedule_compute_nodes = sorted( + ready_to_schedule_compute_nodes, key=earliest_comm_descendant + ) + + for snode in ready_to_schedule_compute_nodes: + if total_compute_runtime_cost >= prev_comm_runtime_cost: + # If accumulated compute runtime cost is greater than comm `idx-1` runtime cost, + # it means we have maximized overlap for comm `idx-1`, and hence we stop looking + # for more compute to schedule. + break + compute_runtime_cost = estimate_op_runtime(snode) + # If we're not able to leverage more than half of this + # node's compute to overlap, we skip it. + # TODO: Smarter heuristics here + if ( + prev_comm_runtime_cost - total_compute_runtime_cost + ) <= compute_runtime_cost / 2: + continue + schedule_node(snode) + total_compute_runtime_cost += compute_runtime_cost + rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost + + # Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`. + needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]] + schedule_nodes(list(needed_by_next_comm_nodes)) + + # Step 4: We schedule comm `idx`. + schedule_nodes([comm_nodes[idx]]) + + is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0 + # The idea here is that if there are no compute nodes from Step 3 + # (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes + # in Step 2 to overlap with the next comm, since they're not required to finish + # before the next comm starts. + if is_prev_comm_blocking_next_comm: + rolled_over_compute_cost = 0 + else: + rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment] + + schedule_nodes(unscheduled_nodes) + return final_order + + +def node_summary(snode): + detail = "" + if isinstance(snode.node, ir.ExternKernelOut): + detail = f" ({snode.node.python_kernel_name})" + out_tensor_info = "" + if ( + hasattr(snode.node, "layout") + and hasattr(snode.node.layout, "size") + and hasattr(snode.node.layout, "stride") + ): + out_tensor_info = ( + f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" + ) + node_name = "" + if hasattr(snode.node, "name"): + node_name = snode.node.name + return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" + + +def visualize_overlap(order): + total_est_runtime: float = 0.0 + cur_comm_node = None + for snode in order: + if cur_comm_node is None: + if isinstance(snode.node, ir.CollectiveKernel): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + elif isinstance(snode.node, ir.Wait): + raise Exception( + "Wait is not expected when there is no collective running" + ) + else: # exposed compute op + total_est_runtime += estimate_op_runtime(snode) + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + else: # cur_comm_node is not None + if isinstance(snode.node, ir.CollectiveKernel): + raise Exception( + "Found two collectives running at the same time. " + "`visualize_overlap` needs to be updated to handle this case" + ) + elif isinstance(snode.node, ir.Wait): # end of this comm op + overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 + cur_comm_node = None + else: # overlapped compute op + overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 + overlap_log.debug( + f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 + ) + + +def reorder_compute_and_comm_for_overlap( + snodes: List["scheduler.BaseSchedulerNode"], +) -> List["scheduler.BaseSchedulerNode"]: + order = snodes + for p in config.reorder_for_compute_comm_overlap_passes: + if isinstance(p, str) and p in globals(): + p = globals()[p] # it is a builtin pass + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + order = p(order) # type: ignore[operator] + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug(str(e)) + return order diff --git a/MLPY/Lib/site-packages/torch/_inductor/compile_fx.py b/MLPY/Lib/site-packages/torch/_inductor/compile_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f86d481e322d8a28fe8312d2b891fdae93124c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/compile_fx.py @@ -0,0 +1,1451 @@ +import contextlib +import functools +import logging +import os +import sys +import time +import warnings +from itertools import count + +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Tuple, + Union, +) +from unittest import mock + +from functorch.compile import min_cut_rematerialization_partition + +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo import ( + compiled_autograd, + config as dynamo_config, + logging as dynamo_logging, + utils as dynamo_utils, +) +from torch._dynamo.utils import ( + counters, + detect_fake_mode, + lazy_format_graph_code, + optimus_scuba_log, +) +from torch._functorch.aot_autograd import aot_export_module, make_boxed_func +from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache +from torch._inductor.cudagraph_utils import BoxedDeviceIndex + +from torch._inductor.debug import save_args_for_compile_fx_inner +from torch._inductor.utils import BoxedBool, count_tangents +from torch._logging import trace_structured +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import signpost_event +from torch.fx.passes.fake_tensor_prop import FakeTensorProp + +from .._dynamo.backends.common import aot_autograd +from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] +from ..fx.graph import _PyTreeCodeGen +from . import config, metrics +from .debug import DebugContext +from .decomposition import select_decomp_table +from .fx_passes.joint_graph import joint_graph_passes +from .fx_passes.post_grad import post_grad_passes, view_to_reshape +from .fx_passes.pre_grad import pre_grad_passes +from .graph import GraphLowering +from .ir import ExternKernelNode +from .utils import get_dtype_size, has_incompatible_cudagraph_ops, output_node +from .virtualized import V + +if config.is_fbcode(): + from torch._inductor.fb.utils import time_and_log +else: + # no-op decorator + def time_and_log(attr: str, extra_loggings: Optional[Dict[str, str]] = None): + return dynamo_utils.identity + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs") +ALIGNMENT = 16 + + +# copy_ fails when trying to write to tensors with memory overlap, +# for expanded dimensions (a dimension which used to have size 1 -> ?) +# we can select one element from that dimension and write to it +# to achieve writing to all values of that dimension of the input tensor +def get_expanded_dims(t): + if not isinstance(t, torch.Tensor): + return None + return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] + + +def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor: + for expanded_dim in expanded_dims: + t = torch.ops.aten.slice(t, expanded_dim, 0, 1) + return t + + +def complex_memory_overlap(t: torch.Tensor) -> bool: + # if torch._debug_has_internal_overlap thinks this tensor potentially has + # memory overlap internally, let's dig deeper to find out whether it's true. + t = index_expanded_dims(t, get_expanded_dims(t)) + if torch._debug_has_internal_overlap(t) != 0: + strides = t.stride() + sizes = t.shape + indices = list(range(len(strides))) + indices = [x for _, x in sorted(zip(strides, indices))] + for i in range(len(strides)): + prev_stride = 1 if i == 0 else strides[indices[i - 1]] + prev_size = 1 if i == 0 else sizes[indices[i - 1]] + if strides[indices[i]] < prev_stride * prev_size: + return True + return False + + +@functools.lru_cache(None) +def _step_logger(): + return dynamo_logging.get_step_logger(log) + + +@functools.lru_cache(None) +def _warn_tf32_disabled(): + if ( + torch.cuda.is_available() + and not torch.backends.cuda.matmul.allow_tf32 + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. " + "Consider setting `torch.set_float32_matmul_precision('high')` for better performance." + ) + + +def _unlift_graph(mod, gm, graph_signature): + from torch.export.unflatten import _assign_attr, _AttrKind + + state_dict = {} + for name, param in mod.named_parameters(remove_duplicate=False): + state_dict[name] = param + _assign_attr( + param, + gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + for name, buffer in mod.named_buffers(remove_duplicate=False): + state_dict[name] = buffer + _assign_attr( + buffer, + gm, + name, + attr_kind=_AttrKind.BUFFER, + ) + + placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + lifted_inputs = [] + for node in placeholder_nodes: + node_name = node.name + if node_name in graph_signature.inputs_to_parameters: + lifted_inputs.append(graph_signature.inputs_to_parameters[node_name]) + elif node_name in graph_signature.inputs_to_buffers: + lifted_inputs.append(graph_signature.inputs_to_buffers[node_name]) + else: + assert node_name in graph_signature.user_inputs + lifted_inputs.append(None) + + from torch.export._unlift import _unlift + + outputs = list(gm.graph.nodes)[-1].args[0] + mutated_outputs = [] + for out in outputs: + if out in graph_signature.buffers_to_mutate: + mutated_outputs.append(graph_signature.buffers_to_mutate[out.name]) + else: + mutated_outputs.append(None) + + unlifted_gm = _unlift( + gm, + lifted_inputs, + mutated_outputs, + pytree.LeafSpec(), + None, + state_dict, + {}, + ) + return unlifted_gm + + +def _get_subgraph_names(gm): + for node in gm.graph.nodes: + if node.target == torch.ops.higher_order.cond: + true_subgraph_name = node.args[1].name + false_subgraph_name = node.args[2].name + yield true_subgraph_name + yield false_subgraph_name + + +def _recursive_pre_grad_passes(gm, example_inputs): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + # as we don't have recursive example inputs, passing None here + new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None) + setattr(gm, subgraph_name, new_subgraph) + return pre_grad_passes(gm, example_inputs) + + +def _recursive_joint_graph_passes(gm): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_joint_graph_passes(subgraph) + joint_graph_passes(gm) + + +def _recursive_post_grad_passes(gm, is_inference: bool = False): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_post_grad_passes(subgraph, is_inference) + post_grad_passes(gm, is_inference) + + +def split_const_gm( + gm: torch.fx.GraphModule, +) -> Tuple[torch.fx.GraphModule, Dict[str, int]]: + """ + This function takes an GraphModule input "gm". + The gm will be split into 2 components, + 1) const_gm, which consists the subgraph of gm that can be constant folded. + 2) gm (being inplace modified,) which returns the graph after constant folding. + + const_output_index is a mapping of corresponding node name from gm to the + output index of const_gm. + Returns (const_gm, const_output_index) + """ + from torch._inductor.constant_folding import ( + CONST_MODULE_TAG, + META_TAG, + MODULE_TAG, + replace_node_with_constant, + run_and_get_constant_graph, + ) + + const_gm = run_and_get_constant_graph(gm) + const_result = const_gm() + + const_outputs = { + x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0]) + } + + to_erase_node = [] + to_replace_node = [] + const_output_index = {} + for node in gm.graph.nodes: + if node.name in const_outputs: + to_replace_node.append(node) + elif node.meta[META_TAG] == CONST_MODULE_TAG: + to_erase_node.append(node) + + for node in to_replace_node: + new_const_name = "_FOLDED_CONST_" + node.name + replace_node_with_constant( + gm, + node, + const_result[const_outputs[node.name]], + new_const_name, + ) + const_output_index[new_const_name] = const_outputs[node.name] + for node in to_erase_node[::-1]: + if node.users: + for n in node.users: + assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty." + else: + gm.graph.erase_node(node) + gm.recompile() + + return const_gm, const_output_index + + +def is_tf32_warning_applicable(gm: torch.fx.GraphModule): + aten = torch.ops.aten + tf32_ops = { + aten.mm.default, + aten.addmm.default, + aten.bmm.default, + aten.baddbmm.default, + } + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target in tf32_ops + and isinstance(node.meta.get("val", None), torch.Tensor) + and node.meta["val"].dtype == torch.float32 + and node.meta["val"].device.type == "cuda" + ): + return True + return False + + +@DebugContext.wrap +def count_bytes_inner( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + num_fixed: int = 0, + **kwargs, +): + shape_env = _shape_env_from_inputs(example_inputs) + fake_mode = fake_tensor_prop(gm, example_inputs) + + with V.set_fake_mode(fake_mode): + _recursive_post_grad_passes(gm, False) + + graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed) + with V.set_graph_handler(graph), V.set_real_inputs(example_inputs): + graph.run(*example_inputs) + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.nodes_num_elem += nodes_num_elem + metrics.node_runtimes += node_runtimes + return make_boxed_func(gm.forward) + + +def fake_tensor_prop( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + force_allow_non_fake_inputs: bool = False, +): + """ + If we can not detect fake mode from the context of inputs, create one. + + The created fake mode will be returned. + """ + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) + ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) + + return fake_mode + + +# pass config dict back to user +def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: + with config.patch(config_patches): + return config.get_config_copy() + + +@DebugContext.wrap +@torch.utils._python_dispatch._disable_current_modes() +@time_and_log( + attr="compilation time (in seconds)", + extra_loggings={"config_dict": str(get_patched_config_dict())}, +) +# Need this decorator for compile_fx_inner even if we already have one for +# compile_fx. The reason is the compilation for backward graph may happen after +# compile_fx return and we may want to use the _LazyGraphModule for compiling +# the backward graph as well. +@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module) +@dynamo_utils.dynamo_timed(phase_name="inductor_compile") +def compile_fx_inner( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + num_fixed: int = 0, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + user_visible_outputs: FrozenSet[str] = frozenset(), + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, +) -> Union[CompiledFxGraph, str]: + """ + Inductor API that compiles a single graph. + + If you change the argument list for this function, make sure you + also update the call to save_args_for_compile_fx_inner below accordingly. + """ + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: + # trigger the real recompilation for _LazyGraphModule before returning + # the forward method. + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(gm) + return make_boxed_func(gm.forward) + + assert isinstance( + next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) + ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + + if config.save_args: + save_args_for_compile_fx_inner( + gm, + example_inputs, + cudagraphs=cudagraphs, + num_fixed=num_fixed, + is_backward=is_backward, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + is_inference=is_inference, + boxed_forward_device_index=boxed_forward_device_index, + user_visible_outputs=user_visible_outputs, + layout_opt=layout_opt, + ) + + if cudagraphs is None: + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # Inputs to fx_codegen_and_compile + # Anything that affects codegen should go here, so if the signature + # of fx_codegen_and_compile changes, the dict should be updated accordingly + graph_kwargs = { + "cudagraphs": cudagraphs, + "num_fixed": num_fixed, + "is_backward": is_backward, + "graph_id": graph_id, + "cpp_wrapper": cpp_wrapper, + "aot_mode": aot_mode, + "is_inference": is_inference, + "user_visible_outputs": user_visible_outputs, + "layout_opt": layout_opt, + "extern_node_serializer": extern_node_serializer, + } + + start = time.time() + + if config.fx_graph_cache and not aot_mode: + compiled_graph = FxGraphCache.load( + fx_codegen_and_compile, gm, example_inputs, graph_kwargs + ) + else: + compiled_graph = fx_codegen_and_compile( + gm, example_inputs, **graph_kwargs # type: ignore[arg-type] + ) + + log.debug("FX codegen and compilation took %.3fs", time.time() - start) + + # check cudagraph disabling reasons from inductor lowering + if cudagraphs and compiled_graph.disabled_cudagraphs_reason: + perf_hint_log.warning( + "skipping cudagraphs due to %s", compiled_graph.disabled_cudagraphs_reason + ) + BoxedBool.disable(cudagraphs) + + # Return the output strides to the caller via TracingContext + context = torch._guards.TracingContext.try_get() + if context is not None and context.output_strides is not None: + assert len(context.output_strides) == 0 + context.output_strides.extend(compiled_graph.output_strides) + + if aot_mode: + return compiled_graph + + if cudagraphs: + # output args are tuple of first argument + output = output_node(gm) + assert len(output.args) == 1 + stack_traces = [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] + ] + + complex_memory_overlap_inputs = any( + complex_memory_overlap(t) + for t in example_inputs + if isinstance(t, torch.Tensor) + ) + + from torch._inductor.cudagraph_utils import check_for_mutation + + has_mutation_str = check_for_mutation(gm, compiled_graph, num_fixed) + has_mutation = has_mutation_str is not None + + if has_mutation: + compiled_graph.disabled_cudagraphs_reason = has_mutation_str + + cudagraph_tests = [ + (not has_mutation, "mutated inputs"), + (not has_incompatible_cudagraph_ops(gm), "incompatible ops"), + (not complex_memory_overlap_inputs, "complex memory overlap"), + ( + all( + isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs + ), + "non-Tensor inputs", + ), + ] + cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b] + + if not cudagraph_fail_reasons: + if not config.triton.cudagraph_trees: + # Force specialize all inputs so that CUDA graphs will work + for t in example_inputs: + if isinstance(t, torch.SymInt): + int(t) # guard + + if ( + boxed_forward_device_index is not None + and not is_inference + and not is_backward + ): + boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) + + compiled_graph.current_callable = cudagraphify( + compiled_graph.get_current_callable(), + example_inputs, + static_input_idxs=range(num_fixed), + device_index=next(iter(compiled_graph.device_idxs)), + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(compiled_graph.constants.values()), + ) + else: + BoxedBool.disable(cudagraphs) + + # See [Backward Generation Handling] + # if cudagraph'd the forward and set the device, we need to let the cudagraph manager + # know we are we running the backward even if we will not run it in cudagraphs + if is_backward and config.triton.cudagraph_trees: + assert boxed_forward_device_index is not None + assert boxed_forward_device_index.value is not None + compiled_graph_callable = compiled_graph.get_current_callable() + + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_forward_device_index.value, create_if_none_exists=False + ) + # should already exist from forward + assert manager is not None + + def compiled_artifact(new_inputs): + manager.set_to_running_backward() + return compiled_graph_callable(new_inputs) + + compiled_graph.current_callable = compiled_artifact + + if "cuda" in compiled_graph.device_types: + # prefer better disable_cudagraphs_reason bc stack trace + # TODO: migrate all disable reasons to stack trace, refactor + if compiled_graph.disabled_cudagraphs_reason: + perf_hint_log.warning(compiled_graph.disabled_cudagraphs_reason) + else: + perf_hint_log.warning( + "skipping cudagraphs due to %s", cudagraph_fail_reasons + ) + + # cudagraphs does its own aligning of inputs + if not cudagraphs: + new_callable = align_inputs( + compiled_graph.get_current_callable(), example_inputs, range(num_fixed) + ) + if new_callable is not compiled_graph.get_current_callable(): + compiled_graph.current_callable = new_callable + + _step_logger()( + logging.INFO, + "torchinductor done compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + + # aot autograd needs to know to pass in inputs as a list + compiled_graph._boxed_call = True + return compiled_graph + + +def fx_codegen_and_compile( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + num_fixed: int = 0, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + user_visible_outputs: FrozenSet[str] = frozenset(), + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, +) -> Union[CompiledFxGraph, str]: + if is_tf32_warning_applicable(gm): + _warn_tf32_disabled() + + # lift the maximum depth of the Python interpreter stack + # to adapt large/deep models + sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000)) + + _step_logger()( + logging.INFO, + "torchinductor compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + V.debug.fx_graph(gm, example_inputs) + # TODO: Should we actually dump this? It should be redundant with the aot + # structured logs... + # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False)) + + shape_env = _shape_env_from_inputs(example_inputs) + + # Convert view to reshape in the graph. This is necessary primarily for + # layout optimization. Do it unconditionally for uniformity. + # + # It's needed because when we do layout optimization, an contiguous tensor + # in eager mode may becomes a channels last tensor. A view op previously + # can be applied to the contiguous tensor may not be able to be applied + # on the channels tensor any more. An error like + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + # will be printed. + # + # Replace view op to reshape op in this case. + # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this. + # + # Also this has to be done before FakeTensorProp below to avoid the failed + # .view() call. + view_to_reshape(gm) + + # It is safe to run FakeTensorProp under no_grad because by the time + # we're in inductor, we assume that AOTAutograd has already "taken care" + # of autograd, so there should be no more autograd-related API's in the + # graph. + with torch.no_grad(): + fake_mode = fake_tensor_prop(gm, example_inputs) + + # pattern matcher passes might not preserve striding information + # on node.meta["val"]. if in the future we rely on these being + # correct we will need to fix. + + with V.set_fake_mode(fake_mode): + # has some issues with memory in training + _recursive_post_grad_passes(gm, is_inference=is_inference) + V.debug.fx_graph_transformed(gm, example_inputs) + post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm)) + trace_structured( + "inductor_post_grad_graph", + payload_fn=lambda: gm.print_readable(print_output=False), + ) + optimus_scuba_log["inductor_post_grad"] = counters["inductor"] + signpost_event( + "optimus", + "compile_fx.post_grad_passes", + optimus_scuba_log, + ) + + with V.set_fake_mode(fake_mode): + const_output_index = None + const_graph = None + const_code = None + + if aot_mode and config.aot_inductor.use_runtime_constant_folding: + const_gm, const_output_index = split_const_gm(gm) + + const_graph = GraphLowering( + const_gm, + example_inputs=[], + shape_env=shape_env, + num_static_inputs=num_fixed, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + user_visible_outputs=user_visible_outputs, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_const_graph=True, + ) + with V.set_graph_handler(const_graph): + assert cpp_wrapper, "AOT mode only supports C++ wrapper" + const_graph.run() + + const_code, _ = const_graph.codegen_with_cpp_wrapper() + + graph = GraphLowering( + gm, + # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. + # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, + # we currently use fake tensors and defake them later. + example_inputs=example_inputs, + shape_env=shape_env, + num_static_inputs=num_fixed, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + user_visible_outputs=user_visible_outputs, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + const_output_index=const_output_index, + const_code=const_code, + const_module=const_graph, + ) + with V.set_graph_handler(graph): + graph.run(*example_inputs) + output_strides: List[Optional[Tuple[int, ...]]] = [] + if graph.graph_outputs is not None: + # We'll put the output strides in the compiled graph so we + # can later return them to the caller via TracingContext + for out in graph.graph_outputs: + if hasattr(out, "layout"): + output_strides.append( + tuple( + V.graph.sizevars.size_hint(s) for s in out.layout.stride + ) + ) + else: + output_strides.append(None) + + metrics_helper = metrics.CachedMetricsHelper() + compiled_fn = graph.compile_to_fn() + + if V.aot_compilation is True: + return compiled_fn + + if cudagraphs and not V.graph.disable_cudagraphs_reason: + from torch._inductor.cudagraph_utils import ( + check_lowering_disable_cudagraph, + ) + + V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph( + V.graph.device_node_mapping + ) + + compiled_graph = CompiledFxGraph( + compiled_fn, + graph, + output_strides, + V.graph.disable_cudagraphs_reason, + metrics_helper.get_deltas(), + ) + + return compiled_graph + + +def clone_preserve_strides(x: torch.Tensor): + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) + + +def copy_misaligned_inputs( + new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int] +) -> None: + for i in check_inputs_idxs: + if new_inputs[i].data_ptr() % ALIGNMENT: + new_inputs[i] = clone_preserve_strides(new_inputs[i]) + + +def get_input_idxs_to_check( + inputs: Union[List[torch.Tensor], Sequence[int]], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + def is_aligned(storage_offset, dtype): + return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0 + + ids_to_check = [] + for i, input in enumerate(inputs): + if ( + isinstance(input, torch.Tensor) + and ( + i not in static_input_idxs + or not is_aligned(input.storage_offset(), input.dtype) + ) + and input.device.type == "cuda" + ): + ids_to_check.append(i) + return ids_to_check + + +def align_inputs_from_check_idxs( + model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int] +): + if len(inputs_to_check) == 0: + return model + + def run(new_inputs): + copy_misaligned_inputs(new_inputs, inputs_to_check) + return model(new_inputs) + + return run + + +def align_inputs( + model: Callable[[List[torch.Tensor]], Any], + inputs: List[torch.Tensor], + static_input_idxs: Sequence[int] = (), +): + inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs) + return align_inputs_from_check_idxs(model, inputs_to_check) + + +@dynamo_utils.dynamo_timed +def cudagraphify( + model: torch.fx.GraphModule, + inputs: List[torch.Tensor], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + stack_traces: List[Optional[str]], + is_backward: bool, + is_inference: bool, + constants: Tuple[torch.Tensor, ...] = (), +): + from torch._inductor.cudagraph_trees import ( + cudagraphify_impl as new_cudagraphify_impl, + ) + + cudagraphify_fn: Callable[..., Any] + if config.triton.cudagraph_trees: + cudagraphify_fn = functools.partial( + new_cudagraphify_impl, + device_index=device_index, + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=constants, + ) + else: + cudagraphify_fn = cudagraphify_impl + + # if using fake tensors, defer cudagraphs until we get real inputs at runtime + if not any(isinstance(inp, FakeTensor) for inp in inputs): + return cudagraphify_fn(model, inputs, static_input_idxs) + + compiled_fn = None + + def run(new_inputs): + nonlocal compiled_fn + if compiled_fn is None: + with dynamo_utils.preserve_rng_state(): + compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) + return compiled_fn(new_inputs) + + return run + + +def remove_unaligned_input_idxs( + inputs: Union[List[torch.Tensor], Sequence[int]], + static_input_idxs: Sequence[int], +): + """ + We require all inputs to be aligned, so introduce a copy for any + that aren't. + """ + aligned_static_input_idxs = [] + for idx, input in zip(static_input_idxs, inputs): + if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: + aligned_static_input_idxs.append(idx) + if len(aligned_static_input_idxs) != len(static_input_idxs): + return aligned_static_input_idxs + return static_input_idxs + + +def static_input(x: torch.Tensor): + """ + Copy and input while preserving strides + """ + # TODO(jansel): figure out why this version doesn't work: + # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device) + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device) + return torch.as_strided(buffer, x.size(), x.stride()) + + +def index_expanded_dims_and_copy_( + dst: torch.Tensor, + src: torch.Tensor, + expanded_dims: List[int], +): + "Index into expanded dimensions of both dst and src then copy_" + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + dst.copy_(src) + + +def cudagraphify_impl( + model: torch.fx.GraphModule, + inputs: List[torch.Tensor], + static_input_idxs: Sequence[int] = (), +): + """ + Assumes inputs[static_input_idxs[i]] are always the same memory address + """ + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + assert isinstance(inputs, list) + + inps_expanded_dims = [ + get_expanded_dims(x) if idx not in static_input_idxs else [] + for idx, x in enumerate(inputs) + ] + + # allocate static tensor inputs + static_inputs = [ + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() + for idx, x in enumerate(inputs) + ] + + # copy over input values for fresh allocations + for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)): + if isinstance(x, torch.Tensor) and idx not in static_input_idxs: + index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + # copy static_inputs because it will be cleared in model + with torch.cuda.stream(stream): + model(list(static_inputs)) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): + static_outputs = model(list(static_inputs)) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + if config.size_asserts: + + def run(new_inputs): + assert len(static_inputs) == len(new_inputs) + for idx, (dst, src, expanded_dims) in enumerate( + zip(static_inputs, new_inputs, inps_expanded_dims) + ): + if not isinstance(dst, torch.Tensor): + pass + elif idx in static_input_idxs: + assert dst.data_ptr() == src.data_ptr() + else: + # TODO - could make one single op of multiple slices + # and avoid dispatch. + # Could also pre-index the `dst` tensors + index_expanded_dims_and_copy_(dst, src, expanded_dims) + new_inputs.clear() + graph.replay() + return static_outputs + + else: + copy_indices = [ + idx for idx in range(len(static_inputs)) if idx not in static_input_idxs + ] + + def run(new_inputs): + for idx in copy_indices: + expanded_dims = inps_expanded_dims[idx] + index_expanded_dims_and_copy_( + static_inputs[idx], new_inputs[idx], expanded_dims + ) + new_inputs.clear() + graph.replay() + return static_outputs + + return align_inputs_from_check_idxs(run, check_input_idxs) + + +def compile_fx_aot( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile: Callable[..., Any] = compile_fx_inner, + config_patches: Optional[Dict[str, Any]] = None, +): + config_patches: Dict[str, Any] = ( + {"cpp_wrapper": True} + if config_patches is None + else {**config_patches, "cpp_wrapper": True} + ) + if ( + "aot_inductor.output_path" not in config_patches + and not config.aot_inductor.output_path + ): + config_patches = { + **config_patches, + "aot_inductor.output_path": code_hash(model_.code), + } + + extern_node_serializer = config_patches.pop("extern_node_serializer", None) + with V.set_aot_compilation(True): + compiled_lib_path = compile_fx( + model_, + example_inputs_, + inner_compile=functools.partial( + inner_compile, + aot_mode=True, + extern_node_serializer=extern_node_serializer, + ), + config_patches=config_patches, + ) + assert os.path.exists( + compiled_lib_path + ), f"AOTInductor compiled library does not exist at {compiled_lib_path}" + return compiled_lib_path + + +_graph_counter = count(0) + + +def fw_compiler_freezing( + aot_autograd_model: torch.fx.GraphModule, + aot_example_inputs: List[torch.Tensor], + dynamo_model: torch.fx.GraphModule, + num_example_inputs: int, + inner_compile: Callable[..., Any], + cudagraphs: BoxedBool, + graph_id: int, + forward_device: BoxedDeviceIndex, +): + from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze + + # partition_fn won't be called + _recursive_joint_graph_passes(aot_autograd_model) + + layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) + if layout_opt: + # make sure meta['val'] is properly setup + fake_tensor_prop(aot_autograd_model, aot_example_inputs, True) + convert_conv_weights_to_channels_last(aot_autograd_model) + + opt_model, preserved_arg_indices = freeze( + dynamo_model, + aot_autograd_model, + aot_example_inputs, # type: ignore[arg-type] + ) + + aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] + num_fixed = len(preserved_arg_indices) - num_example_inputs + + fake_mode = detect_fake_mode(aot_example_inputs) + + # for freezing, all graph outputs should be user visible + *_, model_outputs_node = opt_model.graph.nodes + model_outputs = model_outputs_node.args[0] + user_visible_outputs = [ + n.name for n in model_outputs if isinstance(n, torch.fx.Node) + ] + + # constant params will be real tensors, not fake + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None: + params_flat = tracing_context.params_flat + assert params_flat is not None + for i in range(len(params_flat)): + if i not in preserved_arg_indices: + params_flat[i] = None + + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): + optimized_function = inner_compile( + opt_model, + aot_example_inputs, + num_fixed=num_fixed, + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=True, + boxed_forward_device_index=forward_device, + layout_opt=layout_opt, + user_visible_outputs=user_visible_outputs, + ) + + # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper + # that drops constant-ified params + if V.aot_compilation is True: + return optimized_function + + def wrapper(args): + args_new = [args[i] for i in preserved_arg_indices] + args.clear() + return optimized_function(args_new) + + wrapper._boxed_call = True # type: ignore[attr-defined] + + return wrapper + + +@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module) +def compile_fx( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile: Callable[..., Any] = compile_fx_inner, + config_patches: Optional[Dict[str, Any]] = None, + decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, +): + """Main entrypoint to a compile given FX graph""" + if config_patches: + with config.patch(config_patches): + return compile_fx( + model_, + example_inputs_, + # need extra layer of patching as backwards is compiled out of scope + inner_compile=config.patch(config_patches)(inner_compile), + decompositions=decompositions, + ) + + if config.cpp_wrapper: + with config.patch( + { + "cpp_wrapper": False, + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, + "triton.store_cubin": True, + } + ), V.set_real_inputs(example_inputs_): + inputs_ = example_inputs_ + if isinstance(model_, torch.fx.GraphModule): + fake_inputs = [ + node.meta.get("val") + for node in model_.graph.nodes + if node.op == "placeholder" + ] + if all(v is not None for v in fake_inputs): + # Validate devices before switching to fake tensors. + for idx, fi, i in zip(count(), fake_inputs, inputs_): + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) + inputs_ = fake_inputs + return compile_fx( + model_, + inputs_, + inner_compile=functools.partial(inner_compile, cpp_wrapper=True), + decompositions=decompositions, + ) + + recursive_compile_fx = functools.partial( + compile_fx, + inner_compile=inner_compile, + decompositions=decompositions, + ) + + if not graph_returns_tuple(model_): + return make_graph_return_tuple( + model_, + example_inputs_, + recursive_compile_fx, + ) + + if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_.graph._codegen, _PyTreeCodeGen): + # this graph is the result of dynamo.export() + return handle_dynamo_export_graph( + model_, + example_inputs_, + recursive_compile_fx, + ) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + optimus_scuba_log["inductor_pre_grad"] = counters["inductor"] + signpost_event( + "optimus", + "compile_fx.pre_grad_passes", + optimus_scuba_log, + ) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): + return flatten_graph_inputs( + model_, + example_inputs_, + recursive_compile_fx, + ) + + assert not config._raise_error_for_testing + num_example_inputs = len(example_inputs_) + cudagraphs = BoxedBool(config.triton.cudagraphs) + forward_device = BoxedDeviceIndex(None) + + graph_id = next(_graph_counter) + + decompositions = ( + decompositions if decompositions is not None else select_decomp_table() + ) + + @dynamo_utils.dynamo_timed + def fw_compiler_base( + model: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + is_inference: bool, + ): + if is_inference: + # partition_fn won't be called + _recursive_joint_graph_passes(model) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + user_visible_outputs = set() + + if config.keep_output_stride: + *_, model_outputs_node = model.graph.nodes + assert model_outputs_node.op == "output" + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + if isinstance(model_, torch.fx.GraphModule): + *_, orig_model_outputs_node = model_.graph.nodes + assert orig_model_outputs_node.op == "output" + orig_model_outputs, _ = pytree.tree_flatten( + orig_model_outputs_node.args + ) + num_orig_model_outputs = len(orig_model_outputs) + else: + num_orig_model_outputs = num_model_outputs + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = original_output_start_index + num_orig_model_outputs + # Sanity chec: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + user_visible_outputs = { + n.name + for n in model_outputs[original_output_start_index:orig_output_end_idx] + if isinstance(n, torch.fx.Node) + } + + return inner_compile( + model, + example_inputs, + num_fixed=fixed, + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=is_inference, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) + + fw_compiler = functools.partial(fw_compiler_base, is_inference=False) + + if config.freezing and not torch.is_grad_enabled(): + inference_compiler = functools.partial( + fw_compiler_freezing, + dynamo_model=model_, + num_example_inputs=num_example_inputs, + inner_compile=inner_compile, + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, + ) + else: + inference_compiler = functools.partial(fw_compiler_base, is_inference=True) + + def partition_fn(graph, joint_inputs, **kwargs): + _recursive_joint_graph_passes(graph) + return min_cut_rematerialization_partition( + graph, joint_inputs, **kwargs, compiler="inductor" + ) + + @dynamo_utils.dynamo_timed + @dynamo_utils.maybe_cprofile + def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + fixed = count_tangents(model) + return inner_compile( + model, + example_inputs, + num_fixed=fixed, + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + ) + + # TODO: can add logging before/after the call to create_aot_dispatcher_function + # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func + # once torchdynamo is merged into pytorch + + fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode( + allow_non_fake_inputs=True + ) + tracing_context = ( + torch._guards.TracingContext.try_get() + or torch._guards.TracingContext(fake_mode) + ) + + if V.aot_compilation is True: + gm, graph_signature = aot_export_module( + model_, example_inputs_, trace_joint=False, decompositions=decompositions + ) + unlifted_gm = _unlift_graph(model_, gm, graph_signature) + if "dynamo_flat_name_to_original_fqn" in model_.meta: + unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[ + "dynamo_flat_name_to_original_fqn" + ] + with V.set_fake_mode(fake_mode), compiled_autograd.disable(): + return inference_compiler(unlifted_gm, example_inputs_) + + with V.set_fake_mode(fake_mode), torch._guards.tracing( + tracing_context + ), compiled_autograd.disable(): + return aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + decompositions=decompositions, + partition_fn=partition_fn, + keep_inference_input_mutations=True, + )(model_, example_inputs_) + + +def _shape_env_from_inputs(inputs: List[torch.Tensor]): + shape_env = None + fake_mode = detect_fake_mode(inputs) + + # TODO(voz): It would be nice to enable this assert, but there are lots of tests that + # pass in real inputs for now. + # if len(inputs) > 0: + # assert fake_mode is not None, breakpoint() + + if fake_mode is not None: + return fake_mode.shape_env + + # When there are no tensor inputs, get shape_env from the first SymInt. + for input in inputs: + if isinstance(input, torch.SymInt): + return input.node.shape_env + + # TODO(voz): Should we always have one anyway? + return None + + +def graph_returns_tuple(gm: torch.fx.GraphModule): + """True if a FX graph returns a tuple""" + if not isinstance(gm, torch.fx.GraphModule): + return True # can't check this, assume true + (rv,) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if ( + isinstance(rv, torch.fx.node.Node) + and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 + and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) + ): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + compile_gm: Callable[..., Any], +): + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + """ + node = output_node(gm) + (rv,) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + compiled_fn = compile_gm(gm, inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args, **kwargs): + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + + return wrapper + + +def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): + """ + Mutate inputs so that they are flat and wrap gm such that it + accepts those inputs. This is only needed for graphs not created + by torchdynamo that take bumpy inputs. + """ + inputs, spec = pytree.tree_flatten(inputs) + + class GmWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self.gm = gm + + def forward(self, *args): + args: List[Any] = list(args) + return self.gm(*pytree.tree_unflatten(args, spec)) + + compiled_fn = compile_gm(GmWrapper(), inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args): + # note this doesn't check the spec, assuming it is the same + return compiled_fn(*pytree.arg_tree_leaves(*args)) + + return wrapper + + +def handle_dynamo_export_graph( + gm: torch.fx.GraphModule, + inputs: List[torch.Tensor], + compile_gm: Callable[..., Any], +): + """ + `torch._dynamo.export` embeds pytrees in the FX graph codegen object, + convert that to a normal FX graph so inductor can compile it. + """ + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) + + @functools.wraps(compiled_fn) + def wrapper(*args): + return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) + + return wrapper diff --git a/MLPY/Lib/site-packages/torch/_inductor/config.py b/MLPY/Lib/site-packages/torch/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..845c9ad6bae9e55dc247e857d57672327bc5ff0e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/config.py @@ -0,0 +1,752 @@ +import os # noqa: C101 +import sys +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING + +import torch + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +# add some debug printouts +debug = False + +# add inf and NaN checkers +debug_check_inf_and_nan = False + +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# use fx aot graph codegen cache +fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1" + +# use cpp wrapper instead of python wrapper +cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" + +# codegen cpp wrapper code in an ABI compatible mode +abi_compatible = ( + os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1" +) + +c_shim_version = os.environ.get( + "TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2" +) + +# dead code elimination +dce = False + +# assume weight tensors are fixed size +static_weight_shapes = True + +# put correctness assertions in generated code +size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" +nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1" + +# enable loop reordering based on input orders +pick_loop_orders = True + +# reuse a kernel input as the output +inplace_buffers = True + +# reuse a buffer for an unrelated purpose +allow_buffer_reuse = True + +# Enable pooled allocations for non-output tensors +memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1" + +# How to organize memory under memory_planning=True: +# - "none": do not try to pool storage, just reuse +# - "intermediates": all non-outputs share storage, outputs each get unique storage +# - "outputs": two pools, one for intermediates (freed on return) and one for outputs +# - "combined": a single pool for both intermediates and outputs +memory_pool = os.environ.get("TORCHINDUCTOR_MEMORY_POOL", "intermediates") + +# codegen benchmark harness +benchmark_harness = True + +# fuse pointwise into templates +epilogue_fusion = True + +# do epilogue fusions before other fusions +epilogue_fusion_first = False + +# enable pattern match+replace optimizations +pattern_matcher = True + +# register custom graph optimization pass hook. so far, pre/post passes are +# only applied before/after pattern_matcher in post_grad_passes. +# +# def my_custom_pre_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# def my_custom_post_pass(graph: torch.fx.graph.Graph): +# # my custom graph optimization pass +# ... +# +# torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass +# torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass +post_grad_custom_pre_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None +post_grad_custom_post_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom pregrad pass. Note that the pre-grad IR is 1. +# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should +# use post-grad passes. +pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Optimize away split cat patterns (Experimental) +split_cat_fx_passes = True + +# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. +efficient_conv_bn_eval_fx_passes = False + +# Enable predispatch aten IR for export +is_predispatch = False + +# Deprecated +group_fusion = False + +# Deprecated +batch_fusion = True + +# Pre grad group/batch fusion and options in order, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions. +pre_grad_fusion_options: Dict[str, Dict[str, Any]] = { + "batch_linear": {}, + "batch_linear_lhs": {}, + "batch_layernorm": {}, + "batch_tanh": {}, + "batch_relu": {}, + "batch_sigmoid": {}, +} + +# Post grad group/batch fusion and options, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. +post_grad_fusion_options: Dict[str, Dict[str, Any]] = {} + +# enable reordering pass for improving memory locality +reorder_for_locality = True + +# Scale down RBLOCK for better occupancy +dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" + +# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32 +# but the mul gets fused with other pointwise ops instead. +force_fuse_int_mm_with_mul = False + +# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, +# enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel. +# Autotune will compare perf with normal cast->then->mm option +use_mixed_mm = False + +# enable runtime numeric check for pre/post grad fx passes +# floating point provides limited accuracy (about 7 decimal digits for single precision +# floating point numbers,about 16 decimal digits for double precision floating point numbers) +# according to PyTorch documentation. +# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations +fx_passes_numeric_check: Dict[str, Any] = { + "pre_grad": False, + "precision": 1e-4, + "num_iterations": 1, + "requires_optimizer": True, +} + +# for pattern torch.mm(a, b.to(dtype)) with cuda tensors, always use +# torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel. +# Autotune will not compare with normal cast->then->mm option. +# (if force_mixed_mm is true, the use_mixed_mm flag will be ignored) +force_mixed_mm = False + +# enable reordering pass for increasing overlap between compute and communication +reorder_for_compute_comm_overlap = False + +# passes (in execution order) for increasing overlap between compute and communication +# for built-in passes, use string name; for user-defined passes, pass in the function handle +reorder_for_compute_comm_overlap_passes = [ + "reorder_compute_for_overlap", + "sink_waits", + "raise_comms", +] + +# runtime estimation function for ops +# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle +estimate_op_runtime = "default" + +# unit: GB/s, uni-directional P2P bandwidth per card +# default value is NVLink +intra_node_bw = 300 + +# unit: GB/s, uni-directional P2P bandwidth per node +# default value is InfiniBand +inter_node_bw = 25 + +# enable slow autotuning passes to select algorithms +max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" + +# enable slow autotuning passes to select pointwise/reductions algorithms +max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" + +# enable slow autotuning passes to select gemm algorithms +max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" + +# enable autotune local cache +use_autotune_local_cache = True + +# enable autotune remote cache +use_autotune_remote_cache = ( + os.environ.get("TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE") == "1" +) + +# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations +# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations +# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure +# that triton does not use TF32 wherever cublas would not use TF32 +force_same_precision = ( + True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +) +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS. +# ATen: default Pytorch ATen kernels. +# Triton: Triton templates defined in torch inductor. +# CUTLASS: Cutlass templates and kernels. +max_autotune_gemm_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" +).upper() + +# the value used as a fallback for the unbacked SymInts +# that can appear in the input shapes (e.g., in autotuning) +unbacked_symint_fallback = 8192 + +# enable searching global and local cache regardless of `max_autotune` +search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" + +save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" + +# We will disable creating subprocess for autotuning if this is False +autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" + +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + +coordinate_descent_tuning = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" +) +coordinate_descent_check_all_directions = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" +) +coordinate_descent_search_radius = int( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") +) + +# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions +layout_opt_default = "1" if not torch.version.hip else "0" +layout_optimization = ( + os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1" +) + +force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" + + +# Whether to keep the output strides the same as eager after layout optimization. +keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" + +# Enabling this will let compiler print warning messages if a generated triton +# kernel has inputs with mixed layouts. This is helpful for perf debugging +# since kernel with mixed layout inputs may run much slower then one whose inputs +# have uniform layouts. +warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" + +# control store vs recompute heuristic +# For fanouts, rematerialization can lead to exponential blowup. So, have +# smaller threshold +realize_reads_threshold = 4 +realize_opcount_threshold = 30 + +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 + +# fallback to eager for random/dropout, this is slow but useful for debugging +fallback_random = False + +# automatically create fallbacks when encountering an unhandled op +implicit_fallbacks = True + +# fuse even in cases without common reads +aggressive_fusion = False + +# For each fused kernel in the wrapper, comment with the nodes that get fused. +# Useful for debugging fusion. +debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" +enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") + +# how many nodes to allow into a single fusion +max_fusion_size = 64 + +# max number of inputs to generate cat as a pointwise op with masked laods +max_pointwise_cat_inputs = 8 + +# replace small reductions with pointwise, disable with `= 1` +unroll_reductions_threshold = 8 + +# Add extra comments to output code (causes compile cache misses) +comment_origin = False + +# Convert 1x1 convs into matmuls +conv_1x1_as_mm = False + +# Enable split reductions for better utilization when the dimension +# being reduced over is large (by splitting it) +split_reductions = True + +benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" + +# Enable constant and index_expr folding +constant_and_index_propagation = True + +# we always add constants into graph.constants without +# performing any constant-inlining optimization +always_keep_tensor_constants = False + +# assert that indirect indexing does not read / write out of bounds +assert_indirect_indexing = True + +# constant folding on the joint graph +joint_graph_constant_folding = True + +# Enable indirect_indexing asserts for decompositions and lowerings +debug_index_asserts = False + +# warnings intended for PyTorch developers, disable for point releases +is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ +developer_warnings = is_fbcode() or is_nightly_or_source + +# The multiprocessing start method to use for inductor workers in the codecache. +# TODO: fork is not safe in a multithreaded environment, we should evaluate changing +# the default to spawn. +worker_start_method = "fork" + + +def decide_compile_threads(): + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform or it's a fbcode build + 3. decide by the number of CPU cores + """ + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + elif sys.platform == "win32" or is_fbcode(): + return 1 + else: + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + return min(32, cpu_count) + + +compile_threads = decide_compile_threads() + +# gemm autotuning global cache dir +if is_fbcode(): + from libfb.py import parutil + + try: + if __package__: + global_cache_dir = parutil.get_dir_path( + os.path.join(__package__.replace(".", os.sep), "fb/cache") + ) + else: + global_cache_dir = parutil.get_dir_path("fb/cache") + except ValueError: + global_cache_dir = None +else: + global_cache_dir = None + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False + +# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for +# every intermediate for which we can correlate it with an intermediate +# from the original FX graph +generate_intermediate_hooks = False + +# Populate traceback field on IRNode; good for debugging why origin_node is +# not populated, or finding out where an IRNode was constructed +debug_ir_traceback = False + +# used for debugging to make sure config is properly set +_raise_error_for_testing = False + +_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") +profile_bandwidth = _profile_var != "" +profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output = os.environ.get("TORCHINDUCTOR_PROFILE_OUTPUT", None) + +# TODO: remove later +disable_cpp_codegen = False + + +# Freezing will attempt to inline weights as constants in optimization +# and run constant folding and other optimizations on them. After freezing, weights +# can no longer be updated. +freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" + +# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead +# of potentially keeping multiple copies of weights. +freezing_discard_parameters: bool = False + +# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests +# should be run with this flag both on and off to make sure we have coverage. +allow_stack_allocation: bool = ( + os.environ.get("TORCHINDUCTOR_STACK_ALLOCATION", "1") == "1" +) + +# Enables an alternate DSO interface (the "minimal ArrayRef interface") intended +# to maximize performance for use cases that it can accommodate at the expense of +# generality. In brief: +# - inputs and outputs are ArrayRefTensor (note that strides are required, but the +# tensor must be contiguous) +# - constant handling is unchanged because it is not a per-inference-iteration bottleneck +# +# When the DSO is generated in this mode, the usual interface will also be supported, +# but performance for that interface may be degraded. +use_minimal_arrayref_interface: bool = False + +# decompose some memory bound matmul/bmm to mul +decompose_mem_bound_mm: bool = False + + +# config specific to codegen/cpp.py +class cpp: + # set to torch.get_num_threads() + threads = -1 + + # Do not generate loops when the condition doesn't hold, like: + # for(long i0=4096; i0<4096; i0+=1) + no_redundant_loops = True + + # Assume number of threads is dynamic, don't specialize thread number. + # Kernels don't recompile on thread number changes with this flag on. + # For single-threaded workload, turning it on would incur a slight + # performance degradation. + dynamic_threads = False + + simdlen: Optional[int] = None + min_chunk_size = 4096 + cxx = ( + None, # download gcc12 from conda-forge if conda is installed + # "g++-12", + # "g++-11", + # "g++-10", + # "clang++", + os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), + # "g++.par", + ) + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = False + + # enable weight prepacking to get a better performance; may lead to large memory footprint + weight_prepack = True + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + inject_log1p_bug_TESTING_ONLY: Optional[str] = None + + # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, + # force usage as specified, without testing. + vec_isa_ok: Optional[bool] = None + + # similar to config.triton.descriptive_names + descriptive_names = "original_aten" + + # how many nodes to allow into a single horizontal fusion + max_horizontal_fusion_size = 16 + + # Make scatter_reduce fallback when reduce is sum to avoid performance regression + # using atomic_add. + fallback_scatter_reduce_sum = True + + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = False + + # Use ffp-contract when compiling + enable_floating_point_contract_flag = False + + +# config specific to codegen/triton.py +class triton: + # Use cudagraphs on output code + cudagraphs = False + + # Use cudagraph trees for memory pooling if `cudagraphs` is True + cudagraph_trees = True + + # assertions not on the fast path, steady state + slow_path_cudagraph_asserts = True + + # TODO - need to debug why this prevents cleanup + cudagraph_trees_history_recording = False + + # assertions on the fast path + fast_path_cudagraph_asserts = False + + # skip warmup for cudagraph trees + skip_cudagraph_warmup = False + + # Synchronize before and after every compiled graph. + debug_sync_graph = False + + # Synchronize after every kernel launch, to help pinpoint bugs + debug_sync_kernel = False + + # Always load full blocks (rather than broadcasting inside the block) + dense_indexing = False + + # limit tiling dimensions + max_tiles = 2 + + # use triton.autotune for pointwise ops with complex layouts + # this should only be disabled for debugging/testing + autotune_pointwise = True + + # max autotune gemm with cublasLt + autotune_cublasLt = True + + # should we stop a fusion to allow better tiling? + tiling_prevents_pointwise_fusion = True + tiling_prevents_reduction_fusion = True + + # should we give different names to kernels + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1" + + # should we put op names in kernel names + # False: No special names (just triton__1, triton__2, etc.) + # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) + # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) + # "inductor_node": Maps to the node name in the FX graph passed to Inductor + descriptive_names = "original_aten" + + # use alternate codegen for smaller reductions + persistent_reductions = ( + os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" + ) + + # 0/False: disable + # 1/True: enable, use tuning to pick between different subkernels + # 2: enable, force using persistent reduction (for debugging) + # 3: enable, force using non-persistent reduction (for debugging) + multi_kernel = int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")) + + # hint to Triton when arguments are divisible by 16 + divisible_by_16 = True + + # theses are not enforced, but they are used by asserts in triton_heuristics.py + # NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048 + + # Max RBLOCK will be large for multi-kernel since we do more aggressive + # persistent reduction. + max_block = { + "X": 2048, + "Y": 1024, + "Z": 1024, + "R": 4096 * (16 if multi_kernel else 1), + } + + # Minimum RBLOCK to be used for a TritonSplitScanKernel + # NOTE: This also indirectly controls the size of workspace buffer required + min_split_scan_rblock = 256 + + # Store the generated cubin files for cpp wrapper code to load + store_cubin = False + + # the max number of spills we allow for the configs we benchmark. + # Setting this to 0 means we skip a config if it spills even a single + # register. + # Setting it to a larger value allows a config spilling a small amount + # of registers being benchmarked. + # + # NOTE: triton will always report >0 register spills for kernels using sin/cos. + # (check this issue https://github.com/openai/triton/issues/1756 ) + # So far we see a fixed 8 spilled registers for kernels using sin/cos. + # Raise the threshold to 16 to be safe. + # We should revisit this once we understand more of the source of register spills. + spill_threshold: int = 16 + + # Generate code containing the newer tl.make_block_ptr() API for loads/store + use_block_ptr = False + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + + +class aot_inductor: + # AOTInductor output path + # If an absolute path is specified, the generated lib files will be stored under the directory; + # If a relative path is specified, it will be used as a subdirectory under the default caching path; + # If not specified, a temp directory will be created under the default caching path. + # If the specified path contains something like "model.so", the sub-string will be used + # to name the generated library. + output_path = "" + + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + + # Serialized tree spec for flattening inputs + serialized_in_spec = "" + + # Serialized tree spec for flattening outputs + serialized_out_spec = "" + + # flag to decide whether to create a submodule for constant graph. + use_runtime_constant_folding: bool = False + + +class cuda: + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. + compile_opt_level = "-O1" + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + # Whether to enable debug info, e.g. line number, cutlass debug info. + enable_debug_info = False + + # Whether to use fast math. + use_fast_math = False + + # Path to the CUTLASS repo root directory. + # The default path only works under PyTorch local development environment. + cutlass_dir = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") + ), + ) + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + # This is mainly used to reduce test time in CI. + cutlass_max_profiling_configs: Optional[int] = None + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2)CUDACXX environment variable + # 3)CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # If set to True, it will ensure that only GEMM ops capable of + # epilogue fusion via CUTLASS Epilogue Visitor Trees ( EVT ) + # are enabled for the CUTLASS backend. + cutlass_only_evt_capable_ops: bool = False + + +# create a directory containing lots of debug information +class trace: + # master switch for all debugging flags below + enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + # Save debug information to a temporary directory + # If not specified, a temp directory will be created by system + debug_dir: Optional[str] = None + + # Save python logger call >=logging.DEBUG + debug_log = False + + # Save python logger call >=logging.INFO + info_log = False + + # Save input FX graph (post decomps, pre optimization) + fx_graph = True + + # Save FX graph after transformations + fx_graph_transformed = True + + # Save TorchInductor IR before fusion pass + ir_pre_fusion = True + + # Save TorchInductor IR after fusion pass + ir_post_fusion = True + + # Copy generated code to trace dir + output_code = True + + # SVG figure showing post-fusion graph + graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" + + # SVG figure showing fx with fusion + draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" + + # We draw our fx graphs with the "record" shape attribute by default. + # Sometimes, when the graph is very complex, we may hit dot errors like below: + # "flat edge between adjacent nodes one of which has a record shape - + # replace records with HTML-like labels" + # and thus fail to generate a graph. So, let's give the user an option + # to specify the shape attribute for the dot graph. For example, passing + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like lables + # to workaround the above failure. + dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + + # Store cProfile (see snakeviz to view) + compile_profile = False + + # Upload the .tar.gz file + # Needs to be overriden based on specific environment needs + upload_tar: Optional[Callable[[str], None]] = None + + log_autotuning_results: bool = False + + +_save_config_ignore = { + # workaround: "Can't pickle " + "trace.upload_tar", +} + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + +from torch.utils._config_module import install_config_module + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/MLPY/Lib/site-packages/torch/_inductor/constant_folding.py b/MLPY/Lib/site-packages/torch/_inductor/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..31b0a71b8008bd317e56c6864c1ad65901da8fe5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/constant_folding.py @@ -0,0 +1,264 @@ +import collections +from typing import Any, Callable, Dict, Optional + +import torch +import torch.utils._pytree as pytree + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm, + skip_constructors=False, + ): + super().__init__(gm) + self.node_replacements: Dict[torch.fx.Node, Any] = {} + self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.node.Node): + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) + + for node in reversed(self.module.graph.nodes): + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + pytree.tree_map_only(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + if self.unknown_value in flattened_inputs: + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): + env = {} + for n in self.module.graph.nodes: + if n.op == "placeholder": + env[n] = self.unknown_value + return super().run(initial_env=env) + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_graph_tag(gm: torch.fx.GraphModule): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.nodes: + if node.op == "get_attr": + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/MLPY/Lib/site-packages/torch/_inductor/coordinate_descent_tuner.py b/MLPY/Lib/site-packages/torch/_inductor/coordinate_descent_tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..ba64b1df60fe71faf548fbf00e56fa7c53bb6907 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/coordinate_descent_tuner.py @@ -0,0 +1,315 @@ +import copy +import itertools +import logging +from typing import Callable, Optional + +from torch.utils._triton import has_triton +from .utils import red_text, triton_config_to_hashable + +if has_triton(): + import triton +else: + triton = None + +from . import config as inductor_config + +log = logging.getLogger(__name__) + + +def get_field(config, name): + if name == "num_warps": + return config.num_warps + elif name == "num_stages": + return config.num_stages + else: + return config.kwargs.get(name, None) + + +def set_field(config, name, value): + if name == "num_warps": + config.num_warps = value + elif name == "num_stages": + config.num_stages = value + else: + config.kwargs[name] = value + + +class CoordescTuner: + """ + The coordinate descent tuner. Tune one field/coordinate at a time. + + TODO will it be necessary to tune multiple fields simultaneously. + + + TODO: what if both increasing and decreasing a field can improve perf. + i.e., there are multiple local optima.. + """ + + def __init__(self, is_mm=False, name="unknown", size_hints=None): + self.is_mm = is_mm # we will tune num_stages for mm + self.cached_benchmark_results = {} + self.name = name + self.size_hints = size_hints + + def get_xmax(self): + xmax = inductor_config.triton.max_block["X"] + if self.size_hints and len(self.size_hints) > 0: + xmax = min(xmax, self.size_hints[0]) + return xmax + + def get_ymax(self): + ymax = inductor_config.triton.max_block["Y"] + if self.size_hints and len(self.size_hints) > 1: + ymax = min(ymax, self.size_hints[1]) + return ymax + + def get_zmax(self): + zmax = inductor_config.triton.max_block["Z"] + if self.size_hints and len(self.size_hints) > 2: + zmax = min(zmax, self.size_hints[2]) + return zmax + + def get_rmax(self): + if self.size_hints and len(self.size_hints) > 0: + return self.size_hints[-1] # the last one is for reduction + else: + # large enough. We should not pick this large RBLOCK anyway + return 2**30 + + def get_warpsmax(self): + # Currently, CUDA has a maximum of 1024 threads, so 32 is the max + # number of warps. + return 1024 // 32 + + def cache_benchmark_result(self, config, timing): + self.cached_benchmark_results[triton_config_to_hashable(config)] = timing + + def lookup_in_cache(self, config): + return self.cached_benchmark_results.get(triton_config_to_hashable(config)) + + def call_func(self, func, config): + found = self.lookup_in_cache(config) + if found is not None: + log.debug(" CACHED") + return found + timing = func(config) + self.cache_benchmark_result(config, timing) + return timing + + @property + def tunable_fields(self): + out = [ + "XBLOCK", + "YBLOCK", + "ZBLOCK", + # NOTE: we should not tune RBLOCK for persistent reduction. + # We rely on the fact that persistent reduction's triton.Config + # does not have the RBLOCK field to guarantee that. + "RBLOCK", + # the following 3 are for mm + "BLOCK_M", + "BLOCK_N", + "BLOCK_K", + "num_warps", + ] + if self.is_mm: + out.append("num_stages") + + return out + + def value_too_large(self, name, val): + if name == "XBLOCK": + return val > self.get_xmax() + if name == "YBLOCK": + return val > self.get_ymax() + if name == "ZBLOCK": + return val > self.get_zmax() + if name == "RBLOCK": + return val > self.get_rmax() + if name == "num_warps": + return val > self.get_warpsmax() + + return False + + def get_neighbour_values(self, name, orig_val, radius=1, include_self=False): + """ + Get neighbour values in 'radius' steps. The original value is not + returned as it's own neighbour. + """ + assert radius >= 1 + + def update(cur_val, inc=True): + if name == "num_stages": + if inc: + return cur_val + 1 + else: + return cur_val - 1 + else: + if inc: + return cur_val * 2 + else: + return cur_val // 2 + + out = [] + # increment loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, True) + if self.value_too_large(name, cur_val): + break + out.append(cur_val) + + # decrement loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, False) + if cur_val <= 0: + break + out.append(cur_val) + + if include_self: + out.append(orig_val) + return out + + @staticmethod + def has_improvement(baseline, test): + threshold = 0.001 # 0.1% + return test is not None and test < baseline * (1 - threshold) + + def check_all_tuning_directions( + self, + func: Callable[["triton.Config"], float], + best_config, + best_timing, + ): + """ + Check all directions. We only do this once the regular coordinate + descent tuning find no better choices any more. + We only have a few tunable fields, so this should be fine. + """ + candidate_values_list = [] + effective_fields = [] + for field in self.tunable_fields: + old_value = get_field(best_config, field) + if old_value is None: + continue + candidate_values = self.get_neighbour_values( + field, + old_value, + radius=inductor_config.coordinate_descent_search_radius, + include_self=True, + ) + candidate_values_list.append(candidate_values) + effective_fields.append(field) + + choices = itertools.product(*candidate_values_list) + improved = False + for choice in choices: + assert len(choice) == len(effective_fields) + candidate_config = copy.deepcopy(best_config) + for new_val, field in zip(choice, effective_fields): + set_field(candidate_config, field, new_val) + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config = candidate_config + best_timing = candidate_timing + + return improved, best_config, best_timing + + def compare_config(self, func, candidate_config, best_config, best_timing): + """ + Check if candidate_config is better than best_config. + + Return a touple of (compare_result, candidate_timing). + compare_result is true iff candidate_config is better. + """ + log.debug("Try config %s", candidate_config) + try: + candidate_timing = self.call_func(func, candidate_config) + except Exception as e: + log.debug("Got exception %s", e) + return False, float("inf") + + if self.has_improvement(best_timing, candidate_timing): + log.debug( + "Tune from %s %f -> %s %f", + best_config, + best_timing, + candidate_config, + candidate_timing, + ) + + return True, candidate_timing + return False, candidate_timing + + def autotune( + self, + func: Callable[["triton.Config"], float], + baseline_config: "triton.Config", + baseline_timing: Optional[float] = None, + ) -> "triton.Config": + if baseline_timing is None: + baseline_timing = self.call_func(func, baseline_config) + + log.debug("= Do coordinate descent tuning for %s =", self.name) + log.debug( + "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing + ) + improved = True + best_config = baseline_config + best_timing = baseline_timing + tunable_fields = self.tunable_fields + + while improved: + improved = False + + for name in tunable_fields: + cur_val = get_field(best_config, name) + # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None + if cur_val is None: + continue + + # It's possible that candidate_values is empty. + # E.g., if XBLOCK is 1 initially and size_hint for x is also 1. + # We would not try either larger or smaller XBLOCK in this case. + candidate_values = self.get_neighbour_values(name, cur_val) + + for next_val in candidate_values: + candidate_config = copy.deepcopy(best_config) + set_field(candidate_config, name, next_val) + + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config, best_timing = candidate_config, candidate_timing + + if not improved and inductor_config.coordinate_descent_check_all_directions: + old_best_timing = best_timing + improved, best_config, best_timing = self.check_all_tuning_directions( + func, best_config, best_timing + ) + + if improved: + msg = red_text( + "Coordinate descend tuning found improvement of %.3fx by looking in all directions." + ) + log.debug( + msg, + old_best_timing / best_timing, + ) + + log.debug( + "Improve from %s %f -> %s %f, %.3fx", + baseline_config, + baseline_timing, + best_config, + best_timing, + baseline_timing / best_timing, + ) + + return best_config diff --git a/MLPY/Lib/site-packages/torch/_inductor/cudagraph_trees.py b/MLPY/Lib/site-packages/torch/_inductor/cudagraph_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..f567001e9fbffa8987de7141a124dac3da3d621a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/cudagraph_trees.py @@ -0,0 +1,2159 @@ +""" +CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, +which share the same memory pool. Sharing a memory pool is an extremely +important optimization when chaining multiple CUDA graphs together, as it +prevents you from needing to copy intermediate tensors from one graph to the +next, and reduces overall memory usage by allowing dead memory from the first +pool to be reused in the second. + +The standard graph/make_graph_callables support sharing memory pool, but +with a lot of caveats. CUDA graph trees remove these restrictions: + +* Previously, if you recorded graphs A, B, you had to replay A, B in that + order. With CUDA graph trees, after replaying A, you can change your + mind and record/replay a different graph B'; we will support efficient + execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In + other words: we support arbitrary trees of CUDA graph operations, not just + sequences (this is why this feature is called CUDA graph trees.) + +* Previously, if you executed graph A, some non-CUDA graph code, and then + graph B, after executing graph B, it was not safe to retain any references + to intermediates produced by A. With CUDA graph trees, we track if any +outputs of graph A are still live by the time graph B is run, and make + sure graph B doesn't clobber there memory when reusing the CUDA graphs + pool. You'll get a separate recording of B depending on what tensors + stay live or dead. + +CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, +which is their primary use case. + +The ability to switch from replay to record is fairly nontrivial: remember that +when you replay a CUDA graph, you only replay CUDA operations; no CPU side state +is updated. In particular, the CPU-side book-keeping for the allocator is not +reconstructed. However, to record a new child CUDA graph, we must restore this +book-keeping. This is what checkpoint pool state is used for. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import gc +import itertools +import operator +import sys +import threading +import traceback +import warnings +import weakref +from collections import defaultdict + +from enum import auto, Enum +from typing import ( + Any, + Callable, + cast, + Dict, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch.fx +from torch import Tensor +from torch._dynamo.mutation_guard import GenerationTracker +from torch._dynamo.utils import preserve_rng_state +from torch._inductor.compile_fx import ( + align_inputs_from_check_idxs, + copy_misaligned_inputs, + get_expanded_dims, + get_input_idxs_to_check, + index_expanded_dims, + remove_unaligned_input_idxs, + static_input, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage +from torch.types import _bool +from torch.utils import _pytree as pytree +from torch.utils.weak import TensorWeakRef + +StorageWeakRefPointer = int +StorageDataPtr = int +NBytes = int + +if torch.backends.cuda.is_built(): + from torch._C import ( + _cuda_CUDAAllocator_AllocatorState as AllocatorState, + _set_cached_tensors_enabled as _set_cached_tensors_enabled, + ) +else: + + class AllocatorState: # type: ignore[no-redef] + pass + + def _set_cached_tensors_enabled(enabled: _bool) -> None: + pass + + +log = torch._logging.getArtifactLogger(__name__, "cudagraphs") + + +from . import config + + +@dataclasses.dataclass(frozen=True) +class GraphID: + "Unique counter of a cuda graph recording" + id: int + + +@dataclasses.dataclass(frozen=True) +class FunctionID: + "Unique counter of a function wrapped in cudagraphify_impl" + id: int + + +@dataclasses.dataclass(frozen=True) +class WrappedFunction: + """ + Represents a function that you want to record for CUDA graph replay, + with a little more metadata so we can identify if we have an applicable + CUDA graph in our CUDA graph tree for it. + """ + + model: Callable[..., Any] + static_input_idxs: Sequence[int] + id: FunctionID + constants: Tuple[torch.Tensor, ...] + + +def clear_cublass_cache(): + """ + Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for + doing warmup within a CUDAGraph private pool because we do not want persistent allocations from + one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors + from the previous generation are freed. This frees them the memory pool, but not elsewhere. + A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated + in the next run. The memory would be in use in two places. + + To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required + it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the + program. There is no overhead to this on replay since cudagraphs removes allocation overhead. + """ + torch._C._cuda_clearCublasWorkspaces() + + +@contextlib.contextmanager +def clear_cublas_manager(): + "Context manager around clearing cublas caches that will clear on enter and exit" + clear_cublass_cache() + try: + yield + finally: + clear_cublass_cache() + + +@contextlib.contextmanager +def disable_conv_cache_emptying(): + prev = torch._C._cuda_get_conv_benchmark_empty_cache() + torch._C._cudnn_set_conv_benchmark_empty_cache(False) + try: + yield + finally: + torch._C._cudnn_set_conv_benchmark_empty_cache(prev) + + +@contextlib.contextmanager +def enable_history_recording(): + "Turns on history recording in the CUDA Caching Allocator" + enabled = torch._C._cuda_isHistoryEnabled() + try: + if not enabled: + torch.cuda.memory._record_memory_history() + yield + finally: + if not enabled: + torch.cuda.memory._record_memory_history(None) + + +def get_history_recording(): + # TODO - remove, prevents cleanup + if not config.triton.cudagraph_trees_history_recording: + return contextlib.nullcontext() + return enable_history_recording() + + +class TreeManagerContainer: + """ + Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, + the tree and its corresponding memory pool should be kept alive as long as any outstanding + graph or tensor which is an output of a graph remains alive. + + There is a single tree manager container per device. + + The lifecycle of a tree_manager is: + - Is constructed, no graph, no fns, no tensors + - Tree manager is fetched, resulting in tree manager being allocated + - We generate a bunch of functions, calling add_strong_reference + - These functions die, calling finalize_reference + - When all the functions die, we finalize_tree_manager. + + TODO: in the future, we would like to do the following once storage weak refs land + - We look for all the live storages and add references to THOSE + - We count as storages die + - All the storages are dead, we deallocate the tree manager + """ + + def __init__(self, device_index): + # This class keeps a strong reference to tree_manager, + # but upon all other strong references to the tree_manager will reset it to None. + # We need a strong reference so that we can still access its attributes upon cleanup. + self.tree_manager: Optional[CUDAGraphTreeManager] = None + + # Number of outstanding references to the current tree manager + self.live_cudagraphify_fns = 0 + + self.device_index = device_index + + # Following two objects are only set in the case that Tensor outputs outlive + # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from + # deallocation. + self.live_storages_count = 0 + self.graph: Optional[torch.cuda.CUDAGraph] = None + + self.lock = threading.Lock() + + def _finalize_tensor(self): + with self.lock: + self.live_storages_count -= 1 + if self.live_storages_count == 0: + self.graph = None + + # manager was used again after existing cleanup, + # we shouldnt set it to None + if self.live_cudagraphify_fns == 0: + self.tree_manager = None + + def finalize_cudagraphify_fn(self): + with self.lock: + self.live_cudagraphify_fns -= 1 + if self.live_cudagraphify_fns == 0: + self._finalize_tree_manager() + + def _finalize_tree_manager(self): + assert self.lock.locked() + self.tree_manager = None + + # TODO - when issue #91395 is landed, we can set a weakref on + # storages and trigger a deallocation when all outputs of the + # cudagraph are dead. + + # live_storages = list( + # tree_manager.live_cudagraph_pool_storages_in_curr_execution() + # ) + + # # Maintain reference to graph to keep tensors alive + # assert len(tree_manager.roots) > 0, "expected at least one use" + # root = next(tree_manager.get_roots()) + # self.graph = root.graph + # seen_storages = set() + # for stor in live_storages: + # if stor in seen_storages: + # continue + # seen_storages.add(stor) + # self.live_storages_count += 1 + # . weakref.finalize(stor, self._finalize_tensor) + + def add_strong_reference(self, fn: Callable[..., Any]): + with self.lock: + self.live_cudagraphify_fns += 1 + + weakref.finalize(fn, self.finalize_cudagraphify_fn) + + def get_tree_manager(self) -> CUDAGraphTreeManager: + with self.lock: + if self.tree_manager is None: + self.tree_manager = CUDAGraphTreeManager(self.device_index) + return self.tree_manager + + +local = threading.local() + +# one tree manager per device +local.tree_manager_containers = {} +local.tree_manager_locks = defaultdict(threading.Lock) + + +# only incremented by user call of mark_step_begin +class MarkStepBox: + mark_step_counter = 0 + + +# We need to register this as an object that will be copied over as TLS when new +# threads are created in autograd +torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) +torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) + + +def mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + + # iterate down to distinguish from GenerationTracking counter + MarkStepBox.mark_step_counter -= 1 + + +def reset_cudagraph_trees(): + "Clear all cudagraph trees" + # see shutdown below for why this is necessary + container_dict = get_obj(local, "tree_manager_containers") + locks_dict = get_obj(local, "tree_manager_locks") + for device, lock in locks_dict.items(): + with lock: + container = container_dict.get(device) + if not container or not container.tree_manager: + continue + + container.tree_manager.shutdown() + + _set_cached_tensors_enabled(False) + container_dict.clear() + + MarkStepBox.mark_step_counter = 0 + + +def get_obj(local, attr_name): + if hasattr(local, attr_name): + return getattr(local, attr_name) + else: + assert torch._C._is_key_in_tls(attr_name) + return torch._C._get_obj_in_tls(attr_name) + + +def get_container(device_index: int): + container_dict = get_obj(local, "tree_manager_containers") + lock = get_obj(local, "tree_manager_locks")[device_index] + + with lock: + if device_index not in container_dict: + container_dict[device_index] = TreeManagerContainer(device_index) + + return container_dict[device_index] + + +def get_manager( + device_index: int, create_if_none_exists=True +) -> Optional[CUDAGraphTreeManager]: + if create_if_none_exists: + return get_container(device_index).get_tree_manager() + return get_container(device_index).tree_manager + + +def cudagraphify_impl(model, inputs, static_input_idxs, *args, **kwargs): + fn_cache: Dict[Tuple[int, ...], Callable[..., Any]] = {} + + # Detect int inputs: we need to index on these + int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] + get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None + + del inputs + + def deferred_cudagraphify(inputs): + int_key = get_ints(inputs) + fn = fn_cache.get(int_key) + if fn is not None: + return fn(inputs) + + if int_key is None: + log.info("recording cudagraph tree for graph without symints") + else: + log.info("recording cudagraph tree for symint key %s", int_key) + + # first get indices we need to check to align, then update our static inputs, + # and finally copy + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) + fn = align_inputs_from_check_idxs(fn, inputs_to_check=check_input_idxs) + fn_cache[int_key] = fn + + return out + + return deferred_cudagraphify + + +def cudagraphify( + model, + inputs, + static_input_idxs=(), + *, + device_index: int, + is_backward: bool, + is_inference: bool, + stack_traces: Optional[StackTraces] = None, + constants: Tuple[torch.Tensor, ...] = (), +): + manager = get_container(device_index).get_tree_manager() + assert not (is_backward and is_inference) + mode = ( + CompilationMode.BACKWARD + if is_backward + else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) + ) + + return manager.add_function( + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + ) + + +class StorageWeakRefWrapper: + """ + Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. + """ + + __slots__ = ["ref", "_data_ptr", "extra_ref_check"] + + storage_ref: Optional[StorageWeakRef] + + def __init__( + self, + inp: Union[Tensor, UntypedStorage], + extra_ref_check: Optional[Callable[[], None]] = None, + ): + """ + extra_ref_check is an additional check we need to run to check if the + weak ref has expired. in checking storage use count we assume extra_ref_check + will hold an additional reference to the storage. + """ + if isinstance(inp, Tensor): + stor = inp.untyped_storage() + else: + assert isinstance(inp, UntypedStorage) + stor = inp + self.ref = StorageWeakRef(stor) + self._data_ptr = stor.data_ptr() + self.extra_ref_check = extra_ref_check + + @classmethod + def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None): + instance = cls.__new__(cls) + instance._data_ptr = data_ptr + instance.ref = StorageWeakRef.from_weakref(cdata) + instance.extra_ref_check = extra_ref_check + return instance + + def __call__(self) -> Optional[StorageWeakRefPointer]: + if self.expired(): + return None + + return self.ref.cdata + + def swap_weakref(self, cdata): + self.ref.__del__() + self.ref.cdata = cdata + + def data_ptr(self) -> int: + "NB: returns the data ptr even if the storage has expired" + return self._data_ptr + + def remove_extra_reference(self): + self.extra_ref_check = None + + def expired(self): + if self.extra_ref_check is not None and not self.extra_ref_check(): + return False + + # if extra_ref_check is not None we expect an additional reference + stor_count = torch._C._storage_Use_Count(self.ref.cdata) + return (stor_count - (self.extra_ref_check is not None)) == 0 + + def __repr__(self): + if self.ref is None or self.ref.expired(): + return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" + else: + return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" + + +def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: + return maybe_deref(weak_ref) is not None + + +def maybe_deref( + weak_ref: Optional[StorageWeakRefWrapper], +) -> Optional[Tuple[StorageWeakRefPointer, int]]: + if weak_ref is None: + return None + r = weak_ref() + if r is None: + return None + # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() + return r, weak_ref.data_ptr() + + +@contextlib.contextmanager +def _use_cuda_memory_pool_manager(device, mem_pool, stream): + """ + Context manager to use cuda graph pool for new allocations. If you use this manager + all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. + existing_graph should already have been used in a capture, and the mem_pool must already exist, + because this manager will not preserve a reference to the pool which keeps it alive. + """ + torch.cuda.synchronize() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.device(device): + torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool) + try: + yield + finally: + torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool) + torch._C._cuda_releasePool(device, mem_pool) + + torch.cuda.current_stream().wait_stream(stream) + + +def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: + if not isinstance(t, torch.Tensor): + assert t is None + return None + return StorageWeakRefWrapper(t) + + +# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root +# at graph output offset +PathOutputIndex = Tuple[int, int] + +# For each node in the path, for each output, is the output alive +PathLiveness = List[List[bool]] + +StackTraces = List[Optional[str]] + + +class CUDAWarmupNode: + """ + Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes + apis to get the live storages in the current chain of warmup. + + A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have + CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable + memory addresses. + + CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. + - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the + first instance of warmup, these are not finalized yet. + - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. + - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. + + NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and + `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + parent, + cuda_graphs_pool: Tuple[int, int], + existing_cuda_graph: Optional[torch.cuda.CUDAGraph], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + already_warm: bool, + ): + self.wrapped_function = wrapped_function + self.parent = parent + self.cuda_graphs_pool = cuda_graphs_pool + self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = [] + self.tensor_weakrefs: List[Optional[TensorWeakRef]] = [] + self.existing_cuda_graph = existing_cuda_graph + self.has_run = False + self.device_index = device_index + self.stack_traces = stack_traces + self.stream = stream + self.already_warm = already_warm + + def run(self, new_inputs): + assert not self.has_run, "Wrapped function should never be run twice" + + # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created + # storages in path_live_weakrefs. + existing_path_data_ptrs = { + t.data_ptr() for t in self.path_live_weakrefs() if t() + } + + def get_non_cudagraph_inps(): + non_cudagraph_inps = set() + for t in itertools.chain(new_inputs, self.wrapped_function.constants): + if ( + isinstance(t, torch.Tensor) + and t.untyped_storage().data_ptr() not in existing_path_data_ptrs + ): + non_cudagraph_inps.add(t.untyped_storage().data_ptr()) + return non_cudagraph_inps + + non_cudagraph_inps = get_non_cudagraph_inps() + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) + + with torch.cuda.device( + self.device_index + ), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), get_history_recording(): + out = self.wrapped_function.model(new_inputs) + + assert len(new_inputs) == 0 + + # sdpa returns cpu tensors when not recording cuda graph + def add_ref(o): + return ( + o is not None + and isinstance(o, torch.Tensor) + and o.is_cuda + and o.untyped_storage().data_ptr() not in non_cudagraph_inps + and o.untyped_storage().data_ptr() != 0 + ) + + self.outputs_weakrefs.extend( + [map_to_ref(o) if add_ref(o) else None for o in out] + ) + self.tensor_weakrefs.extend( + [TensorWeakRef(o) if add_ref(o) else None for o in out] + ) + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + out_refs = self.path_live_weakrefs() + new_storages = [ + t for t in out_refs if t.data_ptr() not in non_cudagraph_inps + ] + check_memory_pool(self.device_index, self.cuda_graphs_pool, new_storages) + + return out + + @property + def _path_from_root(self): + nodes = [] + node = self + while node: + nodes.append(node) + node = node.parent + + yield from reversed(nodes) + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + "Returns all live storages weakrefs that created by nodes in this path" + for node in self._path_from_root: + for output in node.outputs_weakrefs: + if is_live(output): + yield output + + def all_outputs_are_dead(self): + return not list(self.path_live_weakrefs()) + + +# Aliases for List that say what the indices denote +InputList = List # input indexes +OutputList = List # output indexes +LevelList = List # levels (distance from root of tree) + + +class OutputAliasInfo: + pass + + +class _UnaliasedStorage(OutputAliasInfo): + "Singleton to mark that the graph output constructs a new alias or is None" + pass + + +UnaliasedStorage = _UnaliasedStorage() + + +class AliasesPriorGraphOutput(OutputAliasInfo): + "Marks that the graph output aliases an output of a prior graph" + __slots__ = ["index"] + + index: PathOutputIndex + + def __init__(self, index: PathOutputIndex): + assert isinstance(index, tuple) + self.index = index + + +class AliasesNewOutput(OutputAliasInfo): + "Marks that the graph output aliases an index in the new, returned outputs" + + __slots__ = ["index"] + + index: int + + def __init__(self, index): + assert isinstance(index, int) + self.index = index + + +class CUDAGraphNode: + """ + A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool + and are structured into a tree, where there is a single recording that can precede it (parent) and multiple + subsequent recordings that may follow (children). A node will have no parent if it is the first recording + in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which + would force a dependency. + + On first recording, all of the live tensors in the current CUDA Graph Node path will be + reflected in the corresponding private pool. On subsequent executions, the caching allocator + is unaffected when the graph is replayed. + + In order to support recording a subsequent cuda graph recording after execution of this graph, + we checkpoint the state of the memory pool so that it may later be resumed. + + WrappedFunction should have already been warmed up prior to invocation. + + See [setCheckpointPoolState] for further explanation, as well as + https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + id: GraphID, + parent: Optional[CUDAGraphNode], + inputs: List[Tensor], + cuda_graphs_pool: Tuple[int, int], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + ): + assert isinstance(inputs, (list, tuple)) + + self.wrapped_function = wrapped_function + self.id = id + self.device = device_index + self.stack_traces = stack_traces + self.stream = stream + + # if this is a root parent will be None. use weakref to prevent reference cycle + self._parent = weakref.ref(parent) if parent is not None else None + # reference to the shared memory pool for the entire cuda graphs tree + self.cuda_graphs_pool = cuda_graphs_pool + + # A single wrapped function may be recorded multiple times if memory patterns or + # invariants change from one execution to the next + self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # StorageWeakRef maintains whether the Storage C++ object remains allocated, + # not whether the corresponding memory has been deallocated. In order + # to use them to track memory deallocations we must maintain a single StorageWeakRef + # for all Storages that reference that memory (even if we are constructing Storages + # that do not have a deallocator function). We maintain one single storage_cache + # as we execute any tree path. When we retrieve a storage from the cache we + # check that it is still alive, and we hash based on observed recording data ptr + # and storage cdata. + + # we preserve a single reference to executed outputs that is then referenced + # in children to avoid children having to chase parent pointers in the hot path + # DO NOT reassign output_weakrefs, only call `clear()` + # Path is a series of nodes from root to the current node + self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] + self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ + node.outputs_weakrefs for node in self._path_from_root + ] + self.path_stacktraces: LevelList[StackTraces] = [ + node.stack_traces for node in self._path_from_root + ] + self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] + + # tensors which are outputs of previous graphs in the tree + self.cudagraph_managed_idxs: List[int] = [ + idx + for idx, t in enumerate(inputs) + if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) + ] + + self.static_input_idxs: List[int] = list( + set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) + ) + + self.static_input_data_ptrs: InputList[Optional[int]] = [ + ( + inputs[i].data_ptr() + if isinstance(inputs[i], torch.Tensor) and i in self.static_input_idxs + else None + ) + for i in range(len(inputs)) + ] + + # When we checkpoint, and free generations, we will be manually freeing the outputs + # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for + # their liveness (they are static), so we need to compute which outputs are aliases of + # parameters. Some static inputs are saved tensors from the forward that die in the backward. + # Their locations are static but lifetimes are not. We only include the persistent static + # data ptrs below because the non persistent data ptrs may be outputs of this record and + # fresh allocations. + + # precompute expanded dims to avoid computing in the hot path + self.expanded_dims: List[List[int]] = [ + get_expanded_dims(x) + if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs + else [] + for idx, x in enumerate(inputs) + ] + + # For each node in path, which outputs were observed to be live + # before invoking graph recording, and after graph recording + self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] + self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] + + # List of Tuples of (depth, output_index) that index into node at depth + # number of nodes from root and output_index of outputs. Will index into + # path_weakrefs. + self.expected_dead_indices_before_graph: List[PathOutputIndex] = [] + self.expected_dead_indices_after_graph: List[PathOutputIndex] = [] + + # all live indices after graph recording + self.live_indices_after_graph: List[PathOutputIndex] = [] + + if self.parent is not None: + previous_liveness = self.parent.recorded_liveness_after_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + different_indices = self._get_different_indices( + previous_liveness, curr_liveness + ) + + self.recorded_liveness_before_graph = curr_liveness + self.expected_dead_indices_before_graph = different_indices + + recording_inputs = self._allocate_and_copy_recording_inputs(inputs) + # recording inputs will copy over memory, so we can free non recording inputs + inputs.clear() + del inputs + + # graph used for recording model invocation + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + + # we allocate non-static inputs within the same memory pool as the CUDAGraph + # which we will record the model with. For memory efficiency, it is important + # to reclaim the input memory when the inputs are no longer live. To accomplish this, + # we reconstruct tensors at the correct data pointers of our inputs which are + # non owning and do not prevent deallocation. On subsequent executions, input values + # will be copied over to these tensors. + self.reconstructed_inputs: InputList[Union[Tensor, int]] = [ + self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) + if isinstance(x, torch.Tensor) + else x + for x in recording_inputs + ] + + # DO THE RECORDING!!! + # We record the CUDA graph in the constructor of CUDAGraphNode, which + # gives you what the CPU side compute of the function would do. We + # don't throw the recording outputs away: their memory is + # correctly accounted for in the CUDAGraphs caching allocator. This + # means on the very FIRST run of the CUDA graph node, we can directly + # do more recording, because we have a valid caching allocator state. + # NB: This relies on run() being called immediately after the + # constructor, otherwise this optimization would not be valid. + + # initialized below in _record + + self.checkpointed_caching_state: Optional[AllocatorState] = None + + # Output Storage Alias information, can be: + # - A new, unaliased storage, or the output is None + # - An alias of an output of a prior graph + # - An alias of an output already created in the reconstructed outputs + # This is None if the output in question is an int + self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] + + # is the output Storage unaliased in subsequent outputs, of all subsequent paths + # if it is, we cached the output tensor and adjust storage liveness tracking to also + # check if the output tensor does not have an additional python reference. + # If a descendent node discovers it has an alias of a prior output, then the output + # will no longer be cached in the ancestor. + # The large majority of tensors are unaliased, and preserving aliased output tensors would add + # significant additional complexity with marginal gains + # The cached tensor outputs are added on the first execution, and cleared whenever we need + # to do subsequent recording + self.unaliased_in_all_paths: OutputList[bool] = [] + self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] + + # if an output aliases a static, persistent input then the corresponding Tensor will + # be set here. These are different than cached tensors, because they are tensors that + # are aliases of parameters that are always live. + self.static_output_tensors: OutputList[Optional[Tensor]] = [] + + # Cleared after recording + self.recording_outputs: Optional[ + OutputList[Union[torch.Tensor, int]] + ] = self._record(wrapped_function.model, recording_inputs) + self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = [] + + # As with inputs, we do not want to keep the outputs permanently alive because that would prevent + # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata + # needed to reconstruct instead. + assert self.recording_outputs is not None + for out in self.recording_outputs: + if isinstance(out, torch.Tensor): + self.outputs_metadata.append( + self._tensor_metadata(out, ignore_storage_offset=False) + ) + else: + assert isinstance(out, (int, type(None))), type(out) + self.outputs_metadata.append(out) + + self.graph.replay() + + def _copy_input(self, idx, dst, src): + expanded_dims = self.expanded_dims[idx] + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + # TODO - one jit kernel across multiple inputs + dst.copy_(src) + + def run_first_inputs(self, new_inputs): + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + # graph is already invoked in the __init__ + # inputs are copied over in _allocate_recording_inputs and subsequently cleared + assert len(new_inputs) == 0 + outputs = self.recording_outputs + self.recording_outputs = None + return outputs + + def run(self, new_inputs): + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + assert len(self.static_input_data_ptrs) == len(new_inputs) + # NB: this ranges over non-static inputs too + for idx, data_ptr in enumerate(self.static_input_data_ptrs): + if idx in self.cudagraph_managed_idxs: + continue + if not isinstance(new_inputs[idx], torch.Tensor): + pass + elif data_ptr is not None: + # static input, e.g., parameter + assert data_ptr == new_inputs[idx].data_ptr() + else: + # non-static input, need to copy it into CUDA graph + dst = self.reconstructed_inputs[idx] + src = new_inputs[idx] + self._copy_input(idx, dst, src) + + new_inputs.clear() + self.run_graph() + + outputs = self.reconstruct_outputs() + self.debug_check_invariants_after_invocation() + + return outputs + + def reconstruct_outputs(self): + "Reconstruct output tensors according to their saved metadata and alias information" + + # Cached tensors will not yet be set on the first execution + # They are also cleared in checkpointing, so if we checkpoint this node + # and then execute it again we will need to repopulate cached tensors + if not self.cached_tensor_outputs: + self._initialize_cached_tensors() + + outputs: List[Optional[Union[int, torch.Tensor]]] = [] + + for i, (storage_info, metadata) in enumerate( + zip(self.output_storage_alias, self.outputs_metadata) + ): + if not isinstance(metadata, dict): # tensor metadata + assert isinstance(metadata, (int, type(None))) + outputs.append(metadata) + continue + + cached_t = self.cached_tensor_outputs[i] + if cached_t is not None: + # No need to update weakrefs, already correctly initialized + outputs.append(cached_t) + continue + + static_t = self.static_output_tensors[i] + if static_t is not None: + assert self.outputs_weakrefs[i] is None + outputs.append(static_t) + continue + + storage = self.prepare_alias_info_for_tensor_construction( + storage_info, metadata + ) + + if isinstance(storage, UntypedStorage) or storage is None: + out = self._reconstruct_from_tensor_metadata(metadata, storage) + else: + assert isinstance(storage, int) + out = self._reconstruct_from_tensor_metadata( + metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() + ) + + outputs.append(out) + w = self.outputs_weakrefs[i] + assert w is not None + w.swap_weakref(out.untyped_storage()._weak_ref()) + + return outputs + + def prepare_alias_info_for_tensor_construction( + self, + out_alias_info: Optional[OutputAliasInfo], + metadata: Union[Dict[str, Any], int, None], + ) -> Union[UntypedStorage, None, int]: + if ( + isinstance(metadata, (int, type(None))) + or out_alias_info is UnaliasedStorage + ): + return None + + if isinstance(out_alias_info, AliasesPriorGraphOutput): + depth, existing_output_index = out_alias_info.index + ref = self.path_weakrefs[depth][existing_output_index] + assert ref is not None + return torch.UntypedStorage._new_with_weak_ptr(ref()) + + assert isinstance(out_alias_info, AliasesNewOutput) + return out_alias_info.index + + def prepare_storages_for_construction( + self, + ) -> List[Union[UntypedStorage, None, int]]: + output_storages = [] + for output_storage_alias, metadata in zip( + self.output_storage_alias, self.outputs_metadata + ): + output_storages.append( + self.prepare_alias_info_for_tensor_construction( + output_storage_alias, metadata + ) + ) + + return output_storages + + def run_graph(self): + assert self.graph is not None + self.graph.replay() + + def all_outputs_are_dead(self): + "All outputs of the path from this node to its root are dead" + for depth, output_index in self.live_indices_after_graph: + if is_live(self.path_weakrefs[depth][output_index]): + return False + return True + + def _record(self, model, inputs): + "Record the model" + + def static_input_iter(): + for i in self.wrapped_function.static_input_idxs: + if isinstance( + inputs[i], torch.Tensor + ) and not self._is_cuda_graph_recorded_tensor(inputs[i]): + yield inputs[i] + + # see: output_is_alias_of_persistent_static_inputs above + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = { + inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) + for inp in itertools.chain( + static_input_iter(), self.wrapped_function.constants + ) + } + + if config.triton.slow_path_cudagraph_asserts: + # need to use parent live weakrefs because live_indices isnt set yet + memory = ( + [] if self.parent is None else list(self.parent.path_live_weakrefs()) + ) + memory += [ + StorageWeakRefWrapper(elem) + for i, elem in enumerate(inputs) + if isinstance(elem, torch.Tensor) + and i not in self.wrapped_function.static_input_idxs + and elem.untyped_storage().data_ptr() != 0 + ] + check_memory_pool(self.device, self.cuda_graphs_pool, memory) + + with preserve_rng_state(), torch.cuda.device( + self.device + ), clear_cublas_manager(), torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), get_history_recording(): + static_outputs = model(inputs) + + # running model should reclaim memory + assert len(inputs) == 0 + + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + + return static_outputs + + def _add_first_outputs( + self, + outputs, + static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper], + ): + "Add the outputs from the first invocation of the node and set up metadata" + + # getting liveness before we have added the outputs to path, so the length + # of the two lists is equal + prev_liveness = self.recorded_liveness_before_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + delta = self._get_different_indices(prev_liveness, curr_liveness) + self.expected_dead_indices_after_graph = delta + + assert len(self.outputs_weakrefs) == 0 + # index from data pointer to index in outputs + output_new_storages_index: Dict[StorageDataPtr, int] = {} + + self.unaliased_in_all_paths = [False for _ in range(len(outputs))] + self.static_output_tensors = [None for _ in range(len(outputs))] + + for i, o in enumerate(outputs): + if o is None or not isinstance(o, torch.Tensor): + self.output_storage_alias.append(UnaliasedStorage) + continue + + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), + ), + + ref = static_input_persistent_storage_ptrs.get( + o.untyped_storage().data_ptr(), None + ) + # also treat empty storages as static outputs because we do not need to manage their lifetime + # and they should not participate in checkpointing + is_empty_storage = o.untyped_storage().data_ptr() == 0 + if (ref and ref() is not None) or is_empty_storage: + self.output_storage_alias.append(None) + self.static_output_tensors[i] = o + continue + + path_ref = self._is_alias_of_live_recorded_tensor(o) + if path_ref is not None: + self._mark_prior_graph_output_as_aliased(path_ref) + self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) + continue + + if o.untyped_storage().data_ptr() in output_new_storages_index: + index = output_new_storages_index[o.untyped_storage().data_ptr()] + self.unaliased_in_all_paths[index] = False + self.output_storage_alias.append(AliasesNewOutput(index)) + continue + + output_new_storages_index[o.untyped_storage().data_ptr()] = i + self.output_storage_alias.append(UnaliasedStorage) + self.unaliased_in_all_paths[i] = True + + if self.stack_traces is None: + self.stack_traces = [None for _ in range(len(outputs))] + else: + assert len(self.stack_traces) == len( + outputs + ), "Wrong number of stack traces passed in" + + assert not self.outputs_weakrefs + for out, static_output_tensor in zip(outputs, self.static_output_tensors): + if not isinstance(out, torch.Tensor) or static_output_tensor is not None: + self.outputs_weakrefs.append(None) + self.tensor_weakrefs.append(None) + else: + self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) + self.tensor_weakrefs.append(TensorWeakRef(out)) + + self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) + self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( + self.device, self.cuda_graphs_pool + ) + + # now, get liveness with outputs added + for depth in range(len(self.path_weakrefs)): + for output_index in range(len(self.path_weakrefs[depth])): + if is_live(self.path_weakrefs[depth][output_index]): + self.live_indices_after_graph.append((depth, output_index)) + + self.debug_check_invariants_after_invocation() + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) + ) + + def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex): + "Remove a graph output from the unaliased, cached tensors in an ancestor node" + depth, output_index = index + node = list(self._path_from_root)[depth] + node.unaliased_in_all_paths[output_index] = False + x = self.path_weakrefs[depth][output_index] + assert x is not None + x.remove_extra_reference() + + def _initialize_cached_tensors(self): + # we should not be clearing output_weakrefs, and they should be set in the first + # record run + assert len(self.outputs_weakrefs) == len(self.outputs_metadata) + + for i, (storage_info, metadata, make_cached) in enumerate( + zip( + self.output_storage_alias, + self.outputs_metadata, + self.unaliased_in_all_paths, + ) + ): + if not make_cached: + self.cached_tensor_outputs.append(None) + continue + + assert storage_info is UnaliasedStorage + assert isinstance(metadata, dict) + s = self.create_storage(metadata) + out = self._reconstruct_from_tensor_metadata(metadata, storage=s) + + # XXX: let autograd know that there will be an additional reference to the tensor + # that can be ignored when deciding whether to do gradient buffer inplacing. + # Otherwise, inplacing could differ between tracing and subsequent execution. + # For some models we tested this led to inputs no longer being in cudagraph pools, + # leading to spurious re-recordings. + # It also tells AMP cache that even though the tensor impls cannot be cached + # in dtype conversions. + + torch._C._add_cached_tensor(out) + + self_ref = weakref.ref(self) + + # one reference in our array, and calling sys.getrefcount bumps the refcount by one + def check_refcount(i): + self_loc = self_ref() + if self_loc is None: + return False + return self_loc.get_output_refcount(i) == 2 + + check = functools.partial(check_refcount, i=i) + + self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) + self.cached_tensor_outputs.append(out) + + def get_output_refcount(self, index): + return sys.getrefcount(self.cached_tensor_outputs[index]) + + @property + def parent(self): + "unwraps the weakref to _parent" + return self._parent() if self._parent is not None else None + + @property + def _path_to_root(self): + "Returns all nodes in the path starting at self and ending at root" + node = self + while node: + yield node + node = node.parent + + @property + def _path_from_root(self): + "Returns all nodes in the path starting at the root and ending at self" + nodes = reversed(list(self._path_to_root)) + yield from nodes + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor): + "Is this tensor an output of a node in this path" + for output_refs in self.path_weakrefs: + for storage_weak_ref in output_refs: + if storage_weak_ref is None: + continue + # don't need to check liveness of storage since the cuda graph managed + # memory is never released. + data_ptr = storage_weak_ref.data_ptr() + if t.untyped_storage().data_ptr() == data_ptr: + return True + + return False + + def _is_alias_of_live_recorded_tensor( + self, t: torch.Tensor + ) -> Optional[PathOutputIndex]: + for depth, output_refs in enumerate(self.path_weakrefs): + for output_index, storage_ref in enumerate(output_refs): + if (storage_and_ptr := maybe_deref(storage_ref)) is not None: + storage, ptr = storage_and_ptr + if ptr == t.untyped_storage().data_ptr(): + return (depth, output_index) + + return None + + @staticmethod + def _check_liveness( + indices: List[PathOutputIndex], + output_refs: List[List[Optional[StorageWeakRefWrapper]]], + ): + "Check that all of the indices specified are dead references" + for depth, output_index in indices: + w = output_refs[depth][output_index] + assert w is not None + if w() is not None: + return False + return True + + def add_child(self, function_id: FunctionID, node: CUDAGraphNode): + "Adds node as a a child of self" + self.children[function_id].append(node) + + @staticmethod + def _get_different_indices( + prev: List[List[bool]], curr: List[List[bool]] + ) -> List[PathOutputIndex]: + "Find indices where the two lists differ." + dead_indices = [] + assert len(prev) <= len(curr) + for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): + assert len(outputs1) == len(outputs2) + for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): + if output1 != output2: + dead_indices.append((i, j)) + + return dead_indices + + @staticmethod + def _get_liveness( + weakrefs: List[List[Optional[StorageWeakRefWrapper]]], + ) -> List[List[bool]]: + "Maps weakrefs to true if the reference is alive and false otherwise" + if len(weakrefs) == 0: + return [] + + return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] + + def debug_assert_invariants( + self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex] + ): + if not config.triton.fast_path_cudagraph_asserts: + return + + for i, node in enumerate(self._path_from_root): + assert self.path_weakrefs[i] is node.outputs_weakrefs + + nodes = list(self._path_from_root) + + live_blocks = get_block_addrs(self.cuda_graphs_pool) + + live_storage_data_ptrs = set() + live_storage_weak_ptrs = set() + + for depth, outputs_liveness in enumerate(expected_liveness): + for output_idx, output_liveness in enumerate(outputs_liveness): + # tensor can die early, but it can't be alive when it should be dead + w = self.path_weakrefs[depth][output_idx] + if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: + assert output_liveness + stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr + assert (stor_data_ptr in live_storage_data_ptrs) == ( + stor_weak_ptr in live_storage_weak_ptrs + ) + live_storage_data_ptrs.add(stor_data_ptr) + live_storage_weak_ptrs.add(stor_weak_ptr) + + is_persistent_alias = ( + nodes[depth].static_output_tensors[output_idx] is not None + ) + + if is_persistent_alias: + assert stor_data_ptr not in live_blocks + + for depth, output_index in newly_dead: + assert not is_live(self.path_weakrefs[depth][output_index]) + + def debug_check_invariants_before_invocation(self): + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph + ) + + def debug_check_invariants_after_invocation(self): + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph + ) + + def data_ptrs_dead_since_invocation(self) -> List[int]: + """ + Since this node was invoked, return data ptrs of all tensor outputs that have died + in the current executing tree path. + """ + curr_liveness = self._get_liveness(self.path_weakrefs) + _get_different_indices = self._get_different_indices( + self.recorded_liveness_after_graph, curr_liveness + ) + + path = list(self._path_from_root) + ptrs_to_deallocate = [] + for depth, output_index in _get_different_indices: + ptrs_to_deallocate.append( + path[depth].outputs_metadata[output_index]["data_ptr"] + ) + + return ptrs_to_deallocate + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if out is not None and is_live(out): + yield out + + def remove_node_cached_tensors(self): + for t in self.cached_tensor_outputs: + if t is not None: + torch._C._remove_cached_tensor(t) + self.cached_tensor_outputs.clear() + + for i, unaliased in enumerate(self.unaliased_in_all_paths): + if unaliased: + n = self.outputs_weakrefs[i] + assert n is not None + n.remove_extra_reference() + + def remove_path_cached_tensors(self): + for node in self._path_from_root: + node.remove_node_cached_tensors() + + def clear_path_state(self): + "Clear the path state in this current executing node" + # this doesnt actually do anything right now, leaving it as placeholder + pass + + @staticmethod + def _tensor_metadata(x, ignore_storage_offset=True): + assert isinstance(x, torch.Tensor) + # We ignore the storage offset for inputs, but not for outputs + # TODO: - should we make the storage resizable ? + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, + } + + def _reconstruct_from_tensor_metadata( + self, metadata: Dict[str, Any], storage=None + ) -> Tensor: + s = self.create_storage(metadata) if storage is None else storage + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) + + def create_storage(self, metadata): + return torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + + def _allocate_and_copy_recording_inputs( + self, inputs + ) -> List[Union[torch.Tensor, int]]: + """ + Allocate inputs for non static, non cudagraph managraphed managed tensors in the memory pool + and copy over the tensor values. + """ + + torch.cuda.synchronize() + self.stream.wait_stream(torch.cuda.current_stream()) + recording_inputs: List[Union[Tensor, int]] = [] + + with warnings.catch_warnings(record=True), torch.cuda.device( + self.device + ), _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ): + for i, inp in enumerate(inputs): + if not isinstance(inp, torch.Tensor): + assert isinstance(inp, int) + recording_inputs.append(inp) + elif i not in self.static_input_idxs: + # static_input does an allocation! + recording_inputs.append(static_input(inp)) + # copy over and clear non recording input + self._copy_input(i, recording_inputs[-1], inp) + inputs[i] = None + del inp + else: + recording_inputs.append(inp) + + return recording_inputs + + def check_invariants(self, inputs: List[Tensor]) -> bool: + """ + Checks if this node can be run. The same pattern of tensor liveness and tensors + managed in the cudagraph private pool must remain stable. + """ + + # previously managed data pointers remain stable + for idx in self.cudagraph_managed_idxs: + if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]: + return False + + if not self._check_liveness( + self.expected_dead_indices_before_graph, self.path_weakrefs + ): + return False + + # the cudagraph managed tensors which died upon recording must also die upon + # this invocation. it is too late to check after we've replayed the graph, + # because we would have already written over their memory. + for idx in self.cudagraph_managed_idxs: + inputs[idx] = None # type: ignore[call-overload] + + torch._check( + self._check_liveness( + self.expected_dead_indices_after_graph, self.path_weakrefs + ), + lambda: "TODO: graph recording observed an input tensor deallocate during graph " + " recording that did not occur during replay. Please file an issue.", + ) + return True + + def num_descendants(self) -> int: + "Total number of descendents of this node" + num_desc = 0 + for children in self.children.values(): + for child in children: + num_desc += 1 + num_desc += child.num_descendants() + return num_desc + + +def get_cudagraph_segments(pool_id): + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + + +def get_block_addrs(pool_id, live_only=True): + blocks = [] + + for segment in get_cudagraph_segments(pool_id): + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated" or not live_only: + blocks.append(addr) + + addr += block["size"] + + return blocks + + +def format_tb(frames): + formatted_traceback = [] + + for entry in frames: + formatted_traceback.append( + traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) + ) + + return "".join(traceback.format_list(formatted_traceback)) + + +def check_memory_pool(device, pool_id, live_storages_ptrs: List[StorageWeakRefWrapper]): + assert all( + isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs + ) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} + + # check if there is a divergence first, then do the expensive snapshot call after + # we know it will error + if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): + return + + # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, + # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages + gc.collect() + + segments = get_cudagraph_segments(pool_id) + + allocated_not_in_live_storages = {} + + for segment in segments: + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated": + if addr not in unique_storages: + allocated_not_in_live_storages[addr] = block + else: + unique_storages.remove(addr) + + addr += block["size"] + + torch._check( + len(unique_storages) == 0, + lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", + ) + + if allocated_not_in_live_storages != 0: + formatted = [] + for dp, block in allocated_not_in_live_storages.items(): + trace = format_tb(block.get("frames", [])) + formatted.append(f"Data Pointer: {dp}, history: \n{trace}") + formatted_s = "\n".join(formatted) + msg = ( + f"These live storage data ptrs are in the cudagraph pool but not " + f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" + ) + raise RuntimeError(msg) + + +class ExecutionState(Enum): + """ + Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated + in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. + """ + + NONE = auto() + WARMUP = auto() + RECORDING = auto() + EXECUTION = auto() + + +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + +class CUDAGraphTreeManager: + """ + Groups individual recordings or executions of cuda graphs into a tree of recordings, + and checks required invariants, and manages warmups of graphs. + + When graphs are recorded in the same tree, it enforces subsequent execution + to follow the same order and have the same output tensor livespans. To remove + unnecessary coupling of cuda graphs (and additional imposed invariants), + the tree manager will end a currently recording tree whenever it is valid - when + the memory pool no longer has any live allocations. + + We ignore outputs from a previous generation that correspond to prior model outputs. + Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. + # TODO: make generation increment configurable, warn on overwrite. + + We run graph warmups in the cudagraph memory pool and return the result on the first invocation + of a function. For many models it is important to reclaim activations as you run the backward. + If we were to warm up the model and keep an extra copy of the inputs around to subsequently + use for recording, we would incur a memory penalty. Additionally, if we are part way through training + your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this + warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors + to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph + replay. + """ + + def __init__(self, device_index: int): + # roots are functions which have no dependencies on an other node. I.e., + # when they are first invoked, none of their inputs are outputs are outputs + # of another node, nor are there any live outputs of another node whose + # liveness would create a dependency. + self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) + + # mapping from function id to wrapped function + self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {} + + self.ids_to_stack_traces: Dict[FunctionID, StackTraces] = {} + + self.warmed_up_functions: Set[FunctionID] = set() + # if we fail to increment generation, and are stuck warming up, + # only warn on each function once + self.warned_functions: Set[FunctionID] = set() + torch._C._set_cached_tensors_enabled(True) + + # NB: cuda caching allocator will remember the stream a segment is allocated to + # and only allocate that segment to the same stream. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be reused; separate recordings would have use the same memory pool, but not + # the same memory. + + with torch.cuda.device(device_index): + torch.cuda.synchronize() + self.stream = torch.cuda.Stream() + self.stream.wait_stream(torch.cuda.current_stream()) + + # Keeps Memory Pool Alive + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() + + with warnings.catch_warnings(record=True), torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ): + pass + + self.graph_counter = itertools.count(0) + self.func_counter = itertools.count(0) + + # whether we the current node is in a state of warmup, recording, execution. If + # there is no current node the state will be ExecutionState.None. + self.path_state = ExecutionState.NONE + self.device_index = device_index + + # the most recently invoked cudagraph wrapping of a function. Will be None + # when there is no output from a previous recording or execution whose memory + # we need to respect in the cuda caching allocation. If you incremented generation, + # this will also be none, as ignore those allocations. + self.current_node: Optional[CUDAGraphNode] = None + + # current generation of cudagraph invocations. when torch.compile is run + # we increment the current generation. are willing to ignore live outputs + # of a previous generation in checking liveness. + self.current_gen: int = -1 + + # number of instances we are in execution and failed to match to an + # existing child + self.debug_fail_counter = 0 + # number of instances we had to checkpoint the function + self.debug_checkpointing_counter = 0 + + self.id_to_mode: Dict[FunctionID, CompilationMode] = {} + + # Note: [Backward Generation Handling] + # We generally perform a sequence of forward executions followed by backward executions. + # If multiple torch.compile wrapped forwards are executed with their backwards pending, + # we should not disregard the outputs from a prior torch.compile since the entire training + # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may + # not be executed, so we cannot wait for all pending forward pass backward completions, so + # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward + # invocation. Triggering a backward pass typically doesn't lead to another torch.compile + # invocation, making it less likely for the generation to increase between multiple + # backward calls. The following use case is covered by this approach: + # mod1 = torch.compile(...) + # mod2 = torch.compile(...) + # mod2(mod1(x)).sum().backward() + + self.running_forwards_with_pending_backwards = False + + def run(self, new_inputs: List[Tensor], function_id: FunctionID): + assert self.graph is not None, "Running CUDAGraph after shutdown" + out = self._run(new_inputs, function_id) + + # The forwards are only pending following invocation, not before + mode = self.id_to_mode[function_id] + if mode == CompilationMode.FORWARD: + self.running_forwards_with_pending_backwards = True + elif mode == CompilationMode.BACKWARD: + self.running_forwards_with_pending_backwards = False + + return out + + def set_to_running_backward(self): + self.running_forwards_with_pending_backwards = False + + def _run(self, new_inputs: List[Tensor], function_id: FunctionID): + # we will try to end the current execution lazily, since + # we dont want to do unnecessary checking of the existing outputs + # on the hot path, but both recording and warmup only happen once + # so we check up front + if self.in_recording: + self.try_end_curr_recording(function_id) + + if self.in_warmup: + self.try_end_curr_warmup(function_id) + + # warming up a function and subsequentally recording may use different memory addresses + # because both depend on the state of the caching allocator. if we warm up graph A, + # then warm up graph B and make more allocations, the subsequent recording of A will not + # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only + # be followed by warm up runs. + if ( + not ( + function_id in self.warmed_up_functions + or config.triton.skip_cudagraph_warmup + ) + ) or self.in_warmup: + # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. + # Both Recording and Warmup will be reflected in the allocator and dont need changes + if self.path_state == ExecutionState.EXECUTION: + self.apply_checkpoint_execution_state_in_allocator() + + return self.run_eager(new_inputs, function_id) + + child_nodes = ( + self.roots if self.current_node is None else self.current_node.children + ) + + if not self.in_recording: + for child in child_nodes[function_id]: + # here we are checking memory consistency between recording and execution, + # as well as things like stability of tensor locations, etc + # and other + if child.check_invariants(new_inputs): + return self.execute_node(child, new_inputs) + + # now that we know the new function can't be run as a child of the + # current node, if it is a root, try to end the current execution. + # as noted above, we want to do this lazily to avoid having to + # check all existing outputs + if self.current_node is not None and function_id in self.roots: + self.try_end_curr_execution() + + # run again to hit the root matching case which must succeed + if self.current_node is None: + return self.run(new_inputs, function_id) + + # at this point, we necessarily will do a new recording + self.debug_fail_counter += 1 + + self.try_end_curr_execution() + if self.current_node is not None: + self.apply_checkpoint_execution_state_in_allocator() + + # now, we are in a recording state ! + return self.record_function(new_inputs, function_id) + + def shutdown(self): + """ + Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn + might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown + to avoid a reference cycle. + """ + nodes = [] + for roots in self.roots.values(): + nodes.extend(roots) + + while nodes: + node = nodes.pop() + for children in node.children.values(): + nodes.extend(children) + node.remove_node_cached_tensors() + node.graph = None + + self.graph = None + self.roots = None # type: ignore[assignment] + self.current_node = None + + def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]: + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) + + def execute_node(self, node: CUDAGraphNode, new_inputs) -> List[Optional[Tensor]]: + self.current_node = node + self.path_state = ExecutionState.EXECUTION + self.update_generation() + return node.run(new_inputs) + + def run_eager(self, new_inputs, function_id: FunctionID): + # this is only stored on current node, because when we start a new path, + # we will deallocate it + already_warm = function_id in self.warmed_up_functions + if not already_warm: + log.debug("Running warmup of function %d", function_id.id) + else: + log.debug( + "Running eager of function %d because ancestor needed to warm up", + function_id.id, + ) + self.warmed_up_functions.add(function_id) + node = CUDAWarmupNode( + self.ids_to_funcs[function_id], + self.current_node, + self.cuda_graphs_thread_pool, + self.graph, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + already_warm, + ) + self.current_node = node + self.path_state = ExecutionState.WARMUP + self.update_generation() + return node.run(new_inputs) + + def new_graph_id(self) -> GraphID: + return GraphID(next(self.graph_counter)) + + def new_func_id(self) -> FunctionID: + return FunctionID(next(self.func_counter)) + + def add_function( + self, + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + ) -> Tuple[Callable[..., Any], List[Optional[Tensor]]]: + id = self.new_func_id() + self.ids_to_stack_traces[id] = stack_traces + self.ids_to_funcs[id] = WrappedFunction( + model, + static_input_idxs, + id, + tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), + ) + self.id_to_mode[id] = mode + fn = functools.partial(self.run, function_id=id) + + # container needs to set clean up when fn dies + get_container(self.device_index).add_strong_reference(fn) + return fn, fn(inputs) + + @property + def in_recording(self): + return self.path_state == ExecutionState.RECORDING + + @property + def in_warmup(self): + return self.path_state == ExecutionState.WARMUP + + def get_roots(self) -> Iterator[CUDAGraphNode]: + for nodes in self.roots.values(): + yield from nodes + + @property + def current_node(self): + return self._current_node + + @current_node.setter + def current_node(self, value): + self._current_node = value + if value is None: + self.path_state = ExecutionState.NONE + + def update_generation(self): + self.current_gen = self.get_curr_generation() + + @staticmethod + def get_curr_generation() -> int: + if MarkStepBox.mark_step_counter != 0: + return MarkStepBox.mark_step_counter + + return GenerationTracker.generation + + @staticmethod + def user_invoked_mark_step(): + return MarkStepBox.mark_step_counter != 0 + + def can_start_new_generation(self) -> bool: + if not self.in_new_torch_compile_invocation(): + return False + + if self.user_invoked_mark_step(): + return True + + return not self.running_forwards_with_pending_backwards + + def in_new_torch_compile_invocation(self): + return self.current_gen != self.get_curr_generation() + + def try_end_curr_recording(self, function_id: FunctionID) -> None: + """ + Check if the current recording can be terminated, either because all outputs of the + previously recorded node are dead or because it was executed in a different + generation. Will set current_node to None and in_recording to False if successful. + """ + assert self.in_recording + assert self.current_node is not None + + # multiple invocations, allow overwriting the previous generation + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def try_end_curr_execution(self) -> None: + """ + Check if the current executing node can be terminated, either because all outputs of the + previously executed node are dead or because it was executed in a different generation. + Will set current_node to None if successful. + """ + + assert not self.in_recording + if self.current_node is None: + return + + if self.can_start_new_generation(): + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + + def try_end_curr_warmup(self, function_id: FunctionID): + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.current_node = None + return + + if self.current_node.all_outputs_are_dead(): + self.current_node = None + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def check_warn_on_unable_to_start_executing(self, function_id: FunctionID): + "Warn if we in a potential loop where we are unable to hit fast path" + if ( + function_id in self.warned_functions + or not self.in_new_torch_compile_invocation() + ): + return + + existing_nodes = [ + node + for node in self.current_node._path_from_root + if node.wrapped_function.id == function_id + ] + + if len(existing_nodes) <= 1: + return + + # repeated same pattern + parents = { + n.parent.wrapped_function.id + for n in itertools.chain(existing_nodes, (self.current_node,)) + if n.parent is not None + } + if len(parents) == len(existing_nodes): + return + + self.warned_functions.add(function_id) + warnings.warn( + "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " + "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " + "before each model invocation" + ) + + def dealloc_current_path_weakrefs(self): + # TODO: we could also allow the these weak refs to continue to be allocated, + # but that adds some complications. + for node in self.current_node._path_from_root: + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): + ten = None if t is None else t() + if ten is None: + continue + + stack_trace = ( + stack_trace.strip() + if stack_trace + else "[Could not find stack trace]" + ) + msg = ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + torch._C._set_storage_access_error_msg(ten, msg) + + deleted = set() + for storage_ref in self.current_node.path_live_weakrefs(): + if storage_ref() and storage_ref.data_ptr() not in deleted: + deleted.add(storage_ref.data_ptr()) + torch._C._free_And_Remove_DeleterFn(storage_ref()) + + def clear_current_path_state_and_set_to_none(self): + self.current_node.clear_path_state() + self.current_node = None + + def apply_checkpoint_execution_state_in_allocator(self): + """ + Checkpoint the current execution state in the caching allocator so that + additional cudagraph recordings can be made respecting existent live storages. + """ + self.debug_checkpointing_counter += 1 + log.debug( + "Checkpointing cuda caching allocator state. Number of checkpoints %d", + self.debug_checkpointing_counter, + ) + + state = self.current_node.checkpointed_caching_state + device = self.current_node.device + assert state is not None and device is not None + + # currently we deallocate on instead of allowing stale recordings + stale_storages: List[int] = [] + + # remove cached tensors, otherwise they would prevent memory from being + # reclaimed in subsequent recordings + self.current_node.remove_path_cached_tensors() + live_storages_wrappers = list(self.current_node.path_live_weakrefs()) + + live_storages_weak_refs = [t() for t in live_storages_wrappers] + ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() + torch._C._cuda_setCheckpointPoolState( + device, state, stale_storages, live_storages_weak_refs + ) + + # NB: deduplicate aliased outputs + for ptr in set(ptrs_to_deallocate): + torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) + + # Now the live blocks should be exactly equal to the live storages in private pool + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers + ) + for wrapper in live_storages_wrappers: + assert wrapper() + assert torch._C._has_Standard_Deleter(wrapper()) + assert wrapper.data_ptr() not in ptrs_to_deallocate + + def live_cudagraph_pool_storages_in_curr_execution( + self, + ) -> List[StorageWeakRefPointer]: + if self.current_node is None: + return [] + # explicitly ignoring previous recorded outputs from past path + return [t() for t in self.current_node.path_live_weakrefs()] diff --git a/MLPY/Lib/site-packages/torch/_inductor/cudagraph_utils.py b/MLPY/Lib/site-packages/torch/_inductor/cudagraph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94c38b6ff4cd4cf7cf1de809daffd69e17f3752c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/cudagraph_utils.py @@ -0,0 +1,105 @@ +import dataclasses +from typing import Dict, Iterable, Optional + +import torch +from torch._inductor.codecache import CompiledFxGraph + + +def get_mutating_use_stack_trace(placeholder_node: torch.fx.Node) -> Optional[str]: + # reinplaced uses might have a single, non-copy_ use + if len(placeholder_node.users) == 1: + return next(iter(placeholder_node.users)).meta.get("stack_trace", None) + + for use in placeholder_node.users: + if use.target == torch.ops.aten.copy_.default: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + + return None + + +def format_default_skip_message(reason: str) -> str: + return f"skipping cudagraphs due to {reason}" + + +def get_mutation_stack_trace( + gm: torch.fx.GraphModule, mutation_indices: Iterable[int] +) -> str: + stack_trace: Optional[str] = "" + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + + for idx in mutation_indices: + placeholder = placeholders[idx] + if stack_trace := get_mutating_use_stack_trace(placeholder): + break + + if stack_trace: + msg = f"skipping cudagraphs due to mutation on input. Found from : \n {stack_trace}" + return msg + + return format_default_skip_message("mutated inputs") + + +def check_for_mutation( + gm: torch.fx.GraphModule, compiled_graph: CompiledFxGraph, num_fixed: int +) -> Optional[str]: + default_msg = format_default_skip_message("mutated inputs") + + # doesnt work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + # checking if mutation is only on parameters/static inputs + mutation_indices = [ + idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed + ] + has_mutation = len(mutation_indices) != 0 + if not has_mutation: + return None + + return get_mutation_stack_trace(gm, mutation_indices) + + else: + has_mutation = len(compiled_graph.mutated_inputs) != 0 + return None if not has_mutation else default_msg + + +def get_use_stack_trace(node) -> Optional[str]: + for use in node.users: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + return None + + +def check_multiple_devices_or_any_cpu_nodes( + device_node_mapping: Dict[torch.device, torch.fx.Node] +) -> Optional[str]: + if cpu_node := device_node_mapping.get(torch.device("cpu")): + if stack_trace := get_use_stack_trace(cpu_node): + return format_default_skip_message( + f"cpu device. Found from : \n {stack_trace}" + ) + + return format_default_skip_message("cpu device") + + if ( + len(device_node_mapping) == 1 + and next(iter(device_node_mapping.keys())).type == "cuda" + ): + return None + + keys_repr = (repr(key) for key in device_node_mapping.keys()) + return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") + + +def check_lowering_disable_cudagraph( + device_node_mapping: Dict[torch.device, torch.fx.Node] +): + return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) + + +@dataclasses.dataclass +class BoxedDeviceIndex: + value: Optional[int] + + def set(self, device_idx: Optional[int]): + assert device_idx is None or isinstance(device_idx, int) + self.value = device_idx diff --git a/MLPY/Lib/site-packages/torch/_inductor/debug.py b/MLPY/Lib/site-packages/torch/_inductor/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..72fa404a54a5d8e4d27ddee8d9c53ab311153a59 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/debug.py @@ -0,0 +1,655 @@ +import collections +import contextlib +import cProfile +import dataclasses +import functools +import itertools +import logging +import os +import os.path +import pickle +import pstats +import shutil +import subprocess +from typing import Any, Dict, List, Optional +from unittest.mock import patch + +from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled + +import torch +from torch import fx as fx + +from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug +from torch._dynamo.utils import get_debug_dir +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.fx.passes.tools_common import legalize_graph +from torch.utils._pytree import tree_map + +from . import config, ir # noqa: F811, this is needed +from .scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + NopKernelSchedulerNode, + OutputNode, + SchedulerNode, +) +from .virtualized import V + +log = logging.getLogger(__name__) + +SchedulerNodeList = List[Any] +BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) +GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] + + +@functools.lru_cache(None) +def has_dot() -> bool: + try: + subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE) + return True + except subprocess.SubprocessError: + return False + + +def draw_buffers(nodes: List[BaseSchedulerNode], print_graph=False, fname=None): + """ + Draw a graph in fname.svg. + """ + if not has_dot(): + log.warning("draw_buffers() requires `graphviz` package") + return + + if fname is None: + fname = get_graph_being_compiled() + + graph = create_fx_from_snodes(nodes) + + for node in graph.nodes: + if "fusion_meta" not in node.meta: + continue + group = node.meta["fusion_meta"].group + if isinstance(group, tuple): + if isinstance(group[1], int): + group = (group[1],) + else: + group = group[1] + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] + node.meta["tensor_meta"] = metadata + + if print_graph: + print(graph) + + gm = GraphModule({}, graph) + legalize_graph(gm) + gm.graph.lint() + draw_graph( + gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape + ) + + +def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + + def get_fake_func(name): + def func1(*args): + return 0 + + func1.__name__ = name + return func1 + + FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) + + buf_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + outputs = [] + group: Any = None + # create call_function node for each Buffer and Kernel + for snode in snodes: + if snode.is_extern(): + node_type = "extern" + group = node_type + elif snode.is_template(): + node_type = "template" + group = node_type + elif isinstance(snode, NopKernelSchedulerNode): + node_type = "nop" + group = node_type + elif isinstance(snode, SchedulerNode): + node_type = "compute" + group = snode.group + elif isinstance(snode, FusedSchedulerNode): + node_type = "fused" + group = snode.group + else: + raise RuntimeError("Unknown node type") + + fused_name = torch._inductor.utils.get_fused_kernel_name( + snode.get_nodes(), "original_aten" + ) + func_name = f"{node_type}: {fused_name}" + node_func = get_fake_func(func_name) + kwargs = {} + if hasattr(snode, "get_device"): + kwargs = {"device": snode.get_device()} + fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) + + def in_output(snode): + if isinstance(snode, FusedSchedulerNode): + return any(in_output(x) for x in snode.snodes) + return any(isinstance(user.node, OutputNode) for user in snode.users) + + if in_output(snode): + outputs.append(fx_node) + name = snode.get_name() + fx_node.name = name + + fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) + + if isinstance(snode, FusedSchedulerNode): + for x in snode.snodes: + buf_to_fx_node[x.get_name()] = fx_node + buf_to_fx_node[name] = fx_node + + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in snodes: + name = snode.get_name() + deps = snode.read_writes.reads + + fx_node = buf_to_fx_node[name] + new_args = [] + for dep in deps: + if dep.name in buf_to_fx_node: + dep_node = buf_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) + buf_to_fx_node[dep.name] = dep_node + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + return graph + + +def update_orig_fx_node_name_to_buf_name( + nodes: SchedulerNodeList, + node_name_to_buf_name: Dict[str, str], + parent_buf_name: Optional[str] = None, + n_origins: int = 0, +): + if nodes is None: + return + for node in nodes: + # for FusedSchedulerNode, traverse recursively into get_nodes() + buf_name = node.get_name() + children_nodes = node.get_nodes() + if children_nodes is not None and len(children_nodes) > 1: + update_orig_fx_node_name_to_buf_name( + children_nodes, + node_name_to_buf_name, + buf_name if parent_buf_name is None else parent_buf_name, + ) + continue + else: + assert len(children_nodes) == 1 and children_nodes[0] == node + + ir_node = node.node + if ir_node is None or ir_node.origins is None: + continue + for origin in ir_node.origins: + node_name = origin.name + # when buf1 and buf2 both have origin=node1 + # we draw node1 according to buf1 + if node_name not in node_name_to_buf_name: + node_name_to_buf_name[node_name] = ( + buf_name if parent_buf_name is None else parent_buf_name + ) + + +def get_node_name_to_buf_meta(node_name_to_buf_name: Dict[str, str]): + buf_name_to_n_node = {} + for node_name, buf_name in node_name_to_buf_name.items(): + if buf_name not in buf_name_to_n_node: + buf_name_to_n_node[buf_name] = {node_name} + else: + buf_name_to_n_node[buf_name].add(node_name) + + node_name_to_buf_meta = {} + for node_name, buf_name in node_name_to_buf_name.items(): + n_node = len(buf_name_to_n_node[buf_name]) + node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) + return node_name_to_buf_meta + + +def annotate_orig_fx_with_snodes( + gm: torch.fx.GraphModule, snodes: SchedulerNodeList +) -> None: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + node_name_to_buf_name: Dict[str, str] = {} + update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) + if node_name_to_buf_name is None: + return + node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) + for node in gm.graph.nodes: + if node.name in node_name_to_buf_meta: + node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) + + +@contextlib.contextmanager +def enable_aot_logging(): + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + import torch._functorch.aot_autograd + + log = logging.getLogger(torch._functorch.aot_autograd.__name__) + + stack = contextlib.ExitStack() + if not compile_debug: + try: + yield + finally: + stack.close() + return + + # Enable all graphs to be logged to a file by setting the flags to True + # and the log level of the file logger to DEBUG + stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) + + path = os.path.join(get_debug_dir(), "torchinductor") + os.makedirs(path, exist_ok=True) + + fh = logging.FileHandler( + os.path.join( + path, + f"aot_{get_aot_graph_name()}_debug.log", + ) + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(fh) + try: + yield + finally: + log.removeHandler(fh) + stack.close() + + +class DebugContext: + _counter = itertools.count() + + @staticmethod + def wrap(fn): + @functools.wraps(fn) + def inner(*args, **kwargs): + with DebugContext(): + return fn(*args, **kwargs) + + return wrap_compiler_debug(inner, compiler_name="inductor") + + @staticmethod + def create_debug_dir(folder_name: str) -> Optional[str]: + debug_dir = config.trace.debug_dir or get_debug_dir() + for n in DebugContext._counter: + dirname = os.path.join( + debug_dir, + "torchinductor", + f"{folder_name}.{n}", + ) + if not os.path.exists(dirname): + os.makedirs(dirname) + return dirname + return None + + def __init__(self): + self._prof = None + self._path = None + self._stack = contextlib.ExitStack() + + def copy(self, new_path: str): + if not self._path: + return + assert new_path.endswith(".debug"), new_path + if os.path.exists(new_path): + shutil.rmtree(new_path) + try: + shutil.copytree(self._path, new_path) + self._path = new_path + except OSError: + log.warning( + "Failed to copy debug files from %s to %s", self._path, new_path + ) + pass + + def fopen(self, filename: str, write_mode: str = "w", *args, **kwargs): + assert self._path + return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) + + @contextlib.contextmanager + def fopen_context(self, filename: str, write_mode: str = "w", *args, **kwargs): + assert self._path + with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: + yield f + + def filename(self, suffix: str): + assert self._path + return os.path.join(self._path, suffix) + + def upload_tar(self): + if config.trace.upload_tar is not None: + import tarfile + + assert self._path + tar_file = os.path.join( + self._path, f"{os.path.basename(self._path)}.tar.gz" + ) + with tarfile.open(tar_file, "w:gz") as tar: + tar.add(self._path, arcname=os.path.basename(self._path)) + config.trace.upload_tar(tar_file) + + def __enter__(self): + if config.debug: + log = logging.getLogger("torch._dynamo") + prev_level = log.level + log.setLevel(logging.DEBUG) + + def reset_log_level(level): + log.setLevel(level) + + self._stack.callback(reset_log_level, prev_level) + + self._stack.enter_context(V.set_debug_handler(self)) + + if not config.trace.enabled: + return + + self._path = self.create_debug_dir(get_aot_graph_name()) + + if config.trace.debug_log: + self._setup_log_capture("debug.log", logging.DEBUG) + if config.trace.info_log: + self._setup_log_capture("info.log", logging.INFO) + if config.trace.compile_profile: + self._prof = cProfile.Profile() + self._prof.enable() + + def _setup_log_capture(self, filename: str, level: int): + log = logging.getLogger("torch._inductor") + fd = self._stack.enter_context(self.fopen(filename)) + ch = logging.StreamHandler(fd) + ch.setLevel(level) + ch.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(ch) + log.setLevel(min(log.level, level)) + self._stack.callback(log.removeHandler, ch) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._prof: + self._prof.disable() + self._save_profile_data() + + if self._path: + self.upload_tar() + log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) + self._stack.close() + + def _save_profile_data(self): + assert self._prof + self._prof.dump_stats(self.filename("compile.prof")) + with self.fopen("compile.stats") as fd: + stats = pstats.Stats(self._prof, stream=fd) + stats.strip_dirs() + stats.sort_stats("cumtime") + stats.print_stats(100) + stats.sort_stats("tottime") + stats.print_stats(100) + + def __getattr__(self, name): + if config.trace.enabled and getattr(config.trace, name): + try: + return getattr(DebugFormatter(self), name) + except Exception: + log.warning("Ignoring exception in debug code", exc_info=True) + else: + + def ignored(*args, **kwargs): + pass + + return ignored + + +class DebugFormatter: + def __init__(self, handler): + self.fopen = handler.fopen + self.fopen_context = handler.fopen_context + self.filename = handler.filename + self.handler = handler + + def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): + with self.fopen("fx_graph_runnable.py") as fd: + save_graph_repro(fd, gm, inputs, "inductor") + + with self.fopen("fx_graph_readable.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def fx_graph_transformed( + self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor] + ): + with self.fopen("fx_graph_transformed.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def ir_pre_fusion(self, nodes: SchedulerNodeList): + self._write_ir("ir_pre_fusion.txt", nodes) + + def ir_post_fusion(self, nodes: SchedulerNodeList): + self._write_ir("ir_post_fusion.txt", nodes) + + def _write_ir(self, filename: str, nodes: SchedulerNodeList): + with self.fopen(filename) as fd: + log.info("Writing debug ir to %s", fd.name) + for node in nodes: + fd.write(node.debug_str()) + fd.write("\n\n\n") + + def graph_diagram(self, nodes: SchedulerNodeList): + draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) + + def draw_orig_fx_graph(self, gm: torch.fx.GraphModule, nodes: SchedulerNodeList): + annotate_orig_fx_with_snodes(gm, nodes) + draw_graph( + gm, + fname=self.filename("orig_fx_graph_diagram.svg"), + clear_meta=False, + prog=GRAPHVIZ_COMMAND_SCALABLE, + parse_stack_trace=True, + dot_graph_shape=config.trace.dot_graph_shape, + ) + + def output_code(self, filename): + shutil.copy(filename, self.filename("output_code.py")) + + def log_autotuning_results( + self, + name: str, + input_nodes: List[ir.IRNode], + timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 + elapse: float, + ): + import json + + from .ir import FixedLayout + + def build_node_info(node: ir.IRNode): + if hasattr(node, "name"): + node_name = node.name + else: + node_name = "" + node_info = { + "name": node_name, + "type": type(node).__name__, + } + try: + layout = node.get_layout() + if isinstance(layout, FixedLayout): + offset = 0 + try: + offset = int(layout.offset) + except Exception: + try: + offset = V.graph.sizevars.size_hint( + layout.offset, fallback=0 + ) + except Exception: + pass + static_layout = FixedLayout( + layout.device, + dtype=layout.dtype, + size=list(V.graph.sizevars.size_hints(layout.size)), + stride=list(V.graph.sizevars.size_hints(layout.stride)), + offset=offset, + ) + node_info["layout"] = str(static_layout) + else: + node_info["layout"] = str(node.get_layout()) + except Exception as e: + pass + try: + node_info["dtype"] = str(node.get_dtype()) + except Exception as e: + pass + try: + node_info["device"] = str(node.get_device()) + except Exception as e: + pass + try: + node_info["stride"] = str( + V.graph.sizevars.size_hints(node.get_stride()) + ) + except Exception as e: + pass + try: + node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) + except Exception as e: + pass + try: + node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) + except Exception as e: + pass + if hasattr(node, "data") and isinstance(node.data, ir.IRNode): + node_info["data"] = build_node_info(node.data) + return node_info + + general_properties = { + "op_name": name, + "cuda_device_name": torch.cuda.get_device_name(), + "cuda_device_count": torch.cuda.device_count(), + "input_nodes": [build_node_info(node) for node in input_nodes], + "autotuning_time": elapse, + } + with self.fopen_context( + "autotuning_result_json_list.txt", "at", encoding="utf-8" + ) as fd: + for caller, time in timings.items(): + info_dict = dict(caller.info_dict()) + info_dict.update(general_properties) + info_dict["benchmark_result"] = time + json.dump(info_dict, fd) + fd.write("\n") + + +@dataclasses.dataclass +class TensorMetadataHolder: + tensor_metadata: TensorMetadata + device: torch.device + + +save_args_cnt = itertools.count() + + +def save_args_for_compile_fx_inner(*args, **kwargs): + """ + This function is used to save arguments for a compile_fx_inner function call + to the file system. Later on one can replay the compile_fx_inner call + with the saved arguments using load_args_and_run_compile_fx_inner. + """ + + folder = "/tmp/inductor_saved_args" + if not os.path.exists(folder): + os.mkdir(folder) + + def handle_tensor(x): + """ + Pickle FakeTensor will result in error: + AttributeError: Can't pickle local object 'WeakValueDictionary.__init__..remove' + + Convert all Tensor to metadata. This may also makes pickle faster. + """ + if isinstance(x, torch.Tensor): + return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) + else: + return x + + args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) + + fn_name = "compile_fx_inner" + path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" + with open(path, "wb") as f: + pickle.dump((args_to_save, kwargs_to_save), f) + + if log.isEnabledFor(logging.DEBUG): + message = f""" +Arguments for a compile_fx_inner call is saved to {path}. To replay the call, +run the following: + +from torch._inductor.debug import load_args_and_run_compile_fx_inner +load_args_and_run_compile_fx_inner({path!r}) + """ + # call print rather than log.debug. log.debug will print message + # prefix for each line which makes the code snippet harder to be + # copied. + # Not a big deal since the code is already been guarded by checking + # the log level. + print(message) + + +def load_args_and_run_compile_fx_inner(path: str): + from torch._inductor.compile_fx import compile_fx_inner + + with open(path, "rb") as f: + args, kwargs = pickle.load(f) + + def handle_tensor(x): + if isinstance(x, TensorMetadataHolder): + return torch._dynamo.testing.rand_strided( + x.tensor_metadata.shape, + x.tensor_metadata.stride, + x.tensor_metadata.dtype, + x.device, + ) + else: + return x + + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + with fake_mode, config.patch("save_args", False): + args, kwargs = tree_map(handle_tensor, (args, kwargs)) + return compile_fx_inner(*args, **kwargs) diff --git a/MLPY/Lib/site-packages/torch/_inductor/decomposition.py b/MLPY/Lib/site-packages/torch/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f5de93b07c209075921f3b940e66cb4ba08fb4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/decomposition.py @@ -0,0 +1,678 @@ +import functools +import logging +import math +import sys +import typing +from typing import Optional + +import torch +import torch._decomp as decomp +import torch._prims_common as utils +import torch.ao.quantization.fx._decomposed +from torch._decomp import ( + core_aten_decompositions, + get_decompositions, + remove_decompositions, +) +from torch._decomp.decompositions import ( + _grid_sampler_2d as decomp_grid_sampler_2d, + pw_cast_for_opmath, +) +from torch._decomp.decompositions_for_rng import extra_random_decomps +from torch._higher_order_ops.out_dtype import out_dtype +from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + type_to_dtype, +) + +from . import config, inductor_prims + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed + +inductor_decompositions = get_decompositions( + [ + aten._adaptive_avg_pool2d_backward, + aten.arange, + aten.bitwise_and_, + aten.bitwise_or_, + aten.clamp_min_, + aten.dist, + aten.empty_like, + aten.flip, + aten.gelu, + aten.hardtanh, + aten.index_select, + aten.lcm, + aten.leaky_relu, + aten.linalg_vector_norm, + aten._log_softmax, + aten.max_pool2d_with_indices_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten.native_batch_norm, + aten.native_group_norm, + aten.native_layer_norm, + aten.nll_loss2d_backward, + aten._softmax, + aten.sin_, + aten.sqrt_, + out_dtype, + aten._to_copy, + aten.tril_indices, + aten.triu_indices, + aten.upsample_bilinear2d.vec, + ] +) +decompositions = {**core_aten_decompositions(), **inductor_decompositions} + +# Remove unwanted decompositions included via the core ATen decompositions from +# the Inductor decomp table. +decomps_to_exclude = [ + aten._unsafe_index, + aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py + aten.clamp_max, + aten.clamp_min, + aten.glu, # inductor lowers this directly + aten.split.Tensor, # inductor lowers this directly + aten.squeeze, # inductor lowers this directly + aten.sum, # inductor lowers this directly + aten.unbind, # inductor lowers this directly +] + +remove_decompositions(decompositions, decomps_to_exclude) + + +def register_decomposition(ops): + for op in [ops] if callable(ops) else ops: + if op in decompositions: + log.warning("duplicate decomp: %s", ops) + return decomp.register_decomposition(ops, decompositions) + + +# TODO: for now, inductor doesn't handle asserts +# because the condition is symbool -> tensor in the graph. +@register_decomposition([aten._assert_async.msg]) +def assert_async_msg_decomp(tensor, msg): + return + + +# Following `assert_async_msg_decomp` and implement as non-op. +@register_decomposition([aten._functional_assert_async.msg]) +def functional_assert_async_msg_decomp(tensor, msg): + return + + +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size(symbol, *, min=None, max=None): + return + + +@register_decomposition([aten.clamp]) +@pw_cast_for_opmath +def clamp(x, min=None, max=None): + if min is not None: + x = x.clamp_min(min) + if max is not None: + x = x.clamp_max(max) + return x + + +@register_decomposition([aten.full]) +def full(size, fill_value, **kwargs): + dtype = kwargs.get("dtype") + if dtype is None: + kwargs["dtype"] = type_to_dtype(type(fill_value)) + return aten.full(size, fill_value, **kwargs) + return NotImplemented + + +# Not really sure how to put this into the main library. PrimTorch wants +# empty_permuted to go to the prim, and typically users don't really want +# to decompose to empty_strided (but inductor is OK with it, because we are +# cool with strides and everything goes to empty_strided) +@register_decomposition([aten.empty_permuted.default]) +def empty_permuted(size, physical_layout, **kwargs): + perm = [0] * len(size) + for p, l in enumerate(physical_layout): + perm[l] = p + return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) + + +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + if not output_mask[2] or grad_output.device.type != "cuda": + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + + +@register_decomposition([aten.log2]) +def log2(x): + return torch.log(x) * (1.0 / math.log(2.0)) + + +@register_decomposition([aten.round.decimals]) +def round_dec(x, decimals=0): + ten_pow_decimals = 10.0**decimals + return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) + + +@register_decomposition([aten.bmm]) +@pw_cast_for_opmath +def bmm(self, batch2): + if config.coordinate_descent_tuning: + if self.shape[1] == 1 or batch2.shape[2] == 1: + out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) + return out + if self.device.type == "cpu": + if self.size(1) == 1 and batch2.size(-1) == 1: + return torch.sum( + self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True + ).unsqueeze(1) + return NotImplemented + + +@register_decomposition([aten.addmm]) +@pw_cast_for_opmath +def addmm(self, mat1, mat2, beta=1, alpha=1): + if self.device.type == "cpu": + if mat1.size(0) == 1 and mat2.size(-1) == 1: + out = torch.sum( + mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return alpha * out + beta * self + if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16: + out = (mat1.T * mat2).sum(dim=0, keepdim=True) + return alpha * out + beta * self + return NotImplemented + + +@register_decomposition([aten.mm]) +@pw_cast_for_opmath +def mm(self, input2): + from torch.fx.experimental.symbolic_shapes import ( + definitely_true, + guard_size_oblivious, + ) + + # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. + # todo: Look into why and fix it (hopefully) + if config.coordinate_descent_tuning: + if self.shape[0] == 1 or input2.shape[1] == 1: + return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) + if self.device.type == "cpu": + if ( + guard_size_oblivious(self.size(-1) == 1) + and guard_size_oblivious(self.size(0) > 0) + and guard_size_oblivious(input2.size(0) == 1) + and (self.dtype == input2.dtype) + and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) + ): + return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) + if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( + input2.size(-1) == 1 + ): + return torch.sum( + self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return NotImplemented + + +# This pass does two things: +# - Eliminate cat when there is only one tensor input +# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we +# don't remove ALL empty tensors, only the naughty ones) +@register_decomposition([aten.cat.default]) +def cat(tensors, dim=0): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + def non_empty_tensor(x): + # For better or worse, this is a valid cat: + # + # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) + # + # We'd like to eliminate naughtiness like this for downstream passes + # like split_cat. The easiest way is to just drop such inputs + # (guarding that they are non-zero). + # + # Is it permissible for this filtering to be size-oblivious? A case + # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 + # happened to be zero, we would have liked to have filtered it out. + # But actually, the ONLY way this could have passed is if u0 == 0, + # so by the time we get here we have already installed a deferred + # runtime assert forcing u0 to be zero. So if this hasn't happened, + # we know that the unbacked SymInt has appropriate size and there are + # no problems. + return len(x.shape) != 1 or guard_size_oblivious(x.shape[0] > 0) + + filtered_tensors = list(filter(non_empty_tensor, tensors)) + + if len(filtered_tensors) == 1: + return filtered_tensors[0].clone() + elif 1 < len(filtered_tensors) < len(tensors): + # on the first call, when we remove empty tensors, we redispatch recursively + return aten.cat.default(filtered_tensors, dim) + # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) + return NotImplemented + + +@register_decomposition([aten.angle]) +def angle(x): + if x.is_complex(): + return torch.where( + torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) + ) + + # when x is real number + # if x >= 0, return 0 + # if x < 0, return pi + # if x is nan, return nan + _, dtype = elementwise_dtypes( + x, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) + ret = torch.where(x < 0, pi, 0.0) + return torch.where(torch.isnan(x), float("nan"), ret) + + +@register_decomposition([aten.add]) +def add(x, y, *, alpha=None): + x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() + y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() + if not x_is_complex_tensor or not y_is_complex_tensor: + return NotImplemented + z = y + if alpha is not None: + z = alpha * y + complex_type = torch.promote_types(x.dtype, y.dtype) + return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type) + + +@register_decomposition([aten.conj_physical]) +def conj_physical(self): + assert not self.is_complex(), "TODO: implement this" + return self + + +@register_decomposition([aten.lift, aten.detach_]) +def lift(self): + return self + + +@register_decomposition([aten.bernoulli.default]) +def bernoulli(self, *, generator=None): + assert generator is None + return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) + + +@register_decomposition([aten.fmin, prims.fmin]) +def fmin(self, other): + return torch.where(torch.isnan(other) | (other > self), self, other) + + +@register_decomposition([aten.fmax, prims.fmax]) +def fmax(self, other): + return torch.where(torch.isnan(other) | (other < self), self, other) + + +@register_decomposition(aten.amax) +def amax(self, dim=None, keepdim=False): + if self.dtype == torch.bool: + return torch.any(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition(aten.amin) +def amin(self, dim=None, keepdim=False): + if self.dtype == torch.bool: + return torch.all(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition([aten.narrow_copy]) +def narrow_copy(self, dim, start, length): + return torch.narrow(self, dim, start, length).clone() + + +@register_decomposition([aten.expand_copy]) +def expand_copy(self, size, *, implicit=False): + return aten.expand(self, size, implicit=implicit).clone() + + +@register_decomposition([aten.view_copy.default]) +def view_copy_default(self, size): + return aten.view(self, size).clone() + + +@register_decomposition([aten.view_copy.dtype]) +def view_copy_dtype(self, dtype): + return self.to(dtype).clone() + + +def get_like_layout( + tensor: torch.Tensor, memory_format: Optional[torch.memory_format] +) -> torch.memory_format: + # TODO: _to_copy tensor to stride permutation + if memory_format is torch.preserve_format or memory_format is None: + return utils.suggest_memory_format(tensor) + else: + return memory_format + + +@register_decomposition(aten.rand_like) +def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): + return torch.rand( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randn_like) +def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): + return torch.randn( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.full_like) +def full_like( + self, + fill_value, + *, + dtype=None, + layout=None, + device=None, + pin_memory=False, + requires_grad=False, + memory_format=torch.preserve_format, +): + return torch.full( + [*self.size()], + fill_value, + dtype=dtype or self.dtype, + layout=layout or self.layout, + device=device or self.device, + requires_grad=requires_grad, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.default) +def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs): + return aten.randint.low( + 0, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.low_dtype) +def randint_like_low( + self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs +): + return aten.randint.low( + low, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint.default) +def randint(high, size, **kwargs): + return aten.randint.low(0, high, size, **kwargs) + + +# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is +# scale and zero_point is scalar or scalar tensor +@register_decomposition(quantized_decomposed.quantize_per_tensor.default) +def quantize_per_tensor_default_decomp_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is +# scale and zero_point is scalar or scalar tensor +@register_decomposition(quantized_decomposed.dequantize_per_tensor.default) +def dequantize_per_tensor_default_decomp_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return (input.to(torch.float32) - zero_point) * scale + + +@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor) +def quantize_per_tensor_tensor_decomp_impl( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor) +def dequantize_per_tensor_tensor_decomp_impl( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to( + torch.float32 + ) + + +@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) +def q_embedding_bag_byte_unpack_decomp(packed): + def bitcast_u8_to_f32(u8): + x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) + if sys.byteorder == "little": + return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] + else: + return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] + + scales = bitcast_u8_to_f32(packed[..., -8:-4]) + offsets = bitcast_u8_to_f32(packed[..., -4:]) + return packed[..., :-8].to(torch.float32) * scales + offsets + + +@register_decomposition([aten.grid_sampler_2d]) +@pw_cast_for_opmath +def grid_sampler_2d( + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> torch.Tensor: + # We do not expand the grid (_expand_grid=False) on cpu for performance reasons + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + _expand_grid = not ( + a.device == torch.device("cpu") + and interpolation_mode == 0 + and a.is_contiguous(memory_format=torch.contiguous_format) + ) + + output = decomp_grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + _expand_grid=_expand_grid, + ) + return output + + +@register_decomposition(aten._foreach_addcmul.Scalar) +def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1): + return aten._foreach_add.List( + self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_addcdiv.Scalar) +def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1): + return aten._foreach_add.List( + self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_lerp.Scalar) +def _foreach_lerp_scalar(start_tensors, end_tensors, weight): + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.Scalar( + aten._foreach_sub.List(end_tensors, start_tensors), weight + ), + ) + + +@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) +@register_decomposition(aten.miopen_batch_norm) +def miopen_batch_norm( + input: torch.Tensor, + weight: torch.Tensor, + bias: typing.Optional[torch.Tensor], + running_mean: typing.Optional[torch.Tensor], + running_var: typing.Optional[torch.Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +): + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + + if training: + return (a, b, c) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + ) + + +@functools.lru_cache(None) +def fast_random_decomps(): + return {**decompositions, **extra_random_decomps} + + +def select_decomp_table(): + """decomps can change based on config""" + if config.fallback_random: + return decompositions + return fast_random_decomps() + + +@register_decomposition(aten.masked_scatter) +def masked_scatter(self, mask, source): + if self.device.type == "cuda": + # This two-step algorithm is the same as eager CUDA, for eager CPU we + # use a 1-shot serial iteration. + self, mask = aten.broadcast_tensors([self, mask]) + source_idx = mask.reshape(-1).cumsum(0) - 1 + return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source) + return NotImplemented + + +@register_decomposition(quantized_decomposed.choose_qparams.tensor) +def choose_qparams_tensor( + input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype +): + min_val, max_val = torch.aminmax(input) + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.max(scale, torch.Tensor([eps])) + zero_point = quant_min - torch.round(min_val / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale.to(torch.float64), zero_point.to(torch.int64) + + +@register_decomposition(aten.put) +def put(self, index, source, accumulate=False): + flattened = self.flatten() + flattened = torch.index_put( + flattened, [index], source.reshape(index.shape), accumulate + ) + return flattened.reshape(self.shape) + + +@register_decomposition(aten.put_) +def put_(self, index, source, accumulate=False): + out = aten.put(self, index, source, accumulate=accumulate) + return self.copy_(out) diff --git a/MLPY/Lib/site-packages/torch/_inductor/dependencies.py b/MLPY/Lib/site-packages/torch/_inductor/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..7733ea29fb079ee35cd09aa2b94287cb94ec37a7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/dependencies.py @@ -0,0 +1,506 @@ +import collections +import dataclasses +import itertools +import logging +import re +import typing +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch + +import sympy + +import torch +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + +from .codegen.common import index_prevent_reordering +from .utils import ( + get_dtype_size, + reduction_num_outputs, + sympy_index_symbol, + sympy_str, + sympy_subs, + VarRanges, +) +from .virtualized import OpsHandler, ReductionType, V + +log = logging.getLogger(__name__) +is_indirect = re.compile(r"indirect|tmp").search +Dep = Union["MemoryDep", "StarDep", "WeakDep"] + + +class MemoryDep(typing.NamedTuple): + name: str + index: sympy.Expr # type: ignore[assignment] + var_names: Tuple[sympy.Symbol, ...] + size: Tuple[sympy.Expr, ...] + + def __repr__(self): + return f"MemoryDep({self.name!r}, {self.index}, {self.ranges})" + + @property + def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]: + """{c0: 128, c1: 512, ...}""" + return dict(zip(self.var_names, self.size)) + + def get_numel(self) -> sympy.Expr: + if self.is_indirect(): + numel = V.graph.get_numel(self.name) + else: + vars = set(self.index.free_symbols) + numel = sympy.Integer(1) + for var, size in zip(self.var_names, self.size): + if var in vars: + numel = numel * size + return numel + + def rename(self, renames: Dict[str, str]) -> "MemoryDep": + if self.name in renames: + return MemoryDep( + renames[self.name], self.index, var_names=self.var_names, size=self.size + ) + return self + + def numbytes_hint(self): + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + + def has_unbacked_symbols(self): + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return isinstance(self.index, sympy.Symbol) and self.index in self.var_names + + def is_scalar(self) -> bool: + if isinstance(self.index, sympy.Symbol): + return self.index not in self.var_names and not self.is_indirect() + return isinstance(self.index, (int, sympy.Integer)) + + def is_indirect(self) -> bool: + return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] + + +class StarDep(typing.NamedTuple): + # depends on the entire buffer + name: str + + @property + def index(self): + raise NotImplementedError("StarDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return V.graph.get_numel(self.name) + + def rename(self, renames: Dict[str, str]) -> "StarDep": + if self.name in renames: + return StarDep(renames[self.name]) + return self + + def numbytes_hint(self): + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + + def has_unbacked_symbols(self): + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return False + + def is_scalar(self) -> bool: + return False + + def is_indirect(self) -> bool: + return False + + +# Used for tracking mutation ordering +# if A reads a buffer and B mutates it +# B must be ordered after A +# +# It is weak because if it turns out A's read is never used, we can still +# eliminate it +class WeakDep(typing.NamedTuple): + name: str + + @property + def index(self): + raise NotImplementedError("WeakDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return sympy.Integer(1) + + def rename(self, renames: Dict[str, str]) -> "WeakDep": + if self.name in renames: + return WeakDep(renames[self.name]) + return self + + def numbytes_hint(self): + return 1 # Purely inserted for ordering, not an actual dep + + def has_unbacked_symbols(self): + return False + + def is_contiguous(self) -> bool: + return False + + +class IndexExprDep(typing.NamedTuple): + index: sympy.Expr # type: ignore[assignment] + var_names: Tuple[sympy.Symbol, ...] + size: Tuple[sympy.Expr, ...] + + +@dataclasses.dataclass +class ReadWrites: + reads: Set[Dep] + writes: Set[Dep] + index_exprs: Set[IndexExprDep] + range_vars: Optional[List[sympy.Expr]] = None + var_ranges: Optional[VarRanges] = None + op_counts: typing.Counter[str] = dataclasses.field( + default_factory=collections.Counter + ) + + def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": + return ReadWrites( + {dep.rename(renames) for dep in self.reads}, + {dep.rename(renames) for dep in self.writes}, + self.index_exprs, + self.range_vars, + self.var_ranges, + op_counts=self.op_counts, + ) + + def with_read(self, dep: Dep) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep)) + return ReadWrites( + set.union(self.reads, {dep}), + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + op_counts=self.op_counts, + ) + + def merge(self, other: "ReadWrites"): + reads = set.union(self.reads, other.reads) + writes = set.union(self.writes, other.writes) + index_exprs = set.union(self.index_exprs, other.index_exprs) + op_counts = collections.Counter(self.op_counts) + op_counts.update(other.op_counts) + return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts) + + @staticmethod + def merge_list(read_writes: List["ReadWrites"]): + all_writes = set.union(*[rw.writes for rw in read_writes]) + all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes + all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes]) + + op_counts: typing.Counter[Any] = collections.Counter() + for rw in read_writes: + op_counts.update(rw.op_counts) + + return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts) + + def remove_reads(self, rem_reads): + return ReadWrites( + self.reads - rem_reads, + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + op_counts=self.op_counts, + ) + + def reads_and_writes(self): + return itertools.chain(self.reads, self.writes) + + +class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool): + super().__init__() + self._reads: Set[Dep] = set() + self._writes: Set[MemoryDep] = set() + self._index_exprs: Set[IndexExprDep] = set() + self._var_ranges: VarRanges = var_ranges + self._normalize: bool = normalize + + def canonicalize( + self, index: sympy.Expr + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + if not self._normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = tuple( + k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 + ) + sizes = tuple(v for v in sizes if v != 1) + return index, var_names, sizes # type: ignore[return-value] + + # Try to further simplify the indexes even if simplify_loops didn't + # convert it to the simplest form because of the interference from + # different indexing formulas. + free_symbols = index.free_symbols + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + index_vars = [*var_ranges.keys()] + sizes = tuple(var_ranges.values()) + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, + sizes, + index_prevent_reordering([index], index_vars, sizes), + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + new_vars, add_var = var_builder(canonicalization_prefix()) + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + index = sympy_subs(sympy.expand(index), replacement) + + new_vars = [*new_vars.keys()] + new_sizes = [*new_sizes] + free_symbols = index.free_symbols + while new_vars and new_vars[-1] not in free_symbols: + # Reduction has last (reduced) dim in its sizes, but + # downstream users won't. Normalize this away. + new_vars.pop() + new_sizes.pop() + return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + + def load(self, name: str, index: sympy.Expr) -> str: + self._reads.add(MemoryDep(name, *self.canonicalize(index))) + return f"load({name}, {sympy_str(index)})" + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + return self.load(name, sympy.Integer(index)) + + def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str: + self._writes.add(MemoryDep(name, *self.canonicalize(index))) + return f"store({name}, {sympy_str(index)}, {value}, {mode})" + + def store_reduction(self, name: str, index, value) -> str: + return self.store(name, index, f"store_reduction({value})") + + def index_expr(self, index: sympy.Expr, dtype) -> str: + self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) + return f"index_expr({sympy_str(index)}, {dtype})" + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + self._reads.add(StarDep(offsets_name)) + return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" + + +class _OpCounter: + """Shim to count how many times each op is used""" + + def __init__(self, inner): + super().__init__() + self.parent_handler = inner + self._op_counts: typing.Counter[Any] = collections.Counter() + + def __getattr__(self, name): + self._op_counts[name] += 1 + return getattr(self.parent_handler, name) + + +class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool): + parent_handler = _RecordLoadStoreInner( + var_ranges=var_ranges, normalize=normalize + ) + parent_handler = _OpCounter(parent_handler) + super().__init__(parent_handler=parent_handler) + + +def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: + cnt = itertools.count() + var_ranges: VarRanges = dict() + + def add_var(length: sympy.Expr) -> sympy.Symbol: + v = sympy_index_symbol(f"{prefix}{next(cnt)}") + var_ranges[v] = length + return v + + return var_ranges, add_var + + +def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str): + var_ranges, add_var = var_builder(prefix) + args: List[List[sympy.Symbol]] = [] + for size in argsizes: + args.append(list(map(add_var, size))) + return args, var_ranges + + +def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"): + from .ir import SqueezeView + + var_ranges, add_var = var_builder(prefix) + args: List[List[sympy.Expr]] = [] + new_sizes: List[List[sympy.Expr]] = [] + for size in argsizes: + new_size, reindex = SqueezeView.squeezer(size) + new_sizes.append(new_size) + args.append(reindex(list(map(add_var, new_size)))) + return args, var_ranges + + +def extract_read_writes( + fn: Callable[..., Any], + *argsizes: Tuple[sympy.Expr, ...], + normalize: bool = False, + prefix: str = "d", +): + args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args) + + if normalize: + range_vars = [] # Number of vars could differ due to normalization + else: + range_vars = list(itertools.chain.from_iterable(args)) + + inner = rw.parent_handler.parent_handler + return ReadWrites( + set(inner._reads), + set(inner._writes), + inner._index_exprs, + range_vars, + var_ranges, + rw.parent_handler._op_counts, + ) + + +def extract_input_node_reduction_ranges( + input_node: "torch._inductor.ir.TensorBox", +) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: + """ + Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. + It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. + In this case, reduction_sizes of the Reduction nodes need to be the same. + Otherwise returns (None, None). + """ + + from .ir import ComputedBuffer, Loops + + if isinstance(input_node.data, ComputedBuffer): + # Input node has already been realized. Return its size and reduction_size. + size = input_node.get_size() + reduction_size = input_node.get_reduction_size() + if len(reduction_size) > 0: + return (size, reduction_size) + else: + return (None, None) + + if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] + # Other IRNodes do not have reduction_ranges. + return (None, None) + + # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? + # The current method still uses reduction ranges from the dependent realized node, which is not ideal. + # Is there a way to check whether there are permutations inbetween? + reads = input_node.get_reads() + reduction_size = None + size = None + while reduction_size is None and len(reads) > 0: + seen = set() + new_reads = [] + for read in reads: + if not isinstance(read, MemoryDep): + continue + if read.name in seen: + continue + seen.add(read.name) + buffer = V.graph.get_buffer(read.name) + if buffer is None: + continue + if ( + isinstance(buffer, ComputedBuffer) + and len(buffer.get_reduction_size()) > 0 + ): + if reduction_size is None: + reduction_size = buffer.get_reduction_size() + size = buffer.get_size() + elif ( + reduction_size != buffer.get_reduction_size() + or size != buffer.get_size() + ): + return (None, None) + else: + new_reads.extend(buffer.get_reads()) + if reads == new_reads: + return (size, reduction_size) + else: + reads = new_reads + return (size, reduction_size) + + +def canonicalization_prefix(): + return "c" + + +# ops handler which computes all the free unbacked symbols for an IR +class FreeUnbackedSymbolsOpsHandler: + symbols: Set[sympy.Symbol] + + def __init__(self): + self.symbols = set() + + def __getattr__(self, name: str) -> Callable[..., Any]: + def inner(*args, **kwargs): + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= free_unbacked_symbols(a) + + return inner + + def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol: + assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) + self.symbols |= free_unbacked_symbols(size) + return sympy_index_symbol(f"({str(index_var)})") + + def frexp(self, x): + return (None,) * 2 + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[None, Tuple[None, ...]], + ) -> Union[None, Tuple[None, ...]]: + num_values = reduction_num_outputs(reduction_type) + return (None,) * num_values if num_values > 1 else None + + +def _typecheck_FreeUnbackedSymbolsOpsHandler( + h: FreeUnbackedSymbolsOpsHandler, +) -> OpsHandler[None]: + return h + + +def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None): + from .ir import FlexibleLayout + + args = [index, rindex] if rindex is not None else [index] + handler = FreeUnbackedSymbolsOpsHandler() + # NB: I cargo culted the allow_indexing patch here, I don't understand why + # people do this all over + with V.set_ops_handler(handler), patch.object( + FlexibleLayout, "allow_indexing", True + ): + fn(*args) + return handler.symbols diff --git a/MLPY/Lib/site-packages/torch/_inductor/exc.py b/MLPY/Lib/site-packages/torch/_inductor/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb73e3f21cee556757f4cc21a003d8c6583cf6d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/exc.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import os +import tempfile +import textwrap +from functools import lru_cache + +if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": + + @lru_cache(None) + def _record_missing_op(target): + with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd: + fd.write(str(target) + "\n") + +else: + + def _record_missing_op(target): # type: ignore[misc] + pass + + +class OperatorIssue(RuntimeError): + @staticmethod + def operator_str(target, args, kwargs): + lines = [f"target: {target}"] + [ + f"args[{i}]: {arg}" for i, arg in enumerate(args) + ] + if kwargs: + lines.append(f"kwargs: {kwargs}") + return textwrap.indent("\n".join(lines), " ") + + +class MissingOperatorWithoutDecomp(OperatorIssue): + def __init__(self, target, args, kwargs): + _record_missing_op(target) + super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") + + +class MissingOperatorWithDecomp(OperatorIssue): + def __init__(self, target, args, kwargs): + _record_missing_op(target) + super().__init__( + f"missing decomposition\n{self.operator_str(target, args, kwargs)}" + + textwrap.dedent( + f""" + + There is a decomposition available for {target} in + torch._decomp.get_decompositions(). Please add this operator to the + `decompositions` list in torch._inductor.decompositions + """ + ) + ) + + +class LoweringException(OperatorIssue): + def __init__(self, exc: Exception, target, args, kwargs): + super().__init__( + f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" + ) + + +class InvalidCxxCompiler(RuntimeError): + def __init__(self): + from . import config + + super().__init__( + f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}" + ) + + +class CppWrapperCodeGenError(RuntimeError): + def __init__(self, msg: str): + super().__init__(f"C++ wrapper codegen error: {msg}") + + +class CppCompileError(RuntimeError): + def __init__(self, cmd: list[str], output: str): + if isinstance(output, bytes): + output = output.decode("utf-8") + + super().__init__( + textwrap.dedent( + """ + C++ compile error + + Command: + {cmd} + + Output: + {output} + """ + ) + .strip() + .format(cmd=" ".join(cmd), output=output) + ) + + +class CUDACompileError(CppCompileError): + pass diff --git a/MLPY/Lib/site-packages/torch/_inductor/freezing.py b/MLPY/Lib/site-packages/torch/_inductor/freezing.py new file mode 100644 index 0000000000000000000000000000000000000000..8de18a38e140187af47d0fe5781cbaa10a37e533 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/freezing.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +import itertools +import logging + +import weakref +from typing import Any, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code +from torch._functorch.aot_autograd import MutationType +from torch._functorch.compile_utils import fx_graph_cse +from torch._inductor.constant_folding import constant_fold, replace_node_with_constant + +from torch._inductor.fx_passes.freezing_patterns import freezing_passes +from torch._inductor.fx_passes.post_grad import view_to_reshape + +from . import config + +aten = torch.ops.aten +prims = torch.ops.prims + +log = logging.getLogger(__name__) + + +def replace_params_with_constants( + gm: torch.fx.GraphModule, + flat_params: list[Any], + fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, +) -> List[int]: + """ + Replaces the parameters of a PyTorch GraphModule with constants wherever possible. + Returns a list of indices representing the input parameters that were not converted to constants. + """ + params = [node for node in gm.graph.nodes if node.op == "placeholder"] + fake_inp_nodes = params[: len(params)] + preserved_arg_indices = [] + aliased_input_args = [ + out_info.base_idx + for out_info in fw_metadata.output_info + if out_info.base_idx is not None + ] + + # TODO (tmanlaibaatar) figure out why this is different + # from mutated_inp_runtime_indices + mutated_inps = [ + i + for i, m in enumerate(fw_metadata.input_info) + if m.mutation_type + in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) + ] + + for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): + if i in mutated_inps or i in aliased_input_args: + preserved_arg_indices.append(i) + continue + replace_node_with_constant(gm, node, real_input) + # add on non param inputs + preserved_arg_indices.extend(range(len(flat_params), len(params))) + # is this necessary ? + gm.recompile() + return preserved_arg_indices + + +def freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: List[torch._subclasses.FakeTensor], +) -> Tuple[torch.fx.GraphModule, List[int]]: + """ + Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation + and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. + + Assumes that this function is run in dynamo tracing post aot_autograd. + + Args: + dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. + aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. + example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. + + Returns: + Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices + of the inputs that were preserved (not turned into constants). + """ + # We have convert conv's weight to channels last which may meet error for .view + # when doing fake_tensor_prop. So we need to convert view to reshape first. + # See the details in fx_codegen_and_compile of compile_fx.py. + view_to_reshape(aot_autograd_gm) + + if tracing_context := torch._guards.TracingContext.try_get(): + fw_metadata = tracing_context.fw_metadata + params_flat = tracing_context.params_flat + assert fw_metadata is not None and params_flat is not None + + preserved_arg_indices = replace_params_with_constants( + aot_autograd_gm, params_flat, fw_metadata + ) + else: + inputs = [ + node for node in aot_autograd_gm.graph.nodes if node.op == "placeholder" + ] + preserved_arg_indices = list(range(len(inputs))) + + # TODO - further restrict cse ? right now needed to dedup aliasing ops + cse_graph = fx_graph_cse(aot_autograd_gm.graph) + aot_autograd_gm.graph = cse_graph + aot_autograd_gm.recompile() + + aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] + freezing_passes(aot_autograd_gm, aot_example_inputs) + + constant_fold(aot_autograd_gm) + # invalidate nn Modules + if config.freezing_discard_parameters: + invalidate_eager_modules() + discard_traced_gm_params(dynamo_gm) + + log.debug("%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm)) + + return aot_autograd_gm, preserved_arg_indices + + +class ErasedTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, name, owning_mod): + return super().__new__(cls, elem.to(device="meta")) + + def __init__(self, elem, name: Optional[str], mod): + self.erased_name = name + self.owning_mod_ref = weakref.ref(mod) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + erased_tensors = [ + e + for e in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(e, ErasedTensor) + ] + assert len(erased_tensors) > 0 + e = erased_tensors[0] + + raise RuntimeError( + f"Trying to run Pytorch Eager Module after Dynamo Freezing. " + "The original parameters have been discarded for memory efficiency. " + f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" + ) + + +@torch.utils._python_dispatch._disable_current_modes() +def invalidate_eager_modules(): + for mod in torch._guards.TracingContext.get().module_context.nn_modules.values(): + if not isinstance(mod, torch.nn.Module): + continue + + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True # type: ignore[attr-defined] + setattr(mod, attr_name, e_t) + + +@torch.utils._python_dispatch._disable_current_modes() +def discard_traced_gm_params(mod: torch.fx.GraphModule): + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True # type: ignore[attr-defined] + setattr(mod, attr_name, e_t) + + +def enforce_output_layout(gm: torch.fx.GraphModule): + """ + Make sure the output node's layout does not change due to compiler optimizations + by adding aten.as_strided nodes with the expected strides. + + Only used for inference so we can assume all graph outputs are model outputs. + """ + *_, output_node = gm.graph.nodes + out_list = output_node.args[0] + with gm.graph.inserting_before(output_node): + for n in out_list: + if not isinstance( + n.meta["val"], torch.Tensor + ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): + continue + + # add a node to enforce eager layout + ft = n.meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n, ft.stride()) + ) + + # can not call + # n.replace_all_uses_with(new_node) + # since it will replace the usage of n in new_node itself. + output_node.replace_input_with(n, new_node) + + gm.graph.lint() + gm.recompile() + + +def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): + """ + Make sure the as_strided node's input's layout does not change due to compiler + optimizations, because the as_strided strides info depends on input tensor stride info. + """ + + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] + for n in strided_nodes: + with gm.graph.inserting_before(n): + # add a node to enforce eager layout + ft = n.args[0].meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) + ) + n.replace_input_with(n.args[0], new_node) + + gm.graph.lint() + gm.recompile() + + +@dynamo_timed +def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): + """ + Convert 4d convolution weight tensor to channels last format. + + This pass is performed before freezing so the added nodes can be constant + folded by freezing. + """ + convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] + for conv in convs: + weight_node = conv.args[1] + if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ + "val" + ].is_contiguous(memory_format=torch.channels_last): + # not a 4d tensor or already channels last, skip + continue + + with gm.graph.inserting_before(conv): + new_node = gm.graph.call_function( + aten.clone.default, + (weight_node,), + {"memory_format": torch.channels_last}, + ) + conv.replace_input_with(weight_node, new_node) + + enforce_as_strided_input_layout(gm) + enforce_output_layout(gm) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69e044de5a3058c82c64b5292cf3fa706d1cce1e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..924e386148a9938ba3b5f7034ff8a31f0f1ee9ce Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a48ddef1ef6c206faac222726830caf1ee79ac5b Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a4e5bf3aa83f2607103270eda6aa1643a378de2 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cd2133ff8a669fb2af9ba4535feed0e5aac4604 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e80fe654a13df12421a0c59d54b1e4301388092 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e50be63a3df14bb9ac9165f6d858567b5a042d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09fdb6f5771f0820bc5a00348e8f4d73514d38c9 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194c7707eab000aa10cb4e829ac4a87662f9d700 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74b731fae158acbed29e2469d14a53c08870b4a4 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..123054a5c65babbb349efcc567783f3e535c6d13 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6d5670dd7b046ca3360dff168549191bdba8f0e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbe921a293f754e0de2b57631f02882a6e4b4890 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c9fb96f886c6a092ff2e9b97cbaf3f48d004d11 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0185aa822214a328c054a19d2ad420cdb250c051 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..583ed21c481c06ed3d6f4ec3184969fd5a4340c0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa3e8ae176ccc7ff6575c268d8cc50bae4b54dd5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..006cb41e80248fd95827223b3226bfeb07080f1e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31484150a6e7c7a20b4d6e5c598a03a679020d28 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/binary_folding.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/binary_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..0f088b3af0bb4d41257ce3836f8bf067373feaa4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/binary_folding.py @@ -0,0 +1,277 @@ +import functools +import itertools + +import torch +from ..._dynamo.utils import counters + +from ..pattern_matcher import Arg, CallFunction, KeywordArg +from .freezing_patterns import register_binary_folding_pattern + +aten = torch.ops.aten +prims = torch.ops.prims + + +def mark_mixed_dtype_conv(conv): + conv_dtype = conv.meta["val"].dtype + if conv_dtype not in (torch.float16, torch.bfloat16): + return + + if not len(conv.users) == 1: + return + + conv_user = next(iter(conv.users.keys())) + if not isinstance(conv_user.meta["val"], torch.Tensor): + return + + if not conv_user.meta["val"].dtype == torch.float32: + return + + while conv_user.target in _binary_ops: + if not len(conv_user.users) == 1: + return + + conv_user = next(iter(conv_user.users.keys())) + + if not ( + conv_user.target == prims.convert_element_type.default + and conv_user.args[1] == conv_dtype + ): + return + + conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype + + +def mark_mixed_dtype_allowed_convs(gm): + """ + Mark convolutions which we will binary fold even with mixed precision constants. We constant fold in the higher precision + for better accuracy and then recover the original precision after. + """ + for node in gm.graph.nodes: + if node.target is aten.convolution.default: + mark_mixed_dtype_conv(node) + + +def recover_original_precision_folded_convs(gm): + """ + After binary folding conv weights and biases to a higher dtype, recover the original precision they were in. + """ + graph = gm.graph + convs = [node for node in graph.nodes if node.target is aten.convolution.default] + for node in convs: + orig_dtype = node.meta.get("_allow_conv_mixed_dtype_folding", None) + if orig_dtype is None: + continue + + with graph.inserting_before(node): + for idx in [1, 2]: + old_input = node.args[idx] + if old_input is None: + continue + + new_input = graph.create_node( + "call_function", + prims.convert_element_type.default, + (old_input, orig_dtype), + ) + node.replace_input_with(old_input, new_input) + + +_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] + + +@functools.lru_cache(None) +def binary_folding_init(): + _conv_args = [Arg() for _ in range(9)] + _computation_ops = [aten.convolution.default] + _computation_calls = [CallFunction(aten.convolution.default, *_conv_args, _users=1)] + + """ + In order to fuse add/sub/mul/div with conv, the dimensions of its + constant tensor must satisfy the following: + - with resizing, broadcast to w/ weight/bias tensor shape + - broadcast to the conv output shape + It needs to have a shape that can resize to weight/bias + tensor shape because we need to run the op with the conv + weights/bias without changing their sizes. + It needs to broadcast to the conv output shape so that we do + accidentally change the shape of op output by pre-fusing it + compared to eager. + The only dimension value shared by weight/bias/conv output + is they all contain a dim with value = channels-out. In the + conv output tensor, this is in the second dimension, + so the pointwise op tensor may have a second dimension of + value == channels-out, but all the other dimensions have to be 1 + """ + + def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): + # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + if len(weight_shape) < len(other_shape): + return False + if len(weight_shape) == len(other_shape) + 1: + # weight shape is [o, i, *], other_shape is [o, 1...]. + for i in reversed(range(len(other_shape))): + if i == 0 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + else: + # weight shape is [o, i, *], other_shape is [1, i, *] + for i in reversed(range(len(other_shape))): + if i == 1 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + return True + + def _check_conv_and_broadcast_op(conv_node, other): + # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. + # conv.weight + if conv_node.args[1].op != "get_attr": + return False + # conv.bias + if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(conv_node.args[1].users) == 1: + return False + + weight_meta_value = conv_node.args[1].meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) + != weight_meta_value.dtype + ): + if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): + return False + else: + # TODO: support scalar case + return False + + return True + + def _is_foldable_pattern(match): + binary_node = match.output_node() + computation_node = binary_node.args[0] + other = binary_node.args[1] + if binary_node.args[0].target not in _computation_ops: + computation_node = binary_node.args[1] + other = binary_node.args[0] + if binary_node.args[0].target == aten.convolution.default: + return _check_conv_and_broadcast_op(computation_node, other) + + return False + + def resize_scalar_or_tensor_to_shape(graph, other, shape): + # TODO: support scalar case + if other.meta.get("val").numel() == 1: + # expand errors if the shape input has less # dims than the tensor input + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + else: + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, shape), + ) + return res + + def _create_new_conv_node(graph, conv_node, binary_node, other): + assert conv_node.target == aten.convolution.default + conv_args = list(conv_node.args) + weight_meta_value = conv_node.args[1].meta.get("val") + bias = conv_args[2] + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, other, (weight_meta_value.size(0),) + ) + new_bias = graph.create_node( + "call_function", + binary_node.target, + (0 if bias is None else bias, other_reshape), + ) + conv_args[2] = new_bias + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] + weight_broadcast_shape[0] = weight_meta_value.size(0) + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, other, tuple(weight_broadcast_shape) + ) + new_weight = graph.create_node( + "call_function", binary_node.target, (conv_args[1], other_reshape1) + ) + new_weight.meta.update(conv_args[1].meta) + conv_args[1] = new_weight + if bias is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, other, (weight_meta_value.size(0),) + ) + new_bias = graph.create_node( + "call_function", binary_node.target, (bias, other_reshape) + ) + new_bias.meta.update(bias.meta) + conv_args[2] = new_bias + return graph.create_node("call_function", conv_node.target, tuple(conv_args)) + + for _computation_call, binary_op in itertools.product( + _computation_calls, _binary_ops + ): + + @register_binary_folding_pattern( + CallFunction(binary_op, _computation_call, KeywordArg("other")), + extra_check=_is_foldable_pattern, + ) + def folded_op(match, *args, **kwargs): + counters["inductor"]["binary_folding"] += 1 + other = kwargs.get("other") + binary_node = match.output_node() + computation_node = ( + binary_node.args[0] + if binary_node.args[0].target in _computation_ops + else binary_node.args[1] + ) + graph = match.graph + with graph.inserting_before(binary_node): + # TODO: support linear? + assert computation_node.target == aten.convolution.default + new_computation_node = _create_new_conv_node( + graph, computation_node, binary_node, other + ) + binary_node.replace_all_uses_with(new_computation_node) + new_computation_node.meta.update(computation_node.meta) + graph.erase_node(binary_node) + graph.erase_node(computation_node) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..10d1aaf3c8bf6d66d44ff9bac13d6c3a6bb3d818 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -0,0 +1,221 @@ +import logging +from typing import List, Optional + +import torch +from torch import Tensor +from torch._dynamo.utils import counters +from torch._inductor import utils + +from ..pattern_matcher import ( + Arg, + CallFunction, + config_flag, + Ignored, + Match, + register_graph_pattern, +) +from .post_grad import decompose_mm_pass + +aten = torch.ops.aten +log = logging.getLogger(__name__) + +# TODO: need a better strategy for decomposing mm +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def should_decompose_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + return ( + torch._inductor.config.decompose_mem_bound_mm + and check_device(mat1, mat2) + and not utils.any_is_symbolic(mat1, mat2, input) + ) + + +def should_decompose_bmm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if not should_decompose_common(mat1, mat2): + return False + else: + if len(mat1.shape) != 3 or len(mat2.shape) != 3: + return False + if mat1.shape[0] < MIN_FIRST_DIMENSION_DECOMPOSITION: + return False + # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION + if (mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION) + ( + mat1.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + (mat2.shape[2] < MAX_OTHER_DIMENSION_DECOMPOSITION) < 2: + return False + return True + + +def should_decompose_mm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + should_decompose_common(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION + and mat2.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION + and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def should_decompose_mmt(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + should_decompose_common(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[0] >= MIN_FIRST_DIMENSION_DECOMPOSITION + and mat1.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def should_decompose_mm_largek(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + return ( + should_decompose_common(mat1, mat2) + and len(mat1.shape) == 2 + and len(mat2.shape) == 2 + and mat1.shape[1] >= MIN_FIRST_DIMENSION_DECOMPOSITION + and mat1.shape[0] < MAX_OTHER_DIMENSION_DECOMPOSITION + and mat2.shape[1] < MAX_OTHER_DIMENSION_DECOMPOSITION + ) + + +def is_node_meta_valid(node: torch.fx.Node): + return "val" in node.meta + + +def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]): + node = match.nodes[-1] + log.debug( + "Decompose %s with input shape: %s", + node.target, + ", ".join( + str(input.meta["val"].shape) if "val" in input.meta else "None" + for input in inputs + ), + ) + + +@register_graph_pattern( + CallFunction(aten.bmm, Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2) + + if should_decompose_bmm(mat1, mat2): + counters["inductor"]["decompose_bmm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.addmm, Arg(), Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_addmm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, + mat3: torch.fx.Node, +): + def repl(mat1, mat2, mat3): + return torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2) + mat1 + + if should_decompose_mm(mat2, mat3): + counters["inductor"]["decompose_addmm"] += 1 + match.replace_by_example(repl, [mat1, mat2, mat3]) + print_decompose_pattern(match, [mat1, mat2, mat3]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, CallFunction(aten.permute, Arg(), Ignored()), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_mmt( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0) + + if should_decompose_mmt(mat1, mat2): + counters["inductor"]["decompose_mmt"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_mm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2) + + if should_decompose_mm(mat1, mat2): + counters["inductor"]["decompose_mm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=decompose_mm_pass, + extra_check=config_flag("decompose_mem_bound_mm"), +) +def decompose_mm_large_k( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + mat1 = mat1.permute(1, 0) + return torch.sum(mat1[:, :, None] * mat2[:, None, :], dim=0) + + if should_decompose_mm_largek(mat1, mat2): + counters["inductor"]["decompose_mm_large_k"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + return diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py new file mode 100644 index 0000000000000000000000000000000000000000..0df666affd75ab02b3b62c67d8fdd65bc3171fb5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +from typing import Union + +import torch +from torch.fx.experimental.proxy_tensor import py_sym_types, SymBool, SymFloat, SymInt + + +@dataclass +class _SymExprHash: + """ + Hash for a py_sym_types that will use the underlying sympy expression + """ + + sym_obj: Union[SymInt, SymFloat, SymBool] + + def __hash__(self) -> int: + return hash((type(self.sym_obj), self.sym_obj.node.expr)) + + def __eq__(self, value) -> bool: + if not isinstance(value, _SymExprHash): + return False + return self.sym_obj.node.expr == value.sym_obj.node.expr + + +class _SymHashingDict: + """ + Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse + existing sym proxies. + + SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail, + fallback to symnodes. + """ + + def __init__(self): + self.sym_hash_dict = {} + + def __setitem__(self, key, value): + self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value) + + def __getitem__(self, key): + return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)] + + def __contains__(self, key): + return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict + + def get(self, key, default=None): + return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default) + + def _wrap_to_sym_expr_hash(self, key): + return _SymExprHash(key) if isinstance(key, py_sym_types) else key + + +def dedupe_symints(graph: torch.fx.Graph): + """ + Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs. + + We only dedupe from graph inputs to avoid adding a potential dependency in the forward + from the backward. + + """ + + sym_dict = _SymHashingDict() + resolvable_from_input_symints = set() + + for node in graph.nodes: + val = node.meta.get("val", None) + if val is None or not isinstance(val, py_sym_types): + continue + + if node.op == "placeholder": + resolvable_from_input_symints.add(node) + sym_dict[val] = node + elif existing_node := sym_dict.get(val): + node.replace_all_uses_with(existing_node) + graph.erase_node(node) + elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): + sym_dict[val] = node + resolvable_from_input_symints.add(node) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8712cf39a2a2e27b4e14c927324f3d2826a4e502 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn + +from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config +from torch.func import functional_call + +from ..pattern_matcher import CallModuleVarArgs, Match, register_graph_pattern + +from .pre_grad import efficient_conv_bn_eval_pass + + +def efficient_conv_bn_eval( + bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. + conv (nn.modules.conv._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + + assert bn.running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv.weight.ndim - 1) + if isinstance(conv, nn.modules.conv._ConvTransposeNd): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn.running_mean + ) + + input = x + params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} + output = functional_call(conv, params, input) + return output + + +@register_graph_pattern( + CallModuleVarArgs( + [ + nn.modules.batchnorm._BatchNorm, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + ], + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): + # We matched a BN node + bn_node = match.nodes[0] + graph = match.graph + gm = graph.owning_module + bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] + + # We can only use efficient conv-bn for eval mode with track_running_stats + if not bn_mod.track_running_stats or bn_mod.training: + return + + # Check if the input is Conv + if bn_node.args: + input_node = bn_node.args[0] + else: + input_node = bn_node.kwargs["input"] + if input_node.op != "call_module": # type: ignore[union-attr] + return + if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] + return + input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] + supported_convs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ] + if not any(isinstance(input_mod, cls) for cls in supported_convs): + return + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + # Find a pair of conv and bn computation nodes to optimize. + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(conv_node): + # create `get_attr` node to access modules + # note that we directly call `create_node` to fill the `name` + # argument. `graph.get_attr` and + # `graph.call_function` does not allow the `name` argument. + conv_get_node = graph.create_node( + op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr] + ) + bn_get_node = graph.create_node( + op="get_attr", target=bn_node.target, name="get_bn" + ) + if conv_node.args: # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + else: + conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] + # prepare args for the fused function + args = (bn_get_node, conv_get_node, conv_input) + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval, + args=args, + name="efficient_conv_bn_eval", + ) + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/freezing_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..02f679925c534798e4ee6a11ad25e2f256b80196 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/freezing_patterns.py @@ -0,0 +1,212 @@ +import functools + +import torch +from torch._inductor.compile_fx import fake_tensor_prop +from ..._dynamo.utils import counters + +from .. import config +from ..pattern_matcher import ( + _return_true, + CallFunction, + fwd_only, + Ignored, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) + +aten = torch.ops.aten + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + +binary_folding_pass = PatternMatcherPass() + + +def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): + """ + Passes that are applied to the graph to freeze pass. + """ + + from ..freezing import constant_fold + + lazy_init() + # We need a few rounds of binary folding to get rid of all the + # unnecessary nodes, but may need a good method to chose the rounds number. + # works like: conv+binary+binary. + binary_folding = counters["inductor"]["binary_folding"] + fake_tensor_prop(gm, aot_example_inputs, True) + + torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm) + for _ in range(4): + constant_fold(gm) + # Make sure meta['val'] is properly set for all nodes + fake_tensor_prop(gm, aot_example_inputs, True) + binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] + # If we don't have binary folding, we don't need to run the pass again. + # TODO: remove the need to run fake_tensor_prop on the whole model. + if counters["inductor"]["binary_folding"] == binary_folding: + break + binary_folding = counters["inductor"]["binary_folding"] + + torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm) + + constant_fold(gm) + fake_tensor_prop(gm, aot_example_inputs, True) + + for pattern in pass_patterns: + pattern.apply(gm.graph) # type: ignore[arg-type] + + # The CPU weight packing always assume the conv's weight is channels last, + # So make sure the layout_optimization is on when doing it. + if ( + torch._C._has_mkldnn + and config.cpp.weight_prepack + and config.layout_optimization + ): + from .mkldnn_fusion import _eliminate_duplicate_packed_nodes + + _eliminate_duplicate_packed_nodes(gm) + + stable_topological_sort(gm.graph) + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn and config.cpp.weight_prepack: + from .mkldnn_fusion import _mkldnn_weight_pack_init + + _mkldnn_weight_pack_init() + + from .binary_folding import binary_folding_init + + addmm_patterns_init() + binary_folding_init() + + +def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=pass_patterns[pass_number], + ) + + +def register_binary_folding_pattern(pattern, extra_check=_return_true): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=binary_folding_pass, + ) + + +@functools.lru_cache(None) +def addmm_patterns_init(): + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + + def check_concat_weights(match): + weights = [ + match.kwargs["w1"], + match.kwargs["w2"], + ] + if "w3" in match.kwargs: + weights.append(match.kwargs["w3"]) + + return all( + w.op == "get_attr" and w.meta["val"].shape == weights[0].meta["val"].shape + for w in weights + ) + + def matmul_fuse_pattern(inp, w1, w2, w3): + return (inp @ w1, inp @ w2, inp @ w3) + + def matmul_replacement(inp, w1, w2, w3): + cat_t = torch.cat((w1, w2, w3), dim=1) + mm = inp @ cat_t + return mm.chunk(3, dim=1) + + register_replacement( + matmul_fuse_pattern, + matmul_replacement, + [val(), val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3"), + ) + + def matmul_fuse_pattern_two(inp, w1, w2): + return (inp @ w1, inp @ w2) + + def matmul_replacement_two(inp, w1, w2): + cat_t = torch.cat((w1, w2), dim=1) + mm = inp @ cat_t + return mm.chunk(2, dim=1) + + register_replacement( + matmul_fuse_pattern_two, + matmul_replacement_two, + [val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2"), + ) + + def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): + return ( + aten.addmm(b1, inp, w1), + aten.addmm(b2, inp, w2), + aten.addmm(b3, inp, w3), + ) + + def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_b = torch.cat((b1, b2, b3)) + return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + + register_replacement( + addmm_fuse_pattern_second, + addmm_fuse_replacement_second, + [val() for _ in range(7)], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), + ) + + +def same_dtype(match): + return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + Ignored(), + KeywordArg("dtype"), + ), + pass_dict=pass_patterns[0], + extra_check=same_dtype, +) +def unnecessary_dtype_convert(match: Match, **kwargs): + """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" + graph = match.graph + node = match.output_node() + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/fuse_attention.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/fuse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a47ec3558d114009135664e20025af2dd4cb32e7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/fuse_attention.py @@ -0,0 +1,786 @@ +import functools +import inspect +import logging +import math + +import torch +from ..._dynamo.utils import counters +from ..pattern_matcher import ( + filter_nodes, + fwd_only, + joint_fwd_bwd, + register_replacement, +) + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +def _sfdp_pattern_1(query, key, value, inv_scale): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_1(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_2(query, key, value, scale_factor): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(scale_factor) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_2(query, key, value, scale_factor): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale_factor) + .softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_5(query, key, value, attn_mask): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + # attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ value + + +def _sfdp_replacement_5(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_7(query, key, value, dropout_p): + # in real workloads inputs to matmul are permuted + # causing matmul to expand to a series of expand and clone calls + # we want the same to happen during pattern tracing + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_7(query, key, value, dropout_p): + # sdpa prefers inputs in permuted format + # it makes a copy to put them in this format + # if they aren't already + # to make replacement efficient ensure that inputs to sdpa + # are in required order + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return aten.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_8(query, key, value): + # no dropout version of pattern 7 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_8(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return aten.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_9(query, key, value, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_9(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return aten.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_10(query, key, value): + # no dropout version of 9 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_10(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return aten.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_11(query, key, value, inv_scale): + # Mainly for huggingface models + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v) + + +def _sfdp_replacement_11(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + +def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_13(query, key, value, dropout_p): + attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p) + return torch.bmm(attn_weight, value) + + +def _sfdp_replacement_13(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + dropout_p=dropout_p, + scale=1.0, + ).squeeze(0) + + +def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): + # for BertLarge + # Permutations are needed to create clones in graph. + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask) + .softmax(dim=-1) + .matmul(v) + ) + + +def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): + # for DistilBert + # Permutations are needed to create clones in graph. + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v + + +def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in aten.scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): + # for BertLarge with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + torch.nn.functional.dropout( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax( + dim=-1 + ), + dropout_p, + ) + .to(dtype=query.dtype) + .matmul(v) + ) + + +def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p): + # for DistilBert with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in aten.scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_params_check(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_mask_node = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + ): + return False + return True + + +def _sfdp_extra_check(scale_factor_op, disable_cuda=False): + def fn(match): + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + return _sfdp_params_check(match) + + return fn + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def _get_sfdp_patterns(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values don't actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + # attn_mask + b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) + # inv_scale + c_inp = functools.partial(torch.tensor, 2.0, device=device) + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + d = {"dropout_p": 0.113377} + + # we could also generate all these patterns in 3d.. TODO + g_3d_inp = functools.partial( + torch.empty, (1024, 128, 128), device=device, requires_grad=True + ) + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + + # softmax will generate a dtype conversion on inputs if they are in half, + # but will not in float, so we generate a pattern for both + for dtype in [torch.float, torch.half]: + g = functools.partial(g_inp, dtype=dtype) + b = functools.partial(b_inp, dtype=dtype) + m = functools.partial(m_inp, dtype=dtype) + m_float = functools.partial(m_inp, dtype=torch.float) + c = functools.partial(c_inp, dtype=dtype) + g_3d = functools.partial(g_3d_inp, dtype=dtype) + g_bs1 = functools.partial(g_bs1_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float) + + candidates = [ + ( + _sfdp_pattern_1, + _sfdp_replacement_1, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_2, + _sfdp_replacement_2, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_3, + _sfdp_replacement_3, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_4, + _sfdp_replacement_4, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_5, + _sfdp_replacement_5, + [g(), g(), g(), b()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_6, + _sfdp_replacement_6, + [g(), g(), g(), b()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_7, + _sfdp_replacement_7, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_8, + _sfdp_replacement_8, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_9, + _sfdp_replacement_9, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_10, + _sfdp_replacement_10, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_11, + _sfdp_replacement_11, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_12, + _sfdp_replacement_12, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_13, + _sfdp_replacement_13, + [g_3d(), g_3d(), g_3d()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_14, + _sfdp_replacement_14, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_15, + _sfdp_replacement_15, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_17, + _sfdp_replacement_17, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ] + mask_fp32_patterns = ["pattern_16"] + if dtype == torch.half: + # Add inputs of bf16 q/k/v and fp32 mask, for models like albert. + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + + for pattern, replacement, args, workaround, extra_check in candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + if ( + any(p in name for p in mask_fp32_patterns) + and args[3].dtype == torch.float32 + ): + name += "_mask_fp32" + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield training_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + } + + if workaround: + assert len(workaround) == 1 and "dropout_p" in workaround + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + } + + +@functools.lru_cache(None) +def _sfdp_init(): + from .serialized_patterns.central_index import get_serialized_pattern + + for key, register_replacement_kwargs in _get_sfdp_patterns(): + search_fn_pattern = get_serialized_pattern(key) + register_replacement( + **register_replacement_kwargs, search_fn_pattern=search_fn_pattern + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..fbfb6aa4631d0fafd5c10f89a8bf00b17e6ea684 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py @@ -0,0 +1,1059 @@ +import collections +import logging +import operator +from collections import OrderedDict +from typing import ( + Any, + DefaultDict, + Deque, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, +) + +import torch +from torch._dynamo.utils import counters + +from .. import config +from ..pattern_matcher import ( + CallFunctionVarArgs, + get_arg_value, + stable_topological_sort, +) + +try: + # importing this will register fbgemm lowerings for inductor + import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 + + has_fbgemm = True +except Exception: + has_fbgemm = False + pass + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + +MIN_FUSE_SET_SIZE = 5 +MAX_FUSE_SET_SIZE = 300 +MAX_FUSE_SEARCH_DEPTH = 5 +# The maximum tensor size that can go into the fusion group +MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 + +# exclude these nodes from BFS +# excluding get item improves optimizer compilation time by 60s +SEARCH_EXCLUSIONS = {operator.getitem} + + +default_graph_search_options = { + "min_fuse_set_size": MIN_FUSE_SET_SIZE, + "max_fuse_set_size": MAX_FUSE_SET_SIZE, + "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, + "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, +} + +graph_search_options = default_graph_search_options + + +def update_stack_example_value(node, metadata, dim=0, op=torch.stack): + """ + Update the example value of the node in the graph to enable followup split cat opt. + """ + if node is not None and hasattr(node, "meta"): + if op == torch.stack: + example_value = torch.stack(metadata, dim=dim) + elif op == torch.unbind: + example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] + else: + return + node.meta["example_value"] = example_value + + +def update_pointwise_example_value(pointwise_node, input, other, op): + """ + Update the example value of the add node in the graph to enable followup split cat opt. + """ + if pointwise_node is not None and hasattr(pointwise_node, "meta"): + if op == torch.add: + example_value = torch.add(input, other) + elif op == torch.mul: + example_value = torch.mul(input, other) + else: + return + pointwise_node.meta["example_value"] = example_value + + +class GroupBatchFusionBase: + def __init__(self, **kwargs): + self.graph_search_options = kwargs.pop( + "graph_search_options", default_graph_search_options + ) + + def match(self, node): + raise NotImplementedError("match called on base") + + def fuse(self, graph, subset): + raise NotImplementedError("fuse called on base") + + +PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict() +POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = dict() + + +def register_fusion(name: str, pre_grad=True): + def decorator(fusion_cls: GroupBatchFusionBase): + if pre_grad: + PRE_GRAD_FUSIONS[name] = fusion_cls + else: + POST_GRAD_FUSIONS[name] = fusion_cls + return fusion_cls + + return decorator + + +def list_group_batch_fusions(pre_grad=True) -> List[str]: + if pre_grad: + return list(PRE_GRAD_FUSIONS.keys()) + else: + return list(POST_GRAD_FUSIONS.keys()) + + +def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any: + unsqueezed_inputs = [] + for input_tensor in input_tensors: + unsqueezed_input = graph.call_function( + aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} + ) + unsqueezed_inputs.append(unsqueezed_input) + stacked_inputs = graph.call_function( + aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} + ) + return stacked_inputs + + +class GroupFusion(GroupBatchFusionBase): + """ + Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. + """ + + pass + + +class BatchFusion(GroupBatchFusionBase): + """ + Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm. + """ + + pass + + +class BatchPointwiseOpsFusionFactory(BatchFusion): + def __init__(self, op, **kwargs): + super().__init__(**kwargs) + self.op = op + + +@register_fusion("batch_linear_post_grad", pre_grad=False) +class PostGradBatchLinearFusion(BatchFusion): + """ + Fuse ops in a batch way in post grad (aten level). + """ + + def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + return ( + node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value] + ) + + def _is_input_2d(self, input: torch.fx.Node) -> bool: + input_shapes = input.meta["tensor_meta"].shape + return ( + len(input_shapes) == 2 + and isinstance(input_shapes[0], int) + and isinstance(input_shapes[1], int) + ) + + def match(self, node: torch.fx.Node) -> Optional[Tuple[str, int, int, int, bool]]: + if CallFunctionVarArgs(aten.mm).match(node): + input_m, weight_m = node.args + bias_m = None + + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias_m, input_m, weight_m = node.args + else: + return None + + # only handle the cases where inputs are 2D tensors + if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] + return None + m, k = input_m.meta["tensor_meta"].shape # type: ignore[union-attr] + n = weight_m.meta["tensor_meta"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear", m, k, n, bias_m is not None) + return batch_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_nodes = [] + + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + elif CallFunctionVarArgs(aten.mm.default).match(node): + input, weight = node.args + bias = None + batch_nodes.append(node) + batch_inputs.append(input) # type: ignore[possibly-undefined] + batch_weights.append(weight) # type: ignore[possibly-undefined] + batch_biases.append(bias) # type: ignore[possibly-undefined] + + with graph.inserting_before(subset[-1]): + fused_inputs = decompose_stack(graph, batch_inputs) + fused_weights = decompose_stack(graph, batch_weights) + fused_bmm = graph.call_function( + aten.bmm, + args=(fused_inputs, fused_weights), + ) + + for i, original_mm in enumerate(batch_nodes): + has_bias = False + with graph.inserting_after(fused_bmm): + new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) + if batch_biases[i]: + has_bias = True + new_bias_add = graph.call_function( + aten.add, args=((batch_biases[i], new_mm)) + ) + new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] + original_mm.replace_all_uses_with(new_mm_cont) + new_mm_cont.meta.update(original_mm.meta) + graph.erase_node(original_mm) + + +@register_fusion("group_linear", pre_grad=False) +class GroupLinearFusion(GroupFusion): + def _addmm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["tensor_meta"].shape # type: ignore[union-attr] + return ( + node.kwargs.get("beta", 1.0) == 1.0 + and node.kwargs.get("alpha", 1.0) == 1.0 + and len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def _mm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[0].meta["tensor_meta"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["tensor_meta"].shape # type: ignore[union-attr] + return ( + len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]: + if CallFunctionVarArgs(aten.mm.default).match( + node + ) and self._mm_node_can_be_fused(node): + group_key = ("group_linear", True) + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias = node.args[0] + group_key = ("group_linear", bias is None) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + group_inputs = [] + group_weights = [] + group_biases = [] + group_nodes = [] + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + else: + assert CallFunctionVarArgs(aten.mm.default).match(node) + input, weight = node.args + bias = None + + group_nodes.append(node) + group_inputs.append(input) + group_weights.append(weight) + group_biases.append(bias) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + group_biases: Optional[List[Any]] + + with graph.inserting_before(subset[0]): + fused_mm = graph.call_function( + torch.ops.fbgemm.gmm.default, + args=(group_inputs, group_weights, group_biases), + kwargs={"smart_fused": True}, + ) + + for i, original_mm in enumerate(group_nodes): + with graph.inserting_after(fused_mm): + new_mm = graph.call_function(operator.getitem, args=(fused_mm, i)) + original_mm.replace_all_uses_with(new_mm) + new_mm.meta.update(original_mm.meta) + graph.erase_node(original_mm) + + +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise operator (e.g., add, mul) in post grad pass. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def _pointwise_node_can_be_fused(self, node: torch.fx.Node): + # note: we only consider the case where the inputs are tensors + # for mixed precision training, we need to make sure the inputs + # of the aten.cat when do the stack should be the same dtype + # otherwise, the output of the aten.cat may be not the same as + # its inputs, and cause dtype not same error in mm or addmm + input, other = node.args + return ( + input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape # type: ignore[union-attr] + if hasattr(input, "meta") + and hasattr(other, "meta") + and "tensor_meta" in input.meta # type: ignore[union-attr] + and "tensor_meta" in other.meta # type: ignore[union-attr] + else False + ) + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(self.op).match( + node + ) and self._pointwise_node_can_be_fused(node): + alpha = node.kwargs.get("alpha", 1.0) + rounding_mode = node.kwargs.get("rounding_mode", None) + input, other = node.args + shape = list(input.meta["tensor_meta"].shape) # type: ignore[union-attr] + group_key = ( + "batch_" + self.op.__name__.lower() + "_post_grad", + str(shape), + str(input.meta["tensor_meta"].dtype), # type: ignore[union-attr] + str(other.meta["tensor_meta"].dtype), # type: ignore[union-attr] + str(alpha), + str(rounding_mode), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_inputs, batch_others = [], [] + alpha = subset[0].kwargs.get("alpha", 1.0) + + for node in subset: + input, other = node.args + batch_inputs.append(input) + batch_others.append(other) + + with graph.inserting_before(subset[0]): + stack_inputs = decompose_stack(graph, batch_inputs) + stack_others = decompose_stack(graph, batch_others) + + batch_op = graph.call_function( + self.op, + args=(stack_inputs, stack_others), + kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, + ) + for i, original_add in enumerate(subset): + with graph.inserting_after(batch_op): + new_add = graph.call_function( + torch.ops.aten.select, args=((batch_op, 0, i)) + ) + original_add.replace_all_uses_with(new_add) + new_add.meta.update(original_add.meta) + graph.erase_node(original_add) + + +@register_fusion("batch_linear_lhs") +class BatchLinearLHSFusion(BatchFusion): + """ + Batch linear left-hand side fusion. This pass tries to fuse the following patterns: + + torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn) + -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1)) + + We have a separate pass to eliminate contiguous transpose in a generic way. + """ + + def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]: + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + bias = get_arg_value(node, 2, "bias") + group_key = ("batch_linear_lhs", bias is None, input) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_input = None + batch_weights = [] + batch_biases = [] + split_sections = [] + for node in subset: + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + batch_nodes.append(node) + if batch_input is None: + batch_input = input + else: + assert batch_input is input + batch_weights.append(weight) + if bias: + batch_biases.append(bias) + split_sections.append(weight.meta["example_value"].shape[0]) + + with graph.inserting_before(subset[0]): + cat_weights = graph.call_function( + torch.cat, args=(batch_weights,), kwargs={"dim": 0} + ) + transposed_weights = graph.call_function( + torch.transpose, args=(cat_weights, 0, 1) + ) + if len(batch_biases) > 0: + cat_biases = graph.call_function( + torch.cat, args=(batch_biases,), kwargs={"dim": 0} + ) + fused_lhs = graph.call_function( + torch.addmm, + args=(cat_biases, batch_input, transposed_weights), + ) + else: + fused_lhs = graph.call_function( + torch.mm, + args=(batch_input, transposed_weights), + ) + fused_lhs_list = graph.call_function( + torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1} + ) + + for i, node in enumerate(batch_nodes): + with graph.inserting_after(fused_lhs_list): + new_node = graph.call_function( + operator.getitem, args=(fused_lhs_list, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + + +def is_node_meta_valid(node: Optional[torch.fx.Node]): + if node is None: + return True + if "example_value" not in node.meta: + return False + return True + + +def is_linear_node_can_be_fused(node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + return ( + is_node_meta_valid(node) + and is_node_meta_valid(input) + and is_node_meta_valid(weight) + and len(input.meta["example_value"].shape) == 2 + and len(weight.meta["example_value"].shape) == 2 + ) + + +@register_fusion("batch_linear") +class PreGradBatchLinearFusion(BatchFusion): + """ + Batch linear fusion in pre grad pass. + Fuse linear with same size with torch.baddmm + """ + + def _getitem_args(self, getitem_node: torch.fx.Node): + if getitem_node.target != operator.__getitem__ or ( + getitem_node.op != "call_function" + ): + return None + return getitem_node.args[0] + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + group_key = ( + "batch_linear_pre_grad", + self._getitem_args(input), + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape), + bias is None, + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_inputs_metadata = [] + batch_weights_metadata = [] + batch_biases_metadata = [] + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + weight = get_arg_value(node, 1, "weight") + batch_weights.append(weight) + batch_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 2, "bias") + batch_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + batch_biases_metadata.append(bias.meta["example_value"]) + + with graph.inserting_before(subset[0]): + stack_inputs = graph.call_function( + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + stack_weights = graph.call_function( + torch.stack, args=(batch_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weights, batch_weights_metadata) + transpose_weight = graph.call_function( + torch.transpose, args=(stack_weights, 1, 2) + ) + if all(bias is None for bias in batch_biases): + bmm = graph.call_function( + torch.bmm, + args=(stack_inputs, transpose_weight), + ) + else: + stack_biases = graph.call_function( + torch.stack, args=(batch_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_biases, batch_biases_metadata) + unsqueeze_biases = graph.call_function( + torch.unsqueeze, args=(stack_biases, 1) + ) + bmm = graph.call_function( + torch.baddbmm, + args=(unsqueeze_biases, stack_inputs, transpose_weight), + ) + + bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) + for i, linear in enumerate(batch_nodes): + with graph.inserting_after(bmm): + getitem = graph.call_function(operator.getitem, args=(bmm, i)) + linear.replace_all_uses_with(getitem) + getitem.meta.update(linear.meta) + graph.erase_node(linear) + + +@register_fusion("batch_layernorm") +class BatchLayernormFusion(BatchFusion): + """ + Batch layer norm fusion in pre grad pass + """ + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 2, "weight") + bias = get_arg_value(node, 3, "bias") + group_key = ( + ( + "batch_layernorm", + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape) + if weight is not None + else "", + str(bias.meta["example_value"].shape) if bias is not None else "", + str(get_arg_value(node, 1, "normalized_shape")), + str(get_arg_value(node, 4, "eps")), + ) + if "example_value" in input.meta + and is_node_meta_valid(weight) + and is_node_meta_valid(bias) + else None + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + group_inputs = [] + group_shapes = [] + group_weights = [] + group_biases = [] + group_epss = [] + group_nodes = [] + group_inputs_metadata = [] + group_biases_metadata = [] + group_weights_metadata = [] + for node in subset: + group_nodes.append(node) + input = get_arg_value(node, 0, "input") + group_inputs.append(input) + group_inputs_metadata.append(input.meta["example_value"]) + group_shapes.append(get_arg_value(node, 1, "normalized_shape")) + weight = get_arg_value(node, 2, "weight") + group_weights.append(weight) + if weight is not None and hasattr(weight, "meta"): + group_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 3, "bias") + group_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + group_biases_metadata.append(bias.meta["example_value"]) + eps = get_arg_value(node, 4, "eps") + if eps is None: + eps = 1e-5 + group_epss.append(eps) + stack_dim = -1 - len(group_shapes[-1]) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + group_biases: Optional[List[Any]] + if all(weight is None for weight in group_weights): + group_weights = None # type: ignore[assignment] + group_weights: Optional[List[Any]] + assert all( + eps == group_epss[0] for eps in group_epss + ), "all epsilon values must be equal" + + with graph.inserting_before(subset[0]): + stack_input = graph.call_function( + torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim} + ) + update_stack_example_value(stack_input, group_inputs_metadata, stack_dim) + if group_weights is not None: + stack_weight = graph.call_function( + torch.stack, args=(group_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weight, group_weights_metadata) + else: + stack_weight = None + if group_biases is not None: + stack_bias = graph.call_function( + torch.stack, args=(group_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_bias, group_biases_metadata) + else: + stack_bias = None + + batch_layer_norm = graph.call_function( + torch.nn.functional.layer_norm, + args=(stack_input, group_shapes[-1]), + kwargs={"eps": group_epss[-1]}, + ) + batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"] + + if group_weights is not None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + elif group_weights is not None and group_biases is None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + elif group_weights is None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + + batch_layer_norm_unbind = graph.call_function( + torch.unbind, + args=(batch_layer_norm,), + kwargs={"dim": stack_dim}, + ) + update_stack_example_value( + batch_layer_norm_unbind, + batch_layer_norm.meta["example_value"], + op=torch.unbind, + dim=stack_dim, + ) + + for i, node in enumerate(group_nodes): + with graph.inserting_after(batch_layer_norm_unbind): + new_node = graph.call_function( + operator.getitem, args=(batch_layer_norm_unbind, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + + +class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch poinwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass. + We fuse it in random place, and the introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + group_key = ( + "batch_" + self.op.__name__.lower() + "_pre_grad", + str(input.meta["example_value"].shape), + str(node.kwargs.get("inplace", False)), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): + stack_inputs = graph.call_function( + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + if self.op == torch.nn.functional.relu: + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + kwargs={"inplace": subset[0].kwargs.get("inplace", False)}, + ) + else: + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + ) + unbind_op = graph.call_function( + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) + + +@register_fusion("batch_tanh") +class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.tanh, **kwargs) + + +@register_fusion("batch_sigmoid") +class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.sigmoid, **kwargs) + + +@register_fusion("batch_relu") +class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.nn.functional.relu, **kwargs) + + +@register_fusion("batch_aten_add", pre_grad=False) +class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.add.Tensor, **kwargs) + + +@register_fusion("batch_aten_sub", pre_grad=False) +class BatchSubPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.sub.Tensor, **kwargs) + + +@register_fusion("batch_aten_div", pre_grad=False) +class BatchDivPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.div.Tensor, **kwargs) + + +@register_fusion("batch_aten_mul", pre_grad=False) +class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs): + super().__init__(aten.mul.Tensor, **kwargs) + + +class _OrderedSet: + def __init__(self, param=None): + if param: + self.rep = OrderedDict({k: None for k in param}) + else: + self.rep = OrderedDict() + + def __contains__(self, o): + return o in self.rep + + def __len__(self): + return self.rep.__len__() + + def append(self, o): + self.rep[o] = None + + def __iter__(self): + return self.rep.keys().__iter__() + + +def find_independent_subset_greedy( + node_list: Iterable[torch.fx.Node], + graph_search_options: Dict[str, Any], +) -> Iterator[Iterable[torch.fx.Node]]: + """ + Yields a list of subsets of `node_list` where no element in the subset + depends on any other element in the subset. This results in a set of + independent nodes which can be fused together. + + The order of `node_list` is preserved within each subset so we can benefit + from split-cat elimination in later passes. + + During iteration it is only safe to mutate the graph by changing the nodes + that have been returned. + + graph_search_options: + - min_fuse_set_size: Minimum size of the subset to consider. Subsets below + this size will be ignored. + - max_fuse_set_size: Maximum size of the subset to consider. Subsets will + be broken to be at most this size. + """ + + # Compute all the children of `node` which are members of + # `interesting_nodes`. + def find_dependent_nodes(node, interesting_nodes): + visited_node_set: Set[torch.fx.Node] = {node} + dep_set: Set[torch.fx.Node] = set() + + work = [node] + while work: + node = work.pop() + for input_node in node.all_input_nodes: + if input_node in interesting_nodes: + dep_set.add(input_node) + + if input_node not in visited_node_set: + visited_node_set.add(input_node) + work.append(input_node) + + return dep_set + + min_fuse_set_size = graph_search_options["min_fuse_set_size"] + max_fuse_set_size = graph_search_options["max_fuse_set_size"] + + # node_list needs to be a set because we only track the nodes that are left + # in it (and we want to do the `in` on a set, not a list). But we want to + # keep the correct order. + node_list = _OrderedSet(node_list) + + cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {} + while node_list: + subset: List[torch.fx.Node] = [] + subset_deps: Set[torch.fx.Node] = set() + + next_round_node_list = _OrderedSet() + for node in node_list: + if len(subset) >= max_fuse_set_size or node in subset_deps: + next_round_node_list.append(node) + continue + + dep_set = cache.pop(node, None) + if dep_set is None: + dep_set = find_dependent_nodes(node, node_list) + + if not dep_set.intersection(subset): + subset.append(node) + subset_deps.update(dep_set) + else: + next_round_node_list.append(node) + cache[node] = dep_set + + if len(subset) >= min_fuse_set_size: + # Careful here - the caller uses the subsets to fuse nodes together + # so we need to clear any cache entry that contains one of the + # returned nodes because the dependency list could be different + # (larger) after the merge. + cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)} + yield subset + + node_list = next_round_node_list + + +def get_fusion_candidates( + rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node] +) -> DefaultDict[Any, List[torch.fx.Node]]: + """ + Search fusion candidates for a specific rule using BFS starting from the root node. + We only search the subgraph within graph_search_options["max_fuse_search_depth"]. + """ + q: Deque[Tuple[int, torch.fx.Node]] = collections.deque() + + candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict( + list + ) + + if root_node.target in SEARCH_EXCLUSIONS: + return candidate_dict + + visited_set: Set[torch.fx.Node] = set() + + for next_node in root_node.all_input_nodes: + q.append((1, next_node)) + visited_set.add(next_node) + + while len(q) > 0: + depth, node = q.popleft() + + if node in fused_set: + continue + + key = rule.match(node) + if key is not None: + candidate_nodes = candidate_dict[key] + if node not in candidate_nodes: + candidate_nodes.append(node) + else: + if depth < rule.graph_search_options["max_fuse_search_depth"]: + for next_node in node.all_input_nodes: + if next_node not in visited_set: + visited_set.add(next_node) + q.append((depth + 1, next_node)) + + return candidate_dict + + +def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase): + stable_topological_sort(graph) # type: ignore[arg-type] + fused_set: Set[torch.fx.Node] = set() + + for node in reversed(graph.nodes): + candidates = get_fusion_candidates(rule, node, fused_set) + + for key, candidate_nodes in candidates.items(): + if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]: + continue + + for subset in find_independent_subset_greedy( + candidate_nodes, rule.graph_search_options + ): + rule.fuse(graph, subset) + fused_set.update(subset) + if isinstance(rule, GroupFusion): + counters["inductor"]["group_fusion"] += 1 + elif isinstance(rule, BatchFusion): + counters["inductor"]["batch_fusion"] += 1 + else: + counters["inductor"]["unknown_group_batch_fusion"] += 1 + + log.debug( + f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004 + ) + + +def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True): + fusions: List[GroupBatchFusionBase] = [] + for name, options in config_options.items(): + fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name] + _options = graph_search_options.copy() + _options.update(options) + fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator] + return fusions + + +def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): + fusions: List[GroupBatchFusionBase] = [] + # we keep all current pre grad fusions to keep + # current implementation, will remove this later + if pre_grad: + fusions += generate_fusion_from_config( + config.pre_grad_fusion_options, pre_grad=True + ) + else: + fbgemm_fusion_keys = [ + x + for x in config.post_grad_fusion_options + if config.post_grad_fusion_options[x].get("require_fbgemm", False) + ] + fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in fbgemm_fusion_keys + } + non_fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in config.post_grad_fusion_options.keys() + if fusion not in fbgemm_fusion_keys + } + fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) + if has_fbgemm: + fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) + + for rule in fusions: + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/joint_graph.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/joint_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3e292db6922a21a9cfc0f73bb03508ad07f7b4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/joint_graph.py @@ -0,0 +1,341 @@ +import logging +import typing +from collections import Counter +from typing import Dict, List, Set + +import torch +import torch._guards +from torch._inductor.constant_folding import ConstantFolder +from torch.multiprocessing.reductions import StorageWeakRef + +from .. import config +from ..pattern_matcher import ( + CallFunction, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from .replace_random import replace_random_passes + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() + + +@init_once_fakemode +def lazy_init(): + from .fuse_attention import _sfdp_init + from .misc_patterns import _misc_patterns_init + from .pad_mm import _pad_mm_init + + _pad_mm_init() + _sfdp_init() + _misc_patterns_init() + + +@torch.utils._python_dispatch._disable_current_modes() +def remove_no_ops( + gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node] +): + "Removes no-ops: (+ 0, - 0, * 1, / 1)" + aten = torch.ops.aten + graph = gm.graph + + def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")): + if any(not isinstance(t, torch.Tensor) for t in (t1, t2)): + return False + for field in fields: + if getattr(t1, field) != getattr(t2, field): + return False + return True + + def replace_no_op(node, replace_input_index): + replacement = node.args[replace_input_index] + + # https://github.com/pytorch/pytorch/issues/86128 causes + # non-Tensor inputs even for ops with only Tensor inputs. + # TODO - decompose/type promote to avoid this + if not all(isinstance(arg, torch.fx.Node) for arg in node.args): + return + + if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]): + if fake_tensors_eq( + node.meta["val"], + replacement.meta["val"], + ("shape", "device"), + ): + with graph.inserting_after(node): + replacement = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(replacement, node.meta["val"].dtype), + ) + else: + return + + node.replace_all_uses_with(replacement) + replacement.meta.update(node.meta) + graph.erase_node(node) + + for node in graph.nodes: + if node.op != "call_function": + continue + + # TODO handle Tensor-Scalar adds, it's a different schema + if node.target == aten.add.Tensor and len(node.args) == 2: + if ( + not any(e in zeros for e in node.args) + or node.kwargs.get("alpha", 1) != 1 + ): + continue + + replace_index = 1 if node.args[0] in zeros else 0 + replace_no_op(node, replace_index) + + elif node.target == aten.sub.Tensor and len(node.args) == 2: + if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1: + continue + + replace_no_op(node, 0) + + elif node.target == aten.mul.Tensor and len(node.args) == 2: + if not any(e in ones for e in node.args): + continue + + replace_input_index = 1 if node.args[0] in ones else 0 + replace_no_op(node, replace_input_index) + + elif ( + node.target == aten.div.Tensor + and len(node.args) == 2 + and node.args[1] in ones + ): + replace_no_op(node, 0) + + +@torch.utils._python_dispatch._disable_current_modes() +def remove_redundant_views(gm: torch.fx.GraphModule): + """ + Removes redundant views by reusing existing ones. + """ + + # A dictionary mapping a tensor to all aliased views. + views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {} + graph = gm.graph + + for node in graph.nodes: + if node.op != "call_function": + continue + + if node.target != torch.ops.aten.view.dtype: + continue + + src = node.args[0] + to_type = node.args[1] + existing_views = views.get(src) + is_needed = True + + if existing_views: + # Replace the view with the an existing view if available. + alias = existing_views.get(to_type) + if alias: + is_needed = False + node.replace_all_uses_with(alias) + alias.meta.update(node.meta) + graph.erase_node(node) + else: + from_type = src.meta["val"].dtype + existing_views = {from_type: src} + views[src] = existing_views + + if is_needed: + # Save the new alias but do not replace existing one. + existing_views.setdefault(to_type, node) + views[node] = existing_views + + # Clean up unused views. + while True: + unused_views = [alias for alias in views if not alias.users] + if len(unused_views) == 0: + break + for unused in unused_views: + views.pop(unused) + graph.erase_node(unused) + + +class UniformValueConstantFolder(ConstantFolder): + """ + Runs constant folding and replaces tensors that have a unifrom value + with a tensor constructor call: aten.full([shape], value, ...) + """ + + def __init__(self, gm, skip_constructors=False): + super().__init__(gm, skip_constructors) + self.node_storages_ptrs: Dict[torch.fx.Node, int] = {} + self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {} + # we may constant fold a tensor which in the graph has a sym size + # see: [constant folding refining of symints] + self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {} + + def insertable_tensor_check(self, t: torch.Tensor) -> bool: + # TODO - we could also Tensors which get replaced with arange here + return ( + t.numel() != 0 + and bool((t == t.flatten()[0]).all()) + and torch._C._has_storage(t) + and t.layout == torch.strided + ) + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor.flatten()[0].item() + self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage()) + shape = list(tensor.shape) + assert all(type(dim) is int for dim in shape) + self.node_replacements_shapes[node] = shape + + +@torch.utils._python_dispatch._disable_current_modes() +def constant_fold_uniform_value(gm: torch.fx.GraphModule): + "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops." + aten = torch.ops.aten + + # Constant folding can leak memory, especially with repeated compilation, so we are only going to + # remove constants which can be replaced with a constructor. + cf = UniformValueConstantFolder(gm) + cf.run() + + node_replacements = cf.node_replacements + + # note: [constant folding refining of symints] + # constant folding will partially evaluate a graph such that values which have dependencies which + # are entirely known at compile time may also become compile time constants. in some cases, + # this will include symints which we had not yet previously deduced are guaranteed a + # constant value and is then deduced in constant folding. an example is: + # unbacked_symint_eq_11 = torch.full((), 11).item() + # torch.full((unbacked_symint_eq_11,), 0) + node_replacements_shapes = cf.node_replacements_shapes + + graph = gm.graph + + zeros = set() + ones = set() + + # Got failures in `test_is_set_to_cuda` if we change aliasing on constants, + # so just constant-ify if a Tensor is unaliased + constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter() + + for node in cf.node_replacements: + constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1 + + for node, value in node_replacements.items(): + # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now + # hasn't shown up to be important yet + fake_tensor = node.meta["val"] + if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format): + continue + + if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1: + continue + + with graph.inserting_after(node): + # the conversion from tensor and back to value can be lossy, just use the original full ctor value + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + value = node.args[1] + + # refines symints, see [constant folding refining of symints] above + for runtime_size, compile_time_size in zip( + node_replacements_shapes[node], fake_tensor.shape + ): + torch._check(runtime_size == compile_time_size) + + # zeros, and ones just get traced into full, so we insert those + new_node = graph.call_function( + aten.full.default, + args=(node_replacements_shapes[node], value), + kwargs={ + "dtype": fake_tensor.dtype, + "layout": torch.strided, + "device": fake_tensor.device, + "pin_memory": False, + }, + ) + + new_node.meta.update(node.meta) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if value == 0: + zeros.add(new_node) + elif value == 1: + ones.add(new_node) + + remove_no_ops(gm, zeros, ones) + remove_redundant_views(gm) + + +def joint_graph_passes(graph: torch.fx.GraphModule): + """ + Run FX transformations on the joint forwards+backwards graph. + """ + lazy_init() + count = 0 + + if config.joint_graph_constant_folding: + constant_fold_uniform_value(graph) + + if config.pattern_matcher: + count += patterns.apply(graph.graph) # type: ignore[arg-type] + + if not config.fallback_random: + count += replace_random_passes(graph) + + if count: + stable_topological_sort(graph.graph) + graph.graph.lint() + graph.recompile() + return graph + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + KeywordArg("arg"), + KeywordArg("dtype1"), + ), + KeywordArg("dtype2"), + ), + pass_dict=patterns, +) +def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): + """Remove chain of dtype conversions often created by AMP""" + graph = match.graph + node = match.output_node() + allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64} + if dtype1 in allowed and dtype2 in allowed: + repl = graph.call_function( + torch.ops.prims.convert_element_type.default, (arg, dtype2) + ) + repl.meta.update(node.meta) + node.replace_all_uses_with(repl) + match.erase_nodes(graph) + + +@register_graph_pattern( + CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), + pass_dict=patterns, +) +def pointless_view(match: Match, arg, size): + """Remove no-op view""" + graph = match.graph + node = match.output_node() + arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] + if size == arg_size: + node.replace_all_uses_with(node.args[0]) + match.erase_nodes(graph) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..e91fdd6611af037d7e855489d990ec31f0c490cf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,130 @@ +import functools + +from typing import Dict, Set, Tuple + +import torch +from torch._dynamo.utils import counters + +from torch._ops import OpOverload, OpOverloadPacket +from ..pattern_matcher import fwd_only, register_replacement + +aten = torch.ops.aten + + +@functools.lru_cache(None) +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + randperm_index_add_pattern, + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + pattern = register_replacement( + randperm_index_pattern, + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) + + +class NumpyCompatNormalization: + numpy_compat: Dict[str, Tuple[str, ...]] = { + "dim": ("axis",), + "keepdim": ("keepdims",), + "input": ("x", "a", "x1"), + "other": ("x2",), + } + inverse_mapping: Dict[str, str] + cache: Dict["torch.fx.graph.Target", Set[str]] + + def __init__(self): + self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] + self.inverse_mapping = {} + for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): + for numpy_kwarg in numpy_kwargs: + assert numpy_kwarg not in self.inverse_mapping + self.inverse_mapping[numpy_kwarg] = actual_kwarg + + def __call__(self, graph: torch.fx.Graph): + for node in graph.nodes: + if node.op != "call_function": + continue + if isinstance(node.target, (OpOverload, OpOverloadPacket)): + # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. + continue + kwargs = node.kwargs + + if node.target in self.cache: + replaceable_kwargs = self.cache[node.target] + else: + signatures = torch.fx.operator_schemas.get_signature_for_torch_op( + node.target + ) + signatures = () if signatures is None else signatures + replaceable_kwargs = set() + for sig in signatures: + for param_name in sig.parameters.keys(): + if param_name in self.numpy_compat: + replaceable_kwargs.update(self.numpy_compat[param_name]) + + self.cache[node.target] = replaceable_kwargs + + if not replaceable_kwargs: + continue + + new_kwargs = {} + kwargs_changed = False + for k, v in kwargs.items(): + if k in replaceable_kwargs: + kwargs_changed = True + new_kwargs[self.inverse_mapping[k]] = v + else: + new_kwargs[k] = v + + if kwargs_changed: + node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) + counters["inductor"]["numpy_compat_normalization"] += 1 + + +numpy_compat_normalization = NumpyCompatNormalization() diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..81da87fd503a8850d367ba9c7308e757ec976919 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -0,0 +1,1204 @@ +import functools +import operator +from functools import reduce +from typing import Any, Tuple + +import torch + +from torch.fx.experimental.symbolic_shapes import has_free_symbols + +from .. import ir + +from ..lowering import lowerings as L +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + get_arg_value, + KeywordArg, + MULTIPLE, +) +from ..virtualized import ops +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern +from .quantization import ( + _register_quantization_lowerings, + _register_quantization_weight_pack_pass, +) + +if torch._C._has_mkldnn: + aten = torch.ops.aten + mkldnn = torch.ops.mkldnn + prims = torch.ops.prims + + _conv_args = [Arg() for _ in range(10)] + _linear_args = [Arg() for _ in range(6)] + _conv_transpose_args = [Arg() for _ in range(11)] + + def _conv_call(users=1): + return CallFunction( + mkldnn._convolution_pointwise.default, *_conv_args, _users=users + ) + + def _linear_call(users=1): + return CallFunction( + mkldnn._linear_pointwise.default, *_linear_args, _users=users + ) + + def _conv_transpose_call(users=1): + return CallFunction( + mkldnn._convolution_transpose_pointwise.default, + *_conv_transpose_args, + _users=users, + ) + + def _to_float(input_call, users=1): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_float"), + _users=users, + ) + + def _to_bf16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_bf16"), + _users=1, + ) + + def _to_fp16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_fp16"), + _users=1, + ) + + def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): + # only insert to_dtype if lowp_dtype is True + computation_call = ( + _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) + ) + out = unary_fusion(computation_call) + if lowp_dtype == torch.bfloat16: + return _to_bf16(out) + elif lowp_dtype == torch.float16: + return _to_fp16(out) + else: + return out + + def _gelu_fusion_1(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.erf, + CallFunction(aten.mul, computation_call, 0.7071067811865476), + ), + 1, + ), + ) + + def _gelu_fusion_2(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.tanh, + CallFunction( + aten.mul, + CallFunction( + aten.add, + computation_call, + CallFunction( + aten.mul, + CallFunction( + aten.mul, + CallFunction( + aten.mul, computation_call, computation_call + ), + computation_call, + ), + 0.044715, + ), + ), + 0.7978845608028654, + ), + ), + 1, + ), + ) + + def _hardswish_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + ), + 6, + ) + + def _silu_fusion(computation_call): + return CallFunction( + aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) + ) + + def _hardsigmoid_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + 6, + ) + + def _leaky_relu_fusion(computation_call): + return CallFunction( + aten.where, + CallFunction(aten.gt, computation_call, 0), + computation_call, + CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), + ) + + def _hardtanh_fusion(computation_call): + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + + def _combined_fusion(computation_call, elementwise_op): + return CallFunction(elementwise_op, computation_call) + + # binary_op(other, computation_op) + def _binary_fusion_v1(computation_call, binary_fn): + return CallFunction(binary_fn, KeywordArg("other"), computation_call) + + # binary_op(computation_op, other) + def _binary_fusion_v2(computation_call, binary_fn): + return CallFunction(binary_fn, computation_call, KeywordArg("other")) + + def _is_single_computation_op(computation_op): + def fn(match): + computation_nodes = filter_nodes(match.nodes, computation_op) + if len(computation_nodes) < 1: + return False + if any(n.args[-3] != "none" for n in computation_nodes): + return False + return True + + return fn + + def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): + def fn(match): + matched = _is_single_computation_op(computation_op)(match) + computation_node = filter_nodes(match.nodes, computation_op)[0] + if lowp_dtype: + conversion_dtype_nodes = filter_nodes( + match.nodes, prims.convert_element_type.default + ) + if len(conversion_dtype_nodes) != 2: + return False + # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 + if computation_node == conversion_dtype_nodes[0].args[0]: + to_float = conversion_dtype_nodes[0].args[1] + to_lp = conversion_dtype_nodes[1].args[1] + else: + to_float = conversion_dtype_nodes[1].args[1] + to_lp = conversion_dtype_nodes[0].args[1] + matched = matched and to_float == torch.float and to_lp == lowp_dtype + return matched + + return fn + + def _register_unary_fusion_lowering( + pattern, unary_attr, computation_op, lowp_dtype=None + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), + ) + def fn(match, *args, **kwargs): + computation_args = list(args)[:-3] + [ + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + return L[computation_op](*computation_args) + + return fn + + def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op) + ) + def fn(match, *args, **kwargs): + negative_slope = kwargs.get("negative_slope") + if isinstance(negative_slope, ir.TensorBox): + matched = False + else: # inp is a Number + matched = True + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + if matched: + computation_args = computation_args[:-3] + [ + "leaky_relu", + [negative_slope], + "", + ] + return L[computation_op](*computation_args) + else: + # computation_args += ["none", [], ""] + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.where]( + L[aten.gt](out, 0), + out, + L[aten.mul](out, negative_slope), + ) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op) + ) + def fn(match, *args, **kwargs): + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + if isinstance(min_value, ir.TensorBox) or isinstance( + max_value, ir.TensorBox + ): + matched = False + else: # inp is a Number + assert max_value is not None + matched = min_value <= max_value + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + if matched: + computation_args = computation_args[:-3] + [ + "hardtanh", + [min_value, max_value], + "", + ] + return L[computation_op](*computation_args) + else: + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + _binary_attr = { + aten.add: "add", + ops.add: "add", + aten.sub: "sub", + ops.sub: "sub", + } + + def _is_valid_binary(match, fn): + binary_nodes = filter_nodes(match.nodes, fn) + if len(binary_nodes) < 1: + return False + + def get_meta_value(argument: torch.fx.node.Argument): + # Only torch.fx.Node is expected to have meta. + if isinstance(argument, torch.fx.Node): + return argument.meta.get("val", None) + return None + + if any( + not isinstance(get_meta_value(n.args[0]), torch.Tensor) + or not isinstance(get_meta_value(n.args[1]), torch.Tensor) + for n in binary_nodes + ): + return False + # check alpha is one. + if any( + get_arg_value(n, 2, kwarg_name="alpha") != 1.0 + and get_arg_value(n, 2, kwarg_name="alpha") is not None + for n in binary_nodes + ): + return False + if any( + get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size() + or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device + or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype + for n in binary_nodes + ): + return False + # check args[0] and args[1] is not same + if any(n.args[0] == n.args[1] for n in binary_nodes): + return False + return True + + def _is_valid_computation_binary(computation_op, binary_op, other_index=None): + def fn(match): + if not _is_single_computation_op(computation_op)(match): + return False + if not _is_valid_binary(match, binary_op): + return False + return True + + return fn + + def _get_remaining_users(extra_input_node, compute_node): + # Think about this pattern: + # ReLU + # / \ + # Conv1 + # / \ + # Conv2 + # \ / + # Add + # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. + # The Conv1 is the ancestor node of the current compute node (Conv2). + # This indicates that the buffer of ReLU has completed all its usage, + # So we can safely make changes to it now by doing Conv2->Add inplace fusion. + # Take above case as example: + # * extra_input_node: ReLU + # * compute_node: Conv2 + # _get_remaining_users will return the users of extra_input_node which are not + # ancestor node of compute_node. + def _is_ancestor_node(_current_node, _ancestor_node): + # Check whether _ancestor_node is the ancestor node of _current_node + _node_list = [_current_node] + _visited_nodes = set() + while len(_node_list) != 0: + _current_node = _node_list.pop(0) + if _current_node not in _visited_nodes: + _visited_nodes.add(_current_node) + if _current_node == _ancestor_node: + return True + elif isinstance( + _current_node, torch.fx.Node + ) and _current_node.op not in ["placeholder", "output", "get_attr"]: + for input in _current_node.all_input_nodes: + _node_list.append(input) # noqa: PERF402 + return False + + return [ + user + for user in list(extra_input_node.users) + if not _is_ancestor_node(compute_node, user) + ] + + def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): + def fn(match): + if not _is_valid_computation_binary(computation_op, binary_op)(match): + return False + binary_nodes = filter_nodes(match.nodes, binary_op) + + def _get_compute_node(_binary_node, _other_index): + assert ( + len(_binary_node.all_input_nodes) == 2 + ), "Binary node should have 2 input nodes." + _compute_index = 1 if (_other_index == 0) else 0 + return _binary_node.args[_compute_index] + + def _other_input_not_inplaceable(_binary_node, _other_index): + _compute_node = _get_compute_node(_binary_node, _other_index) + return ( + len( + _get_remaining_users( + _binary_node.args[_other_index], _compute_node + ) + ) + > 1 + or _binary_node.args[_other_index] == _compute_node.args[0] + ) + + if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): + return False + if any( + n.args[other_index].op in ["placeholder", "output"] + for n in binary_nodes + ): + return False + return True + + return fn + + def _register_binary_unary_fusion_lowering( + pattern, + computation_op, + binary_op, + fusion_op, + unary_attr=None, + ): + @register_lowering_pattern( + pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + return L[fusion_op](*computation_args) + + return fn + + def _can_be_inplace(_other): + if isinstance(_other.data, ir.View): + return _can_be_inplace(_other.data) + else: + return not ( + isinstance(_other.data, ir.ReinterpretView) + or isinstance( + _other.get_layout(), (ir.MutationLayout, ir.AliasedLayout) + ) + ) + + def _register_binary_unary_maybe_inplace_fusion_lowering( + pattern, + computation_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + unary_attr=None, + other_index=None, + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_binary_inplace( + computation_op, binary_op, other_index + ), + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + # Make sure the other is not an alias or mutation(fx side doesn't has such info). + other.realize() + if not _can_be_inplace(other): + return L[outplace_fusion_op](*computation_args) + return L[inplace_fusion_op](*computation_args) + + return fn + + computation_ops = [ + mkldnn._convolution_pointwise.default, + mkldnn._linear_pointwise.default, + mkldnn._convolution_transpose_pointwise.default, + ] + + class UnaryAttr: + def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + def _register_unary_fusion(): + computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] + + def _unary_fusion_patterns(lowp_dtype): + replacement_unary_fusion_patterns = { + UnaryAttr("gelu", algorithm_attr="tanh"): [ + _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("gelu", algorithm_attr="none"): [ + _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardswish"): [ + _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardsigmoid"): [ + _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("swish"): [ + _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + } + if not lowp_dtype: + call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] + replacement_unary_fusion_patterns.update( + { + UnaryAttr("relu"): [ + _combined_fusion(u, aten.relu) for u in call_user1 + ], + UnaryAttr("sigmoid"): [ + _combined_fusion(u, aten.sigmoid) for u in call_user1 + ], + UnaryAttr("tanh"): [ + _combined_fusion(u, aten.tanh) for u in call_user1 + ], + } + ) + + return replacement_unary_fusion_patterns + + for lowp_dtype in [torch.bfloat16, torch.float16, None]: + replace_patterns = _unary_fusion_patterns(lowp_dtype) + for unary_attr, patterns in replace_patterns.items(): + _register_unary_fusion_lowering( + patterns[0], unary_attr, computation_ops[0], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[1], unary_attr, computation_ops[1], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[2], unary_attr, computation_ops[2], lowp_dtype + ) + _leaky_relu_patterns = [ + _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): + _register_leaky_relu_fusion_lowering( + pattern, computation_op, lowp_dtype + ) + hardtanh_patterns = [ + _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(hardtanh_patterns, computation_ops): + _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) + + def _register_inplace_fusion(): + binary_ops = [aten.add, ops.add] + inplace_fusion_op = mkldnn._convolution_pointwise_.binary + outplace_fusion_op = mkldnn._convolution_pointwise.binary + conv_call = _conv_call(users=1) + conv_op = computation_ops[0] + for binary_op in binary_ops: + binary_v1 = _binary_fusion_v1(conv_call, binary_op) + binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + ) + binary_v2 = _binary_fusion_v2(conv_call, binary_op) + binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + ) + + def _register_binary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [ + mkldnn._convolution_pointwise.binary, + mkldnn._linear_pointwise.binary, + ] + _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern = _binary_fusion_v2(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + for binary_op in [aten.add, ops.add]: + pattern = _binary_fusion_v1(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + def _register_binary_unary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [mkldnn._convolution_pointwise.binary] + _computation_user_1 = [_conv_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern_v1 = _combined_fusion( + _binary_fusion_v2(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v1, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + for binary_op in [aten.add, ops.add]: + pattern_v2 = _combined_fusion( + _binary_fusion_v1(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v2, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + + def _recover_linear(): + # convert reshape+linear+reshape to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.reshape.default, + CallFunction( + mkldnn._linear_pointwise.default, + CallFunction( + aten.reshape.default, + Arg(), + KeywordArg("reshape_1"), + _users=MULTIPLE, + ), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("reshape_2"), + ), + pass_number=1, + ) + def reshape_linear_reshape_pattern(match, *args, **kwargs): + reshape_1 = kwargs.get("reshape_1") + reshape_2 = kwargs.get("reshape_2") + assert isinstance(reshape_1, list) + assert isinstance(reshape_2, list) + assert len(reshape_1) == 2 + dynamic_shapes = not all( + isinstance(x, int) for x in ([reshape_1[0]] + reshape_2[:-1]) + ) + + graph = match.graph + reshape_2_node = match.output_node() + linear_input_node = reshape_2_node.args[0].args[0].args[0] + # check linear's input's shape[:-1] == reshape_2[:-1] + # and check product(reshape_2[:-1]) == reshape_1[0] + if dynamic_shapes: + # TODO: Haozhe investigate how add guard here + return + else: + can_remove_reshape = linear_input_node.meta.get("val").shape[ + :-1 + ] == torch.Size(reshape_2[:-1]) + can_remove_reshape = can_remove_reshape and ( + reduce(operator.mul, reshape_2[:-1]) == reshape_1[0] + ) + + if can_remove_reshape: + repl = graph.call_function(mkldnn._linear_pointwise.default, args) + repl.meta.update(reshape_2_node.meta) + reshape_2_node.replace_all_uses_with(repl) + old_linear_node = reshape_2_node.args[0] + reshape_1_node = old_linear_node.args[0] + graph.erase_node(reshape_2_node) + graph.erase_node(old_linear_node) + if len(reshape_1_node.users) == 0: + graph.erase_node(reshape_1_node) + + def is_linear_add_bias(match): + add_node = match.output_node() + linear_node = add_node.args[0] + weight_meta = linear_node.args[1].meta.get("val") + bias_meta = add_node.args[1].meta.get("val") + if weight_meta is None or bias_meta is None: + return False + return ( + linear_node.args[2] is None + and bias_meta.dim() == 1 + and bias_meta.size(0) == weight_meta.size(0) + ) + + # convert linear+bias to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.add.Tensor, + CallFunction(mkldnn._linear_pointwise.default, *_linear_args), + Arg(), + ), + pass_number=1, + extra_check=is_linear_add_bias, + ) + def linear_bias_pattern(match, *args): + graph = match.graph + add_node = match.output_node() + linear_node = add_node.args[0] + new_args = list(linear_node.args) + new_args[2] = add_node.args[1] + repl = graph.call_function( + mkldnn._linear_pointwise.default, tuple(new_args) + ) + repl.meta.update(add_node.meta) + add_node.replace_all_uses_with(repl) + match.erase_nodes(graph) + + def _is_packable_mkldnn_rnn_layer(match): + lstm_node = match.output_node() + POS_WEIGHTS = [1, 2] + POS_INPUTS = [0, 5, 6] + POS_ARGS = POS_WEIGHTS + POS_INPUTS + # Weights should be Constant + if any( + lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS + ): + return False + + # Meta info for weights and inputs should be available + if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): + return False + + # Check device + if any( + lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" + for POS_ARG in POS_ARGS + ): + return False + + # Check dtype + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 + and not mkldnn._is_mkldnn_bf16_supported() + for POS_ARG in POS_ARGS + ): + return False + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 + and not mkldnn._is_mkldnn_fp16_supported() + for POS_ARG in POS_ARGS + ): + return False + + return True + + def _is_packable_convolution(match): + """ + Check if the node is supported for MKLDNN convolution. + """ + conv_node = match.output_node() + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + input_size = input_meta_value.shape + if conv_node.args[1].op != "get_attr": + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 4 + ): + return False + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not mkldnn._is_mkldnn_bf16_supported(): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not mkldnn._is_mkldnn_fp16_supported(): + return False + is_transposed = conv_node.args[-3] + if is_transposed: + # TODO: Support dynamic shape case for MKLDNN conv transpose. + if has_free_symbols(input_size): + return False + groups = conv_node.args[-1] + in_channels = weight_meta_value.size(0) + # doesn't support group_depthwise_conv_transpose. + if groups > 1 and groups == in_channels: + return False + # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big + output_paddings = conv_node.args[-2] + strides = conv_node.args[3] + if any( + output_padding >= stride + for output_padding, stride in zip(output_paddings, strides) + ): + return False + return True + + def _is_packable_linear(match): + """ + Check if the node is supported for MKLDNN linear. + """ + linear_node = match.output_node() + # weight_idx is 1 for aten.mm and is 2 for aten.addmm + weight_idx = 2 if linear_node.target == aten.addmm.default else 1 + if linear_node.args[weight_idx].op != "get_attr": + return False + input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") + weight_meta_value = linear_node.args[weight_idx].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + batch_size = input_meta_value.shape[0] + is_lp_weight = weight_meta_value.dtype in ( + torch.bfloat16, + torch.float16, + ) + # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not is_lp_weight + and not mkldnn._is_mkldnn_acl_supported() + and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) + ): + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 2 + ): + return False + if weight_idx == 2: + bias_meta_value = linear_node.args[0].meta.get("val") + if ( + bias_meta_value is None + or meta_value.device.type != "cpu" + or bias_meta_value.dim() != 1 + or bias_meta_value.size(0) != weight_meta_value.size(1) + ): + return False + + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not mkldnn._is_mkldnn_bf16_supported(): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not mkldnn._is_mkldnn_fp16_supported(): + return False + return True + + _aten_conv_args = ( + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + KeywordArg("is_transposed"), + Arg(), + Arg(), + ) + + _aten_mkldnn_rnn_layer_args = ( + Arg(), # input + Arg(), # weight0 + Arg(), # weight1 + Arg(), # weight2 + Arg(), # weight3 + Arg(), # hx_ + Arg(), # cx_ + KeywordArg("reverse"), # reverse + Arg(), # batch_sizes + Arg(), # mode + Arg(), # hidden_size + Arg(), # num_layers + Arg(), # has_biases + Arg(), # bidirectional + Arg(), # batch_first + Arg(), # train + ) + + def _register_weight_pack_pass(): + @register_freezing_graph_pattern( + CallFunction(aten.convolution.default, *_aten_conv_args), + extra_check=_is_packable_convolution, + ) + def convolution(match, *args, **kwargs): + is_transposed = kwargs.get("is_transposed") + assert isinstance(is_transposed, bool) + graph = match.graph + conv_node = match.output_node() + input_size = conv_node.args[0].meta.get("val").shape + with graph.inserting_before(conv_node): + constant_args = [args[4], args[3], args[5], args[-1]] + packed_weight_op = mkldnn._reorder_convolution_weight + packed_conv_op = mkldnn._convolution_pointwise.default + if is_transposed: + constant_args.insert(1, args[-2]) # output_padding + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + packed_conv_op = mkldnn._convolution_transpose_pointwise.default + if not has_free_symbols(input_size): + packed_weight_inputs = ( + (args[1],) + tuple(constant_args) + (input_size,) + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + else: + assert not is_transposed + # For dynamic shape case, we need to pack weight in runtime. + packed_weight_node = args[1] + packed_conv_inputs = ( + (args[0], packed_weight_node, args[2]) + + tuple(constant_args) + + ("none", [], "") + ) + packed_conv_node = graph.create_node( + "call_function", packed_conv_op, tuple(packed_conv_inputs) + ) + conv_node.replace_all_uses_with(packed_conv_node) + packed_conv_node.meta.update(conv_node.meta) + graph.erase_node(conv_node) + + @register_freezing_graph_pattern( + CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), + extra_check=_is_packable_mkldnn_rnn_layer, + ) + def mkldnn_rnn_layer(match, *args, **kwargs): + def get_item(graph, node, index): + return graph.call_function(operator.getitem, (node, index)) + + graph = match.graph + lstm_node = match.output_node() + input = args[0] + weight0, weight1 = args[1:3] + reverse = kwargs.get("reverse") + packed_lstm_op = aten.mkldnn_rnn_layer.default + hidden_size = args[9] + has_biases = args[11] + batch_first = args[13] + with graph.inserting_before(lstm_node): + packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default + packed_weight_inputs = ( + weight0, + weight1, + hidden_size, + reverse, + has_biases, + batch_first, + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, packed_weight_inputs, {}, "name" + ) + packed_weight_items = [ + get_item(graph, packed_weight_node, i) for i in range(2) + ] + pack_lstm_inputs = ( + args[0], + *packed_weight_items, + args[3], + args[4], + args[5], + args[6], + reverse, + *args[7:], + ) + + packed_lstm_node = graph.create_node( + "call_function", packed_lstm_op, args=pack_lstm_inputs + ) + lstm_node.replace_all_uses_with(packed_lstm_node) + packed_lstm_node.meta.update(lstm_node.meta) + graph.erase_node(lstm_node) + + @register_freezing_graph_pattern( + CallFunction(aten.addmm.default, Arg(), Arg(), Arg()), + extra_check=_is_packable_linear, + ) + @register_freezing_graph_pattern( + CallFunction(aten.mm.default, Arg(), Arg()), + extra_check=_is_packable_linear, + ) + def linear(match, *args, **kwargs): + graph = match.graph + linear_node = match.output_node() + input = args[0] if linear_node.target == aten.mm.default else args[1] + bias = None if linear_node.target == aten.mm.default else args[0] + weight = args[1] if linear_node.target == aten.mm.default else args[2] + with graph.inserting_before(linear_node): + transpose_weight_node = graph.create_node( + "call_function", aten.permute.default, (weight, (1, 0)) + ) + weight_dtype = weight.meta.get("val").dtype + is_lp_weight = weight_dtype in ( + torch.bfloat16, + torch.float16, + ) + batch_size = input.meta.get("val").shape[0] + if has_free_symbols(batch_size): + assert ( + is_lp_weight or mkldnn._is_mkldnn_acl_supported() + ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + packed_weight_op = ( + mkldnn._reorder_linear_weight + if (is_lp_weight or mkldnn._is_mkldnn_acl_supported()) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node) + if is_lp_weight or mkldnn._is_mkldnn_acl_supported(): + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + packed_linear_node = graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + linear_node.replace_all_uses_with(packed_linear_node) + packed_linear_node.meta.update(linear_node.meta) + graph.erase_node(linear_node) + + def _eliminate_duplicate_packed_nodes(gm): + """ + Combine packed weight nodes with the same inputs to reduce memory usage. + for example: + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(32, 32, bias=True) + + def forward(self, x): + return self.linear(self.linear(x)) + + the above's packed weight nodes are duplicate if two linear calls have same input size. + """ + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + + packed_weight_ops = [ + torch._C._nn.mkldnn_reorder_conv2d_weight, + mkldnn._reorder_convolution_transpose_weight, + mkldnn._reorder_linear_weight, + mkldnn._reorder_mkldnn_rnn_layer_weight, + ] + if torch._C.has_mkl: + packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) + + for node in gm.graph.nodes: + if node.target in packed_weight_ops and len(node.args[0].users) > 1: + for user_node in list(node.args[0].users.keys()): + if ( + user_node.target == node.target + and user_node != node + and user_node.args == node.args + ): + user_node.replace_all_uses_with(node) + gm.graph.erase_node(user_node) + + @functools.lru_cache(None) + def _mkldnn_fusion_init(): + # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. + # Otherwise even the matmul or innerproduct can not be accelerated with acl + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and not torch.ops.mkldnn._is_mkldnn_acl_supported() + ): + _register_unary_fusion() + _register_inplace_fusion() + _register_binary_unary_fusion() + _register_binary_fusion() + _register_quantization_lowerings() + + @functools.lru_cache(None) + def _mkldnn_weight_pack_init(): + if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): + _register_weight_pack_pass() + _recover_linear() + _register_quantization_weight_pack_pass() diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/numeric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e08cef8c47ed39a2cbac29b0bcc4aaaa479c37a0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/numeric_utils.py @@ -0,0 +1,210 @@ +import gc +import logging +import os +import random +import traceback + +import numpy + +import torch +import torch.optim as optim + +from .. import config + +logger: logging.Logger = logging.getLogger(__name__) + +MAIN_RANDOM_SEED = 1337 + +# Set the CUBLAS_WORKSPACE_CONFIG environment variable +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +# If the two forward functions involve any non-deterministic operations, +# such as certain types of parallelism or asynchronous execution, +# this can also lead to different outputs. +def set_deterministic() -> None: + """Make torch manual seed deterministic.""" + + torch.manual_seed(MAIN_RANDOM_SEED) + random.seed(MAIN_RANDOM_SEED) + numpy.random.seed(MAIN_RANDOM_SEED) + torch.use_deterministic_algorithms(True) + + +def clean_memory() -> None: + """Clean memory to avoid OOM.""" + gc.collect() + torch.cuda.empty_cache() + + +# We compare the numerical results before and after pre/post grad fx passes +# transformation to make sure the numerical results are the same. +def compare_dict_tensors(dict_base, dict_control, precision): + if len(set(dict_base.keys())) != len(set(dict_control.keys())): + logger.warning("Mismatch keys found before and after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) + return False + is_allclose = True + for key in dict_base.keys(): + if key not in dict_control: + logger.warning( + "Mismatch parameter name %s does not exist after pre/post grad fx passes", + key, + ) + # Some parameters have `None`, and not every param has a valid .grad field, we skip them + if dict_base[key] is None or dict_control[key] is None: + continue + if not torch.allclose( + dict_base[key], + dict_control[key], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.warning( + "Mismatch parameter values found before and after pre/post grad fx passes." + ) + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) + is_allclose = False + return is_allclose + + +def compare_tuple_tensors(tuple_base, tuple_control, precision): + if len(tuple_base) != len(tuple_control): + logger.warning( + "Mismatch fw output length. before transformation: %s, after transformation: %s", + len(tuple_base), + len(tuple_control), + ) + return False + is_allclose = True + for i in range(len(tuple_base)): + # Some parameters have `None`, we skip them + if tuple_base[i] is None or tuple_control[i] is None: + continue + if not torch.allclose( + tuple_base[i], + tuple_control[i], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] + ) + is_allclose = False + return is_allclose + + +def compare_parameters(model_base, model_control, precision): + return compare_dict_tensors( + dict(model_base.named_parameters()), + dict(model_control.named_parameters()), + precision, + ) + + +def compare_forward_output(pred_base, pred_control, precision): + return compare_tuple_tensors( + pred_base, + pred_control, + precision, + ) + + +def compare_gradients(model_base, model_control, precision): + grad_base = {key: param.grad for key, param in model_base.named_parameters()} + grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} + return compare_dict_tensors( + grad_base, + grad_pt2, + precision, + ) + + +def run_model( + model_base, model_control, model_input, num_iterations=10, precision=1e-4 +): + clean_memory() + for i in range(num_iterations): + logger.info("start %s iteration", i) + set_deterministic() + pred_base = model_base(*model_input) + set_deterministic() + pred_control = model_control(*model_input) + + res = compare_parameters(model_base, model_control, precision) + logger.info("compare parameters. Numerical result : %s", res) + + res = compare_forward_output(pred_base, pred_control, precision) + logger.info("compare loss/predict. Numerical result : %s", res) + # tensor may not have a grad_fn + try: + _ = pred_base[0].sum().backward(retain_graph=True) + _ = pred_control[0].sum().backward(retain_graph=True) + res = compare_gradients(model_base, model_control, precision) + logger.info("compare param grad. Numerical result : %s", res) + except Exception as e: + logger.exception("Exception %s when compare gradients", e) + traceback.print_exc() + + if config.fx_passes_numeric_check["requires_optimizer"]: + try: + optimizer_base = optim.SGD( + [param for name, param in model_base.named_parameters()], lr=0.01 + ) + optimizer_base.step() + + optimizer_control = optim.SGD( + [param for name, param in model_control.named_parameters()], lr=0.01 + ) + optimizer_control.step() + + res = compare_parameters(model_base, model_control, precision) + logger.info( + "compare parameters with optimizer added. Numerical result : %s", + res, + ) + except Exception as e: + logger.exception( + "Exception %s when optimizer is added to check parameter names", e + ) + traceback.print_exc() + else: + logger.warning( + "no parameter with optimizer to compare with length %s before transformation" + " and the length %s after transformation", + len(dict(model_base.named_parameters())), + len(dict(model_control.named_parameters())), + ) + + +def numeric_check_if_enabled( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations, + precision, +): + # need to topo-sort graphmodule before we run the model, + # otherwise it may fail as refer before def + # fail silently in order not to block the model run + try: + with torch.autograd.set_detect_anomaly(True): + run_model( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations=num_iterations, + precision=precision, + ) + except Exception as e: + logger.warning( + "Runtime numeric check failed in pre grad fx passes with error: %s", e + ) + traceback.print_exc() diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/pad_mm.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/pad_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8d8b00a596ef6da021b14e4580bfcad192726e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/pad_mm.py @@ -0,0 +1,567 @@ +import functools +from typing import List, Optional, Set, Union + +import torch +from torch import Tensor +from torch._inductor import utils +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._mode_utils import no_dispatch +from torch.utils._triton import has_triton + +from ..pattern_matcher import ( + fwd_only, + joint_fwd_bwd, + Match, + MatchContext, + register_replacement, +) +from ..utils import is_view + +aten = torch.ops.aten + + +# This flag is only used for testing purpose. +# Changing it to True will ignore comparing do_bench times +# between original pattern and padded one. +_skip_do_bench_times = False + + +def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]: + kwargs = match.kwargs + return [kwargs[name].meta["val"] for name in kwarg_names] + + +def unwrap_fake_args(*arg_names): + def decorator(func): + def wrapper(match): + fake_tensors = fetch_fake_tensors(match, arg_names) + return func(*fake_tensors) + + return wrapper + + return decorator + + +def get_alignment_size(x: Tensor) -> int: + if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16: + return 8 + elif x.dtype == torch.float32 or x.dtype == torch.float: + return 4 + else: + return 0 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def check_dtype(a: Tensor, b: Tensor) -> bool: + return a.is_floating_point() and b.is_floating_point() + + +def _result_layout_affects_graph_output(match: Match) -> bool: + """ + Check if the matched GEMM operation potentially affects the graph output strides. + returns True if the matched op's output buffer does not pass through functions which certainly + redefine the memory layout before being part of the graph output. + """ + + if match.ctx is not None: + assert isinstance(match.ctx, MatchContext) + search_node: torch.fx.Node = match.output_node() + else: + return True + + assert search_node is not None + seen: Set[torch.fx.Node] = set() + + def find_output(node: torch.fx.Node, is_start_node=False): + if not isinstance(node, torch.fx.Node): + return False + if node in seen: + return False + seen.add(node) + if node.op == "output": + return True + if node.op != "call_function": + return False + if not is_start_node and ( + (not isinstance(node.target, torch._ops.OpOverload)) + or (not is_view(node.target)) + ): + return False + if node.users is not None and len(node.users) > 0: + for n in node.users: + if find_output(n): + return True + return False + + return find_output(search_node, True) + + +def should_pad_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + # It's fine we have symbolic shapes or strides as long as they + # have hints. Later, we will make sure we only pad non-symbolic dimensions. + def valid_shape_and_stride(t: Optional[Tensor]) -> bool: + if t is None: + return True + + symbolic_cnt = 0 + for x in t.size(): + if isinstance(x, int): + continue + elif utils.is_symbolic(x): + if not x.node.has_hint(): + return False + symbolic_cnt += 1 + else: + return False + # filter out cases where all dimentions are symbolic + if symbolic_cnt == len(t.size()): + return False + return all( + isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) + for x in t.stride() + ) + + return ( + torch._inductor.config.shape_padding + and check_device(mat1, mat2) + and check_dtype(mat1, mat2) + and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) + ) + + +def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int: + # we don't pad x if it is symbolic + if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: + return 0 + return int((x // alignment_size + 1) * alignment_size) - x + + +def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: + if padded_length == 0: + return x + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def addmm_pattern( + input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float +) -> Tensor: + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def should_pad_addmm(match: Match) -> bool: + if ( + torch._inductor.config.keep_output_stride + and _result_layout_affects_graph_output(match) + ): + return False + mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) + return should_pad_common(mat1, mat2, input) and should_pad_bench( + mat1, mat2, torch.ops.aten.addmm, input=input + ) + + +def addmm_replace( + input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0 +) -> Tensor: + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + + if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0: + return pad_addmm( + input, + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + beta, + alpha, + ) + + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def pad_addmm( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + beta=1.0, + alpha=1.0, +): + # addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 0) + + # the add broadcasts, so we only pad if the dimension != 1 + if input is not None and k_padded_length == 0: + if n_padded_length != 0: + if input.dim() == 2 and input.shape[1] != 1: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1 and input.shape[0] != 1: + input = pad_dim(input, n_padded_length, 0) + elif m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: + input = pad_dim(input, m_padded_length, 0) + + if k_padded_length != 0: + return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha) + elif n_padded_length != 0: + return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[ + :, :-n_padded_length + ] + else: + return addmm_replace(input, mat1, mat2, beta=beta, alpha=alpha)[ + :-m_padded_length, : + ] + + +def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: + denominator = M * K + N * K + M * N + if denominator == 0: + return False + arithmetic_intensity = (M * N * K) / denominator + + # Fails with AMD + try: + machine_balance = ( + 1000 * utils.get_device_tflops(dtype) + ) / utils.get_gpu_dram_gbps() + except Exception: + return True + + # dram_gbps might be underestimating bandwidth because of cache. + # if we estimate machine balance too low we might miss some speedups, + # if we extimate too high there will be unnecessary compilation time increase. + # TODO - finetune coefficient here. As a reference point, Triton mm model assumes + # 80% of reads are in cache and cache is 4x faster than dram_gbps + machine_balance = machine_balance * 0.5 + + return arithmetic_intensity > machine_balance + + +@functools.lru_cache(None) +def get_pad_cache(): + return torch._inductor.codecache.LocalCache() + + +def get_cached_should_pad(key): + return get_pad_cache().lookup(key) + + +def set_cached_should_pad(key, value): + return get_pad_cache().set_value(key, value=value) + + +def should_pad_bench_key( + mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None +) -> str: + def tensor_key(t): + return (t.shape, t.stride(), t.dtype) + + tf32_key = ( + None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 + ) + key = ( + tensor_key(mat1), + tensor_key(mat2), + op, + input if input is None else tensor_key(input), + tf32_key, + ) + + return str(key) + + +def should_pad_bench( + mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None +) -> bool: + if not has_triton(): + return False + + do_bench = functools.partial( + utils.do_bench, + warmup=5, + ) + + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m = mat1.shape[0] + k = mat1.shape[1] + n = mat2.shape[1] + + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + elif op is torch.ops.aten.bmm: + m = mat1.shape[1] + k = mat1.shape[2] + n = mat2.shape[2] + + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + if not is_mm_compute_bound(m, k, n, mat1.dtype): + return False + + # We don't want to look up the cache for cases that are trivially false + # since it does file io + key = should_pad_bench_key(mat1, mat2, op, input) + + cached_pad = get_cached_should_pad(key) + if cached_pad is not None: + return cached_pad + + def realize_symbols(ds): + return [d if isinstance(d, int) else d.node.hint for d in ds] + + def realize_tensor(t): + if isinstance(t, FakeTensor): + size_hints = realize_symbols(t.size()) + stride_hint = realize_symbols(t.stride()) + real_size = ( + sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 + ) + real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) + return torch.as_strided(real_t, size_hints, stride_hint) + else: + return torch.randn_like(t) + + mat1 = realize_tensor(mat1) + mat2 = realize_tensor(mat2) + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + ori_time = do_bench( + lambda: op(mat1, mat2), + ) + else: + if input is not None: + input = realize_tensor(input) + ori_time = do_bench( + lambda: op(input, mat1, mat2), + ) + + mat1_pad = torch.randn_like(mat1) + mat2_pad = torch.randn_like(mat2) + + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda: + input_pad = torch.randn_like(input) + pad_time = do_bench( + lambda: pad_addmm( + input_pad, + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + ), + ) + elif op is torch.ops.aten.mm: + pad_time = do_bench( + lambda: pad_mm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + ), + ) + else: + pad_time = do_bench( + lambda: pad_bmm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + ), + ) + + # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable + # tradeoff between performance improvement from shape padding and overhead from additional memory ops + # TODO: Build a learned model which would be better than this heuristic + should_pad = _skip_do_bench_times or ori_time > pad_time * 1.1 + set_cached_should_pad(key, should_pad) + + return should_pad + + +def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.mm(mat1, mat2) + + +def should_pad_mm(match: Match) -> bool: + if ( + torch._inductor.config.keep_output_stride + and _result_layout_affects_graph_output(match) + ): + return False + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + mat1, mat2, torch.ops.aten.mm + ) + + +def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + + return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length) + + +def pad_mm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> Tensor: + # mm_replace will go through pad_mm multiple times if multiple dimensions are needed to be padded + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length] + else: + mat1 = pad_dim(mat1, m_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :] + + +def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.bmm(mat1, mat2) + + +def should_pad_bmm(match: Match) -> bool: + if ( + torch._inductor.config.keep_output_stride + and _result_layout_affects_graph_output(match) + ): + return False + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + mat1, mat2, torch.ops.aten.bmm + ) + + +def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) + + if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0: + return pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length) + + return aten.bmm(mat1, mat2) + + +def pad_bmm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> Tensor: + # bmm_replace will go through pad_bmm multiple times if multiple dimensions are needed to be padded + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 2) + mat2 = pad_dim(mat2, k_padded_length, 1) + + return aten.bmm(mat1, mat2) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 2) + return aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous() + else: + mat1 = pad_dim(mat1, m_padded_length, 1) + return aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous() + + +@functools.lru_cache(None) +def _pad_mm_init(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + + dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + + dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) + + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + rep = {"beta": 0.213377, "alpha": 0.113377} + + for pattern, replacement, args, workaround, extra_check in [ + ( + mm_pattern, + mm_replace, + [dim2a(), dim2b()], + {}, + should_pad_mm, + ), + ( + bmm_pattern, + bmm_replace, + [dim3a(), dim3b()], + {}, + should_pad_bmm, + ), + ( + addmm_pattern, + addmm_replace, + [dim1a(), dim2a(), dim2b()], + rep, + should_pad_addmm, + ), + ]: + assert isinstance(workaround, dict) # mypy is unable to infer the type properly + register_replacement( + pattern, + replacement, + args, + joint_fwd_bwd, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) + register_replacement( + pattern, + replacement, + args, + fwd_only, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/post_grad.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/post_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..a10893ae1a574896c3e977710fb451dab2ea2b22 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/post_grad.py @@ -0,0 +1,1100 @@ +import copy +import functools +import itertools +import logging +import operator +from collections import Counter, defaultdict +from typing import Any, Dict, List, Optional, Set, Union + +from sympy import Expr + +import torch +import torch._inductor as inductor +import torch.utils._pytree as pytree +from torch import fx +from torch._decomp import register_decomposition +from torch._dynamo.utils import counters, optimus_scuba_log + +from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype + +from torch._utils_internal import upload_graph +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + +from .. import config, ir, pattern_matcher +from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage + +from ..lowering import lowerings as L +from ..pattern_matcher import ( + _return_true, + Arg, + CallFunction, + CallFunctionVarArgs, + filter_nodes, + get_arg_value, + get_mutation_region_id, + Ignored, + init_once_fakemode, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from ..utils import decode_device, is_pointwise_use +from ..virtualized import V +from .group_batch_fusion import group_batch_fusion_passes +from .reinplace import reinplace_inplaceable_ops + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] +# patterns applied only in inference +inference_patterns = PatternMatcherPass() +decompose_mm_pass = PatternMatcherPass() + + +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): + """ + Passes that run on after grad. This is called once on the forwards + graph and once on the backwards graph. + + The IR here has been normalized and functionalized. + """ + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if is_inference and config.reorder_for_locality: + reorder_for_locality(gm.graph) + + fake_tensor_updater = FakeTensorUpdater(gm.graph) + + if config.post_grad_custom_pre_pass is not None: + config.post_grad_custom_pre_pass(gm.graph) + + if config.pattern_matcher: + lazy_init() + inductor_before_change = copy.deepcopy(counters["inductor"]) + group_batch_fusion_passes(gm.graph, pre_grad=False) + if counters["inductor"] != inductor_before_change: + optimus_scuba_log["group_batch_fusion_post_grad"] = upload_graph(gm.graph) + remove_noop_ops(gm.graph) + for patterns in pass_patterns: + patterns.apply(gm.graph) # type: ignore[arg-type] + if is_inference: + inference_patterns.apply(gm.graph) # type: ignore[arg-type] + decompose_mm_pass.apply(gm.graph) # type: ignore[arg-type] + + if config.post_grad_custom_post_pass is not None: + config.post_grad_custom_post_pass(gm.graph) + + stable_topological_sort(gm.graph) + + move_constructors_to_cuda(gm.graph) + + fake_tensor_updater.incremental_update() + + # Keep these last, since they introduces mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + reinplace_inplaceable_ops(gm.graph) + decompose_auto_functionalized(gm.graph) + + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn: + from . import decompose_mem_bound_mm # noqa: F401 + from .mkldnn_fusion import _mkldnn_fusion_init + + _mkldnn_fusion_init() + + +def reorder_for_locality(graph: torch.fx.Graph): + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + and get_mutation_region_id(graph, node) + == get_mutation_region_id(graph, other_node) + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = set() + + # only reorder nodes before the first copy_ in the graph. + # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, + # and this reordering doesnt work well with mutation + first_copy = next( + ( + node + for node in graph.nodes + if node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + ), + None, + ) + past_mutating_epilogue = True if first_copy is None else False + + for node in reversed(graph.nodes): + seen_nodes.add(node) + if not past_mutating_epilogue: + past_mutating_epilogue = node is first_copy + continue + + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): + """ + Register an aten to inductor IR replacement pattern + """ + return pattern_matcher.register_lowering_pattern( + pattern, extra_check, pass_dict=pass_patterns[pass_number] + ) + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +def is_valid_mm_plus_mm(match: Match): + *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape + *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape + if k1 != k2: + return False + + *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape + *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape + if k3 != k4: + return False + + if m1 != m2 or n1 != n2: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), + CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), + ), + extra_check=is_valid_mm_plus_mm, +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +def cuda_and_enabled_mixed_mm(match): + return (config.use_mixed_mm or config.force_mixed_mm) and getattr( + match.kwargs["mat1"].meta.get("val"), "is_cuda", False + ) + + +def cuda_and_enabled_mixed_mm_and_not_int8(match): + return ( + cuda_and_enabled_mixed_mm(match) + and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False) + and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8) + != torch.int8 + ) # bitshift numerics in triton and pytorch don't match for torch.int8 + + +""" + this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor + (where the int4 and uint4x2 are represented with int8 and uint8 respectively) + where every other row of the int4 is packed with the row above it as: + uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4 + + unpack formulas: + int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8 + int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8 + + thus matching on unpack formula: + torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8)) + + note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior + of the kernel matches the pytorch formula for all dtypes except torch.int8 + where the bitwise numerics in triton do not match those in pytorch. +""" + + +@register_lowering_pattern( + CallFunction( + aten.mm.default, + KeywordArg("mat1"), + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + aten.bitwise_and.Scalar, + KeywordArg("mat2"), + 0xF, + ), + CallFunction( + aten.__rshift__.Scalar, + KeywordArg("mat2"), + 4, + ), + ), + 1, + ), + KeywordArg("mat2_mm_shape"), + ), + KeywordArg("mat2_dtype"), + ), + 8, + ), + ), + extra_check=cuda_and_enabled_mixed_mm_and_not_int8, +) +def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype): + return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm( + mat1, mat2, mat2_mm_shape, mat2_dtype + ) + + +""" + torch.mm(mat1, mat2.to(mat2_dtype)) +""" + + +@register_lowering_pattern( + CallFunction( + aten.mm, + KeywordArg("mat1"), + CallFunction( + prims.convert_element_type.default, + KeywordArg("mat2"), + KeywordArg("mat2_dtype"), + ), + ), + extra_check=cuda_and_enabled_mixed_mm, +) +def mixed_mm(match: Match, mat1, mat2, mat2_dtype): + return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype) + + +@register_graph_pattern( + CallFunction( + aten.cumsum.default, + CallFunction( + torch.ops.aten.full.default, + KeywordArg("shape"), + KeywordArg("fill_value"), + dtype=KeywordArg("dtype"), + layout=Ignored(), + device=KeywordArg("device"), + pin_memory=False, + _users=MULTIPLE, + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + pass_dict=pass_patterns[1], +) +def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): + """Based on a pattern in OPTForCausalLM""" + + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + # cumsum promotes all integral types to int64 + dtype = torch.int64 + + def repl(*shape): + dim_size = shape[dim] + idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) + + inter_shape = [1] * len(shape) + inter_shape[dim] = dim_size + return (idx * fill_value).view(inter_shape).expand(shape) + + # only replace the output node, not all nodes + match.nodes = [match.output_node()] + with V.fake_mode: + match.replace_by_example(repl, list(shape)) + + +def shape_of_mm(a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + +@register_lowering_pattern( + CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), +) +def cat_mm(match, inputs, dim): + return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm) + + +@register_lowering_pattern( + CallFunction( + aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg() + ), +) +def cat_addmm(match, inputs, dim): + def shape_of(bias, a, b): + m, _ = a.get_size() + _, n = b.get_size() + return [m, n] + + return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of) + + +def cat_tuned_op(match, inputs, dim, *, op, shape_of): + """ + Memory planning to remove cat. We can't use the stock memory + planner since autotuning matmuls needs to know the output layout. + """ + if len(inputs) == 1: + return op(*inputs[0]) + + # TODO(jansel): rewrite this as a bmm? + if dim < 0: + dim += len(shape_of(*inputs[0])) + assert dim in (0, 1) + notdim = 1 - dim + + new_size: Optional[Union[List[Expr], List[int]]] = None + offsets_start = [] + offsets_end = [] + + # compute output sizes + for i in range(len(inputs)): + shape = shape_of(*inputs[i]) + if new_size is None: + new_size = shape + else: + new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload] + shape[notdim], new_size[notdim] + ) + new_size[dim] += shape[dim] + offsets_start.append(new_size[dim] - shape[dim]) + offsets_end.append(new_size[dim]) + + assert new_size is not None + dtype = functools.reduce( + torch.promote_types, + [x.get_dtype() for x in itertools.chain.from_iterable(inputs)], + ) + device = inputs[0][0].get_device() + kernel = ir.ConcatKernel( + name=None, + layout=ir.FixedLayout(device, dtype, new_size), + inputs=[], + ) + kernel_tensor = ir.TensorBox.create(kernel) + + for i in range(len(inputs)): + dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i]) + src = op(*inputs[i], layout=dst.get_layout()).data.data + assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer)) + src.layout = ir.AliasedLayout(dst) + kernel.inputs.append(src) + + kernel.name = V.graph.register_buffer(kernel) + kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs) + return kernel_tensor + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + _cat_1, + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + # Optimization is optional, because we can just not fold the cat + # size should be within first.get_size()[dim] such that the optimization is valid. + # For negative `end`, we currently fallback to not optimizing. + if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +def is_valid_splitwithsizes_cat(match): + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + cat_nodes = filter_nodes(match.nodes, aten.cat) + get_item_nodes = filter_nodes(match.nodes, operator.getitem) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + # The dim of split and cat should match for passthrough + if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): + return False + get_item_args = { + get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes + } + assert None not in get_item_args + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # All parts of split should be included in the cat + if get_item_args != set(range(len(split_sizes))): + return False + # The order of get_item_args should same with cat_node used. + # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), + # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). + cat_items_args_order = [ + get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) + ] + if cat_items_args_order != list(range(len(split_sizes))): + return False + + return True + + +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" + val1 = node1.meta.get("val") + val2 = node2.meta.get("val") + return ( + val1 is not None + and val2 is not None + and statically_known_true(sym_eq(val1.size(), val2.size())) + and val1.layout == val2.layout + and val1.dtype == val2.dtype + and val1.device == val2.device + and ( + val1.layout != torch.strided + or statically_known_true(sym_eq(val1.stride(), val2.stride())) + ) + ) + + +noop_registry: Dict[Any, Any] = {} + + +def register_noop_decomp(targets, nop_arg=0): + def register_fun(cond): + register_decomposition(targets, registry=noop_registry, unsafe=True)( + (cond, nop_arg) + ) + return cond + + return register_fun + + +@register_noop_decomp(aten.slice) +def slice_noop(self, dim=0, start=None, end=None, step=1): + if start is None or end is None: + return False + if start == 0 and end >= 2**63 - 1 and step == 1: + return True + return False + + +@register_noop_decomp(aten.slice_scatter, 1) +def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): + if start is None: + start = 0 + if end is None: + end = 2**63 - 1 + if start == 0 and end >= 2**63 - 1 and step == 1: + return True + return False + + +@register_noop_decomp(aten.repeat) +def repeat_noop(self, repeats): + return all(r == 1 for r in repeats) + + +@register_noop_decomp(aten.constant_pad_nd) +def constant_pad_nd(x, padding, fill_value=0): + return all(p == 0 for p in padding) + + +@register_noop_decomp(torch.ops.prims.convert_element_type) +def convert_element_type_noop(x, dtype: torch.dtype): + return x.dtype == dtype + + +@register_noop_decomp(torch.ops.prims.device_put) +def device_put_noop(x, device): + return x.device == decode_device(device) + + +@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) +def int_noop(x): + return is_integer_dtype(x.dtype) + + +@register_noop_decomp([aten.pow]) +def pow_noop(a, b): + return isinstance(b, int) and b == 1 + + +@register_noop_decomp([aten.cat], lambda args: args[0][0]) +def cat_noop(inputs, dim=0): + return len(inputs) == 1 + + +@register_noop_decomp(aten.view) +def view_noop(arg, size): + return arg.shape == size + + +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) +def true_noop(*args, **kwargs): + return True + + +def remove_noop_ops(graph: torch.fx.Graph): + """ + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. + """ + inputs = set() + input_storages = set() + output_storages = set() + + for node in graph.nodes: + if node.op == "placeholder": + inputs.add(node) + input_storages.add(get_node_storage(node)) + else: + break + + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" + for out in output_node.args[0]: + if isinstance(out, torch.fx.Node): + output_storages.add(get_node_storage(out)) + + for node in graph.nodes: + if node.target in noop_registry: + cond, src_index = noop_registry[node.target] + if isinstance(src_index, int): + src = node.args[src_index] + else: + src = src_index(node.args) + if not isinstance(src, torch.fx.Node): + continue + # Don't introduce new aliasing between inputs and outputs. + # See fx_passes/README.md for a discussion of why this is + # necessary. + node_storage = get_node_storage(node) + src_storage = get_node_storage(src) + node_is_view = node_storage == src_storage + if ( + not node_is_view + and node_storage in output_storages + and (src_storage in input_storages or src_storage in output_storages) + ): + continue + + # Even if input and outputs are expected to alias, + # don't make "node is src" True + if ( + node_is_view + and node in output_node.args + and (src in inputs or src in output_node.args) + ): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + if same_meta(node, src) and cond(*args, **kwargs): + node.replace_all_uses_with(src) + graph.erase_node(node) + + +def decompose_auto_functionalized(graph): + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + pass_dict=graph_pass, + ) + def replacement(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense + + only_clone_these_tensors = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) + + with V.fake_mode: + match.replace_by_example(decomp, flat_args, run_dce=False) + + graph_pass.apply(graph) + for node in graph.nodes: + if node.target is torch.ops.higher_order.auto_functionalized: + raise AssertionError("auto_functionalized was not removed") + + +@register_lowering_pattern( + CallFunction( + aten.cat, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split_with_sizes, + KeywordArg("input_"), + Ignored(), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_splitwithsizes_cat, +) +def splitwithsizes_cat_replace(match, input_): + return input_ + + +def is_valid_cat_splitwithsizes(match): + cat_nodes = filter_nodes(match.nodes, aten.cat) + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + + # the cat node has other users: can't eliminate + if len(cat_node.users) > 1: + return False + + # the dim of the cat and split should match + dim = get_arg_value(split_node, 2, "dim") + if dim != get_arg_value(cat_node, 1, "dim"): + return False + + cat_inputs = list(get_arg_value(cat_node, 0)) + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # the number of input tensors in cat and the + # length of the split sizes should match + if len(cat_inputs) != len(split_sizes): + return False + + for cat_input, split_size in zip(cat_inputs, split_sizes): + # each cat input tensor's size along dim + # should match the corresponding split size + if "val" not in cat_input.meta: + return False + cat_input_size = cat_input.meta["val"].size(dim) + if cat_input_size != split_size: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.split_with_sizes, + CallFunction( + aten.cat, + KeywordArg("input_"), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_cat_splitwithsizes, +) +def cat_splitwithsizes_replace(match, input_): + return input_ + + +def view_to_reshape(gm): + """ + Replace view ops in the GraphModule to reshape ops. + """ + for nd in gm.graph.nodes: + if nd.target == torch.ops.aten.view.default: + nd.target = torch.ops.aten.reshape.default + + +def should_prefer_unfused_addmm(match): + inp = match.kwargs["inp"] + if not inp.meta["val"].is_cuda: + return False + + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) + + +@register_graph_pattern( + CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + pass_dict=pass_patterns[2], + extra_check=should_prefer_unfused_addmm, +) +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): + def repl(inp, x1, x2): + return x1 @ x2 + inp + + with V.fake_mode: + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def is_valid_addmm_fusion(match): + mat1, mat2 = match.args + inp = match.kwargs["inp"] + + if not ( + isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) + ): + return False # Input is a number + + in_shape = inp.meta["val"].shape + mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] + matched = is_expandable_to(in_shape, mm_shape) + if not matched: + return False # Shape mismatch + + return not should_prefer_unfused_addmm(match) + + +@register_graph_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("inp"), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +@register_graph_pattern( + CallFunction( + aten.add, + KeywordArg("inp"), + CallFunction(aten.mm, Arg(), Arg()), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +def addmm(match, mat1, mat2, *, inp): + def repl(inp, mat1, mat2): + return aten.addmm(inp, mat1, mat2) + + with V.fake_mode: + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def check_shape_cuda_and_fused_int_mm_mul_enabled(match): + return ( + config.force_fuse_int_mm_with_mul + and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 + and getattr(match.args[2].meta.get("val"), "is_cuda", False) + ) + + +@register_lowering_pattern( + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.mul, + CallFunction( + aten._int_mm, + Arg(), + Arg(), + ), + Arg(), + ), + Arg(), + ), + check_shape_cuda_and_fused_int_mm_mul_enabled, +) +@register_lowering_pattern( + CallFunction( + aten.mul, + CallFunction( + aten._int_mm, + Arg(), + Arg(), + ), + Arg(), + ), + check_shape_cuda_and_fused_int_mm_mul_enabled, +) +def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None): + return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype) + + +class ConstructorMoverPass: + def __init__(self, target: str, allow_outputs: bool = False) -> None: + """ + Move constructors from cpu to the target_device. + + Sweeps through the module, looking for constructor nodes that can be moved + to the target_device. + + A constructor node can be moved to the target_device iff all of its users + can also be moved (tested by cannot_be_moved). Otherwise, all dependent + constructor nodes won't be moved. + + - target: target device type + - allow_outputs: allow outputs to be moved + """ + + self.target = target + self.allow_outputs = allow_outputs + + assert isinstance(target, str), ( + "target should be a string representing the device type. " + f"Got: {type(target).__name__}" + ) + + def allow_cpu_device(self, node: fx.Node) -> bool: + """ + Returns whether a node that returns a tensor on the target device may have + cpu tensors as input. + """ + return node.target in ( + torch.ops.aten.index.Tensor, + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten.copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.slice_scatter.default, + ) + + def cannot_be_moved(self, node: fx.Node) -> bool: + """ + Returns whether a node can be moved to the target device. + + If this function returns False, it means that this node and all of its users + won't be moved into the target device. + """ + if node.target == "output": + return not self.allow_outputs + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + return True + + return False + + def get_node_device(self, node: fx.Node) -> Optional[torch.device]: + """ + Get the device of a node. + """ + ten = node.meta.get("val") + return None if not isinstance(ten, torch.Tensor) else ten.device + + def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]: + """ + Get the number of cpu inputs to a node + """ + cpu_indeg: Dict[fx.Node, int] = Counter() + + for node in graph.nodes: + cpu_count = 0 + + def add_cpu_inp(node): + nonlocal cpu_count + device = self.get_node_device(node) + cpu_count += device is not None and device.type == "cpu" + + pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + + if cpu_count: + cpu_indeg[node] = cpu_count + + return cpu_indeg + + def __call__(self, graph: fx.Graph) -> None: + target_devices = set() + constructors = [] + + for node in graph.nodes: + device = self.get_node_device(node) + if device and device.type == self.target: + target_devices.add(device) + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + continue + + if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): + continue + + if not node.kwargs.get("device") == torch.device("cpu"): + continue + + constructors.append(node) + + # not handling multiple target devices initially + if not constructors or len(target_devices) != 1: + return + + movable_constructors = self.find_movable_constructors(graph, constructors) + + for node in movable_constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = next(iter(target_devices)) + node.kwargs = kwargs + + def find_movable_constructors( + self, graph: fx.Graph, constructors: List[fx.Node] + ) -> Set[fx.Node]: + """ + Starting from the cpu constructors, iterate through the graph and test that all of their + downstream uses can safely be moved to cpu. + """ + cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph) + + # which constructors cannot be moved to cuda + cannot_move_to_cuda: Set[fx.Node] = set() + + # For any node in the graph, which constructors does it have a dependency on + constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set) + + # if a cpu node has a dependency on two different cpu constructors, + # then if either constructor cannot be moved to cuda, the other cannot as well. + # In this case any node with a dependency on one will have a dependency on the other + equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = { + c: {c} for c in constructors + } + + def make_dependencies_equivalent( + set1: Set[fx.Node], set2: Set[fx.Node] + ) -> Set[fx.Node]: + # could use union find but not worth complexity here + set1.update(set2) + for obj in set1: + equal_constructor_sets[obj] = set1 + return set1 + + queue: List[fx.Node] = list(constructors) + + for c in queue: + constructor_dependencies[c].add(c) + + while queue: + node = queue.pop() + dependencies = constructor_dependencies[node] + + for user in node.users: + if self.cannot_be_moved(user): + cannot_move_to_cuda.update(dependencies) + break + + # this node was used on a op which takes in multiple devices and output a cuda + # tensor. we can convert its cpu input to cuda without making further changes + node_device = self.get_node_device(user) + if ( + self.allow_cpu_device(user) + and node_device + and node_device.type == self.target + ): + del cpu_indeg[user] + else: + # otherwise, we should continue look at its downstream uses + cpu_indeg[user] -= 1 + if cpu_indeg[user] == 0: + del cpu_indeg[user] + queue.append(user) + + unioned_set = make_dependencies_equivalent( + dependencies, constructor_dependencies[user] + ) + constructor_dependencies[user] = unioned_set + + for node in cpu_indeg: + if constructor_dependencies[node]: + cannot_move_to_cuda.update(constructor_dependencies[node]) + + all_cannot_move_to_cuda = cannot_move_to_cuda.copy() + for constructor in cannot_move_to_cuda: + all_cannot_move_to_cuda.update(equal_constructor_sets[constructor]) + + return set(constructors) - all_cannot_move_to_cuda + + +def move_constructors_to_cuda(graph: fx.Graph) -> None: + """ + Moves intermediary tensors which are constructed on the cpu to cuda when safe + """ + ConstructorMoverPass("cuda")(graph) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/pre_grad.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/pre_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..bed3b3229003399fdcfd979bd7e0ce79ef3786fe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/pre_grad.py @@ -0,0 +1,611 @@ +import copy +import logging +from typing import List, Optional + +import torch +import torch.nn as nn +from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log +from torch._utils_internal import upload_graph +from torch.fx.experimental.optimization import ( + matches_module_pattern, + replace_node_module, +) +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn import functional as F +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights + +from .. import config + +from ..fx_utils import matches_module_function_pattern +from ..pattern_matcher import ( + init_once_fakemode, + PatternMatcherPass, + stable_topological_sort, +) +from ..utils import is_cpu_device, pass_execution_and_save +from .group_batch_fusion import group_batch_fusion_passes +from .misc_patterns import numpy_compat_normalization + +log = logging.getLogger(__name__) + +normalization_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="normalization_pass" +) +merge_splits_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="merge_splits_pass" +) +split_cat_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="split_cat_pass" +) +unbind_stack_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="unbind_stack_pass" +) +efficient_conv_bn_eval_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="efficient_conv_bn_eval_pass" +) +merge_getitem_cat_pass = PatternMatcherPass( + prevent_match_across_mutations=True, pass_name="merge_getitem_cat_pass" +) + +fuse_split_linear_add_pass = PatternMatcherPass( + prevent_match_across_mutations=True, + pass_name="fuse_split_linear_add_pass", +) +fuse_chunk_squeeze_cat_pass = PatternMatcherPass( + prevent_match_across_mutations=True, + pass_name="fuse_chunk_squeeze_cat_pass", +) +remove_reshape_pass = PatternMatcherPass( + prevent_match_across_mutations=True, + pass_name="remove_reshape_pass", +) + +# based on predispatch aten IR +normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) +merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) +split_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) +unbind_stack_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) +merge_getitem_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True) + + +def fuse_parallel_linear_pass(graph): + return None + + +def remove_split_ops(graph, shape_prop): + return None + + +pattern_matcher_passes: List[PatternMatcherPass] = [ + normalization_pass, + merge_getitem_cat_pass, + merge_splits_pass, + split_cat_pass, + unbind_stack_pass, + efficient_conv_bn_eval_pass, +] +pattern_matcher_passes_aten: List[PatternMatcherPass] = [ + merge_getitem_cat_pass_aten, + merge_splits_pass_aten, + split_cat_pass_aten, + unbind_stack_pass_aten, +] + + +@init_once_fakemode +def lazy_init(): + from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401 + + if config.is_fbcode(): + from . import fb # type: ignore[attr-defined] # noqa: F401 + + +def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None): + """ + Apply passes on the input FX graph using Torch IR. + + WARNING: + The IR before grad is not functional or normalized, so it is harder + to write passes on this IR. Passes must be safe with respect to + aliasing and mutation and need to handle all possible arg schemas. + + Consider adding a new pass to post_grad.py or joint_graph.py which + are after functionalization and normalization. + """ + if config.pattern_matcher: + lazy_init() + if hasattr( + config, "fx_passes_numeric_check" + ) and config.fx_passes_numeric_check.get("pre_grad", False): + gm_before_fx_passes = gm.__copy__() + # explicitly run with predispatch atenIR based passes + if config.is_predispatch: + + def shape_prop(mod) -> None: + ShapeProp( + gm=mod, + fake_mode=detect_fake_mode(example_inputs), + ).propagate(*example_inputs) + + # normalization pass + pass_execution_and_save( + normalization_pass_aten.apply, + gm, + "[Pre grad(predispatch IR)]Apply normalization pass", + ) + pass_execution_and_save( + group_batch_fusion_passes, + gm, + "[Pre grad(predispatch IR)] Apply group_batch_fusion", + ) + pass_execution_and_save( + fuse_chunk_squeeze_cat_pass.apply, + gm, + "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass", + ) + pass_execution_and_save( + fuse_split_linear_add_pass.apply, + gm, + "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass", + ) + + log.debug( + "[Pre grad(predispatch IR)]Before split cat in pre grad pass. graph: %s", + gm.graph, + ) + for ind, pattern_matcher_pass_aten in enumerate( + pattern_matcher_passes_aten + ): + pass_execution_and_save( + pattern_matcher_pass_aten.apply, + gm, + f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}", + ) + pass_execution_and_save( + remove_reshape_pass.apply, + gm, + "[Pre grad(predispatch IR)] Apply remove_reshape_pass", + ) + pass_execution_and_save( + fuse_parallel_linear_pass, + gm, + "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass", + ) + pass_execution_and_save( + lambda graph: remove_split_ops(graph.owning_module, shape_prop), + gm, + "[Pre grad(predispatch IR)] Apply remove_split_ops", + ) + shape_prop(gm) + + else: + # We only log the graph with changes to avoid the excessive compilation time + # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ + if example_inputs is not None: + gm = fuse_fx(gm, example_inputs) + numpy_compat_normalization(gm.graph) + inductor_before_change = copy.deepcopy(counters["inductor"]) + group_batch_fusion_passes(gm.graph, pre_grad=True) + if counters["inductor"] != inductor_before_change: + optimus_scuba_log["group_batch_fusion_pre_grad"] = upload_graph( + gm.graph + ) + for pattern_matcher_pass in pattern_matcher_passes: + inductor_before_change = copy.deepcopy(counters["inductor"]) + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + if counters["inductor"] != inductor_before_change: + optimus_scuba_log[ + f"split_cat_pattern_{pattern_matcher_pass.pass_name}_pre_grad" + ] = upload_graph(gm.graph) + + if config.pre_grad_custom_pass is not None: + config.pre_grad_custom_pass(gm.graph) + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + if ( + config.pattern_matcher + and hasattr(config, "fx_passes_numeric_check") + and config.fx_passes_numeric_check.get("pre_grad", False) + and example_inputs is not None + ): + from .numeric_utils import numeric_check_if_enabled + + gm_after_fx_passes = gm.__copy__() + numeric_check_if_enabled( + gm_before_fx_passes, # type: ignore[possibly-undefined] + gm_after_fx_passes, + example_inputs, + config.fx_passes_numeric_check.get("num_iterations", 1), + config.fx_passes_numeric_check.get("precision", 1e-4), + ) + + return gm + + +def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: + is_cpu = is_cpu_device(example_inputs) + + fake_mode = detect_fake_mode(example_inputs) + + gm = sink_cat_after_pointwise(gm) + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + gm = linear_permute_fusion(gm) + gm = permute_linear_fusion(gm) + gm = permute_matmul_fusion(gm) + + # make sure the autograd is disabled. + if torch.is_grad_enabled() or not is_cpu: + return gm + if config.freezing: + gm = remove_identity(gm) + gm = fuse_conv_bn(gm) + return gm + + +def fetch_attr(target: str, mod): + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Removes all identity layers from the module. + """ + + class IdentityRemover(torch.fx.Transformer): + def call_module(self, target, args, kwargs): + if isinstance(self.submodules[target], nn.Identity): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return IdentityRemover(gm).transform() + + +def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: + """ + Fuses Convolution/BN layers for inference purposes. + """ + modules_patterns = [ + (torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d), + ] + module_function_patterns = [ + (torch.nn.Conv1d, F.batch_norm), + (torch.nn.Conv2d, F.batch_norm), + (torch.nn.Conv3d, F.batch_norm), + ] + modules = dict(gm.named_modules()) + for pattern in modules_patterns: + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + eval_mode = all(not n.training for n in [conv, bn]) + if not eval_mode: + continue + if not bn.track_running_stats: + continue + fused_conv = fuse_conv_bn_eval(conv, bn) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + gm.graph.lint() + for pattern in module_function_patterns: + for node in gm.graph.nodes: + if matches_module_function_pattern(pattern, node, modules): + # TODO: support kwargs. + if len(node.args) != 8: + continue + conv = modules[node.args[0].target] + bn_training = node.args[5] + bn_eps = node.args[7] + if conv.training or bn_training: + continue + if type(bn_eps) is not float: + continue + bn_args_is_constant = all( + n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5] + ) + if not bn_args_is_constant: + continue + bn_running_mean = fetch_attr(node.args[1].target, gm) + bn_running_var = fetch_attr(node.args[2].target, gm) + bn_weight = fetch_attr(node.args[3].target, gm) + bn_bias = fetch_attr(node.args[4].target, gm) + if bn_running_mean is None or bn_running_var is None: + continue + fused_conv = copy.deepcopy(conv) + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn_running_mean, + bn_running_var, + bn_eps, + bn_weight, + bn_bias, + ) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + gm.graph.lint() + gm.recompile() + + return gm + + +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["weight"] # type: ignore[return-value] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] # type: ignore[return-value] + else: + return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["other"] # type: ignore[return-value] + + +def check_permute(node: torch.fx.Node) -> bool: + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type] + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + def one_user(node): + users = list(node.users) + return users[0] if len(users) == 1 else None + + def is_view(node): + view = {"view"} + return node.op == "call_method" and node.target in view + + def is_pointwise_unary(node): + pointwise = {torch.relu, torch.tanh, "relu", "tanh"} + return node.op in {"call_function", "call_method"} and node.target in pointwise + + g = module.graph + for node in g.nodes: + if node.op != "call_function" or node.target != torch.cat: + continue + + cat_or_view = node + while True: + user = one_user(cat_or_view) + if not user or not is_view(user): + break + cat_or_view = user + + if user and is_pointwise_unary(user): + with g.inserting_before(node): + + def cat_args(tensors, dim=0): + return tensors, dim + + tensors, dim = cat_args(*node.args, **node.kwargs) + new_tensors = [ + g.create_node(user.op, user.target, args=(arg,), kwargs=user.kwargs) + for arg in tensors + ] + new_cat = g.create_node( + "call_function", torch.cat, args=(new_tensors, dim) + ) + user.replace_all_uses_with(cat_or_view) + node.replace_all_uses_with(new_cat) + g.erase_node(user) + g.erase_node(node) + g.lint() + module.recompile() + return module + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if ( + node.op == "call_method" + and node.target == "permute" + and check_permute(node) + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(weight, input.transpose(-1, -2)) + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and node.target == torch.nn.functional.linear: + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.bmm or node.target == torch.matmul + ): + normalized = NormalizedMatmulNode(node) + input_A_node = normalized.get_input() + input_B_node = normalized.get_other() + input_A = input_A_node + input_B = input_B_node + Atrans = Btrans = False + if ( + input_A_node.op == "call_method" + and input_A_node.target == "permute" + and check_permute(input_A_node) + ): + Atrans = True + if len(input_A_node.args) > 0: + input_A = input_A_node.args[0] # type: ignore[assignment] + else: + input_A = input_A_node.kwargs["input"] # type: ignore[assignment] + + if ( + input_B_node.op == "call_method" + and input_B_node.target == "permute" + and check_permute(input_B_node) + ): + Btrans = True + if len(input_B_node.args) > 0: + input_B = input_B_node.args[0] # type: ignore[assignment] + else: + input_B = input_B_node.kwargs["input"] # type: ignore[assignment] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(input_A, input_B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if Atrans and len(input_A_node.users) == 0: + module.graph.erase_node(input_A_node) + if Btrans and len(input_B_node.users) == 0: + module.graph.erase_node(input_B_node) + + module.graph.lint() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(input.transpose(-1, -2), weight.t()) + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul( + A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool +) -> torch.Tensor: + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/quantization.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..4af6622a719245a019f31cb95e14035bc619c68d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/quantization.py @@ -0,0 +1,1980 @@ +import copy +import functools +import itertools +import math +import operator +from typing import Any, Tuple + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from ..lowering import lowerings as L, require_channels_last +from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match +from ..utils import pad_listlike +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _may_generate_pattern_with_dtype_convert(pattern, dtype=Arg(), dtype_convert=True): + if dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +""" +dequantize activation: + x = x.to(fp32) + x = x - zero_point + x = x * scale +""" +dequantize_per_tensor_activation_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + KeywordArg("x"), + KeywordArg("x_dq_dtype"), + ), + KeywordArg("x_zp"), + ), + KeywordArg("x_scale"), +) + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_dequantize_qconv_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.default, + KeywordArg("x"), + KeywordArg("x_scale"), # x_scale + KeywordArg("x_zp"), # x_zp + KeywordArg("packed_weight"), # packed_weight + KeywordArg("w_scale"), # w_scale + KeywordArg("w_zp"), # w_zp + KeywordArg("b"), # bias + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("inv_output_scale"), # inv_output_scale = 1.0 + KeywordArg("output_zero_point"), # output_zero_point = 0 + KeywordArg("output_dtype"), # output_dtype = None + KeywordArg("attr"), # attr = "none" + Arg(), # scalars + Arg(), # algorithm + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + ) + + +dequantize_accum_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + KeywordArg("accum"), + KeywordArg("accum_dq_dtype"), + ), + KeywordArg("accum_zp"), + ), + KeywordArg("accum_scale"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + int8_mixed_bf16_with_inplace_add=False, +): + binary_pattern = CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + int8_mixed_bf16_with_inplace_add, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + if unary_post_op == aten.hardtanh.default: + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + if unary_post_op == aten.hardswish.default: + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, + CallFunction(aten.add, computation_call, 3), + 0, + ), + 6, + ), + ), + 6, + ) + else: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, dtype=torch.float32): + """ + quantize output: + output = round(output * o_inv_scale) + output = output + zero_point + output = clamp_min(output, 0) + output = clamp_max(output, 127) + output = output.to(uint8) + """ + assert dtype in [torch.float32, torch.bfloat16] + quantized_op_output_pattern_pt2e = CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.clamp_max.default, + CallFunction( + aten.clamp_min.default, + CallFunction( + aten.add.Tensor, + CallFunction( + aten.round.default, + CallFunction( + aten.mul.Tensor, + _may_generate_pattern_with_dtype_convert( + computation_call, + KeywordArg("autocast_output_quant_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("o_inv_scale"), + ), + ), + KeywordArg("o_zp"), + ), + KeywordArg("o_qmin"), + ), + KeywordArg("o_qmax"), + ), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv2d_optimization_pattern(output_dtype): + def fn(match): + if output_dtype is not None: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv2d_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _register_quantized_conv_lowering( + pattern, + pass_number, + computation_op, + output_dtype, + unary_attr, + original_pattern_output_dtype=torch.float32, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + assert output_dtype in [None, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + assert ( + kwargs["output_dtype"] is original_pattern_output_dtype + ) # Expected int8-in fp32-out qconv in weight prepack phase + assert ( + kwargs["attr"] == "none" + ) # Expected no post op fused in weight prepack phase + if unary_attr.op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + unary_attr.scalars_attr = [min_value, max_value] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ) + counters["inductor"]["qconv2d_unary_matcher_count"] += 1 + counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv + + +def _is_valid_quantized_linear_optimization_pattern(output_dtype): + def fn(match): + if output_dtype is not None: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _register_quantized_linear_lowering( + pattern, + pass_number, + computation_op, + output_dtype, + unary_attr, + original_pattern_output_dtype=torch.float32, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype), + pass_number=pass_number, + ) + def qlinear(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + assert ( + kwargs["output_dtype"] is original_pattern_output_dtype + ) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ) + counters["inductor"]["qlinear_unary_matcher_count"] += 1 + counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear + + +def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype): + # Check if it's a valid Conv Binary Pattern: + # * qconv2d_pointwise should only has one users + # * Extra input of binary node comes from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0] + # qconv2d_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype is not None: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( + extra_input_of_binary_node.target != aten.mul.Tensor + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from .mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["accum"] + if output_dtype is None + else match.kwargs["accum_after_dequant"] + ) + if ( + len( + _get_remaining_users( + extra_input_of_pattern, + compute_node, + ) + ) + > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _register_quantized_conv_binary_lowering( + pattern, + pass_number, + computation_op, + output_dtype, + binary_unary_attr, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype), + pass_number=pass_number, + ) + def qconv_binary(match: Match, *args, **kwargs): + x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] + accum = ( + kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"] + ) + accum_scale = kwargs["accum_scale"] if output_dtype is None else 1.0 + accum_zp = kwargs["accum_zp"] if output_dtype is None else 0 + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + # Output QParams + o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 + o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + + accum.realize() + from .mkldnn_fusion import _can_be_inplace + + assert _can_be_inplace( + accum + ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + + computation_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_unary_attr.binary_op_name, + binary_unary_attr.alpha, + binary_unary_attr.unary_op_name, + binary_unary_attr.scalars_attr, + binary_unary_attr.algorithm_attr, + ) + counters["inductor"]["qconv2d_binary_matcher_count"] += 1 + counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv_binary + + +def _register_quantization_unary_fusion(): + class UnaryAttr: + def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # QConv2d + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + conv_unary_replace_patterns = { + UnaryAttr("none", [], ""): generate_pattern_with_output_quant( + get_dequantize_qconv_pt2e_pattern(1), + dtype=original_pattern_output_dtype, + ), + UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.relu.default + ), + dtype=original_pattern_output_dtype, + ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default + ), + dtype=original_pattern_output_dtype, + ), + UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default + ), + dtype=original_pattern_output_dtype, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_quantized_conv_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qconv2d_pointwise, # computation_op + None, # output_dtype, None is the default value for int8 output + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + UnaryAttr("relu", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.relu.default + ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(1), aten.hardtanh.default + ), + UnaryAttr("hardswish", [], ""): generate_pattern_with_unary( + get_dequantize_qconv_pt2e_pattern(2), aten.hardswish.default + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_quantized_conv_lowering( + patterns, + 2, # pass_number + torch.ops.onednn.qconv2d_pointwise, # computation_op + original_pattern_output_dtype, # output_dtype + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + UnaryAttr("none", [], ""): generate_pattern_with_output_quant( + qlinear_pattern, + dtype=original_pattern_output_dtype, + ), + UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + dtype=original_pattern_output_dtype, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_quantized_linear_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qlinear_pointwise, # computation_op + None, # output_dtype + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + UnaryAttr("relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_quantized_linear_lowering( + patterns, + 2, # pass_number + torch.ops.onednn.qlinear_pointwise, # computation_op + original_pattern_output_dtype, # output_dtype + unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, + ) + + +def _register_quantization_binary_fusion(): + class BinaryUnaryAttr: + def __init__( + self, + binary_op_name: str, + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ): + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + binary_replace_patterns = { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + ), + dtype=torch.bfloat16 + if int8_mixed_bf16_with_inplace_add + else torch.float32, + ), + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + ), + aten.relu.default, + ), + dtype=torch.bfloat16 + if int8_mixed_bf16_with_inplace_add + else torch.float32, + ), + } + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_quantized_conv_binary_lowering( + patterns, + 0, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + None, # output_dtype + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = { + BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + ), + aten.relu.default, + ), + } + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_quantized_conv_binary_lowering( + patterns, + 0, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + # Note that for int8-mixed-bf16 and non-inplace add, because we have + # q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs, + # the output dtype will be float32. + # For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output. + torch.bfloat16, + binary_unary_attr, # binary_unary_attr + ) + else: + _register_quantized_conv_binary_lowering( + patterns, + 1, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + torch.float32, + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = { + BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + ), + } + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_quantized_conv_binary_lowering( + patterns, + 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + # Same output dtype setting as conv-add-relu pattern + torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32, + binary_unary_attr, # binary_unary_attr + ) + + +def _is_valid_quantized_maxpool2d_optimization_pattern(): + def fn(match): + # Only match the pattern which max_pool2d_with_indices returns value + # instead of indices. + get_item_node = filter_nodes(match.nodes, operator.getitem)[0] + return get_item_node.args[1] == 0 + + return fn + + +def _register_quantized_maxpool2d_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), + ) + def qmaxpool2d(match: Match, *args, **kwargs): + x = kwargs["x"] + kernel_size = kwargs["kernel_size"] + stride = kwargs["stride"] if ("stride" in kwargs) else None + padding = kwargs["padding"] if ("padding" in kwargs) else 0 + dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 + ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False + + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + + computation_args = ( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + computation_args, _ = require_channels_last(computation_op, *computation_args) + return L[computation_op](*computation_args) + + return qmaxpool2d + + +def _register_quantization_maxpool2d(): + # Currently, the default parameters are not in FX Graph generated by Dynamo export. + # So, if user defines nn.MaxPool2d with different assignment of default parameter, + # it will generate graph with different number of input nodes and hence + # different pattern to be matched. + # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 + max_pool2d_args_list = [ + [ + KeywordArg("stride"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("ceil_mode"), + ], + ] + + for max_pool2d_args in max_pool2d_args_list: + dequantize_maxpool2d_pattern = CallFunction( + aten.max_pool2d_with_indices.default, + dequantize_per_tensor_activation_pattern, + KeywordArg("kernel_size"), + *max_pool2d_args, + ) + dequantize_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_maxpool2d_pattern, + Arg(), + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), + quantized.max_pool2d.default, + ) + + +def _is_input_output_same_scale_zp(check_node): + def fn(match): + # Ensure all the inputs and output has same scale and zero point + # Step 1: Check inputs/output zero point + sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor) + zero_points = [node.args[1] for node in sub_nodes] + add_nodes = filter_nodes(match.nodes, aten.add.Tensor) + assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(add_nodes[0].args[1]) + if not all(zero_point == zero_points[0] for zero_point in zero_points): + return False + + # Step 2: Check inputs/output scale + mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor) + # We need to find mul node at output since the scale value is reciprocal to input scale. + # Mul node at output should connect to cat node directly. + scales = [ + ( + mul_node.args[1] + if mul_node.args[0].target is check_node # type: ignore[union-attr] + else 1.0 / mul_node.args[1] # type: ignore[operator] + ) + for mul_node in mul_nodes + ] + if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] + return False + + return True + + return fn + + +def _register_quantized_cat_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.cat.default), + ) + def qcat(match: Match, inputs, dim, **kwargs): + # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] + uint8_inputs = [input[0] for input in inputs] + return L[computation_op](uint8_inputs, dim) + + return qcat + + +_raw_dequantize_per_tensor_activation_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + Arg(), + Arg(), + ), + Arg(), + ), + Arg(), +) + + +def _register_quantization_cat(): + dequantize_cat_pattern = CallFunction( + aten.cat.default, + ListOf(_raw_dequantize_per_tensor_activation_pattern), + KeywordArg("dim"), + ) + _register_quantized_cat_lowering( + generate_pattern_with_output_quant(dequantize_cat_pattern), + aten.cat, + ) + + +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + dequantize_per_tensor_activation_pattern, + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + +def _register_quantization_lowerings(): + _register_quantization_unary_fusion() + _register_quantization_binary_fusion() + _register_quantization_maxpool2d() + _register_quantization_cat() + _register_quantization_reshape() + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + aten.mul.Tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + mul_node = ( + dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- mul + ) + else: + mul_node = ( + dequant_pattern_end_node # pattern: linear <- mul + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- mul + ) + + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + if ( + mul_node.target is aten.mul.Tensor + and sub_node.target is aten.sub.Tensor + and to_fp32_node.target is prims.convert_element_type.default + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert ( + source_node.op == "call_function" + ), "clone_to_new_node only support node.op call_function" + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dtype convert to float32 + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + aten.mul.Tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * aten.mul + # * aten.sub + # * prims.convert_element_type.default (to_fp32) + def _find_first_node_in_dequant_pattern(_node): + if ( + _node.target is prims.convert_element_type.default + and _node.args[1] == torch.float32 + ): + # For a dequant pattern, we expect the start node is a to_fp32 node + return _node + else: + assert ( + len(_node.args) >= 1 + ), "In in dequant pattern, each node should have more than 1 arg." + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv2d_pattern(dtype): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 4 + ): + # Only support conv2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + if dtype == torch.float32: + mul_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + mul_node = convert_to_bf16.args[0] + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + + assert to_fp32_node.target is prims.convert_element_type.default + assert sub_node.target is aten.sub.Tensor + assert mul_node.target is aten.mul.Tensor + if ( + len(list(to_fp32_node.users)) != 1 + or len(list(sub_node.users)) != 1 + or len(list(mul_node.users)) != 1 + ): + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv2d_pattern(dtype), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if dtype == torch.float32: + mul_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + mul_node = convert_to_bf16.args[0] # type: ignore[union-attr] + sub_node = mul_node.args[0] # type: ignore[union-attr] + to_fp32_node = sub_node.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: Tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # inv_output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv2d_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if dtype == torch.bfloat16: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined] + # Erase the dequant pattern + graph.erase_node(mul_node) + graph.erase_node(sub_node) + graph.erase_node(to_fp32_node) + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32 +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_mul_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if dtype == torch.float32: + # pattern: linear -> reshape -> mul + mul_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> mul + activation_to_bf16_node = act_reshape_node.args[0] + mul_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if dtype == torch.float32: + mul_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + mul_node = activation_to_bf16_node.args[0] + else: + if dtype == torch.float32: + # pattern: linear -> mul + mul_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> mul + activation_to_bf16_node = linear_node.args[input_index] + mul_node = activation_to_bf16_node.args[0] + return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + + ( + mul_node, + _, + _, + _, + ) = _get_linear_dq_mul_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + + assert to_fp32_node.target is prims.convert_element_type.default + assert sub_node.target is aten.sub.Tensor + assert mul_node.target is aten.mul.Tensor + if ( + len(list(to_fp32_node.users)) != 1 + or len(list(sub_node.users)) != 1 + or len(list(mul_node.users)) != 1 + ): + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + mul_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_mul_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: Tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(mul_node) + graph.erase_node(sub_node) + graph.erase_node(to_fp32_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two + ) + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ) + for dtype, input_dim_exceeds_two in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + dequantize_per_tensor_activation_pattern, + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two in linear_weight_prepack_cases: + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, input_dim_exceeds_two + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/ + # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for dtype, with_bias in itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ): + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + ) + + +@functools.lru_cache(None) +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/reinplace.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..ab939087a72c8c0a33e402ba8a118a04402b4783 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/reinplace.py @@ -0,0 +1,537 @@ +import operator +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_functional +from torch._inductor import inductor_prims +from torch._inductor.fx_utils import get_node_storage, is_node_realized +from torch._inductor.lowering import ( + inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, +) +from torch._inductor.virtualized import V +from torch.fx.immutable_collections import immutable_dict +from torch.fx.passes.reinplace import _is_view_op +from torch.utils import _pytree as pytree + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class InplaceableOp: + inplace_op: Callable[..., Any] + mutated_arg: int + extra_check: Callable[[torch.fx.Node], bool] = lambda node: True + + +_SCATTER_OP_TO_VIEW = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} +_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()} + + +def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (args, kwargs), + ) + with V.fake_mode: + fake_result = fn(*fake_args, **fake_kwargs) + + node = graph.call_function(fn, args, kwargs) + node.meta["val"] = fake_result + return node + + +@dataclass +class ViewOp: + target: torch._ops.OpOverload + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + + +def _inplace_generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp] +) -> torch.Tensor: + tmp = inp + for view in view_ops: + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (view.args, view.kwargs), + ) + tmp = view.target(tmp, *fake_args, **fake_kwargs) + tmp.copy_(src) + return inp + + +def _generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp] +) -> torch.Tensor: + out = inp.clone() + return _inplace_generalized_scatter(out, src, view_ops) + + +def _decompose_scatter_functional_helper( + graph: torch.fx.Graph, + inp: torch.Tensor, + src: torch.Tensor, + view_ops: List[ViewOp], +) -> torch.fx.Node: + view_op, view_ops_tail = view_ops[0], view_ops[1:] + + if view_ops_tail: + view = graph_call_function( + graph, view_op.target, inp, *view_op.args, **view_op.kwargs + ) + src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment] + + return graph_call_function( + graph, + _VIEW_OP_TO_SCATTER[view_op.target], + inp, + src, + *view_op.args, + **view_op.kwargs, + ) + + +def _decompose_scatter_functional( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter to a sequence of view_scatter operations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + view = aten.slice(inp, 0, 0, 10) + view_updated = aten.slice_scatter(view, src, 1, 10, -10) + inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10) + """ + assert node.target is _generalized_scatter + inp, src, view_ops = node.args + return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type] + + +def _decompose_scatter_mutating( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter using mutations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + inp_updated = aten.clone(inp) + slice1 = aten.slice(inp_updated, 0, 0, 10) + slice2 = aten.slice(slice1, 1, 10, -10) + slice2.copy_(src) + + """ + assert node.target in (_generalized_scatter, _inplace_generalized_scatter) + inp, src, view_ops = node.args + assert not node.kwargs + + if node.target is _generalized_scatter: + inp = graph_call_function(graph, aten.clone, inp) + + tmp = inp + for view in view_ops: # type: ignore[union-attr] + tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] + + graph_call_function(graph, aten.copy_.default, tmp, src) + return inp # type: ignore[return-value] + + +# View ops whose view_scatter op is lowered into mutations anyway, +# so is never a pessimisation to decompose. +_ALWAYS_MUTATING_SCATTER_OPS = { + aten.as_strided.default, + aten.diagonal.default, +} + + +def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: + _, _, view_ops = node.args + return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr] + + +def should_reinplace_scatter(node: torch.fx.Node) -> bool: + """Choose between mutating and functional scatter decompositions + + Reinplacing view scatter ops can be pessimising as it blocks fusion with the + input or output tensor computations. However, it is still profitable if the + input and output would have been realized anyway. + + """ + inp, src, view_ops = node.args + + # Mutating scatter ops unconditionally realize input and output + if scatter_always_uses_mutation(node): + return True + + if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type] + return True + + # If the output is copied back into the input, this forces both to be + # realized as the output is a user of the input + if inp.op == "placeholder" and any( # type: ignore[union-attr] + user.target is aten.copy_.default and user.args[0] is inp for user in node.users + ): + return True + + # Otherwise, assume fusions will make functional variants profitable + return False + + +def decompose_generalized_scatter(graph: torch.fx.Graph) -> None: + """Replace _generalized_scatter with normal aten ops""" + for node in graph.nodes: + if node.target not in (_generalized_scatter, _inplace_generalized_scatter): + continue + + use_mutation = ( + node.target is _inplace_generalized_scatter + or scatter_always_uses_mutation(node) + ) + + with graph.inserting_before(node): + if use_mutation: + new_node = _decompose_scatter_mutating(graph, node) + else: + new_node = _decompose_scatter_functional(graph, node) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None: + """ + This canonicalizes view scatter ops into a generalized form, defined as: + def scatter(inp, src, views): + tmp = inp.clone() + for view in views: + tmp = view(tmp) + tmp.copy_(src) + + We also fuse consecutive view scatter ops of the form + a = scatter(view2(self), src, [view1]) + b = scatter(self, a, [view2]) + which can be rewritten as + b = scatter(self, src, [view2, view1]) + a = view2(b) + + This is both more efficient as we only do a single scatter, and also + easier to reinplace since there is only one use of `self` + """ + + node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {} + node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list) + + def handle_views(node: torch.fx.Node): + inp = node.args[0] + node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + node_to_view_op[node] = [ + *node_to_view_op[inp], # type: ignore[index] + ViewOp( + node.target, # type: ignore[arg-type] + args=node.args[1:], + kwargs=node.kwargs, + ), + ] + + def handle_view_scatter(node: torch.fx.Node): + assert len(node.args) >= 2 + inp, src = node.args[:2] + + scatter_view_op = ViewOp( + _SCATTER_OP_TO_VIEW[node.target], + args=node.args[2:], + kwargs=node.kwargs, + ) + + def can_fuse(): + if src.target is not _generalized_scatter: # type: ignore[union-attr] + return False + src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + + inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type] + return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index] + *node_to_view_op[inp], # type: ignore[index] + scatter_view_op, + ] + + if not can_fuse(): + with graph.inserting_before(node): + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src, + [scatter_view_op], + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + return + + src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + with graph.inserting_before(src): + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src_src, + [scatter_view_op, *src_scatter_view_op], # type: ignore[misc] + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if src.users: # type: ignore[union-attr] + new_src = graph_call_function( + graph, + _SCATTER_OP_TO_VIEW[node.target], + new_node, + *node.args[2:], + **node.kwargs, + ) + + handle_views(new_src) + src.replace_all_uses_with(new_src) # type: ignore[union-attr] + + graph.erase_node(src) + + for node in graph.nodes: + if _is_view_op(node.target): + handle_views(node) + elif node.target in _SCATTER_OP_TO_VIEW: + handle_view_scatter(node) + + +inplaceable_ops = { + aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), + _generalized_scatter: InplaceableOp( + _inplace_generalized_scatter, + 0, + extra_check=should_reinplace_scatter, + ), +} + +try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) +except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {} +for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): + inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) + + +inplaceable_triton_ops = {triton_kernel_wrapper_functional} + + +# Operators that don't depend on the tensor data +META_ONLY_OPS = { + aten.sym_size.int, + aten.sym_stride.int, + aten.sym_numel.default, + aten.sym_storage_offset.default, +} + + +def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: + """ + Reinplaces in-placeable operations. + If there are no uses of a view of the mutated arg after the current node, + it is possible to inplace the op. + This above algorithm could be justified by observing side effects. While + we traverse the graph in forwards direction, only latter nodes could view + side effects of the current node. If the current node is not used later as + well as no view of this node is used later in the graph, then it is safe to + inplace as there would be no way to observe the side effects. + This condition is slightly different for graph inputs where they can only + be inplaced if the above condition is true and there's a copy_ in the + epilogue that signals that the caller wants to observe the mutation. + """ + + copy_args_to_copy_nodes = {} + mutated_inputs = set() + storage_to_nodes = defaultdict(list) + node_order: Dict[Any, int] = {} + for i, node in enumerate(reversed(graph.nodes)): + node_order[node] = len(graph.nodes) - i - 1 + storage_to_nodes[get_node_storage(node)].append(node) + if node.target == aten.copy_.default and node.args[0].op == "placeholder": + dst = node.args[0] + src = node.args[1] + # If the target is a getitem and it indexes a possible clone, + # then skip over it + if src.target == operator.getitem and ( + ( + src.args[0].target == triton_kernel_wrapper_functional + and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + ) + or (src.args[0].target in inplaceable_foreach_ops) + or (src.args[0].target == torch.ops.higher_order.auto_functionalized) + ): + src = src.args[0] + + copy_args_to_copy_nodes[(dst, src)] = node + + mutated_inputs.add(node.args[0]) + + def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node): + node_loc = node_order[node] + copy_node_loc = node_order[copy_node] if copy_node is not None else None + + def is_meta_only_user(node): + if _is_view_op(node.target): + return all(is_meta_only_user(u) for u in node.users) + return node.target in META_ONLY_OPS + + for view in shared_view_nodes: + for user in view.users: + user_loc = node_order[user] + # Skip all users before node + if user_loc <= node_loc: + continue + # Ignore uses after the copy_ epilogue node, where the input + # has already been mutated anyway + if copy_node_loc is not None and copy_node_loc <= user_loc: + continue + # Reinplacing does not change shape metadata + if is_meta_only_user(user): + continue + return True + return False + + def can_inplace(node, mutated_arg): + if isinstance(mutated_arg, (list, tuple)): + return all(can_inplace(node, arg) for arg in mutated_arg) + + if get_node_storage(mutated_arg) is None: + return False + shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + if mutated_arg.op == "placeholder": + if not ( + copy_node := copy_args_to_copy_nodes.get((mutated_arg, node), False) + ): + return False + + if any_use_of_views_after_node( + node, shared_view_nodes, copy_node=copy_node + ): + return False + + return True + elif any(view.op == "placeholder" for view in shared_view_nodes): + # If mutated arg is view of any of the inputs of the graph, + # do not allow for inplacing. + # This would require more sophisticated algorithm to handle + return False + else: + return not any_use_of_views_after_node( + node, shared_view_nodes, copy_node=None + ) + + replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {} + + def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs): + tensors_to_clone: List[str] = [] + for arg in old_tensors_to_clone: + assert arg in kwargs + mutated_arg = kwargs[arg] + if can_inplace(node, mutated_arg): + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + for user in node.users: + if user.target == operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg + else: + tensors_to_clone.append(arg) + return tensors_to_clone + + for node in graph.nodes: + if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: + mutated_arg = node.args[inplaceable_op.mutated_arg] + if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node): + # TODO(yifu): this doesn't properly remove copy epilogues for + # ops that mutate multiple inputs. Need to revise the copy + # node tracking logic to support the case. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + node.target = inplaceable_op.inplace_op + elif node.target == torch.ops.higher_order.auto_functionalized: + _mutable_op = node.args[0] + from torch._higher_order_ops.auto_functionalize import get_mutable_arg_names + + tensors_to_clone = get_mutable_arg_names(_mutable_op) + # Don't try to reinplace Optional[Tensor] args that are None. + tensors_to_clone = [ + t for t in tensors_to_clone if node.kwargs[t] is not None + ] + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + tensors_to_clone, node.kwargs + ) + + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = tensors_to_clone + elif node.target in inplaceable_triton_ops: + # inplaceable_triton_ops take an additional argument called + # tensors_to_clone which contain a list of tensors to clone + # This pass iterates over them and sees which ones are safe + # to eliminate (i.e. no longer need the clones) + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + node.kwargs["tensors_to_clone"], node.kwargs["kwargs"] + ) + + kwargs = dict(node.kwargs) + kwargs["tensors_to_clone"] = tensors_to_clone + node.kwargs = immutable_dict(kwargs) + elif ( + inplaceable_op := inplaceable_foreach_ops.get(node.target, None) + ) is not None: + mutated_args = node.args[inplaceable_op.mutated_arg] + + if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): + continue + + if can_inplace(node, mutated_args): + for arg in mutated_args: + copy_node = copy_args_to_copy_nodes[(arg, node)] + replace_dict[copy_node] = copy_node.args[0] + + node.target = inplaceable_op.inplace_op + for node, replacement in replace_dict.items(): + while replacement in replace_dict: + replacement = replace_dict[replacement] + replace_dict[node] = replacement + + node.replace_all_uses_with(replacement) + graph.erase_node(node) + + +def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: + canonicalize_view_scatter_ops(graph) + reinplace_inplaceable_ops_core(graph) + decompose_generalized_scatter(graph) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/replace_random.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/replace_random.py new file mode 100644 index 0000000000000000000000000000000000000000..d3bd47f93d3d8a655af9d6606dd3570ce58957b3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/replace_random.py @@ -0,0 +1,139 @@ +import collections +import logging + +import torch + +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..virtualized import V + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten + + +def replace_random_passes(gm: torch.fx.GraphModule): + """Modify the given FX graph to use backend-native random ops""" + if config.fallback_random: + return 0 + + count = patterns.apply(gm) + count += fuse_seed_creation_pass(gm.graph) + + return count + + +def fuse_seed_creation_pass(graph: torch.fx.Graph): + """ + Horizontally fuse all the seed generation on each device + + a = inductor_seed(dev) + b = inductor_seed(dev) + + Becomes: + seeds = inductor_seeds(2, dev) + a = inductor_lookup_seed(seeds, 0) + b = inductor_lookup_seed(seeds, 1) + + We do this because seed creation is entirely launch overhead bound. + """ + device_seeds = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.seed).match(node): + device_seeds[node.args[0]].append(node) + + if not device_seeds: + return 0 + + for device, seeds in device_seeds.items(): + with graph.inserting_before(seeds[0]): + combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(seeds)], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, seed in enumerate(seeds): + with graph.inserting_before(seed): + new_seed = graph.call_function( + inductor_prims.lookup_seed, (combined, idx) + ) + seed.replace_all_uses_with(new_seed) + new_seed.meta.update(seed.meta) + graph.erase_node(seed) + + return len(device_seeds) + + +def default_kwargs(device): + return {} + + +def get_device(device): + if device is not None: + return device + return torch.empty([]).device # default device + + +@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) +def replace_random( + match: Match, + size, + *, + generator=None, + dtype=None, + device=None, + layout=None, + pin_memory=None, +): + if generator is not None: + return + + def replacement(size): + result = inductor_prims.random( + size, inductor_prims.seed(device), mode, **default_kwargs(device) + ) + if dtype is not None: + result = result.to(dtype) + return result + + mode = { + aten.rand: "rand", + aten.randn: "randn", + }[ + match.output_node().target.overloadpacket # type: ignore[union-attr] + ] # type: ignore[union-attr] + device = get_device(device) + match.replace_by_example(replacement, [size]) + + +@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) +def replace_randint( + match: Match, + low, + high, + size, + *, + dtype=torch.int64, + device=None, + layout=None, + pin_memory=None, +): + def replacement(size): + result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) + return result.to(dtype) + + device = get_device(device) + match.replace_by_example(replacement, [size]) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1951dcb0b2a98735214df0ece98a6aa8db20258 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d14aedbc748390cd63f71c0bcddd3bc54512fb16 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..118cae4b91c5c17f6172342f6e0841cc41d44728 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc49d65fcac1cb24a7fc77bea4cc0f94009929dd Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6eccbeec1b2fba1e2be3d52bad5ad0e6ca6fe5f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f55ad1b7e534ba4a302abc6f9e8b43c65de86e1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5331184e8c5af2cbe8357224f862cd50ceaa005a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ff30d7a5e93a1432ba74e688e5ed3786a488f5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1855f808287b6141c9a8acae6feb017620312865 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2285a440852168e2fcddd00de3971e29d0e21abd Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3db56343d61e311ef4a18e9d9d5d0d881eddf127 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bc12d667d2b39ef1c75ab3cb2bdaee1d9116288 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65488e3dd89c841169c5c1cf8a83b24ab0ceaf76 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e09a0a4b547d1acc8aeaebcb7b22a77a8193805d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b00473d6be5787f2db16951b631bfc9adcaeb8d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71929133477589e09a6f24a2b58f0e50ce83cd76 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f670d9e4fa2460b35cb25f745ec5793a7fa33b53 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fc35b514d02f3e54c1e1a576f6d56c4101b7113 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c65099dc221b736a973115b3fba5dedd49caf5a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/central_index.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..7cabcce55572f2cf6c655bf9001dc9955aa962aa --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,182 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..f7202fa6c6ac7a649c8ab03717782d7d85f62acc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,213 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..29f4f100f755b663e035c1927668b13cad3a1ef8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,212 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..67030a6f9735a3e56814f5a1b86de7e4b0f85b2f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,232 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..20062a0d75db069326c85fc3a81a8b024a04abcf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,142 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, sub_Tensor_1, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, sub_Tensor_1) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value')) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, clone_default, KeywordArg('value')) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e7045190031ebb1fd7f78ab3e0726fab2d062e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,218 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..ffcc566249470c1fd15659cf8f407b915188affc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,236 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..9d44b570a0eb401069408d72947c6b9b9e5b36b2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,635 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..a5fa9a19d545f959188bfac12e7e932e89e02ad2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,256 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, sub_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default_3, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_3 = CallFunction(aten.expand.default, clone_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e7d69f6121a65907b068edbb5a0507b9c5046e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,182 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_1, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_3, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..895b921f9ee3f733f6e33befa0181848cebab503 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,202 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff3b6644ab50271233860505beaafb2de6b9a13 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,202 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6) +mul_Tensor_7 = CallFunction(aten.mul.Tensor, sub_Tensor_1, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +clone_default = CallFunction(aten.clone.default, mul_Tensor_4, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_5, mul_Tensor_6) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +mul_Tensor_7 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1ca2ef9321d837deb360cc1d1d5e6dd9359781 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,186 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, convert_element_type_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3dff8e414a77b65a8d3679ce8ed0baf04a1cca --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,206 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default = CallFunction(aten.clone.default, div_Tensor_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +clone_default = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, clone_default, Ignored()) +alias_default = CallFunction(aten.alias.default, convert_element_type_default_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, alias_default_3, Ignored(), _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_4, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +clone_default = CallFunction(aten.clone.default, convert_element_type_default_1) +expand_default_2 = CallFunction(aten.expand.default, clone_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..a277750c844e1e0faeb485635a49b0d0643e7191 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,233 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..69eefd322686b5e80260250de7ff6057605eac24 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,213 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..c0cdb933a8e9bca2c02a391917a634f10c40f287 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,233 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +view_default_8 = CallFunction(aten.view.default, sub_Tensor_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +clone_default_3 = CallFunction(aten.clone.default, mul_Tensor_3, memory_format=torch.contiguous_format) +alias_default = CallFunction(aten.alias.default, div_Tensor_1) +alias_default_1 = CallFunction(aten.alias.default, alias_default) +alias_default_2 = CallFunction(aten.alias.default, alias_default_1) +alias_default_3 = CallFunction(aten.alias.default, alias_default_2, _users=2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, clone_default_3, alias_default_3, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, alias_default_3, sum_dim_IntList_1) +sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor_4, mul_Tensor_5) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, sub_Tensor_1, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +clone_default_2 = CallFunction(aten.clone.default, div_Tensor_1) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, clone_default_2, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_3 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored()) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4a11ee604d7103f1669816c846012786b90ac6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/central_index.py @@ -0,0 +1,114 @@ +# mypy: ignore-errors + +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python +# torchgen/fuse_attention_patterns/gen_attention_patterns.py +from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference, _sfdp_pattern_1_half_training, _sfdp_pattern_1_half_inference) +from ._sfdp_pattern_2 import (_sfdp_pattern_2_training, _sfdp_pattern_2_inference, _sfdp_pattern_2_half_training, _sfdp_pattern_2_half_inference) +from ._sfdp_pattern_3 import (_sfdp_pattern_3_training, _sfdp_pattern_3_inference, _sfdp_pattern_3_half_training, _sfdp_pattern_3_half_inference) +from ._sfdp_pattern_4 import (_sfdp_pattern_4_training, _sfdp_pattern_4_inference, _sfdp_pattern_4_half_training, _sfdp_pattern_4_half_inference) +from ._sfdp_pattern_5 import (_sfdp_pattern_5_training, _sfdp_pattern_5_inference, _sfdp_pattern_5_half_training, _sfdp_pattern_5_half_inference) +from ._sfdp_pattern_6 import (_sfdp_pattern_6_training, _sfdp_pattern_6_inference, _sfdp_pattern_6_half_training, _sfdp_pattern_6_half_inference) +from ._sfdp_pattern_7 import (_sfdp_pattern_7_training, _sfdp_pattern_7_inference, _sfdp_pattern_7_half_training, _sfdp_pattern_7_half_inference) +from ._sfdp_pattern_8 import (_sfdp_pattern_8_training, _sfdp_pattern_8_inference, _sfdp_pattern_8_half_training, _sfdp_pattern_8_half_inference) +from ._sfdp_pattern_9 import (_sfdp_pattern_9_training, _sfdp_pattern_9_inference, _sfdp_pattern_9_half_training, _sfdp_pattern_9_half_inference) +from ._sfdp_pattern_10 import (_sfdp_pattern_10_training, _sfdp_pattern_10_inference, _sfdp_pattern_10_half_training, _sfdp_pattern_10_half_inference) +from ._sfdp_pattern_11 import (_sfdp_pattern_11_training, _sfdp_pattern_11_inference, _sfdp_pattern_11_half_training, _sfdp_pattern_11_half_inference) +from ._sfdp_pattern_12 import (_sfdp_pattern_12_training, _sfdp_pattern_12_inference, _sfdp_pattern_12_half_training, _sfdp_pattern_12_half_inference) +from ._sfdp_pattern_13 import (_sfdp_pattern_13_training, _sfdp_pattern_13_inference, _sfdp_pattern_13_half_training, _sfdp_pattern_13_half_inference) +from ._sfdp_pattern_14 import (_sfdp_pattern_14_training, _sfdp_pattern_14_inference, _sfdp_pattern_14_half_training, _sfdp_pattern_14_half_inference) +from ._sfdp_pattern_15 import (_sfdp_pattern_15_training, _sfdp_pattern_15_inference, _sfdp_pattern_15_half_training, _sfdp_pattern_15_half_inference) +from ._sfdp_pattern_16 import (_sfdp_pattern_16_training, _sfdp_pattern_16_inference, _sfdp_pattern_16_bs1_training, _sfdp_pattern_16_bs1_inference, _sfdp_pattern_16_half_training, _sfdp_pattern_16_half_inference, _sfdp_pattern_16_half_bs1_training, _sfdp_pattern_16_half_bs1_inference, _sfdp_pattern_16_half_mask_fp32_training, _sfdp_pattern_16_half_mask_fp32_inference, _sfdp_pattern_16_half_mask_fp32_bs1_training, _sfdp_pattern_16_half_mask_fp32_bs1_inference) +from ._sfdp_pattern_17 import (_sfdp_pattern_17_training, _sfdp_pattern_17_inference, _sfdp_pattern_17_half_training, _sfdp_pattern_17_half_inference) + +central_index = { + '_sfdp_pattern_1_training': _sfdp_pattern_1_training, + '_sfdp_pattern_1_inference': _sfdp_pattern_1_inference, + '_sfdp_pattern_2_training': _sfdp_pattern_2_training, + '_sfdp_pattern_2_inference': _sfdp_pattern_2_inference, + '_sfdp_pattern_3_training': _sfdp_pattern_3_training, + '_sfdp_pattern_3_inference': _sfdp_pattern_3_inference, + '_sfdp_pattern_4_training': _sfdp_pattern_4_training, + '_sfdp_pattern_4_inference': _sfdp_pattern_4_inference, + '_sfdp_pattern_5_training': _sfdp_pattern_5_training, + '_sfdp_pattern_5_inference': _sfdp_pattern_5_inference, + '_sfdp_pattern_6_training': _sfdp_pattern_6_training, + '_sfdp_pattern_6_inference': _sfdp_pattern_6_inference, + '_sfdp_pattern_7_training': _sfdp_pattern_7_training, + '_sfdp_pattern_7_inference': _sfdp_pattern_7_inference, + '_sfdp_pattern_8_training': _sfdp_pattern_8_training, + '_sfdp_pattern_8_inference': _sfdp_pattern_8_inference, + '_sfdp_pattern_9_training': _sfdp_pattern_9_training, + '_sfdp_pattern_9_inference': _sfdp_pattern_9_inference, + '_sfdp_pattern_10_training': _sfdp_pattern_10_training, + '_sfdp_pattern_10_inference': _sfdp_pattern_10_inference, + '_sfdp_pattern_11_training': _sfdp_pattern_11_training, + '_sfdp_pattern_11_inference': _sfdp_pattern_11_inference, + '_sfdp_pattern_12_training': _sfdp_pattern_12_training, + '_sfdp_pattern_12_inference': _sfdp_pattern_12_inference, + '_sfdp_pattern_13_training': _sfdp_pattern_13_training, + '_sfdp_pattern_13_inference': _sfdp_pattern_13_inference, + '_sfdp_pattern_14_training': _sfdp_pattern_14_training, + '_sfdp_pattern_14_inference': _sfdp_pattern_14_inference, + '_sfdp_pattern_15_training': _sfdp_pattern_15_training, + '_sfdp_pattern_15_inference': _sfdp_pattern_15_inference, + '_sfdp_pattern_16_training': _sfdp_pattern_16_training, + '_sfdp_pattern_16_inference': _sfdp_pattern_16_inference, + '_sfdp_pattern_16_bs1_training': _sfdp_pattern_16_bs1_training, + '_sfdp_pattern_16_bs1_inference': _sfdp_pattern_16_bs1_inference, + '_sfdp_pattern_17_training': _sfdp_pattern_17_training, + '_sfdp_pattern_17_inference': _sfdp_pattern_17_inference, + '_sfdp_pattern_1_half_training': _sfdp_pattern_1_half_training, + '_sfdp_pattern_1_half_inference': _sfdp_pattern_1_half_inference, + '_sfdp_pattern_2_half_training': _sfdp_pattern_2_half_training, + '_sfdp_pattern_2_half_inference': _sfdp_pattern_2_half_inference, + '_sfdp_pattern_3_half_training': _sfdp_pattern_3_half_training, + '_sfdp_pattern_3_half_inference': _sfdp_pattern_3_half_inference, + '_sfdp_pattern_4_half_training': _sfdp_pattern_4_half_training, + '_sfdp_pattern_4_half_inference': _sfdp_pattern_4_half_inference, + '_sfdp_pattern_5_half_training': _sfdp_pattern_5_half_training, + '_sfdp_pattern_5_half_inference': _sfdp_pattern_5_half_inference, + '_sfdp_pattern_6_half_training': _sfdp_pattern_6_half_training, + '_sfdp_pattern_6_half_inference': _sfdp_pattern_6_half_inference, + '_sfdp_pattern_7_half_training': _sfdp_pattern_7_half_training, + '_sfdp_pattern_7_half_inference': _sfdp_pattern_7_half_inference, + '_sfdp_pattern_8_half_training': _sfdp_pattern_8_half_training, + '_sfdp_pattern_8_half_inference': _sfdp_pattern_8_half_inference, + '_sfdp_pattern_9_half_training': _sfdp_pattern_9_half_training, + '_sfdp_pattern_9_half_inference': _sfdp_pattern_9_half_inference, + '_sfdp_pattern_10_half_training': _sfdp_pattern_10_half_training, + '_sfdp_pattern_10_half_inference': _sfdp_pattern_10_half_inference, + '_sfdp_pattern_11_half_training': _sfdp_pattern_11_half_training, + '_sfdp_pattern_11_half_inference': _sfdp_pattern_11_half_inference, + '_sfdp_pattern_12_half_training': _sfdp_pattern_12_half_training, + '_sfdp_pattern_12_half_inference': _sfdp_pattern_12_half_inference, + '_sfdp_pattern_13_half_training': _sfdp_pattern_13_half_training, + '_sfdp_pattern_13_half_inference': _sfdp_pattern_13_half_inference, + '_sfdp_pattern_14_half_training': _sfdp_pattern_14_half_training, + '_sfdp_pattern_14_half_inference': _sfdp_pattern_14_half_inference, + '_sfdp_pattern_15_half_training': _sfdp_pattern_15_half_training, + '_sfdp_pattern_15_half_inference': _sfdp_pattern_15_half_inference, + '_sfdp_pattern_16_half_training': _sfdp_pattern_16_half_training, + '_sfdp_pattern_16_half_inference': _sfdp_pattern_16_half_inference, + '_sfdp_pattern_16_half_bs1_training': _sfdp_pattern_16_half_bs1_training, + '_sfdp_pattern_16_half_bs1_inference': _sfdp_pattern_16_half_bs1_inference, + '_sfdp_pattern_17_half_training': _sfdp_pattern_17_half_training, + '_sfdp_pattern_17_half_inference': _sfdp_pattern_17_half_inference, + '_sfdp_pattern_16_half_mask_fp32_training': _sfdp_pattern_16_half_mask_fp32_training, + '_sfdp_pattern_16_half_mask_fp32_inference': _sfdp_pattern_16_half_mask_fp32_inference, + '_sfdp_pattern_16_half_mask_fp32_bs1_training': _sfdp_pattern_16_half_mask_fp32_bs1_training, + '_sfdp_pattern_16_half_mask_fp32_bs1_inference': _sfdp_pattern_16_half_mask_fp32_bs1_inference, +} + + +def get_serialized_pattern(key): + import torch._inductor # noqa: F401 + from torch._inductor import config + if config.fallback_random: + return None + + # TODO - could add more validation that the same set of decomps used when + # tracing SDPA are also used in current context. softmax, dropout, etc + # decomp use is stable so not an issue in practice. + return central_index.get(key) diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..5f02e1ec5d90523caa1b772c0a565a67c6f973dc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,1537 @@ +import itertools +import logging +import operator +from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union + +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + config_flag, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid +from .pre_grad import ( + merge_getitem_cat_pass, + merge_splits_pass, + normalization_pass, + split_cat_pass, + unbind_stack_pass, +) + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = Tuple[ + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], +] +_Range: TypeAlias = Tuple[int, int] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def remove_split_with_size_one( + graph: torch.fx.Graph, + node: torch.fx.Node, + input: torch.fx.Node, +): + # find the grand children of the split_node + next_users = find_next_users(node) + user = next(iter(node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(node) + + counters["inductor"]["remove_split_with_size_one"] += 1 + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if "example_value" not in split_node.meta: + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + if len(split_sections) == 1: + remove_split_with_size_one(graph, split_node, split_input) + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["split_cat_norm"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if "example_value" not in input.meta: + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["split_cat_norm"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "example_value" not in tensor.meta: + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["split_cat_norm"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if "example_value" not in tensor.meta: + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["split_cat_norm"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users.keys(): + for getitem_user in getitem_node.users.keys(): + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=normalization_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + match.graph.erase_node(squeeze_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split): + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = set() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=merge_splits_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: List[int], + next_split_sections: List[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = first_split.kwargs["dim"] # type: ignore[union-attr] + + to_remove = [] + + with graph.inserting_before(first_split): + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + first_split_num_to_user = { + user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = { + user.args[1]: user for user in node.users.keys() + } + # It is not necessary all getitems from the split node are used. + # We use the num of users to check the getitems to be merged. + for next_split_num in range(len(node.users.keys())): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters["inductor"]["consecutive_split_merged"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: List[torch.fx.Node] + ) -> List[List[Union[torch.fx.Node, _Range]]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = [] + for user in next_users: + if user.target in {torch.cat, torch.stack}: + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> List[Union[torch.fx.Node, _Range]]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = set(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> List[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = set(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: List[Union[torch.fx.Node, int]] + ) -> List[Union[torch.fx.Node, _Range]]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + ranges = set() + for user_node, user_inputs in zip(next_users, user_inputs_list): + ranges |= { + user_input + for user_input in user_inputs + if isinstance(user_input, tuple) + } + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool: + for range_, next_range in zip(ranges, ranges[1:]): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = split_node.kwargs["dim"] + split_sections = split_node.args[1] + transform_params_list: List[List[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in {torch.cat, torch.stack}: + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target == torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + 1 + ] + # All sections should be equal + if len(set(subset_split_sections)) != 1: + return None + + num_splits = len(subset_split_sections) + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target == torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + split_ranges: List[_Range], + ) -> List[List[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = split_node.kwargs["dim"] + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + new_split.meta.update(split_node.meta) + counters["inductor"]["scmerge_split_added"] += 1 + with graph.inserting_after(new_split): + split_items = [ + graph.call_function(operator.getitem, args=(new_split, i)) + for i in range(len(split_ranges)) + ] + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list_new, + transform_params_list: List[List[_TransformParam]], + ): + split_dim = split_node.kwargs["dim"] + + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in {torch.cat, torch.stack}: + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed = [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack = [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + to_stack = [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + continue + + if unflatten_params: + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + if movedim_params: + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + if flatten_params: + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_inputs_new_transformed.append(user_input_new) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + user_inputs_new_transformed.append(stacked_input) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta.update(user_node.meta) + counters["inductor"]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + + if ( + user_node.target == torch.cat + and split_dim != cat_dim + and split_node.target == torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + ): + to_remove = [split_node] + counters["inductor"]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in {torch.cat, torch.stack}: + continue + counters["inductor"]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + num_unbind = ( # type: ignore[operator] + max(getitem_node.args[1] for getitem_node in unbind_node.users.keys()) + 1 # type: ignore[operator, union-attr, type-var] + ) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: List[int], + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + unbind_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = unbind_node.kwargs["dim"] + transform_params_list: List[List[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target == torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target == torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1): + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target == torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target == torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + for item_index, getitem_node in sorted( + [ + (getitem_node.args[1], getitem_node) + for getitem_node in split.users.keys() + ] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters["inductor"]["split_squeeze_replaced"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=unbind_stack_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=unbind_stack_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=unbind_stack_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def simplify_split_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: + return False + return True + + +def remove_zeros(split_sections: List[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: List[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_getitem_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + indices = [] + for arg in cat_user.args[0]: # type: ignore[union-attr] + indices.append(arg.args[1]) # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( + split_node, indices + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["getitem_cat_merged"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=split_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def mutate_cat_node(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["cat_mutated"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, list(range(indices[0])) + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, indices + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["cat_mutated"] += 1 + + +# noqa: W605 +# ############The pattern to be optimized is######### +# split_node (dim=1) +# / ... \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / +# stack (dim=0) -> user=1, getitems to be consecutive +# | +# tahn -> user=1 +# | +# unbind (dim=0) +# | + +# ################After transformation############# +# split_node (dim=1) +# / ... / \ +# getitem getitem getitem -> user=1 +# | +# tahn +# | +# split +# | + + +@register_graph_pattern( + CallFunction( + torch.tanh, + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + ), + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + torch.tanh, + CallFunction( + torch.stack, + tensors=getitem_split, + dim=Ignored(), + ), + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +@register_graph_pattern( + CallFunction( + torch.tanh, + CallFunction( + torch.stack, + getitem_split, + Ignored(), + ), + ), + pass_dict=merge_getitem_cat_pass, + extra_check=config_flag("split_cat_fx_passes"), +) +def merge_stack_tahn_unbind(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for user in next_users: + # stack user only has one user + if user.target == torch.stack: + stack_dim = get_arg_value(user, 1, "dim") or 0 + unbind_user = find_next_users(user)[0] + if unbind_user.target != torch.unbind: + continue + unbind_dim = get_arg_value(unbind_user, 1, "dim") or 0 + # stack and unbind should have the same dim + # check the all getitems in the user from the same node + # check all the getitems only has single user + if ( + stack_dim != unbind_dim + or not has_same_parent_node(user) + or not all(len(arg.users) == 1 for arg in user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be stacked + indices = [] + split_sections_for_unbind = [] + for arg in user.args[0]: # type: ignore[union-attr] + indices.append(arg.args[1]) # type: ignore[union-attr] + split_sections_for_unbind.append(split_sections[arg.args[1]]) # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # update the arg of stack user, only keep the first getitem + user.update_arg(0, user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, index, assignment] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( + split_node, indices + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + replace_unbind_with_split = graph.call_function( + torch.split, + args=(unbind_user.args[0], split_sections_for_unbind), + kwargs={"dim": split_dim}, + ) + unbind_user.replace_all_uses_with(replace_unbind_with_split) + replace_unbind_with_split.meta.update(unbind_user.meta) + # remove getitem and split, stack + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [unbind_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + graph.erase_node(user) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["stack_tahn_unbind_merged"] += 1 diff --git a/MLPY/Lib/site-packages/torch/_inductor/fx_utils.py b/MLPY/Lib/site-packages/torch/_inductor/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed09b35b6fbba51eb61076045cd56371dd4b6de --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/fx_utils.py @@ -0,0 +1,220 @@ +import operator +from collections import defaultdict +from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type + +import torch +import torch.fx +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map +from .virtualized import V + + +# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. +# Works for length 2 patterns with 1 module and 1 function/method. +def matches_module_function_pattern( + pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]], + node: torch.fx.node.Node, + modules: Dict[str, torch.nn.modules.Module], +) -> bool: + if len(node.args) == 0: + return False + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node, torch.fx.Node + ): + return False + # the first node is call_module + if node.args[0].op != "call_module": + return False + if not isinstance(node.args[0].target, str): + return False + if node.args[0].target not in modules: + return False + if type(modules[node.args[0].target]) is not pattern[0]: + return False + # the second node is call_function or call_method + if node.op != "call_function" and node.op != "call_method": + return False + if node.target != pattern[1]: + return False + # make sure node.args[0] output is only used by current node. + if len(node.args[0].users) > 1: + return False + return True + + +class FakeTensorUpdater: + """ + The main idea here is that it's difficult to maintain accurate fake + tensors (our primary form of metadata) for each node in our graph as we + transform it. + + The most reliable way to obtain this information is by rerunning + faketensor propagation. However, in general, faketensor propagation is + fairly expensive. So, instead we'd like to only rerun faketensor + propagation on nodes that have changed. + + In order to detect which nodes have changed, we first hash its node, + target, and argument lists (which are immutable in FX). + + Then, whenever we call incremental_update, we check which FX nodes have a + new hash, and recompute the faketensor metadata for that node. Then, we + continue to recursively compute the faketensors for all users until the + fake tensors stop changing. + """ + + def __init__(self, graph: torch.fx.Graph): + self.processed_hashes = set() + self.graph = graph + + for node in self.graph.nodes: + self.processed_hashes.add(self.hash_node(node)) + + def hash_node(self, node: torch.fx.Node): + # todo(chilli): Not a great hash function + return (node, node.target, id(node.args), id(node.kwargs)) + + def incremental_update(self): + processed = set() + existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) + for node in self.graph.nodes: + existing_storages[get_node_storage(node)] += 1 + + def is_intlist_same(new, old): + return statically_known_true(sym_eq(new, old)) + + def is_fake_tensor_same(new, old): + if type(new) != type(old): + return False + if isinstance(new, (list, tuple)): + if len(new) != len(old): + return False + return all( + is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) + ) + assert isinstance(new, torch.Tensor) + if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: + return False + if new.layout == torch.strided and ( + not is_intlist_same(new.stride(), old.stride()) + or not statically_known_true( + new.storage_offset() == old.storage_offset() + ) + ): + return False + + if get_storage(new) == get_storage(old): + return True + + # This is the case where it returns a completely fresh storage that's used nowhere else. + if ( + existing_storages[get_storage(old)] == 1 + and get_storage(new) not in existing_storages + ): + return True + return False + + for node in self.graph.nodes: + if self.hash_node(node) in self.processed_hashes: + continue + + def is_aten_node(node): + return node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ) + + if not is_aten_node(node): + continue + + processing = [node] + while len(processing) > 0: + updating_node = processing.pop() + if updating_node in processed: + continue + if is_aten_node(updating_node): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(updating_node) + if not is_valid: + continue + with V.fake_mode: + new_fake_tensor = updating_node.target(*args, **kwargs) + if "val" in updating_node.meta and is_fake_tensor_same( + new_fake_tensor, updating_node.meta["val"] + ): + continue + updating_node.meta["val"] = new_fake_tensor + + # todo(chilli): This code path is not exercised by our existing + # tests - add a test + existing_storages[get_node_storage(new_fake_tensor)] += 1 + processed.add(updating_node) + processing.extend(updating_node.users) + + self.processed_hashes.add(self.hash_node(updating_node)) + + +def get_storage(t: torch.Tensor) -> int: + return t.untyped_storage()._cdata + + +def get_node_storage(node: torch.fx.Node) -> Optional[int]: + if "val" not in node.meta: + return None + if not isinstance(node.meta["val"], torch.Tensor): + return None + if not torch._C._has_storage(node.meta["val"]): + return None + return get_storage(node.meta["val"]) + + +def get_fake(x): + if isinstance(x, torch.fx.Node): + if "val" not in x.meta: + return x + return x.meta["val"] + return x + + +def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]: + """ + First value returns a boolean if any of the input nodes don't have a faketensor. + """ + args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) + if any( + isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) + ): + return False, args, kwargs + return True, args, kwargs + + +def is_node_realized(node: torch.fx.Node) -> bool: + """Returns true if a node is always realized when lowered to inductor IR. + + NOTE: This may return some false negatives. e.g. it doesn't + handle buffers realized heuristically during lowering, or + buffers realized indirectly through view ops. + """ + from torch._inductor.lowering import fallbacks, needs_realized_inputs + + def is_buffer(node: torch.fx.Node) -> bool: + if node.op == "call_function" and node.target is operator.getitem: + # For nodes with multiple outputs, we get the fx graph: + # foo = torch.ops.aten.foo(...) + # getitem = foo[0] + # getitem_1 = foo[1] + # where we need to check if foo is a fallback kernel + return is_buffer(node.args[0]) # type: ignore[arg-type] + return node.op in ("placeholder", "output") or node.target in fallbacks + + if is_buffer(node): + return True + + def realizes_inputs(node: torch.fx.Node) -> bool: + return node.op == "output" or node.target in needs_realized_inputs + + if any(realizes_inputs(user) for user in node.users): + return True + + # Otherwise, assume node isn't realized + return False diff --git a/MLPY/Lib/site-packages/torch/_inductor/graph.py b/MLPY/Lib/site-packages/torch/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..ed21547c5da3cc2259c5026632440c169309a3a6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/graph.py @@ -0,0 +1,1324 @@ +import itertools +import logging +import operator +import os +import re +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple + +import sympy + +import torch +import torch._logging +import torch.fx +from torch._decomp import get_decompositions +from torch._dynamo.utils import defake, dynamo_timed +from torch._logging import LazyString, trace_structured +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes +from torch.utils._mode_utils import no_dispatch + +from . import config, ir +from .codegen.common import ( + DeviceOpOverrides, + get_device_op_overrides, + get_scheduling_for_device, + get_wrapper_codegen_for_device, + register_backend_for_device, +) +from .codegen.cpp_wrapper_cpu import CppWrapperCpu +from .codegen.cpp_wrapper_cuda import CppWrapperCuda +from .codegen.wrapper import WrapperCodeGen +from .exc import ( + CppWrapperCodeGenError, + LoweringException, + MissingOperatorWithDecomp, + MissingOperatorWithoutDecomp, +) +from .ir import ( + Constant, + FixedLayout, + InputBuffer, + Pointwise, + Reduction, + StorageBox, + TensorBox, +) +from .lowering import ( + constrain_to_fx_strides, + FALLBACK_ALLOW_LIST, + fallback_handler, + fallback_node_due_to_unsupported_type, + layout_constraints, + lowerings, + make_fallback, + needs_realized_inputs, + unsupported_output_tensor, +) +from .sizevars import SizeVarAllocator +from .utils import convert_shape_to_inductor, gather_origins, get_sympy_Expr_dtype +from .virtualized import V + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") + + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_module_code +else: + + def log_module_code(*args, **kwargs): + pass + + +def supported_dtype_of_cpp_wrapper(dtype, cuda): + supported_dtype = { + torch.float32, + torch.float64, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + torch.complex32, + torch.complex64, + torch.complex128, + torch.float16, + } + if cuda: + supported_dtype.add(torch.float8_e4m3fn) + supported_dtype.add(torch.float8_e5m2) + supported_dtype.add(torch.float8_e4m3fnuz) + supported_dtype.add(torch.float8_e5m2fnuz) + + return dtype in supported_dtype + + +def may_get_constant_buffer_dtype(constant_buffer): + assert isinstance( + constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + if isinstance(constant_buffer, sympy.core.numbers.Integer): + return torch.int64 + + if isinstance(constant_buffer, sympy.Expr): + return get_sympy_Expr_dtype(constant_buffer) + + if constant_buffer.is_integer: + return torch.int64 + elif constant_buffer.is_float: + return torch.float32 + else: + return None + + +def is_magic_method(op): + magic_ops = {method_to_operator(m) for m in magic_methods} + return op in magic_ops + + +def getattr_recursive(obj, target): + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class GraphLowering(torch.fx.Interpreter): + graph_outputs: List[ir.IRNode] + + def symbolic_sizes_strides(self, ex: torch.Tensor): + """ + Support dynamic shapes and dynamic strides by assigning variables + to each dimension. We duck-shape tensors, so if two tensors + have the same size they get assigned the same symbolic variable. + """ + if self.reuse_shape_env: + return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( + ex.stride() + ) + else: + from torch._dynamo.source import ConstantSource + + # TODO: this should not be needed once #93059 lands + # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 + # TODO: make a dedicated UnknownSource for this? + # NB: This is using the legacy default behavior from + # create_symbolic_sizes_strides_storage_offset but we hope we can + # just delete this entirely + source = ConstantSource( + f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" + ) + ( + size, + stride, + _, + ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( + ex, + source, + ) + + size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] + stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] + return size, stride + + def static_sizes_strides(self, ex: torch.Tensor): + """ + Primarily used to weights + """ + size = [sympy.Integer(i) for i in ex.size()] + stride = [sympy.Integer(i) for i in ex.stride()] + return size, stride + + def init_backend_registration(self): + if get_scheduling_for_device("cpu") is None: + from .codegen.cpp import CppScheduling + + register_backend_for_device("cpu", CppScheduling, WrapperCodeGen) + + if get_scheduling_for_device("cuda") is None: + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + register_backend_for_device("cuda", CUDACombinedScheduling, WrapperCodeGen) + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Optional[List[torch.Tensor]] = None, + shape_env=None, + num_static_inputs=None, + graph_id=None, + cpp_wrapper=False, + aot_mode=False, + user_visible_outputs=frozenset(), + layout_opt=None, + extern_node_serializer=None, + is_inference=False, + is_const_graph=False, + const_output_index=None, + const_code=None, + const_module=None, + name=None, + ): + super().__init__(gm) + + self.example_inputs = example_inputs + self.layout_opt = ( + layout_opt + if layout_opt is not None + else self.decide_layout_opt(gm, is_inference=is_inference) + ) + self.num_channels_last_conv = 0 + self.is_inference = is_inference + self.is_const_graph = is_const_graph + self.const_code = const_code + self.const_module = const_module + + self.extra_traceback = False # we do our own error wrapping + if shape_env is None: + shape_env = ShapeEnv() + self.reuse_shape_env = False + else: + self._shape_env = shape_env + self.reuse_shape_env = True + self._shape_env = shape_env + self.sizevars = SizeVarAllocator(shape_env) + self.graph_input_names: List[str] = [] + self.graph_inputs: Dict[str, TensorBox] = {} + self.graph_inputs_original: Dict[str, InputBuffer] = {} + self.device_types: Set[str] = ( + const_module.device_types if const_module else set() + ) + self.device_idxs: Set[int] = const_module.device_idxs if const_module else set() + self.cuda = False + self.buffers: List[ir.Buffer] = [] + self.const_output_index: Dict[str, int] = ( + const_output_index if const_output_index else {} + ) + self.folded_constants: Set[str] = ( + set(const_output_index.keys()) if const_output_index else set() + ) + self.constants: Dict[str, torch.Tensor] = ( + const_module.constants if const_module else {} + ) + self.constant_reprs: Dict[str, str] = {} + self.removed_buffers: Set[str] = set() + self.removed_inplace_buffers: Set[str] = set() + self.mutated_buffers: Set[str] = set() + self.never_reuse_buffers: Set[str] = set() + self.inplaced_to_remove: Set[str] = set() + self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] + self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment] + # See `ProxyExecutor Design Note` in ir.py for more details + self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] + self.extern_node_serializer: Optional[ + Callable[[List[ir.ExternKernelNode]], Any] + ] = extern_node_serializer + self.current_node: torch.fx.Node = None # type: ignore[assignment] + self.num_static_inputs = num_static_inputs + self.lists: Dict[str, List[str]] = {} + self.mutated_inputs: Set[str] = set() + self.mutated_input_idxs: List[int] = [] + self.name_to_buffer: Dict[str, ir.Buffer] = {} + self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) + self.creation_time = time.time() + self.name = name + self.cpp_wrapper = cpp_wrapper + + # record multi_kernel choice for cpp_wrapper so the second pass knows + # which sub-kernel is picked. Copy cpp_wrapper to another variable + # since cpp_wrapper flag is set to false for the first pass of codegen. + self.record_multi_kernel_choice = cpp_wrapper + self.multi_kernel_to_choice: Dict[str, int] = {} + + self.aot_mode = aot_mode + self.graph_id = graph_id + self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment] + self.nodes_prefer_channels_last = ( + self.find_nodes_prefer_channels_last() if self.layout_opt else set() + ) + self._warned_fallback = {"aten.convolution_backward"} + self.user_visible_outputs = user_visible_outputs + self.cache_key: str = "" # This is the cache key for the compiled artifact + self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored + self.cache_linemap: List[ + Tuple[int, str] + ] = ( + [] + ) # This is the linemap used by the profiler to mark custom compiled kernels getting run + # Used if lowering encounters cases where cudagraphs are not supported + self.disable_cudagraphs_reason: Optional[str] = None + + # only keeping one node per device for stack trace purposes + self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} + self.orig_gm: torch.fx.GraphModule = gm.__copy__() + self.dynamo_flat_name_to_original_fqn = self.module.meta.get( + "dynamo_flat_name_to_original_fqn", {} + ) + self.allocated_constant_name = ( + const_module.allocated_constant_name if const_module is not None else {} + ) + self.init_backend_registration() + + @staticmethod + def decide_layout_opt(gm, *, is_inference) -> bool: + """ + Decide if we should enable layout optimization for this graph based on + heuristics. + """ + if not config.layout_optimization: + return False + + if config.force_layout_optimization: + return True + + conv_nodes = [ + n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default + ] + nconv = len(conv_nodes) + + if nconv == 0: + return False + + # For cpu backend and mkldnn enabled, we always use channels_last for better performance. + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and all( + n.args[idx].meta["val"].device == torch.device("cpu") + for n in conv_nodes + for idx in [0, 1] + ) + ): + return True + + # Following models are skipped due to this: + # jx_nest_base + # volo_d1_224 + if len(list(gm.graph.nodes)) >= 300 * nconv: + log.debug("Skipped layout opt because only a few conv") + return False + + if any( + has_free_symbols(n.args[idx].meta["val"]) + for n in conv_nodes + for idx in [0, 1] + ): + log.debug( + "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" + ) + return False + + def is_grouped(n): + return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1 + + def is_in_out_channel(n): + return ( + n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) + and n.args[1].meta["val"].size(2) > 1 + ) + + def is_small_channel(n): + return ( + n.args[1].meta["val"].size(0) <= 64 + and n.args[1].meta["val"].size(1) <= 64 + ) + + # only grouped convolutions benchmarked as slower in conv samples for inference only + if is_inference: + from torch.utils.flop_counter import FlopCounterMode + + flop_counts: Dict[str, float] = defaultdict(float) + for node in conv_nodes: + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( + node + ) + + if success: + with FlopCounterMode(display=False) as flop_counter_mode: + with V.fake_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") + + # average benchmarked channels last speedup / slowdown, < 1 is speedup. + # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ + # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb + GROUPED_MULTIPLIER = 1.358 + DEFAULT_MULTIPLIER = 0.823 + IN_OUT_MULTIPLIER = 0.725 + SMALL_MULTIPLIER = 0.783 + + total_flops = sum(flop_counts.values()) + # TODO - get different values per hardware + weighted_flops = ( + flop_counts["grouped"] * GROUPED_MULTIPLIER + + flop_counts["small"] * SMALL_MULTIPLIER + + flop_counts["in_out"] * IN_OUT_MULTIPLIER + + flop_counts["default"] * DEFAULT_MULTIPLIER + ) + do_layout_opt = weighted_flops <= total_flops + if not do_layout_opt: + log.debug( + "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", + total_flops, + weighted_flops, + ) + return do_layout_opt + + # Channels last layout can dramatically hurt grouped conv perf. E.g. + # Conv with arguments like + # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 2} + # slows down 31x using channels last.. + + # But a lot of timm models use depthwise separable convolution which will + # result in grouped convolution with in-channel size == 1. + # For those grouped convolution, channels last still helps a lot. + # E.g. + # Conv with arguments + # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 58} + # get 1.86x speedup with channels last layout. + # + # The following heuristics skip using channels-last if the model contains + # grouped convolution with in-channels > 1. + if any(map(is_grouped, conv_nodes)): + log.debug( + "Skip layout opt because found grouped convolution with >1 in_channels!" + ) + return False + + # For some models that contain convolution with larger in-channel than out-channel, applying + # channels last hurts performance. + # Following models are skipped due to this: + # - pytorch_unet + # - phlippe_densenet (slightly worse) + # - Background_Matting (1.22x -> 0.821x) + # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) + if any(map(is_in_out_channel, conv_nodes)): + log.debug( + "Skip layout opt because some convolutions have smaller out_channel" + ) + return False + + # Following models are skipped due to this: + # - functorch_maml_omniglot + if all(map(is_small_channel, conv_nodes)): + log.debug("Skip layout opt because all convolution channels are too small") + return False + + return True + + def qualify_name(self, name: str) -> str: + """Prepend the given name with the graph name if any.""" + if self.name is not None: + return f"{self.name}_{name}" + return name + + def make_subgraph( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + subgraph_name: str, + ) -> "GraphLowering": + """ + Make a subgraph of the current graph with all inherited + parts, except the graph module (`gm`) and `example_inputs`. + The subgraphs are lowered separately, but intended to be + inlined in the parent graph's codegening. Hence the need + for maintaining the same `shape_env` and other properties. + The subgraph name is qualified by the parent graph's name. + """ + return GraphLowering( + gm=gm, + example_inputs=example_inputs, + shape_env=self._shape_env, + cpp_wrapper=self.cpp_wrapper, + aot_mode=self.aot_mode, + extern_node_serializer=self.extern_node_serializer, + is_inference=self.is_inference, + name=self.qualify_name(subgraph_name), + ) + + def find_nodes_prefer_channels_last(self): + """ + The rule to decide if an node prefer channels last is simple. + 1. if it's input/output of a convolution + 2. if one of its user prefers channels last + + We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; + Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers + channels last. + + Consider the scenario: conv -> batch-norm -> relu -> conv + Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: + 1. the output of batch-norm should be channels last initially since its input is a conv's output. + Forcing the batch-norm's output to be contiguous results in the first copy + 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. + We need convert it to channels last layout which results in the second copy. + With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies + can be saved. + """ + output_set = set() + for n in reversed(self.module.graph.nodes): + if n.target == torch.ops.aten.convolution.default: + output_set.add(n) + continue + + for user in n.users: + if user in output_set: + output_set.add(n) + break + + # need a second pass to add downstream nodes of those channel last nodes to the sets. + # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. + # + # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned + # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. + # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last + # tensors and passed to a kernel. + # + # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. + # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . + # This also helps the following models: + # - res2net101_26w_4s + # - res2net50_14w_8s + # - sebotnet33ts_256 + for n in self.module.graph.nodes: + if n in output_set: + for child in n.users: + output_set.add(child) + + return output_set + + def warn_fallback(self, name): + if name not in self._warned_fallback: + self._warned_fallback.add(name) + perf_hint_log.info("Using FallbackKernel: %s", name) + + def add_device_info(self, device: torch.device): + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) + if V.graph.current_node and device not in self.device_node_mapping: + self.device_node_mapping[device] = V.graph.current_node + + @property + def fake_mode(self): + return V.fake_mode + + def get_buffer(self, buffer_name: str): + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name] + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name] + return None + + def get_dtype(self, buffer_name: str): + if buffer_name in self.constants: + return self.constants[buffer_name].dtype + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name].get_dtype() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) + raise KeyError(f"could not find {buffer_name}") + + def get_numel(self, buffer_name: str): + from .ir import MultiOutputLayout + + if buffer_name in self.constants: + return self.constants[buffer_name].numel() + if buffer_name in self.name_to_buffer: + buf = self.name_to_buffer[buffer_name] + if isinstance(getattr(buf, "layout", None), MultiOutputLayout): + return 1 + return buf.get_numel() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_numel() + raise KeyError(f"could not find {buffer_name}") + + @dynamo_timed + def run(self, *args): + return super().run(*args) + + def register_buffer(self, buffer: ir.Buffer): + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + self.name_to_buffer[name] = buffer + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements(): + self.add_device_info(buffer.get_device()) + return name + + def register_list(self, buffer_names: List[str]): + name = self.qualify_name("list_" + "_".join(buffer_names)) + self.lists[name] = buffer_names + return name + + def register_users_of(self, node_output): + def register(value): + if isinstance(value, (list, tuple)): + for x in value: + register(x) + if isinstance(value, ir.IRNode): + if ( + not hasattr(value, "data") + or not isinstance(value.data, ir.IRNode) + or not ( + hasattr(value.data, "data") + and isinstance(value.data.data, ir.IRNode) + ) + ): + return + + for read_name in value.get_read_names(): + self.name_to_users[read_name].append(value) + + register(node_output) + + def mark_buffer_mutated(self, name: str): + """ + When a buffer is mutated we need to make sure all the reads to + the old version are realized before the mutation happens. + """ + assert isinstance(name, str) + self.mutated_buffers.add(name) + + if name not in self.name_to_users: + return + + for user in self.name_to_users[name]: + user.realize() + + def add_tensor_constant(self, data, name=None): + def allocate(name): + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and torch.eq(data, value).all() + ): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + return name + + new_name = allocate(name) + self.allocated_constant_name[new_name] = name + + return TensorBox.create( + ir.ConstantBuffer( + new_name, + FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), + ) + ) + + def constant_name(self, name: str, device_override: Optional[torch.device]): + """ + We AOT copy constants to the devices they are needed on. + If device_override doesn't match the constant's device, then + copy it and return a different name. + """ + if self.constants[name].device == device_override or device_override is None: + return name + alt_name = f"{name}_{device_override.type}{device_override.index or 0}" + if alt_name not in self.constants: + self.constants[alt_name] = self.constants[name].to(device_override) + return alt_name + + def placeholder(self, target: str, args, kwargs): + example = super().placeholder(target, args, kwargs) + self.graph_input_names.append(target) + if isinstance(example, SymTypes): + expr = example.node.expr + self.graph_inputs[target] = expr + return expr + elif isinstance(example, (int, bool, float)): + expr = sympy.sympify(example) + self.graph_inputs[target] = expr + return expr + if isinstance(example, BackwardState): + # Ignored arg, must be unused + # Alternately we could filter this out in AotAutograd + return None + assert isinstance(example, torch.Tensor), example + # todo(chilli): We can remove the last check once we turn buffers into + # static shape tensors. That's a hack to workaround Inductor believing + # the buffer should be static but us passing in a fake tensor with + # symbolic shapes. + if not example._has_symbolic_sizes_strides: + # the first N inputs are weights + sizes, strides = self.static_sizes_strides(example) + else: + sizes, strides = self.symbolic_sizes_strides(example) + # TODO(jansel): handle input aliasing + target = self.qualify_name(target) + tensor = TensorBox.create( + InputBuffer( + target, + FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + self.graph_inputs[target] = tensor + self.graph_inputs_original[target] = tensor.data.data + self.add_device_info(example.device) + return tensor + + def call_function(self, target, args, kwargs): + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + if hasattr(target, "_inductor_lowering_function"): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + def get_custom_op_layout_constraints(target, args, kwargs): + # Custom operations that require preserving stride order + # which run through implicit fallback must constrain their + # arguments' fx strides + layout_constraint = None + if torch._C.Tag.needs_fixed_stride_order in target.tags: + # We have to set the current args because call_function will immediately + # evaluate this lowering after creating the fallback, without evaluating + # the layout constraint + args, kwargs = constrain_to_fx_strides( + self.current_node, *args, **kwargs + ) + # Also register the layout constraint so when the fallback + # is used again, we can constrain the args to the same layout + layout_constraint = constrain_to_fx_strides + return layout_constraint, args, kwargs + + if target not in lowerings: + assert isinstance( + target, torch._ops.OpOverload + ), f"{target} is not an OpOverload" + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target) + elif config.implicit_fallbacks: + layout_constraint, args, kwargs = get_custom_op_layout_constraints( + target, args, kwargs + ) + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target, layout_constraint) + + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + log.debug(" via %s", lowerings[target]) + out = lowerings[target](*args, **kwargs) + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None + + @staticmethod + def can_inline_constant(t: torch.Tensor) -> bool: + """ + True if this is a small constant attr that will be inlined. + """ + return len(t.shape) == 1 and t.shape[0] <= 8 + + def get_attr(self, target, args, kwargs): + # this is a constant + value = getattr_recursive(self.module, target) + + if isinstance(value, torch.fx.GraphModule): + return ir.Subgraph(name=target, graph_module=value) + + if ( + config.aot_inductor.use_runtime_constant_folding + or config.always_keep_tensor_constants + or unsupported_output_tensor(value) + ): + return self.add_tensor_constant(value, target) + + with no_dispatch(): + if value.shape == (): + return Constant(value.item(), value.dtype, value.device) + if self.can_inline_constant(value): + # tensor lowering has constant inlining logic + from .lowering import tensor + + return tensor(value.tolist(), dtype=value.dtype, device=value.device) + + return self.add_tensor_constant(value, target) + + def call_module(self, target, args, kwargs): + raise AssertionError() + + def call_method(self, target, args, kwargs): + raise AssertionError() + + def output(self, target, args, kwargs): + result = super().output(target, args, kwargs) + assert isinstance(result, (tuple, list)), type(result) + assert all( + isinstance( + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + sympy.logic.boolalg.Boolean, + int, + ), + ) + for x in result + ), result + self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result] + value: ir.IRNode + for name, value in self.graph_inputs.items(): + assert isinstance( + value, (TensorBox, sympy.Expr) + ), f"Unsupported inductor graph input type: {type(value)}" + if not isinstance(value, TensorBox): + continue + value.realize() + assert isinstance(value, TensorBox) + value = value.data + assert isinstance(value, ir.StorageBox) + value_storage_box = value + value = value.data + if not isinstance(value, InputBuffer) or value.get_name() != name: + # one of our inputs was mutated, need to turn that into a copy + ir.MutationLayout.realize_into(value, self.graph_inputs_original[name]) + # replace output with mutated input + try: + ind = self.graph_outputs.index(value_storage_box) + self.graph_outputs[ind] = self.graph_inputs_original[name] + except ValueError: + pass + + self.finalize() + log.debug( + "Force channels last inputs for %d conv for the current graph with id %d", + self.num_channels_last_conv, + self.graph_id if self.graph_id is not None else -1, + ) + + def finalize(self): + for buf in self.buffers: + buf.decide_layout() + + @contextmanager + def set_current_node(self, node: torch.fx.Node): + old = self.current_node + try: + self.current_node = node + yield + finally: + self.current_node = old + + def run_node(self, n: torch.fx.Node): + def debug(msg): + log.debug("lowering %s %s", LazyString(n.format_node), msg) + + origins = {n} + if n.op == "call_function": + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ir.IRNode.current_origins(origins), self.set_current_node( + n + ), V.set_current_node(n): + if ( + n.op == "call_function" + and n.target is not operator.getitem + and fallback_node_due_to_unsupported_type(n) + ): + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, **kwargs # type: ignore[possibly-undefined] + ) + elif n.op == "call_function" and n.target in layout_constraints: + debug("layout_constraints") + args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) + elif is_magic_method(n.target): + # TODO: this is sus, it probably should be handled in the + # lowerings themselves similarly to sym_size/sym-stride + debug("is_magic_method") + if isinstance(n.meta["val"], torch.SymInt): + result = n.meta["val"].node.expr + else: + result = super().run_node(n) + else: + debug("") + result = super().run_node(n) + + # require the same stride order for dense outputs, + # 1. user-land view() will not throw because inductor + # output different strides than eager + # long term the solution is to make view() always succeed + # with infallible strides. + # 2: as_strided ops, we need make sure its input has same size/stride with + # eager model to align with eager behavior. + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + is_output = any(user.op == "output" for user in n.users) + is_input_for_as_strided = any( + user.target in as_strided_ops for user in n.users + ) + if ( + is_output + and isinstance(result, TensorBox) + and isinstance(result.data, ir.BaseView) + ): + # Realize so that outputs are correctly aliased + result.realize() + + if (is_output or is_input_for_as_strided) and isinstance( + n.meta["val"], torch.Tensor + ): + strides = n.meta["val"].stride() + dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]) + # requiring a stride order for a non-dense output wouldn't + # recreate the same strides, and would fail with view, defer for now. + if dense and len(strides): + stride_order = ir.get_stride_order(strides) + if ( + len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and n.name not in self.user_visible_outputs + and not is_input_for_as_strided + ): + stride_order = ir.NHWC_STRIDE_ORDER + result = ir.ExternKernel.require_stride_order(result, stride_order) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. + num_users = len(set(n.users)) + if num_users > 1 and isinstance(result, TensorBox): + for user in n.users: + if user.target in needs_realized_inputs: + result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometimes result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + need_fixed_layout = [ + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + torch.ops.aten._int_mm.default, + ] + if not self.layout_opt: + need_fixed_layout.append(torch.ops.aten.convolution.default) + if torch._C._has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: + result = ir.ExternKernel.require_stride_order( + result, ir.get_stride_order(n.meta["val"].stride()) + ) + if user.op == "output": + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() + + # TODO(jansel): introduce a store vs inline choice + result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches each with small number of memory + # reads, but they converge to a user. + result.realize_hint() + + # Realize if a Pointwise has too much stuff to be inlined. + # As this may cause RecursionError during Inductor's evaluation. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + curr = result.data.data + if isinstance(curr, Pointwise): + # Use inner fn as a rough proxy. Good enough. + if curr.has_large_inner_fn(): + result.realize() + + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): + if isinstance(result.data.data, ir.Loops): + result.data.data.origin_node = n + elif isinstance(result.data.data, ir.Buffer): + result.data.data.origin_node = n + if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( + result.data.data.data, ir.Loops + ): + result.data.data.data.origin_node = n + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, ir.MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], ir.Buffer): + result.data.data.inputs[0].origin_node = n + + self.register_users_of(result) + + return result + + def validate_can_generate_cpp_wrapper(self): + if config.disable_cpp_codegen: + raise CppWrapperCodeGenError("C++ codegen is disabled") + + if sys.platform not in ["linux", "darwin"]: + raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}") + + for value in self.graph_inputs.values(): + dtype = None + if isinstance(value, TensorBox): + dtype = value.get_dtype() + elif isinstance( + value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ): + dtype = may_get_constant_buffer_dtype(value) + + if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): + raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") + + def init_wrapper_code(self): + self.cuda = "cuda" in self.device_types + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() + self.wrapper_code = CppWrapperCuda() if self.cuda else CppWrapperCpu() + else: + device_types = self.device_types.copy() + device_types.discard("cpu") + # TODO(Eikan): Only support mixing cpu and other device now. + assert len(device_types) <= 1, "Does not support mixing {}".format( + "+".join(device_types) + ) + only_cpu = len(device_types) == 0 + device_type = "cpu" if only_cpu else device_types.pop() + + self.device_ops = get_device_op_overrides(device_type) + wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type) + assert ( + wrapper_code_gen_cls is not None + ), f"Device {device_type} not supported" + self.wrapper_code = wrapper_code_gen_cls() + + if self.const_module: + # If we have const module, we could reuse the kernels + # This could avoid duplication and save time on doing recompilation (if Triton.) + self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter + self.wrapper_code.src_to_kernel = ( + self.const_module.wrapper_code.src_to_kernel + ) + + def codegen_with_cpp_wrapper(self): + """ + For CPU, the cpp wrapper codegen is done in one pass. + For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python + wrapper code and run it to generate autotuned kernel binaries in the first pass; and then + generate cpp wrapper code and compile it to a dynamic library in the second pass. + """ + if "cuda" in self.device_types: + # first pass + self.cpp_wrapper = False + compiled = self.compile_to_module().call + + def materialize(x): + if isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance( + x, torch.Tensor + ), "Unknown type when creating real inputs" + str(type(x)) + return x + + if tracing_context := torch._guards.TracingContext.try_get(): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + real_inputs = [materialize(x) for x in V.real_inputs] + + with torch.utils._python_dispatch._disable_current_modes(): + assert self.example_inputs is not None + compiled(real_inputs) + del real_inputs + + # second pass + # TODO: reuse self.scheduler from the first pass to speed up the second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.inplaced_to_remove.clear() + return self.codegen() + else: + # cpu + return self.codegen() + + def codegen(self): + from .scheduler import Scheduler + + self.init_wrapper_code() + + self.scheduler = Scheduler(self.buffers) + V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) + + self.scheduler.codegen() + return self.wrapper_code.generate(self.is_inference) + + def codegen_subgraph(self, parent_graph): + """ + This is a more compact version of the `codegen()` above + where we codegen this graph as a subgraph of some parent + graph. The parent graph is passed as an argument: the + intention is to inline codegening of the subgraph in + the parent graph's wrapper code (including the generated + kerenls). The wrapper code is not finalized (via `.generate()` + call), as this will be done in the parent graph's `codegen()`. + """ + from .scheduler import Scheduler + + self.wrapper_code = parent_graph.wrapper_code + self.device_ops = parent_graph.device_ops + self.cpp_wrapper = parent_graph.cpp_wrapper + + self.scheduler = Scheduler(self.buffers) + self.scheduler.codegen() + + def count_bytes(self): + from .scheduler import Scheduler + + scheduler = Scheduler(self.buffers) + + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in scheduler.nodes: + num_bytes = node.get_read_write_buffers_sizes() + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + return total_bytes, node_counts, node_runtimes + + @dynamo_timed(phase_name="code_gen") + def compile_to_module(self): + from .codecache import PyCodeCache + + code, linemap = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + linemap = [(line_no, node.stack_trace) for line_no, node in linemap] + key, path = PyCodeCache.write(code) + mod = PyCodeCache.load_by_key_path( + key, path, linemap=linemap, attrs=self.constants + ) + self.cache_key = key + self.cache_path = path + self.cache_linemap = linemap + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.debug("Output code: \n%s", code) + trace_structured( + "inductor_output_code", + lambda: {"filename": mod.__file__}, + payload_fn=lambda: code, + ) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + return mod + + def compile_to_fn(self): + if self.aot_mode: + from .codecache import AotCodeCompiler + + assert self.cpp_wrapper, "AOT mode only supports C++ wrapper" + code, linemap = self.codegen_with_cpp_wrapper() + output_code_log.debug("Output code: \n%s", code) + + serialized_extern_kernel_nodes = None + if ( + config.is_fbcode() + and self.extern_kernel_nodes + and self.extern_node_serializer + ): + serialized_extern_kernel_nodes = self.extern_node_serializer( + self.extern_kernel_nodes + ) + output_code_log.debug( + "Serialized Extern Kernel Nodes: \n%s", + serialized_extern_kernel_nodes, + ) + + # Directly return the file path with the compiled code + return AotCodeCompiler.compile( + self, code, serialized_extern_kernel_nodes, cuda=self.cuda + ) + else: + return self.compile_to_module().call + + def get_output_names(self): + return [ + node.get_name() + for node in self.graph_outputs + if not isinstance(node, ir.NoneAsConstantBuffer) + and not isinstance(node, ir.ShapeAsConstantBuffer) + ] + + def is_unspec_arg(self, name: str): + # dynamo wraps unspec variable as 0d CPU tensor, + # need to convert to scalar during codegen (triton only) + return ( + name in self.graph_inputs.keys() + and self.graph_inputs[name].get_numel() == 1 + and self.graph_inputs[name].get_device().type == "cpu" + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/hooks.py b/MLPY/Lib/site-packages/torch/_inductor/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..761e5156553dc9aa71bb925972125bd8e0eda31c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/hooks.py @@ -0,0 +1,28 @@ +import contextlib +from typing import Callable, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + +# Executed in the order they're registered +INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = [] + + +@contextlib.contextmanager +def intermediate_hook(fn): + INTERMEDIATE_HOOKS.append(fn) + try: + yield + finally: + INTERMEDIATE_HOOKS.pop() + + +def run_intermediate_hooks(name, val): + global INTERMEDIATE_HOOKS + hooks = INTERMEDIATE_HOOKS + INTERMEDIATE_HOOKS = [] + try: + for hook in hooks: + hook(name, val) + finally: + INTERMEDIATE_HOOKS = hooks diff --git a/MLPY/Lib/site-packages/torch/_inductor/index_propagation.py b/MLPY/Lib/site-packages/torch/_inductor/index_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..a5f60972fe422c58cb52a778e9cd0ac2c7cb49ed --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/index_propagation.py @@ -0,0 +1,277 @@ +"""This file implements the IndexPropagation ops handler, which wraps an +underlying handler to add a limited form of constant propagation, as well as +propagation of sympy expressions downstream of ops.index_expr calls. + +For example, say we have the IR: + + tmp0 = ops.index_expr(x, torch.int32) + tmp1 = ops.constant(2, torch.int32) + tmp2 = ops.mul(tmp0, tmp1) + tmp3 = ops.indirect_indexing(tmp2, x_size) + tmp4 = ops.load("buf0", tmp3) + +The underlying handler would just see: + + ops.load("buf0", x * 2) + +This is limited by the set of operators handled in the sympy expression +printers. So simple operations like minimum and maximum cannot be translated to +SymPy expressions yet, despite sympy.Min and sympy.Max existing. + +""" +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union + +import sympy + +from typing_extensions import TypeAlias + +import torch +from torch._prims_common import is_boolean_dtype, is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where + + +@dataclass +class TypedExpr: + """A SymPy expression with associated type""" + + expr: sympy.Expr + dtype: torch.dtype + + +class SymPyOps: + """An ops handler where all IR values are SymPy expressions + + When a value cannot be represented as a SymPy expression, the method is + either not defined, or returns NotImplemented + + """ + + @staticmethod + def identity(value: Any) -> Any: + return value + + @staticmethod + def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: + if is_boolean_dtype(dtype): + expr = sympy.Integer(bool(value)) + elif is_integer_dtype(dtype): + expr = sympy.Integer(int(value)) + else: + expr = sympy.Float(float(value)) + return TypedExpr(expr, dtype) + + @staticmethod + def index_expr(value: sympy.Expr, dtype: torch.dtype) -> Union[int, TypedExpr]: + if isinstance(value, int): + value = sympy.Integer(value) + return TypedExpr(value, dtype) + + @staticmethod + def to_dtype( + value: Any, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None + ) -> Union[int, TypedExpr]: + if isinstance(value.expr, (sympy.Integer, sympy.Float)): + return SymPyOps.constant(value.expr, dtype) + elif is_integer_dtype(dtype) and is_integer_dtype(value.dtype): + return SymPyOps.index_expr(value.expr, dtype) + else: + # TODO: Inductor doesn't handle floating point in sympy expressions well at the moment + return NotImplemented + + @staticmethod + def square(x: TypedExpr) -> TypedExpr: + return TypedExpr(x.expr * x.expr, x.dtype) + + @staticmethod + def add(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr + y.expr, result_type) + + @staticmethod + def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr - y.expr, result_type) + + @staticmethod + def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr * y.expr, result_type) + + @staticmethod + def neg(x: TypedExpr) -> TypedExpr: + return TypedExpr(-x.expr, x.dtype) + + @staticmethod + def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + return TypedExpr(FloorDiv(x.expr, y.expr), result_type) + + @staticmethod + def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + return TypedExpr(result_expr, result_type) + + @staticmethod + def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + # In these cases, remainder in Python == remainder in C++, so this transformation + # is sound + if ( + x.expr.is_nonnegative is not None + and x.expr.is_nonnegative == y.expr.is_positive + ): + result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) + return TypedExpr(result_expr, result_type) + return NotImplemented + + @staticmethod + def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Min(x.expr, y.expr), result_type) + + @staticmethod + def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Max(x.expr, y.expr), result_type) + + +@dataclass +class IndexPropVar: + value: Any # Either an IR value, or TypedExpr if is_symbolic is true + is_symbolic: bool = False + + @staticmethod + def new_symbolic(expr: TypedExpr) -> "IndexPropVar": + return IndexPropVar(expr, is_symbolic=True) + + def __post_init__(self): + assert not self.is_symbolic or isinstance( + self.value, TypedExpr + ), "Symbolic IndexPropVar must contain a TypedExpr" + + +IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]] + + +class IndexPropagation: + """Ops wrapper that tries to propagate constant and index_expr values through the computation. + + This aims to maximize the compile time simplification possible, and convert + indirect indexing from arange into normal static indexing. + + """ + + def __init__(self, inner: Any): + self._inner = inner + + def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any: + # Construct a new constant/index_expr from the SymPy expression + if isinstance(expr, sympy.Integer): + return self._inner.constant(int(expr), dtype) + elif expr.is_number: + return self._inner.constant(float(expr), dtype) + return self._inner.index_expr(expr, dtype) + + def unwrap(self, a: Union[Any, IndexPropVar]) -> Any: + if isinstance(a, (list, tuple)): + return tuple(self.unwrap(v) for v in a) + + if not isinstance(a, IndexPropVar): + return a + + # Prefer the sympy representation if possible + if a.is_symbolic: + return self.materialize_expr(a.value.expr, a.value.dtype) + + return a.value + + def wrap(self, a) -> IndexPropResult: + if isinstance(a, (list, tuple)): + return tuple(self.wrap(v) for v in a) + return IndexPropVar(a) + + @overload + def fallback( + self, + name: Literal["indirect_indexing"], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> IndexPropVar: + ... + + @overload + def fallback( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + ... + + def fallback( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + # Fallback to the wrapped handler + new_args = [self.unwrap(a) for a in args] + new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()} + return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) + + def propagate_sympy( + self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> IndexPropResult: + # Build a new SymPy expression from this ops call + def unwrap(a: Union[Any, IndexPropVar]) -> Any: + if not isinstance(a, IndexPropVar): + return a + return a.value + + new_args = [unwrap(a) for a in args] + new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} + new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs) + is_valid_expr = new_expr is not NotImplemented and ( + # Inductor doesn't expect floating point in sympy expressions, but + # allow floating point constants to be propagated + isinstance(new_expr.expr, sympy.Number) + or new_expr.expr.is_integer + ) + if not is_valid_expr: + return self.fallback(name, args, kwargs) + return IndexPropVar.new_symbolic(new_expr) + + def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: + def inner(*args: Any, **kwargs: Any) -> IndexPropResult: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) + + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) + + return self.propagate_sympy(name, args, kwargs) + + return inner + + def indirect_indexing( + self, index: Union[Any, IndexPropVar], size: Any, check: bool = True + ) -> Any: + # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE + # for SymPy expressions, so we don't want to repeat idx too much + + # indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here + if isinstance(index, IndexPropVar) and index.is_symbolic: + # If we are turning a indirect indexing into direct, we need to wrap it. + index = index.value.expr + return index + Where(index >= 0, 0, size) + return self.fallback("indirect_indexing", (index, size, check), {}).value diff --git a/MLPY/Lib/site-packages/torch/_inductor/inductor_prims.py b/MLPY/Lib/site-packages/torch/_inductor/inductor_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..df99fb4b5ca4bdc132c8dde28b641e05b789c702 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/inductor_prims.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import logging +from typing import Optional, Sequence + +import torch +from torch import _prims, Tensor + +log = logging.getLogger(__name__) + + +def make_prim( + schema: str, + impl_aten, + return_type=_prims.RETURN_TYPE.NEW, + doc: str = "", + tags: Optional[Sequence[torch.Tag]] = None, +): + def meta(*args, **kwargs): + return _prims.TensorMeta(impl_aten(*args, **kwargs)) + + return _prims._make_prim( + schema=schema, + return_type=return_type, + meta=meta, + impl_aten=impl_aten, + doc=doc, + tags=tags, + ) + + +def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: + if input_tensor.stride() == stride: + return input_tensor + new_tensor = input_tensor.clone().as_strided( + input_tensor.shape, + stride, + ) + new_tensor.copy_(input_tensor) + return new_tensor + + +# Custom prims used for handling randomness +seed = make_prim( + "inductor_seed(Device device) -> Tensor", + lambda device: torch.randint(2**63 - 1, [], device=device), + doc="create a fresh seed (one per call) for use with inductor_rand", + tags=(torch.Tag.nondeterministic_seeded,), +) +seeds = make_prim( + "inductor_seeds(int count, Device device) -> Tensor", + lambda count, device: torch.randint(2**63 - 1, [count], device=device), + doc="Horizontal fusion of many inductor_seed() calls", + tags=(torch.Tag.nondeterministic_seeded,), +) +lookup_seed = make_prim( + # if inductor_lookup_seed changes, update partitioners.py + "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", + lambda seeds, index: seeds[index], + doc="Extract a single seed from the result of inductor_seeds()", +) +random = make_prim( + "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", + lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), + doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", +) +randint = make_prim( + "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", + lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), + doc="torch.randint() using backend-specific RNG that can be fused", +) +force_stride_order = make_prim( + "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", + eager_force_stride, + doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", +) +masked_scatter_with_index = make_prim( + "inductor_masked_scatter_with_index(Tensor input, Tensor mask, Tensor source_idx, Tensor source) -> Tensor", + lambda input_tensor, mask, index, source: torch.masked_scatter( + input_tensor, mask, source + ), + doc="masked_scatter with precomputed indices", +) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) diff --git a/MLPY/Lib/site-packages/torch/_inductor/ir.py b/MLPY/Lib/site-packages/torch/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6d7a959f9c37ed1f4eba980203a81cc6378812 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/ir.py @@ -0,0 +1,8064 @@ +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import re +import textwrap +import traceback +from contextlib import nullcontext +from enum import Enum +from functools import partial +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) +from unittest.mock import patch + +import sympy +from sympy import Expr, Integer + +import torch._export.serde.schema as export_schema + +import torch._logging + +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity +from torch._export.serde.serialize import GraphModuleSerializer +from torch._higher_order_ops.auto_functionalize import can_auto_functionalize +from torch._prims_common import ( + compute_required_storage_length, + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + make_contiguous_strides_for, + StrideType, +) +from torch._subclasses.fake_tensor import get_schema_info +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymTypes +from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .dependencies import ( + extract_free_unbacked_symbols, + extract_input_node_reduction_ranges, + extract_read_writes, + var_builder, +) +from .ops_handler import OpCounterCSE +from .utils import ( + argsort, + cache_on_self, + convert_shape_to_inductor, + convert_shape_to_symint, + developer_warning, + get_kernel_metadata, + is_dynamic, + pad_listlike, + sympy_dot, + sympy_index_symbol, + sympy_product, + sympy_subs, +) +from .virtualized import ops, V + +if TYPE_CHECKING: + from .graph import GraphLowering + +log = logging.getLogger(__name__) +indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten + +""" [Note: Inductor IR] + +Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each +lowering is registered to a particular aten operator, and expects inputs that +correspond to the aten schema. However, in place of torch Tensor inputs, lowerings +expect Inductor TensorBox inputs. + +TensorBox IR represents torch tensors. Tensors are sometimes single objects owning +storage, and sometimes views of another Tensor's storage. Mutating tensor operations +(such as add_()) affect the underlying storage and any associated views. Other operations +(such as .t_()) update metadata about the current view but don't modify the underlying storage. + +To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. + +TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor +output from an operation. But just as torch.Tensors take different forms, TensorBox IR can +reference View IR or directly reference StorageBox IRs. + +Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) +may take an existing TensorBox and point it to a new underlying View IR. + +Tensors that directly own storage are represented as a chain of: +TensorBox -> StorageBox -> Buffer +where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. + +If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer +(leaving the old buffer unmodified and functionalizing the operation). + +Tensors backed by views add one more indirection to the IR. +TensorBox -> View -> StorageBox -> Buffer +In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. +""" + + +def validate_ir(node_or_nodes): + def _check_tensorbox(nodes): + # Could expand this to check deeper properties + # (e.g. TensorBox points to View or StorageBox) + if isinstance(nodes, (list, tuple)): + for node in nodes: + _check_tensorbox(node) + elif isinstance(nodes, dict): + for node in nodes.values(): + _check_tensorbox(node) + else: + assert isinstance( + nodes, + ( + torch._inductor.ir.ExpandView, + DynamicScalar, + AssertScalar, + TensorBox, + sympy.logic.boolalg.Boolean, + Expr, + ), + ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + + # Be picky about the accepted data structure (don't use pytree here) + _check_tensorbox(node_or_nodes) + + +def ops_wrapper(name): + assert isinstance(name, str) + + def fn(*args, **kwargs): + return getattr(ops, name)(*args, **kwargs) + + return fn + + +def inverse_reorder(order): + inv_order = dict(zip(order, range(len(order)))) + + def reindex(index): + assert len(index) == len(inv_order) + return [index[inv_order[i]] for i in range(len(index))] + + return reindex + + +def same_reorder(order): + def reindex(index): + assert len(index) == len(order) + return [index[order[i]] for i in range(len(index))] + + return reindex + + +def fuse_reindexing(reindex1, reindex2): + def reindex(index): + return reindex1(reindex2(index)) + + return reindex + + +NHWC_STRIDE_ORDER = [3, 0, 2, 1] + + +def stride_order2fill_order(order): + """ + Convert stride order to fill order + For channel last format, + stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] + """ + lookup = {pos: idx for idx, pos in enumerate(order)} + fill_order = [lookup[i] for i in range(len(order))] + return fill_order + + +def get_stride_order(seq: Sequence[int]) -> List[int]: + """ + Convert strides to stride order + """ + sorted_idx: List[int] = argsort(seq) + out = [0 for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + +def ir_node_to_tensor(x, guard_shape=True): + if x is None: + return None + + shape_fn: Callable[[Expr], Union[int, Expr]] + if not guard_shape: + shape_fn = V.graph.sizevars.size_hint + else: + shape_fn = identity + size = [shape_fn(s) for s in x.get_size()] + stride: StrideType + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] + else: + stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + dtype = x.get_dtype() + device = x.get_device() + size = convert_shape_to_symint(size) + stride = convert_shape_to_symint(stride) + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + +def may_convert_to_optional(value): + if isinstance(value, list) and not value: + # [None] makes sure the cpp wrapper codegen will generate something like + # {c10::nullopt} instead of {} + return [None] + return value + + +def get_device_type(x): + if getattr(x, "get_device", None): + return get_device_type(x.get_device()) + if isinstance(x, torch.device): + return x.type + return None + + +def is_triton(x): + return get_device_type(x) == "cuda" + + +def is_cpu(x): + return get_device_type(x) == "cpu" + + +class IRNode: + _current_origins: ClassVar[Set[Any]] = set() + + @staticmethod + @contextlib.contextmanager + def current_origins(origins: Set[torch.fx.Node]): + old = IRNode._current_origins + IRNode._current_origins = old | origins + try: + yield + finally: + IRNode._current_origins = old + + def __post_init__(self): + self.origins = set(self._current_origins) + self.traceback = traceback.format_stack() if config.debug_ir_traceback else None + + def get_traceback(self): + return self.traceback + + def common_repr(self): + origins = f"origins={getattr(self, 'origins', '')}" + if len(origins) > 64: + # this can get *very* long + origins = f"{origins[:61]}..." + return [origins] + + def str_helper(self, lines): + lines = lines + self.common_repr() + lines = indent(",\n".join(map(str, lines))) + return f"{type(self).__name__}(\n{lines}\n)" + + def is_user_of(self, name): + return name in self.get_read_names() + + @cache_on_self + def get_read_names(self): + return {dep.name for dep in self.get_reads()} + + def get_dtype(self): + return self.dtype + + def get_layout(self): + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def get_size(self): + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + + def get_numel(self): + return sympy_product(self.get_size()) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def realize(self): + """ + If the IRNode refers to data which has not been materialized (e.g., + it is a Pointwise/Reduction that could potentially have more + compute fused into it), realize the IRNode into physical memory, + ending the possibility of fusing into it, but allowing, e.g., multiple + users to access the data without having to recompute. + + Check StorageBox.realize for a particularly notable implementation. + + TODO(ezyang): I think, in principle, every IRNode should have an + implementation of this, and most of the time no-op is OK, but you + really do have to audit each IRNode for this, so for now, raise + an error if it's not implemented. Note that some code in graph.py + will catch this thrown error and suppress it with a warning. + """ + raise NotImplementedError(f"realize NYI on {type(self)}") + + def codegen_reference(self, writer=None): + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + + # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions + # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of + # the code dynamically check for defined attributes. + get_device: Callable[[], torch.device] + dtype: torch.dtype + get_name: Callable[[], str] + get_reads: Callable[[], Any] + get_stride: Callable[[], Any] + get_storage_numel: Callable[[], Any] + has_exceeded_max_reads: Callable[[], bool] + make_loader: Callable[[], Callable[[Any], Any]] + make_indexer: Callable[[], Callable[[Any], Any]] + mark_reuse: Callable[[int], None] + realize_hint: Callable[[], None] + get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]] + + +@dataclasses.dataclass +class Loops(IRNode): + device: torch.device + dtype: torch.dtype + inner_fn: Callable[..., Any] + ranges: List[Expr] + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return set().union( + *(free_unbacked_symbols(e) for e in self.ranges), + self.inner_fn_free_unbacked_symbols(), + ) + + def __str__(self, names=("ranges",)): + return self.str_helper( + [ + f"'{self.device.type}'", + str(self.dtype), + self.inner_fn_str(), + ] + + [f"{name}={getattr(self, name)}" for name in names] + + [f"origin_node={self.origin_node!r}"] + ) + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + __repr__ = __str__ + + def get_device(self): + return self.device + + def get_origin_node(self): + return self.origin_node + + def get_size(self): + return self.ranges + + def get_pointwise_size(self): + return self.ranges + + def is_extern(self): + return False + + @classmethod + def create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + r.origin_node = origin_node + r.traceback = ( + tb or traceback.format_stack() if config.debug_ir_traceback else None + ) + return TensorBox.create(r) + + @staticmethod + def _index(ranges, prefix="i"): + return [ + sympy.Integer(0) if s == 1 else sympy_index_symbol(f"{prefix}{n}") + for n, s in enumerate(ranges) + ] + + @cache_on_self + def inner_fn_opcount(self): + from .ir import FlexibleLayout + + opcounter = OpCounterCSE(V.MockHandler()) + + with V.set_ops_handler(opcounter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = self.inner_fn(*self.inner_fn_args()) + return opcounter.op_count + + def inner_fn_args(self): + return (self._index(self.ranges),) + + def inner_fn_str(self): + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self): + return self.inner_fn_opcount() > config.realize_opcount_threshold + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + return extract_free_unbacked_symbols(self.inner_fn, index) + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.get_reduction_type(): + return extract_read_writes( + self.make_loader(), + self.get_size(), + self.get_reduction_size(), + ).reads + else: + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def get_reduction_size(self): + raise NotImplementedError( + f"get_reduction_size() is not implemented by {type(self)}!" + ) + + def get_reduction_type(self): + raise NotImplementedError( + f"get_reduction_type() is not implemented by {type(self)}!" + ) + + def constant_to_device(self, device): + raise NotImplementedError( + f"constant_to_device() is not implemented by {type(self)}!" + ) + + +def nop_loader_fn(idx, *, dtype): + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + +class Pointwise(Loops): + def make_loader(self): + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + + return self.inner_fn + + def get_reduction_size(self): + return [] + + def get_reduction_type(self): + return None + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store(output_name, indexer(vars), loader(vars)) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.dtype, loader, self.ranges) + + +@dataclasses.dataclass +class Scatter(Pointwise): + output_indexer: Callable[[List[Expr]], Expr] + scatter_mode: Optional[str] = None + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Scatter( + device, + self.dtype, + loader, + self.ranges, + self.output_indexer, + self.scatter_mode, + ) + + def store_output(self, output_name, indexer, vars): + loader = self.make_loader() + return ops.store( + output_name, + indexer(self.output_indexer(vars)), + loader(vars), + mode=self.scatter_mode, + ) + + +class ReductionHint(Enum): + INNER = 0 + OUTER = 1 + OUTER_TINY = 2 + DEFAULT = 3 + + +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + +REDUCTION_COMBINE_FN = { + "any": ops_wrapper("logical_or"), + "max": ops_wrapper("maximum"), + "min": ops_wrapper("minimum"), + "prod": ops_wrapper("mul"), + "sum": ops_wrapper("add"), + "xor_sum": ops_wrapper("bitwise_xor"), +} + + +def get_reduction_combine_fn(reduction_type, dtype): + if reduction_type in REDUCTION_COMBINE_FN: + combine_fn = REDUCTION_COMBINE_FN[reduction_type] + elif reduction_type in {"argmax", "argmin"}: + + def combine_fn(a, b): + a_value, a_index = a + b_value, b_index = b + + if reduction_type == "argmin": + mask = ops.lt(a_value, b_value) + else: + mask = ops.gt(a_value, b_value) + + equal = ops.eq(a_value, b_value) + if is_float_dtype(dtype): + a_isnan = ops.ne(a_value, a_value) + b_isnan = ops.ne(b_value, b_value) + mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) + equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) + + mask = ops.logical_or( + mask, ops.logical_and(equal, ops.lt(a_index, b_index)) + ) + return ( + ops.where(mask, a_value, b_value), + ops.where(mask, a_index, b_index), + ) + + elif reduction_type == "welford_combine": + + def combine_fn(a, b): + a_mean, a_m2, a_weight = a + b_mean, b_m2, b_weight = b + + delta = b_mean - a_mean + new_weight = a_weight + b_weight + w2_over_w = b_weight / new_weight + return ( + a_mean + delta * w2_over_w, + a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, + new_weight, + ) + + else: + raise NotImplementedError(f"unknown reduction_type={reduction_type}") + + return combine_fn + + +@dataclasses.dataclass +class Reduction(Loops): + reduction_ranges: List[Expr] + reduction_type: str + # self.dtype represents the dst dtype + src_dtype: torch.dtype + reduction_hint: ReductionHint + + def __str__(self): + return Loops.__str__( # type: ignore[call-arg] + self, names=("ranges", "reduction_ranges", "reduction_type") + ) + + def __repr__(self): + return self.__str__() + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return super().get_unbacked_symbol_uses() | set().union( + *(free_unbacked_symbols(e) for e in self.reduction_ranges) + ) + + def get_reduction_size(self): + return self.reduction_ranges + + def get_reduction_type(self): + return self.reduction_type + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + value = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + return ops.store_reduction(output_name, indexer(vars), value) + + def index_length(self): + return len(self.ranges) + len(self.reduction_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, "r") + return (index, rindex) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, "r") + return extract_free_unbacked_symbols(self.inner_fn, index, rindex) + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Reduction( + device, + self.dtype, + loader, + self.ranges, + self.reduction_ranges, + self.reduction_type, + self.src_dtype, + ReductionHint.DEFAULT, + ) + + @staticmethod + def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + def _is_static(x): + return isinstance(x, (int, sympy.Integer)) + + reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) + numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) + + should_split = ( + is_triton(device) + and reduction_type + not in { + "argmax", + "argmin", + } + and config.split_reductions + # We don't support unbacked symints + and _is_static(reduction_numel_hint) + and _is_static(numel_hint) + ) + if not should_split: + return ReductionHint.DEFAULT, 1 + + device_interface = get_interface_for_device(get_device_type(device)) + num_sm = device_interface.Worker.get_device_properties( + device + ).multi_processor_count + min_elements_per_thread = 32 + max_elements_per_thread = 512 + threads_per_sm = 2048 + min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm + max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm + + def inner_reduction_splits(reduction_numel_hint, numel_hint): + # do heuristics that's close to eager mode for split inner reduction + # we leak reduction autotune configs here, and will need to refactor to avoid this later + num_warps = 8 + num_threads = 32 * num_warps + if numel_hint >= 2 * num_sm: # don't split if there are enough outputs + return 1 + if reduction_numel_hint <= 8192: + return 1 + if reduction_numel_hint * numel_hint <= min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (2 * num_threads) + blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint + tmp_split_size = ( + reduction_numel_hint + num_threads * blocks_per_output - 1 + ) // (num_threads * blocks_per_output) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(closest - tmp_split_size) < 30: + # prefer even splits, but never smalle than min_elements_per_thread + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + return (reduction_numel_hint + split_size * num_threads - 1) // ( + split_size * num_threads + ) + + def outer_reduction_splits(reduction_numel_hint, numel_hint): + # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 + # extend to even smaller number of outputs + num_warps = 8 + num_threads = num_warps * 32 + rvals_per_thread = 4 # comes from heuristics, refactor to not leak here + xvals_per_block = 128 + xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block + if reduction_numel_hint * numel_hint < min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (num_threads) + target_blocks = (target_blocks + xblocks - 1) // xblocks + tmp_split_size = ( + reduction_numel_hint + rvals_per_thread * target_blocks - 1 + ) // (rvals_per_thread * target_blocks) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(tmp_split_size - closest) < 20: + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + + return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( + rvals_per_thread * split_size + ) + + # easy cases + if numel_hint == 1: + split = inner_reduction_splits(reduction_numel_hint, numel_hint) + if split == 1: + # No need to split. + return ReductionHint.INNER, split + if ( + len(ranges) == 0 + and input_node is not None + and isinstance(input_node, TensorBox) + ): + # Only handles the case where keep_dim = False. + # Otherwise, we need to propagate reduction dim info to the stage where + # the intermediate loader of the first Reduction is generated. + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) + if new_ranges is not None and new_reduction_ranges is not None: + extracted_numel_hint = V.graph.sizevars.symbolic_hint( + sympy_product(new_ranges + new_reduction_ranges) + ) + if reduction_numel_hint == extracted_numel_hint: + log.debug( + "Use previous IRNode's range and reduction_ranges instead of split. " + "current ranges: %s, current reduction ranges: %s, current split: %d, " + "new ranges: %s, new reduction ranges: %s", + ranges, + reduction_ranges, + split, + new_ranges, + new_reduction_ranges, + ) + # If the input_node or its dependent nodes are also Reduction nodes, + # use reduction_sizes of this node or its dependent nodes directly. + return ReductionHint.INNER, -1 + return ReductionHint.INNER, split + if ( + reduction_numel_hint <= min_elements_per_thread + or numel_hint >= num_sm * 2 * 32 + ): + return ReductionHint.DEFAULT, 1 + + r = Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + ReductionHint.DEFAULT, + ) + + def get_read_indices(r): + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=r.get_device(), + dtype=r.get_dtype(), + size=r.get_size(), + ), + data=r, + ) + read_writes = cb.get_read_writes() + # try finding the full size producer + # TODO this will fail for something like ((1, N) * (N, 1)).sum() + # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare + range_vars = [ + r + for r in read_writes.range_vars + if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) + ] + indices = [] + changed = False + for md in sorted(read_writes.reads, key=lambda x: x.name): + if all(r in md.index.free_symbols for r in range_vars): + indices.append(md.index) + if md.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[md.name] + original_stride = buf.layout.stride + buf.decide_layout() + if buf.layout.stride != original_stride: + changed = True + return indices, changed + + indices, changed = get_read_indices(r) + if changed: + indices, _ = get_read_indices(r) + + if len(indices) == 0: + # TODO determine splits when all inputs are broadcast + return ReductionHint.DEFAULT, 1 + + (_, reduction_vars), ranges = dependencies.index_vars_squeeze( + r.get_size(), r.get_reduction_size() + ) + num_outer = 0 + num_inner = 0 + for i in indices: + i = V.graph.sizevars.simplify_with_ranges(i, ranges) + strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) + outer = all(s > 1 for s in strides) + if outer: + num_outer += 1 + else: + num_inner += 1 + if num_inner > num_outer: + return ReductionHint.INNER, inner_reduction_splits( + reduction_numel_hint, numel_hint + ) + else: + return ReductionHint.OUTER, outer_reduction_splits( + reduction_numel_hint, numel_hint + ) + + @staticmethod + def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype): + """Convert inner_fn from a reduction to an pointwise""" + reduction_ranges = [ + V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges + ] + + combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) + + def fn(index): + return functools.reduce( + combine_fn, + ( + value_fn(index, rindex) + for rindex in itertools.product( + *[range(x) for x in reduction_ranges] + ) + ), + ) + + if reduction_type in ("argmin", "argmax"): + flatten_index = FixedLayout( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ).make_indexer() + + def value_fn(index, rindex): + rindex = [sympy.expand(i) for i in rindex] + return ( + inner_fn(index, rindex), + ops.index_expr(flatten_index(rindex), torch.int64), + ) + + return lambda index: fn(index)[1] + else: + value_fn = inner_fn + return fn + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ): + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index): + return ops.constant(0, dst_dtype) + + else: + + def fn(index): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return Pointwise.create(device, dst_dtype, fn, ranges) + + if ( + isinstance(reduction_numel, sympy.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ): + return Pointwise.create( + device, + dst_dtype, + cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges, + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node # type: ignore[arg-type] + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + return TensorBox.create( + Reduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @staticmethod + def default_accumulator(reduction_type, dtype): + if reduction_type in {"max", "argmax"}: + if is_float_dtype(dtype): + return float("-inf") + elif is_boolean_dtype(dtype): + return 0 + else: + return torch.iinfo(dtype).min + if reduction_type in {"min", "argmin"}: + if is_float_dtype(dtype): + return float("inf") + elif is_boolean_dtype(dtype): + return 1 + else: + return torch.iinfo(dtype).max + + return { + "sum": 0, + "prod": 1, + "xor_sum": 0, + "any": 0, + "welford_reduce": (0, 0, 0), + "welford_combine": (0, 0, 0), + }[reduction_type] + + @staticmethod + def default_value(reduction_type, dtype): + if reduction_type == "welford_reduce": + return 0 + return Reduction.default_accumulator(reduction_type, dtype) + + @staticmethod + def _multilayer_second_step_hint( + split: int, numel_hint: int, reduction_hint: ReductionHint + ) -> ReductionHint: + if split == -1: + return reduction_hint + if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: + return ReductionHint.OUTER_TINY + if ( + split <= 1024 + and numel_hint <= 256 + and reduction_hint == ReductionHint.OUTER + ): + return ReductionHint.OUTER_TINY + + return reduction_hint + + @classmethod + def _multilayer_wrap_loader( + cls, + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default, + ): + reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + def wrapper_fn(index, reduction_index): + (reduction_index,) = reduction_index + *new_index, reduction_block = index + indices = block_size * reduction_block + reduction_index + + def body(): + return loader(new_index, reindex([indices])) + + if need_mask: + mask = ops.lt( + ops.index_expr(indices, torch.int32), + ops.index_expr(reduction_numel, torch.int32), + ) + return ops.masked(mask, body, default) + else: + return body() + + return wrapper_fn + + @classmethod + def _multilayer_wrap_loader_existing_ranges( + cls, + loader, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ): + assert len(original_ranges) == 0, f"{original_ranges}= is not equal to []" + reindex = View.dynamic_reshape_indexer( + original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) + ) + + def wrapper_fn(index, reduction_index): + return loader([], reindex(tuple(index) + tuple(reduction_index))) + + return wrapper_fn + + @classmethod + def create_multilayer_helper( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + wrapper_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 + # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction + # in fp32 and not reduce precision by breaking up the kernel into multiple layers + intermediate_dtype = ( + dst_dtype + if dst_dtype not in (torch.float16, torch.bfloat16) + else torch.float + ) + intermediate = Reduction.create( + device, + intermediate_dtype, + src_dtype, + wrapper_fn, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + intermediate.realize() + intermediate_loader = intermediate.make_loader() + + def intermediate_fn(index, reduction_index): + return intermediate_loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + + assert original_ranges == new_ranges[: len(original_ranges)] + return TensorBox.create( + Reduction( + device, + dst_dtype, + intermediate_fn, + original_ranges, + new_ranges[len(original_ranges) :], + reduction_type, + src_dtype, + reduction_hint, + ) + ) + + @classmethod + def create_multilayer( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # TODO(jansel): realize the reduction so we can do dynamic indexing + reduction_numel = sympy_product(reduction_ranges) + block_size = FloorDiv(reduction_numel + (split - 1), split) + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader( + inner_fn, reduction_ranges, reduction_numel, split, block_size, default + ) + + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + ranges, + reduction_ranges, + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + split, + reduction_hint, + ) + + @classmethod + def create_multilayer_existing_ranges( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + original_ranges: List[Expr], + original_reduction_ranges: List[Expr], + new_ranges: List[Expr], + new_reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( + inner_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + default, + ) + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + -1, + reduction_hint, + ) + + +def num_reduction_outputs(reduction_type): + return 3 if "welford" in reduction_type else 1 + + +class WelfordReduction(Reduction): + output_index: int + + def __init__( + self, + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_index, + ): + if len(inner_fns) == 1: + loader = inner_fns[0] + else: + + def loader(idx, reduction_idx): + return tuple(fn(idx, reduction_idx) for fn in inner_fns) + + super().__init__( + device, + dtype, + loader, + ranges, + reduction_ranges, + reduction_type, + dtype, + reduction_hint, + ) + self.output_index = output_index + + def store_reduction(self, output_name, indexer, vars, reduction_vars): + values = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + value = values[self.output_index] + return ops.store_reduction(output_name, indexer(vars), value) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ): + assert reduction_type in {"welford_reduce", "welford_combine"} + + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + def const(val): + def inner_fn(idx): + return ops.constant( + val, + dtype, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_numel == 0: + mean = const(0) + m2 = const(0) + weight = const(0) + return mean, m2, weight + + if reduction_numel == 1: + + def copy(loader): + def inner_fn(idx): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return loader(idx, reduction_index) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_type == "welford_reduce": + return copy(inner_fns[0]), const(0), const(1) + else: + return tuple(copy(fn) for fn in inner_fns) + + # TODO: Unrolled reduction + # if ( + # isinstance(reduction_numel, sympy.Integer) + # and V.graph.sizevars.size_hint(reduction_numel) + # < config.unroll_reductions_threshold + # and sympy_product(ranges) != 1 + # ): + # return Pointwise.create( + # device, + # dst_dtype, + # cls._unroll_reduction_fn( + # inner_fn, reduction_ranges, reduction_type, src_dtype + # ), + # ranges, + # ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = Reduction.num_splits( + device, + dtype, + dtype, + inner_fns[0], + ranges, + reduction_ranges, + reduction_type=reduction_type, + reduction_numel=reduction_numel, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + results = [ + TensorBox.create( + WelfordReduction( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(3) + ] + for t in results: + t.realize() + return results + + @staticmethod + def default_value(reduction_type, dtype): + return (0, 0, 0) + + @classmethod + def create_multilayer( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + split: int, + reduction_hint: ReductionHint, + ): + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + reduction_numel = sympy_product(reduction_ranges) + need_mask = not V.graph.sizevars.is_expr_static_and_true( + sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] + ) + + if need_mask and reduction_type != "welford_combine": + # If we need mask, then "welford_reduce" doesn't work because + # masked inputs shouldn't count towards the welford weight + + def constant(idx, reduction_idx, value): + return ops.constant(value, dtype) + + return cls.create_multilayer( + device=device, + dtype=dtype, + inner_fns=( + inner_fns[0], + partial(constant, value=0), + partial(constant, value=1), + ), + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type="welford_combine", + split=split, + reduction_hint=reduction_hint, + ) + + block_size = FloorDiv(reduction_numel + (split - 1), split) + intermediates = WelfordReduction.create( + device, + dtype, + tuple( + cls._multilayer_wrap_loader( + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default=0, + ) + for loader in inner_fns + ), + [*ranges, split], # type: ignore[list-item] + [block_size], + reduction_type, + reduction_hint, + ) + for i in intermediates: + i.realize() + + i_loaders = [i.make_loader() for i in intermediates] + + def intermediate_loader_fn(index, reduction_index, loader): + return loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + return WelfordReduction.create( + device, + dtype, + tuple( + partial(intermediate_loader_fn, loader=i.make_loader()) + for i in intermediates + ), + ranges, + [split], # type: ignore[list-item] + # welford_reduce turns one input into three outputs, which are combined with welford_combine + "welford_combine", + reduction_hint, + ) + + +@dataclasses.dataclass +class Scan(Loops): + scan_ranges: List[Expr] + size: List[Expr] + combine_fn: Callable[..., Any] + reindex: Callable[[List[Expr], List[Expr]], List[Expr]] + reduction_hint: ReductionHint + init: int + + # HACK we mimick reduction + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we + # need to explicitly represent the closure so we can pull out unbacked + # symbols here + return ( + super().get_unbacked_symbol_uses() + | set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges)) + | set().union(*(free_unbacked_symbols(e) for e in self.size)) + ) + + def __post_init__(self): + assert len(self.ranges) + len(self.scan_ranges) == len(self.size) + super().__post_init__() + + def store_reduction(self, output_name, indexer, vars, scan_vars): + idx = self.reindex(vars, scan_vars) + value = self.inner_fn(idx) + result = ops.scan(self.dtype, self.combine_fn, value, self.init) + return ops.store(output_name, indexer(idx), result) + + def get_reduction_type(self): + # return self.scan_op + return "custom" + + def get_reduction_size(self): + return self.scan_ranges + + def get_size(self): + return self.size + + def get_pointwise_size(self): + return self.ranges + + def index_length(self): + return len(self.ranges) + len(self.scan_ranges) + + def inner_fn_args(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, "r") + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_unbacked_symbols(self): + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, "r") + idx = self.reindex(index, rindex) + return extract_free_unbacked_symbols(self.inner_fn, idx) + + @classmethod + def create( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[List[Expr]], Any], + size: List[Expr], + axis: int, + combine_fn: Callable[..., Any], + init: Any, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ) -> Optional["TensorBox"]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + scan_ranges = [size[axis]] + + if device.type != "cuda": + # TODO: CPU support + return None + + sizevars = V.graph.sizevars + scan_numel = sizevars.simplify(sympy_product(scan_ranges)) + + # Scan with a single element is just a copy + if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=size, + ) + + reduction_hint, num_splits = cls.num_splits( + device=device, + dtype=dtype, + inner_fn=inner_fn, + axis=axis, + pointwise_ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + scan_numel=scan_numel, + ) + scan_type = Scan if num_splits <= 1 else SplitScan + + if num_splits > 1 and torch.version.hip is not None: + # Fallback for split-scan on ROCm + return None + + def reindex(index, scan_index): + assert len(scan_index) == len(scan_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *scan_index, *index[axis:]] + + result = TensorBox.create( + scan_type( + device=device, + dtype=dtype, + inner_fn=inner_fn, + size=size, + ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + reindex=reindex, + init=init, + reduction_hint=reduction_hint, + ) + ) + result.realize() + return result + + @classmethod + def num_splits( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[List[Expr]], Any], + axis: int, + pointwise_ranges: List[Expr], + scan_ranges: List[Expr], + combine_fn: Callable[..., Any], + scan_numel: Expr, + ): + # TODO: custom splitting heuristic for scan + def wrapper_fn(idx, reduction_idx): + return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) + + return Reduction.num_splits( + device=device, + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=wrapper_fn, + ranges=pointwise_ranges, + reduction_ranges=scan_ranges, + reduction_type="sum", + reduction_numel=scan_numel, + ) + + +# This signifies a scan op that should go through TritonSplitScanKernel codgen on CUDA. +@dataclasses.dataclass +class SplitScan(Scan): + pass + + +def is_storage_and_layout(x): + try: + as_storage_and_layout(x, freeze=False) + return True + except NotImplementedError: + return False + + +def is_contiguous_storage_and_layout(x): + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_contiguous() + except NotImplementedError: + return False + + +def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None): + """Try to simplify x into a StorageBox and a Layout""" + if isinstance(x, TensorBox): + return as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + ) + if isinstance(x, StorageBox) and isinstance(x.data, Buffer): + if freeze: + if want_contiguous: + x.data.freeze_layout() + assert x.data.layout.is_contiguous() + elif stride_order is not None: + x.data.freeze_layout_with_stride_order(stride_order) + else: + x.data.decide_layout() + return x, x.data.layout + if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretView either, so don't pass along those arguments + buffer, _ = as_storage_and_layout( + x.data, + freeze=freeze, + ) + return buffer, x.layout + raise NotImplementedError + + +as_contiguous_storage_and_layout = functools.partial( + as_storage_and_layout, want_contiguous=True +) + + +def is_stride_order_storage_and_layout(x, stride_order): + try: + buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_stride_ordered(stride_order) + except NotImplementedError: + return False + + +@dataclasses.dataclass +class BaseView(IRNode): + data: IRNode + + def get_unbacked_symbol_uses(self): + return self.data.get_unbacked_symbol_uses() + + def make_reindexer(self): + raise NotImplementedError(f"make_reindexer NYI on {self}") + + def make_indexer(self): + inner = self.data.make_indexer() + reindex = self.make_reindexer() + + def indexer(idx): + return inner(reindex(idx)) + + return indexer + + def make_loader(self): + inner = self.data.make_loader() + reindex = self.make_reindexer() + + def loader(idx): + return inner(reindex(idx)) + + return loader + + @property + def dtype(self): + return self.data.dtype + + def get_layout(self): + return self.data.get_layout() + + def get_device(self): + return self.data.get_device() + + def get_origin_node(self): + return None + + def get_name(self): + return self.data.get_name() + + def get_pointwise_size(self): + return self.get_size() + + def mark_reuse(self, users): + return self.data.mark_reuse(users) + + def has_exceeded_max_reads(self): + return self.data.has_exceeded_max_reads() + + def realize(self): + return self.data.realize() + + def realize_hint(self): + return self.data.realize_hint() + + def get_storage_numel(self): + return self.data.get_storage_numel() + + def is_extern(self): + return self.data.is_extern() # type: ignore[attr-defined] + + def get_reads(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def unwrap_view(self): + x: IRNode = self + while isinstance(x, BaseView): + x = x.data + return x + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise(device, self.get_dtype(), loader, self.get_size()) + + +@dataclasses.dataclass +class ExpandView(BaseView): + size: List[Expr] + + @staticmethod + def _normalize_size(x, new_size): + """Replace `-1` with correct sizes""" + new_size = list(map(sympy.expand, new_size)) + old_size = x.get_size() + old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) + assert len(new_size) == len(old_size) + for i in range(len(new_size)): + if new_size[i] == -1: + assert old_size[i] is not None + new_size[i] = old_size[i] + elif old_size[i] is None or old_size[i] == 1: + pass + else: + # Expect broadcast compatibility + new_size[i] = V.graph.sizevars.expect_equals( + new_size[i], + old_size[i], + msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}", + ) + return new_size + + @classmethod + def create(cls, x, new_size): + new_size = cls._normalize_size(x, new_size) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append(stride if size != 1 else sympy.Integer(0)) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return ExpandView(x, new_size) + + def get_size(self): + return self.size + + def make_reindexer(self): + target = self.get_size() + actual = self.data.get_size() + skip = len(target) - len(actual) + + def reindex(index): + index = list(index[skip:]) + assert len(index) == len(actual) + for i in range(len(actual)): + if actual[i] == 1: + # zero out broadcast dimension + index[i] = sympy.Integer(0) + return index + + return reindex + + +@dataclasses.dataclass +class PermuteView(BaseView): + dims: List[Expr] + + @classmethod + def create(cls, x, dims): + dims = cls._map_neg_dims(dims) + assert set(dims) == set(range(len(dims))) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + return PermuteView(x, dims) + + @classmethod + def _map_neg_dims(cls, dims): + return [dim if dim >= 0 else len(dims) + dim for dim in dims] + + def get_size(self): + assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims))) + size = self.data.get_size() + return [size[i] for i in self.dims] + + def make_reindexer(self): + inv = {j: i for i, j in enumerate(self.dims)} + inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] + assert set(inv) == set(range(len(self.dims))) + + def reindex(index): + return [index[i] for i in inv] + + return reindex + + +class SqueezeView(BaseView): + @classmethod + def create(cls, x, *, dim=None): + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_size = [] + new_stride = [] + if dim is not None: + assert isinstance(dim, int), "expected integer dim argument" + assert 0 <= dim and dim < len(old_layout.size) + + for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): + if dim is None: + if size != 1: + new_size.append(size) + new_stride.append(stride) + else: + if i != dim: + new_size.append(size) + new_stride.append(stride) + else: + assert size == 1, "expected squeezed size to be 1" + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + if dim is None: + # redirect to a generic view + return View.create(x, [s for s in x.get_size() if s != 1]) + else: + assert x.get_size()[dim] == 1 + return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) + + @staticmethod + def squeezer(size: Tuple[sympy.Expr, ...]): + new_size = [s for s in size if s != 1] + not_one = [i for i, s in enumerate(size) if s != 1] + length = len(size) + + def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: + assert len(index) == len(not_one), f"{index} {not_one}" + new_index = [sympy.Integer(0)] * length + for idx, s in zip(not_one, index): + new_index[idx] = s + return tuple(new_index) + + return new_size, reindex + + def __init__(self, data): + raise AssertionError("use SqueezeView.create()") + + +@dataclasses.dataclass +class GenericView(BaseView): + size: List[Expr] + reindex: Callable[..., Any] + + def make_reindexer(self): + return self.reindex + + def reindex_str(self): + index_old = [sympy_index_symbol(f"i{n}") for n in range(len(self.size))] + index_new = list(self.reindex(index_old)) + return f"lambda {', '.join(map(str, index_old))}: {index_new}" + + def __str__(self): + return self.str_helper( + [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] + ) + + __repr__ = __str__ + + @classmethod + def create(cls, x, new_size, reindex): + return cls(x, list(new_size), reindex) + + def get_size(self): + return self.size + + +@dataclasses.dataclass +class View(GenericView): + @staticmethod + def handle_negative_index(idx, size): + idx = sympy.expand(idx) + size = sympy.expand(size) + evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr + if evaluate_expr(sympy.Lt(idx, 0)): + idx = idx + size + return idx + + @classmethod + def create(cls, x, new_size): + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(free_unbacked_symbols(old_size)) > 0 + or len(free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + return cls(x, list(new_size), fake_reindex) + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ExternKernel.realize_input(x) + + storage, old_layout = as_contiguous_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + return ReinterpretView(storage, new_layout) + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + return cls(x, list(new_size), reindex) + + @staticmethod + def resolve_negative_size(old_size, new_size): + new_size = [V.graph.sizevars.simplify(x) for x in new_size] + old_size = [V.graph.sizevars.simplify(x) for x in old_size] + + new_size = list(new_size) + for i in range(len(new_size)): + if new_size[i] == -1: + new_size[i] = sympy.Integer(1) + new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) + break + + V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) + return old_size, new_size + + @classmethod + def dynamic_reshape_indexer(cls, old_size, new_size): + try: + reindex = cls._dynamic_reshape_indexer(old_size, new_size) + except (AssertionError, IndexError): + # optimistic algorithm failed, lets do a fallback + flat = [sympy_product(old_size)] + reindex1 = cls._dynamic_reshape_indexer(old_size, flat) + reindex2 = cls._dynamic_reshape_indexer(flat, new_size) + reindex = fuse_reindexing(reindex1, reindex2) + return reindex + + @staticmethod + def _dynamic_reshape_indexer(old_size, new_size): + """ + Perform a reshape entirely by modifying indexing math + """ + size_hint = V.graph.sizevars.size_hint + vars = [sympy_index_symbol(f"view{i}") for i in range(len(new_size))] + + stack_new = list(zip(vars, new_size)) + stack_old = list(old_size) + + view_expr = [] + while stack_new and stack_old: + size_old = stack_old.pop() + var, size_new = stack_new.pop() + if size_old == 1: + view_expr.append(sympy.Integer(0)) + stack_new.append((var, size_new)) # re-add + elif size_new == 1: + stack_old.append(size_old) # re-add + elif size_hint(size_new) == size_hint(size_old): + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) < size_hint(size_old): + while size_hint(size_new) < size_hint(size_old): + var2, size_new2 = stack_new.pop() + var = var2 * size_new + var + size_new = size_new * size_new2 + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) > size_hint(size_old): + divisor = sympy.Integer(1) + modulus = size_old + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + while size_hint(size_new) > size_hint(size_old): + modulus = stack_old.pop() + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + size_old = size_old * modulus + V.graph.sizevars.guard_equals(size_new, size_old) + else: + raise AssertionError() + + while stack_old: + size_old = stack_old.pop() + V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] + view_expr.append(sympy.Integer(0)) + + while stack_new: + var, size_new = stack_new.pop() + V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] + + view_expr.reverse() + assert len(view_expr) == len(old_size) + + def reindex(index): + assert len(index) == len(vars), (len(index), len(vars)) + replacements = dict(zip(vars, index)) + return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] + + return reindex + + +@dataclasses.dataclass +class ReinterpretView(BaseView): + """Pretend our storage has a different layout""" + + layout: "Layout" + + def __post_init__(self): + super().__post_init__() + if isinstance(self.data, BaseView): + self.data = self.data.unwrap_view() + + def __str__(self): + return self.str_helper( + [ + self.data, + self.layout, + ] + ) + + __repr__ = __str__ + + def get_name(self): + return self.data.get_name() + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return None + + @property + def dtype(self): + return self.layout.dtype + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + return ops.load(self.get_name(), indexer(index)) + + return loader + + def make_indexer(self): + return self.layout.make_indexer() + + def get_layout(self): + return self.layout + + def freeze_layout(self): + pass + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return ( + free_unbacked_symbols(self.layout.size) + | free_unbacked_symbols(self.layout.stride) + | free_unbacked_symbols(self.layout.offset) + ) + + def codegen_reference(self, writer=None): + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return V.graph.wrapper_code.codegen_reinterpret_view( + self.data, + self.layout.size, + self.layout.stride, + self.layout.offset, + writer, + ) + + +class SliceView(View): + @classmethod + def normalize_start_end(cls, x, dim, start, end): + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + sizevars = V.graph.sizevars + dim_size = x.get_size()[dim] + + if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): + + def clamp(x, lower, upper): + return sympy.Min(sympy.Max(x, lower), upper) + + else: + + def clamp(x, lower, upper): + return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper) + + def clamp_wrap(val, lower, upper, default): + if val is None: + return default + val = cls.handle_negative_index(val, dim_size) + return clamp(val, lower, upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + @classmethod + def create(cls, x, dim, start, end, step=1): + step = sympy.expand(step) + assert step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = FloorDiv(end - start + (step - 1), step) + + if is_storage_and_layout(x): + # Fast path + storage, old_layout = as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + return ReinterpretView(storage, new_layout) + + def reindex(index): + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + return SliceView(x, size=new_size, reindex=reindex) + + +class BaseConstant(IRNode): + dtype: torch.dtype + device: torch.device + + def get_size(self): + return () + + def get_device(self): + return self.device + + def get_origin_node(self): + return None + + def mark_reuse(self, users): + pass + + def has_exceeded_max_reads(self): + return False + + def get_reads(self): + return () + + def is_extern(self): + return False + + +@dataclasses.dataclass +class Constant(BaseConstant): + value: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.constant(self.value, self.dtype) + + return loader + + def realize(self): + pass + + def constant_to_device(self, device): + return Constant(self.value, self.dtype, device) + + +@dataclasses.dataclass +class IndexingConstant(BaseConstant): + index: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self): + def loader(index): + return ops.index_expr(self.index, self.dtype) + + return loader + + def constant_to_device(self, device): + return IndexingConstant(self.index, self.dtype, device) + + +def is_contiguous_strides_for_shape(stride, shape): + return all( + size == 1 or left == right + for left, right, size in zip( + stride, FlexibleLayout.contiguous_strides(shape), shape + ) + ) + + +@dataclasses.dataclass +class Layout(IRNode): + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: Optional[Sequence[Union[Expr, int]]], + offset: Expr = Integer(0), + ): + assert stride is None or len(size) == len( + stride + ), f"size={size}, stride={stride}" + self.device = device + self.dtype = dtype + assert all(isinstance(s, (Expr, int)) for s in size) + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride + + def __str__(self): + offset = "" + if self.offset != 0: + offset = f", offset={self.offset}" + return ( + f"{type(self).__name__}('{self.device.type}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset})" + ) + + __repr__ = __str__ + + def is_contiguous(self): + return is_contiguous_strides_for_shape(self.stride, self.size) + + def is_channels_last_contiguous(self): + ndim = len(self.size) + if ndim not in [4, 5]: + return False + for left, right, size in zip( + self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type] + ): + if size != 1 and left != right: + return False + return True + + def is_transposed(self): + for left, right, size in zip( + self.stride, + reversed(FlexibleLayout.contiguous_strides(self.size)), + self.size, + ): + if size != 1 and left != right: + return False + return True + + def is_stride_ordered(self, order): + assert len(self.stride) == len(order) + + # ignore dimensions of size 1, they dont affect layout + non_1_indices = [ + i + for i, dim in enumerate(self.size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + + stride = [self.stride[i] for i in non_1_indices] + order = [order[i] for i in non_1_indices] + + def sorted_indices(arr): + sorted_arr = sorted(arr) + return [sorted_arr.index(element) for element in arr] + + # since we may have removed dimensions, need to re-sort & re-index order + order = sorted_indices(order) + + # reorder the stride given order + stride_ordered = [-1] * len(order) + for i in range(len(order)): + stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i]) + # check if it is in ascending order + for i in range(len(order) - 1): + if stride_ordered[i] > stride_ordered[i + 1]: + return False + return True + + def is_channels_last_stride_ordered(self): + # create channels_last order(NCHW, NCDHW, the C is the first order). + order = [0] + list(reversed(range(1, len(self.stride) - 1))) + order = [len(order)] + order + return self.is_stride_ordered(order) + + def as_fixed(self): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride, + self.offset, + ) + + def make_indexer(self): + assert ( + FlexibleLayout.allow_indexing + ), f"convert {type(self).__name__} to FixedLayout first" + return self.as_fixed().make_indexer() + + def __eq__(self, other) -> bool: + return ( + self.device == other.device + and self.dtype == other.dtype + and self.size == other.size + and self.stride == other.stride + and self.offset == other.offset + ) + + def storage_size(self) -> sympy.Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] + + +class FixedLayout(Layout): + """A Tensor layout we cannot change""" + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Union[List[Expr], List[int]], + stride: Optional[Sequence[Union[Expr, int]]] = None, + offset: Union[Expr, int] = Integer(0), + ): + if stride is None: + stride = FlexibleLayout.contiguous_strides(size) + super().__init__( + device, + dtype, + size, # type: ignore[arg-type] + stride, + offset, # type: ignore[arg-type] + ) + + def make_indexer(self): + """A closure containing math to read a given element""" + + def indexer(index): + assert len(index) == len(self.stride) == len(self.size) + result = self.offset + for idx, stride, sz in zip(index, self.stride, self.size): + if sz != 1: + result = result + idx * stride + return result + + return indexer + + +class FlexibleLayout(Layout): + """A Tensor layout we are allowed to change""" + + allow_indexing = False + + @staticmethod + def contiguous_strides(sizes): + if len(sizes) == 0: + return [] + reversed_strides = [sympy.Integer(1)] + for size in reversed(sizes[1:]): + reversed_strides.append(size * reversed_strides[-1]) + return list(reversed(reversed_strides)) + + @staticmethod + def fill_ordered(sizes, order): + """ + Create a stride based on the order the dimensions should be filled in. + + In this format, channels last would be: + [1, 3, 2, 0] + """ + assert set(range(len(sizes))) == set(order) + next_stride = sympy.Integer(1) + strides = [None] * len(order) + + for i in order: + strides[i] = next_stride + next_stride = next_stride * sizes[i] + return strides + + @staticmethod + def stride_ordered(sizes, order): + """ + Create a stride based on the sorted order of a permuted range. + + In this format, channels last would be: + [3, 0, 2, 1] + """ + assert set(range(len(sizes))) == set(order) + fill_order = stride_order2fill_order(order) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @staticmethod + def same_ordered(sizes, stride): + """ + Create a stride that has the same stride order as given stride + + For example, if given stride is [1000, 1, 100, 10], + the fill order should be [1, 3, 2, 0] + """ + assert len(sizes) == len(stride) + stride = [V.graph.sizevars.size_hint(x) for x in stride] + fill_order = sorted(range(len(stride)), key=stride.__getitem__) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + def as_stride_order(self, order): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride_ordered(self.size, order), + self.offset, + ) + + def as_fill_order(self, order): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.fill_ordered(self.size, order), + self.offset, + ) + + def as_same_order(self, stride): + return FixedLayout( + self.device, + self.dtype, + self.size, + self.same_ordered(self.size, stride), + self.offset, + ) + + def __init__(self, device, dtype, size, stride_order=None): + if stride_order: + strides = FlexibleLayout.fill_ordered(size, stride_order) + else: + strides = FlexibleLayout.contiguous_strides(size) + super().__init__(device, dtype, size, strides) + + +class AliasedLayout(Layout): + """Shares the same storage as another tensor""" + + def __init__(self, view: Union[BaseView, "TensorBox"]): + layout = view.get_layout() + super().__init__( + layout.device, + layout.dtype, + layout.size, + layout.stride, + ) + self.view = view + + def make_indexer(self): + return self.as_fixed().make_indexer() + + def maybe_guard_aligned(self): + offset = self.view.get_layout().offset + if offset == 0: + return True + from .compile_fx import ALIGNMENT + + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] + + +class NoneLayout(IRNode): + # This is janky, I figured out what fields to populate by just running + # the model I was interested in and adding properties/methods as needed. + # This doesn't inherit from Layout because Layout assumes you have stuff + # like sizes, but I don't really have anything here. + # + # If you have an ir.Node with NoneLayout, you probably need to setup + # dependencies manually in scheduler + + def __init__(self, device): + self.device = device + self.size = [0] + self.stride = [0] + + def storage_size(self): + return 0 + + def as_fixed(self): + return self + + +class MutationLayout(Layout): + def __init__(self, target: IRNode): + super().__init__( + target.get_device(), + target.get_dtype(), + target.get_size(), + None, + ) + self.target = target + name = self.get_buffer().get_name() + V.graph.mark_buffer_mutated(name) + + @Layout.stride.getter # type: ignore[attr-defined] + def stride(self): + return self.real_layout().stride + + def storage_size(self) -> sympy.Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> "Buffer": + def unwrap_views(target): + if isinstance(target, MutationLayout): + return unwrap_views(target.target) + if isinstance(target, BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, MutableBox): + return unwrap_views(target.data) + return target + + result = unwrap_views(self.target) + assert isinstance(result, Buffer), "MutationLayout must refer to a buffer" + return result + + def real_layout(self): + return self.get_buffer().layout + + @classmethod + def realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further mutations to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + src = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + + src.realize() + assert isinstance(src.data.layout, FlexibleLayout) + src.data.layout = MutationLayout(dst) + return src.data + + def as_fixed(self): + return self + + def make_indexer(self): + return self.target.make_indexer() + + +@dataclasses.dataclass +class Buffer(IRNode): + # Name is sometimes None; e.g., ForceInPlace, where there isn't + # a meaningful name + name: Optional[str] + layout: Layout + + # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, + # MultiOutput does NOT define this! + + def __post_init__(self): + super().__post_init__() + self.origin_node = None + + def make_indexer(self): + return self.layout.make_indexer() + + def get_name(self) -> str: + assert self.name + return self.name + + def get_device(self): + return self.layout.device + + def get_origin_node(self): + return self.origin_node + + @property + def dtype(self): + return getattr(self.layout, "dtype", None) + + def get_size(self): + return list(self.layout.size) + + def get_stride(self): + return list(self.layout.stride) + + def get_offset(self): + return self.layout.offset + + def get_layout(self): + return self.layout + + def get_storage_numel(self): + return self.get_numel() + + def is_extern(self): + return False + + def freeze_layout(self): + if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)): + self.layout = self.layout.as_fixed() + + def freeze_layout_with_stride_order(self, order): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_stride_order(order) + + def freeze_layout_with_fill_order(self, order): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride): + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_same_order(stride) + + def is_zero_elements(self): + return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] + + def make_loader(self): + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + + def loader(index): + indexer = self.layout.make_indexer() + return ops.load(self.name, indexer(index)) + + return loader + + def is_no_op(self): + return False + + def codegen_reference(self, writer=None): + return self.get_name() + + def decide_layout(self): + pass + + def get_alias_names(self): + if isinstance(self.layout, AliasedLayout): + return [self.layout.view.get_name()] + return () + + def get_mutation_names(self): + if isinstance(self.layout, MutationLayout): + return [self.layout.target.get_name()] + return () + + def get_read_writes(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ) + + def get_reads(self): + return self.get_read_writes().reads + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + """ + Returns the unbacked symbols which are defined by this IR node, + because this is a data-dependent IR node, or item() + """ + # So this is a little unusual. In principle, you could imagine + # defining a MultiOutputLayout buffer so that it DOES define + # unbacked symints. However, we can't easily tell what symints + # such a buffer defines, because MultiOutputLayout doesn't actually + # define any useful information about what it returns. + # + # An easier and better approach is to delay the symint allocation + # to the MultiOutput IR nodes, which are when we actually extract + # out the buffers and know what their sizes are. + # + # There are two subleties here: + # + # 1. Suppose you have a kernel that produces out1: (i0,), out2: (i0,) + # Both of these actually count as defs! The scheduler will just + # arbitrarily pick one of these as the canonical definer and + # ensure it stays live. It's not a big deal if we pick the + # wrong one because tuple accesses are cheap, and all this means + # is we accidentally keep a MultiOutput node live when it wasn't + # strictly necessary. + # + # 2. Suppose you have a MultiOutput buffer whose size is (i0,), but + # the MultiOutputLayout buffer it is projecting from isn't actually + # dynamic; it has i0 as one of the arguments. We cannot tell this + # directly from MultiOutput, we have to look at the input buffer's + # uses to work this out. No big deal. + if isinstance(self.layout, (NoneLayout, MultiOutputLayout)): + return set() + + # This kernel defines all unbacked symbols... that it didn't get in as + # arguments! + defs = ( + free_unbacked_symbols(self.get_size()) + | free_unbacked_symbols(self.get_stride()) + | free_unbacked_symbols(self.get_offset()) + ) + return defs - self.get_unbacked_symbol_uses() + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + """ + Returns the unbacked symbols which are required to be in scope in + order to successfully perform codegen for this buffer. For example, + a buffer that corresponds to an extern kernel call that takes i0 as + an argument would return {i0} here. This is used to generate necessary + dependencies that ensure we actually bind i0 in codegen before you + try to use it. + + Note that this is NOT transitive; in particular, if this buffer takes + in as input another buffer with dynamic shape (e.g., (i0,)), we will + not report it here, because you will already have a dependency + on that buffer, which will eventually have a dependency on i0 if + necessary. + """ + return set() + + def codegen_unbacked_symbol_defs(self, wrapper): + # NB: If it is possible for other ir node types to return unbacked + # symints, you need to make sure their codegen calls this method. + # Don't forget to update get_unbacked_symbol_defs too. + symbols_to_define = self.get_unbacked_symbol_defs() + for i, s in enumerate(self.get_size()): + if s in symbols_to_define: + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.size({i}){wrapper.ending}" + ) + symbols_to_define.remove(s) + for i, s in enumerate(self.get_stride()): + if s in symbols_to_define: + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.stride({i}){wrapper.ending}" + ) + symbols_to_define.remove(s) + if (s := self.get_offset()) in symbols_to_define: + wrapper.writeline( + f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.storage_offset(){wrapper.ending}" + ) + symbols_to_define.remove(s) + assert ( + not symbols_to_define + ), f"unbacked symint {s} not written out, check comment above" + + def realize(self): + pass + + def get_workspace_size(self): + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 + + def should_allocate(self): + # Returns False by default. + return False + + +class InputBuffer(Buffer): + pass + + +class ConstantBuffer(InputBuffer): + override_device: Optional[torch.device] = None + + def make_loader(self): + def loader(index): + indexer = self.layout.make_indexer() + return ops.load( + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), + ) + + return loader + + def constant_to_device(self, device): + return ConstantBuffer( + V.graph.constant_name(self.get_name(), device), self.layout + ) + + +class NoneAsConstantBuffer(IRNode): + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return set() + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.none_str + + +class ShapeAsConstantBuffer(IRNode): + def __init__(self, shape): + super().__init__() + self.shape = shape + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return free_unbacked_symbols(self.shape) + + def codegen_reference(self, writer=None): + return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) + + +@dataclasses.dataclass +class ComputedBuffer(Buffer): + data: Loops + + def get_computed_buffer_name(self): + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + + @cache_on_self + def num_reads(self): + return len(self.get_read_writes().reads) + + def get_read_writes(self): + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.data.get_reduction_type(): + return extract_read_writes( + self.get_store_function(), + self.data.get_pointwise_size(), + self.data.get_reduction_size(), + ) + else: + return extract_read_writes( + self.get_store_function(), + self.data.get_size(), + ) + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + # Ordinarily, we'd like to just peek at the arguments list, + # but ComputedBuffers have no argument list. + # + # Morally, this logic needs to be synchronized with the + # KernelArgs.size calls, which are responsible for making symbols make + # there way as kernel arguments (and it is precisely passing in one of + # those symbols that establishes a dependency). However, we haven't + # started codegen yet so we can't directly reuse that logic. + # + # For now, I'm just yoloing with the size of the buffer. Not sure if + # it is enough. + # + # One thing you might wonder is if this is enough for a ComputedBuffer + # denoting a reduction over i0. Empirically, it is enough, but for an + # unusual reason: we only need accurate dependencies for item() call, + # but it's impossible to end up with a reduction over i0 from an + # item() call without a regular non-reduction buffer first. + return ( + free_unbacked_symbols(self.get_size()) + | free_unbacked_symbols(self.get_stride()) + | free_unbacked_symbols(self.get_offset()) + | self.data.get_unbacked_symbol_uses() + ) + + def make_loader(self): + # Inline constants and index_expressions + if ( + hasattr(self.data, "make_loader") + and self.name not in V.graph.mutated_buffers + and self.num_reads() == 0 + ): + # can be inlined + return self.data.make_loader() + return super().make_loader() + + def get_store_function(self): + indexer = self.layout.as_fixed().make_indexer() + if isinstance(self.data, (Reduction, Scan)): + return partial(self.data.store_reduction, self.name, indexer) + else: + assert isinstance(self.data, Pointwise) + return partial(self.data.store_output, self.name, indexer) + + def get_fill_order(self): + """ + If our layout is still flexible, try to determine the stride order based on stride orders of reads. + + TODO(jansel): A better algorithm here would look at downstream consumers of this + value and try to do global graph-level layout optimization. + This is also something just begging to be autotuned. + """ + if isinstance(self.layout, FlexibleLayout): + (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size() + ) + reads = self.get_read_writes().reads + reads_bufs = [ + V.graph.name_to_buffer[r.name] + if r.name in V.graph.name_to_buffer.keys() + else None + for r in reads + ] + # only consider reads to buffer of same size + # ignore StarDeps because they don't contribute stride information + assert all( + isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) + for r in reads + ) + reads = [ + sympy_subs( + r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} + ) + for r in reads + if isinstance(r, dependencies.MemoryDep) + ] + + if reads: + if isinstance(self.data, Scan): + indices = self.data.reindex(index_vars, reduction_vars) + else: + indices = index_vars + stride_lengths = [ + V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] + ] + from .scheduler import pick_loop_order + + return pick_loop_order(stride_lengths, self.get_size()) + + return None + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + order = self.get_fill_order() + if order: + self.freeze_layout_with_fill_order(order) + else: + self.freeze_layout() + + def get_default_sizes_body(self): + args, var_ranges = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" + ) + with patch.object(ConstantBuffer, "override_device", self.get_device()): + body = LoopBody( + self.get_store_function(), + (args if self.get_reduction_type() else args[:1]), + var_ranges, + ) + index_vars = [] + reduce_vars: List[Any] = [] + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + assert not reduce_vars + index_vars.append(v) + index_size.append(s) + else: + assert v in args[1] + reduce_vars.append(v) + reduce_size.append(s) + return (index_size, reduce_size), body, (index_vars, reduce_vars) + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + ): + """ + This is a main place where we do loop transformations in a + backend-agnostic way. + + Here we: + 1) Remove any 1 dimensions + 2) Fuse contiguous dimensions together + 3) Reorder dimensions based on stride orders + + Optional argument extra_indexing_constraints can be used to append additional + indexing expressions to existing ones derived from buffer's body. This can be useful + to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) + on CPU by preventing indexing simplifications and obtaining index/reduce ranges for + the scheduler node compatible with other nodes. + """ + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = self.get_default_sizes_body() + + index_formulas = [*body.indexing_exprs.values()] + if extra_indexing_constraints is not None: + assert ( + isinstance(extra_indexing_constraints, tuple) + and len(extra_indexing_constraints) == 2 + ) + extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints + assert isinstance(extra_indexing_ranges, dict) + assert isinstance(extra_indexing_expr, list) + assert all(isinstance(f, Expr) for f in extra_indexing_expr) + + expected_var_ranges = body.var_ranges + assert expected_var_ranges == extra_indexing_ranges, ( + expected_var_ranges, + extra_indexing_ranges, + ) + # remove already existing expressions + extra_indexing_expr = [ + e for e in extra_indexing_expr if e not in index_formulas + ] + index_formulas += extra_indexing_expr + + reads_bufs = [ + V.graph.name_to_buffer[reads_name] + if reads_name in V.graph.name_to_buffer.keys() + else None + for reads_name in body.reads_name2expr.keys() + ] + memory_addrs = [ + *body.reads_name2expr.values(), + *body.writes_name2expr.values(), + ] + + # the reordering_reindex in reads' simplify_reorder_and_tile + reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs) + for i, reads_buf in enumerate(reads_bufs): + if isinstance(reads_buf, ComputedBuffer) and hasattr( + reads_buf, "iter_reordering_reindex" + ): + reordering_reindex[i] = reads_buf.iter_reordering_reindex # type: ignore[has-type] + + def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): + sizes, reindex0, reindex1 = self._apply_loop_reordering( + x_vars, support_vars, sizes, memory_addrs, reordering_reindex + ) + # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] + x_vars = reindex0(x_vars) + sizes, reindex2, prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + x_vars = prune(x_vars) + # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas) + # x_vars = prune(x_vars) + # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs) + reindex = fuse_reindexing(reindex1, reindex2) + return sizes, reindex, reindex1 + + support_vars = index_vars + reduce_vars + iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder( + index_vars, support_vars, index_size, reordering_reindex + ) + reduce_ranges, reduce_reindex, _ = simplify_and_reorder( + reduce_vars, support_vars, reduce_size + ) + + # remember the reordering if not have loop collapse. + if len(iter_ranges) == len(index_vars): + self.iter_reordering_reindex = iter_reordering_reindex + # retrace the loop body with simplification and reordering applied + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + iter_ranges, reduce_ranges, prefix="z" + ) + body = LoopBody( + body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges + ) + return (iter_ranges, reduce_ranges), body + + @staticmethod + def _apply_loop_reordering( + index_vars, + support_vars, + sizes, + memory_addrs, + reordering_reindex=None, + priority_idx=None, + ): + """ + Shuffle the order of loops around to hopefully improve performance. + """ + from .scheduler import pick_loop_order + + if priority_idx is None: + priority_idx = [] + + try: + strides = [ + V.graph.sizevars.stride_hints(expr, index_vars, support_vars) + for expr in memory_addrs + ] + assert len(strides) == len(memory_addrs) and len(strides[0]) == len( + index_vars + ) + # consider both layout(strides) and reordering(reordering_reindex) + if reordering_reindex is not None: + for i in range(len(memory_addrs)): + try: + strides[i] = reordering_reindex[i](strides[i]) + # if len(order) != len(strides), do not reorder + except AssertionError: + pass + order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) + except Exception: + if config.debug: + log.warning( + "Did not simplify complex index:\n%s\n%s", + dict(zip(index_vars, sizes)), + memory_addrs, + ) + order = list(range(len(sizes))) + sizes = [sizes[i] for i in order] + return sizes, same_reorder(order), inverse_reorder(order) + + def get_reduction_size(self): + return self.data.get_reduction_size() + + def get_reduction_type(self): + return self.data.get_reduction_type() + + def is_no_op(self): + return self.data.is_zero_elements() + + def should_allocate(self): + return True + + def constant_to_device(self, device): + """Move this to a given device. Requires that all reads are to constants.""" + return self.data.constant_to_device(device) + + +class TemplateBuffer(Buffer): + """ + Represents a Triton (in the future other type) of template operator + that we can fuse an epilogue onto. + """ + + def __init__(self, layout, inputs, make_kernel_render): + super().__init__(name=None, layout=layout) + self.inputs = InputsKernel.unwrap_storage(inputs) + self.make_kernel_render = make_kernel_render + self.name = V.graph.register_buffer(self) + + def get_read_writes(self): + return self.normalized_read_writes() + + def normalized_read_writes(self): + name = self.get_name() + indexer = self.layout.make_indexer() + + def dummy(index, rindex): + assert len(rindex) == 0 + return ops.store(name, indexer(index), "fake") + + deps = dependencies.extract_read_writes( + dummy, self.get_size(), (), normalize=True + ) + deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs} + return deps + + def get_reduction_size(self): + return 1 + + def get_reduction_type(self): + return None + + def is_no_op(self): + return False + + def should_allocate(self): + return True + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + ): + return ( + ( + self.get_size(), + (), + ), + None, + ) + + +class TritonTemplateBuffer(TemplateBuffer): + pass + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821 + ): + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): + return self.workspace_size if self.workspace_size is not None else 0 + + +@dataclasses.dataclass +class InputsKernel(Buffer): + inputs: List[Buffer] + + def get_read_writes_input(self, x): + return dependencies.StarDep(x.get_name()) + + def get_read_writes(self): + star_dep = [] + for input in self.inputs: + if isinstance(input, list): + star_dep.extend([self.get_read_writes_input(x) for x in input]) + else: + star_dep.append(self.get_read_writes_input(input)) + + return dependencies.ReadWrites( + set(star_dep), + {dependencies.StarDep(self.get_name())}, + set(), + [], + None, + op_counts=collections.Counter(), + ) + + @classmethod + def unwrap_storage_for_input(cls, x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, StorageBox): + x = x.data + if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): + x = ExternKernel.realize_input(x) + if isinstance(x, TensorBox): + # when converting to ReinterpretView fails in the + # realize_input call above, the result will be wrapped + # into TensorBox / StorageBox pair as a result of the + # cls.copy_input call; so we should unwrap recursively + return cls.unwrap_storage_for_input(x) + assert isinstance(x, (Buffer, ReinterpretView)), x + return x + + @staticmethod + def unwrap_storage(inputs): + inputs_new = [] + for x in inputs: + if isinstance(x, list): + x = [InputsKernel.unwrap_storage_for_input(i) for i in x] + else: + x = InputsKernel.unwrap_storage_for_input(x) + inputs_new.append(x) + return inputs_new + + def is_extern(self): + return True + + +class NopKernel(InputsKernel): + def is_no_op(self): + return True + + +class ConcatKernel(NopKernel): + """ + There isn't actually a real kernel for concat, we just change the + storage for the upstream data. + """ + + @classmethod + def create(cls, inputs, dim): + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = FlexibleLayout.contiguous_strides(new_size) + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if is_storage_and_layout(x): + layout = x.get_layout() + if ( + isinstance(layout, FixedLayout) + and layout.is_channels_last_contiguous() + ): + # use CL stride for the output + output_stride = make_channels_last_strides_for(new_size) + break + + concat_kernel = ConcatKernel( + name=None, + layout=FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + kernel = StorageBox(concat_kernel) + buffer_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and inputs[i].get_device().type == "cuda" + and not is_dynamic(input_buffer) + ): + buffer_names.append(input_buffer.get_name()) + + if len(buffer_names) > 1: + V.graph.register_list(buffer_names) + + concat_kernel.name = V.graph.register_buffer(concat_kernel) + concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) + + return kernel + + @classmethod + def can_realize_into_without_copy(cls, src): + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.can_realize_into_without_copy(src.data) + + return isinstance(src.data.layout, FlexibleLayout) and not isinstance( + src.data, ExternKernelAlloc + ) + + @classmethod + def realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ReinterpretView): + if is_storage_and_layout(dst): + storage, layout = as_storage_and_layout(dst) + dst = ReinterpretView(storage, layout) + assert isinstance(dst, ReinterpretView), dst + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src): + src.data.layout = AliasedLayout(dst) + return src.data + # introduce a copy + pw = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + return cls.realize_into(pw, dst) + + def should_allocate(self): + return True + + +@dataclasses.dataclass +class ExternKernel(InputsKernel): + constant_args: Tuple[Any, ...] = () + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + output_view: Optional[ReinterpretView] = None + python_kernel_name: Optional[str] = None + cpp_kernel_name: Optional[str] = None + # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel + # We shouldn't need to do this since the information can be retrieved from op_overload._schema. + ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( + default_factory=list + ) + op_overload: Optional[ + Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + ] = None + arg_properties: Optional[List[Dict[str, Any]]] = None + kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None + + def __init__( + self, + name, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + name, + layout, + inputs, + ) + self.constant_args = constant_args + self.kwargs = kwargs if kwargs else {} + self.output_view = output_view + self.python_kernel_name = python_kernel_name + self.cpp_kernel_name = cpp_kernel_name + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + self.op_overload = op_overload + self.collect_arg_kwarg_properties() + + def collect_arg_kwarg_properties(self): + # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional + # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen + if ( + isinstance(self.op_overload, torch._ops.OpOverload) + and not self.ordered_kwargs_for_cpp_kernel + ): + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] + self.arg_properties = ( + [ + { + "name": x.name, + "type": x.real_type, + "default_value": x.default_value, + } + for x in self.op_overload._schema.arguments + if not x.kwarg_only + ] + if isinstance(self.op_overload, torch._ops.OpOverload) + else [{} for i in range(len(self.inputs))] + ) + self.kwarg_properties = ( + { + x.name: {"type": x.real_type, "default_value": x.default_value} + for x in self.op_overload._schema.arguments + if x.kwarg_only + } + if isinstance(self.op_overload, torch._ops.OpOverload) + else {} + ) + + def decide_layout(self): + if isinstance(self.layout, FlexibleLayout): + self.apply_constraint() + self.freeze_layout() + + def codegen_comment(self, wrapper): + origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper) + if origin_str: + wrapper.writeline(origin_str) + + def codegen(self, wrapper): + raise NotImplementedError() + + def get_kernel_name(self): + return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name + + @staticmethod + def copy_input(x): + pw = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + ) + pw.realize() + return pw + + @classmethod + def process_kernel(cls, kernel, *args, **kwargs): + binded_args = {"args": args, "kwargs": kwargs} + + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + tensor_args = [] + non_tensor_args: List[Any] = [] + for arg in args_flat: + is_arg_tensor.append(isinstance(arg, IRNode)) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + if isinstance(arg, sympy.Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + r = pytree.tree_unflatten(result, args_spec) + return r.get("args", []), r.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # We don't have generic shape formulas, so just burn in the + # shapes and run an example input. + # TODO(jansel): replace this with dynamic shape formulas + example_args = [] + + # We need to retain the constant values of fake tensors that we originally + # propagated the graph with, because for some operators running without a + # constant would trigger an error / DataDependentException + for x in tensor_args: + if x.get_name() in V.graph.constants: + example_args.append(V.graph.constants[x.get_name()]) + else: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + example_out_li = ( + [example_output] + if not isinstance(example_output, (list, tuple)) + else example_output + ) + for t in example_out_li: + if isinstance(t, torch.Tensor) and t.is_sparse: + msg = "sparsity not handled. Please file issue for sparse inference weights." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + # TODO: Unconditionally do this, not just when example_output has + # unbacked symbols + if maybe_free_unbacked_symbols(example_output): + example_output = V.graph.current_node.meta["val"] + + return example_output, tensor_args, non_tensor_args, unflatten_args + + @classmethod + def convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, BaseView) + if isinstance(x, ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x.unwrap_view().freeze_layout() + index_args, var_ranges = dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = sympy_dot(range_vars, strides) + offset + + if index != expected: + log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError() + + return ReinterpretView( + data=x.data, + layout=FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + + @classmethod + def realize_input(cls, x): + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): + return ShapeAsConstantBuffer(x) + if isinstance(x, Constant): + return V.graph.add_tensor_constant( + torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) + ) + if isinstance(x, ConstantBuffer): + return x + if isinstance(x, TensorBox): + return cls.realize_input(x.data) + if isinstance(x, ReinterpretView): + return ReinterpretView(cls.realize_input(x.data), x.get_layout()) + if isinstance(x, BaseView): + x.realize() + if is_storage_and_layout(x.unwrap_view()): + try: + return cls.convert_to_reinterpret_view(x) + except NotImplementedError: + pass + if isinstance(x, StorageBox): + # TODO(jansel): impose layout preference on realized buffer + x.realize() + return x + return cls.copy_input(x) + + @classmethod + def require_stride1(cls, x): + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: + return x + for stride in x.get_stride(): + if stride == 1: + return x + return cls.copy_input(x) + + @classmethod + def require_stride_order(cls, x, order): + if x.get_numel() == 0: # Layout doesn't matter + return x + + # require x to have the layout as strided_ordered as order + if is_storage_and_layout(x): + while isinstance(x.get_layout(), AliasedLayout): + x = x.get_layout().view + if isinstance(x.get_layout(), FlexibleLayout): + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, freeze=True, want_contiguous=False, stride_order=order + ) + return x + elif isinstance( + x.get_layout(), FixedLayout + ) and x.get_layout().is_stride_ordered(order): + return x + elif isinstance(x.get_layout(), MutationLayout): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayout's real layout shouldn't be FlexibleLayout" + ) + elif isinstance( + x.get_layout().real_layout(), FixedLayout + ) and x.get_layout().real_layout().is_stride_ordered(order): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): + return x + if ( + isinstance(x, TensorBox) + and isinstance(x.data, BaseView) + and not isinstance(x.data, ReinterpretView) + and is_storage_and_layout(x.unwrap_view()) + and not isinstance(x.unwrap_view().data, ExternKernelAlloc) + ): + try: + x.data = cls.convert_to_reinterpret_view(x.data) + return cls.require_stride_order(x, order) + except NotImplementedError: + pass + x = cls.copy_input(x) + as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) + assert is_stride_order_storage_and_layout(x, order) + return x + + @classmethod + def require_channels_last(cls, x): + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x): + return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + + def apply_constraint(self): + pass + + def codegen_const_args(self): + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + + def codegen_args(self): + args = [] + for i, x in enumerate(self.inputs): + if isinstance(x, list): + names = [i.codegen_reference() for i in x] + codegen_reference = f'[{", ".join(names)}]' + args.append(codegen_reference) + else: + if V.graph.cpp_wrapper: + assert self.arg_properties and i < len( + self.arg_properties + ), "Invalid arg_properties accessing" + type_ = self.arg_properties[i].get("type") + args.append( + V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + type_, x, self.is_legacy_abi_kernel() + ) + ) + else: + args.append(x.codegen_reference()) + args.extend(self.codegen_const_args()) + return args + + def get_kwargs_value(self, arg_name): + if arg_name in self.kwargs: + return self.kwargs.get(arg_name) + if self.kwarg_properties and self.kwarg_properties.get(arg_name): + return self.kwarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + else: + raise AssertionError(f"{arg_name} not in self.kwarg_properties") + + def is_legacy_abi_kernel(self): + return False + + def codegen_kwargs(self): + if V.graph.cpp_wrapper: + kwargs = [] + for arg_name in self.ordered_kwargs_for_cpp_kernel: + v = self.get_kwargs_value(arg_name) + if isinstance(v, sympy.Expr): + kwargs.append(v) + else: + type_ = ( + self.kwarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.kwarg_properties and arg_name in self.kwarg_properties + else None + ) + kwargs.append( + V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] + type_, v, self.is_legacy_abi_kernel() + ) + ) + else: + kwargs = [ + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] + for k, v in self.kwargs.items() + ] + return kwargs + + def codegen_size_asserts(self, wrapper): + if config.size_asserts and not V.graph.cpp_wrapper: + size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) + stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) + wrapper.writeline( + f"assert_size_stride({self.get_name()}, {size}, {stride})" + ) + + def get_group_stride(self): + """ + get output sizes and strides, for template_codegen + """ + _size = self.get_size() + _stride = self.get_stride() + # iter_ranges = _size of output tensor, reduce_range = [] because no reduction + return [_size, []], _stride + + def canonicalize(self): + """ + Manually get canonicalization of the output index + """ + # manually generate index formula for conv + sizevars = V.graph.sizevars + sizes = self.get_size() + strides = self.get_stride() + strides = [sizevars.size_hint(x) for x in strides] + index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] + # reorder index vars according to stride + index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + lookup = {pos: idx for idx, pos in enumerate(index_order)} + order = [lookup[i] for i in range(len(lookup))] + index_vars = [index_vars[i] for i in order] + indexer = self.make_indexer() + index = indexer(index_vars) + + new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, [index] + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + _, add_var = var_builder("c") + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + + index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] + return index, tuple(new_sizes) + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them + r = set() + for arg in self.constant_args: + r |= maybe_free_unbacked_symbols(arg) + for arg in self.kwargs.values(): + r |= maybe_free_unbacked_symbols(arg) + return r + + def __str__(self): + kernel_name = getattr(self, "python_kernel_name", None) + lines = [ + f"python_kernel_name={kernel_name!r}", + ] + lines += [ + f"{field.name}={getattr(self, field.name)}" + for field in dataclasses.fields(self) + ] + lines.append(f"origin_node={self.origin_node!r}") + return self.str_helper(lines) + + __repr__ = __str__ + + +@dataclasses.dataclass +class ExternKernelOut(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + wrapper.generate_extern_kernel_out( + self.output_view, + self.codegen_reference(), + args, + self.get_kernel_name(), + ) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return True + + +class RandomSeeds(ExternKernelOut): + def __init__(self, count: int, device: torch.device): + limits = torch.iinfo(torch.int64) + super().__init__( + layout=FixedLayout( + device=device, + dtype=torch.int64, + size=[count], + ), + inputs=[], + constant_args=[limits.min, limits.max, [count]], + python_kernel_name="aten.randint.low_out", + cpp_kernel_name="at::randint_out", + ) + + +class ExternKernelAlloc(ExternKernel): + def codegen(self, wrapper): + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def __init__( + self, + layout, + inputs, + constant_args=(), + kwargs=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ): + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return False + + def apply_constraint(self): + raise NotImplementedError + + +class UserDefinedTritonKernel(ExternKernel): + def get_kernel_and_configs(self): + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + if isinstance(kernel, Autotuner): + configs = kernel.configs + kernel = kernel.fn + return kernel, configs + + def codegen(self, wrapper): + kernel, configs = self.get_kernel_and_configs() + + # Definition of kernel + new_name, triton_meta = wrapper.define_user_defined_triton_kernel( + kernel, configs, self.kwargs + ) + + args = self.codegen_kwargs() + if V.graph.cpp_wrapper: + # in C++ wrapper, we don't pass constexpr args, as they don't + # get added as parameters to the PTX code compiled from the + # user-defined Triton kernel (only non-constexpr args do) + args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs] + + # Call to kernel + self.codegen_comment(wrapper) + wrapper.generate_user_defined_triton_kernel( + new_name, + self.grid, + configs, + args, + triton_meta, + ) + + def should_allocate(self): + return False + + def has_side_effects(self): + # UserDefinedTritonKernel does not return anything, but rather + # modifies input in place, do not let it get DCEd + return True + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def get_mutation_names(self): + return [] + + def __init__(self, *, kernel_idx, grid, kernel_args): + inputs = [] + kwargs = dict() + constant_args = [] + for k, v in kernel_args.items(): + if isinstance(v, TensorBox): + t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + inputs.append(t) + kwargs[k] = t + else: + constant_args.append(v) + kwargs[k] = v + + assert len(inputs) != 0 + device = inputs[0].get_device() + + super().__init__( + None, + NoneLayout(device), # type: ignore[arg-type] + inputs, + tuple(constant_args), + kwargs, + ) + self.name = V.graph.register_buffer(self) + self.kernel_idx = kernel_idx + self.grid = grid + + kernel, _ = self.get_kernel_and_configs() + # If we are autotuning, not all arguments will be passed + self.ordered_kwargs_for_cpp_kernel = [ + arg for arg in kernel.arg_names if arg in kernel_args + ] + + mark_node_as_mutating( + self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)] + ) + + def get_alias_names(self): + return [i.get_name() for i in self.inputs] + + +def mark_node_as_mutating(cur_buffer, *mutated_ops): + """ + Allows ops in mutated_ops to be marked as being mutated as well as + indicates to the scheduler that these ops depend on cur_buffer. + """ + for op in mutated_ops: + assert isinstance(op, IRNode), op + V.graph.mark_buffer_mutated(op.get_name()) + assert hasattr(op, "layout") + MutationOutput(op.layout, op, cur_buffer) + + +class MutationOutput(ExternKernel): + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def __init__(self, layout, input, parent): + super().__init__(None, layout, [input, parent], ()) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return False + + def is_no_op(self): + return True + + def has_side_effects(self): + return True + + def get_alias_names(self): + return [self.inputs[0].get_name()] + + +class InplaceBernoulliFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (x,) = (t.codegen_reference() for t in self.inputs) + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__(self, x, *constant_args): + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage([x]), + constant_args, + ) + self.name = V.graph.register_buffer(self) + self.python_kernel_name = "aten.bernoulli_" + self.cpp_kernel_name = ( + "aoti_torch_bernoulli_" + if config.abi_compatible + else "at::native::bernoulli_" + ) + mark_node_as_mutating(self, x) + + +# Used to deal with torch.complex types +class InplaceCopyFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + (dst, src, non_blocking) = self.codegen_args() + wrapper.writeline( + f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__( + self, + layout, + inputs, + constant_args, + ): + super().__init__( + None, + layout, + inputs, + constant_args, + python_kernel_name="aten.copy_", + cpp_kernel_name=( + "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" + ), + ) + self.name = V.graph.register_buffer(self) + + @classmethod + def create(cls, dst, src, non_blocking: bool = False): + inputs = [cls.realize_input(t) for t in [dst, src]] + constant_args = (non_blocking,) + result = InplaceCopyFallback( + NoneLayout(dst.get_device()), # type: ignore[arg-type] + inputs, + constant_args, + ) + mark_node_as_mutating(result, dst) + return result + + +class MutatingFirstArgExternKernel(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper): + argrefs = [ + *(t.codegen_reference() for t in self.inputs), + *map(repr, self.constant_args), + ] + wrapper.writeline( + f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def has_side_effects(self): + return True + + +class ResizeStorageBytes(MutatingFirstArgExternKernel): + def __init__(self, variable, new_size): + assert isinstance(new_size, int), "TODO: dynamic shapes" + super().__init__( + None, + NoneLayout(variable.get_device()), # type: ignore[arg-type] + self.unwrap_storage([variable]), + constant_args=(new_size,), + ) + V.graph.mark_buffer_mutated(variable.get_name()) + self.name = V.graph.register_buffer(self) + self.python_kernel_name = "inductor_ops.resize_storage_bytes_" + self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + V.graph.never_reuse_buffers.add(variable.data.get_name()) + mark_node_as_mutating(self, variable) + + +class ScatterFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly. + This class handles both aten.scatter_ and aten.scatter_reduce_. + It also handle the case `src` being a scalar properly. + """ + + def codegen(self, wrapper): + reduce = self.kwargs["reduce"] + if V.graph.cpp_wrapper: + # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum + get_operator_enum = {"add": "sum", "multiply": "prod"} + if reduce in get_operator_enum: + reduce = get_operator_enum[reduce] + + if self.src_is_tensor: + (x, index, src) = (t.codegen_reference() for t in self.inputs) + else: + (x, index) = (t.codegen_reference() for t in self.inputs) + src = self.constant_args[1] + wrapper.generate_scatter_fallback( + x, + [x, self.constant_args[0], index, src], + self.get_kernel_name(), + self.python_kernel_name, + self.src_is_tensor, + reduce, + self.codegen_kwargs(), + ) + + def should_allocate(self): + return False + + def get_cpp_kernel(self): + reduce = self.kwargs["reduce"] + if self.python_kernel_name == "aten.scatter_": + if self.src_is_tensor: + kernel = ( + "at::scatter_out" if reduce is None else "at::scatter_reduce_out" + ) + else: + assert ( + reduce is None + ), "Expect reduce to be None for aten.scatter_ with scalar src" + kernel = "at::scatter_out" + else: + assert ( + reduce is not None + ), "Expect reduce to be not None for aten.scatter_reduce_" + kernel = "at::scatter_reduce_out" + return kernel + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__( + self, + op_overload, + python_kernel_name, + x, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, + ): + assert python_kernel_name in {"aten.scatter_", "aten.scatter_reduce_"} + self.src_is_tensor = isinstance(src, TensorBox) + + constant_args: Tuple[Any, ...] + if self.src_is_tensor: + tensors = [self.realize_input(t) for t in [x, index, src]] + constant_args = (dim,) + else: + tensors = [self.realize_input(t) for t in [x, index]] + constant_args = (dim, src) + + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + constant_args, + {"reduce": reduce, "include_self": include_self}, + python_kernel_name=python_kernel_name, + ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], + op_overload=op_overload, + ) + self.cpp_kernel_name = self.get_cpp_kernel() + self.name = V.graph.register_buffer(self) + mark_node_as_mutating(self, x) + + +class IndexPutFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation and indices properly + """ + + def codegen(self, wrapper): + (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) + indices = [] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(self.indices): + if self.indices[i] is not None: + indices.append(next(iter_valid_indices)) + else: + indices.append(V.graph.wrapper_code.none_str) + + wrapper.generate_index_put_fallback( + self.get_kernel_name(), x, indices, values, *self.codegen_const_args() + ) + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + def __init__(self, op_overload, x, indices, values, accumulate): + self.indices = indices + valid_indices = [i for i in indices if i is not None] + tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] + cpp_kernel_name = ( + "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" + ) + super().__init__( + None, + NoneLayout(x.get_device()), # type: ignore[arg-type] + self.unwrap_storage(tensors), + (accumulate,), + python_kernel_name="aten.index_put_", + cpp_kernel_name=cpp_kernel_name, + op_overload=op_overload, + ) + self.name = V.graph.register_buffer(self) + mark_node_as_mutating(self, x) + + +class DeviceCopy(ExternKernelOut): + @classmethod + def create(cls, x, device): + if ( + not x.is_extern() + and all( + (r.name in V.graph.constants and isinstance(r, dependencies.MemoryDep)) + for r in x.get_reads() + ) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + developer_warning("DeviceCopy in input program") + return DeviceCopy( + FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + ) + + def codegen(self, wrapper): + args = self.codegen_args() + assert len(args) == 1 + if self.output_view: + wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference()) + else: + wrapper.codegen_device_copy(args[0], self.codegen_reference()) + + +class DynamicScalar(ExternKernel): + """ + The result of a call to aten._local_scalar_dense. + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + # TODO: handle bools carefully + def __init__(self, sym, data): + data.realize() + super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] + if isinstance(sym, sympy.Symbol): + self.sym = sym + self.is_bool = False + else: + # Special case for boolean. For Reasons(TM), we don't represent + # boolean variables directly in sympy; instead, we generate an + # indicator integer variable which we then convert to a boolean by + # testing i0 == 1. We have to identify the underlying indicator + # variable, and then bind i0 to the appropriate integer value + # based on the runtime boolean. + assert isinstance(sym, sympy.Eq), sym + assert isinstance(sym.args[0], sympy.Symbol), sym + assert sym.args[1] == 1, sym + self.sym = sym.args[0] + self.is_bool = True + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return {self.sym} + + def codegen(self, wrapper): + wrapper.codegen_dynamic_scalar(self) + + +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self): + return () + + def should_allocate(self): + return False + + def __init__(self, scalar, msg): + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(torch.device("cpu")), # type: ignore[arg-type] + # InputsKernel(inputs) + [], + ) # type: ignore[arg-type] + self.scalar = scalar + self.msg = msg + + def has_side_effects(self): + return True + + def get_unbacked_symbol_uses(self): + return free_unbacked_symbols(self.scalar) + + def codegen(self, wrapper): + if V.graph.cpp_wrapper: + pass + else: + wrapper.writeline( + f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:" + ) + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + +@dataclasses.dataclass +class ExternKernelNode: + name: str + node: export_schema.Node + + +has_c_shim = { + aten._embedding_bag.default, + aten._fft_c2c.default, + aten._scaled_dot_product_efficient_attention.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_mm.default, + aten.addmm.out, + aten.bmm.out, + aten.copy_.default, + aten.mm.out, + aten.repeat_interleave.Tensor, + aten.nonzero.default, + aten.view.dtype, + aten.view_as_real.default, +} + + +def get_aten_cpp_kernel_name(kernel): + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) + assert ( + isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten" + ), "Invalid aten kernel" + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + return f"at::_ops::{opname}::call" + + +class FallbackKernel(ExternKernelAlloc): + args_default_value: List[Dict[str, Any]] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + ): + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + op_overload=kernel, + ) + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.use_runtime_dispatch = False + self.abi_compatible_kernel = None + + assert isinstance( + kernel, + ( + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ), + ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" + self.op_overload = kernel + + self.unflatten_args = unflatten_args + self.kwargs = {} if kwargs is None else kwargs + V.graph.warn_fallback(self.python_kernel_name) + + # args that are aliased + self.alias_names: List[str] = [] + # args that are mutated AND returned from the op + self.mutation_names: List[str] = [] + + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + # We assume here that HOPs with FallbackKernel are functional. + # This may not always be true! HOPs must individually opt-in to + # FallbackKernel, so please check this if you opt-in. + return + + if "_c10d_functional" in self.op_overload.name(): + # _c10d_functional kernels are lowered into _CollectiveKernel which + # derives from FallbackKernel for the cpp codegen. The kernels + # don't pass the can_auto_functionalize check, but their mutation + # is handled properly by _CollectiveKernel. + return + + schema = self.op_overload._schema + + # NOTE: [FallbackKernel supported operators] + # We only support three types of operators: + # - functional ops + # - view ops + # - inplace aten ops + # - mutating ops that are auto-functionalizable. That is, + # the operator may mutate any number of inputs, but its outputs + # may not alias any of the inputs. + # + # The unsupported cases usually do not show up here (because + # AOTAutograd functionalized them away); the only way for an in-place + # op to show up here is if a lowering or pass introduced it. + if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): + self.mutation_names.append(tensor_args[0].get_name()) + return + + if schema.is_mutable and not can_auto_functionalize(kernel): + raise NotImplementedError( + f"NYI: Can't generate FallbackKernel for {kernel}" + ) + + schema_args = schema.arguments + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + + def handle_aliasing_and_mutation(info, arg): + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)) + is_optional_tensor = isinstance( + info.type, torch.OptionalType + ) and isinstance(info.type.getElementType(), torch.TensorType) + if is_optional_tensor or isinstance(info.type, torch.TensorType): + # PyTorch also accepts None and scalar types for args marked as "Tensor". + # We're not going to check all of them here. + assert not isinstance(arg, (tuple, list)) + + if arg is None: + return + if info.alias_info is None: + return + # can_auto_functionalize already filters out mutable List[Tensor]. + # We can support this in the future, but this is very uncommon. + assert isinstance(info.type, torch.TensorType) or is_optional_tensor + self.alias_names.append(arg.get_name()) + if info.alias_info.is_write: + mark_node_as_mutating(self, arg) + + for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): + handle_aliasing_and_mutation(info, arg) + + def set_cpp_kernel(self, kernel): + from .codegen.wrapper import get_cpp_op_schema + + assert ( + not kernel._schema.is_mutable + ), f"mutable {kernel.__name__} is not supported with cpp_wrapper" + + # These checks are here because ops that return aliasing tensors will + # return type Tensor& instead of Tensor, but codegen will always write + # type Tensor on the LHS. + def is_not_write(arg): + return arg.alias_info is None or not arg.alias_info.is_write + + assert all( + is_not_write(x) for x in kernel._schema.arguments + ), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper" + assert all( + is_not_write(x) for x in kernel._schema.returns + ), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper" + + self.cpp_kernel_name = kernel._schema.name + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.init_args_default_value(kernel._schema) + + def is_legacy_abi_kernel(self): + return ( + config.c_shim_version == "1" + and "_scaled_dot_product_flash_attention" in str(self.python_kernel_name) + ) + + def init_args_default_value(self, schema): + self.args_default_value = [ + { + "name": x.name, + "type": x.real_type, + "value": x.default_value, + } + for x in schema.arguments + if not x.kwarg_only + ] + + def get_pos_arg_value(self, pos, kwargs): + # positional args may be provided in kwargs + pos_arg_name = self.args_default_value[pos]["name"] + if pos_arg_name in kwargs: + log.debug( + "Found argument %s with value %s from kwargs", + pos_arg_name, + kwargs[pos_arg_name], + ) + return kwargs[pos_arg_name] + + assert hasattr( + self, "args_default_value" + ), "self.args_default_value has to be provided" + assert pos < len( + self.args_default_value + ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}" + arg_default_value = self.args_default_value[pos]["value"] + log.debug( + "Use default value %s for argument %s", arg_default_value, pos_arg_name + ) + return arg_default_value + + def codegen_args(self): + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + args, kwargs = self.unflatten_args(tensor_args, self.constant_args) + # Now we setup abi_compatible_kernel after self.python_kernel_name + # and kwargs are adjusted appropriately. + # For sdpa, we need the v2 version since v1 didn't consider optional arg + # FIXME: no need to do this after we switch to the torchgen-ed C shim + self.abi_compatible_kernel = ( + f"{self.cpp_kernel_name}_v2" + if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"} + and config.c_shim_version == "1" + else self.cpp_kernel_name + ) + + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = [ + V.graph.wrapper_code.val_to_cpp_arg_str( + param.real_type, x, self.is_legacy_abi_kernel() + ) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being set. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + if V.graph.cpp_wrapper and hasattr(self, "args_default_value"): + self.fill_non_provided_args(args, kwargs, convert_val_to_str=True) + + # let self.codegen_kwargs handle kwargs + self.kwargs.update(kwargs) + return args + + @staticmethod + def find_device(tensor_args, example_output): + if tensor_args: + return tensor_args[0].get_device() + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + devices = {FallbackKernel.find_device(None, x) for x in example_output} + # Remove None + devices = [device for device in devices if device] + if len(devices) == 1: + return devices[0] + for device in devices: + if device.type == "cuda": + return device + return devices[0] + return None + + def has_side_effects(self): + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + return False + return get_schema_info(self.op_overload).is_mutable() + + def get_alias_names(self): + return self.alias_names + + def get_mutation_names(self): + assert len(self.mutation_names) <= 1 + return self.mutation_names + + def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert hasattr(self, "args_default_value") + n_args = len(args) + n_pos_args = len(self.args_default_value) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + pos_args = [ + self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args) + ] + if convert_val_to_str: + pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args] + args.extend(pos_args) + return args + + # ProxyExecutor Design Note + # We export the ExternFallbackNodes (for custom ops) into a serialized file + # and run it with a host side proxy executor to address the ABI problem + # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. + # Detailed design doc can be found at + # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing + def export_extern_kernel_node(self): + assert isinstance(self, FallbackKernel) + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + args = self.fill_non_provided_args(args, kwargs) + ordered_kwargs = [ + kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel + ] + + serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] + + # serialize_outputs + def handle_single_output(return_type, output): + if isinstance(return_type, torch.TensorType): + # For single Tensor + out = output + if isinstance(output, (list, tuple)): + assert len(output) == 1 + out = output[0] + return export_schema.Argument.create( + as_tensor=export_schema.TensorArgument(name=out.get_name()) + ) + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For single TensorList + return export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=out.get_name()) + for out in output + ] + ) + else: + raise RuntimeError(f"Unsupported return type {type(return_type)}") + + target = self.op_overload + returns = target._schema.returns # type: ignore[union-attr] + if len(returns) == 1: + return_type = returns[0].real_type + output_arguments = [handle_single_output(return_type, self.outputs)] + else: + # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" + assert isinstance(self.outputs, tuple) + assert len(returns) == len(self.outputs) + output_arguments = [ + handle_single_output(return_schema.real_type, output) + for return_schema, output in zip(returns, self.outputs) + ] + + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), # type: ignore[union-attr] + inputs=named_arguments, + outputs=output_arguments, + metadata={}, + ), + ) + + V.graph.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + def codegen(self, wrapper): + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + if ( + config.is_fbcode() + and kernel not in has_c_shim + # C shim v2 is torchgen-ed, which should cover all aten ops. + # If you do hit a missed op, please update gen_aoti_c_shim.py. + and config.c_shim_version == "1" + ): + log.warning( + "%s is missing a c-shim implementation, using proxy executor as fallback", + kernel, + ) + self.use_runtime_dispatch = True + self.set_cpp_kernel(kernel) + else: + self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel) + schema = kernel._schema + self.init_args_default_value(schema) + else: + self.python_kernel_name = str(kernel) + + elif isinstance(kernel, torch._ops.HigherOrderOperator): + self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" + else: + # For non-aten OpOverload, i.e. custom ops + if V.graph.cpp_wrapper: + self.use_runtime_dispatch = True + self.set_cpp_kernel(kernel) + else: + self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" # type: ignore[union-attr] + + if self.use_runtime_dispatch: + self.codegen_comment(wrapper) + + exported_args = None + args = None + if config.is_fbcode() and V.graph.cpp_wrapper: + exported_args = self.export_extern_kernel_node() + else: + args = [*self.codegen_args(), *self.codegen_kwargs()] + + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + self.op_overload, + exported_args, + self.outputs, + ) + else: + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + V.graph.wrapper_code.generate_fallback_kernel(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @staticmethod + def tensor_to_layout(output: torch.Tensor): + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + ) + + @classmethod + def create(cls, kernel, *args, **kwargs): + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + context = ( + V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() + ) + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, *args, **kwargs) + + device = cls.find_device(tensor_args, example_output) + assert device, "Not sure where to find device info" + + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + + def generate_output(output, indices): + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], indices + [(type(output), i)]) + for i in range(len(output)) + ) + elif isinstance(output, dict): + return { + key: generate_output(val, indices + [(type(output), key)]) + for key, val in output.items() + } + elif isinstance(output, torch.Tensor): + return MultiOutput( + cls.tensor_to_layout(output), + packed, + indices, + ) + elif isinstance(output, int): + return output + elif isinstance(output, torch.SymInt): + return output.node.expr + else: + assert ( + output is None + ), f"FallbackKernel output type {type(output)} is not supported" + return None + + outputs = generate_output(example_output, []) + if isinstance(outputs, (list, tuple, dict)): + packed.outputs = outputs # type: ignore[assignment] + else: + packed.outputs = [outputs] + return outputs + + def apply_constraint(self): + return super().apply_constraint() + + +@dataclasses.dataclass +class ComplexView(FallbackKernel): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self): + return False + + def get_alias_names(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].get_name()] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + ): + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + ) + + +@dataclasses.dataclass +class MultiOutputLayout(IRNode): + device: torch.device + + +class MultiOutput(ExternKernel): + # Given an input MultiOutputLayout buffer, indexes out an actual buffer + # from that result. This doesn't actually produce multiple outputs, + # that's MultiOutputLayout! + def codegen_list_tuple_access(self, basename, indices): + if len(indices) > 0: + itype, i = indices[0] + if itype == list: + return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif itype == tuple: + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = V.graph.wrapper_code.codegen_tuple_access( + basename, self.get_name(), str(i) + ) + return self.codegen_list_tuple_access(tuple_access, indices[1:]) + elif itype == dict: + return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type") + else: + return basename + + def codegen(self, wrapper): + wrapper.codegen_multi_output( + self.get_name(), + self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), + ) + self.codegen_unbacked_symbol_defs(wrapper) + + def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): + super().__init__(None, layout, [input], ()) + self.name = V.graph.register_buffer(self) + self.indices = indices + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return self.inputs[0].get_unbacked_symbol_uses() + + def should_allocate(self): + return False + + def get_alias_names(self): + return [ + inp.get_name() + for inp in self.inputs + if isinstance(inp, FallbackKernel) and len(inp.get_alias_names()) > 0 + ] + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: List[int], + stride: List[int], + dilation: List[int], + groups: int, + transposed: bool = False, + output_padding: Optional[List[int]] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + for d in range(2, dim): + weight_size.append(prepacked_weight_size[d]) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, int) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + output_size = output.size() + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + output_stride = make_channels_last_strides_for(output_size) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + output_stride = make_contiguous_strides_for(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: List[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_key = "convolution_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ): + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise_", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary_" + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, inputs[1]) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkl._mkl_linear", + cpp_kernel_name="mkl::_mkl_linear", + ) + self.cpp_kernel_key = "mkl_linear" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& self, + const at::Tensor& mkl_weight_t, + const at::Tensor& origin_weight_t, + const c10::optional& bias_opt, + const int64_t prepack_batch_size)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, packed_w, orig_w, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = make_contiguous_strides_for(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [None, batch_size] + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_key = "linear_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, w, b, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) + else: + constant_args.insert(0, None) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "linear_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr) + """ + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + @classmethod + def create(cls, x, y, w, b, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if b is not None: + b = cls.require_contiguous(cls.realize_input(b)) + inputs.append(b) + else: + constant_args.insert(0, b) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", + cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", + ) + self.cpp_kernel_key = "convolution_transpose_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + output_padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups_: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="aten.mkldnn_rnn_layer", + cpp_kernel_name="at::mkldnn_rnn_layer", + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + res: List[IRNode] = [] + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return make_contiguous_strides_for(output_shape) + + output_sizes = [output_shape, hy_shape, cy_shape] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + make_contiguous_strides_for(hy_shape), + make_contiguous_strides_for(cy_shape), + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + + return output_ir + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_key = "qconv2d_pointwise" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale: float, + x_zp: int, + weight: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zp: "TensorBox", + bias: "TensorBox", + stride_: List[int], + padding_: List[int], + dilation_: List[int], + groups: int, + o_inv_scale: float, + output_zero_point: int, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups, + transposed, + output_padding, + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + constant_args = constant_args + [ + x_scale, + x_zp, + o_inv_scale, + output_zero_point, + output_dtype, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, w, b, accum, w_scale, w_zp] + - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, accum, w_scale, w_zp] + - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, + accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 6 + self.idx_for_inplace_sum = 3 if self.has_bias else 2 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "qconv2d_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor accum, + double accum_scale, + int64_t accum_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + accum, w_scale, w_zp = args[-3], args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + accum_scale, + accum_zp, + o_inv_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-16:] + conv_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + conv_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale, + x_zp, + accum: "TensorBox", + accum_scale, + accum_zp, + weight: "TensorBox", # packed_weight + w_scale, + w_zp, + bias: "TensorBox", + stride_: List[int], + padding_: List[int], + dilation_: List[int], + groups: int, + o_inv_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups, + transposed, + output_padding, + ) + + accum = cls.require_stride_order(accum, req_stride_order) + inputs.append(accum) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + constant_args = constant_args + [ + x_scale, + x_zp, + accum_scale, + accum_zp, + o_inv_scale, + output_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + assert ( + binary_attr == "sum" + ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(accum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, accum) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.default" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" + self.cpp_kernel_key = "qlinear_pointwise" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + std::string post_op_name, + torch::List> post_op_args, + std::string post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 4 + x_scale, x_zp = args[-4], args[-3] + ( + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-6:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-8:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.get_kernel_name(), + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + x_scale: float, + x_zp: int, + weight: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zp: "TensorBox", + bias: "TensorBox", + o_inv_scale: float, + output_zero_point: int, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + cls, + x, + weight, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): + x_scale.realize() + x_zp.realize() + inputs = inputs + [x_scale, x_zp] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zp, int) + constant_args = constant_args + [x_scale, x_zp] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zp.realize() + inputs = inputs + [w_scale, w_zp] + constant_args = constant_args + [ + o_inv_scale, + output_zero_point, + output_dtype, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +@dataclasses.dataclass +class MutableBox(IRNode): + """ + TensorBox / StorageBox allow in-place mutation of Tensors + """ + + data: IRNode + + def __getattr__(self, name): + fn = getattr(self.data, name) + if callable(fn): + return fn + raise AttributeError(f"{type(self.data).__name__}.{name} not callable") + + def realize(self): + return self.data.realize() + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + return self.data.get_unbacked_symbol_uses() + + def codegen_reference(self, writer=None): + return self.data.codegen_reference(writer) + + @property + def layout(self): + return self.data.layout # type: ignore[attr-defined] + + def get_layout(self): + return self.layout + + def get_size(self): + return self.data.get_size() + + @property + def dtype(self): + return self.data.dtype + + def __str__(self): + if isinstance(self.data, MutableBox): + line0 = f"{type(self).__name__}({type(self.data).__name__}(" + endl = "))" + inner = self.data.data + else: + line0 = f"{type(self).__name__}(" + inner = self.data + endl = ")" + + lines = [ + line0, + indent(str(inner)), + endl, + ] + return "\n".join(lines) + + __repr__ = __str__ + + +class TensorBox(MutableBox): + @staticmethod + def create(data): + return TensorBox(StorageBox(data)) + + +class StorageBox(MutableBox): + def is_input_buffer(self): + if isinstance(self.data, (InputBuffer, ReinterpretView)): + return self.data.get_name() in V.graph.graph_inputs + return False + + def realize(self): + if isinstance( + self.data, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ): + return self.data.get_name() + assert isinstance(self.data, (Pointwise, Reduction, Scan)), type(self.data) + origin_node = self.data.get_origin_node() + traceback = self.data.get_traceback() + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=self.data.get_device(), + dtype=self.data.get_dtype(), + size=self.data.get_size(), + ), + data=self.data, + ) + self.data.name = V.graph.register_buffer(self.data) + self.data.origins = self.origins + self.data.origin_node = origin_node + self.data.traceback = traceback + return self.data.name + + def realize_hint(self): + """ + Called on buffers we expect to be forced to realize later. + """ + if ( + isinstance(self.data, (Pointwise, Reduction)) + and self.num_reads() > 1 + and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() + ): + self.realize() + + def has_exceeded_max_reads(self): + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or self.has_large_inner_fn() + ) + + def mark_reuse(self, users): + """ + A heuristic to decide if we should realize a tensor + that is used multiple times. + """ + + def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): + """ + The heuristic for realizing reused result of heavy ops on cpu + """ + heavy_ops = ["exp"] # a list of heavy ops + fn_str = loops.inner_fn_str() + return any((op + "(") in fn_str for op in heavy_ops) + + if ( + users > 1 + and isinstance(self.data, (Pointwise, Reduction)) + and ( + self.num_reads() > config.realize_reads_threshold + or self.has_large_inner_fn() + or (is_cpu(self.data) and should_realize_on_cpu(self.data)) + ) + ): + self.realize() + + @cache_on_self + def num_reads(self): + data = self.data + if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): + return 1 + if isinstance(data, ComputedBuffer): + read_writes = data.get_read_writes() + else: + assert isinstance(data, (Pointwise, Reduction)), type(data) + read_writes = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ).get_read_writes() + return len(read_writes.reads) + + @cache_on_self + def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): + # Skip the check for non Pointwise instances + return ( + (sum(read.index != 0 for read in self.data.get_reads()) > 1) + if isinstance(self.data, Pointwise) + and all( + not isinstance(read, dependencies.StarDep) + for read in self.data.get_reads() + ) + else True + ) + + +@dataclasses.dataclass +class Subgraph(IRNode): + name: str + graph_module: torch.fx.GraphModule + graph: Optional["GraphLowering"] = None + + +@dataclasses.dataclass +class Conditional(ExternKernel): + predicate: Optional[DynamicScalar] = None + operands: Optional[List[TensorBox]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, + predicate: DynamicScalar, + operands: List[TensorBox], + true_subgraph: Subgraph, + false_subgraph: Subgraph, + layout: MultiOutputLayout, + ): + self.predicate = predicate + self.operands = operands + self.true_subgraph = true_subgraph + self.false_subgraph = false_subgraph + + super().__init__( + name=None, + layout=layout, # type: ignore[arg-type] + inputs=[predicate, *operands], # type: ignore[list-item] + ) + + self.name = V.graph.register_buffer(self) + + @classmethod + def create( + cls, + predicate: TensorBox, + true_fn: Subgraph, + false_fn: Subgraph, + operands: List[TensorBox], + ): + predicate = cls.realize_input(predicate) + operands = [cls.realize_input(x) for x in operands] + + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + for subgraph in (true_fn, false_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + + def _aliased_buffers(outputs): + buffers = [ + output.unwrap_view() if isinstance(output, ReinterpretView) else output + for output in outputs + ] + # assuming the same buffer is represented by the same IRNode object + return len({id(buffer) for buffer in buffers}) < len(outputs) + + for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): + if _aliased_buffers(true_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.cond. " + f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" + ) + + # make sure true and false outputs are structurally equivalent + assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) + for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): + assert to.get_size() == fo.get_size(), (i, to, fo) + assert to.get_stride() == fo.get_stride(), (i, to, fo) + assert to.get_device() == fo.get_device(), (i, to, fo) + assert to.get_dtype() == fo.get_dtype(), (i, to, fo) + assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) + + conditional = Conditional( + predicate=predicate, + operands=operands, + true_subgraph=true_fn, + false_subgraph=false_fn, + # use predicate device for consistent codegen-ing + layout=MultiOutputLayout(predicate.get_device()), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + conditional, + [(list, i)], + ) + # as the true and false outputs are equivalent, + # we can use either of them here as a "template" + for i, output in enumerate(true_outputs) + ] + + conditional.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_conditional(self) + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.lru_cache(None) + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + def __init__(self, fn, args, var_ranges): + super().__init__() + self.var_ranges = var_ranges + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.reads = [] + self.writes = [] + self.reads_name2expr = {} + self.writes_name2expr = {} + self.other = [] + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.root_block = LoopBodyBlock(self, fn, args) + self.indexing = None + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def add_index_expr(self, expr: sympy.Expr, category, buf_name): + getattr(self, category).append(expr) + if buf_name is not None: + getattr(self, f"{category}_name2expr")[buf_name] = expr + if expr not in self.indexing_exprs_name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + return self.indexing_exprs_name[expr] + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + name = f"indirect{len(self.indirect_vars)}" + var = sympy_index_symbol(name) + self.indirect_vars.append(var) + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def __call__(self, *indices): + index = list(itertools.chain.from_iterable(indices)) + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all(v not in self.var_ranges for v in index) + replacements = dict(zip(self.var_ranges.keys(), index)) + self.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + result = self.root_block() + self.indexing = None + return result + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, hower in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + self.body = body + + def add_index(expr, category, buf_name=None): + return tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, category, buf_name),), + {}, + ) + + class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] + self.name = "CaptureIndexing" + + def load(self, name: str, index: sympy.Expr): + index = add_index(index, "reads", name) + return self._inner.load(name, index) + + def store(self, name, index, value, mode=None): + index = add_index(index, "writes", name) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = add_index(index, "writes", name) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = add_index(index, "other") + return self._inner.index_expr(index, dtype) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + offsets_size = add_index(offsets_size, "other") + return self._inner.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + @staticmethod + def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + + subblock: LoopBodyBlock + + def shim(mask, other): + return V.ops.masked(mask, subblock, other) + + name = self.body.add_submodule(shim, "masked_subblock") + subblock = LoopBodyBlock(self.body, masked_body, []) + self.body.subblocks[name] = subblock + return tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + @staticmethod + def scan( + dtype_proxy, combine_fn: Callable[..., Any], value_proxy, init_proxy + ): + def shim(dtype, value, init): + return V.ops.scan(dtype, combine_fn, value, init) + + name = self.body.add_submodule(shim, "scan") + return tracer.create_proxy( + "call_module", name, (dtype_proxy, value_proxy, init_proxy), {} + ) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + @staticmethod + def indirect_indexing(index_proxy, size, check=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + + def set_indirect(new_var): + self.body.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check) + ) + + tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + @staticmethod + def output(result): + tracer.create_proxy("output", "output", (result,), {}) + + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + from .sizevars import SimplifyIndexing + + handler: Any = SimplifyIndexing( + CaptureIndexing(proxy_ops), self.body.var_ranges + ) + if config.constant_and_index_propagation: + handler = IndexPropagation(handler) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + +class Wait(ExternKernelAlloc): + """ + Wait should not be used by itself. It should always be constructed in tandem + with a collective op that produces a work to wait on. + """ + + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__(layout, inputs, constant_args) + + def should_allocate(self): + return False + + def codegen(self, wrapper): + from .codegen.wrapper import ReuseLine + + wrapper.add_import_once( + "from torch.distributed._functional_collectives_impl import _wait_tensor" + ) + (input_collective,) = (t.codegen_reference() for t in self.inputs) + wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})") + + # wait op still needs to produce a 'buffer' that represents the tensor output. + # this is a symbolic gesture, and it gets handled by WrapperCodegen. + # codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective') + # to a new name (`self.get_name()`) and `del`s the old name. + wrapper.writeline(ReuseLine(wrapper, self.inputs[0], self, delete_old=False)) + + @classmethod + def create(cls, collective_op: "TensorBox"): + # TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream + collective_op.decide_layout() + return Wait( + layout=AliasedLayout(collective_op), + inputs=[collective_op], + ) + + def get_alias_names(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].codegen_reference()] + + def get_mutation_names(self): + # The generated `_wait_tensor` op mutates the input tensor + return [self.inputs[0].codegen_reference()] + + +class CollectiveKernel(ExternKernel): + """ + Each collective should follow the pattern: + - extend InPlaceCollectiveKernel or OutOfPlaceCollectiveKernel. + - the kernel delegates into c10d processgroup, which returns a 'work' obj + - the work obj is registered via _register_tensor_work so it can be waited on later + """ + + def __init__(self, layout, inputs, constant_args): + super().__init__(None, layout, inputs, constant_args) + self.name = V.graph.register_buffer(self) + + def should_emit_register_tensor_work(self): + return True + + def should_emit_find_or_create_pg(self): + return True + + def codegen_collective(self, wrapper, output_name, input_names): + # factor so the boilerplate can be handled in CollectiveKernel.codegen + raise NotImplementedError("Must implement") + + def codegen_output(self, wrapper, output_name, input_names): + # factor so the boilerplate can be handled in CollectiveKernel.codegen + raise NotImplementedError("Must implement") + + @classmethod + def wrap_inputs_as_inplace(cls, inputs): + def wrap_input(var): + op = InPlaceHint( + FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var + ) + return TensorBox.create(op) + + return list(map(wrap_input, inputs)) + + def codegen(self, wrapper): + wrapper.add_import_once("import torch.distributed as dist") + wrapper.add_import_once("import torch.distributed.distributed_c10d as c10d") + wrapper.add_import_once( + "import torch.distributed._functional_collectives_impl as fun_col_impl" + ) + # extract references to our args in string form for codegen output + input_names = [t.codegen_reference() for t in self.inputs] + output_name = self.get_name() + tag, ranks, group_size = self.constant_args + + if self.should_emit_find_or_create_pg(): + # TODO: avoid more than one ref of the same pg (even though they are cached inside the api) + wrapper.writeline( + f"{output_name}_pg = c10d._find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})" + ) + + self.codegen_output(wrapper, output_name, input_names) + self.codegen_collective(wrapper, output_name, input_names) + if self.should_emit_register_tensor_work(): + wrapper.writeline( + f"fun_col_impl._register_tensor_work({output_name}, {output_name}_work)" + ) + + +class InPlaceCollectiveKernel(CollectiveKernel): + """ + InPlaceCollectiveKernel are those with in-out arguments such as all_reduce. + Extend this kernel if your collective needs to modify its inputs in-place. + """ + + def __init__(self, layout, inputs, constant_args): + super().__init__(layout, inputs, constant_args) + + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + def codegen_output(self, wrapper, output_name, input_names): + if len(input_names) > 1: + wrapper.writeline(f"{output_name} = [{','.join(input_names)}] ") + else: + wrapper.writeline(f"{output_name} = {input_names[0]}") + + +class OutOfPlaceCollectiveKernel(CollectiveKernel): + """ + OutOfPlaceCollectiveKernel are those that allocate their + outputs and leave their inputs inplace, such as all_gather. + """ + + def __init__(self, layout, inputs, outputs, constant_args): + super().__init__(layout, inputs + outputs, constant_args) + self.outputs = outputs + self.original_inputs = inputs + # NOTE: As seen in issue #108780, output buffers of out-of-place collectives + # could be incorrectly reused. As a safety measure, here we just ban the reuse of them. + # TODO: A better fix is to figure out how to propagate the aliases properly, + # so that the buffer is only reused after all its users have consumed it. + for x in self.outputs: + V.graph.never_reuse_buffers.add(x.name) + + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + def codegen_output(self, wrapper, output_name, input_names): + input_names = [t.codegen_reference() for t in self.original_inputs] + wrapper.writeline(f"{output_name}_inputs = [{','.join(input_names)}]") + wrapper.writeline(f"{output_name} = [{','.join(x.name for x in self.outputs)}]") + + @classmethod + def create_output_buffers(cls, inputs, size_cb=None): + outputs = [] + for input in inputs: + new_size = input.get_size() + if size_cb is not None: + size_cb(new_size) + # new_size[0] *= group_size + + buff = OutputBuffer( + layout=FlexibleLayout( + device=input.get_device(), + dtype=input.get_dtype(), + size=new_size, + ), + ) + outputs.append(buff) + return outputs + + @classmethod + def create_output_nodes(cls, coll, output_buffers): + return [ + MultiOutputNoSizeAssert( + out_t.layout, + coll, + f"[{i}]", + ) + for i, out_t in enumerate(output_buffers) + ] + + +class InPlaceHint(ExternKernel): + """ + Helper OP to encode an in/out argument that tries to make it inplace whenever possible. + Wrap the input of your inplace op to enable this behavior. + + The design is based on two key decisions: + - this node is responsible for allocating the in/out buffer used by the collective. + This is controlled by the ``should_allocate`` method that returns True here and + False for the collective node + - The scheduler special-case this node and enable it to reuse its input. + """ + + def codegen(self, wrapper): + input_name = self.inputs[0].codegen_reference() + output_name = self.get_name() + if not wrapper.did_reuse(self, self.inputs[0]): + wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse") + + def __init__(self, layout, input): + input = self.realize_input(input) + super().__init__(None, layout, self.unwrap_storage([input]), ()) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return True + + +class OutputBuffer(ExternKernel): + """ + Represent the output buffer used by ops that require multiple of them + """ + + def __init__(self, layout): + super().__init__(name=None, layout=layout, inputs=[]) + self.name = V.graph.register_buffer(self) + + def should_allocate(self): + return True + + def codegen(self, wrapper): + wrapper.writeline(f"# collective out buffer {self.name}") + + +class MultiOutputNoSizeAssert(MultiOutput): + """ + Extract partial output from a multi-output OP. + Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emitting this. + """ + + def __init__(self, layout, input, index): + super().__init__(layout, input, []) + self.index = index + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}" + ) + + +class Broadcast(InPlaceCollectiveKernel): + def __init__(self, layout, inputs, constant_args, src): + super().__init__(layout, inputs, constant_args) + self.src = src + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, x: "TensorBox", src: int, tag: str, ranks: List[int], group_size: int + ): + inplace_inputs = cls.wrap_inputs_as_inplace([x]) + packed = Broadcast( + layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] + inputs=inplace_inputs, + constant_args=[tag, ranks, group_size], + src=src, + ) + mark_node_as_mutating(packed, inplace_inputs[0]) + return inplace_inputs[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.broadcast(" + f"{output_name}, async_op=True, group={output_name}_pg, src={self.src})" + ) + + +class AllReduceCoalesced(InPlaceCollectiveKernel): + def __init__(self, layout, inputs, constant_args, reduce_op): + super().__init__(layout, inputs, constant_args) + self.reduce_op = reduce_op + + def should_allocate(self): + return False + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + inputs: List["TensorBox"], + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, + ): + inplace_inputs = cls.wrap_inputs_as_inplace(inputs) + packed = AllReduceCoalesced( + layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] + inputs=inplace_inputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + mark_node_as_mutating(packed, inplace_inputs[0]) + return inplace_inputs + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.all_reduce_coalesced(" + f"{output_name}, " + f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), " + f"group={output_name}_pg, " + "async_op=True)" + ) + + +class AllReduce(InPlaceCollectiveKernel): + def __init__(self, layout, inputs, constant_args, reduce_op): + super().__init__(layout, inputs, constant_args) + self.reduce_op = reduce_op + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int + ): + inplace_inputs = cls.wrap_inputs_as_inplace([x]) + + packed = AllReduce( + layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] + inputs=inplace_inputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + mark_node_as_mutating(packed, inplace_inputs[0]) + return inplace_inputs[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.all_reduce(" + f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))" + ) + + +class AllGatherIntoTensor(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args): + super().__init__(layout, inputs, outputs, constant_args) + + @classmethod + def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int): + inputs = [cls.realize_input(x)] + + def compute_size(new_size): + new_size[0] *= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = AllGatherIntoTensor( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + ) + return cls.create_output_nodes(packed, outputs)[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.all_gather_into_tensor(" + f"{output_name}[0], {output_name}_inputs[0], async_op=True, group={output_name}_pg)" + ) + + +class ReduceScatterTensor(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args, reduce_op): + super().__init__(layout, inputs, outputs, constant_args) + self.reduce_op = reduce_op + + @classmethod + def create( + cls, + x: "TensorBox", + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x)] + + def compute_size(new_size): + new_size[0] //= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = ReduceScatterTensor( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + return cls.create_output_nodes(packed, outputs)[0] + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = dist.reduce_scatter_tensor(" + f"{output_name}[0], {output_name}_inputs[0], " + f"async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))" + ) + + +class AllGatherIntoTensorCoalesced(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args): + super().__init__(layout, inputs, outputs, constant_args) + + @classmethod + def create( + cls, + inputs: List["TensorBox"], + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x) for x in inputs] + + def compute_size(new_size): + new_size[0] *= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = AllGatherIntoTensorCoalesced( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + ) + + return outputs + # return cls.create_output_nodes(packed, outputs) + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback(" + f"output_tensors={output_name}, " + f"input_tensors={output_name}_inputs, " + f"group={output_name}_pg, " + "async_op=True)" + ) + + +class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel): + def __init__(self, layout, inputs, outputs, constant_args, reduce_op): + super().__init__(layout, inputs, outputs, constant_args) + self.reduce_op = reduce_op + + @classmethod + def create( + cls, + inputs: List["TensorBox"], + reduce_op: str, + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x) for x in inputs] + + def compute_size(new_size): + new_size[0] //= group_size + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + _ = ReduceScatterTensorCoalesced( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + reduce_op=reduce_op, + ) + + return outputs + + def codegen_collective(self, wrapper, output_name, input_names): + wrapper.writeline( + f"{output_name}_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback(" + f"output_tensors={output_name}, " + f"input_tensors={output_name}_inputs, " + f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), " + f"group={output_name}_pg, " + "async_op=True)" + ) + + +# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel. +class _CollectiveKernel(FallbackKernel): + def should_allocate(self): + return False + + def has_side_effects(self): + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel(self, kernel): + from .codegen.wrapper import get_cpp_op_schema + + self.cpp_kernel_name = kernel._schema.name + self.cpp_kernel_overload_name = kernel._schema.overload_name + self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] + + self.cpp_op_schema = get_cpp_op_schema(kernel) + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ) -> None: + cpp_kernel_name = kernel._name + python_kernel_name = cpp_kernel_name.replace("::", ".") + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + for tensor_arg in tensor_args: + tensor_arg.realize() + + packed = cls( + NoneLayout(tensor_args[0].get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name + + def mark_mutation(x): + if isinstance(x.data, BaseView): + x = x.data.unwrap_view() + MutationOutput(x.layout, x, packed) + + pytree.tree_map(lambda inp: mark_mutation(inp), inputs) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( + cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs + ): + cpp_kernel_name = kernel._name + python_kernel_name = cpp_kernel_name.replace("::", ".") + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + packed = cls( + MultiOutputLayout(device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name + packed.outputs = [packed] + return packed + + +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, MultiOutput): + # This can be two things: + # 1. Out-of-place multi-output coll + # 2. In-place coll with inputs coming from another MultiOutput + coll = inp.inputs[0] + # Case 1 + if isinstance(coll, _CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + # Case 2 + return [] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel, inp: TensorBox) -> None: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + ) = cls.process_kernel(kernel, inp) + packed = cls( + NoneLayout(inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + if isinstance(inp.data, BaseView): + inp = inp.data.unwrap_view() + MutationOutput(inp.layout, inp, packed) + + def get_read_writes(self): + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + +# NB: recursive structure here reflects val_to_arg_str, avoid +# calling free_unbacked_symbols on "exotic" types that don't get pexpr +# treatment +def maybe_free_unbacked_symbols(s): + if isinstance(s, (SymTypes, sympy.Expr)): + # This branch should be impossible in return position + return free_unbacked_symbols(s) + elif isinstance(s, (tuple, list)): + r = set() + for t in s: + r |= maybe_free_unbacked_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_unbacked_symbols(s) + else: + return set() + + +class AllToAllSingle(OutOfPlaceCollectiveKernel): + def __init__( + self, + layout, + inputs, + outputs, + constant_args, + output_split_sizes, + input_split_sizes, + ): + super().__init__(layout, inputs, outputs, constant_args) + self.output_split_sizes = output_split_sizes + self.input_split_sizes = input_split_sizes + + def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: + r = set() + if self.output_split_sizes is not None: + r |= free_unbacked_symbols(self.output_split_sizes) + if self.input_split_sizes is not None: + r |= free_unbacked_symbols(self.input_split_sizes) + return r + + @classmethod + def create( + cls, + x: "TensorBox", + output_split_sizes: Optional[List[Expr]], + input_split_sizes: Optional[List[Expr]], + tag: str, + ranks: List[int], + group_size: int, + ): + inputs = [cls.realize_input(x)] + + def compute_size(new_size): + if output_split_sizes is not None: + new_size[0] = sum(output_split_sizes) + + outputs = cls.create_output_buffers(inputs, compute_size) + + layout = MultiOutputLayout(inputs[0].get_device()) + + packed = AllToAllSingle( + layout=layout, + inputs=inputs, + outputs=outputs, + constant_args=[tag, ranks, group_size], + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + ) + return cls.create_output_nodes(packed, outputs)[0] + + def codegen_collective(self, wrapper, output_name, input_names): + tag, ranks, group_size = self.constant_args + + # TODO: might be necessary to do some pretty printing on + # split sizes + wrapper.writeline( + f"{output_name}_work = dist.all_to_all_single(" + f"{output_name}[0], {output_name}_inputs[0], " + f"output_split_sizes={self.output_split_sizes}, " + f"input_split_sizes={self.input_split_sizes}, " + f"group={output_name}_pg, async_op=True)" + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__init__.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba0fe9754f43c39cd3adc2187ccb68c1890605c8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/__init__.py @@ -0,0 +1 @@ +from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecfbb31f76acf7f98eccde4f1522dcdf30ba925e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dfcc05734cb791b52ffd6afb2c0b3457ada52ef Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce2e57376b2c52e5acf129ac4209964ed86aded1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c0b6f01ec764f5bdbfbf46fdcb6e38611b0b859 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07d09e94e2b2b50f5a6334fe0ddcfd2554e195e5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2613aafe4b2712785e4dd75043d223b54d17180 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c704a30764f355123981e6c4fcd8a86ad0b1d8d4 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_inductor/kernel/__pycache__/unpack_mixed_mm.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/bmm.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..a09c730afa6f9cc6b3b72552b737d328665e66fb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/bmm.py @@ -0,0 +1,128 @@ +import torch + +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ceildiv as cdiv, use_aten_gemm_kernels, use_triton_template + +from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options + +aten = torch.ops.aten + + +def bmm_grid(b, m, n, meta): + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) + + +bmm_template = TritonTemplate( + name="bmm", + grid=bmm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", -2)}} + N = {{size("B", -1)}} + K = {{size("A", -1)}} + + stride_aq = {{stride("A", 0)}} + stride_am = {{stride("A", 1)}} + stride_ak = {{stride("A", 2)}} + + stride_bq = {{stride("B", 0)}} + stride_bk = {{stride("B", 1)}} + stride_bn = {{stride("B", 2)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}} +""", +) + +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") +aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out") + + +@register_lowering(aten.bmm) +def tuned_bmm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + + # options to tune from + choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + if use_triton_template(layout): + for config in mm_configs(m, n, k): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) + + +# Don't register this since it is slower than decomposing it +# @register_lowering(aten.baddbmm) +def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + + # options to tune from + choices = ( + [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + if use_aten_gemm_kernels() + else [] + ) + if use_triton_template(layout): + for config in mm_configs(m, n, k): + bmm_template.maybe_append_choice( + choices, + input_nodes=(inp, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/conv.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a419432cbe0f10a2ce9512e9fe3ccf9e569f4b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/conv.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import functools +import logging +from typing import cast, List, Optional, Sequence, Tuple, TypedDict + +import torch +from .. import config, ir +from ..ir import TensorBox + +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ( + ceildiv, + is_ones, + is_zeros, + pad_listlike, + sympy_product, + use_triton_template, +) +from ..virtualized import V +from .mm_common import filtered_configs + +log = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +def conv_grid(n, c, h, w, meta): + return ( + ceildiv(n * h * w, meta["BLOCK_M"]), + ceildiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +# List of dictionaries to store the kernel configs. Configs that evaluate to true +# will be utilised on the target platform +kernel_configs = [ + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" + {"config": (64, 256, 16, 2, 4), "cond": True}, + {"config": (256, 64, 16, 2, 4), "cond": True}, + {"config": (1024, 16, 16, 1, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 256, 32, 2, 8), "cond": True}, + {"config": (256, 64, 32, 2, 8), "cond": True}, +] + +# Create filtered list of configs based on conv +platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in kernel_configs + if config["cond"] +) + +# On ROCm convert num_stages to 1 as pipelining provides no benefit +if torch.version.hip: + platform_configs = tuple( + (config[0], config[1], config[2], 1, config[4]) for config in platform_configs + ) + +conv_configs = functools.partial( + filtered_configs, + configs=platform_configs, +) + +LOOP_BODY = """ + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +""" +This is a relatively simple conv implementation that can likely be +improved. Many alternate conv versions can be found here: +https://github.com/pytorch/torchdynamo/pull/971 +""" +conv2d_template = TritonTemplate( + name="convolution", + grid=conv_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_H = {{size("X", 2)}} + IN_W = {{size("X", 3)}} + OUT_C = {{size(None, 1)}} + OUT_H = {{size(None, 2)}} + OUT_W = {{size(None, 3)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xh = {{stride("X", 2)}} + stride_xw = {{stride("X", 3)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wh = {{stride("W", 2)}} + stride_ww = {{stride("W", 3)}} + + nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = nhw % OUT_W + nh = nhw // OUT_W + idx_y_h = nh % OUT_H + idx_n = nh // OUT_H + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY + + """ +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (ijk % BLOCK_K_COUNT) * BLOCK_K + ij = ijk // BLOCK_K_COUNT + i = ij // KERNEL_W + j = ij % KERNEL_W + """ + + LOOP_BODY + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +aten_convolution = ExternKernelChoice( + torch.convolution, + "at::convolution", + has_out_variant=False, + op_overload=aten.convolution.default, +) + + +def conv1x1_via_mm(x, w, *, out): + w = torch.squeeze(torch.squeeze(w, -1), -1) + return torch.matmul( + x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1) + ) + + +aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None) + + +class ConvLayoutParams(TypedDict): + stride: tuple[int, ...] + padding: tuple[int, ...] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + + +def conv_layout( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, +) -> ir.Layout: + """Determine output layout for a convolution""" + with V.graph.fake_mode: + output = torch.ops.aten.convolution( + ir.ir_node_to_tensor(x, guard_shape=True), + ir.ir_node_to_tensor(weight, guard_shape=True), + ir.ir_node_to_tensor(bias, guard_shape=True), + stride, + tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type] + dilation, + transposed, + tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type] + groups, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] + + return ir.FixedLayout( + x.get_device(), + x.get_dtype(), + sizes, + stride, + ) + + +def channels_last_order(rank): + order = list(reversed(range(rank))) + order.insert(1, order.pop(-1)) + return order + + +def convert_1x1_conv_to_mm(x, weight, bias): + # special case for 1x1 convolution, which is actually just a matmul + rank = len(weight.get_size()) + for _ in range(rank - 2): + weight = L[aten.squeeze](weight, dim=-1) + weight = L[aten.permute](weight, [1, 0]) + + if x.get_size()[0] != 1: + x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank)) + else: + x.realize() + x.freeze_layout() + + x_permute = list(range(rank)) + x_permute.append(x_permute.pop(1)) + x = L[aten.permute](x, x_permute) + *sizes, in_chan = x.get_size() + x = L[aten.reshape](x, [sympy_product(sizes), in_chan]) + if bias is None: + result = L[aten.mm](x, weight) + else: + result = L[aten.addmm](bias, x, weight) + result = L[aten.reshape](result, [*sizes, -1]) + result_permute = list(range(rank)) + result_permute.insert(1, result_permute.pop(-1)) + return L[aten.permute](result, result_permute) + + +@register_lowering(aten.convolution) +def convolution( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, +): + stride = tuple(stride) + padding = tuple(padding) + dilation = tuple(dilation) + output_padding = tuple(output_padding) + if not isinstance(groups, int): + groups = V.graph.sizevars.evaluate_static_shape(groups) + assert isinstance(groups, int) + kwargs: ConvLayoutParams = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + } + + if len(x.get_size()) == len(weight.get_size()) - 1: + # add batch dimension to simplify rest of function + return L[aten.squeeze]( + convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs), + dim=0, + ) + + out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( + weight.get_size() + ) + ndim = len(kernel_shape) + stride = pad_listlike(stride, ndim) + padding = pad_listlike(padding, ndim) + dilation = pad_listlike(dilation, ndim) + output_padding = pad_listlike(output_padding, ndim) + + def channels_last_conv(): + if V.graph.layout_opt and ndim == 2: + return True + + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + return req_stride_order == ir.NHWC_STRIDE_ORDER + + autotuning_gemm = config.max_autotune or config.max_autotune_gemm + + if ( + (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv())) + and is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + and groups == 1 + ): + return convert_1x1_conv_to_mm(x, weight, bias) + + if bias is not None and ir.get_device_type(x) != "cpu": + # peel off the bias, cudnn is slower with it + result = convolution(x, weight, None, **kwargs) + return L[aten.add]( + result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1]) + ) + + x.realize() + weight.realize() + + # ndim can be 1 for convolution in models such as demucs + # TODO: check if it's beneficial to convert Conv1d to Conv2d and then + # apply channels last. + if V.graph.layout_opt and ndim == 2: + V.graph.num_channels_last_conv += 1 + x = ir.ExternKernel.require_channels_last(x) + # TODO maybe we can convert weights to channels last just once before + # running the model. + weight = ir.ExternKernel.require_channels_last(weight) + layout = conv_layout(x, weight, None, **kwargs) + else: + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + x = ir.ExternKernel.require_stride_order(x, req_stride_order) + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) + + ordered_kwargs_for_cpp_kernel = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + if bias is None: + args = [x, weight] + kwargs["bias"] = None # type: ignore[typeddict-unknown-key] + ordered_kwargs_for_cpp_kernel.insert(0, "bias") + else: + args = [x, weight, bias] + bias.realize() + bias.freeze_layout() + V.graph.sizevars.evaluate_static_shapes(bias.get_size()) + choices = [ + aten_convolution.bind( + args, + layout, + ordered_kwargs_for_cpp_kernel, + **kwargs, + ) + ] + + if ( + use_triton_template(layout) + # templates only support these: + and ndim == 2 + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) + and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type] + ): + if ( + is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and groups == 1 + ): + choices.append(aten_conv1x1_via_mm.bind(args, layout)) + + for cfg in conv_configs( + sympy_product([x.get_size()[0], *x.get_size()[2:]]), + out_chan, + in_chan, + ): + conv2d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_H=kernel_shape[0], + KERNEL_W=kernel_shape[1], + STRIDE_H=stride[0], + STRIDE_W=stride[1], + PADDING_H=padding[0], + PADDING_W=padding[1], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/openai/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + + return autotune_select_algorithm("convolution", choices, args, layout) + + +@register_lowering(aten._convolution) +def _convolution( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32, +): + return convolution( + x, weight, bias, stride, padding, dilation, transposed, output_padding, groups + ) + + +def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): + assert fx_node.target == torch.ops.aten.convolution.default + if V.graph.layout_opt: + return args, kwargs + else: + return constrain_to_fx_strides(fx_node, *args, **kwargs) + + +add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides) diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/mm.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..f1375501d570c56e12983eceb06b0aac6b17e042 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/mm.py @@ -0,0 +1,312 @@ +import functools +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch._inductor.virtualized import V +from .. import config as inductor_config +from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ( + use_aten_gemm_kernels, + use_cutlass_template, + use_max_autotune, + use_triton_template, +) +from .mm_common import ( + addmm_epilogue, + int8_mm_configs, + mm_args, + mm_configs, + mm_grid, + mm_options, +) + +log = logging.getLogger(__name__) +aten = torch.ops.aten + +mm_template = TritonTemplate( + name="mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if B_PROLOGUE_CAST_TYPE is not None: + b = b.to(B_PROLOGUE_CAST_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out") + + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.default +) + +aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm") + + +def _is_int8_mat(mat): + return mat.get_dtype() in (torch.int8, torch.uint8) + + +def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): + """ + Giving torch.addmm a 1D tensor calls a different (faster) cublasLt + kernel under the hood. There are a few shapes where this is slower, + but they are rare. + """ + if inp.stride(0) == 0 or inp.size(0) == 1: + return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) + return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) + + +aten_bias_addmm = ExternKernelChoice(bias_addmm, None) + + +@register_lowering(aten.mm, type_promotion_kind=None) +def tuned_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + + # options to tune from + choices = [aten_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + + if m * n != 0 and use_triton_template(layout): + for config in mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + if m * n != 0 and use_cutlass_template(layout): + CUTLASSGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + + from torch._inductor.ir import FixedLayout, FlexibleLayout + + if ( + len(choices) == 1 + and use_aten_gemm_kernels() + and isinstance(layout, FixedLayout) + ): + # If we are not autotuning, we can swap to a FlexibleLayout + # in order to get fusion optimizations to kick in, e.g. ConcatFusion + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = [aten_mm.bind((mat1, mat2), layout)] + + return autotune_select_algorithm("mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten._int_mm, type_promotion_kind=None) +def tuned_int_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + choices = ( + [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + if m * n != 0 and use_triton_template(layout, enable_int32=True): + # TODO: Re-enable eager mode implementation once cuBLAS is fixed + choices = [] + for config in int8_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten.addmm, type_promotion_kind=None) +def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + if m * n == 0 or not use_max_autotune(): + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + use_aten_gemm_kernels() + and inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + + if use_triton_template(layout): + for config in mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + if use_cutlass_template(layout): + CUTLASSGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + fuseable=False, + ) + + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) + + +def fallback_mixed_mm(mat1, mat2, *, out): + return torch.mm(mat1, mat2.to(mat1.dtype), out=out) + + +aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None) + + +@functools.lru_cache(None) +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: + props = torch.cuda.get_device_properties(index or 0) + return props.major <= 7 + + +def tuned_mixed_mm(mat1, mat2, mat2_dtype): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None) + choices = [aten_fallback_mixed_mm.bind((mat1, mat2), layout)] + if ( + mat1.layout.dtype != torch.float32 and not mat2.layout.is_contiguous() + ) or _is_sm7x_or_older_gpu(layout.device.index): + # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) + return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout) + if inductor_config.force_mixed_mm: + choices = [] + b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") + has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) + for config in mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout, b_prologue_cast_type), + ) + return autotune_select_algorithm("mixed_mm", choices, [mat1, mat2], layout) + + +# This op is a special case of the int_mm op which we use based on the pattern +# _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent +# realization of the int32 _int_mm output by forcing fusion with the mul op. +# This is only used when config.force_fuse_int_mm_with_mul = True +def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None): + out_dtype = ( + torch.promote_types(mat3.get_dtype(), torch.int32) + if out_dtype is None + else out_dtype + ) + m, n, k, layout, mat1, mat2, mat3 = mm_args( + mat1, mat2, mat3, layout=layout, out_dtype=out_dtype + ) + choices: List[Dict[Any, Any]] = [] + for config in int8_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3), + layout=layout, + **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"), + suffix_args=1, + epilogue_fn=V.ops.mul, + ) + return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout) diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/mm_common.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/mm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..0edc9b9b5ba645b236f9885d63a5f61be57cec31 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/mm_common.py @@ -0,0 +1,262 @@ +import functools +import logging +from typing import cast, List, Tuple + +import sympy + +import torch +from torch._inductor.select_algorithm import realize_inputs +from torch._inductor.virtualized import V + +from .. import config as inductor_config +from ..utils import ceildiv as cdiv, next_power_of_2 + +log = logging.getLogger(__name__) + + +def triton_config(num_stages, num_warps, **kwargs): + from triton import Config + + return Config(kwargs, num_stages=num_stages, num_warps=num_warps) + + +def filtered_configs( + m: int, + n: int, + k: int, + configs: List[Tuple[int, int, int, int, int]], + has_int8_tensor=False, +): + """Heuristic to shrink configs when they are bigger than the input size""" + + # According to https://github.com/openai/triton/issues/2156#issuecomment-1695897424 + # it's safer to use at least [32, 32] block size for int8/uint8 + # tensors + min_block_size = 32 if has_int8_tensor else 16 + m = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + n = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + k = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + used = set() + for block_m, block_n, block_k, num_stages, num_warps in configs: + # shrink configs for small sizes + block_m = max(min(block_m, m), min_block_size) + block_n = max(min(block_n, n), min_block_size) + block_k = max(min(block_k, k), min_block_size) + # each warp computes 16x16 tile = 256 + num_warps = min(num_warps, block_m * block_n // 256) + if torch.version.hip: + for matrix_instr_nonkdim in [0, 16]: + if matrix_instr_nonkdim != 0 and ( + block_m % matrix_instr_nonkdim != 0 + or block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + if ( + block_m, + block_n, + block_k, + num_stages, + num_warps, + matrix_instr_nonkdim, + ) not in used: + used.add( + ( + block_m, + block_n, + block_k, + num_stages, + num_warps, + matrix_instr_nonkdim, + ) + ) + yield triton_config( + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=matrix_instr_nonkdim, + ) + else: + if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used: + used.add((block_m, block_n, block_k, num_stages, num_warps, 0)) + yield triton_config( + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_stages=num_stages, + num_warps=num_warps, + ) + + +# List of dictionaries to store the kernel configs. Configs that evaluate to true +# will be utilised on the target platform +mm_kernel_configs = [ + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, + {"config": (64, 64, 16, 2, 4), "cond": True}, + {"config": (32, 32, 16, 1, 2), "cond": True}, +] + +int8_mm_kernel_configs = [ + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + # {"config": (32, 32, 128, 2, 4), "cond": True}, + # {"config": (64, 64, 16, 2, 4), "cond": True}, + # {"config": (32, 32, 16, 1, 2), "cond": True}, + {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None}, + {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, +] + +# Create filtered list of configs based on cond evaluation + + +mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mm_kernel_configs + if config["cond"] +) +int8_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in int8_mm_kernel_configs + if config["cond"] +) + +# On ROCm convert num_stages to 1 as pipelining provides no benefit +if torch.version.hip: + mm_platform_configs = tuple( + (config[0], config[1], config[2], 1, config[4]) + for config in mm_platform_configs + ) + int8_platform_configs = tuple( + (config[0], config[1], config[2], 1, config[4]) + for config in mm_platform_configs + ) + +mm_configs = functools.partial( + filtered_configs, + configs=mm_platform_configs, +) + +int8_mm_configs = functools.partial( + filtered_configs, + configs=int8_platform_configs, +) + + +def mm_grid(m, n, meta): + """ + The CUDA grid size for matmul triton templates. + """ + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) + + +def acc_type(dtype): + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None): + """ + Common options to matmul triton templates. + """ + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) + == config.kwargs["BLOCK_K"] + ) + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) + ) + return dict( + GROUP_M=8, + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + ACC_TYPE=acc_type(layout.dtype), + B_PROLOGUE_CAST_TYPE=b_prologue_cast_type, + num_stages=config.num_stages, + num_warps=config.num_warps, + **config.kwargs, + ) + + +def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False): + """ + Common arg processing for mm,bmm,addmm,etc + """ + mat1, mat2 = realize_inputs(mat1, mat2) + *b1, m, k1 = mat1.get_size() + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.guard_equals(k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + layout = FixedLayout( + mat1.get_device(), + out_dtype, + [*b, m, n], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + from ..lowering import expand + + others = [realize_inputs(expand(x, layout.size)) for x in others] + + return [m, n, k, layout, mat1, mat2, *others] + + +def addmm_epilogue(dtype, alpha, beta): + def epilogue(acc, bias): + if alpha != 1: + acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) + if beta != 1: + bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) + return V.ops.add(acc, bias) + + return epilogue diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..8d547532c606aae2ade4058134207ad4f1bb5c51 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,235 @@ +import functools + +import torch + +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid, mm_options + +aten = torch.ops.aten + +aten_mm_plus_mm = ExternKernelChoice( + torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm" +) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + for k2 in range(K1, 0, -BLOCK_K): + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +@functools.lru_cache(None) +def mm_configs(): + import triton + + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform + mm_triton_configs = [ + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 2, + "num_warps": 4, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 3, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 16, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, + "num_stages": 1, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, + "num_stages": 1, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, + "num_stages": 1, + "num_warps": 8, + "cond": torch.version.hip is None, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, + "num_stages": 2, + "num_warps": 4, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, + "num_stages": 1, + "num_warps": 2, + "cond": True, + }, + ] + + # Filter out configs in which cond evaluates to true + # On ROCm convert num_stages to 1 as pipelining provides no benefit + if torch.version.hip: + filtered_configs = [ + triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"]) + for c in mm_triton_configs + if c["cond"] + ] + else: + filtered_configs = [ + triton.Config( + c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"] + ) + for c in mm_triton_configs + if c["cond"] + ] + + return filtered_configs + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + # Optimization is optional, because we can always just not do the fusion + if ( + m1 * n1 == 0 + or m2 * n2 == 0 + or not V.graph.sizevars.statically_known_list_equals( + mat1.get_size(), mat3.get_size() + ) + or not V.graph.sizevars.statically_known_list_equals( + mat2.get_size(), mat4.get_size() + ) + ): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/openai/triton/issues/967 + return lowerings[aten.add]( + lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) + ) + + assert layout1 == layout2 + # options to tune from + choices = ( + [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + if use_aten_gemm_kernels() + else [] + ) + if use_triton_template(layout1): + for config in mm_configs(): + # see https://github.com/openai/triton/issues/1298 + # BLOCK_K = K causes llvm error + if config.kwargs["BLOCK_K"] < k1: + mm_plus_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3, mat4), + layout=layout1, + **mm_options(config, m1, n1, k1, layout1), + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py b/MLPY/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf03a9ffb0b5127cb480fccfa6255893f0efc40e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py @@ -0,0 +1,82 @@ +import logging +from typing import List + +from ..select_algorithm import autotune_select_algorithm, ChoiceCaller, TritonTemplate +from .mm_common import mm_args, mm_configs, mm_grid, mm_options + +log = logging.getLogger(__name__) + +uint4x2_mixed_mm_template = TritonTemplate( + name="uint4x2_mixed_mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn) + b_shifts = 4*(rk%2) + b_subs = 8*(1-(rk%2)) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + b = ((b >> b_shifts[:, None]) & 0xF) - 8 + b = b.to(B_PROLOGUE_CAST_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K//2 * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True) + choices: List[ChoiceCaller] = [] + b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") + for config in mm_configs(m, n, k): + uint4x2_mixed_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout, b_prologue_cast_type), + ) + return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout) diff --git a/MLPY/Lib/site-packages/torch/_inductor/lowering.py b/MLPY/Lib/site-packages/torch/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..8987a72cabfde8f1c60d06acf4aaae184240a4af --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/lowering.py @@ -0,0 +1,6006 @@ +import functools +import itertools +import logging +import os +import warnings +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import sympy + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional, + triton_kernel_wrapper_mutation, +) +from torch._prims_common import ( + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from .._dynamo.utils import import_submodule + +from . import config, inductor_prims, ir, test_operators # NOQA: F401 +from .decomposition import decompositions, get_decompositions +from .ir import ( + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from .utils import ( + ceildiv, + decode_device, + is_dynamic, + is_pointwise_use, + pad_listlike, + parallel_num_threads, + sympy_product, +) +from .virtualized import ops, V + +log = logging.getLogger(__name__) +lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +layout_constraints: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +fallbacks: Set[torch._ops.OpOverload] = set() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs: Set[torch._ops.OpOverload] = set() +foreach_ops: Set[torch._ops.OpOverload] = set() +inplace_foreach_ops: Set[torch._ops.OpOverload] = set() +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = dict() +quantized_decomposed = torch.ops.quantized_decomposed + + +def assert_nyi(cond, msg): + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs(fn): + if isinstance(fn, (list, tuple, set)): + return [add_needs_realized_inputs(x) for x in fn] + needs_realized_inputs.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + needs_realized_inputs.add(getattr(fn, overload)) + + +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + layout_constraints[getattr(fn, overload)] = constraint + else: + layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten.upsample_bicubic2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: int): + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x): + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x): + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, (Number, sympy.Expr)): + return inp + else: + assert hasattr(inp, "get_dtype") + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool): + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + if (type_promotion_kind or convert_input_to_bool) and indices: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME that's a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Expr)) or hasattr(a, "dtype") + ] + dtype = get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind + ) + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(arg.value, dtype, args[indices[0]].get_device()) + else: + return arg + + args = [promote(a) for a in args] + if broadcast and indices: + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + + return args + + +def _register_foreach_lowering(aten_fn, decomp_fn): + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]] = list(args) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = args[0] + + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + # kwargs tensors not supported yet unless it's a fallback op + assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all( + fn in fallbacks for fn in aten_fn + ) + + args = transform_args( + args, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, +): + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest( + reversed(a), reversed(b), fillvalue=sympy.Integer(1) + ): + if y == 1: + output.append(x) + elif x == 1: + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert ( + override_return_dtype is None or type_promotion_kind is None + ), "only one of override_return_dtype or type_promotion_kind may be given" + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Expr)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, type_promotion_kind=type_promotion_kind + ) + + def const_func(x): + if isinstance(x, sympy.Expr): + return ir.IndexingConstant(x, dtype, decode_device(None)) + else: + return ir.Constant(x, dtype, decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) + ) + ) + elif isinstance(x, sympy.Expr): + out.append( + ExpandView.create( + IndexingConstant(x, ex.get_dtype(), ex.get_device()), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_cuda_float64=None, + allow_alpha=False, + triton_fallback=None, +): + def inner(*inputs: List[TensorBox], alpha=None): + if triton_fallback is not None and any(map(is_triton, inputs)): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_cuda = decode_device(inputs[0].get_device()).type == "cuda" + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif override_fn_when_cuda_float64 and is_cuda and dtype == torch.float64: + return override_fn_when_cuda_float64(*[load(index) for load in loaders]) + else: + return fn(*[load(index) for load in loaders]) + + if not override_device: + device = None + for i in inputs: + if i.get_device().type == "cuda": + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: List[List[TensorBox]], alpha=1): + # group by device, whether any of the inputs are dynamic, and whether their types match + # (proxy for type promotion) + def group_args(arg_pairs): + out = defaultdict(list) + for i, args in enumerate(arg_pairs): + use_foreach = not is_dynamic(*args) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert ( + device is not None + ), "foreach op should have at least one tensor arg" + out[(device, use_foreach)].append((i, args)) + return out + + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + ) + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + realize_outputs = True + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert ( + a_list_input is not None + ), "at least one input must be a list to a foreach op" + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + buffer_list = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if device.type == "cuda" and use_foreach and realize_outputs: + buffer_list.append(output.realize()) + + if buffer_list: + V.graph.register_list(buffer_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) + + +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decompostion doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + raise NotImplementedError( + f"bitcast {x_dtype} to different bitwidth type {dtype} is not supported yet." + ) + + def _to_dtype_bitcast(x): + # Because we may promote tensor type from float16 or bfloat16 + # to float, we will need to pass the original src dtype (i.e. x_dtype), + # which is used for correctly constructing type conversion before bitcast, + # which requires the bitwidth of the input tensor type is the same as the + # target type. + return ops.to_dtype_bitcast(x, dtype, x_dtype) + + return make_pointwise(_to_dtype_bitcast, override_return_dtype=dtype)(x) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype, copy=True) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False): + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + return TensorBox.create(ir.DeviceCopy.create(x, device)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device): + return to_device(x, device, copy=True) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = canonicalize_dims(len(x.get_size()), dim) + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not any(V.graph.sizevars.shape_env.is_unbacked_symint(s) for s in x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not any( + V.graph.sizevars.shape_env.is_unbacked_symint(s) for s in sizes + ): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + return TensorBox(ExpandView.create(x.data, tuple(sizes))) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return expand(x, new_size) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.Integer(0) + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + return TensorBox(View.create(x.data, sizes)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + return TensorBox(PermuteView.create(x.data, tuple(dims))) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step)) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # as_strided ignores views + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(storage, new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + idx_load[dim] -= inputs_ranges[i][0] + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + ) + + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + if all(input.get_dtype() in [torch.int8, torch.uint8] for input in inputs): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if inputs[0].get_device().type == "cpu" or fusable_reduction: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount() + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users) + all_pointwise_inputs = all(should_lower_cat_input(inp) for inp in inputs) + any_pointwise_inputs = any(should_lower_cat_input(inp) for inp in inputs) + + if all_pointwise_inputs or (any_pointwise_inputs and pointwise_uses): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = max(min(original_shape[dim1] + offset, original_shape[dim2]), 0) + else: + diag_size = max(min(original_shape[dim1], original_shape[dim2] - offset), 0) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + if isinstance(sizes, sympy.Expr): + # TODO: We don't have to guard on sizes per se, but the number + # of splits must stay constant + sizes = V.graph.sizevars.evaluate_static_shape(sizes) + if isinstance(sizes, (int, sympy.Integer)): + sizes = [sizes] * ((x_size + sizes - 1) // sizes) + result = [] + start = 0 + for size in sizes: + end = start + size + result.append(slice_(x, dim, start, end)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [] + for i in range(x_size): + result.append(select(x, dim, i)) + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint(dim_size) > 0: + x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.Integer(1)) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + assert isinstance(dim, int) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + a = slice_(x, dim, 0, new_len) + b = slice_(x, dim, new_len, new_len * 2) + return mul(a, sigmoid(b)) + + +def register_onednn_fusion_ops(): + if torch._C._has_mkldnn: + cpu_needs_realized_inputs = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv2d_pointwise, + ] + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm + ): + return TensorBox.create( + ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): + return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr)) + + @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) + def convolution_transpose_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.ConvolutionTransposeUnary.create( + x, + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(aten.mkldnn_rnn_layer.default) + def mkldnn_rnn_layer( + x: TensorBox, + w0: TensorBox, + w1: TensorBox, + w2: TensorBox, + w3: TensorBox, + hx: TensorBox, + cx: TensorBox, + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + return pytree.tree_map( + TensorBox.create, + ir.MkldnnRnnLayer.create( + x, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ), + ) + + @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) + def qconvolution_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.QConvPointWisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None + ) + def qconvolution_binary( + x: TensorBox, + x_scale, + x_zp, + accum: TensorBox, + accum_scale, + accum_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if ( + binary_attr == "sum" + and output_dtype in [torch.float32, torch.bfloat16] + and accum.get_dtype() in [torch.float32, torch.bfloat16] + and accum.get_dtype() != output_dtype + ): + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype convertion here. + accum = to_dtype(accum, output_dtype) + return TensorBox.create( + ir.QConvPointWiseBinaryPT2E.create( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + + @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) + def qlinear_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.QLinearPointwisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + if torch._C.has_mkl: + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: TensorBox, + batch_size, + ): + result = TensorBox.create( + ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) + ) + if b is not None: + result = add(result, b) + return result + + add_needs_realized_inputs(cpu_needs_realized_inputs) + else: + pass + + +register_onednn_fusion_ops() + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + return pytree.tree_map( + TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + return handler + + +@functools.lru_cache(None) +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch._subclasses.FakeTensor, parent=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + if parent and parent.target in ( + torch.ops.aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ): + return False + _warn_complex_not_supported() + return True + return False + + +def unsupported_output_tensor(t: torch._subclasses.FakeTensor, parent=None): + "Do not support writing tensor but can read from it" + if unsupported_input_tensor(t, parent): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(node, parent, is_output): + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + for meta in pytree.tree_leaves(node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, parent): + return True + else: + if unsupported_input_tensor(meta, parent): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, node, is_output=False): + return True + + return check_skip_condition(node, node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True): + assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + x.realize() + ir.InplaceBernoulliFallback(x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError() + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + low, + high, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +@register_lowering(aten.bucketize, type_promotion_kind=None) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not (is_triton(input) and is_triton(boundaries)): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + boundaries_size = boundaries.get_size()[0] + boundaries_loader = boundaries.make_loader() + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + boundaries.get_name(), + boundaries_size, + index_dtype, + right, + ) + + return indices + + return Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = { + "torchvision::roi_align", +} + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + if not meta_val.is_cuda: + return arg + + stride_order = ir.get_stride_order(meta_val.stride()) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + def is_aligned_realized_tensor(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + try: + arg.get_stride() + if is_aligned_realized_tensor(arg): + return arg + except AttributeError: + pass + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return arg + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten.index_reduce) # @pearu +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten.avg_pool3d) # @isuruf +make_fallback(aten.fractional_max_pool3d) # @isuruf +make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) +make_fallback(aten.cummax) # @isuruf +make_fallback(aten.cummin) # @isuruf + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? +make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl) + + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) +# See resize_storage_bytes +make_fallback(aten.resize) +make_fallback(aten.resize_) +make_fallback(aten.resize_as) +make_fallback(aten.resize_as_) + + +# 2) Medium +make_fallback(aten.max_unpool2d) +make_fallback(aten.max_unpool3d) +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten.addmv, warn=False) +make_fallback(aten._addmm_activation, warn=False) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_dense_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.avg_pool3d_backward) +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) + +# Implmented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) +make_fallback(aten._scaled_mm.default, constrain_to_fx_strides) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(x, layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Expr): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + sym = V.graph.current_node.meta["val"].node.expr + buffer = ir.DynamicScalar(sym, data) + buffer.name = V.graph.register_buffer(buffer) + return sym + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + buffer = ir.AssertScalar(data, msg) + # This buffer isn't used by anyone (it returns None), so we must explicitly register it + buffer.name = V.graph.register_buffer(buffer) + return buffer + + +def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Expr): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + +@register_lowering(aten.full_like, type_promotion_kind=None) +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, device, dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data.ranges = [0] * len(size) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + assert index.get_dtype() == torch.int64 + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + if len(idx) != 0: + idx[dim] = ops.indirect_indexing(index_loader(idx), size[dim]) + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in zip(tensor_indices, tensor_indices[1:]): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check=check, + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put) +def index_put(x, indices, values, accumulate=False): + return index_put_(clone(x), indices, values, accumulate) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_(clone(x), indices, values, accumulate, check=False) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + deterministic = torch.are_deterministic_algorithms_enabled() + if is_triton(values) and (accumulate or deterministic): + msg = ( + "index put with accumulate." + if not deterministic + else "deterministic index put." + ) + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=True) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=False) + + +def needs_fallback_due_to_atomic_add_limitations(dtype): + # tl.atomic_add does NOT support the following types + return dtype in {torch.int64, torch.bool, torch.bfloat16} + + +def index_put_impl_(self, indices, values, accumulate, check): + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in {torch.bool, torch.uint8} + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in {torch.bool, torch.uint8}: + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayout(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +@register_lowering( + inductor_prims.masked_scatter_with_index, type_promotion_kind=None, broadcast=False +) +def masked_scatter_with_index(self, mask, source_idx, source): + self_flat, mask_flat, source_flat = (view(x, (-1,)) for x in (self, mask, source)) + + assert self.get_size() == mask.get_size() + assert mask.get_dtype() in {torch.bool, torch.uint8} + + self_loader = self_flat.make_loader() + mask_loader = mask_flat.make_loader() + source_idx_loader = source_idx.make_loader() + source_loader = source_flat.make_loader() + source_numel = source.get_numel() + + def inner_fn(idx): + self_val = self_loader(idx) + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + + def load_source_val(): + source_idx_val = source_idx_loader(idx) + i = ops.indirect_indexing(source_idx_val, source_numel) + return source_loader([i]) + + source_val = ops.masked(mask_val, load_source_val, 0) + return ops.where(mask_val, source_val, self_val) + + result_flat = Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=self_flat.get_size(), + ) + return view(result_flat, self.get_size()) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + fn, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + reduce_ty = "add" if fn == "aten.scatter_" else "sum" + if ( + reduce not in {None, reduce_ty} + or ( + isinstance(src, TensorBox) + and src.get_device().type == torch.device("cuda").type + and needs_fallback_due_to_atomic_add_limitations(src.get_dtype()) + ) + or ( + fn == "aten.scatter_reduce_" + and reduce == "sum" + and isinstance(src, TensorBox) + and src.get_device() == torch.device("cpu") + and config.cpp.fallback_scatter_reduce_sum + and (config.cpp.dynamic_threads or parallel_num_threads() != 1) + ) + or (reduce == reduce_ty and self.get_dtype() in {torch.bool, torch.int64}) + or torch.are_deterministic_algorithms_enabled() + ): + ir.ScatterFallback( + V.graph.current_node.target, + fn, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in {None, "add", "multiply"} + + fallback_result = scatter_fallback( + "aten.scatter_", self, dim, index, src, reduce=reduce + ) + + if fallback_result: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} + + fallback_result = scatter_fallback( + "aten.scatter_reduce_", + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim] + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayout(self), + zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + None, + ir.MutationLayout(self), + scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: Tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(aten.upsample_bicubic2d.default) +def upsample_bicubic2d_default( + x, + output_size, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + x.realize_hint() + x_loader = x.make_loader() + + N, C, iH, iW = x.get_size() + oH, oW = output_size + + iH = V.graph.sizevars.evaluate_static_shape(iH) + iW = V.graph.sizevars.evaluate_static_shape(iW) + + def get_int_dtype(maxval): + if maxval > torch.iinfo(torch.int32).max: + return torch.int64 + return torch.int32 + + def compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 + else: + return 1 / scale if scale is not None and scale > 0 else in_size / out_size + + def compute_source_index(scale, dst_index, align_corners): + dst_index_ie = ops.index_expr(dst_index, torch.float32) + scale = ops.constant(scale, torch.float32) + if align_corners: + return ops.mul(scale, dst_index_ie) + else: + half = ops.constant(0.5, torch.float32) + return scale * (dst_index_ie + half) - half + + def cubic_convolution1(x, A): + _Ap2, _Ap3, _1 = _create_constants(A + 2, A + 3, 1, dtype=torch.float32) + return (_Ap2 * x - _Ap3) * x * x + _1 + + def cubic_convolution2(x, A): + _A, _4A, _5A, _8A = _create_constants( + A, 4 * A, 5 * A, 8 * A, dtype=torch.float32 + ) + return ((_A * x - _5A) * x + _8A) * x - _4A + + def get_cubic_upsample_coefficients(t): + A = -0.75 + _1 = ops.constant(1.0, torch.float32) + c0 = cubic_convolution2(ops.add(t, _1), A) + c1 = cubic_convolution1(t, A) + + x2 = ops.sub(_1, t) + c2 = cubic_convolution1(x2, A) + c3 = cubic_convolution2(ops.add(x2, _1), A) + return (c0, c1, c2, c3) + + def cubic_interp1d(xs, t): + cs = get_cubic_upsample_coefficients(t) + # dot product between xs and cs + return xs[0] * cs[0] + xs[1] * cs[1] + xs[2] * cs[2] + xs[3] * cs[3] + + height_scale = compute_scale(iH, oH, align_corners, scales_h) + width_scale = compute_scale(iW, oW, align_corners, scales_h) + + def clamp(v, min, max): + return ops.maximum(min, ops.minimum(max, v)) + + def fn(idx): + n, c, oy, ox = idx + + real_x = compute_source_index(width_scale, ox, align_corners) + in_x = ops.floor(real_x) + t_x = ops.sub(real_x, in_x) + + real_y = compute_source_index(height_scale, oy, align_corners) + in_y = ops.floor(real_y) + t_y = ops.sub(real_y, in_y) + + def load_bounded(fy, fx): + # TODO(Lezcano) Here we may not need to set-up a device_size + _0 = ops.constant(0, torch.int32) + iHm1 = ops.constant(iH - 1, torch.int32) + iWm1 = ops.constant(iW - 1, torch.int32) + iy = ops.indirect_indexing(clamp(fy, _0, iHm1), iH, check=False) + ix = ops.indirect_indexing(clamp(fx, _0, iWm1), iW, check=False) + return x_loader([n, c, iy, ix]) + + iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) + ix = ops.to_dtype(in_x, get_int_dtype(iW + 1)) + iys_ofs = tuple(ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)) + ixs_ofs = tuple(ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)) + + def get_x_interp(y): + coeffs_x = tuple(load_bounded(y, x) for x in ixs_ofs) + return cubic_interp1d(coeffs_x, t_x) + + coeffs_y = tuple(get_x_interp(y) for y in iys_ofs) + return cubic_interp1d(coeffs_y, t_y) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)], + ) + + +@register_lowering(aten.reflection_pad1d_backward) +@register_lowering(aten.reflection_pad2d_backward) +@register_lowering(aten.reflection_pad3d_backward) +def _reflection_padnd_backward(grad_output, x, padding): + dim = len(padding) // 2 + + dhw = [h - 1 for h in x.get_size()[-dim:]] + grad_loader = grad_output.make_loader() + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + def fn(idx): + b = idx[:-dim] + xyz = idx[-dim:] + + def load_from_output(x): + return grad_loader([*b, *x]) + + def index_range_condition(index_range): + i, lb, ub = index_range + i = ops.index_expr(i, torch.int32) + lb = ops.index_expr(lb, torch.int64) + ub = ops.index_expr(ub, torch.int64) + return ops.and_(ops.ge(i, lb), ops.le(i, ub)) + + # Areas after reflection: + # + # top-left | top | top-right + # ----------------------------------------- + # left | center | right + # ----------------------------------------- + # bottom-left | bottom | bottom-right + # + # The center area is the original matrix. Other areas are reflections. + + center = [xyz[i] + padding_left[i] for i in range(dim)] + left_reflect = [padding_left[i] - xyz[i] for i in range(dim)] + right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)] + + # Accumulate gradients from different areas + # If some of the padding is negative, center load is not always valid + range_c = [ + (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) + for i in range(dim) + ] + cond = functools.reduce( + ops.and_, [index_range_condition(range_c[i]) for i in range(dim)] + ) + grad = ops.masked(cond, lambda: load_from_output(center), 0.0) + + def accumulate(grad, out, index_ranges): + # If the upper bound is less than the lower bound, we can get rid of one accumulation. + # This happens when the padding size is zero. + for i in range(dim): + upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1] + if isinstance(upper_less_than_lower, bool) and upper_less_than_lower: + return grad + cond = functools.reduce( + ops.and_, + [index_range_condition(index_range) for index_range in index_ranges], + ) + g = ops.masked(cond, lambda: load_from_output(out), 0.0) + return ops.add(grad, g) + + for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]): + if area == tuple([0] * dim): + # center, this is already done. + continue + + outs = [] + index_ranges = [] + + for i in range(dim): + if area[i] == 0: + out = center[i] + index_range = range_c[i] + elif area[i] == -1: + out = left_reflect[i] + index_range = (xyz[i], 1, padding_left[i]) + elif area[i] == 1: + out = right_reflect[i] + index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) + + outs.append(out) # type: ignore[possibly-undefined] + index_ranges.append(index_range) # type: ignore[possibly-undefined] + + grad = accumulate(grad, outs, index_ranges) + + return grad + + return Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition_2d(x, fill_value, padding=None, pad_fill_value=1.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + padding_h = padding[0] if padding else 0 + padding_w = padding[1] if padding else 0 + + def load(index): + *prefix, ih, iw = index + + mask = ops.and_( + range_mask(ih, h + padding_h, -padding_h), + range_mask(iw, w + padding_w, -padding_w), + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition_2d(x, pad_fill_value)( + [*prefix, ih, iw] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, ih, iw]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): + x_out = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] + ) + + if ceil_mode: + x_alt = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +fallback_max_pool2d_with_indices = fallback_handler( + aten.max_pool2d_with_indices.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) +def max_pool2d_with_indices( + x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition_2d(x, float("-inf")) + else: + x_loader = x.make_loader() + + new_size = list(batch) + [h_out, w_out] + window_size = kernel_size[0] * kernel_size[1] + + if window_size > 25 or any(d != 1 for d in dilation): + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode + ) + + def fn(idx, return_index): + *prefix, bh, bw = idx + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + ih = bh * stride[0] + ih - padding[0] + iw = bw * stride[1] + iw - padding[1] + val = x_loader([*prefix, ih, iw]) + if return_index: + index = ops.index_expr(ih * w + iw, torch.int64) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + r1 = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + r2 = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return r1, r2 + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + try: + gO_stride = grad_output.get_stride() + except AttributeError: + # some classes don't have `get_stride` + # TODO will need a better way of determining if inputs are channels-last + gO_stride = None + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + try: + x_stride = x.get_stride() + except AttributeError: + x_stride = None + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + autotune = ( + config.coordinate_descent_tuning + or config.max_autotune + or config.max_autotune_pointwise + ) + if any(d != 1 for d in dilation) or (is_channels_last and not autotune): + # don't codegen channels-last when autotune is not enabled, it's very slow + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices.realize_hint() + + *batch, height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + [ + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ] + ) + w_window_size = max( + [ + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ] + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + return Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def pad_adaptive_loader(x, pad_val=0.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns): + h_start_index_fn, w_start_index_fn = start_index_fns + h_end_index_fn, w_end_index_fn = end_index_fns + + def fn_sum(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + total = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + return fn_sum + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + fn_sum = _adaptive_pooling_idx_sum( + [h_kernel_max, w_kernel_max], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader): + # NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d + # Look into refactoring/deduplication after #116418 is merged. + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + def fn_max(idx): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + if return_index: + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + return fn_max + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return max_pool2d_with_indices(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + inner_func_max_val = _adaptive_pooling_idx_max( + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + return_index=False, + loader=pad_adaptive_loader(x, float("-inf")), + ) + + inner_func_max_idx = _adaptive_pooling_idx_max( + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + return_index=True, + loader=pad_adaptive_loader(x, float("-inf")), + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_func_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_func_max_idx, + ranges=new_size, + ) + return rv, ri + + +fallback_fractional_max_pool2d = fallback_handler( + aten.fractional_max_pool2d.default, add_to_fallback_set=False +) + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + alpha = (in_sz - kernel_sz) / (out_sz - 1) + samples_loader = samples.make_loader() + + def load(prefix, i): + sample = samples_loader([*prefix, dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + alpha_expr = ops.index_expr(alpha, samples.get_dtype()) + seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor( + sample * alpha_expr + ) + seq_i = ops.to_dtype(seq_i, torch.int64) + + mask = ops.lt( + i_expr, + ops.index_expr(out_sz - 1, torch.int64), + ) + return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + x.realize_hint() + *batch, inp_h, inp_w = x.get_size() + kernel_h, kernel_w = kernel_size + h_out, w_out = output_size + + if kernel_h * kernel_w >= 25: + return fallback_fractional_max_pool2d( + x, kernel_size, output_size, random_samples + ) + + gen_offsets_for_dim = functools.partial( + _fractional_pooling_offsets, + samples=random_samples, + in_sz=[inp_h, inp_w], + out_sz=output_size, + kernel_sz=kernel_size, + ) + + h_index_fn = gen_offsets_for_dim(dim=0) + w_index_fn = gen_offsets_for_dim(dim=1) + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + + h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h) + w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + val = x_loader([*prefix, h_start_index + ih, w_start_index + iw]) + if return_index: + index = ops.index_expr( + (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64 + ) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where( + ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex + ) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + new_size = list(batch) + [h_out, w_out] + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return rv, ri + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + + *batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, out_dim) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h) + h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h) + + w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w) + w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w) + + fn_sum = _adaptive_pooling_idx_sum( + [h_kernel_max, w_kernel_max], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition_2d(x, 0.0) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = kernel_size[0] * kernel_size[1] + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + *prefix, bh, bw = idx + total = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + ih = bh * stride[0] + ih - padding[0] + iw = bw * stride[1] + iw - padding[1] + val = loader([*prefix, ih, iw]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + if divisor_override: + scale = 1 / divisor_override + else: + scale = 1.0 / (kernel_size[0] * kernel_size[1]) + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + ones_loader = constant_boundary_condition_2d( + ones_like(x), 0.0, padding if count_include_pad else None + ) + + def fn(idx): + # TODO(jansel): optimize to do `int(x 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(set(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.Integer(1) + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + if use_two_step_variance(x, axis=axis, keepdim=keepdim) + else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayout.realize_into(val, changed_data, unsafe_alias=unsafe_alias) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + return ops.floordiv(a, b) + + +@make_pointwise +def truncdiv(a, b): + return ops.truncdiv(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/openai/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + return make_pointwise(fn)(a, b) + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + def fn(*args): + return ops.truediv(*args) + + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + return make_pointwise(_rsqrt)(x) + + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + result = ir.Scan.create(**kwargs, combine_fn=ops.add, init=0) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + result = ir.Scan.create(**kwargs, combine_fn=ops.mul, init=1) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a, b): + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + result = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper, init=float("-inf")) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + + +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +register_lowering(aten.clamp_min)(maximum) +register_lowering(aten.clamp_max)(minimum) +neg = register_pointwise(aten.neg) +abs = register_pointwise(aten.abs) +reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.nextafter) + +from .codegen.common import pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + tracing_context = torch._guards.TracingContext.try_get() + assert ( + tracing_context is None or a in tracing_context.fake_mode.shape_env.var_to_range + ) + return a + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + # Note [Can val be an int?] + # ~~~~~~~~~~~~~~~~~~~~~~~~~ + # In principle, someone could construct an FX graph where + # a call to size/stride has a val that is a plain int (not + # SymInt). However, we will maintain the invariant that + # this is not possible: if you are constructing an FX graph + # where there is a call to size/stride that returns an + # int, but you KNOW that int must always be a constant, + # then you do not need trace that call at all (and just + # constant propagate the integer as is.) + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + # See Note [Can val be an int?] + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_(*, kernel_idx, grid, kwargs): + ir.UserDefinedTritonKernel(kernel_idx=kernel_idx, grid=grid, kernel_args=kwargs) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(triton_kernel_wrapper_functional) +def triton_kernel_wrap(*, kernel_idx, grid, kwargs, tensors_to_clone): + new_kwargs = {} + for name, value in kwargs.items(): + if isinstance(value, ir.TensorBox): + x = value.data + has_non_rv_views = False + while isinstance(x, ir.BaseView): + if not isinstance(x, ir.ReinterpretView): + has_non_rv_views = True + break + x = x.data + if has_non_rv_views: + # we realize the inputs wrapped into any view which is not + # ReinterpretView to convert them into ReinterpretView during + # realization; all views being ReinterpretView is assumed by + # the downstream code (e.g., preserving ReinterpretView in + # cloning; layout should be available in mutation marking) + value = ir.TensorBox(ir.ExternKernel.realize_input(value)) + if name in tensors_to_clone: + value = clone_preserve_reinterpret_view(value) + new_kwargs[name] = value + + return triton_kernel_wrap_(kernel_idx=kernel_idx, grid=grid, kwargs=new_kwargs) + + +@register_lowering(torch.ops.higher_order.cond) +def cond(pred, true_fn, false_fn, operands): + if is_triton(pred) or any(map(is_triton, operands)): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +try: + import torch.distributed._functional_collectives + + c10d_functional = torch.ops.c10d_functional + + @register_lowering(c10d_functional.wait_tensor) + def wait(input): + return TensorBox.create(ir.Wait.create(input)) + + @register_lowering(c10d_functional.broadcast) + def broadcast(input, src, tag, ranks, group_size): + return ir.Broadcast.create(input, src, tag, ranks, group_size) + + @register_lowering(c10d_functional.all_reduce) + def allreduce(input, reduce_op, tag, ranks, group_size): + return ir.AllReduce.create(input, reduce_op, tag, ranks, group_size) + + @register_lowering(c10d_functional.all_gather_into_tensor) + def all_gather_into_tensor(shard, tag, ranks, group_size): + return TensorBox.create( + ir.AllGatherIntoTensor.create( + ir.ExternKernel.require_contiguous(shard), tag, ranks, group_size + ) + ) + + @register_lowering(c10d_functional.reduce_scatter_tensor) + def reduce_scatter_tensor(input, reduce_op, tag, ranks, group_size): + return TensorBox.create( + ir.ReduceScatterTensor.create(input, reduce_op, tag, ranks, group_size) + ) + + @register_lowering(c10d_functional.all_reduce_coalesced) + def all_reduce_coalesced(input, reduce_op, tag, ranks, group_size): + return ir.AllReduceCoalesced.create(input, reduce_op, tag, ranks, group_size) + + @register_lowering(c10d_functional.all_gather_into_tensor_coalesced) + def all_gather_into_tensor_coalesced(self, tag, ranks, group_size): + result = ir.AllGatherIntoTensorCoalesced.create(self, tag, ranks, group_size) + return list(map(TensorBox.create, result)) + + @register_lowering(c10d_functional.reduce_scatter_tensor_coalesced) + def reduce_scatter_tensor_coalesced(self, reduceOp, tag, ranks, group_size): + result = ir.ReduceScatterTensorCoalesced.create( + self, reduceOp, tag, ranks, group_size + ) + return list(map(TensorBox.create, result)) + + @register_lowering(c10d_functional.all_to_all_single) + def all_to_all_single( + self, output_split_sizes, input_split_sizes, tag, ranks, group_size + ): + return TensorBox.create( + ir.AllToAllSingle.create( + self, output_split_sizes, input_split_sizes, tag, ranks, group_size + ) + ) + + _c10d_functional = torch.ops._c10d_functional + + @register_lowering(_c10d_functional.all_reduce) + def _all_reduce(inp, reduce_op, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_) + def _all_reduce_(inp, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_lowering(_c10d_functional.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_lowering(_c10d_functional.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + ) + + @register_lowering(_c10d_functional.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_lowering(_c10d_functional.all_to_all_single) + def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + _c10d_functional.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) + ) + + @register_lowering(_c10d_functional.broadcast) + def _broadcast(inp, src, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.broadcast_) + def _broadcast_(inp, src, group_name): + ir._CollectiveKernel.create_inplace( + _c10d_functional.broadcast_.default, inp, src, group_name + ) + return inp + + @register_lowering(_c10d_functional.wait_tensor) + def _wait_tensor(inp): + ir._WaitKernel.create_wait(_c10d_functional.wait_tensor.default, inp) + return inp + +except ImportError: + log.info( + "Inductor support for distributed collectives depends on building torch.distributed" + ) + +# populate lowerings defined in kernel/* +from . import kernel + +import_submodule(kernel) + +from . import quantized_lowerings + +quantized_lowerings.register_quantized_ops() diff --git a/MLPY/Lib/site-packages/torch/_inductor/metrics.py b/MLPY/Lib/site-packages/torch/_inductor/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..915b602d6f3b197cc297b7c1814855a023dfd3b4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/metrics.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import csv +import inspect +import os +import re +from dataclasses import dataclass +from functools import lru_cache + +from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union + +from torch._inductor import config +from torch._inductor.utils import get_benchmark_name + +# Prevent circular import +if TYPE_CHECKING: + from torch._inductor.scheduler import ( + BaseSchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + SchedulerNode, + ) + +# counter for tracking how many kernels have been generated +generated_kernel_count = 0 +generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem: List[ + Tuple[ + Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode], + int, + ] +] = [] +node_runtimes: List[Tuple[BaseSchedulerNode, float]] = [] + +# counters for tracking fusions +ir_nodes_pre_fusion = 0 + +# counters for tracking to_dtype inserted +cpp_to_dtype_count = 0 + +# counters for tracking cpp_wrapper disabled +disable_cpp_wrapper = 0 + + +# reset all counters +def reset(): + global generated_kernel_count + global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem + global ir_nodes_pre_fusion + global cpp_to_dtype_count + global disable_cpp_wrapper + + generated_kernel_count = 0 + generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() + node_runtimes.clear() + ir_nodes_pre_fusion = 0 + cpp_to_dtype_count = 0 + disable_cpp_wrapper = 0 + + +@dataclass +class CachedMetricsDeltas: + """ + The subset of metrics we want update across cache hits, e.g., the + FxGraphCache. + """ + + generated_kernel_count: int + generated_cpp_vec_kernel_count: int + ir_nodes_pre_fusion: int + cpp_to_dtype_count: int + + +class CachedMetricsHelper: + """ + A helper class to help calculate and apply counter deltas for those + metrics we want to save with cache entries (e.g., FxGraphCache) and + apply on a cache hit. + """ + + def __init__(self): + global generated_kernel_count + global generated_cpp_vec_kernel_count + global ir_nodes_pre_fusion + global cpp_to_dtype_count + + self.generated_kernel_count = generated_kernel_count + self.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + self.ir_nodes_pre_fusion = ir_nodes_pre_fusion + self.cpp_to_dtype_count = cpp_to_dtype_count + + def get_deltas(self) -> CachedMetricsDeltas: + global generated_kernel_count + global generated_cpp_vec_kernel_count + global ir_nodes_pre_fusion + global cpp_to_dtype_count + + return CachedMetricsDeltas( + generated_kernel_count - self.generated_kernel_count, + generated_cpp_vec_kernel_count - self.generated_cpp_vec_kernel_count, + ir_nodes_pre_fusion - self.ir_nodes_pre_fusion, + cpp_to_dtype_count - self.cpp_to_dtype_count, + ) + + @staticmethod + def apply_deltas(delta: CachedMetricsDeltas): + global generated_kernel_count + global generated_cpp_vec_kernel_count + global ir_nodes_pre_fusion + global cpp_to_dtype_count + + generated_kernel_count += delta.generated_kernel_count + generated_cpp_vec_kernel_count += delta.generated_cpp_vec_kernel_count + ir_nodes_pre_fusion += delta.ir_nodes_pre_fusion + cpp_to_dtype_count += delta.cpp_to_dtype_count + + +REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} + + +@dataclass +class MetricTable: + table_name: str + column_names: List[str] + + num_rows_added: int = 0 + + def add_row(self, row_fn): + if self.table_name not in enabled_metric_tables(): + return + + row_dict = row_fn() + assert len(self.column_names) == len( + row_dict + ), f"{len(self.column_names)} v.s. {len(row_dict)}" + assert set(self.column_names) == set( + row_dict.keys() + ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}" + + row = [ + get_benchmark_name(), + ] + row += [row_dict[column_name] for column_name in self.column_names] + self._write_row(row) + + def output_filename(self): + return f"metric_table_{self.table_name}.csv" + + def write_header(self): + filename = self.output_filename() + with open(filename, "w") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(["model_name"] + self.column_names) + + def _write_row(self, row): + filename = self.output_filename() + if self.num_rows_added == 0 and not os.path.exists(filename): + self.write_header() + + self.num_rows_added += 1 + + for idx, orig_val in enumerate(row): + if isinstance(orig_val, float): + new_val = f"{orig_val:.6f}" + elif orig_val is None: + new_val = "" + else: + new_val = orig_val + row[idx] = new_val + + with open(filename, "a") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(row) + + @staticmethod + def register_table(name, column_names): + table = MetricTable(name, column_names) + REGISTERED_METRIC_TABLES[name] = table + + +MetricTable.register_table( + "slow_fusion", + [ + "kernel1_path", + "kernel1_latency", + "kernel2_path", + "kernel2_latency", + "fused_kernel_path", + "fused_kernel_latency", + "slow_down_ratio", + ], +) + +# track the fusion statistics for each graph +MetricTable.register_table( + "graph_stats", + [ + "graph_id", + "num_nodes_before_fusion", + "num_nodes_after_fusion", + ], +) + +# track the perf difference between persistent reduction and non-persistent +# reductions +MetricTable.register_table( + "persistent_red_perf", + [ + "kernel1_name", + "kernel2_name", + "kernel1_latency", + "kernel2_latency", + "size_hints", + "reduction_hint", + "speedup", + ], +) + +# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint +MetricTable.register_table( + "kernel_metadata", + [ + "kernel_name", + "kernel_path", + "kernel_category", # pointwise/reduction/foreach etc. + "size_hints", + "reduction_hint", + "line_of_code", + "num_load", + "num_store", + "num_for_loop", + "num_atomic_add", + "num_args", + # xyz numel can be different to size_hints since size_hints are rounded + # up to the nearest power of 2. + # Inductor kernel will burn in the xyz numel in kernel code for static + # shape kernels. + # Logging them will be helpful to find unaligned shape for reduction + "xnumel", + "ynumel", + "rnumel", + "kernel_args_num_gb", + ], +) + + +def _parse_kernel_fn_code(kernel_module_code): + """ + The kernel_module_code is the python module that contains kernel function code. + kernel function is the proper triton kernel function annotated with + @triton.jit + """ + from .codecache import PyCodeCache + from .wrapper_benchmark import get_triton_kernel + + mod = PyCodeCache.load(kernel_module_code) + kernel = get_triton_kernel(mod) + # kernel is a CachingAutotune; kernel.fn is the JITFunction; + # kernel.fn.fn is the function being decorate by triton.jit + return inspect.getsource(kernel.fn.fn) + + +def _parse_kernel_line_of_code(proper_kernel_fn_code): + """ + Return the line of code for the kernel excluding the decorators. + """ + return len(proper_kernel_fn_code.splitlines()) + + +def _parse_size_hints(kernel_module_code, kernel_category): + if kernel_category == "foreach": + # foreach kernel does not have size_hints + return None + m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code) + assert m, "size_hints missing!" + return m.group(1) + + +def _parse_reduction_hint(kernel_category, kernel_module_code): + if kernel_category not in ("reduction", "persistent_reduction"): + return None + m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code) + assert m, "reduction_hint not found in kernel source code!" + return m.group(1) + + +def _count_pattern(proper_kernel_fn_code, pattern): + return proper_kernel_fn_code.count(pattern) + + +def _count_args(proper_kernel_fn_code): + def_line = proper_kernel_fn_code.splitlines()[0] + assert def_line.startswith("def ") + start_idx = def_line.index("(") + end_idx = def_line.index("):") + decl_csv = def_line[start_idx + 1 : end_idx] + comps = decl_csv.split(",") + return len(comps) + + +def _parse_proper_kernel_fn_code(kernel_fn_code): + """ + Skip decorators. + """ + start_pos = kernel_fn_code.index("def ") + return kernel_fn_code[start_pos:] + + +def _parse_numel(proper_kernel_fn_code, numel_arg_name): + m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code) + if m: + return int(m.group(1)) + else: + return None + + +def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category): + """ + inductor meta looks like: + inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0}, + """ + m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code) + if m: + return float(m.group(1)) + else: + """ + There are a few cases that kernel_num_gdb field can be missing: + 1. the field will be missing if config.benchmark_kernel and + config.profile_bandwidth are false + 2. even if config.benchmark_kernel or config.profile_bandwidth is true. + foreach kernel does not have kernel_num_gb field in the metadata + """ + return None + + +def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code): + """ + An utility to log kernel metadata. We may parse metadata from kernel source code here. + + It's fine to parse the generated kernel code here since the logging is + disabled by default. It would hurt compilation time. + """ + from .wrapper_benchmark import get_kernel_category_by_source_code + + kernel_category = get_kernel_category_by_source_code(kernel_module_code) + reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code) + size_hints = _parse_size_hints(kernel_module_code, kernel_category) + kernel_fn_code = _parse_kernel_fn_code(kernel_module_code) + + proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code) + + # the line of code excluding the decortors + kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code) + + get_metric_table("kernel_metadata").add_row( + lambda: { + "kernel_name": kernel_name, + "kernel_path": kernel_path, + "kernel_category": kernel_category, + "size_hints": size_hints, + "reduction_hint": reduction_hint, + "line_of_code": kernel_line_of_code, + "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"), + "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"), + "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "), + "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"), + "num_args": _count_args(proper_kernel_fn_code), + "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"), + "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"), + "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"), + "kernel_args_num_gb": _parse_kernel_args_num_gb( + kernel_fn_code, kernel_category + ), + } + ) + + +def purge_old_log_files(): + """ + Purge the old log file at the beginning when the benchmark script runs. + Should do it in the parent process rather than the child processes running + each individual model. + """ + for name, table in REGISTERED_METRIC_TABLES.items(): + if name in enabled_metric_tables(): + filename = table.output_filename() + if os.path.exists(filename): + os.unlink(filename) + + table.write_header() + + +@lru_cache +def enabled_metric_tables() -> Set[str]: + config_str = config.enabled_metric_tables + + enabled = set() + for name in config_str.split(","): + name = name.strip() + if not name: + continue + assert ( + name in REGISTERED_METRIC_TABLES + ), f"Metric table name {name} is not registered" + enabled.add(name) + return enabled + + +def is_metric_table_enabled(name): + return name in enabled_metric_tables() + + +def get_metric_table(name): + assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" + return REGISTERED_METRIC_TABLES[name] diff --git a/MLPY/Lib/site-packages/torch/_inductor/ops_handler.py b/MLPY/Lib/site-packages/torch/_inductor/ops_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a20e9848a3244a752ddff2e8aa258115ec6ddcc6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/ops_handler.py @@ -0,0 +1,655 @@ +import itertools +from typing import Any, Callable, Generic, Literal, Optional, Tuple, TypeVar, Union +from unittest.mock import patch + +import sympy +from typing_extensions import Protocol + +import torch +import torch.utils._pytree as pytree +from torch.fx.graph import inplace_methods, magic_methods +from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + +T = TypeVar("T") +StoreMode = Optional[Literal["atomic_add"]] +ReductionType = Literal[ + "argmax", + "argmin", + "welford_reduce", + "welford_combine", + "any", + "max", + "min", + "prod", + "sum", + "xor_sum", +] + + +def _arg_str(a) -> str: + if isinstance(a, sympy.Expr): + return sympy_str(a) + return str(a) + + +# NB: This is not done as a parent class, because our ops handlers +# implementations make heavy use of __getattr__ magic, and pre-existing +# stubs for methods would interfere with this mechanism. +# +# TODO: A superclass that does desugaring for operations like +# reciprocal/square might be useful. +class OpsHandler(Protocol[T]): + """ + Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, + as well as the contract for op handlers. The type T signifies the domain + of the abstract analysis AKA what all of the functions return / take as arguments + anywhere compute occurs. + + While these operators are typically dtype polymorphic (e.g., you can use mul + on both integers and floats), they do NOT do promotion and usually return the + same dtype as the input. You are expected to have handled type promotion + during ATen decompositions. Most operators correspond exactly to pointwise + operations as defined by torch, so when in doubt about semantics, check the + corresponding torch documentation. These are all scalar operations (so they + are defined to operate on a single element at a time.) + + For convenience, many operators take a src_dtype which indicates what the dtype + of the input argument is. Although in principle this can be derived by an + analysis, providing this for ops where it is useful helps avoid having to repeatedly + recompute dtype in code generation. + + Note that this often describes a class of static methods, for stateless + ops handlers. + + Handlers are often defined using ``__getattr__`` metaprogramming, which means + that you cannot declare that a type implements a protocol by inheriting from + it (as the type stubs count as attribute declarations and impede the getattr + magic method from being called). Instead, define a function that casts an + argument of your type to the protocol, which is sufficient to induce mypy to + test that the protocol is implemented correctly. Search for ``_typecheck_`` + in this file to see some examples. If you see an obscure error where a + class doesn't implement a Protocol, but mypy doesn't say why, check to see + that ``__getattr__`` is typed correctly (typically, it is not possible to + type ``__getattr__`` without typing it as ``Callable[..., Any]``) + """ + + def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: + """Produces a scalar constant of type dtype.""" + ... + + def load_seed(self, name: str, offset: T): + """Computes inductor_prims.lookup_seed.""" + ... + + def rand(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" + ... + + def randn(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" + ... + + def randint64(self, seed: T, offset: T, low: T, high: T) -> T: + """Computes inductor_prims.randint. offset has dtype int32.""" + ... + + def masked(self, mask: T, body: Callable[[], T], other: T) -> T: + """ + Computes body, but only perform loads/stores if the boolean mask + evaluates to true. For example, you would use this if you needed to + perform an indirect load that may not be valid on some elements; + without masking, invalid accesses can cause IMAs. When mask is true, + the result is the result of body; otherwise it is other. + + Contrast this with ops.where, which can multiplex between two values + that have been unconditionally computed. + """ + ... + + def where(self, condition: T, input: T, other: T) -> T: + """ + Computes torch.where: when condition is true, return input; otherwise return other. + """ + ... + + def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: + """ + Converts a sympy expression into a scalar of type dtype. expr is typically + an indexing expression, thus the name; however, it can also be used in + non-indexing situations. + """ + ... + + def to_dtype( + self, x: T, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None + ) -> T: + """ + Convert x to dtype. src_dtype can be optionally set to specify what the original + dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). + """ + ... + + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: + """ + Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) + src_dtype must be the original type of x. + """ + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operations are only available in a "kernel" context. Check + # torch._inductor.codegen.common.CSEProxy for their typical implementation + # in op handler (routing to their respective implementations in the kernel + # handler) + # + # Importantly, inside a kernel, indexing and mask variables are available + # in scope, which are typically used by sympy.Expr indexing. + + def indirect_indexing( + self, x: T, size: sympy.Expr, check: bool = True + ) -> sympy.Expr: + """ + Convert an integral x into a sympy.Expr that can be subsequently used in + indexing computation. 'size' represents an upper bound on the what valid + indexes can be; when 'check' is True, we check that the x is in bounds. + + NB: This is typically mandatory to implement for any analysis, because you + MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). + """ + ... + + def load(self, name: str, index: sympy.Expr) -> T: + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + ... + + def store( + self, + name: str, + index: sympy.Expr, + value: T, + mode: StoreMode = None, + ) -> None: + """ + Store 'value' to the memory location 'name' offset by 'expr'. If + specified, 'mode' can require the store to be an atomic addition. + """ + ... + + # TODO: Better explain how the "collective" semantics of these ops; + # remember that the input value is a scalar, you can't reduce on it in the + # traditional sense! + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: T, + ) -> Union[T, Tuple[T, ...]]: + """ + Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', + using 'dtype' as the accumulation dtype for the reduction. The result + is an intermediate computation which should be stored to the final + location using 'ops.store_reduction'. + + Valid reduction types are . For Welford reduction types, this + function returns multiple outputs; consult reduction_num_outputs to + determine the amount in metaprogramming applications. + """ + ... + + # TODO: in practice, this seems to actually return None, but not returning + # a T makes common __getattr__ idioms not type correctly. Figure out if + # this should be returning something. + def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T: + """ + Store the fully accumulated result of 'reduction' to the memory + location 'name' offset by 'expr'. + """ + ... + + def scan( + self, dtype: torch.dtype, combine_fn: Callable[[T, T], T], value: T, init: int + ) -> T: + """ + Perform an associative scan on 'value'. + """ + # TODO: Improve the description with some pseudocode + ... + + def bucketize( + self, + values: T, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> T: + # See [Note: Inductor bucketize op] + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The following ops have semantics that correspond exactly to the torch + # operation with the same corresponding name. + + def abs(self, x0: T) -> T: + ... + + def exp(self, x0: T) -> T: + ... + + def exp2(self, x0: T) -> T: + ... + + def expm1(self, x0: T) -> T: + ... + + def sqrt(self, x0: T) -> T: + ... + + def relu(self, x0: T) -> T: + ... + + def minimum(self, x0: T, x1: T) -> T: + ... + + def maximum(self, x0: T, x1: T) -> T: + ... + + def cos(self, x0: T) -> T: + ... + + def sin(self, x0: T) -> T: + ... + + def lgamma(self, x0: T) -> T: + ... + + def erf(self, x0: T) -> T: + ... + + def cosh(self, x0: T) -> T: + ... + + def sinh(self, x0: T) -> T: + ... + + def acos(self, x0: T) -> T: + ... + + def acosh(self, x0: T) -> T: + ... + + def asin(self, x0: T) -> T: + ... + + def asinh(self, x0: T) -> T: + ... + + def atan2(self, x0: T, x1: T) -> T: + ... + + def atan(self, x0: T) -> T: + ... + + def atanh(self, x0: T) -> T: + ... + + def copysign(self, x0: T, x1: T) -> T: + ... + + def erfc(self, x0: T) -> T: + ... + + def erfinv(self, x0: T) -> T: + ... + + def frexp(self, x0: T): + ... + + def hypot(self, x0: T, x1: T) -> T: + ... + + def log10(self, x0: T) -> T: + ... + + def nextafter(self, x0: T, x1: T) -> T: + ... + + def logical_and(self, x0: T, x1: T) -> T: + ... + + def logical_not(self, x0: T) -> T: + ... + + def logical_or(self, x0: T, x1: T) -> T: + ... + + def logical_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_and(self, x0: T, x1: T) -> T: + ... + + def bitwise_not(self, x0: T) -> T: + ... + + def bitwise_or(self, x0: T, x1: T) -> T: + ... + + def bitwise_xor(self, x0: T, x1: T) -> T: + ... + + def bitwise_left_shift(self, x0: T, x1: T) -> T: + ... + + def bitwise_right_shift(self, x0: T, x1: T) -> T: + ... + + def rsqrt(self, x0: T) -> T: + ... + + def log1p(self, x0: T) -> T: + ... + + def tan(self, x0: T) -> T: + ... + + def tanh(self, x0: T) -> T: + ... + + def sigmoid(self, x0: T) -> T: + ... + + def signbit(self, x0: T) -> T: + ... + + def fmod(self, x0: T, x1: T) -> T: + ... + + def log(self, x0: T) -> T: + ... + + def isinf(self, x0: T) -> T: + ... + + def isnan(self, x0: T) -> T: + ... + + def round(self, x0: T) -> T: + ... + + def floor(self, x0: T) -> T: + ... + + def sign(self, x0: T) -> T: + ... + + def to_int(self, x0: T) -> T: + ... + + def trunc(self, x0: T) -> T: + ... + + def truncdiv(self, x0: T, x1: T) -> T: + ... + + def ceil(self, x0: T) -> T: + ... + + def neg(self, x0: T) -> T: + ... + + def reciprocal(self, x0: T) -> T: + ... + + def eq(self, x0: T, x1: T) -> T: + ... + + def ne(self, x0: T, x1: T) -> T: + ... + + def lt(self, x0: T, x1: T) -> T: + ... + + def gt(self, x0: T, x1: T) -> T: + ... + + def le(self, x0: T, x1: T) -> T: + ... + + def ge(self, x0: T, x1: T) -> T: + ... + + def add(self, x0: T, x1: T) -> T: + ... + + def sub(self, x0: T, x1: T) -> T: + ... + + def mul(self, x0: T, x1: T) -> T: + ... + + def floordiv(self, x0: T, x1: T) -> T: + ... + + def truediv(self, x0: T, x1: T) -> T: + ... + + def div(self, x0: T, x1: T) -> T: + ... + + def mod(self, x0: T, x1: T) -> T: + ... + + def pow(self, x0: T, x1: T) -> T: + ... + + def and_(self, x0: T, x1: T) -> T: + ... + + def or_(self, x0: T, x1: T) -> T: + ... + + def xor(self, x0: T, x1: T) -> T: + ... + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # In CUDA, optimized implementations of other mathematical operations are + # offered separately via libdevice for double precision computation (in + # Triton, these go to tl.math rather than tl). We lower to these + # operators when doing FP64 on CUDA. Note that some operators + # unconditional go to tl.math. + # + # TODO(ezyang): Is this really the best way to do this? What if we have + # abs internally route to tl.math automatically when given a double + # precision input? One reason is that when doing codegen, we often don't + # know what the dtype of the inputs are! (In principle we do know, but + # for many analyses it's not conveniently available.) + + def libdevice_abs(self, x0: T) -> T: + ... + + def libdevice_exp(self, x0: T) -> T: + ... + + def libdevice_sqrt(self, x0: T) -> T: + ... + + def libdevice_cos(self, x0: T) -> T: + ... + + def libdevice_sin(self, x0: T) -> T: + ... + + def libdevice_sigmoid(self, x0: T) -> T: + ... + + def libdevice_log(self, x0: T) -> T: + ... + + +class MockHandler: + def __getattr__(self, name): + if name == "name": + return "MockHandler" + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fargs.extend(f"{k}={v}" for k, v in kwargs.items()) + return f"ops.{name}({', '.join(fargs)})" + + return inner + + @staticmethod + def masked(mask, body, other) -> str: + return f"ops.masked({mask}, {body()}, {other})" + + @staticmethod + def frexp(x): + return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") + + @staticmethod + def indirect_indexing(index_var, size, check=True) -> sympy.Symbol: + return sympy_index_symbol(f"({str(index_var)})") + + @classmethod + def _init_cls(cls): + def make_handler(format_string): + @staticmethod # type: ignore[misc] + def inner(*args): + return format_string.format(*args) + + return inner + + for name, format_string in itertools.chain( + magic_methods.items(), inplace_methods.items() + ): + setattr(cls, name, make_handler(format_string)) + + +MockHandler._init_cls() + + +# Use mypy to check protocol implemented correctly +def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: + return h + + +class KernelFormatterHandler: + def __init__(self, parent_handler): + self.parent_handler = parent_handler + self.output = IndentedBuffer(1) + self.var_counter = itertools.count() + + @staticmethod + def ir_to_string(ir_fn, index, rindex=None) -> str: + from .ir import FlexibleLayout + from .virtualized import V + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter.output.indent(-1): + formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter.output.writeline(f"{lhs} = {name}") + + with V.set_ops_handler(formatter), patch.object( + FlexibleLayout, "allow_indexing", True + ): + result = ir_fn(*args) + return formatter.getvalue(result) + + def __getattr__(self, name) -> Callable[..., Any]: + def inner(*args, **kwargs): + line = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return line + + def write(line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self.output.writeline(f"{varname} = {line}") + return varname + + return pytree.tree_map(write, line) + + return inner + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[str, Tuple[str, ...]], + ) -> Union[str, Tuple[str, ...]]: + line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) + num_values = reduction_num_outputs(reduction_type) + varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] + self.output.writeline(f"{','.join(varnames)} = {line}") + return tuple(varnames) if num_values > 1 else varnames[0] + + def getvalue(self, result): + self.output.writeline(f"return {result}") + return self.output.getvalue() + + +# Use mypy to check protocol implemented correctly +def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: + return h + + +class WrapperHandler(Generic[T]): + def __init__(self, inner: OpsHandler[T]): + self._inner = inner + + def __getattr__(self, item): + return getattr(self._inner, item) + + +# Use mypy to check protocol implemented correctly +def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: + return h + + +class OpCounterCSE: + """Shim to count how many ops are used""" + + def __init__(self, inner): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names = {} + + def __getattr__(self, name): + def inner(*args, **kwargs): + val = getattr(self.parent_handler, name)(*args, **kwargs) + if name == "indirect_indexing": + return val + + def count(val): + if val not in self.var_names: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + else: + return self.var_names[val] + + return pytree.tree_map(count, val) + + return inner + + +def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: + return h diff --git a/MLPY/Lib/site-packages/torch/_inductor/optimize_indexing.py b/MLPY/Lib/site-packages/torch/_inductor/optimize_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..b2438f68d16e42610a02062f67bcbfdf47564917 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/optimize_indexing.py @@ -0,0 +1,118 @@ +import math + +import sympy + +import torch +from torch.utils._sympy.value_ranges import ValueRanges +from .ir import LoopBody +from .utils import dominated_nodes + + +def val_expressable_in_32_bits(val): + if getattr(val, "is_Boolean", False): + return True + + if isinstance(val, sympy.Expr): + assert val.is_number + if val.is_Integer or val.is_Boolean: + val = int(val) + else: + val = float(val) + + # bound within mantissa + if isinstance(val, float): + return val <= (2**24) and val >= -(2**24) + + if isinstance(val, int): + iinfo = torch.iinfo(torch.int32) + return val <= iinfo.max and val >= iinfo.min + + raise Exception(f"Unexpected value {val}") + + +def range_expressable_in_32_bits(range): + return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( + range.upper + ) + + +def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals): + # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, + # then it's precision is set for that chain of uses, and we don't need to consider those + # dominated values + def skip_filter(node): + return node.target == "to_dtype" and node.args[2] in ( + torch.int32, + torch.float32, + torch.float64, + ) + + # TODO - there are dominated uses whose dtype does not depend on whether + # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to + # int32 without changing the output precision of the node. this case hasn't shown up + for dominated in dominated_nodes([node], skip_filter): + if dominated.target in ["store", "output"]: + continue + + if isinstance(dominated.target, str) and "set_indirect" in dominated.target: + idx = int(dominated.target[len("set_indirect") :]) + indirect_var = indirect_vars[idx] + + # We check that we can compute all the indices it's involved in with int32 + for index, expr in indices.items(): + if indirect_var in expr.free_symbols: + index_val = replacement_vals[index] + + if math.isinf(index_val.lower) or math.isinf(index_val.upper): + return + + # all indices are integers, so make sure that we + # use the bounds of integers instead of floats. + # TODO - not sure if we should be doing int/float casts while tracing, + # might interfere with sympy. + + index_val_int = ValueRanges[sympy.Expr]( + int(index_val.lower), int(index_val.upper) + ) + if not range_expressable_in_32_bits(index_val_int): + return + + if not range_expressable_in_32_bits(bounds[dominated]): + return + + args = list(node.args) + args[2] = torch.int32 + node.args = tuple(args) + + +def indexing_dtype_strength_reduction(loop_body: LoopBody): + """ + Performs Value Range Analysis on LoopBody's fx graph to reduce precision of + intermediaries from int64 to int32 + """ + bv = loop_body.bounds() + + int64_dtype_nodes = [ + node + for node in loop_body.get_nodes() + if ( + node.target == "to_dtype" + and node.args[2] == torch.int64 + and node not in bv.unbounded_vars + ) + ] + if not int64_dtype_nodes: + return + + bounds = bv.get_bounds() + + # TODO - if dominated node of one to_dtype is not expressible in int32, + # we should short circuit another to_dtype node if that node also dominates + for node in int64_dtype_nodes: + try_to_reduce_precision( + node, + bounds, + loop_body.indirect_vars, + loop_body.indexing_exprs, + bv.replacement_vals, + ) diff --git a/MLPY/Lib/site-packages/torch/_inductor/pattern_matcher.py b/MLPY/Lib/site-packages/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..016cf680973ba93cf07ff5f1638f1d816edd8bb7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/pattern_matcher.py @@ -0,0 +1,1524 @@ +from __future__ import annotations + +import dataclasses +import functools +import inspect +import itertools +import logging +import operator +import os +import re +from collections import defaultdict +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterable, + List, + NoReturn, + Optional, + Set, + Union, +) + +from typing_extensions import TypeGuard + +import torch +import torch._guards +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import counters +from torch._prims_common import is_integer_dtype +from torch.fx import Node +from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +from torch.fx.immutable_collections import immutable_dict, immutable_list + +from .._functorch import config as functorch_config +from .._functorch.aot_autograd import aot_function, make_boxed_func +from .._functorch.partitioners import default_partition +from .._subclasses import FakeTensorMode +from ..fx import Transformer +from . import config +from .decomposition import select_decomp_table +from .lowering import fallback_node_due_to_unsupported_type + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + + +class Multiple: + pass + + +# Sentinel indicating multiple quantities can be matched +MULTIPLE = Multiple() + + +class Match: + """ + Represents a successfully matched pattern. + """ + + def __init__(self, pattern: PatternExpr, args=None, kwargs=None): + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = args or [] + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes: List[torch.fx.Node] = [] + # Mapping CallFunction to the node.target + self.targets: Dict[_TargetExpr, torch.fx.node.Target] = {} + self.ctx: Optional[MatchContext] = None + self.replacement_graph: Optional[torch.fx.Graph] = None + + @property + def graph(self) -> torch.fx.Graph: + assert self.ctx + return self.ctx.graph + + def extend(self, other: Match): + if self.kwargs: + for key in set(self.kwargs.keys()) & set(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch("kwarg mismatch: {}", key) + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self) -> Match: + # Wrap args in an extra list + self.args = [tuple(self.args)] if self.args else [] + return self + + def __repr__(self): + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self, graph: torch.fx.Graph): + for n in reversed(self.nodes): + if not n._erased: + graph.erase_node(n) + + def output_nodes(self) -> List[Optional[torch.fx.Node]]: + assert self.ctx + return [ + (self.ctx.pattern_to_node[p] if p is not None else None) + for p in self.ctx.outputs + ] + + def output_node(self) -> torch.fx.Node: + return next(p for p in self.output_nodes() if p) + + def replace_with_graph(self, replacement_graph, args): + assert self.ctx + ReplacementPatternEntry.replace_with_graph( + self, self.ctx.graph, replacement_graph, args + ) + + def replace_by_example(self, replacement_fn, args, trace_fn=None, run_dce=True): + assert self.ctx + if trace_fn is None: + trace_fn = functools.partial(fwd_only, run_dce=run_dce) + replacement = trace_fn( + replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + ) + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) + + +class FailedMatch(RuntimeError): + def __init__(self, format_string, *args, **kwargs): + self.format_string = format_string + # We want to construct error messages lazily instead of eagerly, as + # constructing them eagerly can significantly worsen compile times. + if len(format_string) > 200: + raise RuntimeError( + f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" + ) + self.args = args + self.kwargs = kwargs + + def __str__(self): + return self.format_string.format(*self.args, **self.kwargs) + + def __bool__(self): + return False + + +def is_match(m: Union[Match, FailedMatch]) -> TypeGuard[Match]: + """ + TypeGuards cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeGuard. + """ + return bool(m) + + +class MatchContext: + """ + State needed while running PatternExpr._match(). + """ + + def __init__( + self, + outputs: List[Optional[PatternExpr]], + pattern_to_node: Optional[Dict[PatternExpr, Node]] = None, + *, + graph: torch.fx.Graph, + ): + self.outputs = outputs + self.pattern_to_node = {} if pattern_to_node is None else pattern_to_node + self.graph = graph + self.exclusive_node_set: List[NodeOrConstant] = [] + + def match(self, pattern, node): + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + m.ctx = self + return m + + def filter_multi_user_patterns(self): + return { + pattern: node + for pattern, node in self.pattern_to_node.items() + if pattern.has_multiple_users() and node is not None + } + + +class PatternExpr: + """ + Base class for types of patterns + """ + + def _match( + self, node: torch.fx.Node, ctx: MatchContext + ) -> Union[Match, FailedMatch]: + raise NotImplementedError() + + def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + try: + return MatchContext([self], graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def has_multiple_users(self) -> bool: + return False + + def __repr__(self): + return self.__class__.__name__ + "()" + + def find_anchor_nodes(self, ctx: MatchContext, searched): + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self, args=[node]) # matches anything + + +class Ignored(PatternExpr): + """ + Match an arg, but don't pass it to handler + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self) # matches anything + + def __repr__(self): + return "*" + + def pretty_print(self, pp: PatternPrettyPrinter): + return "Ignored()" + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name: str): + super().__init__() + self.name = name + + def __repr__(self): + return f"KeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + return Match(self, kwargs={self.name: node}) # matches anything + + +class ExclusiveKeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name): + super().__init__() + self.name = name + + def __repr__(self): + return f"ExclusiveKeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext): + if node in ctx.exclusive_node_set: + return FailedMatch("exclusive arg appears twice") + + ctx.exclusive_node_set.append(node) + return Match(self, kwargs={self.name: node}) # matches anything + + +class _TargetExpr(PatternExpr): + """ + Base class for filtering match by node.target + """ + + op: Optional[str] = None + + def __init__(self, fns, users=1): + if not self.op: + raise NotImplementedError("Shouldn't directly use _BaseNodeMatch") + super().__init__() + fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) + for fn in list(fns): + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + self.fns: List[Union[Callable[..., Any], str]] = fns + self.fns_set: Set[Union[Callable[..., Any], str]] = set(fns) + self.users: Union[int, Multiple] = users + + def fns_repr(self) -> str: + first_repr = self.fns[0] + if not isinstance(first_repr, str): + first_repr = first_repr.__name__ + + if len(self.fns) > 1: + return f"[{first_repr}, ...]" + elif self.fns[0] is getattr(torch, first_repr, None): + return f"torch.{first_repr}" + elif isinstance(self.fns[0], torch._ops.OpOverload): + return str(self.fns[0]) + else: + return first_repr + + def __repr__(self): + return f"{self.__class__.__name__}({self.fns_repr()})" + + def has_multiple_users(self) -> bool: + return isinstance(self.users, Multiple) or self.users > 1 + + def find_anchor_nodes(self, ctx: MatchContext, searched): + raise NotImplementedError() + + def _match_fns(self, node: torch.fx.Node): + return ( + isinstance(node, torch.fx.Node) + and node.op == self.op + and extract_target(node) in self.fns_set + ) + + def _match_users(self, node: torch.fx.Node, ctx: MatchContext): + return ( + self in ctx.outputs + or self.users is MULTIPLE + or len(node.users) == self.users + ) + + +class _TargetArgsExpr(_TargetExpr): + """ + Base class for filtering match by node.{target,args,kwargs} + """ + + def __init__(self, fns, *args, _users=1, **kwargs): + super().__init__(fns, _users) + self.args = tuple(args) + self.kwargs = dict(kwargs) + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten(args, kwargs: Dict[Any, Any]): + return (*args, *kwargs.values()), (len(args), *kwargs.keys()) + + @staticmethod + def pytree_flatten(args, kwargs: Dict[Any, Any]): + def norm_spec(s: pytree.TreeSpec): + if s.type is None: + return s + mapping = {immutable_list: list, tuple: list, immutable_dict: dict} + return pytree.TreeSpec( + mapping.get(s.type, s.type), + s.context, + list(map(norm_spec, s.children_specs)), + ) + + flat, spec = pytree.tree_flatten([args, kwargs]) + spec = norm_spec(spec) + return flat, spec + + def __repr__(self): + args = [ + self.fns_repr(), + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + return f"{self.__class__.__name__}({', '.join(args)})" + + def pretty_print(self, pp: PatternPrettyPrinter): + args = [ + self.fns_repr(), + *(pp.pretty_print(x) for x in self.args), + *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], + ] + if isinstance(self.users, Multiple): + args.append("_users=MULTIPLE") + elif self.users > 1: + args.append(f"_users={self.users}") + + joiner_str = ", " + return f"{self.__class__.__name__}({joiner_str.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + if not self._match_fns(node) or len(node.args) != len(self.args): + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users {}", self) + + _args = node.args + _kwargs = node.kwargs + if len(_kwargs) < len(self.kwargs): + from torch.fx.operator_schemas import normalize_function + + normalized_args_and_kwargs = normalize_function( + node.target, node.args, node.kwargs + ) + + if normalized_args_and_kwargs is None: + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + else: + _args, _kwargs = normalized_args_and_kwargs + if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + else: + return FailedMatch( + "function_mismatch: node={}, pattern={}", node, self + ) + else: + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + + node_items, node_spec = self.flatten(_args, _kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch("args_structure {} {}", node_spec, self_spec) + assert len(node_items) == len(self_items) + + m = Match(self) + for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not child_match: + return child_match + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch( + "constant_args: {} {!r}!={pattern!r}", node, child_node + ) + m.nodes.append(node) + m.targets[self] = node.target + return m + + def find_anchor_nodes(self, ctx: MatchContext, searched): + """ + This is used when we are matching a pattern with multiple outputs. + There is a partial match (stored in ctx) and we want to walk + this pattern to find a connection to an already-matched node. + + Yields candidate nodes that `self._match` might like. + """ + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + return + + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +class CallFunction(_TargetArgsExpr): + """ + Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` + """ + + op = "call_function" + + +class CallMethod(_TargetArgsExpr): + """ + Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` + """ + + op = "call_method" + + +class CallModule(_TargetArgsExpr): + """ + Matches a call_module node in the FX graphs: `module(*args, **kwargs)` + """ + + op = "call_module" + + +class _TargetExprVarArgs(_TargetExpr): + """ + Matches a call_function node with any arguments which are passed into the pattern + """ + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + if not self._match_fns(node): + return FailedMatch("function_mismatch") + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users") + + m = Match(self) + m.nodes.append(node) + m.targets[self] = node.target + m.args.extend(node.args) + m.kwargs.update(node.kwargs) + return m + + +class CallFunctionVarArgs(_TargetExprVarArgs): + op = "call_function" + + +class CallMethodVarArgs(_TargetExprVarArgs): + op = "call_method" + + +class CallModuleVarArgs(_TargetExprVarArgs): + op = "call_module" + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern: PatternExpr, partial=False): + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + self.partial = partial + + def __repr__(self): + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: List[torch.fx.Node], ctx: MatchContext): # type: ignore[override] + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(self) + # Propagating patterns with multiple users will ensure we don't revisit + # the same nodes + pattern_to_node = ctx.filter_multi_user_patterns() + matched = False + for i, child_node in enumerate(node): + child_ctx = MatchContext( + ctx.outputs, pattern_to_node, graph=child_node.graph + ) + child_match = child_ctx.match(self.pattern, child_node) + pattern_to_node = child_ctx.filter_multi_user_patterns() + if not child_match: + if not self.partial: + return FailedMatch("list[{}]: {}", i, child_match) + continue + matched = True + m.extend(child_match.bundle()) + if not matched: + return FailedMatch("list: no_match") + return m.bundle() + + +class MultiOutputPattern(PatternExpr): + def __init__(self, outputs): + super().__init__() + assert all(isinstance(x, (PatternExpr, type(None))) for x in outputs), outputs + self.outputs: List[Optional[PatternExpr]] = outputs + + @property + def fns(self): + assert self.outputs[0] and hasattr(self.outputs[0], "fns") + return self.outputs[0].fns + + def __repr__(self): + return f"{self.__class__.__name__}({self.outputs})" + + def pretty_print(self, pp: PatternPrettyPrinter): + args = [pp.pretty_print(x) for x in self.outputs] + joiner_str = f",\n{' '}" + str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" + str_out = f"{str_out}\n])" + return str_out + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = ctx.match(self.outputs[0], node) + if not m: + return m + + for pattern in self.outputs[1:]: + if pattern is None: + continue + child_match = self._match_from_anchors(pattern, ctx) + if not child_match: + return child_match + m.extend(child_match) + + return m + + def _match_from_anchors(self, pattern, ctx): + prior = dict(ctx.pattern_to_node) + m = FailedMatch("no anchor found") + for node in pattern.find_anchor_nodes(ctx, set()): + m = ctx.match(pattern, node) + if m: + return m + # revert any partial matches + ctx.pattern_to_node = dict(prior) + return m + + def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]: + try: + return MatchContext(self.outputs, graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + +class RepeatedExpr(PatternExpr): + """ + Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` + """ + + def __init__(self, inner_pattern: PatternExpr): + super().__init__() + assert hasattr(inner_pattern, "fns") + self.inner_pattern = inner_pattern + + @property + def fns(self): + return self.inner_pattern.fns + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = ctx.match(self.inner_pattern, node) + if not m: + return m + ctx.pattern_to_node.pop( + self.inner_pattern, + ) + # Check all anchor nodes match the pattern + for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, set()): + anchor_m = MatchContext([self], graph=node.graph).match( + self.inner_pattern, anchor_node + ) + if not anchor_m: + return anchor_m + m.extend(anchor_m) + return m + + +class PatternPrettyPrinter: + """ + Serializes Patterns to executable python. + XXX: currently only used and tested for fuse attention patterns. May not cover + all patterns. + """ + + def __init__(self): + self.namespace = torch.fx.graph._Namespace() + self.memoized_objs_names: Dict[PatternExpr, str] = {} + self.memoized_objs_pp: Dict[PatternExpr, str] = {} + + @staticmethod + def run(obj: PatternExpr, output_name="output"): + """ + Serializes obj to python code with obj written out to `output_name` + """ + + pp = PatternPrettyPrinter() + assert hasattr(obj, "pretty_print") + out_str = obj.pretty_print(pp=pp) + + output = [] + for key in pp.memoized_objs_names: + output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}") + + output.append(f"{output_name} = {out_str}") + + return "\n".join(output) + + def pretty_print(self, obj): + if isinstance(obj, _TargetArgsExpr): + if memoized_name := self.memoized_objs_names.get(obj): + return memoized_name + else: + return self.memoize(obj) + if hasattr(obj, "pretty_print"): + return obj.pretty_print(self) + + return repr(obj) + + def memoize(self, obj): + obj_str = obj.pretty_print(self) + obj_name = obj.fns_repr() + for prefix in ("aten.", "torch.", "prims."): + obj_name = obj_name.replace(prefix, "") + + tmp_name = self.namespace.create_name(obj_name, None) + self.memoized_objs_names[obj] = tmp_name + self.memoized_objs_pp[obj] = obj_str + return tmp_name + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + raise NotImplementedError() + + def register(self, pass_dicts, target=None, prepend=False): + if target is None: + assert hasattr(self.pattern, "fns") + for fn in self.pattern.fns: + self.register(pass_dicts, fn, prepend=prepend) + elif isinstance(pass_dicts, (dict, PatternMatcherPass)): + if prepend: + pass_dicts[target].insert(0, self) + else: + pass_dicts[target].append(self) + else: + for x in pass_dicts: + self.register(x, target, prepend=prepend) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes(graph) + + +@dataclasses.dataclass +class GraphPatternEntry(PatternEntry): + """ + A pattern that runs a function on the FX graph + """ + + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + with graph.inserting_before(node): + self.handler(match, *match.args, **match.kwargs) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + normalize_args: Callable[..., List[Any]] + + @staticmethod + def replace_with_graph( + match: Match, + graph: torch.fx.Graph, + replacement_graph: torch.fx.Graph, + args: List[Any], + ): + output_nodes = match.output_nodes() + first_node = output_nodes[0] + + class Replacer(torch.fx.Interpreter): + call_method = None # type: ignore[assignment] + call_module = None # type: ignore[assignment] + get_attr = None # type: ignore[assignment] + + def run_node(self, node) -> Any: + if node.op in ("placeholder", "output"): + return super().run_node(node) + if node.op == "call_function": + target = node.target + args, kwargs = self.fetch_args_kwargs_from_env(node) + result = graph.call_function(target, args, kwargs) + if "val" in node.meta and "val" not in result.meta: + result.meta["val"] = node.meta["val"] + if isinstance(node.meta["val"], torch.Tensor): + assert "tensor_meta" in node.meta + result.meta["tensor_meta"] = node.meta["tensor_meta"] + return result + raise NotImplementedError(f"unhandled {node}") + + output_nodes = match.output_nodes() + + if len(output_nodes) == 1: + last_node = output_nodes[0] + else: + assert output_nodes[0] + nodes = list(output_nodes[0].graph.nodes) + indices = [ + (nodes.index(n), n) + for n in output_nodes + if isinstance(n, torch.fx.Node) + ] + last_node = min(indices, key=lambda tup: tup[0])[1] + + def percolate_tags(node, recompute_tag, input_stops): + queue = [node] + visited = set() + + while queue: + arg = queue.pop() + if ( + arg not in visited + and arg not in input_stops + and hasattr(arg, "meta") + ): + visited.add(arg) + arg.meta["recompute"] = recompute_tag + queue.extend(arg.all_input_nodes) + + with graph.inserting_before(last_node): + replacement = Replacer(replacement_graph).run(*args) + if isinstance(replacement, torch.fx.Node): + replacement = [replacement] + + def maybe_getitem(node): + if node.op != "call_function": + return None + if node.target != operator.getitem: + return None + assert len(node.args) == 2 + return node.args[1] + + def replace(old, new): + if old is None: + assert new is None + return + assert isinstance(old, torch.fx.Node) + if new is None: + old.replace_all_uses_with(None) + graph.erase_node(old) + return + if isinstance(new, torch.fx.Node): + if "val" not in new.meta: + new.meta.update(old.meta) + + # Preserve the recompute tags in the replacement graph. We + # look at the recompute tags of the original output node to + # propagate the tag from the output all the way to the input + # args (named as args in the replace_with_graph). + # Note that this is best effort. Since patterns are from + # many to many, there is no easy way to correctly map the + # recomputable tags. It is possible in some scenarios that we + # incorrectly tag some nodes as recomputables. + if "recompute" in old.meta: + percolate_tags(new, old.meta["recompute"], args) + + old.replace_all_uses_with(new) + graph.erase_node(old) + return + + # `new` is not a node: it's a list of nodes. + # + # This happens when we want to replace a node that has a single + # packed return with multiple unpacked returns. We need to do + # some graph surgery here. + # + # Example: + # def original_graph(x): + # a = op(x) + # b = a[0] + # c = a[1] + # ... + # + # Assume that we want to replace op(x) with the graph + # def new_op(x): + # w = x + 1 + # z = x + 2 + # return (w, z) + # + # We need to replace `op` with the contents of `new_op`, + # and then rewrite a[0] to be w and a[1] to be z, as so: + # def new_graph(x): + # w = x + 1 + # z = x + 2 + # b = w + # c = z + # ... + old_uses = list(old.users.keys()) + for user in old_uses: + idx = maybe_getitem(user) + if idx is None: + raise AssertionError("can't handle") + replace(user, new[idx]) + graph.erase_node(old) + + if len(output_nodes) == len(replacement): + for old, new in zip(output_nodes, replacement): + replace(old, new) + else: + assert len(output_nodes) == 1 + replace(output_nodes[0], replacement) + + match.erase_nodes(graph) + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node): + self.replace_with_graph( + match, + graph, + match.replacement_graph, # type: ignore[arg-type] + self.normalize_args(*match.args, **match.kwargs), + ) + + +def _return_true(match): + return True + + +def log_trace_failure(search_fn, e): + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + + +def register_replacement( + search_fn, + replace_fn, + example_inputs: Iterable[Any], + trace_fn: Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], + pass_dicts, + extra_check=_return_true, + scalar_workaround=(), + exclusive_arg_names=(), + search_fn_pattern=None, +): + """ + Create a replacement rule based on example functions that get traced + to create patterns. This supports both training and inference when + run on a joint forward+backward graph. + + Args: + search_fn: traced to give original pattern + replace_fn: traced to give replacement graph + example_inputs: example inputs for initial trace + trace_fn: fwd_only or joint_fwd_bwd + pass_dict: dict of passes to register to + extra_check: additional check to run on match(using real shapes) + """ + argnames_static = [*inspect.signature(search_fn).parameters.keys()] + + def check_fn(match: Match): + """ + Often shapes get burned into the pattern, so our initial match ran with + `ignore_types=(int, ...)`. + + Recheck the match with the correct shapes. + """ + argnames = list(argnames_static) + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + + args = list( + torch.fx.map_arg( + [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] + ) + ) + sym_args: List[torch.SymInt] = [] + with torch._dynamo.utils.detect_fake_mode(args): + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False + + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) + for v in itertools.chain(args[i].shape, args[i].stride()): + if isinstance(v, torch.SymInt) and all( + guard_size_oblivious(v != a) for a in sym_args + ): + sym_args.append(v) + + if sym_args: + # AOT Autograd and make fx will dedupe symbolic shape size + # accesses of sym ints that appear as inputs + # We don't want the sym_size uses to interfere with pattern matching + # so we provide them as inputs. + # Later, when we actually do the replacement, the symbolic shape + # sizes will get re-traced and added to the graph. + + def search_fn_new(*args_new): + return search_fn(*args_new[len(args_new) - len(args) :]) + + try: + specific_graph = trace_fn(search_fn_new, sym_args + args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + # correct argnames in the graph + sym_arg_names = [] + for i, placeholder in zip( + range(len(sym_args) + len(args)), + specific_graph.graph.nodes, + ): + if i < len(sym_args): + sym_arg_names.append(placeholder.target) + continue + + with specific_graph.graph.inserting_after(placeholder): + new_node = specific_graph.graph.placeholder( + argnames[i - len(sym_args)] + ) + new_node.target = new_node.name + placeholder.replace_all_uses_with(new_node) + specific_graph.graph.erase_node(placeholder) + + argnames = sym_arg_names + argnames + else: + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, + scalar_workaround=scalar_workaround, + ) + specific_pattern_match = specific_pattern.match(match.output_nodes()[0]) # type: ignore[arg-type] + if specific_pattern_match and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] + return True + return False + + def normalize_args(**kwargs): + args = [] + for name in argnames_static: + args.append(kwargs.pop(name)) + for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break + args.append(kwargs.pop(f"tangents_{i}")) + assert not kwargs, f"leftover kwargs: {kwargs!r}" + return args + + if trace_fn is joint_fwd_bwd: + # If inference mode is enabled during compilation, assume that we don't + # want to match on any training graph patterns + if torch.is_inference_mode_enabled(): + return False + + # TODO: Revisit the functionalize_rng_ops for lowmem dropout + with functorch_config.patch(functionalize_rng_ops=False): + requires_grad: List[bool] = [ + isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs + ] + if search_fn_pattern is None: + pattern = gen_pattern( + search_fn, + example_inputs, + trace_fn, + scalar_workaround, + exclusive_arg_names, + ) + else: + pattern = search_fn_pattern + + pattern_repr = PatternPrettyPrinter.run(pattern) + assert pattern_repr not in _seen_patterns + _seen_patterns.add(pattern_repr) + pattern = ReplacementPatternEntry( + pattern=pattern, + extra_check=check_fn, + normalize_args=normalize_args, + ) + pattern.register(pass_dicts) + return pattern.pattern + + +@functorch_config.patch(functionalize_rng_ops=False) +def gen_pattern( + search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=() +) -> PatternExpr: + argnames = [*inspect.signature(search_fn).parameters.keys()] + + if scalar_workaround == (): + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) + return fx_to_pattern( + search_gm, + ignore_types=(int, float, list, torch.device, torch.dtype), + argnames=argnames, + scalar_workaround=scalar_workaround, + exclusive_arg_names=exclusive_arg_names, + ) + + +def register_lowering_pattern( + pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False +): + """ + Register an aten to inductor IR replacement pattern. The decorated + function is saved and then called a lowering time allowing direct + pattern to inductor IR conversion. + """ + + def decorator(handler): + assert callable(handler) + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + handler._inductor_lowering_function = True + return handler + + return decorator + + +def register_graph_pattern( + pattern: PatternExpr, extra_check=_return_true, *, pass_dict, prepend=False +): + """ + Register a pattern that runs a function on the FX graph, allowing + custom transformation code. + """ + + def decorator(handler): + assert callable(handler) + GraphPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + return handler + + return decorator + + +def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: + # first node in the graph + return node is next(iter(graph.nodes)) + + +# match: copy_, relu_, _set_grad_enabled, manual_seed, enter_functional_autocast, etc +_mutation_op_re = re.compile(r"_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_)") + + +def is_mutation_op(node: torch.fx.Node) -> bool: + if node.op == "call_function": + if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr] + return True + elif node.op == "call_method": + if _mutation_op_re.search(node.target): # type: ignore[union-attr, arg-type] + return True + return node.kwargs.get("out") is not None + + +def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: + n = node + while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): + n = n.prev + mutation_region_id = n.meta.get("mutation_region_id", 0) + while n is not node: + n = n.next + if is_mutation_op(n): + mutation_region_id += 1 + n.meta["mutation_region_id"] = mutation_region_id + return mutation_region_id + + +def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta + + +def compute_mutation_region_ids(graph: torch.fx.GraphModule): + mutation_region_id = 0 + for nd in graph.nodes: + if is_mutation_op(nd): + mutation_region_id += 1 + nd.meta["mutation_region_id"] = mutation_region_id + + +class PatternMatcherPass: + def __init__( + self, prevent_match_across_mutations=False, pass_name: Optional[str] = None + ): + super().__init__() + self.patterns: DefaultDict[ + torch.fx.node.Target, List[PatternEntry] + ] = defaultdict(list) + self.prevent_match_across_mutations = prevent_match_across_mutations + self.pass_name = pass_name + + def __getitem__(self, item: torch.fx.node.Target) -> List[PatternEntry]: + return self.patterns[item] + + def apply(self, graph: torch.fx.GraphModule) -> int: + if not self.patterns: + return 0 + if isinstance(graph, torch.fx.GraphModule): + graph = graph.graph + if self.prevent_match_across_mutations: + if should_compute_mutation_region_ids(graph): + compute_mutation_region_ids(graph) + get_mutation_region_id_partial = functools.partial( + get_mutation_region_id, graph + ) + count = 0 + for node in reversed(graph.nodes): + target = extract_target(node) + if ( + node.op in ["call_function", "call_method", "call_module"] + and target in self.patterns + ): + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): + continue + + for entry in self.patterns[target]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + self.prevent_match_across_mutations + and is_match(m) + and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined] + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) # type: ignore[arg-type] + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + return count + + def clear(self): + self.patterns.clear() + + +def _not_implemented(*args, **kwargs) -> NoReturn: + raise NotImplementedError() + + +def fx_to_pattern( + gm, + ignore_types=(), + argnames=(), + scalar_workaround=(), + exclusive_arg_names=(), +) -> PatternExpr: + """ + Convert an FX graph into a PatternExpr. This is useful for simple + patterns that can only match single functions and fixed-length lists. + """ + # scalar_workaround is a hack to capture dropout_p + # see https://github.com/pytorch/pytorch/issues/97894 + scalar_workaround = scalar_workaround or {} + inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} + assert len(inv_scalar_workaround) == len(scalar_workaround) + + def process_arg(x): + if isinstance(x, (float, int)) and x in inv_scalar_workaround: + return KeywordArg(inv_scalar_workaround[x]) + if type(x) in ignore_types: + return Ignored() + if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: + return Ignored() + return x + + argnum = itertools.count() + + class Converter(torch.fx.Interpreter): + call_method = _not_implemented + call_module = _not_implemented + get_attr = _not_implemented + + def placeholder(self, target, args, kwargs): + n = next(argnum) + if n < len(argnames): + name = argnames[n] + elif argnames: + assert target.startswith("tangent") + name = target + else: + target = re.sub(r"_\d+$", "", target) # de-mangle arg name + name = target + if name in exclusive_arg_names: + return ExclusiveKeywordArg(name) + else: + return KeywordArg(name) + + def call_function(self, target, args, kwargs): + args, kwargs = pytree.tree_map(process_arg, (args, kwargs)) + if list in ignore_types: + # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] + args = [process_arg(a) for a in args] + kwargs = {k: process_arg(a) for k, a in kwargs.items()} + return CallFunction(target, *args, **kwargs) + + def run_node(self, n): + rv = super().run_node(n) + if n.op == "output" and isinstance(rv, tuple): + assert len(rv) == len(n.args[0]) + for r, arg in zip(rv, n.args[0]): + r.users = len(arg.users) + else: + rv.users = len(n.users) + return rv + + pattern = Converter(gm).run() + if not isinstance(pattern, PatternExpr): + return MultiOutputPattern(pytree.tree_leaves(pattern)) + return pattern + + +@torch.no_grad() +def fwd_only(fn, args, *, run_dce=True) -> torch.fx.GraphModule: + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(): + mode = ( + "real" if not torch._inductor.utils.any_is_symbolic(*args) else "symbolic" + ) + gm = make_fx(fn, select_decomp_table(), tracing_mode=mode)(*args) + if run_dce: + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +@torch.enable_grad() +def joint_fwd_bwd(fn, args) -> torch.fx.GraphModule: + """Build a normalized training graph, for use with fx_to_pattern""" + gm: Optional[torch.fx.GraphModule] = None + + def record_joint_graph(joint_graph, inputs, **kwargs): + nonlocal gm + assert not gm + gm = clone_graph(joint_graph) + return default_partition(joint_graph, inputs, **kwargs) + + with torch._guards.tracing(None): + aot_function( + fn, + lambda g, i: make_boxed_func(g), + partition_fn=record_joint_graph, + decompositions=select_decomp_table(), + keep_inference_input_mutations=True, + enable_log=False, + )(*args) + assert gm + + from .fx_passes.joint_graph import pointless_view + + matcher_pass = PatternMatcherPass() + + pattern = CallFunction( + torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") + ) + GraphPatternEntry( + pattern=pattern, handler=pointless_view, extra_check=_return_true + ).register(matcher_pass.patterns) + matcher_pass.apply(gm.graph) # type: ignore[arg-type] + + # remove in/out specs + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: + args: List[torch.fx.node.Argument] = list() + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + +def stable_topological_sort(graph: torch.fx.Graph): + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = set() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def init_once_fakemode(fn: Callable[..., Any]): + """Wrapper around lazy init functions in fx_passes/""" + + @functools.lru_cache(None) + @functools.wraps(fn) + def lazy_init(): + counters_ref = counters["inductor"].copy() + + with torch._guards.tracing( + None + ), maybe_disable_fake_tensor_mode(), FakeTensorMode(): + result = fn() + + # clear view matches encountered during tracing + counters["inductor"] = counters_ref + + return result + + return lazy_init + + +def config_flag(name): + """Function for extra_check to put pass behind a flag""" + + def flag_check(match): + return getattr(config, name) + + return flag_check + + +def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: + class CopyGraph(Transformer): + def run_node(self, old_node): + new_node = super().run_node(old_node) + if isinstance(new_node, torch.fx.Proxy): + new_node.node.meta.update(old_node.meta) + new_node.node.name = self.new_graph._graph_namespace.create_name( + old_node.name, None + ) + return new_node + + return CopyGraph(input_graph).transform() + + +_seen_patterns: Set[str] = set() + + +def get_arg_value( + node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None +): + return ( + node.args[arg_number] + if len(node.args) > arg_number + else node.kwargs.get(kwarg_name) # type: ignore[arg-type] + ) + + +def filter_nodes(nodes: Iterable[torch.fx.Node], fn) -> List[torch.fx.Node]: + fns = [fn] + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + return [node for node in nodes if node.target in fns] + + +def extract_target(node: Node): + """For call_function and call_method, we directly use the target function; + For call_module, the target is string, and we treat the module class + as a function. + """ + if node.op == "call_module": + return getattr(node.graph.owning_module, node.target).__class__ # type: ignore[arg-type] + return node.target diff --git a/MLPY/Lib/site-packages/torch/_inductor/quantized_lowerings.py b/MLPY/Lib/site-packages/torch/_inductor/quantized_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..39d7c233d986c53aed900a66f32a10908cf2494e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/quantized_lowerings.py @@ -0,0 +1,15 @@ +import torch + + +def register_quantized_ops(): + from . import lowering + + quantized = torch.ops.quantized + + lowering.add_needs_realized_inputs( + [ + quantized.max_pool2d, + ] + ) + + lowering.make_fallback(quantized.max_pool2d) diff --git a/MLPY/Lib/site-packages/torch/_inductor/scheduler.py b/MLPY/Lib/site-packages/torch/_inductor/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0484dad1fbf7e1ab2050069e083b21ea222bbcfb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/scheduler.py @@ -0,0 +1,2445 @@ +import collections +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import pprint +import textwrap +from typing import ( + Any, + Counter, + DefaultDict, + Dict, + Generic, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) + +import sympy + +import torch +from torch._dynamo.utils import dynamo_timed +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.utils._triton import has_triton + +from . import comms, config, dependencies, ir, metrics +from .codegen.common import get_scheduling_for_device, Kernel +from .comm_analysis import estimate_nccl_collective_runtime +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout +from .sizevars import SimplifyIndexing +from .utils import ( + cache_on_self, + cmp, + free_symbol_has, + get_device_tflops, + get_dtype_size, + get_gpu_dram_gbps, + green_text, + is_collective, + is_wait, + red_text, + sympy_product, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["node1", "node2", "reason", "args"] + reason: str + args: Tuple[Any, ...] + + def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"): + self.node1 = node1 + self.node2 = node2 + + def __call__(self, reason, *args): + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self): + return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( + self.reason % self.args + ) + + +def pformat(obj): + if isinstance(obj, set): + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' '*4)}" + return result + + +class OutputNode: + def __init__(self, dep): + self.unmet_dependencies = {dep} + self.inverse_users = [] + + def is_reduction(self): + return False + + def get_alias_names(self): + return () + + def get_name(self): + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps(node, name_to_fused_node): + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + name_to_dep_count[name_to_fused_node[dep.name].get_name()] += 1 + + def should_prune(dep): + if isinstance(dep, WeakDep): + is_redundant = ( + name_to_dep_count[name_to_fused_node[dep.name].get_name()] > 0 + ) + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[dep.name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)} + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel +kernel_name_to_op = { + "extern_kernels.convolution": torch.ops.aten.convolution, + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, +} + + +class BaseSchedulerNode: + def __init__(self, scheduler: "Scheduler", node: ir.Buffer): + self.scheduler: Scheduler = scheduler + self.node: ir.Buffer = node + self.users: List[NodeUser] = [] + self.inverse_users: List[BaseSchedulerNode] = [] + self.node_users: List[BaseSchedulerNode] = [] + self.set_read_writes(node.get_read_writes()) + self.ancestors: Set[str] = set() + self.min_order: int + self.max_order: int + self.last_usage: Set[ + str + ] = set() # buffers that won't be used after this kernel + self.written = False + + def __repr__(self): + return f"{type(self).__name__}(name={self.get_name()!r})" + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + lines = [ + f"{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})", + f"{name}.writes = {pformat(self.read_writes.writes)}", + f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", + f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", + f"{name}.users = {self.users}", + ] + try: + lines += [ + self.debug_str_extra(), + ] + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return "\n".join(lines).rstrip() + + def debug_str_extra(self) -> str: + return "" + + def log_details(self): + log.info( + "%s: unmet_dependencies = %s, writes = %s", + self, + self.unmet_dependencies, + self.read_writes.writes, + ) + + def update_mutated_names(self, renames: Dict[str, str]): + self.set_read_writes(self.read_writes.rename(renames)) + + def add_mutation_dep(self, dep): + self.set_read_writes(self.read_writes.with_read(dep)) + + def add_fake_dep(self, dep): + self.set_read_writes(self.read_writes.with_read(dep)) + + def set_users(self, users: List["NodeUser"]): + # deduplicate + result: Dict[int, NodeUser] = {} + for use in users: + if id(use.node) in result: + result[id(use.node)] = use.merge(result[id(use.node)]) + else: + result[id(use.node)] = use + self.users = list(result.values()) + + def set_last_usage( + self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] + ): + used_buffers = self.used_or_aliased_buffer_names() + used_buffers = {mutation_real_name.get(k, k) for k in used_buffers} + self.last_usage = used_buffers - future_used_buffers + + def get_aliases(self): + return self.node.get_alias_names() + + def get_mutations(self): + return self.node.get_mutation_names() + + def has_aliasing_or_mutation(self): + return bool(self.get_aliases() or self.get_mutations()) + + def set_read_writes(self, rw: dependencies.ReadWrites): + self.read_writes: dependencies.ReadWrites = rw + self.unmet_dependencies = self.read_writes.reads + self.prune_deps() + + def op_counts(self): + return self.read_writes.op_counts + + def used_buffer_names(self) -> Set[str]: + return { + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + } + + def used_or_aliased_buffer_names(self) -> Set[str]: + used_names = set() + + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes): + used_names.add(dep.name) + if V.graph.name_to_buffer.get(dep.name): + layout = V.graph.name_to_buffer[dep.name].get_layout() + # needed to avoid deallocating aliased buffer + # if there are still uses of aliases ahead + if isinstance(layout, ir.AliasedLayout): + used_names.add(layout.view.data.get_name()) + return used_names + + def prune_deps(self): + self.unmet_dependencies = { + dep + for dep in self.unmet_dependencies + if dep.name not in self.scheduler.available_buffer_names + } + + def prune_weak_deps(self): + # Prune weak dependencies on buffers that have been removed + def should_prune(dep): + return isinstance(dep, WeakDep) and dep.name in V.graph.removed_buffers + + to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)} + self.set_read_writes(self.read_writes.remove_reads(to_remove)) + + def prune_redundant_deps(self, name_to_fused_node): + _prune_redundant_deps(self, name_to_fused_node) + + def get_name(self) -> str: + return self.node.get_name() + + def get_first_name(self) -> str: + return self.get_name() + + def get_names(self) -> Set[str]: + return {self.get_name()} + + def get_nodes(self) -> Sequence["BaseSchedulerNode"]: + return [self] + + def get_device(self): + return self.node.get_device() + + def is_reduction(self): + return False + + def is_split_scan(self): + return False + + def is_template(self): + return False + + def is_extern(self): + return False + + def is_foreach(self): + return False + + def can_inplace(self, read_dep: dependencies.MemoryDep): + return False + + def has_side_effects(self): + return False + + def decide_inplace_update(self): + """ + Decide if there should be inplace updates for the node + and record the decision in the active kernel. + """ + if not self.node.should_allocate(): + return + + if isinstance(self, (SchedulerNode,)) and ( + self.node.get_alias_names() or self.node.get_mutation_names() + ): + return + + if ( + ( + isinstance(self, (SchedulerNode,)) + # o what have i done. lets make this an api + or ( + isinstance(self, ExternKernelSchedulerNode) + and isinstance(self.node, (ir.AllReduce, ir.InPlaceHint)) + ) + ) + and config.inplace_buffers + and ( + not isinstance(V.kernel, torch._inductor.codegen.triton.TritonKernel) + or getattr(V.kernel, "mutations", None) is not None + ) + ): + from .codegen.wrapper import buffer_reuse_key + + ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) + + for read in ordered_reads: + input_node: Optional[ + BaseSchedulerNode + ] = self.scheduler.name_to_node.get(read.name) + if input_node and V.graph.wrapper_code.can_reuse(input_node, self): + assert input_node.users is not None + remaining_uses = [ + x + for x in input_node.users + if x.node.get_name() + not in self.scheduler.available_buffer_names + ] + if ( + len(remaining_uses) == 1 + and remaining_uses[0].can_inplace + and remaining_uses[0].node is self + and not isinstance( + input_node.node.get_layout(), + ( + ir.MultiOutputLayout, + ir.MutationLayout, + ir.AliasedLayout, + ), + ) + and not ( + isinstance( + input_node.node, (ir.FallbackKernel, ir.MultiOutput) + ) + and len(input_node.node.get_alias_names()) > 0 + ) + and buffer_reuse_key(input_node.node) + == buffer_reuse_key(self.node) + ): + # hacky check for if V.kernel is a real kernel or NullHandler + if hasattr(V.kernel, "args"): + # if there isn't a triton kernel, then we don't need to call triton-specific things. + # but TODO this might be a convenient place to signal to the Collective kernels to inplace + # (and, can we make "kernel" less generic of a name?) + V.kernel.args.make_inplace( + input_node.get_name(), self.get_name() + ) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.triton.TritonKernel + ): + V.kernel.mutations.add(input_node.get_name()) + V.kernel.mutations.add(self.get_name()) + + # update last usage of reused node + self.last_usage.discard(input_node.get_name()) + + V.kernel.inplace_update_buffers[ + self.get_name() + ] = input_node.get_name() + break + + def allocate(self): + if not self.node.should_allocate(): + return + + if isinstance(self, (SchedulerNode,)) and ( + self.node.get_alias_names() or self.node.get_mutation_names() + ): + V.graph.wrapper_code.codegen_allocation(self.node) + return + + # hacky check for if V.kernel is a real kernel or NullHandler + if ( + hasattr(V.kernel, "args") + and self.get_name() in V.kernel.inplace_update_buffers + ): + V.graph.wrapper_code.codegen_inplace_reuse( + self.scheduler.name_to_node[ + V.kernel.inplace_update_buffers[self.get_name()] + ].node, + self.node, + ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self): + # There's no real allocated buffer, no need to free it + if isinstance(self.node.layout, ir.NoneLayout): + return False + for use in self.users: + if isinstance(use.node, OutputNode): + return False + return True + + def codegen_originating_info(self, buffer, only_once=True): + if not config.comment_origin: + return + + if only_once and self.written: + return + origins = self.node.origins + out_lines = [] + + for o in origins: + if o.op == "output": + # These are boring and samey + continue + + out_lines.append("") + # TODO(voz): Should the pragma be constant somewhere? + out_lines.append("#pragma CMT ORIGIN:") + op_info_str = f"#pragma CMT {o.op} {o.target}" + if "seq_nr" in o.meta: + op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" + out_lines.append(op_info_str) + if "stack_trace" in o.meta: + stack_trace = f"{o.meta['stack_trace']}" + stack_trace_last_line = stack_trace.split("|")[-1] + out_lines.append( + "#pragma CMT " + + stack_trace_last_line.replace("{", "{{") + .replace("}", "}}") + .replace("\n", "\\") + ) + out_lines.append("#pragma CMT END ORIGIN") + out_lines.append("") + + if len(out_lines) == 0: + return + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + buffer.writelines(out_lines) + self.written = True + + def get_read_write_buffers_sizes(self) -> int: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + """ + if isinstance(self, NopKernelSchedulerNode): + return 0 + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + return 0 + + if isinstance(self, SchedulerNode): + node_numel = V.graph.sizevars.size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]) + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + for dep in self.read_writes.reads | self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = {dep.name for dep in self.read_writes.reads} + writes = {dep.name for dep in self.read_writes.writes} + + def is_materialized(buf, snodes): + users = self.scheduler.name_to_node[buf].users + buf_uses = {user.node for user in users} + return len(buf_uses - set(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = { + dep for dep in writes if not is_materialized(dep, self.snodes) + } + writes = writes - removed_buffers + reads = reads - removed_buffers + node_bytes = 0 + + for buf_name in reads | writes: + buf_accessed_elems = sum([node_numel for dep in buf_accesses[buf_name]]) + buf: Union[ir.Buffer, ir.TensorBox] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_elems(buf): + return V.graph.sizevars.size_hint(sympy_product(buf.get_size())) + + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_node[buf.get_name()].users + buf_elems = sum(get_buf_elems(user.node.node) for user in users) + else: + buf_elems = get_buf_elems(buf) + + node_bytes += min(buf_elems, buf_accessed_elems) * get_dtype_size( + buf.get_dtype() + ) + + return node_bytes + + def get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + layout = None + dtype = None + if not hasattr(self, "node") or not self.node: + assert isinstance( + self, (FusedSchedulerNode, ForeachKernelSchedulerNode) + ), f"{type(self)=}" + assert self.snodes + if not self.snodes[0].node: + return 0 + layout = self.snodes[0].node.get_layout() + dtype = self.snodes[0].node.get_dtype() + else: + layout = self.node.get_layout() + dtype = self.node.get_dtype() + + if "cuda" != layout.device.type: + # default to no reordering based on runtime + return 0 + + # Collective kernels + if is_collective(self.node): + return estimate_nccl_collective_runtime(self.node) + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + + try: + gpu_memory_bandwidth = get_gpu_dram_gbps() + gpu_flops = get_device_tflops(dtype) * 10**12 + except Exception: + return 0 + + if isinstance(self, ExternKernelSchedulerNode): + assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" + op = kernel_name_to_op.get( + getattr(self.node, "python_kernel_name", ""), None + ) + + # if there is a resolved op, dry-run using fake mode and record flop count + if op is not None: + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils.flop_counter import FlopCounterMode + + with FakeTensorMode(), FlopCounterMode( + display=False + ) as flop_counter_mode: + from .ir import ir_node_to_tensor + + fake_inputs = [ + ir_node_to_tensor(input, guard_shape=False) + for input in self.node.inputs + ] + cls = self.node.__class__ + cls.process_kernel(op, *fake_inputs, **self.node.kwargs) + + # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship + factor = 1.0 + counted_flops = flop_counter_mode.get_total_flops() + counted_bytes = self.get_read_write_buffers_sizes() + compute_time = (factor * counted_flops / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth + + # Return estimated runtime in nanoseconds + return max(compute_time, transfer_time) + + elif isinstance(self, FusedSchedulerNode) or isinstance( + self.node, ComputedBuffer + ): + # Return estimated runtime in nanoseconds (bytes / gbps) + return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + + return 0 + + +class ExternKernelSchedulerNode(BaseSchedulerNode): + def debug_str_extra(self) -> str: + return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" + + def is_extern(self): + return True + + def has_side_effects(self): + return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() + + def can_inplace(self, read_dep: dependencies.MemoryDep): + if self.get_aliases() or self.is_template(): + return False + + if read_dep.name not in self.scheduler.name_to_node: + # don't allow reuse of an 'input' buffer, we don't own it + # (would this have been fixed if I tracked mutations properly above?) + return False + if not isinstance( + self.node, (torch._inductor.ir.AllReduce, torch._inductor.ir.InPlaceHint) + ): + # TODO make this a property of the IR + return False + + if len(self.read_writes.writes) == 1: + write_dep = next(iter(self.read_writes.writes)) + numel_diff = read_dep.get_numel() - write_dep.get_numel() + return V.graph.sizevars.simplify(numel_diff) == 0 + + return False + + +class NopKernelSchedulerNode(BaseSchedulerNode): + pass + + +class SchedulerNode(BaseSchedulerNode): + def __init__( + self, + scheduler: "Scheduler", + node: Union[ir.ComputedBuffer, ir.TemplateBuffer], + ): + super().__init__(scheduler, node) + self._compute_attrs() + + def _compute_attrs( + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + ): + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, self._body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints + ) + + group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn + self.group = (self.node.get_device(), group_fn(self._sizes)) + + if isinstance(self.node, ir.TemplateBuffer): + self.set_read_writes(self.node.normalized_read_writes()) + else: + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=True + ) + ) + + def recompute_size_and_body( + self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] + ): + self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) + + def debug_str_extra(self) -> str: + name = self.get_name() + lines = [ + f"{name}.group.device = {self.group[0]}", + f"{name}.group.iteration = {self.group[1]}", + f"{name}.sizes = {self._sizes}", + ] + if self.get_aliases(): + lines.append(f"{name}.aliases = {pformat(self.get_aliases())}") + if self.get_mutations(): + lines.append(f"{name}.mutations = {pformat(self.get_mutations())}") + if isinstance(self._body, ir.LoopBody): + lines.append(f"class {name}_loop_body:") + lines.append(textwrap.indent(self._body.debug_str(), " ")) + return "\n".join(lines) + + def get_ranges(self): + return self._sizes + + def is_reduction(self): + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return bool(self.node.get_reduction_type()) + + def is_split_scan(self): + assert isinstance( + self.node, (ir.ComputedBuffer, ir.TemplateBuffer) + ), f"{type(self.node)=}" + return isinstance(self.node, ir.ComputedBuffer) and isinstance( + self.node.data, ir.SplitScan + ) + + def is_template(self): + return isinstance(self.node, ir.TemplateBuffer) + + def get_template_node(self): + return self.node if self.is_template() else None + + def run(self, *index_vars): + self.decide_inplace_update() + self.mark_run() + self.codegen(index_vars) + + def mark_run(self): + self.allocate() + + def ranges_from_index_vars(self, index_vars): + sizes = self._sizes + assert sum(map(len, sizes)) == sum(map(len, index_vars)) + var_ranges = dict( + zip( + itertools.chain.from_iterable(index_vars), + itertools.chain.from_iterable(sizes), + ) + ) + return var_ranges + + def codegen(self, index_vars): + var_ranges = self.ranges_from_index_vars(index_vars) + try: + with V.set_ops_handler( + SimplifyIndexing(V.get_ops_handler(), var_ranges) + ), V.kernel.set_current_node(self): + self._body(*index_vars) + except Exception: + log.fatal("Error in codegen for %s", self.node) + raise + + def pointwise_read_writes(self): + """ + Get the memory dependencies in the non-reduction axis. + """ + sizes, reduction_sizes = self._sizes + + def fn(index): + return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) + + return dependencies.extract_read_writes(fn, sizes) + + def can_inplace(self, read_dep: dependencies.MemoryDep): + if self.get_aliases() or self.is_template(): + return False + if len(self.read_writes.writes) == 1 and isinstance( + read_dep, dependencies.MemoryDep + ): + write_dep = next(iter(self.read_writes.writes)) + assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" + return read_dep.index == write_dep.index and read_dep.size == write_dep.size + return False + + @cache_on_self + def _get_atomic_add_buffers(self) -> Set[str]: + buffers_store_as_atomic_add = set() + if isinstance(self._body, ir.LoopBody): + for node in self._body.get_nodes(): + if ( + node.op == "call_method" + and node.target == "store" + and ( + ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") + or (len(node.args) == 5 and node.args[4] == "atomic_add") + ) + ): + buffers_store_as_atomic_add.add( + node.kwargs["name"] + if "name" in node.kwargs + else (node.args[1] if len(node.args) >= 2 else "") + ) + return buffers_store_as_atomic_add + + def has_atomic_add(self, check_buf): + return check_buf in self._get_atomic_add_buffers() + + +class FusedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be fused together. The way it does this is by maintaining + its unmet dependencies as the union of its constituent nodes. + """ + + @classmethod + def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + assert node1.scheduler is node2.scheduler + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance( + node2, (SchedulerNode, FusedSchedulerNode) + ) + return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type] + + def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]): + # NB: No need to call super().__init__() because we don't need to re-use any of its logic. + self.snodes = snodes + self.scheduler = scheduler + self.node: ir.Buffer = None # type: ignore[assignment] + self.users: List[NodeUser] = [] + self.inverse_users = [] + self.node_users = [] + self.group = max(snodes, key=lambda x: int(x.is_reduction())).group + self.ancestors = set.union( + *[x.ancestors for x in snodes if x.ancestors is not None] + ) + + self.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + self.unmet_dependencies = { + dep + for dep in set.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in self.get_names() + } - self.read_writes.writes + self.min_order = min([x.min_order for x in self.snodes]) + self.max_order = max([x.max_order for x in self.snodes]) + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_names(self) -> Set[str]: + return set.union(*[x.get_names() for x in self.snodes]) + + def debug_str_extra(self) -> str: + lines = [ + f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" + for i, node in enumerate(self.snodes) + ] + return textwrap.indent("\n".join(lines).rstrip(), " ") + + def set_last_usage( + self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] + ): + # Set self.last_usage using the global information + # This will be used for inter-kernel optimisations + super().set_last_usage(future_used_buffers, mutation_real_name) + # Set self.last_usage on the snodes + # This will be used for optimisations within the kernel + future_used_buffers: Set[str] = set() + for node in reversed(self.snodes): + node.set_last_usage(future_used_buffers, mutation_real_name) + future_used_buffers.update(node.last_usage) # type: ignore[arg-type] + + @cache_on_self + def used_buffer_names(self) -> Set[str]: + return set.union(*[x.used_buffer_names() for x in self.snodes]) + + @cache_on_self + def used_or_aliased_buffer_names(self) -> Set[str]: + return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes]) + + def get_nodes(self) -> List[SchedulerNode]: + return self.snodes + + def __repr__(self): + return f"{type(self).__name__}(nodes={self.get_name()})" + + @cache_on_self + def is_reduction(self): + return any(x.is_reduction() for x in self.snodes) + + @cache_on_self + def is_split_scan(self): + return any(x.is_split_scan() for x in self.snodes) + + @cache_on_self + def is_template(self): + return any(x.is_template() for x in self.snodes) + + @cache_on_self + def get_template_node(self): + for node in self.snodes: + if node.is_template(): + return node + return None + + def get_device(self): + return self.group[0] + + @cache_on_self + def has_aliasing_or_mutation(self): + return any(x.has_aliasing_or_mutation() for x in self.snodes) + + @cache_on_self + def op_counts(self): + op_counts: Counter[str] = collections.Counter() + for node in self.snodes: + op_counts.update(node.op_counts()) + return op_counts + + def has_atomic_add(self, check_buf): + return any( + ( + isinstance(sub_schedule_node1, SchedulerNode) + and sub_schedule_node1.has_atomic_add(check_buf) + ) + for sub_schedule_node1 in self.get_nodes() + ) + + # None of these need to be implemented, as a FusedSchedulerNode is just an + # abstraction for scheduling purposes + def update_mutated_names(self, renames: Dict[str, str]): + raise NotImplementedError + + def add_mutation_dep(self, name): + raise NotImplementedError + + def set_users(self, users: List["NodeUser"]): + raise NotImplementedError + + def get_aliases(self): + raise NotImplementedError + + def get_mutations(self): + raise NotImplementedError + + def can_inplace(self, read_dep: dependencies.MemoryDep): + raise NotImplementedError + + def allocate(self): + raise NotImplementedError + + def can_free(self): + raise NotImplementedError + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + node_typestr = ",".join(type(n).__name__ for n in self.snodes) + lines = [ + f"{name}: {type(self).__name__}({node_typestr})", + f"{name}.writes = {pformat(self.read_writes.writes)}", + f"{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}", + f"{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}", + f"{name}.users = {self.users}", + ] + try: + lines += [ + self.debug_str_extra(), + ] + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return "\n".join(lines).rstrip() + + +class ForeachKernelSchedulerNode(FusedSchedulerNode): + """Scheduler node which consists of a list of scheduler nodes that each operate on a + distinct tensor in a list of tensors.""" + + def get_consumer_subnode_for(self, producer): + if producer.get_name() in self.read_to_node: + return self.read_to_node[producer.get_name()] + + return None + + def get_producer_subnode_for(self, consumer): + for rd in consumer.read_writes.reads: + if rd.name in self.name_to_node: + return self.name_to_node[rd.name] + + return None + + @classmethod + def can_fuse(cls, producer, consumer): + why = WhyNoFuse(producer, consumer) + if producer.is_foreach() and consumer.is_foreach(): + foreach_match = len(producer.snodes) == len(consumer.snodes) + if not foreach_match: + why("foreach do not have same length") + return foreach_match and all( + producer.scheduler.can_fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ) + elif consumer.is_foreach(): + consumer_subnode = consumer.get_consumer_subnode_for(producer) + if consumer_subnode is not None: + return consumer.scheduler.can_fuse(producer, consumer_subnode) + + why("candidate producer is not dep of any foreach consumer") + return False + + elif producer.is_foreach(): + producer_subnode = producer.get_producer_subnode_for(consumer) + if producer_subnode is not None: + return producer.scheduler.can_fuse(producer_subnode, consumer) + + why("candidate consumer has no dep in any foreach producer") + return False + + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" + ) + + @classmethod + def fuse(cls, producer, consumer): + assert producer.is_foreach() or consumer.is_foreach() + prev_node_1 = None + prev_node_2 = None + if producer.is_foreach() and consumer.is_foreach(): + fused_nodes = [ + FusedSchedulerNode.fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ] + elif producer.is_foreach(): + producer_subnode = producer.get_producer_subnode_for(consumer) + fused_nodes = [] + prev_node_1 = producer + prev_node_2 = None + for node in producer.snodes: + if node is producer_subnode: + new_node = FusedSchedulerNode.fuse(node, consumer) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + elif consumer.is_foreach(): + consumer_subnode = consumer.get_consumer_subnode_for(producer) + fused_nodes = [] + prev_node_1 = consumer + prev_node_2 = None + + for node in consumer.snodes: + if node is consumer_subnode: + new_node = FusedSchedulerNode.fuse(producer, node) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined] + + def __init__( + self, + scheduler: "Scheduler", + nodes: List[SchedulerNode], + prev_node_1=None, + prev_node_2=None, + ): + self.read_to_node = {} + self.name_to_node = {} + + if prev_node_1 is None or prev_node_2 is None: + super().__init__(scheduler, nodes) + + for node in nodes: + for read in node.read_writes.reads: + self.read_to_node[read.name] = node + + for name in node.get_names(): + self.name_to_node[name] = node + else: + self.scheduler = scheduler + self.snodes = nodes + self.node: ir.Buffer = None # type: ignore[assignment] + self.users: List[NodeUser] = [] + + self.set_read_writes( + dependencies.ReadWrites.merge_list( + [prev_node_1.read_writes, prev_node_2.read_writes] + ) + ) + + self.unmet_dependencies = { + dep + for dep in set.union( + prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies + ) + if dep.name not in self.get_names() + } - self.read_writes.writes + + self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) + self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) + + foreach_node = prev_node_1 if prev_node_1.is_foreach() else prev_node_2 + other_node = prev_node_2 if prev_node_1.is_foreach() else prev_node_1 + + self.ancestors = foreach_node.ancestors + self.ancestors.update(other_node.ancestors) + + self.name_to_node = foreach_node.name_to_node + for name in other_node.get_names(): + self.name_to_node[name] = other_node + + self.group = (nodes[0].get_device(), "foreach") + + self.origins: Set[torch.fx.Node] = set() + + def mark_run(self): + raise NotImplementedError + + def codegen(self): + assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" + self.node.get_store_function()(self.node.make_loader()()) + + def can_free(self): + return NotImplementedError + + def is_foreach(self): + return True + + def get_subkernel_nodes(self): + """Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists. + These nodes may be vertically fused.""" + return list(self.snodes) + + def get_nodes(self): + """Returns all nodes contained in this kernel, unpacking fused nodes into their constituent scheduler nodes.""" + return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) + + def get_first_name(self): + return self.snodes[0].get_first_name() + + def prune_redundant_deps(self, name_to_fused_node): + _prune_redundant_deps(self, name_to_fused_node) + + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + + +def pick_loop_order(stride_lengths, sizes, priority_idx=()): + """ + A heuristic to decide loop iteration orders. This has not been well + tuned and may be something we should autotune. + """ + + @functools.cmp_to_key + def index_cmp(a, b): + if sizes[a] == 1 or sizes[b] == 1: + # 1-sizes don't matter, just move them to the end + return cmp(sizes[a] == 1, sizes[b] == 1) + + stride_len_a = [sl[a] for sl in stride_lengths] + stride_len_b = [sl[b] for sl in stride_lengths] + + # equivalent to + # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() + a_first = sum( + sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + b_first = sum( + sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + if a_first > b_first: + return -1 + if b_first > a_first: + return 1 + + # otherwise contiguous + return cmp(b, a) + + order = list(reversed(range(len(stride_lengths[0])))) + if len(priority_idx) > 0: + # if we have priority node, only use that node's order + stride_lengths = [stride_lengths[pi] for pi in priority_idx] + if config.pick_loop_orders: + order.sort(key=index_cmp) + return order + + +@dataclasses.dataclass +class NodeUser: + node: BaseSchedulerNode + can_inplace: bool = False + + # A weak user must be scheduled after a given node, but doesn't actually + # use the result + is_weak: bool = False + + def __hash__(self): + return hash((self.node.get_name(), self.can_inplace, self.is_weak)) + + def __eq__(self, other): + return ( + self.get_name() == other.get_name() + and self.can_inplace == other.can_inplace + and self.is_weak == other.is_weak + ) + + def get_name(self): + return self.node.get_name() + + def merge(self, other: "NodeUser") -> "NodeUser": + assert self.node is other.node + return NodeUser( + self.node, + self.can_inplace and other.can_inplace, + self.is_weak and other.is_weak, + ) + + +_post_grad_graph_counter = itertools.count() + + +class Scheduler: + @dynamo_timed + def __init__(self, nodes): + super().__init__() + self.backends = {} + self.fuse_cache = {} + self.post_grad_graph_id = next(_post_grad_graph_counter) + + self.nodes = [] + self.available_buffer_names = { + *V.graph.graph_inputs.keys(), + *V.graph.constants.keys(), + } + + self.nodes = [self.create_scheduler_node(n) for n in nodes] + + # some new constants could have been created above + self.available_buffer_names.update(V.graph.constants.keys()) + for node in self.nodes: + node.prune_deps() + + self.name_to_node: Dict[str, BaseSchedulerNode] = { + n.get_name(): n for n in self.nodes + } + self.name_to_fused_node: Dict[ + str, BaseSchedulerNode + ] = dict() # set in fuse_nodes() + + # mutation_real_name: Maps back to the original name for codegen + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_real_name = {"buf0" : "buf1"} + # all subsequent uses of buf0 become buf1's usage in dependency graph + self.mutation_real_name = {} + + # We handle mutation by renaming modified versions of the same + # buffer in the dependency graph to prevent cycles. + # mutation_renames: tracks the current name for a given buffer + # (changed once per mutation) + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_renames = {"buf1" : "buf0"} + # in codegen we only use buf0, never buf1 + self.mutation_renames = {} + + self.compute_dependencies() + self.topological_sort_schedule() + self.dead_node_elimination() + if config.reorder_for_compute_comm_overlap: + comms.decide_global_ordering_of_comms(self.nodes) + self.compute_ancestors() + + metrics.ir_nodes_pre_fusion += len(self.nodes) + V.debug.ir_pre_fusion(self.nodes) + self.num_orig_nodes = len(self.nodes) + self.name_to_fused_node = {n.get_name(): n for n in self.nodes} + self.create_foreach_nodes() + self.topological_sort_schedule() + self.logged_slow_fusion = set() + self.fuse_nodes() + if config.reorder_for_compute_comm_overlap: + # Refresh node_users and inverse_users to reflect fused nodes + self.compute_node_users() + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + self.compute_last_usage() + V.debug.ir_post_fusion(self.nodes) + V.debug.graph_diagram(self.nodes) + self.debug_draw_graph() + + # used during codegen: + self.current_device: torch.device = None # type: ignore[assignment] + self.buffer_names_to_free = set() + + # fx graph node to the position it appears in the graph + # for debug attribution + self.origin_to_index = {} + + get_metric_table("graph_stats").add_row( + lambda: { + "graph_id": self.post_grad_graph_id, + "num_nodes_before_fusion": self.num_orig_nodes, + "num_nodes_after_fusion": len(self.nodes), + } + ) + + def debug_draw_graph(self): + """Generate an image of the graph for debugging""" + if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": + from .debug import draw_buffers + + draw_buffers(self.nodes, print_graph=True) + + def debug_print_nodes(self, label): + if log.isEnabledFor(logging.INFO): + log.info("%s:", label) + for node in self.nodes: + node.log_details() + + def create_scheduler_node(self, node): + assert ( + node.origins is not None + ), "All nodes passed to scheduling must have an origin" + if node.is_no_op(): + return NopKernelSchedulerNode(self, node) + elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): + return SchedulerNode(self, node) + elif isinstance(node, ir.ExternKernel): + return ExternKernelSchedulerNode(self, node) + else: + raise NotImplementedError(node) + + def create_foreach_nodes(self): + removed_node_names = set() + fe_nodes = [] + kept_node_names = self.name_to_fused_node.keys() + + for names in V.graph.lists.values(): + names = [ + name + for name in names + if name in kept_node_names + and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) + ] + if not names: + # All nodes eliminated + continue + + removed_node_names.update(names) + snodes = [self.name_to_node[name] for name in names] + + fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type] + + fe_nodes.append(fe_node) + + for name in names: + self.name_to_fused_node[name] = fe_node + + self.nodes = [ + node for node in self.nodes if node.get_name() not in removed_node_names + ] + fe_nodes + + def compute_dependencies(self): + """ + Create dependency edges between nodes, handling aliasing and + mutation properly. + """ + + T = TypeVar("T") + + class DedupList(Generic[T]): + """ + This data structure behaves like a list except it makes sure the + elements remain unique. + Normally one could use a set/dict for this purpose however + the list in question gets elements appended as it is being + iterated over which means that we need to keep the list + semantics. + """ + + def __init__(self, items=None, membership=None): + self.items = items or list() + self.membership = membership or set() + + def append(self, node_user: T) -> None: + if node_user in self.membership: + return + self.items.append(node_user) + self.membership.add(node_user) + + def __add__(self, other: "DedupList[T]") -> "DedupList[T]": + new_membership = set.union(self.membership, other.membership) + new_items = self.items + [ + x for x in other.items if x not in self.membership + ] + return DedupList(new_items, new_membership) + + name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( + DedupList + ) + + # handle aliasing by using python aliasing in name_to_users + # if foo aliases bar then we will make name_to_users["foo"] point + # to the same python list as name_to_users["bar"] + for node1 in self.nodes: + node1_name = node1.get_name() + for node2_name in node1.get_aliases(): + if node1_name in name_to_users and node2_name in name_to_users: + # merge the two + list1 = name_to_users[node1_name] + list2 = name_to_users[node2_name] + combined = list1 + list2 + for key in name_to_users.keys(): + if name_to_users[key] is list1 or name_to_users[key] is list2: + name_to_users[key] = combined + elif node1_name in name_to_users: + name_to_users[node2_name] = name_to_users[node1_name] + else: + name_to_users[node1_name] = name_to_users[node2_name] + + def rename(n): + if n in self.mutation_renames: + return rename(self.mutation_renames[n]) + return n + + def dep_closure(node_name): + reachable_names = {node_name} + node = self.name_to_node[node_name] + write_dep = next(iter(node.read_writes.writes)) + for read_dep in node.read_writes.reads: + if ( + read_dep.name in self.name_to_node + and isinstance(read_dep, dependencies.MemoryDep) + and isinstance(write_dep, dependencies.MemoryDep) + and read_dep.index == write_dep.index + and read_dep.size == write_dep.size + ): + reachable_names.update(dep_closure(read_dep.name)) + return reachable_names + + def add_user(used_by_name, user_node, can_inplace=False, is_weak=False): + name_to_users[rename(used_by_name)].append( + NodeUser(user_node, can_inplace, is_weak) + ) + + unbacked_symbol_to_origin_node = {} + + for node in self.nodes: + log.debug("scheduling %s", node.node) + + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + unbacked_symbol_defs = sorted( + node.node.get_unbacked_symbol_defs(), key=lambda x: x.name + ) + for s in unbacked_symbol_defs: + assert isinstance(s, sympy.Symbol) + # Pick the first definer as canonical. There may be multiple + # because if a MultiOutputLayout buffer propagates an unbacked + # symint to multiple outputs, they will all claim to def it. + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node + + unbacked_symbol_uses = sorted( + node.node.get_unbacked_symbol_uses(), key=lambda x: x.name + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node}" + node.add_fake_dep(StarDep(unbacked_symbol_to_origin_node[s].get_name())) + + # a node will mutate either 0 or 1 buffers + assert len(node.get_mutations()) <= 1 + for alt_name in node.get_mutations(): + alt_name = rename(alt_name) + # this node must run after the prior writer + add_user(alt_name, node) + node.add_mutation_dep(StarDep(alt_name)) + for other_node in name_to_users[alt_name].items: + # this node must run after all prior readers + other_name = rename(other_node.get_name()) + known_dep_node_names = dep_closure(node.get_name()) + if other_name not in known_dep_node_names: + # If this node already directly or indirectly depends on other_node, + # we don't need to insert an extra dep. + node.add_mutation_dep(WeakDep(other_name)) + add_user(other_name, node, is_weak=True) + + # add normal non-mutation dependencies + for read in node.read_writes.reads: + is_weak = isinstance(read, WeakDep) + add_user(read.name, node, node.can_inplace(read), is_weak) + + node.update_mutated_names(self.mutation_renames) + + # update our renaming scheme for the next iteration + for alt_name in node.get_mutations(): + self.mutation_renames[rename(alt_name)] = node.get_name() + self.mutation_renames[alt_name] = node.get_name() + self.mutation_real_name[node.get_name()] = self.mutation_real_name.get( + alt_name, alt_name + ) + + # make sure outputs aren't dead-code-eliminated + for node_name in V.graph.get_output_names(): + log.debug("scheduling output %s", node_name) + add_user(node_name, OutputNode(StarDep(node_name))) + + # make sure unbacked symints aren't dead-code-eliminated + for node in V.graph.graph_outputs: + for s in node.get_unbacked_symbol_uses(): + assert ( + s in unbacked_symbol_to_origin_node + ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + node_name = unbacked_symbol_to_origin_node[s].node.name + log.debug("scheduling output %s for unbacked symint %s", node_name, s) + add_user(node_name, OutputNode(StarDep(node_name))) + + # make sure input mutation isn't dead-code-eliminated + for name in self.mutation_renames: + if name in V.graph.graph_inputs: + add_user(name, OutputNode(StarDep(name))) + V.graph.mutated_inputs.add(name) + + inp_names = { + name: index for index, name in enumerate(V.graph.graph_inputs.keys()) + } + V.graph.mutated_input_idxs = [ + inp_names[name] for name in V.graph.mutated_inputs + ] + + # copy users information onto the nodes + for node in self.nodes: + node.set_users(name_to_users[node.get_name()].items) + + # populate inverse_users + for node in self.nodes: + for user in node.users: + user.node.inverse_users.append(node) + + def compute_node_users(self): + # set up buffer name to (fused)snode mapping + buf_to_snode = {} + for node in self.nodes: + if isinstance(node, FusedSchedulerNode): + for x in node.snodes: + buf_to_snode[x.get_name()] = node + buf_to_snode[node.get_name()] = node + + for node in self.nodes: + node.node_users = [] + node.inverse_users = [] + + # compute inverse_users + for node in self.nodes: + inverse_users = [] + for dep in node.unmet_dependencies: + assert dep.name in buf_to_snode + dep_node = buf_to_snode[dep.name] + inverse_users.append(dep_node) + node.inverse_users = inverse_users + + # compute node_users + # TODO: ideally, we should deduplicate .users and .node_users, + # but currently .users contains extra information that's difficult to + # extract into a standalone container. + node_to_users: Dict[BaseSchedulerNode, List[BaseSchedulerNode]] = {} + for node in self.nodes: + for inverse_user in node.inverse_users: + node_to_users.setdefault(inverse_user, []).append(node) + for node, users in node_to_users.items(): + node.node_users = users + + def dead_node_elimination(self): + """ + Remove any nodes without users + """ + again = True # repeat until a fixed point + while again: + updated_nodes = [] + for node in self.nodes: + + def can_eliminate_user(user: NodeUser): + return user.is_weak or user.get_name() in V.graph.removed_buffers + + can_eliminate = not node.has_side_effects() and all( + can_eliminate_user(u) for u in node.users + ) + + if not can_eliminate: + updated_nodes.append(node) + else: + # dead code + log.debug("removed dead node: %s", node.get_name()) + V.graph.removed_buffers.add(node.get_name()) + + again = len(self.nodes) > len(updated_nodes) + self.nodes = updated_nodes + + # Prune any WeakDeps no longer needed + for node in self.nodes: + node.prune_weak_deps() + + def topological_sort_schedule(self): + """ + Ensure self.nodes is in topologically sorted order + """ + seen: Set[ir.Buffer] = set() + name_to_node: Dict[str, ir.Buffer] = dict() + result: List[ir.Buffer] = [] + + def visit(n): + if n not in seen: + seen.add(n) + for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): + visit(name_to_node[dep.name]) + result.append(n) + + for node in self.nodes: + for name in node.get_names(): + name_to_node[name] = node + for node in self.nodes: + visit(node) + self.nodes = result + + def compute_ancestors(self): + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: Dict[str, Set[str]] = {} + for node in self.nodes: + ancestors = set() + for dep in node.unmet_dependencies: + ancestors.add(dep.name) + ancestors |= name_to_ancestors[dep.name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + + def fuse_nodes(self): + """ + Mutates self.nodes to combine nodes into FusedSchedulerNodes. + """ + for i in range(10): + old_len = len(self.nodes) + fusion_log.debug( + "===== attempting fusion (%d/10): %d nodes =====", i + 1, old_len + ) + self.fuse_nodes_once() + new_len = len(self.nodes) + fusion_log.debug( + "completed fusion round (%d/10): fused %d nodes into %d nodes\n", + i + 1, + old_len, + new_len, + ) + if new_len == old_len or new_len == 1: + fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) + break + + def benchmark_fused_nodes(self, nodes): + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + V.graph.scheduler = self + self.current_device = device + backend = self.get_backend(device) + return backend.benchmark_fused_nodes(nodes) + + def speedup_by_fusion(self, node1, node2): + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + if not config.benchmark_fusion: + return True + + if ( + node1.is_template() + and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + or node1.is_foreach() + or node2.is_foreach() + ): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = node_list_1 + node_list_2 + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if any( + hasattr(n.node, "data") + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list_fused + ): + return True + + from triton.compiler.errors import CompilationError + + why = WhyNoFuse(node1, node2) + + try: + ms1, path1 = self.benchmark_fused_nodes(node_list_1) + if math.isinf(ms1): + why("register spilling of the first kernel") + return False + ms2, path2 = self.benchmark_fused_nodes(node_list_2) + if math.isinf(ms2): + why("register spilling of the second kernel") + return False + ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) + if math.isinf(ms_fused): + why("register spilling of the fused kernel") + return False + except CompilationError as e: + # workaround triton issue: https://github.com/openai/triton/issues/2151 + if "Loop-carried variable" in str(e): + return True # allow fusion + else: + raise + + if fusion_log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + fusion_log.debug( + "can fuse (benchmark): fusing %s with %s cause %sx speedup", + node1.get_names(), + node2.get_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", + node1.get_names(), + node2.get_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + if ( + is_metric_table_enabled("slow_fusion") + and ms_fused >= ms1 + ms2 + and (path1, path2) not in self.logged_slow_fusion + ): + self.logged_slow_fusion.add((path1, path2)) + get_metric_table("slow_fusion").add_row( + lambda: { + "kernel1_path": path1, + "kernel1_latency": ms1, + "kernel2_path": path2, + "kernel2_latency": ms2, + "fused_kernel_path": path_fused, + "fused_kernel_latency": ms_fused, + "slow_down_ratio": ms_fused / (ms1 + ms2), + } + ) + return ms_fused < ms1 + ms2 + + def fuse_nodes_once(self): + """ + Mutates self.nodes to combine nodes into FusedSchedulerNodes. + + This relies on two key functions to control the logic: + - self.can_fuse(): checks if a fusion is legal + - self.score_fusion(): assigns priority to a given fusion + """ + fused_nodes = set(self.nodes) + for node1, node2 in self.get_possible_fusions(): + node1 = self.name_to_fused_node[node1.get_first_name()] + node2 = self.name_to_fused_node[node2.get_first_name()] + if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( + node1, node2 + ): + if not self.speedup_by_fusion(node1, node2): + continue + fusion_log.debug( + "fusing %s with %s", node1.get_name(), node2.get_name() + ) + + # above can_fuse asserts that node2 has the same device + device = node1.get_device() + node3 = self.get_backend(device).fuse(node1, node2) + fused_nodes.remove(node1) + fused_nodes.remove(node2) + fused_nodes.add(node3) + self.name_to_fused_node.update( + {n.get_name(): node3 for n in node3.get_nodes()} + ) + self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) + self.topological_sort_schedule() + self.prune_redundant_deps() + + def prune_redundant_deps(self): + for node in self.nodes: + node.prune_redundant_deps(self.name_to_fused_node) + + def get_possible_fusions(self): + """ + Helper to find all legal fusion opportunities, sorted by self.score_fusion() + """ + possible_fusions = [] + seen = set() + + def check_all_pairs(nodes): + for node1_index, node1 in enumerate(nodes): + for node2 in nodes[node1_index + 1 :]: + key = (node1, node2) + if key in seen: + continue + seen.add(key) + + if self.can_fuse(node1, node2): + possible_fusions.append(key) + elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( + node2, node1 + ): + # foreach fusions and epilogue fusions are order dependent + possible_fusions.append((node2, node1)) + + buffer_names_grouping = collections.defaultdict(list) + for node in self.nodes: + for buf in node.used_buffer_names(): + buffer_names_grouping[buf].append(node) + for node_grouping in buffer_names_grouping.values(): + check_all_pairs(node_grouping) + + if config.aggressive_fusion: + group_grouping = collections.defaultdict(list) + for node in self.nodes: + group = getattr(node, "group", None) + if group: + group_grouping[group].append(node) + for node_grouping in group_grouping.values(): + check_all_pairs(node_grouping) + + possible_fusions.sort(key=self.score_fusion_key, reverse=True) + fusion_log.debug("found %d possible fusions", len(possible_fusions)) + return possible_fusions + + def will_fusion_create_cycle(self, node1, node2): + """ + Finds whether there's a path from node1 to node2 (or vice-versa) + caused indirectly by other fusions. + """ + + def found_path(node): + # only fused nodes can introduce new ancestors. + if isinstance(node, FusedSchedulerNode) and node not in visited: + visited.add(node) + if node.get_names().issubset(combined_ancestors): + # All fusion outputs are in ancestors of node1 and node2, thus + # cannot introduce new path: + # + # 1. if output is neither descendent of node1 or node2, the + # output cannot introduce a path + # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be + # on path(node1->node2), hence it cannot be ancestor of node2 + # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be + # ancestor of node1 + return False + else: + # continue DFS of new ancestors introduced by the fusion + return bool(combined_names & node.ancestors) or any( + found_path(self.name_to_fused_node[n]) + for n in node.ancestors - combined_ancestors + ) + return False + + visited = set() + combined_names = node1.get_names() | node2.get_names() + combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names + cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) + if cycle: + WhyNoFuse(node1, node2)("will create cycle") + return cycle + + def can_fusion_increase_peak_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ): + """ + This function prevents fusion for nodes that can increase memory + footprint. This problem is more common in horizontal fusion, where nodes + that are far apart in the original order get fused, lengthening the live + intervals of tensors. This is very evident in models with activation + checkpointing, where the recomputed nodes from different checkpointed + regions get fused and significantly increase the memory footprint. + + The current attempt is a quick, possibly hacky, heuristic to prevent the + fusion of nodes that are far away in the original order. + + A better but difficult to implement heurisitic would be to use live + intervals of the buffers, find region of peak pressure in the original + program and prevent fusion that crosses that peak region. We might need + special care or good approximation in this implementation, as fusion of + node changes live intervals, and re-computing live intervals and peak + memory after each fusion can introduce large compilation overhead. + """ + proximity_score = max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return proximity_score > 64 + + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Determine if it is possible to combine node1 and node2 into a + single fused node. + """ + + if node1 is node2: + return False + + why = WhyNoFuse(node1, node2) + + if ( + isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node1.is_template() + ): + why("node1 is extern or nop") + return False + if ( + isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node2.is_template() + ): + why("node2 is extern or nop") + return False + + if node2.get_names() & node1.ancestors: + why("node1 must go before node2") + return False + + if ( + isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + and isinstance(node2, SchedulerNode) + and isinstance(node2._body, ir.LoopBody) + ): + # Fix issue: https://github.com/pytorch/pytorch/issues/108963 + # Check: + # If node2 reads a buf which is a mutation buf of node1(SchedulerNode) or among nodes in node1(FusedSchedulerNode), + # we will get the corresponding mutation buf and check if this mutation buf is stored by atomic_add mode. + # If True, we will disable the fusion of node1 and node2. + if any( + ( + node2_used_buf in self.mutation_renames + and node1.has_atomic_add(self.mutation_renames[node2_used_buf]) + ) + for node2_used_buf in node2._body.reads_name2expr.keys() + ): + return False + + if node2.is_template(): + why("templates can only fuse epilogues") + return False + if node1.is_template() and ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not config.epilogue_fusion + ): + why("template epilogue not satisfied") + return False + + device = node1.get_device() + device2 = node2.get_device() + if device != device2: + why("device mismatch (%s vs %s)", device, device2) + return False + del device2 + + no_shared_data = self.score_fusion_memory(node1, node2) == 0 + if no_shared_data and ( + not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() + ): + why("no shared data") + return False # heuristic not needed for correctness + + if ( + not node1.is_foreach() + and not node2.is_foreach() + and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size + ): + why("exceeds max fusion") + return False # heuristic not needed for correctness + + if node1.get_names() & node2.ancestors: + # node2 depends on node1 outputs + if not self.can_fuse_vertical(node1, node2): + return False + return self.get_backend(device).can_fuse_vertical(node1, node2) + else: # nodes don't depend on each other, but may have common reads + if self.can_fusion_increase_peak_memory(node1, node2): + why("will increase peak memory") + return False + return self.get_backend(device).can_fuse_horizontal(node1, node2) + + def can_fuse_vertical(self, node1, node2): + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + + We also disable fusion of a write subsequent to a read if the reads + and writes do not align. + """ + node1_names = node1.get_names() + computed_deps = set() + why = WhyNoFuse(node1, node2) + + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + # if there's indirect indexing, don't match it + def fusable_read_and_write(read: Dep, write: Dep): + return ( + self.mutation_renames.get(read.name, read.name) == write.name + and (isinstance(read, MemoryDep) and isinstance(write, MemoryDep)) + and not free_symbol_has(read.index, "tmp") + and not free_symbol_has(write.index, "tmp") + and read.index == write.index + and len(read.size) >= len(write.size) + and read.size[: len(write.size)] == write.size + ) + + for rd in node2.unmet_dependencies: + for cd in node1.read_writes.writes: + if fusable_read_and_write(rd, cd): + computed_deps.add(rd) + + remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps} + if remaining_deps & node1_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + for name in remaining_deps: + if node1_names & self.name_to_fused_node[name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + # similar to can_inplace, if we are going to fuse a write subsequent to a read + # require that the indexing and size is the same + for write in node2.read_writes.writes: + for read in node1.read_writes.reads: + if write.name != self.mutation_renames.get(read.name, read.name): + continue + + # bail on StarDep + if not fusable_read_and_write(read=read, write=write): + why("fusing a write into a read with different indexing formula") + return False + + return True + + def score_fusion(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Assign a score (higher comes first) to the fusion of node1 + and node2. When different fusions conflict with each other, + this is the way we decide what order to run them in. + + Our current score is based on: + - Estimate of the saved memory operations + - Fusions closer together in original order + """ + memory_score = self.score_fusion_memory(node1, node2) + proximity_score = -max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return ( + node1.is_template() == config.epilogue_fusion_first and memory_score > 0, + node1.is_reduction() == node2.is_reduction() and memory_score > 0, + memory_score, + proximity_score, + ) + + def score_fusion_memory(self, node1, node2): + """ + The first term in our fusion score that estimates number of saved memory operations. + """ + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes + ) + common_memory_deps = { + dep for dep in common_memory_deps if not dep.has_unbacked_symbols() + } + return sum(dep.numbytes_hint() for dep in common_memory_deps) + + def score_fusion_key(self, nodes): + """ + Shim for list.sort(key=...) + """ + node1, node2 = nodes + return self.score_fusion(node1, node2) + + def compute_last_usage(self): + """ + Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) + """ + + future_used_buffers = set() + for node_name in V.graph.get_output_names(): + future_used_buffers.add(node_name) + + for node in reversed(self.nodes): + node.set_last_usage(future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + + def free_buffers(self): + """Free any buffers that are no longer needed""" + for name in sorted( + self.buffer_names_to_free + - V.graph.removed_buffers + - V.graph.wrapper_code.freed + ): + if name in self.name_to_node: + node = self.name_to_node[name] + if node.can_free(): + V.graph.wrapper_code.codegen_free(node.node) + elif name in V.graph.graph_inputs: + storage = V.graph.graph_inputs[name].data + assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + V.graph.wrapper_code.codegen_free(storage.data) + + self.buffer_names_to_free.clear() + + def remove_kernel_local_buffers(self): + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + """ + + # V.kernel.store_buffer_names should represent the set of nodes + # get fused + fused_node_names = V.kernel.store_buffer_names + names_to_remove = [] + for out_buf in V.kernel.store_buffer_names: + users = self.name_to_node[out_buf].users + assert users is not None + users = {user.get_name() for user in users if not user.is_weak} + if users.issubset(fused_node_names): + names_to_remove.append(out_buf) + + def remove_filter(n): + return ( + n not in V.kernel.must_keep_buffers + and n not in V.kernel.args.input_buffers + and n not in self.mutation_renames + and n not in self.mutation_real_name + ) + + names_to_remove = list(filter(remove_filter, names_to_remove)) + + for name in names_to_remove: + if name in V.kernel.args.inplace_buffers: + buf = V.kernel.args.inplace_buffers[name] + if isinstance(buf, str) and buf.startswith("REMOVED"): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + V.kernel.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name): + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + V.kernel.args.output_buffers[name] = "REMOVED" + V.kernel.removed_buffers.add(name) + + def remove_inplace_buffer(self, name): + log.debug("removing_inplace_buffer(%r)", name) + inner_name = V.kernel.args.inplace_buffers[name].inner_name + V.kernel.args.inplace_buffers[name] = inner_name.replace( + "in_out_ptr", "REMOVED" + ) + V.kernel.removed_buffers.add(name) + + def flush(self): + for backend in self.backends.values(): + backend.flush() + self.free_buffers() + + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode): + assert isinstance(scheduler_node, ExternKernelSchedulerNode) + # 'decide_inplace_update' stores the inplace update decisions in + # the current kernel from where 'allocate' retrieve those decisions. + # We have to make sure there is a non-NULL kernel handler to store + # those inplace update decisions. + with V.set_kernel_handler(Kernel(increase_kernel_count=False)): + scheduler_node.decide_inplace_update() + scheduler_node.allocate() + node = scheduler_node.node + assert isinstance(node, ir.ExternKernel), f"{type(node)=}" + node.codegen(V.graph.wrapper_code) + self.free_buffers() + + def create_backend(self, device: torch.device): + assert ( + device.type != "cuda" or device.index is not None + ), f"{device} should have been normalized in lowering" + V.graph.add_device_info(device) + + device_scheduling = get_scheduling_for_device(device.type) + if device_scheduling is None: + raise RuntimeError(f"Unsupported device type: {device.type}") + + if device.type == "cuda" and not has_triton(): + device_props = torch.cuda.get_device_properties(device) + if device_props.major < 7: + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + else: + raise RuntimeError( + "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) + + return device_scheduling(self) + + def get_backend(self, device: torch.device): + if device not in self.backends: + self.backends[device] = self.create_backend(device) + return self.backends[device] + + def enter_context(self, node): + def get_order(n): + if n not in self.origin_to_index: + self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.origin_to_index[n] + + # Use a dict to have ordering + origins = { + (get_order(e), e): None for n in node.get_nodes() for e in n.node.origins + } + origins = list(origins.keys()) + if origins: + _, last = max(origins, key=operator.itemgetter(0)) + V.graph.wrapper_code.enter_context(last) + + @dynamo_timed + def codegen(self): + for node in self.nodes: + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception as e: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) + + self.enter_context(node) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if ( + device != self.current_device + or node.is_extern() + or node.is_template() + ): + self.flush() + if device != self.current_device: + if device.type == "cuda": + if self.current_device and self.current_device.type == "cuda": + V.graph.wrapper_code.codegen_device_guard_exit() + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_device_guard_enter(device.index) + elif self.current_device and self.current_device.type == "cuda": + V.graph.wrapper_code.codegen_device_guard_exit() + self.current_device = device + + self.buffer_names_to_free.update(node.last_usage) + + if node.is_template(): + node, *epilogue = node.get_nodes() + self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined] + elif node.is_extern(): + self.codegen_extern_call(node) + elif node.is_foreach(): + self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined] + elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + self.get_backend(device).codegen_nodes(node.get_nodes()) # type: ignore[possibly-undefined] + else: + assert isinstance(node, NopKernelSchedulerNode) + node.allocate() + + if config.debug_check_inf_and_nan: + V.graph.wrapper_code.generate_inf_and_nan_checker(node) + + if config.triton.debug_sync_kernel: + self.get_backend(device).codegen_sync() # type: ignore[possibly-undefined] + + self.available_buffer_names.update(node.get_names()) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if self.get_backend(device).ready_to_flush(): + self.flush() + + if self.current_device and self.current_device.type == "cuda": + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() + + self.flush() + + def is_unaligned_buffer(self, buf_name): + if buf_name in V.graph.graph_inputs or buf_name in V.graph.constants: + # all graph inputs or constants are assumed to be aligned + return False + node = self.name_to_node[buf_name] + layout = node.node.get_layout() + if isinstance(layout, ir.AliasedLayout): + return not layout.maybe_guard_aligned() + else: + return False + + +class BaseScheduling: + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Check whether node1 and node2 can be vertically fused or not. + """ + raise NotImplementedError() + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Check whether node1 and node2 can be horizontally fused or not. + """ + raise NotImplementedError() + + def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + """ + Fuse two nodes + """ + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def group_fn(self, sizes): + """ + Process the iteration sizes in case a transformation needs to be applied. + """ + raise NotImplementedError() + + def codegen_template( + self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Given a template node, generate a kernel. + + This function is only available for triton now. If the third-party backend behaves as a sub-class + of TritonScheduling, it can override it or reuse it. + """ + raise NotImplementedError() + + def codegen_nodes(self, nodes: List[SchedulerNode]): + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError() + + def codegen_sync(self): + """ + Generate synchronization code for the kernel. This method depends on the hardware characteristics. + """ + raise NotImplementedError() + + def ready_to_flush(self) -> bool: + """ + Check whether the backend is requesting the scheduler to flush the generated kernel. + If not supported, please return False. + """ + return False + + def flush(self): + """ + Flush the generated kernel and python wrapper code to the source code file. + """ + raise NotImplementedError() + + def benchmark_fused_nodes(self, nodes): + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError() diff --git a/MLPY/Lib/site-packages/torch/_inductor/select_algorithm.py b/MLPY/Lib/site-packages/torch/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..5865c3d9d0b9a04b509f49e1f562296928352c05 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/select_algorithm.py @@ -0,0 +1,1156 @@ +import builtins +import functools +import inspect +import itertools +import logging +import operator +import sys +import textwrap +import time +from concurrent.futures import ThreadPoolExecutor +from io import StringIO + +from typing import Any, Callable, Dict, List, Optional, Union +from unittest.mock import patch + +import sympy + +import torch +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import counters, identity, preserve_rng_state + +from . import config, ir +from .autotune_process import TensorMeta, TritonBenchmarkRequest +from .codecache import code_hash, PersistentCache, PyCodeCache +from .codegen.common import ( + ChoiceCaller, + IndentedBuffer, + KernelTemplate, + PrimitiveInfoType, +) +from .codegen.triton import ( + gen_common_triton_imports, + texpr, + TritonKernel, + TritonPrinter, + TritonScheduling, +) +from .codegen.triton_utils import config_of, signature_to_meta +from .exc import CUDACompileError +from .utils import ( + do_bench, + get_dtype_size, + Placeholder, + sympy_dot, + sympy_product, + unique, +) +from .virtualized import V + +log = logging.getLogger(__name__) + +# correctness checks struggle with fp16/tf32 +VERIFY: Dict[str, Any] = dict() +PRINT_AUTOTUNE = True +DEBUG = False + + +class KernelNamespace: + pass + + +# these objects are imported from the generated wrapper code +extern_kernels = KernelNamespace() + + +class PartialRender: + """ + Some parts of a template need to be generated at the end, but + inserted into the template at the start. This allows doing a bunch + of replacements after the initial render. + """ + + def __init__(self, code, replacement_hooks): + super().__init__() + self.code = code + self.replacement_hooks = replacement_hooks + + def finalize(self): + code = self.code + assert code is not None, "can only be called once" + self.code = None + for key, fn in self.replacement_hooks.items(): + code = code.replace(key, fn()) + return code + + +class TritonTemplateKernel(TritonKernel): + def __init__( + self, + kernel_name, + input_nodes, + output_node, + defines, + num_stages, + num_warps, + grid_fn, + meta, + call_sizes, + use_jit=True, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + *, + index_dtype, + ): + super().__init__( + sympy_product(output_node.get_size()), + sympy.Integer(1), + index_dtype=index_dtype, + ) + self.input_nodes = input_nodes + self.output_node = output_node + self.named_input_nodes = {} + self.defines = defines + self.kernel_name = kernel_name + self.template_mask = None + self.use_jit = use_jit + self.num_stages = num_stages + self.num_warps = num_warps + self.grid_fn = grid_fn + self.meta = meta + self.call_sizes = call_sizes + # for templates with fixed epilogues + self.prefix_args = prefix_args + self.suffix_args = suffix_args + self.epilogue_fn = epilogue_fn + self.render_hooks = dict() + self.triton_meta: Optional[Dict[str, object]] = None + + def need_numel_args(self): + return False + + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size()) + numel = functools.reduce(operator.mul, size) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + + def jit_lines(self): + if self.use_jit: + return "@triton.jit" + + argdefs, _, signature = self.args.python_argdefs() + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] + triton_meta["constants"][arg_num] = 1 # type: ignore[index] + self.triton_meta = triton_meta + + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + return f""" + @triton_heuristics.template( + num_stages={self.num_stages}, + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + + def def_kernel(self, *argnames): + """ + Hook called from template code to generate function def and + needed args. + """ + assert all(isinstance(x, str) for x in argnames) + renames = IndentedBuffer(initial_indent=1) + + named_args = self.input_nodes[ + self.prefix_args : len(self.input_nodes) - self.suffix_args + ] + + assert len(argnames) == len(named_args), ( + len(argnames), + len(named_args), + self.prefix_args, + len(self.input_nodes), + ) + + for input_node in self.input_nodes[: self.prefix_args]: + # get args in correct order + self.args.input(input_node.get_name()) + + for name, input_node in zip(argnames, named_args): + arg_name = f"arg_{name}" + self.named_input_nodes[name] = input_node + self.args.input_buffers[input_node.get_name()] = arg_name + + # The args may be duplicated, so renaming must be after args are de-duplicated. + for name in argnames: + input_node = self.named_input_nodes[name] + arg_name = self.args.input_buffers[input_node.get_name()] + if input_node.get_layout().offset == 0: + renames.writeline(f"{name} = {arg_name}") + else: + offset = texpr(self.rename_indexing(input_node.get_layout().offset)) + renames.writeline(f"{name} = {arg_name} + {offset}") + + for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: + # get args in correct order + self.args.input(input_node.get_name()) + + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + code = IndentedBuffer() + code.splice(gen_common_triton_imports()) + code.splice(self.jit_lines()) + code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):") + with code.indent(): + code.splice(self.defines) + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def size(self, name: str, index: int): + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_size()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_size()[index] + return texpr(self.rename_indexing(val)) + + def stride(self, name, index): + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_stride()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_stride()[index] + return texpr(self.rename_indexing(val)) + + def store_output(self, indices, val, mask): + """ + Hook called from template code to store the final output + (if the buffer hasn't been optimized away), then append any + epilogue fusions. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(val, str) + assert isinstance(mask, str) + assert self.template_mask is None + indices = list(map(TritonPrinter.paren, indices)) + index_symbols = [sympy.Symbol(x) for x in indices] + lengths = [V.graph.sizevars.simplify(s) for s in self.output_node.get_size()] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name( + "xindex" + ) + self.template_mask = mask + self.template_indices = indices + output_index = self.output_node.get_layout().make_indexer()(index_symbols) + output_index = self.rename_indexing(output_index) + if output_index == contiguous_index: + output_index = sympy.Symbol("xindex") + + epilogue_args = [val] + for input_node in itertools.chain( + self.input_nodes[: self.prefix_args], + self.input_nodes[len(self.input_nodes) - self.suffix_args :], + ): + input_node.freeze_layout() + epilogue_args.append(input_node.make_loader()(index_symbols)) + + V.ops.store( + self.output_node.get_name(), + output_index, + self.epilogue_fn(*epilogue_args), + ) + self.codegen_body() + + def hook(): + # more stuff might have been added since the codegen_body above + self.codegen_body() + return textwrap.indent(self.body.getvalue(), " ").strip() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render(self, template, kwargs): + return PartialRender( + template.render(**self.template_env(), **kwargs), + self.render_hooks, + ) + + def make_load(self, name, indices, mask): + """ + Optional helper called from template code to generate the code + needed to load from an tensor. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(name, str) + assert isinstance(mask, str) + stride = self.named_input_nodes[name].get_stride() + indices = list(map(TritonPrinter.paren, indices)) + assert len(indices) == len(stride) + index = " + ".join( + f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) + ) + return f"tl.load({name} + ({index}), {mask})" + + def template_env(self): + """ + Generate the namespace visible in the template. + """ + return { + fn.__name__: fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.make_load, + ] + } + + def indexing( + self, + index: sympy.Expr, + *, + dense_indexing=False, + copy_shape=None, + override_mask=None, + block_ptr=False, + ): + """ + Override the default indexing to use our custom mask and force + dense indexing. + """ + return super().indexing( + index, + dense_indexing=False, + copy_shape=self.template_mask, + override_mask=self.template_mask, + block_ptr=block_ptr, + ) + + def initialize_range_tree(self, pid_cache): + super().initialize_range_tree(pid_cache) + # ignore default codegen + self.body.clear() + self.indexing_code.clear() + + def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): + wrapper = V.graph.wrapper_code + _, call_args, _ = self.args.python_argdefs() + call_args = [str(a) for a in call_args] + + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + if isinstance(call_args[i], sympy.Symbol): + call_args[i] = texpr(call_args[i]) + + if V.graph.cpp_wrapper: + # In the cpp_wrapper case, we have to compute CUDA launch grid at runtime + # if any dynamic dimension is involved. We rely on the Python version + # of the grid function to generate those grid configs, which may contain + # symbolic values. The wrapper will use cexpr to print out C++ code + # appropriately for the grid configs. + grid_args = [V.graph.sizevars.simplify(s) for s in self.call_sizes] + [ + self.meta + ] + grid = self.grid_fn(*grid_args) + + wrapper.generate_kernel_call( + name, + call_args, + device_index=V.graph.scheduler.current_device.index, + grid=grid, + triton_meta=self.triton_meta, + ) + else: + stream_name = wrapper.write_get_raw_stream( + V.graph.scheduler.current_device.index + ) + + wrapper.add_import_once(f"import {self.grid_fn.__module__}") + meta = wrapper.add_meta_once(self.meta) + + grid_call = [ + texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes + ] + [meta] + grid_call = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})" + wrapper.writeline( + f"{name}.run({', '.join(call_args)}, grid={grid_call}, stream={stream_name})" + ) + + +@functools.lru_cache(None) +def _jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class TritonTemplate(KernelTemplate): + index_counter = itertools.count() + all_templates: Dict[str, "TritonTemplate"] = dict() + + def __init__(self, name: str, grid: Any, source: str, debug=False): + super().__init__(name) + self.grid = grid + self.template = self._template_from_string(source) + assert name not in self.all_templates, "duplicate template name" + self.all_templates[name] = self + self.debug = debug + + def generate( + self, + input_nodes, + layout, + num_stages, + num_warps, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + **kwargs, + ): + assert self.template, "requires jinja2" + defines = StringIO() + for name, val in kwargs.items(): + defines.write(f" {name} : tl.constexpr = {val}\n") + defines = defines.getvalue() + + fake_out = ir.Buffer("buf_out", layout) + kernel_name = f"triton_{self.name}" + + numel = sympy_product(layout.size) + buffers = itertools.chain(input_nodes, (fake_out,)) + if not TritonScheduling.can_use_32bit_indexing(numel, buffers): + raise NotImplementedError( + "64-bit indexing is not yet implemented for triton templates" + ) + + kernel_options = dict( + input_nodes=input_nodes, + defines=defines, + num_stages=num_stages, + num_warps=num_warps, + grid_fn=self.grid, + meta=kwargs, + call_sizes=layout.size, + prefix_args=prefix_args, + suffix_args=suffix_args, + epilogue_fn=epilogue_fn, + index_dtype="tl.int32", + ) + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(fake_out) + ), TritonTemplateKernel( + kernel_name=kernel_name, + output_node=fake_out, + use_jit=True, + **kernel_options, + ) as kernel: + try: + code = kernel.render(self.template, kwargs).finalize() + except ZeroDivisionError: + # TODO(nmacchioni): fix sympy division by zero + return None + if self.debug: + print("Generated Code:\n", code) + extra = ( + "-".join( + [ + *[ + f"{kwarg}={repr(kwargs[kwarg])}" + for kwarg in sorted(kwargs.keys()) + ], + f"num_stages={num_stages}", + f"num_warps={num_warps}", + ] + ) + + "-" + ) + mod = PyCodeCache.load(code, extra) + _, call_args, _ = kernel.args.python_argdefs() + + expected_args = list(unique(x.get_name() for x in input_nodes)) + expected_args.extend([fake_out.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]), + fallback=config.unbacked_symint_fallback, + ) + + kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" + + def make_kernel_render(out_node): + kernel = TritonTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + output_node=out_node, + use_jit=False, + **kernel_options, + ) + render = functools.partial( + kernel.render, + self.template, + kwargs, + ) + return kernel, render + + # create the BenchmarkRequest + assert mod.__file__ is not None + grid = self.grid( + *V.graph.sizevars.size_hints( + layout.size, + fallback=config.unbacked_symint_fallback, + ), + kwargs, + ) + bmreq = TritonBenchmarkRequest( + module_path=mod.__file__, + module_cache_key=mod.key, + kernel_name=kernel_name, + grid=grid, + extra_args=extra_args, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + input_tensor_meta=TensorMeta.from_irnodes(input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(layout), + ) + + return TritonTemplateCaller( + kernel_hash_name, + input_nodes, + layout, + make_kernel_render, + extra.strip("-").replace("-", ", "), + bmreq, + log_info={ + "tile_shape": str( + ( + kwargs.get("BLOCK_M", -1), + kwargs.get("BLOCK_K", -1), + kwargs.get("BLOCK_N", -1), + ) + ), + "num_stages": num_stages, + "num_warps": num_warps, + "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), + "acc_type": str(kwargs.get("ACC_TYPE", None)), + }, + ) + + +class ExternKernelChoice: + def __init__( + self, + kernel, + cpp_kernel=None, + *, + name=None, + has_out_variant=True, + op_overload=None, + use_fallback_kernel=False, + ): + super().__init__() + name = name or kernel.__name__ + assert callable(kernel) + assert not hasattr(extern_kernels, name), "duplicate extern kernel" + self.name = name + self.cpp_kernel_name = cpp_kernel + self.has_out_variant = has_out_variant + setattr(extern_kernels, name, kernel) + self.op_overload = op_overload + self.use_fallback_kernel = use_fallback_kernel + + def to_callable(self): + return getattr(extern_kernels, self.name) + + def call_name(self): + return f"extern_kernels.{self.name}" + + @functools.lru_cache(None) + def hash_key(self): + fn = self.to_callable() + parts = [ + self.name, + getattr(fn, "__name__", ""), + getattr(fn, "__module__", ""), + ] + try: + parts.append(inspect.getsource(fn)) + except Exception: + pass + return code_hash("-".join(parts)) + + def bind( + self, + input_nodes, + layout, + ordered_kwargs_for_cpp_kernel=(), + **kwargs, + ): + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + return ExternKernelCaller( + self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant + ) + + +class TritonTemplateCaller(ChoiceCaller): + def __init__( + self, + name, + input_nodes, + layout, + make_kernel_render, + debug_extra, + bmreq, + log_info: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout) + self.make_kernel_render = make_kernel_render + self.debug_extra = debug_extra + self.bmreq: TritonBenchmarkRequest = bmreq + if log_info is None: + log_info = {} + self.log_info: Dict[str, Any] = log_info + self.log_info.update( + { + "backend": "Triton", + "grid": str(self.bmreq.grid), + "num_stages": self.bmreq.num_stages, + "num_warps": self.bmreq.num_warps, + } + ) + + def benchmark(self, *args, out): + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def __str__(self): + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" + + def call_name(self): + return f"template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def output_node(self): + return ir.TensorBox.create( + ir.TritonTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + ) + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return self.log_info + + +class ExternKernelCaller(ChoiceCaller): + def __init__( + self, + choice: ExternKernelChoice, + input_nodes, + layout, + kwargs=None, + *, + has_out_variant=True, + ): + super().__init__(choice.name, input_nodes, layout) + self.choice = choice + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + + def __str__(self): + return f"ExternKernelCaller({self.choice.call_name()})" + + def benchmark(self, *args, out): + if self.has_out_variant: + return super().benchmark(*args, out=out) + else: + algo = self.to_callable() + out_new = algo(*args) + torch._C._dynamo.guards.assert_size_stride( + out_new, tuple(out.size()), tuple(out.stride()) + ) + out.copy_(out_new) # for correctness checking + return do_bench(lambda: algo(*args)) + + def to_callable(self): + fn = self.choice.to_callable() + if self.kwargs: + return functools.partial(fn, **self.kwargs) + else: + return fn + + def hash_key(self): + return "-".join( + [ + self.choice.name, + *[ + f"{kwarg}={repr(self.kwargs[kwarg])}" + for kwarg in sorted(self.kwargs.keys()) + ], + self.choice.hash_key(), + ] + ) + + def output_node(self): + if config.abi_compatible and self.choice.use_fallback_kernel: + assert ( + self.choice.op_overload is not None + ), "Please provide an op_overload to use ir.FallbackKernel" + inner = ir.FallbackKernel.create( + self.choice.op_overload, *self.input_nodes, **self.kwargs + ) + else: + cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc + inner = cls( + layout=self.layout, + inputs=self.input_nodes, + python_kernel_name=self.choice.call_name(), + cpp_kernel_name=self.choice.cpp_kernel_name, + ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, + op_overload=self.choice.op_overload, + kwargs=self.kwargs, + ) + + return ir.TensorBox.create(inner) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "extern", + "kernel_call_name": self.choice.call_name(), + } + + +class ErrorFromChoice(RuntimeError): + def __init__(self, msg, choice: ChoiceCaller, inputs_str): + msg += f"\nFrom choice {choice}\n{inputs_str}" + super().__init__(msg) + self.choice = choice + + +class AlgorithmSelectorCache(PersistentCache): + def __call__( + self, + name, + choices: List[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + # TODO(nmacchioni): remove once CI tests are fixed + choices = [choice for choice in choices if choice is not None] + if len(choices) == 0: + raise RuntimeError( + "No choices to select, please consider adding ATEN into max_autotune_gemm_backends " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], CUDATemplateCaller): + # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + @functools.lru_cache(None) + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + def precompile(choices): + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + return + num_workers = min( + config.compile_threads, + torch.get_num_threads(), + len(choices), + ) + if num_workers <= 0: + return + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = executor.map( + lambda c: c.precompile(), + [c for c in choices if hasattr(c, "precompile")], + timeout=precompilation_timeout_seconds, + ) + try: + iterator = iter(futures) + while True: + try: + next(iterator) + except CUDACompileError: + log.error( # noqa: G201 + "CUDA Compilation error", exc_info=True + ) + except TimeoutError: + log.warning( + f"Precompilation timed out after {precompilation_timeout_seconds} seconds." # noqa: G004 + ) + except StopIteration: + pass + executor.shutdown(wait=True) + + def autotune(choices): + try: + precompile(choices) + except TimeoutError: + log.warning( + "Precompilation phase took longer than timeout allowed. Continuing" + ) + pass + return make_benchmark_fn()(choices) + + if config.autotune_in_subproc: + from .autotune_process import tuning_pool + + # do the optional warmup + tuning_pool.initialize() + + autotune_start_ts = time.time() + timings = self.lookup( + choices, + name, + repr([self.key_of(x) for x in input_nodes]), + autotune, + ) + autotune_elapse = time.time() - autotune_start_ts + if timings == {} or choices[0] not in timings: + return choices[0].output_node() + + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + if ( + make_benchmark_fn.cache_info().currsize + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results(name, input_nodes, timings, autotune_elapse) + selected_choice = builtins.min(timings, key=timings.__getitem__).output_node() + log.debug("selected choice: %s", str(selected_choice)) + return selected_choice + + @classmethod + def make_benchmark_fn( + cls, + choices, + input_nodes, + layout, + input_gen_fns=None, + ): + if input_gen_fns is None: + input_gen_fns = {} + + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [ + torch.as_strided( + unique_example_inputs[input_node.get_name()], + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + for input_node in input_nodes + ] + + out = cls.benchmark_example_value(layout) + out_extern = torch.as_strided( + out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) + ) + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + if DEBUG: + print(f"{len(choices)} tuning requests:") + + def debug_str(): + def tensor_repr(x): + return ( + f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " + f"dtype={x.dtype!r}, device={x.device.type!r})" + ) + + lines = [ + "inputs = [", + ] + for x in example_inputs: + lines.append(f" {tensor_repr(x)},") + lines += ["]", f"out = {tensor_repr(out)}", ""] + return "\n".join(lines) + + def benchmark_choice_in_current_process(choice): + out.zero_() + if isinstance(choice, ExternKernelCaller): + # aten kernels want the offset baked in for sliced tensors + result = choice.benchmark(*example_inputs_extern, out=out_extern) + else: + # triton templates want the base pointer for sliced tensors + result = choice.benchmark(*example_inputs, out=out) + if VERIFY: + torch.testing.assert_close(out_extern, expected, **VERIFY) + torch.cuda.synchronize() # shake out any CUDA errors + return result + + def benchmark_in_current_process(choices): + timings = {} + for choice in choices: + try: + timing = benchmark_choice_in_current_process(choice) + except CUDACompileError as e: + log.warning( + "CUDA compilation error: \n%s. \nIgnore this choice.", str(e) + ) + timing = float("inf") + except RuntimeError as e: + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + log.warning(msg) + timing = float("inf") + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + raise ErrorFromChoice(msg, choice, debug_str()) # noqa: TRY200 + except AssertionError as e: + raise AssertionError( # noqa: TRY200 + f"Incorrect result from choice {choice}\n\n{e}" + ) + + timings[choice] = timing + + return timings + + def benchmark_in_sub_process(choices): + from . import autotune_process + + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if isinstance(c, ExternKernelCaller)] + triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] + + timings = benchmark_in_current_process(extern) + timings.update(autotune_process.benchmark_in_sub_process(triton)) + return timings + + benchmark = ( + benchmark_in_sub_process + if config.autotune_in_subproc + else benchmark_in_current_process + ) + + return benchmark + + @staticmethod + def log_results( + name: str, + input_nodes: List[ir.IRNode], + timings: Dict[ChoiceCaller, float], + elapse: float, + ): + V.debug.log_autotuning_results(name, input_nodes, timings, elapse) + if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: + return + sizes = ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), fallback=config.unbacked_symint_fallback + ), + ) + ) + for n in input_nodes + ] + ) + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 + top_k = sorted(timings, key=timings.__getitem__)[:n] + best = top_k[0] + best_time = timings[best] + sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + for choice in top_k: + result = timings[choice] + if result: + sys.stderr.write( + f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n" + ) + else: + sys.stderr.write( + f" {choice.name} {result:.4f} ms \n" + ) + + autotune_type_str = ( + "SubProcess" if config.autotune_in_subproc else "SingleProcess" + ) + sys.stderr.write(f"{autotune_type_str} AUTOTUNE takes {elapse:.4f} seconds\n") + + @staticmethod + def benchmark_example_value(node): + """ + Convert an ir.Buffer into a concrete torch.Tensor we can use for + benchmarking. + """ + if isinstance(node, ir.Layout): + node = ir.Buffer("fake", node) + # triton templates want the base tensor. + if isinstance(node, ir.BaseView): + node = node.unwrap_view() + # preserve rng states to avoid the rand_strided call below changes + # the rng states for the real model code. + with preserve_rng_state(): + return rand_strided( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + device=node.get_device(), + dtype=node.get_dtype(), + extra_size=node.layout.offset, + ) + + @staticmethod + def key_of(node): + """ + Extract the pieces of an ir.Buffer that we should invalidate cached + autotuning results on. + """ + sizevars = V.graph.sizevars + return ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + +_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None + + +def autotune_select_algorithm(*args, **kwargs): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) + + +def realize_inputs(*args): + if len(args) == 1: + return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) + return [realize_inputs(x) for x in args] + + +# ensure lowering is imported so that `extern_kernels.*` is populated +from . import lowering # noqa: F401 diff --git a/MLPY/Lib/site-packages/torch/_inductor/sizevars.py b/MLPY/Lib/site-packages/torch/_inductor/sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0faf55efa24630c6de37e96304ece7ca3bcf8d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/sizevars.py @@ -0,0 +1,643 @@ +import functools +import itertools +import logging +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union + +import sympy +from sympy import Expr + +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.value_ranges import bound_sympy + +from .utils import sympy_index_symbol, sympy_subs, VarRanges +from .virtualized import V + +log = logging.getLogger(__name__) + + +# This class is a little awkward, because ShapeEnv is doing most of the heavy +# lifting and in some cases we should be directly passing through to ShapeEnv, +# but there is some extra inductor logic that needs to be handled here +class SizeVarAllocator: + def __init__(self, shape_env=None): + super().__init__() + if shape_env is None: + shape_env = ShapeEnv() + self.shape_env = shape_env + self.var_to_val = self.shape_env.var_to_val + self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements + # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. + # The basic idea is if we have some complicated sympy expression + # f(s0), we may choose to precompute it on the host and then replace + # all occurrences of that sympy expression with ps0, so that when we + # codegen we simply reference ps0 directly without repeating + # f(s0). Unlike regular size variables, ps variables cannot be + # guarded upon; so if we are asked to guard on a Sympy expression + # which potentially could have already had a precomputed replacement + # on it, we are obligated to invert the precomputed replacements + # (inv_precomputed_replacements). + self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict() + self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict() + self.stride_vars = self.make_stride_vars_cache() + self.simplify_with_ranges = self.make_simplify_with_ranges_cache() + self._simplify_loops = self.make_simplify_loops_cache() + + def simplify(self, expr: Expr): + return sympy.expand(expr).xreplace(self.replacements) + + def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Expr] = dict() + replacement_count = len(self.replacements) + + def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (expr, *var_ranges.items()) + result = cache.get(key, None) + if result is None: + result = self._simplify_with_ranges(expr, var_ranges) + cache[key] = result + return result + + return simplify_with_ranges + + def make_simplify_loops_cache(self): + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: Dict[Tuple[Any, ...], Any] = dict() + replacement_count = len(self.replacements) + + def simplify_loops(index_vars, sizes, index_formulas): + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (*index_vars, *sizes, *index_formulas) + result = cache.get(key, None) + if result is None: + result = self._simplify_loops_impl(index_vars, sizes, index_formulas) + cache[key] = result + return result + + return simplify_loops + + def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: + """ + Simplify indexing expression with knowledge of the ranges of + iteration variables. + """ + + expr = join_dimensions(self.simplify(expr)) + original_expr = expr + + def remove_zero_terms(base, divisor): + """Symbols smaller than the divisor are zero""" + for v in base.free_symbols: + if v in var_ranges: + # var smaller than divisor can be removed + # if the rest is guaranteed to be multiple of divisor + rest = sympy.Wild("_rest", exclude=[v]) + m = base.match(v + rest) + if m and v not in m[rest].free_symbols: + gcd = sympy.gcd(m[rest], divisor) + if gcd == divisor: + if self.statically_known_leq(var_ranges[v], divisor): + base = m[rest] + return base + + def visit_indexing_div(base, divisor): + return FloorDiv(remove_zero_terms(base, divisor), divisor) + + def visit_modular_indexing(base, divisor, modulus): + base = remove_zero_terms(base, divisor) + base_pos = True + if isinstance(base, ModularIndexing): + # for modular indexing, biggest values from the ranges don't necessarily result in + # the biggest result, the biggest result is modulus - 1 + base_s = base.args[2] - 1 + elif not base.has(ModularIndexing): + # actual iteration range is to size-1 + iter_ranges_zero = {k: 0 for k, v in var_ranges.items()} + base_lowest = sympy_subs(base, iter_ranges_zero) + if self.statically_known_leq(0, base_lowest): # type: ignore[arg-type] + # can't replace with indexing div if base can be negative + base_pos = True + else: + base_pos = False + iter_ranges = {k: v - 1 for k, v in var_ranges.items()} + base_s = sympy_subs(base, iter_ranges) + else: + base_s = base + if self.statically_known_lt(base_s, modulus * divisor) and base_pos: + return FloorDiv(base, divisor) + return ModularIndexing(base, divisor, modulus) + + if expr.has(ModularIndexing): + expr = expr.replace( + ModularIndexing( + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), + ), + visit_modular_indexing, + ) + + if expr.has(FloorDiv): + expr = expr.replace( + FloorDiv( + sympy.Wild("base"), + sympy.Wild("divisor"), + ), + visit_indexing_div, + ) + + if expr != original_expr: + return self._simplify_with_ranges(expr, var_ranges) + return expr + + def _simplify_loops_impl( + self, index_vars: List[sympy.Symbol], sizes, index_formulas + ): + """ + Try to remove as many axis from loop iterations as possible, by: + 1) removing size==1 dimensions + 2) fuse contiguous dimensions into a single loop + If channel_last = True, we will prevent the last dim fused with other dims + """ + sizes = list(map(self.simplify, sizes)) + + strides = [self.stride_vars(x, index_vars) for x in index_formulas] + assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) + + for i in range(len(sizes)): + if sizes[i] == 1: + # remove dim + sizes[i] = None + + def can_merge_dims(a, b): + for k in range(len(strides)): + if self.simplify(strides[k][a] * sizes[a]) == self.simplify( + strides[k][b] + ): + # approximate test passed, try sound version + va = index_vars[a] + vb = index_vars[b] + v = sympy_index_symbol("_merge_tester") + expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0}) + expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v}) + if self.simplify(expr1) == self.simplify(expr2): + continue + return False + return True + + changed = True + while changed: + changed = False + for i, j in itertools.product( + reversed(range(len(sizes))), reversed(range(len(sizes))) + ): + if i == j or sizes[i] is None or sizes[j] is None: + continue + if can_merge_dims(i, j): + changed = True + sizes[i] = sizes[i] * sizes[j] + sizes[j] = None + + def reindex(index): + it = list(reversed(index)) + new_index = [] + for size in sizes: + if size is None: + new_index.append(sympy.Integer(0)) + else: + new_index.append(it.pop()) + assert not it + return new_index + + def prune(index): + assert len(index) == len(sizes) + return [i for i, s in zip(index, sizes) if s is not None] + + return [x for x in sizes if x is not None], reindex, prune + + # Note - [On Statically Known] + # + # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system + # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was + # true, we add a guard and return True, otherwise, False. + # + # def maybe_guard_foo(args): + # if size_hinted_check(args): + # return False # No guard, no optim + # guard(args) # Make a guard + # return True # Safe to apply optimization + # + # The prior system incurred a guard, and green lit an optimization. + # + # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the + # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we + # return False. + # + # def maybe_guard_foo(args): + # if all_static(args): + # return True # Safe to apply optimization + # else: + # return False # No guard, no optim + + # See Note - [On Statically Known] + + def is_expr_static_and_true(self, expr: Union[Expr, int]) -> bool: + if expr in (True, False): + return bool(expr) + + try: + simplified = self.shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr) + + return False + + def statically_known_equals(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right are equal. + """ + return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] + + # See Note - [On Statically Known] + def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right lists are equal. + """ + if len(left) != len(right): + return False + if all(self.statically_known_equals(l, r) for l, r in zip(left, right)): + return True + return False + + # See Note - [On Statically Known] + def statically_known_leq(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. + """ + expr = left <= right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_lt(self, left: Expr, right: Expr) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than right. + """ + expr = left < right + return self.is_expr_static_and_true(expr) + + # See Note - [On Statically Known] + def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool: + """ + Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. + """ + expr = sympy.Eq(numerator % denominator, 0) + return self.is_expr_static_and_true(expr) # type: ignore[arg-type] + + # The guard functions require you to ALREADY KNOW that a particular + # condition holds. If you don't know (you want to guard on an expression + # being a particular value, and then get access to that value), use + # the evaluate functions. + + def guard_equals(self, left: Expr, right: Expr) -> Expr: + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) + return left + + def guard_leq(self, left: Expr, right: Expr) -> None: + return self.guard_lt(left, right + 1) + + def guard_lt(self, left: Expr, right: Expr) -> None: + assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) + + def expect_true(self, expr: Expr, *, msg: str) -> None: + expr = sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + self.shape_env.defer_runtime_assert(expr, msg, fx_node=None) + + def expect_equals(self, left: Expr, right: Expr, *, msg: str) -> Expr: + # Prefer returning the expression without unbacked symints + if self.shape_env.is_unbacked_symint(left): + self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type] + return right + elif self.shape_env.is_unbacked_symint(right): + self.expect_true(sympy.Eq(left, right), msg=msg) # type: ignore[arg-type] + return left + else: + return self.guard_equals(left, right) + + def guarded_order(self, seq): + """ + Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. + Used for generating block_ptrs. + """ + seq = [*map(self.remove_precomputed_replacements, seq)] + seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] + seq.sort() + order = [-1] * len(seq) + last_var = None + for new_index, (_, orig_index, var) in enumerate(seq): + order[orig_index] = new_index + if last_var is not None: + self.guard_leq(last_var, var) + last_var = var + return order + + # The evaluate functions evaluate some symbolic sympy expression + # (NB: not necessarily an Expr) and return what the concrete result + # is, guarding on the expression being that result + + # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) + # as this will ensure that you actually have a sympy'ified expression, + # and will prevent you from incorrectly writing evaluate_expr(a == b) + # which does the wrong thing if a or b is a sympy expression + def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: + assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) + return self.shape_env.evaluate_expr(sympy.sympify(left)) + + def evaluate_min(self, left: Expr, right: Expr) -> Expr: + """return the smaller of left and right, and guard on that choice""" + lv = self.size_hint(left) + rv = self.size_hint(right) + if lv <= rv: + self.guard_leq(left, right) + return left + else: + self.guard_leq(right, left) + return right + + def evaluate_max(self, left: Expr, right: Expr) -> Expr: + """return the larger of left and right, and guard on that choice""" + # Always choose the opposite of eval min for consistency + # This means min(a, b) and max(a, b) produce the same guards + min_val = self.evaluate_min(left, right) + return right if min_val is left else left + + def evaluate_static_shape(self, left: Expr) -> int: + right = self.size_hint(left) + self.guard_equals(left, sympy.Integer(right)) + return int(right) + + def evaluate_static_shapes(self, left: List[Expr]) -> List[int]: + return [self.evaluate_static_shape(x) for x in left] + + def remove_precomputed_replacements(self, expr: Expr) -> Expr: + if any(s.name.startswith("ps") for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + return expr + + def symbolic_hint(self, expr: Expr) -> Expr: + # Substitute all hints into expr, but leave unbacked symints alone + if not isinstance(expr, Expr): + assert isinstance(expr, int) + return expr + free_symbols = expr.free_symbols + if not free_symbols: + return int(expr) # type: ignore[return-value] + expr = self.remove_precomputed_replacements(expr) + return sympy_subs(expr, self.var_to_val) + + def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int: + out = self.symbolic_hint(expr) + if not isinstance(out, (int, sympy.Integer)) and fallback is not None: + # Use the provided heuristic fallback hint + sym_vrs = { + s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols + } + if all(vr is not None for vr in sym_vrs.values()): + expr_vr = bound_sympy(expr, sym_vrs) # type: ignore[arg-type] + lower = self.size_hint(expr_vr.lower) # type: ignore[arg-type] + upper = self.size_hint(expr_vr.upper) # type: ignore[arg-type] + fallback = min(max(fallback, lower), upper) + return fallback + try: + return int(out) + except Exception: + log.debug("failed on: %s", out) + raise + + def size_hints( + self, + exprs: Iterable[Expr], + *, + fallback: Optional[int] = None, + ) -> Tuple[int, ...]: + return tuple(self.size_hint(x, fallback=fallback) for x in exprs) + + def _lru_cache(self, fn, maxsize=None): + """ + Wrapper around functools.lru_cache that clears when replacements + has been invalidated. + """ + fn_cache = functools.lru_cache(maxsize)(fn) + prior_len = len(self.replacements) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal prior_len + if prior_len != len(self.replacements): + prior_len = len(self.replacements) + fn_cache.cache_clear() + return fn_cache(*args, **kwargs) + + return wrapper + + def make_stride_vars_cache(self): + cache = self._lru_cache(self._stride_vars) + + def stride_vars( + index: Expr, + vars: List[sympy.Symbol], + support_vars: Optional[List[sympy.Symbol]] = None, + ) -> List[Expr]: + if not support_vars: + support_vars = vars + return cache(index, tuple(vars), tuple(support_vars)) + + return stride_vars + + def _stride_vars( + self, index: Expr, vars: List[sympy.Symbol], support_vars: List[sympy.Symbol] + ) -> List[Expr]: + """Convert an indexing expression back into strides + + NOTE: This is only valid if the index is a standard strided offset + calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a + stride of -10 because the index wraps around after the first element + + """ + strides = [] + index = self.simplify(index) + # remove any offset + index = index - sympy_subs( + index, {v: sympy.Integer(0) for v in support_vars if v != 0} + ) + for i in range(len(vars)): + # drop all the other dims + index_dim = sympy_subs( + index, + { + support_vars[j]: sympy.Integer(0) + for j in range(len(support_vars)) + if vars[i] != support_vars[j] and support_vars[j] != 0 + }, + ) + v = vars[i] + if v == 0: + strides.append(sympy.Integer(0)) + else: + # TODO(jansel): should we use sympy.diff here? + strides.append( + sympy_subs(index_dim, {v: sympy.Integer(1)}) + - sympy_subs(index_dim, {v: sympy.Integer(0)}) + ) + return strides + + def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: + """Extract offset part of an indexing expression""" + index = self.simplify(index) + return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) + + def stride_hints( + self, + index: Expr, + vars: List[sympy.Symbol], + support_vars: Optional[List[sympy.Symbol]] = None, + ) -> List[int]: + for v in index.free_symbols: + if v.name.startswith("indirect"): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] + result = [] + for s in self.stride_vars(index, vars, support_vars): + try: + result.append(self.size_hint(s)) + except TypeError: + result.append(0) + return result + + def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: + strides = tuple(map(abs, self.stride_hints(index, vars))) + order = list(range(len(strides))) + order.sort(key=lambda x: (strides[x] == 0, strides[x])) + return order + + def lookup_precomputed_size(self, expr: Expr) -> Expr: + if ( + isinstance(expr, (int, sympy.Symbol, sympy.Number)) + or expr.is_number + or expr.is_symbol + ): + return expr + expr = self.remove_precomputed_replacements(expr) + if expr not in self.precomputed_replacements: + sym = sympy_index_symbol(f"ps{len(self.precomputed_replacements)}") + self.precomputed_replacements[expr] = sym + self.inv_precomputed_replacements[sym] = expr + return self.precomputed_replacements[expr] + + def free_symbols(self) -> Set[sympy.Symbol]: + return set(self.var_to_val.keys()) - set(self.replacements.keys()) + + +def join_dimensions(expr: Expr) -> Expr: + if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): + return expr # fast exit path + return _join_dimensions_cached(expr) + + +@functools.lru_cache(256) +def _join_dimensions_cached(expr: Expr) -> Expr: + """ + ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) + becomes + ModularIndexing(i0, 1, 128) + ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) + becomes i0 + + + This type of pattern can come from view operations + """ + assert isinstance(expr, sympy.Add) + + scale = sympy.Wild("scale", exclude=[0]) + base = sympy.Wild("base") + divisor = sympy.Wild("divisor") + mod1 = sympy.Wild("modulus") + mod2 = sympy.Wild("modulus2") + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] + * m1[mod1] + * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) + ) + if m2 and term1 != term2: + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] + * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) + ) + return expr + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) + ) + if m2 is not None: # in case of success we get an empty dict here + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] * FloorDiv(m1[base], m1[divisor]) + ) + return expr + return expr + + +class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] + """ + A wrapper around .virtualize.ops that uses var range information to + simplify ModularIndexing/FloorDiv. + """ + + def __init__(self, inner, var_ranges: VarRanges): + super().__init__(inner) + self.name = "SimplifyIndexing" + self._simplify: Callable[ + [Expr], Expr + ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(name, self._simplify(index)) + + def store(self, name, index, value, mode=None): + return self._inner.store(name, self._simplify(index), value, mode=mode) + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(name, self._simplify(index), value) + + def index_expr(self, index, dtype): + return self._inner.index_expr(self._simplify(index), dtype) diff --git a/MLPY/Lib/site-packages/torch/_inductor/test_case.py b/MLPY/Lib/site-packages/torch/_inductor/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..545c33dc4b952d5352e61caee1bcc7429458e007 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/test_case.py @@ -0,0 +1,53 @@ +import contextlib +import tempfile +import unittest + +from torch._dynamo.test_case import ( + run_tests as dynamo_run_tests, + TestCase as DynamoTestCase, +) + +from torch._inductor import config + + +def run_tests(needs=()): + dynamo_run_tests(needs) + + +class TestCase(DynamoTestCase): + """ + A base TestCase for inductor tests. Enables FX graph caching and isolates + the cache directory for each test. + """ + + _stack: contextlib.ExitStack + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._stack = contextlib.ExitStack() + cls._stack.enter_context(config.patch({"fx_graph_cache": True})) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._stack.close() + + def setUp(self): + super().setUp() + + # For all tests, mock the tmp directory populated by the inductor + # FxGraphCache, both for test isolation and to avoid filling disk. + self._inductor_cache_tmp_dir = tempfile.TemporaryDirectory() + self._inductor_cache_get_tmp_dir_patch = unittest.mock.patch( + "torch._inductor.codecache.FxGraphCache._get_tmp_dir" + ) + mock_get_dir = self._inductor_cache_get_tmp_dir_patch.start() + mock_get_dir.return_value = self._inductor_cache_tmp_dir.name + + def tearDown(self): + super().tearDown() + + # Clean up the FxGraphCache tmp dir. + self._inductor_cache_get_tmp_dir_patch.stop() + self._inductor_cache_tmp_dir.cleanup() diff --git a/MLPY/Lib/site-packages/torch/_inductor/test_operators.py b/MLPY/Lib/site-packages/torch/_inductor/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..9d41e9c1e6b41b27b468c8d48199c4bbfe792706 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/test_operators.py @@ -0,0 +1,24 @@ +import torch.library +from torch import Tensor +from torch.autograd import Function + +_test_lib_def = torch.library.Library("_inductor_test", "DEF") +_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) + +_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") +for dispatch_key in ("CPU", "CUDA", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + +class Realize(Function): + @staticmethod + def forward(ctx, x): + return torch.ops._inductor_test.realize(x) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/MLPY/Lib/site-packages/torch/_inductor/triton_helpers.py b/MLPY/Lib/site-packages/torch/_inductor/triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3773267cd6a0c062037f60f9e9943873d1fddaf6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/triton_helpers.py @@ -0,0 +1,344 @@ +import triton +import triton.language as tl + +# In the latest triton, math functions were shuffled around into different modules: +# https://github.com/openai/triton/pull/3172 +if hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math +else: + libdevice = tl.math + math = tl + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def _prod_accumulate(a, b): + return a * b + + +@triton.jit +def prod(input, axis): + return tl.reduce(input, axis, _prod_accumulate) + + +@triton.jit +def minimum(a, b): + mask = a < b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def min2(a, dim): + return tl.reduce(a, dim, minimum) + + +@triton.jit +def max2(a, dim): + return tl.reduce(a, dim, maximum) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def min_with_index(value, index, dim): + return tl.reduce((value, index), dim, minimum_with_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +@triton.jit +def welford_reduce(value, mean, m2, weight, first_iteration): + if first_iteration: + new_weight = tl.full(weight.shape, 1, weight.dtype) + new_mean = value + new_m2 = tl.zeros_like(m2) + else: + delta = value - mean + new_weight = weight + 1 + new_mean = mean + delta / new_weight + new_m2 = m2 + delta * (value - new_mean) + return new_mean, new_m2, new_weight + + +@triton.jit +def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +@triton.jit +def welford(mean, m2, weight, dim): + return tl.reduce((mean, m2, weight), dim, welford_combine) + + +@triton.jit +def device_assert_then(cond, msg, r): + tl.device_assert(cond, msg) + return r + + +@triton.jit +def randint64(seed, offset, low, high): + r0, r1, r2, r3 = tl.randint4x(seed, offset) + r0 = r0.to(tl.uint64) + r1 = r1.to(tl.uint64) + result = r0 | (r1 << 32) + size = high - low + result = result % size.to(tl.uint64) + result = result.to(tl.int64) + low + return result + + +@triton.jit +def _any_combine(a, b): + return a | b + + +@triton.jit +def any(a, dim): + return tl.reduce(a, dim, _any_combine) + + +@triton.jit +def bucketize_binary_search( + values, # 1D tensor + offsets_ptr, + indexing_dtype, + right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op] + OFFSETS_SIZE: int, + BLOCK_SHAPE, # tuple/list of block shape +): + """ + See [Note: Inductor bucketize op] + """ + + low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype) + high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype) + + full_range = OFFSETS_SIZE + 1 + while full_range > 1: + mid = (high + low) // 2 + mask = mid < OFFSETS_SIZE + bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask) + if right: + is_above = values >= bucket_upper_bound + else: + is_above = values > bucket_upper_bound + + low = tl.where(is_above & mask, mid + 1, low) + high = tl.where(is_above, high, mid) + + full_range = (full_range + 1) // 2 + + return low + + +@triton.jit +def pack_value_flag( + value, + flag, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) + return flag.to(DTYPE_PACK) | (uv << bitwidth) + + +@triton.jit +def unpack_value( + pack, + DTYPE_VALUE, + DTYPE_VALUE_AS_UINT, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE) + DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) + return value_uint.to(DTYPE_VALUE, bitcast=True) + + +@triton.jit +def unpack_flag(pack, DTYPE_FLAG): + return pack.to(DTYPE_FLAG) + + +@triton.jit +def exclusive_scan_decoupled_lookback( + scratch_base, + block_value, + index, + combine_fn, + init, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + init: Scalar value equal to the identiy of combine_fn + DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` + DTYPE_PACK: Unsigned type twice the width of block_value + + NOTE: This function is limited to values which are 32-bits or less. + """ + DTYPE_VALUE = block_value.dtype + pack = pack_value_flag( + block_value, + tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + + exclusive_prefix = init + test_target = index - 1 + while test_target >= 0: + # tl.atomic_load + flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) + while flag == 0: + pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") + flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) + + value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) + exclusive_prefix = combine_fn(value, exclusive_prefix) + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + pack = pack_value_flag( + inclusive_prefix, + tl.full([], 2, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + return exclusive_prefix + + +@triton.jit +def exclusive_scan_decoupled_lookback_64( + scratch_base, block_value, index, combine_fn, init +): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block, must be 64-bits wide + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + init: Scalar value equal to the identiy of combine_fn + """ + block_value_u64 = block_value.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 1, block_value_u64) + tl.debug_barrier() + flag_one = tl.full([], 1, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") + + exclusive_prefix = init + test_target = index - 1 + while test_target >= 0: + flag = tl.full([], 0, tl.uint64) + while flag == 0: + flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") + + value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) + value = value_u64.to(block_value.dtype, bitcast=True) + exclusive_prefix = combine_fn(value, exclusive_prefix) + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) + tl.debug_barrier() + flag_two = tl.full([], 2, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") + + return exclusive_prefix + + +@triton.jit +def frexp(x): + # TODO(isuruf): use inline_asm_elementwise here + y = libdevice.ilogb(x) + 1 + exponent = tl.where(x == 0, 0, y) + mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) + return mantissa, exponent diff --git a/MLPY/Lib/site-packages/torch/_inductor/triton_heuristics.py b/MLPY/Lib/site-packages/torch/_inductor/triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..a0589405e555857aeb14813dc96d66204fc08ddf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/triton_heuristics.py @@ -0,0 +1,1527 @@ +import builtins +import copy +import functools +import hashlib +import inspect +import json +import logging +import math +import operator +import os +import os.path +import re +import threading +from enum import auto, Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple + +import torch + +import torch.autograd.profiler as autograd_profiler +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import dynamo_timed, get_first_attr +from torch.utils._triton import has_triton_package + +from . import config +from .codecache import cache_dir, CudaKernelParamCache +from .coordinate_descent_tuner import CoordescTuner + +from .ir import ReductionHint, TileHint +from .utils import ( + ceildiv, + conditional_product, + create_bandwidth_info_str, + do_bench, + get_max_y_grid, + get_num_bytes, + next_power_of_2, + triton_config_to_hashable, +) + + +log = logging.getLogger(__name__) + +if has_triton_package(): + import triton + from triton import Config + from triton.runtime.autotuner import OutOfResources + from triton.runtime.jit import KernelInterface + + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None +else: + Config = object + triton = None + KernelInterface = object + OutOfResources = object + ASTSource = None + + +_NUM_THREADS_PER_WARP = 32 + + +class HeuristicType(Enum): + PERSISTENT_REDUCTION = auto() + POINTWISE = auto() + REDUCTION = auto() + SPLIT_SCAN = auto() + TEMPLATE = auto() + USER_AUTOTUNE = auto() + + +class AutotuneHint(Enum): + ELEMENTS_PER_WARP_32 = 0 + + # Triton codegen tries to codegen set of AutotuneHints. + # Enum.__repr__ looks like """ + # which isn't valid python. + # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". + __repr__ = Enum.__str__ + + +def autotune_hints_to_configs( + hints: Set[AutotuneHint], size_hints, block_size: int +) -> List[Config]: + """ + AutotuneHints can be attached to the metadata of triton kernels for providing + suggestions about what to try for autotuning. One reason to do this is if there are + some configs that are only useful in specific scenarios, in which case we can avoid + wasting compile time on autotuning unless we know we are in one of those scenarios. + + Based on those hints, this function will generate a list of additional autotuning + configs to try. + """ + xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...] + configs = [] + + for hint in hints: + if hint == AutotuneHint.ELEMENTS_PER_WARP_32: + if len(size_hints) == 1: + xyz_options = ((block_size // 4, None, None),) + elif len(size_hints) == 2: + xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) + elif len(size_hints) == 3: + xyz_options = ( + (block_size // 4, 1, 1), + (1, block_size // 4, 1), + (1, 1, block_size // 4), + ) + for xyz in xyz_options: + configs.append( + triton_config( + size_hints, + *xyz, + num_elements_per_warp=32, + ) + ) + + return configs + + +def disable_pointwise_autotuning(): + # Autotuning can give different benchmarking results from run to run, and + # therefore we disable autotuning when use_deterministic flag is on. + if torch.are_deterministic_algorithms_enabled(): + return True + return not config.triton.autotune_pointwise + + +class CachingAutotuner(KernelInterface): + """ + Simplified version of Triton autotuner that has no invalidation + key and caches the best config to disk to improve cold start times. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + ): + super().__init__() + + assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + self.fn = fn + self.triton_meta = triton_meta + self.inductor_meta = {} if inductor_meta is None else inductor_meta + self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names + self.configs = configs + self.heuristic_type = heuristic_type + self.custom_kernel = custom_kernel + self.cuda_kernel_saved = False + + # Align the default design that default as cuda + self.device_type = ( + triton_meta["device_type"] if "device_type" in triton_meta else "cuda" + ) + self.gpu_device = get_interface_for_device(self.device_type) + + if log.isEnabledFor(logging.DEBUG): + log.debug( + "CachingAutotuner gets %d configs for %s", + len(self.configs), + self.fn.__name__, + ) + for c in self.configs: + log.debug(c) + + self.launchers = [] + self.lock = threading.Lock() + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = os.path.join( + cache_dir(), + "triton", + str(self.triton_meta.get("device", 0)), + ) + + self.size_hints = size_hints + self.coordesc_tuner = CoordescTuner( + is_mm=False, name=self.fn.__name__, size_hints=size_hints + ) + + # pre-create the profiler context manager to reduce latency + self.record_function_ctx = torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel") + ) + + def precompile(self, warm_cache_only_with_cc=None): + with self.lock: + if self.launchers: + return + self.launchers = [] + compiled_binaries = [] + if not self.configs: + raise RuntimeError("No triton configs are available") + + for c in self.configs: + try: + compiled_binary, launcher = self._precompile_config( + c, warm_cache_only_with_cc + ) + except OutOfResources: + # Skip the config if we run out of resource + continue + self.launchers.append(launcher) + compiled_binaries.append(compiled_binary) + + if len(self.launchers) == 0: + raise RuntimeError( + "No valid triton configs. Report a fatal compilation error" + ) + + seen_configs = set(self.configs) + + device_prop = self.gpu_device.Worker.get_device_properties( + self.triton_meta["device"] + ) + if ( + config.dynamic_scale_rblock + and self.heuristic_type == HeuristicType.REDUCTION + and self.size_hints is not None + # Disable for AMDGPU as Triton is not ready to return n_regs for a compiled_binary. + and torch.version.hip is None + and device_prop.major >= 8 + ): + for triton_config, compiled_binary in zip( + self.configs, compiled_binaries + ): + assert len(self.size_hints) == 2 + xblock = triton_config.kwargs.get("XBLOCK", 1) + rblock = triton_config.kwargs["RBLOCK"] + total_block = (self.size_hints[0] + xblock - 1) // xblock + nreg = getattr(compiled_binary, "n_regs", None) + if nreg is None: + continue + + # make sure rblock is not too small + if rblock <= 64: + continue + + # each SM of A100 has 65536 32-bit registers. To maximize + # the theoretical occupancy, we need run 2048 threads on each + # SM. So each thread should use no more than 65536 / 2048 + # = 32 registers. In cases where occupancy matters, and each + # thread uses too many registers, reduce RBLOCK to reduce + # the register usage. + # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd + # from PLBartForCausalLM, latency improve from + # 7.795ms to 4.883ms. + # + if ( + nreg + <= device_prop.regs_per_multiprocessor + // device_prop.max_threads_per_multi_processor + ): + continue + + nreg_per_warp = nreg * 32 + nreg_per_block = nreg_per_warp * triton_config.num_warps + + # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' + # The formula below is a tighter upper bound since we have the assumption that + # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor + # due to the if condition above and: + # regs_per_multiprocessor / nreg_per_block + # = regs_per_multiprocessor / (nreg * 32 * num_warps) + # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) + # = max_threads_per_multi_processor / (32 * num_warps) + # Using a tigher upper bound can reveal more optimization opportunities. + max_blocks_per_sm = max( + device_prop.regs_per_multiprocessor // nreg_per_block, 1 + ) + + if ( + total_block + <= max_blocks_per_sm * device_prop.multi_processor_count + ): + # no need to improve occupancy + continue + new_config = copy.deepcopy(triton_config) + new_config.kwargs["RBLOCK"] = rblock // 2 + if new_config in seen_configs: + continue + seen_configs.add(new_config) + self.launchers.append( + self._precompile_config(new_config, warm_cache_only_with_cc)[1] + ) + self.configs = None + + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + for k, v in cfg.kwargs.items(): + compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = ( + config.assert_indirect_indexing and torch.version.hip is None + ) + + # Setting device_type="hip" required on ROCm to pass down to triton + compile_meta["device_type"] = ( + self.device_type if torch.version.hip is None else "hip" + ) + + if warm_cache_only_with_cc: + cc = warm_cache_only_with_cc + else: + # Use device_type 'cuda' for both cuda and hip devices to retrieve + # the compute capability. + device_type = self.device_type if torch.version.hip is None else "cuda" + device_id = compile_meta["device"] + device = torch.device(device_type, device_id) + cc = self.gpu_device.get_compute_capability(device) + + compile_meta["cc"] = cc + + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), + ) + + target = (compile_meta["device_type"], cc) + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn,) + compile_kwargs = compile_meta + + if warm_cache_only_with_cc: + return ( + triton.compile(*compile_args, **compile_kwargs), + None, + ) + + # load binary to the correct device + with self.gpu_device.device(compile_meta["device"]): # type: ignore[attr-defined] + # need to initialize context + self.gpu_device.synchronize(self.gpu_device.current_device()) + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + binary._init_handles() + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": binary.launch_enter_hook, + "launch_exit_hook": binary.launch_exit_hook, + "metadata": binary.metadata, + "torch": torch, + "set_device": self.gpu_device.set_device, + "current_device": self.gpu_device.current_device, + } + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + scope["function"] = get_first_attr(binary, "function", "cu_function") + scope["cta_args"] = ( + (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ) + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + scope["shared"] = binary_shared + + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + runner(grid_0, grid_1, grid_2, num_warps, + *cta_args, shared, + stream, function, + launch_enter_hook, + launch_exit_hook, + metadata, + {', '.join(call_args)}) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = config.triton.store_cubin + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = self.fn + launcher.bin = binary + + return binary, launcher + + def bench(self, launcher, *args, grid, **kwargs): + """Measure the performance of a given launcher""" + # we don't skip configs wiht spilled registers when auto-tuning custom + # (user-written) Triton kernels, as (i) we don't have any knowledge or + # control over the kernel code; (ii) there is empirical evidence that + # for some (complicated) custom Triton kernels, a register-spilling + # config may yield the best latency. + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + + def kernel_call(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid, + stream=stream, + ) + + return do_bench(kernel_call, rep=40, fast_flush=True) + + def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: + from .compile_fx import clone_preserve_strides + + # clone inplace buffers to avoid autotune contaminating them if + # the kernel does in-place stores. avoid cloning other buffers because + # it leads to increase memory use + cloned_args = [] + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_args.append(clone_preserve_strides(arg)) + else: + cloned_args.append(arg) + + cloned_kwargs: Dict[str, Any] = {} + for name, arg in kwargs.items(): + if name in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_kwargs[name] = clone_preserve_strides(arg) + else: + cloned_kwargs[name] = arg + + return cloned_args, cloned_kwargs + + @dynamo_timed + def benchmark_all_configs(self, *args, **kwargs): + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + timings = self.benchmark_all_configs(*args, **kwargs) + self.launchers = [builtins.min(timings, key=timings.get)] + if self.save_cache_hook: + self.save_cache_hook(self.launchers[0].config) + + def save_cuda_kernel(self, grid, stream, launcher): + if callable(grid): + grid_x, grid_y, grid_z = grid(launcher.config.kwargs) + else: + grid_x, grid_y, grid_z = grid + + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + assert key is not None, "kernel_name can not be None" + params = { + "mangled_name": launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"], + "grid_x": grid_x, + "grid_y": grid_y, + "grid_z": grid_z, + "x_block": launcher.config.kwargs.get("XBLOCK", 1), + "y_block": launcher.config.kwargs.get("YBLOCK", None), + "z_block": launcher.config.kwargs.get("ZBLOCK", None), + "num_warps": launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps, + "shared_mem": launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared, + "stream": stream, + # User defined triton kernels will have arbitrary kwarg names + "meta": launcher.config.kwargs, + } + + if torch.version.hip is None: + CudaKernelParamCache.set(key, params, launcher.bin.asm["cubin"]) + else: + # There is some divergence between CUDA and ROCm here. + # On ROCm's triton we only have the the path to the binary, not the binary itself. + # For ROCm we will copy the binary to the new location instead of writing to file + import pathlib + + launcher.bin.asm["hsaco"] = pathlib.Path( + launcher.bin.asm["hsaco_path"] + ).read_bytes() + CudaKernelParamCache.set(key, params, launcher.bin.asm["hsaco"]) + + self.cuda_kernel_saved = True + + def coordinate_descent_tuning(self, launcher, *args, **kwargs): + """ + Coordinate descent tuning can be run with or without max-autotune. + + The only difference between these two is the starting config for coordinate_descent tuning. + E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 + and max-autotune figure out C3 is the best. + + Then if coordinate descnt tuning is run with max-autotune disabled, it will start from C1; + while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. + """ + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): + # skip triton template + return launcher + + cloned_args, _ = self.clone_args(*args) + config2launcher = {launcher.config: launcher} + + def benchmark_one_config(config): + with self.lock: + _, launcher = self._precompile_config(config, None) + config2launcher[config] = launcher + + out = self.bench(launcher, *cloned_args, **kwargs) + log.debug( + "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", + launcher.config, + out, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + return out + + assert not ( + self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION + and "RBLOCK" in launcher.config.kwargs + ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK" + best_config = self.coordesc_tuner.autotune( + benchmark_one_config, launcher.config, None + ) + best_config.found_by_coordesc = True + + if self.save_cache_hook: + self.save_cache_hook(best_config, found_by_coordesc=True) + return config2launcher.get(best_config) + + def run(self, *args, grid, stream, **kwargs): + if len(self.launchers) != 1: + if len(self.launchers) == 0: + self.precompile() + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, grid=grid, **kwargs) + + if ( + not getattr(self.launchers[0].config, "found_by_coordesc", False) + and config.coordinate_descent_tuning + ): + self.launchers = [ + self.coordinate_descent_tuning( + self.launchers[0], *args, grid=grid, **kwargs + ) + ] + + (launcher,) = self.launchers + if launcher.store_cubin: + self.save_cuda_kernel(grid, stream, launcher) + + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs} + ) + + # guard the record_function_ctx and only call it if profiling is currently + # in progress, to reduce latency when profiler is not turned on. Note that + # the "if" statement (instead of, say, a contextlib.nullcontext) is intentional; + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + with self.record_function_ctx: + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + + +def _find_names(obj): + import gc + import inspect + + frame = inspect.currentframe() + while frame is not None: + frame.f_locals + frame = frame.f_back + obj_names = [] + for referrer in gc.get_referrers(obj): + if isinstance(referrer, dict): + for k, v in referrer.items(): + if v is obj: + obj_names.append(k) + return obj_names + + +collected_calls: List[Any] = [] + + +def start_graph(): + collected_calls.clear() + + +def end_graph(): + if len(collected_calls) == 0: + return + overall_time = sum(call[0] for call in collected_calls) + overall_gb = sum(call[1] for call in collected_calls) + cur_file = inspect.stack()[1].filename + summary_str = ( + f"SUMMARY ({cur_file})\n" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s" + ) + print(summary_str) + print() + output_file = config.profile_bandwidth_output + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.debug("Save profile bandwidth results to %s", output_file) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms/overall_time*100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, + num_gb, + gb_per_s, + suffix=suffix, + color=False, + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception as e: + log.warning( + "failed to write profile bandwidth result into %s: %s", + output_file, + e, + ) + + +class DebugAutotuner(CachingAutotuner): + def __init__(self, *args, regex_filter="", **kwargs): + self.regex_filter = regex_filter + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, grid, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=grid, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench(launcher, *args, grid=grid) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = (ms, num_gb, gb_per_s, kernel_name) + else: + ms, num_gb, gb_per_s, kernel_name = self.cached + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}") + ) + + +def hash_configs(configs: List[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() + ) + return hasher.hexdigest() + + +def load_cached_autotuning( + best_config, + configs_hash: str, + configs: List[Config], +): + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 or filename + save_cache_hook: Optional[Callable[[Any, Any], Any]] + inductor_meta = {} if inductor_meta is None else inductor_meta + + # on disk caching logic and/or remote caching + if filename is not None and (len(configs) > 1 or config.coordinate_descent_tuning): + configs_hash = hash_configs(configs) + + cache_filename = None + remote_cache = None + remote_cache_key = None + if config.use_autotune_local_cache: + cache_filename = os.path.splitext(filename)[0] + ".best_config" + if config.use_autotune_remote_cache or ( + config.is_fbcode() + and torch._utils_internal.justknobs_check( + "pytorch/autotune_remote_cache:enable" + ) + ): + backend_hash = inductor_meta.get("backend_hash", None) + if backend_hash is not None: + key = backend_hash + configs_hash + "autotune-best-config" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + try: + if config.is_fbcode(): + remote_cache = ( + triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend( + key, is_autotune=True + ) + ) + else: + remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) + except Exception: + remote_cache = None + log.warning("Unable to create a remote cache", exc_info=True) + # we already sha256 hash the source contents + remote_cache_key = os.path.basename(filename) + else: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + + best_config = None + if cache_filename is not None and os.path.exists(cache_filename): + with open(cache_filename) as fd: + best_config = json.loads(fd.read()) + elif remote_cache is not None and remote_cache_key is not None: + cache_outs = remote_cache.get([remote_cache_key]) + cache_out = cache_outs.get(remote_cache_key, None) + best_config = json.loads(cache_out) if cache_out else None + + best_config = load_cached_autotuning(best_config, configs_hash, configs) + if best_config: + configs = [best_config] + + def save_cache_hook(cfg, found_by_coordesc=False): + data = json.dumps( + { + **cfg.kwargs, + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + "configs_hash": configs_hash, + "found_by_coordesc": found_by_coordesc, + } + ) + if cache_filename is not None: + with open(cache_filename, "w") as fd: + fd.write(data) + if remote_cache is not None and remote_cache_key is not None: + remote_cache.put(remote_cache_key, data) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, cache_filename) + + else: + save_cache_hook = None + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + assert tconfig.kwargs["XBLOCK"] == 1 + tconfig.kwargs.pop("XBLOCK") + + if config.profile_bandwidth: + return DebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=config.profile_bandwidth_regex, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + ) + return CachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + ) + + return decorator + + +def unique_configs(configs: List[Config]): + """Remove duplicate configurations""" + seen = set() + pruned_configs = [] + + for cfg in configs: + key = triton_config_to_hashable(cfg) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + + +def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): + for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): + if numel is None: + continue + block = cfg[f"{label}BLOCK"] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" + f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." + ) + max_block = config.triton.max_block[label] + max_block_str = f'config.triton.max_block["{label}"]' + assert max_block % block == 0, ( + f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" + f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." + ) + + +def triton_config( + size_hints, + x, + y=None, + z=None, + num_stages=1, + num_elements_per_warp=256, + min_elem_per_thread=0, +) -> Config: + """ + Construct a pointwise triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + + num_elements_per_warp is a suggestion for controlling how many warps + the triton config should contain. e.g.: if x=16, y=8, z=4 then + num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, + we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's + just a suggestion, and sometimes other adjustment heuristics will + override the num_elements_per_warp. + + min_elem_per_thread controls the minimum number of elements + processed by each thread. It's always enforced. + """ + # Ideally we want to read this from some device config + + # for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK + size_hints = list(reversed(size_hints)) + + maxGridSize = [2147483647, 65535, 65535] + + target = conditional_product(x, y, z) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + if y: + y = min(y, size_hints[1]) + if z: + z = min(z, size_hints[2]) + + # if we are below original block size, scale up where we can; + # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension + while x < min(size_hints[0], config.triton.max_block["X"]) and ( + x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target + ): + x *= 2 + while ( + y + and y < min(size_hints[1], config.triton.max_block["Y"]) + and ( + y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target + ) + ): + y *= 2 + while ( + z + and z < min(size_hints[2], config.triton.max_block["Z"]) + and ( + z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target + ) + ): + z *= 2 + + num_warps = next_power_of_2( + min(max(conditional_product(x, y, z) // num_elements_per_warp, 1), 8) + ) + # we are going to arrive at 2 warps only if bs was too small due to + # numel being too small. However to workaround some ptx bugs we still + # want at least 4 warps if there's enough elements per thread + # given that this is a rare situation, don't expect this to affect perf + # in general + # see https://github.com/pytorch/pytorch/pull/97950 + num_warps = max(num_warps, 4) if conditional_product(x, y, z) >= 128 else num_warps + xnumel = size_hints[0] + ynumel = size_hints[1] if y else None + znumel = size_hints[2] if z else None + + # Increase x to satisfy min_elem_per_thread requirements. + block_size = max( + conditional_product(x, y, z), + min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, + ) + x *= math.ceil(block_size / conditional_product(x, y, z)) + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def triton_config_reduction(size_hints, x, r, num_stages=1, num_warps=None) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + + target = conditional_product(x, r) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + r = min(r, size_hints[1]) + + # if we are below original block size, scale up where we can + while x < size_hints[0] and conditional_product(x, r) < target: + x *= 2 + while r < size_hints[1] and conditional_product(x, r) < target: + r *= 2 + + cfg = {"XBLOCK": x, "RBLOCK": r} + if num_warps is None: + num_warps = conditional_product(x, r) // 128 + num_warps = next_power_of_2(min(max(num_warps, 2), 8)) + check_config(cfg, xnumel=size_hints[0]) + assert ( + r <= config.triton.max_block["R"] + ), f"increase config.triton.MAX_BLOCK['r'] to {r}" + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): + """ + Construct a tile reduction triton config with some adjustment + heuristics based on size_hints. Size_hints is a tuple of numels in + each tile dimension and will be rounded up to the nearest power of 2. + """ + + target = conditional_product(x, y, r) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + y = min(y, size_hints[1]) + r = min(r, size_hints[2]) + + # if we are below original block size, scale up where we can + while x < size_hints[0] and conditional_product(x, y, r) < target: + x *= 2 + while r < size_hints[2] and conditional_product(x, y, r) < target: + r *= 2 + while y < size_hints[1] and conditional_product(x, y, r) < target: + y *= 2 + + cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r} + num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8)) + check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) + assert ( + r <= config.triton.max_block["R"] + ), f"increase config.triton.MAX_BLOCK['r'] to {r}" + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + inductor_meta = {} if inductor_meta is None else inductor_meta + assert not inductor_meta.get("no_x_dim") + + numel = functools.reduce(operator.mul, size_hints) + bs = max(256, min(numel // 128, 1024)) + + hinted_configs = autotune_hints_to_configs( + inductor_meta.get("autotune_hints", set()), size_hints, bs + ) + + triton_config_with_settings = functools.partial( + triton_config, min_elem_per_thread=min_elem_per_thread + ) + + if len(size_hints) == 1: + if disable_pointwise_autotuning() and not ( + config.max_autotune or config.max_autotune_pointwise + ): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, bs)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + else: + return cached_autotune( + size_hints, + [ + triton_config_with_settings( + size_hints, bs, num_elements_per_warp=256 + ), + triton_config_with_settings( + size_hints, bs // 2, num_elements_per_warp=64 + ), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + if len(size_hints) == 2: + if (disable_pointwise_autotuning() or tile_hint == TileHint.SQUARE) and not ( + config.max_autotune or config.max_autotune_pointwise + ): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, 32, 32)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + return cached_autotune( + size_hints, + [ + triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, bs, 1), + triton_config_with_settings(size_hints, 1, bs), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.POINTWISE, + ) + if len(size_hints) == 3: + if disable_pointwise_autotuning(): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, 16, 16, 16)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + return cached_autotune( + size_hints, + [ + triton_config_with_settings(size_hints, 16, 16, 16), + triton_config_with_settings(size_hints, 64, 8, 8), + triton_config_with_settings(size_hints, 8, 64, 8), + triton_config_with_settings(size_hints, 8, 8, 64), + triton_config_with_settings(size_hints, bs, 1, 1), + triton_config_with_settings(size_hints, 1, bs, 1), + triton_config_with_settings(size_hints, 1, 1, bs), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.POINTWISE, + ) + raise NotImplementedError(f"size_hints: {size_hints}") + + +def _reduction_configs( + *, size_hints: List[int], inductor_meta: Dict[str, Any] +) -> List[Config]: + reduction_hint = inductor_meta.get("reduction_hint", None) + assert len(size_hints) == 2 + rnumel = size_hints[-1] + + contiguous_config = triton_config_reduction( + size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048) + ) + outer_config = triton_config_reduction(size_hints, 64, 8) + tiny_config = triton_config_reduction( + size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048) + ) + if config.max_autotune or config.max_autotune_pointwise: + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return [tiny_config] + if disable_pointwise_autotuning(): + return [triton_config_reduction(size_hints, 32, 128)] + return [ + contiguous_config, + outer_config, + tiny_config, + triton_config_reduction(size_hints, 64, 64), + triton_config_reduction(size_hints, 8, 512), + # halve the XBLOCK/RBLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + triton_config_reduction(size_hints, 64, 4, num_warps=8), + ] + + +def reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + assert triton_meta is not None + rnumel = size_hints[-1] + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def persistent_reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + xnumel, rnumel = size_hints + + configs = [ + triton_config_reduction(size_hints, xblock, rnumel) + for xblock in (1, 8, 32, 128) + if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel) + ] + + # TODO(jansel): we should be able to improve these heuristics + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = [ + triton_config_reduction( + size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel + ) + ] + for c in configs: + # we don't need RBLOCK for persistent reduction + c.kwargs.pop("RBLOCK") + + if disable_pointwise_autotuning(): + configs = configs[:1] + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def split_scan( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """Heuristic for TritonSplitScanKernel""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + assert triton_meta is not None + rnumel = size_hints[-1] + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + + # Fixup configs to enforce the minimum RBLOCK size + min_rblock = config.triton.min_split_scan_rblock + for cfg in configs: + if cfg.kwargs["RBLOCK"] < min_rblock: + cfg.kwargs["RBLOCK"] = min_rblock + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.SPLIT_SCAN, + filename=filename, + ) + + +def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None): + """ + Compile a triton template + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def user_autotune( + configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False +): + """ + Compile a user defined triton kernel + """ + defaults = inspect.signature(triton.Config).parameters + default_num_stages = defaults["num_stages"].default + default_num_warps = defaults["num_warps"].default + + if len(configs) == 0: + configs = [ + triton.Config( + {}, num_stages=default_num_stages, num_warps=default_num_warps + ) + ] + else: + configs = [ + triton.Config( + c.get("kwargs", {}), + num_stages=c.get("num_stages", default_num_stages), + num_warps=c.get("num_warps", default_num_warps), + ) + for c in configs + ] + + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def grid(*numels): + """Helper function to compute triton grids""" + if len(numels) == 1: + xnumel, ynumel, znumel = numels[0], None, None + elif len(numels) == 2: + xnumel, ynumel, znumel = numels[1], numels[0], None + elif len(numels) == 3: + xnumel, ynumel, znumel = numels[2], numels[1], numels[0] + else: + raise AssertionError(f"invalid size for numels {len(numels)}") + + def get_grid_dim(numel, block): + if numel is None: + return 1 + if block is None: + return numel + return ceildiv(numel, block) + + max_grid_dims = config.triton.max_tiles + + def grid_fn(meta): + x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1)) + y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None)) + + MAX_Y_GRID = get_max_y_grid() + if znumel is None and max_grid_dims <= 2: + div = ceildiv(y_grid, MAX_Y_GRID) + y_grid = y_grid // div + z_grid = div + else: + z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None)) + torch._check( + y_grid <= MAX_Y_GRID, + lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue", + ) + + return ( + x_grid, + y_grid, + z_grid, + ) + + return grid_fn + + +def split_scan_grid(xnumel, rnumel): + def grid_fn(meta): + assert meta.get("XBLOCK", 1) == 1 + return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1) + + return grid_fn diff --git a/MLPY/Lib/site-packages/torch/_inductor/utils.py b/MLPY/Lib/site-packages/torch/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..693d80727bc5dd956a2f280a54ba1df082a439d3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/utils.py @@ -0,0 +1,1428 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import enum +import functools +import getpass +import inspect +import io +import itertools +import logging +import math +import operator +import os +import platform +import re +import shutil +import sys +import tempfile +import textwrap +import time +import unittest +from dataclasses import fields +from datetime import datetime +from io import StringIO +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + NamedTuple, + Optional, + Protocol, + Set, + TypeVar, + Union, + ValuesView, +) +from unittest import mock + +import sympy +from typing_extensions import Concatenate, ParamSpec + +import torch +from torch._dynamo.device_interface import get_interface_for_device +from torch.autograd import DeviceType +from torch.autograd.profiler_util import EventList +from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing +from . import config + +log = logging.getLogger(__name__) + +_T = TypeVar("_T") +VarRanges = Dict[sympy.Expr, sympy.Expr] + + +def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + # Benchmark + for i in range(n_repeat): + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + fn() + # Record clocks + torch.cuda.synchronize() + + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] + ) + if len(filtered_events) % n_repeat != 0: + raise RuntimeError( + "Failed to divide all profiling events into #repeat groups. " + "#CUDA events: %d, #repeats: %s", + len(filtered_events), + n_repeat, + ) + num_event_per_group = len(filtered_events) / n_repeat + actual_events = EventList( + [ + event + for i, event in enumerate(filtered_events) + if i % num_event_per_group != 0 + ] + ) + actual_events._build_tree() + actual_events = actual_events.key_averages() + + log.debug("profiling time breakdown") + log.debug(actual_events.table(row_limit=-1)) + + res = sum(event.cuda_time_total for event in actual_events) / 1000.0 / n_repeat + log.debug("profiling results: %s ms", res) + return res + + +def do_bench(*args, **kwargs): + @functools.lru_cache(None) + def load_triton(): + try: + # NB: Lazily load triton, as importing triton is slow + # see https://github.com/openai/triton/issues/1599 + from triton.testing import do_bench as triton_do_bench + except ImportError as exc: + raise NotImplementedError("requires Triton") from exc + + # triton PR https://github.com/openai/triton/pull/1513 change the + # quantile fields name from 'percentiles' to 'quantiles' + # and change the default value from (0.5, 0.2, 0.8) to None. + # This may break inductor since a caller expects a tuple may get a item. + # + # Add a wrapper to maintain the same behavior for inductor. + # Maybe we should have own implementation of this function? + return triton_do_bench, ( + "quantiles" + if inspect.signature(triton_do_bench).parameters.get("quantiles") + is not None + else "percentiles" + ) + + triton_do_bench, quantile_field_name = load_triton() + + if quantile_field_name not in kwargs: + kwargs[quantile_field_name] = (0.5, 0.2, 0.8) + return triton_do_bench(*args, **kwargs)[0] + + +@functools.lru_cache(None) +def has_torchvision_roi_align() -> bool: + try: + from torchvision.ops import roi_align # noqa: F401 + + return roi_align is not None and hasattr( + getattr(torch.ops, "torchvision", None), "roi_align" + ) + except ImportError: + return False + + +def conditional_product(*args): + return functools.reduce(operator.mul, [x for x in args if x]) + + +def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: + if device is None: + return torch.tensor(0.0).device # default device + if isinstance(device, str): + device = torch.device(device) + if device.type != "cpu" and device.index is None: + device_interface = get_interface_for_device(device.type) + return torch.device(device.type, index=device_interface.Worker.current_device()) + return device + + +def sympy_product(it): + return functools.reduce(operator.mul, it, sympy.Integer(1)) + + +def sympy_dot(seq1, seq2): + assert len(seq1) == len(seq2) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + + +def unique(it: Iterable[_T]) -> ValuesView[_T]: + return {id(x): x for x in it}.values() + + +def ceildiv( + numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] +) -> Union[int, sympy.Expr]: + if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(numer, denom) + # TODO: There is a bug in a call to this function, to repro: + # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy + # --amp --only YituTechConvBert --dynamic-shapes + assert isinstance(numer, int) and isinstance( + denom, int + ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" + return -(numer // -denom) + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def _type_of(key): + # Use the function here to get rid of dependencies on the Triton during the codegen. + # Refer to Triton implementation here: + # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + for v in list(tys.values()): + tys[v] = v + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + +def convert_shape_to_inductor( + lst: Iterable[Union[int, torch.SymInt]] +) -> List[sympy.Expr]: + """ + Gets the shape and stride of a tensor. For non-symbolic tensors, this is + trivial. But for symbolic tensors, we need to map from SymIntNode into + sympy.Expr. + """ + return [ + i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst + ] + + +def convert_shape_to_symint( + lst: Iterable[Union[int, sympy.Expr]] +) -> List[Union[int, torch.SymInt]]: + """ + Takes a list of shapes from Inductor and converts them into symints (or just + ints if all shapes are static). + """ + from .virtualized import V + + return [ + i + if isinstance(i, int) + else int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + for i in lst + ] + + +def is_view(op: torch._ops.OpOverload): + """ + Does this op overload have aliasing + """ + assert isinstance(op, torch._ops.OpOverload) + return any(a.alias_info is not None for a in op._schema.arguments) + + +def is_pointwise_use(use): + if not use.op == "call_function": + return False + + if not ( + isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem + ): + return False + + if use.target is operator.getitem or is_view(use.target): + return all(is_pointwise_use(u) for u in use.users) + + return torch.Tag.pointwise in use.target.tags + + +def gen_gm_and_inputs(target, args, kwargs): + g = torch.fx.Graph() + g_args = [] + a_args = [] + for n, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + g_args.append(g.placeholder(f"arg{n}")) + a_args.append(arg) + else: + g_args.append(arg) + assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) + node = g.call_function(target, tuple(g_args), kwargs) + if ( + len(target._schema.returns) == 1 + and str(target._schema.returns[0].type) == "Tensor" + ): + node = (node,) + g.output(node) + + gm = torch.fx.GraphModule({}, g) + return gm, a_args + + +def synchronize(device: str = "cuda"): + if device == "cpu": + return + device_interface = get_interface_for_device(device) + if device_interface.is_available(): + device_interface.synchronize() + + +def timed( + model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" +) -> float: + synchronize(device) + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize(device) + t1 = time.perf_counter() + # GC the result after timing + assert result is not None # type: ignore[possibly-undefined] + return t1 - t0 + + +def print_performance( + fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" +): + timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) + took = torch.median(timings) / times + print(f"{took/baseline:.6f}") + return took + + +def precompute_method(obj: Any, method: str): + """Replace obj.method() with a new method that returns a precomputed constant.""" + result = getattr(obj, method)() + setattr(obj, method, lambda: result) + + +def precompute_methods(obj: Any, methods: List[str]): + """Replace methods with new methods that returns a precomputed constants.""" + for method in methods: + precompute_method(obj, method) + + +def cmp(a, b) -> int: + return int(a > b) - int(a < b) + + +def pad_listlike(x, size): + if len(x) == 1: + return type(x)([x[0]]) * size + else: + return x + + +# Used to ensure that iterating over a set is deterministic +def tuple_sorted(x): + if len(x) == 0: + return [] + + def sort_func(elem): + if isinstance(elem, str): + return elem + else: + # We expect `elem` to be `scheduler.BaseSchedulerNode` type here, + # but we are not able to do isinstance assert because of circular dependency + return elem.get_name() + + return sorted(x, key=sort_func) + + +P = ParamSpec("P") +RV = TypeVar("RV", covariant=True) + + +class CachedMethod(Generic[P, RV], Protocol): + @staticmethod + def clear_cache(self) -> None: + ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: + ... + + +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: + key = f"__{fn.__name__}_cache" + + @functools.wraps(fn) + def wrapper(self): + if not hasattr(self, key): + setattr(self, key, fn(self)) + return getattr(self, key) + + def clear_cache(self): + if hasattr(self, key): + delattr(self, key) + + wrapper.clear_cache = clear_cache # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +def aggregate_origins(node_schedule): + from . import ir + + if isinstance(node_schedule, list): + return functools.reduce( + operator.or_, + [ + node.node.origins + for node in node_schedule + if hasattr(node, "node") and node.node + ], + set(), + ) + elif isinstance(node_schedule, ir.ExternKernel): + return node_schedule.origins + else: + return set() + + +def get_fused_kernel_name(node_schedule, descriptive_names): + all_origins = aggregate_origins(node_schedule) + if descriptive_names == "original_aten": + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + origin.meta["original_aten"]._overloadpacket.__name__ + for origin in all_origins + if origin.op == "call_function" + and "original_aten" in origin.meta + and origin.meta["original_aten"] is not None + ] + sources = sorted(set(sources)) + elif descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function" and "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + if isinstance(source_fn[1], str): + sources.append(source_fn[1]) + else: + sources.append(source_fn[1].__name__) + sources = sorted(set(sources)) + elif descriptive_names == "inductor_node": + sources = [ + origin.name for origin in all_origins if origin.op == "call_function" + ] + else: + raise NotImplementedError + sources = sources + return "_".join(["fused"] + sources) + + +def get_kernel_metadata(node_schedule, wrapper): + all_origins = aggregate_origins(node_schedule) + inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] + + from_node_dict = collections.defaultdict(list) + original_aten_dict = collections.defaultdict(list) + for node in inductor_nodes: + if "original_aten" in node.meta and node.meta["original_aten"] is not None: + key = str(node.meta["original_aten"]._overloadpacket) + original_aten_dict[key].append(node.name) + if "from_node" in node.meta: + key = node.meta["from_node"][0][0] + from_node_dict[key].append(node.name) + metadata = ( + f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], " + f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]" + ) + # trace back to original node here + detailed_metadata = [] + for original_node, nodes in sorted(from_node_dict.items()): + detailed_metadata.append( + f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" + ) + return metadata, "\n".join(detailed_metadata) + + +def dominated_nodes( + initial_queue: Iterable[torch.fx.Node], skip_filter=None +) -> Set[torch.fx.Node]: + """Returns the set of nodes whose values depend on those within initial_queue""" + initial_queue = list(initial_queue) + dominated_set = set(initial_queue) + + while initial_queue: + node = initial_queue.pop() + for user in node.users: + if skip_filter and skip_filter(user): + continue + if user not in dominated_set: + dominated_set.add(user) + initial_queue.append(user) + + return dominated_set + + +def gather_origins(args, kwargs): + import itertools + + from . import ir + + def is_unrealized_node(n): + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return set(itertools.chain(*arg_origins, *kwarg_origins)) + + +def sympy_str(expr: sympy.Expr) -> str: + """ + Normal sympy str is very slow, this is a lot faster. The result are + somewhat worse, as it doesn't do as much simplification. So don't + use this for final codegen. + """ + if isinstance(expr, sympy.Symbol): + return expr.name + if isinstance(expr, sympy.Add): + return " + ".join(map(sympy_str, expr.args)) + if isinstance(expr, sympy.Mul): + return " * ".join(map(sympy_str, expr.args)) + + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): + return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" + return str(expr) + + +def sympy_index_symbol(name: str) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol(replaced, replacement): + assert isinstance(replaced, sympy.Expr) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def free_symbol_startswith(index: sympy.Expr, prefix: str): + return any(v.name.startswith(prefix) for v in index.free_symbols) # type: ignore[attr-defined] + + +def free_symbol_has(index: sympy.Expr, pattern: str): + return any(pattern in v.name for v in index.free_symbols) # type: ignore[attr-defined] + + +def is_symbolic(a: Any) -> bool: + return isinstance(a, torch.SymInt) or ( + isinstance(a, torch.Tensor) + and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) + ) + + +def any_is_symbolic(*args: Any) -> bool: + return any(is_symbolic(a) for a in args) + + +def has_incompatible_cudagraph_ops(gm): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + forbidden_set = { + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.multinomial.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + "run_and_save_rng_state", + "run_with_rng_state", + "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", + } + if torch.are_deterministic_algorithms_enabled(): + forbidden_set.update( + { + "aten._unsafe_index_put.default", + "aten.index_put.default", + "aten.index_put_.default", + "aten.scatter.src", + "aten.scatter.reduce", + "aten.scatter.value_reduce", + "aten.scatter_add_", + "aten.scatter_add.default", + "aten.scatter_reduce.two", + "aten.scatter_reduce_.two", + "aten.scatter_reduce.two_out", + } + ) + for node in gm.graph.nodes: + if str(node.target) in forbidden_set: + return True + if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): + return True + return False + + +def output_node(gm: torch.fx.GraphModule): + """Get the output node from an FX graph""" + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + return last_node + + +# Attempt to import AttrsDescriptor from Triton +try: + from triton.compiler.compiler import AttrsDescriptor + + attrs_descriptor_available = True + # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor + attr_desc_fields = {f.name for f in fields(AttrsDescriptor)} + ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields + divisible_by_8_available = "divisible_by_8" in attr_desc_fields +except ImportError: + attrs_descriptor_available = False + +# Define `instance_descriptor` function with clear conditional handling +if attrs_descriptor_available: + + def instance_descriptor( + divisible_by_16=None, + equal_to_1=None, + ids_of_folded_args=None, + divisible_by_8=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor + if ids_of_folded_args_available: + kwargs["ids_of_folded_args"] = ids_of_folded_args + if divisible_by_8_available: + kwargs["divisible_by_8"] = divisible_by_8 + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) + +else: + # Define a namedtuple as a fallback when AttrsDescriptor is not available + instance_descriptor = collections.namedtuple( # type: ignore[no-redef] + "instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], + defaults=[tuple(), tuple(), tuple(), tuple()], + ) + + +@functools.lru_cache(None) +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + cache_dir = os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +@contextlib.contextmanager +def fresh_inductor_cache(cache_entries=None): + """ + Contextmanager that provides a clean tmp cachedir for inductor. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + with tempfile.TemporaryDirectory() as inductor_cache_dir: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + + +def argsort(seq) -> List[int]: + # preserve original order for equal strides + getter = seq.__getitem__ + a_r = range(len(seq)) + return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 + + +@functools.lru_cache(8) +def get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +class LineContext(NamedTuple): + context: Any + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent=0): + self._lines = [] + self._indent = initial_indent + + def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: + buf = StringIO() + p = 1 + linemap = [] + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + linemap.append((p, line.context)) + continue + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + p += 1 + line.count("\n") + return buf.getvalue(), linemap + + def getvalue(self) -> str: + v, _ = self.getvaluewithlinemap() + return v + + def getrawvalue(self) -> str: + buf = StringIO() + for line in self._lines: + if isinstance(line, DeferredLineBase): + line = line() + if line is None: + continue + elif isinstance(line, LineContext): + continue + assert isinstance(line, str) + # backslash implies line continuation + if line.endswith("\\"): + buf.write(line[:-1]) + else: + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self): + self._lines.clear() + + def __bool__(self): + return bool(self._lines) + + def prefix(self): + return " " * (self._indent * self.tabwidth) + + def newline(self): + self.writeline("\n") + + def writeline(self, line): + if isinstance(line, LineContext): + self._lines.append(line) + elif isinstance(line, DeferredLineBase): + self._lines.append(line.with_prefix(self.prefix())) + elif line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + def do_indent(self, offset=1): + self._indent += offset + + def do_unindent(self, offset=1): + self._indent -= offset + + def splice(self, other_code, strip=False): + if isinstance(other_code, IndentedBuffer): + dedent = float("inf") + for line in other_code._lines: + if not isinstance(line, LineContext) and line: + dedent = min(dedent, len(line) - len(line.lstrip())) + if math.isinf(dedent): + dedent = 0 + for line in other_code._lines: + if isinstance(line, LineContext): + self._lines.append(line) + else: + IndentedBuffer.writeline(self, line[int(dedent) :]) + else: + other_code = textwrap.dedent(other_code) + if strip: + other_code = other_code.lstrip() + if not other_code: + return + other_code = other_code.rstrip() + for line in other_code.split("\n"): + self.writeline(line) + + def __repr__(self): + return f"{type(self)}({self.getvalue()})" + + +class DeferredLineBase: + """A line that can be 'unwritten' at a later time""" + + def __init__(self, line): + if not line.strip(): + line = "" + self.line = line + + def __call__(self) -> Optional[str]: + """Returns either self.line or None to indicate the line has been 'unwritten'""" + raise NotImplementedError() + + def _new_line(self, line: str) -> DeferredLineBase: + """Returns a new deferred line with the same condition""" + raise NotImplementedError() + + def with_prefix(self, prefix): + return self._new_line(f"{prefix}{self.line}") + + def lstrip(self): + return self._new_line(self.line.lstrip()) + + def __getitem__(self, index): + return self._new_line(self.line[index]) + + def __bool__(self): + return bool(self.line) + + def __len__(self): + return len(self.line) + + +@functools.lru_cache(None) +def is_big_gpu(index): + sms = torch.cuda.get_device_properties(index).multi_processor_count + if sms < 80: # V100 + log.warning("not enough SMs to use max_autotune_gemm mode") + return False + return True + + +def use_max_autotune() -> bool: + return ( + config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache + ) + + +def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: + return ( + use_max_autotune() + and layout.device.type == "cuda" + and layout.dtype in allowed_layout_dtypes + and is_big_gpu(layout.device.index or 0) + ) + + +def _use_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") + ] + + +def use_triton_template(layout, *, enable_int32=False): + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + if enable_int32: + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( + "TRITON" + ) + + +def use_cutlass_template(layout): + from .codegen.cuda.cutlass_utils import try_import_cutlass + + # Do not use cutlass template on ROCm + if torch.version.hip: + return False + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( + "CUTLASS" + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cuda.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." + ) + return False + return res + + +def use_aten_gemm_kernels(): + return not use_max_autotune() or _use_autotune_backend("ATEN") + + +class DebugDirManager: + counter = itertools.count(0) + prev_debug_name: str + + def __init__(self): + self.id = next(DebugDirManager.counter) + + def __enter__(self): + self.prev_debug_name = torch._dynamo.config.debug_dir_root + self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" + torch._dynamo.config.debug_dir_root = self.new_name + + def __exit__(self, *args): + shutil.rmtree(self.new_name) + torch._dynamo.config.debug_dir_root = self.prev_debug_name + + +def run_and_get_code(fn, *args, **kwargs): + from .graph import GraphLowering + + compile_to_module = GraphLowering.compile_to_module + source_codes = [] + + def patched_compile_to_module(self): + mod = compile_to_module(self) + with open(mod.__file__) as f: + source_codes.append(f.read()) + return mod + + # If FX code caching is enabled, a hit prevents getting the code. + with config.patch({"fx_graph_cache": False}): + with mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ): + torch._dynamo.reset() + result = fn(*args, **kwargs) + return result, source_codes + + +def run_and_get_triton_code(fn, *args, **kwargs): + _, source_codes = run_and_get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert ( + 1 <= len(source_codes) <= 2 + ), f"expected one or two code outputs got {len(source_codes)}" + return source_codes[0] + + +@contextlib.contextmanager +def override_lowering(aten_op, override_fn): + """ + Override the lowering of aten_op with override_fn. + The first argument of override_fn is the original lowering fn. + """ + from torch._inductor import lowering + + orig_fn = lowering.lowerings[aten_op] + try: + lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) + yield + finally: + lowering.lowerings[aten_op] = orig_fn + + +def add_scheduler_init_hook(pre_fn, post_fn=None): + """ + Add hook functions to be called at the beginning and end of Scheduler.__init__. + Used for unit tests. + """ + from torch._inductor.scheduler import Scheduler + + orig_fn = Scheduler.__init__ + + def wrapper(scheduler, nodes): + pre_fn(scheduler, nodes) + out = orig_fn(scheduler, nodes) + if post_fn: + post_fn(scheduler, nodes) + return out + + return unittest.mock.patch.object(Scheduler, "__init__", wrapper) + + +def developer_warning(msg): + """ + Warnings that will be actionable for PyTorch developers, but not + end users. Allows us to easily disable them in stable releases but + keep them on for nightly builds. + """ + if config.developer_warnings: + log.warning(msg) + else: + log.info(msg) + + +def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: + """ + Return the total number of bytes the arguments of tensor type takes. + + For in/out args, tensor sizes are counted twice: once for reading and + once for writing. + + The first num_in_out_args arguments are in out tensors. + """ + return sum( + arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) + for i, arg in enumerate(args) + if isinstance(arg, torch.Tensor) + ) + + +def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True): + info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" + slow = ms > 0.012 and gb_per_s < 650 + return red_text(info_str) if color and slow else info_str + + +def get_benchmark_name(): + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + +def is_ones(items): + return all(x == 1 for x in items) + + +def is_zeros(items): + return all(x == 0 for x in items) + + +def is_cpu_device(inputs): + return all( + item.device == torch.device("cpu") + for item in inputs + if isinstance(item, torch.Tensor) + ) + + +def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: + assert isinstance( + val, sympy.Expr + ), "only support sympy.Expr as input to get_sympy_Expr_dtype" + if val.is_integer: # type: ignore[attr-defined] + return torch.int64 + else: + return torch.float64 + + +@contextlib.contextmanager +def maybe_profile(should_profile, *args, **kwargs): + if should_profile: + with torch.profiler.profile(*args, **kwargs) as p: + yield p + else: + yield + + +def triton_config_to_hashable(cfg): + """ + Convert triton config to a tuple that can uniquely identify it. We can use + the return value as a dictionary key. + """ + items = sorted(cfg.kwargs.items()) + items.append(("num_warps", cfg.num_warps)) + items.append(("num_stages", cfg.num_stages)) + return tuple(items) + + +def parallel_num_threads(): + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + +HAS_COLORAMA = True +try: + import colorama +except ImportError: + HAS_COLORAMA = False + + +def _color_text(msg, color): + if not HAS_COLORAMA: + return msg + + return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET + + +def green_text(msg): + return _color_text(msg, "green") + + +def yellow_text(msg): + return _color_text(msg, "yellow") + + +def red_text(msg): + return _color_text(msg, "red") + + +def blue_text(msg): + return _color_text(msg, "blue") + + +@functools.lru_cache(None) +def get_device_tflops(dtype): + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + + assert dtype in (torch.float16, torch.bfloat16, torch.float32) + + if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): + # Triton API change in https://github.com/openai/triton/pull/2293 + from torch._utils_internal import max_clock_rate + + sm_clock = max_clock_rate() + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype, sm_clock) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32, sm_clock) + else: + return get_max_simd_tflops(torch.float32, sm_clock) + else: + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32) + else: + return get_max_simd_tflops(torch.float32) + + +@functools.lru_cache(None) +def get_gpu_dram_gbps(): + from triton.testing import get_dram_gbps + + return get_dram_gbps() + + +def is_welford_reduction(reduction_type): + return reduction_type.startswith("welford") + + +def reduction_num_outputs(reduction_type): + return 3 if is_welford_reduction(reduction_type) else 1 + + +def get_max_y_grid(): + return 65535 + + +def is_linux() -> bool: + return platform.system() == "Linux" + + +def has_free_symbols(itr: Iterable[Any]): + return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) + + +def is_dynamic(*args): + from . import ir + + for t in args: + if isinstance(t, ir.TensorBox): + if has_free_symbols(t.data.get_size()) or ( + hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) + ): + return True + elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): + assert hasattr(t, "get_size") and hasattr(t, "get_stride") + if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): + return True + elif not isinstance(t, ir.IRNode): + continue + else: + raise TypeError(f"unexpected type for is_dynamic {type(t)}") + + return False + + +# Placeholder strings used in triton codegen. +class Placeholder(enum.Enum): + # The placeholder for the actual name of a triton kernel. + # e.g. for "def triton_" it would be "triton_" + KERNEL_NAME = "KERNEL_NAME" + + # The descriptive name of the triton kernel; when unique_kernel_names = False, this + # placeholder will be replaced with a string with more information. + DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" + + +def pass_execution_and_save(func, gm, msg): + from .pattern_matcher import stable_topological_sort + + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + before_io = io.StringIO() + after_io = io.StringIO() + print(f"Before:\n{gm.graph}", file=f) + print(gm.graph, file=before_io) + start_time = datetime.now() + func(gm.graph) + time_elapsed = datetime.now() - start_time + # recompile graph + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + print(f"After:\n{gm.graph}", file=f) + print(gm.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + log.info( + "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", + msg, + f.name, + t, + time_elapsed, + ) + + +def is_collective(node): + from . import ir + + return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel + + +def is_wait(node): + from . import ir + + return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel + + +def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): + "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" + num_rng_seed_offset_inputs = ( + 2 if torch._functorch.config.functionalize_rng_ops else 0 + ) + return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs + + +def count_tangents(fx_g: torch.fx.GraphModule): + """ + Infers which inputs are static for a backwards graph + """ + + def is_saved_tensor(x): + return ( + "tangents" not in x.name + and "bwd_seed" not in x.name + and "bwd_base_offset" not in x.name + ) + + arg_count = 0 + static_arg_idxs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if is_saved_tensor(n): + static_arg_idxs.append(arg_count) + arg_count += 1 + + assert static_arg_idxs == list(range(len(static_arg_idxs))) + return len(static_arg_idxs) + + +@dataclasses.dataclass +class BoxedBool: + value: bool + + def __bool__(self): + return self.value + + @staticmethod + def disable(obj): + if isinstance(obj, BoxedBool): + obj.value = False + return obj + return False + + +@contextlib.contextmanager +def collect_defined_kernels(kernel_list): + from .codegen.wrapper import WrapperCodeGen + + orig_define_kernel = WrapperCodeGen.define_kernel + + def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): + nonlocal kernel_list + kernel_list.append(kernel_code) + return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) + + with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): + yield diff --git a/MLPY/Lib/site-packages/torch/_inductor/virtualized.py b/MLPY/Lib/site-packages/torch/_inductor/virtualized.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd641b40eea6e1b935ddd7a87d4e4208ab023f8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/virtualized.py @@ -0,0 +1,351 @@ +""" +This file provides a number of "global" variables/handlers that are actually +thread local and dynamically scoped, with Inductor patching them to various +implementations depending on the situation. + +These handlers are interacted with in a fairly stylized way. Typically, +we will import V from this module:: + + from .virtualized import V + +Various handlers are accessible as attributes on this module; for example, +you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with +a number. + +There are a few distinct usage patterns for virtualized global variables: + +1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``. + Use ``V.set_current_node`` to change what the current node is while we're + executing some region of code, so code inside that region can query ``V.current_node`` + to find out what it is. This is often more convenient than manually threading + the current node as an argument through all call stacks. + +2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a + given ``compile_fx`` invocation, these typically don't change, but they are + associated with some internal state so they cannot just be global functions. + We install these objects at the beginning of compilation and then you can + conveniently access them without having to pass them around. + +3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``. + A commonly used IR in Inductor is define-by-run: instead of maintaining + explicit syntax data structures, we instead represent loop bodies as + callable functions, which internally invoke operations defined on + ``V.ops``. To perform semantic analysis, print or code generate these + operations, we dynamically patch ``V.ops`` with an alternate handler with + the intended semantics and then run the callable function. For example, to + extract out a traditional (FX) graph representation of the define-by-run + IR, simply install a handler that records each ``ops`` call to a graph. + + TODO: Define a parent class / protocol that defines all of the operations + V.ops is expected to support. + +It is typically an error to access a virtualized global without having installed +an appropriate handler (you will get a NullHandler), although in some cases we +provide a default implementation. + +One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is +ubiquitous enough to have its own top level variable, so you will typically see +``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not +equivalent; the former interface supports arithmetic overloads like ``x + y`` +instead of forcing ``ops.add(x, y)``, so it should be preferred. + +Some operators are seemingly unused, but they are implicitly used by ops_wrapper. +In particular, we typically have an operator for every basic pointwise PyTorch operation +supported. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager, contextmanager +from threading import local +from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union + +from .ops_handler import ( # noqa: F401 + KernelFormatterHandler, + MockHandler, + OpsHandler, + ReductionType, + StoreMode, + WrapperHandler, +) + +if TYPE_CHECKING: + import torch + from torch._inductor.debug import DebugContext + from torch._inductor.graph import GraphLowering + from torch._inductor.ir import InterpreterShim + from torch._subclasses import FakeTensorMode + +threadlocal = local() + +T = TypeVar("T") + + +class NullHandler: + """ + Sentinel indicating that a global variable is unset ala None. Typically, + attempting to access the global variable before it's set is an error, but with + NullHandler it won't fail until you try to access an attribute on it. + """ + + pass + + +class Virtualized(Generic[T]): + """ + Implements a global variable that redirects via thread local variable + (NB: construct this class to create the global variable; this is not + a singleton class!) + + This allows us to swap in different op implementations in codegen. + + NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is + the default value of the variable), we sometimes use these variables to + store other things, like booleans. + """ + + def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]): + self._key: str = f"__torchinductor_{vname}" + self._default = default + + def _set_handler(self, value: T) -> AbstractContextManager[None]: + prior = self._get_handler() + setattr(threadlocal, self._key, value) + + @contextmanager + def ctx(): + try: + yield + finally: + self._set_handler(prior) + + return ctx() + + def _get_handler(self) -> T: + try: + return getattr(threadlocal, self._key) + except AttributeError: + # TODO: To be honest, I feel we probably should just error in this + # case, instead of making a null handler that will probably error + # when you getattr on it + return self._default() # type: ignore[return-value] + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_handler(), name) + + +class NullKernelHandler(NullHandler): + """ + We need access `V.kernel.removed_buffers` in DeferredLine class when there + is no kernel in the context. This happens when codegening the wrapper. + Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't + need call 'getattr' with default value which is error prone to typo in + attribute name. + """ + + def __init__(self): + super().__init__() + self.removed_buffers = set() + self.inplaced_to_remove = set() + self.index_dtype = "tl.int64" + + +_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) +_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler) +_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) +_kernel: Virtualized[NullKernelHandler] = Virtualized( + "kernel", NullKernelHandler +) # TODO: improve type +_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) +_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) +_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) +_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) + + +class OpsValue: + """The return type of most ops calls. + + This exists so we can overload magic methods, and write mathematical + expressions much more fluently. So instead of + + ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1) + + we can write + + (_Ap2 * x - _Ap3) * x * x + _1 + + """ + + value: Any + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"OpsValue({self.value!r})" + + def __add__(self, other): + return ops.add(self, other) + + def __mul__(self, other): + return ops.mul(self, other) + + def __sub__(self, other): + return ops.sub(self, other) + + def __neg__(self): + return ops.neg(self) + + def __truediv__(self, other): + return ops.truediv(self, other) + + def __floordiv__(self, other): + return ops.floordiv(self, other) + + def __mod__(self, other): + return ops.mod(self, other) + + def __pow__(self, other): + return ops.pow(self, other) + + def __lt__(self, other): + return ops.lt(self, other) + + def __le__(self, other): + return ops.le(self, other) + + def __eq__(self, other): + return ops.eq(self, other) + + def __ne__(self, other): + return ops.ne(self, other) + + def __gt__(self, other): + return ops.gt(self, other) + + def __ge__(self, other): + return ops.ge(self, other) + + def __and__(self, other): + return ops.bitwise_and(self, other) + + def __or__(self, other): + return ops.bitwise_or(self, other) + + def __xor__(self, other): + return ops.bitwise_xor(self, other) + + def __invert__(self): + return ops.bitwise_not(self) + + def __rshfit__(self, n): + return ops.bitwise_right_shift(self, n) + + def __lshift__(self, n): + return ops.bitwise_left_shift(self, n) + + +class OpsWrapper: + """This wraps any returned IR values into an `OpsValue` instance, so that we + can overload the magic methods for writing mathematical expressions fluently. + """ + + def __getattr__(self, name): + def inner(*args, **kwargs): + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) + + return inner + + @staticmethod + def _unwrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsWrapper._unwrap(v) for v in x) + if isinstance(x, OpsValue): + return x.value + return x + + @staticmethod + def _wrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsValue(v) for v in x) + return OpsValue(x) + + @staticmethod + def indirect_indexing(index, size, check=True): + # Returns a sympy value, not IR value + index = OpsWrapper._unwrap(index) + return _ops.indirect_indexing(index, size, check) + + +ops = OpsWrapper() + + +class _V: + MockHandler = MockHandler + KernelFormatterHandler = KernelFormatterHandler + WrapperHandler = WrapperHandler + + set_ops_handler: Callable[[Any], Any] = _ops._set_handler + get_ops_handler: Callable[[], Any] = _ops._get_handler + set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler + set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler + get_real_inputs: Callable[[], Any] = _real_inputs._get_handler + set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler + get_fake_mode: Callable[[], Any] = _fake_mode._get_handler + set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler + set_debug_handler: Callable[[Any], Any] = _debug._set_handler + set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler + set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler + get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler + set_current_node: Callable[[Any], Any] = _current_node._set_handler + get_current_node: Callable[[], Any] = _current_node._get_handler + + @property + def ops(self) -> OpsHandler[Any]: + """The operator handler specific to the current codegen task""" + return _ops._get_handler() + + @property + def graph(self) -> GraphLowering: + """The graph currently being generated""" + return _graph._get_handler() + + @property + def real_inputs(self): + """non-fake example inputs""" + return _real_inputs._get_handler() + + @property + def fake_mode(self): + """The graph currently being generated""" + return _fake_mode._get_handler() + + @property + def kernel(self): + """The kernel currently being generated""" + return _kernel._get_handler() + + @property + def debug(self): + return _debug._get_handler() + + @property + def interpreter(self): + return _interpreter._get_handler() + + @property + def aot_compilation(self): + return _aot_compilation._get_handler() + + @property + def current_node(self): + return _current_node._get_handler() + + +V = _V() diff --git a/MLPY/Lib/site-packages/torch/_inductor/wrapper_benchmark.py b/MLPY/Lib/site-packages/torch/_inductor/wrapper_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..730bdeaf2b927a225b4867615a7c2c1efc8f1ecd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_inductor/wrapper_benchmark.py @@ -0,0 +1,299 @@ +import dataclasses +import tempfile +from collections import defaultdict + +import torch +from torch.autograd import DeviceType +from .utils import create_bandwidth_info_str, do_bench, get_num_bytes + +_kernel_category_choices = [ + "foreach", + "persistent_reduction", + "pointwise", + "reduction", + "split_scan", + "template", +] + + +def get_kernel_category_by_source_code(src_code): + """ + Similar to get_kernel_category but use the source code. Call this API + if we have not compile the src_code to module yet. + """ + choices = [ + ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code + ] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_kernel_category(kernel_mod): + """ + Given the module defining a triton kernel, return the category of the kernel. + Category can be one of: + - pointwise + - reduction + - persistent_reduction + + Currently we simply decide the category depending on what decorator is imported + by the kernel. + """ + choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_triton_kernel(mod): + from torch._inductor.triton_heuristics import CachingAutotuner + + cand_list = [ + v + for k, v in mod.__dict__.items() + if k.startswith("triton_") and isinstance(v, CachingAutotuner) + ] + assert len(cand_list) == 1 + return cand_list[0] + + +def benchmark_all_kernels(benchmark_name, benchmark_all_configs): + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_key, kernel_mod in PyCodeCache.cache.items(): + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + + triton_kernel = get_triton_kernel(kernel_mod) + kernel_category = get_kernel_category(kernel_mod) + args = kernel_mod.get_args() + num_in_out_ptrs = len( + [ + arg_name + for arg_name in triton_kernel.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + + def get_info_str(ms, n_regs, n_spills, shared, prefix=""): + if not any(x is None for x in [n_regs, n_spills, shared]): + kernel_detail_str = ( + f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" + ) + else: + kernel_detail_str = "" + + gb_per_s = num_gb / (ms / 1e3) + return create_bandwidth_info_str( + ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str + ) + + kernel_desc = ( + f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" + ) + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + print(kernel_desc) + for launcher, ms in bench_result.items(): + print( + f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" + ) + else: + ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True) + assert ( + len(triton_kernel.launchers) == 1 + ), "Autotuner should have selected the best config" + launcher = triton_kernel.launchers[0] + print( + get_info_str( + ms, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + prefix=f"{kernel_desc} ", + ) + ) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) + + +@dataclasses.dataclass +class ProfileEvent: + category: str + key: str + self_cuda_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +def parse_profile_event_list(benchmark_name, event_list, wall_time_ms, nruns): + def get_self_cuda_time(ev): + """ + ev.self_cuda_time_total is in microsecond. Convert to millisecond. + """ + return ev.self_cuda_time_total / 1000 / nruns + + all_events = defaultdict(list) + + def add_event(ev, category): + profile_ev = ProfileEvent( + category=category, + key=ev.key, + self_cuda_time_ms=get_self_cuda_time(ev), + count=ev.count / nruns, # average across all runs + ) + all_events[category].append(profile_ev) + + for ev in event_list: + assert not ev.is_legacy, "Don't support the legacy profiler" + if ev.device_type == DeviceType.CPU: + # ignore the event on CPU side + continue + + category = "unknown" + if ev.key.startswith("triton_"): + if ev.key.startswith("triton_poi"): + category = "triton_pointwise" + elif ev.key.startswith("triton_red"): + category = "triton_reduction" + elif ev.key.startswith("triton_per"): + category = "triton_persistent_reduction" + else: + category = "triton_unknown" + + add_event(ev, category) + + def report_category(category, profile_events): + from tabulate import tabulate + + profile_events.sort(key=lambda ev: ev.self_cuda_time_ms, reverse=True) + + rows = [] + total_time = 0.0 + print(f"\n == {category} category kernels == ") + for ev in profile_events: + total_time += ev.self_cuda_time_ms + percent = f"{ev.self_cuda_time_ms / wall_time_ms * 100:.2f}%" + rows.append([ev.key[:120], ev.self_cuda_time_ms, ev.count, percent]) + rows.append( + ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] + ) + print( + tabulate( + rows, headers=["Kernel", "Self CUDA TIME (ms)", "Count", "Percent"] + ) + ) + return total_time + + def report(): + category_list = [ + "triton_pointwise", + "triton_reduction", + "triton_persistent_reduction", + "triton_unknown", + "unknown", + ] + assert set(all_events.keys()).issubset( + set(category_list) + ), f"{list(all_events.keys())}" + + per_category_wall_time = {} + total_cuda_ms = 0.0 + for category in category_list: + if category in all_events: + _time = report_category(category, all_events[category]) + per_category_wall_time[category] = _time + total_cuda_ms += _time + + gpu_busy_percent = f"{total_cuda_ms / wall_time_ms * 100:.2f}%" + print(f"\nPercent of time when GPU is busy: {gpu_busy_percent}") + print(f"Total wall time {wall_time_ms:.3f} ms") + + # output such a line so we can gather such line from all compiled modules from all + # benchmarks and tabulate it! + # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, + # unknown_category_percent, GPU_busy_percent, wall_time_ms + tabulate_line = f"Output for tabulate: {benchmark_name}" + for category in category_list: + percent = ( + f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" + ) + tabulate_line += f", {percent}" + tabulate_line += f", {gpu_busy_percent}, {wall_time_ms:.3f}ms" + + print(tabulate_line) + + report() + + +def compiled_module_main(benchmark_name, benchmark_compiled_module_fn): + """ + This is the function called in __main__ block of a compiled module. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-kernels", + "-k", + action="store_true", + help="Whether to benchmark each individual kernels", + ) + parser.add_argument( + "--benchmark-all-configs", + "-c", + action="store_true", + help="Whether to benchmark each individual config for a kernel", + ) + parser.add_argument( + "--profile", + "-p", + action="store_true", + help="Whether to profile the compiled module", + ) + args = parser.parse_args() + + if args.benchmark_kernels: + benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) + else: + times = 10 + repeat = 10 + wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 + + if not args.profile: + return + + with torch.profiler.profile(record_shapes=True) as p: + benchmark_compiled_module_fn(times=times, repeat=repeat) + + path = f"{tempfile.gettempdir()}/compiled_module_profile.json" + p.export_chrome_trace(path) + print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") + print(f"Chrome trace for the profile is written to {path}") + event_list = p.key_averages(group_by_input_shape=True) + print(event_list.table(sort_by="self_cuda_time_total", row_limit=10)) + parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, times * repeat + ) diff --git a/MLPY/Lib/site-packages/torch/_jit_internal.py b/MLPY/Lib/site-packages/torch/_jit_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..1de2be2ad72b3d235bfe160198a54f2020670d8c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_jit_internal.py @@ -0,0 +1,1510 @@ +""" +The weak_script annotation needs to be here instead of inside torch/jit/ so it +can be used in other places in torch/ (namely torch.nn) without running into +circular dependency problems +""" + +import ast +import builtins +import collections +import contextlib +import enum +import inspect +import io +import pickle +import sys +import threading +import types +import typing +import warnings +import weakref +from textwrap import dedent +from typing import ( # noqa: F401 + Any, + Callable, + Dict, + Final, + ForwardRef, + Generic, + get_args, # new in 3.8 + get_origin, # new in 3.8 + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import torch + +# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. +# Explicitly ask to import `torch.distributed.__init__` first. +# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. +import torch.distributed.rpc +import torch.package._mangling as package_mangling +from torch._awaits import _Await +from torch._C import _Await as CAwait, Future as CFuture +from torch._sources import fake_range, get_source_lines_and_file, parse_def +from torch.futures import Future + +IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9) +IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) + +BuiltinUnionType: Union[Type, Tuple[Type, ...]] +if sys.version_info >= (3, 10): + # NOTE: IS_PY310_PLUS doesn't work with mypy. + # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks + BuiltinUnionType = types.UnionType +else: + BuiltinUnionType = () # trick: this makes isinstance short circuit. + +LockType: Type +try: + import _thread + + LockType = _thread.LockType +except ImportError: + import _dummy_thread # type: ignore[import-not-found] + + LockType = _dummy_thread.LockType + +# Wrapper functions that can call either of 2 functions depending on a boolean +# argument +boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = ( + weakref.WeakKeyDictionary() +) # noqa: T484 + + +FAKE_FILENAME_PREFIX = "__torch_jit_dataclass" + + +class SourceLoader: + def __init__(self): + self.content = {} + + def cache(self, fn, source): + self.content[fn] = source + + def get_source(self, fn): + return self.content.get(fn) + + +loader = SourceLoader() + + +def createResolutionCallbackFromEnv(lookup_base): + """ + Creates a resolution callback that will look up qualified names in an + environment, starting with `lookup_base` for the base of any qualified + names, then proceeding down the lookup chain with the resolved object. + + You should not use this directly, it should only be used from the other + createResolutionCallbackFrom* functions. + """ + + def lookupInModule(qualified_name, module): + if "." in qualified_name: + base, remaining_pieces = qualified_name.split(".", maxsplit=1) + module_value = getattr(module, base) + return lookupInModule(remaining_pieces, module_value) + else: + return getattr(module, qualified_name) + + def parseNestedExpr(expr, module) -> Tuple[Any, int]: + i = 0 + while i < len(expr) and expr[i] not in (",", "[", "]"): + i += 1 + + # Special case logic for the empty Tuple as a subscript (used + # in the type annotation `Tuple[()]`) + if expr[:i] == "()": + return (), i + + base = lookupInModule(expr[:i].strip(), module) + assert base is not None, f"Unresolvable type {expr[:i]}" + if i == len(expr) or expr[i] != "[": + return base, i + + assert expr[i] == "[" + parts = [] + while expr[i] != "]": + part_len = 0 + i += 1 + part, part_len = parseNestedExpr(expr[i:], module) + parts.append(part) + i += part_len + if len(parts) > 1: + return base[tuple(parts)], i + 1 + else: + return base[parts[0]], i + 1 + + def parseExpr(expr, module): + try: + value, len_parsed = parseNestedExpr(expr, module) + assert len_parsed == len( + expr + ), "whole expression was not parsed, falling back to c++ parser" + return value + except Exception: + """ + The python resolver fails in several cases in known unit tests, and is intended + to fall back gracefully to the c++ resolver in general. For example, python 2 style + annotations which are frequent in our unit tests often fail with types e.g. int not + resolvable from the calling frame. + """ + return None + + return lambda expr: parseExpr(expr, lookup_base) + + +def createResolutionCallbackFromFrame(frames_up: int = 0): + """ + Creates a function which, given a string variable name, + returns the value of the variable in the scope of the caller of + the function which called createResolutionCallbackFromFrame (by default). + + This is used to enable access in-scope Python variables inside + TorchScript fragments. + + frames_up is number of additional frames to go up on the stack. + The default value is 0, which correspond to the frame of the caller + of createResolutionCallbackFromFrame. Also for example, if frames_up is set + to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame + will be taken. + + For example, the following program prints 2:: + + def bar(): + cb = createResolutionCallbackFromFrame(1) + print(cb("foo")) + + def baz(): + foo = 2 + bar() + + baz() + """ + frame = inspect.currentframe() + i = 0 + while i < frames_up + 1: + assert frame is not None + frame = frame.f_back + i += 1 + + assert frame is not None + f_locals = frame.f_locals + f_globals = frame.f_globals + + class env: + def __getattr__(self, key): + if key in f_locals: + return f_locals[key] + elif key in f_globals: + return f_globals[key] + elif key in dir(builtins): + return getattr(builtins, key) + + return createResolutionCallbackFromEnv(env()) + + +def get_closure(fn): + """ + Get a dictionary of closed over variables from a function + """ + captures = {} + captures.update(fn.__globals__) + + for index, captured_name in enumerate(fn.__code__.co_freevars): + captures[captured_name] = fn.__closure__[index].cell_contents + + return captures + + +# [local resolution in python] +# Depending on where a variable is defined, and where it is used, we may +# or may not be able to recover its value when recursively compiling a +# script function. Remember in the general case, a module or function is +# first defined and then later scripted. This means we do not have a +# chance to capture the active frames when the function is defined. Hence any +# name resolution has to happen later on the created closure. The way +# python captures type annotations restricts what we can recover. The +# follow example illustrates the different cases: +# +# class MyGlobalClass: +# ... +# def my_local_scope(): +# @torch.jit.script +# class MyClass: +# ... +# @torch.jit.script +# class MyClassUsedAsVar: +# ... +# def eg(x: MyClass, y: MyGlobalClass): +# a_local_capture : Foo +# return MyClassUsedAsVar(x) +# +# MyGlobalClass is defined in the __globals__ dictionary of function +# 'eg', so it is always recoverable. my_local_scope introduces a new local +# variable scope in the function. Classes defined here are only visible as +# local variables. For the case of MyClassUsedAsVar, it is captured +# because it is used as a variable inside the body of the function, and we +# can resolve it using the captures returned from `get_closure`. However, +# the type annotations are not captured by the closure. In Python +# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as +# annotations on `eg``, but starting in Python 4.0, they will represented as +# strings and no longer present. Furthermore, since the body of `eg` does +# not reference those names, they do not appear in the list of closed over +# variables. In Python 2.x, type annotations are in comments, leading to a +# similar situation where their definitions are not available. We anticipate +# that most users will not run into this issue because their modules and +# functions will be defined at a global scope like MyGlobalClass. In cases +# where they are not, it is possible to work around issues by declaring the +# values global in the function. +# In Python 3.9 declaring class as global will make it invisible to +# `inspect.getsource`, see https://bugs.python.org/issue42666 . +# This could be worked around by manualy adding it to `global()` dictionary. + + +def createResolutionCallbackFromClosure(fn): + """ + Create a resolutionCallback by introspecting the function instead of + looking up the stack for the enclosing scope + """ + closure = get_closure(fn) + + class closure_lookup: + # This is a class since `closure` is a dict and it's easier in + # `env_helper` if everything just works with `getattr` calls + def __getattr__(self, key): + if key in closure: + return closure[key] + elif hasattr(typing, key): + return getattr(typing, key) + elif hasattr(builtins, key): + return getattr(builtins, key) + return None + + return createResolutionCallbackFromEnv(closure_lookup()) + + +def can_compile_class(cls) -> bool: + # If any of the functions on a type don't have a code object, this type can't + # be compiled and is probably a builtin / bound from C + if is_ignored_fn(cls): + return False + + # Ignore the following list of built-in classes. + ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) + if issubclass(cls, ignored_builtin_classes): + return False + + names = cls.__dict__ + fns = [ + getattr(cls, name) + for name in names + if inspect.isroutine(getattr(cls, name, None)) + ] + has_code = [hasattr(fn, "__code__") for fn in fns] + return all(has_code) + + +def get_callable_argument_names(fn) -> List[str]: + """ + Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`. + Returns an empty list when other types of arguments are present. + + This is used by `torch.jit.trace` to assign meaningful argument names to + traced functions and modules. + + Args: + fn: A callable. + Returns: + Argument names: List[str] + """ + # inspect.signature may fail, give up in that case. + try: + callable_signature = inspect.signature(fn) + except Exception: + return [] + + argument_names = [] + for name, param in callable_signature.parameters.items(): + # All four other types of arguments do not map to individual values + # with a keyword as name. + if not param.kind == param.POSITIONAL_OR_KEYWORD: + continue + + argument_names.append(name) + + return argument_names + + +def get_annotation_str(annotation): + """ + Convert an AST node containing a type annotation to the string present in the source + that represents the same annotation. + """ + if isinstance(annotation, ast.Name): + return annotation.id + elif isinstance(annotation, ast.Attribute): + return ".".join([get_annotation_str(annotation.value), annotation.attr]) + elif isinstance(annotation, ast.Subscript): + # In Python3.9+ subscript indicies are not wrapped in ast.Index + subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined] + return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" + elif isinstance(annotation, ast.Tuple): + return ",".join([get_annotation_str(elt) for elt in annotation.elts]) + elif isinstance(annotation, (ast.Constant, ast.NameConstant)): + return f"{annotation.value}" + + # If an AST node is not handled here, it's probably handled in ScriptTypeParser. + return None + + +def get_type_hint_captures(fn): + """ + Get a dictionary containing type resolution mappings necessary to resolve types + for the literal annotations on 'fn'. These are not considered to be closed-over by fn + and must be obtained separately (e.g. using this function). + + Args: + fn: A callable. + Returns: + A Dict[str, Any] containing a mapping from the literal annotations used on + fn to the Python objects they refer to. + """ + # First, try to get the source of the function. We'll need to parse it to find the actual string names + # that were used to annotate the types, since inspect.signature() will only return the class object that + # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict. + # This may happen in cases where the function is synthesized dynamically at runtime. + src = loader.get_source(fn) + if src is None: + src = inspect.getsource(fn) + + # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated + # types are strings. These are only understood by TorchScript in the context of a type annotation + # that refers to a class in its own definition, but trying to include a mapping for this in the result + # function would cause infinite recursion because the class is currently being compiled. + # In addition, there is logic in ScriptTypeParser to handle this. + signature = inspect.signature(fn) + name_to_type = { + name: parameter.annotation + for name, parameter in signature.parameters.items() + if parameter.annotation is not inspect.Parameter.empty + and not isinstance(parameter.annotation, str) + } + + # Then, get the literal type annotations from the function declaration + # by source inspection. This accounts for the case in which aliases are used + # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). + # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. + a = ast.parse(dedent(src)) + if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): + raise RuntimeError(f"Expected {fn} to be a function") + f = a.body[0] + + # Prepare a dictionary of source annotation -> type, which will be the final result of this function, + # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping + # them to the type object corresponding to the annotation via name_to_type using the parameter name. + annotation_to_type = {} + + for arg in f.args.args: + # Get the source type annotation string for this argument if possible. + arg_annotation_str = ( + get_annotation_str(arg.annotation) if arg.annotation else None + ) + + # If the argument has no annotation or get_annotation_str cannot convert it to a string, + # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle + # this in the latter case. + if arg_annotation_str is None: + continue + + # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not + # be present in name_to_type is that the annotation itself is a string and not a type object + # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this. + arg_name = arg.arg + if arg_name in name_to_type: + annotation_to_type[arg_annotation_str] = name_to_type[arg_name] + + # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations, + # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type + # of the annotation cannot be a string. + literal_return_annotation = get_annotation_str(f.returns) + valid_literal_annotation = literal_return_annotation is not None + return_annotation = signature.return_annotation + valid_return_annotation_type = ( + return_annotation is not inspect.Parameter.empty + and not isinstance(return_annotation, str) + ) + if valid_literal_annotation and valid_return_annotation_type: + annotation_to_type[literal_return_annotation] = return_annotation + + return annotation_to_type + + +def createResolutionCallbackForClassMethods(cls): + """ + This looks at all the methods defined in a class and pulls their closed-over + variables into a dictionary and uses that to resolve variables. + """ + # cls is a type here, so `ismethod` is false since the methods on the type + # aren't bound to anything, so Python treats them as regular functions + fns = [ + getattr(cls, name) + for name in cls.__dict__ + if inspect.isroutine(getattr(cls, name)) + ] + # Skip built-ins, as they do not have global scope nor type hints + # Needed to support `enum.Enum` derived classes in Python-3.11 + # That adds `_new_member_` property which is an alias to `__new__` + fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")] + captures = {} + + for fn in fns: + captures.update(get_closure(fn)) + captures.update(get_type_hint_captures(fn)) + + def lookup_in_class(key): + if key in captures: + return captures[key] + else: + return getattr(builtins, key, None) + + return lookup_in_class + + +def boolean_dispatch( + arg_name, arg_index, default, if_true, if_false, module_name, func_name +): + """ + Dispatches to either of 2 script functions based on a boolean argument. + In TorchScript, the boolean argument must be constant so that the correct + function to use can be determined at compile time. + """ + + def fn(*args, **kwargs): + dispatch_flag = default + if arg_name in kwargs: + dispatch_flag = kwargs[arg_name] + elif arg_index < len(args): + dispatch_flag = args[arg_index] + + if dispatch_flag: + return if_true(*args, **kwargs) + else: + return if_false(*args, **kwargs) + + if if_true.__doc__ is None and if_false.__doc__ is not None: + doc = if_false.__doc__ + if_true.__doc__ = doc + elif if_false.__doc__ is None and if_true.__doc__ is not None: + doc = if_true.__doc__ + if_false.__doc__ = doc + elif if_false.__doc__ is None and if_true.__doc__ is None: + # neither function has a docstring + doc = None + else: + raise RuntimeError("only one function can have a docstring") + fn.__doc__ = doc + + if module_name is not None: + fn.__module__ = module_name + if func_name is not None: + fn.__name__ = func_name + + boolean_dispatched[fn] = { + "if_true": if_true, + "if_false": if_false, + "index": arg_index, + "default": default, + "arg_name": arg_name, + } + return fn + + +class FunctionModifiers: + """ + Used to denote the behavior of a function in TorchScript. See export() and + ignore() for details. + """ + + UNUSED = "unused (ignored and replaced with raising of an exception)" + IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" + EXPORT = "export (compile this function even if nothing calls it)" + DEFAULT = "default (compile if called from a exported function / forward)" + COPY_TO_SCRIPT_WRAPPER = ( + "if this method is not scripted, copy the python method onto the scripted model" + ) + _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" + + +def export(fn): + """ + This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a + :class:`ScriptModule` and should be compiled. + + ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator. + Functions and methods called from ``forward`` are compiled as they are seen + by the compiler, so they do not need this decorator either. + + Example (using ``@torch.jit.export`` on a method): + + .. testcode:: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + def implicitly_compiled_method(self, x): + return x + 99 + + # `forward` is implicitly decorated with `@torch.jit.export`, + # so adding it here would have no effect + def forward(self, x): + return x + 10 + + @torch.jit.export + def another_forward(self, x): + # When the compiler sees this call, it will compile + # `implicitly_compiled_method` + return self.implicitly_compiled_method(x) + + def unused_method(self, x): + return x - 20 + + # `m` will contain compiled methods: + # `forward` + # `another_forward` + # `implicitly_compiled_method` + # `unused_method` will not be compiled since it was not called from + # any compiled methods and wasn't decorated with `@torch.jit.export` + m = torch.jit.script(MyModule()) + """ + fn._torchscript_modifier = FunctionModifiers.EXPORT + return fn + + +def unused(fn): + """ + This decorator indicates to the compiler that a function or method should + be ignored and replaced with the raising of an exception. This allows you + to leave code in your model that is not yet TorchScript compatible and still + export your model. + + Example (using ``@torch.jit.unused`` on a method):: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + def __init__(self, use_memory_efficient): + super().__init__() + self.use_memory_efficient = use_memory_efficient + + @torch.jit.unused + def memory_efficient(self, x): + import pdb + pdb.set_trace() + return x + 10 + + def forward(self, x): + # Use not-yet-scriptable memory efficient mode + if self.use_memory_efficient: + return self.memory_efficient(x) + else: + return x + 10 + + m = torch.jit.script(MyModule(use_memory_efficient=False)) + m.save("m.pt") + + m = torch.jit.script(MyModule(use_memory_efficient=True)) + # exception raised + m(torch.rand(100)) + """ + if isinstance(fn, property): + prop = fn + setattr( # noqa: B010 + prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED + ) + + if prop.fset: + setattr( # noqa: B010 + prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED + ) + + return prop + + fn._torchscript_modifier = FunctionModifiers.UNUSED + return fn + + +# No op context manager from python side +class _IgnoreContextManager(contextlib.AbstractContextManager): + def __init__(self, **kwargs): + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + pass + + +def ignore(drop=False, **kwargs): + """ + This decorator indicates to the compiler that a function or method should + be ignored and left as a Python function. This allows you to leave code in + your model that is not yet TorchScript compatible. If called from TorchScript, + ignored functions will dispatch the call to the Python interpreter. Models with ignored + functions cannot be exported; use :func:`@torch.jit.unused ` instead. + + Example (using ``@torch.jit.ignore`` on a method):: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + @torch.jit.ignore + def debugger(self, x): + import pdb + pdb.set_trace() + + def forward(self, x): + x += 10 + # The compiler would normally try to compile `debugger`, + # but since it is `@ignore`d, it will be left as a call + # to Python + self.debugger(x) + return x + + m = torch.jit.script(MyModule()) + + # Error! The call `debugger` cannot be saved since it calls into Python + m.save("m.pt") + + Example (using ``@torch.jit.ignore(drop=True)`` on a method): + + .. testcode:: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + @torch.jit.ignore(drop=True) + def training_method(self, x): + import pdb + pdb.set_trace() + + def forward(self, x): + if self.training: + self.training_method(x) + return x + + m = torch.jit.script(MyModule()) + + # This is OK since `training_method` is not saved, the call is replaced + # with a `raise`. + m.save("m.pt") + + .. testcleanup:: + + import os + os.remove('m.pt') + """ + + if callable(drop): + # used without any args, so drop is actually a function + # @torch.jit.ignore + # def fn(...): + fn = drop + fn._torchscript_modifier = FunctionModifiers.IGNORE + return fn + + if not isinstance(drop, bool): + raise RuntimeError( + "Argument to @torch.jit.ignore must be a bool or " + f"a function but got {drop}" + ) + + # for backwards compat + drop_on_export = kwargs.pop("drop_on_export", None) + if drop_on_export: + warnings.warn( + "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", + category=FutureWarning, + ) + + drop = drop_on_export + elif drop: + warnings.warn( + "ignore(True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", + category=FutureWarning, + ) + + def decorator(fn): + if drop: + fn._torchscript_modifier = FunctionModifiers.UNUSED + else: + fn._torchscript_modifier = FunctionModifiers.IGNORE + return fn + + return decorator + + +def _drop(fn): + fn._torchscript_modifier = FunctionModifiers._DROP + return fn + + +def _copy_to_script_wrapper(fn): + fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER + return fn + + +def module_has_exports(mod): + for name in dir(mod): + if hasattr(mod, name): + item = getattr(mod, name) + if callable(item): + if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: + return True + return False + + +# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you +# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to +# allow JIT'd code to still be covered. +def should_drop(fn) -> bool: + attr = get_torchscript_modifier(fn) + if attr is None: + return False + return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP + + +def is_ignored_fn(fn) -> bool: + mod = get_torchscript_modifier(fn) + return ( + mod is FunctionModifiers.UNUSED + or mod is FunctionModifiers.IGNORE + or mod is FunctionModifiers._DROP + ) + + +def _is_drop_fn(fn) -> bool: + mod = get_torchscript_modifier(fn) + return mod is FunctionModifiers._DROP + + +def is_static_fn(cls, fn) -> bool: + return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod) + + +def get_static_fn(cls, fn): + return inspect.getattr_static(cls, fn).__func__ + + +def get_torchscript_modifier(fn): + if not callable(fn): + return None + if hasattr(fn, "__func__"): + fn = fn.__func__ + return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT) + + +def copy_torchscript_modifier(orig, new) -> None: + attr = get_torchscript_modifier(orig) + if attr is None: + return + new._torchscript_modifier = attr + + +# overloading registration +# overloads get registered in this file, and compiled in torch/jit/__init__.py +# so that they can be imported in nn/functional.py without an import cycle + +# qualified_name => list[overload_functions] +_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484 + + +_OVERLOAD_EXAMPLE = """ +Example usage of overload function: +@torch.jit._overload +def my_function(x: type0) -> type0: # decl 1 + pass + +@torch.jit._overload +def my_function(x: type1) -> type1: # decl 2 + pass + +def my_function(x): # implementation + if isinstance(x, type0): + return x + elif isinstance(x, type1): + return x +""" + + +def get_overload_no_implementation_error_message(kind, obj): + sourcelines, file_lineno, filename = get_source_lines_and_file(obj) + return ( + f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' + f"sure a definition is provided and defined after all overload declarations.\n" + f'File "{filename}", line {file_lineno}:\n' + + "".join(sourcelines) + + "\n" + + _OVERLOAD_EXAMPLE + ) + + +def _check_overload_body(func): + try: + parsed_def = parse_def(func) + except OSError as e: + # Parsing the function definition can raise an OSError if source is unavailable. + # Since this is just an initial check, just raise a warning if this is the case. + warnings.warn( + f"Unable to retrieve source for @torch.jit._overload function: {func}." + ) + return + + body = parsed_def.ast.body[0].body + + def is_pass(x): + return isinstance(x, ast.Pass) + + def is_ellipsis(x): + return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis) + + if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): + msg = ( + "Only `pass` statement or `...` can be the body of overload declaration:\n" + ) + msg += "\n".join(parsed_def.source.split("\n")[:3]) + msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE + raise RuntimeError(msg) + + +def _overload(func): + _check_overload_body(func) + qual_name = _qualified_name(func) + global _overloaded_fns + fn_overload_list = _overloaded_fns.get(qual_name) + if fn_overload_list is None: + fn_overload_list = [] + _overloaded_fns[qual_name] = fn_overload_list + fn_overload_list.append(func) + return func + + +def _get_fn_overloads(qual_name): + return _overloaded_fns.get(qual_name) + + +def _clear_fn_overloads(qual_name) -> None: + del _overloaded_fns[qual_name] + + +def get_class_name_lineno(method) -> Tuple[str, int]: + current_frame = inspect.currentframe() + + # one for the get_class_name call, one for _overload_method call + for i in range(2): + assert ( + current_frame is not None + ) # assert current frame is not an Optional[FrameType] + current_frame = current_frame.f_back + + assert current_frame is not None # same here + class_name = current_frame.f_code.co_name + line_no = current_frame.f_code.co_firstlineno + return class_name, line_no + + +# At the point the decorator is applied to class methods the method +# has no reference to its owning class. _qualified_name would not include +# the class it is defined in, so any methods with the same name in the same file +# would have the same _qualified_name, even if they were defined in different +# classes. This problem only exists in python 2. +# We get around this problem by looking at the stack frame and identifying +# the class name, and throwing an error whenever overloads are used +# when modules of the same name are in the same file + +# qualified_name => class name => list[overload_functions] +_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 + + +# (qualified_name, class name) => class_fileno +_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {} + + +def _overload_method(func): + _check_overload_body(func) + qual_name = _qualified_name(func) + global _overloaded_methods + class_name_map = _overloaded_methods.get(qual_name, None) + if class_name_map is None: + class_name_map = {} + _overloaded_methods[qual_name] = class_name_map + + class_name, line_no = get_class_name_lineno(func) + method_overloads = class_name_map.get(class_name, None) + if method_overloads is None: + method_overloads = [] + class_name_map[class_name] = method_overloads + _overloaded_method_class_fileno[(qual_name, class_name)] = line_no + else: + existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)] + if existing_lineno != line_no: + raise RuntimeError( + "Cannot currently overload the same method name in two different" + " classes with the same name in the same module" + ) + + method_overloads.append(func) + return func + + +def _get_overloaded_methods(method, mod_class): + # TODO: __name__ not set for submodules in recursive script + if not hasattr(method, "__name__"): + return None + qual_name = _qualified_name(method) + class_name_map = _overloaded_methods.get(qual_name, None) + if class_name_map is None: + return None + overloads = class_name_map.get(mod_class.__name__, None) + if overloads is None: + return None + + method_line_no = get_source_lines_and_file(method)[1] + mod_class_fileno = get_source_lines_and_file(mod_class)[1] + mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) + if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): + raise Exception( + "Overloads are not useable when a module is redeclared within the same file: " + + str(method) + ) + return overloads + + +def is_tuple(ann) -> bool: + if ann is Tuple: + raise_error_container_parameter_missing("Tuple") + + # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule + if not hasattr(ann, "__module__"): + return False + + ann_origin = get_origin(ann) + if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple: + return True + return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple) + + +def is_list(ann) -> bool: + if ann is List: + raise_error_container_parameter_missing("List") + + if not hasattr(ann, "__module__"): + return False + + ann_origin = get_origin(ann) + if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list: + return True + return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list) + + +def is_dict(ann) -> bool: + if ann is Dict: + raise_error_container_parameter_missing("Dict") + + if not hasattr(ann, "__module__"): + return False + + ann_origin = get_origin(ann) + if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict: + return True + return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict) + + +def is_union(ann): + if ann is Union: + raise_error_container_parameter_missing("Union") + + return isinstance(ann, BuiltinUnionType) or ( + hasattr(ann, "__module__") + and ann.__module__ == "typing" + and (get_origin(ann) is Union) + ) + + +def is_optional(ann): + if ann is Optional: + raise_error_container_parameter_missing("Optional") + + def is_optional_as_optional(ann): + return ( + hasattr(ann, "__module__") + and ann.__module__ == "typing" + and (get_origin(ann) is Optional) + ) + + def is_union_as_optional(ann): + ann_args = get_args(ann) + return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args) + + return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) + + +def is_future(ann) -> bool: + if ann is Future: + raise RuntimeError( + "Attempted to use Future without a " + "contained type. Please add a contained type, e.g. " + "Future[int]" + ) + return get_origin(ann) is Future + + +def is_await(ann) -> bool: + if ann is _Await: + return True + return get_origin(ann) is _Await + + +if torch.distributed.rpc.is_available(): + from torch._C._distributed_rpc import PyRRef + from torch.distributed.rpc import RRef + + def is_rref(ann) -> bool: + if ann is RRef: + raise RuntimeError( + "Attempted to use RRef without a " + "contained type. Please add a contained type, e.g. " + "RRef[int]" + ) + return get_origin(ann) is RRef + + def is_rref_instance(obj) -> bool: + return isinstance(obj, PyRRef) + +else: + + def is_rref_instance(obj) -> bool: + # If the RPC module doesn't exist then RRefs don't exist either. + return False + + +def is_final(ann) -> bool: + return ( + hasattr(ann, "__module__") + and ann.__module__ in {"typing", "typing_extensions"} + and (get_origin(ann) is Final or isinstance(ann, type(Final))) + ) + + +# allows BroadcastingList instance to be subscriptable +class BroadcastingListCls: + def __getitem__(self, types): + return + + +# mypy doesn't support parameters on types, so we have to explicitly type each +# list size +BroadcastingList1 = BroadcastingListCls() +for i in range(2, 7): + globals()[f"BroadcastingList{i}"] = BroadcastingList1 + + +def is_scripting() -> bool: + r""" + Function that returns True when in compilation and False otherwise. This + is useful especially with the @unused decorator to leave code in your + model that is not yet TorchScript compatible. + .. testcode:: + + import torch + + @torch.jit.unused + def unsupported_linear_op(x): + return x + + def linear(x): + if torch.jit.is_scripting(): + return torch.linear(x) + else: + return unsupported_linear_op(x) + """ + return False + + +# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. +def _qualified_name(obj, mangle_name=True) -> str: + # This special case allows us to override the qualified name on a type. + # It's currently used in conjunction with tracing, where we create a + # fake module to filter only supported attributes. However, since this + # new type is defined as a local class, we need a mechanism to override + # its qualname so it appears correctly in the TorchScript system. This, + # we set '_jit_override_qualname' with the original traced module's + # qualified name, which is picked up here + if hasattr(obj, "_jit_override_qualname"): + return obj._jit_override_qualname + # short-circuit in cases where the object already has a known qualified name + if isinstance(obj, torch._C.ScriptFunction): + return obj.qualified_name + + if getattr(obj, "__name__", None): + name = obj.__name__ + # Enum classes do not have `__name__` attr, instead they have `name`. + elif isinstance(obj, enum.Enum): + name = obj.name + else: + raise RuntimeError("Could not get name of python class object") + + if name == "": + name = "_lambda" # make name a valid identifier + + module_name = obj.__module__ + + # If the module is actually a torchbind module, then we should short circuit + if module_name == "torch._classes": + return obj.qualified_name + + # The Python docs are very clear that `__module__` can be None, but I can't + # figure out when it actually would be. + if module_name is None: + raise RuntimeError( + f"Could not get qualified name for class '{name}': " + "__module__ can't be None." + ) + + # if getattr(sys.modules[module_name], name) is not obj: + # raise RuntimeError(f"Could not get qualified name for class '{name}': " + # f"the attr {name} on module {module_name} is not the class") + + # torch.package and TorchScript have separate mangling schemes to avoid + # name collisions from multiple packages. To avoid them interfering with + # each other, normalize the package manging here. + if package_mangling.is_mangled(module_name): + module_name = module_name.replace("<", "_") + module_name = module_name.replace(">", "_") + + # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h + # does not need mangle the python class name. + if mangle_name: + # __main__ is a builtin module, so rewrite it to "__torch__". + if module_name == "__main__": + module_name = "__torch__" + else: + # Everything else gets a "__torch__" prefix to avoid name collisions + # with the names of user values. + module_name = "__torch__." + module_name + + if "." in name: + raise RuntimeError( + f"Could not get qualified name for class '{name}': " + f"'{name}' is not a valid identifier" + ) + + return module_name + "." + name + + +def _try_get_dispatched_fn(fn): + if not callable(fn): + return None + return boolean_dispatched.get(fn) + + +def _get_named_tuple_properties( + obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None +): + if loc is None: + loc = fake_range() + + assert issubclass(obj, tuple) and hasattr(obj, "_fields") + if hasattr(obj, "_field_defaults"): + defaults = [ + obj._field_defaults[field] + for field in obj._fields + if field in obj._field_defaults + ] + else: + defaults = [] + # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function + # Also, annotations from base class are not inherited so they need to be queried explicitly + if sys.version_info[:2] < (3, 10): + obj_annotations = getattr(obj, "__annotations__", {}) + else: + obj_annotations = inspect.get_annotations(obj) + if len(obj_annotations) == 0 and hasattr(obj, "__base__"): + obj_annotations = inspect.get_annotations(obj.__base__) + + annotations = [] + for field in obj._fields: + if field in obj_annotations: + field_type = obj_annotations[field] + # [Note: ForwardRef annotations in NamedTuple attributes] + # NamedTuple types are slightly different from normal types. + # + # Normally, annotations are evaluted like this (during jit.script): + # 1. Load strings of python code into c++ and parse. + # 2. Get annotations as strings + # 3. Use the PythonResolver's resolution callback (rcb) to convert + # the string into a python object + # 4. We call into annotations.py:ann_to_type to convert python obj + # from step 3 into a type that torchscript understands. + # + # NamedTuples are more complicated, because it has sub-types. + # Normally, once we have the NamedTuple type object from #3, + # we can just look at the annotation literal values and use + # ann_to_type directly on them. + # + # But sometimes, users will annotate with string literals, e.g. + # x: 'int' + # This also happens with PEP563 (from __forward__ import annotations) + # + # These annotations appear in the annotation dict as ForwardRef('int'). + # + # Then, we need to convert the string into a python object. This + # requires having local context for custom objects or imported types. + # rcb() is what gives us this. So, we plumb rcb through the stack so + # it can be used in this context for the if block below. + # + # FAQ: + # - Why do we need this special handling for NamedTuple but string + # annotations work fine for normal types? Normally, we parse the + # string directly and then call rcb() directly from C++. + # - Why not use ForwardRef._evaluate? For that, we need globals() + # and locals() for the local context where the NamedTuple was defined. + # rcb is what lets us look up into these. So, basically rcb does the + # hard work for us. + if isinstance(field_type, ForwardRef) and rcb is not None: + rcb_type = rcb(field_type.__forward_arg__) + # rcb returns None if it can't find anything. + if rcb_type is None: + raise ValueError( + f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}." + f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858." + f" Issue occurred at {loc.highlight()}" + ) + field_type = rcb_type + the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb) + annotations.append(the_type) + else: + annotations.append(torch._C.TensorType.getInferred()) + return type(obj).__name__, obj._fields, annotations, defaults + + +def _create_named_tuple( + t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...] +): + TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] + return TupleType(*t) + + +@contextlib.contextmanager +def _disable_emit_hooks(): + hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + try: + yield + finally: + torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) + + +def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811 + def __enter__(self) -> None: + self.hooks = torch._C._jit_get_emit_hooks() + torch._C._jit_set_emit_hooks(None, None) + + def __exit__(self, *args) -> None: + torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) + + +def _is_exception(obj) -> bool: + if not inspect.isclass(obj): + return False + return issubclass(obj, Exception) + + +def raise_error_container_parameter_missing(target_type) -> None: + if target_type == "Dict": + raise RuntimeError( + "Attempted to use Dict without " + "contained types. Please add contained type, e.g. " + "Dict[int, int]" + ) + raise RuntimeError( + f"Attempted to use {target_type} without a " + "contained type. Please add a contained type, e.g. " + f"{target_type}[int]" + ) + + +def check_args_exist(target_type) -> None: + if target_type is List or target_type is list: + raise_error_container_parameter_missing("List") + elif target_type is Tuple or target_type is tuple: + raise_error_container_parameter_missing("Tuple") + elif target_type is Dict or target_type is dict: + raise_error_container_parameter_missing("Dict") + elif target_type is None or target_type is Optional: + raise_error_container_parameter_missing("Optional") + + +def check_empty_containers(obj) -> None: + if obj == [] or obj == {} or obj == (): + warnings.warn( + "The inner type of a container is lost when " + "calling torch.jit.isinstance in eager mode. For " + "example, List[int] would become list and " + "therefore falsely return True for List[float] or" + " List[str]." + ) + + +# supports List/Dict/Tuple and Optional types +# TODO support future +def container_checker(obj, target_type) -> bool: + origin_type = get_origin(target_type) + check_args_exist(target_type) + if origin_type is None: + return False + elif origin_type is list or origin_type is List: + check_empty_containers(obj) + if not isinstance(obj, list): + return False + arg_type = get_args(target_type)[0] + arg_origin = get_origin(arg_type) + for el in obj: + # check if nested container, ex: List[List[str]] + if arg_origin: # processes nested container, ex: List[List[str]] + if not container_checker(el, arg_type): + return False + elif not isinstance(el, arg_type): + return False + return True + elif origin_type is Dict or origin_type is dict: + check_empty_containers(obj) + if not isinstance(obj, dict): + return False + key_type = get_args(target_type)[0] + val_type = get_args(target_type)[1] + for key, val in obj.items(): + # check if keys are of right type + if not isinstance(key, key_type): + return False + val_origin = get_origin(val_type) + if val_origin: + if not container_checker(val, val_type): + return False + elif not isinstance(val, val_type): + return False + return True + elif origin_type is Tuple or origin_type is tuple: + check_empty_containers(obj) + if not isinstance(obj, tuple): + return False + arg_types = get_args(target_type) + if len(obj) != len(arg_types): + return False + for el, el_type in zip(obj, arg_types): + el_origin = get_origin(el_type) + if el_origin: + if not container_checker(el, el_type): + return False + elif not isinstance(el, el_type): + return False + return True + elif origin_type is Union or issubclass( + origin_type, BuiltinUnionType + ): # also handles Optional + if obj is None: # check before recursion because None is always fine + return True + inner_types = get_args(target_type) + for t in inner_types: + t_origin = get_origin(t) + if t_origin: + return container_checker(obj, t) + elif isinstance(obj, t): + return True + return False + + +def _isinstance(obj, target_type) -> bool: + if isinstance(target_type, collections.abc.Container): + if not isinstance(target_type, tuple): + raise RuntimeError( + "The second argument to " + "`torch.jit.isinstance` must be a type " + "or a tuple of types" + ) + for t_type in target_type: + if _isinstance(obj, t_type): + return True + return False + + origin_type = get_origin(target_type) + if origin_type: + return container_checker(obj, target_type) + + # Check to handle non-typed optional origin returns as none instead + # of as optional in 3.7-3.8 + check_args_exist(target_type) + + # handle non-containers + return isinstance(obj, target_type) + + +class _TensorExtractor(pickle.Pickler): + def __init__(self, *args, tensors: List[torch.Tensor], **kwargs): + super().__init__(*args, **kwargs) + self.tensors = tensors + + def persistent_id(self, obj): + if isinstance(obj, torch.Tensor): + self.tensors.append(obj) + return "" + # Since we just want to extract tensors, we don't mind if an object is + # unpicklable if it doesn't contain tensors, as we can just ignore/skip + # it. To play it safe, we only do so for common objects that we're sure + # don't contain tensors. Feel free to add new types here. Note also that + # even if a type isn't listed here this won't block users, since thet + # can just add a __getstate__ or __reduce__ method to their class. + if isinstance(obj, LockType): + return "" + # Futures and RRefs don't technically contain a value, they just offer + # the means to access a value. + if isinstance(obj, CFuture) or is_rref_instance(obj): + return "" + if isinstance(obj, CAwait): + return "" + if isinstance(obj, torch.cuda.Event): + return "" + if isinstance(obj, threading.Thread): + return "" + return None + + +def _extract_tensors(obj): + r""" + This function is exclusively called from C++. + See ``torch/csrc/jit/python/python_ivalue.h``. + + It extracts the tensors contained in the given object, through pickling. + """ + tensors: List[torch.Tensor] = [] + extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors) + extractor.dump(obj) + return tensors + + +# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass +# that were previously dropped. To preserve the behavior, explicitly drop them there + +if sys.version_info > (3, 10): + _drop(enum.Enum.__new__) + _drop(enum.Enum.__format__) + _drop(enum.Enum.__repr__) + _drop(enum.Enum.__str__) diff --git a/MLPY/Lib/site-packages/torch/_lazy/__init__.py b/MLPY/Lib/site-packages/torch/_lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee827fe4809f28fed7b065d2b96e752342367288 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/__init__.py @@ -0,0 +1,55 @@ +import threading + +import torch._C._lazy +from torch.utils._pytree import tree_flatten, tree_unflatten + +from .closure import add_step_closure, run_step_closures + + +def mark_step(device: str = "", wait=False): + """Triggers a mark step, which amounts to + - collecting a group of 'live' lazy tensors to index into the compilation cache + (lowering/compiling their IR graphs if not cached) + - kicking off execution of the compiled function + - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) + """ + # TODO(whc) expand this to include backend hooks and align with XLA backend needs + torch._C._lazy._mark_step(device, [], wait=wait) + + run_step_closures() + + +def wait_device_ops(devices=None): + """Waits for all the async operations on the given devices to complete. + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + if devices is None: + devices = [] + torch._C._lazy._wait_device_ops(devices=devices) + + +def sync_multi(tensors, devices): + """ + Sync the list of lazy tensors so there IR get lowered for the activate backend + and the compiled computation graph get cached. + """ + torch._C._lazy._sync_multi(tensors, devices) + + +def get_tensor_id(tensor): + """Return a unique id of the lazy tensor maintained by LTC""" + return torch._C._lazy._get_tensor_id(tensor) + + +def to_cpu(tensors, devices=None): + devices = devices or ["lazy"] + + flattened, spec = tree_flatten(tensors) + sync_multi(flattened, devices) + return tree_unflatten([t.to("cpu") for t in flattened], spec) + + +def save(tensors, *args, **kwargs): + torch.save(to_cpu(tensors), *args, **kwargs) diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00ec282163633e78fefe472ce9f34e5f0d6ac677 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/closure.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/closure.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4c480f891805f4a51d555b076142b0261ef8e8f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/closure.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/computation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/computation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4423b7f56de2eed1b9e352a0c4729715af5d596d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/computation.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47d5f6ce4d6f774ac1fd1b743f9c37386bb565af Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/config.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/debug.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/debug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de58b0bf9fee6d04e155abc7691cbb8ddfa450e7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/debug.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/device_context.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/device_context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82bb35b4c30aa95185176038f6bad51adeff3aee Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/device_context.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99c0033eea7df0940b20f046836e02935a74cde7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f6ccc21144e85d001c1228c05ffe1e2b6bc6354 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/metrics.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103f8461985813f5371e3c638ed4ca149094db58 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/metrics.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3809b1575ad2339ab71ef139a6bd8e24c9ff7f7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d0c66bdf88f129f1e6bbc443fb98d8fe2c8112 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_lazy/closure.py b/MLPY/Lib/site-packages/torch/_lazy/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c00f2814692b61eae9f62d5c53085ec7663aa4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/closure.py @@ -0,0 +1,134 @@ +import os +import threading +from queue import Empty as EmptyQueue, Queue + +from torch._lazy.device_context import get_device_context + + +class ClosureHandler: + def __init__(self): + pass + + def run(self, closure): + """Run closure function + + Args: + closure: callable function to run + """ + closure() + + def __call__(self, closures): + for closure in closures: + self.run(closure) + + +class AsyncClosureHandler(ClosureHandler): + """Handler for Asynchronous Step Closures + Args: + max_queue_size: The maximum length of the closure queue after which + the training loop will block until closures are evaluated. + By default, a reasonable limit of a maximum of 100 on the queue. + This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment + variable. + """ + + def __init__(self, max_queue_size=100): + super().__init__() + self._closure_queue: Queue = Queue( + int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) + ) + self._closure_exception: Queue = Queue() + self._closure_lock = threading.Lock() + self._closure_event_loop_finished = threading.Event() + self._closure_event_loop = None + + def start_event_loop(self): + """Start closure event loop if not started""" + if self._closure_event_loop is None: + + def event_loop(): + # Run loop until closure event is set and closure queue is empty + while True: + try: + closure = self._closure_queue.get(block=True, timeout=3) + closure() + self._closure_queue.task_done() + except EmptyQueue: + with self._closure_lock: + if self._closure_queue.empty(): + self._closure_event_loop_finished.set() + return + except Exception as e: + self._closure_exception.put(e) + return + + self._closure_event_loop = threading.Thread(target=event_loop) + self._closure_event_loop.start() + + def run(self, closure): + with self._closure_lock: + self._closure_queue.put(closure, block=True) + if ( + self._closure_event_loop is None + or not self._closure_event_loop.is_alive() + ): + try: + e = self._closure_exception.get(block=False) + raise RuntimeError( + "Cannot run asynchronous closure due to previously raised exception" + ) from e + except EmptyQueue: + self._closure_event_loop = None + self.start_event_loop() + + +def add_step_closure(closure, args=(), run_async=False): + """Adds a closure to the list of the ones to be run at the end of the step. + Many times during model training there is the need to print/report (print to + console, post to tensorboard, etc...) information which require the content of + intermediary tensors to be inspected. + Inspecting different tensors content in different points of the model code + requires many executions and typically causes performance issues. + Adding a step closure will ensure that it will be run after the barrier, when + all the live tensors will be already materialized to device data. + Live tensors which will include the ones captured by the closure arguments. + So using `add_step_closure()` will ensure a single execution will be + performed, even when multiple closures are queued, requiring multiple tensors + to be inspected. + Step closures will be run sequentially in the order they have been queued. + Note that even though using this API the execution will be optimized, it is + advised to throttle the printing/reporting events once every N steps. + Args: + closure (callable): The function to be called. + args (tuple): The arguments to be passed to the closure. + run_async: If True, run the closure asynchronously. + """ + devctx = get_device_context() + closures_type = "async_step_closures" if run_async else "step_closures" + step_closures = getattr(devctx, closures_type, None) + if step_closures is None: + step_closures = [] + setattr(devctx, closures_type, step_closures) + step_closures.append(lambda a=args: closure(*a)) + + +def run_step_closures(): + devctx = get_device_context() + async_step_closures = getattr(devctx, "async_step_closures", None) + if async_step_closures is not None: + devctx.async_step_closures = [] + async_closure_handler = getattr(devctx, "async_closure_handler", None) + if async_closure_handler is None: + async_closure_handler = AsyncClosureHandler() + devctx.async_closure_handler = async_closure_handler + async_closure_handler(async_step_closures) + + step_closures = getattr(devctx, "step_closures", None) + if step_closures is not None: + devctx.step_closures = [] + closure_handler = getattr(devctx, "closure_handler", None) + if closure_handler is None: + closure_handler = ClosureHandler() + devctx.closure_handler = closure_handler + closure_handler(step_closures) + return devctx diff --git a/MLPY/Lib/site-packages/torch/_lazy/computation.py b/MLPY/Lib/site-packages/torch/_lazy/computation.py new file mode 100644 index 0000000000000000000000000000000000000000..747e009ab85d5e8ac1048b9a4cd0a7e7a34111f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/computation.py @@ -0,0 +1,26 @@ +import torch._C._lazy +import torch._C._lazy_ts_backend + + +def get_tensors_ts_device_data_node(tensors): + """Return tensor ids and eager tensors for DeviceData nodes in the + IR for the passed in lazy tensors. + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors) + + +def get_graph_hash(tensors): + """Return the graph hash for the passed in lazy tensors""" + return torch._C._lazy._get_graph_hash(tensors) + + +def run_cached_graph(hash_str, graph_inputs): + """Running the cached computation graph with the given inputs + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs) diff --git a/MLPY/Lib/site-packages/torch/_lazy/config.py b/MLPY/Lib/site-packages/torch/_lazy/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c204f1cd4aea9ab63039f69557b4aeec58ee0a8d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/config.py @@ -0,0 +1,16 @@ +import torch._C._lazy + + +def get_force_fallback(): + """Get the config used to force LTC fallback""" + return torch._C._lazy._get_force_fallback() + + +def set_force_fallback(configval): + """Set the config used to force LTC fallback""" + torch._C._lazy._set_force_fallback(configval) + + +def set_reuse_ir(val: bool): + """Set the config to reuse IR nodes for faster tracing""" + torch._C._lazy._set_reuse_ir(val) diff --git a/MLPY/Lib/site-packages/torch/_lazy/debug.py b/MLPY/Lib/site-packages/torch/_lazy/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..201674767b8c692352a87b5cd66748270c9d5210 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/debug.py @@ -0,0 +1,21 @@ +import torch._C._lazy + + +def render_ir_graph(tensors): + """Return a text dump of the LTC IR graph in dot format for the tensors. + The text can be processed by tools like dot to be rendered in pdf,png etc.""" + return torch._C._lazy._get_tensors_dot(tensors) + + +def dump_ir(tensors, ir_format): + """Return a dump of the tensors in the specified format. + Valid format are + - text: for LTC IR + - backend: for the activate backend IR + """ + if ir_format == "text": + return torch._C._lazy._get_tensors_text(tensors) + elif ir_format == "backend": + return torch._C._lazy._get_tensors_backend(tensors) + else: + raise RuntimeError(f"Unrecognized IR format: {ir_format}") diff --git a/MLPY/Lib/site-packages/torch/_lazy/device_context.py b/MLPY/Lib/site-packages/torch/_lazy/device_context.py new file mode 100644 index 0000000000000000000000000000000000000000..1332f4e9d7dddefa00369702131977d77a3933db --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/device_context.py @@ -0,0 +1,25 @@ +import threading +from typing import Any, Dict + +import torch._C._lazy + + +class DeviceContext: + _CONTEXTS: Dict[str, Any] = dict() + _CONTEXTS_LOCK = threading.Lock() + + def __init__(self, device): + self.device = device + + +def get_device_context(device=None): + if device is None: + device = torch._C._lazy._get_default_device_type() + else: + device = str(device) + with DeviceContext._CONTEXTS_LOCK: + devctx = DeviceContext._CONTEXTS.get(device, None) + if devctx is None: + devctx = DeviceContext(device) + DeviceContext._CONTEXTS[device] = devctx + return devctx diff --git a/MLPY/Lib/site-packages/torch/_lazy/extract_compiled_graph.py b/MLPY/Lib/site-packages/torch/_lazy/extract_compiled_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..ecac11f9cb13aca2caf43e4bf25b75ce7a81d25e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/extract_compiled_graph.py @@ -0,0 +1,223 @@ +import copy +import dataclasses +import itertools +import os +from typing import Any, Callable, Dict, List + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions + +debug = os.environ.get("debug_extract_compiled_graph") is not None + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: Dict[int, int] + graph_input_tensor_ids: List[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_ivalues: List[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_ivalue + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +class ReturnValueHandler: + r""" + When ltc_sync_multi is called on multi tensors, the compiled graph + will contain output only for unique tensors - if a tensor appears multiple + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation + even if the TS graph dedup the output. e.g. for method: + + def forward(self, a): + return a, a + + the TS graph captured by LTC will return a single tensor, but Python method expects 2. + + This class dedup the lazy tensors first to get the index that will be used + to duplicate the eager tensors later. + """ + + def __init__(self, lazy_out_list): + self.index: List[List[int]] = [] + self.total_count = len(lazy_out_list) + + tensor_id_to_idx: Dict[int, int] = {} + for dup_idx, lazy_tensor in enumerate(lazy_out_list): + uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) + if uniq_idx is not None: + self.index[uniq_idx].append(dup_idx) + else: + uniq_idx = len(self.index) + self.index.append([dup_idx]) + tensor_id_to_idx[id(lazy_tensor)] = uniq_idx + + def duplicate_eager_tensors(self, eager_tensor_list): + duplicated_list = [None] * self.total_count + assert len(eager_tensor_list) == len(self.index) + + for uniq_idx, eager_tensor in enumerate(eager_tensor_list): + for dup_idx in self.index[uniq_idx]: + duplicated_list[dup_idx] = eager_tensor + return duplicated_list + + +def force_lazy_device(model: fx.GraphModule): + """ + Factory methods in a Fx graph may create tensors for a specific eager devices. + If we take no actions, those eager tensors will be mixed with lazy tensors and + cause crash. This method overwrite those eager device to lazy device. + """ + + def tolazydevice(dev): + if isinstance(dev, torch.device): + return torch.device("lazy", index=dev.index) + return dev + + def hasDeviceArg(args, kwargs): + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) + + for nd in model.graph.nodes: + nd.args = tuple(tolazydevice(arg) for arg in nd.args) + nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} + + # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return + # eager tensors on the default device + # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, + # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). + # To force those tensors on the lazy device, we can not simply override + # the device argument since there is no explicit device argument. + # What we are doing here is, for the list of covered tensor factory methods + # we add a lazy device argument explicity. + # + # TODO: This solution is no ideal since we may miss some factory methods. In future + # when we support lazy mode, this method can be replaced by that. + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): + kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. + kwargs["device"] = torch.device("lazy") + nd.kwargs = kwargs + + model.recompile() + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: + """ + Optimize an eager model with LTC and returns a wrapper to execute the + compiled graph directly without retracing. It depends on other mechanisms + like TorchDynamo guards to guarantee the returned wrapper is only called + when it's safe. + """ + lazy_args = [arg.to(device="lazy") for arg in example_inputs] + args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) + force_lazy_device(lazy_model) + + # This line executes lazy tracing and enable us extracting compiled graph later + metrics.reset() + lazy_out = lazy_model(*lazy_args) + fallback_ops = get_fallback_ops() + metrics.reset() + + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(lazy_out, (tuple, list)): + lazy_out = (lazy_out,) + + args_and_out = tuple(lazy_args) + tuple(lazy_out) + return_value_handler = ReturnValueHandler(args_and_out) + if debug: + print("Fx code:\n", model.code) + print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) + + # TODO: this part is TS backend specific for now and will be generalized to + # support XLA + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) + assert len(graph_input_tensor_ids) == len(graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) + + graph_hash = computation.get_graph_hash(args_and_out) + + if debug: + print("graph_hash", graph_hash) + print(f"args_tensor_ids {args_tensor_ids}") + print("tensor ids from device data:", graph_input_tensor_ids) + + # sync the list of output tensors so the computation graph for these + # tensors will be cached. Those computation graphs can be retrieved + # by graph hash later. + lazy.sync_multi(args_and_out, []) + + def optimized_mod(*args): + if len(args_and_out) == 0: + return () + graph_input = graph_input_matcher(args) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) + + assert len(res) == len(args_and_out) + for i, arg in enumerate(args): + # only copy those tensors that get inplace updated + if arg is not res[i]: + arg.copy_(res[i]) + + # skip the args + return res[len(args) :] + + return optimized_mod diff --git a/MLPY/Lib/site-packages/torch/_lazy/ir_cache.py b/MLPY/Lib/site-packages/torch/_lazy/ir_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..63cf09d13b2345210dfb06c33ac77d0dea5d6296 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/ir_cache.py @@ -0,0 +1,13 @@ +import torch._C._lazy + + +def dump(dot_file_name: str): + """Dump TrieCache in the dot format""" + return torch._C._lazy._dump_ir_cache(dot_file_name) + + +def reset(): + """Clear TrieCache. This is needed in testing to avoid + node reusing between different tests. + """ + return torch._C._lazy._clear_ir_cache() diff --git a/MLPY/Lib/site-packages/torch/_lazy/metrics.py b/MLPY/Lib/site-packages/torch/_lazy/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..662f77fb65d21e5297c4df23b923cb2efa47a655 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/metrics.py @@ -0,0 +1,21 @@ +import torch._C._lazy + + +def reset(): + """Resets all metric counters.""" + torch._C._lazy._reset_metrics() + + +def counter_names(): + """Retrieves all the currently active counter names.""" + return torch._C._lazy._counter_names() + + +def counter_value(name: str): + """Return the value of the counter with the speficied name""" + return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/MLPY/Lib/site-packages/torch/_lazy/tensor_factory_functions.py b/MLPY/Lib/site-packages/torch/_lazy/tensor_factory_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..01ffe07b101eb3a8713f8d5d0c6ececf452b696e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/tensor_factory_functions.py @@ -0,0 +1,48 @@ +import torch + +""" +tensor_factory_functions defines the list of torch functions that create tensors. +The list is grabbed by searching thru native_functions.yaml by the following +regular expression: + + cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor" + +It's possible that new tensor factory functions are added making this list stale. +Use at your own risk or regenerate the list. +""" +tensor_factory_functions = ( + torch._cudnn_init_dropout_state, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch._empty_affine_quantized, + torch.empty_strided, + torch.eye, + torch.full, + torch.from_file, + torch.hann_window, + torch.hamming_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.ones, + torch.scalar_tensor, + torch.rand, + torch.randint, + torch.randn, + torch.randperm, + torch.range, + torch._efficientzerotensor, + torch.zeros, + torch.tril_indices, + torch.triu_indices, + # Note: the following functions match the regular expression search above but + # they are not available in the torch module. Comment out. + # torch._sparse_coo_tensor_with_dims, + # torch.fft_fftfreq, + # torch.fft_rfftfreq, +) + ( + # torch.tensor is special since it's not in native_functions.yaml + # add it separately + torch.tensor, +) diff --git a/MLPY/Lib/site-packages/torch/_lazy/ts_backend.py b/MLPY/Lib/site-packages/torch/_lazy/ts_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6ea374121c16284382c258eb5e90f51094061d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lazy/ts_backend.py @@ -0,0 +1,6 @@ +import torch._C._lazy_ts_backend + + +def init(): + """Initializes the lazy Torchscript backend""" + torch._C._lazy_ts_backend._init() diff --git a/MLPY/Lib/site-packages/torch/_library/__init__.py b/MLPY/Lib/site-packages/torch/_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..526349fcaa97c2b8b850a2b33672c8a2eb98b894 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_library/__init__.py @@ -0,0 +1,3 @@ +import torch._library.abstract_impl +import torch._library.simple_registry +import torch._library.utils diff --git a/MLPY/Lib/site-packages/torch/_library/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_library/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c0e67703f3c51291bdd255a67f0e5ea84fba83a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_library/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_library/__pycache__/abstract_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_library/__pycache__/abstract_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9131026ad7cf59304fd634e10143f665d20588dc Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_library/__pycache__/abstract_impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_library/__pycache__/simple_registry.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_library/__pycache__/simple_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d86083153a2de1f15460cc1cea9cf9c92465ca50 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_library/__pycache__/simple_registry.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_library/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_library/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e1eab717dcf37cddef1290cf9aa4fb43703435e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_library/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_library/abstract_impl.py b/MLPY/Lib/site-packages/torch/_library/abstract_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef7cbf9fcafc92085b226dbecaddb4b52a990f0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_library/abstract_impl.py @@ -0,0 +1,206 @@ +import contextlib +import functools +import warnings +from typing import Callable, Optional + +import torch +from torch._library.utils import Kernel, RegistrationHandle + + +class AbstractImplHolder: + """A holder where one can register an abstract impl to.""" + + def __init__(self, qualname: str): + self.qualname: str = qualname + self.kernel: Optional[Kernel] = None + self.lib: Optional[torch.library.Library] = None + + def register(self, func: Callable, source: str) -> RegistrationHandle: + """Register an abstract impl. + + Returns a RegistrationHandle that one can use to de-register this + abstract impl. + """ + if self.kernel is not None: + raise RuntimeError( + f"impl_abstract(...): the operator {self.qualname} " + f"already has an abstract impl registered at " + f"{self.kernel.source}." + ) + if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): + raise RuntimeError( + f"impl_abstract(...): the operator {self.qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call " + f"impl_abstract." + ) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + self.qualname, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"impl_abstract(...): the operator {self.qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to " + f"DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an abstract " + f"impl; " + f"instead, the operator will decompose into its constituents " + f"and those " + f"can have abstract impls defined on them." + ) + + # Store the kernel in this holder + self.kernel = Kernel(func, source) + + # Also register the abstract impl to Meta key + if self.lib is None: + ns = self.qualname.split("::")[0] + self.lib = torch.library.Library(ns, "FRAGMENT") + meta_kernel = construct_meta_kernel(self.qualname, self) + self.lib.impl(self.qualname, meta_kernel, "Meta") + + def deregister_abstract_impl(): + if self.lib: + self.lib._destroy() + self.lib = None + self.kernel = None + + return RegistrationHandle(deregister_abstract_impl) + + +def construct_meta_kernel( + qualname: str, abstract_impl_holder: AbstractImplHolder +) -> Callable: + assert abstract_impl_holder.kernel is not None + + @functools.wraps(abstract_impl_holder.kernel.func) + def meta_kernel(*args, **kwargs): + assert abstract_impl_holder.kernel is not None + source = abstract_impl_holder.kernel.source + + def error_on_ctx(): + raise RuntimeError( + f"Attempted to call get_ctx() for the meta implementation " + f"for {qualname} (implemented at {source})" + f"You have presumably called get_ctx() because the operator " + f"has a data-dependent output shape; if so, there is no " + f"such meta implementation and this error is the correct " + f"behavior." + ) + + with set_ctx_getter(error_on_ctx): + return abstract_impl_holder.kernel(*args, **kwargs) + + return meta_kernel + + +def get_none(): + return None + + +global_ctx_getter: Callable = get_none + + +@contextlib.contextmanager +def set_ctx_getter(ctx_getter): + global global_ctx_getter + prev = global_ctx_getter + try: + global_ctx_getter = ctx_getter + yield + finally: + global_ctx_getter = prev + + +class AbstractImplCtx: + """ + Context object for writing abstract implementations for custom operators. + """ + + def __init__(self, _shape_env, _op): + self._shape_env = _shape_env + self._op = _op + + def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: + warnings.warn( + "create_unbacked_symint is deprecated, please use new_dynamic_size instead" + ) + return self.new_dynamic_size(min=min, max=max) + + def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: + """Constructs a new symint (symbolic int) representing a data-dependent value. + + This is useful for writing the abstract implementation (which is necessary + for torch.compile) for a CustomOp where an output Tensor has a size + that depends on the data of the input Tensors. + + Args: + min (int): A statically known inclusive lower bound for this symint. Default: 0 + max (Optional[int]): A statically known inclusive upper bound for this + symint. Default: None + + .. warning: + + It is important that the ``min`` and ``max`` (if not None) values are set + correctly, otherwise, there will be undefined behavior under + torch.compile. The default value of ``min`` is 2 due to torch.compile + specializing on 0/1 sizes. + + You must also verify that your implementation on concrete Tensors + (e.g. CPU/CUDA) only returns Tensors where the size that corresponds + to the symint also has respects these constraint. + The easiest way to do this is to add an assertion in the CPU/CUDA/etc + implementation that the size follows these bounds. + + Example:: + + >>> # An operator with data-dependent output shape + >>> lib = torch.library.Library("mymodule", "FRAGMENT") + >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") + >>> + >>> @torch.library.impl_abstract("mymodule::custom_nonzero") + >>> def custom_nonzero_abstract(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> @torch.library.impl(lib, "custom_nonzero", "CPU") + >>> def custom_nonzero_cpu(x): + >>> x_np = x.numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + + """ + if ( + self._shape_env is None + or not self._shape_env.allow_dynamic_output_shape_ops + ): + raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) + + if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): + raise ValueError( + f"ctx.new_dynamic_size(min={min}, max={max}): expected " + f"min and max to be statically known ints but got SymInt. " + f"This is not supported." + ) + + if min < 0: + raise ValueError( + f"ctx.new_dynamic_size(min={min}, ...): expected min to be " + f"greater than or equal to 0: this API can only create " + f"non-negative sizes." + ) + + result = self._shape_env.create_unbacked_symint() + torch.fx.experimental.symbolic_shapes._constrain_range_for_size( + result, min=min, max=max + ) + return result diff --git a/MLPY/Lib/site-packages/torch/_library/simple_registry.py b/MLPY/Lib/site-packages/torch/_library/simple_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6653eed7a2cacb8f3cfba1f7445c8ca65e176f9b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_library/simple_registry.py @@ -0,0 +1,43 @@ +from .abstract_impl import AbstractImplHolder + +__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] + + +class SimpleLibraryRegistry: + """Registry for the "simple" torch.library APIs + + The "simple" torch.library APIs are a higher-level API on top of the + raw PyTorch DispatchKey registration APIs that includes: + - abstract impl + + Registrations for these APIs do not go into the PyTorch dispatcher's + table because they may not directly involve a DispatchKey. For example, + the abstract impl is a Python function that gets invoked by FakeTensor. + Instead, we manage them here. + + SimpleLibraryRegistry is a mapping from a fully qualified operator name + (including the overload) to SimpleOperatorEntry. + """ + + def __init__(self): + self._data = {} + + def find(self, qualname: str) -> "SimpleOperatorEntry": + if qualname not in self._data: + self._data[qualname] = SimpleOperatorEntry(qualname) + return self._data[qualname] + + +singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() + + +class SimpleOperatorEntry: + """This is 1:1 to an operator overload. + + The fields of SimpleOperatorEntry are Holders where kernels can be + registered to. + """ + + def __init__(self, qualname: str): + self.qualname: str = qualname + self.abstract_impl: AbstractImplHolder = AbstractImplHolder(qualname) diff --git a/MLPY/Lib/site-packages/torch/_library/utils.py b/MLPY/Lib/site-packages/torch/_library/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7fdd52b67b73c48e575f1d06b26052530f931f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_library/utils.py @@ -0,0 +1,158 @@ +import dataclasses +import inspect +import sys +from typing import Any, Callable, Tuple + +import torch + + +@dataclasses.dataclass +class Kernel: + """Models a (function, source location)""" + + func: Callable + source: str + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class RegistrationHandle: + """Does something when someone calls .destroy() on it""" + + def __init__(self, on_destroy: Callable): + self._on_destroy = on_destroy + + def destroy(self) -> None: + self._on_destroy() + + +def get_source(stacklevel: int) -> str: + """Get a string that represents the caller. + + Example: "/path/to/foo.py:42" + + Use stacklevel=1 to get the caller's source + Use stacklevel=2 to get the caller's caller's source + etc. + """ + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + source = f"{frame.filename}:{frame.lineno}" + return source + + +def parse_namespace(qualname: str) -> Tuple[str, str]: + splits = qualname.split("::") + if len(splits) != 2: + raise ValueError( + f"Expected `qualname` to be of the form " + f'"namespace::name", but got {qualname}. ' + f"The qualname passed to the torch.library APIs must consist " + f"of a namespace and a name, e.g. aten::sin" + ) + return splits[0], splits[1] + + +def lookup_op(qualname: str) -> torch._ops.OpOverloadPacket: + namespace, name = parse_namespace(qualname) + if "." in name: + name, overload = name.split(".") + else: + overload = "default" + ns = getattr(torch.ops, namespace) + packet = getattr(ns, name) + return getattr(packet, overload) + + +def is_builtin(op: torch._ops.OpOverload) -> bool: + assert isinstance(op, torch._ops.OpOverload) + return op.namespace in {"aten", "prim", "prims"} + + +def is_functional_schema(schema: Any) -> bool: + """Check if the schema is functional. + + An operator is functional if: + - it does not mutate any of its inputs + - it does not return a view on any of its inputs + - it has at least one return + """ + + # Lazy import because not all PyTorch builds have torchgen + from torchgen.model import FunctionSchema, SchemaKind + + assert isinstance(schema, (str, FunctionSchema)) + if isinstance(schema, str): + schema = FunctionSchema.parse(schema) + + if schema.kind() != SchemaKind.functional: + return False + rets = schema.returns + is_non_mutating_view = len(rets) > 0 and any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + if is_non_mutating_view: + return False + if not schema.returns: + return False + return True + + +def mutates_and_returns_first_arg(op: torch._ops.OpOverload): + """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. + + TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, + but not all PyTorch builds have torchgen (due to the yaml dependency being weird). + Figure this out. + + Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) + """ + if op.namespace != "aten": + return False + schema = op._schema + if not len(schema.returns) == 1: + return False + if schema.returns[0].alias_info is None: + return False + alias_set = schema.returns[0].alias_info.after_set + if len(alias_set) != 1: + return False + loc = next(iter(alias_set)) + if len(schema.arguments) < 1: + return False + first_arg = schema.arguments[0] + if first_arg.alias_info is None: + return False + if not first_arg.alias_info.is_write: + return False + alias_set = first_arg.alias_info.after_set + if len(alias_set) != 1: + return False + if loc != next(iter(alias_set)): + return False + for arg in schema.arguments[1:]: + if arg.alias_info is not None: + return False + return True + + +def zip_schema(schema, args, kwargs): + """zips schema.arguments and (args, kwargs) together. + + Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: + that is, kwargs must be keyword-only arguments and default values may be omitted. + """ + assert len(schema.arguments) >= len(args) + len(kwargs) + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + yield info, kwargs[info.name] + continue + if i >= len(args): + # args that are equal to their default values are not populated + # if they are followed by args that are equal to their defaults. + # Skip these. + continue + yield info, args[i] + return diff --git a/MLPY/Lib/site-packages/torch/_linalg_utils.py b/MLPY/Lib/site-packages/torch/_linalg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6de7fd4f72e0bd3e756cb63cfa25283b50a693c4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_linalg_utils.py @@ -0,0 +1,164 @@ +"""Various linear algebra utility methods for internal use. + +""" + +from typing import Optional, Tuple + +import torch +from torch import Tensor + + +def is_sparse(A): + """Check if tensor A is a sparse tensor""" + if isinstance(A, torch.Tensor): + return A.layout == torch.sparse_coo + + error_str = "expected Tensor" + if not torch.jit.is_scripting(): + error_str += f" but got {type(A)}" + raise TypeError(error_str) + + +def get_floating_dtype(A): + """Return the floating point dtype of tensor A. + + Integer types map to float32. + """ + dtype = A.dtype + if dtype in (torch.float16, torch.float32, torch.float64): + return dtype + return torch.float32 + + +def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: + """Multiply two matrices. + + If A is None, return B. A can be sparse or dense. B is always + dense. + """ + if A is None: + return B + if is_sparse(A): + return torch.sparse.mm(A, B) + return torch.matmul(A, B) + + +def conjugate(A): + """Return conjugate of tensor A. + + .. note:: If A's dtype is not complex, A is returned. + """ + if A.is_complex(): + return A.conj() + return A + + +def transpose(A): + """Return transpose of a matrix or batches of matrices.""" + ndim = len(A.shape) + return A.transpose(ndim - 1, ndim - 2) + + +def transjugate(A): + """Return transpose conjugate of a matrix or batches of matrices.""" + return conjugate(transpose(A)) + + +def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: + """Return bilinear form of matrices: :math:`X^T A Y`.""" + return matmul(transpose(X), matmul(A, Y)) + + +def qform(A: Optional[Tensor], S: Tensor): + """Return quadratic form :math:`S^T A S`.""" + return bform(S, A, S) + + +def basis(A): + """Return orthogonal basis of A columns.""" + return torch.linalg.qr(A).Q + + +def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: + """Return eigenpairs of A with specified ordering.""" + if largest is None: + largest = False + E, Z = torch.linalg.eigh(A, UPLO="U") + # assuming that E is ordered + if largest: + E = torch.flip(E, dims=(-1,)) + Z = torch.flip(Z, dims=(-1,)) + return E, Z + + +# These functions were deprecated and removed +# This nice error message can be removed in version 1.13+ +def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed.\n" + "Please use the `torch.linalg.matrix_rank` function instead. " + "The parameter 'symmetric' was renamed in `torch.linalg.matrix_rank()` to 'hermitian'." + ) + + +def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`torch.solve` is deprecated in favor of `torch.linalg.solve`. " + "`torch.linalg.solve` has its arguments reversed and does not return the LU factorization.\n\n" + "To get the LU factorization see `torch.lu`, which can be used with `torch.lu_solve` or `torch.lu_unpack`.\n" + "X = torch.solve(B, A).solution " + "should be replaced with:\n" + "X = torch.linalg.solve(A, B)" + ) + + +def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n" + "`torch.linalg.lstsq` has reversed arguments and does not return the QR decomposition in " + "the returned tuple (although it returns other information about the problem).\n\n" + "To get the QR decomposition consider using `torch.linalg.qr`.\n\n" + "The returned solution in `torch.lstsq` stored the residuals of the solution in the " + "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, " + "the residuals are in the field 'residuals' of the returned named tuple.\n\n" + "The unpacking of the solution, as in\n" + "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n" + "should be replaced with:\n" + "X = torch.linalg.lstsq(A, B).solution" + ) + + +def _symeig( + input, eigenvectors=False, upper=True, *, out=None +) -> Tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "The default behavior has changed from using the upper triangular portion of the matrix by default " + "to using the lower triangular portion.\n\n" + "L, _ = torch.symeig(A, upper=upper) " + "should be replaced with:\n" + "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n\n" + "and\n\n" + "L, V = torch.symeig(A, eigenvectors=True) " + "should be replaced with:\n" + "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')" + ) + + +def eig( + self: Tensor, eigenvectors: bool = False, *, e=None, v=None +) -> Tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. " + "`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors " + "mimicking complex tensors.\n\n" + "L, _ = torch.eig(A) " + "should be replaced with:\n" + "L_complex = torch.linalg.eigvals(A)\n\n" + "and\n\n" + "L, V = torch.eig(A, eigenvectors=True) " + "should be replaced with:\n" + "L_complex, V_complex = torch.linalg.eig(A)" + ) diff --git a/MLPY/Lib/site-packages/torch/_lobpcg.py b/MLPY/Lib/site-packages/torch/_lobpcg.py new file mode 100644 index 0000000000000000000000000000000000000000..d686337f5e059a51247f9f93cbda9a3f8a9382a5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lobpcg.py @@ -0,0 +1,1167 @@ +"""Locally Optimal Block Preconditioned Conjugate Gradient methods. +""" +# Author: Pearu Peterson +# Created: February 2020 + +from typing import Dict, Optional, Tuple + +import torch +from torch import Tensor +from . import _linalg_utils as _utils +from .overrides import handle_torch_function, has_torch_function + + +__all__ = ["lobpcg"] + + +def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): + # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 + F = D.unsqueeze(-2) - D.unsqueeze(-1) + F.diagonal(dim1=-2, dim2=-1).fill_(float("inf")) + F.pow_(-1) + + # A.grad = U (D.grad + (U^T U.grad * F)) U^T + Ut = U.mT.contiguous() + res = torch.matmul( + U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut) + ) + + return res + + +def _polynomial_coefficients_given_roots(roots): + """ + Given the `roots` of a polynomial, find the polynomial's coefficients. + + If roots = (r_1, ..., r_n), then the method returns + coefficients (a_0, a_1, ..., a_n (== 1)) so that + p(x) = (x - r_1) * ... * (x - r_n) + = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0 + + Note: for better performance requires writing a low-level kernel + """ + poly_order = roots.shape[-1] + poly_coeffs_shape = list(roots.shape) + # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0, + # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)}, + # but we insert one extra coefficient to enable better vectorization below + poly_coeffs_shape[-1] += 2 + poly_coeffs = roots.new_zeros(poly_coeffs_shape) + poly_coeffs[..., 0] = 1 + poly_coeffs[..., -1] = 1 + + # perform the Horner's rule + for i in range(1, poly_order + 1): + # note that it is computationally hard to compute backward for this method, + # because then given the coefficients it would require finding the roots and/or + # calculating the sensitivity based on the Vieta's theorem. + # So the code below tries to circumvent the explicit root finding by series + # of operations on memory copies imitating the Horner's method. + # The memory copies are required to construct nodes in the computational graph + # by exploting the explicit (not in-place, separate node for each step) + # recursion of the Horner's method. + # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. + poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs + out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) + out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow( + -1, poly_order - i + 1, i + 1 + ) + poly_coeffs = poly_coeffs_new + + return poly_coeffs.narrow(-1, 1, poly_order + 1) + + +def _polynomial_value(poly, x, zero_power, transition): + """ + A generic method for computing poly(x) using the Horner's rule. + + Args: + poly (Tensor): the (possibly batched) 1D Tensor representing + polynomial coefficients such that + poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and + poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n + + x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. + + zero_power (Tensor): the representation of `x^0`. It is application-specific. + + transition (Callable): the function that accepts some intermediate result `int_val`, + the `x` and a specific polynomial coefficient + `poly[..., k]` for some iteration `k`. + It basically performs one iteration of the Horner's rule + defined as `x * int_val + poly[..., k] * zero_power`. + Note that `zero_power` is not a parameter, + because the step `+ poly[..., k] * zero_power` depends on `x`, + whether it is a vector, a matrix, or something else, so this + functionality is delegated to the user. + """ + + res = zero_power.clone() + for k in range(poly.size(-1) - 2, -1, -1): + res = transition(res, x, poly[..., k]) + return res + + +def _matrix_polynomial_value(poly, x, zero_power=None): + """ + Evaluates `poly(x)` for the (batched) matrix input `x`. + Check out `_polynomial_value` function for more details. + """ + + # matrix-aware Horner's rule iteration + def transition(curr_poly_val, x, poly_coeff): + res = x.matmul(curr_poly_val) + res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1)) + return res + + if zero_power is None: + zero_power = torch.eye( + x.size(-1), x.size(-1), dtype=x.dtype, device=x.device + ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) + + return _polynomial_value(poly, x, zero_power, transition) + + +def _vector_polynomial_value(poly, x, zero_power=None): + """ + Evaluates `poly(x)` for the (batched) vector input `x`. + Check out `_polynomial_value` function for more details. + """ + + # vector-aware Horner's rule iteration + def transition(curr_poly_val, x, poly_coeff): + res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val) + return res + + if zero_power is None: + zero_power = x.new_ones(1).expand(x.shape) + + return _polynomial_value(poly, x, zero_power, transition) + + +def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): + # compute a projection operator onto an orthogonal subspace spanned by the + # columns of U defined as (I - UU^T) + Ut = U.mT.contiguous() + proj_U_ortho = -U.matmul(Ut) + proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1) + + # compute U_ortho, a basis for the orthogonal complement to the span(U), + # by projecting a random [..., m, m - k] matrix onto the subspace spanned + # by the columns of U. + # + # fix generator for determinism + gen = torch.Generator(A.device) + + # orthogonal complement to the span(U) + U_ortho = proj_U_ortho.matmul( + torch.randn( + (*A.shape[:-1], A.size(-1) - D.size(-1)), + dtype=A.dtype, + device=A.device, + generator=gen, + ) + ) + U_ortho_t = U_ortho.mT.contiguous() + + # compute the coefficients of the characteristic polynomial of the tensor D. + # Note that D is diagonal, so the diagonal elements are exactly the roots + # of the characteristic polynomial. + chr_poly_D = _polynomial_coefficients_given_roots(D) + + # the code belows finds the explicit solution to the Sylvester equation + # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U + # and incorporates it into the whole gradient stored in the `res` variable. + # + # Equivalent to the following naive implementation: + # res = A.new_zeros(A.shape) + # p_res = A.new_zeros(*A.shape[:-1], D.size(-1)) + # for k in range(1, chr_poly_D.size(-1)): + # p_res.zero_() + # for i in range(0, k): + # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2) + # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t()) + # + # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity + # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g, + # and we need to compute g(U_grad, A, U, D) + # + # The naive implementation is based on the paper + # Hu, Qingxi, and Daizhan Cheng. + # "The polynomial solution to the Sylvester matrix equation." + # Applied mathematics letters 19.9 (2006): 859-864. + # + # We can modify the computation of `p_res` from above in a more efficient way + # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2) + # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2) + # + ... + # + A.matrix_power(k - 1) U_grad * chr_poly_D[k] + # Note that this saves us from redundant matrix products with A (elimination of matrix_power) + U_grad_projected = U_grad + series_acc = U_grad_projected.new_zeros(U_grad_projected.shape) + for k in range(1, chr_poly_D.size(-1)): + poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D) + series_acc += U_grad_projected * poly_D.unsqueeze(-2) + U_grad_projected = A.matmul(U_grad_projected) + + # compute chr_poly_D(A) which essentially is: + # + # chr_poly_D_at_A = A.new_zeros(A.shape) + # for k in range(chr_poly_D.size(-1)): + # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k) + # + # Note, however, for better performance we use the Horner's rule + chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A) + + # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t + chr_poly_D_at_A_to_U_ortho = torch.matmul( + U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho) + ) + # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its + # Cholesky decomposition and then use `torch.cholesky_solve` for better stability. + # Cholesky decomposition requires the input to be positive-definite. + # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if + # 1. `largest` == False, or + # 2. `largest` == True and `k` is even + # under the assumption that `A` has distinct eigenvalues. + # + # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite + chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1 + chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky( + chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho + ) + + # compute the gradient part in span(U) + res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) + + # incorporate the Sylvester equation solution into the full gradient + # it resides in span(U_ortho) + res -= U_ortho.matmul( + chr_poly_D_at_A_to_U_ortho_sign + * torch.cholesky_solve( + U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L + ) + ).matmul(Ut) + + return res + + +def _symeig_backward(D_grad, U_grad, A, D, U, largest): + # if `U` is square, then the columns of `U` is a complete eigenspace + if U.size(-1) == U.size(-2): + return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) + else: + return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest) + + +class LOBPCGAutogradFunction(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None, + ) -> Tuple[Tensor, Tensor]: + # makes sure that input is contiguous for efficiency. + # Note: autograd does not support dense gradients for sparse input yet. + A = A.contiguous() if (not A.is_sparse) else A + if B is not None: + B = B.contiguous() if (not B.is_sparse) else B + + D, U = _lobpcg( + A, + k, + B, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, + ) + + ctx.save_for_backward(A, B, D, U) + ctx.largest = largest + + return D, U + + @staticmethod + def backward(ctx, D_grad, U_grad): + A_grad = B_grad = None + grads = [None] * 14 + + A, B, D, U = ctx.saved_tensors + largest = ctx.largest + + # lobpcg.backward has some limitations. Checks for unsupported input + if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): + raise ValueError( + "lobpcg.backward does not support sparse input yet." + "Note that lobpcg.forward does though." + ) + if ( + A.dtype in (torch.complex64, torch.complex128) + or B is not None + and B.dtype in (torch.complex64, torch.complex128) + ): + raise ValueError( + "lobpcg.backward does not support complex input yet." + "Note that lobpcg.forward does though." + ) + if B is not None: + raise ValueError( + "lobpcg.backward does not support backward with B != I yet." + ) + + if largest is None: + largest = True + + # symeig backward + if B is None: + A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest) + + # A has index 0 + grads[0] = A_grad + # B has index 2 + grads[2] = B_grad + return tuple(grads) + + +def lobpcg( + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None, +) -> Tuple[Tensor, Tensor]: + """Find the k largest (or smallest) eigenvalues and the corresponding + eigenvectors of a symmetric positive definite generalized + eigenvalue problem using matrix-free LOBPCG methods. + + This function is a front-end to the following LOBPCG algorithms + selectable via `method` argument: + + `method="basic"` - the LOBPCG method introduced by Andrew + Knyazev, see [Knyazev2001]. A less robust method, may fail when + Cholesky is applied to singular input. + + `method="ortho"` - the LOBPCG method with orthogonal basis + selection [StathopoulosEtal2002]. A robust method. + + Supported inputs are dense, sparse, and batches of dense matrices. + + .. note:: In general, the basic method spends least time per + iteration. However, the robust methods converge much faster and + are more stable. So, the usage of the basic method is generally + not recommended but there exist cases where the usage of the + basic method may be preferred. + + .. warning:: The backward method does not support sparse and complex inputs. + It works only when `B` is not provided (i.e. `B == None`). + We are actively working on extensions, and the details of + the algorithms are going to be published promptly. + + .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not. + To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric + in first-order optimization routines, prior to running `lobpcg` + we do the following symmetrization map: `A -> (A + A.t()) / 2`. + The map is performed only when the `A` requires gradients. + + Args: + + A (Tensor): the input tensor of size :math:`(*, m, m)` + + B (Tensor, optional): the input tensor of size :math:`(*, m, + m)`. When not specified, `B` is interpreted as + identity matrix. + + X (tensor, optional): the input tensor of size :math:`(*, m, n)` + where `k <= n <= m`. When specified, it is used as + initial approximation of eigenvectors. X must be a + dense tensor. + + iK (tensor, optional): the input tensor of size :math:`(*, m, + m)`. When specified, it will be used as preconditioner. + + k (integer, optional): the number of requested + eigenpairs. Default is the number of :math:`X` + columns (when specified) or `1`. + + n (integer, optional): if :math:`X` is not specified then `n` + specifies the size of the generated random + approximation of eigenvectors. Default value for `n` + is `k`. If :math:`X` is specified, the value of `n` + (when specified) must be the number of :math:`X` + columns. + + tol (float, optional): residual tolerance for stopping + criterion. Default is `feps ** 0.5` where `feps` is + smallest non-zero floating-point number of the given + input tensor `A` data type. + + largest (bool, optional): when True, solve the eigenproblem for + the largest eigenvalues. Otherwise, solve the + eigenproblem for smallest eigenvalues. Default is + `True`. + + method (str, optional): select LOBPCG method. See the + description of the function above. Default is + "ortho". + + niter (int, optional): maximum number of iterations. When + reached, the iteration process is hard-stopped and + the current approximation of eigenpairs is returned. + For infinite iteration but until convergence criteria + is met, use `-1`. + + tracker (callable, optional) : a function for tracing the + iteration process. When specified, it is called at + each iteration step with LOBPCG instance as an + argument. The LOBPCG instance holds the full state of + the iteration process in the following attributes: + + `iparams`, `fparams`, `bparams` - dictionaries of + integer, float, and boolean valued input + parameters, respectively + + `ivars`, `fvars`, `bvars`, `tvars` - dictionaries + of integer, float, boolean, and Tensor valued + iteration variables, respectively. + + `A`, `B`, `iK` - input Tensor arguments. + + `E`, `X`, `S`, `R` - iteration Tensor variables. + + For instance: + + `ivars["istep"]` - the current iteration step + `X` - the current approximation of eigenvectors + `E` - the current approximation of eigenvalues + `R` - the current residual + `ivars["converged_count"]` - the current number of converged eigenpairs + `tvars["rerr"]` - the current state of convergence criteria + + Note that when `tracker` stores Tensor objects from + the LOBPCG instance, it must make copies of these. + + If `tracker` sets `bvars["force_stop"] = True`, the + iteration process will be hard-stopped. + + ortho_iparams, ortho_fparams, ortho_bparams (dict, optional): + various parameters to LOBPCG algorithm when using + `method="ortho"`. + + Returns: + + E (Tensor): tensor of eigenvalues of size :math:`(*, k)` + + X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)` + + References: + + [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal + Preconditioned Eigensolver: Locally Optimal Block Preconditioned + Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2), + 517-541. (25 pages) + https://epubs.siam.org/doi/abs/10.1137/S1064827500366124 + + [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng + Wu. (2002) A Block Orthogonalization Procedure with Constant + Synchronization Requirements. SIAM J. Sci. Comput., 23(6), + 2165-2182. (18 pages) + https://epubs.siam.org/doi/10.1137/S1064827500370883 + + [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming + Gu. (2018) A Robust and Efficient Implementation of LOBPCG. + SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages) + https://epubs.siam.org/doi/abs/10.1137/17M1129830 + + """ + + if not torch.jit.is_scripting(): + tensor_ops = (A, B, X, iK) + if not set(map(type, tensor_ops)).issubset( + (torch.Tensor, type(None)) + ) and has_torch_function(tensor_ops): + return handle_torch_function( + lobpcg, + tensor_ops, + A, + k=k, + B=B, + X=X, + n=n, + iK=iK, + niter=niter, + tol=tol, + largest=largest, + method=method, + tracker=tracker, + ortho_iparams=ortho_iparams, + ortho_fparams=ortho_fparams, + ortho_bparams=ortho_bparams, + ) + + if not torch._jit_internal.is_scripting(): + if A.requires_grad or (B is not None and B.requires_grad): + # While it is expected that `A` is symmetric, + # the `A_grad` might be not. Therefore we perform the trick below, + # so that `A_grad` becomes symmetric. + # The symmetrization is important for first-order optimization methods, + # so that (A - alpha * A_grad) is still a symmetric matrix. + # Same holds for `B`. + A_sym = (A + A.mT) / 2 + B_sym = (B + B.mT) / 2 if (B is not None) else None + + return LOBPCGAutogradFunction.apply( + A_sym, + k, + B_sym, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, + ) + else: + if A.requires_grad or (B is not None and B.requires_grad): + raise RuntimeError( + "Script and require grads is not supported atm." + "If you just want to do the forward, use .detach()" + "on A and B before calling into lobpcg" + ) + + return _lobpcg( + A, + k, + B, + X, + n, + iK, + niter, + tol, + largest, + method, + tracker, + ortho_iparams, + ortho_fparams, + ortho_bparams, + ) + + +def _lobpcg( + A: Tensor, + k: Optional[int] = None, + B: Optional[Tensor] = None, + X: Optional[Tensor] = None, + n: Optional[int] = None, + iK: Optional[Tensor] = None, + niter: Optional[int] = None, + tol: Optional[float] = None, + largest: Optional[bool] = None, + method: Optional[str] = None, + tracker: None = None, + ortho_iparams: Optional[Dict[str, int]] = None, + ortho_fparams: Optional[Dict[str, float]] = None, + ortho_bparams: Optional[Dict[str, bool]] = None, +) -> Tuple[Tensor, Tensor]: + # A must be square: + assert A.shape[-2] == A.shape[-1], A.shape + if B is not None: + # A and B must have the same shapes: + assert A.shape == B.shape, (A.shape, B.shape) + + dtype = _utils.get_floating_dtype(A) + device = A.device + if tol is None: + feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype] + tol = feps**0.5 + + m = A.shape[-1] + k = (1 if X is None else X.shape[-1]) if k is None else k + n = (k if n is None else n) if X is None else X.shape[-1] + + if m < 3 * n: + raise ValueError( + f"LPBPCG algorithm is not applicable when the number of A rows (={m})" + f" is smaller than 3 x the number of requested eigenpairs (={n})" + ) + + method = "ortho" if method is None else method + + iparams = { + "m": m, + "n": n, + "k": k, + "niter": 1000 if niter is None else niter, + } + + fparams = { + "tol": tol, + } + + bparams = {"largest": True if largest is None else largest} + + if method == "ortho": + if ortho_iparams is not None: + iparams.update(ortho_iparams) + if ortho_fparams is not None: + fparams.update(ortho_fparams) + if ortho_bparams is not None: + bparams.update(ortho_bparams) + iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3) + iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3) + fparams["ortho_tol"] = fparams.get("ortho_tol", tol) + fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol) + fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol) + bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False) + + if not torch.jit.is_scripting(): + LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign] + + if len(A.shape) > 2: + N = int(torch.prod(torch.tensor(A.shape[:-2]))) + bA = A.reshape((N,) + A.shape[-2:]) + bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None + bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None + bE = torch.empty((N, k), dtype=dtype, device=device) + bXret = torch.empty((N, m, k), dtype=dtype, device=device) + + for i in range(N): + A_ = bA[i] + B_ = bB[i] if bB is not None else None + X_ = ( + torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] + ) + assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n)) + iparams["batch_index"] = i + worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker) + worker.run() + bE[i] = worker.E[:k] + bXret[i] = worker.X[:, :k] + + if not torch.jit.is_scripting(): + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] + + return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) + + X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X + assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n)) + + worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker) + + worker.run() + + if not torch.jit.is_scripting(): + LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] + + return worker.E[:k], worker.X[:, :k] + + +class LOBPCG: + """Worker class of LOBPCG methods.""" + + def __init__( + self, + A: Optional[Tensor], + B: Optional[Tensor], + X: Tensor, + iK: Optional[Tensor], + iparams: Dict[str, int], + fparams: Dict[str, float], + bparams: Dict[str, bool], + method: str, + tracker: None, + ) -> None: + # constant parameters + self.A = A + self.B = B + self.iK = iK + self.iparams = iparams + self.fparams = fparams + self.bparams = bparams + self.method = method + self.tracker = tracker + m = iparams["m"] + n = iparams["n"] + + # variable parameters + self.X = X + self.E = torch.zeros((n,), dtype=X.dtype, device=X.device) + self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device) + self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device) + self.tvars: Dict[str, Tensor] = {} + self.ivars: Dict[str, int] = {"istep": 0} + self.fvars: Dict[str, float] = {"_": 0.0} + self.bvars: Dict[str, bool] = {"_": False} + + def __str__(self): + lines = ["LOPBCG:"] + lines += [f" iparams={self.iparams}"] + lines += [f" fparams={self.fparams}"] + lines += [f" bparams={self.bparams}"] + lines += [f" ivars={self.ivars}"] + lines += [f" fvars={self.fvars}"] + lines += [f" bvars={self.bvars}"] + lines += [f" tvars={self.tvars}"] + lines += [f" A={self.A}"] + lines += [f" B={self.B}"] + lines += [f" iK={self.iK}"] + lines += [f" X={self.X}"] + lines += [f" E={self.E}"] + r = "" + for line in lines: + r += line + "\n" + return r + + def update(self): + """Set and update iteration variables.""" + if self.ivars["istep"] == 0: + X_norm = float(torch.norm(self.X)) + iX_norm = X_norm**-1 + A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm + B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm + self.fvars["X_norm"] = X_norm + self.fvars["A_norm"] = A_norm + self.fvars["B_norm"] = B_norm + self.ivars["iterations_left"] = self.iparams["niter"] + self.ivars["converged_count"] = 0 + self.ivars["converged_end"] = 0 + + if self.method == "ortho": + self._update_ortho() + else: + self._update_basic() + + self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1 + self.ivars["istep"] = self.ivars["istep"] + 1 + + def update_residual(self): + """Update residual R from A, B, X, E.""" + mm = _utils.matmul + self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E + + def update_converged_count(self): + """Determine the number of converged eigenpairs using backward stable + convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018]. + + Users may redefine this method for custom convergence criteria. + """ + # (...) -> int + prev_count = self.ivars["converged_count"] + tol = self.fparams["tol"] + A_norm = self.fvars["A_norm"] + B_norm = self.fvars["B_norm"] + E, X, R = self.E, self.X, self.R + rerr = ( + torch.norm(R, 2, (0,)) + * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1 + ) + converged = rerr < tol + count = 0 + for b in converged: + if not b: + # ignore convergence of following pairs to ensure + # strict ordering of eigenpairs + break + count += 1 + assert ( + count >= prev_count + ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" + self.ivars["converged_count"] = count + self.tvars["rerr"] = rerr + return count + + def stop_iteration(self): + """Return True to stop iterations. + + Note that tracker (if defined) can force-stop iterations by + setting ``worker.bvars['force_stop'] = True``. + """ + return ( + self.bvars.get("force_stop", False) + or self.ivars["iterations_left"] == 0 + or self.ivars["converged_count"] >= self.iparams["k"] + ) + + def run(self): + """Run LOBPCG iterations. + + Use this method as a template for implementing LOBPCG + iteration scheme with custom tracker that is compatible with + TorchScript. + """ + self.update() + + if not torch.jit.is_scripting() and self.tracker is not None: + self.call_tracker() + + while not self.stop_iteration(): + self.update() + + if not torch.jit.is_scripting() and self.tracker is not None: + self.call_tracker() + + @torch.jit.unused + def call_tracker(self): + """Interface for tracking iteration process in Python mode. + + Tracking the iteration process is disabled in TorchScript + mode. In fact, one should specify tracker=None when JIT + compiling functions using lobpcg. + """ + # do nothing when in TorchScript mode + pass + + # Internal methods + + def _update_basic(self): + """ + Update or initialize iteration variables when `method == "basic"`. + """ + mm = torch.matmul + ns = self.ivars["converged_end"] + nc = self.ivars["converged_count"] + n = self.iparams["n"] + largest = self.bparams["largest"] + + if self.ivars["istep"] == 0: + Ri = self._get_rayleigh_ritz_transform(self.X) + M = _utils.qform(_utils.qform(self.A, self.X), Ri) + E, Z = _utils.symeig(M, largest) + self.X[:] = mm(self.X, mm(Ri, Z)) + self.E[:] = E + np = 0 + self.update_residual() + nc = self.update_converged_count() + self.S[..., :n] = self.X + + W = _utils.matmul(self.iK, self.R) + self.ivars["converged_end"] = ns = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + else: + S_ = self.S[:, nc:ns] + Ri = self._get_rayleigh_ritz_transform(S_) + M = _utils.qform(_utils.qform(self.A, S_), Ri) + E_, Z = _utils.symeig(M, largest) + self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc])) + self.E[nc:] = E_[: n - nc] + P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc])) + np = P.shape[-1] + + self.update_residual() + nc = self.update_converged_count() + self.S[..., :n] = self.X + self.S[:, n : n + np] = P + W = _utils.matmul(self.iK, self.R[:, nc:]) + + self.ivars["converged_end"] = ns = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + + def _update_ortho(self): + """ + Update or initialize iteration variables when `method == "ortho"`. + """ + mm = torch.matmul + ns = self.ivars["converged_end"] + nc = self.ivars["converged_count"] + n = self.iparams["n"] + largest = self.bparams["largest"] + + if self.ivars["istep"] == 0: + Ri = self._get_rayleigh_ritz_transform(self.X) + M = _utils.qform(_utils.qform(self.A, self.X), Ri) + E, Z = _utils.symeig(M, largest) + self.X = mm(self.X, mm(Ri, Z)) + self.update_residual() + np = 0 + nc = self.update_converged_count() + self.S[:, :n] = self.X + W = self._get_ortho(self.R, self.X) + ns = self.ivars["converged_end"] = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + + else: + S_ = self.S[:, nc:ns] + # Rayleigh-Ritz procedure + E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest) + + # Update E, X, P + self.X[:, nc:] = mm(S_, Z[:, : n - nc]) + self.E[nc:] = E_[: n - nc] + P = mm( + S_, + mm( + Z[:, n - nc :], + _utils.basis(_utils.transpose(Z[: n - nc, n - nc :])), + ), + ) + np = P.shape[-1] + + # check convergence + self.update_residual() + nc = self.update_converged_count() + + # update S + self.S[:, :n] = self.X + self.S[:, n : n + np] = P + W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np]) + ns = self.ivars["converged_end"] = n + np + W.shape[-1] + self.S[:, n + np : ns] = W + + def _get_rayleigh_ritz_transform(self, S): + """Return a transformation matrix that is used in Rayleigh-Ritz + procedure for reducing a general eigenvalue problem :math:`(S^TAS) + C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T + S^TAS Ri) Z = Z E` where `C = Ri Z`. + + .. note:: In the original Rayleight-Ritz procedure in + [DuerschEtal2018], the problem is formulated as follows:: + + SAS = S^T A S + SBS = S^T B S + D = () ** -1/2 + R^T R = Cholesky(D SBS D) + Ri = D R^-1 + solve symeig problem Ri^T SAS Ri Z = Theta Z + C = Ri Z + + To reduce the number of matrix products (denoted by empty + space between matrices), here we introduce element-wise + products (denoted by symbol `*`) so that the Rayleight-Ritz + procedure becomes:: + + SAS = S^T A S + SBS = S^T B S + d = () ** -1/2 # this is 1-d column vector + dd = d d^T # this is 2-d matrix + R^T R = Cholesky(dd * SBS) + Ri = R^-1 * d # broadcasting + solve symeig problem Ri^T SAS Ri Z = Theta Z + C = Ri Z + + where `dd` is 2-d matrix that replaces matrix products `D M + D` with one element-wise product `M * dd`; and `d` replaces + matrix product `D M` with element-wise product `M * + d`. Also, creating the diagonal matrix `D` is avoided. + + Args: + S (Tensor): the matrix basis for the search subspace, size is + :math:`(m, n)`. + + Returns: + Ri (tensor): upper-triangular transformation matrix of size + :math:`(n, n)`. + + """ + B = self.B + mm = torch.matmul + SBS = _utils.qform(B, S) + d_row = SBS.diagonal(0, -2, -1) ** -0.5 + d_col = d_row.reshape(d_row.shape[0], 1) + # TODO use torch.linalg.cholesky_solve once it is implemented + R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True) + return torch.linalg.solve_triangular( + R, d_row.diag_embed(), upper=True, left=False + ) + + def _get_svqb( + self, U: Tensor, drop: bool, tau: float # Tensor # bool # float + ) -> Tensor: + """Return B-orthonormal U. + + .. note:: When `drop` is `False` then `svqb` is based on the + Algorithm 4 from [DuerschPhD2015] that is a slight + modification of the corresponding algorithm + introduced in [StathopolousWu2002]. + + Args: + + U (Tensor) : initial approximation, size is (m, n) + drop (bool) : when True, drop columns that + contribution to the `span([U])` is small. + tau (float) : positive tolerance + + Returns: + + U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size + is (m, n1), where `n1 = n` if `drop` is `False, + otherwise `n1 <= n`. + + """ + if torch.numel(U) == 0: + return U + UBU = _utils.qform(self.B, U) + d = UBU.diagonal(0, -2, -1) + + # Detect and drop exact zero columns from U. While the test + # `abs(d) == 0` is unlikely to be True for random data, it is + # possible to construct input data to lobpcg where it will be + # True leading to a failure (notice the `d ** -0.5` operation + # in the original algorithm). To prevent the failure, we drop + # the exact zero columns here and then continue with the + # original algorithm below. + nz = torch.where(abs(d) != 0.0) + assert len(nz) == 1, nz + if len(nz[0]) < len(d): + U = U[:, nz[0]] + if torch.numel(U) == 0: + return U + UBU = _utils.qform(self.B, U) + d = UBU.diagonal(0, -2, -1) + nz = torch.where(abs(d) != 0.0) + assert len(nz[0]) == len(d) + + # The original algorithm 4 from [DuerschPhD2015]. + d_col = (d**-0.5).reshape(d.shape[0], 1) + DUBUD = (UBU * d_col) * _utils.transpose(d_col) + E, Z = _utils.symeig(DUBUD) + t = tau * abs(E).max() + if drop: + keep = torch.where(E > t) + assert len(keep) == 1, keep + E = E[keep[0]] + Z = Z[:, keep[0]] + d_col = d_col[keep[0]] + else: + E[(torch.where(E < t))[0]] = t + + return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5) + + def _get_ortho(self, U, V): + """Return B-orthonormal U with columns are B-orthogonal to V. + + .. note:: When `bparams["ortho_use_drop"] == False` then + `_get_ortho` is based on the Algorithm 3 from + [DuerschPhD2015] that is a slight modification of + the corresponding algorithm introduced in + [StathopolousWu2002]. Otherwise, the method + implements Algorithm 6 from [DuerschPhD2015] + + .. note:: If all U columns are B-collinear to V then the + returned tensor U will be empty. + + Args: + + U (Tensor) : initial approximation, size is (m, n) + V (Tensor) : B-orthogonal external basis, size is (m, k) + + Returns: + + U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`) + such that :math:`V^T B U=0`, size is (m, n1), + where `n1 = n` if `drop` is `False, otherwise + `n1 <= n`. + """ + mm = torch.matmul + mm_B = _utils.matmul + m = self.iparams["m"] + tau_ortho = self.fparams["ortho_tol"] + tau_drop = self.fparams["ortho_tol_drop"] + tau_replace = self.fparams["ortho_tol_replace"] + i_max = self.iparams["ortho_i_max"] + j_max = self.iparams["ortho_j_max"] + # when use_drop==True, enable dropping U columns that have + # small contribution to the `span([U, V])`. + use_drop = self.bparams["ortho_use_drop"] + + # clean up variables from the previous call + for vkey in list(self.fvars.keys()): + if vkey.startswith("ortho_") and vkey.endswith("_rerr"): + self.fvars.pop(vkey) + self.ivars.pop("ortho_i", 0) + self.ivars.pop("ortho_j", 0) + + BV_norm = torch.norm(mm_B(self.B, V)) + BU = mm_B(self.B, U) + VBU = mm(_utils.transpose(V), BU) + i = j = 0 + stats = "" + for i in range(i_max): + U = U - mm(V, VBU) + drop = False + tau_svqb = tau_drop + for j in range(j_max): + if use_drop: + U = self._get_svqb(U, drop, tau_svqb) + drop = True + tau_svqb = tau_replace + else: + U = self._get_svqb(U, False, tau_replace) + if torch.numel(U) == 0: + # all initial U columns are B-collinear to V + self.ivars["ortho_i"] = i + self.ivars["ortho_j"] = j + return U + BU = mm_B(self.B, U) + UBU = mm(_utils.transpose(U), BU) + U_norm = torch.norm(U) + BU_norm = torch.norm(BU) + R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) + R_norm = torch.norm(R) + # https://github.com/pytorch/pytorch/issues/33810 workaround: + rerr = float(R_norm) * float(BU_norm * U_norm) ** -1 + vkey = f"ortho_UBUmI_rerr[{i}, {j}]" + self.fvars[vkey] = rerr + if rerr < tau_ortho: + break + VBU = mm(_utils.transpose(V), BU) + VBU_norm = torch.norm(VBU) + U_norm = torch.norm(U) + rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 + vkey = f"ortho_VBU_rerr[{i}]" + self.fvars[vkey] = rerr + if rerr < tau_ortho: + break + if m < U.shape[-1] + V.shape[-1]: + # TorchScript needs the class var to be assigned to a local to + # do optional type refinement + B = self.B + assert B is not None + raise ValueError( + "Overdetermined shape of U:" + f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold" + ) + self.ivars["ortho_i"] = i + self.ivars["ortho_j"] = j + return U + + +# Calling tracker is separated from LOBPCG definitions because +# TorchScript does not support user-defined callback arguments: +LOBPCG_call_tracker_orig = LOBPCG.call_tracker + + +def LOBPCG_call_tracker(self): + self.tracker(self) diff --git a/MLPY/Lib/site-packages/torch/_logging/__init__.py b/MLPY/Lib/site-packages/torch/_logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a388324056a62d7b70e75c7284c0ae3f79c06a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_logging/__init__.py @@ -0,0 +1,16 @@ +# Top level logging module for torch logging +# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# +# Simple setup for onboarding (see above doc for more detail): +# 1. register any top-level log qualified name for your module in torch._logging._registrations (see there for examples) +# 2. register any artifacts ( below) in torch._logging._registrations +# a. call getArtifactLogger(__name__, ) at your logging site instead of the standard logger to log your artifact +import torch._logging._registrations +from ._internal import ( + _init_logs, + DEFAULT_LOGGING, + getArtifactLogger, + LazyString, + set_logs, + trace_structured, + warning_once, +) diff --git a/MLPY/Lib/site-packages/torch/_logging/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_logging/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4825548a3f3211f4faa64d0c2a0b473488f7ad23 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_logging/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_logging/__pycache__/_internal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_logging/__pycache__/_internal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ad80d979c054f1c8f5968320d6078162b28ec02 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_logging/__pycache__/_internal.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_logging/__pycache__/_registrations.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_logging/__pycache__/_registrations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7480f6096e877f0df667072af729ce7e195d6660 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_logging/__pycache__/_registrations.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_logging/__pycache__/structured.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_logging/__pycache__/structured.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e1e97fc6e2fff82e9dca353fcf68ed8ac33a99 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_logging/__pycache__/structured.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_logging/_internal.py b/MLPY/Lib/site-packages/torch/_logging/_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..4fad1c394cf704f0147c7e43927cec5b8c4e29b1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_logging/_internal.py @@ -0,0 +1,1085 @@ +import functools +import hashlib +import itertools +import json +import logging +import os +import os.path +import re +import tempfile +from dataclasses import dataclass, field +from importlib import __import__ +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from weakref import WeakSet + +log = logging.getLogger(__name__) + +# This is a synthetic logger which doesn't correspond to an actual logger, +# but handles all of our "tracing" logging, which is structured and doesn't go +# to stderr but always goes to a dedicated log file. We don't put these +# loggers in the classic module hierarchy, because we don't want a suppression +# of logs to also cause a trace to get suppressed (traces typically are not +# collected, unless we are in prod, in which case they always are collected.) +# +# TODO: Maybe we should allow for some sub-hierarchy so you can control which +# traces you want to collect, for performance reasons. +# +# See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit +trace_log = logging.getLogger("torch.__trace") + +DEFAULT_LOG_LEVEL = logging.WARNING +LOG_ENV_VAR = "TORCH_LOGS" +LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT" +LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT" +TRACE_ENV_VAR = "TORCH_TRACE" + + +@dataclass +class LogRegistry: + # shorthand name to log qualified name + # Note: this only contains loggers registered + # from register_log + # e.g. "dynamo" -> "torch._dynamo" + log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict) + + # artifact logger qualified names, + # this is populated lazily, as calls to getArtifactLogger + # currently formatted as .__ + # e.g. "torch._dynamo.convert_frame.__guards" + artifact_log_qnames: Set[str] = field(default_factory=set) + + # child logs of registered logs if specified via open + # registration by the user (ie placing "torch._dynamo.output_graph" in the env var) + # these need to be tracked so their levels can be reset properly + # e.g. "torch._dynamo.output_graph" + child_log_qnames: Set[str] = field(default_factory=set) + + # artifact names, populated by register_artifact + # e.g. "guards" + artifact_names: Set[str] = field(default_factory=set) + + # Artifacts that should be visible by default in the error message + visible_artifacts: Set[str] = field(default_factory=set) + + # A short description of each artifact + artifact_descriptions: Dict[str, str] = field(default_factory=dict) + + # artifacts which are not displayed unless explicitly named in the + # settings. Ex. output_code is NOT displayed even if the inductor + # log level is set to DEBUG. It must be explicitly named in the settings + off_by_default_artifact_names: Set[str] = field(default_factory=set) + + # logging format string for artifacts + artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict) + + def is_artifact(self, name): + return name in self.artifact_names + + def is_log(self, alias): + return alias in self.log_alias_to_log_qnames + + # register a log with an alias + def register_log(self, alias, log_qnames: Union[str, List[str]]): + if isinstance(log_qnames, str): + log_qnames = [log_qnames] + self.log_alias_to_log_qnames[alias] = log_qnames + + # register an artifact name + def register_artifact_name( + self, name, description, visible, off_by_default, log_format + ): + self.artifact_names.add(name) + if visible: + self.visible_artifacts.add(name) + self.artifact_descriptions[name] = description + + # if off by default, don't enable it + # when log_name's log_level is set to DEBUG + if off_by_default: + self.off_by_default_artifact_names.add(name) + + if log_format is not None: + self.artifact_log_formatters[name] = logging.Formatter(log_format) + + # register the qualified name of an artifact log + # this is needed to know which logs need to be reset + # whenever the log_state is changed + def register_artifact_log(self, artifact_log_qname): + self.artifact_log_qnames.add(artifact_log_qname) + + def register_child_log(self, log_qname): + self.child_log_qnames.add(log_qname) + + # flattens all the qnames together (TODO: consider memoizing?) + def get_log_qnames(self) -> Set[str]: + return { + qname + for qnames in self.log_alias_to_log_qnames.values() + for qname in qnames + } + + def get_artifact_log_qnames(self): + return set(self.artifact_log_qnames) + + def get_child_log_qnames(self): + return set(self.child_log_qnames) + + def is_off_by_default(self, artifact_qname): + return artifact_qname in self.off_by_default_artifact_names + + +@dataclass +class LogState: + # qualified log names -> currently set log level + log_qname_to_level: Dict[str, str] = field(default_factory=dict) + + # the set of currently enabled artifacts + artifact_names: Set[str] = field(default_factory=set) + + def enable_artifact(self, artifact_name): + self.artifact_names.add(artifact_name) + + def is_artifact_enabled(self, name): + return name in self.artifact_names + + def enable_log(self, log_qnames, log_level): + if isinstance(log_qnames, str): + log_qnames = [log_qnames] + for log_qname in log_qnames: + self.log_qname_to_level[log_qname] = log_level + + def get_log_level_pairs(self): + """Returns all qualified module names for which the user requested + explicit logging settings. + + .. warning: + + This function used to return all loggers, regardless of whether + or not the user specified them or not; it now only returns logs + which were explicitly mentioned by the user (and torch, which + always is implicitly requested when we initialize our logging + subsystem.) + """ + return self.log_qname_to_level.items() + + def clear(self): + self.log_qname_to_level.clear() + self.artifact_names.clear() + + +log_registry = LogRegistry() +log_state = LogState() + +# sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING) +DEFAULT_LOGGING = { + "dynamo": logging.DEBUG, + "aot": logging.DEBUG, + "inductor": logging.DEBUG, + "ddp_graphs": True, + "graph_breaks": True, + "guards": True, + "recompiles": True, + "dynamic": logging.INFO, +} + + +def set_logs( + *, + all: Optional[int] = None, + dynamo: Optional[int] = None, + aot: Optional[int] = None, + autograd: Optional[int] = None, + dynamic: Optional[int] = None, + inductor: Optional[int] = None, + distributed: Optional[int] = None, + dist_c10d: Optional[int] = None, + dist_ddp: Optional[int] = None, + dist_fsdp: Optional[int] = None, + onnx: Optional[int] = None, + bytecode: bool = False, + aot_graphs: bool = False, + aot_joint_graph: bool = False, + ddp_graphs: bool = False, + graph: bool = False, + graph_code: bool = False, + graph_breaks: bool = False, + graph_sizes: bool = False, + guards: bool = False, + recompiles: bool = False, + recompiles_verbose: bool = False, + trace_source: bool = False, + trace_call: bool = False, + output_code: bool = False, + schedule: bool = False, + perf_hints: bool = False, + post_grad_graphs: bool = False, + onnx_diagnostics: bool = False, + fusion: bool = False, + overlap: bool = False, + export: Optional[int] = None, + modules: Optional[Dict[str, Union[int, bool]]] = None, + cudagraphs: bool = False, + sym_node: bool = False, +): + """ + Sets the log level for individual components and toggles individual log + artifact types. + + .. warning:: This feature is a prototype and may have compatibility + breaking changes in the future. + + .. note:: The ``TORCH_LOGS`` environment variable has complete precedence + over this function, so if it was set, this function does nothing. + + A component is a set of related features in PyTorch. All of the log + messages emitted from a given component have their own log levels. If the + log level of a particular message has priority greater than or equal to its + component's log level setting, it is emitted. Otherwise, it is suppressed. + This allows you to, for instance, silence large groups of log messages that + are not relevant to you and increase verbosity of logs for components that + are relevant. The expected log level values, ordered from highest to lowest + priority, are: + + * ``logging.CRITICAL`` + * ``logging.ERROR`` + * ``logging.WARNING`` + * ``logging.INFO`` + * ``logging.DEBUG`` + * ``logging.NOTSET`` + + See documentation for the Python ``logging`` module for more information on + log levels: ``_ + + An artifact is a particular type of log message. Each artifact is assigned + to a parent component. A component can emit many different kinds of + artifacts. In general, an artifact is emitted if either its corresponding + setting in the argument list below is turned on or if its parent component + is set to a log level less than or equal to the log level of the artifact. + + Keyword args: + all (:class:`Optional[int]`): + The default log level for all components. Default: ``logging.WARN`` + + dynamo (:class:`Optional[int]`): + The log level for the TorchDynamo component. Default: ``logging.WARN`` + + aot (:class:`Optional[int]`): + The log level for the AOTAutograd component. Default: ``logging.WARN`` + + autograd (:class:`Optional[int]`): + The log level for autograd. Default: ``logging.WARN`` + + inductor (:class:`Optional[int]`): + The log level for the TorchInductor component. Default: ``logging.WARN`` + + dynamic (:class:`Optional[int]`): + The log level for dynamic shapes. Default: ``logging.WARN`` + + distributed (:class:`Optional[int]`): + Whether to log c10d communication operations and other debug info from PyTorch Distributed components. + Default: ``logging.WARN`` + + dist_c10d (:class:`Optional[int]`): + Whether to log c10d communication operations related debug info in PyTorch Distributed components. + Default: ``logging.WARN`` + + dist_ddp (:class:`Optional[int]`): + Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components. + Default: ``logging.WARN`` + + dist_fsdp (:class:`Optional[int]`): + Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components. + Default: ``logging.WARN`` + + onnx (:class:`Optional[int]`): + The log level for the ONNX exporter component. Default: ``logging.WARN`` + + bytecode (:class:`bool`): + Whether to emit the original and generated bytecode from TorchDynamo. + Default: ``False`` + + aot_graphs (:class:`bool`): + Whether to emit the graphs generated by AOTAutograd. Default: ``False`` + + aot_joint_graph (:class:`bool`): + Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` + + inductor (:class:`Optional[int]`): + Whether to log information from inductor cudagraphs. Default: ``logging.WARN`` + + ddp_graphs (:class:`bool`): + Whether to emit graphs generated by DDPOptimizer. Default: ``False`` + + graph (:class:`bool`): + Whether to emit the graph captured by TorchDynamo in tabular format. + Default: ``False`` + + graph_code (:class:`bool`): + Whether to emit the python source of the graph captured by TorchDynamo. + Default: ``False`` + + graph_breaks (:class:`bool`): + Whether to emit the graph breaks encountered by TorchDynamo. + Default: ``False`` + + graph_sizes (:class:`bool`): + Whether to emit tensor sizes of the graph captured by TorchDynamo. + Default: ``False`` + + guards (:class:`bool`): + Whether to emit the guards generated by TorchDynamo for each compiled + function. Default: ``False`` + + recompiles (:class:`bool`): + Whether to emit a guard failure reason and message every time + TorchDynamo recompiles a function. Default: ``False`` + + recompiles_verbose (:class:`bool`): + Whether to emit all guard failure reasons when TorchDynamo recompiles + a function, even those that are not actually run. Default: ``False`` + + trace_source (:class:`bool`): + Whether to emit when TorchDynamo begins tracing a new line. Default: ``False`` + + trace_call (:class:`bool`): + Whether to emit detailed line location when TorchDynamo creates an FX node + corresponding to function call. Python 3.11+ only. Default: ``False`` + + output_code (:class:`bool`): + Whether to emit the TorchInductor output code. Default: ``False`` + + schedule (:class:`bool`): + Whether to emit the TorchInductor schedule. Default: ``False`` + + perf_hints (:class:`bool`): + Whether to emit the TorchInductor perf hints. Default: ``False`` + + post_grad_graphs (:class:`bool`): + Whether to emit the graphs generated by after post grad passes. Default: ``False`` + + onnx_diagnostics (:class:`bool`): + Whether to emit the ONNX exporter diagnostics in logging. Default: ``False`` + + fusion (:class:`bool`): + Whether to emit detailed Inductor fusion decisions. Default: ``False`` + + overlap (:class:`bool`): + Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False`` + + sym_node (:class:`bool`): + Whether to emit debug info for various SymNode opterations. Default: ``False`` + + export (:class:`Optional[int]`): + The log level for export. Default: ``logging.WARN`` + + modules (dict): + This argument provides an alternate way to specify the above log + component and artifact settings, in the format of a keyword args + dictionary given as a single argument. There are two cases + where this is useful (1) if a new log component or artifact has + been registered but a keyword argument for it has not been added + to this function and (2) if the log level for an unregistered module + needs to be set. This can be done by providing the fully-qualified module + name as the key, with the log level as the value. Default: ``None`` + + + Example:: + + >>> # xdoctest: +SKIP + >>> import logging + + # The following changes the "dynamo" component to emit DEBUG-level + # logs, and to emit "graph_code" artifacts. + + >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True) + + # The following enables the logs for a different module + + >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG}) + """ + # ignore if env var is set + if LOG_ENV_VAR in os.environ: + log.warning( + "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs" + ) + return + + log_state.clear() + + modules = modules or {} + + def _set_logs(**kwargs): + for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr] + if val is None: + continue + + if log_registry.is_artifact(alias): + if not isinstance(val, bool): + raise ValueError( + f"Expected bool to enable artifact {alias}, received {val}" + ) + + if val: + log_state.enable_artifact(alias) + elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames: + if val not in logging._levelToName: + raise ValueError( + f"Unrecognized log level for log {alias}: {val}, valid level values " + f"are: {','.join([str(k) for k in logging._levelToName.keys()])}" + ) + + log_state.enable_log( + log_registry.log_alias_to_log_qnames.get(alias, alias), val + ) + else: + raise ValueError( + f"Unrecognized log or artifact name passed to set_logs: {alias}" + ) + + _init_logs() + + _set_logs( + torch=all, + dynamo=dynamo, + aot=aot, + autograd=autograd, + inductor=inductor, + dynamic=dynamic, + bytecode=bytecode, + aot_graphs=aot_graphs, + aot_joint_graph=aot_joint_graph, + ddp_graphs=ddp_graphs, + distributed=distributed, + dist_c10d=dist_c10d, + dist_ddp=dist_ddp, + dist_fsdp=dist_fsdp, + graph=graph, + graph_code=graph_code, + graph_breaks=graph_breaks, + graph_sizes=graph_sizes, + guards=guards, + recompiles=recompiles, + recompiles_verbose=recompiles_verbose, + trace_source=trace_source, + trace_call=trace_call, + output_code=output_code, + schedule=schedule, + perf_hints=perf_hints, + post_grad_graphs=post_grad_graphs, + onnx=onnx, + onnx_diagnostics=onnx_diagnostics, + fusion=fusion, + overlap=overlap, + sym_node=sym_node, + export=export, + cudagraphs=cudagraphs, + ) + + +def get_loggers(): + """ + Returns: a list of all registered loggers + """ + return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()] + + +def register_log(setting_name, log_name): + """ + Enables a log to be controlled by the env var and user API with the setting_name + Args: + setting_name: the shorthand name used in the env var and user API + log_name: the log name that the setting_name is associated with + """ + log_registry.register_log(setting_name, log_name) + + +def register_artifact( + setting_name, description, visible=False, off_by_default=False, log_format=None +): + """ + Enables an artifact to be controlled by the env var and user API with name + Args: + setting_name: the shorthand name used in the env var and user API + description: A description of what this outputs + visible: Whether it gets suggested to users by default + off_by_default: whether this artifact should be logged when the ancestor loggers + are enabled at level DEBUG + """ + log_registry.register_artifact_name( + setting_name, description, visible, off_by_default, log_format + ) + + +def getArtifactLogger(module_qname, artifact_name): + if artifact_name not in log_registry.artifact_names: + raise ValueError( + f"Artifact name: {repr(artifact_name)} not registered," + f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations." + ) + qname = module_qname + f".__{artifact_name}" + log = logging.getLogger(qname) + log.artifact_name = artifact_name # type: ignore[attr-defined] + log_registry.register_artifact_log(qname) + configure_artifact_log(log) + return log + + +INCR_VERBOSITY_CHAR = "+" +DECR_VERBOSITY_CHAR = "-" +VERBOSITY_REGEX = ( + "(" + + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)]) + + "?)" +) + + +def configure_artifact_log(log): + # If the artifact is off by default, then it should only be logged when explicitly + # enabled; set propagate to False so that this artifact is not propagated + # to its ancestor logger + if log_registry.is_off_by_default(log.artifact_name): + log.propagate = False + + # enable artifact logging when explicitly enabled + if log_state.is_artifact_enabled(log.artifact_name): + log.setLevel(logging.DEBUG) + log.propagate = True + + +# match a comma separated list of loggable names (whitespace allowed after commas) +def _gen_settings_regex(): + return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?") + + +def _validate_settings(settings): + return re.fullmatch(_gen_settings_regex(), settings) is not None + + +def help_message(verbose=False): + def pad_to(s, length=30): + assert len(s) <= length + return s + " " * (length - len(s)) + + if verbose: + printed_artifacts = log_registry.artifact_names + else: + printed_artifacts = log_registry.visible_artifacts + + if verbose: + heading = "All registered names" + else: + heading = "Visible registered names (use TORCH_LOGS='+help' for full list)" + lines = ( + ["all"] + + sorted(log_registry.log_alias_to_log_qnames.keys()) + + sorted( + [ + f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}" + for name in printed_artifacts + ] + ) + ) + setting_info = " " + "\n ".join(lines) + examples = """ +Examples: + TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to + logging.DEBUG and AOT to logging.INFO + + TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to + logging.ERROR and TorchInductor to logging.DEBUG + + TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact + + TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo + to logging.DEBUG and enable the schedule artifact + + TORCH_LOGS="+some.random.module,schedule" will set the log level of + some.random.module to logging.DEBUG and enable the schedule artifact + + TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format + string will set the output format + Valid keys are "levelname", "message", "pathname", "levelno", "lineno", + "filename" and "name". + + TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as + well. This is useful when the output is long. +""" # flake8: noqa: B950 + msg = f""" +TORCH_LOGS Info +{examples} + +{heading} +{setting_info} +""" + return msg + + +def _invalid_settings_err_msg(settings, verbose=False): + valid_settings = ", ".join( + ["all"] + + list(log_registry.log_alias_to_log_qnames.keys()) + + list(log_registry.artifact_names) + ) + msg = f""" +Invalid log settings: {settings}, must be a comma separated list of fully +qualified module names, registered log names or registered artifact names. +For more info on various settings, try TORCH_LOGS="help" +Valid settings: +{valid_settings} +""" + return msg + + +@functools.lru_cache +def _parse_log_settings(settings): + if settings == "": + return dict() + + if settings == "help": + raise ValueError(help_message(verbose=False)) + elif settings == "+help": + raise ValueError(help_message(verbose=True)) + if not _validate_settings(settings): + raise ValueError(_invalid_settings_err_msg(settings)) + + settings = re.sub(r"\s+", "", settings) + log_names = settings.split(",") + + def get_name_level_pair(name): + clean_name = name.replace(INCR_VERBOSITY_CHAR, "") + clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "") + + if name[0] == INCR_VERBOSITY_CHAR: + level = logging.DEBUG + elif name[0] == DECR_VERBOSITY_CHAR: + level = logging.ERROR + else: + level = logging.INFO + + return clean_name, level + + log_state = LogState() + + for name in log_names: + name, level = get_name_level_pair(name) + + if name == "all": + name = "torch" + + if log_registry.is_log(name): + assert level is not None + log_qnames = log_registry.log_alias_to_log_qnames[name] + log_state.enable_log(log_qnames, level) + elif log_registry.is_artifact(name): + log_state.enable_artifact(name) + elif _is_valid_module(name): + if not _has_registered_parent(name): + log_registry.register_log(name, name) + else: + log_registry.register_child_log(name) + log_state.enable_log(name, level) + else: + raise ValueError(_invalid_settings_err_msg(settings)) + + return log_state + + +def _is_valid_module(qname): + try: + __import__(qname) + return True + except ImportError: + return False + + +def _update_log_state_from_env(): + global log_state + log_setting = os.environ.get(LOG_ENV_VAR, None) + if log_setting is not None: + log_state = _parse_log_settings(log_setting) + + +def _has_registered_parent(log_qname): + cur_log = logging.getLogger(log_qname) + + registered_log_qnames = log_registry.get_log_qnames() + + while cur_log.parent: + if cur_log.name in registered_log_qnames: + return True + cur_log = cur_log.parent + + return False + + +# apply custom formats to artifacts when necessary +class TorchLogsFormatter(logging.Formatter): + def __init__(self, *, trace: bool = False): + super().__init__() + self._is_trace = trace + + def format(self, record): + artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None) + if artifact_name is not None: + artifact_formatter = log_registry.artifact_log_formatters.get( + artifact_name, None + ) + if artifact_formatter is not None: + return artifact_formatter.format(record) + + record.message = record.getMessage() + record.asctime = self.formatTime(record, "%m%d %H:%M:%S") + + # exception handling - copied from logging.Formatter.format + s = record.message + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if s[-1:] != "\n": + s = s + "\n" + s = s + record.exc_text + if record.stack_info: + if s[-1:] != "\n": + s = s + "\n" + s = s + self.formatStack(record.stack_info) + + record.rankprefix = "" + if not self._is_trace and dist.is_available() and dist.is_initialized(): + record.rankprefix = f"[rank{dist.get_rank()}]:" + + record.traceid = "" + if ( + not self._is_trace + and (trace_id := torch._guards.CompileContext.current_trace_id()) + is not None + ): + record.traceid = f" [{trace_id}]" + + glog_level_to_abbr = { + "DEBUG": "V", # V is for VERBOSE in glog + "INFO": "I", + "WARNING": "W", + "ERROR": "E", + "CRITICAL": "C", + } + + shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname) + + record.artifactprefix = "" + if artifact_name is not None: + record.artifactprefix = f" [__{artifact_name}]" + + prefix = ( + f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.thread} " + f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:" + f"{record.lineno}]{record.traceid}{record.artifactprefix}" + ) + if self._is_trace: + assert s == "" + r = f"{prefix} {json.dumps(record.metadata)}" + if record.payload is not None: + r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) + return r + else: + lines = s.split("\n") + return "\n".join(f"{prefix} {l}" for l in lines) + + +def _default_formatter(): + fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None) + if fmt is None: + return TorchLogsFormatter() + else: + if fmt in ("short", "basic"): + fmt = logging.BASIC_FORMAT + return logging.Formatter(fmt) + + +DEFAULT_FORMATTER = _default_formatter() + + +def _setup_handlers(create_handler_fn, log): + debug_handler = _track_handler(create_handler_fn()) + debug_handler.setFormatter(DEFAULT_FORMATTER) + debug_handler.setLevel(logging.DEBUG) + log.addHandler(debug_handler) + + +handlers = WeakSet() # type: ignore[var-annotated] + + +# mark handlers that we've created +# so we don't modify user handlers +def _track_handler(handler): + handlers.add(handler) + return handler + + +def _is_torch_handler(handler): + return handler in handlers + + +# clears all torch handlers on specified loggers +def _clear_handlers(log): + to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)] + for handler in to_remove: + log.removeHandler(handler) + + +def _reset_logs(): + # reset all registered logs + for log_qname in log_registry.get_log_qnames(): + log = logging.getLogger(log_qname) + log.setLevel(logging.WARNING) + log.propagate = False + _clear_handlers(log) + + # reset all artifact and child logs + for artifact_log_qname in itertools.chain( + log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames() + ): + log = logging.getLogger(artifact_log_qname) + log.setLevel(logging.NOTSET) + log.propagate = True + + trace_log.propagate = False + _clear_handlers(trace_log) + + +def _get_log_state(): + return log_state + + +def _set_log_state(state): + global log_state + log_state = state + + +def _init_logs(log_file_name=None): + _reset_logs() + _update_log_state_from_env() + + out = os.environ.get(LOG_OUT_ENV_VAR, None) + if out is not None: + log_file_name = out + + # First, reset all known (registered) loggers to NOTSET, so that they + # respect their parent log level + for log_qname in log_registry.get_log_qnames(): + # But not the top level torch level: this defaults to WARNING so + # that our log messages don't leak to the lower levels + if log_qname == "torch": + continue + log = logging.getLogger(log_qname) + log.setLevel(logging.NOTSET) + + # Now, for all loggers which the user requested to have non-standard + # logging behavior, modify their log levels + for log_qname, level in log_state.get_log_level_pairs(): + log = logging.getLogger(log_qname) + log.setLevel(level) + + # Finally, setup handlers for all registered loggers + for log_qname in log_registry.get_log_qnames(): + log = logging.getLogger(log_qname) + _setup_handlers( + logging.StreamHandler, + log, + ) + + if log_file_name is not None: + _setup_handlers( + lambda: logging.FileHandler(log_file_name), + log, + ) + + # configure artifact loggers, note: this must happen last + # since the levels of ancestor loggers are taken into account + for artifact_log_qname in log_registry.get_artifact_log_qnames(): + log = logging.getLogger(artifact_log_qname) + configure_artifact_log(log) + + # Setup handler for the special trace_log, with different default + # configuration + trace_dir_name = os.environ.get(TRACE_ENV_VAR, None) + # This handler may remove itself if trace_dir_name is None and we are not + # actually in an FB environment. This allows us to defer actually + # initializing it until we actually need to log anything. This is + # important because JK initializes a C++ singleton, which will pork our + # process if we subsequently fork. + handler = LazyTraceHandler(trace_dir_name) + # This log is ALWAYS at debug level. We will additionally test if there + # are any handlers before deciding to actually call logging on this. Do + # not manually call + trace_log.setLevel(logging.DEBUG) + trace_log_handler = _track_handler(handler) + trace_log_handler.setFormatter(TorchLogsFormatter(trace=True)) + trace_log.addHandler(trace_log_handler) + + +class LazyTraceHandler(logging.StreamHandler): + """Like FileHandler, but the file is allocated lazily only upon the first log message""" + + def __init__(self, root_dir: Optional[str]): + # This is implemented in the same way that delay is implemented on + # FileHandler + self.root_dir = root_dir + logging.Handler.__init__(self) + self.stream = None + self._builtin_open = open + + # cloned from FileHandler in cpython + def close(self): + self.acquire() + try: + try: + if self.stream: + try: + self.flush() + finally: + stream = self.stream + self.stream = None + if hasattr(stream, "close"): + stream.close() + finally: + # Issue #19523: call unconditionally to + # prevent a handler leak when delay is set + # Also see Issue #42378: we also rely on + # self._closed being set to True there + logging.StreamHandler.close(self) + finally: + self.release() + + def emit(self, record): + if self.stream is None: + ok = False + if self.root_dir is None: + TRACE_LOG_DIR = "/logs" + open_func = self._builtin_open + + import torch.version as torch_version + + if hasattr(torch_version, "git_version"): + log.info("LazyTraceHandler: disabled because not fbcode") + elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"): + log.info( + "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False" + ) + elif not os.path.exists(TRACE_LOG_DIR): + log.info( + "LazyTraceHandler: disabled because %s does not exist", + TRACE_LOG_DIR, + ) + elif not os.access(TRACE_LOG_DIR, os.W_OK): + log.info( + "LazyTraceHandler: disabled because %s is not writeable", + TRACE_LOG_DIR, + ) + else: + self.root_dir = TRACE_LOG_DIR + + if self.root_dir is not None: + os.makedirs(self.root_dir, exist_ok=True) + ranksuffix = "" + if dist.is_available() and dist.is_initialized(): + ranksuffix = f"rank_{dist.get_rank()}_" + self.stream = tempfile.NamedTemporaryFile( + mode="w+", + suffix=".log", + prefix=f"dedicated_log_torch_trace_{ranksuffix}", + dir=self.root_dir, + delete=False, + ) + log.info("LazyTraceHandler: logging to %s", self.stream.name) + else: + # We go poof, remove and no-op + trace_log.removeHandler(self) + return + if self.stream: + super().emit(record) + + +@functools.lru_cache(None) +def warning_once(logger_obj, *args, **kwargs): + """ + This function is similar to `logger.warning()`, but will emit the warning with the same message only once + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. + The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to + another type of cache that includes the caller frame information in the hashing function. + """ + logger_obj.warning(*args, **kwargs) + + +class LazyString: + def __init__(self, func, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + def __str__(self): + return self.func(*self.args, **self.kwargs) + + +def trace_structured( + name: str, + # NB: metadata expected to be dict so adding more info is forward compatible + # Tuple[str, int] is a special case for string interning + metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict, + *, + payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, + suppress_context: bool = False, +): + """ + metadata is an arbitrary JSON compatible struct, but it's expected to not be + too long (e.g., less than 1MB) + + payload is an arbitrary string, which can be arbitrarily long (but expected to have + newlines so no lines are too long) + """ + assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"] + assert callable( + metadata_fn + ), f"metadata_fn should be callable, but got {type(metadata_fn)}" + assert callable( + payload_fn + ), f"payload_fn should be callable, but got {type(payload_fn)}" + # trace_log never propagates and is ALWAYS DEBUG, so also check that there + # are handlers instead of checking the log level + if trace_log.handlers: + record: Dict[str, object] = {} + record[name] = metadata_fn() + if not suppress_context: + # TODO: Actually, the rank probably should just be emitted once at + # the top, and not repeatedly spammed in all the logs, since it + # never changes and we assume no interleaving + if dist.is_available() and dist.is_initialized(): + record["rank"] = dist.get_rank() + if ( + trace_id := torch._guards.CompileContext.current_trace_id() + ) is not None: + record["frame_id"] = trace_id.compile_id.frame_id + record["frame_compile_id"] = trace_id.compile_id.frame_compile_id + record["attempt"] = trace_id.attempt + payload = payload_fn() + if payload is not None: + if not isinstance(payload, str): + if isinstance(payload, list): + # special case to look better + payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]" + else: + # force newlines so we are unlikely to overflow line limit + payload = json.dumps(payload, indent=0) + h = hashlib.md5() + h.update(payload.encode("utf-8")) + record["has_payload"] = h.hexdigest() + trace_log.debug( + "", extra={"metadata": record, "payload": payload}, stacklevel=2 + ) + + +import torch._guards +import torch._utils_internal +import torch.distributed as dist diff --git a/MLPY/Lib/site-packages/torch/_logging/_registrations.py b/MLPY/Lib/site-packages/torch/_logging/_registrations.py new file mode 100644 index 0000000000000000000000000000000000000000..ad33a92eca3deaec4dbade3723e5f24ab805d048 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_logging/_registrations.py @@ -0,0 +1,134 @@ +# flake8: noqa: B950 +from ._internal import register_artifact, register_log + +DYNAMIC = ["torch.fx.experimental.symbolic_shapes", "torch.fx.experimental.sym_node"] +DISTRIBUTED = [ + "torch.distributed", + "torch._dynamo.backends.distributed", + "torch.nn.parallel.distributed", +] + +register_log("dynamo", ["torch._dynamo", *DYNAMIC]) +register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"]) +register_log("autograd", "torch.autograd") +register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"]) + +register_artifact( + "cudagraphs", + "Logs information from wrapping inductor generated code with cudagraphs.", +) + +register_log("dynamic", DYNAMIC) +register_log("torch", "torch") +register_log("distributed", DISTRIBUTED) +register_log( + "dist_c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"] +) +register_log( + "dist_ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] +) +register_log("dist_fsdp", ["torch.distributed.fsdp"]) +register_log("onnx", "torch.onnx") +register_log("export", ["torch._dynamo", "torch.export", *DYNAMIC]) + +register_artifact( + "guards", + "This prints the guards for every compiled Dynamo frame. It does not tell you where the guards come from.", + visible=True, +) +register_artifact("verbose_guards", "", off_by_default=True) +register_artifact( + "bytecode", + "Prints the original and modified bytecode from Dynamo. Mostly useful if you're debugging our bytecode generation in Dynamo.", + off_by_default=True, +) +register_artifact( + "graph", + "Prints the dynamo traced graph (prior to AOTDispatch) in a table. If you prefer python code use `graph_code` instead. ", +) +register_artifact("graph_code", "Like `graph`, but gives you the Python code instead.") +register_artifact( + "graph_sizes", "Prints the sizes of all FX nodes in the dynamo graph." +) +register_artifact( + "trace_source", + "As we execute bytecode, prints the file name / line number we are processing and the actual source code. Useful with `bytecode`", +) +register_artifact( + "trace_call", + "Like trace_source, but it will give you the per-expression blow-by-blow if your Python is recent enough.", +) +register_artifact( + "aot_graphs", + "Prints the FX forward and backward graph generated by AOTDispatch, after partitioning. Useful to understand what's being given to Inductor", + visible=True, +) +register_artifact( + "aot_joint_graph", + "Print FX joint graph from AOTAutograd, prior to partitioning. Useful for debugging partitioning", +) +register_artifact( + "post_grad_graphs", + "Prints the FX graph generated by post grad passes. Useful to understand what's being given to Inductor after post grad passes", +) +register_artifact( + "compiled_autograd", + "Prints various logs in compiled_autograd, including but not limited to the graphs. Useful for debugging compiled_autograd.", + visible=True, +) +register_artifact( + "ddp_graphs", + "Only relevant for compiling DDP. DDP splits into multiple graphs to trigger comms early. This will print each individual graph here.", +) +register_artifact( + "recompiles", + "Prints the reason why we recompiled a graph. Very, very useful.", + visible=True, +) +register_artifact( + "recompiles_verbose", + "Prints all guard checks that fail during a recompilation. " + "At runtime, Dynamo will stop at the first failed check for each failing guard. " + "So not all logged failing checks are actually ran by Dynamo.", + visible=True, + off_by_default=True, +) +register_artifact( + "graph_breaks", + "Prints whenever Dynamo decides that it needs to graph break (i.e. create a new graph). Useful for debugging why torch.compile has poor performance", + visible=True, +) +register_artifact( + "not_implemented", + "Prints log messages whenever we return NotImplemented in a multi-dispatch, letting you trace through each object we attempted to dispatch to", +) +register_artifact( + "output_code", + "Prints the code that Inductor generates (either Triton or C++)", + off_by_default=True, + visible=True, +) +register_artifact( + "schedule", + "Inductor scheduler information. Useful if working on Inductor fusion algo", + off_by_default=True, +) +register_artifact("perf_hints", "", off_by_default=True) +register_artifact("onnx_diagnostics", "", off_by_default=True) +register_artifact( + "fusion", + "Detailed Inductor fusion decisions. More detailed than 'schedule'", + off_by_default=True, +) +register_artifact( + "overlap", + "Detailed Inductor compute/comm overlap decisions", + off_by_default=True, +) +register_artifact( + "sym_node", + "Logs extra info for various SymNode operations", + off_by_default=True, +) + +register_artifact("custom_format_test_artifact", "Testing only", log_format="") diff --git a/MLPY/Lib/site-packages/torch/_logging/structured.py b/MLPY/Lib/site-packages/torch/_logging/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..26f9600a93a32482317124155a091cb98f495ed6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_logging/structured.py @@ -0,0 +1,37 @@ +""" +Utilities for converting data types into structured JSON for dumping. +""" + +import traceback +from typing import Dict, Sequence + +import torch._logging._internal + + +INTERN_TABLE: Dict[str, int] = {} + + +def intern_string(s: str) -> int: + r = INTERN_TABLE.get(s, None) + if r is None: + r = len(INTERN_TABLE) + INTERN_TABLE[s] = r + torch._logging._internal.trace_structured( + "str", lambda: (s, r), suppress_context=True + ) + return r + + +def from_traceback(tb: Sequence[traceback.FrameSummary]) -> object: + r = [] + for frame in tb: + # dict naming convention here coincides with + # python/combined_traceback.cpp + r.append( + { + "line": frame.lineno, + "name": frame.name, + "filename": intern_string(frame.filename), + } + ) + return r diff --git a/MLPY/Lib/site-packages/torch/_lowrank.py b/MLPY/Lib/site-packages/torch/_lowrank.py new file mode 100644 index 0000000000000000000000000000000000000000..6e458d4198dccfcf88b0da570f437cdee4e3aa47 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_lowrank.py @@ -0,0 +1,298 @@ +"""Implement various linear algebra algorithms for low rank matrices. +""" + +__all__ = ["svd_lowrank", "pca_lowrank"] + +from typing import Optional, Tuple + +import torch +from torch import Tensor +from . import _linalg_utils as _utils +from .overrides import handle_torch_function, has_torch_function + + +def get_approximate_basis( + A: Tensor, q: int, niter: Optional[int] = 2, M: Optional[Tensor] = None +) -> Tensor: + """Return tensor :math:`Q` with :math:`q` orthonormal columns such + that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is + specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` + approximates :math:`A - M`. + + .. note:: The implementation is based on the Algorithm 4.4 from + Halko et al, 2009. + + .. note:: For an adequate approximation of a k-rank matrix + :math:`A`, where k is not known in advance but could be + estimated, the number of :math:`Q` columns, q, can be + choosen according to the following criteria: in general, + :math:`k <= q <= min(2*k, m, n)`. For large low-rank + matrices, take :math:`q = k + 5..10`. If k is + relatively small compared to :math:`min(m, n)`, choosing + :math:`q = k + 0..2` may be sufficient. + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + Args:: + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int): the dimension of subspace spanned by :math:`Q` + columns. + + niter (int, optional): the number of subspace iterations to + conduct; ``niter`` must be a + nonnegative integer. In most cases, the + default value 2 is more than enough. + + M (Tensor, optional): the input tensor's mean of size + :math:`(*, 1, n)`. + + References:: + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + """ + + niter = 2 if niter is None else niter + m, n = A.shape[-2:] + dtype = _utils.get_floating_dtype(A) + matmul = _utils.matmul + + R = torch.randn(n, q, dtype=dtype, device=A.device) + + # The following code could be made faster using torch.geqrf + torch.ormqr + # but geqrf is not differentiable + A_H = _utils.transjugate(A) + if M is None: + Q = torch.linalg.qr(matmul(A, R)).Q + for i in range(niter): + Q = torch.linalg.qr(matmul(A_H, Q)).Q + Q = torch.linalg.qr(matmul(A, Q)).Q + else: + M_H = _utils.transjugate(M) + Q = torch.linalg.qr(matmul(A, R) - matmul(M, R)).Q + for i in range(niter): + Q = torch.linalg.qr(matmul(A_H, Q) - matmul(M_H, Q)).Q + Q = torch.linalg.qr(matmul(A, Q) - matmul(M, Q)).Q + + return Q + + +def svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, + batches of matrices, or a sparse matrix :math:`A` such that + :math:`A \approx U diag(S) V^T`. In case :math:`M` is given, then + SVD is computed for the matrix :math:`A - M`. + + .. note:: The implementation is based on the Algorithm 5.1 from + Halko et al, 2009. + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + .. note:: The input is assumed to be a low-rank matrix. + + .. note:: In general, use the full-rank SVD implementation + :func:`torch.linalg.svd` for dense matrices due to its 10-fold + higher performance characteristics. The low-rank SVD + will be useful for huge sparse matrices that + :func:`torch.linalg.svd` cannot handle. + + Args:: + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int, optional): a slightly overestimated rank of A. + + niter (int, optional): the number of subspace iterations to + conduct; niter must be a nonnegative + integer, and defaults to 2 + + M (Tensor, optional): the input tensor's mean of size + :math:`(*, 1, n)`. + + References:: + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + + """ + if not torch.jit.is_scripting(): + tensor_ops = (A, M) + if not set(map(type, tensor_ops)).issubset( + (torch.Tensor, type(None)) + ) and has_torch_function(tensor_ops): + return handle_torch_function( + svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M + ) + return _svd_lowrank(A, q=q, niter=niter, M=M) + + +def _svd_lowrank( + A: Tensor, + q: Optional[int] = 6, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + q = 6 if q is None else q + m, n = A.shape[-2:] + matmul = _utils.matmul + if M is None: + M_t = None + else: + M_t = _utils.transpose(M) + A_t = _utils.transpose(A) + + # Algorithm 5.1 in Halko et al 2009, slightly modified to reduce + # the number conjugate and transpose operations + if m < n or n > q: + # computing the SVD approximation of a transpose in + # order to keep B shape minimal (the m < n case) or the V + # shape small (the n > q case) + Q = get_approximate_basis(A_t, q, niter=niter, M=M_t) + Q_c = _utils.conjugate(Q) + if M is None: + B_t = matmul(A, Q_c) + else: + B_t = matmul(A, Q_c) - matmul(M, Q_c) + assert B_t.shape[-2] == m, (B_t.shape, m) + assert B_t.shape[-1] == q, (B_t.shape, q) + assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape + U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) + V = Vh.mH + V = Q.matmul(V) + else: + Q = get_approximate_basis(A, q, niter=niter, M=M) + Q_c = _utils.conjugate(Q) + if M is None: + B = matmul(A_t, Q_c) + else: + B = matmul(A_t, Q_c) - matmul(M_t, Q_c) + B_t = _utils.transpose(B) + assert B_t.shape[-2] == q, (B_t.shape, q) + assert B_t.shape[-1] == n, (B_t.shape, n) + assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape + U, S, Vh = torch.linalg.svd(B_t, full_matrices=False) + V = Vh.mH + U = Q.matmul(U) + + return U, S, V + + +def pca_lowrank( + A: Tensor, q: Optional[int] = None, center: bool = True, niter: int = 2 +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Performs linear Principal Component Analysis (PCA) on a low-rank + matrix, batches of such matrices, or sparse matrix. + + This function returns a namedtuple ``(U, S, V)`` which is the + nearly optimal approximation of a singular value decomposition of + a centered matrix :math:`A` such that :math:`A = U diag(S) V^T`. + + .. note:: The relation of ``(U, S, V)`` to PCA is as follows: + + - :math:`A` is a data matrix with ``m`` samples and + ``n`` features + + - the :math:`V` columns represent the principal directions + + - :math:`S ** 2 / (m - 1)` contains the eigenvalues of + :math:`A^T A / (m - 1)` which is the covariance of + ``A`` when ``center=True`` is provided. + + - ``matmul(A, V[:, :k])`` projects data to the first k + principal components + + .. note:: Different from the standard SVD, the size of returned + matrices depend on the specified rank and q + values as follows: + + - :math:`U` is m x q matrix + + - :math:`S` is q-vector + + - :math:`V` is n x q matrix + + .. note:: To obtain repeatable results, reset the seed for the + pseudorandom number generator + + Args: + + A (Tensor): the input tensor of size :math:`(*, m, n)` + + q (int, optional): a slightly overestimated rank of + :math:`A`. By default, ``q = min(6, m, + n)``. + + center (bool, optional): if True, center the input tensor, + otherwise, assume that the input is + centered. + + niter (int, optional): the number of subspace iterations to + conduct; niter must be a nonnegative + integer, and defaults to 2. + + References:: + + - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding + structure with randomness: probabilistic algorithms for + constructing approximate matrix decompositions, + arXiv:0909.4061 [math.NA; math.PR], 2009 (available at + `arXiv `_). + + """ + + if not torch.jit.is_scripting(): + if type(A) is not torch.Tensor and has_torch_function((A,)): + return handle_torch_function( + pca_lowrank, (A,), A, q=q, center=center, niter=niter + ) + + (m, n) = A.shape[-2:] + + if q is None: + q = min(6, m, n) + elif not (q >= 0 and q <= min(m, n)): + raise ValueError( + f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}" + ) + if not (niter >= 0): + raise ValueError(f"niter(={niter}) must be non-negative integer") + + dtype = _utils.get_floating_dtype(A) + + if not center: + return _svd_lowrank(A, q, niter=niter, M=None) + + if _utils.is_sparse(A): + if len(A.shape) != 2: + raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") + c = torch.sparse.sum(A, dim=(-2,)) / m + # reshape c + column_indices = c.indices()[0] + indices = torch.zeros( + 2, + len(column_indices), + dtype=column_indices.dtype, + device=column_indices.device, + ) + indices[0] = column_indices + C_t = torch.sparse_coo_tensor( + indices, c.values(), (n, 1), dtype=dtype, device=A.device + ) + + ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) + M = _utils.transpose(torch.sparse.mm(C_t, ones_m1_t)) + return _svd_lowrank(A, q, niter=niter, M=M) + else: + C = A.mean(dim=(-2,), keepdim=True) + return _svd_lowrank(A - C, q, niter=niter, M=None) diff --git a/MLPY/Lib/site-packages/torch/_meta_registrations.py b/MLPY/Lib/site-packages/torch/_meta_registrations.py new file mode 100644 index 0000000000000000000000000000000000000000..f7776aae7803b1da33cf59f8f7a5609c3ac2a9a0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_meta_registrations.py @@ -0,0 +1,6253 @@ +import math +from enum import Enum +from functools import partial +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch._prims_common as utils +from torch import SymBool, SymFloat, Tensor +from torch._decomp import ( + _add_op_to_registry, + _convert_out_params, + global_decomposition_table, + meta_table, +) +from torch._ops import OpOverload +from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND +from torch._prims_common import ( + corresponding_complex_dtype, + corresponding_real_dtype, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + make_contiguous_strides_for, + TensorLike, +) + +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _resize_output_check, + _safe_copy_out, + out_wrapper, +) +from torch._refs import _broadcast_shapes, _maybe_broadcast +from torch.utils import _pytree as pytree + + +aten = torch.ops.aten + +_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta") + + +def register_meta(op): + def wrapper(fn): + fn = _convert_out_params(fn) + + def register(op): + _add_op_to_registry(meta_table, op, fn) + + pytree.tree_map_(register, op) + return fn + + return wrapper + + +def elementwise_meta( + *args, + type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND, +): + # Perform type promotion, as this is expected from prim_metafunction + _, result_dtype = utils.elementwise_dtypes( + *args, + type_promotion_kind=type_promotion, + ) + args = [_maybe_convert_to_dtype(x, result_dtype) for x in args] + + # Broadcast + args = _maybe_broadcast(*args) + + # Perform prim checks + return _prim_elementwise_meta( + *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT + ) + + +def toRealValueType(dtype): + from_complex = { + torch.complex32: torch.half, + torch.cfloat: torch.float, + torch.cdouble: torch.double, + } + return from_complex.get(dtype, dtype) + + +def check_inplace_broadcast(self_shape, *args_shape): + broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape)) + torch._check( + broadcasted_shape == self_shape, + lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}", + ) + + +@register_meta([aten.linspace, aten.logspace]) +@out_wrapper() +def meta_linspace_logspace( + start, + end, + steps, + base=None, + dtype=None, + device=None, + layout=torch.strided, + pin_memory=False, + requires_grad=False, +): + if isinstance(start, torch.Tensor): + torch._check( + start.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + if isinstance(end, torch.Tensor): + torch._check( + end.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + + if any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + torch._check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + + # steps does not participate in the computation of the dtype + torch._check_type( + isinstance(steps, IntLike), + lambda: f"received an invalid combination of arguments - got \ +({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", + ) + assert isinstance(steps, IntLike) # for mypy + torch._check(steps >= 0, lambda: "number of steps must be non-negative") + + return torch.empty( + (steps,), # type: ignore[arg-type] + dtype=dtype, + layout=layout, + device="meta", + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_meta([aten.take.default, aten.take.out]) +@out_wrapper() +def meta_take(self, index): + # Type and device checks + torch._check( + index.dtype == torch.long, + lambda: f"take(): Expected a long tensor for index, but got {index.dtype}", + ) + # Index checks + torch._check_index( + not (self.numel() == 0 and index.numel() != 0), + lambda: "take(): tried to take from an empty tensor", + ) + return self.new_empty(index.shape) + + +@register_meta([aten.linalg_cross.default, aten.linalg_cross.out]) +@out_wrapper() +def linalg_cross(self, other, *, dim=-1): + x_d = self.ndim + y_d = other.ndim + torch._check( + x_d == y_d, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + self.size(dim) == 3 and other.size(dim) == 3, + lambda: ( + f"linalg.cross: inputs dimension {dim} must have length 3. " + f"Got {self.size(dim)} and {other.size(dim)}" + ), + ) + out_shape = _broadcast_shapes(self.shape, other.shape) + return self.new_empty(out_shape) + + +@register_meta(aten.linalg_matrix_exp) +@out_wrapper() +def linalg_matrix_exp(self): + squareCheckInputs(self, "linalg.matrix_exp") + checkFloatingOrComplex(self, "linalg.matrix_exp") + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +@register_meta( + [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out] +) +@out_wrapper("values", "indices") +def cummaxmin(self, dim): + values = torch.empty(self.shape, device=self.device, dtype=self.dtype) + indices = torch.empty(self.shape, device=self.device, dtype=torch.int64) + if self.numel() != 0 and self.ndim != 0: + # Checks that dim is within bounds + maybe_wrap_dim(dim, self.ndim) + return values, indices + + +@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out]) +@out_wrapper() +def logcumsumexp(self, dim): + # Checks that dim is within bounds + maybe_wrap_dim(dim, self.ndim) + return torch.empty_like(self).contiguous() + + +# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp +def _exec_fft(out, self, out_sizes, dim, forward): + ndim = self.ndim + signal_ndim = len(dim) + batch_dims = ndim - signal_ndim + + # Permute dimensions so batch dimensions come first, and in stride order + dim_permute = list(range(ndim)) + + is_transformed_dim = [False for _ in range(ndim)] + for d in dim: + is_transformed_dim[d] = True + + # std::partition + left, right = [], [] + for d in dim_permute: + if not is_transformed_dim[d]: + left.append(d) + else: + right.append(d) + dim_permute = left + right + batch_end = len(left) + + self_strides = self.stride() + tmp = dim_permute[:batch_end] + tmp.sort(key=lambda x: self_strides[x], reverse=True) + dim_permute = tmp + dim_permute[batch_end:] + input = self.permute(dim_permute) + + # Collapse batch dimensions into a single dimension + batched_sizes = [-1] + list(input.shape[batch_dims:]) + input = input.reshape(batched_sizes) + + batch_size = input.size(0) + batched_sizes[0] = batch_size + batched_out_sizes = batched_sizes + for i in range(len(dim)): + batched_out_sizes[i + 1] = out_sizes[dim[i]] + out = out.reshape(batched_out_sizes) + + # Reshaping to original batch shape and inverting the dimension permutation + out_strides = [0 for _ in range(ndim)] + batch_numel = 1 + i = batch_dims - 1 + while i >= 0: + out_strides[dim_permute[i]] = batch_numel * out.stride(0) + batch_numel *= out_sizes[dim_permute[i]] + i -= 1 + for i in range(batch_dims, ndim): + out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims)) + return out.as_strided(out_sizes, out_strides, out.storage_offset()) + + +# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp +# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp +@register_meta([aten._fft_c2c.default, aten._fft_c2c.out]) +@out_wrapper() +def meta_fft_c2c(self, dim, normalization, forward): + assert self.dtype.is_complex + + out_sizes = self.shape + output = self.new_empty(out_sizes) + + if not dim: + return output + + sorted_dims = dim[:] + self_strides = self.stride() + sorted_dims.sort(key=lambda x: self_strides[x], reverse=True) + output = _exec_fft(output, self, out_sizes, sorted_dims, forward) + + return output + + +@register_meta([aten._fft_r2c.default, aten._fft_r2c.out]) +@out_wrapper() +def meta_fft_r2c(self, dim, normalization, onesided): + assert self.dtype.is_floating_point + output_sizes = list(self.size()) + + if onesided: + last_dim = dim[-1] + last_dim_halfsize = (output_sizes[last_dim] // 2) + 1 + output_sizes[last_dim] = last_dim_halfsize + + return self.new_empty( + output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) + ) + + +@register_meta(aten.randperm.generator_out) +def meta_randperm(n, *, generator=None, out): + return _maybe_resize_out(out, torch.Size([n])) + + +@register_meta(aten.randperm.default) +def meta_randperm_default( + n, *, dtype=torch.long, layout=None, device=None, pin_memory=None +): + return torch.empty( + n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.randint.default) +def meta_randint( + high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None +): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.randint.low) +def meta_randint_low( + low, + high, + size, + *, + dtype=torch.long, + layout=None, + device=None, + pin_memory=None, +): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.rand.default) +def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) +@out_wrapper() +def meta_fft_c2r(self, dim, normalization, lastdim): + assert self.dtype.is_complex + output_sizes = list(self.size()) + output_sizes[dim[-1]] = lastdim + return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) + + +@register_meta(aten.copy_.default) +def meta_copy_(self, src, non_blocking=False): + # This code simulates the original decomp from inductor, + # which runs most of the meta checks that we care about. + # In theory, we should make this more robust by carefully + # auditing our C++ copy_() kernel and copying the checks here. + + if torch._debug_has_internal_overlap(self) == 1: # 1 == MemOverlap::Yes + raise RuntimeError( + "more than one element of the written-to tensor refers to a single memory location" + ) + + if isinstance(src, Tensor): + intermediate = src.to(self, non_blocking) + if self.size() != intermediate.size(): + aten.expand_copy.default(intermediate, self.size()) + return self + + +def inferUnsqueezeGeometry(tensor, dim): + result_sizes = list(tensor.size()) + result_strides = list(tensor.stride()) + new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] + result_sizes.insert(dim, 1) + result_strides.insert(dim, new_stride) + return result_sizes, result_strides + + +@register_meta(aten.unsqueeze_.default) +def meta_unsqueeze_(self, dim): + dim = maybe_wrap_dim(dim, self.dim() + 1) + g_sizes, g_strides = inferUnsqueezeGeometry(self, dim) + self.as_strided_(g_sizes, g_strides) + return self + + +@register_meta(aten._sparse_semi_structured_linear) +def meta_sparse_structured_linear( + input: Tensor, + weight: Tensor, + _meta: Tensor, + bias: Optional[Tensor] = None, + _activation_opt: Optional[str] = None, + out_dtype: Optional[torch.dtype] = None, +): + output_sizes = list(input.shape) + if bias is not None: + assert weight.size(0) == bias.size(0), "output size mismatch" + assert weight.size(1) == input.size(-1) / 2 + output_sizes[-1] = weight.size(0) + + # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375 + # We assume that we have already squashed the inputs into a 2-D tensor + # Then, as the output is transposed, we need to propagate the transposed + # stride information to the output tensor + assert len(input.shape) == 2, "we can only handle the squashed input case" + transposed_strides = (1, input.size(0)) + + if out_dtype is not None: + assert ( + input.dtype == torch.int8 and out_dtype == torch.int32 + ), "out_dtype is only supported for i8i8->i32 linear operator" + output = input.new_empty( + output_sizes, + dtype=input.dtype if out_dtype is None else out_dtype, + ).as_strided(output_sizes, transposed_strides) + + return output + + +@register_meta(aten._cslt_sparse_mm) +def meta__cslt_sparse_mm( + compressed_A: torch.Tensor, + dense_B: torch.Tensor, + bias: Optional[Tensor] = None, + alpha: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + transpose_result: bool = False, +): + assert dense_B.dtype in { + torch.float32, + torch.float16, + torch.bfloat16, + torch.int8, + }, "_cslt_sparse_mm only supports fp16, bf16, and int8" + assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" + assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" + + is_int8_input_type = compressed_A.dtype == torch.int8 + compression_factor = 10 if is_int8_input_type else 9 + k = dense_B.size(0) + n = dense_B.size(1) + m = (compressed_A.numel() * 16) // (compression_factor * k) + if bias is not None: + assert m == bias.size(0) + + if out_dtype is not None: + assert is_int8_input_type and out_dtype in { + torch.float16, + torch.bfloat16, + torch.int32, + }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul" + output_shape = (n, m) if transpose_result else (m, n) + result = dense_B.new_empty(output_shape, dtype=out_dtype) + return result + + +@register_meta(aten.index_reduce.default) +def meta_index_reduce( + self: Tensor, + dim: int, + index: Tensor, + source: torch.Tensor, + reduce: str, + *, + include_self: bool = True, +) -> Tensor: + return torch.empty_like(self, memory_format=torch.contiguous_format) + + +@register_meta(aten.index_reduce_.default) +def meta_index_reduce_( + self: Tensor, + dim: int, + index: Tensor, + source: torch.Tensor, + reduce: str, + *, + include_self: bool = True, +) -> Tensor: + return self + + +# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py +@out_wrapper() +@register_meta(aten.index_select.default) +def meta_index_select(self, dim, index): + result_size = list(self.size()) + if self.dim() > 0: + result_size[dim] = index.numel() + return self.new_empty(result_size) + + +@register_meta(aten.segment_reduce.default) +def meta_segment_reduce( + data: Tensor, + reduce: str, + *, + lengths: Optional[Tensor] = None, + indices: Optional[Tensor] = None, + offsets: Optional[Tensor] = None, + axis: int = 0, + unsafe: bool = False, + initial=None, +) -> Tensor: + if indices is not None: + raise NotImplementedError( + "segment_reduce(): indices based reduction is not supported yet." + ) + + def segment_reduce_lengths_tensor(lengths_shape): + return torch.empty( + lengths_shape + data.shape[axis + 1 :], + dtype=data.dtype, + device="meta", + memory_format=torch.contiguous_format, + ) + + if lengths is not None: + return segment_reduce_lengths_tensor(lengths.shape) + # FIXME should probably check that lengths and offset aren't both set, but + # the ATen implementation neglects this too + if offsets is not None: + # lengths == torch.diff(offsets) + lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,) + return segment_reduce_lengths_tensor(lengths_shape) + raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.") + + +@register_meta([aten.max.default, aten.max.unary_out]) +@out_wrapper() +def meta_max(self): + return self.new_empty(()) + + +@register_meta(aten.max.dim) +def meta_max_dim(self, dim, keepdim=False): + dim = utils.reduction_dims(self.shape, (dim,)) + output_shape = _compute_reduction_shape(self, dim, keepdim) + return ( + self.new_empty(output_shape), + self.new_empty(output_shape, dtype=torch.long), + ) + + +@register_meta([aten.min.default, aten.min.unary_out]) +@out_wrapper() +def meta_min(self): + return self.new_empty(()) + + +@register_meta(aten.min.dim) +def meta_min_dim(self, dim, keepdim=False): + dim = utils.reduction_dims(self.shape, (dim,)) + output_shape = _compute_reduction_shape(self, dim, keepdim) + return ( + self.new_empty(output_shape), + self.new_empty(output_shape, dtype=torch.long), + ) + + +@register_meta(aten.angle.default) +def meta_angle(self): + if self.is_complex(): + result_dtype = corresponding_real_dtype(self.dtype) + else: + _, result_dtype = elementwise_dtypes( + self, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + return torch.empty_like(self, dtype=result_dtype) + + +@register_meta(aten.angle.out) +def meta_angle_out(self, out): + torch._resize_output_(out, self.size(), self.device) + return out.copy_(torch.angle(self)) + + +@register_meta(aten._assert_async.default) +def assert_async(val): + return + + +@register_meta(aten._assert_async.msg) +def assert_async_meta(val, assert_msg): + return + + +@register_meta(aten._print.default) +def print_meta(s): + return + + +@register_meta(aten._make_dep_token.default) +def make_dep_token( + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + return torch.empty([], device="meta") + + +@register_meta(aten.sym_constrain_range.default) +def sym_constrain_range(size, min=None, max=None): + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import constrain_range + + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") + constrain_range(size, min=min, max=max) + + +@register_meta(aten._functional_sym_constrain_range.default) +def functional_sym_constrain_range(size, min=None, max=None, dep_token=None): + aten.sym_constrain_range(size, min=min, max=max) + return dep_token + + +@register_meta(aten.sym_constrain_range_for_size.default) +def sym_constrain_range_for_size(size, min=None, max=None): + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + if isinstance(size, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat or Symbool is nyi") + _constrain_range_for_size(size, min=min, max=max) + + +@register_meta(aten._functional_sym_constrain_range_for_size.default) +def functional_sym_constrain_range_for_size(size, min, max, dep_token): + aten.sym_constrain_range_for_size(size, min=min, max=max) + return dep_token + + +@register_meta(aten._functional_assert_async.msg) +def functional_assert_async_meta(val, assert_msg, dep_token): + return dep_token + + +# From aten/src/ATen/native/LinearAlgebraUtils.h +def squareCheckInputs(self: Tensor, f_name: str): + assert ( + self.dim() >= 2 + ), f"{f_name}: The input tensor must have at least 2 dimensions." + assert self.size(-1) == self.size( + -2 + ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" + + +# Validates input shapes and devices +# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve) +# From aten/src/ATen/native/LinearAlgebraUtils.h +def linearSolveCheckInputs( + self: Tensor, + A: Tensor, + name: str, +): + torch._check( + self.device == A.device, + lambda: ( + f"Expected b and A to be on the same device, but found b on " + f"{self.device} and A on {A.device} instead." + ), + ) + + torch._check( + self.dtype == A.dtype, + lambda: ( + f"Expected b and A to have the same dtype, but found b of type " + f"{self.dtype} and A of type {A.dtype} instead." + ), + ) + + torch._check( + A.size(-1) == A.size(-2), + lambda: ( + f"A must be batches of square matrices, " + f"but they are {A.size(-2)} by {A.size(-1)} matrices" + ), + ) + + torch._check( + A.size(-1) == self.size(-2), + lambda: ( + f"Incompatible matrix sizes for {name}: each A " + f"matrix is {A.size(-1)} by {A.size(-1)}" + f" but each b matrix is {self.size(-2)} by {self.size(-1)}" + ), + ) + + +# From aten/src/ATen/native/LinearAlgebraUtils.h +def checkFloatingOrComplex( + t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True +): + dtype = t.dtype + torch._check( + t.is_floating_point() or t.is_complex(), + lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + if not allow_low_precision_dtypes: + torch._check( + dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), + lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}", + ) + + +# From aten/src/ATen/native/LinearAlgebraUtils.h +def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"): + torch._check( + A.dim() >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + +def checkInputsSolver( + A: Tensor, + B: Tensor, + left: bool, + f_name: str, +): + squareCheckInputs(A, f_name) + checkIsMatrix(B, f_name) + torch._check( + A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1), + lambda: ( + f"{f_name}: Incompatible shapes of A and B for the equation " + f"{'AX = B' if left else 'XA = B'}" + f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})" + ), + ) + + +def checkSameDevice( + fn_name: str, result: Tensor, input: Tensor, result_name: str = "result" +): + torch._check( + result.device == input.device, + lambda: ( + f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got " + f"{result_name} on {result.device} and input on {input.device}" + ), + ) + + +def checkUplo(UPLO: str): + UPLO_uppercase = UPLO.upper() + torch._check( + len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"), + lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}", + ) + + +@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues]) +@out_wrapper("eigenvalues", "eigenvectors") +def meta__linalg_eigh( + A: Tensor, + UPLO: str = "L", + compute_v: bool = True, +): + squareCheckInputs(A, "linalg.eigh") + checkUplo(UPLO) + + shape = list(A.shape) + if compute_v: + vecs = A.new_empty(shape) + vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False)) + else: + vecs = A.new_empty([0]) + + shape.pop() + vals = A.new_empty(shape, dtype=toRealValueType(A.dtype)) + + return vals, vecs + + +@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out]) +@out_wrapper() +def meta__linalg_eigvals(input: Tensor) -> Tensor: + squareCheckInputs(input, "linalg.eigvals") + complex_dtype = ( + input.dtype + if utils.is_complex_dtype(input.dtype) + else utils.corresponding_complex_dtype(input.dtype) + ) + return input.new_empty(input.shape[:-1], dtype=complex_dtype) + + +@register_meta([aten.linalg_eig]) +@out_wrapper("eigenvalues", "eigenvectors") +def meta_linalg_eig(input: Tensor): + squareCheckInputs(input, "linalg.eig") + complex_dtype = ( + input.dtype + if utils.is_complex_dtype(input.dtype) + else utils.corresponding_complex_dtype(input.dtype) + ) + values = input.new_empty(input.shape[:-1], dtype=complex_dtype) + vectors = input.new_empty(input.shape, dtype=complex_dtype) + return values, vectors + + +def cloneBatchedColumnMajor(src: Tensor) -> Tensor: + return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1) + + +@register_meta(aten._cholesky_solve_helper) +@out_wrapper() +def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor: + return cloneBatchedColumnMajor(self) + + +@register_meta(aten.cholesky_solve) +@out_wrapper() +def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor: + torch._check( + self.ndim >= 2, + lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead", + ) + torch._check( + A.ndim >= 2, + lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead", + ) + self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name( + self, A, "cholesky_solve" + ) + return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper) + + +@register_meta(aten.cholesky) +@out_wrapper() +def cholesky(self: Tensor, upper: bool = False) -> Tensor: + if self.numel() == 0: + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + squareCheckInputs(self, "cholesky") + return cloneBatchedColumnMajor(self) + + +@register_meta(aten.cholesky_inverse) +@out_wrapper() +def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor: + squareCheckInputs(self, "cholesky_inverse") + return cloneBatchedColumnMajor(self) + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_cholesky_ex.default) +def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False): + squareCheckInputs(A, "linalg.cholesky") + checkFloatingOrComplex(A, "linalg.cholesky") + + A_shape = A.shape + ndim = len(A_shape) + + # L + L_strides = make_contiguous_strides_for(A_shape, False) + L = A.new_empty(A_shape) + L.as_strided_(A_shape, L_strides) + + # infos + infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32) + return L, infos + + +@register_meta( + [aten.linalg_householder_product.default, aten.linalg_householder_product.out] +) +@out_wrapper() +def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor: + torch._check( + input.ndim >= 2, + lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.", + ) + torch._check( + input.size(-2) >= input.size(-1), + lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]", + ) + torch._check( + input.size(-1) >= tau.size(-1), + lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]", + ) + + torch._check( + input.ndim - tau.ndim == 1, + lambda: ( + f"torch.linalg.householder_product: Expected tau to have one dimension less than input, " + f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}" + ), + ) + if input.ndim > 2: + expected_batch_tau_shape = input.shape[:-2] + actual_batch_tau_shape = tau.shape[:-1] + torch._check( + actual_batch_tau_shape == expected_batch_tau_shape, + lambda: ( + f"torch.linalg.householder_product: Expected batch dimensions of tau to be " + f"equal to input.shape[:-2], but got {actual_batch_tau_shape}" + ), + ) + + torch._check( + tau.dtype == input.dtype, + lambda: ( + f"torch.linalg.householder_product: tau dtype {tau.dtype}" + f" does not match input dtype {input.dtype}" + ), + ) + checkSameDevice("torch.linalg.householder_product", tau, input, "tau") + + return torch.empty_strided( + size=input.shape, + stride=make_contiguous_strides_for(input.shape, row_major=False), + dtype=input.dtype, + device=input.device, + ) + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_inv_ex.default) +def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False): + squareCheckInputs(A, "linalg.inv_ex") + checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False) + + L = A.new_empty(A.shape) + L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) + + infos = A.new_empty(A.shape[:-2], dtype=torch.int32) + return L, infos + + +@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out]) +@out_wrapper("LD", "pivots", "info") +def linalg_ldl_factor_ex_meta( + self: Tensor, + *, + hermitian: bool = False, + check_errors: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: + squareCheckInputs(self, "torch.linalg.ldl_factor_ex") + checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex") + LD = torch.empty_strided( + size=self.shape, + stride=make_contiguous_strides_for(self.shape, row_major=False), + dtype=self.dtype, + device=self.device, + ) + pivots = self.new_empty(self.shape[:-1], dtype=torch.int) + info = self.new_empty(self.shape[:-2], dtype=torch.int) + return LD, pivots, info + + +@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out]) +@out_wrapper() +def linalg_ldl_solve_meta( + LD: Tensor, pivots: Tensor, B: Tensor, *, hermitian: bool = False +) -> Tensor: + squareCheckInputs(LD, "torch.linalg.ldl_solve") + checkFloatingOrComplex(LD, "torch.linalg.ldl_solve") + linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve") + torch._check( + B.ndim >= 2, + lambda: ( + f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, " + f"but it has {B.ndim} dimensions instead" + ), + ) + expected_pivots_shape = LD.shape[:-1] + torch._check( + expected_pivots_shape == pivots.shape, + lambda: ( + f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, " + f"but got pivots with shape {pivots.shape} instead" + ), + ) + torch._check( + utils.is_integer_dtype(pivots.dtype), + lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}", + ) + torch._check( + LD.dtype == B.dtype, + lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}", + ) + B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD) + return torch.empty_strided( + size=B_broadcast_size, + stride=make_contiguous_strides_for(B_broadcast_size, row_major=False), + dtype=B.dtype, + device=B.device, + ) + + +@register_meta([aten.linalg_lu.default, aten.linalg_lu.out]) +@out_wrapper("P", "L", "U") +def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + torch._check( + A.ndim >= 2, + lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", + ) + + sizes = list(A.shape) + m = sizes[-2] + n = sizes[-1] + k = min(m, n) + + sizes[-1] = m + if pivot: + P = A.new_empty(sizes) + else: + P = A.new_empty([0]) + + sizes[-1] = k + L = A.new_empty(sizes) + + sizes[-2] = k + sizes[-1] = n + U = A.new_empty(sizes) + return P, L, U + + +@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out]) +@out_wrapper("LU", "pivots", "info") +def linalg_lu_factor_ex_meta( + A: Tensor, *, pivot: bool = True, check_errors: bool = False +) -> Tuple[Tensor, Tensor, Tensor]: + torch._check( + A.ndim >= 2, + lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", + ) + + sizes = list(A.shape) + m = sizes[-2] + n = sizes[-1] + + LU = torch.empty_strided( + size=sizes, + stride=make_contiguous_strides_for(sizes, row_major=False), + dtype=A.dtype, + device=A.device, + ) + + # Sets sizes to the size of pivots + sizes.pop() + sizes[-1] = min(m, n) + pivots = A.new_empty(sizes, dtype=torch.int) + + # Sets sizes to the size of info + sizes.pop() + info = A.new_empty(sizes, dtype=torch.int) + + return LU, pivots, info + + +@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out]) +@out_wrapper() +def linalg_lu_solve_meta( + LU: Tensor, + pivots: Tensor, + B: Tensor, + *, + left: bool = True, + adjoint: bool = False, +) -> Tensor: + # dtype + checkFloatingOrComplex(LU, "torch.linalg.lu_solve") + torch._check( + LU.dtype == B.dtype, + lambda: ( + f"linalg.lu_solve: Expected LU and B to have the same dtype, " + f"but found LU of type {LU.dtype} and B of type {B.dtype} instead" + ), + ) + torch._check( + pivots.dtype == torch.int, + lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32", + ) + + # matrix shapes + squareCheckInputs(LU, "torch.linalg.lu_solve") + checkInputsSolver(LU, B, left, "linalg.lu_solve") + torch._check( + LU.size(-1) == pivots.size(-1), + lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix", + ) + + # batches + torch._check( + LU.shape[:-1] == pivots.shape, + lambda: ( + f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, " + f"but got pivots with shape {pivots.shape} instead" + ), + ) + + B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU) + + result = torch.empty_strided( + size=B_broadcast_size, + stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left), + dtype=B.dtype, + device=B.device, + ) + + if result.numel() != 0 and not left: + if result.is_complex(): + result = result.conj() + + return result + + +@register_meta(aten.lu_unpack) +@out_wrapper("P", "L", "U") +def lu_unpack_meta( + LU: Tensor, + pivots: Tensor, + unpack_data: bool = True, + unpack_pivots: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + torch._check( + LU.ndim >= 2, + lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead", + ) + if unpack_pivots: + torch._check( + pivots.dtype == torch.int32, + lambda: ( + "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n" + "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor" + ), + ) + sizes = list(LU.shape) + m = sizes[-2] + n = sizes[-1] + k = min(m, n) + sizes[-1] = m + if unpack_pivots: + P = LU.new_empty(sizes) + else: + P = LU.new_empty([0]) + if unpack_data: + sizes[-1] = k + L = LU.new_empty(sizes) + sizes[-2] = k + sizes[-1] = n + U = LU.new_empty(sizes) + else: + L = LU.new_empty([0]) + U = LU.new_empty([0]) + return P, L, U + + +# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) +def _parse_qr_mode(mode: str) -> Tuple[bool, bool]: + if mode == "reduced": + compute_q = True + reduced = True + elif mode == "complete": + compute_q = True + reduced = False + elif mode == "r": + compute_q = False + reduced = True # this is actually irrelevant in this mode + else: + torch._check( + False, + lambda: ( + f"qr received unrecognized mode '{mode}' " + f"but expected one of 'reduced' (default), 'r', or 'complete'" + ), + ) + return compute_q, reduced # type: ignore[possibly-undefined] + + +@register_meta([aten.linalg_qr.default, aten.linalg_qr.out]) +@out_wrapper("Q", "R") +def linalg_qr_meta( + A: Tensor, + mode: str = "reduced", +) -> Tuple[Tensor, Tensor]: + checkIsMatrix(A, "linalg.qr") + checkFloatingOrComplex(A, "linalg.qr") + + compute_q, reduced_mode = _parse_qr_mode(mode) + + m = A.shape[-2] + n = A.shape[-1] + k = min(m, n) + + if compute_q: + Q_shape = list(A.shape) + Q_shape[-1] = k if reduced_mode else m + Q = A.new_empty(Q_shape) + Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False)) + else: + Q = A.new_empty([0]) + + # For readability + R_shape = list(A.shape) + R_shape[-2] = k if reduced_mode or not compute_q else m + R = A.new_empty(R_shape) + R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False)) + return Q, R + + +@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign]) +@out_wrapper("sign", "logabsdet", "LU", "pivots") +def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + squareCheckInputs(A, "linalg.slogdet") + checkFloatingOrComplex(A, "linalg.slogdet", False) + shape = A.shape + sign = A.new_empty(shape[:-2]) + logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype)) + LU = torch.empty_strided( + size=shape, + stride=make_contiguous_strides_for(shape, False), + dtype=A.dtype, + device=A.device, + ) + pivots = A.new_empty(shape[:-1], dtype=torch.int32) + return sign, logabsdet, LU, pivots + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml +@register_meta(aten._linalg_svd.default) +def _linalg_svd_meta( + A: Tensor, + full_matrices: bool = False, + compute_uv: bool = True, + driver: Optional[str] = None, +): + checkIsMatrix(A, "linalg.svd") + checkFloatingOrComplex(A, "linalg.svd") + + batch_dims = list(A.shape[:-2]) + m = A.shape[-2] + n = A.shape[-1] + k = min(m, n) + + if compute_uv: + U_shape = batch_dims + [m, m if full_matrices else k] + U = A.new_empty(U_shape) + U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False)) + + V_shape = batch_dims + [n if full_matrices else k, n] + V = A.new_empty(V_shape) + # NB: This checks for CUDA since there is no way to check for cuSolver. + # Also, this might not work correctly on CPU when fake_device is not + # available as device_hint just defaults to CUDA in that case. See + # _linalg_svd meta in core. + is_cuda = device_hint(A) == "cuda" + V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda)) + else: + # doesn't matter + U = A.new_empty([0]) + V = A.new_empty([0]) + + # S is always real, even when A is complex. + S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype)) + return U, S, V + + +def _linalg_broadcast_batch_dims( + arg1: Tensor, arg2: Tensor +) -> Tuple[List[int], List[int]]: + # broadcast the batch dimensions of arg1 and arg2. + arg1_batch_sizes = arg1.shape[:-2] + arg2_batch_sizes = arg2.shape[:-2] + expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes) + + arg1_expand_size = list(expand_batch_portion) + arg1_expand_size += [arg1.size(-2), arg1.size(-1)] + + arg2_expand_size = list(expand_batch_portion) + arg2_expand_size += [arg2.size(-2), arg2.size(-1)] + return arg1_expand_size, arg2_expand_size + + +def _linalg_broadcast_batch_dims_name( + arg1: Tensor, arg2: Tensor, name: Optional[str] +) -> Tuple[Tensor, Tensor]: + # If there's no name we assume we don't want to check the errors + if name: + linearSolveCheckInputs(arg1, arg2, name) + + arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2) + + arg1_broadcasted = ( + arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size) + ) + arg2_broadcasted = ( + arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size) + ) + return arg1_broadcasted, arg2_broadcasted + + +def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool: + expected_batched_rhs_shape = input.shape[:-1] + vector_case = other.ndim == 1 or ( + input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape + ) + return vector_case + + +@register_meta(aten._linalg_solve_ex) +def _linalg_solve_ex( + A: Tensor, + B: Tensor, + *, + left: bool = True, + check_errors: bool = False, + result: Optional[Tensor] = None, + LU: Optional[Tensor] = None, + pivots: Optional[Tensor] = None, + info: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + checkFloatingOrComplex(A, "linalg.solve") + torch._check( + A.dtype == B.dtype, + lambda: ( + f"linalg.solve: Expected A and B to have the same dtype, but found A of type " + f"{A.dtype} and B of type {B.dtype} instead" + ), + ) + vector_case = linalg_solve_is_vector_rhs(A, B) + B_ = B.unsqueeze(-1) if vector_case else B + checkInputsSolver(A, B_, left, "linalg.solve") + B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A) + torch._check( + left or not vector_case, + lambda: ( + "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. " + "In this case linalg.solve is equivalent to B / A.squeeze(-1)" + ), + ) + result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape + result_ = torch.empty_strided( + size=result_shape, + stride=make_contiguous_strides_for(result_shape, not left), + dtype=B.dtype, + device=B.device, + ) + shape = A.shape + ndim = A.ndim + LU_ = torch.empty_strided( + size=shape, + stride=make_contiguous_strides_for(shape, False), + dtype=A.dtype, + device=A.device, + ) + pivots_ = A.new_empty(shape[:-1], dtype=torch.int32) + info_ = A.new_empty(shape[:-2], dtype=torch.int32) + out = (result, LU, pivots, info) + res = (result_, LU_, pivots_, info_) + if all(x is not None for x in out): + for r, o in zip(res, out): + # resize and copy operations are done in-place + _maybe_resize_out(o, r.shape) # type: ignore[arg-type] + # strides are not copied in out_wrapper + o.as_strided_(r.shape, r.stride()) # type: ignore[union-attr] + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False) # type: ignore[arg-type] + return res + + +@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out]) +def linalg_solve_triangular_meta( + A: Tensor, + B: Tensor, + *, + upper: bool, + left: bool = True, + unitriangular: bool = False, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + out = A.new_empty([0]) + assert isinstance(out, TensorLike) + checkInputsSolver(A, B, left, "linalg.solve_triangular") + B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None) + avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj() + if avoid_copy_A: + out = _maybe_resize_out(out, B_.shape) + else: + # reimplementation of resize_output with result F-contig + if _resize_output_check(out, B_.shape): + out.resize_(B_.transpose(-2, -1).shape) + out.transpose_(-2, -1) + return out # type: ignore[return-value] + + +@register_meta(aten.triangular_solve) +@out_wrapper("solution", "cloned_coefficient") +def triangular_solve_meta( + self: Tensor, + A: Tensor, + upper: bool = True, + transpose: bool = False, + unitriangular: bool = False, +) -> Tuple[Tensor, Tensor]: + torch._check( + self.ndim >= 2, + lambda: ( + f"torch.triangular_solve: Expected b to have at least 2 dimensions, " + f"but it has {self.ndim} dimensions instead" + ), + ) + torch._check( + A.ndim >= 2, + lambda: ( + f"torch.triangular_solve: Expected A to have at least 2 dimensions, " + f"but it has {A.ndim} dimensions instead" + ), + ) + + linearSolveCheckInputs(self, A, "triangular_solve") + + if A.layout == torch.strided: + self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A) + solution = torch.empty_strided( + size=self_broadcast_size, + stride=make_contiguous_strides_for(self_broadcast_size, row_major=False), + dtype=self.dtype, + device=self.device, + ) + cloned_coefficient = torch.empty_strided( + size=A_broadcast_size, + stride=make_contiguous_strides_for(A_broadcast_size, row_major=False), + dtype=A.dtype, + device=A.device, + ) + elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr: + solution = torch.empty_like(self) + cloned_coefficient = self.new_empty([0]) + else: + torch._check(False, lambda: "triangular_solve: Got an unexpected layout.") + return solution, cloned_coefficient # type: ignore[possibly-undefined] + + +# From aten/src/ATen/native/LinearAlgebra.cpp +@register_meta(aten._linalg_det.default) +def _linalg_det_meta(A): + squareCheckInputs(A, "linalg.det") + checkFloatingOrComplex(A, "linalg.det") + + det = A.new_empty(A.shape[:-2]) + + LU = A.new_empty(A.shape) + LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) + + pivots = A.new_empty(A.shape[:-1], dtype=torch.int32) + return det, LU, pivots + + +@register_meta(aten.ormqr) +@out_wrapper() +def ormqr( + input: Tensor, + tau: Tensor, + other: Tensor, + left: bool = True, + transpose: bool = False, +) -> Tensor: + torch._check( + input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions." + ) + torch._check( + other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions." + ) + + left_size_condition = -2 if left else -1 + torch._check( + other.shape[left_size_condition] >= tau.shape[-1], + lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]", + ) + torch._check( + other.shape[left_size_condition] == input.shape[-2], + lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]", + ) + + torch._check( + tau.shape[-1] <= input.shape[-1], + lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]", + ) + + torch._check( + input.ndim - tau.ndim == 1, + lambda: ( + f"torch.ormqr: Expected tau to have one dimension less than input, " + f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}" + ), + ) + torch._check( + input.ndim == other.ndim, + lambda: ( + f"torch.ormqr: Expected other to have the same number of dimensions as input, " + f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}" + ), + ) + + if input.ndim > 2: + expected_batch_shape = input.shape[:-2] + actual_batch_tau_shape = tau.shape[:-1] + torch._check( + actual_batch_tau_shape == expected_batch_shape, + lambda: ( + f"torch.ormqr: Expected batch dimensions of tau to be " + f"equal to input.shape[:-2], but got {actual_batch_tau_shape}" + ), + ) + + actual_batch_other_shape = other.shape[:-2] + torch._check( + actual_batch_other_shape == expected_batch_shape, + lambda: ( + f"torch.ormqr: Expected batch dimensions of other to be " + f"equal to input.shape[:-2], but got {actual_batch_other_shape}" + ), + ) + + torch._check( + tau.dtype == input.dtype, + lambda: ( + f"torch.ormqr: Expected input and tau to have the same dtype, " + f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}" + ), + ) + torch._check( + other.dtype == input.dtype, + lambda: ( + f"torch.ormqr: Expected input and other to have the same dtype, " + f"but input has dtype {input.dtype} and other has dtype {other.dtype}" + ), + ) + + checkSameDevice("torch.ormqr", tau, input, "tau") + checkSameDevice("torch.ormqr", other, input, "other") + + return torch.empty_strided( + size=other.shape, + stride=make_contiguous_strides_for(other.shape, row_major=False), + dtype=other.dtype, + device=other.device, + ) + + +def _padding_check_valid_input(input, padding, *, dim): + torch._check( + len(padding) == 2 * dim, + lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}", + ) + + input_dim = input.ndim + + is_batch_mode = input_dim == (dim + 2) + + valid_batch_mode = is_batch_mode + valid_non_batch_mode = not is_batch_mode + + if is_batch_mode: + # allow batch size of 0-dim. + for d in range(1, input_dim): + valid_batch_mode = valid_batch_mode and input.size(d) != 0 + else: + for d in range(0, input_dim): + valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0 + + # allow empty batch size but not other dimensions. + torch._check( + valid_batch_mode or valid_non_batch_mode, + lambda: ( + f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size " + f"and other non-zero dimensions for input, but got: {input.shape}" + ), + ) + + +def _pad1d_common(input, padding, *, is_reflection): + dim_plane = 0 + dim_w = 1 + nbatch = 1 + + if input.ndim == 3: + nbatch = input.size(0) + dim_w += 1 + dim_plane += 1 + + _padding_check_valid_input(input, padding, dim=1) + + pad_l, pad_r = padding + + nplane = input.size(dim_plane) + input_w = input.size(dim_w) + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + + torch._check( + output_w >= 1, + lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}", + ) + + if input.ndim == 2: + return input.new_empty((nplane, output_w)) + else: + return input.new_empty((nbatch, nplane, output_w)) + + +@register_meta(aten.reflection_pad1d) +@out_wrapper() +def meta_reflection_pad1d(input, padding): + return _pad1d_common(input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad1d) +@out_wrapper() +def meta_replication_pad1d(input, padding): + return _pad1d_common(input, padding, is_reflection=False) + + +def _pad1d_backward_common(grad_output, input, padding, *, is_reflection): + dim_w = 1 + if not is_reflection: + torch._check(len(padding) == 2, lambda: "padding size is expected to be 2") + + if input.ndim == 3: + dim_w += 1 + + pad_l, pad_r = padding + + input_w = input.size(dim_w) + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + + torch._check( + output_w == grad_output.size(dim_w), + lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", + ) + + return input.new_empty(input.shape) + + +@register_meta(aten.reflection_pad1d_backward) +@out_wrapper("grad_input") +def meta_reflection_pad1d_backward(grad_output, input, padding): + return _pad1d_backward_common(grad_output, input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad1d_backward) +@out_wrapper("grad_input") +def meta_replication_pad1d_backward(grad_output, input, padding): + return _pad1d_backward_common(grad_output, input, padding, is_reflection=False) + + +def _pad2d_common(input, padding, *, is_reflection): + dim_w = 2 + dim_h = 1 + dim_slices = 0 + nbatch = 1 + + _padding_check_valid_input(input, padding, dim=2) + + ndim = input.ndim + if ndim == 4: + nbatch = input.size(0) + dim_w += 1 + dim_h += 1 + dim_slices += 1 + + pad_l, pad_r, pad_t, pad_b = padding + + nplane = input.size(dim_slices) + input_h = input.size(dim_h) + input_w = input.size(dim_w) + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + torch._check( + pad_t < input_h and pad_b < input_h, + lambda: ( + f"Argument #6: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}" + ), + ) + + torch._check( + output_w >= 1 or output_h >= 1, + lambda: ( + f"input (H: {input_h} W: {input_w}) is too small. " + f"Calculated output H: {output_h} W: {output_w}" + ), + ) + + if input.ndim == 3: + return input.new_empty((nplane, output_h, output_w)) + else: + return input.new_empty((nbatch, nplane, output_h, output_w)) + + +@register_meta(aten.reflection_pad2d) +@out_wrapper() +def meta_reflection_pad2d(input, padding): + return _pad2d_common(input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad2d) +@out_wrapper() +def meta_replication_pad2d(input, padding): + return _pad2d_common(input, padding, is_reflection=False) + + +@register_meta( + [ + aten.reflection_pad2d_backward.default, + aten.reflection_pad2d_backward.grad_input, + aten.replication_pad2d_backward.default, + aten.replication_pad2d_backward.grad_input, + ] +) +@out_wrapper("grad_input") +def meta_pad2d_backward(grad_output, self, padding): + dim_w = 2 + dim_h = 1 + dim_plane = 0 + nbatch = 1 + + self_shape = self.shape + if self.dim() == 4: + nbatch = self_shape[0] + dim_w += 1 + dim_h += 1 + dim_plane += 1 + + pad_l, pad_r, pad_t, pad_b = padding + + nplane = self_shape[dim_plane] + input_h = self_shape[dim_h] + input_w = self_shape[dim_w] + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + torch._check( + output_w == grad_output.size(dim_w), + lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", + ) + torch._check( + output_h == grad_output.size(dim_h), + lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}", + ) + return self.new_empty(self.shape) + + +def _pad3d_common(input, padding, *, is_reflection): + dim_w = 3 + dim_h = 2 + dim_d = 1 + dim_plane = 0 + + _padding_check_valid_input(input, padding, dim=3) + + batch_mode = input.ndim == 5 + if batch_mode: + nbatch = input.size(0) + dim_w += 1 + dim_h += 1 + dim_d += 1 + dim_plane += 1 + + pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding + + nplane = input.size(dim_plane) + input_d = input.size(dim_d) + input_h = input.size(dim_h) + input_w = input.size(dim_w) + output_d = input_d + pad_f + pad_bk + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + if is_reflection: + torch._check( + pad_l < input_w and pad_r < input_w, + lambda: ( + f"Argument #4: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" + ), + ) + torch._check( + pad_t < input_h and pad_b < input_h, + lambda: ( + f"Argument #6: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}" + ), + ) + torch._check( + pad_f < input_d and pad_bk < input_d, + lambda: ( + f"Argument #8: Padding size should be less than the corresponding input dimension, " + f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}" + ), + ) + + torch._check( + output_w >= 1 or output_h >= 1 or output_d >= 1, + lambda: ( + f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. " + f"Calculated output D: {output_d} H: {output_h} W: {output_w}" + ), + ) + + if batch_mode: + return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined] + else: + return input.new_empty((nplane, output_d, output_h, output_w)) + + +@register_meta(aten.reflection_pad3d) +@out_wrapper() +def meta_reflection_pad3d(input, padding): + return _pad3d_common(input, padding, is_reflection=True) + + +@register_meta(aten.replication_pad3d) +@out_wrapper() +def meta_replication_pad3d(input, padding): + return _pad3d_common(input, padding, is_reflection=False) + + +@register_meta( + [ + aten.reflection_pad3d_backward.default, + aten.reflection_pad3d_backward.grad_input, + aten.replication_pad3d_backward.default, + aten.replication_pad3d_backward.grad_input, + ] +) +@out_wrapper("grad_input") +def meta_pad3d_backward(grad_output, input, padding): + torch._check(len(padding) == 6, lambda: "padding size is expected to be 6") + assert input.ndim > 3 + assert grad_output.ndim == input.ndim + + dim_w = 3 + dim_h = 2 + dim_d = 1 + + if input.ndim == 5: + dim_w += 1 + dim_h += 1 + dim_d += 1 + + pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding + + input_d = input.size(dim_d) + input_h = input.size(dim_h) + input_w = input.size(dim_w) + output_d = input_d + pad_f + pad_bk + output_h = input_h + pad_t + pad_b + output_w = input_w + pad_l + pad_r + + torch._check( + output_w == grad_output.size(dim_w), + lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", + ) + torch._check( + output_h == grad_output.size(dim_h), + lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}", + ) + torch._check( + output_d == grad_output.size(dim_d), + lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}", + ) + + return input.new_empty(input.shape) + + +@register_meta(aten._pdist_forward) +@out_wrapper() +def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor: + torch._check( + self.is_contiguous(), lambda: "_pdist_forward requires contiguous input" + ) + n = self.size(0) + if n <= 1: + return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format) # type: ignore[call-overload] + else: + return self.new_empty((n * (n - 1) // 2,)).to( + memory_format=torch.legacy_contiguous_format + ) # type: ignore[call-overload] + + +@register_meta(aten._pdist_backward) +@out_wrapper() +def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor: + torch._check( + self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous" + ) + torch._check( + pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous" + ) + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + + +@register_meta([aten.baddbmm.default, aten.baddbmm.out]) +@out_wrapper() +def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): + dim1 = batch1.size(0) + dim2 = batch1.size(1) + dim3 = batch2.size(2) + self = self.expand((dim1, dim2, dim3)) + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check( + self.dtype == batch1.dtype == batch2.dtype, + lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", + ) + batch1_sizes = batch1.shape + batch2_sizes = batch2.shape + bs = batch1_sizes[0] + contraction_size = batch1_sizes[2] + torch._check( + batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, + lambda: ( + f"Expected size for first two dimensions of batch2 tensor to be: " + f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]." + ), + ) + return self.new_empty(self.size()) + + +@register_meta([aten.bernoulli.default, aten.bernoulli.out]) +@out_wrapper() +def meta_bernoulli(self, *, generator=None): + # https://github.com/pytorch/pytorch/issues/88612 + return torch.empty_like(self).contiguous() + + +@register_meta(aten.bernoulli_.float) +def meta_bernoulli_(self, p=0.5, generator=None): + return self + + +@register_meta(aten.bernoulli.p) +def meta_bernoulli_p(self, p=0.5, generator=None): + # https://github.com/pytorch/pytorch/issues/88612 + return torch.empty_like(self).contiguous() + + +@register_meta(aten._fused_moving_avg_obs_fq_helper.default) +def meta__fused_moving_avg_obs_fq_helper( + self, + observer_on, + fake_quant_on, + running_min, + running_max, + scale, + zero_point, + averaging_const, + quant_min, + quant_max, + ch_axis, + per_row_fake_quant=False, + symmetric_quant=False, +): + torch._check( + ch_axis < self.dim(), + lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()", + ) + mask = torch.empty_like(self, dtype=torch.bool) + return (torch.empty_like(self), mask) + + +@register_meta(aten.mm) +@out_wrapper() +def meta_mm(a, b): + torch._check(a.dim() == 2, lambda: "a must be 2D") + torch._check(b.dim() == 2, lambda: "b must be 2D") + N, M1 = a.shape + M2, P = b.shape + torch._check( + M1 == M2, + lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].", + ) + return a.new_empty(N, P) + + +def _compute_reduction_shape(self, dims, keepdim): + if keepdim: + return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim)) + + return utils.compute_reduction_output_shape(self.shape, dims) + + +# FakeTensors (meta tensors with a device) will report device as meta +# when running meta kernels. Here, access the "fake device" of FakeTensor if it +# exists so meta kernels which have diverge per device will be more +# accurate when run with FakeTensors +def device_hint(tensor) -> "str": + if isinstance(tensor, torch._subclasses.FakeTensor): + return tensor.fake_device.type + else: + return "cuda" # default to cuda + + +def calc_conv_nd_return_shape( + input_tensor: torch.Tensor, + weight: torch.Tensor, + stride: Union[List[int], int], + padding: Union[List[int], int], + dilation: Union[List[int], int], + is_transposed: bool, + groups: int, + output_padding: Optional[Union[List[int], int]] = None, +): + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + Returns: + The output length + """ + return (ln + 2 * p - d * (k - 1) - 1) // s + 1 + + def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: + """ + Formula to apply to calculate the length of some dimension of the output + if transposed convolution is used. + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + + Args: + ln: length of the dimension + p: padding in that dim + d: dilation in that dim + k: kernel size in that dim + s: stride in that dim + op: output padding in that dim + + Returns: + The output length + """ + return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 + + kernel_size = weight.shape[2:] + dims = input_tensor.shape[2:] + if is_transposed: + out_channels = groups * weight.shape[1] + else: + out_channels = weight.shape[0] + if weight.shape[1] * groups != input_tensor.shape[1]: + raise RuntimeError("Invalid channel dimensions") + + ret_shape = [input_tensor.shape[0], out_channels] + if isinstance(stride, IntLike): + stride = [stride] * len(dims) + elif len(stride) == 1: + stride = [stride[0]] * len(dims) + + if isinstance(padding, IntLike): + padding = [padding] * len(dims) + elif len(padding) == 1: + padding = [padding[0]] * len(dims) + + if isinstance(dilation, IntLike): + dilation = [dilation] * len(dims) + elif len(dilation) == 1: + dilation = [dilation[0]] * len(dims) + + output_padding_list: Optional[List[int]] = None + if output_padding: + if isinstance(output_padding, IntLike): + output_padding_list = [output_padding] * len(dims) + elif len(output_padding) == 1: + output_padding_list = [output_padding[0]] * len(dims) + else: + output_padding_list = output_padding + + for i in range(len(dims)): + # If output_padding is present, we are dealing with a transposed convolution + if output_padding_list: + ret_shape.append( + _formula_transposed( + dims[i], + padding[i], + dilation[i], + kernel_size[i], + stride[i], + output_padding_list[i], + ) + ) + else: + ret_shape.append( + _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) + ) + + return ret_shape + + +def is_channels_last(ten): + return torch._prims_common.suggest_memory_format(ten) == torch.channels_last + + +@register_meta(aten.convolution.default) +def meta_conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, +): + def pick_memory_format(): + if device_hint(input_tensor) == "cuda": + if is_channels_last(input_tensor) or is_channels_last(weight): + return torch.channels_last + else: + if is_channels_last(input_tensor): + return torch.channels_last + if input_tensor.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif input_tensor.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + shape_out = calc_conv_nd_return_shape( + input_tensor, + weight, + stride, + padding, + dilation, + is_transposed, + groups, + output_padding if is_transposed else None, + ) + + input_channels_dim = 1 + output_channels_dim = 1 + if input_tensor.size(input_channels_dim) == 0: + shape_out[output_channels_dim] = 0 + + out = input_tensor.new_empty(shape_out) + out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] + return out + + +if torch._C._has_mkldnn: + _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library( + "mkldnn", "IMPL", "Meta" + ) + + @register_meta(torch.ops.mkldnn._convolution_pointwise.default) + def meta_mkldnn_convolution_default( + input_tensor, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + shape_out = calc_conv_nd_return_shape( + input_tensor, weight, stride, padding, dilation, False, groups, [] + ) + out = input_tensor.new_empty(shape_out) + out_memory_format = torch.channels_last + out = out.to(memory_format=out_memory_format) # type: ignore[call-overload] + return out + + @register_meta(torch.ops.mkldnn._linear_pointwise.default) + def meta_linear_pointwise_default( + input_tensor, weight, bias, attr, scalars, algorithm + ): + return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0])) + + if torch._C.has_mkl: + _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library( + "mkl", "IMPL", "Meta" + ) + + @register_meta(torch.ops.mkl._mkl_linear) + def meta_mkl_linear( + input_tensor, + packed_weight, + orig_weight, + bias, + batch_size, + ): + return input_tensor.new_empty( + (*input_tensor.shape[:-1], orig_weight.shape[0]) + ) + + _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library( + "onednn", "IMPL", "Meta" + ) + + @register_meta(torch.ops.onednn.qconv2d_pointwise.default) + def meta_qconv2d_pointwise( + x, + x_scale, + x_zp, + w, # prepacked_weight + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + shape_out = calc_conv_nd_return_shape( + x, + w, + stride, + padding, + dilation, + False, + groups, + None, + ) + assert output_dtype in [torch.float32, torch.bfloat16] + out = x.new_empty(shape_out, dtype=output_dtype) + out = out.to(memory_format=torch.channels_last) + return out + + @register_meta(torch.ops.onednn.qlinear_pointwise.default) + @register_meta(torch.ops.onednn.qlinear_pointwise.tensor) + def meta_qlinear_pointwise( + x, + x_scale, + x_zp, + w, + w_scale, + w_zp, + bias, + output_scale, + output_zero_point, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + output_shape = list(x.shape) + # The weight has been transposed during the qlinear weight prepack process. + output_shape[-1] = w.shape[1] + assert output_dtype in [torch.float32, torch.bfloat16] + out = x.new_empty(output_shape, dtype=output_dtype) + return out + + _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library( + "quantized", "IMPL", "Meta" + ) + + @register_meta(torch.ops.quantized.max_pool2d) + def meta_quantized_max_pool2d( + input, + kernel_size, + stride=(), + padding=(0,), + dilation=(1,), + ceil_mode=False, + ): + ( + nInputPlane, + outputHeight, + outputWidth, + ) = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = torch.channels_last + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return torch.empty( + size, + dtype=input.dtype, + device=input.device, + memory_format=memory_format, + ) + + +# from check_dim_size() in aten/src/ATen/TensorUtils.cpp. +def check_dim_size(tensor, dim, dim_size, size): + torch._check( + tensor.dim() == dim and tensor.shape[dim_size] == size, + lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " + + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", + ) + + +@register_meta(aten.avg_pool2d.default) +def meta_avg_pool2d( + input, + kernel_size, + stride=(), + padding=(0,), + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + def unpack(name, val): + torch._check( + len(val) in [1, 2], + lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", + ) + H = val[0] + W = H if len(val) == 1 else val[1] + return H, W + + kH, kW = unpack("kernel_size", kernel_size) + torch._check( + len(stride) in [0, 1, 2], + lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + if len(stride) == 0: + dH, dW = kH, kW + elif len(stride) == 1: + dH, dW = stride[0], stride[0] + else: + dH, dW = unpack("stride", stride) + + padH, padW = unpack("padding", padding) + + torch._check( + divisor_override is None or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) + + memory_format = utils.suggest_memory_format(input) + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, + ) + + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return torch.empty( + size, + dtype=input.dtype, + device=input.device, + memory_format=memory_format, + ) + + +# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h. +def avg_pool2d_backward_shape_check( + input, + gradOutput, + nbatch, + kH, + kW, + dH, + dW, + padH, + padW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, +): + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, + ) + + ndim = input.dim() + nOutputPlane = nInputPlane + + check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane) + check_dim_size(gradOutput, ndim, ndim - 2, outputHeight) + check_dim_size(gradOutput, ndim, ndim - 1, outputWidth) + + +# Don't override the C++ registration. +@register_meta(aten.avg_pool2d_backward.default) +def meta_avg_pool2d_backward( + gradOutput_, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func. + torch._check( + len(kernel_size) == 1 or len(kernel_size) == 2, + lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints", + ) + kH = kernel_size[0] + kW = kH if len(kernel_size) == 1 else kernel_size[1] + torch._check( + len(stride) == 0 or len(stride) == 1 or len(stride) == 2, + lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + dH = kH if len(stride) == 0 else stride[0] + dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1] + torch._check( + len(padding) == 1 or len(padding) == 2, + lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints", + ) + padH = padding[0] + padW = padH if len(padding) == 1 else padding[1] + + torch._check( + divisor_override is None or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + input_size = input.shape + nbatch = input_size[-4] if input.dim() == 4 else 1 + nInputPlane = input_size[-3] + inputHeight = input_size[-2] + inputWidth = input_size[-1] + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) + + mem_format = utils.suggest_memory_format(input) + + avg_pool2d_backward_shape_check( + input, + gradOutput_, + nbatch, + kH, + kW, + dH, + dW, + padH, + padW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + mem_format, + ) + + return torch.empty( + input_size, + dtype=input.dtype, + device=input.device, + memory_format=mem_format, + ) + + +@register_meta(aten.avg_pool3d) +@out_wrapper() +def meta_avg_pool3d( + input, + kernel_size, + stride=(), + padding=(0,), + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", + ) + padT = padding[0] + padH = padT if len(padding) == 1 else padding[1] + padW = padT if len(padding) == 1 else padding[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + torch._check( + not divisor_override or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nbatch = input.size(0) + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode) + oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode) + owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode) + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + padT, + padH, + padW, + 1, + 1, + 1, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + "avg_pool3d()", + check_input_size=True, + ) + + if input.ndim == 4: + return input.new_empty((nslices, otime, oheight, owidth)) + else: + return input.new_empty((nbatch, nslices, otime, oheight, owidth)) + + +@register_meta(aten.avg_pool3d_backward) +@out_wrapper("grad_input") +def meta_avg_pool3d_backward( + grad_output, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", + ) + padT = padding[0] + padH = padT if len(padding) == 1 else padding[1] + padW = padT if len(padding) == 1 else padding[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + torch._check( + not divisor_override or divisor_override != 0, + lambda: "divisor must be not zero", + ) + + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode) + oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode) + owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode) + + avg_pool3d_backward_shape_check( + input, + grad_output, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + padT, + padH, + padW, + itime, + iheight, + iwidth, + otime_for_shape_check, + oheight_for_shape_check, + owidth_for_shape_check, + "avg_pool3d_backward()", + ) + + return input.new_empty(input.shape) + + +@register_meta(aten._adaptive_avg_pool2d.default) +def meta_adaptive_avg_pool2d(self, output_size): + torch._check( + self.ndim == 3 or self.ndim == 4, + lambda: f"Expected 3D or 4D tensor, but got {self.shape}", + ) + output_shape = self.shape[:-2] + tuple(output_size) + memory_format = utils.suggest_memory_format(self) + # need to set memory_format to preserve the memory format of the input + # channel last input should have channel last output + return torch.empty( + output_shape, + dtype=self.dtype, + device=self.device, + memory_format=memory_format, + ) + + +@register_meta(aten._adaptive_avg_pool3d.default) +def meta_adaptive_avg_pool3d(self, output_size): + torch._check( + self.ndim == 4 or self.ndim == 5, + lambda: f"Expected 4D or 5D tensor, but got {self.shape}", + ) + return self.new_empty(self.shape[:-3] + tuple(output_size)) + + +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta__adaptive_avg_pool2d_backward(grad_out, self): + ndim = grad_out.ndim + for i in range(1, ndim): + torch._check( + grad_out.size(i) > 0, + lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ + size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", + ) + torch._check( + ndim == 3 or ndim == 4, + lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", + ) + torch._check( + self.dtype == grad_out.dtype, + lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", + ) + memory_format = torch.contiguous_format + if is_channels_last(self): + memory_format = torch.channels_last + return self.new_empty(self.shape).to(memory_format=memory_format) + + +@register_meta(aten._adaptive_avg_pool3d_backward) +@out_wrapper("grad_input") +def meta__adaptive_avg_pool3d_backward(grad_output, self): + _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward") + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + + +def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str): + ndim = grad_output.ndim + for i in range(1, ndim): + torch._check( + grad_output.size(i) > 0, + lambda: ( + f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, " + f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty" + ), + ) + + +@register_meta(aten.adaptive_max_pool2d) +@out_wrapper("out", "indices") +def meta_adaptive_max_pool2d(input, output_size): + ndim = input.ndim + torch._check( + ndim in (3, 4), + lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}", + ) + for i in range(1, ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + f"but input has sizes {input.shape} with dimension {i} being empty" + ), + ) + + torch._check( + len(output_size) == 2, + lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2", + ) + + dimH = 1 + sizeB = 1 + sizeD = 0 + + if input.ndim == 4: + sizeB = input.size(0) + dimH += 1 + + sizeD = input.size(dimH - 1) + osizeH, osizeW = output_size + + if input.ndim == 3: + out_shape = (sizeD, osizeH, osizeW) + out = input.new_empty(out_shape) + indices = input.new_empty(out_shape, dtype=torch.int64) + return out, indices + else: + out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment] + memory_format = utils.suggest_memory_format(input) + out = input.new_empty(out_shape).to(memory_format=memory_format) + indices = input.new_empty(out_shape, dtype=torch.int64).to( + memory_format=memory_format + ) + return out, indices + + +@register_meta(aten.adaptive_max_pool2d_backward) +@out_wrapper("grad_input") +def meta_adaptive_max_pool2d_backward(grad_output, input, indices): + ndim = grad_output.ndim + torch._check( + ndim in (3, 4), + lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}", + ) + + _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward") + + torch._check( + input.dtype == grad_output.dtype, + lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}", + ) + + memory_format = utils.suggest_memory_format(input) + return input.new_empty(input.shape).to(memory_format=memory_format) + + +@register_meta(aten.adaptive_max_pool3d) +@out_wrapper("out", "indices") +def meta_adaptive_max_pool3d(input, output_size): + ndim = input.ndim + torch._check( + ndim in (4, 5), + lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}", + ) + for i in range(1, ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, " + f"but input has sizes {input.shape} with dimension {i} being empty" + ), + ) + + torch._check( + len(output_size) == 3, + lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3", + ) + + dimD = 0 + sizeB = 1 + sizeD = 0 + + if ndim == 5: + sizeB = input.size(0) + dimD += 1 + + sizeD = input.size(dimD) + osizeT, osizeH, osizeW = output_size + + if ndim == 4: + out_shape = (sizeD, osizeT, osizeH, osizeW) + else: + out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment] + + out = input.new_empty(out_shape) + indices = input.new_empty(out_shape, dtype=torch.int64) + + return out, indices + + +@register_meta(aten.adaptive_max_pool3d_backward) +@out_wrapper("grad_input") +def meta_adaptive_max_pool3d_backward(grad_output, input, indices): + _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward") + return input.new_empty(input.shape) + + +@register_meta(aten.repeat_interleave.Tensor) +def meta_repeat_interleave_Tensor(repeats, output_size=None): + if output_size is None: + raise RuntimeError("cannot repeat_interleave a meta tensor without output_size") + return repeats.new_empty(output_size) + + +@register_meta([aten.complex.default, aten.complex.out]) +@out_wrapper() +def meta_complex(real, imag): + assert real.dtype.is_floating_point + assert imag.dtype.is_floating_point + out_shape = _broadcast_shapes(real.shape, imag.shape) + return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) + + +@register_meta([aten.nonzero_static.default, aten.nonzero_static.out]) +@out_wrapper() +def nonzero_static(self, *, size: int, fill_value: int = -1): + return self.new_empty((size, self.dim()), dtype=torch.long) + + +@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor]) +def meta_index_Tensor(self, indices): + torch._check(bool(indices), lambda: "at least one index must be provided") + # aten::index is the internal advanced indexing implementation + # checkIndexTensorTypes and expandTensors + result: List[Optional[Tensor]] = [] + for i, index in enumerate(indices): + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int, torch.int8, torch.bool], + lambda: "tensors used as indices must be long, int, byte or bool tensors", + ) + if index.dtype in [torch.int8, torch.bool]: + nonzero = index.nonzero() + k = len(result) + torch._check_index( + k + index.ndim <= self.ndim, + lambda: f"too many indices for tensor of dimension {self.ndim}", + ) + for j in range(index.ndim): + torch._check_index( + index.shape[j] == self.shape[k + j], + lambda: f"The shape of the mask {index.shape} at index {i} " + f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", + ) + result.append(nonzero.select(1, j)) + else: + result.append(index) + else: + result.append(index) + indices = result + torch._check( + len(indices) <= self.ndim, + lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", + ) + # expand_outplace + import torch._refs as refs # avoid import cycle in mypy + + indices = list(refs._maybe_broadcast(*indices)) + # add missing null tensors + while len(indices) < self.ndim: + indices.append(None) + + # hasContiguousSubspace + # true if all non-null tensors are adjacent + # See: + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency + state = 0 + has_contiguous_subspace = False + for index in indices: + if state == 0: + if index is not None: + state = 1 + elif state == 1: + if index is None: + state = 2 + else: + if index is not None: + break + else: + has_contiguous_subspace = True + + # transposeToFront + # This is the logic that causes the newly inserted dimensions to show up + # at the beginning of the tensor, if they're not contiguous + if not has_contiguous_subspace: + dims = [] + transposed_indices = [] + for i, index in enumerate(indices): + if index is not None: + dims.append(i) + transposed_indices.append(index) + for i, index in enumerate(indices): + if index is None: + dims.append(i) + transposed_indices.append(index) + self = self.permute(dims) + indices = transposed_indices + + # AdvancedIndex::AdvancedIndex + # Now we can assume the indices have contiguous subspace + # This is simplified from AdvancedIndex which goes to more effort + # to put the input and indices in a form so that TensorIterator can + # take them. If we write a ref for this, probably that logic should + # get implemented + before_shape: List[int] = [] + after_shape: List[int] = [] + replacement_shape: List[int] = [] + for dim, index in enumerate(indices): + if index is None: + if replacement_shape: + after_shape.append(self.shape[dim]) + else: + before_shape.append(self.shape[dim]) + else: + replacement_shape = list(index.shape) + return self.new_empty(before_shape + replacement_shape + after_shape) + + +@register_meta([aten.convolution_backward.default]) +def meta_convolution_backward( + grad_output_, + input_, + weight_, + bias_sizes_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + # High level logic taken from slow_conv3d_backward_cpu which should + # be representative of all convolution_backward impls + backend_grad_input = None + backend_grad_weight = None + backend_grad_bias = None + + if output_mask[0]: + backend_grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1]: + backend_grad_weight = grad_output_.new_empty(weight_.size()) + if output_mask[2]: + backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) + + return (backend_grad_input, backend_grad_weight, backend_grad_bias) + + +@register_meta([aten.addbmm.default, aten.addbmm.out]) +@out_wrapper() +def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): + dim1 = batch1.size(1) + dim2 = batch2.size(2) + self = self.expand((dim1, dim2)) + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check( + batch1.size(0) == batch2.size(0), + lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", + ) + torch._check( + batch1.size(2) == batch2.size(1), + lambda: ( + f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " + f"and {batch2.size(1)}x{batch2.size(2)})" + ), + ) + torch._check( + self.size(0) == dim1 and self.size(1) == dim2, + lambda: "self tensor does not match matmul output shape", + ) + return self.new_empty(self.size()) + + +def register_meta_foreach(ops): + def wrapper(fn): + def register(op): + op_name = str(op).split(".")[1] + scalar_op = getattr(aten, op_name.replace("_foreach_", "")) + + _add_op_to_registry( + meta_table, + op, + partial( + fn, + _scalar_op=scalar_op, + ), + ) + + pytree.tree_map_(register, ops) + return fn + + return wrapper + + +@register_meta_foreach( + [ + aten._foreach_abs, + aten._foreach_acos, + aten._foreach_asin, + aten._foreach_atan, + aten._foreach_ceil, + aten._foreach_cos, + aten._foreach_cosh, + aten._foreach_erf, + aten._foreach_erfc, + aten._foreach_exp, + aten._foreach_expm1, + aten._foreach_frac, + aten._foreach_floor, + aten._foreach_lgamma, + aten._foreach_log, + aten._foreach_log10, + aten._foreach_log1p, + aten._foreach_log2, + aten._foreach_neg, + aten._foreach_norm, + aten._foreach_reciprocal, + aten._foreach_round, + aten._foreach_sigmoid, + aten._foreach_sign, + aten._foreach_sin, + aten._foreach_sinh, + aten._foreach_sqrt, + aten._foreach_tan, + aten._foreach_tanh, + aten._foreach_trunc, + aten._foreach_zero, + aten._foreach_add, + aten._foreach_sub, + aten._foreach_mul, + aten._foreach_div, + aten._foreach_clamp_min, + aten._foreach_clamp_max, + aten._foreach_lerp, + ], +) +def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs): + torch._check( + isinstance(args[0], list), + lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."), + ) + + nelem = len(args[0]) + torch._check( + nelem > 0, + lambda: ("Tensor list must have at least one tensor."), + ) + + nlists = 1 + for iarg, arg in enumerate(args[1:]): + if isinstance(arg, list): + nlists += 1 + torch._check( + len(arg) == nelem, + lambda: ( + f"self and argument-{iarg+2} must match in length, " + f"but got {nelem} and {len(arg)}." + ), + ) + elif isinstance(arg, Tensor): + torch._check( + arg.dim() == 0 and arg.numel() == 1, + lambda: ( + "scalar tensor expected to be 0 dim but it has " + f"{arg.dim()} dimensions and {arg.numel()} elements." + ), + ) + else: + break + + result = [] + for elem in range(nelem): + each_args = [args[i][elem] for i in range(nlists)] + result.append(_scalar_op(*each_args, *args[nlists:], **kwargs)) + + return result + + +@register_meta_foreach( + [ + aten._foreach_abs_, + aten._foreach_acos_, + aten._foreach_asin_, + aten._foreach_atan_, + aten._foreach_ceil_, + aten._foreach_cos_, + aten._foreach_cosh_, + aten._foreach_erf_, + aten._foreach_erfc_, + aten._foreach_exp_, + aten._foreach_expm1_, + aten._foreach_frac_, + aten._foreach_floor_, + aten._foreach_lgamma_, + aten._foreach_log_, + aten._foreach_log10_, + aten._foreach_log1p_, + aten._foreach_log2_, + aten._foreach_neg_, + aten._foreach_reciprocal_, + aten._foreach_round_, + aten._foreach_sigmoid_, + aten._foreach_sign_, + aten._foreach_sin_, + aten._foreach_sinh_, + aten._foreach_sqrt_, + aten._foreach_tan_, + aten._foreach_tanh_, + aten._foreach_trunc_, + aten._foreach_zero_, + aten._foreach_add_, + aten._foreach_sub_, + aten._foreach_mul_, + aten._foreach_div_, + aten._foreach_clamp_min_, + aten._foreach_clamp_max_, + aten._foreach_lerp_, + aten._foreach_copy_, + ] +) +def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs): + _meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs) + return + + +@register_meta([aten._foreach_pow.ScalarAndTensor]) +def meta__foreach_pow_scalar_and_tensor(self, exponent): + # Only foreach_pow has a ScalarAndTensor method and needs special + # handling because it does not work with _meta_foreach_out_of_place. + torch._check( + isinstance(exponent, List), + lambda: f"exponent must be a tensor list but got {type(exponent)}", + ) + return [torch.empty_like(e) for e in exponent] + + +def _check_foreach_binop_tensor_lists(self, other): + torch._check( + isinstance(self, List) and isinstance(other, List), + lambda: ( + "The first two arguments of must be List[Tensor], " + f"but got {type(self)} and {type(other)}." + ), + ) + torch._check( + len(self) > 0 and len(self) == len(other), + lambda: ( + "self and other must be non-empty and match in length, " + f"but got {len(self)} and {len(other)}." + ), + ) + + +@register_meta( + [ + aten._foreach_maximum, + aten._foreach_minimum, + ] +) +def meta__foreach_binop_scalar(*args): + # aten.maximum(Tensor, Scalar) does not exist. + return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min) + + +@register_meta( + [ + aten._foreach_maximum_, + aten._foreach_minimum_, + ] +) +def meta__foreach_binop__scalar(*args): + # aten.maximum(Tensor, Scalar) does not exist + _meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_) + return + + +@register_meta( + [ + aten._foreach_addcdiv.Scalar, + aten._foreach_addcmul.Scalar, + ] +) +def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1): + # forach_addcdiv and addcdiv have different signatures and + # cannot use _meta_foreach_out_of_place. + torch._check( + all(isinstance(l, List) for l in [self, tensor1, tensor2]), + lambda: ( + "All arguments must be List[Tensor], " + f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" + ), + ) + torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") + torch._check( + len(self) == len(tensor1) and len(self) == len(tensor2), + lambda: "All input tensor lists must have the same length", + ) + + return [torch.empty_like(s) for s in self] + + +@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor]) +def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars): + torch._check( + all(isinstance(l, List) for l in [self, tensor1, tensor2]) + and isinstance(scalars, torch.Tensor), + lambda: ( + "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, " + f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}" + ), + ) + torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") + torch._check( + len(self) == len(tensor1) and len(self) == len(tensor2), + lambda: "All input tensor lists must have the same length", + ) + + +@register_meta( + [ + aten._foreach_addcdiv_.Scalar, + aten._foreach_addcmul_.Scalar, + ] +) +def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1): + torch._check( + all(isinstance(l, List) for l in [self, tensor1, tensor2]), + lambda: ( + "All arguments of _foreach_addc*_ must be List[Tensor], " + f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}" + ), + ) + torch._check(len(self) > 0, lambda: "input tensor list must not be empty.") + torch._check( + len(self) == len(tensor1) and len(self) == len(tensor2), + lambda: "All input tensor lists must have the same length", + ) + + +@register_meta([aten._fused_adam_.default]) +def meta__fused_adam_( + self, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + *, + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale=None, + found_inf=None, +): + for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: + torch._check( + isinstance(l, List), + lambda: f"exponent must be a tensor list but got {type(l)}", + ) + + +@register_meta([aten._fused_adam.default]) +def meta__fused_adam( + self, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + *, + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale=None, + found_inf=None, +): + for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: + torch._check( + isinstance(l, List), + lambda: f"exponent must be a tensor list but got {type(l)}", + ) + + def empty_like_list(tensor_list): + return [torch.empty_like(t) for t in tensor_list] + + return ( + empty_like_list(self), + empty_like_list(grads), + empty_like_list(exp_avgs), + empty_like_list(exp_avg_sqs), + empty_like_list(max_exp_avg_sqs), + ) + + +@register_meta([aten._int_mm]) +@out_wrapper() +def meta__int_mm(a, b): + torch._check(a.dim() == 2, lambda: "a must be a 2D tensor") + torch._check(b.dim() == 2, lambda: "b must be a 2D tensor") + torch._check( + a.dtype is torch.int8, + lambda: f"expected self to be int8, got {a.dtype}", + ) + torch._check( + b.dtype is torch.int8, + lambda: f"expected mat2 to be int8, got {b.dtype}", + ) + torch._check( + a.size(1) == b.size(0), + lambda: ( + f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} " + f"and {b.size(0)}x{b.size(1)})" + ), + ) + return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32) + + +@register_meta([aten._convert_weight_to_int4pack]) +def meta__convert_weight_to_int4pack(w, inner_k_tiles): + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + n = w.size(0) + k = w.size(1) + return w.new_empty( + ( + n // 8, + k // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ) + + +@register_meta([aten._weight_int4pack_mm]) +def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check(w.dim() == 4, lambda: "w must be a 4D tensor") + torch._check( + x.dtype is torch.bfloat16, + lambda: f"expected x to be bf16, got {x.dtype}", + ) + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype) + + +@register_meta([aten._weight_int8pack_mm]) +def meta__weight_int8pack_mm(x, w, q_scales): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check( + x.dtype is torch.bfloat16, + lambda: f"expected x to be bf16, got {x.dtype}", + ) + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + w.dtype is torch.int8, + lambda: f"expected w to be int8, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + +@register_meta(aten._cdist_forward.default) +def meta_cdist_forward(x1, x2, p, compute_mode): + torch._check( + x1.dim() >= 2, + lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", + ) + torch._check( + x2.dim() >= 2, + lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", + ) + torch._check( + x1.size(-1) == x2.size(-1), + lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", + ) + torch._check( + utils.is_float_dtype(x1.dtype), + lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", + ) + torch._check( + utils.is_float_dtype(x2.dtype), + lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", + ) + torch._check(p >= 0, lambda: "cdist only supports non-negative p values") + torch._check( + compute_mode in (None, 1, 2), + lambda: f"possible modes: None, 1, 2, but was: {compute_mode}", + ) + r1 = x1.size(-2) + r2 = x2.size(-2) + batch_tensor1 = x1.shape[:-2] + batch_tensor2 = x2.shape[:-2] + output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) + output_shape.extend([r1, r2]) + return x1.new_empty(output_shape) + + +@register_meta(aten._cdist_backward) +@out_wrapper() +def meta_cdist_backward(grad, x1, x2, p, cdist): + c1 = x1.shape[-1] + r1 = x1.shape[-2] + r2 = x2.shape[-2] + batch_tensor1 = x1.shape[:-2] + batch_tensor2 = x2.shape[:-2] + expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) + tensor1_expand_size = expand_batch_portion.copy() + tensor1_expand_size.extend([r1, c1]) + batch_product = math.prod(expand_batch_portion) + if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0: + return torch.zeros_like(x1) + if tensor1_expand_size != list(x1.shape): + x1 = x1.expand(tensor1_expand_size) + return torch.empty_like(x1, memory_format=torch.contiguous_format) + + +# NB: This meta function accepts non-meta arguments! When this behavior +# was originally introduced this was accidental, but it is now load bearing +# as people are using this so that they can conveniently test code involving +# embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module) +@register_meta(aten._embedding_bag.default) +def meta_embedding_bag( + weight, + indices, + offsets, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, +): + torch._check( + indices.dtype in (torch.long, torch.int), + lambda: f"expected indices to be long or int, got {indices.dtype}", + ) + torch._check( + offsets.dtype in (torch.long, torch.int), + lambda: f"expected offsets to be long or int, got {offsets.dtype}", + ) + torch._check( + utils.is_float_dtype(weight.dtype), + lambda: f"expected weight to be floating point type, got {weight.dtype}", + ) + + num_bags = offsets.size(0) + if include_last_offset: + torch._check( + num_bags >= 1, + lambda: "include_last_offset: numBags should be at least 1", + ) + num_bags -= 1 + + output = weight.new_empty(num_bags, weight.size(1)) + MODE_SUM, MODE_MEAN, MODE_MAX = range(3) + + if per_sample_weights is not None: + torch._check( + mode == MODE_SUM, + lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", + ) + torch._check( + per_sample_weights.dtype == weight.dtype, + lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", + ) + torch._check( + per_sample_weights.ndim == 1, + lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", + ) + torch._check( + per_sample_weights.numel() == indices.numel(), + lambda: ( + f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " + f"to be the same as indices.numel() ({indices.numel()})" + ), + ) + + def is_fast_path_index_select_scale(src, scale, output, padding_idx): + return ( + is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 + ) + + def is_fast_path_index_select(src, output, padding_idx): + return ( + (src.dtype == torch.float or src.dtype == torch.half) + and src.stride(1) == 1 + and output.stride(1) == 1 + and padding_idx < 0 + ) + + def is_fast_path(src, scale, output, padding_idx): + if scale is not None: + return is_fast_path_index_select_scale(src, scale, output, padding_idx) + else: + return is_fast_path_index_select(src, output, padding_idx) + + if device_hint(offsets) != "cpu": + offset2bag = indices.new_empty(indices.size(0)) + bag_size = indices.new_empty(offsets.size()) + if mode == MODE_MAX: + max_indices = indices.new_empty(num_bags, weight.size(1)) + else: + max_indices = indices.new_empty(0) + else: + fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) + if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum: + offset2bag = offsets.new_empty(indices.size(0)) + else: + offset2bag = offsets.new_empty(0) + bag_size = offsets.new_empty(num_bags) + # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp + numBags = offsets.shape[0] + if mode == MODE_MAX: + if include_last_offset: + torch._check( + numBags >= 1, + lambda: "include_last_offset: numBags should be at least 1", + ) + numBags -= 1 + max_indices = offsets.new_empty(numBags, weight.shape[1]) + else: + max_indices = offsets.new_empty(bag_size.size()) + return output, offset2bag, bag_size, max_indices + + +@register_meta(aten._embedding_bag_forward_only.default) +def meta_embedding_bag_forward_only(weight, indices, offsets, *args): + output, offset2bag, bag_size, max_indices = meta_embedding_bag( + weight, indices, offsets, *args + ) + if device_hint(offsets) == "cpu": + bag_size = offsets.new_empty(offsets.size()) + return output, offset2bag, bag_size, max_indices + + +def _get_reduction_dtype(input, dtype, promote_int_to_long=True): + # if specified, dtype takes precedence + if dtype: + return dtype + + if input.dtype.is_floating_point or input.dtype.is_complex: + return input.dtype + elif promote_int_to_long: + return torch.long + + return input.dtype + + +@register_meta([aten.nansum.default, aten.nansum.out]) +@out_wrapper() +def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): + output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) + dims = utils.reduction_dims(input.shape, dims) + output_shape = _compute_reduction_shape(input, dims, keepdim) + return input.new_empty(output_shape, dtype=output_dtype) + + +@register_meta([aten.median.default, aten.nanmedian.default]) +def meta_median(input): + output_shape = utils.compute_reduction_output_shape( + input.shape, tuple(range(input.dim())) + ) + return input.new_empty(output_shape) + + +@register_meta( + [ + aten.median.dim, + aten.median.dim_values, + aten.nanmedian.dim, + aten.nanmedian.dim_values, + aten.mode.default, + aten.mode.values, + ] +) +@out_wrapper("values", "indices") +def meta_median_mode_dim(input, dim=-1, keepdim=False): + if device_hint(input) == "cuda": + utils.alert_not_deterministic("median CUDA with indices output") + dim = utils.reduction_dims(input.shape, (dim,)) + output_shape = _compute_reduction_shape(input, dim, keepdim) + return ( + input.new_empty(output_shape), + input.new_empty(output_shape, dtype=torch.long), + ) + + +@register_meta(aten.logical_not_.default) +def meta_logical_not_(self): + return self + + +@register_meta(aten.repeat.default) +def meta_repeat(self, repeats): + torch._check( + len(repeats) >= self.dim(), + lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", + ) + # Add new leading dimensions to the tensor if the + # number of target dimensions is larger than the + # number of source dimensions. + num_new_dimensions = len(repeats) - self.dim() + padded_size = (1,) * num_new_dimensions + tuple(self.shape) + target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))] + return self.new_empty(target_size) + + +@register_meta(aten.zero_.default) +def meta_zero_(self): + return self + + +@register_meta( + [ + aten.mul_.Scalar, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.div_.Tensor, + aten.logical_and_.default, + aten.logical_or_.default, + aten.logical_xor_.default, + ], +) +def meta_binop_inplace(self, other): + if isinstance(other, torch.Tensor): + check_inplace_broadcast(self.shape, other.shape) + return self + + +@register_meta( + [ + aten.add_.Scalar, + aten.sub_.Scalar, + aten.add_.Tensor, + aten.sub_.Tensor, + ], +) +def meta_binop_inplace_alpha(self, other, alpha=1): + if isinstance(other, torch.Tensor): + check_inplace_broadcast(self.shape, other.shape) + return self + + +@register_meta([aten.round.default, aten.round.decimals]) +def meta_round(self, **kwargs): + return elementwise_meta( + self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +def shift_dtype_check(fn_name, self, val): + torch._check( + utils.is_integer_dtype(self.dtype), + lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}", + ) + if isinstance(val, torch.Tensor): + torch._check( + utils.is_integer_dtype(val.dtype), + lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}", + ) + else: + torch._check( + isinstance(val, IntLike), + lambda: f"{fn_name}: Expected shift value to be an int. Got {val}", + ) + + +@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar]) +def meta_rshifts(self, other): + shift_dtype_check("rshift", self, other) + return elementwise_meta( + self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar]) +def meta_lshifts(self, other): + shift_dtype_check("lshift", self, other) + return elementwise_meta( + self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.zero.default) +def meta_zero(self): + return self.new_empty(self.shape) + + +@register_meta([aten.fill_.Tensor, aten.fill_.Scalar]) +def meta_fill_(self, val): + return self + + +@register_meta([aten.fill.Tensor, aten.fill.Scalar]) +def meta_fill(self, val): + return torch.empty_like(self) + + +@register_meta(aten.relu_.default) +def meta_relu_(self): + return self + + +@register_meta([aten.index_put.default, aten._unsafe_index_put.default]) +def meta_index_put(self, indices, values, accumulate=False): + return torch.empty_like(self) + + +@register_meta(aten.masked_fill_.Scalar) +def meta_masked_fill_(self, mask, value): + check_inplace_broadcast(self.shape, mask.shape) + return self + + +@register_meta(aten.masked_scatter_) +def meta_masked_scatter_(self, mask, source): + torch._check( + mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8" + ) + torch._check( + self.dtype == source.dtype, + lambda: "masked_scatter: expected self and source to have same " + "dtypes but got {self.dtype} and {source.dtype}", + ) + return self + + +@register_meta(aten.masked_scatter) +@out_wrapper() +def meta_masked_scatter(self, mask, source): + self, mask = _maybe_broadcast(self, mask) + output = torch.empty_like(self, memory_format=torch.contiguous_format) + return meta_masked_scatter_(output, mask, source) + + +@register_meta(aten.masked_scatter_backward) +def meta_masked_scatter_backward(self, mask, sizes): + return self.new_empty(sizes) + + +@register_meta(aten.index_put_.default) +def meta_index_put_(self, indices, values, accumulate=False): + return self + + +@register_meta(aten.alias.default) +def meta_alias(self): + return self.view(self.shape) + + +def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): + torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") + torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + + batch1_sizes = batch1.size() + batch2_sizes = batch2.size() + + bs = batch1_sizes[0] + contraction_size = batch1_sizes[2] + res_rows = batch1_sizes[1] + res_cols = batch2_sizes[2] + output_size = (bs, res_rows, res_cols) + + torch._check( + batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, + lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}" + f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].", + ) + + # TODO: handle out + + output = batch2.new_empty(output_size) + + if not is_bmm and self_baddbmm is not None: + torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") + torch._check( + self_baddbmm.size() == output_size, + lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}", + ) + + return output + + +@register_meta(aten.bmm.default) +def meta_bmm(self, mat2): + return common_meta_baddbmm_bmm(self, mat2, True) + + +def div_rtn(x, y): + q = x // y + r = x % y + # WARNING: explicit bool conversion here is necessary; + # would be fixed by SymBool + if r != 0 and (bool(r < 0) != bool(y < 0)): + q -= 1 + return q + + +def pooling_output_shape_pad_lr( + inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode +): + outputSize = ( + div_rtn( + inputSize + + pad_l + + pad_r + - dilation * (kernelSize - 1) + - 1 + + (stride - 1 if ceil_mode else 0), + stride, + ) + + 1 + ) + if ceil_mode: + if (outputSize - 1) * stride >= inputSize + pad_l: + outputSize -= 1 + return outputSize + + +def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode): + torch._check(stride != 0, lambda: "stride should not be zero") + torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") + torch._check( + pad <= ((kernelSize - 1) * dilation + 1) // 2, + lambda: ( + f"pad should be at most half of effective kernel size, but got pad={pad}, " + f"kernel_size={kernelSize} and dilation={dilation}" + ), + ) + return pooling_output_shape_pad_lr( + inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode + ) + + +def pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, +): + ndim = input.dim() + nOutputPlane = nInputPlane + + torch._check( + kW > 0 and kH > 0, + lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}", + ) + torch._check( + dW > 0 and dH > 0, + lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}", + ) + torch._check( + dilationH > 0 and dilationW > 0, + lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}", + ) + + valid_dims = input.size(1) != 0 and input.size(2) != 0 + + if memory_format == torch.channels_last: + torch._check( + ndim == 4 and valid_dims and input.size(3) != 0, + lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: {input.size()}", + ) + else: + torch._check( + (ndim == 3 and input.size(0) != 0 and valid_dims) + or (ndim == 4 and valid_dims and input.size(3) != 0), + lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", + ) + + torch._check( + kW // 2 >= padW and kH // 2 >= padH, + lambda: "pad should be smaller than or equal to half of kernel size, but got " + f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", + ) + + torch._check( + outputWidth >= 1 and outputHeight >= 1, + lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). " + f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). " + "Output size is too small", + ) + + +def pool3d_shape_check( + input: Tensor, + nslices: int, + kT: int, + kH: int, + kW: int, + dT: int, + dH: int, + dW: int, + pT: int, + pH: int, + pW: int, + dilationT: int, + dilationH: int, + dilationW: int, + itime: int, + iheight: int, + iwidth: int, + otime: int, + oheight: int, + owidth: int, + fn_name: str, + check_input_size: bool = False, +): + ndim = input.ndim + + torch._check( + kT > 0 and kW > 0 and kH > 0, + lambda: ( + f"kernel size should be greater than zero, but got " + f"kT: {kT}, kH: {kH}, kW: {kW}" + ), + ) + torch._check( + dT > 0 and dW > 0 and dH > 0, + lambda: ( + f"stride should be greater than zero, but got " + f"dT: {dT}, dH: {dH}, dW: {dW}" + ), + ) + torch._check( + dilationT > 0 and dilationW > 0 and dilationH > 0, + lambda: ( + f"dilation should be greater than zero, but got " + f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}" + ), + ) + + torch._check( + ndim in (4, 5), + lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}", + ) + + for i in range(ndim): + if ndim == 5 and i == 0: + # size of batch-dim can be 0. + continue + torch._check( + input.size(i) > 0, + lambda: ( + f"{fn_name}: Expected input's non-batch dimensions to have positive length," + f" but input has a shape of {input.shape}" + f" and non-batch dimension {input.size(i)} has length zero!" + ), + ) + + if check_input_size: # AveragePool3d + torch._check( + itime >= kT and iheight >= kH and iwidth >= kW, + lambda: ( + f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than " + f"kernel size (kT: {kT} kH: {kH} kW: {kW})" + ), + ) + + torch._check( + kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH, + lambda: ( + f"pad should be smaller than or equal to half of kernel size, but got " + f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}" + ), + ) + + torch._check( + otime >= 1 and owidth >= 1 and oheight >= 1, + lambda: ( + f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). " + f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). " + f"Output size is too small" + ), + ) + + +def max_pool3d_backward_shape_check( + input, + grad_output, + indices, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + fn_name, +): + ndim = input.ndim + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + fn_name, + ) + + check_dim_size(grad_output, ndim, ndim - 4, nslices) + check_dim_size(grad_output, ndim, ndim - 3, otime) + check_dim_size(grad_output, ndim, ndim - 2, oheight) + check_dim_size(grad_output, ndim, ndim - 1, owidth) + + check_dim_size(indices, ndim, ndim - 4, nslices) + check_dim_size(indices, ndim, ndim - 3, otime) + check_dim_size(indices, ndim, ndim - 2, oheight) + check_dim_size(indices, ndim, ndim - 1, owidth) + + +def avg_pool3d_backward_shape_check( + input: Tensor, + grad_output: Tensor, + nslices: int, + kT: int, + kH: int, + kW: int, + dT: int, + dH: int, + dW: int, + pT: int, + pH: int, + pW: int, + itime: int, + iheight: int, + iwidth: int, + otime: int, + oheight: int, + owidth: int, + fn_name: str, +): + ndim = input.ndim + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + 1, + 1, + 1, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + fn_name, + True, + ) + + check_dim_size(grad_output, ndim, ndim - 4, nslices) + check_dim_size(grad_output, ndim, ndim - 3, otime) + check_dim_size(grad_output, ndim, ndim - 2, oheight) + check_dim_size(grad_output, ndim, ndim - 1, owidth) + + +def max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode +): + # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp + def unpack(name, val): + torch._check( + len(val) in [1, 2], + lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints", + ) + H = val[0] + W = H if len(val) == 1 else val[1] + return H, W + + kH, kW = unpack("kernel_size", kernel_size) + + torch._check( + len(stride) in [0, 1, 2], + lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints", + ) + if len(stride) == 0: + dH, dW = kH, kW + else: + dH, dW = unpack("stride", stride) + + padH, padW = unpack("padding", padding) + dilationH, dilationW = unpack("dilation", dilation) + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) + + memory_format = utils.suggest_memory_format(input) + if memory_format == torch.channels_last: + torch._check( + input.dim() == 4, + lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout", + ) + elif memory_format == torch.contiguous_format: + torch._check( + input.dim() in [3, 4], + lambda: "non-empty 3D or 4D (batch mode) tensor expected for input", + ) + else: + torch._check( + False, + lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", + ) + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + pool2d_shape_check( + input, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format, + ) + + return nInputPlane, outputHeight, outputWidth + + +@register_meta(aten.max_pool2d_with_indices_backward.default) +def meta_max_pool2d_with_indices_backward( + grad_output, + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, +): + ( + nInputPlane, + outputHeight, + outputWidth, + ) = max_pool2d_checks_and_compute_shape( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + torch._check( + self.dtype == grad_output.dtype, + lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", + ) + + nOutputPlane = nInputPlane + ndim = self.ndim + + def _check_dim_size(t): + check_dim_size(t, ndim, ndim - 3, nOutputPlane) + check_dim_size(t, ndim, ndim - 2, outputHeight) + check_dim_size(t, ndim, ndim - 1, outputWidth) + + _check_dim_size(grad_output) + _check_dim_size(indices) + + memory_format = utils.suggest_memory_format(self) + return torch.empty( + self.shape, + dtype=self.dtype, + device=self.device, + memory_format=memory_format, + ) + + +@register_meta(aten.max_pool2d_with_indices.default) +def meta_max_pool2d_with_indices( + input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +): + ( + nInputPlane, + outputHeight, + outputWidth, + ) = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = utils.suggest_memory_format(input) + if input.dim() == 3: + size = [nInputPlane, outputHeight, outputWidth] + else: + size = [nbatch, nInputPlane, outputHeight, outputWidth] + return ( + torch.empty( + size, + dtype=input.dtype, + device=input.device, + memory_format=memory_format, + ), + torch.empty( + size, + dtype=torch.int64, + device=input.device, + memory_format=memory_format, + ), + ) + + +@register_meta(aten.fractional_max_pool2d.default) +def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples): + torch._check( + self_.ndim in (3, 4), + lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self_.ndim}", + ) + ndim = self_.ndim + + for d in range(ndim - 3, ndim): + torch._check( + self_.size(d) > 0, + f"fractional_max_pool2d: Expected input to have non-zero " + f" size for non-batch dimenions, but got {self_.size()} with dimension {d} empty", + ) + + # the check and message are out of sync, but this matches the structured meta + torch._check( + len(kernel_size) == 2, + lambda: "fractional_max_pool2d: kernel_size must" + "either be a single int or tuple of Ints", + ) + torch._check( + len(output_size) == 2, + lambda: "fractional_max_pool2d: output_size must " + "either be a single int or tuple of Ints", + ) + + input_channels = self_.size(-3) + input_height = self_.size(-2) + input_width = self_.size(-1) + if ndim == 4: + input_batch = self_.size(0) + else: + input_batch = 1 + + torch._check( + self_.dtype == random_samples.dtype, + lambda: "Expect _random_samples to have the same dtype as input", + ) + torch._check( + random_samples.ndim == 3, + lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}", + ) + + n = random_samples.size(0) + c = random_samples.size(1) + d = random_samples.size(2) + torch._check( + n >= input_batch, + "Expect _random_samples.size(0) no less then input batch size.", + ) + torch._check( + c == input_channels, + lambda: "Expect _random_samples.size(1) equals to input channel size.", + ) + torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.") + + torch._check( + output_size[0] + kernel_size[0] - 1 <= input_height, + lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}", + ) + torch._check( + output_size[1] + kernel_size[1] - 1 <= input_width, + lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}", + ) + + if self_.dim() == 4: + size = [input_batch, input_channels, output_size[0], output_size[1]] + else: + size = [input_channels, output_size[0], output_size[1]] + + return ( + torch.empty( + size, + dtype=self_.dtype, + device=self_.device, + ), + torch.empty( + size, + dtype=torch.int64, + device=self_.device, + ), + ) + + +@register_meta(aten.max_unpool2d) +@out_wrapper() +def meta_max_unpool2d(self_, indices, output_size): + utils.alert_not_deterministic("max_unpooling2d_forward_out") + + torch._check( + indices.dtype == torch.int64, + lambda: f"elements in indices should be type int64 but got: {indices.dtype}", + ) + torch._check( + len(output_size) == 2, + lambda: ( + f"There should be exactly two elements (height, width) in output_size, " + f"but got {len(output_size)} elements." + ), + ) + + oheight, owidth = output_size + + torch._check( + self_.ndim in (3, 4), + lambda: ( + f"Input to max_unpooling2d should be a 3d or 4d Tensor, " + f"but got a tensor with {self_.ndim} dimensions." + ), + ) + torch._check( + self_.shape == indices.shape, + lambda: ( + f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) " + f"but got indices tensor with shape: {indices.shape}" + ), + ) + + for i in range(1, self_.ndim): + torch._check( + self_.size(i) > 0, + lambda: ( + f"max_unpooling2d(): " + f"Expected input to have non-zero size for non-batch dimensions, " + f"but got {self_.shape} with dimension {i} being empty." + ), + ) + + self = self_.contiguous() + + if self_.ndim == 3: + nchannels = self.size(0) + result = self.new_empty((nchannels, oheight, owidth)) + else: + nbatch = self.size(0) + nchannels = self.size(1) + result = self.new_empty((nbatch, nchannels, oheight, owidth)) + + return result + + +def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name): + torch._check( + indices.dtype == torch.int64, lambda: "elements in indices should be type int64" + ) + torch._check( + input.ndim in (4, 5), + lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.", + ) + torch._check( + len(output_size) == 3, + lambda: ( + f"There should be exactly three elements (depth, height, width) in output_size, " + f"but got {len(output_size)} elements." + ), + ) + torch._check( + len(stride) == 3, + lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.", + ) + torch._check( + len(padding) == 3, + lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.", + ) + torch._check( + input.shape == indices.shape, + lambda: ( + f"Expected shape of indices to be same as that of the input tensor ({input.shape}) " + f"but got indices tensor with shape: {indices.shape}" + ), + ) + + for i in range(1, input.ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"{fn_name}: " + f"Expected input to have non-zero size for non-batch dimensions, " + f"but got {input.shape} with dimension {i} being empty." + ), + ) + + torch._check( + stride[0] > 0 and stride[1] > 0 and stride[2] > 0, + lambda: f"strides should be greater than zero, but got stride: {stride}", + ) + + +@register_meta(aten.max_unpool3d) +@out_wrapper() +def meta_max_unpool3d(self_, indices, output_size, stride, padding): + utils.alert_not_deterministic("max_unpooling3d_forward_out") + + _max_unpooling3d_shape_check( + self_, indices, output_size, stride, padding, "max_unpooling3d()" + ) + + self = self_.contiguous() + + odepth, oheight, owidth = output_size + + if self_.ndim == 4: + nchannels = self.size(0) + result = self.new_empty((nchannels, odepth, oheight, owidth)) + else: + nbatch = self.size(0) + nchannels = self.size(1) + result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth)) + + return result + + +@register_meta(aten.max_pool3d_with_indices) +@out_wrapper("out", "indices") +def meta_max_pool3d_with_indices( + input, + kernel_size, + stride=(), + padding=(0,), + dilation=(1,), + ceil_mode=False, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints", + ) + pT = padding[0] + pH = pT if len(padding) == 1 else padding[1] + pW = pT if len(padding) == 1 else padding[2] + + torch._check( + len(dilation) in (1, 3), + lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints", + ) + dilationT = dilation[0] + dilationH = dilationT if len(dilation) == 1 else dilation[1] + dilationW = dilationT if len(dilation) == 1 else dilation[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + nbatch = input.size(-5) if input.ndim == 5 else 1 + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode) + oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode) + owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode) + + pool3d_shape_check( + input, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + "max_pool3d_with_indices()", + ) + + channels_last = ( + input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d + ) + if input.ndim == 4: + input_channels_last_check = input.unsqueeze(0) + channels_last = ( + not input_channels_last_check.is_contiguous() + ) and input_channels_last_check.is_contiguous( + memory_format=torch.channels_last_3d + ) + out_shape = (nslices, otime, oheight, owidth) + else: + out_shape = (nbatch, nslices, otime, oheight, owidth) # type: ignore[assignment] + + out = input.new_empty(out_shape) + indices = input.new_empty(out_shape, dtype=torch.int64) + + if channels_last: + out = out.to(memory_format=torch.channels_last_3d) + indices = indices.to(memory_format=torch.channels_last_3d) + + return out, indices + + +@register_meta(aten.max_pool3d_with_indices_backward) +@out_wrapper("grad_input") +def meta_max_pool3d_with_indices_backward( + grad_output, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, +): + torch._check( + len(kernel_size) in (1, 3), + lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints", + ) + kT = kernel_size[0] + kH = kT if len(kernel_size) == 1 else kernel_size[1] + kW = kT if len(kernel_size) == 1 else kernel_size[2] + + torch._check( + not stride or len(stride) in (1, 3), + lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints", + ) + dT = kT if not stride else stride[0] + dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) + dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) + + torch._check( + len(padding) in (1, 3), + lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints", + ) + pT = padding[0] + pH = pT if len(padding) == 1 else padding[1] + pW = pT if len(padding) == 1 else padding[2] + + torch._check( + len(dilation) in (1, 3), + lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints", + ) + dilationT = dilation[0] + dilationH = dilationT if len(dilation) == 1 else dilation[1] + dilationW = dilationT if len(dilation) == 1 else dilation[2] + + torch._check( + input.ndim in (4, 5), + lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", + ) + + nslices = input.size(-4) + itime = input.size(-3) + iheight = input.size(-2) + iwidth = input.size(-1) + + otime = grad_output.size(-3) + oheight = grad_output.size(-2) + owidth = grad_output.size(-1) + + max_pool3d_backward_shape_check( + input, + grad_output, + indices, + nslices, + kT, + kH, + kW, + dT, + dH, + dW, + pT, + pH, + pW, + dilationT, + dilationH, + dilationW, + itime, + iheight, + iwidth, + otime, + oheight, + owidth, + "max_pool3d_with_indices_backward()", + ) + + channels_last = ( + input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d + ) + if input.ndim == 4: + input_channels_last_check = input.unsqueeze(0) + channels_last = ( + not input_channels_last_check.is_contiguous() + ) and input_channels_last_check.is_contiguous( + memory_format=torch.channels_last_3d + ) + + grad_input = input.new_empty(input.shape) + + if channels_last: + grad_input = grad_input.to(memory_format=torch.channels_last_3d) + + return grad_input + + +def check_grid_sampler_common(input: Tensor, grid: Tensor): + torch._check( + input.device == grid.device, + lambda: ( + f"grid_sampler(): expected input and grid to be on same device, but input " + f"is on {input.device} and grid is on {grid.device}" + ), + ) + torch._check( + input.layout == torch.strided and grid.layout == torch.strided, + lambda: ( + f"grid_sampler(): expected input and grid to have torch.strided layout, but " + f"input has {input.layout} and grid has {grid.layout}" + ), + ) + torch._check( + input.shape[0] == grid.shape[0], + lambda: ( + f"grid_sampler(): expected grid and input to have same batch size, but got " + f"input with sizes {input.shape} and grid with sizes {grid.shape}" + ), + ) + torch._check( + grid.shape[-1] == input.ndim - 2, + lambda: ( + f"grid_sampler(): expected grid to have size {input.ndim - 2} in last " + f"dimension, but got grid with sizes {grid.shape}" + ), + ) + + for i in range(2, input.ndim): + torch._check( + input.shape[i] > 0, + lambda: ( + f"grid_sampler(): expected input to have non-empty spatial dimensions, " + f"but input has sizes {input.shape} with dimension {i} being empty" + ), + ) + + +class GridSamplerInterpolation(Enum): + BILINEAR = 0 + NEAREST = 1 + BICUBIC = 2 + + +def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int): + torch._check( + input.ndim == 5 and input.ndim == grid.ndim, + lambda: ( + f"grid_sampler(): expected 5D input and grid with same number of " + f"dimensions, but got input with sizes {input.shape}" + f" and grid with sizes {grid.shape}" + ), + ) + torch._check( + not ( + input.ndim == 5 + and interpolation_mode == GridSamplerInterpolation.BICUBIC.value + ), + lambda: "grid_sampler(): bicubic interpolation only supports 4D input", + ) + + +@register_meta(aten.grid_sampler_2d_backward.default) +def grid_sampler_2d_backward_meta( + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask, +): + input_requires_grad = output_mask[0] + if input_requires_grad: + grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format) + else: + grad_input = None + grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format) + return (grad_input, grad_grid) + + +@register_meta(aten.grid_sampler_3d) +@out_wrapper() +def grid_sampler_3d( + input, + grid, + interpolation_mode, + padding_mode, + align_corners, +): + check_grid_sampler_common(input, grid) + check_grid_sampler_3d(input, grid, interpolation_mode) + N = input.shape[0] + C = input.shape[1] + out_D = grid.shape[1] + out_H = grid.shape[2] + out_W = grid.shape[3] + return input.new_empty((N, C, out_D, out_H, out_W)) + + +@register_meta(aten.grid_sampler_3d_backward) +@out_wrapper("grad_input", "grad_grid") +def grid_sampler_3d_backward( + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask, +): + check_grid_sampler_common(input, grid) + check_grid_sampler_3d(input, grid, interpolation_mode) + input_requires_grad = output_mask[0] + if input_requires_grad: + grad_input = torch.zeros_like( + input, memory_format=torch.legacy_contiguous_format + ) + else: + grad_input = None + grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format) + return grad_input, grad_grid + + +@register_meta([aten.full.default]) +def full(size, fill_value, *args, **kwargs): + dtype = kwargs.get("dtype", None) + if not dtype: + dtype = utils.get_dtype(fill_value) + kwargs["dtype"] = dtype + return torch.empty(size, *args, **kwargs) + + +# zeros_like is special cased to work for sparse +@register_meta(aten.zeros_like.default) +def zeros_like( + self, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + if layout == torch.sparse_coo: + torch._check( + memory_format is None, + lambda: "memory format option is only supported by strided tensors", + ) + + res = torch.empty( + 0, + dtype=self.dtype if dtype is None else dtype, + layout=layout, + device=self.device if device is None else device, + pin_memory=pin_memory, + ) + + if self.is_sparse: + res.sparse_resize_and_clear_( + self.size(), self.sparse_dim(), self.dense_dim() + ) + else: + res.sparse_resize_and_clear_(self.size(), self.dim(), 0) + + res._coalesced_(True) + return res + res = aten.empty_like.default( + self, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format, + ) + # device can be not "meta" + res.fill_(0) + return res + + +@register_meta(aten.select.int) +def meta_select(self, dim, index): + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + torch._check_index( + not (-index > size or index >= size), + lambda: f"select(): index {index} out of range for tensor of size " + f"{self.size()} at dimension {dim}", + ) + + index = index if index >= 0 else index + size + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = self.storage_offset() + index * new_stride[dim] + del new_size[dim] + del new_stride[dim] + + return self.as_strided(new_size, new_stride, new_storage_offset) + + +@register_meta(aten.select_scatter.default) +def meta_select_scatter(self, src, dim, index): + return utils.clone_preserve_strides(self) + + +@register_meta(aten.slice_scatter.default) +def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1): + return utils.clone_preserve_strides(self) + + +# TODO: Deduplicate this with canonicalize_dim +def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): + if dim_post_expr <= 0: + assert wrap_scalar + dim_post_expr = 1 + min = -dim_post_expr + max = dim_post_expr - 1 + assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})" + if dim < 0: + dim += dim_post_expr + return dim + + +def ensure_nonempty_size(t, dim): + return 1 if t.dim() == 0 else t.shape[dim] + + +# From aten/src/ATen/native/ScatterGatherChecks.h +def gather_shape_check(self, dim, index): + self_dims = max(self.dim(), 1) + index_dims = max(index.dim(), 1) + torch._check( + self_dims == index_dims, + lambda: "Index tensor must have the same number of dimensions as input tensor", + ) + for i in range(self_dims): + if i != dim: + torch._check( + ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), + lambda: f"Size does not match at dimension {i} expected index {index.shape}" + + f" to be smaller than self {self.shape} apart from dimension {dim}", + ) + + +@register_meta(aten.gather.default) +def meta_gather(self, dim, index, sparse_grad=False): + wrapped_dim = maybe_wrap_dim(dim, self.dim()) + is_index_empty = index.numel() == 0 + if not is_index_empty: + torch._check( + index.dtype == torch.long, + lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}", + ) + gather_shape_check(self, wrapped_dim, index) + return self.new_empty(index.shape) + + +# From aten/src/ATen/native/TensorAdvancedIndexing.cpp +def get_operator_enum(reduce_, use_new_options=False): + if use_new_options: + if reduce_ == "sum": + return "REDUCE_ADD" + elif reduce_ == "prod": + return "REDUCE_MULTIPLY" + elif reduce_ == "mean": + return "REDUCE_MEAN" + elif reduce_ == "amax": + return "REDUCE_MAXIMUM" + elif reduce_ == "amin": + return "REDUCE_MINIMUM" + torch._check( + False, + lambda: "reduce argument must be either sum, prod, mean, amax or amin.", + ) + return + else: + if reduce_ == "add": + return "REDUCE_ADD" + elif reduce_ == "multiply": + return "REDUCE_MULTIPLY" + torch._check(False, lambda: "reduce argument must be either add or multiply.") + return + + +# From aten/src/ATen/native/ScatterGatherChecks.h +def scatter_gather_dtype_check(method_name, self, index, src_opt=None): + if index.numel() != 0: + torch._check( + index.dtype == torch.long, + lambda: f"{method_name}(): Expected dtype int64 for index", + ) + + if src_opt is not None: + torch._check( + self.dtype == src_opt.dtype, + lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype", + ) + + +def ensure_nonempty_dim(dim): + return max(dim, 1) + + +# From aten/src/ATen/native/ScatterGatherChecks.h +def scatter_shape_check(self, dim, index, src_opt=None): + if index.numel() == 0: + return + torch._check( + ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), + lambda: "Index tensor must have the same number of dimensions as self tensor", + ) + + is_wrong_shape = False + self_dims = ensure_nonempty_dim(self.dim()) + + # Check: index.size(d) <= self.size(d) for all d != dim + for d in range(self_dims): + index_d_size = ensure_nonempty_size(index, d) + if d == dim: + continue + if index_d_size > ensure_nonempty_size(self, d): + is_wrong_shape = True + break + + # Check: index.size(d) <= src.size(d) for all d if src is Tensor + if not is_wrong_shape and src_opt is not None: + for d in range(self_dims): + index_d_size = ensure_nonempty_size(index, d) + if index_d_size > ensure_nonempty_size(src_opt, d): + is_wrong_shape = True + break + + if src_opt is not None: + torch._check( + ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), + lambda: "Index tensor must have the same number of dimensions as self tensor", + ) + torch._check( + not is_wrong_shape, + lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}", + ) + else: + torch._check( + not is_wrong_shape, + lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + + f" apart from dimension {dim}", + ) + + +# From aten/src/ATen/native/TensorAdvancedIndexing.cpp +def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False): + wrapped_dim = maybe_wrap_dim(dim, self.dim()) + scatter_gather_dtype_check("scatter", self, index, src) + scatter_shape_check(self, wrapped_dim, index, src) + if reduce_ is not None: + # Check if we have a valid reduce operator. + get_operator_enum(reduce_, use_new_options) + + +@register_meta(aten.scatter_add.default) +def meta_scatter_add(self, dim, index, src): + scatter_meta_impl(self, dim, index, src, "add") + return self.new_empty(self.shape) + + +@register_meta(aten.scatter_add_) +def meta_scatter_add_(self, dim, index, src): + scatter_meta_impl(self, dim, index, src, "add") + return self + + +@register_meta( + [ + aten.scatter.src, + aten.scatter.value, + aten.scatter.reduce, + aten.scatter.value_reduce, + ] +) +@out_wrapper() +def meta_scatter(self, dim, index, src_or_value, reduce=None): + src = src_or_value if isinstance(src_or_value, torch.Tensor) else None + scatter_meta_impl(self, dim, index, src, reduce) + return self.new_empty(self.shape) + + +@register_meta( + [ + aten.scatter_.src, + aten.scatter_.value, + aten.scatter_.reduce, + aten.scatter_.value_reduce, + ] +) +def meta_scatter_(self, dim, index, src_or_value, reduce=None): + src = src_or_value if isinstance(src_or_value, torch.Tensor) else None + scatter_meta_impl(self, dim, index, src, reduce) + return self + + +@register_meta( + [ + aten._scaled_dot_product_flash_attention_backward, + ] +) +def meta__scaled_dot_product_flash_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: Tensor, + philox_offset: Tensor, + scale: Optional[float] = None, +): + grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2) + grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2) + grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2) + return grad_q, grad_k, grad_v + + +@register_meta( + [ + aten._scaled_dot_product_flash_attention_for_cpu, + ] +) +def meta__scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + head_dim = query.size(3) + + attention = torch.empty( + (batch_size, max_seqlen_batch_q, num_heads, head_dim), + dtype=query.dtype, + device=query.device, + ).transpose(1, 2) + logsumexp = torch.empty( + ( + batch_size, + max_seqlen_batch_q, + num_heads, + ), + dtype=torch.float, + device=query.device, + ).transpose(1, 2) + return ( + attention, + logsumexp, + ) + + +@register_meta( + [ + aten._scaled_dot_product_flash_attention_for_cpu_backward, + ] +) +def meta__scaled_dot_product_flash_attention_for_cpu_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + dropout_p: float, + is_causal: bool, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +): + # cpus's grad layout is different from cuda's, + # i.e. (batch_size, seq_len,num_heads, head_dim) + batch_size = query.size(0) + num_heads = query.size(1) + head_dim = query.size(3) + len_q = query.size(2) + len_k = key.size(2) + + grad_q = torch.empty_permuted( + (batch_size, num_heads, len_q, head_dim), + (0, 2, 1, 3), + dtype=query.dtype, + device=query.device, + ) + grad_k = torch.empty_permuted( + (batch_size, num_heads, len_k, head_dim), + (0, 2, 1, 3), + dtype=key.dtype, + device=key.device, + ) + grad_v = torch.empty_permuted( + (batch_size, num_heads, len_k, head_dim), + (0, 2, 1, 3), + dtype=value.dtype, + device=value.device, + ) + + return grad_q, grad_k, grad_v + + +@register_meta( + [ + aten._scaled_dot_product_efficient_attention_backward, + ] +) +def meta__scaled_dot_product_efficient_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + out: Tensor, + logsumexp: Tensor, + philox_seed: Tensor, + philox_offset: Tensor, + dropout_p: float, + grad_input_mask: List[bool], + is_causal: bool = False, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_q = query.size(2) + head_dim = query.size(3) + head_dim_v = value.size(3) + + max_k = key.size(2) + + grad_q = torch.empty_permuted( + (batch_size, num_heads, max_q, head_dim), + (0, 2, 1, 3), + dtype=query.dtype, + device=query.device, + ) + grad_k = torch.empty_permuted( + (batch_size, num_heads, max_k, head_dim), + (0, 2, 1, 3), + dtype=key.dtype, + device=key.device, + ) + grad_v = torch.empty_permuted( + (batch_size, num_heads, max_k, head_dim_v), + (0, 2, 1, 3), + dtype=value.dtype, + device=value.device, + ) + grad_bias = None + if attn_bias is not None and grad_input_mask[3]: + lastDim = attn_bias.size(-1) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + new_sizes = list(attn_bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty( + new_sizes, dtype=attn_bias.dtype, device=attn_bias.device + ) + grad_bias = grad_bias[..., :lastDim] + + return grad_q, grad_k, grad_v, grad_bias + + +@register_meta( + [ + aten._flash_attention_backward, + ] +) +def meta__flash_attention_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + out: Tensor, + logsumexp: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: Tensor, + philox_offset: Tensor, + scale: Optional[float] = None, +): + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + return grad_query, grad_key, grad_value + + +@register_meta( + [ + aten._efficient_attention_backward, + ] +) +def meta__efficient_attention_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + logsumexp: Tensor, + dropout_p: float, + philox_seed: Tensor, + philox_offset: Tensor, + custom_mask_type: int, + bias_requires_grad: bool, + scale: Optional[float] = None, + num_splits_key: Optional[int] = None, +): + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + if bias is not None: + lastDim = bias.size(-1) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + new_sizes = list(bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device) + grad_bias = grad_bias[..., :lastDim] + else: + grad_bias = torch.empty((), device=query.device) + + return grad_query, grad_key, grad_value, grad_bias + + +@register_meta([aten._scaled_mm.default]) +def meta_scaled_mm( + self: torch.Tensor, + mat2: torch.Tensor, + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + scale_a: Optional[torch.Tensor] = None, + scale_b: Optional[torch.Tensor] = None, + scale_result: Optional[torch.Tensor] = None, + use_fast_accum: bool = False, +): + def is_row_major(stride): + return stride[0] > stride[1] and stride[1] == 1 + + def is_col_major(shape, stride): + return stride[0] == 1 and stride[1] == shape[0] + + def is_fp8_type(dtype): + return dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ) + + torch._check( + self.dim() == 2 and mat2.dim() == 2, + lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", + ) + torch._check( + is_row_major(self.stride()), + lambda: "self must be row_major", + ) + torch._check( + is_col_major(mat2.shape, mat2.stride()), + lambda: "mat2 must be col_major", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) + torch._check( + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", + ) + torch._check( + is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), + lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", + ) + _out_dtype = out_dtype if out_dtype is not None else self.dtype + return torch.empty( + self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device + ), torch.empty((), dtype=torch.float32, device=self.device) + + +@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) +@out_wrapper() +def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True): + scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) + return self.new_empty(self.shape) + + +@register_meta(aten.scatter_reduce_.two) +def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): + scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) + return self + + +@register_meta([aten.multinomial.default, aten.multinomial.out]) +@out_wrapper() +def meta_multinomial(input, num_samples, replacement=False, *, generator=None): + torch._check( + 0 < input.dim() <= 2, + lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}", + ) + if input.dim() == 1: + return torch.empty(num_samples, dtype=torch.long, device=input.device) + return torch.empty( + input.size(0), num_samples, dtype=torch.long, device=input.device + ) + + +def multiply_integers(vs): + r = 1 + for v in vs: + r *= v + return r + + +def upsample_common_check(input_size, output_size, num_spatial_dims): + torch._check( + len(output_size) == num_spatial_dims, + lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}", + ) + expected_input_dims = num_spatial_dims + 2 # N, C, ... + torch._check( + len(input_size) == expected_input_dims, + lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}", + ) + + torch._check( + all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size), + lambda: f"Input and output sizes should be greater than 0, but got " + f"input size {input_size} and output size {output_size}", + ) + + nbatch, channels = input_size[:2] + return (nbatch, channels, *output_size) + + +@register_meta( + [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default] +) +def upsample_nearest1d(input, output_size, scales=None): + torch._check( + input.numel() != 0 or multiply_integers(input.size()[1:]), + lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}", + ) + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=1 + ) + return input.new_empty(full_output_size).to( + memory_format=utils.suggest_memory_format(input) + ) + + +@register_meta( + [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default] +) +def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): + torch._check( + input.numel() != 0 or multiply_integers(input.size()[1:]), + lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", + ) + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=2 + ) + output = input.new_empty(full_output_size) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + _, n_channels, _, _ = input.shape + if input.device.type == "cuda" and n_channels < 4: + memory_format = torch.contiguous_format + + output = output.contiguous(memory_format=memory_format) + + return output + + +@register_meta( + [ + aten.upsample_nearest2d_backward.default, + aten._upsample_nearest_exact2d_backward.default, + ] +) +def upsample_nearest2d_backward( + grad_output: Tensor, + output_size: Sequence[Union[int, torch.SymInt]], + input_size: Sequence[Union[int, torch.SymInt]], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + full_output_size = upsample_common_check( + input_size, output_size, num_spatial_dims=2 + ) + torch._check( + grad_output.ndim == 4, + lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}", + ) + for i in range(4): + torch._check( + grad_output.size(i) == full_output_size[i], + lambda: ( + f"Expected grad_output to have the same shape as output;" + f" output.size({i}) = {full_output_size[i]}" + f" but got grad_output.size({i}) = {grad_output.size(i)}" + ), + ) + + return grad_output.new_empty(input_size).to( + memory_format=utils.suggest_memory_format(grad_output) + ) # type: ignore[call-overload] + + +@register_meta( + [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default] +) +def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None): + torch._check( + input.numel() != 0 or multiply_integers(input.size()[1:]), + lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}", + ) + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=3 + ) + return input.new_empty(full_output_size).to( + memory_format=utils.suggest_memory_format(input) + ) + + +@register_meta( + [ + aten.sort.default, + aten.sort.stable, + aten.sort.values, + aten.sort.values_stable, + ] +) +def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None): + v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) + if values is not None and indices is not None: + assert isinstance(values, TensorLike) + assert isinstance(indices, TensorLike) + # Makes sure values and indices have the same strides. For cases where + # these have different shapes, like (5, 10, 5) and (0) in msort. + out_shape = v.shape + out_stride = v.stride() + values = _maybe_resize_out(values, out_shape) + indices = _maybe_resize_out(indices, out_shape) + values.as_strided_(out_shape, out_stride) + indices.as_strided_(out_shape, out_stride) + _safe_copy_out(copy_from=v, copy_to=values) # type: ignore[arg-type] + _safe_copy_out(copy_from=i, copy_to=indices) # type: ignore[arg-type] + return values, indices + return v, i + + +@register_meta(aten.argsort.stable) +def meta_argsort(self, *, stable, dim=-1, descending=False): + return meta_sort(self, stable=stable, dim=dim, descending=descending)[1] + + +def rnn_cell_checkSizes( + input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden +): + torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") + torch._check( + input_gates.shape == hidden_gates.shape, + lambda: f"{input_gates.shape} != {hidden_gates.shape}", + ) + gates_size = input_gates.size(1) + if input_bias is not None: + torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1") + torch._check( + input_bias.numel() == gates_size, + lambda: f"{input_bias.numel()} != {gates_size}", + ) + torch._check( + input_bias.shape == hidden_bias.shape, + lambda: f"{input_bias.shape} != {hidden_bias.shape}", + ) + torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2") + expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor + torch._check( + prev_hidden.numel() == expected_prev_hidden_numel, + lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})", + ) + torch._check( + all( + x.device == input_gates.device + for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] + ), + lambda: "expected all inputs to be same device", + ) + + +@register_meta(aten._thnn_fused_lstm_cell.default) +def _thnn_fused_lstm_cell_meta( + input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None +): + rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx) + workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format) + hy = torch.empty_like(cx, memory_format=torch.contiguous_format) + cy = torch.empty_like(cx, memory_format=torch.contiguous_format) + return (hy, cy, workspace) + + +@register_meta(aten._cudnn_rnn.default) +def _cudnn_rnn( + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, +): + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] + else: + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 + + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) + output = input.new_empty(out_shape) + + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + if cx is None: + cy = torch.empty(0, device=input.device) + else: + cy = cx.new_empty(cell_shape) + + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + + return output, hy, cy, reserve, weight_buf + + +@register_meta(aten.mkldnn_rnn_layer.default) +def mkldnn_rnn_layer( + input, + w0, + w1, + w2, + w3, + hx_, + cx_, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, +): + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + output_chanels = hidden_size + out_shape = ( + [mini_batch, seq_length, output_chanels] + if batch_first + else [seq_length, mini_batch, output_chanels] + ) + output = input.new_empty(out_shape) + if hx_ is None: + hy = torch.empty(0, device=input.device) + else: + hy = hx_.new_empty(hx_.shape) + if cx_ is None: + cy = torch.empty(0, device=input.device) + else: + cy = cx_.new_empty(cx_.shape) + workspace = torch.empty(0, device=input.device, dtype=torch.uint8) + return output, hy, cy, workspace + + +def zero_numel_check_dims(self, dim, fn_name): + if self.ndim == 0: + torch._check_index( + dim == 0 or dim == -1, + lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", + ) + else: + torch._check_index( + self.size(dim) != 0, + lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", + ) + + +# From aten/src/ATen/native/ReduceOps.cpp +def check_argmax_argmin(name, self, dim): + if dim is not None: + dim = maybe_wrap_dim(dim, self.dim()) + zero_numel_check_dims(self, dim, name) + else: + torch._check( + self.numel() != 0, + lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", + ) + + +@register_meta([aten.argmax.default, aten.argmin.default]) +def argmax_argmin_meta(self, dim=None, keepdim=False): + check_argmax_argmin("argmax", self, dim) + dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) + shape = _compute_reduction_shape(self, dims, keepdim) + return self.new_empty(shape, dtype=torch.int64) + + +@register_meta(aten.scalar_tensor.default) +def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.topk.default) +def topk_meta(self, k, dim=-1, largest=True, sorted=True): + # From aten/src/ATen/native/Sorting.cpp + dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) + torch._check( + k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1), + lambda: "selected index k out of range", + ) + sliceSize = 1 if self.dim() == 0 else self.size(dim) + torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") + + topKSize = list(self.shape) + if len(topKSize) > 0: + topKSize[dim] = k + return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) + + +legacy_contiguous_memory_format = torch.contiguous_format + + +# From aten/src/ATen/native/cuda/RNN.cu +def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace): + defined_grad = grad_hy if grad_hy is not None else grad_cy + torch._check(defined_grad.dim() == 2, lambda: "") + exp_size = defined_grad.size() + if grad_hy is not None: + torch._check(grad_hy.size() == exp_size, lambda: "") + if grad_cy is not None: + torch._check(grad_cy.size() == exp_size, lambda: "") + torch._check(cx.size() == exp_size, lambda: "") + torch._check(cy.size() == exp_size, lambda: "") + torch._check(workspace.dim() == 2, lambda: "") + torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "") + + +# From aten/src/ATen/native/cuda/RNN.cu +@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default) +def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias): + if grad_hy is None and grad_cy is None: + return None, None, None + checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace) + grad_gates = torch.empty_like( + workspace, memory_format=legacy_contiguous_memory_format + ) + grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format) + grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None + return grad_gates, grad_cx, grad_bias + + +# From aten/src/ATen/native/mps/operations/Linear.mm +@register_meta(aten.linear_backward.default) +def linear_backward(input_, grad_output_, weight_, output_mask): + grad_input = None + grad_weight = None + grad_bias = None + if output_mask[0]: + grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1] or output_mask[2]: + grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1))) + grad_bias = grad_output_.new_empty(grad_output_.size(-1)) + return (grad_input, grad_weight, grad_bias) + + +@register_meta(aten.pixel_shuffle.default) +def meta_pixel_shuffle(self, upscale_factor): + assert ( + len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0 + ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}" + + def is_channels_last(ten): + return torch._prims_common.suggest_memory_format(ten) == torch.channels_last + + def pick_memory_format(): + if is_channels_last(self): + if device_hint(self) == "cuda": + return torch.contiguous_format + else: + return torch.channels_last + elif self.is_contiguous(memory_format=torch.contiguous_format): + return torch.contiguous_format + elif self.is_contiguous(memory_format=torch.preserve_format): + return torch.preserve_format + + C = self.shape[-3] // (upscale_factor * upscale_factor) + Hr = self.shape[-2] * upscale_factor + Wr = self.shape[-1] * upscale_factor + out_shape = (*self.shape[:-3], C, Hr, Wr) + + out = self.new_empty(out_shape) + out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] + return out + + +@register_meta(aten.mkldnn_rnn_layer_backward.default) +def mkldnn_rnn_layer_backward( + input, + weight0, + weight1, + weight2, + weight3, + hx_, + cx_tmp, + output, + hy_, + cy_, + grad_output_r_opt, + grad_hy_r_opt, + grad_cy_r_opt, + reverse, + mode, + hidden_size, + num_layers, + has_biases, + train, + bidirectional, + batch_sizes, + batch_first, + workspace, +): + diff_x = input.new_empty(input.shape) + diff_hx = hx_.new_empty(hx_.shape) + diff_cx = cx_tmp.new_empty(cx_tmp.shape) + diff_w1 = weight0.new_empty(weight0.shape) + diff_w2 = weight1.new_empty(weight1.shape) + diff_b = weight2.new_empty(weight2.shape) + return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx + + +@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out]) +@out_wrapper() +def meta_bucketize(self, boundaries, *, out_int32=False, right=False): + return torch.empty_like( + self, dtype=torch.int32 if out_int32 else torch.int64 + ).contiguous() + + +@register_meta( + [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default] +) +def meta_upsample_bimode2d_aa( + input, output_size, align_corners, scales_h=None, scales_w=None +): + full_output_size = upsample_common_check( + input.size(), output_size, num_spatial_dims=2 + ) + torch._check( + input.numel() != 0 or all(size > 0 for size in input.size()[1:]), + lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", + ) + return input.new_empty(full_output_size).to( + memory_format=utils.suggest_memory_format(input) + ) + + +# From aten/src/ATen/native/cuda/AmpKernels.cu +@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default) +def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): + torch._check( + found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor." + ) + torch._check( + inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor." + ) + torch._check( + found_inf.dtype.is_floating_point, + lambda: "found_inf must be a float tensor.", + ) + torch._check( + inv_scale.dtype.is_floating_point, + lambda: "inv_scale must be a float tensor.", + ) + + +# From aten/src/ATen/native/UnaryOps.cpp +@register_meta([aten.nan_to_num.default, aten.nan_to_num.out]) +@out_wrapper() +def nan_to_num(self, nan=None, posinf=None, neginf=None): + result_size = list(self.size()) + return self.new_empty(result_size) + + +@register_meta(torch.ops.aten.transpose_) +def transpose_(self, dim0, dim1): + assert self.layout not in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout" + + ndims = self.ndim + + dim0 = maybe_wrap_dim(dim0, ndims) + dim1 = maybe_wrap_dim(dim1, ndims) + + if dim0 == dim1: + return self + + size = list(self.size()) + stride = list(self.stride()) + + stride[dim0], stride[dim1] = stride[dim1], stride[dim0] + size[dim0], size[dim1] = size[dim1], size[dim0] + + self.as_strided_(size, stride) + return self + + +@register_meta(torch.ops.aten.t_) +def t_(self): + ndims = self.ndim + + if self.is_sparse: + sparse_dim = self.sparse_dim() + dense_dim = self.dense_dim() + assert ( + sparse_dim <= 2 and dense_dim == 0 + ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950 + else: + assert ( + self.dim() <= 2 + ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D" + + return transpose_(self, 0, 0 if ndims < 2 else 1) + + +@register_meta(aten.searchsorted) +@out_wrapper() +def meta_searchsorted( + sorted_sequence, self, *, out_int32=False, right=False, side=None, sorter=None +): + dtype = torch.int32 if out_int32 else torch.int64 + if isinstance(self, torch.Tensor): + return torch.empty_like(self, dtype=dtype).contiguous() + else: # Scalar + return torch.empty((), dtype=dtype, device=sorted_sequence.device) + + +def _check_for_unsupported_isin_dtype(dtype): + torch._check( + dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64], + lambda: f"Unsupported input type encountered for isin(): {dtype}", + ) + + +@register_meta(aten.isin) +@out_wrapper() +def meta_isin(elements, test_elements, *, assume_unique=False, invert=False): + torch._check( + isinstance(elements, Tensor) or isinstance(test_elements, Tensor), + lambda: "At least one of elements and test_elements must be a Tensor.", + ) + if not isinstance(elements, Tensor): + elements = torch.tensor(elements, device=test_elements.device) + + if not isinstance(test_elements, Tensor): + test_elements = torch.tensor(test_elements, device=elements.device) + + _check_for_unsupported_isin_dtype(elements.dtype) + _check_for_unsupported_isin_dtype(test_elements.dtype) + return torch.empty_like(elements, dtype=torch.bool) + + +@register_meta(aten.polygamma) +@out_wrapper() +def meta_polygamma(n: int, self: Tensor) -> Tensor: + torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.") + _, result_dtype = elementwise_dtypes( + self, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + return torch.empty_like(self, dtype=result_dtype) + + +def _create_unary_float_meta_func(func): + @register_meta(func) + @out_wrapper() + def _f(x): + return elementwise_meta( + x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + return _f + + +def _create_binary_float_meta_func(func): + @register_meta(func) + @out_wrapper() + def _f(x, y): + return elementwise_meta( + x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + return _f + + +_create_unary_float_meta_func(aten.special_airy_ai) +_create_unary_float_meta_func(aten.special_bessel_y0) +_create_unary_float_meta_func(aten.special_bessel_y1) +_create_unary_float_meta_func(aten.special_modified_bessel_i0) +_create_unary_float_meta_func(aten.special_modified_bessel_i1) +_create_unary_float_meta_func(aten.special_modified_bessel_k0) +_create_unary_float_meta_func(aten.special_modified_bessel_k1) +_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0) +_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1) + + +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t) +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u) +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v) +_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v) +_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w) +_create_binary_float_meta_func(aten.special_hermite_polynomial_h) +_create_binary_float_meta_func(aten.special_hermite_polynomial_he) +_create_binary_float_meta_func(aten.special_laguerre_polynomial_l) +_create_binary_float_meta_func(aten.special_legendre_polynomial_p) + + +# We must also trigger meta registrations from PrimTorch ref +# decompositions +import torch._refs +import torch._refs.nn.functional +import torch._refs.special + + +def activate_meta(): + activate_meta_table = {} + + # For a given op, we pick the most specific decomp function from + # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd + for type in ["meta", "post_autograd", "pre_autograd"]: + registry = global_decomposition_table[type] + + for opo in registry: + if opo not in activate_meta_table: + activate_meta_table[opo] = registry[opo] + + for op_overload, fn in activate_meta_table.items(): + # Don't register meta for HigherOrderOp's decomp. + # We can reconsider this in the future, but in general, + # the way you do a meta for a HigherOrderOp is different from + # OpOverload. + if isinstance(op_overload, torch._ops.HigherOrderOperator): + continue + assert isinstance(op_overload, OpOverload) + + op_overload.py_impl(torch._C.DispatchKey.Meta)(fn) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_overload.name(), "CompositeImplicitAutograd" + ): + # Internally, we shouldn't be registering meta kernels for any operators that + # have CompositeImplicitAutograd kernels. + # Instead, we should be letting those decompositions run, and writing meta kernels + # only for the base operators. + if op_overload in global_decomposition_table["meta"]: + raise RuntimeError( + f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't " + "register meta function for it. Instead, we should let the decomposition run and write " + "meta kernels for the base operators." + ) + pass + elif op_overload.is_view: + # Attempting to register a python meta kernel for a view operator. + # We shouldn't do this, because the output will report as not having aliased storages. + # All view ops have meta kernels in C++ today, so we should use those instead. + pass + elif op_overload.name() in { + "aten::empty_strided", # causing infinite recursion, test_meta.py + "aten::clone", # causing infinite recursion + "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950 + "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950 + "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950 + "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950 + "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950 + }: + pass + else: + if "mkldnn::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn) + elif "mkl::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn) + elif "onednn::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn) + elif "quantized::" in op_overload.name(): + _meta_lib_dont_use_me_use_register_meta_for_quantized.impl( + op_overload, fn + ) + else: + _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) + + +activate_meta() diff --git a/MLPY/Lib/site-packages/torch/_namedtensor_internals.py b/MLPY/Lib/site-packages/torch/_namedtensor_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..c073e7c12fc8795aa5fa4c375eb91c383a9cd666 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_namedtensor_internals.py @@ -0,0 +1,157 @@ +from collections import OrderedDict + +""" +This file contains helper functions that implement experimental functionality +for named tensors in python. All of these are experimental, unstable, and +subject to change or deletion. +""" + + +def check_serializing_named_tensor(tensor): + if tensor.has_names(): + raise RuntimeError( + "NYI: Named tensors don't support serialization. Please drop " + "names via `tensor = tensor.rename(None)` before serialization." + ) + + +def build_dim_map(tensor): + """Returns a map of { dim: dim_name } where dim is a name if the dim is named + and the dim index otherwise.""" + return OrderedDict( + [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)] + ) + + +def unzip_namedshape(namedshape): + if isinstance(namedshape, OrderedDict): + namedshape = namedshape.items() + if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple): + raise RuntimeError( + f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}" + ) + if len(namedshape) == 0: + raise RuntimeError("Expected namedshape to non-empty.") + return zip(*namedshape) + + +def namer_api_name(inplace): + if inplace: + return "rename_" + else: + return "rename" + + +def is_ellipsis(item): + return item == Ellipsis or item == "..." + + +def single_ellipsis_index(names, fn_name): + ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] + if len(ellipsis_indices) >= 2: + raise RuntimeError( + f"{fn_name}: More than one Ellipsis ('...') found in names (" + f"{names}). This function supports up to one Ellipsis." + ) + if len(ellipsis_indices) == 1: + return ellipsis_indices[0] + return None + + +def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): + return names[numel_pre_glob : len(names) - numel_post_glob] + + +def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): + globbed_names = expand_single_ellipsis( + ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names + ) + return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :] + + +def resolve_ellipsis(names, tensor_names, fn_name): + """ + Expands ... inside `names` to be equal to a list of names from `tensor_names`. + """ + ellipsis_idx = single_ellipsis_index(names, fn_name) + if ellipsis_idx is None: + return names + return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) + + +def update_names_with_list(tensor, names, inplace): + # Special case for tensor.rename(None) + if len(names) == 1 and names[0] is None: + return tensor._update_names(None, inplace) + + return tensor._update_names( + resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace + ) + + +def update_names_with_mapping(tensor, rename_map, inplace): + dim_map = build_dim_map(tensor) + for old_dim in rename_map.keys(): + new_dim = rename_map[old_dim] + if old_dim in dim_map.keys(): + dim_map[old_dim] = new_dim + else: + raise RuntimeError( + f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim " + f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist" + ) + return tensor._update_names(tuple(dim_map.values()), inplace) + + +def update_names(tensor, names, rename_map, inplace): + """There are two usages: + + tensor.rename(*names) returns a view on tensor with named dims `names`. + `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`, + then it is expanded greedily to be equal to the corresponding names from + `tensor.names`. + + For example, + ``` + >>> # xdoctest: +SKIP + >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> x.rename('...', 'height', 'width').names + ('N', 'C', 'height', 'width') + + >>> # xdoctest: +SKIP + >>> x.rename('batch', '...', 'width').names + ('batch', 'C', 'H', 'width') + + ``` + + tensor.rename(**rename_map) returns a view on tensor that has rename dims + as specified in the mapping `rename_map`. + + For example, + ``` + >>> # xdoctest: +SKIP + >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> x.rename(W='width', H='height').names + ('N', 'C', 'height', 'width') + + ``` + + Finally, tensor.rename has an in-place version called tensor.rename_. + """ + has_names = len(names) > 0 + has_rename_pairs = bool(rename_map) + if has_names and has_rename_pairs: + raise RuntimeError( + f"{namer_api_name(inplace)}: This function takes either positional " + f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) " + f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename " + "dims." + ) + + # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor. + if not has_names and not has_rename_pairs: + return update_names_with_list(tensor, names, inplace) + + if has_names: + return update_names_with_list(tensor, names, inplace) + return update_names_with_mapping(tensor, rename_map, inplace) diff --git a/MLPY/Lib/site-packages/torch/_numpy/__init__.py b/MLPY/Lib/site-packages/torch/_numpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd4b4e7f7481fc5023cbfe62f10a7881e5257fa --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/__init__.py @@ -0,0 +1,30 @@ +# mypy: ignore-errors + +from . import fft, linalg, random +from ._dtypes import * # noqa: F403 +from ._funcs import * # noqa: F403 +from ._getlimits import finfo, iinfo +from ._ndarray import ( + array, + asarray, + ascontiguousarray, + can_cast, + from_dlpack, + ndarray, + newaxis, + result_type, +) +from ._ufuncs import * # noqa: F403 +from ._util import AxisError, UFuncTypeError + +# from . import testing + +alltrue = all +sometrue = any + +inf = float("inf") +nan = float("nan") +from math import pi, e # isort: skip + +False_ = False +True_ = True diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0267877707d8803f23d1817ce9f4eccc6064a36 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..084c7755d6277bb907a698597f3e3178f2e43c96 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f322d82433813988ad140c1052ce97772c01b21 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dbc2c16828f670246508d56b57db107afad6930 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab17e2c8921d38e11a130ffae9c1f60fac334e77 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_funcs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_funcs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afe5c2d6fbc24545f1d15db5c720854d594d2a38 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_funcs.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b423dadbb266156725ef5062425e7dfdd38ac561 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbd2d6f0643875b6ae053d70144e98da41a4ef6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb9a2073b550fee22ca704e1e3535f21c4e58ba Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a11705d885176995c08a6e4d9fb08531747d6f0d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5a8ac5c96b085427e414ce26dba2dffab409376 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33f85bef0d85ba93a1306972ac71da3a7b62703a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..518ce42ce37d5e38009f8c5fd2b143e5de206d4e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_util.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..869e7f810ad8d103128fc928d9b89e73a10a9152 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/_util.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/fft.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/fft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..320fabb3fdf81a33a6adfbd0a1fc9ab9e41e9ec6 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/fft.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/linalg.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/linalg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6bfe97ec013d44375263d02357b5c6c85f184a Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/linalg.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/__pycache__/random.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5f63a73bac41ac2049c1a591221eb822c9940af Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/__pycache__/random.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/_binary_ufuncs_impl.py b/MLPY/Lib/site-packages/torch/_numpy/_binary_ufuncs_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..1636dfe34b3cbff5fb5a0be400935a87706d4dab --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_binary_ufuncs_impl.py @@ -0,0 +1,86 @@ +# mypy: ignore-errors + +"""Export torch work functions for binary ufuncs, rename/tweak to match numpy. +This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module. +""" + +import torch + +from torch import ( # noqa: F401 + add, # noqa: F401 + arctan2, # noqa: F401 + bitwise_and, # noqa: F401 + bitwise_left_shift as left_shift, # noqa: F401 + bitwise_or, # noqa: F401 + bitwise_right_shift as right_shift, # noqa: F401 + bitwise_xor, # noqa: F401 + copysign, # noqa: F401 + divide, # noqa: F401 + eq as equal, # noqa: F401 + float_power, # noqa: F401 + floor_divide, # noqa: F401 + fmax, # noqa: F401 + fmin, # noqa: F401 + fmod, # noqa: F401 + gcd, # noqa: F401 + greater, # noqa: F401 + greater_equal, # noqa: F401 + heaviside, # noqa: F401 + hypot, # noqa: F401 + lcm, # noqa: F401 + ldexp, # noqa: F401 + less, # noqa: F401 + less_equal, # noqa: F401 + logaddexp, # noqa: F401 + logaddexp2, # noqa: F401 + logical_and, # noqa: F401 + logical_or, # noqa: F401 + logical_xor, # noqa: F401 + maximum, # noqa: F401 + minimum, # noqa: F401 + multiply, # noqa: F401 + nextafter, # noqa: F401 + not_equal, # noqa: F401 + pow as power, # noqa: F401 + remainder, # noqa: F401 + remainder as mod, # noqa: F401 + subtract, # noqa: F401 + true_divide, # noqa: F401 +) + +from . import _dtypes_impl, _util + + +# work around torch limitations w.r.t. numpy +def matmul(x, y): + # work around: + # - RuntimeError: expected scalar type Int but found Double + # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool' + # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + dtype = _dtypes_impl.result_type_impl(x, y) + is_bool = dtype == torch.bool + is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and ( + x.is_cpu or y.is_cpu + ) + + work_dtype = dtype + if is_bool: + work_dtype = torch.uint8 + if is_half: + work_dtype = torch.float32 + + x = _util.cast_if_needed(x, work_dtype) + y = _util.cast_if_needed(y, work_dtype) + + result = torch.matmul(x, y) + + if work_dtype != dtype: + result = result.to(dtype) + + return result + + +# a stub implementation of divmod, should be improved after +# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch +def divmod(x, y): + return x // y, x % y diff --git a/MLPY/Lib/site-packages/torch/_numpy/_casting_dicts.py b/MLPY/Lib/site-packages/torch/_numpy/_casting_dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4176512321f696b02e9d77336e1cb3c769c2b4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_casting_dicts.py @@ -0,0 +1,881 @@ +# mypy: ignore-errors + +import torch + +# These two dicts are autogenerated with autogen/gen_dtypes.py, +# using numpy version 1.23.5. + +_can_cast_dict = { + "no": { + torch.float16: { + torch.float16: True, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: False, + torch.float32: True, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: True, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: True, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: False, + torch.bool: False, + }, + torch.int64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: True, + }, + }, + "equiv": { + torch.float16: { + torch.float16: True, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: False, + torch.float32: True, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: True, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: True, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: False, + torch.bool: False, + }, + torch.int64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: True, + }, + }, + "safe": { + torch.float16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: False, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int16: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int32: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + }, + "same_kind": { + torch.float16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + }, + "unsafe": { + torch.float16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.float32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.float64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.complex64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.complex128: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.bool: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + }, +} + + +_result_type_dict = { + torch.float16: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.float16, + torch.int8: torch.float16, + torch.int16: torch.float32, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.float16, + }, + torch.float32: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.float32, + torch.int8: torch.float32, + torch.int16: torch.float32, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.float32, + }, + torch.float64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.float64, + torch.int8: torch.float64, + torch.int16: torch.float64, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.float64, + }, + torch.complex64: { + torch.float16: torch.complex64, + torch.float32: torch.complex64, + torch.float64: torch.complex128, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.complex64, + torch.int8: torch.complex64, + torch.int16: torch.complex64, + torch.int32: torch.complex128, + torch.int64: torch.complex128, + torch.bool: torch.complex64, + }, + torch.complex128: { + torch.float16: torch.complex128, + torch.float32: torch.complex128, + torch.float64: torch.complex128, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.complex128, + torch.int8: torch.complex128, + torch.int16: torch.complex128, + torch.int32: torch.complex128, + torch.int64: torch.complex128, + torch.bool: torch.complex128, + }, + torch.uint8: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint8, + torch.int8: torch.int16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.uint8, + }, + torch.int8: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.int16, + torch.int8: torch.int8, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.int8, + }, + torch.int16: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.int16, + torch.int8: torch.int16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.int16, + }, + torch.int32: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.int32, + torch.int8: torch.int32, + torch.int16: torch.int32, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.int32, + }, + torch.int64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.int64, + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bool: torch.int64, + }, + torch.bool: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint8, + torch.int8: torch.int8, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.bool, + }, +} diff --git a/MLPY/Lib/site-packages/torch/_numpy/_dtypes.py b/MLPY/Lib/site-packages/torch/_numpy/_dtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..4195f7a32d010ff91b5aa6ae80be589673f06ae4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_dtypes.py @@ -0,0 +1,434 @@ +# mypy: ignore-errors + +""" Define analogs of numpy dtypes supported by pytorch. +Define the scalar types and supported dtypes and numpy <--> torch dtype mappings. +""" +import builtins + +import torch + +from . import _dtypes_impl + + +# ### Scalar types ### + + +class generic: + name = "generic" + + def __new__(cls, value): + # NumPy scalars are modelled as 0-D arrays + # so a call to np.float32(4) produces a 0-D array. + + from ._ndarray import asarray, ndarray + + if isinstance(value, str) and value in ["inf", "nan"]: + value = {"inf": torch.inf, "nan": torch.nan}[value] + + if isinstance(value, ndarray): + return value.astype(cls) + else: + return asarray(value, dtype=cls) + + +################## +# abstract types # +################## + + +class number(generic): + name = "number" + + +class integer(number): + name = "integer" + + +class inexact(number): + name = "inexact" + + +class signedinteger(integer): + name = "signedinteger" + + +class unsignedinteger(integer): + name = "unsignedinteger" + + +class floating(inexact): + name = "floating" + + +class complexfloating(inexact): + name = "complexfloating" + + +_abstract_dtypes = [ + "generic", + "number", + "integer", + "signedinteger", + "unsignedinteger", + "inexact", + "floating", + "complexfloating", +] + +# ##### concrete types + +# signed integers + + +class int8(signedinteger): + name = "int8" + typecode = "b" + torch_dtype = torch.int8 + + +class int16(signedinteger): + name = "int16" + typecode = "h" + torch_dtype = torch.int16 + + +class int32(signedinteger): + name = "int32" + typecode = "i" + torch_dtype = torch.int32 + + +class int64(signedinteger): + name = "int64" + typecode = "l" + torch_dtype = torch.int64 + + +# unsigned integers + + +class uint8(unsignedinteger): + name = "uint8" + typecode = "B" + torch_dtype = torch.uint8 + + +# floating point + + +class float16(floating): + name = "float16" + typecode = "e" + torch_dtype = torch.float16 + + +class float32(floating): + name = "float32" + typecode = "f" + torch_dtype = torch.float32 + + +class float64(floating): + name = "float64" + typecode = "d" + torch_dtype = torch.float64 + + +class complex64(complexfloating): + name = "complex64" + typecode = "F" + torch_dtype = torch.complex64 + + +class complex128(complexfloating): + name = "complex128" + typecode = "D" + torch_dtype = torch.complex128 + + +class bool_(generic): + name = "bool_" + typecode = "?" + torch_dtype = torch.bool + + +# name aliases +_name_aliases = { + "intp": int64, + "int_": int64, + "intc": int32, + "byte": int8, + "short": int16, + "longlong": int64, # XXX: is this correct? + "ubyte": uint8, + "half": float16, + "single": float32, + "double": float64, + "float_": float64, + "csingle": complex64, + "singlecomplex": complex64, + "cdouble": complex128, + "cfloat": complex128, + "complex_": complex128, +} +# We register float_ = float32 and so on +for name, obj in _name_aliases.items(): + vars()[name] = obj + + +# Replicate this NumPy-defined way of grouping scalar types, +# cf tests/core/test_scalar_methods.py +sctypes = { + "int": [int8, int16, int32, int64], + "uint": [uint8], + "float": [float16, float32, float64], + "complex": [complex64, complex128], + "others": [bool_], +} + + +# Support mappings/functions + +_names = {st.name: st for cat in sctypes for st in sctypes[cat]} +_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]} +_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]} + + +_aliases = { + "u1": uint8, + "i1": int8, + "i2": int16, + "i4": int32, + "i8": int64, + "b": int8, # XXX: srsly? + "f2": float16, + "f4": float32, + "f8": float64, + "c8": complex64, + "c16": complex128, + # numpy-specific trailing underscore + "bool_": bool_, +} + + +_python_types = { + int: int64, + float: float64, + complex: complex128, + builtins.bool: bool_, + # also allow stringified names of python types + int.__name__: int64, + float.__name__: float64, + complex.__name__: complex128, + builtins.bool.__name__: bool_, +} + + +def sctype_from_string(s): + """Normalize a string value: a type 'name' or a typecode or a width alias.""" + if s in _names: + return _names[s] + if s in _name_aliases.keys(): + return _name_aliases[s] + if s in _typecodes: + return _typecodes[s] + if s in _aliases: + return _aliases[s] + if s in _python_types: + return _python_types[s] + raise TypeError(f"data type {s!r} not understood") + + +def sctype_from_torch_dtype(torch_dtype): + return _torch_dtypes[torch_dtype] + + +# ### DTypes. ### + + +def dtype(arg): + if arg is None: + arg = _dtypes_impl.default_dtypes().float_dtype + return DType(arg) + + +class DType: + def __init__(self, arg): + # a pytorch object? + if isinstance(arg, torch.dtype): + sctype = _torch_dtypes[arg] + elif isinstance(arg, torch.Tensor): + sctype = _torch_dtypes[arg.dtype] + # a scalar type? + elif issubclass_(arg, generic): + sctype = arg + # a dtype already? + elif isinstance(arg, DType): + sctype = arg._scalar_type + # a has a right attribute? + elif hasattr(arg, "dtype"): + sctype = arg.dtype._scalar_type + else: + sctype = sctype_from_string(arg) + self._scalar_type = sctype + + @property + def name(self): + return self._scalar_type.name + + @property + def type(self): + return self._scalar_type + + @property + def kind(self): + # https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html + return _torch_dtypes[self.torch_dtype].name[0] + + @property + def typecode(self): + return self._scalar_type.typecode + + def __eq__(self, other): + if isinstance(other, DType): + return self._scalar_type == other._scalar_type + try: + other_instance = DType(other) + except TypeError: + return False + return self._scalar_type == other_instance._scalar_type + + @property + def torch_dtype(self): + return self._scalar_type.torch_dtype + + def __hash__(self): + return hash(self._scalar_type.name) + + def __repr__(self): + return f'dtype("{self.name}")' + + __str__ = __repr__ + + @property + def itemsize(self): + elem = self.type(1) + return elem.tensor.element_size() + + def __getstate__(self): + return self._scalar_type + + def __setstate__(self, value): + self._scalar_type = value + + +typecodes = { + "All": "efdFDBbhil?", + "AllFloat": "efdFD", + "AllInteger": "Bbhil", + "Integer": "bhil", + "UnsignedInteger": "B", + "Float": "efd", + "Complex": "FD", +} + + +# ### Defaults and dtype discovery + + +def set_default_dtype(fp_dtype="numpy", int_dtype="numpy"): + """Set the (global) defaults for fp, complex, and int dtypes. + + The complex dtype is inferred from the float (fp) dtype. It has + a width at least twice the width of the float dtype, + i.e., it's complex128 for float64 and complex64 for float32. + + Parameters + ---------- + fp_dtype + Allowed values are "numpy", "pytorch" or dtype_like things which + can be converted into a DType instance. + Default is "numpy" (i.e. float64). + int_dtype + Allowed values are "numpy", "pytorch" or dtype_like things which + can be converted into a DType instance. + Default is "numpy" (i.e. int64). + + Returns + ------- + The old default dtype state: a namedtuple with attributes ``float_dtype``, + ``complex_dtypes`` and ``int_dtype``. These attributes store *pytorch* + dtypes. + + Notes + ------------ + This functions has a side effect: it sets the global state with the provided dtypes. + + The complex dtype has bit width of at least twice the width of the float + dtype, i.e. it's complex128 for float64 and complex64 for float32. + + """ + if fp_dtype not in ["numpy", "pytorch"]: + fp_dtype = dtype(fp_dtype).torch_dtype + if int_dtype not in ["numpy", "pytorch"]: + int_dtype = dtype(int_dtype).torch_dtype + + if fp_dtype == "numpy": + float_dtype = torch.float64 + elif fp_dtype == "pytorch": + float_dtype = torch.float32 + else: + float_dtype = fp_dtype + + complex_dtype = { + torch.float64: torch.complex128, + torch.float32: torch.complex64, + torch.float16: torch.complex64, + }[float_dtype] + + if int_dtype in ["numpy", "pytorch"]: + int_dtype = torch.int64 + else: + int_dtype = int_dtype + + new_defaults = _dtypes_impl.DefaultDTypes( + float_dtype=float_dtype, complex_dtype=complex_dtype, int_dtype=int_dtype + ) + + # set the new global state and return the old state + old_defaults = _dtypes_impl.default_dtypes + _dtypes_impl._default_dtypes = new_defaults + return old_defaults + + +def issubclass_(arg, klass): + try: + return issubclass(arg, klass) + except TypeError: + return False + + +def issubdtype(arg1, arg2): + # cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420 + + # We also accept strings even if NumPy doesn't as dtypes are serialized as their + # string representation in dynamo's graph + def str_to_abstract(t): + if isinstance(t, str) and t in _abstract_dtypes: + return globals()[t] + return t + + arg1 = str_to_abstract(arg1) + arg2 = str_to_abstract(arg2) + + if not issubclass_(arg1, generic): + arg1 = dtype(arg1).type + if not issubclass_(arg2, generic): + arg2 = dtype(arg2).type + return issubclass(arg1, arg2) + + +__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"] +__all__ += list(_names.keys()) # noqa: PLE0605 +__all__ += list(_name_aliases.keys()) # noqa: PLE0605 +__all__ += _abstract_dtypes # noqa: PLE0605 diff --git a/MLPY/Lib/site-packages/torch/_numpy/_dtypes_impl.py b/MLPY/Lib/site-packages/torch/_numpy/_dtypes_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..834e585c89e719f93460075fbbc2aaeca4b487a0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_dtypes_impl.py @@ -0,0 +1,216 @@ +# mypy: ignore-errors + +"""Dtypes/scalar type implementaions with torch dtypes. + +Here `dtype` is always a torch.dtype, this module knows nothing about +scalar types, wrapper dtypes or anything like that. PyTorch only. +""" +from collections import namedtuple + +import torch + +# defaults : mimic NumPy, allow user control +DefaultDTypes = namedtuple( + "DefaultDTypes", ["float_dtype", "complex_dtype", "int_dtype"] +) + +# a global state +# We set it the first time we call default_dtypes() to avoid importing +# torch._dynamo.config and create a circular reference +_default_dtypes = None + + +def default_dtypes(): + global _default_dtypes + if _default_dtypes is None: + import torch._dynamo.config as config + + _default_dtypes = DefaultDTypes( + float_dtype=getattr(torch, config.numpy_default_float), + complex_dtype=getattr(torch, config.numpy_default_complex), + int_dtype=getattr(torch, config.numpy_default_int), + ) + assert isinstance(_default_dtypes.float_dtype, torch.dtype) + assert isinstance(_default_dtypes.complex_dtype, torch.dtype) + assert isinstance(_default_dtypes.int_dtype, torch.dtype) + return _default_dtypes + + +def get_default_dtype_for(dtype): + """Default scalar type given sctype category.""" + if dtype == torch.bool: + return dtype + if dtype.is_complex: + return default_dtypes().complex_dtype + if dtype.is_floating_point: + return default_dtypes().float_dtype + # else, it must be (some) integer + return default_dtypes().int_dtype + + +from . import _casting_dicts as _cd + + +def can_cast_impl(from_torch_dtype, to_torch_dtype, casting): + return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype] + + +def result_type_impl(*tensors): + # NB: torch dtypes here + dtyp = tensors[0].dtype + if len(tensors) == 1: + return dtyp + + for curr in tensors[1:]: + dtyp = _cd._result_type_dict[dtyp][curr.dtype] + + return dtyp + + +def python_type_for_torch(dtyp): + """Get a python scalar type a torch dtype""" + if dtyp.is_floating_point: + typ = float + elif dtyp.is_complex: + typ = complex + elif dtyp == torch.bool: + typ = bool + else: + typ = int + return typ + + +# ### NEP 50 helpers ### + +_SCALAR_TYPES = (int, bool, float, complex) + +_SCALAR_AND_SYMBOLIC_TYPES = ( + *_SCALAR_TYPES, + torch.SymInt, + torch.SymFloat, + torch.SymBool, +) + +_NEP50_FUNCS_TENSOR_ONLY = ( + "minimum", + "maximum", + "logaddexp", + "logaddexp2", + "lcm", + "gcd", + "hypot", + "heaviside", + "fmod", + "fmin", + "fmax", + "copysign", + "arctan2", +) + + +def is_scalar(x): + return isinstance(x, _SCALAR_TYPES) + + +def is_scalar_or_symbolic(x): + return isinstance(x, _SCALAR_AND_SYMBOLIC_TYPES) + + +def _dtype_for_scalar(py_type): + return { + bool: torch.bool, + torch.SymBool: torch.bool, + int: torch.int64, + torch.SymInt: torch.int64, + float: torch.float64, + torch.SymFloat: torch.float64, + complex: torch.complex128, + }[py_type] + + +def _dtype_for_scalar_or_tensor(x): + return x.dtype if isinstance(x, torch.Tensor) else _dtype_for_scalar(type(x)) + + +def is_float_or_fp_tensor(x): + return _dtype_for_scalar_or_tensor(x).is_floating_point + + +def is_complex_or_complex_tensor(x): + return _dtype_for_scalar_or_tensor(x).is_complex + + +def _category(dtype): + return { + torch.bool: 0, + torch.SymBool: 0, + # int + torch.uint8: 1, + torch.int8: 1, + torch.int16: 1, + torch.int32: 1, + torch.int64: 1, + torch.SymInt: 1, + # float + torch.float16: 2, + torch.float32: 2, + torch.float64: 2, + torch.SymFloat: 2, + # complex + torch.complex64: 3, + torch.complex128: 3, + }[dtype] + + +def nep50_to_tensors(x1, x2, handle_weaks, function_name): + """If either of inputs is a python scalar, type-promote with NEP 50.""" + + def to_tensor(scalar, dtype=None): + if dtype is None: + dtype = _dtype_for_scalar(type(scalar)) + dtype = get_default_dtype_for(dtype) + return torch.as_tensor(scalar, dtype=dtype) + + x1_is_weak = not isinstance(x1, torch.Tensor) + x2_is_weak = not isinstance(x2, torch.Tensor) + if not handle_weaks or (x1_is_weak and x2_is_weak): + x1 = to_tensor(x1) if x1_is_weak else x1 + x2 = to_tensor(x2) if x2_is_weak else x2 + return x1, x2 + + # scalar tensor: NEP 50 + assert x1_is_weak != x2_is_weak + + weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1) + + # find the dtype for the weak's type + weak_dtype = _dtype_for_scalar(type(weak)) + + cat_weak = _category(weak_dtype) + cat_not_weak = _category(not_weak.dtype) + + dt = not_weak.dtype if cat_weak <= cat_not_weak else None + + # special-case complex + float32 + if weak_dtype.is_complex and not_weak.dtype == torch.float32: + dt = torch.complex64 + + # detect overflows: in PyTorch, uint8(-1) wraps around to 255, + # while NEP50 mandates an exception. + # + # Note that we only check if each element of the binop overflows, + # not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK + # in uint8, but the result overflows and wrap around 255. + # Numpy emits a RuntimeWarning, PyTorch does not, and we do not either. + if cat_weak == 1 and cat_not_weak == 1: + # integers + iinfo = torch.iinfo(not_weak.dtype) + if not (iinfo.min <= weak <= iinfo.max): + raise OverflowError( + f"Python integer {weak} out of bounds for {not_weak.dtype}" + ) + if weak_dtype != dt or function_name in _NEP50_FUNCS_TENSOR_ONLY: + # finally, can make `weak` into a 0D tensor, if both parameters are required to be tensor. + weak = to_tensor(weak, dt) + + return (weak, not_weak) if x1_is_weak else (not_weak, weak) diff --git a/MLPY/Lib/site-packages/torch/_numpy/_funcs.py b/MLPY/Lib/site-packages/torch/_numpy/_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d46fef08b33c8c4fd826a8bc22e4242893645d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_funcs.py @@ -0,0 +1,75 @@ +# mypy: ignore-errors + +import inspect +import itertools + +from . import _funcs_impl, _reductions_impl +from ._normalizations import normalizer + +# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents, +# and consume/return PyTorch tensors/dtypes. +# They are also type annotated. +# Pull these functions from _funcs_impl and decorate them with @normalizer, which +# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`. +# - Maps NumPy dtypes to PyTorch dtypes +# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple +# - Implements the semantics for the `out=` arg +# - Wraps back the outputs into `torch._numpy.ndarrays` + + +def _public_functions(mod): + def is_public_function(f): + return inspect.isfunction(f) and not f.__name__.startswith("_") + + return inspect.getmembers(mod, is_public_function) + + +# We fill in __all__ in the loop below +__all__ = [] + +# decorate implementer functions with argument normalizers and export to the top namespace +for name, func in itertools.chain( + _public_functions(_funcs_impl), _public_functions(_reductions_impl) +): + if name in ["percentile", "quantile", "median"]: + decorated = normalizer(func, promote_scalar_result=True) + elif name == "einsum": + # normalized manually + decorated = func + else: + decorated = normalizer(func) + + decorated.__qualname__ = name + decorated.__name__ = name + vars()[name] = decorated + __all__.append(name) + + +""" +Vendored objects from numpy.lib.index_tricks +""" + + +class IndexExpression: + """ + Written by Konrad Hinsen + last revision: 1999-7-23 + + Cosmetic changes by T. Oliphant 2001 + """ + + def __init__(self, maketuple): + self.maketuple = maketuple + + def __getitem__(self, item): + if self.maketuple and not isinstance(item, tuple): + return (item,) + else: + return item + + +index_exp = IndexExpression(maketuple=True) +s_ = IndexExpression(maketuple=False) + + +__all__ += ["index_exp", "s_"] diff --git a/MLPY/Lib/site-packages/torch/_numpy/_funcs_impl.py b/MLPY/Lib/site-packages/torch/_numpy/_funcs_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..d470076ee51163c7c1d26ce1bfc3d6139f9f6fa0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_funcs_impl.py @@ -0,0 +1,2053 @@ +# mypy: ignore-errors + +"""A thin pytorch / numpy compat layer. + +Things imported from here have numpy-compatible signatures but operate on +pytorch tensors. +""" +# Contents of this module ends up in the main namespace via _funcs.py +# where type annotations are used in conjunction with the @normalizer decorator. +from __future__ import annotations + +import builtins +import itertools +import operator +from typing import Optional, Sequence + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import ( + ArrayLike, + ArrayLikeOrScalar, + CastingModes, + DTypeLike, + NDArray, + NotImplementedType, + OutArray, +) + + +def copy( + a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False +): + return a.clone() + + +def copyto( + dst: NDArray, + src: ArrayLike, + casting: Optional[CastingModes] = "same_kind", + where: NotImplementedType = None, +): + (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting) + dst.copy_(src) + + +def atleast_1d(*arys: ArrayLike): + res = torch.atleast_1d(*arys) + if isinstance(res, tuple): + return list(res) + else: + return res + + +def atleast_2d(*arys: ArrayLike): + res = torch.atleast_2d(*arys) + if isinstance(res, tuple): + return list(res) + else: + return res + + +def atleast_3d(*arys: ArrayLike): + res = torch.atleast_3d(*arys) + if isinstance(res, tuple): + return list(res) + else: + return res + + +def _concat_check(tup, dtype, out): + if tup == (): + raise ValueError("need at least one array to concatenate") + + """Check inputs in concatenate et al.""" + if out is not None and dtype is not None: + # mimic numpy + raise TypeError( + "concatenate() only takes `out` or `dtype` as an " + "argument, but both were provided." + ) + + +def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): + """Figure out dtypes, cast if necessary.""" + + if out is not None or dtype is not None: + # figure out the type of the inputs and outputs + out_dtype = out.dtype.torch_dtype if dtype is None else dtype + else: + out_dtype = _dtypes_impl.result_type_impl(*tensors) + + # cast input arrays if necessary; do not broadcast them agains `out` + tensors = _util.typecast_tensors(tensors, out_dtype, casting) + + return tensors + + +def _concatenate( + tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind" +): + # pure torch implementation, used below and in cov/corrcoef below + tensors, axis = _util.axis_none_flatten(*tensors, axis=axis) + tensors = _concat_cast_helper(tensors, out, dtype, casting) + return torch.cat(tensors, axis) + + +def concatenate( + ar_tuple: Sequence[ArrayLike], + axis=0, + out: Optional[OutArray] = None, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(ar_tuple, dtype, out=out) + result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting) + return result + + +def vstack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.vstack(tensors) + + +row_stack = vstack + + +def hstack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.hstack(tensors) + + +def dstack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + # XXX: in numpy 1.24 dstack does not have dtype and casting keywords + # but {h,v}stack do. Hence add them here for consistency. + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.dstack(tensors) + + +def column_stack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords + # but row_stack does. (because row_stack is an alias for vstack, really). + # Hence add these keywords here for consistency. + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.column_stack(tensors) + + +def stack( + arrays: Sequence[ArrayLike], + axis=0, + out: Optional[OutArray] = None, + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(arrays, dtype, out=out) + + tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting) + result_ndim = tensors[0].ndim + 1 + axis = _util.normalize_axis_index(axis, result_ndim) + return torch.stack(tensors, axis=axis) + + +def append(arr: ArrayLike, values: ArrayLike, axis=None): + if axis is None: + if arr.ndim != 1: + arr = arr.flatten() + values = values.flatten() + axis = arr.ndim - 1 + return _concatenate((arr, values), axis=axis) + + +# ### split ### + + +def _split_helper(tensor, indices_or_sections, axis, strict=False): + if isinstance(indices_or_sections, int): + return _split_helper_int(tensor, indices_or_sections, axis, strict) + elif isinstance(indices_or_sections, (list, tuple)): + # NB: drop split=..., it only applies to split_helper_int + return _split_helper_list(tensor, list(indices_or_sections), axis) + else: + raise TypeError("split_helper: ", type(indices_or_sections)) + + +def _split_helper_int(tensor, indices_or_sections, axis, strict=False): + if not isinstance(indices_or_sections, int): + raise NotImplementedError("split: indices_or_sections") + + axis = _util.normalize_axis_index(axis, tensor.ndim) + + # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n + l, n = tensor.shape[axis], indices_or_sections + + if n <= 0: + raise ValueError() + + if l % n == 0: + num, sz = n, l // n + lst = [sz] * num + else: + if strict: + raise ValueError("array split does not result in an equal division") + + num, sz = l % n, l // n + 1 + lst = [sz] * num + + lst += [sz - 1] * (n - num) + + return torch.split(tensor, lst, axis) + + +def _split_helper_list(tensor, indices_or_sections, axis): + if not isinstance(indices_or_sections, list): + raise NotImplementedError("split: indices_or_sections: list") + # numpy expects indices, while torch expects lengths of sections + # also, numpy appends zero-size arrays for indices above the shape[axis] + lst = [x for x in indices_or_sections if x <= tensor.shape[axis]] + num_extra = len(indices_or_sections) - len(lst) + + lst.append(tensor.shape[axis]) + lst = [ + lst[0], + ] + [a - b for a, b in zip(lst[1:], lst[:-1])] + lst += [0] * num_extra + + return torch.split(tensor, lst, axis) + + +def array_split(ary: ArrayLike, indices_or_sections, axis=0): + return _split_helper(ary, indices_or_sections, axis) + + +def split(ary: ArrayLike, indices_or_sections, axis=0): + return _split_helper(ary, indices_or_sections, axis, strict=True) + + +def hsplit(ary: ArrayLike, indices_or_sections): + if ary.ndim == 0: + raise ValueError("hsplit only works on arrays of 1 or more dimensions") + axis = 1 if ary.ndim > 1 else 0 + return _split_helper(ary, indices_or_sections, axis, strict=True) + + +def vsplit(ary: ArrayLike, indices_or_sections): + if ary.ndim < 2: + raise ValueError("vsplit only works on arrays of 2 or more dimensions") + return _split_helper(ary, indices_or_sections, 0, strict=True) + + +def dsplit(ary: ArrayLike, indices_or_sections): + if ary.ndim < 3: + raise ValueError("dsplit only works on arrays of 3 or more dimensions") + return _split_helper(ary, indices_or_sections, 2, strict=True) + + +def kron(a: ArrayLike, b: ArrayLike): + return torch.kron(a, b) + + +def vander(x: ArrayLike, N=None, increasing=False): + return torch.vander(x, N, increasing) + + +# ### linspace, geomspace, logspace and arange ### + + +def linspace( + start: ArrayLike, + stop: ArrayLike, + num=50, + endpoint=True, + retstep=False, + dtype: Optional[DTypeLike] = None, + axis=0, +): + if axis != 0 or retstep or not endpoint: + raise NotImplementedError + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + # XXX: raises TypeError if start or stop are not scalars + return torch.linspace(start, stop, num, dtype=dtype) + + +def geomspace( + start: ArrayLike, + stop: ArrayLike, + num=50, + endpoint=True, + dtype: Optional[DTypeLike] = None, + axis=0, +): + if axis != 0 or not endpoint: + raise NotImplementedError + base = torch.pow(stop / start, 1.0 / (num - 1)) + logbase = torch.log(base) + return torch.logspace( + torch.log(start) / logbase, + torch.log(stop) / logbase, + num, + base=base, + ) + + +def logspace( + start, + stop, + num=50, + endpoint=True, + base=10.0, + dtype: Optional[DTypeLike] = None, + axis=0, +): + if axis != 0 or not endpoint: + raise NotImplementedError + return torch.logspace(start, stop, num, base=base, dtype=dtype) + + +def arange( + start: Optional[ArrayLikeOrScalar] = None, + stop: Optional[ArrayLikeOrScalar] = None, + step: Optional[ArrayLikeOrScalar] = 1, + dtype: Optional[DTypeLike] = None, + *, + like: NotImplementedType = None, +): + if step == 0: + raise ZeroDivisionError + if stop is None and start is None: + raise TypeError + if stop is None: + # XXX: this breaks if start is passed as a kwarg: + # arange(start=4) should raise (no stop) but doesn't + start, stop = 0, start + if start is None: + start = 0 + + # the dtype of the result + if dtype is None: + dtype = ( + _dtypes_impl.default_dtypes().float_dtype + if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step)) + else _dtypes_impl.default_dtypes().int_dtype + ) + work_dtype = torch.float64 if dtype.is_complex else dtype + + # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager. + if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)): + raise NotImplementedError + + if (step > 0 and start > stop) or (step < 0 and start < stop): + # empty range + return torch.empty(0, dtype=dtype) + + result = torch.arange(start, stop, step, dtype=work_dtype) + result = _util.cast_if_needed(result, dtype) + return result + + +# ### zeros/ones/empty/full ### + + +def empty( + shape, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.empty(shape, dtype=dtype) + + +# NB: *_like functions deliberately deviate from numpy: it has subok=True +# as the default; we set subok=False and raise on anything else. + + +def empty_like( + prototype: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + result = torch.empty_like(prototype, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def full( + shape, + fill_value: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if isinstance(shape, int): + shape = (shape,) + if dtype is None: + dtype = fill_value.dtype + if not isinstance(shape, (tuple, list)): + shape = (shape,) + return torch.full(shape, fill_value, dtype=dtype) + + +def full_like( + a: ArrayLike, + fill_value, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + # XXX: fill_value broadcasts + result = torch.full_like(a, fill_value, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def ones( + shape, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.ones(shape, dtype=dtype) + + +def ones_like( + a: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + result = torch.ones_like(a, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def zeros( + shape, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.zeros(shape, dtype=dtype) + + +def zeros_like( + a: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + result = torch.zeros_like(a, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +# ### cov & corrcoef ### + + +def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True): + """Prepare inputs for cov and corrcoef.""" + + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 + if y_tensor is not None: + # make sure x and y are at least 2D + ndim_extra = 2 - x_tensor.ndim + if ndim_extra > 0: + x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape) + if not rowvar and x_tensor.shape[0] != 1: + x_tensor = x_tensor.mT + x_tensor = x_tensor.clone() + + ndim_extra = 2 - y_tensor.ndim + if ndim_extra > 0: + y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape) + if not rowvar and y_tensor.shape[0] != 1: + y_tensor = y_tensor.mT + y_tensor = y_tensor.clone() + + x_tensor = _concatenate((x_tensor, y_tensor), axis=0) + + return x_tensor + + +def corrcoef( + x: ArrayLike, + y: Optional[ArrayLike] = None, + rowvar=True, + bias=None, + ddof=None, + *, + dtype: Optional[DTypeLike] = None, +): + if bias is not None or ddof is not None: + # deprecated in NumPy + raise NotImplementedError + xy_tensor = _xy_helper_corrcoef(x, y, rowvar) + + is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + + xy_tensor = _util.cast_if_needed(xy_tensor, dtype) + result = torch.corrcoef(xy_tensor) + + if is_half: + result = result.to(torch.float16) + + return result + + +def cov( + m: ArrayLike, + y: Optional[ArrayLike] = None, + rowvar=True, + bias=False, + ddof=None, + fweights: Optional[ArrayLike] = None, + aweights: Optional[ArrayLike] = None, + *, + dtype: Optional[DTypeLike] = None, +): + m = _xy_helper_corrcoef(m, y, rowvar) + + if ddof is None: + ddof = 1 if bias == 0 else 0 + + is_half = (m.dtype == torch.float16) and m.is_cpu + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + + m = _util.cast_if_needed(m, dtype) + result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights) + + if is_half: + result = result.to(torch.float16) + + return result + + +def _conv_corr_impl(a, v, mode): + dt = _dtypes_impl.result_type_impl(a, v) + a = _util.cast_if_needed(a, dt) + v = _util.cast_if_needed(v, dt) + + padding = v.shape[0] - 1 if mode == "full" else mode + + if padding == "same" and v.shape[0] % 2 == 0: + # UserWarning: Using padding='same' with even kernel lengths and odd + # dilation may require a zero-padded copy of the input be created + # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.) + raise NotImplementedError("mode='same' and even-length weights") + + # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights + aa = a[None, :] + vv = v[None, None, :] + + result = torch.nn.functional.conv1d(aa, vv, padding=padding) + + # torch returns a 2D result, numpy returns a 1D array + return result[0, :] + + +def convolve(a: ArrayLike, v: ArrayLike, mode="full"): + # NumPy: if v is longer than a, the arrays are swapped before computation + if a.shape[0] < v.shape[0]: + a, v = v, a + + # flip the weights since numpy does and torch does not + v = torch.flip(v, (0,)) + + return _conv_corr_impl(a, v, mode) + + +def correlate(a: ArrayLike, v: ArrayLike, mode="valid"): + v = torch.conj_physical(v) + return _conv_corr_impl(a, v, mode) + + +# ### logic & element selection ### + + +def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0): + if x.numel() == 0: + # edge case allowed by numpy + x = x.new_empty(0, dtype=int) + + int_dtype = _dtypes_impl.default_dtypes().int_dtype + (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe") + + return torch.bincount(x, weights, minlength) + + +def where( + condition: ArrayLike, + x: Optional[ArrayLikeOrScalar] = None, + y: Optional[ArrayLikeOrScalar] = None, + /, +): + if (x is None) != (y is None): + raise ValueError("either both or neither of x and y should be given") + + if condition.dtype != torch.bool: + condition = condition.to(torch.bool) + + if x is None and y is None: + result = torch.where(condition) + else: + result = torch.where(condition, x, y) + return result + + +# ###### module-level queries of object properties + + +def ndim(a: ArrayLike): + return a.ndim + + +def shape(a: ArrayLike): + return tuple(a.shape) + + +def size(a: ArrayLike, axis=None): + if axis is None: + return a.numel() + else: + return a.shape[axis] + + +# ###### shape manipulations and indexing + + +def expand_dims(a: ArrayLike, axis): + shape = _util.expand_shape(a.shape, axis) + return a.view(shape) # never copies + + +def flip(m: ArrayLike, axis=None): + # XXX: semantic difference: np.flip returns a view, torch.flip copies + if axis is None: + axis = tuple(range(m.ndim)) + else: + axis = _util.normalize_axis_tuple(axis, m.ndim) + return torch.flip(m, axis) + + +def flipud(m: ArrayLike): + return torch.flipud(m) + + +def fliplr(m: ArrayLike): + return torch.fliplr(m) + + +def rot90(m: ArrayLike, k=1, axes=(0, 1)): + axes = _util.normalize_axis_tuple(axes, m.ndim) + return torch.rot90(m, k, axes) + + +# ### broadcasting and indices ### + + +def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): + return torch.broadcast_to(array, size=shape) + + +# This is a function from tuples to tuples, so we just reuse it +from torch import broadcast_shapes + + +def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False): + return torch.broadcast_tensors(*args) + + +def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"): + ndim = len(xi) + + if indexing not in ["xy", "ij"]: + raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.") + + s0 = (1,) * ndim + output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)] + + if indexing == "xy" and ndim > 1: + # switch first and second axis + output[0] = output[0].reshape((1, -1) + s0[2:]) + output[1] = output[1].reshape((-1, 1) + s0[2:]) + + if not sparse: + # Return the full N-D matrix (not only the 1-D vector) + output = torch.broadcast_tensors(*output) + + if copy: + output = [x.clone() for x in output] + + return list(output) # match numpy, return a list + + +def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False): + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791 + dimensions = tuple(dimensions) + N = len(dimensions) + shape = (1,) * N + if sparse: + res = tuple() + else: + res = torch.empty((N,) + dimensions, dtype=dtype) + for i, dim in enumerate(dimensions): + idx = torch.arange(dim, dtype=dtype).reshape( + shape[:i] + (dim,) + shape[i + 1 :] + ) + if sparse: + res = res + (idx,) + else: + res[i] = idx + return res + + +# ### tri*-something ### + + +def tril(m: ArrayLike, k=0): + return torch.tril(m, k) + + +def triu(m: ArrayLike, k=0): + return torch.triu(m, k) + + +def tril_indices(n, k=0, m=None): + if m is None: + m = n + return torch.tril_indices(n, m, offset=k) + + +def triu_indices(n, k=0, m=None): + if m is None: + m = n + return torch.triu_indices(n, m, offset=k) + + +def tril_indices_from(arr: ArrayLike, k=0): + if arr.ndim != 2: + raise ValueError("input array must be 2-d") + # Return a tensor rather than a tuple to avoid a graphbreak + return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) + + +def triu_indices_from(arr: ArrayLike, k=0): + if arr.ndim != 2: + raise ValueError("input array must be 2-d") + # Return a tensor rather than a tuple to avoid a graphbreak + return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k) + + +def tri( + N, + M=None, + k=0, + dtype: Optional[DTypeLike] = None, + *, + like: NotImplementedType = None, +): + if M is None: + M = N + tensor = torch.ones((N, M), dtype=dtype) + return torch.tril(tensor, diagonal=k) + + +# ### equality, equivalence, allclose ### + + +def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): + dtype = _dtypes_impl.result_type_impl(a, b) + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False): + dtype = _dtypes_impl.result_type_impl(a, b) + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def _tensor_equal(a1, a2, equal_nan=False): + # Implementation of array_equal/array_equiv. + if a1.shape != a2.shape: + return False + cond = a1 == a2 + if equal_nan: + cond = cond | (torch.isnan(a1) & torch.isnan(a2)) + return cond.all().item() + + +def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False): + return _tensor_equal(a1, a2, equal_nan=equal_nan) + + +def array_equiv(a1: ArrayLike, a2: ArrayLike): + # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not + try: + a1_t, a2_t = torch.broadcast_tensors(a1, a2) + except RuntimeError: + # failed to broadcast => not equivalent + return False + return _tensor_equal(a1_t, a2_t) + + +def nan_to_num( + x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None +): + # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble' + if x.is_complex(): + re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf) + im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf) + return re + 1j * im + else: + return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +# ### put/take_along_axis ### + + +def take( + a: ArrayLike, + indices: ArrayLike, + axis=None, + out: Optional[OutArray] = None, + mode: NotImplementedType = "raise", +): + (a,), axis = _util.axis_none_flatten(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + idx = (slice(None),) * axis + (indices, ...) + result = a[idx] + return result + + +def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): + (arr,), axis = _util.axis_none_flatten(arr, axis=axis) + axis = _util.normalize_axis_index(axis, arr.ndim) + return torch.take_along_dim(arr, indices, axis) + + +def put( + a: NDArray, + indices: ArrayLike, + values: ArrayLike, + mode: NotImplementedType = "raise", +): + v = values.type(a.dtype) + # If indices is larger than v, expand v to at least the size of indices. Any + # unnecessary trailing elements are then trimmed. + if indices.numel() > v.numel(): + ratio = (indices.numel() + v.numel() - 1) // v.numel() + v = v.unsqueeze(0).expand((ratio,) + v.shape) + # Trim unnecessary elements, regardless if v was expanded or not. Note + # np.put() trims v to match indices by default too. + if indices.numel() < v.numel(): + v = v.flatten() + v = v[: indices.numel()] + a.put_(indices, v) + return None + + +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): + (arr,), axis = _util.axis_none_flatten(arr, axis=axis) + axis = _util.normalize_axis_index(axis, arr.ndim) + + indices, values = torch.broadcast_tensors(indices, values) + values = _util.cast_if_needed(values, arr.dtype) + result = torch.scatter(arr, axis, indices, values) + arr.copy_(result.reshape(arr.shape)) + return None + + +def choose( + a: ArrayLike, + choices: Sequence[ArrayLike], + out: Optional[OutArray] = None, + mode: NotImplementedType = "raise", +): + # First, broadcast elements of `choices` + choices = torch.stack(torch.broadcast_tensors(*choices)) + + # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`: + # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939) + idx_list = [ + torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1)) + for i, dim in enumerate(choices.shape) + ] + + idx_list[0] = a + return choices[idx_list].squeeze(0) + + +# ### unique et al ### + + +def unique( + ar: ArrayLike, + return_index: NotImplementedType = False, + return_inverse=False, + return_counts=False, + axis=None, + *, + equal_nan: NotImplementedType = True, +): + (ar,), axis = _util.axis_none_flatten(ar, axis=axis) + axis = _util.normalize_axis_index(axis, ar.ndim) + + result = torch.unique( + ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis + ) + + return result + + +def nonzero(a: ArrayLike): + return torch.nonzero(a, as_tuple=True) + + +def argwhere(a: ArrayLike): + return torch.argwhere(a) + + +def flatnonzero(a: ArrayLike): + return torch.flatten(a).nonzero(as_tuple=True)[0] + + +def clip( + a: ArrayLike, + min: Optional[ArrayLike] = None, + max: Optional[ArrayLike] = None, + out: Optional[OutArray] = None, +): + return torch.clamp(a, min, max) + + +def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None): + return torch.repeat_interleave(a, repeats, axis) + + +def tile(A: ArrayLike, reps): + if isinstance(reps, int): + reps = (reps,) + return torch.tile(A, reps) + + +def resize(a: ArrayLike, new_shape=None): + # implementation vendored from + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497 + if new_shape is None: + return a + + if isinstance(new_shape, int): + new_shape = (new_shape,) + + a = a.flatten() + + new_size = 1 + for dim_length in new_shape: + new_size *= dim_length + if dim_length < 0: + raise ValueError("all elements of `new_shape` must be non-negative") + + if a.numel() == 0 or new_size == 0: + # First case must zero fill. The second would have repeats == 0. + return torch.zeros(new_shape, dtype=a.dtype) + + repeats = -(-new_size // a.numel()) # ceil division + a = concatenate((a,) * repeats)[:new_size] + + return reshape(a, new_shape) + + +# ### diag et al ### + + +def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): + axis1 = _util.normalize_axis_index(axis1, a.ndim) + axis2 = _util.normalize_axis_index(axis2, a.ndim) + return torch.diagonal(a, offset, axis1, axis2) + + +def trace( + a: ArrayLike, + offset=0, + axis1=0, + axis2=1, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, +): + result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype) + return result + + +def eye( + N, + M=None, + k=0, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + if M is None: + M = N + z = torch.zeros(N, M, dtype=dtype) + z.diagonal(k).fill_(1) + return z + + +def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None): + return torch.eye(n, dtype=dtype) + + +def diag(v: ArrayLike, k=0): + return torch.diag(v, k) + + +def diagflat(v: ArrayLike, k=0): + return torch.diagflat(v, k) + + +def diag_indices(n, ndim=2): + idx = torch.arange(n) + return (idx,) * ndim + + +def diag_indices_from(arr: ArrayLike): + if not arr.ndim >= 2: + raise ValueError("input array must be at least 2-d") + # For more than d=2, the strided formula is only valid for arrays with + # all dimensions equal, so we check first. + s = arr.shape + if s[1:] != s[:-1]: + raise ValueError("All dimensions of input must be of equal length") + return diag_indices(s[0], arr.ndim) + + +def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False): + if a.ndim < 2: + raise ValueError("array must be at least 2-d") + if val.numel() == 0 and not wrap: + a.fill_diagonal_(val) + return a + + if val.ndim == 0: + val = val.unsqueeze(0) + + # torch.Tensor.fill_diagonal_ only accepts scalars + # If the size of val is too large, then val is trimmed + if a.ndim == 2: + tall = a.shape[0] > a.shape[1] + # wrap does nothing for wide matrices... + if not wrap or not tall: + # Never wraps + diag = a.diagonal() + diag.copy_(val[: diag.numel()]) + else: + # wraps and tall... leaving one empty line between diagonals?! + max_, min_ = a.shape + idx = torch.arange(max_ - max_ // (min_ + 1)) + mod = idx % min_ + div = idx // min_ + a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()] + else: + idx = diag_indices_from(a) + # a.shape = (n, n, ..., n) + a[idx] = val[: a.shape[0]] + + return a + + +def vdot(a: ArrayLike, b: ArrayLike, /): + # 1. torch only accepts 1D arrays, numpy flattens + # 2. torch requires matching dtype, while numpy casts (?) + t_a, t_b = torch.atleast_1d(a, b) + if t_a.ndim > 1: + t_a = t_a.flatten() + if t_b.ndim > 1: + t_b = t_b.flatten() + + dtype = _dtypes_impl.result_type_impl(t_a, t_b) + is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu) + is_bool = dtype == torch.bool + + # work around torch's "dot" not implemented for 'Half', 'Bool' + if is_half: + dtype = torch.float32 + elif is_bool: + dtype = torch.uint8 + + t_a = _util.cast_if_needed(t_a, dtype) + t_b = _util.cast_if_needed(t_b, dtype) + + result = torch.vdot(t_a, t_b) + + if is_half: + result = result.to(torch.float16) + elif is_bool: + result = result.to(torch.bool) + + return result + + +def tensordot(a: ArrayLike, b: ArrayLike, axes=2): + if isinstance(axes, (list, tuple)): + axes = [[ax] if isinstance(ax, int) else ax for ax in axes] + + target_dtype = _dtypes_impl.result_type_impl(a, b) + a = _util.cast_if_needed(a, target_dtype) + b = _util.cast_if_needed(b, target_dtype) + + return torch.tensordot(a, b, dims=axes) + + +def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): + dtype = _dtypes_impl.result_type_impl(a, b) + is_bool = dtype == torch.bool + if is_bool: + dtype = torch.uint8 + + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + + if a.ndim == 0 or b.ndim == 0: + result = a * b + else: + result = torch.matmul(a, b) + + if is_bool: + result = result.to(torch.bool) + + return result + + +def inner(a: ArrayLike, b: ArrayLike, /): + dtype = _dtypes_impl.result_type_impl(a, b) + is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu) + is_bool = dtype == torch.bool + + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + elif is_bool: + dtype = torch.uint8 + + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + + result = torch.inner(a, b) + + if is_half: + result = result.to(torch.float16) + elif is_bool: + result = result.to(torch.bool) + return result + + +def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): + return torch.outer(a, b) + + +def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None): + # implementation vendored from + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685 + if axis is not None: + axisa, axisb, axisc = (axis,) * 3 + + # Check axisa and axisb are within bounds + axisa = _util.normalize_axis_index(axisa, a.ndim) + axisb = _util.normalize_axis_index(axisb, b.ndim) + + # Move working axis to the end of the shape + a = torch.moveaxis(a, axisa, -1) + b = torch.moveaxis(b, axisb, -1) + msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)" + if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): + raise ValueError(msg) + + # Create the output array + shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape) + if a.shape[-1] == 3 or b.shape[-1] == 3: + shape += (3,) + # Check axisc is within bounds + axisc = _util.normalize_axis_index(axisc, len(shape)) + dtype = _dtypes_impl.result_type_impl(a, b) + cp = torch.empty(shape, dtype=dtype) + + # recast arrays as dtype + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + + # create local aliases for readability + a0 = a[..., 0] + a1 = a[..., 1] + if a.shape[-1] == 3: + a2 = a[..., 2] + b0 = b[..., 0] + b1 = b[..., 1] + if b.shape[-1] == 3: + b2 = b[..., 2] + if cp.ndim != 0 and cp.shape[-1] == 3: + cp0 = cp[..., 0] + cp1 = cp[..., 1] + cp2 = cp[..., 2] + + if a.shape[-1] == 2: + if b.shape[-1] == 2: + # a0 * b1 - a1 * b0 + cp[...] = a0 * b1 - a1 * b0 + return cp + else: + assert b.shape[-1] == 3 + # cp0 = a1 * b2 - 0 (a2 = 0) + # cp1 = 0 - a0 * b2 (a2 = 0) + # cp2 = a0 * b1 - a1 * b0 + cp0[...] = a1 * b2 + cp1[...] = -a0 * b2 + cp2[...] = a0 * b1 - a1 * b0 + else: + assert a.shape[-1] == 3 + if b.shape[-1] == 3: + cp0[...] = a1 * b2 - a2 * b1 + cp1[...] = a2 * b0 - a0 * b2 + cp2[...] = a0 * b1 - a1 * b0 + else: + assert b.shape[-1] == 2 + cp0[...] = -a2 * b1 + cp1[...] = a2 * b0 + cp2[...] = a0 * b1 - a1 * b0 + + return torch.moveaxis(cp, -1, axisc) + + +def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False): + # Have to manually normalize *operands and **kwargs, following the NumPy signature + # We have a local import to avoid poluting the global space, as it will be then + # exported in funcs.py + from ._ndarray import ndarray + from ._normalizations import ( + maybe_copy_to, + normalize_array_like, + normalize_casting, + normalize_dtype, + wrap_tensors, + ) + + dtype = normalize_dtype(dtype) + casting = normalize_casting(casting) + if out is not None and not isinstance(out, ndarray): + raise TypeError("'out' must be an array") + if order != "K": + raise NotImplementedError("'order' parameter is not supported.") + + # parse arrays and normalize them + sublist_format = not isinstance(operands[0], str) + if sublist_format: + # op, str, op, str ... [sublistout] format: normalize every other argument + + # - if sublistout is not given, the length of operands is even, and we pick + # odd-numbered elements, which are arrays. + # - if sublistout is given, the length of operands is odd, we peel off + # the last one, and pick odd-numbered elements, which are arrays. + # Without [:-1], we would have picked sublistout, too. + array_operands = operands[:-1][::2] + else: + # ("ij->", arrays) format + subscripts, array_operands = operands[0], operands[1:] + + tensors = [normalize_array_like(op) for op in array_operands] + target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype + + # work around 'bmm' not implemented for 'Half' etc + is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors) + if is_half: + target_dtype = torch.float32 + + is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] + if is_short_int: + target_dtype = torch.int64 + + tensors = _util.typecast_tensors(tensors, target_dtype, casting) + + from torch.backends import opt_einsum + + try: + # set the global state to handle the optimize=... argument, restore on exit + if opt_einsum.is_available(): + old_strategy = torch.backends.opt_einsum.strategy + old_enabled = torch.backends.opt_einsum.enabled + + # torch.einsum calls opt_einsum.contract_path, which runs into + # https://github.com/dgasmith/opt_einsum/issues/219 + # for strategy={True, False} + if optimize is True: + optimize = "auto" + elif optimize is False: + torch.backends.opt_einsum.enabled = False + + torch.backends.opt_einsum.strategy = optimize + + if sublist_format: + # recombine operands + sublists = operands[1::2] + has_sublistout = len(operands) % 2 == 1 + if has_sublistout: + sublistout = operands[-1] + operands = list(itertools.chain.from_iterable(zip(tensors, sublists))) + if has_sublistout: + operands.append(sublistout) + + result = torch.einsum(*operands) + else: + result = torch.einsum(subscripts, *tensors) + + finally: + if opt_einsum.is_available(): + torch.backends.opt_einsum.strategy = old_strategy + torch.backends.opt_einsum.enabled = old_enabled + + result = maybe_copy_to(out, result) + return wrap_tensors(result) + + +# ### sort and partition ### + + +def _sort_helper(tensor, axis, kind, order): + if tensor.dtype.is_complex: + raise NotImplementedError(f"sorting {tensor.dtype} is not supported") + (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) + axis = _util.normalize_axis_index(axis, tensor.ndim) + + stable = kind == "stable" + + return tensor, axis, stable + + +def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): + # `order` keyword arg is only relevant for structured dtypes; so not supported here. + a, axis, stable = _sort_helper(a, axis, kind, order) + result = torch.sort(a, dim=axis, stable=stable) + return result.values + + +def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): + a, axis, stable = _sort_helper(a, axis, kind, order) + return torch.argsort(a, dim=axis, stable=stable) + + +def searchsorted( + a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None +): + if a.dtype.is_complex: + raise NotImplementedError(f"searchsorted with dtype={a.dtype}") + + return torch.searchsorted(a, v, side=side, sorter=sorter) + + +# ### swap/move/roll axis ### + + +def moveaxis(a: ArrayLike, source, destination): + source = _util.normalize_axis_tuple(source, a.ndim, "source") + destination = _util.normalize_axis_tuple(destination, a.ndim, "destination") + return torch.moveaxis(a, source, destination) + + +def swapaxes(a: ArrayLike, axis1, axis2): + axis1 = _util.normalize_axis_index(axis1, a.ndim) + axis2 = _util.normalize_axis_index(axis2, a.ndim) + return torch.swapaxes(a, axis1, axis2) + + +def rollaxis(a: ArrayLike, axis, start=0): + # Straight vendor from: + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259 + # + # Also note this function in NumPy is mostly retained for backwards compat + # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing) + # so let's not touch it unless hard pressed. + n = a.ndim + axis = _util.normalize_axis_index(axis, n) + if start < 0: + start += n + msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" + if not (0 <= start < n + 1): + raise _util.AxisError(msg % ("start", -n, "start", n + 1, start)) + if axis < start: + # it's been removed + start -= 1 + if axis == start: + # numpy returns a view, here we try returning the tensor itself + # return tensor[...] + return a + axes = list(range(0, n)) + axes.remove(axis) + axes.insert(start, axis) + return a.view(axes) + + +def roll(a: ArrayLike, shift, axis=None): + if axis is not None: + axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) + if not isinstance(shift, tuple): + shift = (shift,) * len(axis) + return torch.roll(a, shift, axis) + + +# ### shape manipulations ### + + +def squeeze(a: ArrayLike, axis=None): + if axis == (): + result = a + elif axis is None: + result = a.squeeze() + else: + if isinstance(axis, tuple): + result = a + for ax in axis: + result = a.squeeze(ax) + else: + result = a.squeeze(axis) + return result + + +def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"): + # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) + newshape = newshape[0] if len(newshape) == 1 else newshape + return a.reshape(newshape) + + +# NB: cannot use torch.reshape(a, newshape) above, because of +# (Pdb) torch.reshape(torch.as_tensor([1]), 1) +# *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int + + +def transpose(a: ArrayLike, axes=None): + # numpy allows both .transpose(sh) and .transpose(*sh) + # also older code uses axes being a list + if axes in [(), None, (None,)]: + axes = tuple(reversed(range(a.ndim))) + elif len(axes) == 1: + axes = axes[0] + return a.permute(axes) + + +def ravel(a: ArrayLike, order: NotImplementedType = "C"): + return torch.flatten(a) + + +def diff( + a: ArrayLike, + n=1, + axis=-1, + prepend: Optional[ArrayLike] = None, + append: Optional[ArrayLike] = None, +): + axis = _util.normalize_axis_index(axis, a.ndim) + + if n < 0: + raise ValueError(f"order must be non-negative but got {n}") + + if n == 0: + # match numpy and return the input immediately + return a + + if prepend is not None: + shape = list(a.shape) + shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1 + prepend = torch.broadcast_to(prepend, shape) + + if append is not None: + shape = list(a.shape) + shape[axis] = append.shape[axis] if append.ndim > 0 else 1 + append = torch.broadcast_to(append, shape) + + return torch.diff(a, n, axis=axis, prepend=prepend, append=append) + + +# ### math functions ### + + +def angle(z: ArrayLike, deg=False): + result = torch.angle(z) + if deg: + result = result * (180 / torch.pi) + return result + + +def sinc(x: ArrayLike): + return torch.sinc(x) + + +# NB: have to normalize *varargs manually +def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1): + N = f.ndim # number of dimensions + + varargs = _util.ndarrays_to_tensors(varargs) + + if axis is None: + axes = tuple(range(N)) + else: + axes = _util.normalize_axis_tuple(axis, N) + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0): + # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0) + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + distances = torch.as_tensor(distances) + if distances.ndim == 0: + continue + elif distances.ndim != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError( + "when 1d, distances must match " + "the length of the corresponding dimension" + ) + if not (distances.dtype.is_floating_point or distances.dtype.is_complex): + distances = distances.double() + + diffx = torch.diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype + if _dtypes_impl.python_type_for_torch(otype) in (int, bool): + # Convert to floating point. + # First check if f is a numpy integer type; if so, convert f to float64 + # to avoid modular arithmetic when computing the changes in f. + f = f.double() + otype = torch.float64 + + for axis, ax_dx in zip(axes, dx): + if f.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical gradient, " + "at least (edge_order + 1) elements are required." + ) + # result allocation + out = torch.empty_like(f, dtype=otype) + + # spacing for the current axis (NB: np.ndim(ax_dx) == 0) + uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx) + else: + dx1 = ax_dx[0:-1] + dx2 = ax_dx[1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = [1] * N + shape[axis] = -1 + a = a.reshape(shape) + b = b.reshape(shape) + c = c.reshape(shape) + # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = ax_dx if uniform_spacing else ax_dx[0] + # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = ax_dx if uniform_spacing else ax_dx[-1] + # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / ax_dx + b = 2.0 / ax_dx + c = -0.5 / ax_dx + else: + dx1 = ax_dx[0] + dx2 = ax_dx[1] + a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = -dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / ax_dx + b = -2.0 / ax_dx + c = 1.5 / ax_dx + else: + dx1 = ax_dx[-2] + dx2 = ax_dx[-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = -(dx2 + dx1) / (dx1 * dx2) + c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals + + +# ### Type/shape etc queries ### + + +def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): + if a.is_floating_point(): + result = torch.round(a, decimals=decimals) + elif a.is_complex(): + # RuntimeError: "round_cpu" not implemented for 'ComplexFloat' + result = torch.complex( + torch.round(a.real, decimals=decimals), + torch.round(a.imag, decimals=decimals), + ) + else: + # RuntimeError: "round_cpu" not implemented for 'int' + result = a + return result + + +around = round +round_ = round + + +def real_if_close(a: ArrayLike, tol=100): + if not torch.is_complex(a): + return a + if tol > 1: + # Undocumented in numpy: if tol < 1, it's an absolute tolerance! + # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577 + tol = tol * torch.finfo(a.dtype).eps + + mask = torch.abs(a.imag) < tol + return a.real if mask.all() else a + + +def real(a: ArrayLike): + return torch.real(a) + + +def imag(a: ArrayLike): + if a.is_complex(): + return a.imag + return torch.zeros_like(a) + + +def iscomplex(x: ArrayLike): + if torch.is_complex(x): + return x.imag != 0 + return torch.zeros_like(x, dtype=torch.bool) + + +def isreal(x: ArrayLike): + if torch.is_complex(x): + return x.imag == 0 + return torch.ones_like(x, dtype=torch.bool) + + +def iscomplexobj(x: ArrayLike): + return torch.is_complex(x) + + +def isrealobj(x: ArrayLike): + return not torch.is_complex(x) + + +def isneginf(x: ArrayLike, out: Optional[OutArray] = None): + return torch.isneginf(x) + + +def isposinf(x: ArrayLike, out: Optional[OutArray] = None): + return torch.isposinf(x) + + +def i0(x: ArrayLike): + return torch.special.i0(x) + + +def isscalar(a): + # We need to use normalize_array_like, but we don't want to export it in funcs.py + from ._normalizations import normalize_array_like + + try: + t = normalize_array_like(a) + return t.numel() == 1 + except Exception: + return False + + +# ### Filter windows ### + + +def hamming(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.hamming_window(M, periodic=False, dtype=dtype) + + +def hanning(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.hann_window(M, periodic=False, dtype=dtype) + + +def kaiser(M, beta): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype) + + +def blackman(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.blackman_window(M, periodic=False, dtype=dtype) + + +def bartlett(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.bartlett_window(M, periodic=False, dtype=dtype) + + +# ### Dtype routines ### + +# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666 + + +array_type = [ + [torch.float16, torch.float32, torch.float64], + [None, torch.complex64, torch.complex128], +] +array_precision = { + torch.float16: 0, + torch.float32: 1, + torch.float64: 2, + torch.complex64: 1, + torch.complex128: 2, +} + + +def common_type(*tensors: ArrayLike): + is_complex = False + precision = 0 + for a in tensors: + t = a.dtype + if iscomplexobj(a): + is_complex = True + if not (t.is_floating_point or t.is_complex): + p = 2 # array_precision[_nx.double] + else: + p = array_precision.get(t, None) + if p is None: + raise TypeError("can't get common type for non-numeric array") + precision = builtins.max(precision, p) + if is_complex: + return array_type[1][precision] + else: + return array_type[0][precision] + + +# ### histograms ### + + +def histogram( + a: ArrayLike, + bins: ArrayLike = 10, + range=None, + normed=None, + weights: Optional[ArrayLike] = None, + density=None, +): + if normed is not None: + raise ValueError("normed argument is deprecated, use density= instead") + + if weights is not None and weights.dtype.is_complex: + raise NotImplementedError("complex weights histogram.") + + is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex) + is_w_int = weights is None or not weights.dtype.is_floating_point + if is_a_int: + a = a.double() + + if weights is not None: + weights = _util.cast_if_needed(weights, a.dtype) + + if isinstance(bins, torch.Tensor): + if bins.ndim == 0: + # bins was a single int + bins = operator.index(bins) + else: + bins = _util.cast_if_needed(bins, a.dtype) + + if range is None: + h, b = torch.histogram(a, bins, weight=weights, density=bool(density)) + else: + h, b = torch.histogram( + a, bins, range=range, weight=weights, density=bool(density) + ) + + if not density and is_w_int: + h = h.long() + if is_a_int: + b = b.long() + + return h, b + + +def histogram2d( + x, + y, + bins=10, + range: Optional[ArrayLike] = None, + normed=None, + weights: Optional[ArrayLike] = None, + density=None, +): + # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821 + if len(x) != len(y): + raise ValueError("x and y must have the same length.") + + try: + N = len(bins) + except TypeError: + N = 1 + + if N != 1 and N != 2: + bins = [bins, bins] + + h, e = histogramdd((x, y), bins, range, normed, weights, density) + + return h, e[0], e[1] + + +def histogramdd( + sample, + bins=10, + range: Optional[ArrayLike] = None, + normed=None, + weights: Optional[ArrayLike] = None, + density=None, +): + # have to normalize manually because `sample` interpretation differs + # for a list of lists and a 2D array + if normed is not None: + raise ValueError("normed argument is deprecated, use density= instead") + + from ._normalizations import normalize_array_like, normalize_seq_array_like + + if isinstance(sample, (list, tuple)): + sample = normalize_array_like(sample).T + else: + sample = normalize_array_like(sample) + + sample = torch.atleast_2d(sample) + + if not (sample.dtype.is_floating_point or sample.dtype.is_complex): + sample = sample.double() + + # bins is either an int, or a sequence of ints or a sequence of arrays + bins_is_array = not ( + isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins) + ) + if bins_is_array: + bins = normalize_seq_array_like(bins) + bins_dtypes = [b.dtype for b in bins] + bins = [_util.cast_if_needed(b, sample.dtype) for b in bins] + + if range is not None: + range = range.flatten().tolist() + + if weights is not None: + # range=... is required : interleave min and max values per dimension + mm = sample.aminmax(dim=0) + range = torch.cat(mm).reshape(2, -1).T.flatten() + range = tuple(range.tolist()) + weights = _util.cast_if_needed(weights, sample.dtype) + w_kwd = {"weight": weights} + else: + w_kwd = {} + + h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd) + + if bins_is_array: + b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)] + + return h, b + + +# ### odds and ends + + +def min_scalar_type(a: ArrayLike, /): + # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288 + + from ._dtypes import DType + + if a.numel() > 1: + # numpy docs: "For non-scalar array a, returns the vector’s dtype unmodified." + return DType(a.dtype) + + if a.dtype == torch.bool: + dtype = torch.bool + + elif a.dtype.is_complex: + fi = torch.finfo(torch.float32) + fits_in_single = a.dtype == torch.complex64 or ( + fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max + ) + dtype = torch.complex64 if fits_in_single else torch.complex128 + + elif a.dtype.is_floating_point: + for dt in [torch.float16, torch.float32, torch.float64]: + fi = torch.finfo(dt) + if fi.min <= a <= fi.max: + dtype = dt + break + else: + # must be integer + for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: + # Prefer unsigned int where possible, as numpy does. + ii = torch.iinfo(dt) + if ii.min <= a <= ii.max: + dtype = dt + break + + return DType(dtype) + + +def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs): + if mode != "constant": + raise NotImplementedError + value = kwargs.get("constant_values", 0) + # `value` must be a python scalar for torch.nn.functional.pad + typ = _dtypes_impl.python_type_for_torch(array.dtype) + value = typ(value) + + pad_width = torch.broadcast_to(pad_width, (array.ndim, 2)) + pad_width = torch.flip(pad_width, (0,)).flatten() + + return torch.nn.functional.pad(array, tuple(pad_width), value=value) diff --git a/MLPY/Lib/site-packages/torch/_numpy/_getlimits.py b/MLPY/Lib/site-packages/torch/_numpy/_getlimits.py new file mode 100644 index 0000000000000000000000000000000000000000..75036ce6ab4b0b417be7ea0a308ec19018304fd4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_getlimits.py @@ -0,0 +1,15 @@ +# mypy: ignore-errors + +import torch + +from . import _dtypes + + +def finfo(dtyp): + torch_dtype = _dtypes.dtype(dtyp).torch_dtype + return torch.finfo(torch_dtype) + + +def iinfo(dtyp): + torch_dtype = _dtypes.dtype(dtyp).torch_dtype + return torch.iinfo(torch_dtype) diff --git a/MLPY/Lib/site-packages/torch/_numpy/_ndarray.py b/MLPY/Lib/site-packages/torch/_numpy/_ndarray.py new file mode 100644 index 0000000000000000000000000000000000000000..0e81a9cec8578a1366f1e0c93e5f0aec04e69dad --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_ndarray.py @@ -0,0 +1,591 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import builtins +import math +import operator +from typing import Sequence + +import torch + +from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util +from ._normalizations import ( + ArrayLike, + normalize_array_like, + normalizer, + NotImplementedType, +) + +newaxis = None + +FLAGS = [ + "C_CONTIGUOUS", + "F_CONTIGUOUS", + "OWNDATA", + "WRITEABLE", + "ALIGNED", + "WRITEBACKIFCOPY", + "FNC", + "FORC", + "BEHAVED", + "CARRAY", + "FARRAY", +] + +SHORTHAND_TO_FLAGS = { + "C": "C_CONTIGUOUS", + "F": "F_CONTIGUOUS", + "O": "OWNDATA", + "W": "WRITEABLE", + "A": "ALIGNED", + "X": "WRITEBACKIFCOPY", + "B": "BEHAVED", + "CA": "CARRAY", + "FA": "FARRAY", +} + + +class Flags: + def __init__(self, flag_to_value: dict): + assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check + self._flag_to_value = flag_to_value + + def __getattr__(self, attr: str): + if attr.islower() and attr.upper() in FLAGS: + return self[attr.upper()] + else: + raise AttributeError(f"No flag attribute '{attr}'") + + def __getitem__(self, key): + if key in SHORTHAND_TO_FLAGS.keys(): + key = SHORTHAND_TO_FLAGS[key] + if key in FLAGS: + try: + return self._flag_to_value[key] + except KeyError as e: + raise NotImplementedError(f"{key=}") from e + else: + raise KeyError(f"No flag key '{key}'") + + def __setattr__(self, attr, value): + if attr.islower() and attr.upper() in FLAGS: + self[attr.upper()] = value + else: + super().__setattr__(attr, value) + + def __setitem__(self, key, value): + if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): + raise NotImplementedError("Modifying flags is not implemented") + else: + raise KeyError(f"No flag key '{key}'") + + +def create_method(fn, name=None): + name = name or fn.__name__ + + def f(*args, **kwargs): + return fn(*args, **kwargs) + + f.__name__ = name + f.__qualname__ = f"ndarray.{name}" + return f + + +# Map ndarray.name_method -> np.name_func +# If name_func == None, it means that name_method == name_func +methods = { + "clip": None, + "nonzero": None, + "repeat": None, + "round": None, + "squeeze": None, + "swapaxes": None, + "ravel": None, + # linalg + "diagonal": None, + "dot": None, + "trace": None, + # sorting + "argsort": None, + "searchsorted": None, + # reductions + "argmax": None, + "argmin": None, + "any": None, + "all": None, + "max": None, + "min": None, + "ptp": None, + "sum": None, + "prod": None, + "mean": None, + "var": None, + "std": None, + # scans + "cumsum": None, + "cumprod": None, + # advanced indexing + "take": None, + "choose": None, +} + +dunder = { + "abs": "absolute", + "invert": None, + "pos": "positive", + "neg": "negative", + "gt": "greater", + "lt": "less", + "ge": "greater_equal", + "le": "less_equal", +} + +# dunder methods with right-looking and in-place variants +ri_dunder = { + "add": None, + "sub": "subtract", + "mul": "multiply", + "truediv": "divide", + "floordiv": "floor_divide", + "pow": "power", + "mod": "remainder", + "and": "bitwise_and", + "or": "bitwise_or", + "xor": "bitwise_xor", + "lshift": "left_shift", + "rshift": "right_shift", + "matmul": None, +} + + +def _upcast_int_indices(index): + if isinstance(index, torch.Tensor): + if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): + return index.to(torch.int64) + elif isinstance(index, tuple): + return tuple(_upcast_int_indices(i) for i in index) + return index + + +# Used to indicate that a parameter is unspecified (as opposed to explicitly +# `None`) +class _Unspecified: + pass + + +_Unspecified.unspecified = _Unspecified() + +############################################################### +# ndarray class # +############################################################### + + +class ndarray: + def __init__(self, t=None): + if t is None: + self.tensor = torch.Tensor() + elif isinstance(t, torch.Tensor): + self.tensor = t + else: + raise ValueError( + "ndarray constructor is not recommended; prefer" + "either array(...) or zeros/empty(...)" + ) + + # Register NumPy functions as methods + for method, name in methods.items(): + fn = getattr(_funcs, name or method) + vars()[method] = create_method(fn, method) + + # Regular methods but coming from ufuncs + conj = create_method(_ufuncs.conjugate, "conj") + conjugate = create_method(_ufuncs.conjugate) + + for method, name in dunder.items(): + fn = getattr(_ufuncs, name or method) + method = f"__{method}__" + vars()[method] = create_method(fn, method) + + for method, name in ri_dunder.items(): + fn = getattr(_ufuncs, name or method) + plain = f"__{method}__" + vars()[plain] = create_method(fn, plain) + rvar = f"__r{method}__" + vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) + ivar = f"__i{method}__" + vars()[ivar] = create_method( + lambda self, other, fn=fn: fn(self, other, out=self), ivar + ) + + # There's no __idivmod__ + __divmod__ = create_method(_ufuncs.divmod, "__divmod__") + __rdivmod__ = create_method( + lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" + ) + + # prevent loop variables leaking into the ndarray class namespace + del ivar, rvar, name, plain, fn, method + + @property + def shape(self): + return tuple(self.tensor.shape) + + @property + def size(self): + return self.tensor.numel() + + @property + def ndim(self): + return self.tensor.ndim + + @property + def dtype(self): + return _dtypes.dtype(self.tensor.dtype) + + @property + def strides(self): + elsize = self.tensor.element_size() + return tuple(stride * elsize for stride in self.tensor.stride()) + + @property + def itemsize(self): + return self.tensor.element_size() + + @property + def flags(self): + # Note contiguous in torch is assumed C-style + return Flags( + { + "C_CONTIGUOUS": self.tensor.is_contiguous(), + "F_CONTIGUOUS": self.T.tensor.is_contiguous(), + "OWNDATA": self.tensor._base is None, + "WRITEABLE": True, # pytorch does not have readonly tensors + } + ) + + @property + def data(self): + return self.tensor.data_ptr() + + @property + def nbytes(self): + return self.tensor.storage().nbytes() + + @property + def T(self): + return self.transpose() + + @property + def real(self): + return _funcs.real(self) + + @real.setter + def real(self, value): + self.tensor.real = asarray(value).tensor + + @property + def imag(self): + return _funcs.imag(self) + + @imag.setter + def imag(self, value): + self.tensor.imag = asarray(value).tensor + + # ctors + def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): + if order != "K": + raise NotImplementedError(f"astype(..., order={order} is not implemented.") + if casting != "unsafe": + raise NotImplementedError( + f"astype(..., casting={casting} is not implemented." + ) + if not subok: + raise NotImplementedError(f"astype(..., subok={subok} is not implemented.") + if not copy: + raise NotImplementedError(f"astype(..., copy={copy} is not implemented.") + torch_dtype = _dtypes.dtype(dtype).torch_dtype + t = self.tensor.to(torch_dtype) + return ndarray(t) + + @normalizer + def copy(self: ArrayLike, order: NotImplementedType = "C"): + return self.clone() + + @normalizer + def flatten(self: ArrayLike, order: NotImplementedType = "C"): + return torch.flatten(self) + + def resize(self, *new_shape, refcheck=False): + # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. + if refcheck: + raise NotImplementedError( + f"resize(..., refcheck={refcheck} is not implemented." + ) + if new_shape in [(), (None,)]: + return + + # support both x.resize((2, 2)) and x.resize(2, 2) + if len(new_shape) == 1: + new_shape = new_shape[0] + if isinstance(new_shape, int): + new_shape = (new_shape,) + + if builtins.any(x < 0 for x in new_shape): + raise ValueError("all elements of `new_shape` must be non-negative") + + new_numel, old_numel = math.prod(new_shape), self.tensor.numel() + + self.tensor.resize_(new_shape) + + if new_numel >= old_numel: + # zero-fill new elements + assert self.tensor.is_contiguous() + b = self.tensor.flatten() # does not copy + b[old_numel:].zero_() + + def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified): + if dtype is _Unspecified.unspecified: + dtype = self.dtype + if type is not _Unspecified.unspecified: + raise NotImplementedError(f"view(..., type={type} is not implemented.") + torch_dtype = _dtypes.dtype(dtype).torch_dtype + tview = self.tensor.view(torch_dtype) + return ndarray(tview) + + @normalizer + def fill(self, value: ArrayLike): + # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and + # error out on D > 0 arrays + self.tensor.fill_(value) + + def tolist(self): + return self.tensor.tolist() + + def __iter__(self): + return (ndarray(x) for x in self.tensor.__iter__()) + + def __str__(self): + return ( + str(self.tensor) + .replace("tensor", "torch.ndarray") + .replace("dtype=torch.", "dtype=") + ) + + __repr__ = create_method(__str__) + + def __eq__(self, other): + try: + return _ufuncs.equal(self, other) + except (RuntimeError, TypeError): + # Failed to convert other to array: definitely not equal. + falsy = torch.full(self.shape, fill_value=False, dtype=bool) + return asarray(falsy) + + def __ne__(self, other): + return ~(self == other) + + def __index__(self): + try: + return operator.index(self.tensor.item()) + except Exception as exc: + raise TypeError( + "only integer scalar arrays can be converted to a scalar index" + ) from exc + + def __bool__(self): + return bool(self.tensor) + + def __int__(self): + return int(self.tensor) + + def __float__(self): + return float(self.tensor) + + def __complex__(self): + return complex(self.tensor) + + def is_integer(self): + try: + v = self.tensor.item() + result = int(v) == v + except Exception: + result = False + return result + + def __len__(self): + return self.tensor.shape[0] + + def __contains__(self, x): + return self.tensor.__contains__(x) + + def transpose(self, *axes): + # np.transpose(arr, axis=None) but arr.transpose(*axes) + return _funcs.transpose(self, axes) + + def reshape(self, *shape, order="C"): + # arr.reshape(shape) and arr.reshape(*shape) + return _funcs.reshape(self, shape, order=order) + + def sort(self, axis=-1, kind=None, order=None): + # ndarray.sort works in-place + _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) + + def item(self, *args): + # Mimic NumPy's implementation with three special cases (no arguments, + # a flat index and a multi-index): + # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702 + if args == (): + return self.tensor.item() + elif len(args) == 1: + # int argument + return self.ravel()[args[0]] + else: + return self.__getitem__(args) + + def __getitem__(self, index): + tensor = self.tensor + + def neg_step(i, s): + if not (isinstance(s, slice) and s.step is not None and s.step < 0): + return s + + nonlocal tensor + tensor = torch.flip(tensor, (i,)) + + # Account for the fact that a slice includes the start but not the end + assert isinstance(s.start, int) or s.start is None + assert isinstance(s.stop, int) or s.stop is None + start = s.stop + 1 if s.stop else None + stop = s.start + 1 if s.start else None + + return slice(start, stop, -s.step) + + if isinstance(index, Sequence): + index = type(index)(neg_step(i, s) for i, s in enumerate(index)) + else: + index = neg_step(0, index) + index = _util.ndarrays_to_tensors(index) + index = _upcast_int_indices(index) + return ndarray(tensor.__getitem__(index)) + + def __setitem__(self, index, value): + index = _util.ndarrays_to_tensors(index) + index = _upcast_int_indices(index) + + if not _dtypes_impl.is_scalar(value): + value = normalize_array_like(value) + value = _util.cast_if_needed(value, self.tensor.dtype) + + return self.tensor.__setitem__(index, value) + + take = _funcs.take + put = _funcs.put + + def __dlpack__(self, *, stream=None): + return self.tensor.__dlpack__(stream=stream) + + def __dlpack_device__(self): + return self.tensor.__dlpack_device__() + + +def _tolist(obj): + """Recursively convert tensors into lists.""" + a1 = [] + for elem in obj: + if isinstance(elem, (list, tuple)): + elem = _tolist(elem) + if isinstance(elem, ndarray): + a1.append(elem.tensor.tolist()) + else: + a1.append(elem) + return a1 + + +# This is the ideally the only place which talks to ndarray directly. +# The rest goes through asarray (preferred) or array. + + +def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): + if subok is not False: + raise NotImplementedError("'subok' parameter is not supported.") + if like is not None: + raise NotImplementedError("'like' parameter is not supported.") + if order != "K": + raise NotImplementedError() + + # a happy path + if ( + isinstance(obj, ndarray) + and copy is False + and dtype is None + and ndmin <= obj.ndim + ): + return obj + + if isinstance(obj, (list, tuple)): + # FIXME and they have the same dtype, device, etc + if obj and all(isinstance(x, torch.Tensor) for x in obj): + # list of arrays: *under torch.Dynamo* these are FakeTensors + obj = torch.stack(obj) + else: + # XXX: remove tolist + # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists + obj = _tolist(obj) + + # is obj an ndarray already? + if isinstance(obj, ndarray): + obj = obj.tensor + + # is a specific dtype requested? + torch_dtype = None + if dtype is not None: + torch_dtype = _dtypes.dtype(dtype).torch_dtype + + tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) + return ndarray(tensor) + + +def asarray(a, dtype=None, order="K", *, like=None): + return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) + + +def ascontiguousarray(a, dtype=None, *, like=None): + arr = asarray(a, dtype=dtype, like=like) + if not arr.tensor.is_contiguous(): + arr.tensor = arr.tensor.contiguous() + return arr + + +def from_dlpack(x, /): + t = torch.from_dlpack(x) + return ndarray(t) + + +def _extract_dtype(entry): + try: + dty = _dtypes.dtype(entry) + except Exception: + dty = asarray(entry).dtype + return dty + + +def can_cast(from_, to, casting="safe"): + from_ = _extract_dtype(from_) + to_ = _extract_dtype(to) + + return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) + + +def result_type(*arrays_and_dtypes): + tensors = [] + for entry in arrays_and_dtypes: + try: + t = asarray(entry).tensor + except (RuntimeError, ValueError, TypeError): + dty = _dtypes.dtype(entry) + t = torch.empty(1, dtype=dty.torch_dtype) + tensors.append(t) + + torch_dtype = _dtypes_impl.result_type_impl(*tensors) + return _dtypes.dtype(torch_dtype) diff --git a/MLPY/Lib/site-packages/torch/_numpy/_normalizations.py b/MLPY/Lib/site-packages/torch/_numpy/_normalizations.py new file mode 100644 index 0000000000000000000000000000000000000000..f2167e25554d782b30c7e445ae52560fb92ed397 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_normalizations.py @@ -0,0 +1,258 @@ +# mypy: ignore-errors + +""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. +""" +from __future__ import annotations + +import functools +import inspect +import operator +import typing + +import torch + +from . import _dtypes, _dtypes_impl, _util + +ArrayLike = typing.TypeVar("ArrayLike") +Scalar = typing.Union[int, float, complex, bool] +ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar] + +DTypeLike = typing.TypeVar("DTypeLike") +AxisLike = typing.TypeVar("AxisLike") +NDArray = typing.TypeVar("NDArray") +CastingModes = typing.TypeVar("CastingModes") +KeepDims = typing.TypeVar("KeepDims") + +# OutArray is to annotate the out= array argument. +# +# This one is special is several respects: +# First, It needs to be an NDArray, and we need to preserve the `result is out` +# semantics. Therefore, we cannot just extract the Tensor from the out array. +# So we never pass the out array to implementer functions and handle it in the +# `normalizer` below. +# Second, the out= argument can be either keyword or positional argument, and +# as a positional arg, it can be anywhere in the signature. +# To handle all this, we define a special `OutArray` annotation and dispatch on it. +# +OutArray = typing.TypeVar("OutArray") + +try: + from typing import NotImplementedType +except ImportError: + NotImplementedType = typing.TypeVar("NotImplementedType") + + +def normalize_array_like(x, parm=None): + from ._ndarray import asarray + + return asarray(x).tensor + + +def normalize_array_like_or_scalar(x, parm=None): + if _dtypes_impl.is_scalar_or_symbolic(x): + return x + return normalize_array_like(x, parm) + + +def normalize_optional_array_like_or_scalar(x, parm=None): + if x is None: + return None + return normalize_array_like_or_scalar(x, parm) + + +def normalize_optional_array_like(x, parm=None): + # This explicit normalizer is needed because otherwise normalize_array_like + # does not run for a parameter annotated as Optional[ArrayLike] + return None if x is None else normalize_array_like(x, parm) + + +def normalize_seq_array_like(x, parm=None): + return tuple(normalize_array_like(value) for value in x) + + +def normalize_dtype(dtype, parm=None): + # cf _decorators.dtype_to_torch + torch_dtype = None + if dtype is not None: + dtype = _dtypes.dtype(dtype) + torch_dtype = dtype.torch_dtype + return torch_dtype + + +def normalize_not_implemented(arg, parm): + if arg != parm.default: + raise NotImplementedError(f"'{parm.name}' parameter is not supported.") + + +def normalize_axis_like(arg, parm=None): + from ._ndarray import ndarray + + if isinstance(arg, ndarray): + arg = operator.index(arg) + return arg + + +def normalize_ndarray(arg, parm=None): + # check the arg is an ndarray, extract its tensor attribute + if arg is None: + return arg + + from ._ndarray import ndarray + + if not isinstance(arg, ndarray): + raise TypeError(f"'{parm.name}' must be an array") + return arg.tensor + + +def normalize_outarray(arg, parm=None): + # almost normalize_ndarray, only return the array, not its tensor + if arg is None: + return arg + from ._ndarray import ndarray + + # Dynamo can pass torch tensors as out arguments, + # wrap it in an ndarray before processing + if isinstance(arg, torch.Tensor): + arg = ndarray(arg) + + if not isinstance(arg, ndarray): + raise TypeError(f"'{parm.name}' must be an array") + return arg + + +def normalize_casting(arg, parm=None): + if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]: + raise ValueError( + f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')" + ) + return arg + + +normalizers = { + "ArrayLike": normalize_array_like, + "ArrayLikeOrScalar": normalize_array_like_or_scalar, + "Optional[ArrayLike]": normalize_optional_array_like, + "Sequence[ArrayLike]": normalize_seq_array_like, + "Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar, + "Optional[NDArray]": normalize_ndarray, + "Optional[OutArray]": normalize_outarray, + "NDArray": normalize_ndarray, + "Optional[DTypeLike]": normalize_dtype, + "AxisLike": normalize_axis_like, + "NotImplementedType": normalize_not_implemented, + "Optional[CastingModes]": normalize_casting, +} + + +def maybe_normalize(arg, parm): + """Normalize arg if a normalizer is registered.""" + normalizer = normalizers.get(parm.annotation, None) + return normalizer(arg, parm) if normalizer else arg + + +# ### Return value helpers ### + + +def maybe_copy_to(out, result, promote_scalar_result=False): + # NB: here out is either an ndarray or None + if out is None: + return result + elif isinstance(result, torch.Tensor): + if result.shape != out.shape: + can_fit = result.numel() == 1 and out.ndim == 0 + if promote_scalar_result and can_fit: + result = result.squeeze() + else: + raise ValueError( + f"Bad size of the out array: out.shape = {out.shape}" + f" while result.shape = {result.shape}." + ) + out.tensor.copy_(result) + return out + elif isinstance(result, (tuple, list)): + return type(result)( + maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result) + ) + else: + raise AssertionError() # We should never hit this path + + +def wrap_tensors(result): + from ._ndarray import ndarray + + if isinstance(result, torch.Tensor): + return ndarray(result) + elif isinstance(result, (tuple, list)): + result = type(result)(wrap_tensors(x) for x in result) + return result + + +def array_or_scalar(values, py_type=float, return_scalar=False): + if return_scalar: + return py_type(values.item()) + else: + from ._ndarray import ndarray + + return ndarray(values) + + +# ### The main decorator to normalize arguments / postprocess the output ### + + +def normalizer(_func=None, *, promote_scalar_result=False): + def normalizer_inner(func): + @functools.wraps(func) + def wrapped(*args, **kwds): + sig = inspect.signature(func) + params = sig.parameters + first_param = next(iter(params.values())) + + # NumPy's API does not have positional args before variadic positional args + if first_param.kind == inspect.Parameter.VAR_POSITIONAL: + args = [maybe_normalize(arg, first_param) for arg in args] + else: + # NB: extra unknown arguments: pass through, will raise in func(*args) below + args = ( + tuple( + maybe_normalize(arg, parm) + for arg, parm in zip(args, params.values()) + ) + + args[len(params.values()) :] + ) + + kwds = { + name: maybe_normalize(arg, params[name]) if name in params else arg + for name, arg in kwds.items() + } + + result = func(*args, **kwds) + + # keepdims + bound_args = None + if "keepdims" in params and params["keepdims"].annotation == "KeepDims": + # keepdims can be in any position so we need sig.bind + bound_args = sig.bind(*args, **kwds).arguments + if bound_args.get("keepdims", False): + # In this case the first arg is the initial tensor and + # the second arg is (optionally) the axis + tensor = args[0] + axis = bound_args.get("axis") + result = _util.apply_keepdims(result, axis, tensor.ndim) + + # out + if "out" in params: + # out can be in any position so we need sig.bind + if bound_args is None: + bound_args = sig.bind(*args, **kwds).arguments + out = bound_args.get("out") + result = maybe_copy_to(out, result, promote_scalar_result) + result = wrap_tensors(result) + + return result + + return wrapped + + if _func is None: + return normalizer_inner + else: + return normalizer_inner(_func) diff --git a/MLPY/Lib/site-packages/torch/_numpy/_reductions_impl.py b/MLPY/Lib/site-packages/torch/_numpy/_reductions_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0bba1bc12eb6de98259867132eb2d51fb9f941 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_reductions_impl.py @@ -0,0 +1,456 @@ +# mypy: ignore-errors + +""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc +in the 'public' layer. + +Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc +""" +from __future__ import annotations + +import functools +from typing import Optional + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import ( + ArrayLike, + AxisLike, + DTypeLike, + KeepDims, + NotImplementedType, + OutArray, +) + + +def _deco_axis_expand(func): + """ + Generically handle axis arguments in reductions. + axis is *always* the 2nd arg in the function so no need to have a look at its signature + """ + + @functools.wraps(func) + def wrapped(a, axis=None, *args, **kwds): + if axis is not None: + axis = _util.normalize_axis_tuple(axis, a.ndim) + + if axis == (): + # So we insert a length-one axis and run the reduction along it. + # We cannot return a.clone() as this would sidestep the checks inside the function + newshape = _util.expand_shape(a.shape, axis=0) + a = a.reshape(newshape) + axis = (0,) + + return func(a, axis, *args, **kwds) + + return wrapped + + +def _atleast_float(dtype, other_dtype): + """Return a dtype that is real or complex floating-point. + + For inputs that are boolean or integer dtypes, this returns the default + float dtype; inputs that are complex get converted to the default complex + dtype; real floating-point dtypes (`float*`) get passed through unchanged + """ + if dtype is None: + dtype = other_dtype + if not (dtype.is_floating_point or dtype.is_complex): + return _dtypes_impl.default_dtypes().float_dtype + return dtype + + +@_deco_axis_expand +def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False): + return a.count_nonzero(axis) + + +@_deco_axis_expand +def argmax( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + *, + keepdims: KeepDims = False, +): + if a.is_complex(): + raise NotImplementedError(f"argmax with dtype={a.dtype}.") + + axis = _util.allow_only_single_axis(axis) + + if a.dtype == torch.bool: + # RuntimeError: "argmax_cpu" not implemented for 'Bool' + a = a.to(torch.uint8) + + return torch.argmax(a, axis) + + +@_deco_axis_expand +def argmin( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + *, + keepdims: KeepDims = False, +): + if a.is_complex(): + raise NotImplementedError(f"argmin with dtype={a.dtype}.") + + axis = _util.allow_only_single_axis(axis) + + if a.dtype == torch.bool: + # RuntimeError: "argmin_cpu" not implemented for 'Bool' + a = a.to(torch.uint8) + + return torch.argmin(a, axis) + + +@_deco_axis_expand +def any( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + axis = _util.allow_only_single_axis(axis) + axis_kw = {} if axis is None else {"dim": axis} + return torch.any(a, **axis_kw) + + +@_deco_axis_expand +def all( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + axis = _util.allow_only_single_axis(axis) + axis_kw = {} if axis is None else {"dim": axis} + return torch.all(a, **axis_kw) + + +@_deco_axis_expand +def amax( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + if a.is_complex(): + raise NotImplementedError(f"amax with dtype={a.dtype}") + + return a.amax(axis) + + +max = amax + + +@_deco_axis_expand +def amin( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + if a.is_complex(): + raise NotImplementedError(f"amin with dtype={a.dtype}") + + return a.amin(axis) + + +min = amin + + +@_deco_axis_expand +def ptp( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, +): + return a.amax(axis) - a.amin(axis) + + +@_deco_axis_expand +def sum( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + assert dtype is None or isinstance(dtype, torch.dtype) + + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + + axis_kw = {} if axis is None else {"dim": axis} + return a.sum(dtype=dtype, **axis_kw) + + +@_deco_axis_expand +def prod( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + axis = _util.allow_only_single_axis(axis) + + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + + axis_kw = {} if axis is None else {"dim": axis} + return a.prod(dtype=dtype, **axis_kw) + + +product = prod + + +@_deco_axis_expand +def mean( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + dtype = _atleast_float(dtype, a.dtype) + + axis_kw = {} if axis is None else {"dim": axis} + result = a.mean(dtype=dtype, **axis_kw) + + return result + + +@_deco_axis_expand +def std( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + ddof=0, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + in_dtype = dtype + dtype = _atleast_float(dtype, a.dtype) + tensor = _util.cast_if_needed(a, dtype) + result = tensor.std(dim=axis, correction=ddof) + return _util.cast_if_needed(result, in_dtype) + + +@_deco_axis_expand +def var( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + ddof=0, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + in_dtype = dtype + dtype = _atleast_float(dtype, a.dtype) + tensor = _util.cast_if_needed(a, dtype) + result = tensor.var(dim=axis, correction=ddof) + return _util.cast_if_needed(result, in_dtype) + + +# cumsum / cumprod are almost reductions: +# 1. no keepdims +# 2. axis=None flattens + + +def cumsum( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, +): + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + if dtype is None: + dtype = a.dtype + + (a,), axis = _util.axis_none_flatten(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + + return a.cumsum(axis=axis, dtype=dtype) + + +def cumprod( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, +): + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + if dtype is None: + dtype = a.dtype + + (a,), axis = _util.axis_none_flatten(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + + return a.cumprod(axis=axis, dtype=dtype) + + +cumproduct = cumprod + + +def average( + a: ArrayLike, + axis=None, + weights: ArrayLike = None, + returned=False, + *, + keepdims=False, +): + if weights is None: + result = mean(a, axis=axis) + wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype) + else: + if not a.dtype.is_floating_point: + a = a.double() + + # axis & weights + if a.shape != weights.shape: + if axis is None: + raise TypeError( + "Axis must be specified when shapes of a and weights differ." + ) + if weights.ndim != 1: + raise TypeError( + "1D weights expected when shapes of a and weights differ." + ) + if weights.shape[0] != a.shape[axis]: + raise ValueError( + "Length of weights not compatible with specified axis." + ) + + # setup weight to broadcast along axis + weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape) + weights = weights.swapaxes(-1, axis) + + # do the work + result_dtype = _dtypes_impl.result_type_impl(a, weights) + numerator = sum(a * weights, axis, dtype=result_dtype) + wsum = sum(weights, axis, dtype=result_dtype) + result = numerator / wsum + + # We process keepdims manually because the decorator does not deal with variadic returns + if keepdims: + result = _util.apply_keepdims(result, axis, a.ndim) + + if returned: + if wsum.shape != result.shape: + wsum = torch.broadcast_to(wsum, result.shape).clone() + return result, wsum + else: + return result + + +# Not using deco_axis_expand as it assumes that axis is the second arg +def quantile( + a: ArrayLike, + q: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + overwrite_input=False, + method="linear", + keepdims: KeepDims = False, + *, + interpolation: NotImplementedType = None, +): + if overwrite_input: + # raise NotImplementedError("overwrite_input in quantile not implemented.") + # NumPy documents that `overwrite_input` MAY modify inputs: + # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile + # Here we choose to work out-of-place because why not. + pass + + if not a.dtype.is_floating_point: + dtype = _dtypes_impl.default_dtypes().float_dtype + a = a.to(dtype) + + # edge case: torch.quantile only supports float32 and float64 + if a.dtype == torch.float16: + a = a.to(torch.float32) + + if axis is None: + a = a.flatten() + q = q.flatten() + axis = (0,) + else: + axis = _util.normalize_axis_tuple(axis, a.ndim) + + # FIXME(Mario) Doesn't np.quantile accept a tuple? + # torch.quantile does accept a number. If we don't want to implement the tuple behaviour + # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above. + axis = _util.allow_only_single_axis(axis) + + q = _util.cast_if_needed(q, a.dtype) + + return torch.quantile(a, q, axis=axis, interpolation=method) + + +def percentile( + a: ArrayLike, + q: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + overwrite_input=False, + method="linear", + keepdims: KeepDims = False, + *, + interpolation: NotImplementedType = None, +): + # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32 + if _dtypes_impl.python_type_for_torch(q.dtype) == int: + q = q.to(_dtypes_impl.default_dtypes().float_dtype) + qq = q / 100.0 + + return quantile( + a, + qq, + axis=axis, + overwrite_input=overwrite_input, + method=method, + keepdims=keepdims, + interpolation=interpolation, + ) + + +def median( + a: ArrayLike, + axis=None, + out: Optional[OutArray] = None, + overwrite_input=False, + keepdims: KeepDims = False, +): + return quantile( + a, + torch.as_tensor(0.5), + axis=axis, + overwrite_input=overwrite_input, + out=out, + keepdims=keepdims, + ) diff --git a/MLPY/Lib/site-packages/torch/_numpy/_ufuncs.py b/MLPY/Lib/site-packages/torch/_numpy/_ufuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..139aa89ebc5016dad0c522efaa2970c308f2df59 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_ufuncs.py @@ -0,0 +1,334 @@ +# mypy: ignore-errors + +from __future__ import annotations + +from typing import Optional + +import torch + +from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util +from ._normalizations import ( + ArrayLike, + ArrayLikeOrScalar, + CastingModes, + DTypeLike, + normalizer, + NotImplementedType, + OutArray, +) + + +def _ufunc_postprocess(result, out, casting): + if out is not None: + result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) + result = torch.broadcast_to(result, out.shape) + return result + + +# ############# Binary ufuncs ###################### + +_binary = [ + name + for name in dir(_binary_ufuncs_impl) + if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"] +] + + +NEP50_FUNCS = ( + "add", + "subtract", + "multiply", + "floor_divide", + "true_divide", + "divide", + "remainder", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "hypot", + "arctan2", + "logaddexp", + "logaddexp2", + "heaviside", + "copysign", + "fmax", + "minimum", + "fmin", + "maximum", + "fmod", + "gcd", + "lcm", + "pow", +) + + +def deco_binary_ufunc(torch_func): + """Common infra for binary ufuncs. + + Normalize arguments, sort out type casting, broadcasting and delegate to + the pytorch functions for the actual work. + """ + + @normalizer + def wrapped( + x1: ArrayLikeOrScalar, + x2: ArrayLikeOrScalar, + /, + out: Optional[OutArray] = None, + *, + where: NotImplementedType = True, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, + ): + if dtype is not None: + + def cast(x, dtype): + if isinstance(x, torch.Tensor): + return _util.typecast_tensor(x, dtype, casting) + else: + return torch.as_tensor(x, dtype=dtype) + + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + else: + x1, x2 = _dtypes_impl.nep50_to_tensors( + x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__ + ) + + result = torch_func(x1, x2) + + return _ufunc_postprocess(result, out, casting) + + wrapped.__qualname__ = torch_func.__name__ + wrapped.__name__ = torch_func.__name__ + + return wrapped + + +# matmul's signature is _slightly_ different from other ufuncs: +# - no where=... +# - additional axis=..., axes=... +# - no NEP50 scalars in or out +@normalizer +def matmul( + x1: ArrayLike, + x2: ArrayLike, + /, + out: Optional[OutArray] = None, + *, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, + axes: NotImplementedType = None, + axis: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + + result = _binary_ufuncs_impl.matmul(x1, x2) + + result = _ufunc_postprocess(result, out, casting) + return result + + +# ldexp casting is special : the dtype of the result == dtype of the 1st arg +@normalizer +def ldexp( + x1: ArrayLikeOrScalar, + x2: ArrayLikeOrScalar, + /, + out: Optional[OutArray] = None, + *, + where: NotImplementedType = True, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, +): + if dtype is not None: + if isinstance(x1, torch.Tensor): + x1 = _util.typecast_tensor(x1, dtype, casting) + else: + x1 = torch.as_tensor(x1, dtype=dtype) + else: + if not isinstance(x1, torch.Tensor): + x1 = torch.as_tensor(x1) + x1 = _util.cast_int_to_float(x1) + + x2 = torch.as_tensor(x2) + # the second arg must be integer + if _dtypes_impl._category(x2.dtype) != 1: + raise ValueError("ldexp 2nd arg must be integer") + + result = _binary_ufuncs_impl.ldexp(x1, x2) + + if x1.dtype == torch.float16: + # torch.ldexp(f16, int) -> f32, undo it + result = result.to(torch.float16) + + return _ufunc_postprocess(result, out, casting) + + +# nin=2, nout=2 +@normalizer +def divmod( + x1: ArrayLike, + x2: ArrayLike, + out1: Optional[OutArray] = None, + out2: Optional[OutArray] = None, + /, + out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), + *, + where: NotImplementedType = True, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, +): + # make sure we either have no out arrays at all, or there is either + # out1, out2, or out=tuple, but not both + num_outs = sum(x is not None for x in [out1, out2]) + if num_outs == 1: + raise ValueError("both out1 and out2 need to be provided") + elif num_outs == 2: + o1, o2 = out + if o1 is not None or o2 is not None: + raise TypeError( + "cannot specify 'out' as both a positional and keyword argument" + ) + else: + out1, out2 = out + + if dtype is None: + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + + quot, rem = _binary_ufuncs_impl.divmod(x1, x2) + + quot = _ufunc_postprocess(quot, out1, casting) + rem = _ufunc_postprocess(rem, out2, casting) + return quot, rem + + +# +# Attach ufuncs to this module, for a further export to the public namespace in __init__.py +# +for name in _binary: + ufunc = getattr(_binary_ufuncs_impl, name) + vars()[name] = deco_binary_ufunc(ufunc) + + +def modf(x, /, *args, **kwds): + quot, rem = divmod(x, 1, *args, **kwds) + return rem, quot + + +_binary = _binary + ["divmod", "modf", "matmul", "ldexp"] + + +# ############# Unary ufuncs ###################### + + +_unary = [ + name + for name in dir(_unary_ufuncs_impl) + if not name.startswith("_") and name != "torch" +] + + +# these are ufunc(int) -> float +_fp_unary = [ + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctanh", + "cbrt", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "exp2", + "expm1", + "log", + "log10", + "log1p", + "log2", + "rad2deg", + "radians", + "reciprocal", + "sin", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + "trunc", +] + + +def deco_unary_ufunc(torch_func): + """Common infra for unary ufuncs. + + Normalize arguments, sort out type casting, broadcasting and delegate to + the pytorch functions for the actual work. + """ + + @normalizer + def wrapped( + x: ArrayLike, + /, + out: Optional[OutArray] = None, + *, + where=True, + casting: Optional[CastingModes] = "same_kind", + order="K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature=None, + extobj=None, + ): + if dtype is not None: + x = _util.typecast_tensor(x, dtype, casting) + + if torch_func.__name__ in _fp_unary: + x = _util.cast_int_to_float(x) + + result = torch_func(x) + result = _ufunc_postprocess(result, out, casting) + return result + + wrapped.__qualname__ = torch_func.__name__ + wrapped.__name__ = torch_func.__name__ + + return wrapped + + +# +# Attach ufuncs to this module, for a further export to the public namespace in __init__.py +# +for name in _unary: + ufunc = getattr(_unary_ufuncs_impl, name) + vars()[name] = deco_unary_ufunc(ufunc) + + +__all__ = _binary + _unary # noqa: PLE0605 diff --git a/MLPY/Lib/site-packages/torch/_numpy/_unary_ufuncs_impl.py b/MLPY/Lib/site-packages/torch/_numpy/_unary_ufuncs_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8678f87816a36cae55bec7d525bf514f5149c3f5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_unary_ufuncs_impl.py @@ -0,0 +1,73 @@ +# mypy: ignore-errors + +"""Export torch work functions for unary ufuncs, rename/tweak to match numpy. +This listing is further exported to public symbols in the `_numpy/_ufuncs.py` module. +""" + +import torch + +from torch import ( # noqa: F401 + absolute as fabs, # noqa: F401 + arccos, # noqa: F401 + arccosh, # noqa: F401 + arcsin, # noqa: F401 + arcsinh, # noqa: F401 + arctan, # noqa: F401 + arctanh, # noqa: F401 + bitwise_not, # noqa: F401 + bitwise_not as invert, # noqa: F401 + ceil, # noqa: F401 + conj_physical as conjugate, # noqa: F401 + cos, # noqa: F401 + cosh, # noqa: F401 + deg2rad, # noqa: F401 + deg2rad as radians, # noqa: F401 + exp, # noqa: F401 + exp2, # noqa: F401 + expm1, # noqa: F401 + floor, # noqa: F401 + isfinite, # noqa: F401 + isinf, # noqa: F401 + isnan, # noqa: F401 + log, # noqa: F401 + log10, # noqa: F401 + log1p, # noqa: F401 + log2, # noqa: F401 + logical_not, # noqa: F401 + negative, # noqa: F401 + rad2deg, # noqa: F401 + rad2deg as degrees, # noqa: F401 + reciprocal, # noqa: F401 + round as fix, # noqa: F401 + round as rint, # noqa: F401 + sign, # noqa: F401 + signbit, # noqa: F401 + sin, # noqa: F401 + sinh, # noqa: F401 + sqrt, # noqa: F401 + square, # noqa: F401 + tan, # noqa: F401 + tanh, # noqa: F401 + trunc, # noqa: F401 +) + + +# special cases: torch does not export these names +def cbrt(x): + return torch.pow(x, 1 / 3) + + +def positive(x): + return +x + + +def absolute(x): + # work around torch.absolute not impl for bools + if x.dtype == torch.bool: + return x + return torch.absolute(x) + + +# TODO set __name__ and __qualname__ +abs = absolute +conj = conjugate diff --git a/MLPY/Lib/site-packages/torch/_numpy/_util.py b/MLPY/Lib/site-packages/torch/_numpy/_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c147bd30550972625a0b8fb424568096654c047c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/_util.py @@ -0,0 +1,261 @@ +# mypy: ignore-errors + +"""Assorted utilities, which do not need anything other then torch and stdlib. +""" + +import operator + +import torch + +from . import _dtypes_impl + + +# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 +def is_sequence(seq): + if isinstance(seq, str): + return False + try: + len(seq) + except Exception: + return False + return True + + +class AxisError(ValueError, IndexError): + pass + + +class UFuncTypeError(TypeError, RuntimeError): + pass + + +def cast_if_needed(tensor, dtype): + # NB: no casting if dtype=None + if dtype is not None and tensor.dtype != dtype: + tensor = tensor.to(dtype) + return tensor + + +def cast_int_to_float(x): + # cast integers and bools to the default float dtype + if _dtypes_impl._category(x.dtype) < 2: + x = x.to(_dtypes_impl.default_dtypes().float_dtype) + return x + + +# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h +def normalize_axis_index(ax, ndim, argname=None): + if not (-ndim <= ax < ndim): + raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") + if ax < 0: + ax += ndim + return ax + + +# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + Used internally by multi-axis-checking logic. + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + """ + # Optimization to speed-up the most common cases. + if type(axis) not in (tuple, list): + try: + axis = [operator.index(axis)] + except TypeError: + pass + # Going via an iterator directly is slower than via list comprehension. + axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) + if not allow_duplicate and len(set(axis)) != len(axis): + if argname: + raise ValueError(f"repeated axis in `{argname}` argument") + else: + raise ValueError("repeated axis") + return axis + + +def allow_only_single_axis(axis): + if axis is None: + return axis + if len(axis) != 1: + raise NotImplementedError("does not handle tuple axis") + return axis[0] + + +def expand_shape(arr_shape, axis): + # taken from numpy 1.23.x, expand_dims function + if type(axis) not in (list, tuple): + axis = (axis,) + out_ndim = len(axis) + len(arr_shape) + axis = normalize_axis_tuple(axis, out_ndim) + shape_it = iter(arr_shape) + shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] + return shape + + +def apply_keepdims(tensor, axis, ndim): + if axis is None: + # tensor was a scalar + shape = (1,) * ndim + tensor = tensor.expand(shape).contiguous() + else: + shape = expand_shape(tensor.shape, axis) + tensor = tensor.reshape(shape) + return tensor + + +def axis_none_flatten(*tensors, axis=None): + """Flatten the arrays if axis is None.""" + if axis is None: + tensors = tuple(ar.flatten() for ar in tensors) + return tensors, 0 + else: + return tensors, axis + + +def typecast_tensor(t, target_dtype, casting): + """Dtype-cast tensor to target_dtype. + + Parameters + ---------- + t : torch.Tensor + The tensor to cast + target_dtype : torch dtype object + The array dtype to cast all tensors to + casting : str + The casting mode, see `np.can_cast` + + Returns + ------- + `torch.Tensor` of the `target_dtype` dtype + + Raises + ------ + ValueError + if the argument cannot be cast according to the `casting` rule + + """ + can_cast = _dtypes_impl.can_cast_impl + + if not can_cast(t.dtype, target_dtype, casting=casting): + raise TypeError( + f"Cannot cast array data from {t.dtype} to" + f" {target_dtype} according to the rule '{casting}'" + ) + return cast_if_needed(t, target_dtype) + + +def typecast_tensors(tensors, target_dtype, casting): + return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) + + +def _try_convert_to_tensor(obj): + try: + tensor = torch.as_tensor(obj) + except Exception as e: + mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." + raise NotImplementedError(mesg) # noqa: TRY200 + return tensor + + +def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): + """The core logic of the array(...) function. + + Parameters + ---------- + obj : tensor_like + The thing to coerce + dtype : torch.dtype object or None + Coerce to this torch dtype + copy : bool + Copy or not + ndmin : int + The results as least this many dimensions + is_weak : bool + Whether obj is a weakly typed python scalar. + + Returns + ------- + tensor : torch.Tensor + a tensor object with requested dtype, ndim and copy semantics. + + Notes + ----- + This is almost a "tensor_like" coersion function. Does not handle wrapper + ndarrays (those should be handled in the ndarray-aware layer prior to + invoking this function). + """ + if isinstance(obj, torch.Tensor): + tensor = obj + else: + # tensor.dtype is the pytorch default, typically float32. If obj's elements + # are not exactly representable in float32, we've lost precision: + # >>> torch.as_tensor(1e12).item() - 1e12 + # -4096.0 + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) + try: + tensor = _try_convert_to_tensor(obj) + finally: + torch.set_default_dtype(default_dtype) + + # type cast if requested + tensor = cast_if_needed(tensor, dtype) + + # adjust ndim if needed + ndim_extra = ndmin - tensor.ndim + if ndim_extra > 0: + tensor = tensor.view((1,) * ndim_extra + tensor.shape) + + # copy if requested + if copy: + tensor = tensor.clone() + + return tensor + + +def ndarrays_to_tensors(*inputs): + """Convert all ndarrays from `inputs` to tensors. (other things are intact)""" + from ._ndarray import ndarray + + if len(inputs) == 0: + return ValueError() + elif len(inputs) == 1: + input_ = inputs[0] + if isinstance(input_, ndarray): + return input_.tensor + elif isinstance(input_, tuple): + result = [] + for sub_input in input_: + sub_result = ndarrays_to_tensors(sub_input) + result.append(sub_result) + return tuple(result) + else: + return input_ + else: + assert isinstance(inputs, tuple) # sanity check + return ndarrays_to_tensors(inputs) diff --git a/MLPY/Lib/site-packages/torch/_numpy/fft.py b/MLPY/Lib/site-packages/torch/_numpy/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..ac26d8bc787c90023cd6b0e7a4b9abcb336dee92 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/fft.py @@ -0,0 +1,130 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import functools + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import ArrayLike, normalizer + + +def upcast(func): + """NumPy fft casts inputs to 64 bit and *returns 64-bit results*.""" + + @functools.wraps(func) + def wrapped(tensor, *args, **kwds): + target_dtype = ( + _dtypes_impl.default_dtypes().complex_dtype + if tensor.is_complex() + else _dtypes_impl.default_dtypes().float_dtype + ) + tensor = _util.cast_if_needed(tensor, target_dtype) + return func(tensor, *args, **kwds) + + return wrapped + + +@normalizer +@upcast +def fft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.fft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def ifft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.ifft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def rfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.rfft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def irfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.irfft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def fftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.fftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def ifftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.ifftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def rfftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.rfftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def irfftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.irfftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.fft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.ifft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.rfft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.irfft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def hfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.hfft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def ihfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.ihfft(a, n, dim=axis, norm=norm) + + +@normalizer +def fftfreq(n, d=1.0): + return torch.fft.fftfreq(n, d) + + +@normalizer +def rfftfreq(n, d=1.0): + return torch.fft.rfftfreq(n, d) + + +@normalizer +def fftshift(x: ArrayLike, axes=None): + return torch.fft.fftshift(x, axes) + + +@normalizer +def ifftshift(x: ArrayLike, axes=None): + return torch.fft.ifftshift(x, axes) diff --git a/MLPY/Lib/site-packages/torch/_numpy/linalg.py b/MLPY/Lib/site-packages/torch/_numpy/linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..a17808c5b454f122d817ede08377f18413686f2c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/linalg.py @@ -0,0 +1,239 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import functools +import math +from typing import Sequence + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import ArrayLike, KeepDims, normalizer + + +class LinAlgError(Exception): + pass + + +def _atleast_float_1(a): + if not (a.dtype.is_floating_point or a.dtype.is_complex): + a = a.to(_dtypes_impl.default_dtypes().float_dtype) + return a + + +def _atleast_float_2(a, b): + dtyp = _dtypes_impl.result_type_impl(a, b) + if not (dtyp.is_floating_point or dtyp.is_complex): + dtyp = _dtypes_impl.default_dtypes().float_dtype + + a = _util.cast_if_needed(a, dtyp) + b = _util.cast_if_needed(b, dtyp) + return a, b + + +def linalg_errors(func): + @functools.wraps(func) + def wrapped(*args, **kwds): + try: + return func(*args, **kwds) + except torch._C._LinAlgError as e: + raise LinAlgError(*e.args) # noqa: TRY200 + + return wrapped + + +# ### Matrix and vector products ### + + +@normalizer +@linalg_errors +def matrix_power(a: ArrayLike, n): + a = _atleast_float_1(a) + return torch.linalg.matrix_power(a, n) + + +@normalizer +@linalg_errors +def multi_dot(inputs: Sequence[ArrayLike], *, out=None): + return torch.linalg.multi_dot(inputs) + + +# ### Solving equations and inverting matrices ### + + +@normalizer +@linalg_errors +def solve(a: ArrayLike, b: ArrayLike): + a, b = _atleast_float_2(a, b) + return torch.linalg.solve(a, b) + + +@normalizer +@linalg_errors +def lstsq(a: ArrayLike, b: ArrayLike, rcond=None): + a, b = _atleast_float_2(a, b) + # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991 + # on CUDA, only `gels` is available though, so use it instead + driver = "gels" if a.is_cuda or b.is_cuda else "gelsd" + return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) + + +@normalizer +@linalg_errors +def inv(a: ArrayLike): + a = _atleast_float_1(a) + result = torch.linalg.inv(a) + return result + + +@normalizer +@linalg_errors +def pinv(a: ArrayLike, rcond=1e-15, hermitian=False): + a = _atleast_float_1(a) + return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian) + + +@normalizer +@linalg_errors +def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None): + a, b = _atleast_float_2(a, b) + return torch.linalg.tensorsolve(a, b, dims=axes) + + +@normalizer +@linalg_errors +def tensorinv(a: ArrayLike, ind=2): + a = _atleast_float_1(a) + return torch.linalg.tensorinv(a, ind=ind) + + +# ### Norms and other numbers ### + + +@normalizer +@linalg_errors +def det(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.det(a) + + +@normalizer +@linalg_errors +def slogdet(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.slogdet(a) + + +@normalizer +@linalg_errors +def cond(x: ArrayLike, p=None): + x = _atleast_float_1(x) + + # check if empty + # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 + if x.numel() == 0 and math.prod(x.shape[-2:]) == 0: + raise LinAlgError("cond is not defined on empty arrays") + + result = torch.linalg.cond(x, p=p) + + # Convert nans to infs (numpy does it in a data-dependent way, depending on + # whether the input array has nans or not) + # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 + return torch.where(torch.isnan(result), float("inf"), result) + + +@normalizer +@linalg_errors +def matrix_rank(a: ArrayLike, tol=None, hermitian=False): + a = _atleast_float_1(a) + + if a.ndim < 2: + return int((a != 0).any()) + + if tol is None: + # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885 + atol = 0 + rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps + else: + atol, rtol = tol, 0 + return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian) + + +@normalizer +@linalg_errors +def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False): + x = _atleast_float_1(x) + return torch.linalg.norm(x, ord=ord, dim=axis) + + +# ### Decompositions ### + + +@normalizer +@linalg_errors +def cholesky(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.cholesky(a) + + +@normalizer +@linalg_errors +def qr(a: ArrayLike, mode="reduced"): + a = _atleast_float_1(a) + result = torch.linalg.qr(a, mode=mode) + if mode == "r": + # match NumPy + result = result.R + return result + + +@normalizer +@linalg_errors +def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False): + a = _atleast_float_1(a) + if not compute_uv: + return torch.linalg.svdvals(a) + + # NB: ignore the hermitian= argument (no pytorch equivalent) + result = torch.linalg.svd(a, full_matrices=full_matrices) + return result + + +# ### Eigenvalues and eigenvectors ### + + +@normalizer +@linalg_errors +def eig(a: ArrayLike): + a = _atleast_float_1(a) + w, vt = torch.linalg.eig(a) + + if not a.is_complex() and w.is_complex() and (w.imag == 0).all(): + w = w.real + vt = vt.real + return w, vt + + +@normalizer +@linalg_errors +def eigh(a: ArrayLike, UPLO="L"): + a = _atleast_float_1(a) + return torch.linalg.eigh(a, UPLO=UPLO) + + +@normalizer +@linalg_errors +def eigvals(a: ArrayLike): + a = _atleast_float_1(a) + result = torch.linalg.eigvals(a) + if not a.is_complex() and result.is_complex() and (result.imag == 0).all(): + result = result.real + return result + + +@normalizer +@linalg_errors +def eigvalsh(a: ArrayLike, UPLO="L"): + a = _atleast_float_1(a) + return torch.linalg.eigvalsh(a, UPLO=UPLO) diff --git a/MLPY/Lib/site-packages/torch/_numpy/random.py b/MLPY/Lib/site-packages/torch/_numpy/random.py new file mode 100644 index 0000000000000000000000000000000000000000..57155b7bf9f081366dac3cfe706bc5b0c7231a2d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/random.py @@ -0,0 +1,191 @@ +# mypy: ignore-errors + +"""Wrapper to mimic (parts of) np.random API surface. + +NumPy has strict guarantees on reproducibility etc; here we don't give any. + +Q: default dtype is float64 in numpy + +""" +from __future__ import annotations + +import functools +from math import sqrt +from typing import Optional + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import array_or_scalar, ArrayLike, normalizer + + +__all__ = [ + "seed", + "random_sample", + "sample", + "random", + "rand", + "randn", + "normal", + "choice", + "randint", + "shuffle", + "uniform", +] + + +def use_numpy_random(): + # local import to avoid ref cycles + import torch._dynamo.config as config + + return config.use_numpy_random_stream + + +def deco_stream(func): + @functools.wraps(func) + def inner(*args, **kwds): + if not use_numpy_random(): + return func(*args, **kwds) + else: + import numpy + + from ._ndarray import ndarray + + f = getattr(numpy.random, func.__name__) + + # numpy funcs accept numpy ndarrays, unwrap + args = tuple( + arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args + ) + kwds = { + key: val.tensor.numpy() if isinstance(val, ndarray) else val + for key, val in kwds.items() + } + + value = f(*args, **kwds) + + # `value` can be either numpy.ndarray or python scalar (or None) + if isinstance(value, numpy.ndarray): + value = ndarray(torch.as_tensor(value)) + + return value + + return inner + + +@deco_stream +def seed(seed=None): + if seed is not None: + torch.random.manual_seed(seed) + + +@deco_stream +def random_sample(size=None): + if size is None: + size = () + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.empty(size, dtype=dtype).uniform_() + return array_or_scalar(values, return_scalar=size == ()) + + +def rand(*size): + if size == (): + size = None + return random_sample(size) + + +sample = random_sample +random = random_sample + + +@deco_stream +def uniform(low=0.0, high=1.0, size=None): + if size is None: + size = () + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.empty(size, dtype=dtype).uniform_(low, high) + return array_or_scalar(values, return_scalar=size == ()) + + +@deco_stream +def randn(*size): + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.randn(size, dtype=dtype) + return array_or_scalar(values, return_scalar=size == ()) + + +@deco_stream +def normal(loc=0.0, scale=1.0, size=None): + if size is None: + size = () + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.empty(size, dtype=dtype).normal_(loc, scale) + return array_or_scalar(values, return_scalar=size == ()) + + +@deco_stream +def shuffle(x): + # no @normalizer because we do not cast e.g. lists to tensors + from ._ndarray import ndarray + + if isinstance(x, torch.Tensor): + tensor = x + elif isinstance(x, ndarray): + tensor = x.tensor + else: + raise NotImplementedError("We do not random.shuffle lists in-place") + + perm = torch.randperm(tensor.shape[0]) + xp = tensor[perm] + tensor.copy_(xp) + + +@deco_stream +def randint(low, high=None, size=None): + if size is None: + size = () + if not isinstance(size, (tuple, list)): + size = (size,) + if high is None: + low, high = 0, low + values = torch.randint(low, high, size=size) + return array_or_scalar(values, int, return_scalar=size == ()) + + +@deco_stream +@normalizer +def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): + # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch + if a.numel() == 1: + a = torch.arange(a) + + # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises + + # number of draws + if size is None: + num_el = 1 + elif _util.is_sequence(size): + num_el = 1 + for el in size: + num_el *= el + else: + num_el = size + + # prepare the probabilities + if p is None: + p = torch.ones_like(a) / a.shape[0] + + # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973 + atol = sqrt(torch.finfo(p.dtype).eps) + if abs(p.sum() - 1.0) > atol: + raise ValueError("probabilities do not sum to 1.") + + # actually sample + indices = torch.multinomial(p, num_el, replacement=replace) + + if _util.is_sequence(size): + indices = indices.reshape(size) + + samples = a[indices] + + return samples diff --git a/MLPY/Lib/site-packages/torch/_numpy/testing/__init__.py b/MLPY/Lib/site-packages/torch/_numpy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02051d8eedc61437cf8bc0d2a85fd4f46ce1b692 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/testing/__init__.py @@ -0,0 +1,19 @@ +# mypy: ignore-errors + +from .utils import ( + _gen_alignment_data, + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_array_less, + assert_equal, + assert_raises_regex, + assert_warns, + HAS_REFCOUNT, + IS_WASM, + suppress_warnings, +) + +# from .testing import assert_allclose # FIXME diff --git a/MLPY/Lib/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..272d490cf4914e39770e08ddc323ead636216d79 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7a137063da7030b007650f70a6eb3f82e199427 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_numpy/testing/utils.py b/MLPY/Lib/site-packages/torch/_numpy/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1b1b683f460865fd9086b4190cb6a96d44298f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_numpy/testing/utils.py @@ -0,0 +1,2390 @@ +# mypy: ignore-errors + +""" +Utility function to facilitate testing. + +""" +import contextlib +import gc +import operator +import os +import platform +import pprint +import re +import shutil +import sys +import warnings +from functools import wraps +from io import StringIO +from tempfile import mkdtemp, mkstemp +from warnings import WarningMessage + +import torch._numpy as np +from torch._numpy import arange, asarray as asanyarray, empty, float32, intp, ndarray + +__all__ = [ + "assert_equal", + "assert_almost_equal", + "assert_approx_equal", + "assert_array_equal", + "assert_array_less", + "assert_string_equal", + "assert_", + "assert_array_almost_equal", + "build_err_msg", + "decorate_methods", + "print_assert_equal", + "verbose", + "assert_", + "assert_array_almost_equal_nulp", + "assert_raises_regex", + "assert_array_max_ulp", + "assert_warns", + "assert_no_warnings", + "assert_allclose", + "IgnoreException", + "clear_and_catch_warnings", + "temppath", + "tempdir", + "IS_PYPY", + "HAS_REFCOUNT", + "IS_WASM", + "suppress_warnings", + "assert_array_compare", + "assert_no_gc_cycles", + "break_cycles", + "IS_PYSTON", +] + + +verbose = 0 + +IS_WASM = platform.machine() in ["wasm32", "wasm64"] +IS_PYPY = sys.implementation.name == "pypy" +IS_PYSTON = hasattr(sys, "pyston_version_info") +HAS_REFCOUNT = getattr(sys, "getrefcount", None) is not None and not IS_PYSTON + + +def assert_(val, msg=""): + """ + Assert that works in release mode. + Accepts callable msg to allow deferring evaluation until failure. + + The Python built-in ``assert`` does not work when executing code in + optimized mode (the ``-O`` flag) - no byte-code is generated for it. + + For documentation on usage, refer to the Python documentation. + + """ + __tracebackhide__ = True # Hide traceback for py.test + if not val: + try: + smsg = msg() + except TypeError: + smsg = msg + raise AssertionError(smsg) + + +def gisnan(x): + return np.isnan(x) + + +def gisfinite(x): + return np.isfinite(x) + + +def gisinf(x): + return np.isinf(x) + + +def build_err_msg( + arrays, + err_msg, + header="Items are not equal:", + verbose=True, + names=("ACTUAL", "DESIRED"), + precision=8, +): + msg = ["\n" + header] + if err_msg: + if err_msg.find("\n") == -1 and len(err_msg) < 79 - len(header): + msg = [msg[0] + " " + err_msg] + else: + msg.append(err_msg) + if verbose: + for i, a in enumerate(arrays): + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + # r_func = partial(array_repr, precision=precision) + r_func = ndarray.__repr__ + else: + r_func = repr + + try: + r = r_func(a) + except Exception as exc: + r = f"[repr failed for <{type(a).__name__}>: {exc}]" + if r.count("\n") > 3: + r = "\n".join(r.splitlines()[:3]) + r += "..." + msg.append(f" {names[i]}: {r}") + return "\n".join(msg) + + +def assert_equal(actual, desired, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal. + + Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), + check that all elements of these objects are equal. An exception is raised + at the first conflicting values. + + When one of `actual` and `desired` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + + This function handles NaN comparisons as if NaN was a "normal" number. + That is, AssertionError is not raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal. + + Examples + -------- + >>> np.testing.assert_equal([4,5], [4,6]) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal: + item=1 + ACTUAL: 5 + DESIRED: 6 + + The following comparison does not raise an exception. There are NaNs + in the inputs, but they are in the same positions. + + >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + num_nones = sum([actual is None, desired is None]) + if num_nones == 1: + raise AssertionError(f"Not equal: {actual} != {desired}") + elif num_nones == 2: + return True + # else, carry on + + if isinstance(actual, np.DType) or isinstance(desired, np.DType): + result = actual == desired + if not result: + raise AssertionError(f"Not equal: {actual} != {desired}") + else: + return True + + if isinstance(desired, str) and isinstance(actual, str): + assert actual == desired + return + + if isinstance(desired, dict): + if not isinstance(actual, dict): + raise AssertionError(repr(type(actual))) + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in desired.keys(): + if k not in actual: + raise AssertionError(repr(k)) + assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) + return + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose) + return + + from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit + + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_equal(actual, desired, err_msg, verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except (ValueError, TypeError): + usecomplex = False + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_equal(actualr, desiredr) + assert_equal(actuali, desiredi) + except AssertionError: + raise AssertionError(msg) # noqa: TRY200 + + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + + # Inf/nan/negative zero handling + try: + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan and isactnan: + return # both nan, so equal + + # handle signed zero specially for floats + array_actual = np.asarray(actual) + array_desired = np.asarray(desired) + + if desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) + + except (TypeError, ValueError, NotImplementedError): + pass + + try: + # Explicitly use __eq__ for comparison, gh-2552 + if not (desired == actual): + raise AssertionError(msg) + + except (DeprecationWarning, FutureWarning) as e: + # this handles the case when the two types are not even comparable + if "elementwise == comparison" in e.args[0]: + raise AssertionError(msg) # noqa: TRY200 + else: + raise + + +def print_assert_equal(test_string, actual, desired): + """ + Test if two objects are equal, and print an error message if test fails. + + The test is performed with ``actual == desired``. + + Parameters + ---------- + test_string : str + The message supplied to AssertionError. + actual : object + The object to test for equality against `desired`. + desired : object + The expected result. + + Examples + -------- + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Test XYZ of func xyz failed + ACTUAL: + [0, 1] + DESIRED: + [0, 2] + + """ + __tracebackhide__ = True # Hide traceback for py.test + import pprint + + if not (actual == desired): + msg = StringIO() + msg.write(test_string) + msg.write(" failed\nACTUAL: \n") + pprint.pprint(actual, msg) + msg.write("DESIRED: \n") + pprint.pprint(desired, msg) + raise AssertionError(msg.getvalue()) + + +def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies that the elements of `actual` and `desired` satisfy. + + ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + decimal : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> from torch._numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 10 decimals + ACTUAL: 2.3333333333333 + DESIRED: 2.33333334 + + >>> assert_almost_equal(np.array([1.0,2.3333333333333]), + ... np.array([1.0,2.33333334]), decimal=9) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 9 decimals + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 6.666699636781459e-09 + Max relative difference: 2.8571569790287484e-09 + x: torch.ndarray([1.0000, 2.3333], dtype=float64) + y: torch.ndarray([1.0000, 2.3333], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import imag, iscomplexobj, ndarray, real + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except ValueError: + usecomplex = False + + def _build_err_msg(): + header = "Arrays are not almost equal to %d decimals" % decimal + return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header) + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_almost_equal(actualr, desiredr, decimal=decimal) + assert_almost_equal(actuali, desiredi, decimal=decimal) + except AssertionError: + raise AssertionError(_build_err_msg()) # noqa: TRY200 + + if isinstance(actual, (ndarray, tuple, list)) or isinstance( + desired, (ndarray, tuple, list) + ): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(_build_err_msg()) + else: + if not desired == actual: + raise AssertionError(_build_err_msg()) + return + except (NotImplementedError, TypeError): + pass + if abs(desired - actual) >= np.float64(1.5 * 10.0 ** (-decimal)): + raise AssertionError(_build_err_msg()) + + +def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to significant + digits. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + Given two numbers, check that they are approximately equal. + Approximately equal is defined as the number of significant digits + that agree. + + Parameters + ---------- + actual : scalar + The object to check. + desired : scalar + The expected object. + significant : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP + ... significant=8) + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP + ... significant=8) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal to 8 significant digits: + ACTUAL: 1.234567e-21 + DESIRED: 1.2345672e-21 + + the evaluated condition that raises the exception is + + >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) + True + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + (actual, desired) = map(float, (actual, desired)) + if desired == actual: + return + # Normalized the numbers to be in range (-10.0,10.0) + # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) + scale = 0.5 * (np.abs(desired) + np.abs(actual)) + scale = np.power(10, np.floor(np.log10(scale))) + try: + sc_desired = desired / scale + except ZeroDivisionError: + sc_desired = 0.0 + try: + sc_actual = actual / scale + except ZeroDivisionError: + sc_actual = 0.0 + msg = build_err_msg( + [actual, desired], + err_msg, + header="Items are not equal to %d significant digits:" % significant, + verbose=verbose, + ) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + except (TypeError, NotImplementedError): + pass + if np.abs(sc_desired - sc_actual) >= np.power(10.0, -(significant - 1)): + raise AssertionError(msg) + + +def assert_array_compare( + comparison, + x, + y, + err_msg="", + verbose=True, + header="", + precision=6, + equal_nan=True, + equal_inf=True, + *, + strict=False, +): + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import all, array, asarray, bool_, inf, isnan, max + + x = asarray(x) + y = asarray(y) + + def array2string(a): + return str(a) + + # original array for output formatting + ox, oy = x, y + + def func_assert_same_pos(x, y, func=isnan, hasval="nan"): + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + __tracebackhide__ = True # Hide traceback for py.test + x_id = func(x) + y_id = func(y) + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implementations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if (x_id == y_id).all().item() is not True: + msg = build_err_msg( + [x, y], + err_msg + "\nx and y %s location mismatch:" % (hasval), + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + # If there is a scalar, then here we know the array has the same + # flag as it everywhere, so we should return the scalar flag. + if isinstance(x_id, bool) or x_id.ndim == 0: + return bool_(x_id) + elif isinstance(y_id, bool) or y_id.ndim == 0: + return bool_(y_id) + else: + return y_id + + try: + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if not cond: + if x.shape != y.shape: + reason = f"\n(shapes {x.shape}, {y.shape} mismatch)" + else: + reason = f"\n(dtypes {x.dtype}, {y.dtype} mismatch)" + msg = build_err_msg( + [x, y], + err_msg + reason, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + + flagged = bool_(False) + + if equal_nan: + flagged = func_assert_same_pos(x, y, func=isnan, hasval="nan") + + if equal_inf: + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == +inf, hasval="+inf" + ) + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == -inf, hasval="-inf" + ) + + if flagged.ndim > 0: + x, y = x[~flagged], y[~flagged] + # Only do the comparison if actual values are left + if x.size == 0: + return + elif flagged: + # no sense doing comparison if everything is flagged. + return + + val = comparison(x, y) + + if isinstance(val, bool): + cond = val + reduced = array([val]) + else: + reduced = val.ravel() + cond = reduced.all() + + # The below comparison is a hack to ensure that fully masked + # results, for which val.ravel().all() returns np.ma.masked, + # do not trigger a failure (np.ma.masked != True evaluates as + # np.ma.masked, which is falsy). + if not cond: + n_mismatch = reduced.size - int(reduced.sum(dtype=intp)) + n_elements = flagged.size if flagged.ndim != 0 else reduced.size + percent_mismatch = 100 * n_mismatch / n_elements + remarks = [ + f"Mismatched elements: {n_mismatch} / {n_elements} ({percent_mismatch:.3g}%)" + ] + + # with errstate(all='ignore'): + # ignore errors for non-numeric types + with contextlib.suppress(TypeError, RuntimeError): + error = abs(x - y) + if np.issubdtype(x.dtype, np.unsignedinteger): + error2 = abs(y - x) + np.minimum(error, error2, out=error) + max_abs_error = max(error) + remarks.append( + "Max absolute difference: " + array2string(max_abs_error.item()) + ) + + # note: this definition of relative error matches that one + # used by assert_allclose (found in np.isclose) + # Filter values where the divisor would be zero + nonzero = bool_(y != 0) + if all(~nonzero): + max_rel_error = array(inf) + else: + max_rel_error = max(error[nonzero] / abs(y[nonzero])) + remarks.append( + "Max relative difference: " + array2string(max_rel_error.item()) + ) + + err_msg += "\n" + "\n".join(remarks) + msg = build_err_msg( + [ox, oy], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + except ValueError: + import traceback + + efmt = traceback.format_exc() + header = f"error during assertion:\n\n{efmt}\n\n{header}" + + msg = build_err_msg( + [x, y], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise ValueError(msg) # noqa: TRY200 + + +def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): + """ + Raises an AssertionError if two array_like objects are not equal. + + Given two array_like objects, check that the shape is equal and all + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. + + The usual caution for verifying equality with floating point numbers is + advised. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function checks that each element of the array_like object is equal to + the scalar. This behaviour can be disabled with the `strict` parameter. + + Examples + -------- + The first assert does not raise an exception: + + >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], + ... [np.exp(0),2.33333, np.nan]) + + Use `assert_allclose` or one of the nulp (number of floating point values) + functions for these cases instead: + + >>> np.testing.assert_allclose([1.0,np.pi,np.nan], + ... [1, np.sqrt(np.pi)**2, np.nan], + ... rtol=1e-10, atol=0) + + As mentioned in the Notes section, `assert_array_equal` has special + handling for scalars. Here the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_array_equal(x, 3) + + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: torch.ndarray([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: torch.ndarray(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes dtype("int64"), dtype("float32") mismatch) + x: torch.ndarray([2, 2, 2]) + y: torch.ndarray([2., 2., 2.]) + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__eq__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not equal", + strict=strict, + ) + + +def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. + + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + decimal : int, optional + Desired precision, default is 6. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + the first assert does not raise an exception + + >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], + ... [1.0,2.333,np.nan]) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33339,np.nan], decimal=5) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 5.999999999994898e-05 + Max relative difference: 2.5713661239633743e-05 + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33333, 5], decimal=5) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + x and y nan location mismatch: + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3333, 5.0000], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import any as npany, float_, issubdtype, number, result_type + + def compare(x, y): + try: + if npany(gisinf(x)) or npany(gisinf(y)): + xinfid = gisinf(x) + yinfid = gisinf(y) + if not (xinfid == yinfid).all(): + return False + # if one item, x and y is +- inf + if x.size == y.size == 1: + return x == y + x = x[~xinfid] + y = y[~yinfid] + except (TypeError, NotImplementedError): + pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.0) + y = asanyarray(y, dtype) + z = abs(x - y) + + if not issubdtype(z.dtype, number): + z = z.astype(float_) # handle object arrays + + return z < 1.5 * 10.0 ** (-decimal) + + assert_array_compare( + compare, + x, + y, + err_msg=err_msg, + verbose=verbose, + header=("Arrays are not almost equal to %d decimals" % decimal), + precision=decimal, + ) + + +def assert_array_less(x, y, err_msg="", verbose=True): + """ + Raises an AssertionError if two array_like objects are not ordered by less + than. + + Given two array_like objects, check that the shape is equal and all + elements of the first object are strictly smaller than those of the + second object. An exception is raised at shape mismatch or incorrectly + ordered values. Shape mismatch does not raise if an object has zero + dimension. In contrast to the standard usage in numpy, NaNs are + compared, no assertion is raised if both objects have NaNs in the same + positions. + + + + Parameters + ---------- + x : array_like + The smaller object to check. + y : array_like + The larger object to compare. + err_msg : string + The error message to be printed in case of failure. + verbose : bool + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_array_equal: tests objects for equality + assert_array_almost_equal: test objects for equality up to precision + + + + Examples + -------- + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 1.0 + Max relative difference: 0.5 + x: torch.ndarray([1., 1., nan], dtype=float64) + y: torch.ndarray([1., 2., nan], dtype=float64) + + >>> np.testing.assert_array_less([1.0, 4.0], 3) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 2.0 + Max relative difference: 0.6666666666666666 + x: torch.ndarray([1., 4.], dtype=float64) + y: torch.ndarray(3) + + >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + (shapes (3,), (1,) mismatch) + x: torch.ndarray([1., 2., 3.], dtype=float64) + y: torch.ndarray([4]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__lt__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not less-ordered", + equal_inf=False, + ) + + +def assert_string_equal(actual, desired): + """ + Test if two strings are equal. + + If the given strings are equal, `assert_string_equal` does nothing. + If they are not equal, an AssertionError is raised, and the diff + between the strings is shown. + + Parameters + ---------- + actual : str + The string to test for equality against the expected string. + desired : str + The expected string. + + Examples + -------- + >>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP + >>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP + Traceback (most recent call last): + File "", line 1, in + ... + AssertionError: Differences in strings: + - abc+ abcd? + + + """ + # delay import of difflib to reduce startup time + __tracebackhide__ = True # Hide traceback for py.test + import difflib + + if not isinstance(actual, str): + raise AssertionError(repr(type(actual))) + if not isinstance(desired, str): + raise AssertionError(repr(type(desired))) + if desired == actual: + return + + diff = list( + difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)) + ) + diff_list = [] + while diff: + d1 = diff.pop(0) + if d1.startswith(" "): + continue + if d1.startswith("- "): + l = [d1] + d2 = diff.pop(0) + if d2.startswith("? "): + l.append(d2) + d2 = diff.pop(0) + if not d2.startswith("+ "): + raise AssertionError(repr(d2)) + l.append(d2) + if diff: + d3 = diff.pop(0) + if d3.startswith("? "): + l.append(d3) + else: + diff.insert(0, d3) + if d2[2:] == d1[2:]: + continue + diff_list.extend(l) + continue + raise AssertionError(repr(d1)) + if not diff_list: + return + msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" + if actual != desired: + raise AssertionError(msg) + + +import unittest + + +class _Dummy(unittest.TestCase): + def nop(self): + pass + + +_d = _Dummy("nop") + + +def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): + """ + assert_raises_regex(exception_class, expected_regexp, callable, *args, + **kwargs) + assert_raises_regex(exception_class, expected_regexp) + + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Alternatively, can be used as a context manager like `assert_raises`. + + Notes + ----- + .. versionadded:: 1.9.0 + + """ + __tracebackhide__ = True # Hide traceback for py.test + return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) + + +def decorate_methods(cls, decorator, testmatch=None): + """ + Apply a decorator to all methods in a class matching a regular expression. + + The given decorator is applied to all public methods of `cls` that are + matched by the regular expression `testmatch` + (``testmatch.search(methodname)``). Methods that are private, i.e. start + with an underscore, are ignored. + + Parameters + ---------- + cls : class + Class whose methods to decorate. + decorator : function + Decorator to apply to methods + testmatch : compiled regexp or str, optional + The regular expression. Default value is None, in which case the + nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) + is used. + If `testmatch` is a string, it is compiled to a regular expression + first. + + """ + if testmatch is None: + testmatch = re.compile(r"(?:^|[\\b_\\.%s-])[Tt]est" % os.sep) + else: + testmatch = re.compile(testmatch) + cls_attr = cls.__dict__ + + # delayed import to reduce startup time + from inspect import isfunction + + methods = [_m for _m in cls_attr.values() if isfunction(_m)] + for function in methods: + try: + if hasattr(function, "compat_func_name"): + funcname = function.compat_func_name + else: + funcname = function.__name__ + except AttributeError: + # not a function + continue + if testmatch.search(funcname) and not funcname.startswith("_"): + setattr(cls, funcname, decorator(function)) + return + + +def _assert_valid_refcount(op): + """ + Check that ufuncs don't mishandle refcount of object `1`. + Used in a few regression tests. + """ + if not HAS_REFCOUNT: + return True + + import gc + + import numpy as np + + b = np.arange(100 * 100).reshape(100, 100) + c = b + i = 1 + + gc.disable() + try: + rc = sys.getrefcount(i) + for j in range(15): + d = op(b, c) + assert_(sys.getrefcount(i) >= rc) + finally: + gc.enable() + del d # for pyflakes + + +def assert_allclose( + actual, + desired, + rtol=1e-7, + atol=0, + equal_nan=True, + err_msg="", + verbose=True, + check_dtype=False, +): + """ + Raises an AssertionError if two objects are not equal up to desired + tolerance. + + Given two array_like objects, check that their shapes and all elements + are equal (but see the Notes for the special handling of a scalar). An + exception is raised if the shapes mismatch or any values conflict. In + contrast to the standard usage in numpy, NaNs are compared like numbers, + no assertion is raised if both objects have NaNs in the same positions. + + The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note + that ``allclose`` has different default values). It compares the difference + between `actual` and `desired` to ``atol + rtol * abs(desired)``. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + actual : array_like + Array obtained. + desired : array_like + Array desired. + rtol : float, optional + Relative tolerance. + atol : float, optional + Absolute tolerance. + equal_nan : bool, optional. + If True, NaNs will compare equal. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_array_almost_equal_nulp, assert_array_max_ulp + + Notes + ----- + When one of `actual` and `desired` is a scalar and the other is + array_like, the function checks that each element of the array_like + object is equal to the scalar. + + Examples + -------- + >>> x = [1e-5, 1e-3, 1e-1] + >>> y = np.arccos(np.cos(x)) + >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + def compare(x, y): + return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + actual, desired = asanyarray(actual), asanyarray(desired) + header = f"Not equal to tolerance rtol={rtol:g}, atol={atol:g}" + + if check_dtype: + assert actual.dtype == desired.dtype + + assert_array_compare( + compare, + actual, + desired, + err_msg=str(err_msg), + verbose=verbose, + header=header, + equal_nan=equal_nan, + ) + + +def assert_array_almost_equal_nulp(x, y, nulp=1): + """ + Compare two arrays relatively to their spacing. + + This is a relatively robust method to compare two arrays whose amplitude + is variable. + + Parameters + ---------- + x, y : array_like + Input arrays. + nulp : int, optional + The maximum number of unit in the last place for tolerance (see Notes). + Default is 1. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the spacing between `x` and `y` for one or more elements is larger + than `nulp`. + + See Also + -------- + assert_array_max_ulp : Check that all items of arrays differ in at most + N Units in the Last Place. + spacing : Return the distance between x and the nearest adjacent number. + + Notes + ----- + An assertion is raised if the following condition is not met:: + + abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) + + Examples + -------- + >>> x = np.array([1., 1e-10, 1e-20]) + >>> eps = np.finfo(x.dtype).eps + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP + + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: X and Y are not equal to 1 ULP (max is 2) + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ax = np.abs(x) + ay = np.abs(y) + ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) + if not np.all(np.abs(x - y) <= ref): + if np.iscomplexobj(x) or np.iscomplexobj(y): + msg = "X and Y are not equal to %d ULP" % nulp + else: + max_nulp = np.max(nulp_diff(x, y)) + msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) + raise AssertionError(msg) + + +def assert_array_max_ulp(a, b, maxulp=1, dtype=None): + """ + Check that all items of arrays differ in at most N Units in the Last Place. + + Parameters + ---------- + a, b : array_like + Input arrays to be compared. + maxulp : int, optional + The maximum number of units in the last place that elements of `a` and + `b` can differ. Default is 1. + dtype : dtype, optional + Data-type to convert `a` and `b` to if given. Default is None. + + Returns + ------- + ret : ndarray + Array containing number of representable floating point numbers between + items in `a` and `b`. + + Raises + ------ + AssertionError + If one or more elements differ by more than `maxulp`. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + See Also + -------- + assert_array_almost_equal_nulp : Compare two arrays relatively to their + spacing. + + Examples + -------- + >>> a = np.linspace(0., 1., 100) + >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ret = nulp_diff(a, b, dtype) + if not np.all(ret <= maxulp): + raise AssertionError( + f"Arrays are not almost equal up to {maxulp:g} " + f"ULP (max difference is {np.max(ret):g} ULP)" + ) + return ret + + +def nulp_diff(x, y, dtype=None): + """For each item in x and y, return the number of representable floating + points between them. + + Parameters + ---------- + x : array_like + first input array + y : array_like + second input array + dtype : dtype, optional + Data-type to convert `x` and `y` to if given. Default is None. + + Returns + ------- + nulp : array_like + number of representable floating point numbers between each item in x + and y. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + Examples + -------- + # By definition, epsilon is the smallest number such as 1 + eps != 1, so + # there should be exactly one ULP between 1 and 1 + eps + >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) # doctest: +SKIP + 1.0 + """ + import numpy as np + + if dtype: + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) + else: + x = np.asarray(x) + y = np.asarray(y) + + t = np.common_type(x, y) + if np.iscomplexobj(x) or np.iscomplexobj(y): + raise NotImplementedError("_nulp not implemented for complex array") + + x = np.array([x], dtype=t) + y = np.array([y], dtype=t) + + x[np.isnan(x)] = np.nan + y[np.isnan(y)] = np.nan + + if not x.shape == y.shape: + raise ValueError(f"x and y do not have the same shape: {x.shape} - {y.shape}") + + def _diff(rx, ry, vdt): + diff = np.asarray(rx - ry, dtype=vdt) + return np.abs(diff) + + rx = integer_repr(x) + ry = integer_repr(y) + return _diff(rx, ry, t) + + +def _integer_repr(x, vdt, comp): + # Reinterpret binary representation of the float as sign-magnitude: + # take into account two-complement representation + # See also + # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ + rx = x.view(vdt) + if not (rx.size == 1): + rx[rx < 0] = comp - rx[rx < 0] + else: + if rx < 0: + rx = comp - rx + + return rx + + +def integer_repr(x): + """Return the signed-magnitude interpretation of the binary representation + of x.""" + import numpy as np + + if x.dtype == np.float16: + return _integer_repr(x, np.int16, np.int16(-(2**15))) + elif x.dtype == np.float32: + return _integer_repr(x, np.int32, np.int32(-(2**31))) + elif x.dtype == np.float64: + return _integer_repr(x, np.int64, np.int64(-(2**63))) + else: + raise ValueError(f"Unsupported dtype {x.dtype}") + + +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with suppress_warnings() as sup: + l = sup.record(warning_class) + yield + if not len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError("No warning raised" + name_str) + + +def assert_warns(warning_class, *args, **kwargs): + """ + Fail unless the given callable throws the specified warning. + + A warning of class warning_class should be thrown by the callable when + invoked with arguments args and keyword arguments kwargs. + If a different type of warning is thrown, it will not be caught. + + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + warning_class : class + The class defining the warning that `func` is expected to throw. + func : callable, optional + Callable to test + *args : Arguments + Arguments for `func`. + **kwargs : Kwargs + Keyword arguments for `func`. + + Returns + ------- + The value returned by `func`. + + Examples + -------- + >>> import warnings + >>> def deprecated_func(num): + ... warnings.warn("Please upgrade", DeprecationWarning) + ... return num*num + >>> with np.testing.assert_warns(DeprecationWarning): + ... assert deprecated_func(4) == 16 + >>> # or passing a func + >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) + >>> assert ret == 16 + """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter("always") + yield + if len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError(f"Got warnings{name_str}: {l}") + + +def assert_no_warnings(*args, **kwargs): + """ + Fail if the given callable produces any warnings. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + The value returned by `func`. + + """ + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) + + +def _gen_alignment_data(dtype=float32, type="binary", max_size=24): + """ + generator producing data with different alignment and offsets + to test simd vectorization + + Parameters + ---------- + dtype : dtype + data type to produce + type : string + 'unary': create data for unary operations, creates one input + and output array + 'binary': create data for unary operations, creates two input + and output array + max_size : integer + maximum size of data to produce + + Returns + ------- + if type is 'unary' yields one output, one input array and a message + containing information on the data + if type is 'binary' yields one output array, two input array and a message + containing information on the data + + """ + ufmt = "unary offset=(%d, %d), size=%d, dtype=%r, %s" + bfmt = "binary offset=(%d, %d, %d), size=%d, dtype=%r, %s" + for o in range(3): + for s in range(o + 2, max(o + 3, max_size)): + if type == "unary": + + def inp(): + return arange(s, dtype=dtype)[o:] + + out = empty((s,), dtype=dtype)[o:] + yield out, inp(), ufmt % (o, o, s, dtype, "out of place") + d = inp() + yield d, d, ufmt % (o, o, s, dtype, "in place") + yield out[1:], inp()[:-1], ufmt % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp()[1:], ufmt % ( + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") + yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") + if type == "binary": + + def inp1(): + return arange(s, dtype=dtype)[o:] + + inp2 = inp1 + out = empty((s,), dtype=dtype)[o:] + yield out, inp1(), inp2(), bfmt % (o, o, o, s, dtype, "out of place") + d = inp1() + yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") + d = inp2() + yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") + yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", + ) + + +class IgnoreException(Exception): + "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit. To work correctly, all 'ignore' filters should + filter by one of these modules. + + Examples + -------- + >>> import warnings + >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP + ... modules=[np.core.fromnumeric]): + ... warnings.simplefilter('always') + ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') + ... # do something that raises a warning but ignore those in + ... # np.core.fromnumeric + """ + + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super().__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super().__enter__() + + def __exit__(self, *exc_info): + super().__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class suppress_warnings: + """ + Context manager and decorator doing much the same as + ``warnings.catch_warnings``. + + However, it also provides a filter mechanism to work around + https://bugs.python.org/issue4180. + + This bug causes Python before 3.4 to not reliably show warnings again + after they have been ignored once (even within catch_warnings). It + means that no "ignore" filter can be used easily, since following + tests might need to see the warning. Additionally it allows easier + specificity for testing warnings and can be nested. + + Parameters + ---------- + forwarding_rule : str, optional + One of "always", "once", "module", or "location". Analogous to + the usual warnings module filter mode, it is useful to reduce + noise mostly on the outmost level. Unsuppressed and unrecorded + warnings will be forwarded based on this rule. Defaults to "always". + "location" is equivalent to the warnings "default", match by exact + location the warning warning originated from. + + Notes + ----- + Filters added inside the context manager will be discarded again + when leaving it. Upon entering all filters defined outside a + context will be applied automatically. + + When a recording filter is added, matching warnings are stored in the + ``log`` attribute as well as in the list returned by ``record``. + + If filters are added and the ``module`` keyword is given, the + warning registry of this module will additionally be cleared when + applying it, entering the context, or exiting it. This could cause + warnings to appear a second time after leaving the context if they + were configured to be printed once (default) and were already + printed before the context was entered. + + Nesting this context manager will work as expected when the + forwarding rule is "always" (default). Unfiltered and unrecorded + warnings will be passed out and be matched by the outer level. + On the outmost level they will be printed (or caught by another + warnings context). The forwarding rule argument can modify this + behaviour. + + Like ``catch_warnings`` this context manager is not threadsafe. + + Examples + -------- + + With a context manager:: + + with np.testing.suppress_warnings() as sup: + sup.filter(DeprecationWarning, "Some text") + sup.filter(module=np.ma.core) + log = sup.record(FutureWarning, "Does this occur?") + command_giving_warnings() + # The FutureWarning was given once, the filtered warnings were + # ignored. All other warnings abide outside settings (may be + # printed/error) + assert_(len(log) == 1) + assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator:: + + sup = np.testing.suppress_warnings() + sup.filter(module=np.ma.core) # module must match exactly + @sup + def some_function(): + # do something which causes a warning in np.ma.core + pass + """ + + def __init__(self, forwarding_rule="always"): + self._entered = False + + # Suppressions are either instance or defined inside one with block: + self._suppressions = [] + + if forwarding_rule not in {"always", "module", "once", "location"}: + raise ValueError("unsupported forwarding rule.") + self._forwarding_rule = forwarding_rule + + def _clear_registries(self): + if hasattr(warnings, "_filters_mutated"): + # clearing the registry should not be necessary on new pythons, + # instead the filters should be mutated. + warnings._filters_mutated() + return + # Simply clear the registry, this should normally be harmless, + # note that on new pythons it would be invalidated anyway. + for module in self._tmp_modules: + if hasattr(module, "__warningregistry__"): + module.__warningregistry__.clear() + + def _filter(self, category=Warning, message="", module=None, record=False): + if record: + record = [] # The log where to store warnings + else: + record = None + if self._entered: + if module is None: + warnings.filterwarnings("always", category=category, message=message) + else: + module_regex = module.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=category, message=message, module=module_regex + ) + self._tmp_modules.add(module) + self._clear_registries() + + self._tmp_suppressions.append( + (category, message, re.compile(message, re.I), module, record) + ) + else: + self._suppressions.append( + (category, message, re.compile(message, re.I), module, record) + ) + + return record + + def filter(self, category=Warning, message="", module=None): + """ + Add a new suppressing filter or apply it if the state is entered. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + self._filter(category=category, message=message, module=module, record=False) + + def record(self, category=Warning, message="", module=None): + """ + Append a new recording filter or apply it if the state is entered. + + All warnings matching will be appended to the ``log`` attribute. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Returns + ------- + log : list + A list which will be filled with all matched warnings. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + return self._filter( + category=category, message=message, module=module, record=True + ) + + def __enter__(self): + if self._entered: + raise RuntimeError("cannot enter suppress_warnings twice.") + + self._orig_show = warnings.showwarning + self._filters = warnings.filters + warnings.filters = self._filters[:] + + self._entered = True + self._tmp_suppressions = [] + self._tmp_modules = set() + self._forwarded = set() + + self.log = [] # reset global log (no need to keep same list) + + for cat, mess, _, mod, log in self._suppressions: + if log is not None: + del log[:] # clear the log + if mod is None: + warnings.filterwarnings("always", category=cat, message=mess) + else: + module_regex = mod.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=cat, message=mess, module=module_regex + ) + self._tmp_modules.add(mod) + warnings.showwarning = self._showwarning + self._clear_registries() + + return self + + def __exit__(self, *exc_info): + warnings.showwarning = self._orig_show + warnings.filters = self._filters + self._clear_registries() + self._entered = False + del self._orig_show + del self._filters + + def _showwarning( + self, message, category, filename, lineno, *args, use_warnmsg=None, **kwargs + ): + for cat, _, pattern, mod, rec in (self._suppressions + self._tmp_suppressions)[ + ::-1 + ]: + if issubclass(category, cat) and pattern.match(message.args[0]) is not None: + if mod is None: + # Message and category match, either recorded or ignored + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + # Use startswith, because warnings strips the c or o from + # .pyc/.pyo files. + elif mod.__file__.startswith(filename): + # The message and module (filename) match + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + + # There is no filter in place, so pass to the outside handler + # unless we should only pass it once + if self._forwarding_rule == "always": + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + return + + if self._forwarding_rule == "once": + signature = (message.args, category) + elif self._forwarding_rule == "module": + signature = (message.args, category, filename) + elif self._forwarding_rule == "location": + signature = (message.args, category, filename, lineno) + + if signature in self._forwarded: + return + self._forwarded.add(signature) + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + + def __call__(self, func): + """ + Function decorator to apply certain suppressions to a whole + function. + """ + + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + +@contextlib.contextmanager +def _assert_no_gc_cycles_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + + # not meaningful to test if there is no refcounting + if not HAS_REFCOUNT: + yield + return + + assert_(gc.isenabled()) + gc.disable() + gc_debug = gc.get_debug() + try: + for i in range(100): + if gc.collect() == 0: + break + else: + raise RuntimeError( + "Unable to fully collect garbage - perhaps a __del__ method " + "is creating more reference cycles?" + ) + + gc.set_debug(gc.DEBUG_SAVEALL) + yield + # gc.collect returns the number of unreachable objects in cycles that + # were found -- we are checking that no cycles were created in the context + n_objects_in_cycles = gc.collect() + objects_in_cycles = gc.garbage[:] + finally: + del gc.garbage[:] + gc.set_debug(gc_debug) + gc.enable() + + if n_objects_in_cycles: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError( + "Reference cycles were found{}: {} objects were collected, " + "of which {} are shown below:{}".format( + name_str, + n_objects_in_cycles, + len(objects_in_cycles), + "".join( + "\n {} object with id={}:\n {}".format( + type(o).__name__, + id(o), + pprint.pformat(o).replace("\n", "\n "), + ) + for o in objects_in_cycles + ), + ) + ) + + +def assert_no_gc_cycles(*args, **kwargs): + """ + Fail if the given callable produces any reference cycles. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_gc_cycles(): + do_something() + + .. versionadded:: 1.15.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + Nothing. The result is deliberately discarded to ensure that all cycles + are found. + + """ + if not args: + return _assert_no_gc_cycles_context() + + func = args[0] + args = args[1:] + with _assert_no_gc_cycles_context(name=func.__name__): + func(*args, **kwargs) + + +def break_cycles(): + """ + Break reference cycles by calling gc.collect + Objects can call other objects' methods (for instance, another object's + __del__) inside their own __del__. On PyPy, the interpreter only runs + between calls to gc.collect, so multiple calls are needed to completely + release all cycles. + """ + + gc.collect() + if IS_PYPY: + # a few more, just to make sure all the finalizers are called + gc.collect() + gc.collect() + gc.collect() + gc.collect() + + +def requires_memory(free_bytes): + """Decorator to skip a test if not enough memory is available""" + import pytest + + def decorator(func): + @wraps(func) + def wrapper(*a, **kw): + msg = check_free_memory(free_bytes) + if msg is not None: + pytest.skip(msg) + + try: + return func(*a, **kw) + except MemoryError: + # Probably ran out of memory regardless: don't regard as failure + pytest.xfail("MemoryError raised") + + return wrapper + + return decorator + + +def check_free_memory(free_bytes): + """ + Check whether `free_bytes` amount of memory is currently free. + Returns: None if enough memory available, otherwise error message + """ + env_var = "NPY_AVAILABLE_MEM" + env_value = os.environ.get(env_var) + if env_value is not None: + try: + mem_free = _parse_size(env_value) + except ValueError as exc: + raise ValueError( # noqa: TRY200 + f"Invalid environment variable {env_var}: {exc}" + ) + + msg = ( + f"{free_bytes/1e9} GB memory required, but environment variable " + f"NPY_AVAILABLE_MEM={env_value} set" + ) + else: + mem_free = _get_mem_available() + + if mem_free is None: + msg = ( + "Could not determine available memory; set NPY_AVAILABLE_MEM " + "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " + "the test." + ) + mem_free = -1 + else: + msg = ( + f"{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available" + ) + + return msg if mem_free < free_bytes else None + + +def _parse_size(size_str): + """Convert memory size strings ('12 GB' etc.) to float""" + suffixes = { + "": 1, + "b": 1, + "k": 1000, + "m": 1000**2, + "g": 1000**3, + "t": 1000**4, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "tb": 1000**4, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "tib": 1024**4, + } + + size_re = re.compile( + r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), re.I + ) + + m = size_re.match(size_str.lower()) + if not m or m.group(2) not in suffixes: + raise ValueError(f"value {size_str!r} not a valid size") + return int(float(m.group(1)) * suffixes[m.group(2)]) + + +def _get_mem_available(): + """Return available memory in bytes, or None if unknown.""" + try: + import psutil + + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + + if sys.platform.startswith("linux"): + info = {} + with open("/proc/meminfo") as f: + for line in f: + p = line.split() + info[p[0].strip(":").lower()] = int(p[1]) * 1024 + + if "memavailable" in info: + # Linux >= 3.14 + return info["memavailable"] + else: + return info["memfree"] + info["cached"] + + return None + + +def _no_tracing(func): + """ + Decorator to temporarily turn off tracing for the duration of a test. + Needed in tests that check refcounting, otherwise the tracing itself + influences the refcounts + """ + if not hasattr(sys, "gettrace"): + return func + else: + + @wraps(func) + def wrapper(*args, **kwargs): + original_trace = sys.gettrace() + try: + sys.settrace(None) + return func(*args, **kwargs) + finally: + sys.settrace(original_trace) + + return wrapper + + +def _get_glibc_version(): + try: + ver = os.confstr("CS_GNU_LIBC_VERSION").rsplit(" ")[1] + except Exception as inst: + ver = "0.0" + + return ver + + +_glibcver = _get_glibc_version() + + +def _glibc_older_than(x): + return _glibcver != "0.0" and _glibcver < x diff --git a/MLPY/Lib/site-packages/torch/_ops.py b/MLPY/Lib/site-packages/torch/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..95774269e3ccb86ba4decffa94e960189f9f2aff --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_ops.py @@ -0,0 +1,1037 @@ +import contextlib +import ctypes +import importlib +import inspect +import sys +import types +from typing import Any, Callable, Dict, Set, Type, Union + +import torch._C +import torch.utils._pytree as pytree +from torch import _utils_internal +from torch._functorch.pyfunctorch import dispatch_functorch +from torch.utils._python_dispatch import TorchDispatchMode + +# Query `hasattr` only once. + +_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") + + +@contextlib.contextmanager +def dl_open_guard(): + """ + Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a + shared library to load custom operators. + """ + if not _SET_GLOBAL_FLAGS: + yield + return + old_flags = sys.getdlopenflags() + sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) + try: + yield + finally: + sys.setdlopenflags(old_flags) + + +class OperatorBase: + """ + Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator + (which represents Python-only operators that are unrepresentable in TorchScript). + """ + + def __init__(self): + # The dispatch cache precomputes a mapping of dispatch key that the + # dispatcher wants to dispatch to, to an actual implementation of the + # dispatch key. Confusingly, the actual implementation could *also* be a + # dispatch key, but in this case, this refers to the C++ kernel that + # was registered to some dispatch key. Aliases are permitted in the + # latter but not the former; for example, you might lookup the + # entry for AutogradCPU, and this maps you to the Autograd key for + # the generic autograd kernel that works for all devices. Since this + # is the Python dispatcher, you can also put an arbitrary Python + # callable to call instead. This handler gets precisely the + # args/kwargs that the operator was __call__'ed with. + # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp + # for use with OpOverload; cache lookup is done entirely from C++ + # for speed. + # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! + self._dispatch_cache: Dict[ + torch._C.DispatchKey, Union[torch._C.DispatchKey, Callable[..., Any]] + ] = {} + + # This table allows you to override the behavior of a particular + # dispatch key to call a custom Python function, rather than the + # ordinary C++ configured behavior. This is the raison d'etre of + # Python dispatcher: to let you program the dispatcher from Python + # in case you need something unusual, and don't want to clobber + # the existing registrations using the Python operator registration + # API. + self.py_kernels: Dict[torch._C.DispatchKey, Callable[..., Any]] = {} + + # This table allows you to override the behavior of a particular + # operator for a particular TorchDispatchMode. In practice, + # we are using this mostly for ProxyTensorMode. Modes can be + # thought of as an open world extension of dispatch keys, so it + # makes sense that you should be able to register them, the same + # way you can register dispatch keys. + self.python_key_mode_table: Dict[ + Type[TorchDispatchMode], Callable[..., Any] + ] = {} + + # This table allows you to override the behavior of functorch + # transformations. NB: this currently only does something for + # HigherOrderOperator + self.functorch_table = {} + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def has_kernel_for_dispatch_key(self, k): + return k in self.py_kernels + + def has_kernel_for_any_dispatch_key(self, ks): + for k in self.py_kernels: + if not torch._C._dispatch_is_alias_key(k) and ks.has(k): + return True + return False + + def py_impl(self, k): + def inner(fn): + if inspect.isclass(k) and issubclass(k, TorchDispatchMode): + assert k not in self.python_key_mode_table + # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? + self.python_key_mode_table[k] = fn + self._dispatch_cache.clear() + return fn + + if isinstance(k, torch._C._functorch.TransformType): + assert k not in self.functorch_table + self.functorch_table[k] = fn + return fn + + assert isinstance(k, torch._C.DispatchKey) + assert ( + k != torch._C.DispatchKey.Python + ), "Please register a mode for the torch._C.DispatchKey.Python key instead." + + if k in self.py_kernels: + raise RuntimeError( + f"Trying to override a python impl for {k} on operator {self.name()}" + ) + self.py_kernels[k] = fn + self._dispatch_cache.clear() + return fn + + return inner + + # Registers an implementation to all **3** variants of functionalization that we have: + # - DispatchKey.Functionalize + # - functorch.TransformType.Functionalize + # - FunctionalTensorMode + # Example: + # @py_functionalize_impl + # def functionalize_rule(ctx, inner_f, *args): + # args_unwrapped = ctx.unwrap_tensors(args) + # with ctx.redispatch_to_next(): + # out = ctx.functionalize(inner_f)(*args_unwrapped) + # return ctx.wrap_tensors(out) + def py_functionalize_impl(self, fn): + from torch._subclasses.functional_tensor import ( + CppFunctionalizeAPI as _CppFunctionalizeAPI, + FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, + PythonFunctionalizeAPI as _PythonFunctionalizeAPI, + ) + + # Construct our three flavors of functionalization, + # each of which have slightly different wrap/unwrap/redispatch policies + def functionalize_dk_fn(*args, **kwargs): + return fn(_CppFunctionalizeAPI(), *args, **kwargs) + + def functionalize_dispatch_mode_fn(mode, *args, **kwargs): + return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) + + def functionalize_functorch_fn(interpreter, *args, **kwargs): + return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) + + self.py_impl(torch._C.DispatchKey.Functionalize)(functionalize_dk_fn) + self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( + functionalize_dispatch_mode_fn + ) + self.py_impl(torch._C._functorch.TransformType.Functionalize)( + functionalize_functorch_fn + ) + + return fn + + def name(self): + raise NotImplementedError() + + +is_included_in_alias = torch._C._dispatch_is_included_in_alias + +DispatchKey = torch._C.DispatchKey + + +# Equivalent to computeDispatchTableEntryWithDebug +def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] + # 1. (Direct) operator registration + if op.has_kernel_for_dispatch_key(k): + return k + # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available + cand = DispatchKey.CompositeExplicitAutogradNonFunctional + if ( + k == DispatchKey.Undefined or is_included_in_alias(k, cand) + ) and op.has_kernel_for_dispatch_key(cand): + return cand + # 2.2 Use CompositeExplicitAutograd kernel if available + cand = DispatchKey.CompositeExplicitAutograd + if ( + k == DispatchKey.Undefined or is_included_in_alias(k, cand) + ) and op.has_kernel_for_dispatch_key(cand): + return cand + has_backend_kernel = op.has_kernel_for_any_dispatch_key( + torch._C._dispatch_get_backend_keyset_from_autograd(k) + ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) + # 2.3. Use CompositeImplicitAutograd kernel if available + cand = DispatchKey.CompositeImplicitAutogradNestedTensor + if ( + (k != DispatchKey.Undefined and is_included_in_alias(k, cand)) + and op.has_kernel_for_dispatch_key(cand) + and not has_backend_kernel + ): + return cand + cand = DispatchKey.CompositeImplicitAutograd + if ( + k == DispatchKey.Undefined or is_included_in_alias(k, cand) + ) and op.has_kernel_for_dispatch_key(cand): + if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( + torch._C._dispatch_autogradother_backends + ): + raise RuntimeError("ambiguous autogradother kernel") + elif not has_backend_kernel: + return cand + # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available + cand = DispatchKey.Autograd + if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): + return cand + # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available + cand = DispatchKey.FuncTorchBatchedDecomposition + if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): + return cand + # Backend fallback + if torch._C._dispatch_has_backend_fallback(k): + # The dispatch key itself will implicitly route to backend fallback. + # This is probably not great for the pure Python implementation. + return k + raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") + + +_higher_order_ops: Dict[str, "HigherOrderOperator"] = {} + +_HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ + DispatchKey.PythonDispatcher, # type: ignore[attr-defined] + DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined] + DispatchKey.ADInplaceOrView, + DispatchKey.BackendSelect, + DispatchKey.AutocastCPU, # type: ignore[attr-defined] + DispatchKey.AutocastCUDA, # type: ignore[attr-defined] +] + + +class HigherOrderOperator(OperatorBase): + # The HigherOrderOperator will appear as torch.ops.higher_order.{name} + # + # If you're creating a new HigherOrderOperator, please do not change the + # default. Adding operators to the global torch.ops namespace is a bad + # practice due to name collisions. + def __init__(self, name): + super().__init__() + self._name = name + + # Make _OPNamespace not scream, this whole name based association needs a good hard look + self.__name__ = name + _higher_order_ops[name] = self + self._ns = "higher_order" + + # For a normal HigherOrderOperator instance, we will change its __module__ from torch._ops to + # torch._ops.higher_order. + # For an instance of subclass of HigherOrderOperator (e.g. customized higher order op), + # the __module__ attribute will be kept unchanged. + if self.__class__ is HigherOrderOperator: + self_name_space = "." + self.namespace if self.namespace else "" + self.__module__ = self.__module__ + self_name_space + self.non_fallthrough_keys = torch._C._dispatch_keyset_full() + + for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: + self.fallthrough(dispatch_key) + + # [NOTE] We have to register pre-dispatch key implementation + # because sometimes HOP use aot-dispatch tracing to detect certaion + # mutations. This is problematic when we are functionalizing HOP + # during pre-dispatch because when the inner tracer starts, it will see + # that PreDispatch key is still active. In that case, we just redispatch + # it to next key. This is only safe to do when PreDispatch key stack has no + # active modes. + # TODO (tmanlaibaatar) Make it generic fallback mechanism + def _(*args, **kwargs): + if _len_torch_dispatch_stack_pre_dispatch() == 0: + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(DispatchKey.PreDispatch) + ): + return self(*args, **kwargs) + raise AssertionError( + """ + Can't directly invoke HOP implementation at PreDispatch key + if there are active modes on PreDispatch mode stack. + """ + ) + + self.py_impl(torch._C.DispatchKey.PreDispatch)(_) + + def py_impl(self, k): + if isinstance(k, torch._C.DispatchKey) and not self.non_fallthrough_keys.has(k): + self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) + return super().py_impl(k) + + @property + def namespace(self): + return self._ns + + def fallthrough(self, dispatch_key): + self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) + + def dispatch(self, dispatch_key, *args, **kwargs): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + if dispatch_key in self._dispatch_cache: + kernel = self._dispatch_cache[dispatch_key] + assert not isinstance(kernel, torch._C.DispatchKey) + return kernel(*args, **kwargs) + + if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode: + return dispatch_functorch(self, args, kwargs) + + if dispatch_key == torch._C.DispatchKey.Python: + # The place to handle ProxyTorchDispatchMode, FakeTensorMode, etc + from torch.utils._python_dispatch import _pop_mode_temporarily + + curr_mode = _get_current_dispatch_mode() + assert ( + curr_mode is not None + ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." + assert ( + type(curr_mode) in self.python_key_mode_table + ), f"Current active mode {curr_mode} not registered" + handler = self.python_key_mode_table[type(curr_mode)] + with _pop_mode_temporarily() as mode: + return handler(mode, *args, **kwargs) + + functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] + if functionality_key == torch._C.DispatchKey.PreDispatch: + from torch.utils._python_dispatch import _pop_mode_temporarily + + # The check for Python in the exclude set is so we properly respect `with no_dispatch()` + # calls inside of a mode. + if ( + _len_torch_dispatch_stack_pre_dispatch() > 0 + ) and not torch._C._dispatch_tls_is_dispatch_key_excluded( + DispatchKey.Python + ): + curr_mode = _get_current_dispatch_mode_pre_dispatch() + assert ( + curr_mode is not None + ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." + assert ( + type(curr_mode) in self.python_key_mode_table + ), f"Current active mode {curr_mode} not registered" + handler = self.python_key_mode_table[type(curr_mode)] + with _pop_mode_temporarily(functionality_key) as mode: + return handler(mode, *args, **kwargs) + + final_key = resolve_key(self, dispatch_key) + + # This can current fail due to backend fallbacks. You just have to + # register them by hand for HigherOrderOperator. + if final_key not in self.py_kernels: + raise NotImplementedError( + f"could not find kernel for HigherOrderOperator {self._name} " + f"at dispatch key {final_key} (resolved from {dispatch_key})" + ) + self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] + kernel = self.py_kernels[final_key] + # It's illegal to register DispatchKey to py_kernels, since there's no + # C++ kernel to call into + assert not isinstance(kernel, torch._C.DispatchKey) + return kernel(*args, **kwargs) + + def __call__(self, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo + from torch._dynamo import disable + + @disable + def wrapper(): + flat_args = _to_flat_tuple(args, kwargs) + if torch.overrides.has_torch_function(flat_args): + return torch.overrides.handle_torch_function( + self, flat_args, *args, **kwargs + ) + + dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) + return self.dispatch( + dispatch_key_set.highestPriorityTypeId(), *args, **kwargs + ) + + return wrapper() + + def __str__(self): + return f"{self.name()}" + + def name(self): + return self._name + + +def _to_flat_tuple(args, kwargs): + return pytree.arg_tree_leaves(*args, **kwargs) + + +def _compute_keyset(args, kwargs, non_fallthrough_keys): + tensors = _get_tensors(args, kwargs) + return key_extractor(tensors, non_fallthrough_keys) + + +def _get_tensors(args, kwargs): + flat_all = _to_flat_tuple(args, kwargs) + tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] + return tuple(tensor_args) + + +# Note - this should maintain identical impl to the C++ dispatcher key extraction logic +# at ATen/core/dispatch/DispatchKeyExtractor.h +def key_extractor(tensors, key_mask): + key_set = torch._C._dispatch_tls_local_include_set() + for tensor in tensors: + key_set = key_set | torch._C._dispatch_keys(tensor) + key_set = key_set - torch._C._dispatch_tls_local_exclude_set() + key_set = key_set & key_mask + return key_set + + +# Mode stack for PreDispatchKey +# it should always have two keys with +# priority given to FunctionalTensorMode and +# then ProxyTorchDispatchMode. It means that +# slot 0 belongs to ProxyTorchDispatchMode and +# slot 1 belongs to FunctionalTensorMode. +class _ModeStackStateForPreDispatch: + def __init__(self): + self.__infra_modes = [None, None] + + def set(self, index, mode): + assert index < len(self.__infra_modes) + self.__infra_modes[index] = mode + + def get(self, index): + assert index < len(self.__infra_modes) + return self.__infra_modes[index] + + def count(self): + return len([i for i in self.__infra_modes if i is not None]) + + +_mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() + + +def unset_mode_pre_dispatch(mode_key): + current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() + assert mode_key in ( + torch._C._TorchDispatchModeKey.PROXY, + torch._C._TorchDispatchModeKey.FUNCTIONAL, + ) + if mode_key == torch._C._TorchDispatchModeKey.PROXY: + current_mode = current_mode_stack_pre_dispatch.get(0) + mode_stack_state_for_pre_dispatch().set(0, None) + return current_mode + else: + current_mode = current_mode_stack_pre_dispatch.get(1) + mode_stack_state_for_pre_dispatch().set(1, None) + return current_mode + + +def _set_mode_pre_dispatch(mode): + from torch._subclasses.functional_tensor import FunctionalTensorMode + from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode + + assert isinstance(mode, (FunctionalTensorMode, ProxyTorchDispatchMode)) + if isinstance(mode, FunctionalTensorMode): + current_mode = mode_stack_state_for_pre_dispatch().get(1) + assert current_mode is None + mode_stack_state_for_pre_dispatch().set(1, mode) + return + + current_mode = mode_stack_state_for_pre_dispatch().get(0) + assert current_mode is None + mode_stack_state_for_pre_dispatch().set(0, mode) + + +def _pop_mode_from_pre_dispatch(): + mode_stack = mode_stack_state_for_pre_dispatch() + if mode_stack.get(1) is not None: + res = mode_stack.get(1) + mode_stack.set(1, None) + return res + + if mode_stack.get(0) is not None: + res = mode_stack.get(0) + mode_stack.set(0, None) + return res + + raise AssertionError("Trying to pop empty mode stack") + + +def _len_torch_dispatch_stack_pre_dispatch(): + return mode_stack_state_for_pre_dispatch().count() + + +def _get_dispatch_mode_pre_dispatch(mode_key): + assert mode_key in ( + torch._C._TorchDispatchModeKey.PROXY, + torch._C._TorchDispatchModeKey.FUNCTIONAL, + ) + if mode_key == torch._C._TorchDispatchModeKey.PROXY: + return mode_stack_state_for_pre_dispatch().get(0) + return mode_stack_state_for_pre_dispatch().get(1) + + +def _get_current_dispatch_mode_pre_dispatch(): + stack_len = mode_stack_state_for_pre_dispatch().count() + if stack_len == 2: + return mode_stack_state_for_pre_dispatch().get(1) + if stack_len == 1: + return ( + mode_stack_state_for_pre_dispatch().get(1) + if mode_stack_state_for_pre_dispatch().get(1) is not None + else mode_stack_state_for_pre_dispatch().get(0) + ) + return None + + +def mode_stack_state_for_pre_dispatch(): + global _mode_stack_state_for_pre_dispatch + return _mode_stack_state_for_pre_dispatch + + +cached_ops: Set["OpOverload"] = set() + + +def add_cached_op(op_overload): + global cached_ops + cached_ops.add(op_overload) + + +def reset_cached_ops(): + global cached_ops + cached_ops.clear() + + +def get_cached_ops(): + global cached_ops + return cached_ops + + +# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. +# You can obtain an OpOverload object through attribute query on OpOverloadPacket. +class OpOverload(OperatorBase): + def __init__(self, overloadpacket, op, op_dk, schema, tags): + super().__init__() + self._op = op + self._op_dk = op_dk + self._schema = schema + self._overloadpacket = overloadpacket + self._tags = tags + self._overloadname = ( + "default" if schema.overload_name == "" else schema.overload_name + ) + self._name = self._schema.name + if schema.overload_name: + self._name += "." + schema.overload_name + self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" + self.__module__ = overloadpacket.__module__ + op.__module__ = overloadpacket.__module__ + self.__qualname__ = self._name + self.__annotations__ = {} + + # If the OpOverload was constructed from a Library.def in Python. + self._defined_in_python = self.__qualname__ in torch.library._defs + + # Logic replicated from aten/src/ATen/native/MathBitsFallback.h + is_write = None + for a in self._schema.arguments: + if a.alias_info is None: + continue + if is_write is None: + is_write = a.alias_info.is_write + else: + # We will conservatively call mixed mutable/non-mutable + # aliased inputs as NOT a view + is_write = a.alias_info.is_write or is_write + self.is_view = is_write is not None and not is_write + + # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. + def __deepcopy__(self, memo=None): + return self + + def __repr__(self): + return "".format( + *self._schema.name.split("::"), self._overloadname + ) + + def __call__(self_, *args, **kwargs): # noqa: B902 + # use `self_` to avoid naming collide with aten ops arguments that + # are named "self". This way, all the aten ops can be called by kwargs. + return self_._op(*args, **kwargs) + + def __hash__(self): + return hash(self._op) + + # `my_namespace.my_op_name.overload_name` + def __str__(self): + return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) + + def has_kernel_for_dispatch_key(self, k): + return super().has_kernel_for_dispatch_key( + k + ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) + + def has_kernel_for_any_dispatch_key(self, ks): + return torch._C._dispatch_has_kernel_for_any_dispatch_key( + self.name(), ks + ) or super().has_kernel_for_any_dispatch_key(ks) + + @property + def namespace(self): + return self._schema.name.split("::")[0] + + def _handle(self): + return torch._C._dispatch_find_schema_or_throw( + self._schema.name, self._schema.overload_name + ) + + def decompose(self, *args, **kwargs): + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if dk in self.py_kernels: + # NB: This branch is not too necessary anymore, because we can + # apply Python CompositeImplicitAutograd *before* tracing + # using Python dispatcher (also taking advantage of the autograd + # formula). But it's included for completeness + return self.py_kernels[dk](*args, **kwargs) + elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): + return self._op_dk(dk, *args, **kwargs) + else: + return NotImplemented + + # Remove a dispatch key from the dispatch cache. This will force it to get + # recomputed the next time. Does nothing + # WARNING: if you register a dispatch key to py_kernels of an OpOverload, + # calling _del_dispatch on that key is NOT sufficient to apply your change, + # because a single registration may affect MULTIPLE dispatch keys (e.g., + # registering Autograd affects AutogradCPU). del_dispatch is to be used + # only if you are specifically modifying how get_dispatch handles a + # particular input 'key'. + def _uncache_dispatch(self, key): + self._dispatch_cache.pop(key, None) + + # This implements the pre-computation logic for the Python dispatcher. + def _get_dispatch(self, key): + # This is only called upon a cache miss + assert key not in self._dispatch_cache, f"{self} {key}" + + if key == torch._C.DispatchKey.Python: + if not self.python_key_mode_table: + self._dispatch_cache[key] = key + add_cached_op(self) + return key + + def handler(*args, **kwargs): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + # TODO: We also need to handle tensor subclasses here + # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. + curr_mode = type(_get_current_dispatch_mode()) + assert ( + curr_mode is not None + ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." + if curr_mode not in self.python_key_mode_table: + # TODO: This path is slow, should generally encourage this + # case to not happen + return self._op_dk(key, *args, **kwargs) + # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. + return self.python_key_mode_table[curr_mode](*args, **kwargs) + + self._dispatch_cache[key] = handler + add_cached_op(self) + return handler + + functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined] + if functionality_key == torch._C.DispatchKey.PreDispatch: + curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() + # The check for Python in the exclude set is so we properly respect `with no_dispatch()` + # calls inside of a mode. + if ( + curr_stack_len > 0 + and not torch._C._dispatch_tls_is_dispatch_key_excluded( + DispatchKey.Python + ) + ): + + def handler(*args, **kwargs): + @contextlib.contextmanager + def _temporarily_pop_modes_from_pre_dispatch(): + top_mode = _pop_mode_from_pre_dispatch() + try: + yield top_mode + finally: + _set_mode_pre_dispatch(top_mode) + + with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: + assert isinstance(curr_mode, TorchDispatchMode) + overload_types = [] + args_flattened, _ = torch.utils._pytree.tree_flatten( + (args, kwargs.values()) + ) + for a in args_flattened: + # TODO: need to double check the semantics of the "types" argument to torch_dispatch. + # It's generated in PyInterpreter.cpp, but seems to be generated in two places, + # where in one case we only include tensors with the python key, and in another + # we include **all** tensors. + if isinstance(a, torch.Tensor) and torch._C._dispatch_keys( + a + ).has(torch._C.DispatchKey.Python): + overload_types.append(type(a)) + # TODO: check that I got these args correct (in C++, we pass in "0000"??) + + return curr_mode.__torch_dispatch__( + self, overload_types, args, kwargs + ) + + # Note [Not Caching Per-Dispatch-Key Mode Handlers] + # Note that we're not caching this handler. There isn't really a point, since the slow bit + # is the handler itself (in python). + # Also, not caching means that we don't have to reset the cache when any existing + # modes go out of scope (which in of itself takes time to loop through all operators). + return handler + + final_key = resolve_key(self, key) + + # See Note [Not Caching Per-Dispatch-Key Mode Handlers] + cache_result = key != torch._C.DispatchKey.PreDispatch + + # TODO: We could potentially have lots of debugging wrappers against + # dispatch keys; design some general registration mechanism instead of + # having if statement for each of them + if key == torch._C.DispatchKey.Functionalize: + import torch._dispatch.python as pydispatch + + if pydispatch.CROSSREF_FUNCTIONALIZE: + handler = pydispatch.make_crossref_functionalize(self, final_key) + if cache_result: + self._dispatch_cache[key] = handler + add_cached_op(self) + return handler + + # print(self, key, final_key) + r = self.py_kernels.get(final_key, final_key) + if cache_result: + self._dispatch_cache[key] = r + add_cached_op(self) + return r + + def name(self): + return self._name + + @property + def overloadpacket(self): + return self._overloadpacket + + @property + def op(self): + return self._op + + @property + def tags(self): + return self._tags + + # TODO: add more methods to expose information about input and output arguments + + +# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator +# You can obtain an OpOverload object through attribute query. +class OpOverloadPacket: + def __init__(self, qualified_op_name, op_name, op, overload_names): + # These attributes are accessible on the object through the properties + # defined below but are immutable + self._qualified_op_name = qualified_op_name + self.__name__ = op_name + self._op = op + self._overload_names = overload_names + self._dir = [] + + # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. + def __deepcopy__(self, memo=None): + return self + + def __repr__(self): + return "".format( + *self._qualified_op_name.split("::") + ) + + def __hash__(self): + return hash(self._op) + + def __str__(self): + return "{}.{}".format(*self._qualified_op_name.split("::")) + + @property + def op(self): + return self._op + + def __getattr__(self, key): + # It is not a valid op_name when __file__ is passed in + if key == "__file__": + return "torch.ops" + + # ensure that query for dunder attributes that does not exist on + # opoverloadpacket but instead exists on the self._op object does not unnecessarily call + # `_get_operation_overload` (which is an expensive operation). + # This is done to prevent any potential slowdown. This list can be extended + # if there exists other attributes like `__name__` that only exist on self._op and not on the + # opoverloadpacket. + # This is ok since we are guaranteed that an overload name for an aten op can't start with '__' + try: + if key.startswith("__"): + return getattr(self._op, key) + except AttributeError: + # for consistency because it seems weird to + # throw an attribute error with a message containing + # an object name different from the one the attribute + # query was performed on. + raise AttributeError( + f"'{str(self)}' can't have an overload name beginning with '__' and the " + f"underlying op {str(self._op)} has no attribute {key} either." + ) from None + + try: + # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' + use_key = "" if key == "default" else key + # TODO: disallow access to overloads registered by JIT + op_, op_dk_, tags = torch._C._get_operation_overload( + self._qualified_op_name, use_key + ) + schema = torch._C._get_schema(self._qualified_op_name, use_key) + overload = OpOverload(self, op_, op_dk_, schema, tags) + # cache the overload object + setattr(self, key, overload) + self._dir.append(key) + return overload + except RuntimeError: + raise AttributeError( + f"The underlying op of '{str(self)}' has no overload name '{key}'" + ) from None + + def __iter__(self): + return iter(self._dir) + + def __call__(self_, *args, **kwargs): # noqa: B902 + # use `self_` to avoid naming collide with aten ops arguments that + # named "self". This way, all the aten ops can be called by kwargs. + + # overloading __call__ to ensure torch.ops.foo.bar() + # is still callable from JIT + # We save the function ptr as the `op` attribute on + # OpOverloadPacket to access it here. + return self_._op(*args, **(kwargs or {})) + + # TODO: use this to make a __dir__ + def overloads(self): + return [n if n else "default" for n in self._overload_names] + + +# Resolution of torch.fn is different from torch.ops.aten.fn +# torch.fn uses the Python argparser, matches with the +# appropriate schema, and calls into the unboxed version of the method +# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. +# JIT creates a stack of all the overloads and then tries to match the +# correct one at runtime and always calls into the boxed version of the method +# Autograd codegen creates VariableType, TracerType, +# inplace or view type and python bindings. +# Aten codegen generates tensor methods for the tensor class. + +# _OpNamespace is a subclass of ModuleType because the torch script +# allows attribute lookups on modules only. Since we want torch.ops.foo.bar() +# to work from script, we need to ensure ops and foo are modules + + +class _OpNamespace(types.ModuleType): + """ + An op namespace to dynamically bind Operators into Python. + + Say a user has created a custom Operator called "my_namespace::my_op". To + call this op, the user will write torch.ops.my_namespace.my_op(...). + At startup, this operation will not yet be bound into Python. Instead, the + following sequence of magic tricks will occur: + 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method + on the `torch.ops` object, which will create a new `_OpNamespace` + object called `my_namespace` and set it as an attribute on the `ops` + object. + 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on + the `my_namespace` object, which will retrieve the operation via + `torch.get_operation`, a function bound from C++, and then in a similar + fashion bind this new object onto the `my_namespace` object. + 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation + and subsequent accesses will incur no further lookup (the namespace and + operation will already exist). + """ + + def __init__(self, name): + super().__init__("torch.ops." + name) + self.name = name + self._dir = [] + + def __iter__(self): + return iter(self._dir) + + def __getattr__(self, op_name): + # It is not a valid op_name when __file__ is passed in + if op_name == "__file__": + return "torch.ops" + elif op_name in ["__origin__", "__self__"]: + raise AttributeError( + f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" + ) + + # Get the op `my_namespace::my_op` if available. This will also check + # for overloads and raise an exception if there are more than one. + namespace_name = self.name + qualified_op_name = f"{namespace_name}::{op_name}" + try: + op, overload_names = torch._C._jit_get_operation(qualified_op_name) + if op is None: + raise AttributeError( + f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" + ) + except RuntimeError as e: + # Turn this into AttributeError so getattr(obj, key, default) + # works (this is called by TorchScript with __origin__) + raise AttributeError( + f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" + ) from e + + # let the script frontend know that op is identical to the builtin op + # with qualified_op_name + torch.jit._builtins._register_builtin(op, qualified_op_name) + op.__module__ = self.__module__ + "." + namespace_name + opoverloadpacket = OpOverloadPacket( + qualified_op_name, op_name, op, overload_names + ) + opoverloadpacket.__module__ = self.__module__ + "." + namespace_name + # cache the opoverloadpacket to ensure that each op corresponds to + # a unique OpOverloadPacket object + setattr(self, op_name, opoverloadpacket) + self._dir.append(op_name) + return opoverloadpacket + + +class _PyOpNamespace(_OpNamespace): + def __init__(self, name, ops): + super().__init__(name) + self._ops = ops + + def __getattr__(self, name): + # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object. + op = self._ops.get(name, None) + if op is None: + raise AttributeError( + f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" + ) + setattr(self, name, op) + return op + + +class _Ops(types.ModuleType): + __file__ = "_ops.py" + + def __init__(self): + super().__init__("torch.ops") + self.loaded_libraries = set() + self._higher_order_op_namespace = _PyOpNamespace( + "torch.ops.higher_order", _higher_order_ops + ) + self._dir = [] + + def __getattr__(self, name): + # Check if the name is a HigherOrderOperator + if name == "higher_order": + return self._higher_order_op_namespace + + # Here we are creating `torch.ops.my_namespace` + namespace = _OpNamespace(name) + setattr(self, name, namespace) + self._dir.append(name) + return namespace + + def __iter__(self): + return iter(self._dir) + + def import_module(self, module): + """ + Imports a Python module that has torch.library registrations. + + Generally, to extend PyTorch with custom operators, a user will + create a Python module whose import triggers registration of + the custom operators via a torch.ops.load_library call or a call + to one or more torch.library.* APIs. + + It is unexpected for Python modules to have side effects, so some + linters and formatters will complain. Use this API to import Python + modules that contain these torch.library side effects. + + Args: + module (str): The name of the Python module to import + + """ + importlib.import_module(module) + + def load_library(self, path): + """ + Loads a shared library from the given path into the current process. + + The library being loaded may run global initialization code to register + custom operators with the PyTorch JIT runtime. This allows dynamically + loading custom operators. For this, you should compile your operator + and the static registration code into a shared library object, and then + call ``torch.ops.load_library('path/to/libcustom.so')`` to load the + shared object. + + After the library is loaded, it is added to the + ``torch.ops.loaded_libraries`` attribute, a set that may be inspected + for the paths of all libraries loaded using this function. + + Args: + path (str): A path to a shared library to load. + """ + if torch._running_with_deploy(): + return + + path = _utils_internal.resolve_library_path(path) + with dl_open_guard(): + # Import the shared library into the process, thus running its + # static (global) initialization code in order to register custom + # operators with the JIT. + ctypes.CDLL(path) + self.loaded_libraries.add(path) + + +# The ops "namespace" +ops = _Ops() diff --git a/MLPY/Lib/site-packages/torch/_prims/__init__.py b/MLPY/Lib/site-packages/torch/_prims/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28af648500b507672f07811a38faef549844d794 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims/__init__.py @@ -0,0 +1,3031 @@ +import contextlib +import itertools +import operator +import weakref +from enum import Enum +from functools import partial, reduce +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union + +import torch + +import torch._prims_common as utils +import torch.library +from torch import sym_float, Tensor, TypedStorage +from torch._C import _get_default_device +from torch._prims.debug_prims import register_debug_prims +from torch._prims.rng_prims import register_rng_prims +from torch._prims_common import ( + Dim, + DimsSequenceType, + DimsType, + IntLike, + Number, + NumberType, + RETURN_TYPE, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + type_to_dtype, +) +from torch._prims_common.wrappers import backwards_not_supported +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.overrides import handle_torch_function, has_torch_function +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +prim = torch.library.Library("prims", "DEF") +prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd") +prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect") +prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd") +prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta") + +# Experimental module containing prototype "primitive" operations. + +__all__ = [ + # + # Common datastructures and helpers + # + "RETURN_TYPE", + # + # Elementwise unary prims + # + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", + "cos", + "cosh", + "bessel_i0", + "bessel_i0e", + "bessel_i1", + "bessel_i1e", + "bessel_j0", + "bessel_j1", + "bitwise_not", + "cbrt", + "ceil", + "conj_physical", + "digamma", + "erf", + "erf_inv", + "erfc", + "erfcx", + "exp", + "expm1", + "exp2", + "fill", + "floor", + "imag", + "isfinite", + "lgamma", + "log", + "log1p", + "log2", + "log10", + "ndtri", + "neg", + "real", + "reciprocal", + "round", + "sign", + "signbit", + "sin", + "sinh", + "spherical_bessel_j0", + "sqrt", + "tan", + "tanh", + "trunc", + # + # Elementwise binary prims + # + "add", + "atan2", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + # 'complex', # needs custom meta + "div", + "eq", + "fmax", + "fmin", + "fmod", + "frexp", + "gcd", + "ge", + "gt", + "hypot", + "igamma", + "igammac", + "le", + "lt", + "maximum", + "minimum", + "mul", + "ne", + "nextafter", + "pow", + "remainder", + "rsqrt", + "shift_left", + "shift_right_arithmetic", + "shift_right_logical", # not implemented + "sub", + "zeta", + # + # View prims + # + "as_strided", + "broadcast_in_dim", + "collapse_view", + "conj", + "expand_dims", + "slice", + "slice_in_dim", # implemented using slice -- make this a ref? + "split_dim", + "squeeze", + "transpose", + "view_of", + "view_element_type", + # + # Functionalized view mutations + # + "as_strided_scatter", + # + # Shape prims + # + "collapse", + "cat", + "reshape", + "rev", + # + # Conditional prims + # + "where", + # + # Data conversion and movement prims + # + "clone", + "convert_element_type", + "device_put", + "item", + "maximum_value", + "minimum_value", + "copy_strided", + # + # Inplace prims + # + "copy_to", + "resize", + # "_set", # Commented out, see note below + # + # Reduction prims + # + "amax", + "amin", + "prod", + "sum", + "xor_sum", + "var", + # + # Tensor Creation Prims + # + "empty_strided", + "empty_permuted", + "scalar_tensor", + "iota", + # + # Linear algebra (linalg) Prims + # + "svd", + # + # Randomness Prims + # + "normal", + "_uniform_helper", + # + # FFT prims + # + "fft_r2c", + "fft_c2c", + "fft_c2r", +] + + +def TensorMeta( + tensorlike: Optional[Union[NumberType, torch.Tensor]] = None, + *, + shape: Optional[ShapeType] = None, + strides: Optional[StrideType] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, +): + if isinstance(tensorlike, Number): + assert not shape and (shape is None or isinstance(shape, Sequence)) + assert not strides and (strides is None or isinstance(strides, Sequence)) + inferred_shape: Tuple[int, ...] = () + inferred_strides: Tuple[int, ...] = () + inferred_dtype = type_to_dtype(type(tensorlike)) + inferred_device = torch.device("cpu") + # TODO: This looks wrong, a number that is wrapped into a tensor + # needs to behave differently than a scalar tensor for type + # promotion purposes + elif tensorlike is not None: + assert isinstance(tensorlike, torch.Tensor) + inferred_shape = tuple(tensorlike.shape) + inferred_strides = tuple(tensorlike.stride()) + inferred_dtype = tensorlike.dtype + inferred_device = tensorlike.device + else: + # If no tensorlike "example" is given then all metadata + # must be provided explicitly + assert shape is not None + assert strides is not None + assert dtype is not None + assert device is not None + + shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined] + strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined] + dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined] + device = inferred_device if device is None else device # type: ignore[possibly-undefined] + + if isinstance(device, str): + device = torch.device(device) + + return torch.empty_strided(shape, strides, dtype=dtype, device=device) + + +def _make_prim( + *, + schema: str, + return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]], + meta: Callable, + impl_aten: Callable, + doc: str, + tags: Optional[Sequence[torch.Tag]] = None, +): + """ + Creates a primitive operation. + + """ + + prim.define(schema, tags=torch.Tag.pt2_compliant_tag) + + def _prim_impl(*args, **kwargs): + # always run the meta function because aten implementation will + # typically accept more inputs (e.g., it will do promotion and + # broadcasting) which we want to reject + meta(*args, **kwargs) + return impl_aten(*args, **kwargs) + + # Right now prims don't support autograd (we can and should add an + # argument that provides an implementation for backward here.) Because we + # don't have derivative formulas, we must setup a custom autograd function + # that raises an error if backwards is invoked + def _autograd_impl(*args, **kwargs): + return backwards_not_supported(_prim)(*args, **kwargs) + + def _backend_select_impl(*args, **kwargs): + if kwargs.get("device") and kwargs["device"].type == "meta": + return meta(*args, **kwargs) + if any(isinstance(x, torch.device) and x.type == "meta" for x in args): + return meta(*args, **kwargs) + else: + return _prim_impl(*args, **kwargs) + + name = schema.split("(")[0] + prim_impl.impl(name, _prim_impl) + prim_autograd_impl.impl(name, _autograd_impl) + prim_meta_impl.impl(name, meta) + + _prim_packet = getattr(torch._ops.ops.prims, name) + _prim = _prim_packet.default + if tags: + _prim._tags = tags + + from torch._subclasses.fake_tensor import contains_tensor_types + + if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str( + _prim + ) in [ + # See https://github.com/pytorch/pytorch/issues/103532 + "prims.device_put.default" + ]: + prim_backend_select_impl.impl(name, _backend_select_impl) + + for p in (_prim_packet, _prim): + p.__doc__ = doc + p.return_type = return_type # type: ignore[attr-defined] + + p.schema = schema + p.prim_impl = _prim_impl + p.prim_meta_impl = meta + p.impl_aten = impl_aten + + return _prim + + +class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + + +# TODO: implement dtype validation here, too, or on the corresponding refs +def _prim_elementwise_meta( + *args, + type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, + args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None, +) -> FakeTensor: + """ + Meta function for elementwise operations that produce outputs in the same dtype + as their inputs. + + Stride logic is currently incorrect. + """ + + assert len(args) > 0 + + utils.check_same_dtype(*args) + + args_ = list(args) + if args_with_fixed_dtypes is not None: + args_ = list(args_with_fixed_dtypes) + args_ + + utils.check_same_device(*args_, allow_cpu_scalar_tensors=True) + utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True) + + l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_) + shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True) + + # Acquires the dtype + dtype = None + scalar_type = None + for arg in args: + if isinstance(arg, TensorLike): + if not utils.is_cpu_scalar_tensor(arg): + dtype = arg.dtype + break + else: + dtype = arg.dtype + elif isinstance(arg, Number): + scalar_type = type(arg) + + if dtype is None and scalar_type is not None: + dtype = utils.type_to_dtype(scalar_type) + + # Acquires the device (if it exists) or number + device = None + number = None + for arg in args_: + if isinstance(arg, TensorLike): + if utils.is_cpu_scalar_tensor(arg): + if device is None: + device = arg.device + # keep going, in case there is a cuda tensor later + else: + device = arg.device + break + + elif isinstance(arg, Number): + if number is None: + number = arg + + # NOTE: type promotion behavior here is mostly hidden from tests because + # references will typically handle the type promotion properly even if this doesn't + # (but getting it wrong will cause too many casts to be inserted in traces!) + if device is not None: + assert dtype is not None + if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT: + dtype = dtype + elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + dtype = torch.bool + elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype): + dtype = torch.get_default_dtype() + elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: + if utils.is_complex_dtype(dtype): + dtype = utils.corresponding_real_dtype(dtype) + else: + dtype = dtype + + assert shape is not None + return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value] + + # Number case + # TODO: fix number type promotion (bool, complex->float) + + # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) + seen_float = False + if isinstance(number, (torch.SymInt, torch.SymFloat)): + for a in args: + assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" + seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) + if seen_float: + number = sym_float(number) + + return TensorMeta(number) # type: ignore[arg-type] + + +def _complex_only_elementwise_meta(*args, **kwargs): + torch._check( + utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported" + ) + return _prim_elementwise_meta(*args, **kwargs) + + +def _make_elementwise_unary_prim( + name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs +): + """ + Creates an elementwise unary prim. + """ + + return _make_prim( + schema=f"{name}(Tensor self) -> Tensor", + meta=partial(_prim_elementwise_meta, type_promotion=type_promotion), + return_type=RETURN_TYPE.NEW, + **kwargs, + ) + + +def _make_elementwise_binary_prim( + name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs +): + """ + Creates an elementwise binary prim. + """ + + return _make_prim( + schema=f"{name}(Tensor self, Tensor other) -> Tensor", + meta=partial(_prim_elementwise_meta, type_promotion=type_promotion), + return_type=RETURN_TYPE.NEW, + **kwargs, + ) + + +def _not_impl(*args, **kwargs): + raise NotImplementedError + + +# +# Elementwise unary operations +# + + +abs = _make_elementwise_unary_prim( + "abs", + impl_aten=torch.abs, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) + +acos = _make_elementwise_unary_prim( + "acos", + impl_aten=torch.acos, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +acosh = _make_elementwise_unary_prim( + "acosh", + impl_aten=torch.acosh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +asin = _make_elementwise_unary_prim( + "asin", + impl_aten=torch.asin, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +asinh = _make_elementwise_unary_prim( + "asinh", + impl_aten=torch.asinh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +atan = _make_elementwise_unary_prim( + "atan", + impl_aten=torch.atan, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +atanh = _make_elementwise_unary_prim( + "atanh", + impl_aten=torch.atanh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +cos = _make_elementwise_unary_prim( + "cos", + impl_aten=torch.cos, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +cosh = _make_elementwise_unary_prim( + "cosh", + impl_aten=torch.cosh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_j0 = _make_elementwise_unary_prim( + "bessel_j0", + impl_aten=torch.special.bessel_j0, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_j1 = _make_elementwise_unary_prim( + "bessel_j1", + impl_aten=torch.special.bessel_j1, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i0 = _make_elementwise_unary_prim( + "bessel_i0", + impl_aten=torch.i0, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i0e = _make_elementwise_unary_prim( + "bessel_i0e", + impl_aten=torch.special.i0e, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i1 = _make_elementwise_unary_prim( + "bessel_i1", + impl_aten=torch.special.i1, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i1e = _make_elementwise_unary_prim( + "bessel_i1e", + impl_aten=torch.special.i1e, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_not = _make_elementwise_unary_prim( + "bitwise_not", + impl_aten=torch.bitwise_not, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _cbrt_aten(a: torch.Tensor) -> Tensor: + torch._check( + not a.is_complex(), + lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)", + ) + # Returns the real cubic root of the number. + # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number + # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i} + # which is a complex number. + # For more info see the section Note in + # https://en.cppreference.com/w/cpp/numeric/math/cbrt + return torch.copysign(torch.pow(a.abs(), 1 / 3), a) + + +cbrt = _make_elementwise_unary_prim( + "cbrt", + impl_aten=_cbrt_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +ceil = _make_elementwise_unary_prim( + "ceil", + impl_aten=torch.ceil, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType: + if not input.dtype.is_complex: + raise RuntimeError("prims.conj_physical is only defined for complex dtypes") + + strides = utils.compute_elementwise_output_strides(input) + return TensorMeta(input, strides=strides) + + +conj_physical = _make_prim( + schema="conj_physical(Tensor self) -> Tensor", + meta=_conj_physical_meta, + impl_aten=torch._conj_physical, + doc="Returns the physical conjugation of a complex tensor", + return_type=RETURN_TYPE.NEW, +) + + +def _clone_meta( + input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format +) -> TensorLikeType: + if memory_format != torch.preserve_format: + return torch.empty( + input.shape, + dtype=input.dtype, + layout=input.layout, + device=input.device, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + strides = utils.compute_elementwise_output_strides(input) + return torch.empty_strided( + input.shape, + strides, + dtype=input.dtype, + layout=input.layout, + device=input.device, + ) + + +clone = _make_prim( + schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + meta=_clone_meta, + impl_aten=torch.clone, + doc="Returns the copy of a tensor", + return_type=RETURN_TYPE.NEW, +) + +digamma = _make_elementwise_unary_prim( + "digamma", + impl_aten=torch.digamma, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erf = _make_elementwise_unary_prim( + "erf", + impl_aten=torch.erf, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erf_inv = _make_elementwise_unary_prim( + "erf_inv", + impl_aten=torch.special.erfinv, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erfc = _make_elementwise_unary_prim( + "erfc", + impl_aten=torch.special.erfc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erfcx = _make_elementwise_unary_prim( + "erfcx", + impl_aten=torch.special.erfcx, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +exp = _make_elementwise_unary_prim( + "exp", + impl_aten=torch.exp, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +expm1 = _make_elementwise_unary_prim( + "expm1", + impl_aten=torch.special.expm1, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +exp2 = _make_elementwise_unary_prim( + "exp2", + impl_aten=torch.special.exp2, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType: + return _prim_elementwise_meta( + a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT + ) + + +# NOTE: fill uses _make_prim directly because it has a value parameter +fill = _make_prim( + schema="fill(Tensor self, Scalar value) -> Tensor", + return_type=RETURN_TYPE.NEW, + meta=_fill_meta, + impl_aten=torch.fill, + doc="", +) + +floor = _make_elementwise_unary_prim( + "floor", + impl_aten=torch.floor, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +imag = _make_prim( + schema="imag(Tensor self) -> Tensor", + meta=partial( + _complex_only_elementwise_meta, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + return_type=RETURN_TYPE.VIEW, + impl_aten=torch.imag, + doc="", +) + +isfinite = _make_elementwise_unary_prim( + "isfinite", + impl_aten=torch.isfinite, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +lgamma = _make_elementwise_unary_prim( + "lgamma", + impl_aten=torch.lgamma, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log = _make_elementwise_unary_prim( + "log", + impl_aten=torch.log, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log1p = _make_elementwise_unary_prim( + "log1p", + impl_aten=torch.log1p, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log2 = _make_elementwise_unary_prim( + "log2", + impl_aten=torch.log2, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log10 = _make_elementwise_unary_prim( + "log10", + impl_aten=torch.log10, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +real = _make_prim( + schema="real(Tensor self) -> Tensor", + meta=partial( + _complex_only_elementwise_meta, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + return_type=RETURN_TYPE.VIEW, + impl_aten=torch.real, + doc="", +) + +reciprocal = _make_elementwise_unary_prim( + "reciprocal", + impl_aten=torch.reciprocal, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +ndtri = _make_elementwise_unary_prim( + "ndtri", + impl_aten=torch.special.ndtri, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +neg = _make_elementwise_unary_prim( + "neg", + impl_aten=torch.neg, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +round = _make_elementwise_unary_prim( + "round", + impl_aten=torch.round, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +rsqrt = _make_elementwise_unary_prim( + "rsqrt", + impl_aten=torch.rsqrt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sign = _make_elementwise_unary_prim( + "sign", + impl_aten=torch.sign, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +signbit = _make_elementwise_unary_prim( + "signbit", + impl_aten=torch.signbit, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sin = _make_elementwise_unary_prim( + "sin", + impl_aten=torch.sin, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sinh = _make_elementwise_unary_prim( + "sinh", + impl_aten=torch.sinh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +spherical_bessel_j0 = _make_elementwise_unary_prim( + "spherical_bessel_j0", + impl_aten=torch.special.spherical_bessel_j0, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sqrt = _make_elementwise_unary_prim( + "sqrt", + impl_aten=torch.sqrt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +tan = _make_elementwise_unary_prim( + "tan", + impl_aten=torch.tan, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +tanh = _make_elementwise_unary_prim( + "tanh", + impl_aten=torch.tanh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +trunc = _make_elementwise_unary_prim( + "trunc", + impl_aten=torch.trunc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +# +# Elementwise binary operations +# + +add = _make_elementwise_binary_prim( + name="add", + impl_aten=torch.add, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +atan2 = _make_elementwise_binary_prim( + name="atan2", + impl_aten=torch.atan2, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_and = _make_elementwise_binary_prim( + "bitwise_and", + impl_aten=torch.bitwise_and, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_or = _make_elementwise_binary_prim( + "bitwise_or", + impl_aten=torch.bitwise_or, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_xor = _make_elementwise_binary_prim( + "bitwise_xor", + impl_aten=torch.bitwise_xor, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +# TODO: complex needs a special meta to account for its float -> complex behavior +# complex = _make_elementwise_binary_prim( +# impl_aten=torch.complex, +# doc="", +# ) + + +# div prim performs truncation division on integer inputs +# and true division for floating and complex inputs +def _div_aten(a, b): + is_integral = isinstance(a, (bool, int, torch.SymInt)) or ( + isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype) + ) + + if is_integral: + return torch.div(a, b, rounding_mode="trunc") + else: + return torch.true_divide(a, b) + + +div = _make_elementwise_binary_prim( + "div", + impl_aten=_div_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +eq = _make_elementwise_binary_prim( + "eq", + impl_aten=torch.eq, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +fmax = _make_elementwise_binary_prim( + "fmax", + impl_aten=torch.fmax, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +fmin = _make_elementwise_binary_prim( + "fmin", + impl_aten=torch.fmin, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +fmod = _make_elementwise_binary_prim( + "fmod", + impl_aten=torch.fmod, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +gcd = _make_elementwise_binary_prim( + "gcd", + impl_aten=torch.gcd, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +ge = _make_elementwise_binary_prim( + "ge", + impl_aten=torch.ge, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +gt = _make_elementwise_binary_prim( + "gt", + impl_aten=torch.gt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +hypot = _make_elementwise_binary_prim( + "hypot", + impl_aten=torch.hypot, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +igamma = _make_elementwise_binary_prim( + "igamma", + impl_aten=torch.special.gammainc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +igammac = _make_elementwise_binary_prim( + "igammac", + impl_aten=torch.special.gammaincc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +le = _make_elementwise_binary_prim( + "le", + impl_aten=torch.le, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +lt = _make_elementwise_binary_prim( + "lt", + impl_aten=torch.lt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + + +# Note: the following impls are because torch.maximum and torch.minimum do not support scalar inputs +def _maximum_aten( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +) -> TensorLikeType: + if isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + + return torch.maximum(a, b) # type: ignore[arg-type] + + +maximum = _make_elementwise_binary_prim( + "maximum", + impl_aten=_maximum_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _minimum_aten( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +) -> TensorLikeType: + if isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + + return torch.minimum(a, b) # type: ignore[arg-type] + + +minimum = _make_elementwise_binary_prim( + "minimum", + impl_aten=_minimum_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +mul = _make_elementwise_binary_prim( + "mul", + impl_aten=torch.mul, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +ne = _make_elementwise_binary_prim( + "ne", + impl_aten=torch.ne, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +nextafter = _make_elementwise_binary_prim( + "nextafter", + impl_aten=torch.nextafter, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +pow = _make_elementwise_binary_prim( + "pow", + impl_aten=torch.pow, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +remainder = _make_elementwise_binary_prim( + "remainder", + impl_aten=torch.remainder, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +shift_left = _make_elementwise_binary_prim( + "shift_left", + impl_aten=torch.bitwise_left_shift, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +shift_right_arithmetic = _make_elementwise_binary_prim( + "shift_right_arithmetic", + impl_aten=torch.bitwise_right_shift, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +shift_right_logical = _not_impl + +sub = _make_elementwise_binary_prim( + "sub", + impl_aten=torch.sub, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +zeta = _make_elementwise_binary_prim( + "zeta", + impl_aten=torch.special.zeta, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +# +# View operations +def _as_strided_meta( + a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int +) -> TensorLikeType: + assert len(size) == len(stride) + assert storage_offset >= 0 + utils.validate_strides(stride) + utils.validate_shape(size) + + if reduce(operator.mul, size) == 0: + # NOTE: This special case is to avoid having to acquire the storage below + # as_strided to shapes with no elements are trivially valid, so it's OK + pass + elif isinstance(a, torch.Tensor): + utils.check_in_bounds_for_storage( + a._typed_storage(), size, stride, storage_offset + ) + + return torch.as_strided(a, size, stride, storage_offset) + + +def _as_strided_aten( + a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int +) -> Tensor: + return torch.as_strided(a, size, stride, storage_offset) + + +_as_strided_doc = """ + Creates a view of the tensor with the given shape (size), strides (stride) and + storage offset (storage_offset). +""" + +as_strided = _make_prim( + schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)", + meta=_as_strided_meta, + impl_aten=_as_strided_aten, + return_type=RETURN_TYPE.VIEW, + doc=_as_strided_doc, +) + + +def _broadcast_in_dim_meta( + a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int] +): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # Type checks + assert isinstance(a, TensorLike) + assert isinstance(shape, Sequence) + assert isinstance(broadcast_dimensions, Sequence) + + # every dimension must be accounted for + assert a.ndim == len(broadcast_dimensions) + + # broadcast shape must have weakly more dimensions + assert len(shape) >= a.ndim + + # broadcast_dimensions must be an ascending sequence + # (no relative reordering of dims) of integers and + # each dimension must be within the new shape + def _greater_than_reduce(acc, x): + assert isinstance(x, Dim) + assert x > acc + assert x < len(shape) + + return x + + reduce(_greater_than_reduce, broadcast_dimensions, -1) + + # shape must be broadcastable to + for idx, new_idx in enumerate(broadcast_dimensions): + if not guard_size_oblivious(a.shape[idx] == 1): + torch._check( + a.shape[idx] == shape[new_idx], + lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}", + ) + + new_strides = [] + original_idx = 0 + for idx in range(len(shape)): + if idx in broadcast_dimensions: + # Assigns a stride of zero to dimensions + # which were actually broadcast + if guard_size_oblivious(a.shape[original_idx] != shape[idx]): + new_strides.append(0) + else: + new_strides.append(a.stride()[original_idx]) + original_idx = original_idx + 1 + else: + if guard_size_oblivious(shape[idx] != 1): + new_strides.append(0) + elif original_idx == a.ndim: + new_strides.append(1) + else: + new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) + + return a.as_strided(shape, new_strides, a.storage_offset()) + + +def _broadcast_in_dim_aten(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = v.unsqueeze(idx) + + return v.expand(shape) + + +_broadcast_in_dim_doc = """ + Creates a view of a with the specified shape. + + Allows adding dimensions of any length and broadcasting + dimensions of length one in a to any length. + + The location of the broadcast dimensions must be specified + using the broadcast_dimensions argument. Changing the + relative order of dimensions is not supported. + """ + +broadcast_in_dim = _make_prim( + schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)", + meta=_broadcast_in_dim_meta, + impl_aten=_broadcast_in_dim_aten, + return_type=RETURN_TYPE.VIEW, + doc=_broadcast_in_dim_doc, +) + + +def _validate_collapse_args(a: Tensor, start: int, end: int) -> None: + # Special-case for zero dimensional tensors + ndim = max(1, a.dim()) + utils.validate_idx(ndim, start) + utils.validate_idx(ndim, end) + + # Verifies end is strictly greater than start + # (Collapse requires a non-empty interval) + torch._check_value( + end >= start, + lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!", + ) + + +def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]: + """ + Returns the shape of a with dims in [start, end) merged into a single dimension. + """ + # Special-case for zero dimensional tensors + shape = (1,) if len(shape) == 0 else tuple(shape) + + dim_length = 1 + for s in shape[start : end + 1]: + dim_length = dim_length * s + + return shape[0:start] + (dim_length,) + shape[end + 1 :] + + +def _collapse_view_helper( + a: TensorLikeType, start: int, end: int +) -> Tuple[Optional[ShapeType], Optional[StrideType]]: + assert isinstance(a, TensorLike) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + _validate_collapse_args(a, start, end) + + # Special-case for zero dimensional tensors + if a.ndim == 0: + shape = (1,) + strides = (1,) + else: + shape = a.shape # type: ignore[assignment] + strides = a.stride() # type: ignore[assignment] + + if a.ndim == 0 or (end == start): + return shape, strides + + length = shape[end] + stride = strides[end] + for idx in range(end - 1, start - 1, -1): + if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious( + shape[idx + 1] == 0 + ): + length = 0 + stride = 0 + break + + if guard_size_oblivious(shape[idx] == 1): + continue + + length = length * shape[idx] + stride = min(stride, strides[idx]) + + if ( + guard_size_oblivious(a.numel() > 0) + and guard_size_oblivious(shape[idx + 1] != 1) + and not guard_size_oblivious( + strides[idx] == strides[idx + 1] * shape[idx + 1] + ) + ): + return None, None + + new_shape = shape[:start] + (length,) + shape[end + 1 :] + new_strides = strides[:start] + (stride,) + strides[end + 1 :] + + # NOTE: when the input has no elements it's restrided as if it were contiguous + if guard_size_oblivious(a.numel() == 0): + new_strides = utils.make_contiguous_strides_for(new_shape) + + return new_shape, new_strides + + +def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType: + new_shape, new_strides = _collapse_view_helper(a, start, end) + + if new_shape is None: + msg = "Attempting to view a collapsed tensor, but no such view exists!" + raise ValueError(msg) + + assert new_strides is not None + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor: + new_shape = _collapsed_shape(a.shape, start, end) + return a.view(new_shape) + + +_collapse_view_doc = """ + Creates a view of a with the dimensions between + start (inclusive) and end (exclusive) merged into a + single dimension. + + If it's not possible to take such a view then an error + is thrown. See collapse instead. + + The dimensions can be merged if and only if + they are all "nested" with each other. That is, they all + have the property that + + stride[i] = stride[i+1] * shape[i+1] + + for all i in [start, end - 1). + """ + +collapse_view = _make_prim( + schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)", + meta=_collapse_view_meta, + impl_aten=_collapse_view_aten, + return_type=RETURN_TYPE.VIEW, + doc=_collapse_view_doc, +) + + +def _conj_meta(a: TensorLikeType) -> TensorLikeType: + if not a.dtype.is_complex: + raise RuntimeError("Expected complex dtype in prims.conj") + out = a.as_strided(a.shape, a.stride(), a.storage_offset()) + torch._C._set_conj(out, not a.is_conj()) + return out + + +_conj_doc = """ +Returns a conjugated view of the original tensor +""" + +conj = _make_prim( + schema="conj(Tensor(a) a) -> Tensor(a)", + meta=_conj_meta, + impl_aten=torch.conj, + return_type=RETURN_TYPE.VIEW, + doc=_conj_doc, +) + + +def expand_dims( + a: TensorLikeType, dimensions: DimsSequenceType, ndim=None +) -> TensorLikeType: + """ + Creates a view of a with a.ndim + len(dimensions) dimensions, with new + dimensions of length one at the dimensions specified by dimensions. + """ + if ndim is not None: + # TODO: this is only here to support the unsqueeze ref + dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type] + else: + dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type] + if len(set(dims)) != len(dims): + msg = f"Received duplicate dimensions to expand in {str(dimensions)}" + raise ValueError(msg) + + new_shape = list(a.shape) + for idx in dims: + new_shape.insert(idx, 1) + + broadcast_dimensions = [ + idx for idx in range(len(new_shape)) if idx not in dimensions + ] + return broadcast_in_dim(a, new_shape, broadcast_dimensions) + + +# Note: saves the Python slice object because we're about to clobber its name with the slice prim +pyslice: Type[slice] = slice # type: ignore[has-type] + + +def _slice_meta( + a: TensorLikeType, + start_indices: DimsSequenceType, + limit_indices: DimsSequenceType, + strides: Optional[StrideType] = None, +) -> TensorLikeType: + _strides = strides if strides is not None else [1] * len(start_indices) + + if a.ndim != len(start_indices): + msg = f"Attempting to slice tensor of rank {a.ndim} with start_indices of length {len(start_indices)}!" + raise ValueError(msg) + + if a.ndim != len(limit_indices): + msg = f"Attempting to slice tensor of rank {a.ndim} with limit_indices of length {len(limit_indices)}!" + raise ValueError(msg) + + if a.ndim != len(_strides): + msg = f"Attempting to slice tensor of rank {a.ndim} with strides of length {len(limit_indices)}!" + raise ValueError(msg) + + for x, y in zip(start_indices, a.shape): + if x < 0: + msg = f"Attempting to slice a tensor with a negative start index of {x}!" + raise ValueError(msg) + if x > y: + msg = ( + f"Attempting to slice a tensor but a start index in {start_indices} is greater than" + f" the length of its corresponding dimension in shape {a.shape}" + ) + raise ValueError(msg) + + for x, y, z in zip(limit_indices, a.shape, start_indices): + if x < 0: + msg = f"Attempting to slice a tensor with a negative stop index of {x}!" + raise ValueError(msg) + if x > y: + msg = ( + f"Attempting to slice a tensor but a stop index in {limit_indices} is greater than the length of " + f" its corresponding dimension in shape {a.shape}" + ) + raise ValueError(msg) + if x < z: + msg = ( + f"Attempting to slice a tensor but a start index in {x} is greater than " + f" its corresponding stop index {z}" + ) + + for x in _strides: + if x <= 0: + msg = f"Attempting to slice a tensor with a non-positive step of {x}!" + raise ValueError(msg) + + new_shape = [] + for x, y, z in zip(start_indices, limit_indices, _strides): + new_shape.append(1 + (y - x - 1) // z) + + new_strides = [] + for x, y in zip(a.stride(), _strides): + new_strides.append(x * y) + + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +def _slice_aten( + a: Tensor, + start_indices: DimsSequenceType, + limit_indices: DimsSequenceType, + strides: Optional[StrideType] = None, +) -> Tensor: + _strides = strides if strides is not None else [1] * len(start_indices) + + slices = [] + for start, stop, step in zip(start_indices, limit_indices, _strides): + slices.append(pyslice(start, stop, step)) + + return operator.getitem(a, slices) # type: ignore[call-overload] + + +_slice_doc = """ + Creates a view of a "bounding box" within the tensor. + + The bounding box is specified independently in each of the tensor's dimensions. + start_indices and limit_indices describe the box's boundaries for their corresponding + dimensions. If strides is specified then they specify the step size between elements + in their corresponding dimension. + + This operation is analogous to slicing in NumPy, but does not permit slices where + the stop indices are less than the start indices. + """ + +slice = _make_prim( + schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)", + meta=_slice_meta, + impl_aten=_slice_aten, + return_type=RETURN_TYPE.VIEW, + doc=_slice_doc, +) + + +def _slice_in_dim_meta( + a: TensorLikeType, + start_index: int, + limit_index: int, + stride: int = 1, + axis: int = 0, +) -> TensorLikeType: + if axis < 0: + msg = f"slice_in_dim: received a negative axis {axis}" + raise ValueError(msg) + if axis >= a.ndim: + msg = f"slice_in_dim: axis {axis} is greater or equal to the rank {a.ndim} of the tensor" + raise ValueError(msg) + + if start_index < 0: + msg = f"slice_in_dim: received a negative start_index {start_index}" + raise ValueError(msg) + + if start_index > a.shape[axis]: + msg = f"slice_in_dim: start_index is greater than the length {start_index} of dimension {axis}" + raise ValueError(msg) + + if limit_index > a.shape[axis]: + msg = f"slice_in_dim: limit_index is greater than the length {limit_index} of dimension {axis}" + raise ValueError(msg) + + if limit_index < start_index: + msg = f"slice_in_dim: received a limit_index {limit_index} less than the start_index {start_index}" + raise ValueError(msg) + + if stride < 0: + msg = f"slice_in_dim: received a non-positive stride of {stride}!" + raise ValueError(msg) + + start_indices = [0] * a.ndim + limit_indices = list(a.shape) + strides = [1] * a.ndim + + start_indices[axis] = start_index + limit_indices[axis] = limit_index + strides[axis] = stride + + return _slice_meta(a, start_indices, limit_indices, strides) + + +def _slice_in_dim_aten( + a: Tensor, + start_index: int, + limit_index: int, + stride: int = 1, + axis: int = 0, +) -> Tensor: + start_indices = [0] * a.ndim + limit_indices = list(a.shape) + strides = [1] * a.ndim + + start_indices[axis] = start_index + limit_indices[axis] = limit_index + strides[axis] = stride + + return slice(a, start_indices, limit_indices, strides) + + +_slice_in_dim_doc = """ + Convenience wrapper for slicing just one dimension using slice. + """ + +# TODO: make stride SymInt +slice_in_dim = _make_prim( + schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)", + meta=_slice_in_dim_meta, + impl_aten=_slice_in_dim_aten, + return_type=RETURN_TYPE.VIEW, + doc=_slice_in_dim_doc, +) + + +def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: + assert isinstance(a, TensorLike) + utils.validate_idx(a.ndim, dim) + utils.validate_dim_length(outer_length) + + # Verifies the dim can be split with the specified lhs_length + inner_length = a.shape[dim] // outer_length + + if (a.shape[dim] % outer_length) != 0: + msg = "Attempting to split dimension of length {}, but outer length of {} divides it with a remainder!".format( + a.shape[dim], outer_length + ) + raise ValueError(msg) + + new_shape: List[int] = [] + new_strides: List[int] = [] + for idx in range(a.ndim): + if idx == dim: + new_shape.extend((outer_length, inner_length)) + new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx])) + else: + new_shape.append(a.shape[idx]) + new_strides.append(a.stride()[idx]) + + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor: + inner_length = a.shape[dim] // outer_length + new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :] + + return a.view(new_shape) + + +_split_dim_doc = """ + Creates a view of a with the given dimension (of length l) split + into two dimensions, with the outer of the two having + length outer_length and the inner of the two having computed + length inner_length such outer_length * inner_length = l. + """ + +# TODO: consider renaming split_dim_view +split_dim = _make_prim( + schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)", + meta=_split_dim_meta, + impl_aten=_split_dim_aten, + return_type=RETURN_TYPE.VIEW, + doc=_split_dim_doc, +) + + +# Note: allows dimensions to be specified redundantly +def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: + assert isinstance(a, TensorLike) + + for idx in dimensions: + utils.validate_idx(a.ndim, idx) + assert a.shape[idx] == 1 + + new_shape = [] + new_strides = [] + for idx in range(len(a.shape)): + if idx in dimensions: + continue + + new_shape.append(a.shape[idx]) + new_strides.append(a.stride()[idx]) + + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +_squeeze_doc = """ + Creates a view of the tensor with the specified dimensions removed. + + The removed dimensions must each have length one. + """ + +squeeze = _make_prim( + schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)", + meta=_squeeze_meta, + impl_aten=torch.squeeze, + return_type=RETURN_TYPE.VIEW, + doc=_squeeze_doc, +) + + +def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType: + if a.ndim != len(permutation): + msg = "Attempting to permute a tensor of rank {}, but received a permutation of length {}!".format( + a.ndim, len(permutation) + ) + raise ValueError(msg) + + if not utils.is_valid_permutation(a.ndim, permutation): + msg = f"Received an invalid permutation, {permutation}!" + raise ValueError(msg) + + new_shape = [0] * a.ndim + new_strides = [0] * a.ndim + for idx, dim in enumerate(permutation): + new_shape[idx] = a.shape[dim] + new_strides[idx] = a.stride()[dim] + + return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset()) + + +def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor: + return torch.permute(a, permutation) + + +_transpose_doc = """ + Creates a view of the tensor with its dimensions permuted. + + The length of the permutation must be the rank of the tensor, + and each element of the permutation specifies the new order + for the corresponding dimension. + """ + +transpose = _make_prim( + schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)", + meta=_transpose_meta, + impl_aten=_transpose_aten, + return_type=RETURN_TYPE.VIEW, + doc=_transpose_doc, +) + + +def _view_of_meta(a: TensorLikeType) -> TensorLikeType: + return a.as_strided(a.shape, a.stride(), a.storage_offset()) + + +def _view_of_aten(a: Tensor) -> Tensor: + return a.view(a.shape) + + +_view_of_doc = """ + Creates a view of the tensor. + """ + +view_of = _make_prim( + schema="view_of(Tensor(a) a) -> Tensor", + meta=_view_of_meta, + impl_aten=_view_of_aten, + return_type=RETURN_TYPE.VIEW, + doc=_view_of_doc, +) + + +def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + return a.view(dtype) + + +def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: + return a.view(dtype) + + +_view_element_type_doc = """ + Creates a view of the tensor with a different dtype. + """ + +view_element_type = _make_prim( + schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor", + meta=_view_element_type_meta, + impl_aten=_view_element_type_aten, + return_type=RETURN_TYPE.VIEW, + doc=_view_element_type_doc, +) + +# +# Functionalized view mutations +# + + +def _as_strided_scatter_meta( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: int, +) -> TensorLikeType: + utils.validate_shape(size) + utils.validate_strides(stride) + + required_size = utils.compute_required_storage_length(size, stride, storage_offset) + torch._check( + input.numel() >= required_size, + lambda: ( + f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} " + f" and itemsize {input.element_size()} requiring a storage size of " + f"{required_size * input.element_size()} are out of bounds " + f"for storage of size {input.numel() * input.element_size()}" + ), + ) + torch._check( + utils.is_same_shape(src.shape, size), + lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}", + ) + + return utils.clone_preserve_strides(input) + + +_as_strided_scatter_doc = """ + Creates a new tensor equivalent to ``out = input.clone()`` after mutation by + ``out.as_strided(size, stride, storage_offset).copy_(src)``. +""" + +as_strided_scatter = _make_prim( + schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor", + meta=_as_strided_scatter_meta, + impl_aten=torch.as_strided_scatter, + return_type=RETURN_TYPE.NEW, + doc=_as_strided_scatter_doc, +) + + +# +# Shape operations +# + + +def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor: + # Special-case for zero dimensional tensors + _validate_collapse_args(a, start, end) + new_shape = _collapsed_shape(a.shape, start, end) + return a.new_empty(new_shape) + + +def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor: + new_shape = _collapsed_shape(a.shape, start, end) + out = a.new_empty(new_shape) + with torch.no_grad(): + out.view_as(a).copy_(a) + return out + + +_collapse_doc = """ +Collapse a span of neighboring dimensions into one. + +See collapse_view for the corresponding view operation. +""" +collapse = _make_prim( + schema="collapse(Tensor a, int start, int end) -> Tensor", + meta=_collapse_meta, + impl_aten=_collapse_aten, + return_type=RETURN_TYPE.NEW, + doc=_collapse_doc, +) + + +# TODO: review stride logic +# NB: unlike torch.cat, this is more strict about empty tensors and dim is +# never negative +def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: + # Verifies same shape (except in the concat dimension) + assert dim >= 0 + shape = tensors[0].shape + concat_length = 0 + for tensor_idx, tensor in enumerate(tensors): + assert len(shape) == len(tensor.shape) + for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): + if idx == dim: + concat_length = concat_length + length + else: + torch._check( + length == common_length, + lambda: f"Sizes of tensors must match except in dimension {dim}. " + f"Expected {common_length} but got {length} for tensor number " + f"{tensor_idx} in the list", + ) + + new_shape = list(tensors[0].shape).copy() + new_shape[dim] = concat_length + return TensorMeta( + tensors[0], + shape=new_shape, + strides=utils.make_contiguous_strides_for(new_shape), + ) + + +def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor: + return torch.cat(tensors, dim) + + +_cat_doc = """ + Concatenates tensors along the specified dimension. + + The tensors' shapes must have the same rank and same length for other dimensions. + """ + +cat = _make_prim( + schema="cat(Tensor[] tensors, int dim) -> Tensor", + meta=_cat_meta, + impl_aten=_cat_aten, + return_type=RETURN_TYPE.NEW, + doc=_cat_doc, +) + + +def _reshape_meta(a: TensorLikeType, shape: ShapeType): + assert isinstance(a, TensorLike) + utils.validate_shape(shape) + + # Validates the tensor and the requested shape have the + # same number of elements + numel = reduce(operator.mul, shape) + if numel != a.numel(): + msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!" + raise ValueError(msg) + + return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape)) + + +def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor: + return a.reshape(shape).contiguous().clone() + + +_reshape_doc = """ + Creates a contiguous tensor with the specified shape + containing a copy of the data in a. + """ +reshape = _make_prim( + schema="reshape(Tensor a, SymInt[] shape) -> Tensor", + meta=_reshape_meta, + impl_aten=_reshape_aten, + return_type=RETURN_TYPE.NEW, + doc=_reshape_doc, +) + + +def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: + utils.validate_dimension_indices(a.ndim, dims) + return torch.empty_like(a, memory_format=torch.preserve_format) + + +_rev_doc = """ + Reverses the order of elements along the given dimensions. + """ + +rev = _make_prim( + schema="rev(Tensor a, int[] dims) -> Tensor", + meta=_rev_meta, + impl_aten=torch.flip, + return_type=RETURN_TYPE.NEW, + doc=_rev_doc, +) + +# +# Conditional prims +# + + +def _where_meta( + pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType +) -> TensorLikeType: + return _prim_elementwise_meta( + a, + b, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, + args_with_fixed_dtypes=(pred,), + ) + + +_where_doc = """ + Selects elements from a and b according to pred. + + Where pred is true the result contains the element from a, and + where pred is false the result contains the element from b. + """ + +where = _make_prim( + schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor", + meta=_where_meta, + impl_aten=torch.where, + return_type=RETURN_TYPE.NEW, + doc=_where_doc, +) + + +# +# Type conversions +# +def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + # Type checks + assert isinstance(a, TensorLike) + assert isinstance(dtype, torch.dtype) + + # dtype conversion preserves dense strides + if torch._prims_common.is_non_overlapping_and_dense(a): + strides = a.stride() + else: + strides = utils.compute_elementwise_output_strides(a) + + return TensorMeta(a, strides=strides, dtype=dtype) + + +def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: + # Propagates requires grad when possible + if not utils.is_grad_dtype(dtype): + requires_grad = False + else: + # TODO: update meta objects so this can be acquired directly + try: + requires_grad = a.requires_grad + except Exception as e: + requires_grad = False + + result = torch.empty_like( + a, device=a.device, dtype=dtype, requires_grad=requires_grad + ) + with torch.no_grad(): + return copy_to(result, a) + + +_convert_element_type_doc = """ + Creates a copy of a tensor with the given dtype. + """ + +convert_element_type = _make_prim( + schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor", + meta=_convert_element_type_meta, + impl_aten=_convert_element_type_aten, + return_type=RETURN_TYPE.NEW, + doc=_convert_element_type_doc, + tags=(torch.Tag.pointwise,), +) + + +def _device_put_meta( + a: TensorLikeType, device: Union[str, torch.device] +) -> TensorLikeType: + assert isinstance(a, TensorLike) + assert isinstance(device, (str, torch.device)) + + return TensorMeta(a, device=utils.canonicalize_device(device)) + + +def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor: + return a.to(device) + + +_device_put_doc = """ + Creates a copy of a tensor on the given device. + """ + +device_put = _make_prim( + schema="device_put(Tensor a, Device device) -> Tensor", + meta=_device_put_meta, + impl_aten=_device_put_aten, + return_type=RETURN_TYPE.NEW, + doc=_device_put_doc, +) + + +# NOTE: need to model meta scalars +# See https://github.com/pytorch/pytorch/issues/78070 +def _item_meta(a: TensorLikeType) -> FakeTensor: + number_type = utils.dtype_to_type(a.dtype) + return TensorMeta(number_type(-1)) + + +_item_doc = """ + Converts a tensor with one element to a Python number. +""" + +# TODO: create a new return type for scalars? +# FIXME: currently returns integers for boolean tensors +# https://github.com/pytorch/pytorch/issues/78071 +item = _make_prim( + schema="item(Tensor a) -> Scalar", + meta=_item_meta, + impl_aten=torch.Tensor.item, + return_type=RETURN_TYPE.NEW, + doc=_item_doc, +) + + +# NOTE: need to model meta scalars +# See https://github.com/pytorch/pytorch/issues/78070 +def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor: + number_type = utils.dtype_to_type(dtype) + return TensorMeta(number_type(-1)) + + +def _maximum_value_aten(dtype: torch.dtype): + if dtype == torch.bool: + return True + elif dtype.is_complex or dtype.is_floating_point: + return torch.finfo(dtype).max + else: + return torch.iinfo(dtype).max + + +_maximum_value_doc = """ + Return the maximum finite value for a dtype. +""" + +# TODO: create a new return type for scalars? +# FIXME: currently returns integers for boolean tensors +# https://github.com/pytorch/pytorch/issues/78071 +maximum_value = _make_prim( + schema="maximum_value(ScalarType dtype) -> Scalar", + meta=_maximum_value_meta, + impl_aten=_maximum_value_aten, + return_type=RETURN_TYPE.NEW, + doc=_maximum_value_doc, +) + + +# NOTE: need to model meta scalars +# See https://github.com/pytorch/pytorch/issues/78070 +def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor: + number_type = utils.dtype_to_type(dtype) + return TensorMeta(number_type(-1)) + + +def _minimum_value_aten(dtype: torch.dtype): + if dtype == torch.bool: + return False + elif dtype.is_complex or dtype.is_floating_point: + return torch.finfo(dtype).min + else: + return torch.iinfo(dtype).min + + +_minimum_value_doc = """ + Return the minimum finite value for a dtype. +""" + +# TODO: create a new return type for scalars? +# FIXME: currently returns integers for boolean tensors +# https://github.com/pytorch/pytorch/issues/78071 +minimum_value = _make_prim( + schema="minimum_value(ScalarType dtype) -> Scalar", + meta=_minimum_value_meta, + impl_aten=_minimum_value_aten, + return_type=RETURN_TYPE.NEW, + doc=_minimum_value_doc, +) + +# +# Inplace operators +# + + +def _copy_to_meta(a: TensorLikeType, b: TensorLikeType): + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + + # Validates the cast is safe + # TODO: move this as an option on the reference + # a_typ = utils.dtype_to_type(a.dtype) + # b_typ = utils.dtype_to_type(b.dtype) + # if a_typ is not utils.get_higher_type(a_typ, b_typ): + # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!") + + # Validates the tensors have the same number of elements + if a.numel() != b.numel(): + msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!" + raise RuntimeError(msg) + + return a + + +def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor: + return a.copy_(b) + + +_copy_to_doc = """ + Copies the data in b to a and returns the modified a. + """ + +# TODO: Remove safe casting and implement on reference instead +copy_to = _make_prim( + schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)", + meta=_copy_to_meta, + impl_aten=_copy_to_aten, + return_type=RETURN_TYPE.INPLACE, + doc=_copy_to_doc, +) + + +def _copy_strided_meta(a: TensorLikeType, stride: ShapeType): + assert isinstance(a, TensorLike) + return torch.empty_strided( + a.shape, + stride, + dtype=a.dtype, + layout=a.layout, + device=a.device, + requires_grad=a.requires_grad, + ) + + +def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor: + out = torch.empty_strided( + a.size(), + stride=stride, + dtype=a.dtype, + layout=a.layout, + device=a.device, + requires_grad=a.requires_grad, + ) + out.copy_(a) + return out + + +_copy_strided_doc = """ + Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride. + """ + + +copy_strided = _make_prim( + schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor", + meta=_copy_strided_meta, + impl_aten=_copy_strided_aten, + return_type=RETURN_TYPE.NEW, + doc=_copy_strided_doc, +) + + +def _resize_meta(a: TensorLikeType, shape: ShapeType): + return a.resize_(shape) + + +def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor: + return a.resize_(shape) + + +_resize_doc = """ + Gives a tensor with no elements a new shape, returning the modified tensor. + + The tensor's strides are contiguous and its values are unitialized. + """ + +# TODO: review support arbitrary resizes +resize = _make_prim( + schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)", + meta=_resize_meta, + impl_aten=_resize_aten, + return_type=RETURN_TYPE.INPLACE, + doc=_resize_doc, +) + + +def _reduction_meta(inp, dims, *, output_dtype=None): + """ + Meta function for single output reduction operations + Stride logic is incorrect + """ + assert isinstance(inp, TensorLike) + if output_dtype is None: + output_dtype = inp.dtype + output_shape = utils.compute_reduction_output_shape(inp.shape, dims) + return TensorMeta( + shape=output_shape, + strides=utils.make_contiguous_strides_for(output_shape), + dtype=output_dtype, + device=inp.device, + ) + + +def _var_reduction_meta(inp, dims, *, correction): + if utils.is_complex_dtype(inp.dtype): + output_dtype = utils.corresponding_real_dtype(inp.dtype) + else: + output_dtype = inp.dtype + return _reduction_meta(inp, dims, output_dtype=output_dtype) + + +_sum_doc = """ + Computes the sum of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_xor_sum_doc = """ + Computes the xor sum of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_prod_doc = """ + Computes the product of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_amax_doc = """ + Computes the maximum value of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_amin_doc = """ + Computes the minimum value of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_var_doc = """ + Computes the biased variance of x over the list of dimensions specified in the dim argument + """ + + +def _make_reduction_prim(name: str, impl_aten, doc): + """Creates a reduction prim.""" + return _make_prim( + schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor", + meta=_reduction_meta, + impl_aten=impl_aten, + return_type=RETURN_TYPE.NEW, + doc=doc, + ) + + +def _make_var_reduction_prim(name: str, impl_aten, doc): + """Creates a reduction prim.""" + return _make_prim( + schema=f"{name}(Tensor inp, int[]? dims, *, float correction, ScalarType? output_dtype=None) -> Tensor", + meta=_var_reduction_meta, + impl_aten=impl_aten, + return_type=RETURN_TYPE.NEW, + doc=doc, + ) + + +sum = _make_reduction_prim( + name="sum", + impl_aten=torch.sum, + doc=_sum_doc, +) + + +def _xor_sum_aten( + inp: TensorLikeType, + dims: Optional[DimsSequenceType], + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + raise NotImplementedError("xor_sum only implemented with inductor") + + +xor_sum = _make_reduction_prim( + name="xor_sum", + impl_aten=_xor_sum_aten, + doc=_xor_sum_doc, +) + + +def _prod_aten( + inp: TensorLikeType, + dims: Optional[DimsSequenceType], + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + if dims is not None: + for d in sorted(dims, reverse=True): + assert d >= 0 + inp = torch.prod(inp, d, dtype=dtype) + return inp + else: + return torch.prod(inp, dims, dtype=dtype) + + +prod = _make_reduction_prim( + name="prod", + impl_aten=_prod_aten, + doc=_prod_doc, +) + +var = _make_var_reduction_prim( + name="var", + impl_aten=torch.var, + doc=_var_doc, +) + +amax = _make_reduction_prim( + name="amax", + impl_aten=torch.amax, + doc=_amax_doc, +) + +amin = _make_reduction_prim( + name="amin", + impl_aten=torch.amin, + doc=_amin_doc, +) + + +_iota_doc = """ + Constructs a 1-D tensor t where ``t[i] == start + i * step``. +""" + + +# TODO: layout, pin_memory, memory_format +# TODO: model requires_grad on TensorMeta +def _iota_meta( + length: int, + *, + start: int, + step: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + torch._check( + utils.is_integer_dtype(dtype), + lambda: "prims.iota only supports integer dtypes", + ) + torch._check(step != 0, lambda: "step must be nonzero") + return torch.empty( + length, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def _iota_aten( + length: int, + *, + start: int, + step: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + end = start + length * step + return torch.arange( + start, end, step, dtype=dtype, device=device, requires_grad=requires_grad + ) + + +iota = _make_prim( + schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 + return_type=RETURN_TYPE.NEW, + meta=_iota_meta, + impl_aten=_iota_aten, + doc=_iota_doc, +) + + +# TODO: layout, pin_memory, memory_format +# TODO: model requires_grad on TensorMeta +def _empty_meta( + shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool +) -> TensorLikeType: + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _empty_aten( + shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool +) -> Tensor: + return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) + + +_empty_doc = """ + Creates a tensor with uninitialized values and the specified shape, dtype, and device. +""" + +empty = _make_prim( + schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + meta=_empty_meta, + impl_aten=_empty_aten, + return_type=RETURN_TYPE.NEW, + doc=_empty_doc, +) + + +def _empty_strided_meta( + shape: ShapeType, + strides: StrideType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +_empty_strided_doc = """ + Creates a tensor with uninitialized values. +""" + +# TODO: add layout, pin_memory +empty_strided = _make_prim( + schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + return_type=RETURN_TYPE.NEW, + meta=_empty_strided_meta, + impl_aten=torch.empty_strided, + doc=_empty_strided_doc, +) + + +def _empty_permuted_meta( + shape: ShapeType, + physical_layout: DimsSequenceType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout]) + dim = len(shape) + torch._check( + len(physical_layout) == dim, + lambda: ( + "Number of dimensions in the tensor input does not match the " + f"length of the physical layout; i.e. len(size) = {dim} " + f"is not equal to len(physical_layout) = {len(physical_layout)}" + ), + ) + strides = [0] * len(shape) + seen_dims = set() + for p, l in enumerate(physical_layout): + torch._check( + 0 <= l < dim, + lambda: ( + f"Dimension out of range (expected to be between 0 and {dim - 1}, but got " + f"{l} at index {p}). NB: negative dims " + "not currently supported; file an issue if you want it." + ), + ) + torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed") + strides[l] = p_strides[p] + seen_dims.add(l) + return TensorMeta( + shape=shape, + strides=strides, + dtype=dtype, + device=device, + ) + + +_empty_permuted_doc = """ + Creates a tensor with uninitialized values according to some physical layout, + that is guaranteed to be non-overlapping and dense. +""" + +# TODO: add layout, pin_memory +empty_permuted = _make_prim( + schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 + return_type=RETURN_TYPE.NEW, + meta=_empty_permuted_meta, + impl_aten=torch.empty_permuted, + doc=_empty_permuted_doc, +) + + +def _full_meta( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _full_aten( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> Tensor: + # Note that Mypy thinks torch.full can't accept a complex fill_value + return torch.full( + shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] + ) + + +_full_doc = """ + Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device. +""" + +# TODO: add layout +full = _make_prim( + schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + meta=_full_meta, + impl_aten=_full_aten, + return_type=RETURN_TYPE.NEW, + doc=_full_doc, +) + + +def _full_like_meta( + a: TensorLikeType, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + strides = utils.compute_elementwise_output_strides(a) + if a.numel() == 0: + strides = a.stride() + + return TensorMeta(a, strides=strides, dtype=dtype, device=device) + + +def _full_like_aten( + a: Tensor, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> Tensor: + # Note that Mypy thinks torch.full can't accept a complex fill_value + return torch.full_like( + a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] + ) + + +_full_like_doc = """ + Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the + given tensor by default. The dtype and device settings can be overridden + by specifying them explicitly. +""" + +full_like = _make_prim( + schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + meta=_full_like_meta, + impl_aten=_full_like_aten, + return_type=RETURN_TYPE.NEW, + doc=_full_like_doc, +) + + +def _scalar_tensor_meta( + scalar: NumberType, + *, + dtype: torch.dtype, + device: torch.device, +) -> TensorLikeType: + shape: ShapeType = [] + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device) + + +def _scalar_tensor_aten( + scalar: NumberType, + *, + dtype: torch.dtype, + device: torch.device, +) -> Tensor: + if isinstance(scalar, complex) and ( + dtype is None or not utils.is_complex_dtype(dtype) + ): + raise TypeError("Complex scalar requires complex tensor dtype.") + # Note that Mypy thinks torch.scalar can't accept a complex scalar + return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type] + + +_scalar_tensor_doc = """ + Wraps a Number into a Tensor with the specified dtype and device. +""" + +# TODO: add layout and pin_memory support +scalar_tensor = _make_prim( + schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor", + meta=_scalar_tensor_meta, + impl_aten=_scalar_tensor_aten, + return_type=RETURN_TYPE.NEW, + doc=_scalar_tensor_doc, +) + + +# +# Linear algebra (linalg) prims +# + + +def _svd_meta( + A: TensorLikeType, *, full_matrices: bool +) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]: + utils.check_is_matrix(A, "linalg.svd") + utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False) + + A_shape = A.shape + batch = A_shape[:-2] + m, n = A_shape[-2:] + k = min(m, n) + + shape_U = batch + (m, m if full_matrices else k) + strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False) + U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device) + + shape_S = batch + (k,) + strides_S = utils.make_contiguous_strides_for(shape_S) + S = TensorMeta( + shape=shape_S, + strides=strides_S, + dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype, + device=A.device, + ) + + shape_Vh = batch + (n if full_matrices else k, n) + # The CPU backend returns V, but the cuSolver backend returns V^H + # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend + is_cuda = A.device.type == "cuda" + strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda) + Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device) + # Also makes sure this is CUDA or HIP: + # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip + if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available(): + Vh = Vh.conj() + return U, S, Vh + + +def _svd_aten( + A: TensorLikeType, *, full_matrices: bool +) -> Tuple[Tensor, Tensor, Tensor]: + return torch.linalg.svd(A, full_matrices=full_matrices) + + +_svd_doc = """ + Returns the SVD of a matrix or batch of matrices. + + The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned. +""" + +svd = _make_prim( + schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)", + meta=_svd_meta, + impl_aten=_svd_aten, + return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW), + doc=_svd_doc, +) + + +# +# Randomness Prims +# + + +def _normal_meta( + shape: ShapeType, + *, + mean: Union[float, complex], + std: float, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, + generator: Optional[torch.Generator] = None, +) -> TensorLikeType: + torch._check( + std >= 0.0, + lambda: f"expected non-negative standard deviation, but got std={std}", + ) + + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}", + ) + + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _normal_aten( + shape: ShapeType, + *, + mean: Union[float, complex], + std: float, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, + generator: Optional[torch.Generator] = None, +) -> Tensor: + a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) + with torch.no_grad(): + # NOTE: normal_ is incorrectly annotated to expect mean to be a float + a.normal_(mean, std, generator=generator) # type: ignore[arg-type] + return a + + +_normal_doc = """ + Constructs a tensor filled with values drawn from a normal distribution with the specified mean + and standard deviation. + + Only supports floating-point types. +""" + +normal = _make_prim( + schema=( + "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950 + ), + return_type=RETURN_TYPE.NEW, + meta=_normal_meta, + impl_aten=_normal_aten, + doc=_normal_doc, +) + + +def _uniform_meta( + shape: ShapeType, + *, + low: float, + high: float, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, +) -> TensorLikeType: + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _uniform_aten( + shape: ShapeType, + *, + low: float, + high: float, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, +) -> Tensor: + a = torch.empty(shape, dtype=dtype, device=device) + a.uniform_(low, high, generator=generator) + return a + + +_uniform_doc = """ + Constructs a tensor filled with values drawn uniformly from low to high. +""" + +# TODO: we should more seriously review randomness modeling and prims +_uniform_helper = _make_prim( + schema=( + "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor" + ), + return_type=RETURN_TYPE.NEW, + meta=_uniform_meta, + impl_aten=_uniform_aten, + doc=_uniform_doc, +) + +# +# FFT prims +# + + +def _fft_r2c_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + onesided: bool, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = list(input.shape) + if onesided: + last_dim = dim[-1] + shape[last_dim] = shape[last_dim] // 2 + 1 + + dtype = utils.corresponding_complex_dtype(input.dtype) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) + + +def _fft_r2c_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + onesided: bool, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_r2c(input, dim, normalization, onesided) + + +_fft_r2c_doc = """ + Performs a real to complex Fast Fourier Transform +""" + + +fft_r2c = _make_prim( + schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor", + meta=_fft_r2c_meta, + impl_aten=_fft_r2c_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_r2c_doc, +) + + +def _fft_c2c_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + forward: bool, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = input.shape + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta( + shape=shape, strides=strides, dtype=input.dtype, device=input.device + ) + + +def _fft_c2c_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + forward: bool, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_c2c(input, dim, normalization, forward) + + +_fft_c2c_doc = """ + Performs either a Fast Fourier Transform, or its inverse +""" + + +fft_c2c = _make_prim( + schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor", + meta=_fft_c2c_meta, + impl_aten=_fft_c2c_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_c2c_doc, +) + + +def _fft_c2r_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + last_dim_size: int, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = list(input.shape) + shape[dim[-1]] = last_dim_size + dtype = utils.corresponding_real_dtype(input.dtype) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) + + +def _fft_c2r_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + last_dim_size: int, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_c2r(input, dim, normalization, last_dim_size) + + +_fft_c2r_doc = """ + Performs a complex to real Inverse Fast Fourier Transform +""" + + +fft_c2r = _make_prim( + schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor", + meta=_fft_c2r_meta, + impl_aten=_fft_c2r_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_c2r_doc, +) + + +def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: + torch._check( + self.dtype.is_floating_point, + lambda: "torch.frexp() only supports floating-point dtypes", + ) + return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32) + + +frexp = _make_prim( + schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)", + meta=_frexp_meta, + return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW), + impl_aten=torch.frexp, + doc="", +) + +register_rng_prims() +register_debug_prims() diff --git a/MLPY/Lib/site-packages/torch/_prims/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..218174476de6f83f84715c688b87f49650e96e96 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims/__pycache__/context.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims/__pycache__/context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9a7c59eb66a5b2d42282a69a0eebe343390ef2f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims/__pycache__/context.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims/__pycache__/debug_prims.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims/__pycache__/debug_prims.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48173561973e21da3eb29fe285b82f4e2304bacb Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims/__pycache__/debug_prims.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims/__pycache__/executor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims/__pycache__/executor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..061186fd921544dd67f5a5c0f8a2199da265c8b7 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims/__pycache__/executor.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims/__pycache__/rng_prims.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims/__pycache__/rng_prims.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334089afc3d976529926efe748cc60c3d99d5a5e Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims/__pycache__/rng_prims.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims/context.py b/MLPY/Lib/site-packages/torch/_prims/context.py new file mode 100644 index 0000000000000000000000000000000000000000..8cf0104178568444168792cea4ffcfbd6c516357 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims/context.py @@ -0,0 +1,144 @@ +import functools +from contextlib import nullcontext +from typing import Any, Callable, Dict, Optional, Sequence + +import torch + +import torch._decomp +import torch._prims + +import torch._refs +import torch._refs.nn +import torch._refs.nn.functional +import torch._refs.special +import torch.overrides + +from torch._prims_common import torch_function_passthrough + + +@functools.lru_cache(None) +def torch_to_refs_map(): + """ + Mapping of torch API functions to torch._refs functions. + E.g. torch_to_refs_map()[torch.add] == torch._refs.add + """ + modules = [ + (torch, torch._refs), + (torch.nn, torch._refs.nn), + (torch.nn.functional, torch._refs.nn.functional), + (torch.special, torch._refs.special), + (torch.fft, torch._refs.fft), + (torch.linalg, torch._refs.linalg), + ] + r: Dict[Any, Any] = { + torch.Tensor.__invert__: torch._refs.bitwise_not, + torch.Tensor.__xor__: torch._refs.bitwise_xor, + torch.Tensor.__and__: torch._refs.bitwise_and, + torch.Tensor.__or__: torch._refs.bitwise_or, + torch.Tensor.__eq__: torch._refs.eq, + torch.Tensor.__rsub__: torch._refs.rsub, + torch.Tensor.__rtruediv__: torch._refs.rtruediv, + torch.Tensor.__floordiv__: torch._refs.floor_divide, + torch.Tensor.__rfloordiv__: torch._refs.rfloordiv, + torch.Tensor.__pow__: torch._refs.pow, + torch.Tensor.__rpow__: torch._refs.rpow, + torch.Tensor.new_empty: torch._refs.new_empty, + torch.Tensor.new_full: torch._refs.new_full, + torch.Tensor.new_zeros: torch._refs.new_zeros, + torch.Tensor.new_ones: torch._refs.new_ones, + torch.Tensor.fill_: torch._refs.fill_, + torch.Tensor.zero_: torch._refs.zero_, + torch.Tensor.to: torch._refs.to, + torch.Tensor.sum_to_size: torch._refs.sum_to_size, + # TODO: Should these methods be mapped some other way? + torch.Tensor.copy_: torch._prims.copy_to, + torch.Tensor.resize: torch._prims.resize, + } + for mod_torch, mod_refs in modules: + for s in mod_refs.__all__: # type: ignore[attr-defined] + r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s) + + # Support remapping torch.Tensor.foo to _refs.foo + for s in dir(torch.Tensor): + if s in torch._refs.__all__: + r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s) + + # Support conversions + for s in torch._refs._conversions.__all__: + tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s) + r[tensor_attr] = torch._refs._conversions.__dict__.get(s) + + return r + + +@functools.lru_cache(None) +def all_prims(): + """ + Set of all prim functions, e.g., torch._prims.add in all_prims() + """ + return {torch._prims.__dict__.get(s) for s in torch._prims.__all__} + + +class TorchRefsMode(torch.overrides.TorchFunctionMode): + """ + Switches the interpretation of torch.* functions and Tensor methods to + use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.) + + >>> # xdoctest: +SKIP + >>> with TorchRefsMode(): + ... torch.add(x, y) # calls torch._refs.add(x, y) + + By default, this context manager will fall back on the torch.* if the + ref does not exist; set strict=True to error if this occurs. + If the ref exists we still would like to fall back on the torch.* sometimes, + this behavior can be customized by passing a function to should_fallback_fn. + """ + + def __init__( + self, + strict=False, + should_fallback_fn=lambda *_: False, + prims_mode_cls=nullcontext, + ): + self.strict = strict + self.should_fallback_fn = should_fallback_fn + self.prims_mode_cls = prims_mode_cls + + def __torch_function__( + self, + orig_func: Callable, + types: Sequence, + args: Sequence[Any] = (), + kwargs: Optional[Dict] = None, + ): + if kwargs is None: + kwargs = {} + # For primitive operations, run them as is without interception + # Unless we are in prims_mode, in which case we want to use nvprims + if orig_func in torch_function_passthrough or orig_func in all_prims(): + with self.prims_mode_cls(): + return orig_func(*args, **kwargs) + mapping = torch_to_refs_map() + func = mapping.get(orig_func, None) + + # For torch.ops.aten.*, use registered decompositions from torch._decomp + # torch._decomp.decomposition_table provides a mapping from + # torch.ops.aten.* to torch._refs or torch._decomp.decompositions + # implementations. + # There're other ways to implement this functionality, + # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417 + if func is None and isinstance(orig_func, torch._ops.OpOverload): + func = torch._decomp.decomposition_table.get(orig_func, None) + + if func is not None: + # If the ref exists query whether we should use it or not + if self.should_fallback_fn(self, orig_func, func, args, kwargs): + return orig_func(*args, **kwargs) + # torch calls inside func should be interpreted as refs calls + with self: + return func(*args, **kwargs) + if self.strict: + raise RuntimeError( + f"no _refs support for {torch.overrides.resolve_name(orig_func)}" + ) + return orig_func(*args, **kwargs) diff --git a/MLPY/Lib/site-packages/torch/_prims/debug_prims.py b/MLPY/Lib/site-packages/torch/_prims/debug_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd450f08fadb3c4aeb615230fd1dbadf3921e6c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims/debug_prims.py @@ -0,0 +1,59 @@ +import contextlib +from typing import Optional, Sequence + +import torch +from torch._custom_op.impl import custom_op +from torch.utils._content_store import ContentStoreReader + +LOAD_TENSOR_READER: Optional[ContentStoreReader] = None + + +@contextlib.contextmanager +def load_tensor_reader(loc): + global LOAD_TENSOR_READER + assert LOAD_TENSOR_READER is None + # load_tensor is an "op", and we will play merry hell on + # Inductor's memory planning if we return a tensor that + # aliases another tensor that we previously returned from + # an operator. So unlike standard ContentStoreReader use, + # we disable the cache so that you always get fresh storages + # (no aliasing for you!) + LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False) + try: + yield + finally: + LOAD_TENSOR_READER = None + + +def register_debug_prims(): + @custom_op("debugprims::load_tensor") + def load_tensor( # type: ignore[empty-body] + name: str, + size: Sequence[int], + stride: Sequence[int], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + ... + + @load_tensor.impl_factory() + def load_tensor_factory(name, size, stride, dtype, device): + if LOAD_TENSOR_READER is None: + from torch._dynamo.testing import rand_strided + + return rand_strided(size, stride, dtype, device) + else: + from torch._dynamo.utils import clone_input + + # device argument here takes care of coercion + r = LOAD_TENSOR_READER.read_tensor(name, device=device) + assert list(r.size()) == size, f"{r.size()} != {size}" + assert list(r.stride()) == stride, f"{r.stride()} != {stride}" + assert r.device == device, f"{r.device} != {device}" + + # Unlike the other properties, we will do coercions for dtype + # mismatch + if r.dtype != dtype: + r = clone_input(r, dtype=dtype) + return r diff --git a/MLPY/Lib/site-packages/torch/_prims/executor.py b/MLPY/Lib/site-packages/torch/_prims/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..415738ae14bbfd910fda03a9b68981845eca9c1f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims/executor.py @@ -0,0 +1,60 @@ +from typing import Callable, Optional + +from torch._prims.context import TorchRefsMode + +from torch.fx import GraphModule +from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx + + +def execute( + gm: GraphModule, + *args, + executor: str = "aten", + executor_parameters: Optional[dict] = None, +): + """ + Prototype ATen executor. + + Just executes the context's graph. + """ + + if executor == "aten": + return gm.forward(*args) + + msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten." + raise ValueError(msg) + + +def make_traced(fn: Callable): + """ + Returns a function that, when called, will + trace its torch operations to prims and then + execute those prims on the requested trace executor + (possibly lowering them to that trace executor first). + + Only supports the torch operations defined in _torch_to_reference_map + in context.py and operations with positional args. All args must + be tensors. + In the near future all these restrictions will be lifted. + + Example usage: + + def foo(a, b): + return torch.add(a, b) + + traced_foo = make_traced(foo) + + a = torch.randn((1, 2, 3, 4, 5), device='cuda') + b = torch.randn((1, 2, 3, 4, 5), device='cuda') + result = traced_foo(a, b, executor='aten') + """ + + def _traced(*args, executor="aten", **kwargs): + # TODO: caching + wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs) + + with TorchRefsMode(): + gm = make_fx(wrapped)(all_args) + return execute(gm, all_args, executor=executor) + + return _traced diff --git a/MLPY/Lib/site-packages/torch/_prims/rng_prims.py b/MLPY/Lib/site-packages/torch/_prims/rng_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7e6e2d3e6e8a90d206816df6fc42cfd6d9e30f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims/rng_prims.py @@ -0,0 +1,268 @@ +from typing import Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch import _prims +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator + +from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for +from torch._prims_common.wrappers import backwards_not_supported +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.types import _device, _dtype + + +rngprim_namespace = "rngprims" +rngprim = torch.library.Library(rngprim_namespace, "DEF") +rngprim_impl = torch.library.Library( + rngprim_namespace, "IMPL", "CompositeExplicitAutograd" +) +rngprim_autograd_impl = torch.library.Library(rngprim_namespace, "IMPL", "Autograd") +rngprim_meta_impl = torch.library.Library(rngprim_namespace, "IMPL", "Meta") + + +def throw_on_non_cuda(device): + raise RuntimeError( + f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " + f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " + "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." + ) + + +def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None): + rngprim.define(schema) + rngprim_impl.impl(name, impl_aten) + rngprim_meta_impl.impl(name, impl_meta) + + prim_packet = getattr(torch._ops.ops.rngprims, name) + prim = prim_packet.default + if tags: + prim._tags = tags + + rngprim_autograd_impl.impl(name, backwards_not_supported(prim)) + + for p in (prim_packet, prim): + p.__doc__ = doc + p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] + + p.schema = schema + p.impl_aten = impl_aten + p.prim_meta_impl = impl_meta + + +# Philox rand offsets could be shared in future with other philox ops, so +# keeping these functions in global scope. +def philox_rand_offset_meta( + shape: torch.Size, +): + return _prims.TensorLike(torch.tensor(0, dtype=torch.int64)) + + +def philox_rand_offset( + shape: torch.Size, +): + # For impl, look at the function calc_execution_policy in the file + # aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at + # commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6 + numel_scalar = 1 + for dim_size in shape: + numel_scalar *= dim_size + numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64) + + block_size = 256 + unroll = 4 + curand4_engine_calls = 4 + device_property = torch.cuda.get_device_properties(torch.cuda.current_device()) + blocks_per_sm = device_property.max_threads_per_multi_processor // block_size + grid_size = (numel + block_size - 1) // block_size + grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) + offset = ( + (numel - 1) // (block_size * grid_size * unroll) + 1 + ) * curand4_engine_calls + return offset + + +def register_philox_rand(): + name = "philox_rand" + schema = "philox_rand(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950 + + def _philox_rand_meta( + shape: torch.Size, + seed: torch.Tensor, + offset: torch.Tensor, + stride: Optional[Tuple[int, ...]], + device: _device, + dtype: _dtype, + ): + # stride arg will be useful for distributed usecase. Currently, its unused. + assert stride is None + stride = make_contiguous_strides_for(shape) + random_values = _prims.TensorMeta( + shape=shape, strides=stride, dtype=dtype, device=device + ) + offset = philox_rand_offset_meta(shape) + return (random_values, offset) + + def _philox_rand( + shape: torch.Size, + seed: torch.Tensor, + offset: torch.Tensor, + stride: Optional[Tuple[int, ...]], + device: _device, + dtype: _dtype, + ): + # stride arg will be useful for distributed usecase. Currently, its unused. + assert stride is None + if device.type == "cpu": + devices = [] + else: + devices = [device] + + if device.type != "cuda": + raise throw_on_non_cuda(device) + + with torch.random.fork_rng(devices): + CUDARngStateHelper.set_torch_state_tensor(seed, offset) + random_values = torch.rand(shape, device=device, dtype=dtype) + + return random_values, philox_rand_offset(shape) + + register_rng_prim( + name=name, + schema=schema, + impl_aten=_philox_rand, + impl_meta=_philox_rand_meta, + doc="Philox based stateless rand operator", + tags=(torch.Tag.nondeterministic_seeded,), + ) + + +def get_device(args, kwargs): + if kwargs.get("device"): + device = kwargs.get("device") + if isinstance(device, str): + device = torch.device(device) + return device.type + + devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)} + if any(dev == "cuda" for dev in devices): + return "cuda" + elif any(dev == "cpu" for dev in devices): + return "cpu" + return None + + +def register_run_and_save_rng_state_op(): + run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state") + + run_and_save_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(run_and_save_rng_state, deferred_error=True) + ) + + @run_and_save_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(op, *args, **kwargs): + return torch.cuda.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.CPU) + def impl_cpu(op, *args, **kwargs): + return torch.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(op, *args, **kwargs): + impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + device = get_device(args, kwargs) + assert device in impl_map, f"Backend not supported for {device}" + impl = impl_map[device] + return impl(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, op, *args, **kwargs): + # Check device to call the right impl + with mode: + return impl_backend_select(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, op, *args, **kwargs): + if mode.enable_tracing: + out = impl_backend_select(op, *args, **kwargs) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args)) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + else: + return run_and_save_rng_state(op, *args, **kwargs) + + return run_and_save_rng_state + + +def register_run_with_rng_state_op(): + run_with_rng_state = HigherOrderOperator("run_with_rng_state") + + run_with_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(run_with_rng_state, deferred_error=True) + ) + + @run_with_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(rng_state, op, *args, **kwargs): + current_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state.cpu()) + out = op(*args, **kwargs) + torch.cuda.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(DispatchKey.CPU) + def impl_cpu(rng_state, op, *args, **kwargs): + current_state = torch.get_rng_state() + torch.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): + if mode.enable_tracing: + with disable_proxy_modes_tracing(): + out = run_with_rng_state(rng_state, op, *args, **kwargs) + proxy_args = pytree.tree_map( + mode.tracer.unwrap_proxy, (rng_state, op, *args) + ) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", run_with_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + else: + return run_with_rng_state(rng_state, op, *args, **kwargs) + + @run_with_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(rng_state, op, *args, **kwargs): + impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + device = get_device(args, kwargs) + assert device in impl_map, f"Backend not supported for {device}" + impl = impl_map[device] + return impl(rng_state, op, *args, **kwargs) + + @run_with_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs): + # Skip setting the set_rng_state as it does not work well with fake tensors. + # And it does not matter for the fake tensor mode. + with mode: + return op(*args, **kwargs) + + return run_with_rng_state + + +run_and_save_rng_state = register_run_and_save_rng_state_op() +run_with_rng_state = register_run_with_rng_state_op() + + +def register_rng_prims(): + register_philox_rand() diff --git a/MLPY/Lib/site-packages/torch/_prims_common/__init__.py b/MLPY/Lib/site-packages/torch/_prims_common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..494c94ffe2d852bb723dd22dd1b1c5e2fbdd5a22 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims_common/__init__.py @@ -0,0 +1,1985 @@ +from __future__ import annotations + +import operator +import warnings +import weakref + +from contextlib import nullcontext +from enum import Enum +from functools import cmp_to_key, reduce +from typing import ( + Any, + Callable, + cast, + List, + NamedTuple, + Optional, + overload, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +from typing_extensions import TypeAlias + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + +import torch +from torch import sym_float, sym_int, sym_max + + +ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]] +StrideType: TypeAlias = Union[List[int], Tuple[int, ...]] +DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]] +DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]] +# TODO: Type[torch.SymInt], Type[torch.SymFloat] +NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]] +# TODO: This needs a lot more type annotations +# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] +NumberType: TypeAlias = Union[bool, int, float, complex] +RealNumberType: TypeAlias = Union[bool, int, float] + +Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat) +# I don't call it Integral because numbers.Integral includes bool, but IntLike +# does not +Dim = int +IntLike = (int, torch.SymInt) +FloatLike = (float, torch.SymFloat) +IntWithoutSymInt = int +FloatWithoutSymFloat = float +DeviceLikeType: TypeAlias = Union[str, torch.device, int] +Tensor = torch.Tensor + + +torch_function_passthrough = { + torch.device, + torch.sym_not, + torch.sym_float, + torch.sym_int, + torch.sym_max, + torch.sym_min, + torch._sym_sqrt, # type: ignore[attr-defined] + torch.sym_ite, + torch.Tensor.dim, + torch.Tensor.ndim.__get__, # type: ignore[attr-defined] + torch.Tensor.numel, + torch.Tensor.size, + torch.Tensor.storage_offset, + torch.Tensor.stride, + torch.Tensor.dtype.__get__, # type: ignore[attr-defined] + torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined] + torch.Tensor.shape.__get__, # type: ignore[attr-defined] + torch.Tensor.device.__get__, # type: ignore[attr-defined] + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.layout.__get__, # type: ignore[attr-defined] + torch.Tensor.is_contiguous, + # For TorchRefsMode only + torch.Tensor.__format__, + torch.Tensor.__repr__, + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] +} + + +TensorLikeType = torch.Tensor +TensorLike = torch.Tensor +TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]] +TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType] + +CustomOutParamAnnotation = "__custom_out_param__" + + +def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if len(a) != len(b): + return False + + for x, y in zip(a, b): + if allow_rhs_unbacked: + # TODO: We should check that the symbols are consistent + # with each other + if isinstance(y, torch.SymInt): + continue + # NB: Naively, you would not expect to have to do an oblivious guard + # here because there is seemingly no broadcasting here, but in fact we + # use this in some situations to determine if we need to do an expand + # on the tensor because they don't line up, so you can definitely end + # up trying to prove u0 != 1 in this situation. See + # python test/test_proxy_tensor.py -k test_cumsum_unbacked + if guard_size_oblivious(x != y): + return False + + return True + + +def _maybe_get_pytype(t): + if t is torch.SymFloat: + return float + elif t is torch.SymInt: + return int + elif t is torch.SymBool: + return bool + else: + return t + + +# TODO: look at using torch.testing.assert_close instead with an option +# to just compare metadata +def compare_tensor_meta( + a: TensorLikeType, + b: TensorLikeType, + check_strides=False, + *, + allow_rhs_unbacked=False, + check_conj=True, +): + """ + Checks that two tensor likes have the same shape, + dtype and device. + + In the future this will validate additional metadata, like + strides. + """ + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + + if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked): + msg = f"Shapes {a.shape} and {b.shape} are not equal!" + raise AssertionError(msg) + + if a.dtype != b.dtype: + msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!" + raise AssertionError(msg) + + if a.device != b.device: + # Handles special cuda:0 vs cuda case + # TODO: we should review why this happens and see about fixing it + if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and ( + str(b.device) == "cuda:0" or str(b.device) == "cuda" + ): + pass + else: + msg = f"Devices {a.device} and {b.device} are not equal!" + raise AssertionError(msg) + + # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050 + if check_strides: + same_strides, idx = check_significant_strides(a, b) + if not same_strides: + msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" + raise RuntimeError(msg) + + if a.storage_offset() != b.storage_offset(): + msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" + raise RuntimeError(msg) + + if check_conj: + if a.is_conj() != b.is_conj(): + raise RuntimeError( + f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}" + ) + + if a.is_neg() != b.is_neg(): + raise RuntimeError( + f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}" + ) + + +def _check_strides_helper( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True +) -> Tuple[bool, Optional[int]]: + # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch + # See https://github.com/pytorch/pytorch/issues/77553 + # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 + # and for tensors with more than one element + if ( + not only_cuda or a.device.type == "cuda" or b.device.type == "cuda" + ) and a.numel() > 0: + for idx in range(a.ndim): + check = not significant_only or a.shape[idx] > 1 + if a.stride()[idx] != b.stride()[idx] and check: + return False, idx + + return True, None + + +def check_significant_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) + + +def check_all_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) + + +# This function is equivalent to compute_contiguous() from TensorImpl.cpp +def is_contiguous(a: TensorLikeType) -> bool: + """ + Tests whether a tensor is contiguous or not. + + Tensors are contiguous when they have no elements, + one element, or when they have "nested" strides. + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(a.numel() < 2): + return True + + expected_stride = 1 + for x, y in reversed(tuple(zip(a.shape, a.stride()))): + # Skips checking strides when a dimension has length 1 + if guard_size_oblivious(x == 1): + continue + + if y != expected_stride: + return False + expected_stride = expected_stride * x + + return True + + +# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp +def is_channels_last_contiguous_2d(a: Tensor) -> bool: + # NHWC or not channels last 2D contiguous + if a.ndim != 4: + return False + + expected_stride = 1 + for idx in (1, 3, 2, 0): + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def is_channels_last_contiguous_3d(a: Tensor) -> bool: + # NDHWC or not channels last 3D contiguous + if a.ndim != 5: + return False + + expected_stride = 1 + for idx in (1, 4, 3, 2, 0): + length = a.shape[idx] + if length == 1: + continue + + stride = a.stride()[idx] + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +_memory_formats = { + torch.contiguous_format, + torch.preserve_format, + torch.channels_last, + torch.channels_last_3d, +} + + +def validate_memory_format(memory_format: torch.memory_format): + torch._check( + memory_format in _memory_formats, + lambda: f"Received unknown memory format {memory_format}!", + ) + + +def is_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format +) -> bool: + validate_memory_format(memory_format) + + if memory_format == torch.contiguous_format: + return is_contiguous(a) + if memory_format == torch.channels_last: + return is_channels_last_contiguous_2d(a) + if memory_format == torch.channels_last_3d: + return is_channels_last_contiguous_3d(a) + + torch._check( + False, + lambda: f"is_contiguous received unsupported memory format {memory_format}", + ) + + +# NOTE: that tensors with no elements and channels last is ??? +def is_channels_last_contiguous(a: Tensor) -> bool: + """ + True when a tensor is channels-last contiguous. + + This requires that: + + - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions + - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the + stride of the 'C' dimension (Cs) is 1 and the strides corresponding to + each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are + "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, + for example. + """ + return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) + + +def is_non_overlapping_and_dense(a: Tensor) -> bool: + """ + True when a tensor is non-overlapping and dense. + + A tensor is non-overlapping and dense when there exists a permutation of + its dimensions that is contiguous. + """ + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if a.is_sparse: + return False + + # Short-circuits if the tensor is already contiguous or channels-last contiguous + if is_contiguous(a) or is_channels_last_contiguous(a): + return True + + # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + if a.ndim == 1: + return a.stride()[0] == 1 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + # + # This sort is done in a size-oblivious way, which helps if we do a + # comparison like 2048*u0 > u0; we just want this to return True + # (and not worry about what if u0 is zero). + class K(NamedTuple): + size: int + stride: int + + def __lt__(self, other): + return guard_size_oblivious(self.stride < other.stride) + + def __gt__(self, other): + return guard_size_oblivious(self.stride > other.stride) + + def __le__(self, other): + return guard_size_oblivious(self.stride <= other.stride) + + def __ge__(self, other): + return guard_size_oblivious(self.stride >= other.stride) + + def __eq__(self, other): + return guard_size_oblivious(self.stride == other.stride) + + lengths_and_strides = sorted(map(K, a.shape, a.stride())) + + expected_stride = 1 + for length, stride in lengths_and_strides: + if guard_size_oblivious(length == 1): + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +# NOTE: Based on the implementation in TensorIterator.cpp, but note that +# the note [Computing output strides] is incorrect, because it +# says that strides will be preserved even if they are not +# "non overlapping and dense", but this is incorrect. The +# output of elementwise operations are always given +# non overlapping and dense strides. +# This is also INCORRECT because it does not model TensorIterator's +# short-circuit, which can cause different strides. +def compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=False +) -> List[int]: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not _skip_checks and len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + if not _skip_checks: + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + if not _skip_checks: + tensors = tuple( + a + for a in tensors + if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return [] + + # Short-circuits for shapes with zero or one dimensions + # TODO: are these necessary? + ndim = tensors[0].ndim + if ndim == 0: + return [] + if ndim == 1: + return [0] + + # Short-circuits if contiguous, following the fake fast path. + # This reduces the number of guards we end up making + # TODO: do channels last too + is_contiguous = True + for t in tensors: + is_contiguous = is_contiguous and t.is_contiguous( + memory_format=torch.contiguous_format + ) + + if is_contiguous: + return list(range(ndim)) + + shape = tensors[0].shape + + def should_swap(idx_a, idx_b): + for tensor in tensors: + stride_a = tensor.stride()[idx_a] + stride_b = tensor.stride()[idx_b] + + if guard_size_oblivious(stride_a == 0) or guard_size_oblivious( + stride_b == 0 + ): + continue + + if guard_size_oblivious(stride_a < stride_b): + return -1 + + if guard_size_oblivious(stride_a > stride_b): + return 1 + + # stride_a == stride_b + if guard_size_oblivious(shape[idx_a] > shape[idx_b]): + return 1 + + # Note: this case is hit if all strides are zero, + # or all strides are equal and all dimensions have the same length + return 0 + + # The "sort" order for the permutation is back-to-front, but + # the natural order for permutations is front-to-back. Do the + # sorting back-to-front and then reverse it on output. + # + # also, note this returns the logical to physical shape permutation + perm = list(reversed(range(ndim))) + + # insertion sort with support for ambiguous comparisons + for i in range(1, ndim): + dim1 = i + for dim0 in reversed(range(i)): + comparison = should_swap(perm[dim0], perm[dim1]) + if comparison > 0: + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + dim1 = dim0 + elif comparison < 0: + break + + return list(reversed(perm)) + + +def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: + """ + Computes the output strides for elementwise operations. + """ + if len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + tensors = tuple( + a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return () + + ndim = tensors[0].ndim + shape = tensors[0].shape + + if ndim == 0: + return () + if ndim == 1: + return (1,) + + logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=True + ) + permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical + + new_strides = make_contiguous_strides_for(permuted_shape) + permuted_strides = apply_perm( + new_strides, invert_perm(logical_to_physical_perm) + ) # to logical + + return tuple(permuted_strides) + + +# Identity permutation is [0, 1, 2] +def apply_perm(inp, perm): + ndim = len(inp) + permuted_inp = [-1] * ndim + for idx, x in enumerate(perm): + permuted_inp[idx] = inp[x] + return permuted_inp + + +def invert_perm(perm): + ndim = len(perm) + new_perm = [-1] * ndim + for idx, x in enumerate(perm): + new_perm[x] = idx + return new_perm + + +# +# Common helper functions +# + + +def validate_dim_length(length: int): + """ + Validates that an object represents a valid + dimension length. + """ + + if isinstance(length, (int, torch.SymInt)): + torch._check_is_size(length) + else: + # sometimes called with sympy expression by inductor + assert length >= 0 + + +def validate_shape(shape: ShapeType): + """ + Validates that a sequence represents a valid shape. + """ + + assert isinstance(shape, Sequence), type(shape) + for l in shape: + validate_dim_length(l) + + +def validate_strides(strides: StrideType): + """ + Verifies the object specifies valid strides. + """ + + assert isinstance(strides, Sequence) + for stride in strides: + assert stride >= 0 + + +def validate_idx(rank: int, idx: int): + """ + Validates that idx is a valid index for the given shape. + Assumes the index is already canonicalized. + """ + + assert isinstance(idx, Dim) + assert isinstance(rank, Dim) + + assert idx >= 0 and idx < rank or idx == 0 + + +def validate_dimension_indices(rank: int, indices: DimsSequenceType): + for idx in indices: + validate_idx(rank, idx) + + +def validate_exclusive_idx(rank: int, ex_idx: int): + """ + Validates that ex_idx is a valid exclusive index + for the given shape. + """ + + assert isinstance(ex_idx, Dim) + assert isinstance(rank, Dim) + assert ex_idx > 0 and ex_idx <= rank + + +# "Wraps" a dim (up to one time) for the given rank, allowing dims to be +# specified using negative indices. If `wrap_scalar` is true then scalar +# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise, +# idx should be in the range [-rank, rank-1]. +def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: + if rank < 0: + msg = f"Rank cannot be negative but got {rank}" + raise IndexError(msg) + + if rank == 0: + if not wrap_scalar: + msg = f"Dimension specified as {idx} but tensor has no dimensions" + raise IndexError(msg) + rank = 1 + + if idx >= 0 and idx < rank: + return idx + + if idx < 0: + _idx = idx + rank + else: + _idx = idx + + if _idx < 0 or _idx >= rank: + # Same error message as in aten/src/ATen/WrapDimUtils.h:49 + msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})" + raise IndexError(msg) + + return _idx + + +# Takes a dimension or sequence of dimensions and "wraps" them, +# mapping negative offsets to positive ones +@overload +def canonicalize_dims( + rank: int, indices: Sequence[int], wrap_scalar: bool = True +) -> Tuple[int, ...]: + pass + + +@overload +def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int: + pass + + +def canonicalize_dims(rank, indices, wrap_scalar=True): + if isinstance(indices, Dim): + return canonicalize_dim(rank, indices, wrap_scalar) + + return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices) + + +def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: + """ + Validates that perm is a permutation of length rank. + """ + + if not isinstance(perm, Sequence): + return False + + if not (tuple(sorted(perm)) == tuple(range(0, rank))): + return False + + return True + + +def is_same_shape(a: Sequence, b: Sequence) -> bool: + """ + Compares two shapes a and b, returning True if they are the same + (their ranks and corresponding lengths match) and False otherwise. + """ + + return tuple(a) == tuple(b) + + +def is_cpu_scalar_tensor(a: Any) -> bool: + return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu" + + +def check_same_device(*args, allow_cpu_scalar_tensors): + """ + Checks that all Tensors in args have the same device. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True + """ + # Short-circuits if all (one or fewer) arguments are trivially on the same device + if len(args) <= 1: + return + + # Note: cannot initialize device to the first arg's device (it may not have one) + device = None + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if device is None: + device = arg.device + + if device != arg.device: + msg = ( + "Tensor on device " + + str(arg.device) + + " is not on the expected device " + + str(device) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same device, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +def canonicalize_device(device: DeviceLikeType) -> torch.device: + if isinstance(device, torch.device): + return device + + assert isinstance(device, str) + return torch.device(device) + + +# Asserts if any of the following are true: +# - a non-scalar or non-Tensor is given +# - the shape of any tensors is distinct +def check_same_shape(*args, allow_cpu_scalar_tensors: bool): + """ + Checks that all Tensors in args have the same shape. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices + """ + shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + msg = f"Shape {arg.shape} is not the expected shape {shape}!" + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same shape, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Acquires a common shape, if it exists, from one or more tensor arguments, +# filtering number arguments +def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: + shape = None + scalar_shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + scalar_shape = arg.shape + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + return None + else: + return None + + return shape if shape is not None else scalar_shape + + +# Extracts dimensions that might be passed either as a list/tuple or as varargs. +# A typical case is Tensor.permute . +def extract_dims_from_varargs( + dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]] +) -> DimsSequenceType: + if dims and isinstance(dims[0], Sequence): + assert len(dims) == 1 + dims = cast(Tuple[DimsSequenceType], dims) + return dims[0] + else: + return cast(DimsSequenceType, dims) + + +def extract_shape_from_varargs( + shape: Union[ShapeType, Tuple[ShapeType]], + validate=True, +) -> Tuple[int, ...]: + """ + Returns a shape from varargs. + + In PyTorch, operations that accept shapes often accept them as varargs, like + foo(*shape). However a user can pass the shape as a sequence of integers, + like this: + + foo(1, 2, 3) + + or as a sequence of integers + + foo((1, 2, 3)) + + In the first case shape will be a tuple of integers, and in the second case it's a tuple + containing a tuple of integers. This validates those inputs and canonicalizes them + to a tuple of integers. + """ + + # Handles tuple unwrapping + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = shape[0] + + if validate: + validate_shape(shape) # type: ignore[arg-type] + return shape # type: ignore[return-value] + + +def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]: + ndim = max(len(a), len(b)) + expandedSizes = [0] * ndim + + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = len(a) - 1 - offset + dimB = len(b) - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + torch._check( + (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1), + lambda: ( + f"The size of tensor a ({sizeA}) must match the size of " + f"tensor b ({sizeB}) at non-jagged dimension {i}" + ), + ) + + # 1s map to the other size (even 0) + expandedSizes[i] = sizeB if sizeA == 1 else sizeA + + return tuple(expandedSizes) + + +def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: + """ + Infers the size of a dim with size -1, if it exists. + Also checks that new shape is compatible with the number of elements. + """ + dim = None + newsize = 1 + for i, d in enumerate(shape): + if d == -1: + torch._check(dim is None, lambda: "only one dimension can be inferred") + dim = i + elif d >= 0: + newsize *= d + else: + torch._check(False, lambda: f"invalid shape dimension {d}") + if dim is None: + torch._check( + numel == newsize, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + else: + from torch.fx.experimental.symbolic_shapes import definitely_true + + torch._check( + newsize != 0, + lambda: ( + f"cannot reshape tensor of 0 elements into shape {list(shape)} because the " + f"unspecified dimension size -1 can be any value and is ambiguous" + if definitely_true(numel == 0) + else f"shape '{list(shape)}' is invalid for input of size {numel}" + ), + ) + torch._check( + numel % newsize == 0, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + # Convert to list to produce a compatible error message with core + # PyTorch, which prints sequences in square brackets. + shape = list(shape) + shape[dim] = numel // newsize + # NB: This is pretty important when you have unbacked SymInts. + # Suppose you have (i0, 12) resizing into (2, -1, 12). The old + # range for i0 is typically [2, inf], which means if you divide + # by two the new range should be [1, inf]. But this is bad news + # if you have an unbacked SymInt: we need to reapply the unsound + # assumption that the size is >= 2. + torch._check_is_size(shape[dim]) + return tuple(shape) + + +_integer_dtypes = ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, +) +_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) +_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) + + +def is_boolean_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype is torch.bool + + +def is_integer_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _integer_dtypes + + +def is_low_precision_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _low_precision_dtypes + + +def is_float_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype.is_floating_point + + +def is_complex_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _complex_dtypes + + +def is_grad_dtype(dtype: torch.dtype) -> bool: + """ + Checks if the dtype can require a gradient. + """ + return dtype.is_floating_point or is_complex_dtype(dtype) + + +_complex_to_real_dtype_map = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +_real_to_complex_dtype_map = { + torch.float16: torch.complex32, + torch.bfloat16: torch.complex64, + torch.float32: torch.complex64, + torch.float64: torch.complex128, +} + + +def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype: + return _complex_to_real_dtype_map[dtype] + + +def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype: + return _real_to_complex_dtype_map[dtype] + + +def dtype_to_type(dtype: torch.dtype) -> type: + """ + Computes the corresponding Python type (AKA "type kind") for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return bool + if dtype in _integer_dtypes: + return int + if dtype.is_floating_point: + return float + if dtype in _complex_dtypes: + return complex + + raise ValueError("Invalid dtype!") + + +def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]: + """ + Computes the corresponding Python type constructor for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return lambda x: bool(x) + if dtype in _integer_dtypes: + return sym_int + if dtype.is_floating_point: + return sym_float + if dtype in _complex_dtypes: + # TODO: type error here is real, replace with sym_complex + return lambda x: complex(x) # type: ignore[arg-type] + + raise ValueError("Invalid dtype!") + + +def type_to_dtype(typ: type) -> torch.dtype: + """ + Computes the corresponding dtype for a Number type. + """ + + assert isinstance(typ, type) + + if typ is bool: + return torch.bool + if typ in [int, torch.SymInt]: + return torch.long + if typ in [float, torch.SymFloat]: + return torch.get_default_dtype() + # TODO: sym_complex_float? + if typ is complex: + return corresponding_complex_dtype(torch.get_default_dtype()) + + raise ValueError("Invalid type!") + + +def get_dtype(x: Union[torch.Tensor, NumberType]): + if isinstance(x, torch.Tensor): + return x.dtype + else: + return type_to_dtype(type(x)) + + +_ordered_types = (bool, int, float, complex) + + +def check_fp_or_complex( + dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True +): + """ + Checks whether the input is floating point or complex. + If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 + """ + torch._check( + is_float_dtype(dtype) or is_complex_dtype(dtype), + lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + torch._check( + allow_low_precision_dtypes or not is_low_precision_dtype(dtype), + lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", + ) + + +def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): + torch._check( + len(A.shape) >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + +def get_higher_type(a: type, b: type) -> type: + """ + Returns the higher of the two given Number types. + + The types are ordered bool -> int -> float -> complex. + """ + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + # Type checking + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + if a is b: + return a + + for typ in _ordered_types: + if a is typ: + return b + if b is typ: + return a + + raise ValueError("Unknown Python scalar type!") + + +# Returns the higher of two torch datatypes a and b or, if the two +# are not ordered relative to each other, the next +# higher datatype +def get_higher_dtype( + a: Optional[Union[torch.dtype, TensorLikeType, NumberType]], + b: Optional[Union[torch.dtype, TensorLikeType, NumberType]], +) -> Optional[torch.dtype]: + """ + Computes the "lowest" datatype that is weakly + "higher" than both a and b. + """ + + # Type checking + assert a is None or isinstance(a, (torch.dtype, TensorLike, Number)) + assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) + + def _extract_dtype( + x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] + ) -> Optional[torch.dtype]: + if x is None: + return None + if isinstance(x, torch.dtype): + return x + if isinstance(x, TensorLike): + return x.dtype + if isinstance(x, Number): + return type_to_dtype(type(x)) + + raise RuntimeError("Unexpected type given to _extract_dtype!") + + a, b = _extract_dtype(a), _extract_dtype(b) + + if a is b: + return a + + if a is None: + return b + + if b is None: + return a + + ordered_datatypes = ( + (torch.bool,), + (torch.uint8, torch.int8), + (torch.int16,), + (torch.int32,), + (torch.int64,), + (torch.float16, torch.bfloat16), + (torch.float32,), + (torch.float64,), + (torch.complex32,), + (torch.complex64,), + (torch.complex128,), + ) + + for idx, dtypes in enumerate(ordered_datatypes): + if a in dtypes and b in dtypes: + return ordered_datatypes[idx + 1][0] + if a in dtypes: + return b + if b in dtypes: + return a + + raise RuntimeError("Unexpected termination!") + + +def check_pin_memory(pin_memory: bool): + torch._check_not_implemented( + not pin_memory, lambda: "PrimTorch does not support pinned memory" + ) + + +def check_layout(layout: torch.layout): + torch._check_not_implemented( + layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}" + ) + + +# TODO: maybe unify with can_cast_to? +def is_weakly_lesser_type(a: type, b: type) -> bool: + """ + Compares two types, a and b, returning True if a is weakly "less" than b. + + The comparison is determined by the following type ordering: bool, int, float, complex. + """ + + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + for typ in _ordered_types: + if a == typ: + return True + if b == typ: + return False + + raise RuntimeError("Unexpected termination!") + + +def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool: + for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype): + if fn(cast_to): + return True + if fn(cast_from): + return False + + raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!") + + +def check_same_dtype(*args): + """ + Checks that all Tensors in args have the same device and that all Numbers have the + same corresponding Python type. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensors objects in args have different dtypes + - two Number objects in args have different types + - there are Tensors and Numbers in args, and one of those Tensors corresponding + Python types is different from the type of one of those Numbers + """ + full_dtype = None + scalar_type = None + + for arg in args: + if isinstance(arg, Number): + # Scalar type checking is disabled (and may be removed in the future) + continue + # if scalar_type is None: + # scalar_type = type(arg) + + # if scalar_type is not type(arg): + # msg = ( + # "Scalar of type " + # + str(type(arg)) + # + " is not the expected type of " + # + str(scalar_type) + # + "!" + # ) + # raise RuntimeError(msg) + elif isinstance(arg, TensorLike): + if full_dtype is None: + full_dtype = arg.dtype + if scalar_type is None: + scalar_type = dtype_to_type(arg.dtype) + + if full_dtype is not arg.dtype: + msg = ( + "Tensor with dtype " + + str(arg.dtype) + + " is not the expected dtype of " + + str(full_dtype) + + "!" + ) + raise RuntimeError(msg) + + arg_type = dtype_to_type(arg.dtype) + if arg_type is not scalar_type: + msg = ( + "Tensor with corresponding Python type " + + str(arg_type) + + " is not the expected type of " + + str(scalar_type) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same dtype, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Maps datatypes to their computation types for elementwise operations +_computation_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32, + torch.complex32: torch.complex64, +} + + +def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: + return _computation_dtype_map.get(dtype, dtype) + + +_cpu_acc_type_map = { + torch.bfloat16: torch.float64, + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.complex32: torch.complex128, + torch.complex64: torch.complex128, +} + + +def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype: + # Equivalent to at::toAccumulateType, prefer computation_dtype where possible + if device.type == "cpu": + return _cpu_acc_type_map.get(dtype, dtype) + else: + return get_computation_dtype(dtype) + + +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) + + +class REDUCTION_OUTPUT_TYPE_KIND(Enum): + SAME = (0,) + COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type + KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean + ALWAYS_BOOL = (3,) + + +# Describes the return type of the primitive: +# +# - NEW, a new tensor is created +# - VIEW, a view of an input tensor is returned +# - INPLACE, one or more input tensors is modified +# +# these descriptors are mututally exclusive and exhaustive. +class RETURN_TYPE(Enum): + NEW = (0,) + VIEW = (1,) + INPLACE = (2,) + + +# TODO: when NumberType contains the sym types, can simplify this +def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type: + if isinstance(x, torch.SymInt): + return int + elif isinstance(x, torch.SymFloat): + return float + else: + return type(x) + + +def expr_type(x: sympy.Expr) -> Type: + if x.is_integer: # type: ignore[attr-defined] + return int + else: + # NB: Not strictly correct, but we don't support SymPy complex or bool. + return float + + +# TODO: document type promotion kinds +def elementwise_dtypes( + *_args, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, +) -> Tuple[torch.dtype, torch.dtype]: + """ + Computes the computation and result dtypes for elementwise type promotion + on the given arguments and with the given elementwise type promotion kind. + + Note that not all inputs to an elementwise operation necessarily participate in type promotion. + For example, the "alpha" parameter of torch.add does not participate in type promotion, + although it may be cast to the Python type corresponding to the computation dtype that + the type promotion algorithm determines. + + Default elementwise type promotion, which all other type promotion kinds tweak (see below), + first decides which of four ordered types to use: + + bool -> integer -> floating point -> complex + + The selected type is the "lowest" type in the above list such that all number arguments + have a weakly "lower" type and all tensor arguments have a weakly lower corresponding + type for their dtype. + + Once the type is determined, the particular result dtype is found. The dtypes are + partially ordered as follows: + + bool -> uint8, int8 -> int16 -> int32 -> int64 -> + float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 + + The result dtype is selected by: + - if no tensor's dtype has the same corresponding type as the one selected, + then the result dtype is the (default) dtype corresponding to the selected type + (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) + - if the result type is complex then the dtype is: + - the default complex dtype if there are no floating point or complex tensors + - if there are floating point or complex tensors with one or more dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + (for example, double + cfloat -> cdouble) + - if there are only floating point or complex tensors with zero dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + - if the first two cases do not apply, the result dtype is the highest dtype among + all tensors with one or more dimensions of the output type, and if there are no such + tensors then it's the highest dtype among all tensors with zero dimensions of the output type + (for example, long + half -> half, even if the half tensor has zero dimensions) + + The "corresponding complex dtypes" are: + float16 -> complex32 + bfloat16 -> complex64 + float32 -> complex64 + float64 -> complex128 + complex32 -> complex32 + complex64 -> complex64 + complex128 -> complex128 + + The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation + dtype by mapping low precision floating point and complex dtypes as follows: + + float16 -> float32 + bfloat16 -> float32 + complex32 -> complex64 + + This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the + computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels + which perform no mathematical operations on their tensors (see below for examples). + + The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype, + and computation dtypes to the appropriate op math dtype. + + The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this + mapping: + + complex32 -> float16 + complex64 -> float32 + complex128 -> float64 + + Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does. + + The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long. + + The ALWAYS_BOOL type promotion kind always sets the result dtype to bool. + + Example operators for each type promotion option: + DEFAULT : add + NO_OPMATH : where, nextafter, cat + INT_TO_FLOAT : sin + COMPLEX_TO_FLOAT : abs + BOOL_TO_LONG : pow + ALWAYS_BOOL : eq + + """ + + args = tuple(x for x in _args if x is not None) + + highest_type: type = bool + + # Import sympy locally, as importing it eagerly at a module level is too slow + # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589 + import sympy + + for x in args: + if not isinstance(x, (Number, TensorLike, sympy.Expr)): + msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" + raise ValueError(msg) + + if isinstance(x, Number): + highest_type = get_higher_type(highest_type, number_type(x)) + elif isinstance(x, sympy.Expr): + highest_type = get_higher_type(highest_type, expr_type(x)) + else: + # x is a TensorLike + highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype)) + + result_dtype = None + + def _find_highest_dtype_filtered( + args, filter, *, float_as_complex=False + ) -> Optional[torch.dtype]: + zero_dim_tensor_dtype = None + one_plus_dim_tensor_dtype = None + for x in args: + if isinstance(x, TensorLike) and filter(x.dtype): + _dtype = x.dtype + if float_as_complex and is_float_dtype(_dtype): + _dtype = corresponding_complex_dtype(_dtype) + if x.ndim == 0: + zero_dim_tensor_dtype = get_higher_dtype( + zero_dim_tensor_dtype, _dtype + ) + else: + # x.ndim > 0 + one_plus_dim_tensor_dtype = get_higher_dtype( + one_plus_dim_tensor_dtype, _dtype + ) + + # Prefers dtype of tensors with one or more dimensions + if one_plus_dim_tensor_dtype is not None: + return one_plus_dim_tensor_dtype + + return zero_dim_tensor_dtype + + if highest_type is float: + result_dtype = _find_highest_dtype_filtered(args, is_float_dtype) + result_dtype = ( + torch.get_default_dtype() if result_dtype is None else result_dtype + ) + elif highest_type is complex: + result_dtype = _find_highest_dtype_filtered( + args, + lambda x: is_float_dtype(x) or is_complex_dtype(x), + float_as_complex=True, + ) + if result_dtype is None: + result_dtype = corresponding_complex_dtype(torch.get_default_dtype()) + elif highest_type is int: + result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype) + result_dtype = torch.long if result_dtype is None else result_dtype + else: + # highest_type is bool + result_dtype = torch.bool + + if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH: + return result_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype): + result_dtype = torch.get_default_dtype() + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: + # NOTE: computation can still occur in a complex dtype + computation_dtype = get_computation_dtype(result_dtype) + if is_complex_dtype(result_dtype): + result_dtype = corresponding_real_dtype(result_dtype) + return computation_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: + if is_boolean_dtype(result_dtype): + return torch.long, torch.long + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + return get_computation_dtype(result_dtype), torch.bool + else: + raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}") + + +def reduction_dtypes( + arg, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.dtype, Optional[torch.dtype]]: + # even though some reductions, like amin or amax, don't strictly require type promotion, + # all the math ops (including comparisons) are still defined only for a computation type, + # so promotion will still happen. We are doing it explicitly here + inp_dtype = dtype if dtype is not None else arg.dtype + computation_dtype = get_computation_dtype(inp_dtype) + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME + or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ): + result_dtype = dtype if dtype else arg.dtype + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + and is_complex_dtype(result_dtype) + ): + result_dtype = corresponding_real_dtype(result_dtype) + elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE: + result_dtype = None + else: # ALWAYS_BOOL + result_dtype = torch.bool + return computation_dtype, result_dtype + + +# This function's logic is borrowed from the following functions defined in C++: +# batched_matrix_contiguous_strides and contiguous_strides +def make_contiguous_strides_for( + shape: ShapeType, row_major: bool = True +) -> Tuple[int, ...]: + """ + Returns the strides of a contiguous tensor if row_major + If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices + This is often used when calling external libraries like BLAS/LAPACK/cuSolver... + """ + # contiguous_strides from c10/util/strides.h + validate_shape(shape) + if not shape: + return () + + from torch.fx.experimental.symbolic_shapes import is_nested_int + + multiplier = 1 + strides = [] + for l in reversed(shape): + strides.append(multiplier) + multiplier *= l if is_nested_int(l) else sym_max(l, 1) + + result = tuple(reversed(strides)) + + # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h + if row_major: + return result + else: + if len(shape) < 2: + return result + return result[:-2] + (1, max(shape[-2], 1)) + + +def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + torch._check( + len(shape) == 3, + lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", + ) + + multiplier = 1 + strides = [0] * 3 + for idx in (1, -1, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? + torch._check( + len(shape) == 4, + lambda: "Only tensors of rank 4 can use the channels_last memory format", + ) + + multiplier = 1 + strides = [0] * 4 + for idx in (1, -1, -2, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: + torch._check( + len(shape) == 5, + lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", + ) + + multiplier = 1 + strides = [0] * 5 + for idx in (1, -1, -2, -3, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]: + ndim = len(shape) if isinstance(shape, Sequence) else 1 + if ndim == 3: + return make_channels_last_1d_strides_for(shape) + elif ndim == 4: + return make_channels_last_2d_strides_for(shape) + elif ndim == 5: + return make_channels_last_3d_strides_for(shape) + else: + raise RuntimeError( + f"no channels last format strides exist in {ndim} dimensions" + ) + + +def compute_reduction_output_shape( + shape: ShapeType, dimensions: Sequence +) -> Tuple[int, ...]: + for idx in dimensions: + validate_idx(len(shape), idx) + + new_shape = [] + for idx in range(len(shape)): + if idx in dimensions: + continue + + new_shape.append(shape[idx]) + + return tuple(new_shape) + + +def validate_no_repeating_dims(dims: Sequence): + if len(dims) != len(set(dims)): + raise RuntimeError("duplicate value in the list of dims") + + +def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]: + if dims is None: + return tuple(range(len(shape))) + dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims) + validate_no_repeating_dims(dims) + return dims + + +def set_correction( + unbiased: Optional[bool] = None, + correction: Optional[NumberType] = None, +) -> float: + if correction is not None and unbiased is not None: + raise RuntimeError("cannot specify both correction and unbiased arguments") + elif correction is None and unbiased is None: + correction = 1.0 + elif correction is None and unbiased is not None: + correction = 0.0 if unbiased is False else 1.0 + # NB: we don't actually support symint here, but it's harmless to accept + if not isinstance(correction, (IntLike, FloatLike)): + raise ValueError("correction argument should be integer or float") + if correction < 0: + raise ValueError("correction argument should be non-negative") + return sym_float(correction) + + +def compute_required_storage_length( + shape: ShapeType, strides: StrideType, storage_offset: int +) -> int: + """Computes the minimum storage size to hold the given tensor geometry. + + Example + ======= + + This is the size of a newly allocated tensor's storage, in units of elements + + >>> t = torch.empty((10, 20)) + >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) + 200 + + >>> # xdoctest: +SKIP(failing) + >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) + >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) + >>> size == t.storage().size() + True + + A valid tensor may have a larger storage size, but never smaller + + >>> slice = torch.empty(100)[20:40] + >>> slice.storage().size() + 100 + + >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + 40 + + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # Short-circuits if the shape has no elements + if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0): + return 0 + + max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) + # +1 to account for the first element which offsets are taken from + return 1 + storage_offset + max_offset + + +def check_in_bounds_for_storage( + a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int +): + """ + Determines if the given shape, strides, and offset are valid for the given storage. + """ + + required_length = compute_required_storage_length(shape, strides, storage_offset) + if a.size() < required_length: + msg = ( + "Can't view a storage of size {} with an offset of {}, shape of {}, and strides of {}, " + "which requires a storage of size {}".format( + a.size(), storage_offset, str(shape), str(strides), required_length + ) + ) + raise ValueError(msg) + + +# NOTE: This function should ideally be removed, but some Meta internal models +# packaged with `torch.package` are using it, so it will have to be removed +# at some point in the future when those models no longer use this function. +def check( + b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError +) -> None: + """ + Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails. + Error message is a callable producing a string (to avoid wasting time + string formatting in non-error case, and also to make it easier for torchdynamo + to trace.) + + .. note:: This function is planned for removal in the future. Please use + `torch._check*` functions instead. + """ + warnings.warn( + DeprecationWarning( + "'torch._prims_common.check' will be removed in the future. Please use " + "'torch._check*' functions instead" + ) + ) + torch._check_with(exc_type, b, s) + + +# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in +# c10/core/MemoryFormat.h into one function +def are_strides_like_channels_last( + shape: Sequence[int], strides: Sequence[int] +) -> bool: + ndim = len(shape) + + if ndim == 4: + # Check for channels_last_2d + dim_order = [1, 3, 2, 0] + elif ndim == 5: + # Check for channels_last_3d + dim_order = [1, 4, 3, 2, 0] + else: + return False + + if strides[1] == 0: + return False + + min = 0 + for d in dim_order: + if shape[d] == 0: + return False + if strides[d] < min: + return False + if d == 0 and min == strides[1]: + return False + min = strides[d] + if strides[d] > 1: + min *= shape[d] + return True + + +def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: + if x.layout != torch.strided: + return torch.contiguous_format + + if are_strides_like_channels_last(x.shape, x.stride()): + return torch.channels_last if x.ndim == 4 else torch.channels_last_3d + + return torch.contiguous_format + + +def prod(xs: Sequence[NumberType]) -> NumberType: + """Product of elements in input sequence. Returns 1 for empty sequence""" + return reduce(operator.mul, xs, 1) + + +def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool: + """Checks if a shape can be expanded to another shape. + This is equivalent to checking if the two shapes are broadcastable. + """ + # This is a Python implementation of + # aten/src/ATen/ExpandUtils.h:is_expandable_to + if len(shape) > len(desired): + return False + for i in range(len(shape)): + if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1: + return False + return True + + +def mask_tensor(mask: TensorLikeType, t: TensorLikeType): + """ + Similar to torch.where(mask, t, 0) but if t is boolean, + result is also boolean and not promoted to int. + """ + # torch.where(mask, t, False) is equivalent + # but feels hacky and might break in the future + if t.dtype is torch.bool: + return mask.logical_and(t) + else: + return torch.where(mask, t, 0) + + +def get_aten_op(fn: Callable, name: str): + """ + Given the __module__ of reference and its name, it returns + (our best guess of) the ATen name of the associated operation + + Note: In ATen, the __name__ of a function within a module often + starts by the module name. E.g. linalg_eigh, or special_zeta + """ + module = fn.__module__ + prefix = "torch._refs" + assert module.startswith(prefix) + module = module[len(prefix) :] + # We want to go from .special / .nn.functional + # to special and special_ / nn_functional_ + if module: + module = module[1:] + module = module.replace(".", "_") + module = module + "_" + return getattr(torch._ops.ops.aten, f"{module}{name}") + + +def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: + return dtype if dtype is not None else torch.get_default_dtype() + + +def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType: + return device if device is not None else torch.device("cpu") + + +def layout_or_default(layout: Optional[torch.layout]) -> torch.layout: + return layout if layout is not None else torch.strided + + +def clone_preserve_strides(x): + needed_size = compute_required_storage_length( + x.size(), x.stride(), x.storage_offset() + ) + # Our eager implementations for *_scatter ops are all primitives w.r.t autograd, + # so these as_strided() calls are not seen by autograd. + # We need to mimic this behavior in our ref/prim implementations. + # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided" + # We should revisit this when we add a compositional as_strided op, + # and also as part of https://github.com/pytorch/pytorch/issues/90507 + try: + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, True + ) + buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone() + return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old + ) + + +def alert_not_deterministic(caller: str): + if torch.are_deterministic_algorithms_enabled(): + if torch.is_deterministic_algorithms_warn_only_enabled(): + warnings.warn( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " + f"You can file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ) + else: + torch._check( + False, + lambda: ( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True)'. You can turn off " + f"determinism just for this operation, or you can use the " + f"'warn_only=True' option, if that's acceptable for your application. " + f"You can also file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ), + ) + + +class CUDARngStateHelper: + @staticmethod + def get_torch_state_as_tuple(fake_mode=nullcontext()): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + with fake_mode: + seed = torch.tensor(torch.cuda.initial_seed()) + offset = torch.tensor(torch.cuda._get_rng_state_offset()) + return seed, offset + + @staticmethod + def set_torch_state_tensor(seed, offset): + # Rng state is [64-bit seed, 64-bit offset] + seed_portion = seed.reshape([1]).view(torch.uint8) + offset_portion = offset.reshape([1]).view(torch.uint8) + new_state = torch.cat([seed_portion, offset_portion]) + torch.cuda.set_rng_state(new_state) + + @staticmethod + def set_new_offset(relative_offset): + torch.cuda._set_rng_state_offset(relative_offset.item()) diff --git a/MLPY/Lib/site-packages/torch/_prims_common/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims_common/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f307d8e07c12a89f9867b402fd9205f7cb1205d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims_common/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..986cf751a0b536ff60c669a06a358fd8437da660 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_prims_common/wrappers.py b/MLPY/Lib/site-packages/torch/_prims_common/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..bafeb88d67391f9ac9ffcbd9f89e34014cddba31 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_prims_common/wrappers.py @@ -0,0 +1,401 @@ +import inspect +import warnings +from functools import wraps +from itertools import chain + +from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple + +import torch +import torch._prims_common as utils +from torch._prims_common import ( + CustomOutParamAnnotation, + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_flatten, tree_unflatten + + +@overload +def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + pass + + +@overload +def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: + pass + + +@overload +def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: + pass + + +@overload +def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: + pass + + +# TODO: implement ref.cast with an option to enforce safe casting +def _maybe_convert_to_dtype(a, dtype): + if isinstance(a, TensorLike): + if a.dtype != dtype: + return a.to(dtype) + return a + if isinstance(a, Number): + return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type] + if isinstance(a, Sequence): + return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) + # Passthrough None because some functions wrapped with type promotion + # wrapper might have optional args + if a is None: + return None + + raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") + + +def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: + if not isinstance(a, Number): + msg = f"Found unknown type {type(a)} when trying to convert scalars!" + raise ValueError(msg) + if not utils.is_weakly_lesser_type(type(a), typ): + msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" + raise ValueError(msg) + + return typ(a) + + +def _annotation_has_type(*, typ, annotation): + if hasattr(annotation, "__args__"): + for a in annotation.__args__: + if _annotation_has_type(typ=typ, annotation=a): + return True + return False + + return typ is annotation + + +class elementwise_type_promotion_wrapper: + """ + Adds elementwise type promotion to a Python reference implementation. + + Takes two kwargs, type_promoting_args and type_promotion_kind. + + type_promoting_args must be a string Sequence specifiying the argument names of all + arguments that participate in type promotion (and should be type promoted). If the + arg specifies a Sequence-type then every element of the Sequence will participate in + type promotion. + + type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. + See its documentation for details. + + The return_dtype will be coerced to the wrapped function's dtype arg if it is available and + not None. + + Other type promotion behavior, like validating the Python type of scalar arguments, must + be handled separately. + """ + + def __init__( + self, + *, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, + type_promoting_args: Optional[Sequence[str]] = None, + ): + self.type_promoting_arg_names = type_promoting_args + self.type_promotion_kind = type_promotion_kind + + def __call__(self, fn: Callable) -> Callable: + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + type_promoting_args = tuple( + bound.arguments[x] + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + ) + + flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) + compute_dtype, result_dtype = utils.elementwise_dtypes( + *flattened_type_promoting_args, + type_promotion_kind=self.type_promotion_kind, + ) + + promoted_args = { + x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + } + bound.arguments.update(promoted_args) + + result = fn(**bound.arguments) + + # Override the return_dtype if a dtype arg is present and not None + if "dtype" in bound.arguments: + maybe_dtype = bound.arguments["dtype"] + if maybe_dtype: # dtype cannot be None + result_dtype = maybe_dtype + + if isinstance(result, TensorLike): + return _maybe_convert_to_dtype(result, result_dtype) + if isinstance(result, Sequence): + return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) + raise AssertionError(f"Unhandled result type: {type(result)}") + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn + + +# Returns True if resize is necessary +def _resize_output_check(out: TensorLikeType, shape: ShapeType): + # If the shapes are correct there's nothing to do + if utils.same_shape(out.shape, shape): + return False + if out.numel() != 0: + msg = ( + f"An output with one or more elements was resized since it had shape {str(out.shape)} " + "which does not match the required output shape {str(shape)}. " + "This behavior is deprecated, and in a future PyTorch release outputs will not " + "be resized unless they have zero elements. " + "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." + ) + warnings.warn(msg) + return True + + +# TODO: handle tuples of tensors +def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): + if _resize_output_check(out, shape): + return out.resize_(shape) + else: + return out + + +def _safe_copy_out( + *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False +): + # Checks same device + if copy_from.device != copy_to.device: + msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format( + copy_from.device, copy_to.device + ) + raise RuntimeError(msg) + + # Checks safe cast + if exact_dtype: + torch._check( + copy_from.dtype == copy_to.dtype, + lambda: f"Expected out tensor to have dtype {copy_from.dtype} " + f"but got {copy_to.dtype} instead", + ) + else: + torch._check( + utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), + lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " + "but this can't be cast because it is not safe!", + ) + + return copy_to.copy_(copy_from) + + +def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False): + # The wrapped function needs to convert the output parameters to ensure + # compatibility between the Python API (which always uses "out" as the + # parameter name and may be a tuple) and the Aten API (which may have + # multiple output parameters and use different parameter names such as + # "grad_input", "indices" or "values".) + + default_out_names = ("out",) + if len(out_names) == 0: + # Use default in out name + out_names = default_out_names + + is_tensor = len(out_names) == 1 + + def _out_wrapper(fn: Callable) -> Callable: + """ + Adds the out parameter to a Python reference. + """ + out_type = ( + TensorLikeType + if is_tensor + else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] + ) + return_type = ( + TensorLikeType + if is_tensor + else NamedTuple( + f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] + ) + ) + + sig = inspect.signature(fn) + factory_kwargs = ("device", "dtype") + is_factory_fn = all(p in sig.parameters for p in factory_kwargs) + + @wraps(fn) + def _fn(*args, out=None, **kwargs): + if is_factory_fn and out is not None: + for k in factory_kwargs: + out_attr = getattr(out, k) + if k not in kwargs: + kwargs[k] = out_attr + if pass_is_out: + result = fn(*args, is_out=(out is not None), **kwargs) + else: + result = fn(*args, **kwargs) + assert ( + isinstance(result, TensorLike) + and is_tensor + or isinstance(result, Tuple) # type: ignore[arg-type] + and len(result) == len(out_names) + ) + if out is not None: + # Naively you might expect this assert to be true, but + # it's not: + # + # assert type(out) == type(result) + # + # The reason is that functions under this wrapper can + # get registered to the Meta dispatch key, and that + # means they can be executed in a context where tensor + # subclasses are disabled (with no_dispatch), which is a + # handy way for an is-a tensor subclass (e.g., + # FakeTensor) to have the normal meta backend create a + # meta tensor, to be wrapped once it gets returned. + # In this situation, you will get a FakeTensor as + # the output tensor, but not the result--which will + # be a normal meta tensor, but this is perfectly + # harmless. + if is_tensor: + assert isinstance(out, TensorLike) + # These two operations are done in-place + _maybe_resize_out(out, result.shape) + _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + assert isinstance(out, Tuple) # type: ignore[arg-type] + torch._check_type( + len(out) == len(result), + lambda: f"expected tuple of {len(result)} elements but got {len(out)}", + ) + for r, o in zip(result, out): + # These two operations are done in-place + _maybe_resize_out(o, r.shape) + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + out = result + # mypy does not see through the definition of out_type given that it's in a different scope + return out if is_tensor else return_type(*out) # type: ignore[operator] + + out_param = inspect.Parameter( + "out", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_type, + ) + # Mark that the function now returns a tuple + assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( + sig.empty, + out_type, + ) + params = chain(sig.parameters.values(), (out_param,)) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=return_type # type: ignore[arg-type] + ) + + _fn.__annotations__ = fn.__annotations__ + _fn.__annotations__["out"] = out_type + _fn.__annotations__["return"] = return_type + + # In the special case of having a single tensor out parameter with a + # name other than out, add a special annotation to name the parameter + if is_tensor and out_names != default_out_names: + _fn.__annotations__[CustomOutParamAnnotation] = out_names[0] + + # Add an indicator attribute that can be used in special cases + # where having a function wrapped by `out_wrapper` is not desirable e.g. + # jit + _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] + + return _fn + + return _out_wrapper + + +def _maybe_remove_out_wrapper(fn: Callable): + return inspect.unwrap( + fn, + stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), + ) + + +def backwards_not_supported(prim): + def redispatch_prim(args, kwargs): + with torch._C._AutoDispatchBelowAutograd(): + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + return prim(*args, **kwargs) + + class BackwardsNotSupported(torch.autograd.Function): + @staticmethod + def forward(ctx, args_spec, *flat_args): + args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] + return redispatch_prim(args, kwargs) + + @staticmethod + def backward(ctx, *args): + raise RuntimeError("backwards not supported on prim") + + @wraps(prim) + def _autograd_impl(*args, **kwargs): + flat_args, args_spec = tree_flatten((args, kwargs)) + if torch.is_grad_enabled() and any( + a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) + ): + # TODO: There is a subtle bug here: prims like copy_to + # return their input argument after mutating it; and custom + # autograd function will incorrectly turn the result into + # a view which will fail test_python_ref_executor tests. + # At the moment, we sidestep this by observing that the + # unit tests don't ever try to run the executor with + # autograd, so we don't exercise the buggy case, but if + # you ever want to feed autograd through this, be aware + # of it! We need a way of properly implementing autograd + # for mutating operations in Python to do this. + return BackwardsNotSupported.apply(args_spec, *flat_args) + else: + return redispatch_prim(args, kwargs) + + return _autograd_impl + + +# TODO: when tracing this will add torch tensors and not TensorMeta objects +# to the trace -- we should fix this by adding a tracing context and NumberMeta classes +# TODO: this wrapper is currently untested +def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: + """ + Allows unary operators that accept tensors to work with Python numbers. + """ + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], Number): + dtype = utils.type_to_dtype(type(args[0])) + args_ = list(args) + args_[0] = torch.tensor(args[0], dtype=dtype) + result = fn(*args_, **kwargs) + assert isinstance(result, torch.Tensor) + return result.item() + + return fn(*args, **kwargs) + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn diff --git a/MLPY/Lib/site-packages/torch/_python_dispatcher.py b/MLPY/Lib/site-packages/torch/_python_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..fb40dae036840fb44662d7ee13173f1c671be15e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_python_dispatcher.py @@ -0,0 +1,181 @@ +import re + +import torch._C as C + + +""" +PythonDispatcher class is a thin python-binding to C++ dispatcher and it +is designed to show how dispatcher precompute works. In particular, +it shows for a certain op `foo`, what the computed dispatch table looks +like after user register their kernels to certains dispatch keys. + +In the real C++ dispatcher we support many dispatch keys for different +functionalities. For simplicity PythonDispatcher only supports dispatch +keys for a single example of each use case. These use cases are listed below: + +- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & + autograd kernel in pytorch core library. + E.g. CPU, CUDA +- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific + inference kernels, but they share the same autograd kernel specified in AutogradOther. + E.g. FPGA, SparseCsrCPU +- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd + kernel defined in pytorch core library. Backend owner is responsible for registering both + inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. + E.g. XLA, XPU, MPS +- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. + Kernels registered to this key MUST work for inference for all backends. +- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. + Kernels registered to this key MUST work for autograd for all backends. +- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd + Kernels registered to this key MUST work for both inference + autograd for all backends. + +Note we only allow registrations to alias keys inside pytorch core library. E.g +you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd +kernel from torch-xla extension, instead you should upstream the kernel into +pytorch/pytorch repo so that it's available for all backends and continuously +tested even without the extension. + +Usage: + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) + print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. + # For more debugging information + # print(dispatcher.keys()) + # print(dispatcher.registrations()) + # print(dispatcher.rawRegistrations()) + # print(dispatcher.rawDispatchTable()) +PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. +This file only provides the simplified API for developers, relevant test code is located in +test/test_dispatch.py +""" + + +class PythonDispatcher: + namespace = "__test__" + name = "foo" + # fmt: off + runtime_keys = [ + "CPU", "AutogradCPU", + "FPGA", "AutogradOther", + "XLA", "AutogradXLA", + "Lazy", "AutogradLazy", + ] + # fmt: on + alias_keys = [ + "CompositeExplicitAutograd", + "Autograd", + "CompositeImplicitAutograd", + ] + supported_keys = runtime_keys + alias_keys + + def __init__(self): + C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] + self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") + self.ref.def_("foo(Tensor x) -> Tensor") + + """ + Returns a list of dispatch keys supported by PythonDispatcher. + You can register kernels to these keys. + """ + + def keys(self): + return self.supported_keys + + """ + Register kernels to the target dispatchKeys. + dispatchKeys(list[str]): a list of dispatch keys that you want to register + your own kernel. Note that you don't need to write the kernel yourself in + this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is + automatically generated and registered. + """ + + def register(self, dispatchKeys): + # Overriden is not supported and triggers a warning in C++ dispatcher. + if len(set(dispatchKeys)) != len(dispatchKeys): + raise RuntimeError( + f"Overriden is not allowed but found duplicates in {dispatchKeys}." + ) + # We currently forbid this in codegen instead of C++ dispatcher. + if ( + "CompositeImplicitAutograd" in dispatchKeys + and "CompositeExplicitAutograd" in dispatchKeys + ): + raise RuntimeError( + "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." + ) + for key in dispatchKeys: + if key not in self.supported_keys: + raise RuntimeError( + f"{key} is not supported, please select a dispatch key in {self.supported_keys}." + ) + self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) + + """ + Helper function to format (key, kernel). + """ + + def _format_line(self, key, kernel): + return f"{key:<15} {kernel}\n" + + """ + Helper function to print a table header. + """ + + def _format_header(self, header): + s = f""" +{header} +""" + s += self._format_line("key", "kernel") + s += "---------------------------\n" + return s + + """ + Returns raw output of all registration info for debugging only. + Use registrations() for a simplified version. + """ + + def rawRegistrations(self): + return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] + + """ + Returns raw output of computed dispatch table for debugging only. + Use dispatchTable() for a simplified version. + """ + + def rawDispatchTable(self): + return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] + + """ + Returns a table(str) including all the registrations from users. + Note this includes registrations to both runtime keys and alias keys. + """ + + def registrations(self): + output = self._format_header("Registered Kernels") + state = self.rawRegistrations() + state_entries = state.split("\n") + for line in state_entries: + first = line.split(":")[0] + if any(first.startswith(k) for k in self.supported_keys): + kernel = line.split("::")[0].split(" ")[1] + output += self._format_line(first, kernel) + return output + + """ + Returns the computed dispatch table(str). Note this only include + runtime keys, registrations to alias keys have been decoded to their + mapped runtime keys. + """ + + def dispatchTable(self): + output = self._format_header("Computed Dispatch Table") + table = self.rawDispatchTable() + table_entries = table.split("\n") + regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") + for line in table_entries: + k = line.split(":")[0] + if k in self.runtime_keys: + entry = regex.sub("[", line) + output += self._format_line(k, entry.split(": ")[1]) + return output diff --git a/MLPY/Lib/site-packages/torch/_refs/__init__.py b/MLPY/Lib/site-packages/torch/_refs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a96d1f82d8156c14cb8ceb254d6057127b1f053 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/__init__.py @@ -0,0 +1,6443 @@ +import builtins +import collections +import inspect +import itertools +import math +import operator +import warnings + +from collections.abc import Iterable +from enum import Enum +from functools import partial, reduce, singledispatch, wraps +from typing import Any, Callable, Dict, List, Optional, overload, Sequence, Tuple, Union + +import torch + +import torch._prims as prims +import torch._prims_common as utils +from torch import sym_float, sym_int +from torch._prims_common import ( + DeviceLikeType, + Dim, + DimsSequenceType, + DimsType, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + FloatWithoutSymFloat, + IntLike, + is_weakly_lesser_type, + Number, + NumberType, + RealNumberType, + REDUCTION_OUTPUT_TYPE_KIND, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + TensorOrNumberLikeType, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) + +# Experimental module containing prototype Python references for existing +# PyTorch operations. + +__all__ = [ + # + # Elementwise Unary References + # + "abs", + "acos", + "acosh", + "asinh", + "asin", + "atan", + "atanh", + "bitwise_not", + # "cbrt", # No corresponding torch operation + "ceil", + "conj_physical", + "cos", + "cosh", + "count_nonzero", + "deg2rad", + "digamma", + "erf", + "erfinv", + "erfc", + "exp", + "expm1", + "exponential", + "exp2", + "fill", + "fill_", + "floor", + "frac", + "geometric", + "index_add", + "index_copy", + "index_copy_", + "index_select", + "index_fill", + "index_fill_", + "isfinite", + "isinf", + "isposinf", + "isneginf", + "isnan", + "isreal", + "i0", + "lerp", + "lgamma", + "log", + "log1p", + "log2", + "log10", + "log_normal", + "log_softmax", + "mvlgamma", + "norm", + "normal", + "nan_to_num", + "neg", + "positive", + "rad2deg", + "reciprocal", + "round", # TODO: model kwargs + "sigmoid", + "sgn", + "sign", + "signbit", + "sin", + "sinc", + "sinh", + "softmax", + "sqrt", + "square", + "tan", + "tanh", + "trace", + "trunc", + # + # Elementwise Binary References + # + "add", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "clamp_min", + "clamp_max", + "copysign", + "div", + "eq", + "float_power", + "floor_divide", + "fmax", + "fmin", + "fmod", + "gcd", + "ge", + "gt", + "heaviside", + "hypot", + "igamma", + "igammac", + "imag", + "isclose", + "lcm", + # 'ldexp', + "le", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logsumexp", + "lt", + # 'max', # implement with reductions + "maximum", + # 'min', # implement with reductions + "minimum", + "mul", + "ne", + "nextafter", + # 'polar', # abs, cos, sin + "pow", + "real", + "rpow", + "remainder", + "rsub", + "rtruediv", + "rfloordiv", + "sub", + "true_divide", + "trunc_divide", + "xlogy", + # + # Elementwise Ternary References + # + "addcdiv", + "addcmul", + "clamp", + # + # Conditional references + # + "masked_fill", + "masked_fill_", + "where", + # + # Data conversion and movement references + # + "clone", + "copy_to", # TODO: add OpInfo (or implement .to) + "item", + "to", + # + # Reduction ops + # + "all", + "amax", + "amin", + "any", + "cumsum", + "cumprod", + "mean", + "dot", + "vdot", + "std", + "std_mean", + "sum", + "sum_to_size", + "prod", + "var", + "var_mean", + # + # Linear algebra ops + # + "addr", + # + # View & Shape Ops + # + "alias", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "as_strided", + "as_strided_scatter", + "block_diag", + "broadcast_shapes", + "broadcast_tensors", + "broadcast_to", + "cat", + "chunk", + "column_stack", + "conj", + "constant_pad_nd", + "contiguous", + "diag_embed", + "diag", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "dsplit", + "dstack", + "expand", + "expand_as", + "flatten", + "flip", + "fliplr", + "flipud", + "hsplit", + "hstack", + "meshgrid", + "movedim", + "narrow", + "narrow_copy", + "native_group_norm", + "native_layer_norm", + "permute", + "ravel", + "repeat", + "reshape", + "reshape_as", + "roll", + "rot90", + "rsqrt", + "stack", + "swap_axes", # alias for transpose + "squeeze", + "t", + "T", + "take_along_dim", + "tensor_split", + "transpose", + "unfold", + "unfold_copy", + "unsqueeze", + "view", + "view_as", + "vsplit", + "vstack", + "view_as_complex", + "unflatten", + "unbind", + "triu", + "tril", + "triu_indices", + "tril_indices", + # + # Tensor Creation + # + "arange", + "cauchy", + "empty", + "empty_like", + "empty_permuted", + "empty_strided", + "eye", + "full", + "full_like", + "linspace", + "logspace", + "new_empty", + "new_empty_strided", + "new_full", + "new_ones", + "new_zeros", + "ones", + "ones_like", + "randn", + "scalar_tensor", + "zero", + "zeros", + "zeros_like", + # + # Test-related functions + # + "allclose", + "equal", + # + # Statistical operations + # + "bucketize", + # + # Misc + # + "is_complex", + "renorm", + "stft", + "istft", +] + +Tensor = torch.Tensor +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] +aten = torch._ops.ops.aten + +# Note that the docstrings for the public methods from this file are in +# torch/_torch_docs.py + + +def is_noncontiguous_supported(device): + if device is not None and device.type == "hpu": + return False + return True + + +def handle_noncontiguous_outputs(input_tlist, output): + device = None + from torch._subclasses.fake_tensor import FakeTensor + + for t in input_tlist: + if isinstance(t, FakeTensor): + device = t.fake_device + break + + if not is_noncontiguous_supported(device): + output = output.contiguous() + + return output + + +def _broadcast_shapes(*_shapes): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + shapes = tuple( + (x,) if isinstance(x, IntLike) else x + for x in filter(lambda x: x is not None, _shapes) + ) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) + + # Computes common shape + common_shape = [ + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + if guard_size_oblivious(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError( + "Attempting to broadcast a dimension with negative length!" + ) + common_shape[idx] = shape[idx] + elif guard_size_oblivious(shape[idx] != 1): + if common_shape[idx] != shape[idx]: + raise RuntimeError( + f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}" + ) + + return common_shape + + +def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): + # Computes common shape + common_shape = _broadcast_shapes( + *(t.shape if isinstance(t, TensorLike) else None for t in args) + ) + + def __maybe_broadcast(x, shape): + if x is None: + return None + elif isinstance(x, Number): + return x + elif isinstance(x, TensorLike): + if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): + return x + + if not utils.same_shape(x.shape, common_shape): + return x.expand(common_shape) + + return x + else: + raise RuntimeError( + "Unexpected type when broadcasting: " + str(type(x)) + "!" + ) + + return tuple(__maybe_broadcast(x, common_shape) for x in args) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +# +# Elementwise unary references +# + +infer_aten_op = object() + + +# TODO: add type promotion support +def _make_elementwise_unary_reference( + type_promotion_kind, + *, + aten_op=infer_aten_op, + extra_meta=None, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op + + @wraps(prim) + @out_wrapper() + @elementwise_unary_scalar_wrapper + @elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=type_promotion_kind, + ) + def _ref(a: TensorLikeType) -> TensorLikeType: + if extra_meta is not None: + extra_meta(a) + + output = prim(a) + return handle_noncontiguous_outputs([a], output) + + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, prim.__name__) + if aten_op is not None: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +def _make_alias(fn, name): + """ + This function defines an alias of another function and sets its __name__ argument. + It also sets its __module__ argument to the module of the caller. + Note that when naïvely doing `alias = fn`, we have that `alias.__name__ == "fn"`, and + `alias.__module__ == fn.__module__`. + """ + + def _fn(*args, **kwargs): + return fn(*args, **kwargs) + + _fn.__name__ = name + _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr] + return _fn + + +def _make_inplace(fn): + """ + Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant + See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, **kwargs): + return fn(a, *args, out=a, **kwargs) + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_decomposition(getattr(aten, inplace_name))(_fn) + + # We access the __all__ attribute of the module where fn is defined + # There may be a cleaner way of doing this... + from inspect import getmodule + + _all = getmodule(fn).__all__ # type: ignore[union-attr] + if inplace_name not in _all: + _all.append(inplace_name) + return _fn + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) +def abs(a): + return prims.abs(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acos(a): + return prims.acos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acosh(a): + return prims.acosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asin(a): + return prims.asin(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asinh(a): + return prims.asinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atan(a): + return prims.atan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atanh(a): + return prims.atanh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def bitwise_not(a): + return prims.bitwise_not(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def ceil(a): + return prims.ceil(a) + + +@register_decomposition(aten.is_complex) +def is_complex(input: TensorLikeType): + return utils.is_complex_dtype(input.dtype) + + +@register_decomposition(aten.conj_physical) +@out_wrapper() +def conj_physical(input: TensorLikeType): + if not utils.is_complex_dtype(input.dtype): + return input + return prims.conj_physical(input) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cos(a): + return prims.cos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cosh(a): + return prims.cosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def digamma(a): + return prims.digamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erf(a): + return prims.erf(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfinv(a): + return prims.erf_inv(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfc(a): + return prims.erfc(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp(a): + return prims.exp(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def expm1(a): + return prims.expm1(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp2(a): + return prims.exp2(a) + + +# Fill has its own implementation because it has a value parameter +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a,"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: + assert isinstance(a, TensorLike) + assert isinstance(value, Number) + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(value), python_type): + msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + + return prims.fill(a, value) + + +def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: + r = prims.fill(a, value) + prims.copy_to(a, r) + return a + + +@register_decomposition(aten.zero) +@out_wrapper() +def zero(input: TensorLikeType) -> TensorLikeType: + return torch.zeros_like(input) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def floor(a): + return prims.floor(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def frac(x: TensorLikeType) -> TensorLikeType: + trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x)) + return torch.sub(x, trunc_x) + + +# imag does not use _make_elementwise_unary_reference because it does not support out +def imag(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + torch._check( + utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." + ) + return prims.imag(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isfinite(a: TensorLikeType) -> TensorLikeType: + if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): + return prims.isfinite(a) + + return ones_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isinf(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a))) + if utils.is_float_dtype(a.dtype): + return torch.abs(a) == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isposinf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isneginf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("-inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isnan(a: TensorLikeType) -> TensorLikeType: + return prims.ne(a, a) + + +# alias +mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isreal(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.imag(a) == 0 + return torch.ones_like(a, dtype=torch.bool) + + +# TODO: if this is special maybe it should be defined there and imported here? +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0 +) +def i0(a): + return prims.bessel_i0(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def lgamma(a): + return prims.lgamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log(a): + return prims.log(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log1p(a): + return prims.log1p(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log2(a): + return prims.log2(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log10(a): + return prims.log10(a) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logsumexp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logsumexp( + self: TensorLikeType, dim: DimsType, keepdim: bool = False +) -> TensorLikeType: + if not isinstance(dim, Iterable): + dim = (dim,) + if self.numel() == 0: + return torch.sum(torch.exp(self), dim, keepdim).log() + maxes = torch.amax(self, dim, keepdim=True) + maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) + maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) + result = torch.sum(torch.exp(self - maxes), dim, keepdim) + return result.log().add(maxes_squeezed) + + +@register_decomposition(aten.nan_to_num) +@out_wrapper() +def nan_to_num( + a: TensorLikeType, + nan: Optional[NumberType] = 0.0, + posinf: Optional[NumberType] = None, + neginf: Optional[NumberType] = None, +) -> TensorLikeType: + assert isinstance(a, TensorLike) + + if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + return a.clone() + + if nan is None: + nan = 0.0 + + if posinf is None: + posinf = torch.finfo(a.dtype).max + + if neginf is None: + neginf = torch.finfo(a.dtype).min + + result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload] + result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload] + result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload] + return result + + +def _neg_meta(a: TensorLikeType): + torch._check( + a.dtype is not torch.bool, + lambda: ( + "Negation, the `-` operator, on a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` " + "operator instead." + ), + ) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta +) +def neg(a): + return prims.neg(a) + + +# positive does not use _make_elementwise_unary_reference because it does not support out +# CompositeImplicitAutograd - don't register decomp +def positive(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if a.dtype is torch.bool: + msg = "positive does not support bool tensors." + raise RuntimeError(msg) + return a + + +# real does not use _make_elementwise_unary_reference because it does not support out +def real(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if utils.is_complex_dtype(a.dtype): + return prims.real(a) + return a + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def reciprocal(a): + return prims.reciprocal(a) + + +@register_decomposition(aten.round) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType: + if decimals == 0: + return prims.round(a) + else: + ten_pow = 10**decimals + ten_neg_pow = 10 ** (-decimals) + return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rsqrt(a): + return prims.rsqrt(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sigmoid(a: TensorLikeType) -> TensorLikeType: + return true_divide(1, add(1, exp(neg(a)))) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def sgn(a): + if utils.is_complex_dtype(a.dtype): + a_abs = a.abs() + return torch.where(a_abs == 0, 0, a / a_abs) + else: + return a.sign() + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def sign(a): + return prims.sign(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def signbit(a): + return prims.signbit(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sin(a): + return prims.sin(a) + + +# Autograd note: This will give the right first derivative at zero (by chance), +# but not the right second derivative +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinc(a): + a = math.pi * a + return torch.where(a == 0, 1, torch.sin(a) / a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinh(a): + return prims.sinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sqrt(a): + return prims.sqrt(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, + aten_op=None, # CompositeImplicitAutograd, +) +def square(a: TensorLikeType) -> TensorLikeType: + return mul(a, a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tan(a): + return prims.tan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tanh(a): + return prims.tanh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def trunc(a): + return prims.trunc(a) + + +# TODO: register this as a real ref/decomposition once TorchInductor supports complex! +def view_as_complex(self: TensorLikeType) -> TensorLikeType: + input_dtype = self.dtype + torch._check( + utils.is_float_dtype(input_dtype), + lambda: f"view_as_complex is only supported for floating point" + f"tensors, but got a tensor of scalar type: {input_dtype}", + ) + sizes = self.size() + torch._check( + len(sizes) != 0, + lambda: "Input tensor must have one or more dimensions", + ) + torch._check( + sizes[-1] == 2, + lambda: "Tensor must have a last dimension of size 2", + ) + + old_strides = self.stride() + torch._check( + old_strides[-1] == 1, + lambda: "Tensor must have a last dimension with stride 1", + ) + dims = old_strides[:-1] + torch._check( + py_all(stride % 2 == 0 for stride in dims), + lambda: "Tensor must have a stride divisible by 2 for all but last dimension", + ) + torch._check( + self.storage_offset() % 2 == 0, + lambda: "Tensor must have a storage_offset divisible by 2", + ) + return prims.view_element_type( + self, utils.corresponding_complex_dtype(input_dtype) + ).squeeze(-1) + + +def _make_elementwise_binary_reference( + type_promotion_kind, + aten_op=infer_aten_op, + name=None, + has_out=True, + supports_lhs_python_scalar=True, + supports_rhs_python_scalar=True, + supports_two_python_scalars=False, + should_register_decomposition=True, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op, name + if name is None: + name = prim.__name__ + + @wraps(prim) + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + def _ref( + a: Union[Tensor, NumberType], + b: Union[Tensor, NumberType], + ) -> Tensor: + torch._check_value( + supports_lhs_python_scalar or not isinstance(a, Number), + lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " + "operation that does not accept lhs scalars!", + ) + torch._check_value( + supports_rhs_python_scalar or not isinstance(b, Number), + lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " + "operation that does not accept rhs scalars!", + ) + torch._check_value( + supports_two_python_scalars + or not (isinstance(a, Number) and isinstance(b, Number)), + lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", + ) + a, b = _maybe_broadcast(a, b) + output = prim(a, b) + return handle_noncontiguous_outputs([a, b], output) + + if has_out: + _ref = out_wrapper()(_ref) + + _ref.__name__ = name + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, name) + if aten_op is not None and should_register_decomposition: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +# Add has its own implementation because it has an alpha argument +@register_decomposition(aten.add) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def add( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: Optional[NumberType] = None, +): + """ + Reference implementation of torch.add + """ + + a, b = _maybe_broadcast(a, b) + + if alpha is not None: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if python_type != bool and not utils.is_weakly_lesser_type( + type(alpha), python_type + ): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, TensorLike): + b = prims.mul(b, alpha) + else: + b = b * alpha + + output = prims.add(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def atan2(a, b): + return prims.atan2(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_and(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_left(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_or(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_right_arithmetic(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_xor(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, +) +def copysign( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + if isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + msg = "Expected divisor (b) to be on the same device ({}) as dividend (a), but it is found on {}!".format( + a.device, b.device + ) + raise RuntimeError(msg) + return where(signbit(b), neg(abs(a)), abs(a)) + + +# complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + + +@register_decomposition(aten.div) +@out_wrapper() +def div( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + rounding_mode: Optional[str] = None, +): + """ + Reference implementation of torch.div + """ + if rounding_mode is None: + return true_divide(a, b) + elif rounding_mode == "trunc": + return trunc_divide(a, b) + elif rounding_mode == "floor": + return floor_divide(a, b) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.eq(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) +def pow( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> TensorLikeType: + assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) + + if isinstance(b, Number): + if b == 1.0: + return a.clone() # type: ignore[return-value,union-attr] + elif b == 2.0: + return a * a # type: ignore[return-value] + elif b == 0.5: + return torch.sqrt(a) # type: ignore[arg-type] + elif isinstance(a, Number): + if a == 1.0: + return torch.fill(b, True) + if a == 2.0 and ( + utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype) + ): + return torch.exp2(b) + + return prims.pow(a, b) + + +# Float power has its own implementation because it has unique type promotion. +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def float_power( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> Tensor: + if isinstance(a, Number) and isinstance(b, Number): + raise ValueError( + "Receive two Number inputs to an elementwise binary operation!" + ) + + # Handles type promotion + dtype = utils.get_higher_dtype(a, b) + assert dtype is not None + if utils.is_complex_dtype(dtype): + dtype = torch.complex128 + else: + dtype = torch.float64 + + # Float power has the following contiguous cast behavior to be + # consistent with its C++ impl + a = _maybe_convert_to_dtype(a, dtype) + b = _maybe_convert_to_dtype(b, dtype) + + a, b = _maybe_broadcast(a, b) + return pow(a, b) + + +# >>> a = torch.tensor(-0.2500, dtype=torch.float64) +# tensor(-0.250000000000000, dtype=torch.float64) +# +# >>> b = torch.tensor(-0.0010, dtype=torch.float64) +# tensor(-0.001000000000000, dtype=torch.float64) +# +# Note: In this case, casting float to double will expand the float mantissa with zeros, +# while creating a double generates a distinct mantissa. +# >>> torch.tensor(-0.001).to(dtype=torch.float64) +# tensor(-0.001000000047497, dtype=torch.float64) +# +# Floor Division +# The difference is caused because torch.remainder(a, b) = -0.001. +# +# >>> torch.floor(torch.true_divide(a, b)) +# tensor(250., dtype=torch.float64) +# +# >>> torch.div(a, b, rounding_mode='floor') +# tensor(249., dtype=torch.float64) +# +# Definition: a // b = (a - remainder(a, b)) / b +# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b) +# tensor(249., dtype=torch.float64) +# +# For reference, see CPython's implementation: +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, + should_register_decomposition=False, +) +def floor_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + # Wrap scalars because some references only accept tensor arguments. + if isinstance(a, Number) and isinstance(b, Number): + a = scalar_tensor(a) + b = scalar_tensor(b) + elif isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Number) and isinstance(b, Tensor): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + if a.device == torch.device("cpu"): + msg = "Expected divisor (b) to be on the same device ({}) as dividend (a), but it is found on {}!".format( + a.device, b.device + ) + raise RuntimeError(msg) + else: + b = prims.device_put(b, device=a.device) + + assert isinstance(a, Tensor) and isinstance(b, Tensor) + dtype = a.dtype + if utils.is_float_dtype(dtype): + return _floor_divide_float(a, b) + elif utils.is_integer_dtype(dtype): + return _floor_divide_integer(a, b) + else: + torch._check(False, lambda: f"{dtype} not supported for floor_divide") + + +def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: + a, b = _maybe_broadcast(a, b) + + if not a.dtype.is_signed: + return prims.div(a, b) + + # Convert truncation to flooring: + offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + + +def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: + mod = fmod(a, b) + div = true_divide(sub(a, mod), b) + + # Ensure that the remainder has the same sign as denominator + different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) + non_zero_remainder = ne(mod, 0) + mask = bitwise_and(non_zero_remainder, different_signed_inputs) + div = where(mask, sub(div, 1), div) + + # Map quotient to nearest integer value + floor_div = floor(div) + mask = gt(sub(div, floor_div), 0.5) + floor_div = where(mask, add(floor_div, 1), floor_div) + + basic_div = true_divide(a, b) + zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) + + # If quotient is zero, copy signbit from true_divide quotient + floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) + + # If denominator is zero, then follow true_divide behavior + return where(ne(b, 0), floor_div, basic_div) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmax(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmin(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=True, +) +def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmod(a, b) + + +@register_decomposition(aten.frexp) +@out_wrapper("mantissa", "exponent") +def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: + return torch.return_types.frexp(prims.frexp(self)) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gcd(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ge(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: + input_eq_zero = torch.eq(input, 0) + input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input)) + zeros_and_ones = torch.where(input_lt_zero, 0, 1) + output = torch.where(input_eq_zero, values, zeros_and_ones) + return output + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.hypot(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igamma(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igammac(a, b) + + +def _check_close_args( + name: str, + a: TensorLikeType, + b: TensorLikeType, + rtol: float, + atol: float, +) -> None: + torch._check_value( + a.dtype == b.dtype, + lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!", + ) + torch._check( + rtol >= 0, + lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!", + ) + torch._check( + atol >= 0, + lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!", + ) + + +# CompositeImplicitAutograd - don't register decomp +def isclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> TensorLikeType: + _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) + + close = eq(a, b) + if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): + close = logical_or(close, logical_and(isnan(a), isnan(b))) + + # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. + # In this case, the short-circuit prevents false positives as detailed in the paragraph below. + if atol == 0 and rtol == 0: + return close + + # Note [closeness error computation] + # atol and rtol are provided as doubles, so the computation + # rtol * other will produce a float or complex tensor. + # When the difference (self - other) is compared to it then the + # tensor representing the difference will also be cast to float or complex. + # However, since (self - other) in uint8 is very likely to produce a + # negative value, this moves the cast forward so the difference is + # always computed in a float or complex type. + # If the values of the integer tensors cannot be exactly represented + # by the default scalar type then this may cause an incorrect result. + if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype): + a = prims.convert_element_type(a, torch.get_default_dtype()) + b = prims.convert_element_type(b, torch.get_default_dtype()) + + allowed_error = add(atol, abs(mul(b, rtol))) + actual_error = abs(sub(a, b)) + + # Computes finite closeness + result = logical_or( + close, logical_and(isfinite(actual_error), le(actual_error, allowed_error)) + ) + + return result + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def lcm(a: TensorLikeType, b: TensorLikeType): + dtype = a.dtype + # promoting to int32 to maintain 100% consistency with C++ and to + # prevent overflow in case of int8 and int16 + promote_to_int = dtype in (torch.int8, torch.int16) + if promote_to_int: + a = prims.convert_element_type(a, torch.int32) + b = prims.convert_element_type(b, torch.int32) + + g = torch.gcd(a, b) + # Avoid division by zero in case gcd(0, 0) == 0 + g = torch.where(g == 0, 1, g) + res = torch.abs(prims.div(a, g) * b) + return res if not promote_to_int else prims.convert_element_type(res, dtype) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.le(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = torch.real(a) >= torch.real(b) + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and( + torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b) + ) + if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype): + # are you wondering what this bunch of codes are for? edge cases! + neg_min_mask = torch.real(min_) < 0 + inf_vals = torch.where( + neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_)) + ) + non_nan_vals = torch.where( + inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_)) + ) + # the type for full_like does not include tensor yet + nan_mask = torch.isnan(min_) + return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload] + else: + return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_))) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + torch._check( + not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)), + lambda: "logaddexp2 doesn't support complex dtypes", + ) + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = a >= b + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and(torch.isinf(a), a == b) + inv_log_2 = 1.0 / math.log(2) + result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2 + return torch.where(inf_mask, a, result) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_and(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a & b + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def logical_not(a: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + return a == 0 + return ~a + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_or(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return bitwise_or(a, b) + + +# TODO: skip unnecessary conversion of long to float +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_xor(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a ^ b + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.lt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.maximum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.minimum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, +) +def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.mul(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ne(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.nextafter(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.remainder(a, b) + + +# reverse sub +@register_decomposition(aten.rsub) +@out_wrapper() +def rsub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + alpha: NumberType = 1, +): + if isinstance(a, Number): + msg = "Received a Number for the first argument, but expected a Tensor" + raise ValueError(msg) + + return torch.sub(b, a, alpha=alpha) + + +# TODO: consider refactoring this with add impl +# sub has its own implementation because it has an alpha argument +@register_decomposition(aten.sub) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def sub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: NumberType = 1, +): + """ + Reference implementation of torch.sub + """ + + a, b = _maybe_broadcast(a, b) + + if alpha != 1: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, torch.Tensor): + b = prims.mul(b, alpha) + else: + # Carefully not to use prims.mul if b is a scalar / symint. + # prims.mul always returns a tensor, + # which will mess with type promotion. + b = b * alpha + + output = prims.sub(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + name="true_divide", + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.div(a, b) + + +@register_decomposition(aten.xlogy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def trunc_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + dtype = utils.get_dtype(a) + if utils.is_integer_dtype(dtype): + return prims.div(a, b) + + return trunc(prims.div(a, b)) + + +# +# Elementwise Ternary References +# + + +@register_decomposition(aten.addcdiv) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def addcdiv( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcdiv + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 / tensor2 + + +@register_decomposition(aten.addcmul) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addcmul( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcmul + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 * tensor2 + + +@register_decomposition(aten.clamp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "min", "max"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def clamp( + a: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + # NOTE: grad behavior with implementation `where` is not consistent on `nan` + if min is None and max is None: + msg = "clamp called but both min and max are none!" + raise ValueError(msg) + if min is not None: + a_isnan = torch.isnan(a) + condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type] + # we should also propagate `nan` coming from boundaries. However, that's + # not necessary since `ge` would already `False` when either operands has + # a `nan`. So this line below is redundant + # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` + a = torch.where(condition, a, min) # type: ignore[arg-type] + if max is not None: + a_isnan = torch.isnan(a) + # same as above, no need to adjust `nan` from `max` + condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] + a = torch.where(condition, a, max) # type: ignore[arg-type] + + return a + + +@register_decomposition(aten.clamp_min) +@out_wrapper() +def clamp_min( + self: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, min=min) # type: ignore[arg-type] + + +@register_decomposition(aten.clamp_max) +@out_wrapper() +def clamp_max( + self: TensorLikeType, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, max=max) # type: ignore[arg-type] + + +# +# Conditional references +# + + +# https://pytorch.org/docs/stable/generated/torch.where.html +# TODO: implement alternate where +@register_decomposition(aten.where) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def where( + pred: Tensor, + a: Optional[TensorOrNumberLikeType] = None, + b: Optional[TensorOrNumberLikeType] = None, +): + """ """ + + if a is None or b is None: + raise NotImplementedError + + utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) + torch._check( + pred.dtype is torch.bool, + lambda: f"expected predicate to be bool, got {pred.dtype}", + ) + + pred, a, b = _maybe_broadcast(pred, a, b) + return prims.where(pred, a, b) + + +# +# Data Movement References +# +@register_decomposition(aten.clone) +@out_wrapper() +def clone( + a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format +) -> TensorLikeType: + result = prims.clone(a, memory_format=memory_format) + return result + + +def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): + if not allow_cross_device and a.device != b.device: + msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format( + b.device, a.device + ) + raise RuntimeError(msg) + + return prims.copy_to(a, b) + + +@register_decomposition(aten.item) +def item(a: TensorLikeType) -> NumberType: + if a.numel() != 1: + msg = f"Can't convert a tensor with {a.numel()} elements to a number!" + raise ValueError(msg) + + # NOTE: explicit conversion is necessary for bool! + # See https://github.com/pytorch/pytorch/issues/78071 + number_type = utils.dtype_to_type(a.dtype) + return number_type(prims.item(a)) + + +# fast path when `to` returns an alias to input. This mimics the same function in aten +def _to_will_alias( + a: TensorLikeType, + device: Optional[DeviceLikeType] = None, + dtype: Optional[torch.dtype] = None, + copy: Optional[bool] = None, + layout: Optional[torch.layout] = None, + memory_format: Optional[torch.memory_format] = None, + pin_memory: Optional[bool] = False, + non_blocking: bool = False, # not using non_blocking +) -> bool: + return ( + not copy + and (device is None or a.device == device) + and (dtype is None or a.dtype == dtype) + and (layout is None or a.layout == layout) + # is_pinned issue #84925 + # and (pin_memory is None or pin_memory == a.is_pinned()) + and ( + memory_format is None + or memory_format == torch.preserve_format + or utils.is_contiguous_for_memory_format(a, memory_format=memory_format) + ) + ) + + +@singledispatch +def _to_dispatch(*args, **kwargs): + raise NotImplementedError + + +@_to_dispatch.register +def _to_device( + device: torch.device, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + kwargs = { + "device": device, + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_device_str( + device: str, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + kwargs = { + "device": torch.device(device), + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_dtype( + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + kwargs = { + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_other( + other: Tensor, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + device = other.device + dtype = other.dtype + layout = other.layout + # is_pinned issue #84925 + # pin_memory = other.is_pinned() + kwargs = { + "device": device, + "dtype": dtype, + "layout": layout, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +# remove to_kwargs that is already present in `a` +def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict): + options_to_check = ["dtype", "device", "layout", "memory_format"] + # "device" option could be passed a str instead torch.device + if "device" in to_kwargs and isinstance(to_kwargs["device"], str): + to_kwargs["device"] = torch.device(to_kwargs["device"]) + + for kw in options_to_check: + if kw in to_kwargs: + if ( + (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format) + or ( + kw == "device" + and to_kwargs[kw].type == a.device.type + and ( + not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index + ) + ) + or ( + getattr(a, kw, None) == to_kwargs[kw] + ) # this also handles {"memory_format": None} + ): + to_kwargs.pop(kw) + + +def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType: + # handled dispatch via positional arguments + if len(args) != 0: + kwargs = _to_dispatch(*args, **kwargs) + + # TODO: is_pinned is not currently supported in refs or fake_tensor + # https://github.com/pytorch/pytorch/issues/84925 + assert "pin_memory" not in kwargs + _canonicalize_to_arguments(a, kwargs) + + if _to_will_alias(a, **kwargs): + return a + + copy = kwargs.pop("copy") if "copy" in kwargs else False + non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False + + # short-circuit to `prims.convert_element_type` when `to` is just a dtype change + if ( + (copy or (kwargs.get("dtype", a.dtype) != a.dtype)) + and (not non_blocking) + and ("memory_format" not in kwargs) + and ("device" not in kwargs) + and ("layout" not in kwargs) + # is_pinned issue #84925 + # and ("pin_memory" not in kwargs) + ): + return prims.convert_element_type(a, kwargs.get("dtype", a.dtype)) + + result = torch.empty_like(a, **kwargs) + # TODO: non_blocking should be handled by `copy_to` + copy_to(result, a) + return result + + +# +# Reduction references +# + + +def _reduction( + a: TensorLikeType, + prim: Callable, + *, + has_identity: bool = True, + accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only + dims: Optional[DimsType] = None, + keepdims: bool = False, + dtype: Optional[torch.dtype] = None, # should be specified for ops that support it + out: Optional[Tensor] = None, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, +) -> TensorLikeType: # it is usually SAME, but I want + # ref writers to actually think about what to put here + assert isinstance(a, TensorLike) + if a.ndim > 64: + raise RuntimeError( + f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + + if out is not None: + assert isinstance(out, TensorLike) + if dtype is not None: + # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms + if dtype != out.dtype: + raise RuntimeError( + "dtype argument and out dtype must match in reduction" + ) + if not accepts_dim_tuple: + assert dims is None or isinstance(dims, Dim) + if isinstance(dims, Dim): + dims = (dims,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dims) + if not has_identity: + valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims) + if not valid_shape: + raise RuntimeError( + "reducing over zero-size dimension for reduction operation without identity" + ) + computation_dtype, result_dtype = utils.reduction_dtypes( + a, output_dtype_kind, dtype + ) + a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign] + result = prim(a, dims) + if keepdims: + output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] + broadcast_dims = [i for i in range(a.ndim) if i not in dims] + result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) + + if out is not None: + assert result_dtype is not None + if dtype is not None and result_dtype != out.dtype: + raise RuntimeError( + "Expected the dtype of reduction result and out to match" + ) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + + if result.dtype != result_dtype and result_dtype is not None: + result = prims.convert_element_type(result, result_dtype) + + return result + + +def _make_copy_from_view(fn): + """ + Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) + """ + name = fn.__name__ + fn = out_wrapper()(fn) + + def _fn(*args, out=None, **kwargs): + result = fn(*args, out=out, **kwargs) + if out is None: + return result.clone(memory_format=torch.contiguous_format) + return result + + copy_name = f"{name}_copy" + _fn.__name__ = copy_name + _fn = register_decomposition(getattr(aten, copy_name))(_fn) + return _fn + + +# Saves Python all +py_all = all + + +@register_decomposition(aten.all) +@out_wrapper() +def all( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) + + if a.dtype == torch.uint8: + result = result.to(dtype=torch.uint8) + + return result + + +# Saves Python any +py_any = any + + +@register_decomposition(aten.any) +@out_wrapper() +def any( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + a_ = _maybe_convert_to_dtype(a, torch.bool) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + result = a_.clone() + else: + result = a_.sum(dim=dim, keepdim=keepdim).ne(False) + + # Preserves uint8 -- probably a legacy mask thing + if a.dtype is torch.uint8: + return prims.convert_element_type(result, torch.uint8) + + return result + + +@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out]) +def sum( + a: TensorLikeType, + dim: Union[Optional[int], Optional[List[int]]] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def sum_to_size( + a: Tensor, + *shape, +) -> Tensor: + shape = utils.extract_shape_from_varargs(shape, validate=False) + torch._check( + utils.is_expandable_to(shape, a.shape), + lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', + ) + # In ATen scalar tensors are sent through sum and the result is returned as + # type promoted + if utils.is_same_shape(shape, a.shape) and len(shape) > 0: + return prims.view_of(a) + leading_dims = a.ndim - len(shape) + reduce_dims = tuple(range(leading_dims)) + tuple( + i + for i in range(leading_dims, len(shape)) + if shape[i - leading_dims] == 1 and a.shape[i] != 1 + ) + return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None) + + +@register_decomposition(aten.prod) +def prod( + a: TensorLikeType, + dim: Union[Optional[int], Optional[List[int]]] = None, + keepdim: bool = False, + *, + dtype=None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.prod, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amin) +def amin( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amin, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amax) +def amax( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amax, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def _dim_var_dispatch(dim=None, unbiased=None): + # There's the following overload of torch.var: + # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + # We need to explicitly convert bool dims to unbiased arg + if unbiased is None and isinstance(dim, bool): + unbiased = dim + dim = None + return dim, unbiased + + +@register_decomposition(aten.var) +@out_wrapper() +def var( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + result = _reduction( + a, + partial(prims.var, correction=correction), + dims=dim, + keepdims=keepdim, + dtype=None, + out=None, + has_identity=True, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ) + return result + + +@register_decomposition(aten.std) +@out_wrapper() +def std( + a: TensorLikeType, + dim: Union[Optional[int], Optional[List[int]]] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var = torch.var(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return _maybe_convert_to_dtype(a_std, dtype) + + +@register_decomposition(aten.mean) +def mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype=None, + out=None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + orig_dtype = dtype + if dtype is None: + dtype = a.dtype + # can't use out wrapper because of this argument + torch._check( + out is None or out.dtype == dtype, + lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead", + ) + result = _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=None, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, + ) + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: ( + f"mean(): could not infer output dtype. " + f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either " + f"a floating point or complex dtype. Got: {dtype}" + ), + ) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type] + nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) + result = true_divide(result, nelem) + result_dtype = a.dtype if dtype is None else dtype + result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign] + if out is not None: + assert isinstance(out, TensorLike) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + return result + + +@register_decomposition(aten.std_mean) +@out_wrapper("out0", "out1") +def std_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + *, + unbiased: Optional[bool] = None, + keepdim: bool = False, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + original_dtype = a.dtype + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return ( + _maybe_convert_to_dtype(a_std, dtype), + _maybe_convert_to_dtype(a_mean, original_dtype), + ) + + +@register_decomposition(aten.var_mean) +@out_wrapper("out0", "out1") +def var_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + v = var(a, dim, unbiased, keepdim, correction=correction) + m = mean(a, dim, keepdim) + return v, m + + +@register_decomposition(aten.addr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "vec1", "vec2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addr( + self: TensorLikeType, + vec1: TensorLikeType, + vec2: TensorLikeType, + *, + beta: NumberType = 1, + alpha: NumberType = 1, +) -> TensorLikeType: + torch._check( + vec1.ndim == 1, + lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", + ) + torch._check( + vec2.ndim == 1, + lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", + ) + self = self.expand(vec1.shape[0], vec2.shape[0]) + if utils.is_boolean_dtype(self.dtype): + # Integers are accepted for booleans + torch._check( + is_weakly_lesser_type(type(beta), int), + lambda: f"expected bool/int beta but got {type(beta)}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), int), + lambda: f"expected bool/int alpha but got {type(beta)}", + ) + if not beta: + return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False) + else: + return torch.logical_or( + self, + torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), + ) + else: + torch._check( + is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(beta)} to {self.dtype}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", + ) + if beta == 0: + # This means NaNs from self are dropped if beta is zero + return alpha * torch.outer(vec1, vec2) + else: + return beta * self + alpha * torch.outer(vec1, vec2) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_1d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_1d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) + return res if len(res) > 1 else res[0] + + +# Helper function with assert to avoid MyPy error +# of incompatible type passed to unsqueeze +def _unsqueeze_atleast( + at_least_fn: Callable, dim: int, arg: TensorLikeType +) -> TensorLikeType: + arg_ = at_least_fn(arg) + assert isinstance(arg_, TensorLike) + return unsqueeze(arg_, dim) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_2d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_2d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) + res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +# CompositeImplicitAutograd - don't register decomp +def atleast_3d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_3d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) + res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +def as_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = ( + storage_offset if storage_offset is not None else a.storage_offset() + ) + return prims.as_strided(a, size, stride, storage_offset_int) + + +@register_decomposition(aten.as_strided_scatter) +@out_wrapper() +def as_strided_scatter( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = 0 if storage_offset is None else storage_offset + return prims.as_strided_scatter(input, src, size, stride, storage_offset_int) + + +def broadcast_shapes(*shapes) -> ShapeType: + return torch.Size(_broadcast_shapes(*shapes)) + + +@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) +def broadcast_tensors(*tensors) -> List[TensorLikeType]: + if len(tensors) == 1 and not isinstance(tensors[0], Tensor): + tensors = tensors[0] + return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) + + +# CompositeImplicitAutograd - don't register decomp +def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: + start = len(size) - len(a.shape) + dims = tuple(range(start, len(a.shape) + start)) + return prims.broadcast_in_dim(a, size, dims) + + +@register_decomposition(aten.cat) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("tensors",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + def cat_compute_output_memory_format(inputs): + format = None + for t in inputs: + f = utils.suggest_memory_format(t) + if f == torch.contiguous_format: + return f + if format is not None and format != f: + return torch.contiguous_format + format = f + assert format is not None + return format + + if len(tensors) == 0: + msg = "cat expects at least one tensor, but received zero!" + raise ValueError(msg) + + for tensor in tensors: + assert isinstance(tensor, TensorLike) + + utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # This is a bit tricky. Naively, you would expect to just pick one + # arbitrary tensor and check that all tensors match this tensor. However, + # there is legacy behavior which says that if you have a 1-D empty tensor + # (0,), this is permissible. So you can't assume that all the tensors + # have same dimensionality, and you can't assume that the first tensor is + # the correct stencil. + # + # We'll implement this in a few passes. First, we will try to infer the + # ndim of the cat output. If this ndim != 1, then we know that all ndim = + # 1 inputs must be empty, or are errors. If this ndim == 1, then life + # is easy (the legacy special case coincides with regular handling). + # + # NB: The regular implementation of cat just filters out empty inputs, + # but we do it slightly different here for better handling for unbacked + # SymInts + + example = None + for i, t in enumerate(tensors): + if example is None: + if t.ndim != 1: + example = t + else: + if t.ndim != 1: + torch._check( + t.ndim == example.ndim, + lambda: "Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for " + f"tensor number {i} in the list", + ) + + if example is None: + # example is None if everything is 1-D. If so, just arbitrarily pick + # the first one + example = tensors[0] + + shape = example.shape + filtered = [] + for tensor_idx, tensor in enumerate(tensors): + if len(shape) != len(tensor.shape): + assert tensor.ndim == 1 # we've already checked this above + # Don't suggest the legacy behavior in the error message + torch._check( + tensor.shape[0] == 0, + lambda: f"Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got 1-D for " + f"tensor number {tensor_idx} in the list", + ) + else: + # Remove inputs that are 1-D, zero size + if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0): + continue + # Don't bother checking size match, prims.cat will handle it + filtered.append(tensor) + + memory_format = cat_compute_output_memory_format(tensors) + + if len(filtered) == 0: + t = tensors[0] + + # TODO: fix this to work with meta tensors + try: + requires_grad = any(x.requires_grad for x in tensors) + except Exception: + requires_grad = False + + return empty( + (0,), + dtype=t.dtype, + device=t.device, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + dim = utils.canonicalize_dim(filtered[0].ndim, dim) + utils.validate_idx(filtered[0].ndim, dim) + + return prims.cat(filtered, dim).clone(memory_format=memory_format) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def column_stack(tensors: TensorSequenceType) -> TensorLikeType: + aligned_tensors = tuple( + x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors + ) + return cat(aligned_tensors, 1) + + +def conj(input: TensorLikeType) -> TensorLikeType: + if not utils.is_complex_dtype(input.dtype): + return input + if input.is_sparse: + return torch.conj_physical(input) + return prims.conj(input) + + +# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp +@register_decomposition(aten.constant_pad_nd) +@out_wrapper() +def constant_pad_nd( + input: TensorLikeType, pad: List[int], value: NumberType = 0 +) -> TensorLikeType: + torch._check( + len(pad) % 2 == 0, + lambda: f"Length of pad must be even but instead it equals {len(pad)}", + ) + + input_sizes = input.shape + l_inp = len(input_sizes) + + l_pad = len(pad) // 2 + l_diff = l_inp - l_pad + + torch._check( + l_inp >= l_pad, + lambda: "Length of pad should be no more than twice the number of " + f"dimensions of the input. Pad length is {len(pad)} while the input has " + f"{l_inp} dimensions.", + ) + + c_input = input + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] < 0: + c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) + + if pad[pad_idx + 1] < 0: + c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) + + # if none of the pads are positive we can just return the result + if builtins.all(p <= 0 for p in pad): + return c_input.clone() + + new_shape = list(input_sizes[:l_diff]) + + for i in range(l_pad): + pad_idx = len(pad) - ((i + 1) * 2) + new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] + torch._check( + new_dim > 0, + lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " + f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " + f"which is invalid. Check dimension {l_diff + i} of your input.", + ) + new_shape.append(new_dim) + + memory_format = utils.suggest_memory_format(input) + output = torch.empty( + new_shape, + dtype=input.dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=memory_format, + ) + + if value == 0 and input.dtype == torch.bool: + value = False + # torch.fill isn't typed to allow complex values + output = torch.fill(output, value) # type: ignore[arg-type] + + c_output = output + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] > 0: + c_output = c_output.narrow( + i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] + ) + if pad[pad_idx + 1] > 0: + c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) + + prims.copy_to(c_output, c_input) + return output + + +def contiguous( + a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format +) -> Tensor: + torch._check( + memory_format != torch.preserve_format, + lambda: "preserve memory format is unsupported by the contiguous operator", + ) + + if utils.is_contiguous_for_memory_format(a, memory_format=memory_format): + return a + + return torch.clone(a, memory_format=memory_format) + + +@out_wrapper() +def dstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") + aligned_tensors = atleast_3d(*tensors) + return cat(aligned_tensors, 2) + + +@register_decomposition(aten.expand) +def expand(a: Tensor, *shape) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # NOTE: cannot use utils.extract_shape_from_varargs here + # because that also validates the shape, but the shape + # given to expand may be "invalid" + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = tuple(shape[0]) + + torch._check( + len(shape) >= len(a.shape), + lambda: "expand: the requested shape has too few dimensions!", + ) + + offset = len(shape) - len(a.shape) + shape_ = list(shape) + for idx, x in enumerate(a.shape): + offset_idx = idx + offset + requested_length = shape[offset_idx] + torch._check( + guard_size_oblivious(requested_length == x) + or guard_size_oblivious(x == 1) + or requested_length == -1, + lambda: f"expand: attempting to expand a dimension of length {x}!", + ) + + shape_[offset_idx] = requested_length if requested_length != -1 else x + + # At this point shape must be valid + utils.validate_shape(shape_) + + return prims.broadcast_in_dim( + a, shape_, tuple(range(offset, len(a.shape) + offset)) + ) + + +# CompositeImplicitAutograd - don't register decomp +def expand_as(a: Tensor, b: Tensor) -> Tensor: + return a.expand(b.shape) + + +def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: + if chunks <= 0: + msg = f"Expected at least one chunk, but got {chunks}!" + raise ValueError(msg) + + dim = utils.canonicalize_dim(a.ndim, dim) + length = a.shape[dim] + chunk_size = math.ceil(length / chunks) + full_chunks = math.floor(length / chunk_size) + tail_chunk_size = length % chunk_size + + result = [] + for i in range(full_chunks): + result.append(narrow(a, dim, i * chunk_size, chunk_size)) + + if tail_chunk_size != 0: + result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) + + return tuple(result) + + +# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless +# a 0D tensor is flattened, in which case it's returned in 1D) +# CompositeImplicitAutograd - don't register decomp +def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: + start_dim = utils.canonicalize_dim(a.ndim, start_dim) + end_dim = utils.canonicalize_dim(a.ndim, end_dim) + + # Short-circuits on no-op + if start_dim == end_dim and a.ndim != 0: + return a + + # Tries to take a view + # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) + new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim) + if new_shape is not None: + return prims.collapse_view(a, start_dim, end_dim) + + # Makes a copy if it can't make a view + return prims.collapse(a, start_dim, end_dim) + + +@register_decomposition(aten.flip) +@out_wrapper() +def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: + if not isinstance(dims, tuple) and not isinstance(dims, list): + raise ValueError("dims has to be a sequence of ints") + dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment] + utils.validate_no_repeating_dims(dims) + return prims.rev(a, dims) + + +# CompositeImplicitAutograd - don't register decomp +def fliplr(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 2: + raise RuntimeError("Input must be >= 2-d.") + + return flip(a, (1,)) + + +# CompositeImplicitAutograd - don't register decomp +def flipud(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 1: + raise RuntimeError("Input must be >= 1-d.") + + return flip(a, (0,)) + + +# CompositeImplicitAutograd - don't register decomp +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + torch._check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + torch._check(length >= 0, lambda: "narrow(): length must be non-negative.") + dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + torch._check_with( + IndexError, + -dim_length <= start and start <= dim_length, # type: ignore[arg-type] + lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})", + ) + if start < 0: + start = start + dim_length + torch._check( + start <= dim_length - length, # type: ignore[arg-type] + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) + return prims.slice_in_dim(a, start, start + length, axis=dim) + + +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(narrow) + + +def _normalize( + a: Tensor, norm_dims: DimsType, eps: float +) -> Tuple[Tensor, Tensor, Tensor]: + """Computes mean and 1/std of a tensor along norm_dims. + + Used as a helper function for normalization layers. + + Args: + a (Tensor): input tensor + norm_dims (DimsType): dimensions to normalize over + eps (float): epsilon for numerical stability + + Returns: + out (Tensor): normalized tensor. + mean (Tensor): mean of the tensor along norm_dims. + rstd (Tensor): 1/std of the tensor along norm_dims. + """ + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) + computation_dtype = utils.get_computation_dtype(a.dtype) + a_acc = _maybe_convert_to_dtype(a, computation_dtype) + assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean + biased_var, mean = torch.var_mean( + a_acc, dim=norm_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + out = (a - mean) * rstd + return out, mean, rstd + + +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +@register_decomposition(aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) + out = out.view(input.shape) + + broadcast_dims = [0] + list(range(2, input.ndim)) + unsqueeze_bias = None + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + unsqueeze_weight = None + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + + if unsqueeze_weight is not None: + out = out * unsqueeze_weight + if unsqueeze_bias is not None: + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = torch.squeeze(mean, reduction_dims) + rstd = torch.squeeze(rstd, reduction_dims) + return (out, mean, rstd) + + +@register_decomposition(aten.native_layer_norm) +@out_wrapper("out0", "out1", "out2") +def native_layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + normalized_ndim = len(normalized_shape) + torch._check( + normalized_ndim >= 1, + lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " + + "containing at least one element, but got normalized_shape = " + + str(normalized_shape), + ) + # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False + # while torch.Size([1, 2, 3]) == (1, 2, 3) is True + # therefore we use tuple(normalized_shape) + torch._check( + weight is None or weight.shape == tuple(normalized_shape), + lambda: "Expected weight to be of same shape as normalized_shape, but got " + + "weight of shape " + + str(weight.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + bias is None or bias.shape == tuple(normalized_shape), + lambda: "Expected bias to be of same shape as normalized_shape, but got " + + "bias of shape " + + str(bias.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + input.ndim >= normalized_ndim + and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), + lambda: "Given normalized_shape=" + + str(normalized_shape) + + ", expected input with shape " + + str(normalized_shape) + + ", but got input of size " + + str(input.shape), + ) + + input = input.contiguous() + if weight is not None: + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + + axis = input.ndim - normalized_ndim + reduction_dims = list(range(axis, input.ndim)) + out, mean, rstd = _normalize(input, reduction_dims, eps) + + if weight is None and bias is not None: + out = out + bias + elif weight is not None and bias is None: + out = out * weight + elif weight is not None and bias is not None: + out = out * weight + bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + if input.device.type == "cpu": + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + return (out, mean, rstd) + + +# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. +# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu +@register_decomposition(aten.permute) +def permute(a: TensorLikeType, *dims) -> TensorLikeType: + _permutation = utils.canonicalize_dims( + a.ndim, utils.extract_dims_from_varargs(dims) + ) + return prims.transpose(a, _permutation) + + +@register_decomposition(aten.renorm) +@out_wrapper() +def renorm( + input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType +) -> TensorLikeType: + torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued") + torch._check(p > 0, lambda: "renorm: non-positive norm not supported") + torch._check( + not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued" + ) + torch._check( + maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}" + ) + ndim = input.ndim + torch._check( + ndim > 1, + lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions", + ) + + dim = utils.canonicalize_dim(ndim, dim) + reduce_dims = list(range(ndim)) + del reduce_dims[dim] + + # For half and bfloat16, calculate norm in float precision then cast + # normalization factor to half + acc_type = utils.get_computation_dtype(input.dtype) + if acc_type != input.dtype: + norm = torch.linalg.vector_norm( + input, p, reduce_dims, keepdim=True, dtype=acc_type + ) + else: + norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True) + + eps = 1e-7 + norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) + if acc_type != input.dtype: + norm_factor = prims.convert_element_type(norm_factor, input.dtype) + return (input * norm_factor).contiguous() + + +# CompositeImplicitAutograd - don't register decomp +@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd) +def stft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"stft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + if return_complex is None: + return_complex_ = input.is_complex() or ( + window is not None and utils.is_complex_dtype(window.dtype) + ) + torch._check( + return_complex_, + ( + "stft requires the return_complex parameter be given for real inputs, " + + "and will further require that return_complex=True in a future PyTorch release." + ), + ) + else: + return_complex_ = return_complex + + torch._check( + utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), + lambda: "stft expected a tensor of floating point or complex values", + ) + torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor") + + original_ndim = input.ndim + if original_ndim == 1: + input = input.unsqueeze(0) + + if center: + extra_dims = 3 - input.ndim + pad_amount = n_fft // 2 + extended_shape = [*itertools.repeat(1, extra_dims), *input.shape] + input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode) + input = input.view(input.size()[extra_dims:]) + + batch = input.size(0) + length = input.size(1) + torch._check( + 0 < n_fft <= length, + lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}", + ) + torch._check( + hop_length_ > 0, + lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}", + ) + torch._check( + 0 < win_length_ <= n_fft, + lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}", + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: ( + f"expected a 1D window tensor of size equal to win_length={win_length_}, " + + f"but got window with size {window.shape}" # type: ignore[union-attr] + ), + ) + + if win_length_ < n_fft: + if window is None: + window = torch.ones(win_length_, dtype=input.dtype, device=input.device) + left = (n_fft - win_length_) // 2 + window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left]) + + input = input.unfold(dimension=-1, size=n_fft, step=hop_length_) + if window is not None: + input = input * window + + complex_fft = utils.is_complex_dtype(input.dtype) + onesided = onesided if onesided is not None else not complex_fft + norm = "ortho" if normalized else None + if onesided: + torch._check( + not complex_fft, + lambda: "Cannot have onesided output if window or input is complex", + ) + out = torch.fft.rfft(input, dim=-1, norm=norm) + else: + out = torch.fft.fft(input, dim=-1, norm=norm) + + out.transpose_(1, 2) + + if original_ndim == 1: + out = out.squeeze_(0) + + return out if return_complex_ else torch.view_as_real(out) + + +# CompositeImplicitAutograd - don't register decomp +@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def istft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex=False, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"istft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + torch._check( + utils.is_complex_dtype(input.dtype), + lambda: ( + "istft input and window must be on the same device but got self on " + + f"{input.device} and window on {window.device}" # type: ignore[union-attr] + ), + ) + n_frames = input.size(-1) + fft_size = input.size(-2) + + expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1) + torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty") + torch._check( + 2 <= input.ndim <= 3, + lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}", + ) + onesided_ = onesided if onesided is not None else fft_size != n_fft + + if onesided_: + torch._check( + n_fft // 2 + 1 == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}" + ), + ) + else: + torch._check( + n_fft == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + "to match n_fft when onesided=False, but got {fft_size}", + ), + ) + + torch._check( + 0 < hop_length_ <= win_length_, + lambda: "istft expected 0 < hop_length <= win_length", + ) + torch._check( + 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft" + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: "Invalid window shape. window has to be 1D and length of `win_length`", + ) + + if window is None: + real_dtype = utils.corresponding_real_dtype(input.dtype) + window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device) + else: + window_ = window + + if win_length_ != n_fft: + left = (n_fft - win_length_) // 2 + window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0) + + original_ndim = input.ndim + if input.ndim == 2: + input = input.unsqueeze(0) + + input = input.transpose(1, 2) + norm = "ortho" if normalized else None + if return_complex: + torch._check( + not onesided_, + lambda: "cannot have onesided output if window or input is complex", + ) + input = torch.fft.ifft(input, dim=-1, norm=norm) + else: + torch._check( + window is None or not utils.is_complex_dtype(window.dtype), + lambda: "Complex windows are incompatible with return_complex=False", + ) + if not onesided_: + input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1) + input = torch.fft.irfft(input, dim=-1, norm=norm) + + assert input.size(2) == n_fft + + y_tmp = input * window_.view([1, 1, n_fft]) + y = aten.unfold_backward( + y_tmp, + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + window_envelop = aten.unfold_backward( + window_.pow(2).expand((1, n_frames, n_fft)), + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + + assert expected_output_signal_len == y.size(1) + assert expected_output_signal_len == window_envelop.size(1) + + start = n_fft // 2 if center else 0 + if length is not None: + end = start + length + elif center: + end = expected_output_signal_len - n_fft // 2 + else: + end = expected_output_signal_len + + length = max(0, end - start) + y = y.narrow(dim=1, start=start, length=length) + window_envelop = window_envelop.narrow(dim=1, start=start, length=length) + + window_envelop_lowest = window_envelop.abs().min().lt(1e-11) + torch._check( + not window_envelop_lowest.item(), + lambda: "window overlap add min less than 1e-11", + ) + + y = y / window_envelop + if original_ndim == 2: + y = y.squeeze(0) + + if end > expected_output_signal_len: + warnings.warn( + "The length of signal is shorter than the length parameter. Result is being " + + "padded with zeros in the tail. Please check your center and hop_length settings" + ) + y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0) + return y + + +# Get the new shape and stride after applying unfold to an input tensor +def _get_unfold_shape_stride( + a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int +): + a_ndim = len(a_shape) + dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True) + max_size = 1 if a_ndim == 0 else a_shape[dim] + last_stride = 1 if a_ndim == 0 else a_stride[dim] + + torch._check( + size <= max_size, + lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", + ) + + torch._check( + step > 0, + lambda: f"Step is {step} but must be > 0", + ) + + shape = list(a_shape) + strides = list(a_stride) + shape.append(size) + strides.append(last_stride) + if dim < a_ndim: + shape[dim] = (shape[dim] - size) // step + 1 + strides[dim] *= step + return shape, strides + + +@register_decomposition(aten.repeat) +@out_wrapper() +def repeat(a: Tensor, *repeat_shape) -> Tensor: + repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) + torch._check( + len(repeat_shape) >= len(a.shape), + lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", + ) + + if len(repeat_shape) == 0: + return torch.clone(a) + + num_new_dimensions = len(repeat_shape) - a.ndim + padded_shape = [1] * num_new_dimensions + for dim_size in a.shape: + padded_shape.append(dim_size) + + target_shape = tuple( + padded_size * repeat_size + for padded_size, repeat_size in zip(padded_shape, repeat_shape) + ) + + # return an empty tensor if one of the repeat_shape dimensions is zero + if 0 in repeat_shape: + return torch.empty( + target_shape, + dtype=a.dtype, + device=a.device, + requires_grad=a.requires_grad, + memory_format=utils.suggest_memory_format(a), + ) + + urtensor_shape = target_shape + urtensor_stride = utils.make_contiguous_strides_for(target_shape) + for dim, dim_size in enumerate(padded_shape): + # repeat each dimension by using unfold_copy operation + urtensor_shape, urtensor_stride = _get_unfold_shape_stride( + urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1) + ) + + # derive permute order by sorting urtensor strides + enumerated_stride = list(enumerate(urtensor_stride)) + enumerated_stride.sort(key=lambda item: item[1], reverse=True) + permute_order, sorted_stride = zip(*enumerated_stride) + + # add new and expand dimensions according to urtensor + repeat_xtensor = a.expand(urtensor_shape) + + # clone tensor to concretize expanded dimensions + cloned_result = torch.clone(repeat_xtensor) + + # transpose axis so strides are in sorted order + permuted_result = cloned_result.permute(permute_order) + + # reshape to get contiguous tensor with correct target shape + return permuted_result.reshape(target_shape) + + +def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq + + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Short-circuits if shape is the same + if guard_size_oblivious(sym_eq(tuple(a.shape), tuple(shape))): + return prims.view_of(a) + + # Special-cases tensors with no elements + if guard_size_oblivious(a.numel() == 0): + return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + assert length == 1 + _a = unsqueeze(_a, -1) + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + assert length == 1 + _a = squeeze(_a, -1) + return _a + + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape + + # NOTE [Reshape Algorithm] + # This algorithm works by attempting to greedily construct the desired dimensions in + # the output shape, left to right. It does this by, conceptually, accumulating + # dimensions of the original tensor, also left to right, until the dimension + # can be constructed using prims.split_dim. + # The algorithm also has special handling for tail squeezes/unsqueezes, like + # if a reshape from (5, 5) to (5, 5, 1) or vice versa. + # + # This algorithm does not flatten the original tensor and then split dims as appropriate + # because that would create copies more often than this algorithm. flatten is the only + # operation below which can create a view or a copy, and while it prefers creating + # views it may sometimes create a copy if the tensor's strides do not permit a view. + # As a result, this algorithm tries to minimize flattening. + # + # Note that a better version of this algorithm may exist. Regions which could be + # flattened without creating a copy can be identified in advance, and that might + # allow fewer flatten calls or faster short-circuiting to make a copy. + idx = 0 + a_ = a + for length in shape: + # Handles tail unsqueezes + if idx >= a_.ndim: + assert length == 1 + last_dim = a_.ndim - 1 + # NOTE: using split_dim instead of unsqueeze may seem silly here, + # but it's necessary to get the strides correct + a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim]) + idx = idx + 1 + continue + + # Skips dimensions that are already the correct length + if guard_size_oblivious(length == a_.shape[idx]): + idx = idx + 1 + continue + + # Gathers enough original dimensions such that this new dimension can be created + # Note that this accumulation will terminate because we've verified a and the shape + # specify the same number of elements above + accum = a_.shape[idx] + end = idx + while guard_size_oblivious(accum % length != 0): + end = end + 1 + accum = accum * a_.shape[end] + if end != idx: + # NOTE: in this case multiple dimensions must be flatten to create the desired dimension + # This flattening is why reshape sometimes creates a copy -- because flattening + # may return a view of a copy + + # Checks if collapse can be a view and short-circuits to copying reshape if it can't + new_shape, new_strides = prims._collapse_view_helper(a_, idx, end) + if new_shape is None: + if allow_copy: + return prims.reshape(a, shape) + + msg = "Cannot view a tensor with shape {} and strides {} as a tensor with shape {}!".format( + a.shape, a.stride(), shape + ) + raise ValueError(msg) + + a_ = flatten(a_, idx, end) + + # Splits the (possibly flattened) dimension to create the desired dim length + if guard_size_oblivious(accum != length): + a_ = prims.split_dim(a_, idx, length) + + idx = idx + 1 + + # Squeezes tail + while idx < a_.ndim: + assert a_.shape[idx] == 1 + a_ = squeeze(a_, idx) + + return a_ + + +# CompositeImplicitAutograd - don't register decomp +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call +# torch.reshape doesn't support unpacked shapes +def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=True) + + +# CompositeImplicitAutograd - don't register decomp +def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.reshape(other.size()) + + +@register_decomposition(aten.roll) +@out_wrapper() +def roll( + a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple() +) -> TensorLikeType: + """Reference implementation of :func:`torch.roll`.""" + dims = utils.canonicalize_dims(a.ndim, dims) + # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 + if not isinstance(shifts, Iterable): + shifts = (shifts,) + if not isinstance(dims, Iterable): + dims = (dims,) + + # Avoid modulo by zero + if a.numel() == 0: + # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors + return a.clone() + + if a.dim() == 0 and len(dims) > 0: + raise IndexError( + f"Dimension specified as {dims[0]} but tensor has no dimensions" + ) + + len_shifts = len(shifts) + len_dims = len(dims) + if len_shifts != 1 or len_dims != 1: + if len_shifts == 0: + raise RuntimeError("`shifts` required") + # Takes care of the case when dims is not specified (default) + # By default, the tensor is flattened before shifting, after which the original shape is restored + if len_dims == 0 and len_shifts == 1: + return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) + if len_shifts != len_dims: + raise RuntimeError( + f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" + ) + assert len_dims > 1 + tail_shifts = shifts[1:] + tail_dims = dims[1:] + first_dim_rolled = torch.roll(a, (shifts[0],), dims[0]) + return torch.roll(first_dim_rolled, tail_shifts, tail_dims) + + # This path is taken when only one dimension is rolled + # For example to get `first_dim_rolled` above + dim = dims[0] + size = a.shape[dim] + start = (size - shifts[0]) % size + idx = torch.arange(size, device=a.device) + return a.index_select(dim, torch.fmod(start + idx, size)) + + +@register_decomposition(aten.rot90) +@out_wrapper() +def rot90( + a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1) +) -> TensorLikeType: + """Reference implementation of :func:`torch.rot90`.""" + if len(dims) != 2: + raise RuntimeError( + f"expected total rotation dims == 2, but got dims = {len(dims)}" + ) + if a.ndim < 2: + raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}") + + # Do this after the initial checks to be compatible with the behavior in + # core. + dims = utils.canonicalize_dims(a.ndim, dims) + + if dims[0] == dims[1]: + raise RuntimeError( + f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" + ) + k = k % 4 # Rotation direction is from the second towards the first axis for k < 0 + if k == 1: + return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1]) + elif k == 2: + return torch.flip(a, dims) + elif k == 3: + return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) + else: + return clone(a, memory_format=torch.contiguous_format) + + +def _check_stack_inputs(tensors: TensorSequenceType) -> None: + entry_shape = tensors[0].shape + for i in range(1, len(tensors)): + assert tensors[i].shape == entry_shape, ( + f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0" + f"and {tensors[i].shape} at entry {i}" + ) + + +@register_decomposition(aten.stack) +@out_wrapper() +def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + assert len(tensors) > 0, "stack expects a non-empty TensorList" + wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) + # Refs need sparse support to check other condition + if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse: + _check_stack_inputs(tensors) + result_sizes = list(tensors[0].shape) + result_sizes.insert(wrapped_dim, len(tensors)) + out = torch.cat(tensors, wrapped_dim) + return out.view(result_sizes) + + # If dim == tensors[0].ndim, view cannot efficiently handle it + return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + if a.numel() == 0: + a_exp = exp(a_) + else: + a_max = amax(a_, dim, keepdim=True) + a_exp = exp(a_ - a_max) + return _maybe_convert_to_dtype( + true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype + ) # type: ignore[return-value] + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def hstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") + aligned_tensors = atleast_1d(*tensors) + if aligned_tensors[0].ndim == 1: + return cat(aligned_tensors, 0) + return cat(aligned_tensors, 1) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def vstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") + aligned_tensors = atleast_2d(*tensors) + return cat(aligned_tensors, 0) + + +# CompositeImplicitAutograd - don't register decomp +def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: + dim = utils.canonicalize_dim(a.ndim, dim) + torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") + return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) + + +@register_decomposition(aten.unbind) +def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: + dim = utils.canonicalize_dim(t.ndim, dim) + torch._check_index( + len(t.shape) > 0, + lambda: "Dimension specified as 0 but tensor has no dimensions", + ) + if t.shape[dim] == 0: + return tuple() + else: + return tuple( + torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) + ) + + +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return x.clone(memory_format=torch.contiguous_format).index_copy_( + dim, index, tensor + ) + + +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + y = x.unsqueeze(0) if x.ndim == 0 else x + idx = (slice(None),) * dim + (index,) + y[idx] = tensor + return x + + +@register_decomposition(aten.index_fill) +@out_wrapper() +def index_fill( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=False) + + +@register_decomposition(aten.index_fill_) +def index_fill_( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=True) + + +def _index_fill( + x: TensorLike, + dim: int, + index: TensorLike, + value: Union[NumberType, TensorLike], + *, + inplace: bool, +): + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if isinstance(value, TensorLike): + torch._check( + value.ndim == 0, + lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] + f"Got a tensor with {value.ndim} dimensions.", + ) # type: ignore[arg-type] + else: + value = torch.scalar_tensor( + value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type] + ) + + # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them + zero_dim = x.ndim == 0 + y = x.unsqueeze(0) if zero_dim else x + # index_copy does not broadcast on value so we have to do it manually + shape = list(y.shape) + shape[dim] = index.numel() + value = value.expand(shape) + index_copy = Tensor.index_copy_ if inplace else torch.index_copy + out = index_copy(y, dim, index, value) # type: ignore[operator] + if inplace: + return x + else: + if zero_dim: + # The clone is necessary so that it returns a fresh tensor rather than a view + out = out.squeeze(0).clone() + # index_fill preserves the strides. index_copy always returns contiguous tensors + if out.stride() != x.stride(): + new_out = torch.empty_like(x) + new_out.copy_(out) + out = new_out + return out + + +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + # index_add always returns a new contiguous tensor + return x.clone(memory_format=torch.contiguous_format).index_add_( + dim, index, tensor, alpha=alpha # type: ignore[arg-type] + ) + + +@register_decomposition(aten.index_select) +@out_wrapper() +def index_select(x: TensorLike, dim: int, index: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if index.ndim == 0: + index = index.unsqueeze(0) + if x.ndim == 0: + # Treat scalars as elements of \R^1 + # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction + return torch.empty_like(x).index_copy(0, index, x.expand_as(index)) + + idx = (slice(None),) * dim + (index,) + return x[idx] + + +@register_decomposition(aten.squeeze.dims) +def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if dim is None: + dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1) + return prims.squeeze(a, dims) if dims else prims.view_of(a) + + ndim = a.ndim + dim = utils.canonicalize_dims(ndim, dim) + dims = (dim,) if isinstance(dim, Dim) else dim + # Short-circuits if the tensor has no dimensions + if ndim == 0: + assert len(dims) == 0 or dims == (0,) + return prims.view_of(a) + + # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 + dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1)) + if len(dims) == 0: + return prims.view_of(a) + if len(dims) == 1: + return prims.squeeze(a, dims) + dims_list = list(dims) + dims_list = sorted(dims_list, reverse=True) + for i in dims_list: + a = squeeze(a, i) + return a + + +# Note: does not work with TensorMetas because of data-dependent control-flow +# CompositeImplicitAutograd - don't register decomp +def tensor_split( + a: TensorLikeType, + indices_or_sections: Union[Tensor, DimsType], + dim: int = 0, +) -> Tuple[TensorLikeType, ...]: + _dim = utils.canonicalize_dim(a.ndim, dim) + if a.ndim == 0: + msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" + raise ValueError(msg) + + # If indices_or_sections is a tensor, it must be a CPU Long tensor + if isinstance(indices_or_sections, TensorLike): + if not indices_or_sections.device.type == "cpu": + msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {}".format( + indices_or_sections.device + ) + raise ValueError(msg) + if indices_or_sections.dtype != torch.long: + msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, " + f" but received one with dtype {indices_or_sections.dtype}" + raise ValueError(msg) + + # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length + if isinstance(indices_or_sections, IntLike) or ( + isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 + ): + sections: int = ( + indices_or_sections # type: ignore[assignment] + if isinstance(indices_or_sections, Number) + else indices_or_sections.item() + ) + + if sections <= 0: + msg = f"tensor_split: number of sections must be greater than 0, but was {sections}" + raise ValueError(msg) + + splits = [] + dim_size = a.shape[_dim] + min_split_size = math.floor(dim_size / sections) + num_splits_one_extra = dim_size % sections + start_idx = 0 + for split_idx in range(sections): + split_size = ( + min_split_size + 1 + if (split_idx < num_splits_one_extra) + else min_split_size + ) + s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim) + splits.append(s) + start_idx = start_idx + split_size + + return tuple(splits) + # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits + else: + indices = indices_or_sections + if isinstance(indices_or_sections, TensorLike): + if indices_or_sections.ndim != 1: + msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " + f"but received a tensor with {indices_or_sections.ndim} dimensions" + raise ValueError(msg) + + indices = indices_or_sections.tolist() + + splits = [] + start_idx = 0 + for x in indices: + splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim)) + start_idx = x + splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim)) + return tuple(splits) + + +# CompositeImplicitAutograd - don't register decomp +def hsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> Tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 1, + lambda: ( + "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + dim = 0 if a.ndim == 1 else 1 + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + (split_size != 0 and a.shape[dim] % split_size == 0), + lambda: ( + "torch.hsplit attempted to split along dimension " + + str(dim) + + ", but the size of the dimension " + + str(a.shape[dim]) + + " is not divisible by the split_size " + + str(split_size) + + "!" + ), + ) + return tensor_split(a, split_size, dim) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "hsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, dim) + + +# CompositeImplicitAutograd - don't register decomp +def vsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> Tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 2, + lambda: ( + "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + (split_size != 0 and a.shape[0] % split_size == 0), + lambda: ( + f"torch.vsplit attempted to split along dimension 0" + f", but the size of the dimension " + f"{a.shape[0]}" + f" is not divisible by the split_size " + f"{split_size}" + f"!" + ), + ) + return tensor_split(a, split_size, 0) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "vsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, 0) + + +@register_decomposition(aten.diag.out) +@out_wrapper() +def diag( + self: TensorLikeType, + offset: int = 0, +) -> TensorLikeType: + ndim = self.dim() + torch._check( + ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" + ) + if ndim == 1: + return torch.diag_embed(self, offset) + else: + return torch.diagonal_copy(self, offset) + + +@register_decomposition(aten.diagonal_scatter) +@out_wrapper() +def diagonal_scatter( + input: TensorLikeType, + src: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + out = utils.clone_preserve_strides(input) + diag = out.diagonal(offset, dim1, dim2) + torch._check( + diag.shape == src.shape, + lambda: "expected src to have a size equal to the diagonal of the input." + f"Got {src.shape} for a diagonal of shape {diag.shape}", + ) + copy_to(diag, src) + return out + + +@register_decomposition(aten.diagonal) +def diagonal( + self: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.diagonal + """ + num_dims = self.dim() + dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + storage_offset = self.storage_offset() + + if offset >= 0: + diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0) + else: + diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0) + + if diag_size > 0: + if offset >= 0: + storage_offset += offset * self.stride()[dim2] + else: + storage_offset -= offset * self.stride()[dim1] + + sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)] + sizes.append(diag_size) + + strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)] + strides.append(self.stride()[dim1] + self.stride()[dim2]) + + result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset) + + return result + + +diagonal_copy = _make_copy_from_view(diagonal) + + +@register_decomposition(aten.diag_embed) +@out_wrapper() +def diag_embed( + t: TensorLikeType, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + """ + Reference implementation of torch.diag_embed + """ + # convert from negative dims + rank = t.ndim + 1 + dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) + dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) + + # as per the docs, exchanging dims is equivalent to changing the sign of + # offset + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + # as per the docs, the size of last dim is placed at dim1 and dim2 + last_dim = t.size(-1) + + if offset != 0: + # add padding to match the new size + t_shape = list(t.shape) + t_shape[-1] = builtins.abs(offset) + z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False) + pair = (z, t) if offset > 0 else (t, z) + t = torch.cat(pair, dim=-1) + # make sure the diagonal always has the same size + last_dim += builtins.abs(offset) + + # preserve original data, but place 1 at dim1 and move last dim to dim2 + t = t.unsqueeze(dim1).movedim(-1, dim2) + + # generate ranges shifting indices based on offset + a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64) + b_range = torch.arange( + offset, last_dim + offset, device=t.device, dtype=torch.int64 + ) + + # broadcast + cond = a_range == b_range.unsqueeze(-1) + cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] + cond = cond.reshape(cond_shape) + + # aten.diag_embed always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(cond, t).contiguous() + + +@register_decomposition(aten.block_diag) +@out_wrapper() +def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType: + """ + Reference implementation of torch.block_diag + """ + tensors_2d = [ + tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors + ] + + ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d) + device = tensors_2d[0].device + + result = [] + + col_start = 0 + for i, tensor in enumerate(tensors_2d): + torch._check( + tensor.dim() == 2, + lambda: "Input tensors must have 2 or fewer dimensions. " + f"Input {i} has {tensor.dim()} dimensions", + ) + torch._check( + tensor.device == device, + lambda: "Input tensors must all be on the same device. " + f"Input 0 is on device {device} and input {i} is on device {tensor.device}.", + ) + row, col = tensor.shape + left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype) + right = torch.zeros( + (row, ncols - col_start - col), device=device, dtype=tensor.dtype + ) + result += [torch.cat((left, tensor, right), dim=1)] + col_start += col + + return torch.cat(result, dim=0) + + +def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType: + """ + This is used as an input to PythonRefInfo. `torch.block_diag` + expects arguments splatted, but `aten.block_diag` expects only + one argument that is a list of Tensors. + """ + return _block_diag_iterable(tensors) + + +# CompositeImplicitAutograd - don't register decomp +def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: + if a.ndim < 3: + raise RuntimeError( + f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" + ) + if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): + raise RuntimeError( + "torch.dsplit attempted to split along dimension 2, " + + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" + ) + return tensor_split(a, sections, 2) + + +@register_decomposition(aten.t.default) +def t(a: TensorLikeType): + # TODO: Add sparse support + # if a.is_sparse: + # sparse_dim = a.sparse_dim() + # dense_dim = a.dense_dim() + # if not (sparse_dim <= 2 and dense_dim == 0): + # raise RuntimeError( + # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and" + # f"{dense_dim} dense dimensions" + # ) + if a.ndim > 2: + raise RuntimeError( + f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D" + ) + return torch.transpose(a, 0, 0 if a.ndim < 2 else 1) + + +# CompositeImplicitAutograd - don't register decomp +def T(a: TensorLikeType) -> TensorLikeType: + # n != 2 && n != 0 is deprecated in regular PyTorch. + torch._check( + a.ndim in (0, 2), + lambda: ( + "The use of `x.T` on tensors of dimension other than 0 or 2 " + "to reverse their shape is not supported." + ), + ) + return a.t() + + +@register_decomposition(aten.alias) +def alias(a: TensorLikeType) -> TensorLikeType: + return prims.view_of(a) + + +@register_decomposition(aten.transpose) +def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: + _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] + + if a.ndim <= 1 or dim0 == dim1: + return aten.alias.default(a) + + _permutation = list(range(0, a.ndim)) + _permutation[_dim0] = _dim1 + _permutation[_dim1] = _dim0 + return torch.permute(a, _permutation) + + +# Aliases for transpose +swap_axes = transpose + + +@register_decomposition(aten.unfold) +def unfold( + self: TensorLikeType, dimension: int, size: int, step: int +) -> TensorLikeType: + shape, strides = _get_unfold_shape_stride( + self.shape, self.stride(), dimension, size, step + ) + return self.as_strided(shape, strides) + + +@register_decomposition(aten.unfold_copy) +@out_wrapper() +def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int): + return self.unfold(dimension, size, step).clone( + memory_format=torch.contiguous_format + ) + + +def _cumsumprod_common( + func, + init, + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # We implement all the kwargs of a reduction. ATen just handles dtype + # nb. This decomposition may not be as efficient as a backend-specific implementation + ndim = a.ndim + dim = utils.canonicalize_dim(ndim, dim) + if ndim == 0: + return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out) + a = a.unsqueeze(dim + 1) + rg = torch.arange(a.shape[dim], device=a.device) + mask = rg.unsqueeze(1) <= rg + for _ in range(ndim - dim - 1): + mask = mask.unsqueeze(-1) + masked_a = torch.where(mask, a, init) + return func(masked_a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumsum) +def cumsum( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumprod) +def cumprod( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out) + + +# Note: although squeeze is documented as having the out= kwarg it doesn't +@register_decomposition(aten.unsqueeze) +def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: + # Note that unsqueeze canonicalizes with rank + 1 because it allows + # a new innermost dimension to be specified + ndim = a.ndim + 1 + dim = utils.canonicalize_dim(ndim, dim) + return prims.expand_dims(a, (dim,), ndim=ndim) + + +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view +# doesn't support unpacked shapes +# TODO: Turn this into a decomposition (currently fails on reshape meta tests) +@register_decomposition(aten.view.default) +def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=False) + + +# CompositeImplicitAutograd - don't register decomp +def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.view(other.size()) + + +# CompositeImplicitAutograd - don't register decomp +def ravel(a: TensorLikeType) -> TensorLikeType: + return reshape(a, (-1,)) + + +# CompositeImplicitAutograd - don't register decomp +# missing ref impl. for aten.gather +@out_wrapper() +def take_along_dim( + a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + torch._check( + a.ndim == indices.ndim, + lambda: ( + "torch.take_along_dim(): input and indices should have the same " + f"number of dimensions, but got {a.ndim} dimensions for input, and " + f"{indices.ndim} dimensions for indices" + ), + ) + + torch._check( + utils.is_integer_dtype(indices.dtype), + lambda: ( + "torch.take_along_dim(): dtype of indices should be int but got " + f"{indices.dtype} instead" + ), + ) + + if dim is None: + return torch.gather(a.view(-1), 0, indices.view(-1)) + else: + self_sizes = list(a.shape) + self_sizes[dim] = indices.size(dim) + broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size()) + indices_broadcast = broadcast_to(indices, broadcast_shape) + + indices_sizes = list(indices.shape) + indices_sizes[dim] = a.size(dim) + broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size()) + self_broadcast = broadcast_to(a, broadcast_shape) + + return torch.gather(self_broadcast, dim, indices_broadcast) + + +@out_wrapper() +def empty( + *shape, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, +) -> TensorLikeType: + torch._check( + memory_format != torch.preserve_format, + lambda: "torch.empty: the Preserve memory format is not supported", + ) + + shape = utils.extract_shape_from_varargs(shape) + + if memory_format == torch.contiguous_format: + strides = utils.make_contiguous_strides_for(shape) + elif memory_format == torch.channels_last_3d: + strides = utils.make_channels_last_3d_strides_for(shape) + else: # memory_format == torch.channels_last + torch._check( + memory_format == torch.channels_last, + lambda: f"torch.empty: received an unknown memory format {memory_format}!", + ) + strides = utils.make_channels_last_2d_strides_for(shape) + + return torch.empty_strided( + shape, + strides, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@out_wrapper() +def empty_permuted( + shape, + physical_layout, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + return prims.empty_permuted( + shape, + physical_layout, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_empty) +@out_wrapper() +def new_empty( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty( + size, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.new_empty_strided) +@out_wrapper() +def new_empty_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.Tensor.new_empty_strided + """ + + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty_strided( + size, + stride, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.zeros.default) +@out_wrapper() +def zeros( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + False if dtype == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_zeros) +@out_wrapper() +def new_zeros( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.ones.default) +@out_wrapper() +def ones( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + True if dtype == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_ones) +@out_wrapper() +def new_ones( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_full) +@out_wrapper() +def new_full( + a: TensorLikeType, + size: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + + +@register_decomposition(aten.empty_like) +@out_wrapper() +def empty_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + if memory_format != torch.preserve_format: + return torch.empty( + a.shape, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + logical_to_physical_perm = ( + utils.compute_elementwise_output_logical_to_physical_perm(a) + ) + # identity perm is [2, 1, 0] + return torch.empty_permuted( + a.shape, + logical_to_physical_perm, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition([aten.arange.start_step, aten.arange.start_out]) +@out_wrapper() +def arange( + start: NumberType = 0, + end: Optional[NumberType] = None, + step: NumberType = 1, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + device = torch.device(utils.device_or_default(device)) + + assert not isinstance(start, complex) + assert not isinstance(end, complex) + assert not isinstance(step, complex) + + # Case: torch.arange(5) + if end is None: + end = start + start = 0 + torch._check(step != 0, lambda: "step must be nonzero") + if step > 0: + torch._check( + end >= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + elif step < 0: + torch._check( + end <= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + + def is_finite(x): + return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) + + torch._check( + is_finite(start) and is_finite(end), + lambda: f"unsupported range: {start} -> {end}", + ) + torch._check( + is_finite(step), + lambda: f"step must be finite but got {step}", + ) + + if dtype is None: + args = (start, end, step) + integer_args = builtins.all(isinstance(arg, IntLike) for arg in args) + dtype = torch.int64 if integer_args else torch.get_default_dtype() + + is_integer = utils.is_integer_dtype(dtype) + if is_integer: + xstart = sym_int(start) + xend = sym_int(end) + xstep = sym_int(step) + + # For int64 we truncate arguments to int before calculating length, but + # other integral dtypes we don't. Weird... but needed to match ATen shapes. + if dtype == torch.int64: + # Uses floordiv to avoid ceil in inductor. + sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] + length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined] + else: + length = math.ceil((end - start) / step) + + if is_integer: + return prims.iota( + length, + start=xstart, # type: ignore[possibly-undefined] + step=xstep, # type: ignore[possibly-undefined] + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + computation_dtype = utils.get_acc_type(dtype, device) + index = prims.iota( + length, + start=0, + step=1, + dtype=torch.int64, + device=device, + requires_grad=False, + ) + index = _maybe_convert_to_dtype(index, computation_dtype) + result = start + step * index + result = _maybe_convert_to_dtype(result, dtype) + + if requires_grad: + result.requires_grad_(True) + return result + + +@register_decomposition(aten.lerp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("start", "end", "weight"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): + inputs = [start, end] + if isinstance(weight, Number): + weight = start.new_full((), weight) # type: ignore[arg-type] + else: + inputs.append(weight) + assert isinstance(weight, Tensor) # mypy + # We implement it this way for numerical stability. We assume (in the stability optimisation) + # that 0 <= weight <= 1. We take the abs to deal with complex numbers + # We want to perform operations near zero, which is where floating points are most precise + # thus, we perform the following optimisation: + # If weight.abs() >= 0.5: + # return (1 - weight) * (start - end) + end + mask = weight.abs() >= 0.5 + coeff = torch.where(mask, weight - 1, weight) + base = torch.where(mask, end, start) + output = coeff * (end - start) + base + # make sure the decomposition output's stride is same as non-decomposition path. + stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) + if output.stride() != stride: + output = prims.copy_strided(output, stride) + + return handle_noncontiguous_outputs(inputs, output) + + +@register_decomposition(aten.linspace) +@out_wrapper() +def linspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, torch.float64) + if isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, torch.float64) + + if py_any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + torch._check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + + # steps does not participate in the computation of the dtype + torch._check_type( + isinstance(steps, IntLike), + lambda: f"received an invalid combination of arguments - got \ +({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", + ) + assert isinstance(steps, IntLike) # for mypy + torch._check(steps >= 0, lambda: "number of steps must be non-negative") + + factory_kwargs = { + "layout": layout, + "device": device, + "pin_memory": pin_memory, + "requires_grad": requires_grad, + } + if steps == 0: + return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if steps == 1: + if isinstance(start, TensorLikeType): + return torch.empty((steps,), dtype=dtype, **factory_kwargs).copy_(start) # type: ignore[arg-type] + else: + return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + + # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes + rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type] + + # Small types need to be computed in higher precision as this is, at heart, an associative scan + dtype_red = ( + torch.int64 + if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype)) + else dtype + ) + computation_dtype, _ = utils.reduction_dtypes( + rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red + ) + cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype) + + # We implement torch.lerp without performing rg / (steps - 1) explicitly + # With this we get out[0] == start, out[-1] == end + step = (end - start) / (steps - 1) + out = torch.where( + rg < steps / 2, + start + step * cast_rg(rg), # type: ignore[arg-type,operator] + end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator] + ) + return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logspace) +@out_wrapper() +def logspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + base: NumberType = 10, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if dtype is None: + dtype = torch.get_default_dtype() + + # NB: NumPy doesn't have this cast + if prims.utils.is_integer_dtype(dtype): + if isinstance(start, FloatLike): + start = sym_int(start) + elif isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, dtype) + if isinstance(end, FloatLike): + end = sym_int(end) + elif isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, dtype) + + if py_any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + dtype = default_complex_dtype + _dtype = None # torch.linspace will update the correct dtype + else: + _dtype = torch.float64 + + assert not isinstance(base, complex) # for mypy + if base < 0: + raise NotImplementedError + ret = torch.linspace( # type: ignore[misc] + start, # type: ignore[arg-type] + end, # type: ignore[arg-type] + steps, # type: ignore[arg-type] + dtype=_dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value] + + +@overload +def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): + pass + + +@overload +def meshgrid(*tensors: TensorLikeType, indexing: str): + pass + + +@register_decomposition(aten.meshgrid) +def meshgrid( + *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]], + indexing: str, +) -> List[TensorLikeType]: + # This ref simultaneously handles two overloads (see stubs above) + # The `indexing` argument is currently optional for torch.meshgrid, but we + # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 + if isinstance(tensors[0], (list, tuple)): + assert len(tensors) == 1 + tensors = tuple(tensors[0]) + + torch._check( + py_all(isinstance(a, TensorLike) for a in tensors), + lambda: "meshgrid expects its inputs to be tensors", + ) + + torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") + + for i in range(len(tensors) - 1): + torch._check( + tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same dtype", + ) + torch._check( + tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same device", + ) + + swap_first_and_second_tensors = False + if indexing == "xy": + swap_first_and_second_tensors = len(tensors) >= 2 + if swap_first_and_second_tensors: + tensors = (tensors[1], tensors[0], *tensors[2:]) + else: + torch._check( + indexing == "ij", + lambda: ( + 'torch.meshgrid: indexing must be one of "xy" or "ij", ' + f"but received: {indexing}" + ), + ) + + result_shape: List[int] = [] + for t in tensors: + assert isinstance(t, TensorLike) # mypy + torch._check( + t.ndim == 0 or t.ndim == 1, + lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", + ) + result_shape.append(t.numel()) + + grids: List[TensorLikeType] = [] + for i, t in enumerate(tensors): + assert isinstance(t, TensorLike) # mypy + if t.ndim == 0: + t = t.view((1,)) + grids.append(prims.broadcast_in_dim(t, result_shape, (i,))) + + if swap_first_and_second_tensors: + # Swap outputs if we originally swapped at the beginning + grids[0], grids[1] = grids[1], grids[0] + + return grids + + +# CompositeImplicitAutograd - don't register decomp +def movedim( + input: TensorLikeType, + source: Union[int, DimsSequenceType], + destination: Union[int, DimsSequenceType], +) -> TensorLikeType: + """ + Reference implementation of torch.movedim + """ + if type(source) is int: + source = (source,) + if type(destination) is int: + destination = (destination,) + + # Converts to list to produce a compatible error message with core PyTorch, + # which prints sequences in square brackets. + torch._check( + len(source) == len(destination), # type: ignore[arg-type] + lambda: ( + "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] + f"({list(source)} dims) should contain the same number " # type: ignore[arg-type] + f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type] + ), + ) + + rank = input.ndim + ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type] + ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type] + + sss = set(ss) + dss = set(ds) + + # See above on why this converts to list in error messages. + torch._check( + len(ss) == len(sss), + lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] + ) + torch._check( + len(ds) == len(dss), + lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] + ) + + m = dict(zip(ds, ss)) + dims = [] + si = 0 # source index + for di in range(rank): + # check if the destination index is in the mapping + s = m.get(di) + if s is not None: + # insert source index if found + dims.append(s) + else: + # insert source index sequentially, skipping indices from the mapping + while si in sss: + si += 1 + dims.append(si) + si += 1 + + result = torch.permute(input, tuple(dims)) + + return result + + +# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints +@register_decomposition(aten.empty_strided) +@out_wrapper() +def empty_strided( + shape: Union[ShapeType, Tuple[ShapeType]], + strides: StrideType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # Layout == strided, pin_memory is False + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + shape = utils.extract_shape_from_varargs(shape) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + return prims.empty_strided( + shape, + strides, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.eye) +@out_wrapper() +def eye( + n: int, + m: Optional[int] = None, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, # TODO: unused +) -> TensorLikeType: + """ + Reference implementation of torch.eye + """ + if m is None: + m = n + + torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") + torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") + + range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) + range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) + + cond = range_n.unsqueeze(-1) == range_m + if dtype is torch.bool: + return cond + else: + one = torch.ones( + (1,), + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=False, + ) + return torch.where(cond, one, 0) + # TODO: Use requires_grad. All refs taking the requires_grad kwarg must + # return a leaf tensor. + # result.requires_grad_(requires_grad) + + +@register_decomposition([aten.full.default, aten.full.out]) +@out_wrapper() +def full( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + + e = empty( + shape, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return torch.fill(e, fill_value) # type: ignore[arg-type] + + +def full_like( + a: TensorLikeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + e = torch.empty_like( + a, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + return fill(e, fill_value) + + +@register_decomposition(aten.zeros_like) +@out_wrapper() +def zeros_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.ones_like) +@out_wrapper() +def ones_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.randn.default) +@out_wrapper() +def randn( + *shape, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_pin_memory(pin_memory) + + shape_ = utils.extract_shape_from_varargs(shape) + + dtype = utils.dtype_or_default(dtype) + device = utils.device_or_default(device) + + return prims.normal( + shape_, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def scalar_tensor( + a: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) + device = device if device is not None else torch.device("cpu") + return prims.scalar_tensor(a, dtype=dtype, device=device) + + +# +# Randomness References +# + + +def _uniform_helper( + shape: ShapeType, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + *, + dtype: torch.dtype, + device: DeviceLikeType, +) -> TensorLikeType: + utils.validate_shape(shape) + + assert isinstance(low, Number) + assert isinstance(high, Number) + low = sym_float(low) + high = sym_float(high) + + assert isinstance(dtype, torch.dtype) + device = utils.canonicalize_device(device) + + return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device) + + +@register_decomposition(aten.masked_fill) +@out_wrapper() +def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType): + python_type = utils.dtype_to_type(a.dtype) + if isinstance(value, Number): + value_type = type(value) + else: + # NOTE: Could not use value = item(value) as it resulted in + # RuntimeError: Cannot cast FakeTensor(cpu) to number + value_ndim = value.ndim + torch._check( + value_ndim == 0, + lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", + ) + # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise. + is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu" + torch._check( + is_cpu_scalar or value.device == a.device, + lambda: "Expected `value` to be on same device as `a`", + ) + value_type = utils.dtype_to_type(value.dtype) + + if value_type is complex: + # only downcasting from complex to lower type is not allowed. + # We allow casting `value` to lower type for other case + # Eg. float -> int. + # Ref: https://github.com/pytorch/pytorch/issues/79195 + torch._check( + utils.is_weakly_lesser_type(value_type, python_type), + lambda: f"could not convert to type {python_type} without overflow", + ) + + # Since `where` allows type-promotion, + # cast value to correct type before passing to `where` + value = _maybe_convert_to_dtype(value, a.dtype) + r = torch.where(mask, value, a) # type: ignore[arg-type] + + # aten.mask_fill always return a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return r.contiguous() + + +@register_decomposition(aten.masked_fill_) +def masked_fill_( + a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType +) -> TensorLikeType: + b = torch.masked_fill(a, mask, value) # type: ignore[arg-type] + a.copy_(b) + return a + + +# CompositeImplicitAutograd - don't register decomp +def allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + return bool( + torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() + ) + + +def equal(a: TensorLikeType, b: TensorLikeType) -> bool: + utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) + utils.check_same_dtype(a, b) + + # Shape check + if a.ndim != b.ndim: + return False + + for x, y in zip(a.shape, b.shape): + if x != y: + return False + + # Short-circuits if there are no elements to validate + if a.numel() == 0: + return True + + return item(all(eq(a, b))) # type: ignore[return-value] + + +@register_decomposition(aten.norm) +@out_wrapper(exact_dtype=True) +def norm( + input: TensorLikeType, + p: Optional[Union[float, str]] = "fro", + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # In these cases we compute the "Frobenius norm" + if ( + p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2) + ) or p is None: + p = 2 + if isinstance(dim, Dim): + dim = [dim] + if isinstance(p, str): + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if dim is None: + dim = tuple(range(input.ndim)) + return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) + + +@register_decomposition(aten.trace) +@out_wrapper() +def trace(self: TensorLikeType) -> TensorLikeType: + torch._check( + self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" + ) + return torch.sum(torch.diag(self, 0)) + + +def _make_r_binary_op(base_op): + def rop( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + ) -> TensorLikeType: + return base_op(b, a) + + return rop + + +rtruediv = _make_r_binary_op(true_divide) +rfloordiv = _make_r_binary_op(floor_divide) +rpow = _make_r_binary_op(pow) + + +@register_decomposition(aten.triu) +@out_wrapper() +def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) >= diagonal + + # aten.triu always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +@register_decomposition(aten.tril) +@out_wrapper() +def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) <= diagonal + + # aten.tril always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h +# The components of the matrix that belong to the lower triangle with offset +# form a pentagon that can be broken down into a top trapezoid and a bottom +# rectangle. For the implementation of tril_indices, we need the sizes of +# both of these, as well as the length of the top side of the trapezoid. +def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) + m_last_row = max(0, min(col, row + offset)) + n_row_all = max(0, min(row, row + offset)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size = max(0, diff_row * col) + + return trapezoid_size, rectangle_size, m_first_row + + +def _trilu_checks( + name: str, + row: int, + col: int, + dtype: torch.dtype, + layout: torch.layout, + pin_memory: bool, +): + torch._check(row >= 0, lambda: f"row must be non-negative, got {row}") + torch._check(col >= 0, lambda: f"col must be non-negative, got {col}") + torch._check( + dtype in (torch.int32, torch.int64), + lambda: f"\"{name}\" not implemented for '{dtype}'", + ) + + +# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu +@register_decomposition(aten.tril_indices) +@out_wrapper() +def tril_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset) + row_offset = max(0, -offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # first we do the indices for top trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = m_first_row - 0.5 + row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1)) + col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + # then bottom rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset) + col_inds2 = xs2 % col + + return torch.stack( + (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2))) + ) + + +# Similar to _get_tril_sizes above, but here there is a top trapezoid and +# a bottom rectangle instead. Note that you can't reduce this to +# _get_tril_sizes(col, row, -offset) because that would correspond to +# decomposing into a left trapezoid and right rectangle. +def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = max(0, col - offset) if offset > 0 else col + + # Number of elements in top rectangle + rectangle_size = max(0, min(row, -offset) * col) + + # Number of elements in bottom trapezoid + trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1) + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + trapezoid_size = triu_size - rectangle_size + + return trapezoid_size, rectangle_size, m_first_row + + +@register_decomposition(aten.triu_indices) +@out_wrapper() +def triu_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset) + col_offset = max(0, offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # indices for top rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + col_inds2 = xs2 % col + + # bottom trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = -0.5 - m_first_row + row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1)) + col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + if col: + row_inds1 = row_inds1 + (rectangle_size // col) + col_inds1 = col_inds1 + col_offset + + return torch.stack( + (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1))) + ) + + +@register_decomposition(aten.bucketize) +@out_wrapper(exact_dtype=True) +def bucketize( + a: TensorLikeType, + boundaries: TensorLikeType, + *, + out_int32: bool = False, + right: bool = False, +): + torch._check( + boundaries.dim() == 1, + lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", + ) + + out_dtype = torch.int32 if out_int32 else torch.int64 + n_boundaries = boundaries.shape[-1] + if n_boundaries == 0: + return torch.zeros_like(a) + # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`) + # each element of `a` belongs to. We use binary search to achieve logarithimic complexity, + # but each step of the search is done "in parallel" over all elements of `a` + # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end + start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) + end = start + n_boundaries + # Max depth of the binary search + # Since we can't break out of the loop at different points for different elements of a, + # we just do the max amount of iterations that binary search requires and add condition + # tensor (cond_update below) to stop updating once the search terminates + + # For first iteration through loop we can skip some checks, we have separate implementation + mid = start + (end - start) // 2 + mid_val = boundaries[mid] + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where(cond_mid, start, mid + 1) + + if n_boundaries > 1: + cond_update = torch.ones_like(a, dtype=torch.bool) + niters = int(math.log2(n_boundaries)) + for _ in range(niters): + end = torch.where(cond_mid & cond_update, mid, end) + cond_update = start < end + # start might end up pointing to 1 past the end, we guard against that + mid = torch.where(cond_update, start + (end - start) // 2, 0) + mid_val = boundaries[mid] + # If right is true, the buckets are closed on the *left* + # (i.e., we are doing the equivalent of std::upper_bound in C++) + # Otherwise they are closed on the right (std::lower_bound) + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where((~cond_mid) & cond_update, mid + 1, start) + + return start.to(dtype=out_dtype) + + +@register_decomposition(aten.cauchy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def cauchy(self, median=0, sigma=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Cauchy distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + sigma > 0.0, + lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", + ) + return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5)) + + +@register_decomposition(aten.exponential) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def exponential(self, rate=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Exponential distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + rate > 0.0, + lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", + ) + return -1 / rate * torch.log1p(-torch.rand_like(self)) + + +@register_decomposition(aten.geometric) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def geometric(self, p, generator=None): + assert generator is None + # TODO: fix inductor rand_like for integer, bool dtypes + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"geometric not implemented for {self.dtype}", + ) + torch._check( + 0 < p and p < 1, + lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", + ) + return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1 + + +@register_decomposition(aten.log_normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def log_normal(self, mean=1, std=2, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"log_normal not implemented for {self.dtype}", + ) + torch._check( + 0 < std, + lambda: f"log_normal_ expects std > 0.0, but found std={std}", + ) + return torch.exp(std * torch.randn_like(self) + mean) + + +# TODO: add support for functionalization aten.normal_functional +# NOTE: the device and dtype will be ignored when shape is None +@register_decomposition(aten.normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=( + "mean", + "std", + ), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def normal( + mean=0, + std=1, + size=None, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + assert layout is None or layout == torch.strided + + if not isinstance(std, TensorLike): + torch._check( + std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}" + ) + + if size is None: + tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike)) + torch._check( + len(tensors) > 0, + lambda: "normal expects that either mean or std is a tensor, or size is defined", + ) + torch._check( + layout is None and pin_memory is None, + lambda: "Cannot pass layout, or pin_memory without size", + ) + + size = _broadcast_shapes(*(t.shape for t in tensors)) + dtype = tensors[0].dtype + device = tensors[0].device + else: + torch._check( + not isinstance(mean, TensorLike) and not isinstance(std, TensorLike), + lambda: "normal expects mean and std to be scalars when size is defined", + ) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + normal_samples = prims.normal( + size, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=False, + generator=generator, + ) + return std * normal_samples + mean + + +@register_decomposition(aten.normal_) +def normal_(self, mean=0, std=1, *, generator=None): + return normal(mean, std, self.shape, out=self, generator=generator) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rad2deg(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "rad2deg is not supported for complex tensors.", + ) + M_180_PI = 57.295779513082320876798154814105170332405472466564 + return self * M_180_PI + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def deg2rad(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "deg2rad is not supported for complex tensors.", + ) + M_PI_180 = 0.017453292519943295769236907684886127134428718885417 + return self * M_PI_180 + + +@register_decomposition(aten.count_nonzero) +@out_wrapper() +def count_nonzero(self, dim: Optional[DimsType] = None): + return (self != 0).sum(dim) + + +def _dot_check(self, other): + torch._check( + self.dim() == 1 and other.dim() == 1, + lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", + ) + + def numel_error(): + return ( + f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the" + f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" + ) + + torch._check(self.numel() == other.numel(), numel_error) + + +@register_decomposition(aten.dot) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def dot(self, other): + if self.is_complex(): + if self.is_conj(): + if other.is_conj(): + return torch.dot(self.conj(), other.conj()).conj() + else: + return torch.vdot(self.conj(), other) + elif other.is_conj(): + return torch.vdot(other.conj(), self) + + _dot_check(self, other) + return (self * other).sum() + + +@register_decomposition(aten.vdot) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vdot(self, other): + if not self.is_complex(): + return torch.dot(self, other) + + if self.is_conj(): + if other.is_conj(): + return torch.vdot(other.conj(), self.conj()) + else: + return torch.dot(self.conj(), other) + elif other.is_conj(): + return torch.dot(self, other.conj()).conj() + + _dot_check(self, other) + # The decomposition fails if you do self.conj()... not sure why + return (self.conj_physical() * other).sum() + + +# inplace +abs_ = _make_inplace(abs) +acos_ = _make_inplace(acos) +acosh_ = _make_inplace(acosh) +add_ = _make_inplace(add) +addcmul_ = _make_inplace(addcmul) +addcdiv_ = _make_inplace(addcdiv) +asin_ = _make_inplace(asin) +asinh_ = _make_inplace(asinh) +atan_ = _make_inplace(atan) +atanh_ = _make_inplace(atanh) +atan2_ = _make_inplace(atan2) +bitwise_and_ = _make_inplace(bitwise_and) +bitwise_left_shift_ = _make_inplace(bitwise_left_shift) +bitwise_not_ = _make_inplace(bitwise_not) +bitwise_or_ = _make_inplace(bitwise_or) +bitwise_right_shift_ = _make_inplace(bitwise_right_shift) +bitwise_xor_ = _make_inplace(bitwise_xor) +ceil_ = _make_inplace(ceil) +clamp_ = _make_inplace(clamp) +clamp_min_ = _make_inplace(clamp_min) +clamp_max_ = _make_inplace(clamp_max) +conj_physical_ = _make_inplace(conj_physical) +copysign_ = _make_inplace(copysign) +cos_ = _make_inplace(cos) +cosh_ = _make_inplace(cosh) +cumsum_ = _make_inplace(cumsum) +cumprod_ = _make_inplace(cumprod) +deg2rad_ = _make_inplace(deg2rad) +digamma_ = _make_inplace(digamma) +div_ = _make_inplace(div) +eq_ = _make_inplace(eq) +erf_ = _make_inplace(erf) +erfc_ = _make_inplace(erfc) +erfinv_ = _make_inplace(erfinv) +exp_ = _make_inplace(exp) +exp2_ = _make_inplace(exp2) +expm1_ = _make_inplace(expm1) +float_power_ = _make_inplace(float_power) +floor_ = _make_inplace(floor) +floor_divide_ = _make_inplace(floor_divide) +fmod_ = _make_inplace(fmod) +frac_ = _make_inplace(frac) +gcd_ = _make_inplace(gcd) +ge_ = _make_inplace(ge) +gt_ = _make_inplace(gt) +heaviside_ = _make_inplace(heaviside) +hypot_ = _make_inplace(hypot) +igamma_ = _make_inplace(igamma) +igammac_ = _make_inplace(igammac) +i0_ = _make_inplace(i0) +lcm_ = _make_inplace(lcm) +le_ = _make_inplace(le) +lerp_ = _make_inplace(lerp) +lgamma_ = _make_inplace(lgamma) +log10_ = _make_inplace(log10) +log1p_ = _make_inplace(log1p) +log2_ = _make_inplace(log2) +log_ = _make_inplace(log) +logical_and_ = _make_inplace(logical_and) +logical_not_ = _make_inplace(logical_not) +logical_or_ = _make_inplace(logical_or) +logical_xor_ = _make_inplace(logical_xor) +lt_ = _make_inplace(lt) +mul_ = _make_inplace(mul) +mvlgamma_ = _make_inplace(mvlgamma) +nan_to_num_ = _make_inplace(nan_to_num) +ne_ = _make_inplace(ne) +neg_ = _make_inplace(neg) +nextafter_ = _make_inplace(nextafter) +pow_ = _make_inplace(pow) +rad2deg_ = _make_inplace(rad2deg) +reciprocal_ = _make_inplace(reciprocal) +remainder_ = _make_inplace(remainder) +rsqrt_ = _make_inplace(rsqrt) +sgn_ = _make_inplace(sgn) +sigmoid_ = _make_inplace(sigmoid) +sign_ = _make_inplace(sign) +sin_ = _make_inplace(sin) +sinc_ = _make_inplace(sinc) +sinh_ = _make_inplace(sinh) +sqrt_ = _make_inplace(sqrt) +square_ = _make_inplace(square) +sub_ = _make_inplace(sub) +tan_ = _make_inplace(tan) +tanh_ = _make_inplace(tanh) +tril_ = _make_inplace(tril) +triu_ = _make_inplace(triu) +true_divide_ = _make_inplace(true_divide) +trunc_ = _make_inplace(trunc) +xlogy_ = _make_inplace(xlogy) +cauchy_ = _make_inplace(cauchy) +exponential_ = _make_inplace(exponential) +geometric_ = _make_inplace(geometric) +log_normal_ = _make_inplace(log_normal) +zero_ = _make_inplace(zero) + + +# xref: isStorage in torch/csrc/DynamicTypes.cpp +def _isStorage(obj): + return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage)) + + +# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp +def _compute_sizes(seq, scalar_type): + MAX_DIMS = 128 + is_storage = _isStorage(seq) + sizes = [] + # TODO: this is inaccurate, we actually test PySequence_Check + while isinstance(seq, (list, tuple)): + length = len(seq) + if is_storage: + length //= scalar_type.itemsize + sizes.append(length) + if len(sizes) > MAX_DIMS: + raise ValueError(f"too many dimensions '{type(seq).__name__}'") + if length == 0: + break + try: + handle = seq[0] + except Exception: + raise ValueError( # noqa: TRY200 + f"could not determine the shape of object type '{type(seq).__name__}'" + ) + seq = handle + + return sizes + + +# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp +def _infer_scalar_type(obj): + if isinstance(obj, FloatLike): + return torch.get_default_dtype() + if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful! + return torch.int64 + if isinstance(obj, bool): + return torch.bool + if isinstance(obj, complex): + default_dtype = torch.get_default_dtype() + if default_dtype is torch.float: + return torch.cfloat + elif default_dtype is torch.double: + return torch.cdouble + else: + raise RuntimeError("invalid default scalar type for complex") + if isinstance(obj, torch.Tensor): + return obj.dtype + if isinstance(obj, str): + raise TypeError(f"new(): invalid data type '{type(obj).__name__}'") + # TODO: this is inaccurate, we actually test PySequence_Check + if isinstance(obj, (list, tuple)): + scalarType = None + length = len(obj) + # match NumPy semantics, except use default tensor type instead of + # double. + if length == 0: + return torch.get_default_dtype() + for i in range(length): + cur_item = obj[i] + # TODO: test this + """ + if cur_item is obj: + raise TypeError("new(): self-referential lists are incompatible") + """ + item_scalarType = _infer_scalar_type(cur_item) # recurse! + if scalarType is not None: + scalarType = torch.promote_types(scalarType, item_scalarType) + else: + scalarType = item_scalarType + if scalarType is torch.cdouble: + # this won't change (unless we hit undefined, but that will + # fail later) + return scalarType + return scalarType + raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}") + + +# Analogous to recursive_store +# xref: recursive_store in torch/csrc/utils/tensor_new.cpp +def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType): + if isinstance(obj, Tensor) and obj.ndim <= 1: + obj = obj.item() + # fall through into next case + if isinstance(obj, Number): + return torch.scalar_tensor(obj, dtype=scalarType) + + seq = obj + return torch.stack([_recursive_build(scalarType, item) for item in seq]) + + +# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp +def _internal_new_from_data( + options, + scalar_type, + device_opt, + data, + copy_variables, + copy_numpy, + type_inference, + pin_memory=False, +): + if isinstance(data, torch.Tensor): + torch._check( + not pin_memory, lambda: "Can't pin tensor constructed from a variable" + ) + var = data + if copy_variables: + var = var.detach() + inferred_scalar_type = var.dtype if type_inference else scalar_type + device = device_opt if device_opt is not None else var.device + return var.to( + device=device, + dtype=inferred_scalar_type, + non_blocking=False, + copy=copy_variables, + ) + + # TODO + if hasattr(data, "__cuda_array_interface__"): + return NotImplemented + + # TODO: test for numpy input with PyArray_Check + + device = device_opt if device_opt is not None else options["device"] + inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type + + # NB: Don't need to avoid tracing, as we aren't going to do any manual + # pointer filling tricks + if _isStorage(data): + return NotImplemented + else: + if torch.device(device).type == "meta": + return NotImplemented + + # In the C implementation, we would directly start poking the memory + # of a freshly allocated CPU tensor. Here, we're going to do an + # alternate, heinously slow implementation: turn each individual + # scalar into a tensor, and then repeatedly cat them together + tensor = _recursive_build(inferred_scalar_type, data) + + tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False) + + # NB: lift_fresh is not needed, because we built the tensor from scalars + # guaranteeing a fresh tensor in this case + return tensor + + +# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp +def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False): + # TODO (or not): support names kwarg + if isinstance(data, torch.Tensor): + warnings.warn( + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " + "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)" + ) + type_inference = dtype is None + new_tensor = _internal_new_from_data( + # device="cpu" because that's what you get with torch.tensor(2) no + # device by default + {"device": "cpu"}, # TODO: use torch.get_default_tensor_type + dtype if dtype is not None else torch.get_default_dtype(), + device, + data, + copy_variables=True, + copy_numpy=True, + type_inference=type_inference, + pin_memory=pin_memory, + ) + new_tensor.detach_() + new_tensor.requires_grad_(requires_grad) + return new_tensor + + +# Views +# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function +# given that it does not reshape the input (it just copies the result into it) + +# squeeze_ = _make_inplace(squeeze) +# t_ = _make_inplace(t) +# transpose_ = _make_inplace(transpose) +# unsqueeze_ = _make_inplace(unsqueeze) + + +import torch._refs._conversions +import torch._refs.fft +import torch._refs.linalg +import torch._refs.nn.functional +import torch._refs.special diff --git a/MLPY/Lib/site-packages/torch/_refs/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0df7830c61db3cf2f46d2eb8e5f742c8e243913d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_refs/__pycache__/_conversions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/__pycache__/_conversions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f3c25efae588e4d76c5341b9db360d74fd57fac Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/__pycache__/_conversions.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_refs/__pycache__/fft.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/__pycache__/fft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e8d69b19c2775cdb761fac83b8f2feb82c1137 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/__pycache__/fft.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_refs/_conversions.py b/MLPY/Lib/site-packages/torch/_refs/_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..ce345330e5676ef107060ed93dc50bc9ddd5ebcb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/_conversions.py @@ -0,0 +1,118 @@ +import torch +import torch._prims_common as utils + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +from torch._prims_common import TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes + +# Data conversion references. +# +# Note: this module breaks the usual _refs to torch naming scheme where +# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not +# part of _refs/__init__.py to avoid name clashes with Python builtin types +# (like int). + +__all__ = [ + # dtypes + "bfloat16", + "bool", + "byte", + "cdouble", + "cfloat", + "chalf", + "char", + "double", + "float", + "half", + "int", + "long", + "short", + # misc + "complex", + "polar", +] + + +def _make_conversion_method(name: str, dtype: torch.dtype): + def fn( + self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format + ) -> TensorLikeType: + return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] + + fn.__name__ = name + return fn + + +bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) + +bool = _make_conversion_method("bool", torch.bool) + +byte = _make_conversion_method("byte", torch.uint8) + +cdouble = _make_conversion_method("cdouble", torch.cdouble) + +cfloat = _make_conversion_method("cfloat", torch.cfloat) + +chalf = _make_conversion_method("chalf", torch.complex32) + +char = _make_conversion_method("char", torch.int8) + +double = _make_conversion_method("double", torch.double) + +float = _make_conversion_method("float", torch.float) + +half = _make_conversion_method("half", torch.half) + +int = _make_conversion_method("int", torch.int) + +long = _make_conversion_method("long", torch.long) + +short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch._ops.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + torch._check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + torch._check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result + + +@register_decomposition(torch._ops.ops.aten.polar) +# Note: polar has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: + result = torch.complex(abs, angle) + result.real = abs * torch.cos(angle) + result.imag = abs * torch.sin(angle) + return result diff --git a/MLPY/Lib/site-packages/torch/_refs/fft.py b/MLPY/Lib/site-packages/torch/_refs/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..47500148f8428462ab9b46f4d767047e12276e46 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/fft.py @@ -0,0 +1,590 @@ +import math + +from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +from torch._decomp import register_decomposition +from torch._prims_common import DimsType, ShapeType, TensorLikeType +from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper + +__all__ = [ + # Transforms + "fft", + "fft2", + "fftn", + "hfft", + "hfft2", + "hfftn", + "rfft", + "rfft2", + "rfftn", + "ifft", + "ifft2", + "ifftn", + "ihfft", + "ihfft2", + "ihfftn", + "irfft", + "irfft2", + "irfftn", + # Helpers + "fftshift", + "ifftshift", +] + +NormType = Union[None, Literal["forward", "backward", "ortho"]] +_NORM_VALUES = {None, "forward", "backward", "ortho"} +aten = torch._ops.ops.aten + + +def _apply_norm( + x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool +) -> TensorLikeType: + """Apply normalization to the un-normalized FFT result""" + torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") + + if norm == "ortho": + return x * (1 / math.sqrt(signal_numel)) + + normalize = (not forward and (norm is None or norm == "backward")) or ( + forward and norm == "forward" + ) + return x * (1 / signal_numel) if normalize else x + + +def _promote_type_fft( + dtype: torch.dtype, require_complex: bool, device: torch.device +) -> torch.dtype: + """Helper to promote a dtype to one supported by the FFT primitives""" + if dtype.is_complex: + return dtype + + # Promote integral to default float type + if not dtype.is_floating_point: + dtype = torch.get_default_dtype() + + allowed_types = [torch.float32, torch.float64] + maybe_support_half = device.type in ["cuda", "meta"] + + if maybe_support_half: + allowed_types.append(torch.float16) + torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}") + + if require_complex: + dtype = utils.corresponding_complex_dtype(dtype) + + return dtype + + +def _maybe_promote_tensor_fft( + t: TensorLikeType, require_complex: bool = False +) -> TensorLikeType: + """Helper to promote a tensor to a dtype supported by the FFT primitives""" + cur_type = t.dtype + new_type = _promote_type_fft(cur_type, require_complex, t.device) + return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] + + +def _resize_fft_input( + x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...] +) -> TensorLikeType: + """ + Fixes the shape of x such that x.size(dims[i]) == sizes[i], + either by zero-padding, or by slicing x starting from 0. + """ + assert len(dims) == len(sizes) + must_copy = False + x_sizes = x.shape + pad_amount = [0] * len(x_sizes) * 2 + for i in range(len(dims)): + if sizes[i] == -1: + continue + + if x_sizes[dims[i]] < sizes[i]: + must_copy = True + pad_idx = len(pad_amount) - 2 * dims[i] - 1 + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] + + if x_sizes[dims[i]] > sizes[i]: + x = x.narrow(dims[i], 0, sizes[i]) + + return torch.constant_pad_nd(x, pad_amount) if must_copy else x + + +def _fft_c2r( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to real FFT (irfft or hfft)""" + input = _maybe_promote_tensor_fft(input, require_complex=True) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + if n is not None: + input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) + + if forward: + input = torch.conj(input) + + output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) + return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) + + +def _fft_r2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, + onesided: bool, +) -> TensorLikeType: + """Common code for performing any real to complex FFT (rfft or ihfft)""" + torch._check( + not input.dtype.is_complex, + lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", + ) + input = _maybe_promote_tensor_fft(input) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_r2c(input, dim=dims, onesided=onesided) + ret = _apply_norm(ret, norm, dim_size, forward) + return ret if forward else torch.conj(ret) + + +def _fft_c2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to complex FFT (fft or ifft)""" + torch._check( + input.dtype.is_complex, + lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", + ) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_c2c(input, dim=dims, forward=forward) + return _apply_norm(ret, norm, dim_size, forward) + + +@register_decomposition(aten.fft_fft) +@out_wrapper() +def fft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("fft", input, n, dim, norm, forward=True) + else: + return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) + + +@register_decomposition(aten.fft_ifft) +@out_wrapper() +def ifft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("ifft", input, n, dim, norm, forward=False) + else: + return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) + + +@register_decomposition(aten.fft_rfft) +@out_wrapper() +def rfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) + + +@register_decomposition(aten.fft_irfft) +@out_wrapper() +def irfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("irfft", input, n, dim, norm, forward=False) + + +@register_decomposition(aten.fft_hfft) +@out_wrapper() +def hfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("hfft", input, n, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ihfft) +@out_wrapper() +def ihfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) + + +class _ShapeAndDims(NamedTuple): + shape: Tuple[int, ...] + dims: Tuple[int, ...] + + +def _canonicalize_fft_shape_and_dim_args( + input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] +) -> _ShapeAndDims: + """Convert the shape and dim arguments into a canonical form where neither are optional""" + input_dim = input.ndim + input_sizes = input.shape + + if dim is not None: + if not isinstance(dim, Sequence): + dim = (dim,) + ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) + + # Check dims are unique + torch._check( + len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique" + ) + + if shape is not None: + if not isinstance(shape, Sequence): + shape = (shape,) + + # Has shape, might have dim + torch._check( + dim is None or len(dim) == len(shape), + lambda: "When given, dim and shape arguments must have the same length", + ) + transform_ndim = len(shape) + + torch._check( + transform_ndim <= input_dim, + lambda: f"Got shape with {transform_ndim} values but input tensor " + f"only has {input_dim} dimensions.", + ) + + # If shape is given, dims defaults to the last len(shape) dimensions + if dim is None: + ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) + + # Translate any -1 values in shape to the default length + ret_shape = tuple( + s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] + ) + elif dim is None: + # No shape, no dim + ret_dims = tuple(range(input_dim)) + ret_shape = tuple(input_sizes) + else: + # No shape, has dim + ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] + + for n in ret_shape: + torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") + + return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] + + +def _prod(xs: Iterable[int]) -> int: + """Compute product of a list""" + prod = 1 + for x in xs: + prod *= x + return prod + + +def _fftn_c2c( + function_name: str, + input: TensorLikeType, + shape: Tuple[int, ...], + dim: Tuple[int, ...], + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" + torch._check( + input.dtype.is_complex, + lambda: f"{function_name} expects a complex input tensor, " + f"but got {input.dtype}", + ) + x = _resize_fft_input(input, dim, shape) + output = prims.fft_c2c(x, dim=dim, forward=forward) + return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) + + +@register_decomposition(aten.fft_fftn) +@out_wrapper() +def fftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ifftn) +@out_wrapper() +def ifftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) + + +@register_decomposition(aten.fft_rfftn) +@out_wrapper() +def rfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_r2c(input, dim=dim, onesided=True) + return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) + + +@register_decomposition(aten.fft_ihfftn) +@out_wrapper() +def ihfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) + + if len(dim) == 1: + tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) + return prims.conj(tmp) + + tmp = prims.conj_physical(tmp) + tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) + return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) + + +class _CanonicalizeC2rReturn(NamedTuple): + shape: Tuple[int, ...] + dim: Tuple[int, ...] + last_dim_size: int + + +def _canonicalize_fft_c2r_shape_and_dim_args( + fname: str, + input: TensorLikeType, + s: Optional[ShapeType], + dim: Optional[DimsType], +) -> _CanonicalizeC2rReturn: + """Canonicalize shape and dim arguments for n-dimensional c2r transforms, + as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") + + if s is None or s[-1] == -1: + last_dim_size = 2 * (input.shape[dim[-1]] - 1) + else: + last_dim_size = shape[-1] + + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + shape_list = list(shape) + shape_list[-1] = last_dim_size // 2 + 1 + return _CanonicalizeC2rReturn( + shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size + ) + + +@register_decomposition(aten.fft_irfftn) +@out_wrapper() +def irfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "irfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) + return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) + + +@register_decomposition(aten.fft_hfftn) +@out_wrapper() +def hfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "hfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input + tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) + tmp = prims.conj_physical(tmp) + out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) + return _apply_norm(out, norm, last_dim_size, forward=True) + + +@register_decomposition(aten.fft_fft2) +@out_wrapper() +def fft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.fftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ifft2) +@out_wrapper() +def ifft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_rfft2) +@out_wrapper() +def rfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_irfft2) +@out_wrapper() +def irfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_hfft2) +@out_wrapper() +def hfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ihfft2) +@out_wrapper() +def ihfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) + + +def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]: + """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" + if dim is None: + return list(range(x.ndim)) + elif not isinstance(dim, Sequence): + return [dim] + else: + return list(dim) + + +@register_decomposition(aten.fft_fftshift) +def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [input.shape[d] // 2 for d in dims] + return torch.roll(input, shift, dims) + + +@register_decomposition(aten.fft_ifftshift) +def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [(input.shape[d] + 1) // 2 for d in dims] + return torch.roll(input, shift, dims) diff --git a/MLPY/Lib/site-packages/torch/_refs/linalg/__init__.py b/MLPY/Lib/site-packages/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..492f43e840909c6236f98c3e138022ff8317d9be --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/linalg/__init__.py @@ -0,0 +1,308 @@ +from functools import partial + +from typing import List, Optional, Tuple, Union + +import torch + +import torch._prims as prims + +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check_fp_or_complex, + check_is_matrix, + Dim, + DimsType, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + NumberType, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + elementwise_type_promotion_wrapper, + out_wrapper, +) + + +__all__ = [ + "diagonal", + "matrix_norm", + "norm", + "svd", + "svdvals", + "vector_norm", + "vecdot", + "cross", +] + + +def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + torch._check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + torch._check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + "without narrowing to the specified dtype ({dtype})", + ) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._decomp.decompositions import pw_cast_for_opmath + + +@register_decomposition(torch._ops.ops.aten.linalg_cross) +@out_wrapper() +@pw_cast_for_opmath +def cross(a: Tensor, b: Tensor, dim: int = -1): + torch._check( + a.ndim == b.ndim, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + a.size(dim) == 3 and b.size(dim) == 3, + lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", + ) + a, b = torch.broadcast_tensors(a, b) + dim = utils.canonicalize_dim(a.ndim, dim) + idx = torch.arange(3, device=a.device) + return a.index_select(dim, (idx + 1) % 3) * b.index_select( + dim, (idx + 2) % 3 + ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + + +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + +@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: Union[float, int] = 2, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + # Checks + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, Dim): + dim = [dim] # type: ignore[assignment] + + if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): + torch._check( + dim is not None and len(dim) != 0, + lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + shape = x.shape + assert dim is not None # mypy does not seem to be able to see through check? + for d in dim: + torch._check( + shape[d] != 0, + lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + elif ord == float("-inf"): + return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] + reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) + + is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if not (is_ord_even and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] + + +def _backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def _inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( + dim[0] != dim[1], + lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + torch._check( + ord in ("fro", "nuc"), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + torch._check( + abs_ord in (2, 1, float("inf")), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + if abs_ord == 2.0: + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return max_min( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) in (1, 2), + lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + torch._check( + A.ndim in (1, 2), + lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] + + +# CompositeImplicitAutograd +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("x", "y"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: + check_fp_or_complex(x.dtype, "linalg.vecdot") + return (x.conj() * y).sum(dim=dim) diff --git a/MLPY/Lib/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4340d23e2be53fba4f6dd5dd4613cb4281b8f57c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_refs/nn/__init__.py b/MLPY/Lib/site-packages/torch/_refs/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..970be144221489803f5ff4fcbe500037775ca79e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/nn/__init__.py @@ -0,0 +1,3 @@ +from typing import List + +__all__: List[str] = [] diff --git a/MLPY/Lib/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef53da8d86a863fa1bff3e5a1b262b24400454be Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_refs/nn/functional/__init__.py b/MLPY/Lib/site-packages/torch/_refs/nn/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0a32d1c7c0da310a3d5e965ebacfed6be163b9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/nn/functional/__init__.py @@ -0,0 +1,1230 @@ +import math +from functools import wraps +from typing import Callable, Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) +from torch._refs import _make_inplace + +__all__ = [ + "alpha_dropout", + "celu", + "celu_", + "dropout", + "elu", + "elu_", + "gelu", + "glu", + "group_norm", + "hardshrink", + "hardtanh", + "hinge_embedding_loss", + "huber_loss", + "l1_loss", + "layer_norm", + "leaky_relu", + "log_softmax", + "margin_ranking_loss", + "mish", + "mish_", + "mse_loss", + "nll_loss", + "pairwise_distance", + "pdist", + "poisson_nll_loss", + "prelu", + "relu", + "relu6", + "selu", + "selu_", + "smooth_l1_loss", + "softmax", + "softmin", + "softplus", + "softshrink", + "tanhshrink", + "threshold", + "threshold_", + "triplet_margin_loss", +] + +Tensor = torch.Tensor +aten = torch._ops.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return self + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def _inplace_wrapper(fn): + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, inplace=False, **kwargs): + if inplace: + torch._check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + return fn(a, *args, inplace=False, out=a, **kwargs) + else: + return fn(a, *args, inplace=False, **kwargs) + + return _fn + + +# celu is implemented specially because it has an alpha argument +# celu is very similar to elu +@register_decomposition(aten.celu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def celu( + a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.celu + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if alpha is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] + else: + rhs = torch.expm1(a) + + return torch.where(a > 0, a, rhs) + + +@_inplace_wrapper +@out_wrapper() +def dropout( + a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return a + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(a) + + if p == 0: + return a + + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale + + +@register_decomposition(aten.elu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def elu( + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.elu + """ + if inplace: + raise NotImplementedError + + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + torch._check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) + + +@register_decomposition(aten.relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu + """ + + if inplace: + raise NotImplementedError + + return torch.where(torch.le(a, 0), 0, a) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + +def layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.layer_norm`. + """ + return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] + + +@register_decomposition(aten.leaky_relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def leaky_relu( + a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.leaky_relu + """ + + if inplace: + raise NotImplementedError + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(negative_slope), python_type): + msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) + + +@register_decomposition(aten.mish) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.mish + """ + + if inplace: + raise NotImplementedError + return a * torch.tanh(torch.nn.functional.softplus(a)) + + +@register_decomposition(aten.selu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.selu + """ + if inplace: + raise NotImplementedError + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + rhs = alpha * torch.expm1(a) + + return scale * torch.where(a > 0, a, rhs) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# CompositeImplicitAutograd - don't register decomp +def softmin( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# softplus is implemented specially because it has beta and threshold arguments +@register_decomposition(aten.softplus) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def softplus( + a: TensorLikeType, + beta: Optional[NumberType] = None, + threshold: NumberType = 20, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.softplus + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if beta is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(beta), python_type): + msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + scaled_input = a * beta + rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] + + else: + scaled_input = a + rhs = torch.log1p(torch.exp(scaled_input)) + + return torch.where(scaled_input > threshold, a, rhs) + + +@aten.hardshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.hardshrink) +@out_wrapper() +def hardshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # hardshrink(x) = x if x > lambd + # = x if x < -lambd + # = 0 otherwise + return torch.where(torch.abs(a) <= lambd, 0, a) + + +@aten.softshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.softshrink) +@out_wrapper() +def softshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # softshrink(x) = x - lambd if x > lambd + # = x + lambd if x < -lambd + # = 0 otherwise + torch._check( + lambd >= 0, + lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", + ) + # We implement this in one torch.where to generate better code in the backward + # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 + return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0) + + +# Losses +def _reduction_int_to_str(reduction: int) -> str: + from torch._decomp.decompositions import Reduction + + if reduction == Reduction.NONE.value: + return "none" + elif reduction == Reduction.MEAN.value: + return "mean" + elif reduction == Reduction.SUM.value: + return "sum" + else: + raise ValueError(f"{reduction} is not a valid value for reduction") + + +def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def _check_reduction_value(reduction: str): + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def smooth_l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.smooth_l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + + if beta == 0.0: + return torch.nn.functional.l1_loss( + input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + else: + loss = torch.abs(input - target) + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return _apply_loss_reduction(loss, reduction) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@register_decomposition(aten.margin_ranking_loss) +def margin_ranking_loss( + input1: TensorLikeType, + input2: TensorLikeType, + target: TensorLikeType, + margin: float = 0.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = max(0, −target * (input1 − input2) + margin) + if input1.ndim != input2.ndim or input1.ndim != target.ndim: + raise RuntimeError( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " + ) + _check_reduction_value(reduction) + loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hinge_embedding_loss) +def hinge_embedding_loss( + input: TensorLikeType, + target: TensorLikeType, + margin: float = 1.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = input if y == 1 + # = max(0, margin - input) if y == -1 + _check_reduction_value(reduction) + margin_clamp = torch.clamp_min(margin - input, 0) + output_margin = torch.where(target != 1, margin_clamp, 0) + output_self = torch.where(target != -1, input, 0) + loss = output_margin + output_self + return _apply_loss_reduction(loss, reduction) + + +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + torch._check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", + ) + + torch._check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + from torch._subclasses.fake_tensor import FakeTensor + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + torch._check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classe, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(aten.nll_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + torch._check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + torch._check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + +# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: +# https://github.com/pytorch/pytorch/issues/83931 +# TODO: Could be rewritten to support complex: +# https://github.com/pytorch/pytorch/pull/85041 +@register_decomposition(aten.huber_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def huber_loss( + input: TensorLikeType, + target: TensorLikeType, + reduction: Union[str, int] = "mean", + delta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.huber_loss + """ + if type(reduction) is int: + reduction = _reduction_int_to_str(reduction) + _check_reduction_value(reduction) # type: ignore[arg-type] + torch._check( + delta > 0, + lambda: "huber_loss does not support non-positive values for delta.", + ) + z = (input - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] + + +# tanhshrink does not use _make_elementwise_unary_reference because it does not support out +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def tanhshrink(a: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.tanhshrink + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + return a - torch.tanh(a) + + +@register_decomposition(aten.threshold) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + torch._check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hardtanh) +@_inplace_wrapper +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def hardtanh( + a: TensorLikeType, + min_val: NumberType = -1, + max_val: NumberType = 1, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.hardtanh + """ + if inplace: + raise NotImplementedError + if utils.is_boolean_dtype(a.dtype): + raise RuntimeError("Bool inputs not supported for hardtanh") + + # preserve legacy behavior of boundaries not causing type promotion + if utils.is_integer_dtype(a.dtype): + min_val = int(min_val) # type: ignore[arg-type] + max_val = int(max_val) # type: ignore[arg-type] + if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): + raise RuntimeError( + "Cannot do hardtanh on an unsigned type with negative limits" + ) + return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] + + +@register_decomposition(aten.gelu) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.gelu + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + elif approximate == "none": + kAlpha = M_SQRT1_2 + return a * 0.5 * (1 + torch.erf(a * kAlpha)) + else: + raise RuntimeError("approximate argument must be either none or tanh.") + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def poisson_nll_loss( + input: TensorLikeType, + target: TensorLikeType, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.poisson_nll_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + if log_input: + loss = torch.exp(input) - target * input + else: + loss = input - target * torch.log(input + eps) + + if full: + stirling_term = ( + target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) + ) + # avoid inplace add + loss = loss + stirling_term.masked_fill(target <= 1, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + torch._check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + torch._check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + torch._check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + torch._check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + if a.ndim == 0: + weight = weight[0] if weight.ndim == 1 else weight + else: + weight = prims.broadcast_in_dim( + weight, a.shape, tuple() if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) + ) + + return torch.where(a > 0, a, a * weight) + + +@register_decomposition(aten.relu6) +@_inplace_wrapper +@out_wrapper() +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return torch.nn.functional.hardtanh(a, 0, 6) + + +@register_decomposition(aten.glu) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + torch._check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + b, c = torch.tensor_split(a, 2, dim) + + return b * torch.sigmoid(c) + + +@register_decomposition(aten.pairwise_distance) +@out_wrapper() +def pairwise_distance( + x1: TensorLikeType, + x2: TensorLikeType, + p: NumberType = 2.0, + eps: NumberType = 1e-6, + keepdim=False, +) -> TensorLikeType: + return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) + + +@register_decomposition(aten.pdist) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") + # For p == 2 we can use an efficient implementation, but other values of p + # require creating a much bigger tensor for an intermediate step + if p == 2: + aTa = torch.mm(a, a.T) + aTa_diag = torch.diag(aTa) + t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) + else: + t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) + i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) + return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +@register_decomposition(aten.pixel_shuffle) +@out_wrapper() +def pixel_shuffle(self: Tensor, upscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] // upscale_factor**2 + HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) + return ( + self.view( + *batch, + C_out, + upscale_factor, + upscale_factor, + self.shape[-2], + self.shape[-1], + ) + .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +@register_decomposition(aten.pixel_unshuffle) +@out_wrapper() +def pixel_unshuffle(self: Tensor, downscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] * downscale_factor**2 + HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) + return ( + self.view( + *batch, + self.shape[-3], + HW_out[0], + downscale_factor, + HW_out[1], + downscale_factor, + ) + .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/MLPY/Lib/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df4f31ffa83a3615eccde9a75c3d2158a155b9b1 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_refs/special/__init__.py b/MLPY/Lib/site-packages/torch/_refs/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d57e327f049b8555b90011f7f588927f925832c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_refs/special/__init__.py @@ -0,0 +1,236 @@ +import math +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs + +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper +from torch._refs import ( + _make_alias, + _make_elementwise_binary_reference, + _make_elementwise_unary_reference, +) + + +__all__ = [ + "bessel_j0", + "bessel_j1", + "entr", + "erfcx", + "expit", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "log_softmax", + "multigammaln", + "ndtr", + "ndtri", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "zeta", +] +aten = torch._ops.ops.aten + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j0(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j1(a) + + +@register_decomposition(aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +@register_decomposition(aten.special_erfcx) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def erfcx(a: TensorLikeType) -> TensorLikeType: + return prims.erfcx(a) + + +# alias for sigmoid +expit = _make_alias(torch.sigmoid, "expit") + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i0e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i0e(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1e(a) + + +@register_decomposition(aten.special_log_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def log_ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / √2 + M_SQRT1_2 = 0.707106781186547524400844362104849039 + t = a * M_SQRT1_2 + return torch.where( + a < 1.0, + torch.log(torch.special.erfcx(-t) / 2) - t * t, + torch.log1p(-torch.erfc(t) / 2), + ) + + +@register_decomposition(aten.logit) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: + if eps is None: + eps = -1.0 + lo = eps + hi = 1 - eps + self = torch.clamp(self, lo, hi) + return torch.log(torch.true_divide(self, torch.sub(1, self))) + + +@register_decomposition(aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@register_decomposition(aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + +@register_decomposition(aten.special_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / √2 + M_SQRT1_2 = 0.707106781186547524400844362104849039 + a_sqrt_2 = a * M_SQRT1_2 + return (1 + torch.erf(a_sqrt_2)) * 0.5 + + +@register_decomposition(aten.special_ndtri) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtri(a: TensorLikeType) -> TensorLikeType: + return prims.ndtri(a) + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.spherical_bessel_j0(a) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/MLPY/Lib/site-packages/torch/_refs/special/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_refs/special/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b7bd6bb560e0a1ed99fe5127728b63fb63f951d Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_refs/special/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_sources.py b/MLPY/Lib/site-packages/torch/_sources.py new file mode 100644 index 0000000000000000000000000000000000000000..c5342e5ce12f0e13889fcb944e1df913ea5ccc39 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_sources.py @@ -0,0 +1,137 @@ +import ast +import functools +import inspect +from textwrap import dedent +from typing import Any, List, NamedTuple, Optional, Tuple + +from torch._C import ErrorReport +from torch._C._jit_tree_views import SourceRangeFactory + + +def get_source_lines_and_file( + obj: Any, + error_msg: Optional[str] = None, +) -> Tuple[List[str], int, Optional[str]]: + """ + Wrapper around inspect.getsourcelines and inspect.getsourcefile. + + Returns: (sourcelines, file_lino, filename) + """ + filename = None # in case getsourcefile throws + try: + filename = inspect.getsourcefile(obj) + sourcelines, file_lineno = inspect.getsourcelines(obj) + except OSError as e: + msg = ( + f"Can't get source for {obj}. TorchScript requires source access in " + "order to carry out compilation, make sure original .py files are " + "available." + ) + if error_msg: + msg += "\n" + error_msg + raise OSError(msg) from e + + return sourcelines, file_lineno, filename + + +def normalize_source_lines(sourcelines: List[str]) -> List[str]: + """ + This helper function accepts a list of source lines. It finds the + indentation level of the function definition (`def`), then it indents + all lines in the function body to a point at or greater than that + level. This allows for comments and continued string literals that + are at a lower indentation than the rest of the code. + Args: + sourcelines: function source code, separated into lines by + the '\n' character + Returns: + A list of source lines that have been correctly aligned + """ + + def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix) :] + + # Find the line and line number containing the function definition + idx = None + for i, l in enumerate(sourcelines): + if l.lstrip().startswith("def"): + idx = i + break + + # This will happen when the function is a lambda- we won't find "def" anywhere in the source + # lines in that case. Currently trying to JIT compile a lambda will throw an error up in + # `parse_def()`, but we might want to handle this case in the future. + if idx is None: + return sourcelines + + # Get a string representing the amount of leading whitespace + fn_def = sourcelines[idx] + whitespace = fn_def.split("def")[0] + + # Add this leading whitespace to all lines before and after the `def` + aligned_prefix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] + ] + aligned_suffix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] + ] + + # Put it together again + aligned_prefix.append(fn_def) + return aligned_prefix + aligned_suffix + + +# Thin wrapper around SourceRangeFactory to store extra metadata +# about the function-to-be-compiled. +class SourceContext(SourceRangeFactory): + def __init__( + self, + source, + filename, + file_lineno, + leading_whitespace_len, + uses_true_division=True, + funcname=None, + ): + super().__init__(source, filename, file_lineno, leading_whitespace_len) + self.uses_true_division = uses_true_division + self.filename = filename + self.funcname = funcname + + +@functools.lru_cache(maxsize=None) +def make_source_context(*args): + return SourceContext(*args) + + +def fake_range(): + return SourceContext("", None, 0, 0).make_raw_range(0, 1) + + +class ParsedDef(NamedTuple): + ast: ast.Module + ctx: SourceContext + source: str + filename: Optional[str] + file_lineno: int + + +def parse_def(fn): + sourcelines, file_lineno, filename = get_source_lines_and_file( + fn, ErrorReport.call_stack() + ) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): + raise RuntimeError( + f"Expected a single top-level function: {filename}:{file_lineno}" + ) + leading_whitespace_len = len(source.split("\n", 1)[0]) - len( + dedent_src.split("\n", 1)[0] + ) + ctx = make_source_context( + source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ + ) + return ParsedDef(py_ast, ctx, source, filename, file_lineno) diff --git a/MLPY/Lib/site-packages/torch/_storage_docs.py b/MLPY/Lib/site-packages/torch/_storage_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..a32bb6a1222355fd550a3e23cc8e0ef376eeb4c6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_storage_docs.py @@ -0,0 +1,43 @@ +"""Adds docstrings to Storage functions""" + +import torch._C +from torch._C import _add_docstr as add_docstr + + +storage_classes = [ + "StorageBase", +] + + +def add_docstr_all(method, docstr): + for cls_name in storage_classes: + cls = getattr(torch._C, cls_name) + try: + add_docstr(getattr(cls, method), docstr) + except AttributeError: + pass + + +add_docstr_all( + "from_file", + """ +from_file(filename, shared=False, size=0) -> Storage + +Creates a CPU storage backed by a memory-mapped file. + +If ``shared`` is ``True``, then memory is shared between all processes. +All changes are written to the file. If ``shared`` is ``False``, then the changes on +the storage do not affect the file. + +``size`` is the number of elements in the storage. If ``shared`` is ``False``, +then the file must contain at least ``size * sizeof(Type)`` bytes +(``Type`` is the type of storage, in the case of an ``UnTypedStorage`` the file must contain at +least ``size`` bytes). If ``shared`` is ``True`` the file will be created if needed. + +Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the storage +""", +) diff --git a/MLPY/Lib/site-packages/torch/_streambase.py b/MLPY/Lib/site-packages/torch/_streambase.py new file mode 100644 index 0000000000000000000000000000000000000000..db9fc14b892a5e102a3dba12d3bf54465d123aaa --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_streambase.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod + + +class _StreamBase(ABC): + r"""Base stream class abstraction for multi backends Stream to herit from""" + + @abstractmethod + def wait_event(self, event): + raise NotImplementedError() + + @abstractmethod + def wait_stream(self, stream): + raise NotImplementedError() + + @abstractmethod + def record_event(self, event=None): + raise NotImplementedError() + + @abstractmethod + def query(self): + raise NotImplementedError() + + @abstractmethod + def synchronize(self): + raise NotImplementedError() + + @abstractmethod + def __eq__(self, stream): + raise NotImplementedError() + + +class _EventBase(ABC): + r"""Base Event class abstraction for multi backends Event to herit from""" + + @abstractmethod + def wait(self, stream=None): + raise NotImplementedError() + + @abstractmethod + def query(self): + raise NotImplementedError() + + @abstractmethod + def synchronize(self): + raise NotImplementedError() diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__init__.py b/MLPY/Lib/site-packages/torch/_subclasses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2eff305df6928c21d603dc08a4f22fb45a8859 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/__init__.py @@ -0,0 +1,18 @@ +import torch + +from torch._subclasses.fake_tensor import ( + DynamicOutputShapeException, + FakeTensor, + FakeTensorMode, + UnsupportedFakeTensorException, +) + +from torch._subclasses.fake_utils import CrossRefFakeMode + +__all__ = [ + "FakeTensor", + "FakeTensorMode", + "UnsupportedFakeTensorException", + "DynamicOutputShapeException", + "CrossRefFakeMode", +] diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..038616017e4fbb3e5b5996e16e34199bfca315b5 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e32be5069551ec5881f86c5b5886337da5770e8 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4bcc1734f75c445c9c867a2c66f53e4df6285f0 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ea6742717cebd081ecd75bb1ec5d46b8cddf134 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25aec2437065e447e4e26ed09d8a260fba19ddca Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3db173b2508722dbd0889f02ab91b0f91bec4bfc Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62fd0e70bdb90698332b0252e9e0220366ec8bef Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_subclasses/fake_impls.py b/MLPY/Lib/site-packages/torch/_subclasses/fake_impls.py new file mode 100644 index 0000000000000000000000000000000000000000..c7421b485abdbfe8b9ca15e5b7cf16b754b4ce92 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/fake_impls.py @@ -0,0 +1,1061 @@ +# mypy: ignore-errors + +import functools +import itertools +import math +import sys +from typing import Callable, Union + +import torch +import torch._custom_op +import torch._logging + +from torch._ops import OpOverload +from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, +) + +from torch._subclasses.fake_tensor import ( + DataDependentOutputException, + DynamicOutputShapeException, + FakeTensor, + in_kernel_invocation_manager, + run_fallback_kernel, + UnsupportedOperatorException, +) +from torch.fx.operator_schemas import normalize_function + +from torch.utils._stats import count_label + +pytree = torch.utils._pytree + +__all__ = [ + "op_implementations_checks", + "get_fast_op_impls", + "stride_incorrect_op", + "has_meta", +] + +op_implementations_dict = {} +op_implementations_checks = [] + + +aten = torch._ops.ops.aten + + +def ordered_set(*items): + return dict.fromkeys(items, True) + + +# This function indicates if the backend device +# supports non-contiguous tensors +def is_noncontiguous_supported(device): + if device.type == "hpu": + return False + return True + + +_like_tensor_constructors = ordered_set( + aten.empty_like.default, + aten.empty_like.out, + aten.full_like.default, + aten.full_like.out, + aten.ones_like.default, + aten.ones_like.out, + aten.rand_like.default, + aten.rand_like.out, + aten.randn_like.default, + aten.randn_like.out, + aten.randint_like.default, + aten.randint_like.out, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.zeros_like.default, + aten.zeros_like.out, + aten.new_empty.default, + aten.new_empty.out, + aten.new_empty_strided.default, + aten.new_empty_strided.out, + aten.new_full.default, + aten.new_full.out, + aten.new_zeros.default, + aten.new_zeros.out, + aten.new_ones.default, + aten.new_ones.out, +) + + +_device_not_kwarg_ops = ordered_set( + aten._resize_output_.default, + aten._nested_tensor_from_tensor_list.default, + aten._nested_tensor_from_tensor_list.out, + aten.pin_memory.default, + aten.is_pinned.default, + aten.to.device, + aten.to.prim_Device, + aten._pin_memory.default, + aten._pin_memory.out, + aten._resize_output.default, + aten._resize_output.out, +) + +# this op is never actually used +_non_kwarg_device_constructors = (aten._list_to_tensor,) + + +def contains_tensor_types(type): + tensor_type = torch._C.TensorType.get() + return type.isSubtypeOf(tensor_type) or any( + contains_tensor_types(e) for e in type.containedTypes() + ) + + +@functools.lru_cache(None) +def _is_tensor_constructor(func: OpOverload): + assert isinstance(func, OpOverload) + schema = func._schema + if any(contains_tensor_types(arg.type) for arg in schema.arguments): + return False + # TODO: no real reason to restrict multiple outputs + return ( + len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() + ) + + +def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): + def impl_decorator(op_impl): + if isinstance(run_impl_check, OpOverload): + assert ( + run_impl_check not in op_implementations_dict + ), f"duplicate registration: {run_impl_check}" + op_implementations_dict[run_impl_check] = op_impl + elif isinstance(run_impl_check, (list, tuple)): + for op in run_impl_check: + register_op_impl(op)(op_impl) + else: + assert callable(run_impl_check) + op_implementations_checks.append((run_impl_check, op_impl)) + + return op_impl + + return impl_decorator + + +@register_op_impl(op_implementations_dict.__contains__) +def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): + return op_implementations_dict[func](fake_mode, func, *args, **kwargs) + + +@register_op_impl(_is_tensor_constructor) +@register_op_impl([*_like_tensor_constructors]) +def constructors(fake_mode, func, *args, **kwargs): + assert func not in _non_kwarg_device_constructors + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + if "names" in kwargs: + raise UnsupportedOperatorException( + "torch.compile doesn't support named tensors" + ) + + if func in _like_tensor_constructors: + default_device = new_kwargs["input"].device + # TODO: file issue + args = (new_kwargs.pop("input"),) + else: + # cpu is default device if none is specified + default_device = torch.device("cpu") + args = () + out_device = new_kwargs.pop("device", None) + out_device = out_device if out_device is not None else default_device + new_kwargs["device"] = torch.device("meta") + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): + r = func(*args, **new_kwargs) + return FakeTensor(fake_mode, r, out_device) + + +@register_op_impl(aten.to.prim_Device) +@register_op_impl(aten.to.device) +def non_kwarg_to(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ) + input_device = new_kwargs["device"] + out_device = input_device if input_device else new_kwargs["input"].device + new_kwargs["device"] = torch.device("meta") + inp = new_kwargs.pop("input") + with in_kernel_invocation_manager(fake_mode): + r = func(inp, **new_kwargs) + # TODO: I think this does the wrong thing if r is inp + return fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, r, out_device + ) + + +def stride_incorrect_op(op): + if op.namespace not in ("aten", "prims"): + return False + if op is aten._fft_c2c.default: + return False + + op_name = op.name() + if "fft" in op_name: + return True + return False + + +# These operators have meta implementations with incorrect strides +@register_op_impl(stride_incorrect_op) +def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): + # This is a workaround for meta implmentations with incorrect strides + + def is_symbolic(x): + if isinstance(x, FakeTensor): + return x._has_symbolic_sizes_strides + if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return True + return False + + # For static shapes, we can fall back to eager for the real strides + if fake_mode.allow_fallback_kernels: + require_dynamic = any( + is_symbolic(x) for x in itertools.chain(args, kwargs.values()) + ) + if not require_dynamic: + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) + + raise UnsupportedOperatorException(func) + + +# Dont default to default device handling, +# since the device of `the_template` is ignored +@register_op_impl(aten.resize_as_.default) +def resize_as_(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + return func(*args, **kwargs) + + +@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) +def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): + # TODO: remove me + return constructors(fake_mode, func, *args, **kwargs) + + +# index.Tensor data-dependent in only some conditions +@register_op_impl( + lambda func: torch.Tag.dynamic_output_shape in func.tags + and func + not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] +) +def dyn_shape(fake_mode, func, *args, **kwargs): + raise DynamicOutputShapeException(func) + + +@register_op_impl(aten.repeat_interleave.Tensor) +def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): + if output_size is None: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + raise DynamicOutputShapeException(func) + + output_size = fake_mode.shape_env.create_unbacked_symint() + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(output_size) + # TODO: consider a memo + return repeats.new_empty(output_size) + + +@register_op_impl(torch.ops.aten._local_scalar_dense.default) +def local_scalar_dense(fake_mode, func, arg): + if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs: + # Without symints/symfloats, cannot handle this + raise DataDependentOutputException(func) + if is_float_dtype(arg.dtype): + return fake_mode.shape_env.create_unbacked_symfloat() + elif is_integer_dtype(arg.dtype): + return fake_mode.shape_env.create_unbacked_symint() + elif is_boolean_dtype(arg.dtype): + return fake_mode.shape_env.create_unbacked_symbool() + else: + raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") + + +@register_op_impl(torch.ops.aten.nonzero.default) +def nonzero(fake_mode, func, arg): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + if arg.nonzero_memo is None: + nnz = fake_mode.shape_env.create_unbacked_symint() + + # This is unsound, but it works well in practice + # See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# + # TODO: Add a config knob to turn off this unsound behavior + # + # NB: If numel < 2, the bounds here might be COMPLETELY + # disjoint with what can actually occur. But this is fine: + # remember, the hypothesis is that if your later code works + # with N >= 2, it will work with N = 1 and N = 0. + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(arg.numel()): + # Don't upgrade the range if numel is less than two, since we then + # have an empty range which makes things go explodey. We also + # don't allow for 2 because that would specialize the unbacked + # SymInt to 2, which is also likely to be buggy. + if arg.numel() > 2: + maxval = int(arg.numel()) + + _constrain_range_for_size(nnz, max=maxval) + + arg._nonzero_memo = nnz + arg._nonzero_memo_vc = arg._version + + return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64) + + +@register_op_impl(torch.ops.aten.masked_select.default) +def masked_select(fake_mode, func, self, mask): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + nnz = fake_mode.shape_env.create_unbacked_symint() + + # see nonzero for commentary + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(self.numel()): + if self.numel() > 2: + maxval = int(self.numel()) + + _constrain_range_for_size(nnz, max=maxval) + + return self.new_empty((nnz,)) + + +# NB: this must be ordered after local_scalar_dense +@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) +def data_dep(fake_mode, func, *args, **kwargs): + raise DataDependentOutputException(func) + + +# Bool Indices get Expanded as Masks +# See: IndexingUtils.h:expandTensors +def check_no_bool_index_tensors(func, self, indices): + for index in indices: + if index is not None and index.dtype in (torch.bool, torch.uint8): + raise DynamicOutputShapeException(func) + + +def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + with in_kernel_invocation_manager(fake_mode): + out = func(*args, **kwargs) + if not is_noncontiguous_supported(out_device): + out = out.new_empty(out.shape) + + if out is new_kwargs["input"]: + return out # copy_ + return FakeTensor(fake_mode, out, out_device) + + +_is_builtin_namespaces = ordered_set("aten", "prims", "prim") + + +def is_builtin(op): + return op.namespace in _is_builtin_namespaces + + +def has_meta(func): + return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") + + +@register_op_impl( + lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) +) +def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): + tensor_lists = [] + for arg in itertools.chain(args, kwargs.values()): + if ( + isinstance(arg, (list, tuple)) + and len(arg) + and isinstance(arg[0], torch.Tensor) + ): + tensor_lists.append(arg) + + try: + with in_kernel_invocation_manager(fake_mode): + out_meta = func(*args, **kwargs) + except NotImplementedError as not_implemented_error: + return NotImplemented + + if not out_meta: + return out_meta + + assert tensor_lists + out_fake = [] + + for i, meta_t in enumerate(out_meta): + device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) + out_fake.append( + fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, meta_t, device + ) + ) + + return out_fake + + +# Dont default to default device handling, +# Since op can take in non-zero sized cpu +# index tensors with cuda self +@register_op_impl(aten.index.Tensor) +def index_tensor(fake_mode, func, *args, **kwargs): + from torch._meta_registrations import meta_index_Tensor + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + # ensure nonzero call goes to fake tensor + with fake_mode: + out = meta_index_Tensor(*args, **kwargs) + return out.to(out_device) + + +# Can take mixed meta/non-meta arguments; the meta registration +# will roughly do the right thing even when given real devices +@register_op_impl(aten._embedding_bag.default) +def embedding_bag(fake_mode, func, *args, **kwargs): + from torch._meta_registrations import meta_embedding_bag + + with fake_mode: + return meta_embedding_bag(*args, **kwargs) + + +# takes in multiple-devices, dont default to default device handling +@register_op_impl(aten._unsafe_index_put.default) +@register_op_impl(aten.copy.default) +@register_op_impl(aten.copy_.default) +@register_op_impl(aten.slice_scatter.default) +def multi_device_op_default(fake_mode, func, *args, **kwargs): + return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + + +# same with multi_device_op_default, but return the input +@register_op_impl(aten.copy.out) +@register_op_impl(aten.slice_scatter.out) +def multi_device_op_out(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + out = func(*args, **kwargs) + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + return new_kwargs["input"] + + +@register_op_impl(aten.index_put.default) +@register_op_impl(aten.index_put_.default) +def index_put_impl(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + values = new_kwargs["values"] + self_device = new_kwargs["input"].fake_device + torch._check( + self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), + lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", + ) + + out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + if func is aten.index_put_.default: + return new_kwargs["input"] + else: + return out + + +@register_op_impl(aten._nested_tensor_from_tensor_list.default) +@register_op_impl(aten._nested_tensor_from_tensor_list.out) +def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): + raise UnsupportedOperatorException( + "torch.compile does not support strided NestedTensor" + ) + + +@register_op_impl( + [ + x + for x in _device_not_kwarg_ops + if x + not in ( + # these are already registered elsewhere + aten.to.device, + aten.to.prim_Device, + aten._nested_tensor_from_tensor_list.default, + aten._nested_tensor_from_tensor_list.out, + ) + ] +) +def nyi(fake_mode, func, *args, **kwargs): + assert func not in _device_not_kwarg_ops, f"NYI: {func}" + + +@register_op_impl([aten.convolution.default, aten.convolution_backward.default]) +def conv(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + device = kwargs["input"].fake_device + # need to re-enable mode so the tensors report fake device + with fake_mode: + # if the input is unsqueezed is done in Convolution.cpp we get segfault + k = kwargs["weight"].ndim + batch = kwargs["input"].shape[0] + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import has_hint + + if not has_hint(batch): + # TODO: We can make this a little more faithful with best effort + # channels last detection (but only if it's statically obvious!) + mem_fmt = None + elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + mem_fmt = None + else: + if func is aten.convolution.default: + conv_backend = torch._C._select_conv_backend(**kwargs) + else: + conv_backend = torch._C._select_conv_backend( + kwargs["input"], + kwargs["weight"], + bias=None, + stride=kwargs["stride"], + padding=kwargs["padding"], + dilation=kwargs["dilation"], + transposed=kwargs["transposed"], + output_padding=kwargs["output_padding"], + groups=kwargs["groups"], + bias_sizes=kwargs["bias_sizes"], + ) + mem_fmt = torch._C._conv_determine_backend_memory_format( + kwargs["input"], kwargs["weight"], conv_backend + ) + + def convert(t, mem_fmt): + if t is None: + return t + if mem_fmt is not None: + t = t.to(memory_format=mem_fmt) + return FakeTensor(fake_mode, t, device) + + with in_kernel_invocation_manager(fake_mode): + out = func(**kwargs) + + if func is aten.convolution.default: + return convert(out, mem_fmt) + else: + return ( + convert(out[0], mem_fmt), + convert(out[1], mem_fmt), + convert(out[2], None), + ) + + +@register_op_impl(aten._scaled_dot_product_flash_attention.default) +def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + return_debug_mask = kwargs["return_debug_mask"] + # unused: value, dropout_p, is_causal, scale + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + head_dim = query.size(3) + max_seqlen_batch_k = key.size(2) + + query_t = query.transpose(1, 2) + # empty_like already returns a fake tensor so we don't need to convert it + attention = torch.empty_like(query_t).transpose(1, 2) + logsumexp = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device="meta", + ), + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device="meta", + ), + device=query.device, + ) + else: + debug_mask = convert_tensor( + torch.empty(0, dtype=query.dtype, device="meta"), + query.device, + ) + + # Note [Seed and Offset]: device for seed and offset below depends on whether we are + # capturing or not, but at the time of tracing we don't know if we + # are going to use cudagraphs or not, so we return meta tensors here + # it's possible we'll need to have some special handling in inductor for sdpa + + return ( + attention, + logsumexp, + None, + None, + max_seqlen_batch_q, + max_seqlen_batch_k, + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + debug_mask, + ) + + +@register_op_impl(aten._scaled_dot_product_efficient_attention.default) +def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + compute_log_sumexp = kwargs["compute_log_sumexp"] + # unused: attn_bias, dropout_p, is_causal, scale + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = convert_tensor( + torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), + query.device, + ) + + logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 + logsum_exp = convert_tensor( + torch.empty( + (B, num_heads, logsumexp_dim), + dtype=torch.float, + device="meta", + ), + query.device, + ) + + res = res.transpose(1, 2) + + # See Note [Seed and Offset]: + seed = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + offset = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + + return res, logsum_exp, seed, offset + + +@register_op_impl(aten._flash_attention_forward.default) +def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + cum_seq_q = kwargs["cum_seq_q"] + cum_seq_k = kwargs["cum_seq_k"] + max_q = kwargs["max_q"] + max_k = kwargs["max_k"] + return_debug_mask = kwargs["return_debug_mask"] + # unused: value, dropout_p, is_causal, scale + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + # NB: there are two underlying paths: + # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) + # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total + # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total + batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 + max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q + max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k + num_heads = query.size(-2) + head_dim = query.size(-1) + + # Cuda Path + # note: empty_like already returns a fake tensor, we don't need to wrap it + attention = torch.empty_like(query) + logsumexp = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device="meta", + ), + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = convert_tensor( + torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device="meta", + ), + query.device, + ) + else: + debug_mask = convert_tensor( + torch.empty(0, dtype=query.dtype, device="meta"), + query.device, + ) + + # See Note [Seed and Offset]: + return ( + attention, + logsumexp, + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), + debug_mask, + ) + + +@register_op_impl(aten._efficient_attention_forward.default) +def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + cu_seqlens_q = kwargs["cu_seqlens_q"] + max_seqlen_q = kwargs["max_seqlen_q"] + max_seqlen_k = kwargs["max_seqlen_k"] + compute_log_sumexp = kwargs["compute_log_sumexp"] + # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k + + def convert_tensor(t, device): + return FakeTensor(fake_mode, t, device) + + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = convert_tensor( + torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), + query.device, + ) + + logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B + actual_max_seqlen_q = M + if cu_seqlens_q is not None: + assert max_seqlen_q is not None + actual_max_seqlen_q = max_seqlen_q + actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N + logsumexp_dim = ( + math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 + ) + logsum_exp = convert_tensor( + torch.empty( + (logsumexp_batch_dim, num_heads, logsumexp_dim), + dtype=torch.float, + device="meta", + ), + query.device, + ) + + # See Note [Seed and Offset]: + seed = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + offset = convert_tensor( + torch.empty((), dtype=torch.long, device="meta"), query.device + ) + + return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k + + +FAST_OP_IMPLEMENTATIONS = {} + + +# Unlike register_op_impl, these don't do the slow iteration for +# run_impl_check, and these run BEFORE decompositions +def register_fast_op_impl(func: OpOverload): + def impl_decorator(op_impl): + FAST_OP_IMPLEMENTATIONS[func] = op_impl + return op_impl + + return impl_decorator + + +# infer_size_impl in ExpandUtils +def infer_size(a, b): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + # NB: It is very important to test for broadcasting, before testing + # sizeA == sizeB. This is because the broadcasting tests are likely + # to be statically known (in particular, if sizeA/sizeB is unbacked + # but size-like, we will unsoundly assume they never equal 1), but + # the sizeA == sizeB test may not be statically known. However, once + # we have established that no broadcasting is happening, the + # sizeA == sizeB is now expect_true and we can defer it as a runtime + # assert (this works because Python will return the terminal + # expression of an or statement as-is, without bool()'ing it; if this + # were not the case, we'd need to write this using torch.sym_or() or + # something like that). + torch._check( + guard_size_oblivious(sizeA == 1) + or guard_size_oblivious(sizeB == 1) + or sizeA == sizeB, + lambda: f"The size of tensor a ({sizeA}) " + f"must match the size of tensor b ({sizeB}) " + f"at non-singleton dimension {i})", + ) + expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA + return tuple(expandedSizes) + + +def make_fast_binary_impl(slow_ref): + def fast_binary_impl(mode, *args, **kwargs): + def slow(msg): + count_label(f"slow {msg}") + with mode: + return slow_ref(*args, **kwargs) + + count_label("attempt fast") + + # Fast path (based off of TensorIterator fast path). + # Unfortunately, there is no way to easily deduplicate + # this with either the TensorIterator C++ implementation + # (which we don't want to SymIntify, and also the algorithm + # here is slightly different from TensorIterator to allow + # for broadcasting), nor the PrimTorch implementation + # (which does not actually implement a fast path.) + + operands = args + + # compute_shape + has_scalars = False + has_tensors = False + final_shape = None + for op in operands: + shape = op.shape if isinstance(op, torch.Tensor) else () + if len(shape) == 0: + has_scalars = True + else: + has_tensors = True + if final_shape is None: + final_shape = shape + # TODO: Minor optimization: track if the shapes + # were equal so you can skip the equality check + # below if unnecessary + final_shape = infer_size(final_shape, shape) + assert final_shape is not None + + # Do some extra safety checks to see if the output + # stride is obvious + for op in operands: + if ( + isinstance(op, torch.Tensor) + and len(op.shape) == len(final_shape) + and op.shape == final_shape + ): + break + else: + return slow("both tensors nontrivially broadcast") + + # compute_types + cpu = torch.device("cpu") + common_device = cpu + common_dtype = None + output_dtype = None + has_different_input_dtypes = False + for op in operands: + if not isinstance(op, torch.Tensor): + # Use elementwise_dtypes for the tricky case + has_different_input_dtypes = True + continue + if common_device == cpu and not op.device.type == "cpu": + common_device = op.device + # Slightly simplified here as target_dtype cannot vary + if common_dtype is None: + common_dtype = op.dtype + elif common_dtype != op.dtype: + has_different_input_dtypes = True + + if has_different_input_dtypes: + # compute promotion + # TODO: we don't need the compute type + _, common_dtype = elementwise_dtypes( + *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + # check all tensors on same device + # cpu scalars are assumed allow + current_cpu_scalars_on_non_cpu = 0 + max_cpu_scalars_on_non_cpu = 1 # hard coded atm + for op in operands: + if not isinstance(op, torch.Tensor): + continue + if common_device != cpu and op.dim() == 0 and op.device == cpu: + if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: + return slow("error") + current_cpu_scalars_on_non_cpu += 1 + elif op.device != common_device: + return slow("error") + + # compute_fast_setup_type + is_contiguous = True + is_channels_last = True + # TODO: is_non-overlapping_and_dense (not bound from Python + # no inplace, no out, everything defined + + if is_noncontiguous_supported(common_device): + for op in operands: + if not isinstance(op, torch.Tensor): + continue + is_contiguous = is_contiguous and op.is_contiguous( + memory_format=torch.contiguous_format + ) + is_channels_last = is_channels_last and op.is_contiguous( + memory_format=torch.channels_last + ) + if is_contiguous: + # do contiguous + count_label("fast is_contiguous") + return FakeTensor( + mode, + torch.empty( + final_shape, + dtype=common_dtype, + device="meta", + memory_format=torch.contiguous_format, + ), + device=common_device, + ) + if is_channels_last: + count_label("fast channels_last") + # do channels last + return FakeTensor( + mode, + torch.empty( + final_shape, + dtype=common_dtype, + device="meta", + memory_format=torch.channels_last, + ), + device=common_device, + ) + + return slow("no contiguity match") + + return fast_binary_impl + + +@functools.lru_cache(None) +def get_fast_op_impls(): + import torch._refs + + register_fast_op_impl(torch.ops.aten.add.Tensor)( + make_fast_binary_impl(torch._refs.add) + ) + register_fast_op_impl(torch.ops.aten.sub.Tensor)( + make_fast_binary_impl(torch._refs.sub) + ) + register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] + register_fast_op_impl(torch.ops.aten.div.Tensor)( + make_fast_binary_impl(torch._refs.div) + ) + return FAST_OP_IMPLEMENTATIONS diff --git a/MLPY/Lib/site-packages/torch/_subclasses/fake_tensor.py b/MLPY/Lib/site-packages/torch/_subclasses/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b9cacbdebf4912c8d0d2bba9575f0933666139 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/fake_tensor.py @@ -0,0 +1,1819 @@ +# mypy: ignore-errors + +import contextlib +import functools +import logging +import os +import traceback +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, TypeVar +from weakref import ReferenceType + +import torch +import torch._custom_op +import torch._logging +from torch._C._functorch import is_functorch_wrapped_tensor + +from torch._guards import Source +from torch._ops import OpOverload +from torch._prims_common import suggest_memory_format +from torch._subclasses.meta_utils import ( + assert_eq, + assert_metadata_eq, + is_sparse_any, + is_sparse_compressed, + MetaConverter, +) +from torch._utils import render_call +from torch.fx.operator_schemas import normalize_function +from torch.multiprocessing.reductions import StorageWeakRef +from torch.overrides import TorchFunctionMode +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TorchDispatchMode, +) + +from torch.utils._pytree import PyTree, tree_map +from torch.utils._stats import count +from torch.utils.weak import WeakIdRef + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +DimList = List + +log = logging.getLogger(__name__) + +# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186 +# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105 +try: + not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") +except ValueError as e: + if "'not_implemented' not registered" in str(e): + import logging as not_implemented_log + else: + raise e + +pytree = torch.utils._pytree +T = TypeVar("T") +TensorWeakRef = Any + +aten = torch._ops.ops.aten + +CONSTANT_NUMEL_LIMIT = 1 + +RECURSION_COUNT = 0 + + +# Small helper that increments recursion count, and +# resets it when the object goes out of scope. Useful +# if you don't want to increase indentation which is +# what a context manager would do. +class IncrementRecursionCount: + def __init__(self): + global RECURSION_COUNT + RECURSION_COUNT += 1 + + def __del__(self): + global RECURSION_COUNT + RECURSION_COUNT -= 1 + + +@dataclass +class UnsupportedFakeTensorException(RuntimeError): + reason: str + + +@dataclass +class DynamicOutputShapeException(RuntimeError): + func: OpOverload + + +@dataclass +class DataDependentOutputException(RuntimeError): + func: OpOverload + + +@dataclass +class UnsupportedOperatorException(RuntimeError): + func: OpOverload + + +def ordered_set(*items): + return dict.fromkeys(items, True) + + +@contextlib.contextmanager +def unset_fake_temporarily(): + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + try: + yield old + finally: + if old is not None: + torch._C._set_dispatch_mode(old) + + +def is_fake(x): + if isinstance(x, FakeTensor): + return True + if is_traceable_wrapper_subclass(x): + attrs, _ = type(x).__tensor_flatten__(x) + flattened_tensors = [getattr(x, attr) for attr in attrs] + # need to recurse because we could have nested subclasses + all_fake = all(is_fake(x) for x in flattened_tensors) + any_fake = any(is_fake(x) for x in flattened_tensors) + assert all_fake == any_fake, "got mixed fake and real tensors!" + return all_fake + elif isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views) + return is_fake(unwrapped) + elif isinstance(x, torch.Tensor) and is_functorch_wrapped_tensor(x): + unwrapped = torch._C._functorch.get_unwrapped(x) + return is_fake(unwrapped) + return False + + +def maybe_get_fake_mode(t): + if isinstance(t, FakeTensor): + return t.fake_mode + if is_traceable_wrapper_subclass(t): + inner_tensor_names, _ = t.__tensor_flatten__() + modes = [ + maybe_get_fake_mode(getattr(t, t_name)) for t_name in inner_tensor_names + ] + m = modes[0] + assert all(m is x for x in modes) + return m + elif isinstance(t, torch.Tensor) and torch._is_functional_tensor(t): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views) + return maybe_get_fake_mode(unwrapped) + elif isinstance(t, torch.Tensor) and is_functorch_wrapped_tensor(t): + unwrapped = torch._C._functorch.get_unwrapped(t) + return maybe_get_fake_mode(unwrapped) + return None + + +@functools.lru_cache(None) +def get_schema_info(func): + return torch._C._SchemaInfo(func._schema) # type: ignore[attr-defined] + + +# many of the decompositions registered to torch/_prims do not at the moment model +# aliasing or strides, so as an incremental step, just enable the decompositions in +# torch/_decomp/decompositions.py. +# decomps are used for aot autograd tracing so we would like to unify on their +# implementation and add additional testing to them +@functools.lru_cache(None) +def torch_decomp_decompositions(func): + from torch._decomp import decomposition_table + + decompositions = torch._decomp.decompositions + # Note that the function in the decomposition table might be + # different from the one in the module because of the difference + # in out handling in aten API and torch public API + return decomposition_table[func].__module__.startswith( + "torch._decomp" + ) and decomposition_table[func].__name__ in dir(decompositions) + + +def tree_flatten_only(ty: Type[T], tree: PyTree): + flat_vals = pytree.tree_leaves(tree) + return [elem for elem in flat_vals if isinstance(elem, ty)] + + +# Similar to `MetaConverter`, this is a class for converting +# multiple tensors into fake tensors which share the same view/storage +# structure. Like `MetaConverter`, it uses `WeakIdRef` to +# hold a weak reference for all memoized tensors. +class FakeTensorConverter: + @property + def tensor_memo(self): + return self.meta_converter.tensor_memo + + meta_converter: MetaConverter + constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]] + + def __init__(self): + self.meta_converter = MetaConverter() + + # map from to storage to corresponding constant tensors + self.constant_storage_mapping = {} + + def add_constant_storage_mapping(self, fake_tensor): + # when you have a constant, aliased tensor: + # const_tensor.add_(torch.rand([1])) + # all aliases of it must become no longer const + assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None + weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) + + # we need a map from a weak storage to all of its corresponding + # constant tensors. python doesn't have the weak value equivalent + # of defaultdict(list), so we are using a WeakValueDictionary as one + if weak_st not in self.constant_storage_mapping: + self.constant_storage_mapping[weak_st] = [] + self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor)) + + def invalidate_constant_aliases(self, tensor): + assert not isinstance(tensor, FakeTensor) + + weak_st = StorageWeakRef(tensor._typed_storage()) + if weak_st not in self.constant_storage_mapping: + return + + for weak_tensor_ref in self.constant_storage_mapping[weak_st]: + ten = weak_tensor_ref() + if ten is not None: + ten._fix_weakref() + ten.constant = None + + del self.constant_storage_mapping[weak_st] + + def _get_memo(self, t): + if WeakIdRef(t) in self.tensor_memo: + out = self.tensor_memo[WeakIdRef(t)] + out._fix_weakref() + return out + return None + + def set_tensor_memo(self, t, v): + th = WeakIdRef(t) + + # hold a weak ref to self, otherwise it will be kept alive + # by the del_ten closure + self_weak_ref = weakref.ref(self) + + def del_ten(): + self_ref = self_weak_ref() + if self_ref is None: + return + # on shutdown, th may not be in memo + self_ref.tensor_memo.pop(th, None) + + weakref.finalize(t, del_ten) + self.tensor_memo[th] = v + + def from_real_tensor( + self, + fake_mode, + t, + make_constant=False, + shape_env=None, + *, + source=None, + symbolic_context=None, + memoized_only=False, + ): + # see note [Tensor Fakification and Symbol Caching] + if not symbolic_context and not source and shape_env: + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + maybe_memo = self._get_memo(t) + if maybe_memo is not None: + return maybe_memo + if memoized_only: + return None + existing_device = t.device + # not yet supported in metatensors + if t.is_quantized: + raise UnsupportedFakeTensorException("quantized nyi in meta tensors") + if type(t) is torch.nn.Parameter: + assert not make_constant + + def mk_fake_tensor(make_meta_t): + # NB: don't use in_kernel_invocation_manager. to + # ensure FakeTensor can internally do constant computation + # as necessary. Invocation manager is "more correct" as + # it works for more operators in make_meta_t, but + # invariant is that make_meta_t only calls factories + # for which it is not strictly necessary to use the + # invocation manager (I think!) + with no_dispatch(): + return FakeTensor( + fake_mode, + make_meta_t(), + existing_device, + constant=t if make_constant else None, + ) + + out = self.meta_converter( + t, + shape_env=shape_env, + callback=mk_fake_tensor, + source=source, + symbolic_context=symbolic_context, + ) + if out is NotImplemented: + raise UnsupportedFakeTensorException("meta converter nyi") + if make_constant: + self.add_constant_storage_mapping(out) + # NB: meta_converter set the memo + return out + + # If you specify the device, it MUST be a meta tensor. + def from_meta_and_device(self, fake_mode, t, device): + assert ( + t.device.type == "meta" + ), f"tensor's device must be `meta`, got {t.device.type} instead" + maybe_memo = self._get_memo(t) + if maybe_memo is not None: + return maybe_memo + out = FakeTensor(fake_mode, t, device) + self.set_tensor_memo(t, out) + return out + + # You can have a real tensor that you need to convert into a fake tensor. + # If you have a meta tensor already, call from_meta_and_device. + # + # You're allowed to pass a meta tensor to be turned into a fake + # tensor; although an odd thing to do, this can occur if you're doing + # cross ref testing and the inner test is already operating on meta tensors. + def __call__( + self, + fake_mode, + t, + *, + make_constant=False, + shape_env=None, + source=None, + symbolic_context=None, + memoized_only=False, + ): + return self.from_real_tensor( + fake_mode, + t, + make_constant, + shape_env=shape_env, + source=source, + symbolic_context=symbolic_context, + memoized_only=memoized_only, + ) + + +@functools.lru_cache(None) +def init_cuda_context(): + # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first + if torch.cuda.is_available(): + torch.empty(1, device="cuda") if torch.version.hip is None else torch.zeros( + 1, device="cuda" + ) + + +@contextlib.contextmanager +def in_kernel_invocation_manager(fake_mode): + # See: note [Fake Tensor Dispatch Keys] + prev_in_kernel = fake_mode.in_kernel_invocation + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}" + + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + fake_mode.in_kernel_invocation = True + torch._C._set_meta_in_tls_dispatch_include(True) + try: + yield + finally: + fake_mode.in_kernel_invocation = prev_in_kernel + torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel) + del guard + + +# Return if the function allows Python numbers to bind to Tensors +def should_allow_numbers_as_tensors(func: OpOverload): + return torch._C._should_allow_numbers_as_tensors( + func.name().split("::")[-1].split(".")[0] + ) + + +class FakeTensorConfig: + debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1" + + +class FakeTensor(torch.Tensor): + """ + Meta tensors give you the ability to run PyTorch code without having to + actually do computation through tensors allocated on a `meta` device. + Because the device is `meta`, meta tensors do not model device propagation. + FakeTensor extends MetaTensors to also carry an additional `fake_device` + which tracks devices that would have been used. + """ + + fake_device: torch.device + fake_mode: "FakeTensorMode" + constant: Optional[torch.Tensor] + + # This memorizes the unbacked SymInt representing the number of nonzero + # elements in this tensor. This is helpful if you do something like + # x[mask] and y[mask]; mask.nonzero() gets repeatedly called and should + # give a consistent unbacked SymInt. It needs to be invalidated in the + # same way constant is. + # TODO: Generalize this as needed, e.g., into a trie of memos + _nonzero_memo: Optional[torch.SymInt] + _nonzero_memo_vc: Optional[int] + + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FAKE + + @property + def nonzero_memo(self): + if self._nonzero_memo is None: + return None + # Version counter based tracking isn't 100% sound but it's close + # enough + if self._nonzero_memo_vc != self._version: + self._nonzero_memo = None + return None + return self._nonzero_memo + + @property + def device(self): + if self.fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return self.fake_device + + # Note: [Fake Tensor Dispatch Keys] + # In order to model the behavior of device-specific autocast + # and autograd logic, we update the dispatch keys of FakeTensors + # to reflect their fake device. This includes the BackendComponent + # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent + # related Autocast and Autograd keys. __torch__dispatch__ sits below + # Autocast and Autograd, and is only invoked when we are at the + # kernel for the BackendComponent. Then, we add Meta to the + # thread-local dispatch include set to hit the meta kernel + # instead of the kernel of the BackendComponent for the fake device. + # The `device_for_backend_keys` does that below + # NOTE: this probably will not do the right thing for backends + # that have dispatch keys which are higher than the "meta" key: + # https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189 + + # We don't support named tensors; graph break + @property + def names(self): + raise UnsupportedFakeTensorException( + "torch.compile doesn't support named tensors" + ) + + @staticmethod + def __new__(cls, fake_mode, elem, device, constant=None): + self = torch.Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + + assert elem.device.type == "meta", elem.device.type + device = device if isinstance(device, torch.device) else torch.device(device) + # NB: it is fine, if a little confusing, for device to be meta + # (we are faking a meta tensor in that case). However, it often + # indicates some sort of confusion (e.g., you accidentally passed + # in a meta tensor when you should have passed in the real tensor). + # So by default we disallow meta, and if you are working in a situation + # where it is helpful (e.g., crossref testing) you can turn it back + # on + if not fake_mode.allow_meta: + assert device.type != "meta" + # normalize device. + if device.type == "cuda": + init_cuda_context() + + if ( + device.type + in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()] + and device.index is None + ): + device = torch.device( + f"{device.type}:{getattr(torch, device.type).current_device()}" + ) + self.fake_device = device # type: ignore[attr-defined] + self.fake_mode = fake_mode # type: ignore[attr-defined] + self.constant = constant # type: ignore[attr-defined] + self._nonzero_memo = None # type: ignore[attr-defined] + self._nonzero_memo_vc = None # type: ignore[attr-defined] + + if FakeTensorConfig.debug: + import traceback + + self._debug_trace = traceback.extract_stack() # type: ignore[attr-defined] + return self + + # In some circumstances, a conventional torch.Tensor constructor + # will get rewritten to call into FakeTensor. We must provide an + # __init__ method that can accept the Python interpreters initialization + # in such a situation; we must also be able to handle direct fake + # tensor construction via FakeTensor(). + # + # In particular, the __init__ call will look funny in the following case: + # + # with FakeTensorMode(): + # x = torch.Tensor([1, 2, 3]) + # + # this desugars into: + # + # with FakeTensorMode(): + # x = torch.Tensor.__new__([1, 2, 3]) + # # NB: x is a fake tensor, because of the mode! + # x.__init__([1, 2, 3]) # not the normal fake tensor args! + # + def __init__(self, *args, **kwargs): + super().__init__() + + @staticmethod + def from_tensor(t, fake_mode): + return fake_mode.from_tensor(t) + + @classmethod + @count + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # need to handle here to avoid infinite recursion + # see [in_kernel_invocation] + if func == torch.ops.prim.device.default: + assert len(args) == 1 and isinstance(args[0], FakeTensor) + if args[0].fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return args[0].fake_device + + # Because fake mode can return NotImplemented (if it sees a subclass + # it doesn't know how to deal with), this test here is important + # because the next dispatch after a fake mode will attempt to use + # subclasses of tensors to dispatch, and any FakeTensor arguments + # will be considered eligible. + unrecognized_types = [ + t for t in types if not issubclass(t, FakeTensor) and t is not torch.Tensor + ] + if unrecognized_types: + not_implemented_log.debug( + "FakeTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + fake_mode = None + for arg in pytree.arg_tree_leaves(*args, **kwargs): + if isinstance(arg, FakeTensor): + fake_mode = arg.fake_mode + break + + assert fake_mode is not None + + # If the fake mode is already active, don't try to reapply it! + # NotImplemented is the right thing to return here, because the + # typical situation this can occur is if ProxyTensorMode returned a + # NotImplemented because of a not implemented subclass; we may have + # unluckily attempted to hit FakeTensor's dispatch first, + # NotImplemented lets us keep chaining until we find the actual + # subclass + maybe_cur_fake_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE + ) + if maybe_cur_fake_mode: + not_implemented_log.debug( + "FakeTensor mode already active: %s in %s", + fake_mode, + maybe_cur_fake_mode, + ) + return NotImplemented + + with fake_mode: # type: ignore[attr-defined] + return func(*args, **kwargs) + + @staticmethod + def _find_common_device(func, flat_args) -> Tuple[torch.device, bool]: + # Returns: (common_device, has_scalar_only_inputs) + + # cpu - zero-dim tensors can be called in cuda kernels, + # so overwrite the common_device if it the only existing + # device comes from a cpu zero-dim tensor + common_device = None + has_scalar_only_inputs = False + is_cpu_zero_dim = None + + def cpu_zero_dim(t): + return t.device.type == "cpu" and t.dim() == 0 + + def merge_devices(t): + nonlocal common_device + nonlocal is_cpu_zero_dim + if not isinstance(t, FakeTensor): + return + + if common_device is None: + common_device = t.device + is_cpu_zero_dim = cpu_zero_dim(t) + return + + t_is_cpu_zero_dim = cpu_zero_dim(t) + if t.device == common_device: + if is_cpu_zero_dim: + is_cpu_zero_dim = t_is_cpu_zero_dim + return + + # mismatching devices ! + # if current tensor is cpu 0 dim, defer to existing device + if t_is_cpu_zero_dim: + return + + # current device is from cpu 0 dim tensor, overwrite + if is_cpu_zero_dim: + common_device = t.device + is_cpu_zero_dim = t_is_cpu_zero_dim + return + + # mismatching devices of non-zero dim tensors, throw + # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as + raise RuntimeError( + f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" + ) + + for arg in flat_args: + merge_devices(arg) + + # some functions that allow Python numbers to bind to Tensors + # if we have failed to find a device, and we're running one of these operators, + # we must have scalar only inputs + if should_allow_numbers_as_tensors(func) and common_device is None: + # ops with scalar only inputs always have result on cpu + has_scalar_only_inputs = True + common_device = torch.device("cpu") + + assert common_device is not None, f"Could not find common device for {func}" + + return common_device, has_scalar_only_inputs + + # We must handle tolist in a special way for FakeTensors here in the case + # where tolist is called from torch dispatch for tensor subclasses. + # Ordinarily, if a program calls .tolist compiling still works because there is + # special handling in dynamo, but for tensor subclasses if .tolist is called + # inside torch dispatch, the .tolist call may be directly on a FakeTensor. + # This would result in an error since wrapper subclasses don't have storage. + # To avoid this, we handle the FakeTensor case by (1) specializing on the size + # of the tensor to create the output Python list, and (2) creating unbacked + # symints for each element of the list. + def tolist(self): + assert self.dim() == 1, "NYI for higher dims" + shape_env = self.fake_mode.shape_env + out = [] + # Specialize on the length of the list + for _ in range(self.shape[0]): + s = shape_env.create_unbacked_symint() + # max value? + torch._constrain_as_size(s, min=2) + out.append(s) + return out + + +@dataclass(frozen=True) +class TensorMetadata: + """ + The Tensor metadata relevant to hashing FakeTensors when caching. + """ + + dtype: torch.dtype + shape: torch.Size + stride: Tuple[Any, ...] + device: torch.device + layout: torch.layout + memory_format: Optional[torch.memory_format] + storage_offset: int + requires_grad: bool + is_quantized: bool + is_conj: bool + is_neg: bool + is_inference: bool + is_sparse: bool # read: is sparse COO + is_coalesced: Optional[bool] + dense_dim: Optional[int] + sparse_dim: Optional[int] + + +def extract_tensor_metadata(t: torch.Tensor) -> "TensorMetadata": + """ + Extract the TensorMetadata of a tensor. + """ + memory_format = suggest_memory_format(t) + if is_sparse_any(t) or not t.is_contiguous(memory_format=memory_format): + memory_format = None + + return TensorMetadata( + dtype=t.dtype, + shape=t.shape, + stride=t.stride() if t.layout == torch.strided else (), + device=t.device, + layout=t.layout, + memory_format=memory_format, + storage_offset=t.storage_offset(), + requires_grad=t.requires_grad, + is_quantized=t.is_quantized, + is_conj=t.is_conj(), + is_neg=t.is_neg(), + is_inference=t.is_inference(), + is_sparse=t.is_sparse, + is_coalesced=t.is_coalesced() if t.is_sparse else None, + dense_dim=t.dense_dim() if t.is_sparse else None, + sparse_dim=t.sparse_dim() if t.is_sparse else None, + ) + + +@dataclass(frozen=True) +class _ShapeEnvSettings: + """ + Encapsulates all shape env settings that could potentially affect + FakeTensor dispatch. Used when creating dispatch cache keys. + """ + + allow_scalar_outputs: bool + allow_dynamic_output_shape_ops: bool + assume_static_by_default: bool + specialize_zero_one: bool + duck_shape: bool + + def __init__(self, env: "ShapeEnv"): + # Initialize this way because the class is frozen (to enable hashing): + object.__setattr__(self, "allow_scalar_outputs", env.allow_scalar_outputs) + object.__setattr__( + self, "allow_dynamic_output_shape_ops", env.allow_dynamic_output_shape_ops + ) + object.__setattr__( + self, "assume_static_by_default", env.assume_static_by_default + ) + object.__setattr__(self, "specialize_zero_one", env.specialize_zero_one) + object.__setattr__(self, "duck_shape", env.duck_shape) + + +class _DispatchCacheKey(list): + """ + Key for the FakeTensor dispatch cache. Inspired by (copied from) + _HashedSeq from the functools.lru_cache implementation. + """ + + __slots__ = "hashvalue" # noqa: PLC0205 + + def __init__(self, tup, hash=hash): + self[:] = tup + self.hashvalue = hash(tup) + + def __hash__(self): + return self.hashvalue + + +@dataclass(frozen=True) +class _DispatchCacheEntry: + """ + Entry type for the FakeTensor dispatch cache. Accounts for two possibilities: + 1) The op is inplace, and a hit means we need to alias the argument at a given + index. 2) We need to synthesize a new FakeTensor given tensor metadata. For view + ops, we further capture the index of the arg to alias. + """ + + inplace_idx: Optional[int] = None + metadata: Optional[TensorMetadata] = None + view_idx: Optional[int] = None + + +@dataclass(frozen=True) +class _BypassDispatchCache(Exception): + """ + Signals cases that should skip FakeTensor caching. + """ + + reason: str + + +@dataclass(frozen=True) +class DispatchCacheInfo: + """ + Information about the state of the FakeTensor dispatch cache. + """ + + hits: int + misses: int + bypasses: Dict[str, int] + size: int + + +# We keep one instantiation of `fake_tensor_converter` active +# for the duration of `with FakeTensorMode()`. +# This allows accurate storage aliasing across invocation of +# different operators. While this will keep all freshly allocated +# tensors alive during `FakeTensorMode`, there will no be no +# new allocations of Tensors which have non-meta storage so +# memory should not significantly increase. + + +class FakeTensorMode(TorchDispatchMode): + cache: Dict[_DispatchCacheKey, _DispatchCacheEntry] = {} + cache_hits: int = 0 + cache_misses: int = 0 + cache_bypasses = defaultdict(int) + + def __init__( + self, + *, + allow_fallback_kernels=True, + allow_non_fake_inputs=False, + shape_env=None, + static_shapes=None, + ): + log.debug("create_mode 0x%x", id(self)) + self.allow_fallback_kernels = allow_fallback_kernels + self.fake_tensor_converter = FakeTensorConverter() + if static_shapes is not None: + self.static_shapes = static_shapes + else: + self.static_shapes = shape_env is None + + import torch._dynamo.config + import torch._functorch.config + + self.allow_meta = torch._functorch.config.fake_tensor_allow_meta + self.cache_enabled = torch._dynamo.config.fake_tensor_cache_enabled + self.cache_crosscheck_enabled = ( + torch._dynamo.config.fake_tensor_cache_crosscheck_enabled + ) + + # A flag that controls, whether we want to invoke ops on mix of + # real weights/global variables and fake inputs + self.allow_non_fake_inputs = allow_non_fake_inputs + + # [in_kernel_invocation] + # when FakeTensor is invoked in user code, .device should return + # the fake_device of the tensor so that code such as as `if x.is_cuda` + # or torch.zeros([10, 10], device=x.device) continues to execute as if + # the FakeTensor were real. However, within kernel execution, we return + # the `Meta` device because all computation within the kernels should + # behave as if the Tensors are on meta devices. Kernels should allocate + # new tensors on meta devices, and checks like `is_meta` should return true. + # within python refs, we always return the real device by defining + # the device property + self.in_kernel_invocation = False + + # True if we enter'ed and actually enabled fake tensor mode, + # false if it was a no-op. Not thread safe but neither is + # in_kernel_invocation + # If another fake mode was already active when we enter, we also stash it here. + # That way when we exit, we know to re-enable the previous fake mode. + self.enter_stack: List[Tuple[bool, Optional[FakeTensorMode]]] = [] + + self.shape_env = shape_env + + self.stack = "".join(traceback.format_stack()) + + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FAKE + + # Typically, there is only one fake tensor mode and you test for it by + # doing an isinstance test. However, in some situations, there might be + # TWO fake tensor modes. The canonical example of this is exporting + # a fake model: there is an outer fake mode created by the user, and + # an inner fake mode created by Dynamo. The two phase process is required + # because the outer fake mode typically won't have a ShapeEnv, even if + # the user is interested in exporting with dynamic shapes (so the inner + # fake mode will actually have a ShapeEnv and swap in symbolic sizes.) + # + # In this case, it's insufficient to test only one FakeTensor: you need + # to distinguish between our fake tensor and other fake tensors. That's + # what this function does. + def is_our_fake(self, t): + return isinstance(t, FakeTensor) and t.fake_mode is self + + @count + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + # FakeTensorMode should not be set when we're inside of it. + assert ( + torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None + ), func + try: + return self.dispatch(func, types, args, kwargs) + except TypeError: + log.exception("fake tensor raised TypeError") + raise + + # No-op if FakeTensorMode is already in use + def __enter__(self): + maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) + if self is not maybe_prev_fake_mode: + self.enter_stack.append((True, maybe_prev_fake_mode)) + return super().__enter__() + else: + # no-op (still need to re-set the fake mode though since we unset it) + torch._C._set_dispatch_mode(self) + self.enter_stack.append((False, None)) + return self + + def __exit__(self, a, b, c): + live, maybe_prev_fake_mode = self.enter_stack.pop() + if live: + out = super().__exit__(a, b, c) + # Re-enable the previous fake mode, if there was one. + if maybe_prev_fake_mode is not None: + torch._C._set_dispatch_mode(maybe_prev_fake_mode) + + @classmethod + def cache_info(cls) -> DispatchCacheInfo: + """ + Query the state of the dispatch cache. + """ + return DispatchCacheInfo( + FakeTensorMode.cache_hits, + FakeTensorMode.cache_misses, + dict(FakeTensorMode.cache_bypasses), + len(FakeTensorMode.cache), + ) + + @classmethod + def cache_clear(cls): + """ + Clear the dispatch cache. + """ + cls.cache_hits = 0 + cls.cache_misses = 0 + cls.cache_bypasses.clear() + cls.cache.clear() + + def _cached_dispatch_impl( + self, + func: OpOverload, + types: Tuple[Any, ...], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ): + """ + Lookup a cache entry for the given arguments. If none exists, dispatch + and cache the result (if the result is eligible for caching). + """ + output = unassigned = object() + try: + key = self._cache_key(func, args, kwargs) + entry = FakeTensorMode.cache.get(key, None) + if entry is not None: + output = self._output_from_cache_entry(entry, func, args) + FakeTensorMode.cache_hits += 1 + if self.cache_crosscheck_enabled: + # For debugging / testing: Validate that the output synthesized + # from the cache matches the output created by normal dispatch. + self._crosscheck_cache_output(output, func, types, args, kwargs) + else: + output = self._dispatch_impl(func, types, args, kwargs) + entry = self._make_cache_entry(key, func, args, kwargs, output) + FakeTensorMode.cache[key] = entry + FakeTensorMode.cache_misses += 1 + except _BypassDispatchCache as e: + FakeTensorMode.cache_bypasses[e.reason] += 1 + + if output is unassigned: + output = self._dispatch_impl(func, types, args, kwargs) + + return output + + def _cache_key( + self, + func: OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> _DispatchCacheKey: + """ + Create a cache key given the dispatch args. Raises _BypassDispatchCache + for any situation that precludes caching. + """ + # Avoid caching for any ops that would require a more sophisticated + # caching implementation, e.g., data dependent ops or ops that modify + # the inputs. + if torch.Tag.data_dependent_output in func.tags: + raise _BypassDispatchCache("data dependent output") + + if torch.Tag.dynamic_output_shape in func.tags: + raise _BypassDispatchCache("dynamic output shape") + + if torch.Tag.inplace_view in func.tags: + raise _BypassDispatchCache("inplace view") + + if func == aten._unsafe_view.default: + raise _BypassDispatchCache("unsafe view") + + if func in self.lift_fns: + raise _BypassDispatchCache("lift") + + if not torch._library.utils.is_builtin(func): + raise _BypassDispatchCache("non-builtin") + + # In order to handle storage aliasing, we need to establish the alias + # for any view op on a cache hit. But CompositeImplicitAutograd ops may + # or may not alias the input, so just punt on caching these. + if func.is_view and torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + raise _BypassDispatchCache("CompositeImplicitAutograd") + + key_values = ( + func, + # Translate any FakeTensor args to metadata. + self._prep_args_for_hash(args) if args else (), + self._prep_args_for_hash(kwargs) if kwargs else (), + # Capture the default_dtype mode since that can affect the output tensor, + # e.g., when operating on constant float values. + torch.get_default_dtype(), + # Capture the current device to support, e.g., cache tensor creation, + # where there isn't necessarily a tensor to take the device from. + torch._C._get_default_device(), + # We want to create tensors from cached metadata only when the inference + # mode is the same. + torch.is_inference_mode_enabled(), + # Shape env settings could affect behavior. One example seen in the wild: + # Disasllowing dynamic shapes can introduce a DynamicOutputShapeException + # where it wasn't seen on a previous instance of the same op. + _ShapeEnvSettings(self.shape_env) if self.shape_env else None, + ) + return _DispatchCacheKey(key_values) + + def _prep_args_for_hash(self, args: Any) -> Any: + """ + Translate the provided args into a form suitable for caching at FakeTensor + dispatch, i.e., convert unhashable types like lists & dicts into tuples and + convert FakeTensors into metadata. Raises _BypassDispatchCache to signal + unsupported cases that should bypass caching. + """ + if isinstance(args, dict): + args = list(args.keys()) + list(args.values()) + + result = [] + for arg in args: + if isinstance(arg, FakeTensor): + if not self.is_our_fake(arg): + raise _BypassDispatchCache("not our fake") + if arg._has_symbolic_sizes_strides: + raise _BypassDispatchCache("symbolic shape") + if arg.constant is not None: + raise _BypassDispatchCache("constant attribute") + if arg.is_sparse: + raise _BypassDispatchCache("sparse tensor") + if is_sparse_compressed(arg): + raise _BypassDispatchCache("sparse compressed tensor") + result.append(extract_tensor_metadata(arg)) + elif isinstance(arg, torch.Tensor): + raise _BypassDispatchCache("non-fake tensor") + elif isinstance(arg, (torch.SymBool, torch.SymInt, torch.SymFloat)): + raise _BypassDispatchCache("symbolic shape") + elif isinstance(arg, (list, tuple, dict)): + result.extend(self._prep_args_for_hash(arg)) + else: + # It's important to capture the type of the arg since, e.g., 1 and 1.0 + # hash to the same value, but can produce different dtypes for the + # output tensor. + result.append((type(arg), arg)) + + return tuple(result) + + def _make_cache_entry( + self, + key: _DispatchCacheKey, + func: OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + output: FakeTensor, + ) -> _DispatchCacheEntry: + """ + Make a cache entry object for the given 'output' Tensor. Raises + _BypassDispatchCache if the output tensor has characteristics that + prevent caching it. + """ + # Some ops return tuples of Tensors, but it's rare, so avoid + # the complexity of caching other types. + if not isinstance(output, FakeTensor): + raise _BypassDispatchCache("non-FakeTensor output") + + # Avoid caching FakeTensors with constants attached since those + # can be invalidated. + if output.constant is not None: + raise _BypassDispatchCache("constant attribute") + + # TODO: support caching sparse outputs? + if output.is_sparse: + raise _BypassDispatchCache("sparse output") + + if is_sparse_compressed(output): + raise _BypassDispatchCache("sparse compressed output") + + # Can an in-place op really reference a kwarg? If so, then we need + # to extend the implementation to handle it. + for kval in kwargs.values(): + if id(kval) == id(output): + raise _BypassDispatchCache("kwarg aliases output") + + # If this is an in-place op, the entry records which input arg is aliased. + for idx in range(len(args)): + if id(args[idx]) == id(output): + return _DispatchCacheEntry( + inplace_idx=idx, metadata=None, view_idx=None + ) + + # Otherwise, create an entry that records the output tensor's metadata. + view_idx = None + if func.is_view: + idxs = [i for i, t in enumerate(args) if isinstance(t, torch.Tensor)] + assert len(idxs) == 1 + view_idx = idxs[0] + + metadata = extract_tensor_metadata(output) + entry = _DispatchCacheEntry( + inplace_idx=None, metadata=metadata, view_idx=view_idx + ) + + # N.B.: Some checks for bypassing the cache would be performed on the + # output tensor synthesized from the cached metadata. As an optimization, + # we can synthesize a tensor here and do the checks on that instance. + # This approach keeps the (more frequent) cache-hit path as lightweight + # as possible. + synth_output = self._output_from_cache_entry(entry, func, args) + + # Make sure the dispatch_key_set from the synthesized output tensor will + # be the same. + synth_key_set = torch._C._dispatch_key_set(synth_output) + key_set = torch._C._dispatch_key_set(output) + if synth_key_set != key_set: + raise _BypassDispatchCache("dispatch_key_set mismatch") + + return entry + + def _output_from_cache_entry( + self, entry: _DispatchCacheEntry, func: OpOverload, args: Tuple[Any, ...] + ) -> FakeTensor: + """ + Create a new FakeTensor from the cache entry. + """ + if entry.inplace_idx is not None: + # This is an in-place op; return the aliased arg. + return args[entry.inplace_idx] + + # Synthesize a new FakeTensor with the cached metadata. + metadata = entry.metadata + assert not metadata.is_sparse + + empty = torch.empty_strided( + metadata.shape, + metadata.stride, + dtype=metadata.dtype, + layout=metadata.layout, + device="meta", + requires_grad=metadata.requires_grad, + ) + + if metadata.is_conj: + torch._C._set_conj(empty, True) + if metadata.is_neg: + torch._C._set_neg(empty, True) + + if func.is_view: + # For view ops, the storage should be the same as the tensor input. + storage = args[entry.view_idx].untyped_storage() + with in_kernel_invocation_manager(self): + empty.set_( + storage, metadata.storage_offset, metadata.shape, metadata.stride + ) + elif metadata.storage_offset != 0: + storage = empty.untyped_storage() + with in_kernel_invocation_manager(self): + empty.set_( + storage, metadata.storage_offset, metadata.shape, metadata.stride + ) + + return FakeTensor(self, empty, metadata.device) + + def _crosscheck_cache_output( + self, + output: FakeTensor, + func: OpOverload, + types: Tuple[Any, ...], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ): + """ + Helper to validate that the output synthesized from the cache matches + the output created by normal dispatch. + """ + try: + true_output = self._dispatch_impl(func, types, args, kwargs) + except Exception as e: + raise RuntimeError( + f"FakeTensor cache crosscheck failure: func={func}, " + f"args={args}, kwargs={kwargs}: Dispatch raised={e}" + ) from e + try: + assert_metadata_eq(assert_eq, true_output, output) + except Exception as e: + raise RuntimeError( + f"FakeTensor cache crosscheck failure: func={func}, " + f"args={args}, kwargs={kwargs}" + ) from e + + def dispatch(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + with no_dispatch(): + log.debug("%s %s %s", func, args, kwargs) + + if func in _DISPATCH_META_HANDLERS: + return _DISPATCH_META_HANDLERS[func](args) + + if log.getEffectiveLevel() <= logging.DEBUG: + log.debug( + "%sFakeTensorMode.__torch_dispatch__: %s", " " * RECURSION_COUNT, func + ) + # NOTE: incr is intentionally unused for a RAII pattern + incr = IncrementRecursionCount() + + # Some attribute queries that can be serviced directly + # See Note [is_coalesced is dispatched] + if func in _DISPATCH_HANDLE_DIRECTLY: + # NB: no_dispatch is ok here too, this func is very simple + with in_kernel_invocation_manager(self): + return func(*args, **kwargs) + + if self.cache_enabled: + return self._cached_dispatch_impl(func, types, args, kwargs) + else: + return self._dispatch_impl(func, types, args, kwargs) + + def _dispatch_impl(self, func, types, args, kwargs): + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + + flat_arg_fake_tensors = [ + t for t in flat_args if isinstance(t, FakeTensor) and self.is_our_fake(t) + ] + has_symbolic_sizes = any( + i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors + ) or any(isinstance(a, torch.SymInt) for a in flat_args) + + converter = self.fake_tensor_converter + + def maybe_to_constant(t): + if isinstance(t, FakeTensor) and self.is_our_fake(t): + return t.constant + else: + return t + + # To constant propagate through these functions: + # 1, If this is a lift due to a torch.tensor call, + # the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point. + # (Note that you can always call a lift fn manually, so we do + # have to check if there are any fake tensors!) + # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div + if (func in self.lift_fns and not flat_arg_fake_tensors) or ( + should_allow_numbers_as_tensors(func) + and not has_symbolic_sizes + and not flat_arg_fake_tensors + ): + assert all( + t.constant is not None for t in flat_arg_fake_tensors + ), f"{func} should not have fake inputs without constants" + const_flat_args = [maybe_to_constant(a) for a in flat_args] + const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) + out = func(*const_args, **const_kwargs) + if type(out) is torch.Tensor and self.may_turn_const(out): + # NB: not in_kernel_invocation_manager because we're doing real + # compute here + # NB: no_dispatch() here is VERY DANGEROUS (like, segfault + # dangerous) if this is actually a wrapper subclass tensor, + # therefore the exact type test above + with no_dispatch(): + out = out.clone() + return converter(self, out, make_constant=True) + + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + unrecognized_types = self.check_for_subclass(flat_args) + if unrecognized_types: + not_implemented_log.debug( + "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + # if we are in the dispatch mode, we will enter this function even if the inputs + # are not FakeTensors. For now, throw if any non-Fake Tensor inputs + # and just support constructors. + + # this is generated from torch.tensor(), which does not use the + # dispatcher, to allow wrapper subclasses to wrap the new tensor + if func in self.lift_fns: + assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}" + + if type(args[0]) is torch.Tensor: + return converter(self, args[0]) + + # Recompute flat_arg_fake_tensors here again in case some of the inputs + # were real tensors and fakified in validate_and_convert_non_fake_tensors + (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( + func, converter, flat_args, args_spec + ) + del args, kwargs # Invalidated + + # The current constant handling only support tracing systems + # (aot autograd, torchdynamo) where each operation is run consecutively. + # Because each operation is run in order, we can trace out and support + # sequences like: x = torch.tensor(0.); y = x.add_(1) + # Whenver a constant is written to but with inputs that cannot be evaluated + # statically, such as random_(), we invalidate all constants that alias the input + # We will rely on functionalization for use of fake tensors constants as persistent + # objects on an FX Graph. + + # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view + all_constant = all(e.constant is not None for e in flat_arg_fake_tensors) + if ( + torch.Tag.nondeterministic_seeded not in func.tags + and torch.Tag.inplace_view not in func.tags + and all_constant + and len(flat_arg_fake_tensors) != 0 + and not has_symbolic_sizes + ): + const_flat_args = [maybe_to_constant(a) for a in flat_args] + const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) + + # NB: not in_kernel_invocation_manager(self) as we want to do REAL + # compute + with no_dispatch(): + out = func(*const_args, **const_kwargs) + + flat_out = pytree.tree_leaves(out) + flat_out_tensors = [t for t in flat_out if isinstance(t, torch.Tensor)] + all_constant = all(self.may_turn_const(t) for t in flat_out_tensors) + + if all_constant: + return pytree.tree_map_only( + torch.Tensor, + lambda t: converter(self, t, make_constant=True), + out, + ) + + # we weren't able to turn outputs to constants, + # so invalidate all constants that might be aliases of the outputs + for ten in flat_out_tensors: + converter.invalidate_constant_aliases(ten) + + # we are falling through to running non constant tensors, any input constant that + # is written to must be invalidated + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) + + # Try for fastpath + if has_symbolic_sizes: + fast_impl = get_fast_op_impls().get(func) + if fast_impl is not None: + return fast_impl(self, *args, **kwargs) + + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table as meta_table + + if func not in meta_table and not self.cpp_meta_supports_symint(func): + from torch._decomp import decomposition_table + + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) + ) + ): + with self: + return decomposition_table[func](*args, **kwargs) + + with self: + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + # prims already wrap FakeTensor inputs to FakeTensor outputs + # and do device logic, we dont need do anything but run them + # and ensure that Meta kernels are dispatched to (see) + # Fake Tensor Dispatch Keys + # TODO - we should be use the prim aten impl + # TODO - fix prims complex ops + if ( + "prims::" in func._schema.name + and hasattr(func, "prim_meta_impl") + and not stride_incorrect_op(func) + ): + with self: + return func.prim_meta_impl(*args, **kwargs) + + # Users can register FakeTensor rules for custom operators + # Call them if they exist. + maybe_abstract_impl = torch._library.simple_registry.singleton.find( + func.name() + ).abstract_impl.kernel + if maybe_abstract_impl: + ctx = torch._library.abstract_impl.AbstractImplCtx(self.shape_env, func) + with torch._library.abstract_impl.set_ctx_getter(lambda: ctx), self: + result = maybe_abstract_impl(*args, **kwargs) + return result + + # special handling for funcs registered through `register_op_impl`, + # e.g., manipulating args on constructor calls to construct meta tensors + # and then afterwards wrapping them to a FakeTensor + for run_impl_check, op_impl in op_implementations_checks: + if run_impl_check(func): + op_impl_out = op_impl(self, func, *args, **kwargs) + if op_impl_out != NotImplemented: + return op_impl_out + + def maybe_run_unsafe_fallback(error=None): + # We infer the meta of a custom ops that return None to just + # return None. custom ops are not allowed to mutate metadata + # of their inputs, so this is safe. + if can_generate_trivial_abstract_impl(func): + return None + # no meta kernel registered, fallback to kernel for the device + if has_symbolic_sizes or not self.can_run_unsafe_fallback(func): + raise UnsupportedOperatorException(func) + if error is None: + error = UnsupportedOperatorException(func) + return run_fallback_kernel(self, func, flat_args, args_spec, error) + + # Optimization: If there is no Meta kernel, it takes a surprisingly long + # amount of time to catch the NotImplementedError, so we check it here. + if not has_meta(func): + return maybe_run_unsafe_fallback() + + # run kernel registered to meta for func, which include + # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) + # It's possible that the kernel will return NotImplementedError + try: + with in_kernel_invocation_manager(self): + r = func(*args, **kwargs) + except NotImplementedError as not_implemented_error: + return maybe_run_unsafe_fallback(not_implemented_error) + + return self.wrap_meta_outputs_with_default_device_logic( + r, func, flat_args, device=kwargs.get("device") + ) + + # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators + # outside of the pytorch/pytorch library! Any pre-existing things here + # are either in the pytorch/pytorch library or have been grandfathered in. + # The fallback does not always work and MAY CRASH and emit unreadable error messages + # so it should not be allowed by default. + _can_run_unsafe_fallback_allowed_namespaces = ordered_set( + "debugprims", + "prims", + "aten", + "xla", + "vision", + "torchtext", + "torchaudio", + "quantized", + ) + + def can_run_unsafe_fallback(self, func: OpOverload): + if not self.allow_fallback_kernels: + return False + # It's OK to try the fallback for built-in ops (e.g. aten, prims) + # because we control and test these but the fallback leads to unexpected behavior + # in user-defined custom ops + return ( + func.namespace in self._can_run_unsafe_fallback_allowed_namespaces + or func.name() == "fbgemm::gmm" + ) + + # [subclass inputs] + # Suppose we enable fake tensor mode. This means that fake tensor + # mode will run first. But what if we do an operation that + # involves a tensor subclass that will desugar into normal tensor + # operations? Without returning NotImplemented, fake tensor mode will run first, + # decide that a conversion was made (since there was a non fake + # tensor argument), and report an error that converting non + # fake tensor is not supported. What we actually wanted to happen + # was to give the subclass a chance to figure out what it wants to + # before erroring out. Returning NotImplemented here allows this. + def check_for_subclass(self, flat_args): + def check(x): + return ( + isinstance(x, torch.Tensor) + and not isinstance(x, FakeTensor) + and type(x) is not torch.Tensor + and type(x) is not torch.nn.Parameter + ) + + return [type(x) for x in flat_args if check(x)] + + def validate_and_convert_non_fake_tensors( + self, func, converter, flat_args, args_spec + ): + """ + Checks if the list of tensors are fake tensors. + If not, try to convert them to fake tensors. + Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors. + """ + flat_arg_fake_tensors = [] + + def validate(x): + if not isinstance(x, torch.Tensor): + return x + + nonlocal flat_arg_fake_tensors + if not self.is_our_fake(x): + if torch.Tag.inplace_view in func.tags: + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + raise Exception( + f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}" + ) + if not self.allow_non_fake_inputs: + if isinstance(x, FakeTensor) and x.fake_mode is not self: + raise AssertionError("Mixing fake modes NYI") + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + raise Exception( + f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode " + f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}" + ) + + x = converter(self, x) + + flat_arg_fake_tensors.append(x) + return x + + validated_args = [validate(a) for a in flat_args] + return validated_args, flat_arg_fake_tensors + + def wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device): + converter = self.fake_tensor_converter + + # Lazily initialized, in case there are no tensor returns + common_device = None + has_scalar_only_inputs = False + + def wrap(e): + nonlocal common_device + nonlocal has_scalar_only_inputs + + if isinstance(e, torch.Tensor) and common_device is None: + ( + common_device, + has_scalar_only_inputs, + ) = FakeTensor._find_common_device(func, flat_args) + + if self.is_our_fake(e): + torch._check( + e.device == common_device, + lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", + ) + + if ( + isinstance(e, torch.Tensor) + and not self.is_our_fake(e) + and converter is not None + ): + if has_scalar_only_inputs: + # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div, + # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details. + # We thus directly convert real tensor to fake tensor. + return converter(self, e) + else: + return converter.from_meta_and_device( + self, e, device or common_device + ) + else: + return e + + return tree_map(wrap, r) + + _cpp_meta_supports_symint = ordered_set( + aten.empty.memory_format, + aten.empty_strided.default, + aten.as_strided_scatter.default, + aten.as_strided.default, + aten.as_strided_.default, + aten.zeros.default, + aten.detach.default, + aten.view_as_real.default, + aten.view_as_complex.default, + aten.set_.source_Storage_storage_offset, + aten._sparse_coo_tensor_with_dims_and_tensors.default, + ) + + def cpp_meta_supports_symint(self, func): + if torch.Tag.view_copy in func.tags: + return True + return func in self._cpp_meta_supports_symint + + lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default) + + def may_turn_const(self, t): + return ( + t.numel() <= CONSTANT_NUMEL_LIMIT + and not t.is_sparse + and not self.is_our_fake(t) + and not t.device.type == "meta" + ) + + def invalidate_written_to_constants( + self, func, flat_arg_fake_tensors, args, kwargs + ): + any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) + schema_info = get_schema_info(func) + if any_constant and schema_info.is_mutable(): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + for k, v in new_kwargs.items(): + k = k if (k != "input" or schema_info.has_argument(k)) else "self" + if ( + self.is_our_fake(v) + and schema_info.is_mutable(k) + and v.constant is not None + ): + self.fake_tensor_converter.invalidate_constant_aliases(v.constant) + + def from_tensor( + self, + tensor, + *, + static_shapes=None, + source: Optional[Source] = None, + symbolic_context=None, + # Setting this flag will force FakeTensorMode to return `None` if attempting to convert a tensor we have not + # seen before. + memoized_only=False, + ): + shape_env = self.shape_env + if static_shapes is None: + static_shapes = self.static_shapes + if static_shapes: + assert ( + symbolic_context is None + ), "cannot set both static_shapes and symbolic_context" + shape_env = None + # see note [Tensor Fakification and Symbol Caching] + if not symbolic_context and not source and not static_shapes: + if tracing_context := torch._guards.TracingContext.try_get(): + if tensor in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[tensor] + source = symbolic_context.tensor_source + return self.fake_tensor_converter( + self, + tensor, + shape_env=shape_env, + source=source, + symbolic_context=symbolic_context, + memoized_only=memoized_only, + ) + + +# NB: returns fake tensors +def run_fallback_kernel( + fake_mode, func, flat_args, args_spec, orig_not_implemented_exception +): + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: + raise orig_not_implemented_exception + + inp_impls = {} + + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): + if fake_mode.is_our_fake(e): + out = torch.zeros_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + + r = func(*args, **kwargs) + + tensor_impls = set() + storages = set() + + for e in flat_args: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return fake_mode.fake_tensor_converter(fake_mode, e) + else: + return e + + return pytree.tree_map(map_out, r) + + +def can_generate_trivial_abstract_impl(op: torch._ops.OpOverload) -> bool: + assert isinstance(op, torch._ops.OpOverload) + if torch._library.utils.is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution + if not schema.is_mutable: + return False + if len(schema.returns) > 0: + return False + # If the op returns nothing, then it has a trivial abstract impl. + return True + + +# Just for use to allow copying a module to fake tensors, +# does not apply elsewhere +class FakeCopyMode(TorchFunctionMode): + def __init__(self, fake_mode): + self.fake_mode = fake_mode + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + # clone will get called in Parameter deepcopy + if func == torch._C.TensorBase.clone: + return func( + self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs + ) + elif func == torch.Tensor.__deepcopy__: + assert len(args) == 2 and len(kwargs) == 0 + tensor, memo = args + + if id(tensor) in memo: + return memo[id(tensor)] + + out = self.fake_mode.from_tensor(tensor, static_shapes=True) + memo[id(tensor)] = out + return out + else: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +def _device_handler(args): + # NB: Don't use is_our_fake, just serve the fake information + # as is. Notice we don't use 'self'; we use args[0].fake_mode + # because they may not be the same. It would also be possible + # to return NotImplemented here, in which case the FakeTensor + # handler on args[0] would handle it, but we're being nice and + # short-circuiting quickly. + assert len(args) == 1 and isinstance(args[0], FakeTensor) + if args[0].fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return args[0].fake_device + + +_DISPATCH_META_HANDLERS = { + torch.ops.prim.device.default: _device_handler, + torch.ops.aten.size.default: lambda args: tuple(int(s) for s in args[0].size()), + torch.ops.aten.stride.default: lambda args: tuple(int(s) for s in args[0].stride()), + torch.ops.aten.storage_offset.default: lambda args: int(args[0].storage_offset()), +} + +_DISPATCH_HANDLE_DIRECTLY = ordered_set( + torch.ops.aten.is_coalesced.default, + torch.ops.aten.dense_dim.default, + torch.ops.aten.sparse_dim.default, +) + +from torch._subclasses.fake_impls import ( # noqa: F401 + _device_not_kwarg_ops, # noqa: F401 + _is_tensor_constructor, # noqa: F401 + _like_tensor_constructors, # noqa: F401 + contains_tensor_types, # noqa: F401 + get_fast_op_impls, + has_meta, + op_implementations_checks, + stride_incorrect_op, +) diff --git a/MLPY/Lib/site-packages/torch/_subclasses/fake_utils.py b/MLPY/Lib/site-packages/torch/_subclasses/fake_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7fc4e85946b49ee558ac1898dfc8b41e5d5849 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/fake_utils.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +import functools +import warnings +from typing import Callable, Union + +import torch +import torch.utils._pytree as pytree +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import ( + FakeTensorMode, + tree_flatten_only, + UnsupportedFakeTensorException, +) +from torch.utils._python_dispatch import TorchDispatchMode + + +aten = torch._ops.ops.aten + + +def outputs_alias_inputs(outputs, inputs): + input_storages = { + inp._typed_storage()._cdata + for inp in tree_flatten_only(torch.Tensor, inputs) + if torch._C._has_storage(inp) + } + return any( + torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages + for out in tree_flatten_only(torch.Tensor, outputs) + ) + + +def outputs_are_inputs(outputs, inputs): + input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} + return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) + + +def output_alias_each_other(outputs): + storages = set() + for out in tree_flatten_only(torch.Tensor, outputs): + if not torch._C._has_storage(out): + continue + stor = out._typed_storage()._cdata + if stor in storages: + return True + storages.add(stor) + return False + + +def is_sdpa_error(func, idx, e): + if ( + ( + func is aten._scaled_dot_product_flash_attention.default + or func is aten._flash_attention_forward.default + ) + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True + if ( + ( + func is aten._scaled_dot_product_efficient_attention.default + or func is aten._efficient_attention_forward.default + ) + and idx in (2, 3) + and "Devices" in repr(e) + ): + return True + return False + + +class CrossRefFakeMode(TorchDispatchMode): + def __init__( + self, + ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, + *, + check_strides=True, + check_aliasing=True, + ): + self.ignore_op_fn = ( + ignore_op_fn if ignore_op_fn is not None else lambda fn: False + ) + self.check_strides = check_strides + self.check_aliasing = check_aliasing + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + fake_r = None + + # empty_like excluded for now due to sparse complex + # aten._to_dense.default this one is getting called with csc + if ( + func + not in ( + aten.lift_fresh.default, + aten.lift_fresh_copy.default, + aten.set_.source_Storage_storage_offset, + ) + and not self.ignore_op_fn(func) + and torch.Tag.dynamic_output_shape not in func.tags + and torch.Tag.inplace_view not in func.tags + and torch.Tag.data_dependent_output not in func.tags + ): + # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + try: + # TODO: enable_python_dispatcher() here + with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: + fake_args, fake_kwargs = pytree.tree_map_only( + torch.Tensor, + functools.partial(fake_mode.from_tensor, static_shapes=True), + (args, kwargs), + ) + with warnings.catch_warnings(): + fake_r = func(*fake_args, **fake_kwargs) + except UnsupportedFakeTensorException: + pass + + context = ( + f"When comparing the output of {func} on FakeTensor and concrete Tensors, " + f"found" + ) + r = func(*args, **kwargs) + if fake_r is not None: + r_flat = pytree.tree_leaves(r) + f_flat = pytree.tree_leaves(fake_r) + assert len(f_flat) == len( + r_flat + ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" + + if self.check_aliasing: + r_aliasing = outputs_alias_inputs(r, (args, kwargs)) + f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs)) + assert ( + r_aliasing == f_aliasing + ), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" + + r_identity_eq = outputs_are_inputs(r, (args, kwargs)) + f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs)) + assert ( + r_identity_eq == f_identity_eq + ), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" + + r_output_alias_each_other = output_alias_each_other(r) + f_output_alias_each_other = output_alias_each_other(fake_r) + assert r_output_alias_each_other == f_output_alias_each_other, ( + f"{context} mismatch in outputs_alias_each_other check " + f"{f_output_alias_each_other} != {r_output_alias_each_other}" + ) + + for idx, (r_out, fake_out) in enumerate( + zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) + ): + r_is_ten = isinstance(r_out, torch.Tensor) + assert r_is_ten == isinstance( + fake_out, torch.Tensor + ), f"{context} mismatched number of tensor outputs" + if r_is_ten: + assert r_out.requires_grad == fake_out.requires_grad, ( + f"{context} mismatched requires_grad-ness of outputs. " + f"This usually means that you have added autograd support " + f"for your operator at a dispatch key other than Autograd, " + f"which will lead to problems" + ) + if torch._C._has_storage(r_out): + r_offset = r_out.storage_offset() + f_offset = fake_out.storage_offset() + assert ( + r_offset == f_offset + ), f"{context} mismatched storage offset" + + try: + torch._prims.utils.compare_tensor_meta( + r_out, + fake_out, + check_strides=self.check_strides, + allow_rhs_unbacked=True, + ) + except Exception as e: + if is_sdpa_error(func, idx, e): + continue + error_message = ( + f"{context} mismatched tensor metadata: {e}" + if len(r_flat) == 1 + else f"{context} mismatched tensor metadata for output[{idx}]: {e}" + ) + raise RuntimeError(error_message) from e + return r diff --git a/MLPY/Lib/site-packages/torch/_subclasses/functional_tensor.py b/MLPY/Lib/site-packages/torch/_subclasses/functional_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..67b5215a6a2d4195e4cbf0f7e209565237e42acd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/functional_tensor.py @@ -0,0 +1,653 @@ +import contextlib +from abc import ABC, abstractmethod +from typing import Any, Callable, ContextManager, Dict, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._C import _functionalization_reapply_views_tls as _reapply_views +from torch._ops import _get_dispatch_mode_pre_dispatch +from torch.utils._python_dispatch import ( + _detect_functional_mode, + _disable_infra_mode, + return_and_correct_aliasing, + TorchDispatchMode, +) + +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + + +class FunctionalTensor(torch.Tensor): + """ + Functional tensors represent tensors that will remove mutations + from a program. If you perform a mutable operation on a functional tensor, + it will re-dispatch to the functional variant of that operation. + + Historically, functionalization is implemented in C++ in the dispatcher. + This class is a lightweight python shim around the C++ functionalization logic. + + FunctionalTensor is required to be used with a corresponding + FunctionalTensormode active, because it relies + on using the mode for dispatch (which can properly handle factory functions). + """ + + elem: torch.Tensor + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL + + # Note: The reason we add these extra keys to our FunctionalTensor subclass + # is to mirror the behavior of C++ functionalization (we can choose to change this + # later, as long as it doesn't break anything). + # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor + # to the wrapper, excluding functorch and python dispatch keys. + # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy, + # except that they don't include ZeroTensor so I'm manually adding it in. + _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add( + torch._C.DispatchKey.ZeroTensor + ) + + # These are all aten ops that correspond to metadata queries. + # We want FunctionalTensor to be able to handle them directly. + metadata_fns = [ + torch.ops.aten.is_contiguous.default, # type: ignore[has-type] + torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type] + torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type] + torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type] + torch.ops.aten.size.default, # type: ignore[has-type] + torch.ops.aten.sym_size.default, # type: ignore[has-type] + torch.ops.aten.stride.default, # type: ignore[has-type] + torch.ops.aten.sym_stride.default, # type: ignore[has-type] + torch.ops.aten.storage_offset.default, # type: ignore[has-type] + torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type] + torch.ops.aten.numel.default, # type: ignore[has-type] + torch.ops.aten.sym_numel.default, # type: ignore[has-type] + torch.ops.aten.dim.default, # type: ignore[has-type] + torch.ops.prim.device.default, # type: ignore[has-type] + ] + + # These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing + # TODO (tmanlaibaatar) make it a tag + maybe_aliasing_or_mutating_ops = [ + torch.ops.aten.dropout.default, # type: ignore[has-type] + torch.ops.aten.batch_norm.default, # type: ignore[has-type] + torch.ops.aten.native_batch_norm.default, # type: ignore[has-type] + torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type] + torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type] + torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type] + ] + + def __new__(cls, elem): + assert torch._is_functional_tensor(elem) + + # In general, we'd like our functional tensor subclass to only be in charge of functionalization, + # and defer to the inner subclass for all other functionality. + # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback + # until after we redispatch to our inner ZeroTensor. + # However, there are a few keys that we need to mirror between the inner and outer tensors. + # Conjugate + # Negative + # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`. + # We **need** calls to is_conj() to return the same thing on the outer and inner tensors, + # Because user code / framework code that branches like so needs to do the same thing + # when it sees the outer FunctionalTensor: + # if (x.is_conj()) { + # return at::view_as_real(x.resolve_conj()); + # } else { + # return at::view_as_real(x); + # } + extra_dispatch_keys = ( + FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem) + ) + + out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined] + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + elem.shape, # sizes + elem.stride(), # strides + elem.storage_offset(), # storage_offset + None, # memory_format + elem.dtype, # dtype + elem.layout, # layout + elem.device, # device + False, # pin_memory + elem.requires_grad, # requires_grad + "sizes", # dispatch_sizes_strides_policy + False, # dispatch_device + False, # dispatch_layout + extra_dispatch_keys, # _extra_dispatch_keys + ) + out.elem = elem + return out + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + unrecognized_types = [ + t + for t in types + if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor] + ] + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + if kwargs is None: + kwargs = {} + + # FunctionalTensor needs to plumb all metadata requests to the inner tensor. + # In theory we don't have to do this - but if we want to service metadata requests here, + # we need to carefully make sure all metadata is accurate (including metadata mutations) + if func in FunctionalTensor.metadata_fns: + # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry + # about the problem of keeping metadata in sync between the wrapper and inner tensor. + # This also alleviates us from having to manually handle metadata mutations on the wrapper. + assert len(kwargs) == 0 + if func in [ + torch.ops.aten.is_strides_like_format.default, + torch.ops.aten.is_contiguous.memory_format, + ]: + assert len(args) == 2 and isinstance(args[0], FunctionalTensor) + return func(args[0].elem, args[1]) + assert len(args) == 1 and isinstance(args[0], FunctionalTensor) + + return func(args[0].elem) + # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up: + # - _make_wrapper_subclass requires a __torch_dispatch__ + # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor, + # which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper. + # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(), + # which causes every subclass created above autograd to have autograd view metadata + # (in addition to also being a FunctionalTensorWrapper). + raise RuntimeError( + "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" + ) + + def __repr__(self): + return f"FunctionalTensor({repr(self.elem)})" + + @staticmethod + def to_functional(x): + # We will do the wrapping for the user. + assert not torch._is_functional_tensor(x) + # The only autograd metadata we care about on the FunctionalTensor is: + # - requires_grad (so autograd runs) + # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) + # this is handled by FunctionalTensor.to_functional + x_functional = torch._to_functional_tensor(x) + # Technically the FunctionalTensormode here is unnecessary, + # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing. + # _mirror_autograd_meta_to queries tensor sizes, + # and otherwise the sym_size() call will go to the proxy mode before hitting + # FunctionalTensor.__torch_dispatch__ + + functional_mode = _detect_functional_mode() + assert functional_mode is not None + + with functional_mode: + torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] + out = FunctionalTensor(x_functional) + torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] + return out + + def from_functional(self): + torch._sync(self) + return torch._from_functional_tensor(self.elem) + + def replace_(self, output) -> None: + torch._functionalize_replace(self.elem, output) + + def commit_update(self) -> None: + torch._functionalize_commit_update(self.elem) + + def sync(self) -> None: + torch._functionalize_sync(self.elem) + + def mark_mutation_hidden_from_autograd(self) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(self.elem) + + def tolist(self) -> Any: + if self.elem.dim() == 0: + return self.elem.item() + elif self.elem.dim() == 1: + return [elem.item() for elem in self.elem] + else: + return [elem.tolist() for elem in self.elem] + + +class FunctionalTensorMode(TorchDispatchMode): + def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False): + self.export = export + self.is_on_stack = False + self.enter_stack = [] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL + self.pre_dispatch = pre_dispatch + # This will be turned off later for pre-dispatch functionalization + self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined] + # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep + # track of the ordering between side effectful operations. + self._tokens: Dict[Any, torch.Tensor] = {} + + # Functionalization runs twice in AOTAutograd, once in + # `run_functionalized_fw_and_collect_metadata` to collect metadata to + # see which tensors need to be functionalized and discover how many + # tokens we need, and another time in `make_fx` which does the actual + # tracing to replace ops with their functional variants and handling + # side-effectful ops. In the second stage there should be no token + # discovery. This flag distinguishes between the two stages. + self._allow_token_discovery = _allow_token_discovery + + # No-op if FunctionalTensorMode is already in use + def __enter__(self): + def _get_prev_mode(): + if self._dispatch_key == torch._C.DispatchKey.PreDispatch: + return _get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + return torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + + if _get_prev_mode() is None: + self.enter_stack.append(True) + return super().__enter__() + else: + self.enter_stack.append(False) + return self + + def __exit__(self, a, b, c): + is_on_stack = self.enter_stack.pop() + if is_on_stack: + super().__exit__(a, b, c) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t not in [torch.Tensor, FunctionalTensor] + ] + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + def _can_decompose(func): + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 + # We never decompose dropout in export + if self.export and func == torch.ops.aten.dropout.default: + return False + # TODO (tmanlaibaatar) + # Eventually, we don't want to decompose any aten op at all + # but there is a safety and coverage gap that we need to close + # before that. + # + # (1) the "safety" is what we are risking with this PR + # (we are blindly taking every op that advertises as + # functional and sending it to the functional fallback. + # We risk silent correctness if we have an op that lies about its schema, + # that we didn't manually hardcode above) Therefore we always decompose them + # (2) the "not every composite inplace op has a functional variant" is a coverage gap, + # but not really a safety risk, since we'll loudly error when we try to generate + # functionalization kernels for these new (composite) inplace/view ops. But until we + # establish such gap more concretely, we still decompose them + if self._dispatch_key is not None: + # it is unsafe to not decompose ops that claim to be functional but actually aren't + if func in FunctionalTensor.maybe_aliasing_or_mutating_ops: + return True + # only decompose view or inplace mutating ops + alias_info = len( + [i for i in func._schema.arguments if i.alias_info is not None] + ) + return alias_info != 0 or func._schema.is_mutable + return True + + if ( + func not in FunctionalTensor.metadata_fns + and _can_decompose(func) + # Not all funcs from __torch_dispatch__ are actual dispatcher ops, + # e.g. prim.device + and torch._C._dispatch_has_kernel(func.name()) + ): + with self: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + def assert_is_functional(x): + assert torch._is_functional_tensor(x) + + def wrap(x): + # Only wrap our outputs in subclasses if the inner functionalization call + # also wrapped outputs into FunctionalTensorWrappers. + # When can this happen? e.g. `torch.div(2, 2)` + assert not isinstance(x, FunctionalTensor) + if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): + return FunctionalTensor(x) + return x + + def unwrap(x): + return x.elem + + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize, + ) + + if can_auto_functionalize( + func + ) and not torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.Functionalize + ): + if self.pre_dispatch: + raise NotImplementedError( + "Auto functionalization is not supported on pre-dispatch tracing" + ) + return do_auto_functionalize(func, args, kwargs) + + from torch._higher_order_ops.effects import handle_effects, has_effects + + if has_effects(func, args, kwargs): + assert not torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.Functionalize + ) + return handle_effects( + self._allow_token_discovery, self._tokens, func, args, kwargs + ) + + args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( + FunctionalTensor, unwrap, (args, kwargs) + ) + + # Expectation: functionalization should not **already** be enabled above our mode. + # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization + # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper. + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + assert is_excluded or not is_included + include_to_set = ( + torch._C._dispatch_tls_local_include_set() + | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + exclude_to_set = ( + torch._C._dispatch_tls_local_exclude_set().remove( + torch._C.DispatchKey.Functionalize + ) + - FunctionalTensor._extra_dispatch_keys + ) + + # All we want to do here is re-use the existing C++ functionalization logic. + # This requires swizzling our TLS dispatch keys so that the Functionalize key is active. + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + try: + # By default for python functionalization (for AOTAutograd), we reapply views. + old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined] + + # Sometimes these functions cannot be directly dispatched to functionalize key + # because args are sometimes not functional tensors for some reason? + if func in FunctionalTensor.metadata_fns: + outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped) + outs_wrapped = pytree.tree_map_only( + torch.Tensor, wrap, outs_unwrapped + ) + else: + # When we dispatch to the C++ functionalization kernel, we might need to jump back to the + # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath + # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch + # from the TLS in order to avoid infinite looping, but this would prevent us from coming + # back to PreDispatch later + outs_unwrapped = func._op_dk( + torch._C.DispatchKey.Functionalize, + *args_unwrapped, + **kwargs_unwrapped, + ) + # We don't allow any mutation on result of dropout + if self.export and func == torch.ops.aten.dropout.default: + torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] + outs_wrapped = pytree.tree_map_only( + torch.Tensor, wrap, outs_unwrapped + ) + finally: + torch._disable_functionalization() + torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined] + + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + assert is_excluded or not is_included + + if ( + # If no outputs are our functional subclass, then don't try to fix up aliasing + not any( + isinstance(x, FunctionalTensor) + for x in pytree.tree_leaves(outs_wrapped) + ) + # Since lift_fresh lifts its argument into a functional tensor, we can skip the + # aliasing correction step. Otherwise, we would be setting the storage of a + # lifted tensor to that of an unlifted tensor. + # Ref: https://github.com/pytorch/pytorch/issues/111506 + or func == torch.ops.aten.lift_fresh.default + ): + return outs_wrapped + # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing. + # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects. + # Use this util to figure out the right thing to return. + # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for. + return return_and_correct_aliasing(func, args, kwargs, outs_wrapped) + + +@contextlib.contextmanager +def disable_functional_mode(): + return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) + + +# This is similar to torch.func.functionalize, but: +# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass). +# One important advantage to using this mode is that it will let us +# run functionalization underneath __torch_dispatch__, +# which we need in AOTAutograd. +# - Doing so means that it does not automatically compose with other +# functorch transforms, since these transforms always run above __torch_dispatch__. +# That's why this util lives here, and not in functorch. +def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()): + # TODO: pull these from aot autograd + def to_fun(t): + if isinstance(t, torch.Tensor): + return FunctionalTensor.to_functional(t) + return t + + def from_fun(t): + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) + return t + torch._sync(t) + return torch._from_functional_tensor(t.elem) + + def inner(*args, **kwargs): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + with disable_above, mode: + func_args = pytree.tree_map_only(torch.Tensor, to_fun, args) + func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs) + func_outputs = func(*func_args, **func_kwargs) + outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs) + + return outputs + + return inner + + +class BaseFunctionalizeAPI(ABC): + @abstractmethod + def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + pass + + @abstractmethod + def unwrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + pass + + @abstractmethod + def functionalize(self, inner_f: Callable) -> Callable: + pass + + @abstractmethod + def redispatch_to_next(self) -> ContextManager: + pass + + @abstractmethod + def replace(self, input_tensor, output_tensor) -> None: + pass + + @abstractmethod + def commit_update(self, tensor) -> None: + pass + + @abstractmethod + def sync(self, tensor) -> None: + pass + + @abstractmethod + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + pass + + +class PythonFunctionalizeAPI(BaseFunctionalizeAPI): + def __init__( + self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False + ) -> None: + super().__init__() + self.mode = mode if mode else FunctionalTensorMode() + self.pre_dispatch = pre_dispatch + + def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + with self.mode: + return torch.utils._pytree.tree_map_only( + torch.Tensor, FunctionalTensor.to_functional, args + ) + + def unwrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + return torch.utils._pytree.tree_map_only( + FunctionalTensor, FunctionalTensor.from_functional, args + ) + + def functionalize(self, inner_f: Callable) -> Callable: + return dispatch_functionalize(inner_f, self.mode) + + def redispatch_to_next(self) -> ContextManager: + # [NOTE] We don't do anything here because at the time + # we exercise this path, we would have already popped the + # FunctionalTensorMode from mode stack. Since FunctionalTensorMode + # is now stateful, it is better to explicitly pass in correct mode + # directly instead of globally setting it. + return contextlib.nullcontext() + + def replace(self, input_tensor, output_tensor) -> None: + assert isinstance(input_tensor, FunctionalTensor) + assert not isinstance(output_tensor, FunctionalTensor) + input_tensor.replace_(output_tensor) + + def commit_update(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.commit_update() + + def sync(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.sync() + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.mark_mutation_hidden_from_autograd() + + +class CppFunctionalizeAPI(BaseFunctionalizeAPI): + def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional + + return _wrap_all_tensors_to_functional(args, level=0) + + def unwrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + ) + + return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views()) + + def functionalize(self, inner_f: Callable) -> Callable: + return torch.func.functionalize(inner_f) + + def redispatch_to_next(self) -> ContextManager: + return torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + def replace(self, input_tensor, output_tensor) -> None: + torch._functionalize_replace(input_tensor, output_tensor) + + def commit_update(self, tensor) -> None: + torch._functionalize_commit_update(tensor) + + def sync(self, tensor) -> None: + torch._functionalize_sync(tensor) + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) + + +class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): + def __init__(self, interpreter): + self.interpreter = interpreter + + def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional + + return _wrap_all_tensors_to_functional(args, level=self.interpreter.level()) + + def unwrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: + from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + ) + + return _unwrap_all_tensors_from_functional( + args, reapply_views=self.interpreter.functionalize_add_back_views() + ) + + def functionalize(self, inner_f: Callable) -> Callable: + return torch.func.functionalize( + inner_f, + remove="mutations_and_views" + if self.interpreter.functionalize_add_back_views() + else "mutations", + ) + + def redispatch_to_next(self) -> ContextManager: + return self.interpreter.lower() + + def replace(self, input_tensor, output_tensor) -> None: + torch._functionalize_replace(input_tensor, output_tensor) + + def commit_update(self, tensor) -> None: + torch._functionalize_commit_update(tensor) + + def sync(self, tensor) -> None: + torch._functionalize_sync(tensor) + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) diff --git a/MLPY/Lib/site-packages/torch/_subclasses/meta_utils.py b/MLPY/Lib/site-packages/torch/_subclasses/meta_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..984b46d78c651fc059f3361ba8b53b25695efe93 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/meta_utils.py @@ -0,0 +1,987 @@ +import contextlib +import warnings +import weakref +from typing import ContextManager, Dict, List, Optional, Tuple, TYPE_CHECKING + +import torch +from torch._C._functorch import ( + _add_batch_dim, + _unwrap_functional_tensor, + _wrap_functional_tensor, + current_level, + get_unwrapped, + is_batchedtensor, + is_functorch_wrapped_tensor, + is_gradtrackingtensor, + maybe_get_bdim, + maybe_get_level, + peek_interpreter_stack, + TransformType, +) +from torch._guards import Source + +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) +from torch.utils.weak import WeakIdRef + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch.fx.experimental.symbolic_shapes import SymbolicContext + +DimList = List + + +def safe_is_leaf(t): + try: + return t.is_leaf + except RuntimeError: + # inference mode can trigger this + return False + + +def safe_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return t.grad + + +def assert_eq(a, b): + assert a == b, f"{a} != {b}" + + +def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False): + def go(m1, m2): + assert_eq(m1.dtype, m2.dtype) + if not skip_symbolic: + assert_eq(m1.shape, m2.shape) + assert_eq(m1.requires_grad, m2.requires_grad) + assert_eq(m1.is_leaf, m2.is_leaf) + assert_eq(m1.grad_fn is None, m2.grad_fn is None) + assert_eq(m1.is_sparse, m2.is_sparse) + assert_eq(m1.is_inference(), m2.is_inference()) + assert_eq(m1.is_conj(), m2.is_conj()) + assert_eq(m1.is_neg(), m2.is_neg()) + assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None) + if safe_grad(m1) is not None: + go(safe_grad(m1), safe_grad(m2)) + if m1.is_sparse: + assert_eq(m1.dense_dim(), m2.dense_dim()) + assert_eq(m1.sparse_dim(), m2.sparse_dim()) + assert_eq(m1.is_coalesced(), m2.is_coalesced()) + else: + if not skip_symbolic: + assert_eq(m1.stride(), m2.stride()) + assert_eq(m1.storage_offset(), m2.storage_offset()) + assert_eq(m1._is_view(), m2._is_view()) + if m1._is_view(): + go(m1._base, m2._base) + # TODO: test if is resizable (no direct query for this atm) + # TODO: audit AutogradMeta to see if it matches + # TODO: test forward AD + + return go(m1, m2) + + +def is_sparse_coo(t): + return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo + + +def is_sparse_compressed(t): + return isinstance(t, torch.Tensor) and t.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + + +def is_sparse_any(t): + return is_sparse_coo(t) or is_sparse_compressed(t) + + +# This is a class for converting multiple tensors into meta tensors which +# share the same view/storage structure. The operation model is you allocate +# one of these, and then call it repeatedly on all the tensors you want to +# convert. It's important to use the same object for tensors you want to +# share storage because this is how we correlate shared storages to the same +# meta storages. This class will hold weak references to cached tenosrs +# and tensor storages. +class MetaConverter: + def __init__(self): + self.storage_memo = {} + self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.maybe_storages_to_delete = [] + self.check_expired_frequency = 128 + self.check_expired_count = 0 + self.hit = 0 + self.miss = 0 + self.del_hook = None + self.arg_cnt = 0 + + def successful(self): + return self.hit > 0 and self.miss == 0 + + def check_for_expired_weak_storages(self): + new_li = [] + stor_to_delete = [] + for obj in self.maybe_storages_to_delete: + if not obj.expired(): + new_li.append(obj) + else: + stor_to_delete.append(obj) + for obj in stor_to_delete: + self.storage_memo.pop(obj, None) + self.maybe_storages_to_delete = new_li + + # if for some reason we have aquired many storages which have not expired + # even though a tensor with their storage has expired (aliasing or otherwise) + # check for expired storages less often so as to bound the amount of work we + # do checking for expired storages + self.check_expired_frequency = max( + self.check_expired_frequency, len(self.maybe_storages_to_delete) + ) + + def get_tensor_memo(self, t): + return self.tensor_memo.get(WeakIdRef(t), None) + + def set_tensor_memo(self, t, v): + # hold a weak ref to self, otherwise it will be kept alive + # by the del_ten closure + self_weak_ref = weakref.ref(self) + if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t): + weak_st = None + else: + weak_st = StorageWeakRef(t._typed_storage()) + tensor_ref_key = WeakIdRef(t) + + def del_ten(): + # tensor outlives the converter + self_ref = self_weak_ref() + if self_ref is None: + return + # on shutdown, tensor_ref_key may not be in memo + self_ref.tensor_memo.pop(tensor_ref_key, None) + if weak_st and weak_st.expired(): + self_ref.storage_memo.pop(weak_st, None) + elif weak_st is not None: + # [expired-storages] + # NB: even though the tensor has died, + # the deallocation of its storage can take longer, + # even when the storage has no other uses/views. + # In this case, the StorageWeakRef object will be kept alive + # longer than it needs to be, however the storage itself + # will be deallocated. We retain the possibly dead storages + # and periodically check if any of them are expired and + # can be freed. + self_ref.maybe_storages_to_delete.append(weak_st) + + weakref.finalize(t, del_ten) + self.tensor_memo[tensor_ref_key] = v + + # NB: doesn't actually return a storage, because meta storage is + # not supported + def meta_storage(self, s, callback): + # NB: TypedStorage is freshly allocated and cannot be used as hash + # key index. + + # Use a Weak Ref to s in order to not leak memory + swr = StorageWeakRef(s) + if swr not in self.storage_memo: + self.storage_memo[swr] = callback( + lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") + ).untyped_storage() + return self.storage_memo[swr] + + # This function assumes that it's possible to do the conversion + # NB: name here is used in a conventional way by Dynamo; it corresponds + # precisely to the Source.name() of the tensor we're fakeifying and + # corresponds to a valid Python expression. When we construct sub-names + # as part of this process, we will maintain this invariant! (Even though + # other users of this may not need it this property to be upheld.) + def meta_tensor( + self, + t, + shape_env=None, + callback=lambda t: t(), + source: Optional[Source] = None, + symbolic_context: Optional["SymbolicContext"] = None, + ): + if source is None: + from torch._dynamo.source import ConstantSource + + # TODO: make a dedicated UnknownSource for this? + source = ConstantSource( + f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" + ) + + # This indicates you set no_dispatch() before calling into this + # function. This is an error: we may be creating fake tensors and + # will perform operations on them which need fake tensor mode to + # be active. You will segfault if you are in a no_dispatch() block. + assert not torch._C._dispatch_tls_local_exclude_set().has( + torch._C.DispatchKey.Python + ) + arg_cnt = self.arg_cnt + self.arg_cnt += 1 + + # When we make as_strided calls, we end up generating a guard + # that the new as_strided tensor is in bounds for the old storage + # for the base (since as_strided calls can "bust" out of their + # bounding box.) This guard is unnecessary: if a user is able + # to provide us a tensor with the view base setup this way, we + # don't need to produce a guard, because the fact that they + # were able to produce the view base means its in bounds. + # + # Now, ordinarily, this guard would be harmless. However, the + # generated guard refers to variables bound on the base variable. + # At the moment, Dynamo doesn't actually guard on x._base, because + # according to Voz this results in a lot of spurious invalidations, + # and also if the user doesn't directly make use of _base, its + # pointless anyway (because programs should be parametric over + # whether or not the input tensor is a view or not--unless you're + # mutating the input, but that's a whole 'nother ballgame). So + # for expediency, we suppress these guards so we don't have to + # deal with this (yet, anyway.) + # + # NB: An old version of this code suppressed guards for ALL operations + # happening during meta conversion, not just as_strided calls. + # This is too aggressive: we do duck sizing and 0/1 simplification + # as we allocate variables, and we do need to register guards for + # these cases. + maybe_suppress = contextlib.nullcontext + if shape_env is not None: + maybe_suppress = shape_env.suppress_guards + + def sym_sizes_strides_storage_offset( + t, src, symbolic_context=symbolic_context + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + if shape_env is not None: + fake_mode = torch._subclasses.fake_tensor.maybe_get_fake_mode(t) + if fake_mode is not None and fake_mode.shape_env is shape_env: + # Don't reallocate the sizes; the shape envs are the same, + # so reuse the old sizes/strides/etc + return (t.size(), t.stride(), t.storage_offset()) + else: + return shape_env.create_symbolic_sizes_strides_storage_offset( + t, + src, + symbolic_context=symbolic_context, + ) + else: + assert symbolic_context is None + return (t.size(), t.stride(), t.storage_offset()) + + def empty_create(inner_t, inner_src, symbolic_context=symbolic_context): + ( + inner_sizes, + inner_strides, + inner_storage_offset, + ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) + return torch.empty_strided( + inner_sizes, + inner_strides, + dtype=inner_t.dtype, + device="meta", + ) + + # Creates a subclass instance with empty inner tensors according to the specified + # symbolic context. + def empty_create_subclass( + t, + outer_size, + outer_stride, + symbolic_context=symbolic_context, + callback=callback, + source=source, + ): + from torch._dynamo.source import AttrSource + from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext + + assert symbolic_context is None or isinstance( + symbolic_context, SubclassSymbolicContext + ) + + # Note: transform_subclass will use __tensor_unflatten__ to generate + # a fresh subclass wrapper with outer sizes / strides according to the + # outer symbolic context (passed in to this function). Inner size / stride + # / storage offset symbols are allocated according to the appropriate inner + # symbolic contexts, after which the checks in transform_subclass() will + # relate them to the outer metadata as possible. + return transform_subclass( + t, + lambda attr, inner_t: callback( + lambda: empty_create( + inner_t, + AttrSource(source, attr), + symbolic_context=( + None + if symbolic_context is None + else symbolic_context.inner_contexts[attr] + ), + ) + ), + outer_size=outer_size, + outer_stride=outer_stride, + ) + + # Returns an all-dynamic symbolic context used for metafying the given tensor with + # fully dynamic dims. This is useful when fake-ifying intermediate tensors in + # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we + # don't want to over-specialize during view replay. + def all_dynamic_symbolic_context(t, source, shape_env, callback): + from torch._dynamo.source import AttrSource + from torch.fx.experimental.symbolic_shapes import ( + DimDynamic, + StatelessSymbolicContext, + SubclassSymbolicContext, + SymbolicContext, + ) + + view_base_context: Optional[SymbolicContext] = None + if t._is_view(): + view_base_context = all_dynamic_symbolic_context( + t._base, AttrSource(source, "_base"), shape_env, callback + ) + + t_symbolic_context: SymbolicContext + t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.dim() + if is_traceable_wrapper_subclass(t): + inner_contexts: Dict[str, SymbolicContext] = {} + attrs, _ = t.__tensor_flatten__() + for attr in attrs: + assert isinstance(attr, str) + inner = getattr(t, attr) + inner_contexts[attr] = all_dynamic_symbolic_context( + inner, AttrSource(source, attr), shape_env, callback + ) + t_symbolic_context = SubclassSymbolicContext( + dynamic_sizes=t_dynamic_sizes, + constraint_sizes=[None] * t.dim(), + inner_contexts=inner_contexts, + tensor_source=source, + view_base_context=view_base_context, + ) + else: + t_symbolic_context = StatelessSymbolicContext( + dynamic_sizes=t_dynamic_sizes, + constraint_sizes=[None] * t.dim(), + view_base_context=view_base_context, + ) + + return t_symbolic_context + + # Returns a fake-ified version of an input view tensor t, given an already fake-ified + # base. At a high level, we want two things: + # 1. fake_t should have the same view relationship to the given fake base as the + # input t has to its _base. + # 2. fake_t should have symbolic sizes / strides / storage offset according to the + # appropriate symbolic context (i.e. from the automatic dynamic algorithm). + # + # We currently take different strategies across view types: + # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an + # as_strided() call on the fake-ified base, passing symbolic metadata. + # * For views involving subclasses, perform view replay using view funcs to + # achieve (1). It's necessary for (2) to swap out any closed-over state in + # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this + # avoids specialization (and thus over-eager simplification of symbols) that + # could occur during view replay on the fake-ified base. + # + # Examples: + # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled + # with an as_strided() call on the fake base passing symbolic metadata. + # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg + # is made symbolic to avoid invalid specialization and view replay is then + # done to reconstruct the view. + # * _nested_from_jagged(values, offsets) is a dense -> subclass view + # that returns a subclass instance from a dense values tensor. The offsets + # tensor is closed over in the view func, as it can be considered view metadata. + # First, the offsets tensor is fake-ified according to the inner symbolic + # context and with the correct relationship to the outer size / stride metadata. + # Then view replay is done, swapping in the fake offsets so the view replay output + # is fully fake with no invalid specialization. + def view_from_base(base, t, source=source, shape_env=shape_env): + # fake-ify t's metadata according to the outer symbolic context + (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( + t, source + ) + if not is_traceable_wrapper_subclass( + t + ) and not is_traceable_wrapper_subclass(base): + # Dense -> Dense view case uses as_strided() to construct view relationship. + # TODO: Change this logic to use view replay for consistency? + # It's likely there is no view func available. + return base.as_strided(sizes, strides, storage_offset) + + from torch._dynamo.source import EphemeralSource + from torch.fx.experimental.symbolic_shapes import sym_eq + + def symint_visitor_fn(s): + if shape_env is None: + return s + + # NB: The symbol here is expected to be simplified out because we a priori + # allocate inner and outer symbols according to the appropriate symbolic + # contexts and prefer those over this symbol during symbol simplification + # (via usage of EphemeralSource below). This -shouldn't- happen, but if + # this symbol somehow leaks out beyond the view tensor's shape metadata, our + # assumption of it being simplified out will fail and it may be guarded on, + # which will hard error. + sym_source = EphemeralSource("symint_visitor_fn") + symbol = shape_env.create_symbol(s, sym_source) + return shape_env.create_symintnode(symbol, hint=s, source=sym_source) + + real_to_fake_mapping = {} + if is_traceable_wrapper_subclass(t): + # Fake-ify t naively here; this is only done so we can get fake-ified inner + # tensors with the correct relationships to the outer sizes / strides for use + # in view replay. It's done beforehand here because it's not easy to do when + # visiting tensors one-by-one during view replay. + # + # Example: + # Consider a Dense -> NJT view. NJT has (values, offsets) components and we + # want a view of values with the offsets closed over. As the offsets component + # is needed to describe the output view, it's important that it's fakeified + # correctly. + fake_t = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + attrs, _ = fake_t.__tensor_flatten__() + for attr in attrs: + real_to_fake_mapping[getattr(t, attr)] = getattr(fake_t, attr) + + def tensor_visitor_fn( + visited_t, shape_env=shape_env, callback=callback, source=source + ): + # It's possible to close over an undefined tensor (e.g. NJT's lengths). + if visited_t is None: + return None + + # Fake inner tensors of view subclasses will come from the mapping built above. + fake_visited_t = real_to_fake_mapping.get(visited_t, None) + if fake_visited_t is not None: + return fake_visited_t + + # For other closed-over tensor state, fake-ify it as all dynamic with an + # ephemeral source. This avoids invalid specialization during view replay. + # If we find that in practice the usage of ephemeral sources isn't enough + # to guarantee that we don't have guards on these symbols, we may need to + # explicitly suppress guards (as is done for _base in the dense -> dense + # view case). + temp_source = EphemeralSource("tensor_visitor_fn") + return self.meta_tensor( + visited_t, + shape_env, + callback, + source=temp_source, + symbolic_context=all_dynamic_symbolic_context( + visited_t, temp_source, shape_env, callback + ), + ) + + # Replay the view, swapping out any non-symbolic SymInts or real tensors + # for symbolic SymInts or fake tensors. + fake_t = t._view_func_unsafe(base, symint_visitor_fn, tensor_visitor_fn) + + # Ensure the output has symbolic shapes according to the outer symbolic context. + # These checks should simplify out any symbols created for closed-over view func + # SymInts. + torch._check(sym_eq(fake_t.size(), sizes)) + torch._check(sym_eq(fake_t.stride(), strides)) + torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) + return fake_t + + # see expired-storages + self.check_expired_count += 1 + if self.check_expired_count >= self.check_expired_frequency: + self.check_for_expired_weak_storages() + self.check_expired_count = 0 + + if self.get_tensor_memo(t) is None: + with torch.inference_mode(t.is_inference()): + if t.is_sparse: + is_leaf = safe_is_leaf(t) + + # The lambda function below is similar to + # `t.to(device='meta')` except the latter + # preserves nnz value + r = callback( + lambda: torch.ops.aten._sparse_coo_tensor_with_dims( + t.sparse_dim(), + t.dense_dim(), + t.shape, + dtype=t.dtype, + layout=torch.sparse_coo, + device="meta", + ) + ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + # Note [is_coalesced is dispatched] + # Strangely enough, is_coalesced() is a dispatched operator, + # which means that it will get caught by fake tensor mode. + # Ordinarily this would error, but there's some logic in + # fake tensor ensure this doesn't happen. + r._coalesced_(t.is_coalesced()) + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + with torch.enable_grad(): + r = r.clone() + r._coalesced_(t.is_coalesced()) + elif is_sparse_compressed(t): + is_leaf = safe_is_leaf(t) + + def mk_meta(): + nnz = 0 + batch_dim = t.ndim - t.sparse_dim() - t.dense_dim() + batch_size = t.shape[:batch_dim] + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + index_dtype = t.crow_indices().dtype + compressed_indices = torch.empty( + t.crow_indices().shape, device="meta", dtype=index_dtype + ) + plain_indices = torch.empty( + (*t.col_indices().shape[:-1], nnz), + device="meta", + dtype=index_dtype, + ) + else: + index_dtype = t.ccol_indices().dtype + compressed_indices = torch.empty( + t.ccol_indices().shape, device="meta", dtype=index_dtype + ) + plain_indices = torch.empty( + (*t.row_indices().shape[:-1], nnz), + device="meta", + dtype=index_dtype, + ) + values_shape = t.values().shape + values = torch.empty( + ( + *values_shape[:batch_dim], + nnz, + *values_shape[batch_dim + 1 :], + ), + dtype=t.dtype, + device="meta", + ) + return torch.ops.aten.sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + t.shape, + layout=t.layout, + dtype=t.dtype, + device="meta", + ) + + # `mk_meta()` is similar to `t.to(device='meta'))` + # except `to('meta')` preserves nnz value while + # `mk_meta` result has nnz == 0. + r = callback(mk_meta) + + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + with torch.enable_grad(): + r = r.clone() + elif t.is_nested and not is_traceable_wrapper_subclass(t): + # TODO: Handle this better in Dynamo? + # There are checks there now, but this can still be triggered by a dense + # tensor graph input that is a view of a strided NT. + from torch._dynamo.exc import unimplemented + + unimplemented( + "strided nested tensors are not supported by meta conversion" + ) + elif t.is_mkldnn: + is_leaf = safe_is_leaf(t) + sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( + t, source + ) + r = callback( + lambda: torch.empty_strided( + sizes, strides, dtype=t.dtype, device="meta" + ) + ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + with torch.enable_grad(): + r = r.clone() + elif is_functorch_wrapped_tensor(t): + if t._is_view(): + from torch._dynamo.exc import unimplemented + + unimplemented( + "view functorch tensors are not supported by meta conversion" + ) + + # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) + # in a FakeTensor + def _to_fake_tensor(t): + if is_batchedtensor(t): + ft = _to_fake_tensor(get_unwrapped(t)) + lvl = maybe_get_level(t) + bdim = maybe_get_bdim(t) + r = _add_batch_dim(ft, bdim, lvl) + elif is_gradtrackingtensor(t): + disable_functorch = torch._C._DisableFuncTorch + with disable_functorch(): + ft = _to_fake_tensor(get_unwrapped(t)) + lvl = torch._C._functorch.maybe_get_level(t) + r = torch._C._functorch._wrap_for_grad(ft, lvl) + + is_leaf = safe_is_leaf(t) + if t.requires_grad and safe_is_leaf(r): + r.requires_grad = True + elif t.requires_grad and not is_leaf: + with torch.enable_grad(): + r = r.clone() + else: + sizes = t.size() + strides = t.stride() + r = callback( + lambda: torch.empty_strided( + sizes, + strides, + dtype=t.dtype, + device="meta", + ) + ) + return r + + r = _to_fake_tensor(t) + + elif t._is_view(): + # Construct views in two steps: recursively meta-fy their + # base, and then create view(s) off that. NB: doing it + # directly from storage is WRONG because this won't cause + # version counters to get shared. + assert t._is_view() + + base_symbolic_context = None + if shape_env and symbolic_context is not None: + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + ) + + assert isinstance(symbolic_context, StatelessSymbolicContext) + # NB: This should generally be set when the input is a view, + # but the exception right now is for fake-ifying grads, which is + # a work in progress. + if symbolic_context.view_base_context is not None: + base_symbolic_context = symbolic_context.view_base_context + + base = self.meta_tensor( + t._base, + shape_env, + callback, + source=torch._dynamo.source.AttrSource(source, "_base"), + symbolic_context=base_symbolic_context, + ) + + def is_c_of_r(complex_dtype, real_dtype): + return ( + utils.is_complex_dtype(complex_dtype) + and utils.corresponding_real_dtype(complex_dtype) + == real_dtype + ) + + # In some situations, MetaConverter may be called in a + # context where autograd is disabled. For the _is_view + # assert to pass, we have to setup the autograd view + # metadata anyway. Do this by reenabling the + # ADInplaceOrView key. This is kind of a hack. + old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, False + ) + try: + if base.dtype == t.dtype: + pass + elif is_c_of_r(base.dtype, t.dtype): + base = torch.view_as_real(base) + elif is_c_of_r(t.dtype, base.dtype): + base = torch.view_as_complex(base) + else: + # This is not guaranteed to succeed. If it fails, it + # means there is another dtype-converting view function + # that hasn't been handled here + base = base.view(t.dtype) + + # This is very tricky. Naively, you might expect this + # to hold: + # + # if t.requires_grad and not safe_is_leaf(t) + # assert t._base.requires_grad + # + # But it's not true! As you can see in the following + # program: + # + # x = torch.zeros(4) + # y = x.view(1, 4) + # y.requires_grad = True + # z = y.view(1, 1, 4) + # assert z._base is x + # + # So we may have to do *two* views out of the base to + # recreate this situation. + if safe_is_leaf(t): + # Leaf views that track view metadata are created by + # creating a view inside a no_grad block + with torch.no_grad(), maybe_suppress(): + r = view_from_base(base, t) + # As it's a leaf, we can directly assign requires_grad + r.requires_grad = t.requires_grad + else: + if t._base.requires_grad == t.requires_grad: + # Easy case, just run the view op + with torch.enable_grad(), maybe_suppress(): + r = view_from_base(base, t) + + # NB: We don't actaully faithfully replicate + # autograd connectivity, but that doesn't matter + # today. See following for more info: + # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 + else: + # Obscure case. Create a leaf view and give it the + # correct requires_grad, then do the final view. + # NB: Can't have a non-leaf without requiring grad! + assert t.requires_grad + with torch.no_grad(): + mid = base.view(base.shape) + mid.requires_grad = t.requires_grad + with torch.enable_grad(), maybe_suppress(): + r = view_from_base(mid, t) + # The CreationMeta influences whether or not inplace + # mutation is an error or not. So we need to make + # sure we properly propagate this as well. + torch._C._autograd._set_creation_meta( + r, torch._C._autograd._get_creation_meta(t) + ) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old_exclude + ) + + else: + is_leaf = safe_is_leaf(t) + + ( + sizes, + strides, + storage_offset, + ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) + + # If we have a subclass that desugars into dense tensors, + # perform our callback on each inner tensor. + if is_traceable_wrapper_subclass(t): + r = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + else: + r = callback( + lambda: torch.empty_strided( + sizes, + strides, + dtype=t.dtype, + device="meta", + ) + ) + + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = t.requires_grad + if not is_leaf: + # Fake up some autograd history. + with torch.enable_grad(): + # preserve_format is the default, but we want to + # emphasize how important it is to preserve + # format here + r = r.clone(memory_format=torch.preserve_format) + + # Graph-Break for wrapped tensors + if not ( + is_batchedtensor(t) or is_gradtrackingtensor(t) + ) and torch._C._functorch.is_functorch_wrapped_tensor(t): + return NotImplemented + + s = t.untyped_storage() + swr = StorageWeakRef(s) + if swr not in self.storage_memo and ( + r.is_nested + or ( + r.stride() == strides + and r.storage_offset() == storage_offset + ) + ): + # You're normal and happy, install the fresh storage into the memo + self.storage_memo[swr] = r.untyped_storage() + else: + # You're in crazy town; somehow you gave us a tensor + # that wasn't a view, but had nonzero storage offset, + # nontrivial strides (such that clone() couldn't + # preserve them), or already aliases with another + # tensor's storage. The most typical way to end + # up here is with set_. So use set_ to bludgeon this + # in. + r_s = self.meta_storage(s, callback=callback) + # NB: In principle, this should always work, but there + # is some subtle difference in the autograd metadata + # that means we will backprop the set_ call, even if + # r is declared as an input to grad. + # See https://github.com/pytorch/pytorch/issues/87956 + # for the reproducer. + # NB: The in_kernel_invocation_manager here is necessary + # for fake tensor. If we run the set_ call with fake + # tensor on, r will improperly report that it is NOT a + # meta tensor but a cpu tensor, and then the set_ call + # will fail due to device mismatch. no_dispatch() is + # not enough, because the fake tensor will still claim + # to be a CPU tensor and you'll end up in the CPU + # kernel. Arguably this is a hack; a cleaner way to + # solve this is to have a FakeStorage concept which + # would report it's CPU device--no problem now! But + # this is difficult to do because we don't have storage + # subclasses. Relevant test is + # DynamicShapesFunctionTests::test_add_dynamic_shapes in + # test/dynamo/test_dynamic_shapes.py + maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() + from torch._subclasses.fake_tensor import ( + in_kernel_invocation_manager, + maybe_get_fake_mode, + ) + + mb_fake_mode = maybe_get_fake_mode(r) + if mb_fake_mode is not None: + maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) + with maybe_fake_mgr, torch.no_grad(): + r.set_(r_s, storage_offset, sizes, strides) + + if safe_grad(t) is not None: + from torch._dynamo.source import AttrSource + + # TODO: Use a valid grad-specific symbolic context instead of recycling + # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). + r.grad = self.meta_tensor( + safe_grad(t), + shape_env, + callback, + source=AttrSource(source, "grad"), + symbolic_context=symbolic_context, + ) + torch._C._set_conj(r, t.is_conj()) + torch._C._set_neg(r, t.is_neg()) + # This can be skipped if necessary for performance reasons + assert_metadata_eq(assert_eq, t, r, skip_symbolic=True) + self.set_tensor_memo(t, r) + + return self.get_tensor_memo(t) + + def __call__( + self, + t, + shape_env=None, + *, + callback=lambda t: t(), + source=None, + symbolic_context=None, + ): + # TODO: zero tensors? We appear to have eliminated them by + # excluding complex for now + + if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t): + if t.device.type != "xla" and any( + [ + t.is_quantized, + t._is_view() and t._base is not None and t._base.is_sparse, + torch._is_functional_tensor(t), + t.device.type in ("lazy"), + # We need a way to test if a tensor is batched but there + # is no official APi to do it + # torch._C._is_batched(t), + ] + ): + # TODO: sparse should support meta + # NB technically to('meta') does work but our logging + # instrumentation will see the meta conversions and the + # tests all break so we just exclude this. In any case + # the to conversion isn't really right anyhow. + + if torch._is_functional_tensor(t) and t.device.type != "lazy": + if t._is_view(): + raise RuntimeError( + "Cannot safely fakify a view because this process drops the view information right now." + ) + + st = peek_interpreter_stack() + assert ( + st is None or st.key() == TransformType.Functionalize + ), "Expect st to be either None or have Functionalize transform key." + if st is None: + # the case of AOTAutograd + torch._sync(t) + unwrap_t = torch._from_functional_tensor(t) + with torch._dispatch.python.suspend_functionalization(): + fake_t = self.meta_tensor( + unwrap_t, + shape_env=shape_env, + callback=callback, + source=source, + symbolic_context=symbolic_context, + ) + out = torch._to_functional_tensor(fake_t) + torch._mirror_autograd_meta_to(fake_t, out) + return out + else: + # torch.func.functionalize + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrap_t = _unwrap_functional_tensor(t, reapply_views) + pop_st_ctx = ( + torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack() + ) + with pop_st_ctx: + fake_t = self.meta_tensor( + unwrap_t, + shape_env=shape_env, + callback=callback, + source=source, + symbolic_context=symbolic_context, + ) + return _wrap_functional_tensor(fake_t, current_level()) + self.miss += 1 + return NotImplemented + else: + self.hit += 1 + + disable_functorch = torch._C._DisableFuncTorch + with disable_functorch(): + r = self.meta_tensor( + t, + shape_env=shape_env, + callback=callback, + source=source, + symbolic_context=symbolic_context, + ) + if type(t) is torch.nn.Parameter: + # NB: Cannot directly use Parameter constructor + # because that would force a detach, not desirable + r._is_param = True + return r + elif torch.overrides.is_tensor_like(t): + self.miss += 1 + return NotImplemented + else: + # non-Tensor types don't count as hit or miss + return t + + +import torch._prims_common as utils diff --git a/MLPY/Lib/site-packages/torch/_subclasses/schema_check_mode.py b/MLPY/Lib/site-packages/torch/_subclasses/schema_check_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..95b2ca093e6acf75ba5fb7ad1d136ccbda992666 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_subclasses/schema_check_mode.py @@ -0,0 +1,198 @@ +# mypy: ignore-errors + +from collections import namedtuple +from copy import deepcopy +from itertools import combinations + +import torch +from torch.fx.operator_schemas import normalize_function +from torch.testing._internal.jit_utils import clone_inputs +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map + +# Named Tuples used within SchemaCheckMode +Mutation = namedtuple("Mutation", ["op_name", "arg_name"]) +Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"]) + +# Simplified naming for C++ classes +SchemaArgument = torch._C._SchemaArgument +SchemaArgType = torch._C._SchemaArgType +SchemaInfo = torch._C._SchemaInfo + +# This TorchDispatchMode Subclass is used to verify op schemas +# This TorchDispatchMode Scubclass currently: +# - Records the called ops +# - Checks for mutations on all inputs +# - Checks for aliasing on all inputs + + +class SchemaCheckMode(TorchDispatchMode): + def __init__(self): + # Information recorded for testing purposes. For example: + # - incorrect schemas + # - overly conservative schemas + self.ops = [] + self.mutated = [] + self.aliasing = [] + + def reset_cache(self): + self.ops.clear() + self.mutated.clear() + self.aliasing.clear() + + def display_ops(self): + print(*self.ops, sep=",") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def bitwise_equal(lhs, rhs): + if lhs.is_quantized: + # TODO: This is only OK if can't have NaN quantized; idk if + # this is actually true + return torch.equal(lhs, rhs) + else: + return torch.allclose(lhs, rhs, equal_nan=True) + + def has_mutated(before, after, md): + are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor + if ( + are_tensors + and before.layout != torch.sparse_csr + and after.layout != torch.sparse_csr + ): + return not ( + before.size() == after.size() + and bitwise_equal(before, after) + and md[0] == after.stride() + and md[1] == after._typed_storage()._cdata + ) + return False + + def has_aliased(lhs, rhs): + try: + return torch._C._overlaps(lhs, rhs) + except Exception as exception: + if str(exception).startswith("Cannot inspect value of type "): + return False + else: + raise exception + + def standardize_name(name): + return name if name != "self" else "input" + + def unwrap(e): + if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: + try: + return e.elem + except AttributeError as t: + return e + return e + + def parse_metadata(e): + if isinstance(e, torch.Tensor): + if not type(e) == torch.Tensor: + try: + current = e.elem + return ( + deepcopy(current.stride()), + current._typed_storage()._cdata, + ) + except AttributeError as t: + return None + # Sparse CSR tensors do not have strides or storage + elif e.layout != torch.sparse_csr: + return (deepcopy(e.stride()), e._typed_storage()._cdata) + return None + + self.ops.append(func._schema.name) + + # Clone and process arguments and outputs + pre_arguments = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ).kwargs + + c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) + cloned_arguments = { + name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args + } + cloned_metadata = { + name: [ + parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name)) + ] + for name in pre_arguments + } + + out = func(*args, **kwargs) + arguments = { + name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments + } + tuple_out = out if isinstance(out, tuple) else (out,) + tuple_out = tree_map(unwrap, tuple_out) + + schema_info = SchemaInfo(func._schema) + schema_info.add_argument_values(pre_arguments) + + # Process arguments with outputs + for i in range(len(func._schema.arguments)): + arg = func._schema.arguments[i] + name = standardize_name(arg.name) + if arguments.get(name) is not None: + before = cloned_arguments.get(name) + md = cloned_metadata.get(name) + after = arguments.get(name) + for j in range(len(tuple_out)): + # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe) + unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split") + if ( + has_aliased(tuple_out[j], after) + and func._schema.name not in unsafe_ops + ): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, j), + SchemaArgument(SchemaArgType.input, i), + ): + raise RuntimeError( + f"Argument {name} is not defined to alias output but was aliasing" + ) + else: + self.aliasing.append( + Aliasing(func._schema.name, name, f"output_{j}") + ) + if after is tuple_out[j] and isinstance(after, torch.Tensor): + # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs. + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ) and func not in [ + torch.ops.aten.lift.default, + torch.ops.aten.lift_fresh.default, + ]: + raise RuntimeError( + f"""\ +Dispatcher operators below autograd are not allowed to directly return inputs. +However, we found that `outputs[{str(j)}] is {name}""" + ) + if any( + has_mutated(a, b, c) + for a, b, c in zip( + pytree.tree_leaves(before), pytree.tree_leaves(after), md + ) + ): + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ): + raise RuntimeError( + f"Argument {name} is not defined as mutable but was mutated" + ) + else: + self.mutated.append(Mutation(func._schema.name, name)) + + # Aliasing between outputs + for i, j in combinations(range(len(func._schema.returns)), 2): + if has_aliased(tuple_out[i], tuple_out[j]): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, i), + SchemaArgument(SchemaArgType.output, j), + ): + raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly") + + return out diff --git a/MLPY/Lib/site-packages/torch/_tensor.py b/MLPY/Lib/site-packages/torch/_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..85a466bd155bf9a40475f31af7ae32c3a45b3393 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_tensor.py @@ -0,0 +1,1543 @@ +import copyreg +import enum +import functools +import warnings +from collections import OrderedDict +from copy import deepcopy +from numbers import Number +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch._C as _C +import torch.utils.hooks as hooks +from torch._namedtensor_internals import ( + check_serializing_named_tensor, + is_ellipsis, + resolve_ellipsis, + single_ellipsis_index, + unzip_namedshape, + update_names, +) +from torch.overrides import ( + get_default_nowrap_functions, + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) +from torch.utils.dlpack import DLDeviceType + + +def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): + assigned = functools.WRAPPER_ASSIGNMENTS + + @functools.wraps(f, assigned=assigned) + def wrapped(*args, **kwargs): + try: + # See https://github.com/pytorch/pytorch/issues/75462 + if has_torch_function(args): + return handle_torch_function(wrapped, args, *args, **kwargs) + return f(*args, **kwargs) + except TypeError: + return NotImplemented + + return wrapped + + +# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses +def _rebuild_from_type(func, type, args, dict): + if type is Tensor: + return func(*args) + + ret = func(*args).as_subclass(type) + ret.__dict__ = dict + return ret + + +def _rebuild_from_type_v2(func, new_type, args, state): + ret = func(*args) + if type(ret) is not new_type: + ret = ret.as_subclass(new_type) + # Tensor does define __setstate__ even though it doesn't define + # __getstate__. So only use __setstate__ if it is NOT the one defined + # on Tensor + if ( + getattr(ret.__class__, "__setstate__", Tensor.__setstate__) + is not Tensor.__setstate__ + ): + ret.__setstate__(state) + else: + ret = torch._utils._set_obj_state(ret, state) + return ret + + +# NB: If you subclass Tensor, and want to share the subclassed class +# across processes, you must also update torch/multiprocessing/reductions.py +# to define a ForkingPickler serialization mode for the class. +# +# NB: If you add a new method to Tensor, you must update +# torch/_C/__init__.pyi.in to add a type annotation for your method; +# otherwise, it will not show up in autocomplete. +class Tensor(torch._C.TensorBase): + def __deepcopy__(self, memo): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) + if not self.is_leaf: + raise RuntimeError( + "Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment. " + "If you were attempting to deepcopy a module, this may be because " + "of a torch.nn.utils.weight_norm usage, " + "see https://github.com/pytorch/pytorch/pull/103001" + ) + if id(self) in memo: + return memo[id(self)] + with torch.no_grad(): + # TODO: skipping storage copy is wrong for meta, as meta + # does accurate alias tracking; however, the code below + # doesn't work because of + # https://github.com/pytorch/pytorch/issues/47442 + # Update the test in test_serialization if you remove 'meta' from here + if ( + self.is_sparse + or self.device.type + in ["lazy", "xla", "mtia", "mps", "ort", "meta", "ipu"] + or ( + not torch._C._has_storage(self) + and self.device.type == torch._C._get_privateuse1_backend_name() + ) + or (type(self) is not Tensor and self.data_ptr() == 0) + ): + new_tensor = self.clone() + if type(new_tensor) is not type(self): + raise RuntimeError( + "The default implementation of __deepcopy__() for wrapper subclasses " + "only works for subclass types that implement clone() and for which " + "cloning returns another instance of the same subclass. You should either " + "properly implement clone() for your subclass or override __deepcopy__() " + "if it is intended behavior for clone() to return an instance of a " + "different type." + ) + else: + new_storage = self._typed_storage()._deepcopy(memo) + if self.is_quantized: + # quantizer_params can be different type based on torch attribute + quantizer_params: Union[ + Tuple[torch.qscheme, float, int], + Tuple[torch.qscheme, Tensor, Tensor, int], + ] + if self.qscheme() == torch.per_tensor_affine: + quantizer_params = ( + self.qscheme(), + self.q_scale(), + self.q_zero_point(), + ) + elif self.qscheme() in ( + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ): + quantizer_params = ( + self.qscheme(), + self.q_per_channel_scales(), + self.q_per_channel_zero_points(), + self.q_per_channel_axis(), + ) + else: + raise RuntimeError( + f"Unsupported qscheme {self.qscheme()} in deepcopy" + ) + # TODO: Once we decide to break serialization FC, no longer + # need to wrap with TypedStorage + new_tensor = torch._utils._rebuild_qtensor( + torch.storage.TypedStorage( + wrap_storage=new_storage._untyped_storage, + dtype=self.dtype, + _internal=True, + ), + self.storage_offset(), + self.size(), + self.stride(), + quantizer_params, + self.requires_grad, + self._backward_hooks, + ) + if type(new_tensor) is not type(self): + raise RuntimeError( + "The default implementation of __deepcopy__() for quantized tensors " + "expects the tensor returned by torch._utils._rebuild_qtensor() to " + "match the type of the instance being copied. If you encounter this, " + "please open an issue on PyTorch's GitHub." + ) + else: + new_tensor = self.new_empty([]) + if type(new_tensor) is not type(self): + raise RuntimeError( + "The default implementation of __deepcopy__() for non-wrapper subclasses " + "only works for subclass types that implement new_empty() and for which " + "that function returns another instance of the same subclass. You should " + "either properly implement new_empty() for your subclass or override " + "__deepcopy__() if it is intended behavior for new_empty() to return " + "an instance of a different type." + ) + new_tensor.set_( + new_storage, self.storage_offset(), self.size(), self.stride() + ) + if self.is_conj(): + new_tensor = new_tensor.conj_physical() + if self.is_neg(): + new_tensor = new_tensor.neg() + if self.requires_grad: + new_tensor.requires_grad_() + if self.grad is not None: + new_tensor.grad = self.grad.__deepcopy__(memo) + + if type(self) is not Tensor: + if type(new_tensor) is not type(self): + raise RuntimeError( + "Type of deepcopy result does not match the type of the source tensor. " + "If you encounter this, please open an issue on PyTorch's GitHub." + ) + + # Plain Tensors don't have slots + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo)) + + new_tensor.__dict__ = deepcopy(self.__dict__, memo) + + memo[id(self)] = new_tensor + return new_tensor + + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + if type(self) is Tensor and not state: + # Fast path for regular tensor without Python state. + return self._reduce_ex_internal(proto) + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) + func, args = self._reduce_ex_internal(proto) + return (_rebuild_from_type_v2, (func, type(self), args, state)) + + def storage(self): + r""" + storage() -> torch.TypedStorage + + Returns the underlying :class:`TypedStorage`. + + .. warning:: + + :class:`TypedStorage` is deprecated. It will be removed in the future, and + :class:`UntypedStorage` will be the only storage class. To access the + :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.storage, (self,), self) + + torch.storage._warn_typed_storage_removal(stacklevel=2) + return self._typed_storage() + + # For internal use only, to avoid raising deprecation warning + def _typed_storage(self): + untyped_storage = self.untyped_storage() + return torch.TypedStorage( + wrap_storage=untyped_storage, dtype=self.dtype, _internal=True + ) + + def _reduce_ex_internal(self, proto): + check_serializing_named_tensor(self) + # See Note [Don't serialize hooks] + torch.utils.hooks.warn_if_has_hooks(self) + backward_hooks: Dict[Any, Any] = OrderedDict() + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, ORT Tensors. + # We considered a few options: + # 1. CPU tensor can't be used here. + # Otherwise in torch.load CPU storage is reconstructed with randomly + # initialized data, moved onto backend device, and then storage is updated + # to the serialized content. This works perfectly for CPU/CUDA but not these backends; + # their tensors are disconnected with storage so they don't get the update. + # 2. Python list is not a good fit due to performance reason. + # `tolist()` converts every single element in the tensor into python objects + # and serialize them one by one. + if self.device.type in ["xla", "mtia", "ort"] or ( + not torch._C._has_storage(self) + and self.device.type == torch._C._get_privateuse1_backend_name() + ): + # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't + # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, + # this would reconstruct the BFloat16 tensor from numpy. + numpy_tensor = ( + self.cpu().numpy() + if self.dtype != torch.bfloat16 + else self.cpu().to(torch.float32).numpy() + ) + return ( + torch._utils._rebuild_device_tensor_from_numpy, + (numpy_tensor, self.dtype, str(self.device), self.requires_grad), + ) + if self.device.type == "meta": + # NB: This implementation BREAKS storage sharing. Current + # hypothesis is that no one cares for meta tensors. + arg_meta = ( + self.dtype, + tuple(self.size()), + self.stride(), + self.requires_grad, + ) + return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) + if self.is_quantized: + # quantizer_params can be different type based on torch attribute + quantizer_params: Union[ + Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] + ] + if self.qscheme() == torch.per_tensor_affine: + quantizer_params = ( + torch.per_tensor_affine, + self.q_scale(), + self.q_zero_point(), + ) + elif self.qscheme() in ( + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ): + # convert scales and zero points to tuple to avoid recursive calls + # when/if we get multi-axis quantized tensors in the future, the shape + # is recoverable from the main tensor shape + quantizer_params = ( + torch.per_channel_affine, + self.q_per_channel_scales(), + self.q_per_channel_zero_points(), + self.q_per_channel_axis(), + ) + else: + raise RuntimeError( + f"Serialization is not supported for tensors of type {self.qscheme()}" + ) + # TODO: Once we decide to break serialization FC, no longer + # need to wrap with TypedStorage + args_qtensor = ( + torch.storage.TypedStorage( + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, + ), + self.storage_offset(), + tuple(self.size()), + self.stride(), + quantizer_params, + self.requires_grad, + backward_hooks, + ) + return (torch._utils._rebuild_qtensor, args_qtensor) + elif self.is_sparse: + if self.layout == torch.sparse_coo: + args_sparse = ( + self.layout, + (self._indices(), self._values(), self.size(), self.is_coalesced()), + ) + else: + raise NotImplementedError( + f"sparse tensor __reduce_ex__ for layout `{self.layout}`" + ) + return (torch._utils._rebuild_sparse_tensor, args_sparse) + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if self.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = ( + self.crow_indices(), + self.col_indices(), + ) + else: + compressed_indices, plain_indices = ( + self.ccol_indices(), + self.row_indices(), + ) + args_sparse_compressed = ( + self.layout, + ( + compressed_indices, + plain_indices, + self.values(), + self.size(), + ), + ) + return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) + elif self.is_nested: + args_nested = ( + # NB: values() currently returns the storage as a buffer in an unsafe way. + # Ideally, we'd use a private API for this instead. TODO: Switch to this if + # we ever get around to adding it. + self.values(), + self._nested_tensor_size(), + self._nested_tensor_strides(), + self._nested_tensor_storage_offsets(), + ) + return (torch._utils._rebuild_nested_tensor, args_nested) + elif ( + self.data_ptr() == 0 + and type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + else: + v3_dtypes = [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, + torch.bits8, + torch.bits16, + torch.bits1x8, + torch.bits2x4, + torch.bits4x2, + torch.complex32, + ] + if self.dtype in v3_dtypes: + rebuild_func = torch._utils._rebuild_tensor_v3 + storage = self.untyped_storage() + else: + # TODO: Once we decide to break serialization FC, no longer + # need to wrap with TypedStorage + rebuild_func = torch._utils._rebuild_tensor_v2 # type: ignore[assignment] + storage = torch.storage.TypedStorage( + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, + ) # type: ignore[assignment] + args = ( + storage, + self.storage_offset(), + tuple(self.size()), + self.stride(), + self.requires_grad, + backward_hooks, + ) # previously was self._backward_hooks + + if isinstance(storage, torch.storage.UntypedStorage): + args = args + (self.dtype,) # type: ignore[assignment] + + metadata = torch._utils.get_tensor_metadata(self) + if metadata: + args = args + (metadata,) # type: ignore[assignment] + + return (rebuild_func, args) + + def __setstate__(self, state): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__setstate__, (self,), self, state) + # Warning: this method is NOT called when you torch.load() a tensor; + # that is managed by _rebuild_tensor_v2 + if not self.is_leaf: + raise RuntimeError("__setstate__ can be only called on leaf Tensors") + if len(state) == 4: + # legacy serialization of Tensor + self.set_(*state) + return + elif len(state) == 5: + # legacy serialization of Variable + self.data = state[0] + state = (state[3], state[4], state[2]) + # The setting of _backward_hooks is expected to be a no-op. + # See Note [Don't serialize hooks] + self.requires_grad, _, self._backward_hooks = state + + def __repr__(self, *, tensor_contents=None): + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.__repr__, (self,), self, tensor_contents=tensor_contents + ) + # All strings are unicode in Python 3. + return torch._tensor_str._str(self, tensor_contents=tensor_contents) + + def backward( + self, gradient=None, retain_graph=None, create_graph=False, inputs=None + ): + r"""Computes the gradient of current tensor wrt graph leaves. + + The graph is differentiated using the chain rule. If the tensor is + non-scalar (i.e. its data has more than one element) and requires + gradient, the function additionally requires specifying ``gradient``. + It should be a tensor of matching type and location, that contains + the gradient of the differentiated function w.r.t. ``self``. + + This function accumulates gradients in the leaves - you might need to zero + ``.grad`` attributes or set them to ``None`` before calling it. + See :ref:`Default gradient layouts` + for details on the memory layout of accumulated gradients. + + .. note:: + + If you run any forward ops, create ``gradient``, and/or call ``backward`` + in a user-specified CUDA stream context, see + :ref:`Stream semantics of backward passes`. + + .. note:: + + When ``inputs`` are provided and a given input is not a leaf, + the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). + It is an implementation detail on which the user should not rely. + See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + + Args: + gradient (Tensor or None): Gradient w.r.t. the + tensor. If it is a tensor, it will be automatically converted + to a Tensor that does not require grad unless ``create_graph`` is True. + None values can be specified for scalar Tensors or ones that + don't require grad. If a None value would be acceptable then + this argument is optional. + retain_graph (bool, optional): If ``False``, the graph used to compute + the grads will be freed. Note that in nearly all cases setting + this option to True is not needed and often can be worked around + in a much more efficient way. Defaults to the value of + ``create_graph``. + create_graph (bool, optional): If ``True``, graph of the derivative will + be constructed, allowing to compute higher order derivative + products. Defaults to ``False``. + inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be + accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were + used to compute the attr::tensors. + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.backward, + (self,), + self, + gradient=gradient, + retain_graph=retain_graph, + create_graph=create_graph, + inputs=inputs, + ) + torch.autograd.backward( + self, gradient, retain_graph, create_graph, inputs=inputs + ) + + def register_hook(self, hook): + r"""Registers a backward hook. + + The hook will be called every time a gradient with respect to the + Tensor is computed. The hook should have the following signature:: + + hook(grad) -> Tensor or None + + + The hook should not modify its argument, but it can optionally return + a new gradient which will be used in place of :attr:`grad`. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. + + Example:: + + >>> v = torch.tensor([0., 0., 0.], requires_grad=True) + >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient + >>> v.backward(torch.tensor([1., 2., 3.])) + >>> v.grad + + 2 + 4 + 6 + [torch.FloatTensor of size (3,)] + + >>> h.remove() # removes the hook + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.register_hook, (self,), self, hook) + if not self.requires_grad: + raise RuntimeError( + "cannot register a hook on a tensor that doesn't require gradient" + ) + if self._backward_hooks is None: + self._backward_hooks = OrderedDict() + if self.grad_fn is not None: + self.grad_fn._register_hook_dict(self) + handle = hooks.RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle + + def register_post_accumulate_grad_hook(self, hook): + r"""Registers a backward hook that runs after grad accumulation. + + The hook will be called after all gradients for a tensor have been accumulated, + meaning that the .grad field has been updated on that tensor. The post + accumulate grad hook is ONLY applicable for leaf tensors (tensors without a + .grad_fn field). Registering this hook on a non-leaf tensor will error! + + The hook should have the following signature:: + + hook(param: Tensor) -> None + + Note that, unlike other autograd hooks, this hook operates on the tensor + that requires grad and not the grad itself. The hook can in-place modify + and access its Tensor argument, including its .grad field. + + This function returns a handle with a method ``handle.remove()`` + that removes the hook from the module. + + .. note:: + See :ref:`backward-hooks-execution` for more information on how when this hook + is executed, and how its execution is ordered relative to other hooks. Since + this hook runs during the backward pass, it will run in no_grad mode (unless + create_graph is True). You can use torch.enable_grad() to re-enable autograd + within the hook if you need it. + + Example:: + + >>> v = torch.tensor([0., 0., 0.], requires_grad=True) + >>> lr = 0.01 + >>> # simulate a simple SGD update + >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) + >>> v.backward(torch.tensor([1., 2., 3.])) + >>> v + tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) + + >>> h.remove() # removes the hook + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.register_post_accumulate_grad_hook, (self,), self, hook + ) + if not self.requires_grad: + raise RuntimeError( + "cannot register a hook on a tensor that doesn't require gradient" + ) + if self.grad_fn is not None: + raise RuntimeError( + "post accumulate grad hooks cannot be registered on non-leaf tensors" + ) + if self._post_accumulate_grad_hooks is None: + self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict() + handle = hooks.RemovableHandle(self._post_accumulate_grad_hooks) + self._post_accumulate_grad_hooks[handle.id] = hook + return handle + + def reinforce(self, reward): + def trim(str): + return "\n".join([line.strip() for line in str.split("\n")]) + + raise RuntimeError( + trim( + r"""reinforce() was removed. + Use torch.distributions instead. + See https://pytorch.org/docs/master/distributions.html + + Instead of: + + probs = policy_network(state) + action = probs.multinomial() + next_state, reward = env.step(action) + action.reinforce(reward) + action.backward() + + Use: + + probs = policy_network(state) + # NOTE: categorical is equivalent to what used to be called multinomial + m = torch.distributions.Categorical(probs) + action = m.sample() + next_state, reward = env.step(action) + loss = -m.log_prob(action) * reward + loss.backward() + """ + ) + ) + + detach = _C._add_docstr( + _C.TensorBase.detach, + r""" + Returns a new Tensor, detached from the current graph. + + The result will never require gradient. + + This method also affects forward mode AD gradients and the result will never + have forward mode AD gradients. + + .. note:: + + Returned Tensor shares the same storage with the original one. + In-place modifications on either of them will be seen, and may trigger + errors in correctness checks. + """, + ) + + detach_ = _C._add_docstr( + _C.TensorBase.detach_, + r""" + Detaches the Tensor from the graph that created it, making it a leaf. + Views cannot be detached in-place. + + This method also affects forward mode AD gradients and the result will never + have forward mode AD gradients. + """, + ) + + def is_shared(self): + r"""Checks if tensor is in shared memory. + + This is always ``True`` for CUDA tensors. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.is_shared, (self,), self) + return self._typed_storage()._is_shared() + + def share_memory_(self): + r"""Moves the underlying storage to shared memory. + + This is a no-op if the underlying storage is already in shared memory + and for CUDA tensors. Tensors in shared memory cannot be resized. + + See :meth:`torch.UntypedStorage.share_memory_` for more details. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.share_memory_, (self,), self) + self._typed_storage()._share_memory_() + return self + + def module_load(self, other, assign=False): + r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`. + + Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. + + It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the + value in the state dictionary with the corresponding key, this method defines + how ``other`` is remapped before being swapped with ``self`` via + :func:`~torch.utils.swap_tensors`` in ``module.load_state_dict()``. + + .. note:: + This method should always return a new object that is not ``self`` or ``other``. + For example, the default implementation returns ``self.copy_(other).detach()`` + if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``. + + Args: + other (Tensor): value in state dict with key corresponding to ``self`` + assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict` + + """ + if has_torch_function_variadic(self, other): + return handle_torch_function( + Tensor.module_load, (self, other), self, other, assign=assign + ) + + if assign: + return other.detach() + else: + return self.copy_(other).detach() + + def __reversed__(self): + r"""Reverses the tensor along dimension 0.""" + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__reversed__, (self,), self) + if self.dim() == 0: + return self + else: + return self.flip(0) + + def norm( + self, + p: Optional[Union[float, str]] = "fro", + dim=None, + keepdim=False, + dtype=None, + ): + r"""See :func:`torch.norm`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype + ) + return torch.norm(self, p, dim, keepdim, dtype=dtype) + + def solve(self, other): + from ._linalg_utils import solve + + return solve(self, other) + + def lstsq(self, other): + from ._linalg_utils import lstsq + + return lstsq(self, other) + + def eig(self, eigenvectors=False): + from ._linalg_utils import eig + + return eig(self, eigenvectors=eigenvectors) + + def symeig(self, eigenvectors=False): + from ._linalg_utils import _symeig + + return _symeig(self, eigenvectors=eigenvectors) + + def lu(self, pivot=True, get_infos=False): + r"""See :func:`torch.lu`""" + # If get_infos is True, then we don't need to check for errors and vice versa + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos + ) + + LU, pivots, infos = torch._lu_with_info( + self, pivot=pivot, check_errors=(not get_infos) + ) + if get_infos: + return LU, pivots, infos + else: + return LU, pivots + + def stft( + self, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + ): + r"""See :func:`torch.stft` + + .. warning:: + This function changed signature at version 0.4.1. Calling with + the previous signature may cause error or return incorrect result. + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.stft, + (self,), + self, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + return_complex=return_complex, + ) + return torch.stft( + self, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + return_complex=return_complex, + ) + + def istft( + self, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: "Optional[Tensor]" = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex: bool = False, + ): + r"""See :func:`torch.istft`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.istft, + (self,), + self, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + return_complex=return_complex, + ) + return torch.istft( + self, + n_fft, + hop_length, + win_length, + window, + center, + normalized, + onesided, + length, + return_complex=return_complex, + ) + + def resize(self, *sizes): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.resize, (self,), self, *sizes) + warnings.warn("non-inplace resize is deprecated") + from torch.autograd._functions import Resize + + return Resize.apply(self, sizes) + + def resize_as(self, tensor): + if has_torch_function_variadic(self, tensor): + return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) + warnings.warn("non-inplace resize_as is deprecated") + from torch.autograd._functions import Resize + + return Resize.apply(self, tensor.size()) + + def split(self, split_size, dim=0): + r"""See :func:`torch.split`""" + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.split, (self,), self, split_size, dim=dim + ) + if isinstance(split_size, Tensor): + try: + split_size = int(split_size) + except ValueError: + pass + + if isinstance(split_size, (int, torch.SymInt)): + return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined] + else: + return torch._VF.split_with_sizes(self, split_size, dim) + + def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): + r"""Returns the unique elements of the input tensor. + + See :func:`torch.unique` + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.unique, + (self,), + self, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + return torch.unique( + self, + sorted=sorted, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + + def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): + r"""Eliminates all but the first element from every consecutive group of equivalent elements. + + See :func:`torch.unique_consecutive` + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.unique_consecutive, + (self,), + self, + return_inverse=return_inverse, + return_counts=return_counts, + dim=dim, + ) + return torch.unique_consecutive( + self, return_inverse=return_inverse, return_counts=return_counts, dim=dim + ) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rsub__(self, other): + return _C._VariableFunctions.rsub(self, other) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rdiv__(self, other): + return self.reciprocal() * other + + __rtruediv__ = __rdiv__ + __itruediv__ = _C.TensorBase.__idiv__ + + __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C.TensorBase.pow + ) + __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( + _C.TensorBase.pow_ + ) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rmod__(self, other): + return torch.remainder(other, self) + + def __format__(self, format_spec): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__format__, (self,), self, format_spec) + if self.dim() == 0 and not self.is_meta and type(self) is Tensor: + return self.item().__format__(format_spec) + return object.__format__(self, format_spec) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rpow__(self, other): + return torch.pow(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __floordiv__(self, other): + return torch.floor_divide(self, other) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rfloordiv__(self, other): + return torch.floor_divide(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rlshift__(self, other): + return torch.bitwise_left_shift(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rrshift__(self, other): + return torch.bitwise_right_shift(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rmatmul__(self, other): + return torch.matmul(other, self) + + __pos__ = _C.TensorBase.positive + __neg__ = _C.TensorBase.neg + __abs__ = _C.TensorBase.abs + + def __len__(self): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__len__, (self,), self) + if self.dim() == 0: + raise TypeError("len() of a 0-d tensor") + if torch._C._get_tracing_state(): + warnings.warn( + "Using len to get tensor shape might cause the trace to be incorrect. " + "Recommended usage would be tensor.shape[0]. " + "Passing a tensor of different shape might lead to errors or silently give " + "incorrect results.", + category=torch.jit.TracerWarning, + stacklevel=2, + ) + return self.shape[0] + + def __iter__(self): + # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a + # generator and don't eagerly perform all the indexes. This could + # save us work, and also helps keep trace ordering deterministic + # (e.g., if you zip(*hiddens), the eager map will force all the + # indexes of hiddens[0] before hiddens[1], while the generator + # map will interleave them.) + # NB: We have intentionally skipped __torch_function__ dispatch here. + # See gh-54457 + if self.dim() == 0: + raise TypeError("iteration over a 0-d tensor") + if torch._C._get_tracing_state(): + warnings.warn( + "Iterating over a tensor might cause the trace to be incorrect. " + "Passing a tensor of different shape won't change the number of " + "iterations executed (and might lead to errors or silently give " + "incorrect results).", + category=torch.jit.TracerWarning, + stacklevel=2, + ) + return iter(self.unbind(0)) + + def __hash__(self): + # Do NOT handle __torch_function__ here as user's default + # implementation that handle most functions will most likely do it wrong. + # It can be easily overridden by defining this method on the user + # subclass if needed. + return id(self) + + def __dir__(self): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dir__, (self,), self) + tensor_methods = dir(self.__class__) + tensor_methods.remove("volatile") # deprecated + attrs = list(self.__dict__.keys()) + keys = tensor_methods + attrs + + # property only available dense, cuda tensors + if (not self.is_cuda) or self.is_sparse: + keys.remove("__cuda_array_interface__") + + return sorted(keys) + + # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray` + __array_priority__ = 1000 # prefer Tensor ops over numpy ones + + def __array__(self, dtype=None): + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype) + if dtype is None: + return self.numpy() + else: + return self.numpy().astype(dtype, copy=False) + + # Wrap Numpy array again in a suitable tensor when done, to support e.g. + # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` + def __array_wrap__(self, array): + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.__array_wrap__, (self,), self, array=array + ) + if array.dtype == bool: + # Workaround, torch has no built-in bool tensor + array = array.astype("uint8") + return torch.from_numpy(array) + + def __contains__(self, element): + r"""Check if `element` is present in tensor + + Args: + element (Tensor or scalar): element to be checked + for presence in current tensor" + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__contains__, (self,), self, element) + if isinstance( + element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool) + ): + # type hint doesn't understand the __contains__ result array + return (element == self).any().item() # type: ignore[union-attr] + + raise RuntimeError( + f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}." + ) + + @property + def __cuda_array_interface__(self): + """Array view description for cuda tensors. + + See: + https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html + """ + if has_torch_function_unary(self): + # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 + return handle_torch_function(Tensor.__cuda_array_interface__.__get__, (self,), self) # type: ignore[attr-defined] + + # raise AttributeError for unsupported tensors, so that + # hasattr(cpu_tensor, "__cuda_array_interface__") is False. + if not self.is_cuda: + raise AttributeError( + "Can't get __cuda_array_interface__ on non-CUDA tensor type: %s " + "If CUDA data is required use tensor.cuda() to copy tensor to device memory." + % self.type() + ) + + if self.is_sparse: + raise AttributeError( + "Can't get __cuda_array_interface__ on sparse type: %s " + "Use Tensor.to_dense() to convert to a dense tensor first." + % self.type() + ) + + # RuntimeError, matching tensor.__array__() behavior. + if self.requires_grad: + raise RuntimeError( + "Can't get __cuda_array_interface__ on Variable that requires grad. " + "If gradients aren't required, use var.detach() to get Variable that doesn't require grad." + ) + + # CUDA devices are little-endian and tensors are stored in native byte + # order. 1-byte entries are endian-agnostic. + typestr = { + torch.complex64: " 0 else 0 + data = (data_ptr, False) # read-only is false + + return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2) + + def storage_type(self): + r"""storage_type() -> type + + Returns the type of the underlying storage. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.storage_type, (self,), self) + + torch.storage._warn_typed_storage_removal() + + return self._typed_storage()._get_legacy_storage_class() + + def refine_names(self, *names): + r"""Refines the dimension names of :attr:`self` according to :attr:`names`. + + Refining is a special case of renaming that "lifts" unnamed dimensions. + A ``None`` dim can be refined to have any name; a named dim can only be + refined to have the same name. + + Because named tensors can coexist with unnamed tensors, refining names + gives a nice way to write named-tensor-aware code that works with both + named and unnamed tensors. + + :attr:`names` may contain up to one Ellipsis (``...``). + The Ellipsis is expanded greedily; it is expanded in-place to fill + :attr:`names` to the same length as ``self.dim()`` using names from the + corresponding indices of ``self.names``. + + Python 2 does not support Ellipsis but one may use a string literal + instead (``'...'``). + + Args: + names (iterable of str): The desired names of the output tensor. May + contain up to one Ellipsis. + + Examples:: + + >>> imgs = torch.randn(32, 3, 128, 128) + >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') + >>> named_imgs.names + ('N', 'C', 'H', 'W') + + >>> tensor = torch.randn(2, 3, 5, 7, 11) + >>> tensor = tensor.refine_names('A', ..., 'B', 'C') + >>> tensor.names + ('A', None, None, 'B', 'C') + + .. warning:: + The named tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.refine_names, (self,), self, *names) + names = resolve_ellipsis(names, self.names, "refine_names") + return super().refine_names(names) + + def align_to(self, *names): + r"""Permutes the dimensions of the :attr:`self` tensor to match the order + specified in :attr:`names`, adding size-one dims for any new names. + + All of the dims of :attr:`self` must be named in order to use this method. + The resulting tensor is a view on the original tensor. + + All dimension names of :attr:`self` must be present in :attr:`names`. + :attr:`names` may contain additional names that are not in ``self.names``; + the output tensor has a size-one dimension for each of those new names. + + :attr:`names` may contain up to one Ellipsis (``...``). + The Ellipsis is expanded to be equal to all dimension names of :attr:`self` + that are not mentioned in :attr:`names`, in the order that they appear + in :attr:`self`. + + Python 2 does not support Ellipsis but one may use a string literal + instead (``'...'``). + + Args: + names (iterable of str): The desired dimension ordering of the + output tensor. May contain up to one Ellipsis that is expanded + to all unmentioned dim names of :attr:`self`. + + Examples:: + + >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) + >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') + + # Move the F and E dims to the front while keeping the rest in order + >>> named_tensor.align_to('F', 'E', ...) + + .. warning:: + The named tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.align_to, (self,), self, *names) + ellipsis_idx = single_ellipsis_index(names, "align_to") + if ellipsis_idx is None: + return super().align_to(names) + return super().align_to( + [name for name in names if not is_ellipsis(name)], ellipsis_idx + ) + + def unflatten(self, dim, sizes): + r""" + unflatten(dim, sizes) -> Tensor + + See :func:`torch.unflatten`. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes) + + if not sizes: + raise RuntimeError("unflatten: sizes must be non-empty") + + names = None + if isinstance(sizes, OrderedDict) or ( + isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list)) + ): + names, sizes = unzip_namedshape(sizes) + return super().unflatten(dim, sizes, names) + else: + return super().unflatten(dim, sizes) + + def rename_(self, *names, **rename_map): + """In-place version of :meth:`~Tensor.rename`.""" + + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.rename_, (self,), self, *names, **rename_map + ) + + # Note [rename_ / rename API] + # The Python API for these is different from the C++ API. In Python: + # 1) tensor.rename(*names) takes a vararglist of names + # 2) tensor.rename(**rename_map) takes a map of names to rename. + # C++ is static, making it difficult to implement similar behavior. + return update_names(self, names, rename_map, inplace=True) + + def rename(self, *names, **rename_map): + """Renames dimension names of :attr:`self`. + + There are two main usages: + + ``self.rename(**rename_map)`` returns a view on tensor that has dims + renamed as specified in the mapping :attr:`rename_map`. + + ``self.rename(*names)`` returns a view on tensor, renaming all + dimensions positionally using :attr:`names`. + Use ``self.rename(None)`` to drop names on a tensor. + + One cannot specify both positional args :attr:`names` and keyword args + :attr:`rename_map`. + + Examples:: + + >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) + >>> renamed_imgs = imgs.rename(N='batch', C='channels') + >>> renamed_imgs.names + ('batch', 'channels', 'H', 'W') + + >>> renamed_imgs = imgs.rename(None) + >>> renamed_imgs.names + (None, None, None, None) + + >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width') + >>> renamed_imgs.names + ('batch', 'channel', 'height', 'width') + + .. warning:: + The named tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.rename, (self,), self, *names, **rename_map + ) + + # See Note [rename_ / rename API] + return update_names(self, names, rename_map, inplace=False) + + def to_sparse_coo(self): + """Convert a tensor to :ref:`coordinate format `. + + Examples:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_coo() + >>> sparse._nnz() + 25 + + """ + return self.to_sparse() + + def dim_order(self): + """ + + dim_order() -> tuple + + Returns a tuple of int describing the dim order or physical layout of :attr:`self`. + + Args: + None + + Dim order represents how dimensions are laid out in memory, + starting from the outermost to the innermost dimension. + + Example:: + >>> torch.empty((2, 3, 5, 7)).dim_order() + (0, 1, 2, 3) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order() + (0, 2, 3, 1) + + .. warning:: + The dim_order tensor API is experimental and subject to change. + + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.dim_order, (self,), self) + + import torch._prims_common as utils + + return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self)) + + def _update_names(self, names, inplace): + if has_torch_function_unary(self): + return handle_torch_function( + Tensor._update_names, (self,), self, names, inplace + ) + + # See Note [rename_ / rename API] + if inplace: + return super().rename_(names) + else: + return super().rename(names) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """ + This __torch_function__ implementation wraps subclasses such that + methods called on subclasses return a subclass instance instead of + a ``torch.Tensor`` instance. + + One corollary to this is that you need coverage for torch.Tensor + methods if implementing __torch_function__ for subclasses. + + We recommend always calling ``super().__torch_function__`` as the base + case when doing the above. + + While not mandatory, we recommend making `__torch_function__` a classmethod. + """ + if kwargs is None: + kwargs = {} + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + with _C.DisableTorchFunctionSubclass(): + ret = func(*args, **kwargs) + if func in get_default_nowrap_functions(): + return ret + else: + return _convert(ret, cls) + + __torch_dispatch__ = _C._disabled_torch_dispatch_impl + + def __dlpack__(self, stream=None): + """ + Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ + of the current tensor to be exported to other libraries. + + This function will be called from the `from_dlpack` method + of the library that will consume the capsule. `from_dlpack` passes the current + stream to this method as part of the specification. + + Args: + stream (integer or None): An optional Python integer representing a + pointer to a CUDA stream. The current stream is synchronized with + this stream before the capsule is created, and since the capsule + shares its storage with the tensor this make it safe to access from + both streams. If None or -1 is passed then no synchronization is performed. + If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for + synchronization. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) + + # DLPack capsules can't capture all of PyTorch's semantics, + # so we prohibit exporting tensors that would lose their properties like + # requires_grad and having the conjugate bit set. + if self.requires_grad: + raise RuntimeError( + "Can't export tensors that require gradient, use tensor.detach()" + ) + if self.is_conj(): + raise RuntimeError("Can't export tensors with the conjugate bit set") + if self.layout != torch.strided: + raise RuntimeError( + "Can't export tensors with layout other than torch.strided" + ) + + if stream is not None and type(stream) is not int: + # Stream pointers in CUDA/ROCm are uniquely numbered and can + # be retrieved from their integer value. + raise TypeError("stream must be ``int`` or ``none``") + elif stream is not None and stream != -1: + if self.device.type == "cuda": + # NB: This logic handles the special case values for default + # streams and must be kept in sync with from_dlpack in + # torch/utils/dlpack.py + if stream == 1 and torch.version.hip is None: + stream = torch.cuda.default_stream() + elif stream == 0 and torch.version.hip is not None: + stream = torch.cuda.default_stream() + else: + stream = torch.cuda.ExternalStream(stream) + # Only synchronize on different streams + sync_stream = torch.cuda.current_stream() + if stream != sync_stream: + event = torch.cuda.Event() + event.record(sync_stream) + stream.wait_event(event) + return torch.to_dlpack(self) + + def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack_device__, (self,), self) + device = self.device + idx = device.index if device.index is not None else 0 + torch_device_type = device.type + if torch_device_type == "cuda" and torch.version.hip is not None: + device_type = DLDeviceType.kDLROCM + elif torch_device_type == "cpu" and self.is_pinned(): + device_type = DLDeviceType.kDLCPUPinned + elif torch_device_type == "cuda": + device_type = DLDeviceType.kDLGPU + elif torch_device_type == "cpu": + device_type = DLDeviceType.kDLCPU + elif self.device.type == "xpu": + device_type = DLDeviceType.kDLOneAPI + else: + raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") + return (device_type, idx) + + __module__ = "torch" + + +def _convert(ret, cls): + if cls is Tensor: + return ret + + if isinstance(ret, Tensor) and not isinstance(ret, cls): + ret = ret.as_subclass(cls) + + if isinstance(ret, (tuple, list)): + # Also handles things like namedtuples + ret = type(ret)(_convert(r, cls) for r in ret) + + return ret diff --git a/MLPY/Lib/site-packages/torch/_tensor_docs.py b/MLPY/Lib/site-packages/torch/_tensor_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fa09eaf864be4301db09f40bbe854d702e4555 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_tensor_docs.py @@ -0,0 +1,6976 @@ +"""Adds docstrings to Tensor functions""" + +import torch._C +from torch._C import _add_docstr as add_docstr +from torch._torch_docs import parse_kwargs, reproducibility_notes + + +def add_docstr_all(method, docstr): + add_docstr(getattr(torch._C.TensorBase, method), docstr) + + +common_args = parse_kwargs( + """ + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. +""" +) + +new_common_args = parse_kwargs( + """ + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. +""" +) + +add_docstr_all( + "new_tensor", + """ +new_tensor(data, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a new Tensor with :attr:`data` as the tensor data. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +.. warning:: + + :func:`new_tensor` always copies :attr:`data`. If you have a Tensor + ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` + or :func:`torch.Tensor.detach`. + If you have a numpy array and want to avoid a copy, use + :func:`torch.from_numpy`. + +.. warning:: + + When data is a tensor `x`, :func:`new_tensor()` reads out 'the data' from whatever it is passed, + and constructs a leaf variable. Therefore ``tensor.new_tensor(x)`` is equivalent to ``x.clone().detach()`` + and ``tensor.new_tensor(x, requires_grad=True)`` is equivalent to ``x.clone().detach().requires_grad_(True)``. + The equivalents using ``clone()`` and ``detach()`` are recommended. + +Args: + data (array_like): The returned Tensor copies :attr:`data`. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones((2,), dtype=torch.int8) + >>> data = [[0, 1], [2, 3]] + >>> tensor.new_tensor(data) + tensor([[ 0, 1], + [ 2, 3]], dtype=torch.int8) + +""".format( + **new_common_args + ), +) + +add_docstr_all( + "new_full", + """ +new_full(size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with :attr:`fill_value`. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + fill_value (scalar): the number to fill the output tensor with. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones((2,), dtype=torch.float64) + >>> tensor.new_full((3, 4), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416]], dtype=torch.float64) + +""".format( + **new_common_args + ), +) + +add_docstr_all( + "new_empty", + """ +new_empty(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with uninitialized data. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty((2, 3)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + +""".format( + **new_common_args + ), +) + +add_docstr_all( + "new_empty_strided", + """ +new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with +uninitialized data. By default, the returned Tensor has the same +:class:`torch.dtype` and :class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty_strided((2, 3), (3, 1)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + +""".format( + **new_common_args + ), +) + +add_docstr_all( + "new_ones", + """ +new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with ``1``. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + +""".format( + **new_common_args + ), +) + +add_docstr_all( + "new_zeros", + """ +new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a Tensor of size :attr:`size` filled with ``0``. +By default, the returned Tensor has the same :class:`torch.dtype` and +:class:`torch.device` as this tensor. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + +Keyword args: + {dtype} + {device} + {requires_grad} + {layout} + {pin_memory} + +Example:: + + >>> tensor = torch.tensor((), dtype=torch.float64) + >>> tensor.new_zeros((2, 3)) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]], dtype=torch.float64) + +""".format( + **new_common_args + ), +) + +add_docstr_all( + "abs", + r""" +abs() -> Tensor + +See :func:`torch.abs` +""", +) + +add_docstr_all( + "abs_", + r""" +abs_() -> Tensor + +In-place version of :meth:`~Tensor.abs` +""", +) + +add_docstr_all( + "absolute", + r""" +absolute() -> Tensor + +Alias for :func:`abs` +""", +) + +add_docstr_all( + "absolute_", + r""" +absolute_() -> Tensor + +In-place version of :meth:`~Tensor.absolute` +Alias for :func:`abs_` +""", +) + +add_docstr_all( + "acos", + r""" +acos() -> Tensor + +See :func:`torch.acos` +""", +) + +add_docstr_all( + "acos_", + r""" +acos_() -> Tensor + +In-place version of :meth:`~Tensor.acos` +""", +) + +add_docstr_all( + "arccos", + r""" +arccos() -> Tensor + +See :func:`torch.arccos` +""", +) + +add_docstr_all( + "arccos_", + r""" +arccos_() -> Tensor + +In-place version of :meth:`~Tensor.arccos` +""", +) + +add_docstr_all( + "acosh", + r""" +acosh() -> Tensor + +See :func:`torch.acosh` +""", +) + +add_docstr_all( + "acosh_", + r""" +acosh_() -> Tensor + +In-place version of :meth:`~Tensor.acosh` +""", +) + +add_docstr_all( + "arccosh", + r""" +acosh() -> Tensor + +See :func:`torch.arccosh` +""", +) + +add_docstr_all( + "arccosh_", + r""" +acosh_() -> Tensor + +In-place version of :meth:`~Tensor.arccosh` +""", +) + +add_docstr_all( + "add", + r""" +add(other, *, alpha=1) -> Tensor + +Add a scalar or tensor to :attr:`self` tensor. If both :attr:`alpha` +and :attr:`other` are specified, each element of :attr:`other` is scaled by +:attr:`alpha` before being used. + +When :attr:`other` is a tensor, the shape of :attr:`other` must be +:ref:`broadcastable ` with the shape of the underlying +tensor + +See :func:`torch.add` +""", +) + +add_docstr_all( + "add_", + r""" +add_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.add` +""", +) + +add_docstr_all( + "addbmm", + r""" +addbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addbmm` +""", +) + +add_docstr_all( + "addbmm_", + r""" +addbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addbmm` +""", +) + +add_docstr_all( + "addcdiv", + r""" +addcdiv(tensor1, tensor2, *, value=1) -> Tensor + +See :func:`torch.addcdiv` +""", +) + +add_docstr_all( + "addcdiv_", + r""" +addcdiv_(tensor1, tensor2, *, value=1) -> Tensor + +In-place version of :meth:`~Tensor.addcdiv` +""", +) + +add_docstr_all( + "addcmul", + r""" +addcmul(tensor1, tensor2, *, value=1) -> Tensor + +See :func:`torch.addcmul` +""", +) + +add_docstr_all( + "addcmul_", + r""" +addcmul_(tensor1, tensor2, *, value=1) -> Tensor + +In-place version of :meth:`~Tensor.addcmul` +""", +) + +add_docstr_all( + "addmm", + r""" +addmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addmm` +""", +) + +add_docstr_all( + "addmm_", + r""" +addmm_(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addmm` +""", +) + +add_docstr_all( + "addmv", + r""" +addmv(mat, vec, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addmv` +""", +) + +add_docstr_all( + "addmv_", + r""" +addmv_(mat, vec, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addmv` +""", +) + +add_docstr_all( + "sspaddmm", + r""" +sspaddmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.sspaddmm` +""", +) + +add_docstr_all( + "smm", + r""" +smm(mat) -> Tensor + +See :func:`torch.smm` +""", +) + +add_docstr_all( + "addr", + r""" +addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.addr` +""", +) + +add_docstr_all( + "addr_", + r""" +addr_(vec1, vec2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.addr` +""", +) + +add_docstr_all( + "align_as", + r""" +align_as(other) -> Tensor + +Permutes the dimensions of the :attr:`self` tensor to match the dimension order +in the :attr:`other` tensor, adding size-one dims for any new names. + +This operation is useful for explicit broadcasting by names (see examples). + +All of the dims of :attr:`self` must be named in order to use this method. +The resulting tensor is a view on the original tensor. + +All dimension names of :attr:`self` must be present in ``other.names``. +:attr:`other` may contain named dimensions that are not in ``self.names``; +the output tensor has a size-one dimension for each of those new names. + +To align a tensor to a specific order, use :meth:`~Tensor.align_to`. + +Examples:: + + # Example 1: Applying a mask + >>> mask = torch.randint(2, [127, 128], dtype=torch.bool).refine_names('W', 'H') + >>> imgs = torch.randn(32, 128, 127, 3, names=('N', 'H', 'W', 'C')) + >>> imgs.masked_fill_(mask.align_as(imgs), 0) + + + # Example 2: Applying a per-channel-scale + >>> def scale_channels(input, scale): + >>> scale = scale.refine_names('C') + >>> return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names=('C',)) + >>> imgs = torch.rand(32, 128, 128, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(32, num_channels, 128, 128, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 128, 128, 128, names=('N', 'C', 'H', 'W', 'D')) + + # scale_channels is agnostic to the dimension order of the input + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) + +.. warning:: + The named tensor API is experimental and subject to change. + +""", +) + +add_docstr_all( + "all", + r""" +all(dim=None, keepdim=False) -> Tensor + +See :func:`torch.all` +""", +) + +add_docstr_all( + "allclose", + r""" +allclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + +See :func:`torch.allclose` +""", +) + +add_docstr_all( + "angle", + r""" +angle() -> Tensor + +See :func:`torch.angle` +""", +) + +add_docstr_all( + "any", + r""" +any(dim=None, keepdim=False) -> Tensor + +See :func:`torch.any` +""", +) + +add_docstr_all( + "apply_", + r""" +apply_(callable) -> Tensor + +Applies the function :attr:`callable` to each element in the tensor, replacing +each element with the value returned by :attr:`callable`. + +.. note:: + + This function only works with CPU tensors and should not be used in code + sections that require high performance. +""", +) + +add_docstr_all( + "asin", + r""" +asin() -> Tensor + +See :func:`torch.asin` +""", +) + +add_docstr_all( + "asin_", + r""" +asin_() -> Tensor + +In-place version of :meth:`~Tensor.asin` +""", +) + +add_docstr_all( + "arcsin", + r""" +arcsin() -> Tensor + +See :func:`torch.arcsin` +""", +) + +add_docstr_all( + "arcsin_", + r""" +arcsin_() -> Tensor + +In-place version of :meth:`~Tensor.arcsin` +""", +) + +add_docstr_all( + "asinh", + r""" +asinh() -> Tensor + +See :func:`torch.asinh` +""", +) + +add_docstr_all( + "asinh_", + r""" +asinh_() -> Tensor + +In-place version of :meth:`~Tensor.asinh` +""", +) + +add_docstr_all( + "arcsinh", + r""" +arcsinh() -> Tensor + +See :func:`torch.arcsinh` +""", +) + +add_docstr_all( + "arcsinh_", + r""" +arcsinh_() -> Tensor + +In-place version of :meth:`~Tensor.arcsinh` +""", +) + +add_docstr_all( + "as_strided", + r""" +as_strided(size, stride, storage_offset=None) -> Tensor + +See :func:`torch.as_strided` +""", +) + +add_docstr_all( + "as_strided_", + r""" +as_strided_(size, stride, storage_offset=None) -> Tensor + +In-place version of :meth:`~Tensor.as_strided` +""", +) + +add_docstr_all( + "atan", + r""" +atan() -> Tensor + +See :func:`torch.atan` +""", +) + +add_docstr_all( + "atan_", + r""" +atan_() -> Tensor + +In-place version of :meth:`~Tensor.atan` +""", +) + +add_docstr_all( + "arctan", + r""" +arctan() -> Tensor + +See :func:`torch.arctan` +""", +) + +add_docstr_all( + "arctan_", + r""" +arctan_() -> Tensor + +In-place version of :meth:`~Tensor.arctan` +""", +) + +add_docstr_all( + "atan2", + r""" +atan2(other) -> Tensor + +See :func:`torch.atan2` +""", +) + +add_docstr_all( + "atan2_", + r""" +atan2_(other) -> Tensor + +In-place version of :meth:`~Tensor.atan2` +""", +) + +add_docstr_all( + "arctan2", + r""" +arctan2(other) -> Tensor + +See :func:`torch.arctan2` +""", +) + +add_docstr_all( + "arctan2_", + r""" +atan2_(other) -> Tensor + +In-place version of :meth:`~Tensor.arctan2` +""", +) + +add_docstr_all( + "atanh", + r""" +atanh() -> Tensor + +See :func:`torch.atanh` +""", +) + +add_docstr_all( + "atanh_", + r""" +atanh_(other) -> Tensor + +In-place version of :meth:`~Tensor.atanh` +""", +) + +add_docstr_all( + "arctanh", + r""" +arctanh() -> Tensor + +See :func:`torch.arctanh` +""", +) + +add_docstr_all( + "arctanh_", + r""" +arctanh_(other) -> Tensor + +In-place version of :meth:`~Tensor.arctanh` +""", +) + +add_docstr_all( + "baddbmm", + r""" +baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +See :func:`torch.baddbmm` +""", +) + +add_docstr_all( + "baddbmm_", + r""" +baddbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.baddbmm` +""", +) + +add_docstr_all( + "bernoulli", + r""" +bernoulli(*, generator=None) -> Tensor + +Returns a result tensor where each :math:`\texttt{result[i]}` is independently +sampled from :math:`\text{Bernoulli}(\texttt{self[i]})`. :attr:`self` must have +floating point ``dtype``, and the result will have the same ``dtype``. + +See :func:`torch.bernoulli` +""", +) + +add_docstr_all( + "bernoulli_", + r""" +bernoulli_(p=0.5, *, generator=None) -> Tensor + +Fills each location of :attr:`self` with an independent sample from +:math:`\text{Bernoulli}(\texttt{p})`. :attr:`self` can have integral +``dtype``. + +:attr:`p` should either be a scalar or tensor containing probabilities to be +used for drawing the binary random number. + +If it is a tensor, the :math:`\text{i}^{th}` element of :attr:`self` tensor +will be set to a value sampled from +:math:`\text{Bernoulli}(\texttt{p\_tensor[i]})`. In this case `p` must have +floating point ``dtype``. + +See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` +""", +) + +add_docstr_all( + "bincount", + r""" +bincount(weights=None, minlength=0) -> Tensor + +See :func:`torch.bincount` +""", +) + +add_docstr_all( + "bitwise_not", + r""" +bitwise_not() -> Tensor + +See :func:`torch.bitwise_not` +""", +) + +add_docstr_all( + "bitwise_not_", + r""" +bitwise_not_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_not` +""", +) + +add_docstr_all( + "bitwise_and", + r""" +bitwise_and() -> Tensor + +See :func:`torch.bitwise_and` +""", +) + +add_docstr_all( + "bitwise_and_", + r""" +bitwise_and_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_and` +""", +) + +add_docstr_all( + "bitwise_or", + r""" +bitwise_or() -> Tensor + +See :func:`torch.bitwise_or` +""", +) + +add_docstr_all( + "bitwise_or_", + r""" +bitwise_or_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_or` +""", +) + +add_docstr_all( + "bitwise_xor", + r""" +bitwise_xor() -> Tensor + +See :func:`torch.bitwise_xor` +""", +) + +add_docstr_all( + "bitwise_xor_", + r""" +bitwise_xor_() -> Tensor + +In-place version of :meth:`~Tensor.bitwise_xor` +""", +) + +add_docstr_all( + "bitwise_left_shift", + r""" +bitwise_left_shift(other) -> Tensor + +See :func:`torch.bitwise_left_shift` +""", +) + +add_docstr_all( + "bitwise_left_shift_", + r""" +bitwise_left_shift_(other) -> Tensor + +In-place version of :meth:`~Tensor.bitwise_left_shift` +""", +) + +add_docstr_all( + "bitwise_right_shift", + r""" +bitwise_right_shift(other) -> Tensor + +See :func:`torch.bitwise_right_shift` +""", +) + +add_docstr_all( + "bitwise_right_shift_", + r""" +bitwise_right_shift_(other) -> Tensor + +In-place version of :meth:`~Tensor.bitwise_right_shift` +""", +) + +add_docstr_all( + "broadcast_to", + r""" +broadcast_to(shape) -> Tensor + +See :func:`torch.broadcast_to`. +""", +) + +add_docstr_all( + "logical_and", + r""" +logical_and() -> Tensor + +See :func:`torch.logical_and` +""", +) + +add_docstr_all( + "logical_and_", + r""" +logical_and_() -> Tensor + +In-place version of :meth:`~Tensor.logical_and` +""", +) + +add_docstr_all( + "logical_not", + r""" +logical_not() -> Tensor + +See :func:`torch.logical_not` +""", +) + +add_docstr_all( + "logical_not_", + r""" +logical_not_() -> Tensor + +In-place version of :meth:`~Tensor.logical_not` +""", +) + +add_docstr_all( + "logical_or", + r""" +logical_or() -> Tensor + +See :func:`torch.logical_or` +""", +) + +add_docstr_all( + "logical_or_", + r""" +logical_or_() -> Tensor + +In-place version of :meth:`~Tensor.logical_or` +""", +) + +add_docstr_all( + "logical_xor", + r""" +logical_xor() -> Tensor + +See :func:`torch.logical_xor` +""", +) + +add_docstr_all( + "logical_xor_", + r""" +logical_xor_() -> Tensor + +In-place version of :meth:`~Tensor.logical_xor` +""", +) + +add_docstr_all( + "bmm", + r""" +bmm(batch2) -> Tensor + +See :func:`torch.bmm` +""", +) + +add_docstr_all( + "cauchy_", + r""" +cauchy_(median=0, sigma=1, *, generator=None) -> Tensor + +Fills the tensor with numbers drawn from the Cauchy distribution: + +.. math:: + + f(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - \text{median})^2 + \sigma^2} + +.. note:: + Sigma (:math:`\sigma`) is used to denote the scale parameter in Cauchy distribution. +""", +) + +add_docstr_all( + "ceil", + r""" +ceil() -> Tensor + +See :func:`torch.ceil` +""", +) + +add_docstr_all( + "ceil_", + r""" +ceil_() -> Tensor + +In-place version of :meth:`~Tensor.ceil` +""", +) + +add_docstr_all( + "cholesky", + r""" +cholesky(upper=False) -> Tensor + +See :func:`torch.cholesky` +""", +) + +add_docstr_all( + "cholesky_solve", + r""" +cholesky_solve(input2, upper=False) -> Tensor + +See :func:`torch.cholesky_solve` +""", +) + +add_docstr_all( + "cholesky_inverse", + r""" +cholesky_inverse(upper=False) -> Tensor + +See :func:`torch.cholesky_inverse` +""", +) + +add_docstr_all( + "clamp", + r""" +clamp(min=None, max=None) -> Tensor + +See :func:`torch.clamp` +""", +) + +add_docstr_all( + "clamp_", + r""" +clamp_(min=None, max=None) -> Tensor + +In-place version of :meth:`~Tensor.clamp` +""", +) + +add_docstr_all( + "clip", + r""" +clip(min=None, max=None) -> Tensor + +Alias for :meth:`~Tensor.clamp`. +""", +) + +add_docstr_all( + "clip_", + r""" +clip_(min=None, max=None) -> Tensor + +Alias for :meth:`~Tensor.clamp_`. +""", +) + +add_docstr_all( + "clone", + r""" +clone(*, memory_format=torch.preserve_format) -> Tensor + +See :func:`torch.clone` +""".format( + **common_args + ), +) + +add_docstr_all( + "coalesce", + r""" +coalesce() -> Tensor + +Returns a coalesced copy of :attr:`self` if :attr:`self` is an +:ref:`uncoalesced tensor `. + +Returns :attr:`self` if :attr:`self` is a coalesced tensor. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. +""", +) + +add_docstr_all( + "contiguous", + r""" +contiguous(memory_format=torch.contiguous_format) -> Tensor + +Returns a contiguous in memory tensor containing the same data as :attr:`self` tensor. If +:attr:`self` tensor is already in the specified memory format, this function returns the +:attr:`self` tensor. + +Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. +""", +) + +add_docstr_all( + "copy_", + r""" +copy_(src, non_blocking=False) -> Tensor + +Copies the elements from :attr:`src` into :attr:`self` tensor and returns +:attr:`self`. + +The :attr:`src` tensor must be :ref:`broadcastable ` +with the :attr:`self` tensor. It may be of a different data type or reside on a +different device. + +Args: + src (Tensor): the source tensor to copy from + non_blocking (bool): if ``True`` and this copy is between CPU and GPU, + the copy may occur asynchronously with respect to the host. For other + cases, this argument has no effect. +""", +) + +add_docstr_all( + "conj", + r""" +conj() -> Tensor + +See :func:`torch.conj` +""", +) + +add_docstr_all( + "conj_physical", + r""" +conj_physical() -> Tensor + +See :func:`torch.conj_physical` +""", +) + +add_docstr_all( + "conj_physical_", + r""" +conj_physical_() -> Tensor + +In-place version of :meth:`~Tensor.conj_physical` +""", +) + +add_docstr_all( + "resolve_conj", + r""" +resolve_conj() -> Tensor + +See :func:`torch.resolve_conj` +""", +) + +add_docstr_all( + "resolve_neg", + r""" +resolve_neg() -> Tensor + +See :func:`torch.resolve_neg` +""", +) + +add_docstr_all( + "copysign", + r""" +copysign(other) -> Tensor + +See :func:`torch.copysign` +""", +) + +add_docstr_all( + "copysign_", + r""" +copysign_(other) -> Tensor + +In-place version of :meth:`~Tensor.copysign` +""", +) + +add_docstr_all( + "cos", + r""" +cos() -> Tensor + +See :func:`torch.cos` +""", +) + +add_docstr_all( + "cos_", + r""" +cos_() -> Tensor + +In-place version of :meth:`~Tensor.cos` +""", +) + +add_docstr_all( + "cosh", + r""" +cosh() -> Tensor + +See :func:`torch.cosh` +""", +) + +add_docstr_all( + "cosh_", + r""" +cosh_() -> Tensor + +In-place version of :meth:`~Tensor.cosh` +""", +) + +add_docstr_all( + "cpu", + r""" +cpu(memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in CPU memory. + +If this object is already in CPU memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + {memory_format} + +""".format( + **common_args + ), +) + +add_docstr_all( + "count_nonzero", + r""" +count_nonzero(dim=None) -> Tensor + +See :func:`torch.count_nonzero` +""", +) + +add_docstr_all( + "cov", + r""" +cov(*, correction=1, fweights=None, aweights=None) -> Tensor + +See :func:`torch.cov` +""", +) + +add_docstr_all( + "corrcoef", + r""" +corrcoef() -> Tensor + +See :func:`torch.corrcoef` +""", +) + +add_docstr_all( + "cross", + r""" +cross(other, dim=None) -> Tensor + +See :func:`torch.cross` +""", +) + +add_docstr_all( + "cuda", + r""" +cuda(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in CUDA memory. + +If this object is already in CUDA memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination GPU device. + Defaults to the current CUDA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "ipu", + r""" +ipu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in IPU memory. + +If this object is already in IPU memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination IPU device. + Defaults to the current IPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "xpu", + r""" +xpu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of this object in XPU memory. + +If this object is already in XPU memory and on the correct device, +then no copy is performed and the original object is returned. + +Args: + device (:class:`torch.device`): The destination XPU device. + Defaults to the current XPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "logcumsumexp", + r""" +logcumsumexp(dim) -> Tensor + +See :func:`torch.logcumsumexp` +""", +) + +add_docstr_all( + "cummax", + r""" +cummax(dim) -> (Tensor, Tensor) + +See :func:`torch.cummax` +""", +) + +add_docstr_all( + "cummin", + r""" +cummin(dim) -> (Tensor, Tensor) + +See :func:`torch.cummin` +""", +) + +add_docstr_all( + "cumprod", + r""" +cumprod(dim, dtype=None) -> Tensor + +See :func:`torch.cumprod` +""", +) + +add_docstr_all( + "cumprod_", + r""" +cumprod_(dim, dtype=None) -> Tensor + +In-place version of :meth:`~Tensor.cumprod` +""", +) + +add_docstr_all( + "cumsum", + r""" +cumsum(dim, dtype=None) -> Tensor + +See :func:`torch.cumsum` +""", +) + +add_docstr_all( + "cumsum_", + r""" +cumsum_(dim, dtype=None) -> Tensor + +In-place version of :meth:`~Tensor.cumsum` +""", +) + +add_docstr_all( + "data_ptr", + r""" +data_ptr() -> int + +Returns the address of the first element of :attr:`self` tensor. +""", +) + +add_docstr_all( + "dequantize", + r""" +dequantize() -> Tensor + +Given a quantized Tensor, dequantize it and return the dequantized float Tensor. +""", +) + +add_docstr_all( + "dense_dim", + r""" +dense_dim() -> int + +Return the number of dense dimensions in a :ref:`sparse tensor ` :attr:`self`. + +.. note:: + Returns ``len(self.shape)`` if :attr:`self` is not a sparse tensor. + +See also :meth:`Tensor.sparse_dim` and :ref:`hybrid tensors `. +""", +) + +add_docstr_all( + "diag", + r""" +diag(diagonal=0) -> Tensor + +See :func:`torch.diag` +""", +) + +add_docstr_all( + "diag_embed", + r""" +diag_embed(offset=0, dim1=-2, dim2=-1) -> Tensor + +See :func:`torch.diag_embed` +""", +) + +add_docstr_all( + "diagflat", + r""" +diagflat(offset=0) -> Tensor + +See :func:`torch.diagflat` +""", +) + +add_docstr_all( + "diagonal", + r""" +diagonal(offset=0, dim1=0, dim2=1) -> Tensor + +See :func:`torch.diagonal` +""", +) + +add_docstr_all( + "diagonal_scatter", + r""" +diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor + +See :func:`torch.diagonal_scatter` +""", +) + +add_docstr_all( + "as_strided_scatter", + r""" +as_strided_scatter(src, size, stride, storage_offset=None) -> Tensor + +See :func:`torch.as_strided_scatter` +""", +) + +add_docstr_all( + "fill_diagonal_", + r""" +fill_diagonal_(fill_value, wrap=False) -> Tensor + +Fill the main diagonal of a tensor that has at least 2-dimensions. +When dims>2, all dimensions of input must be of equal length. +This function modifies the input tensor in-place, and returns the input tensor. + +Arguments: + fill_value (Scalar): the fill value + wrap (bool): the diagonal 'wrapped' after N columns for tall matrices. + +Example:: + + >>> a = torch.zeros(3, 3) + >>> a.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + >>> b = torch.zeros(7, 3) + >>> b.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> c = torch.zeros(7, 3) + >>> c.fill_diagonal_(5, wrap=True) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + +""", +) + +add_docstr_all( + "floor_divide", + r""" +floor_divide(value) -> Tensor + +See :func:`torch.floor_divide` +""", +) + +add_docstr_all( + "floor_divide_", + r""" +floor_divide_(value) -> Tensor + +In-place version of :meth:`~Tensor.floor_divide` +""", +) + +add_docstr_all( + "diff", + r""" +diff(n=1, dim=-1, prepend=None, append=None) -> Tensor + +See :func:`torch.diff` +""", +) + +add_docstr_all( + "digamma", + r""" +digamma() -> Tensor + +See :func:`torch.digamma` +""", +) + +add_docstr_all( + "digamma_", + r""" +digamma_() -> Tensor + +In-place version of :meth:`~Tensor.digamma` +""", +) + +add_docstr_all( + "dim", + r""" +dim() -> int + +Returns the number of dimensions of :attr:`self` tensor. +""", +) + +add_docstr_all( + "dist", + r""" +dist(other, p=2) -> Tensor + +See :func:`torch.dist` +""", +) + +add_docstr_all( + "div", + r""" +div(value, *, rounding_mode=None) -> Tensor + +See :func:`torch.div` +""", +) + +add_docstr_all( + "div_", + r""" +div_(value, *, rounding_mode=None) -> Tensor + +In-place version of :meth:`~Tensor.div` +""", +) + +add_docstr_all( + "divide", + r""" +divide(value, *, rounding_mode=None) -> Tensor + +See :func:`torch.divide` +""", +) + +add_docstr_all( + "divide_", + r""" +divide_(value, *, rounding_mode=None) -> Tensor + +In-place version of :meth:`~Tensor.divide` +""", +) + +add_docstr_all( + "dot", + r""" +dot(other) -> Tensor + +See :func:`torch.dot` +""", +) + +add_docstr_all( + "element_size", + r""" +element_size() -> int + +Returns the size in bytes of an individual element. + +Example:: + + >>> torch.tensor([]).element_size() + 4 + >>> torch.tensor([], dtype=torch.uint8).element_size() + 1 + +""", +) + +add_docstr_all( + "eq", + r""" +eq(other) -> Tensor + +See :func:`torch.eq` +""", +) + +add_docstr_all( + "eq_", + r""" +eq_(other) -> Tensor + +In-place version of :meth:`~Tensor.eq` +""", +) + +add_docstr_all( + "equal", + r""" +equal(other) -> bool + +See :func:`torch.equal` +""", +) + +add_docstr_all( + "erf", + r""" +erf() -> Tensor + +See :func:`torch.erf` +""", +) + +add_docstr_all( + "erf_", + r""" +erf_() -> Tensor + +In-place version of :meth:`~Tensor.erf` +""", +) + +add_docstr_all( + "erfc", + r""" +erfc() -> Tensor + +See :func:`torch.erfc` +""", +) + +add_docstr_all( + "erfc_", + r""" +erfc_() -> Tensor + +In-place version of :meth:`~Tensor.erfc` +""", +) + +add_docstr_all( + "erfinv", + r""" +erfinv() -> Tensor + +See :func:`torch.erfinv` +""", +) + +add_docstr_all( + "erfinv_", + r""" +erfinv_() -> Tensor + +In-place version of :meth:`~Tensor.erfinv` +""", +) + +add_docstr_all( + "exp", + r""" +exp() -> Tensor + +See :func:`torch.exp` +""", +) + +add_docstr_all( + "exp_", + r""" +exp_() -> Tensor + +In-place version of :meth:`~Tensor.exp` +""", +) + +add_docstr_all( + "exp2", + r""" +exp2() -> Tensor + +See :func:`torch.exp2` +""", +) + +add_docstr_all( + "exp2_", + r""" +exp2_() -> Tensor + +In-place version of :meth:`~Tensor.exp2` +""", +) + +add_docstr_all( + "expm1", + r""" +expm1() -> Tensor + +See :func:`torch.expm1` +""", +) + +add_docstr_all( + "expm1_", + r""" +expm1_() -> Tensor + +In-place version of :meth:`~Tensor.expm1` +""", +) + +add_docstr_all( + "exponential_", + r""" +exponential_(lambd=1, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with elements drawn from the PDF (probability density function): + +.. math:: + + f(x) = \lambda e^{-\lambda x}, x > 0 + +.. note:: + In probability theory, exponential distribution is supported on interval [0, :math:`\inf`) (i.e., :math:`x >= 0`) + implying that zero can be sampled from the exponential distribution. + However, :func:`torch.Tensor.exponential_` does not sample zero, + which means that its actual support is the interval (0, :math:`\inf`). + + Note that :func:`torch.distributions.exponential.Exponential` is supported on the interval [0, :math:`\inf`) and can sample zero. +""", +) + +add_docstr_all( + "fill_", + r""" +fill_(value) -> Tensor + +Fills :attr:`self` tensor with the specified value. +""", +) + +add_docstr_all( + "floor", + r""" +floor() -> Tensor + +See :func:`torch.floor` +""", +) + +add_docstr_all( + "flip", + r""" +flip(dims) -> Tensor + +See :func:`torch.flip` +""", +) + +add_docstr_all( + "fliplr", + r""" +fliplr() -> Tensor + +See :func:`torch.fliplr` +""", +) + +add_docstr_all( + "flipud", + r""" +flipud() -> Tensor + +See :func:`torch.flipud` +""", +) + +add_docstr_all( + "roll", + r""" +roll(shifts, dims) -> Tensor + +See :func:`torch.roll` +""", +) + +add_docstr_all( + "floor_", + r""" +floor_() -> Tensor + +In-place version of :meth:`~Tensor.floor` +""", +) + +add_docstr_all( + "fmod", + r""" +fmod(divisor) -> Tensor + +See :func:`torch.fmod` +""", +) + +add_docstr_all( + "fmod_", + r""" +fmod_(divisor) -> Tensor + +In-place version of :meth:`~Tensor.fmod` +""", +) + +add_docstr_all( + "frac", + r""" +frac() -> Tensor + +See :func:`torch.frac` +""", +) + +add_docstr_all( + "frac_", + r""" +frac_() -> Tensor + +In-place version of :meth:`~Tensor.frac` +""", +) + +add_docstr_all( + "frexp", + r""" +frexp(input) -> (Tensor mantissa, Tensor exponent) + +See :func:`torch.frexp` +""", +) + +add_docstr_all( + "flatten", + r""" +flatten(start_dim=0, end_dim=-1) -> Tensor + +See :func:`torch.flatten` +""", +) + +add_docstr_all( + "gather", + r""" +gather(dim, index) -> Tensor + +See :func:`torch.gather` +""", +) + +add_docstr_all( + "gcd", + r""" +gcd(other) -> Tensor + +See :func:`torch.gcd` +""", +) + +add_docstr_all( + "gcd_", + r""" +gcd_(other) -> Tensor + +In-place version of :meth:`~Tensor.gcd` +""", +) + +add_docstr_all( + "ge", + r""" +ge(other) -> Tensor + +See :func:`torch.ge`. +""", +) + +add_docstr_all( + "ge_", + r""" +ge_(other) -> Tensor + +In-place version of :meth:`~Tensor.ge`. +""", +) + +add_docstr_all( + "greater_equal", + r""" +greater_equal(other) -> Tensor + +See :func:`torch.greater_equal`. +""", +) + +add_docstr_all( + "greater_equal_", + r""" +greater_equal_(other) -> Tensor + +In-place version of :meth:`~Tensor.greater_equal`. +""", +) + +add_docstr_all( + "geometric_", + r""" +geometric_(p, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with elements drawn from the geometric distribution: + +.. math:: + + P(X=k) = (1 - p)^{k - 1} p, k = 1, 2, ... + +.. note:: + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`, whereas + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`. +""", +) + +add_docstr_all( + "geqrf", + r""" +geqrf() -> (Tensor, Tensor) + +See :func:`torch.geqrf` +""", +) + +add_docstr_all( + "ger", + r""" +ger(vec2) -> Tensor + +See :func:`torch.ger` +""", +) + +add_docstr_all( + "inner", + r""" +inner(other) -> Tensor + +See :func:`torch.inner`. +""", +) + +add_docstr_all( + "outer", + r""" +outer(vec2) -> Tensor + +See :func:`torch.outer`. +""", +) + +add_docstr_all( + "hypot", + r""" +hypot(other) -> Tensor + +See :func:`torch.hypot` +""", +) + +add_docstr_all( + "hypot_", + r""" +hypot_(other) -> Tensor + +In-place version of :meth:`~Tensor.hypot` +""", +) + +add_docstr_all( + "i0", + r""" +i0() -> Tensor + +See :func:`torch.i0` +""", +) + +add_docstr_all( + "i0_", + r""" +i0_() -> Tensor + +In-place version of :meth:`~Tensor.i0` +""", +) + +add_docstr_all( + "igamma", + r""" +igamma(other) -> Tensor + +See :func:`torch.igamma` +""", +) + +add_docstr_all( + "igamma_", + r""" +igamma_(other) -> Tensor + +In-place version of :meth:`~Tensor.igamma` +""", +) + +add_docstr_all( + "igammac", + r""" +igammac(other) -> Tensor +See :func:`torch.igammac` +""", +) + +add_docstr_all( + "igammac_", + r""" +igammac_(other) -> Tensor +In-place version of :meth:`~Tensor.igammac` +""", +) + +add_docstr_all( + "indices", + r""" +indices() -> Tensor + +Return the indices tensor of a :ref:`sparse COO tensor `. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See also :meth:`Tensor.values`. + +.. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. +""", +) + +add_docstr_all( + "get_device", + r""" +get_device() -> Device ordinal (Integer) + +For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. +For CPU tensors, this function returns `-1`. + +Example:: + + >>> x = torch.randn(3, 4, 5, device='cuda:0') + >>> x.get_device() + 0 + >>> x.cpu().get_device() + -1 +""", +) + +add_docstr_all( + "values", + r""" +values() -> Tensor + +Return the values tensor of a :ref:`sparse COO tensor `. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See also :meth:`Tensor.indices`. + +.. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. +""", +) + +add_docstr_all( + "gt", + r""" +gt(other) -> Tensor + +See :func:`torch.gt`. +""", +) + +add_docstr_all( + "gt_", + r""" +gt_(other) -> Tensor + +In-place version of :meth:`~Tensor.gt`. +""", +) + +add_docstr_all( + "greater", + r""" +greater(other) -> Tensor + +See :func:`torch.greater`. +""", +) + +add_docstr_all( + "greater_", + r""" +greater_(other) -> Tensor + +In-place version of :meth:`~Tensor.greater`. +""", +) + +add_docstr_all( + "has_names", + r""" +Is ``True`` if any of this tensor's dimensions are named. Otherwise, is ``False``. +""", +) + +add_docstr_all( + "hardshrink", + r""" +hardshrink(lambd=0.5) -> Tensor + +See :func:`torch.nn.functional.hardshrink` +""", +) + +add_docstr_all( + "heaviside", + r""" +heaviside(values) -> Tensor + +See :func:`torch.heaviside` +""", +) + +add_docstr_all( + "heaviside_", + r""" +heaviside_(values) -> Tensor + +In-place version of :meth:`~Tensor.heaviside` +""", +) + +add_docstr_all( + "histc", + r""" +histc(bins=100, min=0, max=0) -> Tensor + +See :func:`torch.histc` +""", +) + +add_docstr_all( + "histogram", + r""" +histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) + +See :func:`torch.histogram` +""", +) + +add_docstr_all( + "index_add_", + r""" +index_add_(dim, index, source, *, alpha=1) -> Tensor + +Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self` +tensor by adding to the indices in the order given in :attr:`index`. For example, +if ``dim == 0``, ``index[i] == j``, and ``alpha=-1``, then the ``i``\ th row of +``source`` is subtracted from the ``j``\ th row of :attr:`self`. + +The :attr:`dim`\ th dimension of ``source`` must have the same size as the +length of :attr:`index` (which must be a vector), and all other dimensions must +match :attr:`self`, or an error will be raised. + +For a 3-D tensor the output is given as:: + + self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 + self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 + self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 + +Note: + {forward_reproducibility_note} + +Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (Tensor): the tensor containing values to add + +Keyword args: + alpha (Number): the scalar multiplier for ``source`` + +Example:: + + >>> x = torch.ones(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_add_(0, index, t) + tensor([[ 2., 3., 4.], + [ 1., 1., 1.], + [ 8., 9., 10.], + [ 1., 1., 1.], + [ 5., 6., 7.]]) + >>> x.index_add_(0, index, t, alpha=-1) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) +""".format( + **reproducibility_notes + ), +) + +add_docstr_all( + "index_copy_", + r""" +index_copy_(dim, index, tensor) -> Tensor + +Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting +the indices in the order given in :attr:`index`. For example, if ``dim == 0`` +and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the +``j``\ th row of :attr:`self`. + +The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the +length of :attr:`index` (which must be a vector), and all other dimensions must +match :attr:`self`, or an error will be raised. + +.. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + +Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`tensor` to select from + tensor (Tensor): the tensor containing values to copy + +Example:: + + >>> x = torch.zeros(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_copy_(0, index, t) + tensor([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 7., 8., 9.], + [ 0., 0., 0.], + [ 4., 5., 6.]]) +""", +) + +add_docstr_all( + "index_fill_", + r""" +index_fill_(dim, index, value) -> Tensor + +Fills the elements of the :attr:`self` tensor with value :attr:`value` by +selecting the indices in the order given in :attr:`index`. + +Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + +Example:: + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) +""", +) + +add_docstr_all( + "index_put_", + r""" +index_put_(indices, values, accumulate=False) -> Tensor + +Puts values from the tensor :attr:`values` into the tensor :attr:`self` using +the indices specified in :attr:`indices` (which is a tuple of Tensors). The +expression ``tensor.index_put_(indices, values)`` is equivalent to +``tensor[indices] = values``. Returns :attr:`self`. + +If :attr:`accumulate` is ``True``, the elements in :attr:`values` are added to +:attr:`self`. If accumulate is ``False``, the behavior is undefined if indices +contain duplicate elements. + +Args: + indices (tuple of LongTensor): tensors used to index into `self`. + values (Tensor): tensor of same dtype as `self`. + accumulate (bool): whether to accumulate into self +""", +) + +add_docstr_all( + "index_put", + r""" +index_put(indices, values, accumulate=False) -> Tensor + +Out-place version of :meth:`~Tensor.index_put_`. +""", +) + +add_docstr_all( + "index_reduce_", + r""" +index_reduce_(dim, index, source, reduce, *, include_self=True) -> Tensor + +Accumulate the elements of ``source`` into the :attr:`self` +tensor by accumulating to the indices in the order given in :attr:`index` +using the reduction given by the ``reduce`` argument. For example, if ``dim == 0``, +``index[i] == j``, ``reduce == prod`` and ``include_self == True`` then the ``i``\ th +row of ``source`` is multiplied by the ``j``\ th row of :attr:`self`. If +:obj:`include_self="True"`, the values in the :attr:`self` tensor are included +in the reduction, otherwise, rows in the :attr:`self` tensor that are accumulated +to are treated as if they were filled with the reduction identites. + +The :attr:`dim`\ th dimension of ``source`` must have the same size as the +length of :attr:`index` (which must be a vector), and all other dimensions must +match :attr:`self`, or an error will be raised. + +For a 3-D tensor with :obj:`reduce="prod"` and :obj:`include_self=True` the +output is given as:: + + self[index[i], :, :] *= src[i, :, :] # if dim == 0 + self[:, index[i], :] *= src[:, i, :] # if dim == 1 + self[:, :, index[i]] *= src[:, :, i] # if dim == 2 + +Note: + {forward_reproducibility_note} + +.. note:: + + This function only supports floating point tensors. + +.. warning:: + + This function is in beta and may change in the near future. + +Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (FloatTensor): the tensor containing values to accumulate + reduce (str): the reduction operation to apply + (:obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + +Keyword args: + include_self (bool): whether the elements from the ``self`` tensor are + included in the reduction + +Example:: + + >>> x = torch.empty(5, 3).fill_(2) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2, 0]) + >>> x.index_reduce_(0, index, t, 'prod') + tensor([[20., 44., 72.], + [ 2., 2., 2.], + [14., 16., 18.], + [ 2., 2., 2.], + [ 8., 10., 12.]]) + >>> x = torch.empty(5, 3).fill_(2) + >>> x.index_reduce_(0, index, t, 'prod', include_self=False) + tensor([[10., 22., 36.], + [ 2., 2., 2.], + [ 7., 8., 9.], + [ 2., 2., 2.], + [ 4., 5., 6.]]) +""".format( + **reproducibility_notes + ), +) + +add_docstr_all( + "index_select", + r""" +index_select(dim, index) -> Tensor + +See :func:`torch.index_select` +""", +) + +add_docstr_all( + "sparse_mask", + r""" +sparse_mask(mask) -> Tensor + +Returns a new :ref:`sparse tensor ` with values from a +strided tensor :attr:`self` filtered by the indices of the sparse +tensor :attr:`mask`. The values of :attr:`mask` sparse tensor are +ignored. :attr:`self` and :attr:`mask` tensors must have the same +shape. + +.. note:: + + The returned sparse tensor might contain duplicate values if :attr:`mask` + is not coalesced. It is therefore advisable to pass ``mask.coalesce()`` + if such behavior is not desired. + +.. note:: + + The returned sparse tensor has the same indices as the sparse tensor + :attr:`mask`, even when the corresponding values in :attr:`self` are + zeros. + +Args: + mask (Tensor): a sparse tensor whose indices are used as a filter + +Example:: + + >>> nse = 5 + >>> dims = (5, 5, 2, 2) + >>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)), + ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) + >>> V = torch.randn(nse, dims[2], dims[3]) + >>> S = torch.sparse_coo_tensor(I, V, dims).coalesce() + >>> D = torch.randn(dims) + >>> D.sparse_mask(S) + tensor(indices=tensor([[0, 0, 0, 2], + [0, 1, 4, 3]]), + values=tensor([[[ 1.6550, 0.2397], + [-0.1611, -0.0779]], + + [[ 0.2326, -1.0558], + [ 1.4711, 1.9678]], + + [[-0.5138, -0.0411], + [ 1.9417, 0.5158]], + + [[ 0.0793, 0.0036], + [-0.2569, -0.1055]]]), + size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo) +""", +) + +add_docstr_all( + "inverse", + r""" +inverse() -> Tensor + +See :func:`torch.inverse` +""", +) + +add_docstr_all( + "isnan", + r""" +isnan() -> Tensor + +See :func:`torch.isnan` +""", +) + +add_docstr_all( + "isinf", + r""" +isinf() -> Tensor + +See :func:`torch.isinf` +""", +) + +add_docstr_all( + "isposinf", + r""" +isposinf() -> Tensor + +See :func:`torch.isposinf` +""", +) + +add_docstr_all( + "isneginf", + r""" +isneginf() -> Tensor + +See :func:`torch.isneginf` +""", +) + +add_docstr_all( + "isfinite", + r""" +isfinite() -> Tensor + +See :func:`torch.isfinite` +""", +) + +add_docstr_all( + "isclose", + r""" +isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + +See :func:`torch.isclose` +""", +) + +add_docstr_all( + "isreal", + r""" +isreal() -> Tensor + +See :func:`torch.isreal` +""", +) + +add_docstr_all( + "is_coalesced", + r""" +is_coalesced() -> bool + +Returns ``True`` if :attr:`self` is a :ref:`sparse COO tensor +` that is coalesced, ``False`` otherwise. + +.. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + +See :meth:`coalesce` and :ref:`uncoalesced tensors `. +""", +) + +add_docstr_all( + "is_contiguous", + r""" +is_contiguous(memory_format=torch.contiguous_format) -> bool + +Returns True if :attr:`self` tensor is contiguous in memory in the order specified +by memory format. + +Args: + memory_format (:class:`torch.memory_format`, optional): Specifies memory allocation + order. Default: ``torch.contiguous_format``. +""", +) + +add_docstr_all( + "is_pinned", + r""" +Returns true if this tensor resides in pinned memory. +""", +) + +add_docstr_all( + "is_floating_point", + r""" +is_floating_point() -> bool + +Returns True if the data type of :attr:`self` is a floating point data type. +""", +) + +add_docstr_all( + "is_complex", + r""" +is_complex() -> bool + +Returns True if the data type of :attr:`self` is a complex data type. +""", +) + +add_docstr_all( + "is_inference", + r""" +is_inference() -> bool + +See :func:`torch.is_inference` +""", +) + +add_docstr_all( + "is_conj", + r""" +is_conj() -> bool + +Returns True if the conjugate bit of :attr:`self` is set to true. +""", +) + +add_docstr_all( + "is_neg", + r""" +is_neg() -> bool + +Returns True if the negative bit of :attr:`self` is set to true. +""", +) + +add_docstr_all( + "is_signed", + r""" +is_signed() -> bool + +Returns True if the data type of :attr:`self` is a signed data type. +""", +) + +add_docstr_all( + "is_set_to", + r""" +is_set_to(tensor) -> bool + +Returns True if both tensors are pointing to the exact same memory (same +storage, offset, size and stride). +""", +) + +add_docstr_all( + "item", + r""" +item() -> number + +Returns the value of this tensor as a standard Python number. This only works +for tensors with one element. For other cases, see :meth:`~Tensor.tolist`. + +This operation is not differentiable. + +Example:: + + >>> x = torch.tensor([1.0]) + >>> x.item() + 1.0 + +""", +) + +add_docstr_all( + "kron", + r""" +kron(other) -> Tensor + +See :func:`torch.kron` +""", +) + +add_docstr_all( + "kthvalue", + r""" +kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.kthvalue` +""", +) + +add_docstr_all( + "ldexp", + r""" +ldexp(other) -> Tensor + +See :func:`torch.ldexp` +""", +) + +add_docstr_all( + "ldexp_", + r""" +ldexp_(other) -> Tensor + +In-place version of :meth:`~Tensor.ldexp` +""", +) + +add_docstr_all( + "lcm", + r""" +lcm(other) -> Tensor + +See :func:`torch.lcm` +""", +) + +add_docstr_all( + "lcm_", + r""" +lcm_(other) -> Tensor + +In-place version of :meth:`~Tensor.lcm` +""", +) + +add_docstr_all( + "le", + r""" +le(other) -> Tensor + +See :func:`torch.le`. +""", +) + +add_docstr_all( + "le_", + r""" +le_(other) -> Tensor + +In-place version of :meth:`~Tensor.le`. +""", +) + +add_docstr_all( + "less_equal", + r""" +less_equal(other) -> Tensor + +See :func:`torch.less_equal`. +""", +) + +add_docstr_all( + "less_equal_", + r""" +less_equal_(other) -> Tensor + +In-place version of :meth:`~Tensor.less_equal`. +""", +) + +add_docstr_all( + "lerp", + r""" +lerp(end, weight) -> Tensor + +See :func:`torch.lerp` +""", +) + +add_docstr_all( + "lerp_", + r""" +lerp_(end, weight) -> Tensor + +In-place version of :meth:`~Tensor.lerp` +""", +) + +add_docstr_all( + "lgamma", + r""" +lgamma() -> Tensor + +See :func:`torch.lgamma` +""", +) + +add_docstr_all( + "lgamma_", + r""" +lgamma_() -> Tensor + +In-place version of :meth:`~Tensor.lgamma` +""", +) + +add_docstr_all( + "log", + r""" +log() -> Tensor + +See :func:`torch.log` +""", +) + +add_docstr_all( + "log_", + r""" +log_() -> Tensor + +In-place version of :meth:`~Tensor.log` +""", +) + +add_docstr_all( + "log10", + r""" +log10() -> Tensor + +See :func:`torch.log10` +""", +) + +add_docstr_all( + "log10_", + r""" +log10_() -> Tensor + +In-place version of :meth:`~Tensor.log10` +""", +) + +add_docstr_all( + "log1p", + r""" +log1p() -> Tensor + +See :func:`torch.log1p` +""", +) + +add_docstr_all( + "log1p_", + r""" +log1p_() -> Tensor + +In-place version of :meth:`~Tensor.log1p` +""", +) + +add_docstr_all( + "log2", + r""" +log2() -> Tensor + +See :func:`torch.log2` +""", +) + +add_docstr_all( + "log2_", + r""" +log2_() -> Tensor + +In-place version of :meth:`~Tensor.log2` +""", +) + +add_docstr_all( + "logaddexp", + r""" +logaddexp(other) -> Tensor + +See :func:`torch.logaddexp` +""", +) + +add_docstr_all( + "logaddexp2", + r""" +logaddexp2(other) -> Tensor + +See :func:`torch.logaddexp2` +""", +) + +add_docstr_all( + "log_normal_", + r""" +log_normal_(mean=1, std=2, *, generator=None) + +Fills :attr:`self` tensor with numbers samples from the log-normal distribution +parameterized by the given mean :math:`\mu` and standard deviation +:math:`\sigma`. Note that :attr:`mean` and :attr:`std` are the mean and +standard deviation of the underlying normal distribution, and not of the +returned distribution: + +.. math:: + + f(x) = \dfrac{1}{x \sigma \sqrt{2\pi}}\ e^{-\frac{(\ln x - \mu)^2}{2\sigma^2}} +""", +) + +add_docstr_all( + "logsumexp", + r""" +logsumexp(dim, keepdim=False) -> Tensor + +See :func:`torch.logsumexp` +""", +) + +add_docstr_all( + "lt", + r""" +lt(other) -> Tensor + +See :func:`torch.lt`. +""", +) + +add_docstr_all( + "lt_", + r""" +lt_(other) -> Tensor + +In-place version of :meth:`~Tensor.lt`. +""", +) + +add_docstr_all( + "less", + r""" +lt(other) -> Tensor + +See :func:`torch.less`. +""", +) + +add_docstr_all( + "less_", + r""" +less_(other) -> Tensor + +In-place version of :meth:`~Tensor.less`. +""", +) + +add_docstr_all( + "lu_solve", + r""" +lu_solve(LU_data, LU_pivots) -> Tensor + +See :func:`torch.lu_solve` +""", +) + +add_docstr_all( + "map_", + r""" +map_(tensor, callable) + +Applies :attr:`callable` for each element in :attr:`self` tensor and the given +:attr:`tensor` and stores the results in :attr:`self` tensor. :attr:`self` tensor and +the given :attr:`tensor` must be :ref:`broadcastable `. + +The :attr:`callable` should have the signature:: + + def callable(a, b) -> number +""", +) + +add_docstr_all( + "masked_scatter_", + r""" +masked_scatter_(mask, source) + +Copies elements from :attr:`source` into :attr:`self` tensor at positions where +the :attr:`mask` is True. Elements from :attr:`source` are copied into :attr:`self` +starting at position 0 of :attr:`source` and continuing in order one-by-one for each +occurrence of :attr:`mask` being True. +The shape of :attr:`mask` must be :ref:`broadcastable ` +with the shape of the underlying tensor. The :attr:`source` should have at least +as many elements as the number of ones in :attr:`mask`. + +Args: + mask (BoolTensor): the boolean mask + source (Tensor): the tensor to copy from + +.. note:: + + The :attr:`mask` operates on the :attr:`self` tensor, not on the given + :attr:`source` tensor. + +Example: + + >>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) + >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]]) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter_(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + +""", +) + +add_docstr_all( + "masked_fill_", + r""" +masked_fill_(mask, value) + +Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is +True. The shape of :attr:`mask` must be +:ref:`broadcastable ` with the shape of the underlying +tensor. + +Args: + mask (BoolTensor): the boolean mask + value (float): the value to fill in with +""", +) + +add_docstr_all( + "masked_select", + r""" +masked_select(mask) -> Tensor + +See :func:`torch.masked_select` +""", +) + +add_docstr_all( + "matrix_power", + r""" +matrix_power(n) -> Tensor + +.. note:: :meth:`~Tensor.matrix_power` is deprecated, use :func:`torch.linalg.matrix_power` instead. + +Alias for :func:`torch.linalg.matrix_power` +""", +) + +add_docstr_all( + "matrix_exp", + r""" +matrix_exp() -> Tensor + +See :func:`torch.matrix_exp` +""", +) + +add_docstr_all( + "max", + r""" +max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + +See :func:`torch.max` +""", +) + +add_docstr_all( + "amax", + r""" +amax(dim=None, keepdim=False) -> Tensor + +See :func:`torch.amax` +""", +) + +add_docstr_all( + "maximum", + r""" +maximum(other) -> Tensor + +See :func:`torch.maximum` +""", +) + +add_docstr_all( + "fmax", + r""" +fmax(other) -> Tensor + +See :func:`torch.fmax` +""", +) + +add_docstr_all( + "argmax", + r""" +argmax(dim=None, keepdim=False) -> LongTensor + +See :func:`torch.argmax` +""", +) + +add_docstr_all( + "argwhere", + r""" +argwhere() -> Tensor + +See :func:`torch.argwhere` +""", +) + +add_docstr_all( + "mean", + r""" +mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + +See :func:`torch.mean` +""", +) + +add_docstr_all( + "nanmean", + r""" +nanmean(dim=None, keepdim=False, *, dtype=None) -> Tensor + +See :func:`torch.nanmean` +""", +) + +add_docstr_all( + "median", + r""" +median(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.median` +""", +) + +add_docstr_all( + "nanmedian", + r""" +nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.nanmedian` +""", +) + +add_docstr_all( + "min", + r""" +min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + +See :func:`torch.min` +""", +) + +add_docstr_all( + "amin", + r""" +amin(dim=None, keepdim=False) -> Tensor + +See :func:`torch.amin` +""", +) + +add_docstr_all( + "minimum", + r""" +minimum(other) -> Tensor + +See :func:`torch.minimum` +""", +) + +add_docstr_all( + "aminmax", + r""" +aminmax(*, dim=None, keepdim=False) -> (Tensor min, Tensor max) + +See :func:`torch.aminmax` +""", +) + +add_docstr_all( + "fmin", + r""" +fmin(other) -> Tensor + +See :func:`torch.fmin` +""", +) + +add_docstr_all( + "argmin", + r""" +argmin(dim=None, keepdim=False) -> LongTensor + +See :func:`torch.argmin` +""", +) + +add_docstr_all( + "mm", + r""" +mm(mat2) -> Tensor + +See :func:`torch.mm` +""", +) + +add_docstr_all( + "mode", + r""" +mode(dim=None, keepdim=False) -> (Tensor, LongTensor) + +See :func:`torch.mode` +""", +) + +add_docstr_all( + "movedim", + r""" +movedim(source, destination) -> Tensor + +See :func:`torch.movedim` +""", +) + +add_docstr_all( + "moveaxis", + r""" +moveaxis(source, destination) -> Tensor + +See :func:`torch.moveaxis` +""", +) + +add_docstr_all( + "mul", + r""" +mul(value) -> Tensor + +See :func:`torch.mul`. +""", +) + +add_docstr_all( + "mul_", + r""" +mul_(value) -> Tensor + +In-place version of :meth:`~Tensor.mul`. +""", +) + +add_docstr_all( + "multiply", + r""" +multiply(value) -> Tensor + +See :func:`torch.multiply`. +""", +) + +add_docstr_all( + "multiply_", + r""" +multiply_(value) -> Tensor + +In-place version of :meth:`~Tensor.multiply`. +""", +) + +add_docstr_all( + "multinomial", + r""" +multinomial(num_samples, replacement=False, *, generator=None) -> Tensor + +See :func:`torch.multinomial` +""", +) + +add_docstr_all( + "mv", + r""" +mv(vec) -> Tensor + +See :func:`torch.mv` +""", +) + +add_docstr_all( + "mvlgamma", + r""" +mvlgamma(p) -> Tensor + +See :func:`torch.mvlgamma` +""", +) + +add_docstr_all( + "mvlgamma_", + r""" +mvlgamma_(p) -> Tensor + +In-place version of :meth:`~Tensor.mvlgamma` +""", +) + +add_docstr_all( + "narrow", + r""" +narrow(dimension, start, length) -> Tensor + +See :func:`torch.narrow`. +""", +) + +add_docstr_all( + "narrow_copy", + r""" +narrow_copy(dimension, start, length) -> Tensor + +See :func:`torch.narrow_copy`. +""", +) + +add_docstr_all( + "ndimension", + r""" +ndimension() -> int + +Alias for :meth:`~Tensor.dim()` +""", +) + +add_docstr_all( + "nan_to_num", + r""" +nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor + +See :func:`torch.nan_to_num`. +""", +) + +add_docstr_all( + "nan_to_num_", + r""" +nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor + +In-place version of :meth:`~Tensor.nan_to_num`. +""", +) + +add_docstr_all( + "ne", + r""" +ne(other) -> Tensor + +See :func:`torch.ne`. +""", +) + +add_docstr_all( + "ne_", + r""" +ne_(other) -> Tensor + +In-place version of :meth:`~Tensor.ne`. +""", +) + +add_docstr_all( + "not_equal", + r""" +not_equal(other) -> Tensor + +See :func:`torch.not_equal`. +""", +) + +add_docstr_all( + "not_equal_", + r""" +not_equal_(other) -> Tensor + +In-place version of :meth:`~Tensor.not_equal`. +""", +) + +add_docstr_all( + "neg", + r""" +neg() -> Tensor + +See :func:`torch.neg` +""", +) + +add_docstr_all( + "negative", + r""" +negative() -> Tensor + +See :func:`torch.negative` +""", +) + +add_docstr_all( + "neg_", + r""" +neg_() -> Tensor + +In-place version of :meth:`~Tensor.neg` +""", +) + +add_docstr_all( + "negative_", + r""" +negative_() -> Tensor + +In-place version of :meth:`~Tensor.negative` +""", +) + +add_docstr_all( + "nelement", + r""" +nelement() -> int + +Alias for :meth:`~Tensor.numel` +""", +) + +add_docstr_all( + "nextafter", + r""" +nextafter(other) -> Tensor +See :func:`torch.nextafter` +""", +) + +add_docstr_all( + "nextafter_", + r""" +nextafter_(other) -> Tensor +In-place version of :meth:`~Tensor.nextafter` +""", +) + +add_docstr_all( + "nonzero", + r""" +nonzero() -> LongTensor + +See :func:`torch.nonzero` +""", +) + +add_docstr_all( + "nonzero_static", + r""" +nonzero_static(input, *, size, fill_value=-1) -> Tensor + +Returns a 2-D tensor where each row is the index for a non-zero value. +The returned Tensor has the same `torch.dtype` as `torch.nonzero()`. + +Args: + input (Tensor): the input tensor to count non-zero elements. + +Keyword args: + size (int): the size of non-zero elements expected to be included in the out + tensor. Pad the out tensor with `fill_value` if the `size` is larger + than total number of non-zero elements, truncate out tensor if `size` + is smaller. The size must be a non-negative integer. + fill_value (int): the value to fill the output tensor with when `size` is larger + than the total number of non-zero elements. Default is `-1` to represent + invalid index. + +Example: + + # Example 1: Padding + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 4 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([[ 0, 0], + [ 1, 0], + [ 1, 1], + [ -1, -1]], dtype=torch.int64) + + # Example 2: Truncating + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([[ 0, 0], + [ 1, 0]], dtype=torch.int64) + + # Example 3: 0 size + >>> input_tensor = torch.tensor([10]) + >>> static_size = 0 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([], size=(0, 1), dtype=torch.int64) + + # Example 4: 0 rank input + >>> input_tensor = torch.tensor(10) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size = static_size) + tensor([], size=(2, 0), dtype=torch.int64) +""", +) + +add_docstr_all( + "norm", + r""" +norm(p=2, dim=None, keepdim=False) -> Tensor + +See :func:`torch.norm` +""", +) + +add_docstr_all( + "normal_", + r""" +normal_(mean=0, std=1, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with elements samples from the normal distribution +parameterized by :attr:`mean` and :attr:`std`. +""", +) + +add_docstr_all( + "numel", + r""" +numel() -> int + +See :func:`torch.numel` +""", +) + +add_docstr_all( + "numpy", + r""" +numpy(*, force=False) -> numpy.ndarray + +Returns the tensor as a NumPy :class:`ndarray`. + +If :attr:`force` is ``False`` (the default), the conversion +is performed only if the tensor is on the CPU, does not require grad, +does not have its conjugate bit set, and is a dtype and layout that +NumPy supports. The returned ndarray and the tensor will share their +storage, so changes to the tensor will be reflected in the ndarray +and vice versa. + +If :attr:`force` is ``True`` this is equivalent to +calling ``t.detach().cpu().resolve_conj().resolve_neg().numpy()``. +If the tensor isn't on the CPU or the conjugate or negative bit is set, +the tensor won't share its storage with the returned ndarray. +Setting :attr:`force` to ``True`` can be a useful shorthand. + +Args: + force (bool): if ``True``, the ndarray may be a copy of the tensor + instead of always sharing memory, defaults to ``False``. +""", +) + +add_docstr_all( + "orgqr", + r""" +orgqr(input2) -> Tensor + +See :func:`torch.orgqr` +""", +) + +add_docstr_all( + "ormqr", + r""" +ormqr(input2, input3, left=True, transpose=False) -> Tensor + +See :func:`torch.ormqr` +""", +) + +add_docstr_all( + "permute", + r""" +permute(*dims) -> Tensor + +See :func:`torch.permute` +""", +) + +add_docstr_all( + "polygamma", + r""" +polygamma(n) -> Tensor + +See :func:`torch.polygamma` +""", +) + +add_docstr_all( + "polygamma_", + r""" +polygamma_(n) -> Tensor + +In-place version of :meth:`~Tensor.polygamma` +""", +) + +add_docstr_all( + "positive", + r""" +positive() -> Tensor + +See :func:`torch.positive` +""", +) + +add_docstr_all( + "pow", + r""" +pow(exponent) -> Tensor + +See :func:`torch.pow` +""", +) + +add_docstr_all( + "pow_", + r""" +pow_(exponent) -> Tensor + +In-place version of :meth:`~Tensor.pow` +""", +) + +add_docstr_all( + "float_power", + r""" +float_power(exponent) -> Tensor + +See :func:`torch.float_power` +""", +) + +add_docstr_all( + "float_power_", + r""" +float_power_(exponent) -> Tensor + +In-place version of :meth:`~Tensor.float_power` +""", +) + +add_docstr_all( + "prod", + r""" +prod(dim=None, keepdim=False, dtype=None) -> Tensor + +See :func:`torch.prod` +""", +) + +add_docstr_all( + "put_", + r""" +put_(index, source, accumulate=False) -> Tensor + +Copies the elements from :attr:`source` into the positions specified by +:attr:`index`. For the purpose of indexing, the :attr:`self` tensor is treated as if +it were a 1-D tensor. + +:attr:`index` and :attr:`source` need to have the same number of elements, but not necessarily +the same shape. + +If :attr:`accumulate` is ``True``, the elements in :attr:`source` are added to +:attr:`self`. If accumulate is ``False``, the behavior is undefined if :attr:`index` +contain duplicate elements. + +Args: + index (LongTensor): the indices into self + source (Tensor): the tensor containing values to copy from + accumulate (bool): whether to accumulate into self + +Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> src.put_(torch.tensor([1, 3]), torch.tensor([9, 10])) + tensor([[ 4, 9, 5], + [ 10, 7, 8]]) +""", +) + +add_docstr_all( + "put", + r""" +put(input, index, source, accumulate=False) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.put_`. +`input` corresponds to `self` in :meth:`torch.Tensor.put_`. +""", +) + +add_docstr_all( + "qr", + r""" +qr(some=True) -> (Tensor, Tensor) + +See :func:`torch.qr` +""", +) + +add_docstr_all( + "qscheme", + r""" +qscheme() -> torch.qscheme + +Returns the quantization scheme of a given QTensor. +""", +) + +add_docstr_all( + "quantile", + r""" +quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + +See :func:`torch.quantile` +""", +) + +add_docstr_all( + "nanquantile", + r""" +nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + +See :func:`torch.nanquantile` +""", +) + +add_docstr_all( + "q_scale", + r""" +q_scale() -> float + +Given a Tensor quantized by linear(affine) quantization, +returns the scale of the underlying quantizer(). +""", +) + +add_docstr_all( + "q_zero_point", + r""" +q_zero_point() -> int + +Given a Tensor quantized by linear(affine) quantization, +returns the zero_point of the underlying quantizer(). +""", +) + +add_docstr_all( + "q_per_channel_scales", + r""" +q_per_channel_scales() -> Tensor + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns a Tensor of scales of the underlying quantizer. It has the number of +elements that matches the corresponding dimensions (from q_per_channel_axis) of +the tensor. +""", +) + +add_docstr_all( + "q_per_channel_zero_points", + r""" +q_per_channel_zero_points() -> Tensor + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns a tensor of zero_points of the underlying quantizer. It has the number of +elements that matches the corresponding dimensions (from q_per_channel_axis) of +the tensor. +""", +) + +add_docstr_all( + "q_per_channel_axis", + r""" +q_per_channel_axis() -> int + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns the index of dimension on which per-channel quantization is applied. +""", +) + +add_docstr_all( + "random_", + r""" +random_(from=0, to=None, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with numbers sampled from the discrete uniform +distribution over ``[from, to - 1]``. If not specified, the values are usually +only bounded by :attr:`self` tensor's data type. However, for floating point +types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every +value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` +will be uniform in ``[0, 2^53]``. +""", +) + +add_docstr_all( + "rad2deg", + r""" +rad2deg() -> Tensor + +See :func:`torch.rad2deg` +""", +) + +add_docstr_all( + "rad2deg_", + r""" +rad2deg_() -> Tensor + +In-place version of :meth:`~Tensor.rad2deg` +""", +) + +add_docstr_all( + "deg2rad", + r""" +deg2rad() -> Tensor + +See :func:`torch.deg2rad` +""", +) + +add_docstr_all( + "deg2rad_", + r""" +deg2rad_() -> Tensor + +In-place version of :meth:`~Tensor.deg2rad` +""", +) + +add_docstr_all( + "ravel", + r""" +ravel() -> Tensor + +see :func:`torch.ravel` +""", +) + +add_docstr_all( + "reciprocal", + r""" +reciprocal() -> Tensor + +See :func:`torch.reciprocal` +""", +) + +add_docstr_all( + "reciprocal_", + r""" +reciprocal_() -> Tensor + +In-place version of :meth:`~Tensor.reciprocal` +""", +) + +add_docstr_all( + "record_stream", + r""" +record_stream(stream) + +Marks the tensor as having been used by this stream. When the tensor +is deallocated, ensure the tensor memory is not reused for another tensor +until all work queued on :attr:`stream` at the time of deallocation is +complete. + +.. note:: + + The caching allocator is aware of only the stream where a tensor was + allocated. Due to the awareness, it already correctly manages the life + cycle of tensors on only one stream. But if a tensor is used on a stream + different from the stream of origin, the allocator might reuse the memory + unexpectedly. Calling this method lets the allocator know which streams + have used the tensor. + +.. warning:: + + This method is most suitable for use cases where you are providing a + function that created a tensor on a side stream, and want users to be able + to make use of the tensor without having to think carefully about stream + safety when making use of them. These safety guarantees come at some + performance and predictability cost (analogous to the tradeoff between GC + and manual memory management), so if you are in a situation where + you manage the full lifetime of your tensors, you may consider instead + manually managing CUDA events so that calling this method is not necessary. + In particular, when you call this method, on later allocations the + allocator will poll the recorded stream to see if all operations have + completed yet; you can potentially race with side stream computation and + non-deterministically reuse or fail to reuse memory for an allocation. + + You can safely use tensors allocated on side streams without + :meth:`~Tensor.record_stream`; you must manually ensure that + any non-creation stream uses of a tensor are synced back to the creation + stream before you deallocate the tensor. As the CUDA caching allocator + guarantees that the memory will only be reused with the same creation stream, + this is sufficient to ensure that writes to future reallocations of the + memory will be delayed until non-creation stream uses are done. + (Counterintuitively, you may observe that on the CPU side we have already + reallocated the tensor, even though CUDA kernels on the old tensor are + still in progress. This is fine, because CUDA operations on the new + tensor will appropriately wait for the old operations to complete, as they + are all on the same stream.) + + Concretely, this looks like this:: + + with torch.cuda.stream(s0): + x = torch.zeros(N) + + s1.wait_stream(s0) + with torch.cuda.stream(s1): + y = some_comm_op(x) + + ... some compute on s0 ... + + # synchronize creation stream s0 to side stream s1 + # before deallocating x + s0.wait_stream(s1) + del x + + Note that some discretion is required when deciding when to perform + ``s0.wait_stream(s1)``. In particular, if we were to wait immediately + after ``some_comm_op``, there wouldn't be any point in having the side + stream; it would be equivalent to have run ``some_comm_op`` on ``s0``. + Instead, the synchronization must be placed at some appropriate, later + point in time where you expect the side stream ``s1`` to have finished + work. This location is typically identified via profiling, e.g., using + Chrome traces produced + :meth:`torch.autograd.profiler.profile.export_chrome_trace`. If you + place the wait too early, work on s0 will block until ``s1`` has finished, + preventing further overlapping of communication and computation. If you + place the wait too late, you will use more memory than is strictly + necessary (as you are keeping ``x`` live for longer.) For a concrete + example of how this guidance can be applied in practice, see this post: + `FSDP and CUDACachingAllocator + `_. +""", +) + +add_docstr_all( + "remainder", + r""" +remainder(divisor) -> Tensor + +See :func:`torch.remainder` +""", +) + +add_docstr_all( + "remainder_", + r""" +remainder_(divisor) -> Tensor + +In-place version of :meth:`~Tensor.remainder` +""", +) + +add_docstr_all( + "renorm", + r""" +renorm(p, dim, maxnorm) -> Tensor + +See :func:`torch.renorm` +""", +) + +add_docstr_all( + "renorm_", + r""" +renorm_(p, dim, maxnorm) -> Tensor + +In-place version of :meth:`~Tensor.renorm` +""", +) + +add_docstr_all( + "repeat", + r""" +repeat(*sizes) -> Tensor + +Repeats this tensor along the specified dimensions. + +Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. + +.. warning:: + + :meth:`~Tensor.repeat` behaves differently from + `numpy.repeat `_, + but is more similar to + `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. + +Args: + sizes (torch.Size or int...): The number of times to repeat this tensor along each + dimension + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat(4, 2) + tensor([[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]]) + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) +""", +) + +add_docstr_all( + "repeat_interleave", + r""" +repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor + +See :func:`torch.repeat_interleave`. +""", +) + +add_docstr_all( + "requires_grad_", + r""" +requires_grad_(requires_grad=True) -> Tensor + +Change if autograd should record operations on this tensor: sets this tensor's +:attr:`requires_grad` attribute in-place. Returns this tensor. + +:func:`requires_grad_`'s main use case is to tell autograd to begin recording +operations on a Tensor ``tensor``. If ``tensor`` has ``requires_grad=False`` +(because it was obtained through a DataLoader, or required preprocessing or +initialization), ``tensor.requires_grad_()`` makes it so that autograd will +begin to record operations on ``tensor``. + +Args: + requires_grad (bool): If autograd should record operations on this tensor. + Default: ``True``. + +Example:: + + >>> # Let's say we want to preprocess some saved weights and use + >>> # the result as new weights. + >>> saved_weights = [0.1, 0.2, 0.3, 0.25] + >>> loaded_weights = torch.tensor(saved_weights) + >>> weights = preprocess(loaded_weights) # some function + >>> weights + tensor([-0.5503, 0.4926, -2.1158, -0.8303]) + + >>> # Now, start to record operations done to weights + >>> weights.requires_grad_() + >>> out = weights.pow(2).sum() + >>> out.backward() + >>> weights.grad + tensor([-1.1007, 0.9853, -4.2316, -1.6606]) + +""", +) + +add_docstr_all( + "reshape", + r""" +reshape(*shape) -> Tensor + +Returns a tensor with the same data and number of elements as :attr:`self` +but with the specified shape. This method returns a view if :attr:`shape` is +compatible with the current shape. See :meth:`torch.Tensor.view` on when it is +possible to return a view. + +See :func:`torch.reshape` + +Args: + shape (tuple of ints or int...): the desired shape + +""", +) + +add_docstr_all( + "reshape_as", + r""" +reshape_as(other) -> Tensor + +Returns this tensor as the same shape as :attr:`other`. +``self.reshape_as(other)`` is equivalent to ``self.reshape(other.sizes())``. +This method returns a view if ``other.sizes()`` is compatible with the current +shape. See :meth:`torch.Tensor.view` on when it is possible to return a view. + +Please see :meth:`reshape` for more information about ``reshape``. + +Args: + other (:class:`torch.Tensor`): The result tensor has the same shape + as :attr:`other`. +""", +) + +add_docstr_all( + "resize_", + r""" +resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor + +Resizes :attr:`self` tensor to the specified size. If the number of elements is +larger than the current storage size, then the underlying storage is resized +to fit the new number of elements. If the number of elements is smaller, the +underlying storage is not changed. Existing elements are preserved but any new +memory is uninitialized. + +.. warning:: + + This is a low-level method. The storage is reinterpreted as C-contiguous, + ignoring the current strides (unless the target size equals the current + size, in which case the tensor is left unchanged). For most purposes, you + will instead want to use :meth:`~Tensor.view()`, which checks for + contiguity, or :meth:`~Tensor.reshape()`, which copies data if needed. To + change the size in-place with custom strides, see :meth:`~Tensor.set_()`. + +.. note:: + + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, new elements are initialized to prevent nondeterministic behavior + from using the result as an input to an operation. Floating point and + complex values are set to NaN, and integer values are set to the maximum + value. + +Args: + sizes (torch.Size or int...): the desired size + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``sizes``. + +Example:: + + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> x.resize_(2, 2) + tensor([[ 1, 2], + [ 3, 4]]) +""", +) + +add_docstr_all( + "resize_as_", + r""" +resize_as_(tensor, memory_format=torch.contiguous_format) -> Tensor + +Resizes the :attr:`self` tensor to be the same size as the specified +:attr:`tensor`. This is equivalent to ``self.resize_(tensor.size())``. + +Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``tensor.size()``. + +""", +) + +add_docstr_all( + "rot90", + r""" +rot90(k, dims) -> Tensor + +See :func:`torch.rot90` +""", +) + +add_docstr_all( + "round", + r""" +round(decimals=0) -> Tensor + +See :func:`torch.round` +""", +) + +add_docstr_all( + "round_", + r""" +round_(decimals=0) -> Tensor + +In-place version of :meth:`~Tensor.round` +""", +) + +add_docstr_all( + "rsqrt", + r""" +rsqrt() -> Tensor + +See :func:`torch.rsqrt` +""", +) + +add_docstr_all( + "rsqrt_", + r""" +rsqrt_() -> Tensor + +In-place version of :meth:`~Tensor.rsqrt` +""", +) + +add_docstr_all( + "scatter_", + r""" +scatter_(dim, index, src, *, reduce=None) -> Tensor + +Writes all values from the tensor :attr:`src` into :attr:`self` at the indices +specified in the :attr:`index` tensor. For each value in :attr:`src`, its output +index is specified by its index in :attr:`src` for ``dimension != dim`` and by +the corresponding value in :attr:`index` for ``dimension = dim``. + +For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + +This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + +:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have +the same number of dimensions. It is also required that +``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that +``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. +Note that ``index`` and ``src`` do not broadcast. + +Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be +between ``0`` and ``self.size(dim) - 1`` inclusive. + +.. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + +Additionally accepts an optional :attr:`reduce` argument that allows +specification of an optional reduction operation, which is applied to all +values in the tensor :attr:`src` into :attr:`self` at the indices +specified in the :attr:`index`. For each value in :attr:`src`, the reduction +operation is applied to an index in :attr:`self` which is specified by +its index in :attr:`src` for ``dimension != dim`` and by the corresponding +value in :attr:`index` for ``dimension = dim``. + +Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` +is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + +Reducing with the addition operation is the same as using +:meth:`~torch.Tensor.scatter_add_`. + +.. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + +Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + +Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + +.. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + +Writes the value from :attr:`value` into :attr:`self` at the indices +specified in the :attr:`index` tensor. This operation is equivalent to the previous version, +with the :attr:`src` tensor filled entirely with :attr:`value`. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + +Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + +Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) +""", +) + +add_docstr_all( + "scatter_add_", + r""" +scatter_add_(dim, index, src) -> Tensor + +Adds all values from the tensor :attr:`src` into :attr:`self` at the indices +specified in the :attr:`index` tensor in a similar fashion as +:meth:`~torch.Tensor.scatter_`. For each value in :attr:`src`, it is added to +an index in :attr:`self` which is specified by its index in :attr:`src` +for ``dimension != dim`` and by the corresponding value in :attr:`index` for +``dimension = dim``. + +For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + +:attr:`self`, :attr:`index` and :attr:`src` should have same number of +dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all +dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions +``d != dim``. Note that ``index`` and ``src`` do not broadcast. + +Note: + {forward_reproducibility_note} + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and add, can be + either empty or of the same dimensionality as ``src``. When empty, the + operation returns ``self`` unchanged. + src (Tensor): the source elements to scatter and add + +Example:: + + >>> src = torch.ones((2, 5)) + >>> index = torch.tensor([[0, 1, 2, 0, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[1., 0., 0., 1., 1.], + [0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.]]) + >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[2., 0., 0., 1., 1.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 1., 1.]]) + +""".format( + **reproducibility_notes + ), +) + +add_docstr_all( + "scatter_reduce_", + r""" +scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor + +Reduces all values from the :attr:`src` tensor to the indices specified in +the :attr:`index` tensor in the :attr:`self` tensor using the applied reduction +defined via the :attr:`reduce` argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, +:obj:`"amax"`, :obj:`"amin"`). For each value in :attr:`src`, it is reduced to an +index in :attr:`self` which is specified by its index in :attr:`src` for +``dimension != dim`` and by the corresponding value in :attr:`index` for +``dimension = dim``. If :obj:`include_self="True"`, the values in the :attr:`self` +tensor are included in the reduction. + +:attr:`self`, :attr:`index` and :attr:`src` should all have +the same number of dimensions. It is also required that +``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that +``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. +Note that ``index`` and ``src`` do not broadcast. + +For a 3-D tensor with :obj:`reduce="sum"` and :obj:`include_self=True` the +output is given as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + +Note: + {forward_reproducibility_note} + +.. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + +.. warning:: + + This function is in beta and may change in the near future. + +Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and reduce. + src (Tensor): the source elements to scatter and reduce + reduce (str): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + include_self (bool): whether elements from the :attr:`self` tensor are + included in the reduction + +Example:: + + >>> src = torch.tensor([1., 2., 3., 4., 5., 6.]) + >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) + >>> input = torch.tensor([1., 2., 3., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum") + tensor([5., 14., 8., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False) + tensor([4., 12., 5., 4.]) + >>> input2 = torch.tensor([5., 4., 3., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax") + tensor([5., 6., 5., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) + tensor([3., 6., 5., 2.]) + + +""".format( + **reproducibility_notes + ), +) + +add_docstr_all( + "select", + r""" +select(dim, index) -> Tensor + +See :func:`torch.select` +""", +) + +add_docstr_all( + "select_scatter", + r""" +select_scatter(src, dim, index) -> Tensor + +See :func:`torch.select_scatter` +""", +) + +add_docstr_all( + "slice_scatter", + r""" +slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor + +See :func:`torch.slice_scatter` +""", +) + +add_docstr_all( + "set_", + r""" +set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor + +Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, +:attr:`self` tensor will share the same storage and have the same size and +strides as :attr:`source`. Changes to elements in one tensor will be reflected +in the other. + +If :attr:`source` is a :class:`~torch.Storage`, the method sets the underlying +storage, offset, size, and stride. + +Args: + source (Tensor or Storage): the tensor or storage to use + storage_offset (int, optional): the offset in the storage + size (torch.Size, optional): the desired size. Defaults to the size of the source. + stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. +""", +) + +add_docstr_all( + "sigmoid", + r""" +sigmoid() -> Tensor + +See :func:`torch.sigmoid` +""", +) + +add_docstr_all( + "sigmoid_", + r""" +sigmoid_() -> Tensor + +In-place version of :meth:`~Tensor.sigmoid` +""", +) + +add_docstr_all( + "logit", + r""" +logit() -> Tensor + +See :func:`torch.logit` +""", +) + +add_docstr_all( + "logit_", + r""" +logit_() -> Tensor + +In-place version of :meth:`~Tensor.logit` +""", +) + +add_docstr_all( + "sign", + r""" +sign() -> Tensor + +See :func:`torch.sign` +""", +) + +add_docstr_all( + "sign_", + r""" +sign_() -> Tensor + +In-place version of :meth:`~Tensor.sign` +""", +) + +add_docstr_all( + "signbit", + r""" +signbit() -> Tensor + +See :func:`torch.signbit` +""", +) + +add_docstr_all( + "sgn", + r""" +sgn() -> Tensor + +See :func:`torch.sgn` +""", +) + +add_docstr_all( + "sgn_", + r""" +sgn_() -> Tensor + +In-place version of :meth:`~Tensor.sgn` +""", +) + +add_docstr_all( + "sin", + r""" +sin() -> Tensor + +See :func:`torch.sin` +""", +) + +add_docstr_all( + "sin_", + r""" +sin_() -> Tensor + +In-place version of :meth:`~Tensor.sin` +""", +) + +add_docstr_all( + "sinc", + r""" +sinc() -> Tensor + +See :func:`torch.sinc` +""", +) + +add_docstr_all( + "sinc_", + r""" +sinc_() -> Tensor + +In-place version of :meth:`~Tensor.sinc` +""", +) + +add_docstr_all( + "sinh", + r""" +sinh() -> Tensor + +See :func:`torch.sinh` +""", +) + +add_docstr_all( + "sinh_", + r""" +sinh_() -> Tensor + +In-place version of :meth:`~Tensor.sinh` +""", +) + +add_docstr_all( + "size", + r""" +size(dim=None) -> torch.Size or int + +Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, +the returned value is a :class:`torch.Size`, a subclass of :class:`tuple`. +If ``dim`` is specified, returns an int holding the size of that dimension. + +Args: + dim (int, optional): The dimension for which to retrieve the size. + +Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.size(dim=1) + 4 + +""", +) + +add_docstr_all( + "shape", + r""" +shape() -> torch.Size + +Returns the size of the :attr:`self` tensor. Alias for :attr:`size`. + +See also :meth:`Tensor.size`. + +Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.shape + torch.Size([3, 4, 5]) + +""", +) + +add_docstr_all( + "sort", + r""" +sort(dim=-1, descending=False) -> (Tensor, LongTensor) + +See :func:`torch.sort` +""", +) + +add_docstr_all( + "msort", + r""" +msort() -> Tensor + +See :func:`torch.msort` +""", +) + +add_docstr_all( + "argsort", + r""" +argsort(dim=-1, descending=False) -> LongTensor + +See :func:`torch.argsort` +""", +) + +add_docstr_all( + "sparse_dim", + r""" +sparse_dim() -> int + +Return the number of sparse dimensions in a :ref:`sparse tensor ` :attr:`self`. + +.. note:: + Returns ``0`` if :attr:`self` is not a sparse tensor. + +See also :meth:`Tensor.dense_dim` and :ref:`hybrid tensors `. +""", +) + +add_docstr_all( + "sparse_resize_", + r""" +sparse_resize_(size, sparse_dim, dense_dim) -> Tensor + +Resizes :attr:`self` :ref:`sparse tensor ` to the desired +size and the number of sparse and dense dimensions. + +.. note:: + If the number of specified elements in :attr:`self` is zero, then + :attr:`size`, :attr:`sparse_dim`, and :attr:`dense_dim` can be any + size and positive integers such that ``len(size) == sparse_dim + + dense_dim``. + + If :attr:`self` specifies one or more elements, however, then each + dimension in :attr:`size` must not be smaller than the corresponding + dimension of :attr:`self`, :attr:`sparse_dim` must equal the number + of sparse dimensions in :attr:`self`, and :attr:`dense_dim` must + equal the number of dense dimensions in :attr:`self`. + +.. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + +Args: + size (torch.Size): the desired size. If :attr:`self` is non-empty + sparse tensor, the desired size cannot be smaller than the + original size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions +""", +) + +add_docstr_all( + "sparse_resize_and_clear_", + r""" +sparse_resize_and_clear_(size, sparse_dim, dense_dim) -> Tensor + +Removes all specified elements from a :ref:`sparse tensor +` :attr:`self` and resizes :attr:`self` to the desired +size and the number of sparse and dense dimensions. + +.. warning: + Throws an error if :attr:`self` is not a sparse tensor. + +Args: + size (torch.Size): the desired size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions +""", +) + +add_docstr_all( + "sqrt", + r""" +sqrt() -> Tensor + +See :func:`torch.sqrt` +""", +) + +add_docstr_all( + "sqrt_", + r""" +sqrt_() -> Tensor + +In-place version of :meth:`~Tensor.sqrt` +""", +) + +add_docstr_all( + "square", + r""" +square() -> Tensor + +See :func:`torch.square` +""", +) + +add_docstr_all( + "square_", + r""" +square_() -> Tensor + +In-place version of :meth:`~Tensor.square` +""", +) + +add_docstr_all( + "squeeze", + r""" +squeeze(dim=None) -> Tensor + +See :func:`torch.squeeze` +""", +) + +add_docstr_all( + "squeeze_", + r""" +squeeze_(dim=None) -> Tensor + +In-place version of :meth:`~Tensor.squeeze` +""", +) + +add_docstr_all( + "std", + r""" +std(dim=None, *, correction=1, keepdim=False) -> Tensor + +See :func:`torch.std` +""", +) + +add_docstr_all( + "storage_offset", + r""" +storage_offset() -> int + +Returns :attr:`self` tensor's offset in the underlying storage in terms of +number of storage elements (not bytes). + +Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5]) + >>> x.storage_offset() + 0 + >>> x[3:].storage_offset() + 3 + +""", +) + +add_docstr_all( + "untyped_storage", + r""" +untyped_storage() -> torch.UntypedStorage + +Returns the underlying :class:`UntypedStorage`. +""", +) + +add_docstr_all( + "stride", + r""" +stride(dim) -> tuple or int + +Returns the stride of :attr:`self` tensor. + +Stride is the jump necessary to go from one element to the next one in the +specified dimension :attr:`dim`. A tuple of all strides is returned when no +argument is passed in. Otherwise, an integer value is returned as the stride in +the particular dimension :attr:`dim`. + +Args: + dim (int, optional): the desired dimension in which stride is required + +Example:: + + >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> x.stride() + (5, 1) + >>> x.stride(0) + 5 + >>> x.stride(-1) + 1 + +""", +) + +add_docstr_all( + "sub", + r""" +sub(other, *, alpha=1) -> Tensor + +See :func:`torch.sub`. +""", +) + +add_docstr_all( + "sub_", + r""" +sub_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.sub` +""", +) + +add_docstr_all( + "subtract", + r""" +subtract(other, *, alpha=1) -> Tensor + +See :func:`torch.subtract`. +""", +) + +add_docstr_all( + "subtract_", + r""" +subtract_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.subtract`. +""", +) + +add_docstr_all( + "sum", + r""" +sum(dim=None, keepdim=False, dtype=None) -> Tensor + +See :func:`torch.sum` +""", +) + +add_docstr_all( + "nansum", + r""" +nansum(dim=None, keepdim=False, dtype=None) -> Tensor + +See :func:`torch.nansum` +""", +) + +add_docstr_all( + "svd", + r""" +svd(some=True, compute_uv=True) -> (Tensor, Tensor, Tensor) + +See :func:`torch.svd` +""", +) + +add_docstr_all( + "swapdims", + r""" +swapdims(dim0, dim1) -> Tensor + +See :func:`torch.swapdims` +""", +) + +add_docstr_all( + "swapdims_", + r""" +swapdims_(dim0, dim1) -> Tensor + +In-place version of :meth:`~Tensor.swapdims` +""", +) + +add_docstr_all( + "swapaxes", + r""" +swapaxes(axis0, axis1) -> Tensor + +See :func:`torch.swapaxes` +""", +) + +add_docstr_all( + "swapaxes_", + r""" +swapaxes_(axis0, axis1) -> Tensor + +In-place version of :meth:`~Tensor.swapaxes` +""", +) + +add_docstr_all( + "t", + r""" +t() -> Tensor + +See :func:`torch.t` +""", +) + +add_docstr_all( + "t_", + r""" +t_() -> Tensor + +In-place version of :meth:`~Tensor.t` +""", +) + +add_docstr_all( + "tile", + r""" +tile(dims) -> Tensor + +See :func:`torch.tile` +""", +) + +add_docstr_all( + "to", + r""" +to(*args, **kwargs) -> Tensor + +Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are +inferred from the arguments of ``self.to(*args, **kwargs)``. + +.. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + +Here are the ways to call ``to``: + +.. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + {memory_format} + +.. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking`, tries to convert asynchronously with respect to + the host if possible, e.g., converting a CPU Tensor with pinned memory to a + CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + {memory_format} + +.. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. When :attr:`non_blocking`, tries to convert + asynchronously with respect to the host if possible, e.g., converting a CPU + Tensor with pinned memory to a CUDA Tensor. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + +Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') +""".format( + **common_args + ), +) + +add_docstr_all( + "byte", + r""" +byte(memory_format=torch.preserve_format) -> Tensor + +``self.byte()`` is equivalent to ``self.to(torch.uint8)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "bool", + r""" +bool(memory_format=torch.preserve_format) -> Tensor + +``self.bool()`` is equivalent to ``self.to(torch.bool)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "char", + r""" +char(memory_format=torch.preserve_format) -> Tensor + +``self.char()`` is equivalent to ``self.to(torch.int8)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "bfloat16", + r""" +bfloat16(memory_format=torch.preserve_format) -> Tensor +``self.bfloat16()`` is equivalent to ``self.to(torch.bfloat16)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "double", + r""" +double(memory_format=torch.preserve_format) -> Tensor + +``self.double()`` is equivalent to ``self.to(torch.float64)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "float", + r""" +float(memory_format=torch.preserve_format) -> Tensor + +``self.float()`` is equivalent to ``self.to(torch.float32)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "cdouble", + r""" +cdouble(memory_format=torch.preserve_format) -> Tensor + +``self.cdouble()`` is equivalent to ``self.to(torch.complex128)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "cfloat", + r""" +cfloat(memory_format=torch.preserve_format) -> Tensor + +``self.cfloat()`` is equivalent to ``self.to(torch.complex64)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "chalf", + r""" +chalf(memory_format=torch.preserve_format) -> Tensor + +``self.chalf()`` is equivalent to ``self.to(torch.complex32)``. See :func:`to`. + +Args: + {memory_format} + """.format( + **common_args + ), +) + +add_docstr_all( + "half", + r""" +half(memory_format=torch.preserve_format) -> Tensor + +``self.half()`` is equivalent to ``self.to(torch.float16)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "int", + r""" +int(memory_format=torch.preserve_format) -> Tensor + +``self.int()`` is equivalent to ``self.to(torch.int32)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "int_repr", + r""" +int_repr() -> Tensor + +Given a quantized Tensor, +``self.int_repr()`` returns a CPU Tensor with uint8_t as data type that stores the +underlying uint8_t values of the given Tensor. +""", +) + + +add_docstr_all( + "long", + r""" +long(memory_format=torch.preserve_format) -> Tensor + +``self.long()`` is equivalent to ``self.to(torch.int64)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "short", + r""" +short(memory_format=torch.preserve_format) -> Tensor + +``self.short()`` is equivalent to ``self.to(torch.int16)``. See :func:`to`. + +Args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr_all( + "take", + r""" +take(indices) -> Tensor + +See :func:`torch.take` +""", +) + +add_docstr_all( + "take_along_dim", + r""" +take_along_dim(indices, dim) -> Tensor + +See :func:`torch.take_along_dim` +""", +) + +add_docstr_all( + "tan", + r""" +tan() -> Tensor + +See :func:`torch.tan` +""", +) + +add_docstr_all( + "tan_", + r""" +tan_() -> Tensor + +In-place version of :meth:`~Tensor.tan` +""", +) + +add_docstr_all( + "tanh", + r""" +tanh() -> Tensor + +See :func:`torch.tanh` +""", +) + +add_docstr_all( + "softmax", + r""" +softmax(dim) -> Tensor + +Alias for :func:`torch.nn.functional.softmax`. +""", +) + +add_docstr_all( + "tanh_", + r""" +tanh_() -> Tensor + +In-place version of :meth:`~Tensor.tanh` +""", +) + +add_docstr_all( + "tolist", + r""" +tolist() -> list or number + +Returns the tensor as a (nested) list. For scalars, a standard +Python number is returned, just like with :meth:`~Tensor.item`. +Tensors are automatically moved to the CPU first if necessary. + +This operation is not differentiable. + +Examples:: + + >>> a = torch.randn(2, 2) + >>> a.tolist() + [[0.012766935862600803, 0.5415473580360413], + [-0.08909505605697632, 0.7729271650314331]] + >>> a[0,0].tolist() + 0.012766935862600803 +""", +) + +add_docstr_all( + "topk", + r""" +topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor) + +See :func:`torch.topk` +""", +) + +add_docstr_all( + "to_dense", + r""" +to_dense(dtype=None, *, masked_grad=True) -> Tensor + +Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`. + +Keyword args: + {dtype} + masked_grad (bool, optional): If set to ``True`` (default) and + :attr:`self` has a sparse layout then the backward of + :meth:`to_dense` returns ``grad.sparse_mask(self)``. + +Example:: + + >>> s = torch.sparse_coo_tensor( + ... torch.tensor([[1, 1], + ... [0, 2]]), + ... torch.tensor([9, 10]), + ... size=(3, 3)) + >>> s.to_dense() + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) +""", +) + +add_docstr_all( + "to_sparse", + r""" +to_sparse(sparseDims) -> Tensor + +Returns a sparse copy of the tensor. PyTorch supports sparse tensors in +:ref:`coordinate format `. + +Args: + sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor + +Example:: + + >>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) + >>> d + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + >>> d.to_sparse() + tensor(indices=tensor([[1, 1], + [0, 2]]), + values=tensor([ 9, 10]), + size=(3, 3), nnz=2, layout=torch.sparse_coo) + >>> d.to_sparse(1) + tensor(indices=tensor([[1]]), + values=tensor([[ 9, 0, 10]]), + size=(3, 3), nnz=1, layout=torch.sparse_coo) + +.. method:: to_sparse(*, layout=None, blocksize=None, dense_dim=None) -> Tensor + :noindex: + +Returns a sparse tensor with the specified layout and blocksize. If +the :attr:`self` is strided, the number of dense dimensions could be +specified, and a hybrid sparse tensor will be created, with +`dense_dim` dense dimensions and `self.dim() - 2 - dense_dim` batch +dimension. + +.. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + +Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR, CSC, BSR or BSC tensor. This argument should be + used only if :attr:`self` is a strided tensor, and must be a + value between 0 and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize + + >>> x = torch.tensor([[[1], [0]], [[0], [0]], [[2], [3]]]) + >>> x.to_sparse(layout=torch.sparse_csr, dense_dim=1) + tensor(crow_indices=tensor([0, 1, 1, 3]), + col_indices=tensor([0, 0, 1]), + values=tensor([[1], + [2], + [3]]), size=(3, 2, 1), nnz=3, layout=torch.sparse_csr) + +""", +) + +add_docstr_all( + "to_sparse_csr", + r""" +to_sparse_csr(dense_dim=None) -> Tensor + +Convert a tensor to compressed row storage format (CSR). Except for +strided tensors, only works with 2D tensors. If the :attr:`self` is +strided, then the number of dense dimensions could be specified, and a +hybrid CSR tensor will be created, with `dense_dim` dense dimensions +and `self.dim() - 2 - dense_dim` batch dimension. + +Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csr() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csr(dense_dim=2) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csr) + +""", +) + +add_docstr_all( + "to_sparse_csc", + r""" +to_sparse_csc() -> Tensor + +Convert a tensor to compressed column storage (CSC) format. Except +for strided tensors, only works with 2D tensors. If the :attr:`self` +is strided, then the number of dense dimensions could be specified, +and a hybrid CSC tensor will be created, with `dense_dim` dense +dimensions and `self.dim() - 2 - dense_dim` batch dimension. + +Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csc() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csc(dense_dim=2) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csc) + +""", +) + +add_docstr_all( + "to_sparse_bsr", + r""" +to_sparse_bsr(blocksize, dense_dim) -> Tensor + +Convert a tensor to a block sparse row (BSR) storage format of given +blocksize. If the :attr:`self` is strided, then the number of dense +dimensions could be specified, and a hybrid BSR tensor will be +created, with `dense_dim` dense dimensions and `self.dim() - 2 - +dense_dim` batch dimension. + +Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsr = sparse.to_sparse_bsr((5, 5)) + >>> sparse_bsr.col_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsr((2, 1), 1) + tensor(crow_indices=tensor([0, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsr) + +""", +) + +add_docstr_all( + "to_sparse_bsc", + r""" +to_sparse_bsc(blocksize, dense_dim) -> Tensor + +Convert a tensor to a block sparse column (BSC) storage format of +given blocksize. If the :attr:`self` is strided, then the number of +dense dimensions could be specified, and a hybrid BSC tensor will be +created, with `dense_dim` dense dimensions and `self.dim() - 2 - +dense_dim` batch dimension. + +Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSC tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + +Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsc = sparse.to_sparse_bsc((5, 5)) + >>> sparse_bsc.row_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsc((2, 1), 1) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 1, 0]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsc) + +""", +) + +add_docstr_all( + "to_mkldnn", + r""" +to_mkldnn() -> Tensor +Returns a copy of the tensor in ``torch.mkldnn`` layout. + +""", +) + +add_docstr_all( + "trace", + r""" +trace() -> Tensor + +See :func:`torch.trace` +""", +) + +add_docstr_all( + "transpose", + r""" +transpose(dim0, dim1) -> Tensor + +See :func:`torch.transpose` +""", +) + +add_docstr_all( + "transpose_", + r""" +transpose_(dim0, dim1) -> Tensor + +In-place version of :meth:`~Tensor.transpose` +""", +) + +add_docstr_all( + "triangular_solve", + r""" +triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) + +See :func:`torch.triangular_solve` +""", +) + +add_docstr_all( + "tril", + r""" +tril(diagonal=0) -> Tensor + +See :func:`torch.tril` +""", +) + +add_docstr_all( + "tril_", + r""" +tril_(diagonal=0) -> Tensor + +In-place version of :meth:`~Tensor.tril` +""", +) + +add_docstr_all( + "triu", + r""" +triu(diagonal=0) -> Tensor + +See :func:`torch.triu` +""", +) + +add_docstr_all( + "triu_", + r""" +triu_(diagonal=0) -> Tensor + +In-place version of :meth:`~Tensor.triu` +""", +) + +add_docstr_all( + "true_divide", + r""" +true_divide(value) -> Tensor + +See :func:`torch.true_divide` +""", +) + +add_docstr_all( + "true_divide_", + r""" +true_divide_(value) -> Tensor + +In-place version of :meth:`~Tensor.true_divide_` +""", +) + +add_docstr_all( + "trunc", + r""" +trunc() -> Tensor + +See :func:`torch.trunc` +""", +) + +add_docstr_all( + "fix", + r""" +fix() -> Tensor + +See :func:`torch.fix`. +""", +) + +add_docstr_all( + "trunc_", + r""" +trunc_() -> Tensor + +In-place version of :meth:`~Tensor.trunc` +""", +) + +add_docstr_all( + "fix_", + r""" +fix_() -> Tensor + +In-place version of :meth:`~Tensor.fix` +""", +) + +add_docstr_all( + "type", + r""" +type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor +Returns the type if `dtype` is not provided, else casts this object to +the specified type. + +If this is already of the correct type, no copy is performed and the +original object is returned. + +Args: + dtype (dtype or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. +""", +) + +add_docstr_all( + "type_as", + r""" +type_as(tensor) -> Tensor + +Returns this tensor cast to the type of the given tensor. + +This is a no-op if the tensor is already of the correct type. This is +equivalent to ``self.type(tensor.type())`` + +Args: + tensor (Tensor): the tensor which has the desired type +""", +) + +add_docstr_all( + "unfold", + r""" +unfold(dimension, size, step) -> Tensor + +Returns a view of the original tensor which contains all slices of size :attr:`size` from +:attr:`self` tensor in the dimension :attr:`dimension`. + +Step between two slices is given by :attr:`step`. + +If `sizedim` is the size of dimension :attr:`dimension` for :attr:`self`, the size of +dimension :attr:`dimension` in the returned tensor will be +`(sizedim - size) / step + 1`. + +An additional dimension of size :attr:`size` is appended in the returned tensor. + +Args: + dimension (int): dimension in which unfolding happens + size (int): the size of each slice that is unfolded + step (int): the step between each slice + +Example:: + + >>> x = torch.arange(1., 8) + >>> x + tensor([ 1., 2., 3., 4., 5., 6., 7.]) + >>> x.unfold(0, 2, 1) + tensor([[ 1., 2.], + [ 2., 3.], + [ 3., 4.], + [ 4., 5.], + [ 5., 6.], + [ 6., 7.]]) + >>> x.unfold(0, 2, 2) + tensor([[ 1., 2.], + [ 3., 4.], + [ 5., 6.]]) +""", +) + +add_docstr_all( + "uniform_", + r""" +uniform_(from=0, to=1, *, generator=None) -> Tensor + +Fills :attr:`self` tensor with numbers sampled from the continuous uniform +distribution: + +.. math:: + f(x) = \dfrac{1}{\text{to} - \text{from}} +""", +) + +add_docstr_all( + "unsqueeze", + r""" +unsqueeze(dim) -> Tensor + +See :func:`torch.unsqueeze` +""", +) + +add_docstr_all( + "unsqueeze_", + r""" +unsqueeze_(dim) -> Tensor + +In-place version of :meth:`~Tensor.unsqueeze` +""", +) + +add_docstr_all( + "var", + r""" +var(dim=None, *, correction=1, keepdim=False) -> Tensor + +See :func:`torch.var` +""", +) + +add_docstr_all( + "vdot", + r""" +vdot(other) -> Tensor + +See :func:`torch.vdot` +""", +) + +add_docstr_all( + "view", + r""" +view(*shape) -> Tensor + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`shape`. + +The returned tensor shares the same data and must have the same number +of elements, but may have a different size. For a tensor to be viewed, the new +view size must be compatible with its original size and stride, i.e., each new +view dimension must either be a subspace of an original dimension, or only span +across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following +contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + +.. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + +Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` +without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a +:meth:`view` can be performed, it is advisable to use :meth:`reshape`, which +returns a view if the shapes are compatible, and copies (equivalent to calling +:meth:`contiguous`) otherwise. + +Args: + shape (torch.Size or int...): the desired size + +Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + +.. method:: view(dtype) -> Tensor + :noindex: + +Returns a new tensor with the same data as the :attr:`self` tensor but of a +different :attr:`dtype`. + +If the element size of :attr:`dtype` is different than that of ``self.dtype``, +then the size of the last dimension of the output will be scaled +proportionally. For instance, if :attr:`dtype` element size is twice that of +``self.dtype``, then each pair of elements in the last dimension of +:attr:`self` will be combined, and the size of the last dimension of the output +will be half that of :attr:`self`. If :attr:`dtype` element size is half that +of ``self.dtype``, then each element in the last dimension of :attr:`self` will +be split in two, and the size of the last dimension of the output will be +double that of :attr:`self`. For this to be possible, the following conditions +must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + +Additionally, if the element size of :attr:`dtype` is greater than that of +``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + +If any of the above conditions are not met, an error is thrown. + +.. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + +Args: + dtype (:class:`torch.dtype`): the desired dtype + +Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) +""", +) + +add_docstr_all( + "view_as", + r""" +view_as(other) -> Tensor + +View this tensor as the same size as :attr:`other`. +``self.view_as(other)`` is equivalent to ``self.view(other.size())``. + +Please see :meth:`~Tensor.view` for more information about ``view``. + +Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. +""", +) + +add_docstr_all( + "expand", + r""" +expand(*sizes) -> Tensor + +Returns a new view of the :attr:`self` tensor with singleton dimensions expanded +to a larger size. + +Passing -1 as the size for a dimension means not changing the size of +that dimension. + +Tensor can be also expanded to a larger number of dimensions, and the +new ones will be appended at the front. For the new dimensions, the +size cannot be set to -1. + +Expanding a tensor does not allocate new memory, but only creates a +new view on the existing tensor where a dimension of size one is +expanded to a larger size by setting the ``stride`` to 0. Any dimension +of size 1 can be expanded to an arbitrary value without allocating new +memory. + +Args: + *sizes (torch.Size or int...): the desired expanded size + +.. warning:: + + More than one element of an expanded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + +Example:: + + >>> x = torch.tensor([[1], [2], [3]]) + >>> x.size() + torch.Size([3, 1]) + >>> x.expand(3, 4) + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + >>> x.expand(-1, 4) # -1 means not changing the size of that dimension + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) +""", +) + +add_docstr_all( + "expand_as", + r""" +expand_as(other) -> Tensor + +Expand this tensor to the same size as :attr:`other`. +``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``. + +Please see :meth:`~Tensor.expand` for more information about ``expand``. + +Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. +""", +) + +add_docstr_all( + "sum_to_size", + r""" +sum_to_size(*size) -> Tensor + +Sum ``this`` tensor to :attr:`size`. +:attr:`size` must be broadcastable to ``this`` tensor size. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. +""", +) + + +add_docstr_all( + "zero_", + r""" +zero_() -> Tensor + +Fills :attr:`self` tensor with zeros. +""", +) + +add_docstr_all( + "matmul", + r""" +matmul(tensor2) -> Tensor + +See :func:`torch.matmul` +""", +) + +add_docstr_all( + "chunk", + r""" +chunk(chunks, dim=0) -> List of Tensors + +See :func:`torch.chunk` +""", +) + +add_docstr_all( + "unsafe_chunk", + r""" +unsafe_chunk(chunks, dim=0) -> List of Tensors + +See :func:`torch.unsafe_chunk` +""", +) + +add_docstr_all( + "unsafe_split", + r""" +unsafe_split(split_size, dim=0) -> List of Tensors + +See :func:`torch.unsafe_split` +""", +) + +add_docstr_all( + "tensor_split", + r""" +tensor_split(indices_or_sections, dim=0) -> List of Tensors + +See :func:`torch.tensor_split` +""", +) + +add_docstr_all( + "hsplit", + r""" +hsplit(split_size_or_sections) -> List of Tensors + +See :func:`torch.hsplit` +""", +) + +add_docstr_all( + "vsplit", + r""" +vsplit(split_size_or_sections) -> List of Tensors + +See :func:`torch.vsplit` +""", +) + +add_docstr_all( + "dsplit", + r""" +dsplit(split_size_or_sections) -> List of Tensors + +See :func:`torch.dsplit` +""", +) + +add_docstr_all( + "stft", + r""" +stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor + +See :func:`torch.stft` +""", +) + +add_docstr_all( + "istft", + r""" +istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor + +See :func:`torch.istft` +""", +) + +add_docstr_all( + "det", + r""" +det() -> Tensor + +See :func:`torch.det` +""", +) + +add_docstr_all( + "where", + r""" +where(condition, y) -> Tensor + +``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. +See :func:`torch.where` +""", +) + +add_docstr_all( + "logdet", + r""" +logdet() -> Tensor + +See :func:`torch.logdet` +""", +) + +add_docstr_all( + "slogdet", + r""" +slogdet() -> (Tensor, Tensor) + +See :func:`torch.slogdet` +""", +) + +add_docstr_all( + "unbind", + r""" +unbind(dim=0) -> seq + +See :func:`torch.unbind` +""", +) + +add_docstr_all( + "pin_memory", + r""" +pin_memory() -> Tensor + +Copies the tensor to pinned memory, if it's not already pinned. +""", +) + +add_docstr_all( + "pinverse", + r""" +pinverse() -> Tensor + +See :func:`torch.pinverse` +""", +) + +add_docstr_all( + "index_add", + r""" +index_add(dim, index, source, *, alpha=1) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_add_`. +""", +) + +add_docstr_all( + "index_copy", + r""" +index_copy(dim, index, tensor2) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_copy_`. +""", +) + +add_docstr_all( + "index_fill", + r""" +index_fill(dim, index, value) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.index_fill_`. +""", +) + +add_docstr_all( + "scatter", + r""" +scatter(dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_` +""", +) + +add_docstr_all( + "scatter_add", + r""" +scatter_add(dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_add_` +""", +) + +add_docstr_all( + "scatter_reduce", + r""" +scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` +""", +) + +add_docstr_all( + "masked_scatter", + r""" +masked_scatter(mask, tensor) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.masked_scatter_` + +.. note:: + + The inputs :attr:`self` and :attr:`mask` + :ref:`broadcast `. + +Example: + + >>> self = torch.tensor([0, 0, 0, 0, 0]) + >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]]) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + +""", +) + +add_docstr_all( + "xlogy", + r""" +xlogy(other) -> Tensor + +See :func:`torch.xlogy` +""", +) + +add_docstr_all( + "xlogy_", + r""" +xlogy_(other) -> Tensor + +In-place version of :meth:`~Tensor.xlogy` +""", +) + +add_docstr_all( + "masked_fill", + r""" +masked_fill(mask, value) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.masked_fill_` +""", +) + +add_docstr_all( + "grad", + r""" +This attribute is ``None`` by default and becomes a Tensor the first time a call to +:func:`backward` computes gradients for ``self``. +The attribute will then contain the gradients computed and future calls to +:func:`backward` will accumulate (add) gradients into it. +""", +) + +add_docstr_all( + "retain_grad", + r""" +retain_grad() -> None + +Enables this Tensor to have their :attr:`grad` populated during +:func:`backward`. This is a no-op for leaf tensors. +""", +) + +add_docstr_all( + "retains_grad", + r""" +Is ``True`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be +populated during :func:`backward`, ``False`` otherwise. +""", +) + +add_docstr_all( + "requires_grad", + r""" +Is ``True`` if gradients need to be computed for this Tensor, ``False`` otherwise. + +.. note:: + + The fact that gradients need to be computed for a Tensor do not mean that the :attr:`grad` + attribute will be populated, see :attr:`is_leaf` for more details. + +""", +) + +add_docstr_all( + "is_leaf", + r""" +All Tensors that have :attr:`requires_grad` which is ``False`` will be leaf Tensors by convention. + +For Tensors that have :attr:`requires_grad` which is ``True``, they will be leaf Tensors if they were +created by the user. This means that they are not the result of an operation and so +:attr:`grad_fn` is None. + +Only leaf Tensors will have their :attr:`grad` populated during a call to :func:`backward`. +To get :attr:`grad` populated for non-leaf Tensors, you can use :func:`retain_grad`. + +Example:: + + >>> a = torch.rand(10, requires_grad=True) + >>> a.is_leaf + True + >>> b = torch.rand(10, requires_grad=True).cuda() + >>> b.is_leaf + False + # b was created by the operation that cast a cpu Tensor into a cuda Tensor + >>> c = torch.rand(10, requires_grad=True) + 2 + >>> c.is_leaf + False + # c was created by the addition operation + >>> d = torch.rand(10).cuda() + >>> d.is_leaf + True + # d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + >>> e = torch.rand(10).cuda().requires_grad_() + >>> e.is_leaf + True + # e requires gradients and has no operations creating it + >>> f = torch.rand(10, requires_grad=True, device="cuda") + >>> f.is_leaf + True + # f requires grad, has no operation creating it + + +""", +) + +add_docstr_all( + "names", + r""" +Stores names for each of this tensor's dimensions. + +``names[idx]`` corresponds to the name of tensor dimension ``idx``. +Names are either a string if the dimension is named or ``None`` if the +dimension is unnamed. + +Dimension names may contain characters or underscore. Furthermore, a dimension +name must be a valid Python variable name (i.e., does not start with underscore). + +Tensors may not have two named dimensions with the same name. + +.. warning:: + The named tensor API is experimental and subject to change. + +""", +) + +add_docstr_all( + "is_cuda", + r""" +Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_cpu", + r""" +Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_xla", + r""" +Is ``True`` if the Tensor is stored on an XLA device, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_ipu", + r""" +Is ``True`` if the Tensor is stored on the IPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_xpu", + r""" +Is ``True`` if the Tensor is stored on the XPU, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_quantized", + r""" +Is ``True`` if the Tensor is quantized, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_meta", + r""" +Is ``True`` if the Tensor is a meta tensor, ``False`` otherwise. Meta tensors +are like normal tensors, but they carry no data. +""", +) + +add_docstr_all( + "is_mps", + r""" +Is ``True`` if the Tensor is stored on the MPS device, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_sparse", + r""" +Is ``True`` if the Tensor uses sparse COO storage layout, ``False`` otherwise. +""", +) + +add_docstr_all( + "is_sparse_csr", + r""" +Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise. +""", +) + +add_docstr_all( + "device", + r""" +Is the :class:`torch.device` where this Tensor is. +""", +) + +add_docstr_all( + "ndim", + r""" +Alias for :meth:`~Tensor.dim()` +""", +) + +add_docstr_all( + "itemsize", + r""" +Alias for :meth:`~Tensor.element_size()` +""", +) + +add_docstr_all( + "nbytes", + r""" +Returns the number of bytes consumed by the "view" of elements of the Tensor +if the Tensor does not use sparse storage layout. +Defined to be :meth:`~Tensor.numel()` * :meth:`~Tensor.element_size()` +""", +) + +add_docstr_all( + "T", + r""" +Returns a view of this tensor with its dimensions reversed. + +If ``n`` is the number of dimensions in ``x``, +``x.T`` is equivalent to ``x.permute(n-1, n-2, ..., 0)``. + +.. warning:: + The use of :func:`Tensor.T` on tensors of dimension other than 2 to reverse their shape + is deprecated and it will throw an error in a future release. Consider :attr:`~.Tensor.mT` + to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse + the dimensions of a tensor. +""", +) + +add_docstr_all( + "H", + r""" +Returns a view of a matrix (2-D tensor) conjugated and transposed. + +``x.H`` is equivalent to ``x.transpose(0, 1).conj()`` for complex matrices and +``x.transpose(0, 1)`` for real matrices. + +.. seealso:: + + :attr:`~.Tensor.mH`: An attribute that also works on batches of matrices. +""", +) + +add_docstr_all( + "mT", + r""" +Returns a view of this tensor with the last two dimensions transposed. + +``x.mT`` is equivalent to ``x.transpose(-2, -1)``. +""", +) + +add_docstr_all( + "mH", + r""" +Accessing this property is equivalent to calling :func:`adjoint`. +""", +) + +add_docstr_all( + "adjoint", + r""" +adjoint() -> Tensor + +Alias for :func:`adjoint` +""", +) + +add_docstr_all( + "real", + r""" +Returns a new tensor containing real values of the :attr:`self` tensor for a complex-valued input tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +Returns :attr:`self` if :attr:`self` is a real-valued tensor tensor. + +Example:: + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + +""", +) + +add_docstr_all( + "imag", + r""" +Returns a new tensor containing imaginary values of the :attr:`self` tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +.. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + +Example:: + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + +""", +) + +add_docstr_all( + "as_subclass", + r""" +as_subclass(cls) -> Tensor + +Makes a ``cls`` instance with the same data pointer as ``self``. Changes +in the output mirror changes in ``self``, and the output stays attached +to the autograd graph. ``cls`` must be a subclass of ``Tensor``. +""", +) + +add_docstr_all( + "crow_indices", + r""" +crow_indices() -> IntTensor + +Returns the tensor containing the compressed row indices of the :attr:`self` +tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. +The ``crow_indices`` tensor is strictly of shape (:attr:`self`.size(0) + 1) +and of type ``int32`` or ``int64``. When using MKL routines such as sparse +matrix multiplication, it is necessary to use ``int32`` indexing in order +to avoid downcasting and potentially losing information. + +Example:: + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.crow_indices() + tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32) + +""", +) + +add_docstr_all( + "col_indices", + r""" +col_indices() -> IntTensor + +Returns the tensor containing the column indices of the :attr:`self` +tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. +The ``col_indices`` tensor is strictly of shape (:attr:`self`.nnz()) +and of type ``int32`` or ``int64``. When using MKL routines such as sparse +matrix multiplication, it is necessary to use ``int32`` indexing in order +to avoid downcasting and potentially losing information. + +Example:: + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.col_indices() + tensor([0, 1, 2, 3, 4], dtype=torch.int32) + +""", +) + +add_docstr_all( + "to_padded_tensor", + r""" +to_padded_tensor(padding, output_size=None) -> Tensor +See :func:`to_padded_tensor` +""", +) diff --git a/MLPY/Lib/site-packages/torch/_tensor_str.py b/MLPY/Lib/site-packages/torch/_tensor_str.py new file mode 100644 index 0000000000000000000000000000000000000000..ffed793f56286b58d9a0c1711706738ea5a0d96c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_tensor_str.py @@ -0,0 +1,697 @@ +import contextlib +import dataclasses +import math +import textwrap +from typing import Any, Dict, Optional + +import torch +from torch import inf + + +@dataclasses.dataclass +class __PrinterOptions: + precision: int = 4 + threshold: float = 1000 + edgeitems: int = 3 + linewidth: int = 80 + sci_mode: Optional[bool] = None + + +PRINT_OPTS = __PrinterOptions() + + +# We could use **kwargs, but this will give better docs +def set_printoptions( + precision=None, + threshold=None, + edgeitems=None, + linewidth=None, + profile=None, + sci_mode=None, +): + r"""Set options for printing. Items shamelessly taken from NumPy + + Args: + precision: Number of digits of precision for floating point output + (default = 4). + threshold: Total number of array elements which trigger summarization + rather than full `repr` (default = 1000). + edgeitems: Number of array items in summary at beginning and end of + each dimension (default = 3). + linewidth: The number of characters per line for the purpose of + inserting line breaks (default = 80). Thresholded matrices will + ignore this parameter. + profile: Sane defaults for pretty printing. Can override with any of + the above options. (any one of `default`, `short`, `full`) + sci_mode: Enable (True) or disable (False) scientific notation. If + None (default) is specified, the value is defined by + `torch._tensor_str._Formatter`. This value is automatically chosen + by the framework. + + Example:: + + >>> # Limit the precision of elements + >>> torch.set_printoptions(precision=2) + >>> torch.tensor([1.12345]) + tensor([1.12]) + >>> # Limit the number of elements shown + >>> torch.set_printoptions(threshold=5) + >>> torch.arange(10) + tensor([0, 1, 2, ..., 7, 8, 9]) + >>> # Restore defaults + >>> torch.set_printoptions(profile='default') + >>> torch.tensor([1.12345]) + tensor([1.1235]) + >>> torch.arange(10) + tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + """ + if profile is not None: + if profile == "default": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + elif profile == "short": + PRINT_OPTS.precision = 2 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 2 + PRINT_OPTS.linewidth = 80 + elif profile == "full": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = inf + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + + if precision is not None: + PRINT_OPTS.precision = precision + if threshold is not None: + PRINT_OPTS.threshold = threshold + if edgeitems is not None: + PRINT_OPTS.edgeitems = edgeitems + if linewidth is not None: + PRINT_OPTS.linewidth = linewidth + PRINT_OPTS.sci_mode = sci_mode + + +def get_printoptions() -> Dict[str, Any]: + r"""Gets the current options for printing, as a dictionary that + can be passed as ``**kwargs`` to set_printoptions(). + """ + return dataclasses.asdict(PRINT_OPTS) + + +@contextlib.contextmanager +def printoptions(**kwargs): + r"""Context manager that temporarily changes the print options. Accepted + arguments are same as :func:`set_printoptions`.""" + old_kwargs = get_printoptions() + set_printoptions(**kwargs) + try: + yield + finally: + set_printoptions(**old_kwargs) + + +def tensor_totype(t): + dtype = torch.float if t.is_mps else torch.double + return t.to(dtype=dtype) + + +class _Formatter: + def __init__(self, tensor): + self.floating_dtype = tensor.dtype.is_floating_point + self.int_mode = True + self.sci_mode = False + self.max_width = 1 + + with torch.no_grad(): + tensor_view = tensor.reshape(-1) + + if not self.floating_dtype: + for value in tensor_view: + value_str = f"{value}" + self.max_width = max(self.max_width, len(value_str)) + + else: + nonzero_finite_vals = torch.masked_select( + tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) + ) + + if nonzero_finite_vals.numel() == 0: + # no valid number, do nothing + return + + # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. + nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) + nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) + nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) + + for value in nonzero_finite_vals: + if value != torch.ceil(value): + self.int_mode = False + break + + if self.int_mode: + # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites + # to indicate that the tensor is of floating type. add 1 to the len to account for this. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + ): + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = f"{value:.0f}" + self.max_width = max(self.max_width, len(value_str) + 1) + else: + # Check if scientific representation should be used. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + or nonzero_finite_min < 1.0e-4 + ): + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + + if PRINT_OPTS.sci_mode is not None: + self.sci_mode = PRINT_OPTS.sci_mode + + def width(self): + return self.max_width + + def format(self, value): + if self.floating_dtype: + if self.sci_mode: + ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value) + elif self.int_mode: + ret = f"{value:.0f}" + if not (math.isinf(value) or math.isnan(value)): + ret += "." + else: + ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value) + else: + ret = f"{value}" + return (self.max_width - len(ret)) * " " + ret + + +def _scalar_str(self, formatter1, formatter2=None): + if formatter2 is not None: + real_str = _scalar_str(self.real, formatter1) + imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == "+" or imag_str[0] == "-": + return real_str + imag_str + else: + return real_str + "+" + imag_str + else: + return formatter1.format(self.item()) + + +def _vector_str(self, indent, summarize, formatter1, formatter2=None): + # length includes spaces and comma between elements + element_length = formatter1.width() + 2 + if formatter2 is not None: + # width for imag_formatter + an extra j for complex + element_length += formatter2.width() + 1 + + elements_per_line = max( + 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) + ) + + def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): + if formatter2 is not None: + real_str = formatter1.format(val.real) + imag_str = (formatter2.format(val.imag) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == "+" or imag_str[0] == "-": + return real_str + imag_str + else: + return real_str + "+" + imag_str + else: + return formatter1.format(val) + + if summarize and not PRINT_OPTS.edgeitems: + # Deal with edge case that negative zero is zero + data = ["..."] + elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + data = ( + [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()] + + [" ..."] + + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()] + ) + else: + data = [_val_formatter(val) for val in self.tolist()] + + data_lines = [ + data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) + ] + lines = [", ".join(line) for line in data_lines] + return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" + + +# formatter2 is only used for printing complex tensors. +# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real +# and tensor.imag respesectively +def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None): + dim = self.dim() + + if dim == 0: + return _scalar_str(self, formatter1, formatter2) + + if dim == 1: + return _vector_str(self, indent, summarize, formatter1, formatter2) + + if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + slices = ( + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, PRINT_OPTS.edgeitems) + ] + + ["..."] + + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(len(self) - PRINT_OPTS.edgeitems, len(self)) + ] + ) + else: + slices = [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, self.size(0)) + ] + + tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) + return "[" + tensor_str + "]" + + +def _tensor_str(self, indent): + if self.numel() == 0: + return "[]" + + if self.has_names(): + # There are two main codepaths (possibly more) that tensor printing goes through: + # - tensor data can fit comfortably on screen + # - tensor data needs to be summarized + # Some of the codepaths don't fully support named tensors, so we send in + # an unnamed tensor to the formatting code as a workaround. + self = self.rename(None) + + summarize = self.numel() > PRINT_OPTS.threshold + + if self._is_zerotensor(): + self = self.clone() + + # handle the negative bit + if self.is_neg(): + self = self.resolve_neg() + + if self.dtype in [ + torch.float16, + torch.bfloat16, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ]: + self = self.float() + + if self.dtype is torch.complex32: + self = self.cfloat() + + if self.dtype.is_complex: + # handle the conjugate bit + self = self.resolve_conj() + real_formatter = _Formatter( + get_summarized_data(self.real) if summarize else self.real + ) + imag_formatter = _Formatter( + get_summarized_data(self.imag) if summarize else self.imag + ) + return _tensor_str_with_formatter( + self, indent, summarize, real_formatter, imag_formatter + ) + else: + formatter = _Formatter(get_summarized_data(self) if summarize else self) + return _tensor_str_with_formatter(self, indent, summarize, formatter) + + +def _add_suffixes(tensor_str, suffixes, indent, force_newline): + tensor_strs = [tensor_str] + last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 + for suffix in suffixes: + suffix_len = len(suffix) + if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: + tensor_strs.append(",\n" + " " * indent + suffix) + last_line_len = indent + suffix_len + force_newline = False + else: + tensor_strs.append(", " + suffix) + last_line_len += suffix_len + 2 + tensor_strs.append(")") + return "".join(tensor_strs) + + +def get_summarized_data(self): + dim = self.dim() + if dim == 0: + return self + if dim == 1: + if self.size(0) > 2 * PRINT_OPTS.edgeitems: + return torch.cat( + (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) + ) + else: + return self + if not PRINT_OPTS.edgeitems: + return self.new_empty([0] * self.dim()) + elif self.size(0) > 2 * PRINT_OPTS.edgeitems: + start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] + end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))] + return torch.stack([get_summarized_data(x) for x in (start + end)]) + else: + return torch.stack([get_summarized_data(x) for x in self]) + + +def _str_intern(inp, *, tensor_contents=None): + if torch._C._functorch.is_functorch_wrapped_tensor(inp): + return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents) + is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter + if inp.is_nested: + prefix = "nested_tensor(" + elif is_plain_tensor: + prefix = "tensor(" + else: + prefix = f"{type(inp).__name__}(" + indent = len(prefix) + suffixes = [] + custom_contents_provided = tensor_contents is not None + if custom_contents_provided: + tensor_str = tensor_contents + + # This is used to extract the primal value and thus disable the forward AD + # within this function. + # TODO(albanD) This needs to be updated when more than one level is supported + self, tangent = torch.autograd.forward_ad.unpack_dual(inp) + + # Note [Print tensor device]: + # A general logic here is we only print device when it doesn't match + # the device specified in default tensor type. + # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus + # torch._C._get_default_device() only returns either cpu or cuda. + # In other cases, we don't have a way to set them as default yet, + # and we should always print out device for them. + if ( + self.device.type != torch._C._get_default_device() + or ( + self.device.type == "cuda" + and torch.cuda.current_device() != self.device.index + ) + or (self.device.type == "mps") + ): + suffixes.append("device='" + str(self.device) + "'") + + # Tensor printing performs tensor operations like slice, indexing, etc to make it in a + # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence, + # to avoid compilations, copying the tensor to cpu before printing. + if self.device.type in ["xla", "lazy", "ipu", "mtia"]: + self = self.to("cpu") + + # TODO: add an API to map real -> complex dtypes + _default_complex_dtype = ( + torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat + ) + has_default_dtype = self.dtype in ( + torch.get_default_dtype(), + _default_complex_dtype, + torch.int64, + torch.bool, + ) + if self.is_sparse: + suffixes.append("size=" + str(tuple(self.shape))) + from torch._subclasses.fake_tensor import FakeTensor + + is_meta = self.is_meta or isinstance(self, FakeTensor) + if not is_meta: + suffixes.append("nnz=" + str(self._nnz())) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + indices_prefix = "indices=tensor(" + indices = self._indices().detach() + if is_meta: + indices_str = "..." + else: + indices_str = _tensor_str(indices, indent + len(indices_prefix)) + if indices.numel() == 0 or is_meta: + indices_str += ", size=" + str(tuple(indices.shape)) + values_prefix = "values=tensor(" + values = self._values().detach() + if is_meta: + values_str = "..." + else: + values_str = _tensor_str(values, indent + len(values_prefix)) + if values.numel() == 0 or is_meta: + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + indices_prefix + + indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + from torch._subclasses.fake_tensor import FakeTensor + + suffixes.append("size=" + str(tuple(self.shape))) + is_meta = self.is_meta or isinstance(self, FakeTensor) + if not is_meta: + suffixes.append("nnz=" + str(self._nnz())) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + compressed_indices_method, plain_indices_method = { + torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), + torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), + torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), + torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), + }[self.layout] + if self.layout in {torch.sparse_csr, torch.sparse_bsr}: + cdimname, pdimname = "row", "column" + else: + cdimname, pdimname = "column", "row" + compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" + compressed_indices = compressed_indices_method(self).detach() + if is_meta: + compressed_indices_str = "..." + else: + compressed_indices_str = _tensor_str( + compressed_indices, indent + len(compressed_indices_prefix) + ) + if compressed_indices.numel() == 0 or is_meta: + compressed_indices_str += ", size=" + str( + tuple(compressed_indices.shape) + ) + plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" + plain_indices = plain_indices_method(self).detach() + if is_meta: + plain_indices_str = "..." + else: + plain_indices_str = _tensor_str( + plain_indices, indent + len(plain_indices_prefix) + ) + if plain_indices.numel() == 0 or is_meta: + plain_indices_str += ", size=" + str(tuple(plain_indices.shape)) + values_prefix = "values=tensor(" + values = self.values().detach() + if is_meta: + values_str = "..." + else: + values_str = _tensor_str(values, indent + len(values_prefix)) + if values.numel() == 0 or is_meta: + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + compressed_indices_prefix + + compressed_indices_str + + "),\n" + + " " * indent + + plain_indices_prefix + + plain_indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.is_quantized: + suffixes.append("size=" + str(tuple(self.shape))) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + suffixes.append("quantization_scheme=" + str(self.qscheme())) + if ( + self.qscheme() == torch.per_tensor_affine + or self.qscheme() == torch.per_tensor_symmetric + ): + suffixes.append("scale=" + str(self.q_scale())) + suffixes.append("zero_point=" + str(self.q_zero_point())) + elif ( + self.qscheme() == torch.per_channel_affine + or self.qscheme() == torch.per_channel_symmetric + or self.qscheme() == torch.per_channel_affine_float_qparams + ): + suffixes.append("scale=" + str(self.q_per_channel_scales())) + suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) + suffixes.append("axis=" + str(self.q_per_channel_axis())) + if not custom_contents_provided: + tensor_str = _tensor_str(self.dequantize(), indent) + elif self.is_nested: + if not custom_contents_provided: + + def indented_str(s, indent): + return "\n".join(f" {line}" for line in s.split("\n")) + + strs = ",\n".join( + indented_str(str(t), indent + 1) + for t in torch.ops.aten.unbind.int(self, 0) + ) + tensor_str = f"[\n{strs}\n]" + elif torch._is_functional_tensor(self): + prefix = "_to_functional_tensor(" + tensor_str = repr(torch._from_functional_tensor(self)) + else: + # Circular import problem, so we import it here + from torch._subclasses.fake_tensor import FakeTensor + + if self.is_meta or isinstance(self, FakeTensor): + suffixes.append("size=" + str(tuple(self.shape))) + if self.dtype != torch.get_default_dtype(): + suffixes.append("dtype=" + str(self.dtype)) + # TODO: This implies that ellipses is valid syntax for allocating + # a meta tensor or FakeTensor, which it could be, but it isn't right now + if not custom_contents_provided: + tensor_str = "..." + else: + if self.numel() == 0 and not self.is_sparse: + # Explicitly print the shape if it is not (0,), to match NumPy behavior + if self.dim() != 1: + suffixes.append("size=" + str(tuple(self.shape))) + + # In an empty tensor, there are no elements to infer if the dtype + # should be int64, so it must be shown explicitly. + if self.dtype != torch.get_default_dtype(): + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + tensor_str = "[]" + else: + if not PRINT_OPTS.edgeitems: + suffixes.append("size=" + str(tuple(self.shape))) + + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + + if not custom_contents_provided: + if self.layout != torch.strided: + tensor_str = _tensor_str(self.to_dense(), indent) + else: + tensor_str = _tensor_str(self, indent) + + if self.layout != torch.strided: + suffixes.append("layout=" + str(self.layout)) + + # Use inp here to get the original grad_fn and not the one generated by the forward grad + # unpacking. + grad_fn_name = None + try: + grad_fn = inp.grad_fn + except RuntimeError: + # Accessing the grad_fn calls rebasing logic which would cause an error + # if that tensor is a view created in no-grad mode modified in-place in + # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 + grad_fn_name = "Invalid" + + if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] + grad_fn_name = type(grad_fn).__name__ + if grad_fn_name == "CppFunction": + grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] + + if grad_fn_name is not None: + suffixes.append(f"grad_fn=<{grad_fn_name}>") + elif inp.requires_grad: + suffixes.append("requires_grad=True") + + if self.has_names(): + suffixes.append(f"names={self.names}") + + if tangent is not None: + suffixes.append(f"tangent={tangent}") + + string_repr = _add_suffixes( + prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined] + ) + + # Check if this instance is flagged as a parameter and change the repr accordingly. + # Unfortunately, this function has to be aware of this detail. + # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future, + # this should be done for those as well to produce a valid repr. + if isinstance(self, torch.nn.Parameter) and not is_plain_tensor: + string_repr = f"Parameter({string_repr})" + + return string_repr + + +def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None): + level = torch._C._functorch.maybe_get_level(tensor) + assert level != -1 + + if torch._C._functorch.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + torch._sync(tensor) + + value = torch._C._functorch.get_unwrapped(tensor) + value_repr = repr(value) + + indented_value_repr = textwrap.indent(value_repr, " " * 4) + if torch._C._functorch.is_batchedtensor(tensor): + bdim = torch._C._functorch.maybe_get_bdim(tensor) + assert bdim != -1 + return ( + f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n" + f"{indented_value_repr}\n" + f")" + ) + if torch._C._functorch.is_gradtrackingtensor(tensor): + return ( + f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")" + ) + if torch._C._functorch.is_functionaltensor(tensor): + return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})" + + raise ValueError("We don't know how to print this, please file us an issue") + + +def _str(self, *, tensor_contents=None): + with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes(): + guard = torch._C._DisableFuncTorch() + return _str_intern(self, tensor_contents=tensor_contents) diff --git a/MLPY/Lib/site-packages/torch/_torch_docs.py b/MLPY/Lib/site-packages/torch/_torch_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..a95c21974800ca3f889a6850d97e3e71d2856edb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_torch_docs.py @@ -0,0 +1,14192 @@ +"""Adds docstrings to functions defined in the torch._C module.""" + +import re + +import torch._C +from torch._C import _add_docstr as add_docstr + + +def parse_kwargs(desc): + r"""Map a description of args to a dictionary of {argname: description}. + + Input: + (' weight (Tensor): a weight tensor\n' + + ' Some optional description') + Output: { + 'weight': \ + 'weight (Tensor): a weight tensor\n Some optional description' + } + """ + # Split on exactly 4 spaces after a newline + regx = re.compile(r"\n\s{4}(?!\s)") + kwargs = [section.strip() for section in regx.split(desc)] + kwargs = [section for section in kwargs if len(section) > 0] + return {desc.split(" ")[0]: desc for desc in kwargs} + + +def merge_dicts(*dicts): + """Merge dictionaries into a single dictionary.""" + return {x: d[x] for d in dicts for x in d} + + +common_args = parse_kwargs( + """ + input (Tensor): the input tensor. + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned tensor. Default: ``torch.preserve_format``. +""" +) + +reduceops_common_args = merge_dicts( + common_args, + parse_kwargs( + """ + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + keepdim (bool): whether the output tensor has :attr:`dim` retained or not. +""" + ), +) + +multi_dim_common = merge_dicts( + reduceops_common_args, + parse_kwargs( + """ + dim (int or tuple of ints): the dimension or dimensions to reduce. +""" + ), + { + "keepdim_details": """ +If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the +output tensor having 1 (or ``len(dim)``) fewer dimension(s). +""" + }, + { + "opt_dim": """ + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. +""" + }, +) + +single_dim_common = merge_dicts( + reduceops_common_args, + parse_kwargs( + """ + dim (int): the dimension to reduce. +""" + ), + { + "keepdim_details": """If :attr:`keepdim` is ``True``, the output tensor is of the same size +as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in +the output tensor having 1 fewer dimension than :attr:`input`.""" + }, +) + +factory_common_args = merge_dicts( + common_args, + parse_kwargs( + """ + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. +""" + ), + { + "sparse_factory_device_note": """\ +.. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor.""" + }, +) + +factory_like_common_args = parse_kwargs( + """ + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. +""" +) + +factory_data_common_args = parse_kwargs( + """ + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. +""" +) + +tf32_notes = { + "tf32_note": """This operator supports :ref:`TensorFloat32`.""" +} + +rocm_fp16_notes = { + "rocm_fp16_note": """On certain ROCm devices, when using float16 inputs this module will use \ +:ref:`different precision` for backward.""" +} + +reproducibility_notes = { + "forward_reproducibility_note": """This operation may behave nondeterministically when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "backward_reproducibility_note": """This operation may produce nondeterministic gradients when given tensors on \ +a CUDA device. See :doc:`/notes/randomness` for more information.""", + "cudnn_reproducibility_note": """In some circumstances when given tensors on a CUDA device \ +and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is \ +undesirable, you can try to make the operation deterministic (potentially at \ +a performance cost) by setting ``torch.backends.cudnn.deterministic = True``. \ +See :doc:`/notes/randomness` for more information.""", +} + +sparse_support_notes = { + "sparse_beta_warning": """ +.. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request.""", +} + +add_docstr( + torch.abs, + r""" +abs(input, *, out=None) -> Tensor + +Computes the absolute value of each element in :attr:`input`. + +.. math:: + \text{out}_{i} = |\text{input}_{i}| +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.abs(torch.tensor([-1, -2, 3])) + tensor([ 1, 2, 3]) +""".format( + **common_args + ), +) + +add_docstr( + torch.absolute, + r""" +absolute(input, *, out=None) -> Tensor + +Alias for :func:`torch.abs` +""", +) + +add_docstr( + torch.acos, + r""" +acos(input, *, out=None) -> Tensor + +Computes the inverse cosine of each element in :attr:`input`. + +.. math:: + \text{out}_{i} = \cos^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) + >>> torch.acos(a) + tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arccos, + r""" +arccos(input, *, out=None) -> Tensor + +Alias for :func:`torch.acos`. +""", +) + +add_docstr( + torch.acosh, + r""" +acosh(input, *, out=None) -> Tensor + +Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) + +Note: + The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range + will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. +""" + + r""" +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.randn(4).uniform_(1, 2) + >>> a + tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) + >>> torch.acosh(a) + tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arccosh, + r""" +arccosh(input, *, out=None) -> Tensor + +Alias for :func:`torch.acosh`. +""", +) + +add_docstr( + torch.index_add, + r""" +index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor + +See :meth:`~Tensor.index_add_` for function description. +""", +) + +add_docstr( + torch.index_copy, + r""" +index_copy(input, dim, index, source, *, out=None) -> Tensor + +See :meth:`~Tensor.index_add_` for function description. +""", +) + +add_docstr( + torch.index_reduce, + r""" +index_reduce(input, dim, index, source, reduce, *, include_self=True, out=None) -> Tensor + +See :meth:`~Tensor.index_reduce_` for function description. +""", +) + +add_docstr( + torch.add, + r""" +add(input, other, *, alpha=1, out=None) -> Tensor + +Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + +.. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Number): the tensor or number to add to :attr:`input`. + +Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + {out} + +Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.addbmm, + r""" +addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a batch matrix-matrix product of matrices stored +in :attr:`batch1` and :attr:`batch2`, +with a reduced add step (all matrix multiplications get accumulated +along the first dimension). +:attr:`input` is added to the final result. + +:attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the +same number of matrices. + +If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a +:math:`(b \times m \times p)` tensor, :attr:`input` must be +:ref:`broadcastable ` with a :math:`(n \times p)` tensor +and :attr:`out` will be a :math:`(n \times p)` tensor. + +.. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` +must be real numbers, otherwise they should be integers. + +{tf32_note} + +{rocm_fp16_note} + +Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + input (Tensor): matrix to be added + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.addcdiv, + r""" +addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + +Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, +multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + +.. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + +.. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} +""" + + r""" + +The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be +:ref:`broadcastable `. + +For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be +a real number, otherwise an integer. + +Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + +Keyword args: + value (Number, optional): multiplier for :math:`\text{{tensor1}} / \text{{tensor2}}` + {out} + +Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.addcmul, + r""" +addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + +Performs the element-wise multiplication of :attr:`tensor1` +by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` +and adds it to :attr:`input`. + +.. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i +""" + + r""" +The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be +:ref:`broadcastable `. + +For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be +a real number, otherwise an integer. + +Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + +Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + {out} + +Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.addmm, + r""" +addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. +The matrix :attr:`input` is added to the final result. + +If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a +:math:`(m \times p)` tensor, then :attr:`input` must be +:ref:`broadcastable ` with a :math:`(n \times p)` tensor +and :attr:`out` will be a :math:`(n \times p)` tensor. + +:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between +:attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + +.. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and +:attr:`alpha` must be real numbers, otherwise they should be integers. + +This operation has support for arguments with :ref:`sparse layouts`. If +:attr:`input` is sparse the result will have the same layout and if :attr:`out` +is provided it must have the same layout as :attr:`input`. + +{sparse_beta_warning} + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes, **sparse_support_notes + ), +) + +add_docstr( + torch.adjoint, + r""" +adjoint(Tensor) -> Tensor +Returns a view of the tensor conjugated and with the last two dimensions transposed. + +``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and +to ``x.transpose(-2, -1)`` for real tensors. + +Example:: + >>> x = torch.arange(4, dtype=torch.float) + >>> A = torch.complex(x, x).reshape(2, 2) + >>> A + tensor([[0.+0.j, 1.+1.j], + [2.+2.j, 3.+3.j]]) + >>> A.adjoint() + tensor([[0.-0.j, 2.-2.j], + [1.-1.j, 3.-3.j]]) + >>> (A.adjoint() == A.mH).all() + tensor(True) +""", +) + +add_docstr( + torch.sspaddmm, + r""" +sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + +Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor +:attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + +Note: This function is equivalent to :func:`torch.addmm`, except +:attr:`input` and :attr:`mat1` are sparse. + +Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + {out} +""".format( + **common_args + ), +) + +add_docstr( + torch.smm, + r""" +smm(input, mat) -> Tensor + +Performs a matrix multiplication of the sparse matrix :attr:`input` +with the dense matrix :attr:`mat`. + +Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied +""", +) + +add_docstr( + torch.addmv, + r""" +addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a matrix-vector product of the matrix :attr:`mat` and +the vector :attr:`vec`. +The vector :attr:`input` is added to the final result. + +If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of +size `m`, then :attr:`input` must be +:ref:`broadcastable ` with a 1-D tensor of size `n` and +:attr:`out` will be 1-D tensor of size `n`. + +:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between +:attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + +.. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and +:attr:`alpha` must be real numbers, otherwise they should be integers. + +Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) +""".format( + **common_args + ), +) + +add_docstr( + torch.addr, + r""" +addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + +Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` +and adds it to the matrix :attr:`input`. + +Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the +outer product between :attr:`vec1` and :attr:`vec2` and the added matrix +:attr:`input` respectively. + +.. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector +of size `m`, then :attr:`input` must be +:ref:`broadcastable ` with a matrix of size +:math:`(n \times m)` and :attr:`out` will be a matrix of size +:math:`(n \times m)`. + +Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{{vec1}} \otimes \text{{vec2}}` (:math:`\alpha`) + {out} + +Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.allclose, + r""" +allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool + +This function checks if :attr:`input` and :attr:`other` satisfy the condition: + +.. math:: + \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert +""" + + r""" +elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to +`numpy.allclose `_ + +Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + +Example:: + + >>> torch.allclose(torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08])) + False + >>> torch.allclose(torch.tensor([10000., 1e-08]), torch.tensor([10000.1, 1e-09])) + True + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')])) + False + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) + True +""", +) + +add_docstr( + torch.all, + r""" +all(input) -> Tensor + +Tests if all elements in :attr:`input` evaluate to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + +.. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.any, + r""" +any(input) -> Tensor + +Tests if any element in :attr:`input` evaluates to `True`. + +.. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + +Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + +.. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + +For each row of :attr:`input` in the given dimension :attr:`dim`, +returns `True` if any element in the row evaluate to `True` and `False` otherwise. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.angle, + r""" +angle(input, *, out=None) -> Tensor + +Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + +.. math:: + \text{out}_{i} = angle(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +.. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + +Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, -45]) +""".format( + **common_args + ), +) + +add_docstr( + torch.as_strided, + r""" +as_strided(input, size, stride, storage_offset=None) -> Tensor + +Create a view of an existing `torch.Tensor` :attr:`input` with specified +:attr:`size`, :attr:`stride` and :attr:`storage_offset`. + +.. warning:: + Prefer using other view functions, like :meth:`torch.Tensor.expand`, + to setting a view's strides manually with `as_strided`, as this + function's behavior depends on the implementation of a tensor's storage. + The constructed view of the storage must only refer to elements within + the storage or a runtime error will be thrown, and if the view is + "overlapped" (with multiple indices referring to the same element in + memory) its behavior is undefined. + +Args: + {input} + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. + +Example:: + + >>> x = torch.randn(3, 3) + >>> x + tensor([[ 0.9039, 0.6291, 1.0795], + [ 0.1586, 2.1939, -0.4900], + [-0.1909, -0.7503, 1.9355]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2)) + >>> t + tensor([[0.9039, 1.0795], + [0.6291, 0.1586]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) + tensor([[0.6291, 0.1586], + [1.0795, 2.1939]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.as_tensor, + r""" +as_tensor(data, dtype=None, device=None) -> Tensor + +Converts :attr:`data` into a tensor, sharing data and preserving autograd +history if possible. + +If :attr:`data` is already a tensor with the requested dtype and device +then :attr:`data` itself is returned, but if :attr:`data` is a +tensor with a different dtype or device then it's copied as if using +`data.to(dtype=dtype, device=device)`. + +If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device then a +tensor is constructed using :func:`torch.from_numpy`. + +.. seealso:: + + :func:`torch.tensor` never shares its data and creates a new "leaf tensor" (see :doc:`/notes/autograd`). + + +Args: + {data} + {dtype} + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + + +Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) +""".format( + **factory_data_common_args + ), +) + +add_docstr( + torch.asin, + r""" +asin(input, *, out=None) -> Tensor + +Returns a new tensor with the arcsine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sin^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5962, 1.4985, -0.4396, 1.4525]) + >>> torch.asin(a) + tensor([-0.6387, nan, -0.4552, nan]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arcsin, + r""" +arcsin(input, *, out=None) -> Tensor + +Alias for :func:`torch.asin`. +""", +) + +add_docstr( + torch.asinh, + r""" +asinh(input, *, out=None) -> Tensor + +Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) + >>> torch.asinh(a) + tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arcsinh, + r""" +arcsinh(input, *, out=None) -> Tensor + +Alias for :func:`torch.asinh`. +""", +) + +add_docstr( + torch.atan, + r""" +atan(input, *, out=None) -> Tensor + +Returns a new tensor with the arctangent of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \tan^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) + >>> torch.atan(a) + tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arctan, + r""" +arctan(input, *, out=None) -> Tensor + +Alias for :func:`torch.atan`. +""", +) + +add_docstr( + torch.atan2, + r""" +atan2(input, other, *, out=None) -> Tensor + +Element-wise arctangent of :math:`\text{{input}}_{{i}} / \text{{other}}_{{i}}` +with consideration of the quadrant. Returns a new tensor with the signed angles +in radians between vector :math:`(\text{{other}}_{{i}}, \text{{input}}_{{i}})` +and vector :math:`(1, 0)`. (Note that :math:`\text{{other}}_{{i}}`, the second +parameter, is the x-coordinate, while :math:`\text{{input}}_{{i}}`, the first +parameter, is the y-coordinate.) + +The shapes of ``input`` and ``other`` must be +:ref:`broadcastable `. + +Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) + >>> torch.atan2(a, torch.randn(4)) + tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arctan2, + r""" +arctan2(input, other, *, out=None) -> Tensor +Alias for :func:`torch.atan2`. +""", +) + +add_docstr( + torch.atanh, + r""" +atanh(input, *, out=None) -> Tensor + +Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + +Note: + The domain of the inverse hyperbolic tangent is `(-1, 1)` and values outside this range + will be mapped to ``NaN``, except for the values `1` and `-1` for which the output is + mapped to `+/-INF` respectively. + +.. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.randn(4).uniform_(-1, 1) + >>> a + tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) + >>> torch.atanh(a) + tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) +""".format( + **common_args + ), +) + +add_docstr( + torch.arctanh, + r""" +arctanh(input, *, out=None) -> Tensor + +Alias for :func:`torch.atanh`. +""", +) + +add_docstr( + torch.asarray, + r""" +asarray(obj, *, dtype=None, device=None, copy=None, requires_grad=False) -> Tensor + +Converts :attr:`obj` to a tensor. + +:attr:`obj` can be one of: + +1. a tensor +2. a NumPy array or a NumPy scalar +3. a DLPack capsule +4. an object that implements Python's buffer protocol +5. a scalar +6. a sequence of scalars + +When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, +by default, not require a gradient, have the same datatype as :attr:`obj`, be on the +same device, and share memory with it. These properties can be controlled with the +:attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. +If the returned tensor is of a different datatype, on a different device, or a copy is +requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` +is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is +also a tensor with an autograd history then the returned tensor will have the same history. + +When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's +buffer protocol then the buffer is interpreted as an array of bytes grouped according to +the size of the datatype passed to the :attr:`dtype` keyword argument. (If no datatype is +passed then the default floating point datatype is used, instead.) The returned tensor +will have the specified datatype (or default floating point datatype if none is specified) +and, by default, be on the CPU device and share memory with the buffer. + +When :attr:`obj` is a NumPy scalar, the returned tensor will be a 0-dimensional tensor on +the CPU and that doesn't share its memory (i.e. ``copy=True``). By default datatype will +be the PyTorch datatype corresponding to the NumPy's scalar's datatype. + +When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the +returned tensor will, by default, infer its datatype from the scalar values, be on the +current default device, and not share its memory. + +.. seealso:: + + :func:`torch.tensor` creates a tensor that always copies the data from the input object. + :func:`torch.from_numpy` creates a tensor that always shares memory from NumPy arrays. + :func:`torch.frombuffer` creates a tensor that always shares memory from objects that + implement the buffer protocol. + :func:`torch.from_dlpack` creates a tensor that always shares memory from + DLPack capsules. + +Args: + obj (object): a tensor, NumPy array, DLPack Capsule, object that implements Python's + buffer protocol, scalar, or sequence of scalars. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the datatype of the returned tensor. + Default: ``None``, which causes the datatype of the returned tensor to be + inferred from :attr:`obj`. + copy (bool, optional): controls whether the returned tensor shares memory with :attr:`obj`. + Default: ``None``, which causes the returned tensor to share memory with :attr:`obj` + whenever possible. If ``True`` then the returned tensor does not share its memory. + If ``False`` then the returned tensor shares its memory with :attr:`obj` and an + error is thrown if it cannot. + device (:class:`torch.device`, optional): the device of the returned tensor. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. + requires_grad (bool, optional): whether the returned tensor requires grad. + Default: ``False``, which causes the returned tensor not to require a gradient. + If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` + is also a tensor with an autograd history then the returned tensor will have + the same history. + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> # Shares memory with tensor 'a' + >>> b = torch.asarray(a) + >>> a.data_ptr() == b.data_ptr() + True + >>> # Forces memory copy + >>> c = torch.asarray(a, copy=True) + >>> a.data_ptr() == c.data_ptr() + False + + >>> a = torch.tensor([1., 2., 3.], requires_grad=True) + >>> b = a + 2 + >>> b + tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', with no grad + >>> c = torch.asarray(b) + >>> c + tensor([3., 4., 5.]) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> d = torch.asarray(b, requires_grad=True) + >>> d + tensor([3., 4., 5.], grad_fn=) + + >>> array = numpy.array([1, 2, 3]) + >>> # Shares memory with array 'array' + >>> t1 = torch.asarray(array) + >>> array.__array_interface__['data'][0] == t1.data_ptr() + True + >>> # Copies memory due to dtype mismatch + >>> t2 = torch.asarray(array, dtype=torch.float32) + >>> array.__array_interface__['data'][0] == t2.data_ptr() + False + + >>> scalar = numpy.float64(0.5) + >>> torch.asarray(scalar) + tensor(0.5000, dtype=torch.float64) +""", +) + +add_docstr( + torch.baddbmm, + r""" +baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + +Performs a batch matrix-matrix product of matrices in :attr:`batch1` +and :attr:`batch2`. +:attr:`input` is added to the final result. + +:attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same +number of matrices. + +If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a +:math:`(b \times m \times p)` tensor, then :attr:`input` must be +:ref:`broadcastable ` with a +:math:`(b \times n \times p)` tensor and :attr:`out` will be a +:math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the +same as the scaling factors used in :meth:`torch.addbmm`. + +.. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. +""" + + r""" +For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and +:attr:`alpha` must be real numbers, otherwise they should be integers. + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + +Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{{batch1}} \mathbin{{@}} \text{{batch2}}` (:math:`\alpha`) + {out} + +Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.bernoulli, + r""" +bernoulli(input, *, generator=None, out=None) -> Tensor + +Draws binary random numbers (0 or 1) from a Bernoulli distribution. + +The :attr:`input` tensor should be a tensor containing probabilities +to be used for drawing the binary random number. +Hence, all values in :attr:`input` have to be in the range: +:math:`0 \leq \text{input}_i \leq 1`. + +The :math:`\text{i}^{th}` element of the output tensor will draw a +value :math:`1` according to the :math:`\text{i}^{th}` probability value given +in :attr:`input`. + +.. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) +""" + + r""" +The returned :attr:`out` tensor only has values 0 or 1 and is of the same +shape as :attr:`input`. + +:attr:`out` can have integral ``dtype``, but :attr:`input` must have floating +point ``dtype``. + +Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + +Keyword args: + {generator} + {out} + +Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.bincount, + r""" +bincount(input, weights=None, minlength=0) -> Tensor + +Count the frequency of each value in an array of non-negative ints. + +The number of bins (size 1) is one larger than the largest value in +:attr:`input` unless :attr:`input` is empty, in which case the result is a +tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least +:attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size +:attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``, +``out[n] += weights[i]`` if :attr:`weights` is specified else +``out[n] += 1``. + +Note: + {backward_reproducibility_note} + +Arguments: + input (Tensor): 1-d int tensor + weights (Tensor): optional, weight for each value in the input tensor. + Should be of same size as input tensor. + minlength (int): optional, minimum number of bins. Should be non-negative. + +Returns: + output (Tensor): a tensor of shape ``Size([max(input) + 1])`` if + :attr:`input` is non-empty, else ``Size(0)`` + +Example:: + + >>> input = torch.randint(0, 8, (5,), dtype=torch.int64) + >>> weights = torch.linspace(0, 1, steps=5) + >>> input, weights + (tensor([4, 3, 6, 3, 4]), + tensor([ 0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + >>> torch.bincount(input) + tensor([0, 0, 0, 2, 2, 0, 1]) + + >>> input.bincount(weights) + tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) +""".format( + **reproducibility_notes + ), +) + +add_docstr( + torch.bitwise_not, + r""" +bitwise_not(input, *, out=None) -> Tensor + +Computes the bitwise NOT of the given input tensor. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical NOT. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) + tensor([ 0, 1, -4], dtype=torch.int8) +""".format( + **common_args + ), +) + +add_docstr( + torch.bmm, + r""" +bmm(input, mat2, *, out=None) -> Tensor + +Performs a batch matrix-matrix product of matrices stored in :attr:`input` +and :attr:`mat2`. + +:attr:`input` and :attr:`mat2` must be 3-D tensors each containing +the same number of matrices. + +If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a +:math:`(b \times m \times p)` tensor, :attr:`out` will be a +:math:`(b \times n \times p)` tensor. + +.. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i +""" + + r""" +{tf32_note} + +{rocm_fp16_note} + +.. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + +Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + +Keyword Args: + {out} + +Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes + ), +) + +add_docstr( + torch.bitwise_and, + r""" +bitwise_and(input, other, *, out=None) -> Tensor + +Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical AND. + +Args: + input: the first input tensor + other: the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_or, + r""" +bitwise_or(input, other, *, out=None) -> Tensor + +Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical OR. + +Args: + input: the first input tensor + other: the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_xor, + r""" +bitwise_xor(input, other, *, out=None) -> Tensor + +Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of +integral or Boolean types. For bool tensors, it computes the logical XOR. + +Args: + input: the first input tensor + other: the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_left_shift, + r""" +bitwise_left_shift(input, other, *, out=None) -> Tensor + +Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. +The input tensor must be of integral type. This operator supports +:ref:`broadcasting to a common shape ` and +:ref:`type promotion `. + +The operation applied is: + +.. math:: + \text{{out}}_i = \text{{input}}_i << \text{{other}}_i + +Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) +""".format( + **common_args + ), +) + +add_docstr( + torch.bitwise_right_shift, + r""" +bitwise_right_shift(input, other, *, out=None) -> Tensor + +Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. +The input tensor must be of integral type. This operator supports +:ref:`broadcasting to a common shape ` and +:ref:`type promotion `. +In any case, if the value of the right operand is negative or is greater +or equal to the number of bits in the promoted left operand, the behavior is undefined. + +The operation applied is: + +.. math:: + \text{{out}}_i = \text{{input}}_i >> \text{{other}}_i + +Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) +""".format( + **common_args + ), +) + +add_docstr( + torch.broadcast_to, + r""" +broadcast_to(input, shape) -> Tensor + +Broadcasts :attr:`input` to the shape :attr:`\shape`. +Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + +Args: + {input} + shape (list, tuple, or :class:`torch.Size`): the new shape. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.stack, + r""" +stack(tensors, dim=0, *, out=None) -> Tensor + +Concatenates a sequence of tensors along a new dimension. + +All tensors need to be of the same size. + +.. seealso:: + + :func:`torch.cat` concatenates the given sequence along an existing dimension. + +Arguments: + tensors (sequence of Tensors): sequence of tensors to concatenate + dim (int, optional): dimension to insert. Has to be between 0 and the number + of dimensions of concatenated tensors (inclusive). Default: 0 + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]) + >>> x = torch.stack((x, x)) # same as torch.stack((x, x), dim=0) + >>> x + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]], + + [[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]]) + >>> x.size() + torch.Size([2, 2, 3]) + >>> x = torch.stack((x, x), dim=1) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.3367, 0.1288, 0.2345]], + + [[ 0.2303, -1.1229, -0.1863], + [ 0.2303, -1.1229, -0.1863]]]) + >>> x = torch.stack((x, x), dim=2) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + >>> x = torch.stack((x, x), dim=-1) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.hstack, + r""" +hstack(tensors, *, out=None) -> Tensor + +Stack tensors in sequence horizontally (column wise). + +This is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.hstack((a,b)) + tensor([1, 2, 3, 4, 5, 6]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.hstack((a,b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.vstack, + r""" +vstack(tensors, *, out=None) -> Tensor + +Stack tensors in sequence vertically (row wise). + +This is equivalent to concatenation along the first axis after all 1-D tensors have been reshaped by :func:`torch.atleast_2d`. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.vstack((a,b)) + tensor([[1, 2, 3], + [4, 5, 6]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.vstack((a,b)) + tensor([[1], + [2], + [3], + [4], + [5], + [6]]) + + +""".format( + **common_args + ), +) + +add_docstr( + torch.dstack, + r""" +dstack(tensors, *, out=None) -> Tensor + +Stack tensors in sequence depthwise (along third axis). + +This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by :func:`torch.atleast_3d`. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.dstack((a,b)) + tensor([[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.dstack((a,b)) + tensor([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + + +""".format( + **common_args + ), +) + +add_docstr( + torch.tensor_split, + r""" +tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + +Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, +along dimension :attr:`dim` according to the indices or number of sections specified +by :attr:`indices_or_sections`. This function is based on NumPy's +:func:`numpy.array_split`. + +Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + +Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) +""", +) + +add_docstr( + torch.chunk, + r""" +chunk(input, chunks, dim=0) -> List of Tensors + +Attempts to split a tensor into the specified number of chunks. Each chunk is a view of +the input tensor. + + +.. note:: + + This function may return fewer than the specified number of chunks! + +.. seealso:: + + :func:`torch.tensor_split` a function that always returns exactly the specified number of chunks + +If the tensor size along the given dimension :attr:`dim` is divisible by :attr:`chunks`, +all returned chunks will be the same size. +If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`, +all returned chunks will be the same size, except the last one. +If such division is not possible, this function may return fewer +than the specified number of chunks. + +Arguments: + input (Tensor): the tensor to split + chunks (int): number of chunks to return + dim (int): dimension along which to split the tensor + +Example: + >>> torch.arange(11).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10])) + >>> torch.arange(12).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10, 11])) + >>> torch.arange(13).chunk(6) + (tensor([0, 1, 2]), + tensor([3, 4, 5]), + tensor([6, 7, 8]), + tensor([ 9, 10, 11]), + tensor([12])) +""", +) + +add_docstr( + torch.unsafe_chunk, + r""" +unsafe_chunk(input, chunks, dim=0) -> List of Tensors + +Works like :func:`torch.chunk` but without enforcing the autograd restrictions +on inplace modification of the outputs. + +.. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. +""", +) + +add_docstr( + torch.unsafe_split, + r""" +unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors + +Works like :func:`torch.split` but without enforcing the autograd restrictions +on inplace modification of the outputs. + +.. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. +""", +) + +add_docstr( + torch.hsplit, + r""" +hsplit(input, indices_or_sections) -> List of Tensors + +Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors +horizontally according to :attr:`indices_or_sections`. Each split is a view of +:attr:`input`. + +If :attr:`input` is one dimensional this is equivalent to calling +torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is +zero), and if :attr:`input` has two or more dimensions it's equivalent to calling +torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), +except that if :attr:`indices_or_sections` is an integer it must evenly divide +the split dimension or a runtime error will be thrown. + +This function is based on NumPy's :func:`numpy.hsplit`. + +Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + +Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + +""", +) + +add_docstr( + torch.vsplit, + r""" +vsplit(input, indices_or_sections) -> List of Tensors + +Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors +vertically according to :attr:`indices_or_sections`. Each split is a view of +:attr:`input`. + +This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) +(the split dimension is 0), except that if :attr:`indices_or_sections` is an integer +it must evenly divide the split dimension or a runtime error will be thrown. + +This function is based on NumPy's :func:`numpy.vsplit`. + +Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + +Example:: + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + +""", +) + +add_docstr( + torch.dsplit, + r""" +dsplit(input, indices_or_sections) -> List of Tensors + +Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors +depthwise according to :attr:`indices_or_sections`. Each split is a view of +:attr:`input`. + +This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) +(the split dimension is 2), except that if :attr:`indices_or_sections` is an integer +it must evenly divide the split dimension or a runtime error will be thrown. + +This function is based on NumPy's :func:`numpy.dsplit`. + +Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + +Example:: + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + +""", +) + +add_docstr( + torch.can_cast, + r""" +can_cast(from, to) -> bool + +Determines if a type conversion is allowed under PyTorch casting rules +described in the type promotion :ref:`documentation `. + +Args: + from (dtype): The original :class:`torch.dtype`. + to (dtype): The target :class:`torch.dtype`. + +Example:: + + >>> torch.can_cast(torch.double, torch.float) + True + >>> torch.can_cast(torch.float, torch.int) + False +""", +) + +add_docstr( + torch.corrcoef, + r""" +corrcoef(input) -> Tensor + +Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, +where rows are the variables and columns are the observations. + +.. note:: + + The correlation coefficient matrix R is computed using the covariance matrix C as given by + :math:`R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }` + +.. note:: + + Due to floating point rounding, the resulting array may not be Hermitian and its diagonal elements may not be 1. + The real and imaginary values are clipped to the interval [-1, 1] in an attempt to improve this situation. + +Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + +Returns: + (Tensor) The correlation coefficient matrix of the variables. + +.. seealso:: + + :func:`torch.cov` covariance matrix. + +Example:: + + >>> x = torch.tensor([[0, 1, 2], [2, 1, 0]]) + >>> torch.corrcoef(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> x = torch.randn(2, 4) + >>> x + tensor([[-0.2678, -0.0908, -0.3766, 0.2780], + [-0.5812, 0.1535, 0.2387, 0.2350]]) + >>> torch.corrcoef(x) + tensor([[1.0000, 0.3582], + [0.3582, 1.0000]]) + >>> torch.corrcoef(x[0]) + tensor(1.) +""", +) + +add_docstr( + torch.cov, + r""" +cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor + +Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are +the variables and columns are the observations. + +A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains +the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents +a single variable (Scalar or 1D) then its variance is returned. + +The sample covariance of the variables :math:`x` and :math:`y` is given by: + +.. math:: + \text{cov}(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{\max(0,~N~-~\delta N)} + +where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively, and +:math:`\delta N` is the :attr:`correction`. + +If :attr:`fweights` and/or :attr:`aweights` are provided, the weighted covariance +is calculated, which is given by: + +.. math:: + \text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)} + {\max(0,~\sum^{N}_{i = 1}w_i~-~\frac{\sum^{N}_{i = 1}w_ia_i}{\sum^{N}_{i = 1}w_i}~\delta N)} + +where :math:`w` denotes :attr:`fweights` or :attr:`aweights` (``f`` and ``a`` for brevity) based on whichever is +provided, or :math:`w = f \times a` if both are provided, and +:math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable. If not +provided, ``f`` and/or ``a`` can be seen as a :math:`\mathbb{1}` vector of appropriate size. + +Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + +Keyword Args: + correction (int, optional): difference between the sample size and sample degrees of freedom. + Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate, + even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0`` + will return the simple average. Defaults to ``1``. + fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of + times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`. + Must have integral dtype. Ignored if ``None``. Defaults to ``None``. + aweights (tensor, optional): A Scalar or 1D array of observation vector weights. + These relative weights are typically large for observations considered "important" and smaller for + observations considered less "important". Its numel must equal the number of columns of :attr:`input`. + Must have floating point dtype. Ignored if ``None``. Defaults to ``None``. + +Returns: + (Tensor) The covariance matrix of the variables. + +.. seealso:: + + :func:`torch.corrcoef` normalized covariance matrix. + +Example:: + >>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T + >>> x + tensor([[0, 1, 2], + [2, 1, 0]]) + >>> torch.cov(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> torch.cov(x, correction=0) + tensor([[ 0.6667, -0.6667], + [-0.6667, 0.6667]]) + >>> fw = torch.randint(1, 10, (3,)) + >>> fw + tensor([1, 6, 9]) + >>> aw = torch.rand(3) + >>> aw + tensor([0.4282, 0.0255, 0.4144]) + >>> torch.cov(x, fweights=fw, aweights=aw) + tensor([[ 0.4169, -0.4169], + [-0.4169, 0.4169]]) +""", +) + +add_docstr( + torch.cat, + r""" +cat(tensors, dim=0, *, out=None) -> Tensor + +Concatenates the given sequence of :attr:`seq` tensors in the given dimension. +All tensors must either have the same shape (except in the concatenating +dimension) or be a 1-D empty tensor with size ``(0,)``. + +:func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` +and :func:`torch.chunk`. + +:func:`torch.cat` can be best understood via examples. + +.. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + +Args: + tensors (sequence of Tensors): any python sequence of tensors of the same type. + Non-empty tensors provided must have the same shape, except in the + cat dimension. + dim (int, optional): the dimension over which the tensors are concatenated + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.concat, + r""" +concat(tensors, dim=0, *, out=None) -> Tensor + +Alias of :func:`torch.cat`. +""", +) + +add_docstr( + torch.concatenate, + r""" +concatenate(tensors, axis=0, out=None) -> Tensor + +Alias of :func:`torch.cat`. +""", +) + +add_docstr( + torch.ceil, + r""" +ceil(input, *, out=None) -> Tensor + +Returns a new tensor with the ceil of the elements of :attr:`input`, +the smallest integer greater than or equal to each element. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. + +.. math:: + \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + >>> torch.ceil(a) + tensor([-0., -1., -1., 1.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.real, + r""" +real(input) -> Tensor + +Returns a new tensor containing real values of the :attr:`self` tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.imag, + r""" +imag(input) -> Tensor + +Returns a new tensor containing imaginary values of the :attr:`self` tensor. +The returned tensor and :attr:`self` share the same underlying storage. + +.. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.view_as_real, + r""" +view_as_real(input) -> Tensor + +Returns a view of :attr:`input` as a real tensor. For an input complex tensor of +:attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new +real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2 +represents the real and imaginary components of complex numbers. + +.. warning:: + :func:`view_as_real` is only supported for tensors with ``complex dtypes``. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)]) + >>> torch.view_as_real(x) + tensor([[ 0.4737, -0.3839], + [-0.2098, -0.6699], + [ 0.3470, -0.9451], + [-0.5174, -1.3136]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.view_as_complex, + r""" +view_as_complex(input) -> Tensor + +Returns a view of :attr:`input` as a complex tensor. For an input complex +tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a +new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last +dimension of the input tensor is expected to represent the real and imaginary +components of complex numbers. + +.. warning:: + :func:`view_as_complex` is only supported for tensors with + :class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is + expected to have the last dimension of :attr:`size` 2. In addition, the + tensor must have a `stride` of 1 for its last dimension. The strides of all + other dimensions must be even numbers. + +Args: + {input} + +Example:: + + >>> x=torch.randn(4, 2) + >>> x + tensor([[ 1.6116, -0.5772], + [-1.4606, -0.9120], + [ 0.0786, -1.7497], + [-0.6561, -1.6623]]) + >>> torch.view_as_complex(x) + tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) +""".format( + **common_args + ), +) + +add_docstr( + torch.reciprocal, + r""" +reciprocal(input, *, out=None) -> Tensor + +Returns a new tensor with the reciprocal of the elements of :attr:`input` + +.. math:: + \text{out}_{i} = \frac{1}{\text{input}_{i}} + +.. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.4595, -2.1219, -1.4314, 0.7298]) + >>> torch.reciprocal(a) + tensor([-2.1763, -0.4713, -0.6986, 1.3702]) +""".format( + **common_args + ), +) + +add_docstr( + torch.cholesky, + r""" +cholesky(input, upper=False, *, out=None) -> Tensor + +Computes the Cholesky decomposition of a symmetric positive-definite +matrix :math:`A` or for batches of symmetric positive-definite matrices. + +If :attr:`upper` is ``True``, the returned matrix ``U`` is upper-triangular, and +the decomposition has the form: + +.. math:: + + A = U^TU + +If :attr:`upper` is ``False``, the returned matrix ``L`` is lower-triangular, and +the decomposition has the form: + +.. math:: + + A = LL^T + +If :attr:`upper` is ``True``, and :math:`A` is a batch of symmetric positive-definite +matrices, then the returned tensor will be composed of upper-triangular Cholesky factors +of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned +tensor will be composed of lower-triangular Cholesky factors of each of the individual +matrices. + +.. warning:: + + :func:`torch.cholesky` is deprecated in favor of :func:`torch.linalg.cholesky` + and will be removed in a future PyTorch release. + + ``L = torch.cholesky(A)`` should be replaced with + + .. code:: python + + L = torch.linalg.cholesky(A) + + ``U = torch.cholesky(A, upper=True)`` should be replaced with + + .. code:: python + + U = torch.linalg.cholesky(A).mH + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. + +Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + upper (bool, optional): flag that indicates whether to return a + upper or lower triangular matrix. Default: ``False`` + +Keyword args: + out (Tensor, optional): the output matrix + +Example:: + + >>> a = torch.randn(3, 3) + >>> a = a @ a.mT + 1e-3 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> l @ l.mT + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> a = torch.randn(3, 2, 2) # Example for batched input + >>> a = a @ a.mT + 1e-03 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> z = l @ l.mT + >>> torch.dist(z, a) + tensor(2.3842e-07) +""", +) + +add_docstr( + torch.cholesky_solve, + r""" +cholesky_solve(B, L, upper=False, *, out=None) -> Tensor + +Computes the solution of a system of linear equations with complex Hermitian +or real symmetric positive-definite lhs given its Cholesky decomposition. + +Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, +and :math:`L` its Cholesky decomposition such that: + +.. math:: + + A = LL^{\text{H}} + +where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, +and the transpose when :math:`L` is real-valued. + +Returns the solution :math:`X` of the following linear system: + +.. math:: + + AX = B + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices +then the output has the same batch dimensions. + +Args: + B (Tensor): right-hand side tensor of shape `(*, n, k)` + where :math:`*` is zero or more batch dimensions + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False``. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> B = torch.randn(3, 2) + >>> torch.cholesky_solve(B, L) + tensor([[ -8.1625, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + >>> A.inverse() @ B + tensor([[ -8.1626, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> B = torch.randn(2, 1, dtype=torch.complex64) + >>> X = torch.cholesky_solve(B, L) + >>> torch.dist(X, A.inverse() @ B) + tensor(1.6881e-5) +""", +) + +add_docstr( + torch.cholesky_inverse, + r""" +cholesky_inverse(L, upper=False, *, out=None) -> Tensor + +Computes the inverse of a complex Hermitian or real symmetric +positive-definite matrix given its Cholesky decomposition. + +Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, +and :math:`L` its Cholesky decomposition such that: + +.. math:: + + A = LL^{\text{H}} + +where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, +and the transpose when :math:`L` is real-valued. + +Computes the inverse matrix :math:`A^{-1}`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :math:`A` is a batch of matrices +then the output has the same batch dimensions. + +Args: + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False`` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> torch.cholesky_inverse(L) + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + >>> A.inverse() + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(torch.inverse(A), torch.cholesky_inverse(L)) + tensor(5.6358e-7) +""", +) + +add_docstr( + torch.clone, + r""" +clone(input, *, memory_format=torch.preserve_format) -> Tensor + +Returns a copy of :attr:`input`. + +.. note:: + + This function is differentiable, so gradients will flow back from the + result of this operation to :attr:`input`. To create a tensor without an + autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + +Args: + {input} + +Keyword args: + {memory_format} +""".format( + **common_args + ), +) + +add_docstr( + torch.clamp, + r""" +clamp(input, min=None, max=None, *, out=None) -> Tensor + +Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. +Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + +.. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + +If :attr:`min` is ``None``, there is no lower bound. +Or, if :attr:`max` is ``None`` there is no upper bound. +""" + + r""" + +.. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + +Args: + {input} + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.clip, + r""" +clip(input, min=None, max=None, *, out=None) -> Tensor + +Alias for :func:`torch.clamp`. +""", +) + +add_docstr( + torch.column_stack, + r""" +column_stack(tensors, *, out=None) -> Tensor + +Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + +Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` +in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.complex, + r""" +complex(real, imag, *, out=None) -> Tensor + +Constructs a complex tensor with its real part equal to :attr:`real` and its +imaginary part equal to :attr:`imag`. + +Args: + real (Tensor): The real part of the complex tensor. Must be half, float or double. + imag (Tensor): The imaginary part of the complex tensor. Must be same dtype + as :attr:`real`. + +Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + +Example:: + + >>> real = torch.tensor([1, 2], dtype=torch.float32) + >>> imag = torch.tensor([3, 4], dtype=torch.float32) + >>> z = torch.complex(real, imag) + >>> z + tensor([(1.+3.j), (2.+4.j)]) + >>> z.dtype + torch.complex64 + +""", +) + +add_docstr( + torch.polar, + r""" +polar(abs, angle, *, out=None) -> Tensor + +Constructs a complex tensor whose elements are Cartesian coordinates +corresponding to the polar coordinates with absolute value :attr:`abs` and angle +:attr:`angle`. + +.. math:: + \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j + +.. note:: + `torch.polar` is similar to + `std::polar `_ + and does not compute the polar decomposition + of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do. + The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is + infinite. + +""" + + r""" +Args: + abs (Tensor): The absolute value the complex tensor. Must be float or double. + angle (Tensor): The angle of the complex tensor. Must be same dtype as + :attr:`abs`. + +Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + +Example:: + + >>> import numpy as np + >>> abs = torch.tensor([1, 2], dtype=torch.float64) + >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64) + >>> z = torch.polar(abs, angle) + >>> z + tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) +""", +) + +add_docstr( + torch.conj_physical, + r""" +conj_physical(input, *, out=None) -> Tensor + +Computes the element-wise conjugate of the given :attr:`input` tensor. +If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. + +.. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + +.. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + +.. math:: + \text{out}_{i} = conj(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) +""".format( + **common_args + ), +) + +add_docstr( + torch.conj, + r""" +conj(input) -> Tensor + +Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, +this function just returns :attr:`input`. + +.. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + +.. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True +""".format( + **common_args + ), +) + +add_docstr( + torch.resolve_conj, + r""" +resolve_conj(input) -> Tensor + +Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, +else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False +""".format( + **common_args + ), +) + +add_docstr( + torch.resolve_neg, + r""" +resolve_neg(input) -> Tensor + +Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, +else returns :attr:`input`. The output tensor will always have its negative bit set to `False`. + +Args: + {input} + +Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> z = y.imag + >>> z.is_neg() + True + >>> out = z.resolve_neg() + >>> out + tensor([-1., -2., 3.]) + >>> out.is_neg() + False +""".format( + **common_args + ), +) + +add_docstr( + torch.copysign, + r""" +copysign(input, other, *, out=None) -> Tensor + +Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + +.. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +and integer and float inputs. + +Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + +.. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + +""".format( + **common_args + ), +) + +add_docstr( + torch.cos, + r""" +cos(input, *, out=None) -> Tensor + +Returns a new tensor with the cosine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \cos(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) + >>> torch.cos(a) + tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) +""".format( + **common_args + ), +) + +add_docstr( + torch.cosh, + r""" +cosh(input, *, out=None) -> Tensor + +Returns a new tensor with the hyperbolic cosine of the elements of +:attr:`input`. + +.. math:: + \text{out}_{i} = \cosh(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) + >>> torch.cosh(a) + tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + +.. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. +""".format( + **common_args + ), +) + +add_docstr( + torch.cross, + r""" +cross(input, other, dim=None, *, out=None) -> Tensor + + +Returns the cross product of vectors in dimension :attr:`dim` of :attr:`input` +and :attr:`other`. + +Supports input of float, double, cfloat and cdouble dtypes. Also supports batches +of vectors, for which it computes the product along the dimension :attr:`dim`. +In this case, the output has the same batch dimensions as the inputs. + +.. warning:: + If :attr:`dim` is not given, it defaults to the first dimension found + with the size 3. Note that this might be unexpected. + + This behavior is deprecated and will be changed to match that of :func:`torch.linalg.cross` + in a future release. + +.. seealso:: + :func:`torch.linalg.cross` which has dim=-1 as default. + + +Args: + {input} + other (Tensor): the second input tensor + dim (int, optional): the dimension to take the cross-product in. + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.cross(a, b, dim=1) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> torch.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.logcumsumexp, + r""" +logcumsumexp(input, dim, *, out=None) -> Tensor +Returns the logarithm of the cumulative summation of the exponentiation of +elements of :attr:`input` in the dimension :attr:`dim`. + +For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{{logcumsumexp}}(x)_{{ij}} = \log \sum\limits_{{j=0}}^{{i}} \exp(x_{{ij}}) + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cummax, + r""" +cummax(input, dim, *, out=None) -> (Tensor, LongTensor) +Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of +elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index +location of each maximum value found in the dimension :attr:`dim`. + +.. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + +Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cummin, + r""" +cummin(input, dim, *, out=None) -> (Tensor, LongTensor) +Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of +elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index +location of each maximum value found in the dimension :attr:`dim`. + +.. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + +Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cumprod, + r""" +cumprod(input, dim, *, dtype=None, out=None) -> Tensor + +Returns the cumulative product of elements of :attr:`input` in the dimension +:attr:`dim`. + +For example, if :attr:`input` is a vector of size N, the result will also be +a vector of size N, with elements. + +.. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + {dtype} + {out} + +Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.cumsum, + r""" +cumsum(input, dim, *, dtype=None, out=None) -> Tensor + +Returns the cumulative sum of elements of :attr:`input` in the dimension +:attr:`dim`. + +For example, if :attr:`input` is a vector of size N, the result will also be +a vector of size N, with elements. + +.. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + +Args: + {input} + dim (int): the dimension to do the operation over + +Keyword args: + {dtype} + {out} + +Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.count_nonzero, + r""" +count_nonzero(input, dim=None) -> Tensor + +Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. +If no dim is specified then all non-zeros in the tensor are counted. + +Args: + {input} + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + +Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) +""".format( + **reduceops_common_args + ), +) + +add_docstr( + torch.dequantize, + r""" +dequantize(tensor) -> Tensor + +Returns an fp32 Tensor by dequantizing a quantized Tensor + +Args: + tensor (Tensor): A quantized Tensor + +.. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + +Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + +Args: + tensors (sequence of Tensors): A list of quantized Tensors +""", +) + +add_docstr( + torch.diag, + r""" +diag(input, diagonal=0, *, out=None) -> Tensor + +- If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. +- If :attr:`input` is a matrix (2-D tensor), then returns a 1-D tensor with + the diagonal elements of :attr:`input`. + +The argument :attr:`diagonal` controls which diagonal to consider: + +- If :attr:`diagonal` = 0, it is the main diagonal. +- If :attr:`diagonal` > 0, it is above the main diagonal. +- If :attr:`diagonal` < 0, it is below the main diagonal. + +Args: + {input} + diagonal (int, optional): the diagonal to consider + +Keyword args: + {out} + +.. seealso:: + + :func:`torch.diagonal` always returns the diagonal of its input. + + :func:`torch.diagflat` always constructs a tensor with diagonal elements + specified by the input. + +Examples: + +Get the square matrix where the input vector is the diagonal:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.5950,-0.0872, 2.3298]) + >>> torch.diag(a) + tensor([[ 0.5950, 0.0000, 0.0000], + [ 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 2.3298]]) + >>> torch.diag(a, 1) + tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], + [ 0.0000, 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 0.0000, 2.3298], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + +Get the k-th diagonal of a given matrix:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-0.4264, 0.0255,-0.1064], + [ 0.8795,-0.2429, 0.1374], + [ 0.1029,-0.6482,-1.6300]]) + >>> torch.diag(a, 0) + tensor([-0.4264,-0.2429,-1.6300]) + >>> torch.diag(a, 1) + tensor([ 0.0255, 0.1374]) +""".format( + **common_args + ), +) + +add_docstr( + torch.diag_embed, + r""" +diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor + +Creates a tensor whose diagonals of certain 2D planes (specified by +:attr:`dim1` and :attr:`dim2`) are filled by :attr:`input`. +To facilitate creating batched diagonal matrices, the 2D planes formed by +the last two dimensions of the returned tensor are chosen by default. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +The size of the new matrix will be calculated to make the specified diagonal +of the size of the last input dimension. +Note that for :attr:`offset` other than :math:`0`, the order of :attr:`dim1` +and :attr:`dim2` matters. Exchanging them is equivalent to changing the +sign of :attr:`offset`. + +Applying :meth:`torch.diagonal` to the output of this function with +the same arguments yields a matrix identical to input. However, +:meth:`torch.diagonal` has different default dimensions, so those +need to be explicitly specified. + +Args: + {input} Must be at least 1-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: -2. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: -1. + +Example:: + + >>> a = torch.randn(2, 3) + >>> torch.diag_embed(a) + tensor([[[ 1.5410, 0.0000, 0.0000], + [ 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -2.1788]], + + [[ 0.5684, 0.0000, 0.0000], + [ 0.0000, -1.0845, 0.0000], + [ 0.0000, 0.0000, -1.3986]]]) + + >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) + tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], + [ 0.0000, 0.5684, 0.0000, 0.0000]], + + [[ 0.0000, 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -1.0845, 0.0000]], + + [[ 0.0000, 0.0000, 0.0000, -2.1788], + [ 0.0000, 0.0000, 0.0000, -1.3986]], + + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]) +""".format( + **common_args + ), +) + + +add_docstr( + torch.diagflat, + r""" +diagflat(input, offset=0) -> Tensor + +- If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. +- If :attr:`input` is a tensor with more than one dimension, then returns a + 2-D tensor with diagonal elements equal to a flattened :attr:`input`. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +Args: + {input} + offset (int, optional): the diagonal to consider. Default: 0 (main + diagonal). + +Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([-0.2956, -0.9068, 0.1695]) + >>> torch.diagflat(a) + tensor([[-0.2956, 0.0000, 0.0000], + [ 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.1695]]) + >>> torch.diagflat(a, 1) + tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.1695], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + >>> a = torch.randn(2, 2) + >>> a + tensor([[ 0.2094, -0.3018], + [-0.1516, 1.9342]]) + >>> torch.diagflat(a) + tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.3018, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1516, 0.0000], + [ 0.0000, 0.0000, 0.0000, 1.9342]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.diagonal, + r""" +diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + +Returns a partial view of :attr:`input` with the its diagonal elements +with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension +at the end of the shape. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +Applying :meth:`torch.diag_embed` to the output of this function with +the same arguments yields a diagonal matrix with the diagonal entries +of the input. However, :meth:`torch.diag_embed` has different default +dimensions, so those need to be explicitly specified. + +Args: + {input} Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + +.. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + +Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a, 0) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.diagonal_scatter, + r""" +diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` along +the diagonal elements of :attr:`input`, with respect to :attr:`dim1` +and :attr:`dim2`. + +This function returns a tensor with fresh storage; it does not +return a view. + +The argument :attr:`offset` controls which diagonal to consider: + +- If :attr:`offset` = 0, it is the main diagonal. +- If :attr:`offset` > 0, it is above the main diagonal. +- If :attr:`offset` < 0, it is below the main diagonal. + +Args: + {input} Must be at least 2-dimensional. + src (Tensor): the tensor to embed into :attr:`input`. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + +.. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.diagonal(input, offset, dim1, dim2)`` + +Examples:: + + >>> a = torch.zeros(3, 3) + >>> a + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + + >>> torch.diagonal_scatter(a, torch.ones(3), 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + >>> torch.diagonal_scatter(a, torch.ones(2), 1) + tensor([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.as_strided_scatter, + r""" +as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` along +the elements corresponding to the result of calling +input.as_strided(size, stride, storage_offset). + +This function returns a tensor with fresh storage; it does not +return a view. + +Args: + {input} + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor + +.. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + `torch.as_strided(input, size, stride, storage_offset)` + +Example:: + + >>> a = torch.arange(4).reshape(2, 2) + 1 + >>> a + tensor([[1, 2], + [3, 4]]) + >>> b = torch.zeros(3, 3) + >>> b + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> torch.as_strided_scatter(b, a, (2, 2), (1, 2)) + tensor([[1., 3., 2.], + [4., 0., 0.], + [0., 0., 0.]]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.diff, + r""" +diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor + +Computes the n-th forward difference along the given dimension. + +The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order +differences are calculated by using :func:`torch.diff` recursively. + +Args: + input (Tensor): the tensor to compute the differences on + n (int, optional): the number of times to recursively compute the difference + dim (int, optional): the dimension to compute the difference along. + Default is the last dimension. + prepend, append (Tensor, optional): values to prepend or append to + :attr:`input` along :attr:`dim` before computing the difference. + Their dimensions must be equivalent to that of input, and their shapes + must match input's shape except on :attr:`dim`. + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 3, 2]) + >>> torch.diff(a) + tensor([ 2, -1]) + >>> b = torch.tensor([4, 5]) + >>> torch.diff(a, append=b) + tensor([ 2, -1, 2, 1]) + >>> c = torch.tensor([[1, 2, 3], [3, 4, 5]]) + >>> torch.diff(c, dim=0) + tensor([[2, 2, 2]]) + >>> torch.diff(c, dim=1) + tensor([[1, 1], + [1, 1]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.digamma, + r""" +digamma(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.digamma`. +""", +) + +add_docstr( + torch.dist, + r""" +dist(input, other, p=2) -> Tensor + +Returns the p-norm of (:attr:`input` - :attr:`other`) + +The shapes of :attr:`input` and :attr:`other` must be +:ref:`broadcastable `. + +Args: + {input} + other (Tensor): the Right-hand-side input tensor + p (float, optional): the norm to be computed + +Example:: + + >>> x = torch.randn(4) + >>> x + tensor([-1.5393, -0.8675, 0.5916, 1.6321]) + >>> y = torch.randn(4) + >>> y + tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) + >>> torch.dist(x, y, 3.5) + tensor(1.6727) + >>> torch.dist(x, y, 3) + tensor(1.6973) + >>> torch.dist(x, y, 0) + tensor(4.) + >>> torch.dist(x, y, 1) + tensor(2.6537) +""".format( + **common_args + ), +) + +add_docstr( + torch.div, + r""" +div(input, other, *, rounding_mode=None, out=None) -> Tensor + +Divides each element of the input ``input`` by the corresponding element of +:attr:`other`. + +.. math:: + \text{{out}}_i = \frac{{\text{{input}}_i}}{{\text{{other}}_i}} + +.. note:: + By default, this performs a "true" division like Python 3. + See the :attr:`rounding_mode` argument for floor division. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. +Always promotes integer types to the default scalar type. + +Args: + input (Tensor): the dividend + other (Tensor or Number): the divisor + +Keyword args: + rounding_mode (str, optional): Type of rounding applied to the result: + + * None - default behavior. Performs no rounding and, if both :attr:`input` and + :attr:`other` are integer types, promotes the inputs to the default scalar type. + Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``. + * ``"trunc"`` - rounds the results of the division towards zero. + Equivalent to C-style integer division. + * ``"floor"`` - rounds the results of the division down. + Equivalent to floor division in Python (the ``//`` operator) and NumPy's ``np.floor_divide``. + + {out} + +Examples:: + + >>> x = torch.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) + >>> torch.div(x, 0.5) + tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274]) + + >>> a = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917], + ... [ 0.1815, -1.0111, 0.9805, -1.5923], + ... [ 0.1062, 1.4581, 0.7759, -1.2344], + ... [-0.1830, -0.0313, 1.1908, -1.4757]]) + >>> b = torch.tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) + >>> torch.div(a, b) + tensor([[-0.4620, -6.6051, 0.5676, 1.2639], + [ 0.2260, -3.4509, -1.2086, 6.8990], + [ 0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + + >>> torch.div(a, b, rounding_mode='trunc') + tensor([[-0., -6., 0., 1.], + [ 0., -3., -1., 6.], + [ 0., 4., -0., 5.], + [-0., -0., -1., 6.]]) + + >>> torch.div(a, b, rounding_mode='floor') + tensor([[-1., -7., 0., 1.], + [ 0., -4., -2., 6.], + [ 0., 4., -1., 5.], + [-1., -1., -2., 6.]]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.divide, + r""" +divide(input, other, *, rounding_mode=None, out=None) -> Tensor + +Alias for :func:`torch.div`. +""", +) + +add_docstr( + torch.dot, + r""" +dot(input, other, *, out=None) -> Tensor + +Computes the dot product of two 1D tensors. + +.. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + +Args: + input (Tensor): first tensor in the dot product, must be 1D. + other (Tensor): second tensor in the dot product, must be 1D. + +Keyword args: + {out} + +Example:: + + >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) +""".format( + **common_args + ), +) + +add_docstr( + torch.vdot, + r""" +vdot(input, other, *, out=None) -> Tensor + +Computes the dot product of two 1D vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. + +.. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + +.. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + +Args: + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. + +Keyword args: +""" + + rf""" +.. note:: {common_args["out"]} +""" + + r""" + +Example:: + + >>> torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + >>> a = torch.tensor((1 +2j, 3 - 1j)) + >>> b = torch.tensor((2 +1j, 4 - 0j)) + >>> torch.vdot(a, b) + tensor([16.+1.j]) + >>> torch.vdot(b, a) + tensor([16.-1.j]) +""", +) + +add_docstr( + torch.eq, + r""" +eq(input, other, *, out=None) -> Tensor + +Computes element-wise equality + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + +Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.equal, + r""" +equal(input, other) -> bool + +``True`` if two tensors have the same size and elements, ``False`` otherwise. + +Example:: + + >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) + True +""", +) + +add_docstr( + torch.erf, + r""" +erf(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.erf`. +""", +) + +add_docstr( + torch.erfc, + r""" +erfc(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.erfc`. +""", +) + +add_docstr( + torch.erfinv, + r""" +erfinv(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.erfinv`. +""", +) + +add_docstr( + torch.exp, + r""" +exp(input, *, out=None) -> Tensor + +Returns a new tensor with the exponential of the elements +of the input tensor :attr:`input`. + +.. math:: + y_{i} = e^{x_{i}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.exp(torch.tensor([0, math.log(2.)])) + tensor([ 1., 2.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.exp2, + r""" +exp2(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.exp2`. +""", +) + +add_docstr( + torch.expm1, + r""" +expm1(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.expm1`. +""", +) + +add_docstr( + torch.eye, + r""" +eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + +Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + +Keyword arguments: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + +Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.floor, + r""" +floor(input, *, out=None) -> Tensor + +Returns a new tensor with the floor of the elements of :attr:`input`, +the largest integer less than or equal to each element. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. + +.. math:: + \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.8166, 1.5308, -0.2530, -0.2091]) + >>> torch.floor(a) + tensor([-1., 1., -1., -1.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.floor_divide, + r""" +floor_divide(input, other, *, out=None) -> Tensor + +.. note:: + + Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly performed + truncation division. To restore the previous behavior use + :func:`torch.div` with ``rounding_mode='trunc'``. + +Computes :attr:`input` divided by :attr:`other`, elementwise, and floors +the result. + +.. math:: + \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) + +""" + + r""" + +Supports broadcasting to a common shape, type promotion, and integer and float inputs. + +Args: + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([4.0, 3.0]) + >>> b = torch.tensor([2.0, 2.0]) + >>> torch.floor_divide(a, b) + tensor([2.0, 1.0]) + >>> torch.floor_divide(a, 1.4) + tensor([2.0, 2.0]) +""".format( + **common_args + ), +) + +add_docstr( + torch.fmod, + r""" +fmod(input, other, *, out=None) -> Tensor + +Applies C++'s `std::fmod `_ entrywise. +The result has the same sign as the dividend :attr:`input` and its absolute value +is less than that of :attr:`other`. + +This function may be defined in terms of :func:`torch.div` as + +.. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and float inputs. + +.. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + +.. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + +.. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + +Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + +Keyword args: + {out} + +Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.frac, + r""" +frac(input, *, out=None) -> Tensor + +Computes the fractional portion of each element in :attr:`input`. + +.. math:: + \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) + +Example:: + + >>> torch.frac(torch.tensor([1, 2.5, -3.2])) + tensor([ 0.0000, 0.5000, -0.2000]) +""", +) + +add_docstr( + torch.frexp, + r""" +frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) + +Decomposes :attr:`input` into mantissa and exponent tensors +such that :math:`\text{input} = \text{mantissa} \times 2^{\text{exponent}}`. + +The range of mantissa is the open interval (-1, 1). + +Supports float inputs. + +Args: + input (Tensor): the input tensor + + +Keyword args: + out (tuple, optional): the output tensors + +Example:: + + >>> x = torch.arange(9.) + >>> mantissa, exponent = torch.frexp(x) + >>> mantissa + tensor([0.0000, 0.5000, 0.5000, 0.7500, 0.5000, 0.6250, 0.7500, 0.8750, 0.5000]) + >>> exponent + tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) + >>> torch.ldexp(mantissa, exponent) + tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) +""", +) + +add_docstr( + torch.from_numpy, + r""" +from_numpy(ndarray) -> Tensor + +Creates a :class:`Tensor` from a :class:`numpy.ndarray`. + +The returned tensor and :attr:`ndarray` share the same memory. Modifications to +the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned +tensor is not resizable. + +It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, +``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``, +``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, +and ``bool``. + +.. warning:: + Writing to a tensor created from a read-only NumPy array is not supported and will result in undefined behavior. + +Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.from_numpy(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) +""", +) + +add_docstr( + torch.frombuffer, + r""" +frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor + +Creates a 1-dimensional :class:`Tensor` from an object that implements +the Python buffer protocol. + +Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of +the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count` +elements. + +Note that either of the following must be true: + +1. :attr:`count` is a positive non-zero number, and the total number of bytes +in the buffer is more than :attr:`offset` plus :attr:`count` times the size +(in bytes) of :attr:`dtype`. + +2. :attr:`count` is negative, and the length (number of bytes) of the buffer +subtracted by the :attr:`offset` is a multiple of the size (in bytes) of +:attr:`dtype`. + +The returned tensor and buffer share the same memory. Modifications to +the tensor will be reflected in the buffer and vice versa. The returned +tensor is not resizable. + +.. note:: + This function increments the reference count for the object that + owns the shared memory. Therefore, such memory will not be deallocated + before the returned tensor goes out of scope. + +.. warning:: + This function's behavior is undefined when passed an object implementing + the buffer protocol whose data is not on the CPU. Doing so is likely to + cause a segmentation fault. + +.. warning:: + This function does not try to infer the :attr:`dtype` (hence, it is not + optional). Passing a different :attr:`dtype` than its source may result + in unexpected behavior. + +Args: + buffer (object): a Python object that exposes the buffer interface. + +Keyword args: + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + count (int, optional): the number of desired elements to be read. + If negative, all the elements (until the end of the buffer) will be + read. Default: -1. + offset (int, optional): the number of bytes to skip at the start of + the buffer. Default: 0. + {requires_grad} + +Example:: + + >>> import array + >>> a = array.array('i', [1, 2, 3]) + >>> t = torch.frombuffer(a, dtype=torch.int32) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> # Interprets the signed char bytes as 32-bit integers. + >>> # Each 4 signed char elements will be interpreted as + >>> # 1 signed 32-bit integer. + >>> import array + >>> a = array.array('b', [-1, 0, 0, 0]) + >>> torch.frombuffer(a, dtype=torch.int32) + tensor([255], dtype=torch.int32) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.from_file, + r""" +from_file(filename, shared=None, size=0, *, dtype=None, layout=None, device=None, pin_memory=False) + +Creates a CPU tensor with a storage backed by a memory-mapped file. + +If ``shared`` is True, then memory is shared between processes. All changes are written to the file. +If ``shared`` is False, then changes to the tensor do not affect the file. + +``size`` is the number of elements in the Tensor. If ``shared`` is ``False``, then the file must contain +at least ``size * sizeof(dtype)`` bytes. If ``shared`` is ``True`` the file will be created if needed. + +.. note:: + Only CPU tensors can be mapped to files. + +.. note:: + For now, tensors with storages backed by a memory-mapped file cannot be created in pinned memory. + + +Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the tensor + +Keyword args: + {dtype} + {layout} + {device} + {pin_memory} + +Example:: + >>> t = torch.randn(2, 5, dtype=torch.float64) + >>> t.numpy().tofile('storage.pt') + >>> t_mapped = torch.from_file('storage.pt', shared=False, size=10, dtype=torch.float64) + """.format( + **factory_common_args + ), +) + +add_docstr( + torch.flatten, + r""" +flatten(input, start_dim=0, end_dim=-1) -> Tensor + +Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` +are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. +The order of elements in :attr:`input` is unchanged. + +Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, +or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can +be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the +flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + +.. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + +Args: + {input} + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + +Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.unflatten, + r""" +unflatten(input, dim, sizes) -> Tensor + +Expands a dimension of the input tensor over multiple dimensions. + +.. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + +Args: + {input} + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + +Returns: + A View of input with the specified dimension unflattened. + +Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) +""".format( + **common_args + ), +) + +add_docstr( + torch.gather, + r""" +gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + +Gathers values along an axis specified by `dim`. + +For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + +:attr:`input` and :attr:`index` must have the same number of dimensions. +It is also required that ``index.size(d) <= input.size(d)`` for all +dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. +Note that ``input`` and ``index`` do not broadcast against each other. + +Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + +Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + +Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) +""", +) + + +add_docstr( + torch.gcd, + r""" +gcd(input, other, *, out=None) -> Tensor + +Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. + +Both :attr:`input` and :attr:`other` must have integer types. + +.. note:: + This defines :math:`gcd(0, 0) = 0`. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.gcd(a, b) + tensor([1, 2, 5]) + >>> c = torch.tensor([3]) + >>> torch.gcd(a, c) + tensor([1, 1, 3]) +""".format( + **common_args + ), +) + +add_docstr( + torch.ge, + r""" +ge(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} \geq \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + +Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.greater_equal, + r""" +greater_equal(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.ge`. +""", +) + +add_docstr( + torch.gradient, + r""" +gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + +Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in +one or more dimensions using the `second-order accurate central differences method +`_ and +either first or second order estimates at the boundaries. + +The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not +specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates +to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional +:attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and +:math:`g(1, 2, 3)\ == input[1, 2, 3]`. + +When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. +This is detailed in the "Keyword Arguments" section below. + +The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is +accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be +improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative +is estimated using `Taylor's theorem with remainder `_. +Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring +it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + +.. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + +Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + +.. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + +.. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + +The value of each partial derivative at the boundary points is computed differently. See edge_order below. + +Args: + input (``Tensor``): the tensor that represents the values of the function + +Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + +Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + +""", +) + +add_docstr( + torch.geqrf, + r""" +geqrf(input, *, out=None) -> (Tensor, Tensor) + +This is a low-level function for calling LAPACK's geqrf directly. This function +returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . + +Computes a QR decomposition of :attr:`input`. +Both `Q` and `R` matrices are stored in the same output tensor `a`. +The elements of `R` are stored on and above the diagonal. +Elementary reflectors (or Householder vectors) implicitly defining matrix `Q` +are stored below the diagonal. +The results of this function can be used together with :func:`torch.linalg.householder_product` +to obtain the `Q` matrix or +with :func:`torch.ormqr`, which uses an implicit representation of the `Q` matrix, +for an efficient matrix-matrix multiplication. + +See `LAPACK documentation for geqrf`_ for further details. + +.. note:: + See also :func:`torch.linalg.qr`, which computes Q and R matrices, and :func:`torch.linalg.lstsq` + with the ``driver="gels"`` option for a function that can solve matrix equations using a QR decomposition. + +Args: + input (Tensor): the input matrix + +Keyword args: + out (tuple, optional): the output tuple of (Tensor, Tensor). Ignored if `None`. Default: `None`. + +.. _LAPACK documentation for geqrf: + http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html + +""", +) + +add_docstr( + torch.inner, + r""" +inner(input, other, *, out=None) -> Tensor + +Computes the dot product for 1D tensors. For higher dimensions, sums the product +of elements from :attr:`input` and :attr:`other` along their last dimension. + +.. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + +Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + +Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + +Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) +""", +) + +add_docstr( + torch.outer, + r""" +outer(input, vec2, *, out=None) -> Tensor + +Outer product of :attr:`input` and :attr:`vec2`. +If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of +size :math:`m`, then :attr:`out` must be a matrix of size :math:`(n \times m)`. + +.. note:: This function does not :ref:`broadcast `. + +Args: + input (Tensor): 1-D input vector + vec2 (Tensor): 1-D input vector + +Keyword args: + out (Tensor, optional): optional output matrix + +Example:: + + >>> v1 = torch.arange(1., 5.) + >>> v2 = torch.arange(1., 4.) + >>> torch.outer(v1, v2) + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.]]) +""", +) + +add_docstr( + torch.ger, + r""" +ger(input, vec2, *, out=None) -> Tensor + +Alias of :func:`torch.outer`. + +.. warning:: + This function is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.outer` instead. +""", +) + +add_docstr( + torch.get_default_dtype, + r""" +get_default_dtype() -> torch.dtype + +Get the current default floating point :class:`torch.dtype`. + +Example:: + + >>> torch.get_default_dtype() # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_dtype(torch.float64) + >>> torch.get_default_dtype() # default is now changed to torch.float64 + torch.float64 + +""", +) + +add_docstr( + torch.get_num_threads, + r""" +get_num_threads() -> int + +Returns the number of threads used for parallelizing CPU operations +""", +) + +add_docstr( + torch.get_num_interop_threads, + r""" +get_num_interop_threads() -> int + +Returns the number of threads used for inter-op parallelism on CPU +(e.g. in JIT interpreter) +""", +) + +add_docstr( + torch.gt, + r""" +gt(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} > \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + +Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.greater, + r""" +greater(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.gt`. +""", +) + +add_docstr( + torch.histc, + r""" +histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor + +Computes the histogram of a tensor. + +The elements are sorted into equal width bins between :attr:`min` and +:attr:`max`. If :attr:`min` and :attr:`max` are both zero, the minimum and +maximum values of the data are used. + +Elements lower than min and higher than max and ``NaN`` elements are ignored. + +Args: + {input} + bins (int): number of histogram bins + min (Scalar): lower end of the range (inclusive) + max (Scalar): upper end of the range (inclusive) + +Keyword args: + {out} + +Returns: + Tensor: Histogram represented as a tensor + +Example:: + + >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) + tensor([ 0., 2., 1., 0.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.histogram, + r""" +histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + +Computes a histogram of the values in a tensor. + +:attr:`bins` can be an integer or a 1D tensor. + +If :attr:`bins` is an int, it specifies the number of equal-width bins. +By default, the lower and upper range of the bins is determined by the +minimum and maximum elements of the input tensor. The :attr:`range` +argument can be provided to specify a range for the bins. + +If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges +including the rightmost edge. It should contain at least 2 elements +and its elements should be increasing. + +Args: + {input} + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + +Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + {out} (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + +Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + +Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) +""".format( + **common_args + ), +) + +add_docstr( + torch.histogramdd, + r""" +histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + +Computes a multi-dimensional histogram of the values in a tensor. + +Interprets the elements of an input tensor whose innermost dimension has size N +as a collection of N-dimensional points. Maps each of the points into a set of +N-dimensional bins and returns the number of points (or total weight) in each bin. + +:attr:`input` must be a tensor with at least 2 dimensions. +If input has shape (M, N), each of its M rows defines a point in N-dimensional space. +If input has three or more dimensions, all but the last dimension are flattened. + +Each dimension is independently associated with its own strictly increasing sequence +of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D +tensors. Alternatively, bin edges may be constructed automatically by passing a +sequence of integers specifying the number of equal-width bins in each dimension. + +For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + +:attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + +If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences +of bin edges. Each 1D tensor should contain a strictly increasing sequence with at +least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying +the left and right edges of all bins. Every bin is exclusive of its left edge. Only +the rightmost bin is inclusive of its right edge. + +If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins +in each dimension. By default, the leftmost and rightmost bin edges in each dimension +are determined by the minimum and maximum elements of the input tensor in the +corresponding dimension. The :attr:`range` argument can be provided to manually +specify the leftmost and rightmost bin edges in each dimension. + +If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + +.. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + +Args: + {input} + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. +Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. +Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + +Example:: + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + +""".format( + **common_args + ), +) +# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 +torch.histogramdd.__module__ = "torch" + +add_docstr( + torch.hypot, + r""" +hypot(input, other, *, out=None) -> Tensor + +Given the legs of a right triangle, return its hypotenuse. + +.. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}^{2} + \text{other}_{i}^{2}} + +The shapes of ``input`` and ``other`` must be +:ref:`broadcastable `. +""" + + r""" +Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([5.0000, 5.6569, 6.4031]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.i0, + r""" +i0(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.i0`. +""", +) + +add_docstr( + torch.igamma, + r""" +igamma(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.special.gammainc`. +""", +) + +add_docstr( + torch.igammac, + r""" +igammac(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.special.gammaincc`. +""", +) + +add_docstr( + torch.index_select, + r""" +index_select(input, dim, index, *, out=None) -> Tensor + +Returns a new tensor which indexes the :attr:`input` tensor along dimension +:attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + +The returned tensor has the same number of dimensions as the original tensor +(:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length +of :attr:`index`; other dimensions have the same size as in the original tensor. + +.. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + +Args: + {input} + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.inverse, + r""" +inverse(input, *, out=None) -> Tensor + +Alias for :func:`torch.linalg.inv` +""", +) + +add_docstr( + torch.isin, + r""" +isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + +Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns +a boolean tensor of the same shape as :attr:`elements` that is True for elements +in :attr:`test_elements` and False otherwise. + +.. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + +Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + +Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + +Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) +""", +) + +add_docstr( + torch.isinf, + r""" +isinf(input) -> Tensor + +Tests if each element of :attr:`input` is infinite +(positive or negative infinity) or not. + +.. note:: + Complex values are infinite when their real or imaginary part is + infinite. + +Args: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is infinite and False elsewhere + +Example:: + + >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.isposinf, + r""" +isposinf(input, *, out=None) -> Tensor +Tests if each element of :attr:`input` is positive infinity or not. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isposinf(a) + tensor([False, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.isneginf, + r""" +isneginf(input, *, out=None) -> Tensor +Tests if each element of :attr:`input` is negative infinity or not. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isneginf(a) + tensor([ True, False, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.isclose, + r""" +isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + +Returns a new tensor with boolean elements representing if each element of +:attr:`input` is "close" to the corresponding element of :attr:`other`. +Closeness is defined as: + +.. math:: + \lvert \text{input} - \text{other} \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other} \rvert +""" + + r""" + +where :attr:`input` and :attr:`other` are finite. Where :attr:`input` +and/or :attr:`other` are nonfinite they are close if and only if +they are equal, with NaNs being considered equal to each other when +:attr:`equal_nan` is True. + +Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + +Examples:: + + >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4))) + tensor([ True, False, False]) + >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) + tensor([True, True]) +""", +) + +add_docstr( + torch.isfinite, + r""" +isfinite(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element is `finite` or not. + +Real values are finite when they are not NaN, negative infinity, or infinity. +Complex values are finite when both their real and imaginary parts are finite. + +Args: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is finite and False elsewhere + +Example:: + + >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([True, False, True, False, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.isnan, + r""" +isnan(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element of :attr:`input` +is NaN or not. Complex values are considered NaN when either their real +and/or imaginary part is NaN. + +Arguments: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is NaN and False elsewhere + +Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([False, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.isreal, + r""" +isreal(input) -> Tensor + +Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. +All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + +Arguments: + {input} + +Returns: + A boolean tensor that is True where :attr:`input` is real and False elsewhere + +Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) +""".format( + **common_args + ), +) + +add_docstr( + torch.is_floating_point, + r""" +is_floating_point(input) -> (bool) + +Returns True if the data type of :attr:`input` is a floating point data type i.e., +one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. + +Args: + {input} +""".format( + **common_args + ), +) + +add_docstr( + torch.is_complex, + r""" +is_complex(input) -> (bool) + +Returns True if the data type of :attr:`input` is a complex data type i.e., +one of ``torch.complex64``, and ``torch.complex128``. + +Args: + {input} +""".format( + **common_args + ), +) + +add_docstr( + torch.is_grad_enabled, + r""" +is_grad_enabled() -> (bool) + +Returns True if grad mode is currently enabled. +""".format( + **common_args + ), +) + +add_docstr( + torch.is_inference_mode_enabled, + r""" +is_inference_mode_enabled() -> (bool) + +Returns True if inference mode is currently enabled. +""".format( + **common_args + ), +) + +add_docstr( + torch.is_inference, + r""" +is_inference(input) -> (bool) + +Returns True if :attr:`input` is an inference tensor. + +A non-view tensor is an inference tensor if and only if it was +allocated during inference mode. A view tensor is an inference +tensor if and only if the tensor it is a view of is an inference tensor. + +For details on inference mode please see +`Inference Mode `_. + +Args: + {input} +""".format( + **common_args + ), +) + +add_docstr( + torch.is_conj, + r""" +is_conj(input) -> (bool) + +Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + +Args: + {input} +""".format( + **common_args + ), +) + +add_docstr( + torch.is_nonzero, + r""" +is_nonzero(input) -> (bool) + +Returns True if the :attr:`input` is a single element tensor which is not equal to zero +after type conversions. +i.e. not equal to ``torch.tensor([0.])`` or ``torch.tensor([0])`` or +``torch.tensor([False])``. +Throws a ``RuntimeError`` if ``torch.numel() != 1`` (even in case +of sparse tensors). + +Args: + {input} + +Examples:: + + >>> torch.is_nonzero(torch.tensor([0.])) + False + >>> torch.is_nonzero(torch.tensor([1.5])) + True + >>> torch.is_nonzero(torch.tensor([False])) + False + >>> torch.is_nonzero(torch.tensor([3])) + True + >>> torch.is_nonzero(torch.tensor([1, 3, 5])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous +""".format( + **common_args + ), +) + +add_docstr( + torch.kron, + r""" +kron(input, other, *, out=None) -> Tensor + +Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + +If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a +:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a +:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + +.. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + +where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. +If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + +Supports real-valued and complex-valued inputs. + +.. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + +Arguments: + input (Tensor) + other (Tensor) + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) +""", +) + +add_docstr( + torch.kthvalue, + r""" +kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + +Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th +smallest element of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each element found. + +If :attr:`dim` is not given, the last dimension of the `input` is chosen. + +If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors +are the same size as :attr:`input`, except in the dimension :attr:`dim` where +they are of size 1. Otherwise, :attr:`dim` is squeezed +(see :func:`torch.squeeze`), resulting in both the :attr:`values` and +:attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + +.. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + +Args: + {input} + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + {keepdim} + +Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + +Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.lcm, + r""" +lcm(input, other, *, out=None) -> Tensor + +Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. + +Both :attr:`input` and :attr:`other` must have integer types. + +.. note:: + This defines :math:`lcm(0, 0) = 0` and :math:`lcm(0, a) = 0`. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.lcm(a, b) + tensor([15, 20, 15]) + >>> c = torch.tensor([3]) + >>> torch.lcm(a, c) + tensor([15, 30, 15]) +""".format( + **common_args + ), +) + +add_docstr( + torch.ldexp, + r""" +ldexp(input, other, *, out=None) -> Tensor + +Multiplies :attr:`input` by 2 ** :attr:`other`. + +.. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i +""" + + r""" + +Typically this function is used to construct floating point numbers by multiplying +mantissas in :attr:`input` with integral powers of two created from the exponents +in :attr:`other`. + +Args: + {input} + other (Tensor): a tensor of exponents, typically integers. + +Keyword args: + {out} + +Example:: + + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + + +""".format( + **common_args + ), +) + +add_docstr( + torch.le, + r""" +le(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} \leq \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + +Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.less_equal, + r""" +less_equal(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.le`. +""", +) + +add_docstr( + torch.lerp, + r""" +lerp(input, end, weight, *, out=None) + +Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based +on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + +.. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) +""" + + r""" +The shapes of :attr:`start` and :attr:`end` must be +:ref:`broadcastable `. If :attr:`weight` is a tensor, then +the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + +Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + +Keyword args: + {out} + +Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) +""".format( + **common_args + ), +) + +add_docstr( + torch.lgamma, + r""" +lgamma(input, *, out=None) -> Tensor + +Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + +.. math:: + \text{out}_{i} = \ln |\Gamma(\text{input}_{i})| +""" + + """ +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.lgamma(a) + tensor([ 0.5724, 0.0000, -0.1208]) +""".format( + **common_args + ), +) + +add_docstr( + torch.linspace, + r""" +linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly +spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + +.. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) +""" + + """ + +From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + +Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + +Keyword arguments: + {out} + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + {layout} + {device} + {requires_grad} + + +Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.log, + r""" +log(input, *, out=None) -> Tensor + +Returns a new tensor with the natural logarithm of the elements +of :attr:`input`. + +.. math:: + y_{i} = \log_{e} (x_{i}) +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) * 5 + >>> a + tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + >>> torch.log(a) + tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) +""".format( + **common_args + ), +) + +add_docstr( + torch.log10, + r""" +log10(input, *, out=None) -> Tensor + +Returns a new tensor with the logarithm to the base 10 of the elements +of :attr:`input`. + +.. math:: + y_{i} = \log_{10} (x_{i}) +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) + + + >>> torch.log10(a) + tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.log1p, + r""" +log1p(input, *, out=None) -> Tensor + +Returns a new tensor with the natural logarithm of (1 + :attr:`input`). + +.. math:: + y_i = \log_{e} (x_i + 1) +""" + + r""" +.. note:: This function is more accurate than :func:`torch.log` for small + values of :attr:`input` + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) + >>> torch.log1p(a) + tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) +""".format( + **common_args + ), +) + +add_docstr( + torch.log2, + r""" +log2(input, *, out=None) -> Tensor + +Returns a new tensor with the logarithm to the base 2 of the elements +of :attr:`input`. + +.. math:: + y_{i} = \log_{2} (x_{i}) +""" + + r""" + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) + + + >>> torch.log2(a) + tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.logaddexp, + r""" +logaddexp(input, other, *, out=None) -> Tensor + +Logarithm of the sum of exponentiations of the inputs. + +Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful +in statistics where the calculated probabilities of events may be so small as to +exceed the range of normal floating point numbers. In such cases the logarithm +of the calculated probability is stored. This function allows adding +probabilities stored in such a fashion. + +This op should be disambiguated with :func:`torch.logsumexp` which performs a +reduction on a single tensor. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} + +Example:: + + >>> torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1.0, -2, -3])) + tensor([-0.3069, -0.6867, -0.8731]) + >>> torch.logaddexp(torch.tensor([-100.0, -200, -300]), torch.tensor([-1.0, -2, -3])) + tensor([-1., -2., -3.]) + >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) + tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) +""".format( + **common_args + ), +) + +add_docstr( + torch.logaddexp2, + r""" +logaddexp2(input, other, *, out=None) -> Tensor + +Logarithm of the sum of exponentiations of the inputs in base-2. + +Calculates pointwise :math:`\log_2\left(2^x + 2^y\right)`. See +:func:`torch.logaddexp` for more details. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword arguments: + {out} +""".format( + **common_args + ), +) + +add_docstr( + torch.xlogy, + r""" +xlogy(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.special.xlogy`. +""", +) + +add_docstr( + torch.logical_and, + r""" +logical_and(input, other, *, out=None) -> Tensor + +Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are +treated as ``True``. + +Args: + {input} + other (Tensor): the tensor to compute AND with + +Keyword args: + {out} + +Example:: + + >>> torch.logical_and(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, False]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_and(a, b) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b.double()) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b) + tensor([False, False, True, False]) + >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([False, False, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.logical_not, + r""" +logical_not(input, *, out=None) -> Tensor + +Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool +dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> torch.logical_not(torch.tensor([True, False])) + tensor([False, True]) + >>> torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1.5, -10.], dtype=torch.double)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) + tensor([1, 0, 0], dtype=torch.int16) +""".format( + **common_args + ), +) + +add_docstr( + torch.logical_or, + r""" +logical_or(input, other, *, out=None) -> Tensor + +Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are +treated as ``True``. + +Args: + {input} + other (Tensor): the tensor to compute OR with + +Keyword args: + {out} + +Example:: + + >>> torch.logical_or(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_or(a, b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b.double()) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, True, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.logical_xor, + r""" +logical_xor(input, other, *, out=None) -> Tensor + +Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are +treated as ``True``. + +Args: + {input} + other (Tensor): the tensor to compute XOR with + +Keyword args: + {out} + +Example:: + + >>> torch.logical_xor(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([False, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_xor(a, b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b.double()) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, False, False]) +""".format( + **common_args + ), +) + +add_docstr( + torch.logspace, + """ +logspace(start, end, steps, base=10.0, *, \ + out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" + +Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly +spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to +:math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale +with base :attr:`base`. That is, the values are: + +.. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) +""" + + """ + + +From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + +Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + +Keyword arguments: + {out} + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.logsumexp, + r""" +logsumexp(input, dim, keepdim=False, *, out=None) + +Returns the log of summed exponentials of each row of the :attr:`input` +tensor in the given dimension :attr:`dim`. The computation is numerically +stabilized. + +For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{{logsumexp}}(x)_{{i}} = \log \sum_j \exp(x_{{ij}}) + +{keepdim_details} + +Args: + {input} + {opt_dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.lt, + r""" +lt(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} < \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + +Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.lu_unpack, + r""" +lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. + +.. seealso:: + + :func:`~linalg.lu` returns the matrices from the LU decomposition. Its gradient formula is more efficient + than that of doing :func:`~linalg.lu_factor` followed by :func:`~linalg.lu_unpack`. + +Args: + LU_data (Tensor): the packed LU factorization data + LU_pivots (Tensor): the packed LU factorization pivots + unpack_data (bool): flag indicating if the data should be unpacked. + If ``False``, then the returned ``L`` and ``U`` are empty tensors. + Default: ``True`` + unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``. + If ``False``, then the returned ``P`` is an empty tensor. + Default: ``True`` + +Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + +Returns: + A namedtuple ``(P, L, U)`` + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # We can recover A from the factorization + >>> A_ = P @ L @ U + >>> torch.allclose(A, A_) + True + + >>> # LU factorization of a rectangular matrix: + >>> A = torch.randn(2, 3, 2) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # P, L, U are the same as returned by linalg.lu + >>> P_, L_, U_ = torch.linalg.lu(A) + >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) + True + +""".format( + **common_args + ), +) + +add_docstr( + torch.less, + r""" +less(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.lt`. +""", +) + +add_docstr( + torch.lu_solve, + r""" +lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor + +Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted +LU factorization of A from :func:`~linalg.lu_factor`. + +This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + +.. warning:: + + :func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`. + :func:`torch.lu_solve` will be removed in a future PyTorch release. + ``X = torch.lu_solve(B, LU, pivots)`` should be replaced with + + .. code:: python + + X = linalg.lu_solve(LU, pivots, B) + +Arguments: + b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` + is zero or more batch dimensions. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, + where :math:`*` is zero or more batch dimensions. + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, + where :math:`*` is zero or more batch dimensions. + The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of + :attr:`LU_data`. + +Keyword args: + {out} + +Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3, 1) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) + tensor(1.00000e-07 * + 2.8312) +""".format( + **common_args + ), +) + +add_docstr( + torch.masked_select, + r""" +masked_select(input, mask, *, out=None) -> Tensor + +Returns a new 1-D tensor which indexes the :attr:`input` tensor according to +the boolean mask :attr:`mask` which is a `BoolTensor`. + +The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need +to match, but they must be :ref:`broadcastable `. + +.. note:: The returned tensor does **not** use the same storage + as the original tensor + +Args: + {input} + mask (BoolTensor): the tensor containing the binary mask to index with + +Keyword args: + {out} + +Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], + [-1.2035, 1.2252, 0.5002, 0.6248], + [ 0.1307, -2.0608, 0.1244, 2.0139]]) + >>> mask = x.ge(0.5) + >>> mask + tensor([[False, False, False, False], + [False, True, True, True], + [False, False, False, True]]) + >>> torch.masked_select(x, mask) + tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) +""".format( + **common_args + ), +) + +add_docstr( + torch.matrix_power, + r""" +matrix_power(input, n, *, out=None) -> Tensor + +Alias for :func:`torch.linalg.matrix_power` +""", +) + +add_docstr( + torch.matrix_exp, + r""" +matrix_exp(A) -> Tensor + +Alias for :func:`torch.linalg.matrix_exp`. +""", +) + +add_docstr( + torch.max, + r""" +max(input) -> Tensor + +Returns the maximum value of all elements in the ``input`` tensor. + +.. warning:: + This function produces deterministic (sub)gradients unlike ``max(dim=0)`` + +Args: + {input} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + +.. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each maximum value found +(argmax). + +If ``keepdim`` is ``True``, the output tensors are of the same size +as ``input`` except in the dimension ``dim`` where they are of size 1. +Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting +in the output tensors having 1 fewer dimension than ``input``. + +.. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + +Args: + {input} + {dim} + {keepdim} Default: ``False``. + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + +.. function:: max(input, other, *, out=None) -> Tensor + :noindex: + +See :func:`torch.maximum`. + +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.maximum, + r""" +maximum(input, other, *, out=None) -> Tensor + +Computes the element-wise maximum of :attr:`input` and :attr:`other`. + +.. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`maximum` is not supported for tensors with complex dtypes. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.maximum(a, b) + tensor([3, 2, 4]) +""".format( + **common_args + ), +) + +add_docstr( + torch.fmax, + r""" +fmax(input, other, *, out=None) -> Tensor + +Computes the element-wise maximum of :attr:`input` and :attr:`other`. + +This is like :func:`torch.maximum` except it handles NaNs differently: +if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. +Only if both elements are NaN is NaN propagated. + +This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and floating-point inputs. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) +""".format( + **common_args + ), +) + +add_docstr( + torch.amax, + r""" +amax(input, dim, keepdim=False, *, out=None) -> Tensor + +Returns the maximum value of each slice of the :attr:`input` tensor in the given +dimension(s) :attr:`dim`. + +.. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices, + - ``amax``/``amin`` evenly distributes gradient between equal values, + while ``max(dim)``/``min(dim)`` propagates gradient only to a single + index in the source tensor. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.8177, 1.4878, -0.2491, 0.9130], + [-0.7158, 1.1775, 2.0992, 0.4817], + [-0.0053, 0.0164, -1.3738, -0.0507], + [ 1.9700, 1.1106, -1.0318, -1.0816]]) + >>> torch.amax(a, 1) + tensor([1.4878, 2.0992, 0.0164, 1.9700]) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.argmax, + r""" +argmax(input) -> LongTensor + +Returns the indices of the maximum value of all elements in the :attr:`input` tensor. + +This is the second value returned by :meth:`torch.max`. See its +documentation for the exact semantics of this method. + +.. note:: If there are multiple maximal values then the indices of the first maximal value are returned. + +Args: + {input} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a) + tensor(0) + +.. function:: argmax(input, dim, keepdim=False) -> LongTensor + :noindex: + +Returns the indices of the maximum values of a tensor across a dimension. + +This is the second value returned by :meth:`torch.max`. See its +documentation for the exact semantics of this method. + +Args: + {input} + {dim} If ``None``, the argmax of the flattened input is returned. + {keepdim} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a, dim=1) + tensor([ 0, 2, 0, 1]) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.argwhere, + r""" +argwhere(input) -> Tensor + +Returns a tensor containing the indices of all non-zero elements of +:attr:`input`. Each row in the result contains the indices of a non-zero +element in :attr:`input`. The result is sorted lexicographically, with +the last index changing the fastest (C-style). + +If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor +:attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of +non-zero elements in the :attr:`input` tensor. + +.. note:: + This function is similar to NumPy's `argwhere`. + + When :attr:`input` is on CUDA, this function causes host-device synchronization. + +Args: + {input} + +Example:: + + >>> t = torch.tensor([1, 0, 1]) + >>> torch.argwhere(t) + tensor([[0], + [2]]) + >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) + >>> torch.argwhere(t) + tensor([[0, 0], + [0, 2], + [1, 1], + [1, 2]]) +""", +) + +add_docstr( + torch.mean, + r""" +mean(input, *, dtype=None) -> Tensor + +Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + +Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + +.. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + +Returns the mean value of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, +reduce over all of them. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {dtype} + {out} + +.. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.nanmean, + r""" +nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes the mean of all `non-NaN` elements along the specified dimensions. + +This function is identical to :func:`torch.mean` when there are no `NaN` values +in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will +propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the +`NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`). + +{keepdim_details} + +Args: + {input} + {opt_dim} + {keepdim} + +Keyword args: + {dtype} + {out} + +.. seealso:: + + :func:`torch.mean` computes the mean value, propagating `NaN`. + +Example:: + + >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) + >>> x.mean() + tensor(nan) + >>> x.nanmean() + tensor(1.8000) + >>> x.mean(dim=0) + tensor([ nan, 1.5000, 2.5000]) + >>> x.nanmean(dim=0) + tensor([1.0000, 1.5000, 2.5000]) + + # If all elements in the reduced dimensions are NaN then the result is NaN + >>> torch.tensor([torch.nan]).nanmean() + tensor(nan) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.median, + r""" +median(input) -> Tensor + +Returns the median of the values in :attr:`input`. + +.. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + +.. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + +Args: + {input} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + +.. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + +By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + +If :attr:`keepdim` is ``True``, the output tensors are of the same size +as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in +the outputs tensor having 1 fewer dimension than :attr:`input`. + +.. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + +.. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + +Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.nanmedian, + r""" +nanmedian(input) -> Tensor + +Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. +When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, +while this function will return the median of the non-``NaN`` elements in :attr:`input`. +If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + +Args: + {input} + +Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + +.. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` +in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values +found in the dimension :attr:`dim`. + +This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has +one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the +median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + +Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.quantile, + r""" +quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + +Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + +To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location +of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with +indices ``i`` and ``j`` in the sorted order, result is computed according to the given +:attr:`interpolation` method as follows: + +- ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. +- ``lower``: ``a``. +- ``higher``: ``b``. +- ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). +- ``midpoint``: ``(a + b) / 2``. + +If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size +equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + +.. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + +Args: + {input} + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + {dim} + {keepdim} + +Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + {out} + +Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.nanquantile, + r""" +nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + +This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, +computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did +not exist. If all values in a reduced row are ``NaN`` then the quantiles for +that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + +Args: + {input} + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + {dim} + {keepdim} + +Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + {out} + +Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.min, + r""" +min(input) -> Tensor + +Returns the minimum value of all elements in the :attr:`input` tensor. + +.. warning:: + This function produces deterministic (sub)gradients unlike ``min(dim=0)`` + +Args: + {input} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + +.. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + +Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each minimum value found +(argmin). + +If :attr:`keepdim` is ``True``, the output tensors are of the same size as +:attr:`input` except in the dimension :attr:`dim` where they are of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in +the output tensors having 1 fewer dimension than :attr:`input`. + +.. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + +.. function:: min(input, other, *, out=None) -> Tensor + :noindex: + +See :func:`torch.minimum`. +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.minimum, + r""" +minimum(input, other, *, out=None) -> Tensor + +Computes the element-wise minimum of :attr:`input` and :attr:`other`. + +.. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`minimum` is not supported for tensors with complex dtypes. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.minimum(a, b) + tensor([1, 0, -1]) +""".format( + **common_args + ), +) + +add_docstr( + torch.fmin, + r""" +fmin(input, other, *, out=None) -> Tensor + +Computes the element-wise minimum of :attr:`input` and :attr:`other`. + +This is like :func:`torch.minimum` except it handles NaNs differently: +if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. +Only if both elements are NaN is NaN propagated. + +This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and floating-point inputs. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) +""".format( + **common_args + ), +) + +add_docstr( + torch.amin, + r""" +amin(input, dim, keepdim=False, *, out=None) -> Tensor + +Returns the minimum value of each slice of the :attr:`input` tensor in the given +dimension(s) :attr:`dim`. + +.. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices, + - ``amax``/``amin`` evenly distributes gradient between equal values, + while ``max(dim)``/``min(dim)`` propagates gradient only to a single + index in the source tensor. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.6451, -0.4866, 0.2987, -1.3312], + [-0.5744, 1.2980, 1.8397, -0.2713], + [ 0.9128, 0.9214, -1.7268, -0.2995], + [ 0.9023, 0.4853, 0.9075, -1.6165]]) + >>> torch.amin(a, 1) + tensor([-1.3312, -0.5744, -1.7268, -1.6165]) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.aminmax, + r""" +aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) + +Computes the minimum and maximum values of the :attr:`input` tensor. + +Args: + input (Tensor): + The input tensor + +Keyword Args: + dim (Optional[int]): + The dimension along which to compute the values. If `None`, + computes the values over the entire :attr:`input` tensor. + Default is `None`. + keepdim (bool): + If `True`, the reduced dimensions will be kept in the output + tensor as dimensions with size 1 for broadcasting, otherwise + they will be removed, as if calling (:func:`torch.squeeze`). + Default is `False`. + out (Optional[Tuple[Tensor, Tensor]]): + Optional tensors on which to write the result. Must have the same + shape and dtype as the expected output. + Default is `None`. + +Returns: + A named tuple `(min, max)` containing the minimum and maximum values. + +Raises: + RuntimeError + If any of the dimensions to compute the values over has size 0. + +.. note:: + NaN values are propagated to the output if at least one value is NaN. + +.. seealso:: + :func:`torch.amin` computes just the minimum value + :func:`torch.amax` computes just the maximum value + +Example:: + + >>> torch.aminmax(torch.tensor([1, -3, 5])) + torch.return_types.aminmax( + min=tensor(-3), + max=tensor(5)) + + >>> # aminmax propagates NaNs + >>> torch.aminmax(torch.tensor([1, -3, 5, torch.nan])) + torch.return_types.aminmax( + min=tensor(nan), + max=tensor(nan)) + + >>> t = torch.arange(10).view(2, 5) + >>> t + tensor([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + >>> t.aminmax(dim=0, keepdim=True) + torch.return_types.aminmax( + min=tensor([[0, 1, 2, 3, 4]]), + max=tensor([[5, 6, 7, 8, 9]])) +""", +) + +add_docstr( + torch.argmin, + r""" +argmin(input, dim=None, keepdim=False) -> LongTensor + +Returns the indices of the minimum value(s) of the flattened tensor or along a dimension + +This is the second value returned by :meth:`torch.min`. See its +documentation for the exact semantics of this method. + +.. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + +Args: + {input} + {dim} If ``None``, the argmin of the flattened input is returned. + {keepdim} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], + [ 1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) + >>> torch.argmin(a, dim=1) + tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.mm, + r""" +mm(input, mat2, *, out=None) -> Tensor + +Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + +If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a +:math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + +.. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + +Supports strided and sparse 2-D tensors as inputs, autograd with +respect to strided inputs. + +This operation has support for arguments with :ref:`sparse layouts`. +If :attr:`out` is provided it's layout will be used. Otherwise, the result +layout will be deduced from that of :attr:`input`. + +{sparse_beta_warning} + +{tf32_note} + +{rocm_fp16_note} + +Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + +Keyword args: + {out} + +Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes, **sparse_support_notes + ), +) + +add_docstr( + torch.hspmm, + r""" +hspmm(mat1, mat2, *, out=None) -> Tensor + +Performs a matrix multiplication of a :ref:`sparse COO matrix +` :attr:`mat1` and a strided matrix :attr:`mat2`. The +result is a (1 + 1)-dimensional :ref:`hybrid COO matrix +`. + +Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + +Keyword args: + {out} +""".format( + **common_args + ), +) + +add_docstr( + torch.matmul, + r""" +matmul(input, other, *, out=None) -> Tensor + +Matrix product of two tensors. + +The behavior depends on the dimensionality of the tensors as follows: + +- If both tensors are 1-dimensional, the dot product (scalar) is returned. +- If both arguments are 2-dimensional, the matrix-matrix product is returned. +- If the first argument is 1-dimensional and the second argument is 2-dimensional, + a 1 is prepended to its dimension for the purpose of the matrix multiply. + After the matrix multiply, the prepended dimension is removed. +- If the first argument is 2-dimensional and the second argument is 1-dimensional, + the matrix-vector product is returned. +- If both arguments are at least 1-dimensional and at least one argument is + N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first + argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the + batched matrix multiply and removed after. If the second argument is 1-dimensional, a + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. + +This operation has support for arguments with :ref:`sparse layouts`. In particular the +matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions +as :func:`torch.mm` + +{sparse_beta_warning} + +{tf32_note} + +{rocm_fp16_note} + +.. note:: + + The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. + +Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + +Keyword args: + {out} + +Example:: + + >>> # vector x vector + >>> tensor1 = torch.randn(3) + >>> tensor2 = torch.randn(3) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([]) + >>> # matrix x vector + >>> tensor1 = torch.randn(3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([3]) + >>> # batched matrix x broadcasted vector + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3]) + >>> # batched matrix x batched matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(10, 4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + >>> # batched matrix x broadcasted matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + +""".format( + **common_args, **tf32_notes, **rocm_fp16_notes, **sparse_support_notes + ), +) + +add_docstr( + torch.mode, + r""" +mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + +Returns a namedtuple ``(values, indices)`` where ``values`` is the mode +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`, i.e. a value which appears most often +in that row, and ``indices`` is the index location of each mode value found. + +By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + +If :attr:`keepdim` is ``True``, the output tensors are of the same size as +:attr:`input` except in the dimension :attr:`dim` where they are of size 1. +Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting +in the output tensors having 1 fewer dimension than :attr:`input`. + +.. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + +Example:: + + >>> b = torch.tensor( + [[0, 0, 0, 2, 0, 0, 2], + [0, 3, 0, 0, 2, 0, 1], + [2, 2, 2, 0, 0, 0, 3], + [2, 2, 3, 0, 1, 1, 0], + [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.mul, + r""" +mul(input, other, *, out=None) -> Tensor + +Multiplies :attr:`input` by :attr:`other`. + + +.. math:: + \text{out}_i = \text{input}_i \times \text{other}_i +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Number) - the tensor or number to multiply input by. + +Keyword args: + {out} + +Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.2015, -0.4255, 2.6087]) + >>> torch.mul(a, 100) + tensor([ 20.1494, -42.5491, 260.8663]) + + >>> b = torch.randn(4, 1) + >>> b + tensor([[ 1.1207], + [-0.3137], + [ 0.0700], + [ 0.8378]]) + >>> c = torch.randn(1, 4) + >>> c + tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) + >>> torch.mul(b, c) + tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], + [-0.1614, -0.0382, 0.1645, -0.7021], + [ 0.0360, 0.0085, -0.0367, 0.1567], + [ 0.4312, 0.1019, -0.4394, 1.8753]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.multiply, + r""" +multiply(input, other, *, out=None) + +Alias for :func:`torch.mul`. +""", +) + +add_docstr( + torch.multinomial, + r""" +multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor + +Returns a tensor where each row contains :attr:`num_samples` indices sampled +from the multinomial (a stricter definition would be multivariate, +refer to torch.distributions.multinomial.Multinomial for more details) +probability distribution located in the corresponding row +of tensor :attr:`input`. + +.. note:: + The rows of :attr:`input` do not need to sum to one (in which case we use + the values as weights), but must be non-negative, finite and have + a non-zero sum. + +Indices are ordered from left to right according to when each was sampled +(first samples are placed in first column). + +If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. + +If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape +:math:`(m \times \text{{num\_samples}})`. + +If replacement is ``True``, samples are drawn with replacement. + +If not, they are drawn without replacement, which means that when a +sample index is drawn for a row, it cannot be drawn again for that row. + +.. note:: + When drawn without replacement, :attr:`num_samples` must be lower than + number of non-zero elements in :attr:`input` (or the min number of non-zero + elements in each row of :attr:`input` if it is a matrix). + +Args: + input (Tensor): the input tensor containing probabilities + num_samples (int): number of samples to draw + replacement (bool, optional): whether to draw with replacement or not + +Keyword args: + {generator} + {out} + +Example:: + + >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights + >>> torch.multinomial(weights, 2) + tensor([1, 2]) + >>> torch.multinomial(weights, 4) # ERROR! + RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, + not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320 + >>> torch.multinomial(weights, 4, replacement=True) + tensor([ 2, 1, 1, 1]) +""".format( + **common_args + ), +) + +add_docstr( + torch.mv, + r""" +mv(input, vec, *, out=None) -> Tensor + +Performs a matrix-vector product of the matrix :attr:`input` and the vector +:attr:`vec`. + +If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of +size :math:`m`, :attr:`out` will be 1-D of size :math:`n`. + +.. note:: This function does not :ref:`broadcast `. + +Args: + input (Tensor): matrix to be multiplied + vec (Tensor): vector to be multiplied + +Keyword args: + {out} + +Example:: + + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.mv(mat, vec) + tensor([ 1.0404, -0.6361]) +""".format( + **common_args + ), +) + +add_docstr( + torch.mvlgamma, + r""" +mvlgamma(input, p, *, out=None) -> Tensor + +Alias for :func:`torch.special.multigammaln`. +""", +) + +add_docstr( + torch.movedim, + r""" +movedim(input, source, destination) -> Tensor + +Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` +to the position(s) in :attr:`destination`. + +Other dimensions of :attr:`input` that are not explicitly moved remain in +their original order and appear at the positions not specified in :attr:`destination`. + +Args: + {input} + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + +Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.moveaxis, + r""" +moveaxis(input, source, destination) -> Tensor + +Alias for :func:`torch.movedim`. + +This function is equivalent to NumPy's moveaxis function. + +Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.swapdims, + r""" +swapdims(input, dim0, dim1) -> Tensor + +Alias for :func:`torch.transpose`. + +This function is equivalent to NumPy's swapaxes function. + +Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.swapaxes, + r""" +swapaxes(input, axis0, axis1) -> Tensor + +Alias for :func:`torch.transpose`. + +This function is equivalent to NumPy's swapaxes function. + +Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.narrow, + r""" +narrow(input, dim, start, length) -> Tensor + +Returns a new tensor that is a narrowed version of :attr:`input` tensor. The +dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The +returned tensor and :attr:`input` tensor share the same underlying storage. + +Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + +Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) +""", +) + +add_docstr( + torch.narrow_copy, + r""" +narrow_copy(input, dim, start, length, *, out=None) -> Tensor + +Same as :meth:`Tensor.narrow` except this returns a copy rather +than shared storage. This is primarily for sparse tensors, which +do not have a shared-storage narrow method. + +Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow_copy(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow_copy(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) + >>> torch.narrow_copy(s, 0, 0, 1) + tensor(indices=tensor([[0, 0], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + +.. seealso:: + + :func:`torch.narrow` for a non copy variant + +""".format( + **common_args + ), +) + +add_docstr( + torch.nan_to_num, + r""" +nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + +Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` +with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. +By default, :literal:`NaN`\ s are replaced with zero, positive infinity is replaced with the +greatest finite value representable by :attr:`input`'s dtype, and negative infinity +is replaced with the least finite value representable by :attr:`input`'s dtype. + +Args: + {input} + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + +Keyword args: + {out} + +Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.ne, + r""" +ne(input, other, *, out=None) -> Tensor + +Computes :math:`\text{input} \neq \text{other}` element-wise. +""" + + r""" + +The second argument can be a number or a tensor whose shape is +:ref:`broadcastable ` with the first argument. + +Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + +Keyword args: + {out} + +Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + +Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.not_equal, + r""" +not_equal(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.ne`. +""", +) + +add_docstr( + torch.neg, + r""" +neg(input, *, out=None) -> Tensor + +Returns a new tensor with the negative of the elements of :attr:`input`. + +.. math:: + \text{out} = -1 \times \text{input} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.neg(a) + tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) +""".format( + **common_args + ), +) + +add_docstr( + torch.negative, + r""" +negative(input, *, out=None) -> Tensor + +Alias for :func:`torch.neg` +""", +) + +add_docstr( + torch.nextafter, + r""" +nextafter(input, other, *, out=None) -> Tensor + +Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. + +The shapes of ``input`` and ``other`` must be +:ref:`broadcastable `. + +Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> eps = torch.finfo(torch.float32).eps + >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) + tensor([True, True]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.nonzero, + r""" +nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + +.. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + +**When** :attr:`as_tuple` **is** ``False`` **(default)**: + +Returns a tensor containing the indices of all non-zero elements of +:attr:`input`. Each row in the result contains the indices of a non-zero +element in :attr:`input`. The result is sorted lexicographically, with +the last index changing the fastest (C-style). + +If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor +:attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of +non-zero elements in the :attr:`input` tensor. + +**When** :attr:`as_tuple` **is** ``True``: + +Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, +each containing the indices (in that dimension) of all non-zero elements of +:attr:`input` . + +If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` +tensors of size :math:`z`, where :math:`z` is the total number of +non-zero elements in the :attr:`input` tensor. + +As a special case, when :attr:`input` has zero dimensions and a nonzero scalar +value, it is treated as a one-dimensional tensor with one element. + +Args: + {input} + +Keyword args: + out (LongTensor, optional): the output tensor containing indices + +Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + +Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) +""".format( + **common_args + ), +) + +add_docstr( + torch.normal, + r""" +normal(mean, std, *, generator=None, out=None) -> Tensor + +Returns a tensor of random numbers drawn from separate normal distributions +whose mean and standard deviation are given. + +The :attr:`mean` is a tensor with the mean of +each output element's normal distribution + +The :attr:`std` is a tensor with the standard deviation of +each output element's normal distribution + +The shapes of :attr:`mean` and :attr:`std` don't need to match, but the +total number of elements in each tensor need to be the same. + +.. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + +.. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + +Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + +Keyword args: + {generator} + {out} + +Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + +.. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + +Similar to the function above, but the means are shared among all drawn +elements. + +Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + +Keyword args: + {out} + +Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + +.. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + +Similar to the function above, but the standard deviations are shared among +all drawn elements. + +Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + +Keyword args: + out (Tensor, optional): the output tensor + +Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + +.. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + +Similar to the function above, but the means and standard deviations are shared +among all drawn elements. The resulting tensor has size given by :attr:`size`. + +Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + +Keyword args: + {out} + +Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.numel, + r""" +numel(input) -> int + +Returns the total number of elements in the :attr:`input` tensor. + +Args: + {input} + +Example:: + + >>> a = torch.randn(1, 2, 3, 4, 5) + >>> torch.numel(a) + 120 + >>> a = torch.zeros(4,4) + >>> torch.numel(a) + 16 + +""".format( + **common_args + ), +) + +add_docstr( + torch.ones, + r""" +ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a tensor filled with the scalar value `1`, with the shape defined +by the variable argument :attr:`size`. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword arguments: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.ones_like, + r""" +ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor filled with the scalar value `1`, with the same size as +:attr:`input`. ``torch.ones_like(input)`` is equivalent to +``torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +.. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.ones_like(input, out=output)`` is equivalent to + ``torch.ones(input.size(), out=output)``. + +Args: + {input} + +Keyword arguments: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +Example:: + + >>> input = torch.empty(2, 3) + >>> torch.ones_like(input) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.orgqr, + r""" +orgqr(input, tau) -> Tensor + +Alias for :func:`torch.linalg.householder_product`. +""", +) + +add_docstr( + torch.ormqr, + r""" +ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor + +Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + +Multiplies a :math:`m \times n` matrix `C` (given by :attr:`other`) with a matrix `Q`, +where `Q` is represented using Householder reflectors `(input, tau)`. +See `Representation of Orthogonal or Unitary Matrices`_ for further details. + +If :attr:`left` is `True` then `op(Q)` times `C` is computed, otherwise the result is `C` times `op(Q)`. +When :attr:`left` is `True`, the implicit matrix `Q` has size :math:`m \times m`. +It has size :math:`n \times n` otherwise. +If :attr:`transpose` is `True` then `op` is the conjugate transpose operation, otherwise it's a no-op. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + +.. seealso:: + :func:`torch.geqrf` can be used to form the Householder representation `(input, tau)` of matrix `Q` + from the QR decomposition. + +.. note:: + This function supports backward but it is only fast when ``(input, tau)`` do not require gradients + and/or ``tau.size(-1)`` is very small. + `` + +Args: + input (Tensor): tensor of shape `(*, mn, k)` where `*` is zero or more batch dimensions + and `mn` equals to `m` or `n` depending on the :attr:`left`. + tau (Tensor): tensor of shape `(*, min(mn, k))` where `*` is zero or more batch dimensions. + other (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + left (bool): controls the order of multiplication. + transpose (bool): controls whether the matrix `Q` is conjugate transposed or not. + +Keyword args: + out (Tensor, optional): the output Tensor. Ignored if `None`. Default: `None`. + +.. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html +""", +) + +add_docstr( + torch.permute, + r""" +permute(input, dims) -> Tensor + +Returns a view of the original tensor :attr:`input` with its dimensions permuted. + +Args: + {input} + dims (tuple of int): The desired ordering of dimensions + +Example: + >>> x = torch.randn(2, 3, 5) + >>> x.size() + torch.Size([2, 3, 5]) + >>> torch.permute(x, (2, 0, 1)).size() + torch.Size([5, 2, 3]) +""".format( + **common_args + ), +) + +add_docstr( + torch.poisson, + r""" +poisson(input, generator=None) -> Tensor + +Returns a tensor of the same size as :attr:`input` with each element +sampled from a Poisson distribution with rate parameter given by the corresponding +element in :attr:`input` i.e., + +.. math:: + \text{{out}}_i \sim \text{{Poisson}}(\text{{input}}_i) + +:attr:`input` must be non-negative. + +Args: + input (Tensor): the input tensor containing the rates of the Poisson distribution + +Keyword args: + {generator} + +Example:: + + >>> rates = torch.rand(4, 4) * 5 # rate parameter between 0 and 5 + >>> torch.poisson(rates) + tensor([[9., 1., 3., 5.], + [8., 6., 6., 0.], + [0., 4., 5., 3.], + [2., 1., 4., 2.]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.polygamma, + r""" +polygamma(n, input, *, out=None) -> Tensor + +Alias for :func:`torch.special.polygamma`. +""", +) + +add_docstr( + torch.positive, + r""" +positive(input) -> Tensor + +Returns :attr:`input`. +Throws a runtime error if :attr:`input` is a bool tensor. +""" + + r""" +Args: + {input} + +Example:: + + >>> t = torch.randn(5) + >>> t + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.positive(t) + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) +""".format( + **common_args + ), +) + +add_docstr( + torch.pow, + r""" +pow(input, exponent, *, out=None) -> Tensor + +Takes the power of each element in :attr:`input` with :attr:`exponent` and +returns a tensor with the result. + +:attr:`exponent` can be either a single ``float`` number or a `Tensor` +with the same number of elements as :attr:`input`. + +When :attr:`exponent` is a scalar value, the operation applied is: + +.. math:: + \text{out}_i = x_i ^ \text{exponent} + +When :attr:`exponent` is a tensor, the operation applied is: + +.. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} +""" + + r""" +When :attr:`exponent` is a tensor, the shapes of :attr:`input` +and :attr:`exponent` must be :ref:`broadcastable `. + +Args: + {input} + exponent (float or tensor): the exponent value + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + +.. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + +:attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. +The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + +The operation applied is: + +.. math:: + \text{{out}}_i = \text{{self}} ^ {{\text{{exponent}}_i}} + +Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + +Keyword args: + {out} + +Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.float_power, + r""" +float_power(input, exponent, *, out=None) -> Tensor + +Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. +If neither input is complex returns a ``torch.float64`` tensor, +and if one or more inputs is complex returns a ``torch.complex128`` tensor. + +.. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + +Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + +Keyword args: + {out} + +Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) +""".format( + **common_args + ), +) + +add_docstr( + torch.prod, + r""" +prod(input, *, dtype=None) -> Tensor + +Returns the product of all elements in the :attr:`input` tensor. + +Args: + {input} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + +.. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + +Returns the product of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`. + +{keepdim_details} + +Args: + {input} + {dim} + {keepdim} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) +""".format( + **single_dim_common + ), +) + +add_docstr( + torch.promote_types, + r""" +promote_types(type1, type2) -> dtype + +Returns the :class:`torch.dtype` with the smallest size and scalar kind that is +not smaller nor of lower kind than either `type1` or `type2`. See type promotion +:ref:`documentation ` for more information on the type +promotion logic. + +Args: + type1 (:class:`torch.dtype`) + type2 (:class:`torch.dtype`) + +Example:: + + >>> torch.promote_types(torch.int32, torch.float32) + torch.float32 + >>> torch.promote_types(torch.uint8, torch.long) + torch.long +""", +) + +add_docstr( + torch.qr, + r""" +qr(input, some=True, *, out=None) -> (Tensor, Tensor) + +Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, +and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` +with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and +:math:`R` being an upper triangular matrix or batch of upper triangular matrices. + +If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. +Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. + +.. warning:: + + :func:`torch.qr` is deprecated in favor of :func:`torch.linalg.qr` + and will be removed in a future PyTorch release. The boolean parameter :attr:`some` has been + replaced with a string parameter :attr:`mode`. + + ``Q, R = torch.qr(A)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A) + + ``Q, R = torch.qr(A, some=False)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A, mode="complete") + +.. warning:: + If you plan to backpropagate through QR, note that the current backward implementation + is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` + columns of :attr:`input` are linearly independent. + This behavior will probably change once QR supports pivoting. + +.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + +Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. + +Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.qr(a, some=False) + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) + True +""", +) + +add_docstr( + torch.rad2deg, + r""" +rad2deg(input, *, out=None) -> Tensor + +Returns a new tensor with each of the elements of :attr:`input` +converted from angles in radians to degrees. + +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) + >>> torch.rad2deg(a) + tensor([[ 180.0233, -180.0233], + [ 359.9894, -359.9894], + [ 89.9544, -89.9544]]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.deg2rad, + r""" +deg2rad(input, *, out=None) -> Tensor + +Returns a new tensor with each of the elements of :attr:`input` +converted from angles in degrees to radians. + +Args: + {input} + +Keyword arguments: + {out} + +Example:: + + >>> a = torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) + >>> torch.deg2rad(a) + tensor([[ 3.1416, -3.1416], + [ 6.2832, -6.2832], + [ 1.5708, -1.5708]]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.heaviside, + r""" +heaviside(input, values, *, out=None) -> Tensor + +Computes the Heaviside step function for each element in :attr:`input`. +The Heaviside step function is defined as: + +.. math:: + \text{{heaviside}}(input, values) = \begin{cases} + 0, & \text{if input < 0}\\ + values, & \text{if input == 0}\\ + 1, & \text{if input > 0} + \end{cases} +""" + + r""" + +Args: + {input} + values (Tensor): The values to use where :attr:`input` is zero. + +Keyword arguments: + {out} + +Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.rand, + """ +rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, \ +requires_grad=False, pin_memory=False) -> Tensor +""" + + r""" +Returns a tensor filled with random numbers from a uniform distribution +on the interval :math:`[0, 1)` + +The shape of the tensor is defined by the variable argument :attr:`size`. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {generator} + {out} + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.rand_like, + r""" +rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same size as :attr:`input` that is filled with +random numbers from a uniform distribution on the interval :math:`[0, 1)`. +``torch.rand_like(input)`` is equivalent to +``torch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.randint, + """ +randint(low=0, high, size, \\*, generator=None, out=None, \ +dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a tensor filled with random integers generated uniformly +between :attr:`low` (inclusive) and :attr:`high` (exclusive). + +The shape of the tensor is defined by the variable argument :attr:`size`. + +.. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + +Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + +Keyword args: + {generator} + {out} + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + + +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.randint_like, + """ +randint_like(input, low=0, high, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same shape as Tensor :attr:`input` filled with +random integers generated uniformly between :attr:`low` (inclusive) and +:attr:`high` (exclusive). + +.. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + +Args: + {input} + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.randn, + """ +randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +pin_memory=False) -> Tensor +""" + + r""" + +Returns a tensor filled with random numbers from a normal distribution +with mean `0` and variance `1` (also called the standard normal +distribution). + +.. math:: + \text{{out}}_{{i}} \sim \mathcal{{N}}(0, 1) + +For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and +unit variance as + +.. math:: + \text{{out}}_{{i}} \sim \mathcal{{CN}}(0, 1) + +This is equivalent to separately sampling the real :math:`(\operatorname{{Re}})` and imaginary +:math:`(\operatorname{{Im}})` part of :math:`\text{{out}}_i` as + +.. math:: + \operatorname{{Re}}(\text{{out}}_{{i}}) \sim \mathcal{{N}}(0, \frac{{1}}{{2}}),\quad + \operatorname{{Im}}(\text{{out}}_{{i}}) \sim \mathcal{{N}}(0, \frac{{1}}{{2}}) + +The shape of the tensor is defined by the variable argument :attr:`size`. + + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {generator} + {out} + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + +.. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.randn_like, + r""" +randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same size as :attr:`input` that is filled with +random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the +sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to +``torch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.randperm, + """ +randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, \ +device=None, requires_grad=False, pin_memory=False) -> Tensor +""" + + r""" +Returns a random permutation of integers from ``0`` to ``n - 1``. + +Args: + n (int): the upper bound (exclusive) + +Keyword args: + {generator} + {out} + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.tensor, + r""" +tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + +Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. + +.. warning:: + + When working with tensors prefer using :func:`torch.Tensor.clone`, + :func:`torch.Tensor.detach`, and :func:`torch.Tensor.requires_grad_` for + readability. Letting `t` be a tensor, ``torch.tensor(t)`` is equivalent to + ``t.clone().detach()``, and ``torch.tensor(t, requires_grad=True)`` + is equivalent to ``t.clone().detach().requires_grad_(True)``. + +.. seealso:: + + :func:`torch.as_tensor` preserves autograd history and avoids copies where possible. + :func:`torch.from_numpy` creates a tensor that shares storage with a NumPy array. + +Args: + {data} + +Keyword args: + {dtype} + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + {requires_grad} + {pin_memory} + + +Example:: + + >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + tensor([[ 0.1000, 1.2000], + [ 2.2000, 3.1000], + [ 4.9000, 5.2000]]) + + >>> torch.tensor([0, 1]) # Type inference on data + tensor([ 0, 1]) + + >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) # creates a double tensor on a CUDA device + tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') + + >>> torch.tensor(3.14159) # Create a zero-dimensional (scalar) tensor + tensor(3.1416) + + >>> torch.tensor([]) # Create an empty tensor (of size (0,)) + tensor([]) +""".format( + **factory_data_common_args + ), +) + +add_docstr( + torch.range, + r""" +range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` +with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is +the gap between two values in the tensor. + +.. math:: + \text{out}_{i+1} = \text{out}_i + \text{step}. +""" + + r""" +.. warning:: + This function is deprecated and will be removed in a future release because its behavior is inconsistent with + Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). + +Args: + start (float): the starting value for the set of points. Default: ``0``. + end (float): the ending value for the set of points + step (float): the gap between each pair of adjacent points. Default: ``1``. + +Keyword args: + {out} + {dtype} If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.range(1, 4) + tensor([ 1., 2., 3., 4.]) + >>> torch.range(1, 4, 0.5) + tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.arange, + r""" +arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` +with values from the interval ``[start, end)`` taken with common difference +:attr:`step` beginning from `start`. + +Note that non-integer :attr:`step` is subject to floating point rounding errors when +comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` +in such cases. + +.. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} +""" + + r""" +Args: + start (Number): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number): the gap between each pair of adjacent points. Default: ``1``. + +Keyword args: + {out} + {dtype} If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.ravel, + r""" +ravel(input) -> Tensor + +Return a contiguous flattened tensor. A copy is made only if needed. + +Args: + {input} + +Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) +""".format( + **common_args + ), +) + +add_docstr( + torch.remainder, + r""" +remainder(input, other, *, out=None) -> Tensor + +Computes +`Python's modulus operation `_ +entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value +is less than that of :attr:`other`. + +It may also be defined in terms of :func:`torch.div` as + +.. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and float inputs. + +.. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + +.. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + +Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + +Keyword args: + {out} + +Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) +""".format( + **common_args + ), +) + +add_docstr( + torch.renorm, + r""" +renorm(input, p, dim, maxnorm, *, out=None) -> Tensor + +Returns a tensor where each sub-tensor of :attr:`input` along dimension +:attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower +than the value :attr:`maxnorm` + +.. note:: If the norm of a row is lower than `maxnorm`, the row is unchanged + +Args: + {input} + p (float): the power for the norm computation + dim (int): the dimension to slice over to get the sub-tensors + maxnorm (float): the maximum norm to keep each sub-tensor under + +Keyword args: + {out} + +Example:: + + >>> x = torch.ones(3, 3) + >>> x[1].fill_(2) + tensor([ 2., 2., 2.]) + >>> x[2].fill_(3) + tensor([ 3., 3., 3.]) + >>> x + tensor([[ 1., 1., 1.], + [ 2., 2., 2.], + [ 3., 3., 3.]]) + >>> torch.renorm(x, 1, 0, 5) + tensor([[ 1.0000, 1.0000, 1.0000], + [ 1.6667, 1.6667, 1.6667], + [ 1.6667, 1.6667, 1.6667]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.reshape, + r""" +reshape(input, shape) -> Tensor + +Returns a tensor with the same data and number of elements as :attr:`input`, +but with the specified shape. When possible, the returned tensor will be a view +of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs +with compatible strides can be reshaped without copying, but you should not +depend on the copying vs. viewing behavior. + +See :meth:`torch.Tensor.view` on when it is possible to return a view. + +A single dimension may be -1, in which case it's inferred from the remaining +dimensions and the number of elements in :attr:`input`. + +Args: + input (Tensor): the tensor to be reshaped + shape (tuple of int): the new shape + +Example:: + + >>> a = torch.arange(4.) + >>> torch.reshape(a, (2, 2)) + tensor([[ 0., 1.], + [ 2., 3.]]) + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + tensor([ 0, 1, 2, 3]) +""", +) + + +add_docstr( + torch.result_type, + r""" +result_type(tensor1, tensor2) -> dtype + +Returns the :class:`torch.dtype` that would result from performing an arithmetic +operation on the provided input tensors. See type promotion :ref:`documentation ` +for more information on the type promotion logic. + +Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + +Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 +""", +) + +add_docstr( + torch.row_stack, + r""" +row_stack(tensors, *, out=None) -> Tensor + +Alias of :func:`torch.vstack`. +""", +) + +add_docstr( + torch.round, + r""" +round(input, *, decimals=0, out=None) -> Tensor + +Rounds elements of :attr:`input` to the nearest integer. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. +The return type of output is same as that of input's dtype. + +.. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + +.. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + +Args: + {input} + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + +Keyword args: + {out} + +Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.rsqrt, + r""" +rsqrt(input, *, out=None) -> Tensor + +Returns a new tensor with the reciprocal of the square-root of each of +the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + >>> torch.rsqrt(a) + tensor([ nan, 1.8351, 0.8053, nan]) +""".format( + **common_args + ), +) + +add_docstr( + torch.scatter, + r""" +scatter(input, dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_` +""", +) + +add_docstr( + torch.scatter_add, + r""" +scatter_add(input, dim, index, src) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_add_` +""", +) + +add_docstr( + torch.scatter_reduce, + r""" +scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor + +Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` +""", +) + +add_docstr( + torch.select, + r""" +select(input, dim, index) -> Tensor + +Slices the :attr:`input` tensor along the selected dimension at the given index. +This function returns a view of the original tensor with the given dimension removed. + +.. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + +Args: + {input} + dim (int): the dimension to slice + index (int): the index to select with + +.. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. +""".format( + **common_args + ), +) + +add_docstr( + torch.select_scatter, + r""" +select_scatter(input, src, dim, index) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. +This function returns a tensor with fresh storage; it does not create a view. + + +Args: + {input} + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into. + index (int): the index to select with + +.. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.select(input, dim, index)`` + +Example:: + + >>> a = torch.zeros(2, 2) + >>> b = torch.ones(2) + >>> a.select_scatter(b, 0, 0) + tensor([[1., 1.], + [0., 0.]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.slice_scatter, + r""" +slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor + +Embeds the values of the :attr:`src` tensor into :attr:`input` at the given +dimension. +This function returns a tensor with fresh storage; it does not create a view. + + +Args: + {input} + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into + start (Optional[int]): the start index of where to insert the slice + end (Optional[int]): the end index of where to insert the slice + step (int): the how many elements to skip in + +Example:: + + >>> a = torch.zeros(8, 8) + >>> b = torch.ones(2, 8) + >>> a.slice_scatter(b, start=6) + tensor([[0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1.]]) + + >>> b = torch.ones(8, 2) + >>> a.slice_scatter(b, dim=1, start=2, end=6, step=2) + tensor([[0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.set_flush_denormal, + r""" +set_flush_denormal(mode) -> bool + +Disables denormal floating numbers on CPU. + +Returns ``True`` if your system supports flushing denormal numbers and it +successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal` +is supported on x86 architectures supporting SSE3 and AArch64 architecture. + +Args: + mode (bool): Controls whether to enable flush denormal mode or not + +Example:: + + >>> torch.set_flush_denormal(True) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor([ 0.], dtype=torch.float64) + >>> torch.set_flush_denormal(False) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor(9.88131e-324 * + [ 1.0000], dtype=torch.float64) +""", +) + +add_docstr( + torch.set_num_threads, + r""" +set_num_threads(int) + +Sets the number of threads used for intraop parallelism on CPU. + +.. warning:: + To ensure that the correct number of threads is used, set_num_threads + must be called before running eager, JIT or autograd code. +""", +) + +add_docstr( + torch.set_num_interop_threads, + r""" +set_num_interop_threads(int) + +Sets the number of threads used for interop parallelism +(e.g. in JIT interpreter) on CPU. + +.. warning:: + Can only be called once and before any inter-op parallel work + is started (e.g. JIT execution). +""", +) + +add_docstr( + torch.sigmoid, + r""" +sigmoid(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.expit`. +""", +) + +add_docstr( + torch.logit, + r""" +logit(input, eps=None, *, out=None) -> Tensor + +Alias for :func:`torch.special.logit`. +""", +) + +add_docstr( + torch.sign, + r""" +sign(input, *, out=None) -> Tensor + +Returns a new tensor with the signs of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> a + tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) + >>> torch.sign(a) + tensor([ 1., -1., 0., 1.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.signbit, + r""" +signbit(input, *, out=None) -> Tensor + +Tests if each element of :attr:`input` has its sign bit set or not. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> torch.signbit(a) + tensor([ False, True, False, False]) + >>> a = torch.tensor([-0.0, 0.0]) + >>> torch.signbit(a) + tensor([ True, False]) + +.. note:: + signbit handles signed zeros, so negative zero (-0) returns True. + +""".format( + **common_args + ), +) + +add_docstr( + torch.sgn, + r""" +sgn(input, *, out=None) -> Tensor + +This function is an extension of torch.sign() to complex tensors. +It computes a new tensor whose elements have +the same angles as the corresponding elements of :attr:`input` and +absolute values (i.e. magnitudes) of one for complex tensors and +is equivalent to torch.sign() for non-complex tensors. + +.. math:: + \text{out}_{i} = \begin{cases} + 0 & |\text{{input}}_i| == 0 \\ + \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} + \end{cases} + +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> t.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) +""".format( + **common_args + ), +) + +add_docstr( + torch.sin, + r""" +sin(input, *, out=None) -> Tensor + +Returns a new tensor with the sine of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sin(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) +""".format( + **common_args + ), +) + +add_docstr( + torch.sinc, + r""" +sinc(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.sinc`. +""", +) + +add_docstr( + torch.sinh, + r""" +sinh(input, *, out=None) -> Tensor + +Returns a new tensor with the hyperbolic sine of the elements of +:attr:`input`. + +.. math:: + \text{out}_{i} = \sinh(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) + >>> torch.sinh(a) + tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + +.. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. +""".format( + **common_args + ), +) + +add_docstr( + torch.sort, + r""" +sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + +Sorts the elements of the :attr:`input` tensor along a given dimension +in ascending order by value. + +If :attr:`dim` is not given, the last dimension of the `input` is chosen. + +If :attr:`descending` is ``True`` then the elements are sorted in descending +order by value. + +If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving +the order of equivalent elements. + +A namedtuple of (values, indices) is returned, where the `values` are the +sorted values and `indices` are the indices of the elements in the original +`input` tensor. + +Args: + {input} + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + +Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + +Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) +""".format( + **common_args + ), +) + +add_docstr( + torch.argsort, + r""" +argsort(input, dim=-1, descending=False, stable=False) -> Tensor + +Returns the indices that sort a tensor along a given dimension in ascending +order by value. + +This is the second value returned by :meth:`torch.sort`. See its documentation +for the exact semantics of this method. + +If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving +the order of equivalent elements. If ``False``, the relative order of values +which compare equal is not guaranteed. ``True`` is slower. + +Args: + {input} + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.msort, + r""" +msort(input, *, out=None) -> Tensor + +Sorts the elements of the :attr:`input` tensor along its first dimension +in ascending order by value. + +.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.sparse_compressed_tensor, + r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """ + r"""*, dtype=None, layout=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, +CSC, BSR, or BSC - ` with specified values at +the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse +matrix multiplication operations in Compressed Sparse format are +typically faster than that for sparse tensors in COO format. Make you +have a look at :ref:`the note on the data type of the indices +`. + +{sparse_factory_device_note} + +Args: + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. + plain_indices (array_like): Plain dimension (column or row) + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. + + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + layout (:class:`torch.layout`, required): the desired layout of + returned tensor: :attr:`torch.sparse_csr`, + :attr:`torch.sparse_csc`, :attr:`torch.sparse_bsr`, or + :attr:`torch.sparse_bsc`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {requires_grad} + {check_invariants} + +Example:: + >>> compressed_indices = [0, 2, 4] + >>> plain_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_compressed_tensor(torch.tensor(compressed_indices, dtype=torch.int64), + ... torch.tensor(plain_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double, layout=torch.sparse_csr) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_csr_tensor, + r"""sparse_csr_tensor(crow_indices, col_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified +values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations +in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look +at :ref:`the note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. + col_indices (array_like): Column co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {requires_grad} + {check_invariants} + +Example:: + >>> crow_indices = [0, 2, 4] + >>> col_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_csc_tensor, + r"""sparse_csc_tensor(ccol_indices, row_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) +` with specified values at the given +:attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix +multiplication operations in CSC format are typically faster than that +for sparse tensors in COO format. Make you have a look at :ref:`the +note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. + row_indices (array_like): Row co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {requires_grad} + {check_invariants} + +Example:: + >>> ccol_indices = [0, 2, 4] + >>> row_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_bsr_tensor, + r"""sparse_bsr_tensor(crow_indices, col_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) +` with specified 2-dimensional blocks at the given +:attr:`crow_indices` and :attr:`col_indices`. Sparse matrix +multiplication operations in BSR format are typically faster than that +for sparse tensors in COO format. Make you have a look at :ref:`the +note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + blocks in a given row. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {requires_grad} + {check_invariants} + +Example:: + >>> crow_indices = [0, 1, 2] + >>> col_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsr) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_bsc_tensor, + r"""sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor + +Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse +Column)) ` with specified 2-dimensional blocks at the +given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix +multiplication operations in BSC format are typically faster than that +for sparse tensors in COO format. Make you have a look at :ref:`the +note on the data type of the indices `. + +{sparse_factory_device_note} + +Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial blocks for the tensor. Can be a list, + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + {requires_grad} + {check_invariants} + +Example:: + >>> ccol_indices = [0, 1, 2] + >>> row_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 1, 2]), + row_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsc) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sparse_coo_tensor, + r"""sparse_coo_tensor(indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor + +Constructs a :ref:`sparse tensor in COO(rdinate) format +` with specified values at the given +:attr:`indices`. + +.. note:: + + This function returns an :ref:`uncoalesced tensor + ` when :attr:`is_coalesced` is + unspecified or ``None``. + +{sparse_factory_device_note} + +Args: + indices (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` + internally. The indices are the coordinates of the non-zero values in the matrix, and thus + should be two-dimensional where the first dimension is the number of tensor dimensions and + the second dimension is the number of non-zero values. + values (array_like): Initial values for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not + provided the size will be inferred as the minimum size big enough to hold all non-zero + elements. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if None, infers data type from :attr:`values`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + {requires_grad} + {check_invariants} + is_coalesced (bool, optional): When``True``, the caller is + responsible for providing tensor indices that correspond to a + coalesced tensor. If the :attr:`check_invariants` flag is + False, no error will be raised if the prerequisites are not + met and this will lead to silently incorrect results. To force + coalescion please use :meth:`coalesce` on the resulting + Tensor. + Default: None: except for trivial cases (e.g. nnz < 2) the + resulting Tensor has is_coalesced set to ``False```. + +Example:: + + >>> i = torch.tensor([[0, 1, 1], + ... [2, 0, 2]]) + >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) + >>> torch.sparse_coo_tensor(i, v, [2, 4]) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 4), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v) # Shape inference + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v, [2, 4], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_coo) + + # Create an empty sparse tensor with the following invariants: + # 1. sparse_dim + dense_dim = len(SparseTensor.shape) + # 2. SparseTensor._indices().shape = (sparse_dim, nnz) + # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) + # + # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and + # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0,)), + size=(1,), nnz=0, layout=torch.sparse_coo) + + # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and + # sparse_dim = 1 + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0, 2)), + size=(1, 2), nnz=0, layout=torch.sparse_coo) + +.. _torch.sparse: https://pytorch.org/docs/stable/sparse.html +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.sqrt, + r""" +sqrt(input, *, out=None) -> Tensor + +Returns a new tensor with the square-root of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}} +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.sqrt(a) + tensor([ nan, 1.0112, 0.2883, 0.6933]) +""".format( + **common_args + ), +) + +add_docstr( + torch.square, + r""" +square(input, *, out=None) -> Tensor + +Returns a new tensor with the square of the elements of :attr:`input`. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.square(a) + tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) +""".format( + **common_args + ), +) + +add_docstr( + torch.squeeze, + r""" +squeeze(input, dim=None) -> Tensor + +Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + +For example, if `input` is of shape: +:math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` +will be of shape: :math:`(A \times B \times C \times D)`. + +When :attr:`dim` is given, a squeeze operation is done only in the given +dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, +``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` +will squeeze the tensor to the shape :math:`(A \times B)`. + +.. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + +.. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + +Args: + {input} + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + +Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) +""".format( + **common_args + ), +) + +add_docstr( + torch.std, + r""" +std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + +Calculates the standard deviation over the dimensions specified by :attr:`dim`. +:attr:`dim` can be a single dimension, list of dimensions, or ``None`` to +reduce over all dimensions. + +The standard deviation (:math:`\sigma`) is calculated as + +.. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} + +Args: + {input} + {dim} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {keepdim} + {out} + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.std_mean, + r""" +std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + +Calculates the standard deviation and mean over the dimensions specified by +:attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or +``None`` to reduce over all dimensions. + +The standard deviation (:math:`\sigma`) is calculated as + +.. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. + +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {keepdim} + {out} + +Returns: + A tuple (std, mean) containing the standard deviation and mean. + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.sub, + r""" +sub(input, other, *, alpha=1, out=None) -> Tensor + +Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + +.. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i +""" + + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + +Keyword args: + alpha (Number): the multiplier for :attr:`other`. + {out} + +Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) +""".format( + **common_args + ), +) + +add_docstr( + torch.subtract, + r""" +subtract(input, other, *, alpha=1, out=None) -> Tensor + +Alias for :func:`torch.sub`. +""", +) + +add_docstr( + torch.sum, + r""" +sum(input, *, dtype=None) -> Tensor + +Returns the sum of all elements in the :attr:`input` tensor. + +Args: + {input} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + +.. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + +Returns the sum of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, +reduce over all of them. + +{keepdim_details} + +Args: + {input} + {opt_dim} + {keepdim} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.nansum, + r""" +nansum(input, *, dtype=None) -> Tensor + +Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. + +Args: + {input} + +Keyword args: + {dtype} + +Example:: + + >>> a = torch.tensor([1., 2., float('nan'), 4.]) + >>> torch.nansum(a) + tensor(7.) + +.. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + +Returns the sum of each row of the :attr:`input` tensor in the given +dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. +If :attr:`dim` is a list of dimensions, reduce over all of them. + +{keepdim_details} + +Args: + {input} + {opt_dim} + {keepdim} + +Keyword args: + {dtype} + +Example:: + + >>> torch.nansum(torch.tensor([1., float("nan")])) + 1.0 + >>> a = torch.tensor([[1, 2], [3., float("nan")]]) + >>> torch.nansum(a) + tensor(6.) + >>> torch.nansum(a, dim=0) + tensor([4., 2.]) + >>> torch.nansum(a, dim=1) + tensor([3., 3.]) +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.svd, + r""" +svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition of either a matrix or batch of +matrices :attr:`input`. The singular value decomposition is represented as a +namedtuple `(U, S, V)`, such that :attr:`input` :math:`= U \text{diag}(S) V^{\text{H}}`. +where :math:`V^{\text{H}}` is the transpose of `V` for real inputs, +and the conjugate transpose of `V` for complex inputs. +If :attr:`input` is a batch of matrices, then `U`, `S`, and `V` are also +batched with the same batch dimensions as :attr:`input`. + +If :attr:`some` is `True` (default), the method returns the reduced singular +value decomposition. In this case, if the last two dimensions of :attr:`input` are +`m` and `n`, then the returned `U` and `V` matrices will contain only +`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is `False`, the returned `U` and `V` will be +zero-filled matrices of shape `(m, m)` and `(n, n)` +respectively, and the same device as :attr:`input`. The argument :attr:`some` +has no effect when :attr:`compute_uv` is `False`. + +Supports :attr:`input` of float, double, cfloat and cdouble data types. +The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will +always be real-valued, even if :attr:`input` is complex. + +.. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.mH + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.linalg.svdvals(A) + +.. note:: Differences with :func:`torch.linalg.svd`: + + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matrices`. Note that + default value for both is `True`, so the default behavior is + effectively the opposite. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns + `Vh`, that is, :math:`V^{\text{H}}`. + * If :attr:`compute_uv` is `False`, :func:`torch.svd` returns zero-filled + tensors for `U` and `Vh`, whereas :func:`torch.linalg.svd` returns + empty tensors. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch are returned in descending order. + +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is `True`. + +.. note:: When :attr:`some` is `False`, the gradients on `U[..., :, min(m, n):]` + and `V[..., :, min(m, n):]` will be ignored in the backward pass, as those vectors + can be arbitrary bases of the corresponding subspaces. + +.. note:: The implementation of :func:`torch.linalg.svd` on CPU uses LAPACK's routine `?gesdd` + (a divide-and-conquer algorithm) instead of `?gesvd` for speed. Analogously, + on GPU, it uses cuSOLVER's routines `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 + and later, and MAGMA's routine `gesdd` on earlier versions of CUDA. + +.. note:: The returned `U` will not be contiguous. The matrix (or batch of matrices) will + be represented as a column-major matrix (i.e. Fortran-contiguous). + +.. warning:: The gradients with respect to `U` and `V` will only be finite when the input does not + have zero nor repeated singular values. + +.. warning:: If the distance between any two singular values is close to zero, the gradients with respect to + `U` and `V` will be numerically unstable, as they depends on + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix + has small singular values, as these gradients also depend on `S^{-1}`. + +.. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique, + as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column. + The same happens when :attr:`input` has repeated singular values, where one may multiply + the columns of the spanning subspace in `U` and `V` by a rotation matrix + and `the resulting vectors will span the same subspace`_. + Different platforms, like NumPy, or inputs on different device types, + may produce different `U` and `V` tensors. + +Args: + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m, n)` matrices. + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently, the shape of returned `U` and `V`. Default: `True`. + compute_uv (bool, optional): controls whether to compute `U` and `V`. Default: `True`. + +Keyword args: + out (tuple, optional): the output tuple of tensors + +Example:: + + >>> a = torch.randn(5, 3) + >>> a + tensor([[ 0.2364, -0.7752, 0.6372], + [ 1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [ 0.3550, -0.4022, 1.5569], + [ 0.2445, -0.0158, 1.1414]]) + >>> u, s, v = torch.svd(a) + >>> u + tensor([[ 0.4027, 0.0287, 0.5434], + [-0.1946, 0.8833, 0.3679], + [ 0.4296, -0.2890, 0.5261], + [ 0.6604, 0.2717, -0.2618], + [ 0.4234, 0.2481, -0.4733]]) + >>> s + tensor([2.3289, 2.0315, 0.7806]) + >>> v + tensor([[-0.0199, 0.8766, 0.4809], + [-0.5080, 0.4054, -0.7600], + [ 0.8611, 0.2594, -0.4373]]) + >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) + tensor(8.6531e-07) + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, v = torch.svd(a_big) + >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) + tensor(2.6503e-06) + +.. _the resulting vectors will span the same subspace: + (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) +""", +) + + +add_docstr( + torch.t, + r""" +t(input) -> Tensor + +Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 +and 1. + +0-D and 1-D tensors are returned as is. When input is a 2-D tensor this +is equivalent to ``transpose(input, 0, 1)``. + +Args: + {input} + +Example:: + + >>> x = torch.randn(()) + >>> x + tensor(0.1995) + >>> torch.t(x) + tensor(0.1995) + >>> x = torch.randn(3) + >>> x + tensor([ 2.4320, -0.4608, 0.7702]) + >>> torch.t(x) + tensor([ 2.4320, -0.4608, 0.7702]) + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.4875, 0.9158, -0.5872], + [ 0.3938, -0.6929, 0.6932]]) + >>> torch.t(x) + tensor([[ 0.4875, 0.3938], + [ 0.9158, -0.6929], + [-0.5872, 0.6932]]) + +See also :func:`torch.transpose`. +""".format( + **common_args + ), +) + +add_docstr( + torch.flip, + r""" +flip(input, dims) -> Tensor + +Reverse the order of an n-D tensor along given axis in dims. + +.. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + +Args: + {input} + dims (a list or tuple): axis to flip on + +Example:: + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]]]) + >>> torch.flip(x, [0, 1]) + tensor([[[ 6, 7], + [ 4, 5]], + + [[ 2, 3], + [ 0, 1]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.fliplr, + r""" +fliplr(input) -> Tensor + +Flip tensor in the left/right direction, returning a new tensor. + +Flip the entries in each row in the left/right direction. +Columns are preserved, but appear in a different order than before. + +Note: + Requires the tensor to be at least 2-D. + +.. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. + +Args: + input (Tensor): Must be at least 2-dimensional. + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.fliplr(x) + tensor([[1, 0], + [3, 2]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.flipud, + r""" +flipud(input) -> Tensor + +Flip tensor in the up/down direction, returning a new tensor. + +Flip the entries in each column in the up/down direction. +Rows are preserved, but appear in a different order than before. + +Note: + Requires the tensor to be at least 1-D. + +.. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. + +Args: + input (Tensor): Must be at least 1-dimensional. + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flipud(x) + tensor([[2, 3], + [0, 1]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.roll, + r""" +roll(input, shifts, dims=None) -> Tensor + +Roll the tensor :attr:`input` along the given dimension(s). Elements that are +shifted beyond the last position are re-introduced at the first position. If +:attr:`dims` is `None`, the tensor will be flattened before rolling and then +restored to the original shape. + +Args: + {input} + shifts (int or tuple of ints): The number of places by which the elements + of the tensor are shifted. If shifts is a tuple, dims must be a tuple of + the same size, and each dimension will be rolled by the corresponding + value + dims (int or tuple of ints): Axis along which to roll + +Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + >>> x + tensor([[1, 2], + [3, 4], + [5, 6], + [7, 8]]) + >>> torch.roll(x, 1) + tensor([[8, 1], + [2, 3], + [4, 5], + [6, 7]]) + >>> torch.roll(x, 1, 0) + tensor([[7, 8], + [1, 2], + [3, 4], + [5, 6]]) + >>> torch.roll(x, -1, 0) + tensor([[3, 4], + [5, 6], + [7, 8], + [1, 2]]) + >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) + tensor([[6, 5], + [8, 7], + [2, 1], + [4, 3]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.rot90, + r""" +rot90(input, k=1, dims=[0,1]) -> Tensor + +Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. +Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + +Args: + {input} + k (int): number of times to rotate. Default value is 1 + dims (a list or tuple): axis to rotate. Default value is [0, 1] + +Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.rot90(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.rot90(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.take, + r""" +take(input, index) -> Tensor + +Returns a new tensor with the elements of :attr:`input` at the given indices. +The input tensor is treated as if it were viewed as a 1-D tensor. The result +takes the same shape as the indices. + +Args: + {input} + index (LongTensor): the indices into tensor + +Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> torch.take(src, torch.tensor([0, 2, 5])) + tensor([ 4, 5, 8]) +""".format( + **common_args + ), +) + +add_docstr( + torch.take_along_dim, + r""" +take_along_dim(input, indices, dim=None, *, out=None) -> Tensor + +Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. + +If :attr:`dim` is None, the input array is treated as if it has been flattened to 1d. + +Functions that return indices along a dimension, like :func:`torch.argmax` and :func:`torch.argsort`, +are designed to work with this function. See the examples below. + +.. note:: + This function is similar to NumPy's `take_along_axis`. + See also :func:`torch.gather`. + +Args: + {input} + indices (tensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. + +Keyword args: + {out} + +Example:: + + >>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) + >>> max_idx = torch.argmax(t) + >>> torch.take_along_dim(t, max_idx) + tensor([60]) + >>> sorted_idx = torch.argsort(t, dim=1) + >>> torch.take_along_dim(t, sorted_idx, dim=1) + tensor([[10, 20, 30], + [40, 50, 60]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.tan, + r""" +tan(input, *, out=None) -> Tensor + +Returns a new tensor with the tangent of the elements of :attr:`input`. + +.. math:: + \text{out}_{i} = \tan(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.2027, -1.7687, 0.4412, -1.3856]) + >>> torch.tan(a) + tensor([-2.5930, 4.9859, 0.4722, -5.3366]) +""".format( + **common_args + ), +) + +add_docstr( + torch.tanh, + r""" +tanh(input, *, out=None) -> Tensor + +Returns a new tensor with the hyperbolic tangent of the elements +of :attr:`input`. + +.. math:: + \text{out}_{i} = \tanh(\text{input}_{i}) +""" + + r""" +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) + >>> torch.tanh(a) + tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) +""".format( + **common_args + ), +) + +add_docstr( + # torch.softmax doc str. Point this to torch.nn.functional.softmax + torch.softmax, + r""" +softmax(input, dim, *, dtype=None) -> Tensor + +Alias for :func:`torch.nn.functional.softmax`. +""", +) + +add_docstr( + torch.topk, + r""" +topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) + +Returns the :attr:`k` largest elements of the given :attr:`input` tensor along +a given dimension. + +If :attr:`dim` is not given, the last dimension of the `input` is chosen. + +If :attr:`largest` is ``False`` then the `k` smallest elements are returned. + +A namedtuple of `(values, indices)` is returned with the `values` and +`indices` of the largest `k` elements of each row of the `input` tensor in the +given dimension `dim`. + +The boolean option :attr:`sorted` if ``True``, will make sure that the returned +`k` elements are themselves sorted + +Args: + {input} + k (int): the k in "top-k" + dim (int, optional): the dimension to sort along + largest (bool, optional): controls whether to return largest or + smallest elements + sorted (bool, optional): controls whether to return the elements + in sorted order + +Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be + optionally given to be used as output buffers + +Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.topk(x, 3) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) +""".format( + **common_args + ), +) + +add_docstr( + torch.trace, + r""" +trace(input) -> Tensor + +Returns the sum of the elements of the diagonal of the input 2-D matrix. + +Example:: + + >>> x = torch.arange(1., 10.).view(3, 3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]]) + >>> torch.trace(x) + tensor(15.) +""", +) + +add_docstr( + torch.transpose, + r""" +transpose(input, dim0, dim1) -> Tensor + +Returns a tensor that is a transposed version of :attr:`input`. +The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + +If :attr:`input` is a strided tensor then the resulting :attr:`out` +tensor shares its underlying storage with the :attr:`input` tensor, so +changing the content of one would change the content of the other. + +If :attr:`input` is a :ref:`sparse tensor ` then the +resulting :attr:`out` tensor *does not* share the underlying storage +with the :attr:`input` tensor. + +If :attr:`input` is a :ref:`sparse tensor ` with compressed +layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments +:attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must +both be sparse dimensions. The batch dimensions of a sparse tensor are the +dimensions preceding the sparse dimensions. + +.. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + +Args: + {input} + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + +Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + +See also :func:`torch.t`. +""".format( + **common_args + ), +) + +add_docstr( + torch.triangular_solve, + r""" +triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) + +Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` +and multiple right-hand sides :math:`b`. + +In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular +(or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal. + +`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are +batches of 2D matrices. If the inputs are batches, then returns +batched outputs `X` + +If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and +:attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, +the result may contain `NaN` s. + +Supports input of float, double, cfloat and cdouble data types. + +.. warning:: + + :func:`torch.triangular_solve` is deprecated in favor of :func:`torch.linalg.solve_triangular` + and will be removed in a future PyTorch release. + :func:`torch.linalg.solve_triangular` has its arguments reversed and does not return a + copy of one of the inputs. + + ``X = torch.triangular_solve(B, A).solution`` should be replaced with + + .. code:: python + + X = torch.linalg.solve_triangular(A, B) + +Args: + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``. + transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``, + and `op(A) = A` if it is ``False``. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + +Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + +Returns: + A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` + is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` + (or whatever variant of the system of equations, depending on the keyword arguments.) + +Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + torch.return_types.triangular_solve( + solution=tensor([[ 1.7841, 2.9046, -2.5405], + [ 1.9320, 0.9270, -1.2826]]), + cloned_coefficient=tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) +""", +) + +add_docstr( + torch.tril, + r""" +tril(input, diagonal=0, *, out=None) -> Tensor + +Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices +:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + +The lower triangular part of the matrix is defined as the elements on and +below the diagonal. + +The argument :attr:`diagonal` controls which diagonal to consider. If +:attr:`diagonal` = 0, all elements on and below the main diagonal are +retained. A positive value includes just as many diagonals above the main +diagonal, and similarly a negative value excludes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where +:math:`d_{1}, d_{2}` are the dimensions of the matrix. +""" + + r""" +Args: + {input} + diagonal (int, optional): the diagonal to consider + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0813, -0.8619, 0.7105], + [ 0.0935, 0.1380, 2.2112], + [-0.3409, -0.9828, 0.0289]]) + >>> torch.tril(a) + tensor([[-1.0813, 0.0000, 0.0000], + [ 0.0935, 0.1380, 0.0000], + [-0.3409, -0.9828, 0.0289]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], + [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) + >>> torch.tril(b, diagonal=1) + tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) + >>> torch.tril(b, diagonal=-1) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) +""".format( + **common_args + ), +) + +# docstr is split in two parts to avoid format mis-captureing :math: braces '{}' +# as common args. +add_docstr( + torch.tril_indices, + r""" +tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + +Returns the indices of the lower triangular part of a :attr:`row`-by- +:attr:`col` matrix in a 2-by-N Tensor, where the first row contains row +coordinates of all indices and the second row contains column coordinates. +Indices are ordered based on rows and then columns. + +The lower triangular part of the matrix is defined as the elements on and +below the diagonal. + +The argument :attr:`offset` controls which diagonal to consider. If +:attr:`offset` = 0, all elements on and below the main diagonal are +retained. A positive value includes just as many diagonals above the main +diagonal, and similarly a negative value excludes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` +where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + +.. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. +""" + + r""" +Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, ``torch.long``. + {device} + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + +Example:: + + >>> a = torch.tril_indices(3, 3) + >>> a + tensor([[0, 1, 1, 2, 2, 2], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, -1) + >>> a + tensor([[1, 2, 2, 3, 3, 3], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.triu, + r""" +triu(input, diagonal=0, *, out=None) -> Tensor + +Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices +:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + +The upper triangular part of the matrix is defined as the elements on and +above the diagonal. + +The argument :attr:`diagonal` controls which diagonal to consider. If +:attr:`diagonal` = 0, all elements on and above the main diagonal are +retained. A positive value excludes just as many diagonals above the main +diagonal, and similarly a negative value includes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where +:math:`d_{1}, d_{2}` are the dimensions of the matrix. +""" + + r""" +Args: + {input} + diagonal (int, optional): the diagonal to consider + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.3480, -0.5211, -0.4573]]) + >>> torch.triu(a) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.0000, -1.0680, 0.6602], + [ 0.0000, 0.0000, -0.4573]]) + >>> torch.triu(a, diagonal=1) + tensor([[ 0.0000, 0.5207, 2.0049], + [ 0.0000, 0.0000, 0.6602], + [ 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(a, diagonal=-1) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.0000, -0.5211, -0.4573]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) +""".format( + **common_args + ), +) + +# docstr is split in two parts to avoid format mis-capturing :math: braces '{}' +# as common args. +add_docstr( + torch.triu_indices, + r""" +triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + +Returns the indices of the upper triangular part of a :attr:`row` by +:attr:`col` matrix in a 2-by-N Tensor, where the first row contains row +coordinates of all indices and the second row contains column coordinates. +Indices are ordered based on rows and then columns. + +The upper triangular part of the matrix is defined as the elements on and +above the diagonal. + +The argument :attr:`offset` controls which diagonal to consider. If +:attr:`offset` = 0, all elements on and above the main diagonal are +retained. A positive value excludes just as many diagonals above the main +diagonal, and similarly a negative value includes just as many diagonals below +the main diagonal. The main diagonal are the set of indices +:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` +where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + +.. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. +""" + + r""" +Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + +Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, ``torch.long``. + {device} + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + +Example:: + + >>> a = torch.triu_indices(3, 3) + >>> a + tensor([[0, 0, 0, 1, 1, 2], + [0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, -1) + >>> a + tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], + [0, 1, 2, 0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1], + [1, 2, 2]]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.true_divide, + r""" +true_divide(dividend, divisor, *, out) -> Tensor + +Alias for :func:`torch.div` with ``rounding_mode=None``. +""", +) + +add_docstr( + torch.trunc, + r""" +trunc(input, *, out=None) -> Tensor + +Returns a new tensor with the truncated integer values of +the elements of :attr:`input`. + +For integer inputs, follows the array-api convention of returning a +copy of the input tensor. + +Args: + {input} + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) + >>> torch.trunc(a) + tensor([ 3., 0., -0., -0.]) +""".format( + **common_args + ), +) + +add_docstr( + torch.fake_quantize_per_tensor_affine, + r""" +fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + +Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, +:attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + +.. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + +Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + +Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + +Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) +""", +) + +add_docstr( + torch.fake_quantize_per_channel_affine, + r""" +fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor + +Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, +:attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`. + +.. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + +Args: + input (Tensor): the input value(s), in ``torch.float32`` + scale (Tensor): quantization scale, per channel in ``torch.float32`` + zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32`` + axis (int32): channel axis + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + +Returns: + Tensor: A newly fake_quantized per channel ``torch.float32`` tensor + +Example:: + + >>> x = torch.randn(2, 2, 2) + >>> x + tensor([[[-0.2525, -0.0466], + [ 0.3491, -0.2168]], + + [[-0.5906, 1.6258], + [ 0.6444, -0.0542]]]) + >>> scales = (torch.randn(2) + 1) * 0.05 + >>> scales + tensor([0.0475, 0.0486]) + >>> zero_points = torch.zeros(2).to(torch.int32) + >>> zero_points + tensor([0, 0]) + >>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255) + tensor([[[0.0000, 0.0000], + [0.3405, 0.0000]], + + [[0.0000, 1.6134], + [0.6323, 0.0000]]]) +""", +) + +add_docstr( + torch.fix, + r""" +fix(input, *, out=None) -> Tensor + +Alias for :func:`torch.trunc` +""", +) + +add_docstr( + torch.unsqueeze, + r""" +unsqueeze(input, dim) -> Tensor + +Returns a new tensor with a dimension of size one inserted at the +specified position. + +The returned tensor shares the same underlying data with this tensor. + +A :attr:`dim` value within the range ``[-input.dim() - 1, input.dim() + 1)`` +can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` +applied at :attr:`dim` = ``dim + input.dim() + 1``. + +Args: + {input} + dim (int): the index at which to insert the singleton dimension + +Example:: + + >>> x = torch.tensor([1, 2, 3, 4]) + >>> torch.unsqueeze(x, 0) + tensor([[ 1, 2, 3, 4]]) + >>> torch.unsqueeze(x, 1) + tensor([[ 1], + [ 2], + [ 3], + [ 4]]) +""".format( + **common_args + ), +) + +add_docstr( + torch.var, + r""" +var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + +Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` +can be a single dimension, list of dimensions, or ``None`` to reduce over all +dimensions. + +The variance (:math:`\sigma^2`) is calculated as + +.. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {keepdim} + {out} + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.var_mean, + r""" +var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + +Calculates the variance and mean over the dimensions specified by :attr:`dim`. +:attr:`dim` can be a single dimension, list of dimensions, or ``None`` to +reduce over all dimensions. + +The variance (:math:`\sigma^2`) is calculated as + +.. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} + +Args: + {input} + {opt_dim} + +Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + {keepdim} + {out} + +Returns: + A tuple (var, mean) containing the variance and mean. + +Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + +""".format( + **multi_dim_common + ), +) + +add_docstr( + torch.zeros, + r""" +zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Returns a tensor filled with the scalar value `0`, with the shape defined +by the variable argument :attr:`size`. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.zeros_like, + r""" +zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns a tensor filled with the scalar value `0`, with the same size as +:attr:`input`. ``torch.zeros_like(input)`` is equivalent to +``torch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +.. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.zeros_like(input, out=output)`` is equivalent to + ``torch.zeros(input.size(), out=output)``. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +Example:: + + >>> input = torch.empty(2, 3) + >>> torch.zeros_like(input) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.empty, + """ +empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, \ +memory_format=torch.contiguous_format) -> Tensor + +Returns a tensor filled with uninitialized data. The shape of the tensor is +defined by the variable argument :attr:`size`. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + +Keyword args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + {memory_format} + +Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.empty_like, + r""" +empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + +Returns an uninitialized tensor with the same size as :attr:`input`. +``torch.empty_like(input)`` is equivalent to +``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + {input} + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} + +Example:: + + >>> a=torch.empty((2,3), dtype=torch.int32, device = 'cuda') + >>> torch.empty_like(a) + tensor([[0, 0, 0], + [0, 0, 0]], device='cuda:0', dtype=torch.int32) +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.empty_strided, + r""" +empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + +Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. + +.. warning:: + If the constructed tensor is "overlapped" (with multiple indices referring to the same element + in memory) its behavior is undefined. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Example:: + + >>> a = torch.empty_strided((2, 3), (1, 2)) + >>> a + tensor([[8.9683e-44, 4.4842e-44, 5.1239e+07], + [0.0000e+00, 0.0000e+00, 3.0705e-41]]) + >>> a.stride() + (1, 2) + >>> a.size() + torch.Size([2, 3]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.empty_permuted, + r""" +empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + +Creates an uninitialized, non-overlapping and dense tensor with the +specified :attr:`size`, with :attr:`physical_layout` specifying how the +dimensions are physically laid out in memory (each logical dimension is listed +from outermost to innermost). :attr:`physical_layout` is a generalization +of NCHW/NHWC notation: if each dimension is assigned a number according to +what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)`` +while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output +tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]`` +(notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``). + +Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense +tensor with no overlaps. If possible, prefer using this function over +:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`. + +.. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + +Args: + size (tuple of int): the shape of the output tensor + physical_layout (tuple of int): the ordering of dimensions physically in memory + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {pin_memory} + +Examples: + + >>> torch.empty((2, 3, 5, 7)).stride() + (105, 35, 7, 1) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride() + (105, 35, 7, 1) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order() + (0, 2, 3, 1) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.full, + r""" +full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The +tensor's dtype is inferred from :attr:`fill_value`. + +Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + +Keyword args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) +""".format( + **factory_common_args + ), +) + +add_docstr( + torch.full_like, + """ +full_like(input, fill_value, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ +memory_format=torch.preserve_format) -> Tensor + +Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. +``torch.full_like(input, fill_value)`` is equivalent to +``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``. + +Args: + {input} + fill_value: the number to fill the output tensor with. + +Keyword args: + {dtype} + {layout} + {device} + {requires_grad} + {memory_format} +""".format( + **factory_like_common_args + ), +) + +add_docstr( + torch.det, + r""" +det(input) -> Tensor + +Alias for :func:`torch.linalg.det` +""", +) + +add_docstr( + torch.where, + r""" +where(condition, input, other, *, out=None) -> Tensor + +Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + +The operation is defined as: + +.. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} +""" + + r""" +.. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + +Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + +Keyword args: + {out} + +Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + +Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + +.. function:: where(condition) -> tuple of LongTensor + :noindex: + +``torch.where(condition)`` is identical to +``torch.nonzero(condition, as_tuple=True)``. + +.. note:: + See also :func:`torch.nonzero`. +""".format( + **common_args + ), +) + +add_docstr( + torch.logdet, + r""" +logdet(input) -> Tensor + +Calculates log determinant of a square matrix or batches of square matrices. + +It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has +a negative determinant. + +.. note:: + Backward through :meth:`logdet` internally uses SVD results when :attr:`input` + is not invertible. In this case, double backward through :meth:`logdet` will + be unstable in when :attr:`input` doesn't have distinct singular values. See + :func:`torch.linalg.svd` for details. + +.. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + +Arguments: + input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more + batch dimensions. + +Example:: + + >>> A = torch.randn(3, 3) + >>> torch.det(A) + tensor(0.2611) + >>> torch.logdet(A) + tensor(-1.3430) + >>> A + tensor([[[ 0.9254, -0.6213], + [-0.5787, 1.6843]], + + [[ 0.3242, -0.9665], + [ 0.4539, -0.0887]], + + [[ 1.1336, -0.4025], + [-0.7089, 0.9032]]]) + >>> A.det() + tensor([1.1990, 0.4099, 0.7386]) + >>> A.det().log() + tensor([ 0.1815, -0.8917, -0.3031]) +""", +) + +add_docstr( + torch.slogdet, + r""" +slogdet(input) -> (Tensor, Tensor) + +Alias for :func:`torch.linalg.slogdet` +""", +) + +add_docstr( + torch.pinverse, + r""" +pinverse(input, rcond=1e-15) -> Tensor + +Alias for :func:`torch.linalg.pinv` +""", +) + +add_docstr( + torch.hann_window, + """ +hann_window(window_length, periodic=True, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Hann window function. + +.. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.hann_window(L, periodic=True)`` equal to +``torch.hann_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + +""".format( + **factory_common_args + ), +) + + +add_docstr( + torch.hamming_window, + """ +hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Hamming window function. + +.. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.hamming_window(L, periodic=True)`` equal to +``torch.hamming_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + +.. note:: + This is a generalized version of :meth:`torch.hann_window`. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window. + +""".format( + **factory_common_args + ), +) + + +add_docstr( + torch.bartlett_window, + """ +bartlett_window(window_length, periodic=True, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Bartlett window function. + +.. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.bartlett_window(L, periodic=True)`` equal to +``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + +""".format( + **factory_common_args + ), +) + + +add_docstr( + torch.blackman_window, + """ +blackman_window(window_length, periodic=True, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Blackman window function. + +.. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + +where :math:`N` is the full window size. + +The input :attr:`window_length` is a positive integer controlling the +returned window size. :attr:`periodic` flag determines whether the returned +window trims off the last duplicate value from the symmetric window and is +ready to be used as a periodic window with functions like +:meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in +above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have +``torch.blackman_window(L, periodic=True)`` equal to +``torch.blackman_window(L + 1, periodic=False)[:-1])``. + +.. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. +""" + + r""" +Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + +Keyword args: + {dtype} Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +Returns: + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + +""".format( + **factory_common_args + ), +) + + +add_docstr( + torch.kaiser_window, + """ +kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + + r""" +Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + +Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and +``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, +where ``L`` is the :attr:`window_length`. This function computes: + +.. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + +Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling +``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. +The :attr:`periodic` argument is intended as a helpful shorthand +to produce a periodic window as input to functions like :func:`torch.stft`. + +.. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + +""" + + r""" +Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + +Keyword args: + {dtype} + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + {device} + {requires_grad} + +""".format( + **factory_common_args + ), +) + + +add_docstr( + torch.vander, + """ +vander(x, N=None, increasing=False) -> Tensor +""" + + r""" +Generates a Vandermonde matrix. + +The columns of the output matrix are elementwise powers of the input vector :math:`x^{{(N-1)}}, x^{{(N-2)}}, ..., x^0`. +If increasing is True, the order of the columns is reversed :math:`x^0, x^1, ..., x^{{(N-1)}}`. Such a +matrix with a geometric progression in each row is named for Alexandre-Theophile Vandermonde. + +Arguments: + x (Tensor): 1-D input tensor. + N (int, optional): Number of columns in the output. If N is not specified, + a square array is returned :math:`(N = len(x))`. + increasing (bool, optional): Order of the powers of the columns. If True, + the powers increase from left to right, if False (the default) they are reversed. + +Returns: + Tensor: Vandermonde matrix. If increasing is False, the first column is :math:`x^{{(N-1)}}`, + the second :math:`x^{{(N-2)}}` and so forth. If increasing is True, the columns + are :math:`x^0, x^1, ..., x^{{(N-1)}}`. + +Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> torch.vander(x) + tensor([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + >>> torch.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 4, 2, 1], + [ 9, 3, 1], + [25, 5, 1]]) + >>> torch.vander(x, N=3, increasing=True) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) + +""".format( + **factory_common_args + ), +) + + +add_docstr( + torch.unbind, + r""" +unbind(input, dim=0) -> seq + +Removes a tensor dimension. + +Returns a tuple of all slices along a given dimension, already without it. + +Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + +Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) +""", +) + + +add_docstr( + torch.combinations, + r""" +combinations(input, r=2, with_replacement=False) -> seq + +Compute combinations of length :math:`r` of the given tensor. The behavior is similar to +python's `itertools.combinations` when `with_replacement` is set to `False`, and +`itertools.combinations_with_replacement` when `with_replacement` is set to `True`. + +Arguments: + input (Tensor): 1D vector. + r (int, optional): number of elements to combine + with_replacement (bool, optional): whether to allow duplication in combination + +Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, do + `itertools.combinations` or `itertools.combinations_with_replacement` on these + lists, and finally convert the resulting list into tensor. + +Example:: + + >>> a = [1, 2, 3] + >>> list(itertools.combinations(a, r=2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(itertools.combinations(a, r=3)) + [(1, 2, 3)] + >>> list(itertools.combinations_with_replacement(a, r=2)) + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + >>> tensor_a = torch.tensor(a) + >>> torch.combinations(tensor_a) + tensor([[1, 2], + [1, 3], + [2, 3]]) + >>> torch.combinations(tensor_a, r=3) + tensor([[1, 2, 3]]) + >>> torch.combinations(tensor_a, with_replacement=True) + tensor([[1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3]]) + +""", +) + +add_docstr( + torch.trapezoid, + r""" +trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + +Computes the `trapezoidal rule `_ along +:attr:`dim`. By default the spacing between elements is assumed to be 1, but +:attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be +used to specify arbitrary spacing along :attr:`dim`. + + +Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, +the default computation is + +.. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + +When :attr:`dx` is specified the computation becomes + +.. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + +effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, +assuming :attr:`x` is also a one-dimensional tensor with +elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + +.. math:: + \begin{aligned} + \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + +When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. +The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` +and :attr:`y`, the function computes the difference between consecutive elements along +dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have +the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. +After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. +See the examples below for details. + +.. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + +Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + +Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + +Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) +""", +) + +add_docstr( + torch.trapz, + r""" +trapz(y, x, *, dim=-1) -> Tensor + +Alias for :func:`torch.trapezoid`. +""", +) + +add_docstr( + torch.cumulative_trapezoid, + r""" +cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + +Cumulatively computes the `trapezoidal rule `_ +along :attr:`dim`. By default the spacing between elements is assumed to be 1, but +:attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be +used to specify arbitrary spacing along :attr:`dim`. + +For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` +and this function is that, :func:`torch.trapezoid` returns a value for each integration, +where as this function returns a cumulative value for every spacing within the integration. This +is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + +Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + +Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + +Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) +""", +) + +add_docstr( + torch.repeat_interleave, + r""" +repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + +Repeat elements of a tensor. + +.. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + +Args: + {input} + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + +Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + +Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + +If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be +`tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, +`1` appears `n2` times, `2` appears `n3` times, etc. + +.. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + +Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + +Args: + repeats (Tensor): The number of repetitions for each element. + +Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + +Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + +""".format( + **common_args + ), +) + +add_docstr( + torch.tile, + r""" +tile(input, dims) -> Tensor + +Constructs a tensor by repeating the elements of :attr:`input`. +The :attr:`dims` argument specifies the number of repetitions +in each dimension. + +If :attr:`dims` specifies fewer dimensions than :attr:`input` has, then +ones are prepended to :attr:`dims` until all dimensions are specified. +For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`dims` +is (2, 2), then :attr:`dims` is treated as (1, 1, 2, 2). + +Analogously, if :attr:`input` has fewer dimensions than :attr:`dims` +specifies, then :attr:`input` is treated as if it were unsqueezed at +dimension zero until it has as many dimensions as :attr:`dims` specifies. +For example, if :attr:`input` has shape (4, 2) and :attr:`dims` +is (3, 3, 2, 2), then :attr:`input` is treated as if it had the +shape (1, 1, 4, 2). + +.. note:: + + This function is similar to NumPy's tile function. + +Args: + input (Tensor): the tensor whose elements to repeat. + dims (tuple): the number of repetitions per dimension. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) +""", +) + +add_docstr( + torch.quantize_per_tensor, + r""" +quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + +Converts a float tensor to a quantized tensor with given scale and zero point. + +Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + +Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + +Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) +""", +) + +add_docstr( + torch.quantize_per_tensor_dynamic, + r""" +quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor + +Converts a float tensor to a quantized tensor with scale and zero_point calculated +dynamically based on the input. + +Arguments: + input (Tensor): float tensor or list of tensors to quantize + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8`` + reduce_range (bool): a flag to indicate whether to reduce the range of quantized + data by 1 bit, it's required to avoid instruction overflow for some hardwares + +Returns: + Tensor: A newly (dynamically) quantized tensor + +Example:: + + >>> t = torch.quantize_per_tensor_dynamic(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.quint8, False) + >>> print(t) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.011764705882352941, + zero_point=85) + >>> t.int_repr() + tensor([ 0, 85, 170, 255], dtype=torch.uint8) +""", +) + +add_docstr( + torch.quantize_per_channel, + r""" +quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor + +Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + +Arguments: + input (Tensor): float tensor to quantize + scales (Tensor): float 1D tensor of scales to use, size should match ``input.size(axis)`` + zero_points (int): integer 1D tensor of offset to use, size should match ``input.size(axis)`` + axis (int): dimension on which apply per-channel quantization + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + +Returns: + Tensor: A newly quantized tensor + +Example:: + + >>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) + tensor([[-1., 0.], + [ 1., 2.]], size=(2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_channel_affine, + scale=tensor([0.1000, 0.0100], dtype=torch.float64), + zero_point=tensor([10, 0]), axis=0) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() + tensor([[ 0, 10], + [100, 200]], dtype=torch.uint8) +""", +) + + +add_docstr( + torch.quantized_batch_norm, + r""" +quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor + +Applies batch normalization on a 4D (NCHW) quantized tensor. + +.. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + +Arguments: + input (Tensor): quantized tensor + weight (Tensor): float tensor that corresponds to the gamma, size C + bias (Tensor): float tensor that corresponds to the beta, size C + mean (Tensor): float mean value in batch normalization, size C + var (Tensor): float tensor for variance, size C + eps (float): a value added to the denominator for numerical stability. + output_scale (float): output quantized tensor scale + output_zero_point (int): output quantized tensor zero_point + +Returns: + Tensor: A quantized tensor with batch normalization applied. + +Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_batch_norm(qx, torch.ones(2), torch.zeros(2), torch.rand(2), torch.rand(2), 0.00001, 0.2, 2) + tensor([[[[-0.2000, -0.2000], + [ 1.6000, -0.2000]], + + [[-0.4000, -0.4000], + [-0.4000, 0.6000]]], + + + [[[-0.2000, -0.2000], + [-0.2000, -0.2000]], + + [[ 0.6000, -0.4000], + [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) +""", +) + + +add_docstr( + torch.quantized_max_pool1d, + r""" +quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + +Applies a 1D max pooling over an input quantized tensor composed of several input planes. + +Arguments: + input (Tensor): quantized tensor + kernel_size (list of int): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + +Returns: + Tensor: A quantized tensor with max_pool1d applied. + +Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool1d(qx, [2]) + tensor([[0.0000], + [1.5000]], size=(2, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) +""", +) + + +add_docstr( + torch.quantized_max_pool2d, + r""" +quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + +Applies a 2D max pooling over an input quantized tensor composed of several input planes. + +Arguments: + input (Tensor): quantized tensor + kernel_size (``list of int``): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + +Returns: + Tensor: A quantized tensor with max_pool2d applied. + +Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool2d(qx, [2,2]) + tensor([[[[1.5000]], + + [[1.5000]]], + + + [[[0.0000]], + + [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) +""", +) + + +add_docstr( + torch.Generator, + r""" +Generator(device='cpu') -> Generator + +Creates and returns a generator object that manages the state of the algorithm which +produces pseudo random numbers. Used as a keyword argument in many :ref:`inplace-random-sampling` +functions. + +Arguments: + device (:class:`torch.device`, optional): the desired device for the generator. + +Returns: + Generator: An torch.Generator object. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> g_cpu = torch.Generator() + >>> g_cuda = torch.Generator(device='cuda') +""", +) + + +add_docstr( + torch.Generator.set_state, + r""" +Generator.set_state(new_state) -> void + +Sets the Generator state. + +Arguments: + new_state (torch.ByteTensor): The desired state. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu_other = torch.Generator() + >>> g_cpu.set_state(g_cpu_other.get_state()) +""", +) + + +add_docstr( + torch.Generator.get_state, + r""" +Generator.get_state() -> Tensor + +Returns the Generator state as a ``torch.ByteTensor``. + +Returns: + Tensor: A ``torch.ByteTensor`` which contains all the necessary bits + to restore a Generator to a specific point in time. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.get_state() +""", +) + + +add_docstr( + torch.Generator.manual_seed, + r""" +Generator.manual_seed(seed) -> Generator + +Sets the seed for generating random numbers. Returns a `torch.Generator` object. Any 32-bit integer is a valid seed. + +Arguments: + seed (int): The desired seed. Value must be within the inclusive range + `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError + is raised. Negative inputs are remapped to positive values with the formula + `0xffff_ffff_ffff_ffff + seed`. + +Returns: + Generator: An torch.Generator object. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.manual_seed(2147483647) +""", +) + + +add_docstr( + torch.Generator.initial_seed, + r""" +Generator.initial_seed() -> int + +Returns the initial seed for generating random numbers. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.initial_seed() + 2147483647 +""", +) + + +add_docstr( + torch.Generator.seed, + r""" +Generator.seed() -> int + +Gets a non-deterministic random number from std::random_device or the current +time and uses it to seed a Generator. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.seed() + 1516516984916 +""", +) + + +add_docstr( + torch.Generator.device, + r""" +Generator.device -> device + +Gets the current device of the generator. + +Example:: + + >>> g_cpu = torch.Generator() + >>> g_cpu.device + device(type='cpu') +""", +) + +add_docstr( + torch._assert_async, + r""" +_assert_async(tensor) -> void + +Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, +this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for +CUDA tensors, we DO NOT synchronize and you may only find out the assertion +failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for +testing invariants in CUDA tensors without giving up performance. This function +is NOT intended to be used for regular error checking, as it will trash your CUDA +context if the assert fails (forcing you to restart your PyTorch process.) + +Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. +""", +) + +add_docstr( + torch.searchsorted, + r""" +searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + +Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the +corresponding values in :attr:`values` were inserted before the indices, when sorted, the order +of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. +Return a new tensor with the same size as :attr:`values`. More formally, +the returned index satisfies the following rules: + +.. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + +Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + +Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + +Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) +""", +) + +add_docstr( + torch.bucketize, + r""" +bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + +Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the +boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size +as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that +this behavior is opposite the behavior of +`numpy.digitize `_. +More formally, the returned index satisfies the following rules: + +.. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + +Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + +Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of :attr:`boundaries` (one pass the last index). + In other words, if False, gets the lower bound index for each value in :attr:`input` + from :attr:`boundaries`. If True, gets the upper bound index instead. + Default value is False. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + +Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) +""", +) + +add_docstr( + torch.view_as_real_copy, + r""" +Performs the same operation as :func:`torch.view_as_real`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.view_as_complex_copy, + r""" +Performs the same operation as :func:`torch.view_as_complex`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.as_strided_copy, + r""" +Performs the same operation as :func:`torch.as_strided`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.diagonal_copy, + r""" +Performs the same operation as :func:`torch.diagonal`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.expand_copy, + r""" +Performs the same operation as :func:`torch.expand`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.permute_copy, + r""" +Performs the same operation as :func:`torch.permute`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.select_copy, + r""" +Performs the same operation as :func:`torch.select`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.detach_copy, + r""" +Performs the same operation as :func:`torch.detach`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.slice_copy, + r""" +Performs the same operation as :func:`torch.slice`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.split_copy, + r""" +Performs the same operation as :func:`torch.split`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.split_with_sizes_copy, + r""" +Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.squeeze_copy, + r""" +Performs the same operation as :func:`torch.squeeze`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.t_copy, + r""" +Performs the same operation as :func:`torch.t`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.transpose_copy, + r""" +Performs the same operation as :func:`torch.transpose`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.unsqueeze_copy, + r""" +Performs the same operation as :func:`torch.unsqueeze`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.indices_copy, + r""" +Performs the same operation as :func:`torch.indices`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.values_copy, + r""" +Performs the same operation as :func:`torch.values`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.crow_indices_copy, + r""" +Performs the same operation as :func:`torch.crow_indices`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.col_indices_copy, + r""" +Performs the same operation as :func:`torch.col_indices`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.unbind_copy, + r""" +Performs the same operation as :func:`torch.unbind`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.view_copy, + r""" +Performs the same operation as :func:`torch.view`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.unfold_copy, + r""" +Performs the same operation as :func:`torch.unfold`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +add_docstr( + torch.alias_copy, + r""" +Performs the same operation as :func:`torch.alias`, but all output tensors +are freshly created instead of aliasing the input. +""", +) + +for unary_base_func_name in ( + "exp", + "sqrt", + "abs", + "acos", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", + "erfc", + "expm1", + "floor", + "log", + "log10", + "log1p", + "log2", + "neg", + "tan", + "tanh", + "sin", + "sinh", + "round", + "lgamma", + "frac", + "reciprocal", + "sigmoid", + "trunc", + "zero", +): + unary_foreach_func_name = f"_foreach_{unary_base_func_name}" + if hasattr(torch, unary_foreach_func_name): + add_docstr( + getattr(torch, unary_foreach_func_name), + rf""" +{unary_foreach_func_name}(self: List[Tensor]) -> List[Tensor] + +Apply :func:`torch.{unary_base_func_name}` to each Tensor of the input list. + """, + ) + unary_inplace_foreach_func_name = f"{unary_foreach_func_name}_" + if hasattr(torch, unary_inplace_foreach_func_name): + add_docstr( + getattr(torch, unary_inplace_foreach_func_name), + rf""" +{unary_inplace_foreach_func_name}(self: List[Tensor]) -> None + +Apply :func:`torch.{unary_base_func_name}` to each Tensor of the input list. + """, + ) diff --git a/MLPY/Lib/site-packages/torch/_utils.py b/MLPY/Lib/site-packages/torch/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8f51000ed63dd93dae1fd22d1bd19d507e127b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_utils.py @@ -0,0 +1,937 @@ +import copyreg +import functools +import sys +import traceback +import warnings +from collections import defaultdict +from typing import Any, DefaultDict, List, Optional + +import torch + + +def _type(self, dtype=None, non_blocking=False, **kwargs): + """Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (type or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs) + if dtype is None: + return self.__module__ + "." + self.__class__.__name__ + + if isinstance(dtype, str): + dtype = _import_dotted_name(dtype) + if dtype == type(self): + return self + if self.is_sparse: + if not dtype.is_sparse: + raise RuntimeError("Cannot cast sparse tensor to dense tensor") + new_module_name = dtype.__module__.replace(".sparse", "") + new_values_type_name = new_module_name + "." + dtype.__name__ + new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking) + new_indices_type_name = new_module_name + ".LongTensor" + new_indices = torch.Tensor._indices(self).type( + new_indices_type_name, non_blocking + ) + return dtype(new_indices, new_values, self.size()) + if dtype.is_sparse: + raise RuntimeError("Cannot cast dense tensor to sparse tensor") + return dtype(self.size()).copy_(self, non_blocking) + + +def _hpu(self, device=None, non_blocking=False, **kwargs): + """Returns a copy of this object in HPU memory. + + If this object is already in HPU memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination HPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. + """ + non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs) + hpu = getattr(torch, "hpu", None) + assert hpu is not None, "HPU device module is not loaded" + if self.is_hpu: + if device is None: + device = hpu.current_device() + if self.get_device() == device: + return self + else: + if device is None: + device = -1 + with hpu.device(device): + assert not self.is_sparse, "sparse storage is not supported for HPU tensors" + untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu")) + untyped_storage.copy_(self, non_blocking) + return untyped_storage + + +def _cuda(self, device=None, non_blocking=False, **kwargs): + """Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination GPU id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. + """ + non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs) + if self.is_cuda: + if device is None: + device = torch.cuda.current_device() + if self.get_device() == device: + return self + else: + if device is None: + device = -1 + with torch.cuda.device(device): + if self.is_sparse: + new_type = getattr(torch.cuda.sparse, self.__class__.__name__) + indices = torch.Tensor._indices(self).cuda(device, non_blocking) + values = torch.Tensor._values(self).cuda(device, non_blocking) + return new_type(indices, values, self.size()) + else: + untyped_storage = torch.UntypedStorage( + self.size(), device=torch.device("cuda") + ) + untyped_storage.copy_(self, non_blocking) + return untyped_storage + + +def _get_async_or_non_blocking(function_name, non_blocking, kwargs): + """Return the non-blocking flag given the function name and kwargs. + + Args: + function_name (str): the name of the function being used. + non_blocking (bool): the default value. + **kwargs (dict): the kwargs passed to the function. + """ + if not kwargs: + return non_blocking + if len(kwargs) != 1 or "async" not in kwargs: + message = "{}() got an unexpected keyword argument '{}'" + argument = list(kwargs.keys()).pop() + raise TypeError(message.format(function_name, argument)) + warnings.warn("'async' is deprecated; use 'non_blocking'") + return kwargs["async"] + + +# Note [Don't serialize hooks] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Since time immemorial, we have serialized the backward hooks associated with +# variables. This kind of half-worked--Python can pickle global functions +# (but not closures!)--but there were problems. +# +# - It's fragile. If you serialize a backward hook into a saved +# model, and then you rename the function associated with the hook, +# now your saved model is broken and you can't load it anymore. +# +# - It's not actually used. The standard recommendation is to +# serialize the *state_dict* of a model, not the model itself +# (since this is more stable to code changes affecting the model +# serialization), and the state dict saves "data" only, thus +# stripping the backward hooks. In some cases, hooks are +# essential to the well-functioning of a model (e.g., DDP), +# but DDP already manages readding the hooks! +# +# - We didn't serialize them in many cases. Prior to #10220, we +# were dropping backward hooks in ForkingPickler. We "fixed" this +# to be convenient with other serialization sites, but lack of +# serializing backward hooks wasn't actually the root cause of +# the bug. +# +# With these cases in mind, we have decided that a better strategy +# is to just NOT serialize hooks at all. +# +# Since this is a BC-breaking change, we should warn when we previously +# serialized a hook, but no longer do so. This will be done by adding a special +# sentinel property to hooks will be used to suppress this warning. If a hook +# has the property _torch_serialize_ignore, we will not emit a warning if we +# attempt to serialize a Tensor with this hook attached to it. +# +# By the way, when _backward_hooks is skipped, we must give an EMPTY +# OrderedDict(), if you pass a None you'll run afoul #12219. + + +# TODO: Once we decide to break serialization FC, `storage` no longer needs to +# be a TypedStorage +def _rebuild_tensor(storage, storage_offset, size, stride): + # first construct a tensor with the correct dtype/device + t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device) + return t.set_(storage._untyped_storage, storage_offset, size, stride) + + +def get_tensor_metadata(tensor): + # Tensor's Metadata for serializing. + # Currently, this only returns a dict[string, bool] specifing whether + # `conj` or `neg` bit is set. + assert isinstance(tensor, torch.Tensor) + return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] + + +def set_tensor_metadata(tensor, metadata): + # See `get_tensor_metadata` above + assert isinstance(metadata, dict) + assert isinstance(tensor, torch.Tensor) + torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined] + + +def _rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None +): + tensor = _rebuild_tensor(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + if metadata: + set_tensor_metadata(tensor, metadata) + + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + tensor._backward_hooks = backward_hooks + return tensor + + +def _rebuild_tensor_v3( + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + dtype, + metadata=None, +): + t = torch.empty( + (0,), + dtype=dtype, + device=storage._untyped_storage.device, + requires_grad=requires_grad, + ) + t.set_(storage._untyped_storage, storage_offset, size, stride) + if metadata: + set_tensor_metadata(t, metadata) + t._backward_hooks = backward_hooks + return t + + +_sparse_tensors_to_validate: List["torch.Tensor"] = [] + + +# In _legacy_load() in serialization.py we unpickle storages after the sparse +# tensors have been already unpickled. Those storages contain data necessary for +# validating sparse tensors: indices and values. That's why sparse tensors are +# first unpickled without any validation, and then this function is called just +# before _legacy_load() returns, so that all the sparse tensors can be validated +# in bulk. +# +# The same procedure must be followed by _load() in serialization.py because due +# to Pickler semantics, we have to use the same (non-validating) function for +# unpickling sparse tensors, regardless of the caller. +def _validate_loaded_sparse_tensors(): + try: + for t in _sparse_tensors_to_validate: + if t.layout is torch.sparse_coo: + torch._validate_sparse_coo_tensor_args( + t._indices(), t._values(), t.size(), t.is_coalesced() + ) + elif t.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + # TODO: Validation currently involves an expensive traversal + # on CPU, which may include a device transfer. + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = ( + t.crow_indices(), + t.col_indices(), + ) + else: + compressed_indices, plain_indices = ( + t.ccol_indices(), + t.row_indices(), + ) + torch._validate_sparse_compressed_tensor_args( + compressed_indices, plain_indices, t.values(), t.size(), t.layout + ) + else: + raise NotImplementedError( + f"_validate_loaded_sparse_tensors for layout `{t.layout}`" + ) + + finally: + _sparse_tensors_to_validate.clear() + + +def _rebuild_sparse_tensor(layout, data): + """ + Rebuilds a sparse tensor from its sparse storage representation. + + Args: + layout (str): The sparse storage layout of the tensor. + data (tuple): The tensor's sparse storage representation. + """ + if layout == torch.sparse_coo: + if len(data) == 3: + # For BC: + indices, values, size = data + is_coalesced = None + else: + indices, values, size, is_coalesced = data + result = torch.sparse_coo_tensor( + indices, values, size, check_invariants=False, is_coalesced=is_coalesced + ) + _sparse_tensors_to_validate.append(result) + return result + + elif layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + compressed_indices, plain_indices, values, size = data + result = torch.sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + size, + layout=layout, + check_invariants=False, + ) + _sparse_tensors_to_validate.append(result) + return result + + raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}") + + +def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): + return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets) + + +def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): + tensor = torch.from_numpy(data).to(dtype=dtype, device=device) + tensor.requires_grad = requires_grad + return tensor + + +# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch +_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy + + +def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): + return torch.empty_strided( + size, stride, dtype=dtype, device="meta", requires_grad=requires_grad + ) + + +def _rebuild_wrapper_subclass( + cls, dtype, size, stride, storage_offset, layout, device, requires_grad +): + return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + size, + strides=stride, + storage_offset=storage_offset, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +# TODO: Once we decide to break serialization FC, `storage` no longer needs to +# be a TypedStorage +def _rebuild_qtensor( + storage, + storage_offset, + size, + stride, + quantizer_params, + requires_grad, + backward_hooks, +): + qscheme = quantizer_params[0] + if qscheme == torch.per_tensor_affine: + _, scale, zero_point = quantizer_params + tensor = torch._empty_affine_quantized( + size, + scale=scale, + zero_point=zero_point, + dtype=storage.dtype, + device=storage.device, + ) + elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams): + _, scales, zero_points, axis = quantizer_params + if type(scales) is list and type(zero_points) is list: + if qscheme == torch.per_channel_affine: + scales = torch.tensor(scales, dtype=torch.double, device=storage.device) + zero_points = torch.tensor( + zero_points, dtype=torch.long, device=storage.device + ) + else: + scales = torch.tensor(scales, dtype=torch.float, device=storage.device) + zero_points = torch.tensor( + zero_points, dtype=torch.float, device=storage.device + ) + tensor = torch._empty_per_channel_affine_quantized( + size, + scales=scales, + zero_points=zero_points, + axis=axis, + dtype=storage.dtype, + device=storage.device, + ) + else: + raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}") + tensor.set_(storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + tensor._backward_hooks = backward_hooks + return tensor + + +def _rebuild_parameter(data, requires_grad, backward_hooks): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + return param + + +def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + # Restore state on Parameter like python attr. + param = _set_obj_state(param, state) + return param + + +def _get_obj_state(obj): + # Get the state of the python subclass + # This loosely mimicks the function on the object class but since Tensor do not inherit + # from it, we cannot call that function directly + # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 + # Note that starting with Python 3.11, this `__getstate__` is always defined and thus + # the else branch will never be taken. + getstate_fn = getattr(obj, "__getstate__", None) + if getstate_fn: + state = getstate_fn() + else: + slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] + if slots_to_save: + state = ( + obj.__dict__, + { + name: getattr(obj, name) + for name in slots_to_save + if hasattr(obj, name) + }, + ) + else: + state = obj.__dict__ + + return state + + +def _set_obj_state(obj, state): + if isinstance(state, tuple): + if not len(state) == 2: + raise RuntimeError(f"Invalid serialized state: {state}") + dict_state = state[0] + slots_state = state[1] + else: + dict_state = state + slots_state = None + + # Starting with Python 3.11, the __dict__ attribute is lazily created + # and is serialized as None when not needed. + if dict_state: + for k, v in dict_state.items(): + setattr(obj, k, v) + + if slots_state: + for k, v in slots_state.items(): + setattr(obj, k, v) + return obj + + +def _import_dotted_name(name): + components = name.split(".") + obj = __import__(components[0]) + for component in components[1:]: + obj = getattr(obj, component) + return obj + + +def _flatten_dense_tensors(tensors): + """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of + same dense type. + + Since inputs are dense, the resulting tensor will be a concatenated 1D + buffer. Element-wise operation on this buffer will be equivalent to + operating individually. + + Args: + tensors (Iterable[Tensor]): dense tensors to flatten. + + Returns: + A contiguous 1D buffer containing input tensors. + """ + return torch._C._nn.flatten_dense_tensors(tensors) + + +def _flatten_sparse_tensors(tensors): + """Flatten sparse tensors into two contiguous 1D buffers, one of indices and + one of values. Assume tensors are of same sparse type. + + Args: + tensors (Iterable[Tensor]): sparse tensors to flatten. + + Returns: + A tuple of two contiguous 1D buffers, one containing input tensors' + indices and the other containing the values. + """ + flat_indices = torch._C._nn.flatten_dense_tensors( + [torch.Tensor._indices(t) for t in tensors] + ) + flat_values = torch._C._nn.flatten_dense_tensors( + [torch.Tensor._values(t) for t in tensors] + ) + return flat_indices, flat_values + + +def _unflatten_dense_tensors(flat, tensors): + """View a flat buffer using the sizes of tensors. Assume that tensors are of + same dense type, and that flat is given by _flatten_dense_tensors. + + Args: + flat (Tensor): flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): dense tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened dense tensors with sizes same as tensors and values from + flat. + """ + return torch._C._nn.unflatten_dense_tensors(flat, tensors) + + +def _unflatten_sparse_tensors(flat, tensors): + """View flat buffer (containing indices and values) using the sizes of + tensors. Assume that tensors are of same sparse type, and that flat is given + by _flatten_sparse_tensors. + + Args: + flat (tuple(Tensor, Tensor)): flattened indices and values of sparse + tensors to unflatten. + tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to + unflatten flat. + + Returns: + Unflattened sparse tensors with sizes same as tensors and values from + flat. + """ + flat_indices, flat_values = flat + indices = torch._C._nn.unflatten_dense_tensors( + flat_indices, [torch.Tensor._indices(t) for t in tensors] + ) + values = torch._C._nn.unflatten_dense_tensors( + flat_values, [torch.Tensor._values(t) for t in tensors] + ) + outputs = [] + for t, i, v in zip(tensors, indices, values): + outputs.append(t.new(i, v, t.size())) + return tuple(outputs) + + +def _reorder_tensors_as(tensors, ordered_tensors): + """Assume that tensors are of same order as ordered_tensors within their + types, e.g., from _take_tensors. Reorder them to be of same order as + ordered_tensors. + + Args: + tensors (Iterable[Tensor]): tensors to be reordered. They should be of + the same order as ordered_tensors within their own types. + ordered_tensors (Iterable[Tensor]): tensors whose order will be the + reference. + + Returns: + Ordered tuple of tensors with contents from tensors and order of + ordered_tensors. + """ + type_dict = defaultdict(list) + for tensor in tensors: + type_dict[tensor.type()].append(tensor) + type_dict_ = {t: iter(coll) for t, coll in type_dict.items()} + return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors) + + +def _take_tensors(tensors, size_limit): + """Group tensors into chunks. This generator yields a chunk at each time, + each containing tensors of same type up to certain byte limit in total size. + + Args: + tensors (Sequence): A sequence of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + + Yields: + Blocks of tensors of same type and within size_limit. The yielded + tensors are only ordered as the original sequence within its types. + """ + buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0]) + for tensor in tensors: + t = tensor.type() + if tensor.is_sparse: + indices = torch.Tensor._indices(tensor) + values = torch.Tensor._values(tensor) + size = ( + indices.numel() * indices.element_size() + + values.numel() * values.element_size() + ) + else: + size = tensor.numel() * tensor.element_size() + buf_and_size = buf_dict[t] + if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: + yield buf_and_size[0] + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append(tensor) + buf_and_size[1] += size + for buf, _ in buf_dict.values(): + if len(buf) > 0: + yield buf + + +# annotation decorator to get annotations in a way that is compatible +# with both Python 2 and 3 +def annotate(ret, **kwargs): + def dec(fun): + fun.__annotations__ = dict(kwargs) + fun.__annotations__["return"] = ret + return fun + + return dec + + +def render_call(fn, args, kwargs): + str_fn = torch.overrides.resolve_name(fn) + if str_fn is None: + str_fn = str(fn) + + str_args: List[str] = [] + with torch._tensor_str.printoptions(threshold=0, edgeitems=0): + str_args.extend(repr(a) for a in args) + str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items()) + r = f"{str_fn}({', '.join(str_args)})" + return r + + +# NOTE [ Python Traceback Reference Cycle Problem ] +# +# When using sys.exc_info(), it is important to **not** store the exc_info[2], +# which is the traceback, because otherwise you will run into the traceback +# reference cycle problem, i.e., the traceback holding reference to the frame, +# and the frame (which holds reference to all the object in its temporary scope) +# holding reference the traceback. + + +class KeyErrorMessage(str): + r"""str subclass that returns itself in repr""" + + def __repr__(self): + return self + + +class ExceptionWrapper: + r"""Wraps an exception plus traceback to communicate across threads""" + + def __init__(self, exc_info=None, where="in background"): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = where + + def reraise(self): + r"""Reraises the wrapped exception in the current thread""" + # Format a message such as: "Caught ValueError in DataLoader worker + # process 2. Original Traceback:", followed by the traceback. + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" + if self.exc_type == KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + try: + exception = self.exc_type(msg) + except TypeError: + # If the exception takes multiple arguments, don't try to + # instantiate since we don't know how to + raise RuntimeError(msg) from None + raise exception + + +def _get_available_device_type(): + if torch.cuda.is_available(): + return "cuda" + if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined] + return "xpu" + custom_backend_name = torch._C._get_privateuse1_backend_name() + custom_device_mod = getattr(torch, custom_backend_name, None) + if custom_device_mod and custom_device_mod.is_available(): + return custom_backend_name + # add more available device types here + return None + + +def _get_device_attr(get_member): + device_type = _get_available_device_type() + if device_type and device_type.lower() == "cuda": + return get_member(torch.cuda) + if device_type and device_type.lower() == "xpu": + return get_member(torch.xpu) # type: ignore[attr-defined] + if device_type == torch._C._get_privateuse1_backend_name(): + return get_member(getattr(torch, device_type)) + # add more available device types here + return None + + +def _get_current_device_index(): + # current device index + return _get_device_attr(lambda m: m.current_device()) + + +def _get_all_device_indices(): + # all device index + return _get_device_attr(lambda m: list(range(m.device_count()))) + + +def _get_devices_properties(device_ids): + # all device properties + return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids] + + +def get_current_device_index() -> int: + r"""Checks if there are CUDA devices available and + returns the device index of the current default CUDA device. + Returns -1 in case there are no CUDA devices available. + Arguments: ``None`` + """ + if torch.cuda.device_count() > 0: + return torch.cuda.current_device() + return -1 + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Gets the device index from :attr:`device`, which can be a torch.device + object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + has index. Note that for a device without a specified index, + i.e., ``torch.device('xxx')``, this will return the current default + device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default + device of the supported runtime platform if :attr:`optional` is ``True``. + i.e., the current default CUDA device will be returned if CUDA runtime is supported. + """ + if isinstance(device, str): + device = torch.device(device) + device_idx: Optional[int] = None + if isinstance(device, torch.device): + if not allow_cpu and device.type == "cpu": + raise ValueError(f"Expected a non cpu device, but got: {device}") + device_idx = -1 if device.type == "cpu" else device.index + if isinstance(device, int): + device_idx = device + if device_idx is None: + if optional: + # The eager API _get_current_device_index uses `lambda` functions which are + # not supported in JIT and hence not scriptable. The JIT equivalent API to get + # the current device index is `get_current_device_index()` which can + # be scripted. We use is_scripting to check the mode we are in and call the + # appropriate API. + if torch.jit.is_scripting(): + device_idx = get_current_device_index() + else: + device_idx = _get_current_device_index() + else: + raise ValueError( + f"Expected a torch.device with a specified index or an integer, but got:{device}" + ) + return device_idx + + +def _handle_complex(tensor): + """ + Returns a real view of a tensor if complex dtype else just the tensor + need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule + """ + return ( + torch.view_as_real(tensor) + if not isinstance(tensor, torch.nn.UninitializedParameter) + and tensor.is_complex() + else tensor + ) + + +def _element_size(dtype): + """ + Returns the element size for a dtype, in bytes + """ + if not isinstance(dtype, torch.dtype): + raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}") + + if dtype.is_complex: + return torch.finfo(dtype).bits >> 2 + elif dtype.is_floating_point: + return torch.finfo(dtype).bits >> 3 + elif dtype == torch.bool: + # NOTE: torch.bool is not supported in torch.iinfo() + return 1 + else: + return torch.iinfo(dtype).bits >> 3 + + +class _ClassPropertyDescriptor: + def __init__(self, fget, fset=None): + self.fget = fget + + def __get__(self, instance, owner=None): + if owner is None: + owner = type(instance) + return self.fget.__get__(instance, owner)() + + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + return _ClassPropertyDescriptor(func) + + +def is_compiling() -> bool: + """ + Indicates whether we are tracing/compiling with torch.compile() or torch.export(). + + TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling(). + """ + return torch.compiler.is_compiling() + + +def _functionalize_sync(t): + # This code lives in python instead of C++ since conditioning on a certain python subclass + # is much more of a pain in C++. + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(t, FunctionalTensor): + # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called + # when we sync our inner tensor. + # Why? + # (1) If there are input mutations in the graph, then they will be re-applied during + # AOTAutograd when we call _sync() from inside of our functionalization kernels. + # (2) _sync() causes us to regenerate our updated the tensor from the updated base, + # which dispatches to a bunch of view ops + # (3) The input to these view ops is our inner FunctionalTensorWrapper + # (since the sync was called from C++), not the python FunctionalTensor + # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts + # the view op, since it will see an input that is a C++ FunctionalTensorWrapper + # (aka a normal torch.Tensor) instead of a python `FunctionalTensor). + maybe_functional_mode = torch._C._unset_dispatch_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + try: + torch._functionalize_sync(t.elem) # type: ignore[attr-defined] + finally: + if maybe_functional_mode is not None: + torch._C._set_dispatch_mode(maybe_functional_mode) + else: + torch._functionalize_sync(t) # type: ignore[attr-defined] + + +@functools.lru_cache(2) +def _get_device_module(device_type: str): + device_module = getattr(torch, device_type, None) + if device_module is None: + raise RuntimeError( + f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." + ) + return device_module + + +def _dummy_type(name: str) -> type: + def get_err_fn(is_init: bool): + def err_fn(obj, *args, **kwargs): + if is_init: + class_name = obj.__class__.__name__ + else: + class_name = obj.__name__ + raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") + + return err_fn + + return type( + name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} + ) + + +class _LazySeedTracker: + # Since seeding is memory-less, only track the latest seed. + # Note: `manual_seed_all` followed by `manual_seed` overwrites + # the seed on current device. We track the order of **latest** + # calls between these two API. + def __init__(self): + self.manual_seed_all_cb = None + self.manual_seed_cb = None + self.call_order = [] + + def queue_seed_all(self, cb, traceback): + self.manual_seed_all_cb = (cb, traceback) + # update seed_all to be latest + self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] + + def queue_seed(self, cb, traceback): + self.manual_seed_cb = (cb, traceback) + # update seed to be latest + self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] + + def get_calls(self) -> List: + return self.call_order diff --git a/MLPY/Lib/site-packages/torch/_utils_internal.py b/MLPY/Lib/site-packages/torch/_utils_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..200d07bb05661a3d03bad1faae607be939c5055d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_utils_internal.py @@ -0,0 +1,138 @@ +import functools +import logging +import os +import sys +import tempfile +from typing import Any, Dict + +import torch + +log = logging.getLogger(__name__) + + +# this arbitrary-looking assortment of functionality is provided here +# to have a central place for overrideable behavior. The motivating +# use is the FB build environment, where this source file is replaced +# by an equivalent. + +if torch._running_with_deploy(): + # __file__ is meaningless in the context of frozen torch used in torch deploy. + # setting empty torch_parent should allow below functions to operate without crashing, + # but it's unclear if there is a valid use case for them in the context of deploy. + torch_parent = "" +else: + if os.path.basename(os.path.dirname(__file__)) == "shared": + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + else: + torch_parent = os.path.dirname(os.path.dirname(__file__)) + + +def get_file_path(*path_components: str) -> str: + return os.path.join(torch_parent, *path_components) + + +def get_file_path_2(*path_components: str) -> str: + return os.path.join(*path_components) + + +def get_writable_path(path: str) -> str: + if os.access(path, os.W_OK): + return path + return tempfile.mkdtemp(suffix=os.path.basename(path)) + + +def prepare_multiprocessing_environment(path: str) -> None: + pass + + +def resolve_library_path(path: str) -> str: + return os.path.realpath(path) + + +def throw_abstract_impl_not_imported_error(opname, module, context): + if module in sys.modules: + raise NotImplementedError( + f"{opname}: We could not find the abstract impl for this operator. " + ) + else: + raise NotImplementedError( + f"{opname}: We could not find the abstract impl for this operator. " + f"The operator specified that you may need to import the '{module}' " + f"Python module to load the abstract impl. {context}" + ) + + +# Meta only, see +# https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/ +# +# This will cause an event to get logged to Scuba via the signposts API. You +# can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs +# we log to subsystem "torch", and the category and name you provide here. +# Each of the arguments translate into a Scuba column. We're still figuring +# out local conventions in PyTorch, but category should be something like +# "dynamo" or "inductor", and name should be a specific string describing what +# kind of event happened. +# +# Killswitch is at +# https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event +def signpost_event(category: str, name: str, parameters: Dict[str, Any]): + log.info("%s %s: %r", category, name, parameters) + + +def log_compilation_event(metrics): + log.info("%s", metrics) + + +def upload_graph(graph): + pass + + +def set_pytorch_distributed_envs_from_justknobs(): + pass + + +def log_export_usage(**kwargs): + pass + + +def justknobs_check(name: str) -> bool: + """ + This function can be used to killswitch functionality in FB prod, + where you can toggle this value to False in JK without having to + do a code push. In OSS, we always have everything turned on all + the time, because downstream users can simply choose to not update + PyTorch. (If more fine-grained enable/disable is needed, we could + potentially have a map we lookup name in to toggle behavior. But + the point is that it's all tied to source code in OSS, since there's + no live server to query.) + + This is the bare minimum functionality I needed to do some killswitches. + We have a more detailed plan at + https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit + In particular, in some circumstances it may be necessary to read in + a knob once at process start, and then use it consistently for the + rest of the process. Future functionality will codify these patterns + into a better high level API. + + WARNING: Do NOT call this function at module import time, JK is not + fork safe and you will break anyone who forks the process and then + hits JK again. + """ + return True + + +@functools.lru_cache(None) +def max_clock_rate(): + from triton.testing import nvsmi + + return nvsmi(["clocks.max.sm"])[0] + + +TEST_MASTER_ADDR = "127.0.0.1" +TEST_MASTER_PORT = 29500 +# USE_GLOBAL_DEPS controls whether __init__.py tries to load +# libtorch_global_deps, see Note [Global dependencies] +USE_GLOBAL_DEPS = True +# USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load +# _C.so with RTLD_GLOBAL during the call to dlopen. +USE_RTLD_GLOBAL_WITH_LIBTORCH = False diff --git a/MLPY/Lib/site-packages/torch/_vendor/__init__.py b/MLPY/Lib/site-packages/torch/_vendor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MLPY/Lib/site-packages/torch/_vendor/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_vendor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ded06d8d1b17b21b3c63c8196eac5f8f2cb4409f Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_vendor/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_vendor/packaging/__init__.py b/MLPY/Lib/site-packages/torch/_vendor/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2ad1ca0a2bf4d73bb6dc5252c3407dd0f20d14 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_vendor/packaging/__init__.py @@ -0,0 +1,15 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "23.2" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014 %s" % __author__ diff --git a/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf156ffc6b1ee368f76d67efaca417d5ec52172 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2dd495706ff043fd13f2fa6af28f972ce0e3309 Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-39.pyc b/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc3b025be8309c8ad5d34b52bf4c9095fa1fd6c Binary files /dev/null and b/MLPY/Lib/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-39.pyc differ diff --git a/MLPY/Lib/site-packages/torch/_vendor/packaging/_structures.py b/MLPY/Lib/site-packages/torch/_vendor/packaging/_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc91962d80e24f98b76d0da1d765fc78b0a1dcb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/MLPY/Lib/site-packages/torch/_vendor/packaging/version.py b/MLPY/Lib/site-packages/torch/_vendor/packaging/version.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cca483cee045aa1acfa9f5cf27c0331cc532aa --- /dev/null +++ b/MLPY/Lib/site-packages/torch/_vendor/packaging/version.py @@ -0,0 +1,563 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. +""" +.. testsetup:: + + from packaging.version import parse, Version +""" + +import itertools +import re +from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] + +LocalType = Tuple[Union[int, str], ...] + +CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]] +CmpLocalType = Union[ + NegativeInfinityType, + Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...], +] +CmpKey = Tuple[ + int, + Tuple[int, ...], + CmpPrePostDevType, + CmpPrePostDevType, + CmpPrePostDevType, + CmpLocalType, +] +VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] + + +class _Version(NamedTuple): + epoch: int + release: Tuple[int, ...] + dev: Optional[Tuple[str, int]] + pre: Optional[Tuple[str, int]] + post: Optional[Tuple[str, int]] + local: Optional[LocalType] + + +def parse(version: str) -> "Version": + """Parse the given version string. + + >>> parse('1.0.dev1') + + + :param version: The version string to parse. + :raises InvalidVersion: When the version string is not a valid version. + """ + return Version(version) + + +class InvalidVersion(ValueError): + """Raised when a version string is not a valid version. + + >>> Version("invalid") + Traceback (most recent call last): + ... + packaging.version.InvalidVersion: Invalid version: 'invalid' + """ + + +class _BaseVersion: + _key: Tuple[Any, ...] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?Palpha|a|beta|b|preview|pre|c|rc)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+VERSION_PATTERN = _VERSION_PATTERN
+"""
+A string containing the regular expression used to match a valid version.
+
+The pattern is not anchored at either end, and is intended for embedding in larger
+expressions (for example, matching a version number as part of a file name). The
+regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
+flags set.
+
+:meta hide-value:
+"""
+
+
+class Version(_BaseVersion):
+    """This class abstracts handling of a project's versions.
+
+    A :class:`Version` instance is comparison aware and can be compared and
+    sorted using the standard Python interfaces.
+
+    >>> v1 = Version("1.0a5")
+    >>> v2 = Version("1.0")
+    >>> v1
+    
+    >>> v2
+    
+    >>> v1 < v2
+    True
+    >>> v1 == v2
+    False
+    >>> v1 > v2
+    False
+    >>> v1 >= v2
+    False
+    >>> v1 <= v2
+    True
+    """
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    _key: CmpKey
+
+    def __init__(self, version: str) -> None:
+        """Initialize a Version object.
+
+        :param version:
+            The string representation of a version which will be parsed and normalized
+            before use.
+        :raises InvalidVersion:
+            If the ``version`` does not conform to PEP 440 in any way then this
+            exception will be raised.
+        """
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        """A representation of the Version that shows all internal state.
+
+        >>> Version('1.0.0')
+        
+        """
+        return f""
+
+    def __str__(self) -> str:
+        """A string representation of the version that can be rounded-tripped.
+
+        >>> str(Version("1.0a5"))
+        '1.0a5'
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        """The epoch of the version.
+
+        >>> Version("2.0.0").epoch
+        0
+        >>> Version("1!2.0.0").epoch
+        1
+        """
+        return self._version.epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        """The components of the "release" segment of the version.
+
+        >>> Version("1.2.3").release
+        (1, 2, 3)
+        >>> Version("2.0.0").release
+        (2, 0, 0)
+        >>> Version("1!2.0.0.post0").release
+        (2, 0, 0)
+
+        Includes trailing zeroes but not the epoch or any pre-release / development /
+        post-release suffixes.
+        """
+        return self._version.release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        """The pre-release segment of the version.
+
+        >>> print(Version("1.2.3").pre)
+        None
+        >>> Version("1.2.3a1").pre
+        ('a', 1)
+        >>> Version("1.2.3b1").pre
+        ('b', 1)
+        >>> Version("1.2.3rc1").pre
+        ('rc', 1)
+        """
+        return self._version.pre
+
+    @property
+    def post(self) -> Optional[int]:
+        """The post-release number of the version.
+
+        >>> print(Version("1.2.3").post)
+        None
+        >>> Version("1.2.3.post1").post
+        1
+        """
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        """The development number of the version.
+
+        >>> print(Version("1.2.3").dev)
+        None
+        >>> Version("1.2.3.dev1").dev
+        1
+        """
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        """The local version segment of the version.
+
+        >>> print(Version("1.2.3").local)
+        None
+        >>> Version("1.2.3+abc").local
+        'abc'
+        """
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        """The public portion of the version.
+
+        >>> Version("1.2.3").public
+        '1.2.3'
+        >>> Version("1.2.3+abc").public
+        '1.2.3'
+        >>> Version("1.2.3+abc.dev1").public
+        '1.2.3'
+        """
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        """The "base version" of the version.
+
+        >>> Version("1.2.3").base_version
+        '1.2.3'
+        >>> Version("1.2.3+abc").base_version
+        '1.2.3'
+        >>> Version("1!1.2.3+abc.dev1").base_version
+        '1!1.2.3'
+
+        The "base version" is the public version of the project without any pre or post
+        release markers.
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        """Whether this version is a pre-release.
+
+        >>> Version("1.2.3").is_prerelease
+        False
+        >>> Version("1.2.3a1").is_prerelease
+        True
+        >>> Version("1.2.3b1").is_prerelease
+        True
+        >>> Version("1.2.3rc1").is_prerelease
+        True
+        >>> Version("1.2.3dev1").is_prerelease
+        True
+        """
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        """Whether this version is a post-release.
+
+        >>> Version("1.2.3").is_postrelease
+        False
+        >>> Version("1.2.3.post1").is_postrelease
+        True
+        """
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        """Whether this version is a development release.
+
+        >>> Version("1.2.3").is_devrelease
+        False
+        >>> Version("1.2.3.dev1").is_devrelease
+        True
+        """
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        """The first item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").major
+        1
+        """
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        """The second item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").minor
+        2
+        >>> Version("1").minor
+        0
+        """
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        """The third item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").micro
+        3
+        >>> Version("1").micro
+        0
+        """
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[LocalType],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: CmpPrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: CmpPrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: CmpPrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: CmpLocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/MLPY/Lib/site-packages/torch/_vmap_internals.py b/MLPY/Lib/site-packages/torch/_vmap_internals.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a2541bb41d7be8dab651d8b5757e3589ec3fd3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/_vmap_internals.py
@@ -0,0 +1,237 @@
+import functools
+import warnings
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
+
+in_dims_t = Union[int, Tuple]
+out_dims_t = Union[int, Tuple[int, ...]]
+
+
+# Checks that all args-to-be-batched have the same batch dim size
+def _validate_and_get_batch_size(
+    flat_in_dims: List[Optional[int]], flat_args: List
+) -> int:
+    batch_sizes = [
+        arg.size(in_dim)
+        for in_dim, arg in zip(flat_in_dims, flat_args)
+        if in_dim is not None
+    ]
+    if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
+        raise ValueError(
+            f"vmap: Expected all tensors to have the same size in the mapped "
+            f"dimension, got sizes {batch_sizes} for the mapped dimension"
+        )
+    return batch_sizes[0]
+
+
+def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
+    if isinstance(batched_outputs, tuple):
+        return len(batched_outputs)
+    return 1
+
+
+# If value is a tuple, check it has length `num_elements`.
+# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
+def _as_tuple(
+    value: Any, num_elements: int, error_message_lambda: Callable[[], str]
+) -> Tuple:
+    if not isinstance(value, tuple):
+        return (value,) * num_elements
+    if len(value) != num_elements:
+        raise ValueError(error_message_lambda())
+    return value
+
+
+# Creates BatchedTensors for every Tensor in arg that should be batched.
+# Returns the (potentially) batched arguments and the batch_size.
+def _create_batched_inputs(
+    in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable
+) -> Tuple[Tuple, int]:
+    if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
+        raise ValueError(
+            f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): "
+            f"expected `in_dims` to be int or a (potentially nested) tuple "
+            f"matching the structure of inputs, got: {type(in_dims)}."
+        )
+    if len(args) == 0:
+        raise ValueError(
+            f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add "
+            f"inputs, or you are trying to vmap over a function with no inputs. "
+            f"The latter is unsupported."
+        )
+
+    flat_args, args_spec = tree_flatten(args)
+    flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
+    if flat_in_dims is None:
+        raise ValueError(
+            f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): "
+            f"in_dims is not compatible with the structure of `inputs`. "
+            f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
+            f"has structure {args_spec}."
+        )
+
+    for arg, in_dim in zip(flat_args, flat_in_dims):
+        if not isinstance(in_dim, int) and in_dim is not None:
+            raise ValueError(
+                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): "
+                f"Got in_dim={in_dim} for an input but in_dim must be either "
+                f"an integer dimension or None."
+            )
+        if isinstance(in_dim, int) and not isinstance(arg, Tensor):
+            raise ValueError(
+                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): "
+                f"Got in_dim={in_dim} for an input but the input is of type "
+                f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
+                f"please use None as the respective in_dim"
+            )
+        if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
+            raise ValueError(
+                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): "
+                f"Got in_dim={in_dim} for some input, but that input is a Tensor "
+                f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
+                f"0 <= in_dim < {arg.dim()}."
+            )
+
+    batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
+    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
+    batched_inputs = [
+        arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
+        for in_dim, arg in zip(flat_in_dims, flat_args)
+    ]
+    return tree_unflatten(batched_inputs, args_spec), batch_size
+
+
+# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
+def _unwrap_batched(
+    batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
+    out_dims: out_dims_t,
+    vmap_level: int,
+    batch_size: int,
+    func: Callable,
+    allow_none_pass_through: bool = False,
+) -> Tuple:
+    num_outputs = _num_outputs(batched_outputs)
+    out_dims_as_tuple = _as_tuple(
+        out_dims,
+        num_outputs,
+        lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
+        f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
+    )
+
+    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
+    # There is something wrong with our type bindings for functions that begin
+    # with '_', see #40397.
+    if isinstance(batched_outputs, Tensor):
+        out_dim = out_dims_as_tuple[0]
+        return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim)  # type: ignore[return-value]
+    if allow_none_pass_through:
+        return tuple(
+            (
+                torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
+                if out is not None
+                else None
+            )
+            for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
+        )
+    else:
+        return tuple(
+            torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
+            for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
+        )
+
+
+# Checks that `fn` returned one or more Tensors and nothing else.
+# NB: A python function that return multiple arguments returns a single tuple,
+# so we are effectively checking that `outputs` is a single Tensor or a tuple of
+# Tensors.
+def _validate_outputs(outputs: Any, func: Callable) -> None:
+    if isinstance(outputs, Tensor):
+        return
+    if not isinstance(outputs, tuple):
+        raise ValueError(
+            f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
+            f"Tensors, got type {type(outputs)} as the return."
+        )
+    for idx, output in enumerate(outputs):
+        if isinstance(output, Tensor):
+            continue
+        raise ValueError(
+            f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
+            f"Tensors, got type {type(output)} for return {idx}."
+        )
+
+
+def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
+    if isinstance(out_dims, int):
+        return
+    if not isinstance(out_dims, tuple) or not all(
+        isinstance(out_dim, int) for out_dim in out_dims
+    ):
+        raise ValueError(
+            f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
+            f"an int or a tuple of int representing where in the outputs the "
+            f"vmapped dimension should appear."
+        )
+
+
+def _get_name(func: Callable):
+    if hasattr(func, "__name__"):
+        return func.__name__
+
+    # Not all callables have __name__, in fact, only static functions/methods do.
+    # A callable created via functools.partial or an nn.Module, to name some
+    # examples, don't have a __name__.
+    return repr(func)
+
+
+# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
+# sends those into func, and then unwraps the output BatchedTensors. Operations
+# on BatchedTensors perform the batched operations that the user is asking for.
+def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
+    """
+    Please use torch.vmap instead of this API.
+    """
+    warnings.warn(
+        "Please use torch.vmap instead of torch._vmap_internals.vmap. ",
+        stacklevel=2,
+    )
+    return _vmap(func, in_dims, out_dims)
+
+
+# A version of vmap but without the initial "experimental prototype" warning
+def _vmap(
+    func: Callable,
+    in_dims: in_dims_t = 0,
+    out_dims: out_dims_t = 0,
+    allow_none_pass_through: bool = False,
+) -> Callable:
+    # The `allow_none_pass_through` argument is a temporary workaround may be removed.
+    # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
+    # which may return None if any of the inputs are unused. See the issue discussing this:
+    # https://github.com/facebookresearch/functorch/issues/159.
+    @functools.wraps(func)
+    def wrapped(*args):
+        _check_out_dims_is_int_or_int_tuple(out_dims, func)
+        vmap_level = torch._C._vmapmode_increment_nesting()
+        try:
+            batched_inputs, batch_size = _create_batched_inputs(
+                in_dims, args, vmap_level, func
+            )
+            batched_outputs = func(*batched_inputs)
+            if not allow_none_pass_through:
+                _validate_outputs(batched_outputs, func)
+            return _unwrap_batched(
+                batched_outputs,
+                out_dims,
+                vmap_level,
+                batch_size,
+                func,
+                allow_none_pass_through=allow_none_pass_through,
+            )
+        finally:
+            torch._C._vmapmode_decrement_nesting()
+
+    return wrapped
diff --git a/MLPY/Lib/site-packages/torch/_weights_only_unpickler.py b/MLPY/Lib/site-packages/torch/_weights_only_unpickler.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf2d467522c237df56e7abd4dd7af0274fbb8ea1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/_weights_only_unpickler.py
@@ -0,0 +1,306 @@
+# Unpickler restricted to loading only state dicts
+# Restrict constructing types to a list defined in _get_allowed_globals()
+# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
+# Restrict APPEND/APPENDS to `list`
+# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
+# defined by `_get_allowed_globals()` method, that contains:
+# - torch types (Storage, dtypes, Tensor, `torch.Size`),
+# - `torch._utils._rebuild` functions.
+# - `torch.nn.Parameter`
+# - `collections.OrderedDict`
+
+# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
+# Expected to be useful for loading PyTorch model weights
+# For example:
+# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
+# buf = io.BytesIO(data)
+# weights = torch.load(buf, weights_only = True)
+
+import functools as _functools
+from collections import OrderedDict
+from pickle import (
+    APPEND,
+    APPENDS,
+    BINFLOAT,
+    BINGET,
+    BININT,
+    BININT1,
+    BININT2,
+    BINPERSID,
+    BINPUT,
+    BINUNICODE,
+    BUILD,
+    bytes_types,
+    decode_long,
+    EMPTY_DICT,
+    EMPTY_LIST,
+    EMPTY_SET,
+    EMPTY_TUPLE,
+    GLOBAL,
+    LONG1,
+    LONG_BINGET,
+    LONG_BINPUT,
+    MARK,
+    NEWFALSE,
+    NEWOBJ,
+    NEWTRUE,
+    NONE,
+    PROTO,
+    REDUCE,
+    SETITEM,
+    SETITEMS,
+    SHORT_BINSTRING,
+    STOP,
+    TUPLE,
+    TUPLE1,
+    TUPLE2,
+    TUPLE3,
+    UnpicklingError,
+)
+from struct import unpack
+from sys import maxsize
+from typing import Any, Dict, List
+
+import torch
+
+
+# Unpickling machinery
+@_functools.lru_cache(maxsize=1)
+def _get_allowed_globals():
+    rc: Dict[str, Any] = {
+        "collections.OrderedDict": OrderedDict,
+        "torch.nn.parameter.Parameter": torch.nn.Parameter,
+        "torch.serialization._get_layout": torch.serialization._get_layout,
+        "torch.Size": torch.Size,
+        "torch.Tensor": torch.Tensor,
+    }
+    # dtype
+    for t in [
+        torch.complex32,
+        torch.complex64,
+        torch.complex128,
+        torch.float8_e5m2,
+        torch.float8_e4m3fn,
+        torch.float8_e5m2fnuz,
+        torch.float8_e4m3fnuz,
+        torch.float16,
+        torch.float32,
+        torch.float64,
+        torch.int8,
+        torch.int16,
+        torch.int32,
+        torch.int64,
+    ]:
+        rc[str(t)] = t
+    # Tensor classes
+    for tt in torch._tensor_classes:
+        rc[f"{tt.__module__}.{tt.__name__}"] = tt
+    # Storage classes
+    for ts in torch._storage_classes:
+        if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
+            # Wrap legacy storage types in a dummy class
+            rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
+                ts.__name__
+            )
+        else:
+            rc[f"{ts.__module__}.{ts.__name__}"] = ts
+    # Rebuild functions
+    for f in [
+        torch._utils._rebuild_parameter,
+        torch._utils._rebuild_tensor,
+        torch._utils._rebuild_tensor_v2,
+        torch._utils._rebuild_tensor_v3,
+        torch._utils._rebuild_sparse_tensor,
+        torch._utils._rebuild_meta_tensor_no_storage,
+        torch._utils._rebuild_nested_tensor,
+    ]:
+        rc[f"torch._utils.{f.__name__}"] = f
+
+    # Handles Tensor Subclasses, Tensor's with attributes.
+    # NOTE: It calls into above rebuild functions for regular Tensor types.
+    rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
+    return rc
+
+
+class Unpickler:
+    def __init__(self, file, *, encoding: str = "bytes"):
+        self.encoding = encoding
+        self.readline = file.readline
+        self.read = file.read
+        self.memo: Dict[int, Any] = {}
+
+    def load(self):
+        """Read a pickled object representation from the open file.
+
+        Return the reconstituted object hierarchy specified in the file.
+        """
+        self.metastack = []
+        self.stack: List[Any] = []
+        self.append = self.stack.append
+        read = self.read
+        readline = self.readline
+        while True:
+            key = read(1)
+            if not key:
+                raise EOFError
+            assert isinstance(key, bytes_types)
+            # Risky operators
+            if key[0] == GLOBAL[0]:
+                module = readline()[:-1].decode("utf-8")
+                name = readline()[:-1].decode("utf-8")
+                full_path = f"{module}.{name}"
+                if full_path in _get_allowed_globals():
+                    self.append(_get_allowed_globals()[full_path])
+                else:
+                    raise RuntimeError(f"Unsupported class {full_path}")
+            elif key[0] == NEWOBJ[0]:
+                args = self.stack.pop()
+                cls = self.stack.pop()
+                if cls is not torch.nn.Parameter:
+                    raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
+                self.append(torch.nn.Parameter(*args))
+            elif key[0] == REDUCE[0]:
+                args = self.stack.pop()
+                func = self.stack[-1]
+                if func not in _get_allowed_globals().values():
+                    raise RuntimeError(
+                        f"Trying to call reduce for unrecognized function {func}"
+                    )
+                self.stack[-1] = func(*args)
+            elif key[0] == BUILD[0]:
+                state = self.stack.pop()
+                inst = self.stack[-1]
+                if type(inst) is torch.Tensor:
+                    # Legacy unpickling
+                    inst.set_(*state)
+                elif type(inst) is torch.nn.Parameter:
+                    inst.__setstate__(state)
+                elif type(inst) is OrderedDict:
+                    inst.__dict__.update(state)
+                else:
+                    raise RuntimeError(
+                        f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
+                    )
+            # Stack manipulation
+            elif key[0] == APPEND[0]:
+                item = self.stack.pop()
+                list_obj = self.stack[-1]
+                if type(list_obj) is not list:
+                    raise RuntimeError(
+                        f"Can only append to lists, but got {type(list_obj)}"
+                    )
+                list_obj.append(item)
+            elif key[0] == APPENDS[0]:
+                items = self.pop_mark()
+                list_obj = self.stack[-1]
+                if type(list_obj) is not list:
+                    raise RuntimeError(
+                        f"Can only extend lists, but got {type(list_obj)}"
+                    )
+                list_obj.extend(items)
+            elif key[0] == SETITEM[0]:
+                (v, k) = (self.stack.pop(), self.stack.pop())
+                self.stack[-1][k] = v
+            elif key[0] == SETITEMS[0]:
+                items = self.pop_mark()
+                for i in range(0, len(items), 2):
+                    self.stack[-1][items[i]] = items[i + 1]
+            elif key[0] == MARK[0]:
+                self.metastack.append(self.stack)
+                self.stack = []
+                self.append = self.stack.append
+            elif key[0] == TUPLE[0]:
+                items = self.pop_mark()
+                self.append(tuple(items))
+            elif key[0] == TUPLE1[0]:
+                self.stack[-1] = (self.stack[-1],)
+            elif key[0] == TUPLE2[0]:
+                self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
+            elif key[0] == TUPLE3[0]:
+                self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
+            # Basic types construction
+            elif key[0] == NONE[0]:
+                self.append(None)
+            elif key[0] == NEWFALSE[0]:
+                self.append(False)
+            elif key[0] == NEWTRUE[0]:
+                self.append(True)
+            elif key[0] == EMPTY_TUPLE[0]:
+                self.append(())
+            elif key[0] == EMPTY_LIST[0]:
+                self.append([])
+            elif key[0] == EMPTY_DICT[0]:
+                self.append({})
+            elif key[0] == EMPTY_SET[0]:
+                self.append(set())
+            elif key[0] == BININT[0]:
+                self.append(unpack("d", self.read(8))[0])
+            elif key[0] == BINUNICODE[0]:
+                strlen = unpack(" maxsize:
+                    raise RuntimeError("String is too long")
+                strval = str(read(strlen), "utf-8", "surrogatepass")
+                self.append(strval)
+            elif key[0] == SHORT_BINSTRING[0]:
+                strlen = read(1)[0]
+                strdata = read(strlen)
+                if self.encoding != "bytes":
+                    strdata = strdata.decode(self.encoding, "strict")
+                self.append(strdata)
+            elif key[0] == BINPERSID[0]:
+                pid = self.stack.pop()
+                # Only allow persistent load of storage
+                if type(pid) is not tuple and not type(pid) is not int:
+                    raise RuntimeError(
+                        f"persistent_load id must be tuple or int, but got {type(pid)}"
+                    )
+                if (
+                    type(pid) is tuple
+                    and len(pid) > 0
+                    and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
+                ):
+                    raise RuntimeError(
+                        f"Only persistent_load of storage is allowed, but got {pid[0]}"
+                    )
+                self.append(self.persistent_load(pid))
+            elif key[0] in [BINGET[0], LONG_BINGET[0]]:
+                idx = (read(1) if key[0] == BINGET[0] else unpack("` for details.
+
+    When entering an autocast-enabled region, Tensors may be any type.
+    You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
+
+    :class:`autocast` should wrap only the forward pass(es) of your network, including the loss
+    computation(s).  Backward passes under autocast are not recommended.
+    Backward ops run in the same type that autocast used for corresponding forward ops.
+
+    Example for CUDA Devices::
+
+        # Creates model and optimizer in default precision
+        model = Net().cuda()
+        optimizer = optim.SGD(model.parameters(), ...)
+
+        for input, target in data:
+            optimizer.zero_grad()
+
+            # Enables autocasting for the forward pass (model + loss)
+            with torch.autocast(device_type="cuda"):
+                output = model(input)
+                loss = loss_fn(output, target)
+
+            # Exits the context manager before backward()
+            loss.backward()
+            optimizer.step()
+
+    See the :ref:`CUDA Automatic Mixed Precision examples` for usage (along with gradient scaling)
+    in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
+
+    :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
+
+        class AutocastModel(nn.Module):
+            ...
+            @torch.autocast(device_type="cuda")
+            def forward(self, input):
+                ...
+
+    Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
+    After returning to an autocast-disabled region, using them with floating-point
+    Tensors of different dtypes may cause type mismatch errors.  If so, cast the Tensor(s)
+    produced in the autocast region back to ``float32`` (or other dtype if desired).
+    If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
+    and incurs no additional overhead.
+    CUDA Example::
+
+        # Creates some tensors in default dtype (here assumed to be float32)
+        a_float32 = torch.rand((8, 8), device="cuda")
+        b_float32 = torch.rand((8, 8), device="cuda")
+        c_float32 = torch.rand((8, 8), device="cuda")
+        d_float32 = torch.rand((8, 8), device="cuda")
+
+        with torch.autocast(device_type="cuda"):
+            # torch.mm is on autocast's list of ops that should run in float16.
+            # Inputs are float32, but the op runs in float16 and produces float16 output.
+            # No manual casts are required.
+            e_float16 = torch.mm(a_float32, b_float32)
+            # Also handles mixed input types
+            f_float16 = torch.mm(d_float32, e_float16)
+
+        # After exiting autocast, calls f_float16.float() to use with d_float32
+        g_float32 = torch.mm(d_float32, f_float16.float())
+
+    CPU Training Example::
+
+        # Creates model and optimizer in default precision
+        model = Net()
+        optimizer = optim.SGD(model.parameters(), ...)
+
+        for epoch in epochs:
+            for input, target in data:
+                optimizer.zero_grad()
+
+                # Runs the forward pass with autocasting.
+                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+                    output = model(input)
+                    loss = loss_fn(output, target)
+
+                loss.backward()
+                optimizer.step()
+
+
+    CPU Inference Example::
+
+        # Creates model in default precision
+        model = Net().eval()
+
+        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+            for input in data:
+                # Runs the forward pass with autocasting.
+                output = model(input)
+
+    CPU Inference Example with Jit Trace::
+
+        class TestModel(nn.Module):
+            def __init__(self, input_size, num_classes):
+                super().__init__()
+                self.fc1 = nn.Linear(input_size, num_classes)
+            def forward(self, x):
+                return self.fc1(x)
+
+        input_size = 2
+        num_classes = 2
+        model = TestModel(input_size, num_classes).eval()
+
+        # For now, we suggest to disable the Jit Autocast Pass,
+        # As the issue: https://github.com/pytorch/pytorch/issues/75956
+        torch._C._jit_set_autocast_mode(False)
+
+        with torch.cpu.amp.autocast(cache_enabled=False):
+            model = torch.jit.trace(model, torch.randn(1, input_size))
+        model = torch.jit.freeze(model)
+        # Models Run
+        for _ in range(3):
+            model(torch.randn(1, input_size))
+
+    Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
+    please file an issue.
+
+    ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
+    Locally disabling autocast can be useful, for example, if you want to force a subregion
+    to run in a particular ``dtype``.  Disabling autocast gives you explicit control over
+    the execution type.  In the subregion, inputs from the surrounding region
+    should be cast to ``dtype`` before use::
+
+        # Creates some tensors in default dtype (here assumed to be float32)
+        a_float32 = torch.rand((8, 8), device="cuda")
+        b_float32 = torch.rand((8, 8), device="cuda")
+        c_float32 = torch.rand((8, 8), device="cuda")
+        d_float32 = torch.rand((8, 8), device="cuda")
+
+        with torch.autocast(device_type="cuda"):
+            e_float16 = torch.mm(a_float32, b_float32)
+            with torch.autocast(device_type="cuda", enabled=False):
+                # Calls e_float16.float() to ensure float32 execution
+                # (necessary because e_float16 was created in an autocasted region)
+                f_float32 = torch.mm(c_float32, e_float16.float())
+
+            # No manual casts are required when re-entering the autocast-enabled region.
+            # torch.mm again runs in float16 and produces float16 output, regardless of input types.
+            g_float16 = torch.mm(d_float32, f_float32)
+
+    The autocast state is thread-local.  If you want it enabled in a new thread, the context manager or decorator
+    must be invoked in that thread.  This affects :class:`torch.nn.DataParallel` and
+    :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
+    (see :ref:`Working with Multiple GPUs`).
+
+    Args:
+        device_type(str, required):  Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
+                                     The type is the same as the `type` attribute of a :class:`torch.device`.
+                                     Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
+        enabled(bool, optional):  Whether autocasting should be enabled in the region.
+            Default: ``True``
+        dtype(torch_dtype, optional):  Whether to use torch.float16 or torch.bfloat16.
+        cache_enabled(bool, optional):  Whether the weight cache inside autocast should be enabled.
+            Default: ``True``
+    """
+
+    def __init__(
+        self,
+        device_type: str,
+        dtype: Optional[_dtype] = None,
+        enabled: bool = True,
+        cache_enabled: Optional[bool] = None,
+    ):
+        if torch._jit_internal.is_scripting():
+            self._enabled = enabled
+            self.device = device_type
+            self.fast_dtype = dtype
+            # TODO: support get_autocast_gpu/cpu_dtype
+            assert dtype is not None
+            return
+        self.device = device_type
+        self.custom_backend_name = torch._C._get_privateuse1_backend_name()
+        if self.device == "cuda":
+            self.fast_dtype = torch.get_autocast_gpu_dtype()
+        elif self.device == "cpu":
+            self.fast_dtype = torch.get_autocast_cpu_dtype()
+        elif self.device == "xpu":
+            self.fast_dtype = torch.xpu.get_autocast_xpu_dtype()  # type: ignore[attr-defined]
+        elif self.device == "ipu":
+            self.fast_dtype = torch.get_autocast_ipu_dtype()  # type: ignore[attr-defined]
+        elif self.device == "hpu":
+            self.fast_dtype = torch.hpu.get_autocast_hpu_dtype()  # type: ignore[attr-defined]
+        elif self.device == "xla":
+            self.fast_dtype = torch.get_autocast_xla_dtype()  # type: ignore[attr-defined]
+        elif self.device == self.custom_backend_name:
+            necessary_funcs = [
+                "is_autocast_enabled",
+                "set_autocast_enabled",
+                "get_autocast_dtype",
+                "set_autocast_dtype",
+                "get_amp_supported_dtype",
+            ]
+            message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
+            message += "registered a module or  the module miss some necessary funcs. The backend should register "
+            message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
+            message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
+            message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
+            message += (
+                "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
+            )
+
+            assert hasattr(torch, self.custom_backend_name), message
+            self.custom_device_mod = getattr(torch, self.custom_backend_name)
+            for func in necessary_funcs:
+                assert hasattr(self.custom_device_mod, func), (
+                    message + f"But the func `{func}` is missing. \n"
+                )
+
+            self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
+        else:
+            raise RuntimeError(
+                f"User specified an unsupported autocast device_type '{self.device}'"
+            )
+        self._cache_enabled = torch.is_autocast_cache_enabled()
+        if (
+            enabled
+            and torch.cuda.amp.common.amp_definitely_not_available()
+            and self.device == "cuda"
+        ):
+            warnings.warn(
+                "User provided device_type of 'cuda', but CUDA is not available. Disabling"
+            )
+            enabled = False
+        if dtype is not None:
+            self.fast_dtype = dtype
+        if cache_enabled is not None:
+            self._cache_enabled = cache_enabled
+
+        if self.device == "cpu":
+            supported_dtype = [torch.bfloat16, torch.float16]
+            if self.fast_dtype not in supported_dtype and enabled:
+                error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
+                error_message += "CPU Autocast only supports dtype of "
+                error_message += (
+                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
+                )
+                warnings.warn(error_message)
+                enabled = False
+        elif self.device == "xpu":
+            supported_dtype = [torch.bfloat16, torch.float16]
+            if self.fast_dtype not in supported_dtype:
+                error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
+                error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
+                warnings.warn(error_message)
+                enabled = False
+        elif self.device == "ipu":
+            supported_dtypes = [torch.bfloat16, torch.float16]
+            if self.fast_dtype not in supported_dtypes:
+                error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
+                error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
+                warnings.warn(error_message)
+                enabled = False
+        elif self.device == "hpu":
+            supported_dtype = [torch.bfloat16, torch.float16]
+            if self.fast_dtype not in supported_dtype:
+                error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
+                error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
+                warnings.warn(error_message)
+                enabled = False
+        elif self.device == self.custom_backend_name:
+            supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
+            if self.fast_dtype not in supported_dtype:
+                error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
+                error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
+                error_message += (
+                    ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
+                )
+                warnings.warn(error_message)
+                enabled = False
+        elif self.device == "cuda":
+            if (
+                enabled
+                and self.fast_dtype == torch.bfloat16
+                and not torch.cuda.is_bf16_supported()
+            ):
+                raise RuntimeError(
+                    "Current CUDA Device does not support bfloat16. Please switch dtype to float16."
+                )
+        elif self.device == "xla":
+            supported_dtype = [torch.float16, torch.bfloat16]
+            if self.fast_dtype not in supported_dtype:
+                error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
+                error_message += (
+                    "XLA Autocast only supports dtype of torch.bfloat16 currently."
+                )
+                warnings.warn(error_message)
+                enabled = False
+        self._enabled = enabled
+
+    def __enter__(self):
+        if torch._jit_internal.is_scripting():
+            assert self.fast_dtype is not None
+            return self
+
+        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
+        if self.device == "cpu":
+            self.prev = torch.is_autocast_cpu_enabled()
+            self.prev_fastdtype = torch.get_autocast_cpu_dtype()
+            torch.set_autocast_cpu_enabled(self._enabled)
+            torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
+            torch.autocast_increment_nesting()
+        elif self.device == "xpu":
+            self.prev = torch.xpu.is_autocast_xpu_enabled()  # type: ignore[attr-defined]
+            self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype()  # type: ignore[attr-defined]
+            torch.xpu.set_autocast_xpu_enabled(self._enabled)  # type: ignore[attr-defined]
+            torch.xpu.set_autocast_xpu_dtype(self.fast_dtype)  # type: ignore[attr-defined]
+            torch.autocast_increment_nesting()
+        elif self.device == "ipu":
+            self.prev = torch.is_autocast_ipu_enabled()  # type: ignore[attr-defined]
+            self.prev_fastdtype = torch.get_autocast_ipu_dtype()  # type: ignore[attr-defined]
+            torch.set_autocast_ipu_enabled(self._enabled)  # type: ignore[attr-defined]
+            torch.set_autocast_ipu_dtype(self.fast_dtype)  # type: ignore[attr-defined]
+            torch.autocast_increment_nesting()
+        elif self.device == "hpu":
+            self.prev = torch.hpu.is_autocast_hpu_enabled()  # type: ignore[attr-defined]
+            self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype()  # type: ignore[attr-defined]
+            torch.hpu.set_autocast_hpu_enabled(self._enabled)  # type: ignore[attr-defined]
+            torch.hpu.set_autocast_hpu_dtype(self.fast_dtype)  # type: ignore[attr-defined]
+            torch.autocast_increment_nesting()
+        elif self.device == "xla":
+            self.prev = torch.is_autocast_xla_enabled()  # type: ignore[attr-defined]
+            self.prev_fastdtype = torch.get_autocast_xla_dtype()  # type: ignore[attr-defined]
+            torch.set_autocast_xla_enabled(self._enabled)  # type: ignore[attr-defined]
+            torch.set_autocast_xla_dtype(self.fast_dtype)  # type: ignore[attr-defined]
+            torch.autocast_increment_nesting()
+        elif self.device == self.custom_backend_name:
+            self.prev = self.custom_device_mod.is_autocast_enabled()
+            self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
+            self.custom_device_mod.set_autocast_enabled(self._enabled)
+            self.custom_device_mod.set_autocast_dtype(self.fast_dtype)
+            torch.autocast_increment_nesting()
+        else:
+            self.prev = torch.is_autocast_enabled()
+            self.prev_fastdtype = torch.get_autocast_gpu_dtype()
+            torch.set_autocast_gpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
+            torch.set_autocast_enabled(self._enabled)
+            torch.autocast_increment_nesting()
+        torch.set_autocast_cache_enabled(self._cache_enabled)
+
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
+        if torch._jit_internal.is_scripting():
+            return
+
+        # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
+        if self.device == "cpu":
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            torch.set_autocast_cpu_enabled(self.prev)
+            torch.set_autocast_cpu_dtype(self.prev_fastdtype)
+        elif self.device == "xpu":
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            torch.xpu.set_autocast_xpu_enabled(self.prev)  # type: ignore[attr-defined]
+            torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype)  # type: ignore[attr-defined]
+        elif self.device == "ipu":
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            torch.set_autocast_ipu_enabled(self.prev)  # type: ignore[attr-defined]
+            torch.set_autocast_ipu_dtype(self.prev_fastdtype)  # type: ignore[attr-defined]
+        elif self.device == "hpu":
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            torch.hpu.set_autocast_hpu_enabled(self.prev)  # type: ignore[attr-defined]
+            torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype)  # type: ignore[attr-defined]
+        elif self.device == "xla":
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            torch.set_autocast_xla_enabled(self.prev)  # type: ignore[attr-defined]
+            torch.set_autocast_xla_dtype(self.prev_fastdtype)  # type: ignore[attr-defined]
+        elif self.device == self.custom_backend_name:
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            self.custom_device_mod.set_autocast_enabled(self.prev)
+            self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype)
+        else:
+            if torch.autocast_decrement_nesting() == 0:
+                torch.clear_autocast_cache()
+            torch.set_autocast_enabled(self.prev)
+            torch.set_autocast_gpu_dtype(self.prev_fastdtype)
+        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
+        return False
+
+    def __call__(self, func):
+        if torch._jit_internal.is_scripting():
+            return func
+        return autocast_decorator(self, func)
+
+
+# These functions aren't meant for public usage.
+# They are what we trace into a graph during pre_dispatch tracing
+# when we encounter an autocast context manager.
+def _enter_autocast(*vals):
+    # For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
+    if torch._C._is_torch_function_mode_enabled():
+        return torch.overrides.handle_torch_function(
+            torch.amp._enter_autocast, [], *vals
+        )
+    mode = torch.amp.autocast(*vals)
+    mode.__enter__()
+    return mode
+
+
+def _exit_autocast(mode):
+    if torch._C._is_torch_function_mode_enabled():
+        return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
+    mode.__exit__(None, None, None)
diff --git a/MLPY/Lib/site-packages/torch/amp/grad_scaler.py b/MLPY/Lib/site-packages/torch/amp/grad_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c73fab2274046679ed47c267bf57f52705c2d09c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/amp/grad_scaler.py
@@ -0,0 +1,681 @@
+from __future__ import annotations
+
+import inspect
+import warnings
+from collections import abc, defaultdict
+from enum import Enum
+from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
+
+import torch
+
+
+__all__ = ["OptState", "GradScaler"]
+
+
+class _MultiDeviceReplicator:
+    """Lazily serves copies of a tensor to requested devices.
+
+    Copies are cached per-device.
+    """
+
+    def __init__(self, master_tensor: torch.Tensor) -> None:
+        self.master = master_tensor
+        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+    def get(self, device: torch.device) -> torch.Tensor:
+        retval = self._per_device_tensors.get(device, None)
+        if retval is None:
+            retval = self.master.to(device=device, non_blocking=True, copy=True)
+            self._per_device_tensors[device] = retval
+        return retval
+
+
+# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
+# as well as associated "enum" values.  Prefers defining these at top level because
+# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
+# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
+#   causes a circular reference, which we'd rather avoid.
+class OptState(Enum):
+    READY = 0
+    UNSCALED = 1
+    STEPPED = 2
+
+
+def _refresh_per_optimizer_state() -> Dict[str, Any]:
+    return {"stage": OptState.READY, "found_inf_per_device": {}}
+
+
+class GradScaler:
+    """An instance ``scaler`` of :class:`GradScaler`.
+
+    Helps perform the steps of gradient scaling
+    conveniently.
+
+    * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
+    * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
+    * ``scaler.update()`` updates ``scaler``'s scale factor.
+
+    Example::
+
+        # Creates a GradScaler once at the beginning of training.
+        scaler = GradScaler()
+
+        for epoch in epochs:
+            for input, target in data:
+                optimizer.zero_grad()
+                output = model(input)
+                loss = loss_fn(output, target)
+
+                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
+                scaler.scale(loss).backward()
+
+                # scaler.step() first unscales gradients of the optimizer's params.
+                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
+                # otherwise, optimizer.step() is skipped.
+                scaler.step(optimizer)
+
+                # Updates the scale for next iteration.
+                scaler.update()
+
+    See the :ref:`Automatic Mixed Precision examples` for usage
+    (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
+    and multiple losses/optimizers.
+
+    ``scaler`` dynamically estimates the scale factor each iteration.  To minimize gradient underflow,
+    a large scale factor should be used.  However, ``float16`` values can "overflow" (become inf or NaN) if
+    the scale factor is too large.  Therefore, the optimal scale factor is the largest factor that can be used
+    without incurring inf or NaN gradient values.
+    ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
+    ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
+
+    * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
+      themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
+
+    * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
+      If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
+      ``growth_factor``.
+
+    The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
+    value calibrates.  ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
+    iterations.  After that, step skipping should occur rarely (once every few hundred or thousand iterations).
+
+    Args:
+        device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'.
+            The type is the same as the `type` attribute of a :class:`torch.device`.
+            Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
+        init_scale (float, optional, default=2.**16):  Initial scale factor.
+        growth_factor (float, optional, default=2.0):  Factor by which the scale is multiplied during
+            :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
+        backoff_factor (float, optional, default=0.5):  Factor by which the scale is multiplied during
+            :meth:`update` if inf/NaN gradients occur in an iteration.
+        growth_interval (int, optional, default=2000):  Number of consecutive iterations without inf/NaN gradients
+            that must occur for the scale to be multiplied by ``growth_factor``.
+        enabled (bool, optional):  If ``False``, disables gradient scaling. :meth:`step` simply
+            invokes the underlying ``optimizer.step()``, and other methods become no-ops.
+            Default: ``True``
+    """
+
+    def __init__(
+        self,
+        device: str = "cuda",
+        init_scale: float = 2.0**16,
+        growth_factor: float = 2.0,
+        backoff_factor: float = 0.5,
+        growth_interval: int = 2000,
+        enabled: bool = True,
+    ) -> None:
+        self._device = device
+        self._enabled = enabled
+        if self._device == "cuda":
+            if enabled and torch.cuda.amp.common.amp_definitely_not_available():
+                warnings.warn(
+                    "torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling."
+                )
+                self._enabled = False
+
+        if self._enabled:
+            assert growth_factor > 1.0, "The growth factor must be > 1.0."
+            assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
+
+            self._init_scale = init_scale
+            # self._scale will be lazily initialized during the first call to scale()
+            self._scale: Optional[torch.Tensor] = None
+            self._growth_factor = growth_factor
+            self._backoff_factor = backoff_factor
+            self._growth_interval = growth_interval
+            self._init_growth_tracker = 0
+            # self._growth_tracker will be lazily initialized during the first call to scale()
+            self._growth_tracker: Optional[torch.Tensor] = None
+            self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
+                _refresh_per_optimizer_state
+            )
+
+    def _check_scale_growth_tracker(
+        self, funcname: str
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
+        assert self._scale is not None, (
+            f"Attempted {funcname} but _scale is None.  " + fix
+        )
+        assert self._growth_tracker is not None, (
+            f"Attempted {funcname} but _growth_tracker is None.  " + fix
+        )
+        return (self._scale, self._growth_tracker)
+
+    def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
+        assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
+        self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
+        self._growth_tracker = torch.full(
+            (), self._init_growth_tracker, dtype=torch.int32, device=dev
+        )
+
+    @overload
+    def scale(self, outputs: torch.Tensor) -> torch.Tensor:
+        ...
+
+    @overload
+    def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
+        ...
+
+    @overload
+    def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
+        ...
+
+    @overload
+    def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
+        ...
+
+    def scale(
+        self,
+        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
+    ) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
+        """
+        Multiplies ('scales') a tensor or list of tensors by the scale factor.
+
+        Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned
+        unmodified.
+
+        Args:
+            outputs (Tensor or iterable of Tensors):  Outputs to scale.
+        """
+        if not self._enabled:
+            return outputs
+
+        # Short-circuit for the common case.
+        if isinstance(outputs, torch.Tensor):
+            if self._scale is None:
+                self._lazy_init_scale_growth_tracker(outputs.device)
+            assert self._scale is not None
+            return outputs * self._scale.to(device=outputs.device, non_blocking=True)
+
+        # Invoke the more complex machinery only if we're treating multiple outputs.
+        stash: List[
+            _MultiDeviceReplicator
+        ] = []  # holds a reference that can be overwritten by apply_scale
+
+        def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
+            if isinstance(val, torch.Tensor):
+                if len(stash) == 0:
+                    if self._scale is None:
+                        self._lazy_init_scale_growth_tracker(val.device)
+                    assert self._scale is not None
+                    stash.append(_MultiDeviceReplicator(self._scale))
+                return val * stash[0].get(val.device)
+            if isinstance(val, abc.Iterable):
+                iterable = map(apply_scale, val)
+                if isinstance(val, (list, tuple)):
+                    return type(val)(iterable)
+                return iterable
+            raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+        return apply_scale(outputs)
+
+    def _unscale_grads_(
+        self,
+        optimizer: torch.optim.Optimizer,
+        inv_scale: torch.Tensor,
+        found_inf: torch.Tensor,
+        allow_fp16: bool,
+    ) -> Dict[torch.device, torch.Tensor]:
+        per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
+        per_device_found_inf = _MultiDeviceReplicator(found_inf)
+
+        # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
+        # There could be hundreds of grads, so we'd like to iterate through them just once.
+        # However, we don't know their devices or dtypes in advance.
+
+        # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
+        # Google says mypy struggles with defaultdicts type annotations.
+        per_device_and_dtype_grads: Dict[
+            torch.device, Dict[torch.dtype, List[torch.Tensor]]
+        ] = defaultdict(lambda: defaultdict(list))
+        with torch.no_grad():
+            for group in optimizer.param_groups:
+                for param in group["params"]:
+                    assert isinstance(param, torch.Tensor)
+                    if param.grad is None:
+                        continue
+                    if (not allow_fp16) and param.grad.dtype == torch.float16:
+                        raise ValueError("Attempting to unscale FP16 gradients.")
+                    if param.grad.is_sparse:
+                        # is_coalesced() == False means the sparse grad has values with duplicate indices.
+                        # coalesce() deduplicates indices and adds all values that have the same index.
+                        # For scaled fp16 values, there's a good chance coalescing will cause overflow,
+                        # so we should check the coalesced _values().
+                        if param.grad.dtype is torch.float16:
+                            param.grad = param.grad.coalesce()
+                        to_unscale = param.grad._values()
+                    else:
+                        to_unscale = param.grad
+
+                    # TODO: is there a way to split by device and dtype without appending in the inner loop?
+                    per_device_and_dtype_grads[to_unscale.device][
+                        to_unscale.dtype
+                    ].append(to_unscale)
+
+            for device, per_dtype_grads in per_device_and_dtype_grads.items():
+                for grads in per_dtype_grads.values():
+                    torch._amp_foreach_non_finite_check_and_unscale_(
+                        grads,
+                        per_device_found_inf.get(device),
+                        per_device_inv_scale.get(device),
+                    )
+
+        return per_device_found_inf._per_device_tensors
+
+    def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
+        """
+        Divides ("unscales") the optimizer's gradient tensors by the scale factor.
+
+        :meth:`unscale_` is optional, serving cases where you need to
+        :ref:`modify or inspect gradients`
+        between the backward pass(es) and :meth:`step`.
+        If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.
+
+        Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
+
+            ...
+            scaler.scale(loss).backward()
+            scaler.unscale_(optimizer)
+            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+            scaler.step(optimizer)
+            scaler.update()
+
+        Args:
+            optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.
+
+        .. note::
+            :meth:`unscale_` does not incur a CPU-GPU sync.
+
+        .. warning::
+            :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
+            and only after all gradients for that optimizer's assigned parameters have been accumulated.
+            Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
+
+        .. warning::
+            :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
+        """
+        if not self._enabled:
+            return
+
+        self._check_scale_growth_tracker("unscale_")
+
+        optimizer_state = self._per_optimizer_states[id(optimizer)]
+
+        if optimizer_state["stage"] is OptState.UNSCALED:
+            raise RuntimeError(
+                "unscale_() has already been called on this optimizer since the last update()."
+            )
+        elif optimizer_state["stage"] is OptState.STEPPED:
+            raise RuntimeError("unscale_() is being called after step().")
+
+        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
+        assert self._scale is not None
+        inv_scale = self._scale.double().reciprocal().float()
+        found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
+
+        optimizer_state["found_inf_per_device"] = self._unscale_grads_(
+            optimizer, inv_scale, found_inf, False
+        )
+        optimizer_state["stage"] = OptState.UNSCALED
+
+    def _maybe_opt_step(
+        self,
+        optimizer: torch.optim.Optimizer,
+        optimizer_state: Dict[str, Any],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Optional[float]:
+        retval: Optional[float] = None
+        if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
+            retval = optimizer.step(*args, **kwargs)
+        return retval
+
+    def step(
+        self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
+    ) -> Optional[float]:
+        """Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN.
+
+        :meth:`step` carries out the following two operations:
+
+        1.  Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
+            earlier in the iteration).  As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
+        2.  If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
+            gradients.  Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
+
+        ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
+
+        Returns the return value of ``optimizer.step(*args, **kwargs)``.
+
+        Args:
+            optimizer (torch.optim.Optimizer):  Optimizer that applies the gradients.
+            args:  Any arguments.
+            kwargs:  Any keyword arguments.
+
+        .. warning::
+            Closure use is not currently supported.
+        """
+        if not self._enabled:
+            return optimizer.step(*args, **kwargs)
+
+        if "closure" in kwargs:
+            raise RuntimeError(
+                "Closure use is not currently supported if GradScaler is enabled."
+            )
+
+        self._check_scale_growth_tracker("step")
+
+        optimizer_state = self._per_optimizer_states[id(optimizer)]
+
+        if optimizer_state["stage"] is OptState.STEPPED:
+            raise RuntimeError(
+                "step() has already been called since the last update()."
+            )
+
+        retval: Optional[float] = None
+
+        if getattr(optimizer, "_step_supports_amp_scaling", False):
+            # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
+            # The contract with custom optimizers is that their step() should accept an additional,
+            # optional grad_scaler kwarg.  We append self to the kwargs so the custom optimizer has full information:
+            # it can query its own state, invoke unscale_ on itself, etc
+            # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
+            # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
+            # and `found_inf` to the passed optimizer so that the optimizer can utilize those
+            # to skip the parameter updates or unscale gradients before updating parameters in
+            # the fused kernel, e.g. `FusedAdamMathFunctor`.
+            # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
+            # while the method is expected to be called by users side, i.e. their optimizers.
+            kwargs_ = kwargs
+            has_grad_scaler_kwarg = (
+                "grad_scaler" in inspect.signature(optimizer.step).parameters
+            )
+            if has_grad_scaler_kwarg:
+                warnings.warn(
+                    "GradScaler is going to stop passing itself as a keyword argument to the passed "
+                    "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
+                    "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
+                    FutureWarning,
+                )
+                kwargs_.update({"grad_scaler": self})
+            else:
+                if optimizer_state["stage"] is OptState.READY:
+                    self._check_inf_per_device(optimizer)
+                scaler = self._get_scale_async()
+                assert scaler is not None
+                found_inf = cast(
+                    torch.Tensor,
+                    sum(
+                        [
+                            t.to(scaler.device, non_blocking=True)
+                            for t in optimizer_state["found_inf_per_device"].values()
+                        ]
+                    ),
+                )
+                optimizer.grad_scale = (  # type: ignore[attr-defined]
+                    None if optimizer_state["stage"] == OptState.UNSCALED else scaler
+                )
+                optimizer.found_inf = found_inf  # type: ignore[attr-defined]
+            retval = optimizer.step(*args, **kwargs_)
+            optimizer_state["stage"] = OptState.STEPPED
+            if not has_grad_scaler_kwarg:
+                del optimizer.grad_scale  # type: ignore[attr-defined]
+                del optimizer.found_inf  # type: ignore[attr-defined]
+            return retval
+
+        if optimizer_state["stage"] is OptState.READY:
+            self.unscale_(optimizer)
+
+        assert (
+            len(optimizer_state["found_inf_per_device"]) > 0
+        ), "No inf checks were recorded for this optimizer."
+
+        retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
+
+        optimizer_state["stage"] = OptState.STEPPED
+
+        return retval
+
+    def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
+        """Update the scale factor.
+
+        If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
+        to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
+        the scale is multiplied by ``growth_factor`` to increase it.
+
+        Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
+        used directly, it's used to fill GradScaler's internal scale tensor. So if
+        ``new_scale`` was a tensor, later in-place changes to that tensor will not further
+        affect the scale GradScaler uses internally.)
+
+        Args:
+            new_scale (float or :class:`torch.Tensor`, optional, default=None):  New scale factor.
+
+        .. warning::
+            :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
+            been invoked for all optimizers used this iteration.
+
+        .. warning::
+            For performance reasons, we do not check the scale factor value to avoid synchronizations,
+            so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
+            you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
+            bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
+        """
+        if not self._enabled:
+            return
+
+        _scale, _growth_tracker = self._check_scale_growth_tracker("update")
+
+        if new_scale is not None:
+            assert self._scale is not None
+            # Accept a new user-defined scale.
+            if isinstance(new_scale, float):
+                self._scale.fill_(new_scale)
+            else:
+                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
+                    torch.FloatTensor with requires_grad=False."
+                assert new_scale.device.type == self._device, reason
+                assert new_scale.numel() == 1, reason
+                assert new_scale.requires_grad is False, reason
+                self._scale.copy_(new_scale)
+        else:
+            # Consume shared inf/nan data collected from optimizers to update the scale.
+            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
+            found_infs = [
+                found_inf.to(device=_scale.device, non_blocking=True)
+                for state in self._per_optimizer_states.values()
+                for found_inf in state["found_inf_per_device"].values()
+            ]
+
+            assert len(found_infs) > 0, "No inf checks were recorded prior to update."
+
+            found_inf_combined = found_infs[0]
+            if len(found_infs) > 1:
+                for i in range(1, len(found_infs)):
+                    found_inf_combined += found_infs[i]
+
+            torch._amp_update_scale_(
+                _scale,
+                _growth_tracker,
+                found_inf_combined,
+                self._growth_factor,
+                self._backoff_factor,
+                self._growth_interval,
+            )
+
+        # To prepare for next iteration, clear the data collected from optimizers this iteration.
+        self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
+
+    def _get_scale_async(self) -> Optional[torch.Tensor]:
+        return self._scale
+
+    def get_scale(self) -> float:
+        """Return a Python float containing the current scale, or 1.0 if scaling is disabled.
+
+        .. warning::
+            :meth:`get_scale` incurs a CPU-GPU sync.
+        """
+        if self._enabled:
+            return (
+                self._init_scale
+                if (scale := self._get_scale_async()) is None
+                else cast(float, scale.item())
+            )
+        return 1.0
+
+    def get_growth_factor(self) -> float:
+        r"""Return a Python float containing the scale growth factor."""
+        return self._growth_factor
+
+    def set_growth_factor(self, new_factor: float) -> None:
+        r"""Set a new scale growth factor.
+
+        Args:
+            new_scale (float):  Value to use as the new scale growth factor.
+        """
+        self._growth_factor = new_factor
+
+    def get_backoff_factor(self) -> float:
+        r"""Return a Python float containing the scale backoff factor."""
+        return self._backoff_factor
+
+    def set_backoff_factor(self, new_factor: float) -> None:
+        r"""Set a new scale backoff factor.
+
+        Args:
+            new_scale (float):  Value to use as the new scale backoff factor.
+        """
+        self._backoff_factor = new_factor
+
+    def get_growth_interval(self) -> int:
+        r"""Return a Python int containing the growth interval."""
+        return self._growth_interval
+
+    def set_growth_interval(self, new_interval: int) -> None:
+        r"""Set a new growth interval.
+
+        Args:
+            new_interval (int):  Value to use as the new growth interval.
+        """
+        self._growth_interval = new_interval
+
+    def _get_growth_tracker(self) -> int:
+        if self._enabled:
+            return (
+                self._init_growth_tracker
+                if self._growth_tracker is None
+                else cast(int, self._growth_tracker.item())
+            )
+        return 0
+
+    def is_enabled(self) -> bool:
+        r"""Return a bool indicating whether this instance is enabled."""
+        return self._enabled
+
+    def state_dict(self) -> Dict[str, Any]:
+        r"""Return the state of the scaler as a :class:`dict`.
+
+        It contains five entries:
+
+        * ``"scale"`` - a Python float containing the current scale
+        * ``"growth_factor"`` - a Python float containing the current growth factor
+        * ``"backoff_factor"`` - a Python float containing the current backoff factor
+        * ``"growth_interval"`` - a Python int containing the current growth interval
+        * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
+
+        If this instance is not enabled, returns an empty dict.
+
+        .. note::
+           If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
+           should be called after :meth:`update`.
+        """
+        if self._enabled:
+            return {
+                "scale": self.get_scale(),
+                "growth_factor": self._growth_factor,
+                "backoff_factor": self._backoff_factor,
+                "growth_interval": self._growth_interval,
+                "_growth_tracker": self._get_growth_tracker(),
+            }
+        return {}
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        r"""Load the scaler state.
+
+        If this instance is disabled, :meth:`load_state_dict` is a no-op.
+
+        Args:
+           state_dict(dict): scaler state.  Should be an object returned from a call to :meth:`state_dict`.
+        """
+        if not self._enabled:
+            return
+
+        if len(state_dict) == 0:
+            raise RuntimeError(
+                "The source state dict is empty, possibly because it was saved "
+                "from a disabled instance of GradScaler."
+            )
+
+        self._init_scale = cast(float, state_dict["scale"])
+        if self._scale is not None:
+            self._scale.fill_(state_dict["scale"])
+        self._growth_factor = cast(float, state_dict["growth_factor"])
+        self._backoff_factor = cast(float, state_dict["backoff_factor"])
+        self._growth_interval = cast(int, state_dict["growth_interval"])
+        self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
+        if self._growth_tracker is not None:
+            self._growth_tracker.fill_(state_dict["_growth_tracker"])
+
+    def __getstate__(self) -> Dict[str, Any]:
+        state = self.__dict__.copy()
+        if self._enabled:
+            assert len(self._per_optimizer_states) == 0, (
+                "A GradScaler instance may only be pickled at the beginning "
+                "of an iteration, or at the end after scaler.update()."
+            )
+            # Pickling _scale and _growth_tracker Tensors directly triggers
+            # "warnings.warn("pickle support for Storage will be removed in 1.5..."
+            # so instead, we set the unpickled instance up to reinitialize them lazily.
+            state["_init_scale"] = self.get_scale()
+            state["_init_growth_tracker"] = self._get_growth_tracker()
+            state["_scale"] = None
+            state["_growth_tracker"] = None
+        return state
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        self.__dict__.update(state)
+
+    def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
+        _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
+
+        dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
+        found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
+
+        self._per_optimizer_states[id(optimizer)][
+            "found_inf_per_device"
+        ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
+
+        return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
+
+    def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
+        return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
diff --git a/MLPY/Lib/site-packages/torch/ao/__init__.py b/MLPY/Lib/site-packages/torch/ao/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d07d31dfa156371aceefffd993bb34afb7042db
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/__init__.py
@@ -0,0 +1,16 @@
+# torch.ao is a package with a lot of interdependencies.
+# We will use lazy import to avoid cyclic dependencies here.
+
+
+__all__ = [
+    "nn",
+    "ns",
+    "quantization",
+    "pruning",
+]
+
+def __getattr__(name):
+    if name in __all__:
+        import importlib
+        return importlib.import_module("." + name, __name__)
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/MLPY/Lib/site-packages/torch/ao/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33c6e4c4f69972d81ecc4fff7891838b6f7c15eb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..126fe0e0ce0589f3d069296d642445c596024ab6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/__init__.py
@@ -0,0 +1,19 @@
+# We are exposing all subpackages to the end-user.
+# Because of possible inter-dependency, we want to avoid
+# the cyclic imports, thus implementing lazy version
+# as per https://peps.python.org/pep-0562/
+
+import importlib
+
+__all__ = [
+    "intrinsic",
+    "qat",
+    "quantizable",
+    "quantized",
+    "sparse",
+]
+
+def __getattr__(name):
+    if name in __all__:
+        return importlib.import_module("." + name, __name__)
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd12db298c1797de70b121da5f62a0b29bf940ae
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19abcc6c5a918bc1cf45618ee4aaa02631cdb11e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/__init__.py
@@ -0,0 +1,36 @@
+from .modules import *  # noqa: F403
+from .modules.fused import _FusedModule  # noqa: F403
+
+# # Subpackages
+# from . import qat  # noqa: F403
+# from . import quantized  # noqa: F403
+
+__all__ = [
+    'ConvBn1d',
+    'ConvBn2d',
+    'ConvBn3d',
+    'ConvBnReLU1d',
+    'ConvBnReLU2d',
+    'ConvBnReLU3d',
+    'ConvReLU1d',
+    'ConvReLU2d',
+    'ConvReLU3d',
+    'LinearReLU',
+    'BNReLU2d',
+    'BNReLU3d',
+    'LinearBn1d',
+    'LinearLeakyReLU',
+    'LinearTanh',
+    'ConvAdd2d',
+    'ConvAddReLU2d',
+]
+
+# We are exposing all subpackages to the end-user.
+# Because of possible inter-dependency, we want to avoid
+# the cyclic imports, thus implementing lazy version
+# as per https://peps.python.org/pep-0562/
+def __getattr__(name):
+    if name in __all__:
+        import importlib
+        return importlib.import_module("." + name, __name__)
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e073ed05116639d374aa96434d646fbafbcac1b8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3256e90e3e1c94f0bba202bf1741adce7e072f23
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__init__.py
@@ -0,0 +1,38 @@
+from .fused import _FusedModule  # noqa: F401
+from .fused import ConvBn1d
+from .fused import ConvBn2d
+from .fused import ConvBn3d
+from .fused import ConvBnReLU1d
+from .fused import ConvBnReLU2d
+from .fused import ConvBnReLU3d
+from .fused import ConvReLU1d
+from .fused import ConvReLU2d
+from .fused import ConvReLU3d
+from .fused import LinearReLU
+from .fused import BNReLU2d
+from .fused import BNReLU3d
+from .fused import LinearBn1d
+from .fused import LinearLeakyReLU
+from .fused import LinearTanh
+from .fused import ConvAdd2d
+from .fused import ConvAddReLU2d
+
+__all__ = [
+    'ConvBn1d',
+    'ConvBn2d',
+    'ConvBn3d',
+    'ConvBnReLU1d',
+    'ConvBnReLU2d',
+    'ConvBnReLU3d',
+    'ConvReLU1d',
+    'ConvReLU2d',
+    'ConvReLU3d',
+    'LinearReLU',
+    'BNReLU2d',
+    'BNReLU3d',
+    'LinearBn1d',
+    'LinearLeakyReLU',
+    'LinearTanh',
+    'ConvAdd2d',
+    'ConvAddReLU2d',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..907df24381036865500530497fa7ad89509a805b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0d282ff392c7f7891f7f9656f937ae5a31ccb68
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py
new file mode 100644
index 0000000000000000000000000000000000000000..36285ab3d4cc107dd23ea3a49617c5b56e4cc366
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/modules/fused.py
@@ -0,0 +1,160 @@
+import torch
+from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
+           'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
+           'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d']
+
+# Used for identifying intrinsic modules used in quantization
+class _FusedModule(torch.nn.Sequential):
+    pass
+
+class ConvReLU1d(_FusedModule):
+    r"""This is a sequential container which calls the Conv1d and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, relu):
+        assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(relu) == ReLU, \
+            f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
+        super().__init__(conv, relu)
+
+class ConvReLU2d(_FusedModule):
+    r"""This is a sequential container which calls the Conv2d and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, relu):
+        assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(relu) == ReLU, \
+            f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
+        super().__init__(conv, relu)
+
+class ConvReLU3d(_FusedModule):
+    r"""This is a sequential container which calls the Conv3d and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, relu):
+        assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(relu) == ReLU, \
+            f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}'
+        super().__init__(conv, relu)
+
+class LinearReLU(_FusedModule):
+    r"""This is a sequential container which calls the Linear and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, linear, relu):
+        assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(relu) == ReLU, \
+            'Incorrect types for input modules{}{}'.format(
+                type_before_parametrizations(linear), type_before_parametrizations(relu))
+        super().__init__(linear, relu)
+
+class ConvBn1d(_FusedModule):
+    r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, bn):
+        assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d, \
+            f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
+        super().__init__(conv, bn)
+
+class ConvBn2d(_FusedModule):
+    r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, bn):
+        assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d, \
+            f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
+        super().__init__(conv, bn)
+
+class ConvBnReLU1d(_FusedModule):
+    r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, bn, relu):
+        assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d and \
+            type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
+            .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
+        super().__init__(conv, bn, relu)
+
+class ConvBnReLU2d(_FusedModule):
+    r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, bn, relu):
+        assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d and \
+            type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
+            .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
+        super().__init__(conv, bn, relu)
+
+class ConvBn3d(_FusedModule):
+    r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, bn):
+        assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d, \
+            f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}'
+        super().__init__(conv, bn)
+
+class ConvBnReLU3d(_FusedModule):
+    r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, bn, relu):
+        assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d and \
+            type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
+            .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
+        super().__init__(conv, bn, relu)
+
+
+class BNReLU2d(_FusedModule):
+    r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, batch_norm, relu):
+        assert type_before_parametrizations(batch_norm) == BatchNorm2d and type_before_parametrizations(relu) == ReLU, \
+            'Incorrect types for input modules{}{}'.format(
+                type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
+        super().__init__(batch_norm, relu)
+
+class BNReLU3d(_FusedModule):
+    r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, batch_norm, relu):
+        assert type_before_parametrizations(batch_norm) == BatchNorm3d and type_before_parametrizations(relu) == ReLU, \
+            'Incorrect types for input modules{}{}'.format(
+                type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
+        super().__init__(batch_norm, relu)
+
+
+class LinearBn1d(_FusedModule):
+    r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, linear, bn):
+        assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(bn) == BatchNorm1d, \
+            f'Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}'
+        super().__init__(linear, bn)
+
+class LinearLeakyReLU(_FusedModule):
+    r"""This is a sequential container which calls the Linear and LeakyReLU modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, linear, leaky_relu):
+        assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, \
+            f'Incorrect types for input modules{type(linear)}{type(leaky_relu)}'
+        super().__init__(linear, leaky_relu)
+
+class LinearTanh(_FusedModule):
+    r"""This is a sequential container which calls the Linear and Tanh modules.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, linear, tanh):
+        assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, \
+            f'Incorrect types for input modules{type(linear)}{type(tanh)}'
+        super().__init__(linear, tanh)
+
+class ConvAdd2d(_FusedModule):
+    r"""This is a sequential container which calls the Conv2d modules with extra Add.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, add):
+        super().__init__(conv)
+        self.add = add
+
+    def forward(self, x1, x2):
+        return self.add(self[0](x1), x2)
+
+class ConvAddReLU2d(_FusedModule):
+    r"""This is a sequential container which calls the Conv2d, add, Relu.
+    During quantization this will be replaced with the corresponding fused module."""
+    def __init__(self, conv, add, relu):
+        super().__init__(conv)
+        self.add = add
+        self.relu = relu
+
+    def forward(self, x1, x2):
+        return self.relu(self.add(self[0](x1), x2))
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb27d182b411f4afab4c5012c8c31fca410ee8d5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..673e1d6d401d3dde3f49863751c122d1de786e82
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py
@@ -0,0 +1,31 @@
+from .linear_relu import LinearReLU
+from .linear_fused import LinearBn1d
+from .conv_fused import (
+    ConvBn1d,
+    ConvBn2d,
+    ConvBn3d,
+    ConvBnReLU1d,
+    ConvBnReLU2d,
+    ConvBnReLU3d,
+    ConvReLU1d,
+    ConvReLU2d,
+    ConvReLU3d,
+    update_bn_stats,
+    freeze_bn_stats,
+)
+
+__all__ = [
+    "LinearReLU",
+    "LinearBn1d",
+    "ConvReLU1d",
+    "ConvReLU2d",
+    "ConvReLU3d",
+    "ConvBn1d",
+    "ConvBn2d",
+    "ConvBn3d",
+    "ConvBnReLU1d",
+    "ConvBnReLU2d",
+    "ConvBnReLU3d",
+    "update_bn_stats",
+    "freeze_bn_stats",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..557a13bcacc4c6f335bdd06fca42a1822dc29352
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5b21a1166ff83ddcaecfb4513b0309b8ade9700
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..553ef90a6ce481bf0553253cd9f1be9959d3b428
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7339710e6e4e4f84876c5b9e208486c1fa17ba1c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd7632734f137998df1e07372ab29f447dd8e2ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py
@@ -0,0 +1,825 @@
+import math
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.qat as nnqat
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.utils import fuse_conv_bn_weights
+from torch.nn.modules.utils import _single, _pair, _triple
+from torch.nn.parameter import Parameter
+from typing import TypeVar
+
+__all__ = ['ConvBn1d', 'ConvBnReLU1d', 'ConvReLU1d', 'ConvBn2d', 'ConvBnReLU2d', 'ConvReLU2d', 'ConvBn3d',
+           'ConvBnReLU3d', 'ConvReLU3d', 'update_bn_stats', 'freeze_bn_stats']
+_BN_CLASS_MAP = {
+    1: nn.BatchNorm1d,
+    2: nn.BatchNorm2d,
+    3: nn.BatchNorm3d,
+}
+
+
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
+
+
+class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
+
+    _version = 2
+    _FLOAT_MODULE = MOD
+
+    def __init__(self,
+                 # ConvNd args
+                 in_channels, out_channels, kernel_size, stride,
+                 padding, dilation, transposed, output_padding,
+                 groups,
+                 bias,
+                 padding_mode,
+                 # BatchNormNd args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None,
+                 dim=2):
+        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
+                                         stride, padding, dilation, transposed,
+                                         output_padding, groups, False, padding_mode)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.freeze_bn = freeze_bn if self.training else True
+        self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
+        self.weight_fake_quant = self.qconfig.weight()
+        if bias:
+            self.bias = Parameter(torch.empty(out_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_bn_parameters()
+
+        # this needs to be called after reset_bn_parameters,
+        # as they modify the same state
+        if self.training:
+            if freeze_bn:
+                self.freeze_bn_stats()
+            else:
+                self.update_bn_stats()
+        else:
+            self.freeze_bn_stats()
+
+        self._enable_slow_path_for_better_numerical_stability = False
+
+    def reset_running_stats(self):
+        self.bn.reset_running_stats()
+
+    def reset_bn_parameters(self):
+        self.bn.reset_running_stats()
+        init.uniform_(self.bn.weight)
+        init.zeros_(self.bn.bias)
+        # note: below is actually for conv, not BN
+        if self.bias is not None:
+            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+            bound = 1 / math.sqrt(fan_in)
+            init.uniform_(self.bias, -bound, bound)
+
+    def reset_parameters(self):
+        super().reset_parameters()
+
+    def update_bn_stats(self):
+        self.freeze_bn = False
+        self.bn.training = True
+        return self
+
+    def freeze_bn_stats(self):
+        self.freeze_bn = True
+        self.bn.training = False
+        return self
+
+    def _forward(self, input):
+        if self._enable_slow_path_for_better_numerical_stability:
+            return self._forward_slow(input)
+        return self._forward_approximate(input)
+
+    def _forward_approximate(self, input):
+        """Approximated method to fuse conv and bn. It requires only one forward pass.
+        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
+        """
+        assert self.bn.running_var is not None
+        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+        scale_factor = self.bn.weight / running_std
+        weight_shape = [1] * len(self.weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(self.weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
+        # using zero bias here since the bias for original conv
+        # will be added later
+        if self.bias is not None:
+            zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
+        else:
+            zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
+        conv = self._conv_forward(input, scaled_weight, zero_bias)
+        conv_orig = conv / scale_factor.reshape(bias_shape)
+        if self.bias is not None:
+            conv_orig = conv_orig + self.bias.reshape(bias_shape)
+        conv = self.bn(conv_orig)
+        return conv
+
+    def _forward_slow(self, input):
+        """
+        A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
+        It requires two forward passes but handles the case bn.weight == 0
+
+        Conv: Y = WX + B_c
+        Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
+
+        Batch statistics:
+          mean_Y = Y.mean()
+                 = Y0.mean() + B_c
+          var_Y = (Y - mean_Y)^2.mean()
+                = (Y0 - Y0.mean())^2.mean()
+        BN (r: bn.weight, beta: bn.bias):
+          Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
+            = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
+
+        Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
+          Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
+            = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
+
+        Fused Conv BN inference (running_std = sqrt(running_var + eps)):
+          Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
+
+        QAT with fused conv bn:
+          Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
+                  = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
+          Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
+        """
+
+        assert self.bn.running_var is not None
+        assert self.bn.running_mean is not None
+
+        # using zero bias here since the bias for original conv
+        # will be added later
+        zero_bias = torch.zeros(self.out_channels, device=self.weight.device, dtype=input.dtype)
+
+        weight_shape = [1] * len(self.weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(self.weight.shape)
+        bias_shape[1] = -1
+
+        if self.bn.training:
+            # needed to compute batch mean/std
+            conv_out = self._conv_forward(input, self.weight, zero_bias)
+            # update bn statistics
+            with torch.no_grad():
+                conv_out_bias = (
+                    conv_out if self.bias is None else conv_out + self.bias.reshape(bias_shape)
+                )
+                self.bn(conv_out_bias)
+
+        # fused conv + bn without bias using bn running statistics
+        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+        scale_factor = self.bn.weight / running_std
+        scaled_weight = self.weight_fake_quant(
+            self.weight * scale_factor.reshape(weight_shape)
+        )
+        # fused conv without bias for inference: (r * W / running_std) * X
+        conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
+
+        if self.bn.training:
+            avg_dims = [0] + list(range(2, len(self.weight.shape)))
+            batch_mean = conv_out.mean(avg_dims)  # type: ignore[possibly-undefined]
+            batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
+                avg_dims
+            )
+            batch_std = torch.sqrt(batch_var + self.bn.eps)
+
+            # scale to use batch std in training mode
+            # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
+            unscale_factor = running_std / batch_std
+            conv_bn *= unscale_factor.reshape(bias_shape)
+
+            fused_mean = batch_mean
+            fused_std = batch_std
+        else:
+            fused_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0)
+            fused_std = running_std
+
+        # fused bias = beta - r * mean / std
+        fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
+        conv_bn += fused_bias.reshape(bias_shape)
+
+        # HACK to let conv bias participate in loss to avoid DDP error (parameters
+        #   were not used in producing loss)
+        if self.bias is not None:
+            conv_bn += (self.bias - self.bias).reshape(bias_shape)
+
+        return conv_bn
+
+    def extra_repr(self):
+        # TODO(jerryzh): extend
+        return super().extra_repr()
+
+    def forward(self, input):
+        return self._forward(input)
+
+    def train(self, mode=True):
+        """
+        Batchnorm's training behavior is using the self.training flag. Prevent
+        changing it if BN is frozen. This makes sure that calling `model.train()`
+        on a model with a frozen BN will behave properly.
+        """
+        self.training = mode
+        if not self.freeze_bn:
+            for module in self.children():
+                module.train(mode)
+        return self
+
+    # ===== Serialization version history =====
+    #
+    # Version 1/None
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #   |--- gamma : Tensor
+    #   |--- beta : Tensor
+    #   |--- running_mean : Tensor
+    #   |--- running_var : Tensor
+    #   |--- num_batches_tracked : Tensor
+    #
+    # Version 2
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #   |--- bn : Module
+    #        |--- weight : Tensor (moved from v1.self.gamma)
+    #        |--- bias : Tensor (moved from v1.self.beta)
+    #        |--- running_mean : Tensor (moved from v1.self.running_mean)
+    #        |--- running_var : Tensor (moved from v1.self.running_var)
+    #        |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+        if version is None or version == 1:
+            # BN related parameters and buffers were moved into the BN module for v2
+            v2_to_v1_names = {
+                'bn.weight': 'gamma',
+                'bn.bias': 'beta',
+                'bn.running_mean': 'running_mean',
+                'bn.running_var': 'running_var',
+                'bn.num_batches_tracked': 'num_batches_tracked',
+            }
+            for v2_name, v1_name in v2_to_v1_names.items():
+                if prefix + v1_name in state_dict:
+                    state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
+                    state_dict.pop(prefix + v1_name)
+                elif prefix + v2_name in state_dict:
+                    # there was a brief period where forward compatibility
+                    # for this module was broken (between
+                    # https://github.com/pytorch/pytorch/pull/38478
+                    # and https://github.com/pytorch/pytorch/pull/38820)
+                    # and modules emitted the v2 state_dict format while
+                    # specifying that version == 1. This patches the forward
+                    # compatibility issue by allowing the v2 style entries to
+                    # be used.
+                    pass
+                elif strict:
+                    missing_keys.append(prefix + v2_name)
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module or qparams_dict
+
+            Args: `mod` a float module, either produced by torch.ao.quantization utilities
+            or directly from user
+        """
+        # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
+        # has no __name__ (code is fine though)
+        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
+            cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid qconfig'
+        qconfig = mod.qconfig
+        conv, bn = mod[0], mod[1]
+        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
+                         conv.stride, conv.padding, conv.dilation,
+                         conv.groups, conv.bias is not None,
+                         conv.padding_mode,
+                         bn.eps, bn.momentum,
+                         False,
+                         qconfig)
+        qat_convbn.weight = conv.weight
+        qat_convbn.bias = conv.bias
+        qat_convbn.bn.weight = bn.weight
+        qat_convbn.bn.bias = bn.bias
+        qat_convbn.bn.running_mean = bn.running_mean
+        qat_convbn.bn.running_var = bn.running_var
+        # mypy error: Cannot determine type of 'num_batches_tracked'
+        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked  # type: ignore[has-type]
+        return qat_convbn
+
+    def to_float(self):
+        cls = type(self)
+        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined]
+            self.in_channels,
+            self.out_channels,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+            self.bias is not None,
+            self.padding_mode)
+        conv.weight = torch.nn.Parameter(self.weight.detach())
+        if self.bias is not None:
+            conv.bias = torch.nn.Parameter(self.bias.detach())
+
+        if cls._FLOAT_BN_MODULE:  # type: ignore[attr-defined]
+            # fuse bn into conv
+            assert self.bn.running_var is not None and self.bn.running_mean is not None
+            conv.weight, conv.bias = fuse_conv_bn_weights(
+                conv.weight,
+                conv.bias,
+                self.bn.running_mean,
+                self.bn.running_var,
+                self.bn.eps,
+                self.bn.weight,
+                self.bn.bias
+            )
+
+        if cls._FLOAT_RELU_MODULE:  # type: ignore[attr-defined]
+            modules = []
+            modules.append(conv)
+            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
+            modules.append(relu)
+            conv_relu = cls._FUSED_FLOAT_MODULE(*modules)  # type: ignore[attr-defined]
+            conv_relu.train(self.training)
+            return conv_relu
+        else:
+            conv.train(self.training)
+            return conv
+
+class ConvBn1d(_ConvBnNd, nn.Conv1d):
+    r"""
+    A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv1d` and
+    :class:`torch.nn.BatchNorm1d`.
+
+    Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_BN_MODULE = nn.BatchNorm1d
+    _FLOAT_RELU_MODULE: None = None
+    _FLOAT_MODULE = nni.ConvBn1d
+    _FLOAT_CONV_MODULE = nn.Conv1d
+
+    def __init__(self,
+                 # Conv1d args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm1d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        kernel_size = _single(kernel_size)
+        stride = _single(stride)
+        padding = _single(padding)
+        dilation = _single(dilation)
+        _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
+                           padding, dilation, False, _single(0), groups, bias, padding_mode,
+                           eps, momentum, freeze_bn, qconfig, dim=1)
+
+class ConvBnReLU1d(ConvBn1d):
+    r"""
+    A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv1d` and
+    :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
+
+    Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    # base class defines _FLOAT_MODULE as "ConvBn1d"
+    _FLOAT_MODULE = nni.ConvBnReLU1d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv1d
+    _FLOAT_BN_MODULE = nn.BatchNorm1d
+    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
+    # module class after fusing bn into conv
+    _FUSED_FLOAT_MODULE = nni.ConvReLU1d
+
+    def __init__(self,
+                 # Conv1d args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm1d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias,
+                         padding_mode, eps, momentum,
+                         freeze_bn,
+                         qconfig)
+
+    def forward(self, input):
+        return F.relu(ConvBn1d._forward(self, input))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
+    r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
+    FakeQuantize modules for weight for
+    quantization aware training.
+
+    We combined the interface of :class:`~torch.nn.Conv1d` and
+    :class:`~torch.nn.BatchNorm1d`.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvReLU1d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv1d
+    _FLOAT_BN_MODULE: None = None
+    _FLOAT_RELU_MODULE = nn.ReLU
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=True, padding_mode='zeros',
+                 qconfig=None):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride=stride, padding=padding, dilation=dilation,
+                         groups=groups, bias=bias, padding_mode=padding_mode,
+                         qconfig=qconfig)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.weight_fake_quant = self.qconfig.weight()
+
+    def forward(self, input):
+        return F.relu(
+            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+class ConvBn2d(_ConvBnNd, nn.Conv2d):
+    r"""
+    A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv2d` and
+    :class:`torch.nn.BatchNorm2d`.
+
+    Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvBn2d
+    _FLOAT_CONV_MODULE = nn.Conv2d
+    _FLOAT_BN_MODULE = nn.BatchNorm2d
+    _FLOAT_RELU_MODULE: None = None
+
+    def __init__(self,
+                 # ConvNd args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm2d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        kernel_size = _pair(kernel_size)
+        stride = _pair(stride)
+        padding = _pair(padding)
+        dilation = _pair(dilation)
+        _ConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
+                           padding, dilation, False, _pair(0), groups, bias, padding_mode,
+                           eps, momentum, freeze_bn, qconfig, dim=2)
+
+class ConvBnReLU2d(ConvBn2d):
+    r"""
+    A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv2d` and
+    :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
+
+    Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    # base class defines _FLOAT_MODULE as "ConvBn2d"
+    _FLOAT_MODULE = nni.ConvBnReLU2d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv2d
+    _FLOAT_BN_MODULE = nn.BatchNorm2d
+    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
+    # module class after fusing bn into conv
+    _FUSED_FLOAT_MODULE = nni.ConvReLU2d
+
+    def __init__(self,
+                 # Conv2d args
+                 in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=None,
+                 padding_mode='zeros',
+                 # BatchNorm2d args
+                 # num_features: out_channels
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        super().__init__(in_channels, out_channels, kernel_size, stride,
+                         padding, dilation, groups, bias,
+                         padding_mode, eps, momentum,
+                         freeze_bn,
+                         qconfig)
+
+    def forward(self, input):
+        return F.relu(ConvBn2d._forward(self, input))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
+    r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
+    FakeQuantize modules for weight for
+    quantization aware training.
+
+    We combined the interface of :class:`~torch.nn.Conv2d` and
+    :class:`~torch.nn.BatchNorm2d`.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvReLU2d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv2d
+    _FLOAT_BN_MODULE: None = None
+    _FLOAT_RELU_MODULE = nn.ReLU
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1,
+                 bias=True, padding_mode='zeros',
+                 qconfig=None):
+        super().__init__(in_channels, out_channels, kernel_size,
+                         stride=stride, padding=padding, dilation=dilation,
+                         groups=groups, bias=bias, padding_mode=padding_mode,
+                         qconfig=qconfig)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.weight_fake_quant = self.qconfig.weight()
+
+    def forward(self, input):
+        return F.relu(
+            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+class ConvBn3d(_ConvBnNd, nn.Conv3d):
+    r"""
+    A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv3d` and
+    :class:`torch.nn.BatchNorm3d`.
+
+    Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvBn3d
+    _FLOAT_CONV_MODULE = nn.Conv3d
+    _FLOAT_BN_MODULE = nn.BatchNorm3d
+    _FLOAT_RELU_MODULE: None = None
+
+    def __init__(
+        self,
+        # ConvNd args
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=None,
+        padding_mode="zeros",
+        # BatchNorm3d args
+        # num_features: out_channels
+        eps=1e-05,
+        momentum=0.1,
+        # affine: True
+        # track_running_stats: True
+        # Args for this module
+        freeze_bn=False,
+        qconfig=None,
+    ):
+        kernel_size = _triple(kernel_size)
+        stride = _triple(stride)
+        padding = _triple(padding)
+        dilation = _triple(dilation)
+        _ConvBnNd.__init__(
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            False,
+            _triple(0),
+            groups,
+            bias,
+            padding_mode,
+            eps,
+            momentum,
+            freeze_bn,
+            qconfig,
+            dim=3,
+        )
+
+class ConvBnReLU3d(ConvBn3d):
+    r"""
+    A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
+    attached with FakeQuantize modules for weight,
+    used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Conv3d` and
+    :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
+
+    Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvBnReLU3d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv3d
+    _FLOAT_BN_MODULE = nn.BatchNorm3d
+    _FLOAT_RELU_MODULE = nn.ReLU  # type: ignore[assignment]
+    # module class after fusing bn into conv
+    _FUSED_FLOAT_MODULE = nni.ConvReLU3d
+
+    def __init__(
+        self,
+        # Conv3d args
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=None,
+        padding_mode="zeros",
+        # BatchNorm3d args
+        # num_features: out_channels
+        eps=1e-05,
+        momentum=0.1,
+        # affine: True
+        # track_running_stats: True
+        # Args for this module
+        freeze_bn=False,
+        qconfig=None,
+    ):
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding,
+            dilation,
+            groups,
+            bias,
+            padding_mode,
+            eps,
+            momentum,
+            freeze_bn,
+            qconfig,
+        )
+
+    def forward(self, input):
+        return F.relu(ConvBn3d._forward(self, input))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
+    r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
+    FakeQuantize modules for weight for
+    quantization aware training.
+
+    We combined the interface of :class:`~torch.nn.Conv3d` and
+    :class:`~torch.nn.BatchNorm3d`.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    _FLOAT_MODULE = nni.ConvReLU3d  # type: ignore[assignment]
+    _FLOAT_CONV_MODULE = nn.Conv3d
+    _FLOAT_BN_MODULE: None = None
+    _FLOAT_RELU_MODULE = nn.ReLU
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias=True,
+        padding_mode="zeros",
+        qconfig=None,
+    ):
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias,
+            padding_mode=padding_mode,
+            qconfig=qconfig,
+        )
+        assert qconfig, "qconfig must be provided for QAT module"
+        self.qconfig = qconfig
+        self.weight_fake_quant = self.qconfig.weight()
+
+    def forward(self, input):
+        return F.relu(
+            self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
+        )
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+def update_bn_stats(mod):
+    if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
+        mod.update_bn_stats()
+
+def freeze_bn_stats(mod):
+    if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
+        mod.freeze_bn_stats()
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
new file mode 100644
index 0000000000000000000000000000000000000000..abcbfdcb2a38ea9d5a2f46c33133cf3ae57ece84
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+from torch.nn import init
+from torch.nn.parameter import Parameter
+from torch.nn.utils.fusion import fuse_linear_bn_weights
+
+__all__ = [
+    "LinearBn1d",
+]
+
+class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
+    r"""
+    A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
+    with FakeQuantize modules for weight, used in quantization aware training.
+
+    We combined the interface of :class:`torch.nn.Linear` and
+    :class:torch.nn.BatchNorm1d`.
+
+    Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
+    to default.
+
+    Attributes:
+        freeze_bn:
+        weight_fake_quant: fake quant module for weight
+
+    """
+    def __init__(self,
+                 # Linear args
+                 in_features, out_features, bias=True,
+                 # BatchNorm1d args
+                 # num_features: out_features
+                 eps=1e-05, momentum=0.1,
+                 # affine: True
+                 # track_running_stats: True
+                 # Args for this module
+                 freeze_bn=False,
+                 qconfig=None):
+        nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.freeze_bn = freeze_bn if self.training else True
+        self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
+        self.weight_fake_quant = self.qconfig.weight()
+        if bias:
+            self.bias = Parameter(torch.empty(out_features))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_bn_parameters()
+
+        # this needs to be called after reset_bn_parameters,
+        # as they modify the same state
+        if self.training:
+            if freeze_bn:
+                self.freeze_bn_stats()
+            else:
+                self.update_bn_stats()
+        else:
+            self.freeze_bn_stats()
+
+    def reset_running_stats(self):
+        self.bn.reset_running_stats()
+
+    def reset_bn_parameters(self):
+        self.bn.reset_running_stats()
+        init.uniform_(self.bn.weight)
+        init.zeros_(self.bn.bias)
+
+    def reset_parameters(self):
+        super().reset_parameters()
+
+    def update_bn_stats(self):
+        self.freeze_bn = False
+        self.bn.training = True
+        return self
+
+    def freeze_bn_stats(self):
+        self.freeze_bn = True
+        self.bn.training = False
+        return self
+
+    def forward(self, input):
+        assert self.bn.running_var is not None
+
+        # Scale the linear weights by BN's running statistics to reduce
+        # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
+        # for motivation.
+        #
+        # Instead of
+        #
+        #   x1 = F.linear(x0, fq(w), b)
+        #   x2 = self.bn(x1)
+        #
+        # We have
+        #
+        #   # scale the weight by previous batch's running statistics
+        #   scale_factor = bn.w / bn.running_std_from_prev_batch
+        #   # do the linear transformation without bias
+        #   x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
+        #   # reverse the scaling and add original bias
+        #   x1_orig = x1_scaled / scale_factor + b
+        #   x2 = self.bn(x1_orig)
+
+        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
+        scale_factor = self.bn.weight / running_std
+        weight_shape = [1] * len(self.weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(self.weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
+        if self.bias is not None:
+            zero_bias = torch.zeros_like(self.bias)
+        else:
+            zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
+        linear_out = F.linear(input, scaled_weight, zero_bias)
+        linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
+        if self.bias is not None:
+            linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
+        bn_out = self.bn(linear_out_orig)
+        return bn_out
+
+    def train(self, mode=True):
+        """
+        Batchnorm's training behavior is using the self.training flag. Prevent
+        changing it if BN is frozen. This makes sure that calling `model.train()`
+        on a model with a frozen BN will behave properly.
+        """
+        self.training = mode
+        if not self.freeze_bn:
+            for module in self.children():
+                module.train(mode)
+        return self
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module or qparams_dict
+
+            Args: `mod' a float module, either produced by torch.ao.quantization
+            utilities or directly from user
+        """
+        assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
+            '.from_float only works for ' + nni.LinearBn1d.__name__
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid config'
+        qconfig = mod.qconfig
+        linear, bn = mod[0], mod[1]
+        qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
+                           bn.eps, bn.momentum,
+                           False, qconfig)
+        qat_linearbn.weight = linear.weight
+        qat_linearbn.bias = linear.bias
+        qat_linearbn.bn.weight = bn.weight
+        qat_linearbn.bn.bias = bn.bias
+        qat_linearbn.bn.running_mean = bn.running_mean
+        qat_linearbn.bn.running_var = bn.running_var
+        qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
+        return qat_linearbn
+
+    def to_float(self):
+        linear = torch.nn.Linear(self.in_features, self.out_features)
+        assert self.bn.running_var is not None and self.bn.running_mean is not None
+        linear.weight, linear.bias = fuse_linear_bn_weights(
+            self.weight,
+            self.bias,
+            self.bn.running_mean,
+            self.bn.running_var,
+            self.bn.eps,
+            self.bn.weight,
+            self.bn.bias)
+        return linear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d28073322abd3d0c0cc61636466d9e50fb80ce7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py
@@ -0,0 +1,48 @@
+import torch
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+
+class LinearReLU(nnqat.Linear, nni._FusedModule):
+    r"""
+    A LinearReLU module fused from Linear and ReLU modules, attached with
+    FakeQuantize modules for weight, used in
+    quantization aware training.
+
+    We adopt the same interface as :class:`torch.nn.Linear`.
+
+    Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight: fake quant module for weight
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.qat.LinearReLU(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearReLU  # type: ignore[assignment]
+
+    def __init__(self, in_features, out_features, bias=True,
+                 qconfig=None):
+        super().__init__(in_features, out_features, bias, qconfig)
+
+    def forward(self, input):
+        return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    def to_float(self):
+        linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
+        linear.weight = torch.nn.Parameter(self.weight.detach())
+        if self.bias is not None:
+            linear.bias = torch.nn.Parameter(self.bias.detach())
+        relu = torch.nn.ReLU()
+        return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f18b3aa317a68bcef55db3b0b837e83224833b23
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py
@@ -0,0 +1,14 @@
+from .modules import *  # noqa: F403
+
+__all__ = [
+    'BNReLU2d',
+    'BNReLU3d',
+    'ConvReLU1d',
+    'ConvReLU2d',
+    'ConvReLU3d',
+    'LinearReLU',
+    'LinearLeakyReLU',
+    'LinearTanh',
+    'ConvAdd2d',
+    'ConvAddReLU2d',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db171c4562b5809419294c6c01d78b7e7fa9f9dd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3549cf40d305aa16f62c6f7f28ef7465e0cc09a1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1832b330e841f595524b3d83b015a4e4795deda
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py
@@ -0,0 +1,6 @@
+import torch
+from .linear_relu import LinearReLU
+
+__all__ = [
+    'LinearReLU',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f5844282a5fc0e6f0f5298227cd968f73ebeb51
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47e726d5a9de42f5a509ce2437cccc74795e6db2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e7c02eec9161d9d1d78f06cfde43e4818e07146
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py
@@ -0,0 +1,55 @@
+import torch
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.nn.intrinsic as nni
+
+__all__ = [
+    "LinearReLU"
+]
+
+class LinearReLU(nnqd.Linear):
+    r"""
+    A LinearReLU module fused from Linear and ReLU modules that can be used
+    for dynamic quantization.
+    Supports both, FP16 and INT8 quantization.
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.dynamic.Linear
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearReLU  # type: ignore[assignment]
+
+    def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
+        super().__init__(in_features, out_features, bias, dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self._packed_params.dtype == torch.qint8:
+            # TODO check if we should set reduce_rage = True by default here
+            Y = torch.ops.quantized.linear_relu_dynamic(
+                x, self._packed_params._packed_params, reduce_range=True)
+        elif self._packed_params.dtype == torch.float16:
+            Y = torch.ops.quantized.linear_relu_dynamic_fp16(
+                x, self._packed_params._packed_params)
+        else:
+            raise RuntimeError('Unsupported dtype on dynamic quantized linear relu!')
+        return Y.to(x.dtype)
+
+    def _get_name(self):
+        return 'DynamicQuantizedLinearReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_qlinear_relu):
+        return super().from_reference(ref_qlinear_relu[0])
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa94c98c12d6de780f0c4f8688f258d5179ff959
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py
@@ -0,0 +1,17 @@
+from .linear_relu import LinearReLU, LinearLeakyReLU, LinearTanh
+from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
+from .bn_relu import BNReLU2d, BNReLU3d
+from .conv_add import ConvAdd2d, ConvAddReLU2d
+
+__all__ = [
+    'LinearReLU',
+    'ConvReLU1d',
+    'ConvReLU2d',
+    'ConvReLU3d',
+    'BNReLU2d',
+    'BNReLU3d',
+    'LinearLeakyReLU',
+    'LinearTanh',
+    'ConvAdd2d',
+    'ConvAddReLU2d',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6045a6f089769c206c0f384acb761d92505add2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c13c0ee110cb71a2f28355e1c873b9f9d91d9a8c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b71e9e66c7519528b6bf78b20395d7920d62efc1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cb2503f047bbcadf63fe4e63b5c4eb895fd0ade
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bc43786fa64b0a5a32983ed2322580055f48ec7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py
new file mode 100644
index 0000000000000000000000000000000000000000..42e6bfe9e52e28c30569bd758958b75965a1affb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py
@@ -0,0 +1,82 @@
+
+import torch
+import torch.ao.nn.intrinsic
+import torch.ao.nn.intrinsic.qat
+import torch.ao.nn.quantized as nnq
+
+__all__ = [
+    "BNReLU2d",
+    "BNReLU3d"
+]
+
+class BNReLU2d(nnq.BatchNorm2d):
+    r"""
+    A BNReLU2d module is a fused module of BatchNorm2d and ReLU
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.BatchNorm2d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
+
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
+        super().__init__(num_features, eps=eps, momentum=momentum, device=device, dtype=dtype)
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        return torch.ops.quantized.batch_norm2d_relu(
+            input, self.weight, self.bias, self.running_mean,
+            self.running_var, self.eps, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedBNReLU2d'
+
+    @classmethod
+    def from_float(cls, mod):
+        # TODO: Add qat support for BNReLU2d
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, bn_relu, output_scale, output_zero_point):
+        return super().from_reference(bn_relu[0], output_scale, output_zero_point)
+
+class BNReLU3d(nnq.BatchNorm3d):
+    r"""
+    A BNReLU3d module is a fused module of BatchNorm3d and ReLU
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.BatchNorm3d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
+
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
+        super().__init__(num_features, eps=eps, momentum=momentum, device=device, dtype=dtype)
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
+        return torch.ops.quantized.batch_norm3d_relu(
+            input, self.weight, self.bias, self.running_mean,
+            self.running_var, self.eps, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedBNReLU3d'
+
+    @classmethod
+    def from_float(cls, mod):
+        # TODO: Add qat support for BNReLU3d
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, bn_relu, output_scale, output_zero_point):
+        return super().from_reference(bn_relu[0], output_scale, output_zero_point)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py
new file mode 100644
index 0000000000000000000000000000000000000000..632dd1832af380fd74d01dba4768fa2e5c154ca9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py
@@ -0,0 +1,93 @@
+import torch
+import torch.ao.nn.intrinsic
+import torch.ao.nn.intrinsic.qat
+import torch.nn.functional as F
+import torch.ao.nn.quantized as nnq
+
+_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
+
+class ConvAdd2d(nnq.Conv2d):
+    r"""
+    A ConvAdd2d module is a fused module of Conv2d and Add
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.Conv2d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride=stride,
+            padding=padding, dilation=dilation, groups=groups, bias=bias,
+            padding_mode=padding_mode, device=device, dtype=dtype)
+
+    def forward(self, input, extra_input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return torch.ops.quantized.conv2d_add(
+            input, extra_input, self._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedConvAdd2d'
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
+        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
+
+class ConvAddReLU2d(nnq.Conv2d):
+    r"""
+    A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.Conv2d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride=stride,
+            padding=padding, dilation=dilation, groups=groups, bias=bias,
+            padding_mode=padding_mode, device=device, dtype=dtype)
+
+    def forward(self, input, extra_input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return torch.ops.quantized.conv2d_add_relu(
+            input, extra_input, self._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedConvAddReLU2d'
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
+        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c920844733b9cad6d2d52e6bcac09fc581da335b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py
@@ -0,0 +1,175 @@
+
+import torch
+import torch.ao.nn.intrinsic
+import torch.ao.nn.intrinsic.qat
+import torch.nn.functional as F
+import torch.ao.nn.quantized as nnq
+
+from torch.nn.utils import fuse_conv_bn_weights
+
+__all__ = [
+    "ConvReLU1d",
+    "ConvReLU2d",
+    "ConvReLU3d",
+]
+
+_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
+
+# TODO: factor out the common parts to ConvNd
+class ConvReLU1d(nnq.Conv1d):
+    r"""
+    A ConvReLU1d module is a fused module of Conv1d and ReLU
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.Conv1d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride=stride,
+            padding=padding, dilation=dilation, groups=groups, bias=bias,
+            padding_mode=padding_mode, device=device, dtype=dtype)
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 3:
+            raise ValueError("Input shape must be `(N, C, L)`!")
+        if self.padding_mode != 'zeros':
+            # Padding in Conv1d is stored as (p, p), need to get (p,)
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return torch.ops.quantized.conv1d_relu(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedConvReLU1d'
+
+    @classmethod
+    def from_float(cls, mod):
+        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
+            assert mod.bn.running_var is not None and mod.bn.running_mean is not None
+            mod.weight, mod.bias = fuse_conv_bn_weights(
+                mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
+                mod.bn.eps, mod.bn.weight, mod.bn.bias)
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
+        assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, \
+            "BatchNorm1d should be fused into Conv1d before converting to reference module"
+        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
+
+class ConvReLU2d(nnq.Conv2d):
+    r"""
+    A ConvReLU2d module is a fused module of Conv2d and ReLU
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.Conv2d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride=stride,
+            padding=padding, dilation=dilation, groups=groups, bias=bias,
+            padding_mode=padding_mode, device=device, dtype=dtype)
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return torch.ops.quantized.conv2d_relu(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedConvReLU2d'
+
+    @classmethod
+    def from_float(cls, mod):
+        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
+            assert mod.bn.running_var is not None and mod.bn.running_mean is not None
+            mod.weight, mod.bias = fuse_conv_bn_weights(
+                mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
+                mod.bn.eps, mod.bn.weight, mod.bn.bias)
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
+        assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, \
+            "BatchNorm2d should be fused into Conv2d before converting to reference module"
+        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
+
+
+class ConvReLU3d(nnq.Conv3d):
+    r"""
+    A ConvReLU3d module is a fused module of Conv3d and ReLU
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
+
+    Attributes: Same as torch.ao.nn.quantized.Conv3d
+
+    """
+    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride=stride,
+            padding=padding, dilation=dilation, groups=groups, bias=bias,
+            padding_mode=padding_mode, device=device, dtype=dtype)
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return torch.ops.quantized.conv3d_relu(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedConvReLU3d'
+
+    @classmethod
+    def from_float(cls, mod):
+        if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
+            assert mod.bn.running_var is not None and mod.bn.running_mean is not None
+            mod.weight, mod.bias = fuse_conv_bn_weights(
+                mod.weight,
+                mod.bias,
+                mod.bn.running_mean,
+                mod.bn.running_var,
+                mod.bn.eps,
+                mod.bn.weight,
+                mod.bn.bias,
+            )
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
+        assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, \
+            "BatchNorm3d should be fused into Conv3d before converting to reference module"
+        return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py
new file mode 100644
index 0000000000000000000000000000000000000000..08fb6b51bdca30bd8473d21b96dc884314e71899
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py
@@ -0,0 +1,177 @@
+import torch
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.intrinsic as nni
+from torch.ao.nn.quantized.modules.utils import _quantize_weight
+
+__all__ = [
+    "LinearReLU",
+    "LinearLeakyReLU",
+    "LinearTanh",
+]
+
+class LinearReLU(nnq.Linear):
+    r"""
+    A LinearReLU module fused from Linear and ReLU modules
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.Linear
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.intrinsic.LinearReLU(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearReLU  # type: ignore[assignment]
+
+    def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
+        super().__init__(in_features, out_features, bias, dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.ops.quantized.linear_relu(
+            x, self._packed_params._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedLinearReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
+        return super().from_reference(ref_linear_relu[0], output_scale, output_zero_point)
+
+class LinearLeakyReLU(nnq.Linear):
+    r"""
+    For onednn backend only
+    A LinearLeakyReLU module fused from Linear and LeakyReLU modules
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
+    Attributes:
+        Same as torch.ao.nn.quantized.Linear
+        + negative_slope
+    Examples::
+        >>> # xdoctest: +SKIP
+        >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearLeakyReLU  # type: ignore[assignment]
+
+    def __init__(self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8):
+        super().__init__(in_features, out_features, bias, dtype)
+        self.negative_slope = negative_slope
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.ops.quantized.linear_leaky_relu(
+            x, self._packed_params._packed_params, self.scale, self.zero_point, self.negative_slope)
+
+    def _get_name(self):
+        return 'QuantizedLinearLeakyReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        activation_post_process = mod.activation_post_process
+        leaky_relu = mod[1]
+        mod = mod[0]
+        weight_post_process = mod.qconfig.weight()
+        weight_post_process(mod.weight)
+        dtype = weight_post_process.dtype
+        act_scale, act_zp = activation_post_process.calculate_qparams()  # type: ignore[union-attr,operator]
+        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
+        qlinear_leaky_relu = cls(
+            mod.in_features,
+            mod.out_features,
+            leaky_relu.negative_slope,
+            dtype=dtype)
+        qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
+        qlinear_leaky_relu.scale = float(act_scale)
+        qlinear_leaky_relu.zero_point = int(act_zp)
+        return qlinear_leaky_relu
+
+    @classmethod
+    def from_reference(cls, ref_mod, output_scale, output_zero_point):
+        linear = ref_mod[0]
+        leaky_relu = ref_mod[1]
+        qlinear_leaky_relu = cls(
+            linear.in_features,
+            linear.out_features,
+            leaky_relu.negative_slope)
+        qweight = linear.get_quantized_weight()
+        qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
+        qlinear_leaky_relu.scale = float(output_scale)
+        qlinear_leaky_relu.zero_point = int(output_zero_point)
+        return qlinear_leaky_relu
+
+class LinearTanh(nnq.Linear):
+    r"""
+    A LinearTanh module fused from Linear and Tanh modules
+
+    We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
+
+    Attributes:
+        Same as torch.ao.nn.quantized.Linear
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.intrinsic.LinearTanh(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _FLOAT_MODULE = nni.LinearTanh  # type: ignore[assignment]
+
+    def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
+        super().__init__(in_features, out_features, bias, dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.ops.quantized.linear_tanh(
+            x, self._packed_params._packed_params, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedLinearTanh'
+
+    @classmethod
+    def from_float(cls, mod):
+        assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        activation_post_process = mod.activation_post_process
+        mod = mod[0]
+        weight_post_process = mod.qconfig.weight()
+        weight_post_process(mod.weight)
+        dtype = weight_post_process.dtype
+        act_scale, act_zp = activation_post_process.calculate_qparams()  # type: ignore[union-attr,operator]
+        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
+        qlinear_tanh = cls(
+            mod.in_features,
+            mod.out_features,
+            dtype=dtype)
+        qlinear_tanh.set_weight_bias(qweight, mod.bias)
+        qlinear_tanh.scale = float(act_scale)
+        qlinear_tanh.zero_point = int(act_zp)
+        return qlinear_tanh
+
+    @classmethod
+    def from_reference(cls, ref_mod, output_scale, output_zero_point):
+        linear = ref_mod[0]
+        qlinear_tanh = cls(
+            linear.in_features,
+            linear.out_features)
+        qweight = linear.get_quantized_weight()
+        qlinear_tanh.set_weight_bias(qweight, linear.bias)
+        qlinear_tanh.scale = float(output_scale)
+        qlinear_tanh.zero_point = int(output_zero_point)
+        return qlinear_tanh
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6be35e6be423632a6fab8ae3224c4a7f17e1498e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e31c3e0fe6bab106332ccfd4324b8b8918ba260f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f656409ea408920a9eb2f4d28ccf0aeb0f9b50ba
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py
@@ -0,0 +1,3 @@
+from .linear import Linear
+
+__all__ = ["Linear"]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3291390ebb9e9a4bab69aba78dcca9ddd063065
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c33276cc34add1850b60e773a70747ff6da3dea8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..fccb87e291c632481985786c26b06752ad03c8ce
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py
@@ -0,0 +1,25 @@
+import torch
+
+__all__ = ["Linear"]
+
+class Linear(torch.ao.nn.qat.Linear):
+    r"""
+    A linear module attached with FakeQuantize modules for weight,
+    used for dynamic quantization aware training.
+
+    We adopt the same interface as `torch.nn.Linear`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
+    for documentation.
+
+    Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
+    default.
+    """
+
+    def __init__(self, in_features, out_features, bias=True,
+                 qconfig=None, device=None, dtype=None) -> None:
+        super().__init__(in_features, out_features, bias, qconfig, device, dtype)
+        if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig):
+            raise ValueError(
+                "Dynamic QAT requires a memoryless observer." +
+                "This means a MovingAverage observer with averaging constant equal to 1"
+            )
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee80f17bb2660e8ce319854bb9273b8e7f4f909
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__init__.py
@@ -0,0 +1,14 @@
+from .linear import Linear
+from .conv import Conv1d
+from .conv import Conv2d
+from .conv import Conv3d
+from .embedding_ops import EmbeddingBag, Embedding
+
+__all__ = [
+    "Linear",
+    "Conv1d",
+    "Conv2d",
+    "Conv3d",
+    "Embedding",
+    "EmbeddingBag",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bc95d50b6a08acc4900ace8fef0f690caedfb51
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a9743e17ae5d15868a2de189badd275df0034e5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af75f4498f422ea8c63e1eef35f172476af803c6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1318814a07b6690ed64460369004ccce246d6f0a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/conv.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..c201f917a1db4fd007cde0e2c6c039f997299130
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/conv.py
@@ -0,0 +1,270 @@
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _single, _pair, _triple
+from torch.ao.nn.intrinsic import _FusedModule
+from typing import Tuple, TypeVar, Union
+from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
+
+__all__ = [
+    "Conv1d",
+    "Conv2d",
+    "Conv3d"
+]
+
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
+
+class _ConvNd(nn.modules.conv._ConvNd):
+
+    _FLOAT_MODULE = MOD
+
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: Tuple[int, ...],
+                 stride: Tuple[int, ...],
+                 padding: Tuple[int, ...],
+                 dilation: Tuple[int, ...],
+                 transposed: bool,
+                 output_padding: Tuple[int, ...],
+                 groups: int,
+                 bias: bool,
+                 padding_mode: str,
+                 qconfig=None,
+                 device=None,
+                 dtype=None) -> None:
+        factory_kwargs = {"device": device, "dtype": dtype}
+        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
+                                         stride, padding, dilation, transposed,
+                                         output_padding, groups, bias, padding_mode, **factory_kwargs)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
+
+    def forward(self, input):
+        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
+
+    @staticmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module
+
+            Args:
+               `mod`: a float module, either produced by torch.ao.quantization utilities
+               or directly from user
+        """
+        assert type(mod) == cls._FLOAT_MODULE, (
+            "qat."
+            + cls.__name__
+            + ".from_float only works for "
+            + cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
+        )
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid qconfig'
+        if issubclass(type(mod), _FusedModule):
+            mod = mod[0]  # type: ignore[index]
+        qconfig = mod.qconfig
+        qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
+                       stride=mod.stride, padding=mod.padding, dilation=mod.dilation,
+                       groups=mod.groups, bias=mod.bias is not None,
+                       padding_mode=mod.padding_mode, qconfig=qconfig)
+        qat_conv.weight = mod.weight
+        qat_conv.bias = mod.bias
+        return qat_conv
+
+    def to_float(self):
+        """ This works for both single qat conv, and the qat conv - relu modules
+        to convert the qat module to a floating point module
+        """
+        cls = type(self)
+        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined, operator]
+            self.in_channels,
+            self.out_channels,
+            self.kernel_size,  # type: ignore[arg-type]
+            self.stride,  # type: ignore[arg-type]
+            self.padding,  # type: ignore[arg-type]
+            self.dilation,  # type: ignore[arg-type]
+            self.groups,
+            self.bias is not None,
+            self.padding_mode)
+        conv.weight = torch.nn.Parameter(self.weight.detach())
+        if self.bias is not None:
+            conv.bias = torch.nn.Parameter(self.bias.detach())
+        # conv relu
+        if issubclass(cls, _FusedModule):
+            modules = [conv]
+            assert hasattr(cls, "_FLOAT_RELU_MODULE")
+            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
+            modules.append(relu)
+            fused = cls._FLOAT_MODULE(*modules)  # type: ignore[arg-type, attr-defined, operator]
+            fused.train(self.training)
+            return fused
+        else:
+            return conv
+
+class Conv1d(_ConvNd, nn.Conv1d):
+    r"""
+    A Conv1d module attached with FakeQuantize modules for weight,
+    used for quantization aware training.
+
+    We adopt the same interface as :class:`~torch.nn.Conv1d`
+
+    Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+    """
+    _FLOAT_MODULE = nn.Conv1d
+    _FLOAT_CONV_MODULE = nn.Conv1d
+
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_1_t,
+                 stride: _size_1_t = 1,
+                 padding: Union[str, _size_1_t] = 0,
+                 dilation: _size_1_t = 1,
+                 groups: int = 1,
+                 bias: bool = True,
+                 padding_mode: str = 'zeros',
+                 qconfig=None,
+                 device=None,
+                 dtype=None) -> None:
+        kernel_size_ = _single(kernel_size)
+        stride_ = _single(stride)
+        padding_ = padding if isinstance(padding, str) else _single(padding)
+        dilation_ = _single(dilation)
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size_,
+            stride=stride_,
+            padding=padding_,
+            dilation=dilation_,
+            transposed=False,
+            output_padding=_single(0),
+            groups=groups,
+            bias=bias,
+            padding_mode=padding_mode,
+            qconfig=qconfig,
+            device=device,
+            dtype=dtype)
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(cls, mod)
+
+class Conv2d(_ConvNd, nn.Conv2d):
+    r"""
+    A Conv2d module attached with FakeQuantize modules for weight,
+    used for quantization aware training.
+
+    We adopt the same interface as `torch.nn.Conv2d`, please see
+    https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
+    for documentation.
+
+    Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+    """
+    _FLOAT_MODULE = nn.Conv2d
+    _FLOAT_CONV_MODULE = nn.Conv2d
+
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_2_t,
+                 stride: _size_2_t = 1,
+                 padding: Union[str, _size_2_t] = 0,
+                 dilation: _size_2_t = 1,
+                 groups: int = 1,
+                 bias: bool = True,
+                 padding_mode: str = 'zeros',
+                 qconfig=None,
+                 device=None,
+                 dtype=None) -> None:
+        kernel_size_ = _pair(kernel_size)
+        stride_ = _pair(stride)
+        padding_ = padding if isinstance(padding, str) else _pair(padding)
+        dilation_ = _pair(dilation)
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size_,
+            stride=stride_,
+            padding=padding_,
+            dilation=dilation_,
+            transposed=False,
+            output_padding=_pair(0),
+            groups=groups,
+            bias=bias,
+            padding_mode=padding_mode,
+            qconfig=qconfig,
+            device=device,
+            dtype=dtype)
+
+    def forward(self, input):
+        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(cls, mod)
+
+class Conv3d(_ConvNd, nn.Conv3d):
+    r"""
+    A Conv3d module attached with FakeQuantize modules for weight,
+    used for quantization aware training.
+
+    We adopt the same interface as `torch.nn.Conv3d`, please see
+    https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
+    for documentation.
+
+    Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight_fake_quant: fake quant module for weight
+    """
+    _FLOAT_MODULE = nn.Conv3d
+    _FLOAT_CONV_MODULE = nn.Conv3d
+
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_3_t,
+                 stride: _size_3_t = 1,
+                 padding: Union[str, _size_3_t] = 0,
+                 dilation: _size_3_t = 1,
+                 groups: int = 1,
+                 bias: bool = True,
+                 padding_mode: str = 'zeros',
+                 qconfig=None,
+                 device=None,
+                 dtype=None) -> None:
+        kernel_size_ = _triple(kernel_size)
+        stride_ = _triple(stride)
+        padding_ = padding if isinstance(padding, str) else _triple(padding)
+        dilation_ = _triple(dilation)
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size_,
+            stride=stride_,
+            padding=padding_,
+            dilation=dilation_,
+            transposed=False,
+            output_padding=_triple(0),
+            groups=groups,
+            bias=bias,
+            padding_mode=padding_mode,
+            qconfig=qconfig,
+            device=device,
+            dtype=dtype)
+
+    def forward(self, input):
+        return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(cls, mod)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/embedding_ops.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/embedding_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..35172bcd4df39fb8d946ca624aa774e8f3ab6475
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/embedding_ops.py
@@ -0,0 +1,143 @@
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['Embedding', 'EmbeddingBag']
+
+class Embedding(nn.Embedding):
+    r"""
+    An embedding bag module attached with FakeQuantize modules for weight,
+    used for quantization aware training.
+
+    We adopt the same interface as `torch.nn.Embedding`, please see
+    https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
+    for documentation.
+
+    Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight: fake quant module for weight
+    """
+    _FLOAT_MODULE = nn.Embedding
+
+    def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
+                 max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
+                 sparse=False, _weight=None, device=None, dtype=None, qconfig=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
+                         norm_type, scale_grad_by_freq, sparse, _weight,
+                         **factory_kwargs)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
+            'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
+            str(qconfig.weight().qscheme)
+        self.qconfig = qconfig
+        self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
+
+    def forward(self, input) -> Tensor:
+        return F.embedding(input, self.weight_fake_quant(self.weight), self.padding_idx,
+                           self.max_norm, self.norm_type, self.scale_grad_by_freq,
+                           self.sparse)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module
+
+            Args: `mod` a float module, either produced by torch.ao.quantization utilities
+            or directly from user
+        """
+        assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
+            cls._FLOAT_MODULE.__name__
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid qconfig'
+        weight_qscheme = mod.qconfig.weight().qscheme  # type: ignore[union-attr, operator]
+        assert weight_qscheme == torch.per_channel_affine_float_qparams, \
+            'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
+            str(weight_qscheme)
+
+        qconfig = mod.qconfig
+        qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.padding_idx,
+                                mod.max_norm, mod.norm_type, mod.scale_grad_by_freq,
+                                mod.sparse, mod.weight, qconfig=qconfig)
+
+        return qat_embedding_bag
+
+    def to_float(self):
+        embedding_bag = torch.nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx,
+                                           self.max_norm, self.norm_type, self.scale_grad_by_freq,
+                                           self.sparse, None)
+        embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
+        embedding_bag.train(self.training)
+        return embedding_bag
+
+class EmbeddingBag(nn.EmbeddingBag):
+    r"""
+    An embedding bag module attached with FakeQuantize modules for weight,
+    used for quantization aware training.
+
+    We adopt the same interface as `torch.nn.EmbeddingBag`, please see
+    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
+    for documentation.
+
+    Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight: fake quant module for weight
+    """
+    _FLOAT_MODULE = nn.EmbeddingBag
+
+    def __init__(self, num_embeddings, embedding_dim, max_norm=None,
+                 norm_type=2.0, scale_grad_by_freq=False, mode='mean',
+                 sparse=False, _weight=None, include_last_offset=False,
+                 padding_idx=None, qconfig=None, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
+                         scale_grad_by_freq, mode, sparse, _weight,
+                         include_last_offset, padding_idx, **factory_kwargs)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
+            'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
+            str(qconfig.weight().qscheme)
+        self.qconfig = qconfig
+        self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
+
+    def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
+        return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets,
+                               self.max_norm, self.norm_type,
+                               self.scale_grad_by_freq, self.mode, self.sparse,
+                               per_sample_weights, self.include_last_offset,
+                               self.padding_idx)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module
+
+            Args: `mod` a float module, either produced by torch.ao.quantization utilities
+            or directly from user
+        """
+        assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
+            cls._FLOAT_MODULE.__name__
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        assert mod.qconfig, 'Input float module must have a valid qconfig'
+        weight_qscheme = mod.qconfig.weight().qscheme  # type: ignore[union-attr, operator]
+        assert weight_qscheme == torch.per_channel_affine_float_qparams, \
+            'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
+            str(weight_qscheme)
+
+        qconfig = mod.qconfig
+        qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
+                                mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
+                                mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
+
+        return qat_embedding_bag
+
+    def to_float(self):
+        embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
+                                              self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
+                                              None, self.include_last_offset, self.padding_idx)
+        embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
+        embedding_bag.train(self.training)
+        return embedding_bag
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..7986f0c9f5237d3eef6f6b4717c7593e290828e4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/qat/modules/linear.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.nn.intrinsic import LinearReLU
+from torch.nn.utils.parametrize import (
+    is_parametrized,
+    type_before_parametrizations,
+    transfer_parametrizations_and_params,
+)
+
+__all__ = [
+    "Linear"
+]
+
+class Linear(nn.Linear):
+    r"""
+    A linear module attached with FakeQuantize modules for weight,
+    used for quantization aware training.
+
+    We adopt the same interface as `torch.nn.Linear`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
+    for documentation.
+
+    Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
+    default.
+
+    Attributes:
+        weight: fake quant module for weight
+    """
+    _FLOAT_MODULE = nn.Linear
+
+    def __init__(self, in_features, out_features, bias=True,
+                 qconfig=None, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(in_features, out_features, bias, **factory_kwargs)
+        assert qconfig, 'qconfig must be provided for QAT module'
+        self.qconfig = qconfig
+        self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
+
+    def forward(self, input):
+        return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a qat module from a float module or qparams_dict
+            Args: `mod` a float module, either produced by torch.ao.quantization utilities
+            or directly from user
+        """
+        assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, (
+            " qat."
+            + cls.__name__
+            + ".from_float only works for "
+            + cls._FLOAT_MODULE.__name__
+        )
+        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
+        assert mod.qconfig, "Input float module must have a valid qconfig"
+        if type_before_parametrizations(mod) == LinearReLU:
+            mod = mod[0]
+
+        qconfig = mod.qconfig
+        qat_linear = cls(mod.in_features, mod.out_features, bias=mod.bias is not None, qconfig=qconfig)
+
+        if is_parametrized(mod, "weight"):
+            transfer_parametrizations_and_params(mod, qat_linear, "weight")
+        else:
+            qat_linear.weight = mod.weight
+
+        if is_parametrized(mod, "bias"):
+            transfer_parametrizations_and_params(mod, qat_linear, "bias")
+        else:
+            qat_linear.bias = mod.bias
+
+        return qat_linear
+
+    def to_float(self):
+        linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
+        linear.weight = torch.nn.Parameter(self.weight.detach())
+        if self.bias is not None:
+            linear.bias = torch.nn.Parameter(self.bias.detach())
+        linear.train(self.training)
+        return linear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68685e82dbb1a5a5a23f6c3415982ef3bf9dea66
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e55197cb4e3a166e09790606b7503224705b0c2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__init__.py
@@ -0,0 +1,9 @@
+from .activation import MultiheadAttention
+from .rnn import LSTM
+from .rnn import LSTMCell
+
+__all__ = [
+    'LSTM',
+    'LSTMCell',
+    'MultiheadAttention',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f074ffc402f90e2b90838a04be3e108352356ca
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42964880df47e392ed82199e77d530624c23ae34
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c90913c7b8d712de624e148c54efa6a83165e592
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/activation.py b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..90975fbb3de65e9556cc71533abb1b7c03e7a304
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/activation.py
@@ -0,0 +1,465 @@
+import torch
+import torch.jit  # this is needed to avoid a circular import
+from torch import nn
+import torch.nn.functional as nnF
+
+from torch import Tensor
+from typing import Optional, Tuple
+
+import warnings
+
+__all__ = [
+    "MultiheadAttention"
+]
+
+class MultiheadAttention(nn.MultiheadAttention):
+    _FLOAT_MODULE = nn.MultiheadAttention
+
+    r"""Quantizable implementation of the MultiheadAttention.
+
+    Note::
+        Please, refer to :class:`~torch.nn.MultiheadAttention` for more
+        information
+
+    Allows the model to jointly attend to information from different
+    representation subspaces.
+    See reference: Attention Is All You Need
+
+    The original MHA module is not quantizable.
+    This reimplements it by explicitly instantiating the linear layers.
+
+    .. math::
+        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
+
+    Args:
+        embed_dim: total dimension of the model.
+        num_heads: parallel attention heads.
+        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+        bias: add bias as module parameter. Default: True.
+        add_bias_kv: add bias to the key and value sequences at dim=0.
+        add_zero_attn: add a new batch of zeros to the key and
+                       value sequences at dim=1.
+        kdim: total number of features in key. Default: None.
+        vdim: total number of features in value. Default: None.
+        batch_first: If ``True``, then the input and output tensors are provided
+            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
+
+    Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
+    to :attr:`embed_dim` such that query, key, and value have the same
+    number of features.
+
+    Examples::
+
+        >>> import torch.ao.nn.quantizable as nnqa
+        >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
+        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+
+    Note::
+        Please, follow the quantization flow to convert the quantizable MHA.
+    """
+    __constants__ = ['batch_first']
+
+    def __init__(self, embed_dim: int, num_heads: int,
+                 dropout: float = 0., bias: bool = True,
+                 add_bias_kv: bool = False, add_zero_attn: bool = False,
+                 kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False,
+                 device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(embed_dim, num_heads, dropout,
+                         bias, add_bias_kv,
+                         add_zero_attn, kdim, vdim, batch_first,
+                         **factory_kwargs)
+        self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
+        self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
+        self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
+        # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
+        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)  # type: ignore[assignment]
+
+        # Functionals
+        self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
+        # note: importing torch.ao.nn.quantized at top creates a circular import
+
+        # Quant/Dequant
+        self.quant_attn_output = torch.ao.quantization.QuantStub()
+        self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
+        self.dequant_q = torch.ao.quantization.DeQuantStub()
+        self.dequant_k = torch.ao.quantization.DeQuantStub()
+        self.dequant_v = torch.ao.quantization.DeQuantStub()
+
+    def _get_name(self):
+        return 'QuantizableMultiheadAttention'
+
+    @classmethod
+    def from_float(cls, other):
+        assert type(other) == cls._FLOAT_MODULE
+        assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
+        # Setting the dropout to 0.0!
+        observed = cls(other.embed_dim, other.num_heads, other.dropout,
+                       (other.in_proj_bias is not None),
+                       (other.bias_k is not None),
+                       other.add_zero_attn, other.kdim, other.vdim,
+                       other.batch_first)
+        observed.bias_k = other.bias_k
+        observed.bias_v = other.bias_v
+        observed.qconfig = other.qconfig
+
+        # Set the linear weights
+        # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
+        observed.out_proj.weight = other.out_proj.weight  # type: ignore[has-type]
+        observed.out_proj.bias = other.out_proj.bias  # type: ignore[has-type]
+        if other._qkv_same_embed_dim:
+            # Use separate params
+            bias = other.in_proj_bias
+            _start = 0
+            _end = _start + other.embed_dim
+            weight = other.in_proj_weight[_start:_end, :]
+            if bias is not None:
+                bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
+            observed.linear_Q.weight = torch.nn.Parameter(weight,
+                                                          weight.requires_grad)
+            observed.linear_Q.bias = bias
+
+            bias = other.in_proj_bias
+            _start = _end
+            _end = _start + other.embed_dim
+            weight = other.in_proj_weight[_start:_end, :]
+            if bias is not None:
+                bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
+            observed.linear_K.weight = torch.nn.Parameter(weight,
+                                                          weight.requires_grad)
+            observed.linear_K.bias = bias
+
+            bias = other.in_proj_bias
+            _start = _end
+            weight = other.in_proj_weight[_start:, :]
+            if bias is not None:
+                bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
+            observed.linear_V.weight = torch.nn.Parameter(weight,
+                                                          weight.requires_grad)
+            observed.linear_V.bias = bias
+        else:
+            observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
+            observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
+            observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
+            if other.in_proj_bias is None:
+                observed.linear_Q.bias = None  # type: ignore[assignment]
+                observed.linear_K.bias = None  # type: ignore[assignment]
+                observed.linear_V.bias = None  # type: ignore[assignment]
+            else:
+                observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
+                observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
+                observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
+        observed.eval()
+        # Explicit prepare
+        observed = torch.ao.quantization.prepare(observed, inplace=True)
+        return observed
+
+    @torch.jit.unused
+    def dequantize(self):
+        r"""Utility to convert the quantized MHA back to float.
+
+        The motivation for this is that it is not trivial to conver the weights
+        from the format that is used in the quantized version back to the
+        float.
+        """
+        fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
+                                (self.linear_Q._weight_bias()[1] is not None),
+                                (self.bias_k is not None),
+                                self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
+        assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
+        if self.bias_k is not None:
+            fp.bias_k = nn.Parameter(self.bias_k.dequantize())
+        if self.bias_v is not None:
+            fp.bias_v = nn.Parameter(self.bias_v.dequantize())
+
+        # Set the linear weights
+        # Note: Because the linear layers are quantized, mypy does not nkow how
+        # to deal with them -- might need to ignore the typing checks.
+        # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
+        w, b = self.out_proj._weight_bias()  # type: ignore[operator, has-type]
+        fp.out_proj.weight = nn.Parameter(w.dequantize())
+        if b is not None:
+            fp.out_proj.bias = nn.Parameter(b)
+
+        wQ, bQ = self.linear_Q._weight_bias()  # type: ignore[operator]
+        wQ = wQ.dequantize()
+        wK, bK = self.linear_K._weight_bias()  # type: ignore[operator]
+        wK = wK.dequantize()
+        wV, bV = self.linear_V._weight_bias()  # type: ignore[operator]
+        wV = wV.dequantize()
+        if fp._qkv_same_embed_dim:
+            # Use separate params
+            _start = 0
+            _end = _start + fp.embed_dim
+            fp.in_proj_weight[_start:_end, :] = wQ
+            if fp.in_proj_bias is not None:
+                assert all(bQ == 0)
+                fp.in_proj_bias[_start:_end] = bQ
+
+            _start = _end
+            _end = _start + fp.embed_dim
+            fp.in_proj_weight[_start:_end, :] = wK
+            if fp.in_proj_bias is not None:
+                assert all(bK == 0)
+                fp.in_proj_bias[_start:_end] = bK
+
+            _start = _end
+            fp.in_proj_weight[_start:, :] = wV
+            if fp.in_proj_bias is not None:
+                assert all(bV == 0)
+                fp.in_proj_bias[_start:] = bV
+        else:
+            fp.q_proj_weight = nn.Parameter(wQ)
+            fp.k_proj_weight = nn.Parameter(wK)
+            fp.v_proj_weight = nn.Parameter(wV)
+            if fp.in_proj_bias is None:
+                self.linear_Q.bias = None
+                self.linear_K.bias = None
+                self.linear_V.bias = None
+            else:
+                fp.in_proj_bias[0:fp.embed_dim] = bQ
+                fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
+                fp.in_proj_bias[(fp.embed_dim * 2):] = bV
+
+        return fp
+
+
+    @classmethod
+    def from_observed(cls, other):
+        # The whole flow is float -> observed -> quantized
+        # This class does float -> observed only
+        # See nn.quantized.MultiheadAttention
+        raise NotImplementedError("It looks like you are trying to prepare an "
+                                  "MHA module. Please, see "
+                                  "the examples on quantizable MHAs.")
+
+    def forward(self,
+                query: Tensor,
+                key: Tensor,
+                value: Tensor,
+                key_padding_mask: Optional[Tensor] = None,
+                need_weights: bool = True,
+                attn_mask: Optional[Tensor] = None,
+                average_attn_weights: bool = True,
+                is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
+        r"""
+    Note::
+        Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
+        information
+
+    Args:
+        query, key, value: map a query and a set of key-value pairs to an output.
+            See "Attention Is All You Need" for more details.
+        key_padding_mask: if provided, specified padding elements in the key will
+            be ignored by the attention. When given a binary mask and a value is True,
+            the corresponding value on the attention layer will be ignored.
+        need_weights: output attn_output_weights.
+        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+    Shape:
+        - Inputs:
+        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+          the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
+        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+          the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
+        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+          the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
+        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+          If a BoolTensor is provided, the positions with the
+          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+          S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+          positions. If a BoolTensor is provided, positions with ``True``
+          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+          is provided, it will be added to the attention weight.
+        - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
+          Default: ``False``.
+        - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
+          heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
+          effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
+
+        - Outputs:
+        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+          E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
+        - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
+          across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
+          S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
+          head of shape :math:`(N, num_heads, L, S)`.
+        """
+        return self._forward_impl(query, key, value, key_padding_mask,
+                                  need_weights, attn_mask, average_attn_weights,
+                                  is_causal)
+
+    def _forward_impl(self,
+                      query: Tensor,
+                      key: Tensor,
+                      value: Tensor,
+                      key_padding_mask: Optional[Tensor] = None,
+                      need_weights: bool = True,
+                      attn_mask: Optional[Tensor] = None,
+                      average_attn_weights: bool = True,
+                      is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
+        # This version will not deal with the static key/value pairs.
+        # Keeping it here for future changes.
+        #
+        # TODO: This method has some duplicate lines with the
+        # `torch.nn.functional.multi_head_attention`. Will need to refactor.
+        static_k = None
+        static_v = None
+
+        if attn_mask is not None and is_causal:
+            raise AssertionError("Only allow causal mask or attn_mask")
+
+        if is_causal:
+            raise AssertionError("causal mask not supported by AO MHA module")
+
+        if self.batch_first:
+            query, key, value = (x.transpose(0, 1) for x in (query, key, value))
+
+        tgt_len, bsz, embed_dim_to_check = query.size()
+        assert self.embed_dim == embed_dim_to_check
+        # allow MHA to have different sizes for the feature dimension
+        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+        head_dim = self.embed_dim // self.num_heads
+        assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+        scaling = float(head_dim) ** -0.5
+
+        q = self.linear_Q(query)
+        k = self.linear_K(key)
+        v = self.linear_V(value)
+
+        q = self.q_scaling_product.mul_scalar(q, scaling)
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.uint8:
+                warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+                attn_mask = attn_mask.to(torch.bool)
+            assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
+                f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}'
+
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+                    raise RuntimeError('The size of the 2D attn_mask is not correct.')
+            elif attn_mask.dim() == 3:
+                if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
+                    raise RuntimeError('The size of the 3D attn_mask is not correct.')
+            else:
+                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
+            # attn_mask's dim is 3 now.
+
+        # convert ByteTensor key_padding_mask to bool
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+            warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+            key_padding_mask = key_padding_mask.to(torch.bool)
+        if self.bias_k is not None and self.bias_v is not None:
+            if static_k is None and static_v is None:
+
+                # Explicitly assert that bias_k and bias_v are not None
+                # in a way that TorchScript can understand.
+                bias_k = self.bias_k
+                assert bias_k is not None
+                bias_v = self.bias_v
+                assert bias_v is not None
+
+                k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+                v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+                if attn_mask is not None:
+                    attn_mask = nnF.pad(attn_mask, (0, 1))
+                if key_padding_mask is not None:
+                    key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
+            else:
+                assert static_k is None, "bias cannot be added to static key."
+                assert static_v is None, "bias cannot be added to static value."
+        else:
+            assert self.bias_k is None
+            assert self.bias_v is None
+
+        q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
+        if k is not None:
+            k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
+        if v is not None:
+            v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
+
+        if static_k is not None:
+            assert static_k.size(0) == bsz * self.num_heads
+            assert static_k.size(2) == head_dim
+            k = static_k
+
+        if static_v is not None:
+            assert static_v.size(0) == bsz * self.num_heads
+            assert static_v.size(2) == head_dim
+            v = static_v
+
+        src_len = k.size(1)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz
+            assert key_padding_mask.size(1) == src_len
+
+        if self.add_zero_attn:
+            src_len += 1
+            k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
+            if k.is_quantized:
+                k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
+            k = torch.cat([k, k_zeros], dim=1)
+            v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
+            if v.is_quantized:
+                v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
+            v = torch.cat([v, v_zeros], dim=1)
+
+            if attn_mask is not None:
+                attn_mask = nnF.pad(attn_mask, (0, 1))
+            if key_padding_mask is not None:
+                key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
+
+        # Leaving the quantized zone here
+        q = self.dequant_q(q)
+        k = self.dequant_k(k)
+        v = self.dequant_v(v)
+        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+        assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_output_weights.masked_fill_(attn_mask, float('-inf'))
+            else:
+                attn_output_weights += attn_mask
+
+        if key_padding_mask is not None:
+            attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_output_weights = attn_output_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                float('-inf'),
+            )
+            attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_output_weights = nnF.softmax(
+            attn_output_weights, dim=-1)
+        attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_output_weights, v)
+        assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
+        if self.batch_first:
+            attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
+        else:
+            attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
+
+        # Reentering the quantized zone
+        attn_output = self.quant_attn_output(attn_output)
+        # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
+        attn_output = self.out_proj(attn_output)  # type: ignore[has-type]
+        attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
+
+        if need_weights:
+            # average attention weights over heads
+            attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            if average_attn_weights:
+                attn_output_weights = attn_output_weights.mean(dim=1)
+            return attn_output, attn_output_weights
+        else:
+            return attn_output, None
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/rnn.py b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec4a532e548abc536a534dee906f06d454cd71a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantizable/modules/rnn.py
@@ -0,0 +1,411 @@
+import numbers
+from typing import Optional, Tuple
+import warnings
+
+import torch
+from torch import Tensor
+
+"""
+We will recreate all the RNN modules as we require the modules to be decomposed
+into its building blocks to be able to observe.
+"""
+
+__all__ = [
+    "LSTMCell",
+    "LSTM"
+]
+
+class LSTMCell(torch.nn.Module):
+    r"""A quantizable long short-term memory (LSTM) cell.
+
+    For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
+
+    Examples::
+
+        >>> import torch.ao.nn.quantizable as nnqa
+        >>> rnn = nnqa.LSTMCell(10, 20)
+        >>> input = torch.randn(6, 10)
+        >>> hx = torch.randn(3, 20)
+        >>> cx = torch.randn(3, 20)
+        >>> output = []
+        >>> for i in range(6):
+        ...     hx, cx = rnn(input[i], (hx, cx))
+        ...     output.append(hx)
+    """
+    _FLOAT_MODULE = torch.nn.LSTMCell
+
+    def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
+                 device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        self.input_size = input_dim
+        self.hidden_size = hidden_dim
+        self.bias = bias
+
+        self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
+        self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
+        self.gates = torch.ao.nn.quantized.FloatFunctional()
+
+        self.input_gate = torch.nn.Sigmoid()
+        self.forget_gate = torch.nn.Sigmoid()
+        self.cell_gate = torch.nn.Tanh()
+        self.output_gate = torch.nn.Sigmoid()
+
+        self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
+        self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
+        self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
+
+        self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
+
+        self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
+        self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
+        self.hidden_state_dtype: torch.dtype = torch.quint8
+        self.cell_state_dtype: torch.dtype = torch.quint8
+
+    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
+        if hidden is None or hidden[0] is None or hidden[1] is None:
+            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
+        hx, cx = hidden
+
+        igates = self.igates(x)
+        hgates = self.hgates(hx)
+        gates = self.gates.add(igates, hgates)
+
+        input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
+
+        input_gate = self.input_gate(input_gate)
+        forget_gate = self.forget_gate(forget_gate)
+        cell_gate = self.cell_gate(cell_gate)
+        out_gate = self.output_gate(out_gate)
+
+        fgate_cx = self.fgate_cx.mul(forget_gate, cx)
+        igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
+        fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
+        cy = fgate_cx_igate_cgate
+
+        # TODO: make this tanh a member of the module so its qparams can be configured
+        tanh_cy = torch.tanh(cy)
+        hy = self.ogate_cy.mul(out_gate, tanh_cy)
+        return hy, cy
+
+    def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
+        h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
+        if is_quantized:
+            (h_scale, h_zp) = self.initial_hidden_state_qparams
+            (c_scale, c_zp) = self.initial_cell_state_qparams
+            h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype)
+            c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype)
+        return h, c
+
+    def _get_name(self):
+        return 'QuantizableLSTMCell'
+
+    @classmethod
+    def from_params(cls, wi, wh, bi=None, bh=None):
+        """Uses the weights and biases to create a new LSTM cell.
+
+        Args:
+            wi, wh: Weights for the input and hidden layers
+            bi, bh: Biases for the input and hidden layers
+        """
+        assert (bi is None) == (bh is None)  # Either both None or both have values
+        input_size = wi.shape[1]
+        hidden_size = wh.shape[1]
+        cell = cls(input_dim=input_size, hidden_dim=hidden_size,
+                   bias=(bi is not None))
+        cell.igates.weight = torch.nn.Parameter(wi)
+        if bi is not None:
+            cell.igates.bias = torch.nn.Parameter(bi)
+        cell.hgates.weight = torch.nn.Parameter(wh)
+        if bh is not None:
+            cell.hgates.bias = torch.nn.Parameter(bh)
+        return cell
+
+    @classmethod
+    def from_float(cls, other):
+        assert type(other) == cls._FLOAT_MODULE
+        assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
+        observed = cls.from_params(other.weight_ih, other.weight_hh,
+                                   other.bias_ih, other.bias_hh)
+        observed.qconfig = other.qconfig
+        observed.igates.qconfig = other.qconfig
+        observed.hgates.qconfig = other.qconfig
+        return observed
+
+
+class _LSTMSingleLayer(torch.nn.Module):
+    r"""A single one-directional LSTM layer.
+
+    The difference between a layer and a cell is that the layer can process a
+    sequence, while the cell only expects an instantaneous value.
+    """
+    def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
+                 device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
+
+    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
+        result = []
+        seq_len = x.shape[0]
+        for i in range(seq_len):
+            hidden = self.cell(x[i], hidden)
+            result.append(hidden[0])  # type: ignore[index]
+        result_tensor = torch.stack(result, 0)
+        return result_tensor, hidden
+
+    @classmethod
+    def from_params(cls, *args, **kwargs):
+        cell = LSTMCell.from_params(*args, **kwargs)
+        layer = cls(cell.input_size, cell.hidden_size, cell.bias)
+        layer.cell = cell
+        return layer
+
+
+class _LSTMLayer(torch.nn.Module):
+    r"""A single bi-directional LSTM layer."""
+    def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
+                 batch_first: bool = False, bidirectional: bool = False,
+                 device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        self.batch_first = batch_first
+        self.bidirectional = bidirectional
+        self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
+        if self.bidirectional:
+            self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
+
+    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
+        if self.batch_first:
+            x = x.transpose(0, 1)
+        if hidden is None:
+            hx_fw, cx_fw = (None, None)
+        else:
+            hx_fw, cx_fw = hidden
+        hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
+        if self.bidirectional:
+            if hx_fw is None:
+                hx_bw = None
+            else:
+                hx_bw = hx_fw[1]
+                hx_fw = hx_fw[0]
+            if cx_fw is None:
+                cx_bw = None
+            else:
+                cx_bw = cx_fw[1]
+                cx_fw = cx_fw[0]
+            if hx_bw is not None and cx_bw is not None:
+                hidden_bw = hx_bw, cx_bw
+        if hx_fw is None and cx_fw is None:
+            hidden_fw = None
+        else:
+            hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
+        result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
+
+        if hasattr(self, 'layer_bw') and self.bidirectional:
+            x_reversed = x.flip(0)
+            result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
+            result_bw = result_bw.flip(0)
+
+            result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
+            if hidden_fw is None and hidden_bw is None:
+                h = None
+                c = None
+            elif hidden_fw is None:
+                (h, c) = torch.jit._unwrap_optional(hidden_bw)
+            elif hidden_bw is None:
+                (h, c) = torch.jit._unwrap_optional(hidden_fw)
+            else:
+                h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)  # type: ignore[list-item]
+                c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
+        else:
+            result = result_fw
+            h, c = torch.jit._unwrap_optional(hidden_fw)  # type: ignore[assignment]
+
+        if self.batch_first:
+            result.transpose_(0, 1)
+
+        return result, (h, c)
+
+    @classmethod
+    def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
+        r"""
+        There is no FP equivalent of this class. This function is here just to
+        mimic the behavior of the `prepare` within the `torch.ao.quantization`
+        flow.
+        """
+        assert hasattr(other, 'qconfig') or (qconfig is not None)
+
+        input_size = kwargs.get('input_size', other.input_size)
+        hidden_size = kwargs.get('hidden_size', other.hidden_size)
+        bias = kwargs.get('bias', other.bias)
+        batch_first = kwargs.get('batch_first', other.batch_first)
+        bidirectional = kwargs.get('bidirectional', other.bidirectional)
+
+        layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
+        layer.qconfig = getattr(other, 'qconfig', qconfig)
+        wi = getattr(other, f'weight_ih_l{layer_idx}')
+        wh = getattr(other, f'weight_hh_l{layer_idx}')
+        bi = getattr(other, f'bias_ih_l{layer_idx}', None)
+        bh = getattr(other, f'bias_hh_l{layer_idx}', None)
+
+        layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
+
+        if other.bidirectional:
+            wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
+            wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
+            bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
+            bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
+            layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
+        return layer
+
+
+class LSTM(torch.nn.Module):
+    r"""A quantizable long short-term memory (LSTM).
+
+    For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
+
+    Attributes:
+        layers : instances of the `_LSTMLayer`
+
+    .. note::
+        To access the weights and biases, you need to access them per layer.
+        See examples below.
+
+    Examples::
+
+        >>> import torch.ao.nn.quantizable as nnqa
+        >>> rnn = nnqa.LSTM(10, 20, 2)
+        >>> input = torch.randn(5, 3, 10)
+        >>> h0 = torch.randn(2, 3, 20)
+        >>> c0 = torch.randn(2, 3, 20)
+        >>> output, (hn, cn) = rnn(input, (h0, c0))
+        >>> # To get the weights:
+        >>> # xdoctest: +SKIP
+        >>> print(rnn.layers[0].weight_ih)
+        tensor([[...]])
+        >>> print(rnn.layers[0].weight_hh)
+        AssertionError: There is no reverse path in the non-bidirectional layer
+    """
+    _FLOAT_MODULE = torch.nn.LSTM
+
+    def __init__(self, input_size: int, hidden_size: int,
+                 num_layers: int = 1, bias: bool = True,
+                 batch_first: bool = False, dropout: float = 0.,
+                 bidirectional: bool = False,
+                 device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.bias = bias
+        self.batch_first = batch_first
+        self.dropout = float(dropout)
+        self.bidirectional = bidirectional
+        self.training = False  # Default to eval mode. If we want to train, we will explicitly set to training.
+        num_directions = 2 if bidirectional else 1
+
+        if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
+                isinstance(dropout, bool):
+            raise ValueError("dropout should be a number in range [0, 1] "
+                             "representing the probability of an element being "
+                             "zeroed")
+        if dropout > 0:
+            warnings.warn("dropout option for quantizable LSTM is ignored. "
+                          "If you are training, please, use nn.LSTM version "
+                          "followed by `prepare` step.")
+            if num_layers == 1:
+                warnings.warn("dropout option adds dropout after all but last "
+                              "recurrent layer, so non-zero dropout expects "
+                              f"num_layers greater than 1, but got dropout={dropout} "
+                              f"and num_layers={num_layers}")
+
+        layers = [_LSTMLayer(self.input_size, self.hidden_size,
+                             self.bias, batch_first=False,
+                             bidirectional=self.bidirectional, **factory_kwargs)]
+        for layer in range(1, num_layers):
+            layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
+                                     self.bias, batch_first=False,
+                                     bidirectional=self.bidirectional,
+                                     **factory_kwargs))
+        self.layers = torch.nn.ModuleList(layers)
+
+    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
+        if self.batch_first:
+            x = x.transpose(0, 1)
+
+        max_batch_size = x.size(1)
+        num_directions = 2 if self.bidirectional else 1
+        if hidden is None:
+            zeros = torch.zeros(num_directions, max_batch_size,
+                                self.hidden_size, dtype=torch.float,
+                                device=x.device)
+            zeros.squeeze_(0)
+            if x.is_quantized:
+                zeros = torch.quantize_per_tensor(zeros, scale=1.0,
+                                                  zero_point=0, dtype=x.dtype)
+            hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
+        else:
+            hidden_non_opt = torch.jit._unwrap_optional(hidden)
+            if isinstance(hidden_non_opt[0], Tensor):
+                hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
+                                               max_batch_size,
+                                               self.hidden_size)
+                cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
+                                               max_batch_size,
+                                               self.hidden_size)
+                hxcx = [(hx[idx].squeeze(0), cx[idx].squeeze(0)) for idx in range(self.num_layers)]
+            else:
+                hxcx = hidden_non_opt
+
+        hx_list = []
+        cx_list = []
+        for idx, layer in enumerate(self.layers):
+            x, (h, c) = layer(x, hxcx[idx])
+            hx_list.append(torch.jit._unwrap_optional(h))
+            cx_list.append(torch.jit._unwrap_optional(c))
+        hx_tensor = torch.stack(hx_list)
+        cx_tensor = torch.stack(cx_list)
+
+        # We are creating another dimension for bidirectional case
+        # need to collapse it
+        hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
+        cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
+
+        if self.batch_first:
+            x = x.transpose(0, 1)
+
+        return x, (hx_tensor, cx_tensor)
+
+    def _get_name(self):
+        return 'QuantizableLSTM'
+
+    @classmethod
+    def from_float(cls, other, qconfig=None):
+        assert isinstance(other, cls._FLOAT_MODULE)
+        assert (hasattr(other, 'qconfig') or qconfig)
+        observed = cls(other.input_size, other.hidden_size, other.num_layers,
+                       other.bias, other.batch_first, other.dropout,
+                       other.bidirectional)
+        observed.qconfig = getattr(other, 'qconfig', qconfig)
+        for idx in range(other.num_layers):
+            observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
+                                                         batch_first=False)
+
+        # Prepare the model
+        if other.training:
+            observed.train()
+            observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
+        else:
+            observed.eval()
+            observed = torch.ao.quantization.prepare(observed, inplace=True)
+        return observed
+
+    @classmethod
+    def from_observed(cls, other):
+        # The whole flow is float -> observed -> quantized
+        # This class does float -> observed only
+        raise NotImplementedError("It looks like you are trying to convert a "
+                                  "non-quantizable LSTM module. Please, see "
+                                  "the examples on quantizable LSTMs.")
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd57a11b784d7fc16ec305418426c1c23f6b0e39
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/__init__.py
@@ -0,0 +1,38 @@
+from . import functional
+from .modules import *  # noqa: F403
+from .modules import MaxPool2d
+
+__all__ = [
+    'BatchNorm2d',
+    'BatchNorm3d',
+    'Conv1d',
+    'Conv2d',
+    'Conv3d',
+    'ConvTranspose1d',
+    'ConvTranspose2d',
+    'ConvTranspose3d',
+    'DeQuantize',
+    'ELU',
+    'Embedding',
+    'EmbeddingBag',
+    'GroupNorm',
+    'Hardswish',
+    'InstanceNorm1d',
+    'InstanceNorm2d',
+    'InstanceNorm3d',
+    'LayerNorm',
+    'LeakyReLU',
+    'Linear',
+    'LSTM',
+    'MultiheadAttention',
+    'Quantize',
+    'ReLU6',
+    'Sigmoid',
+    'Softmax',
+    'Dropout',
+    'PReLU',
+    # Wrapper modules
+    'FloatFunctional',
+    'FXFloatFunctional',
+    'QFunctional',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e37528bdcf1a4ff91be2f8f9c77f0fdaffbbca8a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf33864f6c48e4f2a914f687c382bd313f94565c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc6df8afce25c62a5707136bc46cab16c49a83c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/__init__.py
@@ -0,0 +1 @@
+from .modules import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b342ef72a0f1335167c79039010f765779dd2849
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0992314f6465b6e79721c91f323e4a98fbfbd18b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py
@@ -0,0 +1,19 @@
+
+from .linear import Linear
+from .rnn import LSTM, GRU, LSTMCell, RNNCell, GRUCell
+from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
+
+__all__ = [
+    'Linear',
+    'LSTM',
+    'GRU',
+    'LSTMCell',
+    'RNNCell',
+    'GRUCell',
+    'Conv1d',
+    'Conv2d',
+    'Conv3d',
+    'ConvTranspose1d',
+    'ConvTranspose2d',
+    'ConvTranspose3d',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97b32fce8c6cf380cf2d8a29428b1682d2b75f5d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19b99ec11b651a7eb4f93f05644b987f29d94d64
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12063b5dedb9af9083db51c28bc7c43c49c8645b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf6665b83f1d77e22a8ac86fc201618adbbea501
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9b078a18fc86ae4603396338ffbc774383b8b8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py
@@ -0,0 +1,399 @@
+r"""Dynamically quantized convolution modules."""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch import Tensor
+from torch._ops import ops
+from torch.nn.common_types import _size_1_t
+from torch.nn.modules.utils import _single, _pair, _triple
+from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
+import torch.ao.nn.quantized as nnq
+import warnings
+
+__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
+
+
+class Conv1d(nnq.Conv1d):
+    r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+
+    See :class:`~torch.nn.Conv1d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
+        >>> input = torch.randn(20, 16, 100)
+        >>> output = m(input)
+
+    """
+
+    _FLOAT_MODULE = nn.Conv1d
+    _NNIQAT_CONV_BN_MODULE = None  # type: ignore[assignment]
+    _NNI_CONV_RELU_MODULE = None  # type: ignore[assignment]
+
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_1_t,
+                 stride: _size_1_t = 1,
+                 padding: _size_1_t = 0,
+                 dilation: _size_1_t = 1,
+                 groups: int = 1,
+                 bias: bool = True,
+                 padding_mode: str = 'zeros',
+                 device=None,
+                 dtype=None,
+                 reduce_range=True):
+        warnings.warn(
+            "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
+                self._get_name()
+            )
+        )
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _single(kernel_size)
+        stride = _single(stride)
+        padding = padding if isinstance(padding, str) else _single(padding)
+        dilation = _single(dilation)
+
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedConv1d'
+
+    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 3:
+            raise ValueError("Input shape must be `(N, C, L)`!")
+        if self.padding_mode != 'zeros':
+            # Padding in Conv1d is stored as (p, p), need to get (p,)
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range)
+
+
+class Conv2d(nnq.Conv2d):
+    r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+
+    See :class:`~torch.nn.Conv2d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> # With square kernels and equal stride
+        >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+        >>> # non-square kernels and unequal stride and with padding and dilation
+        >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
+        >>> input = torch.randn(20, 16, 50, 100)
+        >>> output = m(input)
+
+    """
+    _FLOAT_MODULE = nn.Conv2d
+    _NNIQAT_CONV_BN_MODULE = None  # type: ignore[assignment]
+    _NNI_CONV_RELU_MODULE = None  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        warnings.warn(
+            "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
+                self._get_name()
+            )
+        )
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _pair(kernel_size)
+        stride = _pair(stride)
+        padding = _pair(padding)
+        dilation = _pair(dilation)
+
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedConv2d'
+
+    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return ops.quantized.conv2d_dynamic(
+            input, self._packed_params, reduce_range)
+
+
+class Conv3d(nnq.Conv3d):
+    r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+
+    See :class:`~torch.nn.Conv3d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> # With square kernels and equal stride
+        >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
+        >>> # non-square kernels and unequal stride and with padding and dilation
+        >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
+        >>> input = torch.randn(20, 16, 56, 56, 56)
+        >>> output = m(input)
+
+    """
+    _FLOAT_MODULE = nn.Conv3d
+    _NNIQAT_CONV_BN_MODULE = None  # type: ignore[assignment]
+    _NNI_CONV_RELU_MODULE = None  # type: ignore[assignment]
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        warnings.warn(
+            "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
+                self._get_name()
+            )
+        )
+        assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _triple(kernel_size)
+        stride = _triple(stride)
+        padding = _triple(padding)
+        dilation = _triple(dilation)
+        super()._init(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedConv3d'
+
+    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return ops.quantized.conv3d_dynamic(
+            input, self._packed_params, reduce_range)
+
+
+class ConvTranspose1d(nnq.ConvTranspose1d):
+    r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.ConvTranspose1d`.
+
+    For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d`
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+    See :class:`~torch.nn.ConvTranspose1d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> # With square kernels and equal stride
+        >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+        >>> output = m(input)
+        >>> # exact output size can be also specified as an argument
+        >>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1)
+        >>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
+        >>> h = downsample(input)
+        >>> h.size()
+        torch.Size([1, 16, 6])
+        >>> output = upsample(h, output_size=input.size())
+        >>> output.size()
+        torch.Size([1, 16, 12])
+    """
+
+    _FLOAT_MODULE = nn.ConvTranspose1d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0, groups=1, bias=True,
+                 dilation=1, padding_mode='zeros', device=None, dtype=None):
+        warnings.warn(
+            "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
+                self._get_name()
+            )
+        )
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, output_padding,
+            groups, bias, dilation, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedConvTranspose1d'
+
+    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 3:
+            raise ValueError("Input shape must be `(N, C, L)`!")
+        return torch.ops.quantized.conv_transpose1d_dynamic(
+            input, self._packed_params, reduce_range)
+
+
+class ConvTranspose2d(nnq.ConvTranspose2d):
+    r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.ConvTranspose2d`.
+
+    For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d`
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+    See :class:`~torch.nn.ConvTranspose2d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> # With square kernels and equal stride
+        >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+        >>> output = m(input)
+        >>> # exact output size can be also specified as an argument
+        >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
+        >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
+        >>> h = downsample(input)
+        >>> h.size()
+        torch.Size([1, 16, 6, 6])
+        >>> output = upsample(h, output_size=input.size())
+        >>> output.size()
+        torch.Size([1, 16, 12, 12])
+    """
+
+    _FLOAT_MODULE = nn.ConvTranspose2d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0, groups=1, bias=True,
+                 dilation=1, padding_mode='zeros', device=None, dtype=None):
+        warnings.warn(
+            "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
+                self._get_name()
+            )
+        )
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, output_padding,
+            groups, bias, dilation, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedConvTranspose2d'
+
+    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        return ops.quantized.conv_transpose2d_dynamic(
+            input, self._packed_params, reduce_range)
+
+
+class ConvTranspose3d(nnq.ConvTranspose3d):
+    r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.ConvTranspose3d`.
+
+    For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d`
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+    See :class:`~torch.nn.ConvTranspose3d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> # With cubic kernels and equal stride
+        >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
+        >>> # non-cubic kernels and unequal stride and with padding
+        >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
+        >>> output = m(input)
+        >>> # exact output size can be also specified as an argument
+        >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
+        >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
+        >>> h = downsample(input)
+        >>> h.size()
+        torch.Size([1, 16, 6, 6, 6])
+        >>> output = upsample(h, output_size=input.size())
+        >>> output.size()
+        torch.Size([1, 16, 12, 12, 12])
+    """
+
+    _FLOAT_MODULE = nn.ConvTranspose3d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0, groups=1, bias=True,
+                 dilation=1, padding_mode='zeros', device=None, dtype=None):
+        warnings.warn(
+            "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
+                self._get_name()
+            )
+        )
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, output_padding,
+            groups, bias, dilation, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedConvTranspose3d'
+
+    def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, T, H, W)`!")
+        return ops.quantized.conv_transpose3d_dynamic(
+            input, self._packed_params, reduce_range)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..407ecd9abbcf3c8d10057f5c6843de95ffabea4c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py
@@ -0,0 +1,132 @@
+import torch
+import torch.ao.nn.quantized as nnq
+from torch.ao.nn.quantized.modules.utils import _quantize_weight
+import torch.ao.nn.intrinsic as nni
+
+__all__ = [
+    "Linear",
+]
+
+
+class Linear(nnq.Linear):
+    r"""
+    A dynamic quantized linear module with floating point tensor as inputs and outputs.
+    We adopt the same interface as `torch.nn.Linear`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
+
+    Similar to :class:`torch.nn.Linear`, attributes will be randomly
+    initialized at module creation time and will be overwritten later
+
+    Attributes:
+        weight (Tensor): the non-learnable quantized weights of the module which are of
+                         shape :math:`(\text{out\_features}, \text{in\_features})`.
+        bias (Tensor): the non-learnable floating point bias of the module of shape
+                       :math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
+                       the values are initialized to zero.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> m = nn.quantized.dynamic.Linear(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    # version used in this class is different from the parent class nnq.Linear
+    _version = 4
+
+    def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
+        super().__init__(in_features, out_features, bias_, dtype=dtype)
+        # We don't muck around with buffers or attributes or anything here
+        # to keep the module simple. *everything* is simply a Python attribute.
+        # Serialization logic is explicitly handled in the below serialization and
+        # deserialization modules
+        self.version = 4
+
+    def forward(self, x):
+        # Note that we can handle self.bias == None case.
+        if self._packed_params.dtype == torch.qint8:
+            if self.version is None or self.version < 4:
+                Y = torch.ops.quantized.linear_dynamic(
+                    x, self._packed_params._packed_params)
+            else:
+                Y = torch.ops.quantized.linear_dynamic(
+                    x, self._packed_params._packed_params, reduce_range=True)
+        elif self._packed_params.dtype == torch.float16:
+            Y = torch.ops.quantized.linear_dynamic_fp16(
+                x, self._packed_params._packed_params)
+        else:
+            raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
+        return Y.to(x.dtype)
+
+    def _get_name(self):
+        return 'DynamicQuantizedLinear'
+
+    def extra_repr(self):
+        extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format(
+            self.in_features, self.out_features, self._packed_params.dtype
+        )
+        if self._packed_params.dtype == torch.qint8:
+            extra_repr_str += f', qscheme={self.weight().qscheme()}'
+        return extra_repr_str
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+        self.version = version
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a dynamic quantized module from a float module or qparams_dict
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+                          utilities or provided by the user
+        """
+        float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
+                         torch.ao.nn.intrinsic.modules.fused.LinearReLU, torch.ao.nn.qat.dynamic.Linear]
+
+        assert type(mod) in float_modules, \
+            'nn.quantized.dynamic.Linear.from_float only works for one of' + \
+            str([float_mod.__name__ for float_mod in float_modules])
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        if type(mod) == nni.LinearReLU:
+            mod = mod[0]
+        if mod.qconfig is not None and mod.qconfig.weight is not None:
+            weight_observer = mod.qconfig.weight()
+        else:
+            # We have the circular import issues if we import the qconfig in the beginning of this file:
+            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
+            # import until we need it.
+            from torch.ao.quantization.qconfig import default_dynamic_qconfig
+            weight_observer = default_dynamic_qconfig.weight()
+        dtype = weight_observer.dtype
+        assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \
+            f"dynamic quantized linear are qint8 and float16 got: {dtype}"
+        weight_observer(mod.weight)
+        if dtype == torch.qint8:
+            qweight = _quantize_weight(mod.weight.float(), weight_observer)
+        elif dtype == torch.float16:
+            qweight = mod.weight.float()
+        else:
+            raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
+        qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
+        qlinear.set_weight_bias(qweight, mod.bias)
+        return qlinear
+
+    @classmethod
+    def from_reference(cls, ref_qlinear):
+        """ Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
+        module
+        Args:
+            ref_qlinear (Module): a reference quantized  module, either produced by
+            torch.ao.quantization functions or provided by the user
+        """
+        qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features, dtype=ref_qlinear.weight_dtype)
+        qweight = ref_qlinear.get_quantized_weight()
+        bias = ref_qlinear.bias
+        qlinear.set_weight_bias(qweight, bias)
+        return qlinear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b305ce61818fefdfd366736d6f12ceb61765ac
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py
@@ -0,0 +1,1096 @@
+import numbers
+import warnings
+
+import torch
+import torch.nn as nn
+from torch import Tensor  # noqa: F401
+from torch._jit_internal import Tuple, Optional, List, Union, Dict  # noqa: F401
+from torch.nn.utils.rnn import PackedSequence
+from torch.ao.nn.quantized.modules.utils import _quantize_weight
+
+__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
+           'GRUCell', "apply_permutation"]
+
+
+def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
+    return tensor.index_select(dim, permutation)
+
+
+def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
+    warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead")
+    return _apply_permutation(tensor, permutation, dim)
+
+
+def pack_weight_bias(qweight, bias, dtype):
+
+    if dtype == torch.qint8:
+        # for each layer, for each direction we need to quantize and pack
+        # weights and pack parameters in this order:
+        #
+        #   w_ih, w_hh
+        packed_weight = \
+            torch.ops.quantized.linear_prepack(qweight, bias)
+
+        return packed_weight
+    else:
+        # for each layer, for each direction we need to quantize and pack
+        # weights and pack parameters in this order:
+        #
+        #   packed_ih, packed_hh, b_ih, b_hh
+        packed_weight = torch.ops.quantized.linear_prepack_fp16(
+            qweight, bias)
+
+        return packed_weight
+
+
+class PackedParameter(torch.nn.Module):
+    def __init__(self, param):
+        super().__init__()
+        self.param = param
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'param'] = self.param
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        self.param = state_dict[prefix + 'param']
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+
+class RNNBase(torch.nn.Module):
+
+    _FLOAT_MODULE = nn.RNNBase
+
+    _version = 2
+
+    def __init__(self, mode, input_size, hidden_size,
+                 num_layers=1, bias=True, batch_first=False,
+                 dropout=0., bidirectional=False, dtype=torch.qint8):
+        super().__init__()
+
+        self.mode = mode
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.bias = bias
+        self.batch_first = batch_first
+        self.dropout = float(dropout)
+        self.bidirectional = bidirectional
+        self.dtype = dtype
+        self.version = 2
+        self.training = False
+        num_directions = 2 if bidirectional else 1
+
+        # "type: ignore" is required since ints and Numbers are not fully comparable
+        # https://github.com/python/mypy/issues/8566
+        if not isinstance(dropout, numbers.Number) \
+                or not 0 <= dropout <= 1 or isinstance(dropout, bool):  # type: ignore[operator]
+            raise ValueError("dropout should be a number in range [0, 1] "
+                             "representing the probability of an element being "
+                             "zeroed")
+        if dropout > 0 and num_layers == 1:  # type: ignore[operator]
+            warnings.warn("dropout option adds dropout after all but last "
+                          "recurrent layer, so non-zero dropout expects "
+                          f"num_layers greater than 1, but got dropout={dropout} and "
+                          f"num_layers={num_layers}")
+
+        if mode == 'LSTM':
+            gate_size = 4 * hidden_size
+        elif mode == 'GRU':
+            gate_size = 3 * hidden_size
+        else:
+            raise ValueError("Unrecognized RNN mode: " + mode)
+
+        _all_weight_values = []
+        for layer in range(num_layers):
+            for direction in range(num_directions):
+                layer_input_size = input_size if layer == 0 else hidden_size * num_directions
+
+                w_ih = torch.randn(gate_size, layer_input_size).to(torch.float)
+                w_hh = torch.randn(gate_size, hidden_size).to(torch.float)
+                b_ih = torch.randn(gate_size).to(torch.float)
+                b_hh = torch.randn(gate_size).to(torch.float)
+                if dtype == torch.qint8:
+                    w_ih = torch.quantize_per_tensor(w_ih, scale=0.1, zero_point=0, dtype=torch.qint8)
+                    w_hh = torch.quantize_per_tensor(w_hh, scale=0.1, zero_point=0, dtype=torch.qint8)
+                    packed_ih = \
+                        torch.ops.quantized.linear_prepack(w_ih, b_ih)
+                    packed_hh = \
+                        torch.ops.quantized.linear_prepack(w_hh, b_hh)
+                    if self.version is None or self.version < 2:
+                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
+                            packed_ih, packed_hh, b_ih, b_hh)
+                    else:
+                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
+                            packed_ih, packed_hh, b_ih, b_hh, True)
+                else:
+                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
+                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
+                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
+                        packed_ih, packed_hh)
+
+                _all_weight_values.append(PackedParameter(cell_params))
+        self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
+
+    def _get_name(self):
+        return 'DynamicQuantizedRNN'
+
+    def extra_repr(self):
+        s = '{input_size}, {hidden_size}'
+        if self.num_layers != 1:
+            s += ', num_layers={num_layers}'
+        if self.bias is not True:
+            s += ', bias={bias}'
+        if self.batch_first is not False:
+            s += ', batch_first={batch_first}'
+        if self.dropout != 0:
+            s += ', dropout={dropout}'
+        if self.bidirectional is not False:
+            s += ', bidirectional={bidirectional}'
+        return s.format(**self.__dict__)
+
+    def __repr__(self):
+        # We don't want to show `ModuleList` children, hence custom
+        # `__repr__`. This is the same as nn.Module.__repr__, except the check
+        # for the `PackedParameter` and `nn.ModuleList`.
+        # You should still override `extra_repr` to add more info.
+        extra_lines = []
+        extra_repr = self.extra_repr()
+        # empty string will be split into list ['']
+        if extra_repr:
+            extra_lines = extra_repr.split('\n')
+        child_lines = []
+        for key, module in self._modules.items():
+            if isinstance(module, (PackedParameter, nn.ModuleList)):
+                continue
+            mod_str = repr(module)
+            mod_str = nn.modules.module._addindent(mod_str, 2)
+            child_lines.append('(' + key + '): ' + mod_str)
+        lines = extra_lines + child_lines
+
+        main_str = self._get_name() + '('
+        if lines:
+            # simple one-liner info, which most builtin Modules will use
+            if len(extra_lines) == 1 and not child_lines:
+                main_str += extra_lines[0]
+            else:
+                main_str += '\n  ' + '\n  '.join(lines) + '\n'
+
+        main_str += ')'
+        return main_str
+
+    def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
+        expected_input_dim = 2 if batch_sizes is not None else 3
+        if input.dim() != expected_input_dim:
+            raise RuntimeError(
+                f'input must have {expected_input_dim} dimensions, got {input.dim()}')
+        if self.input_size != input.size(-1):
+            raise RuntimeError(
+                f'input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}')
+
+    def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
+        if batch_sizes is not None:
+            mini_batch = int(batch_sizes[0])
+        else:
+            mini_batch = input.size(0) if self.batch_first else input.size(1)
+        num_directions = 2 if self.bidirectional else 1
+        expected_hidden_size = (self.num_layers * num_directions,
+                                mini_batch, self.hidden_size)
+        return expected_hidden_size
+
+    def check_hidden_size(
+        self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
+        msg: str = 'Expected hidden size {}, got {}'
+    ) -> None:
+        if hx.size() != expected_hidden_size:
+            raise RuntimeError(msg.format(
+                expected_hidden_size, list(hx.size())))
+
+    def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
+        self.check_input(input, batch_sizes)
+        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
+        self.check_hidden_size(hidden, expected_hidden_size,
+                               msg='Expected hidden size {}, got {}')
+
+    def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
+        if permutation is None:
+            return hx
+        return _apply_permutation(hx, permutation)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+        self.version = version
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+    def set_weight_bias(self, weight_bias_dict):
+
+        def weight_bias_name(ihhh, layer, suffix):
+            weight_name = f"weight_{ihhh}_l{layer}{suffix}"
+            bias_name = f"bias_{ihhh}_l{layer}{suffix}"
+            return weight_name, bias_name
+
+        num_directions = 2 if self.bidirectional else 1
+        # TODO: dedup with __init__ of RNNBase
+        _all_weight_values = []
+        for layer in range(self.num_layers):
+            for direction in range(num_directions):
+                suffix = "_reverse" if direction == 1 else ""
+                w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix)
+                w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix)
+                w_ih = weight_bias_dict[w_ih_name]
+                b_ih = weight_bias_dict[b_ih_name]
+                w_hh = weight_bias_dict[w_hh_name]
+                b_hh = weight_bias_dict[b_hh_name]
+                if w_ih.dtype == torch.qint8:
+                    packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
+                    packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
+                    if self.version is None or self.version < 2:
+                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
+                            packed_ih, packed_hh, b_ih, b_hh)
+                    else:
+                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
+                            packed_ih, packed_hh, b_ih, b_hh, True)
+                else:
+                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
+                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
+                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
+                        packed_ih, packed_hh)
+
+                _all_weight_values.append(PackedParameter(cell_params))
+        self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
+
+    @classmethod
+    def from_float(cls, mod):
+        assert type(mod) in {torch.nn.LSTM,
+                             torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
+        assert hasattr(
+            mod,
+            'qconfig'
+        ), 'Input float module must have qconfig defined'
+
+        if mod.qconfig is not None and mod.qconfig.weight is not None:
+            weight_observer_method = mod.qconfig.weight
+        else:
+            # We have the circular import issues if we import the qconfig in the beginning of this file:
+            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
+            # import until we need it.
+            from torch.ao.quantization.qconfig import default_dynamic_qconfig
+            weight_observer_method = default_dynamic_qconfig.weight
+
+        dtype = weight_observer_method().dtype
+        supported_scalar_types = [torch.qint8, torch.float16]
+        if dtype not in supported_scalar_types:
+            raise RuntimeError(f'Unsupported dtype for dynamic RNN quantization: {dtype}')
+        # RNNBase can be either LSTM or GRU
+        qRNNBase: Union[LSTM, GRU]
+        if mod.mode == 'LSTM':
+            qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
+                            mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
+        elif mod.mode == 'GRU':
+            qRNNBase = GRU(mod.input_size, mod.hidden_size, mod.num_layers,
+                           mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
+        else:
+            raise NotImplementedError('Only LSTM/GRU is supported for QuantizedRNN for now')
+
+        num_directions = 2 if mod.bidirectional else 1
+
+        assert mod.bias
+
+        _all_weight_values = []
+        for layer in range(qRNNBase.num_layers):
+            for direction in range(num_directions):
+                suffix = '_reverse' if direction == 1 else ''
+
+                def retrieve_weight_bias(ihhh):
+                    weight_name = f'weight_{ihhh}_l{layer}{suffix}'
+                    bias_name = f'bias_{ihhh}_l{layer}{suffix}'
+                    weight = getattr(mod, weight_name)
+                    bias = getattr(mod, bias_name)
+                    return weight, bias
+
+                weight_ih, bias_ih = retrieve_weight_bias('ih')
+                weight_hh, bias_hh = retrieve_weight_bias('hh')
+
+                if dtype == torch.qint8:
+                    def quantize_and_pack(w, b):
+                        weight_observer = weight_observer_method()
+                        weight_observer(w)
+                        qweight = _quantize_weight(w.float(), weight_observer)
+                        packed_weight = \
+                            torch.ops.quantized.linear_prepack(qweight, b)
+                        return packed_weight
+                    packed_ih = quantize_and_pack(weight_ih, bias_ih)
+                    packed_hh = quantize_and_pack(weight_hh, bias_hh)
+                    if qRNNBase.version is None or qRNNBase.version < 2:
+                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
+                            packed_ih, packed_hh, bias_ih, bias_hh)
+                    else:
+                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
+                            packed_ih, packed_hh, bias_ih, bias_hh, True)
+
+                elif dtype == torch.float16:
+                    packed_ih = torch.ops.quantized.linear_prepack_fp16(
+                        weight_ih.float(), bias_ih)
+                    packed_hh = torch.ops.quantized.linear_prepack_fp16(
+                        weight_hh.float(), bias_hh)
+
+                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
+                        packed_ih, packed_hh)
+                else:
+                    raise RuntimeError('Unsupported dtype specified for dynamic quantized LSTM!')
+
+                _all_weight_values.append(PackedParameter(cell_params))
+        qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values)
+
+        return qRNNBase
+
+    def _weight_bias(self):
+        # Returns a dict of weights and biases
+        weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
+        count = 0
+        num_directions = 2 if self.bidirectional else 1
+        for layer in range(self.num_layers):
+            for direction in range(num_directions):
+                suffix = '_reverse' if direction == 1 else ''
+                key_name1 = f'weight_ih_l{layer}{suffix}'
+                key_name2 = f'weight_hh_l{layer}{suffix}'
+                # packed weights are part of torchbind class, CellParamsSerializationType
+                # Within the packed weight class, the weight and bias are accessible as Tensors
+                packed_weight_bias = self._all_weight_values[count].param.__getstate__()[0][4]
+                weight_bias_dict['weight'][key_name1] = packed_weight_bias[0].__getstate__()[0][0]
+                weight_bias_dict['weight'][key_name2] = packed_weight_bias[1].__getstate__()[0][0]
+                key_name1 = f'bias_ih_l{layer}{suffix}'
+                key_name2 = f'bias_hh_l{layer}{suffix}'
+                weight_bias_dict['bias'][key_name1] = packed_weight_bias[0].__getstate__()[0][1]
+                weight_bias_dict['bias'][key_name2] = packed_weight_bias[1].__getstate__()[0][1]
+                count = count + 1
+        return weight_bias_dict
+
+    def get_weight(self):
+        return self._weight_bias()['weight']
+
+    def get_bias(self):
+        return self._weight_bias()['bias']
+
+
+class LSTM(RNNBase):
+    r"""
+    A dynamic quantized LSTM module with floating point tensor as inputs and outputs.
+    We adopt the same interface as `torch.nn.LSTM`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> rnn = nn.LSTM(10, 20, 2)
+        >>> input = torch.randn(5, 3, 10)
+        >>> h0 = torch.randn(2, 3, 20)
+        >>> c0 = torch.randn(2, 3, 20)
+        >>> output, (hn, cn) = rnn(input, (h0, c0))
+    """
+    _FLOAT_MODULE = nn.LSTM
+
+    __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
+
+    def __init__(self, *args, **kwargs):
+        super().__init__('LSTM', *args, **kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedLSTM'
+
+    def forward_impl(
+        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
+        batch_sizes: Optional[Tensor], max_batch_size: int,
+        sorted_indices: Optional[Tensor]
+    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
+        if hx is None:
+            num_directions = 2 if self.bidirectional else 1
+            zeros = torch.zeros(self.num_layers * num_directions,
+                                max_batch_size, self.hidden_size,
+                                dtype=input.dtype, device=input.device)
+            hx = (zeros, zeros)
+        else:
+            # Each batch of the hidden state should match the input sequence that
+            # the user believes he/she is passing in.
+            hx = self.permute_hidden(hx, sorted_indices)
+
+        self.check_forward_args(input, hx, batch_sizes)
+
+        _all_params = ([m.param for m in self._all_weight_values])
+        if batch_sizes is None:
+            result = torch.quantized_lstm(input, hx, _all_params, self.bias, self.num_layers,
+                                          float(self.dropout), self.training, self.bidirectional,
+                                          self.batch_first, dtype=self.dtype, use_dynamic=True)
+        else:
+            result = torch.quantized_lstm(input, batch_sizes, hx, _all_params, self.bias,
+                                          self.num_layers, float(self.dropout), self.training,
+                                          self.bidirectional, dtype=self.dtype, use_dynamic=True)
+        output = result[0]
+        hidden = result[1:]
+
+        return output, hidden
+
+    @torch.jit.export
+    def forward_tensor(
+        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
+    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
+        batch_sizes = None
+        max_batch_size = input.size(0) if self.batch_first else input.size(1)
+        sorted_indices = None
+        unsorted_indices = None
+
+        output, hidden = self.forward_impl(
+            input, hx, batch_sizes, max_batch_size, sorted_indices)
+
+        return output, self.permute_hidden(hidden, unsorted_indices)
+
+    @torch.jit.export
+    def forward_packed(
+        self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
+    ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
+        input_, batch_sizes, sorted_indices, unsorted_indices = input
+        max_batch_size = int(batch_sizes[0])
+
+        output_, hidden = self.forward_impl(
+            input_, hx, batch_sizes, max_batch_size, sorted_indices
+        )
+
+        output = PackedSequence(output_, batch_sizes,
+                                sorted_indices, unsorted_indices)
+        return output, self.permute_hidden(hidden, unsorted_indices)
+
+    # "type: ignore" is required due to issue #43072
+    def permute_hidden(  # type: ignore[override]
+        self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
+    ) -> Tuple[Tensor, Tensor]:
+        if permutation is None:
+            return hx
+        return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
+
+    # "type: ignore" is required due to issue #43072
+    def check_forward_args(  # type: ignore[override]
+        self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]
+    ) -> None:
+        self.check_input(input, batch_sizes)
+        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
+
+        self.check_hidden_size(hidden[0], expected_hidden_size,
+                               'Expected hidden[0] size {}, got {}')
+        self.check_hidden_size(hidden[1], expected_hidden_size,
+                               'Expected hidden[1] size {}, got {}')
+
+    @torch.jit.ignore
+    def forward(self, input, hx=None):
+        if isinstance(input, PackedSequence):
+            return self.forward_packed(input, hx)
+        else:
+            return self.forward_tensor(input, hx)
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_mod):
+        assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
+        "exists in LSTM, may need to relax the assumption to support the use case"
+        qmod = cls(
+            ref_mod.input_size,
+            ref_mod.hidden_size,
+            ref_mod.num_layers,
+            ref_mod.bias,
+            ref_mod.batch_first,
+            ref_mod.dropout,
+            ref_mod.bidirectional,
+            # assuming there is layer 0, which should be OK
+            ref_mod.weight_ih_l0_dtype,
+        )
+        qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
+        return qmod
+
+
+class GRU(RNNBase):
+    r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
+
+
+    For each element in the input sequence, each layer computes the following
+    function:
+
+    .. math::
+        \begin{array}{ll}
+            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
+            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
+            n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
+            h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
+        \end{array}
+
+    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
+    at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
+    at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
+    :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
+    :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
+
+    In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
+    (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
+    dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
+    variable which is :math:`0` with probability :attr:`dropout`.
+
+    Args:
+        input_size: The number of expected features in the input `x`
+        hidden_size: The number of features in the hidden state `h`
+        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
+            would mean stacking two GRUs together to form a `stacked GRU`,
+            with the second GRU taking in outputs of the first GRU and
+            computing the final results. Default: 1
+        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
+            Default: ``True``
+        batch_first: If ``True``, then the input and output tensors are provided
+            as (batch, seq, feature). Default: ``False``
+        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
+            GRU layer except the last layer, with dropout probability equal to
+            :attr:`dropout`. Default: 0
+        bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
+
+    Inputs: input, h_0
+        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
+          of the input sequence. The input can also be a packed variable length
+          sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
+          for details.
+        - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+          containing the initial hidden state for each element in the batch.
+          Defaults to zero if not provided. If the RNN is bidirectional,
+          num_directions should be 2, else it should be 1.
+
+    Outputs: output, h_n
+        - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
+          containing the output features h_t from the last layer of the GRU,
+          for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
+          given as the input, the output will also be a packed sequence.
+          For the unpacked case, the directions can be separated
+          using ``output.view(seq_len, batch, num_directions, hidden_size)``,
+          with forward and backward being direction `0` and `1` respectively.
+
+          Similarly, the directions can be separated in the packed case.
+        - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+          containing the hidden state for `t = seq_len`
+
+          Like *output*, the layers can be separated using
+          ``h_n.view(num_layers, num_directions, batch, hidden_size)``.
+
+    Shape:
+        - Input1: :math:`(L, N, H_{in})` tensor containing input features where
+          :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
+        - Input2: :math:`(S, N, H_{out})` tensor
+          containing the initial hidden state for each element in the batch.
+          :math:`H_{out}=\text{hidden\_size}`
+          Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
+          If the RNN is bidirectional, num_directions should be 2, else it should be 1.
+        - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
+        - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
+          for each element in the batch
+
+    Attributes:
+        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
+            (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
+            Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
+        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
+            (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
+        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
+            (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
+        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
+            (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
+
+    .. note::
+        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+        where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+    .. note::
+        The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
+        In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
+        previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
+        `W` and addition of bias:
+
+        .. math::
+            \begin{aligned}
+                n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
+            \end{aligned}
+
+        This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
+
+        .. math::
+            \begin{aligned}
+                n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
+            \end{aligned}
+
+        This implementation differs on purpose for efficiency.
+
+    .. include:: ../cudnn_persistent_rnn.rst
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> rnn = nn.GRU(10, 20, 2)
+        >>> input = torch.randn(5, 3, 10)
+        >>> h0 = torch.randn(2, 3, 20)
+        >>> output, hn = rnn(input, h0)
+    """
+    _FLOAT_MODULE = nn.GRU
+
+    __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
+
+    def __init__(self, *args, **kwargs):
+        super().__init__('GRU', *args, **kwargs)
+
+    def _get_name(self):
+        return 'DynamicQuantizedGRU'
+
+    def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
+        self.check_input(input, batch_sizes)
+        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
+
+        self.check_hidden_size(hidden, expected_hidden_size,
+                               'Expected hidden size {}, got {}')
+
+    def forward_impl(
+        self, input: Tensor, hx: Optional[Tensor],
+        batch_sizes: Optional[Tensor], max_batch_size: int,
+        sorted_indices: Optional[Tensor]
+    ) -> Tuple[Tensor, Tensor]:
+        if hx is None:
+            num_directions = 2 if self.bidirectional else 1
+            zeros = torch.zeros(self.num_layers * num_directions,
+                                max_batch_size, self.hidden_size,
+                                dtype=input.dtype, device=input.device)
+            hx = zeros
+        else:
+            # Each batch of the hidden state should match the input sequence that
+            # the user believes he/she is passing in.
+            hx = self.permute_hidden(hx, sorted_indices)
+
+        self.check_forward_args(input, hx, batch_sizes)
+
+        _all_params = ([m.param for m in self._all_weight_values])
+        if batch_sizes is None:
+            result = torch.quantized_gru(input,
+                                         hx,
+                                         _all_params,
+                                         self.bias,
+                                         self.num_layers,
+                                         self.dropout,
+                                         self.training,
+                                         self.bidirectional,
+                                         self.batch_first)
+        else:
+            result = torch.quantized_gru(input,
+                                         batch_sizes,
+                                         hx,
+                                         _all_params,
+                                         self.bias,
+                                         self.num_layers,
+                                         self.dropout,
+                                         self.training,
+                                         self.bidirectional)
+        output = result[0]
+        hidden = result[1]
+
+        return output, hidden
+
+
+    @torch.jit.export
+    def forward_tensor(
+        self, input: Tensor, hx: Optional[Tensor] = None
+    ) -> Tuple[Tensor, Tensor]:
+        batch_sizes = None
+        max_batch_size = input.size(0) if self.batch_first else input.size(1)
+        sorted_indices = None
+        unsorted_indices = None
+
+        output, hidden = self.forward_impl(
+            input, hx, batch_sizes, max_batch_size, sorted_indices)
+
+        return output, self.permute_hidden(hidden, unsorted_indices)
+
+    @torch.jit.export
+    def forward_packed(
+        self, input: PackedSequence, hx: Optional[Tensor] = None
+    ) -> Tuple[PackedSequence, Tensor]:
+        input_, batch_sizes, sorted_indices, unsorted_indices = input
+        max_batch_size = int(batch_sizes[0])
+        output_, hidden = self.forward_impl(
+            input_, hx, batch_sizes, max_batch_size, sorted_indices
+        )
+
+        output = PackedSequence(output_, batch_sizes,
+                                sorted_indices, unsorted_indices)
+        return output, self.permute_hidden(hidden, unsorted_indices)
+
+    def permute_hidden(
+        self, hx: Tensor, permutation: Optional[Tensor]
+    ) -> Tensor:
+        if permutation is None:
+            return hx
+        return _apply_permutation(hx, permutation)
+
+    @torch.jit.ignore
+    def forward(self, input, hx=None):
+        if isinstance(input, PackedSequence):
+            return self.forward_packed(input, hx)
+        else:
+            return self.forward_tensor(input, hx)
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+    @classmethod
+    def from_reference(cls, ref_mod):
+        assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
+        "exists in LSTM, may need to relax the assumption to support the use case"
+        qmod = cls(
+            ref_mod.input_size,
+            ref_mod.hidden_size,
+            ref_mod.num_layers,
+            ref_mod.bias,
+            ref_mod.batch_first,
+            ref_mod.dropout,
+            ref_mod.bidirectional,
+            # assuming there is layer 0, which should be OK
+            ref_mod.weight_ih_l0_dtype,
+        )
+        qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
+        return qmod
+
+class RNNCellBase(torch.nn.Module):
+    # _FLOAT_MODULE = nn.CellRNNBase
+    __constants__ = ['input_size', 'hidden_size', 'bias']
+
+    def __init__(self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8):
+        super().__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.bias = bias
+        self.weight_dtype = dtype
+        if bias:
+            self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
+            self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
+        else:
+            self.register_parameter('bias_ih', None)
+            self.register_parameter('bias_hh', None)
+
+        weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float)
+        weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float)
+        if dtype == torch.qint8:
+            weight_ih = torch.quantize_per_tensor(weight_ih, scale=1, zero_point=0, dtype=torch.qint8)
+            weight_hh = torch.quantize_per_tensor(weight_hh, scale=1, zero_point=0, dtype=torch.qint8)
+
+        if dtype == torch.qint8:
+            # for each layer, for each direction we need to quantize and pack
+            # weights and pack parameters in this order:
+            #
+            #   w_ih, w_hh
+            packed_weight_ih = \
+                torch.ops.quantized.linear_prepack(weight_ih, self.bias_ih)
+            packed_weight_hh = \
+                torch.ops.quantized.linear_prepack(weight_hh, self.bias_hh)
+        else:
+            # for each layer, for each direction we need to quantize and pack
+            # weights and pack parameters in this order:
+            #
+            #   packed_ih, packed_hh, b_ih, b_hh
+            packed_weight_ih = torch.ops.quantized.linear_prepack_fp16(
+                weight_ih, self.bias_ih)
+            packed_weight_hh = torch.ops.quantized.linear_prepack_fp16(
+                weight_hh, self.bias_hh)
+
+        self._packed_weight_ih = packed_weight_ih
+        self._packed_weight_hh = packed_weight_hh
+
+    def _get_name(self):
+        return 'DynamicQuantizedRNNBase'
+
+    def extra_repr(self):
+        s = '{input_size}, {hidden_size}'
+        if 'bias' in self.__dict__ and self.bias is not True:
+            s += ', bias={bias}'
+        if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
+            s += ', nonlinearity={nonlinearity}'
+        return s.format(**self.__dict__)
+
+    def check_forward_input(self, input):
+        if input.size(1) != self.input_size:
+            raise RuntimeError(
+                f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}")
+
+    def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
+        if input.size(0) != hx.size(0):
+            raise RuntimeError(
+                f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}")
+
+        if hx.size(1) != self.hidden_size:
+            raise RuntimeError(
+                f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}")
+
+    @classmethod
+    def from_float(cls, mod):
+        assert type(mod) in {torch.nn.LSTMCell,
+                             torch.nn.GRUCell,
+                             torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
+                                 only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell'
+        assert hasattr(
+            mod, 'qconfig'), 'Input float module must have qconfig defined'
+
+        if mod.qconfig is not None and mod.qconfig.weight is not None:
+            weight_observer_method = mod.qconfig.weight
+        else:
+            # We have the circular import issues if we import the qconfig in the beginning of this file:
+            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
+            # import until we need it.
+            from torch.ao.quantization.qconfig import default_dynamic_qconfig
+            weight_observer_method = default_dynamic_qconfig.weight
+
+        dtype = weight_observer_method().dtype
+        supported_scalar_types = [torch.qint8, torch.float16]
+        if dtype not in supported_scalar_types:
+            raise RuntimeError(f'Unsupported dtype for dynamic RNN quantization: {dtype}')
+
+        qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
+
+        if type(mod) == torch.nn.LSTMCell:
+            qRNNCellBase = LSTMCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
+        elif type(mod) == torch.nn.GRUCell:
+            qRNNCellBase = GRUCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
+        elif type(mod) == torch.nn.RNNCell:
+            qRNNCellBase = RNNCell(mod.input_size, mod.hidden_size, bias=mod.bias, nonlinearity=mod.nonlinearity, dtype=dtype)
+        else:
+            raise NotImplementedError('Only LSTMCell, GRUCell and RNNCell \
+            are supported for QuantizedRNN for now')
+
+        assert mod.bias
+
+        def _observe_and_quantize_weight(weight):
+            if dtype == torch.qint8:
+                weight_observer = weight_observer_method()
+                weight_observer(weight)
+                qweight = _quantize_weight(weight.float(), weight_observer)
+                return qweight
+            else:
+                return weight.float()
+
+        qRNNCellBase._packed_weight_ih = pack_weight_bias(_observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype)
+        qRNNCellBase._packed_weight_hh = pack_weight_bias(_observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype)
+        return qRNNCellBase
+
+    @classmethod
+    def from_reference(cls, ref_mod):
+        assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih "
+        "exists in reference module, may need to relax the assumption to support the use case"
+        if hasattr(ref_mod, "nonlinearity"):
+            qmod = cls(
+                ref_mod.input_size,
+                ref_mod.hidden_size,
+                ref_mod.bias,
+                ref_mod.nonlinearity,
+                dtype=ref_mod.weight_ih_dtype
+            )
+        else:
+            qmod = cls(
+                ref_mod.input_size,
+                ref_mod.hidden_size,
+                ref_mod.bias,
+                dtype=ref_mod.weight_ih_dtype
+            )
+        weight_bias_dict = {
+            "weight": {
+                "weight_ih": ref_mod.get_quantized_weight_ih(),
+                "weight_hh": ref_mod.get_quantized_weight_hh(),
+            },
+            "bias": {
+                "bias_ih": ref_mod.bias_ih,
+                "bias_hh": ref_mod.bias_hh,
+            }
+        }
+        qmod.set_weight_bias(weight_bias_dict)
+        return qmod
+
+    def _weight_bias(self):
+        # Returns a dict of weights and biases
+        weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
+        w1, b1 = self._packed_weight_ih.__getstate__()[0]
+        w2, b2 = self._packed_weight_hh.__getstate__()[0]
+        # TODO: these can be simplified to one level? e.g. using weight_ih as key
+        # directly
+        weight_bias_dict['weight']['weight_ih'] = w1
+        weight_bias_dict['weight']['weight_hh'] = w2
+        weight_bias_dict['bias']['bias_ih'] = b1
+        weight_bias_dict['bias']['bias_hh'] = b2
+        return weight_bias_dict
+
+    def get_weight(self):
+        return self._weight_bias()['weight']
+
+    def get_bias(self):
+        return self._weight_bias()['bias']
+
+    def set_weight_bias(self, weight_bias_dict):
+        # TODO: these can be simplified to one level? e.g. using weight_ih as key
+        # directly
+        self._packed_weight_ih = pack_weight_bias(
+            weight_bias_dict["weight"]["weight_ih"],
+            weight_bias_dict["bias"]["bias_ih"],
+            self.weight_dtype)
+        self._packed_weight_hh = pack_weight_bias(
+            weight_bias_dict["weight"]["weight_hh"],
+            weight_bias_dict["bias"]["bias_hh"],
+            self.weight_dtype)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + '_packed_weight_ih'] = self._packed_weight_ih
+        destination[prefix + '_packed_weight_hh'] = self._packed_weight_hh
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        self._packed_weight_ih = state_dict.pop(prefix + '_packed_weight_ih')
+        self._packed_weight_hh = state_dict.pop(prefix + '_packed_weight_hh')
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+
+class RNNCell(RNNCellBase):
+    r"""An Elman RNN cell with tanh or ReLU non-linearity.
+    A dynamic quantized RNNCell module with floating point tensor as inputs and outputs.
+    Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`,
+    please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> rnn = nn.RNNCell(10, 20)
+        >>> input = torch.randn(6, 3, 10)
+        >>> hx = torch.randn(3, 20)
+        >>> output = []
+        >>> for i in range(6):
+        ...     hx = rnn(input[i], hx)
+        ...     output.append(hx)
+    """
+    __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
+
+    def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8):
+        super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype)
+        self.nonlinearity = nonlinearity
+
+    def _get_name(self):
+        return 'DynamicQuantizedRNNCell'
+
+    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
+        self.check_forward_input(input)
+        if hx is None:
+            hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        self.check_forward_hidden(input, hx, '')
+        if self.nonlinearity == "tanh":
+            ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic(
+                input, hx,
+                self._packed_weight_ih, self._packed_weight_hh,
+                self.bias_ih, self.bias_hh)
+        elif self.nonlinearity == "relu":
+            ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic(
+                input, hx,
+                self._packed_weight_ih, self._packed_weight_hh,
+                self.bias_ih, self.bias_hh)
+        else:
+            ret = input  # TODO: remove when jit supports exception flow
+            raise RuntimeError(
+                f"Unknown nonlinearity: {self.nonlinearity}")
+        return ret
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+
+class LSTMCell(RNNCellBase):
+    r"""A long short-term memory (LSTM) cell.
+
+    A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs.
+    Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`,
+    please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> rnn = nn.LSTMCell(10, 20)
+        >>> input = torch.randn(6, 3, 10)
+        >>> hx = torch.randn(3, 20)
+        >>> cx = torch.randn(3, 20)
+        >>> output = []
+        >>> for i in range(6):
+        ...     hx, cx = rnn(input[i], (hx, cx))
+        ...     output.append(hx)
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, num_chunks=4, **kwargs)  # type: ignore[misc]
+
+    def _get_name(self):
+        return 'DynamicQuantizedLSTMCell'
+
+    def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
+        self.check_forward_input(input)
+        if hx is None:
+            zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+            hx = (zeros, zeros)
+        self.check_forward_hidden(input, hx[0], '[0]')
+        self.check_forward_hidden(input, hx[1], '[1]')
+        return torch.ops.quantized.quantized_lstm_cell_dynamic(
+            input, hx,
+            self._packed_weight_ih, self._packed_weight_hh,
+            self.bias_ih, self.bias_hh)
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
+
+
+class GRUCell(RNNCellBase):
+    r"""A gated recurrent unit (GRU) cell
+
+    A dynamic quantized GRUCell module with floating point tensor as inputs and outputs.
+    Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`,
+    please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP
+        >>> rnn = nn.GRUCell(10, 20)
+        >>> input = torch.randn(6, 3, 10)
+        >>> hx = torch.randn(3, 20)
+        >>> output = []
+        >>> for i in range(6):
+        ...     hx = rnn(input[i], hx)
+        ...     output.append(hx)
+    """
+
+    def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
+        super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype)
+
+    def _get_name(self):
+        return 'DynamicQuantizedGRUCell'
+
+    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
+        self.check_forward_input(input)
+        if hx is None:
+            hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        self.check_forward_hidden(input, hx, '')
+        return torch.ops.quantized.quantized_gru_cell_dynamic(
+            input, hx,
+            self._packed_weight_ih, self._packed_weight_hh,
+            self.bias_ih, self.bias_hh,
+        )
+
+    @classmethod
+    def from_float(cls, mod):
+        return super().from_float(mod)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/functional.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91cbb1b603deaa8da3db139273a9cd7ca9f8452
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/functional.py
@@ -0,0 +1,644 @@
+r""" Functional interface (quantized)."""
+from typing import List, Optional
+import warnings
+
+import torch
+from torch import Tensor
+from torch.nn.modules.utils import _pair, _triple
+from torch.jit.annotations import BroadcastingList2
+
+from .modules.utils import _pair_from_first
+
+# Although some of the functions and docstrings are mirrored from the torch.nn,
+# we want to have them here for future changes.
+
+__all__ = [
+    "avg_pool2d",
+    "avg_pool3d",
+    "adaptive_avg_pool2d",
+    "adaptive_avg_pool3d",
+    "conv1d",
+    "conv2d",
+    "conv3d",
+    "interpolate",
+    "linear",
+    "max_pool1d",
+    "max_pool2d",
+    "celu",
+    "leaky_relu",
+    "hardtanh",
+    "hardswish",
+    "threshold",
+    "elu",
+    "hardsigmoid",
+    "clamp",
+    "upsample",
+    "upsample_bilinear",
+    "upsample_nearest",
+]
+
+def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
+               count_include_pad=True, divisor_override=None):
+    r"""
+    Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
+    :math:`sH \times sW` steps. The number of output features is equal to the number of
+    input planes.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape.
+
+    Args:
+        input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
+        kernel_size: size of the pooling region. Can be a single number or a
+          tuple `(kH, kW)`
+        stride: stride of the pooling operation. Can be a single number or a
+          tuple `(sH, sW)`. Default: :attr:`kernel_size`
+        padding: implicit zero paddings on both sides of the input. Can be a
+          single number or a tuple `(padH, padW)`. Default: 0
+        ceil_mode: when True, will use `ceil` instead of `floor` in the formula
+            to compute the output shape. Default: ``False``
+        count_include_pad: when True, will include the zero-padding in the
+            averaging calculation. Default: ``True``
+        divisor_override: if specified, it will be used as divisor, otherwise
+             size of the pooling region will be used. Default: None
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
+    return torch.nn.functional.avg_pool2d(input, kernel_size, stride, padding,
+                                          ceil_mode, count_include_pad,
+                                          divisor_override)
+
+def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
+               count_include_pad=True, divisor_override=None):
+    r"""
+    Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
+    :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
+    input planes.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    Args:
+        input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
+        kernel_size: size of the pooling region. Can be a single number or a
+          tuple `(kD, kH, kW)`
+        stride: stride of the pooling operation. Can be a single number or a
+          tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
+        padding: implicit zero paddings on both sides of the input. Can be a
+          single number or a tuple `(padD, padH, padW)`. Default: 0
+        ceil_mode: when True, will use `ceil` instead of `floor` in the formula
+            to compute the output shape. Default: ``False``
+        count_include_pad: when True, will include the zero-padding in the
+            averaging calculation. Default: ``True``
+        divisor_override: if specified, it will be used as divisor, otherwise
+             size of the pooling region will be used. Default: None
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
+    return torch.nn.functional.avg_pool3d(input, kernel_size, stride, padding,
+                                          ceil_mode, count_include_pad,
+                                          divisor_override)
+
+def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
+    r"""
+    Applies a 2D adaptive average pooling over a quantized input signal composed
+    of several quantized input planes.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
+
+    Args:
+        output_size: the target output size (single integer or
+                     double-integer tuple)
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!")
+    return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
+
+def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
+    r"""
+    Applies a 3D adaptive average pooling over a quantized input signal composed
+    of several quantized input planes.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape.
+
+    Args:
+        output_size: the target output size (single integer or
+                     double-integer tuple)
+    """
+    if not input.is_quantized:
+        raise ValueError(
+            "Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!")
+    return torch.nn.functional.adaptive_avg_pool3d(input, output_size)
+
+def conv1d(input, weight, bias,
+           stride=1, padding=0, dilation=1, groups=1,
+           padding_mode='zeros',
+           scale=1.0, zero_point=0,
+           dtype=torch.quint8):
+    r"""
+    Applies a 1D convolution over a quantized 1D input composed of several input
+    planes.
+
+    See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape.
+
+    Args:
+        input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
+        weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
+        bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
+        stride: the stride of the convolving kernel. Can be a single number or a
+          tuple `(sW,)`. Default: 1
+        padding: implicit paddings on both sides of the input. Can be a
+          single number or a tuple `(padW,)`. Default: 0
+        dilation: the spacing between kernel elements. Can be a single number or
+          a tuple `(dW,)`. Default: 1
+        groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
+          number of groups. Default: 1
+        padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
+        scale: quantization scale for the output. Default: 1.0
+        zero_point: quantization zero_point for the output. Default: 0
+        dtype: quantization data type to use. Default: ``torch.quint8``
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> from torch.ao.nn.quantized import functional as qF
+        >>> filters = torch.randn(33, 16, 3, dtype=torch.float)
+        >>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
+        >>> bias = torch.randn(33, dtype=torch.float)
+        >>>
+        >>> scale, zero_point = 1.0, 0
+        >>> dtype_inputs = torch.quint8
+        >>> dtype_filters = torch.qint8
+        >>>
+        >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
+        >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
+        >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
+    """  # noqa: E501
+    if padding_mode != 'zeros':
+        raise NotImplementedError("Only zero-padding is supported!")
+    if input.dtype != torch.quint8:
+        raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
+    if weight.dtype != torch.qint8:
+        raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
+    if input.ndim != 3:
+        raise ValueError("Input shape must be `(N, C, L)`!")
+    stride = _pair_from_first(stride)
+    padding = _pair_from_first(padding)
+    dilation = _pair_from_first(dilation)
+
+    packed_params = torch.ops.quantized.conv1d_prepack(
+        weight, bias, stride, padding, dilation, groups)
+    return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
+
+def conv2d(input, weight, bias,
+           stride=1, padding=0, dilation=1, groups=1,
+           padding_mode='zeros',
+           scale=1.0, zero_point=0,
+           dtype=torch.quint8):
+    r"""
+    Applies a 2D convolution over a quantized 2D input composed of several input
+    planes.
+
+    See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape.
+
+    Args:
+        input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
+        weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
+        bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
+        stride: the stride of the convolving kernel. Can be a single number or a
+          tuple `(sH, sW)`. Default: 1
+        padding: implicit paddings on both sides of the input. Can be a
+          single number or a tuple `(padH, padW)`. Default: 0
+        dilation: the spacing between kernel elements. Can be a single number or
+          a tuple `(dH, dW)`. Default: 1
+        groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
+          number of groups. Default: 1
+        padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
+        scale: quantization scale for the output. Default: 1.0
+        zero_point: quantization zero_point for the output. Default: 0
+        dtype: quantization data type to use. Default: ``torch.quint8``
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> from torch.ao.nn.quantized import functional as qF
+        >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
+        >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
+        >>> bias = torch.randn(8, dtype=torch.float)
+        >>>
+        >>> scale, zero_point = 1.0, 0
+        >>> dtype_inputs = torch.quint8
+        >>> dtype_filters = torch.qint8
+        >>>
+        >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
+        >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
+        >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
+    """  # noqa: E501
+    if padding_mode != 'zeros':
+        raise NotImplementedError("Only zero-padding is supported!")
+    if input.dtype != torch.quint8:
+        raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
+    if weight.dtype != torch.qint8:
+        raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
+    if input.ndim != 4:
+        raise ValueError("Input shape must be `(N, C, H, W)`!")
+    stride = _pair(stride)
+    padding = _pair(padding)
+    dilation = _pair(dilation)
+
+    packed_params = torch.ops.quantized.conv2d_prepack(
+        weight, bias, stride, padding, dilation, groups)
+    return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
+
+def conv3d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1,
+           padding_mode='zeros', scale=1.0, zero_point=0, dtype=torch.quint8):
+    r"""
+    Applies a 3D convolution over a quantized 3D input composed of several input
+    planes.
+
+    See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape.
+
+    Args:
+        input: quantized input tensor of shape
+          :math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
+        weight: quantized filters of shape
+          :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
+        bias: **non-quantized** bias tensor of shape
+          :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
+        stride: the stride of the convolving kernel. Can be a single number or a
+          tuple `(sD, sH, sW)`. Default: 1
+        padding: implicit paddings on both sides of the input. Can be a
+          single number or a tuple `(padD, padH, padW)`. Default: 0
+        dilation: the spacing between kernel elements. Can be a single number or
+          a tuple `(dD, dH, dW)`. Default: 1
+        groups: split input into groups, :math:`\text{in\_channels}` should be
+          divisible by the number of groups. Default: 1
+        padding_mode: the padding mode to use. Only "zeros" is supported for
+          quantized convolution at the moment. Default: "zeros"
+        scale: quantization scale for the output. Default: 1.0
+        zero_point: quantization zero_point for the output. Default: 0
+        dtype: quantization data type to use. Default: ``torch.quint8``
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> from torch.ao.nn.quantized import functional as qF
+        >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
+        >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
+        >>> bias = torch.randn(8, dtype=torch.float)
+        >>>
+        >>> scale, zero_point = 1.0, 0
+        >>> dtype_inputs = torch.quint8
+        >>> dtype_filters = torch.qint8
+        >>>
+        >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
+        >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
+        >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
+    """  # noqa: E501
+    if padding_mode != 'zeros':
+        raise NotImplementedError("Only zero-padding is supported!")
+    if input.dtype != torch.quint8:
+        raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
+    if weight.dtype != torch.qint8:
+        raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
+    if input.ndim != 5:
+        raise ValueError("Input shape must be `(N, C, D, H, W)`!")
+    stride = _triple(stride)
+    padding = _triple(padding)
+    dilation = _triple(dilation)
+
+    packed_params = torch.ops.quantized.conv3d_prepack(
+        weight, bias, stride, padding, dilation, groups)
+    return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
+
+def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
+    r"""Down/up samples the input to either the given :attr:`size` or the given
+    :attr:`scale_factor`
+
+    See :func:`torch.nn.functional.interpolate` for implementation details.
+
+    The input dimensions are interpreted in the form:
+    `mini-batch x channels x [optional depth] x [optional height] x width`.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    .. note:: Only 2D/3D input is supported for quantized inputs
+
+    .. note:: Only the following modes are supported for the quantized inputs:
+
+        - `bilinear`
+        - `nearest`
+
+    Args:
+        input (Tensor): the input tensor
+        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
+            output spatial size.
+        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
+        mode (str): algorithm used for upsampling:
+            ``'nearest'`` | ``'bilinear'``
+        align_corners (bool, optional): Geometrically, we consider the pixels of the
+            input and output as squares rather than points.
+            If set to ``True``, the input and output tensors are aligned by the
+            center points of their corner pixels, preserving the values at the corner pixels.
+            If set to ``False``, the input and output tensors are aligned by the corner
+            points of their corner pixels, and the interpolation uses edge value padding
+            for out-of-boundary values, making this operation *independent* of input size
+            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
+            is ``'bilinear'``.
+            Default: ``False``
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.interpolate' must be quantized!")
+    return torch.nn.functional.interpolate(input, size, scale_factor, mode,
+                                           align_corners)
+
+def linear(
+    input: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
+    scale: Optional[float] = None, zero_point: Optional[int] = None
+) -> Tensor:
+    r"""
+    Applies a linear transformation to the incoming quantized data:
+    :math:`y = xA^T + b`.
+    See :class:`~torch.ao.nn.quantized.Linear`
+
+    .. note::
+
+      Current implementation packs weights on every call, which has penalty on performance.
+      If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`.
+
+    Args:
+      input (Tensor): Quantized input of type `torch.quint8`
+      weight (Tensor): Quantized weight of type `torch.qint8`
+      bias (Tensor): None or fp32 bias of type `torch.float`
+      scale (double): output scale. If None, derived from the input scale
+      zero_point (long): output zero point. If None, derived from the input zero_point
+
+    Shape:
+        - Input: :math:`(N, *, in\_features)` where `*` means any number of
+          additional dimensions
+        - Weight: :math:`(out\_features, in\_features)`
+        - Bias: :math:`(out\_features)`
+        - Output: :math:`(N, *, out\_features)`
+    """
+    if scale is None:
+        scale = input.q_scale()
+    if zero_point is None:
+        zero_point = input.q_zero_point()
+    _packed_params = torch.ops.quantized.linear_prepack(weight, bias)
+    return torch.ops.quantized.linear(input, _packed_params, scale, zero_point)
+
+def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
+               ceil_mode=False, return_indices=False):
+    r"""Applies a 1D max pooling over a quantized input signal composed of
+    several quantized input planes.
+
+    .. note:: The input quantization parameters are propagated to the output.
+
+    See :class:`~torch.ao.nn.quantized.MaxPool1d` for details.
+    """
+    if return_indices:
+        raise NotImplementedError("return_indices is not yet implemented!")
+    if stride is None:
+        stride = torch.jit.annotate(List[int], [])
+    return torch.nn.functional.max_pool1d(input, kernel_size, stride, padding,
+                                          dilation, ceil_mode=ceil_mode, return_indices=return_indices)
+
+def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
+               ceil_mode=False, return_indices=False):
+    r"""Applies a 2D max pooling over a quantized input signal composed of
+    several quantized input planes.
+
+    .. note:: The input quantization parameters are propagated to the output.
+
+    See :class:`~torch.ao.nn.quantized.MaxPool2d` for details.
+    """
+    if return_indices:
+        raise NotImplementedError("return_indices is not yet implemented!")
+    if stride is None:
+        stride = torch.jit.annotate(List[int], [])
+    return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding,
+                                          dilation, ceil_mode=ceil_mode, return_indices=return_indices)
+
+def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor:
+    r"""celu(input, scale, zero_point, alpha=1.) -> Tensor
+
+    Applies the quantized CELU function element-wise.
+
+    .. math::
+        \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1))
+
+    Args:
+        input: quantized input
+        alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.celu' must be quantized!")
+    return torch.ops.quantized.celu(input, scale, zero_point, alpha)
+
+
+def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False,
+               scale: Optional[float] = None, zero_point: Optional[int] = None):
+    r"""
+    Quantized version of the.
+    leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor
+
+    Applies element-wise,
+    :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
+
+    Args:
+        input: Quantized input
+        negative_slope: The slope of the negative input
+        inplace: Inplace modification of the input tensor
+        scale, zero_point: Scale and zero point of the output tensor.
+
+    See :class:`~torch.nn.LeakyReLU` for more details.
+    """
+    if scale is not None and zero_point is not None:
+        assert not inplace, "Cannot rescale with `inplace`"
+        output = torch._empty_affine_quantized(
+            input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype)
+        torch._C._nn.leaky_relu(input, negative_slope, out=output)
+        return output
+    if inplace:
+        result = torch._C._nn.leaky_relu_(input, negative_slope)
+    else:
+        result = torch._C._nn.leaky_relu(input, negative_slope)
+    return result
+
+def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor:
+    r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`.
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.hardtanh' must be quantized!")
+    if inplace:
+        return torch._C._nn.hardtanh_(input, min_val, max_val)
+    return torch._C._nn.hardtanh(input, min_val, max_val)
+
+def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor:
+    r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`.
+
+    Args:
+        input: quantized input
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.hardswish' must be quantized!")
+    return torch._ops.ops.quantized.hardswish(input, scale, zero_point)
+
+def threshold(input: Tensor, threshold: float, value: float) -> Tensor:
+    r"""Applies the quantized version of the threshold function element-wise:
+
+    .. math::
+        x = \begin{cases}
+                x & \text{if~} x > \text{threshold} \\
+                \text{value} & \text{otherwise}
+            \end{cases}
+
+    See :class:`~torch.nn.Threshold` for more details.
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.threshold' must be quantized!")
+    if threshold is None:
+        raise ValueError("Input to 'threshold' must be specified!")
+    if value is None:
+        raise ValueError("Input to 'value' must be specified!")
+    return torch._ops.ops.quantized.threshold(input, threshold, value)
+
+def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.) -> Tensor:
+    r"""This is the quantized version of :func:`~torch.nn.functional.elu`.
+
+    Args:
+        input: quantized input
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+        alpha: the alpha constant
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.elu' must be quantized!")
+    return torch.ops.quantized.elu(input, scale, zero_point, alpha)
+
+def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
+    r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`.
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!")
+    if inplace:
+        return torch._C._nn.hardsigmoid_(input)  # type: ignore[attr-defined]
+    return torch._C._nn.hardsigmoid(input)
+
+def clamp(input: Tensor, min_: float, max_: float) -> Tensor:
+    r"""float(input, min\_, max\_) -> Tensor
+
+    Applies the clamp function element-wise.
+    See :class:`~torch.ao.nn.quantized.clamp` for more details.
+
+    Args:
+        input: quantized input
+        min_: minimum value for clamping
+        max_: maximum value for clamping
+    """
+    if not input.is_quantized:
+        raise ValueError("Input to 'quantized.clamp' must be quantized!")
+    return torch.clamp(input, min_, max_)
+
+def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
+    r"""Upsamples the input to either the given :attr:`size` or the given
+    :attr:`scale_factor`
+
+    .. warning::
+        This function is deprecated in favor of
+        :func:`torch.ao.nn.quantized.functional.interpolate`.
+        This is equivalent with ``nn.quantized.functional.interpolate(...)``.
+
+    See :func:`torch.nn.functional.interpolate` for implementation details.
+
+    The input dimensions are interpreted in the form:
+    `mini-batch x channels x [optional depth] x [optional height] x width`.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    .. note:: Only 2D input is supported for quantized inputs
+
+    .. note:: Only the following modes are supported for the quantized inputs:
+
+        - `bilinear`
+        - `nearest`
+
+    Args:
+        input (Tensor): quantized input tensor
+        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
+            output spatial size.
+        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
+        mode (str): algorithm used for upsampling:
+            ``'nearest'`` | ``'bilinear'``
+        align_corners (bool, optional): Geometrically, we consider the pixels of the
+            input and output as squares rather than points.
+            If set to ``True``, the input and output tensors are aligned by the
+            center points of their corner pixels, preserving the values at the corner pixels.
+            If set to ``False``, the input and output tensors are aligned by the corner
+            points of their corner pixels, and the interpolation uses edge value padding
+            for out-of-boundary values, making this operation *independent* of input size
+            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
+            is ``'bilinear'``.
+            Default: ``False``
+
+    .. warning::
+        With ``align_corners = True``, the linearly interpolating modes
+        (`bilinear`) don't proportionally align the
+        output and input pixels, and thus the output values can depend on the
+        input size. This was the default behavior for these modes up to version
+        0.3.1. Since then, the default behavior is ``align_corners = False``.
+        See :class:`~torch.nn.Upsample` for concrete examples on how this
+        affects the outputs.
+    """
+    warnings.warn("nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.")
+    return interpolate(input, size, scale_factor, mode, align_corners)
+
+def upsample_bilinear(input, size=None, scale_factor=None):
+    r"""Upsamples the input, using bilinear upsampling.
+
+    .. warning::
+        This function is deprecated in favor of
+        :func:`torch.ao.nn.quantized.functional.interpolate`.
+        This is equivalent with
+        ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    .. note:: Only 2D inputs are supported
+
+    Args:
+        input (Tensor): quantized input
+        size (int or Tuple[int, int]): output spatial size.
+        scale_factor (int or Tuple[int, int]): multiplier for spatial size
+    """
+    # DeprecationWarning is ignored by default
+    warnings.warn("nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.")
+    return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True)
+
+def upsample_nearest(input, size=None, scale_factor=None):
+    r"""Upsamples the input, using nearest neighbours' pixel values.
+
+    .. warning::
+        This function is deprecated in favor of
+        :func:`torch.ao.nn.quantized.functional.interpolate`.
+        This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``.
+
+    .. note:: The input quantization parameters propagate to the output.
+
+    .. note:: Only 2D inputs are supported
+
+    Args:
+        input (Tensor): quantized input
+        size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial
+            size.
+        scale_factor (int): multiplier for spatial size. Has to be an integer.
+    """
+    # DeprecationWarning is ignored by default
+    warnings.warn("nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.")
+    return interpolate(input, size, scale_factor, mode='nearest')
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd65cb5415fe4b5152e05957b92eefd97b49d63
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__init__.py
@@ -0,0 +1,131 @@
+import torch
+
+# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable`
+# packages. However, the `quantizable` package uses "lazy imports"
+# to avoid circular dependency.
+# Hence we need to include it here to make sure it is resolved before
+# they are used in the modules.
+import torch.ao.nn.quantizable
+
+from torch.nn.modules.pooling import MaxPool2d
+
+from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
+from .dropout import Dropout
+from .batchnorm import BatchNorm2d, BatchNorm3d
+from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \
+    InstanceNorm2d, InstanceNorm3d
+from .conv import Conv1d, Conv2d, Conv3d
+from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
+from .linear import Linear
+from .embedding_ops import Embedding, EmbeddingBag
+from .rnn import LSTM
+
+from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional
+
+__all__ = [
+    'BatchNorm2d',
+    'BatchNorm3d',
+    'Conv1d',
+    'Conv2d',
+    'Conv3d',
+    'ConvTranspose1d',
+    'ConvTranspose2d',
+    'ConvTranspose3d',
+    'DeQuantize',
+    'ELU',
+    'Embedding',
+    'EmbeddingBag',
+    'GroupNorm',
+    'Hardswish',
+    'InstanceNorm1d',
+    'InstanceNorm2d',
+    'InstanceNorm3d',
+    'LayerNorm',
+    'LeakyReLU',
+    'Linear',
+    'LSTM',
+    'MultiheadAttention',
+    'Quantize',
+    'ReLU6',
+    'Sigmoid',
+    'Softmax',
+    'Dropout',
+    'PReLU',
+    # Wrapper modules
+    'FloatFunctional',
+    'FXFloatFunctional',
+    'QFunctional',
+]
+
+class Quantize(torch.nn.Module):
+    r"""Quantizes an incoming tensor
+
+    Args:
+     `scale`: scale of the output Quantized Tensor
+     `zero_point`: zero_point of output Quantized Tensor
+     `dtype`: data type of output Quantized Tensor
+     `factory_kwargs`: Dictionary of kwargs used for configuring initialization
+         of internal buffers. Currently, `device` and `dtype` are supported.
+         Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}`
+         will initialize internal buffers as type `torch.float64` on the current CUDA device.
+         Note that `dtype` only applies to floating-point buffers.
+
+    Examples::
+        >>> t = torch.tensor([[1., -1.], [1., -1.]])
+        >>> scale, zero_point, dtype = 1.0, 2, torch.qint8
+        >>> qm = Quantize(scale, zero_point, dtype)
+        >>> # xdoctest: +SKIP
+        >>> qt = qm(t)
+        >>> print(qt)
+        tensor([[ 1., -1.],
+                [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)
+    """
+
+    scale: torch.Tensor
+    zero_point: torch.Tensor
+
+    def __init__(self, scale, zero_point, dtype, factory_kwargs=None):
+        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
+        super().__init__()
+        self.register_buffer('scale', torch.tensor([scale], **factory_kwargs))
+        self.register_buffer('zero_point',
+                             torch.tensor([zero_point], dtype=torch.long,
+                                          **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
+        self.dtype = dtype
+
+    def forward(self, X):
+        return torch.quantize_per_tensor(X, float(self.scale),
+                                         int(self.zero_point), self.dtype)
+
+    @staticmethod
+    def from_float(mod):
+        assert hasattr(mod, 'activation_post_process')
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype)
+
+    def extra_repr(self):
+        return f'scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}'
+
+
+class DeQuantize(torch.nn.Module):
+    r"""Dequantizes an incoming tensor
+
+    Examples::
+        >>> input = torch.tensor([[1., -1.], [1., -1.]])
+        >>> scale, zero_point, dtype = 1.0, 2, torch.qint8
+        >>> qm = Quantize(scale, zero_point, dtype)
+        >>> # xdoctest: +SKIP
+        >>> quantized_input = qm(input)
+        >>> dqm = DeQuantize()
+        >>> dequantized = dqm(quantized_input)
+        >>> print(dequantized)
+        tensor([[ 1., -1.],
+                [ 1., -1.]], dtype=torch.float32)
+    """
+
+    def forward(self, Xq):
+        return Xq.dequantize()
+
+    @staticmethod
+    def from_float(mod):
+        return DeQuantize()
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8eb16b0882963cd4151a846c098a3e629fc2ba6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e14731901c9be9c9309ea18edb29a019dbace8cd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eaefaf306a5618e8303f6dbfdc409815acda3a4d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8682f6872ee504869f523fecfb6f2d2907936572
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da9707959dbffccefcd07ab8c0d997d886842e61
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bacc8b57e5e6877b7b7bcc0d860496aeae5f9c26
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..924b5a5e3811dae0d22a7dad8899c47ec05813a5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..becb30764e289271fd0a1bd13533d8fecf4c9eec
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cc9d5e8491e49051cf4b0e84bbfbb240af12f68
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfc14a983b768b3e268c30a00aab04e3fa5ab50a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27ab7a8041cc6a8c6ac74753db50fa20391280a1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/activation.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..dae043965f25978d786209fa86b454c70d040b28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/activation.py
@@ -0,0 +1,302 @@
+import torch
+from warnings import warn
+__all__ = [
+    "ReLU6",
+    "Hardswish",
+    "ELU",
+    "LeakyReLU",
+    "Sigmoid",
+    "Softmax",
+    "MultiheadAttention",
+    "PReLU"
+]
+
+class ReLU6(torch.nn.ReLU):
+    r"""Applies the element-wise function:
+
+    :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
+    zero_point, and :math:`q(6)` is the quantized representation of number 6.
+
+    Args:
+        inplace: can optionally do the operation in-place. Default: ``False``
+
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+
+    .. image:: ../scripts/activation_images/ReLU6.png
+
+    Examples::
+
+        >>> m = nn.quantized.ReLU6()
+        >>> input = torch.randn(2)
+        >>> # xdoctest: +SKIP
+        >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
+        >>> output = m(input)
+    """
+    def __init__(self, inplace=False):
+        super().__init__(inplace)
+        self.inplace = inplace
+
+    def forward(self, input):
+        return torch.ops.quantized.relu6(input, self.inplace)
+
+    def _get_name(self):
+        return 'QuantizedReLU6'
+
+    @staticmethod
+    def from_float(mod):
+        return ReLU6(mod.inplace)
+
+class Hardswish(torch.nn.Hardswish):
+    r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
+
+    Args:
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+    """
+    def __init__(self, scale, zero_point, device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.hardswish(input, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedHardswish'
+
+    @staticmethod
+    def from_float(mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        return Hardswish(float(scale), int(zero_point))
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(float(scale), int(zero_point))
+
+class ELU(torch.nn.ELU):
+    r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
+
+    Args:
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+        alpha: the alpha constant
+    """
+    def __init__(self, scale, zero_point, alpha=1.):
+        super().__init__(alpha)
+        self.scale = scale
+        self.zero_point = zero_point
+
+    def forward(self, input):
+        return torch.ao.nn.quantized.functional.elu(
+            input, self.scale, self.zero_point, self.alpha)
+
+    def _get_name(self):
+        return 'QuantizedELU'
+
+    @staticmethod
+    def from_float(mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        return ELU(float(scale), int(zero_point), mod.alpha)
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(float(scale), int(zero_point), mod.alpha)
+
+class LeakyReLU(torch.nn.LeakyReLU):
+    r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
+
+    Args:
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+        negative_slope: Controls the angle of the negative slope. Default: 1e-2
+    """
+    def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2,
+                 inplace: bool = False, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(negative_slope, inplace)
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.leaky_relu(
+            input, self.negative_slope, self.inplace, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedLeakyReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
+
+class Sigmoid(torch.nn.Sigmoid):
+    r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
+
+    Args:
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+    """
+
+    def __init__(self, output_scale: float, output_zero_point: int):
+        super().__init__()
+        self.output_scale = output_scale
+        self.output_zero_point = output_zero_point
+
+    def forward(self, input):
+        return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
+
+    @classmethod
+    def from_float(cls, mod):
+        output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
+        return cls(float(output_scale), int(output_zero_point))
+
+class Softmax(torch.nn.Softmax):
+    r"""This is the quantized version of :class:`~torch.nn.Softmax`.
+
+    Args:
+        dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+    """
+    def __init__(self, dim=None, scale=1.0, zero_point=0):
+        super().__init__()
+        self.dim = dim
+        self.scale = scale
+        self.zero_point = zero_point
+
+    def forward(self, input):
+        dim = self.dim
+        if dim is None:
+            stacklevel = 3
+            # Note: adding the mypy ignore on _get_softmax_dim seems less bad
+            # than making `_get_softmax_dim` an official API.
+            dim = torch.nn.functional._get_softmax_dim(  # type: ignore[attr-defined]
+                "softmax", input.dim(), stacklevel)
+        return torch.ops.quantized.softmax(
+            input, dim, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedSoftmax'
+
+    @staticmethod
+    def from_float(mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        return Softmax(mod.dim, float(scale), int(zero_point))
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(mod.dim, float(scale), int(zero_point))
+
+
+class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
+    _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
+
+    def _get_name(self):
+        return "QuantizedMultiheadAttention"
+
+    @classmethod
+    def from_float(cls, other):
+        # The whole flow is float -> observed -> quantized
+        # This class does observed -> quantized only
+        raise NotImplementedError("It looks like you are trying to convert a "
+                                  "non-observed MHA module. Please, see "
+                                  "the examples on quantizable MHAs.")
+
+    @classmethod
+    def from_observed(cls, other):
+        converted = torch.ao.quantization.convert(other, mapping=None,
+                                                  inplace=False,
+                                                  remove_qconfig=True,
+                                                  convert_custom_config_dict=None)
+        converted.__class__ = cls
+        # Remove the parameters for the bias_k and bias_v to quantize them
+        # TODO: This is a potential source of accuracy drop.
+        #       quantized cat takes the scale and zp of the first
+        #       element, which might lose the precision in the bias_k
+        #       and the bias_v (which are cat'ed with k/v being first).
+        if converted.bias_k is not None:
+            bias_k = converted._parameters.pop('bias_k')
+            sc, zp = torch._choose_qparams_per_tensor(bias_k,
+                                                      reduce_range=False)
+            bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
+            setattr(converted, 'bias_k', bias_k)  # noqa: B010
+
+        if converted.bias_v is not None:
+            bias_v = converted._parameters.pop('bias_v')
+            sc, zp = torch._choose_qparams_per_tensor(bias_k,  # type: ignore[possibly-undefined]
+                                                      reduce_range=False)
+            bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
+            setattr(converted, 'bias_v', bias_v)  # noqa: B010
+
+        del converted.in_proj_weight
+        del converted.in_proj_bias
+
+        return converted
+
+class PReLU(torch.nn.Module):
+    r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
+
+    Args:
+        scale: quantization scale of the output tensor
+        zero_point: quantization zero point of the output tensor
+        num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
+    """
+    def __init__(self, output_scale: float, output_zero_point: int,
+                 num_parameters: int = 1) -> None:
+        super().__init__()
+        self.num_parameters = num_parameters
+        self.scale = output_scale
+        self.zero_point = output_zero_point
+        w = torch.randn(num_parameters, dtype=torch.float)
+        qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
+        self.set_weight(qw)
+
+    def set_weight(self, w: torch.Tensor) -> None:
+        self.weight = w
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedPReLU'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
+        float_wt = mod.weight.float()
+        observer = mod.qconfig.weight()
+        observer(float_wt)
+        if observer.dtype != torch.quint8:
+            warn(
+                f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
+            )
+        wt_scale, wt_zp = observer.calculate_qparams()
+        qweight = torch.quantize_per_tensor(
+            float_wt, float(wt_scale), int(wt_zp), torch.quint8)
+        qprelu.set_weight(qweight)
+        return qprelu
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
+        float_wt = mod.weight.float()
+        observer = mod.qconfig.weight()
+        observer(float_wt)
+        if observer.dtype != torch.quint8:
+            warn(
+                f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
+            )
+        wt_scale, wt_zp = observer.calculate_qparams()
+        qweight = torch.quantize_per_tensor(
+            float_wt, float(wt_scale), int(wt_zp), torch.quint8)
+        qprelu.set_weight(qweight)
+        return qprelu
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/batchnorm.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cfc51ccf73c2ab44571604f7759ed008c0be29b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/batchnorm.py
@@ -0,0 +1,106 @@
+import torch
+import torch.ao.nn.intrinsic as nni
+
+__all__ = [
+    "BatchNorm2d",
+    "BatchNorm3d"
+]
+
+class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
+        self.register_buffer('scale', torch.tensor(1.0, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
+
+    @staticmethod
+    def from_float(cls, mod):
+        activation_post_process = mod.activation_post_process
+        if type(mod) == cls._NNI_BN_RELU_MODULE:
+            mod = mod[0]
+        scale, zero_point = activation_post_process.calculate_qparams()
+        new_mod = cls(mod.num_features, mod.eps)
+        new_mod.weight = mod.weight
+        new_mod.bias = mod.bias
+        new_mod.running_mean = mod.running_mean
+        new_mod.running_var = mod.running_var
+        new_mod.scale = scale
+        new_mod.zero_point = zero_point
+        return new_mod
+
+    @classmethod
+    def from_reference(cls, bn, output_scale, output_zero_point):
+        qbn = cls(
+            bn.num_features,
+            bn.eps,
+            bn.momentum,
+            device=bn.weight.device,
+            dtype=bn.weight.dtype
+        )
+        qbn.weight = bn.weight
+        qbn.bias = bn.bias
+        qbn.running_mean = bn.running_mean
+        qbn.running_var = bn.running_var
+        qbn.scale = output_scale
+        qbn.zero_point = output_zero_point
+        return qbn
+
+class BatchNorm2d(_BatchNorm):
+    r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
+    """
+
+    _NNI_BN_RELU_MODULE = nni.BNReLU2d
+
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_features, eps, momentum, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedBatchNorm2d'
+
+    def _check_input_dim(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        # disabling this since this is not symbolically traceable
+        # self._check_input_dim(input)
+        return torch.ops.quantized.batch_norm2d(
+            input, self.weight, self.bias, self.running_mean,
+            self.running_var, self.eps, self.scale, self.zero_point)
+
+    @classmethod
+    def from_float(cls, mod):
+        return _BatchNorm.from_float(cls, mod)
+
+class BatchNorm3d(_BatchNorm):
+    r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
+    """
+
+    _NNI_BN_RELU_MODULE = nni.BNReLU3d
+
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_features, eps, momentum, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedBatchNorm3d'
+
+    def _check_input_dim(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        # disabling this since this is not symbolically traceable
+        # self._check_input_dim(input)
+        return torch.ops.quantized.batch_norm3d(
+            input, self.weight, self.bias, self.running_mean,
+            self.running_var, self.eps, self.scale, self.zero_point)
+
+    @classmethod
+    def from_float(cls, mod):
+        return _BatchNorm.from_float(cls, mod)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/conv.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..0234874aa6b1027a9cf3cfe505fe1fa076c38e68
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/conv.py
@@ -0,0 +1,945 @@
+r"""Quantized convolution modules."""
+
+from typing import Optional, List, TypeVar
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.qat as nniqat
+
+from torch._ops import ops
+from torch.nn.common_types import _size_1_t
+from torch.nn.modules.utils import _single, _pair, _triple
+from torch.nn.utils import fuse_conv_bn_weights
+
+from .utils import _quantize_weight, WeightedQuantizedModule
+
+__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
+
+_SUPPORTED_PADDING = {
+    'zeros',
+    'reflect'
+}
+
+
+def _reverse_repeat_padding(padding: List[int]) -> List[int]:
+    _reversed_padding_repeated_twice: List[int] = []
+    N = len(padding)
+    for idx in range(N):
+        for _ in range(2):
+            _reversed_padding_repeated_twice.append(padding[N - idx - 1])
+    return _reversed_padding_repeated_twice
+
+
+class _ConvNd(WeightedQuantizedModule):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        # All subclasses have this signature - See PR #49702s
+        raise NotImplementedError
+
+    def _init(self, in_channels, out_channels, kernel_size, stride,
+              padding, dilation,
+              transposed, output_padding,
+              groups, bias,
+              padding_mode='zeros',
+              device=None,
+              dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__()
+
+        if in_channels % groups != 0:
+            raise ValueError('in_channels must be divisible by groups')
+        if out_channels % groups != 0:
+            raise ValueError('out_channels must be divisible by groups')
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.transposed = transposed
+        self.output_padding = output_padding
+        self.groups = groups
+        if padding_mode not in _SUPPORTED_PADDING:
+            raise ValueError(f"'padding_mode' {padding_mode} is not supported by quantized convolution")
+        self.padding_mode = padding_mode
+        # Initialize as NCHW. set_weight will internally transpose to NHWC.
+        if self.transposed:
+            weight_shape = [in_channels, out_channels // self.groups]
+        else:
+            weight_shape = [out_channels, in_channels // self.groups]
+        qweight = torch._empty_affine_quantized(
+            weight_shape + list(kernel_size),
+            scale=1, zero_point=0, dtype=torch.qint8,
+            **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
+        bias_float = (
+            torch.zeros(out_channels, dtype=torch.float,
+                        **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
+
+        self.set_weight_bias(qweight, bias_float)
+        self.scale = 1.0
+        self.zero_point = 0
+
+    def set_weight_bias(self, qweight, bias_float):
+        raise NotImplementedError
+
+    def bias(self):
+        raise NotImplementedError
+
+    def _weight_bias(self):
+        raise NotImplementedError
+
+    def extra_repr(self):
+        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
+             ', stride={stride}, scale={scale}, zero_point={zero_point}')
+        if self.padding != (0,) * len(self.padding):
+            s += ', padding={padding}'
+        if self.dilation != (1,) * len(self.dilation):
+            s += ', dilation={dilation}'
+        if self.output_padding != (0,) * len(self.output_padding):
+            s += ', output_padding={output_padding}'
+        if self.groups != 1:
+            s += ', groups={groups}'
+        if self.bias() is None:
+            s += ', bias=False'
+        return s.format(**self.__dict__)
+
+    # ===== Serialization methods =====
+    # The special consideration here is that we have to unpack the weights into
+    # their regular QTensor form for serialization. Packed weights should not
+    # live outside the process in which they were created, rather they should be
+    # derived from the QTensor weight.
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #
+    # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
+    #   self
+    #   |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        (w, b) = self._weight_bias()
+        destination[prefix + 'weight'] = w
+        destination[prefix + 'bias'] = b
+        destination[prefix + 'scale'] = torch.tensor(self.scale)
+        destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
+
+    @torch.jit.export
+    def __getstate__(self):
+        (w, b) = self._weight_bias()
+        return (
+            self.in_channels,
+            self.out_channels,
+            self.kernel_size,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.transposed,
+            self.output_padding,
+            self.groups,
+            self.padding_mode,
+            w,
+            b,
+            self.scale,
+            self.zero_point,
+            self.training
+        )
+
+    # ===== Deserialization methods =====
+    # Counterpart to the serialization methods, we must pack the serialized
+    # QTensor weight into its packed format for use by the FBGEMM ops.
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        self.set_weight_bias(
+            state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
+        state_dict.pop(prefix + 'weight')
+        state_dict.pop(prefix + 'bias')
+        self.scale = float(state_dict[prefix + 'scale'])
+        state_dict.pop(prefix + 'scale')
+        self.zero_point = int(state_dict[prefix + 'zero_point'])
+        state_dict.pop(prefix + 'zero_point')
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, False, missing_keys,
+            unexpected_keys, error_msgs)
+
+    @torch.jit.export
+    def __setstate__(self, state):
+        self.in_channels = state[0]
+        self.out_channels = state[1]
+        self.kernel_size = state[2]
+        self.stride = state[3]
+        self.padding = state[4]
+        self.dilation = state[5]
+        self.transposed = state[6]
+        self.output_padding = state[7]
+        self.groups = state[8]
+        self.padding_mode = state[9]
+        self.set_weight_bias(state[10], state[11])
+        self.scale = state[12]
+        self.zero_point = state[13]
+        self.training = state[14]
+
+    def __deepcopy__(self, memo):
+        new_instance = type(self).__new__(type(self))
+        torch.nn.Module.__init__(new_instance)
+        state = self.__getstate__()
+        new_instance.__setstate__(state)
+        return new_instance
+
+    def __copy__(self):
+        return self.__deepcopy__({})
+
+    @classmethod
+    def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
+        r"""Creates a qconv object and returns it.
+        """
+        if weight_post_process is None:
+            weight_post_process = mod.qconfig.weight()
+        weight_post_process(mod.weight)
+        assert weight_post_process.dtype == torch.qint8, \
+            'Weight observer must have a dtype of qint8'
+        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
+        # the __init__ call used is the one from derived classes and not the one from _ConvNd
+        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
+                    mod.stride, mod.padding, mod.dilation, mod.groups,
+                    mod.bias is not None, mod.padding_mode)
+        qconv.set_weight_bias(qweight, mod.bias)
+        if activation_post_process is None or activation_post_process.dtype == torch.float:
+            return qconv  # dynamic quantization doesn't need scale/zero_point
+        else:
+            act_scale, act_zp = activation_post_process.calculate_qparams()
+            qconv.scale = float(act_scale)
+            qconv.zero_point = int(act_zp)
+            return qconv
+
+    @staticmethod
+    def from_float(cls, mod):
+        if hasattr(mod, "weight_fake_quant"):
+            # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
+            # ".from_float only works for " + cls.__QAT_MODULE.__name__
+            if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
+                mod.weight, mod.bias = fuse_conv_bn_weights(
+                    mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
+                    mod.bn.eps, mod.bn.weight, mod.bn.bias)
+            assert hasattr(mod, "activation_post_process"), \
+                "Input QAT module must have observer attached"
+            weight_post_process = mod.weight_fake_quant
+            activation_post_process = mod.activation_post_process
+        else:
+            assert type(mod) == cls._FLOAT_MODULE, \
+                " nnq." + cls.__name__ + ".from_float only works for " + \
+                cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod))
+            assert hasattr(mod, "qconfig"), \
+                "Input float module must have qconfig defined."
+            activation_post_process = None if not hasattr(
+                mod, "activation_post_process") else mod.activation_post_process
+            if type(mod) in [cls._NNI_CONV_RELU_MODULE, cls._NNI_CONV_ADD_MODULE, cls._NNI_CONV_ADD_RELU_MODULE]:
+                mod = mod[0]
+            weight_post_process = mod.qconfig.weight()
+        return cls.get_qconv(mod, activation_post_process, weight_post_process)
+
+    @classmethod
+    def from_reference(cls, ref_qconv, output_scale, output_zero_point):
+        r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
+        Args:
+            ref_qconv (Module): a reference quantized  module, either produced by torch.ao.quantization
+                                utilities or provided by the user
+            output_scale (float): scale for output Tensor
+            output_zero_point (int): zero point for output Tensor
+        """
+        qconv = cls(
+            ref_qconv.in_channels,
+            ref_qconv.out_channels,
+            ref_qconv.kernel_size,  # type: ignore[arg-type]
+            ref_qconv.stride,  # type: ignore[arg-type]
+            ref_qconv.padding,  # type: ignore[arg-type]
+            ref_qconv.dilation,  # type: ignore[arg-type]
+            ref_qconv.groups,
+            ref_qconv.bias is not None,  # type: ignore[arg-type]
+            ref_qconv.padding_mode,
+            device=ref_qconv.weight.device,
+            dtype=ref_qconv.weight.dtype)
+        qweight = ref_qconv.get_quantized_weight()
+        qconv.set_weight_bias(qweight, ref_qconv.bias)
+        qconv.scale = float(output_scale)
+        qconv.zero_point = int(output_zero_point)
+        return qconv
+
+
+class Conv1d(_ConvNd):
+    r"""Applies a 1D convolution over a quantized input signal composed of
+    several quantized input planes.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.Conv1d`.
+
+    .. note::
+        Only `zeros` is supported for the :attr:`padding_mode` argument.
+
+    .. note::
+        Only `torch.quint8` is supported for the input data type.
+
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+
+    See :class:`~torch.nn.Conv1d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
+        >>> input = torch.randn(20, 16, 100)
+        >>> # quantize input to quint8
+        >>> # xdoctest: +SKIP
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
+        ...                                     dtype=torch.quint8)
+        >>> output = m(q_input)
+
+    """
+
+    _FLOAT_MODULE = nn.Conv1d
+    _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
+    _NNI_CONV_RELU_MODULE = nni.ConvReLU1d
+    _NNI_CONV_ADD_MODULE: None = None
+    _NNI_CONV_ADD_RELU_MODULE: None = None
+
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_1_t,
+                 stride: _size_1_t = 1,
+                 padding: _size_1_t = 0,
+                 dilation: _size_1_t = 1,
+                 groups: int = 1,
+                 bias: bool = True,
+                 padding_mode: str = 'zeros',
+                 device=None,
+                 dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _single(kernel_size)
+        stride = _single(stride)
+        padding = padding if isinstance(padding, str) else _single(padding)
+        dilation = _single(dilation)
+
+        # Subclasses of _ConvNd needs to call _init rather than __init__. See
+        # discussion on PR #49702
+        super()._init(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            False, _single(0), groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedConv1d'
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        if self.padding_mode == 'zeros':
+            self._packed_params = torch.ops.quantized.conv1d_prepack(
+                w, b, self.stride, self.padding, self.dilation, self.groups)
+        else:
+            self._packed_params = torch.ops.quantized.conv1d_prepack(
+                w, b, self.stride, _pair(0), self.dilation,
+                self.groups)
+
+    def _weight_bias(self):
+        w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
+        return w, b
+
+    def weight(self):
+        return self._weight_bias()[0]
+
+    def bias(self):
+        return self._weight_bias()[1]
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 3:
+            raise ValueError("Input shape must be `(N, C, L)`!")
+        if self.padding_mode != 'zeros':
+            # Padding in Conv1d is stored as (p, p), need to get (p,)
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Creates a quantized module from a float module or qparams_dict.
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+              utilities or provided by the user
+        """
+        return _ConvNd.from_float(cls, mod)
+
+
+class Conv2d(_ConvNd):
+    r"""Applies a 2D convolution over a quantized input signal composed of
+    several quantized input planes.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.Conv2d`.
+
+    .. note::
+        Only `zeros` is supported for the :attr:`padding_mode` argument.
+
+    .. note::
+        Only `torch.quint8` is supported for the input data type.
+
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+
+    See :class:`~torch.nn.Conv2d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> # With square kernels and equal stride
+        >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+        >>> # non-square kernels and unequal stride and with padding and dilation
+        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
+        >>> input = torch.randn(20, 16, 50, 100)
+        >>> # quantize input to quint8
+        >>> # xdoctest: +SKIP
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> output = m(q_input)
+
+    """
+    _FLOAT_MODULE = nn.Conv2d
+    _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
+    _NNI_CONV_RELU_MODULE = nni.ConvReLU2d
+    _NNI_CONV_ADD_MODULE = nni.ConvAdd2d
+    _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _pair(kernel_size)
+        stride = _pair(stride)
+        padding = _pair(padding)
+        dilation = _pair(dilation)
+        # Subclasses of _ConvNd need to call _init rather than __init__. See
+        # discussion on PR #49702
+        super()._init(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedConv2d'
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        if self.padding_mode == 'zeros':
+            self._packed_params = torch.ops.quantized.conv2d_prepack(
+                w, b, self.stride, self.padding, self.dilation, self.groups)
+        else:
+            self._packed_params = torch.ops.quantized.conv2d_prepack(
+                w, b, self.stride, _pair(0), self.dilation, self.groups)
+
+    def _weight_bias(self):
+        return self._packed_params.unpack()
+
+    def weight(self):
+        return self._weight_bias()[0]
+
+    def bias(self):
+        return self._weight_bias()[1]
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return ops.quantized.conv2d(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Creates a quantized module from a float module or qparams_dict.
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+              utilities or provided by the user
+        """
+        return _ConvNd.from_float(cls, mod)
+
+
+class Conv3d(_ConvNd):
+    r"""Applies a 3D convolution over a quantized input signal composed of
+    several quantized input planes.
+
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.Conv3d`.
+
+    .. note::
+        Only `zeros` is supported for the :attr:`padding_mode` argument.
+
+    .. note::
+        Only `torch.quint8` is supported for the input data type.
+
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+
+    See :class:`~torch.nn.Conv3d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> # With square kernels and equal stride
+        >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
+        >>> # non-square kernels and unequal stride and with padding and dilation
+        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
+        >>> input = torch.randn(20, 16, 56, 56, 56)
+        >>> # quantize input to quint8
+        >>> # xdoctest: +SKIP
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> output = m(q_input)
+
+    """
+    _FLOAT_MODULE = nn.Conv3d
+    _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
+    _NNI_CONV_RELU_MODULE = nni.ConvReLU3d
+    _NNI_CONV_ADD_MODULE: None = None
+    _NNI_CONV_ADD_RELU_MODULE: None = None
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros', device=None, dtype=None):
+        assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _triple(kernel_size)
+        stride = _triple(stride)
+        padding = _triple(padding)
+        dilation = _triple(dilation)
+        # Subclasses of _ConvNd need to call _init rather than __init__. See
+        # discussion on PR #49702
+        super()._init(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedConv3d'
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        if self.padding_mode == 'zeros':
+            self._packed_params = torch.ops.quantized.conv3d_prepack(
+                w, b, self.stride, self.padding, self.dilation, self.groups)
+        else:
+            self._packed_params = torch.ops.quantized.conv3d_prepack(
+                w, b, self.stride, _triple(0), self.dilation, self.groups)
+
+    def _weight_bias(self):
+        return self._packed_params.unpack()
+
+    def weight(self):
+        return self._weight_bias()[0]
+
+    def bias(self):
+        return self._weight_bias()[1]
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
+        if self.padding_mode != 'zeros':
+            _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
+            input = F.pad(input, _reversed_padding_repeated_twice,
+                          mode=self.padding_mode)
+        return ops.quantized.conv3d(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Creates a quantized module from a float module or qparams_dict.
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+              utilities or provided by the user
+        """
+        return _ConvNd.from_float(cls, mod)
+
+# === Transposed Convolutions ===
+MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
+
+
+class _ConvTransposeNd(_ConvNd):
+
+    _FLOAT_MODULE = MOD
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride,
+                 padding, dilation, transposed, output_padding,
+                 groups, bias, padding_mode, device=None, dtype=None):
+        if padding_mode != 'zeros':
+            raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}')
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        # Subclasses of _ConvNd need to call _init rather than __init__. See
+        # discussion on PR #49702
+        super()._init(
+            in_channels, out_channels, kernel_size, stride,
+            padding, dilation, transposed, output_padding,
+            groups, bias, padding_mode, **factory_kwargs)
+
+    def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
+        res = torch.jit.annotate(List[int], [])
+        for kdx in range(len(kernel_size)):
+            pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx])
+            res.append(pad)
+        return res
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Creates a quantized module from a float module or qparams_dict.
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+              utilities or provided by the user
+        """
+        # derived classes override cls._FLOAT_MODULE attribute
+        msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
+              cls._FLOAT_MODULE.__name__  # type: ignore[attr-defined]
+        assert type(mod) == cls._FLOAT_MODULE, msg
+        assert hasattr(mod, 'qconfig'), \
+            'Input float module must have qconfig defined.'
+        weight_post_process = mod.qconfig.weight()
+        weight_post_process(mod.weight)
+        assert weight_post_process.dtype == torch.qint8, \
+            'Weight observer must have a dtype of qint8'
+        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
+        # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
+        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,  # type: ignore[call-arg]
+                    mod.stride, mod.padding, mod.output_padding, mod.groups,
+                    mod.bias is not None, mod.dilation, mod.padding_mode)
+        qconv.set_weight_bias(qweight, mod.bias)
+        if not hasattr(mod, "activation_post_process") or mod.activation_post_process.dtype == torch.float:
+            return qconv  # dynamic quantization doesn't need scale/zero_point
+        else:
+            act_scale, act_zp = mod.activation_post_process.calculate_qparams()
+            qconv.scale = float(act_scale)
+            qconv.zero_point = int(act_zp)
+            return qconv
+
+    @staticmethod
+    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
+        r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
+        Args:
+            ref_qconvt (Module): a reference quantized  module, either produced by torch.ao.quantization
+                                 utilities or provided by the user
+            output_scale (float): scale for output Tensor
+            output_zero_point (int): zero point for output Tensor
+        """
+        qconv = cls(
+            ref_qconvt.in_channels,
+            ref_qconvt.out_channels,
+            ref_qconvt.kernel_size,  # type: ignore[arg-type]
+            ref_qconvt.stride,  # type: ignore[arg-type]
+            ref_qconvt.padding,  # type: ignore[arg-type]
+            ref_qconvt.output_padding,  # type: ignore[arg-type]
+            ref_qconvt.groups,
+            ref_qconvt.bias is not None,  # type: ignore[arg-type]
+            ref_qconvt.dilation,  # type: ignore[arg-type]
+            ref_qconvt.padding_mode,
+            device=ref_qconvt.weight.device,
+            dtype=ref_qconvt.weight.dtype)
+        qweight = ref_qconvt.get_quantized_weight()
+        qconv.set_weight_bias(qweight, ref_qconvt.bias)
+        qconv.scale = float(output_scale)
+        qconv.zero_point = int(output_zero_point)
+        return qconv
+
+
+class ConvTranspose1d(_ConvTransposeNd):
+    r"""Applies a 1D transposed convolution operator over an input image
+    composed of several input planes.
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.ConvTranspose1d`.
+
+    .. note:: Currently only the QNNPACK engine is implemented.
+        Please, set the `torch.backends.quantized.engine = 'qnnpack'`
+
+    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+    See :class:`~torch.nn.ConvTranspose2d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> torch.backends.quantized.engine = 'qnnpack'
+        >>> from torch.ao.nn import quantized as nnq
+        >>> # With square kernels and equal stride
+        >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+        >>> input = torch.randn(20, 16, 50)
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> output = m(q_input)
+        >>> # exact output size can be also specified as an argument
+        >>> input = torch.randn(1, 16, 12)
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
+        >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
+        >>> h = downsample(q_input)
+        >>> h.size()
+        torch.Size([1, 16, 6])
+        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
+        >>> output = upsample(h, output_size=input.size())
+        >>> output.size()
+        torch.Size([1, 16, 12])
+    """
+
+    _FLOAT_MODULE = nn.ConvTranspose1d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0, groups=1, bias=True,
+                 dilation=1, padding_mode='zeros', device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _single(kernel_size)
+        stride = _single(stride)
+        padding = _single(padding)
+        dilation = _single(dilation)
+        output_padding = _single(output_padding)
+
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedConvTranspose1d'
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
+            w, b, self.stride, self.padding, self.output_padding, self.dilation,
+            self.groups)
+
+    def _weight_bias(self):
+        w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params)
+        return w, b
+
+    def weight(self):
+        (w, _) = self._weight_bias()
+        return w
+
+    def bias(self):
+        (_, b) = self._weight_bias()
+        return b
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 3:
+            raise ValueError("Input shape must be `(N, C, L)`!")
+        return torch.ops.quantized.conv_transpose1d(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    @classmethod
+    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
+        return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
+
+
+class ConvTranspose2d(_ConvTransposeNd):
+    r"""Applies a 2D transposed convolution operator over an input image
+    composed of several input planes.
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.ConvTranspose2d`.
+
+    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+    See :class:`~torch.nn.ConvTranspose2d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> # QNNPACK or FBGEMM as backend
+        >>> torch.backends.quantized.engine = 'qnnpack'
+        >>> # With square kernels and equal stride
+        >>> import torch.ao.nn.quantized as nnq
+        >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
+        >>> # non-square kernels and unequal stride and with padding
+        >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+        >>> input = torch.randn(20, 16, 50, 100)
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> output = m(q_input)
+        >>> # exact output size can be also specified as an argument
+        >>> input = torch.randn(1, 16, 12, 12)
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
+        >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
+        >>> h = downsample(q_input)
+        >>> h.size()
+        torch.Size([1, 16, 6, 6])
+        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
+        >>> output = upsample(h, output_size=input.size())
+        >>> output.size()
+        torch.Size([1, 16, 12, 12])
+    """
+
+    _FLOAT_MODULE = nn.ConvTranspose2d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0, groups=1, bias=True,
+                 dilation=1, padding_mode='zeros', device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _pair(kernel_size)
+        stride = _pair(stride)
+        padding = _pair(padding)
+        dilation = _pair(dilation)
+        output_padding = _pair(output_padding)
+
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedConvTranspose2d'
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
+            w, b, self.stride, self.padding, self.output_padding, self.dilation,
+            self.groups)
+
+    def _weight_bias(self):
+        w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
+        return w, b
+
+    def weight(self):
+        (w, _) = self._weight_bias()
+        return w
+
+    def bias(self):
+        (_, b) = self._weight_bias()
+        return b
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 4:
+            raise ValueError("Input shape must be `(N, C, H, W)`!")
+        return ops.quantized.conv_transpose2d(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    @classmethod
+    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
+        return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
+
+
+class ConvTranspose3d(_ConvTransposeNd):
+    r"""Applies a 3D transposed convolution operator over an input image
+    composed of several input planes.
+    For details on input arguments, parameters, and implementation see
+    :class:`~torch.nn.ConvTranspose3d`.
+
+    .. note:: Currently only the FBGEMM engine is implemented.
+        Please, set the `torch.backends.quantized.engine = 'fbgemm'`
+
+    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`
+
+    Attributes:
+        weight (Tensor):     packed tensor derived from the learnable weight
+                             parameter.
+        scale (Tensor):      scalar for the output scale
+        zero_point (Tensor): scalar for the output zero point
+    See :class:`~torch.nn.ConvTranspose3d` for other attributes.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> torch.backends.quantized.engine = 'fbgemm'
+        >>> from torch.ao.nn import quantized as nnq
+        >>> # With cubic kernels and equal stride
+        >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
+        >>> # non-cubic kernels and unequal stride and with padding
+        >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
+        >>> input = torch.randn(20, 16, 50, 100, 100)
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> output = m(q_input)
+        >>> # exact output size can be also specified as an argument
+        >>> input = torch.randn(1, 16, 12, 12, 12)
+        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
+        >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
+        >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
+        >>> h = downsample(q_input)
+        >>> h.size()
+        torch.Size([1, 16, 6, 6, 6])
+        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
+        >>> output = upsample(h, output_size=input.size())
+        >>> output.size()
+        torch.Size([1, 16, 12, 12, 12])
+    """
+
+    _FLOAT_MODULE = nn.ConvTranspose3d
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0, groups=1, bias=True,
+                 dilation=1, padding_mode='zeros', device=None, dtype=None):
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        kernel_size = _triple(kernel_size)
+        stride = _triple(stride)
+        padding = _triple(padding)
+        dilation = _triple(dilation)
+        output_padding = _triple(output_padding)
+
+        super().__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation,
+            True, output_padding, groups, bias, padding_mode, **factory_kwargs)
+
+    def _get_name(self):
+        return 'QuantizedConvTranspose3d'
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
+            w, b, self.stride, self.padding, self.output_padding, self.dilation,
+            self.groups)
+
+    def _weight_bias(self):
+        w, b = torch.ops.quantized.conv3d_unpack(self._packed_params)
+        return w, b
+
+    def weight(self):
+        (w, _) = self._weight_bias()
+        return w
+
+    def bias(self):
+        (_, b) = self._weight_bias()
+        return b
+
+    def forward(self, input):
+        # Temporarily using len(shape) instead of ndim due to JIT issue
+        # https://github.com/pytorch/pytorch/issues/23890
+        if len(input.shape) != 5:
+            raise ValueError("Input shape must be `(N, C, T, H, W)`!")
+        return ops.quantized.conv_transpose3d(
+            input, self._packed_params, self.scale, self.zero_point)
+
+    @classmethod
+    def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
+        return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/dropout.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/dropout.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ace7f68a58326ef635ba29b65c172509bcd5c9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/dropout.py
@@ -0,0 +1,27 @@
+import torch
+
+__all__ = ['Dropout']
+
+class Dropout(torch.nn.Dropout):
+    r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
+        And this is a placeholder to enable models where fp32 tensors
+        had dropout to work with quantized tensors in train and eval mode.
+
+    Args:
+        p: probability of an element to be zeroed
+        inplace: can optionally do the operation in-place. Default: ``False``
+    """
+
+    def forward(self, input):
+        return input
+
+    def _get_name(self):
+        return 'QuantizedDropout'
+
+    @classmethod
+    def from_float(cls, mod):
+        return cls(mod.p, mod.inplace)
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(mod.p, mod.inplace)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..377007e64c3907677c5ca2804c69bc0a41c2b256
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py
@@ -0,0 +1,295 @@
+import torch
+import torch.nn as nn
+from torch import Tensor  # noqa: F401
+from torch._jit_internal import Optional, List  # noqa: F401
+
+from .utils import _hide_packed_params_repr
+from .utils import _quantize_weight
+
+__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag']
+
+class EmbeddingPackedParams(torch.nn.Module):
+    _version = 1
+
+    def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
+        super().__init__()
+        self.dtype = dtype
+        if self.dtype in [torch.quint8, torch.quint4x2]:
+            scales = torch.ones(num_embeddings, dtype=torch.float)
+            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
+            wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales,
+                                                           zero_points=zero_points,
+                                                           axis=0, dtype=self.dtype)
+            self.set_weight(wq)
+        else:
+            raise NotImplementedError(f'Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}')
+
+    @torch.jit.export
+    def set_weight(self, weight: torch.Tensor) -> None:
+        if self.dtype in [torch.quint8, torch.quint4x2]:
+            self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
+        else:
+            raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')
+
+
+    @torch.jit.export
+    def _weight(self):
+        if self.dtype in [torch.quint8, torch.quint4x2]:
+            return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
+        else:
+            raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')
+
+    def forward(self, x):
+        return x
+
+    # Version 1
+    #   self
+    #   |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
+    #   |--- dtype : torch.dtype
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'dtype'] = self.dtype
+        destination[prefix + '_packed_weight'] = self._weight()
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        self.dtype = state_dict[prefix + 'dtype']
+        state_dict.pop(prefix + 'dtype')
+
+        weight = state_dict[prefix + '_packed_weight']
+        state_dict.pop(prefix + '_packed_weight')
+        self.set_weight(weight)
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+    def __repr__(self):
+        return self._weight().__repr__()
+
+class Embedding(torch.nn.Module):
+    r"""
+    A quantized Embedding module with quantized packed weights as inputs.
+    We adopt the same interface as `torch.nn.Embedding`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation.
+
+    Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
+    initialized at module creation time and will be overwritten later
+
+    Attributes:
+        weight (Tensor): the non-learnable quantized weights of the module of
+                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
+
+    Examples::
+        >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
+        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
+        >>> output = m(indices)
+        >>> print(output.size())
+        torch.Size([9, 12])
+
+    """
+    _version = 1
+
+    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+                 sparse: bool = False, _weight: Optional[Tensor] = None, dtype=torch.quint8) -> None:
+        super().__init__()
+        self.num_embeddings = num_embeddings
+        self.embedding_dim = embedding_dim
+        self.dtype = dtype
+
+        if _weight is None:
+            scales = torch.ones(num_embeddings, dtype=torch.float)
+            zero_points = torch.zeros(num_embeddings, dtype=torch.float)
+            qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
+                                                                scales=scales, zero_points=zero_points,
+                                                                axis=0, dtype=torch.quint8)
+        else:
+            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
+                'Shape of weight does not match num_embeddings and embedding_dim'
+            qweight = _weight
+
+        self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype)
+        self._packed_params.set_weight(qweight)
+
+    def forward(self, indices: Tensor) -> Tensor:
+        if self.dtype == torch.quint4x2:
+            return torch.ops.quantized.embedding_4bit(self._packed_params._packed_weight, indices)
+        else:
+            return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
+
+    def _get_name(self):
+        return 'QuantizedEmbedding'
+
+    def __repr__(self):
+        return _hide_packed_params_repr(self, EmbeddingPackedParams)
+
+    def extra_repr(self):
+        extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(
+            self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme()
+        )
+
+        return extra_repr_str
+
+    def set_weight(self, w: torch.Tensor) -> None:
+        self._packed_params.set_weight(w)
+
+    def weight(self):
+        return self._packed_params._weight()
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a quantized embedding module from a float module
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+                          utilities or provided by user
+        """
+        if hasattr(mod, 'weight_fake_quant'):
+            assert type(mod) == torch.ao.nn.qat.Embedding, 'nnq.' + cls.__name__ + '.from_float ' + \
+                'with fake quant only works for ' + torch.ao.nn.qat.Embedding.__name__
+            weight_observer = mod.weight_fake_quant
+            activation_post_process = mod.activation_post_process
+        else:
+            assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
+                nn.Embedding.__name__
+            assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
+            from torch.ao.quantization import float_qparams_weight_only_qconfig
+            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
+                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
+            else:
+                weight_observer = float_qparams_weight_only_qconfig.weight()
+
+        dtype = weight_observer.dtype
+        is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
+        assert is_float_qparams_qconfig, \
+            'Embedding quantization is only supported with float_qparams_weight_only_qconfig.'
+
+        assert dtype == torch.quint8 or dtype == torch.quint4x2, \
+            f'The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}'
+
+        # Run the observer to calculate qparams.
+        weight_observer(mod.weight)
+        qweight = _quantize_weight(mod.weight.float(), weight_observer)
+
+        # Create quantized Embedding module and pass in the quantized weight
+        qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
+        qembedding.set_weight(qweight)
+        return qembedding
+
+    @classmethod
+    def from_reference(cls, ref_embedding):
+        qembedding = cls(
+            ref_embedding.num_embeddings,
+            ref_embedding.embedding_dim,
+            ref_embedding.padding_idx,
+            ref_embedding.max_norm,
+            ref_embedding.norm_type,
+            ref_embedding.scale_grad_by_freq,
+            ref_embedding.sparse,
+            ref_embedding.get_quantized_weight(),
+            ref_embedding.weight_dtype,
+        )
+        return qembedding
+
+class EmbeddingBag(Embedding):
+    r"""
+    A quantized EmbeddingBag module with quantized packed weights as inputs.
+    We adopt the same interface as `torch.nn.EmbeddingBag`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation.
+
+    Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
+    initialized at module creation time and will be overwritten later
+
+    Attributes:
+        weight (Tensor): the non-learnable quantized weights of the module of
+                         shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
+
+    Examples::
+        >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
+        >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
+        >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
+        >>> output = m(indices, offsets)
+        >>> print(output.size())
+        torch.Size([5, 12])
+
+    """
+    _version = 1
+
+    def __init__(self, num_embeddings: int, embedding_dim: int,
+                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+                 mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None,
+                 include_last_offset: bool = False, dtype=torch.quint8) -> None:
+        super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
+
+        self.mode = mode
+        self.pruned_weights = False
+        self.include_last_offset = include_last_offset
+        self.dtype = dtype
+
+    def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None,
+                compressed_indices_mapping: Optional[Tensor] = None) -> Tensor:
+        if self.dtype == torch.quint4x2:
+            return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0,
+                                                          self.pruned_weights, per_sample_weights, compressed_indices_mapping,
+                                                          self.include_last_offset)
+        else:
+            return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
+                                                          self.pruned_weights, per_sample_weights, compressed_indices_mapping,
+                                                          self.include_last_offset)
+
+    def _get_name(self):
+        return 'QuantizedEmbeddingBag'
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a quantized embedding_bag module from a float module
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+                          utilities or provided by user
+        """
+        if hasattr(mod, 'weight_fake_quant'):
+            weight_observer = mod.weight_fake_quant
+        else:
+            assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
+                nn.EmbeddingBag.__name__
+            assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
+            from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
+            if mod.qconfig is not None and mod.qconfig.weight is not None:  # type: ignore[union-attr]
+                weight_observer = mod.qconfig.weight()  # type: ignore[union-attr, operator]
+            else:
+                weight_observer = float_qparams_weight_only_qconfig.weight()
+
+        dtype = weight_observer.dtype
+        is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
+        assert is_float_qparams_qconfig, \
+            'EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig.'
+
+        assert dtype == torch.quint8 or dtype == torch.quint4x2, \
+            f'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}'
+
+        # Run the observer to calculate qparams.
+        weight_observer(mod.weight)
+        qweight = _quantize_weight(mod.weight.float(), weight_observer)
+
+        # Create quantized EmbeddingBag module and pass in the quantized weight
+        qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
+        qembedding_bag.set_weight(qweight)
+        return qembedding_bag
+
+    @classmethod
+    def from_reference(cls, ref_embedding_bag):
+        qembedding_bag = cls(
+            ref_embedding_bag.num_embeddings,
+            ref_embedding_bag.embedding_dim,
+            ref_embedding_bag.max_norm,
+            ref_embedding_bag.norm_type,
+            ref_embedding_bag.scale_grad_by_freq,
+            ref_embedding_bag.mode,
+            ref_embedding_bag.sparse,
+            ref_embedding_bag.get_quantized_weight(),
+            ref_embedding_bag.include_last_offset,
+            ref_embedding_bag.weight_dtype,
+        )
+        return qembedding_bag
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/functional_modules.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/functional_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7751c2533d5d9201b320ed2c6e4196efa2eafaf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/functional_modules.py
@@ -0,0 +1,249 @@
+from typing import List
+
+import torch
+from torch import Tensor
+from torch._ops import ops
+
+__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional']
+
+class FloatFunctional(torch.nn.Module):
+    r"""State collector class for float operations.
+
+    The instance of this class can be used instead of the ``torch.`` prefix for
+    some operations. See example usage below.
+
+    .. note::
+
+        This class does not provide a ``forward`` hook. Instead, you must use
+        one of the underlying functions (e.g. ``add``).
+
+    Examples::
+
+        >>> f_add = FloatFunctional()
+        >>> a = torch.tensor(3.0)
+        >>> b = torch.tensor(4.0)
+        >>> f_add.add(a, b)  # Equivalent to ``torch.add(a, b)``
+
+    Valid operation names:
+        - add
+        - cat
+        - mul
+        - add_relu
+        - add_scalar
+        - mul_scalar
+    """
+    def __init__(self):
+        super().__init__()
+        self.activation_post_process = torch.nn.Identity()
+
+    def forward(self, x):
+        raise RuntimeError("FloatFunctional is not intended to use the " +
+                           "'forward'. Please use the underlying operation")
+
+    r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
+    def add(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.add(x, y)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.add(Tensor, float)``"""
+    def add_scalar(self, x: Tensor, y: float) -> Tensor:
+        r = torch.add(x, y)
+        # Note: this operation is not observed because the observation is not
+        # needed for the quantized op.
+        return r
+
+    r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
+    def mul(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.mul(x, y)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
+    def mul_scalar(self, x: Tensor, y: float) -> Tensor:
+        r = torch.mul(x, y)
+        # Note: this operation is not observed because the observation is not
+        # needed for the quantized op.
+        return r
+
+    r"""Operation equivalent to ``torch.cat``"""
+    def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
+        r = torch.cat(x, dim=dim)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``relu(torch.add(x,y))``"""
+    def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.add(x, y)
+        r = torch.nn.functional.relu(r)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
+    def matmul(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.matmul(x, y)
+        r = self.activation_post_process(r)
+        return r
+
+class FXFloatFunctional(torch.nn.Module):
+    r""" module to replace FloatFunctional module before FX graph mode quantization,
+    since activation_post_process will be inserted in top level module directly
+
+    Valid operation names:
+        - add
+        - cat
+        - mul
+        - add_relu
+        - add_scalar
+        - mul_scalar
+    """
+    def forward(self, x):
+        raise RuntimeError("FloatFunctional is not intended to use the " +
+                           "'forward'. Please use the underlying operation")
+
+    r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
+    def add(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.add(x, y)
+        return r
+
+    r"""Operation equivalent to ``torch.add(Tensor, float)``"""
+    def add_scalar(self, x: Tensor, y: float) -> Tensor:
+        r = torch.add(x, y)
+        return r
+
+    r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
+    def mul(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.mul(x, y)
+        return r
+
+    r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
+    def mul_scalar(self, x: Tensor, y: float) -> Tensor:
+        r = torch.mul(x, y)
+        return r
+
+    r"""Operation equivalent to ``torch.cat``"""
+    def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
+        r = torch.cat(x, dim=dim)
+        return r
+
+    r"""Operation equivalent to ``relu(torch.add(x,y))``"""
+    def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.add(x, y)
+        r = torch.nn.functional.relu(r)
+        return r
+
+    r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
+    def matmul(self, x: Tensor, y: Tensor) -> Tensor:
+        r = torch.matmul(x, y)
+        return r
+
+class QFunctional(torch.nn.Module):
+    r"""Wrapper class for quantized operations.
+
+    The instance of this class can be used instead of the
+    ``torch.ops.quantized`` prefix. See example usage below.
+
+    .. note::
+
+        This class does not provide a ``forward`` hook. Instead, you must use
+        one of the underlying functions (e.g. ``add``).
+
+    Examples::
+
+        >>> q_add = QFunctional()
+        >>> # xdoctest: +SKIP
+        >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32)
+        >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
+        >>> q_add.add(a, b)  # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)``
+
+    Valid operation names:
+        - add
+        - cat
+        - mul
+        - add_relu
+        - add_scalar
+        - mul_scalar
+    """
+    def __init__(self):
+        super().__init__()
+        self.scale = 1.0
+        self.zero_point = 0
+        self.activation_post_process = torch.nn.Identity()
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'scale'] = torch.tensor(self.scale)
+        destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+
+        self.scale = float(state_dict.pop(prefix + 'scale'))
+        self.zero_point = int(state_dict.pop(prefix + 'zero_point'))
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+    def _get_name(self):
+        return 'QFunctional'
+
+    def extra_repr(self):
+        return f'scale={self.scale}, zero_point={self.zero_point}'
+
+    def forward(self, x):
+        raise RuntimeError("Functional is not intended to use the " +
+                           "'forward'. Please use the underlying operation")
+
+    r"""Operation equivalent to ``torch.ops.quantized.add``"""
+    def add(self, x: Tensor, y: Tensor) -> Tensor:
+        r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
+    def add_scalar(self, x: Tensor, y: float) -> Tensor:
+        r = ops.quantized.add_scalar(x, y)
+        # Note: this operation is not observed because the observation is not
+        # needed for the quantized op.
+        return r
+
+    r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
+    def mul(self, x: Tensor, y: Tensor) -> Tensor:
+        r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
+    def mul_scalar(self, x: Tensor, y: float) -> Tensor:
+        r = ops.quantized.mul_scalar(x, y)
+        # Note: this operation is not observed because the observation is not
+        # needed for the quantized op.
+        return r
+
+    r"""Operation equivalent to ``torch.ops.quantized.cat``"""
+    def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
+        r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
+    def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
+        r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
+        r = self.activation_post_process(r)
+        return r
+
+    r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
+    def matmul(self, x: Tensor, y: Tensor) -> Tensor:
+        r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
+        # Note: this operation is not observed because the observation is not
+        # needed for the quantized op.
+        return r
+
+    @classmethod
+    def from_float(cls, mod):
+        assert type(mod) == FloatFunctional, \
+            "QFunctional.from_float expects an instance of FloatFunctional"
+        scale, zero_point = mod.activation_post_process.calculate_qparams()  # type: ignore[operator]
+        new_mod = QFunctional()
+        new_mod.scale = float(scale)
+        new_mod.zero_point = int(zero_point)
+        return new_mod
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..e84f6da2f68ffde2a5291e16ebb2e12c14cef09b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/linear.py
@@ -0,0 +1,303 @@
+from collections.abc import Iterable
+import torch
+
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.qat as nniqat
+from torch.nn.utils.fusion import fuse_linear_bn_weights
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+from typing import Optional
+
+from .utils import _quantize_weight, _hide_packed_params_repr, WeightedQuantizedModule
+
+__all__ = ['LinearPackedParams', 'Linear']
+
+
+class LinearPackedParams(torch.nn.Module):
+    _version = 3
+
+    def __init__(self, dtype=torch.qint8):
+        super().__init__()
+        self.dtype = dtype
+        if self.dtype == torch.qint8:
+            wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
+        elif self.dtype == torch.float16:
+            wq = torch.zeros([1, 1], dtype=torch.float)
+        self.set_weight_bias(wq, None)  # type: ignore[possibly-undefined]
+
+    @torch.jit.export
+    def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
+        if self.dtype == torch.qint8:
+            self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
+        elif self.dtype == torch.float16:
+            self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)
+        else:
+            raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
+
+
+    @torch.jit.export
+    def _weight_bias(self):
+        if self.dtype == torch.qint8:
+            return torch.ops.quantized.linear_unpack(self._packed_params)
+        elif self.dtype == torch.float16:
+            return torch.ops.quantized.linear_unpack_fp16(self._packed_params)
+        else:
+            raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
+
+    def forward(self, x):
+        return x
+
+    # Version 1
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #
+    # Version 2
+    #   self
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #   |--- dtype : torch.dtype
+    #
+    # Version 3
+    #   self
+    #   |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
+    #                         of LinearPackedParams
+    #   |--- dtype : torch.dtype
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'dtype'] = self.dtype
+        destination[prefix + '_packed_params'] = self._weight_bias()
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+        if version is None or version < 2:
+            self.dtype = torch.qint8
+        else:
+            self.dtype = state_dict[prefix + 'dtype']
+            state_dict.pop(prefix + 'dtype')
+
+        if version is None or version < 3:
+            self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
+            state_dict.pop(prefix + 'weight')
+            state_dict.pop(prefix + 'bias')
+
+        if version == 3:
+            weight, bias = state_dict[prefix + '_packed_params']
+            state_dict.pop(prefix + '_packed_params')
+            self.set_weight_bias(weight, bias)
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+
+    def __repr__(self):
+        return self._weight_bias().__repr__()
+
+
+class Linear(WeightedQuantizedModule):
+    r"""
+    A quantized linear module with quantized tensor as inputs and outputs.
+    We adopt the same interface as `torch.nn.Linear`, please see
+    https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
+
+    Similar to :class:`~torch.nn.Linear`, attributes will be randomly
+    initialized at module creation time and will be overwritten later
+
+    Attributes:
+        weight (Tensor): the non-learnable quantized weights of the module of
+                         shape :math:`(\text{out\_features}, \text{in\_features})`.
+        bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
+                If :attr:`bias` is ``True``, the values are initialized to zero.
+        scale: `scale` parameter of output Quantized Tensor, type: double
+        zero_point: `zero_point` parameter for output Quantized Tensor, type: long
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
+        >>> m = nn.quantized.Linear(20, 30)
+        >>> input = torch.randn(128, 20)
+        >>> # xdoctest: +SKIP
+        >>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
+        >>> output = m(input)
+        >>> print(output.size())
+        torch.Size([128, 30])
+    """
+    _version = 3
+    _FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear)
+
+    def __init__(self, in_features, out_features, bias_=True,
+                 dtype=torch.qint8):
+        super().__init__()
+        # We don't muck around with buffers or attributes or anything here
+        # to keep the module simple. *everything* is simply a Python attribute.
+        # Serialization logic is explicitly handled in the below serialization and
+        # deserialization modules
+        self.in_features = in_features
+        self.out_features = out_features
+        bias = None
+        if bias_:
+            bias = torch.zeros(out_features, dtype=torch.float)
+
+        if dtype == torch.qint8:
+            qweight = torch._empty_affine_quantized(
+                [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8)
+        elif dtype == torch.float16:
+            qweight = torch.zeros([out_features, in_features], dtype=torch.float)
+        else:
+            raise RuntimeError('Unsupported dtype specified for quantized Linear!')
+
+        self._packed_params = LinearPackedParams(dtype)
+        self._packed_params.set_weight_bias(qweight, bias)
+        self.scale = 1.0
+        self.zero_point = 0
+
+    def _get_name(self):
+        return 'QuantizedLinear'
+
+    def extra_repr(self):
+        return 'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'.format(
+            self.in_features, self.out_features, self.scale, self.zero_point, self.weight().qscheme()
+        )
+
+    def __repr__(self):
+        return _hide_packed_params_repr(self, LinearPackedParams)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.ops.quantized.linear(
+            x, self._packed_params._packed_params, self.scale, self.zero_point)
+
+    # ===== Serialization methods =====
+    # The special consideration here is that we have to unpack the weights into their
+    # regular QTensor form for serialization. Packed weights should not live
+    # outside the process in which they were created, rather they should be derived
+    # from the QTensor weight.
+    #
+    # Version 1
+    #   self
+    #   |--- scale : float
+    #   |--- zero_point : int
+    #   |--- weight : Tensor
+    #   |--- bias : Tensor
+    #
+    # Version 2
+    #   self
+    #   |--- scale : float
+    #   |--- zero_point : int
+    #   |--- _packed_params : Module
+    #        |--- weight : Tensor
+    #        |--- bias : Tensor
+    #
+    # Version 3
+    #   self
+    #   |--- scale : float
+    #   |--- zero_point : int
+    #   |--- _packed_params : Module
+    #        |--- _packed_params : (Tensor, Tensor) representing weight, bias
+    #                              of LinearPackedParams C++ struct
+    #
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'scale'] = torch.tensor(self.scale)
+        destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
+
+    # ===== Deserialization methods =====
+    # Counterpart to the serialization methods, we must pack the serialized QTensor
+    # weight into its packed format for use by the FBGEMM ops.
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        self.scale = float(state_dict[prefix + 'scale'])
+        state_dict.pop(prefix + 'scale')
+
+        self.zero_point = int(state_dict[prefix + 'zero_point'])
+        state_dict.pop(prefix + 'zero_point')
+
+        version = local_metadata.get('version', None)
+
+        if version is None or version == 1:
+            # We moved the parameters into a LinearPackedParameters submodule
+            weight = state_dict.pop(prefix + 'weight')
+            bias = state_dict.pop(prefix + 'bias')
+            state_dict.update({prefix + '_packed_params.weight': weight,
+                               prefix + '_packed_params.bias': bias})
+
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, False,
+            missing_keys, unexpected_keys, error_msgs)
+
+    # Function rather than property to make sure that JIT serialization doesn't
+    # register this as an attribute
+    def _weight_bias(self):
+        return self._packed_params._weight_bias()
+
+    def weight(self):
+        return self._weight_bias()[0]
+
+    def bias(self):
+        return self._weight_bias()[1]
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
+        self._packed_params.set_weight_bias(w, b)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a quantized module from an observed float module
+
+        Args:
+            mod (Module): a float module, either produced by torch.ao.quantization
+                          utilities or provided by the user
+        """
+        if hasattr(mod, 'weight_fake_quant'):
+            if type_before_parametrizations(mod) == nniqat.LinearBn1d:
+                mod.weight, mod.bias = fuse_linear_bn_weights(
+                    mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
+                    mod.bn.eps, mod.bn.weight, mod.bn.bias)
+            weight_post_process = mod.weight_fake_quant
+            activation_post_process = mod.activation_post_process
+        else:
+            # This function does not participate in JIT, so it is OK to ignore
+            # the type mismatch in assignment. Also, mypy has an issue with
+            # iterables not being implemented, so we are ignoring those too.
+            if not isinstance(cls._FLOAT_MODULE, Iterable):
+                cls._FLOAT_MODULE = [cls._FLOAT_MODULE]  # type: ignore[assignment]
+            supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE])  # type: ignore[attr-defined]
+            error_msg = f'nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}'
+            assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, error_msg.format()  # type: ignore[attr-defined]
+            assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+            activation_post_process = mod.activation_post_process
+            if type_before_parametrizations(mod) == nni.LinearReLU:
+                mod = mod[0]
+            weight_post_process = mod.qconfig.weight()
+        weight_post_process(mod.weight)
+        dtype = weight_post_process.dtype
+        act_scale, act_zp = activation_post_process.calculate_qparams()
+        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
+        qlinear = cls(mod.in_features,
+                      mod.out_features,
+                      dtype=dtype)
+        qlinear.set_weight_bias(qweight, mod.bias)
+        qlinear.scale = float(act_scale)
+        qlinear.zero_point = int(act_zp)
+        return qlinear
+
+    @classmethod
+    def from_reference(cls, ref_qlinear, output_scale, output_zero_point):
+        r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
+
+        Args:
+            ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
+                          utilities or provided by the user
+            output_scale (float): scale for output Tensor
+            output_zero_point (int): zero point for output Tensor
+        """
+        qlinear = cls(
+            ref_qlinear.in_features,
+            ref_qlinear.out_features)
+        qweight = ref_qlinear.get_quantized_weight()
+        qlinear.set_weight_bias(qweight, ref_qlinear.bias)
+
+        qlinear.scale = float(output_scale)
+        qlinear.zero_point = int(output_zero_point)
+        return qlinear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/normalization.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8ecb0baa4548a0dd94b05d810cc982ffe900ee
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/normalization.py
@@ -0,0 +1,199 @@
+import torch
+
+__all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
+
+class LayerNorm(torch.nn.LayerNorm):
+    r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
+
+    Additional args:
+        * **scale** - quantization scale of the output, type: double.
+        * **zero_point** - quantization zero point of the output, type: long.
+
+    """
+
+    def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
+                 elementwise_affine=True, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
+                         **factory_kwargs)
+        self.weight = weight
+        self.bias = bias
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.layer_norm(
+            input, self.normalized_shape, weight=self.weight, bias=self.bias,
+            eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedLayerNorm'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        new_mod = cls(
+            mod.normalized_shape, mod.weight, mod.bias, float(scale),
+            int(zero_point), mod.eps, mod.elementwise_affine)
+        return new_mod
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(
+            mod.normalized_shape, mod.weight, mod.bias, float(scale),
+            int(zero_point), mod.eps, mod.elementwise_affine)
+
+class GroupNorm(torch.nn.GroupNorm):
+    r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
+
+    Additional args:
+        * **scale** - quantization scale of the output, type: double.
+        * **zero_point** - quantization zero point of the output, type: long.
+
+    """
+    __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
+
+    def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
+                 affine=True, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
+        self.weight = weight
+        self.bias = bias
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.group_norm(
+            input, self.num_groups, self.weight, self.bias, self.eps, self.scale,
+            self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedGroupNorm'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        new_mod = cls(
+            mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
+        return new_mod
+
+class InstanceNorm1d(torch.nn.InstanceNorm1d):
+    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
+
+    Additional args:
+        * **scale** - quantization scale of the output, type: double.
+        * **zero_point** - quantization zero point of the output, type: long.
+
+    """
+    def __init__(self, num_features, weight, bias, scale, zero_point,
+                 eps=1e-5, momentum=0.1, affine=False,
+                 track_running_stats=False, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
+        self.weight = weight
+        self.bias = bias
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.instance_norm(
+            input, self.weight, self.bias, self.eps, self.scale,
+            self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedInstanceNorm1d'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        new_mod = cls(
+            mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
+        return new_mod
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(
+            mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
+
+class InstanceNorm2d(torch.nn.InstanceNorm2d):
+    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
+
+    Additional args:
+        * **scale** - quantization scale of the output, type: double.
+        * **zero_point** - quantization zero point of the output, type: long.
+
+    """
+    def __init__(self, num_features, weight, bias, scale, zero_point,
+                 eps=1e-5, momentum=0.1, affine=False,
+                 track_running_stats=False, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
+        self.weight = weight
+        self.bias = bias
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.instance_norm(
+            input, self.weight, self.bias, self.eps, self.scale,
+            self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedInstanceNorm2d'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        new_mod = cls(
+            mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
+        return new_mod
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(
+            mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
+
+class InstanceNorm3d(torch.nn.InstanceNorm3d):
+    r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
+
+    Additional args:
+        * **scale** - quantization scale of the output, type: double.
+        * **zero_point** - quantization zero point of the output, type: long.
+
+    """
+    def __init__(self, num_features, weight, bias, scale, zero_point,
+                 eps=1e-5, momentum=0.1, affine=False,
+                 track_running_stats=False, device=None, dtype=None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
+        self.weight = weight
+        self.bias = bias
+        self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
+        self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
+
+    def forward(self, input):
+        return torch.ops.quantized.instance_norm(
+            input, self.weight, self.bias, self.eps, self.scale,
+            self.zero_point)
+
+    def _get_name(self):
+        return 'QuantizedInstanceNorm3d'
+
+    @classmethod
+    def from_float(cls, mod):
+        scale, zero_point = mod.activation_post_process.calculate_qparams()
+        new_mod = cls(
+            mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
+        return new_mod
+
+    @classmethod
+    def from_reference(cls, mod, scale, zero_point):
+        return cls(
+            mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
+            mod.eps, mod.affine)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/rnn.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcf10fe4d97ffef46bc38c261cab21f269f98a25
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/rnn.py
@@ -0,0 +1,51 @@
+import torch
+
+__all__ = [
+    "LSTM",
+]
+
+class LSTM(torch.ao.nn.quantizable.LSTM):
+    r"""A quantized long short-term memory (LSTM).
+
+    For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
+
+    Attributes:
+        layers : instances of the `_LSTMLayer`
+
+    .. note::
+        To access the weights and biases, you need to access them per layer.
+        See examples in :class:`~torch.ao.nn.quantizable.LSTM`
+
+    Examples::
+        >>> # xdoctest: +SKIP
+        >>> custom_module_config = {
+        ...     'float_to_observed_custom_module_class': {
+        ...         nn.LSTM: nn.quantizable.LSTM,
+        ...     },
+        ...     'observed_to_quantized_custom_module_class': {
+        ...         nn.quantizable.LSTM: nn.quantized.LSTM,
+        ...     }
+        ... }
+        >>> tq.prepare(model, prepare_custom_module_class=custom_module_config)
+        >>> tq.convert(model, convert_custom_module_class=custom_module_config)
+    """
+    _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM  # type: ignore[assignment]
+
+    def _get_name(self):
+        return 'QuantizedLSTM'
+
+    @classmethod
+    def from_float(cls, *args, **kwargs):
+        # The whole flow is float -> observed -> quantized
+        # This class does observed -> quantized only
+        raise NotImplementedError("It looks like you are trying to convert a "
+                                  "non-observed LSTM module. Please, see "
+                                  "the examples on quantizable LSTMs.")
+
+    @classmethod
+    def from_observed(cls, other):
+        assert type(other) == cls._FLOAT_MODULE  # type: ignore[has-type]
+        converted = torch.ao.quantization.convert(other, inplace=False,
+                                                  remove_qconfig=True)
+        converted.__class__ = cls
+        return converted
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/utils.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31f792351be933c70d8b77dd7a993ed26663d07
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/modules/utils.py
@@ -0,0 +1,117 @@
+import abc
+import torch
+import itertools
+import collections
+from torch.nn.modules.module import _addindent
+
+__all__ = [
+    "WeightedQuantizedModule",
+]
+
+class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta):
+    """Wrapper for quantized modules than can be lowered from reference modules."""
+    @classmethod
+    @abc.abstractmethod
+    def from_reference(cls, ref_module, output_scale, output_zero_point):
+        raise NotImplementedError
+
+def _get_weight_observer(observer):
+    # FakeQuantize observer
+    if hasattr(observer, "activation_post_process"):
+        observer = observer.activation_post_process
+    # UniformQuantizationObserverBase observer
+    return observer
+
+def _needs_weight_clamping(observer, dtype):
+    observer = _get_weight_observer(observer)
+    if dtype in [torch.qint8, torch.quint8, torch.qint32]:
+        info = torch.iinfo(dtype)
+        return observer.quant_min > info.min or observer.quant_max < info.max
+    return False
+
+def _clamp_weights(qweight, observer, scale, zp):
+    if not _needs_weight_clamping(observer, qweight.dtype):
+        return qweight
+
+    observer = _get_weight_observer(observer)
+    min_, max_ = observer.quant_min, observer.quant_max
+
+    # Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet.
+    qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
+    qw_int_min = torch.clone(qweight.int_repr()).fill_(min_)
+    qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max)
+
+    if observer.qscheme in [torch.per_tensor_symmetric,
+                            torch.per_tensor_affine]:
+        qweight = torch._make_per_tensor_quantized_tensor(qw_int, scale.item(), zp.item())
+    elif observer.qscheme in [torch.per_channel_symmetric,
+                              torch.per_channel_affine,
+                              torch.per_channel_affine_float_qparams]:
+        qweight = torch._make_per_channel_quantized_tensor(qw_int, scale, zp, axis=observer.ch_axis)
+    else:
+        raise ValueError("Unexpected qscheme " + observer.qscheme)
+    return qweight
+
+def _quantize_weight(float_wt, observer):
+    wt_scale, wt_zp = observer.calculate_qparams()
+    if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
+        qweight = torch.quantize_per_tensor(
+            float_wt,
+            float(wt_scale), int(wt_zp), torch.qint8)
+        qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
+    elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
+        wt_axis = observer.ch_axis
+        qweight = torch.quantize_per_channel(
+            float_wt,
+            wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
+        qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
+    elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
+        qweight = torch.quantize_per_channel(
+            float_wt,
+            wt_scale.to(torch.float), wt_zp.to(torch.float), observer.ch_axis, observer.dtype)
+        qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
+    else:
+        raise ValueError("Unexpected qscheme " + observer.qscheme)
+    return qweight
+
+def _ntuple_from_first(n):
+    """Converts the argument to a tuple of size n
+    with the first element repeated."""
+    def parse(x):
+        while isinstance(x, collections.abc.Sequence):
+            if len(x) == n:
+                break
+            x = x[0]
+        return tuple(itertools.repeat(x, n))
+    return parse
+
+def _hide_packed_params_repr(self, params):
+    # We don't want to show `PackedParams` children, hence custom
+    # `__repr__`. This is the same as nn.Module.__repr__, except the check
+    # for the `params module`.
+    extra_lines = []
+    extra_repr = self.extra_repr()
+    # empty string will be split into list ['']
+    if extra_repr:
+        extra_lines = extra_repr.split('\n')
+    child_lines = []
+    for key, module in self._modules.items():
+        if isinstance(module, params):
+            continue
+        mod_str = repr(module)
+        mod_str = _addindent(mod_str, 2)
+        child_lines.append('(' + key + '): ' + mod_str)
+    lines = extra_lines + child_lines
+
+    main_str = self._get_name() + '('
+    if lines:
+        # simple one-liner info, which most builtin Modules will use
+        if len(extra_lines) == 1 and not child_lines:
+            main_str += extra_lines[0]
+        else:
+            main_str += '\n  ' + '\n  '.join(lines) + '\n'
+
+    main_str += ')'
+    return main_str
+
+_pair_from_first = _ntuple_from_first(2)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db56382dc7634a09159e944b9dbbc234ae1e16b6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/__init__.py
@@ -0,0 +1,18 @@
+from .modules import *  # noqa: F403
+
+__all__ = [
+    'Linear',
+    'Conv1d',
+    'Conv2d',
+    'Conv3d',
+    'ConvTranspose1d',
+    'ConvTranspose2d',
+    'ConvTranspose3d',
+    'RNNCell',
+    'LSTMCell',
+    'GRUCell',
+    'LSTM',
+    'GRU',
+    'Embedding',
+    'EmbeddingBag',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d07bccfe77e6a5286bf3ef719482bed0250ff00
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d2f5206435155cd6074cfbab55ee7a70f6d43fc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py
@@ -0,0 +1,21 @@
+from .linear import Linear
+from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
+from .rnn import RNNCell, LSTMCell, GRUCell, LSTM, GRU
+from .sparse import Embedding, EmbeddingBag
+
+__all__ = [
+    'Linear',
+    'Conv1d',
+    'Conv2d',
+    'Conv3d',
+    'ConvTranspose1d',
+    'ConvTranspose2d',
+    'ConvTranspose3d',
+    'RNNCell',
+    'LSTMCell',
+    'GRUCell',
+    'LSTM',
+    'GRU',
+    'Embedding',
+    'EmbeddingBag',
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1377534f853a6fb8c78733417dd60b4a87ad9ca
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84d177f82b9b0f452304f73acce77a7c31516b52
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..698626b0b0603706f7c83c898196aaf8962d02d2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..047155df3cc65dffc936511bc7caf86784a7bd05
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1beec7cfd6c27fa993417a45ef1fac866108821
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..478ff47be8d56270e06c4bf1994ee60882c8b729
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/conv.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bd0dde9d99c6bc6d5bbd0e9bbe8fac9c233fb5f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/conv.py
@@ -0,0 +1,318 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Dict, Any, List
+from torch.nn.common_types import _size_1_t
+from .utils import ReferenceQuantizedModule
+
+__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
+
+class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule):
+    """ A reference version of nn.quantized.Conv2d
+        we will not pack the parameters in this module, since weight packing is an
+        optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
+        this is useful when user want to use this module in other backends like Glow.
+    """
+    __annotations__ = {"bias": Optional[torch.Tensor]}
+    _IS_REFERENCE = True
+
+    @staticmethod
+    def from_float(cls, float_conv, weight_qparams):
+        qref_conv = cls(
+            float_conv.in_channels,
+            float_conv.out_channels,
+            float_conv.kernel_size,  # type: ignore[arg-type]
+            float_conv.stride,  # type: ignore[arg-type]
+            float_conv.padding,  # type: ignore[arg-type]
+            float_conv.dilation,  # type: ignore[arg-type]
+            float_conv.groups,
+            float_conv.bias is not None,  # type: ignore[arg-type]
+            float_conv.padding_mode,
+            device=float_conv.weight.device,
+            dtype=float_conv.weight.dtype,
+            weight_qparams=weight_qparams)
+        qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
+        if float_conv.bias is not None:
+            qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
+        return qref_conv
+
+class Conv1d(_ConvNd, nn.Conv1d):
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_1_t,
+                 stride: _size_1_t = 1,
+                 padding: _size_1_t = 0,
+                 dilation: _size_1_t = 1,
+                 groups: int = 1,
+                 bias: bool = True,
+                 padding_mode: str = "zeros",
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.Conv1d.__init__(
+            self, in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.conv1d ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.conv1d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv1d
+        """
+        weight_quant_dequant = self.get_weight()
+        result = F.conv1d(
+            x, weight_quant_dequant, self.bias, self.stride,
+            self.padding, self.dilation, self.groups)
+        return result
+
+    def _get_name(self):
+        return "QuantizedConv1d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class Conv2d(_ConvNd, nn.Conv2d):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode='zeros',
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.Conv2d.__init__(
+            self, in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.conv2d ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.conv2d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv2d
+        """
+        weight_quant_dequant = self.get_weight()
+        result = F.conv2d(
+            x, weight_quant_dequant, self.bias, self.stride,
+            self.padding, self.dilation, self.groups)
+        return result
+
+    def _get_name(self):
+        return "QuantizedConv2d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class Conv3d(_ConvNd, nn.Conv3d):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True,
+                 padding_mode="zeros",
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.Conv3d.__init__(
+            self, in_channels, out_channels, kernel_size, stride, padding, dilation,
+            groups, bias, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.conv3d ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.conv3d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv3d
+        """
+        weight_quant_dequant = self.get_weight()
+        result = F.conv3d(
+            x, weight_quant_dequant, self.bias, self.stride,
+            self.padding, self.dilation, self.groups)
+        return result
+
+    def _get_name(self):
+        return "QuantizedConv3d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvNd.from_float(cls, float_conv, weight_qparams)
+
+class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd):
+    """ A reference version of nn.quantized.ConvTranspose2d
+        we will not pack the parameters in this module, since weight packing is an
+        optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
+        this is useful when user want to use this module in other backends like Glow.
+    """
+    @staticmethod
+    def from_float(cls, float_conv, weight_qparams):
+        qref_conv = cls(
+            float_conv.in_channels,
+            float_conv.out_channels,
+            float_conv.kernel_size,  # type: ignore[arg-type]
+            float_conv.stride,  # type: ignore[arg-type]
+            float_conv.padding,  # type: ignore[arg-type]
+            float_conv.output_padding,  # type: ignore[arg-type]
+            float_conv.groups,
+            float_conv.bias is not None,  # type: ignore[arg-type]
+            float_conv.dilation,  # type: ignore[arg-type]
+            float_conv.padding_mode,
+            device=float_conv.weight.device,
+            dtype=float_conv.weight.dtype,
+            weight_qparams=weight_qparams)
+        qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
+        if float_conv.bias is not None:
+            qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
+        return qref_conv
+
+
+class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: _size_1_t,
+                 stride: _size_1_t = 1,
+                 padding: _size_1_t = 0,
+                 output_padding: _size_1_t = 0,
+                 groups: int = 1,
+                 bias: bool = True,
+                 dilation: _size_1_t = 1,
+                 padding_mode: str = "zeros",
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.ConvTranspose1d.__init__(
+            self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
+            groups, bias, dilation, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.convTranspose1d ---
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.convTranspose1d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv1d
+        """
+
+        assert isinstance(self.padding, tuple)
+        # One cannot replace List by Tuple or Sequence in "_output_padding" because
+        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
+        output_padding = self._output_padding(
+            input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore[arg-type]
+
+        weight_quant_dequant = self.get_weight()
+        result = F.conv_transpose1d(
+            x, weight_quant_dequant, self.bias, self.stride,
+            self.padding, output_padding, self.groups, self.dilation)
+        return result
+
+    def _get_name(self):
+        return "QuantizedConvTranspose1d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
+
+class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0,
+                 groups=1, bias=True, dilation=1,
+                 padding_mode='zeros',
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+
+        nn.ConvTranspose2d.__init__(
+            self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
+            groups, bias, dilation, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.convTranspose2d ---
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.convTranspose2d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv2d
+        """
+        assert isinstance(self.padding, tuple)
+        # One cannot replace List by Tuple or Sequence in "_output_padding" because
+        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
+
+        output_padding = self._output_padding(
+            input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore[arg-type]
+
+        weight_quant_dequant = self.get_weight()
+        result = F.conv_transpose2d(
+            x, weight_quant_dequant, self.bias, self.stride,
+            self.padding, output_padding, self.groups, self.dilation)
+
+        return result
+
+    def _get_name(self):
+        return "QuantizedConvTranspose2d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
+
+class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, output_padding=0,
+                 groups=1, bias=True, dilation=1,
+                 padding_mode="zeros",
+                 device=None,
+                 dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None):
+        nn.ConvTranspose3d.__init__(
+            self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
+            groups, bias, dilation, padding_mode, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.convTranspose3d ---
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.convTranspose3d --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized conv3d
+        """
+
+        assert isinstance(self.padding, tuple)
+        # One cannot replace List by Tuple or Sequence in "_output_padding" because
+        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
+        output_padding = self._output_padding(
+            input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore[arg-type]
+
+        weight_quant_dequant = self.get_weight()
+        result = F.conv_transpose3d(
+            x, weight_quant_dequant, self.bias, self.stride,
+            self.padding, output_padding, self.groups, self.dilation)
+        return result
+
+    def _get_name(self):
+        return "QuantizedConvTranspose3d(Reference)"
+
+    @classmethod
+    def from_float(cls, float_conv, weight_qparams):
+        return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec5def7ad1ae00c47b68742219504be9cf06eb3a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/linear.py
@@ -0,0 +1,57 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Dict, Any
+from .utils import ReferenceQuantizedModule
+
+__all__ = ['Linear']
+
+class Linear(nn.Linear, ReferenceQuantizedModule):
+    """ A reference quantized linear module that fits into the FX
+    Graph Mode Quantization workflow
+    activation will be floating point Tensor, we will store floating
+    point weight as well in the module, but in forward we'll quantize
+    and dequantize the weight before running the floating point functional
+    linear operator.
+    """
+    _IS_REFERENCE = True
+
+    def __init__(
+            self,
+            in_features: int,
+            out_features: int,
+            bias_: bool = True,
+            device: Optional[torch.device] = None,
+            dtype: Optional[torch.dtype] = None,
+            weight_qparams: Optional[Dict[str, Any]] = None):
+        super().__init__(in_features, out_features, bias_, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def _get_name(self):
+        return "QuantizedLinear(Reference)"
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        we have:
+        w(float) -- quant - dequant \
+        x(float) ------------- F.linear ---
+
+        In the full model, we will see
+        w(float) -- quant - *dequant \
+        x -- quant --- *dequant --  *F.linear --- *quant - dequant
+        and the backend should be able to fuse the ops with `*` into a quantized linear
+        """
+        weight_quant_dequant = self.get_weight()
+        result = F.linear(x, weight_quant_dequant, self.bias)
+        return result
+
+    @classmethod
+    def from_float(cls, float_linear, weight_qparams):
+        qref_linear = Linear(
+            float_linear.in_features, float_linear.out_features,
+            float_linear.bias is not None, device=float_linear.weight.device,
+            dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
+        qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
+        if float_linear.bias is not None:
+            qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
+        return qref_linear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..98fa588e012178bc6af29529c13c9193a5b75213
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py
@@ -0,0 +1,614 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+from .utils import _quantize_and_dequantize_weight
+from .utils import _quantize_weight
+from typing import Optional, Dict, Any, Tuple
+from torch import _VF
+from torch.nn.utils.rnn import PackedSequence
+
+__all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight']
+
+def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
+    return tensor.index_select(dim, permutation)
+
+def _get_weight_and_quantization_params(module, wn):
+    weight = getattr(module, wn)
+    params = [weight]
+    for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]:
+        if hasattr(module, param_name):
+            param = getattr(module, param_name)
+        else:
+            param = None
+        params.append(param)
+    return params
+
+def get_quantized_weight(module, wn):
+    if not hasattr(module, wn):
+        return None
+    params = _get_weight_and_quantization_params(module, wn)
+    weight = _quantize_weight(*params)
+    return weight
+
+def _get_quantize_and_dequantized_weight(module, wn):
+    if not hasattr(module, wn):
+        return None
+    params = _get_weight_and_quantization_params(module, wn)
+    weight = _quantize_and_dequantize_weight(*params)
+    return weight
+
+class RNNCellBase(nn.RNNCellBase):
+    def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
+                 device=None, dtype=None, weight_qparams_dict=None) -> None:
+        super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype)
+        # TODO(jerryzh168): maybe make this arg a required arg
+        if weight_qparams_dict is None:
+            weight_qparams = {
+                "qscheme": torch.per_tensor_affine,
+                "dtype": torch.quint8,
+                "scale": 1.0,
+                "zero_point": 0
+            }
+            weight_qparams_dict = {
+                "weight_ih": weight_qparams,
+                "weight_hh": weight_qparams,
+                "is_decomposed": False,
+            }
+        assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
+        self._init_weight_qparams_dict(weight_qparams_dict, device)
+
+    def _init_weight_qparams_dict(self, weight_qparams_dict, device):
+        assert weight_qparams_dict is not None
+        self.is_decomposed = weight_qparams_dict["is_decomposed"]
+        for key, weight_qparams in weight_qparams_dict.items():
+            if key == "is_decomposed":
+                continue
+            # TODO: refactor the duplicated code to utils.py
+            weight_qscheme = weight_qparams["qscheme"]
+            weight_dtype = weight_qparams["dtype"]
+            setattr(self, key + "_qscheme", weight_qscheme)
+            setattr(self, key + "_dtype", weight_dtype)
+            assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
+                Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
+            if weight_qscheme is not None:
+                scale = weight_qparams["scale"]
+                scale_tensor = scale.clone().detach() \
+                    if isinstance(scale, torch.Tensor) else \
+                    torch.tensor(scale, dtype=torch.float, device=device)
+                self.register_buffer(key + "_scale", scale_tensor)
+                zp = weight_qparams["zero_point"]
+                zp_tensor = zp.clone().detach() \
+                    if isinstance(zp, torch.Tensor) else \
+                    torch.tensor(zp, dtype=torch.int, device=device)
+                self.register_buffer(key + "_zero_point", zp_tensor)
+                if weight_qscheme == torch.per_channel_affine:
+                    axis = weight_qparams["axis"]
+                    axis_tensor = axis.clone().detach() \
+                        if isinstance(axis, torch.Tensor) else \
+                        torch.tensor(axis, dtype=torch.int, device=device)
+                    self.register_buffer(key + "_axis", axis_tensor)
+                else:
+                    # added for TorchScriptability, not used
+                    self.register_buffer(
+                        key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
+                setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
+
+    def _get_name(self):
+        return "QuantizedRNNCellBase(Reference)"
+
+    def get_quantized_weight_ih(self):
+        return get_quantized_weight(self, "weight_ih")
+
+    def get_quantized_weight_hh(self):
+        return get_quantized_weight(self, "weight_hh")
+
+    def get_weight_ih(self):
+        return _get_quantize_and_dequantized_weight(self, "weight_ih")
+
+    def get_weight_hh(self):
+        return _get_quantize_and_dequantized_weight(self, "weight_hh")
+
+class RNNCell(RNNCellBase):
+    """
+    We'll store weight_qparams for all the weights (weight_ih and weight_hh),
+    we need to pass in a `weight_qparams_dict` that maps from weight name,
+    e.g. weight_ih, to the weight_qparams for that weight
+    """
+    def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
+                 device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
+        super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
+        self.nonlinearity = nonlinearity
+
+    def _get_name(self):
+        return "QuantizedRNNCell(Reference)"
+
+    # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
+    # and remove duplicated code, same for the other two Cell modules
+    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
+        assert input.dim() in (1, 2), \
+            f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
+        is_batched = input.dim() == 2
+        if not is_batched:
+            input = input.unsqueeze(0)
+
+        if hx is None:
+            hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        else:
+            hx = hx.unsqueeze(0) if not is_batched else hx
+
+        if self.nonlinearity == "tanh":
+            ret = _VF.rnn_tanh_cell(
+                input, hx,
+                self.get_weight_ih(), self.get_weight_hh(),
+                self.bias_ih, self.bias_hh,
+            )
+        elif self.nonlinearity == "relu":
+            ret = _VF.rnn_relu_cell(
+                input, hx,
+                self.get_weight_ih(), self.get_weight_hh(),
+                self.bias_ih, self.bias_hh,
+            )
+        else:
+            ret = input  # TODO: remove when jit supports exception flow
+            raise RuntimeError(
+                f"Unknown nonlinearity: {self.nonlinearity}")
+
+        if not is_batched:
+            ret = ret.squeeze(0)
+
+        return ret
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams_dict):
+        ref_mod = cls(
+            mod.input_size,
+            mod.hidden_size,
+            mod.bias,
+            mod.nonlinearity,
+            mod.weight_ih.device,
+            mod.weight_ih.dtype,
+            weight_qparams_dict)
+        ref_mod.weight_ih = mod.weight_ih
+        ref_mod.weight_hh = mod.weight_hh
+        ref_mod.bias_ih = mod.bias_ih
+        ref_mod.bias_hh = mod.bias_hh
+        return ref_mod
+
+class LSTMCell(RNNCellBase):
+    """
+    We'll store weight_qparams for all the weights (weight_ih and weight_hh),
+    we need to pass in a `weight_qparams_dict` that maps from weight name,
+    e.g. weight_ih, to the weight_qparams for that weight
+    """
+    def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
+                 device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
+        super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
+
+    def _get_name(self):
+        return "QuantizedLSTMCell(Reference)"
+
+    def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
+        assert input.dim() in (1, 2), \
+            f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
+        is_batched = input.dim() == 2
+        if not is_batched:
+            input = input.unsqueeze(0)
+
+        if hx is None:
+            zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+            hx = (zeros, zeros)
+        else:
+            hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
+
+        ret = _VF.lstm_cell(
+            input, hx,
+            self.get_weight_ih(), self.get_weight_hh(),
+            self.bias_ih, self.bias_hh,
+        )
+
+        if not is_batched:
+            ret = (ret[0].squeeze(0), ret[1].squeeze(0))
+        return ret
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams_dict):
+        ref_mod = cls(
+            mod.input_size,
+            mod.hidden_size,
+            mod.bias,
+            mod.weight_ih.device,
+            mod.weight_ih.dtype,
+            weight_qparams_dict)
+        ref_mod.weight_ih = mod.weight_ih
+        ref_mod.weight_hh = mod.weight_hh
+        ref_mod.bias_ih = mod.bias_ih
+        ref_mod.bias_hh = mod.bias_hh
+        return ref_mod
+
+class GRUCell(RNNCellBase):
+    """
+    We'll store weight_qparams for all the weights (weight_ih and weight_hh),
+    we need to pass in a `weight_qparams_dict` that maps from weight name,
+    e.g. weight_ih, to the weight_qparams for that weight
+    """
+    def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
+                 device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
+        factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
+        super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
+
+    def _get_name(self):
+        return "QuantizedGRUCell(Reference)"
+
+    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
+        assert input.dim() in (1, 2), \
+            f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
+        is_batched = input.dim() == 2
+        if not is_batched:
+            input = input.unsqueeze(0)
+
+        if hx is None:
+            hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        else:
+            hx = hx.unsqueeze(0) if not is_batched else hx
+
+        ret = _VF.gru_cell(
+            input, hx,
+            self.get_weight_ih(), self.get_weight_hh(),
+            self.bias_ih, self.bias_hh,
+        )
+
+        if not is_batched:
+            ret = ret.squeeze(0)
+
+        return ret
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams_dict):
+        ref_mod = cls(
+            mod.input_size,
+            mod.hidden_size,
+            mod.bias,
+            mod.weight_ih.device,
+            mod.weight_ih.dtype,
+            weight_qparams_dict)
+        ref_mod.weight_ih = mod.weight_ih
+        ref_mod.weight_hh = mod.weight_hh
+        ref_mod.bias_ih = mod.bias_ih
+        ref_mod.bias_hh = mod.bias_hh
+        return ref_mod
+
+class RNNBase(nn.RNNBase):
+    def __init__(self, mode: str, input_size: int, hidden_size: int,
+                 num_layers: int = 1, bias: bool = True, batch_first: bool = False,
+                 dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
+                 device=None, dtype=None,
+                 weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__(
+            mode, input_size, hidden_size, num_layers, bias, batch_first, dropout,
+            bidirectional, proj_size, device, dtype
+        )
+        # TODO(jerryzh168): maybe make this arg a required arg
+        if weight_qparams_dict is None:
+            weight_qparams = {
+                'qscheme': torch.per_tensor_affine,
+                'dtype': torch.quint8,
+                'scale': 1.0,
+                'zero_point': 0
+            }
+            weight_qparams_dict = {"is_decomposed": False}  # type: ignore[dict-item]
+            for wn in self._flat_weights_names:
+                if wn.startswith("weight"):
+                    weight_qparams_dict[wn] = weight_qparams
+        self._init_weight_qparams_dict(weight_qparams_dict, device)
+
+    def _init_weight_qparams_dict(self, weight_qparams_dict, device):
+        self.is_decomposed = weight_qparams_dict["is_decomposed"]
+        for key, weight_qparams in weight_qparams_dict.items():
+            if key == "is_decomposed":
+                continue
+            weight_qscheme = weight_qparams["qscheme"]
+            weight_dtype = weight_qparams["dtype"]
+            setattr(self, key + "_qscheme", weight_qscheme)
+            setattr(self, key + "_dtype", weight_dtype)
+            assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
+                Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
+            if weight_qscheme is not None:
+                self.register_buffer(
+                    key + "_scale",
+                    torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
+                self.register_buffer(
+                    key + "_zero_point",
+                    torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
+                if weight_qscheme == torch.per_channel_affine:
+                    self.register_buffer(
+                        key + "_axis",
+                        torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
+                else:
+                    # added for TorchScriptability, not used
+                    self.register_buffer(
+                        key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
+                setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
+
+class LSTM(RNNBase):
+    """ Reference Quantized LSTM Module
+    We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
+    a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
+    to the weight_qparams for that weight
+    """
+    def __init__(self, *args, **kwargs):
+        super().__init__('LSTM', *args, **kwargs)
+
+    # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
+    def permute_hidden(self,  # type: ignore[override]
+                       hx: Tuple[Tensor, Tensor],
+                       permutation: Optional[Tensor]
+                       ) -> Tuple[Tensor, Tensor]:
+        if permutation is None:
+            return hx
+        return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
+
+    def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
+        if batch_sizes is not None:
+            mini_batch = int(batch_sizes[0])
+        else:
+            mini_batch = input.size(0) if self.batch_first else input.size(1)
+        num_directions = 2 if self.bidirectional else 1
+        expected_hidden_size = (self.num_layers * num_directions,
+                                mini_batch, self.hidden_size)
+        return expected_hidden_size
+
+    # In the future, we should prevent mypy from applying contravariance rules here.
+    # See torch/nn/modules/module.py::_forward_unimplemented
+    def check_forward_args(self,  # type: ignore[override]
+                           input: Tensor,
+                           hidden: Tuple[Tensor, Tensor],
+                           batch_sizes: Optional[Tensor],
+                           ):
+        self.check_input(input, batch_sizes)
+        self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
+                               'Expected hidden[0] size {}, got {}')
+        self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
+                               'Expected hidden[1] size {}, got {}')
+
+    def get_quantized_weight_bias_dict(self):
+        """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
+        e.g.
+        {
+          "weight_ih_l0": quantized_weight,
+          "bias_ih_l0": unquantized_bias,
+          ...
+        }
+        """
+        quantized_weight_bias_dict = {}
+        for wn in self._flat_weights_names:
+            if hasattr(self, wn):
+                if wn.startswith("weight"):
+                    weight_or_bias = get_quantized_weight(self, wn)
+                else:
+                    weight_or_bias = getattr(self, wn)
+            else:
+                weight_or_bias = None
+            quantized_weight_bias_dict[wn] = weight_or_bias
+        return quantized_weight_bias_dict
+
+    def get_flat_weights(self):
+        flat_weights = []
+        for wn in self._flat_weights_names:
+            if hasattr(self, wn):
+                weight = getattr(self, wn)
+                if wn.startswith("weight"):
+                    params = _get_weight_and_quantization_params(self, wn)
+                    weight = _quantize_and_dequantize_weight(*params)
+            else:
+                weight = None
+            flat_weights.append(weight)
+        return flat_weights
+
+    def forward(self, input, hx=None):  # noqa: F811
+        orig_input = input
+        # xxx: isinstance check needs to be in conditional for TorchScript to compile
+        batch_sizes = None
+        if isinstance(orig_input, PackedSequence):
+            input, batch_sizes, sorted_indices, unsorted_indices = input
+            max_batch_size = int(batch_sizes[0])
+        else:
+            batch_sizes = None
+            is_batched = input.dim() == 3
+            batch_dim = 0 if self.batch_first else 1
+            if not is_batched:
+                input = input.unsqueeze(batch_dim)
+            max_batch_size = input.size(0) if self.batch_first else input.size(1)
+            sorted_indices = None
+            unsorted_indices = None
+
+        if hx is None:
+            num_directions = 2 if self.bidirectional else 1
+            real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
+            h_zeros = torch.zeros(self.num_layers * num_directions,
+                                  max_batch_size, real_hidden_size,
+                                  dtype=input.dtype, device=input.device)
+            c_zeros = torch.zeros(self.num_layers * num_directions,
+                                  max_batch_size, self.hidden_size,
+                                  dtype=input.dtype, device=input.device)
+            hx = (h_zeros, c_zeros)
+        else:
+            if batch_sizes is None:  # If not PackedSequence input.
+                if is_batched:  # type: ignore[possibly-undefined]
+                    if (hx[0].dim() != 3 or hx[1].dim() != 3):
+                        msg = ("For batched 3-D input, hx and cx should "
+                               f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
+                        raise RuntimeError(msg)
+                else:
+                    if hx[0].dim() != 2 or hx[1].dim() != 2:
+                        msg = ("For unbatched 2-D input, hx and cx should "
+                               f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
+                        raise RuntimeError(msg)
+                    hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
+
+            # Each batch of the hidden state should match the input sequence that
+            # the user believes he/she is passing in.
+            hx = self.permute_hidden(hx, sorted_indices)
+
+        self.check_forward_args(input, hx, batch_sizes)
+        if batch_sizes is None:
+            result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
+                              self.dropout, self.training, self.bidirectional, self.batch_first)
+        else:
+            result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
+                              self.num_layers, self.dropout, self.training, self.bidirectional)
+        output = result[0]
+        hidden = result[1:]
+        # xxx: isinstance check needs to be in conditional for TorchScript to compile
+        if isinstance(orig_input, PackedSequence):
+            output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
+            return output_packed, self.permute_hidden(hidden, unsorted_indices)
+        else:
+            if not is_batched:  # type: ignore[possibly-undefined]
+                output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
+                hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
+            return output, self.permute_hidden(hidden, unsorted_indices)
+
+    def _get_name(self):
+        return "QuantizedLSTM(Reference)"
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams_dict):
+        ref_mod = cls(
+            mod.input_size,
+            mod.hidden_size,
+            mod.num_layers,
+            mod.bias,
+            mod.batch_first,
+            mod.dropout,
+            mod.bidirectional,
+            weight_qparams_dict=weight_qparams_dict)
+        for wn in mod._flat_weights_names:
+            setattr(ref_mod, wn, getattr(mod, wn))
+        return ref_mod
+
+class GRU(RNNBase):
+    """ Reference Quantized GRU Module
+    We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
+    a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
+    to the weight_qparams for that weight
+    """
+    def __init__(self, *args, **kwargs):
+        if 'proj_size' in kwargs:
+            raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
+        super().__init__('GRU', *args, **kwargs)
+
+    def get_quantized_weight_bias_dict(self):
+        """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
+        e.g.
+        {
+          "weight_ih_l0": quantized_weight,
+          "bias_ih_l0": unquantized_bias,
+          ...
+        }
+        """
+        quantized_weight_bias_dict = {}
+        for wn in self._flat_weights_names:
+            if hasattr(self, wn):
+                if wn.startswith("weight"):
+                    weight_or_bias = get_quantized_weight(self, wn)
+                else:
+                    weight_or_bias = getattr(self, wn)
+            else:
+                weight_or_bias = None
+            quantized_weight_bias_dict[wn] = weight_or_bias
+        return quantized_weight_bias_dict
+
+    def get_flat_weights(self):
+        flat_weights = []
+        for wn in self._flat_weights_names:
+            if hasattr(self, wn):
+                weight = getattr(self, wn)
+                if wn.startswith("weight"):
+                    params = _get_weight_and_quantization_params(self, wn)
+                    weight = _quantize_and_dequantize_weight(*params)
+            else:
+                weight = None
+            flat_weights.append(weight)
+        return flat_weights
+
+    def forward(self, input, hx=None):  # noqa: F811
+        # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
+        # only changed self._flat_weights to self.get_flat_weights()
+        # TODO: maybe we can try inheriting from that class and define get_flat_weights
+        # as a @property? this might interfere with TorchScript, if we remove that
+        # requirement in the future we should be able to do this
+        orig_input = input
+        # xxx: isinstance check needs to be in conditional for TorchScript to compile
+        if isinstance(orig_input, PackedSequence):
+            input, batch_sizes, sorted_indices, unsorted_indices = input
+            max_batch_size = int(batch_sizes[0])
+        else:
+            batch_sizes = None
+            assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
+            is_batched = input.dim() == 3
+            batch_dim = 0 if self.batch_first else 1
+            if not is_batched:
+                input = input.unsqueeze(batch_dim)
+                if hx is not None:
+                    if hx.dim() != 2:
+                        raise RuntimeError(
+                            f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
+                    hx = hx.unsqueeze(1)
+            else:
+                if hx is not None and hx.dim() != 3:
+                    raise RuntimeError(
+                        f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
+            max_batch_size = input.size(0) if self.batch_first else input.size(1)
+            sorted_indices = None
+            unsorted_indices = None
+
+        if hx is None:
+            num_directions = 2 if self.bidirectional else 1
+            hx = torch.zeros(self.num_layers * num_directions,
+                             max_batch_size, self.hidden_size,
+                             dtype=input.dtype, device=input.device)
+        else:
+            # Each batch of the hidden state should match the input sequence that
+            # the user believes he/she is passing in.
+            hx = self.permute_hidden(hx, sorted_indices)
+
+        self.check_forward_args(input, hx, batch_sizes)
+        if batch_sizes is None:
+            result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
+                             self.dropout, self.training, self.bidirectional, self.batch_first)
+        else:
+            result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
+                             self.num_layers, self.dropout, self.training, self.bidirectional)
+        output = result[0]
+        hidden = result[1]
+
+        # xxx: isinstance check needs to be in conditional for TorchScript to compile
+        if isinstance(orig_input, PackedSequence):
+            output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
+            return output_packed, self.permute_hidden(hidden, unsorted_indices)
+        else:
+            if not is_batched:  # type: ignore[possibly-undefined]
+                output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
+                hidden = hidden.squeeze(1)
+
+            return output, self.permute_hidden(hidden, unsorted_indices)
+
+    def _get_name(self):
+        return "QuantizedGRU(Reference)"
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams_dict):
+        ref_mod = cls(
+            mod.input_size,
+            mod.hidden_size,
+            mod.num_layers,
+            mod.bias,
+            mod.batch_first,
+            mod.dropout,
+            mod.bidirectional,
+            weight_qparams_dict=weight_qparams_dict)
+        for wn in mod._flat_weights_names:
+            setattr(ref_mod, wn, getattr(mod, wn))
+        return ref_mod
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2d6141d9de6321f89ac691404c060c821cd7719
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py
@@ -0,0 +1,94 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from .utils import ReferenceQuantizedModule
+from typing import Optional, Dict, Any
+
+__all__ = ['Embedding', 'EmbeddingBag']
+
+class Embedding(nn.Embedding, ReferenceQuantizedModule):
+    """ A reference quantized Embedding module that fits into the
+    FX Graph Mode Quantization workflow, activation will be floating point Tensor,
+    we will store floating point weight as well in the module, but in forward we'll
+    quantize and dequantize the weight before running the floating point functional
+    embedding operator.
+    """
+    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+                 sparse: bool = False, _weight: Optional[Tensor] = None,
+                 device=None, dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
+                         norm_type, scale_grad_by_freq, sparse, _weight, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def _get_name(self):
+        return "QuantizedEmbedding(Reference)"
+
+    def forward(self, input: Tensor) -> Tensor:
+        weight_quant_dequant = self.get_weight()
+        return F.embedding(
+            input, weight_quant_dequant, self.padding_idx, self.max_norm,
+            self.norm_type, self.scale_grad_by_freq, self.sparse)
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams):
+        return cls(
+            mod.num_embeddings,
+            mod.embedding_dim,
+            mod.padding_idx,
+            mod.max_norm,
+            mod.norm_type,
+            mod.scale_grad_by_freq,
+            mod.sparse,
+            mod.weight,
+            mod.weight.device,
+            mod.weight.dtype,
+            weight_qparams)
+
+class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
+    """ A reference quantized EmbeddingBag module that fits into the
+    FX Graph Mode Quantization workflow, activation will be floating point Tensor,
+    we will store floating point weight as well in the module, but in forward we'll
+    quantize and dequantize the weight before running the floating point functional
+    embedding operator.
+    """
+    def __init__(self, num_embeddings: int, embedding_dim: int,
+                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+                 mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
+                 include_last_offset: bool = False, padding_idx: Optional[int] = None,
+                 device=None, dtype=None,
+                 weight_qparams: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
+                         scale_grad_by_freq, mode, sparse, _weight, include_last_offset,
+                         padding_idx, device, dtype)
+        self._init_weight_qparams(weight_qparams, device)
+
+    def _get_name(self):
+        return "QuantizedEmbedding(Reference)"
+
+    def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
+        weight_quant_dequant = self.get_weight()
+        return F.embedding_bag(input, weight_quant_dequant, offsets,
+                               self.max_norm, self.norm_type,
+                               self.scale_grad_by_freq, self.mode, self.sparse,
+                               per_sample_weights, self.include_last_offset,
+                               self.padding_idx)
+
+    @classmethod
+    def from_float(cls, mod, weight_qparams):
+        return cls(
+            mod.num_embeddings,
+            mod.embedding_dim,
+            mod.max_norm,
+            mod.norm_type,
+            mod.scale_grad_by_freq,
+            mod.mode,
+            mod.sparse,
+            mod.weight,
+            mod.include_last_offset,
+            mod.padding_idx,
+            mod.weight.device,
+            mod.weight.dtype,
+            weight_qparams
+        )
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/utils.py b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..502f77496c4b658ed521e27404af99e4e7cab4b6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/quantized/reference/modules/utils.py
@@ -0,0 +1,323 @@
+import torch
+import typing
+
+__all__ = [
+    "ReferenceQuantizedModule",
+]
+
+class ReferenceQuantizedModule(torch.nn.Module):
+    def _init_weight_qparams(self, weight_qparams, device):
+        if weight_qparams is None:
+            weight_qparams = {
+                "qscheme": torch.per_tensor_affine,
+                "dtype": torch.quint8,
+                "scale": 1.0,
+                "zero_point": 0
+            }
+        self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
+        self.weight_dtype = weight_qparams["dtype"]
+        assert self.weight_qscheme in [
+            None, torch.per_tensor_affine, torch.per_channel_affine,
+            torch.per_channel_affine_float_qparams], \
+            Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
+        if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
+            zero_point_dtype = weight_qparams["zero_point"].dtype if \
+                isinstance(weight_qparams["zero_point"], torch.Tensor) else \
+                torch.int
+            w_scale = weight_qparams["scale"]
+            w_scale_tensor = w_scale.clone().detach() \
+                if isinstance(w_scale, torch.Tensor) \
+                else torch.tensor(w_scale, dtype=torch.float, device=device)
+            self.register_buffer("weight_scale", w_scale_tensor)
+            w_zp = weight_qparams["zero_point"]
+            w_zp_tensor = w_zp.clone().detach() \
+                if isinstance(w_zp, torch.Tensor) \
+                else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
+            self.register_buffer("weight_zero_point", w_zp_tensor)
+            if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
+                w_axis = weight_qparams["axis"]
+                w_axis_tensor = w_axis.clone().detach() \
+                    if isinstance(w_axis, torch.Tensor) \
+                    else torch.tensor(w_axis, dtype=torch.int, device=device)
+                self.register_buffer("weight_axis", w_axis_tensor)
+            else:
+                # added for TorchScriptability, not used
+                self.register_buffer(
+                    "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
+        else:
+            # added for TorchScriptability, and for torch.float
+            self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
+            self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
+            self.register_buffer(
+                "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
+        self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
+        # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
+        # for capturing `.item` operations
+        self.weight_axis_int: int = self.weight_axis.item()  # type: ignore[operator, assignment]
+        self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min", None)
+        self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max", None)
+
+    def get_weight(self):
+        """
+        Fake quantize (quantize and dequantize) the weight with
+        the quantization parameters for weight, this is used to
+        simulate the numerics for the quantized weight in a quantized
+        model
+        """
+        # suppress mypy warning
+        assert isinstance(self.weight_scale, torch.Tensor)
+        assert isinstance(self.weight_zero_point, torch.Tensor)
+        if self.is_decomposed:
+            return _quantize_and_dequantize_weight_decomposed(
+                self.weight,  # type: ignore[arg-type]
+                self.weight_qscheme,
+                self.weight_dtype,
+                self.weight_scale,
+                self.weight_zero_point,
+                self.weight_axis_int,
+                self.weight_quant_min,
+                self.weight_quant_max)
+        else:
+            return _quantize_and_dequantize_weight(
+                self.weight,  # type: ignore[arg-type]
+                self.weight_qscheme,
+                self.weight_dtype,
+                self.weight_scale,
+                self.weight_zero_point,
+                self.weight_axis_int)
+
+    def get_quantized_weight(self):
+        # suppress mypy warning
+        assert isinstance(self.weight_scale, torch.Tensor)
+        assert isinstance(self.weight_zero_point, torch.Tensor)
+        # assert isinstance(self.weight_axis, torch.Tensor)
+        if self.is_decomposed:
+            return _quantize_weight_decomposed(
+                self.weight,  # type: ignore[arg-type]
+                self.weight_qscheme,
+                self.weight_dtype,
+                self.weight_scale,
+                self.weight_zero_point,
+                self.weight_axis_int,
+                self.weight_quant_min,
+                self.weight_quant_max)
+        else:
+            return _quantize_weight(
+                self.weight,  # type: ignore[arg-type]
+                self.weight_qscheme,
+                self.weight_dtype,
+                self.weight_scale,
+                self.weight_zero_point,
+                self.weight_axis_int)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        _save_weight_qparams(
+            destination, prefix, self.weight_qscheme, self.weight_dtype,
+            self.weight_scale, self.weight_zero_point, self.weight_axis)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        for key in _get_weight_qparam_keys(state_dict, prefix):
+            setattr(self, key, state_dict[prefix + key])
+            state_dict.pop(prefix + key)
+
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, False,
+            missing_keys, unexpected_keys, error_msgs)
+
+def _quantize_weight_decomposed(
+        weight: torch.Tensor,
+        weight_qscheme: torch.qscheme,
+        weight_dtype: torch.dtype,
+        weight_scale: torch.Tensor,
+        weight_zero_point: torch.Tensor,
+        weight_axis: int,
+        weight_quant_min: typing.Optional[int],
+        weight_quant_max: typing.Optional[int],
+) -> torch.Tensor:
+    _DTYPE_TO_QVALUE_BOUNDS = {
+        torch.uint8: (0, 255),
+        torch.int8: (-128, 127),
+        torch.int32: (-(2**31), 2**31 - 1),
+    }
+    # TODO: add an util function for converting qdtype to dtype
+    _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
+        torch.quint8: torch.uint8,
+        torch.qint8: torch.int8,
+        torch.qint32: torch.int32,
+    }
+    if weight_qscheme == torch.per_tensor_affine:
+        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
+            weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
+            if weight_quant_min is None or weight_quant_max is None:
+                weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
+            weight = torch.ops.quantized_decomposed.quantize_per_tensor(
+                weight,
+                weight_scale,
+                weight_zero_point,
+                weight_quant_min,
+                weight_quant_max,
+                weight_dtype_
+            )
+            return weight
+    elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
+        # TODO: torch.quint4x2 is not supported
+        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
+            weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
+            if weight_quant_min is None or weight_quant_max is None:
+                weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
+            weight = torch.ops.quantized_decomposed.quantize_per_channel(
+                weight,
+                weight_scale,
+                weight_zero_point,
+                weight_axis,
+                weight_quant_min,
+                weight_quant_max,
+                weight_dtype_)  # type: ignore[arg-type]
+            return weight
+    raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
+
+def _dequantize_weight_decomposed(
+        weight: torch.Tensor,
+        weight_qscheme: torch.qscheme,
+        weight_dtype: torch.dtype,
+        weight_scale: torch.Tensor,
+        weight_zero_point: torch.Tensor,
+        weight_axis: int,
+        weight_quant_min: typing.Optional[int],
+        weight_quant_max: typing.Optional[int],
+) -> torch.Tensor:
+    # TODO: get the quant_min and quant_max from activation_post_process
+    _DTYPE_TO_QVALUE_BOUNDS = {
+        torch.uint8: (0, 255),
+        torch.int8: (-128, 127),
+        torch.int32: (-(2**31), 2**31 - 1),
+    }
+    # TODO: add an util function for converting qdtype to dtype
+    _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
+        torch.quint8: torch.uint8,
+        torch.qint8: torch.int8,
+        torch.qint32: torch.int32,
+    }
+    weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
+    if weight_quant_min is None or weight_quant_max is None:
+        weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
+    if weight_qscheme == torch.per_tensor_affine:
+        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
+            weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
+                weight,
+                weight_scale,
+                weight_zero_point,
+                weight_quant_min,
+                weight_quant_max,
+                weight_dtype_
+            )
+            return weight
+    elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
+        # TODO: torch.quint4x2 is not supported
+        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
+            weight = torch.ops.quantized_decomposed.dequantize_per_channel(
+                weight,
+                weight_scale,
+                weight_zero_point,
+                weight_axis,
+                weight_quant_min,
+                weight_quant_max,
+                weight_dtype_)  # type: ignore[arg-type]
+            return weight
+    raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
+
+def _quantize_weight(
+        weight: torch.Tensor,
+        weight_qscheme: torch.qscheme,
+        weight_dtype: torch.dtype,
+        weight_scale: torch.Tensor,
+        weight_zero_point: torch.Tensor,
+        weight_axis_int: int
+) -> torch.Tensor:
+    if weight_dtype == torch.float16:
+        weight = weight.to(weight_dtype)
+        return weight
+
+    if weight_qscheme == torch.per_tensor_affine:
+        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
+            weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
+            return weight
+    elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
+        if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
+            weight = torch.quantize_per_channel(
+                weight, weight_scale,
+                weight_zero_point, weight_axis_int, weight_dtype)  # type: ignore[arg-type]
+            return weight
+    raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
+
+def _quantize_and_dequantize_weight_decomposed(
+        weight: torch.Tensor,
+        weight_qscheme: torch.qscheme,
+        weight_dtype: torch.dtype,
+        weight_scale: torch.Tensor,
+        weight_zero_point: torch.Tensor,
+        weight_axis_int: int,
+        weight_quant_min: typing.Optional[int],
+        weight_quant_max: typing.Optional[int],
+) -> torch.Tensor:
+    """ Quantize and then dequantize the weight based on
+    the quantization parameters
+    """
+    if weight_qscheme in [
+            torch.per_tensor_affine,
+            torch.per_channel_affine,
+            torch.per_channel_affine_float_qparams]:
+        weight_quant = _quantize_weight_decomposed(
+            weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int,
+            weight_quant_min, weight_quant_max)
+        weight_dequant = _dequantize_weight_decomposed(
+            weight_quant, weight_qscheme, weight_dtype, weight_scale, weight_zero_point,
+            weight_axis_int, weight_quant_min, weight_quant_max)
+    else:
+        weight_dequant = weight
+    return weight_dequant
+
+def _quantize_and_dequantize_weight(
+        weight: torch.Tensor,
+        weight_qscheme: torch.qscheme,
+        weight_dtype: torch.dtype,
+        weight_scale: torch.Tensor,
+        weight_zero_point: torch.Tensor,
+        weight_axis_int: int
+) -> torch.Tensor:
+    """ Quantize and then dequantize the weight based on
+    the quantization parameters
+    """
+    if weight_qscheme in [
+            torch.per_tensor_affine,
+            torch.per_channel_affine,
+            torch.per_channel_affine_float_qparams]:
+        weight_quant = _quantize_weight(
+            weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int)
+        weight_dequant = weight_quant.dequantize()
+    else:
+        weight_dequant = weight
+    return weight_dequant
+
+def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
+    destination[prefix + "weight_qscheme"] = weight_qscheme
+    destination[prefix + "weight_dtype"] = weight_dtype
+    if weight_qscheme is not None:
+        destination[prefix + "weight_scale"] = weight_scale
+        destination[prefix + "weight_zero_point"] = weight_zero_point
+        if weight_qscheme == torch.per_channel_affine:
+            destination[prefix + "weight_axis"] = weight_axis
+
+def _get_weight_qparam_keys(
+        state_dict: typing.Dict[str, typing.Any],
+        prefix: str):
+    keys = ["weight_qscheme", "weight_dtype"]
+    weight_qscheme = state_dict[prefix + "weight_qscheme"]
+    if weight_qscheme is not None:
+        keys.append("weight_scale")
+        keys.append("weight_zero_point")
+        if weight_qscheme == torch.quantize_per_channel:
+            keys.append("weight_axis")
+    return keys
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/sparse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..308f4d5b55c44b6749a5ede08d5ccc40a09b7bca
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/sparse/__init__.py
@@ -0,0 +1 @@
+from . import quantized
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fac7e7d34ad9937514fe9bb19e71be2fc368f5b5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b030a9c22844690d61bb20af96bc26960968f01
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__init__.py
@@ -0,0 +1,10 @@
+from torch.ao.nn.sparse.quantized import dynamic
+
+from .linear import Linear
+from .linear import LinearPackedParams
+
+__all__ = [
+    "dynamic",
+    "Linear",
+    "LinearPackedParams",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4ae3e84e4fc8af7986698914eb78a02f79c12e2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..474eb8e50dd07082f1de8bbc2f7e0eb7506ab294
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fad593ee14add0029a6078f8dc3e1f58920e39c5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a961737c6c5944b99b5341c55d12fbd95c603e8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py
@@ -0,0 +1,5 @@
+from .linear import Linear
+
+__all__ = [
+    "Linear",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ecb1c9b4de541dd32269a3f2fd226afc829cf92
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3ed9c1b3d07b7223b7dd0f857cbec2ee7e03372
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..855c7ad391de4bf69b0246f58d44acefca2d064b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py
@@ -0,0 +1,139 @@
+from typing import Optional
+
+import torch
+import torch.ao.nn.intrinsic as nni
+
+from torch.ao.nn.sparse.quantized import linear
+from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern
+from torch.ao.nn.quantized.modules.utils import _quantize_weight, _hide_packed_params_repr
+
+__all__ = ['Linear']
+
+class Linear(torch.nn.Module):
+    r"""
+    A dynamically quantized sparse linear module with float tensor as inputs and outputs.
+    """
+    _version = 1
+    _op_type = "sparse_dynamic"
+    _FLOAT_MODULE = torch.nn.Linear
+
+    def __init__(self, in_features, out_features, row_block_size, col_block_size, bias=True, dtype=torch.qint8):
+        super().__init__()
+
+        if dtype != torch.qint8:
+            raise NotImplementedError("Only QINT8 is supported for Sparse Quantized Linear Dynamic")
+
+        self.in_features = in_features
+        self.out_features = out_features
+
+        if bias:
+            bias = torch.zeros(self.out_features, dtype=torch.float)
+        else:
+            bias = None
+
+        qweight = torch._empty_affine_quantized([out_features, in_features],
+                                                scale=1, zero_point=0, dtype=torch.qint8)
+        self._packed_params = linear.LinearPackedParams(row_block_size=row_block_size,
+                                                        col_block_size=col_block_size,
+                                                        dtype=dtype)
+        self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size)
+
+    def _get_name(self):
+        return 'SparseQuantizedDynamicLinear'
+
+    def extra_repr(self):
+        return f'in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}'
+
+    def __repr__(self):
+        return _hide_packed_params_repr(self, linear.LinearPackedParams)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'op_type'] = self._op_type
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        op_type = int(state_dict[prefix + 'op_type'])
+        assert op_type == 'sparse', \
+            f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]"
+        state_dict.pop(prefix + 'op_type')
+
+        version = local_metadata.get('version', None)
+        assert version <= self._version
+
+        # Is this code valid? In old quantization it seemed to be used to load
+        # older model
+        weight = state_dict.pop(prefix + 'weight')
+        bias = state_dict.pop(prefix + 'bias')
+        state_dict.update({prefix + '_packed_params.weight': weight,
+                           prefix + '_packed_params.bias': bias})
+
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, False,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def _weight_bias(self):
+        return self._packed_params._weight_bias()
+
+    def weight(self):
+        return self._weight_bias()[0]
+
+    def bias(self):
+        return self._weight_bias()[1]
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor],
+                        row_block_size: Optional[int], col_block_size: Optional[int]) -> None:
+        assert row_block_size is not None and col_block_size is not None
+        self.out_features = w.shape[0]
+        self.in_features = w.shape[1]
+        self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a quantized sparse dynamic module from a float module.
+
+        We only care about the convert at this stage, no need for observers just yet.
+        """
+        assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
+            cls._FLOAT_MODULE.__name__
+        # TODO: Need to add options to qconfig to avoid the calibration.
+        # TODO: Add calibration for the sparsity
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        if type(mod) == nni.LinearReLU:
+            mod = mod[0]
+        if mod.qconfig is not None and mod.qconfig.weight is not None:
+            weight_observer = mod.qconfig.weight()
+        else:
+            # We have the circular import issues if we import the qconfig in the beginning of this file:
+            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
+            # import until we need it.
+            from torch.ao.quantization.qconfig import default_dynamic_qconfig
+            weight_observer = default_dynamic_qconfig.weight()
+
+        # It is important to multiply by the mask BEFORE calling the `weight_observer`
+        # TODO (zaf): Mask might not be part of the qconfig (T83295194)
+        weight = mod.weight
+        if getattr(mod.qconfig, 'mask', False):
+            weight = mod.qconfig.mask * mod.weight
+
+        weight_observer(weight)
+        dtype = weight_observer.dtype
+        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+        w_sc, w_zp = weight_observer.calculate_qparams()
+        if isinstance(w_zp, torch.Tensor):
+            assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
+        else:
+            assert w_zp == 0, 'Weight zero point must map to 0'
+        qweight = _quantize_weight(weight.float(), weight_observer)
+
+        row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
+        qlinear = cls(mod.in_features,
+                      mod.out_features,
+                      row_block_size,
+                      col_block_size,
+                      dtype=dtype)
+        qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
+        return qlinear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9723e1760c9e88417928bcf545d5dccbad20e3a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/linear.py
@@ -0,0 +1,197 @@
+from typing import Optional
+
+import torch
+from torch.ao.nn.quantized.modules.utils import _quantize_weight, _hide_packed_params_repr
+
+__all__ = ['LinearPackedParams', 'Linear']
+
+# TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430)
+class LinearPackedParams(torch.nn.Module):
+    _version = 1
+
+    def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8):
+        super().__init__()
+
+        if dtype != torch.qint8:
+            raise NotImplementedError("Linear prepacking only supports QINT8")
+        self.dtype = dtype
+        wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
+        self.set_weight_bias(wq, None, row_block_size, col_block_size)
+
+    def _get_name(self):
+        return "SparseQuantizedLinearPackedParams"
+
+    @torch.jit.export
+    def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor],
+                        row_block_size: Optional[int], col_block_size: Optional[int]) -> None:
+        assert row_block_size is not None and col_block_size is not None
+        self._packed_params = torch.ops.sparse.qlinear_prepack(weight, bias, row_block_size, col_block_size)
+
+    @torch.jit.export
+    def _weight_bias(self):
+        (weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack(self._packed_params)
+        return (weight, bias, block_sizes[0], block_sizes[1])
+
+    def forward(self, x):
+        return x
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'dtype'] = self.dtype
+        destination[prefix + '_packed_params'] = self._weight_bias()
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+        assert version <= self._version
+
+        self.dtype = state_dict.pop(prefix + 'dtype')
+        weight, bias, row_block_size, col_block_size = state_dict.pop(prefix + '_packed_params')
+        self.set_weight_bias(weight, bias, row_block_size, col_block_size)
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+    @torch.jit.export
+    def __getstate__(self):
+        return self._packed_params, self.training, self.dtype
+
+    @torch.jit.export
+    def __setstate__(self, state):
+        (self._packed_params, self.training, self.dtype) = state
+
+    def __repr__(self):
+        return self._weight_bias().__repr__()
+
+# TODO (zaf): Inherit from `quantized.Linear` (T83294430)
+class Linear(torch.nn.Module):
+    r"""
+    A quantized sparse linear module with quantized tensor as inputs and outputs.
+    """
+    _version = 1
+    _FLOAT_MODULE = torch.nn.Linear
+
+    def __init__(self, in_features, out_features, row_block_size, col_block_size, bias=True, dtype=torch.qint8):
+        super().__init__()
+
+        if dtype != torch.qint8:
+            raise NotImplementedError("Only QINT8 is supported for Sparse Quantized Linear")
+
+        self.in_features = in_features
+        self.out_features = out_features
+
+        if bias:
+            bias = torch.zeros(self.out_features, dtype=torch.float)
+        else:
+            bias = None
+
+        qweight = torch._empty_affine_quantized([out_features, in_features],
+                                                scale=1, zero_point=0, dtype=torch.qint8)
+        self._packed_params = LinearPackedParams(row_block_size=row_block_size,
+                                                 col_block_size=col_block_size,
+                                                 dtype=dtype)
+        self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size)
+        self.scale = 1.0
+        self.zero_point = 0
+
+    @classmethod
+    def _get_name(cls):
+        return 'SparseQuantizedLinear'
+
+    def extra_repr(self):
+        return 'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'.format(
+            self.in_features, self.out_features, self.scale, self.zero_point, self.weight().qscheme()
+        )
+
+    def __repr__(self):
+        return _hide_packed_params_repr(self, LinearPackedParams)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.ops.sparse.qlinear(x, self._packed_params._packed_params, self.scale, self.zero_point)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'scale'] = torch.tensor(self.scale)
+        destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        self.scale = float(state_dict[prefix + 'scale'])
+        state_dict.pop(prefix + 'scale')
+
+        self.zero_point = int(state_dict[prefix + 'zero_point'])
+        state_dict.pop(prefix + 'zero_point')
+
+        op_type = int(state_dict[prefix + 'op_type'])
+        state_dict.pop(prefix + 'op_type')
+
+        version = local_metadata.get('version', None)
+        assert version <= self._version
+
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, False,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def _weight_bias(self):
+        return self._packed_params._weight_bias()
+
+    def weight(self):
+        return self._weight_bias()[0]
+
+    def bias(self):
+        return self._weight_bias()[1]
+
+    def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor],
+                        row_block_size: Optional[int], col_block_size: Optional[int]) -> None:
+        assert row_block_size is not None and col_block_size is not None
+        self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
+
+    @classmethod
+    def from_float(cls, mod):
+        r"""Create a quantized sparse module from a float module.
+
+        We only care about the convert at this stage, no need for observers just yet.
+
+        TODO(zaf): Need to add the sparse params to the qconfig
+        """
+        assert type(mod) == cls._FLOAT_MODULE, cls._get_name() + \
+            '.from_float only works for ' + cls._FLOAT_MODULE.__name__
+        assert hasattr(mod, 'sparse_params'), \
+            ('Expecting the Linear to have `sparse_params`. Make sure you have provided arguments '
+             'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.')
+        sparse_block_shape = mod.sparse_params.get('sparse_block_shape', None)  # type: ignore[operator, union-attr]
+        assert isinstance(sparse_block_shape, (tuple, list))
+        assert len(sparse_block_shape) == 2
+        # TODO: Need to add options to qconfig to avoid the calibration.
+        # TODO: Add calibration for the sparsity
+        assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
+        activation_post_process = mod.activation_post_process
+        weight_post_process = mod.qconfig.weight()  # type: ignore[operator, union-attr]
+
+        # Assumption is that the weight is already sparsified by the
+        # `sparsifier.convert`
+        weight = mod.weight
+
+        weight_post_process(weight)
+        dtype = weight_post_process.dtype
+        act_scale, act_zp = activation_post_process.calculate_qparams()  # type: ignore[operator, union-attr]
+        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+        w_sc, w_zp = weight_post_process.calculate_qparams()
+        if isinstance(w_zp, torch.Tensor):
+            assert not torch.any(w_zp.bool()), "All weight zero points must map to 0"
+        else:
+            assert w_zp == 0, 'Weight zero point must map to 0'
+        qweight = _quantize_weight(weight.float(), weight_post_process)
+
+        row_block_size = mod.sparse_params['sparse_block_shape'][0]  # type: ignore[index]
+        col_block_size = mod.sparse_params['sparse_block_shape'][1]  # type: ignore[index]
+        qlinear = cls(mod.in_features,
+                      mod.out_features,
+                      row_block_size,
+                      col_block_size,
+                      dtype=dtype)
+        qlinear.set_weight_bias(qweight, mod.bias,
+                                row_block_size, col_block_size)  # type: ignore[arg-type]
+        qlinear.scale = float(act_scale)
+        qlinear.zero_point = int(act_zp)
+        return qlinear
diff --git a/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7338d26eac828b43b9198b5cadec1b3f5e386b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/nn/sparse/quantized/utils.py
@@ -0,0 +1,42 @@
+import threading
+
+__all__ = [
+    "LinearBlockSparsePattern"
+]
+
+def _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size):
+    return (row_block_size == 1 and col_block_size == 4) or \
+           (row_block_size == 8 and col_block_size == 1)
+
+# This is a stop-gap measure as current flow does not allow module
+# specific block sparse pattern.
+# Infact there is no way to convey sparse pattern via module config
+# of quantization flow. Thus using the global context to convey
+# sparsity pattern.
+# Once the flow supports it, this should be removed.
+class LinearBlockSparsePattern:
+    rlock = threading.RLock()
+    row_block_size = 1
+    col_block_size = 4
+    prev_row_block_size = 1
+    prev_col_block_size = 4
+
+    def __init__(self, row_block_size=1, col_block_size=4):
+        assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size)
+        LinearBlockSparsePattern.rlock.acquire()
+        LinearBlockSparsePattern.prev_row_block_size = LinearBlockSparsePattern.row_block_size
+        LinearBlockSparsePattern.prev_col_block_size = LinearBlockSparsePattern.col_block_size
+        LinearBlockSparsePattern.row_block_size = row_block_size
+        LinearBlockSparsePattern.col_block_size = col_block_size
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_value, backtrace):
+        LinearBlockSparsePattern.row_block_size = LinearBlockSparsePattern.prev_row_block_size
+        LinearBlockSparsePattern.col_block_size = LinearBlockSparsePattern.prev_col_block_size
+        LinearBlockSparsePattern.rlock.release()
+
+    @staticmethod
+    def block_size():
+        return LinearBlockSparsePattern.row_block_size, LinearBlockSparsePattern.col_block_size
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/__init__.py b/MLPY/Lib/site-packages/torch/ao/ns/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3a4c7c21ff04220eb10136a55f79a7967672f09
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97595300ceaaa0a90efe7d1be6283cee121d74c9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca71c88be3d1ef7cbc2c769ba7043f91a1155499
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/_numeric_suite.py b/MLPY/Lib/site-packages/torch/ao/ns/_numeric_suite.py
new file mode 100644
index 0000000000000000000000000000000000000000..582708217a89b8c8b964bdef0d48382e5d7ef257
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/_numeric_suite.py
@@ -0,0 +1,526 @@
+import torch
+import torch.nn as nn
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+from torch.ao.quantization import prepare
+from typing import Dict, List, Optional, Any, Union, Callable, Set
+
+from torch.ao.quantization.quantization_mappings import (
+    get_default_compare_output_module_list,
+)
+
+NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
+    nnqd.Linear,
+    nnq.Linear,
+    nnqd.LSTM,
+    nn.LSTM,
+}
+
+
+def _find_match(
+    str_list: Union[Dict[str, Any], List[str]], key_str: str,
+    postfix: str,
+) -> Optional[str]:
+    split_str = key_str.split(".")
+    if split_str[-1] == postfix:
+        match_string = "".join(key_str.split(".")[0:-1])
+        for s2 in str_list:
+            pattern1 = "".join(s2.split(".")[0:-1])
+            pattern2 = "".join(s2.split(".")[0:-2])
+            if match_string == pattern1:
+                return s2
+            if match_string == pattern2:
+                return s2
+
+        # For matching "fc.weight" and "fc._packed_params._packed_params"
+        if postfix == "_packed_params":
+            match_string = "".join(key_str.split(".")[0:-2])
+            if len(match_string) == 0:
+                return None
+            for s2 in str_list:
+                pattern1 = "".join(s2.split(".")[0:-1])
+                pattern2 = "".join(s2.split(".")[0:-2])
+                if match_string == pattern1:
+                    return s2
+                if match_string == pattern2:
+                    return s2
+        return None
+    else:
+        return None
+
+
+def compare_weights(
+    float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    r"""Compare the weights of the float module with its corresponding quantized
+    module. Return a dict with key corresponding to module names and each entry being
+    a dictionary with two keys 'float' and 'quantized', containing the float and
+    quantized weights. This dict can be used to compare and compute the quantization
+    error of the weights of float and quantized models.
+
+    Example usage::
+
+        wt_compare_dict = compare_weights(
+            float_model.state_dict(), qmodel.state_dict())
+        for key in wt_compare_dict:
+            print(
+                key,
+                compute_error(
+                    wt_compare_dict[key]['float'],
+                    wt_compare_dict[key]['quantized'].dequantize()
+                )
+            )
+
+    Args:
+        float_dict: state dict of the float model
+        quantized_dict: state dict of the quantized model
+
+    Return:
+        weight_dict: dict with key corresponding to module names and each entry being
+        a dictionary with two keys 'float' and 'quantized', containing the float and
+        quantized weights
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
+    weight_dict: Dict[str, Dict] = {}
+    for key in quantized_dict:
+        match_key = _find_match(float_dict, key, "weight")
+        if match_key is not None:
+            weight_dict[key] = {}
+            weight_dict[key]["float"] = float_dict[match_key]
+            weight_dict[key]["quantized"] = quantized_dict[key]
+            continue
+
+        # For matching "fc.weight" and "fc._packed_params._packed_params"
+        match_key = _find_match(float_dict, key, "_packed_params")
+        if match_key is not None:
+            weight_dict[key] = {}
+            weight_dict[key]["float"] = float_dict[match_key]
+            weight_dict[key]["quantized"] = quantized_dict[key][0]
+
+        # For LSTM
+        split_str = key.split(".")
+        if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
+            layer = split_str[-2]
+            module_name = ".".join(split_str[:-3])
+            float_weight_ih_key = module_name + ".weight_ih_l" + layer
+            float_weight_hh_key = module_name + ".weight_hh_l" + layer
+            if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
+                weight_dict[key] = {}
+                weight_dict[key]["float"] = float_dict[float_weight_ih_key]
+                weight_dict[key]["quantized"] = (
+                    quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
+                )
+                weight_dict[key]["float"] = float_dict[float_weight_hh_key]
+                weight_dict[key]["quantized"] = (
+                    quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
+                )
+
+    return weight_dict
+
+
+def _get_logger_dict_helper(
+    mod: nn.Module, target_dict: Dict[str, Any],
+    prefix: str = "",
+) -> None:
+    r"""This is the helper function for get_logger_dict
+
+    Args:
+        mod: module we want to save all logger stats
+        prefix: prefix for the current module
+        target_dict: the dictionary used to save all logger stats
+    """
+
+    def get_prefix(prefix):
+        return prefix if prefix == "" else prefix + "."
+
+    for name, child in mod.named_children():
+        if isinstance(child, Logger):
+            target_dict[get_prefix(prefix) + "stats"] = child.stats
+            break
+
+    for name, child in mod.named_children():
+        module_prefix = get_prefix(prefix) + name if prefix else name
+        _get_logger_dict_helper(child, target_dict, module_prefix)
+
+
+def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
+    r"""Traverse the modules and save all logger stats into target dict.
+    This is mainly used for quantization accuracy debug.
+
+    Type of loggers supported:
+        ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
+        OutputLogger: used to log the outputs of the modules
+
+    Args:
+        mod: module we want to save all logger stats
+        prefix: prefix for the current module
+
+    Return:
+        target_dict: the dictionary used to save all logger stats
+
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
+
+    target_dict: Dict[str, Dict] = {}
+    _get_logger_dict_helper(mod, target_dict, prefix)
+    return target_dict
+
+
+class Logger(nn.Module):
+    r"""Base class for stats logging
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.stats = {}
+        # We only insert observer if the op is quantized with static quantization,
+        # which is identified by activation_observer.dtype == quint8.  This is needed
+        # when attaching Logger as observer for FX mode
+        self.dtype = torch.quint8
+
+    def forward(self, x):
+        """
+        """  # blank docblock to make autodoc happy
+        pass
+
+
+class ShadowLogger(Logger):
+    r"""Class used in Shadow module to record the outputs of the original and
+    shadow modules.
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.stats["float"] = []
+        self.stats["quantized"] = []
+
+    def forward(self, x, y):
+        """
+        """  # blank docblock to make autodoc happy
+        if len(x) > 1:
+            x = x[0]
+        if len(y) > 1:
+            y = y[0]
+        self.stats["quantized"].append(x.detach())
+        self.stats["float"].append(y.detach())
+
+
+class OutputLogger(Logger):
+    r"""Class used to log the outputs of the module
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.stats["tensor_val"] = []
+
+
+    def forward(self, x):
+        """
+        """  # blank docblock to make autodoc happy
+        self.stats["tensor_val"].append(x)
+        return x
+
+
+def _convert_tuple_to_list(t: Any) -> Any:
+    return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
+
+
+def _dequantize_tensor_list(t: Any) -> Any:
+    return (
+        [_dequantize_tensor_list(x) for x in t]
+        if type(t) is list
+        else t.dequantize()
+        if t.is_quantized
+        else t
+    )
+
+
+class Shadow(nn.Module):
+    r"""Shadow module attaches the float module to its matching quantized module
+    as the shadow. Then it uses Logger module to process the outputs of both
+    modules.
+
+    Args:
+        q_module: module quantized from float_module that we want to shadow
+        float_module: float module used to shadow q_module
+        logger_cls: type of logger used to process the outputs of q_module and
+            float_module. ShadowLogger or custom loggers can be used.
+    """
+
+    def __init__(self, q_module, float_module, logger_cls):
+        super().__init__()
+        self.orig_module = q_module
+        self.shadow_module = float_module
+        self.dequant = nnq.DeQuantize()
+        self.logger = logger_cls()
+
+    def forward(self, *x) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        xl = _convert_tuple_to_list(x)
+        output = self.orig_module(*xl)
+        xl_float = _dequantize_tensor_list(xl)
+        shadow_output = self.shadow_module(*xl_float)
+        self.logger(output, shadow_output)
+        return output
+
+    def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        output = self.orig_module.add(x, y)
+        x = x.dequantize()
+        y = y.dequantize()
+        shadow_output = self.shadow_module.add(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        output = self.orig_module.add_scalar(x, y)
+        x = x.dequantize()
+        shadow_output = self.shadow_module.add_scalar(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        output = self.orig_module.mul(x, y)
+        x = x.dequantize()
+        y = y.dequantize()
+        shadow_output = self.shadow_module.mul(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        output = self.orig_module.mul_scalar(x, y)
+        x = x.dequantize()
+        shadow_output = self.shadow_module.mul_scalar(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+    def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        output = self.orig_module.cat(x, dim)
+        x = [y.dequantize() for y in x]
+        shadow_output = self.shadow_module.cat(x, dim)
+        self.logger(output, shadow_output)
+        return output
+
+    def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+        """
+        """  # blank docblock to make autodoc happy
+        output = self.orig_module.add_relu(x, y)
+        x = x.dequantize()
+        y = y.dequantize()
+        shadow_output = self.shadow_module.add_relu(x, y)
+        self.logger(output, shadow_output)
+        return output
+
+
+def prepare_model_with_stubs(
+    float_module: nn.Module, q_module: nn.Module,
+    module_swap_list: Set[type], logger_cls: Callable,
+) -> None:
+    r"""Prepare the model by attaching the float module to its matching quantized
+    module as the shadow if the float module type is in module_swap_list.
+
+    Example usage::
+
+        prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
+        q_model(data)
+        ob_dict = get_logger_dict(q_model)
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+        module_swap_list: list of float module types to attach the shadow
+        logger_cls: type of logger to be used in shadow module to process the outputs of
+            quantized module and its float shadow module
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
+
+    float_module_children = {}
+    for name, mod in float_module.named_children():
+        float_module_children[name] = mod
+
+    reassign = {}
+    for name, mod in q_module.named_children():
+
+        if name not in float_module_children:
+            continue
+
+        float_mod = float_module_children[name]
+
+        if type(float_mod) not in module_swap_list:
+            prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
+
+        # Insert shadow module only if the module is not of the same type as
+        # the floating point module
+        if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
+            reassign[name] = Shadow(mod, float_mod, logger_cls)
+
+    for key, value in reassign.items():
+        q_module._modules[key] = value
+
+def _is_identical_module_type(mod1, mod2):
+    # Compare if two modules have the same dtype
+    mod1_module_types = [type(mod) for mod in mod1.modules()]
+    mod2_module_types = [type(mod) for mod in mod2.modules()]
+    return mod1_module_types == mod2_module_types
+
+
+
+def compare_model_stub(
+    float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
+    *data, logger_cls=ShadowLogger
+) -> Dict[str, Dict]:
+    r"""Compare quantized module in a model with its floating point counterpart,
+    feeding both of them the same input. Return a dict with key corresponding to
+    module names and each entry being a dictionary with two keys 'float' and
+    'quantized', containing the output tensors of quantized and its matching
+    float shadow module. This dict can be used to compare and compute the module
+    level quantization error.
+
+    This function first call prepare_model_with_stubs() to swap the quantized
+    module that we want to compare with the Shadow module, which takes quantized
+    module, corresponding float module and logger as input, and creates a forward
+    path inside to make the float module to shadow quantized module sharing the
+    same input. The logger can be customizable, default logger is ShadowLogger
+    and it will save the outputs of the quantized module and float module that
+    can be used to compute the module level quantization error.
+
+    Example usage::
+
+        module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
+        ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
+        for key in ob_dict:
+            print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
+
+    Args:
+        float_model: float model used to generate the q_model
+        q_model: model quantized from float_model
+        module_swap_list: list of float module types at which shadow modules will
+            be attached.
+        data: input data used to run the prepared q_model
+        logger_cls: type of logger to be used in shadow module to process the outputs of
+            quantized module and its float shadow module
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
+    prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
+    q_model(*data)
+    ob_dict = get_logger_dict(q_model)
+    return ob_dict
+
+
+def get_matching_activations(
+    float_module: nn.Module, q_module: nn.Module,
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    r"""Find the matching activation between float and quantized modules.
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+
+    Return:
+        act_dict: dict with key corresponding to quantized module names and each
+        entry being a dictionary with two keys 'float' and 'quantized', containing
+        the matching float and quantized activations
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
+    float_dict = get_logger_dict(float_module)
+    quantized_dict = get_logger_dict(q_module)
+    act_dict: Dict[str, Dict] = {}
+    for key in quantized_dict:
+        if len(quantized_dict[key]["tensor_val"]) == 0:
+            continue
+        match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
+        if match_key is not None:
+            act_dict[key] = {}
+            act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
+            act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
+    return act_dict
+
+
+def prepare_model_outputs(
+    float_module: nn.Module,
+    q_module: nn.Module,
+    logger_cls=OutputLogger,
+    allow_list=None
+) -> None:
+    r"""Prepare the model by attaching the logger to both float module
+    and quantized module if they are in the allow_list.
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+        logger_cls: type of logger to be attached to float_module and q_module
+        allow_list: list of module types to attach logger
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
+    if allow_list is None:
+        allow_list = get_default_compare_output_module_list()
+
+    qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
+    float_module.qconfig = qconfig_debug  # type: ignore[assignment]
+    prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={})
+    q_module.qconfig = qconfig_debug  # type: ignore[assignment]
+    prepare(
+        q_module,
+        inplace=True,
+        allow_list=allow_list,
+        observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
+        prepare_custom_config_dict={}
+    )
+
+
+def compare_model_outputs(
+    float_model: nn.Module,
+    q_model: nn.Module,
+    *data,
+    logger_cls=OutputLogger,
+    allow_list=None
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    r"""Compare output activations between float and quantized models at
+    corresponding locations for the same input. Return a dict with key corresponding
+    to quantized module names and each entry being a dictionary with two keys
+    'float' and 'quantized', containing the activations of quantized model and
+    float model at matching locations. This dict can be used to compare and
+    compute the propagation quantization error.
+
+    Example usage::
+
+        act_compare_dict = compare_model_outputs(float_model, qmodel, data)
+        for key in act_compare_dict:
+            print(
+                key,
+                compute_error(
+                    act_compare_dict[key]['float'],
+                    act_compare_dict[key]['quantized'].dequantize()
+                )
+            )
+
+    Args:
+        float_model: float model used to generate the q_model
+        q_model: model quantized from float_model
+        data: input data used to run the prepared float_model and q_model
+        logger_cls: type of logger to be attached to float_module and q_module
+        allow_list: list of module types to attach logger
+
+    Return:
+        act_compare_dict: dict with key corresponding to quantized module names
+        and each entry being a dictionary with two keys 'float' and 'quantized',
+        containing the matching float and quantized activations
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
+    if allow_list is None:
+        allow_list = get_default_compare_output_module_list()
+    prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
+    float_model(*data)
+    q_model(*data)
+    act_compare_dict = get_matching_activations(float_model, q_model)
+    return act_compare_dict
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py b/MLPY/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce1e5e574b66281024cd64bac940fcf8a7fa8bc3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/_numeric_suite_fx.py
@@ -0,0 +1,1025 @@
+"""
+This module contains tooling to compare weights and activations
+across models. Example usage::
+
+    import copy
+    import torch
+    import torch.ao.quantization.quantize_fx as quantize_fx
+    import torch.ao.ns._numeric_suite_fx as ns
+
+    m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
+    mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
+    # We convert a copy because we need the original prepared model
+    # to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
+    mq = quantize_fx.convert_fx(copy.deepcopy(mp))
+
+    #
+    # Comparing weights
+    #
+
+    # extract weight pairs
+    weight_comparison = ns.extract_weights('a', mp, 'b', mq)
+
+    # add SQNR for each comparison, inplace
+    ns.extend_logger_results_with_comparison(
+        weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
+        'sqnr')
+
+    # weight_comparison contains the weights from `mp` and `mq` stored
+    # in pairs, and can be used for further analysis.
+
+
+    #
+    # Comparing activations, with error propagation
+    #
+
+    # add loggers
+    mp_ns, mq_ns = ns.add_loggers(
+        'a', copy.deepcopy(mp),
+        'b', copy.deepcopy(mq),
+        ns.OutputLogger)
+
+    # send an example datum to capture intermediate activations
+    datum = torch.randn(1, 1, 1, 1)
+    mp_ns(datum)
+    mq_ns(datum)
+
+    # extract intermediate activations
+    act_comparison = ns.extract_logger_info(
+        mp_ns, mq_ns, ns.OutputLogger, 'b')
+
+    # add SQNR for each comparison, inplace
+    ns.extend_logger_results_with_comparison(
+        act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
+        'sqnr')
+
+    # act_comparison contains the activations from `mp_ns` and `mq_ns` stored
+    # in pairs, and can be used for further analysis.
+
+    #
+    # Comparing activations, without error propagation
+    #
+
+    # create shadow model
+    mp_shadows_mq = ns.add_shadow_loggers(
+        'a', copy.deepcopy(mp),
+        'b', copy.deepcopy(mq),
+        ns.OutputLogger)
+
+    # send an example datum to capture intermediate activations
+    datum = torch.randn(1, 1, 1, 1)
+    mp_shadows_mq(datum)
+
+    # extract intermediate activations
+    shadow_act_comparison = ns.extract_shadow_logger_info(
+        mp_shadows_mq, ns.OutputLogger, 'b')
+
+    # add SQNR for each comparison, inplace
+    ns.extend_logger_results_with_comparison(
+        shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
+        'sqnr')
+
+    # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
+    # in pairs, and can be used for further analysis.
+
+"""
+
+import collections
+
+import torch
+import torch.nn as nn
+import torch.ao.quantization.quantize_fx as quantize_fx
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+from torch.ao.ns.fx.mappings import (
+    get_base_name_to_sets_of_related_ops,
+)
+from torch.ao.ns.fx.graph_matcher import (
+    get_matching_subgraph_pairs,
+    get_type_a_related_to_b,
+)
+
+from .fx.weight_utils import (
+    extract_weight_from_node,
+)
+
+from .fx.graph_passes import (
+    add_loggers_to_model,
+    create_a_shadows_b,
+)
+
+from .fx.utils import (
+    rekey_logger_info_on_node_name_of_model,
+    maybe_add_missing_fqns,
+    get_target_type_str,
+)
+
+from .fx.ns_types import (
+    NSSingleResultValuesType,
+    NSResultsType,
+    NSNodeTargetType,
+)
+from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.ao.quantization.fx.match_utils import _find_matches
+from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
+from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
+from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization import QConfigMapping
+from torch.ao.ns.fx.n_shadows_utils import (
+    OutputProp,
+    _get_dedup_subgraphs,
+    SHADOW_WRAPPER_NODE_NAME_PREFIX,
+    group_results_by_subgraph,
+    create_results_comparison,
+    print_n_shadows_summary,
+    create_n_transformed_and_logged_copies_of_subgraph,
+    create_add_loggers_graph,
+    extract_weight_comparison,
+)
+from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
+
+from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
+
+RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+
+class OutputLogger(nn.Module):
+    """
+    Base class for capturing intermediate values.
+    """
+    stats: List[torch.Tensor]
+    stats_rnn: List[RNNReturnType]
+
+    # Mark as impure so that calls to it will not be removed during DCE.
+    _is_impure = True
+
+    def __init__(
+        self,
+        ref_node_name: str,
+        prev_node_name: str,
+        model_name: str,
+        ref_name: str,
+        prev_node_target_type: str,
+        ref_node_target_type: str,
+        results_type: str,
+        index_within_arg: int,
+        index_of_arg: int,
+        fqn: Optional[str],
+        qconfig_str: Optional[str] = '',
+    ):
+        super().__init__()
+        self.stats: List[torch.Tensor] = []
+        self.stats_rnn: List[RNNReturnType] = []
+
+        # name of the node which was responsible for adding this logger
+        # Note:
+        # - if we are logging node outputs, this is the same as prev_node_name
+        # - if we are logging node inputs, this is the name of the node
+        #   whose input this logger is logging.
+        #
+        # example, where logger1 is logging input of op1 and logger2 is logging
+        #    the output of op1:
+        #
+        #  x1 -> logger1 -> op1 -> logger2 -> x2
+        #
+        # in this example,
+        #   - logger1's prev_node_name is x1 and ref_node_name is op1
+        #   - logger2's prev_node_name is op1 and ref_node_name is op1
+        self.ref_node_name = ref_node_name
+        # name of the node whose output this Logger is capturing
+        self.prev_node_name = prev_node_name
+
+        # name of the model from which the node originated from
+        self.model_name = model_name
+        # reference name, used to match loggers from separate models
+        # to each other
+        self.ref_name = ref_name
+        # type of the target of the node whose output this logger is logging
+        self.prev_node_target_type = prev_node_target_type
+        # type of the target of the node which was responsible for adding this
+        # logger
+        self.ref_node_target_type = ref_node_target_type
+        # what kind of values are inside of stats
+        self.results_type = results_type
+        # index of this node within the arg of the input/output node
+        # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
+        self.index_within_arg = index_within_arg
+        # index of this node within the args of the input/output node
+        # for example, in add(x1, x2), x2 would have index_of_arg == 1
+        self.index_of_arg = index_of_arg
+        # fully qualified name
+        self.fqn = fqn
+        # if loggers are added before prepare_fx, but we do not want
+        # collect results of calibration, only results after convert_fx
+        # so, we add a flag to control whether this logger collects data
+        self.enabled = True
+        # string representation of qconfig
+        self.qconfig_str = qconfig_str
+        # this can be turned off to reduce memory usage during calibration
+        self.save_activations = True
+
+    # Note: cannot annotate the type of x because TorchScript does not support
+    #   the Union type.
+    def forward(self, x):
+        """
+        """  # blank docblock to make autodoc happy
+        # TODO(future PR): consider designing this better, as the difference
+        # between these two flags is subtle and not obvious.
+        if not self.enabled:
+            return x
+        if not self.save_activations:
+            return x
+        # TODO(future PR): consider refactoring this to better reuse the parent
+        # class
+        if isinstance(x, torch.Tensor):
+            self.stats.append(x.detach())
+        elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
+            new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
+            self.stats_rnn.append(new_res)
+        return x
+
+    def __repr__(self):
+        clean_dict = {
+            k: v
+            for k, v in self.__dict__.items()
+            # skip nn.Module keys
+            if (k != 'training') and not k.startswith('_')
+        }
+        return f"OutputLogger({clean_dict})"
+
+
+class OutputComparisonLogger(OutputLogger):
+    """
+    Same as OutputLogger, but also requires the original activation
+    in order to calculate the comparison at calibration time
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # TODO(future PR): make the comparison function configurable
+        self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
+        self.comparison_fn_name = 'sqnr'
+        # precalculated comparisons of logger output versus reference
+        self.comparisons = []
+        # precalculated comparisons function
+
+    def forward(self, x, x_ref):
+        """
+        """  # blank docblock to make autodoc happy
+        if not self.enabled:
+            return x
+        assert isinstance(x, torch.Tensor), 'non-tensor inputs not yet supported'
+        if self.save_activations:
+            # save the activation, for debugging
+            self.stats.append(x.detach())
+        # save the comparison
+        self.comparisons.append(self.comparison_fn(x, x_ref))
+        return x
+
+    def __repr__(self):
+        clean_dict = {
+            k: v
+            for k, v in self.__dict__.items()
+            # skip nn.Module keys
+            if (k != 'training') and not k.startswith('_')
+        }
+        return f"OutputComparisonLogger({clean_dict})"
+
+
+class NSTracer(quantize_fx.QuantizationTracer):
+    """
+    Just like a regular FX quantization tracer, but treats observers and fake_quantize
+    modules as leaf modules.
+    """
+    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
+        """
+        """  # blank docblock to make autodoc happy
+        if isinstance(m, torch.ao.quantization.ObserverBase):
+            return True
+        elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
+            return True
+        return super().is_leaf_module(m, module_qualified_name)
+
+
+def _extract_weights_one_model(
+    model_name: str,
+    model: GraphModule,
+    nodes_and_names_to_instrument: List[Tuple[Node, str]],
+    results: NSResultsType,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> None:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
+    for node, ref_name in nodes_and_names_to_instrument:
+        res_type = NSSingleResultValuesType.WEIGHT.value
+        extracted_weight = extract_weight_from_node(
+            node, model, op_to_type_to_weight_extraction_fn)
+        if extracted_weight:
+            if ref_name not in results:
+                results[ref_name] = {res_type: {}}
+            results[ref_name][res_type][model_name] = [extracted_weight]
+
+
+def _extract_weights_impl(
+    model_name_a: str,
+    gm_a: GraphModule,
+    model_name_b: str,
+    gm_b: GraphModule,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> NSResultsType:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
+    matched_subgraph_pairs = get_matching_subgraph_pairs(
+        gm_a, gm_b, base_name_to_sets_of_related_ops,
+        unmatchable_types_map)
+
+    # split the subgraph pairs into one data structure for each model
+    nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
+    nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
+    for match_name, match in matched_subgraph_pairs.items():
+        subgraph_a, subgraph_b = match
+        nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
+        nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
+
+    # populate the results, one model at a time
+    results: NSResultsType = {}
+    _extract_weights_one_model(
+        model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
+        op_to_type_to_weight_extraction_fn)
+    _extract_weights_one_model(
+        model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
+        op_to_type_to_weight_extraction_fn)
+
+    # fill in missing fqn entries
+    maybe_add_missing_fqns(results)
+
+    # rekey on names of nodes in gm_b
+    results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
+
+    return results
+
+
+def extract_weights(
+    model_name_a: str,
+    model_a: nn.Module,
+    model_name_b: str,
+    model_b: nn.Module,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> NSResultsType:
+    """
+    Extract weights from model A and model B, and return a comparison.
+
+    Args:
+        model_name_a: string name of model A to use in results
+        model_a: model A
+        model_name_b: string name of model B to use in results
+        model_b: model B
+        base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
+        unmatchable_types_map: optional override of unmatchable types, subject to change
+        op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
+            from a type, subject to change
+
+    Return:
+        NSResultsType, containing the weight comparisons
+    """
+
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
+    if base_name_to_sets_of_related_ops is None:
+        base_name_to_sets_of_related_ops = \
+            get_base_name_to_sets_of_related_ops()
+    type_a_related_to_b = \
+        get_type_a_related_to_b(base_name_to_sets_of_related_ops)
+
+    # TODO(future PR): expose these
+    skipped_module_names: List[str] = []
+    skipped_module_classes: List[Callable] = []
+    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+    maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
+    if maybe_model_a_node_name_to_scope is not None:
+        gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
+    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+    maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
+    if maybe_model_b_node_name_to_scope is not None:
+        gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
+    return _extract_weights_impl(
+        model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
+        unmatchable_types_map, op_to_type_to_weight_extraction_fn)
+
+
+def _add_loggers_one_model(
+    model_name: str,
+    model: GraphModule,
+    nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
+    nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
+    logger_cls: Callable,
+) -> nn.Module:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
+
+    # TODO(future PR): do not observe nodes we do not care
+    #   about (both fp32, denylist, etc)
+    node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
+    node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
+    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
+        node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
+    for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
+        node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
+
+    model = add_loggers_to_model(
+        model, node_to_instrument_inputs_to_ref_name,
+        node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
+    return model
+
+
+def _add_loggers_impl(
+    name_a: str,
+    gm_a: GraphModule,
+    name_b: str,
+    gm_b: GraphModule,
+    logger_cls: Callable,
+    should_log_inputs: bool,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Tuple[nn.Module, nn.Module]:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
+    matched_subgraph_pairs = get_matching_subgraph_pairs(
+        gm_a, gm_b,
+        base_name_to_sets_of_related_ops, unmatchable_types_map)
+    nodes_and_names_to_instrument_inputs_a = []
+    nodes_and_names_to_instrument_inputs_b = []
+    nodes_and_names_to_instrument_outputs_a = []
+    nodes_and_names_to_instrument_outputs_b = []
+    for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
+        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
+        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
+        # Note: for matching inputs we use start_node, such as observing
+        # the input of linear in linear-relu
+        if should_log_inputs:
+            nodes_and_names_to_instrument_inputs_a.append(
+                (subgraph_a.start_node, match_name, ref_node_type_a))
+            nodes_and_names_to_instrument_inputs_b.append(
+                (subgraph_b.start_node, match_name, ref_node_type_b))
+        # Note: for matching activations we always use end_node,
+        # such as observing the output of relu in linear-relu
+        nodes_and_names_to_instrument_outputs_a.append(
+            (subgraph_a.end_node, match_name, ref_node_type_a))
+        nodes_and_names_to_instrument_outputs_b.append(
+            (subgraph_b.end_node, match_name, ref_node_type_b))
+
+    new_model_a = _add_loggers_one_model(
+        name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
+        nodes_and_names_to_instrument_outputs_a, logger_cls)
+    new_model_b = _add_loggers_one_model(
+        name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
+        nodes_and_names_to_instrument_outputs_b, logger_cls)
+    return (new_model_a, new_model_b)
+
+
+def add_loggers(
+    name_a: str,
+    model_a: nn.Module,
+    name_b: str,
+    model_b: nn.Module,
+    logger_cls: Callable,
+    should_log_inputs : bool = False,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Tuple[nn.Module, nn.Module]:
+    """
+    Instrument model A and model B with loggers.
+
+    Args:
+        name_a: string name of model A to use in results
+        model_a: model A
+        name_b: string name of model B to use in results
+        model_b: model B
+        logger_cls: class of Logger to use
+        base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
+        unmatchable_types_map: optional override of unmatchable types, subject to change
+
+    Return:
+        Returns a tuple of (model_a_with_loggers, model_b_with_loggers).  Modifies both models inplace.
+    """
+
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
+    # TODO(future PR): expose these
+    skipped_module_names: List[str] = []
+    skipped_module_classes: List[Callable] = []
+    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+    maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
+    if maybe_model_a_node_name_to_scope is not None:
+        gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
+    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+    maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
+    if maybe_model_b_node_name_to_scope is not None:
+        gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
+    return _add_loggers_impl(
+        name_a, gm_a, name_b, gm_b, logger_cls,
+        should_log_inputs=should_log_inputs,
+        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
+        unmatchable_types_map=unmatchable_types_map)
+
+
+def _extract_logger_info_one_model(
+    model: nn.Module,
+    results: NSResultsType,
+    logger_cls: Callable,
+) -> None:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
+    for gm_name, mod in model.named_modules():
+        # TODO(future PR): better check when scripted
+        is_logger = (
+            isinstance(mod, logger_cls)  # type: ignore[arg-type]
+            or (
+                isinstance(mod, torch.jit.RecursiveScriptModule)
+                and mod.original_name == 'OutputLogger'
+            )
+        )
+        if is_logger:
+            key = mod.ref_name
+            if key not in results:
+                results[key] = {}
+            assert mod.model_name not in results[key], \
+                f"{mod.model_name} is already present in results"
+            if mod.results_type not in results[key]:
+                results[key][mod.results_type] = {}
+            if mod.model_name not in results[key][mod.results_type]:
+                results[key][mod.results_type][mod.model_name] = []
+            stats_to_use = mod.stats
+            if len(mod.stats_rnn) > 0:
+                stats_to_use = mod.stats_rnn
+            data = {
+                'type': mod.results_type,
+                'values': stats_to_use,
+                'ref_node_name': mod.ref_node_name,
+                'ref_node_target_type': mod.ref_node_target_type,
+                'prev_node_name': mod.prev_node_name,
+                'prev_node_target_type': mod.prev_node_target_type,
+                'index_within_arg': mod.index_within_arg,
+                'index_of_arg': mod.index_of_arg,
+                'fqn': mod.fqn,
+                'qconfig_str': mod.qconfig_str,
+            }
+            if hasattr(mod, 'comparisons'):
+                data['comparisons'] = mod.comparisons
+                data['comparison_fn_name'] = mod.comparison_fn_name
+            else:
+                data['comparisons'] = []
+                data['comparison_fn_name'] = ''
+            results[key][mod.results_type][mod.model_name].append(data)
+            # ensure the list stays sorted
+            results[key][mod.results_type][mod.model_name].sort(
+                key=lambda res:
+                f"{res['index_of_arg']}:{res['index_within_arg']}"
+            )
+
+
+# TODO(future PR): align on naming
+# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
+def extract_logger_info(
+    model_a: nn.Module,
+    model_b: nn.Module,
+    logger_cls: Callable,
+    model_name_to_use_for_layer_names: str,
+) -> NSResultsType:
+    """
+    Traverse all loggers in `model_a` and `model_b`, and extract the logged
+    information.
+
+    Args:
+        model_a: model A
+        model_b: model B
+        logger_cls: class of Logger to use
+        model_name_to_use_for_layer_names: string name of model to use for
+          layer names in the output
+
+    Return:
+        NSResultsType, containing the logged comparisons
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
+    results: NSResultsType = {}
+    for model in (model_a, model_b):
+        _extract_logger_info_one_model(model, results, logger_cls)
+    # fill in missing fqn entries
+    maybe_add_missing_fqns(results)
+    # rekey on the name of model b
+    results = rekey_logger_info_on_node_name_of_model(
+        results, model_name_to_use_for_layer_names)
+    return results
+
+
+def _add_shadow_loggers_impl(
+    name_a: str,
+    gm_a: GraphModule,
+    name_b: str,
+    gm_b: GraphModule,
+    logger_cls: Callable,
+    should_log_inputs: bool,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> nn.Module:
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
+    matched_subgraph_pairs = get_matching_subgraph_pairs(
+        gm_a, gm_b, base_name_to_sets_of_related_ops,
+        unmatchable_types_map)
+    gm_a_shadows_b = create_a_shadows_b(
+        name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
+        should_log_inputs=should_log_inputs,
+        node_type_to_io_type_map=node_type_to_io_type_map)
+    return gm_a_shadows_b
+
+
+def add_shadow_loggers(
+    name_a: str,
+    model_a: nn.Module,
+    name_b: str,
+    model_b: nn.Module,
+    logger_cls: Callable,
+    should_log_inputs: bool = False,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> nn.Module:
+    """
+    Instrument model A and model B with shadow loggers.
+
+    Args:
+        name_a: string name of model A to use in results
+        model_a: model A
+        name_b: string name of model B to use in results
+        model_b: model B
+        logger_cls: class of Logger to use
+        should_log_inputs: whether to log inputs
+        base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
+        unmatchable_types_map: optional override of unmatchable types, subject to change
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
+    # TODO(future PR): expose these
+    skipped_module_names: List[str] = []
+    skipped_module_classes: List[Callable] = []
+    tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
+    tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
+    gm_a = GraphModule(model_a, tracer_a.trace(model_a))
+    maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
+    if maybe_model_a_node_name_to_scope is not None:
+        gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
+    gm_b = GraphModule(model_b, tracer_b.trace(model_b))
+    maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
+    if maybe_model_b_node_name_to_scope is not None:
+        gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
+    return _add_shadow_loggers_impl(
+        name_a, gm_a, name_b, gm_b, logger_cls,
+        should_log_inputs=should_log_inputs,
+        base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
+        node_type_to_io_type_map=node_type_to_io_type_map,
+        unmatchable_types_map=unmatchable_types_map)
+
+
+def extract_shadow_logger_info(
+    model_a_shadows_b: nn.Module,
+    logger_cls: Callable,
+    model_name_to_use_for_layer_names: str,
+) -> NSResultsType:
+    """
+    Traverse all loggers in a shadow model, and extract the logged
+    information.
+
+    Args:
+        model_a_shadows_b: shadow model
+        logger_cls: class of Logger to use
+        model_name_to_use_for_layer_names: string name of model to use for
+          layer names in the output
+
+    Return:
+        NSResultsType, containing the logged comparisons
+    """
+    torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
+    results: NSResultsType = collections.defaultdict(dict)
+    _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
+    # fill in missing fqn entries
+    maybe_add_missing_fqns(results)
+    # rekey on the name of model b
+    results = rekey_logger_info_on_node_name_of_model(
+        results, model_name_to_use_for_layer_names)
+    return dict(results)
+
+
+def extend_logger_results_with_comparison(
+    results: NSResultsType,
+    model_name_1: str,
+    model_name_2: str,
+    comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
+    comparison_name: str,
+) -> None:
+    """
+    Compares the logged values from `model_name_2` against the corresponding
+    values in `model_name_1`, using `comparison_fn`. Records the result
+    in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
+
+    Args:
+        results: the result data structure from `extract_logger_info` or
+          `extract_shadow_logger_info`.
+        model_name_1: string name of model 1
+        model_name_2: string name of model 2
+        comparison_fn: function to compare two Tensors
+        comparison_name: string name of model to use for
+          layer names in the output
+    """
+    for results_type_to_results in results.values():
+        for model_name_to_results in results_type_to_results.values():
+            assert model_name_1 in model_name_to_results, \
+                f"{model_name_1} not found in results"
+            assert model_name_2 in model_name_to_results, \
+                f"{model_name_2} not found in results"
+
+            results_1 = model_name_to_results[model_name_1]
+            results_2 = model_name_to_results[model_name_2]
+
+            for result_2 in results_2:
+                index_within_arg_2 = result_2['index_within_arg']
+                index_of_arg_2 = result_2['index_of_arg']
+                # find corresponding result_1
+                result_1 = None
+                for cur_result_1 in results_1:
+                    index_within_arg_1 = cur_result_1['index_within_arg']
+                    index_of_arg_1 = cur_result_1['index_of_arg']
+                    if (
+                        (index_within_arg_1 == index_within_arg_2) and
+                        (index_of_arg_1 == index_of_arg_2)
+                    ):
+                        result_1 = cur_result_1
+                        break
+                assert result_1 is not None
+
+                values_1 = result_1['values']
+                values_2 = result_2['values']
+                result_2[comparison_name] = []
+                for value_1, value_2 in zip(values_1, values_2):
+                    comparison_result = comparison_fn(value_1, value_2)
+                    result_2[comparison_name].append(comparison_result)
+
+def prepare_n_shadows_model(
+    model: torch.nn.Module,
+    example_inputs: Any,
+    qconfig_multi_mapping: QConfigMultiMapping,
+    backend_config: BackendConfig,
+    custom_prepare_fn: Optional[Callable] = None,
+    custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
+    custom_tracer: Any = None,
+) -> GraphModule:
+    """
+    Given a model with a graph with M ops such as
+
+
+      args_kwargs_m -> op_m -> output_m
+
+
+    And a set of N qconfigs for each op, creates a new model, with
+    each of the subgraph of `op_m` transformed into
+
+    .. code::
+
+           |---------> op_m_n -> log_m_n
+           |                     /
+      args_kwargs_m ---------> op_m -> log_m_0
+
+    Where op_m_n is op_m wrapped in a submodule and transformed with
+    qconfig_n, and its inner graph looks like
+
+    .. code::
+
+      args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
+                  /
+      kwargs_m ---
+
+    This is useful for testing different quantization of multiple layers in
+    a single pass through the model.
+
+    High level TODOs for future PRs:
+    * figure out a better way to name the output structure
+    * return a results data structure instead of printing it out
+    * add examples to docblocks
+    """
+
+    if custom_tracer is None:
+        tracer = quantize_fx.QuantizationTracer([], [])
+    else:
+        tracer = custom_tracer
+    mt = torch.fx.GraphModule(model, tracer.trace(model))
+    # this is necessary to ensure logger FQNs get populated
+    mt._node_name_to_scope = tracer.node_name_to_scope
+
+    # run example input propagation, we need this to call prepare_fx on
+    # individual subgraphs
+    output_prop = OutputProp(mt)
+    output_prop.propagate(*example_inputs)
+
+    # Find the set of subgraphs in the original graph which we need to
+    # consider.
+    modules = dict(mt.named_modules(remove_duplicate=False))
+    patterns = _get_pattern_to_quantize_handlers(backend_config)
+    root_node_getter_mapping = \
+        get_fusion_pattern_to_root_node_getter(backend_config)
+    standalone_module_names: List[str] = []
+    standalone_module_classes: List[Type] = []
+    custom_module_classes: List[Type] = []
+    matches = _find_matches(
+        mt.graph, modules, patterns, root_node_getter_mapping,
+        standalone_module_names, standalone_module_classes, custom_module_classes)
+    subgraphs_dedup: Dict[str, List[Node]] = \
+        _get_dedup_subgraphs(matches)
+
+    # generate node to qconfig for each subgraph
+    # TODO(future PR): deduplicate repeating entries
+    list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
+    for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
+        node_name_to_qconfig = _generate_node_name_to_qconfig(
+            mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
+        list_of_node_name_to_qconfig.append(node_name_to_qconfig)
+
+    # For each region in the model, do the following:
+    #   For each qconfig for that region, do the following:
+    #     1. create a copy of the region wrapped in a module
+    #     2. pass original args, original kwargs, and expected output to module
+    #     3. add an output comparison logger and hook it up to compare
+    #        actual output to expected output
+    #     4. run `prepare_fx` on the module
+    for (subgraph_idx, (match_name, nodes_in_this_subgraph)) in \
+            enumerate(subgraphs_dedup.items()):
+        create_n_transformed_and_logged_copies_of_subgraph(
+            mt, subgraph_idx, match_name, nodes_in_this_subgraph,
+            qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
+            custom_prepare_fn, custom_prepare_kwargs  # type: ignore[arg-type]
+        )
+
+    return mt
+
+# TODO(future PR): we should rethink the names of all the PNP APIs
+def _prepare_n_shadows_add_loggers_model(
+    model: torch.nn.Module,
+    example_inputs: Any,
+    qconfig_mapping: QConfigMapping,
+    backend_config: BackendConfig,
+) -> torch.nn.Module:
+    r"""
+    Note: this API is not recommended for wide usage, it is only
+    provided for customers who need to migrate from the `add_loggers`
+    API.
+
+    This creates a model which provides logging for the following
+    problem: if we quantize `model` with `qconfig_mapping` and feed
+    the same input through both models, log the comparisons of
+    corresponding intermediate layers.
+
+    The problem is solved with a single model.  Specifically, we
+    partition `model` into N subgraphs, create a copy of each relevant
+    subgraph, wrap it in a module, apply the quantization API to that
+    module, and hook up loggers to measure the comparisons.
+
+    Example starting graph:
+
+      x0 -> op0 -> x1 -> op1 -> x2
+
+    Example config: quantize op0 to int8, do nothing to op1.
+    The following graph will be created:
+
+    .. code::
+
+      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
+       \                        \                           \       # noqa: W605
+         ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
+
+    Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
+    to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
+    and clog is a comparison logger.
+    """
+
+    tracer = quantize_fx.QuantizationTracer([], [])
+    mt = torch.fx.GraphModule(model, tracer.trace(model))
+    # this is necessary to ensure logger FQNs get populated
+    mt._node_name_to_scope = tracer.node_name_to_scope
+
+    # run example input propagation, we need this to call prepare_fx on
+    # individual subgraphs
+    output_prop = OutputProp(mt)
+    output_prop.propagate(*example_inputs)
+
+    # Find the set of subgraphs in the original graph which we need to
+    # consider.
+    modules = dict(mt.named_modules(remove_duplicate=False))
+    patterns = _get_pattern_to_quantize_handlers(backend_config)
+    root_node_getter_mapping = \
+        get_fusion_pattern_to_root_node_getter(backend_config)
+    standalone_module_names: List[str] = []
+    standalone_module_classes: List[Type] = []
+    custom_module_classes: List[Type] = []
+    matches = _find_matches(
+        mt.graph, modules, patterns, root_node_getter_mapping,
+        standalone_module_names, standalone_module_classes, custom_module_classes)
+    subgraphs_dedup: Dict[str, List[Node]] = \
+        _get_dedup_subgraphs(matches)
+
+    # generate node to qconfig for each subgraph
+    node_name_to_qconfig = _generate_node_name_to_qconfig(
+        mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
+
+    # Now, mutate the graph to be the add_loggers graph with propagation
+    # error.
+    create_add_loggers_graph(
+        mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
+
+    return mt
+
+# TODO(future PR): we should rethink the names of all the PNP APIs
+def _n_shadows_compare_weights(
+    model: torch.nn.Module,
+    example_inputs: Any,
+    qconfig_mapping: QConfigMapping,
+    backend_config: BackendConfig,
+) -> NSResultsType:
+    """
+    Note: this API is not recommended for wide usage, it is only
+    provided for customers who need to migrate from the `add_loggers`
+    API.
+    """
+    qconfig_multi_mapping = \
+        QConfigMultiMapping.from_list_qconfig_mapping([qconfig_mapping])
+    mp = prepare_n_shadows_model(
+        model, example_inputs, qconfig_multi_mapping, backend_config)
+    # passing inputs through the model is necessary to populate
+    # observers which observe weights with real values
+    mp(*example_inputs)
+    mq = convert_n_shadows_model(mp)
+    weight_comparison = extract_weight_comparison(mq)
+    return weight_comparison
+
+# TODO(future PR): consider aligning API signature with other similar quantization
+# functions (enable_fake_quant, etc)
+def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
+    """
+    Sets the `enabled` setting on a `model`'s loggers
+    """
+    for name, child in model.named_modules():
+        if isinstance(child, OutputLogger):
+            child.enabled = enabled
+
+# TODO(future PR): consider aligning API signature with other similar quantization
+# functions (enable_fake_quant, etc)
+def loggers_set_save_activations(
+    model: torch.nn.Module,
+    save_activations: bool,
+) -> None:
+    """
+    Sets the `save_activations` setting on a `model`'s loggers
+    """
+    for name, child in model.named_modules():
+        if isinstance(child, OutputLogger):
+            child.save_activations = save_activations
+
+def convert_n_shadows_model(
+    model: GraphModule,
+    custom_convert_fn: Optional[Callable] = None,
+    custom_convert_kwargs: Optional[Dict[str, Any]] = None
+) -> GraphModule:
+    """
+    Given a model from `prepare_n_shadows_model`, runs `convert_fx`
+    on each shadow submodule.
+    """
+    for node in model.graph.nodes:
+        # TODO(future PR): consider matching in a safer way than
+        # node name string match
+        if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
+            orig_mod = getattr(model, node.name)
+            if custom_convert_fn is None:
+                converted_mod = torch.ao.quantization.quantize_fx.convert_fx(
+                    orig_mod)
+            else:
+                if custom_convert_kwargs is None:
+                    custom_convert_kwargs = {}
+                converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
+            setattr(model, node.name, converted_mod)
+
+    return model
+
+def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
+    """
+    Extracts logger results from `model`.
+    """
+    results: NSResultsType = {}
+    _extract_logger_info_one_model(model, results, OutputLogger)
+    return results
+
+def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
+    """
+    Prints a summary of extracted `results`.
+    """
+    results_grouped = group_results_by_subgraph(results)
+    results_comparison = create_results_comparison(results_grouped)
+    print_n_shadows_summary(results_comparison)
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__init__.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b3f3add42f99ba5508751ca43f6b0f3a91f58be
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3fd6e092e73b9013bde6a53429a8766943f2999
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..341d4a94cd925dd65042b94283735a5eda45365e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50bced046c6aa05bb99e87c0cd43dbb48922d097
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b81696a6b94b2e21319adb56db6152937829e437
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae626398dde9004facd2d3cb027d56306f94190f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7ad29d69044dfe584061d8c1501a5305d7942fa
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04620e9fd53b062d183406a917d64e9e3adbbf5e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3608f7f6c6a6c4a40c5e9c42c05a1253ba917b04
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..207d02d704dc62ee42b2bc7ba31001c11a0a4dc3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..b28b1607880575fbca4248c22d5e57642d7c16c2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/graph_matcher.py
@@ -0,0 +1,460 @@
+import collections
+import enum
+
+import torch
+toq = torch.ops.quantized
+
+from torch.fx import GraphModule
+from torch.fx.graph import Graph, Node
+
+from torch.ao.quantization.utils import getattr_from_fqn
+from .ns_types import NSSubgraph, NSNodeTargetType
+from .mappings import (
+    get_base_name_to_sets_of_related_ops,
+    get_unmatchable_types_map,
+)
+from .pattern_utils import (
+    get_type_a_related_to_b,
+    get_reversed_fusions,
+    end_node_matches_reversed_fusion,
+)
+from torch.ao.quantization import (
+    ObserverBase,
+    FakeQuantizeBase,
+)
+
+from typing import Dict, Tuple, List, Optional, Set, Any
+
+def _get_output_nodes(g: Graph) -> List[Node]:
+    return [n for n in g.nodes if n.op == 'output']
+
+class _NSGraphMatchableSubgraphsIterator:
+    """
+    Iterates through the graph of gm, starting with the output nodes
+    and continuing backwards.
+    1. Returns matchable subgraphs, in order. A subgraph is defined by
+       (start_node, end_node).
+    2. Skips over non-matchable subgraphs
+    """
+    def __init__(
+        self,
+        gm: GraphModule,
+        non_matchable_functions: Set[NSNodeTargetType],
+        non_matchable_modules: Set[NSNodeTargetType],
+        non_matchable_methods: Set[NSNodeTargetType],
+    ):
+        self.gm: GraphModule = gm
+        self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
+        self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
+        self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
+        self.seen_nodes: Set[Node] = set()
+        self.stack: List[Node] = []
+        for start_node in _get_output_nodes(self.gm.graph):
+            self.stack.append(start_node)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self) -> NSSubgraph:
+        """
+        Returns the next matchable subgraph.
+        """
+        while len(self.stack) > 0:
+            cur_end_node = self.stack.pop()
+            if cur_end_node in self.seen_nodes:
+                continue
+
+            # for subgraphs which are single nodes, start_node == end_node
+            # for subgraphs with more than one node, start node != end_node
+            cur_start_node = cur_end_node
+            # Subgraphs like linear-relu have the base node as the start node.
+            # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
+            #   base node as the second node.
+            # The cur_base_op_node var will move to the actual node during
+            #   the fusion matching later in this code block.
+            cur_base_op_node = cur_end_node
+
+            # Check for potential fusions. For now, we are greedy
+            # and always skip all non-base nodes of a fusion.  For example,
+            # if we match linear-relu backwards, we will always skip the
+            # relu node and attempt to match the linear node.  This can
+            # be made configurable later if needed.
+            for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
+                is_match = end_node_matches_reversed_fusion(
+                    cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes)
+                if is_match:
+                    # navigate to the base node
+                    for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
+                        self.seen_nodes.add(cur_start_node)
+                        # for now, assume that there are no other nodes
+                        # which need to be added to the stack
+                        cur_start_node = cur_start_node.args[0]  # type: ignore[assignment]
+                        # if the base op index matches the current node, set it
+                        rev_base_op_idx = \
+                            len(_reverse_fusion_ops) - 2 - base_op_idx
+                        if rev_fusion_idx == rev_base_op_idx:
+                            cur_base_op_node = cur_start_node
+                    break
+
+            self.seen_nodes.add(cur_start_node)
+            # add args of previous nodes to stack
+            for arg in cur_start_node.all_input_nodes:
+                self._recursively_add_node_arg_to_stack(arg)
+
+            # skip unmatchable nodes
+            # note: this check is done on the start_node, i.e.
+            # if we are matching linear-relu in reverse, this would do the matchable
+            # check on the linear
+            if not self._is_matchable(cur_base_op_node):
+                continue
+
+            # If an observer or a fake_quant was not matched as a part of
+            # a pattern of multiple nodes, ignore it. One case where this is
+            # relevant is an observer on a graph input, which was added because
+            # it is necessary for the next node.
+            if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
+                maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target)  # type: ignore[arg-type]
+                if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
+                    continue
+
+            return NSSubgraph(
+                start_node=cur_start_node, end_node=cur_end_node,
+                base_op_node=cur_base_op_node)
+
+        raise StopIteration
+
+    def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
+        """
+        Adds all of the nodes in this arg to the stack, properly navigating
+        through list, dicts and tuples.
+        """
+        if isinstance(arg, Node):
+            self.stack.append(arg)
+        elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
+            for inner_arg in arg:
+                self._recursively_add_node_arg_to_stack(inner_arg)
+        elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
+            for value in arg.values():
+                self._recursively_add_node_arg_to_stack(value)
+
+    def _is_matchable(self, node: Node) -> bool:
+        if node.op == 'call_function':
+            return node.target not in self.non_matchable_functions
+        elif node.op == 'call_module':
+            assert isinstance(node.target, str)
+            target_mod = getattr_from_fqn(self.gm, node.target)
+            return not \
+                any(isinstance(target_mod, t)  # type: ignore[arg-type]
+                    for t in self.non_matchable_modules)
+        elif node.op == 'call_method':
+            return node.target not in self.non_matchable_methods
+        else:
+            return False
+
+class GraphMatchingException(Exception):
+    """
+    Exception raised when two graphs cannot be matched.
+    """
+    pass
+
+class SubgraphTypeRelationship(enum.Enum):
+    # same type, known
+    # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
+    EQUAL = enum.auto()
+    # same type, but the type is not known to Numerical Suite
+    # (user defined type, etc).
+    EQUAL_BUT_UKNOWN = enum.auto()
+    # known, same subgraph_relationship set, but not the same type
+    # example: F.linear and toq.linear
+    RELATED_BUT_NOT_EQUAL = enum.auto()
+    # not related
+    NOT_RELATED = enum.auto()
+
+def _get_subgraph_relationship_type(
+    subgraph_a: NSSubgraph,
+    subgraph_b: NSSubgraph,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
+) -> SubgraphTypeRelationship:
+    node_a = subgraph_a.base_op_node
+    node_b = subgraph_b.base_op_node
+
+    # TODO(next): make this code handle matching by what is before the base op
+    if node_a.op != node_b.op:
+        if not (
+            node_a.op in ('call_function', 'call_method') and
+            node_b.op in ('call_function', 'call_method')
+        ):
+            return SubgraphTypeRelationship.NOT_RELATED
+
+    if node_a.op in ('call_function', 'call_method'):
+        key = (node_a.target, node_b.target)
+
+        if key not in type_a_related_to_b:
+            if node_a.target == node_b.target:
+                return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
+            else:
+                return SubgraphTypeRelationship.NOT_RELATED
+        # after this point, we are dealing with known types
+
+        if node_a.target == node_b.target:
+            node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
+            node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
+            if node_a_has_prev and (not node_b_has_prev):
+                return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+            elif (not node_a_has_prev) and node_b_has_prev:
+                return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+            elif (not node_a_has_prev) and (not node_b_has_prev):
+                return SubgraphTypeRelationship.EQUAL
+            else:
+                # TODO(future PR): check for matches start_op_node and base_op_node
+                return SubgraphTypeRelationship.EQUAL
+
+        if key in type_a_related_to_b:
+            return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+        else:
+            return SubgraphTypeRelationship.NOT_RELATED
+    elif node_a.op == 'call_module':
+        assert (subgraph_a.base_op_node == subgraph_a.start_node and
+                subgraph_b.base_op_node == subgraph_b.start_node), \
+            "Matching call_module patterns where base_op_node != start_node is not supported yet"
+        # for call_module, we need to look up the modules to do the type check
+        assert isinstance(node_a.target, str)
+        mod_a = getattr_from_fqn(gm_a, node_a.target)
+        assert isinstance(node_b.target, str)
+        mod_b = getattr_from_fqn(gm_b, node_b.target)
+
+        key = (type(mod_a), type(mod_b))
+
+        if key not in type_a_related_to_b:
+            if type(mod_a) == type(mod_b):
+                return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
+            else:
+                return SubgraphTypeRelationship.NOT_RELATED
+        elif type(mod_a) == type(mod_b):
+            return SubgraphTypeRelationship.EQUAL
+        else:
+            return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+
+    return SubgraphTypeRelationship.NOT_RELATED
+
+def _get_name_for_subgraph(
+    subgraph_a: NSSubgraph,
+    gm_a: GraphModule,
+    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
+    existing_names: Set[str],
+) -> str:
+    """
+    Returns a unique name for a subgraph. This name is based on two things:
+    1. the name of the set containing the underlying type of the base op in the
+       subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
+    2. the number of previous subgraphs with related underlying type of the base op
+
+    For example, in the graph
+
+    linear0 -> relu0 -> linear1 -> relu1
+
+    The subgraphs are (linear0, relu0) and (linear1, relu1).  If we iterate
+    from the output node backwards, the name given to (linear1, relu1) will be
+    `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
+    will be `base_op_torch.nn.functional.linear_1`.
+
+    Why are we not just using the node name? Answer: because of two requirements:
+    A. fusions must be supported
+    B. some Numeric Suite APIs can be called without having all of the models in memory
+
+    For example, let's say we need to match nodes of
+
+    (1) ... -> linear0 -> relu0 -> ...
+
+    And
+
+    (2) ... -> linear_relu0 -> ...
+
+    Without being able to inspect them together. With the current naming scheme, if
+    we iterate through both of these graphs in the same order, and assuming the rest
+    of the graphs match, both of these subgraphs will get the same name without
+    (1) and (2) knowing anything about each other.
+    """
+    target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
+    target_base_type = None
+    for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
+        if target_type in sets_of_related_ops:
+            target_base_type = base_name
+    target_base_name = 'base_op_' + str(target_base_type)
+    counter = 0
+    proposed_name = target_base_name + '_' + str(counter)
+    while proposed_name in existing_names:
+        counter += 1
+        proposed_name = target_base_name + '_' + str(counter)
+    existing_names.add(proposed_name)
+    return proposed_name
+
+def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
+    if node.op in ('call_function', 'call_method'):
+        return node.target
+    elif node.op == 'call_module':
+        assert isinstance(node.target, str)
+        mod = getattr_from_fqn(gm, node.target)
+        return type(mod)
+    return None
+
+def get_matching_subgraph_pairs(
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
+    """
+    Matches matchable subgraphs of graph_a to graph_b.
+
+    For a node, "matchable" is defined as a node which is not an observer,
+    fake_quants, quant or dequant.
+
+    A subgraph can contain one or more nodes.  A subgraph is matchable if
+    at least one node inside of it is matchable.  Currently, all nodes in
+    a subgraph must be matchable (because we assume no observers will be
+    inserted in the middle of a fusion).
+
+    A subgraph is defined by (start_node, end_node).  We assume that only
+    start_node and end_node are linked with the surrounding graph, all other
+    nodes in a subgraph are self-contained.
+
+    A pair of nodes is "related" if both nodes represent the same mathematical
+    operation across different quantization flavors. For example,
+    `F.linear` and `torch.ops.quantized.linear` are related, and
+    `F.linear` and `torch.nn.Conv` are not related.
+
+    For each matchable pair of nodes node_a and node_b, they will match
+    if node_a and node_b are related.
+
+    For graphs A and B, they will match iff:
+    1. the number of matchable subgraphs in A and B is equivalent
+    2. when iterating through the matchable subgraphs of A and B in the same order, each
+       corresponding pair of base nodes is related.
+
+    This enables us to find the corresponding subgraphs between
+    graphs of related models.  For example, if we had two graphs such as:
+
+    graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
+             w  -/
+             b  -/
+
+    graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
+           packed_params_0 -/
+
+    This function will return the following result:
+    {
+        'conv_0': (  # the name of the node in graph_b
+          (conv_0, conv_0),  # (start_node_a, end_node_a)
+          (qconv_0, qconv_0),  # (start_node_b, end_node_b)
+        ),
+    }
+
+    Or, if we have a fusion pattern,
+
+    graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
+             w  -/
+             b  -/
+
+    graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
+           packed_params_0 -/
+
+    This function will return the following result:
+    {
+        'linear_relu_0': (  # the name of the node in graph_b
+          (linear_0, relu_0),  # (start_node_a, end_node_a)
+          (linear_relu_0, linear_relu_0),  # (start_node_b, end_node_b)
+        ),
+    }
+    """
+    if unmatchable_types_map is None:
+        unmatchable_types_map = get_unmatchable_types_map()
+    non_matchable_functions = unmatchable_types_map['funs_unmatchable']
+    non_matchable_modules = unmatchable_types_map['mods_unmatchable']
+    non_matchable_methods = unmatchable_types_map['meths_unmatchable']
+
+    graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
+        gm_a, non_matchable_functions, non_matchable_modules,
+        non_matchable_methods)
+    graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
+        gm_b, non_matchable_functions, non_matchable_modules,
+        non_matchable_methods)
+    results = collections.OrderedDict()
+    if base_name_to_sets_of_related_ops is None:
+        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
+    type_a_related_to_b = \
+        get_type_a_related_to_b(base_name_to_sets_of_related_ops)
+
+    existing_names_a: Set[str] = set()
+    existing_names_b: Set[str] = set()
+
+    while True:
+        # fetch the next subgraphs from a and b
+        cur_subgraph_a, cur_subgraph_b = None, None
+        try:
+            cur_subgraph_a = next(graph_a_iterator)
+        except StopIteration:
+            pass
+        try:
+            cur_subgraph_b = next(graph_b_iterator)
+        except StopIteration:
+            pass
+
+        # look up types of a and b for useful error messages
+        type_start_a, type_start_b = None, None
+        if cur_subgraph_a is not None:
+            type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
+        if cur_subgraph_b is not None:
+            type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
+
+        # check for results and determine what to do next
+        if cur_subgraph_a is not None and cur_subgraph_b is not None:
+            # both nodes were fetched, check for subgraph_relationship
+            # note: subgraph_relationship is checked on the start node, i.e.
+            # if a linear-relu pattern is checked, we would check for subgraph_relationship
+            # of the linear
+            subgraph_relationship = _get_subgraph_relationship_type(
+                cur_subgraph_a, cur_subgraph_b,
+                gm_a, gm_b, type_a_related_to_b)
+            if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
+                msg = f"""
+The subgraphs
+({cur_subgraph_a}, {type_start_a}) and
+({cur_subgraph_b}, {type_start_b})
+are not related. Please ensure that the two models you pass in have the same number
+of subgraphs, and each pair of subgraphs is related to each other."""
+                raise GraphMatchingException(msg)
+            elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
+                # skip matching but unknown types
+                continue
+            key_name_a = _get_name_for_subgraph(
+                cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops,
+                existing_names_a)
+            key_name_b = _get_name_for_subgraph(
+                cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops,
+                existing_names_b)
+            assert key_name_a == key_name_b, \
+                f"Subgraph names {key_name_a} and {key_name_b} do not match"
+            results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
+            continue
+        elif cur_subgraph_a is None and cur_subgraph_b is None:
+            # we reached the end of both graphs
+            break
+        else:
+            # only one node was fetched, no match possible, throw error
+            msg = f"""
+Attempting to match
+({cur_subgraph_a}, {type_start_a}) and
+({cur_subgraph_b}, {type_start_b}),
+one of which is empty. Please ensure that the two models you pass in have the same number
+of subgraphs."""
+            raise GraphMatchingException(msg)
+
+    # The subgraph pairs are originally created by traversing the two graphs
+    # from the outputs to the inputs. Reverse the results to return the
+    # subgraphs in their order of execution.
+    results = collections.OrderedDict(reversed(list(results.items())))
+
+    return results
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/graph_passes.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/graph_passes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5f610544d1ea6ca56045cf1e0d63a0bba89aa61
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/graph_passes.py
@@ -0,0 +1,950 @@
+import torch
+from torch.fx import GraphModule, map_arg
+from torch.fx.graph import Graph, Node
+from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
+
+from .utils import (
+    get_node_first_input_and_output_type,
+    getattr_from_fqn,
+    NodeInputOrOutputType,
+    return_first_non_observer_node,
+    get_number_of_non_param_args,
+    get_target_type_str,
+    get_arg_indices_of_inputs_to_log,
+    get_node_input_qparams,
+    op_type_supports_shadowing,
+    get_normalized_nth_input,
+)
+
+from .ns_types import (
+    NSSingleResultValuesType,
+    NSSubgraph,
+    NSNodeTargetType,
+)
+from torch.ao.ns.fx.mappings import (
+    get_node_type_to_io_type_map,
+)
+from torch.ao.quantization.observer import _is_activation_post_process
+
+from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
+
+def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
+    fqn = None
+    if hasattr(gm, '_node_name_to_scope'):
+        # fqn on observers is not present, because they do not
+        # exist when the fqns are created during tracing. If this is
+        # an observer, get the fqn of the node being observed.
+        node_to_use_for_fqn = node
+        if node.op == 'call_module':
+            assert isinstance(node.target, str)
+            module = getattr_from_fqn(gm, node.target)
+            if _is_activation_post_process(module):
+                node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
+        fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0]  # type: ignore[index]
+    return fqn  # type: ignore[return-value]
+
+def _insert_logger_after_node(
+    node: Node,
+    gm: GraphModule,
+    logger_cls: Callable,
+    logger_node_name_suffix: str,
+    ref_node_name: str,
+    model_name: str,
+    ref_name: str,
+    ref_node_target_type: str,
+    results_type: str,
+    index_within_arg: int,
+    index_of_arg: int,
+    fqn: Optional[str],
+) -> Node:
+    """
+    Given a starting graph of
+
+    prev_node -> node -> next_node
+
+    This function creates a new logger_cls obj and adds it
+    after node, resulting in
+
+    prev_node -> node -> logger_obj -> next_node
+    """
+    # create new name
+    logger_node_name = \
+        get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
+    target_type = get_target_type_str(node, gm)
+    # create the logger object
+    logger_obj = logger_cls(
+        ref_node_name, node.name, model_name, ref_name, target_type,
+        ref_node_target_type,
+        results_type, index_within_arg, index_of_arg, fqn)
+    # attach the logger object to the parent module
+    setattr(gm, logger_node_name, logger_obj)
+    logger_node = node.graph.create_node(
+        'call_module', logger_node_name, (node,), {})
+    return logger_node
+
+def add_loggers_to_model(
+    gm: GraphModule,
+    node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
+    node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
+    logger_cls: Callable,
+    model_name: str,
+) -> GraphModule:
+    """
+    Takes the graph of gm, adds loggers to the output
+    of each node in nodes_to_instrument. Returns a GraphModule with the new
+    graph.
+    """
+
+    new_graph = Graph()
+    env: Dict[str, Any] = {}
+    modules = dict(gm.named_modules())
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node.name])
+
+    for node in gm.graph.nodes:
+        if node.op == 'output':
+            new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
+            continue
+
+        if (
+            (node in node_to_instrument_inputs_to_ref_node_name) or
+            (node in node_to_instrument_outputs_to_ref_node_name)
+        ):
+            fqn = _maybe_get_fqn(node, gm)
+
+            if node in node_to_instrument_inputs_to_ref_node_name:
+                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
+                # Ops such add and mul are special because either
+                # one or two of the first two arguments can be tensors,
+                # and if one argument is a tensor it can be first or
+                # second (x + 1 versus 1 + x).
+                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
+                for node_arg_idx in arg_indices_to_log:
+                    node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
+                    if type(node_arg) == Node:
+                        # create a single input logger
+                        prev_node = env[node_arg.name]
+                        env[node_arg.name] = _insert_logger_after_node(
+                            prev_node, gm, logger_cls, '_ns_logger_', node.name,
+                            model_name, ref_name, ref_node_type,
+                            NSSingleResultValuesType.NODE_INPUT.value,
+                            index_within_arg=0, index_of_arg=node_arg_idx,
+                            fqn=fqn)
+                    elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
+                        # create N input loggers, one for each node
+                        for arg_idx, arg in enumerate(node_arg):  # type: ignore[var-annotated, arg-type]
+                            prev_node = env[arg.name]
+                            env[prev_node.name] = _insert_logger_after_node(
+                                prev_node, gm, logger_cls, '_ns_logger_', node.name,
+                                model_name, ref_name, ref_node_type,
+                                NSSingleResultValuesType.NODE_INPUT.value,
+                                index_within_arg=arg_idx, index_of_arg=node_arg_idx,
+                                fqn=fqn)
+                    else:
+                        pass
+
+            # ensure env is populated with base node
+            # Note: runs for both inputs and outputs
+            env[node.name] = new_graph.node_copy(node, load_arg)
+
+            if node in node_to_instrument_outputs_to_ref_node_name:
+                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
+                # add the logger after the base node
+                env[node.name] = _insert_logger_after_node(
+                    env[node.name], gm, logger_cls, '_ns_logger_', node.name,
+                    model_name, ref_name, ref_node_type,
+                    NSSingleResultValuesType.NODE_OUTPUT.value,
+                    index_within_arg=0, index_of_arg=0, fqn=fqn)
+
+        else:
+            env[node.name] = new_graph.node_copy(node, load_arg)
+
+    new_gm = GraphModule(gm, new_graph)
+    return new_gm
+
+def _insert_quantize_per_tensor_node(
+    prev_node_c: Node,
+    node_a: Node,
+    gm_b: GraphModule,
+    graph_c: Graph,
+    scale: Union[torch.Tensor, float],
+    zero_point: Union[torch.Tensor, int],
+    dtype_cast_name: str,
+) -> Node:
+    # copy scale
+    scale_node_name = \
+        get_new_attr_name_with_prefix(
+            node_a.name + '_input_scale_')(gm_b)
+    setattr(gm_b, scale_node_name, scale)
+    scale_node = graph_c.create_node(
+        'get_attr', scale_node_name, (), {}, scale_node_name)
+    # copy zero_point
+    zero_point_node_name = \
+        get_new_attr_name_with_prefix(
+            node_a.name + '_input_zero_point_')(gm_b)
+    setattr(gm_b, zero_point_node_name, zero_point)
+    zero_point_node = graph_c.create_node(
+        'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
+    # create the quantize_per_tensor call
+    return graph_c.create_node(
+        'call_function', torch.quantize_per_tensor,
+        (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
+        dtype_cast_name)
+
+def _insert_dtype_cast_after_node(
+    node_a: Node,
+    node_c: Node,
+    prev_node_c: Union[Node, List[Node]],
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    graph_c: Graph,
+    node_name_prefix: str,
+    logger_cls: Callable,
+    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
+) -> Union[Node, List[Node]]:
+    """
+    Given a starting graph C (derived from graph B) of
+
+    ... -> prev_node_c -> node_c -> ...
+
+    And a corresponding related node_a, inserts the correct dtype
+    cast node after prev_node_c to cast into the dtype expected
+    by node_a, resulting in:
+
+                          dtype_cast
+                        /
+    ... -> prev_node_c -> node_c -> ...
+
+    For example, if node_c is an int8 op and node_a is an fp32 op, this function
+    will insert a dequant.
+    """
+    dtype_cast_op = None
+    dtype_cast_mod_cls = None
+    dtype_cast_method = None
+    dtype_cast_method_dtype = None
+    dtype_cast_scale = None
+    dtype_cast_zero_point = None
+    node_input_type_a, _node_output_type_a = \
+        get_node_first_input_and_output_type(
+            node_a, gm_a, logger_cls, node_type_to_io_type_map)
+    node_input_type_c, _node_output_type_c = \
+        get_node_first_input_and_output_type(
+            node_c, gm_b, logger_cls, node_type_to_io_type_map)
+
+    if (
+        (node_input_type_a == NodeInputOrOutputType.FP32 and
+         node_input_type_c == NodeInputOrOutputType.INT8) or
+        (node_input_type_a == NodeInputOrOutputType.FP32 and
+         node_input_type_c == NodeInputOrOutputType.FP16) or
+        # TODO(future PR): determine the actual dtype of node_c,
+        # the current code only works because dequantize works with
+        # multiple input dtypes.
+        (node_input_type_a == NodeInputOrOutputType.FP32 and
+         node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
+    ):
+        dtype_cast_op = torch.dequantize
+    elif (
+        node_input_type_a == node_input_type_c and
+        node_input_type_a != NodeInputOrOutputType.UNKNOWN
+    ):
+        dtype_cast_mod_cls = torch.nn.Identity
+    elif (
+        node_input_type_a == NodeInputOrOutputType.INT8 and
+        node_input_type_c == NodeInputOrOutputType.FP32
+    ):
+        # int8 shadows fp32, the dtype cast needs to quantize to int8
+        # with the right qparams.
+        node_a_input_qparams = get_node_input_qparams(
+            node_a, gm_a, node_type_to_io_type_map)
+        if node_a_input_qparams is not None:
+            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
+            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
+    elif (
+        node_input_type_a == NodeInputOrOutputType.FP16 and
+        node_input_type_c == NodeInputOrOutputType.FP32
+    ):
+        dtype_cast_method = 'to'
+        dtype_cast_method_dtype = torch.float16
+    else:
+        raise AssertionError(
+            f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
+            f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
+
+    if isinstance(prev_node_c, Node):
+        new_dtype_cast_name = \
+            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+        if dtype_cast_op:
+            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
+                return _insert_quantize_per_tensor_node(
+                    prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
+                    dtype_cast_zero_point, new_dtype_cast_name)
+            else:
+                return graph_c.create_node(
+                    'call_function', dtype_cast_op, (prev_node_c,), {},
+                    new_dtype_cast_name)
+        elif dtype_cast_method:
+            return graph_c.create_node(
+                'call_method', dtype_cast_method,
+                (prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
+        else:
+            assert dtype_cast_mod_cls
+            dtype_cast_mod = dtype_cast_mod_cls()
+            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
+            return graph_c.create_node(
+                'call_module', new_dtype_cast_name, (prev_node_c,), {},
+                new_dtype_cast_name)
+    elif isinstance(prev_node_c, list):
+        results = []
+        for prev_node_c_inner in prev_node_c:
+            new_dtype_cast_name = \
+                get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+            if dtype_cast_op:
+                # TODO(future PR): add handling for quantize_per_tensor
+                new_dtype_cast_node = graph_c.create_node(
+                    'call_function', dtype_cast_op, (prev_node_c_inner,), {},
+                    new_dtype_cast_name)
+                results.append(new_dtype_cast_node)
+            else:
+                assert dtype_cast_mod_cls
+                dtype_cast_mod = dtype_cast_mod_cls()
+                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
+                new_dtype_cast_node = graph_c.create_node(
+                    'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
+                    new_dtype_cast_name)
+                results.append(new_dtype_cast_node)
+        return results
+    else:
+        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
+
+# TODO(future PR): look into using copy_node API instead
+def _copy_node_from_a_to_c(
+    node_a: Node,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    graph_c: Graph,
+) -> Node:
+    """
+    Simple copy of node_a to graph_c.
+    """
+    if node_a.op == 'get_attr':
+        node_a_copy_name = \
+            get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
+        node_a_obj = getattr_from_fqn(gm_a, node_a.target)  # type: ignore[arg-type]
+        if torch.is_tensor(node_a_obj):
+            node_a_obj = node_a_obj.detach()
+        setattr(gm_b, node_a_copy_name, node_a_obj)
+        node_a_copy = graph_c.create_node(
+            node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
+        return node_a_copy
+    elif node_a.op == 'call_method':
+        assert node_a.target in ('dequantize', 'to'), \
+            f"target {node_a.target} is not implemented"
+        if node_a.target == 'dequantize':
+            arg_copy = _copy_node_from_a_to_c(
+                get_normalized_nth_input(node_a, gm_a, 0),
+                gm_a, gm_b, graph_c)  # type: ignore[arg-type]
+            node_a_copy_name = \
+                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
+            node_a_copy = graph_c.create_node(
+                node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
+            return node_a_copy
+        else:  # to
+            arg_copy = _copy_node_from_a_to_c(
+                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c)  # type: ignore[arg-type]
+            node_a_copy_name = \
+                get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
+            node_a_copy = graph_c.create_node(
+                node_a.op, node_a.target,
+                (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
+                {}, node_a_copy_name)
+            return node_a_copy
+
+    else:
+        raise AssertionError(
+            f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
+
+def _can_insert_copy_of_subgraph_a(
+    subgraph_a: NSSubgraph,
+    gm_a: GraphModule,
+    num_non_param_args_node_a: int,
+) -> bool:
+    """
+    This function returns `False` if the input subgraph cannot be copied by
+    `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
+    that there is a corner case logic for which copy is not yet implemented.
+    """
+    # populate the list of nodes we need to check
+    nodes = []
+    cur_node = subgraph_a.end_node
+    while cur_node != subgraph_a.start_node:
+        nodes.append(cur_node)
+        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
+    nodes.append(cur_node)
+    nodes.reverse()
+
+    def _can_insert(node_a_arg, gm_a):
+        if isinstance(node_a_arg, Node):
+            arg_a = return_first_non_observer_node(node_a_arg, gm_a)
+            if arg_a.op == 'call_method':
+                return arg_a.target in ('dequantize', 'to')
+            elif arg_a.op == 'get_attr':
+                return True
+            else:
+                return False
+        elif isinstance(node_a_arg, (list, tuple)):
+            for el in node_a_arg:
+                if not isinstance(el, Node):
+                    return False
+        return True
+
+    # For each node, check if we handle the copy behavior. This follows the
+    # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
+    for node_a in nodes:
+
+        local_num_non_param_args_node_a = num_non_param_args_node_a \
+            if node_a is nodes[0] else 1
+
+        norm_args_kwargs = node_a.normalized_arguments(
+            gm_a, normalize_to_only_use_kwargs=True)
+        if norm_args_kwargs is not None:
+            norm_args, norm_kwargs = norm_args_kwargs
+        else:
+            norm_args, norm_kwargs = node_a.args, node_a.kwargs
+
+        cur_idx = 0
+
+        while cur_idx < len(norm_args):
+            if cur_idx == 0:
+                pass
+            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
+                pass
+            else:
+                if not _can_insert(norm_args[cur_idx], gm_a):
+                    return False
+            cur_idx += 1
+
+        for kwarg_val in norm_kwargs.values():
+            # stitch the inputs from base graph
+            if cur_idx == 0:
+                pass
+            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
+                pass
+            else:
+                if not _can_insert(kwarg_val, gm_a):
+                    return False
+            cur_idx += 1
+
+    return True
+
+def _insert_copy_of_subgraph_a_after_input_node_c(
+    input_node_c: Union[Node, List[Node]],
+    input_node_c_2: Optional[Union[Node, List[Node]]],
+    subgraph_a: NSSubgraph,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    node_name_prefix: str,
+) -> Node:
+    """
+    TODO(before land): real docblock
+    """
+    if isinstance(input_node_c, Node):
+        graph_c = input_node_c.graph
+    else:
+        assert isinstance(input_node_c, list)
+        graph_c = input_node_c[0].graph
+
+    # create a sequential list of the subgraphs' nodes from start to end,
+    # because we need to add the nodes to graph C in non-reverse order
+    nodes_of_a = [subgraph_a.end_node]
+    cur_node = subgraph_a.end_node
+    while cur_node != subgraph_a.start_node:
+        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
+        nodes_of_a.insert(0, cur_node)
+
+    # go through nodes of a in order, and insert them into the graph of c
+    # sequentially
+    cur_node_a = nodes_of_a[0]
+    cur_node_c = _insert_copy_of_node_a_after_input_node_c(
+        input_node_c,
+        input_node_c_2,
+        cur_node_a,
+        gm_a,
+        gm_b,
+        node_name_prefix)
+    for cur_idx_a in range(1, len(nodes_of_a)):
+        cur_node_a = nodes_of_a[cur_idx_a]
+        prev_node_c = cur_node_c  # previous added node is the input to next node
+        cur_node_c = _insert_copy_of_node_a_after_input_node_c(
+            prev_node_c,
+            # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
+            None,
+            cur_node_a,
+            gm_a,
+            gm_b,
+            node_name_prefix)
+    # return the last inserted node
+    return cur_node_c
+
+
+def _insert_copy_of_node_a_after_input_node_c(
+    input_node_c: Union[Node, List[Node]],
+    input_node_c_2: Optional[Union[Node, List[Node]]],
+    node_a: Node,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    node_name_prefix: str,
+) -> Node:
+    """
+    Assume that node_a from graph_a has
+      args (input, (input2)?, arg1, ...), and
+      kwargs {kw0: kwarg0, ...}
+
+    Note: input2 is optional. If it equals to None, we assume that the op
+    has a single non-param input.  If it is specified, we assume that the op
+    has two non-param inputs.
+
+    Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
+    and creates the corresponding nodes in graph_c. Note: observers are ignored,
+    so if an arg is an observer we navigate up until we find a non-observer parent.
+
+    If node_a is a call_module, points the module pointed to by node_a to gm_b.
+
+    Creates the copy of node_a in graph_c, with input as the first arg,
+    and all other args and kwargs pointing to the copies of the objects
+    in gm_b created above.
+
+    An example in pictures:
+
+    graph A:
+    ========
+
+    input -------------> node_a
+                         / / /
+    (input_2)?----------/ / /
+                         / /
+    weight -> weight_obs  /
+                         /
+    bias ----------------
+
+    graph C (derived from B):
+    =========================
+
+    input_node_c --> node_a_copy
+                     / / /
+    (input_node_c_2)? / /
+                     / /
+    weight_copy ----/ /
+                     /
+    bias_copy ------/
+    """
+    if isinstance(input_node_c, Node):
+        graph_c = input_node_c.graph
+    else:
+        assert isinstance(input_node_c, list)
+        graph_c = input_node_c[0].graph
+
+    norm_args_kwargs = node_a.normalized_arguments(
+        gm_a, normalize_to_only_use_kwargs=True)
+    if norm_args_kwargs is not None:
+        norm_args, norm_kwargs = norm_args_kwargs
+    else:
+        norm_args, norm_kwargs = node_a.args, node_a.kwargs
+
+    new_args = []
+    new_kwargs = {}
+
+    def _copy_arg(arg):
+        # copy the other inputs from the other graph
+        if isinstance(arg, Node):
+            arg = return_first_non_observer_node(arg, gm_a)
+            arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
+            return arg
+        elif isinstance(arg, (int, float, torch.dtype)):
+            return arg
+        elif isinstance(kwarg_val, (list, tuple)):
+            for el in kwarg_val:
+                assert not isinstance(el, Node), \
+                    "handling of Node inside list is not implemented"
+            return arg
+        else:
+            raise AssertionError(
+                f"handling for kwarg of type {type(kwarg_val)} is not implemented")
+
+    cur_idx = 0
+
+    while cur_idx < len(norm_args):
+        if cur_idx == 0:
+            new_arg = input_node_c
+        elif cur_idx == 1 and input_node_c_2 is not None:
+            new_arg = input_node_c_2
+        else:
+            new_arg = _copy_arg(norm_args[cur_idx])
+        new_args.append(new_arg)
+        cur_idx += 1
+
+    for kwarg_name, kwarg_val in norm_kwargs.items():
+        # stitch the inputs from base graph
+        if cur_idx == 0:
+            new_kwargs[kwarg_name] = input_node_c
+        elif cur_idx == 1 and input_node_c_2 is not None:
+            new_kwargs[kwarg_name] = input_node_c_2
+        else:
+            new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
+        cur_idx += 1
+
+    new_args = tuple(new_args)  # type: ignore[assignment]
+
+    node_a_shadows_c_name = \
+        get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+
+    if node_a.op == 'call_module':
+        # if target is a module, we point to the module from gm_b
+        new_mod_copy_name = \
+            get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+        # fetch the corresponding module from gm_a
+        assert isinstance(node_a.target, str)
+        mod_a = getattr_from_fqn(gm_a, node_a.target)
+        setattr(gm_b, new_mod_copy_name, mod_a)
+        node_a_shadows_c = graph_c.create_node(
+            node_a.op, new_mod_copy_name, new_args,
+            new_kwargs, node_a_shadows_c_name)
+        return node_a_shadows_c
+    else:
+        assert node_a.op in ('call_function', 'call_method')
+        node_a_shadows_c = graph_c.create_node(
+            node_a.op, node_a.target, new_args,
+            new_kwargs, node_a_shadows_c_name)
+        return node_a_shadows_c
+
+def create_a_shadows_b(
+    name_a: str,
+    gm_a: GraphModule,
+    name_b: str,
+    gm_b: GraphModule,
+    matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
+    logger_cls: Callable,
+    should_log_inputs: bool,
+    node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
+) -> GraphModule:
+    """
+    Creates a new GraphModule consisting of the graph of C, with the meaningful
+    nodes of A shadowing the corresponding nodes of B.  For example,
+
+    Graph A:
+    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
+
+    Graph B:
+    b0 -> op0_int8 -> b1 -> op1_int8 -> b2
+
+    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
+
+    Graph C (A shadows B):
+
+        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
+       /                                     /
+    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
+
+    In a nutshell, this function does the following for each node pair:
+    * copies the necessary attributes and modules from gm_a to gm_b,
+      keeping names unique
+    * adds a dtype cast op (dequant, quant, etc)
+    * adds a copy of node_a in gm_b's graph
+    * adds loggers to the outputs of node_a and node_b
+    """
+
+    if node_type_to_io_type_map is None:
+        node_type_to_io_type_map = get_node_type_to_io_type_map()
+
+    # graph_c is the graph created from copying the nodes of graph_b and inserting
+    # the shadows with the nodes copied from graph_a
+    graph_c = Graph()
+    env_c: Dict[str, Any] = {}
+    modules = dict(gm_b.named_modules())
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env_c[node.name])
+
+    start_node_b_to_matched_subgraph_a_and_name = {}
+    end_node_b_to_matched_subgraph_a_and_name = {}
+    for match_name, match in matched_subgraph_pairs.items():
+        subgraph_a, subgraph_b = match
+        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
+        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
+        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
+            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
+        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
+            (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
+
+    for node_b in gm_b.graph.nodes:
+        if node_b.op == 'output':
+            graph_c.output(map_arg(node_b.args[0], load_arg))
+            continue
+
+        # calculate the flags to determine what to do with this node
+        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
+        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
+
+        if (node_b_is_start_node or node_b_is_end_node):
+
+            if node_b_is_start_node:
+                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
+                    start_node_b_to_matched_subgraph_a_and_name[node_b]
+            else:
+                assert node_b_is_end_node
+                subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
+                    end_node_b_to_matched_subgraph_a_and_name[node_b]
+
+            all_op_types_support_shadowing = (
+                op_type_supports_shadowing(subgraph_a.start_node) and
+                op_type_supports_shadowing(node_b)
+            )
+            if not all_op_types_support_shadowing:
+                print(
+                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+                    ', unsupported')
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            # For both start_node and end_node verify that we know how to do
+            # the dtype cast. If we do not, skip.
+            node_input_type_a, node_output_type_a = \
+                get_node_first_input_and_output_type(
+                    subgraph_a.start_node, gm_a, logger_cls,
+                    node_type_to_io_type_map)
+            node_input_type_b, node_output_type_b = \
+                get_node_first_input_and_output_type(
+                    node_b, gm_b, logger_cls,
+                    node_type_to_io_type_map)
+            node_io_types_known_a_and_b = (
+                node_input_type_a != NodeInputOrOutputType.UNKNOWN and
+                node_output_type_a != NodeInputOrOutputType.UNKNOWN and
+                node_input_type_b != NodeInputOrOutputType.UNKNOWN and
+                node_output_type_b != NodeInputOrOutputType.UNKNOWN
+            )
+            if not node_io_types_known_a_and_b:
+                print(
+                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+                    ', unknown dtype cast')
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            # If we are shadowing from fp32 to int8, we need to insert
+            # quantize_per_tensor call with qparams from the previous node.
+            # Only do this if we are able to infer these qparams from the graph.
+            if (
+                node_input_type_a == NodeInputOrOutputType.INT8 and
+                node_input_type_b == NodeInputOrOutputType.FP32
+            ):
+                node_a_input_qparams = get_node_input_qparams(
+                    subgraph_a.start_node, gm_a, node_type_to_io_type_map)
+                if not node_a_input_qparams:
+                    print(
+                        f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+                        f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+                        ', unknown input qparams')
+                    env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                    continue
+
+            num_non_param_args_node_a = \
+                get_number_of_non_param_args(subgraph_a.start_node, gm_a)
+            if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
+                print(
+                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+                    ', unhandled logic in subgraph copy')
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
+            fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)  # type: ignore[possibly-undefined]
+
+            if node_b_is_start_node:
+
+                # if necessary, log the input of node_c
+                if should_log_inputs:
+                    prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
+                    if isinstance(prev_node_b, Node):
+                        prev_node_c = env_c[prev_node_b.name]
+                        env_c[prev_node_c.name] = _insert_logger_after_node(
+                            prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
+                            node_b.name, name_b, ref_name, ref_node_type_b,
+                            NSSingleResultValuesType.NODE_INPUT.value,
+                            index_within_arg=0, index_of_arg=0,
+                            fqn=fqn_base_b)
+                    elif isinstance(prev_node_b, list):
+                        # first, save the prev_node instances, because they
+                        # will be overwritten in the env after the first logger
+                        # is added
+                        prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
+
+                        for arg_idx, arg in enumerate(prev_node_b):
+                            prev_node_c = prev_node_c_list[arg_idx]
+                            env_c[prev_node_c.name] = _insert_logger_after_node(
+                                prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
+                                node_b.name, name_b, ref_name, ref_node_type_b,
+                                NSSingleResultValuesType.NODE_INPUT.value,
+                                index_within_arg=arg_idx, index_of_arg=0,
+                                fqn=fqn_base_b)
+                    else:
+                        # logging of inputs which are not lists is not supported yet
+                        raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
+                # subgraph so far:
+                #
+                # (prev_node_c)+ -> (logger_c_input)?
+
+            # Note: this if statement is always True, spelling it out to clarify code
+            # intent.
+            if node_b_is_start_node or node_b_is_end_node:
+                # ensure env_c is populated with base node
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                node_c = env_c[node_b.name]
+
+                # after this point,
+                #
+                # node_a is the original node from graph_a, with parent module gm_a
+                # node_b is the original node from graph_b, with parent module gm_b
+                # node_c is the copy of node_b in graph_c
+                #
+                # subgraph so far:
+                #
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+            if node_b_is_start_node:
+
+                # cast dtype from the dtype of node_c's input to the dtype of
+                # node_a's input (dequant, etc)
+                # prev_node_c = node_c.args[0]
+                prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)  # type: ignore[possibly-undefined]
+                if should_log_inputs:
+                    # skip the input logger when inserting a dtype cast
+                    if isinstance(prev_node_c, Node):
+                        prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
+                    elif isinstance(prev_node_c, list):
+                        prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
+                dtype_cast_node = _insert_dtype_cast_after_node(
+                    subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
+                    node_b.name + '_dtype_cast_', logger_cls,
+                    node_type_to_io_type_map)
+                # note: not inserting to env_c because all nodes which use the dtype
+                #   casts are copied from graph_a
+                #
+                # subgraph so far:
+                #
+                #           (dtype_cast_node)+
+                #                  /
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+                # if input logging is enabled, log the input to the subgraph
+                if should_log_inputs:
+                    # TODO: explain this
+                    ref_node_name = ''
+                    if isinstance(dtype_cast_node, Node):
+                        dtype_cast_node = _insert_logger_after_node(
+                            dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
+                            ref_node_name, name_a, ref_name, ref_node_type_a,
+                            NSSingleResultValuesType.NODE_INPUT.value,
+                            index_within_arg=0, index_of_arg=0,
+                            fqn=fqn_base_a)
+                        input_logger: Union[Node, List[Node]] = dtype_cast_node
+                    else:
+                        assert isinstance(dtype_cast_node, list)
+                        new_loggers = []
+                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
+                            dtype_cast_logger = _insert_logger_after_node(
+                                dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
+                                ref_node_name, name_a, ref_name, ref_node_type_a,
+                                NSSingleResultValuesType.NODE_INPUT.value,
+                                index_within_arg=dtype_cast_idx,
+                                index_of_arg=0,
+                                fqn=fqn_base_a)
+                            new_loggers.append(dtype_cast_logger)
+                        dtype_cast_node = new_loggers
+                        input_logger = dtype_cast_node
+                    # subgraph so far:
+                    #
+                    #       (dtype_cast_node)+ -> (logger_a_input)?
+                    #                  /
+                    # prev_node_c -> (logger_c_input)? -> node_start_c
+
+                # hook up the new mod_a copy to be in the graph, receiving the
+                # same inputs as mod_b does, with dtype cast to match a
+                # Some ops, such as LSTMs, have two non-param inputs. If we have
+                # such an op, pass the second param as well. Note: dtype casting
+                # for the second param is not implemented yet, it can be added
+                # later if there is a use case.
+                node_c_second_non_param_arg = None
+                num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
+                if num_non_param_args_node_a == 2:
+                    # node_c_second_non_param_arg = node_c.args[1]
+                    node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
+                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
+                    dtype_cast_node, node_c_second_non_param_arg,
+                    subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
+                env_c[node_a_shadows_c.name] = node_a_shadows_c
+                # subgraph so far:
+                #
+                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
+                #                  /
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+                if should_log_inputs:
+                    # When we created the input logger, we left the ref_node_name
+                    # as an empty string, because the subgraph copy did not exist
+                    # yet. Now that the subgraph copy exists, we modify this name
+                    # to its true value.
+                    # Note: the alternative to this is to create the input logger
+                    # after creating the subgraph, which is slightly more
+                    # complicated. This is the lesser of two evils.
+                    # input_logger = env_c[dtype_cast_node.name]
+                    # Find the first node in the subgraph
+                    cur_node = node_a_shadows_c
+                    while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:  # type: ignore[possibly-undefined]
+                        cur_node = get_normalized_nth_input(cur_node, gm_b, 0)  # type: ignore[assignment]
+                    if isinstance(input_logger, Node):
+                        input_logger_mod = getattr(gm_b, input_logger.name)
+                        input_logger_mod.ref_node_name = cur_node.name
+                    else:
+                        assert isinstance(input_logger, list)
+                        for input_logger_inner in input_logger:
+                            input_logger_mod = getattr(gm_b, input_logger_inner.name)
+                            input_logger_mod.ref_node_name = cur_node.name
+
+                # hook up a logger to the mod_a copy
+                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
+                    env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
+                    node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
+                    NSSingleResultValuesType.NODE_OUTPUT.value,
+                    index_within_arg=0, index_of_arg=0,
+                    fqn=fqn_base_a)
+                # subgraph so far:
+                #
+                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
+                #                  /
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+            if node_b_is_end_node:
+
+                # hook up a logger to the mod_b copy
+                env_c[node_b.name] = _insert_logger_after_node(
+                    env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
+                    node_b.name, name_b, ref_name, ref_node_type_b,
+                    NSSingleResultValuesType.NODE_OUTPUT.value,
+                    index_within_arg=0, index_of_arg=0,
+                    fqn=fqn_base_b)
+                # subgraph so far:
+                #
+                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
+                #                  /
+                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
+                #
+                # Note: node_start_c may be the same node as node_end_c, or they
+                # may have nodes inbetween.
+
+        else:
+            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+
+    gm_c = GraphModule(gm_b, graph_c)
+    return gm_c
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/mappings.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..42bf49f74c958da1b456f616f5f5d28c5c714d79
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/mappings.py
@@ -0,0 +1,761 @@
+import operator
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+toq = torch.ops.quantized
+
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.qat.dynamic as nnqatd
+from torch.ao.quantization.backend_config import get_native_backend_config
+import torch.ao.quantization.fx._lower_to_native_backend as \
+    _lower_to_native_backend
+import torch.ao.quantization.quantization_mappings as quantization_mappings
+
+from .ns_types import NSNodeTargetType
+
+from typing import Callable, Dict, List, Optional, Set, Tuple
+
+
+def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
+    # note: this set is modified below by items from backend_config
+    sets_of_related_ops: List[Set[NSNodeTargetType]] = [
+        # conv modules
+        {
+            nn.Conv1d,
+        },
+        {
+            nn.Conv2d,
+        },
+        {
+            nn.Conv3d,
+        },
+        # conv functionals
+        {
+            F.conv1d,
+        },
+        {
+            F.conv2d,
+        },
+        {
+            F.conv3d,
+        },
+        # linear modules
+        {
+            nn.Linear,
+        },
+        # linear functionals
+        {
+            F.linear,
+        },
+        # average pool
+        {
+            nn.AvgPool1d,
+            torch.avg_pool1d,
+        },
+        {
+            nn.AvgPool2d,
+            torch._C._nn.avg_pool2d,
+        },
+        {
+            nn.AvgPool3d,
+            torch._C._nn.avg_pool3d,
+        },
+        # adaptive average pool
+        {
+            nn.AdaptiveAvgPool1d,
+            F.adaptive_avg_pool1d,
+        },
+        {
+            nn.AdaptiveAvgPool2d,
+            F.adaptive_avg_pool2d,
+        },
+        {
+            nn.AdaptiveAvgPool3d,
+            F.adaptive_avg_pool3d,
+        },
+        # LSTM
+        {
+            nn.LSTM,
+        },
+        # add
+        {
+            torch.add,
+            operator.add,  # x + y
+        },
+        # cat
+        {
+            torch.cat,
+        },
+        # mul
+        {
+            torch.mul,
+            operator.mul,
+        },
+        # relu
+        {
+            F.relu,
+            nn.ReLU,
+            'relu',
+            'relu_',
+            torch.relu,
+        },
+        # maxpool
+        {
+            nn.MaxPool1d,
+            F.max_pool1d,
+        },
+        {
+            nn.MaxPool2d,
+            F.max_pool2d,
+        },
+        {
+            nn.MaxPool3d,
+            F.max_pool3d,
+        },
+        # sigmoid
+        {
+            torch.sigmoid,
+            'sigmoid',
+            'sigmoid_',
+            nn.Sigmoid,
+            F.sigmoid,
+        },
+        # BatchNorm
+        {
+            nn.BatchNorm2d,
+        },
+        {
+            nn.BatchNorm3d,
+        },
+        # ConvTranspose
+        {
+            nn.ConvTranspose1d,
+        },
+        {
+            nn.ConvTranspose2d,
+        },
+        {
+            nn.ConvTranspose3d,
+        },
+        # functional transposed conv
+        {
+            F.conv_transpose1d,
+        },
+        {
+            F.conv_transpose2d,
+        },
+        {
+            F.conv_transpose3d,
+        },
+        # ELU
+        {
+            nn.ELU,
+        },
+        # Embedding
+        {
+            nn.Embedding,
+        },
+        # EmbeddingBag
+        {
+            nn.EmbeddingBag,
+        },
+        # GroupNorm
+        {
+            nn.GroupNorm,
+        },
+        # Hardswish
+        {
+            nn.Hardswish,
+        },
+        # InstanceNorm
+        {
+            nn.InstanceNorm1d,
+        },
+        {
+            nn.InstanceNorm2d,
+        },
+        {
+            nn.InstanceNorm3d,
+        },
+        # LayerNorm
+        {
+            nn.LayerNorm,
+        },
+        # LeakyReLU
+        {
+            nn.LeakyReLU,
+        },
+        # ReLU6
+        {
+            nn.ReLU6,
+            F.relu6,
+        },
+        # F.elu
+        {
+            F.elu,
+        },
+        # F.hardswish
+        {
+            F.hardswish,
+        },
+        # F.group_norm
+        {
+            F.group_norm,
+        },
+        # F.instance_norm
+        {
+            F.instance_norm,
+        },
+        # F.layer_norm
+        {
+            F.layer_norm,
+        },
+        # F.leaky_relu
+        {
+            F.leaky_relu,
+        },
+        # F.silu
+        {
+            nn.SiLU,
+            F.silu,
+        },
+        # F.mish
+        {
+            nn.Mish,
+            F.mish,
+        },
+        # F.tanh
+        {
+            nn.Tanh,
+            F.tanh,
+            torch.tanh,
+            'tanh_',
+            'tanh',
+        },
+        # F.hardsigmoid
+        {
+            'hardsigmoid_',
+            'hardsigmoid',
+            F.hardsigmoid,
+            nn.Hardsigmoid,
+        },
+        # F.hardtanh
+        {
+            nn.Hardtanh,
+            F.hardtanh,
+            F.hardtanh_,
+        },
+        # floordiv
+        {
+            operator.floordiv,
+        },
+        # unsqueeze
+        {
+            torch.unsqueeze,
+        },
+        # stack
+        {
+            torch.stack,
+        },
+        # squeeze
+        {
+            torch.squeeze,
+        },
+        # sort
+        {
+            torch.sort,
+        },
+        # repeat_interleave
+        {
+            torch.repeat_interleave,
+        },
+        # min
+        {
+            torch.min,
+        },
+        # mean
+        {
+            torch.mean,
+        },
+        # max
+        {
+            torch.max,
+        },
+        # transpose
+        {
+            torch.transpose,
+        },
+        # flatten
+        {
+            torch.flatten,
+        },
+        # clamp
+        {
+            torch.clamp,
+        },
+        # chunk
+        {
+            torch.chunk,
+        },
+        # interpolate
+        {
+            torch.nn.functional.interpolate,
+        },
+        # dropout
+        {
+            nn.Dropout,
+        },
+        # F.dropout
+        {
+            F.dropout,
+        },
+        # matmul
+        {
+            torch.matmul,
+        },
+        # Softmax
+        {
+            nn.Softmax,
+        },
+        # PReLU
+        {
+            nn.PReLU,
+            nnq.PReLU,
+        },
+        # F.prelu
+        {
+            F.prelu,
+            toq.prelu,
+        },
+        # pixel shuffle
+        {
+            nn.PixelShuffle,
+        },
+        {
+            F.pixel_shuffle,
+        },
+        # pixel unshuffle
+        {
+            nn.PixelUnshuffle,
+        },
+        {
+            F.pixel_unshuffle,
+        },
+        # narrow
+        {
+            torch.narrow,
+        },
+    ]
+
+    # for each floating point op, add versions of the op added by
+    # backend_config
+    backend_config = get_native_backend_config()
+
+    new_connections: List[Tuple[Callable, Callable]] = [
+        # technical debt edge case
+        (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
+    ]
+
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+
+        # pattern format: (c, (b, a))
+        first_element = pattern
+        # look from the end, because pattern is in reverse order
+        while isinstance(first_element, (list, tuple)):
+            first_element = first_element[-1]
+
+        if config.fused_module is not None:
+            # case 1: pattern fuses a pattern of ops into an op
+            # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
+            new_connections.append((first_element, config.fused_module))
+
+        if config.qat_module is not None:
+            # case 2: pattern swaps a module into a QAT module
+            # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
+            new_connections.append((first_element, config.qat_module))
+
+        if config.reference_quantized_module is not None:
+            # case 3: reference version of floating point module, such as
+            # nn.Conv2d and nnqr.Conv2d
+            new_connections.append((first_element, config.reference_quantized_module))
+
+    #
+    # Add reference module swaps from default lowering path
+    #
+
+    for source_to_target in (
+        _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
+        _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
+        _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
+        _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
+    ):
+        for source, target in source_to_target.items():  # type: ignore[attr-defined]
+            new_connections.append((source, target))
+
+    for source_to_double_target in (
+        _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
+        _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
+        _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
+    ):
+        for source, (target1, target2) in source_to_double_target.items():  # type: ignore[attr-defined]
+            new_connections.append((source, target1))
+            new_connections.append((source, target2))
+
+    #
+    # Add function swaps from default lowering path
+    #
+
+    for source, (target1, target2) in \
+            _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
+        new_connections.append((source, target1))
+        new_connections.append((source, target2))
+
+    for source_to_target in (
+        _lower_to_native_backend.QBIN_OP_MAPPING,
+        _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
+        quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
+    ):
+        for source, target in source_to_target.items():
+            new_connections.append((source, target))
+
+    #
+    # Add other swaps, ideally in the future this could be removed
+    # after the lowering code stops using these.
+    #
+    for source_to_target in (
+        quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
+    ):
+        for source, target in source_to_target.items():
+            new_connections.append((source, target))
+
+
+    # add the new connections from backend_config
+    for item1, item2 in new_connections:
+        for set_of_related_ops in sets_of_related_ops:
+            if item1 in set_of_related_ops or item2 in set_of_related_ops:
+                set_of_related_ops.add(item1)
+                set_of_related_ops.add(item2)
+                break
+
+    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
+
+    counter = 0
+    for set_of_related_ops in sets_of_related_ops:
+        base_name = str(counter)
+        counter += 1
+        base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
+
+    return base_name_to_sets_of_related_ops
+
+
+def get_base_name_for_op(
+    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
+    op: NSNodeTargetType,
+) -> Optional[str]:
+    for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
+        if op in set_of_related_ops:
+            return base_name
+    return None
+
+
+def add_op_to_sets_of_related_ops(
+    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
+    op: NSNodeTargetType,
+    related_op: Optional[NSNodeTargetType],
+) -> None:
+    if related_op is not None:
+        for set_of_related_ops in base_name_to_sets_of_related_ops.values():
+            if related_op in set_of_related_ops:
+                set_of_related_ops.add(op)
+                return
+        # if we got here, related_op was not found
+        raise AssertionError(f"{related_op} was not found")
+    else:
+        counter = 0
+        while str(counter) in base_name_to_sets_of_related_ops:
+            counter += 1
+        base_name_to_sets_of_related_ops[str(counter)] = {op}
+
+
+# TODO(future PR): clean this up
+def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
+    FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
+        F.linear,
+        F.conv1d,
+        F.conv2d,
+        F.conv3d,
+        torch.cat,
+        F.elu,
+        F.hardswish,
+        F.instance_norm,
+        F.layer_norm,
+        F.leaky_relu,
+        F.dropout,
+        F.silu,
+        F.mish,
+        operator.add,
+        torch.add,
+        operator.mul,
+        torch.mul,
+        torch.sum,
+        F.prelu,
+    }
+
+    FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
+
+    FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
+        toq.linear,
+        toq.linear_relu,
+        toq.conv1d,
+        toq.conv1d_relu,
+        toq.conv2d,
+        toq.conv2d_relu,
+        toq.conv3d,
+        toq.conv3d_relu,
+        toq.cat,
+        toq.elu,
+        toq.hardswish,
+        toq.instance_norm,
+        toq.layer_norm,
+        toq.leaky_relu,
+        toq.dropout,
+        toq.prelu,
+        # TODO(future PR): implement shadowing for binary ops and
+        # uncomment below
+        # toq.add,
+        # toq.mul,
+    }
+
+    FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
+        F.relu,
+        F.tanh,
+        torch.tanh,
+        F.sigmoid,
+        torch.sigmoid,
+        F.hardsigmoid,
+        operator.floordiv,
+        torch.adaptive_avg_pool1d,
+        F.adaptive_avg_pool2d,
+        F.adaptive_avg_pool3d,
+        F.dropout,
+        F.hardtanh,
+        F.hardtanh_,
+        F.interpolate,
+        F.max_pool1d,
+        F.max_pool2d,
+        F.max_pool3d,
+        F.relu6,
+        F.pixel_shuffle,
+        F.pixel_unshuffle,
+        torch.avg_pool1d,
+        torch._C._nn.avg_pool2d,
+        torch._C._nn.avg_pool3d,
+        torch.cat,
+        torch.chunk,
+        torch.clamp,
+        torch.flatten,
+        torch.transpose,
+        torch.max,
+        torch.mean,
+        torch.min,
+        torch.narrow,
+        torch.repeat_interleave,
+        torch.sort,
+        torch.squeeze,
+        torch.stack,
+        torch.unsqueeze,
+        operator.add,
+    }
+
+    MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
+        nn.Linear,
+        nnqat.Linear,
+        nnqatd.Linear,
+        nnqd.Linear,
+        torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
+        nn.Conv1d,
+        nn.Conv2d,
+        nn.Conv3d,
+        nnqat.Conv1d,
+        nnqat.Conv2d,
+        nnqat.Conv3d,
+        nnqat.Embedding,
+        nnqat.EmbeddingBag,
+        nn.LSTM,
+        # note: nnqd.Linear is an instance of nnq.Linear, so this
+        # check has to happen before the int8 module check
+        nnqd.LSTM,
+        nn.BatchNorm2d,
+        nn.BatchNorm3d,
+        nn.Dropout,
+        nn.ConvTranspose1d,
+        nn.ConvTranspose2d,
+        nn.ConvTranspose3d,
+        nn.ELU,
+        nn.GroupNorm,
+        nn.InstanceNorm1d,
+        nn.InstanceNorm2d,
+        nn.InstanceNorm3d,
+        nn.LayerNorm,
+        nn.Hardswish,
+        nn.LeakyReLU,
+        nn.ReLU6,
+        nn.SiLU,
+        nn.Mish,
+        nn.Softmax,
+        nn.PReLU,
+        nni.BNReLU2d,
+        nni.BNReLU3d,
+        nni.ConvReLU1d,
+        nni.ConvReLU2d,
+        nni.ConvReLU3d,
+        nni.LinearReLU,
+        nni.LinearBn1d,
+        nni.ConvBn1d,
+        nni.ConvBn2d,
+        nni.ConvBn3d,
+        nniqat.ConvBn1d,
+        nniqat.ConvBn2d,
+        nniqat.ConvBn3d,
+        nniqat.ConvBnReLU1d,
+        nniqat.ConvBnReLU2d,
+        nniqat.ConvBnReLU3d,
+        nniqat.ConvReLU1d,
+        nniqat.ConvReLU2d,
+        nniqat.ConvReLU3d,
+        nniqat.LinearReLU,
+        nniqat.LinearBn1d,
+        nniqd.LinearReLU,
+        nni.LinearLeakyReLU,
+        nni.LinearTanh,
+        nni.ConvAdd2d,
+        nni.ConvAddReLU2d,
+    }
+
+    MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
+        nnq.Linear,
+        nnq.Conv1d,
+        nnq.Conv2d,
+        nnq.Conv3d,
+        nnq.BatchNorm2d,
+        nnq.BatchNorm3d,
+        nnq.Dropout,
+        nnq.ConvTranspose1d,
+        nnq.ConvTranspose2d,
+        nnq.ELU,
+        nnq.InstanceNorm1d,
+        nnq.InstanceNorm2d,
+        nnq.InstanceNorm3d,
+        nnq.LayerNorm,
+        nnq.Hardswish,
+        nnq.LeakyReLU,
+        nnq.Embedding,
+        nnq.EmbeddingBag,
+        nnq.Dropout,
+        nnq.Softmax,
+        nnq.PReLU,
+        nniq.BNReLU2d,
+        nniq.BNReLU3d,
+        nniq.ConvReLU1d,
+        nniq.ConvReLU2d,
+        nniq.ConvReLU3d,
+        nniq.LinearReLU,
+        nniq.LinearLeakyReLU,
+        nniq.LinearTanh,
+        nniq.ConvAdd2d,
+        nniq.ConvAddReLU2d,
+    }
+
+    MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
+        nn.ReLU,
+        nn.Tanh,
+        nn.Sigmoid,
+        nn.Hardsigmoid,
+        nn.AdaptiveAvgPool1d,
+        nn.AdaptiveAvgPool2d,
+        nn.AdaptiveAvgPool3d,
+        nn.AvgPool1d,
+        nn.AvgPool2d,
+        nn.AvgPool3d,
+        nn.Dropout,
+        nn.Hardtanh,
+        nn.Identity,
+        nn.MaxPool1d,
+        nn.MaxPool2d,
+        nn.MaxPool3d,
+        nn.PixelShuffle,
+        nn.PixelUnshuffle,
+        nn.ReLU6,
+    }
+
+    METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
+        'sigmoid_',
+        'sigmoid',
+        'tanh_',
+        'tanh',
+        'hardsigmoid_',
+        'hardsigmoid',
+        'relu_',
+        'relu',
+    }
+
+    return {
+        'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
+        'funs_io_type_fp16': FUNS_IO_TYPE_FP16,
+        'funs_io_type_int8': FUNS_IO_TYPE_INT8,
+        'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8,
+        'mods_io_type_fp32': MODS_IO_TYPE_FP32,
+        'mods_io_type_int8': MODS_IO_TYPE_INT8,
+        'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8,
+        'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8,
+    }
+
+
+def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
+
+    FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
+        torch.quantize_per_tensor,
+        operator.getitem,
+    }
+
+    MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
+        nn.Identity,
+    }
+
+    METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
+        'to',
+        'dequantize',
+        'reshape',
+        'view',
+        'unsqueeze_',
+        'unsqueeze',
+        'transpose',
+        'squeeze_',
+        'squeeze',
+        'size',
+        'shape',
+        'resize_',
+        'repeat_interleave',
+        'repeat',
+        'permute',
+        'numel',
+        'mean',
+        'detach_',
+        'detach',
+        'contiguous',
+        'clamp',
+        'chunk',
+    }
+
+    return {
+        'funs_unmatchable': FUNS_UNMATCHABLE,
+        'mods_unmatchable': MODS_UNMATCHABLE,
+        'meths_unmatchable': METHS_UNMATCHABLE,
+    }
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/n_shadows_utils.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/n_shadows_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eda789155e55da8ed1aaead191ffd7b51f045b13
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/n_shadows_utils.py
@@ -0,0 +1,1311 @@
+import torch
+import torch.fx
+from torch.fx import (
+    Node,
+    GraphModule,
+    Graph,
+)
+
+from torch.ao.ns.fx.utils import (
+    # TODO(future PR): make this work correctly for methods
+    get_target_type_str,
+    get_normalized_nth_input,
+)
+from torch.ao.ns.fx.ns_types import (
+    NSSingleResultValuesType,
+    NSResultsType,
+)
+from torch.ao.ns.fx.graph_passes import _maybe_get_fqn
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization.utils import getattr_from_fqn
+from torch.ao.quantization.fx.match_utils import _MatchResult
+from torch.utils._pytree import tree_map
+
+import collections
+import copy
+from typing import List, Dict, Set, Tuple, Callable, Any, Optional
+import operator
+
+SHADOW_NODE_NAME_PREFIX = 'shadow'
+SHADOW_WRAPPER_NODE_NAME_PREFIX = 'shadow_wrapper'
+
+# TODO(future PR): reuse existing mapping instead of creating a new one
+BINARY_FUNCTIONS = {
+    torch.add,
+    torch.Tensor.add,
+    operator.add,
+    torch.mul,
+    torch.Tensor.mul,
+    operator.mul,
+}
+
+def _get_attr_name(subgraph_idx, subgraph_candidate_idx):
+    return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
+
+def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx):
+    return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
+
+
+class OutputProp:
+    """
+    Output propagation (modeled from shape propagation).
+
+    Given a GraphModule and an example input, saves the output flowing
+    through each node on `node.traced_result`.
+
+    Code based on the example from
+    https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
+    """
+    def __init__(self, mod):
+        self.mod = mod
+        self.graph = mod.graph
+        self.modules = dict(self.mod.named_modules())
+
+    def propagate(self, *args):
+        args_iter = iter(args)
+        env : Dict[str, Node] = {}
+
+        def load_arg(a):
+            return torch.fx.graph.map_arg(a, lambda n: env[n.name])
+
+        def fetch_attr(target : str):
+            target_atoms = target.split('.')
+            attr_itr = self.mod
+            for i, atom in enumerate(target_atoms):
+                if not hasattr(attr_itr, atom):
+                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
+                attr_itr = getattr(attr_itr, atom)
+            return attr_itr
+
+        for node in self.graph.nodes:
+            if node.op == 'placeholder':
+                result = next(args_iter)
+            elif node.op == 'get_attr':
+                result = fetch_attr(node.target)
+            elif node.op == 'call_function':
+                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
+            elif node.op == 'call_method':
+                self_obj, *args = load_arg(node.args)
+                kwargs = load_arg(node.kwargs)
+                result = getattr(self_obj, node.target)(*args, **kwargs)
+            elif node.op == 'call_module':
+                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
+
+            if isinstance(result, torch.Tensor):  # type: ignore[possibly-undefined]
+                node.traced_result = result
+
+            env[node.name] = result
+
+        return None
+
+def _get_dedup_subgraphs(
+    matches: Dict[str, _MatchResult]
+) -> Dict[str, List[Node]]:
+    # the original matches variable is unique by node, make it unique by subgraph
+    # instead
+    seen_nodes = set()
+    subgraphs_dedup = {}
+
+    # Dict items are not reversible until Python 3.8, so we hack it
+    # to be compatible with previous Python versions
+    # TODO(future PR): try reversed(list(matches.items()))
+    matches_items_reversed: List[Tuple[str, _MatchResult]] = []
+    for name, cur_match in matches.items():
+        matches_items_reversed.insert(0, (name, cur_match))
+
+    # Note: the order is important.  `matches` currently provides the matches
+    # in reverse order.  We would like to process the matches in non-reverse
+    # order, so that we can create an intuitive naming scheme, such as
+    # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)`
+    for name, cur_match in matches_items_reversed:  # type: ignore[call-overload]
+        was_seen = False
+        for node_or_tuple in cur_match[1]:
+
+            # Cur_match[1] has an unusual type. It says that it's a `List[Node]`,
+            # but it is really not. Furthermore, the contents of this field
+            # can change from match results of multiple nodes of the same pattern
+            #
+            # For example, for conv -> bn -> relu, we see
+            # match_results = {
+            #   'conv': (relu, [(bn, conv), relu], ...),
+            #   'bn': (relu, [(bn, conv), relu], ...),
+            #   'relu': (relu, [(bn, conv), relu], ...),
+            # }
+            #
+            # Ideally we should clean up the `find_matches` function to make
+            # this more intuitive. For the purposes of this prototype, we hack
+            # around it.
+
+            if isinstance(node_or_tuple, Node):
+                if node_or_tuple in seen_nodes:
+                    was_seen = True
+                seen_nodes.add(node_or_tuple)
+
+            else:
+                assert isinstance(node_or_tuple, tuple)
+                for node in node_or_tuple:
+                    assert isinstance(node, Node)
+                    if node in seen_nodes:
+                        was_seen = True
+                    seen_nodes.add(node)
+
+        if was_seen:
+            continue
+
+        # Start with the unusual type, convert it to [op_0, ..., op_n]
+        list_of_nodes = []
+
+        if len(cur_match[1]) == 1:
+            list_of_nodes = cur_match[1]
+        else:
+            assert len(cur_match[1]) == 2
+            # either (a, b), or ((a, b), c) or (c, (a, b))
+            # cannot make any assumptions on order, not clear what the
+            # _find_matches function is doing to populate this
+            # TODO(future PR): make this code less confusing,  see discussion
+            # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
+
+            def _order_nodes(node_a, node_b, node_c) -> List[Node]:
+                nodes = [node_a, node_b, node_c]
+                first_node = None
+                mid_node = None
+                last_node = None
+                for n in nodes:
+                    prev_n = n.args[0]
+                    next_n = next(iter(n.users))
+                    if prev_n not in nodes:
+                        first_node = n
+                    elif next_n not in nodes:
+                        last_node = n
+                    else:
+                        mid_node = n
+                assert first_node is not None and mid_node is not None and \
+                    last_node is not None
+                assert mid_node.args[0] is first_node
+                assert last_node.args[0] is mid_node
+                return [last_node, mid_node, first_node]
+
+            if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
+                # (a, b)
+                list_of_nodes = cur_match[1]
+            elif isinstance(cur_match[1][0], tuple):
+                # ((a, b), c)
+                node_a, node_b = cur_match[1][0]
+                node_c = cur_match[1][1]
+                list_of_nodes = _order_nodes(node_a, node_b, node_c)
+            elif isinstance(cur_match[1][1], tuple):
+                # (a, (b, c))
+                node_a, node_b = cur_match[1][1]
+                node_c = cur_match[1][0]
+                list_of_nodes = _order_nodes(node_a, node_b, node_c)
+
+        # [node_n, ..., node_0], note that the order is reversed
+        # to make it chronological for simple subgraphs
+        list_of_nodes.reverse()
+        subgraphs_dedup[name] = list_of_nodes
+
+    return subgraphs_dedup
+
+def _get_logger_for_subgraph(
+    model: GraphModule,
+    first_node: Node,
+    last_node: Node,
+    subgraph_idx: int,
+    subgraph_candidate_idx: int,
+    qconfig_str: str,
+    logger_cls: Callable,
+    fqn: Optional[str],
+) -> torch.nn.Module:
+    """
+    Given a model and a linear subgraph starting from `first_node` and
+    ending with `last_node`, creates a logger for the end of this
+    subgraph.
+    """
+    if fqn is None:
+        fqn = ''
+    logger_mod_orig = logger_cls(
+        first_node.name,  # ref_node_name
+        last_node.name,  # prev_node_name
+        f'subgraph_{subgraph_idx}_{subgraph_candidate_idx}',  # model_name
+        'model',  # ref_name
+        get_target_type_str(last_node, model),  # prev_node_target_type
+        get_target_type_str(first_node, model),  # ref_node_target_type
+        NSSingleResultValuesType.NODE_OUTPUT.value,  # results_type
+        0,  # index_within_arg
+        0,  # index_of_arg
+        fqn,  # fqn
+        qconfig_str,
+    )
+    # Usually we expect the user to add loggers, then calibrate, then convert,
+    # and then populate loggers.  This is why the loggers start disabled.
+    # TODO(future PR): reconsider the design to make this more intuitive.
+    logger_mod_orig.enabled = False
+    return logger_mod_orig
+
+def create_submodule_from_subgraph(
+    model: torch.nn.Module,
+    first_node: Node,
+    last_node: Node,
+) -> GraphModule:
+    """
+    Input: a model, and a linear subgraph within the model from first_node to
+      last_node.
+
+    Output: a new submodule containing a copy of the subgraph, with the inputs
+      to the first node becoming the inputs to the submodule, and all other
+      nodes in the subgraph being copied.
+
+    Example inputs:
+
+    `model`: a module with graph
+
+      x0 -> op1 -> x1 -> op2 -> x2
+             |
+            arg1
+
+    `first_node`: op1
+    `last_node`: op2
+
+    Example output: a new module with graph
+
+      input1 -> op1_copy -> x1 -> op2_copy -> output1
+                   |
+                  arg1
+    """
+
+    #
+    # create a blank GraphModule with an empty graph
+    #
+
+    class M(torch.nn.Module):
+        def forward(self, x):
+            pass
+
+    m = M()
+    gm = torch.fx.symbolic_trace(m)
+    g = gm.graph
+    for node in reversed(gm.graph.nodes):
+        g.erase_node(node)
+
+    #
+    # modify the graph to have a copy of our subgraph
+    #
+
+    cur_node_orig = first_node
+    cur_args_orig = cur_node_orig.args
+    cur_kwargs_orig = cur_node_orig.kwargs
+
+    cur_name_idx = 0
+
+    iteration_limit = 100
+    cur_iteration = 0
+
+    while True:
+        if cur_node_orig is first_node:
+            # we are at the first node, we need to set up graph inputs
+            # TODO(future): some graphs could have placeholders which are unrelated
+            # to the first node, need to handle this
+            cur_args_copy = []
+            cur_kwargs_copy = {}
+            seen_names: Set[str] = set()
+            old_name_to_new_node: Dict[str, Node] = {}
+
+            def _add_placeholder(
+                g: Graph, node: Node, seen_names, old_name_to_new_node
+            ):
+                # note: for graphs starting with patterns such as `y = x + x`, we
+                # need to ensure we do not add multiple placeholders with the
+                # same name
+                counter = 0
+                while node.name + '_' + str(counter) in seen_names:
+                    counter += 1
+                cur_name = node.name + '_' + str(counter)
+                seen_names.add(cur_name)
+                placeholder = g.placeholder(cur_name)
+                old_name_to_new_node[node.name] = placeholder
+                return placeholder
+
+            for arg in cur_node_orig.args:
+                if isinstance(arg, Node):
+                    p = _add_placeholder(
+                        g, arg, seen_names, old_name_to_new_node)
+                    cur_args_copy.append(p)
+                elif isinstance(arg, (list, tuple)):
+                    new_arg = []
+                    for inner_arg in arg:
+                        if isinstance(inner_arg, Node):
+                            new_arg.append(_add_placeholder(
+                                g, inner_arg, seen_names, old_name_to_new_node))
+                        else:
+                            new_arg.append(inner_arg)
+                    cur_args_copy.append(new_arg)
+                else:
+                    cur_args_copy.append(arg)
+
+            # TODO(future PR): handle non-normalized kwargs
+            for kwarg_name, kwarg in cur_node_orig.kwargs.items():
+                if isinstance(kwarg, Node):
+                    cur_kwargs_copy[kwarg_name] = _add_placeholder(
+                        g, kwarg, seen_names, old_name_to_new_node)
+                elif isinstance(kwarg, (list, tuple)):
+                    new_kwarg = []
+                    for inner_kwarg in kwarg:
+                        p = _add_placeholder(
+                            g, inner_kwarg, seen_names, old_name_to_new_node)
+                        new_kwarg.append(p)
+                    cur_kwargs_copy[kwarg_name] = new_kwarg
+                else:
+                    cur_kwargs_copy[kwarg_name] = kwarg
+
+            cur_args_copy = tuple(cur_args_copy)  # type: ignore[assignment]
+        else:
+            # we are not at first node, first arg is from the previous node,
+            # and all other args are copied
+
+            # the current implementation is simplistic and cannot handle
+            # ops with two or more arguments which need to be passed from
+            # the previous op, so we assert them out
+            assert cur_node_orig.target not in BINARY_FUNCTIONS
+
+            # at this point in the code, cur_node_copy is pointing to the copy
+            # of the previous node
+            # TODO(future PR): this is not handling complicated graphs correctly, need to
+            # look at actual relationships instead of assuming sequential graph
+            # TODO(future PR): this is ignoring kwargs, will need to support kwargs
+            # for any fusion pattern which has them for a node that is not the
+            # first node.
+            cur_args_copy = [cur_node_copy]  # type: ignore[has-type, possibly-undefined]  # noqa: F821
+
+            if len(cur_node_orig.args) > 1:
+                for arg in cur_node_orig.args[1:]:
+                    if isinstance(arg, torch.nn.Parameter):
+                        new_arg = arg.clone().detach()  # type: ignore[assignment]
+                        mod_name = f"mod_{cur_name_idx}"
+                        cur_name_idx += 1
+                        setattr(gm, mod_name, new_arg)
+                        new_arg_placeholder = gm.placeholder(mod_name)
+                        cur_args_copy.append(new_arg_placeholder)
+                    elif isinstance(arg, (float, int, torch.dtype)):
+                        cur_args_copy.append(arg)
+                    else:
+                        raise AssertionError(f'arg of type {type(arg)} not handled yet')
+            cur_args_copy = tuple(cur_args_copy)  # type: ignore[assignment]
+
+        # copy the node
+        if cur_node_orig.op == 'call_module':
+            orig_mod = getattr_from_fqn(model, cur_node_orig.target)  # type: ignore[arg-type]
+            orig_mod_copy = copy.deepcopy(orig_mod)
+            mod_name = f"mod_{cur_name_idx}"
+            setattr(gm, mod_name, orig_mod_copy)
+            cur_name_idx += 1
+            cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)  # type: ignore[possibly-undefined]
+
+        elif cur_node_orig.op == 'call_function':
+            cur_node_copy = g.call_function(
+                cur_node_orig.target, cur_args_copy, cur_kwargs_copy)  # type: ignore[possibly-undefined]
+
+        elif cur_node_orig.op == 'call_method':
+            cur_node_copy = g.call_method(
+                cur_node_orig.target, cur_args_copy, cur_kwargs_copy)  # type: ignore[possibly-undefined]
+
+        else:
+            raise AssertionError(f'{cur_node_orig.op} not supported yet')
+
+        if cur_node_orig is last_node:
+            break
+
+        # go to next node
+        assert len(cur_node_orig.users.keys()) == 1, \
+            f'{cur_node_orig} has more than 1 users, not supported yet'
+        cur_node_orig = next(iter(cur_node_orig.users.keys()))
+        cur_args_orig = cur_node_orig.args
+        cur_kwargs_orig = cur_node_orig.kwargs
+
+        cur_iteration += 1
+        if cur_iteration > iteration_limit:
+            raise AssertionError('iteration limit exceeded')
+
+    # set up outputs
+    g.output(cur_node_copy)
+
+    gm.recompile()
+    return gm
+
+def create_one_transformed_and_logged_copy_of_subgraph(
+    mt: GraphModule,
+    subgraph_idx: int,
+    subgraph_candidate_idx: int,
+    first_node: Node,
+    last_node: Node,
+    fqn: Optional[str],
+    list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
+    example_inputs: Any,
+    last_added_shadow_node_list: List[Optional[Node]],
+    custom_prepare_fn: Optional[Callable] = None,
+    custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
+) -> None:
+    """
+    Given a subgraph in `mt` and a subgraph candidate idx, inserts the
+    subgraph candidate copy and instruments it with loggers.
+
+    If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just
+    add a logger to the end.
+
+    If subgraph_candidate_idx is not 0, we create a copy of the subgraph and
+    prepare it with `prepare_fx`.
+    """
+
+    # TODO(future PR): move logger classes to utils to remove circular dependency
+    from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
+
+    if subgraph_candidate_idx == 0:
+        # idx = 0 is the floating point (original) version of the subgraph
+        # We keep the subgraph as is, and add a logger at the end
+
+        qconfig_str = ''
+        logger_mod_orig = _get_logger_for_subgraph(
+            mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
+            qconfig_str, OutputLogger, fqn)
+
+        attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
+        assert not hasattr(mt, attr_name)
+        setattr(mt, attr_name, logger_mod_orig)
+        with mt.graph.inserting_after(last_node):
+            new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
+            last_added_shadow_node_list[0] = new_node
+
+    else:
+        # idx > 0 means we have a candidate qconfig to try, so we need
+        # to make a copy of the subgraph, feed it with the right inputs,
+        # and add a logger at the end
+
+        # get the qconfig
+        # subtract one because the first candidate is the floating point
+        # version of the subgraph
+        node_name_to_qconfig = \
+            list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
+        qconfig = node_name_to_qconfig[first_node.name]
+
+        # if no quantization is requested, skip
+        # TODO(future PR): deduplicate equivalent qconfigs that come from
+        #   different qconfig mapping objects
+        if qconfig is None:
+            return
+
+        qconfig_mapping = QConfigMapping().set_global(qconfig)
+
+        # create a copy of the submodule, wrapped in a separate module
+        orig_mod_copy_wrapped = create_submodule_from_subgraph(
+            mt, first_node, last_node)
+
+        # add a call to prepare_fx on the wrapper module
+        if custom_prepare_fn is None:
+            orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx(
+                orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs)
+        else:
+            if custom_prepare_kwargs is None:
+                custom_prepare_kwargs = {}
+            for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]:
+                assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs"
+            prepare_kwargs: Dict[str, Any] = {
+                "example_inputs": example_inputs,
+                "qconfig_mapping": qconfig_mapping
+            }
+            prepare_kwargs.update(custom_prepare_kwargs)
+            orig_mod_copy_wrapped = custom_prepare_fn(
+                orig_mod_copy_wrapped,
+                **prepare_kwargs)
+
+        # attach the wrapper to the model
+        attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
+        assert not hasattr(mt, attr_name)
+        setattr(mt, attr_name, orig_mod_copy_wrapped)
+
+        # add a call to the wrapper module from the parent graph
+        insert_after_node = last_added_shadow_node_list[0]
+        with mt.graph.inserting_after(insert_after_node):
+            # TODO(future PR): handle fusion patterns where non-first nodes
+            # need inputs
+
+            # pass in all node args and kwargs
+
+            new_args = []
+            for arg in first_node.args:
+                if isinstance(arg, Node):
+                    new_args.append(arg)
+                elif isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], Node):
+                    for inner_arg in arg:
+                        if isinstance(inner_arg, Node):
+                            new_args.append(inner_arg)
+
+            new_kwargs = {}
+            for name, old_kwarg in first_node.kwargs.items():
+                if isinstance(old_kwarg, Node):
+                    new_kwargs[name] = old_kwarg
+                elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg):
+                    # TODO(future PR): clarify why we are adding kwargs to args
+                    new_args.extend(old_kwarg)
+
+            new_args = tuple(new_args)  # type: ignore[assignment]
+
+            new_node = mt.graph.call_module(
+                attr_name, args=new_args, kwargs=new_kwargs)
+
+        # add a logger to parent graph to observe the shadow wrapper
+        logger_mod_orig = _get_logger_for_subgraph(
+            mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
+            str(qconfig), OutputComparisonLogger, fqn)
+
+        attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
+        assert not hasattr(mt, attr_name)
+        setattr(mt, attr_name, logger_mod_orig)
+        with mt.graph.inserting_after(new_node):
+            logger = mt.graph.call_module(attr_name, args=(new_node, last_node), kwargs={})
+            last_added_shadow_node_list[0] = logger
+
+    mt.recompile()
+
+def create_n_transformed_and_logged_copies_of_subgraph(
+    mt: GraphModule,
+    subgraph_idx: int,
+    match_name: str,
+    nodes_in_this_subgraph: List[Any],
+    qconfig_mappings: List[QConfigMapping],
+    list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
+    custom_prepare_fn: Optional[Callable] = None,
+    custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
+) -> None:
+    """
+    Given a model `mt` and a subgraph_idx, creates the needed copies
+    of the subgraph for all qconfigs, and instruments them with loggers.
+    """
+    # for now, assume that
+    # 1. the first node has one input
+    # 2. the last node has one output
+
+    # for now, ignore all subgraphs that contain non-nodes (tuples, etc)
+    # TODO(future PR): implement this
+    if any(
+        not isinstance(node, Node)
+        for node in nodes_in_this_subgraph
+    ):
+        return
+
+    first_node = nodes_in_this_subgraph[0]
+    last_node = nodes_in_this_subgraph[-1]
+    # We used output propagation to populate example values on each
+    # node. Use the example values from the previous node as the input
+    # to the current node.
+    prev_node = get_normalized_nth_input(first_node, mt, 0)
+    if isinstance(prev_node, list):
+        example_inputs = [x.traced_result for x in prev_node]
+    elif isinstance(prev_node, tuple):
+        example_inputs = (x.traced_result for x in prev_node)  # type: ignore[assignment]
+    else:
+        # currently some customer models do not have a traced_result in
+        # every node, so we have to guard for this case since we cannot
+        # quantize without an example input
+        # TODO(future PR): add a test case for this once we have an easy
+        # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489
+        # for additional context
+        if hasattr(prev_node, 'traced_result'):
+            example_inputs = (prev_node.traced_result,)  # type: ignore[attr-defined, assignment]
+        else:
+            print(
+                'unable to get example input for node ' +
+                f'{first_node.format_node()}, skipping')
+            return
+
+    # If there are no quantization configs for this subgraph, skip adding
+    # loggers. This reduces memory usage for models where not all layers are
+    # quantized.
+    # TODO(future): consider making this configurable
+    found_at_least_one_qconfig = False
+    for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
+
+        if subgraph_candidate_idx == 0:
+            # fp32 baseline does not need a qconfig
+            continue
+
+        # a. we have N shadows, so len(qconfig_mappings) is N
+        # b. we will have the fp32 layer + N shadows, so overall number of
+        #    (original_op) + (*shadows) will be N+1
+        # c. since `subgraph_candidate_idx` represents (b), we need
+        #    to subtract 1 to query from (a)
+        node_name_to_qconfig = \
+            list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
+        qconfig = node_name_to_qconfig[first_node.name]
+        if qconfig is not None:
+            found_at_least_one_qconfig = True
+            break
+    if not found_at_least_one_qconfig:
+        print('unable to find at least one qconfig for node ' +
+              f'{first_node.format_node()}, skipping')
+        return
+
+    fqn = _maybe_get_fqn(first_node, mt)
+
+    # We want the results to contain the subgraphs in natural order,
+    # and the graph to also contain shadow wrappers and shadow loggers
+    # in natural order.
+    # If we just iterate in reverse, the graph will be in natural
+    # order but the eventual results will be in reverse order.
+    # So, we keep track of the last shadow logger we added and
+    # always insert after it.
+    last_added_shadow_node_list: List[Optional[Node]] = [None]
+    for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
+
+        create_one_transformed_and_logged_copy_of_subgraph(
+            mt, subgraph_idx, subgraph_candidate_idx, first_node,
+            last_node, fqn, list_of_node_name_to_qconfig,
+            example_inputs, last_added_shadow_node_list, custom_prepare_fn,
+            custom_prepare_kwargs)
+
+def create_add_loggers_graph(
+    model: GraphModule,
+    subgraphs_dedup: Dict[str, List[Node]],
+    qconfig_mapping: QConfigMapping,
+    node_name_to_qconfig: Dict[str, QConfigAny],
+) -> None:
+    r"""
+    Given a model, a model graph partition (currently a set of matched
+    subgraphs) and instructions how to transform each subgraph
+    (currently quantizing it according to qconfig_mapping), modifies
+    the model graph to create an alternate path through the original graph,
+    with each of the subgraphs quantized.  This is useful to compare
+    propagation error of a transformation such as quantization.
+
+    For example, given layer op0 and op1, there are four cases when handling op1:
+    1. op0 and op1 quantized
+    2. op0 and op1 unquantized
+    3. op0 quantized, op1 unquantized
+    4. op0 unquantized, op1 quantized
+
+    Example input, case 1:
+
+    .. code::
+
+      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
+       \                        \          \                 \       # noqa: W605
+         ---> op0_1 -> x1_1 ----> clog    op1_1 -> x2_1 ----> clog
+
+    Example output, case 1:
+
+    .. code::
+
+      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
+       \                        \                           \        # noqa: W605
+         ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog
+
+    """
+    # TODO(future PR): move logger classes to utils to remove circular dependency
+    from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
+
+    def _get_subgraph_containing_node(node, subgraphs_dedup):
+        for subgraph in subgraphs_dedup.values():
+            if node in subgraph:
+                return subgraph
+        return None
+
+    # First, we need to create shadow branches, going from
+    #
+    #   x0 -> op0 -> x1 -> ...
+    #
+    #
+    # to
+    #
+    #   x0 -> op0_0 -> x1_0 -> log -> ...
+    #    \                     \
+    #      -> op0_1 -> x1_1 -> clog
+    #
+    # Later, the outputs of each shadow will be rerouted to calculate
+    # propagation error.
+
+    # Note: we cannot iterate over matched subgraphs because some nodes
+    # may not be matched. So, we iterate over nodes in the graph, and
+    # associate them to matched subgraphs if possible.
+
+    nodes_to_skip = set()
+    # for each subgraph, save a mapping from first node of subgraph
+    # to first and last node of the shadow of this subgraph
+    orig_first_node_to_shadow_in_node = {}
+    orig_first_node_to_shadow_out_node = {}
+    # need to record original list because we will mutate the graph as we go
+    orig_nodes = list(model.graph.nodes)  # type: ignore[union-attr, arg-type]
+    cur_subgraph_idx = 0
+    for n in orig_nodes:
+        if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
+            continue
+
+        maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
+        insert_submodule_copy = False
+        if maybe_subgraph is not None:
+            first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
+            for node_to_skip in maybe_subgraph:
+                nodes_to_skip.add(node_to_skip)
+            qconfig = node_name_to_qconfig[first_node.name]
+            if qconfig is not None:
+                insert_submodule_copy = True
+        else:
+            first_node, last_node = n, n
+
+        if insert_submodule_copy:
+            match_name = first_node.name
+            create_n_transformed_and_logged_copies_of_subgraph(
+                model, cur_subgraph_idx, match_name, maybe_subgraph,
+                [qconfig_mapping], [node_name_to_qconfig],
+                None, None  # type: ignore[arg-type]
+            )
+            # find the created shadow module and record it so we
+            # can find it easily in step 2
+            expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1"
+            new_shadow_mod = None
+            for maybe_shadow_mod in model.graph.nodes:
+                if maybe_shadow_mod.op == 'call_module' and \
+                        maybe_shadow_mod.target == expected_shadow_target:
+                    new_shadow_mod = maybe_shadow_mod
+                    break
+            assert new_shadow_mod is not None
+            orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
+            orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
+
+        else:
+            # create a copy of the subgraph by only copying FX nodes
+            # but not copying any parameters, to minimize memory usage
+            subgraph_to_use = maybe_subgraph if maybe_subgraph is not None \
+                else [first_node]
+
+            # add a regular logger after last_node
+            qconfig_str = ''
+            subgraph_candidate_idx = 0
+            fqn = _maybe_get_fqn(first_node, model)
+            logger_mod_orig = _get_logger_for_subgraph(
+                model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
+                qconfig_str, OutputLogger, fqn)
+            attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
+            assert not hasattr(model, attr_name)
+            setattr(model, attr_name, logger_mod_orig)
+            insertion_point = last_node
+            with model.graph.inserting_after(insertion_point):
+                logger = model.graph.call_module(
+                    attr_name, args=(last_node,), kwargs={})
+                insertion_point = logger
+
+            # create a copy of the subgraph
+            cur_node_orig = first_node
+            cur_node_copy = None
+            first_node_copy = None
+            while cur_node_orig in subgraph_to_use:
+                # TODO(future PR): make this support all possible args/kwargs
+                if cur_node_orig is first_node:
+                    new_args = cur_node_orig.args
+                    new_kwargs = cur_node_orig.kwargs
+                else:
+                    first_arg_for_copy = cur_node_copy
+                    new_args = tuple([first_arg_for_copy, *cur_node_orig.args[1:]])  # noqa: C409
+                    new_kwargs = cur_node_orig.kwargs
+                # make a copy of cur_node_orig
+                with model.graph.inserting_after(insertion_point):
+                    cur_node_copy = model.graph.create_node(
+                        cur_node_orig.op,
+                        cur_node_orig.target,
+                        new_args,
+                        new_kwargs,
+                        # cur_node_orig.name,  # TODO(future PR): set name explicitly
+                    )
+                    if first_node_copy is None:
+                        first_node_copy = cur_node_copy
+                # since now only linear subgraphs are supported, all nodes
+                # except the last one must have only one user
+                if cur_node_orig != last_node:
+                    assert len(cur_node_orig.users.keys()) == 1
+                cur_node_orig = next(iter(cur_node_orig.users.keys()))
+                assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
+                insertion_point = cur_node_copy
+
+            # add a comparison logger after last_node's copy
+            subgraph_candidate_idx = 1
+            logger_mod_orig = _get_logger_for_subgraph(
+                model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
+                qconfig_str, OutputComparisonLogger, fqn)
+            attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
+            assert not hasattr(model, attr_name)
+            setattr(model, attr_name, logger_mod_orig)
+            with model.graph.inserting_after(insertion_point):
+                logger = model.graph.call_module(
+                    attr_name, args=(cur_node_copy, last_node), kwargs={})
+
+            # save the final node so we can use it in step 2
+            orig_first_node_to_shadow_in_node[first_node] = first_node_copy
+            orig_first_node_to_shadow_out_node[first_node] = cur_node_copy
+
+        cur_subgraph_idx += 1
+
+    model.recompile()
+
+    # Now, we go from
+    #
+    #   x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ...
+    #    \                     \       \
+    #      -> op0_1 -> x1_1 -> clog      -> op1_1 -> ...
+    #
+    # to
+    #
+    #   x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ...
+    #    \                     \
+    #      -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ...
+    #
+    # sample values of key internal variables for the example above:
+    #
+    #   orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1}
+    #   orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1}
+    #
+    # note: for subgraphs with more than one node, in_node will be different
+    # compared to out_node
+
+
+    nodes_to_skip = set()
+    for n in orig_nodes:
+        if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
+            continue
+
+        maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
+        if maybe_subgraph is not None:
+            first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
+            for node_to_skip in maybe_subgraph:
+                nodes_to_skip.add(node_to_skip)
+        else:
+            first_node, last_node = n, n
+
+        def maybe_remap_node_to_shadow(node):
+            """
+            If unshadowed `node` has a shadow version, return that. If not,
+            return `node`.
+            """
+            if not isinstance(node, Node):
+                # handle scalars
+                return node
+
+            if node.op in ('placeholder', 'get_attr'):
+                return node
+
+            # Find the shadowed version of this arg from the previous
+            # subgraph. For this, we need to:
+            # 1. navigate to the first node of the previous subgraph
+            # 2. get the output of the shadow wrapper which has (1) as an input
+
+            # For now, assume the arg is in matched subgraphs. In the
+            # future we may have to handle the case where this is not true.
+            prev_subgraph = _get_subgraph_containing_node(
+                node, subgraphs_dedup)
+            if prev_subgraph is None:
+                prev_subgraph = [node]
+            prev_first_node = prev_subgraph[0]
+            prev_shadow_output = \
+                orig_first_node_to_shadow_out_node[prev_first_node]
+            return prev_shadow_output
+
+        cur_shadow_input = \
+            orig_first_node_to_shadow_in_node[first_node]
+        assert cur_shadow_input is not None
+        cur_shadow_input.args = tree_map(
+            maybe_remap_node_to_shadow, cur_shadow_input.args)
+        cur_shadow_input.kwargs = tree_map(
+            maybe_remap_node_to_shadow, cur_shadow_input.kwargs)
+
+        model.recompile()
+
+def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
+    # input: shadow wrapper module
+    # output if shadow wrapper module has a weighted op:
+    #   (quantize_fn, (quantize_fn_args))
+    # output if shadow wrapper module doesn't have a weighted op:
+    #   None
+
+    # For now, assume that the weight is the second input
+    # to the shadow module. If that changes, we can fix it later.
+    placeholders_seen = 0
+    for shadow_n in shadow_wrapper.graph.nodes:  # type: ignore[union-attr]
+        if shadow_n.op != 'placeholder':
+            continue
+
+        placeholders_seen += 1
+        if placeholders_seen != 2:
+            continue
+
+        # the subgraph looks like
+        #
+        #   _input_scale_1 = self._input_scale_1
+        #   _input_zero_point_1 = self._input_zero_point_1
+        #   quantize_per_channel = torch.quantize_per_channel(
+        #       w2_0, _input_scale_1, _input_zero_point_1,
+        #       0, torch.qint8)
+        #
+        #  we have `w2_0`, and are navigating this subgraph
+        #  to get `_input_scale_1` and `_input_zero_point_1`
+
+        assert len(shadow_n.users) == 1
+        quant_node = next(iter(shadow_n.users.keys()))
+        new_args: Any = None
+        if quant_node.target == torch.quantize_per_channel:
+            _weight, scale_node, zp_node, axis, dtype = quant_node.args
+            scale_val = getattr_from_fqn(
+                shadow_wrapper, scale_node.target)
+            zp_val = getattr_from_fqn(
+                shadow_wrapper, zp_node.target)
+            new_args = (scale_val, zp_val, axis, dtype)
+        else:
+            assert quant_node.target == torch.quantize_per_tensor
+            _weight, scale_node, zp_node, dtype = quant_node.args
+            scale_val = getattr_from_fqn(
+                shadow_wrapper, scale_node.target)
+            zp_val = getattr_from_fqn(
+                shadow_wrapper, zp_node.target)
+            new_args = (scale_val, zp_val, dtype)
+        return (quant_node.target, new_args)
+
+    return None
+
+
+def extract_weight_comparison(m: GraphModule) -> NSResultsType:
+
+    # example graph:
+    #
+    #   w1 = self.w1
+    #   b1 = self.b1
+    #   linear = torch._C._nn.linear(x, w1, b1)
+    #   shadow_0_0 = self.shadow_0_0(linear)
+    #   shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1)
+    #   shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear)
+    #
+    # algorithm:
+    # 1. for each call_function node matching our allowlist:
+    # 2.   if corresponding shadow wrapper exists, extract the weight pair
+    #
+    # Note: this is not super robust, but that's ok because this is
+    # just for legacy customers who depend on the previous two-model version
+    # of this API. TBD if we need to make this robust.
+    # Note: modules are not supported, since existing customers only
+    # use functions.
+
+    # TODO(future PR): move this to config
+    weighted_ops = {
+        torch.nn.functional.linear,
+    }
+
+    results: NSResultsType = {
+        'model': {NSSingleResultValuesType.WEIGHT.value: {}}
+    }
+
+    for n in m.graph.nodes:  # type: ignore[union-attr]
+        if not (n.op == 'call_function' and n.target in weighted_ops):
+            continue
+
+        # Check if we have a corresponding shadow wrapper
+        # TODO(future PR, if needed): support kwargs
+        # TODO(future PR, if needed): support multiple shadow users
+        first_arg = n.args[0]
+        shadow_wrapper_node = None
+        for user in first_arg.users:
+            # TODO(before land): fix string match
+            if user.op == 'call_module' and \
+                    user.target.startswith('shadow_wrapper'):
+                shadow_wrapper_node = user
+                break
+
+        if shadow_wrapper_node is None:
+            continue
+
+        shadow_wrapper = getattr_from_fqn(
+            m, shadow_wrapper_node.target)  # type: ignore[arg-type]
+        weight_info = _get_weight_info_from_shadow_wrapper(
+            shadow_wrapper)
+        if weight_info is None:
+            continue
+
+        # get weight
+        w_node = n.args[1]
+        w_obj = getattr_from_fqn(m, w_node.target).detach()
+
+        # get a quantized version of weight
+        quant_fn, quant_fn_args_except_first = weight_info
+        new_args = (w_obj, *quant_fn_args_except_first)
+        w_obj_q = quant_fn(*new_args)
+
+        # add a comparison
+        ref_node_name = n.name
+        prev_node_name = n.name
+        ref_node_type = get_target_type_str(n, m)
+        prev_node_type = ref_node_type
+        fqn = None
+        if hasattr(m, '_node_name_to_scope'):
+            fqn = m._node_name_to_scope[n.name][0]  # type: ignore[index]
+        comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q)
+        result_fp32 = {
+            'res_type': NSSingleResultValuesType.WEIGHT.value,
+            'values': [w_obj],
+            'prev_node_name': prev_node_name,
+            'prev_node_target_type': prev_node_type,
+            'ref_node_name': ref_node_name,
+            'ref_node_target_type': ref_node_type,
+            'index_within_arg': 0,
+            'index_of_arg': 0,
+            'fqn': fqn,
+            'qconfig_str': '',
+            'comparisons': [comparison],
+            'comparison_fn_name': 'sqnr',
+        }
+        result_q = {
+            'res_type': NSSingleResultValuesType.WEIGHT.value,
+            'values': [w_obj_q],
+            'prev_node_name': prev_node_name,
+            'prev_node_target_type': prev_node_type,
+            'ref_node_name': ref_node_name,
+            'ref_node_target_type': ref_node_type,
+            'index_within_arg': 0,
+            'index_of_arg': 0,
+            'fqn': fqn,
+            'qconfig_str': '',
+            'comparisons': [comparison],
+            'comparison_fn_name': 'sqnr',
+        }
+
+        # go from subgraph_n_1 to subgraph_n_0
+        _1, _2, node_idx, _3 = shadow_wrapper_node.target.split('_')
+        name_fp32 = f"subgraph_{node_idx}_0"
+        name_q = f"subgraph_{node_idx}_1"
+
+        results['model'][NSSingleResultValuesType.WEIGHT.value][name_fp32] = \
+            [result_fp32]
+        results['model'][NSSingleResultValuesType.WEIGHT.value][name_q] = \
+            [result_q]
+
+    return results
+
+# TODO(future PR): redesign this to make it easier to consume outputs
+def group_results_by_subgraph(results: NSResultsType) -> Any:
+    """
+    Creates a comparison of results
+
+    Input:
+
+    {
+      'model': {
+        'node_output': {
+          'subgraph_0_0': [
+            'values': [torch.tensor(...), ...], ...
+            'ref_node_name': ...,
+            'ref_node_target_type': ...,
+            'qconfig_str': ...,
+            'comparisons': [], ...
+            'comparison_fn_name': '',
+            'fqn': '...',
+          ],
+          'subgraph_0_1': [
+            'values': [torch.tensor(...), ...], ...
+            'ref_node_name': ...,
+            'ref_node_target_type': ...,
+            'qconfig_str': ...,
+            'comparisons': [torch.tensor(...), ...], ...
+            'comparison_fn_name': '...',
+            'fqn': '...',
+          ],
+          ...
+        },
+      },
+    }
+
+    Output:
+    {
+      'subgraph_0': {
+        '0': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': None,
+          'comparisons': [torch.tensor(...), ...], ...
+          'comparison_fn_name': '...',
+          'fqn': '...',
+        },
+        '1': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': '...',
+          'comparisons': [torch.tensor(...), ...], ...
+          'comparison_fn_name': '...',
+          'fqn': '...',
+        },
+      },
+    }
+
+    """
+    subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict)
+
+    # node_output or weight
+    key_to_use = next(iter(results['model'].keys()))
+
+    for subgraph_name_with_idx, subgraph_candidate_results in \
+            results['model'][key_to_use].items():
+
+        # convert from `subgraph_m_n` to `subgraph_m` and `n`
+        subgraph_str, subgraph_idx, subgraph_candidate_idx = \
+            subgraph_name_with_idx.split('_')
+        subgraph_name = f'{subgraph_str}_{subgraph_idx}'
+
+        subgraph_results = {
+            'ref_node_name': subgraph_candidate_results[0]['ref_node_name'],
+            'ref_node_target_type': subgraph_candidate_results[0]['ref_node_target_type'],
+            'fqn': subgraph_candidate_results[0]['fqn'],
+            'values': subgraph_candidate_results[0]['values'],
+            'qconfig_str': subgraph_candidate_results[0]['qconfig_str'],
+            'comparisons': subgraph_candidate_results[0]['comparisons'],
+            'comparison_fn_name': subgraph_candidate_results[0]['comparison_fn_name'],
+        }
+
+        subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = \
+            subgraph_results
+
+    return dict(subgraph_name_to_subgraph_results)
+
+# TODO(future PR): redesign this to make it easier to consume outputs
+def create_results_comparison(
+    results_grouped,
+) -> Any:
+    """
+    Input:
+
+    {
+      'subgraph_0': {
+        '0': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': '',
+          'comparisons': [],
+          'comparison_fn_name': '',
+          'fqn': '...',
+        },
+        '1': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': '...',
+          'comparisons': [torch.tensor(...), ...],
+          'comparison_fn_name': 'sqnr',
+          'fqn': '...',
+        },
+      },
+    }
+
+    Output:
+    {
+      'subgraph_0': {
+        'ref_node_name': '...',
+        'ref_node_target_type': '...',
+        'fqn': '...',
+        'candidates': {
+          '1': {
+            'qconfig_str': ...,
+            'comparison_fn_name': 'sqnr',
+            'cmp_raw': [..., ...],
+            'cmp_mean': ...,
+          },
+          ...,
+        },
+      },
+    }
+    """
+
+    results_comparison = {}
+
+    for subgraph_name, subgraph_results in results_grouped.items():
+
+        candidates = {}
+        for subgraph_inner_name, subgraph_inner_result in subgraph_results.items():
+            # skip comparing baseline to baseline
+            if subgraph_inner_name == '0':
+                continue
+
+            # we expect the comparisons to be precalculated from
+            # calibration, so we just fetch them here
+            cmp_raw = subgraph_inner_result['comparisons']
+            cmp_raw_tensor = torch.stack(cmp_raw)
+
+            candidates[subgraph_inner_name] = {
+                'qconfig_str': subgraph_inner_result['qconfig_str'],
+                'comparison_fn_name': subgraph_inner_result['comparison_fn_name'],
+                'cmp_raw': cmp_raw_tensor,
+                'cmp_mean': torch.mean(cmp_raw_tensor),
+            }
+
+        results_comparison[subgraph_name] = {
+            'ref_node_name': subgraph_results['0']['ref_node_name'],
+            'ref_node_target_type': subgraph_results['0']['ref_node_target_type'],
+            'fqn': subgraph_results['0']['fqn'],
+            'candidates': candidates,
+        }
+
+    return results_comparison
+
+# TODO(future PR): redesign this to make it easier to consume outputs
+def print_n_shadows_summary(
+    results_comparison,
+) -> None:
+    """
+    Input:
+
+    {
+      'subgraph_0': {
+        'ref_node_name': 'linear1',
+        'ref_node_target_type': '...',
+        'fqn': '...',
+        'candidates': {
+          '1': {
+            'qconfig_str': ...,
+            'comparison_fn_name': ...,
+            'cmp_raw': [45.0, 55.0],
+            'cmp_mean': 50.0,
+          },
+          ...,
+        },
+      },
+    }
+
+    Prints:
+
+    node_name | node_type | fqn | 0    | 1    | ...
+    linear1   | ...       | ... | 45.0 | 50.0 | ...
+    """
+
+    try:
+        from tabulate import tabulate
+    except ImportError:
+        print("`print_tabular` relies on the library `tabulate`, "
+              "which could not be found on this machine. Run `pip "
+              "install tabulate` to install the library.")
+        return
+
+    results = []
+    for subgraph_data in results_comparison.values():
+        mean_all_candidates = [
+            candidate['cmp_mean']
+            for candidate_name, candidate in subgraph_data['candidates'].items()
+        ]
+
+        data_row = [
+            subgraph_data['ref_node_name'],
+            subgraph_data['ref_node_target_type'],
+            subgraph_data['fqn'],
+            *mean_all_candidates,
+        ]
+        results.append(data_row)
+
+    max_candidate_idx_len = -1
+    for data_row in results:
+        max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1]))
+    candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)]
+
+    headers = ['node_name', 'node_type', 'fqn', *candidate_idx_headers]
+    print(tabulate(results, headers=headers))
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/ns_types.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/ns_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..564e5d2041ee7c353311907b84bea27d4643d75b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/ns_types.py
@@ -0,0 +1,64 @@
+import enum
+from typing import NamedTuple
+
+from torch.fx.graph import Node
+
+from typing import Dict, Any, List, Union, Callable
+
+class NSSingleResultValuesType(str, enum.Enum):
+    WEIGHT = 'weight'
+    NODE_OUTPUT = 'node_output'
+    NODE_INPUT = 'node_input'
+
+class NSSubgraph(NamedTuple):
+    start_node: Node
+    end_node: Node
+    base_op_node: Node
+
+# TODO(future PR): see if we can use typing_extensions's TypedDict instead
+# to properly type the various keys
+# {
+#   # one of NSSingleResultValuesType
+#   'type': 'weight',
+#   # the values of type specified above
+#   'values': [torch.tensor(...), ...],
+#   # name of the node directly before the logger
+#   'prev_node_name': 'linear1',
+#   # type of the underlying function or module
+#   'prev_node_target_type': torch.nn.functional.linear  # or torch.nn.Linear, etc
+#   # name of the node responsible for adding this logger
+#   # Note: this may differ from prev_node_name if we are logging inputs
+#   'ref_node_name': 'linear1',
+#   # index of this node within the arg of the input/output node
+#   # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
+#   'index_within_arg': 0,
+#   # index of this node within the args of the input/output node
+#   # for example, in add(x1, x2), x2 would have index_of_arg == 1
+#   'index_of_arg': 0,
+#   # precomputed comparisons of logger values to reference values
+#   'comparisons': [torch.tensor(...), ...]
+#   # name of function used for precomputed comparisons
+#   'comparison_fn_name': 'sqnr',
+#   # string representation of qconfig responsible for creating this logger
+#   'qconfig_str': 'QConfig(...)',
+# }
+NSSingleResultType = Dict[str, Any]
+
+# {
+#   'layer_name_1': {  # subgraph name
+#     'node_output': {  # results type (node_output, node_input, weight)
+#       'model_name_a':  # model name
+#          [NSSingleResultType, ...],  # results, ordered by index_within_arg
+#       'model_name_b':
+#          [NSSingleResultType, ...],
+#     },
+#   },
+# }
+#
+NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]]
+
+# Defines the underlying target type of a node, for example:
+# `F.conv1d` for a `call_function` conv node
+# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module
+# `'sigmoid'` for a `call_method` node calling `x.sigmoid()`
+NSNodeTargetType = Union[Callable, str]
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/pattern_utils.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/pattern_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06b0e806d0536c06210821b82001fcad0e627667
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/pattern_utils.py
@@ -0,0 +1,200 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+toq = torch.ops.quantized
+
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+
+from torch.ao.quantization.backend_config import get_native_backend_config
+from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
+from torch.ao.quantization.utils import getattr_from_fqn
+from .ns_types import NSNodeTargetType
+from torch.ao.quantization import (
+    ObserverBase,
+    FakeQuantizeBase,
+)
+
+from typing import Dict, Tuple, Set, Callable, Any, Union, List
+
+
+def get_type_a_related_to_b(
+    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
+) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
+    # TODO(future PR): allow customizations
+    # TODO(future PR): reuse existing quantization mappings
+    # TODO(future PR): add the rest of modules and ops here
+    type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
+
+    for s in base_name_to_sets_of_related_ops.values():
+        s_list = list(s)
+        # add every bidirectional pair
+        for idx_0 in range(0, len(s_list)):
+            for idx_1 in range(idx_0, len(s_list)):
+                type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
+                type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
+
+    return type_a_related_to_b
+
+
+NSFusionElType = Union[
+    Callable,  # call_function or call_module type, example: F.linear or nn.Conv2d
+    str,  # call_method name, example: "dequantize"
+    Tuple[str, Any],  # call_method name and first argument, example: ("to", torch.float16)
+]
+NSFusionType = Union[
+    Tuple[NSFusionElType, NSFusionElType],
+    Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
+]
+
+def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
+    """
+    Set of potential fusions, in reverse order.  The order is reversed
+    to match how fusion patterns are defined in quantization code.
+
+    Fusion format:
+    ((fusion_op_0, fusion_op_1), base_op_idx)
+
+    Where base_op_idx is the idx of the op we should use to match other related
+    ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
+    of 0 represents the first op in regular (non-reverse) order, 1 represents the
+    second op, etc.
+    """
+    results: List[Tuple[NSFusionType, int]] = []
+
+    # Possible syntaxes:
+    # * single op: torch.nn.Conv2d
+    # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
+    # For fusions, we only care about patterns composed of multiple ops.
+    # TODO(future PR): allow customizations from default patterns.
+    all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
+
+    default_base_op_idx = 0
+    for quant_pattern in all_quant_patterns.keys():
+        # TODO: this is a temporary hack to flatten the patterns from quantization so
+        # that it works with the ns matcher function, maybe we should use `_is_match`
+        # in torch.ao.quantization.fx.match_utils to match the patterns
+        if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
+           isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
+            # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
+            quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
+
+        # Only patterns of multiple ops are fusions, ignore
+        # patterns which contain a single ops (they get matched
+        # without caring about fusions).
+        if isinstance(quant_pattern, tuple):
+            results.append((quant_pattern, default_base_op_idx))  # type: ignore[arg-type]
+
+        # For each pattern, add additional patterns with observers and
+        # fake quants at the end.
+        # TODO(future PR): if needed, implement matching for a node
+        #   having multiple output observers.
+        for cls in (ObserverBase, FakeQuantizeBase):
+            if isinstance(quant_pattern, tuple):
+                new_pattern = (cls, *quant_pattern)
+            else:
+                new_pattern = (cls, quant_pattern)
+            results.append((new_pattern, default_base_op_idx))  # type: ignore[arg-type]
+
+
+    # After this point, results contains values such as
+    # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
+
+    # Patterns for matching fp16 emulation are not specified in the quantization
+    # fusion mappings.  For now, define them here.
+    fp16_em_base_op_idx = 1
+    patterns_to_add = [
+        # linear-relu fp16 emulation:
+        # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
+        ((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
+        # Conv-BN fusion (this happens outside of quantization patterns,
+        # which is why it is defined separately here).
+        ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
+        ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
+        ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
+        ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
+        ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
+        ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
+    ]
+    for p in patterns_to_add:
+        results.append(p)  # type: ignore[arg-type]
+        results.append(((ObserverBase, *p[0]), p[1]))  # type: ignore[arg-type]
+        results.append(((FakeQuantizeBase, *p[0]), p[1]))  # type: ignore[arg-type]
+
+    return results
+
+
+def end_node_matches_reversed_fusion(
+    end_node: Node,
+    reversed_fusion: NSFusionType,
+    gm: GraphModule,
+    seen_nodes: Set[Node],
+) -> bool:
+    """
+    Returns true if a pattern ending with `end_node` matches
+    the fusion pattern.
+    """
+    cur_node = end_node
+    for fusion_idx in range(len(reversed_fusion)):
+        # each node can only belong to one matched pattern
+        if cur_node in seen_nodes:
+            return False
+
+        cur_fusion_el = reversed_fusion[fusion_idx]
+
+        if cur_node.op == 'call_function':
+            fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
+                (not isinstance(cur_fusion_el, type))
+            if fusion_el_is_fun:
+                if cur_node.target != cur_fusion_el:
+                    return False
+                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
+                    cur_node = cur_node.args[0]
+                else:
+                    return False
+            else:
+                return False
+
+        elif cur_node.op == 'call_module':
+            fusion_el_is_mod = isinstance(cur_fusion_el, type)
+            if fusion_el_is_mod:
+                assert isinstance(cur_node.target, str)
+                target_mod = getattr_from_fqn(gm, cur_node.target)
+                if not isinstance(cur_fusion_el, type):
+                    return False
+                if not isinstance(target_mod, cur_fusion_el):
+                    return False
+                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
+                    cur_node = cur_node.args[0]
+                else:
+                    return False
+            else:
+                return False
+
+        elif cur_node.op == 'call_method':
+            fusion_el_is_meth_with_second_arg = \
+                isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
+            fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
+            if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
+                if fusion_el_is_meth_without_args:
+                    if cur_node.target != cur_fusion_el:
+                        return False
+                else:
+                    assert isinstance(cur_fusion_el, tuple)
+                    if cur_node.target != cur_fusion_el[0]:
+                        return False
+                    elif len(cur_node.args) < 2:
+                        return False
+                    elif cur_node.args[1] != cur_fusion_el[1]:
+                        return False
+
+                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
+                    cur_node = cur_node.args[0]
+                else:
+                    return False
+            else:
+                return False
+        else:
+            return False
+
+    return True
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..98aafbb14d816ea6bda7d81f68f7862b33e2f699
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py
@@ -0,0 +1,243 @@
+from __future__ import annotations
+
+import copy
+from typing import Any, Callable, Dict, List, Union
+
+import torch
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
+from torch.ao.quantization.qconfig import QConfigAny
+
+__all__ = ["QConfigMultiMapping"]
+
+_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
+    "global_qconfig": "set_global",
+    "object_type_qconfigs": "set_object_type",
+    "module_name_regex_qconfigs": "set_module_name_regex",
+    "module_name_qconfigs": "set_module_name",
+    "module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
+}
+
+def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
+    to_remove = []
+    for index, cur_qconfig in enumerate(qconfig_list):
+        if cur_qconfig is None:
+            to_remove.append(index)
+            break
+        for checked_qconfig in qconfig_list[:index]:
+            if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
+                to_remove.append(index)
+                break
+    for index in to_remove[::-1]:
+        qconfig_list.pop(index)
+
+class QConfigMultiMapping:
+    """
+    This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
+    so that multiple QConfigs can be specified for each QConfig matching style.
+
+    The user can specify QConfigs using the following methods (in increasing match priority):
+
+        ``set_global`` : sets the global (default) QConfigs
+
+        ``set_object_type`` : sets the QConfigs for a given module type, function, or method name
+
+        ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
+
+        ``set_module_name`` : sets the QConfigs for modules matching the given module name
+
+        ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
+        of the given module name, object type, and the index at which the module appears
+
+    Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
+    single QConfig.
+
+    Example usage::
+
+        qconfig_mapping = QConfigMultiMapping()
+            .set_global([qconfig1, qconfig2])
+            .set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
+            .set_object_type(torch.nn.ReLU, [qconfig1])
+            .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
+            .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
+            .set_module_name("module1", [None])
+            .set_module_name("module2", [qconfig2])
+            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
+
+    """
+
+    def __init__(self):
+        # initialize this with 1 QConfigMapping to avoid corner cases
+        self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
+
+    def _handle_list_size_mismatch(
+        self, qconfig_list: List[QConfigAny], style: str
+    ) -> None:
+        # this method handles cases where the size of qconfig_list does not match
+        # the size of qconfig_mappings_list.
+        # Issue: Consider a user inserting global_qconfig A and B first, then inserting
+        # qconfig C as an object_type_qconfig for conv ops. If we internally store
+        # 1 QConfigMapping with A and C and another with just B, then the
+        # second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
+
+        # we avoid this by maintaining the invariant that if any QConfigMapping
+        # has a qconfig style+key with a qconfig in it, all QConfigMappings must
+        # have either a qconfig or None for that same style+key. In the above
+        # example, a None qconfig would prevent the unwanted match in the
+        # second QConfigMapping
+
+        if len(qconfig_list) > len(self.qconfig_mappings_list):
+            # Case: we have more qconfigs (in qconfig_list) than QConfigMappings
+
+            # Add new QConfigMappings (initialized so we maintain the `invariant`)
+
+            new_qconfig_mapping = QConfigMapping()
+            # searches other QConfigMappings for qconfig style+keys
+            # that need to be inserted as `None` into the new QConfigMapping
+            for qconfig_mapping in self.qconfig_mappings_list:
+
+                # global_qconfig has None by default
+                for check_style in _QCONFIG_STYLE_ORDER[1:]:
+                    qconfigs_dict = getattr(qconfig_mapping, check_style)
+                    target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
+                    for key in qconfigs_dict:
+                        target_qconfigs_dict[key] = None
+                break
+
+            # insert copies of this new QConfigMapping until all entires
+            # in qconfig_list can fit among the QConfigMappings
+            while len(qconfig_list) > len(self.qconfig_mappings_list):
+                self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
+        else:
+            # Case: we have fewer qconfigs in qconfig_list than QConfigMappings
+
+            # pad qconfig_list with `None` until length is same
+            while len(qconfig_list) < len(self.qconfig_mappings_list):
+                qconfig_list.append(None)
+
+    # this function applies the insertion method across each QConfigMapping
+    def _insert_qconfig_list(
+        self,
+        style: str,
+        args: List[Union[str, int, Callable]],
+        qconfig_list: List[QConfigAny],
+    ) -> None:
+
+        # we remove duplicates and None to make the ordering of qconfigs
+        # deterministic upon insertion.
+        _remove_duplicates_and_none(qconfig_list)
+
+        self._handle_list_size_mismatch(qconfig_list, style)
+        method_name = _QCONFIG_STYLE_TO_METHOD[style]
+        for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
+            # uses QConfigMapping set method to insert qconfig
+            set_method = getattr(qconfig_mapping, method_name)
+            set_method(*args, qconfig)
+
+    def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
+        """
+        Set global QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
+        """
+        self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
+        return self
+
+    def set_object_type(
+        self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
+    ) -> QConfigMultiMapping:
+        """
+        Set object type QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
+        """
+        self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
+        return self
+
+    def set_module_name_regex(
+        self, module_name_regex: str, qconfig_list: List[QConfigAny]
+    ) -> QConfigMultiMapping:
+        """
+        Set module_name_regex QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
+        """
+        self._insert_qconfig_list(
+            "module_name_regex_qconfigs", [module_name_regex], qconfig_list
+        )
+        return self
+
+    def set_module_name(
+        self, module_name: str, qconfig_list: List[QConfigAny]
+    ) -> QConfigMultiMapping:
+        """
+        Set module_name QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
+        """
+        self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
+        return self
+
+    def set_module_name_object_type_order(
+        self,
+        module_name: str,
+        object_type: Callable,
+        index: int,
+        qconfig_list: List[QConfigAny],
+    ) -> QConfigMultiMapping:
+        """
+        Set module_name QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
+        """
+        self._insert_qconfig_list(
+            "module_name_object_type_order_qconfigs",
+            [module_name, object_type, index],
+            qconfig_list,
+        )
+        return self
+
+    def __repr__(self):
+        return (
+            self.__class__.__name__ +
+            " [" +
+            "".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) +
+            "\n]"
+        )
+
+    @classmethod
+    def from_list_qconfig_mapping(
+        cls, qconfig_mapping_list: List[QConfigMapping]
+    ) -> QConfigMultiMapping:
+        """
+        Creates a QConfigMultiMapping from a list of QConfigMappings
+        """
+        new_qconfig_multi_mapping = cls()
+
+        new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
+            qconfig_mapping_list
+        )
+
+        # we need to avoid the issue described in _handle_list_size_mismatch,
+        # so we reinsert all the qconfigs using the QConfigMultiMapping
+        # set methods
+
+        # go through all qconfig styles
+        # note: global can be ignored since it is None by default
+        for style in _QCONFIG_STYLE_ORDER[1:]:
+
+            # gather all key+qconfigs for current style
+            # into qconfig_dict_list
+            qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
+            for qconfig_mapping in qconfig_mapping_list:
+                qconfig_dict = getattr(qconfig_mapping, style)
+                for key, qconfig in qconfig_dict.items():
+                    if key not in qconfig_dict_list:
+                        qconfig_dict_list[key] = []
+                    qconfig_dict_list[key].append(qconfig)
+
+            # reinsert all gathered key+qconfigs
+            set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
+            set_method = getattr(new_qconfig_multi_mapping, set_method_name)
+            for key, qconfig_list in qconfig_dict_list.items():
+                if isinstance(key, tuple):
+                    set_method(*key, qconfig_list)
+                else:
+                    set_method(key, qconfig_list)
+
+        return new_qconfig_multi_mapping
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/utils.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..66d64d572e58dcf246fd357b0a1072c29999e5d8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/utils.py
@@ -0,0 +1,533 @@
+import enum
+import operator
+
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.quantized as nnq
+
+toq = torch.ops.quantized
+from typing import Tuple, Callable, Dict, Set, List, Optional, Union
+
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+from torch.ao.quantization import (
+    ObserverBase,
+    FakeQuantizeBase,
+)
+from torch.ao.quantization.utils import getattr_from_fqn
+from torch.ao.quantization.observer import _is_activation_post_process
+
+from .ns_types import NSNodeTargetType, NSResultsType
+
+# TODO(future PR): consider deleting this enum and using the torch types
+# directly.  This might be tricky because it is not a one to one mapping.
+class NodeInputOrOutputType(enum.Enum):
+    FP32 = enum.auto()  # torch.float
+    INT8 = enum.auto()  # torch.qint8 or torch.quint8
+    FP16 = enum.auto()  # torch.float16
+    UNKNOWN = enum.auto()  # we cannot determine input/output dtype
+    # TODO(future PR): while these functions can support multiple dtypes,
+    #   for the purposes of numerical debugging we want to get the actual
+    #   dtype used in the model. We will likely need some kind of dtype
+    #   propagation to estimate this.
+    FP32_OR_INT8 = enum.auto()  # either torch.float or torch.quint8 or torch.qint8
+    # TODO(future PRs): dynamic quant, fake quant, etc
+
+
+def get_node_first_input_and_output_type(
+    node: Node,
+    gm: GraphModule,
+    logger_cls: Callable,
+    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
+) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
+
+    # TODO(future PR): clean this up
+    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
+    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
+    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
+    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
+    MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
+    MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
+    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
+    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
+
+    if node.op == "call_function":
+        if node.target in FUNS_IO_TYPE_FP32:
+            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
+        if node.target in FUNS_IO_TYPE_FP16:
+            return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
+        elif node.target in FUNS_IO_TYPE_INT8:
+            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
+        elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
+            first_arg = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(first_arg, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, prev_node_output_type)
+        else:
+            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+
+    elif node.op == "call_module":
+        assert node.op == "call_module"
+        assert isinstance(node.target, str)
+        mod = getattr_from_fqn(gm, node.target)
+        is_known_fp32_or_int8_input_module = any(
+            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
+        )
+        if (
+            isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase))  # type: ignore[arg-type]
+            or is_known_fp32_or_int8_input_module
+        ):
+            # A logger or observer's input and output type is the output
+            # type of the preceding node.
+            first_arg = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(first_arg, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, prev_node_output_type)
+        is_known_fp32_input_module = any(
+            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32  # type: ignore[arg-type]
+        )
+        is_known_int8_input_module = any(
+            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8  # type: ignore[arg-type]
+        )
+        if is_known_fp32_input_module:
+            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
+        elif is_known_int8_input_module:
+            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
+        else:
+            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+
+    elif node.op == "call_method":
+        if node.target == "dequantize":
+            # Dequantize is a special node because it allows multiple input types.
+            # So, we look up the output type of the previous node and return that
+            # as the input type of this node instance.
+            prev_node = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(prev_node, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                prev_node, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, NodeInputOrOutputType.FP32)
+
+        elif node.target == "to":
+            # to is a special node because it allows multiple input types.
+            # So, we look up the output type of the previous node and return that
+            # as the input type of this node instance. We also look up the target
+            # of to and return the correct output type.
+            prev_node = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(prev_node, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                prev_node, gm, logger_cls, node_type_to_io_type_map
+            )
+
+            cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
+            assert (
+                cur_node_dtype_target is torch.float16
+            ), f"{cur_node_dtype_target} handling needs to be added"
+
+            return (prev_node_output_type, NodeInputOrOutputType.FP16)
+
+        elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
+            first_arg = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(first_arg, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, prev_node_output_type)
+
+        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+    else:
+        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+
+
+def get_node_input_qparams(
+    node: Node,
+    gm: GraphModule,
+    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
+) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
+    """
+    Returns the qparams (scale, zero_point) of the first input to `node`,
+    if they can be inferred from the graph.
+    """
+    prev_node = get_normalized_nth_input(node, gm, 0)
+
+    if not isinstance(prev_node, Node):
+        return None
+
+    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
+
+    def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
+        scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
+        zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
+        assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
+        assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
+        scale_obj = getattr_from_fqn(gm, scale_node.target)
+        zp_obj = getattr_from_fqn(gm, zp_node.target)
+        return (scale_obj, zp_obj)
+
+    if prev_node.op == "call_function":
+
+        # quantize - read the args directly
+        if prev_node.target == torch.quantize_per_tensor:
+            return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
+        elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
+            return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
+
+        return None
+        # TODO(future PR): handle more functionals
+        # TODO(future PR): handle functional ops which inherit qparams from input
+
+    elif prev_node.op == "call_module":
+
+        # get type of the module
+        assert isinstance(prev_node.target, str)
+        module_obj = getattr_from_fqn(gm, prev_node.target)
+        if isinstance(
+            module_obj,
+            (
+                nnq.Linear,
+                nnq.Conv1d,
+                nnq.Conv2d,
+                nniq.ConvReLU2d,
+                nnq.Conv3d,
+                nnq.BatchNorm2d,
+                nnq.BatchNorm3d,
+                nnq.ConvTranspose1d,
+                nnq.ConvTranspose2d,
+                nnq.ELU,
+                nnq.GroupNorm,
+                nnq.InstanceNorm1d,
+                nnq.InstanceNorm2d,
+                nnq.InstanceNorm3d,
+                nnq.LayerNorm,
+                nnq.Hardswish,
+                nnq.LeakyReLU,
+                nnq.ReLU6,
+                nniq.BNReLU2d,
+                nniq.BNReLU3d,
+                nniq.ConvReLU1d,
+                nniq.ConvReLU2d,
+                nniq.ConvReLU3d,
+                nniq.LinearReLU,
+            ),
+        ):
+            return (module_obj.scale, module_obj.zero_point)  # type: ignore[return-value]
+
+        is_known_fp32_or_int8_input_module = any(
+            isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
+        )
+        if is_known_fp32_or_int8_input_module:
+            return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
+
+    return None
+
+
+def return_first_non_observer_node(
+    node: Node,
+    gm: GraphModule,
+) -> Node:
+    """
+    If node is not an observer, returns it.  If node is an observer,
+    navigates up the graph and returns the first parent which is not an
+    observer.  For example,
+
+    graph: (node_non_obs), node = node_non_obs : returns node_non_obs
+    graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
+    graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
+    """
+    if node.op == "call_module":
+        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
+        if _is_activation_post_process(node_obj):
+            assert len(node.args) == 1
+            assert isinstance(node.args[0], Node)
+            node = node.args[0]
+            # code duplication intended, not worth refactoring
+            assert isinstance(node.target, str)
+            node_obj = getattr_from_fqn(gm, node.target)
+            if _is_activation_post_process(node_obj):
+                assert len(node.args) == 1
+                assert isinstance(node.args[0], Node)
+                node = node.args[0]
+    return node
+
+
+def get_number_of_non_param_args(
+    node: Node,
+    gm: GraphModule,
+) -> int:
+    """
+    Assumes that all non-param args occur first. Returns the number of
+    non-param args expected for a node.  For example, for
+
+      F.linear(x, weight, bias)
+
+    Returns 1, because x is a non-param arg and weight and bias are params.
+    For
+
+      lstm_mod(x, hid)
+
+    Returns 2, because both x and hid are non-param args.
+    """
+    if node.op == "call_module":
+        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
+        if isinstance(node_obj, nn.LSTM):
+            return 2
+
+    # default is 1
+    return 1
+
+
+def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
+    """
+    Returns the indices of args of the node which we should attach
+    loggers to, if input logging is enabled.
+
+    For example,
+    * for (x + y), returns [0, 1]
+    * for (1 + y), returns [1]
+    * for (x + 1), returns [0]
+    * for (linear(x, w, b)) returns [0]
+    * by default, returns [0]
+    """
+    if len(node.args) == 0:
+        return []
+    if node.op == "call_function" and (
+        # TODO(future PR): use relationship map instead of hardcoding
+        node.target in (torch.add, torch.ops.quantized.add, operator.add)
+        or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
+    ):
+        result = []
+        for i in range(2):
+            if type(node.args[i]) == Node:
+                result.append(i)
+        return result
+    return [0]
+
+
+def get_target_type_str(node: Node, gm: GraphModule) -> str:
+    """
+    Returns a string representation of the type of the function or module
+    pointed to by this node, or '' for other node types.
+    """
+    target_type = ""
+    if node.op in ("call_function", "call_method"):
+        target_type = torch.typename(node.target)
+    elif node.op == "call_module":
+        assert isinstance(node.target, str)
+        target_mod = getattr_from_fqn(gm, node.target)
+        target_type = torch.typename(target_mod)
+    return target_type
+
+
+def rekey_logger_info_on_node_name_of_model(
+    results: NSResultsType,
+    model_name: str,
+) -> NSResultsType:
+    """
+    Rekeys the layer name of a results dictionary to use node names
+    from `model_name`.
+
+    For example, transforms
+
+        {'base_op_1_0': {'node_output': {'model_a':
+          [{'ref_node_name': 'linear1', ...}]}}}
+
+    into
+
+        {'linear1': {'node_output': {'model_a':
+          [{'ref_node_name': 'linear1', ...}]}}}
+
+    Note: we cannot use these node names directly because they are not
+    guaranteed to be consistent across models. This is why we extract
+    the results first and rekey afterwards.
+    """
+    new_results = {}
+    for old_layer_name, result_type_to_results in results.items():
+        new_layer_name = None
+        for model_name_to_results in result_type_to_results.values():
+            for cur_model_name, list_of_results in model_name_to_results.items():
+                if cur_model_name == model_name:
+                    assert len(list_of_results)
+                    new_layer_name = list_of_results[0]["ref_node_name"]
+                else:
+                    continue
+        if new_layer_name is not None:
+            new_results[new_layer_name] = result_type_to_results
+        else:
+            new_results[old_layer_name] = result_type_to_results
+    return new_results
+
+
+def maybe_add_missing_fqns(results: NSResultsType) -> None:
+    """
+    If `fqn` entries are filled in for one of the models in `results`, copies
+    them over to any models which do not have them filled out.
+
+    A common use case benefitting from this is comparing a model prepared by
+    quantization to a quantized model. In this case, the model prepared by
+    quantization would have `fqn` entries, and the quantized model would not.
+    """
+
+    # Check in the first result to find any model with fqn entries defined.
+    model_name_with_fqns = None
+    for result_type_to_results in results.values():
+        for model_name_to_results in result_type_to_results.values():
+            for model_name, model_results in model_name_to_results.items():
+                if len(model_results) > 0:
+                    if model_results[0]["fqn"] is not None:
+                        model_name_with_fqns = model_name
+                        break
+            break
+        break
+
+    if model_name_with_fqns:
+        for result_type_to_results in results.values():
+            for model_name_to_results in result_type_to_results.values():
+                ref_model_results = model_name_to_results[model_name_with_fqns]
+                for model_name, model_results in model_name_to_results.items():
+                    if model_name == model_name_with_fqns:
+                        continue
+                    for i in range(len(model_results)):
+                        fqn = ref_model_results[i]["fqn"]
+                        model_results[i]["fqn"] = fqn
+
+
+def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
+    def inner(*args, **kwargs):
+        a0, a1, *a_other = args
+
+        if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
+            isinstance(a0, list) and isinstance(a1, list)
+        ):
+            results = []
+            for el0, el1 in zip(a0, a1):
+                new_args = (el0, el1, *a_other)
+                results.append(inner(*new_args, **kwargs))
+            return results
+
+        elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
+            if a0.is_quantized:
+                a0 = a0.dequantize()
+            if a1.is_quantized:
+                a1 = a1.dequantize()
+
+        # for the purposes of this util, only handle floats
+        if a0.dtype != torch.float or a1.dtype != torch.float:
+            return None
+
+        new_args = (a0, a1, *a_other)
+        return f(*new_args, **kwargs)
+
+    return inner
+
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """
+    Computes the SQNR between `x` and `y`.
+
+    Args:
+        x: Tensor or tuple of tensors
+        y: Tensor or tuple of tensors
+
+    Return:
+        float or tuple of floats
+    """
+    Ps = torch.norm(x)
+    Pn = torch.norm(x - y)
+    return 20 * torch.log10(Ps / Pn)
+
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """
+    Computes the normalized L2 error between `x` and `y`.
+
+    Args:
+        x: Tensor or tuple of tensors
+        y: Tensor or tuple of tensors
+
+    Return:
+        float or tuple of floats
+    """
+    return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
+
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """
+    Computes the cosine similarity between `x` and `y`.
+
+    Args:
+        x: Tensor or tuple of tensors
+        y: Tensor or tuple of tensors
+
+    Return:
+        float or tuple of floats
+    """
+    # For convolutions, the shape of the quantized weight has one additional
+    # dimension compared to the shape of the fp32 weight. Match the shapes
+    # to enable cosine similarity comparison.
+    x = x.reshape(1, -1)
+    y = y.reshape(1, -1)
+    return torch.nn.functional.cosine_similarity(x, y)
+
+def op_type_supports_shadowing(node: Node) -> bool:
+    if node.op == 'call_function':
+        if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
+            # shadowing for ops with multiple tensor inputs is not implemented yet
+            return False
+    return True
+
+def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
+    """
+    Given a node, gets the n'th input to that node, normalizing
+    args and kwargs to the best of its ability.
+    """
+    try:
+        norm_args_and_kwargs = node.normalized_arguments(
+            gm, normalize_to_only_use_kwargs=True)
+        if norm_args_and_kwargs is not None:
+            norm_args, norm_kwargs = norm_args_and_kwargs
+            assert len(norm_args) + len(norm_kwargs) > idx
+            if idx < len(norm_args):
+                return norm_args[idx]
+            else:
+                # note: in Python 3.7+ dicts are ordered
+                return list(norm_kwargs.values())[idx]
+        else:
+            assert len(node.args) + len(node.kwargs) > idx
+            if idx < len(node.args):
+                return node.args[idx]  # type: ignore[return-value]
+            else:
+                kwargs_idx = idx + len(node.args)
+                return list(node.kwargs.values())[kwargs_idx]  # type: ignore[return-value]
+    except RuntimeError:
+        # this RuntimeError happens when node argument normalization
+        # requires typehints to proceed, such as for torch.add where
+        # either the first, second or both arguments could be tensors
+        assert len(node.args) + len(node.kwargs) > idx
+        if idx < len(node.args):
+            return node.args[idx]  # type: ignore[return-value]
+        else:
+            kwargs_idx = idx + len(node.args)
+            return list(node.kwargs.values())[kwargs_idx]  # type: ignore[return-value]
diff --git a/MLPY/Lib/site-packages/torch/ao/ns/fx/weight_utils.py b/MLPY/Lib/site-packages/torch/ao/ns/fx/weight_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e242d2f58bf6ba328b079f2ed8d026b69150dade
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/ns/fx/weight_utils.py
@@ -0,0 +1,275 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.quantized as nniq
+toq = torch.ops.quantized
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+
+from .utils import (
+    get_target_type_str,
+    getattr_from_fqn,
+    return_first_non_observer_node,
+)
+
+from .ns_types import (
+    NSSingleResultValuesType,
+    NSSingleResultType,
+)
+
+from typing import List, Optional, Dict, Callable
+
+def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
+    return mod.weight.detach()  # type: ignore[operator]
+
+def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
+    return mod[0].weight.detach()  # type: ignore[index]
+
+def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
+    return mod._weight_bias()[0]  # type: ignore[operator]
+
+def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
+    res = []
+    for idx, param_name in enumerate(mod._flat_weights_names):  # type: ignore[arg-type]
+        if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
+            param_value = mod._flat_weights[idx].detach()  # type: ignore[index]
+            res.append(param_value)
+    return res
+
+def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
+    res = []
+    for weight_value in mod._all_weight_values:  # type: ignore[union-attr]
+        res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
+        res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
+    return res
+
+def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
+    if (
+        isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d))
+    ):
+        return mod.weight.detach()
+    elif (
+        isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d))
+    ):
+        return mod[0].weight.detach()
+    else:
+        return mod._weight_bias()[0]  # type: ignore[operator]
+
+def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
+    if isinstance(mod, nn.Linear):
+        return mod.weight.detach()
+    elif isinstance(mod, nni.LinearReLU):
+        return mod[0].weight.detach()
+    else:
+        return mod._weight_bias()[0]  # type: ignore[operator]
+
+def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
+    # TODO(future PR): make more generic, handle everything
+    if isinstance(mod, nn.LSTM):
+        res = []
+        for idx, param_name in enumerate(mod._flat_weights_names):
+            if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
+                param_value = mod._flat_weights[idx].detach()
+                res.append(param_value)
+        return res
+    else:
+        assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
+        res = []
+        for weight_value in mod._all_weight_values:
+            res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
+            res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
+        return res
+
+def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # traverse backwards from the weight arg, accounting for any observers
+    weight_arg_node = node.args[1]
+    assert isinstance(weight_arg_node, Node)
+    weight_node = return_first_non_observer_node(weight_arg_node, gm)
+    assert isinstance(weight_node, Node)
+    assert weight_node.op == 'get_attr'
+    weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
+    return weight.detach()
+
+def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # qconv state is arg 1
+    qconv_state_node = node.args[1]
+    assert isinstance(qconv_state_node, Node)
+    assert qconv_state_node.op == 'get_attr'
+    qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target)  # type: ignore[arg-type]
+    return qconv_state_obj.weight()
+
+def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # traverse backwards from the weight arg, accounting for any observers
+    # supported patterns:
+    # weight -> obs -> linear
+    # weight -> to(torch.float16) -> dequantize -> linear
+    linear_second_arg = node.args[1]
+    assert isinstance(linear_second_arg, Node)
+
+    if linear_second_arg.op == 'call_module':
+        # weight -> obs -> linear
+        weight_arg_node = node.args[1]
+        assert isinstance(weight_arg_node, Node)
+        weight_node = weight_arg_node.args[0]
+        assert isinstance(weight_node, Node)
+        assert weight_node.op == 'get_attr'
+        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
+        return weight.detach()
+    elif linear_second_arg.op == 'call_method':
+        # weight -> to(torch.float16) -> dequantize -> linear
+        assert linear_second_arg.op == 'call_method'
+        dequant_node = node.args[1]
+        assert isinstance(dequant_node, Node)
+        to_fp16_node = dequant_node.args[0]
+        assert isinstance(to_fp16_node, Node)
+        # extract the dtype, so we can cast to it before returning
+        target_dtype = to_fp16_node.args[1]
+        weight_node = to_fp16_node.args[0]
+        assert isinstance(weight_node, Node)
+        assert weight_node.op == 'get_attr'
+        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
+        # return the weight with fp16 cast
+        return weight.detach().to(target_dtype)
+    else:
+        assert linear_second_arg.op == 'get_attr'
+        weight = getattr_from_fqn(gm, linear_second_arg.target)  # type: ignore[arg-type]
+        return weight.detach()
+
+def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # packed weight is arg 1
+    packed_weight_node = node.args[1]
+    assert isinstance(packed_weight_node, Node)
+    assert packed_weight_node.op == 'get_attr'
+    packed_weight = getattr_from_fqn(gm, packed_weight_node.target)  # type: ignore[arg-type]
+    # TODO(future PR): why does packed_weight.unpack() not work?
+    (weight, _bias), _name = packed_weight.__getstate__()
+    return weight
+
+def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:
+
+    op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
+        'call_module': {
+            # Conv1d
+            nn.Conv1d: mod_weight_detach,
+            nni.ConvReLU1d: mod_0_weight_detach,
+            nnq.Conv1d: mod_weight_bias_0,
+            nnqat.Conv1d: mod_weight_detach,
+            nniqat.ConvBn1d: mod_weight_detach,
+            nniqat.ConvBnReLU1d: mod_weight_detach,
+            nniqat.ConvReLU1d: mod_weight_detach,
+            nniq.ConvReLU1d: mod_weight_bias_0,
+            # Conv2d
+            nn.Conv2d: mod_weight_detach,
+            nni.ConvReLU2d: mod_0_weight_detach,
+            nnq.Conv2d: mod_weight_bias_0,
+            nnqat.Conv2d: mod_weight_detach,
+            nniqat.ConvBn2d: mod_weight_detach,
+            nniqat.ConvBnReLU2d: mod_weight_detach,
+            nniqat.ConvReLU2d: mod_weight_detach,
+            nniq.ConvReLU2d: mod_weight_bias_0,
+            # Conv3d
+            nn.Conv3d: mod_weight_detach,
+            nni.ConvReLU3d: mod_0_weight_detach,
+            nnq.Conv3d: mod_weight_bias_0,
+            nnqat.Conv3d: mod_weight_detach,
+            nniqat.ConvBn3d: mod_weight_detach,
+            nniqat.ConvBnReLU3d: mod_weight_detach,
+            nniqat.ConvReLU3d: mod_weight_detach,
+            nniq.ConvReLU3d: mod_weight_bias_0,
+            # Linear
+            nn.Linear: mod_weight_detach,
+            nnq.Linear: mod_weight_bias_0,
+            nni.LinearReLU: mod_0_weight_detach,
+            nniq.LinearReLU: mod_weight_bias_0,
+            nnqat.Linear: mod_weight_detach,
+            nnqd.Linear: mod_weight_bias_0,
+            nniqat.LinearReLU: mod_weight_detach,
+            nniqat.LinearBn1d: mod_weight_detach,
+            nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
+            # LSTM
+            nn.LSTM: get_lstm_weight,
+            nnqd.LSTM: get_qlstm_weight,
+        },
+        'call_function': {
+            # Conv
+            F.conv1d: get_conv_fun_weight,
+            F.conv2d: get_conv_fun_weight,
+            F.conv3d: get_conv_fun_weight,
+            toq.conv1d: get_qconv_fun_weight,
+            toq.conv2d: get_qconv_fun_weight,
+            toq.conv3d: get_qconv_fun_weight,
+            toq.conv1d_relu: get_qconv_fun_weight,
+            toq.conv2d_relu: get_qconv_fun_weight,
+            toq.conv3d_relu: get_qconv_fun_weight,
+            # Linear
+            F.linear: get_linear_fun_weight,
+            toq.linear: get_qlinear_fun_weight,
+            toq.linear_relu: get_qlinear_fun_weight,
+        },
+    }
+
+    return op_to_type_to_weight_extraction_fn
+
+def extract_weight_from_node(
+    node: Node,
+    gm: GraphModule,
+    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
+) -> Optional[NSSingleResultType]:
+    res_type = NSSingleResultValuesType.WEIGHT.value
+
+    # Not all graphmodules have _node_name_to_scope, so only fill it
+    # out if it exists.
+    fqn = None
+    if hasattr(gm, '_node_name_to_scope'):
+        fqn = gm._node_name_to_scope[node.name][0]  # type: ignore[index]
+
+    if op_to_type_to_weight_extraction_fn is None:
+        op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
+
+    ref_node_type = get_target_type_str(node, gm)
+    # for extracting weights, these are always the same
+    prev_node_type = ref_node_type
+
+    if node.op == 'call_function':
+        function_mapping = op_to_type_to_weight_extraction_fn['call_function']
+        for target_fn_type, weight_extraction_fn in function_mapping.items():
+            if node.target == target_fn_type:
+                weight = weight_extraction_fn(node, gm)
+                return {
+                    'type': res_type,
+                    'values': [weight],
+                    'prev_node_name': node.name,
+                    'prev_node_target_type': prev_node_type,
+                    'ref_node_name': node.name,
+                    'ref_node_target_type': ref_node_type,
+                    'index_within_arg': 0,
+                    'index_of_arg': 0,
+                    'fqn': fqn,
+                }
+
+    elif node.op == 'call_module':
+        # for call_module, we need to look up the modules to do the type check
+        assert isinstance(node.target, str)
+        mod = getattr_from_fqn(gm, node.target)
+        module_mapping = op_to_type_to_weight_extraction_fn['call_module']
+        for target_mod_type, weight_extraction_fn in module_mapping.items():
+            if type(mod) == target_mod_type:
+                weight = weight_extraction_fn(mod)
+                return {
+                    'type': res_type,
+                    'values': [weight],
+                    'prev_node_name': node.name,
+                    'prev_node_target_type': prev_node_type,
+                    'ref_node_name': node.name,
+                    'ref_node_target_type': ref_node_type,
+                    'index_within_arg': 0,
+                    'index_of_arg': 0,
+                    'fqn': fqn,
+                }
+
+    return None
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6bcce09809ecf3899bf7bfb1e4a4e8c6c6f5345
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/__init__.py
@@ -0,0 +1,19 @@
+# Variables
+from ._mappings import get_dynamic_sparse_quantized_mapping
+from ._mappings import get_static_sparse_quantized_mapping
+
+# Sparsifier
+from .sparsifier.base_sparsifier import BaseSparsifier
+from .sparsifier.weight_norm_sparsifier import WeightNormSparsifier
+from .sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier
+
+# Scheduler
+from .scheduler.base_scheduler import BaseScheduler
+from .scheduler.lambda_scheduler import LambdaSL
+from .scheduler.cubic_scheduler import CubicSL
+
+# Parametrizations
+from .sparsifier.utils import FakeSparsity
+from .sparsifier.utils import module_to_fqn
+from .sparsifier.utils import fqn_to_module
+from .sparsifier.utils import get_arg_info_from_tensor_fqn
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd43c6268e9ba7f73b044e4c4d7bc059a96549c9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7bda8485f6658452e833f74e033099b60cfde06c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66d63690867e579df48cce9bc185da3f340f4126
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..377efe1717afd5f56a4c45ec229f9fe4d32d4388
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae04648c6dd0c65ce571ef9aa0843787f4844919
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..35eb38ab11dd912166722d241d9f026f6a2f7060
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py
@@ -0,0 +1,418 @@
+from typing import Any, Dict, List, Optional
+import torch
+from collections import defaultdict
+from torch import nn
+import copy
+from ...sparsifier.utils import fqn_to_module, module_to_fqn
+import warnings
+
+__all__ = ['ActivationSparsifier']
+
+
+class ActivationSparsifier:
+    r"""
+    The Activation sparsifier class aims to sparsify/prune activations in a neural
+    network. The idea is to attach the sparsifier to a layer (or layers) and it
+    zeroes out the activations based on the mask_fn (or sparsification function)
+    input by the user.
+    The mask_fn is applied once all the inputs are aggregated and reduced i.e.
+    mask = mask_fn(reduce_fn(aggregate_fn(activations)))
+
+    Note::
+        The sparsification mask is computed on the input **before it goes through the attached layer**.
+
+    Args:
+        model (nn.Module):
+            The model whose layers will be sparsified. The layers that needs to be
+            sparsified should be added separately using the register_layer() function
+        aggregate_fn (Optional, Callable):
+            default aggregate_fn that is used if not specified while registering the layer.
+            specifies how inputs should be aggregated over time.
+            The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor.
+            Example
+                def add_agg_fn(tensor1, tensor2):  return tensor1 + tensor2
+                reduce_fn (Optional, Callable):
+                    default reduce_fn that is used if not specified while registering the layer.
+                    reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after
+                    calling agg_fn() on all inputs.
+                    Example
+                def mean_reduce_fn(agg_tensor):    return agg_tensor.mean(dim=0)
+                mask_fn (Optional, Callable):
+                    default mask_fn that is used to create the sparsification mask using the tensor obtained after
+                    calling the reduce_fn(). This is used by default if a custom one is passed in the
+                    register_layer().
+                    Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config
+                    arguments.
+                features (Optional, list):
+                    default selected features to sparsify.
+                    If this is non-empty, then the mask_fn will be applied for each feature of the input.
+                    For example,
+                mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features]
+                feature_dim (Optional, int):
+                    default dimension of input features. Again, features along this dim will be chosen
+                    for sparsification.
+                sparse_config (Dict):
+                    Default configuration for the mask_fn. This config will be passed
+                    with the mask_fn()
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> model = SomeModel()
+        >>> act_sparsifier = ActivationSparsifier(...)  # init activation sparsifier
+        >>> # Initialize aggregate_fn
+        >>> def agg_fn(x, y):
+        >>>     return x + y
+        >>>
+        >>> # Initialize reduce_fn
+        >>> def reduce_fn(x):
+        >>>     return torch.mean(x, dim=0)
+        >>>
+        >>> # Initialize mask_fn
+        >>> def mask_fn(data):
+        >>>     return torch.eye(data.shape).to(data.device)
+        >>>
+        >>>
+        >>> act_sparsifier.register_layer(model.some_layer, aggregate_fn=agg_fn, reduce_fn=reduce_fn, mask_fn=mask_fn)
+        >>>
+        >>> # start training process
+        >>> for _ in [...]:
+        >>>     # epoch starts
+        >>>         # model.forward(), compute_loss() and model.backwards()
+        >>>     # epoch ends
+        >>>     act_sparsifier.step()
+        >>> # end training process
+        >>> sparsifier.squash_mask()
+    """
+    def __init__(self, model: nn.Module, aggregate_fn=None, reduce_fn=None, mask_fn=None,
+                 features=None, feature_dim=None, **sparse_config):
+        self.model = model
+        self.defaults: Dict[str, Any] = defaultdict()
+        self.defaults['sparse_config'] = sparse_config
+
+        # functions
+        self.defaults['aggregate_fn'] = aggregate_fn
+        self.defaults['reduce_fn'] = reduce_fn
+        self.defaults['mask_fn'] = mask_fn
+
+        # default feature and feature_dim
+        self.defaults['features'] = features
+        self.defaults['feature_dim'] = feature_dim
+
+        self.data_groups: Dict[str, Dict] = defaultdict(dict)  # contains all relevant info w.r.t each registered layer
+
+        self.state: Dict[str, Any] = defaultdict(dict)  # layer name -> mask
+
+    @staticmethod
+    def _safe_rail_checks(args):
+        """Makes sure that some of the functions and attributes are not passed incorrectly
+        """
+
+        # if features are not None, then feature_dim must not be None
+        features, feature_dim = args['features'], args['feature_dim']
+        if features is not None:
+            assert feature_dim is not None, "need feature dim to select features"
+
+        # all the *_fns should be callable
+        fn_keys = ['aggregate_fn', 'reduce_fn', 'mask_fn']
+        for key in fn_keys:
+            fn = args[key]
+            assert callable(fn), 'function should be callable'
+
+    def _aggregate_hook(self, name):
+        """Returns hook that computes aggregate of activations passing through.
+        """
+
+        # gather some data
+        feature_dim = self.data_groups[name]['feature_dim']
+        features = self.data_groups[name]['features']
+        agg_fn = self.data_groups[name]['aggregate_fn']
+
+        def hook(module, input) -> None:
+            input_data = input[0]
+
+            data = self.data_groups[name].get('data')  # aggregated data
+            if features is None:
+                # no features associated, data should not be a list
+                if data is None:
+                    data = torch.zeros_like(input_data)
+                    self.state[name]['mask'] = torch.ones_like(input_data)
+                out_data = agg_fn(data, input_data)
+            else:
+                # data should be a list [aggregated over each feature only]
+                if data is None:
+                    out_data = [0 for _ in range(0, len(features))]  # create one incase of 1st forward
+                    self.state[name]['mask'] = [0 for _ in range(0, len(features))]
+                else:
+                    out_data = data  # a list
+
+                # compute aggregate over each feature
+                for feature_idx in range(len(features)):
+                    # each feature is either a list or scalar, convert it to torch tensor
+                    feature_tensor = torch.Tensor([features[feature_idx]]).long().to(input_data.device)
+                    data_feature = torch.index_select(input_data, feature_dim, feature_tensor)
+                    if data is None:
+                        curr_data = torch.zeros_like(data_feature)
+                        self.state[name]['mask'][feature_idx] = torch.ones_like(data_feature)
+                    else:
+                        curr_data = data[feature_idx]
+                    out_data[feature_idx] = agg_fn(curr_data, data_feature)
+            self.data_groups[name]['data'] = out_data
+        return hook
+
+    def register_layer(self, layer: nn.Module, aggregate_fn=None, reduce_fn=None,
+                       mask_fn=None, features=None, feature_dim=None, **sparse_config):
+        r"""
+        Registers a layer for sparsification. The layer should be part of self.model.
+        Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn
+        and store the aggregated activations that is input over each step.
+
+        Note::
+            - There is no need to pass in the name of the layer as it is automatically computed as per
+              the fqn convention.
+
+            - All the functions (fn) passed as argument will be called at a dim, feature level.
+        """
+        name = module_to_fqn(self.model, layer)
+        assert name is not None, "layer not found in the model"  # satisfy mypy
+
+        if name in self.data_groups:  # unregister layer if already present
+            warnings.warn("layer already attached to the sparsifier, deregistering the layer and registering with new config")
+            self.unregister_layer(name=name)
+
+        local_args = copy.deepcopy(self.defaults)
+        update_dict = {
+            'aggregate_fn': aggregate_fn,
+            'reduce_fn': reduce_fn,
+            'mask_fn': mask_fn,
+            'features': features,
+            'feature_dim': feature_dim,
+            'layer': layer
+        }
+        local_args.update((arg, val) for arg, val in update_dict.items() if val is not None)
+        local_args['sparse_config'].update(sparse_config)
+
+        self._safe_rail_checks(local_args)
+
+        self.data_groups[name] = local_args
+        agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name))
+
+        self.state[name]['mask'] = None  # mask will be created when model forward is called.
+
+        # attach agg hook
+        self.data_groups[name]['hook'] = agg_hook
+
+        # for serialization purposes, we know whether aggregate_hook is attached
+        # or sparsify_hook()
+        self.data_groups[name]['hook_state'] = "aggregate"  # aggregate hook is attached
+
+    def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None):
+        """
+        Returns mask associated to the layer.
+
+        The mask is
+            - a torch tensor is features for that layer is None.
+            - a list of torch tensors for each feature, otherwise
+
+        Note::
+            The shape of the mask is unknown until model.forward() is applied.
+            Hence, if get_mask() is called before model.forward(), an
+            error will be raised.
+        """
+        assert name is not None or layer is not None, "Need at least name or layer obj to retrieve mask"
+
+        if name is None:
+            assert layer is not None
+            name = module_to_fqn(self.model, layer)
+            assert name is not None, "layer not found in the specified model"
+
+        if name not in self.state:
+            raise ValueError("Error: layer with the given name not found")
+
+        mask = self.state[name].get('mask', None)
+
+        if mask is None:
+            raise ValueError("Error: shape unknown, call layer() routine at least once to infer mask")
+        return mask
+
+    def unregister_layer(self, name):
+        """Detaches the sparsifier from the layer
+        """
+
+        # detach any hooks attached
+        self.data_groups[name]['hook'].remove()
+
+        # pop from the state dict
+        self.state.pop(name)
+
+        # pop from the data groups
+        self.data_groups.pop(name)
+
+    def step(self):
+        """Internally calls the update_mask() function for each layer
+        """
+        with torch.no_grad():
+            for name, configs in self.data_groups.items():
+                data = configs['data']
+                self.update_mask(name, data, configs)
+
+                self.data_groups[name].pop('data')  # reset the accumulated data
+
+    def update_mask(self, name, data, configs):
+        """
+        Called for each registered layer and does the following-
+            1. apply reduce_fn on the aggregated activations
+            2. use mask_fn to compute the sparsification mask
+
+        Note:
+            the reduce_fn and mask_fn is called for each feature, dim over the data
+        """
+        mask = self.get_mask(name)
+        sparse_config = configs['sparse_config']
+        features = configs['features']
+        reduce_fn = configs['reduce_fn']
+        mask_fn = configs['mask_fn']
+        if features is None:
+            data = reduce_fn(data)
+            mask.data = mask_fn(data, **sparse_config)
+        else:
+            for feature_idx in range(len(features)):
+                data_feature = reduce_fn(data[feature_idx])
+                mask[feature_idx].data = mask_fn(data_feature, **sparse_config)
+
+    def _sparsify_hook(self, name):
+        """Returns hook that applies sparsification mask to input entering the attached layer
+        """
+        mask = self.get_mask(name)
+        features = self.data_groups[name]['features']
+        feature_dim = self.data_groups[name]['feature_dim']
+
+        def hook(module, input):
+            input_data = input[0]
+            if features is None:
+                # apply to all the features
+                return input_data * mask
+            else:
+                # apply per feature, feature_dim
+                for feature_idx in range(0, len(features)):
+                    feature = torch.Tensor([features[feature_idx]]).long().to(input_data.device)
+                    sparsified = torch.index_select(input_data, feature_dim, feature) * mask[feature_idx]
+                    input_data.index_copy_(feature_dim, feature, sparsified)
+                return input_data
+        return hook
+
+    def squash_mask(self, attach_sparsify_hook=True, **kwargs):
+        """
+        Unregisters aggregate hook that was applied earlier and registers sparsification hooks if
+        attach_sparsify_hook = True.
+        """
+        for name, configs in self.data_groups.items():
+            # unhook agg hook
+            configs['hook'].remove()
+            configs.pop('hook')
+            self.data_groups[name]['hook_state'] = "None"
+            if attach_sparsify_hook:
+                configs['hook'] = configs['layer'].register_forward_pre_hook(self._sparsify_hook(name))
+            configs['hook_state'] = "sparsify"  # signals that sparsify hook is now attached
+
+    def _get_serializable_data_groups(self):
+        """Exclude hook and layer from the config keys before serializing
+
+        TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing.
+              For time-being, functions are treated the same way as other attributes
+        """
+        data_groups: Dict[str, Any] = defaultdict()
+        for name, config in self.data_groups.items():
+            new_config = {key: value for key, value in config.items() if key not in ['hook', 'layer']}
+            data_groups[name] = new_config
+        return data_groups
+
+    def _convert_mask(self, states_dict, sparse_coo=True):
+        r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument.
+        If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor
+        """
+        states = copy.deepcopy(states_dict)
+        for state in states.values():
+            if state['mask'] is not None:
+                if isinstance(state['mask'], List):
+                    for idx in range(len(state['mask'])):
+                        if sparse_coo:
+                            state['mask'][idx] = state['mask'][idx].to_sparse_coo()
+                        else:
+                            state['mask'][idx] = state['mask'][idx].to_dense()
+                else:
+                    if sparse_coo:
+                        state['mask'] = state['mask'].to_sparse_coo()
+                    else:
+                        state['mask'] = state['mask'].to_dense()
+        return states
+
+    def state_dict(self) -> Dict[str, Any]:
+        r"""Returns the state of the sparsifier as a :class:`dict`.
+
+        It contains:
+        * state - contains name -> mask mapping.
+        * data_groups - a dictionary containing all config information for each
+            layer
+        * defaults - the default config while creating the constructor
+        """
+        data_groups = self._get_serializable_data_groups()
+        state = self._convert_mask(self.state)
+        return {
+            'state': state,
+            'data_groups': data_groups,
+            'defaults': self.defaults
+        }
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        r"""The load_state_dict() restores the state of the sparsifier based on the state_dict
+
+        Args:
+        * state_dict - the dictionary that to which the current sparsifier needs to be restored to
+        """
+        state = state_dict['state']
+        data_groups, defaults = state_dict['data_groups'], state_dict['defaults']
+
+        self.__set_state__({'state': state, 'data_groups': data_groups, 'defaults': defaults})
+
+    def __get_state__(self) -> Dict[str, Any]:
+
+        data_groups = self._get_serializable_data_groups()
+        state = self._convert_mask(self.state)
+        return {
+            'defaults': self.defaults,
+            'state': state,
+            'data_groups': data_groups,
+        }
+
+    def __set_state__(self, state: Dict[str, Any]) -> None:
+        state['state'] = self._convert_mask(state['state'], sparse_coo=False)  # convert mask to dense tensor
+        self.__dict__.update(state)
+
+        # need to attach layer and hook info into the data_groups
+        for name, config in self.data_groups.items():
+            # fetch layer
+            layer = fqn_to_module(self.model, name)
+            assert layer is not None  # satisfy mypy
+
+            # if agg_mode is True, then layer in aggregate mode
+            if "hook_state" in config and config['hook_state'] == "aggregate":
+                hook = layer.register_forward_pre_hook(self._aggregate_hook(name))
+
+            elif "hook_state" in config and config["hook_state"] == "sparsify":
+                hook = layer.register_forward_pre_hook(self._sparsify_hook(name))
+
+            config['layer'] = layer
+            config['hook'] = hook  # type: ignore[possibly-undefined]
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + ' ('
+        for name, config in self.data_groups.items():
+            format_string += '\n'
+            format_string += '\tData Group\n'
+            format_string += f'\t    name: {name}\n'
+            for key in sorted(config.keys()):
+                if key in ['data', 'hook', 'reduce_fn', 'mask_fn', 'aggregate_fn']:
+                    continue
+                format_string += f'\t    {key}: {config[key]}\n'
+        format_string += ')'
+        return format_string
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4880dcec5ff6ee8c258fb2bb09ea29a6011fc41b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py
@@ -0,0 +1,5 @@
+from .base_data_scheduler import BaseDataScheduler
+
+__all__ = [
+    "BaseDataScheduler",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ac53cf15a104508937cb84c37be159ca809c23f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..729c8d3db625c29b2179522d251e9efb51eec0eb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..55c3b65273df4ab7fbd7be7f362881aa9bdbf850
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py
@@ -0,0 +1,180 @@
+from functools import wraps
+import weakref
+import abc
+import warnings
+
+from ..data_sparsifier import BaseDataSparsifier
+
+__all__ = ['BaseDataScheduler']
+
+
+class BaseDataScheduler:
+    r"""
+    The BaseDataScheduler is the abstract scheduler class specifically for the
+    BaseDataSparsifier class. This class controls a specific hyperparameter of
+    the sparsifier class and varies it across the training process (or across time).
+
+    Args:
+        data_sparsifier (instance of BaseDataSparsifier)
+            Implemented class data sparsifier class wherein the update_mask is implemented
+        schedule_param (str)
+            A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied
+        last_epoch (int, default=-1)
+            This is specifically is passed when training needs to be resumed from a particular
+            point.
+        verbose (bool, default=False)
+            Verbosity of the BaseDataScheduler
+
+    The *get_hyperparam()* function needs to be implemented by the user.
+    """
+    def __init__(self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False):
+        # Attach sparsifier
+        if not isinstance(data_sparsifier, BaseDataSparsifier):
+            raise TypeError('{} is not an instance of torch.ao.pruning.BaseDataSparsifier'.format(
+                type(data_sparsifier).__name__))
+        self.data_sparsifier = data_sparsifier
+        self.schedule_param = schedule_param
+
+        # Initialize epoch and base hyper-params
+        self.base_param = {
+            name: config.get(schedule_param, None)
+            for name, config in self.data_sparsifier.data_groups.items()
+        }
+
+        self.last_epoch = last_epoch
+
+        # Following https://github.com/pytorch/pytorch/issues/20124
+        # We would like to ensure that `scheduler.step()` is called after
+        # `sparsifier.step()`
+        def with_counter(method):
+            if getattr(method, '_with_counter', False):
+                # `sparsifier.step()` has already been replaced, return.
+                return method
+
+            # Keep a weak reference to the sparsifier instance to prevent
+            # cyclic references.
+            instance_ref = weakref.ref(method.__self__)
+            # Get the unbound method for the same purpose.
+            func = method.__func__
+            cls = instance_ref().__class__
+            del method
+
+            @wraps(func)
+            def wrapper(*args, **kwargs):
+                instance = instance_ref()
+                instance._step_count += 1  # type: ignore[union-attr]
+                wrapped = func.__get__(instance, cls)
+                return wrapped(*args, **kwargs)
+
+            # Note that the returned function here is no longer a bound method,
+            # so attributes like `__func__` and `__self__` no longer exist.
+            wrapper._with_counter = True  # type: ignore[attr-defined]
+            return wrapper
+
+        self.data_sparsifier.step = with_counter(self.data_sparsifier.step)  # type: ignore[assignment]
+        self.data_sparsifier._step_count = 0  # type: ignore[attr-defined]
+        self._step_count: int = 0
+        self.verbose = verbose
+
+        # Housekeeping
+        self._get_sp_called_within_step: bool = False  # sp -> schedule parameter
+        self.step()
+
+    @abc.abstractmethod
+    def get_schedule_param(self):
+        r"""
+        Abstract method that needs to be implemented by the child class.
+        The expected return type should is a dictionary of name to schedule_param value
+        The returned values will be updated in sparsifier when the scheduler step() function
+        is called.
+
+        Example:
+            >>> def get_schedule_param(self):
+            ...     new_param = {}
+            ...     for name in self.sparsifier.data_groups.keys():
+            ...         new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5
+            ...     return new_param
+
+        When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param]
+        would be halved
+        """
+        raise NotImplementedError
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + ' ('
+        format_string += '\n'
+        format_string += f'Data Sparsifier {self.data_sparsifier}\n'
+        format_string += f'    {self.schedule_param}: {self.base_param}\n'
+        format_string += ')'
+        return format_string
+
+    def state_dict(self):
+        """Returns the state of the scheduler as a :class:`dict`.
+
+        It contains an entry for every variable in self.__dict__ which
+        is not the sparsifier.
+
+        Note:
+            The scheduler class does not track the state of the data_sparsifier.
+            Make sure to store the state of the sparsifier before storing the
+            state of the scheduler
+        """
+        return {key: value for key, value in self.__dict__.items() if key != 'data_sparsifier'}
+
+    def load_state_dict(self, state_dict):
+        """Loads the schedulers state.
+
+        Note:
+            Remember to restore the state of the data_sparsifier before the scheduler.
+
+        Args:
+            state_dict (dict): scheduler state. Should be an object returned
+                from a call to :meth:`state_dict`.
+        """
+        self.__dict__.update(state_dict)
+
+    def get_last_param(self):
+        return self._last_param
+
+    def step(self):
+        # Raise warning if trying to call scheduler step before the sparsifier.
+        # https://github.com/pytorch/pytorch/issues/20124
+        if self._step_count == 1:
+            if not hasattr(self.data_sparsifier.step, "_with_counter"):
+                warnings.warn("Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler "
+                              "initialization. Please, make sure to call `data_sparsifier.step()` before "
+                              "`scheduler.step()`.", UserWarning)
+
+            # Just check if there were two first scheduler.step() calls before sparsifier.step()
+            elif self.data_sparsifier._step_count < 1:  # type: ignore[attr-defined]
+                warnings.warn("Detected call of `scheduler.step()` before `data_sparsifier.step()`. "
+                              "You have to make sure you run the data_sparsifier.step() BEFORE any "
+                              "calls to the scheduler.step().", UserWarning)
+        self._step_count += 1
+
+        class _enable_get_sp_call:
+
+            def __init__(self, o):
+                self.o = o
+
+            def __enter__(self):
+                self.o._get_sp_called_within_step = True
+                return self
+
+            def __exit__(self, type, value, traceback):
+                self.o._get_sp_called_within_step = False
+
+        with _enable_get_sp_call(self):
+            self.last_epoch += 1
+            updated_scheduler_params = self.get_schedule_param()
+
+        for name, param in updated_scheduler_params.items():
+            self.data_sparsifier.data_groups[name][self.schedule_param] = param
+            if self.verbose:
+                print(f"Adjusting {self.schedule_param} for group {name} to {param}")
+
+        self._last_param = {
+            name: config.get(self.schedule_param, None)
+            for name, config in self.data_sparsifier.data_groups.items()
+        }
+        self.data_sparsifier.enable_mask_update = True
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..27dd919555b8ccec0431198c9d43a62b0af2eb82
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py
@@ -0,0 +1,7 @@
+from .base_data_sparsifier import BaseDataSparsifier
+from .data_norm_sparsifier import DataNormSparsifier
+
+__all__ = [
+    "BaseDataSparsifier",
+    "DataNormSparsifier",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..290d0567b66bbdd6bf30906c288b15bff707ef55
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b1961ab5f3814d5c195848ea17cd114889ad15e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..287e0414c1493981d84d18fd436207222fe6e5e0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..383e415d642d79b116079df0b6745ff9dfdf3510
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..74bdf3abf8089742e7db8c565af5f31f53e98e3c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py
@@ -0,0 +1,309 @@
+import abc
+import torch
+from typing import Optional, Tuple, List, Any, Dict
+from ...sparsifier import base_sparsifier
+from collections import defaultdict
+from torch import nn
+import copy
+from ...sparsifier import utils
+from torch.nn.utils import parametrize
+import sys
+import warnings
+
+if not sys.warnoptions:
+    # to suppress repeated warnings when being used in a training loop.
+    warnings.simplefilter("once")
+
+__all__ = ['BaseDataSparsifier']
+
+EMBEDDING_TYPES = {
+    nn.Embedding,
+    nn.EmbeddingBag,
+}
+
+SUPPORTED_TYPES = {
+    torch.Tensor,
+    nn.Parameter,
+    *EMBEDDING_TYPES,
+}
+
+
+class _Container(nn.Module):
+    pass
+
+
+class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
+    r"""
+    Base Data Sparsifier class for all Data sparsifiers.
+    The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above)
+    to prepare for sparsification.
+    In this case, mask (and parametrizations) is owned by the class and not by the user.
+    Specifically, the container object inside the class maintains the mask and parametrizations of the input data
+
+    Args:
+        data_list (list of tuples)
+            list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES
+            for type of data. Internally, a container module handles the data sparsification.
+
+        defaults (dict)
+            default configurations will be attached to the
+            configuration. Only the keys that don't exist in the `config` will
+            be updated.
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))]
+        >>> defaults = {'sparsity_level': 0.7}
+        >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier
+        >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3}
+        >>> sparsifier.add_data(**new_tensor_to_add)
+        >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3
+    """
+    def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults):
+        super().__init__(defaults=defaults)
+
+        self._container = _Container()
+
+        self.data_groups: Dict[str, Dict] = defaultdict(dict)  # name -> {**config}
+        if data_list is not None:
+            # add data with default config here
+            [self.add_data(name, data, **self.defaults) for name, data in data_list]
+
+    def prepare(self):
+        raise NotImplementedError("this function is undefined for this class")
+
+    def _extract_weight(self, data):
+        # extract the weight parameter instead of underlying data
+        if type(data) in [torch.Tensor, nn.Parameter]:
+            return data
+        elif type(data) in EMBEDDING_TYPES:
+            return data.weight
+
+    def add_data(self, name: str, data, reuse_mask=True, **config):
+        r""" Configures and parametrizes the internal container model with name and data.
+
+        **Note**:
+            1. If the data with name already exists, it replaces the data.
+            2. While replacing, the old mask is reused when `reuse_mask=True`
+            3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data.
+            4. By default, the config of the replaced data is used as config for the replacing data, unless something
+               is specified in the config dictionary.
+        """
+        assert type(data) in SUPPORTED_TYPES, \
+            "specified data type not supported at the moment"
+        local_args = copy.deepcopy(self.defaults)
+        local_args.update(config)
+        weight = self._extract_weight(data)
+
+        # Bookkeeping in the container class
+        mask = local_args.get('mask', torch.ones_like(weight))
+        param_class = local_args.get('parametrization', utils.FakeSparsity)
+
+        if name in self.state:
+            # If the named data already exists - replace
+            warnings.warn("Replacing existing data of the same name. - Did you mean a different name?")
+
+            # reuse old config
+            old_args = self.data_groups[name]
+            local_args = copy.deepcopy(old_args)
+            local_args.update(config)
+
+            if reuse_mask:
+                current_data = self.get_data(name=name)
+                assert weight.shape == current_data.shape, \
+                    "to retain the old mask, the shape of the new data must be the same as the previous one"
+                mask = self.get_mask(name=name)  # reuse mask instead of creating a new one
+
+            self._delete_data(name=name)
+
+        # parameter creates a deepcopy of the weight inside, so create a buffer
+        self._container.register_buffer(name=name, tensor=weight)
+        parametrize.register_parametrization(self._container, name, param_class(mask))
+        self.state[name]['mask'] = mask
+        self.data_groups[name] = local_args
+        return getattr(self._container, name)
+
+    def get_data(self, name: str, return_original: bool = True):
+        r"""Returns weight tensor (or data)
+        Args:
+            - name: name of the data to be returned
+            - return_original returns weight tensor without applying parametrization if True
+                else - returns the sparsified version (parametrized)
+        """
+        if name not in self.data_groups:
+            raise ValueError("data with specified name does not exist")
+
+        if return_original:
+            if not parametrize.is_parametrized(self._container, name):
+                raise ValueError("mask squashed - original mask value does not exist")
+            data = getattr(self._container.parametrizations, name).original
+            return data
+        else:
+            return getattr(self._container, name)
+
+    def _convert_mask(self, states, sparse_coo=True):
+        r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument.
+        """
+        states = copy.deepcopy(states)
+        for state in states.values():
+            if sparse_coo:
+                state['mask'] = state['mask'].to_sparse_coo()
+            else:
+                state['mask'] = state['mask'].to_dense()
+
+        return states
+
+    def state_dict(self):
+        r"""Returns the state of the optimizer as a :class:`dict`.
+
+        It contains:
+        * state - contains name -> mask mapping.
+        * data_groups - a list containing all sparsity configuration groups
+            with the key name specifying the name of the data
+        * container_state_dict - the state dictionary of the internal
+            container model used for sparsification
+        """
+        state = self._convert_mask(self.state)
+        return {
+            'state': state,
+            'data_groups': self.data_groups,
+            '_container': self._container.state_dict()
+        }
+
+    def _load_container_from_state(self, states, data_groups, container_state_dict):
+        r"""This restores the state of the container specifically based on the data present in state and data_groups
+        If the data was parametrized, then the data would be added to the container and then parametrized,
+        else it would just add the attribute the container.
+        """
+        for name, state in states.items():
+            config_name = data_groups.get(name, None)
+            if config_name is None:
+                raise RuntimeError(f"Error loading {name}")
+
+            # check if the data with such a name was parametrized, if so parametrize
+            # otherwise just set the attribute and continue
+            parametrized_name = f'parametrizations.{name}.original'
+            parametrized = False
+            data = container_state_dict.get(name, None)
+            if name in container_state_dict:
+                # the parametrization was probably removed for this
+                data = container_state_dict.get(name)
+
+            elif parametrized_name in container_state_dict:
+                # so the weight was parametrized
+                data = container_state_dict.get(parametrized_name)
+                parametrized = True
+
+            else:
+                raise RuntimeError(f"Error loading {name}")
+
+            self._container.register_buffer(name=name, tensor=data)
+
+            if parametrized:
+                # register parameter if parametrized
+                mask = state.get('mask', torch.ones_like(data))
+                param_class = data_groups.get('parametrization', utils.FakeSparsity)  # change once public_api for utils is fixed!
+                parametrize.register_parametrization(self._container, name, param_class(mask))
+
+    def load_state_dict(self, state_dict, strict=True):
+        r"""The load_state_dict() restores the state of the sparsifier based on the state_dict
+
+        Args:
+        * state_dict - the dictionary that to which the current sparsifier needs to be restored to
+        * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict.
+            If False - the current sparsifier is not reset before loading the state_dict i.e. data added
+            before loading the state_dict is not erased.
+        """
+        states = copy.deepcopy(state_dict['state'])
+        data_groups = copy.deepcopy(state_dict['data_groups'])
+        container_state_dict = copy.deepcopy(state_dict['_container'])
+
+        states = self._convert_mask(states, sparse_coo=False)  # convert sparse coo mask to dense
+        if strict:
+            # if strict load -> then reset container
+            self._container = _Container()
+
+        self._load_container_from_state(states, data_groups, container_state_dict)
+
+        if not strict:
+            states.update(self.state)
+            data_groups.update(self.data_groups)
+
+        self.__setstate__({'state': states, 'data_groups': data_groups})
+
+    def __setstate__(self, state):
+        if '_container' in state:  # If container object is in state then load model
+            container_dict = state.pop('_container')
+            self._container = _Container()
+            state['state'] = self._convert_mask(state['state'], sparse_coo=False)  # convert sparse coo mask to dense
+            self._load_container_from_state(state['state'], state['data_groups'], container_dict)
+
+        self.__dict__.update(state)
+
+    def __getstate__(self):
+        state = self._convert_mask(self.state)
+        return {
+            'defaults': self.defaults,
+            'state': state,
+            'data_groups': self.data_groups,
+            '_container': self._container.state_dict()
+        }
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + ' ('
+        for name, sparse_args in self.data_groups.items():
+            format_string += '\n'
+            format_string += '\tData Group\n'
+            format_string += f'\t    name: {name}\n'
+            for key in sorted(sparse_args.keys()):
+                if key == 'data':
+                    continue
+                format_string += f'\t    {key}: {sparse_args[key]}\n'
+        format_string += ')'
+        return format_string
+
+    def get_mask(self, name: str):
+        if name not in self.state:
+            raise ValueError("data with specified name does not exist")
+        return self.state[name]['mask']
+
+    def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs):
+        r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings
+        to squash mask for. If none, squashes mask for all the keys
+        kwargs:
+            * names: list of strings to squash mask for
+            * sparsified: if true - applies the mask before squashing
+                          if false - does not apply the mask before squashing
+        """
+        if names is None:
+            names = list(self.data_groups.keys())
+        for name in names:
+            parametrize.remove_parametrizations(self._container, name, leave_parametrized=leave_parametrized)
+
+    def step(self):
+        if not self.enable_mask_update:
+            return
+        with torch.no_grad():
+            for name, config in self.data_groups.items():
+                # get non-sparsified data
+                data = self.get_data(name)
+                # need name for the mask otherwise can directly pass mask?
+                self.update_mask(name, data, **config)
+
+    @abc.abstractmethod
+    def update_mask(self, name, data, **kwargs):
+        pass
+
+    def _delete_data(self, name):
+        """Detaches some data from the sparsifier.
+
+        Args:
+            name (str)
+                Name of the data to be removed from the sparsifier
+
+        Note:
+            Currently private. Kind of used as a helper function when replacing data of the same name
+        """
+        self.squash_mask(names=[name], leave_parametrized=False)  # do not apply the mask while deleting
+        delattr(self._container, name)
+        self.state.pop(name)
+        self.data_groups.pop(name)
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a335041963bf15d2c1d6c117239c6562b94180c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py
@@ -0,0 +1,153 @@
+import torch
+from torch.nn import functional as F
+from functools import reduce
+from typing import Any, List, Optional, Tuple
+
+from .base_data_sparsifier import BaseDataSparsifier
+import operator
+
+__all__ = ['DataNormSparsifier']
+
+
+class DataNormSparsifier(BaseDataSparsifier):
+    r"""L1-Norm Sparsifier
+    This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the
+    ones with the lowest norm. The level of sparsity defines how many of the
+    blocks is removed.
+    This sparsifier is controlled by three variables:
+    1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
+    2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
+        the sparse blocks originate at the zero-index of the tensor.
+    3. `zeros_per_block` is the number of zeros that we are expecting in each
+        sparse block. By default we assume that all elements within a block are
+        zeroed-out. However, setting this variable sets the target number of
+        zeros per block. The zeros within each block are chosen as the *smallest
+        absolute values*.
+    Args:
+        sparsity_level: The target level of sparsity
+        sparse_block_shape: The shape of a sparse block
+        zeros_per_block: Number of zeros in a sparse block
+    Note::
+        All arguments to the DataNormSparsifier constructor are "default"
+        arguments and could be overriden by the configuration provided in the
+        `add_data` step.
+    """
+    def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, sparsity_level: float = 0.5,
+                 sparse_block_shape: Tuple[int, int] = (1, 4),
+                 zeros_per_block: Optional[int] = None, norm: str = 'L1'):
+        if zeros_per_block is None:
+            zeros_per_block = reduce(operator.mul, sparse_block_shape)
+
+        assert norm in ['L1', 'L2'], "only L1 and L2 norm supported at the moment"
+
+        defaults = {'sparsity_level': sparsity_level, 'sparse_block_shape': sparse_block_shape,
+                    'zeros_per_block': zeros_per_block}
+        self.norm = norm
+        super().__init__(data_list=data_list, **defaults)
+
+    def __get_scatter_folded_mask(self, data, dim, indices, output_size, sparse_block_shape):
+        mask = torch.ones_like(data)
+        mask.scatter_(dim=dim, index=indices, value=0)  # zeroing out
+        mask = F.fold(mask, output_size=output_size, kernel_size=sparse_block_shape,
+                      stride=sparse_block_shape)
+        mask = mask.to(torch.int8)
+        return mask
+
+    def __get_block_level_mask(self, data,
+                               sparse_block_shape, zeros_per_block):
+
+        # Assume data is a squeezed tensor
+        height, width = data.shape[-2], data.shape[-1]
+        block_height, block_width = sparse_block_shape
+        values_per_block = block_height * block_width
+
+        # just return zeros if zeroing all elements in block
+        if values_per_block == zeros_per_block:
+            return torch.zeros_like(data, dtype=torch.int8)
+
+        # creating additional height and width to support padding
+        dh = (block_height - height % block_height) % block_height
+        dw = (block_width - width % block_width) % block_width
+
+        # create a new padded tensor like data (to match the block_shape)
+        padded_data = torch.ones(height + dh, width + dw, dtype=data.dtype, device=data.device)
+        padded_data = padded_data * torch.nan  # can also be replaced with 0 to stop the removal of edge data
+        padded_data[0:height, 0:width] = data
+        unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape,
+                                 stride=sparse_block_shape)
+
+        _, sorted_idx = torch.sort(unfolded_data, dim=1)
+        sorted_idx = sorted_idx[:, :zeros_per_block, :]  # zero out zeros_per_block number of elements
+
+        mask = self.__get_scatter_folded_mask(data=unfolded_data, dim=1, indices=sorted_idx, output_size=padded_data.shape,
+                                              sparse_block_shape=sparse_block_shape)
+
+        mask = mask.squeeze(0).squeeze(0)[:height, :width].contiguous()  # remove padding and make contiguous
+        return mask
+
+    def __get_data_level_mask(self, data, sparsity_level,
+                              sparse_block_shape):
+
+        height, width = data.shape[-2], data.shape[-1]
+        block_height, block_width = sparse_block_shape
+        dh = (block_height - height % block_height) % block_height
+        dw = (block_width - width % block_width) % block_width
+
+        data_norm = F.avg_pool2d(data[None, None, :], kernel_size=sparse_block_shape,
+                                 stride=sparse_block_shape, ceil_mode=True)
+
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+
+        data_norm = data_norm.flatten()
+        num_blocks = len(data_norm)
+
+        data_norm = data_norm.repeat(1, values_per_block, 1)  # get similar shape after unfold
+        _, sorted_idx = torch.sort(data_norm, dim=2)
+
+        threshold_idx = round(sparsity_level * num_blocks)  # number of blocks to remove
+        sorted_idx = sorted_idx[:, :, :threshold_idx]
+
+        mask = self.__get_scatter_folded_mask(data=data_norm, dim=2, indices=sorted_idx,
+                                              output_size=(height + dh, width + dw),
+                                              sparse_block_shape=sparse_block_shape)
+
+        mask = mask.squeeze(0).squeeze(0)[:height, :width]  # squeeze only the first 2 dimension
+        return mask
+
+    def update_mask(self, name, data, sparsity_level,
+                    sparse_block_shape, zeros_per_block, **kwargs):
+
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+        if zeros_per_block > values_per_block:
+            raise ValueError("Number of zeros per block cannot be more than "
+                             "the total number of elements in that block.")
+        if zeros_per_block < 0:
+            raise ValueError("Number of zeros per block should be positive.")
+
+        if self.norm == 'L1':
+            data_norm = torch.abs(data).squeeze()  # absolute value based (L1)
+        else:
+            data_norm = (data * data).squeeze()  # square every element for L2
+
+        if len(data_norm.shape) > 2:  # only supports 2 dimensional data at the moment
+            raise ValueError("only supports 2-D at the moment")
+
+        elif len(data_norm.shape) == 1:  # in case the data is bias (or 1D)
+            data_norm = data_norm[None, :]
+
+        mask = self.get_mask(name)
+        if sparsity_level <= 0 or zeros_per_block == 0:
+            mask.data = torch.ones_like(mask)
+        elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
+            mask.data = torch.zeros_like(mask)
+
+        # Fetch the high level mask that zeros out entire blocks
+        data_lvl_mask = self.__get_data_level_mask(data=data_norm, sparsity_level=sparsity_level,
+                                                   sparse_block_shape=sparse_block_shape)
+
+        # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block
+        block_lvl_mask = self.__get_block_level_mask(data=data_norm, sparse_block_shape=sparse_block_shape,
+                                                     zeros_per_block=zeros_per_block)
+
+        # zero out the entries inside those blocks whose block is sparsified
+        mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask)
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e51323a27f71aae2fb7348fe2e15816673271ef
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f73f2ac8fb0d10a1ad32909db656eba883f40c7a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5628ef4f83f632a0310a52da2c79349312d60cfb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17f4e10d603bca0c8bd2ce67d1ab7e1aa1751770
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c0bdcf03cd2589c5019d6490c718f46becceb85
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py
@@ -0,0 +1,39 @@
+import logging
+from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import SUPPORTED_TYPES
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None):
+    """Attaches a data sparsifier to all the layers of the module.
+    Essentially, loop over all the weight parameters in the module and
+    attach it to the data sparsifier.
+    Note::
+        The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below)
+        before attaching to the sparsifier. This is because, the data
+        sparsifier uses a dummy model inside to store the weight parameters.
+    """
+    if config is None:
+        config = {}
+    for name, parameter in module.named_parameters():
+        if type(parameter) in SUPPORTED_TYPES:
+            valid_name = _get_valid_name(name)
+            # will be defaulted to default configs
+            data_sparsifier.add_data(name=valid_name, data=parameter, **config.get(valid_name, {}))
+
+
+def _get_valid_name(name):
+    return name.replace('.', '_')  # . is not allowed as a name
+
+
+def _log_sparsified_level(model, data_sparsifier) -> None:
+    # Show the level of sparsity AFTER step:
+    for name, parameter in model.named_parameters():
+        if type(parameter) not in SUPPORTED_TYPES:
+            continue
+        valid_name = _get_valid_name(name)
+        mask = data_sparsifier.get_mask(name=valid_name)
+        sparsity_level = 1.0 - mask.float().mean()
+        logger.info(
+            "Sparsity in layer %s = % .2%", name, sparsity_level
+        )
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd2ad926a6485ea0fc5907445599605ead33c8d8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py
@@ -0,0 +1,165 @@
+from collections import defaultdict
+from copy import deepcopy
+import torch
+from typing import Any, Optional, Dict
+import pytorch_lightning as pl  # type: ignore[import]
+
+from ._data_sparstity_utils import (
+    _attach_model_to_data_sparsifier,
+    _log_sparsified_level,
+    _get_valid_name
+)
+
+
+class PostTrainingDataSparsity(pl.callbacks.Callback):
+    """Lightning callback that enables post-training sparsity.
+
+    This callback aims to sparsify the model inside lightning module after training.
+    **Note that the model is copied and then sparsified, so the existing model is not modified**
+
+    The sparsified model can be used for comparison and can be accessed using
+        .sparsified
+
+    Args:
+        data_sparsifier_class (some implemented class of BaseDataSparsifier)
+            The data sparsifier object of this class is created when the
+            training starts.
+            Note: Objects should not be passed in here as they are created
+            once the training completes.
+
+        data_sparsifier_args (Dict)
+            Dictionary of args to be passed to the data sparsifier.
+            Note: data_list arg should be ignored
+
+    Hooks implemented:
+        on_fit_end()
+            1. copies the model and attaches it to the sparsifier
+            2. sparsier step() is called
+            3. squashes the mask()
+    """
+    def __init__(self, data_sparsifier_class, data_sparsifier_args):
+        super().__init__()
+        self.data_sparsifier_class = data_sparsifier_class
+        self.data_sparsifier_args = data_sparsifier_args
+        self.data_sparsifier: Any = None
+        self.sparsified: Optional[torch.nn.Module] = None
+
+    def on_fit_end(self, trainer, pl_module) -> None:
+        self.sparsified = deepcopy(pl_module.model).eval()
+        self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args)
+
+        _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier)
+
+        self.data_sparsifier.step()
+
+        self.data_sparsifier.squash_mask()  # currently squashes params for all mask
+
+        _log_sparsified_level(self.sparsified, self.data_sparsifier)
+
+
+class TrainingAwareDataSparsity(pl.callbacks.Callback):
+    """Lightning callback that enables in-training sparsity.
+
+    This callback aims to sparsify the model inside lightning module during training.
+    **Note that the model is copied and then sparsified, so the existing model is not modified**
+
+    The sparsified model can be used for comparison and can be accessed using
+        .sparsified
+
+    Args:
+        data_sparsifier_class (some implemented class of BaseDataSparsifier)
+            The data sparsifier object of this class is created when the
+            training starts.
+            Note: Objects should not be passed in here as they are created
+            when the training starts.
+
+        data_sparsifier_args (Dict)
+            Dictionary of args to be passed to the data sparsifier.
+            Note: data_list arg should be ignored
+
+        data_scheduler_class (some implemented class of BaseDataScheduler)
+            The data scheduler of this class is created when the training starts
+            Note: Objects should not be passed in here as they are created
+            when the training starts.
+
+        data_scheduler_args(Dict)
+            Dictionary of args to be passed to the data scheduler.
+            **Note: data_sparsifier arg should be ignored as the recipe
+            creates and pass sparsifier object into the class**
+
+    Hooks implemented:
+        on_train_start()
+            Data sparsifier and scheduler objects are created.
+            Pytorch model attached to the sparsifier
+
+        on_train_epoch_start()
+            Loads the state_dict of the data sparsifier
+
+        on_train_epoch_end()
+            1. Copies the model and attaches it to the sparsifier
+            2. sparsifier step() and scheduler step()
+            3. Dump state_dict of the current sparsifier
+
+        on_train_end()
+            squash mask
+    """
+    def __init__(self, data_sparsifier_class, data_sparsifier_args,
+                 data_scheduler_class, data_scheduler_args):
+        super().__init__()
+        # data sparsifier objects
+        self.data_sparsifier_class = data_sparsifier_class
+        self.data_sparsifier_args = data_sparsifier_args
+
+        # scheduler objects
+        self.data_scheduler_class = data_scheduler_class
+        self.data_scheduler_args = data_scheduler_args
+
+        # fields
+        self.data_sparsifier: Any = None
+        self.data_scheduler: Any = None
+        self.sparsified: Optional[torch.nn.Module] = None
+
+        self.data_sparsifier_state_dict: Any = None
+
+    def on_train_start(self, trainer, pl_module) -> None:
+        # create sparsifier
+        self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args)
+        self.sparsified = deepcopy(pl_module.model)
+
+        _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier)  # just to populate the base_sl in the scheduler
+
+        # create scheduler
+        args = deepcopy(self.data_scheduler_args)
+        args['data_sparsifier'] = self.data_sparsifier
+        self.data_scheduler = self.data_scheduler_class(**args)
+
+    def on_train_epoch_start(self, trainer, pl_module):
+        if self.data_sparsifier_state_dict is None:
+            return  # probably first epoch
+
+        # load the existing config for each data
+        self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict)
+
+    def __create_config_based_on_state(self, pl_module):
+        config: Dict = defaultdict()
+        if self.data_sparsifier_state_dict is None:
+            return config
+        for name, _ in pl_module.model.named_parameters():
+            valid_name = _get_valid_name(name)
+            config[valid_name] = self.data_sparsifier.data_groups[valid_name]
+
+        return config
+
+    def on_train_epoch_end(self, trainer, pl_module):
+        self.sparsified = deepcopy(pl_module.model)
+        config = self.__create_config_based_on_state(pl_module)
+
+        # attach model to the data sparsifier
+        _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier, config=config)
+        self.data_sparsifier.step()
+        self.data_scheduler.step()
+
+        self.data_sparsifier_state_dict = self.data_sparsifier.state_dict()
+
+    def on_train_end(self, trainer, pl_module):
+        self.data_sparsifier.squash_mask()
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca33b242a4deae1394f520653df2160b98cb74ca
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
+from typing import Dict, List, Optional
+
+SUPPORTED_MODULES = {
+    nn.Embedding,
+    nn.EmbeddingBag
+}
+
+
+def _fetch_all_embeddings(model):
+    """Fetches Embedding and EmbeddingBag modules from the model
+    """
+    embedding_modules = []
+    stack = [model]
+    while stack:
+        module = stack.pop()
+        for _, child in module.named_children():
+            fqn_name = module_to_fqn(model, child)
+            if type(child) in SUPPORTED_MODULES:
+                embedding_modules.append((fqn_name, child))
+            else:
+                stack.append(child)
+    return embedding_modules
+
+
+def post_training_sparse_quantize(model,
+                                  data_sparsifier_class,
+                                  sparsify_first=True,
+                                  select_embeddings: Optional[List[nn.Module]] = None,
+                                  **sparse_config):
+    """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
+    The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
+
+    Args:
+        - model (nn.Module)
+            model whose embeddings needs to be sparsified
+        - data_sparsifier_class (type of data sparsifier)
+            Type of sparsification that needs to be applied to model
+        - sparsify_first (bool)
+            if true, sparsifies first and then quantizes
+            otherwise, quantizes first and then sparsifies.
+        - select_embeddings (List of Embedding modules)
+            List of embedding modules to in the model to be sparsified & quantized.
+            If None, all embedding modules with be sparsified
+        - sparse_config (Dict)
+            config that will be passed to the constructor of data sparsifier object.
+
+    Note:
+        1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
+            - before sparsifying, the embedding layers are dequantized.
+            - scales and zero-points are saved
+            - embedding layers are sparsified and `squash_mask` is applied
+            - embedding weights are requantized using the saved scales and zero-points
+        2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
+            - embeddings are sparsified first
+            - quantization is applied on the sparsified embeddings
+    """
+    data_sparsifier = data_sparsifier_class(**sparse_config)
+
+    # if select_embeddings is None, perform it on all embeddings
+    if select_embeddings is None:
+        embedding_modules = _fetch_all_embeddings(model)
+
+    else:
+        embedding_modules = []
+        assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules"
+        for emb in select_embeddings:
+            assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags"
+            fqn_name = module_to_fqn(model, emb)
+            assert fqn_name is not None, "the embedding modules must be part of input model"
+            embedding_modules.append((fqn_name, emb))
+
+    if sparsify_first:
+        # sparsify
+        for name, emb_module in embedding_modules:
+            valid_name = name.replace('.', '_')
+            data_sparsifier.add_data(name=valid_name, data=emb_module)
+
+        data_sparsifier.step()
+        data_sparsifier.squash_mask()
+
+        # quantize
+        for _, emb_module in embedding_modules:
+            emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
+
+        torch.ao.quantization.prepare(model, inplace=True)
+        torch.ao.quantization.convert(model, inplace=True)
+
+    else:
+        # quantize
+        for _, emb_module in embedding_modules:
+            emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
+
+        torch.ao.quantization.prepare(model, inplace=True)
+        torch.ao.quantization.convert(model, inplace=True)
+
+        # retrieve scale & zero_points
+        quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {},
+                                            'dequant_weights': {}, 'axis': {},
+                                            'dtype': {}}
+
+        for name, _ in embedding_modules:
+            quantized_emb = fqn_to_module(model, name)
+            assert quantized_emb is not None  # satisfy mypy
+
+            quantized_weight = quantized_emb.weight()  # type: ignore[operator]
+            quantize_params['scales'][name] = quantized_weight.q_per_channel_scales()
+            quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points()
+            quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight)
+            quantize_params['axis'][name] = quantized_weight.q_per_channel_axis()
+            quantize_params['dtype'][name] = quantized_weight.dtype
+
+            # attach data to sparsifier
+            data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name])
+
+        data_sparsifier.step()
+        data_sparsifier.squash_mask()
+
+        for name, _ in embedding_modules:
+            quantized_emb = fqn_to_module(model, name)
+            assert quantized_emb is not None  # satisfy mypy
+            requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name],
+                                                            scales=quantize_params['scales'][name],
+                                                            zero_points=quantize_params['zero_points'][name],
+                                                            dtype=quantize_params['dtype'][name],
+                                                            axis=quantize_params['axis'][name])
+
+            quantized_emb.set_weight(requantized_vector)  # type: ignore[operator]
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb8ce411719996ffb6edce123328aad500b83b9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py
@@ -0,0 +1,93 @@
+from typing import Callable, Optional, Union
+
+import torch
+
+from .base_structured_sparsifier import BaseStructuredSparsifier
+
+__all__ = ["FPGMPruner"]
+
+
+class FPGMPruner(BaseStructuredSparsifier):
+    r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner
+    This sparsifier prune fliter (row) in a tensor according to distances among filters according to
+    `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_.
+
+    This sparsifier is controlled by three variables:
+    1. `sparsity_level` defines the number of filters (rows) that are zeroed-out.
+    2. `dist` defines the distance measurement type. Default: 3 (L2 distance).
+    Available options are: [1, 2, (custom callable distance function)].
+
+    Note::
+        Inputs should be a 4D convolutional tensor of shape (N, C, H, W).
+            - N: output channels size
+            - C: input channels size
+            - H: height of kernel
+            - W: width of kernel
+    """
+
+    def __init__(
+        self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None
+    ):
+        defaults = {
+            "sparsity_level": sparsity_level,
+        }
+
+        if dist is None:
+            dist = 2
+
+        if callable(dist):
+            self.dist_fn = dist
+        elif dist == 1:
+            self.dist_fn = lambda x: torch.cdist(x, x, p=1)
+        elif dist == 2:
+            self.dist_fn = lambda x: torch.cdist(x, x, p=2)
+        else:
+            raise NotImplementedError("Distance function is not yet implemented.")
+        super().__init__(defaults=defaults)
+
+    def _compute_distance(self, t):
+        r"""Compute distance across all entries in tensor `t` along all dimension
+        except for the one identified by dim.
+        Args:
+            t (torch.Tensor): tensor representing the parameter to prune
+        Returns:
+            distance (torch.Tensor): distance computed across filtters
+        """
+        dim = 0  # prune filter (row)
+
+        size = t.size(dim)
+        slc = [slice(None)] * t.dim()
+
+        # flatten the tensor along the dimension
+        t_flatten = [
+            t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1)
+            for i in range(size)
+        ]
+        t_flatten = torch.stack(t_flatten)
+
+        # distance measurement
+        dist_matrix = self.dist_fn(t_flatten)
+
+        # more similar with other filter indicates large in the sum of row
+        distance = torch.sum(torch.abs(dist_matrix), 1)
+
+        return distance
+
+    def update_mask(self, module, tensor_name, sparsity_level, **kwargs):
+        tensor_weight = getattr(module, tensor_name)
+        mask = getattr(module.parametrizations, tensor_name)[0].mask
+
+        if sparsity_level <= 0:
+            mask.data = torch.ones_like(mask).bool()
+        elif sparsity_level >= 1.0:
+            mask.data = torch.zeros_like(mask).bool()
+        else:
+            distance = self._compute_distance(tensor_weight)
+
+            tensor_size = tensor_weight.shape[0]  # prune filter (row)
+            nparams_toprune = round(sparsity_level * tensor_size)
+            nparams_toprune = min(
+                max(nparams_toprune, 0), tensor_size
+            )  # clamp to [0, tensor_size]
+            topk = torch.topk(distance, k=nparams_toprune, largest=False)
+            mask[topk.indices] = False
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9589fa6afeb5458172dfa0fd6217c8f6496da45a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py
@@ -0,0 +1,8 @@
+from .base_structured_sparsifier import BaseStructuredSparsifier
+from .parametrization import (
+    FakeStructuredSparsity,
+    BiasHook,
+)
+from .saliency_pruner import SaliencyPruner
+from .lstm_saliency_pruner import LSTMSaliencyPruner
+from .FPGM_pruner import FPGMPruner
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..704126199851897985a00785ee12518bb625269b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee71803163359f389c33e3b7cff04eb75f9d3191
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51b7586135a49e4ce66fe107de668db8beed7be7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4cb3ea0ef6f24d5a339d419de988d2aff2396a9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b623940073893ff7be57a010c320d0c5ec8e9ab2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac8d7e81e1c7e925e9f93976ca44cdbb8adaf8a9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e9379bcdb2128cf9799652c73fe3ebe85fa693a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf073dcb0aba5cefa98ab190d07b30ff6ef5edb1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c33b02fff1eca775d2513bacb156add4533cf2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py
@@ -0,0 +1,310 @@
+from itertools import chain
+from operator import getitem
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.fx import symbolic_trace
+from torch.nn.utils import parametrize
+from typing import Type, Set, Dict, Callable, Tuple, Optional, Union
+
+from torch.ao.pruning import BaseSparsifier
+from .parametrization import FakeStructuredSparsity, BiasHook, module_contains_param
+from .match_utils import apply_match, MatchAllNode
+from .prune_functions import (
+    prune_linear,
+    prune_linear_linear,
+    prune_linear_activation_linear,
+    prune_conv2d,
+    prune_conv2d_conv2d,
+    prune_conv2d_activation_conv2d,
+    prune_conv2d_activation_pool_conv2d,
+    prune_conv2d_pool_activation_conv2d,
+    prune_conv2d_pool_flatten_linear,
+    prune_lstm_output_linear,
+    prune_lstm_output_layernorm_linear,
+)
+
+
+def _get_supported_structured_pruning_modules():
+    SUPPORTED_STRUCTURED_PRUNING_MODULES = {  # added to config if None given
+        nn.Linear,
+        nn.Conv2d,
+        nn.LSTM,
+    }
+    return SUPPORTED_STRUCTURED_PRUNING_MODULES
+
+
+def _get_supported_activation_functions():
+    SUPPORTED_ACTIVATION_FUNCTIONS = {
+        F.relu,
+        F.rrelu,
+        F.hardtanh,
+        F.relu6,
+        F.sigmoid,
+        F.hardsigmoid,
+        F.tanh,
+        F.silu,
+        F.mish,
+        F.hardswish,
+        F.elu,
+        F.celu,
+        F.selu,
+        F.hardshrink,
+        F.leaky_relu,
+        F.logsigmoid,
+        F.softplus,
+        F.prelu,
+        F.softsign,
+        F.tanhshrink,
+        F.gelu,
+    }
+    return SUPPORTED_ACTIVATION_FUNCTIONS
+
+
+def _get_supported_activation_modules():
+    SUPPORTED_ACTIVATION_MODULES = {
+        nn.ReLU,
+        nn.RReLU,
+        nn.Hardtanh,
+        nn.ReLU6,
+        nn.Sigmoid,
+        nn.Hardsigmoid,
+        nn.Tanh,
+        nn.SiLU,
+        nn.Mish,
+        nn.Hardswish,
+        nn.ELU,
+        nn.CELU,
+        nn.SELU,
+        nn.Hardshrink,
+        nn.LeakyReLU,
+        nn.LogSigmoid,
+        nn.Softplus,
+        nn.PReLU,
+        nn.Softsign,
+        nn.Tanhshrink,
+        nn.GELU,
+    }
+    return SUPPORTED_ACTIVATION_MODULES
+
+
+def _get_default_structured_pruning_patterns() -> Dict[
+    Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...],
+    Callable[..., None],
+]:
+    """
+    Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above.
+    """
+    patterns: Dict[
+        Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...],
+        Callable[..., None],
+    ] = {
+        # linear -> linear
+        (nn.Linear, "output"): prune_linear,
+        (nn.Linear, nn.Linear): prune_linear_linear,
+        # conv2d -> conv2d
+        (nn.Conv2d, "output"): prune_conv2d,
+        (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d,
+        # TODO LSTM Structured pruning does not support returned state currently.
+        # Should find a way to explicitly match getitem(0) instead of getitem.
+        # This will also require changing the pruning function.
+        # lstm -> getitem(0) -> linear
+        (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear,
+        # lstm -> getitem(0) -> layernorm -> linear
+        (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear,
+    }
+
+    for activation in chain(
+        _get_supported_activation_functions(), _get_supported_activation_modules()
+    ):
+        patterns.update(
+            {
+                # linear -> activation -> linear
+                (nn.Linear, activation, nn.Linear): prune_linear_activation_linear,
+                # conv2d -> activation -> conv2d
+                (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d,
+                # conv2d -> activation -> pool -> conv2d
+                (
+                    nn.Conv2d,
+                    activation,
+                    nn.AvgPool2d,
+                    nn.Conv2d,
+                ): prune_conv2d_activation_pool_conv2d,
+                (
+                    nn.Conv2d,
+                    activation,
+                    F.avg_pool2d,
+                    nn.Conv2d,
+                ): prune_conv2d_activation_pool_conv2d,
+                (
+                    nn.Conv2d,
+                    activation,
+                    nn.MaxPool2d,
+                    nn.Conv2d,
+                ): prune_conv2d_activation_pool_conv2d,
+                (
+                    nn.Conv2d,
+                    activation,
+                    F.max_pool2d,
+                    nn.Conv2d,
+                ): prune_conv2d_activation_pool_conv2d,
+                # conv2d -> pool -> activation -> conv2d
+                (
+                    nn.Conv2d,
+                    nn.AvgPool2d,
+                    activation,
+                    nn.Conv2d,
+                ): prune_conv2d_pool_activation_conv2d,
+                (
+                    nn.Conv2d,
+                    F.avg_pool2d,
+                    activation,
+                    nn.Conv2d,
+                ): prune_conv2d_pool_activation_conv2d,
+                (
+                    nn.Conv2d,
+                    nn.MaxPool2d,
+                    activation,
+                    nn.Conv2d,
+                ): prune_conv2d_pool_activation_conv2d,
+                (
+                    nn.Conv2d,
+                    F.max_pool2d,
+                    activation,
+                    nn.Conv2d,
+                ): prune_conv2d_pool_activation_conv2d,
+                # conv2d -> adaptive pool -> flatten -> linear
+                (
+                    nn.Conv2d,
+                    nn.AdaptiveAvgPool2d,
+                    nn.Flatten,
+                    nn.Linear,
+                ): prune_conv2d_pool_flatten_linear,
+                (
+                    nn.Conv2d,
+                    nn.AdaptiveAvgPool2d,
+                    torch.flatten,
+                    nn.Linear,
+                ): prune_conv2d_pool_flatten_linear,
+                (
+                    nn.Conv2d,
+                    nn.AdaptiveMaxPool2d,
+                    nn.Flatten,
+                    nn.Linear,
+                ): prune_conv2d_pool_flatten_linear,
+                (
+                    nn.Conv2d,
+                    nn.AdaptiveMaxPool2d,
+                    torch.flatten,
+                    nn.Linear,
+                ): prune_conv2d_pool_flatten_linear,
+            }
+        )
+    return patterns
+
+
+class BaseStructuredSparsifier(BaseSparsifier):
+    r"""Base class for structured pruning.
+
+    Abstract methods that need to be implemented:
+        - update_mask: Function to compute a new mask for all keys in the
+            `groups` attribute.
+
+    Args:
+        - defaults [dict]: default configurations will be attached to the
+            configuration. Only the keys that don't exist in the `config` will
+            be updated.
+    """
+
+    def __init__(self, defaults, patterns=None):
+        super().__init__(defaults)
+        if patterns is None:
+            patterns = _get_default_structured_pruning_patterns()
+        self.patterns = patterns
+
+    def make_config_from_model(
+        self,
+        model: nn.Module,
+        SUPPORTED_MODULES: Optional[Set[Type]] = None,
+    ) -> None:
+        if SUPPORTED_MODULES is None:
+            SUPPORTED_MODULES = _get_supported_structured_pruning_modules()
+        super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES)
+
+    def _prepare(self, *args, **kwargs) -> None:
+        r"""This function will attach the FakeStructuredSparsity parameterizations
+        and BiasHooks at the appropriate points in the model.
+        """
+        for config in self.groups:
+            module = config["module"]
+            tensor_name = config["tensor_name"]
+            parametrization = config.get("parametrization", FakeStructuredSparsity)
+            tensor = getattr(module, tensor_name)
+
+            mask = config.get(
+                "mask",
+                torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device),
+            )
+            self.state[config["tensor_fqn"]]["mask"] = mask
+            parametrize.register_parametrization(
+                module, tensor_name, parametrization(mask)
+            )
+
+            # if linear / conv, we add in bias hooks
+            if isinstance(module, (nn.Linear, nn.Conv2d)):
+                prune_bias = config.get("prune_bias", True)
+                if module.bias is not None:
+                    module.register_parameter(
+                        "_bias", nn.Parameter(module.bias.detach())
+                    )
+                    module.bias = None
+                    module.prune_bias = prune_bias
+
+                module.register_forward_hook(
+                    BiasHook(module.parametrizations.weight[0], prune_bias)
+                )
+
+    def prune(self) -> None:
+        r"""
+        This function will FX symbolically trace the model and then find instances of the patterns
+        defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ).
+
+        For each pattern, it will apply to corresponding conversion function, which will modify the output
+        and input size expected by the modules within the pattern
+        """
+
+        self.traced = symbolic_trace(self.model)
+        modules = dict(self.traced.named_modules())
+
+        # Right now we check for matches simply by iterating across all the patterns
+        # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup
+        for node in self.traced.graph.nodes:
+            for pattern, convert_fn in self.patterns.items():
+                matched = apply_match(modules, pattern, node, [])
+                if matched is None:
+                    continue
+
+                first_module = modules.get(node.target)
+                # check if first module exists and has appropriate parameterization, otherwise skip
+                if (
+                    first_module is not None
+                    and parametrize.is_parametrized(first_module)
+                    and module_contains_param(first_module, FakeStructuredSparsity)
+                ):
+                    convert_block = []
+                    for node in matched:
+                        if node.op == "call_module":
+                            convert_block.append(modules.get(node.target))
+                        elif node.op == "call_function":
+                            convert_block.append(node.target)
+                    convert_fn(*convert_block)
+
+        for module in self.traced.modules():
+            if module_contains_param(module, FakeStructuredSparsity):
+                raise Exception(
+                    f"Error: {module} still contains FakeStructuredSparsity parametrizations!"
+                )
+
+        self.traced.graph.lint()
+        self.traced.recompile()
+        return self.traced
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbd069e202dd860701a32dc3e6853f9d2e5fb689
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py
@@ -0,0 +1,48 @@
+from typing import cast
+
+import torch
+from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity
+
+class LSTMSaliencyPruner(BaseStructuredSparsifier):
+    """
+    Prune packed LSTM weights based on saliency.
+    For each layer {k} inside a LSTM, we have two packed weight matrices
+    - weight_ih_l{k}
+    - weight_hh_l{k}
+
+    These tensors pack the weights for the 4 linear layers together for efficiency.
+
+    [W_ii | W_if | W_ig | W_io]
+
+    Pruning this tensor directly will lead to weights being misassigned when unpacked.
+    To ensure that each packed linear layer is pruned the same amount:
+        1. We split the packed weight into the 4 constituent linear parts
+        2. Update the mask for each individual piece using saliency individually
+
+    This applies to both weight_ih_l{k} and weight_hh_l{k}.
+    """
+
+    def update_mask(self, module, tensor_name, **kwargs):
+        weights = getattr(module, tensor_name)
+
+        for p in getattr(module.parametrizations, tensor_name):
+            if isinstance(p, FakeStructuredSparsity):
+                mask = cast(torch.Tensor, p.mask)
+
+                # select weights based on magnitude
+                if weights.dim() <= 1:
+                    raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!")
+                # take norm over all but first dim
+                dims = tuple(range(1, weights.dim()))
+                saliency = weights.norm(dim=dims, p=1)
+
+                # handle weights in 4 groups
+                split_size = len(mask) // 4
+                masks = torch.split(mask, split_size)
+                saliencies = torch.split(saliency, split_size)
+
+                for keep_mask, sal in zip(masks, saliencies):
+                    # mask smallest k values to be removed
+                    k = int(len(keep_mask) * kwargs["sparsity_level"])
+                    prune = sal.topk(k, largest=False, sorted=False).indices
+                    keep_mask.data[prune] = False  # modifies underlying p.mask directly
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1f995d96279eaaf34264706fdf91a0555341412
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py
@@ -0,0 +1,59 @@
+"""
+Contains utility functions to check if a pattern is in the graph and return the matching nodes
+"""
+import torch
+from torch import nn
+from torch.ao.quantization.utils import (
+    MatchAllNode,
+)
+from torch.fx import Node
+from torch.nn.utils import parametrize
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+def _match(modules: Dict[str, nn.ModuleDict], node: Node, current: Union[nn.Module, Any]) -> bool:
+    r"""
+    checks to see if a single node of a pattern matches
+    """
+    if isinstance(current, type) and issubclass(current, MatchAllNode):
+        return True
+    if not isinstance(node, Node):
+        return False
+    if isinstance(current, type) and issubclass(current, torch.nn.Module):
+        return (
+            node.op == "call_module"
+            and parametrize.type_before_parametrizations(modules[node.target])
+            == current
+        )
+    elif callable(current):
+        return node.op == "call_function" and node.target is current
+    elif isinstance(current, str):
+        return node.target == current
+    return False
+
+def apply_match(
+    modules: Dict[str, nn.ModuleDict],
+    pattern: Union[Tuple[Any], Any],
+    node: Node,
+    matched_node_pattern: List[Node],
+) -> Optional[List[Node]]:
+    r"""
+    This function will return the matched nodes if the pattern matches the node given
+    If there is no match, it will return None
+    """
+    if isinstance(pattern, tuple):
+        if len(pattern) == 1:
+            if _match(modules, node, pattern[0]):
+                return matched_node_pattern + [node]
+
+        first, *rest = pattern
+        if _match(modules, node, first):
+            if rest is None:
+                return matched_node_pattern + [node]
+
+            for user in node.users:
+                return apply_match(
+                    modules, tuple(rest), user, matched_node_pattern + [node]
+                )
+    elif _match(modules, node, pattern):
+        return [node]
+    return None
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e87299cd894c0731518a37bc224376e84e130ed
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py
@@ -0,0 +1,59 @@
+import torch
+from torch import nn
+from torch.nn.utils.parametrize import is_parametrized
+
+
+def module_contains_param(module, parametrization):
+    if is_parametrized(module):
+        # see if any of the module tensors have a parametriztion attached that matches the one passed in
+        return any(
+            any(isinstance(param, parametrization) for param in param_list)
+            for key, param_list in module.parametrizations.items()
+        )
+    return False
+
+
+# Structured Pruning Parameterizations
+class FakeStructuredSparsity(nn.Module):
+    r"""
+    Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to
+    the  'weight' or any other parameter that requires a mask.
+
+    Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask.
+    """
+
+    def __init__(self, mask):
+        super().__init__()
+        self.register_buffer("mask", mask)
+
+    def forward(self, x):
+        assert isinstance(self.mask, torch.Tensor)
+        assert self.mask.shape[0] == x.shape[0]
+        shape = [1] * len(x.shape)
+        shape[0] = -1
+        return self.mask.reshape(shape) * x
+
+    def state_dict(self, *args, **kwargs):
+        # avoid double saving masks
+        return {}
+
+
+class BiasHook:
+    def __init__(self, parametrization, prune_bias):
+        self.param = parametrization
+        self.prune_bias = prune_bias
+
+    def __call__(self, module, input, output):
+
+        if getattr(module, "_bias", None) is not None:
+            bias = module._bias.data
+            if self.prune_bias:
+                bias[~self.param.mask] = 0
+
+            # reshape bias to broadcast over output dimensions
+            idx = [1] * len(output.shape)
+            idx[1] = -1
+            bias = bias.reshape(idx)
+
+            output += bias
+        return output
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..55fb7a973ae0a25c4bb12fd99245c45c740c4aaf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py
@@ -0,0 +1,475 @@
+"""
+Collection of conversion functions for linear / conv2d structured pruning
+Also contains utilities for bias propagation
+"""
+from typing import cast, List, Optional, Callable, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn.utils import parametrize
+from torch.nn.utils.parametrize import ParametrizationList
+from .parametrization import FakeStructuredSparsity, BiasHook
+
+# BIAS PROPAGATION
+def _remove_bias_handles(module: nn.Module) -> None:
+    if hasattr(module, "_forward_hooks"):
+        bias_hooks: List[int] = []
+        for key, hook in module._forward_hooks.items():
+            if isinstance(hook, BiasHook):
+                bias_hooks.append(key)
+
+        for key in bias_hooks:
+            del module._forward_hooks[key]
+
+
+def _get_adjusted_next_layer_bias(
+    next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor
+) -> nn.Parameter:
+    r"""Returns new adjusted bias for the second supported module"""
+    if parametrize.is_parametrized(next_layer):
+        # need to access original weight
+        parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations)
+        weight_parameterizations = cast(
+            ParametrizationList, parametrization_dict.weight
+        )
+        next_weight = weight_parameterizations.original
+    else:
+        next_weight = cast(Tensor, next_layer.weight)
+
+    scaling_weight = next_weight[:, ~mask]
+    if isinstance(next_layer, nn.Conv2d):  # checking for Conv2d
+        # Propagating first layer pruned biases and calculating the new second layer bias
+        # involves more steps since the Conv2d scaling weight has extra dimensions,
+        # so adding bias involves broadcasting, logically:
+        # for each channel k in range(oC):
+        #     scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T)
+        #     new_next_bias[k] = old_next_bias[k] + scaled_biases
+        scaling_product = torch.matmul(
+            pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2)
+        )
+        sum_range = list(range(len(scaling_product.shape)))[
+            1:
+        ]  # all but the first dimension
+        scaled_biases = torch.sum(scaling_product, sum_range)
+    elif isinstance(next_layer, nn.Linear):  # Linear
+        scaled_biases = torch.matmul(
+            pruned_biases, torch.transpose(scaling_weight, 0, 1)
+        )  # recall b2_new = b1 @ w2.T + b2
+    else:
+        raise NotImplementedError(f"Type {type(next_layer)} not supported yet.")
+
+    if (
+        parametrize.is_parametrized(next_layer)
+        and getattr(next_layer, "_bias", None) is not None
+    ):  # next_layer is parametrized & has original bias ._bias
+        adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias)
+    elif (
+        not parametrize.is_parametrized(next_layer) and next_layer.bias is not None
+    ):  # next_layer not parametrized & has .bias
+        adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias)
+    else:  # next_layer has no bias
+        adjusted_bias = nn.Parameter(scaled_biases)
+    return adjusted_bias
+
+
+def _prune_module_bias(module: nn.Module, mask: Tensor) -> None:
+    r"""Applies mask to given modules bias"""
+    # prune bias along with weights, discard pruned indices of bias
+    original_bias = cast(Tensor, getattr(module, "_bias", module.bias))
+    if original_bias is not None:
+        module.bias = nn.Parameter(original_bias[mask])
+
+    #  remove _bias parameter
+    if hasattr(module, "_bias"):
+        delattr(module, "_bias")
+
+
+def _propogate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
+    r"""
+    In the case that we need to propagate biases, this function will return the biases we need
+    """
+    # set current module bias
+    if module.bias is not None:
+        module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
+    elif getattr(module, "_bias", None) is not None:
+        module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
+
+    # get pruned biases to propagate to subsequent layer
+    if getattr(module, "_bias", None) is not None:
+        pruned_biases = cast(Tensor, module._bias)[~mask]
+    else:
+        pruned_biases = None
+
+    if hasattr(module, "_bias"):
+        delattr(module, "_bias")
+
+    return pruned_biases
+
+
+# LINEAR
+def _prune_linear_helper(linear: nn.Linear) -> Tensor:
+    # expects linear to be a parameterized linear module
+    parametrization_dict = cast(nn.ModuleDict, linear.parametrizations)
+    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
+    for p in weight_parameterizations:
+        if isinstance(p, FakeStructuredSparsity):
+            mask = cast(Tensor, p.mask)
+
+    with torch.no_grad():
+        parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True)
+        linear.weight = nn.Parameter(linear.weight[mask])  # type: ignore[possibly-undefined]
+    linear.out_features = linear.weight.shape[0]
+    _remove_bias_handles(linear)
+
+    return mask
+
+
+def prune_linear(linear: nn.Linear) -> None:
+    mask = _prune_linear_helper(linear)
+    if getattr(linear, "prune_bias", False):
+        _prune_module_bias(linear, mask)
+
+
+def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None:
+    prune_linear_activation_linear(linear1, None, linear2)
+
+
+def prune_linear_activation_linear(
+    linear1: nn.Linear,
+    activation: Optional[Callable[[Tensor], Tensor]],
+    linear2: nn.Linear,
+):
+    mask = _prune_linear_helper(linear1)
+    if getattr(linear1, "prune_bias", False):
+        _prune_module_bias(linear1, mask)
+    else:
+        pruned_biases = _propogate_module_bias(linear1, mask)
+        if pruned_biases is not None:
+            if activation:
+                pruned_biases = activation(pruned_biases)
+            linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask)
+
+    with torch.no_grad():
+        if parametrize.is_parametrized(linear2):
+            parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations)
+            weight_parameterizations = cast(
+                ParametrizationList, parametrization_dict.weight
+            )
+
+            weight_parameterizations.original = nn.Parameter(
+                weight_parameterizations.original[:, mask]
+            )
+            linear2.in_features = weight_parameterizations.original.shape[1]
+        else:
+            linear2.weight = nn.Parameter(linear2.weight[:, mask])
+            linear2.in_features = linear2.weight.shape[1]
+
+
+# CONV2D
+def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
+    parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations)
+    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
+    for p in weight_parameterizations:
+        if isinstance(p, FakeStructuredSparsity):
+            mask = cast(Tensor, p.mask)
+
+    with torch.no_grad():
+        parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True)
+        conv2d.weight = nn.Parameter(conv2d.weight[mask])  # type: ignore[possibly-undefined]
+    conv2d.out_channels = conv2d.weight.shape[0]
+
+    _remove_bias_handles(conv2d)
+    return mask
+
+
+def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
+    parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations)
+    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
+    for p in weight_parameterizations:
+        if isinstance(p, FakeStructuredSparsity):
+            mask = cast(Tensor, p.mask)
+
+    with torch.no_grad():
+        parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True)
+
+    if getattr(conv2d_1, "_bias", None) is not None:
+        if (
+            conv2d_1.bias is not None
+        ):  # conv2d_1 has original bias and bias propagated from previous layer
+            new_bias = torch.zeros(conv2d_1.bias.shape)
+            new_bias[mask] = conv2d_1.bias[mask]  # type: ignore[possibly-undefined]
+            # adjusted bias that to keep in conv2d_1
+            new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
+            # pruned biases that are kept instead of propagated
+            conv2d_1.bias = nn.Parameter(new_bias)
+        else:  # conv2d_1 has only original bias
+            conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias))
+    else:
+        # no original bias, only propagated bias
+        if (
+            conv2d_1.bias is not None
+        ):  # conv2d_1 has bias propagated from previous layer
+            conv2d_1.bias.data[~mask] = 0  # type: ignore[possibly-undefined]
+
+    if hasattr(conv2d_1, "_bias"):
+        delattr(conv2d_1, "_bias")
+
+
+def prune_conv2d(conv2d: nn.Conv2d) -> None:
+    mask = _prune_conv2d_helper(conv2d)
+    if getattr(conv2d, "prune_bias", False):
+        _prune_module_bias(conv2d, mask)
+
+
+def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None:
+    prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2)
+
+
+def prune_conv2d_activation_conv2d(
+    conv2d_1: nn.Conv2d,
+    activation: Optional[Callable[[Tensor], Tensor]],
+    conv2d_2: nn.Conv2d,
+):
+    r"""
+    Fusion Pattern for conv2d -> some activation module / function -> conv2d layers
+    """
+    parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations)
+    weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight)
+    for p in weight_parameterizations:
+        if isinstance(p, FakeStructuredSparsity):
+            mask = cast(Tensor, p.mask)
+
+    prune_bias = getattr(conv2d_1, "prune_bias", False)
+    if (
+        hasattr(conv2d_2, "padding")
+        and cast(Tuple[int], conv2d_2.padding) > (0, 0)
+        and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None)
+    ):
+        prune_conv2d_padded(conv2d_1)
+    else:
+        mask = _prune_conv2d_helper(conv2d_1)
+        if prune_bias:
+            _prune_module_bias(conv2d_1, mask)
+        else:
+            pruned_biases = _propogate_module_bias(conv2d_1, mask)
+            if pruned_biases is not None:
+                if activation:
+                    pruned_biases = activation(pruned_biases)
+                conv2d_2.bias = _get_adjusted_next_layer_bias(
+                    conv2d_2, pruned_biases, mask
+                )
+
+        if (
+            not (
+                hasattr(conv2d_2, "padding")
+                and cast(Tuple[int], conv2d_2.padding) > (0, 0)
+            )
+            or conv2d_1.bias is None
+        ):
+            with torch.no_grad():
+                if parametrize.is_parametrized(conv2d_2):
+                    parametrization_dict = cast(
+                        nn.ModuleDict, conv2d_2.parametrizations
+                    )
+                    weight_parameterizations = cast(
+                        ParametrizationList, parametrization_dict.weight
+                    )
+                    weight_parameterizations.original = nn.Parameter(
+                        weight_parameterizations.original[:, mask]
+                    )
+                    conv2d_2.in_channels = weight_parameterizations.original.shape[1]
+                else:
+                    conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask])
+                    conv2d_2.in_channels = conv2d_2.weight.shape[1]
+
+
+def prune_conv2d_pool_activation_conv2d(
+    c1: nn.Conv2d,
+    pool: nn.Module,
+    activation: Optional[Callable[[Tensor], Tensor]],
+    c2: nn.Conv2d,
+) -> None:
+    prune_conv2d_activation_conv2d(c1, activation, c2)
+
+
+def prune_conv2d_activation_pool_conv2d(
+    c1: nn.Conv2d,
+    activation: Optional[Callable[[Tensor], Tensor]],
+    pool: nn.Module,
+    c2: nn.Conv2d,
+) -> None:
+    prune_conv2d_activation_conv2d(c1, activation, c2)
+
+
+def prune_conv2d_pool_flatten_linear(
+    conv2d: nn.Conv2d,
+    pool: nn.Module,
+    flatten: Optional[Callable[[Tensor], Tensor]],
+    linear: nn.Linear,
+) -> None:
+    mask = _prune_conv2d_helper(conv2d)
+
+    # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer.
+    # we determine the flattening scale (h * w), and readjust `first_pruned_indices`
+    # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`,
+    # and `pruned_biases` (repeat each bias by h * w).
+    if parametrize.is_parametrized(linear):
+        parametrization_dict = cast(nn.ModuleDict, linear.parametrizations)
+        weight_parameterizations = cast(
+            ParametrizationList, parametrization_dict.weight
+        )
+        linear_ic = weight_parameterizations.original.shape[1]
+    else:
+        linear_ic = linear.weight.shape[1]
+
+    conv2d_oc = len(mask)
+    assert (
+        linear_ic % conv2d_oc == 0
+    ), f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported"
+
+    flatten_scale = linear_ic // conv2d_oc
+    flattened_mask = torch.tensor(
+        [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device
+    ).flatten()
+
+    if getattr(conv2d, "prune_bias", False):
+        _prune_module_bias(conv2d, mask)
+    else:
+        pruned_biases = cast(Tensor, _propogate_module_bias(conv2d, mask))
+        flattened_pruned_biases = torch.tensor(
+            [[bias] * flatten_scale for bias in pruned_biases], device=mask.device
+        ).flatten()
+        linear.bias = _get_adjusted_next_layer_bias(
+            linear, flattened_pruned_biases, flattened_mask
+        )
+
+    with torch.no_grad():
+        if parametrize.is_parametrized(linear):
+            parametrization_dict = cast(nn.ModuleDict, linear.parametrizations)
+            weight_parameterizations = cast(
+                ParametrizationList, parametrization_dict.weight
+            )
+            weight_parameterizations.original = nn.Parameter(
+                weight_parameterizations.original[:, flattened_mask]
+            )
+            linear.in_features = weight_parameterizations.original.shape[1]
+        else:
+            linear.weight = nn.Parameter(linear.weight[:, flattened_mask])
+            linear.in_features = linear.weight.shape[1]
+
+
+def prune_lstm_output_linear(
+    lstm: nn.LSTM, getitem: Callable, linear: nn.Linear
+) -> None:
+    prune_lstm_output_layernorm_linear(lstm, getitem, None, linear)
+
+
+def prune_lstm_output_layernorm_linear(
+    lstm: nn.LSTM,
+    getitem: Callable,
+    layernorm: Optional[nn.LayerNorm],
+    linear: nn.Linear,
+) -> None:
+    for i in range(lstm.num_layers):
+        if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"):
+            parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations)
+            weight_parameterizations = cast(
+                ParametrizationList, parametrization_dict[f"weight_ih_l{i}"]
+            )
+            mask = weight_parameterizations[0].mask
+
+            with torch.no_grad():
+                parametrize.remove_parametrizations(
+                    lstm, f"weight_ih_l{i}", leave_parametrized=True
+                )
+                setattr(
+                    lstm,
+                    f"weight_ih_l{i}",
+                    nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]),
+                )
+                setattr(
+                    lstm,
+                    f"bias_ih_l{i}",
+                    nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]),
+                )
+
+        if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"):
+            parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations)
+            weight_parameterizations = cast(
+                ParametrizationList, parametrization_dict[f"weight_hh_l{i}"]
+            )
+            mask = weight_parameterizations[0].mask
+
+            with torch.no_grad():
+                parametrize.remove_parametrizations(
+                    lstm, f"weight_hh_l{i}", leave_parametrized=True
+                )
+                # splitting out hidden-hidden masks
+                W_hi, W_hf, W_hg, W_ho = torch.split(
+                    getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size
+                )
+                M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size)
+
+                # resize each individual weight separately
+                W_hi = W_hi[M_hi][:, M_hi]
+                W_hf = W_hf[M_hf][:, M_hf]
+                W_hg = W_hg[M_hg][:, M_hg]
+                W_ho = W_ho[M_ho][:, M_ho]
+
+                # concat, use this as new weight
+                new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho))
+                setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight))
+                setattr(
+                    lstm,
+                    f"bias_hh_l{i}",
+                    nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]),
+                )
+
+            # If this is the final layer, then we need to prune linear layer columns
+            if i + 1 == lstm.num_layers:
+                lstm.hidden_size = int(M_hi.sum())
+                with torch.no_grad():
+                    if parametrize.is_parametrized(linear):
+                        parametrization_dict = cast(
+                            nn.ModuleDict, linear.parametrizations
+                        )
+                        weight_parameterizations = cast(
+                            ParametrizationList, parametrization_dict.weight
+                        )
+
+                        weight_parameterizations.original = nn.Parameter(
+                            weight_parameterizations.original[:, M_ho]
+                        )
+                        linear.in_features = weight_parameterizations.original.shape[1]
+                    else:
+                        linear.weight = nn.Parameter(linear.weight[:, M_ho])
+                        linear.in_features = linear.weight.shape[1]
+
+                    # if layernorm module, prune weight and bias
+                    if layernorm is not None:
+                        layernorm.normalized_shape = (linear.in_features,)
+                        layernorm.weight = nn.Parameter(layernorm.weight[M_ho])
+                        layernorm.bias = nn.Parameter(layernorm.bias[M_ho])
+
+            # otherwise need to prune the columns of the input of the next LSTM layer
+            else:
+                with torch.no_grad():
+                    if parametrize.is_parametrized(lstm, f"weight_ih_l{i+1}"):
+                        parametrization_dict = cast(
+                            nn.ModuleDict, lstm.parametrizations
+                        )
+                        weight_parameterizations = cast(
+                            ParametrizationList,
+                            getattr(parametrization_dict, f"weight_ih_l{i+1}"),
+                        )
+
+                        weight_parameterizations.original = nn.Parameter(
+                            weight_parameterizations.original[:, M_ho]
+                        )
+                    else:
+                        next_layer_weight = getattr(lstm, f"weight_ih_l{i+1}")
+                        setattr(
+                            lstm,
+                            f"weight_ih_l{i+1}",
+                            nn.Parameter(next_layer_weight[:, M_ho]),
+                        )
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe196576a9f50157ca088380e573b99988cee574
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py
@@ -0,0 +1,29 @@
+from .base_structured_sparsifier import BaseStructuredSparsifier
+
+
+class SaliencyPruner(BaseStructuredSparsifier):
+    """
+    Prune rows based on the saliency (L1 norm) of each row.
+
+    This pruner works on N-Dimensional weight tensors.
+    For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row.
+    We expect that the resulting saliency vector has the same shape as our mask.
+    We then pick elements to remove until we reach the target sparsity_level.
+    """
+
+    def update_mask(self, module, tensor_name, **kwargs):
+        # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs
+        weights = getattr(module, tensor_name)
+        mask = getattr(module.parametrizations, tensor_name)[0].mask
+
+        # use negative weights so we can use topk (we prune out the smallest)
+        if weights.dim() <= 1:
+            raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!")
+        saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1)
+        assert saliency.shape == mask.shape
+
+        num_to_pick = int(len(mask) * kwargs["sparsity_level"])
+        prune = saliency.topk(num_to_pick).indices
+
+        # Set the mask to be false for the rows we want to prune
+        mask.data[prune] = False
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/_mappings.py b/MLPY/Lib/site-packages/torch/ao/pruning/_mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaad70fb7b9e6da48bfbd1ad31653efdc6e24bb9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/_mappings.py
@@ -0,0 +1,18 @@
+__all__ = [
+    "get_static_sparse_quantized_mapping",
+    "get_dynamic_sparse_quantized_mapping",
+]
+
+def get_static_sparse_quantized_mapping():
+    import torch.ao.nn.sparse
+    _static_sparse_quantized_mapping = {
+        torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear,
+    }
+    return _static_sparse_quantized_mapping
+
+def get_dynamic_sparse_quantized_mapping():
+    import torch.ao.nn.sparse
+    _dynamic_sparse_quantized_mapping = {
+        torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear,
+    }
+    return _dynamic_sparse_quantized_mapping
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d73b0664ecf49ca2fa809bc87ab1184cde94686
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76a9b97c3b0e336562af0f3995b62f7f08ed9ecb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d7b50bbacda61fb3fb3de930b374ee643a8129d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3460022221cef516fb1a969d22ad759a1b11fce
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/base_scheduler.py b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/base_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaec38eb3112d756b52f302106377609a4b92317
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/base_scheduler.py
@@ -0,0 +1,159 @@
+
+from torch.ao.pruning import BaseSparsifier
+
+from functools import wraps
+import warnings
+import weakref
+
+__all__ = ["BaseScheduler"]
+
+class BaseScheduler:
+
+    def __init__(self, sparsifier, last_epoch=-1, verbose=False):
+
+        # Attach sparsifier
+        if not isinstance(sparsifier, BaseSparsifier):
+            raise TypeError(f'{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier')
+        self.sparsifier = sparsifier
+
+        # Initialize epoch and base sparsity levels
+
+        self.base_sl = [group['sparsity_level'] for group in sparsifier.groups]
+        self.last_epoch = last_epoch
+
+        # Following https://github.com/pytorch/pytorch/issues/20124
+        # We would like to ensure that `scheduler.step()` is called after
+        # `sparsifier.step()`
+        def with_counter(method):
+            if getattr(method, '_with_counter', False):
+                # `sparsifier.step()` has already been replaced, return.
+                return method
+
+            # Keep a weak reference to the sparsifier instance to prevent
+            # cyclic references.
+            instance_ref = weakref.ref(method.__self__)
+            # Get the unbound method for the same purpose.
+            func = method.__func__
+            cls = instance_ref().__class__
+            del method
+
+            @wraps(func)
+            def wrapper(*args, **kwargs):
+                instance = instance_ref()
+                instance._step_count += 1  # type: ignore[union-attr]
+                wrapped = func.__get__(instance, cls)
+                return wrapped(*args, **kwargs)
+
+            # Note that the returned function here is no longer a bound method,
+            # so attributes like `__func__` and `__self__` no longer exist.
+            wrapper._with_counter = True  # type: ignore[attr-defined]
+            return wrapper
+
+        self.sparsifier.step = with_counter(self.sparsifier.step)  # type: ignore[assignment]
+        self.sparsifier._step_count = 0  # type: ignore[attr-defined]
+        self._step_count: int = 0
+        self.verbose = verbose
+
+        # Housekeeping
+        self._get_sl_called_within_step: bool = False
+
+        self.step()
+
+    def state_dict(self):
+        """Returns the state of the scheduler as a :class:`dict`.
+
+        It contains an entry for every variable in self.__dict__ which
+        is not the sparsifier.
+        """
+        return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'}
+
+    def load_state_dict(self, state_dict):
+        """Loads the schedulers state.
+
+        Args:
+            state_dict (dict): scheduler state. Should be an object returned
+                from a call to :meth:`state_dict`.
+        """
+        self.__dict__.update(state_dict)
+
+    def get_last_sl(self):
+        """ Return last computed sparsity level by current scheduler.
+        """
+        return self._last_sl
+
+    def get_sl(self):
+        # Compute sparsity level using chainable form of the scheduler
+        # Note: This method is not intended to be called directly, and is only
+        #       used by the ".step" method. Use .get_last_sl() instead.
+        if not self._get_sl_called_within_step:
+            warnings.warn(
+                "To get the last sparsity level computed by the scheduler, "
+                "please use `get_last_sl()`.")
+        raise NotImplementedError
+
+    def print_sl(self, is_verbose, group, sl, epoch=None):
+        """Display the current sparsity level.
+        """
+        if is_verbose:
+            if epoch is None:
+                print(f'Adjusting sparsity level of group {group} to {sl:.4e}.')
+            else:
+                print(f'Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}.')
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + ' ('
+        format_string += '\n'
+        format_string += f'Sparsifier {self.sparsifier}\n'
+        format_string += f'    base_sl: {self.base_sl}\n'
+        format_string += ')'
+        return format_string
+
+    def step(self, epoch=None):
+        # Raise warning if trying to call scheduler step before the sparsifier.
+        # https://github.com/pytorch/pytorch/issues/20124
+        if self._step_count == 1:
+            if not hasattr(self.sparsifier.step, "_with_counter"):
+                warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler "
+                              "initialization. Please, make sure to call `sparsifier.step()` before "
+                              "`scheduler.step()`.", UserWarning)
+
+            # Just check if there were two first scheduler.step() calls before sparsifier.step()
+            elif self.sparsifier._step_count < 1:  # type: ignore[attr-defined]
+                warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. "
+                              "You have to make sure you run the sparsifier.step() BEFORE any "
+                              "calls to the scheduler.step().", UserWarning)
+        self._step_count += 1
+
+        class _enable_get_sl_call:
+
+            def __init__(self, o):
+                self.o = o
+
+            def __enter__(self):
+                self.o._get_sl_called_within_step = True
+                return self
+
+            def __exit__(self, type, value, traceback):
+                self.o._get_sl_called_within_step = False
+
+        with _enable_get_sl_call(self):
+            self.last_epoch += 1
+            values = self.get_sl()
+
+        for i, data in enumerate(zip(self.sparsifier.groups, values)):
+            param_group, sl = data
+            param_group['sparsity_level'] = sl
+            self.print_sl(self.verbose, i, sl, epoch)
+
+        self._last_sl = [group['sparsity_level'] for group in self.sparsifier.groups]
+        self.sparsifier.enable_mask_update = True
+
+    def _make_sure_a_list(self, var):
+        r"""Utility that extends it to the same length as the .groups, ensuring it is a list"""
+        n = len(self.sparsifier.groups)
+        if not isinstance(var, (list, tuple)):
+            return [var] * n
+        else:
+            if len(var) != n:
+                raise ValueError(f"Expected variable of length {n}, but got {len(var)}")
+            return list(var)  # We want the result to be in a list, not tuple
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fac7f8212478892e94322873f870bea5722a657e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py
@@ -0,0 +1,107 @@
+import warnings
+
+from .base_scheduler import BaseScheduler
+
+__all__ = ["CubicSL"]
+
+def _clamp(x, lo, hi):
+    return max(lo, min(hi, x))
+
+
+class CubicSL(BaseScheduler):
+    r"""Sets the sparsity level of each parameter group to the final sl
+    plus a given exponential function.
+
+    .. math::
+
+        s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3
+
+    where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final
+    sparsity level, :math:`f(i)` is the function to be applied to the current epoch
+    :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`.
+    :math:`\Delta t` is used to control how often the update of the sparsity level
+    happens. By default,
+
+    Args:
+        sparsifier (BaseSparsifier): Wrapped sparsifier.
+        init_sl (int, list): Initial level of sparsity
+        init_t (int, list): Initial step, when pruning starts
+        delta_t (int, list): Pruning frequency
+        total_t (int, list): Total number of pruning steps
+        initially_zero (bool, list): If True, sets the level of sparsity to 0
+            before init_t (:math:`t_0`). Otherwise, the sparsity level before
+            init_t (:math:`t_0`) is set to init_sl(:math:`s_0`)
+        last_epoch (int): The index of last epoch. Default: -1.
+        verbose (bool): If ``True``, prints a message to stdout for
+            each update. Default: ``False``.
+    """
+    def __init__(self,
+                 sparsifier,
+                 init_sl=0.0,
+                 init_t=0,
+                 delta_t=10,
+                 total_t=100,
+                 initially_zero=False,
+                 last_epoch=-1,
+                 verbose=False
+                 ):
+        self.sparsifier = sparsifier
+
+        self.init_sl = self._make_sure_a_list(init_sl)
+        self.init_t = self._make_sure_a_list(init_t)
+        self.delta_t = self._make_sure_a_list(delta_t)
+        self.total_t = self._make_sure_a_list(total_t)
+
+        self.initially_zero = self._make_sure_a_list(initially_zero)
+
+        super().__init__(sparsifier, last_epoch, verbose)
+
+    @staticmethod
+    def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False):
+        r""""Computes the current level of sparsity.
+
+        Based on https://arxiv.org/pdf/1710.01878.pdf
+
+        Args:
+            s_0: Initial level of sparsity, :math:`s_i`
+            s_f: Target level of sparsity, :math:`s_f`
+            t: Current step, :math:`t`
+            t_0: Initial step, :math:`t_0`
+            dt: Pruning frequency, :math:`\Delta T`
+            n: Pruning steps, :math:`n`
+            initially_zero: Sets the level of sparsity to 0 before t_0.
+                If False, sets to s_0
+
+        Returns:
+            The sparsity level :math:`s_t` at the current step :math:`t`
+        """
+        if initially_zero and t < t_0:
+            return 0
+        s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3
+        s_t = _clamp(s_t, s_0, s_f)
+        return s_t
+
+    def get_sl(self):
+        if not self._get_sl_called_within_step:
+            warnings.warn(
+                "To get the last sparsity level computed by the scheduler, "
+                "please use `get_last_sl()`.")
+        return [
+            self.sparsity_compute_fn(
+                s_0=initial_sparsity,
+                s_f=final_sparsity,
+                t=self.last_epoch,
+                t_0=initial_epoch,
+                dt=delta_epoch,
+                n=interval_epochs,
+                initially_zero=initially_zero
+            ) for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in
+            zip(
+                self.init_sl,
+                self.base_sl,
+                self.init_t,
+                self.delta_t,
+                self.total_t,
+                self.initially_zero
+            )
+        ]
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..65bf3998757bfe6f35c1bb57d6281c016473fc0d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py
@@ -0,0 +1,47 @@
+import warnings
+
+from .base_scheduler import BaseScheduler
+
+__all__ = ["LambdaSL"]
+
+class LambdaSL(BaseScheduler):
+    """Sets the sparsity level of each parameter group to the final sl
+    times a given function. When last_epoch=-1, sets initial sl as zero.
+    Args:
+        sparsifier (BaseSparsifier): Wrapped sparsifier.
+        sl_lambda (function or list): A function which computes a multiplicative
+            factor given an integer parameter epoch, or a list of such
+            functions, one for each group in sparsifier.param_groups.
+        last_epoch (int): The index of last epoch. Default: -1.
+        verbose (bool): If ``True``, prints a message to stdout for
+            each update. Default: ``False``.
+    Example:
+        >>> # Assuming sparsifier has two groups.
+        >>> lambda1 = lambda epoch: epoch // 30
+        >>> lambda2 = lambda epoch: 0.95 ** epoch
+        >>> # xdoctest: +SKIP
+        >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
+        >>> for epoch in range(100):
+        >>>     train(...)
+        >>>     validate(...)
+        >>>     scheduler.step()
+    """
+
+    def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
+        self.sparsifier = sparsifier
+
+        if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
+            self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
+        else:
+            if len(sl_lambda) != len(sparsifier.groups):
+                raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}")
+            self.sl_lambdas = list(sl_lambda)
+        super().__init__(sparsifier, last_epoch, verbose)
+
+    def get_sl(self):
+        if not self._get_sl_called_within_step:
+            warnings.warn(
+                "To get the last sparsity level computed by the scheduler, "
+                "please use `get_last_sl()`.")
+        return [base_sl * lmbda(self.last_epoch)
+                for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__init__.py b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..518459be9ee9c87f384d8686d8c6fb5b168749e1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbb8273c33af3b7a6f6ad335e33aa3b8b727e98b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15b94ee81735e01b3b6ee61146204755a009a133
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87557d0f95fde903085ef80887350effbe11950c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8537ad7d0c9bd51fdb1a2ab9abf50de06c6cd2db
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..89165aab967982bce9700b047f17421cb023c742
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py
@@ -0,0 +1,353 @@
+import abc
+import copy
+from collections import defaultdict
+from typing import Any, Dict, Optional, Set, Tuple, List, Type
+
+import torch
+from torch import nn
+from torch.nn.utils import parametrize
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+from .utils import (
+    module_contains_param,
+    swap_module,
+    FakeSparsity,
+    get_arg_info_from_tensor_fqn,
+    module_to_fqn,
+)
+
+__all__ = ["BaseSparsifier"]
+
+SUPPORTED_MODULES = {nn.Linear}
+
+KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"]
+
+__all__ = ["BaseSparsifier"]
+
+
+# TODO update desc with new config args
+class BaseSparsifier(abc.ABC):
+    r"""Base class for all sparsifiers.
+
+    Abstract methods that need to be implemented:
+
+    - update_mask: Function to compute a new mask for all keys in the
+        `groups`.
+
+    Args:
+        - model [nn.Module]: model to configure. The model itself is not saved
+            but used for the state_dict saving / loading.
+        - config [list]: configuration elements should be a dict map that includes
+            `tensor_fqn` of tensors to sparsify
+        - defaults [dict]: default configurations will be attached to the
+            configuration. Only the keys that don't exist in the `config` will
+            be updated.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask")
+        >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}]
+        >>> defaults = {'sparsity_level': 0.7}
+        >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default)
+        >>> sparsifier = BaseSparsifier(config, defaults)
+    """
+
+    def __init__(self, defaults: Optional[Dict[str, Any]] = None):
+        super().__init__()
+        self.defaults: Dict[str, Any] = defaults or {}
+
+        self.state: Dict[str, Dict] = defaultdict(dict)
+        self.groups: List[Dict[str, Any]] = []
+        self.enable_mask_update = True
+
+    def __getstate__(self) -> Dict[str, Any]:
+        return {
+            "defaults": self.defaults,
+            "state": self.state,
+            "groups": self.groups,
+        }
+
+    def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
+        self.__dict__.update(state)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + " ("
+        for i, sparse_args in enumerate(self.groups):
+            module = sparse_args["module"]
+            format_string += "\n"
+            format_string += f"\tGroup {i}\n"
+            format_string += f"\t    module: {module}\n"
+            for key in sorted(sparse_args.keys()):
+                if key == "module":
+                    continue
+                format_string += f"\t    {key}: {sparse_args[key]}\n"
+        format_string += ")"
+        return format_string
+
+    def state_dict(self) -> Dict[str, Any]:
+        r"""Returns the state of the optimizer as a :class:`dict`.
+
+        It contains:
+        * state - current state of the sparsification.
+        * groups - a list containing all sparsity configuration groups
+            with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model
+
+        TODO: Need a clean way of loading the state of the "prepared" module
+        """
+
+        groups: List[Dict[str, Any]] = [
+            dict(
+                filter(
+                    lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT,
+                    mg.items(),
+                )
+            )
+            for mg in self.groups
+        ]
+
+        return {
+            "state": self.state,
+            "groups": groups,
+        }
+
+    def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
+        groups = copy.deepcopy(state_dict["groups"])
+        states = state_dict["state"]
+        for tensor_fqn, s in states.items():
+            arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn)
+            module = arg_info["module"]
+            tensor_name = arg_info["tensor_name"]
+            if strict and module is None:
+                raise RuntimeError(f"Error loading {tensor_fqn} into the model")
+
+            found = False
+            for p in module.parametrizations[tensor_name]:
+                if isinstance(p, FakeSparsity):
+                    found = True
+                    break
+            if not found:
+                p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape))
+                parametrize.register_parametrization(module, tensor_name, p)
+            if s.get("mask", None) is not None:
+                mask = s.pop("mask")
+                p.mask = mask
+
+            for mg in groups:
+                if mg["tensor_fqn"] == tensor_fqn:
+                    mg.update(arg_info)
+        self.__setstate__({"state": states, "groups": groups})
+
+    def make_config_from_model(
+        self,
+        model: nn.Module,
+        SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES,
+    ) -> None:
+        self.config = []
+        stack = [model]
+        while stack:
+            module = stack.pop()
+            for name, child in module.named_children():
+                if type(child) in SUPPORTED_MODULES:
+                    module_fqn = module_to_fqn(model, child)
+                    assert isinstance(module_fqn, str)  # for mypy
+                    self.config.append({"tensor_fqn": module_fqn + ".weight"})
+                else:
+                    stack.append(child)
+
+    def prepare(self, model, config):
+        r"""Prepares a model, by adding the parametrizations.
+
+        Note::
+
+            The model is modified inplace. If you need to preserve the original
+            model, use copy.deepcopy.
+        """
+        self.model = model  # TODO: Need to figure out how to load without this.
+        self.config = config
+
+        # If no config -- try getting all the supported layers
+        if self.config is None:
+            self.make_config_from_model(model)
+
+        # TODO: Remove the configuration by reference ('module')
+        for module_config in self.config:
+            assert isinstance(module_config, dict), (
+                "config elements should be dicts not modules i.e.:"
+                "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
+            )
+
+            assert isinstance(self.defaults, Dict)  # for mypy
+            local_args = copy.deepcopy(self.defaults)
+            local_args.update(module_config)
+
+            tensor_fqn = local_args.get("tensor_fqn", None)
+            assert tensor_fqn is not None, (
+                "tensor_fqn is a required argument in the sparsity config which"
+                "replaces previous `module` and [module]`fqn` arguments"
+            )
+
+            # populate all information from tensor_fqn
+            info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
+
+            # check that whatever was put into local_args agrees with what was obtained
+            # from tensor_fqn
+            for key in info_from_tensor_fqn.keys():
+                if key in local_args:
+                    assert (
+                        info_from_tensor_fqn[key] == local_args[key]
+                        or (
+                            key == "tensor_fqn"
+                            and "." + info_from_tensor_fqn[key] == local_args[key]
+                        )
+                        # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
+                    ), (
+                        f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
+                    )
+            local_args.update(info_from_tensor_fqn)
+            self.groups.append(local_args)
+        self._prepare()
+
+    def _prepare(self, *args, **kwargs):
+        r"""Adds mask parametrization to the layer weight"""
+        for config in self.groups:
+            module = config["module"]
+            tensor_name = config["tensor_name"]
+            parametrization = config.get("parametrization", FakeSparsity)
+            mask = config.get("mask", torch.ones_like(getattr(module, tensor_name)))
+            self.state[config["tensor_fqn"]]["mask"] = mask
+            parametrize.register_parametrization(
+                module, tensor_name, parametrization(mask)
+            )
+
+    def squash_mask(
+        self,
+        params_to_keep: Optional[Tuple[str, ...]] = None,
+        params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None,
+        *args,
+        **kwargs,
+    ):
+        r"""Squashes the sparse masks into the appropriate tensors.
+
+        If either the `params_to_keep` or `params_to_keep_per_layer` is set,
+        the module will have a `sparse_params` dict attached to it.
+
+        Args:
+            params_to_keep: List of keys to save in the module or a dict
+                            representing the modules and keys that will have
+                            sparsity parameters saved
+            params_to_keep_per_layer: Dict to specify the params that should be
+                            saved for specific layers. The keys in the dict
+                            should be the module fqn, while the values should
+                            be a list of strings with the names of the variables
+                            to save in the `sparse_params`
+
+        Examples:
+            >>> # xdoctest: +SKIP("locals are undefined")
+            >>> # Don't save any sparse params
+            >>> sparsifier.squash_mask()
+            >>> hasattr(model.submodule1, 'sparse_params')
+            False
+
+            >>> # Keep sparse params per layer
+            >>> sparsifier.squash_mask(
+            ...     params_to_keep_per_layer={
+            ...         'submodule1.linear1': ('foo', 'bar'),
+            ...         'submodule2.linear42': ('baz',)
+            ...     })
+            >>> print(model.submodule1.linear1.sparse_params)
+            {'foo': 42, 'bar': 24}
+            >>> print(model.submodule2.linear42.sparse_params)
+            {'baz': 0.1}
+
+            >>> # Keep sparse params for all layers
+            >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
+            >>> print(model.submodule1.linear1.sparse_params)
+            {'foo': 42, 'bar': 24}
+            >>> print(model.submodule2.linear42.sparse_params)
+            {'foo': 42, 'bar': 24}
+
+            >>> # Keep some sparse params for all layers, and specific ones for
+            >>> # some other layers
+            >>> sparsifier.squash_mask(
+            ...     params_to_keep=('foo', 'bar'),
+            ...     params_to_keep_per_layer={
+            ...         'submodule2.linear42': ('baz',)
+            ...     })
+            >>> print(model.submodule1.linear1.sparse_params)
+            {'foo': 42, 'bar': 24}
+            >>> print(model.submodule2.linear42.sparse_params)
+            {'foo': 42, 'bar': 24, 'baz': 0.1}
+        """
+        for config in self.groups:
+            module = config["module"]
+            tensor_name = config["tensor_name"]
+            parametrize.remove_parametrizations(
+                module, tensor_name, leave_parametrized=True
+            )
+            sparse_params = {}
+            if params_to_keep is not None:
+                global_params = {k: config[k] for k in params_to_keep}
+                sparse_params.update(global_params)
+            if params_to_keep_per_layer is not None:
+                params = params_to_keep_per_layer.get(config["module_fqn"], None)
+                if params is not None:
+                    per_layer_params = {k: config[k] for k in params}
+                    sparse_params.update(per_layer_params)
+            if sparse_params:
+                # TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
+                module.sparse_params = sparse_params
+
+    def convert(
+        self,
+        module: nn.Module,
+        mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None,
+        inplace: bool = False,
+        parameterization: Type[nn.Module] = FakeSparsity,
+    ):
+        r"""Converts submodules in input module to a different module according to `mapping`
+        by calling `from_dense` method on the target module class
+        Args:
+            module: input module
+            mapping: a dictionary that maps from source module type to target
+                module type, can be overwritten to allow swapping user defined
+                Modules
+            inplace: carry out model transformations in-place, the original module
+                is mutated
+        """
+        if mapping is None:
+            raise NotImplementedError("Need to auto generate mapping ")
+        if not inplace:
+            module = copy.deepcopy(module)
+
+        reassign = {}
+        for name, mod in module.named_children():
+            # leaf node
+            if (
+                module_contains_param(mod, parameterization)
+                and type_before_parametrizations(mod) in mapping
+            ):
+                reassign[name] = swap_module(mod, mapping)
+            else:
+                # recurse
+                reassign[name] = self.convert(
+                    mod,
+                    mapping=mapping,
+                    inplace=True,
+                    parameterization=parameterization,
+                )
+
+        for key, value in reassign.items():
+            module._modules[key] = value
+
+        return module
+
+    def step(self, use_path: bool = True) -> None:
+        if not self.enable_mask_update:
+            return
+        with torch.no_grad():
+            for config in self.groups:
+                self.update_mask(**config)
+
+    @abc.abstractmethod
+    def update_mask(self, module: nn.Module, tensor_name: str, **kwargs):
+        pass
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..2664b3480cdf14e11d7c25beabedb035aab1306b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py
@@ -0,0 +1,55 @@
+import torch
+
+from . import base_sparsifier
+
+
+class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier):
+    r"""Nearly Diagonal Sparsifier
+
+    This sparsifier creates a nearly diagonal mask to be applied to the weight matrix.
+    Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero.
+    An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively.
+    1 1 0 0       1 1 1 0
+    1 1 1 0       1 1 1 1
+    0 1 1 1       1 1 1 1
+    0 0 1 1       0 1 1 1
+    Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated
+
+    This sparsifier is controlled by one variable:
+    1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal.
+        Currently - supports only odd number
+
+    Note:
+        This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix
+        feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy
+
+    Args:
+        nearliness: The degree of nearliness (default = 1)
+
+    """
+    def __init__(self, nearliness: int = 1):
+        defaults = {'nearliness': nearliness}
+        super().__init__(defaults=defaults)
+
+    def update_mask(self, module, tensor_name, nearliness,
+                    **kwargs):
+        mask = getattr(module.parametrizations, tensor_name)[0].mask
+        mask.data = torch.zeros_like(mask)
+        if nearliness <= 0:
+            return
+
+        tensor = getattr(module, tensor_name)
+        height, width = tensor.shape
+
+        if nearliness % 2 == 0:
+            raise ValueError("nearliness can only be an odd number")
+        dist_to_diagonal = nearliness // 2
+        # check
+        if dist_to_diagonal >= min(height, width):
+            raise ValueError("nearliness cannot be larger than the dimensions of tensor.")
+
+        for row in range(0, height):
+            # Bounds of entries that needs to be set to 1
+            low = max(0, row - dist_to_diagonal)
+            high = min(width, row + dist_to_diagonal + 1)
+            mask[row, low:high].fill_(1)
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/utils.py b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3df67b3b53b9548538ce5d298de3a976397c5f6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/utils.py
@@ -0,0 +1,136 @@
+from typing import Any, Dict, Optional, Type
+from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized
+from itertools import chain
+
+from torch import nn
+
+__all__ = [
+    "module_contains_param",
+    "swap_module",
+    "module_to_fqn",
+    "fqn_to_module",
+    "get_arg_info_from_tensor_fqn",
+    "FakeSparsity",
+]
+
+
+def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool:
+    if is_parametrized(module):
+        # see if any of the module tensors have a parametriztion attached that matches the one passed in
+        return any(
+            any(isinstance(param, parametrization) for param in param_list)
+            for key, param_list in module.parametrizations.items()  # type: ignore[union-attr,operator]
+        )
+    return False
+
+
+def swap_module(
+    mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]]
+) -> nn.Module:
+    r"""Swaps the module using from_dense according to the mapping passed in.
+    Args:
+        mod: input module
+        mapping: a dictionary that maps from nn module to sparse nn module
+    Return:
+        The corresponding sparse module of `mod` according to mapping, created using from_dense
+    """
+    if type_before_parametrizations(mod) in mapping:
+        sparse_mod = mapping[type_before_parametrizations(mod)]
+
+        # TODO Fix this typing, as Type[Module] has no attribute "from_dense"
+        new_mod = sparse_mod.from_dense(mod)  # type: ignore[attr-defined]
+
+        # Preserve module's pre forward hooks. They'll be called on quantized input
+        for pre_hook_fn in mod._forward_pre_hooks.values():
+            new_mod.register_forward_pre_hook(pre_hook_fn)
+        # Preserve module's post forward hooks except _observer_forward_hook
+        # After convert they'll work with quantized output
+        for hook_fn in mod._forward_hooks.values():
+            new_mod.register_forward_hook(hook_fn)
+
+        # respect device affinity when swapping modules
+        devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
+        assert len(devices) <= 1, (
+            f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
+        )
+        device = next(iter(devices)) if len(devices) > 0 else None
+        if device:
+            new_mod.to(device)
+
+        return new_mod
+
+    else:
+        return mod
+
+
+def module_to_fqn(
+    model: nn.Module, module: nn.Module, prefix: str = ""
+) -> Optional[str]:
+    """
+    Returns the fqn for a module or None if module not a descendent of model.
+    """
+    if module is model:
+        return ""
+    for name, child in model.named_children():
+        fqn = module_to_fqn(child, module, ".")
+        if isinstance(fqn, str):
+            return prefix + name + fqn
+    return None
+
+
+def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
+    """
+    Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
+    doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
+    """
+    if path != "":
+        for name in path.split("."):
+            model = getattr(model, name, None)
+    return model
+
+
+def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]:
+    """
+    Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
+    """
+    # string manip to split tensor_fqn into module_fqn and tensor_name
+    # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
+    # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
+    tensor_name = tensor_fqn.split(".")[-1]
+    module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
+
+    module = fqn_to_module(model, module_fqn)
+
+    return {
+        "module_fqn": module_fqn,
+        "module": module,
+        "tensor_name": tensor_name,
+        "tensor_fqn": tensor_fqn,
+    }
+
+
+# Parametrizations
+class FakeSparsity(nn.Module):
+    r"""Parametrization for the weights. Should be attached to the 'weight' or
+    any other parameter that requires a mask applied to it.
+
+    Note::
+
+        Once the mask is passed, the variable should not change the id. The
+        contents of the mask can change, but the mask reference itself should
+        not.
+    """
+
+    def __init__(self, mask):
+        super().__init__()
+        self.register_buffer("mask", mask)
+
+    def forward(self, x):
+        assert self.mask.shape == x.shape
+        return self.mask * x
+
+    def state_dict(self, *args, **kwargs):
+        # We don't want to let the parametrizations to save the mask.
+        # That way we make sure that the linear module doesn't store the masks
+        # alongside their parametrizations.
+        return {}
diff --git a/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f1a47280a1fff3e169fb6f2b2dc69e8062132c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
@@ -0,0 +1,200 @@
+from functools import reduce
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+
+from .base_sparsifier import BaseSparsifier
+import operator
+
+__all__ = ["WeightNormSparsifier"]
+
+def _flat_idx_to_2d(idx, shape):
+    rows = idx // shape[1]
+    cols = idx % shape[1]
+    return rows, cols
+
+class WeightNormSparsifier(BaseSparsifier):
+    r"""Weight-Norm Sparsifier
+
+    This sparsifier computes the norm of every sparse block and "zeroes-out" the
+    ones with the lowest norm. The level of sparsity defines how many of the
+    blocks is removed.
+
+    This sparsifier is controlled by three variables:
+    1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
+    2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
+        the sparse blocks originate at the zero-index of the tensor.
+    3. `zeros_per_block` is the number of zeros that we are expecting in each
+        sparse block. By default we assume that all elements within a block are
+        zeroed-out. However, setting this variable sets the target number of
+        zeros per block. The zeros within each block are chosen as the *smallest
+        absolute values*.
+
+    Args:
+
+        sparsity_level: The target level of sparsity
+        sparse_block_shape: The shape of a sparse block (see note below)
+        zeros_per_block: Number of zeros in a sparse block
+        norm: Norm to use. Could be either `int` or a callable.
+            If `int`, only L1 and L2 are implemented.
+
+    Note::
+        The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS),
+        irrespective of what the rows / cols mean in the data tensor. That means,
+        if you were to sparsify a weight tensor in the nn.Linear, which has a
+        weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output
+        channels, while the `block_COLS` would refer to the input channels.
+
+    Note::
+        All arguments to the WeightNormSparsifier constructor are "default"
+        arguments and could be overriden by the configuration provided in the
+        `prepare` step.
+    """
+    def __init__(self,
+                 sparsity_level: float = 0.5,
+                 sparse_block_shape: Tuple[int, int] = (1, 4),
+                 zeros_per_block: Optional[int] = None,
+                 norm: Optional[Union[Callable, int]] = None):
+        if zeros_per_block is None:
+            zeros_per_block = reduce(operator.mul, sparse_block_shape)
+        defaults = {
+            "sparsity_level": sparsity_level,
+            "sparse_block_shape": sparse_block_shape,
+            "zeros_per_block": zeros_per_block,
+        }
+        if norm is None:
+            norm = 2
+        if callable(norm):
+            self.norm_fn = norm
+        elif norm == 1:
+            self.norm_fn = lambda T: T.abs()
+        elif norm == 2:
+            self.norm_fn = lambda T: T * T
+        else:
+            raise NotImplementedError(f"L-{norm} is not yet implemented.")
+        super().__init__(defaults=defaults)
+
+    def _scatter_fold_block_mask(self, output_shape, dim, indices, block_shape,
+                                 mask=None, input_shape=None, device=None):
+        r"""Creates patches of size `block_shape` after scattering the indices."""
+        if mask is None:
+            assert input_shape is not None
+            mask = torch.ones(input_shape, device=device)
+        mask.scatter_(dim=dim, index=indices, value=0)
+        mask.data = F.fold(mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape)
+        return mask
+
+    def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shape, mask=None):
+        r"""Creates a tensor-level mask.
+
+        Tensor-level mask is described as a mask, where the granularity of sparsification of the
+        smallest patch is the sparse_block_shape. That means, that for a given mask and a
+        sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape.
+
+        In this context, `sparsity_level` describes the fraction of sparse patches.
+        """
+        h, w = data.shape[-2:]
+        block_h, block_w = sparse_block_shape
+        dh = (block_h - h % block_h) % block_h
+        dw = (block_w - w % block_w) % block_w
+
+        if mask is None:
+            mask = torch.ones(h + dh, w + dw, device=data.device)
+
+        if sparsity_level >= 1.0:
+            mask.data = torch.zeros_like(mask)
+            return mask
+        elif sparsity_level <= 0.0:
+            mask.data = torch.ones_like(mask)
+            return mask
+
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+        if values_per_block > 1:
+            # Reduce the data
+            data = F.avg_pool2d(
+                data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape, ceil_mode=True
+            )
+        data = data.flatten()
+        num_blocks = len(data)
+
+        data = data.repeat(1, values_per_block, 1)
+
+        threshold_idx = int(round(sparsity_level * num_blocks))
+        threshold_idx = max(0, min(num_blocks - 1, threshold_idx))  # Sanity check
+        _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False)
+
+        # Temp reshape for mask
+        mask_reshape = mask.reshape(data.shape)  # data might be reshaped
+        self._scatter_fold_block_mask(
+            dim=2, output_shape=(h + dh, w + dw),
+            indices=sorted_idx, block_shape=sparse_block_shape, mask=mask_reshape
+        )
+        mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
+        return mask
+
+    def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None):
+        r"""Creates a block-level mask.
+
+        Block-level mask is described as a mask, where the granularity of sparsification of the
+        largest patch is the sparse_block_shape. That means that for a given mask and a
+        sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape.
+
+        In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch.
+        """
+        h, w = data.shape[-2:]
+        block_h, block_w = sparse_block_shape
+        dh = (block_h - h % block_h) % block_h
+        dw = (block_w - w % block_w) % block_w
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+
+        if mask is None:
+            mask = torch.ones((h + dh, w + dw), device=data.device)
+
+        if values_per_block == zeros_per_block:
+            # Everything should be sparsified
+            mask.data = torch.zeros_like(mask)
+            return mask
+
+        # create a new padded tensor like data (to match the block_shape)
+        padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device)
+        padded_data.fill_(torch.nan)
+        padded_data[:h, :w] = data
+        unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape)
+
+        # Temp reshape for mask
+        mask_reshape = mask.reshape(unfolded_data.shape)
+        _, sorted_idx = torch.topk(unfolded_data, k=zeros_per_block, dim=1, largest=False)
+
+        self._scatter_fold_block_mask(
+            dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape
+        )
+
+        mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous()
+        return mask
+
+    def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape,
+                    zeros_per_block, **kwargs):
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+        if zeros_per_block > values_per_block:
+            raise ValueError(
+                "Number of zeros per block cannot be more than the total number of elements in that block."
+            )
+        if zeros_per_block < 0:
+            raise ValueError("Number of zeros per block should be positive.")
+
+        mask = getattr(module.parametrizations, tensor_name)[0].mask
+        if sparsity_level <= 0 or zeros_per_block == 0:
+            mask.data = torch.ones_like(mask)
+        elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
+            mask.data = torch.zeros_like(mask)
+        else:
+            ww = self.norm_fn(getattr(module, tensor_name))
+            tensor_mask = self._make_tensor_mask(
+                data=ww, input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape
+            )
+            if values_per_block != zeros_per_block:
+                block_mask = self._make_block_mask(data=ww, sparse_block_shape=sparse_block_shape,
+                                                   zeros_per_block=zeros_per_block)
+                tensor_mask = torch.logical_or(tensor_mask, block_mask)
+            mask.data = tensor_mask
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4d7921b6d8702567e01038b3eea39e847562a6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/__init__.py
@@ -0,0 +1,189 @@
+# flake8: noqa: F403
+
+from .fake_quantize import *  # noqa: F403
+from .fuse_modules import fuse_modules  # noqa: F403
+from .fuse_modules import fuse_modules_qat  # noqa: F403
+from .fuser_method_mappings import *  # noqa: F403
+from .observer import *  # noqa: F403
+from .qconfig import *  # noqa: F403
+from .qconfig_mapping import *  # noqa: F403
+from .quant_type import *  # noqa: F403
+from .quantization_mappings import *  # type: ignore[no-redef]
+from .quantize import *  # noqa: F403
+from .quantize_jit import *  # noqa: F403
+from .stubs import *  # noqa: F403
+from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval
+from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train
+from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval
+from .pt2e.generate_numeric_debug_handle import generate_numeric_debug_handle  # noqa: F401
+from typing import Union, List, Callable, Tuple, Optional
+from torch import Tensor
+import torch
+
+ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
+ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
+
+__all__ = [
+    "DeQuantStub",
+    "FakeQuantize",
+    "FakeQuantizeBase",
+    "FixedQParamsFakeQuantize",
+    "FixedQParamsObserver",
+    "FusedMovingAvgObsFakeQuantize",
+    "HistogramObserver",
+    "MatchAllNode",
+    "MinMaxObserver",
+    "MovingAverageMinMaxObserver",
+    "MovingAveragePerChannelMinMaxObserver",
+    "NoopObserver",
+    "ObserverBase",
+    "ObserverOrFakeQuantize",
+    "Pattern",
+    "PerChannelMinMaxObserver",
+    "PlaceholderObserver",
+    "QConfig",
+    "QConfigAny",
+    "QConfigDynamic",
+    "QConfigMapping",
+    "QuantStub",
+    "QuantType",
+    "QuantWrapper",
+    "RecordingObserver",
+    "ReuseInputObserver",
+    "UniformQuantizationObserverBase",
+    "add_quant_dequant",
+    "convert",
+    "convert_dynamic_jit",
+    "convert_jit",
+    "default_affine_fixed_qparams_fake_quant",
+    "default_affine_fixed_qparams_observer",
+    "default_debug_observer",
+    "default_dynamic_fake_quant",
+    "default_dynamic_quant_observer",
+    "default_embedding_fake_quant",
+    "default_embedding_fake_quant_4bit",
+    "default_eval_fn",
+    "default_fake_quant",
+    "default_fixed_qparams_range_0to1_fake_quant",
+    "default_fixed_qparams_range_0to1_observer",
+    "default_fixed_qparams_range_neg1to1_fake_quant",
+    "default_fixed_qparams_range_neg1to1_observer",
+    "default_float_qparams_observer",
+    "default_float_qparams_observer_4bit",
+    "default_fused_act_fake_quant",
+    "default_fused_per_channel_wt_fake_quant",
+    "default_fused_wt_fake_quant",
+    "default_histogram_fake_quant",
+    "default_histogram_observer",
+    "default_observer",
+    "default_per_channel_weight_fake_quant",
+    "default_per_channel_weight_observer",
+    "default_placeholder_observer",
+    "default_reuse_input_observer",
+    "default_symmetric_fixed_qparams_fake_quant",
+    "default_symmetric_fixed_qparams_observer",
+    "default_weight_fake_quant",
+    "default_weight_observer",
+    "disable_fake_quant",
+    "disable_observer",
+    "enable_fake_quant",
+    "enable_observer",
+    "fuse_conv_bn",
+    "fuse_conv_bn_jit",
+    "fuse_conv_bn_relu",
+    "fuse_convtranspose_bn",
+    "fuse_linear_bn",
+    "fuse_modules",
+    "fuse_modules_qat",
+    "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
+    "fused_wt_fake_quant_range_neg_127_to_127",
+    "get_combined_dict",
+    "get_default_compare_output_module_list",
+    "get_default_custom_config_dict",
+    "get_default_dynamic_quant_module_mappings",
+    "get_default_dynamic_sparse_quant_module_mappings",
+    "get_default_float_to_quantized_operator_mappings",
+    "get_default_qat_module_mappings",
+    "get_default_qat_qconfig",
+    "get_default_qat_qconfig_dict",
+    "get_default_qat_qconfig_mapping",
+    "get_default_qconfig",
+    "get_default_qconfig_dict",
+    "get_default_qconfig_mapping",
+    "get_default_qconfig_propagation_list",
+    "get_default_static_quant_module_mappings",
+    "get_default_static_quant_reference_module_mappings",
+    "get_default_static_sparse_quant_module_mappings",
+    "get_dynamic_quant_module_class",
+    "get_embedding_qat_module_mappings",
+    "get_embedding_static_quant_module_mappings",
+    "get_fuser_method",
+    "get_fuser_method_new",
+    "get_observer_state_dict",
+    "get_quantized_operator",
+    "get_static_quant_module_class",
+    "load_observer_state_dict",
+    "move_exported_model_to_eval",
+    "move_exported_model_to_train",
+    "allow_exported_model_train_eval",
+    "no_observer_set",
+    "per_channel_weight_observer_range_neg_127_to_127",
+    "prepare",
+    "prepare_dynamic_jit",
+    "prepare_jit",
+    "prepare_qat",
+    "propagate_qconfig_",
+    "qconfig_equals",
+    "quantize",
+    "quantize_dynamic",
+    "quantize_dynamic_jit",
+    "quantize_jit",
+    "quantize_qat",
+    "script_qconfig",
+    "script_qconfig_dict",
+    "swap_module",
+    "weight_observer_range_neg_127_to_127",
+    "generate_numeric_debug_handle",
+]
+
+def default_eval_fn(model, calib_data):
+    r"""Define the default evaluation function.
+
+    Default evaluation function takes a torch.utils.data.Dataset or a list of
+    input Tensors and run the model on the dataset
+    """
+    for data, target in calib_data:
+        model(data)
+
+class _DerivedObserverOrFakeQuantize(ObserverBase):
+    r"""This observer is used to describe an observer whose quantization parameters
+    are derived from other observers
+    """
+
+    def __init__(
+        self,
+        dtype: torch.dtype,
+        obs_or_fqs: List[ObserverOrFakeQuantize],
+        derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]],
+        quant_min: Optional[int]=None,
+        quant_max: Optional[int]=None,
+        qscheme: Optional[torch.qscheme]=None,
+        ch_axis: Optional[int] = None
+    ):
+        super().__init__(dtype)
+        self.obs_or_fqs = obs_or_fqs
+        self.derive_qparams_fn = derive_qparams_fn
+        self.quant_min = quant_min
+        self.quant_max = quant_max
+        self.qscheme = qscheme
+        self.ch_axis = ch_axis
+
+        from .utils import is_per_channel
+        if is_per_channel(self.qscheme):
+            assert self.ch_axis is not None, "Must provide a valid ch_axis if qscheme is per channel"
+
+    def forward(self, x: Tensor) -> Tensor:
+        return x
+
+    def calculate_qparams(self):
+        return self.derive_qparams_fn(self.obs_or_fqs)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b9648e9edd7a70d95d3c5ea7da36384b74694bd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..860feec201c411ed8e851bd097ca801b476969bc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1f7f6000842706549ff42ab6afdfa82421e0f17
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c24c19627051ff791d121e630d637429fb698d0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6033f1ec149dadee7bd33c16996e56dbbbb44616
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a413d9072c59484bd9ebd3f797a404563d5c9bf0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7188bfb694bd5849d22ec108f01b90583886fd5d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/observer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/observer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c892e88019eaa92b1fa7159a3c5fb6eb86bb56da
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/observer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e74006f270c51a920b784c89a62c2b58daf9086c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec961f6314650716a1e60da4136eb3266ae9afbe
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59e0220b163a10ea87d3e5f2b90a23e416397e42
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be729b742a9ea0abb7ca9225e1bd6efed8465cf9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3d5a3a04274739a3139b9623d23ee5131216977
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ae86c421b71ef7108056cd51cb412747d87e8e9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad41b6793bbc33fed71186ddf89702397cc0a9b3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bcba94bb5d81805476dbe9aabde36bbbd5d2872
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78ec8b908dea789d1b69d45b7664a15d29071e9d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c1425eb9deafd13fc5846c4ac67b0e2cec9e8d8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/_correct_bias.py b/MLPY/Lib/site-packages/torch/ao/quantization/_correct_bias.py
new file mode 100644
index 0000000000000000000000000000000000000000..646cbb7492bbfa7fadbbdd5bbc7a1677a9510546
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/_correct_bias.py
@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+import torch.ao.nn.quantized as nnq
+
+import torch.ao.quantization
+import torch.ao.ns._numeric_suite as ns
+
+__all__ = [
+    "get_module",
+    "parent_child_names",
+    "get_param",
+    "MeanShadowLogger",
+    "bias_correction",
+]
+
+_supported_modules = {nn.Linear, nn.Conv2d}
+_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
+
+def get_module(model, name):
+    """Given name of submodule, this function grabs the submodule from given model."""
+    return dict(model.named_modules())[name]
+
+def parent_child_names(name):
+    """Split full name of submodule into parent submodule's full name and submodule's name."""
+    split_name = name.rsplit('.', 1)
+    if len(split_name) == 1:
+        return '', split_name[0]
+    else:
+        return split_name[0], split_name[1]
+
+def get_param(module, attr):
+    """Get the parameter given a module and attribute.
+
+    Sometimes the weights/bias attribute gives you the raw tensor, but sometimes
+    gives a function that will give you the raw tensor, this function takes care of that logic
+    """
+    param = getattr(module, attr, None)
+    if callable(param):
+        return param()
+    else:
+        return param
+
+class MeanShadowLogger(ns.Logger):
+    """Mean Logger for a Shadow module.
+
+    A logger for a Shadow module whose purpose is to record the rolling mean
+    of the data passed to the floating point and quantized models
+    """
+
+    def __init__(self):
+        """Set up initial values for float and quantized stats, count, float sum, and quant sum."""
+        super().__init__()
+        self.stats["float"] = None
+        self.stats["quantized"] = None
+        self.count = 0
+        self.float_sum = None
+        self.quant_sum = None
+
+    def forward(self, x, y):
+        """Compute the average of quantized and floating-point data from modules.
+
+        The inputs x,y are output data from the quantized and floating-point modules.
+        x is for the quantized module, y is for the floating point module
+        """
+        if x.is_quantized:
+            x = x.dequantize()
+
+        self.count += 1
+        if self.stats["quantized"] is None:
+            self.stats["quantized"] = x
+            self.quant_sum = x
+        else:
+            self.quant_sum += x
+            self.stats["quantized"] = self.quant_sum / self.count
+
+        if self.stats["float"] is None:
+            self.stats["float"] = y
+            self.float_sum = y
+        else:
+            self.float_sum += y
+            self.stats["float"] = self.float_sum / self.count
+
+    def clear(self):
+        self.stats["float"] = None
+        self.stats["quantized"] = None
+        self.count = 0
+        self.float_sum = None
+        self.quant_sum = None
+
+def bias_correction(float_model, quantized_model, img_data, target_modules=_supported_modules_quantized, neval_batches=None):
+    """Perform bias correction on a module.
+
+    Using numeric suite shadow module, the expected output of the floating point and quantized modules
+    is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused
+    by quantization
+    Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2)
+
+    Args:
+        float_model: a trained model that serves as a reference to what bias correction should aim for
+        quantized_model: quantized form of float_model that bias correction is to applied to
+        img_data: calibration data to estimate the expected output (used to find quantization error)
+        target_modules: specifies what submodules in quantized_model need bias correction (can be extended to
+                unquantized submodules)
+        neval_batches: a cap to the number of batches you want to be used for estimating the expected output
+    """
+    ns.prepare_model_with_stubs(float_model, quantized_model, _supported_modules, MeanShadowLogger)
+
+    uncorrected_modules = {}
+    for name, submodule in quantized_model.named_modules():
+        if type(submodule) in target_modules:
+            uncorrected_modules[name] = submodule
+
+    for uncorrected_module in uncorrected_modules:
+        quantized_submodule = get_module(quantized_model, uncorrected_module)
+        bias = get_param(quantized_submodule, 'bias')
+        if bias is not None:
+
+            count = 0
+            for data in img_data:
+                quantized_model(data[0])
+                count += 1
+                if count == neval_batches:
+                    break
+            ob_dict = ns.get_logger_dict(quantized_model)
+            parent_name, _ = parent_child_names(uncorrected_module)
+
+            float_data = ob_dict[parent_name + '.stats']['float']
+            quant_data = ob_dict[parent_name + '.stats']['quantized']
+
+            # math for expected_error
+            quantization_error = quant_data - float_data
+            dims = list(range(quantization_error.dim()))
+            # Note: we don't want to take the mean over the output channel dimension
+            dims.remove(1)
+            expected_error = torch.mean(quantization_error, dims)
+
+            updated_bias = bias.data - expected_error
+
+            bias.data = updated_bias
+
+            # Resets the data contained in the loggers
+            for name, submodule in quantized_model.named_modules():
+                if isinstance(submodule, MeanShadowLogger):
+                    submodule.clear()
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/_equalize.py b/MLPY/Lib/site-packages/torch/ao/quantization/_equalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..a752ecc3406b5b83fa3d993960af18f007e9987a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/_equalize.py
@@ -0,0 +1,182 @@
+import torch
+import copy
+from typing import Dict, Any
+
+__all__ = [
+    "set_module_weight",
+    "set_module_bias",
+    "get_module_weight",
+    "get_module_bias",
+    "max_over_ndim",
+    "min_over_ndim",
+    "channel_range",
+    "cross_layer_equalization",
+    "equalize",
+    "converged",
+]
+
+_supported_types = {torch.nn.Conv2d, torch.nn.Linear}
+_supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU}
+_all_supported_types = _supported_types.union(_supported_intrinsic_types)
+
+def set_module_weight(module, weight) -> None:
+    if type(module) in _supported_types:
+        module.weight = torch.nn.Parameter(weight)
+    else:
+        module[0].weight = torch.nn.Parameter(weight)
+
+def set_module_bias(module, bias) -> None:
+    if type(module) in _supported_types:
+        module.bias = torch.nn.Parameter(bias)
+    else:
+        module[0].bias = torch.nn.Parameter(bias)
+
+def get_module_weight(module):
+    if type(module) in _supported_types:
+        return module.weight
+    else:
+        return module[0].weight
+
+def get_module_bias(module):
+    if type(module) in _supported_types:
+        return module.bias
+    else:
+        return module[0].bias
+
+def max_over_ndim(input, axis_list, keepdim=False):
+    """Apply 'torch.max' over the given axes."""
+    axis_list.sort(reverse=True)
+    for axis in axis_list:
+        input, _ = input.max(axis, keepdim)
+    return input
+
+def min_over_ndim(input, axis_list, keepdim=False):
+    """Apply 'torch.min' over the given axes."""
+    axis_list.sort(reverse=True)
+    for axis in axis_list:
+        input, _ = input.min(axis, keepdim)
+    return input
+
+def channel_range(input, axis=0):
+    """Find the range of weights associated with a specific channel."""
+    size_of_tensor_dim = input.ndim
+    axis_list = list(range(size_of_tensor_dim))
+    axis_list.remove(axis)
+
+    mins = min_over_ndim(input, axis_list)
+    maxs = max_over_ndim(input, axis_list)
+
+    assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
+    return maxs - mins
+
+def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
+    """Scale the range of Tensor1.output to equal Tensor2.input.
+
+    Given two adjacent tensors', the weights are scaled such that
+    the ranges of the first tensors' output channel are equal to the
+    ranges of the second tensors' input channel
+    """
+    if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types:
+        raise ValueError("module type not supported:", type(module1), " ", type(module2))
+
+    weight1 = get_module_weight(module1)
+    weight2 = get_module_weight(module2)
+
+    if weight1.size(output_axis) != weight2.size(input_axis):
+        raise TypeError("Number of output channels of first arg do not match \
+        number input channels of second arg")
+
+    bias = get_module_bias(module1)
+
+    weight1_range = channel_range(weight1, output_axis)
+    weight2_range = channel_range(weight2, input_axis)
+
+    # producing scaling factors to applied
+    weight2_range += 1e-9
+    scaling_factors = torch.sqrt(weight1_range / weight2_range)
+    inverse_scaling_factors = torch.reciprocal(scaling_factors)
+
+    bias = bias * inverse_scaling_factors
+
+    # formatting the scaling (1D) tensors to be applied on the given argument tensors
+    # pads axis to (1D) tensors to then be broadcasted
+    size1 = [1] * weight1.ndim
+    size1[output_axis] = weight1.size(output_axis)
+    size2 = [1] * weight2.ndim
+    size2[input_axis] = weight2.size(input_axis)
+
+    scaling_factors = torch.reshape(scaling_factors, size2)
+    inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
+
+    weight1 = weight1 * inverse_scaling_factors
+    weight2 = weight2 * scaling_factors
+
+    set_module_weight(module1, weight1)
+    set_module_bias(module1, bias)
+    set_module_weight(module2, weight2)
+
+def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
+    """Equalize modules until convergence is achieved.
+
+    Given a list of adjacent modules within a model, equalization will
+    be applied between each pair, this will repeated until convergence is achieved
+
+    Keeps a copy of the changing modules from the previous iteration, if the copies
+    are not that different than the current modules (determined by converged_test),
+    then the modules have converged enough that further equalizing is not necessary
+
+    Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
+
+    Args:
+        model: a model (nn.module) that equalization is to be applied on
+        paired_modules_list: a list of lists where each sublist is a pair of two
+            submodules found in the model, for each pair the two submodules generally
+            have to be adjacent in the model to get expected/reasonable results
+        threshold: a number used by the converged function to determine what degree
+            similarity between models is necessary for them to be called equivalent
+        inplace: determines if function is inplace or not
+    """
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    name_to_module : Dict[str, torch.nn.Module] = {}
+    previous_name_to_module: Dict[str, Any] = {}
+    name_set = {name for pair in paired_modules_list for name in pair}
+
+    for name, module in model.named_modules():
+        if name in name_set:
+            name_to_module[name] = module
+            previous_name_to_module[name] = None
+    while not converged(name_to_module, previous_name_to_module, threshold):
+        for pair in paired_modules_list:
+            previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
+            previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
+
+            cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
+
+    return model
+
+def converged(curr_modules, prev_modules, threshold=1e-4):
+    """Test whether modules are converged to a specified threshold.
+
+    Tests for the summed norm of the differences between each set of modules
+    being less than the given threshold
+
+    Takes two dictionaries mapping names to modules, the set of names for each dictionary
+    should be the same, looping over the set of names, for each name take the difference
+    between the associated modules in each dictionary
+
+    """
+    if curr_modules.keys() != prev_modules.keys():
+        raise ValueError("The keys to the given mappings must have the same set of names of modules")
+
+    summed_norms = torch.tensor(0.)
+    if None in prev_modules.values():
+        return False
+    for name in curr_modules.keys():
+        curr_weight = get_module_weight(curr_modules[name])
+        prev_weight = get_module_weight(prev_modules[name])
+
+        difference = curr_weight.sub(prev_weight)
+        summed_norms += torch.norm(difference)
+    return bool(summed_norms < threshold)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/_learnable_fake_quantize.py b/MLPY/Lib/site-packages/torch/ao/quantization/_learnable_fake_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..df448bd15c15f48c4983c8fc3694560a9034c090
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/_learnable_fake_quantize.py
@@ -0,0 +1,164 @@
+import torch
+from torch.nn.parameter import Parameter
+from typing import List
+
+__all__: List[str] = []
+
+class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
+    r"""Generalized extension of the FakeQuantize module in fake_quantize.py.
+
+    This is an extension of the FakeQuantize module in fake_quantize.py, which
+    supports more generalized lower-bit quantization and support learning of the scale
+    and zero point parameters through backpropagation. For literature references,
+    please see the class _LearnableFakeQuantizePerTensorOp.
+
+    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
+    module also includes the following attributes to support quantization parameter learning.
+
+    * :attr:`channel_len` defines the length of the channel when initializing scale and zero point
+      for the per channel case.
+
+    * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
+      normalized by the constant, which is proportional to the square root of the number of
+      elements in the tensor. The related literature justifying the use of this particular constant
+      can be found here: https://openreview.net/pdf?id=rkgO66VKDS.
+
+    * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.
+
+    * :attr:`static_enabled` defines the flag for using observer's static estimation for
+      scale and zero point.
+
+    * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
+    """
+    def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1,
+                 use_grad_scaling=False, **observer_kwargs):
+        super().__init__()
+        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
+        self.quant_min = quant_min
+        self.quant_max = quant_max
+        # also pass quant_min and quant_max to observer
+        observer_kwargs["quant_min"] = quant_min
+        observer_kwargs["quant_max"] = quant_max
+        self.use_grad_scaling = use_grad_scaling
+        if channel_len == -1:
+            self.scale = Parameter(torch.tensor([scale]))
+            self.zero_point = Parameter(torch.tensor([zero_point]))
+        else:
+            assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer."
+            self.scale = Parameter(torch.tensor([scale] * channel_len))
+            self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
+
+        self.activation_post_process = observer(**observer_kwargs)
+        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
+            'quant_min out of bound'
+        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
+            'quant_max out of bound'
+        self.dtype = self.activation_post_process.dtype
+        self.qscheme = self.activation_post_process.qscheme
+        self.ch_axis = self.activation_post_process.ch_axis \
+            if hasattr(self.activation_post_process, 'ch_axis') else -1
+        self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
+        self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8))
+        self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8))
+
+        bitrange = torch.tensor(quant_max - quant_min + 1).double()
+        self.bitwidth = int(torch.log2(bitrange).item())
+        self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
+
+    @torch.jit.export
+    def enable_param_learning(self):
+        r"""Enable parameter learning over static observer estimates.
+
+        Enables learning of quantization parameters and
+        disables static observer estimates. Forward path returns fake quantized X.
+        """
+        self.toggle_qparam_learning(enabled=True) \
+            .toggle_fake_quant(enabled=True) \
+            .toggle_observer_update(enabled=False)
+        return self
+
+    @torch.jit.export
+    def enable_static_estimate(self):
+        """Enable static estimates of quantization parameters.
+
+        Enables static observer estimates and disables learning of
+        quantization parameters. Forward path returns fake quantized X.
+        """
+        self.toggle_qparam_learning(enabled=False) \
+            .toggle_fake_quant(enabled=True) \
+            .toggle_observer_update(enabled=True)
+
+    @torch.jit.export
+    def enable_static_observation(self):
+        """Enable accumulation of data without updating quantization parameters.
+
+        Enables static observer accumulating data from input but doesn't
+        update the quantization parameters. Forward path returns the original X.
+        """
+        self.toggle_qparam_learning(enabled=False) \
+            .toggle_fake_quant(enabled=False) \
+            .toggle_observer_update(enabled=True)
+
+    @torch.jit.export
+    def toggle_observer_update(self, enabled=True):
+        self.static_enabled[0] = int(enabled)  # type: ignore[operator]
+        return self
+
+    @torch.jit.export
+    def enable_observer(self, enabled=True):
+        self.toggle_observer_update(enabled)
+
+    @torch.jit.export
+    def toggle_qparam_learning(self, enabled=True):
+        self.learning_enabled[0] = int(enabled)  # type: ignore[operator]
+        self.scale.requires_grad = enabled
+        self.zero_point.requires_grad = enabled
+        return self
+
+    @torch.jit.export
+    def toggle_fake_quant(self, enabled=True):
+        self.fake_quant_enabled[0] = int(enabled)
+        return self
+
+    @torch.jit.export
+    def observe_quant_params(self):
+        print(f'_LearnableFakeQuantize Scale: {self.scale.detach()}')
+        print(f'_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}')
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        self.scale.data.clamp_(min=self.eps.item())  # type: ignore[operator]
+        scale = self.scale.detach()
+        zero_point = self.zero_point.detach().round().clamp(self.quant_min, self.quant_max).long()
+        return scale, zero_point
+
+    def forward(self, X):
+        if self.static_enabled[0] == 1:  # type: ignore[index]
+            self.activation_post_process(X.detach())
+            _scale, _zero_point = self.activation_post_process.calculate_qparams()
+            _scale = _scale.to(self.scale.device)
+            _zero_point = _zero_point.to(self.zero_point.device)
+            self.scale.data.copy_(_scale)
+            self.zero_point.data.copy_(_zero_point)
+        else:
+            self.scale.data.clamp_(min=self.eps.item())  # type: ignore[operator]
+
+        if self.fake_quant_enabled[0] == 1:
+            if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
+                self.zero_point.data.zero_()
+
+            if self.use_grad_scaling:
+                grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
+            else:
+                grad_factor = 1.0
+            if self.qscheme in (
+                    torch.per_channel_symmetric, torch.per_channel_affine):
+                X = torch._fake_quantize_learnable_per_channel_affine(
+                    X, self.scale, self.zero_point, self.ch_axis,
+                    self.quant_min, self.quant_max, grad_factor)
+            else:
+                X = torch._fake_quantize_learnable_per_tensor_affine(
+                    X, self.scale, self.zero_point,
+                    self.quant_min, self.quant_max, grad_factor)
+
+        return X
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e4baa9c3222f05ca78d5c7baa45fb16bf71332
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__init__.py
@@ -0,0 +1,23 @@
+from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, DTypeWithConstraints, ObservationType
+from .fbgemm import get_fbgemm_backend_config
+from .native import get_native_backend_config, get_native_backend_config_dict
+from .qnnpack import get_qnnpack_backend_config
+from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict
+from .executorch import get_executorch_backend_config
+from .onednn import get_onednn_backend_config
+
+__all__ = [
+    "get_fbgemm_backend_config",
+    "get_native_backend_config",
+    "get_native_backend_config_dict",
+    "get_qnnpack_backend_config",
+    "get_tensorrt_backend_config",
+    "get_tensorrt_backend_config_dict",
+    "get_executorch_backend_config",
+    "BackendConfig",
+    "BackendPatternConfig",
+    "DTypeConfig",
+    "DTypeWithConstraints",
+    "ObservationType",
+    "get_onednn_backend_config",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dace7a7a7f1216fbbb0f9c99560d4b2599d7ee03
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9be231216650f232bf2c8fea10f2a2d5208692da
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91bf2fc178c005941244392d7b6ed0b0fc483265
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d411d098c913907ca8edb1c227ea8c39bef7a543
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a442d8fefb81859fe14cd97bc6644cb1c64f88cf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8bde6b1749de64b132c53372c110ef8a3fe5eec
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c0a064164fe8d8fd12ce73a93b1bbc4a36a3b46
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa380a61c9a22045aee48e0b94942b0b31aa3b10
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20d30980f99738c93948c061c5dfe89f3529f462
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af0d3dd2e267b042350f48a5d508f63c4d99cdee
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4695bbe9d038876e1305725bf90a276c528ea204
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5cb014f5f756de0664e115dd35e56fd7de886de
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c1058dacfa21a9165773c945172707cf2596684
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6707c6bb9bb31c06068036e8e6baf7a06618768c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -0,0 +1,637 @@
+import copy
+import operator
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.quantized.reference as nnqr
+from collections import namedtuple
+from typing import Callable, Dict, List, Union
+from .backend_config import (
+    BackendPatternConfig,
+    DTypeConfig,
+    DTypeWithConstraints,
+    ObservationType,
+)
+from ..fuser_method_mappings import (
+    _sequential_wrapper2,
+    fuse_conv_bn,
+    fuse_conv_bn_relu,
+    fuse_linear_bn,
+    fuse_convtranspose_bn,
+)
+
+__all__: List[str] = []
+
+# TODO: rename to be more explicit, e.g. qat_conv_relu
+_ConvMetadata = namedtuple(
+    "_ConvMetadata",
+    ["root", "transpose", "bn", "reference", "transpose_reference",
+     "fused_conv_relu", "fused_conv_bn", "fused_conv_bn_relu",
+     "qat", "relu_qat", "bn_qat", "bn_relu_qat",
+     "func", "func_transpose"])
+_Conv1dMetadata = _ConvMetadata(
+    nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, nnqr.ConvTranspose1d,
+    nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d,
+    nnqat.Conv1d, nniqat.ConvReLU1d, nniqat.ConvBn1d, nniqat.ConvBnReLU1d,
+    F.conv1d, F.conv_transpose1d)
+_Conv2dMetadata = _ConvMetadata(
+    nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nnqr.Conv2d, nnqr.ConvTranspose2d,
+    nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d,
+    nnqat.Conv2d, nniqat.ConvReLU2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d,
+    F.conv2d, F.conv_transpose2d)
+_Conv3dMetadata = _ConvMetadata(
+    nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, nnqr.Conv3d, nnqr.ConvTranspose3d,
+    nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d,
+    nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d,
+    F.conv3d, F.conv_transpose3d)
+
+# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values
+# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh
+_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints(
+    dtype=torch.quint8,
+    quant_min_lower_bound=0,
+    quant_max_upper_bound=255,
+    scale_exact_match=1.0 / 256.0,
+    zero_point_exact_match=0,
+)
+_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints(
+    dtype=torch.quint8,
+    quant_min_lower_bound=0,
+    quant_max_upper_bound=255,
+    scale_exact_match=2.0 / 256.0,
+    zero_point_exact_match=128,
+)
+_FIXED_QPARAMS_OP_TO_CONSTRAINTS: Dict[Union[Callable, str], DTypeWithConstraints] = {
+    torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+    torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+    "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+    "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+}
+
+def _get_binary_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    binary_op_configs: List[BackendPatternConfig] = []
+    num_tensor_args_to_observation_type_mapping = {
+        # TODO: this is not used right now since we have extra check in prepare
+        # will need to change this to NO_OBSERVER later after we implemented
+        # Tensor dtype inference properly
+        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
+        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+    }
+    for op_with_quantized_bop_scalar_variant in [operator.add, torch.add, operator.mul, torch.mul]:
+        bop_patterns = [
+            (op_with_quantized_bop_scalar_variant, nn.ReLU),
+            (op_with_quantized_bop_scalar_variant, F.relu),
+            (op_with_quantized_bop_scalar_variant, torch.relu),
+            op_with_quantized_bop_scalar_variant
+        ]
+        for bop_pattern in bop_patterns:
+            binary_op_configs.append(
+                BackendPatternConfig(bop_pattern)
+                    .set_dtype_configs(dtype_configs)  # noqa: E131
+                    ._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping))
+    # matmul
+    binary_op_configs.append(
+        BackendPatternConfig(torch.matmul)
+        .set_dtype_configs(dtype_configs)  # noqa: E131
+    )
+    return binary_op_configs
+
+def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    """
+    Return all configs related to linear modules and ops.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    linear_configs: List[BackendPatternConfig] = []
+
+    # (1) Single linear modules/functions
+    # -------------------------------------
+    # linear module
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.Linear)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(torch.nn.Linear)
+            .set_reference_quantized_module(nnqr.Linear)
+            .set_qat_module(nnqat.Linear))
+    # linear qat module
+    linear_configs.append(
+        BackendPatternConfig(nnqat.Linear)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(torch.nn.Linear)
+            .set_reference_quantized_module(nnqr.Linear))
+    # functional linear
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.functional.linear)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            ._set_input_type_to_index({"weight": 1, "bias": 2}))
+
+    # (2) Linear + relu
+    # -------------------
+    # 2.1 linear module + relu fusion config
+    # linear relu, linear module + relu module
+    linear_configs.append(
+        BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
+            .set_fused_module(nni.LinearReLU))
+    # linear relu, linear module + functional relu
+    linear_configs.append(
+        BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
+            .set_fused_module(nni.LinearReLU))
+
+    # 2.2 linear module + relu, fused module configs
+    # linear relu, fused module
+    linear_configs.append(
+        BackendPatternConfig(nni.LinearReLU)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(torch.nn.Linear)
+            .set_reference_quantized_module(nnqr.Linear)
+            .set_qat_module(nniqat.LinearReLU))
+    # linear relu, qat fused module
+    linear_configs.append(
+        BackendPatternConfig(nniqat.LinearReLU)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(torch.nn.Linear)
+            .set_reference_quantized_module(nnqr.Linear))
+    # 2.3 functional linear + relu configs
+    # linear relu, functional linear + relu module
+    linear_configs.append(
+        BackendPatternConfig((F.linear, torch.nn.ReLU))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs))
+    # linear relu, functional linear + functional relu
+    linear_configs.append(
+        BackendPatternConfig((F.linear, F.relu))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs))
+
+    # (3) Linear + batchnorm
+    # ------------------------
+    # 3.1 linear bn fusion
+    linear_configs.append(
+        BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_linear_bn)
+            .set_fused_module(nni.LinearBn1d))
+
+    # 3.2 linear bn fused
+    # linear bn, fused module
+    linear_configs.append(
+        BackendPatternConfig(nni.LinearBn1d)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(torch.nn.Linear)
+            .set_reference_quantized_module(nnqr.Linear)
+            .set_qat_module(nniqat.LinearBn1d))
+    # linear bn, qat fused module
+    linear_configs.append(
+        BackendPatternConfig(nniqat.LinearBn1d)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(torch.nn.Linear)
+            .set_reference_quantized_module(nnqr.Linear))
+    return linear_configs
+
+def _get_conv_configs(dtype_configs):
+    """
+    Return all configs related to conv modules and ops.
+    """
+    conv_configs = []
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]:
+
+        # (1) Single conv modules/functions
+        # -----------------------------------
+        # conv module
+        conv_configs.append(
+            BackendPatternConfig(convs.root)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference)
+                .set_qat_module(convs.qat))
+        # conv qat module
+        conv_configs.append(
+            BackendPatternConfig(convs.qat)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference))
+        # functional conv
+        conv_configs.append(
+            BackendPatternConfig(convs.func)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                ._set_input_type_to_index({"weight": 1, "bias": 2}))
+
+        # (2) Conv + relu
+        # -----------------
+        # 2.1 conv module + relu fusion configs
+        # conv relu fusion, conv module + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.root, torch.nn.ReLU))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+                .set_fused_module(convs.fused_conv_relu))
+        # conv relu fusion, conv module + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.root, F.relu))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+                .set_fused_module(convs.fused_conv_relu))
+        # 2.2 conv module + relu fused module configs
+        # conv relu, fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference)
+                .set_qat_module(convs.relu_qat))
+        # conv relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference))
+        # 2.3 functional conv + relu configs
+        # conv relu, functional conv + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.func, torch.nn.ReLU))
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs))
+        # conv relu, functional conv + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.func, F.relu))
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs))
+
+        # fused conv relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_qat_module(convs.relu_qat))
+
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference))
+
+        # (3) Conv + batchnorm (+ relu)
+        # -------------------------------
+        # 3.1 conv bn fusion configs
+        # conv + bn fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(fuse_conv_bn)
+                .set_fused_module(convs.fused_conv_bn))
+        # conv + bn + relu module fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(fuse_conv_bn_relu)
+                .set_fused_module(convs.fused_conv_bn_relu))
+        # conv + bn + relu functional fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, F.relu))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_root_module(convs.root)
+                .set_fuser_method(fuse_conv_bn_relu)
+                .set_fused_module(convs.fused_conv_bn_relu))
+        # TODO: we can add fusion for torch.relu as well
+
+        # 3.2 conv + bn (+ relu) fused module configs
+        # fused conv bn
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_qat_module(convs.bn_qat))
+
+        # fused conv bn relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn_relu)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_qat_module(convs.bn_relu_qat))
+
+        # conv bn, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_qat)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference))
+        # conv bn relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_relu_qat)
+                .set_observation_type(observation_type)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(convs.root)
+                .set_reference_quantized_module(convs.reference))
+
+        # (4) conv transpose and its fusion
+        # 4.1 conv transpose config
+        conv_configs.append(
+            BackendPatternConfig(convs.transpose)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_root_module(convs.transpose)
+                .set_reference_quantized_module(convs.transpose_reference))
+
+        # 4.2 conv transpose + bn fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.transpose, convs.bn))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(fuse_convtranspose_bn)
+                .set_root_module(convs.transpose)
+                .set_reference_quantized_module(convs.transpose_reference))
+
+        # 4.3 functional conv transpose
+        conv_configs.append(
+            BackendPatternConfig(convs.func_transpose)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                ._set_input_type_to_index({"weight": 1, "bias": 2}))
+
+    return conv_configs
+
+def _get_cat_config(dtype_configs: List[DTypeConfig]) -> BackendPatternConfig:
+    return BackendPatternConfig(torch.cat) \
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
+        .set_dtype_configs(dtype_configs)
+
+def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    ln_configs = []
+    ln_configs.append(
+        BackendPatternConfig(torch.nn.LayerNorm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    ln_configs.append(
+        BackendPatternConfig(torch.nn.functional.layer_norm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 2, "bias": 3})
+    )
+    return ln_configs
+
+def _get_default_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    configs = []
+    default_ops = [
+        torch.nn.ELU,
+        torch.nn.LeakyReLU,
+        torch.nn.Hardswish,
+        torch.nn.InstanceNorm1d,
+        torch.nn.InstanceNorm2d,
+        torch.nn.InstanceNorm3d,
+        torch.nn.Dropout,
+        torch.nn.PReLU,
+        torch.nn.functional.elu,
+        torch.nn.functional.hardswish,
+        torch.nn.functional.leaky_relu,
+        torch.nn.functional.dropout,
+    ]
+    for op in default_ops:
+        configs.append(
+            BackendPatternConfig(op)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(dtype_configs))
+
+    configs.append(
+        BackendPatternConfig(torch.nn.functional.group_norm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 2, "bias": 3})
+    )
+
+    configs.append(
+        BackendPatternConfig(torch.nn.functional.instance_norm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 3, "bias": 4})
+    )
+    return configs
+
+def _add_fixed_qparams_to_dtype_configs(
+    dtype_configs: List[DTypeConfig],
+    constraints: DTypeWithConstraints,
+) -> List[DTypeConfig]:
+    """
+    Return a copy of the list of DTypeConfigs where activations are subject to the specified
+    constraints required for fixed qparams ops.
+
+    If the data type doesn't match the one in the constraints, simply leave the corresponding
+    DTypeConfig unchanged.
+
+    If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations,
+    throw an exception since these settings are incompatible with fixed qparams ops.
+    """
+    new_dtype_configs = []
+    for dtype_config in dtype_configs:
+        dc = copy.deepcopy(dtype_config)
+        for orig_constraints in [dc.input_dtype_with_constraints, dc.output_dtype_with_constraints]:
+            if orig_constraints.dtype != constraints.dtype:
+                continue
+            if orig_constraints.scale_min_lower_bound is not None:
+                raise ValueError(f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}")
+            if orig_constraints.scale_max_upper_bound is not None:
+                raise ValueError(f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}")
+            orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound
+            orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound
+            orig_constraints.scale_exact_match = constraints.scale_exact_match
+            orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match
+        new_dtype_configs.append(dc)
+    return new_dtype_configs
+
+def _get_fixed_qparams_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    fixed_qparams_op_configs = []
+    for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items():
+        new_dtype_configs = _add_fixed_qparams_to_dtype_configs(dtype_configs, constraints)
+        fixed_qparams_op_configs.append(
+            BackendPatternConfig(fixed_qparam_op)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(new_dtype_configs))
+    return fixed_qparams_op_configs
+
+def _get_share_qparams_op_configs(dtype_configs):
+    """ Get the operator config for the operators that works for both float and quantized input
+    if input is quantized, the output Tensor shares the same quantization parameter
+    with input.
+    Example operator: avgpool2d, reshape, transpose, maxpool2d
+    Example observed operator:
+    observer_0 - avgpool2d - observer_0 (same observer instance as input)
+    """
+
+    def _get_share_qprams_op_backend_config(op):
+        return BackendPatternConfig(op) \
+            .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
+            .set_dtype_configs(dtype_configs)
+
+    share_qparams_ops = [
+        torch.nn.AdaptiveAvgPool1d,
+        torch.nn.AdaptiveAvgPool2d,
+        torch.nn.AdaptiveAvgPool3d,
+        torch.nn.AvgPool1d,
+        torch.nn.AvgPool2d,
+        torch.nn.AvgPool3d,
+        torch.nn.Hardtanh,
+        torch.nn.Identity,
+        torch.nn.MaxPool1d,
+        torch.nn.MaxPool2d,
+        torch.nn.MaxPool3d,
+        torch.nn.PixelShuffle,
+        torch.nn.PixelUnshuffle,
+        torch.nn.ReLU,
+        torch.nn.ReLU6,
+        torch.adaptive_avg_pool1d,
+        torch.nn.functional.adaptive_avg_pool2d,
+        torch.nn.functional.adaptive_avg_pool3d,
+        torch.nn.functional.hardtanh,
+        torch.nn.functional.hardtanh_,
+        torch.nn.functional.interpolate,
+        torch.nn.functional.max_pool1d,
+        torch.nn.functional.max_pool2d,
+        torch.nn.functional.max_pool3d,
+        torch.nn.functional.pixel_shuffle,
+        torch.nn.functional.pixel_unshuffle,
+        torch.nn.functional.relu,
+        torch.nn.functional.relu6,
+        torch.avg_pool1d,
+        torch._C._nn.avg_pool2d,
+        torch._C._nn.avg_pool3d,
+        torch.clamp,
+        torch.flatten,
+        torch.mean,
+        torch.narrow,
+        torch.repeat_interleave,
+        torch.transpose,
+        torch.squeeze,
+        torch.stack,
+        torch.unsqueeze,
+        operator.floordiv,
+        "contiguous",
+        "clamp",
+        "detach",
+        "detach_",
+        "mean",
+        "permute",
+        "repeat",
+        "repeat_interleave",
+        "reshape",
+        "resize_",
+        "relu",
+        "relu_",
+        "squeeze",
+        "squeeze_",
+        "transpose",
+        "unsqueeze",
+        "unsqueeze_",
+        "view"
+    ]
+    return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops]
+
+def _get_bn_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    """ Get configs related to batchnorm. """
+    bn_configs = []
+    bn_to_fused_bn = {
+        torch.nn.BatchNorm2d: nni.BNReLU2d,
+        torch.nn.BatchNorm3d: nni.BNReLU3d,
+    }
+    for bn in bn_to_fused_bn.keys():
+        fused_bn = bn_to_fused_bn[bn]
+        # bn module + relu module fusion config
+        bn_configs.append(
+            BackendPatternConfig((bn, nn.ReLU))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(_sequential_wrapper2(fused_bn))
+                .set_fused_module(fused_bn))
+        # bn module + F.relu fusion config
+        bn_configs.append(
+            BackendPatternConfig((bn, F.relu))
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                .set_fuser_method(_sequential_wrapper2(fused_bn))
+                .set_fused_module(fused_bn))
+        bn_configs.append(
+            BackendPatternConfig(bn)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(dtype_configs))
+
+    # fused bn configs
+    for fused_bn in bn_to_fused_bn.values():
+        bn_configs.append(
+            BackendPatternConfig(fused_bn)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(dtype_configs))
+    return bn_configs
+
+def _get_rnn_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    rnn_op_configs = []
+    for rnn_op, ref_rnn_op in [
+            (nn.GRUCell, nnqr.GRUCell),
+            (nn.LSTMCell, nnqr.LSTMCell),
+            (nn.RNNCell, nnqr.RNNCell),
+            (nn.LSTM, nnqr.LSTM),
+            (nn.GRU, nnqr.GRU)
+    ]:
+        rnn_op_configs.append(
+            BackendPatternConfig(rnn_op)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(rnn_op)
+                .set_reference_quantized_module(ref_rnn_op))
+    return rnn_op_configs
+
+def _get_embedding_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
+    embedding_op_configs = []
+    for embedding_op, qat_embedding_op, ref_embedding_op in [
+            (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
+            (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
+    ]:
+        embedding_op_configs.append(
+            BackendPatternConfig(embedding_op)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_qat_module(qat_embedding_op)
+                .set_root_module(embedding_op)
+                .set_reference_quantized_module(ref_embedding_op))
+
+        # config for qat op
+        embedding_op_configs.append(
+            BackendPatternConfig(qat_embedding_op)
+                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+                .set_dtype_configs(dtype_configs)
+                .set_root_module(embedding_op)
+                .set_reference_quantized_module(ref_embedding_op))
+    return embedding_op_configs
+
+def _get_tensor_info_op_configs(dtype_configs):
+    """
+    These ops work on tensors of different dtypes but return non-tensors
+    containing information about the input tensor.
+    """
+
+    def _get_config(op):
+        return BackendPatternConfig(op) \
+            .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED) \
+            .set_dtype_configs(dtype_configs)
+
+    return [_get_config(op) for op in ("shape", "size")]
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..09bc1a9453c001cfbb337417d5b84c8d52d957da
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py
@@ -0,0 +1,160 @@
+import operator
+import torch
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    DTypeConfig,
+    ObservationType,
+    BackendPatternConfig,
+)
+
+weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+from typing import List
+
+def get_linear_configs():
+    linear_configs = []
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+
+    # TODO: need to fix the way we insert observers for this pattern
+    # should be solved in the new fusion API
+    # reason that this doesn't work: the pattern is a bit complicated and we don't
+    # have a way to specify which input of the pattern we would like to observe
+    # pattern:
+    # bias input weight
+    # \     |    /
+    #  \    |   t
+    #   \   |  /
+    #    addmm
+    # we want to observe "weight" as weight, but there is not way to convey this
+    # information with current pattern language
+    #
+    # right now:
+    # original:
+    #         weight - t \
+    #         input  - addmm
+    # observed (no hack):
+    #      weight - t - observer \
+    #       input - observer - addmm
+    # target:
+    #      weight - observer - t \
+    #        input - observer - addmm
+
+    # def root_node_getter(node_pattern):
+    #     addmm, bias, act, weight = node_pattern
+    #     return addmm
+
+    # linear_configs.append(
+    #     BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default))
+    #     .set_observation_type(observation_type)  # noqa: E131
+    #     .set_dtype_configs(dtype_configs)
+    #     ._set_root_node_getter(root_node_getter))
+
+    linear_configs.append(
+        BackendPatternConfig(torch.ops.aten.addmm.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 2, "bias": 0})
+    )
+    # linear is decomposed to `t - mm` if bias is not present
+    linear_configs.append(
+        BackendPatternConfig(torch.ops.aten.mm.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1})
+    )
+    return linear_configs
+
+def get_conv_configs():
+    conv_configs = []
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+    conv_configs.append(
+        BackendPatternConfig(torch.ops.aten.convolution.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    conv_configs.append(
+        BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu.default))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    # TODO: remove when functionalization is supported in PT2 mode
+    conv_configs.append(
+        BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu_.default))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    return conv_configs
+
+def get_pooling_configs():
+    backend_pattern_configs = []
+    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+
+    def root_node_getter(node_pattern):
+        getitem, maxpool, index = node_pattern
+        return maxpool
+
+    backend_pattern_configs.append(
+        BackendPatternConfig()
+        ._set_pattern_complex_format((operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_root_node_getter(root_node_getter)
+    )
+
+    return backend_pattern_configs
+
+def get_relu_configs():
+    backend_pattern_configs = []
+    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+    backend_pattern_configs.append(
+        BackendPatternConfig(torch.ops.aten.relu.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs))
+    return backend_pattern_configs
+
+def get_binary_op_configs():
+    binary_op_configs: List[BackendPatternConfig] = []
+    dtype_configs = [weighted_op_quint8_dtype_config]
+    num_tensor_args_to_observation_type_mapping = {
+        # TODO: this is not used right now since we have extra check in prepare
+        # will need to change this to NO_OBSERVER later after we implemented
+        # Tensor dtype inference properly
+        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
+        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+    }
+    for op_with_quantized_bop_scalar_variant in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
+        bop_patterns = [
+            (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default),
+            op_with_quantized_bop_scalar_variant,
+            # TODO: remove when functionalization is supported in pt2_mode
+            (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
+        ]
+        for bop_pattern in bop_patterns:
+            binary_op_configs.append(
+                BackendPatternConfig(bop_pattern)
+                    .set_dtype_configs(dtype_configs)  # noqa: E131
+                    ._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping))
+
+    return binary_op_configs
+
+def get_qnnpack_pt2e_backend_config():
+    return (
+        BackendConfig("qnnpack_pytorch_2.0_export")
+        .set_backend_pattern_configs(get_linear_configs())
+        .set_backend_pattern_configs(get_binary_op_configs())
+        .set_backend_pattern_configs(get_conv_configs())
+        .set_backend_pattern_configs(get_pooling_configs())
+        .set_backend_pattern_configs(get_relu_configs())
+    )
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/backend_config.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/backend_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c287f9aca396b29fdc8b71ce7913c9ec1361d67e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/backend_config.py
@@ -0,0 +1,659 @@
+from __future__ import annotations
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+import torch
+from torch.ao.quantization.utils import Pattern
+from enum import Enum
+
+
+__all__ = [
+    "BackendConfig",
+    "BackendPatternConfig",
+    "DTypeConfig",
+    "DTypeWithConstraints",
+    "ObservationType",
+]
+
+
+# DTypeConfig dict keys
+INPUT_DTYPE_DICT_KEY = "input_dtype"
+OUTPUT_DTYPE_DICT_KEY = "output_dtype"
+WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
+BIAS_DTYPE_DICT_KEY = "bias_dtype"
+IS_DYNAMIC_DICT_KEY = "is_dynamic"
+
+# BackendConfig dict keys
+NAME_DICT_KEY = "name"
+CONFIGS_DICT_KEY = "configs"
+
+# BackendPatternConfig dict keys
+PATTERN_DICT_KEY = "pattern"
+PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format"
+OBSERVATION_TYPE_DICT_KEY = "observation_type"
+DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
+ROOT_MODULE_DICT_KEY = "root_module"
+QAT_MODULE_DICT_KEY = "qat_module"
+REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
+FUSED_MODULE_DICT_KEY = "fused_module"
+FUSER_METHOD_DICT_KEY = "fuser_method"
+ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
+EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
+NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
+INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
+
+
+# TODO: maybe rename this to something that's not related to observer
+# e.g. QParamsType
+class ObservationType(Enum):
+    """ An enum that represents different ways of how an operator/operator pattern
+    should be observed
+    """
+
+    OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
+    """this means input and output are observed with different observers, based
+    on qconfig.activation
+    example: conv, linear, softmax
+    """
+
+    OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
+    """this means the output will use the same observer instance as input, based
+    on qconfig.activation
+    example: torch.cat, maxpool
+    """
+
+    INPUT_OUTPUT_NOT_OBSERVED = 2
+    """this means the input and output are never observed
+    example: x.shape, x.size
+    """
+
+
+@dataclass
+class DTypeWithConstraints:
+    """
+    Config for specifying additional constraints for a given dtype, such as quantization
+    value ranges, scale value ranges, and fixed quantization params, to be used in
+    :class:`~torch.ao.quantization.backend_config.DTypeConfig`.
+
+    The constraints currently supported are:
+
+    * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper
+      bounds for the minimum and maximum quantized values respectively. If
+      the QConfig’s `quant_min` and `quant_max` fall outside this range,
+      then the QConfig will be ignored.
+
+    * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper
+      bounds for the minimum and maximum scale values respectively. If the
+      QConfig’s minimum scale value (currently exposed as `eps`) falls below
+      the lower bound, then the QConfig will be ignored. Note that the upper
+      bound is currently not enforced.
+
+    * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements
+      for scale and zero point, to be used for operators with fixed quantization
+      parameters such as sigmoid and tanh. If the observer specified in the QConfig
+      is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if
+      the quantization parameters don't match, then the QConfig will be ignored.
+    """
+    dtype: Optional[torch.dtype] = None
+    quant_min_lower_bound: Union[int, float, None] = None
+    quant_max_upper_bound: Union[int, float, None] = None
+    scale_min_lower_bound: Union[int, float, None] = None
+    scale_max_upper_bound: Union[int, float, None] = None
+    scale_exact_match: Optional[float] = None
+    zero_point_exact_match: Optional[int] = None
+
+
+@dataclass
+class DTypeConfig:
+    """
+    Config object that specifies the supported data types passed as arguments to
+    quantize ops in the reference model spec, for input and output activations,
+    weights, and biases.
+
+    For example, consider the following reference model:
+
+      quant1 - [dequant1 - fp32_linear - quant2] - dequant2
+
+    The pattern in the square brackets refers to the reference pattern of
+    statically quantized linear. Setting the input dtype as `torch.quint8`
+    in the DTypeConfig means we pass in `torch.quint8` as the dtype argument
+    to the first quantize op (quant1). Similarly, setting the output dtype as
+    `torch.quint8` means we pass in `torch.quint8` as the dtype argument to
+    the second quantize op (quant2).
+
+    Note that the dtype here does not refer to the interface dtypes of the
+    op. For example, the "input dtype" here is not the dtype of the input
+    tensor passed to the quantized linear op. Though it can still be the
+    same as the interface dtype, this is not always the case, e.g. the
+    interface dtype is fp32 in dynamic quantization but the "input dtype"
+    specified in the DTypeConfig would still be quint8. The semantics of
+    dtypes here are the same as the semantics of the dtypes specified in
+    the observers.
+
+    These dtypes are matched against the ones specified in the user’s
+    QConfig. If there is a match, and the QConfig satisfies the constraints
+    specified in the DTypeConfig (if any), then we will quantize the given
+    pattern using this DTypeConfig. Otherwise, the QConfig is ignored and
+    the pattern will not be quantized.
+
+    Example usage::
+
+        >>> # xdoctest: +SKIP(failing)
+        >>> dtype_config1 = DTypeConfig(
+        ...     input_dtype=torch.quint8,
+        ...     output_dtype=torch.quint8,
+        ...     weight_dtype=torch.qint8,
+        ...     bias_dtype=torch.float)
+
+        >>> dtype_config2 = DTypeConfig(
+        ...     input_dtype=DTypeWithConstraints(
+        ...         dtype=torch.quint8,
+        ...         quant_min_lower_bound=0,
+        ...         quant_max_upper_bound=255,
+        ...     ),
+        ...     output_dtype=DTypeWithConstraints(
+        ...         dtype=torch.quint8,
+        ...         quant_min_lower_bound=0,
+        ...         quant_max_upper_bound=255,
+        ...     ),
+        ...     weight_dtype=DTypeWithConstraints(
+        ...         dtype=torch.qint8,
+        ...         quant_min_lower_bound=-128,
+        ...         quant_max_upper_bound=127,
+        ...     ),
+        ...     bias_dtype=torch.float)
+
+        >>> dtype_config1.input_dtype
+        torch.quint8
+
+        >>> dtype_config2.input_dtype
+        torch.quint8
+
+        >>> dtype_config2.input_dtype_with_constraints
+        DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \
+scale_min_lower_bound=None, scale_max_upper_bound=None)
+    """
+    input_dtype_with_constraints: DTypeWithConstraints
+    output_dtype_with_constraints: DTypeWithConstraints
+    weight_dtype_with_constraints: DTypeWithConstraints
+    bias_dtype: Optional[torch.dtype]
+    is_dynamic: Optional[bool]
+
+    def __init__(
+        self,
+        input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
+        output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
+        weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
+        bias_dtype: Optional[torch.dtype] = None,
+        is_dynamic: Optional[bool] = None,
+    ):
+        if isinstance(input_dtype, DTypeWithConstraints):
+            self.input_dtype_with_constraints = input_dtype
+        else:
+            self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype)
+
+        if isinstance(output_dtype, DTypeWithConstraints):
+            self.output_dtype_with_constraints = output_dtype
+        else:
+            self.output_dtype_with_constraints = DTypeWithConstraints(dtype=output_dtype)
+
+        if isinstance(weight_dtype, DTypeWithConstraints):
+            self.weight_dtype_with_constraints = weight_dtype
+        else:
+            self.weight_dtype_with_constraints = DTypeWithConstraints(dtype=weight_dtype)
+
+        self.bias_dtype = bias_dtype
+        self.is_dynamic = is_dynamic
+
+    @property
+    def input_dtype(self) -> Optional[torch.dtype]:
+        return self.input_dtype_with_constraints.dtype
+
+    @property
+    def output_dtype(self) -> Optional[torch.dtype]:
+        return self.output_dtype_with_constraints.dtype
+
+    @property
+    def weight_dtype(self) -> Optional[torch.dtype]:
+        return self.weight_dtype_with_constraints.dtype
+
+    @classmethod
+    def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
+        """
+        Create a ``DTypeConfig`` from a dictionary with the following items (all optional):
+            "input_dtype": torch.dtype or ``DTypeWithConstraints``
+            "output_dtype": torch.dtype or ``DTypeWithConstraints``
+            "weight_dtype": torch.dtype or ``DTypeWithConstraints``
+            "bias_type": torch.dtype
+            "is_dynamic": bool
+        """
+        input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
+        if input_dtype is not None and not isinstance(input_dtype, (torch.dtype, DTypeWithConstraints)):
+            raise ValueError("Expected input_dtype to be a torch.dtype or DTypeWithConstraints")
+        output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
+        if output_dtype is not None and not isinstance(output_dtype, (torch.dtype, DTypeWithConstraints)):
+            raise ValueError("Expected output_dtype to be a torch.dtype or DTypeWithConstraints")
+        weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
+        if weight_dtype is not None and not isinstance(weight_dtype, (torch.dtype, DTypeWithConstraints)):
+            raise ValueError("Expected weight_dtype to be a torch.dtype or DTypeWithConstraints")
+        bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
+        is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
+        return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``DTypeConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
+        """
+        dtype_config_dict: Dict[str, Any] = {}
+        if self.input_dtype is not None:
+            dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints
+        if self.output_dtype is not None:
+            dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype_with_constraints
+        if self.weight_dtype is not None:
+            dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype_with_constraints
+        if self.bias_dtype is not None:
+            dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
+        if self.is_dynamic is not None:
+            dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
+        return dtype_config_dict
+
+
+class BackendConfig:
+    # TODO: refer to NativeBackendConfig once that is implemented
+    """Config that defines the set of patterns that can be quantized on a given backend, and how reference
+    quantized models can be produced from these patterns.
+
+    A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
+    of the above. Each pattern supported on the target backend can be individually configured through
+    :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
+
+    (1) The supported input/output activation, weight, and bias data types
+
+    (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
+
+    (3) (Optionally) Fusion, QAT, and reference module mappings.
+
+    The format of the patterns is described in:
+    https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
+
+    Example usage::
+
+        import torch
+        from torch.ao.quantization.backend_config import (
+            BackendConfig,
+            BackendPatternConfig,
+            DTypeConfig,
+            ObservationType,
+        )
+
+        weighted_int8_dtype_config = DTypeConfig(
+            input_dtype=torch.quint8,
+            output_dtype=torch.quint8,
+            weight_dtype=torch.qint8,
+            bias_dtype=torch.float)
+
+        def fuse_conv2d_relu(is_qat, conv, relu):
+            return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
+
+        # For quantizing Linear
+        linear_config = BackendPatternConfig(torch.nn.Linear) \
+            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+            .add_dtype_config(weighted_int8_dtype_config) \
+            .set_root_module(torch.nn.Linear) \
+            .set_qat_module(torch.ao.nn.qat.Linear) \
+            .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
+
+        # For fusing Conv2d + ReLU into ConvReLU2d
+        conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
+            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+            .add_dtype_config(weighted_int8_dtype_config) \
+            .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
+            .set_fuser_method(fuse_conv2d_relu)
+
+        # For quantizing ConvReLU2d
+        fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \
+            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+            .add_dtype_config(weighted_int8_dtype_config) \
+            .set_root_module(torch.nn.Conv2d) \
+            .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
+            .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)
+
+        backend_config = BackendConfig("my_backend") \
+            .set_backend_pattern_config(linear_config) \
+            .set_backend_pattern_config(conv_relu_config) \
+            .set_backend_pattern_config(fused_conv_relu_config)
+
+    """
+    def __init__(self, name: str = ""):
+        self.name = name
+        # Store all BackendPatternConfigs in a map to handle duplicates
+        # Note: the key in this map uses the complex reversed tuple format.
+        # This is intended only for internal use; users who wish to access
+        # the original patterns should go through `self.configs` instead.
+        self._pattern_complex_format_to_config: Dict[Pattern, BackendPatternConfig] = {}
+
+    def __repr__(self):
+        return f"BackendConfig({self.__dict__})"
+
+    def set_name(self, name: str) -> BackendConfig:
+        """
+        Set the name of the target backend.
+        """
+        self.name = name
+        return self
+
+    def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
+        """
+        Set the config for an pattern that can be run on the target backend.
+        This overrides any existing config for the given pattern.
+        """
+        # Avoid circular dependencies
+        pattern_complex_format = torch.ao.quantization.backend_config.utils \
+            ._get_pattern_in_reversed_nested_tuple_format(config)  # type: ignore[attr-defined]
+        self._pattern_complex_format_to_config[pattern_complex_format] = config
+        return self
+
+    def set_backend_pattern_configs(self, configs: List[BackendPatternConfig]) -> BackendConfig:
+        """
+        Set the configs for patterns that can be run on the target backend.
+        This overrides any existing config for a given pattern if it was previously registered already.
+        """
+        for conf in configs:
+            self.set_backend_pattern_config(conf)
+        return self
+
+    @property
+    def configs(self) -> List[BackendPatternConfig]:
+        """
+        Return a copy of the list of configs set in this `BackendConfig`.
+        """
+        return list(self._pattern_complex_format_to_config.values())
+
+    @classmethod
+    def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
+        """
+        Create a ``BackendConfig`` from a dictionary with the following items:
+
+            "name": the name of the target backend
+
+            "configs": a list of dictionaries that each represents a `BackendPatternConfig`
+
+        """
+        conf = cls(backend_config_dict.get(NAME_DICT_KEY, ""))
+        for d in backend_config_dict.get(CONFIGS_DICT_KEY, []):
+            if isinstance(d, BackendPatternConfig):
+                conf.set_backend_pattern_config(d)
+            elif isinstance(d, Dict):
+                conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
+            else:
+                raise ValueError(f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary")
+        return conf
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``BackendConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
+        """
+        return {
+            NAME_DICT_KEY: self.name,
+            CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs],
+        }
+
+
+class BackendPatternConfig:
+    """
+    Config object that specifies quantization behavior for a given operator pattern.
+    For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
+    """
+    def __init__(self, pattern: Optional[Pattern] = None):
+        self.pattern: Optional[Pattern] = pattern
+        self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+        self.dtype_configs: List[DTypeConfig] = []
+        self.root_module: Optional[Type[torch.nn.Module]] = None
+        self.qat_module: Optional[Type[torch.nn.Module]] = None
+        self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None
+        self.fused_module: Optional[Type[torch.nn.Module]] = None
+        self.fuser_method: Optional[Callable] = None
+
+        # Temporary/internal configs
+        self._root_node_getter: Optional[Callable] = None
+        self._extra_inputs_getter: Optional[Callable] = None
+        self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {}
+        self._input_type_to_index: Dict[str, int] = {}
+        self._pattern_complex_format: Optional[Pattern] = None
+
+    def __repr__(self):
+        dict_nonempty = {
+            k: v for k, v in self.__dict__.items()
+            if (
+                (not isinstance(v, (list, dict)) and v is not None)
+                or (isinstance(v, (list, dict)) and len(v) > 0)
+            )
+        }
+        return f"BackendPatternConfig({dict_nonempty})"
+
+    def set_pattern(self, pattern: Pattern) -> BackendPatternConfig:
+        """
+        Set the pattern to configure.
+
+        The pattern can be a float module, functional operator, pytorch operator, or a tuple
+        combination of the above. Tuple patterns are treated as sequential patterns, and
+        currently only tuples of 2 or 3 elements are supported.
+        """
+        if self._pattern_complex_format is not None:
+            raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set")
+        self.pattern = pattern
+        return self
+
+    def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
+        """
+        Set how observers should be inserted in the graph for this pattern.
+
+        Observation type here refers to how observers (or quant-dequant ops) will be placed
+        in the graph. This is used to produce the desired reference patterns understood by
+        the backend. Weighted ops such as linear and conv require different observers
+        (or quantization parameters passed to quantize ops in the reference model) for the
+        input and the output.
+
+        There are two observation types:
+
+            `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance
+            will be different from the input. This is the most common observation type.
+
+            `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the
+            same as the input. This is useful for operators like `cat`.
+
+        Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs
+        with observers (and fake quantizes) attached instead of observers themselves.
+        """
+        self.observation_type = observation_type
+        return self
+
+    def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
+        """
+        Add a set of supported data types passed as arguments to quantize ops in the
+        reference model spec.
+        """
+        self.dtype_configs.append(dtype_config)
+        return self
+
+    def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig:
+        """
+        Set the supported data types passed as arguments to quantize ops in the
+        reference model spec, overriding all previously registered data types.
+        """
+        self.dtype_configs = dtype_configs
+        return self
+
+    def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig:
+        """
+        Set the module that represents the root for this pattern.
+
+        When we construct the reference quantized model during the convert phase,
+        the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU)
+        will be swapped to the corresponding reference quantized modules (e.g.
+        torch.ao.nn.reference.quantized.Linear). This allows custom backends to
+        specify custom reference quantized module implementations to match the
+        numerics of their lowered operators. Since this is a one-to-one mapping,
+        both the root module and the reference quantized module must be specified
+        in the same BackendPatternConfig in order for the conversion to take place.
+        """
+        self.root_module = root_module
+        return self
+
+    def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig:
+        """
+        Set the module that represents the QAT implementation for this pattern.
+        """
+        self.qat_module = qat_module
+        return self
+
+    def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig:
+        """
+        Set the module that represents the reference quantized implementation for
+        this pattern's root module.
+
+        For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`.
+        """
+        self.reference_quantized_module = reference_quantized_module
+        return self
+
+    def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatternConfig:
+        """
+        Set the module that represents the fused implementation for this pattern.
+        """
+        self.fused_module = fused_module
+        return self
+
+    def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
+        """
+        Set the function that specifies how to fuse this BackendPatternConfig's pattern.
+
+        The first argument of this function should be `is_qat`, and the rest of the arguments
+        should be the items in the tuple pattern. The return value of this function should be
+        the resulting fused module.
+
+        For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be:
+
+            def fuse_linear_relu(is_qat, linear, relu):
+                return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
+
+        For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6.
+        """
+        self.fuser_method = fuser_method
+        return self
+
+    def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
+        self._root_node_getter = root_node_getter
+        return self
+
+    def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig:
+        self._extra_inputs_getter = extra_inputs_getter
+        return self
+
+    def _set_num_tensor_args_to_observation_type(
+            self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig:
+        self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
+        return self
+
+    def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig:
+        self._input_type_to_index = input_type_to_index
+        return self
+
+    def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig:
+        """
+        Set the pattern to configure, using the reversed nested tuple format.
+
+        See the BackendConfig README for more detail:
+        https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification
+        """
+        if self.pattern is not None:
+            raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set")
+        self._pattern_complex_format = pattern
+        return self
+
+    @classmethod
+    def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
+        """
+        Create a ``BackendPatternConfig`` from a dictionary with the following items:
+
+            "pattern": the pattern being configured
+            "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
+            observers should be inserted for this pattern
+            "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s
+            "root_module": a :class:`torch.nn.Module` that represents the root for this pattern
+            "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
+            "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
+            implementation for this pattern's root module.
+            "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
+            "fuser_method": a function that specifies how to fuse the pattern for this pattern
+            "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated)
+
+        """
+        def _get_dtype_config(obj: Any) -> DTypeConfig:
+            """
+            Convert the given object into a ``DTypeConfig`` if possible, else throw an exception.
+            """
+            if isinstance(obj, DTypeConfig):
+                return obj
+            if isinstance(obj, Dict):
+                return DTypeConfig.from_dict(obj)
+            raise ValueError(
+                f"Expected a list of DTypeConfigs in "
+                f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'"
+            )
+
+        conf = cls()
+        if PATTERN_DICT_KEY in backend_pattern_config_dict:
+            conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY])
+        if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
+            conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY])
+        for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
+            conf.add_dtype_config(_get_dtype_config(d))
+        conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None))
+        conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
+        conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None))
+        conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None))
+        conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None))
+        conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None))
+        conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None))
+        conf._set_num_tensor_args_to_observation_type(
+            backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}))
+        conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {}))
+        if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict:
+            conf._set_pattern_complex_format(backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY])
+        return conf
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``BackendPatternConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
+        """
+        backend_pattern_config_dict: Dict[str, Any] = {
+            OBSERVATION_TYPE_DICT_KEY: self.observation_type,
+            DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
+        }
+        if self.pattern is not None:
+            backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern
+        if self.root_module is not None:
+            backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
+        if self.qat_module is not None:
+            backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
+        if self.reference_quantized_module is not None:
+            backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module
+        if self.fused_module is not None:
+            backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
+        if self.fuser_method is not None:
+            backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
+        if self._root_node_getter is not None:
+            backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter
+        if self._extra_inputs_getter is not None:
+            backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter
+        if len(self._num_tensor_args_to_observation_type) > 0:
+            backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type
+        if len(self._input_type_to_index) > 0:
+            backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index
+        if self._pattern_complex_format is not None:
+            backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = self._pattern_complex_format
+        return backend_pattern_config_dict
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/executorch.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/executorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..871b969de4e2b14c209ebe96315cd17cd628045e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/executorch.py
@@ -0,0 +1,494 @@
+# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime
+# not a specific backend
+
+import operator
+from typing import List
+
+import torch
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.quantized.reference as nnqr
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..fuser_method_mappings import (
+    _sequential_wrapper2,
+    fuse_conv_bn,
+    fuse_conv_bn_relu,
+)
+from ._common_operator_config_utils import _Conv2dMetadata
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    DTypeWithConstraints,
+    ObservationType,
+)
+from .qnnpack import (
+    qnnpack_default_op_qint8_symmetric_dtype_config,
+    qnnpack_weighted_op_qint8_symmetric_dtype_config,
+)
+
+
+__all__ = [
+    "get_executorch_backend_config",
+]
+
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+executorch_weighted_op_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+executorch_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+executorch_default_dynamic_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    scale_min_lower_bound=2**-12,
+)
+
+executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    quant_min_lower_bound=-127,
+    quant_max_upper_bound=127,
+    scale_min_lower_bound=2**-12,
+)
+
+executorch_default_dynamic_qint8_dtype_config = DTypeConfig(
+    input_dtype=executorch_act_qint8_scale_min_2_neg_12,
+    output_dtype=torch.float,
+    weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+executorch_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+executorch_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+
+# =============================
+# |  BACKEND PATTERN CONFIGS  |
+# =============================
+
+
+def _get_linear_configs() -> List[BackendPatternConfig]:
+    """
+    Return all configs related to linear modules and ops.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        executorch_weighted_op_int8_dtype_config,
+        executorch_default_dynamic_quint8_dtype_config,
+        executorch_default_dynamic_qint8_dtype_config,
+        executorch_default_dynamic_float16_dtype_config,
+    ]
+    linear_configs: List[BackendPatternConfig] = []
+    # linear module
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.Linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+        .set_qat_module(nnqat.Linear)
+    )
+    # linear qat module
+    linear_configs.append(
+        BackendPatternConfig(nnqat.Linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+    )
+    # functional linear
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.functional.linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    return linear_configs
+
+
+def _get_conv_configs() -> List[BackendPatternConfig]:
+    """
+    Return all configs related to conv modules and ops.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        executorch_weighted_op_int8_dtype_config,
+    ]
+    conv_configs = []
+    for convs in [_Conv2dMetadata]:
+        # (1) Single conv modules/functions
+        # -----------------------------------
+        # conv module
+        conv_configs.append(
+            BackendPatternConfig(convs.root)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+            .set_qat_module(convs.qat)
+        )
+        # conv qat module
+        conv_configs.append(
+            BackendPatternConfig(convs.qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # functional conv
+        conv_configs.append(
+            BackendPatternConfig(convs.func)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            ._set_input_type_to_index({"weight": 1, "bias": 2})
+        )
+
+        # (2) Conv + relu
+        # -----------------------------------
+        # conv module + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.root, nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+            .set_fused_module(convs.fused_conv_relu)
+        )
+        # conv module + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.root, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+            .set_fused_module(convs.fused_conv_relu)
+        )
+        # fused conv relu module
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+            .set_qat_module(convs.relu_qat)
+        )
+        # conv relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # functional conv + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.func, nn.ReLU))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+        # functional conv + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.func, F.relu))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+        # fused conv relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.relu_qat)
+        )
+
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+
+        # (3) Conv + batchnorm (+ relu)
+        # -------------------------------
+        # conv + batchnorm (+ relu)
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_conv_bn)
+            .set_fused_module(convs.fused_conv_bn)
+        )
+        # conv + bn + relu module fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_conv_bn_relu)
+            .set_fused_module(convs.fused_conv_bn_relu)
+        )
+        # conv + bn + relu functional fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.root)
+            .set_fuser_method(fuse_conv_bn_relu)
+            .set_fused_module(convs.fused_conv_bn_relu)
+        )
+        # TODO: we can add fusion for torch.relu as well
+        # 3.2 conv + bn (+ relu) fused module configs
+        # fused conv bn
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.bn_qat)
+        )
+
+        # fused conv bn relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn_relu)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.bn_relu_qat)
+        )
+
+        # conv bn, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # conv bn relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_relu_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+    return conv_configs
+
+
+def _get_binary_ops_configs() -> List[BackendPatternConfig]:
+    """
+    Return all configs related to binary ops.
+    """
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_weighted_op_int8_dtype_config,
+    ]
+    num_tensor_args_to_observation_type_mapping = {
+        # TODO: this is not used right now since we have extra check in prepare
+        # will need to change this to NO_OBSERVER later after we implemented
+        # Tensor dtype inference properly
+        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
+        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+    }
+    binary_op_configs: List[BackendPatternConfig] = []
+    for op in [operator.add, torch.add, operator.sub, torch.sub, operator.mul, torch.mul]:
+        bop_patterns = [
+            (op, torch.nn.ReLU),
+            (op, torch.nn.functional.relu),
+            (op, torch.relu),
+            op
+        ]
+        for bop_pattern in bop_patterns:
+            binary_op_configs.append(
+                BackendPatternConfig(bop_pattern)
+                .set_dtype_configs(dtype_configs)  # noqa: E131
+                ._set_num_tensor_args_to_observation_type(
+                    num_tensor_args_to_observation_type_mapping
+                )
+            )
+    return binary_op_configs
+
+
+def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]:
+    """
+    Return the operator configs for the operators that works for both float and quantized
+    input if input is quantized, the output Tensor shares the same quantization parameter
+    with input.
+
+    Example operator: avgpool2d, reshape, transpose, maxpool2d
+    Example observed operator:
+    observer_0 - avgpool2d - observer_0 (same observer instance as input)
+    """
+    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_default_op_quint8_dtype_config,
+    ]
+    share_qparams_ops = [
+        torch.nn.Flatten,
+        F.adaptive_avg_pool2d,
+        F.elu,
+        F.hardtanh,
+        F.max_pool2d,
+        F.pad,
+        F.relu,
+        F.relu6,
+        F.leaky_relu,
+        F.leaky_relu_,
+        torch.nn.AdaptiveAvgPool2d,
+        torch.nn.ConstantPad2d,
+        torch.nn.ELU,
+        torch.nn.MaxPool2d,
+        torch.nn.ReLU6,
+        torch.nn.Hardtanh,
+        torch.nn.LeakyReLU,
+        torch.clamp,
+        torch.flatten,
+        torch.mean,
+        torch.permute,
+        torch.permute_copy,
+        torch.squeeze,
+        "clamp",
+        "mean",
+        "permute",
+        "reshape",
+        "relu",
+        "relu_",
+        "squeeze",
+        "squeeze_",
+        "leaky_relu",
+    ]
+    share_qparams_op_configs: List[BackendPatternConfig] = []
+    for op in share_qparams_ops:
+        share_qparams_op_configs.append(
+            BackendPatternConfig(op)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+    return share_qparams_op_configs
+
+
+def _get_bn_configs() -> List[BackendPatternConfig]:
+    """
+    Return all configs related to batchnorm.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_default_op_quint8_dtype_config,
+    ]
+    bn_configs = []
+    bn_configs.append(
+        BackendPatternConfig(nn.BatchNorm2d)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    return bn_configs
+
+
+def _get_cat_configs() -> List[BackendPatternConfig]:
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_default_op_quint8_dtype_config,
+    ]
+    cat_configs = []
+    cat_configs.append(
+        BackendPatternConfig(torch.cat)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+    cat_configs.append(
+        BackendPatternConfig(torch.concat)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+    cat_configs.append(
+        BackendPatternConfig(torch.concatenate)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+    return cat_configs
+
+
+def _get_embedding_op_configs() -> List[BackendPatternConfig]:
+    dtype_configs = [
+        executorch_weight_only_quint8_dtype_config,
+    ]
+    embedding_op_configs = []
+    for embedding_op, qat_embedding_op, ref_embedding_op in [
+        (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
+        (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
+    ]:
+        embedding_op_configs.append(
+            BackendPatternConfig(embedding_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_qat_module(qat_embedding_op)
+            .set_root_module(embedding_op)
+            .set_reference_quantized_module(ref_embedding_op)
+        )
+        # config for qat op
+        embedding_op_configs.append(
+            BackendPatternConfig(qat_embedding_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(embedding_op)
+            .set_reference_quantized_module(ref_embedding_op)
+        )
+
+        # config for functional embedding
+        embedding_op_configs.append(
+            BackendPatternConfig(torch.nn.functional.embedding)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            ._set_input_type_to_index({"weight": 1})
+        )
+    return embedding_op_configs
+
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_executorch_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack.
+    """
+    return (
+        BackendConfig("executorch")
+        .set_backend_pattern_configs(_get_linear_configs())
+        .set_backend_pattern_configs(_get_conv_configs())
+        .set_backend_pattern_configs(_get_binary_ops_configs())
+        .set_backend_pattern_configs(_get_share_qparams_ops_configs())
+        .set_backend_pattern_configs(_get_bn_configs())
+        .set_backend_pattern_configs(_get_cat_configs())
+        .set_backend_pattern_configs(_get_embedding_op_configs())
+    )
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/fbgemm.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/fbgemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca2d267ee1597ed2ef360611cdaaf1735313587
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/fbgemm.py
@@ -0,0 +1,116 @@
+import torch
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig
+
+__all__ = [
+    "get_fbgemm_backend_config",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py
+# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max,
+# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized
+# values to within [0, 127].
+
+fbgemm_weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+fbgemm_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+fbgemm_default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+fbgemm_default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+fbgemm_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+fbgemm_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+def get_fbgemm_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native FBGEMM backend.
+    """
+    conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config]
+    linear_dtype_configs = [
+        fbgemm_weighted_op_quint8_dtype_config,
+        fbgemm_default_dynamic_int8_dtype_config,
+        fbgemm_default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    rnn_op_dtype_configs = [
+        fbgemm_default_dynamic_int8_dtype_config,
+        fbgemm_default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        fbgemm_weight_only_quint8_dtype_config,
+        fbgemm_weight_only_quint4x2_dtype_config,
+    ]
+    return BackendConfig("fbgemm") \
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/native.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/native.py
new file mode 100644
index 0000000000000000000000000000000000000000..5425e5173fd7711c00e52a5f56197b9460d2b466
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/native.py
@@ -0,0 +1,204 @@
+import torch
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_ln_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig
+
+__all__ = [
+    "get_test_only_legacy_native_backend_config",
+    "default_op_quint8_dtype_config",
+    "default_op_fp16_dtype_config",
+    "default_dynamic_int8_dtype_config",
+    "default_dynamic_float16_dtype_config",
+    "input_output_only_quint8_dtype_config",
+    "weight_only_quint8_dtype_config",
+    "weight_only_quint4x2_dtype_config",
+    "get_native_backend_config",
+    "get_native_backend_config_dict",
+    "get_test_only_legacy_native_backend_config_dict",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+# weighted op int8 dtype config
+# this is config for ops that has quantized weights, like linear, conv
+weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    # currently the dtype check is not yet enabled, so we provided the dtype_configs but
+    # it is not really used yet,
+    # we will enable it a bit later after we moved everything to backend_config_dict
+    is_dynamic=True,
+)
+
+default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    # currently the dtype check is not yet enabled, so we provided the dtype_configs but
+    # it is not really used yet,
+    # we will enable it a bit later after we moved everything to backend_config_dict
+    is_dynamic=True,
+)
+
+# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
+input_output_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.float,
+    bias_dtype=torch.float,
+)
+
+weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+def get_test_only_legacy_native_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
+    """
+    conv_dtype_configs = [weighted_op_quint8_dtype_config]
+    linear_dtype_configs = [
+        weighted_op_quint8_dtype_config,
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    binary_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    default_op_dtype_configs = [default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    share_qparams_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+        default_op_fp16_dtype_config
+    ]
+    tensor_info_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+    ]
+    rnn_op_dtype_configs = [
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        weight_only_quint8_dtype_config,
+        weight_only_quint4x2_dtype_config,
+    ]
+    layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
+    return BackendConfig("_native_and_fp16") \
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
+
+def get_native_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
+    """
+    # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
+    conv_dtype_configs = [weighted_op_quint8_dtype_config]
+    linear_dtype_configs = [
+        weighted_op_quint8_dtype_config,
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [default_op_quint8_dtype_config]
+    default_op_dtype_configs = [default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
+    share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
+    tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
+    rnn_op_dtype_configs = [
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        weight_only_quint8_dtype_config,
+        weight_only_quint4x2_dtype_config,
+    ]
+    layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
+    return BackendConfig("native") \
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
+
+def get_native_backend_config_dict():
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
+    """
+    return get_native_backend_config().to_dict()
+
+def get_test_only_legacy_native_backend_config_dict():
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
+    fp16 ops in dictionary form.
+    """
+    return get_test_only_legacy_native_backend_config().to_dict()
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/observation_type.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/observation_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/onednn.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/onednn.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d2f7ff42c3913c4eaee43a919251ad12fc20e0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/onednn.py
@@ -0,0 +1,542 @@
+import torch
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+import torch.nn.functional as F
+import torch.ao.nn.quantized.reference as nnqr
+from ._common_operator_config_utils import (
+    _get_conv_configs,
+    _get_linear_configs,
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_ln_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+)
+from .backend_config import (
+    BackendPatternConfig,
+    BackendConfig,
+    DTypeConfig,
+    ObservationType,
+)
+from ..fuser_method_mappings import (
+    _sequential_wrapper2,
+)
+import operator
+from torch.ao.quantization.utils import MatchAllNode
+import itertools
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+onednn_weighted_op_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+onednn_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+onednn_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+onednn_weight_only_qint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+)
+
+onednn_input_output_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.float,
+    bias_dtype=torch.float,
+)
+
+# ===================
+# |  FUSER METHODS  |
+# ===================
+
+def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
+    r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module
+    Args:
+        is_qat: a flag for whether we are using quantization aware training fusion
+                or post training quantization fusion
+        linear: Module instance of type Linear
+        bn: BatchNorm1d instance that needs to be fused with the linear layer
+        leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
+    Examples::
+        >>> # xdoctest: +SKIP(failing)
+        >>> m1 = nn.Linear(20, 10)
+        >>> b1 = nn.BatchNorm1d(10)
+        >>> lr = nn.LeakyReLU(0.01)
+        >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
+    """
+    assert linear.training == bn.training and bn.training == leaky_relu.training, \
+        "Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
+
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(linear, bn, leaky_relu)}")
+    else:
+        map_to_fused_module_eval = {
+            nn.Linear: nni.LinearLeakyReLU,
+        }
+        fused_module = map_to_fused_module_eval.get(type(linear), None)
+        if fused_module is not None:
+            fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
+            fm = fused_module(fused_linear, leaky_relu)
+            return fm
+        else:
+            raise NotImplementedError(f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}")
+
+# ======================
+# |  CONFIGS FOR CONV  |
+# ======================
+observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+
+conv_dtype_configs = [onednn_weighted_op_int8_dtype_config]
+conv_configs = _get_conv_configs(conv_dtype_configs)
+
+# (1) Conv2d + Add
+
+# conv2d   Y
+#   \   /
+#    add
+
+# include:
+# conv2d conv2d
+#   \   /
+#    add
+
+def _fuse_conv_add_left(is_qat, add, conv, _):
+    return nni.ConvAdd2d(conv, add)
+
+def _conv_add_root_node_getter_left(pattern):
+    _, conv, _ = pattern
+    return conv
+
+def _conv_add_extra_inputs_getter_left(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, conv, extra_input = pattern
+    return [extra_input]
+
+# conv2d
+#  \
+#  bn   Y
+#   \   /
+#    add
+
+def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _):
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAdd2d(fused_conv, add)
+
+def _conv_bn_add_root_node_getter_left(add_pattern):
+    _, bn_conv, _ = add_pattern
+    bn, conv = bn_conv
+    return conv
+
+def _conv_bn_add_extra_inputs_getter_left(add_pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, bn_conv, extra_input = add_pattern
+    bn, conv = bn_conv
+    return [extra_input]
+
+conv_add_left_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_left_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_bn_add_left)
+                ._set_root_node_getter(_conv_bn_add_root_node_getter_left)
+                ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left)
+                .set_fused_module(nni.ConvAdd2d))
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_add_left)
+                ._set_root_node_getter(_conv_add_root_node_getter_left)
+                ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left)
+                .set_fused_module(nni.ConvAdd2d))
+
+#  Y   conv2d
+#   \   /
+#    add
+
+def _fuse_conv_add_right(is_qat, add, _, conv):
+    return nni.ConvAdd2d(conv, add)
+
+def _conv_add_root_node_getter_right(pattern):
+    add, _, conv = pattern
+    return conv
+
+def _conv_add_extra_inputs_getter_right(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, extra_input, conv = pattern
+    return [extra_input]
+
+#      conv2d
+#        /
+#  Y    bn
+#   \   /
+#    add
+
+def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv):
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAdd2d(fused_conv, add)
+
+def _conv_bn_add_root_node_getter_right(pattern):
+    add, _, bn_conv = pattern
+    bn, conv = bn_conv
+    return conv
+
+def _conv_bn_add_extra_inputs_getter_right(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, extra_input, bn_conv = pattern
+    bn, conv = bn_conv
+    return [extra_input]
+
+conv_add_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_bn_add_right)
+                ._set_root_node_getter(_conv_bn_add_root_node_getter_right)
+                ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right)
+                .set_fused_module(nni.ConvAdd2d))
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_add_right)
+                ._set_root_node_getter(_conv_add_root_node_getter_right)
+                ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right)
+                .set_fused_module(nni.ConvAdd2d))
+
+conv_configs.append(
+    BackendPatternConfig(nni.ConvAdd2d)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(conv_dtype_configs)
+        .set_root_module(nn.Conv2d)
+        .set_reference_quantized_module(nnqr.Conv2d))
+
+# (2) Conv2d + Add + Relu
+
+# conv2d Y
+#   \   /
+#    add
+#     \
+#     relu
+
+def _fuse_conv_add_relu_left(is_qat, relu, add_pattern):
+    add, conv, _ = add_pattern
+    return nni.ConvAddReLU2d(conv, add, relu)
+
+def _conv_add_relu_root_node_getter_left(pattern):
+    relu, add_pattern = pattern
+    _, conv, _ = add_pattern
+    return conv
+
+def _conv_add_relu_extra_inputs_getter_left(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    relu, add_pattern = pattern
+    _, conv, extra_input = add_pattern
+    return [extra_input]
+
+# conv2d
+#  \
+#  bn   Y
+#   \   /
+#    add
+#     \
+#     relu
+
+def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern):
+    add, bn_conv, _ = add_pattern
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAddReLU2d(fused_conv, add, relu)
+
+def _conv_bn_add_relu_root_node_getter_left(pattern):
+    relu, add_pattern = pattern
+    _, bn_conv, _ = add_pattern
+    bn, conv = bn_conv
+    return conv
+
+def _conv_bn_add_relu_extra_inputs_getter_left(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    relu, add_pattern = pattern
+    _, bn_conv, extra_input = add_pattern
+    bn, conv = bn_conv
+    return [extra_input]
+
+conv_add_relu_left_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_relu_left_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_bn_add_relu_left)
+                ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left)
+                ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left)
+                .set_fused_module(nni.ConvAddReLU2d))
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode)))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_add_relu_left)
+                ._set_root_node_getter(_conv_add_relu_root_node_getter_left)
+                ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left)
+                .set_fused_module(nni.ConvAddReLU2d))
+
+#  Y   conv2d
+#   \   /
+#    add
+#     \
+#     relu
+
+def _fuse_conv_add_relu_right(is_qat, relu, add_pattern):
+    add, _, conv = add_pattern
+    return nni.ConvAddReLU2d(conv, add, relu)
+
+def _conv_add_relu_root_node_getter_right(pattern):
+    relu, add_pattern = pattern
+    _, _, conv = add_pattern
+    return conv
+
+def _conv_add_relu_extra_inputs_getter_right(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    relu, add_pattern = pattern
+    _, extra_input, conv = add_pattern
+    return [extra_input]
+
+#      conv2d
+#        /
+#  Y    bn
+#   \   /
+#    add
+#     \
+#     relu
+
+def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern):
+    add, _, bn_conv = add_pattern
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAddReLU2d(fused_conv, add, relu)
+
+def _conv_bn_add_relu_root_node_getter_right(pattern):
+    relu, add_pattern = pattern
+    _, _, bn_conv = add_pattern
+    bn, conv = bn_conv
+    return conv
+
+def _conv_bn_add_relu_extra_inputs_getter_right(pattern):
+    """ get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    relu, add_pattern = pattern
+    _, extra_input, bn_conv = add_pattern
+    bn, conv = bn_conv
+    return [extra_input]
+
+conv_add_relu_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_relu_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_bn_add_relu_right)
+                ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right)
+                ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right)
+                .set_fused_module(nni.ConvAddReLU2d))
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+                ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d)))  # noqa: E131
+                .set_observation_type(observation_type)
+                .set_dtype_configs(conv_dtype_configs)
+                .set_fuser_method(_fuse_conv_add_relu_right)
+                ._set_root_node_getter(_conv_add_relu_root_node_getter_right)
+                ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right)
+                .set_fused_module(nni.ConvAddReLU2d))
+
+conv_configs.append(
+    BackendPatternConfig(nni.ConvAddReLU2d)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(conv_dtype_configs)
+        .set_root_module(nn.Conv2d)
+        .set_reference_quantized_module(nnqr.Conv2d))
+
+# ========================
+# |  CONFIGS FOR LINEAR  |
+# ========================
+
+linear_dtype_configs = [
+    onednn_weighted_op_int8_dtype_config,
+    onednn_dynamic_int8_dtype_config,
+]
+linear_configs = _get_linear_configs(linear_dtype_configs)
+
+def _add_eltwise_fusion_configs(configs, root_module, root_op, post_module, post_op,
+                                dtype_configs, fuser_method, fused_module, observation_type,
+                                ref_quant_module):
+    # 1 base module + op module fusion config
+    configs.append(
+        BackendPatternConfig((root_module, post_module))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuser_method)
+            .set_fused_module(fused_module))
+    # base module + functional post op
+    configs.append(
+        BackendPatternConfig((root_module, post_op))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuser_method)
+            .set_fused_module(fused_module))
+
+    # 2 fused module configs
+    configs.append(
+        BackendPatternConfig(fused_module)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(root_module)
+            .set_reference_quantized_module(ref_quant_module))
+
+    # 3 functional base op + post op configs
+    configs.append(
+        BackendPatternConfig((root_op, post_module))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs))
+    configs.append(
+        BackendPatternConfig((root_op, post_op))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs))
+
+# Configs for linear + leaky_relu fusion
+_add_eltwise_fusion_configs(linear_configs, nn.Linear, F.linear,
+                            nn.LeakyReLU, F.leaky_relu, linear_dtype_configs,
+                            _sequential_wrapper2(nni.LinearLeakyReLU),
+                            nni.LinearLeakyReLU, observation_type, nnqr.Linear)
+
+# Configs for linear module + batchnorm + leaky_relu
+linear_configs.append(
+    BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU))
+        .set_dtype_configs(linear_dtype_configs)  # noqa: E131
+        .set_fuser_method(_fuse_linear_bn_leaky_relu)
+        .set_fused_module(nni.LinearLeakyReLU))
+
+# Configs for linear + tanh fusion
+_add_eltwise_fusion_configs(linear_configs, nn.Linear, F.linear,
+                            nn.Tanh, torch.tanh, linear_dtype_configs,
+                            _sequential_wrapper2(nni.LinearTanh),
+                            nni.LinearTanh, observation_type, nnqr.Linear)
+
+# ===========================
+# |  CONFIGS FOR OTHER OPS  |
+# ===========================
+
+binary_op_dtype_configs = [onednn_op_quint8_dtype_config]
+default_op_dtype_configs = [onednn_op_quint8_dtype_config]
+fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
+share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
+rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config]
+embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config]
+layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config]
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+def get_onednn_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native ONEDNN backend.
+    """
+    return BackendConfig("onednn") \
+        .set_backend_pattern_configs(conv_configs) \
+        .set_backend_pattern_configs(linear_configs) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
+
+__all__ = [
+    "get_onednn_backend_config",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/qnnpack.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/qnnpack.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0b6a5fa8cd5c29b4251e6115e27ce6039c7c31c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/qnnpack.py
@@ -0,0 +1,160 @@
+import torch
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints
+
+__all__ = [
+    "get_qnnpack_backend_config",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+qnnpack_weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+qnnpack_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+qnnpack_default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+qnnpack_default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+qnnpack_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+qnnpack_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+# xnnpack compatible dtype configs
+
+# We restrict scale values to be 2 ** -12 to ensure the
+# requantization scale never falls below the xnnpack lower
+# threshold. Additionally, for qint8 weight, we restrict
+# the quantization values to [-127, +127], excluding -128.
+# For more detail, refer to the description of
+# `default_symmetric_qnnpack_qconfig`.
+
+# TODO: add additional restriction on qscheme to ensure it
+# is either per_tensor_symmetric or per_channel_symmetric
+
+qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    scale_min_lower_bound=2 ** -12,
+)
+
+qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    quant_min_lower_bound=-127,
+    quant_max_upper_bound=127,
+    scale_min_lower_bound=2 ** -12,
+)
+
+qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig(
+    input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+    output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+    weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
+    bias_dtype=torch.float,
+)
+
+qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig(
+    input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+    output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+def get_qnnpack_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native QNNPACK backend.
+    """
+    conv_dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        qnnpack_weighted_op_quint8_dtype_config,
+    ]
+    linear_dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        qnnpack_weighted_op_quint8_dtype_config,
+        qnnpack_default_dynamic_int8_dtype_config,
+        qnnpack_default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    default_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    fixed_qparams_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    share_qparams_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    rnn_op_dtype_configs = [
+        qnnpack_default_dynamic_int8_dtype_config,
+        qnnpack_default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        qnnpack_weight_only_quint8_dtype_config,
+        qnnpack_weight_only_quint4x2_dtype_config,
+    ]
+    return BackendConfig("qnnpack") \
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/tensorrt.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/tensorrt.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd2c8c169a587ab0889d0eb36ad7e055c83cb72f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/tensorrt.py
@@ -0,0 +1,81 @@
+import torch
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    ObservationType
+)
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_linear_configs,
+    _get_conv_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+
+__all__ = [
+    "get_tensorrt_backend_config",
+    "get_tensorrt_backend_config_dict",
+]
+
+def get_tensorrt_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for the TensorRT backend.
+    NOTE: Current api will change in the future, it's just to unblock experimentation for
+    new backends, please don't use it right now.
+    TODO: add a README when it's more stable
+    """
+    # dtype configs
+    weighted_op_qint8_dtype_config = DTypeConfig(
+        input_dtype=torch.qint8,
+        output_dtype=torch.qint8,
+        weight_dtype=torch.qint8,
+        bias_dtype=torch.float,
+    )
+    non_weighted_op_qint8_dtype_config = DTypeConfig(
+        input_dtype=torch.qint8,
+        output_dtype=torch.qint8,
+    )
+
+    addmm_config = BackendPatternConfig(torch.addmm) \
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+        .add_dtype_config(weighted_op_qint8_dtype_config) \
+        ._set_input_type_to_index({
+            "bias": 0,
+            "input": 1,
+            "weight": 2,
+        })
+    cat_config = BackendPatternConfig(torch.cat) \
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
+        .add_dtype_config(non_weighted_op_qint8_dtype_config)
+    conv_dtype_configs = [
+        weighted_op_qint8_dtype_config,
+    ]
+    linear_dtype_configs = [
+        weighted_op_qint8_dtype_config,
+    ]
+    binary_op_dtype_configs = [
+        weighted_op_qint8_dtype_config,
+    ]
+    share_qparams_op_dtype_configs = [
+        non_weighted_op_qint8_dtype_config,
+    ]
+    tensor_info_op_dtype_configs = [
+        non_weighted_op_qint8_dtype_config,
+    ]
+    # there might be things not supported in fx2trt, but it will error out
+    # during fx2trt conversion and can support them after that
+    return BackendConfig("tensorrt") \
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
+        .set_backend_pattern_config(addmm_config) \
+        .set_backend_pattern_config(cat_config) \
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs))
+
+def get_tensorrt_backend_config_dict():
+    """
+    Return the `BackendConfig` for the TensorRT backend in dictionary form.
+    """
+    return get_tensorrt_backend_config().to_dict()
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bee900e062f2fe9eee9e7127c992360d2671e0d6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/utils.py
@@ -0,0 +1,279 @@
+from typing import Dict, Any, List, Callable, Union, Tuple, Type
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+)
+from ..utils import Pattern
+from ..fuser_method_mappings import (
+    _reverse2,
+    _reverse3,
+)
+
+__all__ = [
+    "get_pattern_to_dtype_configs",
+    "get_qat_module_classes",
+    "get_fused_module_classes",
+    "get_pattern_to_input_type_to_index",
+    "get_root_module_to_quantized_reference_module",
+    "get_fuser_method_mapping",
+    "get_module_to_qat_module",
+    "get_fusion_pattern_to_root_node_getter",
+    "get_fusion_pattern_to_extra_inputs_getter",
+    "remove_boolean_dispatch_from_name",
+    "pattern_to_human_readable",
+    "entry_to_pretty_str",
+]
+
+def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]:
+    pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        pattern_to_dtype_configs[pattern] = config.dtype_configs
+    return pattern_to_dtype_configs
+
+def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
+    qat_module_classes = []
+    for config in backend_config.configs:
+        if config.qat_module is not None:
+            qat_module_classes.append(config.qat_module)
+    return tuple(set(qat_module_classes))
+
+def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
+    fused_module_classes = []
+    for config in backend_config.configs:
+        if config.fused_module is not None:
+            fused_module_classes.append(config.fused_module)
+    return tuple(set(fused_module_classes))
+
+def get_pattern_to_input_type_to_index(backend_config: BackendConfig) -> Dict[Pattern, Dict[str, int]]:
+    pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        pattern_to_input_type_to_index[pattern] = config._input_type_to_index
+    return pattern_to_input_type_to_index
+
+def get_root_module_to_quantized_reference_module(
+        backend_config: BackendConfig) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]:
+    mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {}
+    for config in backend_config.configs:
+        if config.root_module is not None and config.reference_quantized_module is not None:
+            mapping[config.root_module] = config.reference_quantized_module
+    return mapping
+
+def get_fuser_method_mapping(backend_config: BackendConfig) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
+    fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config.fuser_method is not None:
+            # Note: both the fuser method and the pattern are specified in forward order in the
+            # BackendConfig, but the internal pattern matching code uses the reversed nested tuple
+            # format, so we need to convert both to the internal format
+            fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config)
+            fuser_method_mapping[pattern] = fuser_method
+    return fuser_method_mapping
+
+def get_module_to_qat_module(backend_config: BackendConfig) -> Dict[Pattern, Type[torch.nn.Module]]:
+    module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config.qat_module is not None:
+            module_to_qat_module[pattern] = config.qat_module
+    return module_to_qat_module
+
+def get_fusion_pattern_to_root_node_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]:
+    """ Get a map from fusion pattern to a function that returns the root node
+    from the fusion pattern, e.g. the most common one is:
+    def get_root_node(node_pattern):
+        while not isinstance(node_pattern[-1], Node):
+            node_pattern = node_pattern[-1]
+        return node_pattern[-1]
+    This can work for all patterns whose root node is the "last node" in the pattern,
+    e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
+    """
+    root_node_getter_mapping: Dict[Pattern, Callable] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config._root_node_getter is not None:
+            root_node_getter_mapping[pattern] = config._root_node_getter
+    return root_node_getter_mapping
+
+def get_fusion_pattern_to_extra_inputs_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]:
+    """ Get a map from fusion pattern to a function that returns extra input nodes
+    from the fusion pattern, in the order required by the root node. This is optional,
+    if not specified, we will not copy over any extra inputs for the root node.
+    Example:
+    # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
+    # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
+    # argument to the fused module, we can unpack the pattern and return the node at
+    # MatchAllNode here
+    # we can implement extra_inputs_getter as follows:
+    def extra_inputs_getter(pattern) -> List[Any]:
+        add, extra_input, conv_pattern = pattern
+        return [extra_input]
+    """
+    extra_inputs_getter_mapping: Dict[Pattern, Callable] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config._extra_inputs_getter is not None:
+            extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter
+    return extra_inputs_getter_mapping
+
+def remove_boolean_dispatch_from_name(p) -> Any:
+    """
+    Some ops have a default string representation such as
+    '.fn at 0x7ff1106bf280>',
+    this function replaces them with the hardcoded function names.
+    """
+    if p is F.fractional_max_pool2d:
+        return "torch.nn.functional.fractional_max_pool2d"
+    elif p is F.fractional_max_pool3d:
+        return "torch.nn.functional.fractional_max_pool3d"
+    elif p is F.max_pool1d:
+        return "torch.nn.functional.max_pool1d"
+    elif p is F.max_pool2d:
+        return "torch.nn.functional.max_pool2d"
+    elif p is F.max_pool3d:
+        return "torch.nn.functional.max_pool3d"
+    elif p is F.adaptive_max_pool1d:
+        return "torch.nn.functional.adaptive_max_pool1d"
+    elif p is F.adaptive_max_pool2d:
+        return "torch.nn.functional.adaptive_max_pool2d"
+    elif p is F.adaptive_max_pool3d:
+        return "torch.nn.functional.adaptive_max_pool3d"
+    assert "boolean_dispatch" not in str(p), \
+        f"{p} does not have a human readable representation in " + \
+        "quantization documentation"
+    return p
+
+def pattern_to_human_readable(p) -> Any:
+    if isinstance(p, tuple):
+        # nested patterns, recurse
+        return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
+    elif isinstance(p, str):
+        # method names are already human readable
+        return p
+    else:
+        p = remove_boolean_dispatch_from_name(p)
+        return p
+
+# TODO(future PR): move backend_config_dict to use dataclass and move this logic to
+# the corresponding __str__ function
+def entry_to_pretty_str(entry) -> str:
+    """
+    Given a backend_config_dict entry, returns a string with the human readable
+    representation of it.
+    """
+    s = "{\n"
+
+    # always output the pattern first
+    if "pattern" in entry:
+        pattern_str = pattern_to_human_readable(entry["pattern"])
+
+        s += f"  'pattern': {pattern_str},\n"
+
+    # custom output for dtype_configs to make it look nice
+    if "dtype_configs" in entry:
+        s += "  'dtype_configs': [\n"
+        for dtype_config in entry["dtype_configs"]:
+            s += "    {\n"
+            for k, v in dtype_config.items():
+                s += f"      '{k}': {v},\n"
+            s += "    },\n"
+        s += "  ],\n"
+
+    # custom output for num_tensor_args_to_observation_type to make it look nice
+    if "num_tensor_args_to_observation_type" in entry:
+        s += "  'num_tensor_args_to_observation_type': {\n"
+        for k, v in entry["num_tensor_args_to_observation_type"].items():
+            s += f"    {k}: {v},\n"
+        s += "  },\n"
+
+    # output all the other fields
+    custom_handled_fields = [
+        "pattern",
+        "dtype_configs",
+        "num_tensor_args_to_observation_type",
+    ]
+    for field_name in entry:
+        if field_name in custom_handled_fields:
+            continue
+        s += f"  '{field_name}': {entry[field_name]},\n"
+
+    s += "}"
+    return s
+
+def _get_pattern_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Pattern:
+    """
+    Return the pattern specified in the given config in the reversed nested tuple format
+    used internally in the quantization pattern matching code.
+
+    If the pattern is not a tuple, or the pattern is already specified in the reversed
+    nested tuple format, return the pattern as is. Otherwise:
+
+    For 2-tuples (a, b), return (b, a).
+    For 3-tuples (a, b, c), return (c, (b, a)).
+
+    For example:
+        * Given nn.Linear, return nn.Linear
+        * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear)
+        * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return
+          (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
+
+    For context, the reason why this is needed is the user-facing BackendConfig
+    API accepts the flat 2-or-3-tuple format in forward order. While this simple
+    format handles the vast majority of use cases, it does not handle the more
+    complex ones, and so the internal pattern matching code for quantization uses
+    the following, more general reversed nested tuple format instead:
+
+        operator = module_type | functional | torch op | native op | MatchAllNode
+        Pattern = (operator, Pattern, Pattern, ...) | operator
+
+    In the future, we expect to replace the above complex format with the one used
+    by the subgraph rewriter in torch.fx, so we don't have to maintain our own
+    complex pattern matching code. Then we won't need this helper function anymore.
+    """
+    if config._pattern_complex_format is not None:
+        return config._pattern_complex_format
+    if config.pattern is None:
+        raise ValueError("Either 'pattern' or 'pattern_complex_format' must be specified")
+    if not isinstance(config.pattern, tuple):
+        return config.pattern
+
+    # Pattern is specified in the simple tuple format, need to convert
+    if len(config.pattern) == 2:
+        (a, b) = config.pattern
+        return (b, a)
+    elif len(config.pattern) == 3:
+        (a, b, c) = config.pattern
+        return (c, (b, a))
+    else:
+        raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
+
+def _get_fuser_method_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Callable:
+    """
+    Return the fuser method specified in the given config in the reversed nested
+    tuple format used internally in the quantization pattern matching code.
+
+    If pattern is specified in the reversed nested tuple format, we assume the
+    fuser method is also specified in this format and simply return it as is.
+    Otherwise, we convert the fuser method as follows:
+
+        * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv)
+        * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv),
+          where bn_conv is a 2-tuple (bn, conv)
+
+    The first argument of a fuser method is always `is_qat` and is not affected
+    in the conversion. We currently only support functions with 3 or 4 arguments.
+    """
+    assert config.fuser_method is not None
+    if config._pattern_complex_format is not None:
+        return config.fuser_method
+    if not isinstance(config.pattern, tuple):
+        raise ValueError("Expected pattern to be a tuple, got: ", config.pattern)
+
+    # Pattern is specified in the simple tuple format, need to convert
+    if len(config.pattern) == 2:
+        return _reverse2(config.fuser_method)
+    elif len(config.pattern) == 3:
+        return _reverse3(config.fuser_method)
+    else:
+        raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/x86.py b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/x86.py
new file mode 100644
index 0000000000000000000000000000000000000000..2daaded0499048a80100c184af4d5bd5ee8ea01d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/backend_config/x86.py
@@ -0,0 +1,113 @@
+import torch
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig
+
+__all__ = [
+    "get_x86_backend_config",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+# X86 aligns with FBGEMM for now
+
+x86_weighted_op_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+x86_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+x86_default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+x86_default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+x86_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+x86_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+x86_weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+def get_x86_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native x86 backend.
+    """
+    conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
+    linear_dtype_configs = [
+        x86_weighted_op_int8_dtype_config,
+        x86_default_dynamic_int8_dtype_config,
+        x86_default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
+    default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
+    share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
+    tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config]
+    rnn_op_dtype_configs = [
+        x86_default_dynamic_int8_dtype_config,
+        x86_default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        x86_weight_only_quint8_dtype_config,
+        x86_weight_only_quint4x2_dtype_config,
+    ]
+    return BackendConfig("x86") \
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
+        .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fake_quantize.py b/MLPY/Lib/site-packages/torch/ao/quantization/fake_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..4307d7fda0470b2b400cebdcdd4176675f7b9249
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fake_quantize.py
@@ -0,0 +1,546 @@
+"""Implements modules  used to perform fake quantization."""
+
+import torch
+from torch.nn import Module
+from torch.ao.quantization.observer import (
+    MovingAverageMinMaxObserver,
+    HistogramObserver,
+    MovingAveragePerChannelMinMaxObserver,
+    FixedQParamsObserver,
+    default_fixed_qparams_range_0to1_observer,
+    default_fixed_qparams_range_neg1to1_observer,
+    _with_args,
+)
+import re
+from abc import ABC, abstractmethod
+from typing import Any, Tuple
+
+__all__ = [
+    "FakeQuantizeBase",
+    "FakeQuantize",
+    "FixedQParamsFakeQuantize",
+    "FusedMovingAvgObsFakeQuantize",
+    "disable_fake_quant",
+    "disable_observer",
+    "enable_fake_quant",
+    "enable_observer",
+    "default_fake_quant",
+    "default_weight_fake_quant",
+    "default_dynamic_fake_quant",
+    "default_fixed_qparams_range_neg1to1_fake_quant",
+    "default_fixed_qparams_range_0to1_fake_quant",
+    "default_symmetric_fixed_qparams_fake_quant",
+    "default_affine_fixed_qparams_fake_quant",
+    "default_per_channel_weight_fake_quant",
+    "default_embedding_fake_quant",
+    "default_embedding_fake_quant_4bit",
+    "default_histogram_fake_quant",
+    "default_fused_act_fake_quant",
+    "default_fused_wt_fake_quant",
+    "default_fused_per_channel_wt_fake_quant",
+    "fused_wt_fake_quant_range_neg_127_to_127",
+    "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
+]
+
+def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
+
+def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
+
+def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
+
+def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
+    return qscheme in [torch.per_channel_affine_float_qparams, ]
+
+class FakeQuantizeBase(ABC, Module):
+    r"""Base fake quantize module.
+
+    Base fake quantize module
+    Any fake quantize implementation should derive from this class.
+
+    Concrete fake quantize module should follow the same API. In forward, they will update
+    the statistics of the observed Tensor and fake quantize the input. They should also provide a
+    `calculate_qparams` function that computes the quantization parameters given
+    the collected statistics.
+
+    """
+
+    fake_quant_enabled: torch.Tensor
+    observer_enabled: torch.Tensor
+
+    def __init__(self):
+        """Set fake_quant_enabled and observer_enabled."""
+        super().__init__()
+        # fake_quant_enabled and observer_enabled are buffers to support their
+        # replication in DDP. Data type is uint8 because NCCL does not support
+        # bool tensors.
+        self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
+        self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
+
+    @abstractmethod
+    def forward(self, x):
+        pass
+
+    @abstractmethod
+    def calculate_qparams(self, **kwargs):
+        pass
+
+    @torch.jit.export
+    def enable_fake_quant(self, enabled: bool = True) -> None:
+        self.fake_quant_enabled[0] = 1 if enabled else 0
+
+    @torch.jit.export
+    def disable_fake_quant(self):
+        self.enable_fake_quant(False)
+
+    @torch.jit.export
+    def enable_observer(self, enabled: bool = True) -> None:
+        self.observer_enabled[0] = 1 if enabled else 0
+
+    @torch.jit.export
+    def disable_observer(self):
+        self.enable_observer(False)
+
+    @classmethod
+    def with_args(cls, **kwargs):
+        fake_quant_constructor = _with_args(cls, **kwargs)
+        # need to assign the correct module to fake_quantize
+        # constructors to satisfy public v private requirements
+        fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
+        return fake_quant_constructor
+
+class FakeQuantize(FakeQuantizeBase):
+    r"""Simulate the quantize and dequantize operations in training time.
+
+    The output of this module is given by::
+
+        x_out = (
+          clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point
+        ) * scale
+
+    * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
+      operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
+
+    * :attr:`scale` defines the scale factor used for quantization.
+
+    * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
+
+    * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
+      statistics can still be updated.
+
+    * :attr:`observer_enabled` controls statistics collection on tensors
+
+    * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
+        allowable values are torch.qint8 and torch.quint8.
+
+    Args:
+
+        observer (module): Module for observing statistics on input tensors and calculating scale
+          and zero-point.
+        observer_kwargs (optional): Arguments for the observer module
+
+    Attributes:
+        activation_post_process (Module): User provided module that collects statistics on the input tensor and
+          provides a method to calculate scale and zero-point.
+
+    """
+
+    scale: torch.Tensor
+    zero_point: torch.Tensor
+
+    def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=None, quant_max=None, is_dynamic=False, **observer_kwargs):
+        super().__init__()
+        # Populate quant_min/quant_max to observer_kwargs if valid
+        if quant_min is not None and quant_max is not None:
+            assert quant_min <= quant_max, \
+                'quant_min must be less than or equal to quant_max'
+            dtype = observer_kwargs.get("dtype", torch.quint8)
+            if hasattr(observer, "p"):
+                # In case observer is _PartialWrapper, dtype can be stored in
+                # observer.p.keywords["dtype"]
+                dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
+                    "dtype", dtype
+                )
+            assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
+            assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
+            observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
+        observer_kwargs["is_dynamic"] = is_dynamic
+        self.activation_post_process = observer(**observer_kwargs)
+        # TODO: keeping self.quant_min/max for BC; remove after a couple releases
+        # Users should use self.activation_post_process.quant_min
+        self.quant_min = self.activation_post_process.quant_min
+        self.quant_max = self.activation_post_process.quant_max
+        self.is_dynamic = self.activation_post_process.is_dynamic
+        if _is_float_qparams(self.activation_post_process.qscheme):
+            zero_point_dtype = torch.float
+        else:
+            zero_point_dtype = torch.int
+        self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
+        self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
+        self.dtype = self.activation_post_process.dtype
+        self.qscheme = self.activation_post_process.qscheme
+        self.ch_axis = self.activation_post_process.ch_axis \
+            if hasattr(self.activation_post_process, 'ch_axis') else -1
+        assert _is_per_channel(self.qscheme) or \
+            _is_per_tensor(self.qscheme), \
+            'Only per channel and per tensor quantization are supported in fake quantize' + \
+            ' got qscheme: ' + str(self.qscheme)
+        self.is_per_channel = _is_per_channel(self.qscheme)
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        return self.activation_post_process.calculate_qparams()
+
+    def forward(self, X):
+        if self.observer_enabled[0] == 1:
+            self.activation_post_process(X.detach())
+            _scale, _zero_point = self.calculate_qparams()
+            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
+            if self.scale.shape != _scale.shape:
+                self.scale.resize_(_scale.shape)
+                self.zero_point.resize_(_zero_point.shape)
+            self.scale.copy_(_scale)
+            self.zero_point.copy_(_zero_point)
+
+        if self.fake_quant_enabled[0] == 1:
+            if self.is_per_channel:
+                X = torch.fake_quantize_per_channel_affine(
+                    X, self.scale, self.zero_point,
+                    self.ch_axis, self.activation_post_process.quant_min, self.activation_post_process.quant_max)
+            else:
+                X = torch.fake_quantize_per_tensor_affine(
+                    X, self.scale, self.zero_point,
+                    self.activation_post_process.quant_min, self.activation_post_process.quant_max)
+        return X
+
+    @torch.jit.export
+    def extra_repr(self):
+        return 'fake_quant_enabled={}, observer_enabled={}, ' \
+               'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
+               'scale={}, zero_point={}'.format(
+                   self.fake_quant_enabled, self.observer_enabled,
+                   self.activation_post_process.quant_min, self.activation_post_process.quant_max,
+                   self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        # We cannot currently register scalar values as buffers, so need to manually
+        # specify serialization here.
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + 'scale'] = self.scale
+        destination[prefix + 'zero_point'] = self.zero_point
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        # Removing this function throws an error that the size of the loaded tensor does not match the original size
+        # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
+        local_state = ['scale', 'zero_point']
+        for name in local_state:
+            key = prefix + name
+            if key in state_dict:
+                val = state_dict[key]
+                # Custom handling to allow loading scale and zero_point
+                # of size N into uninitialized buffers of size 0. The
+                # buffers are resized here, and the values are copied in
+                # the default state_dict loading code of the parent.
+                if name == 'scale':
+                    self.scale.resize_(val.shape)
+                else:
+                    assert name == 'zero_point'
+                    self.zero_point.resize_(val.shape)
+                # For torchscript module we need to update the attributes here since we do not
+                # call the `_load_from_state_dict` function defined module.py
+                if torch.jit.is_scripting():
+                    if name == 'scale':
+                        self.scale.copy_(val)
+                    else:
+                        assert name == 'zero_point'
+                        self.zero_point.copy_(val)
+            elif strict:
+                missing_keys.append(key)
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
+                                      missing_keys, unexpected_keys, error_msgs)
+
+
+class FixedQParamsFakeQuantize(FakeQuantize):
+    """Simulate quantize and dequantize in training time.
+
+    Simulate quantize and dequantize with fixed quantization
+    parameters in training time. Only per tensor quantization
+    is supported.
+    """
+
+    # TODO: rename observer to observer_ctr
+    def __init__(self, observer):
+        super().__init__(observer=observer)
+        assert type(self.activation_post_process) == FixedQParamsObserver, \
+            f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
+        self._observer_ctr = observer
+        self.scale = self.activation_post_process.scale
+        self.zero_point = self.activation_post_process.zero_point
+        assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
+            ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        return self.scale, self.zero_point
+
+    @torch.jit.export
+    def extra_repr(self):
+        """Define a string representation of the object's attributes."""
+        return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
+               'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
+                   self.fake_quant_enabled, self.observer_enabled,
+                   self.scale, self.zero_point, self.dtype,
+                   self.activation_post_process.quant_min, self.activation_post_process.quant_max, self.qscheme)
+
+
+class FusedMovingAvgObsFakeQuantize(FakeQuantize):
+    r"""Define a fused module to observe the tensor.
+
+    Fused module that is used to observe the input tensor (compute min/max), compute
+    scale/zero_point and fake_quantize the tensor.
+    This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
+    to compute the min/max values in order to compute the scale/zero_point.
+    The qscheme input in the observer is used to differentiate between symmetric/affine
+    quantization scheme.
+
+    The output of this module is given by
+    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
+
+    Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
+    base class.
+
+    """
+
+    def __init__(
+        self,
+        observer: Any = MovingAverageMinMaxObserver,
+        quant_min: int = 0,
+        quant_max: int = 255,
+        **observer_kwargs: Any
+    ) -> None:
+        super().__init__(observer, quant_min, quant_max, **observer_kwargs)
+        assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \
+            "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
+        self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
+        self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
+        self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
+
+    @torch.jit.export
+    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
+        return self.activation_post_process.calculate_qparams()
+
+    @torch.jit.export
+    def extra_repr(self) -> str:
+        return (
+            "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
+            "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
+                self.fake_quant_enabled,
+                self.observer_enabled,
+                self.scale,
+                self.zero_point,
+                self.dtype,
+                self.activation_post_process.quant_min,
+                self.activation_post_process.quant_max,
+                self.qscheme,
+                self.activation_post_process.reduce_range,
+            )
+        )
+
+    def forward(self, X: torch.Tensor) -> torch.Tensor:
+        return torch.fused_moving_avg_obs_fake_quant(
+            X,
+            self.observer_enabled,
+            self.fake_quant_enabled,
+            self.activation_post_process.min_val,
+            self.activation_post_process.max_val,
+            self.scale,
+            self.zero_point,
+            self.activation_post_process.averaging_constant,
+            self.activation_post_process.quant_min,
+            self.activation_post_process.quant_max,
+            self.ch_axis,
+            self.is_per_channel,
+            self.is_symmetric_quant,
+        )
+
+default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
+                                            dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
+"""
+Default fake_quant for activations.
+"""
+
+default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
+                                                   dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
+"""
+Default fake_quant for weights.
+Observer is memoryless since averaging_constant is 1.
+"""
+
+default_dynamic_fake_quant = FakeQuantize.with_args(
+    observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, is_dynamic=True,
+    dtype=torch.quint8, averaging_constant=1)
+"""
+Default dynamic fake_quant for activations.
+"""
+
+default_fixed_qparams_range_neg1to1_fake_quant = (
+    FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
+)
+default_fixed_qparams_range_0to1_fake_quant = (
+    FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
+)
+# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
+default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
+default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
+
+default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                                               quant_min=-128,
+                                                               quant_max=127,
+                                                               dtype=torch.qint8,
+                                                               qscheme=torch.per_channel_symmetric,
+                                                               reduce_range=False,
+                                                               ch_axis=0)
+"""
+Default fake_quant for per-channel weights.
+Observer is memoryless since averaging_constant is 1.
+"""
+default_embedding_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                                      qscheme=torch.per_channel_affine_float_qparams,
+                                                      dtype=torch.quint8,
+                                                      quant_min=0,
+                                                      quant_max=255,
+                                                      ch_axis=0,
+                                                      averaging_constant=1)
+"""
+Default fake_quant for embeddings.
+Observer is memoryless since averaging_constant is 1.
+"""
+
+default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                                           qscheme=torch.per_channel_affine_float_qparams,
+                                                           ch_axis=0,
+                                                           dtype=torch.quint4x2,
+                                                           averaging_constant=1)
+
+default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
+                                                      quant_min=0,
+                                                      quant_max=255,
+                                                      dtype=torch.quint8,
+                                                      qscheme=torch.per_tensor_affine,
+                                                      reduce_range=True)
+"""
+Fake_quant for activations using a histogram..
+"""
+
+
+default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                       quant_min=0,
+                                                                       quant_max=255,
+                                                                       dtype=torch.quint8,)
+
+"""
+Fused version of `default_fake_quant`, with improved performance.
+"""
+
+
+default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                      quant_min=-128,
+                                                                      quant_max=127,
+                                                                      dtype=torch.qint8,
+                                                                      qscheme=torch.per_tensor_symmetric)
+"""
+Fused version of `default_weight_fake_quant`, with improved performance.
+"""
+
+default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                                                                  quant_min=-128,
+                                                                                  quant_max=127,
+                                                                                  dtype=torch.qint8,
+                                                                                  qscheme=torch.per_channel_symmetric)
+"""
+Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
+"""
+
+fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                                   quant_min=-127,
+                                                                                   quant_max=127,
+                                                                                   dtype=torch.qint8,
+                                                                                   qscheme=torch.per_tensor_symmetric,
+                                                                                   eps=2 ** -12)
+"""
+Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
+"""
+
+fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \
+    FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
+                                            quant_min=-127,
+                                            quant_max=127,
+                                            dtype=torch.qint8,
+                                            qscheme=torch.per_channel_symmetric,
+                                            eps=2 ** -12)
+
+"""
+Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
+"""
+
+
+def _is_fake_quant_script_module(mod):
+    """Return true if given mod is an instance of FakeQuantize script module."""
+    if isinstance(mod, torch.jit.RecursiveScriptModule):
+        # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
+        suffix = mod._c.qualified_name.split('.', 1)[1]
+        name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
+        return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
+            name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
+    return False
+
+def disable_fake_quant(mod):
+    """Disable fake quantization for the module.
+
+    Disable fake quantization for this module, if applicable. Example usage::
+
+      # model is any PyTorch model
+      model.apply(torch.ao.quantization.disable_fake_quant)
+
+    """
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.disable_fake_quant()
+
+def enable_fake_quant(mod):
+    """Enable fake quantization for the module.
+
+    Enable fake quantization for this module, if applicable. Example usage::
+
+      # model is any PyTorch model
+      model.apply(torch.ao.quantization.enable_fake_quant)
+
+    """
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.enable_fake_quant()
+
+def disable_observer(mod):
+    """Disable observation for this module.
+
+    Disable observation for this module, if applicable. Example usage::
+
+      # model is any PyTorch model
+      model.apply(torch.ao.quantization.disable_observer)
+
+    """
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.disable_observer()
+
+def enable_observer(mod):
+    """Enable observation for this module.
+
+    Enable observation for this module, if applicable. Example usage::
+
+      # model is any PyTorch model
+      model.apply(torch.ao.quantization.enable_observer)
+
+    """
+    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
+        mod.enable_observer()
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fuse_modules.py b/MLPY/Lib/site-packages/torch/ao/quantization/fuse_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..af23dd75b1a55caa0f4ed7fe51cafa739bf30d93
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fuse_modules.py
@@ -0,0 +1,175 @@
+import copy
+
+import torch.nn as nn
+
+from torch.ao.quantization.fuser_method_mappings import get_fuser_method
+# for backward compatibility
+from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn  # noqa: F401
+from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu  # noqa: F401
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+from typing import List, Optional
+
+__all__ = [
+    "fuse_known_modules",
+    "fuse_modules",
+    "fuse_modules_qat",
+]
+
+# Generalization of getattr
+def _get_module(model, submodule_key):
+    tokens = submodule_key.split('.')
+    cur_mod = model
+    for s in tokens:
+        cur_mod = getattr(cur_mod, s)
+    return cur_mod
+
+# Generalization of setattr
+def _set_module(model, submodule_key, module):
+    tokens = submodule_key.split('.')
+    sub_tokens = tokens[:-1]
+    cur_mod = model
+    for s in sub_tokens:
+        cur_mod = getattr(cur_mod, s)
+
+    setattr(cur_mod, tokens[-1], module)
+
+def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
+    r"""Return a list of known fuse modules.
+
+    Returns a list of modules that fuses the operations specified
+     in the input module list.
+
+    Fuses only the following sequence of modules:
+    conv, bn
+    conv, bn, relu
+    conv, relu
+    linear, bn
+    linear, relu
+    For these sequences, the first element in the output module list performs
+    the fused operation. The rest of the elements are set to nn.Identity()
+    """
+    types = tuple(type_before_parametrizations(m) for m in mod_list)
+    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
+    if fuser_method is None:
+        raise NotImplementedError(f"Cannot fuse modules: {types}")
+    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
+    fused = fuser_method(is_qat, *mod_list)
+    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
+    # Move pre forward hooks of the base module to resulting fused module
+    for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
+        fused.register_forward_pre_hook(pre_hook_fn)
+    mod_list[0]._forward_pre_hooks.clear()
+    # Move post forward hooks of the last module to resulting fused module
+    for hook_fn in mod_list[-1]._forward_hooks.values():
+        fused.register_forward_hook(hook_fn)
+    mod_list[-1]._forward_hooks.clear()
+    new_mod[0] = fused
+
+    for i in range(1, len(mod_list)):
+        identity = nn.Identity()
+        identity.training = mod_list[0].training
+        new_mod[i] = identity
+
+    return new_mod
+
+def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+    if fuse_custom_config_dict is None:
+        fuse_custom_config_dict = {}
+    additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
+    mod_list = []
+    for item in modules_to_fuse:
+        mod_list.append(_get_module(model, item))
+
+    # Fuse list of modules
+    new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
+
+    # Replace original module list with fused module list
+    for i, item in enumerate(modules_to_fuse):
+        _set_module(model, item, new_mod_list[i])
+
+def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    if all(isinstance(module_element, str) for module_element in modules_to_fuse):
+        # Handle case of modules_to_fuse being a list
+        _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict)
+    else:
+        # Handle case of modules_to_fuse being a list of lists
+        for module_list in modules_to_fuse:
+            _fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict)
+    return model
+
+def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+    r"""Fuse a list of modules into a single module.
+
+    Fuses only the following sequence of modules:
+    conv, bn
+    conv, bn, relu
+    conv, relu
+    linear, relu
+    bn, relu
+    All other sequences are left unchanged.
+    For these sequences, replaces the first item in the list
+    with the fused module, replacing the rest of the modules
+    with identity.
+
+    Args:
+        model: Model containing the modules to be fused
+        modules_to_fuse: list of list of module names to fuse. Can also be a list
+                         of strings if there is only a single list of modules to fuse.
+        inplace: bool specifying if fusion happens in place on the model, by default
+                 a new model is returned
+        fuser_func: Function that takes in a list of modules and outputs a list of fused modules
+                    of the same length. For example,
+                    fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
+                    Defaults to torch.ao.quantization.fuse_known_modules
+        `fuse_custom_config_dict`: custom configuration for fusion
+
+    .. code-block:: python
+
+       # Example of fuse_custom_config_dict
+       fuse_custom_config_dict = {
+           # Additional fuser_method mapping
+           "additional_fuser_method_mapping": {
+               (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
+           },
+       }
+
+    Returns:
+        model with fused modules. A new copy is created if inplace=True.
+
+    Examples::
+
+            >>> # xdoctest: +SKIP
+            >>> m = M().eval()
+            >>> # m is a module containing the sub-modules below
+            >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
+            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
+            >>> output = fused_m(input)
+
+            >>> m = M().eval()
+            >>> # Alternately provide a single list of modules to fuse
+            >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
+            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
+            >>> output = fused_m(input)
+
+    """
+    return _fuse_modules(
+        model,
+        modules_to_fuse,
+        is_qat=False,
+        inplace=inplace,
+        fuser_func=fuser_func,
+        fuse_custom_config_dict=fuse_custom_config_dict)
+
+def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+    """QAT version for `fuse_modules`."""
+    return _fuse_modules(
+        model,
+        modules_to_fuse,
+        is_qat=True,
+        inplace=inplace,
+        fuser_func=fuser_func,
+        fuse_custom_config_dict=fuse_custom_config_dict)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fuser_method_mappings.py b/MLPY/Lib/site-packages/torch/ao/quantization/fuser_method_mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..d23f4247b3c30ba96bcaa9c0eec3c9f4e7a9d51c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fuser_method_mappings.py
@@ -0,0 +1,259 @@
+import torch.nn as nn
+import torch.ao.nn.intrinsic as nni
+
+from typing import Any, Union, Callable, List, Tuple, Dict, Optional, Type
+from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode
+import itertools
+
+__all__ = [
+    "fuse_conv_bn",
+    "fuse_conv_bn_relu",
+    "fuse_linear_bn",
+    "fuse_convtranspose_bn",
+    "get_fuser_method",
+    "get_fuser_method_new",
+]
+
+def fuse_conv_bn(is_qat, conv, bn):
+    r"""Return the fused the conv and bn modules.
+    Given the conv and bn modules, fuses them and returns the fused module
+
+    Args:
+        is_qat: a flag for whether we are using quantization aware training fusion
+        or post training quantization fusion
+        conv: Module instance of type conv2d/conv3d
+        bn: Spatial BN instance that needs to be fused with the conv
+
+    Examples::
+
+        >>> m1 = nn.Conv2d(10, 20, 3)
+        >>> b1 = nn.BatchNorm2d(20)
+        >>> # xdoctest: +SKIP
+        >>> m2 = fuse_conv_bn(m1, b1)
+    """
+    assert conv.training == bn.training, \
+        "Conv and BN both must be in the same mode (train or eval)."
+
+    fused_module_class_map = {
+        nn.Conv1d: nni.ConvBn1d,
+        nn.Conv2d: nni.ConvBn2d,
+        nn.Conv3d: nni.ConvBn3d,
+    }
+
+    if is_qat:
+        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
+        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
+        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
+        fused_module_class = fused_module_class_map.get((type(conv)), None)
+        if fused_module_class is not None:
+            return fused_module_class(conv, bn)
+        else:
+            raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}")
+    else:
+        return nn.utils.fuse_conv_bn_eval(conv, bn)
+
+def fuse_conv_bn_relu(is_qat, conv, bn, relu):
+    r"""Return the fused conv and bv modules.
+
+    Given the conv and bn modules, fuses them and returns the fused module
+
+    Args:
+        is_qat: a flag for whether we are using quantization aware training fusion
+        or post training quantization fusion
+        conv: Module instance of type conv2d/conv3d
+        bn: Spatial BN instance that needs to be fused with the conv
+
+    Examples::
+
+        >>> m1 = nn.Conv2d(10, 20, 3)
+        >>> b1 = nn.BatchNorm2d(20)
+        >>> r1 = nn.ReLU(inplace=False)
+        >>> # xdoctest: +SKIP
+        >>> m2 = fuse_conv_bn_relu(m1, b1, r1)
+    """
+    assert conv.training == bn.training == relu.training, \
+        "Conv and BN both must be in the same mode (train or eval)."
+    fused_module : Optional[Type[nn.Sequential]] = None
+    if is_qat:
+        map_to_fused_module_train = {
+            nn.Conv1d: nni.ConvBnReLU1d,
+            nn.Conv2d: nni.ConvBnReLU2d,
+            nn.Conv3d: nni.ConvBnReLU3d,
+        }
+        assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
+        assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
+        assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
+        fused_module = map_to_fused_module_train.get(type(conv), None)
+        if fused_module is not None:
+            return fused_module(conv, bn, relu)
+        else:
+            raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}")
+    else:
+        map_to_fused_module_eval = {
+            nn.Conv1d: nni.ConvReLU1d,
+            nn.Conv2d: nni.ConvReLU2d,
+            nn.Conv3d: nni.ConvReLU3d,
+        }
+        fused_module = map_to_fused_module_eval.get(type(conv), None)
+        if fused_module is not None:
+            fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+            return fused_module(fused_conv, relu)
+        else:
+            raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}")
+
+def fuse_linear_bn(is_qat, linear, bn):
+    r"""Return the fused linear and bn modules.
+    Given the linear and bn modules, fuses them and returns the fused module
+
+    Args:
+        is_qat: a flag for whether we are using quantization aware training fusion
+        or post training quantization fusion
+        linear: Module instance of type Linear
+        bn: BatchNorm1d instance that needs to be fused with the linear layer
+
+    Examples::
+
+        >>> m1 = nn.Linear(20, 10)
+        >>> b1 = nn.BatchNorm1d(10)
+        >>> # xdoctest: +SKIP
+        >>> m2 = fuse_linear_bn(m1, b1)
+    """
+    assert linear.training == bn.training, \
+        "Linear and BN both must be in the same mode (train or eval)."
+
+    if is_qat:
+        assert bn.num_features == linear.out_features, \
+            "Output features of Linear must match num_features of BatchNorm1d"
+        assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
+        assert bn.track_running_stats, \
+            "Only support fusing BatchNorm1d with tracking_running_stats set to True"
+        return nni.LinearBn1d(linear, bn)
+    else:
+        return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
+
+def fuse_convtranspose_bn(is_qat, convt, bn):
+    r"""Return the fused ConvTranspose and bn modules.
+    Given ConvTranspose and bn modules, fuses them and returns the fused module
+
+    Args:
+        convt: Module instance of type ConvTransposeNd
+        bn: BatchNormNd instance that needs to be fused with the linear layer.
+            batch norm N should match the ConvTranspose N
+
+    Examples::
+
+        >>> m1 = nn.ConvTranspose2d(10, 20, 3)
+        >>> b1 = nn.BatchNorm2d(20)
+        >>> # xdoctest: +SKIP
+        >>> m2 = fuse_convtranspose_bn(m1, b1)
+    """
+    assert convt.training == bn.training, \
+        "ConvTranspose and BN both must be in the same mode (train or eval)."
+
+    if is_qat:
+        raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.")
+    else:
+        return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
+
+def _sequential_wrapper2(sequential):
+    """Return a sequential wrapped that for is_qat and two modules.
+    Given a sequential class for two modules, return a function that takes
+    is_qat, and then two modules as argument, that ignores the is_qat flag
+    and always returns the sequential that combines the two input modules
+    """
+    def fuser_method(is_qat, m1, m2):
+        return sequential(m1, m2)
+    return fuser_method
+
+_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
+    (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
+    (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
+    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
+    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
+    (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
+    (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
+    (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d),
+    (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d),
+    (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d),
+    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
+    (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU),
+    (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d),
+    (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d),
+    (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
+    (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
+    (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
+}
+
+def get_fuser_method(op_list, additional_fuser_method_mapping=None):
+    """Get fuser method for the given list of module types.
+
+    Get fuser method for the given list of module types,
+    return None if fuser method does not exist
+    """
+    if additional_fuser_method_mapping is None:
+        additional_fuser_method_mapping = {}
+    all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD,
+                                     additional_fuser_method_mapping)
+    fuser_method = all_mappings.get(op_list, None)
+    assert fuser_method is not None, f"did not find fuser method for: {op_list} "
+    return fuser_method
+
+def _reverse2(f):
+    def reversed(is_qat, x, y):
+        return f(is_qat, y, x)
+    return reversed
+
+def _reverse3(f):
+    def reversed(is_qat, x, w):
+        y, z = w
+        return f(is_qat, z, y, x)
+    return reversed
+
+def _get_valid_patterns(op_pattern):
+    """Return a list of valid patterns generated from the op_pattern.
+
+    Returns a list of valid patterns generated from the op_pattern,
+    since MatchAllNode can match all types of nodes,
+    e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like
+    (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode)
+
+    Example Input:
+    (torch.add, (torch.nn.ReLU, torch.nn.Conv2d))
+
+    Example Output:
+    [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)),
+     (torch.add, (torch.nn.ReLU, MatchAllNode)),
+     (torch.add, (MatchAllNode, torch.nn.Conv2d)),
+     (torch.add, (MatchAllNode, MatchAllNode)),
+     (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)),
+     (MatchAllNode, (torch.nn.ReLU, MatchAllNode)),
+     (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)),
+     (MatchAllNode, (MatchAllNode, MatchAllNode)),
+    ]
+    """
+    result: List[Any]
+    if isinstance(op_pattern, (tuple, list)):
+        sub_combs = []
+        for sub_pattern in op_pattern:
+            sub_combs.append(_get_valid_patterns(sub_pattern))
+        result = list(itertools.product(*sub_combs))
+    else:
+        result = [op_pattern, MatchAllNode]
+    return result
+
+def get_fuser_method_new(
+        op_pattern: Pattern,
+        fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]]):
+    """Get fuser method.
+
+    This will be made default after we deprecate the get_fuser_method
+    Would like to implement this first and have a separate PR for deprecation
+    """
+    op_patterns = _get_valid_patterns(op_pattern)
+    fuser_method = None
+    for op_pattern in op_patterns:
+        fuser_method = fuser_method_mapping.get(op_pattern, None)
+        if fuser_method is not None:
+            break
+    assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
+    return fuser_method
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..030caec3731699dc8972417d6401b6b0b8eb0f2a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__init__.py
@@ -0,0 +1,3 @@
+from .prepare import prepare
+from .convert import convert
+from .fuse import fuse
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca3ce9346ecb53ed58362d62ab5e4eda14bfd99c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..035d37a69ab28ab0f0758805ab5825829ea5dfc4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45f73b2559d66ef2a13b7f5c686a57cb3142b05c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f81805e749c9022d190fd081be3018acfe2489f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be5b23071b4faedc22afb0325fc4bc6045bb1631
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ebf54aa9efcaebc3b9dddec40511fe43e5000b7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1199effa912f593ecb4307b68dedb7c6c8602e91
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc75ab0940467987fc903517e06f381d84cfbbfa
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d477f3451edef53526de5c87e56672e9faac3f1a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12732d8ffe74cbc7136cdd1428b15146a3abab70
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..618563a7eb9204faa873e54451c8bdbe5cc4fa1e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..300ccd9ce964ef3b128efaa5d6468a693e1dff7e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a515d531f11e343ed64145ba10f31d78805c3390
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13133cf1bb150e7297fc4c5364436ffdcad902d9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bd11c669927c92cc176a7089e1483321443b8f6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5adc177a714cc32f3fe4de3a900717f51a65e107
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2656b4d1843e70e78070b110fdd62655598420c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7747022f30d273a096031c4161a9e264435b9ea
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2124e41fc84fdc48e7669459e5e5994ea626da36
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_decomposed.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_decomposed.py
new file mode 100644
index 0000000000000000000000000000000000000000..6159cad7c94b157393947e018ffe24d8075b1d20
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_decomposed.py
@@ -0,0 +1,925 @@
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch.library import Library, impl
+from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
+from torch._refs import _unsqueeze_multiple
+
+
+# Note: decomposed means decomposed quantized tensor, using decomposed so that the
+# name is not too long
+quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
+
+_DTYPE_TO_QVALUE_BOUNDS = {
+    torch.uint8: (0, 255),
+    torch.int8: (-128, 127),
+    torch.int16: (-(2**15), 2**15 - 1),
+    torch.int32: (-(2**31), 2**31 - 1)
+}
+
+# Helper to check the passed in quant min and max are valid for the dtype
+def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
+    if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
+        raise ValueError(f"Unsupported dtype: {dtype}")
+    quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
+
+    assert quant_min >= quant_min_lower_bound, \
+        "quant_min out of bound for dtype, " \
+        f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
+
+    assert quant_max <= quant_max_upper_bound, \
+        "quant_max out of bound for dtype, " \
+        f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
+
+quantized_decomposed_lib.define(
+    "quantize_per_tensor(Tensor input, float scale, int zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")
+def quantize_per_tensor(
+        input: torch.Tensor,
+        scale: float,
+        zero_point: int,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine quantization for the Tensor using the same quantization parameters to map
+    from floating point to quantized values
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scale (float): quantization parameter for affine quantization
+       zero_point (int): quantization parameter for affine quantization
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    if input.dtype == torch.bfloat16:
+        input = input.to(torch.float32)
+
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+
+    inv_scale = 1.0 / scale
+    return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
+
+quantized_decomposed_lib.define(
+    "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd")
+def quantize_per_tensor_tensor(
+        input: torch.Tensor,
+        scale: torch.Tensor,
+        zero_point: torch.Tensor,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine quantization for the Tensor using the same quantization parameters to map
+    from floating point to quantized values
+    Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
+def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
+    if input.dtype == torch.bfloat16:
+        input = input.to(torch.float32)
+    assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    return torch.empty_like(input, dtype=dtype)
+
+# TODO: remove other variants and keep this one
+quantized_decomposed_lib.define(
+    "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
+    "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd")
+def quantize_per_tensor_tensor2(
+        input: torch.Tensor,
+        scale: torch.Tensor,
+        zero_point: torch.Tensor,
+        quant_min: torch.Tensor,
+        quant_max: torch.Tensor,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine quantization for the Tensor using the same quantization parameters to map
+    from floating point to quantized values
+    Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
+def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
+    return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
+
+# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
+# the signature as metadata for the input Tensor, this might be useful for pattern
+# matching in the future
+# We will revisit this later if we found there are no use cases for it
+quantized_decomposed_lib.define(
+    "dequantize_per_tensor(Tensor input, float scale, int zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
+def dequantize_per_tensor(
+        input: torch.Tensor,
+        scale: float,
+        zero_point: int,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine dequantization for the Tensor using the same quantization parameters to map
+    from quantized values to floating point values
+
+    Args:
+       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
+       e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with
+       quantization parameters in the argument of this function (scale/zero_point)
+
+       scale (float): quantization parameter for affine quantization
+
+       zero_point (int): quantization parameter for affine quantization
+
+       quant_min (int): minimum quantized value for input Tensor (not used in computation,
+       reserved for pattern matching)
+
+       quant_max (int): maximum quantized value for input Tensor (not used in computation,
+       reserved for pattern matching)
+
+       dtype (torch.dtype): dtype for input Tensor (not used in computation,
+       reserved for pattern matching)
+
+    Returns:
+       dequantized float32 Tensor
+    """
+    assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
+    if dtype in _DTYPE_TO_QVALUE_BOUNDS:
+        # TODO: investigate why
+        # (input - zero_point).to(torch.float32) * scale
+        # failed the test
+        return (input.to(torch.float32) - zero_point) * scale
+    else:
+        raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
+
+
+quantized_decomposed_lib.define(
+    "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
+def dequantize_per_tensor_tensor(
+        input: torch.Tensor,
+        scale: torch.Tensor,
+        zero_point: torch.Tensor,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine dequantization for the Tensor using the same quantization parameters to map
+    from quantized values to floating point values
+    Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
+def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
+    assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
+    if dtype in _DTYPE_TO_QVALUE_BOUNDS:
+        return torch.empty_like(input, dtype=torch.float32)
+    else:
+        raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
+
+# TODO: remove other variants and keep this one
+quantized_decomposed_lib.define(
+    "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
+    "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd")
+def dequantize_per_tensor_tensor2(
+        input: torch.Tensor,
+        scale: torch.Tensor,
+        zero_point: torch.Tensor,
+        quant_min: torch.Tensor,
+        quant_max: torch.Tensor,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine dequantization for the Tensor using the same quantization parameters to map
+    from quantized values to floating point values
+    Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
+def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
+    return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
+
+quantized_decomposed_lib.define(
+    "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
+    "float eps, ScalarType dtype) -> (Tensor, Tensor)")
+
+@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
+def choose_qparams_tensor(
+        input: torch.Tensor,
+        qmin: int,
+        qmax: int,
+        eps: float,
+        dtype: torch.dtype
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """ Given an input Tensor, derive the per tensor affine quantization parameter
+    (scale and zero_point) for target quantized Tensor from the Tensor
+
+    Args:
+       input (torch.Tensor): floating point input Tensor
+       quant_min (int): minimum quantized value for target quantized Tensor
+       quant_max (int): maximum quantized value for target quantized Tensor
+       dtype (torch.dtype): dtype for target quantized Tensor
+
+    Returns:
+       scale (float): quantization parameter for the target quantized Tensor
+       zero_point (int): quantization parameter for the target quantized Tensor
+    """
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    assert dtype in _DTYPE_TO_QVALUE_BOUNDS, \
+        f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
+    validate_qmin_qmax(qmin, qmax)
+
+    min_val, max_val = torch.aminmax(input)
+
+    return determine_qparams(
+        min_val, max_val, qmin, qmax, dtype, torch.Tensor([eps]), has_customized_qrange=False)
+
+quantized_decomposed_lib.define(
+    "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, "
+    "float eps, ScalarType dtype) -> (Tensor, Tensor)")
+
+@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "CompositeExplicitAutograd")
+def choose_qparams_symmetric_tensor(
+        input: torch.Tensor,
+        qmin: int,
+        qmax: int,
+        eps: float,
+        dtype: torch.dtype
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """ Given an input Tensor, derive the per tensor affine quantization parameter
+    (scale and zero_point) for target quantized Tensor from the Tensor
+
+    Args:
+       input (torch.Tensor): floating point input Tensor
+       quant_min (int): minimum quantized value for target quantized Tensor
+       quant_max (int): maximum quantized value for target quantized Tensor
+       dtype (torch.dtype): dtype for target quantized Tensor
+
+    Returns:
+       scale (float): quantization parameter for the target quantized Tensor
+       zero_point (int): quantization parameter for the target quantized Tensor
+    """
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    assert dtype in _DTYPE_TO_QVALUE_BOUNDS, \
+        f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
+    validate_qmin_qmax(qmin, qmax)
+
+    min_val, max_val = torch.aminmax(input)
+    return determine_qparams(
+        min_val,
+        max_val,
+        qmin,
+        qmax,
+        dtype,
+        torch.Tensor([eps]),
+        has_customized_qrange=False,
+        qscheme=torch.per_tensor_symmetric
+    )
+
+@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
+def choose_qparams_tensor_meta(
+        input: torch.Tensor,
+        quant_min: int,
+        quant_max: int,
+        eps: float,
+        dtype: torch.dtype
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: \
+        {quant_min} max: {quant_max}"
+    return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(1, dtype=torch.int64, device=input.device)
+
+@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta")
+def choose_qparams_symmetric_tensor_meta(
+        input: torch.Tensor,
+        quant_min: int,
+        quant_max: int,
+        eps: float,
+        dtype: torch.dtype
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(1, dtype=torch.int64, device=input.device)
+
+# Helper function used to implement per-channel quantization against any axis
+def _permute_to_axis_zero(x, axis):
+    new_axis_list = list(range(x.dim()))
+    new_axis_list[axis] = 0
+    new_axis_list[0] = axis
+    y = x.permute(tuple(new_axis_list))
+    return y, new_axis_list
+
+quantized_decomposed_lib.define(
+    "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
+def quantize_per_channel(
+        input: torch.Tensor,
+        scales: torch.Tensor,
+        zero_points: torch.Tensor,
+        axis: int,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine per channel quantization for the Tensor using the same quantization
+    parameters for each channel/axis to map from floating point to quantized values
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scales (torch.Tensor): a list of scale quantization parameter for
+       affine quantization, one per channel
+       zero_point (torch.Tensor): a list of zero_point quantization parameter for
+       affine quantization, one per channel
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    if input.dtype == torch.bfloat16:
+        input = input.to(torch.float32)
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    input, permute_axis_list = _permute_to_axis_zero(input, axis)
+    res = torch.zeros_like(input)
+
+    for i in range(input.size(0)):
+        res[i] = torch.clamp(
+            torch.round(input[i] * (1.0 / scales[i])) + zero_points[i],
+            quant_min,
+            quant_max
+        )
+
+    out = res.permute(tuple(permute_axis_list))
+    return out.to(dtype)
+
+@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
+def quantize_per_channel_meta(
+        input: torch.Tensor,
+        scales: torch.Tensor,
+        zero_points: torch.Tensor,
+        axis: int,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    if input.dtype == torch.bfloat16:
+        input = input.to(torch.float32)
+    assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    return torch.empty_like(input, dtype=dtype)
+
+# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
+# the signature as metadata for the input Tensor, this might be useful for pattern
+# matching in the future
+# We will revisit this later if we found there are no use cases for it
+quantized_decomposed_lib.define(
+    "dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
+
+@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
+def dequantize_per_channel(
+        input: torch.Tensor,
+        scales: torch.Tensor,
+        zero_points: torch.Tensor,
+        axis: int,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    """ Affine per channel dequantization for the Tensor using the same quantization
+    parameters for each channel/axis to map from quantized values to floating point values
+
+    Args:
+       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
+       e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
+       quantization parameter in the argument of this function (scales/zero_points/axis)
+
+       scales (torch.Tensor): a list of scale quantization parameter for
+       affine quantization, one per channel
+
+       zero_points (torch.Tensor): a list of zero_point quantization parameter for
+       affine quantization, one per channel
+
+       quant_min (int): minimum quantized value for output Tensor (not used in computation,
+       reserved for pattern matching)
+
+       quant_max (int): maximum quantized value for output Tensor (not used in computation,
+       reserved for pattern matching)
+
+       dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
+       reserved for pattern matching)
+
+    Returns:
+       dequantized float32 Tensor
+    """
+    assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    input, permute_axis_list = _permute_to_axis_zero(input, axis)
+    res = torch.zeros_like(input, dtype=torch.float32)
+
+    for i in range(input.size(0)):
+        # TODO: investigate why
+        # (input[i] - zero_points[i]).to(torch.float32) * scales[i]
+        # failed the test
+        res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
+
+    out = res.permute(tuple(permute_axis_list))
+    return out
+
+@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
+def dequantize_per_channel_meta(
+        input: torch.Tensor,
+        scales: torch.Tensor,
+        zero_points: torch.Tensor,
+        axis: int,
+        quant_min: int,
+        quant_max: int,
+        dtype: torch.dtype
+) -> torch.Tensor:
+    assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    return torch.empty_like(input, dtype=torch.float32)
+
+
+quantized_decomposed_lib.define(
+    "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token",
+    "CompositeExplicitAutograd",
+)
+def choose_qparams_per_token(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): original float32/float16 Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+
+    Returns:
+        scales and zero_points, both float32 Tensors
+    """
+
+    scales = input.abs().amax(dim=-1, keepdim=True)
+    if scales.dtype == torch.float16:
+        scales = (
+            scales.float()
+        )  # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
+    if dtype == torch.int8:
+        n_bits = 8
+        quant_max = 2 ** (n_bits - 1) - 1
+    else:
+        raise Exception(f"unsupported dtype in choose_qparams_per_token: {dtype}")
+
+    scales = scales.clamp(min=1e-5).div(quant_max)
+    zero_points = torch.zeros_like(scales)
+    return scales, zero_points
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token",
+    "Meta",
+)
+def choose_qparams_per_token_meta(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    size = (1, input.size(-1))
+    return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
+        size, dtype=torch.int64, device=input.device
+    )
+
+
+# TODO: move this to https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
+quantized_decomposed_lib.define(
+    "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token_asymmetric",
+    "CompositeExplicitAutograd",
+)
+def choose_qparams_per_token_asymmetric(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): original float32/float16 Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+
+    Returns:
+        scales and zero_points, both float32 Tensors
+    """
+    # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
+    qmin, qmax = -128, 127
+    min_val, max_val = torch.aminmax(input, dim=-1, keepdim=True)
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+    eps = torch.finfo(torch.float32).eps  # use xnnpack eps?
+
+    # scale
+    scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
+    scale = scale.clamp(min=eps)
+
+    # zero point
+    descaled_min = min_val_neg / scale
+    descaled_max = max_val_pos / scale
+    zero_point_from_min_error = qmin + descaled_min
+    zero_point_from_max_error = qmax + descaled_max
+    zero_point = torch.where(
+        zero_point_from_min_error + zero_point_from_max_error > 0,
+        qmin - descaled_min,
+        qmax - descaled_max,
+    )
+    zero_point = torch.clamp(zero_point, qmin, qmax).round()
+
+    return scale.to(torch.float32), zero_point.to(torch.float32)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token_asymmetric",
+    "Meta",
+)
+def choose_qparams_per_token_asymmetric_meta(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    size = (1, input.size(-1))
+    return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
+        size, dtype=torch.int64, device=input.device
+    )
+
+
+def _per_token_quant_qparam_dim_check(input, scales, zero_points):
+    num_tokens = math.prod(list(input.size())[:-1])
+    assert (
+        num_tokens == scales.numel()
+    ), f"num_tokens: {num_tokens} scales: {scales.size()}"
+    assert (
+        num_tokens == zero_points.numel()
+    ), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd")
+def quantize_per_token(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+):
+    """Per token quantization for the Tensor using the quantization parameters to map
+    from floating point to quantized values. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scales (float32 torch.Tensor): quantization parameter for per token affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    _per_token_quant_qparam_dim_check(input, scales, zero_points)
+    input = (
+        torch.round(input / scales + zero_points).clamp(quant_min, quant_max).to(dtype)
+    )
+    return input
+
+
+@impl(quantized_decomposed_lib, "quantize_per_token", "Meta")
+def quantize_per_token_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+):
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    return torch.empty_like(input, dtype=dtype)
+
+
+quantized_decomposed_lib.define(
+    "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
+    "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd")
+def dequantize_per_token(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    output_dtype: torch.dtype = torch.float32,
+):
+    """Per token dequantization for the Tensor using the quantization parameters to map
+    from floating point to quantized values. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): quantized Tensor (uint8, int8 etc.)
+       scales (float32 torch.Tensor): quantization parameter for per token affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
+       quant_min (int): minimum quantized value for input Tensor
+       quant_max (int): maximum quantized value for input Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+       output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
+
+    Returns:
+       dequantized Tensor with dtype `output_dtype`
+    """
+    input = input - zero_points
+    input = input.to(output_dtype) * scales
+    return input
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta")
+def dequantize_per_token_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    output_dtype: torch.dtype = torch.float32,
+):
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    # TODO: support fp16
+    return torch.empty_like(input, dtype=output_dtype)
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
+    "int quant_max, ScalarType dtype, int group_size) -> Tensor"
+)
+
+
+# TODO: dtype is ignored for now
+@impl(
+    quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd"
+)
+def quantize_per_channel_group(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    group_size=128,
+):
+    assert group_size > 1
+    # needed for GPTQ single column quantize
+    if group_size > input.shape[-1] and scales.shape[-1] == 1:
+        group_size = input.shape[-1]
+
+    assert input.shape[-1] % group_size == 0
+    assert input.dim() == 2
+
+    # TODO: check for dtype, currently we can't express torch.int4 so it's omitted
+    to_quant = input.reshape(-1, group_size)
+    assert torch.isnan(to_quant).sum() == 0
+
+    scales = scales.reshape(-1, 1)
+    zero_points = zero_points.reshape(-1, 1)
+
+    input_int8 = (
+        to_quant.div(scales)
+        .add(zero_points)
+        .round()
+        .clamp_(quant_min, quant_max)
+        .to(dtype)
+        .reshape_as(input)
+    )
+
+    return input_int8
+
+
+@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta")
+def quantize_per_channel_group_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    group_size=128,
+):
+    """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters
+    to map from floating point to quantized values. This means for each row of a 2-d Tensor
+    (M, N), we calculate scales/zero_points for each `group_size` elements
+    and quantize every `group_size` elements with the same quantization parameter.
+    The dimension for scales/zero_points will be (M * ceil(N, group_size),)
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    assert group_size > 1
+    # needed for GPTQ single column quantize
+    if group_size > input.shape[-1] and scales.shape[-1] == 1:
+        group_size = input.shape[-1]
+
+    assert input.shape[-1] % group_size == 0
+    assert input.dim() == 2
+    return torch.empty_like(input, dtype=dtype)
+
+
+quantized_decomposed_lib.define(
+    "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, "
+    "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "dequantize_per_channel_group",
+    "CompositeExplicitAutograd",
+)
+def dequantize_per_channel_group(
+    w_int8: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: Optional[torch.Tensor],
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    group_size: int = 128,
+    output_dtype: torch.dtype = torch.float32,
+):
+    """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters
+    to map from floating point to quantized values. This means for each row of a 2-d Tensor
+    (M, N), we calculate scales/zero_points for each `group_size` elements
+    and quantize every `group_size` elements with the same quantization parameter.
+    The dimension for scales/zero_points will be (M * ceil(N, group_size),)
+
+    Args:
+       input (torch.Tensor): quantized Tensor (uint8/int8 etc.)
+       scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
+       quant_min (int): minimum quantized value for input Tensor
+       quant_max (int): maximum quantized value for input Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+       output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
+
+    Returns:
+       dequantized Tensor with dtype `output_dtype`
+    """
+
+    assert group_size > 1
+    # needed for GPTQ single column dequantize
+    if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
+        group_size = w_int8.shape[-1]
+    assert w_int8.shape[-1] % group_size == 0
+    assert w_int8.dim() == 2
+
+    w_int8_grouped = w_int8.reshape(-1, group_size)
+    scales = scales.reshape(-1, 1)
+    if zero_points is not None:
+        zp = zero_points.reshape(-1, 1)
+    else:
+        zp = torch.zeros([], dtype=torch.int32, device=scales.device)
+    w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype)
+    return w_dq
+
+
+quantized_decomposed_lib.define(
+    "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
+    "int quant_min, int quant_max) -> Tensor")
+
+class FakeQuantPerChannel(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
+        with torch._C._AutoDispatchBelowAutograd():
+            if input.dtype == torch.bfloat16:
+                input = input.to(torch.float32)
+            if scales.dtype != torch.float32:
+                scales = scales.to(torch.float32)
+            if zero_points.dtype != torch.int32:
+                zero_points = zero_points.to(torch.int32)
+            assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+            assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+            broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
+            unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
+            unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
+            temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points
+            out = (torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points) * unsqueeze_scales
+            mask = torch.logical_and((temp >= quant_min), (temp <= quant_max))
+
+        ctx.save_for_backward(mask)
+        return out
+
+    @staticmethod
+    def backward(ctx, gy):
+        mask, = ctx.saved_tensors
+        return gy * mask, None, None, None, None, None
+
+@impl(quantized_decomposed_lib, "fake_quant_per_channel", "AutogradCPU")
+def fake_quant_per_channel(
+        input: torch.Tensor,
+        scales: torch.Tensor,
+        zero_points: torch.Tensor,
+        axis: int,
+        quant_min: int,
+        quant_max: int,
+) -> torch.Tensor:
+    return FakeQuantPerChannel.apply(input, scales, zero_points, axis, quant_min, quant_max)
+
+@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta")
+def fake_quant_per_channel_meta(
+        input: torch.Tensor,
+        scales: torch.Tensor,
+        zero_points: torch.Tensor,
+        axis: int,
+        quant_min: int,
+        quant_max: int,
+) -> torch.Tensor:
+    return torch.empty_like(input)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_equalize.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_equalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd06631f0fb1de448162e4b80a675343750081ee
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_equalize.py
@@ -0,0 +1,820 @@
+import warnings
+
+from collections import namedtuple
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.ao.nn.intrinsic as nni
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
+
+from ..observer import _with_args, ObserverBase, PerChannelMinMaxObserver
+from ..utils import _parent_name, check_min_max_valid
+
+from .utils import (
+    get_new_attr_name_with_prefix,
+    maybe_get_next_module,
+    node_arg_is_weight,
+)
+
+CUSTOM_MODULE_SUPP_LIST: List[Any] = []
+
+def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
+    """Reshapes the scale so that we can multiply it to the input by the given axis.
+    """
+    new_shape = [1] * input.ndim
+    new_shape[axis] = input.size(axis)
+    return scale.view(new_shape)
+
+qsheme_mapping_per_tensor_to_per_channel = {
+    torch.per_tensor_affine: torch.per_channel_affine,
+    torch.per_tensor_symmetric: torch.per_channel_symmetric,
+}
+
+
+class _InputEqualizationObserver(nn.Module):
+    r"""Observer for tracking the running min/max values of input columns, and
+    computing the quantization parameters for the overall min/max input values.
+
+    Args:
+        dtype: Quantized data type
+        qscheme: Quantization scheme
+        quant_min: Minimum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+
+    The running minimum/maximum :math:`x_\text{min/max}` are computed in the
+    same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
+    with the difference that the running min/max values are stored per column.
+    This observer is intended to be used along with a WeightEqualizationObserver
+    to calculate the equalization scale.
+    """
+
+    def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
+                 quant_min=None, quant_max=None, factory_kwargs=None) -> None:
+        super().__init__()
+
+        if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
+            raise TypeError("Input qscheme must be per-tensor")
+
+        self.dtype = dtype
+        self.qscheme = qscheme
+
+        per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
+        self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
+                                                  qscheme=per_channel_qscheme,
+                                                  quant_min=quant_min,
+                                                  quant_max=quant_max,
+                                                  factory_kwargs=factory_kwargs)
+
+        self.equalization_scale = torch.tensor(1)
+        self.equalization_shape: List[int] = []
+
+    def forward(self, x_orig):
+        if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
+            raise ValueError("InputEqualizationObserver only supports Linear and Conv layers")
+
+        # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
+        self.equalization_shape = [1] * x_orig.ndim
+        self.equalization_shape[1] = x_orig.size(1)
+
+        return self.input_obs(x_orig)
+
+    def get_input_minmax(self):
+        return (self.input_obs.min_val, self.input_obs.max_val)
+
+    def set_equalization_scale(self, equalization_scale):
+        # Reshape the equalization scale along axis=1 so that it can be
+        # multiplied with the input along axis=1
+        if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+            return
+        self.equalization_scale = torch.reshape(equalization_scale, self.equalization_shape)
+
+    def calculate_scaled_minmax(self):
+        r""" Returns the scaled min/max inputs
+        """
+        if self.equalization_scale.nelement() == 1 and self.equalization_scale == torch.tensor(1):
+            warnings.warn(
+                "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " +
+                "Will not scale the next quantization observer."
+            )
+            return None, None
+
+        # Calculate qparams for the scaled min/max inputs
+        # Scale the input by the equalization scale located at the same column
+        # index
+        (min_inputs, max_inputs) = self.get_input_minmax()
+        equalization_scale_reshaped = reshape_scale(self.equalization_scale, 0, min_inputs)
+        min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
+        max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
+
+        return min_input_scaled, max_input_scaled
+
+    with_args = classmethod(_with_args)
+
+
+class _WeightEqualizationObserver(nn.Module):
+    r"""Observer for tracking the running min/max values of weight columns and
+    rows, and computing the quantization parameters for the weight rows.
+
+    Args:
+        dtype: Quantized data type
+        qscheme: Quantization scheme
+        quant_min: Minimum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+
+    This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
+    to record the running minimum and maximum of columns of incoming weight
+    tensors. This observer is intended to be used along with an
+    InputEqualizationObserver to calculate the equalization scale.
+
+    The running minimum/maximum :math:`w_\text{min/max}` are computed in the
+    same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
+    """
+
+    def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
+                 quant_max=None, factory_kwargs=None) -> None:
+        super().__init__()
+
+        self.dtype = dtype
+        self.qscheme = qscheme
+        self.ch_axis = 1
+
+        per_channel_qscheme = qscheme
+        if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
+            per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
+        self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
+                                                       qscheme=per_channel_qscheme,
+                                                       quant_min=quant_min,
+                                                       quant_max=quant_max,
+                                                       factory_kwargs=factory_kwargs)
+
+        self.equalization_scale = torch.tensor(1)
+
+    def forward(self, w_orig):
+        if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
+            raise ValueError("InputEqualizationObserver only supports Linear and Conv layers")
+
+        return self.weight_col_obs(w_orig)
+
+    def get_weight_col_minmax(self):
+        return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
+
+    def set_equalization_scale(self, equalization_scale):
+        self.equalization_scale = equalization_scale
+
+    with_args = classmethod(_with_args)
+
+
+def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
+                                 weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
+    r""" Calculates the equalization scale and sets the equalization_scale value
+    in the observers.
+
+    Args:
+        input_obs: Observer that tracks the ranges for the input columns
+        weight_obs: Observer that tracks the ranges for the weight columns
+    """
+
+    (min_inputs, max_inputs) = input_obs.get_input_minmax()
+    (min_weights, max_weights) = weight_obs.get_weight_col_minmax()
+
+    if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)):
+        warnings.warn(
+            "Must run observer before calling calculate_equalization_scale. " +
+            "Returning default equalization scale torch.tensor(1)."
+        )
+        return torch.tensor(1)
+
+    if not (min_inputs.shape == min_weights.shape):
+        raise ValueError(
+            "Input and Weight must have the same column dimension. " +
+            f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
+        )
+
+    equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
+    # Replace all 'inf', 'nan', 0's with 1s to prevent errors
+    equalization_scale[equalization_scale == 0.] = 1
+    equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
+    return equalization_scale
+
+
+class EqualizationQConfig(namedtuple('EqualizationQConfig', ['input_activation', 'weight'])):
+    """
+    Describes how to quantize a layer or a part of the network specifically for
+    input-weight equalization by providing settings (observer classes) for
+    inputs, outputs, and weights.
+
+    Note that EqualizationQConfig needs to contain observer **classes** (like
+    MinMaxObserver) or a callable that returns instances on invocation, not the
+    concrete observer instances themselves.
+    Quantization function will instantiate observers multiple times for each of
+    the layers.
+
+    Observer classes have usually reasonable default arguments, but they can be
+    overwritten with `with_args` method (that behaves like functools.partial):
+
+    my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
+                                    weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
+    """
+    def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
+        if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
+            raise ValueError("EqualizationQConfig received observer instance, please pass observer class instead. " +
+                             "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
+        self = super().__new__(cls, input_activation, weight)
+        return self
+
+
+input_equalization_observer = _InputEqualizationObserver.with_args(
+    dtype=torch.quint8, qscheme=torch.per_tensor_symmetric)
+weight_equalization_observer = _WeightEqualizationObserver.with_args(
+    dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
+default_equalization_qconfig = EqualizationQConfig(input_activation=input_equalization_observer,
+                                                   weight=weight_equalization_observer)
+
+
+def fused_module_supports_equalization(module) -> bool:
+    """ Checks if the fused node supports equalization. """
+    return type(module) in [nni.LinearReLU, nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d]
+
+def nn_module_supports_equalization(module) -> bool:
+    """ Checks if the torch.nn node supports equalization. """
+    return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
+
+def custom_module_supports_equalization(module) -> bool:
+    """ Checks if the custom node supports equalization. """
+    return type(module) in CUSTOM_MODULE_SUPP_LIST
+
+
+def node_supports_equalization(node: Node, modules) -> bool:
+    """ Checks if the current node supports equalization
+    Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
+    """
+    if node.op == 'call_module':
+        return nn_module_supports_equalization(modules[str(node.target)]) or \
+            fused_module_supports_equalization(modules[str(node.target)]) or \
+            custom_module_supports_equalization(modules[str(node.target)])
+    elif node.op == 'call_function':
+        return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
+    return False
+
+def is_equalization_observer(observer: nn.Module) -> bool:
+    return (isinstance(observer, (_InputEqualizationObserver, _WeightEqualizationObserver)))
+
+
+###############################################################################
+# Functions for equalization during convert                                   #
+###############################################################################
+
+def get_op_node_and_weight_eq_obs(
+    input_eq_obs_node: Node,
+    model: GraphModule,
+    modules: Dict[str, nn.Module]
+) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
+    """ Gets the following weight equalization observer. There should always
+    exist a weight equalization observer after an input equalization observer.
+
+    Returns the operation node that follows the input equalization observer node
+    and the weight equalization observer
+    """
+
+    # Find the op node that comes directly after the input equalization observer
+    op_node = None
+    for user in input_eq_obs_node.users.keys():
+        if node_supports_equalization(user, modules):
+            op_node = user
+            break
+
+    assert op_node is not None
+    if op_node.op == 'call_module':
+        # If the op_node is a nn.Linear layer, then it must have a
+        # WeightEqualizationObserver configuration
+        maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
+        assert maybe_equalization_node_name_to_config is not None
+        equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config  # type: ignore[assignment]
+        assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
+        weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
+
+        assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
+        return op_node, weight_eq_obs
+
+    elif op_node.op == 'call_function':
+        weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
+        if weight_node is not None:
+            weight_eq_obs = modules[str(weight_node.target)]
+            assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
+            return op_node, weight_eq_obs
+
+    return None, None
+
+def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
+    """ Gets the weight equalization observer node if it exists.
+    """
+    assert op_node.op == 'call_function'
+    for node_arg in op_node.args:
+        if node_arg_is_weight(op_node, node_arg):
+            assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and
+                   isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver))
+            return node_arg
+    return None
+
+def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
+    """ Gets the following input equalization observer if it exists.
+
+    For example, in the case of connecting linear layers:
+        x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
+    If the node being passed in is the linear1 node, then we want to return eq_obs2,
+    the following equalization observer for linear2.
+
+    However, if there are no connecting layers:
+        x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
+    Then we want to return None.
+
+    In the case of an unfused linear-relu layer with a connecting linear layer:
+        linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
+    Since it is unfused, we want to skip over the relu layer and return eq_obs2,
+    the following equalization observer for linear2.
+    """
+
+    assert node_supports_equalization(node, modules)
+
+    # Locate the following nn.ReLU or F.relu node if it exists
+    maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
+    if maybe_relu_node is None:
+        maybe_relu_node = maybe_get_next_module(node, modules, target_functional_type=F.relu)
+
+    # Locate the following output observer if it exists.
+    # We will skip the relu node if it exists.
+    maybe_obs_node = (
+        maybe_get_next_module(node, modules, ObserverBase)
+        if maybe_relu_node is None
+        else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
+    )
+    if maybe_obs_node is None:
+        return None
+
+    maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
+    if maybe_eq_obs_node is None:
+        return None
+
+    maybe_eq_obs = modules[str(maybe_eq_obs_node)]
+    assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
+    return maybe_eq_obs
+
+def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
+    """ If the next next node is an InputEqualizationObserver then we want to
+    return its equalization scale, else we return 1
+
+    This is used in the case where there are two connecting linear layers:
+        linear1 -> LinearOutObs -> InputEqObs -> linear2
+    In this case, the node given is linear1 and we want to locate the InputEqObs.
+    """
+    next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
+    if next_inp_eq_obs:
+        if next_inp_eq_obs.equalization_scale.nelement() == 1 and \
+           next_inp_eq_obs.equalization_scale == torch.tensor(1):
+            return None
+        return next_inp_eq_obs.equalization_scale
+    return None
+
+def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
+    """ Scales the following input quantization observer's min/max values by
+    updating the values with the scaled min/max values calculated by the input
+    equalization observer
+    """
+    input_eq_obs = modules[str(node.target)]
+    assert isinstance(input_eq_obs, _InputEqualizationObserver)
+
+    input_quant_obs_node = node.args[0]
+    assert isinstance(input_quant_obs_node, Node)
+
+    input_quant_obs = modules[str(input_quant_obs_node.target)]
+    if not isinstance(input_quant_obs, ObserverBase):
+        return
+
+    min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
+    if min_input_scaled is None and max_input_scaled is None:
+        return
+    input_quant_obs.min_val = min_input_scaled
+    input_quant_obs.max_val = max_input_scaled
+
+def scale_weight_node(
+    node: Node,
+    modules: Dict[str, nn.Module],
+    equalization_scale: torch.Tensor,
+    next_equalization_scale: Optional[torch.Tensor],
+) -> None:
+    """ Scale the weights for input-weight equalization by multiplying the
+    weight by 1/equalization_scale and next_equalization_scale
+
+    Args:
+        node: Current node whose weights we want to scale
+        equalization_scale: Current node's calculated equalization scale
+        next_equalization_scale: Next node's calculated equalization scale if
+           the following node needs to be equalized, 1 otherwise
+    """
+    if equalization_scale is None:
+        return
+
+    if fused_module_supports_equalization(modules[str(node.target)]):
+        op_module = modules[str(node.target)][0]    # type: ignore[index]
+    else:
+        op_module = modules[str(node.target)]
+    assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
+
+    # Scale the weights for input-weight equalization
+    # If the following layer needs to be equalized then we will multiply its scale
+    weight = op_module.weight
+    assert isinstance(weight, torch.Tensor)
+
+    # Scale the weights by the reciprocal of the equalization scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
+    equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
+    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
+
+    if next_equalization_scale is None:
+        op_module.weight = nn.Parameter(scaled_weight)
+        return
+
+    # Multiply the weights row wise by the next equalization scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=0
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
+    scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
+
+    op_module.weight = nn.Parameter(scaled_weight)
+
+    # Multiply the bias element wise by the next equalization scale
+    bias = op_module.bias
+    if bias is None:
+        return
+    assert isinstance(bias, torch.Tensor)
+
+    # Reshape the equalization scale so that we can multiply it element-wise to the bias
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
+    scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
+    op_module.bias = nn.Parameter(scaled_bias)
+
+def scale_weight_functional(
+    op_node: Node,
+    model: GraphModule,
+    modules: Dict[str, nn.Module],
+    equalization_scale: torch.Tensor,
+    next_equalization_scale: Optional[torch.Tensor],
+) -> None:
+    """ Scales the weight value for functional layers
+    """
+    if equalization_scale is None:
+        return
+
+    # From the given op_node, the path looks like:
+    #   get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
+    # So we want to trace back from the op_node to get the equalization observer
+    # node, then the quantization observer node, and then finally the weight
+    # node which contains the weight values.
+
+    # Get the equalization observer node
+    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
+    if weight_eq_obs_node is None:
+        return
+
+    # Get the quantization observer node
+    weight_quant_obs_node = weight_eq_obs_node.args[0]
+    if weight_quant_obs_node is None:
+        return
+    assert (isinstance(weight_quant_obs_node, Node) and
+           isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
+
+    # Get the get_attr(weight) node
+    weight_node = weight_quant_obs_node.args[0]
+    if weight_node is None:
+        return
+    assert isinstance(weight_node, Node) and weight_node.op == 'get_attr'
+
+    weight_parent_name, weight_name = _parent_name(weight_node.target)
+    weight = getattr(modules[weight_parent_name], weight_name)
+
+    # Scale the weights for input-weight equalization
+    # If the following layer needs to be equalized then we will multiply its scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
+    equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
+    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
+
+    if next_equalization_scale is None:
+        setattr(modules[weight_parent_name], weight_name, scaled_weight)
+        return
+
+    # Multiply the weights row wise by the next equalization scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, scaled_weight)
+    scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
+
+    setattr(modules[weight_parent_name], weight_name, scaled_weight)
+    assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
+
+    # Multiply the bias element wise by the next equalization scale
+    bias_node = None
+    for node in op_node.args:
+        # Find the node containing the weight values
+        if isinstance(node, Node) and node.op == 'get_attr' and 'bias' in node.name:
+            bias_node = node
+            break
+    if bias_node is None:
+        return
+
+    bias_parent_name, bias_name = _parent_name(bias_node.target)
+    bias = getattr(modules[bias_parent_name], bias_name)
+
+    # Reshape the equalization scale so that we can multiply it element-wise to the bias
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
+    scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
+    setattr(modules[bias_parent_name], bias_name, scaled_bias)
+
+def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None:
+    """ Given the operation node, we want find the corresponding quantization
+    observer and reset its min/max values
+    """
+    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
+    if weight_eq_obs_node is None:
+        return
+
+    weight_quant_obs_node = weight_eq_obs_node.args[0]
+    if weight_quant_obs_node is None:
+        return
+    assert isinstance(weight_quant_obs_node, Node)
+
+    weight_quant_obs = modules[str(weight_quant_obs_node.target)]
+    assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
+    weight_quant_obs.reset_min_max_vals()   # type: ignore[operator]
+
+def remove_node(model: GraphModule, node: Node, prev_node: Node):
+    """ Removes the given node from the model by replacing all of its users with
+    the given previous node
+    """
+    # For all of the current node's users, replace the current node with
+    # the input quantization observer node
+    orig_users = list(node.users.keys())
+    for user_node in orig_users:
+        user_node.replace_input_with(node, prev_node)
+
+    # Erase the InputEqualizationObserver node
+    model.graph.erase_node(node)
+
+def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
+    """ Update all of the observer's equalization scale. For each
+    InputEqualizationObserver, we will find the location of the next
+    WeightEqualizationObserver, create it, and calculate the equalization scale
+    based on the two observers.
+
+    We will then return a dictionary mapping operation node names to
+    the corresponding WeightEqualizationObservers for that operation.
+    """
+    weight_eq_obs_dict = {}
+    for node in model.graph.nodes:
+        if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
+            input_eq_obs = modules[node.target]
+            assert isinstance(input_eq_obs, _InputEqualizationObserver)
+            op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
+
+            if op_node is None or weight_eq_obs is None:
+                continue
+
+            if op_node.op == 'call_module':
+                # Calibrate the weight equalization observer since it has just
+                # been created
+                if fused_module_supports_equalization(modules[str(op_node.target)]):
+                    module = modules[str(op_node.target)][0]   # type: ignore[index]
+                    assert nn_module_supports_equalization(module)
+                    weight_eq_obs(module.weight)
+                else:
+                    weight_eq_obs(modules[str(op_node.target)].weight)
+
+            # Calculate and set the equalization scale values
+            equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
+            input_eq_obs.set_equalization_scale(equalization_scale)
+            weight_eq_obs.set_equalization_scale(equalization_scale)
+
+            weight_eq_obs_dict[op_node.name] = weight_eq_obs
+
+    return weight_eq_obs_dict
+
+def convert_eq_obs(
+    model: GraphModule,
+    modules: Dict[str, nn.Module],
+    weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
+) -> None:
+    """ Converts the equalization operations and updates the other nodes in the
+    following way:
+        - Removes the input equalization observers and inserts a mul operator
+          along with an equalization scale node wherever applicable (we do not
+          want to insert a mul operator between connecting linear layers).
+        - Updates the input quantization observers with the scaled input min/max
+          values.
+        - Scales the weights by the current and next equalization scales.
+        - Removes the weight equalization observer node if it exists.
+
+    Before (after prepare):
+                                    weight values
+                                          |
+                                    WeightQuantObs
+                                          |
+                                      WeightEqObs
+                                          |
+        x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
+
+    After this function:
+                                              scaled weight values
+                                                      |
+       equalization scale                       WeightQuantObs
+              |                                       |
+        x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
+
+    After convert:
+       equalization scale                 scaled weight values
+              |                                    |
+        x -> mul -> quantize_per_tensor -> quantized::linear
+
+    Note that although the equalization observer appeared after the quantization
+    observer after prepare_fx, the mul node appears before the quantization node
+    after convert_fx. This is because placing the equalization observer after
+    the quantization observer in prepare_fx would allow us to keep the invariant
+    that the graph before the current node inserts its observers is not
+    modified.
+
+    Having the equalization observer before the quantization observer would also
+    cause some inconsistences between the ordering of the quantization and
+    equalization observers.
+    For example, a single linear layer would look like:
+        x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
+    But between two connected linear layers, it would look like:
+        linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
+    """
+    for node in model.graph.nodes:
+        if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
+            inp_quant_obs_node = node.args[0]
+            prev_node = inp_quant_obs_node.args[0]
+
+            # If the previous node is a layer that needs to be equalized, then
+            # we will remove the current node because we do not need to add any
+            # equalization nodes between two layers that need to be equalized
+
+            # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
+            # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
+            if node_supports_equalization(prev_node, modules) or "relu" in prev_node.name:
+                remove_node(model, node, inp_quant_obs_node)
+                continue
+
+            # Update the following input quantization observer's min/max values
+            scale_input_observer(node, modules)
+
+            # Remove the InputEqualization node and add a mul operator before
+            # the quantization observer node that appears before the equalization node
+            # Before: x -> input_quant_obs -> input_eq_obs -> linear
+            # After: x -> mul -> input_quant_obs -> linear
+
+            # Create a node containing the equalization scale
+            with model.graph.inserting_before(inp_quant_obs_node):
+                get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
+                name = get_new_eq_scale_name(modules)
+                setattr(model, name, modules[node.target].equalization_scale)
+                eq_scale_node = model.graph.create_node('get_attr', name)
+
+            # Create a node multiplying the input with the equalization scale
+            with model.graph.inserting_after(eq_scale_node):
+                inputs = (prev_node, eq_scale_node)
+                mul_node = model.graph.create_node("call_function", torch.mul, inputs)
+
+            # Set the mul nod to be the input_quant_obs_node's input instead of
+            # the previous node
+            inp_quant_obs_node.replace_input_with(prev_node, mul_node)
+            remove_node(model, node, inp_quant_obs_node)
+
+        elif weight_eq_obs_dict.get(node.name, None) is not None:
+            weight_eq_obs = weight_eq_obs_dict.get(node.name)
+            assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
+            equalization_scale = weight_eq_obs.equalization_scale
+
+            if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+                equalization_scale = None  # type: ignore[assignment]
+            maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
+
+            # Scale the weight nodes
+            if node.op == 'call_module':
+                scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
+            elif node.op == 'call_function':
+                scale_weight_functional(node, model, modules, equalization_scale, maybe_next_equalization_scale)
+
+                weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
+                if weight_eq_obs_node is None:
+                    return
+                assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)
+
+                # Clear the quantization observer's min/max values so that they
+                # can get updated later based on the new scale values
+                clear_weight_quant_obs_node(node, modules)
+
+                # Erase the weight equalization observer node
+                prev_node = weight_eq_obs_node.args[0]
+                remove_node(model, weight_eq_obs_node, prev_node)
+            else:
+                raise ValueError("Expected operation node to be 'call_module' or 'call_function" +
+                                 f"Instead got node {node.name} as '{node.op}'.")
+
+def _convert_equalization_ref(model: GraphModule):
+    """ Reference function which applies changes needed for equalization, but
+    does not quantize the nodes
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+
+    # Calculate the equalization scale, update the observers with the scaled
+    # inputs, and scale the weight
+    weight_eq_obs_dict = update_obs_for_equalization(model, modules)
+    convert_eq_obs(model, modules, weight_eq_obs_dict)
+
+    return GraphModule(model, model.graph)
+
+
+###############################################################################
+# Functions for running the equalized model on the Numeric Suite              #
+###############################################################################
+
+def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor) -> Dict[str, float]:
+    """ Runs the Numeric Suite on model_a and model_b and returns a dictionary
+    containing the SQNR between layers in model_a and model_b.
+
+    Note: In order to support equalized models, this function has a hacky fix in
+    which we do not match any torch.mul operators. This is because equalized
+    models contain extra mul operators to scale the input by the equalization
+    scale, but this edge case has not been resolved yet within the numeric suite code.
+
+    Args:
+        model_a: A float model
+        model_b: A quantized model
+        x: Inputs to use during calibration
+    """
+    import torch.ao.ns._numeric_suite_fx as ns
+    from torch.ao.ns.fx.mappings import get_unmatchable_types_map
+
+    unmatchable_types_map = get_unmatchable_types_map()
+    unmatchable_types_map["funs_unmatchable"].add(torch.mul)
+
+    model_a_ns, model_b_ns = ns.add_loggers(
+        'fp32', model_a,
+        'int8', model_b,
+        ns.OutputLogger,
+        unmatchable_types_map=unmatchable_types_map
+    )
+
+    model_a_ns(x)
+    model_b_ns(x)
+
+    activation_comparison_dict = ns.extract_logger_info(
+        model_a_ns,
+        model_b_ns,
+        ns.OutputLogger,
+        'int8')
+    ns.extend_logger_results_with_comparison(
+        activation_comparison_dict,
+        'fp32', 'int8',
+        torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
+    )
+
+    # Construct a dictionary mapping layer names to the SQNR values
+    layer_sqnr_dict = {}
+    for key in activation_comparison_dict:
+        layer = activation_comparison_dict[key]['node_output']['int8'][0]['fqn']
+        sqnr = activation_comparison_dict[key]['node_output']['int8'][0]['sqnr'][0]
+        layer_sqnr_dict[layer] = sqnr
+
+    return layer_sqnr_dict
+
+def get_equalization_qconfig_dict(
+    layer_sqnr_dict: Dict[str, float],
+    num_layers_to_equalize: int
+) -> Any:
+    """ Given the layer to SQNR dictionary, find the layers with the highest
+    quantization errors, and return an equalization_qconfig_dict
+    specifying to only equalize those top layers.
+
+    Args:
+        layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
+            when comparing an equalized model against a float model)
+        num_layers_to_equalize: Number of layers with the highest quantization
+           errors to equalize
+    """
+
+    # Sort the layer_sqnr_dictionary values and get the layers with the lowest
+    # SQNR values (aka highest quantization errors)
+    layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=lambda item: item[1])
+    layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
+
+    # Constructs an equalization_qconfig_dict that specifies to only equalize
+    # the layers with the highest quantization errors
+    module_to_qconfig_list = [(item[0], default_equalization_qconfig) for item in layers_to_equalize]
+    equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
+    return equalization_qconfig_dict
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0144595cd1c4f392d99991e47a4af7082bcd83b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py
@@ -0,0 +1,1170 @@
+import torch
+from torch.fx import map_arg, Node
+from torch.fx.graph import Graph
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.nn.quantized.reference as nnqr
+from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
+from torch.fx import GraphModule
+from .utils import (
+    collect_producer_nodes,
+    get_linear_prepack_op_for_dtype,
+    get_new_attr_name_with_prefix,
+    get_qconv_prepack_op,
+    graph_module_from_producer_nodes,
+)
+from ..utils import _parent_name
+from ..qconfig import QConfigAny
+from ..quantization_mappings import get_quantized_operator
+from .utils import create_node_from_old_node_preserve_meta
+from typing import Dict, Tuple, Type, List, Callable, Any, Union, Set, Optional
+import operator
+
+QOP_TO_ARG_NAMES_TO_SKIP = {
+    torch._ops.ops.quantized.hardswish: ['inplace'],
+    torch._ops.ops.quantized.elu: ['inplace'],
+    torch._ops.ops.quantized.dropout: ['inplace'],
+    torch._ops.ops.quantized.instance_norm:
+    ['running_mean', 'running_var', 'use_input_stats', 'momentum'],
+}
+
+def _is_node_in_list(node, modules, func_list, method_list, module_type_list):
+    is_call_function = node.op == "call_function" and node.target in func_list
+    is_call_method = node.op == "call_method" and node.target in method_list
+    is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
+    return is_call_function, is_call_method, is_call_module
+
+def is_fixed_qparams_node(node, modules):
+    func_list = [
+        torch.nn.functional.hardsigmoid,
+        torch.nn.functional.sigmoid,
+        torch.sigmoid,
+        torch.tanh,
+    ]
+    method_list = [
+        "hardsigmoid",
+        "hardsigmoid_",
+        "sigmoid",
+        "sigmoid_",
+        "tanh",
+        "tanh_",
+    ]
+    module_type_list = [
+        torch.nn.Hardsigmoid,
+        torch.nn.Sigmoid,
+        torch.nn.Tanh,
+        torch.nn.Softmax,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+def is_default_node(node, modules):
+    func_list = [
+        torch.nn.functional.elu,
+        torch.nn.functional.hardswish,
+        torch.nn.functional.instance_norm,
+        torch.nn.functional.layer_norm,
+        torch.nn.functional.leaky_relu,
+        torch.nn.functional.dropout,
+    ]
+    method_list: List[Any] = []
+    module_type_list = [
+        nnqr.ConvTranspose1d,
+        nnqr.ConvTranspose2d,
+        nnqr.ConvTranspose3d,
+        torch.nn.ELU,
+        torch.nn.LeakyReLU,
+        torch.nn.Hardswish,
+        torch.nn.InstanceNorm1d,
+        torch.nn.InstanceNorm2d,
+        torch.nn.InstanceNorm3d,
+        torch.nn.LayerNorm,
+        torch.nn.Dropout,
+        torch.nn.PReLU,
+        torch.nn.BatchNorm2d,
+        torch.nn.BatchNorm3d,
+        torch.ao.nn.intrinsic.BNReLU2d,
+        torch.ao.nn.intrinsic.BNReLU3d,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+def is_copy_node(node, modules):
+    func_list = [
+        torch.adaptive_avg_pool1d,
+        torch.nn.functional.adaptive_avg_pool2d,
+        torch.nn.functional.adaptive_avg_pool3d,
+        torch.nn.functional.hardtanh,
+        torch.nn.functional.hardtanh_,
+        torch.nn.functional.interpolate,
+        torch.nn.functional.max_pool1d,
+        torch.nn.functional.max_pool2d,
+        torch.nn.functional.max_pool3d,
+        torch.nn.functional.relu,
+        torch.nn.functional.relu6,
+        torch.avg_pool1d,
+        torch._C._nn.avg_pool2d,
+        torch._C._nn.avg_pool3d,
+        torch.clamp,
+        torch.flatten,
+        torch.mean,
+        operator.floordiv,
+        # F.channel_shuffle and torch.channel_shuffle are essentially the same thing
+        # so we only need to put one of them here
+        torch.channel_shuffle,
+    ]
+    method_list = [
+        "clamp",
+        "mean",
+        "relu",
+        "relu_",
+    ]
+    module_type_list = [
+        torch.nn.AdaptiveAvgPool1d,
+        torch.nn.AdaptiveAvgPool2d,
+        torch.nn.AdaptiveAvgPool3d,
+        torch.nn.AvgPool1d,
+        torch.nn.AvgPool2d,
+        torch.nn.AvgPool3d,
+        torch.nn.Hardtanh,
+        torch.nn.MaxPool1d,
+        torch.nn.MaxPool2d,
+        torch.nn.MaxPool3d,
+        torch.nn.ReLU,
+        torch.nn.ReLU6,
+        torch.nn.ChannelShuffle,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+def is_general_tensor_shape_node(node, modules):
+    func_list = [
+        torch.narrow,
+        torch.transpose,
+        torch.repeat_interleave,
+        torch.squeeze,
+        torch.stack,
+        torch.unsqueeze,
+        torch.nn.functional.pixel_shuffle,
+        torch.nn.functional.pixel_unshuffle,
+    ]
+    method_list = [
+        "contiguous",
+        "detach",
+        "detach_",
+        "permute",
+        "repeat",
+        "repeat_interleave",
+        "reshape",
+        "resize_",
+        "shape",
+        "size",
+        "squeeze",
+        "squeeze_",
+        "transpose",
+        "unsqueeze",
+        "unsqueeze_",
+        "view",
+    ]
+    module_type_list = [
+        torch.nn.Identity,
+        torch.nn.PixelShuffle,
+        torch.nn.PixelUnshuffle,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+def is_other_node(node, modules):
+    func_list = [
+        torch.cat,
+    ]
+    method_list: List[Any] = []
+    module_type_list: List[Any] = []
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+def is_special_pattern_node(node, modules):
+    res_function, res_method, res_module = False, False, False
+    for checker in [is_fixed_qparams_node, is_default_node, is_copy_node, is_general_tensor_shape_node, is_other_node]:
+        is_call_function, is_call_method, is_call_module = checker(node, modules)
+        res_function = res_function or is_call_function
+        res_method = res_method or is_call_method
+        res_module = res_module or is_call_module
+    return res_function, res_method, res_module
+
+def is_dequantize_node(node):
+    return isinstance(node, Node) and node.op == "call_method" and node.target == "dequantize"
+
+def is_getattr_tensor_metadata_node(node):
+    return node.op == "call_function" and \
+        node.target == getattr and \
+        node.args[1] in ["shape"]
+
+def is_get_tensor_info_node(node):
+    return node.op == "call_method" and \
+        node.target in ["shape", "size"]
+
+def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigAny]):
+    """
+    Return True if the op is configured with a None qconfig, False otherwise.
+    Note: maybe need to generalize this to also check for the dtype, and we
+    only lower when dtype matches, but right now fbgemm/qnnpack only support
+    a single dtype, so it is OK for now.
+    """
+    return op.name in qconfig_map and qconfig_map[op.name] is None
+
+# Mapping from reference module class to the replacement static quantized module class for lowering
+STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[WeightedQuantizedModule]] = {
+    nnqr.Linear: nnq.Linear,
+    nnqr.Conv1d: nnq.Conv1d,
+    nnqr.Conv2d: nnq.Conv2d,
+    nnqr.Conv3d: nnq.Conv3d,
+}
+
+# Mapping from reference module class to the replacement dynamic quantized module class for lowering
+DYNAMIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = {
+    nnqr.Linear: nnqd.Linear,
+    nnqr.GRUCell: nnqd.GRUCell,
+    nnqr.LSTMCell: nnqd.LSTMCell,
+    nnqr.RNNCell: nnqd.RNNCell,
+    nnqr.LSTM: nnqd.LSTM,
+    nnqr.GRU: nnqd.GRU,
+}
+
+# Mapping from reference module class to the replacement weight only quantized module class for lowering
+# TODO: correct the namespace for these modules
+WEIGHT_ONLY_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = {
+    nnqr.Embedding: nnq.Embedding,
+    nnqr.EmbeddingBag: nnq.EmbeddingBag,
+}
+
+# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge
+# _lower_static_weighted_ref_module and special_pattern_replacement
+SPECIAL_PATTERN_LOWER_MODULE_MAP = {
+    nn.BatchNorm2d: nnq.BatchNorm2d,
+    nn.BatchNorm3d: nnq.BatchNorm3d,
+    nnqr.ConvTranspose1d: nnq.ConvTranspose1d,
+    nnqr.ConvTranspose2d: nnq.ConvTranspose2d,
+    nnqr.ConvTranspose3d: nnq.ConvTranspose3d,
+    nn.ELU: nnq.ELU,
+    nn.LeakyReLU: nnq.LeakyReLU,
+    nn.Hardswish: nnq.Hardswish,
+    nn.InstanceNorm1d: nnq.InstanceNorm1d,
+    nn.InstanceNorm2d: nnq.InstanceNorm2d,
+    nn.InstanceNorm3d: nnq.InstanceNorm3d,
+    nn.LayerNorm: nnq.LayerNorm,
+    nn.Dropout: nnq.Dropout,
+    nn.Softmax: nnq.Softmax,
+    nn.PReLU: nnq.PReLU,
+    nni.BNReLU2d: nniq.BNReLU2d,
+    nni.BNReLU3d: nniq.BNReLU3d,
+}
+
+# Mapping from fused module class to a 2-tuple of:
+#   1) The inner reference module class
+#   2) The replacement static quantized module class for lowering
+STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
+    nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU),
+    # TODO: LinearLeakyReLU is registered as global but it is only fused and
+    # lowered when ondnn's backend config is used. Maybe need to separate
+    # registration and lowering functions for different backends in the future.
+    nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU),
+    nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh),
+    nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d),
+    nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d),
+    nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d),
+}
+
+# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP:
+# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs.
+# Mapping from fused module class to a 2-tuple of:
+#   1) The inner reference module class
+#   2) The replacement static quantized module class for lowering
+STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
+    nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d),
+    nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d),
+}
+
+# Mapping from fused module class to a 2-tuple of:
+#   1) The inner reference module class
+#   2) The replacement dynamic quantized module class for lowering
+DYNAMIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[nn.Module]]] = {
+    nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU),
+}
+
+# Mapping from a functional to lower to a 2-tuple of
+#   1) The quantized version of the op
+#   2) The quantized version of the op fused with relu, if it exists, else None
+STATIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Tuple[Callable, Optional[Callable]]] = {
+    F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu),
+    F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu),
+    F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu),
+    F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu),
+    F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None),
+    F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None),
+    F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None),
+}
+
+WEIGHT_PREPACK_OPS: Set[Callable] = {
+    torch._ops.ops.quantized.linear_prepack,
+    torch._ops.ops.quantized.linear_prepack_fp16,
+    torch._ops.ops.quantized.conv1d_prepack,
+    torch._ops.ops.quantized.conv2d_prepack,
+    torch._ops.ops.quantized.conv3d_prepack,
+    torch.ops.quantized.conv_transpose1d_prepack,
+    torch.ops.quantized.conv_transpose2d_prepack,
+    torch.ops.quantized.conv_transpose3d_prepack,
+}
+
+# Mapping from a functional to a dictionary, where the key is a 2-tuple of
+# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of
+#   1) The dynamically quantized version of the op
+#   2) The dynamically quantized version of the op fused with relu, if it exists, else None
+DYNAMIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Dict[Tuple[torch.dtype, torch.dtype], Tuple[Callable, Optional[Callable]]]] = {
+    F.linear: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.linear_dynamic,
+                                      torch.ops.quantized.linear_relu_dynamic),
+        (torch.float16, torch.float16): (torch.ops.quantized.linear_dynamic_fp16,
+                                         torch.ops.quantized.linear_relu_dynamic_fp16)
+    },
+    # dynamic conv + relu is not available yet
+    F.conv1d: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None),
+    },
+    F.conv2d: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None),
+    },
+    F.conv3d: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None),
+    },
+}
+
+CONV_FUNCTIONAL_OPS: Set[Callable] = {
+    F.conv1d,
+    F.conv2d,
+    F.conv3d,
+}
+
+CONV_TRANSPOSE_FUNCTIONAL_OPS: Set[Callable] = {
+    F.conv_transpose1d,
+    F.conv_transpose2d,
+    F.conv_transpose3d,
+}
+
+# TODO: add tests for lowering these ops
+QBIN_OP_MAPPING: Dict[Union[Callable, str], Callable] = {
+    operator.add: torch.ops.quantized.add,
+    torch.add: torch.ops.quantized.add,
+    operator.mul: torch.ops.quantized.mul,
+    operator.matmul: torch.ops.quantized.matmul,
+    torch.mul: torch.ops.quantized.mul,
+    torch.matmul: torch.ops.quantized.matmul,
+}
+QBIN_RELU_OP_MAPPING: Dict[Union[Callable, str], Callable] = {
+    operator.add: torch.ops.quantized.add_relu,
+    torch.add: torch.ops.quantized.add_relu,
+    operator.mul: torch.ops.quantized.mul_relu,
+    torch.mul: torch.ops.quantized.mul_relu,
+}
+
+def _save_packed_weight(self, destination, prefix, keep_vars):
+    for attr_name in dir(self):
+        if "_packed_weight" in attr_name and \
+           isinstance(getattr(self, attr_name), torch._C.ScriptObject):  # type: ignore[attr-defined]
+            packed_weight = getattr(self, attr_name)
+            destination[prefix + attr_name] = packed_weight
+
+def _load_packed_weight(self, state_dict, prefix, local_metadata, strict,
+                        missing_keys, unexpected_keys, error_msgs):
+    attrs_to_pop = []
+    for attr_name in state_dict:
+        if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject):  # type: ignore[attr-defined] # noqa: B950
+            setattr(self, attr_name, state_dict[attr_name])
+            attrs_to_pop.append(attr_name)
+
+    # pop the packed param attributesn
+    for attr_name in attrs_to_pop:
+        state_dict.pop(attr_name)
+
+def fold_weight(
+    quantized_model: GraphModule,
+    node_name_to_scope: Dict[str, Tuple[str, type]]
+) -> GraphModule:
+    """
+    Trace back from the weight node util we hit getattr, reconstruct the
+    graph module with the traced nodes and run the graph module to pack the
+    weight. then replace the original chain of ops with the packed weight.
+    """
+    packed_weights = {}
+    # map from folded node name to the prepacked weight name
+    folded_nodes = {}
+    # get packed weights
+    for node in quantized_model.graph.nodes:
+        if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
+            nodes_to_fold = collect_producer_nodes(node)
+            if nodes_to_fold is not None:
+                for node_to_fold in nodes_to_fold:
+                    folded_nodes[node_to_fold.name] = node
+
+                prepacking_module = graph_module_from_producer_nodes(
+                    quantized_model, nodes_to_fold)
+                packed_weight = prepacking_module()
+                packed_weights[node.name] = packed_weight
+
+    # remove folded nodes and replace the prepacking node with getattr
+    folded_graph = Graph()
+    env: Dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node.name])
+
+    for node in quantized_model.graph.nodes:
+        prepack_node = folded_nodes.get(node.name, None)
+        if prepack_node is node:
+            packed_weight = packed_weights[node.name]
+            # add a prepacked attribute to root
+            op_node = next(iter(prepack_node.users))
+            module_path, _ = node_name_to_scope[op_node.name]
+            get_new_packed_weight_name = \
+                get_new_attr_name_with_prefix(module_path + '_packed_weight_')
+            packed_weight_name = get_new_packed_weight_name(quantized_model)
+            setattr(quantized_model, packed_weight_name, packed_weight)
+            # replace prepack node with a getattr node
+            env[node.name] = folded_graph.create_node(
+                'get_attr', packed_weight_name, (), {})
+        elif prepack_node is not None:
+            # remove the foled node
+            continue
+        else:
+            # copy other nodes
+            env[node.name] = folded_graph.node_copy(node, load_arg)
+
+    quantized_model = GraphModule(quantized_model, folded_graph)
+    quantized_model._register_state_dict_hook(_save_packed_weight)
+    quantized_model._register_load_state_dict_pre_hook(_load_packed_weight, with_module=True)
+    return quantized_model
+
+def _get_module(node: Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
+    """
+    Return the `torch.nn.Module` that corresponds to the specified node's target.
+    If no such node exists, return None.
+    """
+    if node.op == "call_module" and str(node.target) in modules:
+        return modules[str(node.target)]
+    else:
+        return None
+
+def _match_static_pattern(
+    node: Node,
+    modules: Dict[str, nn.Module],
+    qconfig_map: Dict[str, QConfigAny],
+    matching_modules_or_ops: List[Callable],
+    dequantize_node_arg_indices: List[int]
+) -> Union[Tuple[Node, Node, Node], Tuple[None, None, None]]:
+    """
+    Match the pattern (dequantize - ref node - quantize) against the node provided.
+
+    If there is a match, return a 3-tuple of:
+      1) q_node: the quantize node,
+      2) relu_node: a relu node wrapping the ref_node, and
+      3) ref_node: a reference module or functional node to replace with its quantized counterpart
+    Otherwise, if there is no match, return a 3-tuple of (None, None, None).
+
+    Parameters:
+      node: The `torch.fx.Node` to match against.
+      modules: A mapping from node names to modules in the model graph, used for module lookup.
+      qconfig_map: A mapping from node names to the qconfigs associated with the nodes.
+          If the corresponding qconfig for the reference node is None, then return no match.
+      matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s.
+          If the reference node is not in this list, then return no match.
+      dequantize_node_arg_indices: A list of indices in the reference node args where dequantize
+          nodes may be present. An empty list means skipping the check for dequantize nodes.
+    """
+    SKIP_LOWERING_VALUE = (None, None, None)
+
+    # Match quantize node
+    if node.op != "call_function" or node.target != torch.quantize_per_tensor:
+        return SKIP_LOWERING_VALUE
+    q_node = node
+    ref_node = q_node.args[0]
+    assert isinstance(ref_node, Node)
+
+    # Handle cases where the node is wrapped in a ReLU
+    if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\
+            (ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU):
+        relu_node = ref_node
+        ref_node = relu_node.args[0]
+        assert isinstance(ref_node, Node)
+    else:
+        relu_node = None
+    if should_skip_lowering(ref_node, qconfig_map):
+        return SKIP_LOWERING_VALUE
+
+    # Match reference module or functional
+    if isinstance(matching_modules_or_ops[0], type) and issubclass(matching_modules_or_ops[0], nn.Module):
+        expected_op = "call_module"
+        match_key = type(_get_module(ref_node, modules))
+    else:
+        expected_op = "call_function"
+        match_key = ref_node.target
+    if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
+        return SKIP_LOWERING_VALUE
+
+    # Match dequantize node(s). Both of the following conditions must pass:
+    # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node
+    # (2) There must be at least one dequantize node
+    matched_dequantize = False
+    for i in dequantize_node_arg_indices:
+        assert i < len(ref_node.args), \
+            f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
+        arg = ref_node.args[i]
+        if is_dequantize_node(arg):
+            matched_dequantize = True
+        elif isinstance(arg, Node):
+            return SKIP_LOWERING_VALUE
+    if not matched_dequantize:
+        return SKIP_LOWERING_VALUE
+
+    return (q_node, relu_node, ref_node)
+
+def _match_static_pattern_with_two_inputs(
+    node: Node,
+    modules: Dict[str, nn.Module],
+    qconfig_map: Dict[str, QConfigAny],
+    matching_modules_or_ops: List[Callable]
+) -> Union[Tuple[Node, Node], Tuple[None, None]]:
+    """
+                      (dequantize \
+    Match the pattern (dequantize - ref node - quantize) against the node provided.
+
+    If there is a match, return a 2-tuple of:
+      1) q_node: the quantize node,
+      2) ref_node: a reference module or functional node to replace with its quantized counterpart
+    Otherwise, if there is no match, return a 2-tuple of (None, None).
+
+    Parameters:
+      node: The `torch.fx.Node` to match against.
+      modules: A mapping from node names to modules in the model graph, used for module lookup.
+      qconfig_map: A mapping from node names to the qconfigs associated with the nodes.
+          If the corresponding qconfig for the reference node is None, then return no match.
+      matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s.
+          If the reference node is not in this list, then return no match.
+    """
+    SKIP_LOWERING_VALUE = (None, None)
+
+    # Match quantize node
+    if node.op != "call_function" or node.target != torch.quantize_per_tensor:
+        return SKIP_LOWERING_VALUE
+    q_node = node
+    ref_node = q_node.args[0]
+    assert isinstance(ref_node, Node)
+
+    if should_skip_lowering(ref_node, qconfig_map):
+        return SKIP_LOWERING_VALUE
+
+    # Match reference module or functional
+    if isinstance(matching_modules_or_ops[0], type) and issubclass(matching_modules_or_ops[0], nn.Module):
+        expected_op = "call_module"
+        match_key = type(_get_module(ref_node, modules))
+    else:
+        # This pass only support op of "call_module"
+        return SKIP_LOWERING_VALUE
+
+    if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
+        return SKIP_LOWERING_VALUE
+
+    # Check ref_node has 2 input nodes, both are dq node.
+    if len(ref_node.args) != 2:
+        return SKIP_LOWERING_VALUE
+    for i in range(len(ref_node.args)):
+        arg = ref_node.args[i]
+        if not is_dequantize_node(arg):
+            return SKIP_LOWERING_VALUE
+
+    return (q_node, ref_node)
+
+def _lower_static_weighted_ref_module(
+        model: GraphModule,
+        qconfig_map: Dict[str, QConfigAny]):
+    """
+    Traverse the graph and find dequantize - ref module - quantize patterns
+    and replace them with the quantized version of the ref module.
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    nodes = list(model.graph.nodes)
+    for n in model.graph.nodes:
+        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
+        matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list(STATIC_LOWER_FUSED_MODULE_MAP.keys())
+        (q_node, relu_node, ref_node) = _match_static_pattern(
+            n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0])  # type: ignore[arg-type]
+        if q_node is None:
+            continue
+        assert ref_node is not None
+        (_, scale_node, zero_point_node, _) = q_node.args
+        ref_module = _get_module(ref_node, modules)
+        ref_class = type(ref_module)
+        assert isinstance(scale_node, Node)
+        assert isinstance(zero_point_node, Node)
+        assert issubclass(ref_class, nn.Module)
+
+        # Step 1: Change this pattern to use the corresponding quantized module
+        # For fused modules, we also check whether the inner module is a reference module
+        # If so, we replace the entire fused module with the corresponding quantized module
+        if ref_class in STATIC_LOWER_FUSED_MODULE_MAP:
+            inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class]
+            if type(ref_module[0]) != inner_ref_class:  # type: ignore[index]
+                continue
+        else:
+            q_class = STATIC_LOWER_MODULE_MAP[ref_class]
+        output_scale = getattr(model, scale_node.target)
+        output_zero_point = getattr(model, zero_point_node.target)
+        q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
+        # replace reference module with quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(modules[parent_name], module_name, q_module)
+
+        # Step 2: Reroute around dq_node, and remove q_node and its args
+        assert len(ref_node.args) == 1
+        dq_node = ref_node.args[0]
+        assert isinstance(dq_node, Node)
+        ref_node.replace_input_with(dq_node, dq_node.args[0])
+        q_node.replace_all_uses_with(ref_node)
+        model.graph.erase_node(q_node)
+        model.graph.erase_node(scale_node)
+        model.graph.erase_node(zero_point_node)
+
+def _lower_static_weighted_ref_module_with_two_inputs(
+        model: GraphModule,
+        qconfig_map: Dict[str, QConfigAny]):
+    """
+    Traverse the graph and find patterns
+    dequantize   dequantize
+       \\         //
+        ref module
+            \\
+          quantize
+    and replace them with the quantized version of the ref module.
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    nodes = list(model.graph.nodes)
+    for n in model.graph.nodes:
+        #                                            (dequantize \
+        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
+        matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys())
+        (q_node, ref_node) = _match_static_pattern_with_two_inputs(
+            n, modules, qconfig_map, matching_modules)  # type: ignore[arg-type]
+        if q_node is None:
+            continue
+        assert ref_node is not None
+        (_, scale_node, zero_point_node, _) = q_node.args
+        ref_module = _get_module(ref_node, modules)
+        ref_class = type(ref_module)
+        assert isinstance(scale_node, Node)
+        assert isinstance(zero_point_node, Node)
+        assert issubclass(ref_class, nn.Module)
+
+        # Step 1: Change this pattern to use the corresponding quantized module
+        # For fused modules, we also check whether the inner module is a reference module
+        # If so, we replace the entire fused module with the corresponding quantized module
+        if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP:
+            inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ref_class]
+            if type(ref_module[0]) != inner_ref_class:  # type: ignore[index]
+                continue
+        else:
+            continue
+        output_scale = getattr(model, scale_node.target)
+        output_zero_point = getattr(model, zero_point_node.target)
+        q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
+        # replace reference module with quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(modules[parent_name], module_name, q_module)
+
+        # Step 2: Reroute around dq_node, and remove q_node and its args
+        assert len(ref_node.args) == 2
+        for arg in ref_node.args:
+            if not is_dequantize_node(arg):
+                continue
+            dq_node = arg
+            assert isinstance(dq_node, Node)
+            ref_node.replace_input_with(dq_node, dq_node.args[0])
+
+        q_node.replace_all_uses_with(ref_node)
+        model.graph.erase_node(q_node)
+        model.graph.erase_node(scale_node)
+        model.graph.erase_node(zero_point_node)
+
+def _lower_dynamic_weighted_ref_module(model: GraphModule):
+    """
+    Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns
+    and replace them with the dynamically quantized version of the ref module.
+    """
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        if n.op != "call_module" or \
+           type(named_modules[str(n.target)]) not in \
+           set(DYNAMIC_LOWER_MODULE_MAP.keys()).union(
+               set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())):
+            continue
+        ref_node = n
+        dq_node = ref_node.args[0]
+        if dq_node.op != "call_method" or dq_node.target != "dequantize":
+            continue
+
+        input_dynamic_q_node = dq_node.args[0]
+
+        if input_dynamic_q_node.op != "call_function" or \
+           input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic:
+            continue
+
+        activation_dtype = input_dynamic_q_node.args[1]
+        is_fp16 = activation_dtype == torch.float16
+        is_int8 = activation_dtype in [torch.quint8, torch.qint8]
+        if not is_int8 and not is_fp16:
+            continue
+
+        ref_module = named_modules[str(ref_node.target)]
+        ref_class = type(ref_module)
+        if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP:
+            inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class]
+            if type(ref_module[0]) != inner_ref_class:
+                continue
+        else:
+            q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class)  # type: ignore[assignment]
+        # TODO: maybe define a WeightedDynamicallyQuantizedModule
+        q_module = q_class.from_reference(ref_module)  # type: ignore[attr-defined]
+
+        # replace reference module with dynamically quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(named_modules[parent_name], module_name, q_module)
+        ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0])
+
+def _lower_weight_only_weighted_ref_module(model: GraphModule):
+    """
+    Traverse the graph and find ref_module patterns
+    and replace them with the weight only quantized version of the ref module.
+    """
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        if n.op != "call_module" or \
+           type(named_modules[str(n.target)]) not in \
+           set(WEIGHT_ONLY_LOWER_MODULE_MAP.keys()):
+            continue
+        ref_node = n
+        ref_module = named_modules[str(ref_node.target)]
+        ref_class = type(ref_module)
+        q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class)
+        # TODO: WeightedQuantizedModule is currently assuming static quant apis
+        # with output_scale, output_zero_point in from_reference, we may want to
+        # relax that, or rename this
+        # TODO: maybe define a WeightedWeightOnlyQuantizedModule
+        q_module = q_class.from_reference(ref_module)  # type: ignore[union-attr]
+
+        # replace reference module with dynamically quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(named_modules[parent_name], module_name, q_module)
+
+def _lower_static_weighted_ref_functional(
+        model: GraphModule,
+        qconfig_map: Dict[str, QConfigAny]):
+    """
+    Traverse the graph and replace functional reference patterns with their quantized versions.
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    nodes = list(model.graph.nodes)
+    for n in model.graph.nodes:
+        # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize)
+        matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys())
+        (q_node, relu_node, func_node) = _match_static_pattern(
+            n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1])
+        if q_node is None:
+            continue
+        assert func_node is not None
+        (_, output_scale_node, output_zp_node, _) = q_node.args
+        (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
+        assert isinstance(output_zp_node, Node)
+        assert isinstance(input_dq_node, Node)
+        assert isinstance(weight_dq_node, Node)
+        quantized_weight = weight_dq_node.args[0]
+        assert isinstance(quantized_weight, Node)
+        if quantized_weight.op != "call_function" or\
+                quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel):
+            continue
+
+        # Step 1: Replace quantized weights with packed weights, which will be folded later
+        # Use the right prepack op and prepare the corresponding args
+        # Linear prepack args: (quantized weights[, bias])
+        # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
+        prepack_args = [quantized_weight] + remaining_func_args
+        if func_node.target == F.linear:
+            weight_dtype = quantized_weight.args[-1]
+            prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
+        elif func_node.target in CONV_FUNCTIONAL_OPS:
+            prepack_op = get_qconv_prepack_op(func_node.target)  # type: ignore[arg-type]
+            # For conv1d, the stride, padding, and dilation args may be ints,
+            # in which case we need to convert them to tuples
+            if func_node.target == F.conv1d:
+                for i in [2, 3, 4]:
+                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
+                        prepack_args[i] = (prepack_args[i],)
+        elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS:
+            prepack_op = get_qconv_prepack_op(func_node.target)  # type: ignore[arg-type]
+            # For conv_transpose1d, the stride, padding, and dilation args may be ints,
+            # in which case we need to convert them to tuples
+            if func_node.target == F.conv_transpose1d:
+                # Note prepack_args[5] is groups.
+                for i in [2, 3, 4, 6]:
+                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
+                        prepack_args[i] = (prepack_args[i],)
+            # swap dilation and groups
+            # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups}
+            # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation}
+            if (len(prepack_args) > 6):
+                prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
+        else:
+            raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
+        with model.graph.inserting_before(output_scale_node):
+            # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
+            # They are not needed for compute op (i.e., quantized::linear)
+            kwargs = func_node.kwargs
+            # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias
+            if func_node.target == F.linear and 'bias' in kwargs:
+                kwargs = kwargs.copy()
+                kwargs['B'] = kwargs['bias']
+                del kwargs['bias']
+            packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), kwargs)
+
+        # Step 2: Replace reference pattern with the corresponding quantized op
+        (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target]  # type: ignore[index]
+        # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases
+        if q_relu_func is not None:
+            func_node.target = q_relu_func if relu_node is not None else q_func
+        else:
+            func_node.target = q_func
+        func_node.args = (input_dq_node.args[0], packed_weight, output_scale_node, output_zp_node)
+        # kwargs for func_node has been moved to kwargs for prepack op
+        func_node.kwargs = {}
+        q_node.replace_all_uses_with(func_node)
+        # Move func_node after output_zp_node in the graph
+        output_zp_node.append(func_node)
+
+        # Clean up: Remove quantize node, and the relu node if it exists
+        model.graph.erase_node(q_node)
+        if relu_node is not None and q_relu_func is not None:
+            model.graph.erase_node(relu_node)
+
+def _lower_dynamic_weighted_ref_functional(
+        model: GraphModule,
+        qconfig_map: Dict[str, QConfigAny]):
+    """
+    Traverse the graph and replace functional reference patterns with their dynamically
+    quantized versions.
+    Examples:
+    quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic
+    to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    nodes = list(model.graph.nodes)
+    # we want to search in reserved order so that we can match the larger patterns first
+    # e.g. we want to match linear - relu before linear.
+    for n in reversed(model.graph.nodes):
+
+        # Step 0: Find nodes that match this pattern
+        # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op)
+        # We search for the pattern backwards, starting with the quantize node
+        # Quantize node args: (func, scale, zp, dtype)
+        func_node = n
+        # Handle cases where the functional op is wrapped in a ReLU
+        if func_node.op == "call_function" and func_node.target == F.relu or \
+           func_node.op == "call_module" and \
+           type(modules[str(func_node.target)]) == torch.nn.ReLU:
+            relu_node = func_node
+            func_node = relu_node.args[0]
+        else:
+            relu_node = None
+        if should_skip_lowering(func_node, qconfig_map):
+            continue
+        # Linear args: (dequantized inputs, dequantized weights[, bias])
+        # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups])
+        if func_node.op != "call_function" or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP:
+            continue
+        (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
+        if input_dq_node.op != "call_method" or input_dq_node.target != "dequantize" or \
+           weight_dq_node.op != "call_method" or weight_dq_node.target != "dequantize":
+            continue
+
+        input_dynamic_q_node = input_dq_node.args[0]
+
+        if input_dynamic_q_node.op != "call_function" or \
+           input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic:
+            continue
+
+        reduce_range_node = None
+        (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args
+        is_fp16 = activation_dtype == torch.float16
+        is_int8 = activation_dtype in [torch.quint8, torch.qint8]
+        if not is_int8 and not is_fp16:
+            continue
+
+        quantized_weight = weight_dq_node.args[0]
+        weight_dtype = quantized_weight.args[-1]
+
+        # Step 1: Try to select reference pattern with the corresponding quantized op
+        dynamic_quant_dtype_key = (activation_dtype, weight_dtype)
+        if dynamic_quant_dtype_key not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target]:
+            print(f"Didn't find dtype combination {dynamic_quant_dtype_key} during "
+                  f"dynamic quantized op lowering for {func_node.target}")
+            continue
+        (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][dynamic_quant_dtype_key]
+
+        if q_func is None or q_relu_func is None:
+            print("Didn't find corresponding quantized function or quantized relu function "
+                  f"for {func_node.target}, {dynamic_quant_dtype_key}")
+            continue
+
+        # Step 2: Replace quantized weights with packed weights, which will be folded later
+        # Use the right prepack op and prepare the corresponding args
+        # Linear prepack args: (quantized weights[, bias])
+        # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
+        prepack_args = [quantized_weight] + remaining_func_args
+        if func_node.target == F.linear:
+            prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
+        elif func_node.target in CONV_FUNCTIONAL_OPS:
+            prepack_op = get_qconv_prepack_op(func_node.target)
+            # For conv1d, the stride, padding, and dilation args may be ints,
+            # in which case we need to convert them to tuples
+            if func_node.target == F.conv1d:
+                for i in [2, 3, 4]:
+                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
+                        prepack_args[i] = (prepack_args[i],)
+        else:
+            raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
+        with model.graph.inserting_before(func_node):
+            packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), {})
+
+        # Step 3: Replace reference pattern with the corresponding quantized op
+        func_node.target = q_relu_func if relu_node is not None else q_func
+        if is_int8:
+            func_node.args = (pattern_input, packed_weight, reduce_range_node)
+        else:
+            func_node.args = (pattern_input, packed_weight)
+
+        if relu_node is not None:
+            relu_node.replace_all_uses_with(func_node)
+
+        # Step 4: Remove the relu node if it exists
+        if relu_node is not None:
+            model.graph.erase_node(relu_node)
+
+def _lower_quantized_binary_op(
+        model: GraphModule,
+        qconfig_map: Dict[str, QConfigAny]):
+    binary_ops_to_lower: List[Callable] = [operator.add, torch.add, operator.mul, torch.mul, torch.matmul]
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
+        (q_node, relu_node, bop_node) = _match_static_pattern(
+            n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1])
+        if q_node is None:
+            continue
+        assert bop_node is not None
+        (_, scale_node, zero_point_node, _) = q_node.args
+
+        # Step 1: Remove dequant nodes
+        num_dq_nodes = 0
+        for arg in bop_node.args:
+            if not is_dequantize_node(arg):
+                continue
+            dq_node = arg
+            assert isinstance(dq_node, Node)
+            dn_input = dq_node.args[0]
+            bop_node.replace_input_with(dq_node, dn_input)
+            num_dq_nodes += 1
+        assert num_dq_nodes > 0
+
+        # Step 2: Swap binary op to quantized binary op
+        assert bop_node.target in QBIN_OP_MAPPING
+        binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
+        qbin_op = binop_to_qbinop[bop_node.target]
+        # prepare the args for quantized binary op
+        # (x, y)
+        qop_node_args = list(bop_node.args)
+        # (x, y, scale, zero_point)
+        # add scale and zero_point arguments for Tensor - Tensor operation
+        if num_dq_nodes == 2:
+            qop_node_args.extend([scale_node, zero_point_node])
+        # insert a call to quantized binary op and remove the original binary op
+        with model.graph.inserting_after(q_node):
+            qop_node = create_node_from_old_node_preserve_meta(
+                model.graph,
+                ("call_function", qbin_op, tuple(qop_node_args), {}),
+                bop_node)
+            q_node.replace_all_uses_with(qop_node)
+
+        # Step 3: Remove quantize node, binary op node, and relu node if any
+        model.graph.erase_node(q_node)
+        if relu_node is not None:
+            model.graph.erase_node(relu_node)
+        model.graph.erase_node(bop_node)
+
+def special_pattern_replacement(model: GraphModule):
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        q_node = n
+        is_quantize = q_node.target == torch.quantize_per_tensor
+        is_to_fp16 = q_node.op == "call_method" and q_node.target == "to" and \
+            len(q_node.args) == 2 and q_node.args[1] == torch.float16
+        if not (is_quantize or is_to_fp16):
+            continue
+        ref_node = q_node.args[0]
+        # get output scale/zero_point/dtype from the quantize node
+        # ref_node, scale_node, zero_point_node, dtype = q_node.args
+        # TODO: add safety checks that users for the ref_node and dq_node needs to be one
+        is_call_function, is_call_method, is_call_module = is_fixed_qparams_node(ref_node, modules)
+        if is_to_fp16 and (is_call_function or is_call_method or is_call_module):
+            # TODO: add a warning or error out here? (bc-breaking if error out)
+            # warnings.warn(
+            #     "Only reference patterns are currently supported for {dtype} dtype with {op} op"
+            #     "".format(dtype=dtypes, op=ref_node))
+            continue
+
+        is_call_function, is_call_method, is_call_module = is_default_node(ref_node, modules)
+        if is_to_fp16 and (is_call_function or is_call_method or is_call_module):
+            # TODO: add a warning or error out here? (bc-breaking if error out)
+            continue
+
+        # This check includes all supported ops
+        is_call_function, is_call_method, is_call_module = is_special_pattern_node(ref_node, modules)
+        if not (is_call_module or is_call_function or is_call_method):
+            continue
+        assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0
+        dq_node_or_nodes = ref_node.args[0] if len(ref_node.args) > 0 else next(iter(ref_node.kwargs.values()))
+        assert isinstance(dq_node_or_nodes, (Node, tuple, list))
+        is_dequantize = False
+        if isinstance(dq_node_or_nodes, Node):
+            is_dequantize = dq_node_or_nodes.op == 'call_method' and \
+                dq_node_or_nodes.target == 'dequantize'
+        elif isinstance(dq_node_or_nodes, (tuple, list)):
+            is_dequantize = all(
+                x.op == 'call_method' and x.target == 'dequantize'
+                for x in dq_node_or_nodes)
+
+        if not is_dequantize:
+            continue
+
+        # TODO: enable we have patterns that needs to swap the modules
+        if is_call_module:
+            ref_module = modules[ref_node.target]
+            if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize:
+                qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module))
+                scale_node = q_node.args[1]
+                zero_point_node = q_node.args[2]
+                output_scale = getattr(model, scale_node.target)
+                output_zero_point = getattr(model, zero_point_node.target)
+
+                qmodule = qmodule_cls.from_reference(ref_module, output_scale, output_zero_point)  # type:ignore[union-attr]
+                # replace reference module with quantized module
+                parent_name, module_name = _parent_name(ref_node.target)
+                setattr(modules[parent_name], module_name, qmodule)
+
+        # reroute around dq node:
+        dq_nodes: List[Node] = []
+        if isinstance(dq_node_or_nodes, Node):
+            dq_nodes = [dq_node_or_nodes]
+        elif isinstance(dq_node_or_nodes, (tuple, list)):
+            dq_nodes = list(dq_node_or_nodes)
+
+        for dq_node in dq_nodes:
+            dn_input = dq_node.args[0]
+            ref_node.replace_input_with(dq_node, dn_input)
+
+        # store q node args
+        qnode_qparams = list(q_node.args)[1:]
+        # replace uses of q node with input and remove q node
+        q_node_input = q_node.args[0]
+        q_node.replace_all_uses_with(q_node_input)
+        model.graph.erase_node(q_node)
+
+        is_call_function, is_call_method, is_call_module = is_default_node(ref_node, modules)
+        if is_call_function:
+            # pass scale/zer_point arguments from quantize_per_tensor to the default node operator
+            # insert an op after the zero_point node so that the scale/zero_point
+            # nodes are is available
+            qop = get_quantized_operator(ref_node.target)
+            args = list(ref_node.args)
+            kwargs = dict(ref_node.kwargs)
+            if qop in QOP_TO_ARG_NAMES_TO_SKIP:
+                args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop]
+                for arg in args_to_skip:
+                    if arg in kwargs:
+                        kwargs.pop(arg)
+            kwargs["output_scale"] = qnode_qparams[0]
+            kwargs["output_zero_point"] = qnode_qparams[1]
+            with model.graph.inserting_after(qnode_qparams[1]):
+                qop_node = create_node_from_old_node_preserve_meta(
+                    model.graph,
+                    ("call_function", qop, tuple(args), kwargs),
+                    ref_node)
+                ref_node.replace_all_uses_with(qop_node)
+                model.graph.erase_node(ref_node)
+        else:
+            # remove scale/zero_point node for quantize node
+            for n in qnode_qparams:
+                if isinstance(n, Node):
+                    model.graph.erase_node(n)
+
+    return model
+
+def _lower_getattr_tensor_metadta_op(model: GraphModule):
+    """ Modified the graph of the model inplace, to skip extra dequantize op before
+    the general tensor shape ops when possible
+    """
+    for n in model.graph.nodes:
+        if is_getattr_tensor_metadata_node(n):
+            maybe_dq = n.args[0]
+            if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
+                continue
+            # skip the dequantize node
+            args = list(n.args)
+            args[0] = n.args[0].args[0]
+            n.args = tuple(args)
+
+def _lower_get_tensor_info_op(model: GraphModule):
+    """ Modified the graph of the model inplace, to skip extra dequantize op before
+    the general tensor shape ops when possible
+    """
+    for n in model.graph.nodes:
+        if not is_get_tensor_info_node(n):
+            continue
+        maybe_dq = n.args[0]
+        if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
+            continue
+        # skip the dequantize node
+        args = list(n.args)
+        args[0] = n.args[0].args[0]
+        n.args = tuple(args)
+
+def _lower_to_native_backend(
+    model: GraphModule,
+    qconfig_map: Dict[str, QConfigAny],
+    node_name_to_scope: Dict[str, Tuple[str, type]]
+) -> GraphModule:
+    """ Lower a quantized reference model (with reference quantized operator patterns)
+    to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
+    operator signature so they can be lowered with the same function
+    """
+    _lower_static_weighted_ref_module(model, qconfig_map)
+    _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map)
+    _lower_dynamic_weighted_ref_module(model)
+    _lower_weight_only_weighted_ref_module(model)
+    _lower_static_weighted_ref_functional(model, qconfig_map)
+    _lower_dynamic_weighted_ref_functional(model, qconfig_map)
+    _lower_quantized_binary_op(model, qconfig_map)
+    _lower_getattr_tensor_metadta_op(model)
+    _lower_get_tensor_info_op(model)
+    special_pattern_replacement(model)
+    model.graph.eliminate_dead_code()
+    model = fold_weight(model, node_name_to_scope)
+    model.graph.eliminate_dead_code()
+    model.recompile()
+    model.graph.lint()
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b86a063bc1a4fb42906a8a67ebe6d09d4e7dedfe
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..540effe9f8366fd3a70e5817ee3f42b63bdfbfe5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae2ecd4a82a75a0cf6ba05ad34c646fae282c1dd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7fb05ad105873718499c36d0e8f9e0f640d95809
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5348a5d5dc663a8707e12c7ae4d445f7add09a3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/detector.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dbd3ed9538ae7379ce1e37aee071fadc19db081
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/detector.py
@@ -0,0 +1,1539 @@
+from typing import Any, Dict, Set, Tuple, Callable, List
+
+import torch
+import torch.nn as nn
+import torch.ao.nn.qat as nnqat
+from abc import ABC, abstractmethod
+from torch.ao.quantization.fake_quantize import FakeQuantize
+from torch.ao.quantization.fx.graph_module import GraphModule
+from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
+from torch.ao.quantization.qconfig import (
+    QConfig,
+    default_qconfig,
+    _assert_valid_qconfig,
+)
+from torch.ao.quantization.observer import (
+    ObserverBase,
+    default_dynamic_quant_observer,
+    default_per_channel_weight_observer,
+    default_observer,
+    default_weight_observer,
+)
+from torch.ao.quantization.fx._equalize import (
+    default_equalization_qconfig,
+    EqualizationQConfig,
+)
+from torch.ao.quantization.observer import _is_activation_post_process
+
+# Names for observer insert keys
+DETECTOR_TARGET_NODE_KEY = "target_node"
+DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
+DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
+DETECTOR_OBS_ARGS_KEY = "observer_args"
+
+# Mapping related code
+class DetectorQConfigInfo:
+    r"""
+    This class contains the QConfig information for a single module.
+    The list of variables / values this contains can grow depending on the
+    extensibility of the qconfig mapping feature set but this currently includes:
+    - if activation observer is dynamic
+    - if weight observer is per channel
+
+
+    Args:
+        module_fqn (str): The fully qualified name (fqn) of the module that this
+            information contains info relevant to qconfig for
+    """
+
+    def __init__(self, module_fqn: str):
+        super().__init__()
+        self.module_fqn = module_fqn
+
+        # populate this section with all the variables we might find important
+        # change from none if your detector is actually using this
+        self.is_activation_dynamic = False
+        self.is_weight_per_channel = False
+
+        # equalization related options
+        self.is_equalization_recommended = False
+
+    def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
+        r"""
+        Args:
+            module (torch.nn.Module) The module we are generating
+            the qconfig for
+
+        Returns the generated quantization QConfig according to what a valid configuration is
+        """
+        # Apply suggestions to new qconfig
+        module_qconfig = default_qconfig
+
+        # keep track of dynamic and per_channel recommendations
+        recommendations_list = []
+        # append as if a list of combinations
+        recommendations_list.append((self.is_activation_dynamic, self.is_weight_per_channel))
+        recommendations_list.append((self.is_activation_dynamic, False))  # only trying dynamic rec
+        recommendations_list.append((False, self.is_weight_per_channel))  # only trying dynamic
+
+        # now we try each of the combinations
+        for rec in recommendations_list:
+            # rec[0] -> dynamic recommended
+            # rec[1] -> per channel recommended
+            activation = default_dynamic_quant_observer if rec[0] else default_observer
+            weight = default_per_channel_weight_observer if rec[1] else default_weight_observer
+            test_config = QConfig(activation, weight)
+            try:
+                _assert_valid_qconfig(test_config, module)
+                module_qconfig = test_config
+                break
+            except AssertionError:
+                # if not a valid configuration, we move on to the next one in priority
+                continue
+
+        # return the QConfig chosen
+        return module_qconfig
+
+    def generate_equalization_qconfig(self) -> EqualizationQConfig:
+        r"""
+        This returns the equalization configuration for a module.
+
+        For now, it just returns the default, but as more equalization options become
+        possible, this method can get more fleshed out with more nuanced granularity.
+
+
+        Returns the generated equalization QConfig according to what a valid configuration is
+        """
+        # in this case, we just return default equalization config
+        # we know this is valid because only valid modules would even
+        # have this option
+        return default_equalization_qconfig
+
+# Adding base class for detectors
+class DetectorBase(ABC):
+    r""" Base Detector Module
+    Any detector class should derive from this class.
+
+    Concrete detectors should follow the same general API, which includes:
+    - A method to calculate and return observer insertion points
+        - Should return both the fqns and the Observer class to insert
+    - A method to return a report based on the detector
+        - Should return a str-based report and dict info in Tuple[str,Dict] format
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.detector_config_info = None
+
+    @abstractmethod
+    def determine_observer_insert_points(self, model) -> Dict:
+        r"""
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
+            This dict maps string keys to detector specific information
+        """
+        pass
+
+    @abstractmethod
+    def get_detector_name(self) -> str:
+        r""" Returns the name of the current detector """
+        pass
+
+
+    @abstractmethod
+    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
+        r""" Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        pass
+
+    def _get_targeting_node(self, prepared_fx_model: GraphModule, target_fqn: str) -> torch.fx.node.Node:
+        r"""
+        Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.
+
+        If it's not found, it means it is most likely inside a fused layer
+            We just go one layer up in terms of the fqn we are searching for until we find parent node
+            If we get to empty string, then we know that it doesn't exist
+
+        The reason for the recursion is that if the model that we are looking for got fused,
+        we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
+        which would have fqn as x.linear so they will not match.
+        To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
+        or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
+        even in cases with fusion
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+            target_fqn (str): The fqn of the layer we are trying to target
+
+        Returns the node object we are trying to add observers around
+        """
+        for node in prepared_fx_model.graph.nodes:
+            # if the node's target is our target, return it
+            if node.target == target_fqn:
+                return node
+
+        # getting here means node not found
+        # if no "." we are already at base and failed
+        parent_fqn_sep_index = target_fqn.rfind(".")
+        if parent_fqn_sep_index == -1:
+            raise ValueError("passed in target_fqn not found in graph's targets.")
+        else:
+            # recursively call it with parent fqn
+            return self._get_targeting_node(prepared_fx_model, target_fqn[:parent_fqn_sep_index])
+
+    @abstractmethod
+    def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]:
+        r"""
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Tuple of two elements:
+            Str: string report of the suggested improvements
+            Dict: contains useful data collected by the observer pertinent to this report
+        """
+        pass
+
+class PerChannelDetector(DetectorBase):
+    r""" This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
+        Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
+
+        per_channel quantization can lead to major benefits in the form of accuracy.
+        Therefore, if the backend used by the user supports it, it is recommended to use
+
+        Args:
+            backend (str, optional): the backend the user wishes to use in production
+                Default value is current torch.backends.quantized.engine
+    """
+
+    # Keys for return dictionary
+    BACKEND_KEY = "backend"
+    PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
+    PER_CHAN_USED_KEY = "per_channel_quantization_used"
+
+    # Default map for representing supported per channel quantization modules for different backends
+    DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
+        "fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
+        "qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
+        "onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
+        "x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
+    }
+
+    def __init__(self, backend: str = torch.backends.quantized.engine):
+        super().__init__()
+
+        # store the backend information
+        self.backend_chosen = backend
+        self.supported_modules = set()
+        if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
+            self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen]
+        else:
+            raise ValueError(f"Not configured to work with {self.backend_chosen}. Try a different default backend")
+
+    def get_detector_name(self) -> str:
+        r""" returns the string name of this detector"""
+        return "per_channel_detector"
+
+    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
+        r""" Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # run the helper function to populate the dictionary
+        per_channel_info = self._detect_per_channel_helper(model)
+
+        # we actually have a qconfig info object we are populating
+        module_fqn_to_detector_qconfig_info = {}
+
+        for module_fqn in per_channel_info:
+            # create a detector info instance
+            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
+
+            # see if per channel quantization is supported
+            per_chan_supported: bool = per_channel_info[module_fqn][self.PER_CHAN_SUPPORTED_KEY]
+            detector_qconfig_info.is_weight_per_channel = per_chan_supported
+            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
+
+        return module_fqn_to_detector_qconfig_info
+
+    def determine_observer_insert_points(self, model: nn.Module) -> Dict:
+        r"""
+        There is no observers inserted for the PerChannelDetector.
+
+        Returns an empty dictionary since no observers are added or needed
+        """
+        return {}
+
+
+    def _detect_per_channel_helper(self, model: nn.Module):
+        r"""
+        determines if per_channel quantization is supported in modules and submodules.
+
+        Returns a dictionary in the higher level _detect_per_channel function.
+        Each entry maps the fully-qualified-name to information on whether per_channel quantization.
+
+        Args:
+            model: The current module that is being checked to see if it is per_channel quantizable
+
+        Returns dictionary mapping fqns to if per_channel quantization is possible
+        """
+        # create dict we will return
+        per_channel_info: Dict = {}
+
+        # get the fully qualified name and check if in list of modules to include and list of modules to ignore
+        for fqn, module in model.named_modules():
+
+            is_in_include_list = sum([isinstance(module, x) for x in self.supported_modules]) > 0
+
+            # check if the module per_channel is supported
+            # based on backend
+            per_channel_supported = False
+
+            if is_in_include_list:
+                per_channel_supported = True
+
+                # assert statement for MyPy
+                q_config_file = module.qconfig
+                assert isinstance(q_config_file, QConfig)
+
+                # this object should either be fake quant or observer
+                q_or_s_obj = module.qconfig.weight.p.func()
+                assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
+
+                per_channel_used = False  # will be true if found in qconfig
+
+                if hasattr(q_or_s_obj, "ch_axis"):  # then we know that per_channel quantization used
+
+                    # all fake quants have channel axis so need to check is_per_channel
+                    if isinstance(q_or_s_obj, FakeQuantize):
+                        if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel:
+                            per_channel_used = True
+                    elif isinstance(q_or_s_obj, ObserverBase):
+                        # should be an observer otherwise
+                        per_channel_used = True
+                    else:
+                        raise ValueError("Should be either observer or fake quant")
+
+                per_channel_info[fqn] = {
+                    self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
+                    self.PER_CHAN_USED_KEY: per_channel_used,
+                    self.BACKEND_KEY: self.backend_chosen
+                }
+
+        return per_channel_info
+
+    def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]:
+        r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
+        Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
+
+        Looks at q_config format and backend to determine if per_channel can be utilized.
+        Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support
+
+        Args:
+            model: The prepared and calibrated model we want to check if using per_channel
+
+        Returns a tuple with two elements:
+            String report of potential actions to improve model (if per_channel quantization is available in backend)
+            Dictionary mapping per_channel quantizable elements to:
+                whether per_channel quantization is supported by the backend
+                if it is being utilized in the current model
+        """
+
+        # run the helper function to populate the dictionary
+        per_channel_info = self._detect_per_channel_helper(model)
+
+        # String to let the user know of further optimizations
+        further_optims_str = f"Further Optimizations for backend {self.backend_chosen}: \n"
+
+        optimizations_possible = False
+        for fqn in per_channel_info:
+            fqn_dict = per_channel_info[fqn]
+            if fqn_dict[self.PER_CHAN_SUPPORTED_KEY] and not fqn_dict[self.PER_CHAN_USED_KEY]:
+                optimizations_possible = True
+                further_optims_str += f"Module {fqn} can be configured to use per_channel quantization.\n"
+
+        if optimizations_possible:
+            further_optims_str += (
+                "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
+            )
+        else:
+            further_optims_str += "No further per_channel optimizations possible."
+
+        # return the string and the dictionary form of same information
+        return (further_optims_str, per_channel_info)
+
+
+class DynamicStaticDetector(DetectorBase):
+    r"""
+    Determines whether dynamic or static quantization is more appropriate for a given module.
+
+    Takes advantage of the ModelReportObserver that records range information.
+    Stationary distribution of data are strictly above tolerance level for the comparison statistic:
+
+        S = average_batch_activation_range/epoch_activation_range
+
+    Nonstationary distributions are below or at the tolerance level for this metric.
+
+    If the distribution of data right after the module is non-stationary, recommend dynamic quantization
+        Otherwise recommend static quantization
+
+    Args:
+        tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
+    """
+    # names for the pre and post observers that are inserted
+    DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
+    DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
+
+    # naming conventions for stationary vs non-stationary data
+    STATIONARY_STR = "stationary"
+    NON_STATIONARY_STR = "non-stationary"
+
+    # naming for activation
+    INPUT_ACTIVATION_PREFIX = "input_activation_"
+    OUTPUT_ACTIVATION_PREFIX = "output_activation_"
+
+    # naming conventions for the keys of the return module info
+    TOLERANCE_KEY = "dynamic_static_tolerance"
+    DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
+    PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
+    POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
+    PRE_OBS_DATA_DIST_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
+    POST_OBS_DATA_DIST_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
+    IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
+
+    # modules that are supported both dynamic and static for this report function
+    DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
+
+    # modules that will be supported soon for both
+    DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
+
+    def __init__(self, tolerance=0.5):
+        super().__init__()
+
+        # set tolerance level and initialize a set to keep track of useful fqn locations
+        self.tolerance = tolerance
+        self.useful_observer_fqns: Set[str] = set()
+
+    def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
+        r"""
+        Determines where observers need to be inserted for the Dynamic vs Static detector.
+        For this detector, we want to place observers on either side of linear layers in the model.
+
+        Currently inserts observers for:
+            linear layers
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
+            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
+            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
+            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
+            key "observer_args" -> The arguments that are meant to be passed into the observer
+        """
+
+        # observer for this detector is ModelReportObserver
+        obs_ctr = ModelReportObserver
+
+        # return dict
+        obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
+
+        for fqn, module in prepared_fx_model.named_modules():
+            # make sure module is supported
+            if self._is_supported(module, insert=True):
+                # if it's a supported type, we want to get node and add observer insert locations
+                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
+
+                # add entry for pre-observer
+                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
+
+                obs_fqn_to_info[pre_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
+                    DETECTOR_IS_POST_OBS_KEY: False,
+                    DETECTOR_OBS_ARGS_KEY: targeted_node.args
+                }
+
+                # add entry for post-observer
+                post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME
+
+                obs_fqn_to_info[post_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
+                    DETECTOR_IS_POST_OBS_KEY: True,
+                    DETECTOR_OBS_ARGS_KEY: (targeted_node,)
+                }
+
+        return obs_fqn_to_info
+
+    def get_detector_name(self) -> str:
+        r""" returns the string name of this detector"""
+        return "dynamic_vs_static_detector"
+
+
+    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
+        r""" Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # run the helper function to populate the dictionary
+        dynamic_static_info = self._generate_dict_info(model)
+
+        # we actually have a qconfig info object we are populating
+        module_fqn_to_detector_qconfig_info = {}
+
+        for module_fqn in dynamic_static_info:
+            # create a detector info instance
+            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
+
+            # see if per channel quantization is supported
+            dynamic_static_recommended: bool = dynamic_static_info[module_fqn][self.DEFAULT_DYNAMIC_REC_KEY]
+            detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
+            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
+
+        return module_fqn_to_detector_qconfig_info
+
+    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
+        r"""Returns whether the given module is supported for observers
+
+        Args
+            module: The module to check and ensure is supported
+            insert: True if this is check for observer insertion, false if for report gen
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        # check to see if module is of a supported type
+        is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
+
+        # check if it will be supported
+        future_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED]) > 0
+
+        # supported
+        supported = is_supported_type or future_supported_type
+
+        # this is check for observer insertion
+        if insert:
+            return supported
+        else:
+            # this is for report gen and we also need to check if it contains observers
+            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME)
+            return supported and has_obs
+
+    def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]:
+        r"""
+        Helper function for generate_detector_report that does the generation of the dictionary.
+        This process is done as specified in generate_detector_report documentation
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a Dictionary mapping modules with ModelReportObservers around them to:
+                whether dynamic quantization is recommended
+                their S metric of input to module
+                whether input to module is stationary or non-stationary
+                their S metric of output of module
+                whether output of module is stationary or non-stationary
+                the tolerance level to decided whether input/output is stationary or non-stationary
+                whether it is currently supported or planned for the future
+        """
+        # store modules dynamic vs static information
+        module_dynamic_static_info = {}
+
+        # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
+        #   This information primary includes whether the data distributions around a supported module is stationary or not
+        #   Based on this, it is recorded whether dynamic or static quantization is recommended
+
+        # loop through all submodules included nested ones
+        for fqn, module in model.named_modules():
+            # if module is Linear has the ModelReportObserver attached to it
+            if self._is_supported(module):
+                # get pre and post observers for the module
+                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+                post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)
+
+                # get the statistics for each module
+                pre_stat = pre_obs.get_batch_to_epoch_ratio()
+                post_stat = post_obs.get_batch_to_epoch_ratio()
+
+                # record module, pre and post stat, and whether to do dynamic or static based off it
+                # true if post observer data distribution is non-stationary, false if it's stationary
+                dynamic_recommended = post_stat <= self.tolerance
+
+                # specify the classifications for whether data distributions considered stationary or non-stationary
+                pre_obs_dist_classif = self.STATIONARY_STR if pre_stat > self.tolerance else self.NON_STATIONARY_STR
+                post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
+
+                # check if current support or future support
+                is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
+
+                # store the set of important information for this module
+                module_info = {
+                    self.TOLERANCE_KEY: self.tolerance,
+                    self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
+                    self.PRE_OBS_COMP_STAT_KEY: pre_stat,
+                    self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
+                    self.POST_OBS_COMP_STAT_KEY: post_stat,
+                    self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
+                    self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
+                }
+
+                module_dynamic_static_info[fqn] = module_info
+
+        return module_dynamic_static_info
+
+    def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
+        r"""
+        Determines whether dynamic or static quantization is more appropriate for a given module.
+
+        Takes advantage of the ModelReportObserver that records range information.
+        Stationary distribution of data are strictly above tolerance level for the comparison statistic:
+
+            S = average_batch_activation_range/epoch_activation_range
+
+        Nonstationary distributions are below or at the tolerance level for this metric.
+
+        If the distribution of data right after the module is non-stationary, recommend dynamic quantization
+            Otherwise recommend static quantization
+
+        This will then generate suggestions for dynamic vs static quantization focused around Linear.
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a tuple with two elements:
+            String report of of whether dynamic or static quantization is recommended for certain modules
+            Dictionary mapping modules with ModelReportObservers around them to:
+                whether dynamic quantization is recommended
+                their S metric of input to module
+                whether input to module is stationary or non-stationary
+                their S metric of output of module
+                whether output of module is stationary or non-stationary
+                the tolerance level to decided whether input/output is stationary or non-stationary
+                whether it is currently supported or planned for the future
+        """
+
+        # get the dictionary of the information to format the string report
+        module_dynamic_static_info = self._generate_dict_info(model)
+
+        dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
+
+        modules_added: bool = False  # check to make sure at least 1 module added.
+
+        dynamic_benefit = " You will get more accurate results if you use dynamic quantization"
+        static_benefit = " You can increase model efficiency if you use static quantization"
+        future_support_str = ". This layer is not yet supported for dynamic quantization"
+        # This for loop goes through the information collected in module_dynamic_static_info and:
+        #   Populates the string based report with the information from module_dynamic_static_info
+        #   Compiles the complete report by appending relevant formatted strings
+
+        for module_fqn in module_dynamic_static_info.keys():
+
+            # there is at least 1 module for suggestion
+            modules_added = True
+            module_info = module_dynamic_static_info[module_fqn]
+            suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n"
+
+            # decide what string formatting values will be
+            quantization_type = ""
+            quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
+
+            benefit_str = ""
+
+            # strings for if dynamic quantized per tensor is needed
+            recommend_per_tensor = ". We recommend to add a {} before this module if it is static."
+            rec_lay_to_add = "dynamic quantize per tensor layer"
+            dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
+            dynamic_per_tensor_reasoning_string = (
+                " This is because the input to this module has a non-stationary distribution"
+            )
+
+            # start composing explanation
+            if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
+                quantization_type = "dynamic"
+                # check if currently supported or future supported
+                benefit_str = dynamic_benefit
+                if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
+                    benefit_str += future_support_str
+            else:
+                quantization_type = "static"
+                benefit_str = static_benefit
+
+            # now set the quantization explanation string
+            quantization_reasoning = (
+                quantization_reasoning.format(
+                    module_fqn, module_info[self.PRE_OBS_DATA_DIST_KEY], module_info[self.POST_OBS_DATA_DIST_KEY]
+                )
+                + benefit_str
+            )
+
+            # if we have a non-stationary input -> linear -> stationary we suggested static
+            # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
+            if (
+                module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
+                and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
+            ):
+                quantization_reasoning = (
+                    quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string
+                )
+
+            # format the overall suggestion string with the specific inputs
+            module_suggestion_string = suggestion_string_template.format(
+                module_fqn, quantization_type, quantization_reasoning
+            )
+
+            # append to overall suggestion
+            dynamic_vs_static_string += module_suggestion_string
+
+        if not modules_added:
+            dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"
+
+        # return the string as well as the dictionary of information
+        return (dynamic_vs_static_string, module_dynamic_static_info)
+
+
+class InputWeightEqualizationDetector(DetectorBase):
+    r"""
+    Determines whether input-weight equalization can help improve quantization for certain modules.
+
+    Specifically, this list of modules includes:
+        linear
+        conv
+
+    Determines whether input-weight equalization is recommended based on the comp stat:
+        s_c = sqrt(w_c/W)/sqrt(i_c/I)
+        where:
+            w_c is range of weight for channel c, W is range of weight over all channels
+            i_c is range of input for channel c, I is range of input over all channels
+
+        if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization
+
+    Args:
+        ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
+            Should be between 0 and 1 (both non-inclusive)
+        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
+            Default: 1
+
+    * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
+        Should be between 0 and 1
+
+    * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization
+
+    * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization
+
+    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
+    """
+
+    SUPPORTED_MODULES: Set[Callable] = {nn.Linear,
+                                        nn.Conv1d,
+                                        nn.Conv2d,
+                                        nn.Conv3d,
+                                        nnqat.Linear,
+                                        nnqat.Conv1d,
+                                        nnqat.Conv2d,
+                                        nnqat.Conv3d}
+
+    # names for the pre and post observers that are inserted
+    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
+
+    # weight / activation prefix for each of the below info
+    WEIGHT_PREFIX = "weight_"
+    ACTIVATION_PREFIX = "input_activation_"
+
+    # string names for keys of info dictionaries
+    PER_CHANNEL_MAX_KEY = "per_channel_max"
+    PER_CHANNEL_MIN_KEY = "per_channel_min"
+    GLOBAL_MAX_KEY = "global_max"
+    GLOBAL_MIN_KEY = "global_min"
+
+    # keys for return dict of recommendations
+    RECOMMENDED_KEY = "input_weight_equalization_recommended"
+    COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
+    THRESHOLD_KEY = "input_weight_threshold"
+    CHANNEL_KEY = "input_weight_channel_axis"
+
+    # default weight and info strings
+    WEIGHT_STR = "weight"
+    INPUT_STR = "input"
+
+    # default for what ratio we recommend input weight
+    DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4
+
+    def __init__(self, ratio_threshold: float, ch_axis: int = 1):
+        # ensure passed in inputs are valid
+        if ratio_threshold <= 0 or ratio_threshold >= 1:
+            raise ValueError("Make sure threshold is > 0 and < 1")
+
+        # initialize attributes based on args
+        self.ratio_threshold: float = ratio_threshold
+        self.ch_axis: int = ch_axis
+
+    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
+        r"""Returns whether the given module is supported for observers
+
+        Args
+            module: The module to check and ensure is supported
+            insert: True if this is check for observer insertion, false if for report gen
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        # check to see if module is of a supported type
+        is_supported_type = sum([type(module) is x for x in self.SUPPORTED_MODULES]) > 0
+
+        # this is check for observer insertion
+        if insert:
+            return is_supported_type
+        else:
+            # this is for report gen and we also need to check if it contains observers
+            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+            return is_supported_type and has_obs
+
+    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
+        r""" Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # run the helper function to populate the dictionary
+        # find the range of inputs
+        input_values: Dict[str, Dict] = self._extract_input_info(model)
+
+        # find the range of weights
+        weight_values: Dict[str, Dict] = self._extract_weight_info(model)
+
+        # calculate per_channel comparison statistic s_c
+        comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
+
+        # generate the return dictionary
+        input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
+
+        # we actually have a qconfig info object we are populating
+        module_fqn_to_detector_qconfig_info = {}
+
+        for module_fqn in input_weight_equalization_info:
+            # create a detector info instance
+            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
+
+            # see if per channel quantization is supported
+            input_weight_recommended: bool = input_weight_equalization_info[module_fqn][self.RECOMMENDED_KEY]
+            detector_qconfig_info.is_equalization_recommended = input_weight_recommended
+            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
+
+        return module_fqn_to_detector_qconfig_info
+
+    def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
+        r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
+        For this detector, we want to place observers in front of supported layers.
+
+        Currently inserts observers for:
+            linear layers
+            conv layers
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
+            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
+            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
+            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
+            key "observer_args" -> The arguments that are meant to be passed into the observer
+        """
+
+        # observer for this detector is ModelReportObserver
+        obs_ctr = ModelReportObserver
+
+        # return dict
+        obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
+
+        for fqn, module in prepared_fx_model.named_modules():
+            # check to see if module is of a supported type
+            if self._is_supported(module, insert=True):
+                # if it's a supported type, we want to get node and add observer insert locations
+                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
+
+                # add entry for pre-observer
+                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
+
+                obs_fqn_to_info[pre_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
+                    DETECTOR_IS_POST_OBS_KEY: False,
+                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
+                }
+
+        return obs_fqn_to_info
+
+    def get_detector_name(self) -> str:
+        r"""Returns the name of this detector"""
+        return "input_weight_equalization_detector"
+
+    def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]:
+        r"""
+        Takes in a calibrated GraphModule and then finds the relevant observers.
+        It then extracts the input information for each observer returns it
+
+        Args
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a dict mapping relevant module fqns (str) to a dict with keys:
+            "input_activation_per_channel_max" : maps to the per_channel max values
+            "input_activation_per_channel_min" : maps to the per_channel min values
+            "input_activation_global_max" : maps to the global max recorded
+            "input_activation_global_min" : maps to the global min recorded
+        """
+
+        # return dictionary mapping observer fqns to desired info
+        input_info: Dict[str, Dict] = {}
+
+        for fqn, module in model.named_modules():
+            # if module is supported and it has a pre-observer
+            if self._is_supported(module):
+                # get pre observer for the module
+                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+
+                input_info[fqn] = {
+                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
+                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
+                    self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
+                    self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
+                }
+
+        return input_info
+
+    def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]:
+        r"""
+        Takes in a calibrated GraphModule and then finds the relevant observers.
+        It then extracts the weight information for each layer an observer is attached to.
+
+        Args
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a dict mapping module fqns (str) to a dict with keys:
+            "per_channel_max" : maps to the per_channel max values
+            "per_channel_min" : maps to the per_channel min values
+            "global_max" : maps to the global max recorded
+            "global_min" : maps to the global min recorded
+        """
+        # return dictionary mapping observer fqns to desired info
+        weight_info: Dict[str, Dict] = {}
+
+        for fqn, module in model.named_modules():
+            # if module is supported and it has a pre-observer
+            if self._is_supported(module):
+                # we don't need actual observer, just the module weights
+                # calculate min and max vals
+                device = module.weight.device
+                min_val: torch.Tensor = torch.tensor([float('inf')], device=device)
+                max_val: torch.Tensor = torch.tensor([float('-inf')], device=device)
+                x_copy = module.weight
+                x_dim = x_copy.size()
+
+                new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+                new_axis_list[self.ch_axis] = 0
+                new_axis_list[0] = self.ch_axis
+                y = x_copy.permute(new_axis_list)
+
+                # Need to match dtype of min/max because the updates to buffers
+                # are done in place and types need to match for comparisons
+                y = y.to(min_val.dtype)
+                y = torch.flatten(y, start_dim=1)
+                if min_val.numel() == 0 or max_val.numel() == 0:
+                    min_val, max_val = torch.aminmax(y, dim=1)
+                else:
+                    min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
+                    min_val = torch.min(min_val_cur, min_val)
+                    max_val = torch.max(max_val_cur, max_val)
+
+                weight_info[fqn] = {
+                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
+                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
+                    self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
+                    self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
+                }
+
+        return weight_info
+
+    def _calculate_range_ratio(self, info_dict: Dict, info_str: str, module_fqn: str) -> torch.Tensor:
+        r"""
+        Takes in an info dict and calculates the s_c matrix.
+
+        Args:
+            info_dict (dict): A dictionary of either input or weight range info
+            info_str (str): A str describing whether currently looking at weight or input info
+                Either "weight" or "input"
+            module_fqn (str): The fqn of the module we are looking at
+
+        Returns a tensor of values, where each value is the s_c stat for a different channel
+        """
+        # calculate the ratios of the info
+        # get the prefix str
+        prefix_str = self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
+
+        per_channel_range = info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
+        global_range = info_dict[prefix_str + self.GLOBAL_MAX_KEY] - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
+
+        if global_range == 0:
+            range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
+            raise ValueError(
+                "The range of the {} data for module {} is 0, which means you have a constant value channel. {}".format(
+                    info_str, module_fqn, range_zero_explanation
+                )
+            )
+
+        ratio = per_channel_range / global_range
+
+        return ratio
+
+    def _generate_comparison_values(self, input_info: Dict, weight_info: Dict) -> Dict[str, torch.Tensor]:
+        r"""
+        Takes in the information on the min and max values of the inputs and weights and:
+            Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)
+
+        Args:
+            input_info (dict): A dict mapping each observer to input range information
+            weight_info (dict): A dict mapping each observer to weight range information
+
+        Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
+            Each value is a different s_c value for a different channel
+        """
+        # create return dictionary for each observer
+        module_fqn_to_channel: Dict[str, torch.Tensor] = {}
+
+        # for each module (both passed in dicts should have same keys)
+        for module_fqn in input_info:
+
+            # raise error if not in weight info
+            if module_fqn not in weight_info:
+                raise KeyError(f"Unable to find weight range stats for module {module_fqn}")
+
+            # calculate the ratios of the weight info and input info
+            weight_ratio = self._calculate_range_ratio(weight_info[module_fqn], self.WEIGHT_STR, module_fqn)
+            input_ratio = self._calculate_range_ratio(input_info[module_fqn], self.INPUT_STR, module_fqn)
+
+            # if mismatched size, because of grouping, we want to replicate weight enough times
+            weight_channels = len(weight_ratio)
+            input_channels = len(input_ratio)
+            if weight_channels != input_channels:
+                # we try to replicate
+                assert input_channels % weight_channels == 0, "input channels should be divisible by weight channels."
+                # get replication factor
+                rep_factor: int = input_channels // weight_channels
+
+                # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
+                weight_ratio = weight_ratio.repeat(rep_factor)
+
+            # calculate the s metric per channel
+            s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
+            module_fqn_to_channel[module_fqn] = s
+
+        # return compiled observer ratios
+        return module_fqn_to_channel
+
+    def _generate_dict_info(self, input_info: Dict, weight_info: Dict, comp_stats: Dict) -> Dict[str, Dict]:
+        r"""
+        Helper function for generate_detector_report that does the generation of the dictionary.
+        This process is done as specified in generate_detector_report documentation
+
+        Args:
+            input_info (dict): A dict mapping each module to input range information
+            weight_info (dict): A dict mapping each module to weight range information
+            comp_stats (dict): A dict mapping each module to its corresponding comp stat
+
+        Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
+            whether input weight equalization is recommended
+            their s_c metric compared to the threshold
+            the threshold used to make the recommendation
+            the channel used for recording data
+            the input channel range info
+            the weight channel range info
+        """
+        # store modules input weight equalization info
+        input_weight_equalization_info: Dict[str, Dict] = {}
+
+        # for each module we add separate set of suggestions
+        for module_fqn in input_info:
+
+            # get relevant info for this module
+            mod_input_info: Dict = input_info[module_fqn]
+            mod_weight_info: Dict = weight_info[module_fqn]
+            mod_comp_stat: Dict = comp_stats[module_fqn]
+
+            # decide if each channel should have input weight equalization or not
+            channel_rec_vals: list = []
+
+            for val in mod_comp_stat:
+                float_rep: float = val.item()
+
+                # decide if recommending input weight equalization
+                recommended: bool = float_rep >= self.ratio_threshold and float_rep <= 1 / self.ratio_threshold
+                channel_rec_vals.append(recommended)
+
+            # build the return dict input
+            # also unpack input and weight dicts into it
+            input_weight_equalization_info[module_fqn] = {
+                self.RECOMMENDED_KEY: channel_rec_vals,
+                self.COMP_METRIC_KEY: mod_comp_stat,
+                self.THRESHOLD_KEY: self.ratio_threshold,
+                self.CHANNEL_KEY: self.ch_axis,
+                **mod_input_info,
+                **mod_weight_info,
+            }
+
+        # return our compiled info for each module
+        return input_weight_equalization_info
+
+    def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
+        r"""
+        Determines whether input weight equalization is appropriate for a given module.
+
+        Takes advantage of the ModelReport Observer which records per channel information of input range
+        It then uses the passed in weight info inconjunction to compute the desired ratio
+        Finally, it gives suggestions based on this information for each module of interest
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a tuple with two elements:
+            String report of of whether input weight equalization is recommended for certain modules
+            Dictionary mapping modules of interest to:
+                whether input weight equalization is recommended
+                their s_c metric compared to the threshold
+                the threshold used to make the recommendation
+                the channel used for recording data
+                the input channel range info
+                the weight channel range info
+        """
+
+        # find the range of inputs
+        input_values: Dict[str, Dict] = self._extract_input_info(model)
+
+        # find the range of weights
+        weight_values: Dict[str, Dict] = self._extract_weight_info(model)
+
+        # calculate per_channel comparison statistic s_c
+        comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
+
+        # generate the return dictionary
+        input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
+
+        # now we can generate report based on this information
+        input_weight_string = "Input-Weight Equalization suggestions: \n"
+
+        # some strings to be formatted depending on module we are adding
+        module_suggestion_str = "For Module {} looked at with axis {}: \n"
+        channel_suggestion_str = "\tWe suggest {} input weight equalization because {}\n"
+        use_str = "to use"
+        no_use_str = "to not use"
+        input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
+        input_weight_non_benefit_reasoning = "{}/{} channels benefitting from input-weight equalization being applied."
+        input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"
+
+        # added module check
+        added_module: bool = False
+
+        # compile the suggestion string
+        for module_fqn in input_weight_equalization_info:
+            # we added at least 1 module
+            added_module = True
+            # add the module level description
+            input_weight_string += module_suggestion_str.format(module_fqn, self.ch_axis)
+
+            mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn]
+
+            # gather info on how many channels would benefit from input weight and
+            recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
+            num_recs = sum(recommendation_per_channel)
+
+            if num_recs / len(recommendation_per_channel) >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO:
+                input_benefit_formatted = input_weight_benefit_str.format(num_recs, len(recommendation_per_channel))
+                channel_str = channel_suggestion_str.format(use_str, input_benefit_formatted)
+                input_weight_string += channel_str
+            else:
+                non_benefit_reason_formatted = input_weight_non_benefit_reasoning.format(num_recs, len(recommendation_per_channel))
+                non_benefit_str = input_weight_non_benefit_str.format(non_benefit_reason_formatted)
+                channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
+                input_weight_string += channel_str
+
+        # if no modules looked at, amend return string
+        if not added_module:
+            input_weight_string += "No applicable layers for suggestions. Only linear and conv valid.\n"
+
+        # return a tuple with the string explanation and the compiled dict info
+        return (input_weight_string, input_weight_equalization_info)
+
+
+class OutlierDetector(DetectorBase):
+    r"""
+    Determines whether there are significant outliers in activation data around a certain layer.
+
+    This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
+        If the data is stationary, and there are significant outliers, then we want to flag them
+        We want to do this on a per channel basis for detecting outliers
+
+    Determines whether activation data is flagged as outlier based on if data is stationary and:
+        p_r = avg(100th percentile / "reference_percentile"th percentile)
+        where:
+            p_r is average percentile ratio across all batches in the epoch
+            reference_percentile is a percentile values between 0 and 100 exclusive
+
+        if p_r is above some threshold, then we consider the activations to have significant outliers
+
+    Args:
+        ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
+            Should be >= 1
+            Default: 3.5
+        reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
+            Should be between 0 and 1
+            Default: 0.975
+        fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
+            If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
+            regardless of whether we detected outliers or not in channel to take a closer look at channel results
+            Should be between 0 and 1
+            Default: 0.95
+        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
+            Default: 1
+
+    * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
+        The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
+        If it is significantly greater, then we consider it an outlier
+        This threshold was calculated based on the ratio of the percentiles in a normal distribution
+        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
+
+    * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
+        Should be between 0 and 1
+        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
+
+    * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
+        Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
+        Should be between 0 and 1
+
+    * :attr:`ch_axis`: The channel axis being observed to determine outliers
+
+    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
+    """
+
+    # names for the pre observers that are inserted
+    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
+
+    # pre activation prefix
+    INPUT_ACTIVATION_PREFIX = "input_activation_"
+
+    # names for dict keys
+    OUTLIER_KEY = "outliers_detected"
+    NUM_BATCHES_KEY = "outlier_detection_batches_used"
+    IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
+    COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
+    RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
+    REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
+    CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
+    MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
+    CONSTANT_COUNTS_KEY = "constant_batch_counts"
+
+    def __init__(
+        self,
+        ratio_threshold: float = 3.5,
+        reference_percentile: float = 0.975,
+        fraction_batches_used_threshold: float = 0.95,
+        ch_axis: int = 1,
+    ):
+        # initialize the variables of interest
+        self.ratio_threshold = ratio_threshold
+
+        # make sure passed in percentile is valid
+        assert reference_percentile >= 0 and reference_percentile <= 1
+        assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1
+        self.reference_percentile = reference_percentile
+        self.fraction_batches_used_threshold = fraction_batches_used_threshold
+        self.ch_axis = ch_axis
+
+    def get_detector_name(self) -> str:
+        r"""Returns the name of this detector"""
+        return "outlier_detector"
+
+    def _supports_insertion(self, module: nn.Module) -> bool:
+        r"""Returns whether the given module is supported for observers insertion
+
+        Any module that doesn't have children and isn't an observer itself is supported
+
+        Args
+            module: The module to check and ensure is supported
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        # case for insertion of module
+        # check if the module has any children and isn't observer
+        num_children = len(list(module.children()))
+        return num_children == 0 and not _is_activation_post_process(module)
+
+    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
+        r""" Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # currently doesn't do anything for outlier detector
+        return {}
+
+    def _supports_report_gen(self, module: nn.Module) -> bool:
+        r"""Returns whether the given module is supported for report generation
+
+        Any module that has a model report pre-observer is supported
+
+        Args
+            module: The module to check and ensure is supported
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+
+    def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
+        r""" Determines where observers need to be inserted for the Outlier Detector.
+
+        For this detector, we want to place observers in front of supported layers.
+
+        Currently inserts observers for:
+            all layers that do not have children (leaf level layers)
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
+            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
+            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
+            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
+            key "observer_args" -> The arguments that are meant to be passed into the observer
+        """
+        # observer for this detector is ModelReportObserver
+        obs_ctr = ModelReportObserver
+
+        # return dict
+        obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
+
+        for fqn, module in prepared_fx_model.named_modules():
+            # check to see if module is of a supported type
+            if self._supports_insertion(module):
+                # if it's a supported type, we want to get node and add observer insert locations
+                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
+
+                # add entry for pre-observer
+                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
+
+                obs_fqn_to_info[pre_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis, comp_percentile=self.reference_percentile),
+                    DETECTOR_IS_POST_OBS_KEY: False,
+                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
+                }
+
+        return obs_fqn_to_info
+
+    def _calculate_outlier_info(
+        self,
+        percentile_ratios: torch.Tensor,
+        counted_batches: torch.Tensor,
+        total_batches: int,
+    ) -> Dict[str, List[bool]]:
+        r"""
+        Gives info on whether the percentile ratios calculated would be considered outliers
+        Also gives information on whether the collected data is statistically significant to make this claim
+
+        Args:
+            percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
+            counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
+            total_batches (int): The total number of batches that passed through observer in this epoch
+
+        Returns a dictionary mapping:
+            "outliers_detected" : list of bools per channel that are true if it is considered an outlier
+            "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
+                where o_r = counted_batches / total_batches
+        """
+        outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []}
+
+        # get both as flattened lists for easy mapping
+        ratios_list: List = percentile_ratios.tolist()
+        num_batches_list: List = counted_batches.tolist()
+
+        # calculate whether channels were statistically significant
+        significant_size = [
+            batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list
+        ]
+        outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
+
+        # calculate for each channel whether it's an outlier or not based on ratio
+        outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
+        outlier_dict[self.OUTLIER_KEY] = outlier_detected
+
+        # return the dictionary with the two lists
+        return outlier_dict
+
+    def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]:
+        r"""
+        Helper function for generate_detector_report that does the generation of the dictionary.
+        This process is done as specified in generate_detector_report documentation
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a dict mapping relevant module fqns to:
+            whether there were outliers found in activation before
+            the number of batches used for each channel
+            whether fraction of applicable batches used is above fraction_batches_used_threshold
+            their p_r metric compared to the threshold
+            the threshold used to make the recommendation
+            the reference_percentile used to make the recommendation
+            the channel axis used to determine individual channels
+            the constant batch counts per channel
+            the per channel max values
+        """
+        # return dictionary mapping observer fqns to desired info
+        info_dict: Dict[str, Dict] = {}
+
+        for fqn, module in model.named_modules():
+            # if module is supported and it has a pre-observer
+            if self._supports_report_gen(module):
+                # get pre observer for the module
+                pre_obs: ModelReportObserver = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+
+                # get the number of batches and calculated ratio thresholds
+                num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
+                average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
+                channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
+                total_batches: int = pre_obs.num_batches_tracked
+
+                # also get the max values
+                max_vals: torch.Tensor = pre_obs.max_val
+
+                # we have to specifically modify how we are recording negative ratio for pre-relu layers
+                for index, ratio_val in enumerate(average_ratios):
+                    # check if we have a negative ratio
+                    # a ratio might be negative if we have a situation where the 100th percentile is
+                    # > 0 while the nth percentile is < 0, in which case this would not be detected
+                    # as an outlier. Since we care more about magnitude, we make it positive.
+                    if ratio_val.item() < 0:
+                        # first make it positive
+                        average_ratios[index] = -ratio_val
+
+                    if ratio_val.item() < 1:
+                        # if it's less than 1 we have the flip it as well
+                        average_ratios[index] = 1 / ratio_val
+
+                outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches)
+
+                # calculate whether ratios were outliers
+                info_dict[fqn] = {
+                    self.CHANNEL_AXIS_KEY: self.ch_axis,
+                    self.REF_PERCENTILE_KEY: self.reference_percentile,
+                    self.RATIO_THRES_KEY: self.ratio_threshold,
+                    self.COMP_METRIC_KEY: average_ratios,
+                    self.NUM_BATCHES_KEY: num_batches,
+                    self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
+                    self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY],
+                    self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
+                    self.MAX_VALS_KEY: max_vals
+                }
+
+        return info_dict
+
+    def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
+        r"""
+        Determines whether input weight equalization is appropriate for a given module.
+
+        Takes advantage of the ModelReport Observer which records the relevant percentile information
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a tuple with two elements:
+            String report of of whether there are outliers in the activations around certain modules
+            Dictionary mapping modules of interest to:
+                whether there were outliers found in activation before
+                the number of batches used for each channel
+                whether fraction of applicable batches used is above fraction_batches_used_threshold
+                their p_r metric compared to the threshold
+                the threshold used to make the recommendation
+                the reference_percentile used to make the recommendation
+                the channel axis used to determine individual channels
+                the constant batch counts per channel
+                the per channel max values
+        """
+        # generate the information dictionary of outlier information
+        info_dict = self._generate_info_dict(model)
+
+        # now we can generate report based on this information
+        outlier_string = "Outlier detection report: \n"
+
+        # added module check
+        added_module: bool = False
+
+        # some strings to be formatted depending on module we are adding
+        module_suggestion_str = "For Module {} looked at with axis {}: \n"
+        channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
+        channel_max_value_str = "a max value across all batches of {}"
+        note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
+        note_distribution = "stationary distributions"
+        note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"
+
+        # suggestion for constant batch check since that can make it no outliers
+        constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
+        constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."
+
+        # compile the suggestion string
+        for module_fqn in info_dict:
+            # get module specific info
+            mod_info: Dict[str, Any] = info_dict[module_fqn]
+            # check to see if we already added high level model desc
+            added_model_desc = False
+            # look at each individual channel and add a suggestion
+            for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
+                if outlier_detected:
+                    # we found at least 1 outlier
+                    if not added_model_desc:
+                        # add the module level description
+                        outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
+                        added_model_desc = True
+
+                    # we mark that we found at least one outlier
+                    added_module = True
+                    max_value_found_str = channel_max_value_str.format(mod_info[self.MAX_VALS_KEY][index])
+                    channel_str = channel_suggestion_str.format(index, max_value_found_str)
+                    outlier_string += channel_str
+
+                # also check if we found constant batch
+                if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
+                    # make sure we add a module level highlight.
+                    if not added_model_desc:
+                        # add the module level description
+                        outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
+                        added_model_desc = True
+
+                    constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][index]
+                    formatted_str = constant_str.format(index, constant_values_for_channel, constant_suggestion)
+                    outlier_string += formatted_str
+                    # we also added at least one thing to description
+                    added_module = True
+
+
+        # if found outlier, give suggestion, else give default response
+        if added_module:
+            # compose the note string
+            note_composed = note_string.format(note_distribution, note_rec)
+            outlier_string += note_composed
+        else:
+            outlier_string += "There were no outliers found in the activations.\n"
+
+        return (outlier_string, info_dict)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report.py
new file mode 100644
index 0000000000000000000000000000000000000000..934008931291cd5820b9b2e83c65f31add014f9b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report.py
@@ -0,0 +1,606 @@
+from typing import Any, Dict, Set, Tuple, Callable
+from collections import OrderedDict
+import torch
+from torch.ao.quantization.fx._model_report.detector import (
+    DetectorBase,
+    DETECTOR_OBS_ARGS_KEY,
+    DETECTOR_OBS_TO_INSERT_KEY,
+    DETECTOR_IS_POST_OBS_KEY,
+    DETECTOR_TARGET_NODE_KEY,
+    DetectorQConfigInfo
+)
+from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
+from torch.ao.quantization.fx.graph_module import GraphModule
+from torch.ao.quantization.observer import ObserverBase
+from torch.ao.quantization.qconfig_mapping import QConfigMapping, QConfig
+from torch.ao.quantization.fx._equalize import EqualizationQConfig
+
+class ModelReport:
+    r"""
+    The ModelReport class aims to provide users an easy way to diagnose issues that they run into
+    with their models. The class works with all traceable GraphModules to help diagnose issues,
+    though the requirements on the type of model more-so depends on the specific report the user
+    is trying to generate. With respect to the reports, the ModelReport class is initialized with
+    a set of Detector classes, each of which generate reports on quantization configuration
+    issues a use might have.
+
+    Currently supports generating reports on:
+    - Suggestions for per-channel vs. per-tensor quantization (nn.Module)
+    - Suggestions for dynamic vs static quantization for linear layers (Graph Modules)
+    - Suggestions for input-weight equalization for linear and conv layers (Graph Modules)
+    - Suggestions for outlier detection for all layers (Graph Modules)
+
+    The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver)
+    where needed for each detector to gather the information it needs, and then after callibration, the ModelReport
+    class compiles the report generated by each Detector class into a single report to return to the user. It also
+    has the capability to remove all the observers it inserted as well.
+
+    * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule
+
+    * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class
+        Make sure that these are all unique types of detectors [do not have more than 1 of the same class]
+
+    * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors.
+        This set is generated by calling the get_detector_name() of each detector
+
+    * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest
+        The purpose of this is to keep track of what observers were inserted for each detector, so that they
+        can be removed at the end if desired
+
+    * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not
+        This is to ensure we only insert observers once with the ModelReport instance
+
+    * :attr:`_removed_observers` A boolean to track if we have removed observers already
+        The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport
+        instance. This also allows the functionality where we can generate the report multiple times
+        as long as we haven't removed the observers yet.
+
+    Note:
+        This class was initially designed to work with the Fx Graph Mode workflow in mind. However,
+        full functionality is available as long as there is a traceable GraphModule that is being used.
+        One method to get a traceable GraphModule without going through the Fx workflow is to use
+        the QuantizationTracer class.
+
+    General Flow for Fx workflow:
+    1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model
+    2.) Prepare your model with prepare_fx
+    3.) Call model_report.prepare_detailed_calibration to add relevant observers
+    4.) Callibrate your model with data
+    5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
+    Optional
+        6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
+        7.) To help in parsing report information and debugging, view report info as a:
+            - Table
+            - Histogram
+            - Line plot
+    8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions
+
+    Example (with QuantizationTracer):
+        >>> # xdoctest: +SKIP
+        >>> # get the necessary qconfig
+        >>> config = PrepareCustomConfig()
+        >>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False)
+
+        >>> # initialize our model and get GraphModule
+        >>> model = SomeModel()
+        >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
+        >>> graph_module = GraphModule(model, tracer.trace(model))
+
+        >>> # get our set of detectors and ModelReport instance
+        >>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)])
+        >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set)
+
+        >>> # now we insert the observers and callibrate the model
+        >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration()
+        >>> for i in range(num_callibration_batches):
+        >>>     example_input = get_callibration_input()
+        >>>     tracer_model_with_observers(example_input)
+
+        >>> # finally we generate the reports and optionally remove the observers we inserted
+        >>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True)
+
+        >>> # Optional: we can generate the qconfig mapping based on the suggestions
+        >>> qconfigs = model_report.generate_qconfig_mapping()
+
+        >>> # Optional: we can generate the equalization mapping based on the suggestions
+        >>> qconfigs = model_report.generate_equalization_mapping()
+
+        >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
+        >>> model_report_visualizer = tracer_reporter.generate_visualizer()
+
+    """
+
+    def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]):
+
+        if len(desired_report_detectors) == 0:
+            raise ValueError("Should include at least 1 desired report")
+
+        # keep track of the model we wish to generate report for
+        self._model: GraphModule = model
+
+        # keep the reports private so they can't be modified
+        self._desired_report_detectors = desired_report_detectors
+        self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors}
+
+        # keep a mapping of desired reports to observers of interest
+        # this is to get the readings, and to remove them, can create a large set
+        # this set can then be used to traverse the graph and remove added observers
+        self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
+
+        # initialize each report to have empty set of observers of interest
+        for desired_report in self._desired_detector_names:
+            self._detector_name_to_observer_fqns[desired_report] = set()
+
+        # flags to ensure that we can only prepare and remove observers once
+        self._prepared_flag = False
+        self._removed_observers = False
+
+        # store the reports that we generated for visualization purposes
+        # initially empty since no reports generated
+        self._generated_reports: Dict[str, Dict] = {}
+
+    def get_desired_reports_names(self) -> Set[str]:
+        """ Returns a copy of the desired reports for viewing """
+        return self._desired_detector_names.copy()
+
+    def get_observers_of_interest(self) -> Dict[str, Set[str]]:
+        """ Returns a copy of the observers of interest for viewing """
+        return self._detector_name_to_observer_fqns.copy()
+
+    def prepare_detailed_calibration(self) -> GraphModule:
+        r"""
+        Takes in a graph model and inserts the following observers:
+        - ModelReportObserver
+
+        Each observer is inserted based on the desired_reports into the relevant locations
+
+        Right now, each report in self._desired_detector_names has independent insertions
+            However, if a module already has a Observer of the same type, the insertion will not occur
+            This is because all of the same type of Observer collect same information, so redundant
+
+        Returns the same GraphModule with the observers inserted
+        """
+
+        # if already prepared once, cannot prepare again
+        if self._prepared_flag:
+            raise ValueError("Already ran preparing detailed callibration. Run the report generation next after callibration.")
+
+        # loop through each detector, find where placements should be, and keep track
+        insert_observers_fqns: Dict[str, Any] = {}
+
+        for detector in self._desired_report_detectors:
+            # determine observer points for each detector
+            obs_fqn_to_info = detector.determine_observer_insert_points(self._model)
+            # map each insert point to the observer to use
+            insert_observers_fqns.update(obs_fqn_to_info)
+            # update the set of observers this report cares about
+            self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
+
+        # now insert all the observers at their desired locations
+        for observer_fqn in insert_observers_fqns:
+            target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY]
+            insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY]
+            insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY]
+            observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY]
+            self._insert_observer_around_module(
+                observer_fqn, target_node, insert_obs, observer_args, insert_post
+            )
+
+        self._prepared_flag = True
+
+        return self._model
+
+    def _insert_observer_around_module(
+        self,
+        obs_fqn: str,
+        target_node: torch.fx.node.Node,
+        obs_to_insert: ObserverBase,
+        observer_args: Tuple,
+        insert_post: bool
+    ):
+        r"""
+        Helper function that inserts the observer into both the graph structure and the module of the model
+
+        Args
+            node_fqn (str): The fully qualified name of the observer we want to insert
+            target_node (torch.fx.node.Node): The node in model we are inserting observers around
+            obs_to_insert (ObserverBase): The observer we are inserting around target_node
+            observer_args (Tuple): The arguments we want to pass into the observer
+            insert_post (bool): whether this is meant to be a post observer for this node
+        """
+        # if we are inserting post, then our target node is the next node
+        if insert_post:
+            target_node = target_node.next
+
+        with self._model.graph.inserting_before(target_node):
+            self._model.add_submodule(obs_fqn, obs_to_insert)
+            self._model.graph.create_node(op="call_module", target=obs_fqn, args=observer_args)
+
+        # recompile model after inserts are made
+        self._model.recompile()
+
+    def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node:
+        r"""
+        Takes in a node fqn and returns the node based on the fqn
+
+        Args
+            node_fqn (str): The fully qualified name of the node we want to find in model
+
+        Returns the Node object of the given node_fqn otherwise returns None
+        """
+        node_to_return = None
+        for node in self._model.graph.nodes:
+            # if the target matches the fqn, it's the node we are looking for
+            if node.target == node_fqn:
+                node_to_return = node
+                break
+
+        if node_to_return is None:
+            raise ValueError("The node_fqn is was not found within the module.")
+
+        # assert for MyPy
+        assert isinstance(node_to_return, torch.fx.node.Node)
+
+        return node_to_return
+
+    def generate_model_report(
+        self, remove_inserted_observers: bool
+    ) -> Dict[str, Tuple[str, Dict]]:
+        r"""
+        Generates all the requested reports.
+
+        Note:
+            You should have callibrated the model with relevant data before calling this
+
+        The reports generated are specified by the desired_reports specified in desired_reports
+
+        Can optionally remove all the observers inserted by the ModelReport instance
+
+        Args:
+            remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance
+
+        Returns a mapping of each desired report name to a tuple with:
+            The textual summary of that report information
+            A dictionary containing relevant statistics or information for that report
+
+        Note:
+            Throws exception if we try to generate report on model we already removed observers from
+            Throws exception if we try to generate report without preparing for callibration
+        """
+        # if we haven't prepped model for callibration, then we shouldn't generate report yet
+        if not self._prepared_flag:
+            raise Exception("Cannot generate report without preparing model for callibration")
+
+        # if we already removed the observers, we cannot generate report
+        if self._removed_observers:
+            raise Exception("Cannot generate report on model you already removed observers from")
+
+        # keep track of all the reports of interest and their outputs
+        reports_of_interest = {}
+
+        for detector in self._desired_report_detectors:
+            # generate the individual report for the detector
+            report_output = detector.generate_detector_report(self._model)
+            reports_of_interest[detector.get_detector_name()] = report_output
+
+        # if user wishes to remove inserted observers, go ahead and remove
+        if remove_inserted_observers:
+            self._removed_observers = True
+            # get the set of all Observers inserted by this instance of ModelReport
+            all_observers_of_interest: Set[str] = set()
+            for desired_report in self._detector_name_to_observer_fqns:
+                observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
+                all_observers_of_interest.update(observers_of_interest)
+
+            # go through all_observers_of_interest and remove them from the graph and model
+            for observer_fqn in all_observers_of_interest:
+                # remove the observer from the model
+                self._model.delete_submodule(observer_fqn)
+
+                # remove the observer from the graph structure
+                node_obj = self._get_node_from_fqn(observer_fqn)
+
+                if node_obj:
+                    self._model.graph.erase_node(node_obj)
+                else:
+                    raise ValueError("Node no longer exists in GraphModule structure")
+
+            # remember to recompile the model
+            self._model.recompile()
+
+        # save the generated reports for visualization purposes
+        saved_reports: Dict[str, Dict] = {
+            report_name : report_tuple[1] for report_name, report_tuple in reports_of_interest.items()
+        }
+
+        self._generated_reports = saved_reports
+
+        # return the reports of interest
+        return reports_of_interest
+
+    def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool:
+        r"""
+        Takes in two dictionaries and ensures that any common keys between the two have the same
+        values.
+
+        Args:
+            info_dict_a (Dict): First dictionary we wish to compare
+            info_dict_b (Dict): Second dictionary we wish to compare
+
+        Returns True if all shared keys have same values, false otherwise
+        """
+        # get the set of keys for both
+        dict_a_keys: Set = set(info_dict_a.keys())
+        dict_b_keys: Set = set(info_dict_b.keys())
+
+        # get the insersection keys and check if same value for both dicts
+        intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys)
+
+        for key in intersecting_keys:
+            dict_a_val = info_dict_a[key]
+            dict_b_val = info_dict_b[key]
+
+            # if it's a tensor we have to handle separately
+            if type(dict_a_val) == torch.Tensor:
+                # if dict_b_val not tensor, automatically false
+                if type(dict_b_val) != torch.Tensor or sum(dict_a_val != dict_b_val) != 0:
+                    return False
+            else:
+                # for non-tensor vals
+                if dict_a_val != dict_b_val:
+                    return False
+
+        # if no non matching shared keys found, return true
+        return True
+
+    def _reformat_reports_for_visualizer(self) -> OrderedDict:
+        r"""
+        Takes the generated reports and reformats them into the format that is desired by the
+        ModelReportVisualizer
+
+        Returns an OrderedDict mapping module_fqns to their features
+        """
+        # we want to reorder and reformat the information so it is ordered in terms of order
+        # found in the model
+
+        # first create new dict with all modules as keys and features under respective module
+        module_fqns_to_features: Dict[str, Dict] = {}
+
+        for report_name in self._generated_reports:
+            # get mod -> feature dict and go through
+            module_info = self._generated_reports[report_name]
+
+            for module_fqn in module_info:
+                # check if already in our accumulation dict
+                if module_fqn in module_fqns_to_features:
+                    # we merge all the features together
+                    new_info: Dict = module_info[module_fqn]
+                    present_info: Dict = module_fqns_to_features[module_fqn]
+
+                    # merge them together into the new unioned dict
+                    # same features keys -> same info, so okay if override
+
+                    # do safety check to make sure shared keys have same info
+                    if self._is_same_info_for_same_key(new_info, present_info):
+                        module_fqns_to_features[module_fqn] = {**new_info, **present_info}
+                    else:
+                        error_str = "You have the same key with different values across detectors. "
+                        error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors."
+                        raise ValueError(error_str)
+                else:
+                    # we just set it
+                    module_fqns_to_features[module_fqn] = module_info[module_fqn]
+
+        # our ordered dict so that modules can be ordered in order of how they appear in model
+        features_by_module: OrderedDict[str, Dict] = OrderedDict()
+
+        # we loop through modules in graph in order
+        for fqn, module in self._model.named_modules():
+            # find that fqn in fqns_to_features
+            if fqn in module_fqns_to_features:
+                # add it to our ordered dict
+                features_by_module[fqn] = module_fqns_to_features[fqn]
+
+        # return the ordered dict of info we created
+        return features_by_module
+
+    def generate_visualizer(self) -> ModelReportVisualizer:
+        r"""
+        Generates a ModelReportVisualizer instance using the reports generated
+        by the generate_model_report() method.
+
+        Returns the generated ModelReportVisualizer instance initialized
+
+        Note:
+            Throws exception if attempt to get visualizers without generating report
+        """
+        # check if user has generated reports at least once
+        if len(self._generated_reports) == 0:
+            raise Exception("Unable to generate visualizers without first generating reports")
+
+        # get the ordered dict mapping modules to their full set of collected features / stats
+        module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
+
+        # create and return ModelReportVisualizer instance
+        visualizer: ModelReportVisualizer = ModelReportVisualizer(module_fqns_to_features)
+
+        return visualizer
+
+    def _generate_qconfig_mapping_helper(
+        self,
+        detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo],
+        generation_function: Callable
+    ) -> QConfigMapping:
+        r"""
+        This helper takes in the compiled detector qconfig info that
+        has been compiled together and merges it into a QConfigMapping
+        """
+        # keep track of the qconfigmapping
+        qconfig_mapping = QConfigMapping()
+
+        # loop through each module / fqn and attempt to create QConfigMapping
+        for fqn, module in self._model.named_modules():
+            # if we have a qconfig info for this module
+            if fqn in detector_qconfig_info_combined:
+                qconfig_info_compiled = detector_qconfig_info_combined[fqn]
+
+                # now generate the qconfig and add it to the mapping
+                generated_qconfig = generation_function(qconfig_info_compiled, module)
+
+                # add to our config
+                qconfig_mapping.set_module_name(fqn, generated_qconfig)
+
+        # return compiled mapping
+        return qconfig_mapping
+
+    def _update_detector_quantizaiton_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo):
+        r"""
+        Takes in the old and new information and updates the combined information.
+
+        Args:
+            combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
+            new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
+                into it
+        """
+        combined_info.is_activation_dynamic = combined_info.is_activation_dynamic or new_info.is_activation_dynamic
+        combined_info.is_weight_per_channel = combined_info.is_weight_per_channel or new_info.is_weight_per_channel
+
+    def _update_detector_equalization_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo):
+        r"""
+        Takes in the old and new information and updates the combined information.
+
+        Args:
+            combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
+            new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
+                into it
+        """
+        is_equalization_recommended = combined_info.is_equalization_recommended or new_info.is_equalization_recommended
+        combined_info.is_equalization_recommended = is_equalization_recommended
+
+    def _generate_module_fqn_to_detector_info_mapping(
+        self,
+        update_qconfig_info_function: Callable
+    ) -> Dict[str, DetectorQConfigInfo]:
+        r"""
+        Generates a QConfigMapping based on the suggestions of the
+        ModelReport API. The generated mapping encompasses all the
+        different types of feedback from the different detectors
+        all into one place.
+
+        These configs are based on the suggestions provided by the ModelReport API
+        and can only be generated once the reports have been generated.
+
+        Args:
+            update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo
+            and updates the one that is being compiled
+
+        Returns a Dict mapping module_fqns to DetectorQConfigInfo objects
+
+        Note:
+            Throws exception if we try to generate mapping on model we already removed observers from
+            Throws exception if we try to generate mapping without preparing for callibration
+        """
+        # if we haven't prepped model for callibration, then we shouldn't generate mapping yet
+        if not self._prepared_flag:
+            raise Exception("Cannot generate report without preparing model for callibration")
+
+        # if we already removed the observers, we cannot mapping
+        if self._removed_observers:
+            raise Exception("Cannot generate report on model you already removed observers from")
+
+        # keep track of qconfig info for each module across detectors
+        detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo] = {}
+
+        for detector in self._desired_report_detectors:
+            # get the info from the detector
+            detector_info: Dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(self._model)
+
+            # we go through the modules
+            for module_fqn in detector_info:
+                # see if we already have info on it
+                if module_fqn in detector_qconfig_info_combined:
+                    # we combine the current options with what is there
+                    current_options = detector_qconfig_info_combined[module_fqn]
+                    detector_options = detector_info[module_fqn]
+
+                    update_qconfig_info_function(current_options, detector_options)
+                else:
+                    # we just use this for now
+                    detector_qconfig_info_combined[module_fqn] = detector_info[module_fqn]
+
+        return detector_qconfig_info_combined
+
+    def generate_qconfig_mapping(self) -> QConfigMapping:
+        r"""
+        Generates a QConfigMapping based on the suggestions of the
+        ModelReport API. The generated mapping encompasses all the
+        different types of feedback from the different detectors
+        all into one place.
+
+        These configs are based on the suggestions provided by the ModelReport API
+        and can only be generated once the reports have been generated.
+
+        Returns a QConfigMapping for the quantization configuration
+
+        Note:
+            Throws exception if we try to generate mapping on model we already removed observers from
+            Throws exception if we try to generate mapping without preparing for callibration
+        """
+        # get the mapping info
+        detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping(
+            self._update_detector_quantizaiton_qconfig_info
+        )
+
+        # we will do a bit of processing and remove fqns that don't have input weight recommended
+
+        # now we generate the QConfig for each of the options
+        mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
+            detector_qconfig_info_combined,
+            self._quantization_config_generator
+        )
+
+        # return the generated mapping
+        return mapping
+
+    def _quantization_config_generator(self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module) -> QConfig:
+        r"""
+        Returns the quantization configuration generated by the DetectorQConfigInfo object
+        """
+        return detector_qconfig_info.generate_quantization_qconfig(module)
+
+    def _equalization_config_generator(
+        self,
+        detector_qconfig_info: DetectorQConfigInfo,
+        module: torch.nn.Module
+    ) -> EqualizationQConfig:
+        r"""
+        We ignore the module argument here, and only focus on thedetector_qconfig_info
+
+        Returns the equalization configuration generated by the DetectorQConfigInfo object
+        """
+        return detector_qconfig_info.generate_equalization_qconfig()
+
+    def generate_equalization_mapping(self) -> QConfigMapping:
+        r"""
+        Generates a QConfigMapping based on the suggestions of the
+        ModelReport API for equalization. The generated mapping encompasses all the
+        different types of feedback from the input-weight equalization detector.
+
+        These configs are based on the suggestions provided by the ModelReport API
+        and can only be generated once the reports have been generated.
+
+        Returns a QConfigMapping for the equalization configuration
+        """
+        # get the mapping info
+        detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping(
+            self._update_detector_equalization_qconfig_info
+        )
+
+        # now we generate the QConfig for each of the options
+        mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
+            detector_qconfig_info_combined,
+            self._equalization_config_generator
+        )
+
+        # return the generated mapping
+        return mapping
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbea160cc43229e21e713c69009bd8925ef69867
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py
@@ -0,0 +1,265 @@
+import torch
+from torch.ao.quantization.observer import ObserverBase
+
+
+class ModelReportObserver(ObserverBase):
+    r"""This observer is used to record additional information regarding keeping track
+    of S = average_batch_activation_range/epoch_activation_range.
+
+    The purpose of this information is to prepare a report to present to users on whether
+    Dynamic or Static Quantization is more appropriate for their model given the general
+    distributions of their data.
+
+    Args:
+        ch_axis (int, optional): The channel axis for which the range and outlier stats are computed
+            Default: 1
+        comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers
+            Should be between 0 and 1 exclusive
+            Default: 0.9
+
+    * :attr:`num_batches_tracked` specifies number of batches passed through the observer
+
+    * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through
+
+    * :attr:`epoch_activation_min` defines the minimum value passed through the observer
+
+    * :attr:`epoch_activation_max` defines the maximum value passed through the observer
+
+    * :attr:`ch_axis` defines the channel being used to compute per channel min max stats
+
+    * :attr:`min_val` defines the per channel minimum values passed through
+
+    * :attr:`max_val` defines the per channel maximum values passed through
+
+    * :attr:`comp_percentile` defines comparison percentile to find outliers
+
+    * :attr:`average_percentile_ratio` defines the per channel average percentile ratios
+
+    * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel
+
+    * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel
+
+    Note: this tool is meant for FX Graph Mode Quantization
+    """
+
+    epoch_activation_min: torch.Tensor
+    epoch_activation_max: torch.Tensor
+    min_val: torch.Tensor
+    max_val: torch.Tensor
+    comp_percentile: torch.Tensor
+    average_percentile_ratio: torch.Tensor
+    percentile_batches_tracked: torch.Tensor
+    constant_channels: torch.Tensor
+
+    def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9):
+        super().__init__(torch.qint8)
+        self.num_batches_tracked = 0
+
+        # keep track of the min and mix of the range for average batch and epoch as a whole
+        self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0))
+        self.register_buffer("epoch_activation_min", torch.tensor(float("inf")))
+        self.register_buffer("epoch_activation_max", torch.tensor(float("-inf")))
+
+        # keep track of per channel min max information using the given channel
+        self.ch_axis: int = ch_axis
+        self.register_buffer("min_val", torch.tensor([]))
+        self.register_buffer("max_val", torch.tensor([]))
+
+        # keep track of percentile ratio information per channel
+        self.register_buffer("comp_percentile", torch.tensor([comp_percentile]))
+        self.register_buffer("average_percentile_ratio", torch.tensor([]))
+        self.register_buffer("percentile_batches_tracked", torch.tensor([]))
+        self.register_buffer("constant_channels", torch.tensor([]))
+
+    def forward(self, x):
+        x_copy = x.detach()  # avoid keeping autograd tape
+        x_copy = x_copy.to(self.epoch_activation_min.dtype)
+
+        x_copy = self._calculate_range_stats(x_copy)
+        x_copy = self._calculate_min_max_stats(x_copy)
+        x_copy = self._calculate_percentile_stats(x_copy)
+
+        # return the passed in the value
+        return x
+
+    def _calculate_range_stats(self, x_copy):
+        r"""Calculates and stores range stats with forward values.
+
+        Args
+            x_copy: A copy of the forward data
+
+        Returns the passed in x_copy
+        """
+        # get the min, max values of the data
+        min_val_cur, max_val_cur = torch.aminmax(x_copy)
+
+        # calculate new epoch range values
+        epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur)
+        epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur)
+
+        self.epoch_activation_min.copy_(epoch_min_val)
+        self.epoch_activation_max.copy_(epoch_max_val)
+
+        # calculate the average batch activation range
+        current_batch_range = max_val_cur - min_val_cur
+        new_range = (
+            self.average_batch_activation_range * self.num_batches_tracked
+            + current_batch_range
+        ) / (self.num_batches_tracked + 1)
+
+        self.average_batch_activation_range = new_range
+        self.num_batches_tracked += 1  # new batch was processed
+
+        return x_copy
+
+    def _calculate_min_max_stats(self, x_copy):
+        r"""Calculates and stores the per_channel min, max stats with forward values.
+        Does calculation based on channel axis: self.ch_axis
+
+        Args
+            x_copy: A copy of the forward data
+
+        Returns the passed in x_copy
+        """
+        # get the current min and max vals
+        min_val = self.min_val
+        max_val = self.max_val
+        x_dim = x_copy.size()
+
+        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+        new_axis_list[self.ch_axis] = 0
+        new_axis_list[0] = self.ch_axis
+        y = x_copy.permute(new_axis_list)
+        # Need to match dtype of min/max because the updates to buffers
+        # are done in place and types need to match for comparisons
+        y = y.to(self.min_val.dtype)
+        y = torch.flatten(y, start_dim=1)
+        if min_val.numel() == 0 or max_val.numel() == 0:
+            min_val, max_val = torch.aminmax(y, dim=1)
+        else:
+            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
+            min_val = torch.min(min_val_cur, min_val)
+            max_val = torch.max(max_val_cur, max_val)
+
+        self.min_val.resize_(min_val.shape)
+        self.max_val.resize_(max_val.shape)
+        self.min_val.copy_(min_val)
+        self.max_val.copy_(max_val)
+
+        return x_copy
+
+    def _calculate_percentile_stats(self, x_copy):
+        r"""Calculates and stores the per_channel percentile stats with forward values.
+        Does calculation based on channel axis: self.ch_axis
+
+        Args
+            x_copy: A copy of the forward data
+
+        Returns the passed in x_copy
+        """
+        # get the dimension of the copy
+        x_dim = x_copy.size()
+
+        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+        new_axis_list[self.ch_axis] = 0
+        new_axis_list[0] = self.ch_axis
+        y = x_copy.permute(new_axis_list)
+        # Need to match dtype of min/max because the updates to buffers
+        # are done in place and types need to match for comparisons
+        y = y.to(self.min_val.dtype)
+        y = torch.flatten(y, start_dim=1)
+        y = y.to(dtype=self.min_val.dtype, device="cpu")
+
+        # find the percentile values along the axis
+        # we want both 100th percentile and comp_percentile
+        # we also want to find 0th quartile to see if we have constant channel
+        quantiles_list = [0, self.comp_percentile, 1.00]
+        quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype)
+
+        # find the quantiles
+        desired_quantiles = torch.quantile(y, quantiles_to_find, dim=self.ch_axis, interpolation="lower")
+        zero_quantile = desired_quantiles[0]
+        comp_quantile = desired_quantiles[1]
+        hundreth_quartile = desired_quantiles[2]
+
+        # if any of the channels have 0s, we ignore that channel for this calculation
+        any_non_zero_quantile_value: torch.Tensor = (comp_quantile != torch.tensor([0])) | (hundreth_quartile != torch.tensor([0]))
+        any_non_zero_quantile_value = any_non_zero_quantile_value.int()  # transform boolean values to int values
+
+        # we also check if we have a constant channel
+        any_constant_channels: torch.Tensor = (hundreth_quartile - zero_quantile) == torch.tensor([0])
+        any_constant_channels = any_constant_channels.int()  # transform boolean values to int values
+
+        # possibilities to get nan as an answer
+        #   will ignore any of these three cases with 0s and just not deal with them for now
+        # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative
+        # case (2) 0 in denominator: is possible unless case 3, we just ignore
+        # case (3) 0 in both: not outlier, channel just kinda useless, ignore
+
+        # get the ratio and get rid of nan values
+        quantile_ratios = hundreth_quartile / comp_quantile
+        quantile_ratios = torch.nan_to_num(quantile_ratios)
+        # update averages, remembering to only update if didn't have zeros
+        ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios
+
+        # if num_batches and average_ratio are not initialized, we want to initialize them
+        if self.percentile_batches_tracked.shape[0] == 0 or self.average_percentile_ratio.shape[0] == 0:
+            self.percentile_batches_tracked = torch.zeros_like(any_non_zero_quantile_value)
+            self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero)
+
+        # also initialize the constant channel var if that is not initialized separately
+        if self.constant_channels.shape[0] == 0:
+            self.constant_channels = torch.zeros_like(any_constant_channels)
+
+        # get current num batches and average ratio
+        num_batches = self.percentile_batches_tracked
+        average_ratio = self.average_percentile_ratio
+
+        # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches
+        new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value
+        new_ratios: torch.Tensor = ((average_ratio * num_batches) + ratio_if_not_zero) / new_number_of_batches
+        new_ratios = torch.nan_to_num(new_ratios)
+
+        # update the number of non-constant channels
+        new_constant_count: torch.Tensor = self.constant_channels + any_constant_channels
+
+        # update the values locally
+        self.percentile_batches_tracked.copy_(new_number_of_batches)
+        self.average_percentile_ratio.copy_(new_ratios)
+        self.constant_channels.copy_(new_constant_count)
+
+        return x_copy
+
+    @torch.jit.export
+    def get_batch_to_epoch_ratio(self):
+        epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min
+
+        if epoch_activation_range == torch.tensor(float(0)):
+            raise ValueError("Range for Epoch is 0")
+        elif epoch_activation_range == torch.tensor(float("inf")):
+            raise ValueError(
+                "No data has been run through observer or infinity value present"
+            )
+        else:
+            return self.average_batch_activation_range / epoch_activation_range
+
+    @torch.jit.export
+    def reset_batch_and_epoch_values(self):
+        # set all the values back to their original defaults for a new epoch
+        # keep device
+        device = self.max_val.device
+        self.num_batches_tracked = 0
+        self.average_batch_activation_range = torch.tensor(float(0), device=device)
+        self.epoch_activation_min = torch.tensor(float("inf"), device=device)
+        self.epoch_activation_max = torch.tensor(float("-inf"), device=device)
+        self.min_val = torch.tensor([], device=device)
+        self.max_val = torch.tensor([], device=device)
+        self.average_percentile_ratio = torch.tensor([], device=device)
+        self.percentile_batches_tracked = torch.tensor([], device=device)
+        self.constant_channels = torch.tensor([], device=device)
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        raise Exception(
+            "calculate_qparams should not be called for ModelReportObserver"
+        )
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e58463e59a979b24097f0e362a388ec441258048
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
@@ -0,0 +1,666 @@
+import torch
+from typing import Any, Set, Dict, List, Tuple, OrderedDict
+from collections import OrderedDict as OrdDict
+
+# try to import tablate
+got_tabulate = True
+try:
+    from tabulate import tabulate
+except ImportError:
+    got_tabulate = False
+
+
+# var to see if we could import matplotlib
+got_matplotlib = True
+try:
+    import matplotlib.pyplot as plt
+except ImportError:
+    got_matplotlib = False
+
+class ModelReportVisualizer:
+    r"""
+    The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics
+    that were generated by the ModelReport API. However, at a higher level, the class aims to provide
+    some level of visualization of statistics to PyTorch in order to make it easier to parse data and
+    diagnose any potential issues with data or a specific model. With respect to the visualizations,
+    the ModelReportVisualizer class currently supports several methods of visualizing data.
+
+    Supported Visualization Methods Include:
+    - Table format
+    - Plot format (line graph)
+    - Histogram format
+
+    For all of the existing visualization methods, there is the option to filter data based on:
+    - A module fqn prefix
+    - Feature [required for the plot and histogram]
+
+    * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below
+        Ensure sure that features that are the same across different report contain the same name
+        Ensure that objects representing the same features are the same type / dimension (where applicable)
+
+    Note:
+        Currently, the ModelReportVisualizer class supports visualization of data generated by the
+        ModelReport class. However, this structure is extensible and should allow the visualization of
+        other information as long as the information is structured in the following general format:
+
+        Report Structure
+        -- module_fqn [module with attached detectors]
+            |
+            -- feature keys [not every detector extracts same information]
+                                    [same collected info has same keys, unless can be specific to detector]
+
+
+    The goal behind the class is that the generated visualizations can be used in conjunction with the generated
+    report for people to get a better understanding of issues and what the fix might be. It is also just to provide
+    a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as
+    that grows in size.
+
+    General Use Flow Expected
+    1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects
+    2.) Prepare your model with prepare_fx
+    3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers
+    4.) Callibrate your model with data
+    5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
+    6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance
+    7.) Use instance to view different views of data as desired, applying filters as needed
+        8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram
+
+    """
+
+    # keys for table dict
+    TABLE_TENSOR_KEY = "tensor_level_info"
+    TABLE_CHANNEL_KEY = "channel_level_info"
+
+    # Constants for header vals
+    NUM_NON_FEATURE_TENSOR_HEADERS = 2
+    NUM_NON_FEATURE_CHANNEL_HEADERS = 3
+
+    # Constants for row index in header
+    CHANNEL_NUM_INDEX = 2
+
+    def __init__(self, generated_reports: OrderedDict[str, Any]):
+        r"""
+        Initializes the ModelReportVisualizer instance with the necessary reports.
+
+        Args:
+            generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
+                can also be a dictionary generated in another manner, as long as format is same
+        """
+        self.generated_reports = generated_reports
+
+    def get_all_unique_module_fqns(self) -> Set[str]:
+        r"""
+        The purpose of this method is to provide a user the set of all module_fqns so that if
+        they wish to use some of the filtering capabilities of the ModelReportVisualizer class,
+        they don't need to manually parse the generated_reports dictionary to get this information.
+
+        Returns all the unique module fqns present in the reports the ModelReportVisualizer
+        instance was initialized with.
+        """
+        # returns the keys of the ordered dict
+        return set(self.generated_reports.keys())
+
+    def get_all_unique_feature_names(self, plottable_features_only: bool = True) -> Set[str]:
+        r"""
+        The purpose of this method is to provide a user the set of all feature names so that if
+        they wish to use the filtering capabilities of the generate_table_view(), or use either of
+        the generate_plot_view() or generate_histogram_view(), they don't need to manually parse
+        the generated_reports dictionary to get this information.
+
+        Args:
+            plottable_features_only (bool): True if the user is only looking for plottable features,
+                False otherwise
+                plottable features are those that are tensor values
+                Default: True (only return those feature names that are plottable)
+
+        Returns all the unique module fqns present in the reports the ModelReportVisualizer
+        instance was initialized with.
+        """
+        unique_feature_names = set()
+        for module_fqn in self.generated_reports:
+            # get dict of the features
+            feature_dict: Dict[str, Any] = self.generated_reports[module_fqn]
+
+            # loop through features
+            for feature_name in feature_dict:
+                # if we need plottable, ensure type of val is tensor
+                if not plottable_features_only or type(feature_dict[feature_name]) == torch.Tensor:
+                    unique_feature_names.add(feature_name)
+
+        # return our compiled set of unique feature names
+        return unique_feature_names
+
+    def _get_filtered_data(self, feature_filter: str, module_fqn_filter: str) -> OrderedDict[str, Any]:
+        r"""
+        Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed.
+
+        Args:
+            feature_filter (str): The feature filter, if we want to filter the set of data to only include
+                a certain set of features that include feature_filter
+                If feature = "", then we do not filter based on any features
+            module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with
+                this prefix will be included
+                If module_fqn_filter = "" we do not filter based on module fqn, and include all modules
+
+        First, the data is filtered based on module_fqn, and then filtered based on feature
+        Returns an OrderedDict (sorted in order of model) mapping:
+            module_fqns -> feature_names -> values
+        """
+        # create return dict
+        filtered_dict: OrderedDict[str, Any] = OrdDict()
+
+        for module_fqn in self.generated_reports:
+            # first filter based on module
+            if module_fqn_filter == "" or module_fqn_filter in module_fqn:
+                # create entry for module and loop through features
+                filtered_dict[module_fqn] = {}
+                module_reports = self.generated_reports[module_fqn]
+                for feature_name in module_reports:
+                    # check if filtering on features and do so if desired
+                    if feature_filter == "" or feature_filter in feature_name:
+                        filtered_dict[module_fqn][feature_name] = module_reports[feature_name]
+
+        # we have populated the filtered dict, and must return it
+
+        return filtered_dict
+
+    def _generate_tensor_table(
+        self,
+        filtered_data: OrderedDict[str, Dict[str, Any]],
+        tensor_features: List[str]
+    ) -> Tuple[List, List]:
+        r"""
+        Takes in the filtered data and features list and generates the tensor headers and table
+
+        Currently meant to generate the headers and table for both the tensor information.
+
+        Args:
+            filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping:
+                module_fqns -> feature_names -> values
+            tensor_features (List[str]): A list of the tensor level features
+
+        Returns a tuple with:
+            A list of the headers of the tensor table
+            A list of lists containing the table information row by row
+            The 0th index row will contain the headers of the columns
+            The rest of the rows will contain data
+        """
+        # now we compose the tensor information table
+        tensor_table: List[List[Any]] = []
+        tensor_headers: List[str] = []
+
+        # append the table row to the table only if we have features
+        if len(tensor_features) > 0:
+            # now we add all the data
+            for index, module_fqn in enumerate(filtered_data):
+                # we make a new row for the tensor table
+                tensor_table_row = [index, module_fqn]
+                for feature in tensor_features:
+                    # we iterate in same order of added features
+
+                    if feature in filtered_data[module_fqn]:
+                        # add value if applicable to module
+                        feature_val = filtered_data[module_fqn][feature]
+                    else:
+                        # add that it is not applicable
+                        feature_val = "Not Applicable"
+
+                    # if it's a tensor we want to extract val
+                    if isinstance(feature_val, torch.Tensor):
+                        feature_val = feature_val.item()
+
+                    # we add to our list of values
+                    tensor_table_row.append(feature_val)
+
+                tensor_table.append(tensor_table_row)
+
+        # add row of headers of we actually have something, otherwise just empty
+        if len(tensor_table) != 0:
+            tensor_headers = ["idx", "layer_fqn"] + tensor_features
+
+        return (tensor_headers, tensor_table)
+
+    def _generate_channels_table(
+        self,
+        filtered_data: OrderedDict[str, Any],
+        channel_features: List[str],
+        num_channels: int
+    ) -> Tuple[List, List]:
+        r"""
+        Takes in the filtered data and features list and generates the channels headers and table
+
+        Currently meant to generate the headers and table for both the channels information.
+
+        Args:
+            filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping:
+                module_fqns -> feature_names -> values
+            channel_features (List[str]): A list of the channel level features
+            num_channels (int): Number of channels in the channel data
+
+        Returns a tuple with:
+            A list of the headers of the channel table
+            A list of lists containing the table information row by row
+            The 0th index row will contain the headers of the columns
+            The rest of the rows will contain data
+        """
+        # now we compose the table for the channel information table
+        channel_table: List[List[Any]] = []
+        channel_headers: List[str] = []
+
+        # counter to keep track of number of entries in
+        channel_table_entry_counter: int = 0
+
+        if len(channel_features) > 0:
+            # now we add all channel data
+            for module_fqn in filtered_data:
+                # we iterate over all channels
+                for channel in range(num_channels):
+                    # we make a new row for the channel
+                    new_channel_row = [channel_table_entry_counter, module_fqn, channel]
+                    for feature in channel_features:
+                        if feature in filtered_data[module_fqn]:
+                            # add value if applicable to module
+                            feature_val = filtered_data[module_fqn][feature][channel]
+                        else:
+                            # add that it is not applicable
+                            feature_val = "Not Applicable"
+
+                        # if it's a tensor we want to extract val
+                        if type(feature_val) is torch.Tensor:
+                            feature_val = feature_val.item()
+
+                        # add value to channel specific row
+                        new_channel_row.append(feature_val)
+
+                    # add to table and increment row index counter
+                    channel_table.append(new_channel_row)
+                    channel_table_entry_counter += 1
+
+        # add row of headers of we actually have something, otherwise just empty
+        if len(channel_table) != 0:
+            channel_headers = ["idx", "layer_fqn", "channel"] + channel_features
+
+        return (channel_headers, channel_table)
+
+    def generate_filtered_tables(self, feature_filter: str = "", module_fqn_filter: str = "") -> Dict[str, Tuple[List, List]]:
+        r"""
+        Takes in optional filter values and generates two tables with desired information.
+
+        The generated tables are presented in both a list-of-lists format
+
+        The reason for the two tables are that they handle different things:
+        1.) the first table handles all tensor level information
+        2.) the second table handles and displays all channel based information
+
+        The reasoning for this is that having all the info in one table can make it ambiguous which collected
+            statistics are global, and which are actually per-channel, so it's better to split it up into two
+            tables. This also makes the information much easier to digest given the plethora of statistics collected
+
+        Tensor table columns:
+            idx  layer_fqn  feature_1   feature_2   feature_3   .... feature_n
+            ----  ---------  ---------   ---------   ---------        ---------
+
+        Per-Channel table columns:
+            idx  layer_fqn  channel  feature_1   feature_2   feature_3   .... feature_n
+            ----  ---------  -------  ---------   ---------   ---------        ---------
+
+        Args:
+            feature_filter (str, optional): Filters the features presented to only those that
+                contain this filter substring
+                Default = "", results in all the features being printed
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+
+        Returns a dictionary with two keys:
+            (Dict[str, Tuple[List, List]]) A dict containing two keys:
+            "tensor_level_info", "channel_level_info"
+                Each key maps to a tuple with:
+                    A list of the headers of each table
+                    A list of lists containing the table information row by row
+                    The 0th index row will contain the headers of the columns
+                    The rest of the rows will contain data
+
+        Example Use:
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> mod_report_visualizer.generate_filtered_tables(
+            ...     feature_filter = "per_channel_min",
+            ...     module_fqn_filter = "block1"
+            ... ) # generates table with per_channel_min info for all modules in block 1 of the model
+        """
+        # first get the filtered data
+        filtered_data: OrderedDict[str, Any] = self._get_filtered_data(feature_filter, module_fqn_filter)
+
+        # now we split into tensor and per-channel data
+        tensor_features: Set[str] = set()
+        channel_features: Set[str] = set()
+
+        # keep track of the number of channels we have
+        num_channels: int = 0
+
+        for module_fqn in filtered_data:
+            for feature_name in filtered_data[module_fqn]:
+                # get the data for that specific feature
+                feature_data = filtered_data[module_fqn][feature_name]
+
+                # check if not zero dim tensor
+                is_tensor: bool = isinstance(feature_data, torch.Tensor)
+                is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0
+
+                if is_not_zero_dim or isinstance(feature_data, list):
+                    # works means per channel
+                    channel_features.add(feature_name)
+                    num_channels = len(feature_data)
+                else:
+                    # means is per-tensor
+                    tensor_features.add(feature_name)
+
+        # we make them lists for iteration purposes
+        tensor_features_list: List[str] = sorted(tensor_features)
+        channel_features_list: List[str] = sorted(channel_features)
+
+        # get the tensor info
+        tensor_headers, tensor_table = self._generate_tensor_table(filtered_data, tensor_features_list)
+
+        # get the channel info
+        channel_headers, channel_table = self._generate_channels_table(
+            filtered_data, channel_features_list, num_channels
+        )
+
+        # let's now create the dictionary to return
+        table_dict = {
+            self.TABLE_TENSOR_KEY : (tensor_headers, tensor_table),
+            self.TABLE_CHANNEL_KEY : (channel_headers, channel_table)
+        }
+
+        # return the two tables
+        return table_dict
+
+    def generate_table_visualization(self, feature_filter: str = "", module_fqn_filter: str = ""):
+        r"""
+        Takes in optional filter values and prints out formatted tables of the information.
+
+        The reason for the two tables printed out instead of one large one are that they handle different things:
+        1.) the first table handles all tensor level information
+        2.) the second table handles and displays all channel based information
+
+        The reasoning for this is that having all the info in one table can make it ambiguous which collected
+            statistics are global, and which are actually per-channel, so it's better to split it up into two
+            tables. This also makes the information much easier to digest given the plethora of statistics collected
+
+        Tensor table columns:
+         idx  layer_fqn  feature_1   feature_2   feature_3   .... feature_n
+        ----  ---------  ---------   ---------   ---------        ---------
+
+        Per-Channel table columns:
+
+         idx  layer_fqn  channel  feature_1   feature_2   feature_3   .... feature_n
+        ----  ---------  -------  ---------   ---------   ---------        ---------
+
+        Args:
+            feature_filter (str, optional): Filters the features presented to only those that
+                contain this filter substring
+                Default = "", results in all the features being printed
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+
+        Example Use:
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> mod_report_visualizer.generate_table_visualization(
+            ...     feature_filter = "per_channel_min",
+            ...     module_fqn_filter = "block1"
+            ... )
+            >>> # prints out neatly formatted table with per_channel_min info
+            >>> # for all modules in block 1 of the model
+        """
+        # see if we got tabulate
+        if not got_tabulate:
+            print("Make sure to install tabulate and try again.")
+            return None
+
+        # get the table dict and the specific tables of interest
+        table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
+        tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
+        channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
+
+        # get the table string and print it out
+        # now we have populated the tables for each one
+        # let's create the strings to be returned
+        table_str = ""
+        # the tables will have some headers columns that are non-feature
+        # ex. table index, module name, channel index, etc.
+        # we want to look at header columns for features, that come after those headers
+        if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS:
+            # if we have at least one tensor level feature to be added we add tensor table
+            table_str += "Tensor Level Information \n"
+            table_str += tabulate(tensor_table, headers=tensor_headers)
+        if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS:
+            # if we have at least one channel level feature to be added we add tensor table
+            table_str += "\n\n Channel Level Information \n"
+            table_str += tabulate(channel_table, headers=channel_headers)
+
+        # if no features at all, let user know
+        if table_str == "":
+            table_str = "No data points to generate table with."
+
+        print(table_str)
+
+    def _get_plottable_data(self, feature_filter: str, module_fqn_filter: str) -> Tuple[List, List[List], bool]:
+        r"""
+        Takes in the feature filters and module filters and outputs the x and y data for plotting
+
+        Args:
+            feature_filter (str): Filters the features presented to only those that
+                contain this filter substring
+            module_fqn_filter (str): Only includes modules that contains this string
+
+        Returns a tuple of three elements
+            The first is a list containing relevant x-axis data
+            The second is a list containing the corresponding y-axis data
+            If the data is per channel
+        """
+        # get the table dict and the specific tables of interest
+        table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
+        tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
+        channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
+
+        # make sure it is only 1 feature that is being plotted
+        # get the number of features in each of these
+        tensor_info_features_count = len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
+        channel_info_features_count = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
+
+        # see if valid tensor or channel plot
+        is_valid_per_tensor_plot: bool = tensor_info_features_count == 1
+        is_valid_per_channel_plot: bool = channel_info_features_count == 1
+
+        # offset should either be one of tensor or channel table or neither
+        feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
+        table = tensor_table
+
+        # if a per_channel plot, we have different offset and table
+        if is_valid_per_channel_plot:
+            feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
+            table = channel_table
+
+        x_data: List = []
+        y_data: List[List] = []
+        # the feature will either be a tensor feature or channel feature
+        if is_valid_per_tensor_plot:
+            for table_row_num, row in enumerate(table):
+                # get x_value to append
+                x_val_to_append = table_row_num
+                # the index of the feature will the 0 + num non feature columns
+                tensor_feature_index = feature_column_offset
+                row_value = row[tensor_feature_index]
+                if not type(row_value) == str:
+                    x_data.append(x_val_to_append)
+                    y_data.append(row_value)
+        elif is_valid_per_channel_plot:
+            # gather the x_data and multiple y_data
+            # calculate the number of channels
+            num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
+            for channel in range(num_channels):
+                y_data.append([])  # separate data list per channel
+
+            for table_row_num, row in enumerate(table):
+                # get x_value to append
+                x_val_to_append = table_row_num
+                current_channel = row[self.CHANNEL_NUM_INDEX]  # initially chose current channel
+                new_module_index: int = table_row_num // num_channels
+                x_val_to_append = new_module_index
+
+                # the index of the feature will the 0 + num non feature columns
+                tensor_feature_index = feature_column_offset
+                row_value = row[tensor_feature_index]
+                if not type(row_value) == str:
+                    # only append if new index we are appending
+                    if len(x_data) == 0 or x_data[-1] != x_val_to_append:
+                        x_data.append(x_val_to_append)
+
+                    # append value for that channel
+                    y_data[current_channel].append(row_value)
+        else:
+            # more than one feature was chosen
+            error_str = "Make sure to pick only a single feature with your filter to plot a graph."
+            error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names."
+            error_str += " Pick one of those features to plot."
+            raise ValueError(error_str)
+
+        # return x, y values, and if data is per-channel
+        return (x_data, y_data, is_valid_per_channel_plot)
+
+    def generate_plot_visualization(self, feature_filter: str, module_fqn_filter: str = ""):
+        r"""
+        Takes in a feature and optional module_filter and plots of the desired data.
+
+        For per channel features, it averages the value across the channels and plots a point
+        per module. The reason for this is that for models with hundreds of channels, it can
+        be hard to differentiate one channel line from another, and so the point of generating
+        a single average point per module is to give a sense of general trends that encourage
+        further deep dives.
+
+        Note:
+            Only features in the report that have tensor value data are plottable by this class
+            When the tensor information is plotted, it will plot:
+                idx as the x val, feature value as the y_val
+            When the channel information is plotted, it will plot:
+                the first idx of each module as the x val, feature value as the y_val [for each channel]
+                The reason for this is that we want to be able to compare values across the
+                channels for same layer, and it will be hard if values are staggered by idx
+                This means each module is represented by only 1 x value
+        Args:
+            feature_filter (str): Filters the features presented to only those that
+                contain this filter substring
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+
+        Example Use:
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> mod_report_visualizer.generate_plot_visualization(
+            ...     feature_filter = "per_channel_min",
+            ...     module_fqn_filter = "block1"
+            ... )
+            >>> # outputs line plot of per_channel_min information for all
+            >>> # modules in block1 of model each channel gets it's own line,
+            >>> # and it's plotted across the in-order modules on the x-axis
+        """
+        # checks if we have matplotlib and let's user know to install it if don't
+        if not got_matplotlib:
+            print("make sure to install matplotlib and try again.")
+            return None
+
+        # get the x and y data and if per channel
+        x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter)
+
+        # plot based on whether data is per channel or not
+        ax = plt.subplot()
+        ax.set_ylabel(feature_filter)
+        ax.set_title(feature_filter + " Plot")
+        plt.xticks(x_data)  # only show ticks for actual points
+
+        if data_per_channel:
+            ax.set_xlabel("First idx of module")
+            # set the legend as well
+            # plot a single line that is average of the channel values
+            num_modules = len(y_data[0])  # all y_data have same length, so get num modules
+            num_channels = len(y_data)  # we want num channels to be able to calculate average later
+
+            avg_vals = [sum(y_data[:][index]) / num_channels for index in range(num_modules)]
+
+            # plot the three things we measured
+            ax.plot(x_data, avg_vals, label=f"Average Value Across {num_channels} Channels")
+            ax.legend(loc='upper right')
+        else:
+            ax.set_xlabel("idx")
+            ax.plot(x_data, y_data)
+
+        # actually show the plot
+        plt.show()
+
+    def generate_histogram_visualization(self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10):
+        r"""
+        Takes in a feature and optional module_filter and plots the histogram of desired data.
+
+        Note:
+            Only features in the report that have tensor value data can be viewed as a histogram
+            If you want to plot a histogram from all the channel values of a specific feature for
+                a specific model, make sure to specify both the model and the feature properly
+                in the filters and you should be able to see a distribution of the channel data
+
+        Args:
+            feature_filter (str, optional): Filters the features presented to only those that
+                contain this filter substring
+                Default = "", results in all the features being printed
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+            num_bins (int, optional): The number of bins to create the histogram with
+                Default = 10, the values will be split into 10 equal sized bins
+
+        Example Use:
+            >>> # xdoctest: +SKIP
+            >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization(
+            ...     feature_filter = "per_channel_min",
+            ...     module_fqn_filter = "block1"
+            ... )
+            # outputs histogram of per_channel_min information for all modules in block1 of model
+                information is gathered across all channels for all modules in block 1 for the
+                per_channel_min and is displayed in a histogram of equally sized bins
+        """
+        # checks if we have matplotlib and let's user know to install it if don't
+        if not got_matplotlib:
+            print("make sure to install matplotlib and try again.")
+            return None
+
+        # get the x and y data and if per channel
+        x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter)
+
+        # for histogram, we just care about plotting the y data
+        # plot based on whether data is per channel or not
+        ax = plt.subplot()
+        ax.set_xlabel(feature_filter)
+        ax.set_ylabel("Frequency")
+        ax.set_title(feature_filter + " Histogram")
+
+        if data_per_channel:
+            # set the legend as well
+            # combine all the data
+            all_data = []
+            for channel_info in y_data:
+                all_data.extend(channel_info)
+
+            val, bins, _ = plt.hist(
+                all_data,
+                bins=num_bins,
+                stacked=True,
+                rwidth=0.8,
+            )
+            plt.xticks(bins)
+        else:
+            val, bins, _ = plt.hist(
+                y_data,
+                bins=num_bins,
+                stacked=False,
+                rwidth=0.8,
+            )
+            plt.xticks(bins)
+
+        plt.show()
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/convert.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..028c4a94186939afe64bfc1d904a2459da9fc80c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/convert.py
@@ -0,0 +1,1131 @@
+# mypy: ignore-errors
+
+from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable
+from torch.ao.quantization.quant_type import QuantType
+import torch
+import copy
+import warnings
+from torch.fx import (
+    GraphModule,
+)
+from torch.fx.graph import (
+    Graph,
+    Node,
+    Argument,
+)
+from ..utils import (
+    activation_is_statically_quantized,
+    weight_is_quantized,
+    get_qparam_dict,
+    _parent_name,
+    get_swapped_custom_module_class,
+)
+from ..qconfig import (
+    QConfigAny,
+    qconfig_equals
+)
+from ..qconfig_mapping import QConfigMapping
+from .qconfig_mapping_utils import (
+    _generate_node_name_to_qconfig,
+    _compare_prepare_convert_qconfig_mappings,
+    _update_qconfig_for_fusion,
+    _is_qconfig_supported_by_dtype_configs,
+    _update_qconfig_for_qat,
+)
+from torch.ao.quantization.backend_config.utils import (
+    get_root_module_to_quantized_reference_module,
+    get_pattern_to_dtype_configs,
+    get_fused_module_classes,
+    get_qat_module_classes,
+)
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    get_native_backend_config,
+)
+from torch.ao.quantization.observer import _is_activation_post_process
+from .graph_module import (
+    _is_observed_module,
+    _is_observed_standalone_module,
+)
+from ._equalize import update_obs_for_equalization, convert_eq_obs
+from torch.nn.utils.parametrize import type_before_parametrizations
+from .utils import (
+    _get_module,
+    _is_custom_module_lstm,
+    _is_custom_module_mha,
+    assert_and_get_unique_device,
+    get_custom_module_class_keys,
+    create_getattr_from_value,
+    collect_producer_nodes,
+    graph_module_from_producer_nodes,
+    node_arg_is_weight,
+)
+from torch.ao.quantization.utils import (
+    is_per_channel,
+    to_underlying_dtype,
+)
+from torch.ao.quantization.quantize import (
+    _remove_qconfig,
+)
+from torch.ao.quantization.stubs import DeQuantStub
+from .custom_config import (
+    ConvertCustomConfig,
+    PrepareCustomConfig,
+)
+from .lower_to_fbgemm import lower_to_fbgemm
+# importing the lib so that the quantized_decomposed ops are registered
+from ._decomposed import quantized_decomposed_lib  # noqa: F401
+import operator
+
+__all__ = [
+    "convert",
+    "convert_custom_module",
+    "convert_standalone_module",
+    "convert_weighted_module",
+]
+
+_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
+    torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
+    torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
+}
+
+def _replace_observer_with_quantize_dequantize_node_decomposed(
+        model: torch.fx.GraphModule,
+        node: Node,
+        modules: Dict[str, torch.nn.Module],
+        node_name_to_scope: Dict[str, Tuple[str, type]],
+        node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
+    """ Replace activation_post_process module call node with quantize and
+    dequantize node working with decomposed Tensor
+
+    Before:
+    ... -> observer_0(x) -> ...
+    After:
+    ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
+    torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
+
+    or quantize_per_channel and dequantize_per_channel
+    """
+    graph = model.graph
+    assert modules is not None
+    assert isinstance(node.target, str)
+    module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
+    activation_post_process = modules[node.target]
+    if hasattr(activation_post_process, "convert"):
+        activation_post_process.convert(model, node)
+        return
+    # skip replacing observers to quant/dequant nodes if the qconfigs of all
+    # consumers and producers of this observer are None
+    skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
+                           list(node.args) + list(node.users.keys()))
+    if skip_replacement or not _is_conversion_supported(activation_post_process):
+        # didn't find corresponding quantize op and info for the activation_post_process
+        # so we just remove the observer
+        with graph.inserting_before(node):
+            node.replace_all_uses_with(node.args[0])
+            graph.erase_node(node)
+        return
+
+    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
+
+    # 1. extract the information from activation_post_process module for generating
+    # the quantize and dequantize operator
+    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
+
+    is_dynamic = False
+    if hasattr(activation_post_process, "is_dynamic"):
+        is_dynamic = activation_post_process.is_dynamic  # type: ignore[assignment]
+
+    if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
+            (not is_dynamic):
+        # TODO: probably should cleanup this condition check, it's hard
+        # to reason about this if and the following elif
+
+        # uint8/int8/int32 static quantization branch
+
+        # 1. extract information for inserting q/dq node from activation_post_process
+        node_type = "call_function"
+        quantize_op : Optional[Callable] = None
+        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
+        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
+            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
+            quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
+            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
+            quant_min = activation_post_process.quant_min
+            quant_max = activation_post_process.quant_max
+            dtype_ = to_underlying_dtype(dtype)
+            qparams = {
+                "_scale_": scale,
+                "_zero_point_": zero_point,
+                "_axis_": ch_axis,
+                "_quant_min_": quant_min,
+                "_quant_max_": quant_max,
+                "_dtype_": dtype_
+            }
+        else:
+            quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
+            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
+            scale = float(scale)
+            zero_point = int(zero_point)
+            quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
+            quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
+            dtype_ = to_underlying_dtype(dtype)
+            qparams = {
+                "_scale_": scale,
+                "_zero_point_": zero_point,
+                "_quant_min_": quant_min,
+                "_quant_max_": quant_max,
+                "_dtype_": dtype_
+            }
+
+        # 2. replace activation_post_process node with quantize and dequantize
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value_or_node in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))):
+                    # For scale and zero_point values we register them as buffers in the root module.
+                    # However, note that when the values are not tensors, as in the case of
+                    # per_tensor quantization, they will be treated as literals.
+                    # However, registering them as a node seems to cause issue with dynamo
+                    # tracing where it may consider tensor overload as opposed to default.
+                    # With extra check of scale and zero_point being scalar, it makes
+                    # sure that the default overload can be used.
+                    # TODO: maybe need more complex attr name here
+                    qparam_node = create_getattr_from_value(
+                        model, graph, module_path + prefix + key, value_or_node)
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
+                    quantize_op_inputs.append(value_or_node)
+
+            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+            # use the same qparams from quantize op
+            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
+            dequantized_node = graph.call_function(
+                dequantize_op,
+                tuple(dq_inputs),
+                {}
+            )
+
+            def remap_fn(x):
+                return dequantized_node if x is node else x
+
+            # remap numeric_debug_handle
+            for user_node in node.users:
+                if "numeric_debug_handle" in user_node.meta:
+                    numeric_debug_handle = user_node.meta["numeric_debug_handle"]
+                    user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+    elif is_dynamic:
+
+        # uint8/int8/fp16 dynamic quantization
+
+        # 1. extract information for inserting q/dq node from activation_post_process
+        node_type = "call_function"
+        quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
+        # we only use choose_qparams for is_decomposed now,
+        # but we should probably align the non-decomposed path with this as well,
+        # and that can be done after we remove reduce_range flag
+        # 1. extract qparams from activation_post_process module
+        dtype_ = to_underlying_dtype(dtype)
+        assert dtype_ in [torch.uint8, torch.int8], \
+            "only uint8 and int8 are supported in reference flow for " \
+            "dynamic quantization right now"
+        quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
+        quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
+        qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine)  # type: ignore[attr-defined]
+        eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps)  # type: ignore[attr-defined]
+        # note: scale and zero_point are missing for quantize_per_tensor op
+        # we'll need to get this from choose_qparams op, which we'll add after
+        # this step
+        qparams = {
+            "_quant_min_": quant_min,
+            "_quant_max_": quant_max,
+            "_eps_": eps,
+            "_dtype_": dtype_
+        }
+
+        choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
+        # 2. insert choose_qparams op and update the qparams list
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            choose_qparams_op_inputs = [node.args[0]]
+            for key, value in qparams.items():
+                # we have quant_min, quant_max and dtype, all should be stored
+                # as literals
+                choose_qparams_op_inputs.append(value)
+            choose_qparams_node = graph.create_node(
+                "call_function",
+                choose_qparams_op,
+                tuple(choose_qparams_op_inputs),
+                {}
+            )
+            # choose_qparms returns (scale, zero_point)
+            scale_node = graph.create_node(
+                "call_function",
+                operator.getitem,
+                (choose_qparams_node, 0),
+                {}
+            )
+            zero_point_node = graph.create_node(
+                "call_function",
+                operator.getitem,
+                (choose_qparams_node, 1),
+                {}
+            )
+            quant_min = qparams["_quant_min_"]
+            quant_max = qparams["_quant_max_"]
+            dtype = qparams["_dtype_"]
+            qparams = {
+                "_scale_": scale_node,
+                "_zero_point_": zero_point_node,
+                "_quant_min_": quant_min,
+                "_quant_max_": quant_max,
+                "_dtype_": dtype
+            }
+
+        # 3. replace activation_post_process node to quantize and dequantize node
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value_or_node in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ['_scale_', '_zero_point_']:
+                    # in this case we have a node in the graph since it's dynamically
+                    # computed from the input, with choose_qparams op
+                    qparam_node = value_or_node
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we
+                    # store them as literals in the graph.
+                    quantize_op_inputs.append(value_or_node)
+
+            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+            # use the same qparams from quantize op
+            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
+            # need to use the tensor variant of this op, since scale and zero_point
+            # from choose_qparam are Tensors, instead of float/int, this is to
+            # prevent these nodes being traced away by downstream systems
+            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
+            dequantized_node = graph.call_function(
+                dequantize_op,
+                tuple(dq_inputs),
+                {}
+            )
+
+            def remap_fn(x):
+                return dequantized_node if x is node else x
+
+            # remap numeric_debug_handle
+            for user_node in node.users:
+                if "numeric_debug_handle" in user_node.meta:
+                    numeric_debug_handle = user_node.meta["numeric_debug_handle"]
+                    user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+    elif dtype == torch.float16:
+        raise NotImplementedError("decomposed to float16 op not implemented yet")
+
+    # should not reach since we have checks in the beginning to make sure the
+    # activation_post_process is supported
+
+def _replace_observer_with_quantize_dequantize_node(
+        model: torch.fx.GraphModule,
+        node: Node,
+        modules: Dict[str, torch.nn.Module],
+        node_name_to_scope: Dict[str, Tuple[str, type]],
+        node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
+    """ Replace activation_post_process module call node with quantize and
+    dequantize node
+
+    Before:
+    ... -> observer_0(x) -> ...
+    After:
+    ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
+    """
+    assert modules is not None
+    assert isinstance(node.target, str)
+    graph = model.graph
+    module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
+    activation_post_process = modules[node.target]
+    # skip replacing observers to quant/dequant nodes if the qconfigs of all
+    # consumers and producers of this observer are None
+    skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
+                           list(node.args) + list(node.users.keys()))
+    if skip_replacement or not _is_conversion_supported(activation_post_process):
+        # didn't find corresponding quantize op and info for the activation_post_process
+        # so we just remove the observer
+        with graph.inserting_before(node):
+            node.replace_all_uses_with(node.args[0])
+            graph.erase_node(node)
+        return
+
+    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
+    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
+
+    is_dynamic = False
+    if hasattr(activation_post_process, "is_dynamic"):
+        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
+
+    if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
+            (not is_dynamic):
+        # TODO: probably should cleanup this condition check, it's hard
+        # to reason about this if and the following elif
+
+        # uint8/int8/int32 static quantization branch
+
+        # 1. extract the information from activation_post_process module for generating
+        # the quantize and dequantize operator
+        node_type = "call_function"
+        quantize_op : Optional[Callable] = None
+        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
+        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
+            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
+            qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
+            quantize_op = torch.quantize_per_channel
+        else:
+            scale = float(scale)
+            zero_point = int(zero_point)
+            qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
+            quantize_op = torch.quantize_per_tensor
+
+        # 2. replace activation_post_process node with quantize and dequantize
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value_or_node in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ['_scale_', '_zero_point_']:
+                    # For scale and zero_point values we register them as buffers in the root module.
+                    # TODO: maybe need more complex attr name here
+                    qparam_node = create_getattr_from_value(
+                        model, graph, module_path + prefix + key, value_or_node)
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
+                    quantize_op_inputs.append(value_or_node)
+
+            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+    elif is_dynamic:
+
+        # uint8/int8/fp16 dynamic quantization branch
+
+        node_type = "call_function"
+        quantize_op = torch.quantize_per_tensor_dynamic
+        # TODO: get reduce range from observer
+        # reduce_range = activation_post_process.reduce_range
+        reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
+        qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
+
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value in qparams.items():
+                quantize_op_inputs.append(value)
+
+            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+    elif dtype == torch.float16:
+        node_type = "call_method"
+        quantize_op = "to"  # type: ignore[assignment]
+        qparams = {"_dtype_": dtype}
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                quantize_op_inputs.append(value)
+
+            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+
+    # should not reach since we have checks in the beginning to make sure the
+    # activation_post_process is supported
+
+# this is a temporary hack for custom module, we may want to implement
+# this properly after the custom module class design is finalized
+# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
+# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
+# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
+def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None:
+    call_custom_module_node = node.args[0]
+    assert isinstance(call_custom_module_node, Node), \
+        f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
+    node.replace_all_uses_with(call_custom_module_node)
+    graph.erase_node(node)
+    _insert_dequantize_node(call_custom_module_node, graph)
+
+def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
+    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
+
+    is_dynamic = False
+    if hasattr(activation_post_process, "is_dynamic"):
+        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
+
+    return (
+        (dtype in [
+            torch.quint8,
+            torch.qint8,
+            torch.qint32,
+            torch.uint8,
+            torch.int8,
+            torch.int16,
+            torch.int32
+        ] and (not is_dynamic)) or  # type: ignore[return-value]
+        is_dynamic or
+        dtype == torch.float16
+    )
+
+def _has_none_qconfig(node: Argument, node_name_to_qconfig: Dict[str, QConfigAny]) -> bool:
+    """ Check if a node has a qconfig of None, i.e. user requested to not quantize
+    the node
+    """
+    return isinstance(node, Node) and node.name in node_name_to_qconfig and node_name_to_qconfig[node.name] is None
+
+def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
+    """ Extract the subgraph that produces the weight for dynamic quant
+    or weight only quant node and run the subgraph to observe the weight.
+    Note that the observers of dynamic quant or weight only quant ops are
+    run during the convert step.
+    """
+    for node in observed.graph.nodes:
+        if node.op != "call_function":
+            continue
+        for node_arg in node.args:
+            # node_arg is weight
+            if node_arg and node_arg_is_weight(node, node_arg):
+                weight_observer_nodes = collect_producer_nodes(node_arg)
+                if weight_observer_nodes is None:
+                    continue
+                weight_observer_module = \
+                    graph_module_from_producer_nodes(
+                        observed, weight_observer_nodes)
+                # run the weight observer
+                weight_observer_module()
+
+def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
+    """ If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
+    we'll recursively remove the dequantize Node
+    """
+    if isinstance(arg, Node) and \
+       arg.op == "call_method" and \
+       arg.target == "dequantize":
+        quantize_node = arg.args[0]
+        # we only replace the specific use since dequantize could be used by other nodes
+        # as well
+        node.replace_input_with(arg, quantize_node)
+    elif isinstance(arg, (list, tuple)):
+        for arg_element in arg:
+            _maybe_recursive_remove_dequantize(arg_element, node, graph)
+    elif isinstance(arg, dict):
+        for arg_element in arg.values():
+            _maybe_recursive_remove_dequantize(arg_element, node, graph)
+    else:
+        warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
+
+def _get_module_path_and_prefix(
+        obs_node: Node,
+        node_name_to_scope: Dict[str, Tuple[str, type]],
+        node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]:
+    """ Given and observer node, get the `Scope` or the fully qualified name for
+    the submodule containing the observed node, also return a prefix of "_input"
+    when the observed node is an input of a F.linear op, and not the output of another
+    quantized op.
+    TODO: this logic is hacky, we should think about how to remove it or make it more
+    general
+    """
+    observed_node = obs_node.args[0]
+    # an observer can be inserted for both input of the next operator or output of the previous
+    # operator (they can be the same)
+    # this flag identifies if the observer is inserted only because the observed node is
+    # the input of the next operator
+    assert isinstance(observed_node, Node), \
+        f"Expecting observed node to be a Node, but got {observed_node}"
+    is_input_observer_only = node_name_to_qconfig[observed_node.name] is None \
+        if observed_node.name in node_name_to_qconfig else None
+    if is_input_observer_only:
+        # if the quantize function is at the input of op, then we find the first user of the observer_node
+        # to get the path. If a linear call_function is in the user list, we return the first instance
+        # of linear node to get the FQN.
+        users = list(obs_node.users)
+        first_linear_use_or_first_use = users[0] if users else None
+        linear_node = None
+        for n in users:
+            if n.op == "call_function" and n.target == torch.nn.functional.linear:
+                linear_node = n
+                break
+        if linear_node:
+            first_linear_use_or_first_use = linear_node
+        prefix = "_input"
+    else:
+        # if the quantize function is at the output of the op, we use the observer input node to get the path
+        first_linear_use_or_first_use = observed_node
+        prefix = ""
+
+    if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
+        module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
+    else:
+        # TODO: it's not used, so actually we can skip quantization
+        # but this requires changing return type of quantize_node
+        # we can fix it later if needed
+        module_path = ""
+    return module_path, prefix
+
+def _insert_dequantize_node(
+        node: Node,
+        graph: Graph) -> None:
+    """ Inserts dequantize node for `node` in `graph`
+    """
+    with graph.inserting_after(node):
+        dequantize_node = graph.call_method("dequantize", (node,))
+        for user_node in dict(node.users):
+            if user_node is not dequantize_node:
+                user_node.replace_input_with(node, dequantize_node)
+
+def _maybe_get_observer_for_node(
+        node: Node,
+        modules: Dict[str, torch.nn.Module]
+) -> Optional[torch.nn.Module]:
+    """
+    If the node is observed, return the observer
+    instance. Otherwise, return None.
+    """
+    for maybe_obs_node in node.users.keys():
+        if maybe_obs_node.op == 'call_module':
+            maybe_obs = modules[str(maybe_obs_node.target)]
+            if _is_activation_post_process(maybe_obs):
+                return maybe_obs
+    return None
+
+def convert_standalone_module(
+        node: Node,
+        modules: Dict[str, torch.nn.Module],
+        model: torch.fx.GraphModule,
+        is_reference: bool,
+        backend_config: Optional[BackendConfig]) -> None:
+    """ Converts a observed standalone module to a quantized standalone module by calling
+    the fx convert api, currently using the same `is_reference` flag as parent, but we may
+    changing this behavior in the future (e.g. separating quantization and lowering for
+    standalone module as well)
+
+    Args:
+      - node: The call_module node of the observed standalone module
+      - modules: named_module of original model
+      - model: original model
+      - is_reference: a flag from parent provided by user to decide if we want to
+        produce a reference model or a fbgemm/qnnpack model
+      - backend_config: backend configuration of the target backend of quantization
+    """
+    # TODO: remove is_reference flag
+    if is_reference:
+        convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
+    else:
+        convert_fn = torch.ao.quantization.quantize_fx.convert_fx  # type: ignore[attr-defined]
+    # We know that observed standalone module is a GraphModule since
+    # it's produced by us
+    observed_standalone_module : GraphModule = modules[str(node.target)]  # type: ignore[assignment]
+    sm_input_quantized_idxs = \
+        observed_standalone_module \
+        .meta["_observed_graph_module_attrs"].standalone_module_input_quantized_idxs
+    # remove the dequantize nodes for inputs
+    args = list(node.args)
+    for idx in range(len(args)):
+        if idx in sm_input_quantized_idxs:
+            arg = args[idx]
+            if arg.op == "call_method" and arg.target == "dequantize":  # type: ignore[union-attr]
+                quantize_node = arg.args[0]  # type: ignore[union-attr]
+                node.replace_input_with(arg, quantize_node)
+                if len(arg.users) == 0:  # type: ignore[union-attr]
+                    model.graph.erase_node(arg)
+    # add dequantize node for output
+    sm_output_quantized_idxs = \
+        observed_standalone_module \
+        .meta["_observed_graph_module_attrs"].standalone_module_output_quantized_idxs
+    if len(sm_output_quantized_idxs) > 0:
+        assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
+        "output idxs = [0] is supported"
+
+        # if it's non-empty, then it means the output is kept in quantized form
+        # we'll just add a dequantize node after this node
+        _insert_dequantize_node(node, model.graph)
+
+    # TODO: allow convert_custom_config to override backend_config
+    # for standalone module
+    quantized_standalone_module = convert_fn(
+        observed_standalone_module,
+        backend_config=backend_config)
+    parent_name, name = _parent_name(node.target)
+    # update the modules dict
+    setattr(modules[parent_name], name, quantized_standalone_module)
+    modules[str(node.target)] = quantized_standalone_module
+
+def convert_weighted_module(
+        node: Node,
+        modules: Dict[str, torch.nn.Module],
+        observed_node_names: Set[str],
+        node_name_to_qconfig: Dict[str, QConfigAny],
+        backend_config: BackendConfig,
+        is_decomposed: bool = False,
+        is_reference: bool = False,
+) -> None:
+    """ Convert a weighted module to reference quantized module in the model
+    If the QConfig of a QAT module is not set, the module will still be converted to
+    a float module.
+
+    Args:
+      - node: The call_module node of the observed standalone module
+      - modules: named_module of original model
+      - observed_node_names: names for the set of observed fx node, we can skip
+        this conversion if the node is not observed
+    """
+    original_module = modules[str(node.target)]
+    qconfig: QConfigAny = original_module.qconfig  # type: ignore[assignment]
+    weight_post_process = None
+    qat_module_classes = get_qat_module_classes(backend_config)
+
+    if isinstance(
+            original_module,
+            qat_module_classes):
+        # Converting qat module to a float module, we need to attach
+        # weight fake_quant to the module, weight fake_quant is assumed to be run during
+        # QAT so we don't need to run it again here
+        weight_post_process = original_module.weight_fake_quant
+        original_module = original_module.to_float()  # type: ignore[operator]
+        # change qat module to float module
+        parent_name, name = _parent_name(node.target)
+        setattr(modules[parent_name], name, original_module)
+
+    is_observed = node.name in observed_node_names
+    # If a qconfig is not defined for this node, then skip converting to a reference module
+    if qconfig is None or _has_none_qconfig(node, node_name_to_qconfig) or not is_observed:
+        return
+
+    # skip converting to reference quantized module if the qconfig is not supported
+    pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
+    dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
+    if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
+        return
+
+    # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
+    is_weight_quantized = weight_is_quantized(qconfig)
+
+    # the condition for swapping the module to reference quantized module is:
+    # weights need to be quantized
+    if not is_weight_quantized:
+        return
+
+    fused_module = None
+    float_module = original_module
+    # extract the individual float_module and fused module
+    if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
+        fused_module = float_module
+        float_module = fused_module[0]  # type: ignore[index]
+
+    # TODO: move this to the reference quantized module
+    # weight_qparams or weight_qparams dict
+    wq_or_wq_dict = {"is_decomposed": is_decomposed}
+    if isinstance(float_module, torch.nn.RNNCellBase):
+        weight_post_process_ih = qconfig.weight()  # type: ignore[union-attr, operator]
+        weight_post_process_hh = qconfig.weight()  # type: ignore[union-attr, operator]
+        weight_post_process_ih(float_module.weight_ih)
+        weight_post_process_hh(float_module.weight_hh)
+        weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
+        weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
+        wq_or_wq_dict.update({
+            "weight_ih": weight_qparams_ih,
+            "weight_hh": weight_qparams_hh,
+        })
+    elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
+        # format for wq_or_wq_dict (flattened attributes):
+        # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
+        for wn in float_module._flat_weights_names:
+            if hasattr(float_module, wn) and wn.startswith("weight"):
+                weight = getattr(float_module, wn)
+                weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
+                if weight_post_process.dtype == torch.qint8:  # type: ignore[union-attr]
+                    weight_post_process(weight)  # type: ignore[operator, misc]
+                wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
+    else:
+        # weight_post_process is None means the original module is not a QAT module
+        # we need to get weight_post_process from qconfig in this case
+        is_ptq = weight_post_process is None
+        if is_ptq:
+            weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
+            device = assert_and_get_unique_device(float_module)
+            if device:
+                weight_post_process.to(device)
+
+        # Call weight observer/fake_quant at least once to ensure the scales and zero points
+        # have the right shapes. Note: there are two cases where we don't have to do this:
+        #
+        # (1) QAT: The model's forward method already calls the weight observer/fake_quant,
+        #     and this typically happens during training, so we don't need to do it here.
+        #
+        # (2) Non-reference (lowered) case: The quantized module's from_float method already
+        #     calls the weight observer/fake_quant, so we don't have to do it here.
+        #
+        # Currently we ignore both cases and call the weight observer/fake_quant here
+        # regardless, which is technically incorrect. For (1), this is mainly to preserve BC
+        # in test code, which may not always train before convert. In the future, we should
+        # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
+        #
+        # For PT2, however, we don't need to preserve BC here, so we can skip this hack
+        # for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
+        # Note that we still need it for PTQ in the PT2 flow since the model's forward
+        # method doesn't call the weight observer.
+        is_qat = not is_ptq
+        if not (is_decomposed and is_reference and is_qat):
+            weight_post_process(float_module.weight)  # type: ignore[operator]
+
+        wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
+
+    # We use the same reference module for all modes of quantization: static, dynamic, weight_only
+    # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
+    # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
+    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
+    ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None)
+    assert (
+        ref_qmodule_cls is not None
+    ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
+    ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict)  # type: ignore[attr-defined]
+    if fused_module is not None:
+        fused_module[0] = ref_qmodule  # type: ignore[operator]
+    else:
+        parent_name, name = _parent_name(node.target)
+        setattr(modules[parent_name], name, ref_qmodule)
+
+def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None:
+    """
+    Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
+
+    Before: quantize - dequantize - custom_module
+    After: quantize - custom_module
+                 \\ - dequantize
+    """
+    # expecting the input node for a custom module node to be a Node
+    assert isinstance(prev_node, Node), \
+        f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
+    if prev_node.op == "call_method" and prev_node.target == "dequantize":
+        node.replace_input_with(prev_node, prev_node.args[0])
+        # Remove the dequantize node if it doesn't have other users
+        if len(prev_node.users) == 0:
+            graph.erase_node(prev_node)
+
+def convert_custom_module(
+        node: Node,
+        graph: Graph,
+        modules: Dict[str, torch.nn.Module],
+        custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
+        statically_quantized_custom_module_nodes: Set[Node]) -> None:
+    """ Converts an observed custom module to a quantized custom module based on
+    `custom_module_class_mapping`
+    For static quantization, we'll also remove the previous `dequantize` node and
+    attach the observer node for output to the module, the observer for the node
+    will be converted to a dequantize node instead of quantize-dequantize pairs
+    later in the graph. In the end we would have a quantized custom module that
+    has the same interface as a default quantized module in nn.quantized namespace,
+    i.e. quantized input and quantized output.
+
+    Args:
+      - node: The call_module node of the observed standalone module
+      - graph: The graph containing the node
+      - modules: named_module of original model
+      - custom_module_class_mapping: mapping from observed custom module class to
+        quantized custom module class, used to swap custom modules
+      - statically_quantized_custom_module_nodes: we'll add the custom module node
+        if we find it is statically quantized, this will be used later when converting
+        observers to quant/dequant node pairs, if the observed node is a statically
+        quantized custom module nodes, we'll convert the observer to a dequantize node,
+        this is to keep the interface the same as the default quantized module.
+        TODO: maybe we want to redesign this part to align with reference model design
+        as well, but there has been some discussions around the interface, so we can do
+        it later.
+    """
+    observed_custom_module = modules[str(node.target)]
+    maybe_obs = _maybe_get_observer_for_node(node, modules)
+    qconfig = observed_custom_module.qconfig
+    if activation_is_statically_quantized(qconfig):
+        statically_quantized_custom_module_nodes.add(node)
+        if _is_custom_module_lstm(node, modules):
+            # The inputs are tuples in the form (input, (hidden0, hidden1))
+            # Ensure all three input nodes are quantized
+            assert (
+                len(node.args) == 2 and
+                isinstance(node.args[1], tuple) and
+                len(node.args[1]) == 2
+            )
+            (inputs, (hidden0, hidden1)) = node.args  # type: ignore[misc]
+            assert isinstance(inputs, Node)
+            assert isinstance(hidden0, Node)
+            assert isinstance(hidden1, Node)
+            _remove_previous_dequantize_in_custom_module(node, inputs, graph)
+            _remove_previous_dequantize_in_custom_module(node, hidden0, graph)
+            _remove_previous_dequantize_in_custom_module(node, hidden1, graph)
+        elif _is_custom_module_mha(node, modules):
+            # Inputs are in the form (query, key, value)
+            # TODO: This is the first step in enabling the full fx custom module
+            # quantization path for MultiheadAttention, and only covers the inputs
+            # to the module.
+            # Additional handling is yet to be implemented for the outputs, similar
+            # to LSTM custom module
+            assert len(node.args) == 3
+            query, key, value = node.args
+            assert isinstance(query, Node)
+            assert isinstance(key, Node)
+            assert isinstance(value, Node)
+            _remove_previous_dequantize_in_custom_module(node, query, graph)
+            _remove_previous_dequantize_in_custom_module(node, key, graph)
+            _remove_previous_dequantize_in_custom_module(node, value, graph)
+        else:
+            # remove the previous dequant node to ensure the inputs are quantized
+            arg = node.args[0]
+            assert isinstance(arg, Node)
+            _remove_previous_dequantize_in_custom_module(node, arg, graph)
+            # absorb the following observer into the module conversion
+            activation_post_process = _maybe_get_observer_for_node(node, modules)
+            assert activation_post_process is not None
+            observed_custom_module.activation_post_process = activation_post_process
+
+    # swap the observed custom module to quantized custom module
+    quantized_custom_module_class = get_swapped_custom_module_class(
+        observed_custom_module, custom_module_class_mapping, qconfig)
+    quantized_custom_module = \
+        quantized_custom_module_class.from_observed(observed_custom_module)
+    parent_name, name = _parent_name(node.target)
+    setattr(modules[parent_name], name, quantized_custom_module)
+
+def convert(
+        model: GraphModule, is_reference: bool = False,
+        convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
+        is_standalone_module: bool = False,
+        _remove_qconfig_flag: bool = True,
+        qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
+        backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+        is_decomposed: bool = False) -> GraphModule:
+    """
+    We will convert an observed model (a module with observer calls) to a reference
+    quantized model, the rule is simple:
+    1. for each observer module call in the graph, we'll convert it to calls to
+       quantize and dequantize functions based on the observer instance
+    2. for weighted operations like linear/conv, we need to convert them to reference
+       quantized module, this requires us to know whether the dtype configured for the
+       weight is supported in the backend, this is done in prepare step and the result
+       is stored in observed_node_names, we can decide whether we need to swap the
+       module based on this set
+
+    Args:
+       * `is_standalone_module`: when this flag is True, it means we are quantizing
+       a submodule that is not inlined in parent module, and will be quantized
+       separately as one unit.
+
+       * `is_decomposed`: a boolean flag to indicate whether we want to use the
+        quantize operator for decomposed quantized tensor
+        (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
+        quantized tensor (torch.quantize_per_tensor)
+
+    Returns:
+         a quantized standalone module, whether input/output is quantized is
+         specified by prepare_custom_config, with
+         input_quantized_idxs, output_quantized_idxs, please
+         see docs for :func:`~torch.ao.quantization.prepare_fx` for details
+    """
+    if convert_custom_config is None:
+        convert_custom_config = ConvertCustomConfig()
+
+    if isinstance(convert_custom_config, Dict):
+        warnings.warn(
+            "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
+            "in a future version. Please pass in a ConvertCustomConfig instead.")
+        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
+
+    if isinstance(qconfig_mapping, Dict):
+        warnings.warn(
+            "Passing a QConfig dictionary to convert is deprecated and will not be supported "
+            "in a future version. Please pass in a QConfigMapping instead.")
+        qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
+    qconfig_mapping = copy.deepcopy(qconfig_mapping)
+    assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
+
+    if isinstance(backend_config, Dict):
+        warnings.warn(
+            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a BackendConfig instead.")
+        backend_config = BackendConfig.from_dict(backend_config)
+
+    if backend_config is None:
+        backend_config = get_native_backend_config()
+
+    assert _is_observed_module(model), \
+        'incoming model must be produced by prepare_fx'
+    observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
+    node_name_to_scope: Dict[str, Tuple[str, type]] = observed_graph_module_attrs.node_name_to_scope
+    prepare_custom_config: PrepareCustomConfig = observed_graph_module_attrs.prepare_custom_config
+    observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names
+    node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig  # type: ignore[assignment]
+
+    # mapping from fully qualified module name to module instance
+    # for example,
+    # {
+    #   '': Model(...),
+    #   'linear': Linear(...),
+    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
+    # }
+    # We use remove_duplicate=False here because torch.cat uses
+    # the same activation_post_process module instance but different names
+    modules = dict(model.named_modules(remove_duplicate=False))
+
+    # TODO refactor this code once we update the prepare logic to have additional information on
+    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
+    if qconfig_mapping:
+        prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping  # type: ignore[assignment]
+        modules_copy = copy.deepcopy(modules)
+
+        if observed_graph_module_attrs.is_qat:
+            _update_qconfig_for_qat(qconfig_mapping, backend_config)
+        _update_qconfig_for_fusion(model, qconfig_mapping)
+
+        _compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping)  # type: ignore[arg-type]
+        convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
+            model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
+        # check the convert_node_name_to_qconfig generated and ensure that
+        # all the values either match what was set in prepare node_name_to_qconfig
+        # or are set to None in the convert_node_name_to_qconfig.
+        for k, v in node_name_to_qconfig.items():
+            assert k in convert_node_name_to_qconfig, f'Expected key {k} in convert node_name_to_qconfig'
+            if convert_node_name_to_qconfig[k] is not None:
+                assert qconfig_equals(v, convert_node_name_to_qconfig[k]), \
+                    f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " \
+                    f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
+        node_name_to_qconfig = convert_node_name_to_qconfig
+
+    custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping)
+    custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
+
+    if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
+        # If we want to do equalization then do the following:
+        # Calculate the equalization scale, update the observers with the scaled
+        # inputs, and scale the weight
+        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
+        convert_eq_obs(model, modules, weight_eq_obs_dict)
+
+    # always run weight observers in the top level forward method
+    # for dynamic quant ops or weight only quant ops
+    _run_weight_observers(model, backend_config)
+
+    graph_inputs: List[str] = []
+    for node in model.graph.nodes:
+        if node.op == 'placeholder':
+            graph_inputs.append(node.name)
+
+    # additional state to override inputs to be quantized, if specified
+    # by the user
+    placeholder_node_seen_cnt = 0
+    input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
+    output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
+
+    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
+    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
+    root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
+    qat_module_classes = get_qat_module_classes(backend_config)
+    fused_module_classes = get_fused_module_classes(backend_config)
+    statically_quantized_custom_module_nodes: Set[Node] = set()
+
+    for node in list(model.graph.nodes):
+        if node.op == 'placeholder':
+            cur_placeholder_node_idx = placeholder_node_seen_cnt
+            placeholder_node_seen_cnt += 1
+            if cur_placeholder_node_idx in input_quantized_idxs:
+                # Inputs are assumed to be quantized if the user specified the
+                # input_quantized_idxs override.
+                # we need to dequantize the inputs since all operators took
+                # floating point inputs in reference quantized models
+                _insert_dequantize_node(node, model.graph)
+        elif node.op == "output":
+            # If the argument is empty we don't need to do anything
+            if len(output_quantized_idxs) == 0:
+                continue
+            # Result are kept quantized if the user specified the
+            # output_quantized_idxs override.
+            # Remove the dequantize operator for the node in the end if any
+            return_node = node
+            output = node.args[0]
+            # outputs can be Node, list, tuple, dict, other cases are not supported yet
+            if isinstance(output, (list, tuple)):
+                for idx in output_quantized_idxs:
+                    _maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
+            elif isinstance(output, (Node, dict)):
+                # we treat dict as a single argument currently, but it can be extended
+                # to support {"key": dtype} after we change output_quantized_idxs to
+                # dict
+                if 0 in output_quantized_idxs:
+                    _maybe_recursive_remove_dequantize(output, return_node, model.graph)
+            else:
+                warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
+        elif node.op == "call_module":
+            mod = _get_module(node, modules)
+            assert mod is not None
+            if _is_activation_post_process(mod):
+                observed_node = node.args[0]
+                if observed_node in statically_quantized_custom_module_nodes:
+                    _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
+                else:
+                    if is_decomposed:
+                        _replace_observer_with_quantize_dequantize_node_decomposed(
+                            model, node, modules, node_name_to_scope,
+                            node_name_to_qconfig)
+                    else:
+                        _replace_observer_with_quantize_dequantize_node(
+                            model, node, modules, node_name_to_scope,
+                            node_name_to_qconfig)
+            elif isinstance(mod, DeQuantStub):
+                _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
+            elif _is_observed_standalone_module(mod):
+                convert_standalone_module(
+                    node, modules, model, is_reference, backend_config)
+            # below this point `type_before_parametrizations` is used
+            # instead of `type` to handle situations with fx quant + sparsity
+            elif type_before_parametrizations(mod) in set(
+                    root_module_classes).union(qat_module_classes).union(fused_module_classes):
+                # extra check for fused module classes to make sure they are fused module classes
+                # of target modules
+                if type_before_parametrizations(mod) in fused_module_classes and \
+                   type_before_parametrizations(mod[0]) not in root_module_classes:  # type: ignore[index]
+                    continue
+                convert_weighted_module(
+                    node, modules, observed_node_names, node_name_to_qconfig, backend_config,
+                    is_decomposed, is_reference)
+            elif type_before_parametrizations(mod) in custom_module_classes:
+                convert_custom_module(
+                    node, model.graph, modules, custom_module_class_mapping,
+                    statically_quantized_custom_module_nodes)
+
+    # remove deadcode after converting observers to quant/dequant ops
+    model.graph.eliminate_dead_code()
+    model = GraphModule(model, model.graph)
+
+    # TODO: maybe move this to quantize_fx.py
+    if not is_reference:
+        model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
+
+    # TODO: this looks hacky, we want to check why we need this and see if we can
+    # remove this
+    # removes qconfig and activation_post_process modules
+    if _remove_qconfig_flag:
+        _remove_qconfig(model)
+    model.delete_all_unused_submodules()
+    model.meta.pop("_observed_graph_module_attrs", None)
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/custom_config.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/custom_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c7cab65279fb41141c271cd6e349f211eaa9281
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/custom_config.py
@@ -0,0 +1,419 @@
+from __future__ import annotations
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str
+
+
+__all__ = [
+    "ConvertCustomConfig",
+    "FuseCustomConfig",
+    "PrepareCustomConfig",
+    "StandaloneModuleConfigEntry",
+]
+
+
+# TODO: replace all usages with these constants
+STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name"
+STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class"
+FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class"
+OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class"
+NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name"
+NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class"
+INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs"
+OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs"
+PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes"
+
+
+@dataclass
+class StandaloneModuleConfigEntry:
+    # qconfig_mapping for the prepare function called in the submodule,
+    # None means use qconfig from parent qconfig_mapping
+    qconfig_mapping: Optional[QConfigMapping]
+    example_inputs: Tuple[Any, ...]
+    prepare_custom_config: Optional[PrepareCustomConfig]
+    backend_config: Optional[BackendConfig]
+
+
+class PrepareCustomConfig:
+    """
+    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and
+    :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`.
+
+    Example usage::
+
+        prepare_custom_config = PrepareCustomConfig() \
+            .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
+                child_prepare_custom_config, backend_config) \
+            .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
+                child_prepare_custom_config, backend_config) \
+            .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
+            .set_non_traceable_module_names(["module2", "module3"]) \
+            .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
+            .set_input_quantized_indexes([0]) \
+            .set_output_quantized_indexes([0]) \
+            .set_preserved_attributes(["attr1", "attr2"])
+    """
+    def __init__(self):
+        self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
+        self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
+        self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
+        self.non_traceable_module_names: List[str] = []
+        self.non_traceable_module_classes: List[Type] = []
+        self.input_quantized_indexes: List[int] = []
+        self.output_quantized_indexes: List[int] = []
+        self.preserved_attributes: List[str] = []
+
+    def __repr__(self):
+        dict_nonempty = {
+            k: v for k, v in self.__dict__.items()
+            if len(v) > 0
+        }
+        return f"PrepareCustomConfig({dict_nonempty})"
+
+    def set_standalone_module_name(
+            self,
+            module_name: str,
+            qconfig_mapping: Optional[QConfigMapping],
+            example_inputs: Tuple[Any, ...],
+            prepare_custom_config: Optional[PrepareCustomConfig],
+            backend_config: Optional[BackendConfig]) -> PrepareCustomConfig:
+        """
+        Set the configuration for running a standalone module identified by ``module_name``.
+
+        If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
+        If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
+        If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
+        """
+        self.standalone_module_names[module_name] = \
+            StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
+        return self
+
+    def set_standalone_module_class(
+            self,
+            module_class: Type,
+            qconfig_mapping: Optional[QConfigMapping],
+            example_inputs: Tuple[Any, ...],
+            prepare_custom_config: Optional[PrepareCustomConfig],
+            backend_config: Optional[BackendConfig]) -> PrepareCustomConfig:
+        """
+        Set the configuration for running a standalone module identified by ``module_class``.
+
+        If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
+        If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
+        If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
+        """
+        self.standalone_module_classes[module_class] = \
+            StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
+        return self
+
+    def set_float_to_observed_mapping(
+            self,
+            float_class: Type,
+            observed_class: Type,
+            quant_type: QuantType = QuantType.STATIC) -> PrepareCustomConfig:
+        """
+        Set the mapping from a custom float module class to a custom observed module class.
+
+        The observed module class must have a ``from_float`` class method that converts the float module class
+        to the observed module class. This is currently only supported for static quantization.
+        """
+        if quant_type != QuantType.STATIC:
+            raise ValueError("set_float_to_observed_mapping is currently only supported for static quantization")
+        if quant_type not in self.float_to_observed_mapping:
+            self.float_to_observed_mapping[quant_type] = {}
+        self.float_to_observed_mapping[quant_type][float_class] = observed_class
+        return self
+
+    def set_non_traceable_module_names(self, module_names: List[str]) -> PrepareCustomConfig:
+        """
+        Set the modules that are not symbolically traceable, identified by name.
+        """
+        self.non_traceable_module_names = module_names
+        return self
+
+    def set_non_traceable_module_classes(self, module_classes: List[Type]) -> PrepareCustomConfig:
+        """
+        Set the modules that are not symbolically traceable, identified by class.
+        """
+        self.non_traceable_module_classes = module_classes
+        return self
+
+    def set_input_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
+        """
+        Set the indexes of the inputs of the graph that should be quantized.
+        Inputs are otherwise assumed to be in fp32 by default instead.
+        """
+        self.input_quantized_indexes = indexes
+        return self
+
+    def set_output_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
+        """
+        Set the indexes of the outputs of the graph that should be quantized.
+        Outputs are otherwise assumed to be in fp32 by default instead.
+        """
+        self.output_quantized_indexes = indexes
+        return self
+
+    def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig:
+        """
+        Set the names of the attributes that will persist in the graph module even if they are not used in
+        the model's ``forward`` method.
+        """
+        self.preserved_attributes = attributes
+        return self
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(cls, prepare_custom_config_dict: Dict[str, Any]) -> PrepareCustomConfig:
+        """
+        Create a ``PrepareCustomConfig`` from a dictionary with the following items:
+
+            "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs,
+            child_prepare_custom_config, backend_config) tuples
+
+            "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs,
+            child_prepare_custom_config, backend_config) tuples
+
+            "float_to_observed_custom_module_class": a nested dictionary mapping from quantization
+            mode to an inner mapping from float module classes to observed module classes, e.g.
+            {"static": {FloatCustomModule: ObservedCustomModule}}
+
+            "non_traceable_module_name": a list of modules names that are not symbolically traceable
+            "non_traceable_module_class": a list of module classes that are not symbolically traceable
+            "input_quantized_idxs": a list of indexes of graph inputs that should be quantized
+            "output_quantized_idxs": a list of indexes of graph outputs that should be quantized
+            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
+
+        This function is primarily for backward compatibility and may be removed in the future.
+        """
+        def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]:
+            """
+            Convert the given object into a QConfigMapping if possible, else throw an exception.
+            """
+            if isinstance(obj, QConfigMapping) or obj is None:
+                return obj
+            if isinstance(obj, Dict):
+                return QConfigMapping.from_dict(obj)
+            raise ValueError(f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'")
+
+        def _get_prepare_custom_config(obj: Any, dict_key: str) -> Optional[PrepareCustomConfig]:
+            """
+            Convert the given object into a PrepareCustomConfig if possible, else throw an exception.
+            """
+            if isinstance(obj, PrepareCustomConfig) or obj is None:
+                return obj
+            if isinstance(obj, Dict):
+                return PrepareCustomConfig.from_dict(obj)
+            raise ValueError(f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'")
+
+        def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]:
+            """
+            Convert the given object into a BackendConfig if possible, else throw an exception.
+            """
+            if isinstance(obj, BackendConfig) or obj is None:
+                return obj
+            if isinstance(obj, Dict):
+                return BackendConfig.from_dict(obj)
+            raise ValueError(f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'")
+
+        conf = cls()
+        for (module_name, qconfig_dict, example_inputs, _prepare_custom_config_dict, backend_config_dict) in\
+                prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []):
+            qconfig_mapping = _get_qconfig_mapping(qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY)
+            prepare_custom_config = _get_prepare_custom_config(_prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY)
+            backend_config = _get_backend_config(backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY)
+            conf.set_standalone_module_name(
+                module_name, qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
+        for (module_class, qconfig_dict, example_inputs, _prepare_custom_config_dict, backend_config_dict) in\
+                prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []):
+            qconfig_mapping = _get_qconfig_mapping(qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY)
+            prepare_custom_config = _get_prepare_custom_config(_prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY)
+            backend_config = _get_backend_config(backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY)
+            conf.set_standalone_module_class(
+                module_class, qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
+        for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get(FLOAT_TO_OBSERVED_DICT_KEY, {}).items():
+            quant_type = _quant_type_from_str(quant_type_name)
+            for float_class, observed_class in custom_module_mapping.items():
+                conf.set_float_to_observed_mapping(float_class, observed_class, quant_type)
+        conf.set_non_traceable_module_names(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, []))
+        conf.set_non_traceable_module_classes(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, []))
+        conf.set_input_quantized_indexes(prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, []))
+        conf.set_output_quantized_indexes(prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, []))
+        conf.set_preserved_attributes(prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
+        return conf
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``PrepareCustomConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`.
+        """
+        def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
+            qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None
+            prepare_custom_config_dict = e.prepare_custom_config.to_dict() if e.prepare_custom_config else None
+            return (key, qconfig_dict, e.example_inputs, prepare_custom_config_dict, e.backend_config)
+
+        d: Dict[str, Any] = {}
+        for module_name, sm_config_entry in self.standalone_module_names.items():
+            if STANDALONE_MODULE_NAME_DICT_KEY not in d:
+                d[STANDALONE_MODULE_NAME_DICT_KEY] = []
+            d[STANDALONE_MODULE_NAME_DICT_KEY].append(_make_tuple(module_name, sm_config_entry))
+        for module_class, sm_config_entry in self.standalone_module_classes.items():
+            if STANDALONE_MODULE_CLASS_DICT_KEY not in d:
+                d[STANDALONE_MODULE_CLASS_DICT_KEY] = []
+            d[STANDALONE_MODULE_CLASS_DICT_KEY].append(_make_tuple(module_class, sm_config_entry))
+        for quant_type, float_to_observed_mapping in self.float_to_observed_mapping.items():
+            if FLOAT_TO_OBSERVED_DICT_KEY not in d:
+                d[FLOAT_TO_OBSERVED_DICT_KEY] = {}
+            d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = float_to_observed_mapping
+        if len(self.non_traceable_module_names) > 0:
+            d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names
+        if len(self.non_traceable_module_classes) > 0:
+            d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes
+        if len(self.input_quantized_indexes) > 0:
+            d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes
+        if len(self.output_quantized_indexes) > 0:
+            d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes
+        if len(self.preserved_attributes) > 0:
+            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
+        return d
+
+
+class ConvertCustomConfig:
+    """
+    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`.
+
+    Example usage::
+
+        convert_custom_config = ConvertCustomConfig() \
+            .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
+            .set_preserved_attributes(["attr1", "attr2"])
+    """
+
+    def __init__(self):
+        self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
+        self.preserved_attributes: List[str] = []
+
+    def __repr__(self):
+        dict_nonempty = {
+            k: v for k, v in self.__dict__.items()
+            if len(v) > 0
+        }
+        return f"ConvertCustomConfig({dict_nonempty})"
+
+    def set_observed_to_quantized_mapping(
+            self,
+            observed_class: Type,
+            quantized_class: Type,
+            quant_type: QuantType = QuantType.STATIC) -> ConvertCustomConfig:
+        """
+        Set the mapping from a custom observed module class to a custom quantized module class.
+
+        The quantized module class must have a ``from_observed`` class method that converts the observed module class
+        to the quantized module class.
+        """
+        if quant_type not in self.observed_to_quantized_mapping:
+            self.observed_to_quantized_mapping[quant_type] = {}
+        self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class
+        return self
+
+    def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig:
+        """
+        Set the names of the attributes that will persist in the graph module even if they are not used in
+        the model's ``forward`` method.
+        """
+        self.preserved_attributes = attributes
+        return self
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(cls, convert_custom_config_dict: Dict[str, Any]) -> ConvertCustomConfig:
+        """
+        Create a ``ConvertCustomConfig`` from a dictionary with the following items:
+
+            "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization
+            mode to an inner mapping from observed module classes to quantized module classes, e.g.::
+            {
+            "static": {FloatCustomModule: ObservedCustomModule},
+            "dynamic": {FloatCustomModule: ObservedCustomModule},
+            "weight_only": {FloatCustomModule: ObservedCustomModule}
+            }
+            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
+
+        This function is primarily for backward compatibility and may be removed in the future.
+        """
+        conf = cls()
+        for quant_type_name, custom_module_mapping in convert_custom_config_dict.get(OBSERVED_TO_QUANTIZED_DICT_KEY, {}).items():
+            quant_type = _quant_type_from_str(quant_type_name)
+            for observed_class, quantized_class in custom_module_mapping.items():
+                conf.set_observed_to_quantized_mapping(observed_class, quantized_class, quant_type)
+        conf.set_preserved_attributes(convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
+        return conf
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``ConvertCustomConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
+        """
+        d: Dict[str, Any] = {}
+        for quant_type, observed_to_quantized_mapping in self.observed_to_quantized_mapping.items():
+            if OBSERVED_TO_QUANTIZED_DICT_KEY not in d:
+                d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {}
+            d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = observed_to_quantized_mapping
+        if len(self.preserved_attributes) > 0:
+            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
+        return d
+
+
+class FuseCustomConfig:
+    """
+    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`.
+
+    Example usage::
+
+        fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
+    """
+
+    def __init__(self):
+        self.preserved_attributes: List[str] = []
+
+    def __repr__(self):
+        dict_nonempty = {
+            k: v for k, v in self.__dict__.items()
+            if len(v) > 0
+        }
+        return f"FuseCustomConfig({dict_nonempty})"
+
+    def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig:
+        """
+        Set the names of the attributes that will persist in the graph module even if they are not used in
+        the model's ``forward`` method.
+        """
+        self.preserved_attributes = attributes
+        return self
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig:
+        """
+        Create a ``ConvertCustomConfig`` from a dictionary with the following items:
+
+            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
+
+        This function is primarily for backward compatibility and may be removed in the future.
+        """
+        conf = cls()
+        conf.set_preserved_attributes(fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
+        return conf
+
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``FuseCustomConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
+        """
+        d: Dict[str, Any] = {}
+        if len(self.preserved_attributes) > 0:
+            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
+        return d
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/fuse.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/fuse.py
new file mode 100644
index 0000000000000000000000000000000000000000..00c17062c1858f28b010cb1ac77d222d06182ba4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/fuse.py
@@ -0,0 +1,161 @@
+from torch.fx import (
+    GraphModule,
+    Node,
+    map_arg
+)
+from torch.fx.graph import Graph
+from .match_utils import (
+    _is_match,
+    MatchAllNode,
+)
+from .pattern_utils import (
+    _sorted_patterns_dict,
+)
+
+from ..backend_config import (
+    BackendConfig,
+    get_native_backend_config,
+)
+from ..backend_config.utils import (
+    get_fuser_method_mapping,
+    get_fusion_pattern_to_root_node_getter,
+    get_fusion_pattern_to_extra_inputs_getter,
+)
+
+from .custom_config import FuseCustomConfig
+
+from .fuse_handler import (
+    _get_fusion_pattern_to_fuse_handler_cls,
+    FuseHandler,
+)
+
+from typing import Any, Callable, Dict, List, Tuple, Union
+import warnings
+
+from torch.ao.quantization.utils import Pattern, NodePattern
+
+
+__all__ = [
+    "fuse",
+    # TODO: We should make this private in the future
+    # This is currently needed for test_public_bindings for some reason
+    "FuseHandler",
+]
+
+
+def fuse(
+    model: GraphModule,
+    is_qat: bool,
+    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    if fuse_custom_config is None:
+        fuse_custom_config = FuseCustomConfig()
+
+    if isinstance(fuse_custom_config, Dict):
+        warnings.warn(
+            "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
+            "in a future version. Please pass in a FuseCustomConfig instead.")
+        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
+
+    if isinstance(backend_config, Dict):
+        warnings.warn(
+            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a BackendConfig instead.")
+        backend_config = BackendConfig.from_dict(backend_config)
+
+    named_modules = dict(model.named_modules())
+
+    if backend_config is None:
+        backend_config = get_native_backend_config()
+
+    fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
+    fuser_method_mapping = get_fuser_method_mapping(backend_config)
+    fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
+    fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)
+
+    # find fusion
+    fusion_pairs = _find_matches(
+        model, model.graph, fusion_pattern_to_fuse_handler_cls)
+    # TODO: change this to inplace changes to graph, since we no longer construct
+    # new GraphModule anymore
+    fused_graph = Graph()
+    env: Dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node.name])
+
+    def default_root_node_getter(node_pattern):
+        while not isinstance(node_pattern[-1], Node):
+            node_pattern = node_pattern[-1]
+        return node_pattern[-1]
+
+    for node in model.graph.nodes:
+        maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
+            fusion_pairs.get(node.name, (None, None, None, None, None))
+        # get the corresponding subpattern for the current node
+        if node_to_subpattern is not None:
+            node_subpattern = node_to_subpattern.get(node, None)
+        else:
+            node_subpattern = None
+        if maybe_last_node is node:
+            assert obj is not None
+            root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
+            root_node = root_node_getter(matched_node_pattern)  # type: ignore[index]
+            extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
+            extra_inputs = []
+            if extra_inputs_getter is not None:
+                extra_inputs = extra_inputs_getter(matched_node_pattern)
+            # TODO: add validation that root_node is a module and has the same type
+            # as the root_module in the configuration
+            env[node.name] = obj.fuse(
+                load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern,  # type: ignore[arg-type]
+                fuse_custom_config, fuser_method_mapping, is_qat)
+        elif maybe_last_node is None or node_subpattern is MatchAllNode:
+            env[node.name] = fused_graph.node_copy(node, load_arg)
+        # node matched in patterns and is not root is removed here
+
+    model = GraphModule(model, fused_graph)
+    return model
+
+def _find_matches(
+        root: GraphModule,
+        graph: Graph,
+        pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
+) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
+    modules = dict(root.named_modules())
+    # node name -> (root_node, match_value)
+    match_map : Dict[
+        str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {}
+    # a map from node to the matched subpattern
+    node_to_subpattern: Dict[Node, Any] = {}
+
+    # TODO: dedup with quantization matching function in match_utils.py
+    def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
+        if isinstance(pattern, tuple):
+            s, *args = pattern
+            current_node_pattern: List[Node] = []
+            apply_match(s, node, match, current_node_pattern, node_to_subpattern)
+            for subpattern, arg in zip(args, node.args):
+                apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern)
+            matched_node_pattern.append(tuple(current_node_pattern))
+        else:
+            # the first pattern matches will take precedence
+            if node.name not in match_map:
+                matched_node_pattern.append(node)
+                # MatchAllNode here is actually MatchAllInputNode which should not
+                # be added to match_map
+                if pattern is not MatchAllNode:
+                    node_to_subpattern[node] = pattern
+                    root_node, pattern, handler = match
+                    match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)
+
+    for node in reversed(graph.nodes):
+        if node.name not in match_map:
+            for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
+                matched_node_pattern: List[Node] = []
+                if _is_match(modules, node, pattern):
+                    apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern)
+                    break
+
+    return match_map
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/fuse_handler.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/fuse_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..8752d4fe2d4486dd82a75e6c5e531e2efa888842
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/fuse_handler.py
@@ -0,0 +1,120 @@
+import torch
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.fx.graph import Node, Graph
+from ..utils import _parent_name, NodePattern, Pattern
+from ..fuser_method_mappings import get_fuser_method_new
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Union
+from .custom_config import FuseCustomConfig
+from .match_utils import MatchAllNode
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+__all__ = [
+    "DefaultFuseHandler",
+    "FuseHandler",
+]
+
+
+# ----------------------------
+# Fusion Pattern Registrations
+# ----------------------------
+
+# Base Pattern Handler
+class FuseHandler(ABC):
+    """ Base handler class for the fusion patterns
+    """
+    @abstractmethod
+    def __init__(self, node: Node):
+        pass
+
+    @abstractmethod
+    def fuse(self,
+             load_arg: Callable,
+             named_modules: Dict[str, torch.nn.Module],
+             fused_graph: Graph,
+             root_node: Node,
+             extra_inputs: List[Any],
+             matched_node_pattern: NodePattern,
+             fuse_custom_config: FuseCustomConfig,
+             fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
+             is_qat: bool) -> Node:
+        pass
+
+class DefaultFuseHandler(FuseHandler):
+    def __init__(
+            self,
+            node: Node):
+        super().__init__(node)
+
+    def fuse(self,
+             load_arg: Callable,
+             named_modules: Dict[str, torch.nn.Module],
+             fused_graph: Graph,
+             root_node: Node,
+             extra_inputs: List[Any],
+             matched_node_pattern: NodePattern,
+             fuse_custom_config: FuseCustomConfig,
+             fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
+             is_qat: bool) -> Node:
+        assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
+        root_module = named_modules[str(root_node.target)]
+
+        def get_modules(pattern):
+            """ Given a node pattern, extract the corresponding modules
+            e.g. input: (relu_node, (bn_node, conv_node))
+                 output: (relu_module, (bn_module, conv_module))
+            """
+            if isinstance(pattern, (tuple, list)):
+                n, *args = pattern
+                modules: List[torch.nn.Module] = []
+                modules.append(get_modules(n))
+                for a in args:
+                    modules.append(get_modules(a))
+                return tuple(modules)
+            else:
+                n = pattern
+                if n.op == "call_module":
+                    return named_modules[n.target]
+                elif n.op == "call_function" and n.target == torch.nn.functional.relu:
+                    relu = torch.nn.ReLU()
+                    relu.training = root_module.training
+                    return relu
+                elif n.op == "call_function" or n.op == "call_method":
+                    return n.target
+                else:
+                    return MatchAllNode
+
+        # since relu can be used multiple times, we'll need to create a relu module for each match
+        matched_modules = get_modules(matched_node_pattern)
+
+        def get_matched_types(m):
+            if isinstance(m, tuple):
+                return tuple(map(get_matched_types, m))
+            if isinstance(m, torch.nn.Module):
+                return type_before_parametrizations(m)
+            return m
+
+        matched_module_types = get_matched_types(matched_modules)
+        module_parent_name, module_name = _parent_name(root_node.target)
+        fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
+        # TODO: change the signature for fuser_method to take matched module patterns
+        # as input
+        fused_module = fuser_method(is_qat, *matched_modules)
+        setattr(named_modules[module_parent_name], module_name, fused_module)
+        extra_args = []
+        for input in extra_inputs:
+            extra_args.append(load_arg(input))
+        node = fused_graph.node_copy(root_node, load_arg)
+        args = list(node.args)
+        args.extend(extra_args)
+        node.args = tuple(args)
+        return node
+
+def _get_fusion_pattern_to_fuse_handler_cls(
+        backend_config: BackendConfig) -> Dict[Pattern, Callable]:
+    fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config.fuser_method is not None:
+            # TODO: is this logic right?
+            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
+    return fusion_pattern_to_fuse_handlers
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/graph_module.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/graph_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c17701b589a6689eceb27677888ecdeb519ff13b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/graph_module.py
@@ -0,0 +1,119 @@
+import torch
+import copy
+from torch.fx import GraphModule
+from torch.fx.graph import Graph
+from typing import Union, Dict, Any, Set
+
+__all__ = [
+    "FusedGraphModule",
+    "ObservedGraphModule",
+    "ObservedStandaloneGraphModule",
+    "QuantizedGraphModule",
+]
+
+class FusedGraphModule(GraphModule):
+    def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
+        self.preserved_attr_names = preserved_attr_names
+        preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
+        super().__init__(root, graph)
+        for attr in preserved_attrs:
+            setattr(self, attr, preserved_attrs[attr])
+
+    # GraphModule does not copy attributes which are not in the __dict__
+    # of vanilla nn.Module.  So, we override __deepcopy__ in order
+    # to copy the quantization specific attributes correctly.
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return FusedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
+
+class ObservedGraphModule(GraphModule):
+
+    def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
+        self.preserved_attr_names = {
+            '_activation_post_process_map',
+            '_activation_post_process_indexes',
+            '_patterns',
+            '_node_name_to_qconfig',
+            '_prepare_custom_config',
+            '_equalization_node_name_to_qconfig',
+            '_node_name_to_scope',
+            '_qconfig_mapping',
+            '_is_qat',
+            '_observed_node_names'}.union(preserved_attr_names)
+        preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
+        super().__init__(root, graph)
+        for attr in preserved_attrs:
+            setattr(self, attr, preserved_attrs[attr])
+
+    # GraphModule does not copy attributes which are not in the __dict__
+    # of vanilla nn.Module.  So, we override __deepcopy__ in order
+    # to copy the quantization specific attributes correctly.
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
+
+def _is_observed_module(module: Any) -> bool:
+    return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta
+
+def _get_observed_graph_module_attr(model: Union[torch.nn.Module, GraphModule], attr_name: str) -> Any:
+    if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta:  # type: ignore[operator, index]
+        return getattr(model.meta["_observed_graph_module_attrs"], attr_name)  # type: ignore[index]
+    return None
+
+class ObservedStandaloneGraphModule(ObservedGraphModule):
+    def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
+        preserved_attr_names = preserved_attr_names.union({
+            "_standalone_module_input_quantized_idxs",
+            "_standalone_module_output_quantized_idxs"})
+        super().__init__(root, graph, preserved_attr_names)
+
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
+
+def _is_observed_standalone_module(module: Any) -> bool:
+    return _is_observed_module(module) and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module
+
+def _save_packed_weight(self, destination, prefix, keep_vars):
+    for attr_name in dir(self):
+        if "_packed_weight" in attr_name and \
+           isinstance(getattr(self, attr_name), torch._C.ScriptObject):  # type: ignore[attr-defined]
+            packed_weight = getattr(self, attr_name)
+            destination[prefix + attr_name] = packed_weight
+
+class QuantizedGraphModule(GraphModule):
+    """ This class is created to make sure PackedParams
+    (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
+    so that we can serialize and deserialize quantized graph module with
+    torch.save(m.state_dict()) and m.load_state_dict(state_dict)
+    """
+    def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
+        self.preserved_attr_names = preserved_attr_names
+        preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
+        super().__init__(root, graph)
+        for attr in preserved_attrs:
+            setattr(self, attr, preserved_attrs[attr])
+        self._register_state_dict_hook(_save_packed_weight)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        attrs_to_pop = []
+        for attr_name in state_dict:
+            if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject):  # type: ignore[attr-defined] # noqa: B950
+                setattr(self, attr_name, state_dict[attr_name])
+                attrs_to_pop.append(attr_name)
+
+        # pop the packed param attributesn
+        for attr_name in attrs_to_pop:
+            state_dict.pop(attr_name)
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return QuantizedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0d526dfc4f62c76324b712489a2a34278a582ee
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py
@@ -0,0 +1,16 @@
+from ._lower_to_native_backend import _lower_to_native_backend
+from ..qconfig import QConfigAny
+from torch.fx import GraphModule
+from typing import Dict, Tuple
+
+__all__ = ['lower_to_fbgemm']
+
+def lower_to_fbgemm(
+    model: GraphModule,
+    qconfig_map: Dict[str, QConfigAny],
+    node_name_to_scope: Dict[str, Tuple[str, type]]
+) -> GraphModule:
+    """ Lower a quantized reference model (with reference quantized operator patterns)
+    to fbgemm
+    """
+    return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py
new file mode 100644
index 0000000000000000000000000000000000000000..54d816a214a8fbb45fe4fd425bc5d1fcefc46c78
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py
@@ -0,0 +1,18 @@
+from ._lower_to_native_backend import _lower_to_native_backend
+from ..qconfig import QConfigAny
+from torch.fx import GraphModule
+from typing import Dict, Tuple
+
+__all__ = [
+    "lower_to_qnnpack"
+]
+
+def lower_to_qnnpack(
+    model: GraphModule,
+    qconfig_map: Dict[str, QConfigAny],
+    node_name_to_scope: Dict[str, Tuple[str, type]]
+) -> GraphModule:
+    """ Lower a quantized reference model (with reference quantized operator patterns)
+    to qnnpack
+    """
+    return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/lstm_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/lstm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee786f1d42deb4187db0afda2fb6f1e496d8a0b3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/lstm_utils.py
@@ -0,0 +1,183 @@
+import copy
+import operator
+import torch
+from typing import Any, Callable, Optional, Tuple
+from torch.ao.quantization import (
+    default_weight_observer,
+    default_weight_fake_quant,
+    FakeQuantizeBase,
+    QConfig,
+    QConfigMapping,
+)
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.ao.quantization.observer import _PartialWrapper
+from torch.ao.quantization.quantize_fx import (
+    convert_to_reference_fx,
+    prepare_fx,
+)
+
+# TODO: move all LSTM util functions from fx/utils.py to this file
+def _get_lstm_with_individually_observed_parts(
+    float_lstm: torch.nn.LSTM,
+    example_inputs: Tuple[Any, ...],
+    backend_config: Optional[BackendConfig] = None,
+    linear_output_obs_ctr: Optional[_PartialWrapper] = None,
+    sigmoid_obs_ctr: Optional[_PartialWrapper] = None,
+    tanh_obs_ctr: Optional[_PartialWrapper] = None,
+    cell_state_obs_ctr: Optional[_PartialWrapper] = None,
+    hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
+) -> torch.ao.nn.quantizable.LSTM:
+    """
+    Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
+    with specific observers or fake quantizes assigned to the inner ops or submodules.
+
+    In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is
+    used as an observed custom module, which is responsible for inserting its own
+    observers. By default, all inner ops inherit the parent custom module's QConfig.
+    Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM`
+    and use this helper function to customize the observer insertion logic.
+
+    This is meant to be used to convert a float module to an observed module in the
+    custom module flow.
+
+    Args:
+        `float_lstm`: The float LSTM module
+        `example_inputs`: example inputs for the forward function of the LSTM module
+        `backend_config`: BackendConfig to use to observe the LSTM module
+        `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b,
+            where W is the weight matrix, b is the bias, and x is either the inputs
+            or the hidden state from the previous layer (if any)
+        `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations
+        `tanh_obs_ctr`: observer or fake quantize for tanh activations
+        `cell_state_obs_ctr`: observer or fake quantize for the cell state
+        `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and
+            the output
+
+    Return:
+        A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes
+        assigned to the inner ops.
+    """
+    def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig:
+        """
+        Make a QConfig with fixed qparams observers or fake quantizes.
+        """
+        if isinstance(obs_ctr(), FakeQuantizeBase):
+            weight = default_weight_fake_quant
+        else:
+            weight = default_weight_observer
+        return QConfig(activation=obs_ctr, weight=weight)
+
+    quantizable_lstm = torch.ao.nn.quantizable.LSTM(
+        float_lstm.input_size, float_lstm.hidden_size, float_lstm.num_layers, float_lstm.bias,
+        float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional)
+    quantizable_lstm.qconfig = float_lstm.qconfig
+
+    for idx in range(float_lstm.num_layers):
+        quantizable_lstm.layers[idx] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(float_lstm,
+                                                                                                 idx,
+                                                                                                 float_lstm.qconfig,
+                                                                                                 batch_first=False)
+
+    # Build QConfigMapping for the LSTM cell
+    # Note: FloatFunctional qconfigs will be configured separately below
+    cell_qm = QConfigMapping().set_global(float_lstm.qconfig)  # type: ignore[arg-type]
+    if sigmoid_obs_ctr is not None:
+        cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr))
+        cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr))
+        cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr))
+    if tanh_obs_ctr is not None:
+        cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr))
+
+    # Insert observers into each LSTM cell
+    # TODO: maybe make this work for layer_bw as well
+    for layer in quantizable_lstm.layers:
+        cell = layer.layer_fw.cell
+        cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
+        # HACK: Manually replace the activation_post_process following these ops.
+        # This is needed for FloatFunctional ops because there is currently no way
+        # to configure these ops in FX graph mode quantization today. This is because
+        # the FloatFunctional modules simply disappear from the graph after tracing.
+        # In the future, we should rewrite quantizable LSTM without FloatFunctionals.
+        op_index_to_activation_post_process_ctr = {
+            (torch.add, 0): linear_output_obs_ctr,  # gates.add
+            (torch.mul, 0): cell_state_obs_ctr,  # fgate_cx.mul
+            (torch.mul, 1): cell_state_obs_ctr,  # igate_cgate.mul
+            (torch.add, 1): cell_state_obs_ctr,  # fgate_cx_igate_cgate.add
+            (torch.mul, 2): hidden_state_obs_ctr,  # ogate_cy.mul
+        }
+        add_count = 0
+        mul_count = 0
+        for node in cell.graph.nodes:
+            op_index: Optional[Tuple[Callable, int]] = None  # e.g. (torch.add, 1)
+            if node.target == torch.add:
+                op_index = (torch.add, add_count)
+                add_count += 1
+            elif node.target == torch.mul:
+                op_index = (torch.mul, mul_count)
+                mul_count += 1
+            else:
+                # Neither torch.add nor torch.mul
+                continue
+            if op_index not in op_index_to_activation_post_process_ctr:
+                continue
+            assert len(node.users) == 1
+            activation_post_process_name = next(iter(node.users.keys())).name
+            activation_post_process_ctr = op_index_to_activation_post_process_ctr[op_index]
+            if activation_post_process_ctr is not None:
+                setattr(cell, activation_post_process_name, activation_post_process_ctr())
+        layer.layer_fw.cell = cell
+    return quantizable_lstm
+
+def _get_reference_quantized_lstm_module(
+    observed_lstm: torch.ao.nn.quantizable.LSTM,
+    backend_config: Optional[BackendConfig] = None,
+) -> torch.ao.nn.quantized.LSTM:
+    """
+    Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM`
+    with observers or fake quantizes inserted through `prepare_fx`, e.g. from
+    `_get_lstm_with_individually_observed_parts`.
+
+    This is meant to be used to convert an observed module to a quantized module in the
+    custom module flow.
+
+    Args:
+        `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx`
+        `backend_config`: BackendConfig to use to produce the reference quantized model
+
+    Return:
+        A reference `torch.ao.nn.quantized.LSTM` module.
+    """
+    quantized_lstm = torch.ao.nn.quantized.LSTM(
+        observed_lstm.input_size, observed_lstm.hidden_size, observed_lstm.num_layers,
+        observed_lstm.bias, observed_lstm.batch_first, observed_lstm.dropout,
+        observed_lstm.bidirectional)
+
+    for i, layer in enumerate(quantized_lstm.layers):
+        cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell)  # type: ignore[union-attr]
+        cell = convert_to_reference_fx(cell, backend_config=backend_config)  # type: ignore[arg-type]
+        assert isinstance(cell, torch.fx.GraphModule)
+        # HACK: Manually remove input quantize nodes and output dequantize nodes,
+        # since custom modules expect quint8 inputs and outputs for now. Note that
+        # this functionality is supposedly handled through PrepareCustomConfig's
+        # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that
+        # API doesn't currently handle tuple inputs and outputs, so we have to do
+        # this manually for now. In the future we should (1) relax the restriction
+        # on custom module input/output dtypes, and (2) expand support for complex
+        # input/output structures.
+        for node in cell.graph.nodes:
+            if node.target == torch.quantize_per_tensor:
+                arg = node.args[0]
+                # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1])
+                if arg.target == "x" or (arg.target == operator.getitem and arg.args[0].target == "hidden"):
+                    with cell.graph.inserting_before(node):
+                        node.replace_all_uses_with(arg)
+                        cell.graph.erase_node(node)
+            if node.target == "output":
+                # Remove all dequantize nodes in the output tuple
+                for arg in node.args[0]:
+                    with cell.graph.inserting_before(node):
+                        node.replace_input_with(arg, arg.args[0])
+        cell.graph.eliminate_dead_code()
+        cell.recompile()
+        layer.layer_fw.cell = cell
+    return quantized_lstm
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/match_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/match_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b999f0b1d81b45a9748a79645bebcaedbc2fbf90
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/match_utils.py
@@ -0,0 +1,237 @@
+import sys
+import torch
+from torch.fx.graph import (
+    Graph,
+    Node,
+)
+from torch.ao.quantization.utils import Pattern
+from .quantize_handler import (
+    QuantizeHandler,
+)
+from ..qconfig import (
+    QConfigAny,
+)
+from ..utils import (
+    MatchAllNode
+)
+from .graph_module import (
+    _is_observed_standalone_module,
+)
+from torch.nn.utils.parametrize import type_before_parametrizations
+from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
+
+
+__all__: List[str] = []
+
+# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
+# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
+_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
+
+_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
+                                QConfigAny]
+
+# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
+# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
+# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
+# we'll start from the last node of the graph and traverse back.
+def _is_match(modules, node, pattern, max_uses=sys.maxsize):
+    """ Matches a node in fx against a pattern
+    """
+    if isinstance(pattern, tuple):
+        self_match, *arg_matches = pattern
+        if self_match is getattr:
+            assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
+            arg_matches = []
+    else:
+        self_match = pattern
+        arg_matches = []
+
+    if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
+        return True
+
+    if node == pattern:
+        return True
+
+    if not isinstance(node, Node) or len(node.users) > max_uses:
+        return False
+
+    if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
+        if node.op != 'call_module':
+            return False
+        if not type_before_parametrizations(modules[node.target]) == self_match:
+            return False
+    elif callable(self_match):
+        if node.op != 'call_function' or node.target is not self_match:
+            return False
+        elif node.target is getattr:
+            if node.args[1] != pattern[1]:
+                return False
+    elif isinstance(self_match, str):
+        if node.op != 'call_method' or node.target != self_match:
+            return False
+    elif node.target != self_match:
+        return False
+
+    if not arg_matches:
+        return True
+
+    if len(arg_matches) != len(node.args):
+        return False
+
+    return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
+
+def _find_matches(
+        graph: Graph,
+        modules: Dict[str, torch.nn.Module],
+        patterns: Dict[Pattern, QuantizeHandler],
+        root_node_getter_mapping: Dict[Pattern, Callable],
+        standalone_module_names: Optional[List[str]] = None,
+        standalone_module_classes: Optional[List[Type]] = None,
+        custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
+    """
+    Matches the nodes in the input graph to quantization patterns, and
+    outputs the information needed to quantize them in future steps.
+
+    Inputs:
+      - graph: an fx.Graph object
+      - modules: a mapping of fully qualified module name to instance,
+          for example, {'foo': ModuleFoo, ...}
+      - patterns: a mapping from a tuple of nodes in reverse order to
+          uninitialized QuantizeHandler subclass.
+
+    Outputs a map of
+      node_name ->
+        (node, matched_values, matched_pattern, QuantizeHandler instance,
+         qconfig)
+
+    For example, {
+      'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
+                 , QConfig(...)),
+      ...
+    }
+    """
+    if custom_module_classes is None:
+        custom_module_classes = []
+
+    if standalone_module_classes is None:
+        standalone_module_classes = []
+
+    if standalone_module_names is None:
+        standalone_module_names = []
+
+    match_map: Dict[str, _MatchResult] = {}
+    all_matched : Set[str] = set()
+
+    def _recursive_record_node_in_match_map(
+            last_node,
+            match_map,
+            node_pattern,
+            matched_node_pattern,
+            pattern,
+            match_value):
+        if isinstance(node_pattern, Node):
+            match_map[node_pattern.name] = (
+                last_node, matched_node_pattern, pattern, match_value)
+        elif not isinstance(node_pattern, Iterable):
+            return
+        else:
+            for n in node_pattern:
+                _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
+
+    # TODO: 1. merge with fuse matcher 2. document the code
+    def record_match(
+            pattern,
+            node,
+            last_node,
+            matched_node_pattern,
+            match_map):
+        if isinstance(pattern, tuple):
+            s, *args = pattern
+            is_single_arg = len(args) == 1
+            current_node_pattern: List[Node] = []
+            record_match(
+                s,
+                node,
+                last_node,
+                matched_node_pattern,
+                match_map)
+            if pattern[0] is not getattr:
+                for subpattern, arg in zip(args, node.args):
+                    record_match(
+                        subpattern,
+                        arg,
+                        node,
+                        current_node_pattern,
+                        match_map)
+            if len(current_node_pattern) > 1:
+                # current_node_pattern is  the node pattern we get from matching
+                # the subpattern with arguments of the node
+                # we use is_single_arg to recover the original structure of the pattern
+                # if the original pattern has a single argument, we will have
+                # (original_op, (original_arg, ...))
+                # otherwise, we'll have a list of arguments
+                # (original_op, arg0, arg1, arg2, ...)
+                if is_single_arg:
+                    matched_node_pattern.append(tuple(current_node_pattern))
+                else:
+                    matched_node_pattern.extend(list(current_node_pattern))
+            else:
+                matched_node_pattern.append(current_node_pattern[0])
+        else:
+            matched_node_pattern.append(node)
+
+    for node in reversed(graph.nodes):
+        if node.name not in match_map and node.name not in all_matched:
+            for pattern, quantize_handler_cls in patterns.items():
+                root_node_getter = root_node_getter_mapping.get(pattern, None)
+                if _is_match(modules, node, pattern) and node.name not in match_map:
+                    matched_node_pattern: List[Node] = []
+                    record_match(
+                        pattern,
+                        node,
+                        node,
+                        matched_node_pattern,
+                        match_map)
+                    quantize_handler = quantize_handler_cls(  # type: ignore[operator]
+                        matched_node_pattern,
+                        modules,
+                        root_node_getter)
+                    last_node = node
+                    # record the match for all nodes in the pattern
+                    _recursive_record_node_in_match_map(
+                        last_node,
+                        match_map,
+                        # we need to record all nodes in the matched pattern in the match_map
+                        matched_node_pattern,
+                        # this is a part of the value corresponding to the node
+                        matched_node_pattern,
+                        pattern,
+                        quantize_handler)
+                    break
+
+    # add custom module instances to the match result
+    assert modules is not None
+    for node in graph.nodes:
+        if node.op == 'call_module' and \
+           type(modules[node.target]) in custom_module_classes:
+            match_map[node.name] = (
+                node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
+
+    def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
+        assert modules is not None
+        return (
+            node_target in standalone_module_names or  # type: ignore[operator]
+            type(modules[node_target]) in standalone_module_classes  # type: ignore[operator]
+        )
+
+    # add standalone modules to the match
+    for node in graph.nodes:
+        if node.op == 'call_module' and \
+           (is_standalone_module(node.target, modules) or
+                _is_observed_standalone_module(modules[node.target])):
+            # add node to matched nodes
+            match_map[node.name] = (
+                node, node, None,
+                QuantizeHandler(node, modules, is_standalone_module=True))
+
+    return match_map
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/pattern_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/pattern_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..53eeff6a8e4053849a8255b839ebc34f3b842b32
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/pattern_utils.py
@@ -0,0 +1,87 @@
+from collections import OrderedDict
+from typing import Dict, Any
+from torch.ao.quantization.utils import Pattern
+from ..fake_quantize import FixedQParamsFakeQuantize
+from ..observer import ObserverBase
+import copy
+
+__all__ = [
+    "get_default_fusion_patterns",
+    "get_default_quant_patterns",
+    "get_default_output_activation_post_process_map",
+]
+
+# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
+QuantizeHandler = Any
+
+# pattern for conv bn fusion
+_DEFAULT_FUSION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict()
+def _register_fusion_pattern(pattern):
+    def insert(fn):
+        _DEFAULT_FUSION_PATTERNS[pattern] = fn
+        return fn
+    return insert
+
+def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
+    return copy.copy(_DEFAULT_FUSION_PATTERNS)
+
+_DEFAULT_QUANTIZATION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict()
+
+# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
+# e.g. pattern: torch.sigmoid,
+#      output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
+_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: Dict[Pattern, QuantizeHandler] = {}
+_DEFAULT_OUTPUT_OBSERVER_MAP: Dict[Pattern, QuantizeHandler] = {}
+
+# Register pattern for both static quantization and qat
+def _register_quant_pattern(pattern, fixed_qparams_observer=None):
+    def insert(fn):
+        _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
+        if fixed_qparams_observer is not None:
+            _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
+            _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
+        return fn
+    return insert
+
+# Get patterns for both static quantization and qat
+def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
+    return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS)
+
+# a map from pattern to output activation post process constructor
+# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
+def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]:
+    if is_training:
+        return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
+    else:
+        return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP)
+
+# Example use of register pattern function:
+# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
+# class ConvOrLinearBNReLUFusion():
+#     def __init__(...):
+#         ...
+#
+
+def _sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]:
+    """
+    Return a sorted version of the patterns dictionary such that longer patterns are matched first,
+    e.g. match (F.relu, F.linear) before F.relu.
+    This works for current use cases, but we may need to have a more clever way to sort
+    things to address more complex patterns
+    """
+
+    def get_len(pattern):
+        """ this will calculate the length of the pattern by counting all the entries
+        in the pattern.
+        this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
+        (nn.BatchNorm, nn.Conv2d) so that we can match the former first
+        """
+        len = 0
+        if isinstance(pattern, tuple):
+            for item in pattern:
+                len += get_len(item)
+        else:
+            len += 1
+        return len
+
+    return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1))
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/prepare.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6f47c21a048b0cca8af05aeb358cc1c55312c8f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/prepare.py
@@ -0,0 +1,1880 @@
+import copy
+import torch
+import warnings
+from torch.fx import (
+    GraphModule,
+)
+from torch.fx.graph import (
+    Graph,
+    Node,
+)
+from torch.fx.node import Argument
+
+from ..quantize import (
+    propagate_qconfig_,
+)
+from ..observer import (
+    _is_activation_post_process,
+    _PartialWrapper,
+)
+from ..qconfig import (
+    _is_reuse_input_qconfig,
+    QConfigAny,
+)
+from ..qconfig_mapping import (
+    QConfigMapping,
+)
+from .qconfig_mapping_utils import (
+    _generate_node_name_to_qconfig,
+    _update_qconfig_for_fusion,
+    _get_flattened_qconfig_dict,
+    _update_qconfig_for_qat,
+)
+
+from .quantize_handler import (
+    _default_root_node_getter,
+    _get_pattern_to_quantize_handlers,
+    QuantizeHandler,
+)
+
+from torch.ao.quantization import (
+    ObserverBase,
+    FixedQParamsObserver,
+    FixedQParamsFakeQuantize,
+    _DerivedObserverOrFakeQuantize,
+)
+
+from torch.ao.quantization.utils import (
+    Pattern,
+    NodePattern,
+)
+
+from ._equalize import (
+    is_equalization_observer,
+    node_supports_equalization,
+)
+
+from .pattern_utils import (
+    _sorted_patterns_dict,
+)
+
+from .match_utils import (
+    _MatchResultWithQConfig,
+    _find_matches,
+)
+
+from .utils import (
+    _insert_dequant_stubs_for_custom_module_lstm_output,
+    _is_custom_module_lstm,
+    _maybe_get_custom_module_lstm_from_node_arg,
+    _qconfig_satisfies_dtype_config_constraints,
+    get_custom_module_class_keys,
+    all_node_args_have_no_tensors,
+    assert_and_get_unique_device,
+    get_non_observable_arg_indexes_and_types,
+    get_new_attr_name_with_prefix,
+    node_arg_is_weight,
+    node_arg_is_bias,
+    NON_QUANTIZABLE_WEIGHT_OPS,
+    ObservedGraphModuleAttrs,
+)
+
+from torch.ao.quantization import (
+    PlaceholderObserver
+)
+from torch.ao.quantization.quantize import (
+    convert
+)
+
+from ..utils import (
+    _parent_name,
+    get_qconfig_dtypes,
+    get_swapped_custom_module_class,
+)
+
+from ..backend_config.utils import (
+    get_pattern_to_dtype_configs,
+    get_module_to_qat_module,
+    get_fusion_pattern_to_root_node_getter,
+)
+from ..backend_config import (
+    BackendConfig,
+    DTypeConfig,
+    get_native_backend_config,
+)
+from .custom_config import (
+    PrepareCustomConfig,
+    StandaloneModuleConfigEntry,
+)
+from torch.ao.quantization.quantizer import (
+    EdgeOrNode,
+    QuantizationSpec,
+    QuantizationSpecBase,
+    FixedQParamsQuantizationSpec,
+    SharedQuantizationSpec,
+    DerivedQuantizationSpec,
+)
+from torch.ao.quantization import ObserverOrFakeQuantize
+
+from torch._subclasses import FakeTensor
+
+from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
+from dataclasses import asdict
+
+__all__ = [
+    "insert_observers_for_model",
+    "prepare",
+    "propagate_dtypes_for_known_nodes",
+]
+
+
+# list of dtypes to not add observers to
+_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
+_OBS_DTYPE_LIST = [
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+    torch.float16,
+    torch.uint8,
+    torch.int8,
+    torch.int16,
+    torch.int32
+]
+
+_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
+
+# note: the following default target dtype info dicts are temporary,
+# should be moved to the new programmable API class soon
+_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
+    "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
+    "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation
+}
+
+_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
+    "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
+    "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
+}
+
+
+def _get_observer_kwargs(quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec]):
+    kwargs_dict = asdict(quant_spec)
+    return copy.deepcopy(kwargs_dict)
+
+def _get_qspec_for_arg(
+    arg: Node,
+    input_qspec_map: Dict[Node, QuantizationSpecBase],
+    named_modules: Dict[str, torch.nn.Module]
+) -> Optional[QuantizationSpecBase]:
+    while _is_activation_post_process_node(arg, named_modules):
+        arg = arg.args[0]  # type: ignore[assignment]
+    return input_qspec_map.get(arg, None)
+
+def _create_obs_or_fq_from_qspec(
+    quantization_spec: Optional[QuantizationSpecBase],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+):
+    """ Create observer or fake quantize objects based on quantization spec
+
+    Args:
+       quantization_spec: used to store parameters to create the observer or fake quantizer
+       obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant
+       instance, it may be reused for different edge/output depending on configuration
+    """
+    if quantization_spec is None:
+        return None
+    if isinstance(quantization_spec, SharedQuantizationSpec):
+        edge_or_node = quantization_spec.edge_or_node
+        assert edge_or_node in obs_or_fq_map, \
+            "please make sure only refer to edge or node that has " \
+            f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
+        return obs_or_fq_map[edge_or_node]
+    elif isinstance(quantization_spec, DerivedQuantizationSpec):
+        # can't use asdict, so not calling get_observer_kwargs here
+        kwargs = {
+            "dtype": quantization_spec.dtype,
+            "derive_qparams_fn": quantization_spec.derive_qparams_fn,
+            "quant_min": quantization_spec.quant_min,
+            "quant_max": quantization_spec.quant_max,
+            "qscheme": quantization_spec.qscheme,
+            "ch_axis": quantization_spec.ch_axis,
+        }
+        edge_or_nodes = quantization_spec.derived_from
+        obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
+        kwargs["obs_or_fqs"] = obs_or_fqs
+        return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
+    elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
+        kwargs = _get_observer_kwargs(quantization_spec)
+        observer_ctr = FixedQParamsObserver.with_args(**kwargs)
+        if is_qat:
+            return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)
+        else:
+            return observer_ctr()
+
+    assert isinstance(quantization_spec, QuantizationSpec)
+    observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
+    kwargs = _get_observer_kwargs(quantization_spec)
+    kwargs.pop("observer_or_fake_quant_ctr")
+    # we will remove is_dynamic from QuantizationSpec because
+    # it seems that dynamic range quantization
+    obs_or_fq_class = observer_or_fake_quant_ctr
+    if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
+        obs_or_fq_class = observer_or_fake_quant_ctr.p.func  # type: ignore[union-attr, assignment]
+    if "PerChannel" not in obs_or_fq_class.__name__:  # type: ignore[operator, union-attr]
+        kwargs.pop("ch_axis")
+    return observer_or_fake_quant_ctr.with_args(**kwargs)()
+
+def _needs_obs_or_fq(
+        prev_output_dtype: Any,
+        prev_output_is_dynamic: bool,
+        cur_target_dtype: Any,
+        cur_target_is_dynamic: bool,
+        reuse_input_obs_or_fq: bool,
+        is_zeroth_arg: bool = False) -> bool:
+    """
+    note: we will treat "not specified" as torch.float for now
+    utility function that checks if we should insert an observer or fake quant node
+    base on the requested dtype for the nodes from user
+
+    is_zeroth_arg: we only dynamically quantize the first arg of the node right now
+      this should be removed when we enable configuring dynamic quantization
+      for a specific argument, this can be removed if we deprecate fx graph mode
+      quantization
+
+    """
+
+    # need to insert placeholder observer for dynamic quantization so that it can
+    # be converted to choose_qparams -> q -> dq in convert step
+    if cur_target_is_dynamic:
+        assert cur_target_dtype in _OBS_DTYPE_LIST, \
+            f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
+        assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
+        return is_zeroth_arg
+    if reuse_input_obs_or_fq:
+        return False
+    # non dynamic quantization
+    if cur_target_dtype in _OBS_DTYPE_LIST:
+        return prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] and cur_target_dtype != prev_output_dtype
+
+    # lots of error checking are skipped here for now
+    return False
+
+def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool:
+    return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
+        _is_activation_post_process(named_modules[str(node.target)])
+
+def _get_dtype_and_is_dynamic(obs_or_fq: Optional[ObserverOrFakeQuantize]) -> Tuple[Optional[torch.dtype], bool]:
+    """ Given a constructor for observer or fake quant module, returns
+    a Tuple of dtype and is_dynamic
+    """
+    # TODO: instead of instantiating the instance, we can use inspect to get the default args
+    if obs_or_fq is None:
+        return None, False
+    else:
+        return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False)  # type: ignore[return-value]
+
+def _is_input_arg_dtype_supported_by_backend(
+    arg: Argument,
+    node: Node,
+    qconfig: QConfigAny,
+    dtype_config: DTypeConfig,
+    backend_config: BackendConfig,
+) -> bool:
+    """ Check if the configured qconfig for the argument
+    is supported by the backend or not
+    """
+    if isinstance(arg, (list, tuple)):
+        return all(_is_input_arg_dtype_supported_by_backend(
+            a, node, qconfig,
+            dtype_config, backend_config) for a in arg)
+    if not isinstance(arg, Node):
+        return True
+    # TODO: support check for standalone module
+    is_weight = node_arg_is_weight(node, arg)
+    is_bias = node_arg_is_bias(node, arg)
+    is_activation = not is_weight and not is_bias
+    if is_activation:
+        input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
+        input_act_obs_or_fq = input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
+        qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq)
+        # TODO(future PR): remove the cast to bool below after figuring
+        # out why backend_config has is_dynamic set to None in some cases.
+        return (dtype_config.input_dtype is None) or (
+            dtype_config.input_dtype == qconfig_dtype and
+            bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and
+            _qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints)
+        )
+    elif is_weight:
+        # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
+        weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
+        weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
+        qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
+        backend_config_weight_dtype = dtype_config.weight_dtype
+        dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
+        qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
+            qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False)
+        return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
+    else:  # bias
+        # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
+        bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
+        bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
+        qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
+        backend_config_bias_dtype = dtype_config.bias_dtype
+        return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype
+
+def _is_output_dtype_supported_by_backend(
+    node: Node,
+    qconfig: QConfigAny,
+    dtype_config: DTypeConfig,
+) -> bool:
+    """ Check if the configured qconfig for the output
+    is supported by the backend or not
+    """
+    # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
+    backend_config_output_dtype = dtype_config.output_dtype
+    # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
+    # from input activation check can be reused here
+    qconfig_output_dtype = None
+    output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
+    output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+    qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
+    # TODO: this is a hack because we can only specify one activation_obs_or_fq for
+    # qconfig (qconfig.activation), and we are only supporting dynamically quantized
+    # linear op which has fp32 output dtype, this should be removed if we generalize
+    # the structure of qconfig in the future
+    if qconfig_output_is_dynamic:
+        qconfig_output_dtype = torch.float32
+    dtype_matches = qconfig_output_dtype == backend_config_output_dtype
+    qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
+        qconfig, dtype_config.output_dtype_with_constraints)
+    return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
+
+def _is_observer_in_same_graph(
+    node: Node,
+    named_modules: Dict[str, torch.nn.Module],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat,
+):
+    """ Check if observer in same graph
+    when the node output is not fp32 and input is 'placeholder'
+    the input is assumed to be quantized, so it is observed
+    in a different place rather than not observed.
+    """
+    node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules, obs_or_fq_map, is_qat)
+    if len(node.args) > 0 and isinstance(node.args[0], Node):
+        if node_output_dtype in [torch.quint8, torch.uint8] and node.args[0].op == 'placeholder':
+            return False
+    return True
+
+def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
+    pattern: Optional[Pattern],
+    matched_node_pattern: Optional[List[Node]],
+    qconfig: QConfigAny,
+    backend_config: BackendConfig,
+) -> bool:
+    """ Check if the dtype configuration of a pattern is supported by
+    the backend or not, and whether the qconfig satisfies constraints
+    specified in the corresponding dtype config.
+    """
+    if backend_config is None or pattern is None:
+        return True
+    assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
+    pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
+    dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
+    pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
+
+    root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
+    root_node = root_node_getter(matched_node_pattern)
+    input_node = root_node
+    output_node = matched_node_pattern[0]
+    for dtype_config in dtype_configs:
+        # check if arg dtype are supported
+        supported = True
+        for arg in list(input_node.args) + list(input_node.kwargs.values()):
+            supported = supported and _is_input_arg_dtype_supported_by_backend(
+                arg, input_node, qconfig, dtype_config, backend_config)
+        # check if output dtype is supported
+        supported = supported and _is_output_dtype_supported_by_backend(
+            output_node, qconfig, dtype_config)
+        if supported:
+            return True
+    return False
+
+def _get_standalone_module_configs(
+    node: Node,
+    named_modules: Dict[str, torch.nn.Module],
+    prepare_custom_config: PrepareCustomConfig,
+    parent_qconfig: QConfigAny,
+    parent_backend_config: Optional[BackendConfig],
+) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]:
+    """
+    Returns the standalone module QConfigMapping and PrepareCustomConfig
+    for `node`, assuming that the module pointed to by `node` is
+    a standalone modules.
+    """
+    module_name = str(node.target)
+    module_type = type(named_modules[module_name])  # type: ignore[index]
+    # name config has precedence over type config
+    config_entry = StandaloneModuleConfigEntry(None, (), None, None)
+    config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry)
+    config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry)
+    # fallback to use parent module's qconfig if user didn't specify qconfig dict
+    qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig)
+    example_inputs = config_entry.example_inputs
+    prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
+    backend_config = config_entry.backend_config or parent_backend_config
+    return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
+
+def _qat_swap_modules(
+        root: torch.nn.Module,
+        module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None:
+    convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
+
+def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
+    if isinstance(matched_node_pattern, Node):
+        s.add(matched_node_pattern.name)
+    elif isinstance(matched_node_pattern, (list, tuple)):
+        for maybe_node in matched_node_pattern:
+            _add_matched_node_name_to_set(maybe_node, s)
+
+def _insert_obs_or_fq(
+    node: Node,
+    obs_or_fq: ObserverOrFakeQuantize,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+) -> Node:
+    """
+    Attaches `obs_or_fq` to `model`, and creates a node which calls
+    `obs_or_fq` on the output of `node`.
+
+    obs_or_fq: an instance of Observer or FakeQuantize module
+    """
+    model_device = assert_and_get_unique_device(model)
+    if model_device:
+        obs_or_fq.to(model_device)
+    # add obs_or_fq module as attribute
+    if is_equalization_observer(obs_or_fq):
+        prefix = node.name + '_equalization_process_'
+    else:
+        prefix = 'activation_post_process_'
+    get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
+    obs_or_fq_name = get_new_obs_or_fq_name(model)
+    setattr(model, obs_or_fq_name, obs_or_fq)
+    named_modules[obs_or_fq_name] = obs_or_fq
+    with graph.inserting_after(node):
+        new_obs = graph.create_node(
+            'call_module', obs_or_fq_name, (node,), {})
+    return new_obs
+
+def _set_target_dtype_info_for_matched_node_pattern(
+    matched_node_pattern: NodePattern,
+    last_node: Node,
+    qconfig: QConfigAny,
+    qhandler: Optional[QuantizeHandler],
+    backend_config: BackendConfig,
+    named_modules: Dict[str, torch.nn.Module],
+    cache_for_no_tensor_check: Dict[Node, bool],
+    processed_nodes: Set[Node],
+) -> None:
+    """ Sets the target_dtype_info for each node in matched_node_pattern
+    Note: processed_nodes is used to ensure we only process each node once
+    """
+    if isinstance(matched_node_pattern, (list, tuple)):
+        for node_pattern in matched_node_pattern:
+            _set_target_dtype_info_for_matched_node_pattern(
+                node_pattern,
+                last_node,
+                qconfig,
+                qhandler,
+                backend_config,
+                named_modules,
+                cache_for_no_tensor_check,
+                processed_nodes
+            )
+
+    # set target_dtype_info if matched_node_pattern is a Node
+    # other types of matched object, e.g. int, float literals, are ignored
+    elif isinstance(matched_node_pattern, Node):
+        # for pyre
+        assert isinstance(matched_node_pattern, Node)
+        node = matched_node_pattern
+        if node in processed_nodes:
+            return
+        processed_nodes.add(node)
+
+        if qconfig is None:
+            return
+        # TODO: refactor the following code in terms of apply a qconfig to a pattern
+        # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
+        # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
+        # and set output_obs_or_fq_ctr based on qconfig.output_act
+        # this also requires we extend the structure of QConfig to support more fine
+        # grained configurations
+        target_dtype_info: Dict[str, Any] = (
+            _get_target_activation_dtype_for_node(
+                node,
+                qconfig,
+                qhandler,
+                named_modules,
+                backend_config,
+                cache_for_no_tensor_check,
+            )
+        )
+        node.meta["target_dtype_info"] = target_dtype_info
+
+def _get_target_activation_dtype_for_node(
+    node: Node,
+    qconfig: QConfigAny,
+    qhandler: Optional[QuantizeHandler],
+    named_modules: Dict[str, torch.nn.Module],
+    backend_config: BackendConfig,
+    cache_for_no_tensor_check: Dict[Node, bool],
+) -> Dict[str, Any]:
+    """
+    For each op attribute in the op's input activation, output activation,
+    weight, bias - returns the settings of dtype and is_dynamic we expect
+    for the `quantize` call in the reference model representation, or None
+    if there is no `quantize` call needed.
+
+    For example, if we have a node corresponding to `op0` in
+
+      x0 -> op0 -> x1
+
+    And we want a reference quantized representation to be
+
+      x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
+
+    Then this function will return
+
+      {
+        "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
+        "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
+      }
+
+    TODO(future PR, if needed): explicitly spell out the non-Tensor
+    dtypes.
+    """
+    args_have_no_tensors = \
+        all_node_args_have_no_tensors(
+            node, named_modules, cache_for_no_tensor_check)
+    if args_have_no_tensors:
+        return {
+            "input_act_obs_or_fq_ctr": None,
+            "output_act_obs_or_fq_ctr": None,
+        }
+    # get qconfig to determine the eventual dtype of this node
+    if qconfig is not None:
+        act_dtype, weight_dtype, input_act_is_dynamic = \
+            get_qconfig_dtypes(qconfig)
+
+        # Currently `QConfig` only has one `activation` field.
+        # For static quantization, it is reused for both input
+        # and output activation. For dynamic quantization, this
+        # field is currently only used for the input activation,
+        # with the output activation being in fp32.
+        # In the future this may change as we add more fields
+        # to the `QConfig` object.
+        output_act_dtype = act_dtype \
+            if (not input_act_is_dynamic) else torch.float
+
+        bias_dtype = torch.float16 \
+            if (
+                act_dtype == torch.float16
+                and weight_dtype == torch.float16
+                and (not input_act_is_dynamic)
+            ) else torch.float
+
+        is_general_tensor_value_op = \
+            (qhandler is not None and qhandler.is_general_tensor_value_op())
+
+        _is_standalone_module = (
+            qhandler is not None and qhandler.is_standalone_module()
+        )
+
+        weight_index = None
+        if isinstance(node, Node) and node.op == "call_function" and \
+           node.target in backend_config._pattern_complex_format_to_config:
+            weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
+
+        bias_index = None
+        if isinstance(node, Node) and node.op == "call_function" and \
+           node.target in backend_config._pattern_complex_format_to_config:
+            bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
+
+        return {
+            "input_act_obs_or_fq_ctr": qconfig.activation,
+            "weight_obs_or_fq_ctr": qconfig.weight,
+            "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
+            "weight_index": weight_index,
+            "bias_index": bias_index,
+            "output_act_obs_or_fq_ctr": qconfig.activation,
+            "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
+            "input_output_share_observers": is_general_tensor_value_op,
+            "_is_standalone_module": _is_standalone_module,
+        }
+    return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
+
+def _get_output_act_obs_or_fq(
+    arg: Node,
+    named_modules: Dict[str, torch.nn.Module],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> ObserverOrFakeQuantize:
+    """ Get the constructor for observer or fake quant object for
+    the argument in the original graph as the output of previous node,
+    skipping inserted observers
+
+    We are assuming that the observers are inserted correctly, and the dtype for
+    argument in quantized graph will match what is specified by the qconfig
+    """
+    assert isinstance(arg, Node)
+    if "quantization_annotation" in arg.meta:
+        return _create_obs_or_fq_from_qspec(arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
+
+    # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
+    # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
+    # Since we modified the graph in this case, we must trace back from the args through
+    # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
+    # not be able to accurately detect whether this node is a consumer of custom module LSTM.
+    custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
+    output_act_obs_or_fq_ctr = None
+    if custom_module_lstm_node is not None:
+        output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
+        output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+    elif _is_activation_post_process_node(arg, named_modules):
+        observed_arg = arg.args[0]
+        assert isinstance(observed_arg, Node), "Currently we only support observing Node"
+        if "quantization_annotation" in observed_arg.meta:
+            output_act_obs_or_fq = \
+                _create_obs_or_fq_from_qspec(
+                    observed_arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
+        else:
+            assert "target_dtype_info" in observed_arg.meta
+            output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
+            output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+    else:
+        if "target_dtype_info" in arg.meta:
+            output_act_obs_or_fq_ctr = \
+                arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
+        else:
+            output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
+        output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+
+    return output_act_obs_or_fq
+
+def _get_arg_target_dtype_as_output(
+    arg: Node,
+    named_modules: Dict[str, torch.nn.Module],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[torch.dtype]:
+    arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
+    arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
+    return arg_as_output_target_dtype
+
+def _get_arg_as_input_act_obs_or_fq(
+    arg: Node,
+    node: Node,
+    named_modules: Dict[str, torch.nn.Module],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[ObserverOrFakeQuantize]:
+    """ Get the observer or fake quant constructor for the Argument `arg`, as input
+    to Node `node`
+    """
+    assert isinstance(arg, Node)
+    # "input_qspec_map" is the more general design we'll use for pt2e path
+    # it is a map from input argument node to observer or fake quant constructor, for example
+    # for the following graph:
+    # x -> conv -> output
+    #
+    # we may annotate conv node like the following:
+    # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
+    #
+    if "quantization_annotation" in node.meta:
+        input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
+        input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
+        if input_arg_qspec is None:
+            input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR()
+        else:
+            input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(input_arg_qspec, obs_or_fq_map, is_qat)
+        return input_arg_obs_or_fq
+
+    # we can remove the following path in the future if fx graph mode quantization is
+    # no longer used
+    is_weight = node_arg_is_weight(node, arg)
+    is_bias = node_arg_is_bias(node, arg)
+    is_activation = not is_weight and not is_bias
+    obs_or_fq_ctr = None
+    if is_activation:
+        obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
+    elif is_weight:
+        if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
+            obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
+    else:
+        obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
+    return obs_or_fq_ctr() if obs_or_fq_ctr else None
+
+def _maybe_insert_input_observer_for_arg_or_kwarg(
+    node: Union[Node, Any],
+    arg: Argument,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+    qhandler: Optional[QuantizeHandler],
+    prepare_custom_config: PrepareCustomConfig,
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+    backend_config: Optional[BackendConfig] = None,
+) -> Argument:
+    """
+    Given a `node` and an `arg`, inserts an input observer between
+    `node` and `arg` if necessary.
+    """
+    # for ops such as torch.cat([x0, x1]),
+    # traverse through the list
+    if isinstance(arg, (list, tuple)):
+        new_arg_to_return = []
+        for inner_arg in arg:
+            new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+                node, inner_arg, qconfig, model, named_modules,
+                graph,
+                qhandler,
+                prepare_custom_config,
+                obs_or_fq_map,
+                is_qat,
+                backend_config)
+            new_arg_to_return.append(new_inner_arg)
+        return type(arg)(new_arg_to_return)
+
+    if not isinstance(arg, Node):
+        return arg
+    assert isinstance(arg, Node)
+    # default (no observer)
+    new_arg = arg
+
+    is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
+    # TODO: move this to a separate function
+    if not is_standalone_module:
+        # Note: qconfig can be None in this branch this we are getting act/fq from
+        # node.meta now
+        # regular flow for most nodes, except standalone modules
+
+        if "quantization_annotation" in node.meta:
+            reuse_input_obs_or_fq = node.meta["quantization_annotation"]._reuse_input_obs_or_fq
+        else:
+            assert "target_dtype_info" in node.meta
+            # TODO: we are assuming "target_dtype_info" exists here, maybe
+            # a default value also need to be provided here
+            target_dtype_info = node.meta["target_dtype_info"]
+            # for nodes that doesn't have `reuse_input_obs_or_fq` configured,
+            # we'll default to False, this makes configuring this field optional for users
+            reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
+        arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat)
+        arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
+
+        arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
+        arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
+
+
+        needs_obs_or_fq = _needs_obs_or_fq(
+            arg_as_output_target_dtype,
+            arg_as_output_target_is_dynamic,
+            arg_as_input_target_dtype,
+            arg_as_input_target_is_dynamic,
+            reuse_input_obs_or_fq,
+            is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
+        )
+
+    else:
+        assert qconfig is not None
+        # custom flow for standalone modules
+        _, _, sm_prepare_custom_config, _ = \
+            _get_standalone_module_configs(
+                node, named_modules, prepare_custom_config, qconfig, backend_config)
+        sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
+
+        # for args, this is set to the index of the current arg
+        # for kwargs, this is left at None
+        cur_input_idx = None
+        for arg_idx, arg_to_check in enumerate(node.args):
+            if arg_to_check is arg:
+                cur_input_idx = arg_idx
+                break
+
+        if cur_input_idx is None:
+            needs_obs_or_fq = False
+        else:
+            arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules, obs_or_fq_map, is_qat)
+            arg_as_input_target_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \
+                else torch.float
+            needs_obs_or_fq = (
+                (arg_as_output_target_dtype != arg_as_input_target_dtype) and
+                (arg_as_input_target_dtype != torch.float)
+            )
+
+        act_post_process_ctr = qconfig.activation
+        arg_as_input_act_obs_or_fq = act_post_process_ctr() if act_post_process_ctr else None
+
+    if needs_obs_or_fq:
+
+        existing_obs_node = None
+
+        # Before using the new observer, check if an observer
+        # of the correct type already exists. If it does, use it.
+        # This prevents duplicate observer insertions if a node is
+        # used by multiple nodes.
+        # TODO: this is looking into how the value is used in the future
+        # we should remove this
+        # removing this means we insert one observer for each use, even if they
+        # have the same dtype, we can have an extra pass that removes the extra observers
+        for maybe_obs_node in arg.users.keys():
+            if maybe_obs_node.op == 'call_module':
+                maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
+                if (
+                    type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
+                    maybe_obs_mod.dtype == arg_as_input_target_dtype  # type: ignore[possibly-undefined]
+                ):
+                    arg_as_input_act_obs_or_fq = maybe_obs_mod  # type: ignore[assignment]
+                    existing_obs_node = maybe_obs_node
+                    break
+
+        assert arg_as_input_act_obs_or_fq is not None
+        obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
+        if existing_obs_node is None:
+            new_obs_node = _insert_obs_or_fq(
+                arg, arg_as_input_act_obs_or_fq, model, named_modules, graph)
+            # override this arg to be the observed arg
+            new_arg = new_obs_node
+        else:
+            new_arg = existing_obs_node
+
+    return new_arg
+
+
+def _maybe_insert_input_observers_for_node(
+    node: Node,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+    qhandler: Optional[QuantizeHandler],
+    prepare_custom_config: PrepareCustomConfig,
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+    backend_config: Optional[BackendConfig] = None
+) -> None:
+    """
+    If needed, inserts observers to the input args and kwargs of `node`.
+    Note: modifies `node` inplace.
+
+    For example, if cur_node needs an observer after prev_node, we change from
+
+      prev_node -> cur_node
+
+    To
+
+      prev_node -> obs -> cur_node
+
+    Note: backend_config only needed for standalone_module node
+    """
+    # Look through every input arg.  If that arg's target dtype does not
+    # match the current node's target dtype, insert an observer.
+    new_args = []
+    for arg in node.args:
+        new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+            node, arg, qconfig, model, named_modules, graph,
+            qhandler,
+            prepare_custom_config,
+            obs_or_fq_map,
+            is_qat,
+            backend_config)
+        new_args.append(new_arg)
+
+    new_kwargs = {}
+    for k, kwarg in node.kwargs.items():
+        new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
+            node, kwarg, qconfig, model, named_modules, graph,
+            qhandler,
+            prepare_custom_config,
+            obs_or_fq_map,
+            is_qat,
+            backend_config)
+        new_kwargs[k] = new_kwarg
+
+    # assign the new args and kwargs to the node, inplace
+    node.args = tuple(new_args)
+    node.kwargs = new_kwargs
+
+def _maybe_insert_input_equalization_observers_for_node(
+    node: Node,
+    equalization_qconfig: Any,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+    is_branch: bool,
+) -> None:
+    """
+    If `node` needs to be equalized, find the input/weight observers it needs in
+    `equalization_qconfig`, creates them, and inserts it into `graph`.
+
+    If `node` does not need an equalization observer, returns None.
+    """
+    if equalization_qconfig is None or not node_supports_equalization(node, named_modules):
+        return
+
+    if is_branch:
+        warnings.warn(
+            f"Cannot equalize {node} because it is part of a branch."
+        )
+        return
+
+    new_args = []
+    for arg in node.args:
+        if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
+            new_args.append(arg)
+            continue
+
+        is_weight = node_arg_is_weight(node, arg)
+
+        act_eq_process_ctr = equalization_qconfig.weight if is_weight else \
+            equalization_qconfig.input_activation
+
+        new_eq_obs_mod = act_eq_process_ctr()
+        new_eq_obs_node = _insert_obs_or_fq(
+            arg, new_eq_obs_mod, model, named_modules, graph)
+
+        new_args.append(new_eq_obs_node)
+
+    # assign the new args and kwargs to the node, inplace
+    node.args = tuple(new_args)
+
+def _maybe_insert_output_observer_for_node(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[Node]:
+    """
+    If `node` needs an output observer, creates it, inserts it into `graph`
+    and returns it.
+
+    If `node` does not need an output observer, returns None.
+
+    Note: inserting dynamic quantization ops for output is not supported in fx graph mode
+    quantization code path right now
+    """
+    assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
+
+    is_standalone_module = False
+    if "quantization_annotation" in node.meta:
+        output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
+            node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
+        )
+    else:
+        assert "target_dtype_info" in node.meta
+        is_standalone_module = node.meta["target_dtype_info"].get("_is_standalone_module", False)
+        output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
+        output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+    target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
+    # uncomment after we support reuse_input_obs_or_fq properly by having separate
+    # implemntations for this key instead of reusing the input_output_share_observers
+    # code
+    # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
+    # for now we set this to False since reuse_input_obs_or_fq for
+    # the output of a node is implementation in the same code path as observer sharing,
+    # we should refactor this part to make it clearer in the future
+    # and we would be able to read this from config directly
+    reuse_input_obs_or_fq = False
+
+    # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
+    # because the prev_output is the output of an fp32 op, althought technically
+    # we should get the dtype of the output from node.meta["val"] in the future
+    # if we deprecate fx graph mode quantization
+    needs_obs_or_fq = _needs_obs_or_fq(torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq)
+    # currently the activation in QConfig(activation=...,) is for both input
+    # and output, and when the activation is configured to be dynamic quantization
+    # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
+    # the input should by dynamically quantized, but output should not be quantized
+    #
+    # there is no way we can specify different observer/fq for input and output
+    # activation through QConfig today, this limitation is lifted in the
+    # quantizer/annotation API in pytorch 2.0 export quantization code path,
+    # but since this code is reused, annotating output to be dynamically quantized
+    # would not work either for that.
+    # we can change QConfig to support input/output activation if we want
+    # to remove the following check, or if we can deprecate fx graph mode quantization
+    if target_is_dynamic:
+        needs_obs_or_fq = False
+
+    # we never insert observers to output of standalone module, we assume
+    # if needed, they are inserted inside the standalone module
+    needs_obs_or_fq = needs_obs_or_fq and \
+        (not is_standalone_module)
+
+    if needs_obs_or_fq:
+        obs_or_fq_map[node] = output_act_obs_or_fq
+        return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
+    else:
+        return None
+
+def _maybe_insert_observers_before_graph_output(
+    graph_output_node: Node,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> None:
+    """
+    If the output needs to be quantized and there are any nodes
+    in the output which are not already observed, inserts observers
+    for those nodes.
+    """
+
+    def _recursive_maybe_replace_node_with_obs(
+        maybe_node: Argument,
+        model: torch.nn.Module,
+        named_modules: Dict[str, torch.nn.Module],
+        graph: Graph,
+    ) -> Argument:
+        """
+        Navigate an arbitrary data structure of lists, tuples, dicts.
+        For each container type, recurse on all inputs. Once any Node
+        is found, insert an observer if needed and do not recurse further.
+
+        For example, given a structure of
+
+          {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
+
+        we recurse down to bar1 and bar3, observe them if necessary,
+        and if we inserted an observer then replace the original node
+        with its observer.
+
+        Returns the data structure with all nodes needing observation being
+        replaced by their observers.
+        """
+        if isinstance(maybe_node, Node):
+            # check dtype of this node
+            arg_as_output_target_dtype = _get_arg_target_dtype_as_output(maybe_node, named_modules, obs_or_fq_map, is_qat)
+            observer_mod = None
+            arg_as_input_target_dtype = torch.float
+            if "target_dtype_info" in maybe_node.meta:
+                observer_cls = maybe_node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", None)
+                if observer_cls is not None:
+                    observer_mod = observer_cls()
+                    arg_as_input_target_dtype = observer_mod.dtype
+            # TODO: this does not handle dynamic quantization yet
+            need_obs = (
+                arg_as_output_target_dtype != arg_as_input_target_dtype and
+                arg_as_input_target_dtype != torch.float
+            )
+            if need_obs:
+                assert observer_mod is not None
+                # insert observer
+                observer_node = _insert_obs_or_fq(
+                    maybe_node, observer_mod, model, named_modules, graph)
+                return observer_node
+            else:
+                return maybe_node
+        elif isinstance(maybe_node, (list, tuple)):
+            results = []
+            for inner_node in maybe_node:
+                results.append(_recursive_maybe_replace_node_with_obs(
+                    inner_node, model, named_modules, graph))
+            if isinstance(maybe_node, list):
+                return results
+            else:
+                return tuple(results)
+        elif isinstance(maybe_node, dict):
+            results_dict = {}
+            for k, inner_v in maybe_node.items():
+                results_dict[k] = _recursive_maybe_replace_node_with_obs(
+                    inner_v, model, named_modules, graph)
+            return results_dict
+        elif maybe_node is None:
+            return None
+        else:
+            raise Exception("Unhandled type for returned node:", maybe_node)
+
+    new_args = []
+    for old_arg in graph_output_node.args:
+        new_args.append(
+            _recursive_maybe_replace_node_with_obs(
+                old_arg, model, named_modules, graph))
+
+    graph_output_node.args = tuple(new_args)  # type: ignore[assignment]
+
+
+def _maybe_propagate_dtype_for_node(
+    node: Node,
+    target_dtype: Union[torch.dtype, type],
+    node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
+) -> None:
+    """
+    Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
+    is a general tensor shape op, also call this function recursively on
+    the first argument, to propagate the dtype to the caller.
+    """
+    node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
+    node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
+    # if this is a copy node, propagate to first arg
+    root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get(
+        node.name, (None, None, None, None, None))
+    # TODO: probably need to remove `is_general_tensor_value_op`
+    if qhandler is not None and qhandler.is_general_tensor_value_op():
+        prev_node = node.args[0]
+        if isinstance(prev_node, Node):
+            _maybe_propagate_dtype_for_node(
+                prev_node, target_dtype, node_name_to_match_result_with_qconfig)
+
+def propagate_dtypes_for_known_nodes(
+    graph: Graph,
+    node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
+) -> None:
+    """
+    Currently we assume that inputs to the graph are either `torch.float` or
+    `torch.quint8`, which is not always correct. For ops such as
+    `x.masked_fill(mask, value)`, we know that the dtype of  `mask` is a
+    `BoolTensor`. Propagate this information throughout the graph.
+
+    Note: not all dtypes in the graph will be correct after this pass, but a
+    higher percentage of them will be correct. Hopefully in the future we can
+    replace this with a better way to reason about dtypes of tensors.
+    """
+    for node in graph.nodes:
+        non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
+
+        for arg_type in non_observable_arg_dict:
+            non_observable_indices = non_observable_arg_dict[arg_type](node)
+
+            for index in non_observable_indices:
+                arg = node.args[index]
+
+                # when an argument is a tuple, it does not show up as another node so we need to go through
+                # all elements of the tuple manually
+                if isinstance(arg, (tuple, list)):
+                    arg_list = list(arg)
+                else:
+                    arg_list = [arg]
+
+                for cur_arg in arg_list:
+                    # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
+                    if isinstance(cur_arg, torch.fx.node.Node):
+                        _maybe_propagate_dtype_for_node(
+                            cur_arg, arg_type, node_name_to_match_result_with_qconfig)
+
+def _maybe_make_input_output_share_observers(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+) -> bool:
+    """
+    Ensures that we share an observer
+    for all input arguments as well as the output argument. In detail, given
+    a graph of
+
+      x0 -> obs0 -> op -> x2
+                  /
+      x1 -> obs1 /
+
+    where node obs0 points to observer instance observer0,
+    obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
+    and ob2 point to observer0.
+    Returns: whether the operation succeeded or not
+    """
+    first_arg = None
+    # find the first non-Tensor arg
+    for i in range(len(node.args)):
+        if isinstance(node.args[i], (Node, list, tuple)):
+            first_arg = node.args[i]
+            break
+
+    # if there is no non-Tensor arg, return directly
+    if first_arg is None:
+        return False
+
+    if isinstance(first_arg, (list, tuple)):
+        first_arg_arg = first_arg[0]
+    elif isinstance(first_arg, Node):
+        first_arg_arg = first_arg
+    else:
+        return False
+
+    # if we have a graph such as
+    #   observed_node -> non_observed_node -> cat
+    # we need to navigate up to the first observer
+    iteration_guard = 0
+    while not _is_activation_post_process_node(first_arg_arg, named_modules):
+        if not isinstance(first_arg_arg, Node):
+            return False
+        # did not find an activation_post_process for the op
+        if first_arg_arg.op == "placeholder":
+            return False
+        # trace back the args until we found the first Tensor/Node
+        trace_back_node = None
+        for i in range(len(first_arg_arg.args)):
+            trace_back_node = first_arg_arg.args[i]
+            if isinstance(trace_back_node, Node):
+                break
+        if trace_back_node is None:
+            return False
+        first_arg_arg = trace_back_node
+
+        iteration_guard += 1
+        if iteration_guard > 10000:
+            raise AssertionError('Unable to find observer of previous node')
+
+    assert isinstance(first_arg_arg, Node)
+    target_to_use = first_arg_arg.target
+    assert isinstance(target_to_use, str)
+    obs_mod_to_use = named_modules[target_to_use]
+
+    if isinstance(first_arg, (list, tuple)):
+        # set all other input observer nodes to use that module
+        for input_idx, input_arg in enumerate(first_arg):
+            if input_idx == 0:
+                continue
+            iteration_guard = 0
+            while not _is_activation_post_process_node(input_arg, named_modules):
+                # failed to trace back since no input arg for the current node
+                if len(input_arg.args) < 1:
+                    return False
+                input_arg = input_arg.args[0]
+                iteration_guard += 1
+                if iteration_guard > 10000:
+                    raise AssertionError('Unable to find observer of previous node')
+
+            parent_name, name = _parent_name(input_arg.target)
+            setattr(named_modules[parent_name], name, obs_mod_to_use)
+
+    # set the output observer node to use that module
+    for output_obs_node in node.users.keys():
+        assert _is_activation_post_process_node(output_obs_node, named_modules)
+        parent_name, name = _parent_name(output_obs_node.target)
+        setattr(named_modules[parent_name], name, obs_mod_to_use)
+
+    # TODO(future PR): delete the orphaned observer modules
+    return True
+
+def _remove_output_observer(
+        node: Node,
+        model: torch.nn.Module,
+        named_modules: Dict[str, torch.nn.Module]):
+    items = list(node.users.items())
+    for output_obs_node, _ in items:
+        assert _is_activation_post_process_node(output_obs_node, named_modules)
+        output_obs_node.replace_all_uses_with(node)
+        model.graph.erase_node(output_obs_node)  # type: ignore[union-attr, operator]
+
+def _swap_custom_module_to_observed(
+        node: Node,
+        qconfig: QConfigAny,
+        named_modules: Dict[str, torch.nn.Module],
+        prepare_custom_config: PrepareCustomConfig):
+    custom_module = named_modules[node.target]  # type: ignore[index]
+    custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
+    observed_custom_module_class = \
+        get_swapped_custom_module_class(
+            custom_module, custom_module_class_mapping, qconfig)
+    observed_custom_module = \
+        observed_custom_module_class.from_float(custom_module)
+    parent_name, name = _parent_name(node.target)
+    setattr(named_modules[parent_name], name, observed_custom_module)
+
+def insert_observers_for_model(
+    model: GraphModule,
+    node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
+    node_name_to_qconfig: Dict[str, QConfigAny],
+    prepare_custom_config: PrepareCustomConfig,
+    equalization_config_map: Dict[str, Any],
+    backend_config: BackendConfig,
+    observed_node_names: Set[str],
+    is_qat: bool,
+) -> Optional[Node]:
+    """
+    Inserts observers, using the following high level algorithm:
+
+    For each node in the graph:
+      1. determine the target dtype of this node in the quantized graph, and save
+           it for future steps
+      2. determine the target dtype or all args and kwargs of this node
+      3. if any arg or kwarg's target dtype does not match the current node's
+           dtype, insert an observer
+      4. if the current node needs an output observer, insert it
+
+    For example:
+
+    - starting graph:
+        x0 -> linear -> x1
+
+    - observed graph after processing x0:
+        x0(fp32)
+
+    - observed graph after processing linear:
+        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
+
+    - observed graph after processing x1:
+        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
+
+    After a node is processed, the naive observer placement is guaranteed to be
+    complete for that node and all of its predecessors. There can be future
+    passes which optimize the graph by deduplicating observers, etc.
+    """
+
+    # node.meta["target_dtype_info"] stores the target dtype information
+    # that's derived from qconfig for the Node, for example, if we have
+    # a conv2d node that has a qconfig
+    # qconfig = QConfig(activation=..., weight=...)
+    # # information for input and bias node omitted
+    # # for getattr node
+    # # weight = getattr(self, 'weight')
+    # weight.meta["target_dtype_info"] = {
+    #    'output_act_obs_or_fq_ctr': qconfig.weight,
+    # }
+    # # for conv2d node
+    # # conv2d = call_function[target=torch.nn.functional.conv2d](
+    # #            args=(input, weight, bias))
+    # conv2d.meta["target_dtype_info"] = {
+    #   'input_act_obs_or_fq_ctr': qconfig.activation
+    #   'weight_obs_or_fq_ctr': qconfig.weight,
+    #   'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
+    #   'output_act_obs_or_fq_ctr': qconfig.activation,
+    # }
+    #
+    cache_for_no_tensor_check: Dict[Node, bool] = {}
+
+    # first, populate the dtype map based only on qconfig and qhandler
+    # this assumes:
+    # graph inputs are fp32 by default, and int8 where overriden
+    # other nodes output dtype is specified by the qconfig
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+
+    input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
+    output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
+    processed_nodes: Set[Node] = set()
+    # initialize target_dtype_info
+    for node in model.graph.nodes:
+        node.meta["target_dtype_info"] = copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
+
+    inputs_seen_counter = 0
+    outputs_seen_counter = 0
+    placeholder_node_to_input_index: Dict[Node, int] = {}
+    # TODO: we probably don't need this counter since each graph will only have
+    # one output node?
+    output_node_to_output_index: Dict[Node, int] = {}
+    for node in model.graph.nodes:
+        if node.op == "placeholder":
+            placeholder_node_to_input_index[node] = inputs_seen_counter
+            inputs_seen_counter += 1
+        if node.op == "output":
+            output_node_to_output_index[node] = outputs_seen_counter
+            outputs_seen_counter += 1
+
+    # Step 1, set the observer or fake quantize module constructor for each node in the
+    # matched_node_pattern
+
+    for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
+        last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
+        assert qhandler is not None
+        _set_target_dtype_info_for_matched_node_pattern(
+            matched_node_pattern,
+            last_node,
+            qconfig,
+            qhandler,
+            backend_config,
+            named_modules,
+            cache_for_no_tensor_check,
+            processed_nodes
+        )
+
+    # Step 2. Special cases for some operators, we might be able to remove them
+    # in the future if we know dtype information of each node better
+
+    # Step 2.1. some settings are not based on patterns, we need to process each node
+    # instead
+    for node in model.graph.nodes:
+        if node.op == "placeholder" and placeholder_node_to_input_index[node] in input_quantized_idxs:
+            # users are not supposed to call calculate_qparams on PlaceholderObserver, and
+            # this is OK because we are using this as a way to encode the dtypes of input
+            # tensor, we won't actually insert these observers in the graph and won't
+            # actually call calculate_qparams
+            node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
+        elif node.op in ("call_module", "call_method", "call_function"):
+            args_have_no_tensors = \
+                all_node_args_have_no_tensors(
+                    node, named_modules, cache_for_no_tensor_check)
+            if args_have_no_tensors:
+                node.meta["target_dtype_info"] = {
+                    "input_act_obs_or_fq_ctr": None,
+                    "output_act_obs_or_fq_ctr": None,
+                }
+        elif node.op == "output" and output_node_to_output_index[node] in output_quantized_idxs:
+            # TODO(future PR): update the output_quantized_idxs API to match
+            # arbitrary data structures. There is always a single output, and
+            # that output can have arbitrary nesting of values. List[int] is
+            # not the right data type for this.
+
+            # TODO(future PR): support more dtypes in model outputs, if necessary
+            node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
+
+    # Step 2.2, for nodes with known input dtypes, propagate them throughout the
+    # graph. For example, if there is a call such as
+    #   x1 = x0.masked_fill(mask, 1)
+    # we propagate the type of mask to be torch.bool
+    propagate_dtypes_for_known_nodes(model.graph, node_name_to_match_result_with_qconfig)
+
+    # Step 3, check if the requested target_dtype_info is supported by backend or not
+    # if not, we'll reset the target_dtye_info to use the default (float Tensor)
+
+    # reset the counters and set of processed_nodes
+    processed_nodes: Set[Node] = set()
+    for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
+        last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
+        is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
+            pattern, matched_node_pattern, qconfig, backend_config)
+        assert qhandler is not None
+
+        # get output_act_dtype so that we don't also reset the special typed nodes
+        # TODO: we might want to handle these more uniformly with the default path
+        # this can be improved if we can use node.meta["val"]
+        output_act_or_fq_ctr = node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
+        output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
+        output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
+        if not is_supported_by_backend and output_act_dtype not in [None, int, float, torch.bool]:
+            # restore target_dtype_info to default if it is not supported by backend
+            _set_target_dtype_info_for_matched_node_pattern(
+                matched_node_pattern,
+                last_node,
+                torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
+                None,
+                backend_config,
+                named_modules,
+                cache_for_no_tensor_check,
+                processed_nodes
+            )
+
+    # After this point, the current node and all of its arguments
+    # have a target_dtype_info assigned. Now, we insert observers for inputs
+    # of this node (if needed for this node), and the output of this node
+    # (if needed for this node).
+
+    # Since we are mutating the graph as we go, we iterate over the original
+    # nodes before observer insertion, instead of model.graph.nodes.
+    nodes_before_observation = list(model.graph.nodes)
+
+    # Avoid duplicates custom module swaps for multiple nodes with same target.
+    custom_module_names_already_swapped: Set[str] = set()
+
+    # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
+    # reset inputs/outputs counters
+    inputs_seen_counter = 0
+    outputs_seen_counter = 0
+    results_node = None
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
+
+    # TODO: change this to insert obs/fq by pattern instead of by node
+    for node in nodes_before_observation:
+
+        if node.op == 'placeholder':
+            # if a graph input is in fp32, it does not need observation
+            # if a graph input is in int8, we assume the observation happens
+            #   outside of the graph, and no additional observation is needed
+            pass
+
+        elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
+            # check for matches
+            last_node, matched_node_pattern, pattern, qhandler, qconfig = (
+                node_name_to_match_result_with_qconfig.get(node.name, (None, None, None, None, None))  # type: ignore[assignment]
+            )
+            equalization_qconfig = equalization_config_map.get(node.name, None)
+
+            this_node_dtype_info = node.meta["target_dtype_info"]
+            if "val" in node.meta:
+                output_is_a_tensor = (
+                    this_node_dtype_info is not None and
+                    isinstance(node.meta["val"], FakeTensor)
+                )
+            else:
+                output_is_a_tensor = this_node_dtype_info is not None
+
+            skip_inserting_observers = (
+                (qconfig is None) or
+                not output_is_a_tensor
+            ) and (
+                not node.op == 'output'
+            )
+
+            # TODO: take a closer look to see if we can remove this check
+            # right now it is here because of `observed_node_names`, we are using
+            # it as an indicator for swapping the modules to reference modules in
+            # convert
+            is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
+                pattern, matched_node_pattern, qconfig, backend_config)
+
+            if not skip_inserting_observers and is_supported_by_backend:
+                named_modules = dict(model.named_modules(remove_duplicate=False))
+                if node.op != 'output':
+                    assert matched_node_pattern is not None
+                    # add matched nodes to the observed node name set
+                    _add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
+
+                    # This is currently only used for equalization.
+                    # Checks if the current node is in a branch in which the two
+                    # first layers are both being quantized.
+                    #
+                    # ex.       conv2
+                    #         /
+                    #      x -> conv1
+                    #
+                    # If this is the case, we will not apply equalization to the
+                    # initial two layers.
+                    is_quantized_branch = False
+                    if (
+                        len(node.args) > 0 and
+                        isinstance(node.args[0], Node) and
+                        len(node.args[0].users) > 1
+                    ):
+                        for user in node.args[0].users:
+                            # Checks if there exists another user being quantized
+                            is_user_quantized = (
+                                node_name_to_qconfig.get(user.name, None) is not None or
+                                (user.op == 'call_module' and isinstance(named_modules[str(user.target)], ObserverBase))
+                            )
+                            if user != node and is_user_quantized:
+                                is_quantized_branch = True
+
+                    pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
+                    root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
+                    root_node = root_node_getter(matched_node_pattern)
+                    is_input_node_of_the_pattern = node is root_node
+                    if is_input_node_of_the_pattern:
+                        # this modifies node inplace
+                        _maybe_insert_input_observers_for_node(
+                            node, qconfig, model, named_modules, model.graph,
+                            qhandler,
+                            prepare_custom_config,
+                            obs_or_fq_map,
+                            is_qat,
+                            backend_config)
+
+                        # insert equalization input observers if needed
+                        _maybe_insert_input_equalization_observers_for_node(
+                            node, equalization_qconfig, model, named_modules, model.graph,
+                            is_quantized_branch)
+
+                    is_last_node_of_pattern = node is last_node
+                    input_output_share_observers = node.meta["target_dtype_info"].get("input_output_share_observers", False)
+                    reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
+
+                    if is_last_node_of_pattern:
+                        if _is_custom_module_lstm(node, named_modules, qconfig, qhandler):
+                            # Currently custom module outputs are assumed to be already quantized,
+                            # so we need to insert a DeQuantStub after the output. For custom module
+                            # LSTM specifically, the outputs are also a nested tuple, so we must first
+                            # break down the tuple to insert DeQuantStubs after the internal nodes.
+
+                            # TODO: This currently diverges from how custom modules are handled today,
+                            # where we insert observers after the output instead of DeQuantStubs, and
+                            # replace these observers with "dequantize" nodes during convert. Conceptually,
+                            # these output observers are the same as DeQuantStubs. In the future, we
+                            # should resolve this inconsistency by inserting DeQuantStubs for all custom
+                            # modules, not just for LSTM.
+                            _insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
+                            if node.target not in custom_module_names_already_swapped:
+                                custom_module_names_already_swapped.add(node.target)
+                                _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
+                        else:
+                            # this returns the new observer node if it was needed
+                            maybe_output_obs_node = _maybe_insert_output_observer_for_node(
+                                node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
+
+                            if maybe_output_obs_node is not None:
+                                # Update users of original node to use the output observer
+                                # instead. For example, change
+                                #
+                                #           next_node
+                                #          /
+                                #   cur_node -> obs
+                                #
+                                # to
+                                #
+                                #                 next_node
+                                #                 /
+                                #   cur_node -> obs
+                                #
+                                # We need to save orig users before updating uses because
+                                # the list of users will change as we update uses
+                                orig_users = list(node.users.keys())
+                                for user_node in orig_users:
+                                    if user_node is maybe_output_obs_node:
+                                        continue
+                                    user_node.replace_input_with(node, maybe_output_obs_node)
+
+                                _is_observer_in_same_graph_ = _is_observer_in_same_graph(
+                                    node, named_modules, obs_or_fq_map, is_qat)
+
+                                # for ops whose inputs and outputs share observer/fqs, we modify the graph
+                                # to make all inputs and outputs use the first input's
+                                # observer/fq
+                                if (input_output_share_observers and _is_observer_in_same_graph_) or \
+                                        reuse_input_obs_or_fq:
+                                    if not _maybe_make_input_output_share_observers(node, model, named_modules):
+                                        _remove_output_observer(node, model, named_modules)
+
+                                if qhandler is not None and qhandler.is_custom_module():
+                                    if node.target not in custom_module_names_already_swapped:
+                                        custom_module_names_already_swapped.add(node.target)
+                                        _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
+
+                else:  # output
+                    _maybe_insert_observers_before_graph_output(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
+
+        #
+        # After this point, the current node has input and output observers
+        # that it needs for itself inserted.
+        #
+
+        # increment the counters, so future inputs and outputs are assigned
+        # correct dtypes
+        if node.op == 'placeholder':
+            inputs_seen_counter += 1
+        elif node.op == 'output':
+            outputs_seen_counter += 1
+            results_node = node
+
+    return results_node
+
+def _run_prepare_fx_on_standalone_modules(
+    model: torch.nn.Module,
+    is_qat: bool,
+    named_modules: Dict[str, torch.nn.Module],
+    node_name_to_match_result_with_qconfig: Any,
+    prepare_custom_config: PrepareCustomConfig,
+    backend_config: BackendConfig,
+) -> None:
+    """
+    Runs prepare_fx on each standalone module. Note: this does
+    not modify the graph, it just replaces the unobserved modules with
+    their observed versions.
+    """
+    for (root_node, _, pattern, qhandler, qconfig) in node_name_to_match_result_with_qconfig.values():
+        if qhandler is None:
+            continue
+        elif not qhandler.is_standalone_module():
+            continue
+
+        sm_qconfig_mapping, sm_example_inputs, sm_prepare_custom_config, \
+            sm_backend_config = _get_standalone_module_configs(
+                root_node, named_modules, prepare_custom_config, qconfig, backend_config)
+
+        standalone_module = named_modules[root_node.target]
+        prepare = \
+            torch.ao.quantization.quantize_fx._prepare_standalone_module_fx  # type: ignore[attr-defined]
+        observed_standalone_module = \
+            prepare(
+                standalone_module,
+                sm_qconfig_mapping,
+                is_qat,
+                example_inputs=sm_example_inputs,
+                prepare_custom_config=sm_prepare_custom_config,
+                backend_config=sm_backend_config)
+        parent_name, name = _parent_name(root_node.target)
+        setattr(named_modules[parent_name], name, observed_standalone_module)
+        named_modules[root_node.target] = observed_standalone_module
+
+def _save_state(
+    observed: GraphModule,
+    node_name_to_qconfig: Dict[str, QConfigAny],
+    node_name_to_scope: Dict[str, Tuple[str, type]],
+    prepare_custom_config: PrepareCustomConfig,
+    equalization_node_name_to_qconfig: Dict[str, Any],
+    qconfig_mapping: QConfigMapping,
+    is_qat: bool,
+    observed_node_names: Set[str],
+) -> None:
+    observed.meta["_observed_graph_module_attrs"] = (
+        ObservedGraphModuleAttrs(
+            node_name_to_qconfig=node_name_to_qconfig,
+            node_name_to_scope=node_name_to_scope,
+            prepare_custom_config=prepare_custom_config,
+            equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
+            qconfig_mapping=qconfig_mapping,
+            is_qat=is_qat,
+            observed_node_names=observed_node_names,
+        )
+    )
+
+def prepare(
+        model: GraphModule,
+        qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
+        is_qat: bool,
+        node_name_to_scope: Dict[str, Tuple[str, type]],
+        example_inputs: Tuple[Any, ...],
+        prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
+        _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
+        backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+        is_standalone_module: bool = False) -> GraphModule:
+    """ standalone_module means it a submodule that is not inlined in
+    parent module, and will be quantized separately as one unit.
+
+    How the standalone module is observed is specified by `input_quantized_idxs` and
+    `output_quantized_idxs` in the prepare_custom_config for the standalone module
+    Args:
+        node_name_to_scope: mapping from node name to the scope of the module which contains the node.
+        The scope is a tuple of fully qualified path of the module and the type of the module
+    Returns:
+        model(GraphModule): prepared standalone module
+        attributes related to standalone module
+        in model.meta["_observed_graph_module_attrs"]:
+            is_observed_standalone_module (bool): boolean value that shows whether the
+            current model is a observed standalone module or not
+            standalone_module_input_quantized_idxs(List[Int]): a list of
+                indexes for the graph input that is expected to be quantized,
+                same as input_quantized_idxs configuration provided
+                for the standalone module
+            standalone_module_output_quantized_idxs(List[Int]): a list of
+                indexs for the graph output that is quantized
+                same as input_quantized_idxs configuration provided
+                for the standalone module
+    """
+    if prepare_custom_config is None:
+        prepare_custom_config = PrepareCustomConfig()
+    if _equalization_config is None:
+        _equalization_config = QConfigMapping()
+
+    if isinstance(qconfig_mapping, Dict):
+        warnings.warn(
+            "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a QConfigMapping instead.")
+        qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
+
+    if isinstance(_equalization_config, Dict):
+        warnings.warn(
+            "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
+            "be supported in a future version. Please pass in a QConfigMapping instead.")
+        _equalization_config = QConfigMapping.from_dict(_equalization_config)
+
+    if isinstance(prepare_custom_config, Dict):
+        warnings.warn(
+            "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a PrepareCustomConfig instead.")
+        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
+
+    if isinstance(backend_config, Dict):
+        warnings.warn(
+            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a BackendConfig instead.")
+        backend_config = BackendConfig.from_dict(backend_config)
+
+    assert isinstance(qconfig_mapping, QConfigMapping)
+    assert isinstance(_equalization_config, QConfigMapping)
+    qconfig_mapping = copy.deepcopy(qconfig_mapping)
+    _equalization_config = copy.deepcopy(_equalization_config)
+
+    # mapping from a tuple of nodes in reverse order to uninitialized
+    #   QuantizeHandler subclass. For example,
+    # {
+    #   # match a single node
+    #   (:
+    #     ),
+    #   # match multiple nodes in reverse order
+    #   ((, ):
+    #     ),
+    # }
+
+    pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
+    if backend_config is None:
+        backend_config = get_native_backend_config()
+    pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
+    pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
+
+    root_node_getter_mapping = \
+        get_fusion_pattern_to_root_node_getter(backend_config)
+
+    _update_qconfig_for_fusion(model, qconfig_mapping)
+    _update_qconfig_for_fusion(model, _equalization_config)
+    flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
+    # TODO: support regex as well
+    propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
+
+    if is_qat:
+        module_to_qat_module = get_module_to_qat_module(backend_config)
+        _qat_swap_modules(model, module_to_qat_module)
+        _update_qconfig_for_qat(qconfig_mapping, backend_config)
+
+    # mapping from fully qualified module name to module instance
+    # for example,
+    # {
+    #   '': Model(...),
+    #   'linear': Linear(...),
+    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
+    # }
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+
+    # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
+    equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
+        model, named_modules, model.graph, _equalization_config, node_name_to_scope)
+    node_name_to_qconfig = _generate_node_name_to_qconfig(model, named_modules, model.graph, qconfig_mapping, node_name_to_scope)
+
+    # match the patterns that will get quantized
+    standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
+    standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
+
+    custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
+    matches_without_qconfig = _find_matches(
+        model.graph, named_modules, pattern_to_quantize_handler, root_node_getter_mapping,
+        standalone_module_names, standalone_module_classes, custom_module_classes)
+
+    # map qconfig instances to matches
+    node_name_to_match_result_with_qconfig = {}
+    for node_name, match_without_qconfig in matches_without_qconfig.items():
+        match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
+        node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
+
+    _run_prepare_fx_on_standalone_modules(
+        model, is_qat, named_modules, node_name_to_match_result_with_qconfig, prepare_custom_config, backend_config)
+
+    # record names for the set of observed node, so that in convert step
+    # we know whether we need to convert a floating point module to reference
+    # quantized module or not
+    observed_node_names: Set[str] = set()
+
+    result_node = insert_observers_for_model(
+        model,
+        node_name_to_match_result_with_qconfig,
+        node_name_to_qconfig,
+        prepare_custom_config,
+        equalization_node_name_to_qconfig,
+        backend_config,
+        observed_node_names,
+        is_qat,
+    )
+    model = GraphModule(model, model.graph)
+
+    _save_state(model, node_name_to_qconfig, node_name_to_scope,
+                prepare_custom_config, equalization_node_name_to_qconfig,
+                qconfig_mapping, is_qat, observed_node_names)
+
+    if is_standalone_module:
+        assert result_node is not None
+        assert isinstance(result_node.args[0], Node), \
+            "standalone module only supports returning simple value currently"\
+            "(not tuple, dict etc.)"
+        # these inputs are observed in parent
+        # converting List[int] to Tensor since module attribute is
+        # Union[Tensor, Module]
+        input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
+        output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
+        observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
+        # inplace modification
+        observed_graph_module_attrs.is_observed_standalone_module = True
+        observed_graph_module_attrs.standalone_module_input_quantized_idxs = \
+            input_quantized_idxs
+        observed_graph_module_attrs.standalone_module_output_quantized_idxs = \
+            output_quantized_idxs
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc428539d8efb0650101339def74b019520890f6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py
@@ -0,0 +1,343 @@
+import torch
+import re
+from collections import defaultdict, OrderedDict
+from typing import Callable, Any, Dict, Tuple, Set, List, Union
+from torch.ao.quantization import QConfig
+from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
+from torch.ao.quantization.observer import (
+    _is_activation_post_process,
+)
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    DTypeConfig,
+)
+from torch.ao.quantization.backend_config.utils import (
+    get_module_to_qat_module,
+)
+
+from torch.fx import (
+    GraphModule,
+)
+from torch.fx.graph import (
+    Graph,
+)
+from torch.ao.nn.intrinsic import _FusedModule
+
+from ..utils import (
+    _parent_name,
+    get_qconfig_dtypes,
+)
+from ..qconfig_mapping import (
+    _OBJECT_TYPE_DICT_KEY,
+    _MODULE_NAME_DICT_KEY,
+    _MODULE_NAME_REGEX_DICT_KEY,
+    QConfigMapping,
+)
+
+__all__: List[str] = []
+
+
+
+def _maybe_adjust_qconfig_for_module_name_object_type_order(
+    qconfig_mapping: QConfigMapping,
+    cur_module_path: str,
+    cur_object_type: Callable,
+    cur_object_type_idx: int,
+    fallback_qconfig: QConfigAny,
+) -> QConfigAny:
+    for (module_name, object_type, index), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items():
+        if (
+            (module_name == cur_module_path) and
+            (object_type == cur_object_type) and
+            (index == cur_object_type_idx)
+        ):
+            return qconfig
+    return fallback_qconfig
+
+
+def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
+    """
+    Update the QConfigMapping to account for fused modules such as LinearReLU.
+    This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
+    """
+    object_type_dict = qconfig_mapping.object_type_qconfigs
+    if len(object_type_dict) == 0:
+        return qconfig_mapping
+
+    modules = dict(model.named_modules())
+
+    for node in model.graph.nodes:
+        if node.op == 'call_module' and node.target in modules:
+            maybe_fused_module = modules[str(node.target)]
+            if not isinstance(maybe_fused_module, _FusedModule):
+                continue
+
+            ops = list(maybe_fused_module._modules.values())
+            fused_qconfig = object_type_dict.get(type(ops[0]), None)
+
+            # Raise an error if the modules in the fused module have
+            # different qconfigs specified in the qconfig_dict
+            # TODO: currently it only works for modules,
+            # need to make this work for torch.nn.functional.relu
+            # TODO: currently it only works for object_type configurations,
+            # ideally it should work for different types of configurations,
+            # maybe we want to redesign this part
+            for op in ops[1:]:
+                if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
+                    raise LookupError(
+                        "During fusion, we need to specify the same " +
+                        f"qconfigs for all module types in {type(maybe_fused_module)} " +
+                        f"offending type: {type(op)}")
+
+            if fused_qconfig is not None:
+                object_type_dict[type(maybe_fused_module)] = fused_qconfig
+
+def _generate_node_name_to_qconfig(
+        root: torch.nn.Module,
+        modules: Dict[str, torch.nn.Module],
+        input_graph: Graph,
+        qconfig_mapping: QConfigMapping,
+        node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]:
+    global_qconfig = qconfig_mapping.global_qconfig
+    node_name_to_qconfig = {}
+
+    # example:
+    #
+    #   {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
+    #
+    # meaning in submodule 'foo.bar', we have seen 0 F.linear and
+    # 1 F.conv2d invocations so far.
+    submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \
+        defaultdict(lambda: defaultdict(int))
+    for node in input_graph.nodes:
+        qconfig = None
+        if node.op == "get_attr":
+            module_name, _ = _parent_name(node.target)
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, type(modules[module_name]), module_name, global_qconfig)
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
+        elif node.op == "call_function":
+            # precedence: module_name_qconfig
+            # > function_qconfig > global_qconfig
+            # module_name takes precedence over function qconfig
+            function_qconfig = _get_object_type_qconfig(
+                qconfig_mapping, node.target, global_qconfig)
+            module_path, module_type = node_name_to_scope[node.name]
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, module_type, module_path, function_qconfig)
+
+            cur_object_type_idx = \
+                submodule_to_object_type_to_cur_idx[module_path][node.target]
+            submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
+            qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
+                qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig)
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
+
+        elif node.op == "call_method":
+            module_path, module_type = node_name_to_scope[node.name]
+            # first use node.target (string) to get the qconfig
+            # this is to support configs like
+            # "object_type": [("reshape", qconfig)]
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, node.target, module_path, global_qconfig)
+            # if there is no special config for the method, we'll fall back to the
+            # config for the module that contains the call_method node
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, module_type, module_path, qconfig)
+            # currently call_method does not support modifying qconfig
+            # by order, we can add this later if it is needed.
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
+
+        elif node.op == 'call_module':
+            # if the node is an observer, just continue - don't add it to the qconfig_map
+            if _is_activation_post_process(modules[node.target]):
+                continue
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
+
+            module_path, module_type = node_name_to_scope[node.name]
+            # Note: for call_module, the module_path is the current module's name.
+            # to meaningfully count invocations, we need to count them in the parent
+            # module.
+            parent_name, _ = _parent_name(module_path)
+            cur_object_type_idx = \
+                submodule_to_object_type_to_cur_idx[parent_name][module_type]
+            submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
+            qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
+                qconfig_mapping, parent_name, module_type, cur_object_type_idx,
+                qconfig)
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
+
+            # regex is not supported eager mode propagate_qconfig_, we'll
+            # need to set the qconfig explicitly here in case regex
+            # is used
+            modules[node.target].qconfig = qconfig_with_device_check
+        else:
+            qconfig_with_device_check = None
+
+        node_name_to_qconfig[node.name] = qconfig_with_device_check
+    return node_name_to_qconfig
+
+
+def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
+    r""" Checks if the given config_dict has the correct keys
+
+    Args:
+      `config_dict`: dictionary whose keys we want to check
+    """
+
+    for k in config_dict.keys():
+        if k not in allowed_keys:
+            raise ValueError(
+                'Expected ' + dict_name + ' to have the following keys: ' +
+                str(allowed_keys) + '. But found \'' + k +
+                '\' instead.')
+
+
+def _compare_prepare_convert_qconfig_mappings(
+        prepare_qconfig_mapping: QConfigMapping,
+        convert_qconfig_mapping: QConfigMapping):
+    r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values
+
+    Args:
+      `prepare_qconfig_mapping`: configuration for prepare quantization step
+      `convert_qconfig_mapping`: configuration for convert quantization step
+    """
+    assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \
+        "Expected global qconfigs to be the same in the prepare and convert quantization configs"
+    prepare_dicts: List[OrderedDict] = [
+        prepare_qconfig_mapping.object_type_qconfigs,
+        prepare_qconfig_mapping.module_name_qconfigs,
+        prepare_qconfig_mapping.module_name_regex_qconfigs,
+    ]
+    convert_dicts: List[OrderedDict] = [
+        convert_qconfig_mapping.object_type_qconfigs,
+        convert_qconfig_mapping.module_name_qconfigs,
+        convert_qconfig_mapping.module_name_regex_qconfigs,
+    ]
+    dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY]
+    for i in range(len(prepare_dicts)):
+        for name in prepare_dicts[i].keys():
+            assert name in convert_dicts[i], f"Missing key {dict_names[i]} {name} in convert QConfigMapping \
+                when it was present in prepare"
+            assert convert_dicts[i][name] is None \
+                or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \
+                f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \
+                prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
+
+def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
+    for dtype_config in dtype_configs:
+        is_dynamic = dtype_config.is_dynamic
+        if is_dynamic is None:
+            is_dynamic = False
+        input_dtype = dtype_config.input_dtype or torch.float
+        weight_dtype = dtype_config.weight_dtype or torch.float
+        bias_dtype = dtype_config.bias_dtype or torch.float
+        output_dtype = dtype_config.output_dtype or torch.float
+        qconfig_activation_dtype, qconfig_weight_dtype, qconfig_input_act_is_dynamic = \
+            get_qconfig_dtypes(qconfig)
+        qconfig_bias_dtype = torch.float16 \
+            if (
+                qconfig_activation_dtype == torch.float16
+                and qconfig_weight_dtype == torch.float16
+                and not is_dynamic
+            ) else torch.float
+
+        if is_dynamic:
+            is_match = qconfig_input_act_is_dynamic and \
+                input_dtype == qconfig_activation_dtype and \
+                output_dtype == torch.float and \
+                weight_dtype == qconfig_weight_dtype
+        else:
+            is_match = input_dtype == qconfig_activation_dtype and \
+                output_dtype == qconfig_activation_dtype and \
+                weight_dtype == qconfig_weight_dtype and \
+                bias_dtype == qconfig_bias_dtype
+        if is_match:
+            return True
+    return False
+
+def _get_object_type_qconfig(
+        qconfig_mapping: QConfigMapping,
+        object_type: Union[Callable, str],
+        fallback_qconfig: QConfigAny) -> QConfigAny:
+    return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
+
+
+def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+    for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
+        if re.match(regex_pattern, module_name):
+            # first match wins
+            return qconfig
+    return fallback_qconfig
+
+
+def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+    if module_name == '':
+        # module name qconfig not found
+        return fallback_qconfig
+    if module_name in qconfig_mapping.module_name_qconfigs:
+        return qconfig_mapping.module_name_qconfigs[module_name]
+    else:
+        parent, _ = _parent_name(module_name)
+        return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
+
+
+def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
+    # get qconfig for module_name,
+    # fallback to module_name_regex_qconfig, module_type_qconfig,
+    # global_qconfig if necessary
+    module_type_qconfig = _get_object_type_qconfig(
+        qconfig_mapping, module_type, global_qconfig)
+    module_name_regex_qconfig = _get_module_name_regex_qconfig(
+        qconfig_mapping, module_name, module_type_qconfig)
+    module_name_qconfig = _get_module_name_qconfig(
+        qconfig_mapping, module_name, module_name_regex_qconfig)
+    return module_name_qconfig
+
+
+def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
+    """ flatten the global, object_type and module_name qconfig
+    to the same qconfig_dict so that it can be used by
+    propagate_qconfig_ function.
+    "module_name_regex" is ignored for now since it's not supported
+    in propagate_qconfig_, but it can be fixed later.
+
+    For example:
+    Input: {
+      "": qconfig,
+      "object_type": [
+        (torch.add, qconfig)
+      ],
+      "module_name": [
+        ("conv", qconfig)
+      ]
+    }
+
+    Output: {
+      "": qconfig,
+      torch.add: qconfig,
+      "conv": qconfig
+    }
+    """
+    flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig}
+    for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
+        flattened[obj] = qconfig
+    for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
+        flattened[obj] = qconfig
+    return flattened
+
+
+def _update_qconfig_for_qat(
+        qconfig_mapping: QConfigMapping,
+        backend_config: BackendConfig):
+    """
+    Update the qconfig_mapping to account for module swaps during QAT.
+    During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
+    """
+    module_to_qat_module_class = get_module_to_qat_module(backend_config)
+    object_type_dict = qconfig_mapping.object_type_qconfigs
+    new_object_type_dict = object_type_dict.copy()
+    for k, v in new_object_type_dict.items():
+        if k in module_to_qat_module_class:
+            object_type_dict[module_to_qat_module_class[k]] = v
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/quantize_handler.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/quantize_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bda8e0210590306933d20680530ec9911c0537b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/quantize_handler.py
@@ -0,0 +1,197 @@
+from abc import ABC
+from typing import Callable, Dict, List, Optional, Type
+
+import torch
+
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    DTypeConfig,
+    ObservationType,
+)
+from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls
+from torch.fx.graph import Node
+
+from .utils import all_node_args_have_no_tensors
+
+
+__all__ = [
+    "QuantizeHandler",
+    "BinaryOpQuantizeHandler",
+    "CatQuantizeHandler",
+    "ConvReluQuantizeHandler",
+    "LinearReLUQuantizeHandler",
+    "BatchNormQuantizeHandler",
+    "EmbeddingQuantizeHandler",
+    "RNNDynamicQuantizeHandler",
+    "DefaultNodeQuantizeHandler",
+    "FixedQParamsOpQuantizeHandler",
+    "CopyNodeQuantizeHandler",
+    "GeneralTensorShapeOpQuantizeHandler",
+    "CustomModuleQuantizeHandler",
+    "StandaloneModuleQuantizeHandler",
+]
+
+def _default_root_node_getter(node_pattern):
+    if node_pattern is None:
+        return node_pattern
+    while not isinstance(node_pattern, Node):
+        node_pattern = node_pattern[-1]
+    return node_pattern
+
+# Base Pattern Handler
+class QuantizeHandler(ABC):  # noqa: B024
+    """ Base handler class for the quantizer patterns
+    """
+    def __init__(
+            self,
+            node_pattern: NodePattern,
+            modules: Dict[str, torch.nn.Module],
+            root_node_getter: Optional[Callable] = None,
+            is_custom_module=False,
+            is_standalone_module=False):
+        """ Records pattern information in __init__, which will be used
+        in convert
+        """
+        self.node_pattern = node_pattern
+        self.modules = modules
+        if root_node_getter is None:
+            root_node_getter = _default_root_node_getter
+        self.root_node = root_node_getter(node_pattern)
+        self.is_custom_module_ = is_custom_module
+        self.is_standalone_module_ = is_standalone_module
+        self.num_tensor_args = 0
+        # determine how many of the first two args are Tensors (versus scalars)
+        # this distinguishes things like "x + y" from "x + 2" or "2 + x"
+        if isinstance(self.root_node, Node):
+            cache_for_no_tensor_check: Dict[Node, bool] = {}
+            for arg_idx in range(len(self.root_node.args)):
+                arg = self.root_node.args[arg_idx]
+                if isinstance(arg, Node) and (
+                        not all_node_args_have_no_tensors(
+                            arg, self.modules, cache_for_no_tensor_check)):
+                    self.num_tensor_args += 1
+
+    def is_general_tensor_value_op(self) -> bool:
+        """
+        Returns True if the operator works for both floating point and
+        quantized input, and does some computation based on the input Tensor,
+        or the ops that only re-arranges the Tensor values or query some metadata
+        about the Tensor
+        so we need to insert observer/fake_quant for the output of the
+        operator (same observer instance as input)
+        since the distribution of values is different for input and output
+        Tensors (for HistogramObserver) while they share the same quantization
+        parameters
+        Example operator: avgpool2d, reshape, transpose, maxpool2d
+        Example observed operator:
+        observer_0 - avgpool2d - observer_0 (same observer instance as input)
+        """
+        return False
+
+    def is_custom_module(self):
+        return self.is_custom_module_
+
+    def is_standalone_module(self):
+        return self.is_standalone_module_
+
+def _get_quantize_handler_cls(
+        observation_type: ObservationType,
+        dtype_configs: List[DTypeConfig],
+        num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> Type[QuantizeHandler]:
+    """
+    Return a configurable QuantizeHandler that matches the given specifications from the backend.
+    """
+
+    class ConfigurableQuantizeHandler(QuantizeHandler):
+        def __init__(
+                self,
+                node_pattern: NodePattern,
+                modules: Dict[str, torch.nn.Module],
+                root_node_getter: Optional[Callable] = None):
+            super().__init__(node_pattern, modules, root_node_getter)
+            if num_tensor_args_to_observation_type:
+                assert self.num_tensor_args in num_tensor_args_to_observation_type, \
+                    f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
+                    f" in num_tensor_args_to_observation_type for {node_pattern}"
+                self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
+            else:
+                self.observation_type = observation_type
+            self.dtype_configs = dtype_configs
+
+        def is_general_tensor_value_op(self) -> bool:
+            return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+
+    return ConfigurableQuantizeHandler
+
+def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
+    """
+    Note: Quantize handler is just a holder for some check methods like
+    (should_insert_observer_for_output), maybe this can be a enum as well,
+    we can refactor this after we convert the path for fbgemm/qnnpack fully to the
+    new path, this is not exposed to backend developers
+    """
+    pattern_to_quantize_handlers = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        observation_type = config.observation_type
+        dtype_configs = config.dtype_configs
+        num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
+        pattern_to_quantize_handlers[pattern] = \
+            _get_quantize_handler_cls(
+                observation_type,
+                dtype_configs,
+                num_tensor_args_to_observation_type)
+    return pattern_to_quantize_handlers
+
+# TODO: remove this class, this is still exposed in torch.ao.quantization
+# but we should be able to break bc
+class BinaryOpQuantizeHandler(QuantizeHandler):
+    pass
+
+class CatQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove this class
+class ConvReluQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove this class
+class LinearReLUQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove this class
+class BatchNormQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove this class
+class EmbeddingQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove this class
+class RNNDynamicQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove this class
+class DefaultNodeQuantizeHandler(QuantizeHandler):
+    """ Common quantized op, first input and first output will be quantized
+    """
+    pass
+
+# TODO: remove this class
+class FixedQParamsOpQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove
+class CopyNodeQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: remove
+class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
+class CustomModuleQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
+class StandaloneModuleQuantizeHandler(QuantizeHandler):
+    pass
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/tracer.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/tracer.py
new file mode 100644
index 0000000000000000000000000000000000000000..914779749d4b5af6426e0890309908e7141ba050
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/tracer.py
@@ -0,0 +1,45 @@
+import torch
+from torch.fx._symbolic_trace import Tracer
+from torch.fx.proxy import Scope
+from torch.ao.nn.intrinsic import _FusedModule
+from typing import List, Callable
+
+__all__ = [
+    "QuantizationTracer",
+]
+
+class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
+    def __init__(
+        self,
+        scope: Scope,
+        current_module: torch.nn.Module,
+        current_module_path: str
+    ):
+        super().__init__(scope, Scope(current_module_path, type(current_module)))
+
+
+class QuantizationTracer(Tracer):
+    def __init__(
+        self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
+    ):
+        super().__init__()
+        self.skipped_module_names = skipped_module_names
+        self.skipped_module_classes = skipped_module_classes
+        # NB: initialized the module_type of top level module to None
+        # we are assuming people won't configure the model with the type of top level
+        # module here, since people can use "" for global config
+        # We can change this if there is a use case that configures
+        # qconfig using top level module type
+        self.scope = Scope("", None)
+        self.record_stack_traces = True
+
+    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
+        return (
+            (
+                (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
+                and not isinstance(m, torch.nn.Sequential)
+            )
+            or module_qualified_name in self.skipped_module_names
+            or type(m) in self.skipped_module_classes
+            or isinstance(m, _FusedModule)
+        )
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/fx/utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/fx/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..87a8d7b6a933bdd0dc3eb805ec477b59980c5a9a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/fx/utils.py
@@ -0,0 +1,885 @@
+import copy
+import torch
+import torch.nn as nn
+from torch.ao.quantization import (
+    QConfigAny,
+    QuantType,
+)
+from torch.ao.quantization.backend_config import (
+    DTypeWithConstraints,
+)
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantizeBase,
+    FixedQParamsFakeQuantize,
+)
+from torch.ao.quantization.observer import (
+    FixedQParamsObserver,
+    ObserverBase,
+)
+from torch.ao.quantization.qconfig import (
+    float16_static_qconfig,
+    float16_dynamic_qconfig,
+    qconfig_equals,
+)
+from torch.ao.quantization.stubs import DeQuantStub
+from torch.ao.quantization.utils import (
+    activation_is_statically_quantized,
+)
+from torch.ao.quantization.observer import _is_activation_post_process
+from torch.ao.quantization.qconfig_mapping import QConfigMapping
+
+from torch.fx import GraphModule, map_arg
+
+from torch.fx.graph import (
+    Graph,
+    Node,
+)
+from .custom_config import PrepareCustomConfig
+# importing the lib so that the quantized_decomposed ops are registered
+from ._decomposed import quantized_decomposed_lib  # noqa: F401
+
+from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
+from dataclasses import dataclass
+from collections import namedtuple
+import operator
+import warnings
+
+# TODO: revisit this list. Many helper methods shouldn't be public
+__all__ = [
+    "all_node_args_except_first",
+    "all_node_args_have_no_tensors",
+    "assert_and_get_unique_device",
+    "collect_producer_nodes",
+    "create_getattr_from_value",
+    "create_node_from_old_node_preserve_meta",
+    "EMPTY_ARG_DICT",
+    "get_custom_module_class_keys",
+    "get_linear_prepack_op_for_dtype",
+    "get_new_attr_name_with_prefix",
+    "get_non_observable_arg_indexes_and_types",
+    "get_qconv_prepack_op",
+    "get_skipped_module_name_and_classes",
+    "graph_module_from_producer_nodes",
+    "maybe_get_next_module",
+    "NodeInfo",
+    "node_arg_is_bias",
+    "node_arg_is_weight",
+    "NON_OBSERVABLE_ARG_DICT",
+    "NON_QUANTIZABLE_WEIGHT_OPS",
+    "return_arg_list",
+    "ObservedGraphModuleAttrs",
+]
+
+NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}
+
+@dataclass
+class ObservedGraphModuleAttrs:
+    node_name_to_qconfig: Dict[str, QConfigAny]
+    node_name_to_scope: Dict[str, Tuple[str, type]]
+    prepare_custom_config: PrepareCustomConfig
+    equalization_node_name_to_qconfig: Dict[str, Any]
+    qconfig_mapping: QConfigMapping
+    is_qat: bool
+    observed_node_names: Set[str]
+    is_observed_standalone_module: bool = False
+    standalone_module_input_quantized_idxs: Optional[List[int]] = None
+    standalone_module_output_quantized_idxs: Optional[List[int]] = None
+
+def node_arg_is_weight(node: Node, arg: Any) -> bool:
+    """Returns if node arg is weight"""
+    weight_index = None
+    if "target_dtype_info" in node.meta:
+        weight_index = node.meta["target_dtype_info"].get("weight_index", None)
+    if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
+        return True
+    return node.kwargs.get("weight") is arg
+
+def node_arg_is_bias(node: Node, arg: Any) -> bool:
+    """Returns if node arg is bias"""
+    bias_index = None
+    if "target_dtype_info" in node.meta:
+        bias_index = node.meta["target_dtype_info"].get("bias_index", None)
+    if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
+        return True
+    return node.kwargs.get("bias") is arg
+
+def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
+    r""" Get all the unique custom module keys in the custom config dict
+    e.g.
+    Input:
+    {
+        QuantType.STATIC: {
+            CustomModule1: ObservedCustomModule
+        },
+        QuantType.DYNAMIC: {
+            CustomModule2: DynamicObservedCustomModule
+        },
+        QuantType.WEIGHT_ONLY: {
+            CustomModule3: WeightOnlyObservedCustomModule
+        },
+    }
+
+    Output:
+    # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
+    [CustomModule1, CustomModule2, CustomModule3]
+    """
+    # using set to dedup
+    float_custom_module_classes : Set[Any] = set()
+    for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
+        quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
+        quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
+        float_custom_module_classes |= quant_mode_custom_module_classes
+    return list(float_custom_module_classes)
+
+def get_linear_prepack_op_for_dtype(dtype):
+    if dtype == torch.float16:
+        return torch.ops.quantized.linear_prepack_fp16
+    elif dtype == torch.qint8:
+        return torch.ops.quantized.linear_prepack
+    else:
+        raise Exception("can't get linear prepack op for dtype:", dtype)
+
+def get_qconv_prepack_op(conv_op: Callable) -> Callable:
+    prepack_ops = {
+        torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
+        torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
+        torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
+        torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
+        torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
+        torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
+    }
+    prepack_op = prepack_ops.get(conv_op, None)
+    assert prepack_op, f"Didn't find prepack op for {conv_op}"
+    return prepack_op
+
+# Returns a function that can get a new attribute name for module with given
+# prefix, for example,
+# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
+# >> new_name = get_new_observer_name(module)
+# new_name will be an unused attribute name on module, e.g. `_observer_1`
+def get_new_attr_name_with_prefix(prefix: str) -> Callable:
+    prefix = prefix.replace(".", "_")
+
+    def get_new_attr_name(module: torch.nn.Module):
+        def get_attr_name(i: int):
+            return prefix + str(i)
+        i = 0
+        attr_name = get_attr_name(i)
+        while hasattr(module, attr_name):
+            i += 1
+            attr_name = get_attr_name(i)
+        return attr_name
+    return get_new_attr_name
+
+def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
+    r''' Starting from a target node, trace back until we hit inpu or
+    getattr node. This is used to extract the chain of operators
+    starting from getattr to the target node, for example
+    def forward(self, x):
+      observed = self.observer(self.weight)
+      return F.linear(x, observed)
+    collect_producer_nodes(observed) will either return a list of nodes that
+    produces the observed node or None if we can't extract a self contained
+    graph without free variables(inputs of the forward function).
+    '''
+    nodes = [node]
+    frontier = [node]
+    while frontier:
+        node = frontier.pop()
+        all_args = list(node.args) + list(node.kwargs.values())
+        for arg in all_args:
+            if not isinstance(arg, Node):
+                continue
+            if arg.op == 'placeholder':
+                # hit input, can't fold in this case
+                return None
+            nodes.append(arg)
+            if not (arg.op == 'call_function' and arg.target == getattr):
+                frontier.append(arg)
+    return nodes
+
+def graph_module_from_producer_nodes(
+        root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
+    r''' Construct a graph module from extracted producer nodes
+    from `collect_producer_nodes` function
+    Args:
+      root: the root module for the original graph
+      producer_nodes: a list of nodes we use to construct the graph
+    Return:
+      A graph module constructed from the producer nodes
+    '''
+    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
+    # since we traced back from node to getattr
+    producer_nodes.reverse()
+    graph = Graph()
+    env: Dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node])
+    for producer_node in producer_nodes:
+        env[producer_node] = graph.node_copy(producer_node, load_arg)
+    graph.output(load_arg(producer_nodes[-1]))
+    graph_module = GraphModule(root, graph)
+    return graph_module
+
+def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
+    """
+    Returns the unique device for a module, or None if no device is found.
+    Throws an error if multiple devices are detected.
+    """
+    devices = {p.device for p in module.parameters()} | \
+        {p.device for p in module.buffers()}
+    """
+    As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
+    """
+    if {torch.device("cpu"), torch.device("meta")} == devices:
+        warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
+        devices = {torch.device("cpu")}
+    ""
+    assert len(devices) <= 1, (
+        "prepare only works with cpu or single-device CUDA modules, "
+        f"but got devices {devices}"
+    )
+    device = next(iter(devices)) if len(devices) > 0 else None
+    return device
+
+def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
+    """
+    Given a value of any type, creates a getattr node corresponding to the value and
+    registers the value as a buffer to the module.
+    """
+    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
+    attr_name = get_new_attr_name(module)
+    device = assert_and_get_unique_device(module)
+    new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
+        else torch.tensor(value, device=device)
+    module.register_buffer(attr_name, new_value)
+    # Create get_attr with value
+    attr_node = graph.create_node("get_attr", attr_name)
+    return attr_node
+
+def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
+    """
+    If we know for sure that all of this node's args have no
+    tensors (are primitives), return True.  If we either
+    find a tensor or are not sure, return False. Note: this
+    function is not exact.
+    """
+    if cache and node in cache:
+        return cache[node]
+
+    result = False  # will be overwritten
+    if not isinstance(node, Node):
+        result = True
+    elif node.op == 'placeholder':
+        result = False
+    elif node.op == 'call_module':
+        assert isinstance(node.target, str)
+        if _is_activation_post_process(modules[node.target]):
+            result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
+    elif node.op == 'call_module':
+        result = False
+    elif node.op == 'call_function' and node.target is operator.getitem:
+        result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
+    elif node.op == 'get_attr':
+        result = False
+    elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
+        # x1 = x0.ndim
+        result = True
+    elif node.op == 'call_method' and node.target == 'size':
+        # x1 = x0.size(0)
+        result = True
+    else:
+        found_one_tensor = False
+        for arg in node.args:
+            if isinstance(arg, list):
+                for list_el in arg:
+                    if isinstance(list_el, Node):
+                        this_list_el_args_have_no_tensors = \
+                            all_node_args_have_no_tensors(list_el, modules, cache)
+                        found_one_tensor = found_one_tensor or \
+                            (not this_list_el_args_have_no_tensors)
+                        # If found_one_tensor is True, there is no point in
+                        # recursing further as the end result will always
+                        # be True.
+                        # TODO(future PR): remove this entire function  and
+                        # change to dtype inference without recursion.
+                        if found_one_tensor:
+                            result = not found_one_tensor
+                            if cache:
+                                cache[node] = result
+                            return result
+            elif isinstance(arg, int):
+                pass
+            else:
+                if isinstance(arg, Node):
+                    this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
+                    found_one_tensor = found_one_tensor or \
+                        (not this_arg_args_have_no_tensors)
+                    # If found_one_tensor is True, there is no point in
+                    # recursing further as the end result will always
+                    # be True.
+                    # TODO(future PR): remove this entire function  and
+                    # change to dtype inference without recursion.
+                    if found_one_tensor:
+                        result = not found_one_tensor
+                        if cache:
+                            cache[node] = result
+                        return result
+                else:
+                    found_one_tensor = True
+            result = not found_one_tensor
+    if cache:
+        cache[node] = result
+    return result
+
+def all_node_args_except_first(node: Node) -> List[int]:
+    """
+    Returns all node arg indices after first
+    """
+    return list(range(1, len(node.args)))
+
+def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
+    """
+    Constructs a function that takes a node as arg and returns the arg_indices
+    that are valid for node.args
+    """
+    def arg_indices_func(node: Node) -> List[int]:
+        return [i for i in arg_indices if i < len(node.args)]
+    return arg_indices_func
+
+NodeInfo = namedtuple("NodeInfo", "op target")
+
+# this dict identifies which indices of a node are non tensors
+# so that they can be propagated correctly since inserting observers
+# for them would cause errors
+
+NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = {
+    NodeInfo("call_method", "masked_fill") : {
+        torch.bool: return_arg_list([1]),
+        float: return_arg_list([2])
+    },
+    NodeInfo("call_method", "permute") : {
+        int: all_node_args_except_first
+    },
+    NodeInfo("call_method", "repeat") : {
+        int: all_node_args_except_first
+    },
+    NodeInfo("call_method", "reshape") : {
+        int: all_node_args_except_first
+    },
+    NodeInfo("call_method", "size") : {
+        int: return_arg_list([1])
+    },
+    NodeInfo("call_method", "transpose") : {
+        int: all_node_args_except_first
+    },
+    NodeInfo("call_method", torch.transpose) : {
+        int: all_node_args_except_first
+    },
+    NodeInfo("call_method", "unsqueeze") : {
+        int: return_arg_list([1])
+    },
+    NodeInfo("call_method", "unsqueeze_") : {
+        int: return_arg_list([1])
+    },
+    NodeInfo("call_method", torch.unsqueeze) : {
+        int: return_arg_list([1])
+    },
+    NodeInfo("call_method", "view") : {
+        int: all_node_args_except_first
+    },
+}
+
+EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
+
+def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
+    """
+    Returns a dict with of non float tensor types as keys and values which correspond to a
+    function to retrieve the list (which takes the node as an argument)
+    """
+    info = NodeInfo(node.op, node.target)
+
+    return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
+
+def maybe_get_next_module(
+    node: Node,
+    modules: Dict[str, nn.Module],
+    target_module_type: Optional[Type[nn.Module]] = None,
+    target_functional_type: Any = None,
+) -> Optional[Node]:
+    """ Gets the next module that matches what is needed in
+    is_target_module_type if it exists
+
+    Args:
+        node: The node whose users we want to look at
+        target_module_type: Module type that we want to check
+        target_functional_type: Functional type that we want to check
+    """
+
+    for user in node.users.keys():
+        if user.op == 'call_module' and target_module_type is not None and \
+           isinstance(modules[str(user.target)], target_module_type):
+            return user
+        elif (user.op == 'call_function' and target_functional_type is not None and
+              user.target == target_functional_type):
+            return user
+
+    return None
+
+def create_node_from_old_node_preserve_meta(
+    quantized_graph: Graph,
+    create_node_args: Tuple[Any, ...],
+    old_node: Node,
+) -> Node:
+    """
+    Creates `new_node` and copies the necessary metadata to it from `old_node`.
+    """
+    new_node = quantized_graph.create_node(*create_node_args)
+    new_node.stack_trace = old_node.stack_trace
+    return new_node
+
+def get_skipped_module_name_and_classes(
+        prepare_custom_config: PrepareCustomConfig,
+        is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]:
+    skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
+    skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes)
+    if not is_standalone_module:
+        # standalone module and custom module config are applied in top level module
+        skipped_module_names += list(prepare_custom_config.standalone_module_names.keys())
+        skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys())
+        skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
+
+    return skipped_module_names, skipped_module_classes
+
+def _is_custom_module_lstm(
+        node: Node,
+        named_modules: Dict[str, torch.nn.Module],
+        qconfig: QConfigAny = None,
+        # QuantizeHandler, but we cannot include the type here due to circular imports
+        qhandler: Optional[Any] = None,
+) -> bool:
+    """
+    Return whether this refers to the custom module LSTM flow.
+    """
+    mod = _get_module(node, named_modules)
+    if qconfig is not None and qhandler is not None:
+        assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler)  # type: ignore[attr-defined]
+        return isinstance(mod, torch.nn.LSTM) and \
+            activation_is_statically_quantized(qconfig) and \
+            qhandler.is_custom_module()
+    else:
+        return isinstance(mod, torch.ao.nn.quantizable.LSTM)
+
+def _is_custom_module_mha(
+        node: Node,
+        named_modules: Dict[str, torch.nn.Module],
+        qconfig: QConfigAny = None,
+        # QuantizeHandler, but we cannot include the type here due to circular imports
+        qhandler: Optional[Any] = None,
+) -> bool:
+    """
+    Return whether this refers to the custom module MultiheadAttention flow.
+    """
+    mod = _get_module(node, named_modules)
+    if qconfig is not None and qhandler is not None:
+        assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler)  # type: ignore[attr-defined]
+        return isinstance(mod, torch.nn.MultiheadAttention) and \
+            activation_is_statically_quantized(qconfig) and \
+            qhandler.is_custom_module()
+    else:
+        return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
+
+def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]:
+    """
+    If `node` refers to a call_module node, return the module, else None.
+    """
+    if node.op == "call_module" and str(node.target) in named_modules:
+        return named_modules[str(node.target)]
+    else:
+        return None
+
+def _insert_dequant_stub(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+) -> Node:
+    """
+    Attach a `DeQuantStub` to the model and create a node that calls this
+    `DeQuantStub` on the output of `node`, similar to how observers are inserted.
+    """
+    prefix = "dequant_stub_"
+    get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
+    dequant_stub_name = get_new_dequant_stub_name(model)
+    dequant_stub = DeQuantStub()
+    setattr(model, dequant_stub_name, dequant_stub)
+    named_modules[dequant_stub_name] = dequant_stub
+    with graph.inserting_after(node):
+        return graph.call_module(dequant_stub_name, (node,))
+
+def _insert_dequant_stubs_for_custom_module_lstm_output(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+) -> Node:
+    """
+    Insert DeQuantStubs after each internal output node of custom module LSTM.
+
+    Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
+    Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
+    components through `getitem`. This function transforms the graph as follows:
+
+      (1) Split the LSTM node into (output, (hidden0, hidden1))
+      (2) Insert a DeQuantStub after each internal node
+      (3) Recombine the DeQuantStubs into the same structure as before
+      (4) Reroute all consumers of the original LSTM node and its sub-nodes
+          (e.g. lstm[0])
+
+    Before:
+                   lstm_output
+                        |
+                        v
+                  original_user(s)
+    After:
+                   lstm_output
+                  /           \\
+                 /  (getitem)  \\
+                /               \\
+               v                 v
+             output            hidden
+               |               /   \\
+         (DeQuantStub)        (getitem)
+               |             /       \\
+               v            v         v
+           output_dq     hidden0    hidden1
+               |            |         |
+               |    (DeQuantStub) (DeQuantStub)
+               |            |         |
+               |            v         v
+               |      hidden0_dq  hidden1_dq
+               |            \\       /
+               |              (tuple)
+               |              \\   /
+               |               v  v
+               |             hidden_dq
+               \\               /
+                \\   (tuple)   /
+                 v            v
+                 lstm_output_dq
+                       |
+                       v
+                original_user(s)
+
+    For step (4), reroute all users of the original LSTM node(s) as follows:
+      lstm_output -> lstm_output_dq
+      lstm_output[0] -> output_dq
+      lstm_output[1] -> hidden_dq
+      lstm_output[1][0] -> hidden0_dq
+      lstm_output[1][1] -> hidden1_dq
+
+    Return the node `lstm_output_dq`.
+    """
+    # (1) Split the LSTM node into (output, (hidden0, hidden1))
+    # (2) Insert a DeQuantStub after each internal node
+    with graph.inserting_after(node):
+        output = graph.call_function(operator.getitem, (node, 0))
+        output_dq = _insert_dequant_stub(output, model, named_modules, graph)
+    with graph.inserting_after(output_dq):
+        hidden = graph.call_function(operator.getitem, (node, 1))
+    with graph.inserting_after(hidden):
+        hidden0 = graph.call_function(operator.getitem, (hidden, 0))
+        hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
+    with graph.inserting_after(hidden0_dq):
+        hidden1 = graph.call_function(operator.getitem, (hidden, 1))
+        hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
+
+    # (3) Recombine the DeQuantStubs into the same structure as before
+    with graph.inserting_after(hidden1_dq):
+        hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
+    with graph.inserting_after(hidden_dq):
+        lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
+
+    # (4) Reroute all consumers of the original LSTM node and its sub-nodes
+    for user in list(node.users.keys()):
+        if user != output and user != hidden:
+            user.replace_input_with(node, lstm_output_dq)
+    # The getitem and tuple nodes we added here may interfere with reference quantized
+    # pattern matching, so we need to redirect the consumers of internal nodes to the
+    # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
+    # in order to preserve reference patterns like "dequantize - consumer - quantize".
+    _reroute_tuple_getitem_pattern(graph)
+    return lstm_output_dq
+
+def _maybe_get_custom_module_lstm_from_node_arg(
+    arg: Node,
+    named_modules: Dict[str, torch.nn.Module],
+) -> Optional[Node]:
+    """
+    Given an argument of a node, if the argument refers to the path through which the node
+    is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
+
+    This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
+    skip inserting input observers for this node. This is because custom module LSTM produces
+    quantized outputs, so inserting an input observer for the consumer of custom module LSTM
+    would unnecessarily quantize the outputs again.
+
+      lstm -> consumer
+
+    In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
+    DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
+    This tuple can be consumed in one of four ways:
+
+      lstm -> getitem -> DeQuantStub -> consumer                       # consume lstm[0]
+      lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer   # consume lstm[1]
+      lstm -> getitem -> getitem -> DeQuantStub -> consumer            # consume lstm[1][0] or lstm[1][1]
+      lstm -> getitem -> DeQuantStub -> tuple -> consumer              # consume lstm
+
+    Thus, we must match against the above patterns instead of simply checking the parent node
+    to determine whether this node is a consumer of a custom module LSTM.
+    """
+    def match_dq(a):
+        return isinstance(_get_module(a, named_modules), DeQuantStub)
+
+    def match_lstm(a):
+        return _is_custom_module_lstm(a, named_modules)
+
+    def match_getitem(a):
+        return a.op == "call_function" and a.target == operator.getitem
+
+    def match_tuple(a):
+        return a.op == "call_function" and a.target == tuple
+
+    def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]:
+        """
+        Traverse up the graph and match the args one by one.
+        If there is a match, return the last matched node, or None otherwise.
+        """
+        a = arg
+        for i, match in enumerate(match_pattern):
+            if not match(a):
+                return None
+            # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
+            if i < len(match_pattern) - 1:
+                if match == match_tuple:
+                    a = a.args[0][0]  # type: ignore[assignment,index]
+                else:
+                    a = a.args[0]  # type: ignore[assignment]
+        return a
+
+    all_match_patterns = [
+        [match_dq, match_getitem, match_lstm],
+        [match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
+        [match_dq, match_getitem, match_getitem, match_lstm],
+        [match_tuple, match_dq, match_getitem, match_lstm],
+    ]
+
+    for p in all_match_patterns:
+        matched_node = _match_pattern(p)
+        if matched_node is not None:
+            return matched_node
+    return None
+
+def _reroute_tuple_getitem_pattern(graph: Graph):
+    """
+    Search for patterns where N consecutive `tuple` call_function nodes are followed by
+    N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
+    If we find this pattern, reroute the consumers of the last `getitem` to skip these
+    N `tuple` and `getitem` nodes.
+
+    Before:
+
+        a   b     c
+        |   \\   /
+        \\   tuple
+         \\   /
+          tuple
+            |
+        getitem(1)
+            |
+        getitem(0)
+            |
+            d
+
+    After:
+
+        b
+        |
+        d
+    """
+    def find_patterns(
+            node: Node,
+            index_stack: List[int],
+            current_pattern: List[Node],
+            matched_patterns: List[List[Node]],
+            seen: Set[Tuple[Node, Tuple[int, ...]]]):
+        """
+        Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
+        starting at the given node.
+
+        We use a stack to keep track of the expected `getitem` indices, since these are
+        reversed from the `tuple` indices. In the above example, the stack after
+        (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
+        and then by getitem(0).
+
+        TODO: traverse upwards from the output and handle the case when tuple is not a
+        separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
+        """
+        if len(index_stack) == 0 and len(current_pattern) > 0:
+            matched_patterns.append(copy.copy(current_pattern))
+            current_pattern.clear()
+
+        # Avoid duplicating work
+        state = (node, tuple(index_stack))
+        if state in seen:
+            return
+        seen.add(state)
+
+        # Iterate through users of this node to find tuple/getitem nodes to match
+        for user in node.users:
+            if user.op == "call_function" and user.target == tuple:
+                for i, user_arg in enumerate(user.args[0]):  # type: ignore[arg-type]
+                    if user_arg == node:
+                        index_stack.append(i)
+                        current_pattern.append(user)
+                        find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
+            elif user.op == "call_function" and user.target == operator.getitem:
+                if len(index_stack) > 0:
+                    if user.args[1] == index_stack[-1]:
+                        index_stack.pop()
+                        current_pattern.append(user)
+                        find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
+        return matched_patterns
+
+    # Collect all matched patterns
+    matched_patterns: List[List[Node]] = []
+    seen: Set[Tuple[Node, Tuple[int, ...]]] = set()  # (node, index_stack)
+    for node in graph.nodes:
+        find_patterns(node, [], [], matched_patterns, seen)
+
+    # For each pattern, redirect all consumers of the last getitem node to the correct input
+    # of the first tuple node
+    for pattern in matched_patterns:
+        first_tuple = pattern[0]
+        last_getitem = pattern[-1]
+        assert first_tuple.op == "call_function" and first_tuple.target == tuple
+        assert last_getitem.op == "call_function" and last_getitem.target == operator.getitem
+        last_getitem_index = last_getitem.args[1]
+        new_input = first_tuple.args[0][last_getitem_index]  # type: ignore[index]
+        for user in list(last_getitem.users.keys()):
+            user.replace_input_with(last_getitem, new_input)
+
+def _get_observer_from_activation_post_process(
+    activation_post_process: Union[ObserverBase, FakeQuantizeBase],
+) -> ObserverBase:
+    """
+    If `activation_post_process` is an observer, return the observer.
+    If `activation_post_process` is a fake quantize, return the internal observer.
+    """
+    if isinstance(activation_post_process, ObserverBase):
+        return activation_post_process
+    else:
+        assert isinstance(activation_post_process, FakeQuantizeBase)
+        return activation_post_process.activation_post_process  # type: ignore[return-value]
+
+def _qconfig_satisfies_dtype_config_constraints(
+        qconfig: QConfigAny,
+        dtype_with_constraints: DTypeWithConstraints,
+        is_activation: bool = True) -> bool:
+    """
+    Return whether `qconfig` satisfies the following constraints from the backend,
+    specified through the activation and weight DTypeWithConstraints.
+
+        1. QConfig specified a quantization range that falls within the backend's, if any
+        2. QConfig specified a min scale value that is >= the backend's, if any
+        3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
+           scale and zero point that match the backend's, if any
+
+    If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
+    If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
+    """
+    # TODO: log warnings only when the user enabled a debug flag
+    def _activation_post_process_satisfies_dtype_config_constraints(
+            activation_post_process: Union[ObserverBase, FakeQuantizeBase],
+            dtype_with_constraints: DTypeWithConstraints,
+            debug_string: str) -> bool:
+        observer = _get_observer_from_activation_post_process(activation_post_process)
+        app_quant_min = getattr(observer, "quant_min", None)
+        app_quant_max = getattr(observer, "quant_max", None)
+        # TODO: for now, just use the existing eps value as scale_min. In the future, we should
+        # resolve the differences between the two, either by renaming eps or some other way
+        app_scale_min = getattr(observer, "eps", None)
+        backend_quant_min = dtype_with_constraints.quant_min_lower_bound
+        backend_quant_max = dtype_with_constraints.quant_max_upper_bound
+        backend_scale_min = dtype_with_constraints.scale_min_lower_bound
+        backend_scale_exact_match = dtype_with_constraints.scale_exact_match
+        backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
+        # check quantization ranges
+        if backend_quant_min is not None and backend_quant_max is not None:
+            if app_quant_min is None or app_quant_max is None:
+                warnings.warn(f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}")
+                return False
+            elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
+                warnings.warn(
+                    f"QConfig {debug_string} quantization range must fall within the backend's:\n"
+                    f"QConfig range = ({app_quant_min}, {app_quant_max}), "
+                    f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
+                    f"ignoring {qconfig}"
+                )
+                return False
+        # check scale min
+        if backend_scale_min is not None:
+            if app_scale_min is None:
+                warnings.warn(f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}")
+                return False
+            if app_scale_min < backend_scale_min:
+                warnings.warn(
+                    f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
+                    f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
+                )
+                return False
+        # check fixed scale and zero point
+        if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None:
+            # For tests only, accept the following qconfigs for now
+            # TODO: handle fp16 qconfigs properly
+            for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
+                if qconfig_equals(qconfig, accepted_qconfig):
+                    return True
+            suggestion_str = (
+                "Please use torch.ao.quantization.get_default_qconfig_mapping or "
+                "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
+                "    qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n"
+                "    model = prepare_fx(model, qconfig_mapping, example_inputs)"
+            )
+            if not isinstance(activation_post_process, FixedQParamsObserver) and \
+                    not isinstance(activation_post_process, FixedQParamsFakeQuantize):
+                warnings.warn(
+                    f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
+                    f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
+                )
+                return False
+            if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match:
+                warnings.warn(
+                    f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
+                    f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
+                    f"ignoring {qconfig}.\n{suggestion_str}"
+                )
+                return False
+        return True
+
+    if qconfig is None or dtype_with_constraints.dtype is None:
+        return True
+
+    activation_post_process_ctr = qconfig.activation if is_activation else qconfig.weight
+    debug_string = "activation" if is_activation else "weight"
+    satisfies_constraints = True
+    if activation_post_process_ctr is not None:
+        activation_post_process = activation_post_process_ctr()
+        assert _is_activation_post_process(activation_post_process)
+        # If dtypes don't match, don't check the activation_post_process and return True early
+        if activation_post_process.dtype != dtype_with_constraints.dtype:
+            return True
+        satisfies_constraints = _activation_post_process_satisfies_dtype_config_constraints(
+            activation_post_process, dtype_with_constraints, debug_string)
+    return satisfies_constraints
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/observer.py b/MLPY/Lib/site-packages/torch/ao/quantization/observer.py
new file mode 100644
index 0000000000000000000000000000000000000000..45036534daf7d0ce4cfd9270fe242a39c0c315e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/observer.py
@@ -0,0 +1,1688 @@
+"""
+This module implements observers which are used to collect statistics about
+the values observed during calibration (PTQ) or training (QAT).
+"""
+
+import re
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+from functools import partial
+from typing import Any, List, Tuple, Optional, Dict
+
+import torch
+import torch.nn as nn
+from torch.ao.quantization.utils import (
+    check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel, validate_qmin_qmax)
+
+__all__ = [
+    "default_affine_fixed_qparams_observer",
+    "default_debug_observer",
+    "default_dynamic_quant_observer",
+    "default_fixed_qparams_range_0to1_observer",
+    "default_fixed_qparams_range_neg1to1_observer",
+    "default_float_qparams_observer",
+    "default_float_qparams_observer_4bit",
+    "default_histogram_observer",
+    "default_observer",
+    "default_per_channel_weight_observer",
+    "default_placeholder_observer",
+    "default_reuse_input_observer",
+    "default_symmetric_fixed_qparams_observer",
+    "default_weight_observer",
+    "get_observer_state_dict",
+    "load_observer_state_dict",
+    "per_channel_weight_observer_range_neg_127_to_127",
+    "weight_observer_range_neg_127_to_127",
+    "FixedQParamsObserver",
+    "HistogramObserver",
+    "MinMaxObserver",
+    "MovingAverageMinMaxObserver",
+    "MovingAveragePerChannelMinMaxObserver",
+    "NoopObserver",
+    "ObserverBase",
+    "PerChannelMinMaxObserver",
+    "PlaceholderObserver",
+    "RecordingObserver",
+    "ReuseInputObserver",
+    "UniformQuantizationObserverBase",
+]
+
+
+class _PartialWrapper:
+    def __init__(self, p):
+        self.p = p
+        self.callable_args = {}
+
+    def __call__(self, *args, **keywords):
+        # call each arg in callable_args and add them partial, then run with keywords
+        # skip if arg_name in keywords so its possible to overwrite
+        for arg_name in self.callable_args:
+            if arg_name not in keywords:
+                keywords = {**keywords, arg_name: self.callable_args[arg_name]()}
+        return self.p(*args, **keywords)
+
+    def __repr__(self):
+        return self.p.__repr__() + self.callable_args.__repr__()
+
+    def with_args(self, **kwargs):
+        return _with_args(self, **kwargs)
+
+    def with_callable_args(self, **kwargs):
+        result = _PartialWrapper(p=self.p)
+        result.callable_args = {**self.callable_args, **kwargs}
+        return result
+
+
+def _with_args(cls_or_self, **kwargs):
+    r"""Wrapper that allows creation of class factories.
+
+    This can be useful when there is a need to create classes with the same
+    constructor arguments, but different instances. Can be used in conjunction with
+    _callable_args
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Undefined vars")
+        >>> Foo.with_args = classmethod(_with_args)
+        >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
+        >>> foo_instance1 = foo_builder()
+        >>> foo_instance2 = foo_builder()
+        >>> id(foo_instance1) == id(foo_instance2)
+        False
+    """
+    r = _PartialWrapper(partial(cls_or_self, **kwargs))
+    return r
+
+def _with_callable_args(cls_or_self, **kwargs):
+    r"""Wrapper that allows creation of class factories args that need to be
+    called at construction time.
+
+    This can be useful when there is a need to create classes with the same
+    constructor arguments, but different instances and those arguments should only
+    be calculated at construction time. Can be used in conjunction with _with_args
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Undefined vars")
+        >>> Foo.with_callable_args = classmethod(_with_callable_args)
+        >>> Foo.with_args = classmethod(_with_args)
+        >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan")
+        >>> foo_instance1 = foo_builder()
+        >>> # wait 50
+        >>> foo_instance2 = foo_builder()
+        >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time)
+        False
+    """
+    r = _PartialWrapper(partial(cls_or_self))
+    return r.with_callable_args(**kwargs)
+
+
+ABC: Any = ABCMeta("ABC", (object,), {})  # compatible with Python 2 *and* 3:
+
+
+class ObserverBase(ABC, nn.Module):
+    r"""Base observer Module.
+    Any observer implementation should derive from this class.
+
+    Concrete observers should follow the same API. In forward, they will update
+    the statistics of the observed Tensor. And they should provide a
+    `calculate_qparams` function that computes the quantization parameters given
+    the collected statistics.
+
+    Args:
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec.
+        is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization
+        or static quantization
+    """
+
+    def __init__(self, dtype, is_dynamic=False):
+        super().__init__()
+        self.dtype = dtype
+        self.is_dynamic = is_dynamic
+
+    @abstractmethod
+    def forward(self, x):
+        pass
+
+    @abstractmethod
+    def calculate_qparams(self, **kwargs):
+        pass
+
+    with_args = classmethod(_with_args)
+    with_callable_args = classmethod(_with_callable_args)
+
+
+class UniformQuantizationObserverBase(ObserverBase):
+    r"""Common base for all observers using uniform quantization to calculate
+    scale and zero_point.
+
+    Args:
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec.
+        qscheme: Quantization scheme to be used.
+        reduce_range: Reduces the range of the quantized data type by 1 bit.
+                      This is sometimes required to avoid instruction overflow.
+        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
+        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
+
+    .. warning::
+
+        :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
+               or `torch.int8` or `torch.uint8`
+
+    .. warning::
+
+        :attr:`qscheme` can only take one of the following options:
+
+        - ``torch.per_tensor_affine``
+        - ``torch.per_tensor_symmetric``
+        - ``torch.per_channel_affine``
+        - ``torch.per_channel_symmetric``
+    """
+
+    # Note: the version is shared by all observer types
+    #
+    # Version 1/None
+    #   self
+    #
+    # Version 2 (base class only, does not include child class buffers)
+    #   self
+    #   |--- eps : Tensor
+    #
+    # Version 3
+    #   for HistogramObserver only, changed the shape of uninitialized
+    #   min_val and max_val buffers from torch.Size([0]) to torch.Size([])
+    #   for PerChannelObservers, changed the name of the buffers from min_vals
+    #   to min_val and from max_vals to max_val.
+    _version = 3
+
+    eps: torch.Tensor
+
+    def __init__(
+        self,
+        dtype=torch.quint8,
+        qscheme=torch.per_tensor_affine,
+        reduce_range=False,
+        quant_min=None,
+        quant_max=None,
+        factory_kwargs=None,
+        eps=torch.finfo(torch.float32).eps,
+        is_dynamic=False,
+        **kwargs,
+    ) -> None:
+        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
+        super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs)
+        self.qscheme = qscheme
+        if reduce_range:
+            warnings.warn(
+                "Please use quant_min and quant_max to specify the range for observers. \
+                    reduce_range will be deprecated in a future release of PyTorch."
+            )
+        self.reduce_range = reduce_range
+        self.register_buffer(
+            "eps", torch.tensor([eps], **factory_kwargs)
+        )
+        assert self.qscheme in (
+            torch.per_tensor_affine,
+            torch.per_tensor_symmetric,
+            torch.per_channel_affine,
+            torch.per_channel_symmetric,
+            torch.per_channel_affine_float_qparams,
+        ), "Default Observer only works for per_tensor_affine, \
+                per_tensor_symmetric, per_channel_affine, \
+                per_channel_symmetric and per_channel_float_qparams quantization scheme"
+
+        _ALLOWED_DTYPES = (
+            torch.qint8,
+            torch.quint8,
+            torch.quint4x2,
+            torch.qint32,
+            torch.int8,
+            torch.uint8,
+            torch.int16,
+            torch.int32,
+        )
+
+        assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type"
+        self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
+        if self.has_customized_qrange:
+            validate_qmin_qmax(quant_min, quant_max)
+        self.quant_min, self.quant_max = \
+            calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range)
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+
+        version = local_metadata.get("version", None)
+
+        if version is None or version == 1:
+            # eps was moved to a buffer in version 2
+            eps = torch.tensor([torch.finfo(torch.float32).eps])
+            state_dict[prefix + "eps"] = eps
+
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    @torch.jit.export
+    def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
+        r"""Validates that the user-specified quantization range is properly initialized
+        and within the given bound supported by the observer dtype.
+
+        To accommodate lower-bit quantization with respect to the existing torch.qint8 and
+        torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
+        in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
+        values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
+        fake quantization. These estimates are compared against parameters learned through backpropagation.
+        The related literatures for scale and zero point via backpropagation are as follows:
+
+        Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
+        Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
+        """
+        # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
+        # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
+        assert (
+            quant_min <= 0 <= quant_max
+        ), "Used-specified quantization range must include 0."
+        assert (
+            quant_min < quant_max
+        ), "qmin must be strictly less than qmax for user-specified quantization range."
+
+    @torch.jit.export
+    def _calculate_qparams(
+        self, min_val: torch.Tensor, max_val: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""Calculates the quantization parameters, given min and max
+        value tensors. Works for both per tensor and per channel cases
+
+        Args:
+            min_val: Minimum values per channel
+            max_val: Maximum values per channel
+
+        Returns:
+            scales: Scales tensor of shape (#channels,)
+            zero_points: Zero points tensor of shape (#channels,)
+        """
+        # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme
+        # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
+        # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code
+        # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor.
+        # TODO(jakeszwe, jerryzh168)
+        if not check_min_max_valid(min_val, max_val):
+            return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
+
+        quant_min, quant_max = self.quant_min, self.quant_max
+        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+
+        device = min_val_neg.device
+        scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
+        zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+        if (
+            self.qscheme == torch.per_tensor_symmetric
+            or self.qscheme == torch.per_channel_symmetric
+        ):
+            max_val_pos = torch.max(-min_val_neg, max_val_pos)
+            scale = max_val_pos / (float(quant_max - quant_min) / 2)
+            scale = torch.max(scale, self.eps)
+            if self.dtype in [torch.quint8, torch.uint8]:
+                if self.has_customized_qrange:
+                    # When customized quantization range is used, down-rounded midpoint of the range is chosen.
+                    zero_point = zero_point.new_full(
+                        zero_point.size(), (quant_min + quant_max) // 2
+                    )
+                else:
+                    zero_point = zero_point.new_full(zero_point.size(), 128)
+        elif self.qscheme == torch.per_channel_affine_float_qparams:
+            scale = (max_val - min_val) / float(quant_max - quant_min)
+            scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
+            # We use the quantize function
+            # xq = Round(Xf * inv_scale + zero_point),
+            # setting zero_point to (-1 * min *inv_scale) we get
+            # Xq = Round((Xf - min) * inv_scale)
+            zero_point = -1 * min_val / scale
+        else:
+            scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
+            scale = torch.max(scale, self.eps)
+            zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
+            zero_point = torch.clamp(zero_point, quant_min, quant_max)
+
+        # For scalar values, cast them to Tensors of size 1 to keep the shape
+        # consistent with default values in FakeQuantize.
+        if len(scale.shape) == 0:
+            # TODO: switch to scale.item() after adding JIT support
+            scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
+        if len(zero_point.shape) == 0:
+            # TODO: switch to zero_point.item() after adding JIT support
+            zero_point = torch.tensor(
+                [int(zero_point)], dtype=zero_point.dtype, device=device
+            )
+            if self.qscheme == torch.per_channel_affine_float_qparams:
+                zero_point = torch.tensor(
+                    [float(zero_point)], dtype=zero_point.dtype, device=device
+                )
+
+        return scale, zero_point
+
+    @torch.jit.export
+    def reset_min_max_vals(self):
+        raise NotImplementedError("Cannot reset min/max values in the given observer.")
+
+
+# Originally, this class was called `_ObserverBase`.  Keeping the old name around
+# for backwards compatibility.
+# TODO(after v1.13): delete this
+_ObserverBase = UniformQuantizationObserverBase
+
+
+class MinMaxObserver(UniformQuantizationObserverBase):
+    r"""Observer module for computing the quantization parameters based on the
+    running min and max values.
+
+    This observer uses the tensor min/max statistics to compute the quantization
+    parameters. The module records the running minimum and maximum of incoming
+    tensors, and uses this statistic to compute the quantization parameters.
+
+    Args:
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec.
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
+        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
+
+    Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
+    scale :math:`s` and zero point :math:`z` are computed as:
+
+    The running minimum/maximum :math:`x_\text{min/max}` is computed as:
+
+    .. math::
+
+        \begin{array}{ll}
+        x_\text{min} &= \begin{cases}
+            \min(X) & \text{if~}x_\text{min} = \text{None} \\
+            \min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
+        \end{cases}\\
+        x_\text{max} &= \begin{cases}
+            \max(X) & \text{if~}x_\text{max} = \text{None} \\
+            \max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
+        \end{cases}\\
+        \end{array}
+
+    where :math:`X` is the observed tensor.
+
+    The scale :math:`s` and zero point :math:`z` are then computed as:
+
+    .. math::
+
+        \begin{aligned}
+            \text{if Symmetric:}&\\
+            &s = 2 \max(|x_\text{min}|, x_\text{max}) /
+                \left( Q_\text{max} - Q_\text{min} \right) \\
+            &z = \begin{cases}
+                0 & \text{if dtype is qint8} \\
+                128 & \text{otherwise}
+            \end{cases}\\
+            \text{Otherwise:}&\\
+                &s = \left( x_\text{max} - x_\text{min}  \right ) /
+                    \left( Q_\text{max} - Q_\text{min} \right ) \\
+                &z = Q_\text{min} - \text{round}(x_\text{min} / s)
+        \end{aligned}
+
+    where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and
+    maximum of the quantized data type.
+
+    .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
+
+    .. note:: If the running minimum equals to the running maximum, the scale
+              and zero_point are set to 1.0 and 0.
+    """
+    min_val: torch.Tensor
+    max_val: torch.Tensor
+
+    def __init__(
+        self,
+        dtype=torch.quint8,
+        qscheme=torch.per_tensor_affine,
+        reduce_range=False,
+        quant_min=None,
+        quant_max=None,
+        factory_kwargs=None,
+        eps=torch.finfo(torch.float32).eps,
+        is_dynamic=False,
+        **kwargs,
+    ) -> None:
+        if not is_per_tensor(qscheme):
+            raise NotImplementedError(
+                "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \
+                    and torch.per_tensor_affine."
+            )
+        # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but
+        # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it
+        # supports dynamic quantization, we may need to better error checking here
+
+        # For x86 quantized kernels, we need to ensure that the vpmaddubsw
+        # instruction does not overflow. We allow for a reduce_range argument to
+        # observers that reduces the quantized range to (0,127) or (-64, 63).
+        # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
+        # This is not an optimal choice for non x86 backends as it loses a bit
+        # of precision for activations.
+        super().__init__(
+            dtype=dtype,
+            qscheme=qscheme,
+            reduce_range=reduce_range,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            factory_kwargs=factory_kwargs,
+            eps=eps,
+            is_dynamic=is_dynamic,
+            **kwargs,
+        )
+        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
+        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
+        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
+        if (
+            self.qscheme == torch.per_tensor_symmetric
+            and self.reduce_range
+            and self.dtype == torch.quint8
+        ):
+            raise NotImplementedError(
+                "Cannot reduce range for symmetric \
+                                       quantization for quint8"
+            )
+
+    def forward(self, x_orig):
+        r"""Records the running minimum and maximum of ``x``."""
+        if x_orig.numel() == 0:
+            return x_orig
+        x = x_orig.detach()  # avoid keeping autograd tape
+        x = x.to(self.min_val.dtype)
+        min_val_cur, max_val_cur = torch.aminmax(x)
+        min_val = torch.min(min_val_cur, self.min_val)
+        max_val = torch.max(max_val_cur, self.max_val)
+        self.min_val.copy_(min_val)
+        self.max_val.copy_(max_val)
+        return x_orig
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        r"""Calculates the quantization parameters."""
+        return self._calculate_qparams(self.min_val, self.max_val)
+
+    @torch.jit.export
+    def extra_repr(self):
+        return f"min_val={self.min_val}, max_val={self.max_val}"
+
+    @torch.jit.export
+    def reset_min_max_vals(self):
+        """Resets the min/max values."""
+        self.min_val.copy_(torch.tensor(float("inf")))
+        self.max_val.copy_(torch.tensor(float("-inf")))
+
+class MovingAverageMinMaxObserver(MinMaxObserver):
+    r"""Observer module for computing the quantization parameters based on the
+    moving average of the min and max values.
+
+    This observer computes the quantization parameters based on the moving
+    averages of minimums and maximums of the incoming tensors. The module
+    records the average minimum and maximum of incoming tensors, and uses this
+    statistic to compute the quantization parameters.
+
+    Args:
+        averaging_constant: Averaging constant for min/max.
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec.
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
+        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
+
+    The moving average min/max is computed as follows
+
+    .. math::
+
+        \begin{array}{ll}
+                x_\text{min} = \begin{cases}
+                    \min(X) & \text{if~}x_\text{min} = \text{None} \\
+                    (1 - c) x_\text{min} + c \min(X) & \text{otherwise}
+                \end{cases}\\
+                x_\text{max} = \begin{cases}
+                    \max(X) & \text{if~}x_\text{max} = \text{None} \\
+                    (1 - c) x_\text{max} + c \max(X) & \text{otherwise}
+                \end{cases}\\
+        \end{array}
+
+    where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is
+    is the incoming tensor, and :math:`c` is the ``averaging_constant``.
+
+    The scale and zero point are then computed as in
+    :class:`~torch.ao.quantization.observer.MinMaxObserver`.
+
+    .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme.
+
+    .. note:: If the running minimum equals to the running maximum, the scale
+              and zero_point are set to 1.0 and 0.
+    """
+
+    def __init__(
+        self,
+        averaging_constant=0.01,
+        dtype=torch.quint8,
+        qscheme=torch.per_tensor_affine,
+        reduce_range=False,
+        quant_min=None,
+        quant_max=None,
+        eps=torch.finfo(torch.float32).eps,
+        is_dynamic=False,
+        **kwargs
+    ) -> None:
+        if not is_per_tensor(qscheme):
+            raise NotImplementedError(
+                f"MovingAverageMinMaxObserver's qscheme only support \
+                torch.per_tensor_symmetric and torch.per_tensor_affine. \
+                but got: {qscheme}"
+            )
+        self.averaging_constant = averaging_constant
+        if is_dynamic and self.averaging_constant != 1:
+            raise NotImplementedError(
+                "MovingAverageMinMaxObserver doesn't support dynamic quantization for "
+                f"averaging constant of {self.averaging_constant}"
+            )
+        super().__init__(
+            dtype=dtype,
+            qscheme=qscheme,
+            reduce_range=reduce_range,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            eps=eps,
+            is_dynamic=is_dynamic,
+            **kwargs
+        )
+
+    def forward(self, x_orig):
+        if x_orig.numel() == 0:
+            return x_orig
+        x = x_orig.detach()  # avoid keeping autograd tape
+        x = x.to(self.min_val.dtype)
+        min_val = self.min_val
+        max_val = self.max_val
+        if min_val == float("inf") and max_val == float("-inf"):
+            min_val, max_val = torch.aminmax(x)
+        else:
+            min_val_cur, max_val_cur = torch.aminmax(x)
+            min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
+            max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
+        self.min_val.copy_(min_val)
+        self.max_val.copy_(max_val)
+        return x_orig
+
+
+class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
+    r"""Observer module for computing the quantization parameters based on the
+    running per channel min and max values.
+
+    This observer uses the tensor min/max statistics to compute the per channel
+    quantization parameters. The module records the running minimum and maximum
+    of incoming tensors, and uses this statistic to compute the quantization
+    parameters.
+
+    Args:
+        ch_axis: Channel axis
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec.
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
+        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
+
+    The quantization parameters are computed the same way as in
+    :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
+    that the running min/max values are stored per channel.
+    Scales and zero points are thus computed per channel as well.
+
+    .. note:: If the running minimum equals to the running maximum, the scales
+              and zero_points are set to 1.0 and 0.
+    """
+    min_val: torch.Tensor
+    max_val: torch.Tensor
+
+    def __init__(
+        self,
+        ch_axis=0,
+        dtype=torch.quint8,
+        qscheme=torch.per_channel_affine,
+        reduce_range=False,
+        quant_min=None,
+        quant_max=None,
+        factory_kwargs=None,
+        eps=torch.finfo(torch.float32).eps,
+        is_dynamic=False,
+        **kwargs,
+    ) -> None:
+        if not is_per_channel(qscheme):
+            raise NotImplementedError(
+                "PerChannelMinMaxObserver's qscheme only support \
+                    torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams."
+            )
+        if is_dynamic:
+            raise NotImplementedError(
+                "PerChannelMinMaxObserver doesn't support dynamic quantization"
+            )
+        super().__init__(
+            dtype=dtype,
+            qscheme=qscheme,
+            reduce_range=reduce_range,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            factory_kwargs=factory_kwargs,
+            eps=eps,
+            is_dynamic=is_dynamic,
+            **kwargs,
+        )
+        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
+        self.ch_axis = ch_axis
+        self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
+        self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
+        if (
+            self.qscheme == torch.per_channel_symmetric
+            and self.reduce_range
+            and self.dtype == torch.quint8
+        ):
+            raise NotImplementedError(
+                "Cannot reduce range for symmetric quantization for quint8"
+            )
+
+    def forward(self, x_orig):
+        return self._forward(x_orig)
+
+    def _forward(self, x_orig):
+        if x_orig.numel() == 0:
+            return x_orig
+        x = x_orig.detach()  # avoid keeping autograd tape
+        min_val = self.min_val
+        max_val = self.max_val
+        x_dim = x.size()
+
+        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+        new_axis_list[self.ch_axis] = 0
+        new_axis_list[0] = self.ch_axis
+        y = x.permute(new_axis_list)
+        # Need to match dtype of min/max because the updates to buffers
+        # are done in place and types need to match for comparisons
+        y = y.to(self.min_val.dtype)
+        y = torch.flatten(y, start_dim=1)
+        if min_val.numel() == 0 or max_val.numel() == 0:
+            min_val, max_val = torch.aminmax(y, dim=1)
+        else:
+            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
+            min_val = torch.min(min_val_cur, min_val)
+            max_val = torch.max(max_val_cur, max_val)
+        self.min_val.resize_(min_val.shape)
+        self.max_val.resize_(max_val.shape)
+        self.min_val.copy_(min_val)
+        self.max_val.copy_(max_val)
+        return x_orig
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        return self._calculate_qparams(self.min_val, self.max_val)
+
+    def extra_repr(self):
+        return f"min_val={self.min_val}, max_val={self.max_val}"
+
+    def _load_from_state_dict(
+        self,
+        state_dict: Dict[str, Any],
+        prefix: str,
+        local_metadata: Dict[str, torch.Tensor],
+        strict: bool,
+        missing_keys: List[str],
+        unexpected_keys: List[str],
+        error_msgs: List[str],
+    ):
+        version = local_metadata.get("version", None)
+        if version is not None and version < 3:
+            local_state = ["min_vals", "max_vals"]
+            expected_min_name = "min_vals"
+            expected_max_name = "max_vals"
+        else:
+            local_state = ["min_val", "max_val"]
+            expected_min_name = "min_val"
+            expected_max_name = "max_val"
+        for name in local_state:
+            key = prefix + name
+            if key in state_dict:
+                val = state_dict[key]
+                # Custom handling to allow loading min_val or max_val
+                # of size N into uninitialized buffers of size 0. The
+                # buffers are resized here, and the values are copied in
+                # the default state_dict loading code of the parent.
+                if name == expected_min_name:
+                    self.min_val.resize_(val.shape)
+                elif name == expected_max_name:
+                    self.max_val.resize_(val.shape)
+                else:
+                    warnings.warn(f"Observer load_from_state_dict got unexpected name {name}")
+                # For torchscript module we need to update the attributes here since we do not
+                # call the `_load_from_state_dict` function defined module.py
+                if torch.jit.is_scripting():
+                    if name == expected_min_name:
+                        self.min_val.copy_(val)
+                    elif name == expected_max_name:
+                        self.max_val.copy_(val)
+                    else:
+                        warnings.warn(f"Observer load_from_state_dict got unexpected name {name}")
+            elif strict:
+                missing_keys.append(key)
+
+        if not torch.jit.is_scripting():
+            super()._load_from_state_dict(
+                state_dict,
+                prefix,
+                local_metadata,
+                False,
+                missing_keys,
+                unexpected_keys,
+                error_msgs,
+            )
+
+    def _load_from_state_dict_script(
+        self,
+        state_dict: Dict[str, Any],
+        prefix: str,
+        local_metadata: Dict[str, torch.Tensor],
+        strict: bool,
+        missing_keys: List[str],
+        unexpected_keys: List[str],
+        error_msgs: List[str],
+    ):
+
+        self._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    @torch.jit.export
+    def reset_min_max_vals(self):
+        """Resets the min/max values."""
+        # This used to be torch.ones but that does not work because
+        # JIT compiler can optimize it via common subexpression elimination
+        # in which case both min_val and max_val point to the same tensor.
+        self.min_val = torch.rand(0, )
+        self.max_val = torch.rand(0, )
+
+
+class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
+    r"""Observer module for computing the quantization parameters based on the
+    running per channel min and max values.
+
+    This observer uses the tensor min/max statistics to compute the per channel
+    quantization parameters. The module records the running minimum and maximum
+    of incoming tensors, and uses this statistic to compute the quantization
+    parameters.
+
+    Args:
+        averaging_constant: Averaging constant for min/max.
+        ch_axis: Channel axis
+        dtype: Quantized data type
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
+        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
+
+    The quantization parameters are computed the same way as in
+    :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the
+    difference that the running min/max values are stored per channel.
+    Scales and zero points are thus computed per channel as well.
+
+    .. note:: If the running minimum equals to the running maximum, the scales
+              and zero_points are set to 1.0 and 0.
+    """
+
+    def __init__(
+        self,
+        averaging_constant=0.01,
+        ch_axis=0,
+        dtype=torch.quint8,
+        qscheme=torch.per_channel_affine,
+        reduce_range=False,
+        quant_min=None,
+        quant_max=None,
+        eps=torch.finfo(torch.float32).eps,
+        is_dynamic=False,
+        **kwargs
+    ) -> None:
+        if not is_per_channel(qscheme):
+            raise NotImplementedError(
+                "MovingAveragePerChannelMinMaxObserver's qscheme only support \
+                    torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams."
+            )
+        if is_dynamic:
+            raise NotImplementedError(
+                "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization"
+            )
+        super().__init__(
+            ch_axis=ch_axis,
+            dtype=dtype,
+            qscheme=qscheme,
+            reduce_range=reduce_range,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            eps=eps,
+            is_dynamic=is_dynamic,
+            **kwargs
+        )
+        self.averaging_constant = averaging_constant
+
+    def forward(self, x_orig):
+        if x_orig.numel() == 0:
+            return x_orig
+        x = x_orig.detach()  # avoid keeping autograd tape
+        x = x.to(self.min_val.dtype)
+        min_val = self.min_val
+        max_val = self.max_val
+        x_dim = x.size()
+
+        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+        new_axis_list[self.ch_axis] = 0
+        new_axis_list[0] = self.ch_axis
+        y = x.permute(new_axis_list)
+        y = torch.flatten(y, start_dim=1)
+        if min_val.numel() == 0 or max_val.numel() == 0:
+            min_val, max_val = torch.aminmax(y, dim=1)
+        else:
+            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
+            min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
+            max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
+        self.min_val.resize_(min_val.shape)
+        self.max_val.resize_(max_val.shape)
+        self.min_val.copy_(min_val)
+        self.max_val.copy_(max_val)
+        return x_orig
+
+
+class HistogramObserver(UniformQuantizationObserverBase):
+    r"""
+    The module records the running histogram of tensor values along with
+    min/max values. ``calculate_qparams`` will calculate scale and zero_point.
+
+    Args:
+        bins: Number of bins to use for the histogram
+        upsample_rate: Factor by which the histograms are upsampled, this is
+                       used to interpolate histograms with varying ranges across observations
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
+
+    The scale and zero point are computed as follows:
+
+    1. Create the histogram of the incoming inputs.
+        The histogram is computed continuously, and the ranges per bin change
+        with every new tensor observed.
+    2. Search the distribution in the histogram for optimal min/max values.
+        The search for the min/max values ensures the minimization of the
+        quantization error with respect to the floating point model.
+    3. Compute the scale and zero point the same way as in the
+        :class:`~torch.ao.quantization.MinMaxObserver`
+    """
+    histogram: torch.Tensor
+    min_val: torch.Tensor
+    max_val: torch.Tensor
+
+    def __init__(
+        self,
+        bins: int = 2048,
+        upsample_rate: int = 128,
+        dtype: torch.dtype = torch.quint8,
+        qscheme=torch.per_tensor_affine,
+        reduce_range=False,
+        quant_min=None,
+        quant_max=None,
+        factory_kwargs=None,
+        eps=torch.finfo(torch.float32).eps,
+        is_dynamic=False,
+        **kwargs,
+    ) -> None:
+        if not is_per_tensor(qscheme):
+            raise NotImplementedError(
+                "HistogramObserver's qscheme only support torch.per_tensor_symmetric \
+                    and torch.per_tensor_affine."
+            )
+        if is_dynamic:
+            raise NotImplementedError(
+                "HistogramObserver doesn't support dynamic quantization"
+            )
+        # bins: The number of bins used for histogram calculation.
+        super().__init__(
+            dtype=dtype,
+            qscheme=qscheme,
+            reduce_range=reduce_range,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            factory_kwargs=factory_kwargs,
+            eps=eps,
+            is_dynamic=is_dynamic,
+            **kwargs
+        )
+        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
+        self.bins = bins
+        self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
+        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
+        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
+        self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
+        self.upsample_rate = upsample_rate
+
+    def _get_norm(
+        self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor
+    ) -> torch.Tensor:
+        r"""
+        Compute the norm of the values uniformaly distributed between
+        delta_begin and delta_end.
+        Currently only L2 norm is supported.
+
+        norm = density * (integral_{begin, end} x^2)
+             = density * (end^3 - begin^3) / 3
+        """
+        norm = (
+            delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin
+        ) / 3
+        return density * norm
+
+    def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int):
+        r"""
+        Compute the quantization error if we use start_bin to end_bin as the
+        min and max to do the quantization.
+        """
+        bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
+
+        dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
+        if dst_bin_width == 0.0:
+            return 0.0
+
+        src_bin = torch.arange(self.bins, device=self.histogram.device)
+        # distances from the beginning of first dst_bin to the beginning and
+        # end of src_bin
+        src_bin_begin = (src_bin - next_start_bin) * bin_width
+        src_bin_end = src_bin_begin + bin_width
+
+        # which dst_bins the beginning and end of src_bin belong to?
+        dst_bin_of_begin = torch.clamp(
+            torch.div(src_bin_begin, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1
+        )
+        dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width
+
+        dst_bin_of_end = torch.clamp(
+            torch.div(src_bin_end, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1
+        )
+        density = self.histogram / bin_width
+
+        norm = torch.zeros(self.bins, device=self.histogram.device)
+
+        delta_begin = src_bin_begin - dst_bin_of_begin_center
+        delta_end = dst_bin_width / 2
+        norm += self._get_norm(delta_begin,
+                               torch.ones(self.bins, device=self.histogram.device) * delta_end,
+                               density)
+
+        norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
+            torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
+        )
+
+        dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2
+
+        delta_begin = -dst_bin_width / 2
+        delta_end = src_bin_end - dst_bin_of_end_center
+        norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)
+
+        return norm.sum().item()
+
+    def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""Non-linear parameter search.
+
+        An approximation for L2 error minimization for selecting min/max.
+        By selecting new min/max, we filter out outliers in input distribution.
+        This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
+        caffe2/quantization/server/norm_minimization.cc
+        """
+        assert self.histogram.size()[0] == self.bins, "bins mismatch"
+        bin_width = (self.max_val - self.min_val) / self.bins
+
+        # cumulative sum
+        total = torch.sum(self.histogram).item()
+        cSum = torch.cumsum(self.histogram, dim=0)
+
+        stepsize = 1e-5  # granularity
+        alpha = 0.0  # lower bound
+        beta = 1.0  # upper bound
+        start_bin = 0
+        end_bin = self.bins - 1
+        norm_min = float("inf")
+
+        while alpha < beta:
+            # Find the next step
+            next_alpha = alpha + stepsize
+            next_beta = beta - stepsize
+
+            # find the left and right bins between the quantile bounds
+            l = start_bin
+            r = end_bin
+            while l < end_bin and cSum[l] < next_alpha * total:
+                l = l + 1
+            while r > start_bin and cSum[r] > next_beta * total:
+                r = r - 1
+
+            # decide the next move
+            next_start_bin = start_bin
+            next_end_bin = end_bin
+            if (l - start_bin) > (end_bin - r):
+                # move the start bin
+                next_start_bin = l
+                alpha = next_alpha
+            else:
+                # move the end bin
+                next_end_bin = r
+                beta = next_beta
+
+            if next_start_bin == start_bin and next_end_bin == end_bin:
+                continue
+
+            # calculate the quantization error using next_start_bin and next_end_bin
+            norm = self._compute_quantization_error(next_start_bin, next_end_bin)
+
+            if norm > norm_min:
+                break
+            norm_min = norm
+            start_bin = next_start_bin
+            end_bin = next_end_bin
+
+        new_min = self.min_val + bin_width * start_bin
+        new_max = self.min_val + bin_width * (end_bin + 1)
+        return new_min, new_max
+
+    def _adjust_min_max(
+        self, combined_min: torch.Tensor, combined_max: torch.Tensor, upsample_rate: int
+    ) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
+        # We ensure that:
+        # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
+        # This allows us to have a common grid of resolution s, where we can align
+        # the input histogram
+        # start_idx maps min_val to the histogram bin index.
+
+        # Compute the width of histogram bins is a straightforward solution, where
+        # hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
+        # Underflow happens if the numerator is close to the smallest positive subnormal number of FP32
+        # Therefore, we avoid such division operation.
+        downsample_rate = int(
+            torch.ceil(
+                ((combined_max - combined_min) / (self.max_val - self.min_val)) * upsample_rate
+            ).item()
+        )
+        e = downsample_rate / upsample_rate * (self.max_val - self.min_val) - (combined_max - combined_min)
+        start_idx = int(
+            torch.round((self.min_val - combined_min) / (self.max_val - self.min_val) * self.bins * upsample_rate).item()
+        )
+        combined_max = combined_max + e
+        return combined_min, combined_max, downsample_rate, start_idx
+
+    def _combine_histograms(
+        self,
+        orig_hist: torch.Tensor,
+        new_hist: torch.Tensor,
+        upsample_rate: int,
+        downsample_rate: int,
+        start_idx: int,
+        Nbins: int,
+    ) -> torch.Tensor:
+        # First up-sample the histogram with new data by a factor of L
+        # This creates an approximate probability density thats piecewise constant
+        upsampled_histogram = new_hist.repeat_interleave(upsample_rate)
+        # Now insert the upsampled histogram into the output
+        # histogram, which is initialized with zeros.
+        # The offset at which the histogram is introduced is determined
+        # by the start index as the output histogram can cover a wider range
+        histogram_with_output_range = torch.zeros(
+            (Nbins * downsample_rate), device=orig_hist.device
+        )
+        histogram_with_output_range[
+            start_idx : Nbins * upsample_rate + start_idx
+        ] = upsampled_histogram
+        # Compute integral histogram, double precision is needed to ensure
+        # that there are no overflows
+        integral_histogram = torch.cumsum(
+            histogram_with_output_range, 0, dtype=torch.double
+        )[downsample_rate - 1 :: downsample_rate]
+        # Finally perform interpolation
+        shifted_integral_histogram = torch.zeros((Nbins), device=orig_hist.device)
+        shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
+        interpolated_histogram = (
+            integral_histogram - shifted_integral_histogram
+        ) / upsample_rate
+        orig_hist = orig_hist + interpolated_histogram.to(torch.float)
+        return orig_hist
+
+    def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
+        if x_orig.numel() == 0:
+            return x_orig
+        x = x_orig.detach()
+        x_min, x_max = torch.aminmax(x)
+        # want to ignore torch.inf since we don't actually
+        # want to make our quantization range infinite
+        # and in practice those values will be clamped
+        if x_min == -torch.inf or x_max == torch.inf:
+            warnings.warn("torch.inf detected in input tensor, ignoring input")
+            x = x[x.abs() != torch.inf]
+            if x.numel() == 0:
+                return x_orig
+            x_min, x_max = torch.aminmax(x)
+        min_val = self.min_val
+        max_val = self.max_val
+        same_values = min_val.item() == max_val.item()
+        is_uninitialized = min_val == float("inf") and max_val == float("-inf")
+        if is_uninitialized or same_values:
+            min_val, max_val = x_min, x_max
+            self.min_val.resize_(min_val.shape)
+            self.min_val.copy_(min_val)
+            self.max_val.resize_(max_val.shape)
+            self.max_val.copy_(max_val)
+            assert (
+                min_val.numel() == 1 and max_val.numel() == 1
+            ), "histogram min/max values must be scalar."
+            torch.histc(
+                x, self.bins, min=min_val, max=max_val, out=self.histogram  # type: ignore[arg-type]
+            )
+        else:
+            new_min, new_max = x_min, x_max
+            combined_min = torch.min(new_min, min_val)
+            combined_max = torch.max(new_max, max_val)
+            # combine the existing histogram and new histogram into 1 histogram
+            # We do this by first upsampling the histogram to a dense grid
+            # and then downsampling the histogram efficiently
+            (
+                combined_min,
+                combined_max,
+                downsample_rate,
+                start_idx,
+            ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
+            assert (
+                combined_min.numel() == 1 and combined_max.numel() == 1
+            ), "histogram min/max values must be scalar."
+
+            # TODO: For some reason, this is required for it to pass torchscript test
+            # combined_min and combined_max should already have requires_grad set to False
+            combined_min, combined_max = combined_min.detach(), combined_max.detach()
+
+            combined_histogram = torch.histc(
+                x, self.bins, min=combined_min, max=combined_max  # type: ignore[arg-type]
+            )
+            if combined_min == min_val and combined_max == max_val:
+                combined_histogram += self.histogram
+            else:
+                combined_histogram = self._combine_histograms(
+                    combined_histogram,
+                    self.histogram,
+                    self.upsample_rate,
+                    downsample_rate,
+                    start_idx,
+                    self.bins,
+                )
+
+            self.histogram.detach_().resize_(combined_histogram.shape)
+            self.histogram.copy_(combined_histogram)
+            self.min_val.detach_().resize_(combined_min.shape)
+            self.min_val.copy_(combined_min)
+            self.max_val.detach_().resize_(combined_max.shape)
+            self.max_val.copy_(combined_max)
+        return x_orig
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        is_uninitialized = self.min_val == float("inf") and self.max_val == float(
+            "-inf"
+        )
+        if is_uninitialized:
+            warnings.warn(
+                "must run observer before calling calculate_qparams.\
+                                    Returning default scale and zero point "
+            )
+            return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor([0], device=self.min_val.device.type)
+        assert self.bins == len(self.histogram), (
+            "The number of bins in histogram should be equal to the number of bins "
+            "supplied while making this observer"
+        )
+
+        new_min, new_max = self._non_linear_param_search()
+
+        return self._calculate_qparams(new_min, new_max)
+
+    def _save_to_state_dict(self, destination, prefix, keep_vars):
+        super()._save_to_state_dict(destination, prefix, keep_vars)
+        destination[prefix + "min_val"] = self.min_val
+        destination[prefix + "max_val"] = self.max_val
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+
+        if version is None or version < 3:
+            # if min_val and max_val are not initialized, update their shape
+            # to account for the differences between v2 and v3
+            min_val_name, max_val_name = prefix + "min_val", prefix + "max_val"
+            if min_val_name in state_dict:
+                if state_dict[min_val_name].shape == torch.Size([0]):
+                    state_dict[min_val_name] = torch.tensor(float("inf"))
+            if max_val_name in state_dict:
+                if state_dict[max_val_name].shape == torch.Size([0]):
+                    state_dict[max_val_name] = torch.tensor(float("-inf"))
+
+        local_state = ["min_val", "max_val"]
+        for name in local_state:
+            key = prefix + name
+            if key in state_dict:
+                val = state_dict[key]
+                setattr(self, name, val)
+            elif strict:
+                missing_keys.append(key)
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    def extra_repr(self):
+        return f"min_val={self.min_val}, max_val={self.max_val}"
+
+
+class FixedQParamsObserver(ObserverBase):
+    r"""
+    Observer that simulates quantize and dequantize with fixed
+    quantization parameters in training time. Only per tensor
+    quantization is supported.
+
+    Args:
+        `scale` (float): fixed scale for the observer
+        `zero_point` (int): fixed zero point for the observer
+        `dtype`, `qscheme`, `quant_min`, `quant_max`
+    """
+
+    scale: torch.Tensor
+    zero_point: torch.Tensor
+
+    def __init__(
+        self,
+        scale,
+        zero_point,
+        dtype=torch.quint8,
+        qscheme=torch.per_tensor_affine,
+        quant_min=0,
+        quant_max=255,
+        is_dynamic=False,
+        **kwargs,
+    ):
+        if is_dynamic:
+            raise NotImplementedError(
+                "FixedQParamsObserver doesn't support dynamic quantization"
+            )
+        super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs)
+        self.quant_min = quant_min
+        self.quant_max = quant_max
+        self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
+        self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
+        self.dtype = dtype
+        self.qscheme = qscheme
+
+    def forward(self, X):
+        return X
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        return self.scale, self.zero_point
+
+
+class PlaceholderObserver(ObserverBase):
+    r"""
+    Observer that doesn't do anything and just passes its configuration to the
+    quantized module's ``.from_float()``.
+
+    Can be used for quantization to float16 which doesn't require determining
+    ranges.
+
+    Args:
+        dtype: dtype argument to the `quantize` node needed to implement the
+               reference model spec.
+        quant_min: minimum value in quantized domain (TODO: align behavior with other observers)
+        quant_max: maximum value in quantized domain
+        custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
+                        (Can be used in Graph Mode Passes for special case ops).
+        compute_dtype (deprecated): if set, marks the future quantize function to use
+                       dynamic quantization instead of static quantization.
+                       This field is deprecated, use `is_dynamic=True` instead.
+        is_dynamic: if True, the `quantize` function in the reference model
+                    representation taking stats from this observer instance will
+                    use dynamic quantization.
+    """
+
+    def __init__(
+        self, dtype=torch.float32, custom_op_name="", compute_dtype=None,
+        quant_min=None, quant_max=None, qscheme=None, eps=None,
+        is_dynamic=False,
+    ) -> None:
+        super().__init__(dtype=dtype, is_dynamic=is_dynamic)
+        if qscheme is None:
+            qscheme = torch.per_tensor_affine
+        if eps is None:
+            eps = torch.finfo(torch.float32).eps
+
+        # dtype of input of the target operator, e.g. for dynamic quantization
+        # ops, the dtype will be float32
+        self.dtype = dtype
+        self.qscheme = qscheme
+        self.quant_min = quant_min
+        self.quant_max = quant_max
+        self.eps = eps
+        self.custom_op = custom_op_name
+        # used for configuration of computation type for dynamic quantization
+        if compute_dtype:
+            is_dynamic = True
+            warnings.warn(
+                "Please use `is_dynamic` instead of `compute_dtype`. \
+                    `compute_dtype` will be deprecated in a future release \
+                    of PyTorch."
+            )
+
+    def forward(self, x):
+        return x
+
+    @torch.jit.export
+    def extra_repr(self):
+        return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}"
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        raise Exception(
+            "calculate_qparams should not be called for PlaceholderObserver"
+        )
+
+
+class RecordingObserver(ObserverBase):
+    r"""
+    The module is mainly for debug and records the tensor values during runtime.
+
+    Args:
+        dtype: Quantized data type
+        qscheme: Quantization scheme to be used
+        reduce_range: Reduces the range of the quantized data type by 1 bit
+    """
+    __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]}
+
+    def __init__(self, dtype=torch.quint8):
+        super().__init__(dtype=dtype, is_dynamic=False)  # type: ignore[call-arg]
+        self.tensor_val = []
+
+    def forward(self, x):
+        self.tensor_val.append(x.clone())
+        return x
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        raise Exception("calculate_qparams should not be called for RecordingObserver")
+
+    @torch.jit.export
+    def get_tensor_value(self):
+        return self.tensor_val
+
+
+class NoopObserver(ObserverBase):
+    r"""
+    Observer that doesn't do anything and just passes its configuration to the
+    quantized module's ``.from_float()``.
+
+    Primarily used for quantization to float16 which doesn't require determining
+    ranges.
+
+    Args:
+        dtype: Quantized data type
+        custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
+                        (Can be used in Graph Mode Passes for special case ops).
+    """
+
+    def __init__(self, dtype=torch.float16, custom_op_name="") -> None:
+        super().__init__(dtype=dtype, is_dynamic=False)
+        self.dtype = dtype
+        self.custom_op = custom_op_name
+
+    def forward(self, x):
+        return x
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        raise Exception("calculate_qparams should not be called for NoopObserver")
+
+class ReuseInputObserver(ObserverBase):
+    r""" This observer is used when we want to reuse the observer from the operator
+    that produces the input Tensor, typically used for operators like reshape, e.g.
+    ```
+    x0 = ...
+    x1 = x0.reshape()
+    ```
+    if we configure x0 to be observed by some observer, let's say MinMaxObserver,
+    and reshape is configured with ReuseInputObserver, we'll reuse the observer instance
+    for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1.
+
+    Note: this is only enabled in FX Graph Mode Quantization
+    """
+    def __init__(self):
+        super().__init__(torch.quint8, is_dynamic=False)
+
+    def forward(self, x):
+        return x
+
+    @torch.jit.export
+    def calculate_qparams(self):
+        raise Exception("calculate_qparams should not be called for ReuseInputObserver")
+
+def _is_observer_script_module(mod, obs_type_name):
+    """Returns true if given mod is an instance of Observer script module."""
+    if isinstance(mod, torch.jit.RecursiveScriptModule):
+        # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver'
+        suffix = mod._c.qualified_name.split(".", 1)[1]
+        name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
+        return obs_type_name in name
+    return False
+
+
+def _is_activation_post_process(module):
+    return (
+        isinstance(module, (torch.ao.quantization.ObserverBase,
+                            torch.ao.quantization.FakeQuantizeBase)) or _is_observer_script_module(module, "quantization.observer")
+    )
+
+
+def _is_per_channel_script_obs_instance(module):
+    if isinstance(module, torch.jit.RecursiveScriptModule):
+        return _is_observer_script_module(
+            module, "quantization.observer.PerChannelMinMaxObserver"
+        ) or _is_observer_script_module(
+            module, "quantization.observer.MovingAveragePerChannelMinMaxObserver"
+        )
+    return False
+
+
+def get_observer_state_dict(mod):
+    r"""
+    Returns the state dict corresponding to the observer stats.
+    Traverse the model state_dict and extract out the stats.
+    """
+    od = OrderedDict()
+    if isinstance(mod, torch.jit.RecursiveScriptModule):
+        for k, v in mod.state_dict().items():
+            if "observer" in k:
+                od[k] = v
+    else:
+        # path for GraphModule and nn.Module (eager mode)
+        for k, v in mod.state_dict().items():
+            if "activation_post_process" in k:
+                od[k] = v
+    od._metadata = mod.state_dict()._metadata  # type: ignore[attr-defined]
+    return od
+
+
+def load_observer_state_dict(mod, obs_dict):
+    r"""
+    Given input model and a state_dict containing model observer stats,
+    load the stats back into the model. The observer state_dict can be saved
+    using torch.ao.quantization.get_observer_state_dict
+    """
+    missing_keys: List[str] = []
+    unexpected_keys: List[str] = []
+    for name, module in mod.named_modules():
+        prefix = name + "."
+        if _is_activation_post_process(module):
+            if _is_per_channel_script_obs_instance(module):
+                # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor.
+                # However this is not called when the module is scripted and we end up calling the default one in module.py
+                module._load_from_state_dict_script(
+                    obs_dict, prefix, {}, True, missing_keys, unexpected_keys, []
+                )
+            else:
+                module._load_from_state_dict(
+                    obs_dict, prefix, {}, False, missing_keys, unexpected_keys, []
+                )
+    for k in missing_keys:
+        if "observer" in k or "activation_post_process" in k:
+            raise Exception(f"Missing keys for observer {k} in state_dict")
+    for k in unexpected_keys:
+        if "observer" in k or "activation_post_process" in k:
+            raise Exception(f"Unexpected keys for observer {k} in state_dict")
+
+
+# Restrict activations to be in the range (0,127)
+default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127)
+"""
+Default observer for static quantization, usually used for debugging.
+"""
+
+default_placeholder_observer = PlaceholderObserver
+"""
+Default placeholder observer, usually used for quantization to torch.float16.
+"""
+
+default_debug_observer = RecordingObserver
+"""
+Default debug-only observer.
+"""
+
+default_weight_observer = MinMaxObserver.with_args(
+    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
+)
+"""
+Default weight observer.
+"""
+
+weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args(
+    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric,
+    quant_min=-127, quant_max=127, eps=2 ** -12)
+"""
+Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
+"""
+
+default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127)
+"""
+Default histogram observer, usually used for PTQ.
+"""
+
+default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
+    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
+)
+"""
+Default per-channel weight observer, usually used on backends where per-channel
+weight quantization is supported, such as `fbgemm`.
+"""
+
+per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args(
+    dtype=torch.qint8, qscheme=torch.per_channel_symmetric,
+    quant_min=-127, quant_max=127, eps=2 ** -12)
+"""
+Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
+"""
+
+default_dynamic_quant_observer = PlaceholderObserver.with_args(
+    dtype=torch.quint8, quant_min=0, quant_max=255, is_dynamic=True,
+)
+"""
+Default observer for dynamic quantization.
+"""
+
+default_float_qparams_observer = PerChannelMinMaxObserver.with_args(
+    dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
+)
+"""
+Default observer for a floating point zero-point.
+"""
+
+default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args(
+    dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
+)
+"""
+Default observer for a floating point zero-point and 4 bit activations.
+"""
+
+# TODO(future PR): remove these defaults and enforce activation functions
+# to explicitly specify their output range
+default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args(
+    scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
+default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args(
+    scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
+# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
+default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer
+default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer
+
+"""
+Default observers for fixed qparams operations.
+"""
+
+default_reuse_input_observer = ReuseInputObserver
+"""
+Default observer for operators like reshape that reuses the observer of input to
+the operator
+"""
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d4c3a5c76441472b7241c9545d3eeeeb49269f3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32be676ab67b098d8236792f350199d87e06e3a8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e88759f060a10af9dde483824aac3064a07abbc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/generate_numeric_debug_handle.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/generate_numeric_debug_handle.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1becfa40540a4d5095cb797d710130bedabc0985
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/generate_numeric_debug_handle.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d32dc3b1523ddc9824e9b061e6b117d659f30cae
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48949f110c41e99577b5c2b0c5fbdc24101640e5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..400fbea40d61442e714fc2789fa85abc6a5d8e00
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b940ca39992969e4e4b1e897a9f095db9fbc207b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da8152f5c770ee7cf2cd22f610b3495bb73bc73a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..8882292db139d23d5cb4f752663e8fc6d522d07e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py
@@ -0,0 +1,83 @@
+import logging
+import operator
+
+import torch
+
+from torch.ao.quantization.pt2e.utils import (
+    _filter_sym_size_users,
+    _is_valid_annotation,
+)
+
+from torch.fx.node import map_arg
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+__all__ = ["DuplicateDQPass"]
+
+_QUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+
+_DEQUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.dequantize_per_channel.default,
+]
+
+
+def _maybe_duplicate_dq(
+    gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
+):
+    annotation = user.meta.get("quantization_annotation", None)
+    if not _is_valid_annotation(annotation):
+        return
+    with gm.graph.inserting_after(dq_node):
+        new_node = gm.graph.node_copy(dq_node)
+
+        def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
+            if n == dq_node:
+                return new_node
+            else:
+                return n
+
+        new_args = map_arg(user.args, maybe_replace_node)
+        new_kwargs = map_arg(user.kwargs, maybe_replace_node)
+        user.args = new_args
+        user.kwargs = new_kwargs
+
+
+class DuplicateDQPass(PassBase):
+    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
+        for node in graph_module.graph.nodes:
+            if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
+                dq_users = _filter_sym_size_users(node)
+                if len(dq_users) <= 1:
+                    continue
+                # Do not duplicate dq for dynamic quantization
+                # Pattern: choose_qparam - getitem - q - dq
+                q_node = node.args[0]
+                if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
+                    getitem_node = q_node.args[1]
+                    if (
+                        isinstance(getitem_node, torch.fx.node.Node)
+                        and getitem_node.op == "call_function"
+                        and getitem_node.target == operator.getitem
+                    ):
+                        choose_qparam_node = getitem_node.args[0]
+                        if (
+                            isinstance(choose_qparam_node, torch.fx.node.Node)
+                            and choose_qparam_node.op == "call_function"
+                            and choose_qparam_node.target
+                            == torch.ops.quantized_decomposed.choose_qparams.tensor
+                        ):
+                            continue
+                for user in dq_users:
+                    _maybe_duplicate_dq(graph_module, node, user)
+        graph_module.graph.eliminate_dead_code()
+        graph_module.recompile()
+        return PassResult(graph_module, True)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/export_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/export_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe884508bcb9137b9772e6c11f120cd9549cd30
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/export_utils.py
@@ -0,0 +1,211 @@
+import types
+
+import torch
+import torch.nn.functional as F
+
+
+__all__ = [
+    "model_is_exported",
+    "_WrapperModule",
+]
+
+
+class _WrapperModule(torch.nn.Module):
+    """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
+    are trying to export a callable.
+    """
+
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, *args, **kwargs):
+        """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
+        return self.fn(*args, **kwargs)
+
+
+def model_is_exported(m: torch.nn.Module) -> bool:
+    """
+    Return True if the `torch.nn.Module` was exported, False otherwise
+    (e.g. if the model was FX symbolically traced or not traced at all).
+    """
+    return isinstance(m, torch.fx.GraphModule) and any(
+        "val" in n.meta for n in m.graph.nodes
+    )
+
+
+def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
+    """
+    Switch dropout patterns in the model between train and eval modes.
+
+    Dropout has different behavior in train vs eval mode. For exported models,
+    however, calling `model.train()` or `model.eval()` does not automatically switch
+    the dropout behavior between the two modes, so here we need to rewrite the aten
+    dropout patterns manually to achieve the same effect.
+
+    See https://github.com/pytorch/pytorch/issues/103681.
+    """
+    # Avoid circular dependencies
+    from .utils import get_aten_graph_module
+
+    # Needed to ensure subgraph matches are self-contained
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+    for inplace in [False, True]:
+
+        def dropout_train(x):
+            return F.dropout(x, p=0.5, training=True, inplace=inplace)
+
+        def dropout_eval(x):
+            return F.dropout(x, p=0.5, training=False, inplace=inplace)
+
+        example_inputs = (torch.randn(1),)
+        if train_to_eval:
+            match_pattern = get_aten_graph_module(
+                _WrapperModule(dropout_train), example_inputs
+            )
+            replacement_pattern = get_aten_graph_module(
+                _WrapperModule(dropout_eval), example_inputs
+            )
+        else:
+            match_pattern = get_aten_graph_module(
+                _WrapperModule(dropout_eval), example_inputs
+            )
+            replacement_pattern = get_aten_graph_module(
+                _WrapperModule(dropout_train), example_inputs
+            )
+
+        from torch.fx.subgraph_rewriter import replace_pattern_with_filters
+
+        replace_pattern_with_filters(
+            m,
+            match_pattern,
+            replacement_pattern,
+            match_filters=[],
+            ignore_literals=True,
+        )
+        m.recompile()
+
+
+def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
+    """
+    Switch batchnorm patterns in the model between train and eval modes.
+
+    Batchnorm has different behavior in train vs eval mode. For exported models,
+    however, calling `model.train()` or `model.eval()` does not automatically switch
+    the batchnorm behavior between the two modes, so here we need to rewrite the aten
+    batchnorm patterns manually to achieve the same effect.
+    """
+    # TODO(Leslie): This function still fails to support custom momentum and eps value.
+    # Enable this support in future updates.
+
+    # Avoid circular dependencies
+    from .utils import get_aten_graph_module
+
+    # Needed to ensure subgraph matches are self-contained
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+    def bn_train(
+        x: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ):
+        return F.batch_norm(
+            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
+        )
+
+    def bn_eval(
+        x: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ):
+        return F.batch_norm(
+            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
+        )
+
+    example_inputs = (
+        torch.randn(1, 1, 3, 3),  # x
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+    if train_to_eval:
+        match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs)
+        replacement_pattern = get_aten_graph_module(
+            _WrapperModule(bn_eval), example_inputs
+        )
+    else:
+        match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs)
+        replacement_pattern = get_aten_graph_module(
+            _WrapperModule(bn_train), example_inputs
+        )
+
+    from torch.fx.subgraph_rewriter import replace_pattern_with_filters
+
+    replace_pattern_with_filters(
+        m,
+        match_pattern,
+        replacement_pattern,
+        match_filters=[],
+        ignore_literals=True,
+    )
+    m.recompile()
+
+
+# TODO: expose these under this namespace?
+def _move_exported_model_to_eval(model: torch.fx.GraphModule):
+    """
+    Move an exported GraphModule to eval mode.
+
+    This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
+    QAT users should call this before performing inference on the model.
+    """
+    _replace_dropout(model, train_to_eval=True)
+    _replace_batchnorm(model, train_to_eval=True)
+    return model
+
+
+def _move_exported_model_to_train(model: torch.fx.GraphModule):
+    """
+    Move an exported GraphModule to train mode.
+
+    This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
+    QAT users should call this before performing training on the model.
+    """
+    _replace_dropout(model, train_to_eval=False)
+    _replace_batchnorm(model, train_to_eval=False)
+    return model
+
+
+def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
+    """
+    Allow users to call `model.train()` and `model.eval()` on an exported model,
+    but with the effect of changing behavior between the two modes limited to special
+    ops only, which are currently dropout and batchnorm.
+
+    Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
+    does in eager models, but only provides an approximation. In particular, user code
+    branching on `training` flag will not function correctly in general because the branch
+    is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
+    that have different train/eval behavior will also not be converted properly.
+    """
+
+    def _train(self, mode: bool = True):
+        if mode:
+            _move_exported_model_to_train(self)
+        else:
+            _move_exported_model_to_eval(self)
+
+    def _eval(self):
+        _move_exported_model_to_eval(self)
+
+    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
+    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/generate_numeric_debug_handle.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/generate_numeric_debug_handle.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dcb555b4756b72f99233f4ebf27dc517282ecc0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/generate_numeric_debug_handle.py
@@ -0,0 +1,17 @@
+from torch.fx import GraphModule, Node
+
+__all__ = ["generate_numeric_debug_handle"]
+
+
+def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
+    unique_id = 0
+    for node in graph_module.graph.nodes:
+        if node.op == "call_function":
+            node.meta["numeric_debug_handle"] = {}
+            for arg in node.args:
+                if isinstance(arg, Node):
+                    node.meta["numeric_debug_handle"][arg] = unique_id
+                    unique_id += 1
+
+            node.meta["numeric_debug_handle"]["output"] = unique_id
+            unique_id += 1
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff0644a3785bfccdb58527d58b90ad0e0b66aa48
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/graph_utils.py
@@ -0,0 +1,109 @@
+import itertools
+from typing import Any, List, OrderedDict, Set, Optional, Callable
+import operator
+from torch.fx import Node
+
+import torch
+
+from torch.fx.passes.utils.source_matcher_utils import (
+    check_subgraphs_connected,
+    get_source_partitions,
+    SourcePartition,
+)
+
+__all__ = [
+    "find_sequential_partitions",
+    "get_equivalent_types",
+    "update_equivalent_types_dict",
+]
+
+_EQUIVALENT_TYPES: List[Set] = [
+    {torch.nn.Conv1d, torch.nn.functional.conv1d},
+    {torch.nn.Conv2d, torch.nn.functional.conv2d},
+    {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
+    {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
+    {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
+    {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
+    {torch.add, operator.add, operator.iadd, "add", "add_"},
+    {torch.mul, operator.mul, operator.imul, "mul", "mul_"},
+]
+
+
+def _create_equivalent_types_dict():
+    _DICT = {}
+    for values in _EQUIVALENT_TYPES:
+        for v in values:
+            _DICT[v] = list(values)
+    return _DICT
+
+
+_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
+
+def get_equivalent_types() -> List[Set]:
+    return _EQUIVALENT_TYPES
+
+def update_equivalent_types_dict(customized_equivalent_types=None):
+    """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
+    When customized_equivalent_types passes in,
+    re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
+    """
+    if customized_equivalent_types is None:
+        raise ValueError("customized_equivalent_types should not be None")
+    global _EQUIVALENT_TYPES
+    global _EQUIVALENT_TYPES_DICT
+    _EQUIVALENT_TYPES = customized_equivalent_types
+    _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
+
+def _partitions_sequential(partitions: List[SourcePartition]):
+    prev_partition = None
+    for partition in partitions:
+        if prev_partition is not None and not check_subgraphs_connected(
+            prev_partition, partition
+        ):
+            return False
+        prev_partition = partition
+    return True
+
+
+def _get_matching_types(partition_type):
+    matching_types = [partition_type]
+    if partition_type in _EQUIVALENT_TYPES_DICT:
+        matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
+    return matching_types
+
+
+def _valid_type_sequence(partition_types: List[Any]):
+    partition_types_set = set()  # type: ignore[var-annotated]
+    for partition_type in partition_types:
+        matching_types = _get_matching_types(partition_type)
+        matching_types_set = set(matching_types)
+        if len(partition_types_set & matching_types_set) > 0:
+            return False
+        partition_types_set |= matching_types_set
+    return True
+
+
+def find_sequential_partitions(
+    gm: torch.fx.GraphModule,
+    partition_types: List[Any],
+    include_functional_equivalent=True,
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+):
+    if not _valid_type_sequence(partition_types):
+        raise ValueError(
+            f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
+        )
+
+    typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
+    for partition_type in partition_types:
+        types_to_match = _get_matching_types(partition_type)
+        partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
+        typed_partitions[partition_type] = list(itertools.chain.from_iterable(partitions.values()))
+
+    typed_partitions_list = list(typed_partitions.values())
+    fusion_candidates = itertools.product(*typed_partitions_list)
+    fused_partitions = []
+    for candidate in fusion_candidates:
+        if _partitions_sequential(candidate):  # type: ignore[arg-type]
+            fused_partitions.append(candidate)
+    return fused_partitions
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a836ff60ab1f93961ced1415d357b6e1855d64
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py
@@ -0,0 +1,198 @@
+import logging
+from typing import Optional
+
+import torch
+from torch._export.error import InternalError
+
+from torch.ao.quantization.pt2e.utils import (
+    _filter_sym_size_users,
+    _find_q_dq_node_for_user,
+    _is_valid_annotation,
+)
+
+from torch.ao.quantization.quantizer import QuantizationSpecBase
+
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.ERROR)
+
+__all__ = ["PortNodeMetaForQDQ"]
+
+_METADATA_TO_PORT = [
+    "stack_trace",
+    "quantization_tag",
+]
+
+_QUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+
+_DEQUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.dequantize_per_channel.default,
+]
+
+
+def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
+    from_meta = from_node.meta
+    for meta_name in _METADATA_TO_PORT:
+        if meta_name in from_meta:
+            to_node.meta[meta_name] = from_meta[meta_name]
+
+
+def _has_quant_annotation(node: torch.fx.Node) -> bool:
+    return "quantization_annotation" in node.meta
+
+
+def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
+    # BFS to look for choose qparams
+    from collections import deque
+
+    queue = deque(list(node.users.keys()))
+    while len(queue):
+        n = queue.popleft()
+        if n.op == "output":
+            continue
+        if (
+            n.op == "call_function"
+            and n.target == torch.ops.quantized_decomposed.choose_qparams.tensor
+        ):
+            return n
+        for k in n.users.keys():
+            queue.append(k)
+    return None
+
+
+def _port_metadata_for_input_quant_nodes(
+    input_node: torch.fx.Node,
+    node: torch.fx.Node,
+    qspec: Optional[QuantizationSpecBase],
+):
+    if qspec is None:
+        return
+
+    is_dynamic_quant = getattr(qspec, "is_dynamic", None)
+    if is_dynamic_quant is not None and is_dynamic_quant is True:
+        choose_qparams_node = _find_choose_qparams_node(input_node)
+        if choose_qparams_node is None:
+            raise ValueError(f"No chose qparams node found for {node}")
+        choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
+        if len(choose_qparam_users) != 2:
+            raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
+        scale_node = choose_qparam_users.pop()
+        dynamic_q_node = next(iter(scale_node.users.keys()))
+        dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
+        if len(dynamic_q_node_users) > 1:
+            raise InternalError(f"Expecting single user for {dynamic_q_node}")
+        dynamic_dq_node = dynamic_q_node_users.pop()
+        _add_metadata(choose_qparams_node, node)
+        _add_metadata(dynamic_q_node, node)
+        _add_metadata(dynamic_dq_node, node)
+    else:
+        q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
+        if q_node is None or dq_node is None:
+            return
+        # add metadata for all the node between q_node and get_attr node
+        # if the q_node can be traced back to get_attr node
+        q_to_get_attr_nodes = [q_node]
+        q_node_input = q_node.args[0]
+        while isinstance(q_node_input, torch.fx.Node) and q_node_input.op not in [
+            "placeholder",
+            "get_attr",
+        ]:
+            q_to_get_attr_nodes.append(q_node_input)
+            q_node_input = q_node_input.args[0]
+        if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
+            for n in q_to_get_attr_nodes:
+                _add_metadata(n, q_node_input)
+        _add_metadata(dq_node, node)
+
+
+def _port_metadata_for_output_quant_nodes(
+    node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
+):
+    if qspec is None:
+        return
+
+    node_users = _filter_sym_size_users(node)
+    if len(node_users) != 1:
+        raise InternalError(f"Expecting {node} to have single user")
+    q_node = node_users.pop()
+    if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
+        logger.warning(
+            f"Expecting {node} user to be a quantized op but got {q_node}"  # noqa: G004
+        )  # noqa: G004
+        return
+
+    _add_metadata(q_node, node)
+
+
+class PortNodeMetaForQDQ(PassBase):
+    """
+    Port metadata for nodes added by quantization flow.
+    For static quant these are:
+    - quantizer_per_tensor.default, dequantize_per_tensor.default
+    - quantizer_per_channel.default, dequantize_per_channel.default
+    For dynamic quant these are:
+    - choose_qparams.tensor
+    - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
+    - quantizer_per_channel.default, dequantize_per_channel.default
+
+    Rules of porting metadata:
+    - Metadata to be ported:
+      - nn_module_stack
+      - stack_trace
+      - quantization_tag
+    - Metadata to NOT be ported:
+      - Everything else
+    - Rules:
+      - Statically quantized patterns:
+        - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
+        - Quantize nodes on the outputs inherit metadata of the producer node.
+        - Example 1:
+          - Original: [Conv -> AvgPool -> Linear]
+          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
+          - Inner brackets specify which nodes Q/DQ inherit metdata from
+          - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
+          - Note first Q and last DQ do not inherit metadata from any nodes
+        - Example 2:
+          - Original: [Conv -> AvgPool -> Linear]
+          - AvgPool is not quantized
+          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
+          - Inner brackets specify which nodes Q/DQ inherit metdata from
+          - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
+          - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
+            AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
+            on the nodes (in this case AvgPool node) to conclude if the node or patter was
+            supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
+            inherit metadata from AvgPool.
+      - Dynamically quantized patterns:
+        - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
+        - For example, below linear is dynamically quantized while rest statically:
+          - Original: [Conv -> AvgPool -> Linear]
+          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
+          - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
+          - Note first Q does not inherit metadata from any nodes
+    NB:
+    - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
+      knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
+      However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
+      Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
+      code, this pass should like to be integrated in the refactored variant of "convert" step.
+    """
+
+    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
+        for node in graph_module.graph.nodes:
+            annotation = node.meta.get("quantization_annotation", None)
+            if _is_valid_annotation(annotation):
+                input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
+                output_qspec = node.meta["quantization_annotation"].output_qspec
+                for input_node, qspec in input_qspec_map.items():
+                    _port_metadata_for_input_quant_nodes(input_node, node, qspec)
+                _port_metadata_for_output_quant_nodes(node, output_qspec)
+        return PassResult(graph_module, True)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb45684e3a162789c447641876866da5b214fa0e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py
@@ -0,0 +1,489 @@
+import torch
+from torch._subclasses import FakeTensor
+from torch.ao.quantization.fx.prepare import (
+    _insert_obs_or_fq,
+    _save_state,
+    _is_activation_post_process_node,
+    _create_obs_or_fq_from_qspec,
+)
+from torch.fx import (
+    GraphModule,
+    Graph,
+    Node,
+)
+from torch.fx.node import Argument
+
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
+from typing import Dict, Tuple, Union, Any, Optional
+from torch.ao.quantization.quantizer import (
+    EdgeOrNode,
+    SharedQuantizationSpec,
+    QuantizationSpecBase,
+)
+from torch.ao.quantization import ObserverOrFakeQuantize
+
+# TODO: make pt2e folder private?
+__all__ = [
+    "prepare",
+]
+
+
+def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
+    """Find the root node for the sharing tree
+    Args:
+        edge_or_node: edge/node that we want to find the root
+        shared_with_map: each edge/node points to the parent, the root node will points to itself
+
+    Returns:
+        root edge/node
+    """
+    parent = shared_with_map[edge_or_node]
+    if parent == edge_or_node:
+        return edge_or_node
+    root = _find_root_edge_or_node(parent, shared_with_map)
+    # path compression
+    shared_with_map[edge_or_node] = root
+    return root
+
+def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
+    """Merge the subtree for `child` with `parent`, the order is important here
+    """
+    root_parent = _find_root_edge_or_node(parent, shared_with_map)
+    root_child = _find_root_edge_or_node(child, shared_with_map)
+    # union the two trees by pointing the root of child to root of parent
+    shared_with_map[root_child] = root_parent
+
+def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
+    """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
+    configuration and established the relationship between `edge_or_node` with the edge/node that it
+    is pointing to, we'll use this information in the end to get the group id
+    """
+    if isinstance(qspec, SharedQuantizationSpec):
+        parent = qspec.edge_or_node
+        # we point from edge_or_node to the node that it is sharing_with, e.g.
+        # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
+        _union(parent, child, shared_with_map)
+
+def _unwrap_shared_qspec(
+    qspec: QuantizationSpecBase,
+    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
+    shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
+) -> QuantizationSpecBase:
+    """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
+    if qspec is SharedQuantizationSpec
+       (1). tries to find the root edge or node for the node that the qspec points to
+       (2). recursively find the root qspec based on the qspec for the root node
+    """
+    if isinstance(qspec, SharedQuantizationSpec):
+        sharing_with = qspec.edge_or_node
+        root = _find_root_edge_or_node(sharing_with, shared_with_map)
+        qspec = edge_or_node_to_qspec[root]
+        return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+    return qspec
+
+def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
+    return (
+        hasattr(qspec_a, "dtype") and
+        hasattr(qspec_b, "dtype") and
+        qspec_a.dtype == qspec_b.dtype
+    )
+
+def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
+    return (
+        hasattr(qspec_a, "is_dynamic") and
+        hasattr(qspec_b, "is_dynamic") and
+        qspec_a.is_dynamic == qspec_b.is_dynamic
+    )
+
+def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
+    """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
+    """
+    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
+    for n in model.graph.nodes:
+        if hasattr(n, "meta") and "quantization_annotation" in n.meta:
+            qa = n.meta["quantization_annotation"]
+            for input_to_n, qspec in qa.input_qspec_map.items():
+                input_edge = (input_to_n, n)
+                edge_or_node_to_qspec[input_edge] = qspec
+            if qa.output_qspec is not None:
+                output_node = n
+                qspec = qa.output_qspec
+                edge_or_node_to_qspec[output_node] = qspec
+    return edge_or_node_to_qspec
+
+def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
+    """Union input edge with another edge or node, used in implicit sharing to point the current input
+    edge to other user edges of the producer node, or the output of producer node since these are
+    referring to the same Tensor
+    """
+    root_qspec = None
+    if edge_or_node in edge_or_node_to_qspec:
+        qspec = edge_or_node_to_qspec[edge_or_node]
+        root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+    # TODO: add assertions for types of root qspecs
+    if (
+        root_qspec is not None and
+        _has_same_dtype(root_qspec, input_edge_root_qspec) and
+        _has_same_is_dynamic(root_qspec, input_edge_root_qspec)
+    ):
+        # the input arg to the node should reuse the existing output observer for arg
+        # since dtype is the same (we may want to extend this to be a more strict check
+        # in the future)
+        # so we point from `input_edge` to `arg` (output of the argument)
+        _union(edge_or_node, input_edge, shared_with_map)
+
+
+def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
+    """Map from edge/node to the group ID, generated from quantization annotations,
+    edge/node with the same group ID should use the same observer/fake_quant instance
+
+    This is applying SharedQuantizationSpec configuration and map each edge/node to a group
+    There is another implicit sharing that's built in the quantization, when we have the following:
+       * op1 -> op2
+       * output of op1: int8_qspec
+       * (op1 -> op2) input edge: int8_qspec
+    we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
+
+    Figuring out the correct group ID for all edge/node is a standard union find problem:
+    https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
+
+    Args:
+        edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
+    Returns:
+        edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
+        belongs to the same group should have the same id
+
+    Example:
+        op2 -> cat1 -> cat2
+           op1 /        /
+                     op3
+        edge_or_node_to_qspec: {
+            op1: int8_qspec,
+            op2: int8_qspec,
+            (op1, cat1): int8_qspc,
+            (op2, cat1): SharedQuantizationSpec((op1, cat1)),
+            cat1: SharedQuantizationSpec((op1, cat1)),
+            (op3, cat2): int8_qspec,
+            (cat1, cat2): SharedQuantizationSpec((op3, cat2)),
+            cat2: SharedQuantizationSpec((op3, cat2)),
+        }
+
+        edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
+        edge_or_node_to_group_id: {
+            op1: 1,
+            op2: 1,
+            (op1, cat1): 1,
+            (op2, cat1): 1,
+            cat1: 1,
+            (op3, cat2): 1,
+            (cat1, cat2): 1,
+            cat2: 1,
+        }
+        # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
+        # connects the two sharing group around cat1 and cat2 op due to transitive sharing
+    """
+    # means the observer of key should be shared with observer with value, by default it will
+    # be shared with itself
+    shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
+    for edge_or_node, qspec in edge_or_node_to_qspec.items():
+        if isinstance(edge_or_node, torch.fx.Node):
+            output_node = edge_or_node
+            _update_shared_with(output_node, qspec, shared_with_map)
+        else:
+            input_edge = edge_or_node
+            input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+
+            assert isinstance(input_edge, tuple)
+            arg, n = input_edge
+            if n.meta["quantization_annotation"].allow_implicit_sharing:
+                # NOTE: the order is important here, we first share with other users and then share with previous
+                # output because the reverse order could cause circular dependency
+                # e.g node1 -> node2
+                #          \ -> node3
+                # when processing (node1, node2), if we first point (node1, node2) to node1
+                # Step 1. shared_map = {(node1, node2): node1}
+                # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
+                # which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
+                # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
+                # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
+                # have a circular dependency
+                # the following order works around this issue, but this does not allow arbitrary configuration
+                # of sharing so it might break in a different case in the future, when it breaks
+                # quantizer writer can check the notes here to debug the issue
+
+                # sharing with other users of the producer node
+                # (arg, user)
+                if not isinstance(arg, Node) or not isinstance(n, Node):
+                    raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}")
+                for user in arg.users:
+                    if user is n:
+                        continue
+                    arg_to_user_edge = (arg, user)
+                    _union_input_edge_with(
+                        input_edge,
+                        input_edge_root_qspec,
+                        arg_to_user_edge,
+                        edge_or_node_to_qspec,
+                        shared_with_map
+                    )
+
+                # sharing with output of producer node
+                _union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)
+
+            _update_shared_with(input_edge, qspec, shared_with_map)
+
+    # now that we get the sharing relations between all edges and nodes, we can assingn group ids
+    cur_group_id = 0
+    edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
+    for edge_or_node in shared_with_map.keys():
+        root = _find_root_edge_or_node(edge_or_node, shared_with_map)
+        if root not in edge_or_node_to_group_id:
+            edge_or_node_to_group_id[root] = cur_group_id
+            cur_group_id += 1
+        edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
+
+    return edge_or_node_to_group_id
+
+def _get_obs_or_fq_map(
+    edge_or_node_to_group_id: Dict[EdgeOrNode, int],
+    edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
+    is_qat: bool
+) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
+    """Generates the EdgeOrNode to observer/fake_quant instances
+    Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
+    instances
+    """
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
+    group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
+    for edge_or_node, qspec in edge_or_node_to_qspec.items():
+        group_id = edge_or_node_to_group_id[edge_or_node]
+        if group_id not in group_id_to_obs_or_fq:
+            # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
+            # the implementation for _create_obs_or_fq_from_qspec
+            group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
+        obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
+    return obs_or_fq_map
+
+def _maybe_insert_input_observer_for_arg_or_kwarg(
+    node: Union[Node, Any],
+    arg: Argument,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Argument:
+    """
+    Given a `node` and an `arg`, inserts an input observer between
+    `node` and `arg` if necessary.
+    """
+    # for ops such as torch.cat([x0, x1]),
+    # traverse through the list
+    if isinstance(arg, (list, tuple)):
+        new_arg_to_return = []
+        for inner_arg in arg:
+            new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+                node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
+            )
+            new_arg_to_return.append(new_inner_arg)
+        return type(arg)(new_arg_to_return)
+
+    if not isinstance(arg, Node):
+        return arg
+    assert isinstance(arg, Node)
+    # default (no observer)
+    new_arg = arg
+
+    # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
+    original_arg = arg
+    while _is_activation_post_process_node(original_arg, named_modules):
+        original_arg = original_arg.args[0]  # type: ignore[assignment]
+    assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}"
+
+    input_edge = (original_arg, node)
+    if input_edge not in obs_or_fq_map:
+        return new_arg
+    # input_edge needs to be observed
+    input_edge_obs_or_fq = obs_or_fq_map[input_edge]
+    if input_edge_obs_or_fq is None:
+        return new_arg
+
+    arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
+    # the arg is observed as the output and is using the same instance as the input_edge
+    # we'll reuse the inserted observer/fake_quant
+    if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq):
+        return new_arg
+
+    # otherwise, we'll insert a new observer/fake_quant node
+
+    existing_obs_node = None
+    # skip inserting new observers if the same observer instance is inserted before for another user
+    # Example:
+    # conv1 -> obs1 -> existing_obs -> conv2
+    #             \ -> conv3
+    #
+    # instead of inserting new observers we will have:
+    # conv1 -> obs1 -> existing_obs -> conv2
+    #                            \ -> conv3
+    for maybe_obs_node in arg.users.keys():
+        if not _is_activation_post_process_node(maybe_obs_node, named_modules):
+            continue
+        maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
+        if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
+            return maybe_obs_node
+
+    new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
+    return new_arg
+
+def _maybe_insert_input_observers_for_node(
+    node: Node,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> None:
+    """
+    If needed, inserts observers to the input args and kwargs of `node`.
+    Note: modifies `node` inplace.
+
+    For example, if cur_node needs an observer after prev_node, we change from
+
+      prev_node -> cur_node
+
+    To
+
+      prev_node -> obs -> cur_node
+
+    """
+    # Look through every input arg.  If that arg's target dtype does not
+    # match the current node's target dtype, insert an observer.
+    new_args = []
+    # map from old arg to new arg, used for updating the numeric debug handle map
+    remap = {}
+    for arg in node.args:
+        new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+            node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
+        )
+        new_args.append(new_arg)
+        remap[arg] = new_arg
+
+    if "numeric_debug_handle" in node.meta:
+
+        def remap_fn(x):
+            return remap.get(x, x)
+
+        numeric_debug_handle = node.meta["numeric_debug_handle"]
+        node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
+
+    # Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
+    # that persist in exported graph. This is just a work around for these.
+    assert (
+        node.target == torch.ops.aten.clone.default or
+        node.target == torch.ops.aten.zeros_like.default or
+        len(node.kwargs) == 0
+    ), " expecting kwargs for aten op IR to be empty"
+
+    # assign the new args to the node, inplace
+    node.args = tuple(new_args)
+
+def _maybe_insert_output_observer_for_node(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: Dict[str, torch.nn.Module],
+    graph: Graph,
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[Node]:
+    if node in obs_or_fq_map:
+        output_act_obs_or_fq = obs_or_fq_map[node]
+        return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
+    return None
+
+def _maybe_insert_input_and_output_observers_for_node(
+    node: Node,
+    model: torch.fx.GraphModule,
+    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+):
+    this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None
+    if this_node_quantization_annotation is None:
+        return
+
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+    _maybe_insert_input_observers_for_node(
+        node,
+        None,  # qconfig
+        model,
+        named_modules,
+        obs_or_fq_map,
+        is_qat,
+    )
+
+    output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
+    if not output_is_a_tensor:
+        return
+
+    # this returns the new observer node if it was needed
+    maybe_output_obs_node = _maybe_insert_output_observer_for_node(
+        node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
+
+    if maybe_output_obs_node is None:
+        return
+    # Update users of original node to use the output observer
+    # instead. For example, change
+    #
+    #           next_node
+    #          /
+    #   cur_node -> obs
+    #
+    # to
+    #
+    #                 next_node
+    #                 /
+    #   cur_node -> obs
+    #
+    # We need to save orig users before updating uses because
+    # the list of users will change as we update uses
+    orig_users = list(node.users.keys())
+    for user_node in orig_users:
+        if user_node is maybe_output_obs_node:
+            continue
+        user_node.replace_input_with(node, maybe_output_obs_node)
+
+def prepare(
+    model: GraphModule,
+    node_name_to_scope: Dict[str, Tuple[str, type]],
+    is_qat: bool,
+) -> GraphModule:
+    # Since we are mutating the graph as we go, we iterate over the original
+    # nodes before observer insertion, instead of model.graph.nodes.
+    nodes_before_observation = list(model.graph.nodes)
+
+    # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
+    # all edge/nodes that belongs to the same group will use the same instance
+    # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
+    # instance
+    edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
+    edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
+    obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
+
+    for node in nodes_before_observation:
+        # TODO: simplify logic for inserting observers
+        _maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
+
+    model = GraphModule(model, model.graph)
+
+    _save_state(
+        model,
+        {},  # node_name_to_qconfig
+        node_name_to_scope,
+        PrepareCustomConfig(),
+        {},  # equalization_node_name_to_qconfig
+        QConfigMapping(),
+        is_qat,
+        set()  # observed_node_names
+    )
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0527d506d2bdb750640d822b4bc4f6f4e638a75
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/qat_utils.py
@@ -0,0 +1,788 @@
+import dataclasses
+import itertools
+import operator
+from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING
+
+import torch
+from torch.fx import Graph, GraphModule, Node
+from torch.fx.subgraph_rewriter import (
+    replace_pattern_with_filters,
+    ReplacedPatterns,
+)
+import torch.nn.functional as F
+from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
+from torch.ao.quantization.pt2e.export_utils import _WrapperModule
+from torch.ao.quantization.quantizer import (
+    DerivedQuantizationSpec,
+    EdgeOrNode,
+    SharedQuantizationSpec,
+    QuantizationSpecBase,
+)
+from .utils import (
+    _conv1d_bn_example_inputs,
+    _conv2d_bn_example_inputs,
+    _is_conv,
+    _is_bn_node,
+    fold_bn_weights_into_conv_node,
+    get_aten_graph_module,
+)
+
+if TYPE_CHECKING:
+    from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+# Example inputs for quantized and folded conv-bn1d patterns used in convert
+_quantized_conv1d_bn_example_inputs = (
+    torch.randn(1, 1, 3),  # x
+    torch.randn(1, 1, 1),  # conv_weight
+    torch.randn(1),        # bn_weight
+    torch.randn(1),        # bn_bias
+    torch.randn(1),        # bn_running_mean
+    torch.randn(1),        # bn_running_var
+)
+
+# Example inputs for quantized and folded conv-bn2d patterns used in convert
+_quantized_conv2d_bn_example_inputs = (
+    torch.randn(1, 1, 3, 3),  # x
+    torch.randn(1, 1, 1, 1),  # conv_weight
+    torch.randn(1),           # bn_weight
+    torch.randn(1),           # bn_bias
+    torch.randn(1),           # bn_running_mean
+    torch.randn(1),           # bn_running_var
+)
+
+
+def _get_quantized_conv_bn_example_inputs_kwargs(
+    is_per_channel: bool,
+    has_bias: bool,
+    is_cuda: bool,
+) -> Dict[str, Any]:
+    """
+    Optional example inputs for quantized and folded conv-bn patterns
+    used in convert, expressed as kwargs.
+    """
+    kwargs = {}
+    # Per tensor quantization uses literals to represent scale and zero
+    # point, so there is no need to include them here as kwargs
+    if is_per_channel:
+        kwargs["scale"] = torch.tensor([1], dtype=torch.float)
+        kwargs["zero_point"] = torch.tensor([0], dtype=torch.int)
+    if has_bias:
+        kwargs["conv_bias"] = torch.randn(1)
+    if is_cuda:
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                kwargs[k] = v.cuda()
+    return kwargs
+
+def _get_conv_bn_pattern(conv_fn: Callable) -> Callable:
+    def _conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        conv_bias: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ) -> torch.Tensor:
+        x = conv_fn(x, conv_weight, conv_bias)
+        x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
+        return x
+    return _WrapperModule(_conv_bn_pattern)
+
+# TODO: merge this with the `no_conv_bias` case
+def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
+    def _qat_conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        conv_bias: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Approximated method to fuse conv and bn. It requires only one forward pass.
+        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std.
+        This is based on `nniqat.ConvBn2d._forward_approximate`.
+        """
+        # TODO: allow setting eps
+        bn_eps = 1e-5
+        running_std = torch.sqrt(bn_running_var + bn_eps)
+        scale_factor = bn_weight / running_std
+        weight_shape = [1] * len(conv_weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(conv_weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
+        zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
+        x = conv_fn(x, scaled_weight, zero_bias)
+        x = x / scale_factor.reshape(bias_shape)
+        x = x + conv_bias.reshape(bias_shape)
+        x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
+        return x
+    return _WrapperModule(_qat_conv_bn_pattern)
+
+def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
+    def _qat_conv_bn_pattern_no_conv_bias(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        # Not used, only for matching convenience
+        conv_bias: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias.
+        """
+        # TODO: allow setting eps
+        bn_eps = 1e-5
+        running_std = torch.sqrt(bn_running_var + bn_eps)
+        scale_factor = bn_weight / running_std
+        weight_shape = [1] * len(conv_weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(conv_weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
+        x = conv_fn(x, scaled_weight, None)
+        x = x / scale_factor.reshape(bias_shape)
+        x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
+        return x
+    return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias)
+
+def _append_qdq(x, is_per_channel, kwargs):
+    """
+    Helper function to append q-dq ops after `x`, using dummy values for the qparams
+    and qmin/qmax. We use dummy values here because we match with `ignore_literals=True`
+    and will manually replace these values after subgraph rewriting.
+
+    Return the dq node.
+    """
+    # Dummy args to be passed into q-dq ops
+    per_channel_axis = 0
+    scale = kwargs["scale"] if is_per_channel else 1.0
+    zp = kwargs["zero_point"] if is_per_channel else 0
+    qmin = -127
+    qmax = 127
+    dtype = torch.int8
+
+    qd = torch.ops.quantized_decomposed
+    if is_per_channel:
+        x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
+        x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
+    else:
+        x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
+        x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
+    return x
+
+def _get_quantized_qat_conv_bn_pattern(
+    is_per_channel: bool,
+    has_bias: bool,
+    bias_is_quantized: bool,
+    conv_fn: Callable,
+    bn_is_training: bool,
+) -> Callable:
+    """
+    Return the quantized version of QAT conv + BN pattern.
+    This is based on `nniqat.ConvBn2d._forward_approximate`,
+    used in QAT convert. We first match this pattern and replace
+    it with the normal [conv - bn] pattern, then fold the BN
+    weights into conv.
+    """
+    # TODO: allow setting eps
+    bn_eps = 1e-5
+
+    def _quantized_qat_conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+        **kwargs,
+    ) -> torch.Tensor:
+        running_std = torch.sqrt(bn_running_var + bn_eps)
+        scale_factor = bn_weight / running_std
+        weight_shape = [1] * len(conv_weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(conv_weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
+        scaled_weight = _append_qdq(scaled_weight, is_per_channel, kwargs)
+        if has_bias:
+            zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
+            if bias_is_quantized:
+                zero_bias = _append_qdq(zero_bias, is_per_channel, kwargs)
+            x = conv_fn(x, scaled_weight, zero_bias)
+        else:
+            x = conv_fn(x, scaled_weight, None)
+        x = x / scale_factor.reshape(bias_shape)
+        if has_bias:
+            x = x + kwargs["conv_bias"].reshape(bias_shape)
+        x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=bn_is_training, eps=bn_eps)
+        return x
+    return _WrapperModule(_quantized_qat_conv_bn_pattern)
+
+def _get_folded_quantized_qat_conv_bn_pattern(
+    is_per_channel: bool,
+    has_bias: bool,
+    bias_is_quantized: bool,
+    conv_fn: Callable,
+    bn_is_training: bool,
+) -> Callable:
+    """
+    Quantized QAT conv - bn pattern with bn weights being folded into conv.
+    """
+    # TODO: allow setting eps
+    bn_eps = 1e-5
+
+    def _folded_quantized_qat_conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+        **kwargs,
+    ) -> torch.Tensor:
+        conv_weight = _append_qdq(conv_weight, is_per_channel, kwargs)
+        if has_bias:
+            bias = kwargs["conv_bias"]
+            if bias_is_quantized:
+                bias = _append_qdq(bias, is_per_channel, kwargs)
+        else:
+            bias = None
+        x = conv_fn(x, conv_weight, bias)
+        x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=bn_is_training, eps=bn_eps)
+        return x
+    return _WrapperModule(_folded_quantized_qat_conv_bn_pattern)
+
+def _has_conv_bias_filter(
+    match: "InternalMatch",
+    original_graph: Graph,
+    pattern_graph: Graph,
+) -> bool:
+    """
+    Match filter for the subgraph rewriter that returns True if the conv node in
+    the original graph has bias.
+    """
+    for n in match.nodes_map.values():
+        if _is_conv(n):
+            return len(n.args) > 2 and n.args[2] is not None
+    raise ValueError("Could not find conv node in matched conv + bn pattern")
+
+def _no_conv_bias_filter(
+    match: "InternalMatch",
+    original_graph: Graph,
+    pattern_graph: Graph,
+) -> bool:
+    """
+    Match filter for the subgraph rewriter that returns True if the conv node in
+    the original graph does NOT have bias.
+    """
+    return not _has_conv_bias_filter(match, original_graph, pattern_graph)
+
+def _is_quantize(n: Node) -> bool:
+    return n.target in [
+        torch.ops.quantized_decomposed.quantize_per_tensor.default,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+        torch.ops.quantized_decomposed.quantize_per_channel.default,
+    ]
+
+def _is_dequantize(n: Node) -> bool:
+    return n.target in [
+        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+        torch.ops.quantized_decomposed.dequantize_per_channel.default,
+    ]
+
+def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Node]]:
+    """
+    Helper function to extract the nodes in the conv-bn fusion pattern after
+    subgraph rewriting, in the form of a map:
+
+        {name: (original_node, replacement_node)}
+
+    The following names must exist in the map:
+
+        "conv", "conv_weight", "conv_input", "bn", "getitem"
+
+    The following names may exist in the map:
+
+        "conv_weight_q", "conv_weight_dq", "conv_bias",
+        "conv_bias_q", "conv_bias_dq"
+    """
+    def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
+        """
+        Return a 3-tuple of (conv_node, bn_node, getitem_node).
+        This asserts that the match contains exactly one of each node.
+        """
+        conv_node, bn_node, getitem_node = None, None, None
+        for n in nodes:
+            if n.op != "call_function":
+                continue
+            if _is_conv(n):
+                assert conv_node is None
+                conv_node = n
+            if _is_bn_node(n):
+                assert bn_node is None
+                bn_node = n
+            if n.target == operator.getitem:
+                assert getitem_node is None
+                getitem_node = n
+        assert conv_node is not None
+        assert bn_node is not None
+        assert getitem_node is not None
+        return (conv_node, bn_node, getitem_node)
+
+    def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]:
+        """
+        Return a 3-tuple of (orig_node, q_node, dq_node).
+        """
+        assert _is_dequantize(n)
+        q_node = n.args[0]
+        assert isinstance(q_node, Node)
+        assert _is_quantize(q_node)
+        orig_node = q_node.args[0]
+        assert isinstance(orig_node, Node)
+        return (orig_node, q_node, n)
+
+    original_nodes = list(_filter_nodes_map(r.nodes_map).values())
+    o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
+    r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
+
+    # Create the mapping from original node to replacement node
+    mapping = {
+        "conv": (o_conv, r_conv),
+        "bn": (o_bn, r_bn),
+        "getitem": (o_getitem, r_getitem),
+    }
+
+    # Extract conv input and weight
+    # Note: here we extract the original nodes indirectly through the pattern nodes
+    # because the args of the original nodes are no longer available after replacement
+    (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
+    (p_conv_input, p_conv_weight, *_) = p_conv.args
+    (r_conv_input, r_conv_weight, *_) = r_conv.args
+    assert isinstance(p_conv_input, Node)
+    assert isinstance(p_conv_weight, Node)
+    assert isinstance(r_conv_input, Node)
+    assert isinstance(r_conv_weight, Node)
+    o_conv_input = r.nodes_map[p_conv_input]
+    o_conv_weight = r.nodes_map[p_conv_weight]
+
+    # If conv weight is quantized, extract the q - dq nodes
+    if _is_dequantize(p_conv_weight):
+        p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes(p_conv_weight)
+        r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes(r_conv_weight)
+        o_conv_weight = r.nodes_map[p_conv_weight]
+        o_conv_weight_q = r.nodes_map[p_conv_weight_q]
+        o_conv_weight_dq = r.nodes_map[p_conv_weight_dq]
+        mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q)
+        mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq)
+    mapping["conv_input"] = (o_conv_input, r_conv_input)
+    mapping["conv_weight"] = (o_conv_weight, r_conv_weight)
+
+    # Extract conv bias
+    if len(p_conv.args) > 2 and len(r_conv.args) > 2:
+        p_conv_bias = p_conv.args[2]
+        r_conv_bias = r_conv.args[2]
+        assert isinstance(p_conv_bias, Node)
+        assert isinstance(r_conv_bias, Node)
+        o_conv_bias = r.nodes_map[p_conv_bias]
+
+        # If conv bias is quantized, extract the q - dq nodes
+        if _is_dequantize(p_conv_bias):
+            p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias)
+            r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias)
+            o_conv_bias = r.nodes_map[p_conv_bias]
+            o_conv_bias_q = r.nodes_map[p_conv_bias_q]
+            o_conv_bias_dq = r.nodes_map[p_conv_bias_dq]
+            mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q)
+            mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq)
+        mapping["conv_bias"] = (o_conv_bias, r_conv_bias)
+    return mapping
+
+def _filter_nodes_map(nodes_map: Dict[Node, Node]) -> Dict[Node, Node]:
+    """
+    Return a filtered `nodes_map` returned from the subgraph rewriter.
+    The filtered `nodes_map` will contain only nodes that are actually
+    matched in the pattern, excluding None or placeholder nodes.
+    """
+    new_nodes_map: Dict[Node, Node] = {}
+    for pattern_node, graph_node in nodes_map.items():
+        # bias can be None
+        if graph_node is None:
+            continue
+        # skip pattern placeholder nodes
+        if pattern_node.op == "placeholder":
+            continue
+        new_nodes_map[pattern_node] = graph_node
+    return new_nodes_map
+
+# TODO: this is error prone, use the replace_literals_with_placeholders hack instead
+def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
+    """
+    Copy over literal args in conv, such as stride and padding, from the matched node
+    in the original graph to its replacement in the new graph.
+
+    This is needed due to the following limitation in the subgraph rewriter when used
+    with dynamo export: literal (non-tensor) args are not supported in the match and
+    replacement patterns. This is because dynamo export automatically inlines these
+    literal args, making them dead placeholder nodes. In the future, we should check
+    if dynamo export can optionally disable this inlining, or if subgraph rewriter
+    can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
+
+    Note: Unlike other tensor args like conv weights and biases, literal args are
+    preserved in the original nodes after replacement, so we can access them here.
+    """
+    assert _is_conv(original_node)
+    assert _is_conv(new_node)
+    # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
+    new_args = list(new_node.args)
+    if len(new_args) < 3:
+        # bias is optional, when it is not present, it means it is None
+        new_args.append(None)
+    new_node.args = tuple(new_args[:3]) + original_node.args[3:]
+
+def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacement_node: Node):
+    """
+    Update the `input_qspec_map` in the annotation after subgraph rewriting.
+
+    The original annotation referred to the nodes in the original graph,
+    so the keys in the `input_qspec_map` will need to be updated to reflect
+    the corresponding nodes in the replacement graph.
+    """
+    assert _is_conv(original_node)
+    assert _is_conv(replacement_node)
+    if "quantization_annotation" not in original_node.meta:
+        return
+    original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
+    input_qspec_map = {}
+    # get the list of configs, it should be ordered as input, weight, bias
+    # note: this is really hacky, we need a better solution, hopefully
+    # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
+    all_configs = list(original_input_qspec_map.items())
+    # input activation
+    input_qspec_map[replacement_node.args[0]] = all_configs[0][1]
+    # weight
+    input_qspec_map[replacement_node.args[1]] = all_configs[1][1]
+    # bias
+    if len(replacement_node.args) > 2 and len(all_configs) > 2:
+        input_qspec_map[replacement_node.args[2]] = all_configs[2][1]
+    replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
+
+def _update_special_qspecs_after_replacement(
+    node: Node,
+    original_to_replacement_node: Dict[Node, Node],
+):
+    """
+    Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s
+    used in `node`'s quantization annotation after subgraph rewriting.
+
+    The original annotation referred to the nodes in the original graph,
+    so the nodes used in these special quantization specs will need to
+    be updated to the corresponding nodes in the replacement graph.
+    """
+    def _get_new_edge_or_node(edge_or_node: EdgeOrNode):
+        if isinstance(edge_or_node, Node):
+            _node = edge_or_node
+            return original_to_replacement_node.get(_node, _node)
+        elif isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node):
+            src, dest = edge_or_node
+            return (
+                original_to_replacement_node.get(src, src),
+                original_to_replacement_node.get(dest, dest),
+            )
+        else:
+            raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node))
+
+    def _get_new_qspec(qspec: QuantizationSpecBase):
+        if isinstance(qspec, SharedQuantizationSpec):
+            new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node)
+            return SharedQuantizationSpec(new_edge_or_node)
+        elif isinstance(qspec, DerivedQuantizationSpec):
+            new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from]
+            return dataclasses.replace(qspec, derived_from=new_derived_from)
+        else:
+            return qspec
+
+    if "quantization_annotation" not in node.meta:
+        return
+    annotation = node.meta["quantization_annotation"]
+    for input_node, qspec in annotation.input_qspec_map.items():
+        annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
+    annotation.output_qspec = _get_new_qspec(annotation.output_qspec)
+
+def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
+    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
+    if not has_bn:
+        return m
+    m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=False)
+    m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=False)
+    if torch.cuda.is_available():
+        m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=True)
+        m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=True)
+    return m
+
+def _fuse_conv_bn_qat_helper(
+    m: GraphModule,
+    conv_fn: Callable,
+    example_inputs: Tuple[Any, ...],
+    is_cuda: bool,
+) -> GraphModule:
+    """
+    Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
+    the fused QAT subgraph equivalent. The input graph should already be annotated.
+    The annotations in the original nodes will be preserved in the corresponding
+    nodes in the new subgraph.
+
+    Note: This also handles the (conv + bn + relu) pattern.
+    """
+    m.graph.eliminate_dead_code()
+    m.recompile()
+    conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
+    match_pattern = get_aten_graph_module(conv_bn_pattern, example_inputs, is_cuda)
+
+    # Step (1): Replace patterns with conv bias
+    #
+    # Here we do replacement separately for cases with and without conv bias, since
+    # the replacement patterns for these two cases are substantially different.
+    # TODO: use the public replace_pattern API once it also returns replacement nodes
+
+    qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
+    replacement_pattern_with_conv_bias = get_aten_graph_module(
+        qat_conv_bn_pattern,
+        example_inputs,
+        is_cuda,
+    )
+    replacements_with_conv_bias = replace_pattern_with_filters(
+        m,
+        match_pattern,
+        replacement_pattern_with_conv_bias,
+        match_filters=[_has_conv_bias_filter],
+        ignore_literals=True,
+    )
+    m.recompile()
+
+    # Step (2): Replace patterns without conv bias
+
+    qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
+    replacement_pattern_no_conv_bias = get_aten_graph_module(
+        qat_conv_bn_pattern_no_conv_bias,
+        example_inputs,
+        is_cuda,
+    )
+    replacements_no_conv_bias = replace_pattern_with_filters(
+        m,
+        match_pattern,
+        replacement_pattern_no_conv_bias,
+        match_filters=[_no_conv_bias_filter],
+        ignore_literals=True,
+    )
+    m.recompile()
+
+    # Step (3): Post processing
+    #
+    # Due to limited functionality in the subgraph rewriter, here we manually
+    # update the replacement graph as follows:
+    #
+    #   (a) Copy over metadata from original subgraph. This ensures the stack traces
+    #       and annotations are preserved in the new subgraph
+    #
+    #   (b) Copy over literal args for conv from the original subgraph
+    #       TODO: do this for literal args for batchnorm as well
+    #
+    #   (c) Update all references of the old nodes in the original subgraph to refer
+    #       to the corresponding nodes in the new subgraph in the annotations
+    #
+    # In the future, we should try to push as much of this functionality into the
+    # subgraph rewriter as possible, so we don't have to manually copy anything over.
+    # For more detail, see https://github.com/pytorch/pytorch/issues/100419.
+
+    all_original_to_replacement_nodes = {}
+    for r in replacements_with_conv_bias + replacements_no_conv_bias:
+        for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
+            # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
+            replacement_node.meta = original_node.meta
+            if _is_conv(original_node):
+                # Step (3b): Copy over conv literal args
+                _copy_over_literal_conv_args(original_node, replacement_node)
+                # Step (3c): Update old references in the conv node's input_qspec_map
+                _update_conv_input_qspec_map_after_replacement(original_node, replacement_node)
+            all_original_to_replacement_nodes[original_node] = replacement_node
+
+    # Step (3c): Update old references in the special qspecs for all nodes in the graph
+    for n in m.graph.nodes:
+        _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes)
+
+    return m
+
+def _duplicate_dequantize_node(m: GraphModule):
+    """
+    Helper function to duplicate all dequantize nodes in the graph if the
+    node has more than one user. For example:
+
+    Before:
+      quantize -> dequantize -> a
+                          \\--> b
+                          \\--> c
+
+    After:
+      quantize -> dequantize_1 -> a
+            \\--> dequantize_2 -> b
+            \\--> dequantize_3 -> c
+
+    This is useful for subgraph rewriting. E.g. if we wish to match the
+    pattern [dequantize - a] above, subgraph matching would fail because
+    the dequantize node has users outside the matched portion of the graph.
+    Instead, we match [dequantize_1 - a], which is safe.
+    """
+    dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
+    for n in m.graph.nodes:
+        if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
+            continue
+        for user in list(n.users):
+            with m.graph.inserting_before(n):
+                new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs)
+            user.replace_input_with(n, new_node)
+        m.graph.erase_node(n)
+    m.recompile()
+
+def _remove_extra_dequantize(m: GraphModule):
+    """
+    Removes duplicate dequant nodes in the graph, for an operator that has
+    multiple dequant nodes as a user, replace them with a single dequant node
+    that can be shared across all the uses. This should be seen as the "reverse"
+    of `_duplicate_dequantize_node`.
+    """
+    dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
+    for n in m.graph.nodes:
+        dq_users = [user for user in n.users if user.op == "call_function" and user.target == dq_op]
+        if len(dq_users) > 1:
+            with m.graph.inserting_after(dq_users[0]):
+                new_node = m.graph.create_node("call_function", dq_op, dq_users[0].args, {})
+            for dq_user in dq_users:
+                dq_user.replace_all_uses_with(new_node)
+                m.graph.erase_node(dq_user)
+    m.recompile()
+
+def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
+    """
+    Given a pair of quantize or dequantize nodes, copy over all literal args
+    from the original node to the replacement node.
+    """
+    # For quantize_per_tensor, scale and zp are literals and need to be copied
+    # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped
+    assert original_node.target == replacement_node.target
+    if original_node.target in (
+        torch.ops.quantized_decomposed.quantize_per_tensor.default,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    ):
+        # Args: input, [scale, zp, qmin, qmax, dtype]
+        start_copy_arg_index = 1
+    elif original_node.target in (
+        torch.ops.quantized_decomposed.quantize_per_channel.default,
+        torch.ops.quantized_decomposed.dequantize_per_channel.default,
+    ):
+        # Args: input, scale, zp, [axis, qmin, qmax, dtype]
+        start_copy_arg_index = 3
+    else:
+        raise ValueError("Expected quantize/dequantize nodes, got '%s'" % original_node.target)
+    replacement_node.args = (
+        replacement_node.args[:start_copy_arg_index] + original_node.args[start_copy_arg_index:]
+    )
+
+def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
+    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
+    if not has_bn:
+        return m
+    m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=False)
+    m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=False)
+    if torch.cuda.is_available():
+        m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=True)
+        m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=True)
+    return m
+
+def _fold_conv_bn_qat_helper(
+    m: GraphModule,
+    conv_fn: Callable,
+    example_inputs: Tuple[Any, ...],
+    is_cuda: bool,
+) -> GraphModule:
+    """
+    Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
+    """
+    m.graph.eliminate_dead_code()
+    m.recompile()
+    _duplicate_dequantize_node(m)
+
+    # Step (1): Replace QAT pattern with simple [conv - bn] pattern
+    replacements = []
+    replacement_options = itertools.product(
+        [True, False],  # is_per_channel
+        [True, False],  # has_bias
+        [True, False],  # bias_is_quantized
+        [True, False],  # bn_is_training
+    )
+    for is_per_channel, has_bias, bias_is_quantized, bn_is_training in replacement_options:
+        # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
+        # filter out one of the values for this flag to avoid having duplicate patterns
+        if not has_bias and bias_is_quantized:
+            continue
+        kwargs = _get_quantized_conv_bn_example_inputs_kwargs(is_per_channel, has_bias, is_cuda)
+        match_pattern = _get_quantized_qat_conv_bn_pattern(
+            is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
+        )
+        match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs)
+        replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
+            is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
+        )
+        replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs)
+        replacements.extend(
+            replace_pattern_with_filters(
+                m,
+                match_pattern,
+                replacement_pattern,
+                ignore_literals=True,
+            )
+        )
+    m.recompile()
+    _remove_extra_dequantize(m)
+
+    for r in replacements:
+        node_map = _get_conv_bn_pattern_nodes(r)
+
+        # Step (2): Copy over metadata from original subgraph
+        for original_node, replacement_node in node_map.values():
+            replacement_node.meta = original_node.meta
+
+        # Step (3): Copy over args for weight (and optionally bias) q - dq nodes
+        _copy_over_q_dq_args(*node_map["conv_weight_q"])
+        _copy_over_q_dq_args(*node_map["conv_weight_dq"])
+        if "conv_bias_q" in node_map:
+            assert "conv_bias_dq" in node_map
+            _copy_over_q_dq_args(*node_map["conv_bias_q"])
+            _copy_over_q_dq_args(*node_map["conv_bias_dq"])
+
+        # Step (4): Fold BN weights into conv
+        conv_bias = None
+        (_, conv_node) = node_map["conv"]
+        (_, bn_node) = node_map["bn"]
+        (_, conv_weight) = node_map["conv_weight"]
+        if "conv_bias" in node_map:
+            (_, conv_bias) = node_map["conv_bias"]
+        fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
+
+        # Copy over literal args for conv
+        for original_node in _filter_nodes_map(r.nodes_map).values():
+            if _is_conv(original_node):
+                _copy_over_literal_conv_args(original_node, conv_node)
+
+    m.graph.eliminate_dead_code()
+    m.recompile()
+    return m
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..510943b9feb5a6b86af0d6ad7048425f8c018a44
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__init__.py
@@ -0,0 +1,5 @@
+from .rewrite import reference_representation_rewrite
+
+__all__ = [
+    "reference_representation_rewrite",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79200fb8d840a18062a7b1dfc1ef51ab4e7f57da
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce754c009110a22961ef7f8b829d2593ac172f9d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py
new file mode 100644
index 0000000000000000000000000000000000000000..68357be3577e057ce88af1e9d07541a96eec16e1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py
@@ -0,0 +1,600 @@
+import torch
+from torch.fx import GraphModule
+from ..export_utils import _WrapperModule
+from ..utils import (
+    get_aten_graph_module,
+    remove_tensor_overload_for_qdq_ops,
+    _replace_literals_with_new_placeholders,
+    _replace_literals_with_existing_placeholders,
+)
+from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
+from torch.fx.subgraph_rewriter import replace_pattern
+from torch._higher_order_ops.out_dtype import out_dtype
+from typing import Optional, Callable, Tuple, Any
+from dataclasses import dataclass
+
+from functools import partial
+
+__all__ = [
+    "reference_representation_rewrite",
+]
+
+
+_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (2, 5), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randint(-128, 127, (5, 5), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-127], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randn(1, dtype=torch.float),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _qdq_quantized_linear(
+    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+    out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
+    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
+    return out_i8
+
+def _reference_quantized_linear(
+    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+    out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.linear.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None)
+    # TODO: change to mul.Scalar
+    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
+    bias_scale = x_scale * weight_scale
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    acc_i32 = acc_i32 + bias_i32
+    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
+    acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
+    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
+    return out_i8
+
+
+_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
+    torch.randn((2, 5), dtype=torch.float),
+    -128,
+    127,
+    torch.finfo(torch.float32).eps,
+    torch.randint(-128, 127, (5, 5), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-127], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randn(1, dtype=torch.float),
+)
+
+
+def _qdq_dynamic_quantized_linear(
+    x_fp32, x_quant_min, x_quant_max, x_eps,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+):
+    x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
+    x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
+    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
+    return out_fp32
+
+def _reference_dynamic_quantized_linear(
+    x_fp32, x_quant_min, x_quant_max, x_eps,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+):
+    x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
+    # decomposed representation for quantize_per_tensor
+    # TODO: use out_dtype(mul, ...) here when the op is ready
+    x_fp32 = x_fp32 / x_scale  # fp32
+    # round modes might be different here
+    # pytorch is rounding to even, which is also common for most of the backends
+    x_fp32 = torch.round(x_fp32)  # fp32
+    x_i32 = x_fp32.to(dtype=torch.int32)  # int32
+    x_i32 = x_i32 + x_zero_point  # int32
+    # clamp works for fp32, int32 and int8 dtypes
+    x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max)  # int32
+    x_i8 = x_i32.to(dtype=torch.int8)
+
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.linear.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None)
+    bias_scale = x_scale * weight_scale
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    acc_i32 = acc_i32 + bias_i32
+    out_fp32 = acc_i32 * (x_scale * weight_scale)
+    return out_fp32
+
+
+_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-127], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randn(1, dtype=torch.float),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _qdq_quantized_conv2d(
+    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+    out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+    stride = [1, 1]
+    padding = [0, 0]
+    dilation = [1, 1]
+    transposed = False
+    output_padding = [0, 0]
+    groups = 1
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
+    out_fp32 = torch.ops.aten.convolution.default(
+        x_fp32, weight_fp32, bias_fp32, stride, padding, dilation, transposed, output_padding, groups)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
+    return out_i8
+
+def _reference_quantized_conv2d(
+    x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
+    weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
+    bias_fp32,
+    out_scale, out_zero_point, out_quant_min, out_quant_max
+):
+    stride = [1, 1]
+    padding = [0, 0]
+    dilation = [1, 1]
+    transposed = False
+    output_padding = [0, 0]
+    groups = 1
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.convolution.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None, stride, padding, dilation, transposed, output_padding, groups)
+    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
+    bias_scale = x_scale * weight_scale
+    # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
+    # Take linear calculation for example
+    # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
+    # Represent X, W fp32 as their dequant transforms
+    # A_fp32 = (A_q - A_zero_point)/A_scale
+    # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
+    # Factor out X_scale and W_scale
+    # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
+    # In order to addition of bias_(i)_fp32 inside, we must do
+    # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale  # noqa: B950
+    # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
+    # Thus bias quantization to int32 must be with X_scale * W_scale
+
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    # Unsqueeze to match broadcast dims
+    # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
+    # in graph pattern replacement
+    bias_i32 = bias_i32.unsqueeze(-1)
+    bias_i32 = bias_i32.unsqueeze(-1)
+    acc_i32 = acc_i32 + bias_i32
+    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
+    acc_i32 = out_dtype(
+        torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
+    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
+    return out_i8
+
+
+_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _qdq_quantized_add_relu(
+    x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
+    out_scale, out_zero_point, quant_min, quant_max
+):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
+    y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
+    out_fp32 = x_fp32 + y_fp32
+    out_fp32 = torch.ops.aten.relu(out_fp32)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
+    )
+    return out_i8
+
+def _reference_quantized_add_relu(
+    x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
+    out_scale, out_zero_point, quant_min, quant_max
+):
+    """
+    See comments for `_reference_quantized_add` for more information on
+    how to derive the formula for out_i8 based on x_i8 and y_i8
+    """
+    x_i32 = x_i8.to(torch.int32)
+    y_i32 = y_i8.to(torch.int32)
+    # TODO: change this to mul.Scalar?
+    x_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (x_i32 - x_zero_point), (x_scale / out_scale))
+    y_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (y_i32 - y_zero_point), (y_scale / out_scale))
+    out_i32 = x_i32 + y_i32 + out_zero_point
+    # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
+    out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
+    return out_i8
+
+def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point, quant_min, quant_max):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8)
+    y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8)
+    out_fp32 = x_fp32 + y_fp32
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
+    )
+    return out_i8
+
+def _reference_quantized_add(
+    x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point,
+    out_scale, out_zero_point, quant_min, quant_max
+):
+    """
+    # How to Derive the formula for out_i8 based on x_i8 and y_i8
+    # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
+
+    # out_i8 is quantized output, we can write down the formula for it first:
+out_i8 = out_f32 / out_scale + out_zero_point           (1)
+
+    # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
+    out_f32 = x_f32 + y_f32           (2)
+    x_fp32 = (x_i8 - x_zero_point) * x_scale         (3)
+    y_fp32 = (y_i8 - y_zero_point) * y_scale         (4)
+
+    # applying the above fomula to the out_i8 equation we can get the following:
+    out_i8 = out_fp32 / out_scale + out_zero_point             # (1)
+       = (x_f32 + y_f32) / out_scale + out_zero_point      # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
+       = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point  # apply (3) and (4)
+    """
+    x_i32 = x_i8.to(torch.int32)
+    y_i32 = y_i8.to(torch.int32)
+    # TODO: use out_dtype op
+    x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
+    y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
+    out_i32 = x_i32 + y_i32 + out_zero_point
+    quant_min = -128
+    quant_max = 127
+    out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
+    return out_i8
+
+_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _qdq_quantized_max_pool2d(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
+    kernel_size = 1
+    stride = 1
+    padding = 0
+    dilation = 1
+    ceil_mode = False
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
+    out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(x_fp32, kernel_size, stride, padding, dilation, ceil_mode)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
+    return out_i8
+
+def _reference_quantized_max_pool2d(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max):
+    kernel_size = 1
+    stride = 1
+    padding = 0
+    dilation = 1
+    ceil_mode = False
+    # to preserve x_quant_min, x_quant_max in the graph for pattern matching
+    x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
+    x_i32 = x_i8.to(torch.int32)
+    out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
+        x_i32 - x_zero_point,
+        kernel_size,
+        stride,
+        padding,
+        dilation,
+        ceil_mode
+    )
+    out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
+    out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
+    out_i8 = out_fp32.to(torch.int8)
+    return out_i8
+
+_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
+    torch.randn(1, 3, 3, 3, dtype=torch.float),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
+    x = torch.ops.quantized_decomposed.quantize_per_tensor(x_fp32, scale, zero_point, quant_min, quant_max, torch.int8)
+    return x
+
+def _reference_quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
+    # TODO: use out_dtype(mul, ...) here when the op is ready
+    x = x_fp32 / scale  # fp32
+    # round modes might be different here
+    # pytorch is rounding to even, which is also common for most of the backends
+    x = torch.round(x)  # fp32
+    x = x.to(dtype=torch.int32)  # int32
+    x = x + zero_point  # int32
+    # clamp works for fp32, int32 and int8 dtypes
+    x = torch.clamp(x, quant_min, quant_max)  # int32
+    x = x.to(dtype=torch.int8)
+    return x
+
+_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(1, dtype=torch.float),
+    torch.zeros(1, dtype=torch.int),
+    torch.tensor([-128], dtype=torch.int),
+    torch.tensor([127], dtype=torch.int),
+)
+
+def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max, torch.int8)
+    return x_fp32
+
+def _reference_dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
+    # TODO: use out_dtype op
+    # note: x_i8.to(torch.int32) does not work here
+    # TODO: debug the implementation later when torchdynamo time out issue is resolved
+    return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
+
+_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
+    torch.randn(1, 3, 3, 3, dtype=torch.float),
+    torch.randn(3, dtype=torch.float),
+    torch.zeros(3, dtype=torch.int),
+    1,
+    -128,
+    127,
+)
+
+def _quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max):
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
+        x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
+    )
+    return out_i8
+
+def _reference_quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max):
+    x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
+    out_i32 = torch.ops.aten.clamp(torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max)
+    out_i32 = torch.transpose(out_i32, ch_axis, -1)
+    return out_i32.to(torch.int8)
+
+_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
+    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+    torch.randn(3, dtype=torch.float),
+    torch.zeros(3, dtype=torch.int),
+    1,
+    -128,
+    127,
+)
+
+def _dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max):
+    # the following will be replaced as placeholders
+    out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
+        x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
+    )
+    return out_fp32
+
+def _reference_dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max):
+    # the following will be replaced as placeholders
+    # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
+    # we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
+    x_i8 = torch.transpose(x_i8, ch_axis, -1)
+    x_i32 = x_i8.to(torch.int32)
+    out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
+    out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
+    return out_fp32
+
+def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
+    return _replace_literals_with_existing_placeholders(
+        gm,
+        exclude_literals=[-1],
+        literal_to_ph_idx={1: 3, -128: 4, 127: 5}
+    )
+
+
+@dataclass
+class _RewriteInfo:
+    """Data needed for rewrite, this includes example inputs, pattern and replacement functions
+    and post transformation functions for the exported pattern and replacement GraphModule
+    """
+
+    # example inputs used for exporting the pattern into GraphModule
+    example_inputs: Tuple[Any, ...]
+    pattern: Callable
+    replacement: Callable
+    # post transformation on the exported pattern and replacement GraphModule
+    pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
+    replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
+
+_REWRITE_INFO_LIST = [
+    _RewriteInfo(
+        _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
+        _WrapperModule(_qdq_dynamic_quantized_linear),
+        _WrapperModule(_reference_dynamic_quantized_linear),
+        partial(
+            _replace_literals_with_existing_placeholders,
+            literal_to_ph_idx={
+                -128: 1,
+                127: 2,
+                torch.finfo(torch.float32).eps: 3
+            }
+        ),
+        partial(
+            _replace_literals_with_existing_placeholders,
+            literal_to_ph_idx={
+                -128: 1,
+                127: 2,
+                torch.finfo(torch.float32).eps: 3
+            }
+        ),
+    ),
+    _RewriteInfo(
+        _QUANTIZED_LINEAR_EXAMPLE_INPUTS,
+        _WrapperModule(_qdq_quantized_linear),
+        _WrapperModule(_reference_quantized_linear),
+        _replace_literals_with_new_placeholders,
+        _replace_literals_with_new_placeholders,
+    ),
+    _RewriteInfo(
+        _QUANTIZED_CONV2d_EXAMPLE_INPUTS,
+        _WrapperModule(_qdq_quantized_conv2d),
+        _WrapperModule(_reference_quantized_conv2d),
+        partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
+        partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
+    ),
+    _RewriteInfo(
+        _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
+        _WrapperModule(_qdq_quantized_add_relu),
+        _WrapperModule(_reference_quantized_add_relu),
+    ),
+    _RewriteInfo(
+        _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
+        _WrapperModule(_qdq_quantized_add),
+        _WrapperModule(_reference_quantized_add),
+    ),
+    _RewriteInfo(
+        _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
+        _WrapperModule(_qdq_quantized_max_pool2d),
+        _WrapperModule(_reference_quantized_max_pool2d),
+        _replace_literals_with_new_placeholders,
+        _replace_literals_with_new_placeholders
+    ),
+    _RewriteInfo(
+        _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
+        _WrapperModule(_quantize_per_tensor_int8),
+        _WrapperModule(_reference_quantize_per_tensor_int8),
+    ),
+    _RewriteInfo(
+        _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
+        _WrapperModule(_dequantize_per_tensor_int8),
+        _WrapperModule(_reference_dequantize_per_tensor_int8),
+    ),
+    _RewriteInfo(
+        _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
+        _WrapperModule(_quantize_per_channel_int8),
+        _WrapperModule(_reference_quantize_per_channel_int8),
+        _replace_ph_qdq_per_channel_replacement,
+        _replace_ph_qdq_per_channel_replacement
+    ),
+    _RewriteInfo(
+        _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
+        _WrapperModule(_dequantize_per_channel_int8),
+        _WrapperModule(_reference_dequantize_per_channel_int8),
+        _replace_ph_qdq_per_channel_replacement,
+        _replace_ph_qdq_per_channel_replacement
+    ),
+]
+
+def reference_representation_rewrite(model: GraphModule) -> GraphModule:
+    remove_tensor_overload_for_qdq_ops(model)
+    for rewrite_info in _REWRITE_INFO_LIST:
+        example_inputs = rewrite_info.example_inputs
+        pattern = rewrite_info.pattern
+        replacement = rewrite_info.replacement
+        pattern_post_trans = rewrite_info.pattern_post_trans
+        replacement_post_trans = rewrite_info.replacement_post_trans
+        pattern = get_aten_graph_module(pattern, example_inputs)  # type: ignore[arg-type, assignment]
+        remove_tensor_overload_for_qdq_ops(pattern)  # type: ignore[arg-type]
+        replacement = get_aten_graph_module(replacement, example_inputs)  # type: ignore[arg-type, assignment]
+        remove_tensor_overload_for_qdq_ops(replacement)  # type: ignore[arg-type]
+        if pattern_post_trans:
+            pattern = pattern_post_trans(pattern)
+        if replacement_post_trans:
+            replacement = replacement_post_trans(replacement)
+        pattern.recompile()  # type: ignore[attr-defined]
+        replacement.recompile()  # type: ignore[attr-defined]
+        matches = replace_pattern(model, pattern, replacement)
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7cbb8c8c8f4f44db255a22e143a56b6d0e12766
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/pt2e/utils.py
@@ -0,0 +1,540 @@
+import operator
+import types
+
+import torch
+from torch._export import capture_pre_autograd_graph
+from torch.fx import (
+    GraphModule,
+    Node,
+)
+from torch.nn.utils.fusion import fuse_conv_bn_weights
+from typing import Any, Callable, Dict, Optional, Tuple, List, Union
+from torch.utils._pytree import LeafSpec
+from torch.export.unflatten import _AttrKind, _assign_attr
+
+# Makes sure that quantized_decomposed ops are registered
+from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
+
+from torch.ao.quantization.quantizer import QuantizationAnnotation
+
+
+__all__ = [
+    "fold_bn_weights_into_conv_node",
+    "get_aten_graph_module",
+    "remove_tensor_overload_for_qdq_ops",
+]
+
+_QUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+
+
+_DEQUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.dequantize_per_channel.default,
+]
+
+# Example inputs for conv-bn1d patterns
+_conv1d_bn_example_inputs = (
+    torch.randn(1, 1, 3),  # x
+    torch.randn(1, 1, 1),  # conv_weight
+    torch.randn(1),        # conv_bias
+    torch.randn(1),        # bn_weight
+    torch.randn(1),        # bn_bias
+    torch.randn(1),        # bn_running_mean
+    torch.randn(1),        # bn_running_var
+)
+
+# Example inputs for conv-bn2d patterns
+_conv2d_bn_example_inputs = (
+    torch.randn(1, 1, 3, 3),  # x
+    torch.randn(1, 1, 1, 1),  # conv_weight
+    torch.randn(1),           # conv_bias
+    torch.randn(1),           # bn_weight
+    torch.randn(1),           # bn_bias
+    torch.randn(1),           # bn_running_mean
+    torch.randn(1),           # bn_running_var
+)
+
+def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
+    """
+    Assuming dest is one of the ops inserted by quant workflow, this function
+    finds if source and dest are connected. Assumption is that only quant workflow
+    inserted ops exist between source and dest
+    """
+    quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
+    quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
+    while dest.target in quant_workflow_ops:
+        if not isinstance(dest.args[0], torch.fx.Node):
+            raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}")
+        dest = dest.args[0]
+    return (dest == source)
+
+
+def _find_q_dq_node_for_user(
+    produer: torch.fx.Node, user: torch.fx.Node
+) -> Tuple[Any, Any]:
+    """
+    Find q, dq pair corresponding to [producer -> q -> dq -> user]
+    Utils works by finding dq arg of user and ensuring it is connected to
+    producer
+    """
+    dq_node = None
+    for n in user.args:
+        if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
+            if _is_connected(produer, n):
+                dq_node = n
+                break
+    if dq_node is None:
+        for n in user.kwargs:
+            if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
+                if _is_connected(produer, n):
+                    dq_node = n
+                    break
+    if dq_node is None:
+        return (None, None)
+
+    q_node = None
+    if dq_node.args[0].op == "call_function" and dq_node.args[0].target in _QUANTIZE_OPS:
+        q_node = dq_node.args[0]
+    return (q_node, dq_node)
+
+
+
+def _is_sym_size_node(node: Node):
+    return (
+        node.op == "call_function"
+        and node.target == torch.ops.aten.sym_size.default
+        or node.target == torch.ops.aten.sym_numel.default
+        or node.target == torch.ops.aten.sym_numel
+        or node.target == torch.ops.aten.sym_size
+    )
+
+
+def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]:
+    node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
+    return node_users
+
+
+def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
+    if annotation is None:
+        return False
+    input_qspec_map = annotation.input_qspec_map
+    output_qspec = annotation.output_qspec
+    if len(input_qspec_map) == 0 and output_qspec is None:
+        return False
+    return True
+
+
+def _get_tensor_constant_from_node(node, m):
+    if node is None:
+        return None
+    assert node.op == "get_attr"
+    target_atoms = node.target.split('.')
+    attr_itr = m
+    for i, atom in enumerate(target_atoms):
+        if not hasattr(attr_itr, atom):
+            raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
+        attr_itr = getattr(attr_itr, atom)
+    return attr_itr
+
+def _get_all_arguments(orig_args, orig_kwargs, args_schema):
+    all_args = []
+    for i, schema in enumerate(args_schema):
+        if schema.name in orig_kwargs:
+            all_args.append(orig_kwargs[schema.name])
+        elif not schema.kwarg_only and i < len(orig_args):
+            all_args.append(orig_args[i])
+        else:
+            all_args.append(schema.default_value)
+    return all_args
+
+def _is_supported_batch_norm_for_training(node: Node):
+    """
+    Return True if the given node refers to an aten batch norm op QAT supports.
+    """
+    supported_ops = [
+        torch.ops.aten._native_batch_norm_legit.default,
+        # Note: we won't need this op anymore after batch norm consolidation
+        # For now, we need to continue to support it because it gives better
+        # training numerics than `_native_batch_norm_legit`
+        torch.ops.aten.cudnn_batch_norm.default,
+        torch.ops.aten.miopen_batch_norm.default,
+    ]
+    return node.target in supported_ops
+
+# TODO: rename this to _is_conv_node
+def _is_conv(n: Node):
+    """
+    Return whether the node refers to an aten conv op.
+    """
+    return n.op == "call_function" and n.target in [
+        torch.ops.aten.conv1d.default,
+        torch.ops.aten.conv2d.default,
+    ]
+
+# TODO: rename this to _is_conv_transpose_node
+def _is_conv_transpose(n: Node):
+    """
+    Return whether the node refers to an aten conv_transpose op.
+    """
+    return n.op == "call_function" and n.target in [
+        torch.ops.aten.conv_transpose1d,
+        torch.ops.aten.conv_transpose2d,
+    ]
+
+def _is_bn_node(n: Node):
+    return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
+
+def fold_bn_weights_into_conv_node(
+    conv_node: Node,
+    conv_weight_node: Node,
+    conv_bias_node: Optional[Node],
+    bn_node: Node,
+    m: GraphModule
+) -> None:
+    # conv args: input, weight, bias, stride, padding, dilation, ...
+    conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
+    conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
+    transpose = _is_conv_transpose(conv_node)
+
+    # eval bn args: input, weight, bias, running mean, running var, momentum, eps
+    # train bn args: input, weight, bias, running mean, running var, training, momentum, eps
+    bn_args_schema = bn_node.target._schema.arguments  # type: ignore[union-attr]
+    bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
+    bn_w = _get_tensor_constant_from_node(bn_args[1], m)
+    bn_b = _get_tensor_constant_from_node(bn_args[2], m)
+    bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
+    bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
+    if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
+        eps_arg_index = 6
+    elif _is_supported_batch_norm_for_training(bn_node):
+        eps_arg_index = 7
+    else:
+        raise ValueError("BN node target is unexpected ", bn_node.target)
+    bn_eps = bn_args[eps_arg_index]
+
+    fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
+
+    # update the weight and bias for conv
+    conv_args = list(conv_node.args)
+    # filling in the default bias argument
+    if len(conv_args) == 2:
+        conv_args.append(None)
+
+    # calling data since the fused_weight and fused_bias are nn.Parameter
+    weight_attr_name = conv_weight_node.target
+    assert isinstance(weight_attr_name, str)
+    _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
+    if conv_bias_node is not None:
+        bias_attr_name = conv_bias_node.target
+        _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
+    else:
+        bias_attr_name = weight_attr_name + "_bias"
+        _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
+        with m.graph.inserting_before(conv_node):
+            get_bias_node = m.graph.get_attr(bias_attr_name)
+        # NOTE: here we assume the bias of conv is not quantized!
+        conv_args[2] = get_bias_node
+    conv_node.args = tuple(conv_args)
+
+    # native_batch_norm has 3 outputs, we expect getitem calls on the output
+    # and we want to replace the uses of getitem 0 with the output of conv
+    #
+    # Before:
+    # conv -> bn - (first output) -> users1
+    #          \ - (second output) -> users2
+    #          \ - (third output) -> users3
+    # After:
+    # conv -> (first output) -> users1
+    #       bn -
+    #          \ - (second output) -> users2
+    #          \ - (third output) -> users3
+    # if users2 and users3 are empty then bn will be removed through dead code elimination
+
+    for user in bn_node.users:
+        if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
+            continue
+        user.replace_all_uses_with(conv_node)
+
+# fuse conv bn weights, inplace modification of the graph_module and graph
+def _fuse_conv_bn_(m: GraphModule) -> None:
+    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
+    if not has_bn:
+        return
+    for n in m.graph.nodes:
+        if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
+            continue
+        bn_node = n
+        n = bn_node.args[0]
+        if not _is_conv(n):
+            continue
+        conv_node = n
+        conv_weight_node = conv_node.args[1]
+        conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
+        fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
+
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
+    # TODO: move this information to fx node itself
+    node_name_to_scope: Dict[str, Tuple[str, type]] = {}
+    for n in model.graph.nodes:
+        nn_module_stack = n.meta.get("nn_module_stack", None)
+        current_scope = ("", type(None))
+        if nn_module_stack:
+            bt = list(nn_module_stack.values())[-1]
+            current_scope = (bt[0].split(".")[-1], bt[1])
+        node_name_to_scope[n.name] = current_scope
+    return node_name_to_scope
+
+def get_aten_graph_module(
+    pattern: Callable,
+    example_inputs: Tuple[Any, ...],
+    is_cuda: bool = False,
+    **kwargs,
+) -> GraphModule:
+    """
+    Convert the pattern to an FX graph with decomposed aten ops.
+    """
+    if is_cuda:
+        example_inputs = tuple([x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs])
+    aten_pattern = capture_pre_autograd_graph(
+        pattern,
+        example_inputs,
+        kwargs,
+    )
+    aten_pattern.graph.eliminate_dead_code()
+    aten_pattern.recompile()
+    return aten_pattern
+
+def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
+    """ Remove .tensor overload for quantize/dequantize ops so that we can
+    use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
+    """
+    _MAP = {
+        torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
+        torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
+        torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
+        torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
+    }
+    for n in match_pattern.graph.nodes:
+        if n.op != "call_function":
+            continue
+        if n.target in _MAP:
+            n.target = _MAP[n.target]
+
+def _is_literal(arg):
+    if isinstance(arg, (int, float)):
+        return True
+    if isinstance(arg, (tuple, list)):
+        return all(map(_is_literal, arg))
+    return False
+
+def _replace_literals_with_new_placeholders(
+    gm: torch.fx.GraphModule,
+    merge_dup: bool = False,
+    exclude_literals: Optional[List[Any]] = None
+):
+    """Replace the literals in the graph with placeholder nodes that's created on the fly while we
+    traverse the graph, so that the literal arguments in the graph can be matched and replaced
+
+    To use this, the pattern and replacement graph should have the exact same number of literal args
+    and they should be used in the exact same order in the pattern and replacement graph.
+
+    If the literal arguments are not used in the same order in pattern and replacement graph, please
+    use `_replace_literals_with_existing_placeholders` instead
+
+    Args:
+        `gm`: input GraphModule that we'll transform
+        `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
+         the graph, whether they should correspond to the same placeholder or not
+        `exclude_literals`: a list of literals that will not be replaced with placeholders
+
+    Example:
+
+    # 1. Original Graph
+    def pattern(self, x):
+        return x + 3
+
+    def replacement(self, x):
+        return x - 3
+
+    example_inputs = (torch.randn(1, 3, 3, 3),)
+    pattern_gm = get_aten_graph_module(pattern, example_inputs)
+    replacement_gm = get_aten_graph_module(pattern, example_inptus)
+
+    # 2. Before calling replace literals we'll see the following graph:
+    def pattern(self, x):
+        return x + 3
+
+    def replacement(self, x):
+        return x - 3
+
+    pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
+    replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
+
+    # 3. After replacing literals with new placeholder nodes
+
+    def pattern(self, x, new_ph):
+        return x + new_ph
+
+    def pattern(self, x, new_ph):
+        return x - new_ph
+
+    """
+    last_ph = None
+    cnt = 0
+    literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {}
+    if exclude_literals is None:
+        exclude_literals = []
+
+    in_spec = gm._in_spec
+    args_spec = in_spec.children_specs[0]
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            last_ph = node
+            cnt += 1
+            continue
+        with gm.graph.inserting_after(last_ph):
+            new_args = []
+            for arg in node.args:
+                if _is_literal(arg) and arg not in exclude_literals:
+                    if merge_dup and arg in literal_to_ph:
+                        new_args.append(literal_to_ph[arg])
+                    else:
+                        ph_node = gm.graph.placeholder("arg" + str(cnt))
+                        new_args.append(ph_node)
+                        args_spec.children_specs.append(LeafSpec())
+                        cnt += 1
+                        if merge_dup:
+                            literal_to_ph[arg] = ph_node
+                else:
+                    new_args.append(arg)
+            new_args = tuple(new_args)
+
+        node.args = new_args
+
+    # Update `num_nodes`, `num_leaves`, `num_children`.
+    args_spec.__post_init__()
+    in_spec.__post_init__()
+    return gm
+
+
+def _replace_literals_with_existing_placeholders(
+    gm: torch.fx.GraphModule,
+    exclude_literals: Optional[List[Any]] = None,
+    literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None
+):
+    """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
+    in the graph can be matched and replaced
+
+    To use this, all literal args in the graph should be unique and each of them should correspond
+    to exactly one placeholder node
+
+    # 1. Original Graph
+    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
+        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
+
+    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
+        x_i8 = torch.clamp(x_i8, quant_min, quant_max)
+        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
+
+    example_inputs = (
+        torch.randn(1, 3, 3, 3),
+        1.0,
+        0,
+        -128,
+        127,
+    )
+    pattern_gm = get_aten_graph_module(pattern, example_inputs)
+    replacement_gm = get_aten_graph_module(pattern, example_inptus)
+
+    # 2. Before calling replace literals we'll see the following graph:
+    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
+
+    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        x_i8 = torch.clamp(x_i8, -128, 127)
+        return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
+
+    # Note that literal args appear in different order in pattern and replacement graph, so
+    # we can't use _replace_literals_with_new_placeholders
+
+    literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
+    pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
+    replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
+
+    # 3. After replacing literals with existing placeholder nodes
+
+    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
+
+    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        x_i8 = torch.clamp(x_i8, quant_min, quant_max)
+        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
+    """
+    if exclude_literals is None:
+        exclude_literals = []
+
+    if literal_to_ph_idx is None:
+        literal_to_ph_idx = {}
+
+    phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
+
+    for node in gm.graph.nodes:
+        if node.op != "call_function":
+            continue
+        new_args = []
+        for arg in node.args:
+            if _is_literal(arg) and arg not in exclude_literals and arg in literal_to_ph_idx:
+                ph_idx = literal_to_ph_idx[arg]
+                ph_node = phs[ph_idx]
+                new_args.append(ph_node)
+            else:
+                new_args.append(arg)
+        new_args = tuple(new_args)
+        node.args = new_args
+    return gm
+
+# TODO: Handle this in export itself and don't wrap the model in another GraphModule
+# in prepare and convert
+def _disallow_eval_train(model: GraphModule):
+    """
+    Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
+    This is useful for exported models, where these methods don't actually behave as expected.
+    """
+    error_message = \
+        """
+        Calling train() or eval() is not supported for exported models.
+        Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
+
+        If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
+        the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
+        which does the above automatically for you. Note that this has limited effect on switching
+        behavior between train and eval modes, and should be used only for special ops such as dropout
+        and batchnorm.
+        """
+
+    def _train(self, mode: bool = True):
+        raise NotImplementedError(error_message)
+
+    def _eval(self, mode: bool = True):
+        raise NotImplementedError(error_message)
+
+    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
+    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/qconfig.py b/MLPY/Lib/site-packages/torch/ao/quantization/qconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..99cc79f2607e6c3bffe2a044e367e693a039ba3c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/qconfig.py
@@ -0,0 +1,560 @@
+from collections import namedtuple
+from typing import Optional, Any, Union, Type
+
+import torch
+import torch.nn as nn
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantize,
+    FakeQuantizeBase,
+    default_fake_quant,
+    default_dynamic_fake_quant,
+    default_per_channel_weight_fake_quant,
+    default_weight_fake_quant,
+    default_fused_act_fake_quant,
+    default_fused_wt_fake_quant,
+    FusedMovingAvgObsFakeQuantize,
+    default_fused_per_channel_wt_fake_quant,
+    default_embedding_fake_quant,
+    default_embedding_fake_quant_4bit,
+    fused_wt_fake_quant_range_neg_127_to_127,
+    fused_per_channel_wt_fake_quant_range_neg_127_to_127,
+)
+
+from .observer import (
+    _PartialWrapper,
+    MinMaxObserver,
+    HistogramObserver,
+    MovingAverageMinMaxObserver,
+    NoopObserver,
+    PlaceholderObserver,
+    ReuseInputObserver,
+    default_debug_observer,
+    default_dynamic_quant_observer,
+    default_float_qparams_observer,
+    default_float_qparams_observer_4bit,
+    default_observer,
+    default_per_channel_weight_observer,
+    default_placeholder_observer,
+    default_weight_observer,
+    weight_observer_range_neg_127_to_127,
+    per_channel_weight_observer_range_neg_127_to_127,
+    default_reuse_input_observer,
+    ObserverBase,
+)
+import warnings
+import copy
+
+__all__ = [
+    "QConfig",
+    # TODO: deprecated, remove
+    "QConfigDynamic",
+    "default_qconfig",
+    "default_debug_qconfig",
+    "default_per_channel_qconfig",
+    "default_dynamic_qconfig",
+    "float16_dynamic_qconfig",
+    "float16_static_qconfig",
+    "per_channel_dynamic_qconfig",
+    "float_qparams_weight_only_qconfig",
+    "float_qparams_weight_only_qconfig_4bit",
+    "default_quint8_weight_qconfig",
+    "default_qat_qconfig",
+    "default_dynamic_qat_qconfig",
+    "default_weight_only_qconfig",
+    "default_activation_only_qconfig",
+    "default_qat_qconfig_v2",
+    "default_reuse_input_qconfig",
+    "default_symmetric_qnnpack_qconfig",
+    "default_per_channel_symmetric_qnnpack_qconfig",
+    "default_symmetric_qnnpack_qat_qconfig",
+    "default_per_channel_symmetric_qnnpack_qat_qconfig",
+    "default_embedding_qat_qconfig",
+    "default_embedding_qat_qconfig_4bit",
+    "get_default_qconfig",
+    "get_default_qat_qconfig",
+    "get_default_qconfig_dict",
+    "get_default_qat_qconfig_dict",
+    "QConfigAny",
+    "qconfig_equals",
+
+]
+
+class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
+    """
+    Describes how to quantize a layer or a part of the network by providing
+    settings (observer classes) for activations and weights respectively.
+
+
+    Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
+    instances on invocation, not the concrete observer instances themselves.
+    Quantization preparation function will instantiate observers multiple times for each of the layers.
+
+
+    Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
+    method (that behaves like functools.partial)::
+
+      my_qconfig = QConfig(
+          activation=MinMaxObserver.with_args(dtype=torch.qint8),
+          weight=default_observer.with_args(dtype=torch.qint8))
+
+    """
+    def __new__(cls, activation, weight):
+        # catch common mistakes
+        if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
+            raise ValueError("QConfig received observer instance, please pass observer class instead. " +
+                             "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
+        return super().__new__(cls, activation, weight)
+
+
+class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
+    """
+    Describes how to dynamically quantize a layer or a part of the network by providing
+    settings (observer classes) for weights.
+
+    It's like QConfig, but for dynamic quantization.
+
+    Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
+    instances on invocation, not the concrete observer instances themselves.
+    Quantization function will instantiate observers multiple times for each of the layers.
+
+    Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
+    method (that behaves like functools.partial)::
+
+      my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
+    """
+    def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
+        # catch common mistakes
+        if isinstance(weight, nn.Module):
+            raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
+                             "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
+        warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead")
+        return super().__new__(cls, activation, weight)
+
+
+default_qconfig = QConfig(activation=default_observer,
+                          weight=default_weight_observer)
+"""
+Default qconfig configuration.
+"""
+
+default_debug_qconfig = QConfig(weight=default_weight_observer,
+                                activation=default_debug_observer)
+"""
+Default qconfig configuration for debugging.
+"""
+
+default_per_channel_qconfig = QConfig(activation=default_observer,
+                                      weight=default_per_channel_weight_observer)
+"""
+Default qconfig configuration for per channel weight quantization.
+"""
+
+default_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
+                                  weight=default_weight_observer)
+"""
+Default dynamic qconfig.
+"""
+
+float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True),
+                                  weight=PlaceholderObserver.with_args(dtype=torch.float16))
+"""
+Dynamic qconfig with weights quantized to `torch.float16`.
+"""
+
+float16_static_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16),
+                                 weight=PlaceholderObserver.with_args(dtype=torch.float16))
+"""
+Dynamic qconfig with both activations and weights quantized to `torch.float16`.
+"""
+
+per_channel_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
+                                      weight=default_per_channel_weight_observer)
+"""
+Dynamic qconfig with weights quantized per channel.
+"""
+
+float_qparams_weight_only_qconfig = QConfig(
+    activation=default_placeholder_observer,
+    weight=default_float_qparams_observer)
+"""
+Dynamic qconfig with weights quantized with a floating point zero_point.
+"""
+
+float_qparams_weight_only_qconfig_4bit = QConfig(
+    activation=default_placeholder_observer,
+    weight=default_float_qparams_observer_4bit)
+
+default_qat_qconfig = QConfig(activation=default_fake_quant,
+                              weight=default_weight_fake_quant)
+"""
+Default qconfig for QAT.
+"""
+
+default_dynamic_qat_qconfig = QConfig(activation=default_dynamic_fake_quant,
+                                      weight=default_weight_fake_quant)
+"""
+Default qconfig for dynamic QAT.
+"""
+
+default_weight_only_qconfig = QConfig(activation=torch.nn.Identity,
+                                      weight=default_weight_fake_quant)
+"""
+Default qconfig for quantizing weights only.
+"""
+
+default_activation_only_qconfig = QConfig(activation=default_fake_quant,
+                                          weight=torch.nn.Identity)
+"""
+Default qconfig for quantizing activations only.
+"""
+
+# QAT config that uses a fused observer + fake quant modules for optimized training performance.
+# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified.
+default_qat_qconfig_v2 = QConfig(activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant)
+"""
+Fused version of `default_qat_config`, has performance benefits.
+"""
+
+default_reuse_input_qconfig = QConfig(activation=default_reuse_input_observer,
+                                      weight=NoopObserver)
+"""
+Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape
+"""
+
+def get_default_qconfig(backend='x86', version=0):
+    """
+    Returns the default PTQ qconfig for the specified backend.
+
+    Args:
+      * `backend` (str): a string representing the target backend. Currently supports
+        `x86` (default), `fbgemm`, `qnnpack` and `onednn`.
+
+    Return:
+        qconfig
+    """
+    supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"]
+    if backend not in supported_backends:
+        raise AssertionError(
+            "backend: " + str(backend) +
+            f" not supported. backend must be one of {supported_backends}"
+        )
+
+    if version == 0:
+        if backend == 'fbgemm':
+            qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
+                              weight=default_per_channel_weight_observer)
+        elif backend == 'qnnpack':
+            # TODO: make this compatible with xnnpack constraints
+            qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
+                              weight=default_weight_observer)
+        elif backend == 'onednn':
+            if not torch.cpu._is_cpu_support_vnni():
+                warnings.warn(
+                    "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues "
+                    "on CPU without Vector Neural Network Instruction support.")
+            qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
+                              weight=default_per_channel_weight_observer)
+        elif backend == 'x86':
+            qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
+                              weight=default_per_channel_weight_observer)
+        else:
+            # won't reach
+            qconfig = default_qconfig
+    else:
+        raise AssertionError("Version number: " + str(version) +
+                             " in get_default_qconfig is not supported. Version number must be 0")
+
+    return qconfig
+
+"""
+Default, symmetric PTQ qconfig for the specified backend. And a per_channel
+variant of the same.
+
+Symmetric here applies to signed weights with zero point = 0, and additional
+value restrictions. The activations are also signed 8-bit integers with this
+qconfig.
+
+    * Once this change is merged [as of 3/17/22], with backend or qengine =
+    'qnnpack', some quantized operators with this symmetric qconfig may use
+    operators from xnnpack library.
+
+        ** Support to use xnnpack ops with `qnnpack` backed for asymmetric
+        qconfig (returned by get_default_qconfig()) is not available yet.
+
+    * This qconfig uses signed activations and weights. Weights have added
+    restrictions such as zero point is forced to be 0, making the weights
+    symmetric, hence the name. And the 8-bit quantized values are
+    restricting to to [-127, +127], excluding -128.
+
+    * xnnpack has a requantization scale value restriction, 0x1p-32 <=
+    requantization_scale < 256.0 where, `requantization_scale = (input_scale
+    * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value
+    of 256) is to prevent requantization_scale to go below xnnpack lower
+    threshold.
+"""
+default_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
+                                                                                   reduce_range=False,
+                                                                                   eps=2 ** -12),
+                                            weight=weight_observer_range_neg_127_to_127)
+
+default_per_channel_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
+                                                                                               reduce_range=False,
+                                                                                               eps=2 ** -12),
+                                                        weight=per_channel_weight_observer_range_neg_127_to_127)
+
+default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
+                                        weight=default_embedding_fake_quant)
+
+default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
+                                             weight=default_embedding_fake_quant_4bit)
+
+default_quint8_weight_qconfig = QConfig(activation=HistogramObserver, weight=MinMaxObserver)
+
+def get_default_qat_qconfig(backend='x86', version=1):
+    """
+    Returns the default QAT qconfig for the specified backend.
+
+    Args:
+      * `backend` (str): a string representing the target backend. Currently supports
+        `x86` (default), `fbgemm`, `qnnpack` and `onednn`.
+      * `version`: version, for backwards compatibility. Can be `None` or `1`.
+
+    Return:
+        qconfig
+    """
+    supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"]
+    if backend not in supported_backends:
+        raise AssertionError(
+            "backend: " + str(backend) +
+            f" not supported. backend must be one of {supported_backends}"
+        )
+
+    # Histogram observer is too slow for quantization aware training
+    if version == 0:
+        if backend == 'fbgemm':
+            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                quant_min=0,
+                                                                quant_max=255,
+                                                                reduce_range=True),
+                              weight=default_per_channel_weight_fake_quant)
+        elif backend == 'qnnpack':
+            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                quant_min=0,
+                                                                quant_max=255,
+                                                                reduce_range=False),
+                              weight=default_weight_fake_quant)
+        elif backend == 'onednn':
+            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                quant_min=0,
+                                                                quant_max=255),
+                              weight=default_per_channel_weight_fake_quant)
+        elif backend == 'x86':
+            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                quant_min=0,
+                                                                quant_max=255,
+                                                                reduce_range=True),
+                              weight=default_per_channel_weight_fake_quant)
+        else:
+            qconfig = default_qat_qconfig
+    # Use the fused observe + fake_quant modules for doing QAT.
+    elif version == 1:
+        if backend == 'fbgemm':
+            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                                 quant_min=0,
+                                                                                 quant_max=255,
+                                                                                 reduce_range=True),
+                              weight=default_fused_per_channel_wt_fake_quant)
+        elif backend == 'qnnpack':
+            # TODO: make this compatible with xnnpack constraints
+            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                                 quant_min=0,
+                                                                                 quant_max=255,
+                                                                                 reduce_range=False),
+                              weight=default_fused_wt_fake_quant)
+        elif backend == 'onednn':
+            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                                 quant_min=0,
+                                                                                 quant_max=255),
+                              weight=default_fused_per_channel_wt_fake_quant)
+        elif backend == 'x86':
+            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                                                 quant_min=0,
+                                                                                 quant_max=255,
+                                                                                 reduce_range=True),
+                              weight=default_fused_per_channel_wt_fake_quant)
+        else:
+            qconfig = default_qat_qconfig_v2
+    else:
+        raise AssertionError("Version number: " + str(version) +
+                             "in get_default_qat_qconfig is not supported. Version number must be 0 or 1")
+
+    return qconfig
+
+"""
+Default symmetric QAT qconfig for qnnpack. And its per channel weight variant.
+"""
+default_symmetric_qnnpack_qat_qconfig = QConfig(
+    activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                       quant_min=-128,
+                                                       quant_max=127,
+                                                       dtype=torch.qint8,
+                                                       reduce_range=False,
+                                                       eps=2 ** -12),
+    weight=fused_wt_fake_quant_range_neg_127_to_127)
+
+default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
+    activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
+                                                       quant_min=-128,
+                                                       quant_max=127,
+                                                       dtype=torch.qint8,
+                                                       reduce_range=False,
+                                                       eps=2 ** -12),
+    weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
+
+_default_fp32_placeholder_qconfig = QConfig(
+    activation=PlaceholderObserver.with_args(dtype=torch.float32),
+    weight=PlaceholderObserver.with_args(dtype=torch.float32)
+)
+
+_default_quint8_placeholder_qconfig = QConfig(
+    activation=PlaceholderObserver.with_args(dtype=torch.quint8),
+    # operators using this qconfig doesn't have weights
+    weight=None,
+)
+
+def get_default_qconfig_dict(backend='x86', version=0):
+    warnings.warn(
+        "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in "
+        "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.")
+    return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict()
+
+def get_default_qat_qconfig_dict(backend='x86', version=1):
+    warnings.warn(
+        "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in "
+        "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.")
+    return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict()
+
+def _assert_valid_qconfig(qconfig: Optional[QConfig],
+                          mod: torch.nn.Module) -> None:
+    """
+    Verifies that this `qconfig` is valid.
+    """
+    if qconfig is None:
+        return
+    is_conv_transpose_mod = (
+        isinstance(mod, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)))
+    if is_conv_transpose_mod:
+        if qconfig.weight is None:
+            # for now, we assume that any qconfig for ConvTranspose without a weight is valid
+            return
+        example_observer = qconfig.weight()
+        is_per_channel = (
+            isinstance(example_observer, (torch.ao.quantization.PerChannelMinMaxObserver,
+                                          torch.ao.quantization.MovingAveragePerChannelMinMaxObserver))
+        )
+        assert not is_per_channel, \
+            'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
+
+QConfigAny = Optional[QConfig]
+QConfigAny.__module__ = "torch.ao.quantization.qconfig"
+
+def _add_module_to_qconfig_obs_ctr(
+        qconfig: QConfigAny,
+        module: Optional[nn.Module]) -> Any:
+    r"""This is a helper function for use in quantization prepare that updates a qconfig so that
+    the constructors stored in the qconfig will create observers on the same device that
+    'module' is on. This is intended to be used when the qconfigs are propagated to each
+    module in order to avoid potential device alignment issues.
+
+    Args:
+        qconfig: QConfig with obs constructors stored in activation and weight
+        module: module which the qconfig is related to
+
+    Return:
+        qconfig: configured so that obs constructors set to construct on the same device as module
+    """
+
+    if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'):
+        return qconfig
+
+    def get_factory_kwargs_based_on_module_device():
+        assert isinstance(module, torch.nn.Module)
+        devices = {p.device for p in module.parameters()} | \
+            {p.device for p in module.buffers()}
+        device = next(iter(devices)) if len(devices) > 0 else None
+        return None if device is None else {'device': device}
+
+    def configure_constructor_to_put_obs_on_module_device(original_constructor):
+        try:
+            # check if constructor can accept factory_kwargs
+            check = original_constructor.with_args(factory_kwargs=None)
+            check()
+            return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)
+        except AttributeError:  # qconfig doesn't have activation or weight
+            return original_constructor
+        except TypeError:  # the class doesn't accept factory_kwargs argument
+            return original_constructor
+
+    activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation)
+    weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight)
+
+    return QConfig(activation, weight)
+
+_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, Type[ObserverBase], Type[FakeQuantizeBase]]
+
+def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor):
+    if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper):
+        return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2)
+    return obs_or_fq1 == obs_or_fq2
+
+def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper):
+    """
+    Return whether the two partial wrappers are equal,
+    """
+    # functools.partial has no __eq__ operator defined so '==' defaults to 'is'
+    obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords)
+    obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords)
+    keywords_equal = True
+    # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail
+    if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords:
+        keywords_equal = keywords_equal and _obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"])
+        obs_or_fq1_keywords.pop("observer")
+        obs_or_fq2_keywords.pop("observer")
+    keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords
+    return obs_or_fq1.p.func == obs_or_fq2.p.func and obs_or_fq1.p.args == obs_or_fq2.p.args and keywords_equal
+
+def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
+    """
+    Returns `True` if `q1` equals `q2`, and `False` otherwise.
+    """
+    if q1 is None or q2 is None:
+        return q1 == q2
+    else:
+        assert q1 is not None and q2 is not None
+        try:
+            # Qconfig weight and activation can be either a partial wrapper,
+            # or an observer class. Special handling is required (above) for
+            # comparing partial wrappers.
+            activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation)
+            weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight)
+            return activation_same and weight_same
+        except AttributeError:
+            return q1 == q2
+
+def _activation_is_memoryless(qconfig: QConfig):
+    """
+    Return whether the observer for activations defined in the given QConfig is memoryless.
+    This means a MovingAverage observer with averaging constant equal to 1.
+    """
+    def _is_memoryless(observer):
+        return hasattr(observer, "averaging_constant") and observer.averaging_constant == 1
+    act = qconfig.activation()
+    if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"):
+        return _is_memoryless(act.activation_post_process)
+    else:
+        return _is_memoryless(act)
+
+def _is_reuse_input_qconfig(qconfig: Optional[QConfig]):
+    return qconfig is not None and \
+        isinstance(qconfig.activation(), ReuseInputObserver) and \
+        isinstance(qconfig.weight(), NoopObserver)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/qconfig_mapping.py b/MLPY/Lib/site-packages/torch/ao/quantization/qconfig_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6280daefccba3eae33541598398d152165db8e8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/qconfig_mapping.py
@@ -0,0 +1,350 @@
+from __future__ import annotations
+from collections import OrderedDict
+from typing import Any, Callable, Dict, Tuple, Union, List
+
+import torch
+
+from .fake_quantize import (
+    default_weight_fake_quant,
+    FixedQParamsFakeQuantize,
+)
+from .observer import (
+    _PartialWrapper,
+    default_fixed_qparams_range_0to1_observer,
+    default_fixed_qparams_range_neg1to1_observer,
+    default_placeholder_observer,
+    default_weight_observer,
+)
+from .qconfig import (
+    default_reuse_input_qconfig,
+    default_symmetric_qnnpack_qconfig,
+    default_symmetric_qnnpack_qat_qconfig,
+    get_default_qconfig,
+    get_default_qat_qconfig,
+    QConfig,
+    QConfigAny,
+    default_quint8_weight_qconfig
+)
+
+
+__all__ = [
+    "get_default_qconfig_mapping",
+    "get_default_qat_qconfig_mapping",
+    "QConfigMapping",
+]
+
+
+# TODO: replace all usages with these constants
+_GLOBAL_DICT_KEY = ""
+_OBJECT_TYPE_DICT_KEY = "object_type"
+_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex"
+_MODULE_NAME_DICT_KEY = "module_name"
+_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
+
+# TODO: derive this map from the BackendConfig
+_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = {
+    torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer,
+    torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer,
+    "hardsigmoid": default_fixed_qparams_range_0to1_observer,
+    "hardsigmoid_": default_fixed_qparams_range_0to1_observer,
+    torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer,
+    torch.sigmoid: default_fixed_qparams_range_0to1_observer,
+    "sigmoid": default_fixed_qparams_range_0to1_observer,
+    "sigmoid_": default_fixed_qparams_range_0to1_observer,
+    torch.nn.Softmax: default_fixed_qparams_range_0to1_observer,
+    torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer,
+    torch.tanh: default_fixed_qparams_range_neg1to1_observer,
+    "tanh": default_fixed_qparams_range_neg1to1_observer,
+    "tanh_": default_fixed_qparams_range_neg1to1_observer,
+}
+
+
+def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping:
+    """
+    Return the default QConfigMapping for the given quantization type and backend.
+    """
+    if is_qat:
+        qconfig = get_default_qat_qconfig(backend, version)
+    else:
+        qconfig = get_default_qconfig(backend, version)
+    default_weight = default_weight_fake_quant if is_qat else default_weight_observer
+
+    # default_per_channel_weight_observer is not currently compatible with fbgemm backend
+    # so we have to modify the weight observer to default_weight_observer or another
+    # per tensor supported observer.
+    # see https://github.com/pytorch/pytorch/issues/47535
+    if backend in ("fbgemm", "x86"):
+        qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
+    else:
+        qconfig_transpose = qconfig
+
+    # currently layernorm only supports float weights
+    # we have to add this because otherwise there will be a extra quantize-dequantize pair
+    qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
+
+    qconfig_mapping = QConfigMapping() \
+        .set_global(qconfig) \
+        .set_object_type("reshape", default_reuse_input_qconfig) \
+        .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
+        .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
+        .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
+        .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
+        .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
+        .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \
+
+    # Use special observers for ops with fixed qparams
+    fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
+    for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
+        if observer in fixed_qparams_observer_to_qconfig:
+            fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
+        else:
+            if is_qat:
+                activation = FixedQParamsFakeQuantize.with_args(observer=observer)
+            else:
+                activation = observer
+            fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight)
+            fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
+        qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)
+
+    # TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
+    #      Need to be able to support fusion of ops with different qconfigs
+
+    return qconfig_mapping
+
+def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping:
+    """
+    Return the default QConfigMapping for post training quantization.
+
+    Args:
+      * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
+         one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
+      * ``version`` (int) : the version for the default qconfig mapping
+    """
+    # TODO: add assert for backend choices
+    return _get_default_qconfig_mapping(False, backend, version)
+
+def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping:
+    """
+    Return the default QConfigMapping for quantization aware training.
+
+    Args:
+      * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
+         one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
+      * ``version`` (int) : the version for the default qconfig mapping
+    """
+    return _get_default_qconfig_mapping(True, backend, version)
+
+def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping:
+    """
+    Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig`
+    as the default QConfig.
+    """
+    default_qconfig = default_symmetric_qnnpack_qconfig
+    return _get_default_qconfig_mapping_with_default_qconfig(False, "qnnpack", default_qconfig)
+
+def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping:
+    """
+    Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig`
+    as the default QConfig.
+    """
+    default_qconfig = default_symmetric_qnnpack_qat_qconfig
+    return _get_default_qconfig_mapping_with_default_qconfig(True, "qnnpack", default_qconfig)
+
+def _get_default_qconfig_mapping_with_default_qconfig(
+    is_qat: bool,
+    backend: str,
+    default_qconfig: QConfig,
+) -> QConfigMapping:
+    """
+    Return a QConfigMapping that uses the provided qconfig as the default QConfig.
+    """
+    if is_qat:
+        qconfig_mapping = get_default_qat_qconfig_mapping(backend)
+    else:
+        qconfig_mapping = get_default_qconfig_mapping(backend)
+    qconfig_mapping.set_global(default_qconfig)
+    for pattern in qconfig_mapping.object_type_qconfigs.keys():
+        if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
+            qconfig_mapping.set_object_type(pattern, default_qconfig)
+    return qconfig_mapping
+
+_QCONFIG_STYLE_ORDER: List[str] = [
+    "global_qconfig",
+    "object_type_qconfigs",
+    "module_name_regex_qconfigs",
+    "module_name_qconfigs",
+    "module_name_object_type_order_qconfigs",
+]
+
+class QConfigMapping:
+    """
+    Mapping from model ops to :class:`torch.ao.quantization.QConfig` s.
+
+    The user can specify QConfigs using the following methods (in increasing match priority):
+
+        ``set_global`` : sets the global (default) QConfig
+
+        ``set_object_type`` : sets the QConfig for a given module type, function, or method name
+
+        ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string
+
+        ``set_module_name`` : sets the QConfig for modules matching the given module name
+
+        ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination
+        of the given module name, object type, and the index at which the module appears
+
+    Example usage::
+
+        qconfig_mapping = QConfigMapping()
+            .set_global(global_qconfig)
+            .set_object_type(torch.nn.Linear, qconfig1)
+            .set_object_type(torch.nn.ReLU, qconfig1)
+            .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
+            .set_module_name_regex("foo.*", qconfig2)
+            .set_module_name("module1", qconfig1)
+            .set_module_name("module2", qconfig2)
+            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3)
+
+    """
+
+    def __init__(self):
+        # In increasing match priority:
+        self.global_qconfig: QConfigAny = None
+        self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = OrderedDict()
+        self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
+        self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
+        self.module_name_object_type_order_qconfigs: OrderedDict[Tuple[str, Callable, int], QConfigAny] =\
+            OrderedDict()
+
+    def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping:
+        """
+        Set the global (default) QConfig.
+        """
+        self.global_qconfig = global_qconfig
+        return self
+
+    def set_object_type(self, object_type: Union[Callable, str], qconfig: QConfigAny) -> QConfigMapping:
+        """
+        Set the QConfig for a given module type, function, or method name.
+        If the QConfig for an existing object type was already set, the new QConfig will override the old one.
+        """
+        self.object_type_qconfigs[object_type] = qconfig
+        return self
+
+    def set_module_name_regex(self, module_name_regex: str, qconfig: QConfigAny) -> QConfigMapping:
+        """
+        Set the QConfig for modules matching the given regex string.
+
+        Regexes will be matched in the order in which they are registered through this method.
+        Thus, the caller should register more specific patterns first, e.g.::
+
+            qconfig_mapping = QConfigMapping()
+                .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
+                .set_module_name_regex("foo.*bar.*", qconfig2)
+                .set_module_name_regex("foo.*", qconfig3)
+
+        In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2,
+        and "foo.baz.relu" would match qconfig3.
+
+        If the QConfig for an existing module name regex was already set, the new QConfig will override the
+        old one while preserving the order in which the regexes were originally registered.
+        """
+        self.module_name_regex_qconfigs[module_name_regex] = qconfig
+        return self
+
+    def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping:
+        """
+        Set the QConfig for modules matching the given module name.
+        If the QConfig for an existing module name was already set, the new QConfig will override the old one.
+        """
+        self.module_name_qconfigs[module_name] = qconfig
+        return self
+
+    def set_module_name_object_type_order(
+            self,
+            module_name: str,
+            object_type: Callable,
+            index: int,
+            qconfig: QConfigAny) -> QConfigMapping:
+        """
+        Set the QConfig for modules matching a combination of the given module name, object type,
+        and the index at which the module appears.
+
+        If the QConfig for an existing (module name, object type, index)  was already set, the new QConfig
+        will override the old one.
+        """
+        self.module_name_object_type_order_qconfigs[(module_name, object_type, index)] = qconfig
+        return self
+
+    def __repr__(self) -> str:
+        output = self.__class__.__name__ + " ("
+        for style_name in _QCONFIG_STYLE_ORDER:
+            output += f"\n {style_name}"
+            qconfigs = getattr(self, style_name)
+            if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0:
+                for key, qconfig in qconfigs.items():
+                    output += f"\n  {key}: {qconfig}"
+            else:
+                output += f"\n  {qconfigs}"
+        return output + "\n)"
+
+    # TODO: remove this
+    def to_dict(self) -> Dict[str, Any]:
+        """
+        Convert this ``QConfigMapping`` to a dictionary with the following keys:
+
+            "" (for global QConfig)
+
+            "object_type"
+
+            "module_name_regex"
+
+            "module_name"
+
+            "module_name_object_type_order"
+
+        The values of this dictionary are lists of tuples.
+        """
+        return {
+            _GLOBAL_DICT_KEY: self.global_qconfig,
+            _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()),
+            _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()),
+            _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()),
+            _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
+                (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items()
+            ],
+        }
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
+        """
+        Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):
+
+            "" (for global QConfig)
+
+            "object_type"
+
+            "module_name_regex"
+
+            "module_name"
+
+            "module_name_object_type_order"
+
+        The values of this dictionary are expected to be lists of tuples.
+        """
+        conf = cls()
+        if _GLOBAL_DICT_KEY in qconfig_dict:
+            conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY])
+        for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []):
+            conf.set_object_type(object_type, qconfig)
+        for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []):
+            conf.set_module_name_regex(module_name_regex, qconfig)
+        for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []):
+            conf.set_module_name(module_name, qconfig)
+        for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
+            conf.set_module_name_object_type_order(module_name, object_type, index, qconfig)
+        return conf
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quant_type.py b/MLPY/Lib/site-packages/torch/ao/quantization/quant_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1676d986388ac889ed97b40cb5e776f04dc96e0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quant_type.py
@@ -0,0 +1,30 @@
+import enum
+
+__all__ = [
+    "QuantType",
+]
+
+# Quantization type (dynamic quantization, static quantization).
+# Should match the c++ enum in quantization_type.h
+class QuantType(enum.IntEnum):
+    DYNAMIC = 0
+    STATIC = 1
+    QAT = 2
+    WEIGHT_ONLY = 3
+
+_quant_type_to_str = {
+    QuantType.STATIC: "static",
+    QuantType.DYNAMIC: "dynamic",
+    QuantType.QAT: "qat",
+    QuantType.WEIGHT_ONLY: "weight_only",
+}
+
+# TODO: make this private
+def _get_quant_type_to_str(quant_type: QuantType) -> str:
+    return _quant_type_to_str[quant_type]
+
+def _quant_type_from_str(name: str) -> QuantType:
+    for quant_type, s in _quant_type_to_str.items():
+        if name == s:
+            return quant_type
+    raise ValueError(f"Unknown QuantType name '{name}'")
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantization_mappings.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantization_mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f0afe0acd712e5740a3602f280a9386309cdeed
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantization_mappings.py
@@ -0,0 +1,348 @@
+import copy
+
+import torch
+from torch import nn
+
+import torch.nn.functional as F
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.reference as nnqr
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.qat.dynamic as nnqatd
+
+from typing import Optional, Union, Dict, Set, Callable, Any
+
+# Because `torch.ao.nn` uses lazy imports, we need to make
+# sure we import the contents explicitly here.
+import torch.ao.nn.sparse
+import torch.ao.nn as ao_nn
+from torch.ao.quantization.stubs import QuantStub, DeQuantStub
+from torch.ao.quantization.fake_quantize import (
+    default_fixed_qparams_range_0to1_fake_quant,
+    default_fixed_qparams_range_neg1to1_fake_quant,
+)
+from torch.ao.quantization.utils import get_combined_dict
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+__all__ = [
+    "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS",
+    "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS",
+    "DEFAULT_QAT_MODULE_MAPPINGS",
+    "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS",
+    "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS",
+    "DEFAULT_MODULE_TO_ACT_POST_PROCESS",
+    "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS",
+    "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS",
+    "no_observer_set",
+    "get_default_static_quant_module_mappings",
+    "get_default_static_quant_reference_module_mappings",
+    "get_embedding_static_quant_module_mappings",
+    "get_default_static_sparse_quant_module_mappings",
+    "get_static_quant_module_class",
+    "get_dynamic_quant_module_class",
+    "get_default_qat_module_mappings",
+    "get_embedding_qat_module_mappings",
+    "get_default_dynamic_quant_module_mappings",
+    "get_default_dynamic_sparse_quant_module_mappings",
+    "get_default_qconfig_propagation_list",
+    "get_default_compare_output_module_list",
+    "get_default_float_to_quantized_operator_mappings",
+    "get_quantized_operator",
+]
+
+# Default map for swapping float module to reference quantized modules
+DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    QuantStub: nnq.Quantize,
+    DeQuantStub: nnq.DeQuantize,
+    nn.Linear: nnqr.Linear,
+    nn.Conv1d: nnqr.Conv1d,
+    nn.Conv2d: nnqr.Conv2d,
+    nn.Conv3d: nnqr.Conv3d,
+    nn.ConvTranspose1d: nnqr.ConvTranspose1d,
+    nn.ConvTranspose2d: nnqr.ConvTranspose2d,
+    nn.ConvTranspose3d: nnqr.ConvTranspose3d,
+    nn.Embedding: nnqr.Embedding,
+    nn.EmbeddingBag: nnqr.EmbeddingBag,
+    nn.GRUCell: nnqr.GRUCell,
+    nn.LSTMCell: nnqr.LSTMCell,
+    nn.RNNCell: nnqr.RNNCell,
+    nn.LSTM: nnqr.LSTM,
+}
+
+# Default map for swapping float module to quantized ones
+DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    QuantStub: nnq.Quantize,
+    DeQuantStub: nnq.DeQuantize,
+    nn.BatchNorm2d: nnq.BatchNorm2d,
+    nn.BatchNorm3d: nnq.BatchNorm3d,
+    nn.Dropout: nnq.Dropout,
+    nn.Conv1d: nnq.Conv1d,
+    nn.Conv2d: nnq.Conv2d,
+    nn.Conv3d: nnq.Conv3d,
+    nn.ConvTranspose1d: nnq.ConvTranspose1d,
+    nn.ConvTranspose2d: nnq.ConvTranspose2d,
+    nn.ConvTranspose3d: nnq.ConvTranspose3d,
+    nn.ELU: nnq.ELU,
+    nn.Embedding: nnq.Embedding,
+    nn.EmbeddingBag: nnq.EmbeddingBag,
+    nn.GroupNorm: nnq.GroupNorm,
+    nn.Hardswish: nnq.Hardswish,
+    nn.InstanceNorm1d: nnq.InstanceNorm1d,
+    nn.InstanceNorm2d: nnq.InstanceNorm2d,
+    nn.InstanceNorm3d: nnq.InstanceNorm3d,
+    nn.LayerNorm: nnq.LayerNorm,
+    nn.LeakyReLU: nnq.LeakyReLU,
+    nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear,
+    nn.Linear: nnq.Linear,
+    nn.ReLU6: nnq.ReLU6,
+    nn.Dropout: nnq.Dropout,
+    nn.PReLU: nnq.PReLU,
+    # Wrapper Modules:
+    nnq.FloatFunctional: nnq.QFunctional,
+    # Intrinsic modules:
+    nni.BNReLU2d: nniq.BNReLU2d,
+    nni.BNReLU3d: nniq.BNReLU3d,
+    nni.ConvReLU1d: nniq.ConvReLU1d,
+    nni.ConvReLU2d: nniq.ConvReLU2d,
+    nni.ConvReLU3d: nniq.ConvReLU3d,
+    nni.ConvAdd2d: nniq.ConvAdd2d,
+    nni.ConvAddReLU2d: nniq.ConvAddReLU2d,
+    nni.LinearReLU: nniq.LinearReLU,
+    nni.LinearLeakyReLU: nniq.LinearLeakyReLU,
+    nni.LinearTanh: nniq.LinearTanh,
+    nniqat.ConvBn1d: nnq.Conv1d,
+    nniqat.ConvBn2d: nnq.Conv2d,
+    nniqat.ConvBn3d: nnq.Conv3d,
+    nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
+    nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
+    nniqat.ConvBnReLU3d: nniq.ConvReLU3d,
+    nniqat.ConvReLU2d: nniq.ConvReLU2d,
+    nniqat.ConvReLU3d: nniq.ConvReLU3d,
+    nniqat.LinearReLU: nniq.LinearReLU,
+    nniqat.LinearBn1d: nnq.Linear,
+    # QAT modules:
+    nnqat.Linear: nnq.Linear,
+    nnqat.Conv2d: nnq.Conv2d,
+    nnqat.Conv3d: nnq.Conv3d,
+}
+
+# Default map for swapping float module to qat modules
+DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    nn.Conv2d: nnqat.Conv2d,
+    nn.Conv3d: nnqat.Conv3d,
+    nn.Linear: nnqat.Linear,
+    nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
+    # Intrinsic modules:
+    nni.ConvBn1d: nniqat.ConvBn1d,
+    nni.ConvBn2d: nniqat.ConvBn2d,
+    nni.ConvBn3d: nniqat.ConvBn3d,
+    nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
+    nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
+    nni.ConvBnReLU3d: nniqat.ConvBnReLU3d,
+    nni.ConvReLU2d: nniqat.ConvReLU2d,
+    nni.ConvReLU3d: nniqat.ConvReLU3d,
+    nni.LinearReLU: nniqat.LinearReLU,
+    nni.LinearBn1d: nniqat.LinearBn1d,
+}
+
+# Default map for swapping dynamic modules
+DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    nn.GRUCell: nnqd.GRUCell,
+    nn.Linear: nnqd.Linear,
+    nnqatd.Linear: nnqd.Linear,
+    nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear,
+    nn.LSTM: nnqd.LSTM,
+    nn.GRU: nnqd.GRU,
+    nn.LSTMCell: nnqd.LSTMCell,
+    nn.RNNCell: nnqd.RNNCell,
+    nni.LinearReLU: nniqd.LinearReLU,
+    nn.EmbeddingBag: nnq.EmbeddingBag,
+    nn.Embedding: nnq.Embedding,
+    # Don't want to enable these by default because the numerical
+    # accuracy is poor compared to other dynamic ops
+    # nn.Conv1d: nnqd.Conv1d,
+    # nn.Conv2d: nnqd.Conv2d,
+    # nn.Conv3d: nnqd.Conv3d,
+    # nn.ConvTranspose1d: nnqd.ConvTranspose1d,
+    # nn.ConvTranspose2d: nnqd.ConvTranspose2d,
+    # nn.ConvTranspose3d: nnqd.ConvTranspose3d,
+}
+
+# Allowlist for propagating the qconfig
+_INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = {
+    nn.Sequential,
+}
+
+# Default mapping from floating point function or torch ops to quantized ops
+# TODO: merge with default static mapping
+DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
+    F.elu: torch.ops.quantized.elu,
+    F.hardswish: torch.ops.quantized.hardswish,
+    F.instance_norm: torch.ops.quantized.instance_norm,
+    F.layer_norm: torch.ops.quantized.layer_norm,
+    F.leaky_relu: torch.ops.quantized.leaky_relu,
+    F.dropout: torch.ops.quantized.dropout,
+}
+
+# mapping from module to output activation post process class
+DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = {
+    nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
+    nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
+    nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
+    nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
+}
+
+# Default map for swapping float module to static sparse quantized ones
+DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    nn.Linear: ao_nn.sparse.quantized.Linear
+}
+
+# Default map for swapping float module to dynamic sparse quantized ones
+DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
+    nn.Linear: ao_nn.sparse.quantized.dynamic.Linear
+}
+
+def no_observer_set() -> Set[Any]:
+    r"""These modules cannot have observers inserted by default."""
+    no_observers = {
+        nn.quantizable.LSTM,
+        nn.quantizable.MultiheadAttention
+    }
+    return no_observers
+
+def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
+    ''' Get module mapping for post training static quantization
+    '''
+    return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
+
+def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
+    ''' Get reference module mapping for post training static quantization
+    '''
+    return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
+
+def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
+    ''' Get module mapping, including mapping for embedding QAT
+    '''
+    mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
+    mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag
+    mapping[nnqat.Embedding] = nnq.Embedding
+    return mapping
+
+def get_default_static_sparse_quant_module_mappings() -> Dict[Callable, Any]:
+    ''' Get module mapping for post training static sparse quantization
+    '''
+    return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS)
+
+def get_static_quant_module_class(
+        float_module_class: Callable,
+        additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
+        is_reference: bool = False) -> Any:
+    r"""n Get the statically quantized module class corresponding to
+    the floating point module class
+    """
+    if additional_static_quant_mapping is None:
+        additional_static_quant_mapping = {}
+    all_mappings = get_combined_dict(
+        DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
+        else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
+    static_quant_module_class = all_mappings.get(float_module_class, None)
+    assert static_quant_module_class is not None, \
+        f"Floating point module class {str(float_module_class)}" + \
+        " does not have a corresponding quantized module class"
+    return copy.deepcopy(static_quant_module_class)
+
+def get_dynamic_quant_module_class(
+        float_module_class: Callable,
+        additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
+    r"""n Get the dynamically quantized module class corresponding to
+    the floating point module class
+    """
+    if additional_dynamic_quant_mapping is None:
+        additional_dynamic_quant_mapping = {}
+    all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping)
+    dynamic_quant_module_class = all_mappings.get(float_module_class, None)
+    assert dynamic_quant_module_class is not None, \
+        f"Floating point module class {str(float_module_class)}" + \
+        " does not have a corresponding quantized module class"
+    return copy.deepcopy(dynamic_quant_module_class)
+
+def get_default_qat_module_mappings() -> Dict[Callable, Any]:
+    ''' Get default module mapping for quantization aware training
+    '''
+    return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
+
+def get_embedding_qat_module_mappings() -> Dict[Callable, Any]:
+    ''' Get module mapping for quantization aware training
+        This is includes default values in addition to
+        enabling qat for embeddings.
+    '''
+    mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
+    mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag
+    mapping[nn.Embedding] = nnqat.Embedding
+    return mapping
+
+def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]:
+    ''' Get module mapping for post training dynamic quantization
+    '''
+    return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
+
+def get_default_dynamic_sparse_quant_module_mappings() -> Dict[Callable, Any]:
+    ''' Get module mapping for post training dynamic sparse quantization
+    '''
+    return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS
+
+def get_default_qconfig_propagation_list() -> Set[Callable]:
+    ''' Get the default list of module types that we'll attach qconfig
+    attribute to in prepare
+    '''
+    QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
+        set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
+        set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
+        set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
+        _INCLUDE_QCONFIG_PROPAGATE_LIST
+    )
+    return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST)
+
+def get_default_compare_output_module_list() -> Set[Callable]:
+    ''' Get list of module class types that we will record output
+    in numeric suite
+    '''
+    NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
+        set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
+        | set(DEFAULT_QAT_MODULE_MAPPINGS.values())
+        | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
+        | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
+        | set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
+        | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
+        | _INCLUDE_QCONFIG_PROPAGATE_LIST
+    )
+    return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST)
+
+def get_default_float_to_quantized_operator_mappings(
+) -> Dict[Union[Callable, str], Callable]:
+    return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS)
+
+# TODO: merge with get_static_quant_module_class
+def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
+    ''' Get the quantized operator corresponding to the float operator
+    '''
+    quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
+    assert quantized_op is not None, \
+        f'Operator {str(float_op)} does not have corresponding quantized op'
+    return quantized_op
+
+def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]:
+    r""" Get the special activation post process for `module`, this has
+    higher priority than the activation post process in `qconfig`
+    e.g.
+    input: torch.nn.Sigmoid
+    output: default_affine_fixed_qparam_fake_quant
+    """
+    return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type_before_parametrizations(module), None)
+
+def _has_special_act_post_process(module: torch.nn.Module) -> bool:
+    return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantize.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02818a99eaf0bdd666b12864bb1f8c590f22842
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantize.py
@@ -0,0 +1,664 @@
+import copy
+import itertools
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.ao.nn.quantized as nnq
+from torch.ao.nn.intrinsic import _FusedModule
+
+from torch.ao.quantization.quantization_mappings import (
+    get_default_dynamic_quant_module_mappings,
+    get_default_static_quant_module_mappings,
+    get_default_static_quant_reference_module_mappings,
+    get_default_qat_module_mappings,
+    get_default_qconfig_propagation_list,
+    no_observer_set,
+    _has_special_act_post_process,
+    _get_special_act_post_process,
+)
+from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
+from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
+from torch.ao.quantization.qconfig import (
+    _add_module_to_qconfig_obs_ctr,
+    default_dynamic_qconfig,
+    float16_dynamic_qconfig,
+    float_qparams_weight_only_qconfig,
+    float_qparams_weight_only_qconfig_4bit,
+    _activation_is_memoryless)
+from torch.nn.utils.parametrize import type_before_parametrizations
+from torch.ao.quantization.observer import _is_activation_post_process
+
+# TODO remove this once BC is no longer required to avoid a SEV
+from torch.ao.quantization.observer import (   # noqa: F401
+    _is_activation_post_process as is_activation_post_process
+)
+
+__all__ = [
+    "get_default_custom_config_dict",
+    "propagate_qconfig_",
+    "add_quant_dequant",
+    "prepare",
+    "quantize",
+    "quantize_dynamic",
+    "prepare_qat",
+    "quantize_qat",
+    "convert",
+    "swap_module",
+]
+
+_DEFAULT_CUSTOM_CONFIG_DICT = {
+    'float_to_observed_custom_module_class': {
+        nn.LSTM: nn.quantizable.LSTM,
+        nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
+    },
+    'observed_to_quantized_custom_module_class': {
+        nn.quantizable.LSTM: nn.quantized.LSTM,
+        nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
+    }
+}
+
+def get_default_custom_config_dict():
+    r"""Defines the default custom config dict.
+    """
+    return _DEFAULT_CUSTOM_CONFIG_DICT
+
+def _propagate_qconfig_helper(module, qconfig_dict,
+                              qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
+    r"""This is a helper function for `propagate_qconfig_`
+
+    Args:
+        module: input module
+        qconfig_dict: dictionary that maps from name of submodule to quantization
+                     configuration
+        qconfig_parent: quantization config of parent module, we will fallback to
+                       this config when there is no specified config for current
+                       module
+        prefix: corresponding prefix of the current module, used as key in
+                qconfig_dict
+        prepare_custom_config_dict: dictionary for custom handling of modules
+                                    see docs for :func:`~torch.ao.quantization.prepare_fx`
+
+    Return:
+        None, module is modified inplace with qconfig attached
+    """
+
+    module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
+    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
+    module_qconfig = getattr(module, 'qconfig', module_qconfig)
+
+    torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)
+
+    qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
+    module.qconfig = qconfig_with_device_check
+
+    for name, child in module.named_children():
+        module_prefix = prefix + '.' + name if prefix else name
+        #  do no not propagate qconfig to child if child is non traceable
+        if prepare_custom_config_dict is None or not (
+            name in prepare_custom_config_dict.get("non_traceable_module_name", [])
+            or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
+        ):
+            _propagate_qconfig_helper(
+                child, qconfig_dict, qconfig_with_device_check, module_prefix
+            )
+
+def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
+    r"""Propagate qconfig through the module hierarchy and assign `qconfig`
+    attribute on each leaf module
+
+    Args:
+        module: input module
+        qconfig_dict: dictionary that maps from name or type of submodule to
+            quantization configuration, qconfig applies to all submodules of a
+            given module unless qconfig for the submodules are specified (when
+            the submodule already has qconfig attribute)
+        prepare_custom_config_dict: dictionary for custom handling of modules
+            see docs for :func:`~torch.ao.quantization.prepare_fx`
+
+    Return:
+        None, module is modified inplace with qconfig attached
+    """
+    if qconfig_dict is None:
+        qconfig_dict = {}
+    if prepare_custom_config_dict is None:
+        prepare_custom_config_dict = {}
+    _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)
+
+def _observer_forward_hook(self, input, output):
+    r"""Forward hook that calls observer on the output
+    """
+    return self.activation_post_process(output)
+
+def _observer_forward_pre_hook(self, input):
+    r"""Forward pre hook that calls observer on the output
+    """
+    return self.activation_post_process(input[0])
+
+def _register_activation_post_process_hook(module, pre_hook=False):
+    assert hasattr(module, 'activation_post_process'), \
+        'Expect activation_post_process attribute already attached to the module'
+    if pre_hook:
+        handle = module.register_forward_pre_hook(
+            _observer_forward_pre_hook, prepend=True
+        )
+    else:
+        handle = module.register_forward_hook(
+            _observer_forward_hook, prepend=True
+        )
+
+
+def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
+    r"""Add observer for the leaf child of the module.
+
+    This function insert observer module to all leaf child module that
+    has a valid qconfig attribute.
+
+    Args:
+        module: input module with qconfig attributes for all the leaf modules that we want to quantize
+        qconfig_propagation_list: a list of quantizable modules that will have observers added to them
+            if they are leaf nodes
+        device: parent device, if any
+        non_leaf_module_list: list of non-leaf modules we want to add observer
+
+    Return:
+        None, module is modified inplace with added observer modules and forward_hooks
+    """
+    if qconfig_propagation_list is None:
+        qconfig_propagation_list = get_default_qconfig_propagation_list()
+
+    if custom_module_class_mapping is None:
+        custom_module_class_mapping = {}
+
+    # respect device affinity when adding observers
+    if device is None:
+        devices = _get_unique_devices_(module)
+        assert len(devices) <= 1, (
+            f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}"
+        )
+        device = next(iter(devices)) if len(devices) > 0 else None
+
+    def get_activation_post_process(qconfig, device, special_act_post_process=None):
+        activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
+        if device is not None:
+            activation.to(device)
+        return activation
+
+    def needs_observation(m):
+        return hasattr(m, 'qconfig') and m.qconfig is not None
+
+    def insert_activation_post_process(m, special_act_post_process=None):
+        """ Adds an activation post process module and register
+        a pre or post hook that calls the module
+        """
+        # We don't insert observer/fake_quantize for DeQuantStub
+        if needs_observation(m) and not isinstance(m, DeQuantStub):
+            # observer and hook will be gone after we swap the module
+            m.add_module('activation_post_process', get_activation_post_process(
+                m.qconfig, device, special_act_post_process))
+            # Register observer as the first entry in the hook list
+            # All post forward hooks are preserved and will be executed after the observer before convert
+            _register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
+
+    for name, child in module.named_children():
+        # TODO remove Dropout special after codebase stable
+        if type_before_parametrizations(child) in [nn.Dropout]:
+            continue
+        elif issubclass(type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)):
+            if needs_observation(child):
+                assert hasattr(child, "activation_post_process"), (
+                    f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`"
+                )
+                child.activation_post_process = get_activation_post_process(child.qconfig, device)
+        elif isinstance(child, _FusedModule):
+            # activation_post_process are now added directly to nn.Sequential/_FusedModule
+            if needs_observation(child):
+                insert_activation_post_process(child)
+        elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list:
+            if needs_observation(child):
+                insert_activation_post_process(child)
+        elif _has_special_act_post_process(child):
+            special_act_post_process = _get_special_act_post_process(child)
+            insert_activation_post_process(child, special_act_post_process)
+        elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping:
+            observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)
+            setattr(module, name, observed_child)
+            # TODO: These are the modules that cannot be observed
+            #       Once there are more, we should move them to a separate list
+            if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
+                insert_activation_post_process(observed_child)
+        else:
+            _add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
+
+    # Insert observers only for leaf nodes, note that this observer is for
+    # the output of the module, for input QuantStub will observe them
+    if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
+       and type_before_parametrizations(module) in qconfig_propagation_list:
+        insert_activation_post_process(module)
+
+def _get_unique_devices_(module):
+    return {p.device for p in module.parameters()} | \
+        {p.device for p in module.buffers()}
+
+def add_quant_dequant(module):
+    r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
+    Note that this function will modify the children of module inplace and it
+    can return a new module which wraps the input module as well.
+
+    Args:
+        module: input module with qconfig attributes for all the leaf modules
+        that we want to quantize
+
+    Return:
+        Either the inplace modified module with submodules wrapped in
+        `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
+        wraps the input module, the latter case only happens when the input
+        module is a leaf module and we want to quantize it.
+    """
+    if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig:
+        return QuantWrapper(module)
+
+    for name, child in module.named_children():
+        module._modules[name] = add_quant_dequant(child)
+    return module
+
+def prepare(model, inplace=False, allow_list=None,
+            observer_non_leaf_module_list=None,
+            prepare_custom_config_dict=None):
+    r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
+
+    Quantization configuration should be assigned preemptively
+    to individual submodules in `.qconfig` attribute.
+
+    The model will be attached with observer or fake quant modules, and qconfig
+    will be propagated.
+
+    Args:
+        `model`: input model to be modified in-place
+        `inplace`: carry out model transformations in-place, the original module is mutated
+        `allow_list`: list of quantizable modules
+        `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
+        `prepare_custom_config_dict`: customization configuration dictionary for prepare function
+
+    .. code-block:: python
+
+       # Example of prepare_custom_config_dict:
+       prepare_custom_config_dict = {
+           # user will manually define the corresponding observed
+           # module class which has a from_float class method that converts
+           # float custom module to observed custom module
+           "float_to_observed_custom_module_class": {
+               CustomModule: ObservedCustomModule
+           }
+        }
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.prepare")
+    if prepare_custom_config_dict is None:
+        prepare_custom_config_dict = get_default_custom_config_dict()
+    custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
+
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    # TODO: remove allow_list
+    qconfig_propagation_list = allow_list
+    if allow_list is None:
+        qconfig_propagation_list = get_default_qconfig_propagation_list()
+    propagate_qconfig_(model, qconfig_dict=None)
+
+    # sanity check common API misusage
+    if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
+        warnings.warn("None of the submodule got qconfig applied. Make sure you "
+                      "passed correct configuration through `qconfig_dict` or "
+                      "by assigning the `.qconfig` attribute directly on submodules")
+
+    _add_observer_(
+        model, qconfig_propagation_list, observer_non_leaf_module_list,
+        custom_module_class_mapping=custom_module_class_mapping)
+    return model
+
+def _remove_activation_post_process(module):
+    # TODO: maybe we should change activation_post_process to _activation_post_process
+    # to prevent it from being used by user
+    if hasattr(module, 'activation_post_process') and \
+       _is_activation_post_process(module.activation_post_process):
+        delattr(module, 'activation_post_process')
+
+    # remove activation_post_process pre and post hooks
+    def remove_hooks(pre_hook=False):
+        hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
+        observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook
+        handle_ids_to_remove = set()
+        for handle_id, hook_fn in hook_map.items():
+            if hook_fn is observer_hook:
+                handle_ids_to_remove.add(handle_id)
+        for handle_id in handle_ids_to_remove:
+            hook_map.pop(handle_id)
+
+    remove_hooks(pre_hook=True)
+    remove_hooks(pre_hook=False)
+
+# TODO: rename to something more general
+def _remove_qconfig(module):
+    r"""Clean up the qconfig left in the module so that new qconfig can be
+    propagated.
+
+    Args:
+        module: module to be cleaned up
+    """
+    for child in module.children():
+        _remove_qconfig(child)
+
+    if hasattr(module, "qconfig"):
+        del module.qconfig
+
+    _remove_activation_post_process(module)
+
+def quantize(model, run_fn, run_args, mapping=None, inplace=False):
+    r"""Quantize the input float model with post training static quantization.
+
+    First it will prepare the model for calibration, then it calls
+    `run_fn` which will run the calibration step, after that we will
+    convert the model to a quantized model.
+
+    Args:
+        model: input float model
+        run_fn: a calibration function for calibrating the prepared model
+        run_args: positional arguments for `run_fn`
+        inplace: carry out model transformations in-place, the original module is mutated
+        mapping: correspondence between original module types and quantized counterparts
+
+    Return:
+        Quantized model.
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
+    if mapping is None:
+        mapping = get_default_static_quant_module_mappings()
+    if not inplace:
+        model = copy.deepcopy(model)
+    model.eval()
+    prepare(model, inplace=True)
+    run_fn(model, *run_args)
+    convert(model, mapping, inplace=True)
+    return model
+
+def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
+                     mapping=None, inplace=False):
+    r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
+
+    Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
+
+    For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
+    by default is performed for layers with large weights size - i.e. Linear and RNN variants.
+
+    Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
+    If `qconfig` is provided, the `dtype` argument is ignored.
+
+    Args:
+        model: input model
+        qconfig_spec: Either:
+
+            - A dictionary that maps from name or type of submodule to quantization
+              configuration, qconfig applies to all submodules of a given
+              module unless qconfig for the submodules are specified (when the
+              submodule already has qconfig attribute). Entries in the dictionary
+              need to be QConfig instances.
+
+            - A set of types and/or submodule names to apply dynamic quantization to,
+              in which case the `dtype` argument is used to specify the bit-width
+
+        inplace: carry out model transformations in-place, the original module is mutated
+        mapping: maps type of a submodule to a type of corresponding dynamically quantized version
+            with which the submodule needs to be replaced
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
+    if qconfig_spec is None:
+        if dtype == torch.qint8:
+            qconfig_spec = {
+                nn.Linear : default_dynamic_qconfig,
+                nn.LSTM : default_dynamic_qconfig,
+                nn.GRU : default_dynamic_qconfig,
+                nn.LSTMCell : default_dynamic_qconfig,
+                nn.RNNCell : default_dynamic_qconfig,
+                nn.GRUCell : default_dynamic_qconfig,
+            }
+        elif dtype == torch.float16:
+            qconfig_spec = {
+                nn.Linear : float16_dynamic_qconfig,
+                nn.LSTM : float16_dynamic_qconfig,
+                nn.GRU : float16_dynamic_qconfig,
+                nn.LSTMCell : float16_dynamic_qconfig,
+                nn.RNNCell : float16_dynamic_qconfig,
+                nn.GRUCell : float16_dynamic_qconfig,
+            }
+        elif dtype == torch.quint8:
+            qconfig_spec = {
+                nn.EmbeddingBag : float_qparams_weight_only_qconfig,
+                nn.Embedding : float_qparams_weight_only_qconfig,
+            }
+        elif dtype == torch.quint4x2:
+            qconfig_spec = {
+                nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit,
+            }
+        else:
+            raise ValueError(
+                f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please")
+    elif isinstance(qconfig_spec, set):
+        if dtype is torch.qint8:
+            default_qconfig = default_dynamic_qconfig
+        elif dtype is torch.float16:
+            default_qconfig = float16_dynamic_qconfig
+        elif dtype is torch.quint8:
+            default_qconfig = float_qparams_weight_only_qconfig
+        elif dtype is torch.quint4x2:
+            default_qconfig = float_qparams_weight_only_qconfig_4bit
+        else:
+            raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
+        qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
+
+    if mapping is None:
+        mapping = get_default_dynamic_quant_module_mappings()
+
+    if not inplace:
+        model = copy.deepcopy(model)
+    model.eval()
+    propagate_qconfig_(model, qconfig_spec)
+    convert(model, mapping, inplace=True)
+    return model
+
+def prepare_qat(model, mapping=None, inplace=False):
+    r"""
+    Prepares a copy of the model for quantization calibration or
+    quantization-aware training and converts it to quantized version.
+
+    Quantization configuration should be assigned preemptively
+    to individual submodules in `.qconfig` attribute.
+
+    Args:
+        model: input model to be modified in-place
+        mapping: dictionary that maps float modules to quantized modules to be
+                 replaced.
+        inplace: carry out model transformations in-place, the original module
+                 is mutated
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
+    assert model.training, "prepare_qat only works on models in training mode"
+    if mapping is None:
+        mapping = get_default_qat_module_mappings()
+
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    propagate_qconfig_(model, qconfig_dict=None)
+    convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
+    prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
+    return model
+
+def quantize_qat(model, run_fn, run_args, inplace=False):
+    r"""Do quantization aware training and output a quantized model
+
+    Args:
+        model: input model
+        run_fn: a function for evaluating the prepared model, can be a
+                function that simply runs the prepared model or a training
+                loop
+        run_args: positional arguments for `run_fn`
+
+    Return:
+        Quantized model.
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
+    if not inplace:
+        model = copy.deepcopy(model)
+    model.train()
+    prepare_qat(model, inplace=True)
+    run_fn(model, *run_args)
+    convert(model, inplace=True)
+    return model
+
+def convert(
+        module, mapping=None, inplace=False, remove_qconfig=True,
+        is_reference=False, convert_custom_config_dict=None):
+    r"""Converts submodules in input module to a different module according to `mapping`
+    by calling `from_float` method on the target module class. And remove qconfig at the
+    end if remove_qconfig is set to True.
+
+    Args:
+        `module`: prepared and calibrated module
+        `mapping`: a dictionary that maps from source module type to target
+                   module type, can be overwritten to allow swapping user defined
+                   Modules
+        `inplace`: carry out model transformations in-place, the original module
+                   is mutated
+        `convert_custom_config_dict`: custom configuration dictionary for convert function
+
+    .. code-block:: python
+
+       # Example of convert_custom_config_dict:
+       convert_custom_config_dict = {
+           # user will manually define the corresponding quantized
+           # module class which has a from_observed class method that converts
+           # observed custom module to quantized custom module
+           "observed_to_quantized_custom_module_class": {
+               ObservedCustomModule: QuantizedCustomModule
+           }
+       }
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.convert")
+    if not inplace:
+        module = copy.deepcopy(module)
+    _convert(
+        module, mapping, inplace=True, is_reference=is_reference,
+        convert_custom_config_dict=convert_custom_config_dict)
+    if remove_qconfig:
+        _remove_qconfig(module)
+    return module
+
+def _convert(
+        module, mapping=None, inplace=False,
+        is_reference=False, convert_custom_config_dict=None):
+    r"""Converts submodules in input module to a different module according to `mapping`
+    by calling `from_float` method on the target module class
+
+    Args:
+        module: input module
+        mapping: a dictionary that maps from source module type to target
+                 module type, can be overwritten to allow swapping user defined
+                 Modules
+        inplace: carry out model transformations in-place, the original module
+                 is mutated
+        is_reference: a flag to enable quantized reference module
+
+    """
+    if mapping is None:
+        mapping = get_default_static_quant_reference_module_mappings() if is_reference \
+            else get_default_static_quant_module_mappings()
+    if convert_custom_config_dict is None:
+        convert_custom_config_dict = get_default_custom_config_dict()
+    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
+
+    if not inplace:
+        module = copy.deepcopy(module)
+    reassign = {}
+    for name, mod in module.named_children():
+        # both fused modules and observed custom modules are
+        # swapped as one unit
+        if not isinstance(mod, _FusedModule) and \
+           type_before_parametrizations(mod) not in custom_module_class_mapping:
+            _convert(mod, mapping, True,  # inplace
+                     is_reference, convert_custom_config_dict)
+        reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
+
+    for key, value in reassign.items():
+        module._modules[key] = value
+
+    return module
+
+def swap_module(mod, mapping, custom_module_class_mapping):
+    r"""Swaps the module if it has a quantized counterpart and it has an
+    `observer` attached.
+
+    Args:
+        mod: input module
+        mapping: a dictionary that maps from nn module to nnq module
+
+    Return:
+        The corresponding quantized module of `mod`
+    """
+    new_mod = mod
+    if hasattr(mod, 'qconfig') and mod.qconfig is not None:
+        swapped = False
+        if type_before_parametrizations(mod) in custom_module_class_mapping:
+            new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod)
+            swapped = True
+        elif type_before_parametrizations(mod) in mapping:
+            qmod = mapping[type_before_parametrizations(mod)]
+            if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE:
+                assert mod.qconfig is not None
+                weight_post_process = mod.qconfig.weight()
+                weight_post_process(mod.weight)
+                weight_qparams = get_qparam_dict(weight_post_process)
+                new_mod = qmod.from_float(mod, weight_qparams)
+            else:
+                new_mod = qmod.from_float(mod)
+            swapped = True
+
+        if swapped:
+            # Preserve module's pre forward hooks. They'll be called on quantized input
+            for pre_hook_fn in mod._forward_pre_hooks.values():
+                new_mod.register_forward_pre_hook(pre_hook_fn)
+            # Preserve module's post forward hooks except _observer_forward_hook
+            # After convert they'll work with quantized output
+            for hook_fn in mod._forward_hooks.values():
+                if hook_fn is not _observer_forward_hook:
+                    new_mod.register_forward_hook(hook_fn)
+
+            # respect device affinity when swapping modules
+            devices = _get_unique_devices_(mod)
+            assert len(devices) <= 1, (
+                f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
+            )
+            device = next(iter(devices)) if len(devices) > 0 else None
+            if device:
+                new_mod.to(device)
+    return new_mod
+
+def _get_observer_dict(mod, target_dict, prefix=""):
+    r"""Traverse the modules and save all observers into dict.
+    This is mainly used for quantization accuracy debug
+    Args:
+        mod: the top module we want to save all observers
+        prefix: the prefix for the current module
+        target_dict: the dictionary used to save all the observers
+    """
+    def get_prefix(prefix):
+        return prefix if prefix == "" else prefix + '.'
+
+    if hasattr(mod, 'activation_post_process'):
+        target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
+    for name, child in mod.named_children():
+        module_prefix = get_prefix(prefix) + name if prefix else name
+        _get_observer_dict(child, target_dict, module_prefix)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantize_fx.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantize_fx.py
new file mode 100644
index 0000000000000000000000000000000000000000..b42fcee88c32e8e357ccf9e961e7774bfd52fe98
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantize_fx.py
@@ -0,0 +1,726 @@
+from typing import Any, Dict, Optional, Tuple, Union
+import warnings
+
+import torch
+import copy
+from torch.fx import GraphModule
+from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
+from .fx.tracer import QuantizationTracer
+from .fx.tracer import (  # noqa: F401
+    Scope,
+    ScopeContextManager
+)
+from .fx.fuse import fuse  # noqa: F401
+from .fx.prepare import prepare  # noqa: F401
+from .fx.convert import convert
+from .backend_config import (  # noqa: F401
+    BackendConfig,
+    get_tensorrt_backend_config,
+)
+from .fx.graph_module import ObservedGraphModule  # noqa: F401
+from .fx.custom_config import (
+    ConvertCustomConfig,
+    FuseCustomConfig,
+    PrepareCustomConfig,
+)
+from .fx.utils import get_custom_module_class_keys  # noqa: F401
+from .fx.utils import get_skipped_module_name_and_classes
+from .qconfig_mapping import QConfigMapping
+
+def attach_preserved_attrs_to_model(
+    model: Union[GraphModule, torch.nn.Module],
+    preserved_attrs: Dict[str, Any],
+) -> None:
+    """ Store preserved attributes to the model.meta so that it can be preserved during deepcopy
+    """
+    model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs)  # type: ignore[operator, index, assignment]
+    # set the preserved attributes in the model so that user can call
+    # model.attr as they do before calling fx graph mode quantization
+    for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():  # type: ignore[index, union-attr]
+        setattr(model, attr_name, attr)
+
+def _check_is_graph_module(model: torch.nn.Module) -> None:
+    if not isinstance(model, GraphModule):
+        raise ValueError(
+            "input model must be a GraphModule, "
+            + "Got type:"
+            + str(type(model))
+            + " Please make "
+            + "sure to follow the tutorials."
+        )
+
+def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
+    """ Attach meta field to all nodes of the graph if it does not exist,
+    meta field is a field stores some meta information about the node, such
+    as dtype and shape information for output of the node, this only exists
+    if the program is captured by make_fx (used in quantize_pt2e flow), if
+    the program is captured by torch.fx symbolic tracing, this field may not exist,
+    so we add it here to avoid checking this all over the places
+    """
+    for node in model.graph.nodes:
+        if not hasattr(node, "meta"):
+            node.meta = {}
+
+def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
+    r""" Swap FloatFunctional with FXFloatFunctional
+    """
+    modules_to_swap = []
+    for name, module in model.named_children():
+        if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
+            modules_to_swap.append(name)
+        else:
+            _swap_ff_with_fxff(module)
+
+    for name in modules_to_swap:
+        del model._modules[name]
+        model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
+
+
+def _fuse_fx(
+    model: GraphModule,
+    is_qat: bool,
+    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Internal helper function to fuse modules in preparation for quantization
+
+    Args:
+        model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
+    """
+    _check_is_graph_module(model)
+    return fuse(
+        model, is_qat, fuse_custom_config, backend_config)  # type: ignore[operator]
+
+def _prepare_fx(
+    model: torch.nn.Module,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
+    is_qat: bool,
+    example_inputs: Tuple[Any, ...],
+    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
+    _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+    is_standalone_module: bool = False,
+) -> GraphModule:
+    r""" Internal helper function for prepare_fx
+    Args:
+      `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
+      see docs for :func:`~torch.ao.quantization.prepare_fx`
+      `is_standalone_module`: a boolean flag indicates whether we are
+      quantizing a standalone module or not, a standalone module
+      is a submodule of the parent module that is not inlined in the
+forward graph of the parent module,
+      the way we quantize standalone module is described in:
+      :func:`~torch.ao.quantization._prepare_standalone_module_fx`
+    """
+    if prepare_custom_config is None:
+        prepare_custom_config = PrepareCustomConfig()
+    if _equalization_config is None:
+        _equalization_config = QConfigMapping()
+
+    if isinstance(prepare_custom_config, Dict):
+        warnings.warn(
+            "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a PrepareCustomConfig instead.")
+        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
+
+    # swap FloatFunctional with FXFloatFunctional
+    _swap_ff_with_fxff(model)
+
+    skipped_module_names, skipped_module_classes = \
+        get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module)
+    preserved_attr_names = prepare_custom_config.preserved_attributes
+    preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}
+    # symbolically trace the model
+    tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)  # type: ignore[arg-type]
+    graph_module = GraphModule(model, tracer.trace(model))
+    _attach_meta_to_node_if_not_exist(graph_module)
+
+    fuse_custom_config = FuseCustomConfig().set_preserved_attributes(prepare_custom_config.preserved_attributes)
+    graph_module = _fuse_fx(
+        graph_module,
+        is_qat,
+        fuse_custom_config,
+        backend_config)
+    prepared = prepare(
+        graph_module,
+        qconfig_mapping,
+        is_qat,
+        tracer.node_name_to_scope,
+        example_inputs=example_inputs,
+        prepare_custom_config=prepare_custom_config,
+        _equalization_config=_equalization_config,
+        backend_config=backend_config,
+        is_standalone_module=is_standalone_module,
+    )  # type: ignore[operator]
+
+    attach_preserved_attrs_to_model(prepared, preserved_attrs)
+    return prepared
+
+
+def _prepare_standalone_module_fx(
+    model: torch.nn.Module,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
+    is_qat: bool,
+    example_inputs: Tuple[Any, ...],
+    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
+    parent module.
+    standalone_module means it a submodule that is not inlined in parent module,
+    and will be quantized separately as one unit.
+
+    How the standalone module is observed is specified by `input_quantized_idxs` and
+    `output_quantized_idxs` in the prepare_custom_config for the standalone module
+
+    Returns:
+
+        * model(GraphModule): prepared standalone module. It has these attributes in
+          model.meta:
+
+            * `standalone_module_input_quantized_idxs(List[Int])`: a list of
+              indexes for the graph input that is expected to be quantized,
+              same as input_quantized_idxs configuration provided
+              for the standalone module
+            * `standalone_module_output_quantized_idxs(List[Int])`: a list of
+              indexs for the graph output that is quantized
+              same as input_quantized_idxs configuration provided
+              for the standalone module
+
+    """
+    return _prepare_fx(
+        model,
+        qconfig_mapping,
+        is_qat,
+        example_inputs,
+        prepare_custom_config,
+        backend_config=backend_config,
+        is_standalone_module=True,
+    )
+
+
+def fuse_fx(
+    model: torch.nn.Module,
+    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
+    Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py
+
+    Args:
+
+        * `model` (torch.nn.Module): a torch.nn.Module model
+        * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
+            See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
+    Example::
+
+        from torch.ao.quantization import fuse_fx
+        m = Model().eval()
+        m = fuse_fx(m)
+
+    """
+    if fuse_custom_config is None:
+        fuse_custom_config = FuseCustomConfig()
+
+    if isinstance(fuse_custom_config, Dict):
+        warnings.warn(
+            "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
+            "in a future version. Please pass in a FuseCustomConfig instead.")
+        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
+
+    torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
+    preserved_attr_names = fuse_custom_config.preserved_attributes
+    preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}
+
+    graph_module = torch.fx.symbolic_trace(model)
+    _attach_meta_to_node_if_not_exist(graph_module)
+    graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config)
+
+    attach_preserved_attrs_to_model(graph_module, preserved_attrs)
+    return graph_module
+
+def prepare_fx(
+    model: torch.nn.Module,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
+    example_inputs: Tuple[Any, ...],
+    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
+    _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Prepare a model for post training quantization
+
+    Args:
+      * `model` (torch.nn.Module): torch.nn.Module model
+
+      * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
+         quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
+         for more details
+
+      * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
+         Tuple of positional args (keyword args can be passed as positional args as well)
+
+      * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
+          See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
+
+      * `_equalization_config`: config for specifying how to perform equalization on the model
+
+      * `backend_config` (BackendConfig): config that specifies how operators are quantized
+         in a backend, this includes how the operators are observed,
+         supported fusion patterns, how quantize/dequantize ops are
+         inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
+
+    Return:
+      A GraphModule with observer (configured by qconfig_mapping), ready for calibration
+
+    Example::
+
+        import torch
+        from torch.ao.quantization import get_default_qconfig_mapping
+        from torch.ao.quantization.quantize_fx import prepare_fx
+
+        class Submodule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+            def forward(self, x):
+                x = self.linear(x)
+                return x
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+                self.sub = Submodule()
+
+            def forward(self, x):
+                x = self.linear(x)
+                x = self.sub(x) + x
+                return x
+
+        # initialize a floating point model
+        float_model = M().eval()
+
+        # define calibration function
+        def calibrate(model, data_loader):
+            model.eval()
+            with torch.no_grad():
+                for image, target in data_loader:
+                    model(image)
+
+        # qconfig is the configuration for how we insert observers for a particular
+        # operator
+        # qconfig = get_default_qconfig("fbgemm")
+        # Example of customizing qconfig:
+        # qconfig = torch.ao.quantization.QConfig(
+        #    activation=MinMaxObserver.with_args(dtype=torch.qint8),
+        #    weight=MinMaxObserver.with_args(dtype=torch.qint8))
+        # `activation` and `weight` are constructors of observer module
+
+        # qconfig_mapping is a collection of quantization configurations, user can
+        # set the qconfig for each operator (torch op calls, functional calls, module calls)
+        # in the model through qconfig_mapping
+        # the following call will get the qconfig_mapping that works best for models
+        # that target "fbgemm" backend
+        qconfig_mapping = get_default_qconfig_mapping("fbgemm")
+
+        # We can customize qconfig_mapping in different ways.
+        # e.g. set the global qconfig, which means we will use the same qconfig for
+        # all operators in the model, this can be overwritten by other settings
+        # qconfig_mapping = QConfigMapping().set_global(qconfig)
+        # e.g. quantize the linear submodule with a specific qconfig
+        # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
+        # e.g. quantize all nn.Linear modules with a specific qconfig
+        # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
+        # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
+        # argument
+
+        # example_inputs is a tuple of inputs, that is used to infer the type of the
+        # outputs in the model
+        # currently it's not used, but please make sure model(*example_inputs) runs
+        example_inputs = (torch.randn(1, 3, 224, 224),)
+
+        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
+        # e.g. backend_config = get_default_backend_config("fbgemm")
+        # `prepare_fx` inserts observers in the model based on qconfig_mapping and
+        # backend_config. If the configuration for an operator in qconfig_mapping
+        # is supported in the backend_config (meaning it's supported by the target
+        # hardware), we'll insert observer modules according to the qconfig_mapping
+        # otherwise the configuration in qconfig_mapping will be ignored
+        #
+        # Example:
+        # in qconfig_mapping, user sets linear module to be quantized with quint8 for
+        # activation and qint8 for weight:
+        # qconfig = torch.ao.quantization.QConfig(
+        #     observer=MinMaxObserver.with_args(dtype=torch.quint8),
+        #     weight=MinMaxObserver.with-args(dtype=torch.qint8))
+        # Note: current qconfig api does not support setting output observer, but
+        # we may extend this to support these more fine grained control in the
+        # future
+        #
+        # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
+        # in backend config, linear module also supports in this configuration:
+        # weighted_int8_dtype_config = DTypeConfig(
+        #   input_dtype=torch.quint8,
+        #   output_dtype=torch.quint8,
+        #   weight_dtype=torch.qint8,
+        #   bias_type=torch.float)
+
+        # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
+        #    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+        #    .add_dtype_config(weighted_int8_dtype_config) \
+        #    ...
+
+        # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
+        # `prepare_fx` will check that the setting requested by suer in qconfig_mapping
+        # is supported by the backend_config and insert observers and fake quant modules
+        # in the model
+        prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
+        # Run calibration
+        calibrate(prepared_model, sample_inference_data)
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
+    return _prepare_fx(
+        model,
+        qconfig_mapping,
+        False,  # is_qat
+        example_inputs,
+        prepare_custom_config,
+        _equalization_config,
+        backend_config,
+    )
+
+
+def prepare_qat_fx(
+    model: torch.nn.Module,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
+    example_inputs: Tuple[Any, ...],
+    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Prepare a model for quantization aware training
+
+    Args:
+      * `model` (torch.nn.Module): torch.nn.Module model
+      * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
+      * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
+      * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
+      * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
+
+    Return:
+      A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
+      quantization aware training
+
+    Example::
+
+        import torch
+        from torch.ao.quantization import get_default_qat_qconfig_mapping
+        from torch.ao.quantization.quantize_fx import prepare_qat_fx
+
+        class Submodule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+            def forward(self, x):
+                x = self.linear(x)
+                return x
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+                self.sub = Submodule()
+
+            def forward(self, x):
+                x = self.linear(x)
+                x = self.sub(x) + x
+                return x
+
+        # initialize a floating point model
+        float_model = M().train()
+        # (optional, but preferred) load the weights from pretrained model
+        # float_model.load_weights(...)
+
+        # define the training loop for quantization aware training
+        def train_loop(model, train_data):
+            model.train()
+            for image, target in data_loader:
+                ...
+
+        # qconfig is the configuration for how we insert observers for a particular
+        # operator
+        # qconfig = get_default_qconfig("fbgemm")
+        # Example of customizing qconfig:
+        # qconfig = torch.ao.quantization.QConfig(
+        #    activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
+        #    weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
+        # `activation` and `weight` are constructors of observer module
+
+        # qconfig_mapping is a collection of quantization configurations, user can
+        # set the qconfig for each operator (torch op calls, functional calls, module calls)
+        # in the model through qconfig_mapping
+        # the following call will get the qconfig_mapping that works best for models
+        # that target "fbgemm" backend
+        qconfig_mapping = get_default_qat_qconfig("fbgemm")
+
+        # We can customize qconfig_mapping in different ways, please take a look at
+        # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
+        # to configure this
+
+        # example_inputs is a tuple of inputs, that is used to infer the type of the
+        # outputs in the model
+        # currently it's not used, but please make sure model(*example_inputs) runs
+        example_inputs = (torch.randn(1, 3, 224, 224),)
+
+        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
+        # e.g. backend_config = get_default_backend_config("fbgemm")
+        # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
+        # backend_config, if the configuration for an operator in qconfig_mapping
+        # is supported in the backend_config (meaning it's supported by the target
+        # hardware), we'll insert fake_quantize modules according to the qconfig_mapping
+        # otherwise the configuration in qconfig_mapping will be ignored
+        # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
+        # how qconfig_mapping interacts with backend_config
+        prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
+        # Run training
+        train_loop(prepared_model, train_loop)
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
+    return _prepare_fx(
+        model,
+        qconfig_mapping,
+        True,  # is_qat
+        example_inputs,
+        prepare_custom_config,
+        backend_config=backend_config,
+    )
+
+
+def _convert_fx(
+    graph_module: GraphModule,
+    is_reference: bool,
+    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
+    is_standalone_module: bool = False,
+    _remove_qconfig: bool = True,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+    is_decomposed: bool = False,
+) -> GraphModule:
+    """ `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
+    """
+    if convert_custom_config is None:
+        convert_custom_config = ConvertCustomConfig()
+
+    if isinstance(convert_custom_config, Dict):
+        warnings.warn(
+            "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
+            "in a future version. Please pass in a ConvertCustomConfig instead.")
+        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
+
+    _check_is_graph_module(graph_module)
+    preserved_attr_names = convert_custom_config.preserved_attributes
+    preserved_attrs = {attr: getattr(graph_module, attr) for attr in preserved_attr_names if hasattr(graph_module, attr)}
+
+    quantized = convert(
+        graph_module,
+        is_reference,
+        convert_custom_config,
+        is_standalone_module,
+        _remove_qconfig_flag=_remove_qconfig,
+        qconfig_mapping=qconfig_mapping,
+        backend_config=backend_config,
+        is_decomposed=is_decomposed,
+    )
+
+    attach_preserved_attrs_to_model(quantized, preserved_attrs)
+    return quantized
+
+
+def convert_fx(
+    graph_module: GraphModule,
+    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
+    _remove_qconfig: bool = True,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Convert a calibrated or trained model to a quantized model
+
+    Args:
+        * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
+
+        * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
+            See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
+
+        * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
+
+        * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
+
+           The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
+           with the same values or `None`. Additional keys can be specified with values set to `None`.
+
+          For each entry whose value is set to None, we skip quantizing that entry in the model::
+
+            qconfig_mapping = QConfigMapping
+                .set_global(qconfig_from_prepare)
+                .set_object_type(torch.nn.functional.add, None)  # skip quantizing torch.nn.functional.add
+                .set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
+                .set_module_name("foo.bar", None)  # skip quantizing module "foo.bar"
+
+         * `backend_config` (BackendConfig): A configuration for the backend which describes how
+            operators should be quantized in the backend, this includes quantization
+            mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
+            observer placement for each operators and fused operators.
+            See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
+
+    Return:
+        A quantized model (torch.nn.Module)
+
+    Example::
+
+        # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
+        # convert_fx converts a calibrated/trained model to a quantized model for the
+        # target hardware, this includes converting the model first to a reference
+        # quantized model, and then lower the reference quantized model to a backend
+        # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and
+        # they share the same set of quantized operators, so we are using the same
+        # lowering procedure
+        #
+        # backend_config defines the corresponding reference quantized module for
+        # the weighted modules in the model, e.g. nn.Linear
+        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
+        # e.g. backend_config = get_default_backend_config("fbgemm")
+        quantized_model = convert_fx(prepared_model)
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
+    return _convert_fx(
+        graph_module,
+        is_reference=False,
+        convert_custom_config=convert_custom_config,
+        _remove_qconfig=_remove_qconfig,
+        qconfig_mapping=qconfig_mapping,
+        backend_config=backend_config,
+    )
+
+
+def convert_to_reference_fx(
+    graph_module: GraphModule,
+    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
+    _remove_qconfig: bool = True,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Convert a calibrated or trained model to a reference quantized model,
+    see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
+    reference quantized model is a standard representation of a quantized model provided
+    by FX Graph Mode Quantization, it can be further lowered to run on the target
+    hardware, like accelerators
+
+    Args:
+        * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
+
+        * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
+            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
+
+        * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
+
+        * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
+            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
+
+         * `backend_config` (BackendConfig): A configuration for the backend which describes how
+            operators should be quantized in the backend. See
+            :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
+
+    Return:
+        A reference quantized model (GraphModule)
+
+    Example::
+
+        # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
+        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
+        # e.g. backend_config = get_default_backend_config("fbgemm")
+        reference_quantized_model = convert_to_reference_fx(prepared_model)
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx")
+    return _convert_fx(
+        graph_module,
+        is_reference=True,
+        convert_custom_config=convert_custom_config,
+        _remove_qconfig=_remove_qconfig,
+        qconfig_mapping=qconfig_mapping,
+        backend_config=backend_config,
+    )
+
+def _convert_to_reference_decomposed_fx(
+    graph_module: GraphModule,
+    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
+    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" Convert a calibrated or trained model to a reference quantized model, with
+    decomposed representation for quantized Tensor
+    see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
+    reference quantized model is a standard representation of a quantized model provided
+    by FX Graph Mode Quantization, it can be further lowered to run on the target
+    hardware, like accelerators
+
+    Note: this is not public API
+
+    Args:
+        * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
+
+        * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
+            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
+
+        * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
+
+        * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
+            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
+
+         * `backend_config` (BackendConfig): A configuration for the backend which describes how
+            operators should be quantized in the backend. See
+            :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
+
+    Return:
+        A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
+
+    Example::
+
+        # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
+        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
+        # e.g. backend_config = get_default_backend_config("fbgemm")
+        reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
+    return _convert_fx(
+        graph_module,
+        is_reference=True,
+        convert_custom_config=convert_custom_config,
+        _remove_qconfig=False,
+        qconfig_mapping=qconfig_mapping,
+        backend_config=backend_config,
+        is_decomposed=True,
+    )
+
+
+def _convert_standalone_module_fx(
+    graph_module: GraphModule,
+    is_reference: bool = False,
+    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
+) -> GraphModule:
+    r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
+    and convert it to a quantized model
+
+    Returns a quantized standalone module, whether input/output is quantized is
+    specified by prepare_custom_config, with
+    input_quantized_idxs, output_quantized_idxs, please
+    see docs for prepare_fx for details
+    """
+    return _convert_fx(
+        graph_module,
+        is_reference,
+        convert_custom_config,
+        is_standalone_module=True,
+    )
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantize_jit.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantize_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3184b8f963133f57e76121ffa7656eb1b2f6af2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantize_jit.py
@@ -0,0 +1,335 @@
+
+import torch
+from torch.ao.quantization.qconfig import QConfig
+from torch.ao.quantization.quant_type import QuantType
+from torch.jit._recursive import wrap_cpp_module
+
+__all__ = [
+    "script_qconfig",
+    "script_qconfig_dict",
+    "fuse_conv_bn_jit",
+    "prepare_jit",
+    "prepare_dynamic_jit",
+    "convert_jit",
+    "convert_dynamic_jit",
+    "quantize_jit",
+    "quantize_dynamic_jit",
+]
+
+def _check_is_script_module(model):
+    if not isinstance(model, torch.jit.ScriptModule):
+        raise ValueError('input must be a script module, got: ' + str(type(model)))
+
+def _check_forward_method(model):
+    if not model._c._has_method('forward'):
+        raise ValueError('input script module does not have forward method')
+
+def script_qconfig(qconfig):
+    r"""Instantiate the activation and weight observer modules and script
+    them, these observer module instances will be deepcopied during
+    prepare_jit step.
+    """
+    return QConfig(
+        activation=torch.jit.script(qconfig.activation())._c,
+        weight=torch.jit.script(qconfig.weight())._c)
+
+def script_qconfig_dict(qconfig_dict):
+    r"""Helper function used by `prepare_jit`.
+    Apply `script_qconfig` for all entries in `qconfig_dict` that is
+    not None.
+    """
+    return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
+
+def fuse_conv_bn_jit(model, inplace=False):
+    r""" Fuse conv - bn module
+    Works for eval model only.
+
+    Args:
+        model: TorchScript model from scripting or tracing
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit")
+    model_c = model._c
+    model_c = torch._C._jit_pass_fold_convbn(model_c)
+    if inplace:
+        model._reconstruct(model_c)
+    else:
+        model = wrap_cpp_module(model_c)
+    return model
+
+def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
+    _check_is_script_module(model)
+    _check_forward_method(model)
+    if not all(isinstance(x, str) for x in qconfig_dict.keys()):
+        raise ValueError('qconfig_dict should only contain names(str) as keys.')
+    scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
+    model = fuse_conv_bn_jit(model, inplace)
+    model_c = torch._C._jit_pass_insert_observers(model._c,
+                                                  'forward',
+                                                  scripted_qconfig_dict,
+                                                  inplace,
+                                                  quant_type)
+    if inplace:
+        model._reconstruct(model_c)
+    else:
+        model = wrap_cpp_module(model_c)
+    return model
+
+def _prepare_ondevice_jit(model, qconfig_dict, method_name='forward', inplace=False, quant_type=QuantType.STATIC):
+    _check_is_script_module(model)
+    if not all(isinstance(x, str) for x in qconfig_dict.keys()):
+        raise ValueError('qconfig_dict should only contain names(str) as keys.')
+    scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
+    method_graph = model._c._get_method(method_name).graph
+    torch._C._jit_pass_inline(method_graph)
+    model = fuse_conv_bn_jit(model, inplace)
+    model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq(model._c,
+                                                                         method_name,
+                                                                         scripted_qconfig_dict,
+                                                                         inplace,
+                                                                         quant_type)
+    if inplace:
+        model._reconstruct(model_c)
+    else:
+        model = wrap_cpp_module(model_c)
+    return model
+
+def prepare_jit(model, qconfig_dict, inplace=False):
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit")
+    return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
+
+def prepare_dynamic_jit(model, qconfig_dict, inplace=False):
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit")
+    return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
+
+
+def _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
+    return _prepare_ondevice_jit(model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC)
+
+def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
+                 preserved_attrs=None):
+    _check_is_script_module(model)
+    model.eval()
+    model_c = model._c
+    model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type)
+    if not debug:
+        is_xpu = all(p.device.type == 'xpu' for p in model.parameters())
+        if not is_xpu:
+            # Moving model parameters to CPU since quantized operators
+            # are only supported on CPU and XPU right now
+            model.cpu()
+        if preserved_attrs is None:
+            preserved_attrs = []
+        model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs)
+    if inplace:
+        model._reconstruct(model_c)
+    else:
+        model = wrap_cpp_module(model_c)
+    torch._C._jit_pass_constant_propagation(model.graph)
+    torch._C._jit_pass_dce(model.graph)
+    return model
+
+
+def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC):
+    _check_is_script_module(model)
+    assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant."
+    assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward"
+    observe_method_name = "observe_" + method_name
+    quantize_method_name = "quantize_" + method_name
+    model_c = model._c
+    model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq(
+        model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC)
+    model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq(model_c, QuantType.DYNAMIC, quantize_method_name)
+    if inplace:
+        model._reconstruct(model_c)
+    else:
+        model = wrap_cpp_module(model_c)
+    return model
+
+def convert_jit(model, inplace=False, debug=False, preserved_attrs=None):
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit")
+    return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs)
+
+def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None):
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit")
+    return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs)
+
+
+def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
+    return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
+
+
+def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False):
+    model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace)
+    model = _convert_ondevice_dynamic_jit(model, method_name, inplace)
+    return model
+
+def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
+    # Always do inplace convert because the Tensor is already
+    # copied in prepare_jit when inplace is False
+    if quant_type == QuantType.DYNAMIC:
+        model = prepare_dynamic_jit(model, qconfig_dict, inplace)
+        model = convert_dynamic_jit(model, True, debug)
+    else:
+        assert run_fn, "Must provide calibration function for post training static quantization"
+        assert run_args, "Must provide calibration dataset for post training static quantization"
+        model = prepare_jit(model, qconfig_dict, inplace)
+        run_fn(model, *run_args)
+        model = convert_jit(model, True, debug)
+
+    torch._C._jit_pass_constant_propagation(model.graph)
+    torch._C._jit_pass_dce(model.graph)
+    return model
+
+def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
+    r"""Quantize the input float TorchScript model with
+    post training static quantization.
+
+    First it will prepare the model for calibration, then it calls
+    `run_fn` which will run the calibration step, after that we will
+    convert the model to a quantized model.
+
+    Args:
+        `model`: input float TorchScript model
+        `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
+        qconfig for that module as value, empty key means the qconfig will be applied
+        to whole model unless it's overwritten by more specific configurations, the
+        qconfig for each module is either found in the dictionary or fallback to
+         the qconfig of parent module.
+
+        Right now qconfig_dict is the only way to configure how the model is quantized,
+        and it is done in the granularity of module, that is, we only support one type
+        of qconfig for each torch.nn.Module, and the qconfig for sub module will
+        override the qconfig for parent module, empty string means global configuration.
+        `run_fn`: a calibration function for calibrating the prepared model
+        `run_args`: positional arguments for `run_fn`
+        `inplace`: carry out model transformations in-place, the original module is
+        mutated
+        `debug`: flag for producing a debug friendly model (preserve weight attribute)
+
+    Return:
+        Quantized TorchSciprt model.
+
+    Example:
+    ```python
+    import torch
+    from torch.ao.quantization import get_default_qconfig
+    from torch.ao.quantization import quantize_jit
+
+    ts_model = torch.jit.script(float_model.eval())  # or torch.jit.trace(float_model, input)
+    qconfig = get_default_qconfig('fbgemm')
+    def calibrate(model, data_loader):
+        model.eval()
+        with torch.no_grad():
+            for image, target in data_loader:
+                model(image)
+
+    quantized_model = quantize_jit(
+        ts_model,
+        {'': qconfig},
+        calibrate,
+        [data_loader_test])
+    ```
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit")
+    return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC)
+
+def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
+    r"""Quantize the input float TorchScript model with
+    post training dynamic quantization.
+    Currently only qint8 quantization of torch.nn.Linear is supported.
+
+    Args:
+        `model`: input float TorchScript model
+        `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
+        qconfig for that module as value, please see detailed
+        descriptions in :func:`~torch.ao.quantization.quantize_jit`
+        `inplace`: carry out model transformations in-place, the original module is
+        mutated
+        `debug`: flag for producing a debug friendly model (preserve weight attribute)
+
+    Return:
+        Quantized TorchSciprt model.
+
+    Example:
+    ```python
+    import torch
+    from torch.ao.quantization import per_channel_dynamic_qconfig
+    from torch.ao.quantization import quantize_dynamic_jit
+
+    ts_model = torch.jit.script(float_model.eval())  # or torch.jit.trace(float_model, input)
+    qconfig = get_default_qconfig('fbgemm')
+    def calibrate(model, data_loader):
+        model.eval()
+        with torch.no_grad():
+            for image, target in data_loader:
+                model(image)
+
+    quantized_model = quantize_dynamic_jit(
+        ts_model,
+        {'': qconfig},
+        calibrate,
+        [data_loader_test])
+    ```
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit")
+    return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
+
+
+def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
+    r"""Prepares the input float TorchScript model with
+    *on-device* post training dynamic quantization.
+    Currently only qint8 quantization of torch.nn.Linear is supported.
+
+    Args:
+        `model`: input float TorchScript model
+        `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
+        qconfig for that module as value, please see detailed
+        `method_name`: Name of the method within the model, to be prepared for quantization
+        descriptions in :func:`~torch.ao.quantization.quantize_jit`
+        `inplace`: carry out model transformations in-place, the original module is
+        mutated
+
+    Return:
+        TorchScript model that is ready for on device quantization.
+        This means that the returned
+        model has:
+        - Method is inlined.
+        - Model has observer modules inserted in the model.
+        - Model has packed params inserted in the model. However they are empty as in they dont
+          contain valid quantized weights.
+        - observe_ is added that observe the values to be quantized.
+        - reset_observers_ to reset observers.
+        - quantize_ is added to the model.
+          - This method extract scale, zero points.
+          - Quantizes observed weights.
+          - Creates packed params from it and update the attribute of the model with the new values
+            for the packed params.
+          - Reset the original fp32 weights with empty tensor using SetAttr.
+        - quantized_ is added to the model.
+          - This method uses quantized weights and quantized linear ops instead of fp32 op.
+          - This method should be used for inference post PTQ.
+        - Note that all method's signatures should be the same as method_name.
+
+        Later on device:
+        - Run reset_observers_
+        - Run observe_
+        - Run quantize_
+        - Now model can be saved and loaded later.
+        - Run model with quantized_
+
+    Example:
+    ```python
+    import torch
+    from torch.ao.quantization import per_channel_dynamic_qconfig
+    from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit
+
+    ts_model = torch.jit.script(float_model.eval())  # or torch.jit.trace(float_model, input)
+    qconfig = get_default_qconfig('fbgemm')
+    quant_ready_model = _quantize_ondevice_dynamic_jit(
+        ts_model,
+        {'': qconfig},
+        'forward',
+        True)
+    ```
+    """
+    return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantize_pt2e.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantize_pt2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..33267f8d81a9757aff3852d2444c9c8af6fed42a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantize_pt2e.py
@@ -0,0 +1,250 @@
+import torch
+from torch.fx import GraphModule
+from torch.fx import Node
+
+from .pt2e.prepare import prepare
+from .pt2e.qat_utils import (
+    _fuse_conv_bn_qat,
+    _fold_conv_bn_qat,
+)
+from .pt2e.utils import (
+    _get_node_name_to_scope,
+    _fuse_conv_bn_,
+    _disallow_eval_train,
+)
+from .pt2e.representation import reference_representation_rewrite
+from .quantize_fx import _convert_to_reference_decomposed_fx
+from torch.ao.quantization.quantizer import (  # noqa: F401
+    Quantizer,
+    QuantizationSpecBase,
+    QuantizationSpec,
+    FixedQParamsQuantizationSpec,
+    SharedQuantizationSpec,
+    DerivedQuantizationSpec,
+    QuantizationAnnotation,
+)
+from torch.fx.passes.infra.pass_manager import PassManager
+from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
+from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
+from torch._inductor.constant_folding import constant_fold
+
+__all__ = [
+    "prepare_pt2e",
+    "prepare_qat_pt2e",
+    "convert_pt2e",
+]
+
+
+def prepare_pt2e(
+    model: GraphModule,
+    quantizer: Quantizer,
+) -> GraphModule:
+    """Prepare a model for post training quantization
+
+    Args:
+      * `model` (torch.fx.GraphModule): a model captured by `torch.export` API
+        in the short term we are using `torch._export.capture_pre_autograd_graph`,
+        in the long term we'll migrate to some `torch.export` API
+      * `quantizer`: A backend specific quantizer that conveys how user want the
+        model to be quantized. Tutorial for how to write a quantizer can be found here:
+        https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
+
+    Return:
+      A GraphModule with observer (based on quantizer annotation), ready for calibration
+
+    Example::
+
+        import torch
+        from torch.ao.quantization.quantize_pt2e import prepare_pt2e
+        from torch._export import capture_pre_autograd_graph
+        from torch.ao.quantization.quantizer import (
+            XNNPACKQuantizer,
+            get_symmetric_quantization_config,
+        )
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 10)
+
+           def forward(self, x):
+               return self.linear(x)
+
+        # initialize a floating point model
+        float_model = M().eval()
+
+        # define calibration function
+        def calibrate(model, data_loader):
+            model.eval()
+            with torch.no_grad():
+                for image, target in data_loader:
+                    model(image)
+
+        # Step 1. program capture
+        # NOTE: this API will be updated to torch.export API in the future, but the captured
+        # result shoud mostly stay the same
+        m = capture_pre_autograd_graph(m, *example_inputs)
+        # we get a model with aten ops
+
+        # Step 2. quantization
+        # backend developer will write their own Quantizer and expose methods to allow
+        # users to express how they
+        # want the model to be quantized
+        quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
+        m = prepare_pt2e(m, quantizer)
+
+        # run calibration
+        # calibrate(m, sample_inference_data)
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
+    original_graph_meta = model.meta
+    node_name_to_scope = _get_node_name_to_scope(model)
+    # TODO: check qconfig_mapping to make sure conv and bn are both configured
+    # to be quantized before fusion
+    # TODO: (maybe) rewrite this with subgraph_rewriter
+    _fuse_conv_bn_(model)
+    quantizer.transform_for_annotation(model)
+    quantizer.annotate(model)
+    quantizer.validate(model)
+    model = prepare(model, node_name_to_scope, is_qat=False)
+    model.meta.update(original_graph_meta)
+    model = _disallow_eval_train(model)
+    return model
+
+def prepare_qat_pt2e(
+    model: GraphModule,
+    quantizer: Quantizer,
+) -> GraphModule:
+    """Prepare a model for quantization aware training
+
+    Args:
+      * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
+      * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
+
+    Return:
+      A GraphModule with fake quant modules (based on quantizer annotation), ready for
+      quantization aware training
+
+    Example::
+        import torch
+        from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
+        from torch._export import capture_pre_autograd_graph
+        from torch.ao.quantization.quantizer import (
+            XNNPACKQuantizer,
+            get_symmetric_quantization_config,
+        )
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 10)
+
+           def forward(self, x):
+               return self.linear(x)
+
+        # initialize a floating point model
+        float_model = M().eval()
+
+        # define the training loop for quantization aware training
+        def train_loop(model, train_data):
+            model.train()
+            for image, target in data_loader:
+                ...
+
+        # Step 1. program capture
+        # NOTE: this API will be updated to torch.export API in the future, but the captured
+        # result shoud mostly stay the same
+        m = capture_pre_autograd_graph(m, *example_inputs)
+        # we get a model with aten ops
+
+        # Step 2. quantization
+        # backend developer will write their own Quantizer and expose methods to allow
+        # users to express how they
+        # want the model to be quantized
+        quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
+        m = prepare_qat_pt2e(m, quantizer)
+
+        # run quantization aware training
+        train_loop(prepared_model, train_loop)
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
+    original_graph_meta = model.meta
+    node_name_to_scope = _get_node_name_to_scope(model)
+    quantizer.transform_for_annotation(model)
+    quantizer.annotate(model)
+    quantizer.validate(model)
+    # Perform fusion after annotate to avoid quantizing ops in the new
+    # subgraph that don't need to be quantized
+    # TODO: only fuse if conv and bn are both configured to be quantized
+    _fuse_conv_bn_qat(model)
+    model = prepare(model, node_name_to_scope, is_qat=True)
+    model.meta.update(original_graph_meta)
+    model = _disallow_eval_train(model)
+    return model
+
+_QUANT_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+def _quant_node_constraint(n: Node) -> bool:
+    """If there is any pure ops between get_attr and quantize op they will be const propagated
+    e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
+    (Note: dequantize op is not going to be constant propagated)
+
+    This filter is added because we don't want to constant fold the things that are not
+    related to quantization
+    """
+    return n.op == "call_function" and n.target in _QUANT_OPS
+
+def convert_pt2e(
+    model: GraphModule,
+    use_reference_representation: bool = False,
+    fold_quantize: bool = True,
+) -> GraphModule:
+    """Convert a calibrated/trained model to a quantized model
+
+    Args:
+      * `model` (torch.fx.GraphModule): calibrated/trained model
+      * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not
+      * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not
+
+    Returns:
+        quantized model, either in q/dq representation or reference representation
+
+    Example::
+
+        # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training
+        # `convert_pt2e` produces a quantized model that represents quantized computation with
+        # quantize dequantize ops and fp32 ops by default.
+        # Please refer to
+        # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model
+        # for detailed explanation of output quantized model
+        quantized_model = convert_pt2e(prepared_model)
+
+    """  # flake8: noqa
+    torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
+    if not isinstance(use_reference_representation, bool):
+        raise ValueError(
+            "Unexpected argument type for `use_reference_representation`, "
+            f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e")
+    original_graph_meta = model.meta
+    model = _convert_to_reference_decomposed_fx(model)
+    model = _fold_conv_bn_qat(model)
+
+    pm = PassManager([DuplicateDQPass()])
+    model = pm(model).graph_module
+
+    pm = PassManager([PortNodeMetaForQDQ()])
+    model = pm(model).graph_module
+
+    if fold_quantize:
+        constant_fold(model, _quant_node_constraint)
+
+    if use_reference_representation:
+        model = reference_representation_rewrite(model)
+
+    model.meta.update(original_graph_meta)
+    model = _disallow_eval_train(model)
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__init__.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..232f79dd3591caa2b70e15626031658c00cdff0a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__init__.py
@@ -0,0 +1,21 @@
+from .quantizer import (
+    DerivedQuantizationSpec,
+    EdgeOrNode,
+    FixedQParamsQuantizationSpec,
+    QuantizationAnnotation,
+    QuantizationSpec,
+    QuantizationSpecBase,
+    Quantizer,
+    SharedQuantizationSpec,
+)
+
+__all__ = [
+    "EdgeOrNode",
+    "Quantizer",
+    "QuantizationSpecBase",
+    "QuantizationSpec",
+    "FixedQParamsQuantizationSpec",
+    "SharedQuantizationSpec",
+    "DerivedQuantizationSpec",
+    "QuantizationAnnotation",
+]
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c683af5ac2a237cbf9f1aede59a716d8b5a02ea
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4c0b4211813bb2d2830870e0b4b08acd676f75b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cfe70be517f579f5d5338872120fac86f06a3b7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af219fd828a423715a9cf3c2adcf28043e2f8046
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c8dec6ba3c86f62e73e661e5dac06cefe9b3a65
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..086552d922ce599503a26f50ae01fff650dbebcc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d92d1b13b20525e1468c2e957dd3b44c78629bf3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0186d8b8bc58fd47b2f065ed16faf8efc175eb5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..34289a1bba9d1521e4decd6f3b79ca1f43a19c6a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py
@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+from typing import Dict, List
+
+import torch
+
+from torch.fx import Node
+
+from .quantizer import QuantizationAnnotation, Quantizer
+
+__all__ = [
+    "ComposableQuantizer",
+]
+
+
+class ComposableQuantizer(Quantizer):
+    """
+    ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
+    This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
+    maybe supported by one quantizer while linear layers and other ops might be supported by another
+    quantizer.
+
+    ComposableQuantizer is initialized with a list of `Quantizer` instances.
+    The order of the composition matters since that is the order in which the quantizers will be
+    applies.
+    Example:
+    ```
+    embedding_quantizer = EmbeddingQuantizer()
+    linear_quantizer = MyLinearQuantizer()
+    xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
+    composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
+    prepared_m = prepare_pt2e(model, composed_quantizer)
+    ```
+    """
+
+    def __init__(self, quantizers: List[Quantizer]):
+        super().__init__()
+        self.quantizers = quantizers
+        self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
+
+    def _record_and_validate_annotations(
+        self, gm: torch.fx.GraphModule, quantizer: Quantizer
+    ) -> None:
+        for n in gm.graph.nodes:
+            if "quantization_annotation" in n.meta:
+                # check if the annotation has been changed by
+                # comparing QuantizationAnnotation object id
+                if n in self._graph_annotations and (
+                    id(self._graph_annotations[n])
+                    != id(n.meta["quantization_annotation"])
+                ):
+                    raise RuntimeError(
+                        f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
+                    )
+                else:
+                    self._graph_annotations[n] = n.meta["quantization_annotation"]
+            else:
+                if n in self._graph_annotations:
+                    raise RuntimeError(
+                        f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
+                    )
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        for quantizer in self.quantizers:
+            quantizer.annotate(model)
+            self._record_and_validate_annotations(model, quantizer)
+        return model
+
+    def transform_for_annotation(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        for quantizer in self.quantizers:
+            model = quantizer.transform_for_annotation(model)
+        return model
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcbc6115f129394d441b470c2890acd6605f8de
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py
@@ -0,0 +1,96 @@
+from __future__ import annotations
+
+import copy
+from typing import List, Set
+
+import torch
+import torch.nn.functional as F
+from torch.ao.quantization.observer import PerChannelMinMaxObserver
+from torch.ao.quantization.quantizer.quantizer import (
+    QuantizationAnnotation,
+    QuantizationSpec,
+    Quantizer,
+)
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
+    OperatorConfig,
+    OperatorPatternType,
+    QuantizationConfig,
+)
+
+__all__ = [
+    "get_embedding_operators_config",
+    "EmbeddingQuantizer",
+]
+
+
+def get_embedding_operators_config() -> OperatorConfig:
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.uint8,
+        qscheme=torch.per_channel_affine_float_qparams,
+        ch_axis=0,
+        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12),
+    )
+    quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None)
+    ops: List[OperatorPatternType] = [[torch.nn.Embedding]]
+    ops.append([F.embedding])
+    supported_config_and_operators = OperatorConfig(
+        config=quantization_config, operators=ops
+    )
+    return copy.deepcopy(supported_config_and_operators)
+
+
+class EmbeddingQuantizer(Quantizer):
+    def __init__(self):
+        super().__init__()
+
+    @classmethod
+    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
+        op_configs: Set[QuantizationConfig] = set({})
+        for spec, _ in cls.get_supported_operators():
+            op_configs.add(spec)
+        return list(op_configs)
+
+    @classmethod
+    def get_supported_operator_for_quantization_config(
+        cls, quantization_config: QuantizationConfig
+    ) -> List[OperatorPatternType]:
+        for config, ops in cls.get_supported_operators():
+            # note: this assumes each entry in cls.supported_spec_and_operators
+            # corresponds to one spec, e.g. we don't have
+            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
+            # where the first and second entry have the same spec but did not
+            # merge the op list
+            if config == quantization_config:
+                return ops
+        return []
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        self._annotate_embedding_ops(model.graph)
+        return model
+
+    def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None:
+        embedding_config: OperatorConfig = get_embedding_operators_config()
+        for node in graph.nodes:
+            # Keep node parsing based annotations instead of module partitioners
+            # just as an example of alternate ways of annotating
+            if (
+                node.op == "call_function"
+                and node.target == torch.ops.aten.embedding.default
+            ):
+                if embedding_config.config.weight is None:
+                    raise ValueError(
+                        "Embedding config must have a valid weight quantization spec."
+                    )
+                node.meta["quantization_annotation"] = QuantizationAnnotation(
+                    input_qspec_map={
+                        node.args[0]: embedding_config.config.weight,
+                    }
+                )
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
+
+    @classmethod
+    def get_supported_operators(cls) -> List[OperatorConfig]:
+        return [get_embedding_operators_config()]
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/quantizer.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be2acc70b340519dd8971a66eceeffea89890ae
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/quantizer.py
@@ -0,0 +1,158 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.ao.quantization import ObserverOrFakeQuantize
+from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+from torch.fx import Node
+
+__all__ = [
+    "Quantizer",
+    "QuantizationSpecBase",
+    "QuantizationSpec",
+    "FixedQParamsQuantizationSpec",
+    "EdgeOrNode",
+    "SharedQuantizationSpec",
+    "DerivedQuantizationSpec",
+    "QuantizationAnnotation",
+]
+
+
+class QuantizationSpecBase(ABC):  # noqa: B024
+    """Base class for different types of quantization specs that allows users to
+    specify how to quantize a Tensor (input/output of a Node) in the model
+    """
+
+    pass
+
+
+@dataclass(eq=True, frozen=True)
+class QuantizationSpec(QuantizationSpecBase):
+    """Quantization spec for common operators that allows user to specify how to
+    quantize a Tensor, this includes dtype, quant_min, quant_max etc.
+    """
+
+    dtype: torch.dtype
+    # observer or fake_quantize constructor such as
+    # MinMaxObserver, PerChannelHistogramObserver etc.
+    # or we can attach some custom args to them
+    # e.g. MinMaxObserver.with_args(eps=eps)
+    observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
+    quant_min: Optional[int] = None
+    quant_max: Optional[int] = None
+    qscheme: Optional[torch.qscheme] = None
+    ch_axis: Optional[int] = None
+    is_dynamic: bool = False
+
+    def __post_init__(self):
+        # quant_min must be less than quant_max
+        if (
+            self.quant_min is not None
+            and self.quant_max is not None
+            and self.quant_min > self.quant_max
+        ):
+            raise ValueError(
+                f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
+            )
+
+        # ch_axis must be less than the number of channels
+        # but no way to check here. Just check that it is not < 0.
+        if self.ch_axis is not None and self.ch_axis < 0:
+            raise ValueError("Ch_axis is < 0.")
+
+
+@dataclass(eq=True, frozen=True)
+class FixedQParamsQuantizationSpec(QuantizationSpecBase):
+    dtype: torch.dtype
+    scale: float
+    zero_point: int
+    quant_min: Optional[int] = None
+    quant_max: Optional[int] = None
+    qscheme: Optional[torch.qscheme] = None
+
+
+"""
+The way we refer to other points of quantization in the graph will be either
+an input edge or an output value
+input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
+output value is an fx Node
+"""
+EdgeOrNode = Union[Tuple[Node, Node], Node]
+EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
+
+
+@dataclass(eq=True, frozen=True)
+class SharedQuantizationSpec(QuantizationSpecBase):
+    """
+    Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
+    """
+
+    # the edge or node to share observer or fake quant instances with
+    edge_or_node: EdgeOrNode
+
+
+@dataclass(eq=True, frozen=True)
+class DerivedQuantizationSpec(QuantizationSpecBase):
+    """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
+
+    derived_from: List[EdgeOrNode]
+    derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
+    dtype: torch.dtype
+    quant_min: Optional[int] = None
+    quant_max: Optional[int] = None
+    qscheme: Optional[torch.qscheme] = None
+    ch_axis: Optional[int] = None
+
+
+@dataclass
+class QuantizationAnnotation:
+    """How are input arguemnt or output should be quantized,
+    expressed as QuantizationSpec, this corresponds to how a Tensor in the
+    operator Graph is observed (PTQ) or fake quantized (QAT)
+    """
+
+    # a map from torch.fx.Node to a type of QuantizationSpecBase
+    input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field(
+        default_factory=dict
+    )
+
+    # How the output of this node is quantized, expressed as QuantizationSpec
+    # TODO: change the value to QuantizationSpec in a separate PR
+    output_qspec: Optional[QuantizationSpecBase] = None
+
+    # For a Node: node1 and edge: (node1, node2), since they are observing the same
+    # Tensor, we may want to implicitly share observers, this flag allows people to
+    # turn off this behavior for the output of the node
+    allow_implicit_sharing: bool = True
+
+    # whether the node is annotated or not
+    _annotated: bool = False
+
+
+class Quantizer(ABC):
+    def transform_for_annotation(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        """Allows for user defined transforms to run before annotating the graph.
+        This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.
+        For example quantizer can
+        a) decompose a compound operator like scaled dot product attention,
+        into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa
+        or b) transform scalars to tensor to allow quantizing scalares.
+
+        Note: this is an optional method
+        """
+        return model
+
+    # annotate nodes in the graph with observer or fake quant constructors
+    # to convey the desired way of quantization
+    @abstractmethod
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        pass
+
+    # validate the annotated graph is supported by the backend
+    @abstractmethod
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2553bb8faa43875e9c4673dbad157d1b5af313e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/utils.py
@@ -0,0 +1,49 @@
+from typing import List
+
+from torch.ao.quantization.pt2e.utils import _is_sym_size_node
+
+from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
+from torch.fx import Node
+
+
+def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
+    quantization_annotation = node.meta.get(
+        "quantization_annotation", QuantizationAnnotation()
+    )
+    if quantization_annotation.input_qspec_map is None:
+        quantization_annotation.input_qspec_map = {}
+    quantization_annotation.input_qspec_map[input_node] = qspec
+    node.meta["quantization_annotation"] = quantization_annotation
+
+
+def _annotate_output_qspec(node: Node, qspec):
+    quantization_annotation = node.meta.get(
+        "quantization_annotation", QuantizationAnnotation()
+    )
+    quantization_annotation.output_qspec = qspec
+    node.meta["quantization_annotation"] = quantization_annotation
+
+
+def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
+    """
+    This utility is used to handle cases when dynami_shape=True tracing leads
+    to symint nodes in the pattern of linear module. In those cases, we need to
+    distinguish between the nodes that are in input for just extracting value of
+    some dimentions (and symint nodes) vs. the one that is activation.
+    For example:
+    graph(x, y, weight):
+       size_0 = torch.ops.aten.sym_size([x], [0])
+       size_1 = torch.ops.aten.sym_size([y], [1])
+       view_size = size_0 * size_1
+       size_3 = torch.ops.aten.sym_size([x], [2])
+       vie_out = torch.ops.aten.view(x, [view_size, size_3])
+       return mm(view_out, weight)
+    In the example above y node is not actual input. It exist only to extract size_1
+    """
+    if _is_sym_size_node(node):
+        return True
+
+    return all(
+        ((user not in partition_nodes) or _is_sym_size_node(user))
+        for user in node.users
+    )
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..620032293480d3b8d938f78fec4fd187c9ca6533
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
@@ -0,0 +1,1016 @@
+import copy
+import functools
+import itertools
+import operator
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantize,
+    FusedMovingAvgObsFakeQuantize,
+)
+from torch.ao.quantization.observer import (
+    HistogramObserver,
+    MovingAverageMinMaxObserver,
+    MovingAveragePerChannelMinMaxObserver,
+    PerChannelMinMaxObserver,
+    PlaceholderObserver,
+)
+from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
+from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+from torch.ao.quantization.quantizer.quantizer import (
+    QuantizationAnnotation,
+    QuantizationSpec,
+    Quantizer,
+    SharedQuantizationSpec,
+)
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
+    _is_annotated,
+    get_bias_qspec,
+    get_input_act_qspec,
+    get_output_act_qspec,
+    get_weight_qspec,
+    OperatorConfig,
+    OperatorPatternType,
+    QuantizationConfig,
+)
+from torch.fx import Node
+from torch.fx.passes.utils.source_matcher_utils import (
+    get_source_partitions,
+    SourcePartition,
+)
+
+__all__ = [
+    "X86InductorQuantizer",
+    "get_default_x86_inductor_quantization_config",
+]
+
+
+@dataclass
+class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
+    # _is_output_of_quantized_pattern:
+    #  * Node as output node of a fusion pattern.
+    #  * The fusion pattern supports int8 data type.
+    #  * The fusion pattern has inputs annotated to insert observer.
+    _is_output_of_quantized_pattern: bool = False
+
+
+# Operations that:
+# 1. Operations are optimized to run with int8 when int8 input provided.
+# 2. Operations do not support int8 input and produce fp32 output.
+int8_in_int8_out_ops_pt2e: Set = {
+    torch.ops.aten.max_pool2d.default,
+    torch.ops.aten.cat.default,
+    torch.ops.aten.avg_pool2d.default,
+    torch.ops.aten.adaptive_avg_pool2d.default,
+    torch.ops.aten.flatten.using_ints,
+}
+
+
+# Operations support the int8 data type and exclude operations such as conv and linear.
+# A superset of int8_in_int8_out_ops_pt2e incorporating additional operators.
+quantizable_ops_pt2e = copy.deepcopy(int8_in_int8_out_ops_pt2e)
+
+QUANT_ANNOTATION_KEY = "quantization_annotation"
+
+
+def _mark_nodes_as_annotated(nodes: List[Node]):
+    for node in nodes:
+        if node is not None:
+            if QUANT_ANNOTATION_KEY not in node.meta:
+                node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation()
+            node.meta[QUANT_ANNOTATION_KEY]._annotated = True
+
+
+def _is_node_annotated(_node):
+    """
+    return True if the node is annotated, otherwise return False
+    """
+    return (
+        QUANT_ANNOTATION_KEY in _node.meta
+        and _node.meta[QUANT_ANNOTATION_KEY]._annotated
+    )
+
+
+def _is_any_annotated(nodes: List[Node]):
+    """
+    Given a list of nodes (that represents an operator pattern),
+    check if any of the node is annotated, return True if any of the node
+    is annotated, otherwise return False.
+    """
+    return any(_is_node_annotated(node) for node in nodes)
+
+
+def _is_all_annotated(nodes: List[Node]):
+    """
+    Given a list of nodes (that represents an operator pattern),
+    return True if all of the node is annotated, otherwise return False.
+    """
+    return all(_is_node_annotated(node) for node in nodes)
+
+
+def _is_quantized_op_pt2e(node: torch.fx.Node):
+    """
+    Used for pt2e flow to check if the node is a quantized node:
+    Case1: the node has been annotated as output node of a fusion pattern.
+    Case2: the node has been annotated as single quantized node.
+    """
+    if not _is_any_annotated([node]):
+        # The node has not been annotated, directly return False
+        return False
+    quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
+    assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
+    return quantization_annotation._is_output_of_quantized_pattern
+
+
+def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
+    # TODO: Add more supported operators here.
+    supported_operators: Dict[str, List[OperatorPatternType]] = {
+        "conv2d": [
+            [torch.nn.Conv2d],
+            [F.conv2d],
+        ],
+    }
+
+    # Append Conv Optional(Add) Optioinal(ReLU)
+    conv_add_relu_options = itertools.product(
+        [torch.nn.Conv2d, F.conv2d],
+        [torch.add, operator.add, None],  # add
+        [torch.nn.ReLU, F.relu, None],  # relu
+    )
+    for conv_op, add_op, relu_op in conv_add_relu_options:
+        if add_op is None:
+            # Append Conv ReLU
+            supported_operators["conv2d"].append([conv_op, relu_op])  # type: ignore[list-item]
+        elif relu_op is None:
+            # Append Conv Add
+            supported_operators["conv2d"].append([conv_op, add_op])  # type: ignore[list-item]
+        else:
+            # Append Conv Add ReLU
+            supported_operators["conv2d"].append([conv_op, add_op, relu_op])  # type: ignore[list-item]
+
+    return copy.deepcopy(supported_operators)
+
+
+def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
+    supported_config_and_operators: List[OperatorConfig] = []
+    for quantization_config in [
+        get_default_x86_inductor_quantization_config(),
+    ]:
+        ops = _supported_quantized_operators()
+        for pattern_list in ops.values():
+            supported_config_and_operators.append(
+                OperatorConfig(quantization_config, pattern_list)
+            )
+    return copy.deepcopy(supported_config_and_operators)
+
+
+@functools.lru_cache
+def get_default_x86_inductor_quantization_config(
+    is_qat: bool = False,
+    is_dynamic: bool = False,
+):
+    extra_args: Dict[str, Any] = {"eps": 2**-12}
+    if is_qat:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = FakeQuantize
+            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
+                averaging_constant=1
+            )
+            extra_args["observer"] = dynamic_quant_observer
+        else:
+            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
+    else:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
+        else:
+            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
+
+    # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
+    act_quantization_spec = QuantizationSpec(
+        dtype=torch.uint8,
+        quant_min=0,
+        quant_max=255,  # reduce_range=False
+        qscheme=torch.per_tensor_affine,
+        is_dynamic=is_dynamic,
+        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+
+    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
+        FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver
+    )
+
+    if is_qat:
+        # Only support per channel quant for now
+        extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=-128,
+        quant_max=127,
+        qscheme=torch.per_channel_symmetric,
+        ch_axis=0,  # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
+        is_dynamic=False,
+        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+    bias_quantization_spec = None  # will use placeholder observer by default
+    quantization_config = QuantizationConfig(
+        act_quantization_spec,
+        act_quantization_spec,
+        weight_quantization_spec,
+        bias_quantization_spec,
+        is_qat,
+    )
+    return quantization_config
+
+
+def _get_supported_config_and_operators() -> List[OperatorConfig]:
+    return _get_supported_x86_inductor_config_and_operators()
+
+
+class X86InductorQuantizer(Quantizer):
+    supported_config_and_operators = _get_supported_config_and_operators()
+
+    def __init__(self):
+        super().__init__()
+        self.global_config: QuantizationConfig = None  # type: ignore[assignment]
+        self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
+
+    @classmethod
+    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
+        op_configs: Set[QuantizationConfig] = set({})
+        for spec, _ in cls.supported_config_and_operators:
+            op_configs.add(spec)
+        return list(op_configs)
+
+    @classmethod
+    def get_supported_operator_for_quantization_config(
+        cls, quantization_config: Optional[QuantizationConfig]
+    ) -> List[OperatorPatternType]:
+        if quantization_config is None:
+            all_ops = []
+            for _, ops in cls.supported_config_and_operators:
+                all_ops.extend(ops)
+            return all_ops
+
+        for config, ops in cls.supported_config_and_operators:
+            if config == quantization_config:
+                return ops
+        return []
+
+    def set_global(self, quantization_config: QuantizationConfig):
+        self.global_config = quantization_config
+        return self
+
+    def set_config_for_operator_type(
+        self, operator_type: str, quantization_config: QuantizationConfig
+    ):
+        self.operator_type_config[operator_type] = quantization_config
+        return self
+
+    def _annotate_conv_node_helper(
+        self,
+        conv_node: torch.fx.Node,
+        annotate_output: bool,
+        quantization_config: QuantizationConfig,
+    ) -> None:
+        """Helper function to annotate the conv node"""
+        input_qspec_map = {}
+        input_node = conv_node.args[0]
+        assert isinstance(input_node, Node)
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+        weight_node = conv_node.args[1]
+        assert isinstance(weight_node, Node)
+        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+        bias_node = None if len(conv_node.args) == 2 else conv_node.args[2]
+        if isinstance(bias_node, Node):
+            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+        if annotate_output:
+            conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+        else:
+            conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+            )
+
+    def _annotate_linear_node_helper(
+        self,
+        linear_node: torch.fx.Node,
+        annotate_output: bool,
+        quantization_config: QuantizationConfig,
+    ) -> None:
+        """Helper function to annotate the linear node"""
+        input_qspec_map = {}
+        assert linear_node.target in (torch.ops.aten.linear.default,)
+        has_bias = len(linear_node.args) == 3
+        input_index = 0
+        weight_index = 1
+        bias_index = 2
+
+        input_node = linear_node.args[input_index]
+        assert isinstance(input_node, Node)
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+
+        weight_node = linear_node.args[weight_index]
+        assert isinstance(weight_node, Node)
+        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+
+        bias_node = linear_node.args[bias_index] if has_bias else None
+        if isinstance(bias_node, Node):
+            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+
+        if annotate_output:
+            linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+        else:
+            linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map, _annotated=True
+            )
+
+    def _get_output_nodes_of_partitions(
+        self,
+        partition_list: List[SourcePartition],
+    ) -> List[torch.fx.Node]:
+        """Helper function to get the output node list from partition list"""
+        output_node_list = []
+        for partition in partition_list:
+            if len(partition.output_nodes) > 1:
+                raise ValueError("Input partition has more than one output node")
+            output_node = partition.output_nodes[0]
+            assert isinstance(output_node, Node)
+            output_node_list.append(output_node)
+        if len(output_node_list) != len(partition_list):
+            raise ValueError(
+                "length of output_node_list should equal to length of partition_list"
+            )
+        return output_node_list
+
+    def _get_input_idx_for_binary_node(
+        self,
+        conv_gemm_node: torch.fx.Node,
+        binary_node: torch.fx.Node,
+    ):
+        """Helper function to check conv_gemm and extra input node index
+        for binary node fused with conv_gemm.
+        """
+        conv_gemm_node_idx = None
+        extra_input_node_idx = None
+        if (binary_node.args[0].op == "call_function") and (  # type: ignore[union-attr]
+            binary_node.args[0] == conv_gemm_node
+        ):
+            conv_gemm_node_idx = 0
+            extra_input_node_idx = 1
+        elif (binary_node.args[1].op == "call_function") and (  # type: ignore[union-attr]
+            binary_node.args[1] == conv_gemm_node
+        ):
+            conv_gemm_node_idx = 1
+            extra_input_node_idx = 0
+        extra_input_node = binary_node.args[extra_input_node_idx]  # type: ignore[index]
+        assert isinstance(extra_input_node, Node)
+        return conv_gemm_node_idx, extra_input_node_idx
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        if self.global_config and self.global_config.input_activation.is_dynamic:  # type: ignore[union-attr]
+            model = self._annotate_for_dynamic_quantization_config(model)
+        else:
+            model = self._annotate_for_static_quantization_config(model)
+        return model
+
+    def _annotate_for_static_quantization_config(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        r"""
+        High-level description of quantization recipe for X86 Inductor Backend:
+        Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively.
+        Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model
+        from start to the end. If a pattern supports computation with int8 data type and inputs connected to
+        quantized patterns, annotate its inputs as quantized pattern.
+        Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns,
+        such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type,
+        we need to annotate the output of this pattern.
+        """
+
+        config = self.global_config
+
+        # Step1: Recipe of fusion patterns like conv/linear.
+        if config.is_qat:
+            # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat
+            self._annotate_qat_conv2d_fusion_pattern(model, config)
+
+        self._annotate_conv2d_fusion_pattern(model, config)
+
+        # Step2: Recipe to propagate annotation for patterns beside conv/linear.
+        # Go through all the nodes from start to end.
+        # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/
+        # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538
+        for node in model.graph.nodes:
+            self._annotation_propagation_quantizable_pattern(node, config)
+
+        # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized
+        # in inputs. So, we can fuse dq-operator-q into a quantized op.
+        # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/
+        # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487
+        for node in model.graph.nodes:
+            self._annotate_output_for_int8_in_int8_out_pattern(node, config)
+
+        return model
+
+    def _annotate_for_dynamic_quantization_config(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        config = self.global_config
+        self._annotate_linear(model, config)
+        return model
+
+    def _annotate_qat_conv2d_fusion_pattern(
+        self, model: torch.fx.GraphModule, config: QuantizationConfig
+    ):
+        # Annotate QAT Specific patterns
+        self._annotate_qat_conv2d_bn_binary_unary(model, config)
+        self._annotate_qat_conv2d_bn_binary(model, config)
+        self._annotate_qat_conv2d_bn_unary(model, config)
+        self._annotate_qat_conv2d_bn(model, config)
+
+    def _annotate_qat_conv2d_bn_binary_unary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU]
+        )
+        for fused_partition in fused_partitions:
+            (
+                conv_partition,
+                bn_partition,
+                binary_partition,
+                unary_partition,
+            ) = fused_partition
+
+            (
+                conv_node,
+                bn_output_node,
+                binary_node,
+                unary_node,
+            ) = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition, binary_partition, unary_partition]
+            )
+            if len(bn_output_node.users) != 1:
+                # Conv BN pattern should only has 1 user.
+                continue
+            (
+                bn_output_node_idx,
+                extra_input_node_idx,
+            ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node)
+            if (bn_output_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if bn_output_node != binary_node.args[bn_output_node_idx]:
+                raise ValueError(f"{bn_output_node} doesn't match input of binary node")
+            extra_input_node = binary_node.args[extra_input_node_idx]
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _is_annotated([unary_node, binary_node, bn_output_node, conv_node]):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+
+            binary_node_input_qspec_map = {}
+            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                quantization_config
+            )
+            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=binary_node_input_qspec_map,
+                _annotated=True,
+            )
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            nodes_to_mark_annotated.extend(list(binary_partition.nodes))
+            nodes_to_mark_annotated.extend(list(unary_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_qat_conv2d_bn_binary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, bn_partition, binary_partition = fused_partition
+            (
+                conv_node,
+                bn_output_node,
+                binary_node,
+            ) = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition, binary_partition]
+            )
+            if len(bn_output_node.users) != 1:
+                # Conv BN pattern should only has 1 user.
+                continue
+            (
+                bn_output_node_idx,
+                extra_input_node_idx,
+            ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node)
+            if (bn_output_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if bn_output_node != binary_node.args[bn_output_node_idx]:
+                raise ValueError(f"{bn_output_node} doesn't match input of binary node")
+
+            extra_input_node = binary_node.args[extra_input_node_idx]
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _is_annotated([binary_node, bn_output_node, conv_node]):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+
+            binary_node_input_qspec_map = {}
+            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                quantization_config
+            )
+            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=binary_node_input_qspec_map,
+                # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            nodes_to_mark_annotated.extend(list(binary_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_qat_conv2d_bn_unary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        fused_partitions = []
+        unary_patterns = [
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
+        ]
+        for unary_pattern in unary_patterns:
+            partitions = find_sequential_partitions(gm, unary_pattern)
+            if partitions:
+                # Extend the fused_partitions if partitions is not empty
+                fused_partitions.extend(partitions)
+
+        for fused_partition in fused_partitions:
+            conv_partition, bn_partition, unary_partition = fused_partition
+            (
+                conv_node,
+                bn_output_node,
+                unary_node,
+            ) = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition, unary_partition]
+            )
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _is_annotated([unary_node, bn_output_node, conv_node]):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            nodes_to_mark_annotated.extend(list(unary_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_qat_conv2d_bn(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, bn_partition = fused_partition
+            conv_node, bn_output_node = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition]
+            )
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _is_annotated([bn_output_node, conv_node]):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            bn_output_node.meta[
+                QUANT_ANNOTATION_KEY
+            ] = _X86InductorQuantizationAnnotation(
+                # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_conv2d_fusion_pattern(
+        self, model: torch.fx.GraphModule, config: QuantizationConfig
+    ):
+        self._annotate_conv2d_binary_unary(model, config)
+        self._annotate_conv2d_binary(model, config)
+        self._annotate_conv2d_unary(model, config)
+        self._annotate_conv2d(model, config)
+        self._annotate_linear_unary(model, config)
+        self._annotate_linear(model, config)
+
+    def _annotate_conv2d_binary_unary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        # Conv2d + add + unary op
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, binary_partition, unary_partition = fused_partition
+            conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
+                [conv_partition, binary_partition, unary_partition]
+            )
+            if len(conv_node.users) != 1:
+                # Conv Node should only has 1 user node
+                continue
+            conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
+                conv_node, binary_node
+            )
+            if (conv_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if conv_node != binary_node.args[conv_node_idx]:
+                raise ValueError(f"{conv_node} doesn't match input of binary node")
+            extra_input_node = binary_node.args[extra_input_node_idx]
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                # No conv node found to be fused with add
+                continue
+            if _is_annotated([unary_node, binary_node, conv_node]):
+                continue
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            binary_node_input_qspec_map = {}
+            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                quantization_config
+            )
+            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=binary_node_input_qspec_map,
+                _annotated=True,
+            )
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d_binary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        # Conv2d + add
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, operator.add]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, binary_partition = fused_partition
+            conv_node, binary_node = self._get_output_nodes_of_partitions(
+                [conv_partition, binary_partition]
+            )
+            if len(conv_node.users) != 1:
+                # Conv Node should only has 1 user node
+                continue
+            conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
+                conv_node, binary_node
+            )
+            if (conv_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if conv_node != binary_node.args[conv_node_idx]:
+                raise ValueError(f"{conv_node} doesn't match input of binary node")
+            extra_input_node = binary_node.args[extra_input_node_idx]
+            assert isinstance(conv_node, Node)
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                # No conv node found to be fused with add
+                continue
+            if _is_annotated([binary_node, conv_node]):
+                continue
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            binary_node_input_qspec_map = {}
+            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                quantization_config
+            )
+            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=binary_node_input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d_unary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        fused_partitions = []
+        unary_patterns = [
+            [torch.nn.Conv2d, torch.nn.ReLU],
+            [torch.nn.Conv2d, torch.nn.Hardtanh],
+            [torch.nn.Conv2d, torch.nn.Hardswish],
+            [torch.nn.Conv2d, torch.nn.ReLU6],
+        ]
+        for unary_pattern in unary_patterns:
+            partitions = find_sequential_partitions(gm, unary_pattern)
+            if partitions:
+                # Extend the fused_partitions if partitions is not empty
+                fused_partitions.extend(partitions)
+
+        for fused_partition in fused_partitions:
+            conv_partition, unary_partition = fused_partition
+            conv_node, unary_node = self._get_output_nodes_of_partitions(
+                [conv_partition, unary_partition]
+            )
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+            if _is_annotated([unary_node, conv_node]):
+                continue
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        conv_partitions = get_source_partitions(
+            gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
+        )
+        conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values()))
+        for conv_partition in conv_partitions:
+            if len(conv_partition.output_nodes) > 1:
+                raise ValueError("conv partition has more than one output node")
+            conv_node = conv_partition.output_nodes[0]
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                raise ValueError(f"{conv_node} is not an aten conv2d operator")
+            # skip annotation if it is already annotated
+            if _is_annotated([conv_node]):
+                continue
+            self._annotate_conv_node_helper(conv_node, True, quantization_config)
+
+    def _annotate_maxpool2d(
+        self, node: Node, quantization_config: QuantizationConfig
+    ) -> None:
+        if node.target is not torch.ops.aten.max_pool2d.default:
+            return
+        maxpool_node = node
+        if _is_any_annotated(
+            [
+                maxpool_node,
+            ]
+        ):
+            return
+        input_node = maxpool_node.args[0]
+        assert isinstance(input_node, Node)
+        input_qspec_map = {}
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+        maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+            _is_output_of_quantized_pattern=True,
+        )
+
+    def _annotate_cat(
+        self, node: Node, quantization_config: QuantizationConfig
+    ) -> None:
+        cat_node = node
+        input_nodes = cat_node.args[0]
+        assert isinstance(input_nodes, Sequence)
+        first_input_node = input_nodes[0]
+        input_qspec_map = {}
+        assert isinstance(first_input_node, Node)
+        assert isinstance(cat_node, Node)
+        input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
+        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
+            (first_input_node, cat_node)
+        )
+
+        for input_node in input_nodes[1:]:
+            if input_node not in input_qspec_map:
+                # There has the case of cat same nodes: torch.cat([input0, input0], 1)
+                assert isinstance(input_node, Node)
+                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
+
+        cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+            _is_output_of_quantized_pattern=True,
+        )
+
+    def _annotation_propagation_quantizable_pattern(
+        self, node: Node, quantization_config: QuantizationConfig
+    ) -> None:
+        # Propagate annotation to quantizable patterns.
+        if (
+            (node.target in quantizable_ops_pt2e)
+            and (not _is_any_annotated([node]))
+            and (node.op == "call_function")
+        ):
+
+            def is_all_inputs_connected_to_quantized_op(input_nodes):
+                # Ensure all the inputs connect to fusion pattern or quantized node
+                for input_node in input_nodes:
+                    if not _is_quantized_op_pt2e(input_node):
+                        return False
+                return True
+
+            if node.target is torch.ops.aten.max_pool2d.default:
+                # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not
+                input_nodes_to_check = [node.all_input_nodes[0]]
+                if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
+                    return
+                self._annotate_maxpool2d(node, quantization_config)
+                return
+            elif node.target is torch.ops.aten.cat.default:
+                input_nodes_to_check = node.all_input_nodes
+                if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
+                    return
+                self._annotate_cat(node, quantization_config)
+            else:
+                input_node = node.all_input_nodes[0]
+                if not is_all_inputs_connected_to_quantized_op(
+                    [
+                        input_node,
+                    ]
+                ):
+                    return
+                input_qspec_map = {}
+                input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+                node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                    input_qspec_map=input_qspec_map,
+                    _annotated=True,
+                    _is_output_of_quantized_pattern=True,
+                )
+        return
+
+    def _annotate_output_share_observer_as_input(
+        self, input_node: Node, source_node: Node
+    ):
+        source_node_quantization_annotation = (
+            source_node.meta[QUANT_ANNOTATION_KEY]
+            if QUANT_ANNOTATION_KEY in source_node.meta
+            else None
+        )
+        if (
+            source_node_quantization_annotation
+            and source_node_quantization_annotation._is_output_of_quantized_pattern
+        ):
+            edge_or_node = (input_node, source_node)
+            source_node_quantization_annotation.output_qspec = SharedQuantizationSpec(
+                edge_or_node
+            )
+        return
+
+    def _annotate_output_for_int8_in_int8_out_pattern(
+        self, node: Node, quantization_config: QuantizationConfig
+    ) -> None:
+        r"""
+        Check and insert observer at output of node in int8_in_int8_out_ops_pt2e if needed.
+        Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
+        90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
+        """
+        edge_or_node: Tuple[Node, Node]
+        if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])):
+            if node.target == torch.ops.aten.max_pool2d.default:
+                maxpool_node = node
+                if not _is_all_annotated(
+                    [
+                        maxpool_node,
+                    ]
+                ):
+                    return
+                # Get the quantization_annotation from getitem_node
+                maxpool_node_quantization_annotation = (
+                    maxpool_node.meta[QUANT_ANNOTATION_KEY]
+                    if QUANT_ANNOTATION_KEY in maxpool_node.meta
+                    else None
+                )
+                if (
+                    maxpool_node_quantization_annotation
+                    and maxpool_node_quantization_annotation._is_output_of_quantized_pattern
+                ):
+                    # Annotate the output_qspec of getitem_node
+                    input_act = maxpool_node.args[0]
+                    assert isinstance(input_act, Node)
+                    assert isinstance(maxpool_node, Node)
+                    edge_or_node = (input_act, maxpool_node)
+                    maxpool_node_quantization_annotation.output_qspec = (
+                        SharedQuantizationSpec(edge_or_node)
+                    )
+            else:
+                input_node = node.all_input_nodes[0]
+                self._annotate_output_share_observer_as_input(input_node, node)
+        return
+
+    def _annotate_linear(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        linear_partitions = get_source_partitions(
+            gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
+        )
+        linear_partitions = list(
+            itertools.chain.from_iterable(linear_partitions.values())
+        )
+        for partition in linear_partitions:
+            if len(partition.output_nodes) > 1:
+                raise ValueError(
+                    "Linear partition cannot have more than one output node"
+                )
+            linear_node = partition.output_nodes[0]
+            if linear_node.op != "call_function" or linear_node.target not in (
+                torch.ops.aten.linear.default,
+            ):
+                raise ValueError(f"{linear_node} is not an aten linear operator")
+            # skip annotation if it is already annotated
+            if _is_annotated([linear_node]):
+                continue
+            self._annotate_linear_node_helper(linear_node, True, quantization_config)
+
+    def _annotate_linear_unary(
+        self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
+    ) -> None:
+        postop_list = [
+            torch.nn.ReLU,
+            torch.nn.LeakyReLU,
+            torch.nn.Tanh,
+        ]
+        fused_partitions: List[tuple] = []
+        for postop in postop_list:
+            fused_partitions = fused_partitions + find_sequential_partitions(
+                gm, [torch.nn.Linear, postop]
+            )
+        for fused_partition in fused_partitions:
+            linear_partition, unary_partition = fused_partition
+            linear_node, unary_node = self._get_output_nodes_of_partitions(
+                [linear_partition, unary_partition]
+            )
+            if linear_node.op != "call_function" or linear_node.target not in (
+                torch.ops.aten.linear.default,
+            ):
+                continue
+            if _is_annotated([unary_node, linear_node]):
+                continue
+            self._annotate_linear_node_helper(linear_node, False, quantization_config)
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
+
+    @classmethod
+    def get_supported_operators(cls) -> List[OperatorConfig]:
+        return cls.supported_config_and_operators
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ffdfc409b8795c8344b58b13296493f7855554
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py
@@ -0,0 +1,453 @@
+from __future__ import annotations
+
+import copy
+import functools
+
+from typing import Any, Callable, Dict, List, Optional, Set
+
+import torch
+import torch._dynamo as torchdynamo
+import torch.nn.functional as F
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantize,
+    FusedMovingAvgObsFakeQuantize,
+)
+from torch.ao.quantization.observer import (
+    HistogramObserver,
+    MinMaxObserver,
+    MovingAverageMinMaxObserver,
+    MovingAveragePerChannelMinMaxObserver,
+    PerChannelMinMaxObserver,
+    PlaceholderObserver,
+)
+
+from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+
+from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
+
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
+    _convert_scalars_to_attrs,
+    OP_TO_ANNOTATOR,
+    OperatorConfig,
+    OperatorPatternType,
+    propagate_annotation,
+    QuantizationConfig,
+)
+
+from torch.fx import Node
+
+
+__all__ = [
+    "XNNPACKQuantizer",
+    "get_symmetric_quantization_config",
+]
+
+
+def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
+    gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
+    gm.graph.eliminate_dead_code()
+    return gm.graph
+
+
+def _get_linear_patterns(input_size: List[int]):
+    in_channels = input_size[-1]
+    out_channels = 8  # hard coding but this should not matter
+    weight = torch.ones((out_channels, in_channels))
+    bias = torch.ones((out_channels,))
+    act = torch.ones(input_size)
+
+    def linear_op(act, weight, bias=None):
+        return F.linear(act, weight, bias)
+
+    pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
+    pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
+    return [pattern_w_bias, pattern_wo_bias]
+
+
+def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
+    supported_operators: Dict[str, List[OperatorPatternType]] = {
+        # Both conv and linear should be able to handle relu + hardtanh fusion since
+        # those are clamp ops
+        "conv2d": [
+            [torch.nn.Conv2d, torch.nn.ReLU],
+            [torch.nn.Conv2d, F.relu],
+            [F.conv2d, torch.nn.ReLU],
+            [F.conv2d, F.relu],
+        ],
+        "linear": [[torch.nn.Linear], [F.linear]],
+        "add": [[torch.add]],
+        "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
+        "adaptive_avg_pool2d": [
+            [torch.nn.AdaptiveAvgPool2d],
+            [F.adaptive_avg_pool2d],
+        ],
+    }
+    return copy.deepcopy(supported_operators)
+
+
+def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
+    supported_config_and_operators: List[OperatorConfig] = []
+    for quantization_config in [
+        get_symmetric_quantization_config(),
+        get_symmetric_quantization_config(is_qat=True),
+        get_symmetric_quantization_config(is_per_channel=True),
+        get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
+    ]:
+        ops = _supported_symmetric_quantized_operators()
+        for pattern_list in ops.values():
+            supported_config_and_operators.append(
+                OperatorConfig(quantization_config, pattern_list)
+            )
+    return copy.deepcopy(supported_config_and_operators)
+
+
+@functools.lru_cache
+def get_symmetric_quantization_config(
+    is_per_channel: bool = False,
+    is_qat: bool = False,
+    is_dynamic: bool = False,
+    act_qmin: int = -128,
+    act_qmax: int = 127,
+    weight_qmin: int = -127,
+    weight_qmax: int = 127,
+):
+    extra_args: Dict[str, Any] = {"eps": 2**-12}
+    if is_qat:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = FakeQuantize
+            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
+                averaging_constant=1
+            )
+            extra_args["observer"] = dynamic_quant_observer
+        else:
+            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
+    else:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
+        else:
+            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
+
+    act_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=act_qmin,
+        quant_max=act_qmax,
+        qscheme=torch.per_tensor_affine,
+        is_dynamic=is_dynamic,
+        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
+            **extra_args,
+        ),
+    )
+    weight_qscheme = (
+        torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
+    )
+    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
+        MinMaxObserver
+    )
+    if is_qat:
+        # TODO: qat + per channel?
+        weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
+    elif is_per_channel:
+        weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
+
+    extra_args: Dict[str, Any] = {"eps": 2**-12}
+    if is_qat:
+        if weight_qscheme == torch.per_tensor_symmetric:
+            extra_args["observer"] = MovingAverageMinMaxObserver
+        else:
+            extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=weight_qmin,
+        quant_max=weight_qmax,
+        qscheme=weight_qscheme,
+        ch_axis=0,
+        is_dynamic=False,
+        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+
+    bias_quantization_spec = None
+    if is_dynamic:
+        quantization_config = QuantizationConfig(
+            act_quantization_spec,
+            None,
+            weight_quantization_spec,
+            bias_quantization_spec,
+            is_qat,
+        )
+    else:
+        quantization_config = QuantizationConfig(
+            act_quantization_spec,
+            act_quantization_spec,
+            weight_quantization_spec,
+            bias_quantization_spec,
+            is_qat,
+        )
+    return quantization_config
+
+
+def _get_supported_config_and_operators() -> List[OperatorConfig]:
+    return _get_supported_symmetric_config_and_operators()
+
+
+def _get_module_name_filter(module_name: str):
+    """Get the module_name_filter function for a given module name, the filter accepts
+    a node and checks if the node comes from a module that has certain module name
+
+    For example:
+        node: linear_op = call_function[...](...)  # comes from a module with name blocks.sub.linear1
+
+
+    >> module_name_filter = _get_module_name_filter("blocks.sub")
+    >> print(module_name_filter(node))
+    True  # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
+    """
+
+    def module_name_filter(n: Node) -> bool:
+        # example: {
+        #    'L__self___sub': ("L['self'].sub", ),
+        #    'L__self___sub_linear': ("L['self'].sub.linear", )
+        # }
+        # get_attr nodes doesn't have nn_module_stack?
+        nn_module_stack = n.meta.get("nn_module_stack", {})
+        names = [n[len("L['self'].") :] for n, klass in nn_module_stack.values()]
+        return module_name in names
+
+    return module_name_filter
+
+
+def _get_module_type_filter(tp: Callable):
+    """Get the module_type_filter function for a given module type, the filter accepts
+    a node and checks if the node comes from a module that has certain module type
+
+    For example:
+        node: linear_op = call_function[...](...)  # comes from a module with type Block -> Sub -> Linear
+
+
+    >> module_type_filter = _get_module_type_filter(Sub)  # submodule with type `Sub`, under the `Block` submodule
+    >> print(module_type_filter(node))
+    True  # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
+    """
+
+    def module_type_filter(n: Node) -> bool:
+        # example: {
+        #     'L__self___sub': ("L['self'].sub", ),
+        #     'L__self___sub_linear': ("L['self'].sub.linear", )
+        # }
+        nn_module_stack = n.meta.get("nn_module_stack", {})
+        types = [t for _, t in nn_module_stack.values()]
+        return tp in types
+
+    return module_type_filter
+
+
+def _get_not_module_type_or_name_filter(
+    tp_list: List[Callable], module_name_list: List[str]
+) -> Callable[[Node], bool]:
+    module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
+    module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
+
+    def not_module_type_or_name_filter(n: Node) -> bool:
+        return not any(f(n) for f in module_type_filters + module_name_list_filters)
+
+    return not_module_type_or_name_filter
+
+
+class XNNPACKQuantizer(Quantizer):
+    supported_config_and_operators = _get_supported_config_and_operators()
+    STATIC_QAT_ONLY_OPS = [
+        "conv_bn_relu",
+        "conv_bn",
+    ]
+
+    # static quantization ops (both PTQ and QAT)
+    # Preserve the order that fusions come before singular ops
+    STATIC_OPS = [
+        "linear_relu",
+        "linear",
+        "conv_relu",
+        "conv",
+        "adaptive_avg_pool2d",
+        # TODO: move this to BoltNNQuantizer?
+        "gru_io_only",
+        "max_pool2d",
+        "add_relu",
+        "add",
+        "mul_relu",
+        "mul",
+        "cat",
+    ]
+
+    DYNAMIC_OPS = [
+        "linear",
+    ]
+
+    def __init__(self):
+        super().__init__()
+        self.global_config: Optional[QuantizationConfig] = None
+        self.operator_type_config: Dict[
+            torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
+        ] = {}
+        self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
+        self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
+
+    @classmethod
+    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
+        op_configs: Set[QuantizationConfig] = set({})
+        for spec, _ in cls.supported_config_and_operators:
+            op_configs.add(spec)
+        return list(op_configs)
+
+    @classmethod
+    def get_supported_operator_for_quantization_config(
+        cls, quantization_config: Optional[QuantizationConfig]
+    ) -> List[OperatorPatternType]:
+        if quantization_config is None:
+            all_ops = []
+            for _, ops in cls.supported_config_and_operators:
+                all_ops.extend(ops)
+            return all_ops
+
+        for config, ops in cls.supported_config_and_operators:
+            # note: this assumes each entry in cls.supported_spec_and_operators
+            # corresponds to one spec, e.g. we don't have
+            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
+            # where the first and second entry have the same spec but did not
+            # merge the op list
+            if config == quantization_config:
+                return ops
+        return []
+
+    def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
+        self.global_config = quantization_config
+        return self
+
+    def set_operator_type(
+        self,
+        operator_type: torch._ops.OpOverloadPacket,
+        quantization_config: QuantizationConfig,
+    ) -> XNNPACKQuantizer:
+        self.operator_type_config[operator_type] = quantization_config
+        return self
+
+    def set_module_type(
+        self, module_type: Callable, quantization_config: QuantizationConfig
+    ):
+        """Set quantization_config for a submodule with type: `module_type`, for example:
+        quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
+        patterns in the submodule with this module type with the given `quantization_config`
+        """
+        self.module_type_config[module_type] = quantization_config
+        return self
+
+    def set_module_name(
+        self, module_name: str, quantization_config: Optional[QuantizationConfig]
+    ):
+        """Set quantization_config for a submodule with name: `module_name`, for example:
+        quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
+        patterns in the submodule with this module name with the given `quantization_config`
+        """
+        assert (
+            quantization_config is not None
+        ), " quantization_config == None is not supported yet"
+        self.module_name_config[module_name] = quantization_config
+        return self
+
+    def transform_for_annotation(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        """Transforms scalar values to tensor attributes"""
+        return _convert_scalars_to_attrs(model)
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        # hacked for handling dynamic linear quant. will fix later.
+        if self.global_config and self.global_config.input_activation.is_dynamic:  # type: ignore[union-attr]
+            model = self._annotate_for_dynamic_quantization_config(model)
+        else:
+            model = self._annotate_for_static_quantization_config(model)
+        propagate_annotation(model)
+        return model
+
+    def _annotate_all_static_patterns(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[Callable[[Node], bool]] = None,
+    ) -> torch.fx.GraphModule:
+        # TODO: implement the support for None to be canceling out previous annotations
+        if quantization_config is None:
+            return model
+
+        if quantization_config.is_qat:
+            for op in self.STATIC_QAT_ONLY_OPS:
+                OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
+        for op in self.STATIC_OPS:
+            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
+        return model
+
+    def _annotate_all_dynamic_patterns(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[Callable[[Node], bool]] = None,
+    ) -> torch.fx.GraphModule:
+        # TODO: implement the support for None to be canceling out previous annotations
+        if quantization_config is None:
+            return model
+
+        for op in self.DYNAMIC_OPS:
+            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
+        return model
+
+    def _annotate_for_static_quantization_config(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        module_name_list = list(self.module_name_config.keys())
+        for module_name, config in self.module_name_config.items():
+            self._annotate_all_static_patterns(
+                model, config, _get_module_name_filter(module_name)
+            )
+
+        tp_list = list(self.module_type_config.keys())
+        for module_type, config in self.module_type_config.items():
+            self._annotate_all_static_patterns(
+                model, config, _get_module_type_filter(module_type)
+            )
+
+        self._annotate_all_static_patterns(
+            model,
+            self.global_config,
+            _get_not_module_type_or_name_filter(tp_list, module_name_list),
+        )
+        return model
+
+    def _annotate_for_dynamic_quantization_config(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        module_name_list = list(self.module_name_config.keys())
+        for module_name, config in self.module_name_config.items():
+            self._annotate_all_dynamic_patterns(
+                model, config, _get_module_name_filter(module_name)
+            )
+
+        tp_list = list(self.module_type_config.keys())
+        for module_type, config in self.module_type_config.items():
+            self._annotate_all_dynamic_patterns(
+                model, config, _get_module_type_filter(module_type)
+            )
+
+        self._annotate_all_dynamic_patterns(
+            model,
+            self.global_config,
+            _get_not_module_type_or_name_filter(tp_list, module_name_list),
+        )
+        return model
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
+
+    @classmethod
+    def get_supported_operators(cls) -> List[OperatorConfig]:
+        return cls.supported_config_and_operators
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a38a28c30a4fc1f970caa1523922ecc127bc69
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
@@ -0,0 +1,1032 @@
+import itertools
+import operator
+from dataclasses import dataclass
+from typing import Callable, Dict, List, NamedTuple, Optional
+
+import torch
+import torch.nn.functional as F
+from torch._subclasses import FakeTensor
+from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
+from torch.ao.quantization.pt2e.export_utils import _WrapperModule
+from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
+from torch.ao.quantization.pt2e.utils import (
+    _conv1d_bn_example_inputs,
+    _conv2d_bn_example_inputs,
+    get_aten_graph_module,
+)
+from torch.ao.quantization.quantizer import (
+    QuantizationAnnotation,
+    QuantizationSpec,
+    QuantizationSpecBase,
+    SharedQuantizationSpec,
+)
+
+from torch.ao.quantization.quantizer.utils import (
+    _annotate_input_qspec_map,
+    _annotate_output_qspec,
+)
+from torch.fx import Node
+from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
+    SubgraphMatcherWithNameNodeMap,
+)
+from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
+
+
+__all__ = [
+    "OperatorConfig",
+    "OperatorPatternType",
+    "QuantizationConfig",
+    "get_input_act_qspec",
+    "get_output_act_qspec",
+    "get_weight_qspec",
+    "get_bias_qspec",
+    "OP_TO_ANNOTATOR",
+    "propagate_annotation",
+]
+
+
+# In the absence of better name, just winging it with QuantizationConfig
+@dataclass(eq=True, frozen=True)
+class QuantizationConfig:
+    input_activation: Optional[QuantizationSpec]
+    output_activation: Optional[QuantizationSpec]
+    weight: Optional[QuantizationSpec]
+    bias: Optional[QuantizationSpec]
+    # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
+    is_qat: bool = False
+
+
+OperatorPatternType = List[Callable]
+OperatorPatternType.__module__ = (
+    "torch.ao.quantization.quantizer.xnnpack_quantizer_utils"
+)
+
+AnnotatorType = Callable[
+    [
+        torch.fx.GraphModule,
+        Optional[QuantizationConfig],
+        Optional[Callable[[Node], bool]],
+    ],
+    Optional[List[List[Node]]],
+]
+OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {}
+
+
+def register_annotator(op: str):
+    def decorator(annotator: AnnotatorType):
+        OP_TO_ANNOTATOR[op] = annotator
+
+    return decorator
+
+
+class OperatorConfig(NamedTuple):
+    # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
+    # Basically we are mapping a quantization config to some list of patterns.
+    # a pattern is defined as a list of nn module, function or builtin function names
+    # e.g. [nn.Conv2d, torch.relu, torch.add]
+    # We have not resolved whether fusion can be considered internal details of the
+    # quantizer hence it does not need communication to user.
+    # Note this pattern is not really informative since it does not really
+    # tell us the graph structure resulting from the list of ops.
+    config: QuantizationConfig
+    operators: List[OperatorPatternType]
+
+
+def _is_annotated(nodes: List[Node]):
+    """
+    Given a list of nodes (that represents an operator pattern),
+    check if any of the node is annotated, return True if any of the node
+    is annotated, otherwise return False
+    """
+    annotated = False
+    for node in nodes:
+        annotated = annotated or (
+            "quantization_annotation" in node.meta
+            and node.meta["quantization_annotation"]._annotated
+        )
+    return annotated
+
+
+def _mark_nodes_as_annotated(nodes: List[Node]):
+    for node in nodes:
+        if node is not None:
+            if "quantization_annotation" not in node.meta:
+                node.meta["quantization_annotation"] = QuantizationAnnotation()
+            node.meta["quantization_annotation"]._annotated = True
+
+
+def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    if quantization_config.input_activation is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.input_activation
+    assert quantization_spec.qscheme in [
+        torch.per_tensor_affine,
+        torch.per_tensor_symmetric,
+    ]
+    return quantization_spec
+
+
+def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    if quantization_config.output_activation is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.output_activation
+    assert quantization_spec.qscheme in [
+        torch.per_tensor_affine,
+        torch.per_tensor_symmetric,
+    ]
+    return quantization_spec
+
+
+def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    assert quantization_config is not None
+    if quantization_config.weight is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.weight
+    if quantization_spec.qscheme not in [
+        torch.per_tensor_symmetric,
+        torch.per_channel_symmetric,
+    ]:
+        raise ValueError(
+            f"Unsupported quantization_spec {quantization_spec} for weight"
+        )
+    return quantization_spec
+
+
+def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    assert quantization_config is not None
+    if quantization_config.bias is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.bias
+    assert (
+        quantization_spec.dtype == torch.float
+    ), "Only float dtype for bias is supported for bias right now"
+    return quantization_spec
+
+
+@register_annotator("linear")
+def _annotate_linear(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    annotated_partitions = []
+    input_act_qspec = get_input_act_qspec(quantization_config)
+    output_act_qspec = get_output_act_qspec(quantization_config)
+    weight_qspec = get_weight_qspec(quantization_config)
+    bias_qspec = get_bias_qspec(quantization_config)
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
+            continue
+        if filter_fn and not filter_fn(node):
+            continue
+        act_node = node.args[0]
+        weight_node = node.args[1]
+        bias_node = None
+        if len(node.args) > 2:
+            bias_node = node.args[2]
+
+        if _is_annotated([node]) is False:  # type: ignore[list-item]
+            _annotate_input_qspec_map(
+                node,
+                act_node,
+                input_act_qspec,
+            )
+            _annotate_input_qspec_map(
+                node,
+                weight_node,
+                weight_qspec,
+            )
+            nodes_to_mark_annotated = [node, weight_node]
+            if bias_node:
+                _annotate_input_qspec_map(
+                    node,
+                    bias_node,
+                    bias_qspec,
+                )
+                nodes_to_mark_annotated.append(bias_node)
+            _annotate_output_qspec(node, output_act_qspec)
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+            annotated_partitions.append(nodes_to_mark_annotated)
+
+    return annotated_partitions
+
+
+@register_annotator("linear_relu")
+def _annotate_linear_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    annotated_partitions = []
+    input_act_qspec = get_input_act_qspec(quantization_config)
+    output_act_qspec = get_output_act_qspec(quantization_config)
+    weight_qspec = get_weight_qspec(quantization_config)
+    bias_qspec = get_bias_qspec(quantization_config)
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target not in [
+            torch.ops.aten.relu.default,
+            torch.ops.aten.relu_.default,
+        ]:
+            continue
+        relu_node = node
+        maybe_linear_node = node.args[0]
+        if (
+            not isinstance(maybe_linear_node, Node)
+            or maybe_linear_node.op != "call_function"
+            or maybe_linear_node.target != torch.ops.aten.linear.default
+        ):
+            continue
+
+        linear_node = maybe_linear_node
+        input_qspec_map = {}
+        input_act = linear_node.args[0]
+        assert isinstance(input_act, Node)
+        input_qspec_map[input_act] = input_act_qspec
+
+        weight = linear_node.args[1]
+        assert isinstance(weight, Node)
+        input_qspec_map[weight] = weight_qspec
+
+        # adding weight node to the partition as well
+        partition = [relu_node, linear_node, weight]
+        bias = linear_node.args[2] if len(linear_node.args) > 2 else None
+        if isinstance(bias, Node):
+            input_qspec_map[bias] = bias_qspec
+            partition.append(bias)
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("conv")
+def _annotate_conv(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    annotated_partitions = []
+    for n in gm.graph.nodes:
+        if n.op != "call_function" or n.target not in [
+            torch.ops.aten.conv1d.default,
+            torch.ops.aten.conv2d.default,
+        ]:
+            continue
+        conv_node = n
+
+        input_qspec_map = {}
+        input_act = conv_node.args[0]
+        assert isinstance(input_act, Node)
+        input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
+
+        weight = conv_node.args[1]
+        assert isinstance(weight, Node)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
+
+        # adding weight node to the partition as well
+        partition = [conv_node, conv_node.args[1]]
+
+        bias = conv_node.args[2] if len(conv_node.args) > 2 else None
+        if isinstance(bias, Node):
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
+            partition.append(bias)
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=get_output_act_qspec(quantization_config),
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("conv_relu")
+def _annotate_conv_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    annotated_partitions = []
+    for n in gm.graph.nodes:
+        if n.op != "call_function" or n.target not in [
+            torch.ops.aten.relu.default,
+            torch.ops.aten.relu_.default,
+        ]:
+            continue
+        relu_node = n
+        maybe_conv_node = n.args[0]
+        if (
+            not isinstance(maybe_conv_node, Node)
+            or maybe_conv_node.op != "call_function"
+            or maybe_conv_node.target
+            not in [
+                torch.ops.aten.conv1d.default,
+                torch.ops.aten.conv2d.default,
+            ]
+        ):
+            continue
+        conv_node = maybe_conv_node
+
+        input_qspec_map = {}
+        input_act = conv_node.args[0]
+        assert isinstance(input_act, Node)
+        input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
+
+        weight = conv_node.args[1]
+        assert isinstance(weight, Node)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
+
+        # adding weight node to the partition as well
+        partition = [relu_node, conv_node, conv_node.args[1]]
+        bias = conv_node.args[2] if len(conv_node.args) > 2 else None
+        if isinstance(bias, Node):
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
+            partition.append(bias)
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map, _annotated=True
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("conv_bn")
+def _annotate_conv_bn(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    """
+    Find conv + batchnorm parititions
+    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
+    """
+    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
+
+
+@register_annotator("conv_bn_relu")
+def _annotate_conv_bn_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    """
+    Find conv + batchnorm + relu parititions
+    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
+    """
+    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
+
+
+def _do_annotate_conv_bn(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]],
+    has_relu: bool,
+) -> List[List[Node]]:
+    """
+    Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
+    return a list of annotated partitions.
+
+    The output of the pattern must include a dictionary from string name to node
+    for the following names: "input", "conv", "weight", "bias", and "output".
+    """
+
+    def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
+        def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
+            conv = conv_fn(x, conv_weight, conv_bias)
+            bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
+            if has_relu:
+                output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
+            else:
+                output = bn
+            return output, {
+                "input": x,
+                "conv": conv,
+                "weight": conv_weight,
+                "bias": conv_bias,
+                "output": output,
+            }
+
+        return _WrapperModule(_conv_bn)
+
+    # Needed for matching, otherwise the matches gets filtered out due to unused
+    # nodes returned by batch norm
+    gm.graph.eliminate_dead_code()
+    gm.recompile()
+
+    matches = []
+    combinations = [
+        (F.conv1d, _conv1d_bn_example_inputs),
+        (F.conv2d, _conv2d_bn_example_inputs),
+    ]
+
+    # Add `is_cuda` and `relu_is_inplace` dimensions
+    combinations = itertools.product(
+        combinations,
+        [True, False] if torch.cuda.is_available() else [False],  # is_cuda
+        [True, False] if has_relu else [False],  # relu_is_inplace
+    )
+
+    # Match against all conv dimensions and cuda variants
+    for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
+        pattern = get_pattern(conv_fn, relu_is_inplace)
+        pattern = get_aten_graph_module(pattern, example_inputs, is_cuda)
+        pattern.graph.eliminate_dead_code()
+        pattern.recompile()
+        matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
+        matches.extend(matcher.match(gm.graph))
+
+    # Annotate nodes returned in the matches
+    annotated_partitions = []
+    for match in matches:
+        name_node_map = match.name_node_map
+        input_node = name_node_map["input"]
+        conv_node = name_node_map["conv"]
+        weight_node = name_node_map["weight"]
+        bias_node = name_node_map["bias"]
+        output_node = name_node_map["output"]
+
+        # TODO: annotate the uses of input, weight, and bias separately instead
+        # of assuming they come from a single conv node. This is not possible today
+        # because input may have multiple users, and we can't rely on the conv node
+        # always being the first user. This was the case in models with skip
+        # connections like resnet18
+
+        # Validate conv args
+        if conv_node.args[0] is not input_node:
+            raise ValueError("Conv arg did not contain input node ", input_node)
+        if conv_node.args[1] is not weight_node:
+            raise ValueError("Conv arg did not contain weight node ", weight_node)
+        if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
+            raise ValueError("Conv arg did not contain bias node ", bias_node)
+
+        # Skip if the partition is already annotated or is filtered out by the user
+        partition = [conv_node, weight_node]
+        if bias_node is not None:
+            partition.append(bias_node)
+        if _is_annotated(partition):
+            continue
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        # Annotate conv inputs and pattern output
+        input_qspec_map = {}
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+        if bias_node is not None:
+            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        output_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("gru_io_only")
+def _annotate_gru_io_only(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
+    gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values()))
+    annotated_partitions = []
+    for gru_partition in gru_partitions:
+        annotated_partitions.append(gru_partition.nodes)
+        output_nodes = gru_partition.output_nodes
+        input_nodes = gru_partition.input_nodes
+        # skip annotation if it is already annotated
+        if _is_annotated(input_nodes + output_nodes):
+            continue
+        # inside each GRU partition, we should be able to annotate each linear
+        # subgraph
+        input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
+        input_act = input_nodes[0]
+        input_act_user = next(iter(input_act.users.keys()))
+        assert isinstance(input_act, Node)
+        assert isinstance(input_act_user, Node)
+        input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                input_act: get_input_act_qspec(quantization_config),
+            },
+            _annotated=True,
+        )
+
+        hidden_state = input_nodes[1]
+        hidden_state_user = next(iter(hidden_state.users.keys()))
+        assert isinstance(hidden_state, Node)
+        assert isinstance(hidden_state_user, Node)
+        hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                hidden_state: get_input_act_qspec(quantization_config),
+            },
+            _annotated=True,
+        )
+
+        assert len(output_nodes) == 2, "expecting GRU to have two outputs"
+        for output in output_nodes:
+            output.meta["quantization_annotation"] = QuantizationAnnotation(
+                output_qspec=get_output_act_qspec(quantization_config),
+                _annotated=True,
+            )
+        nodes_to_mark_annotated = list(gru_partition.nodes)
+        _mark_nodes_as_annotated(nodes_to_mark_annotated)
+    return annotated_partitions
+
+
+@register_annotator("max_pool2d")
+def _annotate_max_pool2d(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    module_partitions = get_source_partitions(
+        gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
+    )
+    maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values()))
+    annotated_partitions = []
+    for maxpool_partition in maxpool_partitions:
+        annotated_partitions.append(maxpool_partition.nodes)
+        output_node = maxpool_partition.output_nodes[0]
+        maxpool_node = None
+        for n in maxpool_partition.nodes:
+            if n.target == torch.ops.aten.max_pool2d.default:
+                maxpool_node = n
+        assert (
+            maxpool_node is not None
+        ), "XNNPACKQuantizer only works with torch.ops.aten.max_pool2d.default, "
+        "please make sure you are exporting the model correctly"
+        if _is_annotated([output_node, maxpool_node]):  # type: ignore[list-item]
+            continue
+
+        input_act = maxpool_node.args[0]  # type: ignore[union-attr]
+        assert isinstance(input_act, Node)
+
+        # only annotate maxpool when the output of the input node is annotated
+        if (
+            "quantization_annotation" not in input_act.meta
+            or not input_act.meta["quantization_annotation"]._annotated
+            or input_act.meta["quantization_annotation"].output_qspec is None
+        ):
+            continue
+        # input and output of maxpool will share quantization parameter with input of maxpool
+        act_qspec = SharedQuantizationSpec(input_act)
+        # act_qspec = get_act_qspec(quantization_config)
+        maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation(  # type: ignore[union-attr]
+            input_qspec_map={
+                input_act: act_qspec,
+            },
+            _annotated=True,
+        )
+        output_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+@register_annotator("adaptive_avg_pool2d")
+def _annotate_adaptive_avg_pool2d(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    """Always annotate adaptive_avg_pool2d op"""
+    module_partitions = get_source_partitions(
+        gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
+    )
+    partitions = list(itertools.chain.from_iterable(module_partitions.values()))
+    annotated_partitions = []
+    for partition in partitions:
+        pool_node = partition.output_nodes[0]
+        if (
+            pool_node.op != "call_function"
+            or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
+        ):
+            raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
+
+        if _is_annotated([pool_node]):
+            continue
+
+        annotated_partitions.append(partition.nodes)
+        input_act = pool_node.args[0]
+        assert isinstance(input_act, Node)
+
+        # only annotate input output sharing operator
+        # when the output of the input node is annotated
+        if (
+            "quantization_annotation" not in input_act.meta
+            or not input_act.meta["quantization_annotation"]._annotated
+            or input_act.meta["quantization_annotation"].output_qspec is None
+        ):
+            input_act_qspec = get_input_act_qspec(quantization_config)
+        else:
+            input_act_qspec = SharedQuantizationSpec(input_act)
+
+        # output sharing with input
+        output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
+        pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                input_act: input_act_qspec,
+            },
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule):
+    """Check if input is a large scalar value. So that we can skip quantization for the node
+    since histc op (in HistogramObserver) only works for values up to certain upper bound
+    """
+    if node.op == "get_attr":
+        tensor = getattr(gm, node.target)  # type: ignore[arg-type]
+        # torch.histc works until this upper bound
+        HISTC_UPPER_BOUND = 3.4028235e15
+        return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
+    return False
+
+
+def _is_input_non_float_tensor(node: Node):
+    """Check if the input is not a float tensor, so that we can skip quantization for the node
+    since observers only works with float Tensors
+    """
+    if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
+        return True
+    return node.meta["val"].dtype != torch.float32
+
+
+@register_annotator("add_relu")
+def _annotate_add_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    fused_partitions = find_sequential_partitions(
+        gm, [torch.add, torch.nn.ReLU], filter_fn=filter_fn
+    )
+    annotated_partitions = []
+    for fused_partition in fused_partitions:
+        add_partition, relu_partition = fused_partition
+        annotated_partitions.append(add_partition.nodes + relu_partition.nodes)
+        if len(relu_partition.output_nodes) > 1:
+            raise ValueError("Relu partition has more than one output node")
+        relu_node = relu_partition.output_nodes[0]
+        if len(add_partition.output_nodes) > 1:
+            raise ValueError("add partition has more than one output node")
+        add_node = add_partition.output_nodes[0]
+
+        if _is_annotated([relu_node, add_node]):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = add_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            input_qspec_map[input_act0] = input_act_qspec
+
+        input_act1 = add_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            input_qspec_map[input_act1] = input_act_qspec
+
+        add_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+@register_annotator("add")
+def _annotate_add(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    add_partitions = get_source_partitions(
+        gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
+    )
+    add_partitions = list(itertools.chain.from_iterable(add_partitions.values()))
+    annotated_partitions = []
+    for add_partition in add_partitions:
+        annotated_partitions.append(add_partition.nodes)
+        add_node = add_partition.output_nodes[0]
+        if _is_annotated([add_node]):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = add_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            input_qspec_map[input_act0] = input_act_qspec
+
+        input_act1 = add_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            input_qspec_map[input_act1] = input_act_qspec
+
+        add_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+@register_annotator("mul_relu")
+def _annotate_mul_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    fused_partitions = find_sequential_partitions(
+        gm, [torch.mul, torch.nn.ReLU], filter_fn=filter_fn
+    )
+    annotated_partitions = []
+    for fused_partition in fused_partitions:
+        mul_partition, relu_partition = fused_partition
+        annotated_partitions.append(mul_partition.nodes + relu_partition.nodes)
+        if len(relu_partition.output_nodes) > 1:
+            raise ValueError("Relu partition has more than one output node")
+        relu_node = relu_partition.output_nodes[0]
+        if len(mul_partition.output_nodes) > 1:
+            raise ValueError("mul partition has more than one output node")
+        mul_node = mul_partition.output_nodes[0]
+
+        if _is_annotated([relu_node, mul_node]):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = mul_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            input_qspec_map[input_act0] = input_act_qspec
+
+        input_act1 = mul_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            input_qspec_map[input_act1] = input_act_qspec
+
+        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+@register_annotator("mul")
+def _annotate_mul(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    mul_partitions = get_source_partitions(
+        gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
+    )
+    mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))
+    annotated_partitions = []
+    for mul_partition in mul_partitions:
+        annotated_partitions.append(mul_partition.nodes)
+        mul_node = mul_partition.output_nodes[0]
+        if _is_annotated([mul_node]):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = mul_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            input_qspec_map[input_act0] = input_act_qspec
+
+        input_act1 = mul_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            input_qspec_map[input_act1] = input_act_qspec
+
+        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+# TODO: remove Optional in return type, fix annotated_partitions logic
+@register_annotator("cat")
+def _annotate_cat(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[List[List[Node]]]:
+    cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
+    cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
+    annotated_partitions = []
+    for cat_partition in cat_partitions:
+        cat_node = cat_partition.output_nodes[0]
+        if _is_annotated([cat_node]):
+            continue
+
+        if cat_node.target != torch.ops.aten.cat.default:
+            # TODO: change this to AnnotationException
+            raise Exception(
+                f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
+                " please check if you are calling the correct capture API"
+            )
+
+        annotated_partitions.append(cat_partition.nodes)
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        inputs = cat_node.args[0]
+
+        input_qspec_map = {}
+        input_act0 = inputs[0]
+        if isinstance(input_act0, Node):
+            input_qspec_map[input_act0] = input_act_qspec
+
+        shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
+        for input_act in inputs[1:]:
+            input_qspec_map[input_act] = shared_with_input0_qspec
+
+        output_act_qspec = shared_with_input0_qspec
+
+        cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+def _is_share_obs_or_fq_op(op: Callable) -> bool:
+    return op in [
+        torch.ops.aten.hardtanh.default,
+        torch.ops.aten.hardtanh_.default,
+        torch.ops.aten.mean.default,
+        torch.ops.aten.mean.dim,
+        torch.ops.aten.permute.default,
+        torch.ops.aten.permute_copy.default,
+        torch.ops.aten.squeeze.dim,
+        torch.ops.aten.squeeze_copy.dim,
+        # TODO: remove?
+        torch.ops.aten.adaptive_avg_pool2d.default,
+        torch.ops.aten.view_copy.default,
+        torch.ops.aten.view.default,
+        torch.ops.aten.slice_copy.Tensor,
+        torch.ops.aten.flatten.using_ints,
+    ]
+
+
+def propagate_annotation(model: torch.fx.GraphModule) -> None:
+    for n in model.graph.nodes:
+        if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
+            continue
+
+        prev_node = n.args[0]
+        if not isinstance(prev_node, Node):
+            continue
+
+        quantization_annotation = prev_node.meta.get("quantization_annotation", None)
+        if not quantization_annotation:
+            continue
+
+        output_qspec = quantization_annotation.output_qspec
+        if not output_qspec:
+            continue
+
+        # make sure current node is not annotated
+        if (
+            "quantization_annotation" in n.meta
+            and n.meta["quantization_annotation"]._annotated
+        ):
+            continue
+
+        shared_qspec = SharedQuantizationSpec(prev_node)
+        # propagate the previous output_qspec to the current node
+        n.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                prev_node: shared_qspec,
+            },
+            output_qspec=shared_qspec,
+            _annotated=True,
+        )
+
+
+# TODO: make the list of ops customizable
+def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    for n in model.graph.nodes:
+        if n.op != "call_function" or n.target not in [
+            torch.ops.aten.add.Tensor,
+            torch.ops.aten.mul.Tensor,
+        ]:
+            continue
+        args = list(n.args)
+        new_args = []
+        for i in range(len(args)):
+            if isinstance(args[i], torch.fx.Node):
+                new_args.append(args[i])
+                continue
+            prefix = "_tensor_constant_"
+            get_new_attr_name = get_new_attr_name_with_prefix(prefix)
+            tensor_constant_name = get_new_attr_name(model)
+            float_tensor = torch.tensor(float(args[i]))
+            model.register_buffer(tensor_constant_name, float_tensor)
+            fake_mode = n.meta["val"].fake_mode
+            with model.graph.inserting_before(n):
+                get_attr_node = model.graph.create_node(
+                    "get_attr", tensor_constant_name, (), {}
+                )
+                get_attr_node.meta["val"] = fake_mode.from_tensor(
+                    float_tensor, static_shapes=True
+                )
+                new_args.append(get_attr_node)
+        n.args = tuple(new_args)
+    model.recompile()
+    return model
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/stubs.py b/MLPY/Lib/site-packages/torch/ao/quantization/stubs.py
new file mode 100644
index 0000000000000000000000000000000000000000..42a90c8e193e5f2b22874c721b7057a4ef022cf6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/stubs.py
@@ -0,0 +1,64 @@
+
+from torch import nn
+
+class QuantStub(nn.Module):
+    r"""Quantize stub module, before calibration, this is same as an observer,
+    it will be swapped as `nnq.Quantize` in `convert`.
+
+    Args:
+        qconfig: quantization configuration for the tensor,
+            if qconfig is not provided, we will get qconfig from parent modules
+    """
+    def __init__(self, qconfig=None):
+        super().__init__()
+        if qconfig:
+            self.qconfig = qconfig
+
+    def forward(self, x):
+        return x
+
+
+class DeQuantStub(nn.Module):
+    r"""Dequantize stub module, before calibration, this is same as identity,
+    this will be swapped as `nnq.DeQuantize` in `convert`.
+
+    Args:
+        qconfig: quantization configuration for the tensor,
+            if qconfig is not provided, we will get qconfig from parent modules
+    """
+    def __init__(self, qconfig=None):
+        super().__init__()
+        if qconfig:
+            self.qconfig = qconfig
+
+    def forward(self, x):
+        return x
+
+
+class QuantWrapper(nn.Module):
+    r"""A wrapper class that wraps the input module, adds QuantStub and
+    DeQuantStub and surround the call to module with call to quant and dequant
+    modules.
+
+    This is used by the `quantization` utility functions to add the quant and
+    dequant modules, before `convert` function `QuantStub` will just be observer,
+    it observes the input tensor, after `convert`, `QuantStub`
+    will be swapped to `nnq.Quantize` which does actual quantization. Similarly
+    for `DeQuantStub`.
+    """
+    quant: QuantStub
+    dequant: DeQuantStub
+    module: nn.Module
+
+    def __init__(self, module):
+        super().__init__()
+        qconfig = getattr(module, "qconfig", None)
+        self.add_module('quant', QuantStub(qconfig))
+        self.add_module('dequant', DeQuantStub(qconfig))
+        self.add_module('module', module)
+        self.train(module.training)
+
+    def forward(self, X):
+        X = self.quant(X)
+        X = self.module(X)
+        return self.dequant(X)
diff --git a/MLPY/Lib/site-packages/torch/ao/quantization/utils.py b/MLPY/Lib/site-packages/torch/ao/quantization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8835b2ce281c5bd4f6b6be9087a2e83731051d55
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/ao/quantization/utils.py
@@ -0,0 +1,703 @@
+"""
+Utils shared by different modes of quantization (eager/graph)
+"""
+import functools
+import warnings
+from collections import OrderedDict
+from inspect import getfullargspec, signature
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch.ao.quantization.quant_type import QuantType
+from torch.fx import Node
+from torch.nn.utils.parametrize import is_parametrized
+
+NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
+NodePattern.__module__ = "torch.ao.quantization.utils"
+
+# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
+# Define separately to prevent circular imports.
+# TODO(future PR): improve this.
+# make this public once fixed (can't be public as is because setting the module directly
+# doesn't work)
+QuantizerCls = Any
+
+# Type for fusion patterns, it can be more complicated than the following actually,
+# see pattern.md for docs
+# TODO: not sure if typing supports recursive data types
+Pattern = Union[
+    Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
+]
+Pattern.__module__ = "torch.ao.quantization.utils"
+
+# TODO: maybe rename this to MatchInputNode
+class MatchAllNode:
+    """ A node pattern that matches all nodes, used in defining
+    fusion patterns in FX Graph Mode Quantization
+    """
+    pass
+
+module_type_list = {
+    torch.nn.ReLU,
+    torch.nn.ReLU6,
+    torch.nn.AdaptiveAvgPool1d,
+    torch.nn.AdaptiveAvgPool2d,
+    torch.nn.AdaptiveAvgPool3d,
+    torch.nn.AvgPool1d,
+    torch.nn.AvgPool2d,
+    torch.nn.AvgPool3d,
+    torch.nn.MaxPool1d,
+    torch.nn.MaxPool2d,
+    torch.nn.MaxPool3d,
+    torch.nn.Identity,
+    torch.nn.Hardsigmoid,
+    torch.nn.Sigmoid,
+    torch.nn.Tanh,
+}
+func_list = {
+    torch.nn.functional.adaptive_avg_pool1d,
+    torch.nn.functional.adaptive_avg_pool2d,
+    torch.nn.functional.adaptive_avg_pool3d,
+    torch.nn.functional.elu,
+    torch.nn.functional.hardswish,
+    torch.nn.functional.instance_norm,
+    torch.nn.functional.layer_norm,
+    torch.nn.functional.leaky_relu,
+    torch.nn.functional.silu,
+    torch.nn.functional.mish,
+    torch.nn.functional.dropout,
+    torch.nn.functional.max_pool1d,
+    torch.nn.functional.max_pool2d,
+    torch.nn.functional.max_pool3d,
+    torch.nn.functional.relu,
+    torch.nn.functional.hardtanh,
+    torch.nn.functional.hardtanh_,
+    torch.nn.functional.hardsigmoid,
+    torch.nn.functional.sigmoid,
+    torch.transpose,
+    torch.repeat_interleave,
+    torch.sigmoid,
+    torch.squeeze,
+    torch.stack,
+    torch.sum,
+    torch.tanh,
+    torch.unsqueeze,
+    torch.cat,
+}
+method_list = {
+    torch.mean,
+    'relu',
+    'relu_',
+    'contiguous',
+    'detach',
+    'detach_',
+    'hardsigmoid',
+    'hardsigmoid_',
+    'permute',
+    'repeat',
+    'repeat_interleave',
+    'reshape',
+    'resize_',
+    'shape',
+    'sigmoid',
+    'sigmoid_',
+    'size',
+    'squeeze',
+    'squeeze_',
+    'tanh',
+    'tanh_',
+    'transpose',
+    'unsqueeze',
+    'unsqueeze_',
+    'view',
+}
+
+# TODO: not used now, remove
+def check_node(node, modules):
+    # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
+    is_call_function = node.op == "call_function" and node.target in func_list
+    is_call_method = node.op == "call_method" and node.target in method_list
+    is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
+    return is_call_function, is_call_method, is_call_module
+
+def get_combined_dict(default_dict, additional_dict):
+    d = default_dict.copy()
+    d.update(additional_dict)
+    return d
+
+def is_per_tensor(qscheme):
+    return qscheme == torch.per_tensor_affine or \
+        qscheme == torch.per_tensor_symmetric
+
+def is_per_channel(qscheme):
+    return qscheme in [torch.per_channel_affine,
+                       torch.per_channel_affine_float_qparams,
+                       torch.per_channel_symmetric]
+
+def getattr_from_fqn(obj: Any, fqn: str) -> Any:
+    """
+    Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
+    """
+    return functools.reduce(getattr, fqn.split("."), obj)
+
+def to_underlying_dtype(qdtype):
+    DTYPE_MAPPING = {
+        torch.quint8: torch.uint8,
+        torch.qint8: torch.int8,
+        torch.qint32: torch.int32,
+        torch.quint4x2: torch.uint8,
+        torch.quint2x4: torch.uint8,
+        torch.uint8: torch.uint8,
+        torch.int8: torch.int8,
+        torch.int16: torch.int16,
+        torch.int32: torch.int32,
+    }
+    assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype)
+    return DTYPE_MAPPING[qdtype]
+
+def get_qparam_dict(observer_or_fake_quant):
+    from torch.ao.quantization.observer import PlaceholderObserver
+
+    qscheme = getattr(observer_or_fake_quant, "qscheme", None)
+    dtype = observer_or_fake_quant.dtype
+    qparams = {"qscheme": qscheme, "dtype": dtype}
+
+    if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver):
+        return {"qscheme": None, "dtype": dtype}
+
+    if is_per_tensor(qscheme):
+        qscheme = torch.per_tensor_affine
+    elif is_per_channel(qscheme):
+        # change symmetric to affine since we do not have symmetric
+        # quantized Tensor
+        if qscheme == torch.per_channel_symmetric:
+            qscheme = torch.per_channel_affine
+        qparams["axis"] = observer_or_fake_quant.ch_axis
+    else:
+        raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
+    # update qscheme, since we don't have symmetric quant qscheme
+    # in quantized Tensor
+    qparams["qscheme"] = qscheme
+
+    scale, zero_point = observer_or_fake_quant.calculate_qparams()
+    qparams["scale"] = scale
+    qparams["zero_point"] = zero_point
+
+    if hasattr(observer_or_fake_quant, "quant_min"):
+        qparams["quant_min"] = observer_or_fake_quant.quant_min
+    if hasattr(observer_or_fake_quant, "quant_max"):
+        qparams["quant_max"] = observer_or_fake_quant.quant_max
+
+    return qparams
+
+
+def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
+    """ Get the observed/quantized custom module class that we need
+    to swap `custom_module` to
+    Input:
+        custom_module: input, can be an instance of either a float or observed custom module
+        custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
+        qconfig: qconfig configured for the custom module
+
+    Output:
+        corresponding observed/quantized custom module class for input custom module instance
+    """
+    quant_type = get_quant_type(qconfig)
+    class_mapping = custom_module_class_mapping.get(quant_type, {})
+    assert type(custom_module) in class_mapping, "did not find corresponding observed " \
+        f"module class for {type(custom_module)} in mapping: {class_mapping}"
+    return class_mapping[type(custom_module)]
+
+def activation_dtype(qconfig):
+    assert qconfig is not None
+    activation = qconfig.activation()
+    return activation.dtype
+
+def weight_dtype(qconfig):
+    assert qconfig is not None
+    weight = qconfig.weight()
+    return weight.dtype
+
+def activation_is_statically_quantized(qconfig):
+    """ Given a qconfig, decide if the activation needs to be
+    quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
+    """
+    return (
+        activation_dtype(qconfig) in [
+            torch.quint8,
+            torch.qint8,
+            torch.qint32,
+            torch.float16,
+            torch.uint8,
+            torch.int8,
+            torch.int16,
+            torch.int32
+        ]
+        and (not activation_is_dynamically_quantized(qconfig))
+    )
+
+def activation_is_dynamically_quantized(qconfig):
+    """ Given a qconfig, decide if the activation needs to be
+    dynamically quantized or not, this includes dynamically quantizing to
+    quint8, qint8 and float16
+    """
+    activation_dtype, _, activation_is_dynamic = \
+        get_qconfig_dtypes(qconfig)
+    return activation_is_dynamic
+
+def activation_is_int8_quantized(qconfig):
+    """ Given a qconfig, decide if the activation needs to be
+    quantized to int8 or not, this includes quantizing to quint8, qint8
+    """
+    return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8]
+
+def activation_is_int32_quantized(qconfig):
+    """ Given a qconfig, decide if the activation needs to be
+    quantized to int32 or not
+    """
+    return activation_dtype(qconfig) in [torch.qint32, torch.int32]
+
+def weight_is_quantized(qconfig):
+    """ Given a qconfig, decide if the weight needs to be
+    quantized or not
+    """
+    return weight_dtype(qconfig) in [
+        torch.quint8,
+        torch.qint8,
+        torch.float16,
+        torch.quint4x2,
+        torch.uint8,
+        torch.int8,
+        torch.int16,
+        torch.int32
+    ]
+
+def weight_is_statically_quantized(qconfig):
+    """ Given a qconfig, decide if the weight needs to be statically
+    quantized or not
+    """
+    return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8]
+
+def op_is_int8_dynamically_quantized(qconfig) -> bool:
+    """ Given a qconfig, returns True if this op is using int8 dynamic
+    quantization
+    """
+    activation_dtype, weight_dtype, activation_is_dynamic = \
+        get_qconfig_dtypes(qconfig)
+    return (
+        activation_dtype in [torch.quint8, torch.uint8] and
+        # for now, the lines below assume fbgemm or qnnpack
+        weight_dtype in [torch.qint8, torch.int8] and
+        activation_is_dynamic
+    )
+
+def get_qconfig_dtypes(qconfig):
+    r""" returns the qconfig tuple for qconfig:
+    (activation_dtype, weight_dtype, activation_is_dynamic)
+    """
+    assert qconfig is not None
+    activation = qconfig.activation()
+    weight = qconfig.weight()
+    act_is_dynamic = getattr(activation, "is_dynamic", False)
+    return (activation.dtype, weight.dtype, act_is_dynamic)
+
+def get_quant_type(qconfig):
+    assert qconfig is not None
+    activation = qconfig.activation()
+    weight = qconfig.weight()
+    static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32]
+    if weight.dtype in static_dtypes:
+        if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
+            return QuantType.DYNAMIC
+        elif activation.dtype in static_dtypes:
+            return QuantType.STATIC
+        else:
+            return QuantType.WEIGHT_ONLY
+
+    if weight.dtype == torch.float16:
+        if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
+            return QuantType.DYNAMIC
+        elif activation.dtype == torch.float16:
+            return QuantType.STATIC
+
+    raise Exception(f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype}),"
+                    f"weight({weight.dtype})")
+
+def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
+    """ Checks if the given minimum and maximum values are valid, meaning that
+    they exist and the min value is less than the max value.
+    """
+    if min_val.numel() == 0 or max_val.numel() == 0:
+        warnings.warn(
+            "must run observer before calling calculate_qparams. " +
+            "Returning default values."
+        )
+        return False
+
+    if min_val.dim() == 0 or max_val.dim() == 0:
+        if min_val == float("inf") and max_val == float("-inf"):
+            warnings.warn(
+                "must run observer before calling calculate_qparams. " +
+                "Returning default values."
+            )
+
+            return False
+
+        assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
+    else:
+        assert torch.all(
+            min_val <= max_val
+        ), f"min {min_val} should be less than max {max_val}"
+
+    return True
+
+
+def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
+                        reduce_range: bool) -> Tuple[int, int]:
+    r"""Calculates actual qmin and qmax based on the quantization range,
+    observer datatype and if range is reduced.
+    """
+    # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
+    if has_customized_qrange:
+        # This initialization here is to be resolve TorchScript compilation issues and allow
+        # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
+        # The actual values of initial_qmin and initial_qmax will be reset below.
+        if dtype in [torch.qint32, torch.int32]:
+            initial_quant_min, initial_quant_max = 0, 2**32 - 1
+        else:
+            initial_quant_min, initial_quant_max = 0, 255
+        # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
+        # attribute from Optional valid integers for use, based on TorchScript's requirements.
+        custom_quant_min, custom_quant_max = quant_min, quant_max
+        if custom_quant_min is not None and custom_quant_max is not None:
+            initial_quant_min, initial_quant_max = (
+                custom_quant_min,
+                custom_quant_max,
+            )
+
+        qrange_len = initial_quant_max - initial_quant_min + 1
+        if dtype in [torch.qint8, torch.int8]:
+            assert (
+                0 < qrange_len <= 256
+            ), "quantization range should be positive and not exceed the maximum bit range (=256)."
+        elif dtype in [torch.qint32, torch.int32]:
+            assert (
+                0 < qrange_len <= 2**32
+            ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
+        if reduce_range:
+            quant_min, quant_max = quant_min // 2, quant_max // 2
+    else:
+        # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
+        if dtype in [torch.qint8, torch.int8]:
+            if reduce_range:
+                quant_min, quant_max = -64, 63
+            else:
+                quant_min, quant_max = -128, 127
+        elif dtype in [torch.quint8, torch.uint8]:
+            if reduce_range:
+                quant_min, quant_max = 0, 127
+            else:
+                quant_min, quant_max = 0, 255
+        elif dtype in [torch.qint32, torch.int32]:
+            quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
+        else:
+            quant_min, quant_max = 0, 15
+    return quant_min, quant_max
+
+
+def _parent_name(target):
+    """
+    Turn 'foo.bar' into ['foo', 'bar']
+    """
+    r = target.rsplit('.', 1)
+    if len(r) == 1:
+        return '', r[0]
+    else:
+        return r[0], r[1]
+
+def has_no_children_ignoring_parametrizations(module):
+    """
+    Checks if module._modules is empty or
+    if module is a parametrization, checks that module._modules only has
+    the 'parametrizations' module
+    """
+    if len(module._modules) == 0:
+        return True
+    elif is_parametrized(module):
+        return len(module._modules) == 1 and 'parametrizations' in module._modules
+    else:
+        return False
+
+def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]:
+    """ Get the path (fully qualified name) of a submodule
+
+    Example::
+
+    >> class M(torch.nn.Module):
+           def __init__(self):
+               self.linear = torch.nn.Linear(5, 5)
+           def forward(self, x):
+               return self.linear(x)
+
+    >> m = M()
+    >> l = m.linear
+    >> _get_path_of_module(m, l)
+    "linear"
+    """
+    for n, p in root.named_modules():
+        if submodule is p:
+            return n
+    return None
+
+def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
+    """ Get local keyword arguments
+
+    Example::
+
+    >> def f(self, a, b=9):
+           pass
+    >> loc = {"a": 6, "c": 7}
+    >> _get_signature_locals(f, loc)
+    {"a": 6}
+    """
+    return {k: v for k, v in loc.items() if k in signature(f).parameters}
+
+def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]":
+    """ Get all default keyword arguments from function signature
+
+    Example::
+
+    >> def f(self, a, b=9):
+           pass
+    >> _get_default_kwargs(f)
+    {"b": 9}
+    """
+    kwargs = {}
+    for name, param in signature(f).parameters.items():
+        if param.default is not param.empty:
+            kwargs[name] = param.default
+        elif param.kind is param.VAR_POSITIONAL:
+            kwargs[name] = ()
+        elif param.kind is param.VAR_KEYWORD:
+            kwargs[name] = {}
+    return OrderedDict(kwargs)
+
+def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]":
+    """ Given a function and local function arguments, normalize the keyword
+    arguments by filling in default arguments from function signature
+
+    Example::
+
+    >> def f(self, key1=3, key2=3):
+           pass
+    >> loc = {"key2": 6}
+    >> _normalize_kwargs(f, loc)
+    {"key1": 3, "key2": 6}
+    """
+    default_kwargs = _get_default_kwargs(func)
+    local_kwargs = _get_signature_locals(func, loc)
+    normalized_kwargs = default_kwargs.copy()
+    for attr, val in local_kwargs.items():
+        if attr in normalized_kwargs:
+            # override the default keyword arguments
+            normalized_kwargs[attr] = val
+    return normalized_kwargs
+
+def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
+    r"""Validates that the user-specified quantization range is properly initialized
+    and within the given bound supported by the observer dtype.
+
+    To accommodate lower-bit quantization with respect to the existing torch.qint8 and
+    torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
+    in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
+    values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
+    fake quantization. These estimates are compared against parameters learned through backpropagation.
+    The related literatures for scale and zero point via backpropagation are as follows:
+
+    Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
+    Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
+    """
+    # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
+    # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
+    assert (
+        quant_min <= 0 <= quant_max
+    ), "Used-specified quantization range must include 0."
+    assert (
+        quant_min < quant_max
+    ), "qmin must be strictly less than qmax for user-specified quantization range."
+
+
+# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
+# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
+# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change
+# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168)
+def determine_qparams(
+        min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int,
+        dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool,
+        qscheme: torch.qscheme = torch.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]:
+    r"""Calculates the quantization parameters, given min and max
+    value tensors. Works for both per tensor and per channel cases
+
+    Args:
+        min_val: Minimum values per channel
+        max_val: Maximum values per channel
+
+    Returns:
+        scales: Scales tensor of shape (#channels,)
+        zero_points: Zero points tensor of shape (#channels,)
+    """
+    if not check_min_max_valid(min_val, max_val):
+        return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
+
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+
+    device = min_val_neg.device
+    scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device)
+    zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+    if (
+        qscheme == torch.per_tensor_symmetric
+        or qscheme == torch.per_channel_symmetric
+    ):
+        max_val_pos = torch.max(-min_val_neg, max_val_pos)
+        scale = max_val_pos / (float(quant_max - quant_min) / 2)
+        scale = torch.max(scale, eps)
+        if dtype in [torch.uint8, torch.quint8]:
+            if has_customized_qrange:
+                # When customized quantization range is used, down-rounded midpoint of the range is chosen.
+                zero_point = zero_point.new_full(
+                    zero_point.size(), (quant_min + quant_max) // 2
+                )
+            else:
+                zero_point = zero_point.new_full(zero_point.size(), 128)
+    elif qscheme == torch.per_channel_affine_float_qparams:
+        scale = (max_val - min_val) / float(quant_max - quant_min)
+        scale = torch.where(scale > eps, scale, torch.ones_like(scale))
+        # We use the quantize function
+        # xq = Round(Xf * inv_scale + zero_point),
+        # setting zero_point to (-1 * min *inv_scale) we get
+        # Xq = Round((Xf - min) * inv_scale)
+        zero_point = -1 * min_val / scale
+    else:
+        scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
+        scale = torch.max(scale, eps)
+        zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
+        zero_point = torch.clamp(zero_point, quant_min, quant_max)
+
+    # For scalar values, cast them to Tensors of size 1 to keep the shape
+    # consistent with default values in FakeQuantize.
+    if len(scale.shape) == 0:
+        # TODO: switch to scale.item() after adding JIT support
+        scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
+    if len(zero_point.shape) == 0:
+        # TODO: switch to zero_point.item() after adding JIT support
+        zero_point = torch.tensor(
+            [int(zero_point)], dtype=zero_point.dtype, device=device
+        )
+        if qscheme == torch.per_channel_affine_float_qparams:
+            zero_point = torch.tensor(
+                [float(zero_point)], dtype=zero_point.dtype, device=device
+            )
+
+    return scale.to(torch.double), zero_point.to(torch.int64)
+
+def _get_num_pos_args(f: Callable) -> int:
+    """ Get number of positional args for a function
+
+    Example::
+
+    >> def f(self, key1=3, key2=3):
+           pass
+    >> _get_num_pos_args(f)
+    3
+    """
+    return len(getfullargspec(f).args)
+
+def get_fqn_to_example_inputs(
+    model: torch.nn.Module,
+    example_inputs: Tuple[Any, ...]
+) -> Dict[str, Tuple[Any, ...]]:
+    """ Given a model and its example inputs, return a dictionary from
+    fully qualified name of submodules to example_inputs for that submodule,
+    e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
+          "sub.linear1": (tensor4,), ...}
+
+    Used to make quantizing submodules easier now that FX Graph Mode Quantization requires
+    example inputs.
+
+    Also works for keyword arguments with default values, we would flatten keyword
+    arguments as positional arguments and fill in the missing keyword args with default
+    values, e.g. if we have a forward function:
+    def forward(self, x, key1=3, key2=3):
+        ...
+
+    and we call it with self.submodule(x, key2=6)
+    we'll get example_inputs: (x, 3, 6)
+
+    user can also override `key1` with positional arguments as well:
+    for self.submodule(x, 5, key2=6)
+    we'll get: (x, 5, 6)
+
+    variable positional arguments and variable positional keyword arguments in forward
+    function are not supported currently, so please make sure no submodules is using
+    them.
+    """
+    root = model
+    fqn_to_example_inputs = {}
+
+    def _patched_module_call(self, *args, **kwargs):
+        submodule_example_inputs = list(args).copy()
+        normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
+        # minus 1 to skipping counting `self`
+        num_args = _get_num_pos_args(self.forward) - 1
+        num_to_pop = num_args - len(submodule_example_inputs)
+        while num_to_pop and normalized_kwargs:
+            normalized_kwargs.popitem(last=False)
+            num_to_pop -= 1
+        submodule_example_inputs.extend(normalized_kwargs.values())
+        submodule_example_inputs_tuple = tuple(submodule_example_inputs)
+        fqn = _get_path_of_module(root, self)
+        if fqn is not None:
+            fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
+        return orig_module_call(self, *args, **kwargs)
+
+    orig_module_call = torch.nn.Module.__call__
+    torch.nn.Module.__call__ = _patched_module_call  # type: ignore[method-assign]
+    try:
+        model(*example_inputs)
+    finally:
+        # restore the module call even if there is an exception
+        torch.nn.Module.__call__ = orig_module_call  # type: ignore[method-assign]
+    return fqn_to_example_inputs
+
+__all__ = [
+    "NodePattern",
+    "Pattern",
+    "MatchAllNode",
+    "check_node",
+    "get_combined_dict",
+    "is_per_tensor",
+    "is_per_channel",
+    "getattr_from_fqn",
+    "get_qparam_dict",
+    "get_swapped_custom_module_class",
+    "activation_dtype",
+    "weight_dtype",
+    "activation_is_statically_quantized",
+    "activation_is_dynamically_quantized",
+    "activation_is_int8_quantized",
+    "activation_is_int32_quantized",
+    "weight_is_quantized",
+    "weight_is_statically_quantized",
+    "op_is_int8_dynamically_quantized",
+    "get_qconfig_dtypes",
+    "get_quant_type",
+    "check_min_max_valid",
+    "calculate_qmin_qmax",
+    "has_no_children_ignoring_parametrizations",
+    "get_fqn_to_example_inputs",
+    "to_underlying_dtype",
+    "determine_qparams",
+    "validate_qmin_qmax",
+]
diff --git a/MLPY/Lib/site-packages/torch/autograd/__init__.py b/MLPY/Lib/site-packages/torch/autograd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30e24e6df2e2aee769cc029475679325bb5243b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/__init__.py
@@ -0,0 +1,515 @@
+"""
+``torch.autograd`` provides classes and functions implementing automatic
+differentiation of arbitrary scalar valued functions. It requires minimal
+changes to the existing code - you only need to declare :class:`Tensor` s
+for which gradients should be computed with the ``requires_grad=True`` keyword.
+As of now, we only support autograd for floating point :class:`Tensor` types (
+half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
+"""
+import warnings
+from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
+from .. import _vmap_internals
+from ..overrides import handle_torch_function, has_torch_function, is_tensor_like
+from . import forward_ad, functional, graph
+from .anomaly_mode import detect_anomaly, set_detect_anomaly
+from .function import Function, NestedIOFunction
+from .grad_mode import (
+    _force_original_view_tracking,
+    _unsafe_preserve_version_counter,
+    enable_grad,
+    inference_mode,
+    no_grad,
+    set_grad_enabled,
+    set_multithreading_enabled,
+)
+from .gradcheck import gradcheck, gradgradcheck
+from .graph import _engine_run_backward
+
+from .variable import Variable
+
+__all__ = ["Variable", "Function", "backward", "grad_mode"]
+
+_OptionalTensor = Optional[torch.Tensor]
+_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
+
+
+def _calculate_shape(
+    output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool
+) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
+    # is_same_size ensures that both tensors are either nested or non nested
+    # circular import
+    from torch.nested._internal.nested_tensor import NestedTensor
+
+    if output.is_nested and not isinstance(output, NestedTensor):
+        if is_grads_batched:
+            raise RuntimeError("Batched grads are not supported with Nested Tensor.")
+        out_shape = output._nested_tensor_size()
+        grad_shape = grad._nested_tensor_size()
+
+        return out_shape, grad_shape
+
+    reg_out_shape = output.shape
+    reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
+    return reg_out_shape, reg_grad_shape
+
+
+def _make_grads(
+    outputs: Sequence[torch.Tensor],
+    grads: Sequence[_OptionalTensor],
+    is_grads_batched: bool,
+) -> Tuple[_OptionalTensor, ...]:
+    new_grads: List[_OptionalTensor] = []
+    for out, grad in zip(outputs, grads):
+        if isinstance(grad, torch.Tensor):
+            from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
+
+            first_grad = grad if not is_grads_batched else grad[0]
+            # TODO: We can remove this conditional once we uniformly use
+            # singleton int to represent jagged dimension, so that size() call
+            # on nested tensor works
+            if out.is_nested or first_grad.is_nested:
+                shape_matches = torch.is_same_size(out, first_grad)
+            else:
+                # We need to do a regular size check, without going through
+                # the operator, to be able to handle unbacked symints
+                # (expect_true ensures we can deal with unbacked)
+                shape_matches = expect_true(sym_eq(out.size(), first_grad.size()))
+            if not shape_matches:
+                out_shape, grad_shape = _calculate_shape(
+                    out, first_grad, is_grads_batched
+                )
+                if is_grads_batched:
+                    raise RuntimeError(
+                        "If `is_grads_batched=True`, we interpret the first "
+                        "dimension of each grad_output as the batch dimension. "
+                        "The sizes of the remaining dimensions are expected to match "
+                        "the shape of corresponding output, but a mismatch "
+                        "was detected: grad_output["
+                        + str(grads.index(grad))
+                        + "] has a shape of "
+                        + str(grad_shape)
+                        + " and output["
+                        + str(outputs.index(out))
+                        + "] has a shape of "
+                        + str(out_shape)
+                        + ". "
+                        "If you only want some tensors in `grad_output` to be considered "
+                        "batched, consider using vmap."
+                    )
+                else:
+                    raise RuntimeError(
+                        "Mismatch in shape: grad_output["
+                        + str(grads.index(grad))
+                        + "] has a shape of "
+                        + str(grad_shape)
+                        + " and output["
+                        + str(outputs.index(out))
+                        + "] has a shape of "
+                        + str(out_shape)
+                        + "."
+                    )
+            if out.dtype.is_complex != grad.dtype.is_complex:
+                raise RuntimeError(
+                    "For complex Tensors, both grad_output and output"
+                    " are required to have the same dtype."
+                    " Mismatch in dtype: grad_output["
+                    + str(grads.index(grad))
+                    + "] has a dtype of "
+                    + str(grad.dtype)
+                    + " and output["
+                    + str(outputs.index(out))
+                    + "] has a dtype of "
+                    + str(out.dtype)
+                    + "."
+                )
+            new_grads.append(grad)
+        elif grad is None:
+            if out.requires_grad:
+                if out.numel() != 1:
+                    raise RuntimeError(
+                        "grad can be implicitly created only for scalar outputs"
+                    )
+                if not out.dtype.is_floating_point:
+                    msg = (
+                        "grad can be implicitly created only for real scalar outputs"
+                        f" but got {out.dtype}"
+                    )
+                    raise RuntimeError(msg)
+                new_grads.append(
+                    torch.ones_like(out, memory_format=torch.preserve_format)
+                )
+            else:
+                new_grads.append(None)
+        else:
+            raise TypeError(
+                "gradients can be either Tensors or None, but got "
+                + type(grad).__name__
+            )
+    return tuple(new_grads)
+
+
+def _tensor_or_tensors_to_tuple(
+    tensors: Optional[_TensorOrTensors], length: int
+) -> Tuple[_OptionalTensor, ...]:
+    if tensors is None:
+        return (None,) * length
+    if isinstance(tensors, torch.Tensor):
+        return (tensors,)
+    return tuple(tensors)
+
+
+def backward(
+    tensors: _TensorOrTensors,
+    grad_tensors: Optional[_TensorOrTensors] = None,
+    retain_graph: Optional[bool] = None,
+    create_graph: bool = False,
+    grad_variables: Optional[_TensorOrTensors] = None,
+    inputs: Optional[_TensorOrTensorsOrGradEdge] = None,
+) -> None:
+    r"""Computes the sum of gradients of given tensors with respect to graph
+    leaves.
+
+    The graph is differentiated using the chain rule. If any of ``tensors``
+    are non-scalar (i.e. their data has more than one element) and require
+    gradient, then the Jacobian-vector product would be computed, in this
+    case the function additionally requires specifying ``grad_tensors``.
+    It should be a sequence of matching length, that contains the "vector"
+    in the Jacobian-vector product, usually the gradient of the differentiated
+    function w.r.t. corresponding tensors (``None`` is an acceptable value for
+    all tensors that don't need gradient tensors).
+
+    This function accumulates gradients in the leaves - you might need to zero
+    ``.grad`` attributes or set them to ``None`` before calling it.
+    See :ref:`Default gradient layouts`
+    for details on the memory layout of accumulated gradients.
+
+    .. note::
+        Using this method with ``create_graph=True`` will create a reference cycle
+        between the parameter and its gradient which can cause a memory leak.
+        We recommend using ``autograd.grad`` when creating the graph to avoid this.
+        If you have to use this function, make sure to reset the ``.grad`` fields of your
+        parameters to ``None`` after use to break the cycle and avoid the leak.
+
+    .. note::
+
+        If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
+        in a user-specified CUDA stream context, see
+        :ref:`Stream semantics of backward passes`.
+
+    .. note::
+
+        When ``inputs`` are provided and a given input is not a leaf,
+        the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
+        It is an implementation detail on which the user should not rely.
+        See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
+
+    Args:
+        tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
+            computed.
+        grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
+            the Jacobian-vector product, usually gradients w.r.t. each element of
+            corresponding tensors. None values can be specified for scalar Tensors or
+            ones that don't require grad. If a None value would be acceptable for all
+            grad_tensors, then this argument is optional.
+        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
+            will be freed. Note that in nearly all cases setting this option to ``True``
+            is not needed and often can be worked around in a much more efficient
+            way. Defaults to the value of ``create_graph``.
+        create_graph (bool, optional): If ``True``, graph of the derivative will
+            be constructed, allowing to compute higher order derivative products.
+            Defaults to ``False``.
+        inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient
+            be will accumulated into ``.grad``. All other Tensors will be ignored. If
+            not provided, the gradient is accumulated into all the leaf Tensors that
+            were used to compute the :attr:`tensors`.
+    """
+    if torch._C._are_functorch_transforms_active():
+        raise RuntimeError(
+            "backward() called inside a functorch transform. This is not "
+            "supported, please use functorch.grad or functorch.vjp instead "
+            "or call backward() outside of functorch transforms."
+        )
+
+    if grad_variables is not None:
+        warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
+        if grad_tensors is None:
+            grad_tensors = grad_variables
+        else:
+            raise RuntimeError(
+                "'grad_tensors' and 'grad_variables' (deprecated) "
+                "arguments both passed to backward(). Please only "
+                "use 'grad_tensors'."
+            )
+    if inputs is not None and len(inputs) == 0:
+        raise RuntimeError("'inputs' argument to backward() cannot be empty.")
+
+    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
+    inputs = (
+        (inputs,)
+        if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
+        else tuple(inputs)
+        if inputs is not None
+        else tuple()
+    )
+
+    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
+    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
+    if retain_graph is None:
+        retain_graph = create_graph
+
+    # The reason we repeat the same comment below is that
+    # some Python versions print out the first line of a multi-line function
+    # calls in the traceback and some print out the last line
+    _engine_run_backward(
+        tensors,
+        grad_tensors_,
+        retain_graph,
+        create_graph,
+        inputs,
+        allow_unreachable=True,
+        accumulate_grad=True,
+    )
+
+
+def grad(
+    outputs: _TensorOrTensors,
+    inputs: _TensorOrTensorsOrGradEdge,
+    grad_outputs: Optional[_TensorOrTensors] = None,
+    retain_graph: Optional[bool] = None,
+    create_graph: bool = False,
+    only_inputs: bool = True,
+    allow_unused: Optional[bool] = None,
+    is_grads_batched: bool = False,
+    materialize_grads: bool = False,
+) -> Tuple[torch.Tensor, ...]:
+    r"""Computes and returns the sum of gradients of outputs with respect to
+    the inputs.
+
+    ``grad_outputs`` should be a sequence of length matching ``output``
+    containing the "vector" in vector-Jacobian product, usually the pre-computed
+    gradients w.r.t. each of the outputs. If an output doesn't require_grad,
+    then the gradient can be ``None``).
+
+    .. note::
+
+        If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
+        in a user-specified CUDA stream context, see
+        :ref:`Stream semantics of backward passes`.
+
+    .. note::
+
+        ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
+        To accumulate gradient for other parts of the graph, please use
+        ``torch.autograd.backward``.
+
+    Args:
+        outputs (sequence of Tensor): outputs of the differentiated function.
+        inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
+            returned (and not accumulated into ``.grad``).
+        grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
+            Usually gradients w.r.t. each output. None values can be specified for scalar
+            Tensors or ones that don't require grad. If a None value would be acceptable
+            for all grad_tensors, then this argument is optional. Default: None.
+        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
+            will be freed. Note that in nearly all cases setting this option to ``True``
+            is not needed and often can be worked around in a much more efficient
+            way. Defaults to the value of ``create_graph``.
+        create_graph (bool, optional): If ``True``, graph of the derivative will
+            be constructed, allowing to compute higher order derivative products.
+            Default: ``False``.
+        allow_unused (Optional[bool], optional): If ``False``, specifying inputs
+            that were not used when computing outputs (and therefore their grad is
+            always zero) is an error. Defaults to the value of ``materialize_grads``.
+        is_grads_batched (bool, optional): If ``True``, the first dimension of each
+            tensor in ``grad_outputs`` will be interpreted as the batch dimension.
+            Instead of computing a single vector-Jacobian product, we compute a
+            batch of vector-Jacobian products for each "vector" in the batch.
+            We use the vmap prototype feature as the backend to vectorize calls
+            to the autograd engine so that this computation can be performed in a
+            single call. This should lead to performance improvements when compared
+            to manually looping and performing backward multiple times. Note that
+            due to this feature being experimental, there may be performance
+            cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
+            to show any performance warnings and file an issue on github if warnings exist
+            for your use case. Defaults to ``False``.
+        materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs
+            to zero instead of None. This is useful when computing higher-order derivatives.
+            If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error
+            will be raised. Defaults to ``False``.
+
+    """
+    if materialize_grads and allow_unused is False:
+        raise ValueError(
+            "Expected allow_unused to be True or not passed when materialize_grads=True, "
+            "but got: allow_unused=False."
+        )
+    if allow_unused is None:
+        allow_unused = materialize_grads
+    t_outputs = cast(
+        Tuple[torch.Tensor, ...],
+        (outputs,) if is_tensor_like(outputs) else tuple(outputs),
+    )
+    if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
+        inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
+    else:
+        inputs = tuple(inputs)
+    t_inputs = tuple(i for i in inputs if is_tensor_like(i))
+    overridable_args = t_outputs + t_inputs
+    if has_torch_function(overridable_args):
+        return handle_torch_function(
+            grad,
+            overridable_args,
+            t_outputs,
+            inputs,
+            grad_outputs=grad_outputs,
+            retain_graph=retain_graph,
+            create_graph=create_graph,
+            only_inputs=only_inputs,
+            allow_unused=allow_unused,
+            is_grads_batched=is_grads_batched,
+            materialize_grads=materialize_grads,
+        )
+
+    if not only_inputs:
+        warnings.warn(
+            "only_inputs argument is deprecated and is ignored now "
+            "(defaults to True). To accumulate gradient for other "
+            "parts of the graph, please use torch.autograd.backward."
+        )
+
+    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs))
+    grad_outputs_ = _make_grads(
+        t_outputs, grad_outputs_, is_grads_batched=is_grads_batched
+    )
+
+    if retain_graph is None:
+        retain_graph = create_graph
+
+    # The reason we repeat the same comment several times below is because
+    # some Python versions print out the first line of multi-line function
+    # calls in the traceback and some print out the last line
+    if is_grads_batched:
+
+        def vjp(gO):
+            return _engine_run_backward(
+                t_outputs,
+                gO,
+                retain_graph,
+                create_graph,
+                inputs,
+                allow_unused,
+                accumulate_grad=False,
+            )
+
+        result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
+            grad_outputs_
+        )
+    else:
+        result = _engine_run_backward(
+            t_outputs,
+            grad_outputs_,
+            retain_graph,
+            create_graph,
+            inputs,
+            allow_unused,
+            accumulate_grad=False,
+        )
+    if materialize_grads:
+        if any(
+            result[i] is None and not is_tensor_like(inputs[i])
+            for i in range(len(inputs))
+        ):
+            raise RuntimeError(
+                "materialize_grads cannot be used when the given input is a GradientEdge"
+            )
+        result = tuple(
+            output
+            if output is not None
+            else torch.zeros_like(input, requires_grad=True)
+            for (output, input) in zip(result, inputs)
+        )
+    return result
+
+
+# This function applies in case of gradient checkpointing for memory
+# optimization. Currently, gradient checkpointing is supported only if the
+# execution engine is invoked through torch.autograd.backward() and its
+# inputs argument is not passed. It is not supported for torch.autograd.grad().
+# This is because if inputs are specified, the gradient won't be calculated for
+# anything else e.g. model parameters like weights, bias etc.
+#
+# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
+# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
+# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
+# in the stack and before a NodeTask is executed in evaluate_function, it
+# checks for whether reentrant backwards is imperative or not.
+# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
+def _is_checkpoint_valid():
+    return Variable._execution_engine.is_checkpoint_valid()
+
+
+def variable(*args, **kwargs):
+    raise RuntimeError(
+        "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
+    )
+
+
+# Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
+# f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
+# output of an FX graph.  Unfortunately the module name torch.autograd.variable is shadowed by the
+# deprecated function - variable(...).
+variable.Variable = Variable  # type: ignore[attr-defined]
+
+if not torch._C._autograd_init():
+    raise RuntimeError("autograd initialization failed")
+
+# Import all native method/classes
+from torch._C._autograd import (
+    _add_metadata_json,
+    _disable_profiler,
+    _disable_profiler_legacy,
+    _enable_profiler,
+    _enable_profiler_legacy,
+    _enable_record_function,
+    _get_sequence_nr,
+    _kineto_step,
+    _KinetoEvent,
+    _pop_saved_tensors_default_hooks,
+    _prepare_profiler,
+    _profiler_enabled,
+    _ProfilerResult,
+    _push_saved_tensors_default_hooks,
+    _record_function_with_args_enter,
+    _record_function_with_args_exit,
+    _set_empty_test_observer,
+    _supported_activities,
+    DeviceType,
+    kineto_available,
+    ProfilerEvent,
+    SavedTensor,
+)
+
+from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
+
+from . import profiler
+
+
+def _register_py_tensor_class_for_device(device, cls):
+    if not isinstance(cls, type):
+        raise RuntimeError("cls isn't a typeinfo object")
+    torch._C._register_py_class_for_device(device, cls)
+
+
+is_multithreading_enabled = torch._C._is_multithreading_enabled
+torch._C._add_docstr(
+    is_multithreading_enabled, "Returns True if multithreading is currently enabled."
+)
+
+is_view_replay_enabled = torch._C._is_view_replay_enabled
+torch._C._add_docstr(
+    is_view_replay_enabled, "Returns True if view-replay is currently enabled."
+)
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a091e221d0e2d5407d335c0654ced30565213732
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2380b3f4a0cb058384ed7e85e3a83d162a48d8a1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/forward_ad.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/forward_ad.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..482794897f1220950a56b4186a0acace27a4b0b0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/forward_ad.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/function.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/function.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4e809bb299c836596338e96da94b6c5c366ffc6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/function.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/functional.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/functional.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a84d26aaab958f401e6b7622c9d058df6901569
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/functional.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/grad_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/grad_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91f23ce2455c44462ef55bdd24f6513d88fd4798
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/grad_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/gradcheck.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/gradcheck.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..640970ccfff440a9ab09f66bcd67c848b9743e40
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/gradcheck.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/graph.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2fefb1b21f8f4330a3d6974e0c48b95aab5ad58
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/graph.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da7992f12a43fb4e76de133f059f5adaeca49da2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa2d42ae2ebd1f55441f20e5bf158da4762125f8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler_util.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b93b70b4f609e0ef9e94595e213cd3eda332fa06
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/profiler_util.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/__pycache__/variable.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/__pycache__/variable.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2679d8e05cd940d0bec245bd095cfe367979123
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/__pycache__/variable.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/_functions/__init__.py b/MLPY/Lib/site-packages/torch/autograd/_functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc92d7c7fe74ad79a100e0233150f90becde55be
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/_functions/__init__.py
@@ -0,0 +1 @@
+from .tensor import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0756615e91568c405180bdd12586fa5a8bddf8e1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e44c30532ca8fffe471ba586b373ae3164852fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80751ea4f8aa4f040170f17b50aa068f3e60fcb2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/autograd/_functions/tensor.py b/MLPY/Lib/site-packages/torch/autograd/_functions/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dab52745bf21c148b6705d71e3dc306ad57ba6e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/_functions/tensor.py
@@ -0,0 +1,63 @@
+import operator
+import warnings
+from functools import reduce
+
+import torch
+import torch._utils
+from ..function import Function
+
+
+class Type(Function):
+    @staticmethod
+    def forward(ctx, i, dest_type):
+        warnings.warn(
+            "torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use "
+            "torch.tensor.to(dtype=dtype) instead."
+        )
+        ctx.input_type = type(i)
+        ctx.input_device = -1 if not i.is_cuda else i.get_device()
+        return i.type(dest_type)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        if ctx.input_device == -1:
+            return grad_output.type(ctx.input_type), None
+        else:
+            with torch.cuda.device(ctx.input_device):
+                return grad_output.type(ctx.input_type), None
+
+
+# TODO: deprecate this
+class Resize(Function):
+    @staticmethod
+    def forward(ctx, tensor, sizes):
+        ctx.sizes = sizes
+        ctx.numel = reduce(operator.mul, sizes, 1)
+        if tensor.numel() != ctx.numel:
+            raise RuntimeError(
+                (
+                    "requested resize to {} ({} elements in total), "
+                    "but the given tensor has a size of {} ({} elements). "
+                    "autograd's resize can only change the shape of a given "
+                    "tensor, while preserving the number of elements. "
+                ).format(
+                    "x".join(map(str, sizes)),
+                    ctx.numel,
+                    "x".join(map(str, tensor.size())),
+                    tensor.numel(),
+                )
+            )
+        ctx.input_sizes = tensor.size()
+        if tensor.is_quantized:
+            tensor.copy_(tensor)
+            return tensor.contiguous().view(*sizes)
+        if tensor.is_contiguous():
+            result = tensor.new(tensor).contiguous().view(*sizes)
+            return result
+        else:
+            return tensor.contiguous().view(*sizes)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        assert grad_output.numel() == ctx.numel
+        return grad_output.contiguous().view(ctx.input_sizes), None
diff --git a/MLPY/Lib/site-packages/torch/autograd/_functions/utils.py b/MLPY/Lib/site-packages/torch/autograd/_functions/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..86fd64ed2a2baf88561cd116e0a6ca8cc61dfc15
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/_functions/utils.py
@@ -0,0 +1,62 @@
+import operator
+from functools import reduce
+
+
+def maybe_view(tensor, size, check_same_size=True):
+    if check_same_size and tensor.size() == size:
+        return tensor
+    return tensor.contiguous().view(size)
+
+
+def maybe_unexpand(tensor, old_size, check_same_size=True):
+    if check_same_size and tensor.size() == old_size:
+        return tensor
+    num_unsqueezed = tensor.dim() - len(old_size)
+    expanded_dims = [
+        dim
+        for dim, (expanded, original) in enumerate(
+            zip(tensor.size()[num_unsqueezed:], old_size)
+        )
+        if expanded != original
+    ]
+
+    for _ in range(num_unsqueezed):
+        tensor = tensor.sum(0, keepdim=False)
+    for dim in expanded_dims:
+        tensor = tensor.sum(dim, keepdim=True)
+    return tensor
+
+
+# Check whether the op enable broadcasting, and whether it is supported by ONNX.
+# If dims1 and dims2 are different, then broadcast is True.
+# We always assume the combination of dims1 and dims2 is broadcastable.
+# The following types of broadcasting are supported in ONNX:
+#     1) Only one element in dims2, such as dims2 = [1, 1]
+#     2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4]
+# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
+def check_onnx_broadcast(dims1, dims2):
+    broadcast = False
+    supported = True
+    len1 = len(dims1)
+    len2 = len(dims2)
+    numel1 = reduce(operator.mul, dims1)
+    numel2 = reduce(operator.mul, dims2)
+    if len1 < len2:
+        broadcast = True
+        if numel2 != 1:
+            supported = False
+    elif len1 > len2:
+        broadcast = True
+        if numel2 != 1 and dims1[len1 - len2 :] != dims2:
+            supported = False
+    else:
+        if dims1 != dims2:
+            broadcast = True
+            if numel2 != 1:
+                supported = False
+
+    if not supported:
+        raise ValueError(
+            f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}"
+        )
+    return broadcast
diff --git a/MLPY/Lib/site-packages/torch/autograd/anomaly_mode.py b/MLPY/Lib/site-packages/torch/autograd/anomaly_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f186c6527fd1b4273f1e1c3dccfde8c3f2d1e4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/anomaly_mode.py
@@ -0,0 +1,119 @@
+import warnings
+
+import torch
+
+__all__ = ["detect_anomaly", "set_detect_anomaly"]
+
+
+class detect_anomaly:
+    r"""Context-manager that enable anomaly detection for the autograd engine.
+
+    This does two things:
+
+    - Running the forward pass with detection enabled will allow the backward
+      pass to print the traceback of the forward operation that created the failing
+      backward function.
+    - If ``check_nan`` is ``True``, any backward computation that generate "nan"
+      value will raise an error. Default ``True``.
+
+    .. warning::
+        This mode should be enabled only for debugging as the different tests
+        will slow down your program execution.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY)
+        >>> import torch
+        >>> from torch import autograd
+        >>> class MyFunc(autograd.Function):
+        ...     @staticmethod
+        ...     def forward(ctx, inp):
+        ...         return inp.clone()
+        ...     @staticmethod
+        ...     def backward(ctx, gO):
+        ...         # Error during the backward pass
+        ...         raise RuntimeError("Some error in backward")
+        ...         return gO.clone()
+        >>> def run_fn(a):
+        ...     out = MyFunc.apply(a)
+        ...     return out.sum()
+        >>> inp = torch.rand(10, 10, requires_grad=True)
+        >>> out = run_fn(inp)
+        >>> out.backward()
+            Traceback (most recent call last):
+              File "", line 1, in 
+              File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
+                torch.autograd.backward(self, gradient, retain_graph, create_graph)
+              File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
+                allow_unreachable=True)  # allow_unreachable flag
+              File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
+                return self._forward_cls.backward(self, *args)
+              File "", line 8, in backward
+            RuntimeError: Some error in backward
+        >>> with autograd.detect_anomaly():
+        ...     inp = torch.rand(10, 10, requires_grad=True)
+        ...     out = run_fn(inp)
+        ...     out.backward()
+            Traceback of forward call that caused the error:
+              File "tmp.py", line 53, in 
+                out = run_fn(inp)
+              File "tmp.py", line 44, in run_fn
+                out = MyFunc.apply(a)
+            Traceback (most recent call last):
+              File "", line 4, in 
+              File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
+                torch.autograd.backward(self, gradient, retain_graph, create_graph)
+              File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
+                allow_unreachable=True)  # allow_unreachable flag
+              File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
+                return self._forward_cls.backward(self, *args)
+              File "", line 8, in backward
+            RuntimeError: Some error in backward
+
+    """
+
+    def __init__(self, check_nan=True) -> None:
+        self.prev = torch.is_anomaly_enabled()
+        self.check_nan = check_nan
+        self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
+        warnings.warn(
+            "Anomaly Detection has been enabled. "
+            "This mode will increase the runtime "
+            "and should only be enabled for debugging.",
+            stacklevel=2,
+        )
+
+    def __enter__(self) -> None:
+        torch.set_anomaly_enabled(True, self.check_nan)
+
+    def __exit__(self, *args: object) -> None:
+        torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
+
+
+class set_detect_anomaly:
+    r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
+
+    ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
+    based on its argument :attr:`mode`.
+    It can be used as a context-manager or as a function.
+
+    See ``detect_anomaly`` above for details of the anomaly detection behaviour.
+
+    Args:
+        mode (bool): Flag whether to enable anomaly detection (``True``),
+                     or disable (``False``).
+        check_nan (bool): Flag whether to raise an error when the backward
+                          generate "nan"
+
+    """
+
+    def __init__(self, mode: bool, check_nan: bool = True) -> None:
+        self.prev = torch.is_anomaly_enabled()
+        self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
+        torch.set_anomaly_enabled(mode, check_nan)
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, *args: object) -> None:
+        torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
diff --git a/MLPY/Lib/site-packages/torch/autograd/forward_ad.py b/MLPY/Lib/site-packages/torch/autograd/forward_ad.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f0c1400a6a1ff12614a2741b3f33030003b560
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/forward_ad.py
@@ -0,0 +1,227 @@
+import os
+from collections import namedtuple
+
+from typing import Any
+
+import torch
+from .grad_mode import _DecoratorContextManager
+
+__all__ = [
+    "UnpackedDualTensor",
+    "enter_dual_level",
+    "exit_dual_level",
+    "make_dual",
+    "unpack_dual",
+    "dual_level",
+]
+
+# Global variable used to make the python API simpler to use
+_current_level = -1
+
+
+def enter_dual_level():
+    r"""Enter a new forward grad level.
+
+    This level can be used to make and unpack dual Tensors to compute
+    forward gradients.
+
+    This function also updates the current level that is used by default
+    by the other functions in this API.
+    """
+    global _current_level
+    new_level = torch._C._enter_dual_level()
+    if new_level != _current_level + 1:
+        raise RuntimeError(
+            "Entering a new forward AD level but the current level "
+            "is not valid. Make sure you did not modified it directly."
+        )
+    _current_level = new_level
+    return new_level
+
+
+def exit_dual_level(*, level=None):
+    r"""Exit a forward grad level.
+
+    This function deletes all the gradients associated with this
+    level. Only deleting the latest entered level is allowed.
+
+    This function also updates the current level that is used by default
+    by the other functions in this API.
+    """
+    global _current_level
+    if level is None:
+        level = _current_level
+    if level != _current_level:
+        raise RuntimeError(
+            "Trying to exit a forward AD level that was not the last one "
+            "that was created. This is not supported."
+        )
+    torch._C._exit_dual_level(level=level)
+    _current_level = level - 1
+
+
+def make_dual(tensor, tangent, *, level=None):
+    r"""Associate a tensor value with its tangent to create a "dual tensor" for forward AD gradient computation.
+
+    The result is a new tensor aliased to :attr:`tensor` with :attr:`tangent` embedded
+    as an attribute as-is if it has the same storage layout or copied otherwise.
+    The tangent attribute can be recovered with :func:`unpack_dual`.
+
+    This function is backward differentiable.
+
+    Given a function `f` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`)
+    between `J` and a given vector `v` as follows.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> with dual_level():
+        ...     inp = make_dual(x, v)
+        ...     out = f(inp)
+        ...     y, jvp = unpack_dual(out)
+
+    Please see the `forward-mode AD tutorial `__
+    for detailed steps on how to use this API.
+
+    """
+    # See NOTE: [forward-mode AD decompositions mechanism]
+    #
+    # Import from torch._decomp import decompositions_for_jvp to register
+    # decompositions for jvp to the jit registry
+    #
+    # FIXME: We specify that __debug__ must be True because
+    # if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the
+    # following error:
+    #
+    # Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of
+    # type Tuple[Tensor, Tensor]:
+    #   File ".../torch/_decomp/__init__.py", line 1585
+    #     else:
+    #         buffer = z
+    #     return min - torch.log1p(z), buffer
+    #     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
+    if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__:
+        from torch._decomp import decompositions_for_jvp  # noqa: F401
+
+    if level is None:
+        level = _current_level
+
+    if level < 0:
+        raise RuntimeError(
+            "Trying to create a dual Tensor for forward AD but no level "
+            "exists, make sure to enter_dual_level() first."
+        )
+    if not (tensor.is_floating_point() or tensor.is_complex()):
+        raise ValueError(
+            f"Expected primal to be floating point or complex, but got: {tensor.dtype}"
+        )
+    if not (tangent.is_floating_point() or tangent.is_complex()):
+        raise ValueError(
+            f"Expected tangent to be floating point or complex, but got: {tangent.dtype}"
+        )
+
+    return torch._VF._make_dual(tensor, tangent, level=level)
+
+
+_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
+
+
+class UnpackedDualTensor(_UnpackedDualTensor):
+    r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
+
+    See :func:`unpack_dual` for more details.
+
+    """
+
+    pass
+
+
+def unpack_dual(tensor, *, level=None):
+    r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.
+
+    The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of
+    :attr:`tensor`'s primal and ``tangent`` is :attr:`tensor`'s tangent as-is.
+    Neither of these tensors can be dual tensor of level :attr:`level`.
+
+    This function is backward differentiable.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> with dual_level():
+        ...     inp = make_dual(x, x_t)
+        ...     out = f(inp)
+        ...     y, jvp = unpack_dual(out)
+        ...     jvp = unpack_dual(out).tangent
+
+    Please see the `forward-mode AD tutorial `__
+    for detailed steps on how to use this API.
+    """
+    if level is None:
+        level = _current_level
+
+    if level < 0:
+        return UnpackedDualTensor(tensor, None)
+
+    primal, dual = torch._VF._unpack_dual(tensor, level=level)
+
+    return UnpackedDualTensor(primal, dual)
+
+
+class dual_level(_DecoratorContextManager):
+    r"""Context-manager for forward AD, where all forward AD computation must occur within the ``dual_level`` context.
+
+    .. Note::
+
+        The ``dual_level`` context appropriately enters and exit the dual level to
+        controls the current forward AD level, which is used by default by the other
+        functions in this API.
+
+        We currently don't plan to support nested ``dual_level`` contexts, however, so
+        only a single forward AD level is supported. To compute higher-order
+        forward grads, one can use :func:`torch.func.jvp`.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> x = torch.tensor([1])
+        >>> x_t = torch.tensor([1])
+        >>> with dual_level():
+        ...     inp = make_dual(x, x_t)
+        ...     # Do computations with inp
+        ...     out = your_fn(inp)
+        ...     _, grad = unpack_dual(out)
+        >>> grad is None
+        False
+        >>> # After exiting the level, the grad is deleted
+        >>> _, grad_after = unpack_dual(out)
+        >>> grad is None
+        True
+
+    Please see the `forward-mode AD tutorial `__
+    for detailed steps on how to use this API.
+    """
+
+    def __enter__(self):
+        return enter_dual_level()
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        exit_dual_level()
+
+
+# Private helper functions
+_is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled
+
+
+# Private helper function to enable or disable fwd grad.
+# If you're a user and want to use this, please file an issue to discuss the use case.
+class _set_fwd_grad_enabled(_DecoratorContextManager):
+    def __init__(self, mode: bool) -> None:
+        self.prev = _is_fwd_grad_enabled()
+        torch._C._set_fwd_grad_enabled(mode)
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        torch._C._set_fwd_grad_enabled(self.prev)
diff --git a/MLPY/Lib/site-packages/torch/autograd/function.py b/MLPY/Lib/site-packages/torch/autograd/function.py
new file mode 100644
index 0000000000000000000000000000000000000000..31ef625876de8ec2d9598e2c82c03b32b2e06f4f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/function.py
@@ -0,0 +1,883 @@
+import functools
+import inspect
+import itertools
+import warnings
+from collections import OrderedDict
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch._C as _C
+import torch._functorch as _functorch
+import torch.utils.hooks as hooks
+from torch._C import _functions
+from torch._functorch.autograd_function import custom_function_call
+
+__all__ = [
+    "FunctionCtx",
+    "BackwardCFunction",
+    "FunctionMeta",
+    "Function",
+    "once_differentiable",
+    "traceable",
+    "InplaceFunction",
+    "NestedIOFunction",
+]
+
+# Unique id provider for each class inheriting from Function
+# This is incremented in FunctionMeta during class definition
+AUTOGRAD_FUNCTION_COUNTER = itertools.count()
+
+
+# Formerly known as: _ContextMethodMixin
+class FunctionCtx:
+    def save_for_backward(self, *tensors: torch.Tensor):
+        r"""Save given tensors for a future call to :func:`~Function.backward`.
+
+        ``save_for_backward`` should be called at most once, only from inside the
+        :func:`forward` method, and only with tensors.
+
+        All tensors intended to be used in the backward pass should be saved
+        with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
+        incorrect gradients and memory leaks, and enable the application of saved
+        tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
+
+        Note that if intermediary tensors, tensors that are neither inputs
+        nor outputs of :func:`forward`, are saved for backward, your custom Function
+        may not support double backward.
+        Custom Functions that do not support double backward should decorate their
+        :func:`backward` method with ``@once_differentiable`` so that performing
+        double backward raises an error. If you'd like to support double backward,
+        you can either recompute intermediaries based on the inputs during backward
+        or return the intermediaries as the outputs of the custom Function. See the
+        `double backward tutorial `_
+        for more details.
+
+        In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
+        attribute. Before returning them to the user, a check is made to ensure
+        they weren't used in any in-place operation that modified their content.
+
+        Arguments can also be ``None``. This is a no-op.
+
+        See :ref:`extending-autograd` for more details on how to use this method.
+
+        Example::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+            >>> class Func(Function):
+            >>>     @staticmethod
+            >>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
+            >>>         w = x * z
+            >>>         out = x * y + y * z + w * y
+            >>>         ctx.save_for_backward(x, y, w, out)
+            >>>         ctx.z = z  # z is not a tensor
+            >>>         return out
+            >>>
+            >>>     @staticmethod
+            >>>     @once_differentiable
+            >>>     def backward(ctx, grad_out):
+            >>>         x, y, w, out = ctx.saved_tensors
+            >>>         z = ctx.z
+            >>>         gx = grad_out * (y + y * z)
+            >>>         gy = grad_out * (x + z + w)
+            >>>         gz = None
+            >>>         return gx, gy, gz
+            >>>
+            >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
+            >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
+            >>> c = 4
+            >>> d = Func.apply(a, b, c)
+
+        """
+        self.to_save = tensors
+
+    def save_for_forward(self, *tensors: torch.Tensor):
+        r"""Save given tensors for a future call to :func:`~Function.jvp`.
+
+        ``save_for_forward`` should be only called once, from inside the :func:`forward`
+        method, and only be called with tensors.
+
+        In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
+        attribute.
+
+        Arguments can also be ``None``. This is a no-op.
+
+        See :ref:`extending-autograd` for more details on how to use this method.
+
+        Example::
+            >>> # xdoctest: +SKIP
+            >>> class Func(torch.autograd.Function):
+            >>>     @staticmethod
+            >>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
+            >>>         ctx.save_for_backward(x, y)
+            >>>         ctx.save_for_forward(x, y)
+            >>>         ctx.z = z
+            >>>         return x * y * z
+            >>>
+            >>>     @staticmethod
+            >>>     def jvp(ctx, x_t, y_t, _):
+            >>>         x, y = ctx.saved_tensors
+            >>>         z = ctx.z
+            >>>         return z * (y * x_t + x * y_t)
+            >>>
+            >>>     @staticmethod
+            >>>     def vjp(ctx, grad_out):
+            >>>         x, y = ctx.saved_tensors
+            >>>         z = ctx.z
+            >>>         return z * grad_out * y, z * grad_out * x, None
+            >>>
+            >>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
+            >>>     t = torch.tensor(1., dtype=torch.double)
+            >>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
+            >>>     c = 4
+            >>>
+            >>>     with fwAD.dual_level():
+            >>>         a_dual = fwAD.make_dual(a, t)
+            >>>         d = Func.apply(a_dual, b, c)
+
+        """
+        for tensor in tensors:
+            assert isinstance(tensor, torch.Tensor) or tensor is None, (
+                "save_for_forward expects all arguments to be tensors; you should "
+                "save non-tensors as attributes on ctx."
+            )
+
+        self.saved_for_forward = tensors
+
+    def mark_dirty(self, *args: torch.Tensor):
+        r"""Mark given tensors as modified in an in-place operation.
+
+        **This should be called at most once, only from inside the**
+        :func:`forward` **method, and all arguments should be inputs.**
+
+        Every tensor that's been modified in-place in a call to :func:`forward`
+        should be given to this function, to ensure correctness of our checks.
+        It doesn't matter whether the function is called before or after
+        modification.
+
+        Examples::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+            >>> class Inplace(Function):
+            >>>     @staticmethod
+            >>>     def forward(ctx, x):
+            >>>         x_npy = x.numpy() # x_npy shares storage with x
+            >>>         x_npy += 1
+            >>>         ctx.mark_dirty(x)
+            >>>         return x
+            >>>
+            >>>     @staticmethod
+            >>>     @once_differentiable
+            >>>     def backward(ctx, grad_output):
+            >>>         return grad_output
+            >>>
+            >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
+            >>> b = a * a
+            >>> Inplace.apply(a)  # This would lead to wrong gradients!
+            >>>                   # but the engine would not know unless we mark_dirty
+            >>> # xdoctest: +SKIP
+            >>> b.backward() # RuntimeError: one of the variables needed for gradient
+            >>>              # computation has been modified by an inplace operation
+
+        """
+        self.dirty_tensors = args
+
+    def mark_shared_storage(self, *pairs):
+        warnings.warn(
+            "mark_shared_storage is deprecated. "
+            "Tensors with shared storages are automatically tracked. Note "
+            "that calls to `set_()` are not tracked"
+        )
+
+    def mark_non_differentiable(self, *args: torch.Tensor):
+        r"""Mark outputs as non-differentiable.
+
+        **This should be called at most once, only from inside the**
+        :func:`forward` **method, and all arguments should be tensor outputs.**
+
+        This will mark outputs as not requiring gradients, increasing the
+        efficiency of backward computation. You still need to accept a gradient
+        for each output in :meth:`~Function.backward`, but it's always going to
+        be a zero tensor with the same shape as the shape of a corresponding
+        output.
+
+        This is used e.g. for indices returned from a sort. See example::
+            >>> class Func(Function):
+            >>>     @staticmethod
+            >>>     def forward(ctx, x):
+            >>>         sorted, idx = x.sort()
+            >>>         ctx.mark_non_differentiable(idx)
+            >>>         ctx.save_for_backward(x, idx)
+            >>>         return sorted, idx
+            >>>
+            >>>     @staticmethod
+            >>>     @once_differentiable
+            >>>     def backward(ctx, g1, g2):  # still need to accept g2
+            >>>         x, idx = ctx.saved_tensors
+            >>>         grad_input = torch.zeros_like(x)
+            >>>         grad_input.index_add_(0, idx, g1)
+            >>>         return grad_input
+
+        """
+        self.non_differentiable = args
+
+    def set_materialize_grads(self, value: bool):
+        r"""Set whether to materialize grad tensors. Default is ``True``.
+
+        **This should be called only from inside the** :func:`forward` **method**
+
+        If ``True``, undefined grad tensors will be expanded to tensors full of zeros
+        prior to calling the :func:`backward` and :func:`jvp` methods.
+
+        Example::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+            >>> class SimpleFunc(Function):
+            >>>     @staticmethod
+            >>>     def forward(ctx, x):
+            >>>         return x.clone(), x.clone()
+            >>>
+            >>>     @staticmethod
+            >>>     @once_differentiable
+            >>>     def backward(ctx, g1, g2):
+            >>>         return g1 + g2  # No check for None necessary
+            >>>
+            >>> # We modify SimpleFunc to handle non-materialized grad outputs
+            >>> class Func(Function):
+            >>>     @staticmethod
+            >>>     def forward(ctx, x):
+            >>>         ctx.set_materialize_grads(False)
+            >>>         ctx.save_for_backward(x)
+            >>>         return x.clone(), x.clone()
+            >>>
+            >>>     @staticmethod
+            >>>     @once_differentiable
+            >>>     def backward(ctx, g1, g2):
+            >>>         x, = ctx.saved_tensors
+            >>>         grad_input = torch.zeros_like(x)
+            >>>         if g1 is not None:  # We must check for None now
+            >>>             grad_input += g1
+            >>>         if g2 is not None:
+            >>>             grad_input += g2
+            >>>         return grad_input
+            >>>
+            >>> a = torch.tensor(1., requires_grad=True)
+            >>> b, _ = Func.apply(a)  # induces g2 to be undefined
+
+        """
+        self.materialize_grads = value
+
+
+# DO NOT USE: This is only defined to be able to load old serialized models
+_ContextMethodMixin = FunctionCtx
+
+
+class _HookMixin:
+    @staticmethod
+    def _register_hook(backward_hooks, hook):
+        if backward_hooks is None:
+            backward_hooks = OrderedDict()
+        handle = hooks.RemovableHandle(backward_hooks)
+        backward_hooks[handle.id] = hook
+        return backward_hooks, handle
+
+
+class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
+    r"""
+    This class is used for internal autograd work. Do not use.
+    """
+
+    def apply(self, *args):
+        r"""
+        Apply method used when executing this Node during the backward
+        """
+        # _forward_cls is defined by derived class
+        # The user should define either backward or vjp but never both.
+        backward_fn = self._forward_cls.backward  # type: ignore[attr-defined]
+        vjp_fn = self._forward_cls.vjp  # type: ignore[attr-defined]
+        if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
+            raise RuntimeError(
+                "Implementing both 'backward' and 'vjp' for a custom "
+                "Function is not allowed. You should only implement one "
+                "of them."
+            )
+        user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
+        return user_fn(self, *args)
+
+    def apply_jvp(self, *args):
+        r"""
+        Apply method used when executing forward mode AD during the forward
+        """
+        # _forward_cls is defined by derived class
+        return self._forward_cls.jvp(self, *args)  # type: ignore[attr-defined]
+
+    def _compiled_autograd_key(self):
+        return self._forward_cls._compiled_autograd_key(self)  # type: ignore[attr-defined]
+
+
+def _warn_traceable_deprecated():
+    warnings.warn(
+        "The is_traceable field on torch.autograd.Function is deprecated "
+        "and will be removed in PyTorch 2.4.",
+        stacklevel=3,
+    )
+
+
+class FunctionMeta(type):
+    """Function metaclass.
+
+    This metaclass sets up the following properties:
+        _backward_cls: The Function class corresponding to the differentiated
+            version of this function (which is generated on the fly by this
+            metaclass).
+    """
+
+    def __init__(cls, name, bases, attrs):
+        backward_fn = type(
+            name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
+        )
+        backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER)  # type: ignore[attr-defined]
+        backward_fn._compiled_autograd_should_lift = attrs.get(  # type: ignore[attr-defined]
+            "_compiled_autograd_should_lift", True
+        )
+        cls._backward_cls = backward_fn
+
+        if "is_traceable" in attrs and attrs["is_traceable"] is True:
+            _warn_traceable_deprecated()
+
+        super().__init__(name, bases, attrs)
+
+    def __getattribute__(cls, name):
+        if name == "is_traceable":
+            _warn_traceable_deprecated()
+        return super().__getattribute__(name)
+
+    def __setattr__(cls, name, value):
+        if name == "is_traceable" and value is True:
+            warnings.warn(
+                "The is_traceable field on torch.autograd.Function is deprecated "
+                "and will be removed in PyTorch 2.4.",
+                stacklevel=2,
+            )
+        return super().__setattr__(name, value)
+
+
+class _SingleLevelFunction(
+    _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
+):
+    @staticmethod
+    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
+        r"""Define the forward of the custom autograd Function.
+
+        This function is to be overridden by all subclasses.
+        There are two ways to define forward:
+
+        Usage 1 (Combined forward and ctx)::
+
+            @staticmethod
+            def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
+                pass
+
+        - It must accept a context ctx as the first argument, followed by any
+          number of arguments (tensors or other types).
+        - See :ref:`combining-forward-context` for more details
+
+        Usage 2 (Separate forward and ctx)::
+
+            @staticmethod
+            def forward(*args: Any, **kwargs: Any) -> Any:
+                pass
+
+            @staticmethod
+            def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
+                pass
+
+        - The forward no longer accepts a ctx argument.
+        - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
+          staticmethod to handle setting up the ``ctx`` object.
+          ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
+          to the forward.
+        - See :ref:`extending-autograd` for more details
+
+        The context can be used to store arbitrary data that can be then
+        retrieved during the backward pass. Tensors should not be stored
+        directly on `ctx` (though this is not currently enforced for
+        backward compatibility). Instead, tensors should be saved either with
+        :func:`ctx.save_for_backward` if they are intended to be used in
+        ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
+        if they are intended to be used for in ``jvp``.
+        """
+        raise NotImplementedError(
+            "You must implement the forward function for custom autograd.Function."
+        )
+
+    @staticmethod
+    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
+        r"""There are two ways to define the forward pass of an autograd.Function.
+
+        Either:
+
+        1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
+           ``setup_context`` is not overridden. Setting up the ctx for backward
+           happens inside the ``forward``.
+        2. Override forward with the signature ``forward(*args, **kwargs)`` and
+           override ``setup_context``. Setting up the ctx for backward happens
+           inside ``setup_context`` (as opposed to inside the ``forward``)
+
+        See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
+        """
+        raise NotImplementedError("setup_context is not implemented.")
+
+    @staticmethod
+    def backward(ctx: Any, *grad_outputs: Any) -> Any:
+        r"""Define a formula for differentiating the operation with backward mode automatic differentiation.
+
+        This function is to be overridden by all subclasses.
+        (Defining this function is equivalent to defining the ``vjp`` function.)
+
+        It must accept a context :attr:`ctx` as the first argument, followed by
+        as many outputs as the :func:`forward` returned (None will be passed in
+        for non tensor outputs of the forward function),
+        and it should return as many tensors, as there were inputs to
+        :func:`forward`. Each argument is the gradient w.r.t the given output,
+        and each returned value should be the gradient w.r.t. the
+        corresponding input. If an input is not a Tensor or is a Tensor not
+        requiring grads, you can just pass None as a gradient for that input.
+
+        The context can be used to retrieve tensors saved during the forward
+        pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
+        of booleans representing whether each input needs gradient. E.g.,
+        :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
+        first input to :func:`forward` needs gradient computed w.r.t. the
+        output.
+        """
+        raise NotImplementedError(
+            "You must implement either the backward or vjp method for "
+            "your custom autograd.Function to use it with backward "
+            "mode AD."
+        )
+
+    # vjp and backward are alias of each other
+    vjp = backward
+
+    @staticmethod
+    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
+        r"""Define a formula for differentiating the operation with forward mode automatic differentiation.
+
+        This function is to be overridden by all subclasses.
+        It must accept a context :attr:`ctx` as the first argument, followed by
+        as many inputs as the :func:`forward` got (None will be passed in
+        for non tensor inputs of the forward function),
+        and it should return as many tensors as there were outputs to
+        :func:`forward`. Each argument is the gradient w.r.t the given input,
+        and each returned value should be the gradient w.r.t. the
+        corresponding output. If an output is not a Tensor or the function is not
+        differentiable with respect to that output, you can just pass None as a
+        gradient for that input.
+
+        You can use the :attr:`ctx` object to pass any value from the forward to this
+        functions.
+        """
+        raise NotImplementedError(
+            "You must implement the jvp function for custom "
+            "autograd.Function to use it with forward mode AD."
+        )
+
+
+class Function(_SingleLevelFunction):
+    r"""Base class to create custom `autograd.Function`.
+
+    To create a custom `autograd.Function`, subclass this class and implement
+    the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
+    op in the forward pass, call the class method ``apply``. Do not call
+    :meth:`forward` directly.
+
+    To ensure correctness and best performance, make sure you are calling the
+    correct methods on ``ctx`` and validating your backward function using
+    :func:`torch.autograd.gradcheck`.
+
+    See :ref:`extending-autograd` for more details on how to use this class.
+
+    Examples::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> class Exp(Function):
+        >>>     @staticmethod
+        >>>     def forward(ctx, i):
+        >>>         result = i.exp()
+        >>>         ctx.save_for_backward(result)
+        >>>         return result
+        >>>
+        >>>     @staticmethod
+        >>>     def backward(ctx, grad_output):
+        >>>         result, = ctx.saved_tensors
+        >>>         return grad_output * result
+        >>>
+        >>> # Use it by calling the apply method:
+        >>> # xdoctest: +SKIP
+        >>> output = Exp.apply(input)
+    """
+
+    def __init__(self, *args, **kwargs):
+        cls = self.__class__
+        warnings.warn(
+            f"{cls} should not be instantiated. Methods on autograd functions"
+            "are all static, so you should invoke them on the class itself. "
+            "Instantiating an autograd function will raise an "
+            "error in a future version of PyTorch.",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+
+    def __call__(self, *args, **kwargs):
+        raise RuntimeError(
+            "Legacy autograd function with non-static forward method is deprecated. "
+            "Please use new-style autograd function with static forward method. "
+            "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
+        )
+
+    # for the tracer
+    is_traceable = False
+
+    """
+    Bool that specifies if PyTorch should attempt to autogenerate
+    :func:`torch.vmap` support for this autograd.Function. You may set this to
+    True only if this autograd.Function's forward, backward, and jvp (if they
+    exist) are written using PyTorch operations; otherwise, please override
+    :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
+
+    Please see :ref:`func-autograd-function` for more details.
+    """
+    generate_vmap_rule = False
+
+    @staticmethod
+    def vmap(info, in_dims, *args):
+        r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`.
+
+        For a :func:`torch.autograd.Function` to support
+        :func:`torch.vmap`, you must either override this static method, or set
+        ``generate_vmap_rule`` to ``True`` (you may not do both).
+
+        If you choose to override this staticmethod: it must accept
+
+        - an ``info`` object as the first argument. ``info.batch_size``
+          specifies the size of the dimension being vmapped over,
+          while ``info.randomness`` is the randomness option passed to
+          :func:`torch.vmap`.
+        - an ``in_dims`` tuple as the second argument.
+          For each arg in ``args``, ``in_dims`` has a corresponding
+          ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
+          the arg is not being vmapped over, otherwise, it is an integer
+          specifying what dimension of the Tensor is being vmapped over.
+        - ``*args``, which is the same as the args to :meth:`~Function.forward`.
+
+        The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
+        Similar to ``in_dims``, ``out_dims`` should be of the same structure as
+        ``output`` and contain one ``out_dim`` per output that specifies if the
+        output has the vmapped dimension and what index it is in.
+
+        Please see :ref:`func-autograd-function` for more details.
+        """
+        raise NotImplementedError(
+            "To use autograd.Function with vmap, you must either override the "
+            "vmap staticmethod or set generate_vmap_rule=True."
+        )
+
+    @classmethod
+    def apply(cls, *args, **kwargs):
+        def bind_default_args(func, *args, **kwargs):
+            signature = inspect.signature(func)
+            bound_args = signature.bind(*args, **kwargs)
+            bound_args.apply_defaults()
+
+            return bound_args.args
+
+        is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
+        if is_setup_ctx_defined:
+            args = bind_default_args(cls.forward, *args, **kwargs)
+
+        if not torch._C._are_functorch_transforms_active():
+            # See NOTE: [functorch vjp and autograd interaction]
+            args = _functorch.utils.unwrap_dead_wrappers(args)
+            return super().apply(*args, **kwargs)  # type: ignore[misc]
+
+        if not is_setup_ctx_defined:
+            raise RuntimeError(
+                "In order to use an autograd.Function with functorch transforms "
+                "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
+                "staticmethod. For more details, please see "
+                "https://pytorch.org/docs/master/notes/extending.func.html"
+            )
+
+        return custom_function_call(cls, *args, **kwargs)
+
+    @staticmethod
+    def _compiled_autograd_key(ctx):
+        return (ctx._autograd_function_id,)
+
+
+def once_differentiable(fn):
+    @functools.wraps(fn)
+    def wrapper(ctx, *args):
+        with torch.no_grad():
+            outputs = fn(ctx, *args)
+
+        if not torch.is_grad_enabled():
+            return outputs
+
+        # If any of the inputs have requires_grad=True, we force the outputs
+        # to have requires_grad=True but point to a grad_fn which throws an
+        # error message during (double) back-propagation.
+        # XXX: this is only an approximation of requires_grad - there's no way
+        # to figure out if fn didn't use ctx.saved_tensors and as a result
+        # some Tensors might require grad, even if no args do.
+        # Unfortunately, this leads to unexpected error messages ("no nodes
+        # require computing gradients"), but I don't have a better idea.
+        # These functions would raise an error in backward anyway.
+        requires_grad = any(
+            isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
+        )
+        if not requires_grad:
+            return outputs
+
+        if not isinstance(outputs, tuple):
+            outputs = (outputs,)
+
+        err_fn = _functions.DelayedError(
+            b"trying to differentiate twice a function that was marked "
+            b"with @once_differentiable",
+            len(outputs),
+        )
+
+        # Create aliases of each output that has requires_grad=True. We need
+        # at least one of the inputs to err_fn to require grad so that the
+        # output will have a grad_fn.
+        def fake_requires_grad(var):
+            if var is not None:
+                var = var.detach()
+                var.requires_grad = True
+            return var
+
+        return err_fn(*[fake_requires_grad(v) for v in outputs])
+
+    return wrapper
+
+
+def traceable(fn_cls):
+    r"""Mark Function as traceable for the JIT.
+
+    Traceable functions have additional restrictions - they can't pass any
+    data-dependent values to backward (e.g. Prod passes the output, which makes
+    it non-traceable), and their backward should be implemented entirely in terms
+    of operations on autograd Tensors in all cases.
+
+    DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH
+    CARE (or can give incorrect results otherwise).
+    """
+    warnings.warn(
+        "torch.autograd.function.traceable is deprecated "
+        "and will be removed in PyTorch 2.4.",
+        stacklevel=2,
+    )
+    fn_cls.is_traceable = True
+    return fn_cls
+
+
+class InplaceFunction(Function):
+    r"""
+    This class is here only for backward compatibility reasons.
+    Use :class:`Function` instead of this for any new use case.
+    """
+
+    def __init__(self, inplace=False):
+        super().__init__()
+        self.inplace = inplace
+
+
+def _nested_map(condition, fn, condition_msg=None):
+    def _map(obj):
+        if condition(obj):
+            return fn(obj)
+        elif obj is None:
+            return None
+        elif isinstance(obj, (list, tuple)):
+            mapped = (_map(x) for x in obj)
+            if hasattr(obj, "_fields"):
+                # obj is namedtuple
+                return type(obj)(*mapped)
+            return type(obj)(mapped)
+        elif isinstance(obj, dict):
+            return {x: _map(obj[x]) for x in obj}
+        else:
+            raise ValueError(
+                "Auto nesting doesn't know how to process "
+                "an input object of type "
+                + torch.typename(obj)
+                + (
+                    ". Accepted types: " + condition_msg + ", or lists/tuples of them"
+                    if condition_msg
+                    else ""
+                )
+            )
+
+    return _map
+
+
+def _jit_unwrap_structured(obj):
+    if hasattr(obj, "_jit_unwrap"):
+        return obj._jit_unwrap()
+    return obj
+
+
+def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
+    def _iter(obj):
+        if conversion is not None:
+            obj = conversion(obj)
+        if condition(obj):
+            yield obj
+        elif obj is None:
+            return
+        elif isinstance(obj, (list, tuple)):
+            for o in obj:
+                yield from _iter(o)
+        elif isinstance(obj, dict):
+            # We only accept primitive key types, so we needn't inspect them
+            for o in obj.values():
+                yield from _iter(o)
+        elif allow_unknown:
+            yield obj
+        else:
+            raise ValueError(
+                "Auto nesting doesn't know how to process "
+                "an input object of type "
+                + torch.typename(obj)
+                + (
+                    ". Accepted types: " + condition_msg + ", or lists/tuples of them"
+                    if condition_msg
+                    else ""
+                )
+            )
+
+    return _iter
+
+
+def _unflatten(input, proto):
+    # unflatten a list or tuple input into a nested list/tuple structure
+    # specified by proto
+    def unflatten_helper(input, proto):
+        res: List[Optional[torch.Tensor]] = []
+        if hasattr(proto, "_jit_wrap"):
+            return proto._jit_wrap(input)
+        if not isinstance(proto, (list, tuple)):
+            return input[0], input[1:]
+        for e in proto:
+            if e is None:
+                res.append(e)
+            else:
+                res_e, input = unflatten_helper(input, e)
+                res.append(res_e)
+        return type(proto)(res), input
+
+    return unflatten_helper(input, proto)[0]
+
+
+_iter_jit_values = _iter_filter(
+    lambda o: o is None or isinstance(o, torch._C.Value),
+    condition_msg="jit's Values or None",
+)
+_iter_tensors = _iter_filter(
+    lambda x: isinstance(x, torch.Tensor),
+    condition_msg="Tensors",
+    conversion=_jit_unwrap_structured,
+)
+_iter_tensors_permissive = _iter_filter(
+    lambda x: isinstance(x, torch.Tensor),
+    allow_unknown=True,
+    condition_msg="Tensors (permissive)",
+)
+_iter_None_tensors = _iter_filter(
+    lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
+)
+_map_tensor_data = _nested_map(
+    lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
+)
+
+
+class NestedIOFunction(Function):
+    r"""
+    This class is here only for backward compatibility reasons.
+    Use :class:`Function` instead of this for any new use case.
+    """
+    # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
+    # superclass (Function) but are instance methods here, which mypy reports as incompatible.
+
+    def _do_forward(self, *input):
+        self._nested_input = input
+        flat_input = tuple(_iter_tensors(input))
+        flat_output = super()._do_forward(*flat_input)  # type: ignore[misc]
+        nested_output = self._nested_output
+        nested_tensors = _unflatten(flat_output, self._nested_output)
+        return nested_tensors
+
+    def _do_backward(self, gradients, retain_variables):
+        self.retain_variables = retain_variables
+        result = super()._do_backward(gradients, retain_variables)  # type: ignore[misc]
+        if not retain_variables:
+            del self._nested_output
+            del self._to_save_nested
+        return result
+
+    def backward(self, *gradients: Any) -> Any:  # type: ignore[override]
+        r"""
+        Shared backward utility.
+        """
+        nested_gradients = _unflatten(gradients, self._nested_output)
+        result = self.backward_extended(*nested_gradients)  # type: ignore[func-returns-value]
+        return tuple(_iter_None_tensors(result))
+
+    __call__ = _do_forward
+
+    def forward(self, *args: Any) -> Any:  # type: ignore[override]
+        r"""
+        Shared forward utility.
+        """
+        nested_tensors = _map_tensor_data(self._nested_input)
+        result = self.forward_extended(*nested_tensors)  # type: ignore[func-returns-value]
+        del self._nested_input
+        self._nested_output = result
+        return tuple(_iter_tensors(result))
+
+    def save_for_backward(self, *args: Any) -> None:
+        r"""
+        See :meth:`Function.save_for_backward`.
+        """
+        self.to_save = tuple(_iter_tensors(args))
+        self._to_save_nested = args
+
+    @property
+    def saved_tensors(self):
+        r"""
+        See :meth:`Function.saved_tensors`.
+        """
+        flat_tensors = super().saved_tensors  # type: ignore[misc]
+        return _unflatten(flat_tensors, self._to_save_nested)
+
+    def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
+        r"""
+        See :meth:`Function.mark_dirty`.
+        """
+        self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
+
+    def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
+        r"""
+        See :meth:`Function.mark_non_differentiable`.
+        """
+        self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
+
+    def forward_extended(self, *input: Any) -> None:
+        r"""
+        User defined forward.
+        """
+        raise NotImplementedError
+
+    def backward_extended(self, *grad_output: Any) -> None:
+        r"""
+        User defined backward.
+        """
+        raise NotImplementedError
diff --git a/MLPY/Lib/site-packages/torch/autograd/functional.py b/MLPY/Lib/site-packages/torch/autograd/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..23d817dbef15407efc63e8397da87dd5fcc5b99a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/functional.py
@@ -0,0 +1,1182 @@
+from typing import List, Tuple
+
+import torch
+from torch._vmap_internals import _vmap
+from . import forward_ad as fwAD
+
+__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
+
+# Utility functions
+
+
+def _as_tuple_nocheck(x):
+    if isinstance(x, tuple):
+        return x
+    elif isinstance(x, list):
+        return tuple(x)
+    else:
+        return (x,)
+
+
+def _as_tuple(inp, arg_name=None, fn_name=None):
+    # Ensures that inp is a tuple of Tensors
+    # Returns whether or not the original inp was a tuple and the tupled version of the input
+    if arg_name is None and fn_name is None:
+        return _as_tuple_nocheck(inp)
+
+    is_inp_tuple = True
+    if not isinstance(inp, tuple):
+        inp = (inp,)
+        is_inp_tuple = False
+
+    for i, el in enumerate(inp):
+        if not isinstance(el, torch.Tensor):
+            if is_inp_tuple:
+                raise TypeError(
+                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
+                    f" value at index {i} has type {type(el)}."
+                )
+            else:
+                raise TypeError(
+                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
+                    f" given {arg_name} has type {type(el)}."
+                )
+
+    return is_inp_tuple, inp
+
+
+def _tuple_postprocess(res, to_unpack):
+    # Unpacks a potentially nested tuple of Tensors
+    # to_unpack should be a single boolean or a tuple of two booleans.
+    # It is used to:
+    # - invert _as_tuple when res should match the inp given to _as_tuple
+    # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
+    if isinstance(to_unpack, tuple):
+        assert len(to_unpack) == 2
+        if not to_unpack[1]:
+            res = tuple(el[0] for el in res)
+        if not to_unpack[0]:
+            res = res[0]
+    else:
+        if not to_unpack:
+            res = res[0]
+    return res
+
+
+def _grad_preprocess(inputs, create_graph, need_graph):
+    # Preprocess the inputs to make sure they require gradient
+    # inputs is a tuple of Tensors to preprocess
+    # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
+    # need_graph specifies if we internally want gradients to flow back to the Tensors in res
+    # Note that we *always* create a new Tensor object to be able to see the difference between
+    # inputs given as arguments and the same Tensors automatically captured by the user function.
+    # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
+    res = []
+    for inp in inputs:
+        if create_graph and inp.requires_grad:
+            # Create at least a new Tensor object in a differentiable way
+            if not inp.is_sparse:
+                # Use .view_as() to get a shallow copy
+                res.append(inp.view_as(inp))
+            else:
+                # We cannot use view for sparse Tensors so we clone
+                res.append(inp.clone())
+        else:
+            res.append(inp.detach().requires_grad_(need_graph))
+    return tuple(res)
+
+
+def _grad_postprocess(inputs, create_graph):
+    # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
+    # request it.
+    if isinstance(inputs[0], torch.Tensor):
+        if not create_graph:
+            return tuple(inp.detach() for inp in inputs)
+        else:
+            return inputs
+    else:
+        return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
+
+
+def _validate_v(v, other, is_other_tuple):
+    # This assumes that other is the correct shape, and v should match
+    # Both are assumed to be tuples of Tensors
+    if len(other) != len(v):
+        if is_other_tuple:
+            raise RuntimeError(
+                f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
+            )
+        else:
+            raise RuntimeError("The given v should contain a single Tensor.")
+
+    for idx, (el_v, el_other) in enumerate(zip(v, other)):
+        if el_v.size() != el_other.size():
+            prepend = ""
+            if is_other_tuple:
+                prepend = f"Entry {idx} in "
+            raise RuntimeError(
+                f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}."
+            )
+
+
+def _check_requires_grad(inputs, input_type, strict):
+    # Used to make all the necessary checks to raise nice errors in strict mode.
+    if not strict:
+        return
+
+    if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
+        raise RuntimeError("Invalid input_type to _check_requires_grad")
+    for i, inp in enumerate(inputs):
+        if inp is None:
+            # This can only be reached for grad_inputs.
+            raise RuntimeError(
+                f"The output of the user-provided function is independent of input {i}."
+                " This is not allowed in strict mode."
+            )
+        if not inp.requires_grad:
+            if input_type == "hessian":
+                raise RuntimeError(
+                    f"The hessian of the user-provided function with respect to input {i}"
+                    " is independent of the input. This is not allowed in strict mode."
+                    " You should ensure that your function is thrice differentiable and that"
+                    " the hessian depends on the inputs."
+                )
+            elif input_type == "jacobian":
+                raise RuntimeError(
+                    "While computing the hessian, found that the jacobian of the user-provided"
+                    f" function with respect to input {i} is independent of the input. This is not"
+                    " allowed in strict mode. You should ensure that your function is twice"
+                    " differentiable and that the jacobian depends on the inputs (this would be"
+                    " violated by a linear function for example)."
+                )
+            elif input_type == "grad_inputs":
+                raise RuntimeError(
+                    f"The gradient with respect to input {i} is independent of the inputs of the"
+                    " user-provided function. This is not allowed in strict mode."
+                )
+            else:
+                raise RuntimeError(
+                    f"Output {i} of the user-provided function does not require gradients."
+                    " The outputs must be computed in a differentiable manner from the input"
+                    " when running in strict mode."
+                )
+
+
+def _autograd_grad(
+    outputs,
+    inputs,
+    grad_outputs=None,
+    create_graph=False,
+    retain_graph=None,
+    is_grads_batched=False,
+):
+    # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
+    # This has the extra constraint that inputs has to be a tuple
+    assert isinstance(outputs, tuple)
+    if grad_outputs is None:
+        grad_outputs = (None,) * len(outputs)
+    assert isinstance(grad_outputs, tuple)
+    assert len(outputs) == len(grad_outputs)
+
+    new_outputs: Tuple[torch.Tensor, ...] = tuple()
+    new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
+    for out, grad_out in zip(outputs, grad_outputs):
+        if out is not None and out.requires_grad:
+            new_outputs += (out,)
+            new_grad_outputs += (grad_out,)
+
+    if len(new_outputs) == 0:
+        # No differentiable output, we don't need to call the autograd engine
+        return (None,) * len(inputs)
+    else:
+        return torch.autograd.grad(
+            new_outputs,
+            inputs,
+            new_grad_outputs,
+            allow_unused=True,
+            create_graph=create_graph,
+            retain_graph=retain_graph,
+            is_grads_batched=is_grads_batched,
+        )
+
+
+def _fill_in_zeros(grads, refs, strict, create_graph, stage):
+    # Used to detect None in the grads and depending on the flags, either replace them
+    # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
+    # strict and create graph allow us to detect when it is appropriate to raise an error
+    # stage gives us information of which backward call we consider to give good error message
+    if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
+        raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
+
+    res: Tuple[torch.Tensor, ...] = tuple()
+    for i, grads_i in enumerate(grads):
+        if grads_i is None:
+            if strict:
+                if stage == "back":
+                    raise RuntimeError(
+                        "The output of the user-provided function is independent of "
+                        f"input {i}. This is not allowed in strict mode."
+                    )
+                elif stage == "back_trick":
+                    raise RuntimeError(
+                        f"The gradient with respect to the input is independent of entry {i}"
+                        " in the grad_outputs when using the double backward trick to compute"
+                        " forward mode gradients. This is not allowed in strict mode."
+                    )
+                elif stage == "double_back":
+                    raise RuntimeError(
+                        "The jacobian of the user-provided function is independent of "
+                        f"input {i}. This is not allowed in strict mode."
+                    )
+                else:
+                    raise RuntimeError(
+                        "The hessian of the user-provided function is independent of "
+                        f"entry {i} in the grad_jacobian. This is not allowed in strict "
+                        "mode as it prevents from using the double backward trick to "
+                        "replace forward mode AD."
+                    )
+
+            grads_i = torch.zeros_like(refs[i])
+        else:
+            if strict and create_graph and not grads_i.requires_grad:
+                if "double" not in stage:
+                    raise RuntimeError(
+                        "The jacobian of the user-provided function is independent of "
+                        f"input {i}. This is not allowed in strict mode when create_graph=True."
+                    )
+                else:
+                    raise RuntimeError(
+                        "The hessian of the user-provided function is independent of "
+                        f"input {i}. This is not allowed in strict mode when create_graph=True."
+                    )
+
+        res += (grads_i,)
+
+    return res
+
+
+# Public API
+
+
+def vjp(func, inputs, v=None, create_graph=False, strict=False):
+    r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a tuple of Tensors or a Tensor.
+        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+        v (tuple of Tensors or Tensor): The vector for which the vector
+            Jacobian product is computed.  Must be the same size as the output
+            of ``func``. This argument is optional when the output of ``func``
+            contains a single element and (if it is not provided) will be set
+            as a Tensor containing a single ``1``.
+        create_graph (bool, optional): If ``True``, both the output and result
+            will be computed in a differentiable way. Note that when ``strict``
+            is ``False``, the result can not require gradients or be
+            disconnected from the inputs.  Defaults to ``False``.
+        strict (bool, optional): If ``True``, an error will be raised when we
+            detect that there exists an input such that all the outputs are
+            independent of it. If ``False``, we return a Tensor of zeros as the
+            vjp for said inputs, which is the expected mathematical value.
+            Defaults to ``False``.
+
+    Returns:
+        output (tuple): tuple with:
+            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
+
+            vjp (tuple of Tensors or Tensor): result of the dot product with
+            the same shape as the inputs.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def exp_reducer(x):
+        ...     return x.exp().sum(dim=1)
+        >>> inputs = torch.rand(4, 4)
+        >>> v = torch.ones(4)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> vjp(exp_reducer, inputs, v)
+        (tensor([5.7817, 7.2458, 5.7830, 6.7782]),
+         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
+                [2.1288, 1.0652, 1.5483, 2.5035],
+                [2.2046, 1.1292, 1.1432, 1.3059],
+                [1.3225, 1.6652, 1.7753, 2.0152]]))
+
+        >>> vjp(exp_reducer, inputs, v, create_graph=True)
+        (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=),
+         tensor([[1.4458, 1.3962, 1.3042, 1.6354],
+                [2.1288, 1.0652, 1.5483, 2.5035],
+                [2.2046, 1.1292, 1.1432, 1.3059],
+                [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=))
+
+        >>> def adder(x, y):
+        ...     return 2 * x + 3 * y
+        >>> inputs = (torch.rand(2), torch.rand(2))
+        >>> v = torch.ones(2)
+        >>> vjp(adder, inputs, v)
+        (tensor([2.4225, 2.3340]),
+         (tensor([2., 2.]), tensor([3., 3.])))
+    """
+    with torch.enable_grad():
+        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
+        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+        outputs = func(*inputs)
+        is_outputs_tuple, outputs = _as_tuple(
+            outputs, "outputs of the user-provided function", "vjp"
+        )
+        _check_requires_grad(outputs, "outputs", strict=strict)
+
+        if v is not None:
+            _, v = _as_tuple(v, "v", "vjp")
+            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
+            _validate_v(v, outputs, is_outputs_tuple)
+        else:
+            if len(outputs) != 1 or outputs[0].nelement() != 1:
+                raise RuntimeError(
+                    "The vector v can only be None if the "
+                    "user-provided function returns "
+                    "a single Tensor with a single element."
+                )
+
+    enable_grad = True if create_graph else torch.is_grad_enabled()
+    with torch.set_grad_enabled(enable_grad):
+        grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
+        vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
+
+    # Cleanup objects and return them to the user
+    outputs = _grad_postprocess(outputs, create_graph)
+    vjp = _grad_postprocess(vjp, create_graph)
+
+    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
+        vjp, is_inputs_tuple
+    )
+
+
+def jvp(func, inputs, v=None, create_graph=False, strict=False):
+    r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a tuple of Tensors or a Tensor.
+        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+        v (tuple of Tensors or Tensor): The vector for which the Jacobian
+            vector product is computed. Must be the same size as the input of
+            ``func``. This argument is optional when the input to ``func``
+            contains a single element and (if it is not provided) will be set
+            as a Tensor containing a single ``1``.
+        create_graph (bool, optional): If ``True``, both the output and result
+            will be computed in a differentiable way. Note that when ``strict``
+            is ``False``, the result can not require gradients or be
+            disconnected from the inputs.  Defaults to ``False``.
+        strict (bool, optional): If ``True``, an error will be raised when we
+            detect that there exists an input such that all the outputs are
+            independent of it. If ``False``, we return a Tensor of zeros as the
+            jvp for said inputs, which is the expected mathematical value.
+            Defaults to ``False``.
+
+    Returns:
+        output (tuple): tuple with:
+            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
+
+            jvp (tuple of Tensors or Tensor): result of the dot product with
+            the same shape as the output.
+
+    Note:
+        ``autograd.functional.jvp`` computes the jvp by using the backward of
+        the backward (sometimes called the double backwards trick). This is not
+        the most performant way of computing the jvp. Please consider using
+        :func:`torch.func.jvp` or the
+        :ref:`low-level forward-mode AD API ` instead.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def exp_reducer(x):
+        ...     return x.exp().sum(dim=1)
+        >>> inputs = torch.rand(4, 4)
+        >>> v = torch.ones(4, 4)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> jvp(exp_reducer, inputs, v)
+        (tensor([6.3090, 4.6742, 7.9114, 8.2106]),
+         tensor([6.3090, 4.6742, 7.9114, 8.2106]))
+
+        >>> jvp(exp_reducer, inputs, v, create_graph=True)
+        (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=),
+         tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=))
+
+        >>> def adder(x, y):
+        ...     return 2 * x + 3 * y
+        >>> inputs = (torch.rand(2), torch.rand(2))
+        >>> v = (torch.ones(2), torch.ones(2))
+        >>> jvp(adder, inputs, v)
+        (tensor([2.2399, 2.5005]),
+         tensor([5., 5.]))
+
+    """
+    with torch.enable_grad():
+        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
+        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+        if v is not None:
+            _, v = _as_tuple(v, "v", "jvp")
+            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
+            _validate_v(v, inputs, is_inputs_tuple)
+        else:
+            if len(inputs) != 1 or inputs[0].nelement() != 1:
+                raise RuntimeError(
+                    "The vector v can only be None if the input to "
+                    "the user-provided function is a single Tensor "
+                    "with a single element."
+                )
+
+        outputs = func(*inputs)
+        is_outputs_tuple, outputs = _as_tuple(
+            outputs, "outputs of the user-provided function", "jvp"
+        )
+        _check_requires_grad(outputs, "outputs", strict=strict)
+        # The backward is linear so the value of grad_outputs is not important as
+        # it won't appear in the double backward graph. We only need to ensure that
+        # it does not contain inf or nan.
+        grad_outputs = tuple(
+            torch.zeros_like(out, requires_grad=True) for out in outputs
+        )
+
+        grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
+        _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
+
+    if create_graph:
+        with torch.enable_grad():
+            grad_res = _autograd_grad(
+                grad_inputs, grad_outputs, v, create_graph=create_graph
+            )
+            jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
+    else:
+        grad_res = _autograd_grad(
+            grad_inputs, grad_outputs, v, create_graph=create_graph
+        )
+        jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
+
+    # Cleanup objects and return them to the user
+    outputs = _grad_postprocess(outputs, create_graph)
+    jvp = _grad_postprocess(jvp, create_graph)
+
+    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
+        jvp, is_outputs_tuple
+    )
+
+
+def _construct_standard_basis_for(
+    tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
+) -> Tuple[torch.Tensor, ...]:
+    # This function:
+    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
+    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
+    # - Each chunk corresponds to one tensor. The chunk has the same dtype and
+    #   device as the tensor
+    #
+    # For example, with tensor_numels = [1, 2, 1], this function returns:
+    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],
+    #           [0],             [1, 0],              [0],
+    #           [0],             [0, 1],              [0],
+    #           [0]])  ,         [0, 0]])  ,          [1]])  )
+    #
+    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
+    # Precondition: tensors always has at least one element.
+    #
+    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
+    # for context behind this function. All the pre-conditions are guarded for
+    # in torch.autograd.functional.jacobian.
+    assert len(tensors) == len(tensor_numels)
+    assert len(tensors) > 0
+    total_numel = sum(tensor_numels)
+    chunks = tuple(
+        tensor.new_zeros(total_numel, tensor_numel)
+        for tensor, tensor_numel in zip(tensors, tensor_numels)
+    )
+    diag_start_idx = 0
+    for chunk, numel in zip(chunks, tensor_numels):
+        chunk.diagonal(diag_start_idx).fill_(1)
+        diag_start_idx -= numel
+    return chunks
+
+
+def _jacfwd(func, inputs, strict=False, vectorize=False):
+    if strict:
+        raise RuntimeError(
+            "torch.autograd.functional.jacobian: `strict=True` "
+            'and `strategy="forward-mode"` are not supported together (yet). '
+            "Please either set `strict=False` or "
+            '`strategy="reverse-mode"`.'
+        )
+    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
+    output_info = []
+
+    if vectorize:
+        # See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
+        input_numels = tuple(input.numel() for input in inputs)
+
+        # Step 1: Prepare tangents
+        tangents = _construct_standard_basis_for(inputs, input_numels)
+
+        # Step 2: Compute vmap over computation with dual tensors
+        def jvp(tangents):
+            with fwAD.dual_level():
+                dual_inputs = tuple(
+                    fwAD.make_dual(input, tangent.view_as(input))
+                    for input, tangent in zip(inputs, tangents)
+                )
+                _is_outputs_tuple, dual_outputs = _as_tuple(
+                    func(*dual_inputs), "outputs"
+                )
+                output_info.append(_is_outputs_tuple)
+                jv = []
+                primal_outs = []
+                for dual_out in dual_outputs:
+                    primal, tangent = fwAD.unpack_dual(dual_out)
+                    primal_outs.append(primal)
+                    if tangent is not None:
+                        jv.append(tangent)
+                    else:
+                        jv.append(torch.zeros_like(primal))
+                output_info.append(primal_outs)
+                return tuple(jv)
+
+        outputs_before_split = _vmap(jvp)(tangents)
+        is_outputs_tuple, outputs = output_info
+        # Step 3: for each of the output tangents, split along dim 0
+        jacobian_input_output = []
+        for jac_output_i, output_i in zip(outputs_before_split, outputs):
+            jacobian_output_i_output = []
+            for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs):
+                # We need to transpose the Jacobian because in forward AD, the
+                # batch dimension represents that of the inputs
+                jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape(
+                    (*output_i.shape, *input_j.shape)
+                )  # noqa: C409
+
+                jacobian_output_i_output.append(jacobian_input_i_output_j)
+            jacobian_input_output.append(jacobian_output_i_output)
+
+        # Omit [Step 4] because everything is already transposed w/ forward AD
+        return _tuple_postprocess(
+            jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)
+        )
+    else:
+        raise NotImplementedError(
+            "Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
+            "only implemented for `vectorize=True`."
+        )
+
+
+def jacobian(
+    func,
+    inputs,
+    create_graph=False,
+    strict=False,
+    vectorize=False,
+    strategy="reverse-mode",
+):
+    r"""Compute the Jacobian of a given function.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a tuple of Tensors or a Tensor.
+        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+        create_graph (bool, optional): If ``True``, the Jacobian will be
+            computed in a differentiable manner. Note that when ``strict`` is
+            ``False``, the result can not require gradients or be disconnected
+            from the inputs.  Defaults to ``False``.
+        strict (bool, optional): If ``True``, an error will be raised when we
+            detect that there exists an input such that all the outputs are
+            independent of it. If ``False``, we return a Tensor of zeros as the
+            jacobian for said inputs, which is the expected mathematical value.
+            Defaults to ``False``.
+        vectorize (bool, optional): This feature is experimental.
+            Please consider using :func:`torch.func.jacrev` or
+            :func:`torch.func.jacfwd` instead if you are looking for something
+            less experimental and more performant.
+            When computing the jacobian, usually we invoke
+            ``autograd.grad`` once per row of the jacobian. If this flag is
+            ``True``, we perform only a single ``autograd.grad`` call with
+            ``batched_grad=True`` which uses the vmap prototype feature.
+            Though this should lead to performance improvements in many cases,
+            because this feature is still experimental, there may be performance
+            cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
+            more information.
+        strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
+            determine whether the Jacobian will be computed with forward or reverse
+            mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
+            Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
+            inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
+            prefer to use ``"reverse-mode"``.
+
+    Returns:
+        Jacobian (Tensor or nested tuple of Tensors): if there is a single
+        input and output, this will be a single Tensor containing the
+        Jacobian for the linearized inputs and output. If one of the two is
+        a tuple, then the Jacobian will be a tuple of Tensors. If both of
+        them are tuples, then the Jacobian will be a tuple of tuple of
+        Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
+        ``i``\th output and ``j``\th input and will have as size the
+        concatenation of the sizes of the corresponding output and the
+        corresponding input and will have same dtype and device as the
+        corresponding input. If strategy is ``forward-mode``, the dtype will be
+        that of the output; otherwise, the input.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def exp_reducer(x):
+        ...     return x.exp().sum(dim=1)
+        >>> inputs = torch.rand(2, 2)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> jacobian(exp_reducer, inputs)
+        tensor([[[1.4917, 2.4352],
+                 [0.0000, 0.0000]],
+                [[0.0000, 0.0000],
+                 [2.4369, 2.3799]]])
+
+        >>> jacobian(exp_reducer, inputs, create_graph=True)
+        tensor([[[1.4917, 2.4352],
+                 [0.0000, 0.0000]],
+                [[0.0000, 0.0000],
+                 [2.4369, 2.3799]]], grad_fn=)
+
+        >>> def exp_adder(x, y):
+        ...     return 2 * x.exp() + 3 * y
+        >>> inputs = (torch.rand(2), torch.rand(2))
+        >>> jacobian(exp_adder, inputs)
+        (tensor([[2.8052, 0.0000],
+                [0.0000, 3.3963]]),
+         tensor([[3., 0.],
+                 [0., 3.]]))
+    """
+    assert strategy in ("forward-mode", "reverse-mode"), (
+        'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
+        'function has more outputs than inputs, "forward-mode" tends to be more performant. '
+        'Otherwise, prefer to use "reverse-mode".'
+    )
+    if strategy == "forward-mode":
+        if create_graph:
+            raise NotImplementedError(
+                "torch.autograd.functional.jacobian: `create_graph=True` "
+                'and `strategy="forward-mode"` are not supported together (yet). '
+                "Please either set `create_graph=False` or "
+                '`strategy="reverse-mode"`.'
+            )
+        return _jacfwd(func, inputs, strict, vectorize)
+
+    with torch.enable_grad():
+        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
+        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+        outputs = func(*inputs)
+        is_outputs_tuple, outputs = _as_tuple(
+            outputs, "outputs of the user-provided function", "jacobian"
+        )
+        _check_requires_grad(outputs, "outputs", strict=strict)
+
+        if vectorize:
+            if strict:
+                raise RuntimeError(
+                    "torch.autograd.functional.jacobian: `strict=True` "
+                    "and `vectorized=True` are not supported together. "
+                    "Please either set `strict=False` or "
+                    "`vectorize=False`."
+                )
+            # NOTE: [Computing jacobian with vmap and grad for multiple outputs]
+            #
+            # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
+            # It turns out we can compute the jacobian of this function with a single
+            # call to autograd.grad by using vmap over the correct grad_outputs.
+            #
+            # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
+            # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
+            #
+            # To get the first row of the jacobian, we call
+            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
+            # To get the 2nd row of the jacobian, we call
+            # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
+            # and so on.
+            #
+            # Using vmap, we can vectorize all 4 of these computations into one by
+            # passing the standard basis for R^4 as the grad_output.
+            # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
+            #
+            # Now, how do we compute the jacobian *without stacking the output*?
+            # We can just split the standard basis across the outputs. So to
+            # compute the jacobian of f(x), we'd use
+            # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
+            # The grad_outputs looks like the following:
+            # ( torch.tensor([[1, 0, 0],
+            #                 [0, 1, 0],
+            #                 [0, 0, 1],
+            #                 [0, 0, 0]]),
+            #   torch.tensor([[0],
+            #                 [0],
+            #                 [0],
+            #                 [1]]) )
+            #
+            # But we're not done yet!
+            # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
+            # returns a Tensor of shape [4, 3]. We have to remember to split the
+            # jacobian of shape [4, 3] into two:
+            # - one of shape [3, 3] for the first output
+            # - one of shape [   3] for the second output
+
+            # Step 1: Construct grad_outputs by splitting the standard basis
+            output_numels = tuple(output.numel() for output in outputs)
+            grad_outputs = _construct_standard_basis_for(outputs, output_numels)
+            flat_outputs = tuple(output.reshape(-1) for output in outputs)
+
+            # Step 2: Call vmap + autograd.grad
+            def vjp(grad_output):
+                vj = list(
+                    _autograd_grad(
+                        flat_outputs,
+                        inputs,
+                        grad_output,
+                        create_graph=create_graph,
+                        is_grads_batched=True,
+                    )
+                )
+                for el_idx, vj_el in enumerate(vj):
+                    if vj_el is not None:
+                        continue
+                    vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand(
+                        (sum(output_numels),) + inputs[el_idx].shape
+                    )
+                return tuple(vj)
+
+            jacobians_of_flat_output = vjp(grad_outputs)
+
+            # Step 3: The returned jacobian is one big tensor per input. In this step,
+            # we split each Tensor by output.
+            jacobian_input_output = []
+            for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):
+                jacobian_input_i_output = []
+                for jac, output_j in zip(
+                    jac_input_i.split(output_numels, dim=0), outputs
+                ):
+                    jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
+                    jacobian_input_i_output.append(jacobian_input_i_output_j)
+                jacobian_input_output.append(jacobian_input_i_output)
+
+            # Step 4: Right now, `jacobian` is a List[List[Tensor]].
+            # The outer List corresponds to the number of inputs,
+            # the inner List corresponds to the number of outputs.
+            # We need to exchange the order of these and convert to tuples
+            # before returning.
+            jacobian_output_input = tuple(zip(*jacobian_input_output))
+
+            jacobian_output_input = _grad_postprocess(
+                jacobian_output_input, create_graph
+            )
+            return _tuple_postprocess(
+                jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
+            )
+
+        jacobian: Tuple[torch.Tensor, ...] = tuple()
+
+        for i, out in enumerate(outputs):
+            # mypy complains that expression and variable have different types due to the empty list
+            jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs)))  # type: ignore[assignment]
+            for j in range(out.nelement()):
+                vj = _autograd_grad(
+                    (out.reshape(-1)[j],),
+                    inputs,
+                    retain_graph=True,
+                    create_graph=create_graph,
+                )
+
+                for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(
+                    zip(jac_i, vj, inputs)
+                ):
+                    if vj_el is not None:
+                        if strict and create_graph and not vj_el.requires_grad:
+                            msg = (
+                                "The jacobian of the user-provided function is "
+                                f"independent of input {i}. This is not allowed in "
+                                "strict mode when create_graph=True."
+                            )
+                            raise RuntimeError(msg)
+                        jac_i_el.append(vj_el)
+                    else:
+                        if strict:
+                            msg = (
+                                f"Output {i} of the user-provided function is "
+                                f"independent of input {el_idx}. This is not allowed in "
+                                "strict mode."
+                            )
+                            raise RuntimeError(msg)
+                        jac_i_el.append(torch.zeros_like(inp_el))
+
+            jacobian += (
+                tuple(
+                    torch.stack(jac_i_el, dim=0).view(
+                        out.size() + inputs[el_idx].size()  # type: ignore[operator]
+                    )
+                    for (el_idx, jac_i_el) in enumerate(jac_i)
+                ),
+            )
+
+        jacobian = _grad_postprocess(jacobian, create_graph)
+
+        return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
+
+
+def hessian(
+    func,
+    inputs,
+    create_graph=False,
+    strict=False,
+    vectorize=False,
+    outer_jacobian_strategy="reverse-mode",
+):
+    r"""Compute the Hessian of a given scalar function.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a Tensor with a single element.
+        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+        create_graph (bool, optional): If ``True``, the Hessian will be computed in
+            a differentiable manner. Note that when ``strict`` is ``False``, the result can not
+            require gradients or be disconnected from the inputs.
+            Defaults to ``False``.
+        strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
+            such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
+            hessian for said inputs, which is the expected mathematical value.
+            Defaults to ``False``.
+        vectorize (bool, optional): This feature is experimental.
+            Please consider using :func:`torch.func.hessian`
+            instead if you are looking for something less experimental and more performant.
+            When computing the hessian, usually we invoke
+            ``autograd.grad`` once per row of the hessian. If this flag is
+            ``True``, we use the vmap prototype feature as the backend to
+            vectorize calls to ``autograd.grad`` so we only invoke it once
+            instead of once per row. This should lead to performance
+            improvements in many use cases, however, due to this feature
+            being incomplete, there may be performance cliffs. Please
+            use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
+            to show any performance warnings and file us issues if
+            warnings exist for your use case. Defaults to ``False``.
+        outer_jacobian_strategy (str, optional): The Hessian is computed by
+            computing the Jacobian of a Jacobian. The inner Jacobian is always
+            computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
+            or ``"reverse-mode"`` determines whether the outer Jacobian will be
+            computed with forward or reverse mode AD. Currently, computing the outer
+            Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
+            to ``"reverse-mode"``.
+
+    Returns:
+        Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
+        this will be a single Tensor containing the Hessian for the input.
+        If it is a tuple, then the Hessian will be a tuple of tuples where
+        ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
+        and ``j``\th input with size the sum of the size of the ``i``\th input plus
+        the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
+        dtype and device as the corresponding ``i``\th input.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def pow_reducer(x):
+        ...     return x.pow(3).sum()
+        >>> inputs = torch.rand(2, 2)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> hessian(pow_reducer, inputs)
+        tensor([[[[5.2265, 0.0000],
+                  [0.0000, 0.0000]],
+                 [[0.0000, 4.8221],
+                  [0.0000, 0.0000]]],
+                [[[0.0000, 0.0000],
+                  [1.9456, 0.0000]],
+                 [[0.0000, 0.0000],
+                  [0.0000, 3.2550]]]])
+
+        >>> hessian(pow_reducer, inputs, create_graph=True)
+        tensor([[[[5.2265, 0.0000],
+                  [0.0000, 0.0000]],
+                 [[0.0000, 4.8221],
+                  [0.0000, 0.0000]]],
+                [[[0.0000, 0.0000],
+                  [1.9456, 0.0000]],
+                 [[0.0000, 0.0000],
+                  [0.0000, 3.2550]]]], grad_fn=)
+
+
+        >>> def pow_adder_reducer(x, y):
+        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
+        >>> inputs = (torch.rand(2), torch.rand(2))
+        >>> hessian(pow_adder_reducer, inputs)
+        ((tensor([[4., 0.],
+                  [0., 4.]]),
+          tensor([[0., 0.],
+                  [0., 0.]])),
+         (tensor([[0., 0.],
+                  [0., 0.]]),
+          tensor([[6., 0.],
+                  [0., 6.]])))
+    """
+    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
+    assert outer_jacobian_strategy in (
+        "forward-mode",
+        "reverse-mode",
+    ), 'Expected strategy to be either "forward-mode" or "reverse-mode".'
+
+    def ensure_single_output_function(*inp):
+        out = func(*inp)
+        is_out_tuple, t_out = _as_tuple(
+            out, "outputs of the user-provided function", "hessian"
+        )
+        _check_requires_grad(t_out, "outputs", strict=strict)
+
+        if is_out_tuple or not isinstance(out, torch.Tensor):
+            raise RuntimeError(
+                "The function given to hessian should return a single Tensor"
+            )
+
+        if out.nelement() != 1:
+            raise RuntimeError(
+                "The Tensor returned by the function given to hessian should contain a single element"
+            )
+
+        return out.squeeze()
+
+    def jac_func(*inp):
+        if outer_jacobian_strategy == "forward-mode":
+            # _grad_preprocess requires create_graph=True and input to require_grad
+            # or else the input will be detached
+            inp = tuple(t.requires_grad_(True) for t in inp)
+        jac = jacobian(ensure_single_output_function, inp, create_graph=True)
+        _check_requires_grad(jac, "jacobian", strict=strict)
+        return jac
+
+    res = jacobian(
+        jac_func,
+        inputs,
+        create_graph=create_graph,
+        strict=strict,
+        vectorize=vectorize,
+        strategy=outer_jacobian_strategy,
+    )
+    return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
+
+
+def vhp(func, inputs, v=None, create_graph=False, strict=False):
+    r"""Compute the dot product between vector ``v`` and Hessian of a  given scalar function at a specified point.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a Tensor with a single element.
+        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+        v (tuple of Tensors or Tensor): The vector for which the vector Hessian
+            product is computed. Must be the same size as the input of
+            ``func``. This argument is optional when ``func``'s input contains
+            a single element and (if it is not provided) will be set as a
+            Tensor containing a single ``1``.
+        create_graph (bool, optional): If ``True``, both the output and result
+            will be computed in a differentiable way. Note that when ``strict``
+            is ``False``, the result can not require gradients or be
+            disconnected from the inputs.
+            Defaults to ``False``.
+        strict (bool, optional): If ``True``, an error will be raised when we
+            detect that there exists an input such that all the outputs are
+            independent of it. If ``False``, we return a Tensor of zeros as the
+            vhp for said inputs, which is the expected mathematical value.
+            Defaults to ``False``.
+
+    Returns:
+        output (tuple): tuple with:
+            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
+
+            vhp (tuple of Tensors or Tensor): result of the dot product with the
+            same shape as the inputs.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def pow_reducer(x):
+        ...     return x.pow(3).sum()
+        >>> inputs = torch.rand(2, 2)
+        >>> v = torch.ones(2, 2)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> vhp(pow_reducer, inputs, v)
+        (tensor(0.5591),
+         tensor([[1.0689, 1.2431],
+                 [3.0989, 4.4456]]))
+        >>> vhp(pow_reducer, inputs, v, create_graph=True)
+        (tensor(0.5591, grad_fn=),
+         tensor([[1.0689, 1.2431],
+                 [3.0989, 4.4456]], grad_fn=))
+        >>> def pow_adder_reducer(x, y):
+        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
+        >>> inputs = (torch.rand(2), torch.rand(2))
+        >>> v = (torch.zeros(2), torch.ones(2))
+        >>> vhp(pow_adder_reducer, inputs, v)
+        (tensor(4.8053),
+         (tensor([0., 0.]),
+          tensor([6., 6.])))
+    """
+    with torch.enable_grad():
+        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
+        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+        if v is not None:
+            _, v = _as_tuple(v, "v", "vhp")
+            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
+            _validate_v(v, inputs, is_inputs_tuple)
+        else:
+            if len(inputs) != 1 or inputs[0].nelement() != 1:
+                raise RuntimeError(
+                    "The vector v can only be None if the input to the user-provided function "
+                    "is a single Tensor with a single element."
+                )
+        outputs = func(*inputs)
+        is_outputs_tuple, outputs = _as_tuple(
+            outputs, "outputs of the user-provided function", "vhp"
+        )
+        _check_requires_grad(outputs, "outputs", strict=strict)
+
+        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
+            raise RuntimeError(
+                "The function given to vhp should return a single Tensor"
+            )
+
+        if outputs[0].nelement() != 1:
+            raise RuntimeError(
+                "The Tensor returned by the function given to vhp should contain a single element"
+            )
+
+        jac = _autograd_grad(outputs, inputs, create_graph=True)
+        _check_requires_grad(jac, "jacobian", strict=strict)
+
+    enable_grad = True if create_graph else torch.is_grad_enabled()
+    with torch.set_grad_enabled(enable_grad):
+        grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
+        vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
+
+    outputs = _grad_postprocess(outputs, create_graph)
+    vhp = _grad_postprocess(vhp, create_graph)
+
+    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
+        vhp, is_inputs_tuple
+    )
+
+
+def hvp(func, inputs, v=None, create_graph=False, strict=False):
+    r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a Tensor with a single element.
+        inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
+        v (tuple of Tensors or Tensor): The vector for which the Hessian vector
+            product is computed. Must be the same size as the input of
+            ``func``. This argument is optional when ``func``'s input contains
+            a single element and (if it is not provided) will be set as a
+            Tensor containing a single ``1``.
+        create_graph (bool, optional): If ``True``, both the output and result will be
+            computed in a differentiable way. Note that when ``strict`` is
+            ``False``, the result can not require gradients or be disconnected
+            from the inputs.  Defaults to ``False``.
+        strict (bool, optional): If ``True``, an error will be raised when we
+            detect that there exists an input such that all the outputs are
+            independent of it. If ``False``, we return a Tensor of zeros as the
+            hvp for said inputs, which is the expected mathematical value.
+            Defaults to ``False``.
+    Returns:
+        output (tuple): tuple with:
+            func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
+
+            hvp (tuple of Tensors or Tensor): result of the dot product with
+            the same shape as the inputs.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def pow_reducer(x):
+        ...     return x.pow(3).sum()
+        >>> inputs = torch.rand(2, 2)
+        >>> v = torch.ones(2, 2)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> hvp(pow_reducer, inputs, v)
+        (tensor(0.1448),
+         tensor([[2.0239, 1.6456],
+                 [2.4988, 1.4310]]))
+
+        >>> hvp(pow_reducer, inputs, v, create_graph=True)
+        (tensor(0.1448, grad_fn=),
+         tensor([[2.0239, 1.6456],
+                 [2.4988, 1.4310]], grad_fn=))
+
+
+        >>> def pow_adder_reducer(x, y):
+        ...     return (2 * x.pow(2) + 3 * y.pow(2)).sum()
+        >>> inputs = (torch.rand(2), torch.rand(2))
+        >>> v = (torch.zeros(2), torch.ones(2))
+        >>> hvp(pow_adder_reducer, inputs, v)
+        (tensor(2.3030),
+         (tensor([0., 0.]),
+          tensor([6., 6.])))
+
+    Note:
+
+        This function is significantly slower than `vhp` due to backward mode AD constraints.
+        If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
+        know that your function satisfies this condition, you should use vhp instead that is
+        much faster with the current implementation.
+
+    """
+    with torch.enable_grad():
+        is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
+        inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
+
+        if v is not None:
+            _, v = _as_tuple(v, "v", "hvp")
+            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
+            _validate_v(v, inputs, is_inputs_tuple)
+        else:
+            if len(inputs) != 1 or inputs[0].nelement() != 1:
+                raise RuntimeError(
+                    "The vector v can only be None if the input to the user-provided function "
+                    "is a single Tensor with a single element."
+                )
+        outputs = func(*inputs)
+        is_outputs_tuple, outputs = _as_tuple(
+            outputs, "outputs of the user-provided function", "hvp"
+        )
+        _check_requires_grad(outputs, "outputs", strict=strict)
+
+        if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
+            raise RuntimeError(
+                "The function given to hvp should return a single Tensor"
+            )
+
+        if outputs[0].nelement() != 1:
+            raise RuntimeError(
+                "The Tensor returned by the function given to hvp should contain a single element"
+            )
+
+        jac = _autograd_grad(outputs, inputs, create_graph=True)
+        _check_requires_grad(jac, "jacobian", strict=strict)
+
+        grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
+
+        double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
+        _check_requires_grad(jac, "hessian", strict=strict)
+
+    enable_grad = True if create_graph else torch.is_grad_enabled()
+    with torch.set_grad_enabled(enable_grad):
+        grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
+        hvp = _fill_in_zeros(
+            grad_res, inputs, strict, create_graph, "double_back_trick"
+        )
+
+    outputs = _grad_postprocess(outputs, create_graph)
+    hvp = _grad_postprocess(hvp, create_graph)
+
+    return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
+        hvp, is_inputs_tuple
+    )
diff --git a/MLPY/Lib/site-packages/torch/autograd/grad_mode.py b/MLPY/Lib/site-packages/torch/autograd/grad_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6ecb5fcea193967d5fdf160f1cfce20b3c37cf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/grad_mode.py
@@ -0,0 +1,396 @@
+from typing import Any
+
+import torch
+
+from torch.utils._contextlib import (
+    _DecoratorContextManager,
+    _NoParamDecoratorContextManager,
+    F,
+)
+
+__all__ = [
+    "no_grad",
+    "enable_grad",
+    "set_grad_enabled",
+    "inference_mode",
+    "set_multithreading_enabled",
+]
+
+
+class no_grad(_NoParamDecoratorContextManager):
+    r"""Context-manager that disables gradient calculation.
+
+    Disabling gradient calculation is useful for inference, when you are sure
+    that you will not call :meth:`Tensor.backward()`. It will reduce memory
+    consumption for computations that would otherwise have `requires_grad=True`.
+
+    In this mode, the result of every computation will have
+    `requires_grad=False`, even when the inputs have `requires_grad=True`.
+    There is an exception! All factory functions, or functions that create
+    a new Tensor and take a requires_grad kwarg, will NOT be affected by
+    this mode.
+
+    This context manager is thread local; it will not affect computation
+    in other threads.
+
+    Also functions as a decorator.
+
+    .. note::
+        No-grad is one of several mechanisms that can enable or
+        disable gradients locally see :ref:`locally-disable-grad-doc` for
+        more information on how they compare.
+
+    .. note::
+        This API does not apply to :ref:`forward-mode AD `.
+        If you want to disable forward AD for a computation, you can unpack
+        your dual tensors.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> x = torch.tensor([1.], requires_grad=True)
+        >>> with torch.no_grad():
+        ...     y = x * 2
+        >>> y.requires_grad
+        False
+        >>> @torch.no_grad()
+        ... def doubler(x):
+        ...     return x * 2
+        >>> z = doubler(x)
+        >>> z.requires_grad
+        False
+        >>> @torch.no_grad
+        ... def tripler(x):
+        ...     return x * 3
+        >>> z = tripler(x)
+        >>> z.requires_grad
+        False
+        >>> # factory function exception
+        >>> with torch.no_grad():
+        ...     a = torch.nn.Parameter(torch.rand(10))
+        >>> a.requires_grad
+        True
+    """
+
+    def __init__(self) -> None:
+        if not torch._jit_internal.is_scripting():
+            super().__init__()
+        self.prev = False
+
+    def __enter__(self) -> None:
+        self.prev = torch.is_grad_enabled()
+        torch.set_grad_enabled(False)
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        torch.set_grad_enabled(self.prev)
+
+
+class enable_grad(_NoParamDecoratorContextManager):
+    r"""Context-manager that enables gradient calculation.
+
+    Enables gradient calculation, if it has been disabled via :class:`~no_grad`
+    or :class:`~set_grad_enabled`.
+
+    This context manager is thread local; it will not affect computation
+    in other threads.
+
+    Also functions as a decorator.
+
+    .. note::
+        enable_grad is one of several mechanisms that can enable or
+        disable gradients locally see :ref:`locally-disable-grad-doc` for
+        more information on how they compare.
+
+    .. note::
+        This API does not apply to :ref:`forward-mode AD `.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> x = torch.tensor([1.], requires_grad=True)
+        >>> with torch.no_grad():
+        ...     with torch.enable_grad():
+        ...         y = x * 2
+        >>> y.requires_grad
+        True
+        >>> y.backward()
+        >>> x.grad
+        tensor([2.])
+        >>> @torch.enable_grad()
+        ... def doubler(x):
+        ...     return x * 2
+        >>> with torch.no_grad():
+        ...     z = doubler(x)
+        >>> z.requires_grad
+        True
+        >>> @torch.enable_grad
+        ... def tripler(x):
+        ...     return x * 3
+        >>> with torch.no_grad():
+        ...     z = tripler(x)
+        >>> z.requires_grad
+        True
+
+    """
+
+    def __enter__(self) -> None:
+        self.prev = torch.is_grad_enabled()
+        torch._C._set_grad_enabled(True)
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        torch._C._set_grad_enabled(self.prev)
+
+
+class set_grad_enabled(_DecoratorContextManager):
+    r"""Context-manager that sets gradient calculation on or off.
+
+    ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
+    It can be used as a context-manager or as a function.
+
+    This context manager is thread local; it will not affect computation
+    in other threads.
+
+    Args:
+        mode (bool): Flag whether to enable grad (``True``), or disable
+                     (``False``). This can be used to conditionally enable
+                     gradients.
+
+    .. note::
+        set_grad_enabled is one of several mechanisms that can enable or
+        disable gradients locally see :ref:`locally-disable-grad-doc` for
+        more information on how they compare.
+
+    .. note::
+        This API does not apply to :ref:`forward-mode AD `.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> x = torch.tensor([1.], requires_grad=True)
+        >>> is_train = False
+        >>> with torch.set_grad_enabled(is_train):
+        ...     y = x * 2
+        >>> y.requires_grad
+        False
+        >>> _ = torch.set_grad_enabled(True)
+        >>> y = x * 2
+        >>> y.requires_grad
+        True
+        >>> _ = torch.set_grad_enabled(False)
+        >>> y = x * 2
+        >>> y.requires_grad
+        False
+
+    """
+
+    def __init__(self, mode: bool) -> None:
+        self.prev = torch.is_grad_enabled()
+        self.mode = mode
+        torch._C._set_grad_enabled(mode)
+
+    def __call__(self, orig_func: F) -> F:
+        torch._C._set_grad_enabled(self.prev)
+        return super().__call__(orig_func)
+
+    def __enter__(self) -> None:
+        torch._C._set_grad_enabled(self.mode)
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        torch._C._set_grad_enabled(self.prev)
+
+    def clone(self) -> "set_grad_enabled":
+        r"""
+        Create a copy of this class
+        """
+        return self.__class__(self.mode)
+
+
+class inference_mode(_DecoratorContextManager):
+    r"""Context-manager that enables or disables inference mode.
+
+    InferenceMode is a new context manager analogous to :class:`~no_grad`
+    to be used when you are certain your operations will have no interactions
+    with autograd (e.g., model training). Code run under this mode gets better
+    performance by disabling view tracking and version counter bumps. Note that
+    unlike some other mechanisms that locally enable or disable grad,
+    entering inference_mode also disables to :ref:`forward-mode AD `.
+
+    This context manager is thread local; it will not affect computation
+    in other threads.
+
+    Also functions as a decorator.
+
+    .. note::
+        Inference mode is one of several mechanisms that can enable or
+        disable gradients locally see :ref:`locally-disable-grad-doc` for
+        more information on how they compare.
+
+    Args:
+        mode (bool or function): Either a boolean flag whether to enable or
+            disable inference mode or a Python function to decorate with
+            inference mode enabled
+
+    Example::
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> import torch
+        >>> x = torch.ones(1, 2, 3, requires_grad=True)
+        >>> with torch.inference_mode():
+        ...     y = x * x
+        >>> y.requires_grad
+        False
+        >>> # xdoctest: +SKIP("want string isnt quite right")
+        >>> y._version
+        Traceback (most recent call last):
+        File "", line 1, in 
+        RuntimeError: Inference tensors do not track version counter.
+        >>> @torch.inference_mode()
+        ... def func(x):
+        ...     return x * x
+        >>> out = func(x)
+        >>> out.requires_grad
+        False
+        >>> @torch.inference_mode
+        ... def doubler(x):
+        ...     return x * 2
+        >>> out = doubler(x)
+        >>> out.requires_grad
+        False
+
+    """
+
+    def __init__(self, mode: bool = True) -> None:
+        if not torch._jit_internal.is_scripting():
+            super().__init__()
+        self.mode = mode
+
+    def __new__(cls, mode=True):
+        if isinstance(mode, bool):
+            return super().__new__(cls)
+        return cls()(mode)
+
+    def __enter__(self) -> None:
+        self._inference_mode_context = torch._C._InferenceMode(self.mode)
+        self._inference_mode_context.__enter__()
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
+
+    def clone(self) -> "inference_mode":
+        r"""
+        Create a copy of this class
+        """
+        return self.__class__(self.mode)
+
+
+def _enter_inference_mode(mode):
+    mode_context = torch._C._InferenceMode(mode)
+    mode_context.__enter__()
+    return mode_context
+
+
+def _exit_inference_mode(mode):
+    mode.__exit__(None, None, None)
+
+
+class set_multithreading_enabled(_DecoratorContextManager):
+    r"""Context-manager that sets multithreaded backwards on or off.
+
+    ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
+    It can be used as a context-manager or as a function.
+
+    This context manager is thread local; it will not affect computation
+    in other threads.
+
+    Args:
+        mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
+                     (``False``).
+
+    .. note::
+        This API does not apply to :ref:`forward-mode AD `.
+
+    """
+
+    def __init__(self, mode: bool) -> None:
+        self.prev = torch._C._is_multithreading_enabled()
+        torch._C._set_multithreading_enabled(mode)
+        self.mode = mode
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        torch._C._set_multithreading_enabled(self.prev)
+
+    def clone(self) -> "set_multithreading_enabled":
+        r"""
+        Create a copy of this class
+        """
+        return self.__class__(self.mode)
+
+
+class _force_original_view_tracking(_DecoratorContextManager):
+    r"""Context-manager that sets whether or not to always enable view-replay in autograd.
+
+    ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
+    It can be used as a context-manager or as a function.
+
+    This context manager is thread local; it will not affect computation
+    in other threads.
+
+    When a tensor view is mutated, the autograd engine needs to decide whether or not
+    to regenerate the "updated view" by either replaying the chain of views from the updated base,
+    or with a single call to as_strided.
+
+    If set_view_replay_enabled is set to True, then autograd will always use view replay.
+    Otherwise, it will fall back to its existing logic.
+
+    Args:
+        mode (bool): Flag whether to enable view-replay (``True``), or disable
+                     (``False``).
+
+    """
+
+    def __init__(self, mode: bool) -> None:
+        self.prev = torch._C._is_view_replay_enabled()
+        torch._C._set_view_replay_enabled(mode)
+        self.mode = mode
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        torch._C._set_view_replay_enabled(self.prev)
+
+    def clone(self):
+        return self.__class__(self.mode)
+
+
+class _unsafe_preserve_version_counter(_DecoratorContextManager):
+    r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
+
+    This context manager can lead to arbitrary silent-correctness issues in any other part of your code
+    (even the ones not touched directly by the context manager)!
+
+    Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
+    This is generally important for correctness, as for example, mutating a tensor that autograd has saved
+    for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
+    and error out in this situation.
+
+    However, there are rare instances where it might be useful to hide mutations from autograd. For example:
+    if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
+    the tensor right before it is needed by autograd.
+
+    Args:
+        tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
+
+    .. note::
+        This API does not apply to :ref:`forward-mode AD `.
+
+    """
+
+    def __init__(self, tensor: torch.Tensor) -> None:
+        self.tensor = tensor
+        self.prev_version = tensor._version
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, *args) -> None:
+        torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
diff --git a/MLPY/Lib/site-packages/torch/autograd/gradcheck.py b/MLPY/Lib/site-packages/torch/autograd/gradcheck.py
new file mode 100644
index 0000000000000000000000000000000000000000..7505c2fd5ff7d4112996d78863eadda8b184846c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/gradcheck.py
@@ -0,0 +1,2266 @@
+import collections
+import functools
+import warnings
+from itertools import product
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import torch
+import torch.testing
+from torch._vmap_internals import _vmap, vmap
+from torch.overrides import is_tensor_like
+from torch.types import _TensorOrTensors
+
+# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
+# since they have been exposed from before we added `__all__`  and we already maintain BC for them
+# We should eventually deprecate them and remove them from `__all__`
+__all__ = [
+    "gradcheck",
+    "gradgradcheck",
+    "GradcheckError",
+    "get_numerical_jacobian",
+    "get_analytical_jacobian",
+    "get_numerical_jacobian_wrt_specific_input",
+]
+
+
+class GradcheckError(RuntimeError):
+    r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`."""
+
+    pass
+
+
+def _is_sparse_compressed_tensor(obj: torch.Tensor):
+    return obj.layout in {
+        torch.sparse_csr,
+        torch.sparse_csc,
+        torch.sparse_bsr,
+        torch.sparse_bsc,
+    }
+
+
+def _is_sparse_any_tensor(obj: torch.Tensor):
+    return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo
+
+
+def _is_float_or_complex_tensor(obj):
+    return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex())
+
+
+def _allocate_jacobians_with_inputs(
+    input_tensors: Tuple, numel_output
+) -> Tuple[torch.Tensor, ...]:
+    # Makes zero-filled tensors from inputs. If `numel_output` is not None, for
+    # each tensor in `input_tensors`, returns a new zero-filled tensor with height
+    # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
+    # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
+    # the same dtype and device as those of the corresponding input.
+    out: List[torch.Tensor] = []
+    for t in input_tensors:
+        if _is_float_or_complex_tensor(t) and t.requires_grad:
+            out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided))
+    return tuple(out)
+
+
+def _allocate_jacobians_with_outputs(
+    output_tensors: Tuple, numel_input, dtype=None, device=None
+) -> Tuple[torch.Tensor, ...]:
+    # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor
+    # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
+    # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
+    # (t.numel,).
+    out: List[torch.Tensor] = []
+    options = {"dtype": dtype, "device": device, "layout": torch.strided}
+    for t in output_tensors:
+        if _is_float_or_complex_tensor(t):
+            out.append(t.new_zeros((numel_input, t.numel()), **options))
+    return tuple(out)
+
+
+def _iter_tensors(
+    x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False
+) -> Iterable[torch.Tensor]:
+    if is_tensor_like(x):
+        # mypy doesn't narrow type of `x` to torch.Tensor
+        if x.requires_grad or not only_requiring_grad:  # type: ignore[union-attr]
+            yield x  # type: ignore[misc]
+    elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+        for elem in x:
+            yield from _iter_tensors(elem, only_requiring_grad)
+
+
+def _densify(x):
+    # return a copy of sparse x with all unspecified elements
+    # "replaced" with zero-valued elements
+    if isinstance(x, (list, tuple)):
+        return type(x)(map(_densify, x))
+    elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}:  # type: ignore[attr-defined] # no attr _mkldnn
+        return x
+    elif x.layout is torch.sparse_coo:
+        device = x.device
+        indices_dtype = x._indices().dtype
+        tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device)
+        indices = tmp.nonzero().t().to(dtype=indices_dtype)
+        values = torch.zeros(
+            (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device
+        )
+        x_coalesced = x.detach().coalesce()
+        if x_coalesced.numel() > 0:
+            stride = tmp.stride()
+            flat_indices = (
+                x_coalesced.indices()
+                .mul(
+                    torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze(
+                        1
+                    )
+                )
+                .sum(0)
+            )
+            values[flat_indices] = x_coalesced.values()
+        return (
+            torch.sparse_coo_tensor(indices, values, x.shape)
+            ._coalesced_(True)
+            .requires_grad_(x.requires_grad)
+        )
+    elif _is_sparse_compressed_tensor(x):
+        blocksize = (
+            x.values().shape[1:3]
+            if x.layout in {torch.sparse_bsr, torch.sparse_bsc}
+            else None
+        )
+        compressed_indices = (
+            x.crow_indices()
+            if x.layout in {torch.sparse_csr, torch.sparse_bsr}
+            else x.ccol_indices()
+        )
+        # We'll use intermediate sparse COO for simplicity
+        r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse(
+            layout=x.layout, blocksize=blocksize
+        )
+        # Check that all elements are specified also after `to_sparse` op:
+        dense_numel = r.values().numel() // max(1, r.values().shape[0])
+        batch_numel = compressed_indices.numel() // compressed_indices.shape[-1]
+        sparse_numel = r.numel() // max(1, dense_numel * batch_numel)
+        if sparse_numel != r._nnz():
+            raise AssertionError(
+                f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}"
+            )
+        return r.requires_grad_(x.requires_grad)
+    elif _is_sparse_any_tensor(x):
+        raise NotImplementedError(x.layout)
+    return x
+
+
+def _iter_tensor(x_tensor):
+    # (Only used for slow gradcheck) Returns a generator that yields the following
+    # elements at each iteration:
+    #  1) a tensor: the same tensor is returned across all iterations. The tensor
+    #     is not the same as the original x_tensor as given as input - it is
+    #     prepared so that it can be modified in-place. Depending on whether the
+    #     input tensor is strided, sparse, or dense, the returned tensor may or may
+    #     not share storage with x_tensor.
+    #  2) a tuple of indices that can be used with advanced indexing (yielded in
+    #     dictionary order)
+    #  3) flattened index that will be used to index into the Jacobian tensor
+    #
+    # For a tensor t with size (2, 2), _iter_tensor yields:
+    #     `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3`
+    #
+    # where x is the t.data of the original tensor. Perturbing the entry of x
+    # at index (1, 1) yields the 3rd column of the overall Jacobian matrix.
+    if _is_sparse_any_tensor(x_tensor):
+
+        def get_stride(size):
+            dim = len(size)
+            tmp = 1
+            stride = [0] * dim
+            for i in reversed(range(dim)):
+                stride[i] = tmp
+                tmp *= size[i]
+            return stride
+
+        x_nnz = x_tensor._nnz()
+        x_size = list(x_tensor.size())
+        if x_tensor.layout is torch.sparse_coo:
+            x_indices = x_tensor._indices().t()
+            x_values = x_tensor._values()
+        elif x_tensor.layout is torch.sparse_csr:
+            x_indices = torch._convert_indices_from_csr_to_coo(
+                x_tensor.crow_indices(), x_tensor.col_indices()
+            ).t()
+            x_values = x_tensor.values()
+        elif x_tensor.layout is torch.sparse_csc:
+            x_indices = torch._convert_indices_from_csr_to_coo(
+                x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
+            ).t()
+            x_values = x_tensor.values()
+        elif x_tensor.layout is torch.sparse_bsr:
+            x_block_values = x_tensor.values()
+            x_blocksize = x_block_values.size()[1:3]
+            x_indices = (
+                torch._convert_indices_from_csr_to_coo(
+                    x_tensor.crow_indices(), x_tensor.col_indices()
+                )
+                .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1)
+                .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
+                .add_(
+                    torch.stack(
+                        torch.where(torch.ones(x_blocksize, device=x_tensor.device))
+                    ).repeat(1, x_nnz)
+                )
+                .t()
+            )
+            x_values = x_block_values.flatten(0, 2)
+            x_nnz = x_values.size(0)
+        elif x_tensor.layout is torch.sparse_bsc:
+            x_block_values = x_tensor.values()
+            x_blocksize = x_block_values.size()[1:3]
+            x_indices = (
+                torch._convert_indices_from_csr_to_coo(
+                    x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
+                )
+                .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1)
+                .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
+                .add_(
+                    torch.stack(
+                        torch.where(torch.ones(x_blocksize, device=x_tensor.device))
+                    ).repeat(1, x_nnz)
+                )
+                .t()
+            )
+            x_values = x_block_values.flatten(0, 2)
+            x_nnz = x_values.size(0)
+        else:
+            raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input")
+        x_stride = get_stride(x_size)
+        # Use .data here to get around the version check
+        x_values = x_values.data
+        for i in range(x_nnz):
+            x_value = x_values[i]
+            for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
+                indices = x_indices[i].tolist() + list(x_idx)
+                d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
+                yield x_value, x_idx, d_idx
+    elif x_tensor.layout == torch._mkldnn:  # type: ignore[attr-defined]
+        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
+            # this is really inefficient, but without indexing implemented, there's
+            # not really a better way than converting back and forth
+            x_tensor_dense = x_tensor.to_dense()
+            yield x_tensor_dense, x_idx, d_idx
+    else:
+        # Use .data here to get around the version check
+        x_tensor = x_tensor.data
+        for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
+            yield x_tensor, x_idx, d_idx
+
+
+def _get_numerical_jacobian(
+    fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False
+) -> List[Tuple[torch.Tensor, ...]]:
+    """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`.
+
+    If not specified, targets are the input. Returns M * N Jacobians where N is the
+    number of tensors in target that require grad and M is the number of non-integral
+    outputs.
+
+    Args:
+        fn: the function to compute the jacobian for
+        inputs: inputs to `fn`
+        outputs: provide precomputed outputs to avoid one extra invocation of fn
+        target: the Tensors wrt whom Jacobians are calculated (default=`inputs`)
+        eps: the magnitude of the perturbation during finite differencing
+             (default=`1e-3`)
+        is_forward_ad: if this numerical jacobian is computed to be checked wrt
+                       forward AD gradients (this is used for error checking only)
+
+    Returns:
+        A list of M N-tuples of tensors
+
+    Note that `target` may not even be part of `input` to `fn`, so please be
+    **very careful** in this to not clone `target`.
+    """
+    jacobians: List[Tuple[torch.Tensor, ...]] = []
+    if outputs is None:
+        outputs = _as_tuple(fn(*_as_tuple(inputs)))
+    if not is_forward_ad and any(o.is_complex() for o in outputs):
+        raise ValueError(
+            "Expected output to be non-complex. get_numerical_jacobian no "
+            "longer supports functions that return complex outputs."
+        )
+    if target is None:
+        target = inputs
+    inp_indices = [
+        i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad
+    ]
+    for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)):
+        jacobians += [
+            get_numerical_jacobian_wrt_specific_input(
+                fn,
+                inp_idx,
+                inputs,
+                outputs,
+                eps,
+                input=inp,
+                is_forward_ad=is_forward_ad,
+            )
+        ]
+    return jacobians
+
+
+def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0):
+    """Compute the numerical Jacobian for a given fn and its inputs.
+
+    This is a Deprecated API.
+
+    Args:
+        fn: the function to compute the Jacobian for (must take inputs as a tuple)
+        input: input to `fn`
+        target: the Tensors wrt whom Jacobians are calculated (default=`input`)
+        eps: the magnitude of the perturbation during finite differencing
+             (default=`1e-3`)
+
+    Returns:
+        A list of Jacobians of `fn` (restricted to its first output) with respect to
+        each input or target, if provided.
+
+    Note that `target` may not even be part of `input` to `fn`, so please be
+    **very careful** in this to not clone `target`.
+    """
+    warnings.warn(
+        "get_numerical_jacobian was part of PyTorch's private API and not "
+        "meant to be exposed. We are deprecating it and it will be removed "
+        "in a future version of PyTorch. If you have a specific use for "
+        "this or feature request for this to be a stable API, please file "
+        "us an issue at https://github.com/pytorch/pytorch/issues/new"
+    )
+    if (
+        grad_out != 1.0
+    ):  # grad_out param is only kept for backward compatibility reasons
+        raise ValueError(
+            "Expected grad_out to be 1.0. get_numerical_jacobian no longer "
+            "supports values of grad_out != 1.0."
+        )
+
+    def fn_pack_inps(*inps):
+        return fn(inps)
+
+    jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps)
+
+    return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians)
+
+
+def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
+    # Computes numerical directional derivative as finite difference
+    # of function `fn` at input `entry`, perturbed by vector `v`.
+    if _is_sparse_compressed_tensor(entry):
+        # sparse compressed tensors don't implement sub/add/copy_
+        # yet. However, in non-masked semantics context entry and v
+        # have the same sparse indices ...
+        assert entry.layout == v.layout, (entry.layout, v.layout)
+        assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape)
+        # ... the finite differencing can be performed on values only:
+        entry = entry.values()
+        v = v.values()
+        # we'll detach to avoid backward computations that sparse
+        # tensors have limited support for.
+        entry = entry.detach()
+
+    orig = entry.clone()
+    entry.copy_(orig - v)
+    outa = fn()
+    entry.copy_(orig + v)
+    outb = fn()
+    entry.copy_(orig)
+
+    def compute(a, b):
+        nbhd_checks_fn(a, b)
+        ret = (b - a) / (2 * norm_v)  # use central difference approx
+        return ret.detach().reshape(-1)
+
+    return tuple(compute(a, b) for (a, b) in zip(outa, outb))
+
+
+def _compute_numerical_jvps_wrt_specific_input(
+    jvp_fn, delta, input_is_complex, is_forward_ad=False
+) -> List[torch.Tensor]:
+    # Computing the jacobian only works for real delta
+    # For details on the algorithm used here, refer:
+    # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
+    # s = fn(z) where z = x for real valued input
+    # and z = x + yj for complex valued input
+    jvps: List[torch.Tensor] = []
+    ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta)
+
+    if input_is_complex:  # C -> R
+        ds_dy_tup = (
+            jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j)
+        )
+        for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup):
+            assert not ds_dx.is_complex()
+            # conjugate wirtinger derivative
+            conj_w_d = ds_dx + ds_dy * 1j
+            jvps.append(conj_w_d)
+    else:
+        for ds_dx in ds_dx_tup:  # R -> R or (R -> C for the forward AD case)
+            assert is_forward_ad or not ds_dx.is_complex()
+            jvps.append(ds_dx)
+    return jvps
+
+
+def _combine_jacobian_cols(
+    jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel
+) -> Tuple[torch.Tensor, ...]:
+    # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor
+    # we return a list that maps output_idx -> full jacobian Tensor
+    jacobians = _allocate_jacobians_with_outputs(
+        outputs, numel, dtype=input.dtype if input.dtype.is_complex else None
+    )
+    for i, jacobian in enumerate(jacobians):
+        for k, v in jacobians_cols.items():
+            jacobian[k] = v[i]
+    return jacobians
+
+
+def _prepare_input(
+    input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False
+) -> torch.Tensor:
+    # Prepares the inputs to be passed into the function while including the new
+    # modified input.
+    if input.layout == torch._mkldnn:  # type: ignore[attr-defined] # no attr _mkldnn
+        # Convert back to mkldnn
+        if maybe_perturbed_input is not None:
+            return maybe_perturbed_input.to_mkldnn()
+        else:
+            return input
+    elif _is_sparse_any_tensor(input):
+        if fast_mode and maybe_perturbed_input is not None:
+            # entry is already a "cloned" version of the original tensor
+            # thus changes to entry are not reflected in the input
+            return maybe_perturbed_input
+        else:
+            return input
+    else:
+        # We cannot use entry (input.data) if we want gradgrad to work because
+        # fn (in the gradgrad case) needs to compute grad wrt input
+        return input
+
+
+def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None:
+    # Check that the returned outputs don't have different dtype or shape when you
+    # perturb the input
+    on_index = "on index {idx} " if idx is not None else ""
+    assert output1.shape == output2.shape, (
+        f"Expected `func` to return outputs with the same shape"
+        f" when inputs are perturbed {on_index}by {eps}, but got:"
+        f" shapes {output1.shape} and {output2.shape}."
+    )
+    assert output1.dtype == output2.dtype, (
+        f"Expected `func` to return outputs with the same dtype"
+        f" when inputs are perturbed {on_index}by {eps}, but got:"
+        f" dtypes {output1.dtype} and {output2.dtype}."
+    )
+
+
+def get_numerical_jacobian_wrt_specific_input(
+    fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False
+) -> Tuple[torch.Tensor, ...]:
+    # Computes the numerical jacobians wrt to a single input. Returns N jacobian
+    # tensors, where N is the number of outputs. We use a dictionary for
+    # jacobian_cols because indices aren't necessarily consecutive for sparse inputs
+    # When we perturb only a single element of the input tensor at a time, the jvp
+    # is equivalent to a single col of the Jacobian matrix of fn.
+    jacobian_cols: Dict[int, List[torch.Tensor]] = {}
+    input = inputs[input_idx] if input is None else input
+    assert input.requires_grad
+    for x, idx, d_idx in _iter_tensor(input):
+        wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x)
+        input_to_perturb = x[idx]
+        nbhd_checks_fn = functools.partial(
+            _check_outputs_same_dtype_and_shape, idx=idx, eps=eps
+        )
+        jvp_fn = _get_numerical_jvp_fn(
+            wrapped_fn, input_to_perturb, eps, nbhd_checks_fn
+        )
+        jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input(
+            jvp_fn, eps, x.is_complex(), is_forward_ad
+        )
+    return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel())
+
+
+def _get_analytical_jacobian_forward_ad(
+    fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None
+) -> Tuple[Tuple[torch.Tensor, ...], ...]:
+    """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`.
+
+    Return N * M Jacobians where N is the number of tensors in target that require grad and
+    M is the number of non-integral outputs.
+    Contrary to other functions here, this function requires "inputs" to actually be used by the function.
+    The computed value is expected to be wrong if the function captures the inputs by side effect instead of
+    using the passed ones (many torch.nn tests do this).
+
+    Args:
+        fn: the function to compute the jacobian for
+        inputs: inputs to `fn`
+        outputs: provide precomputed outputs to avoid one extra invocation of fn
+        check_grad_dtypes: if True, will check that the gradient dtype are valid
+        all_u (optional): if provided, the Jacobian will be right multiplied with this vector
+
+    Returns:
+        A tuple of M N-tuples of tensors
+    """
+    # To avoid early import issues
+    fwAD = torch.autograd.forward_ad
+
+    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
+
+    if any(i.is_complex() for i in tensor_inputs):
+        raise ValueError(
+            "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad."
+        )
+
+    if all_u:
+        jacobians = tuple(
+            _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs
+        )
+    else:
+        jacobians = tuple(
+            _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs
+        )
+
+    with fwAD.dual_level():
+        fw_grads = []
+        dual_inputs = []
+        for i, inp in enumerate(inputs):
+            if is_tensor_like(inp) and inp.requires_grad:
+                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
+                    raise ValueError(
+                        "MKLDNN inputs are not support for forward AD gradcheck."
+                    )
+
+                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
+                # If inp is a differentiable view, the dual might not be the tangent given to
+                # make_dual, so read it explicitly from the dual tensor
+                fw_grads.append(fwAD.unpack_dual(inp)[1])
+            dual_inputs.append(inp)
+
+        if all_u:
+            # Do the full reduction in one pass
+            # To be consistent with numerical evaluation, we actually compute one reduction per input
+            for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
+                fw_grad.copy_(u.view_as(fw_grad))
+                raw_outputs = _as_tuple(fn(*dual_inputs))
+                dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
+                for index_o, d_o in enumerate(dual_outputs):
+                    val, res = fwAD.unpack_dual(d_o)
+                    if (
+                        check_grad_dtypes
+                        and res is not None
+                        and val.is_complex() != res.is_complex()
+                    ):
+                        raise GradcheckError("Forward AD gradient has dtype mismatch.")
+
+                    # Remove extra dimension of size 1 corresponding to the reduced input
+                    jacobians[i][index_o].squeeze_(0)
+                    if res is None:
+                        jacobians[i][index_o].zero_()
+                    else:
+                        jacobians[i][index_o].copy_(res.reshape(-1))
+                fw_grad.zero_()
+        else:
+            # Reconstruct the full Jacobian column by column
+            for i, fw_grad in enumerate(fw_grads):
+                for lin_idx, grad_idx in enumerate(
+                    product(*[range(m) for m in fw_grad.size()])
+                ):
+                    fw_grad[grad_idx] = 1.0
+                    raw_outputs = _as_tuple(fn(*dual_inputs))
+                    dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs)
+                    for index_o, d_o in enumerate(dual_outputs):
+                        val, res = fwAD.unpack_dual(d_o)
+                        if (
+                            check_grad_dtypes
+                            and res is not None
+                            and val.is_complex() != res.is_complex()
+                        ):
+                            raise GradcheckError(
+                                "Forward AD gradient has dtype mismatch."
+                            )
+
+                        if res is None:
+                            jacobians[i][index_o][lin_idx].zero_()
+                        else:
+                            jacobians[i][index_o][lin_idx].copy_(res.reshape(-1))
+                    fw_grad[grad_idx] = 0.0
+
+    return jacobians
+
+
+def _get_input_to_perturb(input):
+    # Prepare the input so that it can be modified in-place and do certain
+    # operations that require the tensor to have strides. If fast_mode=False,
+    # _iter_tensor would handle the below cases:
+    if input.layout == torch._mkldnn:  # type: ignore[attr-defined] # no attr _mkldnn
+        # Convert to dense so we can perform operations that require strided tensors
+        input_to_perturb = input.to_dense()
+    elif _is_sparse_any_tensor(input):
+        # Clone because input may require grad, and copy_ calls resize_,
+        # which is not allowed for .data
+        input_to_perturb = input.clone()
+    else:
+        input_to_perturb = input.data
+    return input_to_perturb
+
+
+def _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False):
+    # Wraps `fn` so that its inputs are already supplied
+    def wrapped_fn():
+        inp = tuple(
+            _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode)
+            if is_tensor_like(a)
+            else a
+            for i, a in enumerate(_as_tuple(inputs))
+        )
+        return tuple(a.clone() for a in _as_tuple(fn(*inp)))
+
+    return wrapped_fn
+
+
+def _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn):
+    # Wraps jvp_fn so that certain arguments are already supplied
+    def jvp_fn(delta):
+        return _compute_numerical_gradient(
+            wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn
+        )
+
+    return jvp_fn
+
+
+def _reshape_tensor_or_tuple(u, shape):
+    # We don't need to reshape when input corresponding to u is sparse
+    if isinstance(u, tuple):
+        if not _is_sparse_any_tensor(u[0]):
+            return (u[0].reshape(shape), u[1].reshape(shape))
+    else:
+        if not _is_sparse_any_tensor(u):
+            return u.reshape(shape)
+    return u
+
+
+def _mul_tensor_or_tuple(u, k):
+    if isinstance(u, tuple):
+        return (k * u[0], k * u[1])
+    else:
+        return k * u
+
+
+def _get_numerical_jvp_wrt_specific_input(
+    fn, input_idx, inputs, u, eps, is_forward_ad=False
+) -> List[torch.Tensor]:
+    input = inputs[input_idx]
+    input_to_perturb = _get_input_to_perturb(input)
+    wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True)
+    nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps)
+    jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn)
+    u = _reshape_tensor_or_tuple(u, input_to_perturb.shape)
+    u = _mul_tensor_or_tuple(u, eps)
+    return _compute_numerical_jvps_wrt_specific_input(
+        jvp_fn, u, input.is_complex(), is_forward_ad
+    )
+
+
+def _get_numerical_vJu(
+    fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad
+):
+    # Note that all_v can also be None, in that case, this function only computes Ju.
+    reduced_jacobians: List[List[torch.Tensor]] = []
+    for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)):
+        all_Ju = _get_numerical_jvp_wrt_specific_input(
+            fn, inp_idx, inputs, u, eps, is_forward_ad
+        )
+        # Filter out the Ju for non floating point outputs
+        filtered_Ju = []
+        func_out = _as_tuple(func_out)
+        assert len(all_Ju) == len(func_out)
+        for Ju, output in zip(all_Ju, func_out):
+            if _is_float_or_complex_tensor(output):
+                filtered_Ju.append(Ju)
+            else:
+                # TODO: handle the other Ju
+                pass
+        if all_v is not None:
+            jacobian_scalars: List[torch.Tensor] = []
+            for v, Ju in zip(all_v, filtered_Ju):
+                jacobian_scalars.append(_dot_with_type_promotion(v, Ju))
+            reduced_jacobians.append(jacobian_scalars)
+        else:
+            reduced_jacobians.append(filtered_Ju)
+    return reduced_jacobians
+
+
+def _check_jacobians_equal(j1, j2, atol):
+    # Check whether the max difference between two Jacobian tensors are within some
+    # tolerance `atol`.
+    for j1_x, j2_x in zip(j1, j2):
+        if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol:
+            return False
+    return True
+
+
+def _stack_and_check_tensors(
+    list_of_list_of_tensors, inputs, numel_outputs
+) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]:
+    # For the ith tensor in the inner list checks whether it has the same size and
+    # dtype as the ith differentiable input.
+    out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs)
+    diff_input_list = list(_iter_tensors(inputs, True))
+    correct_grad_sizes = True
+    correct_grad_types = True
+    for i, tensor_list in enumerate(list_of_list_of_tensors):
+        inp = diff_input_list[i]
+        out_jacobian = out_jacobians[i]
+        for j, tensor in enumerate(tensor_list):
+            if tensor is not None and tensor.size() != inp.size():
+                correct_grad_sizes = False
+            elif tensor is not None and tensor.dtype != inp.dtype:
+                correct_grad_types = False
+            if tensor is None:
+                out_jacobian[:, j].zero_()
+            else:
+                dense = (
+                    tensor.to_dense() if not tensor.layout == torch.strided else tensor
+                )
+                assert out_jacobian[:, j].numel() == dense.numel()
+                out_jacobian[:, j] = dense.reshape(-1)
+    return out_jacobians, correct_grad_sizes, correct_grad_types
+
+
+FAILED_NONDET_MSG = """\n
+NOTE: If your op relies on non-deterministic operations i.e., it is listed here:
+https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
+this failure might be expected.
+
+If you are adding a new operator, please file an issue and then use one of the
+workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
+If the test
+- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
+  with `nondet_tol=` as a keyword argument.
+- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
+  to have `gradcheck_nondet_tol=`.
+- is a Module test (e.g., in common_nn.py), then modify the corresponding
+  module_test entry to have `gradcheck_nondet_tol=`
+"""
+
+
+def _check_analytical_jacobian_attributes(
+    inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None
+) -> Tuple[torch.Tensor, ...]:
+    # This is used by both fast and slow mode:
+    #  - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith
+    #    input.
+    #  - For fast mode, vjps[i][0] is a linear combination of the rows
+    #    of the Jacobian wrt the ith input
+    diff_input_list = list(_iter_tensors(inputs, True))
+
+    def vjp_fn(grad_output):
+        return torch.autograd.grad(
+            output, diff_input_list, grad_output, retain_graph=True, allow_unused=True
+        )
+
+    # Compute everything twice to check for nondeterminism (which we call reentrancy)
+    if fast_mode:
+        vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
+        vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v)
+    else:
+        vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
+        vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
+
+    output_numel = output.numel() if not fast_mode else 1
+    jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(
+        vjps1, inputs, output_numel
+    )
+    jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel)
+    reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol)
+
+    if not types_ok and check_grad_dtypes:
+        raise GradcheckError("Gradient has dtype mismatch")
+    if not sizes_ok:
+        raise GradcheckError("Analytical gradient has incorrect size")
+    if not reentrant:
+        raise GradcheckError(
+            "Backward is not reentrant, i.e., running backward with "
+            "same input and grad_output multiple times gives different values, "
+            "although analytical gradient matches numerical gradient."
+            f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG
+        )
+    return jacobians1
+
+
+def _get_analytical_vJu_backward_mode(
+    inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u
+):
+    reduced_jacobians: List[List[torch.Tensor]] = []
+    for output, v in zip(outputs, all_v):
+        all_vJ = _check_analytical_jacobian_attributes(
+            inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v
+        )
+        jacobian_scalars: List[torch.Tensor] = []
+        for vJ, u in zip(all_vJ, all_u):
+            # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse
+            # the error checking logic from slow mode
+            vJ = vJ.T.squeeze(0)
+            if vJ.is_complex():  # C -> R
+                tv = torch.view_as_real(vJ.resolve_conj())
+                tr = tv.select(-1, 0)
+                ti = tv.select(-1, 1)
+                jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1]))
+            else:  # R -> R
+                jacobian_scalars.append(vJ.dot(u))
+        reduced_jacobians.append(jacobian_scalars)
+    return reduced_jacobians
+
+
+def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0):
+    # Replicates the behavior of the old get_analytical_jacobian before the refactor
+    # This shares much of its code with _check_analytical_jacobian_attributes
+    warnings.warn(
+        "get_analytical_jacobian was part of PyTorch's private API and not "
+        "meant to be exposed. We are deprecating it and it will be removed "
+        "in a future version of PyTorch. If you have a specific use for "
+        "this or feature request for this to be a stable API, please file "
+        "us an issue at https://github.com/pytorch/pytorch/issues/new"
+    )
+    if (
+        grad_out != 1.0
+    ):  # grad_out param is only kept for backward compatibility reasons
+        raise ValueError(
+            "Expected grad_out to be 1.0. get_analytical_jacobian no longer "
+            "supports values of grad_out != 1.0."
+        )
+    if output.is_complex():
+        raise ValueError(
+            "Expected output to be non-complex. get_analytical_jacobian no "
+            "longer supports functions that return complex outputs."
+        )
+    diff_input_list = list(_iter_tensors(inputs, True))
+
+    def vjp_fn(grad_output):
+        return torch.autograd.grad(
+            output, diff_input_list, grad_output, retain_graph=True, allow_unused=True
+        )
+
+    # Compute everything twice to check for nondeterminism (which we call reentrancy)
+    vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
+    vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone())
+
+    output_numel = output.numel()
+    jacobians1, types_ok, sizes_ok = _stack_and_check_tensors(
+        vjps1, inputs, output_numel
+    )
+    jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel)
+    reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol)
+
+    return jacobians1, reentrant, sizes_ok, types_ok
+
+
+def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx):
+    # Computes the analytical Jacobian in slow mode for a single input-output pair.
+    # Forgoes performing checks on dtype, shape, and reentrancy.
+    jacobians = _check_analytical_jacobian_attributes(
+        inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False
+    )
+    return jacobians[input_idx]
+
+
+def _compute_analytical_jacobian_rows(
+    vjp_fn, sample_output
+) -> List[List[Optional[torch.Tensor]]]:
+    # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis
+    # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian.
+    # NB: this function does not assume vjp_fn(v) to return tensors with the same
+    # number of elements for different v. This is checked when we later combine the
+    # rows into a single tensor.
+    grad_out_base = torch.zeros_like(
+        sample_output, memory_format=torch.legacy_contiguous_format
+    )
+    flat_grad_out = grad_out_base.view(-1)
+    # jacobians_rows[i][j] is the Jacobian jth row for the ith input
+    jacobians_rows: List[List[Optional[torch.Tensor]]] = []
+    for j in range(flat_grad_out.numel()):
+        flat_grad_out.zero_()
+        flat_grad_out[j] = 1.0  # projection for jth row of Jacobian
+        grad_inputs = vjp_fn(grad_out_base)
+        for i, d_x in enumerate(grad_inputs):
+            if j == 0:
+                jacobians_rows.append([])
+            jacobians_rows[i] += [
+                d_x.clone() if isinstance(d_x, torch.Tensor) else None
+            ]
+    return jacobians_rows
+
+
+def _get_analytical_vjps_wrt_specific_output(
+    vjp_fn, sample_output, v
+) -> List[List[Optional[torch.Tensor]]]:
+    vjps: List[List[Optional[torch.Tensor]]] = []
+    grad_inputs = vjp_fn(v.reshape(sample_output.shape))
+    for vjp in grad_inputs:
+        vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None])
+    return vjps
+
+
+def _check_inputs(tupled_inputs) -> bool:
+    # Make sure that gradients are saved for at least one input
+    any_input_requiring_grad = False
+    for idx, inp in enumerate(tupled_inputs):
+        if is_tensor_like(inp) and inp.requires_grad:
+            if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
+                warnings.warn(
+                    f"Input #{idx} requires gradient and "
+                    "is not a double precision floating point or complex. "
+                    "This check will likely fail if all the inputs are "
+                    "not of double precision floating point or complex. "
+                )
+            if inp.is_sparse:
+                content = inp._values()
+            elif _is_sparse_compressed_tensor(inp):
+                content = inp.values()
+            else:
+                content = inp
+            # TODO: To cover more problematic cases, replace stride = 0 check with
+            # "any overlap in memory" once we have a proper function to check it.
+            if content.layout is not torch._mkldnn:  # type: ignore[attr-defined]
+                if not all(
+                    st > 0 or sz <= 1
+                    for st, sz in zip(content.stride(), content.size())
+                ):
+                    raise RuntimeError(
+                        f"The {idx}th input has a dimension with stride 0. gradcheck only "
+                        "supports inputs that are non-overlapping to be able to "
+                        "compute the numerical gradients correctly. You should call "
+                        ".contiguous on the input before passing it to gradcheck."
+                    )
+            any_input_requiring_grad = True
+
+    if not any_input_requiring_grad:
+        raise ValueError(
+            "gradcheck expects at least one input tensor to require gradient, "
+            "but none of the them have requires_grad=True."
+        )
+    return True
+
+
+def _check_outputs(outputs) -> None:
+    if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)):
+        # it is easier to call to_dense() on the sparse output than
+        # to modify analytical jacobian
+        raise ValueError(
+            "Sparse output is not supported at gradcheck yet. "
+            "Please call to_dense(masked_grad=...) on the output of fn for gradcheck."
+        )
+    if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)):  # type: ignore[attr-defined]
+        raise ValueError(
+            "MKLDNN output is not supported at gradcheck yet. "
+            "Please call to_dense(masked_grad=...) on the output of fn for gradcheck."
+        )
+
+
+def _check_no_differentiable_outputs(
+    func, inputs, func_out, eps, *, is_forward_ad
+) -> bool:
+    # When there are no differentiable outputs, numerical gradient for a function is
+    # expected to be zero.
+    jacobians_all_inputs_outputs = _get_numerical_jacobian(
+        func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad
+    )
+    for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs:
+        for jacobian in jacobians_all_outputs_and_fixed_input:
+            if torch.ne(jacobian, 0).sum() > 0:
+                raise GradcheckError(
+                    "Numerical gradient for function expected to be zero"
+                )
+    return True
+
+
+def _check_no_differentiable_outputs_fast(
+    func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol
+):
+    for inp_idx, u in zip(inputs_indices, all_u):
+        jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps)
+        for jvp in jvps:
+            if jvp.numel() == 0:
+                continue
+            if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol:
+                raise GradcheckError(
+                    "Numerical gradient for function expected to be zero"
+                )
+    return True
+
+
+FAILED_BATCHED_GRAD_MSG = """
+gradcheck or gradgradcheck failed while testing batched gradient computation.
+This could have been invoked in a number of ways (via a test that calls
+gradcheck/gradgradcheck directly or via an autogenerated test).
+
+If you are adding a new operator, please file an issue and then use one of the
+workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
+If the test
+- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
+  with `check_batched_grad=False` as a keyword argument.
+- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
+  to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.
+
+If you're modifying an existing operator that supports batched grad computation,
+or wish to make a new operator work with batched grad computation, please read
+the following.
+
+To compute batched grads (e.g., jacobians, hessians), we vmap over the backward
+computation. The most common failure case is if there is a 'vmap-incompatible
+operation' in the backward pass. Please see
+NOTE: [How to write vmap-compatible backward formulas]
+in the codebase for an explanation of how to fix this.
+""".strip()
+
+FAILED_BATCHED_GRAD_MSG_FWD_AD = """
+gradcheck failed while testing batched gradient computation with forward-mode AD.
+This test is enabled automatically when both `check_batched_grad=True`
+and `check_forward_ad=True`, but can be disabled in the following ways
+dependong on how the test was invoked (via a test that calls gradcheck
+directly or via an autogenerated test).
+
+If you are adding a new operator, please file an issue and then use one of the
+workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.
+If the test
+- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
+  with `check_batched_forward_grad=False` as a keyword argument.
+- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
+  to have `check_batched_forward_grad=False`
+"""
+
+
+def _get_failed_batched_grad_test_msg(
+    output_idx, input_idx, res, exp, is_forward_ad=False
+):
+    return f"""
+For output {output_idx} and input {input_idx}:
+
+{FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG}
+
+Got:
+{res}
+
+Expected:
+{exp}
+""".strip()
+
+
+def _test_batched_grad_forward_ad(func, inputs) -> bool:
+    fwAD = torch.autograd.forward_ad  # To avoid early import issues (do we need this?)
+    assert isinstance(inputs, tuple)
+
+    for input_idx, current_input in enumerate(inputs):
+        if not (is_tensor_like(current_input) and current_input.requires_grad):
+            continue
+
+        def jvp(tangent: torch.Tensor):
+            with fwAD.dual_level():
+                dual = fwAD.make_dual(current_input.detach(), tangent)
+                inputs_with_dual = tuple(
+                    dual
+                    if idx == input_idx
+                    else (inp.detach() if is_tensor_like(inp) else inp)
+                    for idx, inp in enumerate(inputs)
+                )
+                dual_outputs = _as_tuple(func(*inputs_with_dual))
+                ret = []
+                for dual_output in dual_outputs:
+                    if dual_output is None:
+                        continue
+                    primal_out, tangent_out = fwAD.unpack_dual(dual_output)
+                    if tangent_out is not None:
+                        ret.append(tangent_out)
+                    else:
+                        ret.append(
+                            torch.zeros(
+                                [], dtype=primal_out.dtype, device=primal_out.device
+                            ).expand(primal_out.shape)
+                        )
+                return tuple(ret)
+
+        if not _is_float_or_complex_tensor(current_input):
+            continue
+
+        tangents = [torch.randn_like(current_input) for _ in range(2)]
+        expected = [jvp(t) for t in tangents]
+        expected = [torch.stack(shards) for shards in zip(*expected)]
+
+        try:
+            result = _vmap(jvp)(torch.stack(tangents))
+        except RuntimeError as ex:
+            # Rethrow to provide a better error message
+            raise GradcheckError(
+                f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}"
+            ) from ex
+
+        for input_idx, (res, exp) in enumerate(zip(result, expected)):
+            if torch.allclose(res, exp):
+                continue
+            raise GradcheckError(
+                _get_failed_batched_grad_test_msg(
+                    input_idx, input_idx, res, exp, is_forward_ad=True
+                )
+            )
+    return True
+
+
+def _test_batched_grad(input, output, output_idx) -> bool:
+    # NB: _test_batched_grad compares two autograd.grad invocations with a single
+    # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the
+    # sense that we're not comparing an analytical jacobian with a numeric one,
+    # but it is morally similar (we could have computed a full analytic jac
+    # via vmap, but that is potentially slow)
+    diff_input_list = list(_iter_tensors(input, True))
+    grad = functools.partial(
+        torch.autograd.grad,
+        output,
+        diff_input_list,
+        retain_graph=True,
+        allow_unused=True,
+    )
+
+    def vjp(v):
+        results = grad(v)
+        results = tuple(
+            grad
+            if grad is not None
+            else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape)
+            for grad, inp in zip(results, diff_input_list)
+        )
+        return results
+
+    grad_outputs = [torch.randn_like(output) for _ in range(2)]
+
+    expected = [vjp(gO) for gO in grad_outputs]
+    expected = [torch.stack(shards) for shards in zip(*expected)]
+
+    # Squash warnings since these are expected to happen in most cases
+    # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209
+    with warnings.catch_warnings():
+        warnings.filterwarnings("ignore", message="There is a performance drop")
+        warnings.filterwarnings("ignore", message="Please use torch.vmap")
+        try:
+            result = vmap(vjp)(torch.stack(grad_outputs))
+        except RuntimeError as ex:
+            # It's OK that we're not raising the error at the correct callsite.
+            # That's because the callsite is always going to inside the Python
+            # autograd.grad instead of the C++ traceback of what line in the
+            # backward formula
+            raise GradcheckError(
+                f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}"
+            ) from ex
+
+    for input_idx, (res, exp) in enumerate(zip(result, expected)):
+        if torch.allclose(res, exp):
+            continue
+        raise GradcheckError(
+            _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp)
+        )
+    return True
+
+
+def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool:
+    # Tests that backward is multiplied by grad_output
+    diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
+    if not diff_input_list:
+        raise GradcheckError("no Tensors requiring grad found in input")
+    grads_input = torch.autograd.grad(
+        outputs,
+        diff_input_list,
+        [
+            torch.zeros_like(o, memory_format=torch.legacy_contiguous_format)
+            for o in outputs
+        ],
+        allow_unused=True,
+    )
+    for gi, di in zip(grads_input, diff_input_list):
+        if gi is None:
+            continue
+        if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
+            if gi.layout != di.layout:
+                raise GradcheckError(
+                    "grad is incorrect layout ("
+                    + str(gi.layout)
+                    + " is not "
+                    + str(di.layout)
+                    + ")"
+                )
+            if _is_sparse_any_tensor(gi):
+                sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "")
+                if gi.sparse_dim() != di.sparse_dim():
+                    raise GradcheckError(
+                        f"grad is {sparse_kind} tensor, but has incorrect sparse_dim"
+                        f" {gi.sparse_dim()}, expected {di.sparse_dim()}"
+                    )
+                if gi.dense_dim() != di.dense_dim():
+                    raise GradcheckError(
+                        f"grad is {sparse_kind} tensor, but has incorrect dense_dim"
+                        f" {gi.dense_dim()}, expected {di.dense_dim()}"
+                    )
+            gi = gi.to_dense()
+            di = di.to_dense()
+        if masked:
+            if not torch.allclose(gi, torch.zeros_like(gi)):
+                raise GradcheckError("backward not multiplied by grad_output")
+        elif not gi.eq(0).all():
+            raise GradcheckError("backward not multiplied by grad_output")
+        if gi.dtype != di.dtype:
+            raise GradcheckError("grad is incorrect type")
+        if gi.device != di.device:
+            raise GradcheckError("grad is incorrect device")
+        if gi.size() != di.size():
+            raise GradcheckError("grad is incorrect size")
+    return True
+
+
+def _test_undefined_forward_mode(func, outputs, inputs):
+    fwAD = torch.autograd.forward_ad
+
+    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
+    all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True)
+
+    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
+
+    with fwAD.dual_level():
+        fw_grads = []
+        dual_inputs = []
+        tensor_indices = set()
+        for i, inp in enumerate(inputs):
+            if is_tensor_like(inp) and inp.requires_grad:
+                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
+                    raise ValueError(
+                        "MKLDNN inputs are not support for forward AD gradcheck."
+                    )
+
+                inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
+                # If inp is a differentiable view, the dual might not be the tangent given to
+                # make_dual, so read it explicitly from the dual tensor
+                fw_grads.append(fwAD.unpack_dual(inp)[1])
+                tensor_indices.add(i)
+            dual_inputs.append(inp)
+
+        for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
+            fw_grad.copy_(u.view_as(fw_grad))
+
+        for idx, inp in enumerate(inputs):
+            if idx not in tensor_indices:
+                continue
+            dual_inp_obj = dual_inputs[idx]
+
+            # case 1 (Materialized Zero Tensor Tangent)
+            dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp))
+            raw_outputs = _as_tuple(func(*dual_inputs))
+            dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
+
+            # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
+            dual_inputs[idx] = inp.detach()
+            raw_outputs = _as_tuple(func(*dual_inputs))
+            dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
+
+            # reset
+            dual_inputs[idx] = dual_inp_obj
+
+            for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)):
+                val1, res1 = fwAD.unpack_dual(d_o1)
+                val2, res2 = fwAD.unpack_dual(d_o2)
+
+                if not (res1 is None or res2 is None):
+                    if not torch.allclose(res1, res2):
+                        raise GradcheckError(
+                            "Mismatch in tangent values for output with index: ",
+                            index_o,
+                            " when input: ",
+                            inp,
+                            " has an undefined tangent value. ",
+                            " Got: ",
+                            res1,
+                            " but expected: ",
+                            res2,
+                        )
+    return True
+
+
+def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
+    diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
+    if not diff_input_list:
+        raise GradcheckError("no Tensors requiring grad found in input")
+
+    def warn_bc_breaking():
+        warnings.warn(
+            "Backwards compatibility: New undefined gradient support checking "
+            "feature is enabled by default, but it may break existing callers "
+            "of this function. If this is true for you, you can call this "
+            'function with "check_undefined_grad=False" to disable the feature'
+        )
+
+    def check_undefined_grad_support(output_to_check):
+        grads_output = [
+            torch.zeros_like(o, memory_format=torch.legacy_contiguous_format)
+            for o in output_to_check
+        ]
+        try:
+            grads_input = torch.autograd.grad(
+                output_to_check, diff_input_list, grads_output, allow_unused=True
+            )
+        except RuntimeError as e:
+            warn_bc_breaking()
+            raise GradcheckError(
+                "Expected backward function to handle undefined output grads. "
+                'Please look at "Notes about undefined output gradients" in '
+                '"tools/autograd/derivatives.yaml"'
+            ) from e
+
+        for gi, i in zip(grads_input, diff_input_list):
+            if (gi is not None) and (not gi.eq(0).all()):
+                warn_bc_breaking()
+                raise GradcheckError(
+                    "Expected all input grads to be undefined or zero when all output grads are undefined "
+                    'or zero. Please look at "Notes about undefined output gradients" in '
+                    '"tools/autograd/derivatives.yaml"'
+                )
+        return True
+
+    # All backward functions must work properly if all output grads are undefined
+    outputs_to_check = [
+        [
+            torch._C._functions.UndefinedGrad()(o)
+            for o in _differentiable_outputs(func(*inputs))
+            # This check filters out Tensor-likes that aren't instances of Tensor.
+            if isinstance(o, torch.Tensor)
+        ]
+    ]
+
+    # If there are multiple output grads, we should be able to undef one at a time without error
+    if len(outputs_to_check[0]) > 1:
+        for undef_grad_idx in range(len(outputs)):
+            output_to_check = _differentiable_outputs(func(*inputs))
+            outputs_to_check.append(
+                [
+                    torch._C._functions.UndefinedGrad()(o)
+                    if idx == undef_grad_idx
+                    else o
+                    for idx, o in enumerate(output_to_check)
+                ]
+            )
+
+    return all(check_undefined_grad_support(output) for output in outputs_to_check)
+
+
+def _as_tuple(x):
+    if isinstance(x, tuple):
+        return x
+    elif isinstance(x, list):
+        return tuple(x)
+    else:
+        return (x,)
+
+
+def _differentiable_outputs(x):
+    return tuple(o for o in _as_tuple(x) if o.requires_grad)
+
+
+def _get_notallclose_msg(
+    analytical,
+    numerical,
+    output_idx,
+    input_idx,
+    complex_indices,
+    test_imag=False,
+    is_forward_ad=False,
+) -> str:
+    out_is_complex = (
+        (not is_forward_ad) and complex_indices and output_idx in complex_indices
+    )
+    inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices
+    part = "imaginary" if test_imag else "real"
+    element = "inputs" if is_forward_ad else "outputs"
+    prefix = (
+        ""
+        if not (out_is_complex or inp_is_complex)
+        else f"While considering the {part} part of complex {element} only, "
+    )
+    mode = "computed with forward mode " if is_forward_ad else ""
+    return (
+        prefix + "Jacobian %smismatch for output %d with respect to input %d,\n"
+        "numerical:%s\nanalytical:%s\n"
+        % (mode, output_idx, input_idx, numerical, analytical)
+    )
+
+
+def _transpose(matrix_of_tensors):
+    # returns list of tuples
+    return list(zip(*matrix_of_tensors))
+
+
+def _real_and_imag_output(fn):
+    # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as
+    # the original fn, except torch.real or torch.imag are applied to the complex outputs
+    def apply_to_c_outs(fn, fn_to_apply):
+        def wrapped_fn(*inputs):
+            outs = _as_tuple(fn(*inputs))
+            return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs)
+
+        return wrapped_fn
+
+    return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag)
+
+
+def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs):
+    # returns new functions that take real inputs instead of complex inputs as
+    # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j).
+    # In each case, the other part is considered constant.
+    # We do not use 0 for the constant here to make sure we always call the user function with a valid input.
+    def apply_to_c_inps(fn, fn_to_apply):
+        def wrapped_fn(*inputs):
+            new_inputs = list(inputs)
+            for should_be_complex in complex_inp_indices:
+                new_inputs[should_be_complex] = fn_to_apply(
+                    new_inputs[should_be_complex], tupled_inputs[should_be_complex]
+                )
+            return _as_tuple(fn(*new_inputs))
+
+        return wrapped_fn
+
+    real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j)
+    imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j)
+    return real_fn, imag_fn
+
+
+def _gradcheck_real_imag(
+    gradcheck_fn,
+    func,
+    func_out,
+    tupled_inputs,
+    outputs,
+    eps,
+    rtol,
+    atol,
+    check_grad_dtypes,
+    check_forward_ad,
+    check_backward_ad,
+    nondet_tol,
+    check_undefined_grad,
+):
+    complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()]
+    has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out))
+    if check_backward_ad:
+        if has_any_complex_output:
+            real_fn, imag_fn = _real_and_imag_output(func)
+
+            imag_func_out = imag_fn(*tupled_inputs)
+            imag_outputs = _differentiable_outputs(imag_func_out)
+            gradcheck_fn(
+                imag_fn,
+                imag_func_out,
+                tupled_inputs,
+                imag_outputs,
+                eps,
+                rtol,
+                atol,
+                check_grad_dtypes,
+                nondet_tol,
+                complex_indices=complex_out_indices,
+                test_imag=True,
+            )
+
+            real_func_out = real_fn(*tupled_inputs)
+            real_outputs = _differentiable_outputs(real_func_out)
+            gradcheck_fn(
+                real_fn,
+                real_func_out,
+                tupled_inputs,
+                real_outputs,
+                eps,
+                rtol,
+                atol,
+                check_grad_dtypes,
+                nondet_tol,
+                complex_indices=complex_out_indices,
+            )
+        else:
+            gradcheck_fn(
+                func,
+                func_out,
+                tupled_inputs,
+                outputs,
+                eps,
+                rtol,
+                atol,
+                check_grad_dtypes,
+                nondet_tol,
+            )
+
+    if check_forward_ad:
+        complex_inp_indices = [
+            i
+            for i, inp in enumerate(tupled_inputs)
+            if is_tensor_like(inp) and inp.is_complex()
+        ]
+        if complex_inp_indices:
+            real_fn, imag_fn = _real_and_imag_input(
+                func, complex_inp_indices, tupled_inputs
+            )
+
+            imag_inputs = [
+                inp.imag if is_tensor_like(inp) and inp.is_complex() else inp
+                for inp in tupled_inputs
+            ]
+            imag_func_out = imag_fn(*imag_inputs)
+            diff_imag_func_out = _differentiable_outputs(imag_func_out)
+            gradcheck_fn(
+                imag_fn,
+                imag_func_out,
+                imag_inputs,
+                diff_imag_func_out,
+                eps,
+                rtol,
+                atol,
+                check_grad_dtypes,
+                nondet_tol,
+                complex_indices=complex_inp_indices,
+                test_imag=True,
+                use_forward_ad=True,
+            )
+
+            real_inputs = [
+                inp.real if is_tensor_like(inp) and inp.is_complex() else inp
+                for inp in tupled_inputs
+            ]
+            real_func_out = real_fn(*real_inputs)
+            diff_real_func_out = _differentiable_outputs(real_func_out)
+            gradcheck_fn(
+                real_fn,
+                real_func_out,
+                real_inputs,
+                diff_real_func_out,
+                eps,
+                rtol,
+                atol,
+                check_grad_dtypes,
+                nondet_tol,
+                complex_indices=complex_inp_indices,
+                use_forward_ad=True,
+            )
+            if check_undefined_grad:
+                _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs)
+                _test_undefined_forward_mode(real_fn, real_func_out, real_inputs)
+        else:
+            gradcheck_fn(
+                func,
+                func_out,
+                tupled_inputs,
+                outputs,
+                eps,
+                rtol,
+                atol,
+                check_grad_dtypes,
+                nondet_tol,
+                use_forward_ad=True,
+            )
+            if check_undefined_grad:
+                _test_undefined_forward_mode(func, outputs, tupled_inputs)
+
+
+def _slow_gradcheck(
+    func,
+    func_out,
+    tupled_inputs,
+    outputs,
+    eps,
+    rtol,
+    atol,
+    check_grad_dtypes,
+    nondet_tol,
+    *,
+    use_forward_ad=False,
+    complex_indices=None,
+    test_imag=False,
+    masked=False,
+):
+    func_out = _as_tuple(func_out)
+    if not outputs:
+        return _check_no_differentiable_outputs(
+            func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad
+        )
+    tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs)
+
+    numerical = _transpose(
+        _get_numerical_jacobian(
+            func,
+            tupled_inputs_numerical,
+            func_out,
+            eps=eps,
+            is_forward_ad=use_forward_ad,
+        )
+    )
+    # Note: [numerical vs analytical output length]
+    # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that
+    # output is False. This behavior is necessary for _check_no_differentiable_outputs to work.
+    numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad]
+    if use_forward_ad:
+        analytical_forward = _get_analytical_jacobian_forward_ad(
+            func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes
+        )
+
+        for i, n_per_out in enumerate(numerical):
+            for j, n in enumerate(n_per_out):
+                a = analytical_forward[j][i]
+                if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
+                    raise GradcheckError(
+                        _get_notallclose_msg(
+                            a, n, i, j, complex_indices, test_imag, is_forward_ad=True
+                        )
+                    )
+    else:
+        for i, o in enumerate(outputs):
+            analytical = _check_analytical_jacobian_attributes(
+                tupled_inputs, o, nondet_tol, check_grad_dtypes
+            )
+
+            for j, (a, n) in enumerate(zip(analytical, numerical[i])):
+                if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol):
+                    raise GradcheckError(
+                        _get_notallclose_msg(a, n, i, j, complex_indices, test_imag)
+                    )
+
+    return True
+
+
+def _dot_with_type_promotion(u, v):
+    assert u.dim() == 1 and v.dim() == 1
+    return (u * v).sum()
+
+
+def _allclose_with_type_promotion(a, b, rtol, atol):
+    promoted_type = torch.promote_types(a.dtype, b.dtype)
+    a = a.to(dtype=promoted_type)
+    b = b.to(dtype=promoted_type)
+    return torch.allclose(a, b, rtol, atol)
+
+
+def _to_real_dtype(dtype):
+    if dtype == torch.complex128:
+        return torch.float64
+    elif dtype == torch.complex64:
+        return torch.float32
+    else:
+        return dtype
+
+
+def _vec_from_tensor(x, generator, downcast_complex=False):
+    # Create a random vector with the same number of elements as x and the same
+    # dtype/device. If x is complex and downcast_complex is False, we create a
+    # complex tensor with only real component.
+    if x.layout == torch.sparse_coo:
+        # For sparse, create a random sparse vec with random values in the same
+        # indices. Make sure size is set so that it isn't inferred to be smaller.
+        x_values = x._values()
+        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
+        values = (
+            torch.rand(x_values.numel(), generator=generator)
+            .to(dtype=dtype, device=x.device)
+            .view(x_values.shape)
+        )
+        values /= values.norm()
+        vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device)
+    elif _is_sparse_compressed_tensor(x):
+        if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
+            compressed_indices, plain_indices = x.crow_indices(), x.col_indices()
+        else:
+            compressed_indices, plain_indices = x.ccol_indices(), x.row_indices()
+        x_values = x.values()
+        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
+        values = (
+            torch.rand(x_values.numel(), generator=generator)
+            .to(dtype=dtype, device=x.device)
+            .view(x_values.shape)
+        )
+        values /= values.norm()
+        vec = torch.sparse_compressed_tensor(
+            compressed_indices,
+            plain_indices,
+            values,
+            x.size(),
+            layout=x.layout,
+            device=x.device,
+        )
+    else:
+        dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype
+        vec = torch.rand(x.numel(), generator=generator).to(
+            dtype=dtype, device=x.device
+        )
+        vec /= vec.norm()
+    return vec
+
+
+def _get_inp_tensors(tupled_inputs):
+    inp_idx_tup = [
+        (i, t)
+        for i, t in enumerate(tupled_inputs)
+        if is_tensor_like(t) and t.requires_grad
+    ]
+    return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup]
+
+
+def _adjusted_atol(atol, u, v):
+    # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we
+    # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and
+    # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is
+    # the correctly sized matrix in which each entry is atol.
+    #
+    # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N
+    # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i)
+    # TODO: properly handle case when u is tuple instead of only taking first element
+    u = u[0] if isinstance(u, tuple) else u
+    sum_u = u.sum()
+    sum_v = 1.0 if v is None else v.sum()
+    return atol * float(sum_u) * float(sum_v)
+
+
+FAST_FAIL_SLOW_OK_MSG = """
+Fast gradcheck failed but element-wise differences are small. This means that the
+test might've passed in slow_mode!
+
+If you are adding a new operator, please file an issue and then use one of the
+workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck:
+
+If the test
+- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck
+  with `fast_mode=False` as a keyword argument.
+- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test
+  to have `gradcheck_fast_mode=False`
+- is a Module test (e.g., in common_nn.py), then modify the corresponding
+  module_test entry to have `gradcheck_fast_mode=False`
+""".strip()
+
+
+def _run_slow_mode_and_get_error(
+    func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad
+):
+    # Compute jacobians in slow mode for better error message
+    slow_numerical = _get_numerical_jacobian(
+        func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad
+    )[input_idx][output_idx]
+    if is_forward_ad:
+
+        def new_fn(inp):
+            new_inputs = list(tupled_inputs)
+            new_inputs[input_idx] = inp
+            return _as_tuple(func(*new_inputs))[output_idx]
+
+        slow_analytical = _get_analytical_jacobian_forward_ad(
+            new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],)
+        )[0][0]
+    else:
+        slow_analytical = _get_analytical_jacobian(
+            tupled_inputs, outputs, input_idx, output_idx
+        )
+
+    # Assume jacobians are non-empty and have the same shape
+    slow_max_diff = (slow_numerical - slow_analytical).abs().max()
+
+    slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol)
+    msg = (
+        "\nThe above quantities relating the numerical and analytical jacobians are computed \n"
+        "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n"
+        "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n"
+        f"Numerical:\n {slow_numerical}\n"
+        f"Analytical:\n{slow_analytical}\n\n"
+        f"The max per-element difference (slow mode) is: {slow_max_diff}.\n"
+    )
+    if slow_allclose:
+        # Slow gradcheck would've passed!
+        msg += FAST_FAIL_SLOW_OK_MSG
+    return msg
+
+
+def _to_flat_dense_if_sparse(tensor):
+    if _is_sparse_any_tensor(tensor):
+        return tensor.to_dense().reshape(-1)
+    else:
+        return tensor
+
+
+def _make_vectors(inp_tensors, outputs, *, use_forward_ad):
+    # Use our own generator to avoid messing with the user's RNG state
+    g_cpu = torch.Generator()
+
+    def _vec_from_tensor_cpu(*args):
+        # Default allocate all tensors on CPU, so they are on the same device as the generator
+        # even if the user specified a default device
+        with torch.device("cpu"):
+            return _vec_from_tensor(*args)
+
+    all_u = []
+    all_u_dense = []
+    for inp in inp_tensors:
+        ur = _vec_from_tensor_cpu(inp, g_cpu, True)
+        ur_dense = _to_flat_dense_if_sparse(ur)
+        if inp.is_complex():
+            ui = _vec_from_tensor_cpu(inp, g_cpu, True)
+            all_u.append((ur, ui))
+            ui_dense = _to_flat_dense_if_sparse(ui)
+            all_u_dense.append((ur_dense, ui_dense))
+        else:
+            all_u.append(ur)
+            all_u_dense.append(ur_dense)
+    all_v = (
+        None
+        if use_forward_ad
+        else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs]
+    )
+    return all_v, all_u, all_u_dense
+
+
+def _check_analytical_numerical_equal(
+    all_analytical,
+    all_numerical,
+    complex_indices,
+    tupled_inputs,
+    outputs,
+    func,
+    all_v,
+    all_u,
+    rtol,
+    atol,
+    eps,
+    test_imag,
+    *,
+    is_forward_ad=False,
+):
+    for i, all_numerical_for_input_i in enumerate(all_numerical):
+        for j, n in enumerate(all_numerical_for_input_i):
+            # Forward AD generates the transpose of what this function expects
+            if is_forward_ad:
+                a = all_analytical[i][j]
+            else:
+                a = all_analytical[j][i]
+            n = n.to(device=a.device)
+            updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None)
+            if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol):
+                jacobians_str = _run_slow_mode_and_get_error(
+                    func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad
+                )
+                raise GradcheckError(
+                    _get_notallclose_msg(
+                        a, n, j, i, complex_indices, test_imag, is_forward_ad
+                    )
+                    + jacobians_str
+                )
+
+
+def _fast_gradcheck(
+    func,
+    func_out,
+    inputs,
+    outputs,
+    eps,
+    rtol,
+    atol,
+    check_grad_dtypes,
+    nondet_tol,
+    *,
+    use_forward_ad=False,
+    complex_indices=None,
+    test_imag=False,
+    masked=False,
+):
+    # See https://github.com/pytorch/pytorch/issues/53876 for details
+    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
+    # Backward mode computes v^T * J (VJP)
+    # Since we computed J * u (JVP) through finite difference method, we perform an equality check
+    # between VJP * u, v * JVP
+    # ----
+    # Forward mode computes J * u (JVP)
+    # Since we already compute JVP through finite difference method,
+    # we don't need v for correctness check here as asserted below
+    all_v, all_u, all_u_dense = _make_vectors(
+        inp_tensors, outputs, use_forward_ad=use_forward_ad
+    )
+
+    inputs_numerical, all_u_numerical, all_v_numerical = (
+        (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v))
+    )
+
+    numerical_vJu = _get_numerical_vJu(
+        func,
+        inputs_numerical,
+        inp_tensors_idx,
+        func_out,
+        all_u_numerical,
+        all_v_numerical,
+        eps,
+        is_forward_ad=use_forward_ad,
+    )
+    # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well
+    if use_forward_ad:
+        assert all_v is None
+        analytical_vJu = _get_analytical_jacobian_forward_ad(
+            func,
+            inputs,
+            _as_tuple(func_out),
+            all_u=all_u,
+            check_grad_dtypes=check_grad_dtypes,
+        )
+    else:
+        if not outputs:
+            _check_no_differentiable_outputs_fast(
+                func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol
+            )
+
+        analytical_vJu = _get_analytical_vJu_backward_mode(
+            inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense
+        )
+
+    _check_analytical_numerical_equal(
+        analytical_vJu,
+        numerical_vJu,
+        complex_indices,
+        inputs,
+        outputs,
+        func,
+        all_v,
+        all_u,
+        rtol,
+        atol,
+        eps,
+        test_imag,
+        is_forward_ad=use_forward_ad,
+    )
+
+    return True
+
+
+# Note [VarArg of Tensors]
+# ~~~~~~~~~~~~~~~~~~~~~~~~
+# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
+# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
+# the '...' first argument of Callable can be replaced with VarArg(Tensor).
+# For now, we permit any input.
+def gradcheck(
+    func: Callable[..., Union[_TensorOrTensors]],  # See Note [VarArg of Tensors]
+    inputs: _TensorOrTensors,
+    *,
+    eps: float = 1e-6,
+    atol: float = 1e-5,
+    rtol: float = 1e-3,
+    raise_exception: bool = True,
+    nondet_tol: float = 0.0,
+    check_undefined_grad: bool = True,
+    check_grad_dtypes: bool = False,
+    check_batched_grad: bool = False,
+    check_batched_forward_grad: bool = False,
+    check_forward_ad: bool = False,
+    check_backward_ad: bool = True,
+    fast_mode: bool = False,
+    masked: Optional[bool] = None,
+) -> bool:  # noqa: D400,D205
+    r"""Check gradients computed via small finite differences against analytical
+    gradients wrt tensors in :attr:`inputs` that are of floating point or complex type
+    and with ``requires_grad=True``.
+
+    The check between numerical and analytical gradients uses :func:`~torch.allclose`.
+
+    For most of the complex functions we consider for optimization purposes, no notion of
+    Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of
+    the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient
+    computation is done under the assumption that the overall function has a real-valued
+    output, we treat functions with complex output in a special way. For these functions,
+    gradcheck is applied to two real-valued functions corresponding to taking the real
+    components of the complex outputs for the first, and taking the imaginary components
+    of the complex outputs for the second. For more details, check out
+    :ref:`complex_autograd-doc`.
+
+    .. note::
+        The default values are designed for :attr:`input` of double precision.
+        This check will likely fail if :attr:`input` is of less precision, e.g.,
+        ``FloatTensor``.
+
+    .. note::
+        Gradcheck may fail when evaluated on non-differentiable points
+        because the numerically computed gradients via finite differencing may differ
+        those computed analytically (not necessarily because either is incorrect).
+        For more context, see :ref:`non-differentiable-func-grad`.
+
+    .. warning::
+       If any checked tensor in :attr:`input` has overlapping memory, i.e.,
+       different indices pointing to the same memory address (e.g., from
+       :func:`torch.expand`), this check will likely fail because the numerical
+       gradients computed by point perturbation at such indices will change
+       values at all other indices that share the same memory address.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a Tensor or a tuple of Tensors
+        inputs (tuple of Tensor or Tensor): inputs to the function
+        eps (float, optional): perturbation for finite differences
+        atol (float, optional): absolute tolerance
+        rtol (float, optional): relative tolerance
+        raise_exception (bool, optional): indicating whether to raise an exception if
+            the check fails. The exception gives more information about the
+            exact nature of the failure. This is helpful when debugging gradchecks.
+        nondet_tol (float, optional): tolerance for non-determinism. When running
+            identical inputs through the differentiation, the results must either match
+            exactly (default, 0.0) or be within this tolerance.
+        check_undefined_grad (bool, optional): if ``True``, check if undefined output grads
+            are supported and treated as zeros, for ``Tensor`` outputs.
+        check_batched_grad (bool, optional): if ``True``, check if we can compute
+            batched gradients using prototype vmap support. Defaults to False.
+        check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute
+            batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``.
+        check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward
+            mode AD match the numerical ones. Defaults to ``False``.
+        check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on
+            backward mode AD to be implemented. Defaults to ``True``.
+        fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only
+            implemented for R to R functions. If none of the inputs and outputs are complex
+            a faster implementation of gradcheck that no longer computes the entire jacobian
+            is run; otherwise, we fall back to the slow implementation.
+        masked (bool, optional): if ``True``, the gradients of unspecified elements of
+            sparse tensors are ignored. Defaults to ``False``.
+    Returns:
+        ``True`` if all differences satisfy allclose condition
+
+    """
+    assert (
+        check_forward_ad or check_backward_ad
+    ), "Expected at least one of check_forward_ad or check_backward_ad to be True"
+    assert not (
+        check_batched_grad and not check_backward_ad
+    ), "Setting check_batched_grad=True requires check_backward_ad to be True"
+    assert not (
+        check_batched_forward_grad and not check_forward_ad
+    ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True"
+    args = locals().copy()
+    args.pop("raise_exception")
+    if not raise_exception:
+        try:
+            return _gradcheck_helper(**args)
+        except GradcheckError as e:
+            return False
+    else:
+        return _gradcheck_helper(**args)
+
+
+def _gradcheck_helper(
+    func,
+    inputs,
+    eps,
+    atol,
+    rtol,
+    nondet_tol,
+    check_undefined_grad,
+    check_grad_dtypes,
+    check_batched_grad,
+    check_batched_forward_grad,
+    check_forward_ad,
+    check_backward_ad,
+    fast_mode,
+    masked,
+):
+    tupled_inputs = _as_tuple(inputs)
+    _check_inputs(tupled_inputs)
+
+    func_out = func(*tupled_inputs)
+    outputs = _differentiable_outputs(func_out)
+    _check_outputs(outputs)
+
+    gradcheck_fn = functools.partial(
+        _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked
+    )
+    _gradcheck_real_imag(
+        gradcheck_fn,
+        func,
+        func_out,
+        tupled_inputs,
+        outputs,
+        eps,
+        rtol,
+        atol,
+        check_grad_dtypes,
+        check_forward_ad=check_forward_ad,
+        check_backward_ad=check_backward_ad,
+        nondet_tol=nondet_tol,
+        check_undefined_grad=check_undefined_grad,
+    )
+
+    if check_batched_forward_grad:
+        _test_batched_grad_forward_ad(func, tupled_inputs)
+
+    # Short circuit because remaining tests rely on backward AD to be implemented
+    if not check_backward_ad:
+        return True
+
+    for i, o in enumerate(outputs):
+        if check_batched_grad:
+            _test_batched_grad(tupled_inputs, o, i)
+
+    _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked)
+
+    if check_undefined_grad and check_backward_ad:
+        _test_undefined_backward_mode(func, outputs, tupled_inputs)
+    return True
+
+
+def gradgradcheck(
+    func: Callable[..., _TensorOrTensors],  # See Note [VarArg of Tensors]
+    inputs: _TensorOrTensors,
+    grad_outputs: Optional[_TensorOrTensors] = None,
+    *,
+    eps: float = 1e-6,
+    atol: float = 1e-5,
+    rtol: float = 1e-3,
+    gen_non_contig_grad_outputs: bool = False,
+    raise_exception: bool = True,
+    nondet_tol: float = 0.0,
+    check_undefined_grad: bool = True,
+    check_grad_dtypes: bool = False,
+    check_batched_grad: bool = False,
+    check_fwd_over_rev: bool = False,
+    check_rev_over_rev: bool = True,
+    fast_mode: bool = False,
+    masked: bool = False,
+) -> bool:  # noqa: D400,D205
+    r"""Check gradients of gradients computed via small finite differences
+    against analytical gradients wrt tensors in :attr:`inputs` and
+    :attr:`grad_outputs` that are of floating point or complex type and with
+    ``requires_grad=True``.
+
+    This function checks that backpropagating through the gradients computed
+    to the given :attr:`grad_outputs` are correct.
+
+    The check between numerical and analytical gradients uses :func:`~torch.allclose`.
+
+    .. note::
+        The default values are designed for :attr:`input` and
+        :attr:`grad_outputs` of double precision. This check will likely fail if
+        they are of less precision, e.g., ``FloatTensor``.
+
+    .. warning::
+       If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
+       overlapping memory, i.e., different indices pointing to the same memory
+       address (e.g., from :func:`torch.expand`), this check will likely fail
+       because the numerical gradients computed by point perturbation at such
+       indices will change values at all other indices that share the same
+       memory address.
+
+    Args:
+        func (function): a Python function that takes Tensor inputs and returns
+            a Tensor or a tuple of Tensors
+        inputs (tuple of Tensor or Tensor): inputs to the function
+        grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
+            respect to the function's outputs.
+        eps (float, optional): perturbation for finite differences
+        atol (float, optional): absolute tolerance
+        rtol (float, optional): relative tolerance
+        gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
+            ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
+            randomly generated gradient outputs are made to be noncontiguous
+        raise_exception (bool, optional): indicating whether to raise an exception if
+            the check fails. The exception gives more information about the
+            exact nature of the failure. This is helpful when debugging gradchecks.
+        nondet_tol (float, optional): tolerance for non-determinism. When running
+            identical inputs through the differentiation, the results must either match
+            exactly (default, 0.0) or be within this tolerance. Note that a small amount
+            of nondeterminism in the gradient will lead to larger inaccuracies in
+            the second derivative.
+        check_undefined_grad (bool, optional): if True, check if undefined output grads
+            are supported and treated as zeros
+        check_batched_grad (bool, optional): if True, check if we can compute
+            batched gradients using prototype vmap support. Defaults to False.
+        fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that
+            no longer computes the entire jacobian.
+        masked (bool, optional): if True, the gradients of unspecified elements of
+            sparse tensors are ignored (default, False).
+    Returns:
+        True if all differences satisfy allclose condition
+    """
+    assert (
+        check_fwd_over_rev or check_rev_over_rev
+    ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True"
+    assert not (
+        check_undefined_grad and not check_rev_over_rev
+    ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True"
+    assert not (
+        check_batched_grad and not check_rev_over_rev
+    ), "Setting check_batched_grad=True requires check_rev_over_rev to be True"
+    # TODO: do we want to test this too?
+    # assert not (check_batched_forward_grad and not check_fwd_over_rev), (
+    #     "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True")
+    tupled_inputs = _as_tuple(inputs)
+
+    if grad_outputs is None:
+        # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs
+
+        outputs = _differentiable_outputs(func(*tupled_inputs))
+        tupled_grad_outputs = tuple(
+            torch.testing.make_tensor(
+                x.shape,
+                dtype=x.dtype
+                if x.is_floating_point() or x.is_complex()
+                else torch.double,
+                device=x.device,
+                low=-1,
+                high=1,
+                requires_grad=True,
+                noncontiguous=gen_non_contig_grad_outputs,
+            )
+            for x in outputs
+        )
+    else:
+        tupled_grad_outputs = _as_tuple(grad_outputs)
+
+    num_outputs = len(tupled_grad_outputs)
+
+    # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs
+    #     before running forward mode AD
+    diff_input_args_indices = {
+        i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad
+    }
+    diff_grad_output_indices = {
+        i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad
+    }
+
+    def new_func(*args):
+        # Restore the requires_grad information
+        input_args = tuple(
+            x.requires_grad_() if i in diff_input_args_indices else x
+            for i, x in enumerate(args[:-num_outputs])
+        )
+        outputs = _differentiable_outputs(func(*input_args))
+        grad_outputs = tuple(
+            x.requires_grad_() if i in diff_grad_output_indices else x
+            for i, x in enumerate(args[-num_outputs:])
+        )
+        diff_input_args = tuple(
+            x for i, x in enumerate(input_args) if i in diff_input_args_indices
+        )
+        grad_inputs = torch.autograd.grad(
+            outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True
+        )
+        grad_inputs = tuple(g for g in grad_inputs if g is not None)
+        return grad_inputs
+
+    return gradcheck(
+        new_func,
+        tupled_inputs + tupled_grad_outputs,
+        eps=eps,
+        atol=atol,
+        rtol=rtol,
+        raise_exception=raise_exception,
+        nondet_tol=nondet_tol,
+        check_undefined_grad=check_undefined_grad,
+        check_grad_dtypes=check_grad_dtypes,
+        check_batched_grad=check_batched_grad,
+        fast_mode=fast_mode,
+        check_forward_ad=check_fwd_over_rev,
+        check_backward_ad=check_rev_over_rev,
+        masked=masked,
+    )
diff --git a/MLPY/Lib/site-packages/torch/autograd/graph.py b/MLPY/Lib/site-packages/torch/autograd/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6db9f086ee0851a57563bdf265bf7e2f636ef34
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/graph.py
@@ -0,0 +1,749 @@
+import abc
+import collections
+import contextlib
+import functools
+import logging
+import threading
+import weakref
+from collections import defaultdict, namedtuple
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Deque,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
+
+import torch
+from torch.autograd.variable import Variable
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils.hooks import RemovableHandle
+
+log = logging.getLogger(__name__)
+
+
+__all__ = [
+    "saved_tensors_hooks",
+    "save_on_cpu",
+    "disable_saved_tensors_hooks",
+    "register_multi_grad_hook",
+    "allow_mutation_on_saved_tensors",
+    "Node",
+    "GradientEdge",
+    "get_gradient_edge",
+    "increment_version",
+]
+
+
+class Node(abc.ABC):
+    @abc.abstractmethod
+    def name(self) -> str:
+        r"""Return the name.
+
+        Example::
+
+            >>> import torch
+            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
+            >>> b = a.clone()
+            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
+            >>> print(b.grad_fn.name())
+            CloneBackward0
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
+        ...
+
+    @abc.abstractmethod
+    def metadata(self) -> dict:
+        r"""Return the metadata."""
+        ...
+
+    @abc.abstractmethod
+    def _register_hook_dict(self, tensor: torch.Tensor) -> None:
+        ...
+
+    @abc.abstractmethod
+    def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
+        r"""Register a backward hook.
+
+        The hook will be called every time a gradient with respect to the
+        Node is computed. The hook should have the following signature::
+
+            hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
+
+
+        The hook should not modify its argument, but it can optionally return
+        a new gradient which will be used in place of :attr:`grad_inputs`.
+
+        This function returns a handle with a method ``handle.remove()``
+        that removes the hook from the module.
+
+        .. note::
+            See :ref:`backward-hooks-execution` for more information on how when this hook
+            is executed, and how its execution is ordered relative to other hooks.
+
+        Example::
+
+            >>> import torch
+            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
+            >>> b = a.clone()
+            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
+            >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
+            >>> b.sum().backward(retain_graph=True)
+            >>> print(a.grad)
+            tensor([2., 2., 2.])
+            >>> handle.remove() # Removes the hook
+            >>> a.grad = None
+            >>> b.sum().backward(retain_graph=True)
+            >>> print(a.grad)
+            tensor([1., 1., 1.])
+        """
+        ...
+
+    @abc.abstractmethod
+    def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
+        r"""Register a backward pre-hook.
+
+        The hook will be called every time a gradient with respect to the
+        Node is computed. The hook should have the following signature::
+
+            hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
+
+        The hook should not modify its argument, but it can optionally return
+        a new gradient which will be used in place of :attr:`grad_outputs`.
+
+        This function returns a handle with a method ``handle.remove()``
+        that removes the hook from the module.
+
+        .. note::
+            See :ref:`backward-hooks-execution` for more information on how when this hook
+            is executed, and how its execution is ordered relative to other hooks.
+
+        Example::
+
+            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
+            >>> b = a.clone()
+            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
+            >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
+            >>> b.sum().backward(retain_graph=True)
+            >>> print(a.grad)
+            tensor([2., 2., 2.])
+            >>> handle.remove()
+            >>> a.grad = None
+            >>> b.sum().backward(retain_graph=True)
+            >>> print(a.grad)
+            tensor([1., 1., 1.])
+        """
+        ...
+
+    @classmethod
+    def __subclasshook__(cls, C):
+        if cls is Node:
+            if (
+                C is not None and C is getattr(torch._C._functions, C.__name__, None)
+            ) or issubclass(C, torch.autograd.function.BackwardCFunction):
+                return True
+        return NotImplemented
+
+
+def _get_grad_fn_or_grad_acc(t):
+    if t.requires_grad and t.grad_fn is None:
+        return t.view_as(t).grad_fn.next_functions[0][0]
+    else:
+        return t.grad_fn
+
+
+GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
+GradientEdge.__doc__ = """\
+Object representing a given gradient edge within the autograd graph.
+To get the gradient edge where a given Tensor gradient will be computed,
+you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
+"""
+
+
+def get_gradient_edge(tensor):
+    """Get the gradient edge for computing the gradient of the given Tensor.
+
+    In particular, it is equivalent to call
+    ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
+    """
+    if not tensor.requires_grad:
+        raise RuntimeError(
+            "It is not possible to get the gradient edge for a Tensor that does not require gradients"
+        )
+    grad_fn = _get_grad_fn_or_grad_acc(tensor)
+
+    # Note that output_nr default to 0 which is the right value
+    # for the AccumulateGrad node.
+    return GradientEdge(grad_fn, tensor.output_nr)
+
+
+def increment_version(tensor):
+    """Update autograd metadata tracking whether the given Tensor was modified in place.
+
+    This is to enable more accurate error checking within the autograd engine.
+    It is already done automatically by PyTorch functions and within custom Function
+    when mark_dirty() is called appropriately so you only need to call this explicitly
+    if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
+    know about. For example a custom kernel that reads the Tensor data_ptr and modifies
+    the memory inplace based on this pointer.
+
+    Note that incrementing the version counter multiple times for a single inplace operation
+    is not problematic.
+    """
+    torch._C._increment_version(tensor)
+
+
+class saved_tensors_hooks:
+    """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
+
+    Use this context-manager to define how intermediary results of an operation
+    should be packed before saving, and unpacked on retrieval.
+
+    In that context, the ``pack_hook`` function will be called everytime an
+    operation saves a tensor for backward (this includes intermediary results
+    saved using
+    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
+    also those recorded by a PyTorch-defined operation). The output of
+    ``pack_hook`` is then stored in the computation graph instead of the
+    original tensor.
+
+    The ``unpack_hook`` is called when the saved tensor needs to be accessed,
+    namely when executing :func:`torch.Tensor.backward()` or
+    :func:`torch.autograd.grad()`. It takes as argument the *packed* object
+    returned by ``pack_hook`` and should return a tensor which has the same
+    content as the original tensor (passed as input to the corresponding
+    ``pack_hook``).
+
+    The hooks should have the following signatures:
+
+        pack_hook(tensor: Tensor) -> Any
+
+        unpack_hook(Any) -> Tensor
+
+    where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
+
+    In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
+    of value, size, dtype and device.
+
+    Example::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> def pack_hook(x):
+        ...     print("Packing", x)
+        ...     return x
+        >>>
+        >>> def unpack_hook(x):
+        ...     print("Unpacking", x)
+        ...     return x
+        >>>
+        >>> a = torch.ones(5, requires_grad=True)
+        >>> b = torch.ones(5, requires_grad=True) * 2
+        >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
+        ...     y = a * b
+        Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
+        Packing tensor([2., 2., 2., 2., 2.], grad_fn=)
+        >>> y.sum().backward()
+        Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
+        Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=)
+
+    .. warning ::
+        Performing an inplace operation on the input to either hooks may lead
+        to undefined behavior.
+
+    .. warning ::
+        Only one pair of hooks is allowed at a time. When recursively nesting this
+        context-manager, only the inner-most pair of hooks will be applied.
+    """
+
+    def __init__(
+        self,
+        pack_hook: Callable[[torch.Tensor], Any],
+        unpack_hook: Callable[[Any], torch.Tensor],
+    ):
+        self.pack_hook = pack_hook
+        self.unpack_hook = unpack_hook
+
+    def __enter__(self):
+        torch._C._autograd._push_saved_tensors_default_hooks(
+            self.pack_hook, self.unpack_hook
+        )
+
+    def __exit__(self, *args: object):
+        torch._C._autograd._pop_saved_tensors_default_hooks()
+
+
+class save_on_cpu(saved_tensors_hooks):
+    """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
+
+    When performing operations within this context manager, intermediary
+    results saved in the graph during the forward pass will be moved to CPU,
+    then copied back to the original device when needed for the backward pass.
+    If the graph was already on CPU, no tensor copy is performed.
+
+    Use this context-manager to trade compute for GPU memory usage (e.g.
+    when your model doesn't fit in GPU memory during training).
+
+    Args:
+        pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
+                           during packing and copied to GPU asynchronously during unpacking.
+                           Defaults to ``False``.
+                           Also see :ref:`cuda-memory-pinning`.
+
+
+    Example::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
+        >>> a = torch.randn(5, requires_grad=True, device="cuda")
+        >>> b = torch.randn(5, requires_grad=True, device="cuda")
+        >>> c = torch.randn(5, requires_grad=True, device="cuda")
+        >>>
+        >>> def f(a, b, c):
+        ...     prod_1 = a * b           # a and b are saved on GPU
+        ...     with torch.autograd.graph.save_on_cpu():
+        ...         prod_2 = prod_1 * c  # prod_1 and c are saved on CPU
+        ...     y = prod_2 * a           # prod_2 and a are saved on GPU
+        ...     return y
+        >>>
+        >>> y = f(a, b, c)
+        >>> del a, b, c  # for illustration only
+        >>> # the content of a, b, and prod_2 are still alive on GPU
+        >>> # the content of prod_1 and c only live on CPU
+        >>> y.sum().backward()  # all CPU tensors are moved back to GPU, for backward
+        >>> # all intermediary tensors are released (deleted) after the call to backward
+
+    """
+
+    def __init__(self, pin_memory=False, device_type="cuda"):
+        device_module = getattr(torch, device_type, torch.cuda)
+
+        def pack_to_cpu(tensor):
+            if not pin_memory:
+                return (tensor.device, tensor.cpu())
+            packed = torch.empty(
+                tensor.size(),
+                dtype=tensor.dtype,
+                layout=tensor.layout,
+                pin_memory=(device_module.is_available() and not tensor.is_sparse),
+            )
+            packed.copy_(tensor)
+            return (tensor.device, packed)
+
+        def unpack_from_cpu(packed):
+            device, tensor = packed
+            return tensor.to(device, non_blocking=pin_memory)
+
+        super().__init__(pack_to_cpu, unpack_from_cpu)
+
+
+@contextlib.contextmanager
+def disable_saved_tensors_hooks(error_message):
+    """Context-manager that disables the saved tensors default hooks feature.
+
+    Useful for if you are creating a feature that does not work with saved
+    tensors default hooks.
+
+    Args:
+        error_message (str): When saved tensors default hooks are used when they
+                             have been are disabled, a RuntimeError with this
+                             error message gets raised.
+
+    Example::
+
+        >>> # xdoctest: +SKIP(failing)
+        >>> message = "saved tensors default hooks are disabled"
+        >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
+        ...     # Raises RuntimeError: saved tensors default hooks are disabled
+        ...     with torch.autograd.graph.save_on_cpu():
+        ...         pass
+
+    """
+    try:
+        maybe_prev_message = (
+            torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
+        )
+        torch._C._autograd._saved_tensors_hooks_disable(error_message)
+        yield
+    finally:
+        # See NOTE: [disabled_error_message invariant]
+        if maybe_prev_message is None:
+            torch._C._autograd._saved_tensors_hooks_enable()
+        else:
+            torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
+
+
+def register_multi_grad_hook(
+    tensors: Sequence[torch.Tensor],
+    fn: Union[
+        Callable[[Sequence[Optional[torch.Tensor]]], None],
+        Callable[[torch.Tensor], None],
+    ],
+    *,
+    mode: str = "all",
+):
+    r"""Register a multi-grad backward hook.
+
+    There are two supported modes: ``"all"`` and ``"any"``.
+
+    Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
+    :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
+    is not part of the graph, or if a tensor is not needed to compute the gradients
+    for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
+    this tensor will be ignored and the hook will not wait for its gradient to be
+    computed.
+
+    After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
+    called with those gradients. ``None`` will be passed for tensors that did not
+    have their gradients computed.
+
+    Under the ``"any"`` mode, the hook will be called after the first gradient
+    with respect to a tensor in :attr:`tensors` has been computed. The hook
+    will be called with that gradient as its argument.
+
+    The hook should not modify its arguments.
+
+    This function returns a handle with a method ``handle.remove()`` that removes the hook.
+
+    .. note::
+        See :ref:`backward-hooks-execution` for more information on how when this hook
+        is executed, and how its execution is ordered relative to other hooks.
+
+    Example::
+
+        >>> import torch
+        >>>
+        >>> a = torch.rand(2, 3, requires_grad=True)
+        >>> b = torch.rand(2, 3, requires_grad=True)
+        >>> c = a * b
+        >>> d = a * b
+        >>>
+        >>> def fn(grads):
+        ...     print([g is not None for g in grads])
+        ...
+        >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
+        >>>
+        >>> c.sum().backward(retain_graph=True)
+        [True, True, True, False]
+        >>> c.sum().backward(inputs=(a,), retain_graph=True)
+        [True, False, True, False]
+        >>>
+    """
+    supported_modes = ("all", "any")
+    if mode not in supported_modes:
+        raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
+
+    class Handle(RemovableHandle):
+        handles: Tuple[RemovableHandle, ...]
+
+        def __init__(self, handles: Tuple[RemovableHandle, ...]):
+            self.handles = handles
+
+        def remove(self):
+            for handle in self.handles:
+                handle.remove()
+
+        def __getstate__(self):
+            return self.handles
+
+        def __setstate__(self, state):
+            self.handles = state
+
+    if mode == "all":
+        count: Dict[int, int] = dict()
+        nb_calls = None
+        buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
+
+        grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
+        len_tensors = len(tensors)
+
+        def get_inner_hook(idx):
+            def inner_hook(grad: torch.Tensor):
+                nonlocal count, nb_calls, buffer, fn
+                id = torch._C._current_graph_task_id()
+                assert (
+                    id != -1
+                ), "expected this hook to be called inside a backward call"
+                count[id] = count.get(id, 0)
+                buffer[id] = buffer.get(id, [None] * len_tensors)
+
+                if count[id] == 0:
+                    # On the first call, compute the actual nb_calls and buffer
+                    nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns)  # type: ignore[attr-defined]
+
+                buffer[id][idx] = grad
+                count[id] += 1
+
+                if count[id] == nb_calls:
+                    fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
+                    fn(buffer[id])
+                    del count[id]
+                    del buffer[id]
+
+            return inner_hook
+
+        handles: Tuple[RemovableHandle] = tuple(
+            t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
+        )
+    elif mode == "any":
+        fn = cast(Callable[[torch.Tensor], None], fn)
+        lock = threading.Lock()
+        ran_hook: Dict[int, bool] = defaultdict(bool)
+
+        @functools.wraps(fn)
+        def wrapped_fn(grad: torch.Tensor):
+            nonlocal ran_hook
+            id = torch._C._current_graph_task_id()
+            assert id != -1, "expected this hook to be called inside a backward call"
+            with lock:
+                prev, ran_hook[id] = ran_hook[id], True
+            if prev:
+                return
+            fn(grad)
+
+        handles = tuple(
+            tensor.register_hook(wrapped_fn)
+            for tensor in tensors
+            if tensor.requires_grad
+        )
+
+    return Handle(handles)  # type: ignore[possibly-undefined]
+
+
+# NOTE [Allow mutation on tensors saved for backward]
+#
+# 1. Tensor gets saved for backward
+#    - remember the python object id and the version of the tensor
+#    - remember aliasing information (data_ptr of base + version)
+#    - save the original so we control its lifetime
+# 2. Any time a tensor gets in-placed
+#    - for each tensor aliased to it:
+#      - check using its object id and version to see if it has been saved
+#      - if it has been saved, clone it
+#      - delete the reference to the original
+# 3. during backward
+#    - if the clone exists, the tensor must've been modified in-place
+_allow_mutation_on_saved_tensors_enabled = False
+
+
+def _get_tid(t) -> Tuple[int, int, int]:
+    return (id(t), t.data_ptr(), t._version)
+
+
+def _get_sid(t) -> Tuple[int, int]:
+    return (t.data_ptr(), t._version)
+
+
+class _Handle:
+    pass
+
+
+class _swap_with_cloned(saved_tensors_hooks):
+    def __init__(self, ctx):
+        def pack_hook(t):
+            tid = _get_tid(t)
+            sid = _get_sid(t)
+            # Tensors saved for backward have an entry in _tid_to_weakhandle
+            handle: Optional[_Handle] = None
+
+            # Save aliasing information
+            ctx.sid_to_tid[sid].add(tid)
+
+            # NB: The same tensor (of the same version) can be saved multiple times
+            if tid not in ctx.tid_to_weakhandle:
+                handle = _Handle()
+                ctx.tid_to_weakhandle[tid] = handle
+                ctx.original[handle] = t
+            else:
+                # Store an additional strong reference to the handle
+                handle = ctx.tid_to_weakhandle[tid]
+            return handle
+
+        def unpack_hook(tup):
+            handle = tup
+            error_msg = (
+                "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
+                "in which the graph was originally recorded."
+            )
+            assert _allow_mutation_on_saved_tensors_enabled, error_msg
+            if handle in ctx.cloned:
+                res = ctx.cloned[handle]
+            else:
+                assert handle in ctx.original, error_msg
+                res = ctx.original[handle]
+            return res
+
+        super().__init__(pack_hook, unpack_hook)
+
+
+class _CloneArgBeforeMutateMode(TorchDispatchMode):
+    def __init__(self, ctx):
+        self.ctx = ctx
+
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        for idx, arg in enumerate(func._schema.arguments):
+            if arg.alias_info is not None and arg.alias_info.is_write:
+                t = kwargs["out"] if arg.is_out else args[idx]
+                tid = _get_tid(t)
+                sid = _get_sid(t)
+                ctx = self.ctx
+                if sid in ctx.sid_to_tid:
+                    for tid in ctx.sid_to_tid[sid]:
+                        if tid not in ctx.tid_to_weakhandle:
+                            # We know that if tid is in sid_to_tid, then it must also be in
+                            # tid_to_weakhandle. However, it is possible for the tensor to be
+                            # saved at one point, but cleared by backward before it is modified
+                            # in-place. Consider the following example:
+                            #
+                            # >>> a = torch.randn(2, 3, requires_grad=True).clone()
+                            # >>> out = (a**2).sum()
+                            # >>> out.backward()
+                            # >>> a.sin_()
+                            continue
+                        handle = ctx.tid_to_weakhandle[tid]
+                        if handle in ctx.cloned:
+                            # The same exact tensor has been cloned already
+                            continue
+                        ctx.cloned[handle] = ctx.original[handle].clone()
+                        del ctx.original[handle]
+
+        rs = func(*args, **kwargs)
+        return rs
+
+
+class _AllowMutationOnSavedContext:
+    def __init__(self):
+        self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
+        self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
+        self.tid_to_weakhandle: weakref.WeakValueDictionary = (
+            weakref.WeakValueDictionary()
+        )
+        self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
+            set
+        )
+
+    def clear(self):
+        self.cloned.clear()
+        self.original.clear()
+        self.tid_to_weakhandle.clear()
+        self.sid_to_tid.clear()
+
+
+@contextlib.contextmanager
+def allow_mutation_on_saved_tensors():
+    """Context manager under which mutating tensors saved for backward is allowed.
+
+    Under this context manager, tensors saved for backward are cloned on mutation,
+    so the original version can still be used during backward. Normally, mutating a tensor
+    saved for backward will result in an error raised when it's used during backward.
+
+    To ensure the correct behavior, both the forward and backward should be run under
+    the same context manager.
+
+    returns:
+        An _AllowMutationOnSavedContext object storing the state managed by this
+        context manager. This object can be useful for debugging purposes. The state
+        managed by the context manager is automatically cleared upon exiting.
+
+    Example::
+
+        >>> import torch
+        >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
+        ...     # forward
+        ...     a = torch.ones(2, 3, requires_grad=True)
+        ...     b = a.clone()
+        ...     out = (b**2).sum()
+        ...     b.sin_()
+        ...     # backward
+        ...     out.sum().backward()
+        ...
+        tensor([[0.8415, 0.8415, 0.8415],
+                [0.8415, 0.8415, 0.8415]], grad_fn=)
+    """
+    global _allow_mutation_on_saved_tensors_enabled
+
+    ctx = _AllowMutationOnSavedContext()
+
+    with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
+        try:
+            if _allow_mutation_on_saved_tensors_enabled:
+                raise RuntimeError(
+                    "allow_mutation_on_saved_tensors contexts cannot be nested"
+                )
+            _allow_mutation_on_saved_tensors_enabled = True
+            yield ctx
+        finally:
+            ctx.clear()
+            _allow_mutation_on_saved_tensors_enabled = False
+
+
+def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
+    grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
+
+    def iter_graph(roots):
+        if not roots:
+            return
+        seen = set()
+        q: Deque = collections.deque()
+        for node in roots:
+            if node is not None:
+                seen.add(node)
+                q.append(node)
+
+        while q:
+            node = q.popleft()
+            for fn, _idx in node.next_functions:
+                if fn in seen or fn is None:
+                    continue
+                seen.add(fn)
+                q.append(fn)
+
+            yield node
+
+    def fmt(t):
+        # Avoid circular import
+        from torch.testing._internal.common_utils import dtype_abbrs
+
+        if t is None:
+            return "None"
+        return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
+
+    def prehook(grad_outputs):
+        node = torch._C._current_autograd_node()
+        grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
+        log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
+        log.debug(log_str)
+
+    handles = []
+    for node in iter_graph(grad_fns):
+        handles.append(node.register_prehook(prehook))
+
+    def unregister_hooks():
+        for handle in handles:
+            handle.remove()
+
+    return unregister_hooks
+
+
+def _engine_run_backward(t_outputs, *args, **kwargs):
+    attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
+    if attach_logging_hooks:
+        unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
+    try:
+        return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
+            t_outputs, *args, **kwargs
+        )  # Calls into the C++ engine to run the backward pass
+    finally:
+        if attach_logging_hooks:
+            unregister_hooks()  # type: ignore[possibly-undefined]
diff --git a/MLPY/Lib/site-packages/torch/autograd/profiler.py b/MLPY/Lib/site-packages/torch/autograd/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5daeb6d250ffa5982e6ddaca21d4a7a689f8e56b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/profiler.py
@@ -0,0 +1,1042 @@
+from collections import defaultdict
+from typing import Any, Dict, List, Optional
+from warnings import warn
+
+import torch
+
+import torch.cuda
+from torch._C import _get_privateuse1_backend_name
+from torch._C._profiler import _ExperimentalConfig
+
+from torch.autograd import (
+    _disable_profiler,
+    _enable_profiler,
+    _kineto_step,
+    _prepare_profiler,
+    _ProfilerResult,
+    _supported_activities,
+    DeviceType,
+    kineto_available,
+    ProfilerActivity,
+    ProfilerConfig,
+    ProfilerState,
+)
+from torch.autograd.profiler_util import (
+    _filter_name,
+    _filter_stack_entry,
+    _rewrite_name,
+    EventList,
+    FunctionEvent,
+    MEMORY_EVENT_NAME,
+    MemRecordsAcc,
+    OUT_OF_MEMORY_EVENT_NAME,
+)
+from torch.futures import Future
+
+__all__ = [
+    "profile",
+    "record_function",
+    "emit_itt",
+    "emit_nvtx",
+    "load_nvprof",
+    "EnforceUnique",
+    "parse_nvprof_trace",
+    "KinetoStepTracker",
+    "EventList",
+    "FunctionEvent",
+    "MemRecordsAcc",
+]
+
+try:
+    # Available in Python >= 3.2
+    from contextlib import ContextDecorator as _ContextDecorator
+except ImportError:
+    import functools
+
+    class _ContextDecorator:  # type: ignore[no-redef]
+        def __enter__(self):
+            raise NotImplementedError
+
+        def __exit__(self, exc_type, exc_val, exc_tb):
+            raise NotImplementedError
+
+        def __call__(self, func):
+            @functools.wraps(func)
+            def wrapped(*args, **kwargs):
+                with self:
+                    return func(*args, **kwargs)
+
+            return wrapped
+
+
+# global python state - whether profiler is currently enabled
+# useful for fast python checks to reduce latency
+_is_profiler_enabled: bool = False
+
+
+def _set_is_profiler_enabled(enable: bool):
+    global _is_profiler_enabled
+    _is_profiler_enabled = enable
+
+
+def _run_on_profiler_start():
+    _set_is_profiler_enabled(True)
+
+
+def _run_on_profiler_stop():
+    _set_is_profiler_enabled(False)
+
+
+class profile:
+    """Context manager that manages autograd profiler state and holds a summary of results.
+
+    Under the hood it just records events of functions being executed in C++ and
+    exposes those events to Python. You can wrap any code into it and it will
+    only report runtime of PyTorch functions.
+    Note: profiler is thread local and is automatically propagated into the async tasks
+
+    Args:
+        enabled (bool, optional): Setting this to False makes this context manager a no-op.
+
+        use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
+            Adds approximately 4us of overhead to each tensor operation.
+
+        record_shapes (bool, optional): If shapes recording is set, information
+            about input dimensions will be collected. This allows one to see which
+            dimensions have been used under the hood and further group by them
+            using prof.key_averages(group_by_input_shape=True). Please note that
+            shape recording might skew your profiling data. It is recommended to
+            use separate runs with and without shape recording to validate the timing.
+            Most likely the skew will be negligible for bottom most events (in a case
+            of nested function calls). But for higher level functions the total
+            self cpu time might be artificially increased because of the shape
+            collection.
+
+        with_flops (bool, optional): If with_flops is set, the profiler will estimate
+            the FLOPs (floating point operations) value using the operator's input shape.
+            This allows one to estimate the hardware performance. Currently,
+            this option only works for the matrix multiplication and 2D convolution operators.
+
+        profile_memory (bool, optional): track tensor memory allocation/deallocation.
+
+        with_stack (bool, optional): record source information (file and line number) for the ops.
+
+        with_modules (bool): record module hierarchy (including function names)
+            corresponding to the callstack of the op. e.g. If module A's forward call's
+            module B's forward which contains an aten::add op,
+            then aten::add's module hierarchy is A.B
+            Note that this support exist, at the moment, only for TorchScript models
+            and not eager mode models.
+
+        use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
+
+        use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
+            ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
+
+        experimental_config (_ExperimentalConfig) : A set of experimental options
+            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
+
+
+    .. warning:
+        Enabling memory profiling or source attribution incurs additional profiler
+        overhead
+
+    .. warning:
+        This context managers should not be called recursively, i.e. no nested
+        instances are allowed
+
+    .. warning:
+        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
+        one cannot use the profiler with ``use_cuda = True`` to benchmark
+        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
+        please use ``use_cuda = False`` or ``num_workers = 0``.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
+        >>> x = torch.randn((1, 1), requires_grad=True)
+        >>> with torch.autograd.profiler.profile() as prof:
+        >>>     for _ in range(100):  # any normal python code, really!
+        >>>         y = x ** 2
+        >>>         y.backward()
+        >>> # NOTE: some columns were removed for brevity
+        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
+        -----------------------------------  ---------------  ---------------  ---------------
+        Name                                 Self CPU total   CPU time avg     Number of Calls
+        -----------------------------------  ---------------  ---------------  ---------------
+        mul                                  32.048ms         32.048ms         200
+        pow                                  27.041ms         27.041ms         200
+        PowBackward0                         9.727ms          55.483ms         100
+        torch::autograd::AccumulateGrad      9.148ms          9.148ms          100
+        torch::autograd::GraphRoot           691.816us        691.816us        100
+        -----------------------------------  ---------------  ---------------  ---------------
+
+    """
+
+    def __init__(
+        self,
+        enabled=True,
+        *,
+        use_cuda=False,
+        use_device=None,
+        record_shapes=False,
+        with_flops=False,
+        profile_memory=False,
+        with_stack=False,
+        with_modules=False,
+        use_kineto=False,
+        use_cpu=True,
+        use_mtia=False,
+        experimental_config=None,
+    ):
+        self.enabled: bool = enabled
+        if not self.enabled:
+            return
+        self.use_cuda = use_cuda
+        self.use_device: Optional[str] = (
+            use_device if use_device != "privateuseone" else None
+        )
+        self.function_events: Optional[EventList] = None
+        self.entered = False
+        self.record_shapes = record_shapes
+        self.with_flops = with_flops
+        self.record_shapes |= self.with_flops
+        self.profile_memory = profile_memory
+        self.with_stack = with_stack
+        self.with_modules = with_modules
+        self.use_cpu = use_cpu
+        self.use_mtia = use_mtia
+        if experimental_config is None:
+            experimental_config = _ExperimentalConfig()
+        self.experimental_config = experimental_config
+        self.kineto_results: Optional[_ProfilerResult] = None
+
+        if not self.use_cpu:
+            assert (
+                use_kineto
+            ), "Device-only events supported only with Kineto (use_kineto=True)"
+
+        if self.use_device == "cuda":
+            self.use_device = None
+            self.use_cuda = True
+
+        if self.use_device and self.use_device != _get_privateuse1_backend_name():
+            warn(f"{self.use_device} doesn't support profile.")
+            self.use_device = None
+
+        if self.use_cuda and not torch.cuda.is_available():
+            warn("CUDA is not available, disabling CUDA profiling")
+            self.use_cuda = False
+
+        self.kineto_activities = set()
+        if self.use_cpu:
+            self.kineto_activities.add(ProfilerActivity.CPU)
+        if self.use_mtia:
+            self.kineto_activities.add(ProfilerActivity.MTIA)
+
+        self.profiler_kind = ProfilerState.KINETO
+        if self.use_cuda:
+            if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
+                assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
+                self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
+            else:
+                self.kineto_activities.add(ProfilerActivity.CUDA)
+
+        if self.use_device:
+            if (
+                not use_kineto
+                or ProfilerActivity.PrivateUse1 not in _supported_activities()
+            ):
+                assert (
+                    self.use_cpu
+                ), "Legacy custombackend profiling requires use_cpu=True"
+                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
+            else:
+                self.kineto_activities.add(ProfilerActivity.PrivateUse1)
+                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1
+
+        assert (
+            len(self.kineto_activities) > 0
+        ), "No activities specified for the profiler"
+
+    def config(self):
+        return ProfilerConfig(
+            self.profiler_kind,
+            self.record_shapes,
+            self.profile_memory,
+            self.with_stack,
+            self.with_flops,
+            self.with_modules,
+            self.experimental_config,
+        )
+
+    def __enter__(self):
+        if not self.enabled:
+            return
+        if self.entered:
+            raise RuntimeError("Profiler context manager is not reentrant")
+        self._prepare_trace()
+        self._start_trace()
+        return self
+
+    def _prepare_trace(self):
+        self.entered = True
+        _prepare_profiler(self.config(), self.kineto_activities)
+
+    def _start_trace(self):
+        self.entered = True
+        _run_on_profiler_start()
+        _enable_profiler(self.config(), self.kineto_activities)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if not self.enabled:
+            return
+        if self.use_cuda:
+            torch.cuda.synchronize()
+        self.kineto_results = _disable_profiler()
+        _run_on_profiler_stop()
+        parsed_results = self._parse_kineto_results(self.kineto_results)
+        self.function_events = EventList(
+            parsed_results,
+            use_cuda=self.use_cuda,
+            use_device=self.use_device,
+            profile_memory=self.profile_memory,
+            with_flops=self.with_flops,
+        )
+        self.function_events._build_tree()
+        return False
+
+    def __repr__(self):
+        if self.function_events is None:
+            return ""
+        return repr(self.function_events)
+
+    def __str__(self):
+        if self.function_events is None:
+            return ""
+        return str(self.function_events)
+
+    def _check_finish(self):
+        if self.function_events is None:
+            raise RuntimeError("Profiler didn't finish running")
+
+    def table(
+        self,
+        sort_by=None,
+        row_limit=100,
+        max_src_column_width=75,
+        max_name_column_width=55,
+        max_shapes_column_width=80,
+        header=None,
+        top_level_events_only=False,
+    ):
+        self._check_finish()
+        assert self.function_events is not None
+        return self.function_events.table(
+            sort_by=sort_by,
+            row_limit=row_limit,
+            max_src_column_width=max_src_column_width,
+            max_name_column_width=max_name_column_width,
+            max_shapes_column_width=max_shapes_column_width,
+            header=header,
+            top_level_events_only=top_level_events_only,
+        )
+
+    table.__doc__ = EventList.table.__doc__
+
+    def export_chrome_trace(self, path):
+        self._check_finish()
+        if kineto_available():
+            self.kineto_results.save(path)  # type: ignore[union-attr]
+        else:
+            return self.function_events.export_chrome_trace(path)  # type: ignore[union-attr]
+
+    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
+
+    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
+        self._check_finish()
+        assert self.function_events is not None, "Expected profiling results"
+        assert self.with_stack, "export_stacks() requires with_stack=True"
+        return self.function_events.export_stacks(path, metric)
+
+    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
+        self._check_finish()
+        assert self.function_events is not None, "Expected profiling results"
+        return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
+
+    key_averages.__doc__ = EventList.key_averages.__doc__
+
+    def total_average(self):
+        self._check_finish()
+        assert self.function_events is not None, "Expected profiling results"
+        return self.function_events.total_average()
+
+    total_average.__doc__ = EventList.total_average.__doc__
+
+    @property
+    def self_cpu_time_total(self):
+        """Returns total time spent on CPU.
+
+        The total time is a sum of all self times across all the events.
+        """
+        self._check_finish()
+        assert self.function_events is not None
+        return self.function_events.self_cpu_time_total
+
+    def _parse_kineto_results(self, result: _ProfilerResult):
+        # result.events() has most of the events - PyTorch op-level and device-level events
+
+        trace_start_us = result.trace_start_us()
+        mem_records = [
+            [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
+        ]
+        oom_records = [
+            evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
+        ]
+        mem_records_acc = MemRecordsAcc(mem_records)
+
+        def _cpu_memory_usage(mem_record):
+            return (
+                mem_record.nbytes()
+                if mem_record.device_type()
+                in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
+                else 0
+            )
+
+        def _cuda_memory_usage(mem_record):
+            return (
+                mem_record.nbytes()
+                if mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP]
+                else 0
+            )
+
+        def _privateuse1_memory_usage(mem_record):
+            return (
+                mem_record.nbytes()
+                if mem_record.device_type() in [DeviceType.PrivateUse1]
+                else 0
+            )
+
+        # Create and return FunctionEvent list
+        function_events = []
+        device_corr_map: Dict[int, List[FunctionEvent]] = {}
+        max_evt_id = 0
+        for kineto_event in result.events():
+            if _filter_name(kineto_event.name()):
+                continue
+            rel_start_us = kineto_event.start_us() - trace_start_us
+            rel_end_us = rel_start_us + kineto_event.duration_us()
+            abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
+
+            cpu_memory_usage = 0
+            cuda_memory_usage = 0
+            privateuse1_memory_usage = 0
+            if kineto_event.device_type() == DeviceType.CPU:
+                # find the corresponding memory allocation events
+                for mem_record in mem_records_acc.in_interval(
+                    kineto_event.start_us(), abs_end_us
+                ):
+                    cpu_memory_usage += _cpu_memory_usage(mem_record[0])
+                    cuda_memory_usage += _cuda_memory_usage(mem_record[0])
+                    privateuse1_memory_usage += _privateuse1_memory_usage(mem_record[0])
+                    mem_record[1] = True
+
+            is_async = kineto_event.is_async() or (
+                kineto_event.start_thread_id() != kineto_event.end_thread_id()
+            )
+
+            fe = FunctionEvent(
+                id=kineto_event.correlation_id(),
+                name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
+                trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
+                thread=kineto_event.start_thread_id(),
+                start_us=rel_start_us,
+                end_us=rel_end_us,
+                fwd_thread=kineto_event.fwd_thread_id(),
+                input_shapes=kineto_event.shapes(),
+                concrete_inputs=kineto_event.concrete_inputs(),
+                stack=[
+                    entry
+                    for entry in kineto_event.stack()
+                    if _filter_stack_entry(entry)
+                ],
+                scope=kineto_event.scope(),
+                use_device=self.use_device,
+                cpu_memory_usage=cpu_memory_usage,
+                cuda_memory_usage=cuda_memory_usage,
+                privateuse1_memory_usage=privateuse1_memory_usage,
+                is_async=is_async,
+                sequence_nr=kineto_event.sequence_nr(),
+                device_type=kineto_event.device_type(),
+                device_index=kineto_event.device_index(),
+                flops=kineto_event.flops(),
+            )
+            max_evt_id = max(max_evt_id, fe.id)
+            if fe.device_type == DeviceType.CPU and not fe.is_async:
+                if self.use_device:
+                    privateuse1_time = kineto_event.privateuse1_elapsed_us()
+                    if privateuse1_time > 0:
+                        fe.append_kernel(fe.name, fe.device_index, privateuse1_time)
+                        fe.is_legacy = True
+                else:
+                    # Check if we have CUDA time as a fallback
+                    cuda_time = kineto_event.cuda_elapsed_us()
+                    if cuda_time > 0:
+                        fe.append_kernel(fe.name, fe.device_index, cuda_time)
+                        fe.is_legacy = True
+            function_events.append(fe)
+            corr_id = kineto_event.linked_correlation_id()
+            if corr_id > 0:
+                if corr_id not in device_corr_map:
+                    device_corr_map[corr_id] = []
+                device_corr_map[corr_id].append(fe)
+
+        # associate CUDA kernels and CUDA runtime (CPU) with CPU events
+        for fe in function_events:
+            if (
+                fe.device_type == DeviceType.CPU
+                and not fe.is_async
+                and fe.id in device_corr_map
+            ):
+                for f_evt in device_corr_map[fe.id]:
+                    if f_evt.device_type == DeviceType.CUDA:
+                        fe.append_kernel(
+                            f_evt.name,
+                            f_evt.device_index,
+                            f_evt.time_range.end - f_evt.time_range.start,
+                        )
+                    elif f_evt.device_type == DeviceType.CPU:
+                        # make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated
+                        # with the 'thread' of the corresponding linked PyTorch event to properly track
+                        # parents and children
+                        f_evt.thread = fe.thread
+
+        def createFunctionEventForMemoryEvents(evt):
+            rel_start_us = evt.start_us() - trace_start_us
+            fe = FunctionEvent(
+                id=max_evt_id,
+                name=evt.name(),
+                trace_name=None,  # not outputting in the trace
+                thread=evt.start_thread_id(),
+                start_us=rel_start_us,
+                end_us=rel_start_us,  # no duration
+                fwd_thread=evt.start_thread_id(),
+                input_shapes=[],
+                stack=[],
+                scope=0,  # RecordScope::FUNCTION
+                use_device=self.use_device,
+                cpu_memory_usage=_cpu_memory_usage(evt),
+                cuda_memory_usage=_cuda_memory_usage(evt),
+                privateuse1_memory_usage=_privateuse1_memory_usage(evt),
+                is_async=False,
+                sequence_nr=-1,
+                device_type=DeviceType.CPU,
+                device_index=0,
+            )
+            return fe
+
+        # output top-level memory events
+        for mem_record in mem_records:
+            if not mem_record[1]:
+                max_evt_id += 1
+                fe = createFunctionEventForMemoryEvents(mem_record[0])
+                function_events.append(fe)
+
+        for oom_record in oom_records:
+            max_evt_id += 1
+            fe = createFunctionEventForMemoryEvents(oom_record)
+            function_events.append(fe)
+
+        function_events.sort(
+            key=lambda evt: [evt.time_range.start, -evt.time_range.end]
+        )
+        return function_events
+
+
+class record_function(_ContextDecorator):
+    """Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
+
+    It is useful when tracing the code profile.
+
+    Args:
+        name (str): Label assigned to the block of code.
+        node_id (int): ID of node, for distributed profiling. Unset in
+        non-distributed cases.
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
+        >>> x = torch.randn((1, 1), requires_grad=True)
+        >>> with torch.autograd.profiler.profile() as prof:
+        ...     y = x ** 2
+        ...     with torch.autograd.profiler.record_function("label-z"): # label the block
+        ...         z = y ** 3
+        ...     y.backward()
+        ...
+        >>> # xdoctest: +IGNORE_WANT
+        >>> # NOTE: some columns were removed for brevity
+        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
+        -----------------------------------  ---------------  ---------------  ---------------
+        Name                                 Self CPU total %  CPU time avg     Number of Calls
+        -----------------------------------  ---------------  ---------------  ---------------
+        pow                                  60.77%           47.470us         3
+        mul                                  21.73%           25.465us         2
+        PowBackward0                         12.03%           121.891us        1
+        torch::autograd::AccumulateGrad      2.70%            6.324us          1
+        label-z                              2.13%            12.421us         1
+        torch::autograd::GraphRoot           0.64%            1.503us          1
+        -----------------------------------  ---------------  ---------------  ---------------
+        Self CPU time total: 234.344us
+        CUDA time total: 0.000us
+
+    """
+
+    def __init__(self, name: str, args: Optional[str] = None):
+        self.name: str = name
+        self.args: Optional[str] = args
+        # Whether or not we should run record function's end callbacks when exiting.
+        self.run_callbacks_on_exit: bool = True
+        # TODO: TorchScript ignores standard type annotation here
+        # self.record: Optional["torch.classes.profiler._RecordFunction"] = None
+        self.record = torch.jit.annotate(
+            Optional["torch.classes.profiler._RecordFunction"], None
+        )
+
+    def __enter__(self):
+        self.record = torch.ops.profiler._record_function_enter_new(
+            self.name, self.args
+        )
+        return self
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
+        if not self.run_callbacks_on_exit:
+            return
+
+        # Local variable is needed by TorchScript to refine Optional[T] to T
+        record = self.record
+        assert record is not None
+
+        # TODO: Too slow with __torch_function__ handling enabled
+        # See https://github.com/pytorch/pytorch/issues/76410
+        if not torch.jit.is_scripting():
+            with torch._C.DisableTorchFunctionSubclass():
+                torch.ops.profiler._record_function_exit._RecordFunction(record)
+        else:
+            torch.ops.profiler._record_function_exit(record)
+
+    def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
+        """Use for profiling async calls that return a future.
+
+        Calling this function will extend recording beyond this scope, until the future is
+        satisfied. It is useful for profiling the end to end time of asynchronous calls.
+        This function should only be called once to attach the callback onto the future, and
+        will throw if called multiple times.
+
+        Args:
+            fut: (torch._C.Future): future for which to schedule
+            callback for.
+
+        Returns:
+            A future that completes with the value of the passed in future when
+            the profiling callbacks have ran.
+
+        """
+        # Throw if we have already attached a callback onto the future.
+        if not self.run_callbacks_on_exit:
+            raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
+
+        # We are scheduling to run this RecordFunction's end callbacks when the
+        # passed in future completes, so don't run end callbacks on exit.
+        self.run_callbacks_on_exit = False
+
+        # Local variable is needed by TorchScript to refine Optional[T] to T
+        record = self.record
+        assert record is not None
+
+        # TODO: Too slow with __torch_function__ handling enabled
+        # See https://github.com/pytorch/pytorch/issues/76410
+        if not torch.jit.is_scripting():
+            with torch._C.DisableTorchFunctionSubclass():
+                profiled_future = (
+                    torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
+                        record, fut
+                    )
+                )
+        else:
+            profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
+                record, fut
+            )
+        return profiled_future
+
+
+class emit_itt:
+    """Context manager that makes every autograd operation emit an ITT range.
+
+    It is useful when running the program under Intel(R) VTune Profiler::
+
+        vtune <--vtune-flags> 
+
+    The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
+    control the collection of trace data during its execution across different Intel tools.
+    This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
+    you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
+
+    .. warning:
+        This context manager should not be called recursively, i.e. at most one
+        instance should be enabled at any given time.
+
+    Args:
+        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
+            Default: ``True``.
+        record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
+            each autograd op will append information about the sizes of Tensor arguments received
+            by that op, in the following format:
+            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
+            Non-tensor arguments will be represented by ``[]``.
+            Arguments will be listed in the order they are received by the backend op.
+            Please note that this order may not match the order in which those arguments were passed
+            on the Python side.  Also note that shape recording may increase the overhead of itt range creation.
+            Default: ``False``
+
+    Example:
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
+        >>> with torch.autograd.profiler.emit_itt():
+        ...     model(x)
+
+    """
+
+    def __init__(self, enabled=True, record_shapes=False):
+        self.enabled = enabled
+        self.entered = False
+        self.record_shapes = record_shapes
+
+    def __enter__(self):
+        if not self.enabled:
+            return
+        if self.entered:
+            raise RuntimeError("ITT annotation context manager is not reentrant")
+        self.entered = True
+        _run_on_profiler_start()
+        _enable_profiler(
+            ProfilerConfig(
+                ProfilerState.ITT,
+                self.record_shapes,
+                False,
+                False,
+                False,
+                False,
+                _ExperimentalConfig(),
+            ),
+            set(),
+        )
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if not self.enabled:
+            return
+        _disable_profiler()
+        _run_on_profiler_stop()
+        return False
+
+
+class emit_nvtx:
+    """Context manager that makes every autograd operation emit an NVTX range.
+
+    It is useful when running the program under nvprof::
+
+        nvprof --profile-from-start off -o trace_name.prof -- 
+
+    Unfortunately, there's no way to force nvprof to flush the data it collected
+    to disk, so for CUDA profiling one has to use this context manager to annotate
+    nvprof traces and wait for the process to exit before inspecting them.
+    Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
+    :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
+    e.g. in Python REPL.
+
+    .. warning:
+        This context manager should not be called recursively, i.e. at most one
+        instance should be enabled at any given time.
+
+    Args:
+        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
+            Default: ``True``.
+        record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
+            each autograd op will append information about the sizes of Tensor arguments received
+            by that op, in the following format:
+            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
+            Non-tensor arguments will be represented by ``[]``.
+            Arguments will be listed in the order they are received by the backend op.
+            Please note that this order may not match the order in which those arguments were passed
+            on the Python side.  Also note that shape recording may increase the overhead of nvtx range creation.
+            Default: ``False``
+
+    Example:
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
+        >>> with torch.cuda.profiler.profile():
+        ...     model(x)  # Warmup CUDA memory allocator and profiler
+        ...     with torch.autograd.profiler.emit_nvtx():
+        ...         model(x)
+
+    **Forward-backward correlation**
+
+    When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
+    correlating each backward-pass op with the corresponding forward-pass op can be difficult.
+    To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
+    generates.
+
+    During the forward pass, each function range is decorated with ``seq=``.  ``seq`` is a running
+    counter, incremented each time a new backward Function object is created and stashed for backward.
+    Thus, the ``seq=`` annotation associated with each forward function range tells you that
+    if a backward Function object is created by this forward function,
+    the backward object will receive sequence number N.
+    During the backward pass, the top-level range wrapping each C++ backward Function's
+    ``apply()`` call is decorated with ``stashed seq=``.  ``M`` is the sequence number that
+    the backward object was created with.  By comparing ``stashed seq`` numbers in backward with ``seq``
+    numbers in forward, you can track down which forward op created each backward Function.
+
+    Any functions executed during the backward pass are also decorated with ``seq=``.  During
+    default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
+    ``N`` may simply be 0 for all such functions.  Only the top-level ranges associated with
+    backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
+    objects with the earlier forward pass.
+
+    **Double-backward**
+
+    If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
+    if you are setting up for a double-backward), each function's execution during backward
+    is given a nonzero, useful ``seq=``.  Those functions may themselves create Function objects
+    to be executed later during double-backward, just as the original functions in the forward pass did.
+    The relationship between backward and double-backward is conceptually the same as the relationship
+    between forward and backward: The functions still emit current-sequence-number-tagged ranges,
+    the Function objects they create still stash those sequence numbers, and during the eventual
+    double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
+    numbers, which can be compared to `seq` numbers from the backward pass.
+
+    .. warning:
+        The sequence number is thread-local, and some forward functions don't create an associated
+        backward Function object (instead delegating that to sub-functions further down the call chain).
+        For these reasons, the correspondence of stashed sequence numbers in
+        backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
+        not guaranteed to be 1 to 1.  The sequence numbers alone may not be enough to fully
+        disambiguate which forward function created which
+        backward Function object.  You may need to make a judgment based on analytic knowledge of what
+        the expected correspondence should be.
+    """
+
+    def __init__(self, enabled=True, record_shapes=False):
+        self.enabled = enabled
+        self.entered = False
+        self.record_shapes = record_shapes
+
+    def __enter__(self):
+        if not self.enabled:
+            return
+        if self.entered:
+            raise RuntimeError("NVTX annotation context manager is not reentrant")
+        self.entered = True
+        torch.cuda.synchronize()
+        _run_on_profiler_start()
+        _enable_profiler(
+            ProfilerConfig(
+                ProfilerState.NVTX,
+                self.record_shapes,
+                False,
+                False,
+                False,
+                False,
+                _ExperimentalConfig(),
+            ),
+            set(),
+        )
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if not self.enabled:
+            return
+        torch.cuda.synchronize()
+        _disable_profiler()
+        _run_on_profiler_stop()
+        return False
+
+
+def load_nvprof(path):
+    """Open an nvprof trace file and parses autograd annotations.
+
+    Args:
+        path (str): path to nvprof trace
+    """
+    return EventList(parse_nvprof_trace(path))
+
+
+class EnforceUnique:
+    """Raises an error if a key is seen more than once."""
+
+    def __init__(self):
+        self.seen = set()
+
+    def see(self, *key):
+        r"""
+        Observe a key and raise an error if it is seen multiple times.
+        """
+        if key in self.seen:
+            raise RuntimeError("duplicate key: " + str(key))
+        self.seen.add(key)
+
+
+def parse_nvprof_trace(path):
+    import sqlite3
+
+    conn = sqlite3.connect(path)
+    conn.row_factory = sqlite3.Row
+
+    # Parse strings table
+    strings = {}
+    for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
+        strings[r["id"]] = torch._C._demangle(r["value"])
+
+    # First, find all functions and create FunctionEvents for them
+    marker_query = """
+    SELECT
+        start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
+    FROM
+        CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
+        ON start.id = end.id
+    WHERE
+        start.name != 0 AND end.name = 0
+    """
+    functions = []
+    functions_map = {}
+    unique = EnforceUnique()
+    for row in conn.execute(marker_query):
+        unique.see(row["marker_id"])
+        evt = FunctionEvent(
+            id=row["marker_id"],
+            node_id=0,  # missing a node_id when calling FunctionEvent. This is just to ensure
+            # that pytorch doesn't crash when creating a FunctionEvent() object
+            name=strings[row["name"]],
+            start_us=row["start_time"],
+            end_us=row["end_time"],
+            thread=0,
+        )  # TODO: find in sqlite database
+        functions.append(evt)
+        functions_map[evt.id] = evt
+
+    # Now, correlate all kernels with FunctionEvents
+    kernel_query = """
+    SELECT
+        start.id AS marker_id, start.name, start.timestamp, end.timestamp,
+        runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
+        kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
+    FROM
+        CUPTI_ACTIVITY_KIND_MARKER AS start
+        INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
+            ON start.id = end.id
+        INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
+            ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
+        INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
+            ON kernel.correlationId = runtime.correlationId
+    """
+    unique = EnforceUnique()
+    for row in conn.execute(kernel_query):
+        unique.see(row["marker_id"], row["runtime_id"])
+        # 211 is cudaKernelLaunch for cuda >= 9.2
+        assert row["cbid"] == 211
+        evt = functions_map[row["marker_id"]]
+        evt.append_kernel(
+            row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
+        )
+
+    functions.sort(key=lambda evt: evt.time_range.start)
+    return functions
+
+
+class KinetoStepTracker:
+    """Provides an abstraction for incrementing the step count globally.
+
+    Previously, we only had one place to mark that a step() has occurred
+    in the program via pytorch profiler step(). We will now add step hooks
+    in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
+
+    - This could mean programs that already call profiler.step() every
+      iteration can end up double incrementing step count.
+    - If a model uses multiple optimizers we can also have double or more
+      counting of the step.
+
+    We fix this by adding a layer of abstraction before calling step()
+    to the kineto library. The idea is to maintain steps per requester in a dict:
+
+    .. code-block::
+
+        {
+           "ProfilerStep": 100,  # triggered by profiler step() call
+           "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
+           "Optimizer2Step": 100,
+        }
+
+    To figure out the global step count just take the max of dict values (100).
+
+    If one of the count increments the max will go up.
+
+    .. code-block::
+
+        {
+           "ProfilerStep": 100,
+           "Optimizer1Step": 101,   # Optimizer1 got incremented first say
+           "Optimizer2Step": 100,
+        }
+
+    Then global step count is 101
+    We only call the kineto step() function when global count increments.
+
+    NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
+    for now. The result could be incorrect increments of the step count.
+    """
+
+    _current_step = 0
+    _step_dict: Dict[str, int] = defaultdict(int)
+
+    @classmethod
+    def init_step_count(cls, requester: str):
+        r"""
+        Initialize for a given requester.
+        """
+        cls._step_dict[requester] = cls._current_step
+
+    @classmethod
+    def erase_step_count(cls, requester: str) -> bool:
+        r"""
+        Remove a given requester.
+        """
+        return cls._step_dict.pop(requester, None) is not None
+
+    @classmethod
+    def increment_step(cls, requester: str) -> int:
+        """Increments the step count for the requester.
+
+        Additionally if the max over all step counts has incremented then
+        trigger the _kineto_step() returns global step count
+        """
+        if requester not in cls._step_dict:
+            cls.init_step_count(requester)
+        cls._step_dict[requester] += 1
+
+        new_step = max(cls._step_dict.values())
+        if new_step > cls._current_step:
+            delta = new_step - cls._current_step
+            if delta > 1:
+                warn(
+                    "Profiler step count has increased more than 1 - "
+                    f"current_step = {cls._current_step} step dict =  {cls._step_dict}"
+                )
+            for _ in range(0, delta):
+                _kineto_step()
+            cls._current_step = new_step
+        return cls._current_step
+
+    @classmethod
+    def current_step(cls) -> int:
+        r"""
+        Get the latest step for any requester
+        """
+        return cls._current_step
diff --git a/MLPY/Lib/site-packages/torch/autograd/profiler_legacy.py b/MLPY/Lib/site-packages/torch/autograd/profiler_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbccb7e69d26ce2e95bad4e64c33eea51594ce7b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/profiler_legacy.py
@@ -0,0 +1,303 @@
+import itertools
+from warnings import warn
+
+import torch
+import torch.cuda
+
+from torch.autograd import (
+    _disable_profiler_legacy,
+    _enable_profiler_legacy,
+    DeviceType,
+    ProfilerConfig,
+    ProfilerState,
+)
+from torch.autograd.profiler_util import (
+    _filter_name,
+    _filter_stack_entry,
+    _rewrite_name,
+    EventList,
+    FunctionEvent,
+    MEMORY_EVENT_NAME,
+)
+
+__all__ = ["profile"]
+
+
+class profile:
+    """DEPRECATED: use torch.profiler instead."""
+
+    def __init__(
+        self,
+        enabled=True,
+        *,
+        use_cuda=False,
+        record_shapes=False,
+        with_flops=False,
+        profile_memory=False,
+        with_stack=False,
+        with_modules=False,
+    ):
+        self.enabled: bool = enabled
+        if not self.enabled:
+            return
+        self.use_cuda = use_cuda
+        self.function_events = None
+        self.entered = False
+        self.record_shapes = record_shapes
+        self.with_flops = with_flops
+        self.record_shapes |= self.with_flops
+        self.profile_memory = profile_memory
+        self.with_stack = with_stack
+        self.with_modules = with_modules
+
+        if self.use_cuda and not torch.cuda.is_available():
+            warn("CUDA is not available, disabling CUDA profiling")
+            self.use_cuda = False
+
+        if self.use_cuda:
+            self.profiler_kind = ProfilerState.CUDA
+        else:
+            self.profiler_kind = ProfilerState.CPU
+
+    def config(self):
+        return ProfilerConfig(
+            self.profiler_kind,
+            self.record_shapes,
+            self.profile_memory,
+            self.with_stack,
+            self.with_flops,
+            self.with_modules,
+            # avoid exposing _ExperimentalConfig this in legacy public API
+            torch._C._profiler._ExperimentalConfig(),
+        )
+
+    def __enter__(self):
+        if not self.enabled:
+            return
+        if self.entered:
+            raise RuntimeError("Profiler context manager is not reentrant")
+        self.entered = True
+        self._start_trace()
+        return self
+
+    def _start_trace(self):
+        _enable_profiler_legacy(self.config())
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if not self.enabled:
+            return
+        if self.use_cuda:
+            torch.cuda.synchronize()
+
+        records = _disable_profiler_legacy()
+        parsed_results = _parse_legacy_records(records)
+        self.function_events = EventList(
+            parsed_results,
+            use_cuda=self.use_cuda,
+            profile_memory=self.profile_memory,
+            with_flops=self.with_flops,
+        )
+        self.function_events._build_tree()
+        return False
+
+    def __repr__(self):
+        if self.function_events is None:
+            return ""
+        return repr(self.function_events)
+
+    def __str__(self):
+        if self.function_events is None:
+            return ""
+        return str(self.function_events)
+
+    def _check_finish(self):
+        if self.function_events is None:
+            raise RuntimeError("Profiler didn't finish running")
+
+    def table(
+        self,
+        sort_by=None,
+        row_limit=100,
+        max_src_column_width=75,
+        max_name_column_width=55,
+        max_shapes_column_width=80,
+        header=None,
+        top_level_events_only=False,
+    ):
+        self._check_finish()
+        assert self.function_events is not None
+        return self.function_events.table(
+            sort_by=sort_by,
+            row_limit=row_limit,
+            max_src_column_width=max_src_column_width,
+            max_name_column_width=max_name_column_width,
+            max_shapes_column_width=max_shapes_column_width,
+            header=header,
+            top_level_events_only=top_level_events_only,
+        )
+
+    table.__doc__ = EventList.table.__doc__
+
+    def export_chrome_trace(self, path):
+        self._check_finish()
+        assert self.function_events is not None
+        return self.function_events.export_chrome_trace(path)
+
+    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
+
+    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
+        self._check_finish()
+        assert self.function_events is not None, "Expected profiling results"
+        assert self.with_stack, "export_stacks() requires with_stack=True"
+        return self.function_events.export_stacks(path, metric)
+
+    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
+        self._check_finish()
+        assert self.function_events is not None, "Expected profiling results"
+        return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
+
+    key_averages.__doc__ = EventList.key_averages.__doc__
+
+    def total_average(self):
+        self._check_finish()
+        assert self.function_events is not None, "Expected profiling results"
+        return self.function_events.total_average()
+
+    total_average.__doc__ = EventList.total_average.__doc__
+
+    @property
+    def self_cpu_time_total(self):
+        """Return CPU time as the sum of self times across all events."""
+        self._check_finish()
+        assert self.function_events is not None
+        return self.function_events.self_cpu_time_total
+
+
+def _parse_legacy_records(thread_records):
+    def _get_record_key(record):
+        """Return a tuple for correlating start and end records in `_parse_legacy_records`."""
+        return (record.handle(), record.node_id())
+
+    next_id = 0
+    start_record = None
+    functions = []
+    record_stack = []
+
+    # '__start_profile' is not guaranteed to be first, so we must find it here
+    for record in itertools.chain.from_iterable(thread_records):
+        name = record.name()
+        if start_record is None and name == "__start_profile":
+            start_record = record
+
+    assert start_record is not None and not start_record.is_remote()
+
+    for thread_record_list in thread_records:
+        # accumulated memory allocations per handle
+        cpu_memory_allocs = {}
+        cuda_memory_allocs = {}
+        # ranges per handle
+        range_starts = {}
+
+        filtered_handles = set()
+        prev_record = None
+        for record in thread_record_list:
+            record_key = _get_record_key(record)
+            if _filter_name(record.name()) or record_key in filtered_handles:
+                filtered_handles.add(record_key)
+                continue
+
+            if record.kind() == "push":
+                # workaround to reduce double logging from operator
+                # wrappers and redispatch
+                if prev_record is not None:
+                    duplicate = (
+                        prev_record.name() == record.name()
+                        and prev_record.kind() == record.kind()
+                        and prev_record.node_id() == record.node_id()
+                    )
+                    if duplicate:
+                        filtered_handles.add(record_key)
+                        continue
+
+                range_starts[record_key] = record
+                cpu_memory_allocs[record_key] = 0
+                cuda_memory_allocs[record_key] = 0
+            elif record.kind() == "pop":
+                assert (
+                    record_key in range_starts
+                ), f"""Expected record with key {record_key} to exist in range_starts.
+                    This means that the pop event did not have a corresponding push."""
+
+                start = range_starts[record_key]
+
+                cpu_memory_usage = cpu_memory_allocs[record_key]
+                cuda_memory_usage = cuda_memory_allocs[record_key]
+                is_async = start.is_async() or (start.thread_id() != record.thread_id())
+                is_remote_event = record.is_remote()
+                start_flops = start.flops()
+
+                fe = FunctionEvent(
+                    id=record.handle(),
+                    node_id=record.node_id(),
+                    name=_rewrite_name(name=start.name(), with_wildcard=True),
+                    trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
+                    thread=start.thread_id(),
+                    start_us=start_record.cpu_elapsed_us(start),
+                    end_us=start_record.cpu_elapsed_us(record),
+                    fwd_thread=start.fwd_thread_id(),
+                    input_shapes=start.shapes(),
+                    stack=[
+                        entry for entry in start.stack() if _filter_stack_entry(entry)
+                    ],
+                    scope=start.scope(),
+                    cpu_memory_usage=cpu_memory_usage,
+                    cuda_memory_usage=cuda_memory_usage,
+                    is_async=is_async,
+                    is_remote=is_remote_event,
+                    sequence_nr=start.sequence_nr(),
+                    device_type=DeviceType.CPU,
+                    is_legacy=True,
+                    flops=start_flops,
+                )
+                # note: async events have only cpu total time
+                if not is_async and start.has_cuda():
+                    duration = start.cuda_elapsed_us(record)
+                    if duration > 0:
+                        fe.append_kernel(start.name(), start.device(), duration)
+                functions.append(fe)
+                del range_starts[record_key]
+                del cpu_memory_allocs[record_key]
+                del cuda_memory_allocs[record_key]
+            elif record.kind() == "memory_alloc":
+                num_open_handles_cpu = len(cpu_memory_allocs)
+                num_open_handles_cuda = len(cuda_memory_allocs)
+                assert num_open_handles_cpu == num_open_handles_cuda
+                for handle in cpu_memory_allocs.keys():
+                    cpu_memory_allocs[handle] += record.cpu_memory_usage()
+                for handle in cuda_memory_allocs.keys():
+                    cuda_memory_allocs[handle] += record.cuda_memory_usage()
+                if num_open_handles_cpu == 0:
+                    # output event as a top-level memory event
+                    fe = FunctionEvent(
+                        id=0,
+                        name=MEMORY_EVENT_NAME,
+                        trace_name=None,
+                        thread=0,
+                        start_us=0,
+                        end_us=0,
+                        stack=[],
+                        cpu_memory_usage=record.cpu_memory_usage(),
+                        cuda_memory_usage=record.cuda_memory_usage(),
+                        is_legacy=True,
+                    )
+                    functions.append(fe)
+            prev_record = record
+
+    # Sort functions by start time then by end time ascending.
+    # This ensures that--in the case of nested events which
+    # have the same start time (which may happen due to the
+    # granularity of the given clock tick)--we always show
+    # the outermost nested call first. This adds stability
+    # in how FunctionEvents appear
+    functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
+    return functions
diff --git a/MLPY/Lib/site-packages/torch/autograd/profiler_util.py b/MLPY/Lib/site-packages/torch/autograd/profiler_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..43dbab2b490092c96b16a3cdce49f8f0745b5847
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/profiler_util.py
@@ -0,0 +1,1178 @@
+import bisect
+import itertools
+import math
+
+from collections import defaultdict, namedtuple
+from operator import attrgetter
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch.autograd import DeviceType
+
+__all__ = [
+    "EventList",
+    "FormattedTimesMixin",
+    "Interval",
+    "Kernel",
+    "FunctionEvent",
+    "FunctionEventAvg",
+    "StringTable",
+    "MemRecordsAcc",
+]
+
+
+class EventList(list):
+    """A list of Events (for pretty printing)."""
+
+    def __init__(self, *args, **kwargs):
+        use_cuda = kwargs.pop("use_cuda", True)
+        use_device = kwargs.pop("use_device", None)
+        profile_memory = kwargs.pop("profile_memory", False)
+        with_flops = kwargs.pop("with_flops", False)
+        super().__init__(*args, **kwargs)
+        self._use_cuda = use_cuda
+        self._use_device = use_device
+        self._profile_memory = profile_memory
+        self._tree_built = False
+        self._with_flops = with_flops
+
+    def _build_tree(self):
+        self._populate_cpu_children()
+        self._remove_dup_nodes()
+        self._set_backward_stacktraces()
+        self._tree_built = True
+
+    def __str__(self):
+        return self.table()
+
+    def _remove_dup_nodes(self):
+        while True:
+            to_delete = set()
+            for idx in range(len(self)):
+                if (
+                    self[idx].cpu_parent is not None
+                    and self[idx].cpu_parent.name == self[idx].name
+                    and len(self[idx].cpu_parent.cpu_children) == 1
+                ):
+                    self[idx].cpu_parent.cpu_children = self[idx].cpu_children
+                    self[idx].cpu_parent.kernels = self[idx].kernels  # lift kernels up
+                    for ch in self[idx].cpu_children:
+                        ch.cpu_parent = self[idx].cpu_parent
+                    to_delete.add(idx)
+            if len(to_delete) == 0:
+                break
+            new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
+            self.clear()
+            self.extend(new_evts)
+
+    def _populate_cpu_children(self):
+        """Populate child events into each underlying FunctionEvent object.
+
+        One event is a child of another if [s1, e1) is inside [s2, e2). Where
+        s1 and e1 would be start and end of the child event's interval. And
+        s2 and e2 start and end of the parent event's interval
+
+        Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10]
+        be a parent of two other intervals.
+
+        If for any reason two intervals intersect only partially, this function
+        will not record a parent child relationship between then.
+        """
+        # Some events can be async (i.e. start and end on different threads),
+        # since it's generally undefined how to attribute children ranges to
+        # async ranges, we do not use them when calculating nested ranges and stats
+        sync_events = [
+            evt
+            for evt in self
+            if not evt.is_async and evt.device_type == DeviceType.CPU
+        ]
+        events = sorted(
+            sync_events,
+            key=attrgetter("thread"),
+        )
+        # Group by both thread and node_id, so that events that happen to have
+        # the same thread_id but are from different nodes aren't incorrectly
+        # grouped together.
+        threads = itertools.groupby(
+            events, key=lambda event: (event.thread, event.node_id)
+        )
+
+        # For each thread we keep a stack of current nested parents.
+        # We maintain the invariant that each interval is a subset of all other
+        # intervals lower in the stack.
+        #
+        # First we sort the intervals by their start time. Then we iterate over them.
+        # Every time we see a new interval we remove several parents from
+        # the top until we restore the invariant. Then parent child relationship
+        # if recorded if the stack is not empty.
+        # Finally we add new interval to the list
+        #
+        # Algorithm has O(N * log(N)) complexity where N is number of
+        # intervals
+        for thread_id, thread_events in threads:
+            thread_events_ = sorted(
+                thread_events,
+                key=lambda event: [event.time_range.start, -event.time_range.end],
+            )
+            current_events: List[FunctionEvent] = []
+            cur_end = 0
+            for event in thread_events_:
+                while len(current_events) > 0:
+                    parent = current_events[-1]
+                    if (
+                        event.time_range.start >= parent.time_range.end
+                        or event.time_range.end > parent.time_range.end
+                    ):
+                        # this can't be a parent
+                        current_events.pop()
+                    else:
+                        parent.append_cpu_child(event)
+                        assert (
+                            event.cpu_parent is None
+                        ), f"There is already a CPU parent event for {event.key}"
+                        event.set_cpu_parent(parent)
+                        break
+
+                current_events.append(event)
+
+    def _set_backward_stacktraces(self):
+        def bw_parent(evt):
+            if evt is None:
+                return None
+            elif evt.scope == 1:  # BACKWARD_FUNCTION
+                return evt
+            else:
+                return bw_parent(evt.cpu_parent)
+
+        fwd_stacks = {}
+        for evt in self:
+            if bw_parent(evt) is None and evt.stack is not None:
+                t = (evt.sequence_nr, evt.thread)
+                if t not in fwd_stacks:
+                    fwd_stacks[t] = evt.stack
+
+        for evt in self:
+            p = bw_parent(evt)
+            if p is not None:
+                assert p.fwd_thread is not None
+                t = (p.sequence_nr, p.fwd_thread)
+                if t in fwd_stacks:
+                    evt.stack = fwd_stacks[t]
+                else:
+                    evt.stack = []
+
+    @property
+    def self_cpu_time_total(self):
+        return sum([event.self_cpu_time_total for event in self])
+
+    def table(
+        self,
+        sort_by=None,
+        row_limit=100,
+        max_src_column_width=75,
+        max_name_column_width=55,
+        max_shapes_column_width=80,
+        header=None,
+        top_level_events_only=False,
+    ):
+        """Print an EventList as a nicely formatted table.
+
+        Args:
+            sort_by (str, optional): Attribute used to sort entries. By default
+                they are printed in the same order as they were registered.
+                Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
+                ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
+                ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
+            top_level_events_only(bool, optional): Boolean flag to determine the
+                selection of events to display. If true, the profiler will only
+                display events at top level like top-level invocation of python
+                `lstm`, python `add` or other functions, nested events like low-level
+                cpu/cuda ops events are omitted for profiler result readability.
+
+        Returns:
+            A string containing the table.
+        """
+        return _build_table(
+            self,
+            sort_by=sort_by,
+            row_limit=row_limit,
+            max_src_column_width=max_src_column_width,
+            max_name_column_width=max_name_column_width,
+            max_shapes_column_width=max_shapes_column_width,
+            header=header,
+            profile_memory=self._profile_memory,
+            with_flops=self._with_flops,
+            top_level_events_only=top_level_events_only,
+        )
+
+    def export_chrome_trace(self, path):
+        """Export an EventList as a Chrome tracing tools file.
+
+        The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
+
+        Args:
+            path (str): Path where the trace will be written.
+        """
+        import os
+
+        device_name = "cuda" if not self._use_device else self._use_device
+        with open(path, "w") as f:
+            chrome_events = []
+            next_id = 0
+            # Use file IO over using json.dump since JSON dumping is very slow and
+            # this technique is proven to give a 4x speedup.
+            f.write("[")
+            for evt in self:
+                if evt.trace_name is None:
+                    continue
+                f.write(
+                    '{{"name": "{}", '
+                    '"ph": "X", '
+                    '"ts": {}, '
+                    '"dur": {}, '
+                    '"tid": {}, '
+                    '"pid": "CPU functions", '
+                    '"args": {{}}}}, '.format(
+                        evt.trace_name,
+                        evt.time_range.start,
+                        evt.time_range.elapsed_us(),
+                        evt.thread
+                        if not evt.is_remote
+                        else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "',
+                    )
+                )
+                for k in evt.kernels:
+                    # 's' and 'f' draw Flow arrows from
+                    # the CPU launch to the GPU kernel
+                    f.write(
+                        f'{{"name": "{evt.trace_name}", '
+                        '"ph": "s", '
+                        f'"ts": {evt.time_range.start}, '
+                        f'"tid": {evt.thread}, '
+                        '"pid": "CPU functions", '
+                        f'"id": {next_id}, '
+                        f'"cat": "cpu_to_{device_name}", '
+                        '"args": {}}, '
+                    )
+                    # Note: use torch.profiler to get device kernel trace
+                    next_id += 1
+            if len(self) > 0:
+                # remove trailing whitespace and comma
+                f.seek(f.tell() - 2, os.SEEK_SET)
+                f.truncate()
+            f.write("]")
+
+    def supported_export_stacks_metrics(self):
+        return [
+            "self_cpu_time_total",
+            "self_cuda_time_total",
+            "self_privateuse1_time_total",
+        ]
+
+    def export_stacks(self, path: str, metric: str):
+        if metric not in self.supported_export_stacks_metrics():
+            raise ValueError(
+                "metric should be one of: "
+                + str(self.supported_export_stacks_metrics())
+            )
+        translate_table = str.maketrans(" ;\t\n", "____")
+        with open(path, "w") as f:
+            for evt in self:
+                if evt.stack and len(evt.stack) > 0:
+                    metric_value = getattr(evt, metric)
+                    if int(metric_value) > 0:
+                        stack_str = ""
+                        for entry in reversed(evt.stack):
+                            stack_str += entry.translate(translate_table)
+                            stack_str += ";"
+                        stack_str = stack_str[:-1] + " " + str(int(metric_value))
+                        f.write(stack_str + "\n")
+
+    def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0):
+        """Averages all function events over their keys.
+
+        Args:
+            group_by_input_shapes: group entries by
+                (event name, input shapes) rather than just event name.
+                This is useful to see which input shapes contribute to the runtime
+                the most and may help with size-specific optimizations or
+                choosing the best candidates for quantization (aka fitting a roof line)
+
+            group_by_stack_n: group by top n stack trace entries
+
+        Returns:
+            An EventList containing FunctionEventAvg objects.
+        """
+        assert self._tree_built
+        stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
+
+        def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
+            key = [
+                str(event.key),
+                str(event.node_id),
+                str(event.device_type),
+                str(event.is_legacy),
+            ]
+            if group_by_input_shapes:
+                key.append(str(event.input_shapes))
+            if group_by_stack_n > 0:
+                key += event.stack[:group_by_stack_n]
+            return tuple(key)
+
+        for evt in self:
+            stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt)
+
+        avg_list = EventList(
+            stats.values(),
+            use_cuda=self._use_cuda,
+            use_device=self._use_device,
+            profile_memory=self._profile_memory,
+            with_flops=self._with_flops,
+        )
+        for evt in avg_list:
+            evt.stack = evt.stack[:group_by_stack_n]
+            if not group_by_input_shapes:
+                evt.input_shapes = ""
+        return avg_list
+
+    def total_average(self):
+        """Averages all events.
+
+        Returns:
+            A FunctionEventAvg object.
+        """
+        total_stat = FunctionEventAvg()
+        for evt in self:
+            total_stat += evt
+            total_stat.key = None
+        total_stat.key = "Total"
+        return total_stat
+
+
+def _format_time(time_us):
+    """Define how to format time in FunctionEvent."""
+    US_IN_SECOND = 1000.0 * 1000.0
+    US_IN_MS = 1000.0
+    if time_us >= US_IN_SECOND:
+        return f"{time_us / US_IN_SECOND:.3f}s"
+    if time_us >= US_IN_MS:
+        return f"{time_us / US_IN_MS:.3f}ms"
+    return f"{time_us:.3f}us"
+
+
+def _format_time_share(time_us, total_time_us):
+    """Define how to format time in FunctionEvent."""
+    if total_time_us == 0:
+        assert time_us == 0, f"Expected time_us == 0 but got {time_us}"
+        return "NaN"
+    return f"{time_us * 100.0 / total_time_us:.2f}%"
+
+
+def _format_memory(nbytes):
+    """Return a formatted memory size string."""
+    KB = 1024
+    MB = 1024 * KB
+    GB = 1024 * MB
+    if abs(nbytes) >= GB:
+        return f"{nbytes * 1.0 / GB:.2f} Gb"
+    elif abs(nbytes) >= MB:
+        return f"{nbytes * 1.0 / MB:.2f} Mb"
+    elif abs(nbytes) >= KB:
+        return f"{nbytes * 1.0 / KB:.2f} Kb"
+    else:
+        return str(nbytes) + " b"
+
+
+def _attr_formatter(name):
+    return property(lambda self: _format_time(getattr(self, name)))
+
+
+class FormattedTimesMixin:
+    """Helpers for FunctionEvent and FunctionEventAvg.
+
+    The subclass should define `*_time_total` and `count` attributes.
+    """
+
+    cpu_time_str = _attr_formatter("cpu_time")
+    cuda_time_str = _attr_formatter("cuda_time")
+    privateuse1_time_str = _attr_formatter("privateuse1_time")
+    cpu_time_total_str = _attr_formatter("cpu_time_total")
+    cuda_time_total_str = _attr_formatter("cuda_time_total")
+    privateuse1_time_total_str = _attr_formatter("privateuse1_time_total")
+    self_cpu_time_total_str = _attr_formatter("self_cpu_time_total")
+    self_cuda_time_total_str = _attr_formatter("self_cuda_time_total")
+    self_privateuse1_time_total_str = _attr_formatter("self_privateuse1_time_total")
+
+    @property
+    def cpu_time(self):
+        return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count  # type: ignore[attr-defined]
+
+    @property
+    def cuda_time(self):
+        return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count  # type: ignore[attr-defined]
+
+    @property
+    def privateuse1_time(self):
+        return 0.0 if self.count == 0 else 1.0 * self.privateuse1_time_total / self.count  # type: ignore[attr-defined]
+
+
+class Interval:
+    def __init__(self, start, end):
+        self.start = start
+        self.end = end
+
+    def elapsed_us(self):
+        r"""
+        Returns the length of the interval
+        """
+        return self.end - self.start
+
+
+Kernel = namedtuple("Kernel", ["name", "device", "duration"])
+
+
+class FunctionEvent(FormattedTimesMixin):
+    """Profiling information about a single function."""
+
+    def __init__(
+        self,
+        id,
+        name,
+        thread,
+        start_us,
+        end_us,
+        fwd_thread=None,
+        input_shapes=None,
+        stack=None,
+        scope=0,
+        use_device=None,
+        cpu_memory_usage=0,
+        cuda_memory_usage=0,
+        privateuse1_memory_usage=0,
+        is_async=False,
+        is_remote=False,
+        sequence_nr=-1,
+        node_id=-1,
+        device_type=DeviceType.CPU,
+        device_index=0,
+        is_legacy=False,
+        flops=None,
+        trace_name=None,
+        concrete_inputs=None,
+    ):
+        self.id: int = id
+        self.node_id: int = node_id
+        self.name: str = name
+        self.trace_name: str = trace_name
+        self.time_range: Interval = Interval(start_us, end_us)
+        self.thread: int = thread
+        self.fwd_thread: Optional[int] = fwd_thread
+        self.kernels: List[Kernel] = []
+        self.count: int = 1
+        self.cpu_children: List[FunctionEvent] = []
+        self.cpu_parent: Optional[FunctionEvent] = None
+        self.input_shapes: Tuple[int, ...] = input_shapes
+        self.concrete_inputs: List[Any] = concrete_inputs
+        self.stack: List = stack
+        self.scope: int = scope
+        self.use_device: Optional[str] = use_device
+        self.cpu_memory_usage: int = cpu_memory_usage
+        self.cuda_memory_usage: int = cuda_memory_usage
+        self.privateuse1_memory_usage: int = privateuse1_memory_usage
+        self.is_async: bool = is_async
+        self.is_remote: bool = is_remote
+        self.sequence_nr: int = sequence_nr
+        self.device_type: DeviceType = device_type
+        self.device_index: int = device_index
+        self.is_legacy: bool = is_legacy
+        self.flops: Optional[int] = flops
+
+    def append_kernel(self, name, device, duration):
+        assert self.device_type == DeviceType.CPU
+        self.kernels.append(Kernel(name, device, duration))
+
+    def append_cpu_child(self, child):
+        """Append a CPU child of type FunctionEvent.
+
+        One is supposed to append only direct children to the event to have
+        correct self cpu time being reported.
+        """
+        assert self.device_type == DeviceType.CPU
+        assert isinstance(child, FunctionEvent)
+        assert child.device_type == DeviceType.CPU
+        self.cpu_children.append(child)
+
+    def set_cpu_parent(self, parent):
+        """Set the immediate CPU parent of type FunctionEvent.
+
+        One profiling FunctionEvent should have only one CPU parent such that
+        the child's range interval is completely inside the parent's. We use
+        this connection to determine the event is from top-level op or not.
+        """
+        assert self.device_type == DeviceType.CPU
+        assert isinstance(parent, FunctionEvent)
+        assert parent.device_type == DeviceType.CPU
+        self.cpu_parent = parent
+
+    # Note: async events don't have children, are not used when computing 'self'
+    # metrics of other events, have only total cpu time
+    @property
+    def self_cpu_memory_usage(self):
+        if self.is_async or self.device_type != DeviceType.CPU:
+            return 0
+        return self.cpu_memory_usage - sum(
+            [child.cpu_memory_usage for child in self.cpu_children]
+        )
+
+    @property
+    def self_cuda_memory_usage(self):
+        if self.is_async or self.device_type != DeviceType.CPU:
+            return 0
+        return self.cuda_memory_usage - sum(
+            [child.cuda_memory_usage for child in self.cpu_children]
+        )
+
+    @property
+    def self_privateuse1_memory_usage(self):
+        if self.is_async or self.device_type != DeviceType.CPU:
+            return 0
+        return self.privateuse1_memory_usage - sum(
+            [child.privateuse1_memory_usage for child in self.cpu_children]
+        )
+
+    @property
+    def self_cpu_time_total(self):
+        if self.is_async or self.device_type != DeviceType.CPU:
+            return 0
+        return self.cpu_time_total - sum(
+            [child.cpu_time_total for child in self.cpu_children]
+        )
+
+    @property
+    def cuda_time_total(self):
+        if self.is_async or self.use_device:
+            return 0
+        if self.device_type == DeviceType.CPU:
+            if not self.is_legacy:
+                # account for the kernels in the children ops
+                return sum(kinfo.duration for kinfo in self.kernels) + sum(
+                    ch.cuda_time_total for ch in self.cpu_children
+                )
+            else:
+                # each legacy cpu events has a single (fake) kernel
+                return sum(kinfo.duration for kinfo in self.kernels)
+        else:
+            assert self.device_type == DeviceType.CUDA
+            return self.time_range.elapsed_us()
+
+    @property
+    def self_cuda_time_total(self):
+        if self.is_async or self.use_device:
+            return 0
+        if self.device_type == DeviceType.CPU:
+            return self.cuda_time_total - sum(
+                [child.cuda_time_total for child in self.cpu_children]
+            )
+        else:
+            assert self.device_type == DeviceType.CUDA
+            return self.cuda_time_total
+
+    @property
+    def cpu_time_total(self):
+        if self.device_type == DeviceType.CPU:
+            return self.time_range.elapsed_us()
+        else:
+            return 0
+
+    @property
+    def self_privateuse1_time_total(self):
+        if self.is_async or not self.use_device:
+            return 0
+        if self.device_type == DeviceType.CPU:
+            return self.privateuse1_time_total - sum(
+                [child.privateuse1_time_total for child in self.cpu_children]
+            )
+        else:
+            assert self.device_type == DeviceType.CUDA
+            return self.privateuse1_time_total
+
+    @property
+    def privateuse1_time_total(self):
+        if self.is_async or not self.use_device:
+            return 0
+        if self.device_type == DeviceType.CPU:
+            if not self.is_legacy:
+                # account for the kernels in the children ops
+                return sum(kinfo.duration for kinfo in self.kernels) + sum(
+                    ch.privateuse1_time_total for ch in self.cpu_children
+                )
+            else:
+                # each legacy cpu events has a single (fake) kernel
+                return sum(kinfo.duration for kinfo in self.kernels)
+        else:
+            assert self.device_type == DeviceType.PrivateUse1
+            return self.time_range.elapsed_us()
+
+    @property
+    def key(self):
+        return self.name
+
+    def __repr__(self):
+        device_name = "cuda" if not self.use_device else self.use_device
+        device_time = (
+            self.cuda_time_str if not self.use_device else self.privateuse1_time_str
+        )
+        device_memory_usage = (
+            self.cuda_memory_usage
+            if not self.use_device
+            else self.privateuse1_memory_usage
+        )
+        return (
+            "".format(
+                self.id,
+                self.name,
+                self.device_type,
+                self.node_id,
+                self.cpu_time_str,
+                self.time_range.start,
+                self.time_range.end,
+                str([child.id for child in self.cpu_children]),
+                device_name,
+                device_time,
+                self.name,
+                self.thread,
+                str(self.input_shapes),
+                self.cpu_memory_usage,
+                device_name,
+                device_memory_usage,
+                self.is_async,
+                self.is_remote,
+                self.sequence_nr,
+                self.is_legacy,
+            )
+        )
+
+
+class FunctionEventAvg(FormattedTimesMixin):
+    """Used to average stats over multiple FunctionEvent objects."""
+
+    def __init__(self):
+        self.key: Optional[str] = None
+        self.count: int = 0
+        self.node_id: int = 0
+        self.is_async: bool = False
+        self.is_remote: bool = False
+        self.use_device: Optional[str] = None
+        self.cpu_time_total: int = 0
+        self.cuda_time_total: int = 0
+        self.privateuse1_time_total: int = 0
+        self.self_cpu_time_total: int = 0
+        self.self_cuda_time_total: int = 0
+        self.self_privateuse1_time_total: int = 0
+        self.input_shapes: Optional[List[List[int]]] = None
+        self.stack: Optional[List] = None
+        self.scope: Optional[int] = None
+        self.cpu_memory_usage: int = 0
+        self.cuda_memory_usage: int = 0
+        self.privateuse1_memory_usage: int = 0
+        self.self_cpu_memory_usage: int = 0
+        self.self_cuda_memory_usage: int = 0
+        self.self_privateuse1_memory_usage: int = 0
+        self.cpu_children: Optional[List[FunctionEvent]] = None
+        self.cpu_parent: Optional[FunctionEvent] = None
+        self.device_type: DeviceType = DeviceType.CPU
+        self.is_legacy: bool = False
+        self.flops: int = 0
+
+    def add(self, other):
+        if self.key is None:
+            # First function being recorded as part of FunctionEventAvg, propagate
+            # fields.
+            self.key = other.key
+            self.node_id = other.node_id
+            self.is_async = other.is_async
+            self.is_remote = other.is_remote
+            self.cpu_parent = other.cpu_parent
+            self.cpu_children = other.cpu_children
+
+            self.input_shapes = other.input_shapes
+            self.stack = other.stack
+            self.scope = other.scope
+            self.device_type = other.device_type
+            self.is_legacy = other.is_legacy
+            self.use_device = other.use_device
+
+        assert isinstance(other, (FunctionEvent, FunctionEventAvg))
+        assert other.key == self.key
+        self.cpu_time_total += other.cpu_time_total
+        self.cuda_time_total += other.cuda_time_total
+        self.privateuse1_time_total += other.privateuse1_time_total
+        self.self_cpu_time_total += other.self_cpu_time_total
+        self.self_cuda_time_total += other.self_cuda_time_total
+        self.self_privateuse1_time_total += other.self_privateuse1_time_total
+        self.cpu_memory_usage += other.cpu_memory_usage
+        self.cuda_memory_usage += other.cuda_memory_usage
+        self.privateuse1_memory_usage += other.privateuse1_memory_usage
+        self.self_cpu_memory_usage += other.self_cpu_memory_usage
+        self.self_cuda_memory_usage += other.self_cuda_memory_usage
+        self.self_privateuse1_memory_usage += other.self_privateuse1_memory_usage
+        self.count += other.count
+        if self.flops is None:
+            self.flops = other.flops
+        elif other.flops is not None:
+            self.flops += other.flops
+        return self
+
+    def __iadd__(self, other):
+        return self.add(other)
+
+    def __repr__(self):
+        device_name = "cuda" if not self.use_device else self.use_device
+        self_device_time = (
+            self.self_cuda_time_total_str
+            if not self.use_device
+            else self.self_privateuse1_time_total_str
+        )
+        device_time = (
+            self.cuda_time_str if not self.use_device else self.privateuse1_time_str
+        )
+        device_memory = (
+            self.cuda_memory_usage
+            if not self.use_device
+            else self.privateuse1_memory_usage
+        )
+        return (
+            "".format(
+                self.key,
+                self.self_cpu_time_total_str,
+                self.cpu_time_str,
+                device_name,
+                self_device_time,
+                device_name,
+                device_time,
+                str(self.input_shapes),
+                self.cpu_memory_usage,
+                device_name,
+                device_memory,
+            )
+        )
+
+
+class StringTable(defaultdict):
+    def __missing__(self, key):
+        # manage cases like 't' (demangled to 'unsigned short') separately,
+        # for now simply check the length to avoid unexpected results for
+        # the short sequences
+        self[key] = torch._C._demangle(key) if len(key) > 1 else key
+        return self[key]
+
+
+class MemRecordsAcc:
+    """Acceleration structure for accessing mem_records in interval."""
+
+    def __init__(self, mem_records):
+        self._mem_records = mem_records
+        self._start_uses: List[int] = []
+        self._indices: List[int] = []
+        if len(mem_records) > 0:
+            tmp = sorted([(r[0].start_us(), i) for i, r in enumerate(mem_records)])
+            self._start_uses, self._indices = zip(*tmp)  # type: ignore[assignment]
+
+    def in_interval(self, start_us, end_us):
+        r"""
+        Return all records in the given interval
+        """
+        start_idx = bisect.bisect_left(self._start_uses, start_us)
+        end_idx = bisect.bisect_right(self._start_uses, end_us)
+        for i in range(start_idx, end_idx):
+            yield self._mem_records[self._indices[i]]
+
+
+def _filter_stack_entry(entry):
+    filtered_entries = [
+        ("autograd/__init__", "_make_grads"),
+        ("autograd/__init__", "backward"),
+        ("torch/tensor", "backward"),
+        ("_internal/common_utils", "prof_callable"),
+        ("_internal/common_utils", "prof_func_call"),
+        ("_internal/common_utils", "prof_meth_call"),
+    ]
+    return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries)
+
+
+MEMORY_EVENT_NAME = "[memory]"
+OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]"
+
+
+def _filter_name(name):
+    # ignoring the following utility ops
+    filtered_out_names = [
+        MEMORY_EVENT_NAME,  # used only for the top-level memory events
+        OUT_OF_MEMORY_EVENT_NAME,
+        "profiler::_record_function_enter",
+        "profiler::_record_function_enter_new",
+        "profiler::_record_function_exit",
+        "aten::is_leaf",
+        "aten::output_nr",
+        "aten::_version",
+    ]
+    return name in filtered_out_names
+
+
+# Demangles and optionally rewrites the provided event name,
+# with_wildcard - whether to replace certain numbered event names
+# with a wildcard name to aggregate them together in the profiler table
+# output
+def _rewrite_name(name, with_wildcard=False):
+    string_table = StringTable()
+    name = string_table[name]
+    if with_wildcard:
+        if name.startswith("ProfilerStep#"):
+            name = "ProfilerStep*"
+    return name
+
+
+def _build_table(
+    events,
+    sort_by=None,
+    header=None,
+    row_limit=100,
+    max_src_column_width=75,
+    max_name_column_width=55,
+    max_shapes_column_width=80,
+    with_flops=False,
+    profile_memory=False,
+    top_level_events_only=False,
+):
+    """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
+    if len(events) == 0:
+        return ""
+
+    has_cuda_time = any(event.self_cuda_time_total > 0 for event in events)
+    has_cuda_mem = any(event.self_cuda_memory_usage > 0 for event in events)
+    has_privateuse1_time = any(
+        event.self_privateuse1_time_total > 0 for event in events
+    )
+    has_privateuse1_mem = any(
+        event.self_privateuse1_memory_usage > 0 for event in events
+    )
+    use_device = events[0].use_device
+    if not use_device and (has_privateuse1_mem or has_privateuse1_time):
+        raise RuntimeError(
+            "use_device is None, but there is private device performance data."
+        )
+
+    has_input_shapes = any(
+        (event.input_shapes is not None and len(event.input_shapes) > 0)
+        for event in events
+    )
+
+    if sort_by is not None:
+        events = EventList(
+            sorted(events, key=lambda evt: getattr(evt, sort_by), reverse=True),
+            use_cuda=has_cuda_time,
+            use_device=use_device,
+            profile_memory=profile_memory,
+            with_flops=with_flops,
+        )
+
+    name_column_width = max([len(evt.key) for evt in events]) + 4
+    if max_name_column_width is not None:
+        name_column_width = min(name_column_width, max_name_column_width)
+
+    shapes_column_width = max([len(str(evt.input_shapes)) for evt in events]) + 4
+    if max_shapes_column_width is not None:
+        shapes_column_width = min(shapes_column_width, max_shapes_column_width)
+
+    DEFAULT_COLUMN_WIDTH = 12
+    flops_column_width = DEFAULT_COLUMN_WIDTH
+
+    src_column_width = None
+    stacks = []
+    for evt in events:
+        if evt.stack is not None and len(evt.stack) > 0:
+            stacks.append(evt.stack)
+    has_stack = len(stacks) > 0
+    if has_stack:
+        src_column_width = (
+            max([max([len(entry) for entry in stack]) for stack in stacks]) + 4
+        )
+        if max_src_column_width is not None:
+            src_column_width = min(src_column_width, max_src_column_width)
+
+    headers = [
+        "Name",
+        "Self CPU %",
+        "Self CPU",
+        "CPU total %",
+        "CPU total",
+        "CPU time avg",
+    ]
+    if has_cuda_time:
+        headers.extend(
+            [
+                "Self CUDA",
+                "Self CUDA %",
+                "CUDA total",
+                "CUDA time avg",
+            ]
+        )
+    if has_privateuse1_time:
+        privateuse1 = use_device.upper()
+        headers.extend(
+            [
+                f"Self {privateuse1}",
+                f"Self {privateuse1} %",
+                f"{privateuse1} total",
+                f"{privateuse1} time avg",
+            ]
+        )
+    if profile_memory:
+        headers.extend(
+            [
+                "CPU Mem",
+                "Self CPU Mem",
+            ]
+        )
+        if has_cuda_mem:
+            headers.extend(
+                [
+                    "CUDA Mem",
+                    "Self CUDA Mem",
+                ]
+            )
+        if has_privateuse1_mem:
+            privateuse1 = use_device.upper()
+            headers.extend(
+                [
+                    f"{privateuse1} Mem",
+                    f"Self {privateuse1} Mem",
+                ]
+            )
+    headers.append("# of Calls")
+    # Only append Node ID if any event has a valid (>= 0) Node ID
+    append_node_id = any(evt.node_id != -1 for evt in events)
+    if append_node_id:
+        headers.append("Node ID")
+
+    # Have to use a list because nonlocal is Py3 only...
+    SPACING_SIZE = 2
+    row_format_lst = [""]
+    header_sep_lst = [""]
+    line_length_lst = [-SPACING_SIZE]
+    MAX_STACK_ENTRY = 5
+
+    def add_column(padding, text_dir=">"):
+        row_format_lst[0] += (
+            "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE)
+        )
+        header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE)
+        line_length_lst[0] += padding + SPACING_SIZE
+
+    def auto_scale_flops(flops):
+        flop_headers = [
+            "FLOPs",
+            "KFLOPs",
+            "MFLOPs",
+            "GFLOPs",
+            "TFLOPs",
+            "PFLOPs",
+        ]
+        assert flops > 0
+        log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
+        assert log_flops >= 0 and log_flops < len(flop_headers)
+        return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
+
+    add_column(name_column_width)
+    for _ in headers[1:]:
+        add_column(DEFAULT_COLUMN_WIDTH)
+
+    if has_input_shapes:
+        headers.append("Input Shapes")
+        add_column(shapes_column_width)
+
+    if has_stack:
+        headers.append("Source Location")
+        add_column(src_column_width, text_dir="<")
+
+    if with_flops:
+        # Auto-scaling of flops header
+        raw_flops = []
+        for evt in events:
+            if evt.flops > 0:
+                raw_flops.append(evt.flops)
+        if len(raw_flops) != 0:
+            (flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
+            headers.append(f"Total {flops_header}")
+            add_column(flops_column_width)
+        else:
+            with_flops = False  # can't find any valid flops
+
+    row_format = row_format_lst[0]
+    header_sep = header_sep_lst[0]
+    line_length = line_length_lst[0]
+    add_column = None  # type: ignore[assignment]
+
+    # Have to use a list because nonlocal is Py3 only...
+    result = []
+
+    def append(s):
+        result.append(s)
+        result.append("\n")  # Yes, newline after the end as well
+
+    sum_self_cpu_time_total = sum([event.self_cpu_time_total for event in events])
+    sum_self_cuda_time_total = 0
+    sum_self_privateuse1_time_total = 0
+    for evt in events:
+        if evt.device_type == DeviceType.CPU:
+            # in legacy profiler, kernel info is stored in cpu events
+            if evt.is_legacy:
+                if not use_device:
+                    sum_self_cuda_time_total += evt.self_cuda_time_total
+                else:
+                    sum_self_privateuse1_time_total += evt.self_privateuse1_time_total
+        elif evt.device_type == DeviceType.CUDA:
+            # in kineto profiler, there're events with the correct device type (e.g. CUDA)
+            sum_self_cuda_time_total += evt.self_cuda_time_total
+        elif evt.device_type == DeviceType.PrivateUse1:
+            sum_self_privateuse1_time_total += evt.self_privateuse1_time_total
+
+    # Actual printing
+    if header is not None:
+        append("=" * line_length)
+        append(header)
+    if top_level_events_only:
+        append("=" * line_length)
+        append("This report only display top-level ops statistics")
+    append(header_sep)
+    append(row_format.format(*headers))
+
+    append(header_sep)
+
+    def trim_path(path, src_column_width):
+        if len(path) > src_column_width:
+            offset = len(path) - src_column_width
+            path = path[offset:]
+            if len(path) > 3:
+                path = "..." + path[3:]
+        return path
+
+    event_limit = 0
+    for evt in events:
+        if event_limit == row_limit:
+            break
+        if top_level_events_only and evt.cpu_parent is not None:
+            continue
+        else:
+            event_limit += 1
+        name = evt.key
+        if max_name_column_width is not None and len(name) >= max_name_column_width - 3:
+            name = name[: (max_name_column_width - 3)] + "..."
+        row_values = [
+            name,
+            # Self CPU total %, 0 for async events.
+            _format_time_share(evt.self_cpu_time_total, sum_self_cpu_time_total),
+            evt.self_cpu_time_total_str,  # Self CPU total
+            # CPU total %, 0 for async events.
+            _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total)
+            if not evt.is_async
+            else 0,
+            evt.cpu_time_total_str,  # CPU total
+            evt.cpu_time_str,  # CPU time avg
+        ]
+        if has_cuda_time:
+            row_values.extend(
+                [
+                    evt.self_cuda_time_total_str,
+                    # CUDA time total %
+                    _format_time_share(
+                        evt.self_cuda_time_total, sum_self_cuda_time_total
+                    ),
+                    evt.cuda_time_total_str,
+                    evt.cuda_time_str,  # Cuda time avg
+                ]
+            )
+        if has_privateuse1_time:
+            row_values.extend(
+                [
+                    evt.self_privateuse1_time_total_str,
+                    # PrivateUse1 time total %
+                    _format_time_share(
+                        evt.self_privateuse1_time_total, sum_self_privateuse1_time_total
+                    ),
+                    evt.privateuse1_time_total_str,
+                    evt.privateuse1_time_str,  # PrivateUse1 time avg
+                ]
+            )
+        if profile_memory:
+            row_values.extend(
+                [
+                    # CPU Mem Total
+                    _format_memory(evt.cpu_memory_usage),
+                    # Self CPU Mem Total
+                    _format_memory(evt.self_cpu_memory_usage),
+                ]
+            )
+            if has_cuda_mem:
+                row_values.extend(
+                    [
+                        # CUDA Mem Total
+                        _format_memory(evt.cuda_memory_usage),
+                        # Self CUDA Mem Total
+                        _format_memory(evt.self_cuda_memory_usage),
+                    ]
+                )
+            if has_privateuse1_mem:
+                row_values.extend(
+                    [
+                        # PrivateUse1 Mem Total
+                        _format_memory(evt.privateuse1_memory_usage),
+                        # Self PrivateUse1 Mem Total
+                        _format_memory(evt.self_privateuse1_memory_usage),
+                    ]
+                )
+        row_values.append(
+            evt.count,  # Number of calls
+        )
+
+        if append_node_id:
+            row_values.append(evt.node_id)
+        if has_input_shapes:
+            row_values.append(str(evt.input_shapes)[:shapes_column_width])
+        if with_flops:
+            if evt.flops <= 0:
+                row_values.append("--")
+            else:
+                row_values.append(f"{evt.flops * flops_scale:8.3f}")  # type: ignore[possibly-undefined]
+        if has_stack:
+            src_field = ""
+            if len(evt.stack) > 0:
+                src_field = trim_path(evt.stack[0], src_column_width)
+            row_values.append(src_field)
+        append(row_format.format(*row_values))
+
+        if has_stack:
+            empty_headers = [""] * (len(headers) - 1)
+            for entry in evt.stack[1:MAX_STACK_ENTRY]:
+                append(
+                    row_format.format(
+                        *(empty_headers + [trim_path(entry, src_column_width)])
+                    )
+                )
+            empty_headers.append("")
+            append(row_format.format(*empty_headers))
+
+    append(header_sep)
+    append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}")
+    if has_cuda_time:
+        append(f"Self CUDA time total: {_format_time(sum_self_cuda_time_total)}")
+    if has_privateuse1_time:
+        append(
+            f"Self {use_device.upper()} time total: {_format_time(sum_self_privateuse1_time_total)}"
+        )
+    return "".join(result)
diff --git a/MLPY/Lib/site-packages/torch/autograd/variable.py b/MLPY/Lib/site-packages/torch/autograd/variable.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6c74819c79090f1926cc83a0b15e2aff322e304
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/autograd/variable.py
@@ -0,0 +1,14 @@
+import torch
+from torch._C import _ImperativeEngine as ImperativeEngine
+
+
+__all__ = ["VariableMeta", "Variable"]
+
+
+class VariableMeta(type):
+    def __instancecheck__(cls, other):
+        return isinstance(other, torch.Tensor)
+
+
+class Variable(torch._C._LegacyVariableBase, metaclass=VariableMeta):  # type: ignore[misc]
+    _execution_engine = ImperativeEngine()
diff --git a/MLPY/Lib/site-packages/torch/backends/__init__.py b/MLPY/Lib/site-packages/torch/backends/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fdfe3ff1655ad581775d2761660c0713292fa15
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/__init__.py
@@ -0,0 +1,70 @@
+import types
+from contextlib import contextmanager
+
+# The idea for this parameter is that we forbid bare assignment
+# to torch.backends..enabled and friends when running our
+# test suite, where it's very easy to forget to undo the change
+# later.
+__allow_nonbracketed_mutation_flag = True
+
+
+def disable_global_flags():
+    global __allow_nonbracketed_mutation_flag
+    __allow_nonbracketed_mutation_flag = False
+
+
+def flags_frozen():
+    return not __allow_nonbracketed_mutation_flag
+
+
+@contextmanager
+def __allow_nonbracketed_mutation():
+    global __allow_nonbracketed_mutation_flag
+    old = __allow_nonbracketed_mutation_flag
+    __allow_nonbracketed_mutation_flag = True
+    try:
+        yield
+    finally:
+        __allow_nonbracketed_mutation_flag = old
+
+
+class ContextProp:
+    def __init__(self, getter, setter):
+        self.getter = getter
+        self.setter = setter
+
+    def __get__(self, obj, objtype):
+        return self.getter()
+
+    def __set__(self, obj, val):
+        if not flags_frozen():
+            self.setter(val)
+        else:
+            raise RuntimeError(
+                "not allowed to set %s flags "
+                "after disable_global_flags; please use flags() context manager instead"
+                % obj.__name__
+            )
+
+
+class PropModule(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+
+from torch.backends import (
+    cpu as cpu,
+    cuda as cuda,
+    cudnn as cudnn,
+    mha as mha,
+    mkl as mkl,
+    mkldnn as mkldnn,
+    mps as mps,
+    nnpack as nnpack,
+    openmp as openmp,
+    quantized as quantized,
+)
diff --git a/MLPY/Lib/site-packages/torch/backends/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e579111fd3cd5dcdf8e5ce62e8ca0f211348556
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/_coreml/__init__.py b/MLPY/Lib/site-packages/torch/backends/_coreml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6e7be639b24e1913a24176f8d87ad6d9f7b1e30
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1d080d40419b6f4a6b9f878bb2fbbe670a85748
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/_coreml/preprocess.py b/MLPY/Lib/site-packages/torch/backends/_coreml/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ab3d64edcf1b8e69ef00ae754a587c31aaf69c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/_coreml/preprocess.py
@@ -0,0 +1,146 @@
+import hashlib
+import json
+from typing import Dict, Tuple
+
+import coremltools as ct  # type: ignore[import]
+from coremltools.converters.mil.input_types import TensorType  # type: ignore[import]
+from coremltools.converters.mil.mil import types  # type: ignore[import]
+from coremltools.models.neural_network import quantization_utils  # type: ignore[import]
+
+import torch
+
+CT_METADATA_VERSION = "com.github.apple.coremltools.version"
+CT_METADATA_SOURCE = "com.github.apple.coremltools.source"
+
+
+class ScalarType:
+    Float = 0
+    Double = 1
+    Int = 2
+    Long = 3
+    Undefined = 4
+
+
+# Supported Tensor types in coremltools:
+# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28
+torch_to_mil_types = {
+    ScalarType.Float: types.fp32,
+    ScalarType.Double: types.fp64,
+    ScalarType.Int: types.int32,
+    ScalarType.Long: types.int64,
+}
+
+
+class CoreMLComputeUnit:
+    CPU = "cpuOnly"
+    CPUAndGPU = "cpuAndGPU"
+    ALL = "all"
+
+
+class CoreMLQuantizationMode:
+    LINEAR = "linear"
+    LINEAR_SYMMETRIC = "linear_symmetric"
+    NONE = "none"
+
+
+def TensorSpec(shape, dtype=ScalarType.Float):
+    return (shape, dtype)
+
+
+def CompileSpec(
+    inputs,
+    outputs,
+    backend=CoreMLComputeUnit.CPU,
+    allow_low_precision=True,
+    quantization_mode=CoreMLQuantizationMode.NONE,
+    mlmodel_export_path=None,
+):
+    return (
+        inputs,
+        outputs,
+        backend,
+        allow_low_precision,
+        quantization_mode,
+        mlmodel_export_path,
+    )
+
+
+def _check_enumerated_shape(shape):
+    for s in shape:
+        if not isinstance(s, (list, tuple)):
+            return False
+    return True
+
+
+def _convert_to_mil_type(shape, dtype, name: str):
+    mil_shape = shape
+    if _check_enumerated_shape(shape):
+        mil_shape = ct.EnumeratedShapes(shape)
+    ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype])
+    ml_type.name = name
+    return ml_type
+
+
+def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]):
+    spec = compile_spec["forward"]
+    (
+        input_specs,
+        output_specs,
+        backend,
+        allow_low_precision,
+        quantization_mode,
+        mlmodel_export_path,
+    ) = spec
+    mil_inputs = []
+    inputs = []
+    for index, input in enumerate(input_specs):
+        shape, dtype = input
+        name = "input_" + str(index)
+        inputs.append([name, str(dtype), str(shape)])
+        ml_type = _convert_to_mil_type(shape, dtype, name)
+        mil_inputs.append(ml_type)
+    model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None)
+    mlmodel = ct.convert(model, inputs=mil_inputs)
+
+    if quantization_mode != CoreMLQuantizationMode.NONE:
+        quant_model_spec = quantization_utils.quantize_weights(
+            mlmodel, nbits=8, quantization_mode=quantization_mode
+        )
+        mlmodel = ct.models.MLModel(quant_model_spec)
+
+    spec = mlmodel.get_spec()
+    assert len(spec.description.output) == len(output_specs)  # type: ignore[attr-defined]
+    outputs = []
+    for index, output in enumerate(output_specs):
+        shape, dtype = output
+        name = spec.description.output[index].name  # type: ignore[attr-defined]
+        outputs.append([name, str(dtype), str(shape)])
+    mlmodel = ct.models.model.MLModel(spec)
+    print(mlmodel)
+
+    if mlmodel_export_path is not None:
+        print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}")
+        mlmodel.save(mlmodel_export_path)
+
+    config = {
+        "spec_ver": str(spec.specificationVersion),  # type: ignore[attr-defined]
+        "backend": backend,
+        "allow_low_precision": str(allow_low_precision),
+    }
+    metadata = {
+        "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION],
+        "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE],
+    }
+    coreml_compile_spec = {
+        "inputs": inputs,
+        "outputs": outputs,
+        "config": config,
+        "metadata": metadata,
+    }
+    mlmodel = spec.SerializeToString()  # type: ignore[attr-defined]
+
+    return {
+        "model": mlmodel,
+        "hash": str(hashlib.sha256(mlmodel).hexdigest()),
+        "extra": json.dumps(coreml_compile_spec),
+    }
diff --git a/MLPY/Lib/site-packages/torch/backends/_nnapi/__init__.py b/MLPY/Lib/site-packages/torch/backends/_nnapi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4352ed061e292ea69fbe696366620e3ef8bbef0c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8530b9989253ef1bd9803c4b61a836b815fd64c4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dabe7e27254c02cffec7b38d58eef24d6752e54
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/_nnapi/prepare.py b/MLPY/Lib/site-packages/torch/backends/_nnapi/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c377f149c8e223bdbfd1c258f65ef1ad61d0388
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/_nnapi/prepare.py
@@ -0,0 +1,198 @@
+from typing import List, Optional
+
+import torch
+from torch.backends._nnapi.serializer import _NnapiSerializer
+
+ANEURALNETWORKS_PREFER_LOW_POWER = 0
+ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
+ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
+
+
+class NnapiModule(torch.nn.Module):
+    """Torch Module that wraps an NNAPI Compilation.
+
+    This module handles preparing the weights, initializing the
+    NNAPI TorchBind object, and adjusting the memory formats
+    of all inputs and outputs.
+    """
+
+    # _nnapi.Compilation is defined
+    comp: Optional[torch.classes._nnapi.Compilation]  # type: ignore[name-defined]
+    weights: List[torch.Tensor]
+    out_templates: List[torch.Tensor]
+
+    def __init__(
+        self,
+        shape_compute_module: torch.nn.Module,
+        ser_model: torch.Tensor,
+        weights: List[torch.Tensor],
+        inp_mem_fmts: List[int],
+        out_mem_fmts: List[int],
+        compilation_preference: int,
+        relax_f32_to_f16: bool,
+    ):
+        super().__init__()
+        self.shape_compute_module = shape_compute_module
+        self.ser_model = ser_model
+        self.weights = weights
+        self.inp_mem_fmts = inp_mem_fmts
+        self.out_mem_fmts = out_mem_fmts
+        self.out_templates = []
+        self.comp = None
+        self.compilation_preference = compilation_preference
+        self.relax_f32_to_f16 = relax_f32_to_f16
+
+    @torch.jit.export
+    def init(self, args: List[torch.Tensor]):
+        assert self.comp is None
+        self.out_templates = self.shape_compute_module.prepare(self.ser_model, args)  # type: ignore[operator]
+        self.weights = [w.contiguous() for w in self.weights]
+        comp = torch.classes._nnapi.Compilation()
+        comp.init2(
+            self.ser_model,
+            self.weights,
+            self.compilation_preference,
+            self.relax_f32_to_f16,
+        )
+
+        self.comp = comp
+
+    def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
+        if self.comp is None:
+            self.init(args)
+        comp = self.comp
+        assert comp is not None
+        outs = [torch.empty_like(out) for out in self.out_templates]
+
+        assert len(args) == len(self.inp_mem_fmts)
+        fixed_args = []
+        for idx in range(len(args)):
+            fmt = self.inp_mem_fmts[idx]
+            # These constants match the values in DimOrder in serializer.py
+            # TODO: See if it's possible to use those directly.
+            if fmt == 0:
+                fixed_args.append(args[idx].contiguous())
+            elif fmt == 1:
+                fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
+            else:
+                raise Exception("Invalid mem_fmt")
+        comp.run(fixed_args, outs)
+        assert len(outs) == len(self.out_mem_fmts)
+        for idx in range(len(self.out_templates)):
+            fmt = self.out_mem_fmts[idx]
+            # These constants match the values in DimOrder in serializer.py
+            # TODO: See if it's possible to use those directly.
+            if fmt in (0, 2):
+                pass
+            elif fmt == 1:
+                outs[idx] = outs[idx].permute(0, 3, 1, 2)
+            else:
+                raise Exception("Invalid mem_fmt")
+        return outs
+
+
+def convert_model_to_nnapi(
+    model,
+    inputs,
+    serializer=None,
+    return_shapes=None,
+    use_int16_for_qint16=False,
+    compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
+    relax_f32_to_f16=False,
+):
+    (
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        retval_count,
+    ) = process_for_nnapi(
+        model, inputs, serializer, return_shapes, use_int16_for_qint16
+    )
+
+    nnapi_model = NnapiModule(
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        compilation_preference,
+        relax_f32_to_f16,
+    )
+
+    class NnapiInterfaceWrapper(torch.nn.Module):
+        """NNAPI list-ifying and de-list-ifying wrapper.
+
+        NNAPI always expects a list of inputs and provides a list of outputs.
+        This module allows us to accept inputs as separate arguments.
+        It returns results as either a single tensor or tuple,
+        matching the original module.
+        """
+
+        def __init__(self, mod):
+            super().__init__()
+            self.mod = mod
+
+    wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
+    wrapper_model = torch.jit.script(wrapper_model_py)
+    # TODO: Maybe make these names match the original.
+    arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
+    if retval_count < 0:
+        ret_expr = "retvals[0]"
+    else:
+        ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
+    wrapper_model.define(
+        f"def forward(self, {arg_list}):\n"
+        f"    retvals = self.mod([{arg_list}])\n"
+        f"    return {ret_expr}\n"
+    )
+    return wrapper_model
+
+
+def process_for_nnapi(
+    model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
+):
+    model = torch.jit.freeze(model)
+
+    if isinstance(inputs, torch.Tensor):
+        inputs = [inputs]
+
+    serializer = serializer or _NnapiSerializer(
+        config=None, use_int16_for_qint16=use_int16_for_qint16
+    )
+    (
+        ser_model,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        shape_compute_lines,
+        retval_count,
+    ) = serializer.serialize_model(model, inputs, return_shapes)
+    ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
+
+    # We have to create a new class here every time this function is called
+    # because module.define adds a method to the *class*, not the instance.
+    class ShapeComputeModule(torch.nn.Module):
+        """Code-gen-ed module for tensor shape computation.
+
+        module.prepare will mutate ser_model according to the computed operand
+        shapes, based on the shapes of args.  Returns a list of output templates.
+        """
+
+        pass
+
+    shape_compute_module = torch.jit.script(ShapeComputeModule())
+    real_shape_compute_lines = [
+        "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
+    ] + [f"    {line}\n" for line in shape_compute_lines]
+    shape_compute_module.define("".join(real_shape_compute_lines))
+
+    return (
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        retval_count,
+    )
diff --git a/MLPY/Lib/site-packages/torch/backends/_nnapi/serializer.py b/MLPY/Lib/site-packages/torch/backends/_nnapi/serializer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e42ab22fa5cd101778e9096903c6a3a25bacc8d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/_nnapi/serializer.py
@@ -0,0 +1,2188 @@
+import array
+import enum
+import functools
+import logging
+import operator
+import struct
+import sys
+from typing import List, NamedTuple, Optional, Tuple
+
+import torch
+
+
+# TODO: Add type annotations
+# TODO: Check tensor types for ops
+
+
+LOG = logging.getLogger("nnapi_serialize")
+
+
+class NNAPI_OperandCode:
+    FLOAT32 = 0
+    INT32 = 1
+    UINT32 = 2
+    TENSOR_FLOAT32 = 3
+    TENSOR_INT32 = 4
+    TENSOR_QUANT8_ASYMM = 5
+    BOOL = 6
+    TENSOR_QUANT16_SYMM = 7
+    TENSOR_FLOAT16 = 8
+    TENSOR_BOOL8 = 9
+    FLOAT16 = 10
+    TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
+    TENSOR_QUANT16_ASYMM = 12
+
+
+class NNAPI_OperationCode:
+    ADD = 0
+    AVERAGE_POOL_2D = 1
+    CONCATENATION = 2
+    CONV_2D = 3
+    DEPTHWISE_CONV_2D = 4
+    DEPTH_TO_SPACE = 5
+    DEQUANTIZE = 6
+    EMBEDDING_LOOKUP = 7
+    FLOOR = 8
+    FULLY_CONNECTED = 9
+    HASHTABLE_LOOKUP = 10
+    L2_NORMALIZATION = 11
+    L2_POOL_2D = 12
+    LOCAL_RESPONSE_NORMALIZATION = 13
+    LOGISTIC = 14
+    LSH_PROJECTION = 15
+    LSTM = 16
+    MAX_POOL_2D = 17
+    MUL = 18
+    RELU = 19
+    RELU1 = 20
+    RELU6 = 21
+    RESHAPE = 22
+    RESIZE_BILINEAR = 23
+    RNN = 24
+    SOFTMAX = 25
+    SPACE_TO_DEPTH = 26
+    SVDF = 27
+    TANH = 28
+    BATCH_TO_SPACE_ND = 29
+    DIV = 30
+    MEAN = 31
+    PAD = 32
+    SPACE_TO_BATCH_ND = 33
+    SQUEEZE = 34
+    STRIDED_SLICE = 35
+    SUB = 36
+    TRANSPOSE = 37
+    ABS = 38
+    ARGMAX = 39
+    ARGMIN = 40
+    AXIS_ALIGNED_BBOX_TRANSFORM = 41
+    BIDIRECTIONAL_SEQUENCE_LSTM = 42
+    BIDIRECTIONAL_SEQUENCE_RNN = 43
+    BOX_WITH_NMS_LIMIT = 44
+    CAST = 45
+    CHANNEL_SHUFFLE = 46
+    DETECTION_POSTPROCESSING = 47
+    EQUAL = 48
+    EXP = 49
+    EXPAND_DIMS = 50
+    GATHER = 51
+    GENERATE_PROPOSALS = 52
+    GREATER = 53
+    GREATER_EQUAL = 54
+    GROUPED_CONV_2D = 55
+    HEATMAP_MAX_KEYPOINT = 56
+    INSTANCE_NORMALIZATION = 57
+    LESS = 58
+    LESS_EQUAL = 59
+    LOG = 60
+    LOGICAL_AND = 61
+    LOGICAL_NOT = 62
+    LOGICAL_OR = 63
+    LOG_SOFTMAX = 64
+    MAXIMUM = 65
+    MINIMUM = 66
+    NEG = 67
+    NOT_EQUAL = 68
+    PAD_V2 = 69
+    POW = 70
+    PRELU = 71
+    QUANTIZE = 72
+    QUANTIZED_16BIT_LSTM = 73
+    RANDOM_MULTINOMIAL = 74
+    REDUCE_ALL = 75
+    REDUCE_ANY = 76
+    REDUCE_MAX = 77
+    REDUCE_MIN = 78
+    REDUCE_PROD = 79
+    REDUCE_SUM = 80
+    ROI_ALIGN = 81
+    ROI_POOLING = 82
+    RSQRT = 83
+    SELECT = 84
+    SIN = 85
+    SLICE = 86
+    SPLIT = 87
+    SQRT = 88
+    TILE = 89
+    TOPK_V2 = 90
+    TRANSPOSE_CONV_2D = 91
+    UNIDIRECTIONAL_SEQUENCE_LSTM = 92
+    UNIDIRECTIONAL_SEQUENCE_RNN = 93
+    RESIZE_NEAREST_NEIGHBOR = 94
+
+
+class NNAPI_FuseCode:
+    FUSED_NONE = 0
+    FUSED_RELU = 1
+    FUSED_RELU1 = 2
+    FUSED_RELU6 = 3
+
+
+class OperandValueSourceType:
+    IMMEDIATE = 0
+    NUMBERED_BUFFER = 2
+    NUMBERED_MEMORY = 3
+
+
+# Scalar types that appear explicitly in models.
+# These must be kept in sync with
+# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
+# TODO: Expose these directly to Python to avoid maintaining this list.
+class TorchScalarTypes(enum.Enum):
+    QUINT8 = 13
+
+
+def approx_equal(lhs, rhs, tolerance=1e-6):
+    return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
+
+
+def tensor_size(op_type, dims):
+    ITEM_SIZES = {
+        NNAPI_OperandCode.TENSOR_FLOAT32: 4,
+        NNAPI_OperandCode.TENSOR_INT32: 4,
+        NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
+        NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
+        NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
+    }
+    size = ITEM_SIZES[op_type]
+    for d in dims:
+        size *= d
+    return size
+
+
+def change_element(tup, index, value):
+    ls = list(tup)
+    ls[index] = value
+    return tuple(ls)
+
+
+class ConvPoolArgs2d(NamedTuple):
+    """Configuration arguments for a convolution."""
+
+    kernel_h: int
+    kernel_w: int
+    stride_h: int
+    stride_w: int
+    pad_t: int
+    pad_b: int
+    pad_l: int
+    pad_r: int
+    dilation_h: int
+    dilation_w: int
+    group: int
+
+
+class DimOrder(enum.Enum):
+    PRESUMED_CONTIGUOUS = 0
+    CHANNELS_LAST = 1
+    SCALAR_OR_VECTOR = 2
+    UNKNOWN_CONSTANT = 999
+
+
+class Operand(NamedTuple):
+    """Represenation of an NNAPI operand."""
+
+    # NNAPI operand type.  One of NNAPI_OperandCode.
+    # TODO: Make this an enum.
+    op_type: int
+
+    # This is always the PyTorch shape, which is NCHW for feature maps.
+    # The actual NNAPI operand might have a transposed shape.
+    # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
+    shape: Tuple[int, ...]
+
+    # Specifies how the shape of the operand that we define in NNAPI
+    # relates to the shape we track above.
+    # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
+    #   the shape of the PyTorch tensor.
+    # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
+    #   the NNAPI operand will be represented explicitly as NHWC.
+    dim_order: DimOrder
+
+    # Quantization params
+    scale: float
+    zero_point: int
+
+    def use_nchw(self):
+        if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
+            return True
+        if self.dim_order is DimOrder.CHANNELS_LAST:
+            return False
+        raise Exception("Unknown dim order")
+
+
+def broadcast_shapes(shape1, shape2):
+    assert len(shape1) > 0
+    assert len(shape2) > 0
+    s1 = list(shape1)
+    s2 = list(shape2)
+    # TODO: Support non-equal-rank broadcast where semantics match.
+    # This can be tricky for NHWC tensors because dimension orders
+    # don't match between PT and NNAPI, even though semantics match.
+    if len(s1) > len(s2):
+        # s2 = [1] * (len(s1) - len(s2)) + s2
+        raise Exception("Non-equal-rank broadcast is not supported yet.")
+    if len(s2) > len(s1):
+        # s3 = [1] * (len(s2) - len(s1)) + s1
+        raise Exception("Non-equal-rank broadcast is not supported yet.")
+    ret = []
+    for d1, d2 in zip(s1, s2):
+        if d1 == 1:
+            ret.append(d2)
+        elif d2 == 1:
+            ret.append(d1)
+        elif d1 == d2:
+            ret.append(d1)
+        else:
+            raise Exception(f"Cannot broadcast shapes: {shape1} and {shape2}")
+    return tuple(ret)
+
+
+def get_conv_pool_shape(image_shape, args, out_ch, transpose):
+    batch, in_c, in_h, in_w = image_shape
+
+    # TODO: Handle dilation
+    if args.dilation_h != 1 or args.dilation_w != 1:
+        raise Exception("Dilation not supported yet.")
+
+    if transpose:
+        out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
+        out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
+    else:
+        out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
+        out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
+
+    # Handle variable-sized tensors.
+    if in_h == 0:
+        out_h = 0
+    if in_w == 0:
+        out_w = 0
+
+    out_shape = (batch, out_ch, out_h, out_w)
+    return out_shape
+
+
+def fix_shape(shape, dim_order):
+    # Return the actual shape that an operand should have in NNAPI,
+    # given a PyTorch shape and dimension order.  This is where we
+    # convert from PyTorch's "always NCHW" shape to explicit NHWC.
+    if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
+        return shape
+    if dim_order is DimOrder.CHANNELS_LAST:
+        return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
+    if dim_order is DimOrder.SCALAR_OR_VECTOR:
+        assert len(shape) == 0 or len(shape) == 1
+        return shape
+    if dim_order is DimOrder.UNKNOWN_CONSTANT:
+        # XXX think this through
+        return shape
+    raise Exception(f"Bad dim_order: {dim_order!r}.")
+
+
+def reverse_map_dim(dim_order, d):
+    # Return the original PyTorch dimension position for a given dimension.
+    # d should be the dimension that NNAPI will see.
+    # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
+    # reverse_map_dim(CHANNELS_LAST, 3) == 1
+    if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
+        return d
+    assert dim_order is DimOrder.CHANNELS_LAST
+    return [0, 2, 3, 1][d]
+
+
+def flex_name(op_id, dim):
+    # Return the local variable name for the computed flexible size
+    # for a given op and dimension.
+    return f"s_{op_id}_{dim}"
+
+
+class _NnapiSerializer:
+    def __init__(self, config, use_int16_for_qint16=False):
+        self.operands = []
+        self.values = []
+        self.operations = []
+        self.value_data = []
+        self.operation_args = []
+        self.inputs = []
+        self.outputs = []
+        self.flexible_shape_computation_lines = []
+
+        self.modules = {}
+        self.constants = {}
+        self.tensor_sequences = {}
+        self.jitval_operand_map = {}
+        self.cached_immediates = {}
+        self.used_weights = []
+        self.weight_offset = 0
+        self.use_int16_for_qint16 = use_int16_for_qint16
+
+        if config is None:
+            config = {}
+
+    def get_next_operand_id(self):
+        return len(self.operands)
+
+    # Add a tensor operand corresponding to a JIT Value.
+    # Returns the NNAPI operand ID.  Can be looked up later with
+    # get_tensor_operand_by_jitval.
+    def add_tensor_operand(self, jitval, oper):
+        assert isinstance(oper, Operand)
+        if jitval in self.jitval_operand_map:
+            raise Exception(f"Duplicate tensor: {jitval!r}")
+
+        operand_id = self.get_next_operand_id()
+        self.operands.append(oper)
+        self.jitval_operand_map[jitval] = operand_id
+        return operand_id
+
+    # Add a tensor operand that does not correspond to a JIT Value.
+    # Useful for cases where multiple NNAPI operands are required
+    # to implement one JIT IR node.  Returns the NNAPI operand ID.
+    def add_anonymous_tensor_operand(self, oper):
+        assert isinstance(oper, Operand)
+        operand_id = self.get_next_operand_id()
+        self.operands.append(oper)
+        return operand_id
+
+    def torch_tensor_to_operand(self, tensor, dim_order):
+        dtype = str(tensor.dtype).replace("torch.", "")
+        scale = 0.0
+        zero_point = 0
+        if dtype == "float32":
+            op_type = NNAPI_OperandCode.TENSOR_FLOAT32
+        elif dtype == "int32":
+            op_type = NNAPI_OperandCode.TENSOR_INT32
+        elif dtype == "quint8":
+            op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+            scale = tensor.q_scale()
+            zero_point = tensor.q_zero_point()
+        elif dtype == "qint32":
+            op_type = NNAPI_OperandCode.TENSOR_INT32
+            scale = tensor.q_scale()
+            zero_point = tensor.q_zero_point()
+            assert zero_point == 0
+        elif dtype == "int16":
+            if self.use_int16_for_qint16:
+                nnapi_dtype = getattr(tensor, "nnapi_dtype", None)
+                op_codes = (
+                    NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
+                    NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
+                )
+                if nnapi_dtype in op_codes:
+                    op_type = nnapi_dtype
+                    scale = tensor.nnapi_scale
+                    zero_point = tensor.nnapi_zero_point
+                else:
+                    raise Exception(
+                        f"`nnapi_type` needs to be one of {op_codes} for `int16`"
+                    )
+            else:
+                raise Exception(
+                    "`int16` isn't supported. If you're trying to represent NNAPI"
+                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
+                )
+        else:
+            raise Exception(f"Can't handle input with dtype '{tensor.dtype}'")
+        return Operand(
+            shape=tuple(tensor.shape),
+            op_type=op_type,
+            dim_order=dim_order,
+            scale=scale,
+            zero_point=zero_point,
+        )
+
+    def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
+        dim_order = (
+            DimOrder.CHANNELS_LAST
+            if getattr(tensor, "nnapi_nhwc", False)
+            else DimOrder.PRESUMED_CONTIGUOUS
+        )
+        toper = self.torch_tensor_to_operand(tensor, dim_order)
+        operand_id = self.add_tensor_operand(jitval, toper)
+        self.inputs.append(operand_id)
+        for dim, size in enumerate(tensor.shape):
+            if size == 0:
+                self.compute_operand_shape(
+                    operand_id, dim, f"args[{arg_idx}].shape[{dim}]"
+                )
+        return operand_id
+
+    def add_tensor_operand_for_weight(
+        self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT
+    ):
+        toper = self.torch_tensor_to_operand(tensor, dim_order)
+        operand_id = len(self.operands)
+        self.operands.append(toper)
+        tsize = tensor_size(toper.op_type, toper.shape)
+        psize = ((tsize - 1) | 0x3) + 1
+        self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
+        buf_num = len(self.used_weights)
+        offset = 0
+        self.value_data.append(struct.pack("iii", buf_num, offset, tsize))
+        # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor
+        if dim_order == DimOrder.CHANNELS_LAST:
+            tensor = tensor.permute(0, 2, 3, 1)
+        self.used_weights.append(tensor)
+        return operand_id
+
+    def add_immediate_operand(self, code, value, dims):
+        assert isinstance(dims, tuple)
+        cache_key = (code, value)
+        if cache_key not in self.cached_immediates:
+            operand_id = len(self.operands)
+            self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
+            self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
+            self.value_data.append(value)
+            self.cached_immediates[cache_key] = operand_id
+        return self.cached_immediates[cache_key]
+
+    def add_immediate_int_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.INT32, struct.pack("i", value), ()
+        )
+
+    def add_immediate_float_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.FLOAT32, struct.pack("f", value), ()
+        )
+
+    def add_immediate_bool_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", ()
+        )
+
+    def add_immediate_int_vector(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.TENSOR_INT32,
+            array.array("i", value).tobytes(),
+            (len(value),),
+        )
+
+    def has_operand_for_jitval(self, jitval):
+        return jitval in self.jitval_operand_map
+
+    def get_tensor_operand_by_jitval(self, jitval):
+        operand_id = self.jitval_operand_map[jitval]
+        return (operand_id, self.operands[operand_id])
+
+    def get_tensor_operand_by_jitval_fixed_size(self, jitval):
+        op_id, oper = self.get_tensor_operand_by_jitval(jitval)
+        for s in oper.shape:
+            if s == 0:
+                # TODO: Improve this error message, possibly after converting
+                # many callsites to support flexible size.
+                raise Exception("Flexible size is not supported for this operand.")
+            if s < 0:
+                # runtime flex
+                LOG.warning("Operand %s has runtime flex shape", oper)
+        return op_id, oper
+
+    def get_tensor_operand_or_constant(
+        self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+    ):
+        operand_id = self.jitval_operand_map.get(jitval)
+        if operand_id is None:
+            _, value = self.get_constant_value(jitval, "TensorType")
+            operand_id = self.add_tensor_operand_for_weight(value, dim_order)
+        return (operand_id, self.operands[operand_id])
+
+    def get_tensor_operand_for_weight(self, jitval):
+        _, value = self.get_constant_value(jitval, "TensorType")
+        operand_id = self.add_tensor_operand_for_weight(value)
+        return (operand_id, self.operands[operand_id])
+
+    def add_operation(self, opcode, inputs, outputs):
+        self.operations.append((opcode, len(inputs), len(outputs)))
+        self.operation_args.extend(inputs + outputs)
+
+    def add_tensor_sequence(self, jitval, values):
+        assert jitval not in self.tensor_sequences
+        self.tensor_sequences[jitval] = values
+
+    def add_constant_value(self, jitval, ctype, value):
+        assert jitval not in self.constants
+        self.constants[jitval] = (ctype, value)
+
+    def get_constant_value(self, jitval, typekind=None):
+        record = self.constants.get(jitval)
+        if record is None:
+            raise Exception(f"Could not find constant value for '{jitval!r}'.")
+        ctype, _ = record
+        if typekind is not None and ctype.kind() != typekind:
+            raise Exception(
+                f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'"
+            )
+        return record
+
+    def operand_to_template_torchscript(self, op_id, oper, shape=None):
+        """Return a TorchScript expression to build a template for a given operand."""
+        if shape is None:
+            shape = oper.shape
+        else:
+            assert len(shape) == len(oper.shape)
+
+        shape_parts = ["("]
+        for d, s in enumerate(shape):
+            if s > 0:
+                # Fixed shape dimension: just add the value.
+                shape_parts.append(str(s))
+            elif s == 0:
+                # Load time flexible shape dimension: it should have been computed in a variable.
+                shape_parts.append(flex_name(op_id, d))
+            elif s == -1:
+                # Runtime flexible shape
+                shape_parts.append("0")
+            else:
+                raise Exception("Unknown dim value, dimensions should be >= -1")
+            shape_parts.append(",")
+        shape_parts.append(")")
+        shape_code = "".join(shape_parts)
+        if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
+            return f"torch.zeros({shape_code}, dtype=torch.float32)"
+        elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32:
+            return f"torch.zeros({shape_code}, dtype=torch.int32)"
+        elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+            return (
+                f"torch.quantize_per_tensor("
+                f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
+                f".expand({shape_code}).contiguous()"
+            )
+        elif oper.op_type in (
+            NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
+            NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
+        ):
+            if self.use_int16_for_qint16:
+                return f"torch.zeros({shape_code}, dtype=torch.int16)"
+            else:
+                raise Exception(
+                    "`int16` isn't supported. If you're trying to represent NNAPI"
+                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
+                )
+
+        raise Exception(f"Unsupported output operand type: {oper.op_type}")
+
+    def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
+        self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
+
+    def compute_operand_shape(self, op_id, dim, expr):
+        self.flexible_shape_computation_lines.append(
+            f"{flex_name(op_id, dim)} = {expr}"
+        )
+
+    def transpose_to_nhwc(self, in_id, oper):
+        if oper.shape[2:] != (1, 1):
+            raise Exception("Automatic transpose only supported for H,W == 1,1")
+
+        out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
+
+        outputs = [None] * 1
+        outputs[0] = self.add_anonymous_tensor_operand(out_oper)
+
+        self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
+
+        return outputs[0], out_oper
+
+    # Transpose inputs as necessary to allow broadcasting.
+    def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
+        if in0_oper.dim_order == in1_oper.dim_order:
+            return in0_id, in0_oper, in1_id, in1_oper
+
+        # Assume NHWC is preferred if there is a mismatch.
+        orders = (in0_oper.dim_order, in1_oper.dim_order)
+        if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
+            return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
+        if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
+            return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+
+        raise Exception(
+            f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}"
+        )
+
+    def get_size_arg(self, jitval):
+        ctype, value = self.get_constant_value(jitval)
+        if ctype.kind() == "ListType":
+            assert ctype.getElementType().kind() == "IntType"
+            return value
+        raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'")
+
+    def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
+        pc = [i.item() for i in packed_config]
+        assert pc[0] == 2
+        strides = [pc[1], pc[2]]
+        paddings = [pc[3], pc[4]]
+        dilations = [pc[5], pc[6]]
+        output_padding = [pc[7], pc[8]]
+        group_num = pc[9]
+
+        assert len(pc) == 11
+        assert output_padding == [0, 0]
+
+        return self.get_conv_pool_args_2d_common(
+            kernel_size, strides, paddings, dilations, group_num
+        )
+
+    def get_conv_pool_args_2d_from_jit(
+        self, kernel_size, stride, padding, dilation=None, group=None
+    ):
+        strides = self.get_size_arg(stride)
+        paddings = self.get_size_arg(padding)
+        if dilation is None:
+            dilations = [1, 1]
+        else:
+            dilations = self.get_size_arg(dilation)
+        if group is not None:
+            _, group_num = self.get_constant_value(group, "IntType")
+        else:
+            group_num = None
+        return self.get_conv_pool_args_2d_common(
+            kernel_size, strides, paddings, dilations, group_num
+        )
+
+    def get_conv_pool_args_2d_common(
+        self, kernel_size, strides, paddings, dilations, group_num
+    ):
+        kernels = list(kernel_size)
+
+        assert len(kernels) == 2
+        assert len(strides) == 2
+        assert len(paddings) == 2
+        assert len(dilations) == 2
+
+        # NNAPI uses 4 values for padding.
+        ph, pw = paddings
+        real_paddings = [ph, ph, pw, pw]
+
+        return ConvPoolArgs2d(
+            *(kernels + strides + real_paddings + dilations + [group_num])
+        )
+
+    def serialize_model(self, model, inputs, return_shapes=None):
+        self.add_immediate_bool_scalar(False)
+        self.add_immediate_bool_scalar(True)
+
+        inp_dim_orders = []
+        out_dim_orders = []
+
+        self_jitval = next(model.graph.inputs())
+        self.add_constant_value(self_jitval, self_jitval.type(), model)
+
+        for arg_idx, (input_value, input_tensor) in enumerate(
+            zip(list(model.graph.inputs())[1:], inputs)
+        ):
+            op_id = self.add_tensor_operand_for_input(
+                arg_idx, input_value, input_tensor
+            )
+            inp_dim_orders.append(self.operands[op_id].dim_order.value)
+
+        for idx, node in enumerate(model.graph.nodes()):
+            LOG.debug("Processing node #%d: %r", idx, node)
+            self.add_node(node)
+
+        retn = model.graph.return_node()
+        assert retn.inputsSize() == 1
+        assert retn.outputsSize() == 0
+        retn_input = retn.inputsAt(0)
+        template_return_lines = ["return ["]
+        if retn_input.type().kind() == "TensorType":
+            return_values = [retn_input]
+            retval_count = -1
+        elif retn_input.type().kind() == "TupleType":
+            return_values = self.tensor_sequences[retn_input]
+            retval_count = len(return_values)
+        else:
+            raise Exception(f"Unsupported return type: {retn_input.type()}")
+
+        if return_shapes is not None:
+            assert len(return_shapes) == len(return_values)
+        for i, v in enumerate(return_values):
+            op_id = self.jitval_operand_map[v]
+            self.outputs.append(op_id)
+            out_dim_orders.append(self.operands[op_id].dim_order.value)
+            shape = return_shapes[i] if return_shapes else None
+            template_return_lines.append(
+                self.operand_to_template_torchscript(op_id, self.operands[op_id], shape)
+                + ","
+            )
+        template_return_lines.append("]")
+
+        model = []
+
+        version = 1
+        header = struct.pack(
+            "iiiiii",
+            version,
+            len(self.operands),
+            len(self.values),
+            len(self.operations),
+            len(self.inputs),
+            len(self.outputs),
+        )
+        model.append(header)
+
+        serialized_values, serialized_value_data = self.serialize_values()
+
+        model.extend(
+            struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands
+        )
+        model.extend(serialized_values)
+        model.extend(struct.pack("iii", *x) for x in self.operations)
+
+        # Compact the model so we can get its length so far.
+        model = [b"".join(model)]
+        model_offset = len(model[0])
+        # Model offset is the index into the model (in 32-bit words, not bytes)
+        # of the next dimension we're about to serialize.  If it's 0,
+        # generate code to mutate it before passing to NNAPI.
+        assert model_offset % 4 == 0
+        model_offset = int(model_offset / 4)
+
+        for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands):
+            shape = fix_shape(dims, dim_order)
+            for d, s in enumerate(shape):
+                if s == 0:
+                    pt_d = reverse_map_dim(dim_order, d)
+                    self.flexible_shape_computation_lines.append(
+                        f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}"
+                    )
+                model_offset += 1
+
+            # convert runtime flex shape from -1 to 0
+            shape = tuple(d if d != -1 else 0 for d in shape)
+            model.append(self.serialize_ints(shape))
+
+        model.extend(serialized_value_data)
+        model.append(self.serialize_ints(self.operation_args))
+        model.append(self.serialize_ints(self.inputs))
+        model.append(self.serialize_ints(self.outputs))
+
+        self.flexible_shape_computation_lines.extend(template_return_lines)
+
+        return (
+            array.array("i", b"".join(model)),
+            self.used_weights,
+            inp_dim_orders,
+            out_dim_orders,
+            self.flexible_shape_computation_lines,
+            retval_count,
+        )
+
+    def serialize_values(self):
+        serialized_values = []
+        serialized_value_data = []
+        assert len(self.values) == len(self.value_data)
+        for (op_index, source_type), data in zip(self.values, self.value_data):
+            source_length = len(data)
+
+            # Pad with 0 bytes out to a multiple of 4 for alignment.
+            physical_length = ((source_length - 1) | 0x3) + 1
+            padded_data = data + (b"\0" * (physical_length - source_length))
+
+            serialized_values.append(
+                struct.pack("iii", op_index, source_type, source_length)
+            )
+            serialized_value_data.append(padded_data)
+
+        return serialized_values, serialized_value_data
+
+    @staticmethod
+    def serialize_ints(ints):
+        return array.array("i", ints).tobytes()
+
+    ADDER_MAP = {
+        "prim::GetAttr": lambda self, node: self.add_getattr(node),
+        "prim::Constant": lambda self, node: self.add_constant_node(node),
+        "prim::ListConstruct": lambda self, node: self.add_list_construct(node),
+        "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node),
+        "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node),
+        "aten::to": lambda self, node: self.add_to(node),
+        "aten::detach": lambda self, node: self._identity(node),
+        "aten::reshape": lambda self, node: self.add_reshape(node),
+        "aten::flatten": lambda self, node: self.add_flatten(node),
+        "aten::slice": lambda self, node: self.add_slice(node),
+        "aten::size": lambda self, node: self.add_size(node),
+        "aten::cat": lambda self, node: self.add_cat(node),
+        "aten::mean": lambda self, node: self.add_mean(node),
+        "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node),
+        "aten::dequantize": lambda self, node: self.add_dequantize(node),
+        "aten::add": lambda self, node: self.add_add_sub_op(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::sub": lambda self, node: self.add_add_sub_op(
+            node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
+            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
+            node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op(
+            node, NNAPI_OperationCode.RELU
+        ),
+        "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op(
+            node, NNAPI_OperationCode.LOGISTIC
+        ),
+        "aten::softmax": lambda self, node: self.add_softmax(node),
+        "aten::hardtanh": lambda self, node: self.add_hardtanh(node),
+        "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node),
+        "aten::max_pool2d": lambda self, node: self.add_pool2d_node(
+            node, NNAPI_OperationCode.MAX_POOL_2D
+        ),
+        "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d(
+            node
+        ),
+        "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d(
+            node
+        ),
+        "aten::prelu": lambda self, node: self.add_prelu_op(node),
+        "aten::addmm": lambda self, node: self.add_addmm(node),
+        "aten::linear": lambda self, node: self.add_linear(node),
+        "aten::_convolution": lambda self, node: self.add_conv_underscore(node),
+        "aten::conv2d": lambda self, node: self.add_conv2d(node),
+        "aten::log_softmax": lambda self, node: self.add_log_softmax(node),
+        "quantized::linear": lambda self, node: self.add_qlinear(node),
+        "quantized::conv2d": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "quantized::conv2d_relu": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_RELU
+        ),
+        "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_NONE, transpose=True
+        ),
+        "quantized::add": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "quantized::add_relu": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU
+        ),
+        "quantized::mul": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
+        ),
+    }
+
+    def add_node(self, node):
+        adder = self.ADDER_MAP.get(node.kind())
+        if not adder:
+            raise Exception(f"Unsupported node kind ({node.kind()!r}) in node {node!r}")
+        adder(self, node)
+
+    def _identity(self, node):
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        jitval = node.outputsAt(0)
+        self.jitval_operand_map[jitval] = in_id
+
+    def add_getattr(self, node):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+        obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
+        assert str(obj_ctype).startswith("__torch__.")
+        name = node.s("name")
+        value = getattr(obj, name)
+        output = node.outputsAt(0)
+        ctype = output.type()
+        self.add_constant_value(output, ctype, value)
+
+    def add_constant_node(self, node):
+        assert node.inputsSize() == 0
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        ctype = output.type()
+        value = output.toIValue()
+        self.add_constant_value(output, ctype, value)
+
+    def add_list_construct(self, node):
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        ctype = output.type()
+        const_vals: Optional[List] = []
+        tensors: Optional[List] = []
+        for inp in node.inputs():
+            if const_vals is not None and inp in self.constants:
+                _, val = self.get_constant_value(inp)
+                const_vals.append(val)
+            else:
+                const_vals = None
+            if tensors is not None and inp.type().kind() == "TensorType":
+                tensors.append(inp)
+            else:
+                tensors = None
+
+        if const_vals is not None:
+            # NOTE: Now that TorchScript supports list constants,
+            # this code path might not be used anymore.
+            self.add_constant_value(output, ctype, const_vals)
+        if tensors is not None:
+            self.add_tensor_sequence(output, tensors)
+        if const_vals is None and tensors is None:
+            raise Exception(
+                f"Unable to handle ListConstruct node.  Neither all constants nor all tensors. {node!r}"
+            )
+
+    def add_tuple_construct(self, node):
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        values = list(node.inputs())
+        self.add_tensor_sequence(output, values)
+
+    def add_unsqueeze(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+
+        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
+        assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
+
+        real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
+        out_shape_list = list(in_oper.shape)
+        out_shape_list.insert(real_dim, 1)
+        out_shape = tuple(out_shape_list)
+        out_oper = in_oper._replace(shape=out_shape)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_scalar(dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
+
+    def add_to(self, node):
+        # Handle to("cpu") / to("gpu") case
+        self._identity(node)
+
+    def add_reshape(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+
+        shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
+        assert shape_ctype.kind() == "ListType"
+        assert shape_ctype.getElementType().kind() == "IntType"
+        is_trivial_reshape = len(shape) == 2 and shape[1] == -1
+
+        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
+            raise Exception(
+                "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]."
+            )
+
+        # Bit of a hack here.  Use a real tensor to infer the output shape.
+        out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
+        out_oper = in_oper._replace(
+            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+        )
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(shape)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
+
+    def add_flatten(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
+        end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
+
+        # channels last with channels == 1 or (height & width both 1)
+        is_trivial_flatten = len(in_oper.shape) == 4 and (
+            in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1)
+        )
+        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten:
+            raise Exception(
+                "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1"
+            )
+
+        if start_dim < 0:
+            start_dim += len(in_oper.shape)
+        if end_dim < 0:
+            end_dim += len(in_oper.shape)
+
+        out_shape = (
+            in_oper.shape[:start_dim]
+            + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),)
+            + in_oper.shape[end_dim + 1 :]
+        )
+
+        if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]):
+            raise Exception("Flattening flexible dims is not supported yet")
+        non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :]
+        if non_flattened_dims.count(0) > 1:
+            raise Exception("Only 1 dim can be flexible")
+
+        out_oper = in_oper._replace(
+            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+        )
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        for idx, dim in enumerate(out_shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
+
+        inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape)
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(inputs_1)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
+
+    def add_slice(self, node):
+        assert node.inputsSize() == 5
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        _, dim_value = self.get_constant_value(node.inputsAt(1))
+        _, start_value = self.get_constant_value(node.inputsAt(2))
+        _, stop_value = self.get_constant_value(node.inputsAt(3))
+        _, step_value = self.get_constant_value(node.inputsAt(4))
+
+        if start_value is None:
+            start_value = 0
+        if stop_value is None:
+            stop_value = sys.maxsize
+
+        if start_value < 0:
+            start_value += in_oper.shape[dim_value]
+        elif start_value == sys.maxsize:
+            start_value = 0
+
+        if start_value == 0 and stop_value == sys.maxsize:
+            self._identity(node)
+            return
+
+        if in_oper.shape[dim_value] == 0:
+            raise Exception("Unable to slice with flexible shape")
+
+        if stop_value < 0:
+            stop_value += in_oper.shape[dim_value]
+        elif stop_value == sys.maxsize:
+            stop_value = in_oper.shape[dim_value]
+
+        if start_value >= stop_value:
+            raise Exception("Slice start value should be less than stop value")
+
+        out_len = (stop_value - start_value) // step_value
+        out_shape = tuple(
+            out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape)
+        )
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), in_oper._replace(shape=out_shape)
+        )
+
+        # flex inputs
+        end_mask = 0
+        for idx, dim in enumerate(out_shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, idx)
+                end_mask |= 1 << idx
+
+        inputs = [None] * 7
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(
+            [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))]
+        )
+        inputs[2] = self.add_immediate_int_vector(
+            [
+                stop_value if i == dim_value else dim
+                for i, dim in enumerate(in_oper.shape)
+            ]
+        )
+        inputs[3] = self.add_immediate_int_vector(
+            [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))]
+        )
+        inputs[4] = self.add_immediate_int_scalar(0)  # begin mask
+        inputs[5] = self.add_immediate_int_scalar(end_mask)
+        inputs[6] = self.add_immediate_int_scalar(0)  # shrink axis mas
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs)
+
+    def add_size(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        _, value = self.constants[node.inputsAt(1)]
+        res = in_oper.shape[value]
+        output = node.outputsAt(0)
+        self.add_constant_value(output, output.type(), res)
+
+    def add_cat(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        tensors = self.tensor_sequences[node.inputsAt(0)]
+        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
+
+        assert len(tensors) > 0
+        in_ids = []
+        out_oper = None
+        out_dim_size = 0
+        for inp in tensors:
+            in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
+            if out_oper is None:
+                out_shape = change_element(in_oper.shape, dim, -1)
+                out_oper = in_oper._replace(shape=out_shape)
+            assert in_oper.op_type == out_oper.op_type
+            assert in_oper.dim_order == out_oper.dim_order
+            assert change_element(in_oper.shape, dim, -1) == change_element(
+                out_oper.shape, dim, -1
+            )
+            # TODO: Possibly check scale and zero point.
+            in_ids.append(in_id)
+            # TODO: Possibly support variable-sized inputs.
+            out_dim_size += in_oper.shape[dim]
+
+        assert out_oper is not None
+        out_oper = out_oper._replace(
+            shape=change_element(out_oper.shape, dim, out_dim_size)
+        )
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST:  # type: ignore[possibly-undefined]
+            assert len(out_oper.shape) == 4
+            nnapi_dim = [0, 3, 1, 2][dim]
+        else:
+            nnapi_dim = dim
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+        for idx, d in enumerate(out_oper.shape):
+            if d == 0:
+                if idx == dim:
+                    shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
+                    self.compute_operand_shape(out_id, idx, shape)
+                else:
+                    self.forward_operand_shape(out_id, idx, in_ids[0], idx)
+
+        inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
+
+    def add_mean(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
+        assert dim_ctype.kind() == "ListType"
+        assert dim_ctype.getElementType().kind() == "IntType"
+        _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
+        # Expect None for dtype
+        self.get_constant_value(node.inputsAt(3), "NoneType")
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST:
+            assert len(in_oper.shape) == 4
+            nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
+        else:
+            nnapi_dim = dim
+
+        collapsed_dims = set()
+        for d in dim:
+            if d < 0:
+                d += len(in_oper.shape)
+            collapsed_dims.add(d)
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
+            assert collapsed_dims.issuperset({2, 3})
+            out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
+        else:
+            out_dim_order = in_oper.dim_order
+
+        out_shape = []
+        for i, s in enumerate(in_oper.shape):
+            if i not in collapsed_dims:
+                out_shape.append(s)
+            elif keep_dim:
+                out_shape.append(1)
+
+        out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
+
+        inputs = [None] * 3
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(nnapi_dim)
+        inputs[2] = self.add_immediate_int_scalar(keep_dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
+
+    def add_quantize(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        if in_oper.dim_order != DimOrder.CHANNELS_LAST:
+            raise Exception(
+                "Most hardware backends prefer NHWC quantized tensors.  "
+                "Try setting `t.nnapi_nhwc = True` on your tensor inputs.  "
+            )
+        _, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
+        _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
+        _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
+        if scalar_type != TorchScalarTypes.QUINT8.value:
+            raise Exception(
+                "PyTorch NNAPI export only supports quantized tensors "
+                "with the quint8 dtype."
+            )
+        op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+
+        out_oper = in_oper._replace(
+            op_type=op_type,
+            scale=scale,
+            zero_point=zero_point,
+        )
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
+
+    def add_dequantize(self, node):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        out_oper = in_oper._replace(
+            op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
+            scale=0.0,
+            zero_point=0,
+        )
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
+
+    def add_pointwise_simple_unary_op(self, node, opcode):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        out_oper = in_oper
+        if opcode == NNAPI_OperationCode.LOGISTIC:
+            # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
+            # must be 1.f / 256 and the zeroPoint must be 0.
+            # https://fburl.com/h52stoog
+            if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+                out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        for idx, dim in enumerate(in_oper.shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, idx)
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None):  # noqa: D401
+        """Helper for pointwise binary broadcast ops with superfluous extra args."""
+        assert node.outputsSize() == 1
+
+        assert node.inputsAt(0).type().kind() == "TensorType"
+        assert node.inputsAt(1).type().kind() == "TensorType"
+
+        if self.has_operand_for_jitval(node.inputsAt(0)):
+            in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+            in1_id, in1_oper = self.get_tensor_operand_or_constant(
+                node.inputsAt(1), in0_oper.dim_order
+            )
+        elif self.has_operand_for_jitval(node.inputsAt(1)):
+            in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
+            in0_id, in0_oper = self.get_tensor_operand_or_constant(
+                node.inputsAt(0), in1_oper.dim_order
+            )
+        else:
+            raise Exception(f"Can't do a NNAPI binary op: {opcode} on two constants")
+
+        assert in0_oper.op_type == in1_oper.op_type
+        in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
+            in0_id, in0_oper, in1_id, in1_oper
+        )
+        # NOTE: PyTorch and NNAPI have the same broadcast semantics.
+        out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
+        out_oper = in0_oper._replace(shape=out_shape)
+        if qparams is not None:
+            scale, zp = qparams
+            out_oper = out_oper._replace(scale=scale, zero_point=zp)
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+        for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)):
+            if d0 == 1 and d1 == 0:
+                self.forward_operand_shape(out_id, idx, in1_id, idx)
+            elif d0 == 0 and d1 == 1:
+                self.forward_operand_shape(out_id, idx, in0_id, idx)
+            elif d0 == 0 and d1 == 0:
+                self.flexible_shape_computation_lines.append(
+                    f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}"
+                )
+                self.forward_operand_shape(out_id, idx, in0_id, idx)
+
+        inputs = [None] * 3
+        inputs[0] = in0_id
+        inputs[1] = in1_id
+        inputs[2] = self.add_immediate_int_scalar(fuse_code)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 2
+        self._do_add_binary(node, opcode, fuse_code)
+
+    def add_add_sub_op(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 3
+
+        _, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
+        if alpha != 1:
+            raise Exception("NNAPI does not support add/sub with alpha.")
+
+        self._do_add_binary(node, opcode, fuse_code)
+
+    def add_qadd(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 4
+
+        _, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
+        _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
+
+        self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
+
+    def add_softmax(self, node):
+        assert node.inputsSize() == 3
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType")
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
+        for dim, size in enumerate(in_oper.shape):
+            if size == 0:
+                self.forward_operand_shape(out_id, dim, in_id, dim)
+
+        inputs = [None] * 3
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_float_scalar(
+            1.0
+        )  # positive scaling factor of exponent, beta
+        inputs[2] = self.add_immediate_int_scalar(softmax_dim)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs)
+
+    def add_hardtanh(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
+        _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
+
+        op_map = {
+            (-1, 1): NNAPI_OperationCode.RELU1,
+            (0, 6): NNAPI_OperationCode.RELU6,  # noqa: E201
+        }
+
+        opcode = op_map.get((min_val, max_val))
+        if opcode is None:
+            raise Exception("NNAPI only supports hardtanh with args (-1, 1) or (0, 6).")
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_prelu_op(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        assert node.inputsAt(0).type().kind() == "TensorType"
+        assert node.inputsAt(1).type().kind() == "TensorType"
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
+        assert len(w_oper.shape) == 1
+        assert w_oper.shape[0] > 0
+        if w_oper.shape[0] > 1:
+            if in_oper.use_nchw():
+                # TODO: Support this by adding trailing 1 dims.
+                raise Exception(
+                    "Per-channel PReLU only supports channels_last right now."
+                )
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
+        for dim, size in enumerate(in_oper.shape):
+            if size > 0:
+                pass
+            elif dim <= 1:
+                raise Exception("PReLU requires fixed size for dim 0 and dim 1.")
+            else:
+                self.forward_operand_shape(out_id, dim, in_id, dim)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = w_id
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
+
+    def add_pool2d_node(self, node, opcode):
+        assert node.inputsSize() == 6
+        assert node.outputsSize() == 1
+        image, kernel, stride, padding, dilation, ceil_mode = node.inputs()
+
+        stride = stride or kernel
+
+        # TODO: Validate ceil_mode semantics.
+
+        args = self.get_conv_pool_args_2d_from_jit(
+            self.get_size_arg(kernel), stride, padding, dilation
+        )
+        if args.dilation_h != 1 or args.dilation_w != 1:
+            raise Exception("NNAPI does not support dilated pooling.")
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
+        assert len(image_oper.shape) == 4
+
+        out_shape = get_conv_pool_shape(
+            image_oper.shape, args, image_oper.shape[1], False
+        )
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
+        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
+        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_avg_pool2d(self, node):
+        assert node.inputsSize() == 7
+        assert node.outputsSize() == 1
+        (
+            image,
+            kernel,
+            stride,
+            padding,
+            ceil_mode,
+            count_include_pad,
+            divisor_override,
+        ) = node.inputs()
+
+        _, count_include_pad_value = self.get_constant_value(count_include_pad)
+        _, divisor_override_value = self.get_constant_value(divisor_override)
+        if not count_include_pad_value or divisor_override_value:
+            raise Exception(
+                "NNAPI doesn't support count_include_pad=False or divisor_override"
+            )
+
+        args = self.get_conv_pool_args_2d_from_jit(
+            self.get_size_arg(kernel), stride, padding
+        )
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
+        assert len(image_oper.shape) == 4
+
+        out_shape = get_conv_pool_shape(
+            image_oper.shape, args, image_oper.shape[1], False
+        )
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
+        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
+        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+        self._handle_conv_pool_flexible_input(out_id, image, args, False)
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
+
+    def add_adaptive_avg_pool2d(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(
+            node.inputsAt(0)
+        )
+        assert len(image_oper.shape) == 4
+
+        size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
+        assert size_ctype.kind() == "ListType"
+        assert size_ctype.getElementType().kind() == "IntType"
+        if size_arg != [1, 1]:
+            raise Exception(
+                "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)."
+            )
+
+        out_shape = image_oper.shape[0:2] + tuple(size_arg)
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(0)
+        inputs[2] = self.add_immediate_int_scalar(0)
+        inputs[3] = self.add_immediate_int_scalar(0)
+        inputs[4] = self.add_immediate_int_scalar(0)
+        inputs[5] = self.add_immediate_int_scalar(1)
+        inputs[6] = self.add_immediate_int_scalar(1)
+        inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
+        inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
+
+    def add_upsample_nearest2d(self, node):
+        assert node.inputsSize() == 3 or node.inputsSize() == 4
+        assert node.outputsSize() == 1
+        if node.inputsSize() == 3:
+            image, size_jit, scale_jit = node.inputs()
+        else:
+            image, size_jit, scale_h_jit, scale_w_jit = node.inputs()
+        size_ctype, size_arg = self.get_constant_value(size_jit)
+
+        if node.inputsSize() == 3:
+            scale_ctype, scale_arg = self.get_constant_value(scale_jit)  # type: ignore[possibly-undefined]
+        else:
+            scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)  # type: ignore[possibly-undefined]
+            scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit)  # type: ignore[possibly-undefined]
+
+            # The only way for the 4-argument overload of upsample_nearest2d to
+            # have been added to the graph without error is if the scale_h and
+            # scale_w arguments are None
+            assert scale_h_ctype.kind() == "NoneType"
+            assert scale_w_ctype.kind() == "NoneType"
+
+            scale_ctype = scale_h_ctype
+            scale_arg = scale_h_arg
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
+        assert len(image_oper.shape) == 4
+
+        if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
+            raise Exception("Size and scale cannot both be non-None.")
+        elif size_ctype.kind() != "NoneType":
+            assert size_ctype.kind() == "ListType"
+            assert size_ctype.getElementType().kind() == "IntType"
+            assert scale_ctype.kind() == "NoneType"
+            assert scale_arg is None
+            assert isinstance(size_arg, list)
+            assert size_arg
+            assert all(isinstance(val, int) for val in size_arg)
+            if len(size_arg) == 1:
+                size_arg = size_arg * 2
+            assert len(size_arg) == 2
+            out_h = size_arg[0]
+            out_w = size_arg[1]
+            arg_h = self.add_immediate_int_scalar(out_h)
+            arg_w = self.add_immediate_int_scalar(out_w)
+        elif scale_ctype.kind() != "NoneType":
+            assert scale_ctype.kind() == "ListType"
+            assert scale_ctype.getElementType().kind() == "FloatType"
+            assert size_ctype.kind() == "NoneType"
+            assert size_arg is None
+            assert isinstance(scale_arg, list)
+            assert scale_arg
+            assert all(isinstance(val, float) for val in scale_arg)
+            if len(scale_arg) == 1:
+                scale_arg = scale_arg * 2
+            assert len(scale_arg) == 2
+            out_h = int(scale_arg[0] * image_oper.shape[2])
+            out_w = int(scale_arg[1] * image_oper.shape[3])
+            arg_h = self.add_immediate_float_scalar(scale_arg[0])
+            arg_w = self.add_immediate_float_scalar(scale_arg[1])
+        else:
+            raise Exception("Size and scale cannot both be None.")
+
+        out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
+        use_nchw = image_oper.use_nchw()
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        if image_oper.shape[0] == 0 or image_oper.shape[1] == 0:
+            raise Exception("Flexible batch or channels not supported")
+
+        # Handle variable input size
+        for dim in (2, 3):  # h, w indices
+            if image_oper.shape[dim] == 0:
+                if size_ctype.kind() != "NoneType":
+                    self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
+                elif scale_ctype.kind() != "NoneType":
+                    self.compute_operand_shape(
+                        out_id,
+                        dim,
+                        f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
+                    )
+                else:
+                    raise Exception("Size and scale cannot both be None.")
+
+        inputs = [None] * 4
+        inputs[0] = image_id
+        inputs[1] = arg_w
+        inputs[2] = arg_h
+        inputs[3] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
+
+    def add_addmm(self, node):
+        assert node.inputsSize() == 5
+        assert node.outputsSize() == 1
+        jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
+
+        for jitval in (jit_beta, jit_alpha):
+            scale_ctype, scale_value = self.get_constant_value(jitval)
+            assert scale_ctype.kind() in ("IntType", "FloatType")
+            if scale_value != 1:
+                raise Exception(
+                    "NNAPI Fully-Connected does not support alpha and beta."
+                )
+
+        self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
+
+    def add_linear(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+        jit_input, jit_weight, jit_bias = node.inputs()
+
+        self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
+
+    def add_addmm_or_linear(
+        self, node, transpose_weight, jit_input, jit_weight, jit_bias
+    ):
+        input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
+        bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
+
+        assert len(input_oper.shape) == 2
+        assert len(bias_oper.shape) == 1
+
+        # TODO: Transform at load time to share weights with CPU model.
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        assert len(weight_tensor.shape) == 2
+        if transpose_weight:
+            nnapi_weight_tensor = weight_tensor.t().contiguous()
+        else:
+            nnapi_weight_tensor = weight_tensor.contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        out_shape = (input_oper.shape[0], weight_oper.shape[0])
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), input_oper._replace(shape=out_shape)
+        )
+
+        if input_oper.shape[0] == 0:
+            self.forward_operand_shape(out_id, 0, input_id, 0)
+
+        inputs = [None] * 4
+        inputs[0] = input_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
+
+    def add_qlinear(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+        (
+            jit_input,
+            jit_packed_weight,
+            jit_scale,
+            jit_zero_point,
+        ) = node.inputs()
+
+        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
+        # TODO: Support automatic reshape
+        assert len(input_oper.shape) == 2
+
+        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
+        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
+        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
+        assert weight_ctype.name() == "LinearPackedParamsBase"
+        raw_weight, raw_bias = packed_weight.__getstate__()[0]
+        assert raw_bias is not None
+
+        assert len(raw_weight.shape) == 2
+        assert len(raw_bias.shape) == 1
+        assert raw_bias.shape[0] == raw_weight.shape[0]
+        assert raw_weight.shape[1] == input_oper.shape[1]
+
+        assert raw_weight.qscheme() == torch.per_tensor_affine
+        if raw_weight.dtype == torch.quint8:
+            unsigned_weight = raw_weight
+        else:
+            assert raw_weight.dtype == torch.qint8
+            unsigned_weight = torch._make_per_tensor_quantized_tensor(
+                (raw_weight.int_repr().int() + 128).to(torch.uint8),
+                scale=raw_weight.q_scale(),
+                zero_point=raw_weight.q_zero_point() + 128,
+            )
+        weight_scale = unsigned_weight.q_scale()
+        bias_scale = input_oper.scale * weight_scale
+        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
+        bias_id = self.add_tensor_operand_for_weight(int_bias)
+
+        multiplier = input_oper.scale * weight_scale / out_scale
+        assert multiplier > 0
+        if multiplier >= 1:
+            raise Exception(
+                "Quantized convolution multiplier is greater than 1.  "
+                "This is supported by NNAPI, but not by most hardware backends.  "
+                "Try training a model without quantization-aware training.  "
+            )
+
+        # TODO: Transform at load time to share weights with CPU model.
+        nnapi_weight_tensor = unsigned_weight.contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        out_shape = (input_oper.shape[0], weight_oper.shape[0])
+        out_oper = input_oper._replace(
+            shape=out_shape,
+            scale=out_scale,
+            zero_point=out_zero_point,
+        )
+
+        inputs = [None] * 4
+        inputs[0] = input_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
+
+    def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
+        ctype, value = self.get_constant_value(jit_bias)
+        if ctype.kind() == "NoneType":
+            bias_idx = 1 if transpose else 0
+            nnapi_bias_tensor = torch.zeros(
+                weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype
+            )
+            bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
+            bias_oper = self.operands[bias_id]
+            return bias_id, bias_oper
+        else:
+            return self.get_tensor_operand_for_weight(jit_bias)
+
+    def add_conv2d(self, node):
+        assert node.inputsSize() == 7
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_weight,
+            jit_bias,
+            jit_stride,
+            jit_pad,
+            jit_dilation,
+            jit_groups,
+        ) = node.inputs()
+
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
+        args = self.get_conv_pool_args_2d_from_jit(
+            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
+        )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            0.0,
+            0,
+            jit_image,
+            weight_tensor,
+            bias_id,
+            args,
+            False,  # transpose
+            NNAPI_FuseCode.FUSED_NONE,
+        )
+
+    def add_conv_underscore(self, node):
+        assert node.inputsSize() == 13
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_weight,
+            jit_bias,
+            jit_stride,
+            jit_pad,
+            jit_dilation,
+            jit_transpose,
+            _,
+            jit_groups,
+            _,
+            _,
+            _,
+            _,
+        ) = node.inputs()
+
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        _, transpose = self.get_constant_value(jit_transpose)
+        bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
+        args = self.get_conv_pool_args_2d_from_jit(
+            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
+        )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            0.0,
+            0,
+            jit_image,
+            weight_tensor,
+            bias_id,
+            args,
+            transpose,
+            NNAPI_FuseCode.FUSED_NONE,
+        )
+
+    def add_log_softmax(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        (jit_input, jit_dim, jit_half_to_float) = node.inputs()
+        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
+        _, dim = self.get_constant_value(jit_dim, "IntType")
+
+        out_shape = input_oper.shape
+
+        inputs = [None] * 3
+        inputs[0] = input_id
+        # specifying 1 as the scaling factor for the exponent, beta
+        inputs[1] = self.add_immediate_float_scalar(1)
+        inputs[2] = self.add_immediate_int_scalar(dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), input_oper._replace(shape=out_shape)
+        )
+        self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs)
+
+    def add_qconv2d(self, node, fuse_code, transpose=False):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_packed_weight,
+            jit_scale,
+            jit_zero_point,
+        ) = node.inputs()
+
+        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
+        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
+        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
+        assert weight_ctype.name() == "Conv2dPackedParamsBase"
+        (
+            pack_version,
+            tensors,
+            opt_tensors,
+        ) = packed_weight.__getstate__()[0]
+        assert pack_version == "2"
+        packed_config, raw_weight = tensors
+        (raw_bias,) = opt_tensors
+        assert raw_bias is not None
+        args = self.get_conv_pool_args_2d_from_pack(
+            raw_weight.shape[2:4], packed_config
+        )
+
+        assert raw_weight.qscheme() == torch.per_tensor_affine
+        if raw_weight.dtype == torch.quint8:
+            unsigned_weight = raw_weight
+        else:
+            assert raw_weight.dtype == torch.qint8
+            unsigned_weight = torch._make_per_tensor_quantized_tensor(
+                (raw_weight.int_repr().int() + 128).to(torch.uint8),
+                scale=raw_weight.q_scale(),
+                zero_point=raw_weight.q_zero_point() + 128,
+            )
+        weight_scale = unsigned_weight.q_scale()
+        _, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        bias_scale = image_oper.scale * weight_scale
+        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
+        bias_id = self.add_tensor_operand_for_weight(int_bias)
+
+        multiplier = image_oper.scale * weight_scale / out_scale
+        assert multiplier > 0
+        if multiplier >= 1:
+            raise Exception(
+                "Quantized convolution multiplier is greater than 1.  "
+                "This is supported by NNAPI, but not by most hardware backends.  "
+                "Try training a model without quantization-aware training.  "
+            )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            out_scale,
+            out_zero_point,
+            jit_image,
+            unsigned_weight,
+            bias_id,
+            args,
+            transpose,
+            fuse_code,
+        )
+
+    def add_conv2d_common(
+        self,
+        jit_out,
+        out_scale,
+        out_zero_point,
+        jit_image,
+        weight_tensor,
+        bias_id,
+        args,
+        transpose,
+        fuse_code,
+    ):
+        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        in_c = image_oper.shape[1]
+
+        if args.group == 1:
+            # Full convolution
+            depthwise = False
+            if transpose:
+                weight_permutation = (1, 2, 3, 0)
+            else:
+                weight_permutation = (0, 2, 3, 1)
+        elif args.group == in_c:
+            # Depthwise convolution
+            depthwise = True
+            weight_permutation = (1, 2, 3, 0)
+        else:
+            raise Exception("Group convolution not supported yet.")
+
+        # TODO: Transform at load time to share weights with CPU model.
+        nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        bias_oper = self.operands[bias_id]
+
+        if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
+            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
+            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
+        elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
+            assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
+            assert bias_oper.zero_point == 0
+        else:
+            raise Exception(f"Unsupported input type for conv2d: {image_oper.op_type}")
+
+        assert len(image_oper.shape) == 4
+        assert len(weight_oper.shape) == 4
+        assert len(bias_oper.shape) == 1
+
+        if depthwise:
+            # Depthwise convolution
+            one, kern_h, kern_w, out_c = weight_oper.shape
+            assert one == 1
+            assert out_c % in_c == 0
+            channel_multiplier = out_c // in_c
+            assert channel_multiplier == 1  # Don't support multiplier
+            assert out_c == in_c
+        else:
+            # Full convolution
+            out_c, kern_h, kern_w, kern_d = weight_oper.shape
+            assert kern_d == in_c
+
+        assert out_c == bias_oper.shape[0]
+
+        use_nchw = image_oper.use_nchw()
+
+        if depthwise:
+            num_args = 12
+            opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
+        else:
+            num_args = 11
+            if transpose:
+                opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
+            else:
+                opcode = NNAPI_OperationCode.CONV_2D
+
+        inputs = [None] * num_args
+        inputs[0] = image_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[5] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[6] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[7] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[8] = self.add_immediate_int_scalar(args.stride_h)
+        if depthwise:
+            inputs[9] = self.add_immediate_int_scalar(1)
+            inputs[10] = self.add_immediate_int_scalar(fuse_code)
+            inputs[11] = self.add_immediate_bool_scalar(use_nchw)
+        else:
+            inputs[9] = self.add_immediate_int_scalar(fuse_code)
+            inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
+        out_oper = image_oper._replace(
+            shape=out_shape,
+            scale=out_scale,
+            zero_point=out_zero_point,
+        )
+        out_id = self.add_tensor_operand(jit_out, out_oper)
+        self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose)
+
+        outputs[0] = out_id
+        self.add_operation(opcode, inputs, outputs)
+
+    def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose):
+        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        batch, in_ch, in_h, in_w = image_oper.shape
+
+        if batch == 0:
+            self.forward_operand_shape(out_id, 0, image_id, 0)
+        if in_ch == 0:
+            raise Exception("Input channels can't be flexible")
+        # H & W
+        if transpose:
+            if in_h == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    2,
+                    f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}",
+                )
+            if in_w == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    3,
+                    f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}",
+                )
+        else:
+            if in_h == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    2,
+                    f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1",
+                )
+            if in_w == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    3,
+                    f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1",
+                )
+
+
+def serialize_model(
+    module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False
+):
+    """Convert to NNAPI and serialize torchscript module.
+
+    Parameters:
+        module: Torchscript module to convert
+        inputs: Tensors used to specify input details for NNAPI
+        config (optional): Optional config to attach to module
+        return_shapes (optional): Specify shape of outputs if
+            your module uses runtime flexible shapes to set output
+            buffer size for NNAPI
+        use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values
+    """
+    return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(
+        module, inputs, return_shapes
+    )
diff --git a/MLPY/Lib/site-packages/torch/backends/cpu/__init__.py b/MLPY/Lib/site-packages/torch/backends/cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..923fbd6401673d7569efc5ddb2f4edd9101b7860
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/cpu/__init__.py
@@ -0,0 +1,19 @@
+import torch
+
+__all__ = [
+    "get_cpu_capability",
+]
+
+
+def get_cpu_capability() -> str:
+    r"""Return cpu capability as a string value.
+
+    Possible values:
+    - "DEFAULT"
+    - "VSX"
+    - "Z VECTOR"
+    - "NO AVX"
+    - "AVX2"
+    - "AVX512"
+    """
+    return torch._C._get_cpu_capability()
diff --git a/MLPY/Lib/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba1f81ddc28b41a7035c7550f5a7f3ef6b766234
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/cuda/__init__.py b/MLPY/Lib/site-packages/torch/backends/cuda/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29a25146a9b88243b4ded4b648b11725fb9a272
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/cuda/__init__.py
@@ -0,0 +1,371 @@
+import contextlib
+import warnings
+
+from typing import Union
+
+import torch
+
+__all__ = [
+    "is_built",
+    "cuFFTPlanCacheAttrContextProp",
+    "cuFFTPlanCache",
+    "cuFFTPlanCacheManager",
+    "cuBLASModule",
+    "preferred_linalg_library",
+    "cufft_plan_cache",
+    "matmul",
+    "SDPBackend",
+    "SDPAParams",
+    "enable_cudnn_sdp",
+    "cudnn_sdp_enabled",
+    "enable_flash_sdp",
+    "flash_sdp_enabled",
+    "enable_mem_efficient_sdp",
+    "mem_efficient_sdp_enabled",
+    "math_sdp_enabled",
+    "enable_math_sdp",
+    "can_use_flash_attention",
+    "can_use_efficient_attention",
+    "sdp_kernel",
+]
+
+
+def is_built():
+    r"""
+    Return whether PyTorch is built with CUDA support.
+
+    Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch
+    binary were run on a machine with working CUDA drivers and devices, we would be able to use it.
+    """
+    return torch._C._has_cuda
+
+
+class cuFFTPlanCacheAttrContextProp:
+    # Like regular ContextProp, but uses the `.device_index` attribute from the
+    # calling object as the first argument to the getter and setter.
+    def __init__(self, getter, setter):
+        self.getter = getter
+        self.setter = setter
+
+    def __get__(self, obj, objtype):
+        return self.getter(obj.device_index)
+
+    def __set__(self, obj, val):
+        if isinstance(self.setter, str):
+            raise RuntimeError(self.setter)
+        self.setter(obj.device_index, val)
+
+
+class cuFFTPlanCache:
+    r"""
+    Represent a specific plan cache for a specific `device_index`.
+
+    The attributes `size` and `max_size`, and method `clear`, can fetch and/ or
+    change properties of the C++ cuFFT plan cache.
+    """
+
+    def __init__(self, device_index):
+        self.device_index = device_index
+
+    size = cuFFTPlanCacheAttrContextProp(
+        torch._cufft_get_plan_cache_size,
+        ".size is a read-only property showing the number of plans currently in the "
+        "cache. To change the cache capacity, set cufft_plan_cache.max_size.",
+    )
+
+    max_size = cuFFTPlanCacheAttrContextProp(
+        torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size
+    )
+
+    def clear(self):
+        return torch._cufft_clear_plan_cache(self.device_index)
+
+
+class cuFFTPlanCacheManager:
+    r"""
+    Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed.
+
+    Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
+    setting the `.max_size`) attribute, the current device's cuFFT plan cache is
+    used.
+    """
+
+    __initialized = False
+
+    def __init__(self):
+        self.caches = []
+        self.__initialized = True
+
+    def __getitem__(self, device):
+        index = torch.cuda._utils._get_device_index(device)
+        if index < 0 or index >= torch.cuda.device_count():
+            raise RuntimeError(
+                f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got "
+                f"device with index {index}"
+            )
+        if len(self.caches) == 0:
+            self.caches.extend(
+                cuFFTPlanCache(index) for index in range(torch.cuda.device_count())
+            )
+        return self.caches[index]
+
+    def __getattr__(self, name):
+        return getattr(self[torch.cuda.current_device()], name)
+
+    def __setattr__(self, name, value):
+        if self.__initialized:
+            return setattr(self[torch.cuda.current_device()], name, value)
+        else:
+            return super().__setattr__(name, value)
+
+
+class cuBLASModule:
+    def __getattr__(self, name):
+        if name == "allow_tf32":
+            return torch._C._get_cublas_allow_tf32()
+        elif name == "allow_fp16_reduced_precision_reduction":
+            return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
+        elif name == "allow_bf16_reduced_precision_reduction":
+            return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
+        raise AttributeError("Unknown attribute " + name)
+
+    def __setattr__(self, name, value):
+        if name == "allow_tf32":
+            return torch._C._set_cublas_allow_tf32(value)
+        elif name == "allow_fp16_reduced_precision_reduction":
+            return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
+        elif name == "allow_bf16_reduced_precision_reduction":
+            return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
+        raise AttributeError("Unknown attribute " + name)
+
+
+_LinalgBackends = {
+    "default": torch._C._LinalgBackend.Default,
+    "cusolver": torch._C._LinalgBackend.Cusolver,
+    "magma": torch._C._LinalgBackend.Magma,
+}
+_LinalgBackends_str = ", ".join(_LinalgBackends.keys())
+
+
+def preferred_linalg_library(
+    backend: Union[None, str, torch._C._LinalgBackend] = None
+) -> torch._C._LinalgBackend:
+    r"""
+    Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations.
+
+    .. warning:: This flag is experimental and subject to change.
+
+    When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries,
+    and if both are available it decides which to use with a heuristic.
+    This flag (a :class:`str`) allows overriding those heuristics.
+
+    * If `"cusolver"` is set then cuSOLVER will be used wherever possible.
+    * If `"magma"` is set then MAGMA will be used wherever possible.
+    * If `"default"` (the default) is set then heuristics will be used to pick between
+      cuSOLVER and MAGMA if both are available.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER
+      globally.
+      This flag only sets the initial value of the preferred library and the preferred library
+      may still be overridden by this function call later in your script.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's heuristic library selection is incorrect
+    for your application's inputs.
+
+    Currently supported linalg operators:
+
+    * :func:`torch.linalg.inv`
+    * :func:`torch.linalg.inv_ex`
+    * :func:`torch.linalg.cholesky`
+    * :func:`torch.linalg.cholesky_ex`
+    * :func:`torch.cholesky_solve`
+    * :func:`torch.cholesky_inverse`
+    * :func:`torch.linalg.lu_factor`
+    * :func:`torch.linalg.lu`
+    * :func:`torch.linalg.lu_solve`
+    * :func:`torch.linalg.qr`
+    * :func:`torch.linalg.eigh`
+    * :func:`torch.linalg.eighvals`
+    * :func:`torch.linalg.svd`
+    * :func:`torch.linalg.svdvals`
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _LinalgBackends:
+            raise RuntimeError(
+                "Unknown input value. " f"Choose from: {_LinalgBackends_str}."
+            )
+        torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])
+    elif isinstance(backend, torch._C._LinalgBackend):
+        torch._C._set_linalg_preferred_backend(backend)
+    else:
+        raise RuntimeError("Unknown input value type.")
+
+    return torch._C._get_linalg_preferred_backend()
+
+
+from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
+
+# Set the __module__ attribute
+SDPAParams.__module__ = "torch.backends.cuda"
+SDPAParams.__name__ = "SDPAParams"
+
+
+def flash_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether flash scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_flash_sdp_enabled()
+
+
+def enable_flash_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables flash scaled dot product attention.
+    """
+    torch._C._set_sdp_use_flash(enabled)
+
+
+def mem_efficient_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether memory efficient scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_mem_efficient_sdp_enabled()
+
+
+def enable_mem_efficient_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables memory efficient scaled dot product attention.
+    """
+    torch._C._set_sdp_use_mem_efficient(enabled)
+
+
+def math_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether math scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_math_sdp_enabled()
+
+
+def enable_math_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables math scaled dot product attention.
+    """
+    torch._C._set_sdp_use_math(enabled)
+
+
+def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if FlashAttention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn debug information as to why FlashAttention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if FlashAttention can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_flash_attention(params, debug)
+
+
+def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if efficient_attention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn with information as to why efficient_attention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if efficient_attention can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_mem_efficient_attention(params, debug)
+
+
+def cudnn_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether cuDNN scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_cudnn_sdp_enabled()
+
+
+def enable_cudnn_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables cuDNN scaled dot product attention.
+    """
+    torch._C._set_sdp_use_cudnn(enabled)
+
+
+@contextlib.contextmanager
+def sdp_kernel(
+    enable_flash: bool = True,
+    enable_math: bool = True,
+    enable_mem_efficient: bool = True,
+    enable_cudnn: bool = True,
+):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
+    Upon exiting the context manager, the previous state of the flags will be restored.
+    """
+    warnings.warn(
+        (
+            "torch.backends.cuda.sdp_kernel() "
+            "is deprecated. In the future, this context manager will be removed. "
+            "Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated "
+            "signature."
+        ),
+        FutureWarning,
+    )
+    from torch.nn.attention import sdpa_kernel, SDPBackend
+
+    backend_list = []
+    if enable_flash:
+        backend_list.append(SDPBackend.FLASH_ATTENTION)
+    if enable_mem_efficient:
+        backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
+    if enable_math:
+        backend_list.append(SDPBackend.MATH)
+    if enable_cudnn:
+        backend_list.append(SDPBackend.CUDNN_ATTENTION)
+
+    with sdpa_kernel(backend_list) as context:
+        try:
+            yield context
+        finally:
+            pass
+
+
+cufft_plan_cache = cuFFTPlanCacheManager()
+matmul = cuBLASModule()
diff --git a/MLPY/Lib/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cf7fdc2911d65ae5b785f1f7685f2fd29c7392e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/cudnn/__init__.py b/MLPY/Lib/site-packages/torch/backends/cudnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939fa72b4fa2a6293e6eba9ba3385fd4ea34045
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/cudnn/__init__.py
@@ -0,0 +1,206 @@
+import os
+import sys
+import warnings
+from contextlib import contextmanager
+from typing import Optional
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+try:
+    from torch._C import _cudnn
+except ImportError:
+    _cudnn = None  # type: ignore[assignment]
+
+# Write:
+#
+#   torch.backends.cudnn.enabled = False
+#
+# to globally disable CuDNN/MIOpen
+
+__cudnn_version: Optional[int] = None
+
+if _cudnn is not None:
+
+    def _init():
+        global __cudnn_version
+        if __cudnn_version is None:
+            __cudnn_version = _cudnn.getVersionInt()
+            runtime_version = _cudnn.getRuntimeVersion()
+            compile_version = _cudnn.getCompileVersion()
+            runtime_major, runtime_minor, _ = runtime_version
+            compile_major, compile_minor, _ = compile_version
+            # Different major versions are always incompatible
+            # Starting with cuDNN 7, minor versions are backwards-compatible
+            # Not sure about MIOpen (ROCm), so always do a strict check
+            if runtime_major != compile_major:
+                cudnn_compatible = False
+            elif runtime_major < 7 or not _cudnn.is_cuda:
+                cudnn_compatible = runtime_minor == compile_minor
+            else:
+                cudnn_compatible = runtime_minor >= compile_minor
+            if not cudnn_compatible:
+                if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1":
+                    return True
+                base_error_msg = (
+                    f"cuDNN version incompatibility: "
+                    f"PyTorch was compiled  against {compile_version} "
+                    f"but found runtime version {runtime_version}. "
+                    f"PyTorch already comes bundled with cuDNN. "
+                    f"One option to resolving this error is to ensure PyTorch "
+                    f"can find the bundled cuDNN. "
+                )
+
+                if "LD_LIBRARY_PATH" in os.environ:
+                    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
+                    if any(
+                        substring in ld_library_path for substring in ["cuda", "cudnn"]
+                    ):
+                        raise RuntimeError(
+                            f"{base_error_msg}"
+                            f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. "
+                            f"Please either remove it from the path or install cudnn {compile_version}"
+                        )
+                    else:
+                        raise RuntimeError(
+                            f"{base_error_msg}"
+                            f"one possibility is that there is a "
+                            f"conflicting cuDNN in LD_LIBRARY_PATH."
+                        )
+                else:
+                    raise RuntimeError(base_error_msg)
+
+        return True
+
+else:
+
+    def _init():
+        return False
+
+
+def version():
+    """Return the version of cuDNN."""
+    if not _init():
+        return None
+    return __cudnn_version
+
+
+CUDNN_TENSOR_DTYPES = {
+    torch.half,
+    torch.float,
+    torch.double,
+}
+
+
+def is_available():
+    r"""Return a bool indicating if CUDNN is currently available."""
+    return torch._C._has_cudnn
+
+
+def is_acceptable(tensor):
+    if not torch._C._get_cudnn_enabled():
+        return False
+    if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES:
+        return False
+    if not is_available():
+        warnings.warn(
+            "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
+            "PyTorch making sure the library is visible to the build system."
+        )
+        return False
+    if not _init():
+        warnings.warn(
+            "cuDNN/MIOpen library not found. Check your {libpath}".format(
+                libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get(
+                    sys.platform, "LD_LIBRARY_PATH"
+                )
+            )
+        )
+        return False
+    return True
+
+
+def set_flags(
+    _enabled=None,
+    _benchmark=None,
+    _benchmark_limit=None,
+    _deterministic=None,
+    _allow_tf32=None,
+):
+    orig_flags = (
+        torch._C._get_cudnn_enabled(),
+        torch._C._get_cudnn_benchmark(),
+        None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
+        torch._C._get_cudnn_deterministic(),
+        torch._C._get_cudnn_allow_tf32(),
+    )
+    if _enabled is not None:
+        torch._C._set_cudnn_enabled(_enabled)
+    if _benchmark is not None:
+        torch._C._set_cudnn_benchmark(_benchmark)
+    if _benchmark_limit is not None and is_available():
+        torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
+    if _deterministic is not None:
+        torch._C._set_cudnn_deterministic(_deterministic)
+    if _allow_tf32 is not None:
+        torch._C._set_cudnn_allow_tf32(_allow_tf32)
+    return orig_flags
+
+
+@contextmanager
+def flags(
+    enabled=False,
+    benchmark=False,
+    benchmark_limit=10,
+    deterministic=False,
+    allow_tf32=True,
+):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(
+            enabled, benchmark, benchmark_limit, deterministic, allow_tf32
+        )
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends..enabled = True
+
+
+class CudnnModule(PropModule):
+    def __init__(self, m, name):
+        super().__init__(m, name)
+
+    enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
+    deterministic = ContextProp(
+        torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic
+    )
+    benchmark = ContextProp(
+        torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark
+    )
+    benchmark_limit = None
+    if is_available():
+        benchmark_limit = ContextProp(
+            torch._C._cuda_get_cudnn_benchmark_limit,
+            torch._C._cuda_set_cudnn_benchmark_limit,
+        )
+    allow_tf32 = ContextProp(
+        torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
+    )
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)
+
+# Add type annotation for the replaced module
+enabled: bool
+deterministic: bool
+benchmark: bool
+allow_tf32: bool
+benchmark_limit: int
diff --git a/MLPY/Lib/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dd8d20bf0ee32606c91a9086f90582ca9033cbe
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4976c7f091ea415da0e0034f1318ed92ebb3543
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/cudnn/rnn.py b/MLPY/Lib/site-packages/torch/backends/cudnn/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bbbe5676413715ecf551326cc392584ae5b356a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/cudnn/rnn.py
@@ -0,0 +1,62 @@
+import torch.cuda
+
+try:
+    from torch._C import _cudnn
+except ImportError:
+    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
+    # so it's safe to not emit any checks here.
+    _cudnn = None  # type: ignore[assignment]
+
+
+def get_cudnn_mode(mode):
+    if mode == "RNN_RELU":
+        return int(_cudnn.RNNMode.rnn_relu)
+    elif mode == "RNN_TANH":
+        return int(_cudnn.RNNMode.rnn_tanh)
+    elif mode == "LSTM":
+        return int(_cudnn.RNNMode.lstm)
+    elif mode == "GRU":
+        return int(_cudnn.RNNMode.gru)
+    else:
+        raise Exception(f"Unknown mode: {mode}")
+
+
+# NB: We don't actually need this class anymore (in fact, we could serialize the
+# dropout state for even better reproducibility), but it is kept for backwards
+# compatibility for old models.
+class Unserializable:
+    def __init__(self, inner):
+        self.inner = inner
+
+    def get(self):
+        return self.inner
+
+    def __getstate__(self):
+        # Note: can't return {}, because python2 won't call __setstate__
+        # if the value evaluates to False
+        return ""
+
+    def __setstate__(self, state):
+        self.inner = None
+
+
+def init_dropout_state(dropout, train, dropout_seed, dropout_state):
+    dropout_desc_name = "desc_" + str(torch.cuda.current_device())
+    dropout_p = dropout if train else 0
+    if (dropout_desc_name not in dropout_state) or (
+        dropout_state[dropout_desc_name].get() is None
+    ):
+        if dropout_p == 0:
+            dropout_state[dropout_desc_name] = Unserializable(None)
+        else:
+            dropout_state[dropout_desc_name] = Unserializable(
+                torch._cudnn_init_dropout_state(  # type: ignore[call-arg]
+                    dropout_p,
+                    train,
+                    dropout_seed,
+                    self_ty=torch.uint8,
+                    device=torch.device("cuda"),
+                )
+            )
+    dropout_ts = dropout_state[dropout_desc_name].get()
+    return dropout_ts
diff --git a/MLPY/Lib/site-packages/torch/backends/mha/__init__.py b/MLPY/Lib/site-packages/torch/backends/mha/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48cd8ce957cdf149ce7fd4608710b303261b3dda
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/mha/__init__.py
@@ -0,0 +1,24 @@
+# Config options to enable/disable C++ kernel for nn.functional.MHA
+# and nn.TransformerEncoder
+import torch
+
+_is_fastpath_enabled: bool = True
+
+
+def get_fastpath_enabled() -> bool:
+    """Returns whether fast path for TransformerEncoder and MultiHeadAttention
+    is enabled, or ``True`` if jit is scripting.
+
+    ..note:
+        The fastpath might not be run even if ``get_fastpath_enabled`` returns
+        ``True`` unless all conditions on inputs are met.
+    """
+    if not torch.jit.is_scripting():
+        return _is_fastpath_enabled
+    return True
+
+
+def set_fastpath_enabled(value: bool) -> None:
+    """Sets whether fast path is enabled"""
+    global _is_fastpath_enabled
+    _is_fastpath_enabled = value
diff --git a/MLPY/Lib/site-packages/torch/backends/mha/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/mha/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31882d6906a9d030a5d900ad69d2ce8d92360e52
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/mha/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/mkl/__init__.py b/MLPY/Lib/site-packages/torch/backends/mkl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1a79bf8d184428242e76805185007878dd356c6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/mkl/__init__.py
@@ -0,0 +1,56 @@
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with MKL support."""
+    return torch._C.has_mkl
+
+
+VERBOSE_OFF = 0
+VERBOSE_ON = 1
+
+
+class verbose:
+    """
+    On-demand oneMKL verbosing functionality.
+
+    To make it easier to debug performance issues, oneMKL can dump verbose
+    messages containing execution information like duration while executing
+    the kernel. The verbosing functionality can be invoked via an environment
+    variable named `MKL_VERBOSE`. However, this methodology dumps messages in
+    all steps. Those are a large amount of verbose messages. Moreover, for
+    investigating the performance issues, generally taking verbose messages
+    for one single iteration is enough. This on-demand verbosing functionality
+    makes it possible to control scope for verbose message dumping. In the
+    following example, verbose messages will be dumped out for the second
+    inference only.
+
+    .. highlight:: python
+    .. code-block:: python
+
+        import torch
+        model(data)
+        with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
+            model(data)
+
+    Args:
+        level: Verbose level
+            - ``VERBOSE_OFF``: Disable verbosing
+            - ``VERBOSE_ON``:  Enable verbosing
+    """
+
+    def __init__(self, enable):
+        self.enable = enable
+
+    def __enter__(self):
+        if self.enable == VERBOSE_OFF:
+            return
+        st = torch._C._verbose.mkl_set_verbose(self.enable)
+        assert (
+            st
+        ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope."
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        torch._C._verbose.mkl_set_verbose(VERBOSE_OFF)
+        return False
diff --git a/MLPY/Lib/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83c9399aaf34fe6256e2a8bb26d336eb2dc92fa4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/mkldnn/__init__.py b/MLPY/Lib/site-packages/torch/backends/mkldnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0d9f83f95d8d147f60e3bde84c1dc191b1c49e0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/mkldnn/__init__.py
@@ -0,0 +1,97 @@
+import sys
+from contextlib import contextmanager
+
+from typing import TYPE_CHECKING
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+def is_available():
+    r"""Return whether PyTorch is built with MKL-DNN support."""
+    return torch._C._has_mkldnn
+
+
+VERBOSE_OFF = 0
+VERBOSE_ON = 1
+VERBOSE_ON_CREATION = 2
+
+
+class verbose:
+    """
+    On-demand oneDNN (former MKL-DNN) verbosing functionality.
+
+    To make it easier to debug performance issues, oneDNN can dump verbose
+    messages containing information like kernel size, input data size and
+    execution duration while executing the kernel. The verbosing functionality
+    can be invoked via an environment variable named `DNNL_VERBOSE`. However,
+    this methodology dumps messages in all steps. Those are a large amount of
+    verbose messages. Moreover, for investigating the performance issues,
+    generally taking verbose messages for one single iteration is enough.
+    This on-demand verbosing functionality makes it possible to control scope
+    for verbose message dumping. In the following example, verbose messages
+    will be dumped out for the second inference only.
+
+    .. highlight:: python
+    .. code-block:: python
+
+        import torch
+        model(data)
+        with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON):
+            model(data)
+
+    Args:
+        level: Verbose level
+            - ``VERBOSE_OFF``: Disable verbosing
+            - ``VERBOSE_ON``:  Enable verbosing
+            - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation
+    """
+
+    def __init__(self, level):
+        self.level = level
+
+    def __enter__(self):
+        if self.level == VERBOSE_OFF:
+            return
+        st = torch._C._verbose.mkldnn_set_verbose(self.level)
+        assert (
+            st
+        ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope."
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF)
+        return False
+
+
+def set_flags(_enabled):
+    orig_flags = (torch._C._get_mkldnn_enabled(),)
+    torch._C._set_mkldnn_enabled(_enabled)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=False):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled)
+    try:
+        yield
+    finally:
+        with __allow_nonbracketed_mutation():
+            set_flags(orig_flags[0])
+
+
+class MkldnnModule(PropModule):
+    def __init__(self, m, name):
+        super().__init__(m, name)
+
+    enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
+
+
+if TYPE_CHECKING:
+    enabled: ContextProp
+
+
+# Cool stuff from torch/backends/cudnn/__init__.py and
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
diff --git a/MLPY/Lib/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..debe775ccc8916b83ab91d237fd1089c53523774
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/mps/__init__.py b/MLPY/Lib/site-packages/torch/backends/mps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..84f2a190303c90d8a73a9825835a2d0a3db5cd04
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/mps/__init__.py
@@ -0,0 +1,54 @@
+from functools import lru_cache as _lru_cache
+
+from typing import Optional
+
+import torch
+from ...library import Library as _Library
+
+__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
+
+
+def is_built() -> bool:
+    r"""Return whether PyTorch is built with MPS support.
+
+    Note that this doesn't necessarily mean MPS is available; just that
+    if this PyTorch binary were run a machine with working MPS drivers
+    and devices, we would be able to use it.
+    """
+    return torch._C._has_mps
+
+
+@_lru_cache
+def is_available() -> bool:
+    r"""Return a bool indicating if MPS is currently available."""
+    return torch._C._mps_is_available()
+
+
+@_lru_cache
+def is_macos_or_newer(major: int, minor: int) -> bool:
+    r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
+    return torch._C._mps_is_on_macos_or_newer(major, minor)
+
+
+@_lru_cache
+def is_macos13_or_newer(minor: int = 0) -> bool:
+    r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
+    return torch._C._mps_is_on_macos_or_newer(13, minor)
+
+
+_lib: Optional[_Library] = None
+
+
+def _init():
+    r"""Register prims as implementation of var_mean and group_norm."""
+    global _lib
+    if is_built() is False or _lib is not None:
+        return
+    from ..._decomp.decompositions import (
+        native_group_norm_backward as _native_group_norm_backward,
+    )
+    from ..._refs import native_group_norm as _native_group_norm
+
+    _lib = _Library("aten", "IMPL")
+    _lib.impl("native_group_norm", _native_group_norm, "MPS")
+    _lib.impl("native_group_norm_backward", _native_group_norm_backward, "MPS")
diff --git a/MLPY/Lib/site-packages/torch/backends/mps/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/mps/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c6991284813119e55b744feee9075ce387aa0f3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/mps/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/nnpack/__init__.py b/MLPY/Lib/site-packages/torch/backends/nnpack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..938ed29a44572a9f529dda1632dfed915c006a28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/nnpack/__init__.py
@@ -0,0 +1,30 @@
+from contextlib import contextmanager
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+__all__ = ["is_available", "flags", "set_flags"]
+
+
+def is_available():
+    r"""Return whether PyTorch is built with NNPACK support."""
+    return torch._nnpack_available()
+
+
+def set_flags(_enabled):
+    r"""Set if nnpack is enabled globally"""
+    orig_flags = (torch._C._get_nnpack_enabled(),)
+    torch._C._set_nnpack_enabled(_enabled)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=False):
+    r"""Context manager for setting if nnpack is enabled globally"""
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled)
+    try:
+        yield
+    finally:
+        with __allow_nonbracketed_mutation():
+            set_flags(orig_flags[0])
diff --git a/MLPY/Lib/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..119488f907ce85e1c0e72803bda834455e3e5912
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/openmp/__init__.py b/MLPY/Lib/site-packages/torch/backends/openmp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e0afd5a0e58f82d010db5775cae6bf46c48336
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/openmp/__init__.py
@@ -0,0 +1,6 @@
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with OpenMP support."""
+    return torch._C.has_openmp
diff --git a/MLPY/Lib/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95f885e201da04103a00da70fa3d1d38eaa0d12b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/opt_einsum/__init__.py b/MLPY/Lib/site-packages/torch/backends/opt_einsum/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..26903bfe75b143c34d00dca97e9a03b0b45f4c29
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/opt_einsum/__init__.py
@@ -0,0 +1,110 @@
+import sys
+import warnings
+from contextlib import contextmanager
+from functools import lru_cache as _lru_cache
+from typing import Any
+
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+try:
+    import opt_einsum as _opt_einsum  # type: ignore[import]
+except ImportError:
+    _opt_einsum = None
+
+
+@_lru_cache
+def is_available() -> bool:
+    r"""Return a bool indicating if opt_einsum is currently available."""
+    return _opt_einsum is not None
+
+
+def get_opt_einsum() -> Any:
+    r"""Return the opt_einsum package if opt_einsum is currently available, else None."""
+    return _opt_einsum
+
+
+def _set_enabled(_enabled: bool) -> None:
+    if not is_available() and _enabled:
+        raise ValueError(
+            f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap "
+            "the benefits of calculating an optimal path for einsum. torch.einsum will "
+            "fall back to contracting from left to right. To enable this optimal path "
+            "calculation, please install opt-einsum."
+        )
+    global enabled
+    enabled = _enabled
+
+
+def _get_enabled() -> bool:
+    return enabled
+
+
+def _set_strategy(_strategy: str) -> None:
+    if not is_available():
+        raise ValueError(
+            f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. "
+            "torch.einsum will bypass path calculation and simply contract from left to right. "
+            "Please install opt_einsum or unset `strategy`."
+        )
+    if not enabled:
+        raise ValueError(
+            f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. "
+            "torch.einsum will bypass path calculation and simply contract from left to right. "
+            "Please set `enabled` to `True` as well or unset `strategy`."
+        )
+    if _strategy not in ["auto", "greedy", "optimal"]:
+        raise ValueError(
+            f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}"
+        )
+    global strategy
+    strategy = _strategy
+
+
+def _get_strategy() -> str:
+    return strategy
+
+
+def set_flags(_enabled=None, _strategy=None):
+    orig_flags = (enabled, None if not is_available() else strategy)
+    if _enabled is not None:
+        _set_enabled(_enabled)
+    if _strategy is not None:
+        _set_strategy(_strategy)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=None, strategy=None):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled, strategy)
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends.opt_einsum.enabled = True
+
+
+class OptEinsumModule(PropModule):
+    def __init__(self, m, name):
+        super().__init__(m, name)
+
+    global enabled
+    enabled = ContextProp(_get_enabled, _set_enabled)
+    global strategy
+    strategy = None
+    if is_available():
+        strategy = ContextProp(_get_strategy, _set_strategy)
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
+
+enabled = True if is_available() else False
+strategy = "auto" if is_available() else None
diff --git a/MLPY/Lib/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3ad4df3e554ed64270eaafd6dcfd265b0a67987
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/quantized/__init__.py b/MLPY/Lib/site-packages/torch/backends/quantized/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d715f6d4acc1750e76482ab06f3792974b22591d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/quantized/__init__.py
@@ -0,0 +1,65 @@
+import sys
+import types
+from typing import List
+
+import torch
+
+
+# This function should correspond to the enums present in c10/core/QEngine.h
+def _get_qengine_id(qengine: str) -> int:
+    if qengine == "none" or qengine == "" or qengine is None:
+        ret = 0
+    elif qengine == "fbgemm":
+        ret = 1
+    elif qengine == "qnnpack":
+        ret = 2
+    elif qengine == "onednn":
+        ret = 3
+    elif qengine == "x86":
+        ret = 4
+    else:
+        ret = -1
+        raise RuntimeError(f"{qengine} is not a valid value for quantized engine")
+    return ret
+
+
+# This function should correspond to the enums present in c10/core/QEngine.h
+def _get_qengine_str(qengine: int) -> str:
+    all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"}
+    return all_engines.get(qengine, "*undefined")
+
+
+class _QEngineProp:
+    def __get__(self, obj, objtype) -> str:
+        return _get_qengine_str(torch._C._get_qengine())
+
+    def __set__(self, obj, val: str) -> None:
+        torch._C._set_qengine(_get_qengine_id(val))
+
+
+class _SupportedQEnginesProp:
+    def __get__(self, obj, objtype) -> List[str]:
+        qengines = torch._C._supported_qengines()
+        return [_get_qengine_str(qe) for qe in qengines]
+
+    def __set__(self, obj, val) -> None:
+        raise RuntimeError("Assignment not supported")
+
+
+class QuantizedEngine(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+    engine = _QEngineProp()
+    supported_engines = _SupportedQEnginesProp()
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
+engine: str
+supported_engines: List[str]
diff --git a/MLPY/Lib/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdae660aae7aa884c6cda8f709c7c6f52885bc72
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/xeon/__init__.py b/MLPY/Lib/site-packages/torch/backends/xeon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59cceff859ff5e0bec1d99334db3f7748b4c3835
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/xeon/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3970883ac193e06ca1de56dd1beeb5fa98d887d8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/backends/xeon/run_cpu.py b/MLPY/Lib/site-packages/torch/backends/xeon/run_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdfe445abad86c8140f4b85f176b8fa1d0c799b0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/xeon/run_cpu.py
@@ -0,0 +1,929 @@
+"""
+This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations.
+
+Single instance inference, multi-instance inference are enabled.
+
+Note: term "instance" here doesn't refer to a cloud instance. This script is executed as a single process. It invokes
+multiple "instances" which are formed from multiple threads for each. "instance" is kind of group of threads in this
+context.
+
+Illustrated as below:
+
+::
+
+    +-----------------------------+----------------------+-------+
+    |            process          |        thread        | core  |
+    +=============================+======================+=======+
+    | torch.backends.xeon.run_cpu | instance 0: thread 0 |   0   |
+    |                             |             thread 1 |   1   |
+    |                             +----------------------+-------+
+    |                             | instance 1: thread 0 |   2   |
+    |                             |             thread 1 |   3   |
+    |                             +----------------------+-------+
+    |                             | ...                  |  ...  |
+    |                             +----------------------+-------+
+    |                             | instance N: thread 0 |   M   |
+    |                             |             thread 1 |  M+1  |
+    +-----------------------------+----------------------+-------+
+
+To get the peak performance on Intel(R) Xeon(R) Scalable Processors, the script optimizes the configuration of thread and memory
+management. For thread management, the script configures thread affinity and the preload of Intel OMP library.
+For memory management, it configures NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc).
+
+Environment variables that will be set by this script:
+
++------------------+-------------------------------------------------------------------------------------------------+
+| Environ Variable |                                             Value                                               |
++==================+=================================================================================================+
+|    LD_PRELOAD    | Depending on knobs you set, /libiomp5.so, /libjemalloc.so, /libtcmalloc.so might |
+|                  | be appended to LD_PRELOAD.                                                                      |
++------------------+-------------------------------------------------------------------------------------------------+
+|   KMP_AFFINITY   | If libiomp5.so is preloaded, KMP_AFFINITY could be set to "granularity=fine,compact,1,0".       |
++------------------+-------------------------------------------------------------------------------------------------+
+|   KMP_BLOCKTIME  | If libiomp5.so is preloaded, KMP_BLOCKTIME is set to "1".                                       |
++------------------+-------------------------------------------------------------------------------------------------+
+|  OMP_NUM_THREADS | value of ncores_per_instance                                                                    |
++------------------+-------------------------------------------------------------------------------------------------+
+|    MALLOC_CONF   | If libjemalloc.so is preloaded, MALLOC_CONF will be set to                                      |
+|                  | "oversize_threshold:1,background_thread:true,metadata_thp:auto".                                |
++------------------+-------------------------------------------------------------------------------------------------+
+
+*Note*: This script respects environment variables set preliminarily. I.e. If you set the environment variables
+mentioned above before running the script, the script will not overwrite the values in the script.
+
+How to use this module:
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Single instance inference
+-------------------------
+
+1. Run single-instance inference on a single node with all CPU nodes.
+
+::
+
+   python -m torch.backends.xeon.run_cpu --throughput-mode script.py args
+
+2. Run single-instance inference on a single CPU node.
+
+::
+
+   python -m torch.backends.xeon.run_cpu --node-id 1 script.py args
+
+Multi-instance inference
+------------------------
+
+1. Multi-instance
+   By default this tool runs one process per node. If you want to set the instance numbers and core per instance,
+   --ninstances and  --ncores-per-instance should be set.
+
+::
+
+   python -m torch.backends.xeon.run_cpu -- python_script args
+
+   eg: on an Intel(R) Xeon(R) Scalable Processor with 14 instance, 4 cores per instance
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 14 --ncores-per-instance 4 python_script args
+
+2. Run single-instance inference among multiple instances.
+   By default, runs all ninstances. If you want to independently run a single instance among ninstances, specify rank.
+
+   eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 0-27)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 0 python_script args
+
+   eg: run 1st instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 28-55)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 1 python_script args
+
+   eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance, 2 cores per instance,
+   first four cores (i.e., numactl -C 0-1)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --core-list "0, 1, 2, 3" --ninstances 2 --ncores-per-instance 2
+   --rank 0 python_script args
+
+3. To look up what optional arguments this module offers:
+
+::
+
+    python -m torch.backends.xeon.run_cpu --help
+
+Memory allocator
+----------------
+
+"--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allcator.
+
+"""
+
+import glob
+import logging
+import os
+import platform
+import re
+import subprocess
+import sys
+from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER
+from os.path import expanduser
+from typing import Dict, List
+
+from torch.distributed.elastic.multiprocessing import (
+    DefaultLogsSpecs,
+    start_processes,
+    Std,
+)
+
+format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+logging.basicConfig(level=logging.INFO, format=format_str)
+logger = logging.getLogger(__name__)
+
+
+class _CPUinfo:
+    """Get CPU information, such as cores list and NUMA information."""
+
+    def __init__(self, test_input=""):
+        self.cpuinfo = []
+        if platform.system() in ["Windows", "Darwin"]:
+            raise RuntimeError(f"{platform.system()} is not supported!!!")
+        elif platform.system() == "Linux":
+            # Sample output of: `lscpu --parse=CPU,Core,Socket,Node`
+            #
+            # # The following is the parsable format, which can be fed to other
+            # # programs. Each different item in every column has an unique ID
+            # # starting from zero.
+            # # CPU,Core,Socket,Node
+            # 0,0,0,0
+            # 1,1,0,0
+            # ...
+            if test_input == "":
+                lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"]
+                lscpu_info = subprocess.check_output(
+                    lscpu_cmd, universal_newlines=True
+                ).split("\n")
+            else:
+                lscpu_info = test_input.split("\n")
+
+            # Get information about  cpu, core, socket and node
+            for line in lscpu_info:
+                pattern = r"^([\d]+,[\d]+,[\d]+,[\d]?)"
+                regex_out = re.search(pattern, line)
+                if regex_out:
+                    self.cpuinfo.append(regex_out.group(1).strip().split(","))
+
+            # physical cores := core column in lscpu output
+            #  logical cores :=  cPU column in lscpu output
+            self.node_nums = int(max([line[3] for line in self.cpuinfo])) + 1
+            self.node_physical_cores: List[List[int]] = []  # node_id is index
+            self.node_logical_cores: List[List[int]] = []  # node_id is index
+            self.physical_core_node_map = {}  # physical core to numa node id
+            self.logical_core_node_map = {}  # logical core to numa node id
+
+            for node_id in range(self.node_nums):
+                cur_node_physical_core = []
+                cur_node_logical_core = []
+                for cpuinfo in self.cpuinfo:
+                    nid = cpuinfo[3] if cpuinfo[3] != "" else "0"
+                    if node_id == int(nid):
+                        if int(cpuinfo[1]) not in cur_node_physical_core:
+                            cur_node_physical_core.append(int(cpuinfo[1]))
+                            self.physical_core_node_map[int(cpuinfo[1])] = int(node_id)
+                        cur_node_logical_core.append(int(cpuinfo[0]))
+                        self.logical_core_node_map[int(cpuinfo[0])] = int(node_id)
+                self.node_physical_cores.append(cur_node_physical_core)
+                self.node_logical_cores.append(cur_node_logical_core)
+
+    def _physical_core_nums(self):
+        return len(self.node_physical_cores) * len(self.node_physical_cores[0])
+
+    def _logical_core_nums(self):
+        return len(self.node_logical_cores) * len(self.node_logical_cores[0])
+
+    def get_node_physical_cores(self, node_id):
+        if node_id < 0 or node_id > self.node_nums - 1:
+            raise ValueError(
+                f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}"
+            )
+        return self.node_physical_cores[node_id]
+
+    def get_node_logical_cores(self, node_id):
+        if node_id < 0 or node_id > self.node_nums - 1:
+            raise ValueError(
+                f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}"
+            )
+        return self.node_logical_cores[node_id]
+
+    def get_all_physical_cores(self):
+        all_cores = []
+        for cores in self.node_physical_cores:
+            all_cores.extend(cores)
+        return all_cores
+
+    def get_all_logical_cores(self):
+        all_cores = []
+        for cores in self.node_logical_cores:
+            all_cores.extend(cores)
+        return all_cores
+
+    def numa_aware_check(self, core_list):
+        """
+        Check whether all cores in core_list are in the same NUMA node.
+
+        Cross NUMA will reduce performance.
+        We strongly advice to not use cores on different nodes.
+        """
+        cores_numa_map = self.logical_core_node_map
+        numa_ids = []
+        for core in core_list:
+            numa_id = cores_numa_map[core]
+            if numa_id not in numa_ids:
+                numa_ids.append(numa_id)
+        if len(numa_ids) > 1:
+            logger.warning(
+                "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \
+this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\
+instance. Alternatively, please use --skip-cross-node-cores knob.",
+                str(core_list),
+                str(numa_ids),
+            )
+        if len(numa_ids) == 0:
+            raise RuntimeError(
+                "invalid number of NUMA nodes; please make sure numa_ids >= 1"
+            )
+        return numa_ids
+
+
+class _Launcher:
+    r"""Class for launcher."""
+
+    msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \
+or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
+{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
+
+    def __init__(self):
+        self.cpuinfo = _CPUinfo()
+
+    def add_lib_preload(self, lib_type):
+        """Enable TCMalloc/JeMalloc/intel OpenMP."""
+        library_paths = []
+        if "CONDA_PREFIX" in os.environ:
+            library_paths.append(f"{os.environ['CONDA_PREFIX']}/lib")
+        if "VIRTUAL_ENV" in os.environ:
+            library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib")
+
+        library_paths += [
+            f"{expanduser('~')}/.local/lib",
+            "/usr/local/lib",
+            "/usr/local/lib64",
+            "/usr/lib",
+            "/usr/lib64",
+        ]
+
+        lib_find = False
+        lib_set = False
+        for item in os.getenv("LD_PRELOAD", "").split(":"):
+            if item.endswith(f"lib{lib_type}.so"):
+                lib_set = True
+                break
+        if not lib_set:
+            for lib_path in library_paths:
+                library_file = os.path.join(lib_path, f"lib{lib_type}.so")
+                matches = glob.glob(library_file)
+                if len(matches) > 0:
+                    ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")]
+                    os.environ["LD_PRELOAD"] = os.pathsep.join(
+                        [p.strip(os.pathsep) for p in ld_preloads if p]
+                    )
+                    lib_find = True
+                    break
+        return lib_set or lib_find
+
+    def is_numactl_available(self):
+        numactl_available = False
+        try:
+            cmd = ["numactl", "-C", "0", "-m", "0", "hostname"]
+            r = subprocess.run(
+                cmd,
+                env=os.environ,
+                stdout=subprocess.DEVNULL,
+                stderr=subprocess.DEVNULL,
+                check=False,
+            )
+            if r.returncode == 0:
+                numactl_available = True
+        except Exception:
+            pass
+        return numactl_available
+
+    def set_memory_allocator(
+        self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False
+    ):
+        """
+        Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc.
+
+        By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better
+        memory reuse and reduce page fault to improve performance.
+        """
+        if enable_tcmalloc and enable_jemalloc:
+            raise RuntimeError(
+                "Unable to enable TCMalloc and JEMalloc at the same time."
+            )
+
+        if enable_tcmalloc:
+            find_tc = self.add_lib_preload(lib_type="tcmalloc")
+            if not find_tc:
+                msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}'
+                logger.warning(msg.format("TCmalloc", "tcmalloc"))  # noqa: G001
+            else:
+                logger.info("Use TCMalloc memory allocator")
+
+        elif enable_jemalloc:
+            find_je = self.add_lib_preload(lib_type="jemalloc")
+            if not find_je:
+                msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}'
+                logger.warning(msg.format("Jemalloc", "jemalloc"))  # noqa: G001
+            else:
+                logger.info("Use JeMalloc memory allocator")
+                self.set_env(
+                    "MALLOC_CONF",
+                    "oversize_threshold:1,background_thread:true,metadata_thp:auto",
+                )
+
+        elif use_default_allocator:
+            pass
+
+        else:
+            find_tc = self.add_lib_preload(lib_type="tcmalloc")
+            if find_tc:
+                logger.info("Use TCMalloc memory allocator")
+                return
+            find_je = self.add_lib_preload(lib_type="jemalloc")
+            if find_je:
+                logger.info("Use JeMalloc memory allocator")
+                return
+            logger.warning(
+                """Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib
+                            or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or
+                           %s/.local/lib/ so the LD_PRELOAD environment variable will not be set.
+                           This may drop the performance""",
+                expanduser("~"),
+            )
+
+    def log_env_var(self, env_var_name=""):
+        if env_var_name in os.environ:
+            logger.info("%s=%s", env_var_name, os.environ[env_var_name])
+
+    def set_env(self, env_name, env_value):
+        if not env_value:
+            logger.warning("%s is None", env_name)
+        if env_name not in os.environ:
+            os.environ[env_name] = env_value
+        elif os.environ[env_name] != env_value:
+            logger.warning(
+                "Overriding value with the one set in environment variable: %s. \
+Value applied: %s. Value ignored: %s",
+                env_name,
+                os.environ[env_name],
+                env_value,
+            )
+        self.log_env_var(env_name)
+
+    # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not.
+    # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores.
+    # In this case, KMP_AFFINITY should not be set.
+    def set_multi_thread_and_allocator(
+        self,
+        ncores_per_instance,
+        disable_iomp=False,
+        set_kmp_affinity=True,
+        enable_tcmalloc=True,
+        enable_jemalloc=False,
+        use_default_allocator=False,
+    ):
+        """
+        Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc.
+
+        By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives
+        to get performance benefit.
+        """
+        self.set_memory_allocator(
+            enable_tcmalloc, enable_jemalloc, use_default_allocator
+        )
+        self.set_env("OMP_NUM_THREADS", str(ncores_per_instance))
+        if not disable_iomp:
+            find_iomp = self.add_lib_preload(lib_type="iomp5")
+            if not find_iomp:
+                msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}'
+                logger.warning(msg.format("iomp", "iomp5"))  # noqa: G001
+            else:
+                logger.info("Using Intel OpenMP")
+                if set_kmp_affinity:
+                    self.set_env("KMP_AFFINITY", "granularity=fine,compact,1,0")
+                self.set_env("KMP_BLOCKTIME", "1")
+        self.log_env_var("LD_PRELOAD")
+
+    r"""
+     Launcher for single instance and multi-instance
+     """
+
+    def launch(self, args):
+        cores = []
+        set_kmp_affinity = True
+        enable_taskset = False
+        if args.core_list:  # user specify what cores will be used by params
+            cores = [int(x) for x in args.core_list.split(",")]
+            if args.ncores_per_instance == -1:
+                raise RuntimeError(
+                    'please specify the "--ncores-per-instance" if you have pass the --core-list params'
+                )
+            elif (
+                args.ninstances > 1
+                and args.ncores_per_instance * args.ninstances < len(cores)
+            ):
+                logger.warning(
+                    "only first %s cores will be used, \
+but you specify %s cores in core_list",
+                    args.ncores_per_instance * args.ninstances,
+                    len(cores),
+                )
+            else:
+                args.ninstances = len(cores) // args.ncores_per_instance
+
+        else:
+            if args.use_logical_core:
+                if args.node_id != -1:
+                    cores = self.cpuinfo.get_node_logical_cores(args.node_id)
+                else:
+                    cores = self.cpuinfo.get_all_logical_cores()
+                    # When using all cores on all nodes, including logical cores,
+                    # setting KMP_AFFINITY disables logical cores. Thus, KMP_AFFINITY should not be set.
+                    set_kmp_affinity = False
+            else:
+                if args.node_id != -1:
+                    cores = self.cpuinfo.get_node_physical_cores(args.node_id)
+                else:
+                    cores = self.cpuinfo.get_all_physical_cores()
+            if (
+                not args.multi_instance
+                and args.ninstances == -1
+                and args.ncores_per_instance == -1
+            ):
+                args.ninstances = 1
+                args.ncores_per_instance = len(cores)
+            elif (
+                args.multi_instance
+                and args.ninstances == -1
+                and args.ncores_per_instance == -1
+            ):
+                args.throughput_mode = True
+            elif args.ncores_per_instance == -1 and args.ninstances != -1:
+                if args.ninstances > len(cores):
+                    raise RuntimeError(
+                        f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \
+please make sure ninstances <= total_cores)"
+                    )
+                else:
+                    args.ncores_per_instance = len(cores) // args.ninstances
+            elif args.ncores_per_instance != -1 and args.ninstances == -1:
+                if not args.skip_cross_node_cores:
+                    args.ninstances = len(cores) // args.ncores_per_instance
+                else:
+                    ncore_per_node = len(self.cpuinfo.node_physical_cores[0])
+                    num_leftover_cores = ncore_per_node % args.ncores_per_instance
+                    if args.ncores_per_instance > ncore_per_node:
+                        # too many ncores_per_instance to skip cross-node cores
+                        logger.warning(
+                            "there are %s core(s) per socket, but you specify %s ncores_per_instance and \
+skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \
+socket",
+                            ncore_per_node,
+                            args.ncores_per_instance,
+                        )
+                        sys.exit(-1)
+                    elif num_leftover_cores == 0:
+                        # aren't any cross-node cores
+                        logger.info(
+                            "--skip-cross-node-cores is set, but there are no cross-node cores."
+                        )
+                        args.ninstances = len(cores) // args.ncores_per_instance
+                    else:
+                        # skip cross-node cores
+                        if args.ninstances != -1:
+                            logger.warning(
+                                "--skip-cross-node-cores is exclusive to --ninstances. --ninstances \
+won't take effect even if it is set explicitly."
+                            )
+
+                        i = 1
+                        leftover_cores = set()
+                        while ncore_per_node * i <= len(cores):
+                            leftover_cores.update(
+                                cores[
+                                    ncore_per_node * i
+                                    - num_leftover_cores : ncore_per_node * i
+                                ]
+                            )
+                            i += 1
+                        cores = list(set(cores) - leftover_cores)
+                        assert len(cores) % args.ncores_per_instance == 0
+                        args.ninstances = len(cores) // args.ncores_per_instance
+            else:
+                if args.ninstances * args.ncores_per_instance > len(cores):
+                    raise RuntimeError(
+                        "Please make sure ninstances * ncores_per_instance <= total_cores"
+                    )
+            if args.latency_mode:
+                logger.warning(
+                    "--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \
+--use-logical-core. They won't take effect even they are set explicitly."
+                )
+                args.ncores_per_instance = 4
+                cores = self.cpuinfo.get_all_physical_cores()
+                args.ninstances = len(cores) // args.ncores_per_instance
+
+            if args.throughput_mode:
+                logger.warning(
+                    "--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \
+--use-logical-core. They won't take effect even they are set explicitly."
+                )
+                args.ninstances = self.cpuinfo.node_nums
+                cores = self.cpuinfo.get_all_physical_cores()
+                args.ncores_per_instance = len(cores) // args.ninstances
+
+        if args.ninstances > 1 and args.rank != -1:
+            logger.info(
+                "assigning %s cores for instance %s",
+                args.ncores_per_instance,
+                args.rank,
+            )
+
+        if not args.disable_numactl:
+            numactl_available = self.is_numactl_available()
+            if not numactl_available:
+                if not args.disable_taskset:
+                    logger.warning(
+                        "Core binding with numactl is not available. Disabling numactl and using taskset instead. \
+                    This may affect performance in multi-socket system; please use numactl if memory binding is needed."
+                    )
+                    args.disable_numactl = True
+                    enable_taskset = True
+                else:
+                    logger.warning(
+                        "Core binding with numactl is not available, and --disable_taskset is set. \
+                    Please unset --disable_taskset to use taskset instead of numactl."
+                    )
+                    sys.exit(-1)
+
+        if not args.disable_taskset:
+            enable_taskset = True
+
+        self.set_multi_thread_and_allocator(
+            args.ncores_per_instance,
+            args.disable_iomp,
+            set_kmp_affinity,
+            args.enable_tcmalloc,
+            args.enable_jemalloc,
+            args.use_default_allocator,
+        )
+        entrypoint = ""
+        launch_args = {}
+        launch_envs: Dict[int, Dict] = {}
+        launch_tee = {}
+        for i in range(args.ninstances):
+            cmd = []
+            cur_process_cores = ""
+            if not args.disable_numactl or enable_taskset:
+                if not args.disable_numactl:
+                    cmd = ["numactl"]
+                elif enable_taskset:
+                    cmd = ["taskset"]
+                cores = sorted(cores)
+                if (
+                    args.rank == -1
+                ):  # sequentially assign ncores_per_instance to ninstances
+                    core_list = cores[
+                        i
+                        * args.ncores_per_instance : (i + 1)
+                        * args.ncores_per_instance
+                    ]
+                else:  # assign ncores_per_instance from rank
+                    core_list = cores[
+                        args.rank
+                        * args.ncores_per_instance : (args.rank + 1)
+                        * args.ncores_per_instance
+                    ]
+
+                core_ranges: List[Dict] = []
+                for core in core_list:
+                    if len(core_ranges) == 0:
+                        range_elem = {"start": core, "end": core}
+                        core_ranges.append(range_elem)
+                    else:
+                        if core - core_ranges[-1]["end"] == 1:
+                            core_ranges[-1]["end"] = core
+                        else:
+                            range_elem = {"start": core, "end": core}
+                            core_ranges.append(range_elem)
+                for r in core_ranges:
+                    cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']},"
+                cur_process_cores = cur_process_cores[:-1]
+                if not args.disable_numactl:
+                    numa_params = f"-C {cur_process_cores} "
+                    numa_ids = ",".join(
+                        [
+                            str(numa_id)
+                            for numa_id in self.cpuinfo.numa_aware_check(core_list)
+                        ]
+                    )
+                    numa_params += f"-m {numa_ids}"
+                    cmd.extend(numa_params.split())
+                elif enable_taskset:
+                    taskset_params = f"-c {cur_process_cores} "
+                    cmd.extend(taskset_params.split())
+            with_python = not args.no_python
+            if with_python:
+                cmd.append(sys.executable)
+                cmd.append("-u")
+            if args.module:
+                cmd.append("-m")
+            cmd.append(args.program)
+            cmd.extend(args.program_args)
+            cmd_s = " ".join(cmd)
+            logger.info(cmd_s)
+            if entrypoint == "":
+                entrypoint = cmd[0]
+            del cmd[0]
+            launch_args[i] = tuple(cmd)
+            launch_envs[i] = {}
+            launch_tee[i] = Std.ALL
+
+            if args.rank != -1:  # launches single instance, rank, only
+                break
+
+        ctx = start_processes(
+            name=args.log_file_prefix,
+            entrypoint=entrypoint,
+            args=launch_args,
+            envs=launch_envs,
+            logs_specs=DefaultLogsSpecs(log_dir=args.log_path, tee=launch_tee),
+        )
+        ctx.wait()
+
+
+def _add_memory_allocator_params(parser):
+    group = parser.add_argument_group("Memory Allocator Parameters")
+    # allocator control
+    group.add_argument(
+        "--enable-tcmalloc",
+        "--enable_tcmalloc",
+        action="store_true",
+        default=False,
+        help="Enable tcmalloc allocator",
+    )
+    group.add_argument(
+        "--enable-jemalloc",
+        "--enable_jemalloc",
+        action="store_true",
+        default=False,
+        help="Enable jemalloc allocator",
+    )
+    group.add_argument(
+        "--use-default-allocator",
+        "--use_default_allocator",
+        action="store_true",
+        default=False,
+        help="Use default memory allocator",
+    )
+
+
+def _add_multi_instance_params(parser):
+    group = parser.add_argument_group("Multi-instance Parameters")
+    # multi-instance control
+    group.add_argument(
+        "--ncores-per-instance",
+        "--ncores_per_instance",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="Cores per instance",
+    )
+    group.add_argument(
+        "--ninstances",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="For multi-instance, you should give the cores number you used for per instance.",
+    )
+    group.add_argument(
+        "--skip-cross-node-cores",
+        "--skip_cross_node_cores",
+        action="store_true",
+        default=False,
+        help="If specified --ncores-per-instance, skips cross-node cores.",
+    )
+    group.add_argument(
+        "--rank",
+        metavar="\b",
+        default="-1",
+        type=int,
+        help="Specify instance index to assign ncores_per_instance for rank; \
+otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \
+https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md",
+    )
+    group.add_argument(
+        "--latency-mode",
+        "--latency_mode",
+        action="store_true",
+        default=False,
+        help="By default 4 core per instance and use all physical cores",
+    )
+    group.add_argument(
+        "--throughput-mode",
+        "--throughput_mode",
+        action="store_true",
+        default=False,
+        help="By default one instance per node and use all physical cores",
+    )
+    group.add_argument(
+        "--node-id",
+        "--node_id",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="node id for multi-instance, by default all nodes will be used",
+    )
+    group.add_argument(
+        "--use-logical-core",
+        "--use_logical_core",
+        action="store_true",
+        default=False,
+        help="Whether only use physical cores",
+    )
+    group.add_argument(
+        "--disable-numactl",
+        "--disable_numactl",
+        action="store_true",
+        default=False,
+        help="Disable numactl",
+    )
+    group.add_argument(
+        "--disable-taskset",
+        "--disable_taskset",
+        action="store_true",
+        default=False,
+        help="Disable taskset",
+    )
+    group.add_argument(
+        "--core-list",
+        "--core_list",
+        metavar="\b",
+        default=None,
+        type=str,
+        help='Specify the core list as "core_id, core_id, ....", otherwise, all the cores will be used.',
+    )
+    group.add_argument(
+        "--log-path",
+        "--log_path",
+        metavar="\b",
+        default="",
+        type=str,
+        help="The log file directory. Default path is "
+        ", which means disable logging to files.",
+    )
+    group.add_argument(
+        "--log-file-prefix",
+        "--log_file_prefix",
+        metavar="\b",
+        default="run",
+        type=str,
+        help="log file prefix",
+    )
+
+
+def _add_kmp_iomp_params(parser):
+    group = parser.add_argument_group("IOMP Parameters")
+    group.add_argument(
+        "--disable-iomp",
+        "--disable_iomp",
+        action="store_true",
+        default=False,
+        help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD",
+    )
+
+
+def create_args(parser=None):
+    """
+    Parse the command line options.
+
+    @retval ArgumentParser
+    """
+    parser.add_argument(
+        "--multi-instance",
+        "--multi_instance",
+        action="store_true",
+        default=False,
+        help="Enable multi-instance, by default one instance per node",
+    )
+
+    parser.add_argument(
+        "-m",
+        "--module",
+        default=False,
+        action="store_true",
+        help="Changes each process to interpret the launch script "
+        "as a python module, executing with the same behavior as"
+        '"python -m".',
+    )
+
+    parser.add_argument(
+        "--no-python",
+        "--no_python",
+        default=False,
+        action="store_true",
+        help='Do not prepend the --program script with "python" - just exec '
+        "it directly. Useful when the script is not a Python script.",
+    )
+
+    _add_memory_allocator_params(parser)
+    _add_kmp_iomp_params(parser)
+
+    _add_multi_instance_params(parser)
+    # positional
+    parser.add_argument(
+        "program",
+        type=str,
+        help="The full path to the program/script to be launched. "
+        "followed by all the arguments for the script",
+    )
+
+    # rest from the training program
+    parser.add_argument("program_args", nargs=REMAINDER)
+
+
+def main(args):
+    env_before = set(os.environ.keys())
+    if platform.system() in ["Windows", "Darwin"]:
+        raise RuntimeError(f"{platform.system()} is not supported!!!")
+
+    if args.log_path:
+        os.makedirs(args.log_path, exist_ok=True)
+    else:
+        args.log_path = os.devnull
+
+    if args.latency_mode and args.throughput_mode:
+        raise RuntimeError(
+            "Either args.latency_mode or args.throughput_mode should be set"
+        )
+
+    if not args.no_python and not args.program.endswith(".py"):
+        raise RuntimeError(
+            'For non Python script, you should use "--no-python" parameter.'
+        )
+
+    # Verify LD_PRELOAD
+    if "LD_PRELOAD" in os.environ:
+        lst_valid = []
+        tmp_ldpreload = os.environ["LD_PRELOAD"]
+        for item in tmp_ldpreload.split(":"):
+            matches = glob.glob(item)
+            if len(matches) > 0:
+                lst_valid.append(item)
+            else:
+                logger.warning("%s doesn't exist. Removing it from LD_PRELOAD.", item)
+        if len(lst_valid) > 0:
+            os.environ["LD_PRELOAD"] = ":".join(lst_valid)
+        else:
+            os.environ["LD_PRELOAD"] = ""
+
+    launcher = _Launcher()
+    launcher.launch(args)
+    for x in sorted(set(os.environ.keys()) - env_before):
+        logger.debug("%s=%s", x, os.environ[x])
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser(
+        description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable "
+        "Processors with optimal configurations. Single instance inference, "
+        "multi-instance inference are enable. To get the peak performance on Intel(R) "
+        "Xeon(R) Scalable Processors, the script optimizes the configuration "
+        "of thread and memory management. For thread management, the script configures thread "
+        "affinity and the preload of Intel OMP library. For memory management, it configures "
+        "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) "
+        "\n################################# Basic usage ############################# \n"
+        "\n 1. single instance\n"
+        "\n   >>> python -m torch.backends.xeon.run_cpu python_script args \n"
+        "\n2. multi-instance \n"
+        "\n   >>> python -m torch.backends.xeon.run_cpu --ninstances xxx "
+        "--ncores-per-instance xx python_script args\n"
+        "\n############################################################################# \n",
+        formatter_class=RawTextHelpFormatter,
+    )
+    create_args(parser)
+    args = parser.parse_args()
+    main(args)
diff --git a/MLPY/Lib/site-packages/torch/backends/xnnpack/__init__.py b/MLPY/Lib/site-packages/torch/backends/xnnpack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4153d7fc4c81126af6fec996619fedc255c264fc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/backends/xnnpack/__init__.py
@@ -0,0 +1,28 @@
+import sys
+import types
+
+import torch
+
+
+class _XNNPACKEnabled:
+    def __get__(self, obj, objtype):
+        return torch._C._is_xnnpack_enabled()
+
+    def __set__(self, obj, val):
+        raise RuntimeError("Assignment not supported")
+
+
+class XNNPACKEngine(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+    enabled = _XNNPACKEnabled()
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__)
diff --git a/MLPY/Lib/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ce784e99210a42c6b922fc17495b530f54ae8ee
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/bin/asmjit.dll b/MLPY/Lib/site-packages/torch/bin/asmjit.dll
new file mode 100644
index 0000000000000000000000000000000000000000..75d2557ec3e06bd90e8fc51199ca6dd1c73879d4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/bin/asmjit.dll differ
diff --git a/MLPY/Lib/site-packages/torch/bin/fbgemm.dll b/MLPY/Lib/site-packages/torch/bin/fbgemm.dll
new file mode 100644
index 0000000000000000000000000000000000000000..c2bafc25ee4fa9f7621a9a9b004d732ac61fab4f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/bin/fbgemm.dll
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e18ad86dae8caa56ebfd9da655f9c8b81d324a35586caf78734d9d0a48aa0518
+size 4961280
diff --git a/MLPY/Lib/site-packages/torch/bin/protoc.exe b/MLPY/Lib/site-packages/torch/bin/protoc.exe
new file mode 100644
index 0000000000000000000000000000000000000000..bc9a6b57f9b243fd57866b26fcedc9fc2bfd1fae
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/bin/protoc.exe
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f10551c6cbc7187ee90ece18ffc24635dc1d308479718919a4807fee6c41551
+size 2812416
diff --git a/MLPY/Lib/site-packages/torch/compiler/__init__.py b/MLPY/Lib/site-packages/torch/compiler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d17c5fcd57b44644193f2d7cd2519087d041ab68
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/compiler/__init__.py
@@ -0,0 +1,193 @@
+import torch
+from typing import List
+
+__all__ = [
+    "compile",
+    "assume_constant_result",
+    "reset",
+    "allow_in_graph",
+    "list_backends",
+    "disable",
+    "cudagraph_mark_step_begin",
+    "wrap_numpy",
+    "is_compiling",
+    "is_dynamo_compiling",
+]
+
+def compile(*args, **kwargs):
+    """
+    See :func:`torch.compile` for details on the arguments for this function.
+    """
+    return torch.compile(*args, **kwargs)
+
+def reset() -> None:
+    """
+    This function clears all compilation caches and restores the system to its initial state.
+    It is recommended to call this function, especially after using operations like `torch.compile(...)`
+    to ensure a clean state before another unrelated compilation
+    """
+    import torch._dynamo
+
+    torch._dynamo.reset()
+
+def allow_in_graph(fn):
+    """
+    Customize which functions compilation will include in the generated graph.
+    It bypasses all introspection of the symbolic python code in favor of
+    directly writing it to the graph.
+    If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()`
+    to each function and returns a new list or tuple containing the modified functions
+
+    Args:
+        fn: A callable representing the function to be included in the graph.
+
+    .. warning::
+
+        :func:`allow_in_graph` skips TorchDynamo completely on the decorated function
+        skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
+        Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems
+        like AOT Autograd rely on torchdynamo
+        If not careful, this could lead to soundness and really hard-to-debug issues.
+
+    """
+    import torch._dynamo
+
+    return torch._dynamo.allow_in_graph(fn)
+
+
+def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
+    """
+    Return valid strings that can be passed to `torch.compile(..., backend="name")`.
+
+    Args:
+        exclude_tags(optional): A tuple of strings representing tags to exclude.
+    """
+    import torch._dynamo
+
+    return torch._dynamo.list_backends(exclude_tags)
+
+def assume_constant_result(fn):
+    """
+    This function is used to mark a function `fn` as having a constant result.
+    This allows the compiler to optimize away your function
+    Returns The same function `fn`
+
+    Args:
+        fn: The function to be marked as having a constant result.
+
+    .. warning::
+        `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile`
+        will not attempt to validate whether the constant assumption is true or not
+
+    """
+    import torch._dynamo
+
+    return torch._dynamo.assume_constant_result(fn)
+
+def disable(fn=None, recursive=True):
+    """
+    This function provides both a decorator and a context manager to disable compilation on a function
+    It also provides the option of recursively disabling called functions
+
+    Args:
+        fn (optional): The function to disable
+        recursive (optional): A boolean value indicating whether the disabling should be recursive.
+    """
+    import torch._dynamo
+
+    return torch._dynamo.disable(fn, recursive)
+
+def cudagraph_mark_step_begin():
+    """
+    Indicates that a new iteration of inference or training is about to begin.
+
+    CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of
+    torch.compile, so long as there is not a pending backward that has not been called.
+
+    If that heuristic is wrong, such as in the following example, manually mark it with this api.
+
+    .. code-block:: python
+
+        @torch.compile(mode="reduce-overhead")
+        def rand_foo():
+            return torch.rand([4], device="cuda")
+
+        for _ in range(5):
+            torch.compiler.cudagraph_mark_step_begin()
+            rand_foo() + rand_foo()
+
+    For more details, see `torch.compiler_cudagraph_trees `__
+    """
+    from torch._inductor import cudagraph_trees
+
+    cudagraph_trees.mark_step_begin()
+
+def wrap_numpy(fn):
+    r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
+    from ``torch.Tensor``s to ``torch.Tensor``s.
+
+    It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to
+    compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code
+    on CUDA or compute its gradients.
+
+    .. note::
+
+        This decorator does not work without :func:`torch.compile`.
+
+    Example::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
+        >>> # Compile a NumPy function as a Tensor -> Tensor function
+        >>> @torch.compile(fullgraph=True)
+        >>> @torch.compiler.wrap_numpy
+        >>> def fn(a: np.ndarray):
+        >>>     return np.sum(a * a)
+        >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients
+        >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)
+        >>> out = fn(x)
+        >>> out.backward()
+        >>> print(x.grad)
+        tensor([ 0.,  2.,  4.,  6.,  8., 10.], device='cuda:0')
+    """
+    from torch._dynamo.external_utils import wrap_numpy as wrap
+    return wrap(fn)
+
+_is_compiling_flag: bool = False
+
+def is_compiling() -> bool:
+    """
+    Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
+
+    Note that there are 2 other related flags that should deprecated eventually:
+      * torch._dynamo.external_utils.is_compiling()
+      * torch._utils.is_compiling()
+
+    Example::
+
+        >>> def forward(self, x):
+        >>>     if not torch.compiler.is_compiling():
+        >>>        ...logic that is not needed in a compiled/traced graph...
+        >>>
+        >>>     ...rest of the function...
+    """
+    if torch.jit.is_scripting():
+        return False
+    else:
+        return _is_compiling_flag
+
+def is_dynamo_compiling() -> bool:
+    """
+    Indicates whether a graph is traced via TorchDynamo.
+
+    It's stricter than is_compiling() flag, as it would only be set to True when
+    TorchDynamo is used.
+
+    Example::
+
+        >>> def forward(self, x):
+        >>>     if not torch.compiler.is_dynamo_compiling():
+        >>>        ...logic that is not needed in a TorchDynamo-traced graph...
+        >>>
+        >>>     ...rest of the function...
+    """
+    return False
diff --git a/MLPY/Lib/site-packages/torch/compiler/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/compiler/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..543c074dd38a49922b284074b3950afa64900790
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/compiler/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/contrib/__init__.py b/MLPY/Lib/site-packages/torch/contrib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/contrib/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/contrib/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cfe12a556b18ba0d2e6d3940602ceea738971b3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/contrib/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc b/MLPY/Lib/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c607878f3640d37db6ab12af9567a38d725c7f0d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/contrib/_tensorboard_vis.py b/MLPY/Lib/site-packages/torch/contrib/_tensorboard_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b411bb3fd2b52ccf7bd0fe176516789d521fa02
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/contrib/_tensorboard_vis.py
@@ -0,0 +1,142 @@
+import time
+from collections import defaultdict
+from functools import partial
+from typing import DefaultDict
+
+import torch
+
+
+# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do
+# anything without having TF installed, and so this file has a hard dependency on it
+# as well. It really is a debugging tool, so it doesn't matter.
+try:
+    from tensorflow.core.util import event_pb2
+    from tensorflow.core.framework import graph_pb2
+    from tensorflow.python.summary.writer.writer import FileWriter
+except ImportError:
+    raise ImportError("TensorBoard visualization of GraphExecutors requires having "
+                      "TensorFlow installed") from None
+
+
+def dump_tensorboard_summary(graph_executor, logdir):
+    with FileWriter(logdir) as w:
+        pb_graph = visualize(graph_executor)
+        evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString())
+        w.add_event(evt)
+
+
+def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
+    """Visualizes an independent graph, or a graph executor."""
+    value_map = {}
+    pb_graph = pb_graph or graph_pb2.GraphDef()
+
+    if isinstance(graph, torch._C.GraphExecutorState):
+        visualize_graph_executor(graph, name_prefix, pb_graph,
+                                 partial(visualize, pb_graph=pb_graph))
+        return pb_graph
+
+    # Set up an input node
+    input_node = pb_graph.node.add(op='input', name=name_prefix + 'input')
+    for i, value in enumerate(graph.param_node().outputs()):
+        value_map[value.unique()] = name_prefix + 'input:' + str(i)
+
+    visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it)
+
+    # Gather all outputs
+    return_node = pb_graph.node.add(op='output', name=name_prefix + 'output')
+    for value in graph.return_node().inputs():
+        return_node.input.append(value_map[value.unique()])
+
+    return pb_graph
+
+
+def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
+    """Append the state of a given GraphExecutor to the graph protobuf.
+
+    Args:
+        state (GraphExecutor or GraphExecutorState): GraphExecutor to display.
+        name_prefix (str): Name prefix of the containing subgraph.
+        pb_graph (GraphDef): graph to append to.
+        inline_graph (Callable): a function that handles setting up a value_map,
+            so that some graphs in here can be inlined. This is necessary, because
+            this will simply be `visualize` for the top-level GraphExecutor,
+            or `inline_graph` for all nested ones.
+
+            The signature should look like (Graph, name_prefix) -> ().
+            It will be called exactly once.
+
+    The strategy is to embed all different configurations as independent subgraphs,
+    while inlining the original graph as the one that actually produces the values.
+    """
+    if state.autograd_fallback_graph is not None:
+        visualize(graph=state.autograd_fallback_graph,
+                  name_prefix=name_prefix + 'autograd_fallback/',
+                  pb_graph=pb_graph,
+                  executors_it=iter(state.autograd_fallback.executors()))
+
+    for i, (arg_spec, plan) in enumerate(state.execution_plans.items()):
+        subgraph_name = name_prefix + f'plan{i}/'
+
+        # Create a disconnected node that will keep information regarding the input
+        # types of this trace. This is unfortunately a bit too verbose to be included
+        # in the subgraph name.
+        input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name)
+        input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii')
+
+        visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors()))
+
+        # Show gradient as an independent subgraph of this plan
+        if plan.grad_executor is not None:
+            grad_subgraph_name = subgraph_name + 'grad/'
+            visualize(plan.grad_executor, grad_subgraph_name, pb_graph)
+
+    return inline_graph(state.graph, name_prefix + 'original/')
+
+
+def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None):
+    """Recursive part of visualize (basically skips setting up the input and output nodes)."""
+    def inline_graph(subgraph, name, node):
+        rec_value_map = {inp.unique(): value_map[val.unique()]
+                         for inp, val in zip(subgraph.inputs(), node.inputs())}
+        visualize_rec(graph=subgraph,
+                      value_map=rec_value_map,
+                      name_prefix=name,
+                      pb_graph=pb_graph)
+        for out, val in zip(subgraph.outputs(), node.outputs()):
+            value_map[val.unique()] = rec_value_map[out.unique()]
+
+    op_id_counter: DefaultDict[str, int] = defaultdict(int)
+
+    def name_for(node):
+        kind = node.kind()[node.kind().index('::') + 2:]
+        op_id_counter[kind] += 1
+        return kind, name_prefix + kind + '_' + str(op_id_counter[kind])
+
+    def add_fusion_group(node):
+        op, name = name_for(node)
+        inline_graph(node.g('Subgraph'), name + '/', node)
+
+    def add_graph_executor(node):
+        op, name = name_for(node)
+        if executors_it is None:
+            add_node(node)
+        else:
+            ge = next(executors_it)
+            visualize_graph_executor(ge, name + '/', pb_graph,
+                                     partial(inline_graph, node=node))
+
+    def add_node(node):
+        if node.kind() == 'prim::FusionGroup':
+            return add_fusion_group(node)
+        elif node.kind() == 'prim::GraphExecutor':
+            return add_graph_executor(node)
+        op, name = name_for(node)
+        pb_node = pb_graph.node.add(op=op, name=name)
+        for value in node.inputs():
+            pb_node.input.append(value_map[value.unique()])
+        # TODO: handle attrs
+        for i, value in enumerate(node.outputs()):
+            value_map[value.unique()] = name + ':' + str(i)
+
+    for node in graph.nodes():
+        add_node(node)
diff --git a/MLPY/Lib/site-packages/torch/cpu/__init__.py b/MLPY/Lib/site-packages/torch/cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02bf9e7b7eb5f8e9a7e9440c70e9bbeec1ebc3ab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cpu/__init__.py
@@ -0,0 +1,157 @@
+r"""
+This package implements abstractions found in ``torch.cuda``
+to facilitate writing device-agnostic code.
+"""
+
+from contextlib import AbstractContextManager
+from typing import Any, Optional, Union
+
+import torch
+
+from .. import device as _device
+from . import amp
+
+__all__ = [
+    "is_available",
+    "synchronize",
+    "current_device",
+    "current_stream",
+    "stream",
+    "set_device",
+    "device_count",
+    "Stream",
+    "StreamContext",
+    "Event",
+]
+
+_device_t = Union[_device, str, int, None]
+
+
+def _is_cpu_support_vnni() -> bool:
+    r"""Returns a bool indicating if CPU supports VNNI."""
+    return torch._C._cpu._is_cpu_support_vnni()
+
+
+def is_available() -> bool:
+    r"""Returns a bool indicating if CPU is currently available.
+
+    N.B. This function only exists to facilitate device-agnostic code
+
+    """
+    return True
+
+
+def synchronize(device: _device_t = None) -> None:
+    r"""Waits for all kernels in all streams on the CPU device to complete.
+
+    Args:
+        device (torch.device or int, optional): ignored, there's only one CPU device.
+
+    N.B. This function only exists to facilitate device-agnostic code.
+    """
+    pass
+
+
+class Stream:
+    """
+    N.B. This class only exists to facilitate device-agnostic code
+    """
+
+    def __init__(self, priority: int = -1):
+        pass
+
+    def wait_stream(self, stream) -> None:
+        pass
+
+
+class Event:
+    def query(self) -> bool:
+        return True
+
+    def record(self, stream=None):
+        pass
+
+    def synchronize(self):
+        pass
+
+    def wait(self, stream=None):
+        pass
+
+
+_default_cpu_stream = Stream()
+_current_stream = _default_cpu_stream
+
+
+def current_stream(device: _device_t = None) -> Stream:
+    r"""Returns the currently selected :class:`Stream` for a given device.
+
+    Args:
+        device (torch.device or int, optional): Ignored.
+
+    N.B. This function only exists to facilitate device-agnostic code
+
+    """
+    return _current_stream
+
+
+class StreamContext(AbstractContextManager):
+    r"""Context-manager that selects a given stream.
+
+    N.B. This class only exists to facilitate device-agnostic code
+
+    """
+    cur_stream: Optional[Stream]
+
+    def __init__(self, stream):
+        self.stream = stream
+        self.prev_stream = _default_cpu_stream
+
+    def __enter__(self):
+        cur_stream = self.stream
+        if cur_stream is None:
+            return
+
+        global _current_stream
+        self.prev_stream = _current_stream
+        _current_stream = cur_stream
+
+    def __exit__(self, type: Any, value: Any, traceback: Any):
+        cur_stream = self.stream
+        if cur_stream is None:
+            return
+
+        global _current_stream
+        _current_stream = self.prev_stream
+
+
+def stream(stream: Stream) -> AbstractContextManager:
+    r"""Wrapper around the Context-manager StreamContext that
+    selects a given stream.
+
+    N.B. This function only exists to facilitate device-agnostic code
+    """
+    return StreamContext(stream)
+
+
+def device_count() -> int:
+    r"""Returns number of CPU devices (not cores). Always 1.
+
+    N.B. This function only exists to facilitate device-agnostic code
+    """
+    return 1
+
+
+def set_device(device: _device_t) -> None:
+    r"""Sets the current device, in CPU we do nothing.
+
+    N.B. This function only exists to facilitate device-agnostic code
+    """
+    pass
+
+
+def current_device() -> str:
+    r"""Returns current device for cpu. Always 'cpu'.
+
+    N.B. This function only exists to facilitate device-agnostic code
+    """
+    return "cpu"
diff --git a/MLPY/Lib/site-packages/torch/cpu/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cpu/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..774468c28dd9106645502203c00db2ee3a944f8a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cpu/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cpu/amp/__init__.py b/MLPY/Lib/site-packages/torch/cpu/amp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..973717d653ba9c47a5b63be6aa699dfe0d25b58f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cpu/amp/__init__.py
@@ -0,0 +1,2 @@
+from .autocast_mode import autocast
+from .grad_scaler import GradScaler
diff --git a/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b88b8da0941e5b4e3b98d14f104c44f4d2a2ee1f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa58d22a8bc16cb885d29fbe34eebca7c06fb6b1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc612338165d7539d56018a7375085ed5ecd146e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cpu/amp/autocast_mode.py b/MLPY/Lib/site-packages/torch/cpu/amp/autocast_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..03075f923746c3a7f625a50aea4ed1bb9eef6403
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cpu/amp/autocast_mode.py
@@ -0,0 +1,43 @@
+from typing import Any
+
+import torch
+
+__all__ = ["autocast"]
+
+
+class autocast(torch.amp.autocast_mode.autocast):
+    r"""
+    See :class:`torch.autocast`.
+    ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)``
+    """
+
+    def __init__(
+        self,
+        enabled: bool = True,
+        dtype: torch.dtype = torch.bfloat16,
+        cache_enabled: bool = True,
+    ):
+        if torch._jit_internal.is_scripting():
+            self._enabled = enabled
+            self.device = "cpu"
+            self.fast_dtype = dtype
+            return
+        super().__init__(
+            "cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
+        )
+
+    def __enter__(self):
+        if torch._jit_internal.is_scripting():
+            return self
+        return super().__enter__()
+
+    # TODO: discuss a unified TorchScript-friendly API for autocast
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
+        if torch._jit_internal.is_scripting():
+            return
+        return super().__exit__(exc_type, exc_val, exc_tb)
+
+    def __call__(self, func):
+        if torch._jit_internal.is_scripting():
+            return func
+        return super().__call__(func)
diff --git a/MLPY/Lib/site-packages/torch/cpu/amp/grad_scaler.py b/MLPY/Lib/site-packages/torch/cpu/amp/grad_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d935371df1f66ee93675c259a51361d8caa903
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cpu/amp/grad_scaler.py
@@ -0,0 +1,27 @@
+import torch
+
+__all__ = ["GradScaler"]
+
+
+class GradScaler(torch.amp.GradScaler):
+    r"""
+    See :class:`torch.amp.GradScaler`.
+    ``torch.cpu.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cpu", args...)``
+    """
+
+    def __init__(
+        self,
+        init_scale: float = 2.0**16,
+        growth_factor: float = 2.0,
+        backoff_factor: float = 0.5,
+        growth_interval: int = 2000,
+        enabled: bool = True,
+    ) -> None:
+        super().__init__(
+            "cpu",
+            init_scale=init_scale,
+            growth_factor=growth_factor,
+            backoff_factor=backoff_factor,
+            growth_interval=growth_interval,
+            enabled=enabled,
+        )
diff --git a/MLPY/Lib/site-packages/torch/cuda/__init__.py b/MLPY/Lib/site-packages/torch/cuda/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d49f1a4c4417ad0856bcf5d6a99c71a6afbf0363
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/__init__.py
@@ -0,0 +1,1412 @@
+r"""
+This package adds support for CUDA tensor types.
+
+It implements the same function as CPU tensors, but they utilize
+GPUs for computation.
+
+It is lazily initialized, so you can always import it, and use
+:func:`is_available()` to determine if your system supports CUDA.
+
+:ref:`cuda-semantics` has more details about working with CUDA.
+"""
+
+
+import contextlib
+import importlib
+import os
+import sys
+import threading
+import traceback
+import warnings
+from functools import lru_cache
+from typing import Any, Callable, cast, List, Optional, Tuple, Union
+
+import torch
+import torch._C
+from torch.types import Device
+from .. import device as _device
+from .._utils import _dummy_type, _LazySeedTracker, classproperty
+from ._utils import _get_device_index
+from .graphs import (
+    CUDAGraph,
+    graph,
+    graph_pool_handle,
+    is_current_stream_capturing,
+    make_graphed_callables,
+)
+from .streams import Event, ExternalStream, Stream
+
+try:
+    from torch._C import _cudart  # type: ignore[attr-defined]
+except ImportError:
+    _cudart = None
+
+_initialized = False
+_tls = threading.local()
+_initialization_lock = threading.Lock()
+_queued_calls: List[
+    Tuple[Callable[[], None], List[str]]
+] = []  # don't invoke these until initialization occurs
+_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
+_device_t = Union[_device, str, int, None]
+
+_HAS_PYNVML = False
+_PYNVML_ERR = None
+try:
+    import pynvml  # type: ignore[import]
+
+    _HAS_PYNVML = True
+except ImportError as err:
+    _PYNVML_ERR = err  # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
+
+_lazy_seed_tracker = _LazySeedTracker()
+
+# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
+if hasattr(torch._C, "_CudaDeviceProperties"):
+    _CudaDeviceProperties = torch._C._CudaDeviceProperties
+else:
+    _CudaDeviceProperties = _dummy_type("_CudaDeviceProperties")  # type: ignore[assignment, misc]
+
+if hasattr(torch._C, "_cuda_exchangeDevice"):
+    _exchange_device = torch._C._cuda_exchangeDevice
+else:
+
+    def _exchange_device(device: int) -> int:
+        if device < 0:
+            return -1
+        raise RuntimeError("PyTorch was compiled without CUDA support")
+
+
+if hasattr(torch._C, "_cuda_maybeExchangeDevice"):
+    _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice
+else:
+
+    def _maybe_exchange_device(device: int) -> int:
+        if device < 0:
+            return -1
+        raise RuntimeError("PyTorch was compiled without CUDA support")
+
+
+has_half: bool = True
+has_magma: bool = torch._C._has_magma
+
+default_generators: Tuple[torch._C.Generator] = ()  # type: ignore[assignment]
+
+
+def _is_compiled() -> bool:
+    r"""Return true if compile with CUDA support."""
+    return hasattr(torch._C, "_cuda_getDeviceCount")
+
+
+def _nvml_based_avail() -> bool:
+    return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1"
+
+
+def is_available() -> bool:
+    r"""Return a bool indicating if CUDA is currently available."""
+    if not _is_compiled():
+        return False
+    if _nvml_based_avail():
+        # The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
+        # using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
+        # fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
+        return device_count() > 0
+    else:
+        # The default availability inspection never throws and returns 0 if the driver is missing or can't
+        # be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
+        # API via `cuInit`
+        return torch._C._cuda_getDeviceCount() > 0
+
+
+def is_bf16_supported():
+    r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
+    # Check for ROCm, if true return true, no ROCM_VERSION check required,
+    # since it is supported on AMD GPU archs.
+    if torch.version.hip:
+        return True
+
+    device = torch.cuda.current_device()
+
+    # Check for CUDA version and device compute capability.
+    # This is a fast way to check for it.
+    cuda_version = torch.version.cuda
+    if (
+        cuda_version is not None
+        and int(cuda_version.split(".")[0]) >= 11
+        and torch.cuda.get_device_properties(device).major >= 8
+    ):
+        return True
+
+    # Finally try to create a bfloat16 device.
+    return _check_bf16_tensor_supported(device)
+
+
+@lru_cache(maxsize=16)
+def _check_bf16_tensor_supported(device: _device_t):
+    try:
+        torch.tensor([1.0], dtype=torch.bfloat16, device=device)
+        return True
+    except Exception:
+        return False
+
+
+def _sleep(cycles):
+    torch._C._cuda_sleep(cycles)
+
+
+def _check_capability():
+    incorrect_binary_warn = """
+    Found GPU%d %s which requires CUDA_VERSION >= %d to
+     work properly, but your PyTorch was compiled
+     with CUDA_VERSION %d. Please install the correct PyTorch binary
+     using instructions from https://pytorch.org
+    """
+
+    old_gpu_warn = """
+    Found GPU%d %s which is of cuda capability %d.%d.
+    PyTorch no longer supports this GPU because it is too old.
+    The minimum cuda capability supported by this library is %d.%d.
+    """
+
+    if torch.version.cuda is not None:  # on ROCm we don't want this check
+        CUDA_VERSION = torch._C._cuda_getCompiledVersion()
+        for d in range(device_count()):
+            capability = get_device_capability(d)
+            major = capability[0]
+            minor = capability[1]
+            name = get_device_name(d)
+            current_arch = major * 10 + minor
+            min_arch = min(
+                (int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list()),
+                default=35,
+            )
+            if current_arch < min_arch:
+                warnings.warn(
+                    old_gpu_warn
+                    % (d, name, major, minor, min_arch // 10, min_arch % 10)
+                )
+
+
+def _check_cubins():
+    incompatible_device_warn = """
+{} with CUDA capability sm_{} is not compatible with the current PyTorch installation.
+The current PyTorch install supports CUDA capabilities {}.
+If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
+"""
+    if torch.version.cuda is None:  # on ROCm we don't want this check
+        return
+    arch_list = get_arch_list()
+    if len(arch_list) == 0:
+        return
+    supported_sm = [int(arch.split("_")[1]) for arch in arch_list if "sm_" in arch]
+    for idx in range(device_count()):
+        cap_major, cap_minor = get_device_capability(idx)
+        # NVIDIA GPU compute architectures are backward compatible within major version
+        supported = any(sm // 10 == cap_major for sm in supported_sm)
+        if not supported:
+            device_name = get_device_name(idx)
+            capability = cap_major * 10 + cap_minor
+            warnings.warn(
+                incompatible_device_warn.format(
+                    device_name, capability, " ".join(arch_list), device_name
+                )
+            )
+
+
+def is_initialized():
+    r"""Return whether PyTorch's CUDA state has been initialized."""
+    return _initialized and not _is_in_bad_fork()
+
+
+def _lazy_call(callable, **kwargs):
+    if is_initialized():
+        callable()
+    else:
+        # TODO(torch_deploy): this accesses linecache, which attempts to read the
+        # file system to get traceback info. Patch linecache or do something
+        # else here if this ends up being important.
+        global _lazy_seed_tracker
+        if kwargs.get("seed_all", False):
+            _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
+        elif kwargs.get("seed", False):
+            _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
+        else:
+            # Don't store the actual traceback to avoid memory cycle
+            _queued_calls.append((callable, traceback.format_stack()))
+
+
+_lazy_call(_check_capability)
+_lazy_call(_check_cubins)
+
+
+class DeferredCudaCallError(Exception):
+    pass
+
+
+OutOfMemoryError = torch._C._OutOfMemoryError
+
+
+def init():
+    r"""Initialize PyTorch's CUDA state.
+
+    You may need to call this explicitly if you are interacting with
+    PyTorch via its C API, as Python bindings for CUDA functionality
+    will not be available until this initialization takes place.
+    Ordinary users should not need this, as all of PyTorch's CUDA methods
+    automatically initialize CUDA state on-demand.
+
+    Does nothing if the CUDA state is already initialized.
+    """
+    _lazy_init()
+
+
+def _lazy_init():
+    global _initialized, _queued_calls
+    if is_initialized() or hasattr(_tls, "is_initializing"):
+        return
+    with _initialization_lock:
+        # We be double-checked locking, boys!  This is OK because
+        # the above test was GIL protected anyway.  The inner test
+        # is for when a thread blocked on some other thread which was
+        # doing the initialization; when they get the lock, they will
+        # find there is nothing left to do.
+        if is_initialized():
+            return
+        # It is important to prevent other threads from entering _lazy_init
+        # immediately, while we are still guaranteed to have the GIL, because some
+        # of the C calls we make below will release the GIL
+        if _is_in_bad_fork():
+            raise RuntimeError(
+                "Cannot re-initialize CUDA in forked subprocess. To use CUDA with "
+                "multiprocessing, you must use the 'spawn' start method"
+            )
+        if not hasattr(torch._C, "_cuda_getDeviceCount"):
+            raise AssertionError("Torch not compiled with CUDA enabled")
+        if _cudart is None:
+            raise AssertionError(
+                "libcudart functions unavailable. It looks like you have a broken build?"
+            )
+        # This function throws if there's a driver initialization error, no GPUs
+        # are found or any other error occurs
+        if "CUDA_MODULE_LOADING" not in os.environ:
+            os.environ["CUDA_MODULE_LOADING"] = "LAZY"
+        torch._C._cuda_init()
+        # Some of the queued calls may reentrantly call _lazy_init();
+        # we need to just return without initializing in that case.
+        # However, we must not let any *other* threads in!
+        _tls.is_initializing = True
+
+        for calls in _lazy_seed_tracker.get_calls():
+            if calls:
+                _queued_calls.append(calls)
+
+        try:
+            for queued_call, orig_traceback in _queued_calls:
+                try:
+                    queued_call()
+                except Exception as e:
+                    msg = (
+                        f"CUDA call failed lazily at initialization with error: {str(e)}\n\n"
+                        f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}"
+                    )
+                    raise DeferredCudaCallError(msg) from e
+        finally:
+            delattr(_tls, "is_initializing")
+        _initialized = True
+
+
+def cudart():
+    _lazy_init()
+    return _cudart
+
+
+class cudaStatus:
+    SUCCESS: int = 0
+    ERROR_NOT_READY: int = 34
+
+
+class CudaError(RuntimeError):
+    def __init__(self, code: int) -> None:
+        msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
+        super().__init__(f"{msg} ({code})")
+
+
+def check_error(res: int) -> None:
+    if res != _cudart.cudaError.success:
+        raise CudaError(res)
+
+
+class _DeviceGuard:
+    def __init__(self, index: int):
+        self.idx = index
+        self.prev_idx = -1
+
+    def __enter__(self):
+        self.prev_idx = torch.cuda._exchange_device(self.idx)
+
+    def __exit__(self, type: Any, value: Any, traceback: Any):
+        self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
+        return False
+
+
+class device:
+    r"""Context-manager that changes the selected device.
+
+    Args:
+        device (torch.device or int): device index to select. It's a no-op if
+            this argument is a negative integer or ``None``.
+    """
+
+    def __init__(self, device: Any):
+        self.idx = _get_device_index(device, optional=True)
+        self.prev_idx = -1
+
+    def __enter__(self):
+        self.prev_idx = torch.cuda._exchange_device(self.idx)
+
+    def __exit__(self, type: Any, value: Any, traceback: Any):
+        self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
+        return False
+
+
+class device_of(device):
+    r"""Context-manager that changes the current device to that of given object.
+
+    You can use both tensors and storages as arguments. If a given object is
+    not allocated on a GPU, this is a no-op.
+
+    Args:
+        obj (Tensor or Storage): object allocated on the selected device.
+    """
+
+    def __init__(self, obj):
+        idx = obj.get_device() if obj.is_cuda else -1
+        super().__init__(idx)
+
+
+def set_device(device: _device_t) -> None:
+    r"""Set the current device.
+
+    Usage of this function is discouraged in favor of :any:`device`. In most
+    cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
+
+    Args:
+        device (torch.device or int): selected device. This function is a no-op
+            if this argument is negative.
+    """
+    device = _get_device_index(device)
+    if device >= 0:
+        torch._C._cuda_setDevice(device)
+
+
+def get_device_name(device: Optional[_device_t] = None) -> str:
+    r"""Get the name of a device.
+
+    Args:
+        device (torch.device or int, optional): device for which to return the
+            name. This function is a no-op if this argument is a negative
+            integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    Returns:
+        str: the name of the device
+    """
+    return get_device_properties(device).name
+
+
+def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]:
+    r"""Get the cuda capability of a device.
+
+    Args:
+        device (torch.device or int, optional): device for which to return the
+            device capability. This function is a no-op if this argument is
+            a negative integer. It uses the current device, given by
+            :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            (default).
+
+    Returns:
+        tuple(int, int): the major and minor cuda capability of the device
+    """
+    prop = get_device_properties(device)
+    return prop.major, prop.minor
+
+
+def get_device_properties(device: _device_t) -> _CudaDeviceProperties:
+    r"""Get the properties of a device.
+
+    Args:
+        device (torch.device or int or str): device for which to return the
+            properties of the device.
+
+    Returns:
+        _CudaDeviceProperties: the properties of the device
+    """
+    _lazy_init()  # will define _get_device_properties
+    device = _get_device_index(device, optional=True)
+    if device < 0 or device >= device_count():
+        raise AssertionError("Invalid device id")
+    return _get_device_properties(device)  # type: ignore[name-defined]
+
+
+def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool:
+    r"""Check if peer access between two devices is possible."""
+    _lazy_init()
+    device = _get_device_index(device, optional=True)
+    peer_device = _get_device_index(peer_device)
+    if device < 0 or device >= device_count():
+        raise AssertionError("Invalid device id")
+    if peer_device < 0 or peer_device >= device_count():
+        raise AssertionError("Invalid peer device id")
+    return torch._C._cuda_canDeviceAccessPeer(device, peer_device)
+
+
+class StreamContext:
+    r"""Context-manager that selects a given stream.
+
+    All CUDA kernels queued within its context will be enqueued on a selected
+    stream.
+
+    Args:
+        Stream (Stream): selected stream. This manager is a no-op if it's
+            ``None``.
+    .. note:: Streams are per-device.
+    """
+    cur_stream: Optional["torch.cuda.Stream"]
+
+    def __init__(self, stream: Optional["torch.cuda.Stream"]):
+        self.stream = stream
+        self.idx = _get_device_index(None, True)
+        if not torch.jit.is_scripting():
+            if self.idx is None:
+                self.idx = -1
+
+        self.src_prev_stream = (
+            None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
+        )
+        self.dst_prev_stream = (
+            None if not torch.jit.is_scripting() else torch.cuda.default_stream(None)
+        )
+
+    def __enter__(self):
+        # Local cur_stream variable for type refinement
+        cur_stream = self.stream
+        # Return if stream is None or CUDA device not available
+        if cur_stream is None or self.idx == -1:
+            return
+        self.src_prev_stream = torch.cuda.current_stream(None)
+
+        # If the stream is not on the current device, then
+        # set the current stream on the device
+        if self.src_prev_stream.device != cur_stream.device:
+            with device(cur_stream.device):
+                self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device)
+        torch.cuda.set_stream(cur_stream)
+
+    def __exit__(self, type: Any, value: Any, traceback: Any):
+        # Local cur_stream variable for type refinement
+        cur_stream = self.stream
+        # If stream is None or no CUDA device available, return
+        if cur_stream is None or self.idx == -1:
+            return
+
+        # Reset the stream on the original device
+        # and destination device
+        if self.src_prev_stream.device != cur_stream.device:  # type: ignore[union-attr]
+            torch.cuda.set_stream(self.dst_prev_stream)  # type: ignore[arg-type]
+        torch.cuda.set_stream(self.src_prev_stream)  # type: ignore[arg-type]
+
+
+def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext:
+    r"""Wrap around the Context-manager StreamContext that selects a given stream.
+
+    Arguments:
+        stream (Stream): selected stream. This manager is a no-op if it's
+            ``None``.
+    ..Note:: In eager mode stream is of type Stream class while in JIT it is
+    an object of the custom class ``torch.classes.cuda.Stream``.
+    """
+    return StreamContext(stream)
+
+
+def _set_stream_by_id(stream_id, device_index, device_type):
+    r"""set stream specified by the stream id, device index and
+        device type
+
+    Args: stream_id (int): stream id in stream pool
+          device_index (int): device index in topo
+          device_type (int): enum device type
+    """
+    torch._C._cuda_setStream(
+        stream_id=stream_id,
+        device_index=device_index,
+        device_type=device_type,
+    )
+
+
+def set_stream(stream: Stream):
+    r"""Set the current stream.This is a wrapper API to set the stream.
+        Usage of this function is discouraged in favor of the ``stream``
+        context manager.
+
+    Args:
+        stream (Stream): selected stream. This function is a no-op
+            if this argument is ``None``.
+    """
+    if stream is None:
+        return
+    _set_stream_by_id(
+        stream_id=stream.stream_id,
+        device_index=stream.device_index,
+        device_type=stream.device_type,
+    )
+
+
+def _parse_visible_devices() -> Union[List[int], List[str]]:
+    r"""Parse CUDA_VISIBLE_DEVICES environment variable."""
+    var = os.getenv("CUDA_VISIBLE_DEVICES")
+    if var is None:
+        return list(range(64))
+
+    def _strtoul(s: str) -> int:
+        """Return -1 or positive integer sequence string starts with."""
+        if not s:
+            return -1
+        for idx, c in enumerate(s):
+            if not (c.isdigit() or (idx == 0 and c in "+-")):
+                break
+            if idx + 1 == len(s):
+                idx += 1
+        return int(s[:idx]) if idx > 0 else -1
+
+    def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
+        rcs: List[str] = []
+        for elem in lst.split(","):
+            # Repeated id results in empty set
+            if elem in rcs:
+                return cast(List[str], [])
+            # Anything other but prefix is ignored
+            if not elem.startswith(prefix):
+                break
+            rcs.append(elem)
+        return rcs
+
+    if var.startswith("GPU-"):
+        return parse_list_with_prefix(var, "GPU-")
+    if var.startswith("MIG-"):
+        return parse_list_with_prefix(var, "MIG-")
+    # CUDA_VISIBLE_DEVICES uses something like strtoul
+    # which makes `1gpu2,2ampere` is equivalent to `1,2`
+    rc: List[int] = []
+    for elem in var.split(","):
+        x = _strtoul(elem.strip())
+        # Repeated ordinal results in empty set
+        if x in rc:
+            return cast(List[int], [])
+        # Negative value aborts the sequence
+        if x < 0:
+            break
+        rc.append(x)
+    return rc
+
+
+def _raw_device_count_nvml() -> int:
+    r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
+    from ctypes import byref, c_int, CDLL
+
+    nvml_h = CDLL("libnvidia-ml.so.1")
+    rc = nvml_h.nvmlInit()
+    if rc != 0:
+        warnings.warn("Can't initialize NVML")
+        return -1
+    dev_count = c_int(-1)
+    rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
+    if rc != 0:
+        warnings.warn("Can't get nvml device count")
+        return -1
+    del nvml_h
+    return dev_count.value
+
+
+def _raw_device_uuid_nvml() -> Optional[List[str]]:
+    r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
+    from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
+
+    nvml_h = CDLL("libnvidia-ml.so.1")
+    rc = nvml_h.nvmlInit()
+    if rc != 0:
+        warnings.warn("Can't initialize NVML")
+        return None
+    dev_count = c_int(-1)
+    rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
+    if rc != 0:
+        warnings.warn("Can't get nvml device count")
+        return None
+    uuids: List[str] = []
+    for idx in range(dev_count.value):
+        dev_id = c_void_p()
+        rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
+        if rc != 0:
+            warnings.warn("Can't get device handle")
+            return None
+        buf_len = 96
+        buf = create_string_buffer(buf_len)
+        rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
+        if rc != 0:
+            warnings.warn("Can't get device UUID")
+            return None
+        uuids.append(buf.raw.decode("ascii").strip("\0"))
+    del nvml_h
+    return uuids
+
+
+def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
+    r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs."""
+
+    def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
+        best_match = -1
+        for idx, uuid in enumerate(uuids):
+            if not uuid.startswith(candidate):
+                continue
+            # Ambiguous candidate
+            if best_match != -1:
+                return -1
+            best_match = idx
+        return best_match
+
+    rc: List[int] = []
+    for candidate in candidates:
+        idx = uuid_to_orinal(candidate, uuids)
+        # First invalid ordinal stops parsing
+        if idx < 0:
+            break
+        # Duplicates result in empty set
+        if idx in rc:
+            return cast(List[int], [])
+        rc.append(idx)
+    return rc
+
+
+def _device_count_nvml() -> int:
+    r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
+
+    Negative value is returned if NVML discovery or initialization has failed.
+    """
+    visible_devices = _parse_visible_devices()
+    if not visible_devices:
+        return 0
+    try:
+        if type(visible_devices[0]) is str:
+            # Skip MIG parsing
+            if visible_devices[0].startswith("MIG-"):
+                return -1
+            uuids = _raw_device_uuid_nvml()
+            if uuids is None:
+                return -1
+            visible_devices = _transform_uuid_to_ordinals(
+                cast(List[str], visible_devices), uuids
+            )
+        else:
+            raw_cnt = _raw_device_count_nvml()
+            if raw_cnt <= 0:
+                return raw_cnt
+            # Trim the list up to a maximum available device
+            for idx, val in enumerate(visible_devices):
+                if cast(int, val) >= raw_cnt:
+                    return idx
+    except OSError:
+        return -1
+    except AttributeError:
+        return -1
+    return len(visible_devices)
+
+
+def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int:
+    r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account."""
+    idx = _get_device_index(device, optional=True)
+    visible_devices = _parse_visible_devices()
+    if type(visible_devices[0]) is str:
+        uuids = _raw_device_uuid_nvml()
+        if uuids is None:
+            raise RuntimeError("Can't get device UUIDs")
+        visible_devices = _transform_uuid_to_ordinals(
+            cast(List[str], visible_devices), uuids
+        )
+    visible_devices = cast(List[int], visible_devices)
+    if idx < 0 or idx >= len(visible_devices):
+        raise RuntimeError(
+            f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})"
+        )
+    return visible_devices[idx]
+
+
+@lru_cache(maxsize=1)
+def device_count() -> int:
+    r"""Return the number of GPUs available."""
+    if not _is_compiled():
+        return 0
+    # bypass _device_count_nvml() if rocm (not supported)
+    nvml_count = -1 if torch.version.hip else _device_count_nvml()
+    return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
+
+
+def get_arch_list() -> List[str]:
+    r"""Return list CUDA architectures this library was compiled for."""
+    if not is_available():
+        return []
+    arch_flags = torch._C._cuda_getArchFlags()
+    if arch_flags is None:
+        return []
+    return arch_flags.split()
+
+
+def get_gencode_flags() -> str:
+    r"""Return NVCC gencode flags this library was compiled with."""
+    arch_list = get_arch_list()
+    if len(arch_list) == 0:
+        return ""
+    arch_list_ = [arch.split("_") for arch in arch_list]
+    return " ".join(
+        [
+            f"-gencode compute=compute_{arch},code={kind}_{arch}"
+            for (kind, arch) in arch_list_
+        ]
+    )
+
+
+def current_device() -> int:
+    r"""Return the index of a currently selected device."""
+    _lazy_init()
+    return torch._C._cuda_getDevice()
+
+
+def synchronize(device: _device_t = None) -> None:
+    r"""Wait for all kernels in all streams on a CUDA device to complete.
+
+    Args:
+        device (torch.device or int, optional): device for which to synchronize.
+            It uses the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+    """
+    _lazy_init()
+    with torch.cuda.device(device):
+        return torch._C._cuda_synchronize()
+
+
+def ipc_collect():
+    r"""Force collects GPU memory after it has been released by CUDA IPC.
+
+    .. note::
+        Checks if any sent CUDA tensors could be cleaned from the memory. Force
+        closes shared memory file used for reference counting if there is no
+        active counters. Useful when the producer process stopped actively sending
+        tensors and want to release unused memory.
+    """
+    _lazy_init()
+    return torch._C._cuda_ipc_collect()
+
+
+def current_stream(device: Optional[_device_t] = None) -> Stream:
+    r"""Return the currently selected :class:`Stream` for a given device.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            the currently selected :class:`Stream` for the current device, given
+            by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            (default).
+    """
+    _lazy_init()
+    streamdata = torch._C._cuda_getCurrentStream(
+        _get_device_index(device, optional=True)
+    )
+    return Stream(
+        stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
+    )
+
+
+def default_stream(device: Optional[_device_t] = None) -> Stream:
+    r"""Return the default :class:`Stream` for a given device.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            the default :class:`Stream` for the current device, given by
+            :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            (default).
+    """
+    _lazy_init()
+    streamdata = torch._C._cuda_getDefaultStream(
+        _get_device_index(device, optional=True)
+    )
+    return Stream(
+        stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
+    )
+
+
+def current_blas_handle():
+    r"""Return cublasHandle_t pointer to current cuBLAS handle"""
+    _lazy_init()
+    return torch._C._cuda_getCurrentBlasHandle()
+
+
+def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
+    r"""Set the debug mode for cuda synchronizing operations.
+
+    Args:
+        debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
+            if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
+
+    Warning:
+        This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
+        particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
+    """
+    _lazy_init()
+    if isinstance(debug_mode, str):
+        if debug_mode == "default":
+            debug_mode = 0
+        elif debug_mode == "warn":
+            debug_mode = 1
+        elif debug_mode == "error":
+            debug_mode = 2
+        else:
+            raise RuntimeError(
+                "invalid value of debug_mode, expected one of `default`, `warn`, `error`"
+            )
+
+    torch._C._cuda_set_sync_debug_mode(debug_mode)
+
+
+def get_sync_debug_mode() -> int:
+    r"""Return current value of debug mode for cuda synchronizing operations."""
+    _lazy_init()
+    return torch._C._cuda_get_sync_debug_mode()
+
+
+def _get_pynvml_handler(device: Optional[Union[Device, int]] = None):
+    if not _HAS_PYNVML:
+        raise ModuleNotFoundError(
+            "pynvml does not seem to be installed or it can't be imported."
+        ) from _PYNVML_ERR
+    from pynvml import NVMLError_DriverNotLoaded
+
+    try:
+        pynvml.nvmlInit()
+    except NVMLError_DriverNotLoaded as e:
+        raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e
+
+    device = _get_nvml_device_index(device)
+    handle = pynvml.nvmlDeviceGetHandleByIndex(device)
+    return handle
+
+
+def memory_usage(device: Optional[Union[Device, int]] = None) -> int:
+    r"""Return the percent of time over the past sample period during which global (device)
+    memory was being read or written as given by `nvidia-smi`.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    Warning: Each sample period may be between 1 second and 1/6 second,
+    depending on the product being queried.
+    """
+    handle = _get_pynvml_handler()
+
+    device = _get_nvml_device_index(device)
+    handle = pynvml.nvmlDeviceGetHandleByIndex(device)
+    return pynvml.nvmlDeviceGetUtilizationRates(handle).memory
+
+
+def utilization(device: Optional[Union[Device, int]] = None) -> int:
+    r"""Return the percent of time over the past sample period during which one or
+    more kernels was executing on the GPU as given by `nvidia-smi`.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    Warning: Each sample period may be between 1 second and 1/6 second,
+    depending on the product being queried.
+    """
+    handle = _get_pynvml_handler(device)
+    device = _get_nvml_device_index(device)
+    handle = pynvml.nvmlDeviceGetHandleByIndex(device)
+    return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
+
+
+def temperature(device: Optional[Union[Device, int]] = None) -> int:
+    r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades).
+
+    The average temperature is computed based on past sample period as given by `nvidia-smi`.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    Warning: Each sample period may be between 1 second and 1/6 second,
+    depending on the product being queried.
+    """
+    handle = _get_pynvml_handler(device)
+    # 0 refers to the temperature sensor for the GPU die.
+    return pynvml.nvmlDeviceGetTemperature(handle, 0)
+
+
+def power_draw(device: Optional[Union[Device, int]] = None) -> int:
+    r"""Return the average power draw of the GPU sensor in mW (MilliWatts)
+        over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    Warning: Each sample period may be between 1 second and 1/6 second,
+    depending on the product being queried.
+    """
+    handle = _get_pynvml_handler(device)
+    return pynvml.nvmlDeviceGetPowerUsage(handle)
+
+
+def clock_rate(device: Optional[Union[Device, int]] = None) -> int:
+    r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    Warning: Each sample period may be between 1 second and 1/6 second,
+    depending on the product being queried.
+    """
+    handle = _get_pynvml_handler(device)
+    return pynvml.nvmlDeviceGetClockInfo(handle, 1)
+
+
+def _get_device(device: Union[int, str, torch.device]) -> torch.device:
+    r"""Return the torch.device type object from the passed in device.
+
+    Args:
+        device (torch.device or int): selected device.
+    """
+    if isinstance(device, str):
+        device = torch.device(device)
+    elif isinstance(device, int):
+        device = torch.device("cuda", device)
+    return device
+
+
+def _get_generator(device: torch.device) -> torch._C.Generator:
+    r"""Return the CUDA Generator object for the given device.
+
+    Args:
+        device (torch.device): selected device.
+    """
+    idx = device.index
+    if idx is None:
+        idx = current_device()
+    return torch.cuda.default_generators[idx]
+
+
+def _set_rng_state_offset(
+    offset: int, device: Union[int, str, torch.device] = "cuda"
+) -> None:
+    r"""Set the random number generator state offset of the specified GPU.
+
+    Args:
+        offset (int): The desired offset
+        device (torch.device or int, optional): The device to set the RNG state.
+            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
+    """
+    final_device = _get_device(device)
+
+    def cb():
+        default_generator = _get_generator(final_device)
+        default_generator.set_offset(offset)
+
+    _lazy_call(cb)
+
+
+def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int:
+    r"""Return the random number generator state offset of the specified GPU.
+
+    Args:
+        device (torch.device or int, optional): The device to return the RNG state offset of.
+            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
+
+    .. warning::
+        This function eagerly initializes CUDA.
+    """
+    _lazy_init()
+    final_device = _get_device(device)
+    default_generator = _get_generator(final_device)
+    return default_generator.get_offset()
+
+
+from .memory import *  # noqa: F403
+
+
+from .random import *  # noqa: F403
+
+################################################################################
+# Define Storage and Tensor classes
+################################################################################
+
+
+@staticmethod  # type: ignore[misc]
+def _lazy_new(cls, *args, **kwargs):
+    _lazy_init()
+    # We may need to call lazy init again if we are a forked child
+    # del _CudaBase.__new__
+    return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
+
+
+class _CudaBase:
+    is_cuda = True
+    is_sparse = False
+
+    def type(self, *args, **kwargs):
+        # We could use a Protocol here to tell mypy that self has `get_device` method
+        # but it is only available in the typing module on Python >= 3.8
+        # or on typing_extensions module on Python >= 3.6
+        with device(self.get_device()):  # type: ignore[attr-defined]
+            return super().type(*args, **kwargs)  # type: ignore[misc]
+
+    __new__ = _lazy_new
+
+
+from torch.storage import _LegacyStorage, _warn_typed_storage_removal
+
+
+class _CudaLegacyStorage(_LegacyStorage):
+    @classmethod
+    def from_buffer(cls, *args, **kwargs):
+        _warn_typed_storage_removal()
+        raise RuntimeError("from_buffer: Not available for CUDA storage")
+
+    @classmethod
+    def _new_with_weak_ptr(cls, *args, **kwargs):
+        raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage")
+
+    @classmethod
+    def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
+        raise RuntimeError("_new_shared_filename: Not available for CUDA storage")
+
+
+class ByteStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.uint8
+
+
+class DoubleStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.double
+
+
+class FloatStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.float
+
+
+class HalfStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.half
+
+
+class LongStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.long
+
+
+class IntStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.int
+
+
+class ShortStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.short
+
+
+class CharStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.int8
+
+
+class BoolStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.bool
+
+
+class BFloat16Storage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.bfloat16
+
+
+class ComplexDoubleStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.cdouble
+
+
+class ComplexFloatStorage(_CudaLegacyStorage):
+    @classproperty
+    def dtype(self):
+        _warn_typed_storage_removal()
+        return self._dtype
+
+    @classproperty
+    def _dtype(self):
+        return torch.cfloat
+
+
+del _LegacyStorage
+del _CudaLegacyStorage
+
+torch._storage_classes.add(DoubleStorage)
+torch._storage_classes.add(FloatStorage)
+torch._storage_classes.add(LongStorage)
+torch._storage_classes.add(IntStorage)
+torch._storage_classes.add(ShortStorage)
+torch._storage_classes.add(CharStorage)
+torch._storage_classes.add(ByteStorage)
+torch._storage_classes.add(HalfStorage)
+torch._storage_classes.add(BoolStorage)
+torch._storage_classes.add(BFloat16Storage)
+torch._storage_classes.add(ComplexDoubleStorage)
+torch._storage_classes.add(ComplexFloatStorage)
+
+
+class _WrappedTritonKernel:
+    """Just a simple wrapper to store some metadata for testing purposes."""
+
+    def __init__(self, kernel):
+        self.kernel = kernel
+        self.kernel_invoked = False
+
+    def __call__(self, *args, **kwargs):
+        res = self.kernel(*args, **kwargs)
+        self.kernel_invoked = True
+        return res
+
+
+def _register_triton_kernels():
+    if torch._running_with_deploy():
+        return
+
+    @_WrappedTritonKernel
+    def kernel_impl(*args, **kwargs):
+        from torch.sparse._triton_ops import bsr_dense_mm
+
+        return bsr_dense_mm(*args, skip_checks=True, **kwargs)
+
+    @_WrappedTritonKernel
+    def addmm_kernel_impl(*args, **kwargs):
+        from torch.sparse._triton_ops import bsr_dense_addmm
+
+        return bsr_dense_addmm(*args, skip_checks=True, **kwargs)
+
+    has_triton = importlib.util.find_spec("triton") is not None
+    if has_triton:
+        torch._TritonLibrary.registerOp(
+            "_triton_bsr_dense_mm_out",
+            "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
+            kernel_impl,
+            "SparseCsrCUDA",
+        )
+
+        torch._TritonLibrary.registerOp(
+            "_triton_bsr_dense_addmm_out",
+            (
+                "_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense,"
+                " *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)"
+            ),
+            addmm_kernel_impl,
+            "SparseCsrCUDA",
+        )
+
+
+_lazy_call(_register_triton_kernels)
+
+
+from . import amp, jiterator, nvtx, profiler, sparse
+
+__all__ = [
+    # Typed storage and tensors
+    "BFloat16Storage",
+    "BFloat16Tensor",
+    "BoolStorage",
+    "BoolTensor",
+    "ByteStorage",
+    "ByteTensor",
+    "CharStorage",
+    "CharTensor",
+    "ComplexDoubleStorage",
+    "ComplexFloatStorage",
+    "DoubleStorage",
+    "DoubleTensor",
+    "FloatStorage",
+    "FloatTensor",
+    "HalfStorage",
+    "HalfTensor",
+    "IntStorage",
+    "IntTensor",
+    "LongStorage",
+    "LongTensor",
+    "ShortStorage",
+    "ShortTensor",
+    "CUDAGraph",
+    "CudaError",
+    "DeferredCudaCallError",
+    "Event",
+    "ExternalStream",
+    "OutOfMemoryError",
+    "Stream",
+    "StreamContext",
+    "amp",
+    "caching_allocator_alloc",
+    "caching_allocator_delete",
+    "can_device_access_peer",
+    "check_error",
+    "cudaStatus",
+    "cudart",
+    "current_blas_handle",
+    "current_device",
+    "current_stream",
+    "default_generators",
+    "default_stream",
+    "device",
+    "device_count",
+    "device_of",
+    "empty_cache",
+    "get_allocator_backend",
+    "CUDAPluggableAllocator",
+    "change_current_allocator",
+    "get_arch_list",
+    "get_device_capability",
+    "get_device_name",
+    "get_device_properties",
+    "get_gencode_flags",
+    "get_rng_state",
+    "get_rng_state_all",
+    "get_sync_debug_mode",
+    "graph",
+    "graph_pool_handle",
+    "graphs",
+    "has_half",
+    "has_magma",
+    "init",
+    "initial_seed",
+    "ipc_collect",
+    "is_available",
+    "is_bf16_supported",
+    "is_current_stream_capturing",
+    "is_initialized",
+    "jiterator",
+    "list_gpu_processes",
+    "make_graphed_callables",
+    "manual_seed",
+    "manual_seed_all",
+    "max_memory_allocated",
+    "max_memory_cached",
+    "max_memory_reserved",
+    "mem_get_info",
+    "memory",
+    "memory_allocated",
+    "memory_cached",
+    "memory_reserved",
+    "memory_snapshot",
+    "memory_stats",
+    "memory_stats_as_nested_dict",
+    "memory_summary",
+    "memory_usage",
+    "temperature",
+    "power_draw",
+    "clock_rate",
+    "nccl",
+    "nvtx",
+    "profiler",
+    "random",
+    "reset_accumulated_memory_stats",
+    "reset_max_memory_allocated",
+    "reset_max_memory_cached",
+    "reset_peak_memory_stats",
+    "seed",
+    "seed_all",
+    "set_device",
+    "set_per_process_memory_fraction",
+    "set_rng_state",
+    "set_rng_state_all",
+    "set_stream",
+    "set_sync_debug_mode",
+    "sparse",
+    "stream",
+    "streams",
+    "synchronize",
+    "utilization",
+]
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3096f4e8512a1c5cc37d7328d5bb79cc24640b4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52ac78d5adff2e7ce813288779dec57ce5ce2954
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a79d4618c950de3b366622fbcc6a1cbbc9da95e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..227d58e8b6a28b41e3123278d85d87bdfefbe0c3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/comm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/comm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fb856cc5f5a44403612ec678de1174d2eda150f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/comm.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/error.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/error.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70c4e65a59eee7949c863c25f7425f82262da24e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/error.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/graphs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/graphs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19caddd9980441d869e4befd4b346ec28d336a54
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/graphs.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/jiterator.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/jiterator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8cc9465913ef891167f69289b0371b7c132e0f6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/jiterator.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/memory.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/memory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48d451960c45f566a189beaf1ab4038444303c06
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/memory.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/nccl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/nccl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66e77a67e7a99b038c21600a9dca06d76a902af1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/nccl.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/nvtx.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/nvtx.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09bde3a7d236c6828e79c62ffa106452795c2b9e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/nvtx.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/profiler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/profiler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9842e1e926c3ba9e5cd8bbe52e294ecf3ab17672
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/profiler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/random.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/random.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1611416ee848d252856cdc44efbeb42679b620f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/random.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/sparse.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/sparse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..162da8c0e2dff0ee9c526bad6fc0a9455e1b7f62
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/sparse.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/__pycache__/streams.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/__pycache__/streams.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e15f5665ef01a2b6c455ff6d3eb2fc1bf20a5f49
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/__pycache__/streams.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/_memory_viz.py b/MLPY/Lib/site-packages/torch/cuda/_memory_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c959803f4d01d04b8590fe678df2124c4ba9c68
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/_memory_viz.py
@@ -0,0 +1,626 @@
+import pickle
+import sys
+import os
+import io
+import subprocess
+import json
+from functools import lru_cache
+from typing import Any
+from itertools import groupby
+import base64
+import warnings
+
+cache = lru_cache(None)
+
+__all__ = ["format_flamegraph", "segments", "memory", "compare"]
+
+def _frame_fmt(f, full_filename=False):
+    i = f['line']
+    fname = f['filename']
+    if not full_filename:
+        fname = fname.split('/')[-1]
+    func = f['name']
+    return f'{fname}:{i}:{func}'
+
+@cache
+def _frame_filter(name, filename):
+    omit_functions = [
+        "unwind::unwind",
+        "CapturedTraceback::gather",
+        "gather_with_cpp",
+        "_start",
+        "__libc_start_main",
+        "PyEval_",
+        "PyObject_",
+        "PyFunction_",
+    ]
+    omit_filenames = [
+        "core/boxing",
+        "/Register",
+        "/Redispatch",
+        "pythonrun.c",
+        "Modules/main.c",
+        "Objects/call.c",
+        "Objects/methodobject.c",
+        "pycore_ceval.h",
+        "ceval.c",
+        "cpython/abstract.h",
+    ]
+    for of in omit_functions:
+        if of in name:
+            return False
+    for of in omit_filenames:
+        if of in filename:
+            return False
+    return True
+
+def _frames_fmt(frames, full_filename=False, reverse=False):
+    if reverse:
+        frames = reversed(frames)
+    return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
+
+def _block_extra_legacy(b):
+    if 'history' in b:
+        frames = b['history'][0].get('frames', [])
+        real_size = b['history'][0]['real_size']
+    else:
+        real_size = b.get('requested_size', b['size'])
+        frames = []
+    return frames, real_size
+
+def _block_extra(b):
+    if 'frames' not in b:
+        # old snapshot format made it more complicated to get frames/allocated size
+        return _block_extra_legacy(b)
+    return b['frames'], b['requested_size']
+
+def format_flamegraph(flamegraph_lines, flamegraph_script=None):
+    if flamegraph_script is None:
+        flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
+    if not os.path.exists(flamegraph_script):
+        import urllib.request
+        print(f"Downloading flamegraph.pl to: {flamegraph_script}")
+        urllib.request.urlretrieve(
+            'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
+        subprocess.check_call(['chmod', '+x', flamegraph_script])
+    args = [flamegraph_script, '--countname', 'bytes']
+    p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
+    assert p.stdin is not None
+    assert p.stdout is not None
+    p.stdin.write(flamegraph_lines)
+    p.stdin.close()
+    result = p.stdout.read()
+    p.stdout.close()
+    p.wait()
+    assert p.wait() == 0
+    return result
+
+def _write_blocks(f, prefix, blocks):
+    def frames_fragment(frames):
+        if not frames:
+            return ""
+        return ';'.join(_frames_fmt(frames, reverse=True))
+    for b in blocks:
+        if 'history' not in b:
+            frames, accounted_for_size = _block_extra(b)
+            f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
+        else:
+            accounted_for_size = 0
+            for h in b['history']:
+                sz = h['real_size']
+                accounted_for_size += sz
+                if 'frames' in h:
+                    frames = h['frames']
+                    f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
+                else:
+                    f.write(f'{prefix};{b["state"]}; {sz}\n')
+        gaps = b['size'] - accounted_for_size
+        if gaps:
+            f.write(f'{prefix};{b["state"]}; {gaps}\n')
+
+def segments(snapshot, format_flamegraph=format_flamegraph):
+    f = io.StringIO()
+    for seg in snapshot['segments']:
+        prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
+        _write_blocks(f, prefix, seg['blocks'])
+    return format_flamegraph(f.getvalue())
+
+def memory(snapshot, format_flamegraph=format_flamegraph):
+    f = io.StringIO()
+    for seg in snapshot['segments']:
+        prefix = f'stream_{seg["stream"]}'
+        _write_blocks(f, prefix, seg['blocks'])
+    return format_flamegraph(f.getvalue())
+
+def compare(before, after, format_flamegraph=format_flamegraph):
+    def _seg_key(seg):
+        return (seg['address'], seg['total_size'])
+
+    def _seg_info(seg):
+        return f'stream_{seg["stream"]};seg_{seg["address"]}'
+
+    f = io.StringIO()
+
+    before_segs = {_seg_key(seg) for seg in before}
+    after_segs = {_seg_key(seg) for seg in after}
+
+    print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}')
+    print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}')
+
+    for seg in before:
+        if _seg_key(seg) not in after_segs:
+            _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
+
+    for seg in after:
+        if _seg_key(seg) not in before_segs:
+            _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
+
+    return format_flamegraph(f.getvalue())
+
+def _format_size(num):
+    # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
+    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
+        if abs(num) < 1024.0:
+            return f"{num:3.1f}{unit}B"
+        num /= 1024.0
+    return f"{num:.1f}YiB"
+
+class Bytes:
+    def __init__(self, value):
+        self.value = value
+
+    def __add__(self, rhs):
+        return Bytes(self.value + rhs)
+
+    def __repr__(self):
+        return _format_size(self.value)
+
+def calc_active(seg):
+    return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
+
+def _report_free(free_external, free_internal):
+    total = free_external + free_internal
+    suffix = ''
+    if total != 0:
+        pct = (free_internal / total) * 100
+        suffix = f' ({pct:.1f}% internal)'
+    return f'{Bytes(total)}{suffix}'
+
+PAGE_SIZE = 1024 * 1024 * 20
+legend = f"""\
+
+Legend:
+    [a     ] - a segment in the allocator
+     ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
+    a-z: pages filled with a single block's content
+    ' ': page is completely free
+    *: page if completely full with multiple blocks
+    0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
+    (X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
+"""
+
+def segsum(data):
+    r"""Visually reports how the allocator has filled its segments.
+
+    This printout can help debug fragmentation issues since free fragments
+    will appear as gaps in this printout.  The amount of free space is reported
+    for each segment.
+    We distinguish between internal free memory which occurs because the
+    allocator rounds the allocation size, and external free memory, which are
+    the gaps between allocations in a segment.
+    Args:
+        data: snapshot dictionary created from _snapshot()
+    """
+    segments = []
+    out = io.StringIO()
+    out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
+    total_reserved = 0
+    total_allocated = 0
+    free_external = 0
+    free_internal = 0
+    for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
+        total_reserved += seg['total_size']
+
+        seg_free_external = 0
+        seg_free_internal = 0
+        seg_allocated = 0
+        all_ranges = []
+        boffset = 0
+        for b in seg['blocks']:
+            active = b['state'] == 'active_allocated'
+            if active:
+                _, allocated_size = _block_extra(b)
+                all_ranges.append((boffset, allocated_size, True))
+                seg_allocated += allocated_size
+                seg_free_internal += b['size'] - allocated_size
+            else:
+                seg_free_external += b['size']
+
+            boffset += b['size']
+
+        total_allocated += seg_allocated
+        free_external += seg_free_external
+        free_internal += seg_free_internal
+
+        nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
+        occupied = [' ' for _ in range(nseg)]
+        frac = [0.0 for _ in range(nseg)]
+        active_size = 0
+        for i, (start_, size, active) in enumerate(all_ranges):
+            active_size += size
+            finish_ = (start_ + size)
+            start = start_ // PAGE_SIZE
+            finish = (finish_ - 1) // PAGE_SIZE + 1
+            m = chr(ord('a' if active else 'A') + (i % 26))
+            for j in range(start, finish):
+                s = max(start_, j * PAGE_SIZE)
+                e = min(finish_, (j + 1) * PAGE_SIZE)
+                frac[j] += (e - s) / PAGE_SIZE
+                if occupied[j] != ' ':
+                    occupied[j] = '0123456789*'[int(frac[j] * 10)]
+                else:
+                    occupied[j] = m
+        stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
+        body = ''.join(occupied)
+        assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
+        stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
+        if seg['total_size'] >= PAGE_SIZE:
+            out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
+                      f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
+    out.write(f'segments: {len(data["segments"])}\n')
+    out.write(f'total_reserved: {Bytes(total_reserved)}\n')
+    out.write(f'total_allocated: {Bytes(total_allocated)}\n')
+    internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
+    out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
+    out.write(legend)
+    assert free_internal + free_external + total_allocated == total_reserved
+    return out.getvalue()
+
+def trace(data):
+    out = io.StringIO()
+
+    def format(entries):
+        segment_intervals : list = []
+        segment_addr_to_name = {}
+        allocation_addr_to_name = {}
+
+        free_names : list = []
+        next_name = 0
+
+        def _name():
+            nonlocal next_name
+            if free_names:
+                return free_names.pop()
+            r, m = next_name // 26, next_name % 26
+            next_name += 1
+            return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
+
+        def find_segment(addr):
+            for name, saddr, size in segment_intervals:
+                if addr >= saddr and addr < saddr + size:
+                    return name, saddr
+            for i, seg in enumerate(data['segments']):
+                saddr = seg['address']
+                size = seg['allocated_size']
+                if addr >= saddr and addr < saddr + size:
+                    return f'seg_{i}', saddr
+            return None, None
+        count = 0
+        out.write(f'{len(entries)} entries\n')
+
+
+        total_reserved = 0
+        for seg in data['segments']:
+            total_reserved += seg['total_size']
+
+        for count, e in enumerate(entries):
+            if e['action'] == 'alloc':
+                addr, size = e['addr'], e['size']
+                n = _name()
+                seg_name, seg_addr = find_segment(addr)
+                if seg_name is None:
+                    seg_name = "MEM"
+                    offset = addr
+                else:
+                    offset = addr - seg_addr
+                out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
+                allocation_addr_to_name[addr] = (n, size, count)
+                count += size
+            elif e['action'] == 'free_requested':
+                addr, size = e['addr'], e['size']
+                name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
+                out.write(f'del {name} # {Bytes(size)}\n')
+            elif e['action'] == 'free_completed':
+                addr, size = e['addr'], e['size']
+                count -= size
+                name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
+                out.write(f'# free completed for {name} {Bytes(size)}\n')
+                if name in allocation_addr_to_name:
+                    free_names.append(name)
+                    del allocation_addr_to_name[name]
+            elif e['action'] == 'segment_alloc':
+                addr, size = e['addr'], e['size']
+                name = _name()
+                out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
+                segment_intervals.append((name, addr, size))
+                segment_addr_to_name[addr] = name
+            elif e['action'] == 'segment_free':
+                addr, size = e['addr'], e['size']
+                name = segment_addr_to_name.get(addr, addr)
+                out.write(f'cudaFree({name}) # {Bytes(size)}\n')
+                if name in segment_addr_to_name:
+                    free_names.append(name)
+                    del segment_addr_to_name[name]
+            elif e['action'] == 'oom':
+                size = e['size']
+                free = e['device_free']
+                out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
+            else:
+                out.write(f'{e}\n')
+        out.write(f"TOTAL MEM: {Bytes(count)}")
+    for i, d in enumerate(data['device_traces']):
+        if d:
+            out.write(f'Device {i} ----------------\n')
+            format(d)
+    return out.getvalue()
+
+
+_memory_viz_template = r"""
+
+
+
+
+
+
+
+"""
+
+def _format_viz(data, viz_kind, device):
+    if device is not None:
+        warnings.warn('device argument is deprecated, plots now contain all device')
+    buffer = pickle.dumps(data)
+    buffer += b'\x00' * (3 - len(buffer) % 3)
+    # Encode the buffer with base64
+    encoded_buffer = base64.b64encode(buffer).decode('utf-8')
+
+    json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
+    return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
+                               .replace('$SNAPSHOT', json_format)
+
+def trace_plot(data, device=None, plot_segments=False):
+    """Generate a visualization over time of the memory usage recorded by the trace as an html file.
+
+    Args:
+        data: Memory snapshot as generated from torch.cuda.memory._snapshot()
+        device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
+        plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
+                                        Defaults to False.
+
+    Returns:
+        str: HTML of visualization
+    """
+    return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
+
+
+def _profile_to_snapshot(profile):
+    import torch
+    from torch.profiler._memory_profiler import Action, TensorKey
+    from torch._C._profiler import _EventType
+    memory_profile = profile._memory_profile()
+
+    allocation_stacks = {}
+    for event in memory_profile._op_tree.sorted_nodes:
+        if event.tag == _EventType.Allocation:
+            parent = event.parent
+            python_parents = []
+            while parent:
+                if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
+                    python_parents.append(parent)
+                parent = parent.parent
+            key = TensorKey.from_allocation(event.extra_fields)
+
+            # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
+            #              key will be None. I should add some way to identify these, I just haven't yet.
+            if key and event.extra_fields.alloc_size > 0:
+                allocation_stacks[key] = python_parents
+
+
+    device_count = torch.cuda.device_count()
+    snapshot = {
+        'device_traces': [[] for _ in range(device_count + 1)],
+        'segments': [{'device': device,
+                      'address': None,
+                      'total_size': 0,
+                      'stream': 0,
+                      'blocks': []} for device in range(device_count + 1)]
+    }
+
+    def to_device(device):
+        if device.type == 'cuda':
+            return device.index
+        else:
+            return device_count
+
+    def allocate(size, tensor_key, version, during_trace=True):
+        device = to_device(tensor_key.device)
+        addr = tensor_key.storage.ptr
+
+        seg = snapshot['segments'][device]  # type: ignore[index]
+        if seg['address'] is None or seg['address'] > addr:
+            seg['address'] = addr
+        seg['total_size'] = max(seg['total_size'], addr + size)  # record max addr for now, we will make it the size later
+        category = memory_profile._categories.get(tensor_key, version)
+        category = category.name.lower() if category is not None else "unknown"
+        stack = allocation_stacks.get(tensor_key, ())
+        stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
+        r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
+        if during_trace:
+            snapshot['device_traces'][device].append(r)  # type: ignore[index]
+        return r
+
+    def free(alloc, device):
+        for e in ('free_requested', 'free_completed'):
+            snapshot['device_traces'][device].append({'action': e,  # type: ignore[index]
+                                                      'addr': alloc['addr'],
+                                                      'size': alloc['size'],
+                                                      'stream': 0,
+                                                      'frames': alloc['frames']})
+
+    kv_to_elem = {}
+
+
+
+    # create the device trace
+    for time, action, (tensor_key, version), size in memory_profile.timeline:
+        if not isinstance(tensor_key, TensorKey):
+            continue
+        if action == Action.CREATE:
+            kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
+        elif action == Action.DESTROY:
+            free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
+        elif action == Action.INCREMENT_VERSION:
+            free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
+            kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
+        elif action == Action.PREEXISTING:
+            kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
+
+
+    # create the final snapshot state
+    blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
+                     for (tensor_key, version), event in kv_to_elem.items()]
+    for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
+        seg = snapshot['segments'][device]  # type: ignore[index]
+        last_addr = seg['address']
+        for _, addr, size, frames in blocks:
+            if last_addr < addr:
+                seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
+            seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
+            last_addr = addr + size
+        if last_addr < seg['total_size']:
+            seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
+
+    snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']]  # type: ignore[attr-defined]
+    for seg in snapshot['segments']:  # type: ignore[attr-defined, name-defined, no-redef]
+        seg['total_size'] -= seg['address']
+        if not seg['blocks']:
+            seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
+
+    return snapshot
+
+def profile_plot(profile, device=None):
+    """Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
+
+    Args:
+        profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
+        device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
+
+    Returns:
+        str: HTML of visualization
+    """
+    snapshot = _profile_to_snapshot(profile)
+    return _format_viz(snapshot, 'Active Memory Timeline', device)
+
+
+def segment_plot(data: Any, device=None):
+    return _format_viz(data, 'Allocator State History', device)
+
+if __name__ == "__main__":
+    import os.path
+    thedir = os.path.realpath(os.path.dirname(__file__))
+    if thedir in sys.path:
+        # otherwise we find cuda/random.py as random...
+        sys.path.remove(thedir)
+    import argparse
+
+    fn_name = 'torch.cuda.memory._snapshot()'
+    pickled = f'pickled memory statistics from {fn_name}'
+    parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
+
+    subparsers = parser.add_subparsers(dest='action')
+
+    def _output(p):
+        p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
+
+    description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
+    stats_a = subparsers.add_parser('stats', description=description)
+    stats_a.add_argument('input', help=pickled)
+
+    description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
+    trace_a = subparsers.add_parser('trace', description=description)
+    trace_a.add_argument('input', help=pickled)
+
+    description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
+    segments_a = subparsers.add_parser('segments', description=description)
+    segments_a.add_argument('input', help=pickled)
+    _output(segments_a)
+
+    description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
+    memory_a = subparsers.add_parser('memory', description=description)
+    memory_a.add_argument('input', help=pickled)
+    _output(memory_a)
+
+    description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
+        'or removed between two different memorys snapshots.'
+    compare_a = subparsers.add_parser('compare', description=description)
+    compare_a.add_argument('before', help=pickled)
+    compare_a.add_argument('after', help=pickled)
+    _output(compare_a)
+
+    plots = (
+        ("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
+        ("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
+    )
+    for cmd, description in plots:
+        trace_plot_a = subparsers.add_parser(cmd, description=description)
+        trace_plot_a.add_argument('input', help=pickled)
+        help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
+        trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
+        help = 'path to save the visualization(default: output.html)'
+        trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
+        if cmd == "trace_plot":
+            help = 'visualize change to segments rather than individual allocations'
+            trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
+
+
+    args = parser.parse_args()
+
+    def _read(name):
+        if name == '-':
+            f = sys.stdin.buffer
+        else:
+            f = open(name, 'rb')
+        data = pickle.load(f)
+        if isinstance(data, list):  # segments only...
+            data = {'segments': data, 'traces': []}
+        return data
+
+    def _write(name, data):
+        with open(name, 'w') as f:
+            f.write(data)
+
+    if args.action == 'segments':
+        data = _read(args.input)
+        _write(args.output, segments(data))
+    elif args.action == 'memory':
+        data = _read(args.input)
+        _write(args.output, memory(data))
+    elif args.action == 'stats':
+        data = _read(args.input)
+        print(segsum(data))
+    elif args.action == 'trace':
+        data = _read(args.input)
+        print(trace(data))
+    elif args.action == 'compare':
+        before = _read(args.before)
+        after = _read(args.after)
+        _write(args.output, compare(before, after))
+    elif args.action == 'trace_plot':
+        data = _read(args.input)
+        _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
+    elif args.action == 'segment_plot':
+        data = _read(args.input)
+        _write(args.output, segment_plot(data, device=args.device))
diff --git a/MLPY/Lib/site-packages/torch/cuda/_sanitizer.py b/MLPY/Lib/site-packages/torch/cuda/_sanitizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e01f8a87a43f1643c8bcfe2577e8977e7fb4e3e4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/_sanitizer.py
@@ -0,0 +1,622 @@
+r"""
+This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
+
+It stores information on accesses to tensors to determine if they are synchronized
+or not. When enabled in a python program and a possible data race is detected, a
+detailed warning will be printed and the program will exit.
+
+It can be enabled either by importing this module and calling
+:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
+environment variable.
+"""
+
+import enum
+import functools
+import inspect
+import io
+import logging
+import sys
+import textwrap
+import traceback
+from dataclasses import dataclass, field
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
+
+import torch
+import torch.utils._cuda_trace as cuda_trace
+from torch.utils import _pytree as pytree
+from torch.utils._python_dispatch import TorchDispatchMode
+
+
+DEFAULT_STREAM_ID = 0
+
+TK = TypeVar("TK")
+TVa = TypeVar("TVa")
+TVb = TypeVar("TVb")
+
+DataPtr = int
+StreamId = int
+EventId = int
+SeqNum = int
+
+logger = logging.getLogger(__name__)
+
+
+class AccessType(enum.Enum):
+    READ = enum.auto()
+    WRITE = enum.auto()
+
+    def __str__(self):
+        return "reading from" if self is AccessType.READ else "writing to"
+
+
+@dataclass
+class Access:
+    r"""Stores information about a single access to a tensor by a kernel.
+
+    Args:
+        type: either AccessType.READ or AccessType.Write.
+        seq_num: the sequential number of the kernel performing the access.
+        stream: the stream id of the stream executing the kernel.
+        operator: the schema of the launched kernel, which lists the
+            arguments and return type.
+        aliases: the arguments in the schema this access corresponds to.
+        is_output: Whether the tensor was an output of the kernel.
+        stack_trace: the stack summary object captured during access.
+    """
+
+    type: AccessType
+    seq_num: SeqNum
+    stream: StreamId
+    operator: str
+    aliases: List[str]
+    is_output: bool
+    stack_trace: traceback.StackSummary
+
+
+class SynchronizationError(Exception):
+    """Base class for errors detected by CUDA Sanitizer."""
+
+    pass
+
+
+class UnsynchronizedAccessError(SynchronizationError):
+    """Stores information about two unsynchronized accesses to one data pointer."""
+
+    def __init__(
+        self,
+        data_ptr: DataPtr,
+        allocation_stack_trace: Optional[traceback.StackSummary],
+        current_access: Access,
+        previous_access: Access,
+    ):
+        self.data_ptr = data_ptr
+        self.allocation_stack_trace = allocation_stack_trace
+        self.current_access = current_access
+        self.previous_access = previous_access
+
+    def __str__(self):
+        def format_access(access: Access):
+            message.write(f"{access.operator}\n{access.type}")
+            if access.aliases:
+                message.write(" argument(s) " + ", ".join(access.aliases))
+                if access.is_output:
+                    message.write(", and to")
+            if access.is_output:
+                message.write(" the output")
+            message.write(
+                f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
+            )
+
+        with io.StringIO() as message:
+            message.write(
+                textwrap.dedent(
+                    f"""\
+                    ============================
+                    CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
+                    Access by stream {self.current_access.stream} during kernel:
+                    """
+                )
+            )
+            format_access(self.current_access)
+
+            message.write(
+                f"Previous access by stream {self.previous_access.stream} during kernel:\n"
+            )
+            format_access(self.previous_access)
+
+            if self.allocation_stack_trace:
+                message.write(
+                    "Tensor was allocated with stack trace:\n"
+                    f"{''.join(self.allocation_stack_trace.format())}"
+                )
+            else:
+                message.write("Trace for tensor allocation not found.")
+            return message.getvalue()
+
+
+class CUDASanitizerErrors(Exception):
+    """Wrapper class for errors reported by CUDA Sanitizer."""
+
+    def __init__(self, errors: List[SynchronizationError]):
+        self.errors = errors
+
+    def __str__(self):
+        return f"detected {len(self.errors)} errors"
+
+
+@dataclass
+class TensorInfo:
+    r"""Stores information about a single tensor and recent accesses to it.
+
+    Args:
+        allocation_stack_trace: the stack summary object captured during tensor
+            allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
+        reads: list of read accesses to the tensor that were performed since
+            the last write.
+        write: the last write access to the tensor.
+    """
+
+    allocation_stack_trace: Optional[traceback.StackSummary]
+    reads: List[Access] = field(default_factory=list)
+    write: Optional[Access] = None
+
+
+class _TensorsAccessed:
+    def __init__(self):
+        self.accesses: Dict[DataPtr, TensorInfo] = {}
+
+    def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
+        if data_ptr not in self.accesses:
+            logger.info(
+                "Found tensor with pointer: %s, but no matching tensor "
+                "allocation in the trace. Backfilling the trace now. "
+                "Perhaps the sanitizer was enabled after some torch operations?",
+                data_ptr,
+            )
+            self.create_tensor(data_ptr, None)
+
+    def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
+        if data_ptr in self.accesses:
+            logger.info(
+                "Found duplicate tensor allocation in the trace for tensor with "
+                "pointer: %s. Assuming the trace for tensor deallocation "
+                "wasn't caught and backfilling it now. "
+                "Perhaps the sanitizer was enabled after some torch operations?",
+                data_ptr,
+            )
+            self.delete_tensor(data_ptr)
+
+    def create_tensor(
+        self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
+    ) -> None:
+        self.accesses[data_ptr] = TensorInfo(stack_trace)
+
+    def delete_tensor(self, data_ptr: DataPtr) -> None:
+        del self.accesses[data_ptr]
+
+    def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
+        return True if self.accesses[data_ptr].reads else False
+
+    def get_allocation_stack_trace(
+        self, data_ptr: DataPtr
+    ) -> Optional[traceback.StackSummary]:
+        return self.accesses[data_ptr].allocation_stack_trace
+
+    def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
+        return self.accesses[data_ptr].write
+
+    def get_reads(self, data_ptr: DataPtr) -> List[Access]:
+        return self.accesses[data_ptr].reads
+
+    def add_read(self, data_ptr: DataPtr, access: Access) -> None:
+        self.accesses[data_ptr].reads.append(access)
+
+    def set_write(self, data_ptr: DataPtr, access: Access) -> None:
+        self.accesses[data_ptr].write = access
+        self.accesses[data_ptr].reads = []
+
+
+class StreamSynchronizations:
+    def __init__(self):
+        self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
+        self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
+        self.host_sync_state: Dict[StreamId, SeqNum] = {}
+        self.create_stream(DEFAULT_STREAM_ID)
+
+    def _ensure_stream_exists(self, stream: StreamId) -> None:
+        if stream not in self.current_sync_states:
+            logger.info(
+                "Found Stream with id: %s, but no matching stream "
+                "creation in the trace. Backfilling the trace now. "
+                "Perhaps the sanitizer was enabled after some torch operations?",
+                stream,
+            )
+            self.create_stream(stream)
+
+    def _ensure_event_exists(self, event: EventId) -> None:
+        if event not in self.recorded_sync_states:
+            logger.info(
+                "Found Event with id: %s, but no matching event "
+                "creation in the trace. Backfilling the trace now. "
+                "Perhaps the sanitizer was enabled after some torch operations?",
+                event,
+            )
+            self.create_event(event)
+
+    def _ensure_event_does_not_exist(self, event: EventId) -> None:
+        if event in self.recorded_sync_states:
+            logger.info(
+                "Found duplicate event creation in the trace for event with "
+                "id: %s. Assuming the trace for event deletion wasn't caught "
+                "and backfilling it now. "
+                "Perhaps the sanitizer was enabled after some torch operations?",
+                event,
+            )
+            self.delete_event(event)
+
+    def create_stream(self, stream: StreamId) -> None:
+        if stream in self.current_sync_states:
+            logger.info(
+                "Found duplicate Stream creation in the trace for Stream with "
+                "id: %s. PyTorch Streams are only created once, so this "
+                "trace entry is ignored.",
+                stream,
+            )
+        else:
+            self.host_sync_state[stream] = 0
+            self.current_sync_states[stream] = self.host_sync_state.copy()
+
+    def create_event(self, event: EventId) -> None:
+        self._ensure_event_does_not_exist(event)
+        self.recorded_sync_states[event] = {}
+
+    def delete_event(self, event: EventId) -> None:
+        self._ensure_event_exists(event)
+        del self.recorded_sync_states[event]
+
+    def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
+        self._ensure_stream_exists(stream)
+        self.current_sync_states[stream][stream] = seq_num
+
+    def record_state(self, event: EventId, stream: StreamId) -> None:
+        self._ensure_event_exists(event)
+        self._ensure_stream_exists(stream)
+        self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
+
+    def _state_wait_for_other(
+        self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
+    ) -> None:
+        for stream, seq_num in other.items():
+            state[stream] = max(state.get(stream, -1), seq_num)
+
+    def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
+        self._ensure_stream_exists(stream)
+        self._ensure_event_exists(event)
+        self._state_wait_for_other(
+            self.current_sync_states[stream], self.recorded_sync_states[event]
+        )
+
+    def all_streams_wait_for_event(self, event: EventId) -> None:
+        self._ensure_event_exists(event)
+        for stream in self.current_sync_states.keys():
+            self.stream_wait_for_event(stream, event)
+
+        self._state_wait_for_other(
+            self.host_sync_state, self.recorded_sync_states[event]
+        )
+
+    def all_streams_wait_for_stream(self, stream: StreamId) -> None:
+        self._ensure_stream_exists(stream)
+        for state in self.current_sync_states.values():
+            self._state_wait_for_other(state, self.current_sync_states[stream])
+
+        self._state_wait_for_other(
+            self.host_sync_state, self.current_sync_states[stream]
+        )
+
+    def sync_all_streams(self) -> None:
+        for stream, state in self.current_sync_states.items():
+            self.host_sync_state[stream] = state[stream]
+
+        for state in self.current_sync_states.values():
+            self._state_wait_for_other(state, self.host_sync_state)
+
+    def is_ordered_after(
+        self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
+    ) -> bool:
+        self._ensure_stream_exists(current_stream)
+        self._ensure_stream_exists(other_stream)
+        return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
+
+
+class EventHandler:
+    """Analyzes CSAN trace for synchronization errors.
+
+    Stores information on each stream's synchronizations with other streams as well
+    as tensor accesses to determine whether a given kernel launch might cause a
+    data race.
+    """
+
+    def __init__(self):
+        self.tensors_accessed = _TensorsAccessed()
+        self.syncs = StreamSynchronizations()
+        self.seq_num: SeqNum = 0
+
+    def _handle_kernel_launch(
+        self,
+        stream: StreamId,
+        read_only: Set[DataPtr],
+        read_write: Set[DataPtr],
+        outputs: Set[DataPtr],
+        operator: str,
+        tensor_aliases: Dict[int, List[str]],
+    ) -> List[SynchronizationError]:
+        def check_conflict(
+            data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
+        ) -> None:
+            if previous_access is None:
+                return
+            if not self.syncs.is_ordered_after(
+                current_access.stream, previous_access.seq_num, previous_access.stream
+            ):
+                error_list.append(
+                    UnsynchronizedAccessError(
+                        data_ptr,
+                        self.tensors_accessed.get_allocation_stack_trace(data_ptr),
+                        current_access,
+                        previous_access,
+                    )
+                )
+
+        error_list: List[SynchronizationError] = []
+        self.seq_num += 1
+        self.syncs.update_seq_num(stream, self.seq_num)
+        stack_trace = traceback.StackSummary.extract(
+            traceback.walk_stack(inspect.currentframe()), lookup_lines=False
+        )
+        # The stack trace generated in this way is in the inverse order, so it must be
+        # reversed.
+        stack_trace.reverse()
+
+        for data_ptr in read_only:
+            self.tensors_accessed.ensure_tensor_exists(data_ptr)
+            current_access = Access(
+                AccessType.READ,
+                self.seq_num,
+                stream,
+                operator,
+                tensor_aliases[data_ptr],
+                data_ptr in outputs,
+                stack_trace,
+            )
+            check_conflict(
+                data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
+            )
+            self.tensors_accessed.add_read(data_ptr, current_access)
+
+        for data_ptr in read_write:
+            self.tensors_accessed.ensure_tensor_exists(data_ptr)
+            current_access = Access(
+                AccessType.WRITE,
+                self.seq_num,
+                stream,
+                operator,
+                tensor_aliases[data_ptr],
+                data_ptr in outputs,
+                stack_trace,
+            )
+            if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
+                for previous_access in self.tensors_accessed.get_reads(data_ptr):
+                    check_conflict(data_ptr, current_access, previous_access)
+            else:
+                check_conflict(
+                    data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
+                )
+            self.tensors_accessed.set_write(data_ptr, current_access)
+
+        return error_list
+
+    def _handle_event_creation(self, event: EventId) -> None:
+        self.syncs.create_event(event)
+
+    def _handle_event_deletion(self, event: EventId) -> None:
+        self.syncs.delete_event(event)
+
+    def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
+        self.syncs.record_state(event, stream)
+
+    def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
+        self.syncs.stream_wait_for_event(stream, event)
+
+    def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
+        self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
+        stack_trace = traceback.StackSummary.extract(
+            traceback.walk_stack(inspect.currentframe()), lookup_lines=False
+        )
+        # The stack trace generated in this way is in the inverse order, so it must be
+        # reversed.
+        stack_trace.reverse()
+        self.tensors_accessed.create_tensor(
+            data_ptr,
+            stack_trace,
+        )
+
+    def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
+        self.tensors_accessed.ensure_tensor_exists(data_ptr)
+        self.tensors_accessed.delete_tensor(data_ptr)
+
+    def _handle_stream_creation(self, stream: StreamId) -> None:
+        self.syncs.create_stream(stream)
+
+    def _handle_device_synchronization(self) -> None:
+        self.syncs.sync_all_streams()
+
+    def _handle_stream_synchronization(self, stream: StreamId) -> None:
+        self.syncs.all_streams_wait_for_stream(stream)
+
+    def _handle_event_synchronization(self, event: EventId) -> None:
+        self.syncs.all_streams_wait_for_event(event)
+
+
+def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
+    for arg, value in a.items():
+        if arg in b:
+            yield arg, value, b[arg]
+
+
+def zip_arguments(
+    schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
+) -> Iterator[Tuple[torch.Argument, Any]]:
+    schema_args = schema.arguments[: len(args)]
+    schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
+
+    yield from zip(schema_args, args)
+
+    for _, argument, value in zip_by_key(schema_kwargs, kwargs):
+        yield (argument, value)
+
+
+class ArgumentHandler:
+    def __init__(self):
+        self.dataptrs_read: Set[DataPtr] = set()
+        self.dataptrs_written: Set[DataPtr] = set()
+        self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
+        self.outputs: Set[DataPtr] = set()
+
+    def _handle_argument(
+        self,
+        value: Any,
+        is_write: bool,
+        name: Optional[str] = None,
+        is_output: bool = False,
+    ) -> None:
+        if isinstance(value, torch.Tensor) and value.is_cuda:
+            data_ptr = value.data_ptr()
+            if is_write:
+                self.dataptrs_written.add(data_ptr)
+            else:
+                self.dataptrs_read.add(data_ptr)
+
+            self.tensor_aliases.setdefault(data_ptr, [])
+            if name is not None:
+                self.tensor_aliases[data_ptr].append(name)
+            if is_output:
+                self.outputs.add(data_ptr)
+
+    def parse_inputs(
+        self,
+        schema: torch.FunctionSchema,
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+    ) -> None:
+        for argument, value in zip_arguments(schema, args, kwargs):
+            is_write = argument.alias_info is not None and argument.alias_info.is_write
+            pytree.tree_map_(
+                functools.partial(
+                    self._handle_argument, is_write=is_write, name=argument.name
+                ),
+                value,
+            )
+
+    def parse_outputs(self, outputs: Any) -> None:
+        pytree.tree_map_(
+            functools.partial(self._handle_argument, is_write=True, is_output=True),
+            outputs,
+        )
+
+
+class CUDASanitizerDispatchMode(TorchDispatchMode):
+    def __init__(self):
+        self.event_handler = EventHandler()
+        torch._C._activate_cuda_trace()
+        cuda_trace.register_callback_for_cuda_event_creation(
+            self.event_handler._handle_event_creation
+        )
+        cuda_trace.register_callback_for_cuda_event_deletion(
+            self.event_handler._handle_event_deletion
+        )
+        cuda_trace.register_callback_for_cuda_event_record(
+            self.event_handler._handle_event_record
+        )
+        cuda_trace.register_callback_for_cuda_event_wait(
+            self.event_handler._handle_event_wait
+        )
+        cuda_trace.register_callback_for_cuda_memory_allocation(
+            self.event_handler._handle_memory_allocation
+        )
+        cuda_trace.register_callback_for_cuda_memory_deallocation(
+            self.event_handler._handle_memory_deallocation
+        )
+        cuda_trace.register_callback_for_cuda_stream_creation(
+            self.event_handler._handle_stream_creation
+        )
+        cuda_trace.register_callback_for_cuda_device_synchronization(
+            self.event_handler._handle_device_synchronization
+        )
+        cuda_trace.register_callback_for_cuda_stream_synchronization(
+            self.event_handler._handle_stream_synchronization
+        )
+        cuda_trace.register_callback_for_cuda_event_synchronization(
+            self.event_handler._handle_event_synchronization
+        )
+
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        if kwargs is None:
+            kwargs = {}
+
+        argument_handler = ArgumentHandler()
+        argument_handler.parse_inputs(func._schema, args, kwargs)
+
+        outputs = func(*args, **kwargs)
+
+        argument_handler.parse_outputs(outputs)
+        errors = self.event_handler._handle_kernel_launch(
+            torch.cuda.current_stream().cuda_stream,
+            argument_handler.dataptrs_read - argument_handler.dataptrs_written,
+            argument_handler.dataptrs_written,
+            argument_handler.outputs,
+            func._schema,
+            argument_handler.tensor_aliases,
+        )
+        if errors:
+            for error in errors:
+                print(error, file=sys.stderr)
+            raise CUDASanitizerErrors(errors)
+
+        return outputs
+
+
+class CUDASanitizer:
+    """Manages the lifetime of a CUDASanitizer dispatch mode object.
+
+    The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
+    context manager in the enable function/destructor, respectively. This is to
+    explicitly set the lifetime of the dispatch mode object to that of the application.
+    This approach was deemed more elegant than using the atexit module.
+    """
+
+    def __init__(self):
+        self.dispatch = CUDASanitizerDispatchMode()
+        self.enabled = False
+
+    def enable(self):
+        self.dispatch.__enter__()
+        self.enabled = True
+
+    def __del__(self):
+        if self.enabled:
+            self.dispatch.__exit__(None, None, None)
+
+
+def enable_cuda_sanitizer():
+    """Enable CUDA Sanitizer.
+
+    The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
+    for synchronization errors. All data races found will be printed to the standard
+    error output along with stack traces of suspected causes. For best results, the
+    sanitizer should be enabled at the very beginning of the program.
+    """
+    cuda_sanitizer.enable()
+
+
+cuda_sanitizer = CUDASanitizer()
diff --git a/MLPY/Lib/site-packages/torch/cuda/_utils.py b/MLPY/Lib/site-packages/torch/cuda/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a745084612a099771abb4d587b40c1b8b447b2e2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/_utils.py
@@ -0,0 +1,38 @@
+from typing import Any
+
+import torch
+
+# The _get_device_index has been moved to torch.utils._get_device_index
+from torch._utils import _get_device_index as _torch_get_device_index
+
+
+def _get_device_index(
+    device: Any, optional: bool = False, allow_cpu: bool = False
+) -> int:
+    r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
+
+    If :attr:`device` is a torch.device object, returns the device index if it
+    is a CUDA device. Note that for a CUDA device without a specified index,
+    i.e., ``torch.device('cuda')``, this will return the current default CUDA
+    device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
+    CPU devices will be accepted and ``-1`` will be returned in this case.
+
+    If :attr:`device` is a Python integer, it is returned as is.
+
+    If :attr:`device` is ``None``, this will return the current default CUDA
+    device if :attr:`optional` is ``True``.
+    """
+    if isinstance(device, int):
+        return device
+    if isinstance(device, str):
+        device = torch.device(device)
+    if isinstance(device, torch.device):
+        if allow_cpu:
+            if device.type not in ["cuda", "cpu"]:
+                raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
+        elif device.type != "cuda":
+            raise ValueError(f"Expected a cuda device, but got: {device}")
+    if not torch.jit.is_scripting():
+        if isinstance(device, torch.cuda.device):
+            return device.idx
+    return _torch_get_device_index(device, optional, allow_cpu)
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/__init__.py b/MLPY/Lib/site-packages/torch/cuda/amp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d45cd029f10483aab9f4e849efb7f667a84e12a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/amp/__init__.py
@@ -0,0 +1,11 @@
+from .autocast_mode import autocast, custom_bwd, custom_fwd
+from .common import amp_definitely_not_available
+from .grad_scaler import GradScaler
+
+__all__ = [
+    "amp_definitely_not_available",
+    "autocast",
+    "custom_bwd",
+    "custom_fwd",
+    "GradScaler",
+]
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c26ad5600930d1974e62808cff6007a31c79d0bf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f40c5b4dbe2990d9b6f898c8faef4723adbda2ed
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/common.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03f7ebd515964176055ed41f6aa652a5d54bad75
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/common.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb598ef5ca81e283f48da43819295fc7a053148a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/autocast_mode.py b/MLPY/Lib/site-packages/torch/cuda/amp/autocast_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2e86429596b43ce6be3a97da4e07e3bdb2221ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/amp/autocast_mode.py
@@ -0,0 +1,144 @@
+import collections
+import functools
+
+import torch
+
+try:
+    import numpy as np
+
+    HAS_NUMPY = True
+except ModuleNotFoundError:
+    np = None  # type: ignore[assignment]
+from typing import Any
+
+__all__ = ["autocast", "custom_fwd", "custom_bwd"]
+
+
+class autocast(torch.amp.autocast_mode.autocast):
+    r"""See :class:`torch.autocast`.
+
+    ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
+    """
+
+    def __init__(
+        self,
+        enabled: bool = True,
+        dtype: torch.dtype = torch.float16,
+        cache_enabled: bool = True,
+    ):
+        if torch._jit_internal.is_scripting():
+            self._enabled = enabled
+            self.device = "cuda"
+            self.fast_dtype = dtype
+            return
+        super().__init__(
+            "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
+        )
+
+    def __enter__(self):
+        if torch._jit_internal.is_scripting():
+            return self
+        return super().__enter__()
+
+    # TODO: discuss a unified TorchScript-friendly API for autocast
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
+        if torch._jit_internal.is_scripting():
+            return
+        return super().__exit__(exc_type, exc_val, exc_tb)
+
+    def __call__(self, func):
+        if torch._jit_internal.is_scripting():
+            return func
+        return super().__call__(func)
+
+
+# Casts Tensors and containers of Tensors.  Special-cases passthroughs for strings and np.ndarrays, which
+# may be falsely detected as "Iterables."
+def _cast(value, dtype):
+    if isinstance(value, torch.Tensor):
+        is_eligible = (
+            value.is_floating_point()
+            and value.is_cuda
+            and (value.dtype is not torch.float64)
+        )
+        return value.to(dtype) if is_eligible else value
+    elif isinstance(value, (str, bytes)):
+        return value
+    elif HAS_NUMPY and isinstance(value, np.ndarray):
+        return value
+    elif isinstance(value, collections.abc.Mapping):
+        return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
+    elif isinstance(value, collections.abc.Iterable):
+        iterable = (_cast(v, dtype) for v in value)
+        if isinstance(value, (list, tuple)):
+            return type(value)(iterable)
+        else:
+            return iterable
+    else:
+        return value
+
+
+# custom_fwd is a decorator that may or may not be used with arguments, following
+# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
+# this works:
+#     @custom_fwd
+#     def forward(...):
+# this also works:
+#     @custom_fwd(cast_inputs=torch.float)
+#     def forward(...):
+def custom_fwd(fwd=None, *, cast_inputs=None):
+    """
+    Create a helper decorator for ``forward`` methods of custom autograd functions.
+
+    Autograd functions are subclasses of :class:`torch.autograd.Function`.
+    See the :ref:`example page` for more detail.
+
+    Args:
+        cast_inputs (:class:`torch.dtype` or None, optional, default=None):  If not ``None``,
+            when ``forward`` runs in an autocast-enabled region, casts incoming
+            floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
+            then executes ``forward`` with autocast disabled.
+            If ``None``, ``forward``'s internal ops execute with the current autocast state.
+
+    .. note::
+        If the decorated ``forward`` is called outside an autocast-enabled region,
+        :func:`custom_fwd` is a no-op and ``cast_inputs`` has no effect.
+    """
+    if fwd is None:
+        return functools.partial(custom_fwd, cast_inputs=cast_inputs)
+
+    @functools.wraps(fwd)
+    def decorate_fwd(*args, **kwargs):
+        args[0]._dtype = torch.get_autocast_gpu_dtype()
+        if cast_inputs is None:
+            args[0]._fwd_used_autocast = torch.is_autocast_enabled()
+            return fwd(*args, **kwargs)
+        else:
+            autocast_context = torch.is_autocast_enabled()
+            args[0]._fwd_used_autocast = False
+            if autocast_context:
+                with autocast(enabled=False):
+                    return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
+            else:
+                return fwd(*args, **kwargs)
+
+    return decorate_fwd
+
+
+# Autograd ensures incoming gradients are the same type as forward outputs.  Allowing a separate
+# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
+# cast_inputs supplied to custom_fwd.
+def custom_bwd(bwd):
+    """Create a helper decorator for backward methods of custom autograd functions.
+
+    Autograd functions are subclasses of :class:`torch.autograd.Function`.
+    Ensures that ``backward`` executes with the same autocast state as ``forward``.
+    See the :ref:`example page` for more detail.
+    """
+
+    @functools.wraps(bwd)
+    def decorate_bwd(*args, **kwargs):
+        with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
+            return bwd(*args, **kwargs)
+
+    return decorate_bwd
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/common.py b/MLPY/Lib/site-packages/torch/cuda/amp/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f03685281b1b053af152301f6bc5d1981ef32b5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/amp/common.py
@@ -0,0 +1,9 @@
+from importlib.util import find_spec
+
+import torch
+
+__all__ = ["amp_definitely_not_available"]
+
+
+def amp_definitely_not_available():
+    return not (torch.cuda.is_available() or find_spec("torch_xla"))
diff --git a/MLPY/Lib/site-packages/torch/cuda/amp/grad_scaler.py b/MLPY/Lib/site-packages/torch/cuda/amp/grad_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed9cac2ecba674640d92a9a04a81e3740549c911
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/amp/grad_scaler.py
@@ -0,0 +1,28 @@
+import torch
+from torch.amp.grad_scaler import OptState
+
+__all__ = ["GradScaler", "OptState"]
+
+
+class GradScaler(torch.amp.GradScaler):
+    r"""
+    See :class:`torch.amp.GradScaler`.
+    ``torch.cuda.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cuda", args...)``
+    """
+
+    def __init__(
+        self,
+        init_scale: float = 2.0**16,
+        growth_factor: float = 2.0,
+        backoff_factor: float = 0.5,
+        growth_interval: int = 2000,
+        enabled: bool = True,
+    ) -> None:
+        super().__init__(
+            "cuda",
+            init_scale=init_scale,
+            growth_factor=growth_factor,
+            backoff_factor=backoff_factor,
+            growth_interval=growth_interval,
+            enabled=enabled,
+        )
diff --git a/MLPY/Lib/site-packages/torch/cuda/comm.py b/MLPY/Lib/site-packages/torch/cuda/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c124cbf6f8932cfb9a27cd4276ce7b6c4c7cd6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/comm.py
@@ -0,0 +1,18 @@
+# The functions here have been moved to torch.nn.parallel.comm
+from torch.nn.parallel.comm import (
+    broadcast,
+    broadcast_coalesced,
+    gather,
+    reduce_add,
+    reduce_add_coalesced,
+    scatter,
+)
+
+__all__ = [
+    "broadcast",
+    "broadcast_coalesced",
+    "reduce_add",
+    "reduce_add_coalesced",
+    "scatter",
+    "gather",
+]
diff --git a/MLPY/Lib/site-packages/torch/cuda/error.py b/MLPY/Lib/site-packages/torch/cuda/error.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/cuda/graphs.py b/MLPY/Lib/site-packages/torch/cuda/graphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5ead56cf61306ab549623c24a39951e5972d371
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/graphs.py
@@ -0,0 +1,479 @@
+import gc
+from typing import Optional
+
+import torch
+from torch.utils import _pytree
+from .._utils import _dummy_type
+
+if not hasattr(torch._C, "_CudaStreamBase"):
+    # Define dummy base classes
+    torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
+    torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
+    torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
+        "_cuda_isCurrentStreamCapturing"
+    )
+
+from torch._C import (  # noqa: F401
+    _cuda_isCurrentStreamCapturing,
+    _CUDAGraph,
+    _graph_pool_handle,
+)
+
+
+def is_current_stream_capturing():
+    r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
+
+    If a CUDA context does not exist on the current device, returns False without initializing the context.
+    """
+    return _cuda_isCurrentStreamCapturing()
+
+
+# Python shim helps Sphinx process docstrings more reliably.
+def graph_pool_handle():
+    r"""Return an opaque token representing the id of a graph memory pool.
+
+    See :ref:`Graph memory management`.
+
+    .. warning::
+        This API is in beta and may change in future releases.
+    """
+    return _graph_pool_handle()
+
+
+# Python shim helps Sphinx process docstrings more reliably.
+class CUDAGraph(torch._C._CUDAGraph):
+    r"""Wrapper around a CUDA graph.
+
+    .. warning::
+        This API is in beta and may change in future releases.
+    """
+
+    def __new__(cls):
+        return super().__new__(cls)
+
+    def capture_begin(self, pool=None, capture_error_mode="global"):
+        r"""Begin capturing CUDA work on the current stream.
+
+        Typically, you shouldn't call ``capture_begin`` yourself.
+        Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
+        which call ``capture_begin`` internally.
+
+        Arguments:
+            pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
+                :meth:`other_Graph_instance.pool()`) that hints this graph may share memory
+                with the indicated pool.  See :ref:`Graph memory management`.
+            capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
+                Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
+                may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
+                actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
+                unless you're familiar with `cudaStreamCaptureMode `_
+        """  # noqa: B950
+        super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
+
+    def capture_end(self):
+        r"""End CUDA graph capture on the current stream.
+
+        After ``capture_end``, ``replay`` may be called on this instance.
+
+        Typically, you shouldn't call ``capture_end`` yourself.
+        Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
+        which call ``capture_end`` internally.
+        """
+        super().capture_end()
+
+    def replay(self):
+        r"""Replay the CUDA work captured by this graph."""
+        super().replay()
+
+    def reset(self):
+        r"""Delete the graph currently held by this instance."""
+        super().reset()
+
+    def pool(self):
+        r"""Return an opaque token representing the id of this graph's memory pool.
+
+        This id can optionally be passed to another graph's ``capture_begin``,
+        which hints the other graph may share the same memory pool.
+        """
+        return super().pool()
+
+    def enable_debug_mode(self):
+        r"""Enable debugging mode for CUDAGraph.debug_dump."""
+        return super().enable_debug_mode()
+
+    def debug_dump(self, debug_path):
+        r"""
+        Arguments:
+            debug_path (required): Path to dump the graph to.
+
+        Calls a debugging function to dump the graph if the debugging is
+        enabled via CUDAGraph.enable_debug_mode()
+        """
+        return super().debug_dump(debug_path)
+
+
+class graph:
+    r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
+
+    See :ref:`CUDA Graphs ` for a general introduction,
+    detailed use, and constraints.
+
+    Arguments:
+        cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
+        pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
+            :meth:`other_Graph_instance.pool()`) hinting this graph's capture
+            may share memory from the specified pool. See :ref:`Graph memory management`.
+        stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
+            If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
+        capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
+            Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
+            may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
+            actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
+            unless you're familiar with `cudaStreamCaptureMode `_
+
+    .. note::
+        For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
+        used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
+
+    .. warning::
+        This API is in beta and may change in future releases.
+
+    .. _cudaStreamCaptureMode:
+        https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
+    """  # noqa: B950
+
+    default_capture_stream: Optional["torch.cuda.Stream"] = None
+
+    def __init__(
+        self,
+        cuda_graph,
+        pool=None,
+        stream=None,
+        capture_error_mode: str = "global",
+    ):
+        # Lazy-init of default_capture_stream helps avoid circular-import errors.
+        # Not thread safe, but graphs already have the general (explicitly documented)
+        # restriction that only one capture may be underway at a time in the process.
+        if self.__class__.default_capture_stream is None:
+            self.__class__.default_capture_stream = torch.cuda.Stream()
+
+        self.pool = () if pool is None else (pool,)
+        self.capture_stream = (
+            stream if stream is not None else self.__class__.default_capture_stream
+        )
+        assert self.capture_stream is not None
+        self.stream_ctx = torch.cuda.stream(self.capture_stream)
+        self.cuda_graph = cuda_graph
+        self.capture_error_mode = capture_error_mode
+
+    def __enter__(self):
+        # Free as much memory as we can for the graph
+        torch.cuda.synchronize()
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        # Stackoverflow seems comfortable with this pattern
+        # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
+        self.stream_ctx.__enter__()
+
+        self.cuda_graph.capture_begin(
+            *self.pool, capture_error_mode=self.capture_error_mode
+        )
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.cuda_graph.capture_end()
+        self.stream_ctx.__exit__(exc_type, exc_value, traceback)
+        # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
+
+
+def make_graphed_callables(
+    callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
+):
+    r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions.
+
+    Each graphed callable's forward pass runs its source callable's
+    forward CUDA work as a CUDA graph inside a single autograd node.
+
+    The graphed callable's forward pass also appends
+    a backward node to the autograd graph. During backward, this node runs the
+    callable's backward work as a CUDA graph.
+
+    Therefore, each graphed callable should be a drop-in replacement for its source callable
+    in an autograd-enabled training loop.
+
+    See :ref:`Partial-network capture` for detailed use and constraints.
+
+    If you pass a tuple of several callables, their captures will use the same memory pool.
+    See :ref:`Graph memory management` for when this is appropriate.
+
+    Arguments:
+        callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
+            See :ref:`Graph memory management` for when passing a tuple of callables
+            is appropriate.  If you pass a tuple of callables, their order in the tuple must be the same order
+            they'll run in the live workload.
+        sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
+            If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
+            If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
+        num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
+            11 iterations for warm up. Default: ``3``.
+        allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
+            (and therefore their grad is always zero) is an error. Defaults to False.
+        pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
+            :meth:`other_Graph_instance.pool()`) that hints this graph may share memory
+            with the indicated pool.  See :ref:`Graph memory management`.
+    .. note::
+        The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
+        that's expected for the corresponding real input in the training loop.
+
+    .. warning::
+        This API is in beta and may change in future releases.
+
+    .. warning::
+        ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
+
+    .. warning::
+        Returned callables do not support higher order differentiation (e.g., double backward).
+
+    .. warning::
+        In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
+        may be trainable. Buffers must have ``requires_grad=False``.
+
+    .. warning::
+        After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
+        you may not add or remove any of that Module's parameters or buffers.
+
+    .. warning::
+        :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
+        registered on them at the time they are passed. However, registering hooks on modules *after* passing them
+        through :func:`~torch.cuda.make_graphed_callables` is allowed.
+
+    .. warning::
+        When running a graphed callable, you must pass its arguments in the same order and format
+        they appeared in that callable's ``sample_args``.
+
+    .. warning::
+        The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
+        caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
+    """
+    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
+        raise RuntimeError(
+            "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
+        )
+
+    just_one_callable = False
+
+    if not isinstance(callables, tuple):
+        just_one_callable = True
+        callables = (callables,)
+        sample_args = (sample_args,)
+
+    flatten_sample_args = []
+
+    for c, args in zip(callables, sample_args):
+        if isinstance(c, torch.nn.Module):
+            assert (
+                len(c._backward_hooks) == 0
+                and len(c._forward_hooks) == 0
+                and len(c._forward_pre_hooks) == 0
+            ), (
+                "Modules must not have hooks registered at the time they are passed. However, registering hooks "
+                + "on modules after passing them through make_graphed_callables is allowed."
+            )
+            assert all(b.requires_grad is False for b in c.buffers()), (
+                "In any :class:`~torch.nn.Module` passed to "
+                + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
+                + "``requires_grad=False``."
+            )
+        flatten_arg = _pytree.arg_tree_leaves(*args)
+        flatten_sample_args.append(tuple(flatten_arg))
+        assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
+            "In the beta API, sample_args "
+            + "for each callable must contain only Tensors. Other types are not allowed."
+        )
+
+    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
+    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
+    per_callable_len_user_args = [len(args) for args in flatten_sample_args]
+    per_callable_module_params = [
+        tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
+        for c in callables
+    ]
+    per_callable_static_input_surfaces = [
+        flatten_sample_args[i] + per_callable_module_params[i]
+        for i in range(len(callables))
+    ]
+
+    fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
+    bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
+
+    mempool = graph_pool_handle() if pool is None else pool
+
+    # Warmup
+    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
+    # from ending up in any captures.
+    torch.cuda.synchronize()
+    with torch.cuda.stream(torch.cuda.Stream()):
+        for func, args, static_input_surface in zip(
+            callables, sample_args, per_callable_static_input_surfaces
+        ):
+            for _ in range(num_warmup_iters):
+                outputs = _pytree.tree_leaves(func(*args))
+                grad_inputs = torch.autograd.grad(
+                    outputs=tuple(o for o in outputs if o.requires_grad),
+                    inputs=tuple(i for i in static_input_surface if i.requires_grad),
+                    grad_outputs=tuple(
+                        torch.empty_like(o) for o in outputs if o.requires_grad
+                    ),
+                    only_inputs=True,
+                    allow_unused=allow_unused_input,
+                )
+            del outputs, grad_inputs  # type: ignore[possibly-undefined]
+    torch.cuda.synchronize()
+
+    # All captures here share a mempool. To avoid replays corrupting each other's memory,
+    # the safest approach is to capture all passes in the same order they'll run:
+    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
+
+    # Capture forward graphs
+    per_callable_static_outputs = []
+    per_callable_output_unflatten_spec = []
+    for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
+        with torch.cuda.graph(fwd_graph, pool=mempool):
+            outputs = func(*args)
+
+        flatten_outputs, spec = _pytree.tree_flatten(outputs)
+        per_callable_static_outputs.append(tuple(flatten_outputs))
+        per_callable_output_unflatten_spec.append(spec)
+
+    # Capture backward graphs in reverse order
+    per_callable_static_grad_outputs = []
+    per_callable_static_grad_inputs = []
+    for static_input_surface, static_outputs, bwd_graph, module_params in zip(
+        reversed(per_callable_static_input_surfaces),
+        reversed(per_callable_static_outputs),
+        reversed(bwd_graphs),
+        reversed(per_callable_module_params),
+    ):
+        # For now, assumes all static_outputs require grad
+        # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
+        static_grad_outputs = tuple(
+            torch.empty_like(o) if o.requires_grad else None for o in static_outputs
+        )
+
+        with torch.cuda.graph(bwd_graph, pool=mempool):
+            grad_inputs = torch.autograd.grad(
+                outputs=tuple(o for o in static_outputs if o.requires_grad),
+                inputs=tuple(i for i in static_input_surface if i.requires_grad),
+                grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
+                only_inputs=True,
+                allow_unused=allow_unused_input,
+            )
+
+        # Constructs a tuple suitable for returning from Graphed.backward:
+        # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
+        # I couldn't think of a slick one-liner for this pattern.
+        static_grad_inputs = []
+        grad_idx = 0
+        for arg in static_input_surface:
+            if arg.requires_grad:
+                static_grad_inputs.append(grad_inputs[grad_idx])
+                grad_idx += 1
+            else:
+                static_grad_inputs.append(None)  # type: ignore[arg-type]
+        static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]
+
+        per_callable_static_grad_outputs.append(static_grad_outputs)
+        per_callable_static_grad_inputs.append(static_grad_inputs)
+
+    # Reverses the most recent two lists
+    per_callable_static_grad_outputs.reverse()
+    per_callable_static_grad_inputs.reverse()
+    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
+
+    def make_graphed_autograd_function(
+        fwd_graph,
+        bwd_graph,
+        module_params,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+        static_grad_outputs,
+        static_grad_inputs,
+    ):
+        class Graphed(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, *inputs):
+                # At this stage, only the user args may (potentially) be new tensors.
+                for i in range(len_user_args):
+                    if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
+                        static_input_surface[i].copy_(inputs[i])
+                fwd_graph.replay()
+                assert isinstance(static_outputs, tuple)
+                return tuple(o.detach() for o in static_outputs)
+
+            @staticmethod
+            @torch.autograd.function.once_differentiable
+            def backward(ctx, *grads):
+                assert len(grads) == len(static_grad_outputs)
+                for g, grad in zip(static_grad_outputs, grads):
+                    if g is not None:
+                        # don't copy if autograd gods have been kind and the
+                        # incoming grad is already in the right place
+                        if g.data_ptr() != grad.data_ptr():
+                            g.copy_(grad)
+                bwd_graph.replay()
+
+                # Input args that didn't require grad expect a None gradient.
+                assert isinstance(static_grad_inputs, tuple)
+                return tuple(
+                    b.detach() if b is not None else b for b in static_grad_inputs
+                )
+
+        def functionalized(*user_args):
+            # Runs the autograd function with inputs == all inputs to the graph that might require grad
+            # (explicit user args + module parameters)
+            # Assumes module params didn't change since capture.
+            flatten_user_args = _pytree.arg_tree_leaves(*user_args)
+            out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
+            return _pytree.tree_unflatten(out, output_unflatten_spec)
+
+        return functionalized
+
+    # Put together the final graphed callables
+    ret = []
+    for i, func in enumerate(callables):
+        graphed = make_graphed_autograd_function(
+            fwd_graphs[i],
+            bwd_graphs[i],
+            per_callable_module_params[i],
+            per_callable_len_user_args[i],
+            per_callable_output_unflatten_spec[i],
+            per_callable_static_input_surfaces[i],
+            per_callable_static_outputs[i],
+            per_callable_static_grad_outputs[i],
+            per_callable_static_grad_inputs[i],
+        )
+
+        if isinstance(func, torch.nn.Module):
+
+            def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
+                def new_fwd(*user_args):
+                    # If the module's training-or-eval state matches what we graphed,
+                    # run the graph, otherwise run the original forward method
+                    if func.training == graph_training_state:
+                        return graphed(*user_args)
+                    else:
+                        return orig_fwd(*user_args)
+
+                return new_fwd
+
+            func.forward = make_graphed_forward(func, func.training, graphed, func.forward)  # type: ignore[assignment]
+            ret.append(func)
+        else:
+            ret.append(graphed)
+
+    if just_one_callable:
+        return ret[0]
+
+    return tuple(ret)
diff --git a/MLPY/Lib/site-packages/torch/cuda/jiterator.py b/MLPY/Lib/site-packages/torch/cuda/jiterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee6ddab5b6cc3f5b13456daae3972a318699aad3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/jiterator.py
@@ -0,0 +1,185 @@
+import re
+from typing import Callable, List
+
+import torch
+from torch import Tensor
+
+__all__: List[str] = []
+
+
+class _CodeParser:
+    def __init__(self, code_string: str):
+        optional_ws = r"\s*"
+        required_ws = r"\s+"
+        template_params = r"(?P\<.+\>)"
+        return_type = r"(?P\w+)"
+        function_name = r"(?P\w+)"
+        function_params = r"(?P\(.+\))"
+        function_body = r"(?P\{.+\})"
+
+        pattern = (
+            optional_ws
+            + "template"
+            + optional_ws
+            + template_params
+            + optional_ws
+            + return_type
+            + required_ws
+            + function_name
+            + optional_ws
+            + function_params
+            + optional_ws
+            + function_body
+            + optional_ws
+        )
+
+        result = re.match(
+            pattern, code_string, re.DOTALL
+        )  # DOTALL for matching multiline
+
+        if result is None:
+            raise Exception(
+                f"Couldn't parse code, please check correctness:\n {code_string}"
+            )
+
+        self.template_params = result["template_params"]
+        self.return_type = result["return_type"]
+        self.function_name = result["function_name"]
+        self.function_params = result["function_params"]
+        self.function_body = result["function_body"]
+
+
+class _JittedFunction:
+    def __init__(
+        self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
+    ):
+        self.code_string = code_string
+
+        assert (
+            return_by_ref or num_outputs == 1
+        ), "Return by value only works for single output. "
+        self.return_by_ref = return_by_ref
+        self.num_outputs = num_outputs
+
+        parsed_code = _CodeParser(code_string)
+        self.kernel_name = parsed_code.function_name
+
+        self.kwargs_dict = kwargs
+        self.is_cuda_available = torch.cuda.is_available()
+
+    def __call__(self, *tensors: Tensor, **kwargs):
+        # Jiterator follow torch.cuda's lazy initialization behavior
+        # Defer checking cuda's availability at the function invocation time
+        assert (
+            self.is_cuda_available
+        ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
+
+        assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
+
+        expanded_kwargs = self.kwargs_dict.copy()
+        for key, value in kwargs.items():
+            if key in self.kwargs_dict:
+                expanded_kwargs[key] = value
+            else:
+                raise KeyError(f"{key} is not declared in function definition")
+
+        return torch._C._cuda_jiterator_compile_and_launch_kernel(
+            self.code_string,
+            self.kernel_name,
+            self.return_by_ref,
+            self.num_outputs,
+            tensors,
+            expanded_kwargs,
+        )
+
+
+def _create_jit_fn(code_string: str, **kwargs) -> Callable:
+    """
+    Create a jiterator-generated cuda kernel for an elementwise op.
+
+    The code string has to be a valid CUDA function that describes the computation for a single element. The code
+    string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
+    into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
+    local temp dir.
+
+    Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
+
+    Args:
+        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
+        kwargs (Dict, optional): Keyword arguments for generated function
+
+    Example::
+
+        code_string = "template  T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
+        jitted_fn = create_jit_fn(code_string, alpha=1.0)
+        a = torch.rand(3, device='cuda')
+        b = torch.rand(3, device='cuda')
+        # invoke jitted function like a regular python function
+        result = jitted_fn(a, b, alpha=3.14)
+
+    code_string also allows multiple function definitions, and the last function will be treated as the entry function.
+
+    Example::
+
+        code_string = "template  T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
+        code_string += "template  T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
+        jitted_fn = create_jit_fn(code_string, val=0.0)
+        a = torch.rand(3, device='cuda')
+        b = torch.rand(3, device='cuda')
+        # invoke jitted function like a regular python function
+        result = jitted_fn(a, b)  # using default val=0.0
+
+    Jiterator can be used together with python registration to override an operator's cuda kernel.
+    Following example is overriding gelu's cuda kernel with relu.
+
+    Example::
+
+        code_string = "template  T my_gelu(T a) { return a > 0 ? a : 0; }"
+        my_gelu = create_jit_fn(code_string)
+        my_lib = torch.library.Library("aten", "IMPL")
+        my_lib.impl('aten::gelu', my_gelu, "CUDA")
+        # torch.nn.GELU and torch.nn.function.gelu are now overridden
+        a = torch.rand(3, device='cuda')
+        torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
+
+    .. warning::
+        This API is in beta and may change in future releases.
+
+    .. warning::
+        This API only supports up to 8 inputs and 1 output
+
+    .. warning::
+        All input tensors must live in CUDA device
+    """
+    return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
+
+
+def _create_multi_output_jit_fn(
+    code_string: str, num_outputs: int, **kwargs
+) -> Callable:
+    """
+    Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
+
+    Args:
+        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
+        num_outputs(int): number of outputs return by the kernel
+        kwargs (Dict, optional): Keyword arguments for generated function
+
+    Example::
+
+        code_string = "template  void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
+        jitted_fn = create_jit_fn(code_string, alpha=1.0)
+        a = torch.rand(3, device='cuda')
+        b = torch.rand(3, device='cuda')
+        # invoke jitted function like a regular python function
+        result = jitted_fn(a, b, alpha=3.14)
+
+    .. warning::
+        This API is in beta and may change in future releases.
+
+    .. warning::
+        This API only supports up to 8 inputs and 8 outputs
+    """
+    return _JittedFunction(
+        code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
+    )
diff --git a/MLPY/Lib/site-packages/torch/cuda/memory.py b/MLPY/Lib/site-packages/torch/cuda/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..34fcedc5994cb2cd1b0bb05d7edbe1a58ec8b514
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/memory.py
@@ -0,0 +1,914 @@
+r"""This package adds support for device memory management implemented in CUDA."""
+
+import collections
+import contextlib
+import ctypes
+import pickle
+import sys
+import warnings
+from inspect import signature
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import _C
+
+from torch.types import Device
+from .._utils import _dummy_type
+from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
+
+from ._memory_viz import memory as _memory, segments as _segments
+
+__all__ = [
+    "caching_allocator_alloc",
+    "caching_allocator_delete",
+    "set_per_process_memory_fraction",
+    "empty_cache",
+    "memory_stats",
+    "memory_stats_as_nested_dict",
+    "reset_accumulated_memory_stats",
+    "reset_peak_memory_stats",
+    "reset_max_memory_allocated",
+    "reset_max_memory_cached",
+    "memory_allocated",
+    "max_memory_allocated",
+    "memory_reserved",
+    "max_memory_reserved",
+    "memory_cached",
+    "max_memory_cached",
+    "memory_snapshot",
+    "memory_summary",
+    "list_gpu_processes",
+    "mem_get_info",
+    "get_allocator_backend",
+    "CUDAPluggableAllocator",
+    "change_current_allocator",
+]
+
+
+if not hasattr(torch._C, "_cuda_CUDAAllocator"):
+    # Define dummy base classes
+    torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
+
+
+def _host_allocator():
+    _lazy_init()
+    return torch._C._cuda_cudaHostAllocator()
+
+
+@contextlib.contextmanager
+def _free_mutex():
+    torch._C._cuda_lock_mutex()
+    try:
+        yield
+    finally:
+        torch._C._cuda_unlock_mutex()
+
+
+def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
+    r"""Perform a memory allocation using the CUDA memory allocator.
+
+    Memory is allocated for a given device and a stream, this
+    function is intended to be used for interoperability with other
+    frameworks. Allocated memory is released through
+    :func:`~torch.cuda.caching_allocator_delete`.
+
+    Args:
+        size (int): number of bytes to be allocated.
+        device (torch.device or int, optional): selected device. If it is
+            ``None`` the default CUDA device is used.
+        stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
+            the default stream for the selected device is used.
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    if device is None:
+        device = torch.cuda.current_device()
+    device = _get_device_index(device)
+    if stream is None:
+        stream = torch.cuda.current_stream(device)
+    if isinstance(stream, torch.cuda.streams.Stream):
+        stream = stream.cuda_stream
+    if not isinstance(stream, int):
+        raise TypeError(
+            "Invalid type for stream argument, must be "
+            "`torch.cuda.Stream` or `int` representing a pointer "
+            "to a existing stream"
+        )
+    with torch.cuda.device(device):
+        return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
+
+
+def caching_allocator_delete(mem_ptr):
+    r"""Delete memory allocated using the CUDA memory allocator.
+
+    Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
+    is freed here. The associated device and stream are tracked inside
+    the allocator.
+
+    Args:
+        mem_ptr (int): memory address to be freed by the allocator.
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
+
+
+def set_per_process_memory_fraction(
+    fraction, device: Union[Device, int] = None
+) -> None:
+    r"""Set memory fraction for a process.
+
+    The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
+    The allowed value equals the total visible memory multiplied fraction.
+    If trying to allocate more than the allowed value in a process, will raise an out of
+    memory error in allocator.
+
+    Args:
+        fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
+        device (torch.device or int, optional): selected device. If it is
+            ``None`` the default CUDA device is used.
+    .. note::
+        In general, the total available free memory is less than the total capacity.
+    """
+    _lazy_init()
+    if device is None:
+        device = torch.cuda.current_device()
+    device = _get_device_index(device)
+    if not isinstance(fraction, float):
+        raise TypeError("Invalid type for fraction argument, must be `float`")
+    if fraction < 0 or fraction > 1:
+        raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
+
+    torch._C._cuda_setMemoryFraction(fraction, device)
+
+
+def empty_cache() -> None:
+    r"""Release all unoccupied cached memory currently held by the caching
+    allocator so that those can be used in other GPU application and visible in
+    `nvidia-smi`.
+
+    .. note::
+        :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
+        memory available for PyTorch. However, it may help reduce fragmentation
+        of GPU memory in certain cases. See :ref:`cuda-memory-management` for
+        more details about GPU memory management.
+    """
+    if is_initialized():
+        torch._C._cuda_emptyCache()
+
+
+def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
+    r"""Return a dictionary of CUDA memory allocator statistics for a given device.
+
+    The return value of this function is a dictionary of statistics, each of
+    which is a non-negative integer.
+
+    Core statistics:
+
+    - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      number of allocation requests received by the memory allocator.
+    - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      amount of allocated memory.
+    - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      number of reserved segments from ``cudaMalloc()``.
+    - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      amount of reserved memory.
+    - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      number of active memory blocks.
+    - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      amount of active memory.
+    - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      number of inactive, non-releasable memory blocks.
+    - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      amount of inactive, non-releasable memory.
+
+    For these core statistics, values are broken down as follows.
+
+    Pool type:
+
+    - ``all``: combined statistics across all memory pools.
+    - ``large_pool``: statistics for the large allocation pool
+      (as of October 2019, for size >= 1MB allocations).
+    - ``small_pool``: statistics for the small allocation pool
+      (as of October 2019, for size < 1MB allocations).
+
+    Metric type:
+
+    - ``current``: current value of this metric.
+    - ``peak``: maximum value of this metric.
+    - ``allocated``: historical total increase in this metric.
+    - ``freed``: historical total decrease in this metric.
+
+    In addition to the core statistics, we also provide some simple event
+    counters:
+
+    - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
+      result in a cache flush and retry.
+    - ``"num_ooms"``: number of out-of-memory errors thrown.
+
+    The caching allocator can be configured via ENV to not split blocks larger than a
+    defined size (see Memory Management section of the Cuda Semantics documentation).
+    This helps avoid memory fragmentation but may have a performance
+    penalty. Additional outputs to assist with tuning and evaluating impact:
+
+    - ``"max_split_size"``: blocks above this size will not be split.
+    - ``"oversize_allocations.{current,peak,allocated,freed}"``:
+      number of over-size allocation requests received by the memory allocator.
+    - ``"oversize_segments.{current,peak,allocated,freed}"``:
+      number of over-size reserved segments from ``cudaMalloc()``.
+
+    The caching allocator can be configured via ENV to round memory allocations in order
+    to reduce fragmentation. Sometimes the overhead from rounding can be higher than
+    the fragmentation it helps reduce. The following stat can be used to check if
+    rounding adds too much overhead:
+
+    - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
+      memory requested by client code, compare this with allocated_bytes to check if
+      allocation rounding adds too much overhead.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistics for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+
+    .. note::
+        With :ref:`backend:cudaMallocAsync`, some stats are not
+        meaningful, and are always reported as zero.
+    """
+    result = []
+
+    def _recurse_add_to_result(prefix, obj):
+        if isinstance(obj, dict):
+            if len(prefix) > 0:
+                prefix += "."
+            for k, v in obj.items():
+                _recurse_add_to_result(prefix + k, v)
+        else:
+            result.append((prefix, obj))
+
+    stats = memory_stats_as_nested_dict(device=device)
+    _recurse_add_to_result("", stats)
+    result.sort()
+
+    return collections.OrderedDict(result)
+
+
+def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
+    r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
+    if not is_initialized():
+        return {}
+    device = _get_device_index(device, optional=True)
+    return torch._C._cuda_memoryStats(device)
+
+
+def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
+    r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
+
+    See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
+    the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
+    `"num_alloc_retries"` and `"num_ooms"`.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    device = _get_device_index(device, optional=True)
+    return torch._C._cuda_resetAccumulatedMemoryStats(device)
+
+
+def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
+    r"""Reset the "peak" stats tracked by the CUDA memory allocator.
+
+    See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
+    `"peak"` key in each individual stat dict.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    device = _get_device_index(device, optional=True)
+    return torch._C._cuda_resetPeakMemoryStats(device)
+
+
+def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
+    r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+
+    See :func:`~torch.cuda.max_memory_allocated` for details.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. warning::
+        This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
+        /all/ peak memory stats.
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    warnings.warn(
+        "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
+        "which resets /all/ peak memory stats.",
+        FutureWarning,
+    )
+    return reset_peak_memory_stats(device=device)
+
+
+def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
+    r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+
+    See :func:`~torch.cuda.max_memory_cached` for details.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. warning::
+        This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
+        /all/ peak memory stats.
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    warnings.warn(
+        "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
+        "which resets /all/ peak memory stats.",
+        FutureWarning,
+    )
+    return reset_peak_memory_stats(device=device)
+
+
+def memory_allocated(device: Union[Device, int] = None) -> int:
+    r"""Return the current GPU memory occupied by tensors in bytes for a given device.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        This is likely less than the amount shown in `nvidia-smi` since some
+        unused memory can be held by the caching allocator and some context
+        needs to be created on GPU. See :ref:`cuda-memory-management` for more
+        details about GPU memory management.
+    """
+    return memory_stats(device=device).get("allocated_bytes.all.current", 0)
+
+
+def max_memory_allocated(device: Union[Device, int] = None) -> int:
+    r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
+
+    By default, this returns the peak allocated memory since the beginning of
+    this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
+    reset the starting point in tracking this metric. For example, these two
+    functions can measure the peak allocated memory usage of each iteration in a
+    training loop.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
+
+
+def memory_reserved(device: Union[Device, int] = None) -> int:
+    r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    return memory_stats(device=device).get("reserved_bytes.all.current", 0)
+
+
+def max_memory_reserved(device: Union[Device, int] = None) -> int:
+    r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
+
+    By default, this returns the peak cached memory since the beginning of this
+    program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
+    the starting point in tracking this metric. For example, these two functions
+    can measure the peak cached memory amount of each iteration in a training
+    loop.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
+
+
+def memory_cached(device: Union[Device, int] = None) -> int:
+    r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
+    warnings.warn(
+        "torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved",
+        FutureWarning,
+    )
+    return memory_reserved(device=device)
+
+
+def max_memory_cached(device: Union[Device, int] = None) -> int:
+    r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
+    warnings.warn(
+        "torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved",
+        FutureWarning,
+    )
+    return max_memory_reserved(device=device)
+
+
+def memory_snapshot():
+    r"""Return a snapshot of the CUDA memory allocator state across all devices.
+
+    Interpreting the output of this function requires familiarity with the
+    memory allocator internals.
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    return torch._C._cuda_memorySnapshot()["segments"]
+
+
+def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
+    r"""Return a human-readable printout of the current memory allocator statistics for a given device.
+
+    This can be useful to display periodically during training, or when
+    handling out-of-memory exceptions.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            printout for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+        abbreviated (bool, optional): whether to return an abbreviated summary
+            (default: False).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more details about GPU memory
+        management.
+    """
+    device = _get_device_index(device, optional=True)
+    stats = memory_stats(device=device)
+
+    def _format_size(sz, pref_sz):
+        prefixes = ["B  ", "KiB", "MiB", "GiB", "TiB", "PiB"]
+        prefix = prefixes[0]
+        for new_prefix in prefixes[1:]:
+            if pref_sz < 768 * 1024:
+                break
+            prefix = new_prefix
+            sz //= 1024
+            pref_sz /= 1024
+        return f"{sz:6d} {prefix}"
+
+    def _format_count(cnt, pref_cnt):
+        prefixes = [" ", "K", "M"]
+        prefix = prefixes[0]
+        for new_prefix in prefixes[1:]:
+            if pref_cnt < 750 * 1000:
+                break
+            prefix = new_prefix
+            cnt //= 1000
+            pref_cnt /= 1000
+        return f"{cnt:7d} {prefix} "
+
+    metrics_to_display = [
+        ("allocated_bytes", "Allocated memory", _format_size),
+        ("active_bytes", "Active memory", _format_size),
+        ("requested_bytes", "Requested memory", _format_size),
+        ("reserved_bytes", "GPU reserved memory", _format_size),
+        ("inactive_split_bytes", "Non-releasable memory", _format_size),
+        ("allocation", "Allocations", _format_count),
+        ("active", "Active allocs", _format_count),
+        ("segment", "GPU reserved segments", _format_count),
+        ("inactive_split", "Non-releasable allocs", _format_count),
+    ]
+
+    lines = []
+    lines.append("=" * 75)
+    lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
+    lines.append("-" * 75)
+    lines.append(
+        "  {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d}  "
+    )
+    lines.append("=" * 75)
+    lines.append(
+        "        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  "
+    )
+
+    for metric_key, metric_name, formatter in metrics_to_display:
+        lines.append("-" * 75)
+        submetrics = [("all", metric_name)]
+        if not abbreviated:
+            submetrics.append(("large_pool", "      from large pool"))
+            submetrics.append(("small_pool", "      from small pool"))
+
+        current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
+            None,
+            None,
+            None,
+            None,
+        )
+
+        for submetric_key, submetric_name in submetrics:
+            prefix = metric_key + "." + submetric_key + "."
+
+            current = stats[prefix + "current"]
+            peak = stats[prefix + "peak"]
+            allocated = stats[prefix + "allocated"]
+            freed = stats[prefix + "freed"]
+
+            if current_prefval is None:
+                current_prefval = current
+                peak_prefval = peak
+                allocated_prefval = allocated
+                freed_prefval = freed
+
+            lines.append(
+                " {:<21} | {} | {} | {} | {} ".format(
+                    submetric_name,
+                    formatter(current, current_prefval),
+                    formatter(peak, peak_prefval),
+                    formatter(allocated, allocated_prefval),
+                    formatter(freed, freed_prefval),
+                ),
+            )
+
+    metrics_to_display = [
+        ("oversize_allocations", "Oversize allocations", _format_count),
+        ("oversize_segments", "Oversize GPU segments", _format_count),
+    ]
+
+    for metric_key, metric_name, formatter in metrics_to_display:
+        lines.append("-" * 75)
+
+        prefix = metric_key + "."
+
+        current = stats[prefix + "current"]
+        peak = stats[prefix + "peak"]
+        allocated = stats[prefix + "allocated"]
+        freed = stats[prefix + "freed"]
+
+        lines.append(
+            " {:<21} | {} | {} | {} | {} ".format(
+                metric_name,
+                formatter(current, current),
+                formatter(peak, peak),
+                formatter(allocated, allocated),
+                formatter(freed, freed),
+            ),
+        )
+
+    lines.append("=" * 75)
+
+    fmt_dict = {"_": "", "device": device}
+    for k, v in stats.items():
+        fmt_dict[k.replace(".", "-")] = v
+    return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
+
+
+def list_gpu_processes(device: Union[Device, int] = None) -> str:
+    r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
+
+    This can be useful to display periodically during training, or when
+    handling out-of-memory exceptions.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            printout for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+    """
+    try:
+        import pynvml  # type: ignore[import]
+    except ModuleNotFoundError:
+        return "pynvml module not found, please install pynvml"
+    from pynvml import NVMLError_DriverNotLoaded
+
+    try:
+        pynvml.nvmlInit()
+    except NVMLError_DriverNotLoaded:
+        return "cuda driver can't be loaded, is cuda enabled?"
+    device = _get_nvml_device_index(device)
+    handle = pynvml.nvmlDeviceGetHandleByIndex(device)
+    procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
+    lines = []
+    lines.append(f"GPU:{device}")
+    if len(procs) == 0:
+        lines.append("no processes are running")
+    for p in procs:
+        mem = p.usedGpuMemory / (1024 * 1024)
+        lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory")
+    return "\n".join(lines)
+
+
+def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
+    r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
+
+    Args:
+        device (torch.device or int, optional): selected device. Returns
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+
+    .. note::
+        See :ref:`cuda-memory-management` for more
+        details about GPU memory management.
+    """
+    if device is None:
+        device = torch.cuda.current_device()
+    device = _get_device_index(device)
+    return torch.cuda.cudart().cudaMemGetInfo(device)
+
+
+def _record_memory_history_legacy(
+    enabled: bool,
+    record_context=True,
+    trace_alloc_max_entries=1,
+    trace_alloc_record_context=False,
+    device: Union[Device, int] = None,
+    record_context_cpp=False,
+):
+    _C._cuda_record_memory_history_legacy(
+        enabled,
+        record_context,
+        trace_alloc_max_entries,
+        trace_alloc_record_context,
+        record_context_cpp,
+    )
+
+
+def _record_memory_history(enabled="all", *args, **kwargs):
+    """Enable recording of stack traces associated with memory
+    allocations, so you can tell what allocated any piece of memory in
+    :func:`torch.cuda.memory._snapshot()`.
+
+    In addition too keeping stack traces with each current allocation and free,
+    this will also enable recording of a history of all alloc/free events.
+
+    Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
+    and the tools in `_memory_viz.py` to visualize snapshots.
+
+    The Python trace collection is fast (2us per trace), so you may consider
+    enabling this on production jobs if you anticipate ever having to debug
+    memory issues.
+
+    C++ trace collection is also fast (~50ns/frame), which for many typical programs
+    works out to ~2us per trace, but can vary depending on stack depth.
+
+    Args:
+        enabled (Literal[None, "state", "all"], optional):
+            `None`, disable recording memory history.
+            `"state"`, keep information for currenly allocated memory.
+            `"all"`, additionally keep a history of all alloc/free calls.
+            Defaults to "all".
+        context (Literal[None, "state", "alloc", "all"], optional):
+            `None`, Do not record any tracebacks.
+            `"state"`, Record tracebacks for currently allocated memory.
+            `"alloc"`, additionally keep tracebacks for alloc calls.
+            `"all"`, additionally keep tracebacks for free calls.
+            Defaults to "all".
+        stacks (Literal["python", "all"], optional):
+            `"python"`, include Python, TorchScript, and inductor frames in tracebacks
+            `"all"`, additionally include C++ frames
+            Defaults to "all".
+        max_entries (int, optional): Keep a maximum of `max_entries`
+            alloc/free events in the recorded history recorded.
+    """
+    if isinstance(enabled, bool):
+        return _record_memory_history_legacy(enabled, *args, **kwargs)
+    else:
+        return _record_memory_history_impl(enabled, *args, **kwargs)
+
+
+def _record_memory_history_impl(
+    enabled: Optional[str] = "all",
+    context: Optional[str] = "all",
+    stacks: str = "all",
+    max_entries: int = sys.maxsize,
+    device: Union[Device, int] = None,
+):
+    _C._cuda_record_memory_history(enabled, context, stacks, max_entries)
+
+
+_record_memory_history.__signature__ = signature(_record_memory_history_impl)  # type: ignore[attr-defined]
+
+
+def _snapshot(device: Union[Device, int] = None):
+    """Save a snapshot of CUDA memory state at the time it was called.
+
+    The state is represented as a dictionary with the following structure.
+
+    .. code-block:: python
+
+        class Snapshot(TypedDict):
+            segments : List[Segment]
+            device_traces: List[List[TraceEntry]]
+
+        class Segment(TypedDict):
+            # Segments are memory returned from a cudaMalloc call.
+            # The size of reserved memory is the sum of all Segments.
+            # Segments are cached and reused for future allocations.
+            # If the reuse is smaller than the segment, the segment
+            # is split into more then one Block.
+            # empty_cache() frees Segments that are entirely inactive.
+            address: int
+            total_size: int #  cudaMalloc'd size of segment
+            stream: int
+            segment_type: Literal['small', 'large'] # 'large' (>1MB)
+            allocated_size: int # size of memory in use
+            active_size: int # size of memory in use or in active_awaiting_free state
+            blocks : List[Block]
+
+        class Block(TypedDict):
+            # A piece of memory returned from the allocator, or
+            # current cached but inactive.
+            size: int
+            requested_size: int # size requested during malloc, may be smaller than
+                                # size due to rounding
+            address: int
+            state: Literal['active_allocated', # used by a tensor
+                        'active_awaiting_free', # waiting for another stream to finish using
+                                                # this, then it will become free
+                        'inactive',] # free for reuse
+            frames: List[Frame] # stack trace from where the allocation occurred
+
+        class Frame(TypedDict):
+                filename: str
+                line: int
+                name: str
+
+        class TraceEntry(TypedDict):
+            # When `torch.cuda.memory._record_memory_history()` is enabled,
+            # the snapshot will contain TraceEntry objects that record each
+            # action the allocator took.
+            action: Literal[
+            'alloc'  # memory allocated
+            'free_requested', # the allocated received a call to free memory
+            'free_completed', # the memory that was requested to be freed is now
+                            # able to be used in future allocation calls
+            'segment_alloc', # the caching allocator ask cudaMalloc for more memory
+                            # and added it as a segment in its cache
+            'segment_free',  # the caching allocator called cudaFree to return memory
+                            # to cuda possibly trying free up memory to
+                            # allocate more segments or because empty_caches was called
+            'oom',          # the allocator threw an OOM exception. 'size' is
+                            # the requested number of bytes that did not succeed
+            'snapshot'      # the allocator generated a memory snapshot
+                            # useful to coorelate a previously taken
+                            # snapshot with this trace
+            ]
+            addr: int # not present for OOM
+            frames: List[Frame]
+            size: int
+            stream: int
+            device_free: int # only present for OOM, the amount of
+                            # memory cuda still reports to be free
+
+    Returns:
+        The Snapshot dictionary object
+    """
+    return _C._cuda_memorySnapshot()
+
+
+def _dump_snapshot(filename="dump_snapshot.pickle"):
+    """
+    Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
+
+    This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
+
+    Args:
+        filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
+    """
+    s = _snapshot()
+    with open(filename, "wb") as f:
+        pickle.dump(s, f)
+
+
+def _save_segment_usage(filename="output.svg", snapshot=None):
+    if snapshot is None:
+        snapshot = _snapshot()
+    with open(filename, "w") as f:
+        f.write(_segments(snapshot))
+
+
+def _save_memory_usage(filename="output.svg", snapshot=None):
+    if snapshot is None:
+        snapshot = _snapshot()
+    with open(filename, "w") as f:
+        f.write(_memory(snapshot))
+
+
+def _set_allocator_settings(env: str):
+    return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
+
+
+def get_allocator_backend() -> str:
+    r"""Return a string describing the active allocator backend as set by
+    ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
+    ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
+    (CUDA's built-in asynchronous allocator).
+
+    .. note::
+        See :ref:`cuda-memory-management` for details on choosing the allocator backend.
+    """
+    return torch._C._cuda_getAllocatorBackend()
+
+
+class _CUDAAllocator:
+    r"""Wrapper over internal CUDA memory allocators."""
+
+    def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
+        self._allocator = allocator
+
+    def allocator(self):
+        return self._allocator
+
+
+class CUDAPluggableAllocator(_CUDAAllocator):
+    r"""CUDA memory allocator loaded from a so file."""
+
+    def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
+        r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
+
+        To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
+
+        Args:
+            path_to_so_file(str): Path in the filesystem to the `.so` file containing
+                the allocator functions
+            alloc_fn_name(str): Name of the function to perform the memory allocation
+                in the so file. The signature must be:
+                void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
+            free_fn_name(str): Name of the function to perform the memory release
+                in the so file. The signature must be:
+                void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
+
+        .. warning::
+            This is currently supported only in unix OSs
+
+        .. note::
+            See :ref:`cuda-memory-management` for details on creating and using a custom allocator
+        """
+        allocator = ctypes.CDLL(path_to_so_file)
+        alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
+        free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
+        assert alloc_fn is not None
+        assert free_fn is not None
+        self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
+
+
+def change_current_allocator(allocator: _CUDAAllocator) -> None:
+    r"""Change the currently used memory allocator to be the one provided.
+
+    If the current allocator has already been used/initialized, this function will error.
+
+
+    Args:
+        allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
+    .. note::
+        See :ref:`cuda-memory-management` for details on creating and using a custom allocator
+    """
+    torch._C._cuda_changeCurrentAllocator(allocator.allocator())
+
+
+def _get_current_allocator() -> _CUDAAllocator:
+    r"""Return the allocator being currently used.
+
+    .. note::
+        See :ref:`cuda-memory-management` for details on creating and using a custom allocator
+    """
+    return _CUDAAllocator(torch._C._cuda_getAllocator())
diff --git a/MLPY/Lib/site-packages/torch/cuda/nccl.py b/MLPY/Lib/site-packages/torch/cuda/nccl.py
new file mode 100644
index 0000000000000000000000000000000000000000..439651a0492431383f358641119b26651d532dda
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/nccl.py
@@ -0,0 +1,137 @@
+import collections
+import warnings
+from typing import Optional, Sequence, Union
+
+import torch.cuda
+
+
+__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
+
+SUM = 0  # ncclRedOp_t
+
+
+def is_available(tensors):
+    if not hasattr(torch._C, "_nccl_all_reduce"):
+        warnings.warn("PyTorch is not compiled with NCCL support")
+        return False
+
+    devices = set()
+    for tensor in tensors:
+        if tensor.is_sparse:
+            return False
+        if not tensor.is_contiguous():
+            return False
+        if not tensor.is_cuda:
+            return False
+        device = tensor.get_device()
+        if device in devices:
+            return False
+        devices.add(device)
+
+    return True
+
+
+def version():
+    ver = torch._C._nccl_version()
+    major = ver >> 32
+    minor = (ver >> 16) & 65535
+    patch = ver & 65535
+    suffix = torch._C._nccl_version_suffix().decode("utf-8")
+    if suffix == "":
+        return (major, minor, patch)
+    else:
+        return (major, minor, patch, suffix)
+
+
+def unique_id():
+    return torch._C._nccl_unique_id()
+
+
+def init_rank(num_ranks, uid, rank):
+    return torch._C._nccl_init_rank(num_ranks, uid, rank)
+
+
+def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
+    if not isinstance(inputs, collections.abc.Container) or isinstance(
+        inputs, torch.Tensor
+    ):
+        raise TypeError("Inputs should be a collection of tensors")
+
+
+def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
+    _check_sequence_type(inputs)
+    if outputs is None:
+        outputs = inputs
+    _check_sequence_type(outputs)
+    torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
+
+
+# `output` used to be `outputs`, taking in a list of tensors. So we have two
+# arguments for BC reasons.
+def reduce(
+    inputs: Sequence[torch.Tensor],
+    output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
+    root: int = 0,
+    op: int = SUM,
+    streams: Optional[Sequence[torch.cuda.Stream]] = None,
+    comms=None,
+    *,
+    outputs: Optional[Sequence[torch.Tensor]] = None,
+) -> None:
+    _check_sequence_type(inputs)
+    _output: torch.Tensor
+    if outputs is not None:
+        if output is not None:
+            raise ValueError(
+                "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
+                "favor of 'output', taking in a single output tensor. The signature of reduce is: "
+                "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
+            )
+        else:
+            warnings.warn(
+                "nccl.reduce with an output tensor list is deprecated. "
+                "Please specify a single output tensor with argument 'output' instead instead."
+            )
+            _output = outputs[root]
+    elif not isinstance(output, torch.Tensor) and isinstance(
+        output, collections.abc.Sequence
+    ):
+        # User called old API with positional arguments of list of output tensors.
+        warnings.warn(
+            "nccl.reduce with an output tensor list is deprecated. "
+            "Please specify a single output tensor."
+        )
+        _output = output[root]
+    else:
+        _output = inputs[root] if output is None else output
+    torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
+
+
+def broadcast(
+    inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
+) -> None:
+    _check_sequence_type(inputs)
+    torch._C._nccl_broadcast(inputs, root, streams, comms)
+
+
+def all_gather(
+    inputs: Sequence[torch.Tensor],
+    outputs: Sequence[torch.Tensor],
+    streams=None,
+    comms=None,
+) -> None:
+    _check_sequence_type(inputs)
+    _check_sequence_type(outputs)
+    torch._C._nccl_all_gather(inputs, outputs, streams, comms)
+
+
+def reduce_scatter(
+    inputs: Sequence[torch.Tensor],
+    outputs: Sequence[torch.Tensor],
+    op: int = SUM,
+    streams=None,
+    comms=None,
+) -> None:
+    _check_sequence_type(inputs)
+    _check_sequence_type(outputs)
+    torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
diff --git a/MLPY/Lib/site-packages/torch/cuda/nvtx.py b/MLPY/Lib/site-packages/torch/cuda/nvtx.py
new file mode 100644
index 0000000000000000000000000000000000000000..58713d06e713a3c5fe4ad9f7eedf0a533a0343d0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/nvtx.py
@@ -0,0 +1,91 @@
+r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
+
+from contextlib import contextmanager
+
+try:
+    from torch._C import _nvtx
+except ImportError:
+
+    class _NVTXStub:
+        @staticmethod
+        def _fail(*args, **kwargs):
+            raise RuntimeError(
+                "NVTX functions not installed. Are you sure you have a CUDA build?"
+            )
+
+        rangePushA = _fail
+        rangePop = _fail
+        markA = _fail
+
+    _nvtx = _NVTXStub()  # type: ignore[assignment]
+
+__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
+
+
+def range_push(msg):
+    """
+    Push a range onto a stack of nested range span.  Returns zero-based depth of the range that is started.
+
+    Args:
+        msg (str): ASCII message to associate with range
+    """
+    return _nvtx.rangePushA(msg)
+
+
+def range_pop():
+    """Pop a range off of a stack of nested range spans.  Returns the  zero-based depth of the range that is ended."""
+    return _nvtx.rangePop()
+
+
+def range_start(msg) -> int:
+    """
+    Mark the start of a range with string message. It returns an unique handle
+    for this range to pass to the corresponding call to rangeEnd().
+
+    A key difference between this and range_push/range_pop is that the
+    range_start/range_end version supports range across threads (start on one
+    thread and end on another thread).
+
+    Returns: A range handle (uint64_t) that can be passed to range_end().
+
+    Args:
+        msg (str): ASCII message to associate with the range.
+    """
+    return _nvtx.rangeStartA(msg)
+
+
+def range_end(range_id) -> None:
+    """
+    Mark the end of a range for a given range_id.
+
+    Args:
+        range_id (int): an unique handle for the start range.
+    """
+    _nvtx.rangeEnd(range_id)
+
+
+def mark(msg):
+    """
+    Describe an instantaneous event that occurred at some point.
+
+    Args:
+        msg (str): ASCII message to associate with the event.
+    """
+    return _nvtx.markA(msg)
+
+
+@contextmanager
+def range(msg, *args, **kwargs):
+    """
+    Context manager / decorator that pushes an NVTX range at the beginning
+    of its scope, and pops it at the end. If extra arguments are given,
+    they are passed as arguments to msg.format().
+
+    Args:
+        msg (str): message to associate with the range
+    """
+    range_push(msg.format(*args, **kwargs))
+    try:
+        yield
+    finally:
+        range_pop()
diff --git a/MLPY/Lib/site-packages/torch/cuda/profiler.py b/MLPY/Lib/site-packages/torch/cuda/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..010590d5d87e681edd650bea98100165dcd7e772
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/profiler.py
@@ -0,0 +1,61 @@
+import contextlib
+import tempfile
+
+import torch
+from . import check_error, cudart
+
+__all__ = ["init", "start", "stop", "profile"]
+
+DEFAULT_FLAGS = [
+    "gpustarttimestamp",
+    "gpuendtimestamp",
+    "gridsize3d",
+    "threadblocksize",
+    "streamid",
+    "enableonstart 0",
+    "conckerneltrace",
+]
+
+
+def init(output_file, flags=None, output_mode="key_value"):
+    rt = cudart()
+    if not hasattr(rt, "cudaOutputMode"):
+        raise AssertionError("HIP does not support profiler initialization!")
+    if (
+        hasattr(torch.version, "cuda")
+        and torch.version.cuda is not None
+        and int(torch.version.cuda.split(".")[0]) >= 12
+    ):
+        # Check https://github.com/pytorch/pytorch/pull/91118
+        # cudaProfilerInitialize is no longer needed after CUDA 12
+        raise AssertionError("CUDA12+ does not need profiler initialization!")
+    flags = DEFAULT_FLAGS if flags is None else flags
+    if output_mode == "key_value":
+        output_mode_enum = rt.cudaOutputMode.KeyValuePair
+    elif output_mode == "csv":
+        output_mode_enum = rt.cudaOutputMode.CSV
+    else:
+        raise RuntimeError(
+            "supported CUDA profiler output modes are: key_value and csv"
+        )
+    with tempfile.NamedTemporaryFile(delete=True) as f:
+        f.write(b"\n".join(f.encode("ascii") for f in flags))
+        f.flush()
+        check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
+
+
+def start():
+    check_error(cudart().cudaProfilerStart())
+
+
+def stop():
+    check_error(cudart().cudaProfilerStop())
+
+
+@contextlib.contextmanager
+def profile():
+    try:
+        start()
+        yield
+    finally:
+        stop()
diff --git a/MLPY/Lib/site-packages/torch/cuda/random.py b/MLPY/Lib/site-packages/torch/cuda/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f831b112b2a92b6d51c5dfa1dafe8c0125e6e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/random.py
@@ -0,0 +1,179 @@
+from typing import Iterable, List, Union
+
+import torch
+from .. import Tensor
+from . import _lazy_call, _lazy_init, current_device, device_count
+
+__all__ = [
+    "get_rng_state",
+    "get_rng_state_all",
+    "set_rng_state",
+    "set_rng_state_all",
+    "manual_seed",
+    "manual_seed_all",
+    "seed",
+    "seed_all",
+    "initial_seed",
+]
+
+
+def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
+    r"""Return the random number generator state of the specified GPU as a ByteTensor.
+
+    Args:
+        device (torch.device or int, optional): The device to return the RNG state of.
+            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
+
+    .. warning::
+        This function eagerly initializes CUDA.
+    """
+    _lazy_init()
+    if isinstance(device, str):
+        device = torch.device(device)
+    elif isinstance(device, int):
+        device = torch.device("cuda", device)
+    idx = device.index
+    if idx is None:
+        idx = current_device()
+    default_generator = torch.cuda.default_generators[idx]
+    return default_generator.get_state()
+
+
+def get_rng_state_all() -> List[Tensor]:
+    r"""Return a list of ByteTensor representing the random number states of all devices."""
+    results = []
+    for i in range(device_count()):
+        results.append(get_rng_state(i))
+    return results
+
+
+def set_rng_state(
+    new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
+) -> None:
+    r"""Set the random number generator state of the specified GPU.
+
+    Args:
+        new_state (torch.ByteTensor): The desired state
+        device (torch.device or int, optional): The device to set the RNG state.
+            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
+    """
+    with torch._C._DisableFuncTorch():
+        new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
+    if isinstance(device, str):
+        device = torch.device(device)
+    elif isinstance(device, int):
+        device = torch.device("cuda", device)
+
+    def cb():
+        idx = device.index
+        if idx is None:
+            idx = current_device()
+        default_generator = torch.cuda.default_generators[idx]
+        default_generator.set_state(new_state_copy)
+
+    _lazy_call(cb)
+
+
+def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
+    r"""Set the random number generator state of all devices.
+
+    Args:
+        new_states (Iterable of torch.ByteTensor): The desired state for each device.
+    """
+    for i, state in enumerate(new_states):
+        set_rng_state(state, i)
+
+
+def manual_seed(seed: int) -> None:
+    r"""Set the seed for generating random numbers for the current GPU.
+
+    It's safe to call this function if CUDA is not available; in that
+    case, it is silently ignored.
+
+    Args:
+        seed (int): The desired seed.
+
+    .. warning::
+        If you are working with a multi-GPU model, this function is insufficient
+        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
+    """
+    seed = int(seed)
+
+    def cb():
+        idx = current_device()
+        default_generator = torch.cuda.default_generators[idx]
+        default_generator.manual_seed(seed)
+
+    _lazy_call(cb, seed=True)
+
+
+def manual_seed_all(seed: int) -> None:
+    r"""Set the seed for generating random numbers on all GPUs.
+
+    It's safe to call this function if CUDA is not available; in that
+    case, it is silently ignored.
+
+    Args:
+        seed (int): The desired seed.
+    """
+    seed = int(seed)
+
+    def cb():
+        for i in range(device_count()):
+            default_generator = torch.cuda.default_generators[i]
+            default_generator.manual_seed(seed)
+
+    _lazy_call(cb, seed_all=True)
+
+
+def seed() -> None:
+    r"""Set the seed for generating random numbers to a random number for the current GPU.
+
+    It's safe to call this function if CUDA is not available; in that
+    case, it is silently ignored.
+
+    .. warning::
+        If you are working with a multi-GPU model, this function will only initialize
+        the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
+    """
+
+    def cb():
+        idx = current_device()
+        default_generator = torch.cuda.default_generators[idx]
+        default_generator.seed()
+
+    _lazy_call(cb)
+
+
+def seed_all() -> None:
+    r"""Set the seed for generating random numbers to a random number on all GPUs.
+
+    It's safe to call this function if CUDA is not available; in that
+    case, it is silently ignored.
+    """
+
+    def cb():
+        random_seed = 0
+        seeded = False
+        for i in range(device_count()):
+            default_generator = torch.cuda.default_generators[i]
+            if not seeded:
+                default_generator.seed()
+                random_seed = default_generator.initial_seed()
+                seeded = True
+            else:
+                default_generator.manual_seed(random_seed)
+
+    _lazy_call(cb)
+
+
+def initial_seed() -> int:
+    r"""Return the current random seed of the current GPU.
+
+    .. warning::
+        This function eagerly initializes CUDA.
+    """
+    _lazy_init()
+    idx = current_device()
+    default_generator = torch.cuda.default_generators[idx]
+    return default_generator.initial_seed()
diff --git a/MLPY/Lib/site-packages/torch/cuda/sparse.py b/MLPY/Lib/site-packages/torch/cuda/sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..702def052945ad1bd54ab221ae517a9c01361e34
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/sparse.py
@@ -0,0 +1 @@
+# The Tensor classes are added to this module by python_tensor.cpp
diff --git a/MLPY/Lib/site-packages/torch/cuda/streams.py b/MLPY/Lib/site-packages/torch/cuda/streams.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ee98beab003d9f5bb7692db6efb969d49b26ea5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/cuda/streams.py
@@ -0,0 +1,241 @@
+import ctypes
+
+import torch
+from torch._streambase import _EventBase, _StreamBase
+from .._utils import _dummy_type
+
+
+if not hasattr(torch._C, "_CudaStreamBase"):
+    # Define dummy base classes
+    torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
+    torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
+
+
+class Stream(torch._C._CudaStreamBase, _StreamBase):
+    r"""Wrapper around a CUDA stream.
+
+    A CUDA stream is a linear sequence of execution that belongs to a specific
+    device, independent from other streams.  See :ref:`cuda-semantics` for
+    details.
+
+    Args:
+        device(torch.device or int, optional): a device on which to allocate
+            the stream. If :attr:`device` is ``None`` (default) or a negative
+            integer, this will use the current device.
+        priority(int, optional): priority of the stream, should be 0 or
+            negative, where negative numbers indicate higher priority. By default,
+            streams have priority 0.
+
+    """
+
+    def __new__(cls, device=None, priority=0, **kwargs):
+        # setting device manager is expensive, so we avoid it unless necessary
+        if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
+            return super().__new__(cls, priority=priority, **kwargs)
+        else:
+            with torch.cuda.device(device):
+                return super().__new__(cls, priority=priority, **kwargs)
+
+    def wait_event(self, event):
+        r"""Make all future work submitted to the stream wait for an event.
+
+        Args:
+            event (torch.cuda.Event): an event to wait for.
+
+        .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
+           `CUDA Stream documentation`_ for more info.
+
+           This function returns without waiting for :attr:`event`: only future
+           operations are affected.
+
+        .. _CUDA Stream documentation:
+           https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
+        """
+        event.wait(self)
+
+    def wait_stream(self, stream):
+        r"""Synchronize with another stream.
+
+        All future work submitted to this stream will wait until all kernels
+        submitted to a given stream at the time of call complete.
+
+        Args:
+            stream (Stream): a stream to synchronize.
+
+        .. note:: This function returns without waiting for currently enqueued
+           kernels in :attr:`stream`: only future operations are affected.
+        """
+        self.wait_event(stream.record_event())
+
+    def record_event(self, event=None):
+        r"""Record an event.
+
+        Args:
+            event (torch.cuda.Event, optional): event to record. If not given, a new one
+                will be allocated.
+
+        Returns:
+            Recorded event.
+        """
+        if event is None:
+            event = Event()
+        event.record(self)
+        return event
+
+    def query(self):
+        r"""Check if all the work submitted has been completed.
+
+        Returns:
+            A boolean indicating if all kernels in this stream are completed.
+        """
+        return super().query()
+
+    def synchronize(self):
+        r"""Wait for all the kernels in this stream to complete.
+
+        .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
+           `CUDA Stream documentation`_ for more info.
+        """
+        super().synchronize()
+
+    @property
+    def _as_parameter_(self):
+        return ctypes.c_void_p(self.cuda_stream)
+
+    def __eq__(self, o):
+        if isinstance(o, Stream):
+            return super().__eq__(o)
+        return False
+
+    def __hash__(self):
+        return hash((self.cuda_stream, self.device))
+
+    def __repr__(self):
+        return f""
+
+
+class ExternalStream(Stream):
+    r"""Wrapper around an externally allocated CUDA stream.
+
+    This class is used to wrap streams allocated in other libraries in order
+    to facilitate data exchange and multi-library interactions.
+
+    .. note:: This class doesn't manage the stream life-cycle, it is the user
+       responsibility to keep the referenced stream alive while this class is
+       being used.
+
+    Args:
+        stream_ptr(int): Integer representation of the `cudaStream_t` value.
+            allocated externally.
+        device(torch.device or int, optional): the device where the stream
+            was originally allocated. if device is specified incorrectly,
+            subsequent launches using this stream may fail.
+    """
+
+    def __new__(cls, stream_ptr, device=None, **kwargs):
+        with torch.cuda.device(device):
+            return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
+
+
+class Event(torch._C._CudaEventBase, _EventBase):
+    r"""Wrapper around a CUDA event.
+
+    CUDA events are synchronization markers that can be used to monitor the
+    device's progress, to accurately measure timing, and to synchronize CUDA
+    streams.
+
+    The underlying CUDA events are lazily initialized when the event is first
+    recorded or exported to another process. After creation, only streams on the
+    same device may record the event. However, streams on any device can wait on
+    the event.
+
+    Args:
+        enable_timing (bool, optional): indicates if the event should measure time
+            (default: ``False``)
+        blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
+        interprocess (bool): if ``True``, the event can be shared between processes
+            (default: ``False``)
+
+    .. _CUDA Event Documentation:
+       https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
+    """
+
+    def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
+        return super().__new__(
+            cls,
+            enable_timing=enable_timing,
+            blocking=blocking,
+            interprocess=interprocess,
+        )
+
+    @classmethod
+    def from_ipc_handle(cls, device, handle):
+        r"""Reconstruct an event from an IPC handle on the given device."""
+        return super().from_ipc_handle(device, handle)
+
+    def record(self, stream=None):
+        r"""Record the event in a given stream.
+
+        Uses ``torch.cuda.current_stream()`` if no stream is specified. The
+        stream's device must match the event's device.
+        """
+        if stream is None:
+            stream = torch.cuda.current_stream()
+        super().record(stream)
+
+    def wait(self, stream=None):
+        r"""Make all future work submitted to the given stream wait for this event.
+
+        Use ``torch.cuda.current_stream()`` if no stream is specified.
+
+        .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
+            `CUDA Event documentation`_ for more info.
+        """
+        if stream is None:
+            stream = torch.cuda.current_stream()
+        super().wait(stream)
+
+    def query(self):
+        r"""Check if all work currently captured by event has completed.
+
+        Returns:
+            A boolean indicating if all work currently captured by event has
+            completed.
+        """
+        return super().query()
+
+    def elapsed_time(self, end_event):
+        r"""Return the time elapsed.
+
+        Time reported in milliseconds after the event was recorded and
+        before the end_event was recorded.
+        """
+        return super().elapsed_time(end_event)
+
+    def synchronize(self):
+        r"""Wait for the event to complete.
+
+        Waits until the completion of all work currently captured in this event.
+        This prevents the CPU thread from proceeding until the event completes.
+
+         .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
+            `CUDA Event documentation`_ for more info.
+        """
+        super().synchronize()
+
+    def ipc_handle(self):
+        r"""Return an IPC handle of this event.
+
+        If not recorded yet, the event will use the current device.
+        """
+        return super().ipc_handle()
+
+    @property
+    def _as_parameter_(self):
+        return ctypes.c_void_p(self.cuda_event)
+
+    def __repr__(self):
+        if self.cuda_event:
+            return f""
+        else:
+            return ""
diff --git a/MLPY/Lib/site-packages/torch/distributed/__init__.py b/MLPY/Lib/site-packages/torch/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..851314d31f3ad16d7e3e617582f9e93fc981b47f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/__init__.py
@@ -0,0 +1,132 @@
+import os
+import sys
+from enum import Enum
+import pdb
+import io
+
+import torch
+
+def is_available() -> bool:
+    """
+    Return ``True`` if the distributed package is available.
+
+    Otherwise,
+    ``torch.distributed`` does not expose any other APIs. Currently,
+    ``torch.distributed`` is available on Linux, MacOS and Windows. Set
+    ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
+    Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
+    ``USE_DISTRIBUTED=0`` for MacOS.
+    """
+    return hasattr(torch._C, "_c10d_init")
+
+
+if is_available() and not torch._C._c10d_init():
+    raise RuntimeError("Failed to initialize torch.distributed")
+
+# Custom Runtime Errors thrown from the distributed package
+DistError = torch._C._DistError
+DistBackendError = torch._C._DistBackendError
+DistNetworkError = torch._C._DistNetworkError
+DistStoreError = torch._C._DistStoreError
+
+if is_available():
+    from torch._C._distributed_c10d import (
+        Store,
+        FileStore,
+        TCPStore,
+        ProcessGroup as ProcessGroup,
+        Backend as _Backend,
+        PrefixStore,
+        Reducer,
+        Logger,
+        BuiltinCommHookType,
+        GradBucket,
+        Work as _Work,
+        _DEFAULT_FIRST_BUCKET_BYTES,
+        _register_comm_hook,
+        _register_builtin_comm_hook,
+        _broadcast_coalesced,
+        _compute_bucket_assignment_by_size,
+        _verify_params_across_processes,
+        _test_python_store,
+        DebugLevel,
+        get_debug_level,
+        set_debug_level,
+        set_debug_level_from_env,
+        _make_nccl_premul_sum,
+    )
+
+    class _DistributedPdb(pdb.Pdb):
+        """
+        Supports using PDB from inside a multiprocessing child process.
+
+        Usage:
+        _DistributedPdb().set_trace()
+        """
+        def interaction(self, *args, **kwargs):
+            _stdin = sys.stdin
+            try:
+                sys.stdin = open('/dev/stdin')
+                pdb.Pdb.interaction(self, *args, **kwargs)
+            finally:
+                sys.stdin = _stdin
+
+    def breakpoint(rank: int = 0):
+        """
+        Set a breakpoint, but only on a single rank.  All other ranks will wait for you to be
+        done with the breakpoint before continuing.
+
+        Args:
+            rank (int): Which rank to break on.  Default: ``0``
+        """
+        if get_rank() == rank:
+            pdb = _DistributedPdb()
+            pdb.message(
+                "\n!!! ATTENTION !!!\n\n"
+                f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
+            )
+            pdb.set_trace()
+        barrier()
+
+    if sys.platform != "win32":
+        from torch._C._distributed_c10d import (
+            HashStore,
+            _round_robin_process_groups,
+        )
+
+    from .distributed_c10d import *  # noqa: F403
+
+    # Variables prefixed with underscore are not auto imported
+    # See the comment in `distributed_c10d.py` above `_backend` on why we expose
+    # this.
+
+    from .distributed_c10d import (
+        _all_gather_base,
+        _reduce_scatter_base,
+        _create_process_group_wrapper,
+        _rank_not_in_group,
+        _coalescing_manager,
+        _CoalescingManager,
+        _get_process_group_name,
+    )
+
+    from .rendezvous import (
+        rendezvous,
+        _create_store_from_options,
+        register_rendezvous_handler,
+    )
+
+    from .remote_device import _remote_device
+
+    set_debug_level_from_env()
+
+else:
+    # This stub is sufficient to get
+    #   python test/test_public_bindings.py -k test_correct_module_names
+    # working even when USE_DISTRIBUTED=0.  Feel free to add more
+    # stubs as necessary.
+    # We cannot define stubs directly because they confuse pyre
+
+    class _ProcessGroupStub:
+        pass
+    sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub  # type: ignore[attr-defined]
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab1ddd5cedae6a0898fa9dc8ee39dadce4b8f8d5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e77e20e45e637eb46d2d1ac07ff7d6dc674c746
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..007cff9806eb5843947378fdb49d79f1985edf96
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29c917f07a9c66243005ab70bd0bc0f76c7583aa
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42b9a0e5d38593433a5baaf505473a56e6ccf50e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6da77a3da7a953f3d1463a03cc70ebadd550950a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fe57d80b08716e41f1e930f2c4ebb40b669e8bd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4211a6a598096e1b2f41cb1443ef399ce8b6f32
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4805ea9eda967be2acce919b5c52f7d399613ca5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b381162f583c9911cba4f5d038d4ee38f8f0c4fd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/distributed_c10d.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/distributed_c10d.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33ff4f0d1389642dbd440c3c371be113838a0dd4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/distributed_c10d.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edd29282eb7c6b3c2f2ded3f4ac5ab443a322e16
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d8bdf18f81fb8729311eb3e20dd9589c7a6694f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35a3fc1468f72b8e19911a0a2b29aa9aa9cbe505
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40338e0ec17d6cd21ecd6e2fea4ad237dff30981
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bf3a4d780165ac6fcf14c1396f745661e110c09
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d17b5af17a584045b80a54605896bafbdc49475
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_composable/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f347203e6c3c9e4aa92a4ea796d35349c0a548f3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/__init__.py
@@ -0,0 +1,4 @@
+from .checkpoint_activation import checkpoint
+from .contract import _get_registry, contract
+from .fully_shard import fully_shard
+from .replicate import replicate
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc39e11aa4a38567d710273c8def068435547b2f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..526120cbaa99e26966e00fbdf80951e4b9725713
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc761259a0c303e6e0d347d1f7a88f6dc56e529b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/fully_shard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/fully_shard.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2869920de64190b8bfedf87864aad54aea6f5e0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/fully_shard.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85a2d0da1c626c3359c67af35a9580358f39c326
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py b/MLPY/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e398d87eb0af4716fce33d122f722e20b1f152
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py
@@ -0,0 +1,94 @@
+from contextlib import contextmanager, nullcontext
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import (
+    _checkpoint_without_reentrant_generator,
+    _DEFAULT_DETERMINISM_MODE,
+)
+
+from .contract import contract
+
+
+@contextmanager
+def _no_hook(module: nn.Module):
+    r"""
+    Disable hooks installed by checkpoint to avoid unintentional recursion
+    during backward recomputation.
+    """
+    orig_enable_hook = checkpoint.state(module).enable_hook
+    checkpoint.state(module).enable_hook = False
+    try:
+        yield
+    finally:
+        checkpoint.state(module).enable_hook = orig_enable_hook
+
+
+@contract()
+def checkpoint(module: nn.Module) -> nn.Module:
+    r"""
+    This is a composable activation checkpointing API. Unlike functional
+    activation checkpointing APIs, this one does not require changing model
+    source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
+    this one does not modify model structure or fully-qualified names either.
+    Under the hood, it registers activation checkpointing logic as pre- and
+    post-forward hooks. Hence, this API can be easily applied to any model or
+    sub-modules in the model.
+
+    Args:
+        module (nn.Module): the target model or sub-module to apply activation
+            checkpointing.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> import torch.nn as nn
+        >>>
+        >>> class MyModel(nn.Module):
+        >>>     def __init__(self):
+        >>>         super().__init__()
+        >>>         self.l1 = nn.Linear(10, 10)
+        >>>         self.l2 = nn.Linear(10, 10)
+        >>>
+        >>>     def forward(self, x):
+        >>>         return self.l2(self.l1(x))
+        >>>
+        >>> model = MyModel()
+        >>> checkpoint(model.l1)  # apply activation checkpointing only to l1
+        >>> model(torch.zeros(2, 10)).sum().backward()
+
+    """
+    torch._C._log_api_usage_once("torch.distributed.checkpoint")
+
+    def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None:
+        if checkpoint.state(module).enable_hook:
+
+            def context_fns():
+                return nullcontext(), _no_hook(module)
+
+            checkpoint.state(
+                module
+            )._ac_generator = _checkpoint_without_reentrant_generator(
+                module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *inputs
+            )
+            next(checkpoint.state(module)._ac_generator)
+
+    def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
+        if checkpoint.state(module).enable_hook:
+            try:
+                next(checkpoint.state(module)._ac_generator)
+            except StopIteration:
+                pass
+            else:
+                raise RuntimeError(
+                    "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
+                )
+
+        #  Ensure that we no longer hold on to the generator. always_call=True helps ensure we
+        # clear this even in the case of exception in fwd pass.
+        checkpoint.state(module)._ac_generator = None
+
+    checkpoint.state(module).enable_hook = True
+    module.register_forward_pre_hook(forward_pre_hook)
+    module.register_forward_hook(forward_hook, prepend=True, always_call=True)
+    return module
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/contract.py b/MLPY/Lib/site-packages/torch/distributed/_composable/contract.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b54b2f55970a9638803b6b32d03212bac26271
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/contract.py
@@ -0,0 +1,194 @@
+import uuid
+from collections import OrderedDict
+from functools import wraps
+from typing import Callable, Dict, List, Optional, Type
+
+import torch.nn as nn
+from torch.distributed._composable_state import _State
+
+
+def generate_state_key(string="__composable_api_state_key"):
+    return f"{string}_{str(uuid.uuid4())}"
+
+
+STATE_KEY = generate_state_key()
+REGISTRY_KEY = generate_state_key()
+
+
+# TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
+# we can add args and kwargs here, and then we can detect whether fully_shard
+# is combined with reentrant activation checkpointing and error out with a clear
+# message.
+class RegistryItem:
+    pass
+
+
+def contract(state_cls: Type[_State] = _State):
+    r"""
+    Decorate a function as a composable distributed API, where the first
+    argument of the function must be an :class:`nn.Module` instance. The
+    decorator verifies that the wrapped function does not modify parameter,
+    buffer or sub-module fully-qualified names (FQN).
+
+    When a function ``func`` is decorated by ``@contract()``, a
+    ``.state(module: nn.Module)`` method will be installed to the decorated
+    function. Then you can retrieve and modify the state on a module by calling
+    ``func.state(module)``.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> import torch.nn as nn
+        >>>
+        >>> class MyModel(nn.Module):
+        >>>     def __init__(self):
+        >>>         super().__init__()
+        >>>         self.l1 = nn.Linear(10, 10)
+        >>>         self.l2 = nn.Linear(10, 10)
+        >>>
+        >>>     def forward(self, x):
+        >>>         return self.l2(self.l1(x))
+        >>>
+        >>> @contract()
+        >>> def my_feature(module: nn.Module) -> nn.Module:
+        >>>     my_feature.state(module).some_state = "any value"
+        >>>     return module
+        >>>
+        >>> model = MyModel()
+        >>> my_feature(model.l1)
+        >>> assert my_feature.state(model.l1).some_state == "any value"
+        >>> my_feature(model.l2)
+        >>> model(torch.randn(2, 10)).sum().backward()
+    """
+
+    # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
+    @wraps(state_cls)
+    def inner(func):
+        @wraps(func)
+        def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
+            # get existing global states
+            default_all_state: Dict[Callable, _State] = OrderedDict()
+            all_state: Dict[Callable, _State] = module.__dict__.setdefault(  # type: ignore[call-overload]
+                STATE_KEY, default_all_state
+            )
+            assert isinstance(
+                all_state, dict
+            ), "Distributed composable API states corrupted"
+
+            # get global registry
+            default_registry: Dict[str, RegistryItem] = OrderedDict()
+            registry: Dict[str, RegistryItem] = module.__dict__.setdefault(  # type: ignore[call-overload]
+                REGISTRY_KEY, default_registry
+            )
+
+            assert isinstance(
+                registry, dict
+            ), "Distributed composable API registry corrupted"
+
+            # make sure the API func has not been applied to the input module yet.
+            assert func not in all_state and func.__name__ not in registry, (
+                "Each distinct composable distributed API can only be applied to a "
+                f"module once. {func.__name__} has already been applied to the "
+                f"following module.\n{module}"
+            )
+
+            # install states specific to the wrapped ``func``
+            all_state.setdefault(func, state_cls())
+            # register ``func`` in the global registry by name
+            registry.setdefault(func.__name__, RegistryItem())
+
+            orig_named_params = OrderedDict(module.named_parameters())
+            orig_named_buffers = OrderedDict(
+                module.named_buffers(remove_duplicate=False)
+            )
+            orig_named_modules = OrderedDict(
+                module.named_modules(remove_duplicate=False)
+            )
+
+            updated = func(module, *args, **kwargs)
+
+            if updated is None:
+                updated = module
+
+            new_named_params = OrderedDict(updated.named_parameters())
+            new_named_buffers = OrderedDict(
+                updated.named_buffers(remove_duplicate=False)
+            )
+            new_named_modules = OrderedDict(
+                updated.named_modules(remove_duplicate=False)
+            )
+
+            assert isinstance(updated, nn.Module), (
+                "Output of composable distributed APIs must be either None or "
+                f"nn.Module, but got {type(updated)}"
+            )
+
+            def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
+                if orig_fqns == new_fqns:
+                    return
+
+                orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
+                orig_only = orig_fqn_set - new_fqn_set
+                new_only = new_fqn_set - orig_fqn_set
+                if len(orig_only) or len(new_only):
+                    raise RuntimeError(
+                        f"{check_key}"
+                        "Composable distributed API implementations cannot modify "
+                        "FQNs.\n"
+                        f"Only in original FQNs: {orig_only},\n"
+                        f"Only in new FQNs: {new_only}"
+                    )
+                else:
+                    raise RuntimeError(
+                        f"{check_key}"
+                        "Composable distributed API implementations cannot modify "
+                        "the order of FQNs.\n"
+                        f"Original FQNs: {orig_only}\n"
+                        f"New FQNs: {new_only}"
+                    )
+
+            check_fqn(
+                list(orig_named_params.keys()),
+                list(new_named_params.keys()),
+                "Check parameters, ",
+            )
+            check_fqn(
+                list(orig_named_buffers.keys()),
+                list(new_named_buffers.keys()),
+                "Check buffer, ",
+            )
+            check_fqn(
+                list(orig_named_modules.keys()),
+                list(new_named_modules.keys()),
+                "Check modules, ",
+            )
+
+            # TODO: a stricter verification should also reject changing module
+            # types and monkey-patching forward() method implementations.
+
+            # TODO: verify that installed distributed paradigms are compatible with
+            # each other.
+
+            return updated
+
+        def get_state(module: nn.Module) -> Optional[_State]:
+            return module.__dict__.setdefault(  # type: ignore[call-overload]
+                STATE_KEY,
+                {},  # TODO(@yhcharles): this is a temporary fix, need a better way
+            ).get(
+                func
+            )  # type: ignore[call-overload]
+
+        wrapper.state = get_state  # type: ignore[attr-defined]
+
+        return wrapper
+
+    return inner
+
+
+def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
+    r"""
+    Get an ``OrderedDict`` of composable APIs that have been applied to the
+    ``module``, indexed by the API name. If no API has been applied, then this
+    returns ``None``.
+    """
+    return getattr(module, REGISTRY_KEY, None)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81a56b625cd924f8d314a020caef977591546d4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py
@@ -0,0 +1,2 @@
+from ._fsdp_api import MixedPrecisionPolicy
+from .fully_shard import FSDP, fully_shard
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43ce0c9e6237ab960bffbfdea099f320e062db24
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6440f0f061625c891ae0fe9705fdb58ad824a4bf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_collectives.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_collectives.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e80f3b71fe2fe0ef4ac84a2d311c705f92d9ec6d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_collectives.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_common.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07d72953511b917fc74f74daac3f274f2df30be0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_common.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_init.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_init.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..409ab5af3bf15462e45dfd46d7fe7c567f5a2d4a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_init.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_param.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_param.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecf73d8a7e0deb95db984d865ac66c9f96c176a2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_param.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_param_group.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_param_group.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e26b3aecf74603ecabf1a0605dd6c8e60193752
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_param_group.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_state.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_state.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d99be15325fa4b7e1ff788f5bd32d60eb0bd4bcf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/_fsdp_state.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea8c9d31a0eda95195557daff3fbaa98fc459a94
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_api.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d77e93513228c74f2defba8270e1e40354cdf8c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_api.py
@@ -0,0 +1,52 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+
+@dataclass(frozen=True)
+class MixedPrecisionPolicy:
+    """
+    This configures FSDP's mixed precision. Unlike autocast, this applies mixed
+    precision at the module level, not op level, which means low-precision
+    activations are saved for backward and high-to-low-precision casts are
+    incurred only at module boundaries.
+
+    FSDP works well with module-level mixed precision since it keeps the
+    high-precision sharded parameters in memory anyway. In other words, FSDP
+    does not require any extra memory to keep a high-precision copy of the
+    parameters for the optimizer step.
+
+    Attributes:
+        param_dtype (Optional[torch.dtype]): This specifies the dtype for
+            the unsharded parameter and hence the dtype for forward/backward
+            computation and the parameter all-gather. If this is ``None``, then
+            the unsharded parameter uses the original dtype. The optimizer step
+            uses the sharded parameter in the original dtype. (Default:
+            ``None``)
+        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
+            gradient reduction (i.e. reduce-scatter or all-reduce). If this is
+            ``None`` but ``param_dtype`` is not ``None``, then the reduction
+            uses the compute dtype. This can be used to run gradient reduction
+            in full precision while using low precision for compute. (Default:
+            ``None``)
+        output_dtype (Optional[torch.dtype]): This specifies the dtype for
+            casting floating-point forward outputs. This can be used to
+            help implement cases where different modules have different mixed
+            precision policies. (Default: ``None``)
+        cast_forward_inputs (bool): This specifies whether FSDP should cast the
+            forward's floating-point input tensors to ``param_dtype`` or not.
+    """
+
+    param_dtype: Optional[torch.dtype] = None
+    reduce_dtype: Optional[torch.dtype] = None
+    output_dtype: Optional[torch.dtype] = None
+    cast_forward_inputs: bool = True
+
+    def __post_init__(self):
+        # Clamp `reduce_dtype` to `None` if no casting is required: since
+        # gradients are computed in `param_dtype`, if `reduce_dtype` matches,
+        # then we do not need extra casting
+        if self.param_dtype == self.reduce_dtype:
+            # Bypass the frozen dataclass checks
+            object.__setattr__(self, "reduce_dtype", None)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1859f42e2a1a8ff7f19d4c76d57b2400f1c008
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -0,0 +1,217 @@
+from typing import List, NamedTuple, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+from torch.distributed.distributed_c10d import ReduceOp
+from ._fsdp_common import (
+    _get_dim0_padded_size,
+    _raise_assert_with_print,
+    _to_dtype_if_needed,
+)
+from ._fsdp_param import FSDPParam
+
+
+class AllGatherResult(NamedTuple):
+    all_gather_output: torch.Tensor
+    all_gather_event: Optional[torch.cuda.Event]
+    all_gather_work: Optional[dist.distributed_c10d.Work]
+    all_gather_input_numels: List[int]
+
+
+@torch.no_grad()
+def foreach_all_gather(
+    fsdp_params: List[FSDPParam],
+    group: dist.ProcessGroup,
+    async_op: bool,
+    all_gather_copy_in_stream: torch.cuda.Stream,
+    all_gather_stream: torch.cuda.Stream,
+    device: torch.device,
+) -> Optional[AllGatherResult]:
+    world_size, rank = group.size(), group.rank()
+    # - Copy in
+    with torch.cuda.stream(all_gather_copy_in_stream):
+        param_all_gather_inputs = [
+            fsdp_param.all_gather_input for fsdp_param in fsdp_params
+        ]
+        dtype = param_all_gather_inputs[0].dtype
+        if not all(t.dtype == dtype for t in param_all_gather_inputs):
+            raise NotImplementedError(
+                f"Mixed dtype not supported yet: {[t.dtype for t in param_all_gather_inputs]}"
+            )
+        inp_split_sizes = [inp.numel() for inp in param_all_gather_inputs]
+        all_gather_input_numel = sum(inp_split_sizes)
+        all_gather_output = torch.empty(
+            (all_gather_input_numel * world_size,), dtype=dtype, device=device
+        )
+        all_gather_input = all_gather_output.narrow(
+            0, all_gather_input_numel * rank, all_gather_input_numel
+        )
+        foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
+        torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
+        del param_all_gather_inputs
+    all_gather_stream.wait_stream(all_gather_copy_in_stream)
+    with torch.cuda.stream(all_gather_stream):
+        # - All-gather
+        all_gather_work = dist.all_gather_into_tensor(
+            output_tensor=all_gather_output,
+            input_tensor=all_gather_input,
+            group=group,
+            async_op=async_op,
+        )
+        all_gather_event = all_gather_stream.record_event()
+        return AllGatherResult(
+            all_gather_output, all_gather_event, all_gather_work, inp_split_sizes
+        )
+
+
+@torch.no_grad()
+def foreach_all_gather_copy_out(
+    all_gather_result: AllGatherResult,
+    fsdp_params: List[FSDPParam],
+    group: dist.ProcessGroup,
+) -> None:
+    (
+        all_gather_output,
+        all_gather_event,
+        all_gather_work,
+        all_gather_input_numels,
+    ) = all_gather_result
+    if all_gather_event is not None:  # sync op
+        torch.cuda.current_stream().wait_event(all_gather_event)
+    if all_gather_work is not None:  # async op
+        all_gather_work.wait()
+    world_size = group.size()
+    dtype, device = all_gather_output.dtype, all_gather_output.device
+    for all_gather_input_numel, fsdp_param in zip(all_gather_input_numels, fsdp_params):
+        fsdp_param.init_all_gather_output(
+            all_gather_input_numel, world_size, dtype, device
+        )  # no-op after 1st call
+        fsdp_param.alloc_all_gather_output()
+    all_gather_output = all_gather_output.view(world_size, -1)
+    out = [
+        fsdp_param.all_gather_output.view(world_size, -1) for fsdp_param in fsdp_params
+    ]
+    torch.split_with_sizes_copy(
+        all_gather_output, all_gather_input_numels, dim=1, out=out
+    )
+
+
+@torch.no_grad()
+def foreach_reduce_scatter(
+    fsdp_params: List[FSDPParam],
+    unsharded_grads: List[torch.Tensor],
+    group: dist.ProcessGroup,
+    reduce_scatter_stream: torch.cuda.Stream,
+    orig_dtype: torch.dtype,
+    reduce_dtype: Optional[torch.dtype],
+    device: torch.device,
+    divide_factors: Optional[Tuple[float, float]],
+) -> torch.cuda.Event:
+    """
+    ``unsharded_grads`` owns the references to the gradients computed by
+    autograd, so clearing the list frees the gradients.
+    """
+    grad_dtypes = {grad.dtype for grad in unsharded_grads}
+    if len(grad_dtypes) != 1:
+        # Check this at runtime since it could be a real runtime error if e.g.
+        # fp8 weights do not produce the correct higher precision gradients
+        _raise_assert_with_print(
+            f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}"
+        )
+    grad_dtype = unsharded_grads[0].dtype
+    reduce_dtype = reduce_dtype or grad_dtype
+    world_size = group.size()
+    padded_unsharded_sizes = tuple(
+        _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
+    )
+    reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
+    reduce_scatter_output_numel = reduce_scatter_input_numel // world_size
+    current_stream = torch.cuda.current_stream()
+    reduce_scatter_stream.wait_stream(current_stream)
+    with torch.cuda.stream(reduce_scatter_stream):
+        reduce_scatter_input = torch.empty(
+            (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
+        )
+        foreach_reduce_scatter_copy_in(
+            unsharded_grads, reduce_scatter_input, world_size
+        )
+        # Only after the copy-in finishes can we free the gradients, which were
+        # computed in the default stream
+        current_stream.wait_stream(reduce_scatter_stream)
+        unsharded_grads.clear()
+        reduce_scatter_output = reduce_scatter_input.new_empty(
+            (reduce_scatter_output_numel,)
+        )
+        _reduce_scatter(
+            reduce_scatter_output, reduce_scatter_input, group, divide_factors
+        )
+        reduce_scatter_output = _to_dtype_if_needed(reduce_scatter_output, orig_dtype)
+        # - View out and accumulate
+        flat_grad_offset = 0  # [0, reduce_scatter_output_numel - 1]
+        for padded_unsharded_size, fsdp_param in zip(
+            padded_unsharded_sizes, fsdp_params
+        ):
+            new_sharded_grad = torch.as_strided(
+                reduce_scatter_output,
+                size=fsdp_param.sharded_size,
+                stride=fsdp_param.contiguous_sharded_stride,
+                storage_offset=flat_grad_offset,
+            )
+            to_accumulate_grad = fsdp_param.sharded_param.grad is not None
+            new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(new_sharded_grad)
+            if to_accumulate_grad:
+                fsdp_param.sharded_param.grad += new_sharded_dtensor_grad
+            else:
+                fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
+            padded_sharded_numel = padded_unsharded_size.numel() // world_size
+            flat_grad_offset += padded_sharded_numel
+        reduce_scatter_view_out_event = reduce_scatter_stream.record_event()
+    # The RS output is allocated in the RS stream and used in the default
+    # stream (for optimizer). To ensure its memory is not reused for later
+    # RSs, we do not need extra synchronization since the sharded parameters
+    # hold refs through the end of backward.
+    return reduce_scatter_view_out_event
+
+
+def foreach_reduce_scatter_copy_in(
+    unsharded_grads: List[torch.Tensor],
+    reduce_scatter_input: torch.Tensor,
+    world_size: int,
+) -> None:
+    grad_views: List[torch.Tensor] = []
+    grads_to_copy: List[torch.Tensor] = []
+    padded_grad_slices: List[torch.Tensor] = []
+    for grad in unsharded_grads:
+        grad_size = grad.size()
+        dim0_padded_size = _get_dim0_padded_size(grad_size, world_size)
+        if dim0_padded_size != grad_size:
+            padded_grad = grad.new_empty(dim0_padded_size)
+            padded_grad_slices.append(padded_grad[: grad.size(0)])
+            grads_to_copy.append(grad)
+            grad = padded_grad
+        grad_views.append(grad.view(world_size, -1))
+    if padded_grad_slices:
+        torch._foreach_copy_(padded_grad_slices, grads_to_copy)
+    torch.cat(grad_views, dim=-1, out=reduce_scatter_input.view(world_size, -1))
+
+
+def _reduce_scatter(
+    output: torch.Tensor,
+    input: torch.Tensor,
+    group: dist.ProcessGroup,
+    divide_factors: Optional[Tuple[float, float]],
+) -> None:
+    if divide_factors:
+        predivide_factor, postdivide_factor = divide_factors
+        _div_if_needed(input, predivide_factor)
+        dist.reduce_scatter_tensor(output, input, group=group)
+        _div_if_needed(output, postdivide_factor)
+    else:
+        # Using NCCL's reduce-scatter to do the division by world size saves
+        # extra memory read/write from a separate division kernel
+        dist.reduce_scatter_tensor(output, input, op=ReduceOp.AVG, group=group)
+
+
+def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:
+    if div_factor > 1:
+        tensor.div_(div_factor)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_common.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aca875bedd6d154a18fc52db30861861566c025
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_common.py
@@ -0,0 +1,151 @@
+import math
+import traceback
+
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import Any, cast, List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed._composable.contract import _get_registry
+from torch.distributed._tensor import DeviceMesh, DTensor, Placement
+
+
+@dataclass
+class DataParallelMeshInfo:
+    mesh: DeviceMesh
+    shard_mesh_dim: Optional[int] = None
+    replicate_mesh_dim: Optional[int] = None
+
+    def __post_init__(self):
+        if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
+            raise AssertionError(
+                "At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
+            )
+
+
+@dataclass
+class FSDPMeshInfo(DataParallelMeshInfo):
+    def __post_init__(self):
+        super().__post_init__()
+        if self.shard_mesh_dim is None:
+            raise AssertionError("Expects non-None shard_mesh_dim")
+        self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim)
+        self.shard_process_group = cast(
+            dist.ProcessGroup, self.mesh.get_group(self.shard_mesh_dim)
+        )
+        self.shard_mesh_rank: int = self.shard_process_group.rank()
+
+
+@dataclass
+class DDPMeshInfo(DataParallelMeshInfo):
+    def __post_init__(self):
+        super().__post_init__()
+        if self.replicate_mesh_dim is None:
+            raise AssertionError("Expects non-None replicate_mesh_dim")
+        self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim)
+        self.replicate_process_group = cast(
+            dist.ProcessGroup, self.mesh.get_group(self.replicate_mesh_dim)
+        )
+        self.replicate_mesh_rank: int = self.replicate_process_group.rank()
+
+
+@dataclass
+class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
+    def __post_init__(self):
+        # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
+        super().__post_init__()
+
+
+class TrainingState(Enum):
+    """Describes the training state of one FSDP state / parameter group."""
+
+    # Transition to forward starting pre-forward until post-forward
+    FORWARD = auto()
+    # Transition to pre-backward when unsharding in backward
+    PRE_BACKWARD = auto()
+    # Transition to post-backward when resharding and reducing gradients
+    POST_BACKWARD = auto()
+    # Idle before/after forward or before pre-backward/after post-backward
+    IDLE = auto()
+
+
+def _raise_assert_with_print(*args: Any, **kwargs: Any):
+    print(f"[Rank {dist.get_rank()}] ", end="")
+    print(*args, **kwargs)
+    traceback.print_stack()
+    raise AssertionError(*args, **kwargs)
+
+
+def _is_composable_with_fsdp(module: nn.Module) -> bool:
+    registry = _get_registry(module)
+    if registry is None:
+        return True
+    # Registry keys by function name
+    return "replicate" not in registry
+
+
+def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size:
+    padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
+    return cast(torch.Size, torch.Size([padded_dim0]) + tensor_size[1:])
+
+
+def _chunk_with_empty(
+    tensor: torch.Tensor, num_chunks: int, dim: int
+) -> List[torch.Tensor]:
+    chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
+    while len(chunks) < num_chunks:
+        chunks.append(chunks[0].new_empty(0))
+    return chunks
+
+
+def _get_dim0_chunked_size(
+    chunk: torch.Tensor, unchunked_size: torch.Size
+) -> torch.Size:
+    if chunk.numel() > 0:
+        return chunk.size()
+    # For 0 numel, we need to preserve trailing dims for DTensor APIs
+    return cast(torch.Size, torch.Size([0]) + unchunked_size[1:])
+
+
+def _from_local_no_grad(
+    local_tensor: torch.Tensor,
+    device_mesh: DeviceMesh,
+    placements: Tuple[Placement, ...],
+    global_size: torch.Size,
+    global_stride: Tuple[int, ...],
+) -> DTensor:
+    """
+    This method is similar to ``DTensor.from_local()`` except it avoids some
+    CPU overhead by avoiding default args and not being differentiable.
+    """
+    return DTensor(
+        # Use the local tensor directly instead of constructing a new tensor
+        # variable, e.g. with `view_as()`, since this is not differentiable
+        local_tensor,
+        device_mesh,
+        placements,
+        shape=global_size,
+        dtype=local_tensor.dtype,
+        requires_grad=local_tensor.requires_grad,
+        stride=global_stride,
+    )
+
+
+def _to_dtype_if_needed(
+    tensor: torch.Tensor, dtype: Optional[torch.dtype]
+) -> torch.Tensor:
+    if dtype is not None and tensor.dtype != dtype:
+        return tensor.to(dtype)
+    return tensor
+
+
+def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
+    if (
+        not isinstance(x, torch.Tensor)
+        or not torch.is_floating_point(x)
+        or x.dtype == dtype
+    ):
+        return x
+    return x.to(dtype)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_init.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b6453309cd139c262bb28dba8ddaf6599b66fe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_init.py
@@ -0,0 +1,144 @@
+import itertools
+from typing import List, Optional, Set, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
+from torch.distributed.device_mesh import _get_device_handle
+from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
+from ._fsdp_state import _get_module_fsdp_state
+
+
+def _get_post_forward_mesh_info(
+    reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo
+) -> Optional[FSDPMeshInfo]:
+    shard_mesh_size = mesh_info.shard_mesh_size
+    if not isinstance(reshard_after_forward, (bool, int)):
+        raise ValueError(
+            "reshard_after_forward should be a bool or an int representing the "
+            f"group size to reshard to, not {reshard_after_forward}"
+        )
+    # NOTE: `isinstance(False, int)` returns `True`.
+    if not isinstance(reshard_after_forward, bool) and isinstance(
+        reshard_after_forward, int
+    ):
+        if (
+            reshard_after_forward < 1
+            or reshard_after_forward > shard_mesh_size
+            or shard_mesh_size % reshard_after_forward != 0
+        ):
+            raise ValueError(
+                "If passing reshard_after_forward as an int, it should be a "
+                f"factor of {shard_mesh_size}, not {reshard_after_forward}"
+            )
+        elif reshard_after_forward == 1:
+            reshard_after_forward = False
+        elif reshard_after_forward == shard_mesh_size:
+            reshard_after_forward = True
+    post_forward_mesh_info = None
+    if reshard_after_forward is True:
+        post_forward_mesh_info = mesh_info
+    elif reshard_after_forward is not False:  # int case
+        # For HSDP, we can flatten the two replicate dims into the 0th dim
+        post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward)
+        post_forward_mesh = DeviceMesh(
+            mesh_info.mesh.device_type, post_forward_mesh_tensor
+        )
+        post_forward_mesh_info = HSDPMeshInfo(
+            post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0
+        )
+    return post_forward_mesh_info
+
+
+def _init_default_fully_shard_mesh() -> DeviceMesh:
+    """Default to global CUDA mesh if possible else global CPU mesh."""
+    if not dist.distributed_c10d.is_initialized():
+        dist.distributed_c10d.init_process_group()
+    default_pg = dist.distributed_c10d._get_default_group()
+    device_type = "cuda" if torch.cuda.is_available() else "cpu"
+    mesh = init_device_mesh(device_type, mesh_shape=(default_pg.size(),))
+    return mesh
+
+
+def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
+    if mesh.device_type == "cpu":
+        return torch.device("cpu")
+    device_handle = _get_device_handle(mesh.device_type)
+    return torch.device(mesh.device_type, device_handle.current_device())
+
+
+def _get_managed_modules(root_module: nn.Module) -> List[nn.Module]:
+    modules: List[nn.Module] = []
+    # Track visisted modules to avoid visiting shared modules multiple times
+    visited_modules: Set[nn.Module] = set()
+
+    def dfs(module: nn.Module) -> None:
+        """
+        Runs a DFS to collect managed modules, not recursing into modules with
+        a non-composable API or ``fully_shard`` already applied.
+        """
+        if not _is_composable_with_fsdp(module):
+            return
+        elif module is not root_module and _get_module_fsdp_state(module) is not None:
+            return  # nested `fully_shard` module
+        visited_modules.add(module)
+        for submodule in module.children():
+            if submodule not in visited_modules:
+                dfs(submodule)
+        modules.append(module)
+
+    dfs(root_module)
+    return modules
+
+
+def _get_managed_states(
+    modules: List[nn.Module],
+) -> Tuple[List[nn.Parameter], List[torch.Tensor]]:
+    params: List[nn.Parameter] = []
+    buffers: List[torch.Tensor] = []
+    # Track visited parameters/buffers to avoid visiting shared parameters and
+    # buffers multiple times
+    visited_params: Set[nn.Parameter] = set()
+    visited_buffers: Set[torch.Tensor] = set()
+    for module in modules:
+        for param in module.parameters(recurse=False):
+            if param not in visited_params:
+                params.append(param)
+                visited_params.add(param)
+        for buffer in module.buffers(recurse=False):
+            if buffer not in visited_buffers:
+                buffers.append(buffer)
+                visited_buffers.add(buffer)
+    return params, buffers
+
+
+def _move_states_to_device(
+    params: List[nn.Parameter],
+    buffers: List[torch.Tensor],
+    device: torch.device,
+    mesh_info: FSDPMeshInfo,
+) -> None:
+    """
+    We have FSDP move states to device for simpler and faster initialization
+    since FSDP almost always uses CUDA for training. We move parameters/buffers
+    rather than modules since modules to support ignoring parameters/buffers in
+    the future.
+    """
+    # TODO: De-duplicate with `_apply` after `swap_tensors` path lands:
+    # https://github.com/pytorch/pytorch/issues/115792
+    for tensor in itertools.chain(params, buffers):
+        if tensor.device == device or tensor.device.type == "meta":
+            # Keep meta-device tensors on meta device for deferred init
+            continue
+        if isinstance(tensor, DTensor):
+            if (dtensor_mesh_type := tensor._spec.mesh.device_type) != device.type:
+                raise ValueError(
+                    "Requires DTensor to have mesh of the same type as the FSDP mesh "
+                    f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
+                )
+            raise AssertionError(
+                f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
+            )
+        tensor.data = tensor.to(device)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py
new file mode 100644
index 0000000000000000000000000000000000000000..275e2f4f980a54e9f7d9f71243e5ba3b3c8f9a67
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -0,0 +1,438 @@
+from dataclasses import dataclass, field
+from enum import auto, Enum
+from typing import cast, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from torch._prims_common import make_contiguous_strides_for
+from torch.distributed._functional_collectives import AsyncCollectiveTensor
+from torch.distributed._tensor import DTensor, Placement, Replicate, Shard
+from torch.distributed._tensor.device_mesh import _mesh_resources
+from torch.distributed._tensor.placement_types import DTensorSpec
+from ._fsdp_api import MixedPrecisionPolicy
+from ._fsdp_common import (
+    _chunk_with_empty,
+    _from_local_no_grad,
+    _get_dim0_chunked_size,
+    _raise_assert_with_print,
+    _to_dtype_if_needed,
+    FSDPMeshInfo,
+    HSDPMeshInfo,
+)
+
+"""
+[Note: FSDP tensors]
+FSDP considers the following tensors:
+- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one
+  on the module when applying FSDP
+- Sharded parameter: sharding the original parameter on dim-0 as a DTensor
+  over the main mesh
+- All-gather input: the ``torch.Tensor`` passed to all-gather, derived from the
+  sharded parameter
+- All-gather output: the ``torch.Tensor`` resulting from all-gathering the
+  all-gather input
+- Unsharded parameter: parameter used for forward/backward computation, derived
+  from the all-gather output; autograd leaf
+
+We define these tensors to describe the general framework that can accomodate
+extensions, where:
+- all-gather-input = pre-all-gather-transform(sharded-parameter)
+- unsharded-parameter = post-all-gather-transform(all-gather-output)
+
+For the default ``torch.Tensor`` case, the sharded parameter and all-gather
+input share the same underlying tensor data, meaning that they can be thought
+of as the same tensors. The same applies for the all-gather output and
+unsharded parameter. For non-``torch.Tensor`` extensions, these equivalences
+may no longer hold due to the pre/post-all-gather transforms.
+
+[Note: FSDP and autograd]
+FSDP dynamically frees and allocates the unsharded parameter. Since autograd
+can pack a reference to it or a view to save for backward, we use storage
+resizing to implement the freeing/allocation since that preserves the aliasing.
+This implies that we construct the unsharded parameter object once and write to
+it in-place thereafter. For the default ``torch.Tensor` original parameter
+case, the all-gather output and unsharded parameter share the same
+data, so we use storage resizing on the all-gather output.
+"""
+
+
+class ShardedState(Enum):
+    """
+    - ``SHARDED``: The sharded parameter is registered to the module. It is the
+      only contributor to parameter memory.
+    - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a
+      smaller world size. Since this data should not be used for computation,
+      we do not register it to the module. Users should reshard the module
+      before any in-place modifications. Both it and the sharded parameter
+      contribute to parameter memory.
+    - ``UNSHARDED``: The unsharded parameter is registered to the module. Both
+      it and the sharded parameter contribute to parameter memory.
+    """
+
+    SHARDED = auto()
+    SHARDED_POST_FORWARD = auto()
+    UNSHARDED = auto()
+
+
+@dataclass
+class ParamModuleInfo:
+    """
+    For a parameter, this stores the module and the parameter name to be able
+    to do a parameter swap via ``setattr(module, param_name, ...)`` or to get
+    the parameter via ``getattr(module, param_name)``. We additionally save
+    shared modules and shared parameter names to update them accordingly.
+    """
+
+    # Parameter names are unprefixed, e.g. "weight", not "lin.weight"
+    module: nn.Module
+    param_name: str
+    shared_modules: List[nn.Module] = field(default_factory=list)
+    shared_param_names: List[str] = field(default_factory=list)
+
+
+class FSDPParam:
+    """
+    This class manages a parameter with FSDP or FSDP variants applied,
+    implementing dim-0 per-parameter sharding.
+    """
+
+    orig_dtype: torch.dtype
+    param_dtype: Optional[torch.dtype]
+    reduce_dtype: Optional[torch.dtype]
+    _orig_size: torch.Size  # ND
+    _contiguous_orig_stride: Tuple[int, ...]
+    sharded_size: torch.Size  # ND
+    contiguous_sharded_stride: Tuple[int, ...]
+    padded_sharded_param_size: torch.Size  # ND
+    sharded_post_forward_size: torch.Size  # ND
+    contiguous_sharded_post_forward_stride: Tuple[int, ...]
+    _sharded_param_data: torch.Tensor  # 1D
+    sharded_param: nn.Parameter  # ND
+    _sharded_post_forward_param_data: Optional[torch.Tensor]  # 1D
+    _sharded_post_forward_param: Optional[nn.Parameter]  # ND
+    _unsharded_param: nn.Parameter  # ND
+    _global_placements: Tuple[Placement, ...]
+    _global_size: torch.Size
+    _global_stride: Tuple[int, ...]
+    # DTensor attributes (only defined for DTensor `param`):
+    _tp_spec: DTensorSpec
+
+    def __init__(
+        self,
+        param: nn.Parameter,
+        module_info: ParamModuleInfo,
+        mesh_info: FSDPMeshInfo,
+        post_forward_mesh_info: Optional[FSDPMeshInfo],
+        device: torch.device,
+        mp_policy: MixedPrecisionPolicy,
+    ):
+        self._module_info: ParamModuleInfo = module_info
+        self.mesh_info = mesh_info
+        self.post_forward_mesh_info = post_forward_mesh_info
+        self.device = device
+        self._init_sharded_param(param, device)
+        if self.post_forward_mesh_info:
+            self._init_sharded_post_forward_param_metadata(param)
+        self.all_gather_output = torch.empty(0)
+        self._param_fqn: Optional[str] = None  # prefixed from root module
+
+    @torch.no_grad()
+    def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
+        if param.device != device and param.device.type != "meta":
+            raise AssertionError(
+                f"Expects the parameter to already be moved to device {device} but got {param.device}"
+            )
+        # TODO: Replace the sharded DTensor parameter construction logic with
+        # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101
+        # TODO: Simplify the following sharded parameter padding logic after
+        # https://github.com/pytorch/pytorch/issues/113045
+        self.is_dtensor = isinstance(param, DTensor)
+        if self.is_dtensor:
+            self._tp_spec = cast(DTensor, param)._spec
+            if (
+                self.mesh_info.shard_mesh_dim != 0
+                or self.mesh_info.replicate_mesh_dim is not None
+            ):
+                raise NotImplementedError("Using TP with HSDP is not supported")
+            dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
+            dp_global_mesh = _mesh_resources.get_parent_mesh(dp_mesh)
+            tp_global_mesh = _mesh_resources.get_parent_mesh(tp_mesh)
+            if dp_global_mesh != tp_global_mesh or (
+                dp_global_mesh is None or tp_global_mesh is None
+            ):
+                raise AssertionError(
+                    "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n"
+                    f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}"
+                )
+            self._global_mesh = dp_global_mesh
+            if len(self._tp_spec.placements) != 1:
+                raise NotImplementedError(
+                    f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
+                )
+            global_placements: List[Placement] = [Replicate(), Replicate()]
+            global_dp_mesh_dim = _mesh_resources.get_parent_mesh_dim(dp_mesh)
+            global_tp_mesh_dim = _mesh_resources.get_parent_mesh_dim(tp_mesh)
+            assert global_dp_mesh_dim is not None  # mypy
+            assert global_tp_mesh_dim is not None  # mypy
+            # TODO: Hard code FSDP + TP; need to support HSDP + TP
+            global_placements[global_dp_mesh_dim] = Shard(0)
+            global_placements[global_tp_mesh_dim] = self._tp_spec.placements[0]
+            self._global_placements = tuple(global_placements)
+            self._global_size = param.size()
+            self._global_stride = param.stride()
+            param_data = cast(DTensor, param)._local_tensor
+        else:
+            self._global_mesh = self.mesh_info.mesh
+            self._global_placements = (Shard(0),)
+            self._global_size = param.size()
+            self._global_stride = param.stride()
+            param_data = param
+        self._orig_size = param_data.size()
+        self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
+        shard_rank = self.mesh_info.shard_mesh_rank
+        shard_world_size = self.mesh_info.shard_mesh_size
+        chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
+        sharded_param = chunks[shard_rank]
+        self.sharded_size = _get_dim0_chunked_size(sharded_param, param_data.size())
+        self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
+        padded_sharded_size = chunks[0].size()  # 0th always padded
+        padded_sharded_param = param_data.new_zeros(padded_sharded_size)
+        self.padded_sharded_param_size = padded_sharded_param.size()
+        if sharded_param.numel() > 0:
+            padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
+        self._sharded_param_data = padded_sharded_param.view(-1)
+        self.sharded_param = nn.Parameter(
+            self.to_sharded_dtensor(padded_sharded_param[: sharded_param.size(0)])
+        )
+        self.sharded_param.requires_grad_(param.requires_grad)
+        # Let `param_data` be freed normally when its ref count reaches 0 when
+        # the `fully_shard` call returns to allow provided parameters to alias
+        self._setattr_on_modules(self.sharded_param)
+        self.sharded_state = ShardedState.SHARDED
+
+    def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
+        mesh_info = self.post_forward_mesh_info
+        assert mesh_info is not None  # mypy
+        param_data = param._local_tensor if isinstance(param, DTensor) else param
+        chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
+        self.sharded_post_forward_size = _get_dim0_chunked_size(
+            chunks[mesh_info.shard_mesh_rank], param_data.size()
+        )
+        self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
+            self.sharded_post_forward_size
+        )
+
+    def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
+        param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
+        self.orig_dtype = self.sharded_param.dtype
+        # Clamp `param_dtype` to `None` if no casting is required
+        if param_dtype == self.orig_dtype:
+            param_dtype = None
+        self.param_dtype = param_dtype
+        self.reduce_dtype = reduce_dtype
+        # None indicates that the mixed precision is not enabled
+
+    def init_all_gather_output(
+        self,
+        all_gather_input_numel: int,
+        world_size: int,
+        dtype: torch.dtype,
+        device: torch.device,
+    ):
+        if self.all_gather_output.numel() > 0:
+            return  # already initialized
+        all_gather_output_size = torch.Size([all_gather_input_numel * world_size])
+        self.all_gather_output = torch.empty(
+            all_gather_output_size, dtype=dtype, device=device
+        )
+
+    def init_unsharded_param(self):
+        if hasattr(self, "_unsharded_param"):
+            return  # already initialized
+        # For the default path (no post-all-gather), the all-gather output
+        # gives the unsharded parameter data directly
+        unsharded_param = torch.as_strided(
+            self.all_gather_output,
+            self._orig_size,
+            self._contiguous_orig_stride,
+            storage_offset=0,
+        )
+        if self.is_dtensor:
+            unsharded_param = _from_local_no_grad(
+                unsharded_param,
+                self._tp_spec.mesh,
+                self._tp_spec.placements,
+                self._global_size,
+                self._global_stride,
+            )
+        self._unsharded_param = nn.Parameter(unsharded_param)
+        self._unsharded_param.requires_grad_(self.sharded_param.requires_grad)
+
+    def to_sharded(self) -> None:
+        self._setattr_on_modules(self.sharded_param)
+        self.free_all_gather_output()
+        self.sharded_state = ShardedState.SHARDED
+
+    def to_sharded_post_forward(self) -> None:
+        if self.is_dtensor:
+            raise NotImplementedError(
+                "Resharding to smaller mesh with TP is not supported yet"
+            )
+        self._assert_in_states(ShardedState.UNSHARDED)
+        assert self.post_forward_mesh_info is not None  # mypy
+        shard_world_size = self.post_forward_mesh_info.shard_mesh_size
+        if (numel := self.all_gather_output.numel()) % shard_world_size != 0:
+            _raise_assert_with_print(
+                f"All-gather output size ({numel}) must be divisible by the shard "
+                f"world size ({shard_world_size})"
+            )
+        shard_rank = self.post_forward_mesh_info.shard_mesh_rank
+        sharded_numel = numel // shard_world_size
+        self._sharded_post_forward_param_data = (
+            self.all_gather_output.narrow(0, sharded_numel * shard_rank, sharded_numel)
+        ).clone()  # clone to be able to free all-gather output
+        sharded_post_forward_tensor = torch.as_strided(
+            self._sharded_post_forward_param_data,
+            size=self.sharded_post_forward_size,
+            stride=self.contiguous_sharded_post_forward_stride,
+            storage_offset=0,
+        )
+        self._sharded_post_forward_param = nn.Parameter(
+            self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
+        )
+        self._setattr_on_modules(self._sharded_post_forward_param)
+        self.free_all_gather_output()
+        self.sharded_state = ShardedState.SHARDED_POST_FORWARD
+
+    def to_unsharded(self) -> None:
+        # Assume that the data has been allocated and all-gathered
+        set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
+        self._setattr_on_modules(self._unsharded_param)
+        if self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
+            # The data is allocated in the default stream via the post-forward
+            # reshard and must be kept alive for the next all-gather copy-in.
+            # Since we call this method after the copy-out, the data's lifetime
+            # is ensured without further synchronization.
+            self._sharded_post_forward_param = None
+            self._sharded_post_forward_param_data = None  # free
+        self.sharded_state = ShardedState.UNSHARDED
+
+    def _setattr_on_modules(self, param: nn.Parameter) -> None:
+        unsafe_setattr_param(
+            self._module_info.module, self._module_info.param_name, param
+        )
+        for shared_module, shared_param_name in zip(
+            self._module_info.shared_modules, self._module_info.shared_param_names
+        ):
+            unsafe_setattr_param(shared_module, shared_param_name, param)
+
+    def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
+        """
+        Converts a local tensor representing either the sharded parameter or
+        sharded gradient to DTensor.
+        """
+        if tensor.shape != self.sharded_size:
+            _raise_assert_with_print(
+                f"Expects size {self.sharded_size} but got {tensor.shape}"
+            )
+        return _from_local_no_grad(
+            tensor,
+            self._global_mesh,
+            self._global_placements,
+            self._global_size,
+            self._global_stride,
+        )
+
+    def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor:
+        if tensor.shape != self.sharded_post_forward_size:
+            _raise_assert_with_print(
+                f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}"
+            )
+        assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo)
+        # TODO: Prefer this DTensor to be read-only and generalize the
+        # placement once we support TP.
+        return _from_local_no_grad(
+            tensor,
+            self.post_forward_mesh_info.mesh,
+            (Replicate(), Shard(0)),
+            self._global_size,
+            self._global_stride,
+        )
+
+    def alloc_all_gather_output(self) -> None:
+        unsafe_alloc_storage(self.all_gather_output)
+
+    def free_all_gather_output(self) -> None:
+        unsafe_free_storage(self.all_gather_output)
+
+    @property
+    def all_gather_input(self) -> torch.Tensor:  # 1D
+        self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
+        if self.sharded_state == ShardedState.SHARDED:
+            return _to_dtype_if_needed(self._sharded_param_data, self.param_dtype)
+        elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
+            return _to_dtype_if_needed(
+                cast(torch.Tensor, self._sharded_post_forward_param_data),
+                self.param_dtype,
+            )
+        return torch.empty(0)  # mypy
+
+    @property
+    def unsharded_param(self) -> nn.Parameter:  # ND
+        self._assert_in_states(ShardedState.UNSHARDED)
+        return self._unsharded_param
+
+    @property
+    def unsharded_grad_data(self) -> torch.Tensor:
+        grad = self.unsharded_param.grad
+        assert grad is not None, "Expects unsharded_param.grad to not be None"
+        return self._get_grad_inner_tensor(grad)
+
+    def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
+        if self.is_dtensor:
+            if isinstance(grad, AsyncCollectiveTensor):
+                grad = grad.wait()
+            grad = cast(DTensor, grad)._local_tensor
+        return grad
+
+    def _assert_in_states(self, *states: ShardedState) -> None:
+        if self.sharded_state not in states:
+            _raise_assert_with_print(
+                f"Expects to be in one of {states}, not {self.sharded_state}"
+            )
+
+
+# NOTE: Unsafe here refers to not checking whether the storage is already
+# allocated or freed, respectively. We should be safe to use them since we
+# explicitly manage the state transition.
+def unsafe_alloc_storage(tensor: torch.Tensor) -> None:
+    # Skip the already-allocated check and assume that `tensor` is the base
+    # tensor to save CPU overhead
+    tensor.untyped_storage().resize_(tensor.numel() * tensor.itemsize)
+
+
+def unsafe_free_storage(tensor: torch.Tensor) -> None:
+    # Skip the already-freed check to save CPU overhead
+    tensor.untyped_storage().resize_(0)
+
+
+# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial
+# CPU overhead, if the module did not override it. For FSDP, we know we do not
+# need those checks when transitioning between sharded/unsharded parameters.
+def unsafe_setattr_param(
+    module: nn.Module, param_name: str, param: nn.Parameter
+) -> None:
+    if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__:
+        module._parameters[param_name] = param
+    else:  # slow path
+        setattr(module, param_name, param)
+
+
+def set_requires_grad_if_needed(
+    src_tensor: torch.Tensor, dst_tensor: torch.Tensor
+) -> None:
+    # Only call `requires_grad_` if needed to avoid the Python <> C++ context
+    # switch overhead
+    if src_tensor.requires_grad != dst_tensor.requires_grad:
+        dst_tensor.requires_grad_(src_tensor.requires_grad)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..51546afa3ae277fa7fe6df4281f9e5a3fa22688b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py
@@ -0,0 +1,506 @@
+import contextlib
+
+from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from torch.autograd.graph import Node
+from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates
+from torch.utils._pytree import tree_flatten, tree_unflatten
+from torch.utils.hooks import RemovableHandle
+from ._fsdp_api import MixedPrecisionPolicy
+from ._fsdp_collectives import (
+    AllGatherResult,
+    foreach_all_gather,
+    foreach_all_gather_copy_out,
+    foreach_reduce_scatter,
+)
+from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
+from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
+
+_ModuleToHandleDict = Dict[nn.Module, RemovableHandle]  # for state dict
+
+
+"""
+[Note: Overlapping all-gather copy-in and all-gather]
+For implicit forward prefetching, we want to overlap the next copy-in with the
+current all-gather. We do so using a separate copy-in stream. However, since
+we have the all-gather input as a view into the output, we must make sure to
+copy into different memory from the current all-gather's output. Thus, we keep
+a reference to the current all-gather's output and have the next FSDP parameter
+group free it after its copy-in. Finally, we have the last FSDP state flush the
+reference to avoid holding onto memory after forward.
+"""
+
+
+class FSDPCommContext:
+    """This has the communication state shared across FSDP states/parameter groups."""
+
+    def init(self):
+        # Setting the all-gather/reduce-scatter streams to be higher priority
+        # can help avoid some issues where their copies in/out are delayed and
+        # block computation
+        high_priority = -1
+        # All-gather state and copy-in stream allow overlapping the next
+        # copy-in with the current all-gather in forward; copy-in overlaps with
+        # reduce-scatter in backward without the separate copy-in stream
+        self.all_gather_copy_in_stream = torch.cuda.Stream(priority=high_priority)
+        self.all_gather_state: Optional[AllGatherState] = None
+        # All-gather stream allows overlapping next all-gather with current
+        # forward compute
+        self.all_gather_stream = torch.cuda.Stream(priority=high_priority)
+        # Reduce-scatter stream gives separate execution "thread" for post-
+        # backward logic like pre/post-gradient division and reduce-scatter
+        self.reduce_scatter_stream = torch.cuda.Stream(priority=high_priority)
+        # Post-forward order for explicit backward prefetching
+        self.post_forward_order: List[FSDPParamGroup] = []  # will cause ref cycles
+
+    def get_all_gather_streams(
+        self, training_state: TrainingState
+    ) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]:
+        if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD):
+            # Use separate streams for implicit prefetching
+            return self.all_gather_copy_in_stream, self.all_gather_stream
+        current_stream = torch.cuda.current_stream()
+        return current_stream, current_stream
+
+
+# See [Note: Overlapping all-gather copy-in and all-gather]
+class AllGatherState(NamedTuple):
+    all_gather_result: AllGatherResult
+    event: torch.cuda.Event  # all-gather copy-out
+
+
+class FSDPParamGroup:
+    """This class represents a parameter group to communicate together."""
+
+    _orig_dtype: torch.dtype
+    _reduce_dtype: Optional[torch.dtype]
+
+    def __init__(
+        self,
+        params: List[nn.Parameter],
+        module: nn.Module,
+        mesh_info: FSDPMeshInfo,
+        post_forward_mesh_info: Optional[FSDPMeshInfo],
+        device: torch.device,
+        mp_policy: MixedPrecisionPolicy,
+    ):
+        self.module = module  # permit ref cycle because 1:1 lifetime
+        param_module_infos = _get_param_module_infos(params, module)
+        self.fsdp_params = [
+            FSDPParam(
+                param, module_info, mesh_info, post_forward_mesh_info, device, mp_policy
+            )
+            for param, module_info in zip(params, param_module_infos)
+        ]
+        self.mesh_info = mesh_info
+        self.post_forward_mesh_info = post_forward_mesh_info
+        self.device = device
+        self.mp_policy = mp_policy
+        self._training_state = TrainingState.IDLE
+        # Group's sharded state always matches its parameters' sharded states
+        self._sharded_state = ShardedState.SHARDED
+        self._module_fqn: Optional[str] = None  # prefixed from root module
+
+        # - Hook state
+        self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
+        self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
+
+        # - Communication and communication/computation overlap
+        self.comm_ctx = FSDPCommContext()
+        # Group's indices in the shared post-forward order
+        self._post_forward_indices: List[int] = []
+        # Used to avoid mistargeted backward prefetches when the module is used
+        # in forward but not in backward: for each forward, we record a tuple
+        # of the output's grad fns and later query the autograd engine whether
+        # any grad fn will execute in the current backward to know to prefetch.
+        self.all_forward_output_grad_fns: Set[Tuple[Node, ...]] = set()
+        # Whether to reduce-scatter or all-reduce gradients, respectively
+        # (can be set to false to save communication during gradient
+        # accumulation); all-reducing without reduce-scatter is disallowed
+        self.reduce_scatter_grads: bool = True
+        self.all_reduce_grads: bool = True
+
+        # - CUDA events for stream synchronization
+        # Holds the all-gather output buffer, sync objects, and metadata
+        self._all_gather_result: Optional[AllGatherResult] = None
+        # Holds the reduce-scatter view-out CUDA event that marks the end of
+        # the group's post-backward (e.g. reduce-scatter and div), which should
+        # be waited on at the end of backward
+        self._reduce_scatter_view_out_event: Optional[torch.cuda.Event] = None
+        # Holds the reshard-after-forward CUDA event when resharding to a
+        # different world size, which should be waited on in the next unshard
+        self._reshard_after_forward_event: Optional[torch.cuda.Event] = None
+
+    # Initialization #
+    def _init_mp_dtypes(self) -> None:
+        for fsdp_param in self.fsdp_params:
+            fsdp_param.init_dtype_attrs(self.mp_policy)
+        orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params}
+        if len(orig_dtypes) != 1:
+            # This can be relaxed if we copy-out for the reduce-scatter
+            raise AssertionError(
+                f"FSDP expects uniform original parameter dtype but got {orig_dtypes}"
+            )
+        self._orig_dtype = next(iter(orig_dtypes))
+        reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params}
+        if len(reduce_dtypes) != 1:
+            # This can be relaxed if we issue one reduce-scatter per reduce
+            # dtype (but we would need a way for users to specify multiple
+            # reduce dtypes)
+            raise AssertionError(
+                f"FSDP expects uniform reduce dtype but got {reduce_dtypes}"
+            )
+        self._reduce_dtype = next(iter(reduce_dtypes))
+
+    def _init_grad_divide_factors(self):
+        data_parallel_world_size = 1
+        data_parallel_world_size *= self.mesh_info.shard_mesh_size
+        if isinstance(self.mesh_info, HSDPMeshInfo):
+            data_parallel_world_size *= self.mesh_info.replicate_mesh_size
+        if self._reduce_dtype == torch.float32:
+            # Use NCCL's AVG op to divide after reduction since it is more
+            # performant and fp32 has sufficient precision
+            self._grad_divide_factors: Optional[Tuple[float, float]] = None
+            return
+        # For N data parallel workers, each worker computes g_i, and they
+        # collectively reduce (g_1 + ... + g_N) / N. To avoid overflow and
+        # underflow, we divide by ~sqrt(N) before and after the reduction.
+        factor: int = 1
+        while (
+            data_parallel_world_size % factor == 0
+            and data_parallel_world_size / factor > factor
+        ):
+            factor *= 2
+        factor = float(factor)
+        self._grad_divide_factors = (factor, data_parallel_world_size / factor)
+
+    def lazy_init(self):
+        param_names_on_meta = [
+            fsdp_param._param_fqn
+            for fsdp_param in self.fsdp_params
+            if fsdp_param.sharded_param.device.type == "meta"
+        ]
+        if param_names_on_meta:
+            raise RuntimeError(
+                "FSDP parameters should be materialized from meta device before training, "
+                f"but the following were still on meta device: {param_names_on_meta}\n"
+                "For example, call module.to_empty(device) to materialize to device and "
+                "call module.reset_parameters() on each module to initialize values."
+            )
+        # Initialize mixed precision attributes lazily in case the user changes
+        # the parameter dtypes after construction time but before forward
+        self._init_mp_dtypes()
+        self._init_grad_divide_factors()
+        self._register_state_dict_hooks()
+
+    # Runtime #
+    def unshard(self, async_op: bool = False):
+        if self._all_gather_result is not None:  # already called, pending wait
+            return
+        if self.is_unsharded:
+            return  # no-op
+        if self._reshard_after_forward_event is not None:
+            # Resharded parameter data is allocated in the default stream and
+            # used in the all-gather streams
+            self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
+            self._reshard_after_forward_event = None
+        self._all_gather_result = foreach_all_gather(
+            self.fsdp_params,
+            self._all_gather_process_group,
+            async_op,
+            *self.comm_ctx.get_all_gather_streams(self._training_state),
+            self.device,
+        )
+
+    def wait_for_unshard(self):
+        """
+        1. In forward with implict prefetching, to overlap the current copy-out
+        with the next all-gather, we save a reference to the current all-gather
+        result to free after the next copy-out.
+        2. Otherwise (explicit prefetching or in backward), we free the
+        all-gather result immediately after the current copy-out since we can
+        already overlap the current copy-out with the previous reduce-scatter.
+        """
+        if not self._all_gather_result:
+            return  # no preceding unshard
+        if self._training_state == TrainingState.FORWARD:  # implicit prefetch
+            if prev_all_gather_state := self.comm_ctx.all_gather_state:
+                self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
+                self.comm_ctx.all_gather_state = None  # free the all-gather result
+        foreach_all_gather_copy_out(
+            self._all_gather_result, self.fsdp_params, self._all_gather_process_group
+        )
+        for fsdp_param in self.fsdp_params:
+            fsdp_param.init_unsharded_param()  # no-op after 1st call
+        self._to_unsharded()
+        all_gather_copy_out_event = torch.cuda.Event()
+        all_gather_copy_out_event.record()
+        if self._training_state == TrainingState.FORWARD:
+            self.comm_ctx.all_gather_state = AllGatherState(
+                self._all_gather_result, all_gather_copy_out_event
+            )
+        else:
+            self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
+        self._all_gather_result = None  # free unless saved in `all_gather_state`
+
+    def _wait_all_gather_streams_on_event(self, event: torch.cuda.Event):
+        self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
+        self.comm_ctx.all_gather_stream.wait_event(event)
+
+    def reshard(self):
+        if self._training_state == TrainingState.FORWARD:
+            if not self._reshard_after_forward:
+                return
+            if self._use_post_forward_mesh:
+                self._to_sharded_post_forward()
+                self._reshard_after_forward_event = torch.cuda.Event()
+                self._reshard_after_forward_event.record()
+                return
+        self._to_sharded()
+
+    def pre_forward(
+        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
+    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+        with torch.profiler.record_function("FSDP::pre_forward"):
+            self._training_state = TrainingState.FORWARD
+            self.unshard()
+            self.wait_for_unshard()
+            args, kwargs = self._register_post_backward_hook(args, kwargs)
+            return args, kwargs
+
+    def post_forward(self, module: nn.Module, input: Any, output: Any):
+        with torch.profiler.record_function("FSDP::post_forward"):
+            self.reshard()
+            self._record_post_forward()
+            self._training_state = TrainingState.IDLE
+            return output
+
+    def _record_post_forward(self) -> None:
+        # Since a group has one pre-backward unshard for each forward call
+        # before the backward, we record each usage (with multiplicity)
+        post_forward_index = len(self.comm_ctx.post_forward_order)
+        self.comm_ctx.post_forward_order.append(self)
+        self._post_forward_indices.append(post_forward_index)
+
+    def pre_backward(self, forward_grad_fns: Tuple[Any, ...], *unused: Any):
+        with torch.profiler.record_function("FSDP::pre_backward"):
+            self._training_state = TrainingState.PRE_BACKWARD
+            self.unshard()  # no-op if prefetched
+            self.wait_for_unshard()
+            # Can be already removed if running multiple `backward`s
+            self.all_forward_output_grad_fns.discard(forward_grad_fns)
+            self._prefetch_unshard()
+
+    def post_backward(self, *unused: Any):
+        self._training_state = TrainingState.POST_BACKWARD
+        with torch.profiler.record_function("FSDP::post_backward_reshard"):
+            if not self.reduce_scatter_grads:
+                self.reshard()
+                return
+            # Save the autograd-computed gradients before resharding to only
+            # access the unsharded parameters when their data is present
+            fsdp_params_with_grad: List[FSDPParam] = []
+            unsharded_grads: List[torch.Tensor] = []
+            for fsdp_param in self.fsdp_params:
+                if fsdp_param.unsharded_param.grad is not None:
+                    fsdp_params_with_grad.append(fsdp_param)
+                    unsharded_grads.append(fsdp_param.unsharded_grad_data)
+                    fsdp_param.unsharded_param.grad = None
+            self.reshard()
+        if len(fsdp_params_with_grad) == 0:
+            return
+        with torch.profiler.record_function("FSDP::post_backward_reduce"):
+            self._reduce_scatter_view_out_event = foreach_reduce_scatter(
+                fsdp_params_with_grad,
+                unsharded_grads,
+                self._reduce_scatter_process_group,
+                self.comm_ctx.reduce_scatter_stream,
+                self._orig_dtype,
+                self._reduce_dtype,
+                self.device,
+                self._grad_divide_factors,
+            )
+
+    def finalize_backward(self):
+        if self._reduce_scatter_view_out_event is not None:
+            torch.cuda.current_stream().wait_event(self._reduce_scatter_view_out_event)
+            self._reduce_scatter_view_out_event = None
+        self._training_state = TrainingState.IDLE
+        self._post_forward_indices.clear()
+        self.all_forward_output_grad_fns.clear()
+
+    def _prefetch_unshard(self):
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            if not self._post_forward_indices:
+                # Can be cleared if running multiple `backward`s
+                return
+            curr_index = self._post_forward_indices.pop()
+            if (target_index := curr_index - 1) < 0:
+                return
+            target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
+            if any(
+                torch._C._will_engine_execute_node(grad_fn)  # type: ignore[attr-defined]
+                for grad_fns in target_fsdp_param_group.all_forward_output_grad_fns
+                for grad_fn in grad_fns
+            ):
+                with torch.profiler.record_function(
+                    "FSDP::backward_prefetch"
+                ), target_fsdp_param_group.use_training_state(
+                    TrainingState.PRE_BACKWARD
+                ):
+                    target_fsdp_param_group.unshard()
+
+    # Utilities #
+    def _to_sharded(self):
+        if not self.is_sharded:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.to_sharded()
+            self._sharded_state = ShardedState.SHARDED
+
+    def _to_sharded_post_forward(self):
+        if not self.is_sharded_post_forward:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.to_sharded_post_forward()
+            self._sharded_state = ShardedState.SHARDED_POST_FORWARD
+
+    def _to_unsharded(self):
+        if not self.is_unsharded:
+            for fsdp_param in self.fsdp_params:
+                fsdp_param.to_unsharded()
+            self._sharded_state = ShardedState.UNSHARDED
+
+    @property
+    def is_sharded(self) -> bool:
+        return self._sharded_state == ShardedState.SHARDED
+
+    @property
+    def is_sharded_post_forward(self) -> bool:
+        return self._sharded_state == ShardedState.SHARDED_POST_FORWARD
+
+    @property
+    def is_unsharded(self) -> bool:
+        return self._sharded_state == ShardedState.UNSHARDED
+
+    @contextlib.contextmanager
+    def use_training_state(self, training_state: TrainingState):
+        old_training_state = self._training_state
+        self._training_state = training_state
+        try:
+            yield
+        finally:
+            self._training_state = old_training_state
+
+    # Hook Registration #
+    def _register_post_backward_hook(
+        self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
+    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+        if not torch.is_grad_enabled():
+            return args, kwargs
+        args_list, args_spec = tree_flatten(args)
+        kwargs_list, kwargs_spec = tree_flatten(kwargs)
+        args_kwargs_list = list(args_list) + list(kwargs_list)
+        inp_tensor_indices: List[int] = []
+        inp_tensors: List[torch.Tensor] = []
+        for i, obj in enumerate(args_kwargs_list):
+            if torch.is_tensor(obj) and obj.requires_grad:
+                inp_tensor_indices.append(i)
+                inp_tensors.append(obj)
+        if len(inp_tensors) == 0:
+            return args, kwargs  # no tensors that require gradients
+        inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors)
+        for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
+            args_kwargs_list[inp_tensor_idx] = inp_tensor
+        args_list = args_kwargs_list[: len(args_list)]
+        kwargs_list = args_kwargs_list[len(args_list) :]
+        args = tree_unflatten(args_list, args_spec)
+        kwargs = tree_unflatten(kwargs_list, kwargs_spec)
+        return args, kwargs
+
+    def _register_state_dict_hooks(self) -> None:
+        assert len(self._module_to_pre_save_state_dict_hook_handle) == 0
+        assert len(self._module_to_pre_load_state_dict_hook_handle) == 0
+        modules_with_fsdp_params: Set[nn.Module] = {
+            fsdp_param._module_info.module for fsdp_param in self.fsdp_params
+        }
+
+        def to_sharded_hook(*args: Any, **kwargs: Any) -> None:
+            self._to_sharded()
+
+        for module in modules_with_fsdp_params:
+            self._module_to_pre_save_state_dict_hook_handle[
+                module
+            ] = module.register_state_dict_pre_hook(to_sharded_hook)
+            self._module_to_pre_load_state_dict_hook_handle[
+                module
+            ] = module._register_load_state_dict_pre_hook(to_sharded_hook)
+
+    # Properties #
+    @property
+    def _reshard_after_forward(self) -> bool:
+        return self.post_forward_mesh_info is not None
+
+    @property
+    def _use_post_forward_mesh(self) -> bool:
+        return (
+            self._reshard_after_forward
+            and self.mesh_info != self.post_forward_mesh_info
+        )
+
+    @property
+    def _all_gather_process_group(self) -> dist.ProcessGroup:
+        mesh_info = (
+            cast(FSDPMeshInfo, self.post_forward_mesh_info)
+            if self.is_sharded_post_forward
+            else self.mesh_info
+        )
+        assert isinstance(mesh_info, FSDPMeshInfo)
+        return mesh_info.shard_process_group
+
+    @property
+    def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
+        mesh_info = self.mesh_info
+        assert isinstance(mesh_info, FSDPMeshInfo)
+        return mesh_info.shard_process_group
+
+
+def _get_param_module_infos(
+    params: List[nn.Parameter], module: nn.Module
+) -> List[ParamModuleInfo]:
+    """
+    Shared parameter: lin1.weight = lin2.weight
+    Shared module: mlp.lin1 = mlp.lin2
+    We do not remove duplicates when traversing both modules and parameters to
+    find shared modules' parameters and shared parameters within a module.
+    """
+    params_set = set(params)
+    param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {}
+    for _, submodule in module.named_modules(remove_duplicate=False):
+        for param_name, param in _named_parameters_with_duplicates(
+            submodule, recurse=False
+        ):
+            if param in params_set:
+                if param not in param_to_module_info:
+                    param_to_module_info[param] = ParamModuleInfo(submodule, param_name)
+                else:
+                    param_to_module_info[param].shared_modules.append(submodule)
+                    param_to_module_info[param].shared_param_names.append(param_name)
+    if len(param_to_module_info) != len(params):
+        raise AssertionError(f"Some parameters are not in the module tree of {module}")
+    return [param_to_module_info[param] for param in params]
+
+
+class RegisterPostBackwardFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
+        # All tensors in `inputs` should require gradient
+        ctx.param_group = param_group
+        return inputs
+
+    @staticmethod
+    def backward(ctx, *grads: torch.Tensor):
+        ctx.param_group.post_backward()
+        return (None,) + grads
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..06f839ee429cdef4bac0761a1cfc45746c0765e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/_fsdp_state.py
@@ -0,0 +1,246 @@
+import functools
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.autograd.graph import Node, register_multi_grad_hook
+from torch.distributed._composable_state import (
+    _get_module_state,
+    _insert_module_state,
+    _State,
+)
+from torch.distributed.utils import _to_kwargs
+from torch.utils._pytree import tree_flatten, tree_map
+from torch.utils.hooks import RemovableHandle
+from ._fsdp_api import MixedPrecisionPolicy
+from ._fsdp_common import _cast_fp_tensor, TrainingState
+from ._fsdp_param import FSDPParam
+from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
+
+
+class FSDPStateContext:
+    """This has state shared across FSDP states."""
+
+    def __init__(self):
+        # All FSDP states in the root state's module tree
+        self.all_states: List[FSDPState] = []
+        # Iteration's forward root runs the once-per-forward logic; this root
+        # may not be the overall root set by lazy initialization in cases where
+        # only a submodule runs forward (e.g. encoder-only for eval)
+        self.iter_forward_root: Optional[FSDPState] = None
+        # Final callback should only be queued once per backward
+        self.post_backward_final_callback_queued: bool = False
+        # Whether to finalize backward in this backward's final callback
+        self.is_last_backward: bool = True
+
+
+class FSDPState(_State):
+    def __init__(self):
+        super().__init__()
+        self._fsdp_param_group: Optional[FSDPParamGroup] = None
+        self._is_root: Optional[bool] = None  # root set during lazy init
+        self._state_ctx = FSDPStateContext()
+        self._comm_ctx = FSDPCommContext()
+        self._training_state: TrainingState = TrainingState.IDLE
+        self._pre_backward_hook_handles: List[RemovableHandle] = []
+
+    # Define a separate init since `__init__` is called in the contract
+    def init(
+        self, module: nn.Module, device: torch.device, mp_policy: MixedPrecisionPolicy
+    ) -> None:
+        _insert_module_state(module, self)
+        self._module = module
+        self._device = device
+        self._mp_policy = mp_policy
+        self._pre_forward_hook_handle = module.register_forward_pre_hook(
+            self._pre_forward, prepend=True, with_kwargs=True
+        )
+        self._post_forward_hook_handle = module.register_forward_hook(
+            self._post_forward, prepend=False
+        )
+
+    def _root_pre_forward(
+        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
+    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+        self._lazy_init()
+        if self._state_ctx.iter_forward_root is not None:
+            return args, kwargs
+        self._state_ctx.iter_forward_root = self
+        with torch.profiler.record_function("FSDP::root_pre_forward"):
+            # Wait for optimizer before implicitly prefetched all-gathers
+            current_stream = torch.cuda.current_stream()
+            self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
+            self._comm_ctx.all_gather_stream.wait_stream(current_stream)
+            if self._device.type == "cuda":
+                with torch.profiler.record_function("FSDP::inputs_to_device"):
+                    args_tuple, kwargs_tuple = _to_kwargs(
+                        args, kwargs, self._device, False
+                    )  # same as DDP
+                args, kwargs = args_tuple[0], kwargs_tuple[0]
+        return args, kwargs
+
+    def _lazy_init(self) -> None:
+        """
+        Lazy initialization represents when all modules' parallelisms have
+        finalized (e.g. FSDP has been applied to all desired modules). This
+        means that we can determine which state is the root, and we do so by
+        the 1st state to run forward.
+        """
+        if self._is_root is not None:
+            return  # no-op: already initialized
+        self._is_root = True
+        root_module = self._module
+        for module_name, module in root_module.named_modules():
+            if (state := _get_module_fsdp_state(module)) is None:
+                continue
+            if module is not root_module:
+                if state._is_root is not None:
+                    raise RuntimeError(
+                        "FSDP state has already been lazily initialized for "
+                        f"{module_name}\nFSDP requires running forward through "
+                        "the root module first"
+                    )
+                state._is_root = False
+            self._state_ctx.all_states.append(state)
+        if self._fsdp_param_group:
+            # For the root, do not reshard after forward since for training,
+            # the parameters would be freed and all-gathered immediately
+            self._fsdp_param_group.post_forward_mesh_info = None
+        self._init_fqns()
+        self._init_shared_state()
+        # Run parameter group lazy inits after initializing FQNs for improved
+        # error messages
+        for state in self._state_ctx.all_states:
+            if state._fsdp_param_group:
+                state._fsdp_param_group.lazy_init()
+
+    def _init_shared_state(self) -> None:
+        self._comm_ctx.init()
+        for state in self._state_ctx.all_states:
+            state._state_ctx = self._state_ctx
+            state._comm_ctx = self._comm_ctx
+            if fsdp_param_group := state._fsdp_param_group:
+                fsdp_param_group.comm_ctx = self._comm_ctx
+
+    def _init_fqns(self) -> None:
+        """Sets module and parameter FQN attributes for debugging."""
+        assert self._is_root
+        root_module = self._module
+        param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {}
+        module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {}
+        for state in self._state_ctx.all_states:
+            if fsdp_param_group := state._fsdp_param_group:
+                for fsdp_param in fsdp_param_group.fsdp_params:
+                    param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param
+                module_to_fsdp_param_group[fsdp_param_group.module] = fsdp_param_group
+        for param_name, param in root_module.named_parameters():
+            if param in param_to_fsdp_param:
+                param_to_fsdp_param[param]._param_fqn = param_name
+        for module_name, module in root_module.named_modules():
+            if module in module_to_fsdp_param_group:
+                module_to_fsdp_param_group[module]._module_fqn = module_name
+
+    def _pre_forward(
+        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
+    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+        # When composing with module-hook-based activation checkpointing, the
+        # the pre-backward hook is responsible for the unshard
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            return args, kwargs
+        self._training_state = TrainingState.FORWARD
+        args, kwargs = self._root_pre_forward(module, args, kwargs)
+        if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype:
+            with torch.profiler.record_function("FSDP::cast_forward_inputs"):
+                cast_fn = functools.partial(
+                    _cast_fp_tensor, self._mp_policy.param_dtype
+                )
+                args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
+        if self._fsdp_param_group:
+            args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
+        return args, kwargs
+
+    def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any:
+        # When composing with module-hook-based activation checkpointing, the
+        # post-backward hook is responsible for the reshard
+        if self._training_state == TrainingState.PRE_BACKWARD:
+            return output
+        if self._fsdp_param_group:
+            output = self._fsdp_param_group.post_forward(module, input, output)
+        output = self._register_pre_backward_hook(output)
+        self._training_state = TrainingState.IDLE
+        if self._state_ctx.iter_forward_root is self:
+            if all_gather_state := self._comm_ctx.all_gather_state:
+                # Free the last all-gather result if needed; refer to
+                # [Note: Overlapping all-gather copy-in and all-gather]
+                self._comm_ctx.all_gather_copy_in_stream.wait_event(
+                    all_gather_state.event
+                )
+                self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event)
+                self._comm_ctx.all_gather_state = None  # free the all-gather result
+            self._state_ctx.iter_forward_root = None
+        if self._mp_policy.output_dtype is not None:
+            with torch.profiler.record_function("FSDP::cast_forward_outputs"):
+                output = tree_map(
+                    functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype),
+                    output,
+                )
+        return output
+
+    def _pre_backward(self, forward_grad_fns: Tuple[Node, ...], *unused: Any) -> None:
+        self._training_state = TrainingState.PRE_BACKWARD
+        self._register_root_post_backward_final_callback()
+        if self._fsdp_param_group:
+            self._fsdp_param_group.pre_backward(forward_grad_fns, *unused)
+
+    def _root_post_backward_final_callback(self) -> None:
+        with torch.profiler.record_function("FSDP::root_post_backward_callback"):
+            for state in self._state_ctx.all_states:
+                if state._fsdp_param_group and state._fsdp_param_group.is_unsharded:
+                    # Run post-backward in case forward inputs did not require
+                    # gradient so the autograd backward did not run
+                    state._fsdp_param_group.post_backward()
+                if self._state_ctx.is_last_backward:
+                    state._finalize_backward()
+            if self._state_ctx.is_last_backward:
+                self._comm_ctx.post_forward_order.clear()
+            self._state_ctx.post_backward_final_callback_queued = False
+
+    def _finalize_backward(self) -> None:
+        self._training_state = TrainingState.IDLE
+        for handle in self._pre_backward_hook_handles:
+            handle.remove()
+        self._pre_backward_hook_handles.clear()
+        if self._fsdp_param_group:
+            self._fsdp_param_group.finalize_backward()
+
+    def _register_pre_backward_hook(self, output: Any) -> Any:
+        if not torch.is_grad_enabled():
+            return output
+
+        flat_outputs, _ = tree_flatten(output)
+        tensors = tuple(t for t in flat_outputs if t.requires_grad)
+        if tensors:
+            grad_fns = tuple(t.grad_fn for t in tensors if t.grad_fn is not None)
+            pre_backward = functools.partial(self._pre_backward, grad_fns)
+            handle = register_multi_grad_hook(tensors, pre_backward, mode="any")
+            self._pre_backward_hook_handles.append(handle)
+            if self._fsdp_param_group:
+                self._fsdp_param_group.all_forward_output_grad_fns.add(grad_fns)
+        return output
+
+    def _register_root_post_backward_final_callback(self):
+        if self._state_ctx.post_backward_final_callback_queued:
+            return
+        self._state_ctx.post_backward_final_callback_queued = True
+        Variable._execution_engine.queue_callback(
+            self._root_post_backward_final_callback
+        )
+
+
+def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]:
+    state = _get_module_state(module)
+    if isinstance(state, FSDPState):
+        return state
+    return None
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9962d01d7ca7877079bc4daacc5a203987fa484
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py
@@ -0,0 +1,246 @@
+from typing import Any, cast, Optional, Union
+
+import typing_extensions
+
+import torch
+import torch.nn as nn
+
+from torch.distributed._composable import contract
+from torch.distributed._tensor import DeviceMesh, DTensor
+
+from ._fsdp_api import MixedPrecisionPolicy
+from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
+from ._fsdp_init import (
+    _get_device_from_mesh,
+    _get_managed_modules,
+    _get_managed_states,
+    _get_post_forward_mesh_info,
+    _init_default_fully_shard_mesh,
+    _move_states_to_device,
+)
+from ._fsdp_param_group import FSDPParamGroup
+from ._fsdp_state import _get_module_fsdp_state, FSDPState
+
+
+# The decorator adds a state object to `module` that can be accessed via
+# `fully_shard.state(module)`. The state object and module are 1:1.
+@contract(state_cls=FSDPState)
+def fully_shard(
+    module: nn.Module,
+    *,
+    mesh: Optional[DeviceMesh] = None,
+    reshard_after_forward: Union[bool, int] = True,
+    mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
+):
+    """
+    Shard module parameters across data parallel workers.
+
+    This function applies fully sharded data parallelism (FSDP) or a variant to
+    ``module``, a technique for memory savings at the cost of communication.
+    Parameters are sharded across ``mesh``, and in turn, so are their gradients
+    and optimizer states.
+
+    The sharded parameters are all-gathered to construct the unsharded
+    parameters for forward or backward computation. The unsharded parameters
+    are freed after computation to save memory. The gradients are reduced
+    across the mesh and divided by the mesh size for data parallelism. The
+    optimizer step runs on the sharded parameters.
+
+    Each call to ``fully_shard`` constructs one communication group that
+    includes the parameters in ``module.parameters()`` except those already
+    assigned to a group from a nested call. Each group's parameters and its
+    gradients are communicated together in one collective, respectively.
+    Constructing multiple groups across the model (e.g. "layer by layer")
+    allows for peak memory savings and communication/computation overlap.
+
+    Implementation-wise, the sharded parameters are represented as
+    :class:`DTensor` s, sharded on dim-0, and the unsharded parameters are
+    represented as :class:`Tensor` s. A module forward pre-hook all-gathers the
+    parameters, and a module forward hook frees them. Similar backward hooks
+    gather parameters and later free parameters/reduce gradients.
+
+    Args:
+        mesh (Optional[DeviceMesh]): This data parallel mesh defines the
+            sharding and device. If 1D, then parameters are fully sharded
+            across the 1D mesh (FSDP). If 2D, then parameters are sharded
+            across the 0th dim and replicated across the 1st dim (HSDP). The
+            mesh's device type gives the device type used for communication;
+            if a CUDA or CUDA-like device type, then we use the current device.
+        reshard_after_forward (Union[bool, int]): This controls the parameter
+            behavior after forward and can trade off memory and communication:
+            - If ``True``, then this reshards parameters after forward and
+            all-gathers in backward.
+            - If ``False``, then this keeps the unsharded parameters in memory
+            after forward and avoids the all-gather in backward.
+            - If an ``int``, then this represents the world size to reshard to
+            after forward. It should be a non-trivial divisor of the ``mesh``
+            shard dim size (i.e. excluding 1 and the dim size itself). A choice
+            may be the intra-node size (e.g. ``torch.cuda.device_count()``).
+            This allows the all-gather in backward to be over a smaller world
+            size at the cost of higher memory usage than setting to ``True``.
+            - The root FSDP state has its value specially set to ``False`` as a
+            heuristic since its parameters would typically be immediately
+            all-gathered for backward.
+            - After forward, the parameters registered to the module depend on
+            to this: The registered parameters are the sharded parameters if
+            ``True``; unsharded parameters if ``False``; and the paramters
+            resharded to the smaller mesh otherwise. To modify the parameters
+            between forward and backward, the registered parameters must be the
+            sharded parameters. For ``False`` or an ``int``, this can be done
+            by manually resharding via :meth:`reshard`.
+        mp_policy (MixedPrecisionPolicy): This controls the mixed precision
+            policy, which offers parameter/reduction mixed precision for this
+            module. See :class:`MixedPrecisionPolicy` for details.
+    """
+    if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
+        raise ValueError(
+            f"fully_shard does not support containers that do not implement forward: {module}"
+        )
+    mesh = mesh or _init_default_fully_shard_mesh()
+    if mesh.ndim not in (1, 2):
+        raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
+    elif mesh.ndim == 1:
+        mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
+    else:
+        mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
+    device = _get_device_from_mesh(mesh)
+    post_forward_mesh_info = _get_post_forward_mesh_info(
+        reshard_after_forward, mesh_info
+    )
+
+    state = fully_shard.state(module)
+    state.init(module, device, mp_policy)
+
+    managed_modules = _get_managed_modules(module)
+    params, buffers = _get_managed_states(managed_modules)
+    _move_states_to_device(params, buffers, device, mesh_info)
+    if params:
+        state._fsdp_param_group = FSDPParamGroup(
+            params, module, mesh_info, post_forward_mesh_info, device, mp_policy
+        )
+
+    # for dynamo
+    for module in managed_modules:
+        module._is_fsdp_managed_module = True  # type: ignore[assignment]
+        module._fsdp_use_orig_params = True  # type: ignore[assignment]
+
+    # Place FSDP leftmost for highest priority in the method resolution order
+    cls = module.__class__
+    dct = {"__deepcopy__": unimplemented_deepcopy}
+    new_cls = type(f"FSDP{cls.__name__}", (FSDP, cls), dct)
+    module.__class__ = new_cls
+    return module
+
+
+def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> typing_extensions.Never:
+    raise AssertionError(
+        "FSDP does not support deepcopy. Please use state dict for serialization."
+    )
+
+
+class FSDP:
+    def __new__(cls, *args, **kwargs):
+        """
+        Override ``__new__`` to remove the FSDP class and directly construct
+        the original class for cases like indexing into a container module.
+        """
+        # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
+        # and index 1 is the `FSDP` class itself
+        orig_cls = cls.__mro__[2]
+        self = orig_cls.__new__(orig_cls, *args, **kwargs)
+        self.__init__(*args, **kwargs)
+        return self
+
+    def reshard(self) -> None:
+        """
+        Reshards the module's parameters, registering the sharded parameters
+        to the module and freeing the unsharded parameters if needed. This
+        method is *not* recursive.
+        """
+        state = self._get_fsdp_state()
+        if fsdp_param_group := state._fsdp_param_group:
+            fsdp_param_group.reshard()
+
+    def set_is_last_backward(self, is_last_backward: bool) -> None:
+        """
+        Sets whether the next backward is the last one, meaning that FSDP
+        should wait for gradient reduction to finish and clear internal data
+        structures used for explicit prefetching.
+        """
+        state = self._get_fsdp_state()
+        state._state_ctx.is_last_backward = is_last_backward
+
+    def set_requires_gradient_sync(
+        self, requires_gradient_sync: bool, recurse: bool = True
+    ) -> None:
+        """
+        Sets if the module should sync gradients. This can be used to implement
+        gradient accumulation without communication. For HSDP, this controls
+        both reduce-scatter and all-reduce together.
+
+        Args:
+            requires_gradient_sync (bool): Whether to reduce gradients for the
+                module's parameters.
+            recurse (bool): Whether to set for all submodules or just the
+                passed-in module.
+        """
+        for module in cast(nn.Module, self).modules():
+            if isinstance(module, FSDP):
+                state = module._get_fsdp_state()
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.reduce_scatter_grads = requires_gradient_sync
+                    fsdp_param_group.all_reduce_grads = requires_gradient_sync
+
+    def set_requires_all_reduce(self, requires_all_reduce: bool, recurse: bool = True):
+        """
+        Sets if the module should all-reduce gradients. This can be used to
+        implement gradient accumulation with only reduce-scatter but not
+        all-reduce for HSDP.
+        """
+        for module in cast(nn.Module, self).modules():
+            if isinstance(module, FSDP):
+                state = module._get_fsdp_state()
+                if fsdp_param_group := state._fsdp_param_group:
+                    fsdp_param_group.all_reduce_grads = requires_all_reduce
+
+    def _get_fsdp_state(self) -> FSDPState:
+        if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
+            raise AssertionError(f"No FSDP state found on {self}")
+        return state
+
+    def _apply(self, *args: Any, **kwargs: Any) -> Any:
+        # Reshard to ensure that sharded parameters are registered
+        self.reshard()
+        ret = super()._apply(*args, **kwargs)  # type: ignore[misc]
+        state = self._get_fsdp_state()
+        if not (fsdp_param_group := state._fsdp_param_group):
+            return ret
+        # TODO: Remove this padding logic once DTensor pads the local tensor:
+        # https://github.com/pytorch/pytorch/issues/113045
+        with torch.no_grad():
+            for fsdp_param in fsdp_param_group.fsdp_params:
+                module_info = fsdp_param._module_info
+                new_param = getattr(module_info.module, module_info.param_name)
+                if new_param is not fsdp_param.sharded_param:
+                    if torch.__future__.get_swap_module_params_on_conversion():
+                        raise AssertionError(
+                            "Expects swap_tensors to preserve object but got "
+                            f"{new_param} instead of {fsdp_param.sharded_param}"
+                        )
+                    else:
+                        raise AssertionError(
+                            "Please set torch.__future__.set_swap_module_params_on_conversion(True) "
+                            "to use _apply methods with FSDP"
+                        )
+                local_tensor = new_param._local_tensor
+                padded_sharded_size = fsdp_param.padded_sharded_param_size
+                if local_tensor.size() != padded_sharded_size:
+                    padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
+                    padded_local_tensor[: local_tensor.size(0)].copy_(local_tensor)
+                    local_tensor = padded_local_tensor
+                fsdp_param._sharded_param_data = local_tensor.view(-1)
+                assert isinstance(fsdp_param.sharded_param, DTensor)  # mypy
+                fsdp_param.sharded_param._local_tensor = local_tensor[
+                    : fsdp_param.sharded_size[0]
+                ]
+        return ret
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/fully_shard.py b/MLPY/Lib/site-packages/torch/distributed/_composable/fully_shard.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ee3b54fbb101861fb0886a32369ba6cfb40c43
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/fully_shard.py
@@ -0,0 +1,133 @@
+import warnings
+from typing import Callable, Iterable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed._composable.contract import contract
+from torch.distributed._composable_state import _get_module_state, _insert_module_state
+from torch.distributed.fsdp._common_utils import _FSDPState
+from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
+
+from torch.distributed.fsdp._init_utils import (
+    _init_buffer_state,
+    _init_core_state,
+    _init_device_handle,
+    _init_ignored_module_states,
+    _init_param_handle_from_module,
+    _init_prefetching_state,
+    _init_process_group_state,
+    _init_runtime_state,
+    _init_state_dict_state,
+    HYBRID_SHARDING_STRATEGIES,
+)
+from torch.distributed.fsdp._runtime_utils import (
+    _register_post_forward_hook,
+    _register_pre_forward_hook,
+    _register_root_pre_forward_hook,
+)
+from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
+from torch.distributed.fsdp._wrap_utils import _auto_wrap
+from torch.distributed.fsdp.api import (
+    BackwardPrefetch,
+    CPUOffload,
+    MixedPrecision,
+    ShardingStrategy,
+)
+from torch.distributed.fsdp.wrap import _Policy
+
+
+@contract(state_cls=_FSDPState)
+def fully_shard(
+    module: nn.Module,
+    *,
+    process_group: Optional[dist.ProcessGroup] = None,
+    policy: Optional[_Policy] = None,
+    strategy: Optional[ShardingStrategy] = None,
+    mixed_precision: Optional[MixedPrecision] = None,
+    cpu_offload: Optional[CPUOffload] = None,
+    ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
+    device_id: Optional[Union[int, torch.device]] = None,
+    param_init_fn: Optional[Callable[[nn.Module], None]] = None,
+    sync_module_states: bool = False,
+    forward_prefetch: bool = False,
+    ignored_states: Union[
+        Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
+    ] = None,
+) -> nn.Module:
+    """
+    Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
+    """
+    warnings.warn(
+        "``torch.distributed._composable.fully_shard`` is being deprecated."
+        "You can contintue to use the wrapper based FSDP."
+        "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py."
+        "``torch.distributed._composable.fully_shard`` will be removed after PyTorch 2.5."
+    )
+
+    torch._C._log_api_usage_once("torch.distributed.fully_shard")
+    # Enforce the new auto wrap policy
+    if policy is not None and not isinstance(policy, _Policy):
+        raise ValueError(f"Expects a `_Policy` but got {policy}")
+    state = fully_shard.state(module)
+    state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
+    state = _init_device_handle(state, module, state._ignored_params, device_id)
+    _annotate_modules_for_dynamo(module, state._ignored_modules, True)
+    state = _init_process_group_state(state, process_group, strategy, policy)
+    if policy is not None:
+        root_kwargs = {
+            "process_group": process_group,
+            "strategy": strategy,
+            "mixed_precision": mixed_precision,
+            "cpu_offload": cpu_offload,
+            "ignored_modules": ignored_modules,
+            "device_id": device_id,
+            "param_init_fn": param_init_fn,
+            "sync_module_states": sync_module_states,
+            "forward_prefetch": forward_prefetch,
+            "ignored_states": ignored_states,
+        }
+        if strategy in HYBRID_SHARDING_STRATEGIES:
+            root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
+        _auto_wrap(
+            module,
+            policy,
+            state._ignored_modules,
+            state._ignored_params,
+            root_kwargs,
+            fully_shard,
+        )
+    state = _init_core_state(
+        state,
+        strategy or ShardingStrategy.FULL_SHARD,
+        mixed_precision,
+        cpu_offload,
+        limit_all_gathers=True,
+        use_orig_params=True,
+        backward_prefetch_limit=1,
+        forward_prefetch_limit=1,
+    )
+    state = _init_runtime_state(state)
+    state = _init_prefetching_state(
+        state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
+    )
+    state = _init_buffer_state(state, module)
+    state = _init_param_handle_from_module(
+        state, module, device_id, param_init_fn, sync_module_states
+    )
+    state = _init_state_dict_state(state)
+    _register_all_state_dict_hooks(state)
+    _register_pre_forward_hook(state, module)
+    _register_post_forward_hook(state, module)
+    _register_root_pre_forward_hook(state, module)  # prepend last
+    # Always insert the state for the passed-in module even if it has no
+    # managed parameters, in which case it has no handles and does not appear
+    # in `_fully_sharded_module_to_handles`
+    _insert_module_state(module, state)
+    for submodule in module.modules():
+        if (
+            submodule in state._fully_sharded_module_to_handle
+            and _get_module_state(submodule) is None
+        ):
+            _insert_module_state(submodule, state)
+    return module
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable/replicate.py b/MLPY/Lib/site-packages/torch/distributed/_composable/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d67f9f4201ba3bf1450f2f9fa59920cb783dcca
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable/replicate.py
@@ -0,0 +1,154 @@
+import weakref
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
+
+import torch
+import torch.nn as nn
+from torch.distributed._composable_state import _State
+from torch.nn.parallel import DistributedDataParallel
+
+from .contract import _get_registry, contract
+
+_ROOT_MODULE_PREFIX = ""
+
+
+class _ReplicateState(_State):
+    def __init__(self) -> None:
+        super().__init__()
+        self.module: nn.Module = nn.ParameterList()
+        self.has_initialized: bool = False
+        self._param_list: nn.ParameterList = nn.ParameterList()
+        # TODO(@fegin): this variable is originally create for testing, we
+        # should remove this if possible.
+        self._param_names: List[str] = []
+
+    def _collect_params(
+        self,
+        module: nn.Module,
+        ignored_modules: Set[nn.Module],
+        ignored_params: Set[nn.Parameter],
+        prefix: str = _ROOT_MODULE_PREFIX,
+    ) -> None:
+        # skip if managed by fully_sharded API
+        if _is_fully_sharded(module):
+            return
+
+        # if a module is ignored, all descendants of the module are ignored.
+        if module in ignored_modules:
+            return
+
+        recurse_prefix = (
+            f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
+        )
+
+        for n, p in module.named_parameters(recurse=False):
+            if p not in ignored_params:
+                self._param_list.append(p)
+                self._param_names.append(f"{recurse_prefix}{n}")
+
+        for name, child_module in module.named_children():
+            self._collect_params(
+                child_module,
+                ignored_modules,
+                ignored_params,
+                prefix=f"{recurse_prefix}{name}",
+            )
+
+    def init(
+        self,
+        module: nn.Module,
+        ignored_modules: Set[nn.Module],
+        **kwargs,
+    ) -> None:
+        if _is_fully_sharded(module):
+            raise RuntimeError(
+                "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
+            )
+
+        if self.has_initialized:
+            return
+
+        self.has_initialized = True
+        self.module = module
+        ignored_params = {p for m in ignored_modules for p in m.parameters()}
+        self._collect_params(module, ignored_modules, ignored_params)
+        module.register_forward_pre_hook(self.forward_pre_hook, with_kwargs=True)
+        module.register_forward_hook(self.forward_post_hook)  # type: ignore[arg-type]
+
+        if "device_id" in kwargs:
+            # replicate() supports a small usability enhancement where
+            # user can pass in device_id as a Union[int, torch.device] even for
+            # CPU devices so users don't have to change code for CPU/GPU runs.
+            # We derive the right device_ids to feed into DDP to support this.
+            if kwargs["device_id"] is not None:
+                device_id = kwargs["device_id"]
+                # Convert to device_ids that DDP expects.
+                if isinstance(device_id, torch.device) and device_id.type == "cpu":
+                    # CPU modules receive device_ids None
+                    kwargs["device_ids"] = None
+                else:
+                    # GPU modules expect device_ids=[cuda_device]
+                    kwargs["device_ids"] = [device_id]
+            else:
+                kwargs["device_ids"] = None
+            kwargs.pop("device_id")
+
+        self._ddp = DistributedDataParallel(self._param_list, **kwargs)
+        # Weakref to the DDP instance is currently only used for testing.
+        replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
+
+    def forward_pre_hook(
+        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
+    ) -> Any:
+        return self._ddp._pre_forward(*args, **kwargs)
+
+    def forward_post_hook(
+        self,
+        module: nn.Module,
+        input: Tuple[torch.Tensor],
+        output: torch.Tensor,
+    ) -> torch.Tensor:
+        return self._ddp._post_forward(output)
+
+
+@contract(state_cls=_ReplicateState)
+def replicate(
+    module: nn.Module,
+    ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
+    **kwargs,
+) -> nn.Module:
+    r"""Replicates a module
+
+    Args:
+        module (torch.nn.Module): module to replicate
+
+    Example::
+        >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
+        >>> module = nn.Linear(3, 3)
+        >>> replicate(module)
+    """
+    torch._C._log_api_usage_once("torch.distributed.replicate")
+
+    # TODO(fegin): using kwargs is not a good idea if we would like to make
+    # replicate a formal API to replace DDP.
+    if "device_id" in kwargs:
+        if not isinstance(kwargs["device_id"], (int, torch.device)):
+            raise RuntimeError(
+                "Expected device_id to be int or torch.device, "
+                f"but got {type(kwargs['device_id'])}"
+            )
+
+    if ignored_modules is None:
+        ignored_modules = {}
+    else:
+        ignored_modules = set(ignored_modules)
+    replicate.state(module).init(module, ignored_modules, **kwargs)
+
+    return module
+
+
+def _is_fully_sharded(module: nn.Module) -> bool:
+    r"""Check if module is marked with fully_shard."""
+    registry = _get_registry(module)
+    if registry is None:
+        return False
+    return "fully_shard" in registry
diff --git a/MLPY/Lib/site-packages/torch/distributed/_composable_state.py b/MLPY/Lib/site-packages/torch/distributed/_composable_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..5095fd4424838902d80594f4c9b4e53852990ace
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_composable_state.py
@@ -0,0 +1,37 @@
+from typing import cast, Dict, Optional
+
+import torch.nn as nn
+
+
+class _State:
+    pass
+
+
+_module_state_mapping: Dict[nn.Module, _State] = {}
+
+
+def _insert_module_state(module: nn.Module, state: _State) -> None:
+    global _module_state_mapping
+    assert module not in _module_state_mapping, f"Inserting {module} more than once."
+    _module_state_mapping[module] = state
+
+
+def _get_module_state(module: nn.Module) -> Optional[_State]:
+    """
+    Return the ``_State`` in ``model``.
+
+    Given a ``module``, this API finds out if the module is also a ``_State``
+    instance or if the module is managed by a composable API. If the module
+    is also a ``_State``, ``module`` will be casted to ``_State` and returned.
+    If it is managed by a composable API, the corresponding ``_State`` will
+    be returned.
+    """
+    global _module_state_mapping
+    if isinstance(module, _State):
+        return cast(_State, module)
+    else:
+        # https://github.com/pytorch/pytorch/issues/107054
+        if module in _module_state_mapping:
+            return _module_state_mapping[module]
+        else:
+            return None
diff --git a/MLPY/Lib/site-packages/torch/distributed/_functional_collectives.py b/MLPY/Lib/site-packages/torch/distributed/_functional_collectives.py
new file mode 100644
index 0000000000000000000000000000000000000000..84c154f90b8889a80116d779e75bd95341e43d14
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_functional_collectives.py
@@ -0,0 +1,1084 @@
+import sys
+import warnings
+from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
+
+import torch
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as c10d
+from torch._custom_ops import impl_abstract
+from torch.distributed.device_mesh import DeviceMesh
+from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
+
+from . import _functional_collectives_impl as fun_col_impl
+from ._functional_collectives_impl import (  # noqa: F401
+    _register_tensor_wrapper,
+    native_funcol_enabled,
+)
+
+try:
+    from torch.utils._cxx_pytree import tree_map_only
+except ImportError:
+    from torch.utils._pytree import tree_map_only  # type: ignore[no-redef]
+
+
+if torch._running_with_deploy():
+
+    def is_torchdynamo_compiling():
+        """Can't import torchdynamo in torchdeploy builds currently."""
+        return False
+
+else:
+    try:
+        from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
+    except Exception:
+        warnings.warn(
+            "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
+        )
+
+        def is_torchdynamo_compiling():
+            return False
+
+
+"""
+New traceable, functional collectives.
+RFC: https://github.com/pytorch/pytorch/issues/93173
+
+  compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
+  eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
+         automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
+         a downstream op.
+
+Issues:
+* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
+* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
+"""
+
+"""
+Functional collectives are asynchronous only and we perform implicit stream synchronization
+on behalf of the user.
+
+We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
+first usage of the tensor and insert cross stream sync at the right place.
+
+The above are the easy bits, the hard one is how we match the Work object returned by
+c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
+op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
+dispatcher which might call other implementations that are allowed to change the returned
+tensor - even return a tensor with a different shape (see ``torch.vmap``).
+
+This means the caller of our ops receives a Tensor that is not guaranteed to be the same
+allocated by our implementations and that makes pairing The AsyncTensor to the original
+tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
+
+Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
+identity is not stable across dispatch, the op caller would end up with a different Tensor
+instance that would not match any in the dictionary.
+
+With Tensor identity out of the question, we decided use the tensor data pointer, which
+should be stable across all the Tensor changes done during dispatch.
+
+We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
+
+We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
+
+Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
+can clean up stale entries in the dictionary.
+
+To eliminate the possibility of races we have a global version counter that is used by the finalizer.
+
+As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
+
+"""
+
+"""
+Functional collectives can accept any of these types to describe the ranks participating in collectives.
+
+The different types will be desugared to a canonical format
+"""
+RANK_TYPES = Union[
+    List[int],
+    List[List[int]],
+    dist.ProcessGroup,
+    DeviceMesh,
+    Tuple["dist._tensor.DeviceMesh", int],
+    str,
+]
+
+
+"""
+User facing APIs for functional collectives
+-------------------------------------------
+
+These apis are called by user code and expected to work both in eager execution and compilation,
+but there are significant differences to how the two modes are implemented underneath.
+
+Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
+just before the tensor is first used.  Compiled tracing currently relies on the compiler to perform this optimization,
+and cannot yet correctly trace the AsyncTensor wrapper class.  In the future, these paths may be unified
+if sufficient subclass support is added in dynamo.
+
+Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
+
+Here's how it works under torch.compile/dynamo:
+all_reduce(...)
+  |--> _expand_group(...)               - desugars processgroup into canonical/traceable format
+  |--> c10d_functional.all_reduce(...)  - dynamo captures this op call, doesn't trace deeper
+  |--> _maybe_wrap_tensor(...)          - wait_tensor() op is immediately called, no AsyncTensor subclass needed
+
+And under eager execution:
+all_reduce(...)
+  |--> _expand_group(...)               - same as above, but less critical for eager
+  |--> c10d_functional.all_reduce(...)  - dispatches to real kernel OR records op in trace
+  |--> _maybe_wrap_tensor(...)          - AsyncTensor wrapper applied to returned tensor,
+                                          which issues wait_tensor() at the time of first use
+"""
+
+
+def wait_tensor(tensor):
+    """
+    Wait on a tensor returned by the collectives ops.
+
+    Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
+    """
+    if native_funcol_enabled():
+        return torch.ops._c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
+    else:
+        return torch.ops.c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
+
+
+def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
+    """
+    Broadcasts the tensor to all processes in the given process group.
+
+    Args:
+        src (int): Source rank
+        group (ProcessGroup or List[int]): The process group to work on.
+        tag (str, optional): A unique identifier for the collective. Default: empty string
+    """
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor = torch.ops.c10d_functional.broadcast(
+            self, src, tag, rankset, group_size
+        )
+    return _maybe_wrap_tensor(tensor)
+
+
+def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
+    """
+    Reduces the tensor data across all machines in such a way that all get
+    the final result.
+
+    The input tensor is left unmodified.
+
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        tensor = torch.ops._c10d_functional.all_reduce(
+            self, reduceOp.lower(), group_name
+        )
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor = torch.ops.c10d_functional.all_reduce(  # type: ignore[attr-defined]
+            self,
+            reduceOp,
+            tag,
+            rankset,
+            group_size,
+        )
+    return _maybe_wrap_tensor(tensor)
+
+
+def all_gather_tensor(
+    self: torch.Tensor,
+    gather_dim: int,
+    group: RANK_TYPES,
+    tag: str = "",
+):
+    """
+    Gather tensor data across from all machines and concatenate over ``gather_dim``.
+
+    Note that it currently only supports gather_dim = 0.
+
+    The input tensor is left unmodified.
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    assert self.is_contiguous()
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        group_size = c10d._get_group_size_by_name(group_name)
+        tensor = torch.ops._c10d_functional.all_gather_into_tensor(
+            self, group_size, group_name
+        )
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor = torch.ops.c10d_functional.all_gather_into_tensor(  # type: ignore[attr-defined]
+            self,
+            tag,
+            rankset,
+            group_size,
+        )
+    res = _maybe_wrap_tensor(tensor)
+    # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
+    if gather_dim != 0:
+        # torch.cat access the data so we already need to wait here, first do wait
+        # and then chunk + cat avoid us going through ACT dispatching logic again
+        if isinstance(res, AsyncCollectiveTensor):
+            res = res.wait()  # type: ignore[attr-defined]
+        res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
+    return res
+
+
+def reduce_scatter_tensor(
+    self: torch.Tensor,
+    reduceOp: str,
+    scatter_dim: int,
+    group: RANK_TYPES,
+    tag: str = "",
+):
+    """
+    Reduces the tensor data across all machines in such a way that all get
+    the final result, then scatter the results to corresponding ranks.
+
+
+    The input tensor is left unmodified.
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        group_size = c10d._get_group_size_by_name(group_name)
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+
+    assert (
+        self.size(scatter_dim) % group_size == 0
+    ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
+    if scatter_dim != 0:
+        tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
+        self = torch.cat(tensor_list)
+
+    if native_funcol_enabled():
+        tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
+            self,
+            reduceOp.lower(),
+            group_size,
+            group_name,  # type: ignore[possibly-undefined]
+        )
+    else:
+        tensor = torch.ops.c10d_functional.reduce_scatter_tensor(  # type: ignore[attr-defined]
+            self,
+            reduceOp,
+            tag,
+            rankset,  # type: ignore[possibly-undefined]
+            group_size,
+        )
+    res = _maybe_wrap_tensor(tensor)
+    return res
+
+
+def all_reduce_coalesced(
+    self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
+) -> List[torch.Tensor]:
+    """
+    Reduces a list of tensors across all machines in such a way that all get
+    the final result.
+
+    The all tensors in the input list are left unmodified.
+
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        tensor_list = torch.ops._c10d_functional.all_reduce_coalesced(  # type: ignore[attr-defined]
+            self,
+            reduceOp.lower(),
+            group_name,
+        )
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor_list = torch.ops.c10d_functional.all_reduce_coalesced(  # type: ignore[attr-defined]
+            self,
+            reduceOp,
+            tag,
+            rankset,
+            group_size,
+        )
+    return list(map(_maybe_wrap_tensor, tensor_list))
+
+
+def all_gather_into_tensor_coalesced(
+    self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
+) -> List[torch.Tensor]:
+    """
+    Gather a list of tensors across from all machines.
+
+    Note that it currently only supports gather_dim = 0.
+
+    The input tensor is left unmodified.
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        group_size = c10d._get_group_size_by_name(group_name)
+        tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(  # type: ignore[attr-defined]
+            self,
+            group_size,
+            group_name,
+        )
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(  # type: ignore[attr-defined]
+            self,
+            tag,
+            rankset,
+            group_size,
+        )
+    return list(map(_maybe_wrap_tensor, tensor_list))
+
+
+def reduce_scatter_tensor_coalesced(
+    inputs: List[torch.Tensor],
+    reduceOp: str,
+    scatter_dim: List[int],
+    group: RANK_TYPES,
+    tag: str = "",
+) -> List[torch.Tensor]:
+    """
+    Reduces a list of tensors across all machines in such a way that all get
+    the final result, then scatter the results to corresponding ranks.
+
+    The input tensors are left unmodified.
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        group_size = c10d._get_group_size_by_name(group_name)
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+
+    assert len(scatter_dim) == len(inputs)
+    for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
+        assert (
+            tensor.size(dim) % group_size == 0
+        ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
+        if dim != 0:
+            tensor_list = torch.chunk(tensor, group_size, dim=dim)
+            inputs[idx] = torch.cat(tensor_list)
+
+    if native_funcol_enabled():
+        tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(  # type: ignore[attr-defined]
+            inputs,
+            reduceOp.lower(),
+            group_size,
+            group_name,  # type: ignore[possibly-undefined]
+        )
+    else:
+        tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(  # type: ignore[attr-defined]
+            inputs,
+            reduceOp,
+            tag,
+            rankset,  # type: ignore[possibly-undefined]
+            group_size,
+        )
+
+    return list(map(_maybe_wrap_tensor, tensor_list))
+
+
+# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
+# Today, this maps 1:1 with "aten ops that are views".
+def _is_view_op(tgt):
+    assert isinstance(tgt, torch._ops.OpOverload)
+    schema = tgt._schema
+    if len(schema.arguments) > 0:
+        first_arg = schema.arguments[0]
+        # check if op is a view
+        return first_arg.alias_info is not None and not first_arg.alias_info.is_write
+
+
+def all_to_all_single(
+    self: torch.Tensor,
+    output_split_sizes: Optional[List[int]],
+    input_split_sizes: Optional[List[int]],
+    group: RANK_TYPES,
+    tag: str = "",
+) -> torch.Tensor:
+    """
+    Each process splits input tensor and then scatters the split list
+    to all processes in a group. Then concatenate the received tensors from all
+    the processes in the group and return single output tensor.
+
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    if output_split_sizes is not None:
+        assert all(
+            isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
+        ), output_split_sizes
+    if input_split_sizes is not None:
+        assert all(
+            isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
+        ), input_split_sizes
+    if native_funcol_enabled():
+        group_name = _resolve_group_name(group, tag)
+        group_size = c10d._get_group_size_by_name(group_name)
+        if output_split_sizes is None or input_split_sizes is None:
+            assert output_split_sizes is None and input_split_sizes is None, (
+                "output_split_sizes and input_split_sizes must either be "
+                "specified together or both set to None"
+            )
+            output_split_sizes = [self.shape[0] // group_size] * group_size
+            input_split_sizes = output_split_sizes
+        tensor = torch.ops._c10d_functional.all_to_all_single(  # type: ignore[attr-defined]
+            self,
+            output_split_sizes,
+            input_split_sizes,
+            group_name,
+        )
+    else:
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor = torch.ops.c10d_functional.all_to_all_single(  # type: ignore[attr-defined]
+            self,
+            output_split_sizes,
+            input_split_sizes,
+            tag,
+            rankset,
+            group_size,
+        )
+    return _maybe_wrap_tensor(tensor)
+
+
+def permute_tensor(
+    self: torch.Tensor,
+    src_dst: List[int],
+    group: RANK_TYPES,
+    tag: str = "",
+) -> torch.Tensor:
+    """
+    Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
+    be defined such that src_dst[m] == n means m sends to n.
+
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one
+    """
+    t, rankset, group_size = _expand_group(group, tag)
+    local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
+
+    output_split_sizes = [0] * group_size
+    input_split_sizes = [0] * group_size
+    for src, dst in enumerate(src_dst):
+        if src == dist.get_rank(local_pg):
+            input_split_sizes[dst] = self.numel()
+        if dst == dist.get_rank(local_pg):
+            output_split_sizes[src] = self.numel()
+
+    return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
+
+
+class AsyncCollectiveTensor(torch.Tensor):
+    r"""
+    A Tensor wrapper subclass that is used to trigger a call to wait
+    prior to first use of the underlying tensor.
+    Use it inside functional collective pytorch wrappers like the following:
+    def functional_collective(self, group, tag):
+        tag, rankset, group_size = _expand_group(group, tag)
+        tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
+        return _maybe_wrap_tensor(tensor)
+    """
+    elem: torch.Tensor
+    completed: bool
+
+    __slots__ = ["elem", "completed"]
+
+    @staticmethod
+    def __new__(cls, elem: torch.Tensor):
+        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
+            cls,
+            elem.size(),
+            strides=elem.stride(),
+            storage_offset=elem.storage_offset(),
+            dtype=elem.dtype,
+            layout=elem.layout,
+            device=elem.device,
+            requires_grad=False,
+        )
+        r.elem = elem
+        r.completed = False
+        return r
+
+    def __tensor_flatten__(self):
+        return ["elem"], None
+
+    def tolist(self):
+        self.trigger_wait()
+        return self.elem.tolist()
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        elem = inner_tensors["elem"]
+        return AsyncCollectiveTensor(elem)
+
+    def __repr__(self):
+        self.trigger_wait()
+        return f"AsyncCollectiveTensor({self.elem})"
+
+    def trigger_wait(self):
+        if not self.completed:
+            wait_tensor(self.elem)
+            self.completed = True
+        return self.elem
+
+    def wait(self) -> torch.Tensor:
+        wait_tensor(self.elem)
+        return self.elem
+
+    def _get_acs_underlying_tensor(self):
+        """This method enables  _functional_collectives_impl to test if a tensor is an ACS"""
+        return self.elem
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if func == torch.ops.aten.view.default:
+            # Fast handle aten.view as a lot of view related op goes to aten.view
+            # eventually, this avoids pytree slowdown
+            res = func(args[0].elem, args[1])
+            wrapper_res = AsyncCollectiveTensor(res)
+            _register_tensor_wrapper(wrapper_res)
+            return wrapper_res
+
+        is_view_op = _is_view_op(func)
+
+        def unwrap(e: AsyncCollectiveTensor):
+            # wait_tensor is idepotent and will do stream sync only once
+            if not is_view_op:
+                e.trigger_wait()
+            return e.elem
+
+        def wrap(e: torch.Tensor):
+            # wait_tensor is idepotent and will do stream sync only once
+            assert not isinstance(e, AsyncCollectiveTensor)
+            res = AsyncCollectiveTensor(e)
+            _register_tensor_wrapper(res)
+            return res
+
+        unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
+        unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
+
+        # we don't wrap the result as it doesn't need to be waited on.
+        out = func(*unwrapped_args, **unwrapped_kwargs)
+
+        # View ops dont require a sync, so we should re-wrap the outputs.
+        if is_view_op:
+            out = tree_map_only(torch.Tensor, wrap, out)
+
+        return out
+
+    def numpy(self):
+        return self.wait().numpy()
+
+
+"""
+Utils and infrastructure for tracing support
+"""
+
+
+def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
+    """
+    _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
+
+    By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
+    torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
+    """
+    # had to define this hack _inside_ expand_group to avoid
+    # graph_break [('torch.* op returned non-Tensor int
+    # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
+    if TYPE_CHECKING:
+
+        def cast_listlistint(x):
+            return cast(List[List[int]], x)
+
+        def cast_listint(x):
+            return cast(List[int], x)
+
+    else:
+        # fake cast op for use at runtime since dynamo doesn't support real cast
+        # also, dynamo didn't like encountering 'typing' objects ()
+        # NotImplementedError: argument of type: 
+        def cast_listlistint(x):
+            return x
+
+        def cast_listint(x):
+            return x
+
+    rankset: List[int]
+    if isinstance(group, list):
+        if isinstance(group[0], list):
+            nested_list = cast_listlistint(group)
+            rankset = []
+            group_size = -1
+            for rs in nested_list:
+                rankset.extend(rs)
+                if group_size != -1 and group_size != len(rs):
+                    raise ValueError(
+                        f"group sizes must be identical found {group_size} and {len(rs)}"
+                    )
+                group_size = len(rs)
+        else:
+            rankset = cast_listint(group)
+            group_size = len(rankset)
+    elif isinstance(group, dist.ProcessGroup):
+        rankset = dist.get_process_group_ranks(group)
+        group_size = len(rankset)
+        tag = tag or c10d._get_group_tag(group)
+    elif isinstance(group, DeviceMesh):
+        assert (
+            group.ndim == 1
+        ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
+        # TODO: it should run collective in the whole mesh instead of dim 0
+        tag, rankset, _ = group._dim_group_infos[0]
+        group_size = len(rankset)
+    elif isinstance(group, tuple):
+        if (
+            len(group) == 2
+            and isinstance(group[0], DeviceMesh)
+            and isinstance(group[1], int)
+        ):
+            dmesh = group[0]
+            dim = group[1]
+            tag, rankset, _ = dmesh._dim_group_infos[dim]
+            group_size = len(rankset)
+        else:
+            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
+    else:
+        raise ValueError(
+            "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
+        )
+
+    return (tag, rankset, group_size)
+
+
+def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
+    """
+    Given group in RANK_TYPES, return the group name.
+    """
+    # `tag` will be deprecated. See details in:
+    # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
+    if isinstance(group, dist.ProcessGroup):
+        return group.group_name
+    elif isinstance(group, str):
+        return group
+    elif isinstance(group, DeviceMesh):
+        assert (
+            group.ndim == 1
+        ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
+        return group._dim_group_infos[0][2]
+    elif isinstance(group, tuple):
+        if (
+            len(group) == 2
+            and isinstance(group[0], DeviceMesh)
+            and isinstance(group[1], int)
+        ):
+            dmesh = group[0]
+            dim = group[1]
+            return dmesh._dim_group_infos[dim][2]
+        else:
+            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
+    elif isinstance(group, list):
+        if not is_torchdynamo_compiling():
+            warnings.warn(
+                "The combination of ranks + tag as process group "
+                "identifier has been deprecated. Please switch to "
+                "using ProcessGroup, DeviceMesh, or group name instead."
+            )
+        return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
+    else:
+        raise ValueError(f"Unsupported group type: {type(group)}, {group}")
+
+
+def _are_we_tracing() -> bool:
+    if is_torchdynamo_compiling():
+        return True
+    # If functionalization is turned on, we are almost definitely compiling/tracing.
+    # (In particular, AOTAutograd traces a model once with functionalization on
+    #  but proxy tracing turned of, so this is how we detect it).
+    if (
+        torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
+        is not None
+    ):
+        return True
+    mode = get_innermost_proxy_mode()
+    if mode is None:
+        return False
+    return mode.tracer is not None
+
+
+def _maybe_wrap_tensor(self) -> torch.Tensor:
+    if _are_we_tracing():
+        return wait_tensor(self)
+    res = AsyncCollectiveTensor(self)
+    _register_tensor_wrapper(res)
+    return cast(torch.Tensor, res)
+
+
+def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
+    def mk_out_tensor(shard):
+        out_size = list(shard.size())
+        out_size[0] *= group_size
+        out_tensor = shard.new_empty(out_size)
+        return out_tensor
+
+    return [mk_out_tensor(t) for t in self]
+
+
+# We now register meta kernels to deal with tracing
+def _broadcast_meta(self, *args):
+    return torch.empty_like(self)
+
+
+def _all_reduce_meta(self, *args):
+    return torch.empty_like(self)
+
+
+def _wait_tensor_meta(self, *args):
+    return torch.empty_like(self)
+
+
+def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
+    out_size = list(shard.size())
+    out_size[0] *= group_size
+    return shard.new_empty(out_size)
+
+
+def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
+    out_size = list(input.size())
+    out_size[0] //= group_size
+    return input.new_empty(out_size)
+
+
+def _all_reduce_coalesced_meta(self, *args):
+    return [torch.empty_like(t) for t in self]
+
+
+def _all_reduce__meta(inp, *args):
+    return inp
+
+
+def _broadcast__meta(inp, *args):
+    return inp
+
+
+def _all_reduce_coalesced__meta(inputs, *args):
+    return inputs
+
+
+def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
+    def mk_out_tensor(input):
+        out_size = list(input.size())
+        out_size[0] //= group_size
+        out_tensor = input.new_empty(out_size)
+        return out_tensor
+
+    return [mk_out_tensor(t) for t in inputs]
+
+
+# NB: We often say all_to_all has dynamic output size, but this is not
+# technically true: instead, what typically happens is you manually
+# communicate the output_split_sizes ahead of time (which is dynamic),
+# but then you pass those sizes explicitly, and the all to all itself
+# isn't dynamic, it just follows the specified output splits
+def _all_to_all_single_meta(
+    input, output_split_sizes, input_split_sizes, *args, **kwargs
+):
+    if output_split_sizes is None:
+        return input.new_empty(input.size())
+    else:
+        for s in output_split_sizes:
+            torch._check_is_size(s)
+        out_size = list(input.size())
+        out_size[0] = sum(output_split_sizes)
+        return input.new_empty(out_size)
+
+
+def _all_gather_into_tensor_native_meta(input, group_size, group_name):
+    shape = list(input.size())
+    shape[0] *= group_size
+    return input.new_empty(shape)
+
+
+def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
+    return [
+        _all_gather_into_tensor_native_meta(input, group_size, group_name)
+        for input in inputs
+    ]
+
+
+def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
+    shape = list(inp.size())
+    shape[0] //= group_size
+    return inp.new_empty(shape)
+
+
+def _reduce_scatter_tensor_coalesced_native_meta(
+    inputs, reduce_op, group_size, group_name
+):
+    return [
+        _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
+        for inp in inputs
+    ]
+
+
+def _register_ops():
+    ops_defs = [
+        "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
+        "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
+        "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
+        "wait_tensor(Tensor self) -> Tensor",
+        "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
+        "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
+        "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
+        "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
+        "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor",  # noqa: B950
+    ]
+
+    my_module = sys.modules[__name__]
+    for op_def in ops_defs:
+        op_name = op_def[0 : op_def.index("(")]
+        backend_impl = getattr(fun_col_impl, f"_{op_name}")
+        meta_impl = getattr(my_module, f"_{op_name}_meta")
+        c10_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
+        c10_lib_impl.impl(op_name, backend_impl, "CompositeExplicitAutograd")
+        impl_abstract(f"c10d_functional::{op_name}")(meta_impl)
+
+
+if not torch._running_with_deploy():
+    # Library MUST be defined at module scope or it doesn't work
+    # Creating a "DEF" Library always crashes torch::deploy so we create our Library instances here
+    #   guarded against running inside it
+    c10_lib = torch.library.Library("c10d_functional", "DEF")
+    c10_lib_impl = torch.library.Library("c10d_functional", "IMPL")
+    _register_ops()
+
+    _c10_lib_impl = torch.library.Library("_c10d_functional", "IMPL")
+    _c10_lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
+    _c10_lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
+    _c10_lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
+    _c10_lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
+    _c10_lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
+    _c10_lib_impl.impl(
+        "all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta"
+    )
+    _c10_lib_impl.impl(
+        "all_gather_into_tensor_coalesced",
+        _all_gather_into_tensor_coalesced_native_meta,
+        "Meta",
+    )
+    _c10_lib_impl.impl(
+        "reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta"
+    )
+    _c10_lib_impl.impl(
+        "reduce_scatter_tensor_coalesced",
+        _reduce_scatter_tensor_coalesced_native_meta,
+        "Meta",
+    )
+    _c10_lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
+    _c10_lib_impl.impl("broadcast", _broadcast_meta, "Meta")
+    _c10_lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
+else:
+    warnings.warn(
+        "PyTorch Distributed functional collectives do not work with torch::deploy."
+    )
+
+
+"""
+Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
+functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
+
+We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
+the mapping dict below.
+
+These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
+"""
+
+
+def all_gather_tensor_inplace(
+    output_tensor: torch.Tensor,
+    input_tensor: torch.Tensor,
+    group,  # TODO add a type,
+    async_op: bool = False,
+    tag: str = "",
+    gather_dim: int = 0,
+):
+    assert (
+        not async_op
+    ), "Can't remap async version of inplace op to functional collective"
+    return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
+
+
+def reduce_scatter_tensor_inplace(
+    output: torch.Tensor,
+    input: torch.Tensor,
+    op: str = "sum",  # TODO type is actually c10d ReduceOp. is this ok?
+    group=None,  # TODO add a type
+    async_op: bool = False,
+    scatter_dim: int = 0,
+    tag: str = "",
+):
+    assert (
+        not async_op
+    ), "Can't remap async version of inplace op to functional collective"
+    return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
+
+
+REDUCE_OP_TO_STR = {
+    dist.ReduceOp.SUM: "sum",
+    dist.ReduceOp.AVG: "avg",
+    dist.ReduceOp.PRODUCT: "product",
+    dist.ReduceOp.MIN: "min",
+    dist.ReduceOp.MAX: "max",
+    dist.ReduceOp.BAND: "band",
+    dist.ReduceOp.BOR: "bor",
+    dist.ReduceOp.BXOR: "bxor",
+}
+
+
+def all_reduce_inplace(
+    tensor: torch.Tensor,
+    op: str = "sum",
+    group=None,
+    async_op: bool = False,
+    tag: str = "",
+):
+    assert (
+        not async_op
+    ), "Can't remap async version of inplace op to functional collective"
+
+    return tensor.copy_(all_reduce(tensor, op, group, tag))
+
+
+def all_to_all_inplace(
+    output: torch.Tensor,
+    input: torch.Tensor,
+    output_split_sizes=None,
+    input_split_sizes=None,
+    group=None,
+    async_op=False,
+    tag: str = "",
+):
+    assert (
+        not async_op
+    ), "Can't remap async version of inplace op to functional collective"
+    return output.copy_(
+        all_to_all_single(input, output_split_sizes, input_split_sizes, group, tag)
+    )
+
+
+def all_gather_inplace(
+    tensor_list: List[torch.Tensor],
+    tensor: torch.Tensor,
+    group=None,
+    async_op=False,
+    tag: str = "",
+):
+    assert (
+        not async_op
+    ), "Can't remap async version of inplace op to functional collective"
+    assert all(
+        t.size(0) == tensor.size(0) for t in tensor_list
+    ), "Remapping variable size all_gather is not yet supported"
+
+    output = all_gather_tensor(tensor, 0, group, tag)
+
+    # Use aten.slice instead of aten.split because the latter causes
+    # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
+    output_splits = []
+    offset = 0
+    for t in tensor_list:
+        output_splits.append(output[offset : offset + t.size(0)])
+        offset += t.size(0)
+    for dst, src in zip(tensor_list, output_splits):
+        dst.copy_(src)
+    return tensor_list
+
+
+from torch.distributed.distributed_c10d import (
+    _all_gather_base as legacy_all_gather_base,
+    _reduce_scatter_base as legacy_reduce_scatter_base,
+    all_gather as legacy_all_gather,
+    all_gather_into_tensor as legacy_allgather,
+    all_reduce as legacy_allreduce,
+    all_to_all_single as legacy_all_to_all_single,
+    reduce_scatter_tensor as legacy_reducescatter,
+)
+
+# This dict should contain sets of functions that dynamo is allowed to remap.
+# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
+traceable_collective_remaps = {
+    legacy_allgather: all_gather_tensor_inplace,
+    legacy_reducescatter: reduce_scatter_tensor_inplace,
+    legacy_allreduce: all_reduce_inplace,
+    legacy_all_to_all_single: all_to_all_inplace,
+    legacy_all_gather: all_gather_inplace,
+    legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
+    legacy_all_gather_base: all_gather_tensor_inplace,
+}
diff --git a/MLPY/Lib/site-packages/torch/distributed/_functional_collectives_impl.py b/MLPY/Lib/site-packages/torch/distributed/_functional_collectives_impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf4c223dc4305ede128bed965ed3fa914e2e894
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_functional_collectives_impl.py
@@ -0,0 +1,409 @@
+import logging
+import os
+import warnings
+import weakref
+from typing import cast, Dict, List, Optional
+
+import torch
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as c10d
+
+"""
+Moved eager kernel implementations to a separate file partly for readability and partly as it is currently
+easier in dynamo to set tracing policy on a file-by-file level.
+
+Do not put code in this file that Dynamo is expected to trace into, as dynamo may disallow this whole file.
+
+DEBUG/TESTING HELPERS:
+
+This module includes some helpers that are quite useful when debugging or testing functional collectives:
+
+_tensor_needs_wait
+_outstanding_wait_count
+_wait_all
+
+"""
+
+_use_native_funcol: Optional[bool] = None
+
+
+if torch._running_with_deploy():
+
+    def native_funcol_enabled():
+        return False
+
+else:
+    from torch._dynamo import assume_constant_result
+
+    @assume_constant_result
+    def native_funcol_enabled():
+        global _use_native_funcol
+        if _use_native_funcol is None:
+            try:
+                # Disable native funcol when torch_xla is installed. This check
+                # will be removed once torch_xla adopts the native_funcol IR.
+                import torch_xla  # noqa: F401
+
+                _use_native_funcol = False
+            except Exception:
+                # When TORCH_DISABLE_NATIVE_FUNCOL is set, fallback to py funcol
+                _use_native_funcol = (
+                    os.environ.get("TORCH_DISABLE_NATIVE_FUNCOL") != "1"
+                )
+
+        return _use_native_funcol
+
+
+logger = logging.getLogger(__name__)
+
+data_ptr_to_work: Dict[int, "_WaitRegistration"] = dict()
+work_version = 0
+
+
+class _WaitRegistration:
+    def __init__(self, work):
+        global work_version
+        self.work = work
+        self.version = work_version
+        self.ptrs = set()
+        self.ptr_alias_count = {}
+        self.cleanup_count = 0
+        work_version += 1
+
+    def _register_tensor_ptr(self, data_ptr):
+        global data_ptr_to_work
+        data_ptr_to_work[data_ptr] = self
+        self.ptrs.add(data_ptr)
+
+    def _record_wrapper(self, ptr):
+        self._register_tensor_ptr(ptr)
+        self.ptr_alias_count.setdefault(ptr, 0)
+        self.ptr_alias_count[ptr] += 1
+        self.cleanup_count += 1
+
+    def wait(self):
+        if self.work is not None:
+            self.work.wait()
+            self.work = None
+        self.cleanup()
+
+    def decrement_live_tensor(self, ptr):
+        self.cleanup_count -= 1
+        if self.cleanup_count == 0:
+            self.wait()
+        else:
+            self.ptr_alias_count[ptr] -= 1
+            if (
+                self.ptr_alias_count[ptr] < 1
+                and data_ptr_to_work.get(ptr, None) == self
+            ):
+                del data_ptr_to_work[ptr]
+
+    def cleanup(self):
+        for ptr in self.ptrs:
+            if data_ptr_to_work.get(ptr, None) == self:
+                del data_ptr_to_work[ptr]
+
+
+def _register_tensor_work(tensor_or_list, work_or_list):
+    if not isinstance(tensor_or_list, list):
+        tensor_or_list = [tensor_or_list]
+    if not isinstance(work_or_list, list):
+        reg = _WaitRegistration(work_or_list)
+        for tensor in tensor_or_list:
+            reg._register_tensor_ptr(tensor.data_ptr())
+    else:
+        for tensor, work in zip(tensor_or_list, work_or_list):
+            reg = _WaitRegistration(work)
+            reg._register_tensor_ptr(tensor.data_ptr())
+
+
+def _wait_reg_dec(ptr, wait_reg):
+    wait_reg.decrement_live_tensor(ptr)
+
+
+def _register_tensor_wrapper(tensor) -> None:
+    if native_funcol_enabled():
+        # Tensor storage -> work mapping is maintained in C++
+        return
+    global data_ptr_to_work
+    data_ptr = tensor.elem.data_ptr()
+    # Note: we should NEVER try to trace this, bc it registers runtime stuff during trace.
+    # Instead, backends must call this themselves when implementing traced collectives.
+    wait_reg = data_ptr_to_work.get(data_ptr, None)
+    if wait_reg is None:
+        warnings.warn(
+            "Trying to register finalizer to AsyncCollectiveTensor but the inner tensor is already gone"
+        )
+    else:
+        # We force the collective to be waited in the case this tensor goes away to reduce the change of deadlocks.
+        # NOTE: we register the callback to the ACT wrapper class, for the following reasons:
+        # 1. The inner tensor is referenced by the associated Work object, so it's uncollective until we release the
+        #  associated work object
+        # 2. There's a n-to-1 relationship between wrappers and inner tensor due to non-waitable ops like view()
+        wait_reg._record_wrapper(data_ptr)
+        weakref.finalize(tensor, _wait_reg_dec, data_ptr, wait_reg)
+
+
+def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
+    global data_ptr_to_work
+    data_ptr = tensor.data_ptr()
+    wait_reg = data_ptr_to_work.get(data_ptr)
+    if wait_reg is not None:
+        wait_reg.wait()
+    return tensor
+
+
+def _tensor_needs_wait(tensor: torch.Tensor) -> bool:
+    """Returns true if ```tensor``` needs to be waited. Works with ACS and inner tensors."""
+    if hasattr(tensor, "_get_acs_underlying_tensor"):
+        tensor = tensor._get_acs_underlying_tensor()
+    data_ptr = tensor.data_ptr()
+    wait_reg = data_ptr_to_work.get(data_ptr)
+    return wait_reg is not None and wait_reg.work is not None
+
+
+def _outstanding_wait_count() -> int:
+    """Returns the number of outstanding work objects waiting to be waited (sic)."""
+    return len(data_ptr_to_work)
+
+
+def _wait_all() -> None:
+    """Wait for all outstanding collectives."""
+    for work_reg in list(data_ptr_to_work.values()):
+        work_reg.wait()
+
+
+def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp:
+    reduceOp = reduceOp.upper()
+    op = dist.ReduceOp.RedOpType.__members__.get(reduceOp)
+    if op is None:
+        raise ValueError(f"Invalid reduce operation {reduceOp}")
+    return cast(dist.ReduceOp, op)
+
+
+"""
+Kernel implementations (for eager runtime only) - should never be traced by torch.compile
+
+These functions should all be bound to dispatcher ops.  During tracing, the op itself should be
+captured in the graph and the backend should implement the op however it prefers.
+"""
+
+
+def _broadcast(self, src, tag, ranks, group_size):
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+
+    inplace_tensor = self.clone(memory_format=torch.contiguous_format)
+    work = dist.broadcast(inplace_tensor, src, group=group, async_op=True)
+    _register_tensor_work(inplace_tensor, work)
+
+    return inplace_tensor
+
+
+# TODO assert if ranks has duplicated entries
+def _all_reduce(self, reduceOp, tag, ranks, group_size):
+    op = _str_to_reduce_op(reduceOp)
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+
+    inplace_tensor = self.clone(memory_format=torch.contiguous_format)
+    work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
+    _register_tensor_work(inplace_tensor, work)
+
+    return inplace_tensor
+
+
+def _all_reduce_coalesced(self, reduceOp, tag, ranks, group_size):
+    op = _str_to_reduce_op(reduceOp)
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+
+    inplace_tensor_list = [t.clone(memory_format=torch.contiguous_format) for t in self]
+    work = dist.all_reduce_coalesced(
+        inplace_tensor_list, op=op, group=group, async_op=True
+    )
+    _register_tensor_work(inplace_tensor_list, work)
+
+    return inplace_tensor_list
+
+
+def _all_gather_into_tensor(shard, tag, ranks, group_size):
+    # TODO add dim support?
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+    out_size = list(shard.size())
+    out_size[0] *= group_size
+    out_tensor = shard.new_empty(out_size)
+    assert out_tensor.is_contiguous()
+    # FIXME gloo doesn't support _allgather_base
+    if dist.get_backend(group) == dist.Backend.GLOO or shard.is_cpu:
+        tensor_list = list(torch.chunk(out_tensor, group_size))
+        work = dist.all_gather(tensor_list, shard, group=group, async_op=True)
+    else:
+        work = dist.all_gather_into_tensor(
+            out_tensor, shard, group=group, async_op=True
+        )
+    _register_tensor_work(out_tensor, work)
+
+    return out_tensor
+
+
+def _all_gather_into_tensor_coalesced(self, tag, rankset, group_size):
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, rankset, group_size)
+    assert group is not None
+
+    def mk_out_tensor(shard):
+        out_size = list(shard.size())
+        out_size[0] *= group_size
+        out_tensor = shard.new_empty(out_size)
+        assert out_tensor.is_contiguous()
+        return out_tensor
+
+    out_tensors = [mk_out_tensor(t) for t in self]
+
+    work_list = _all_gather_into_tensor_coalesced_fallback(
+        output_tensors=out_tensors, input_tensors=self, group=group, async_op=True
+    )
+
+    _register_tensor_work(out_tensors, work_list)
+    return out_tensors
+
+
+def _reduce_scatter_tensor(
+    input: torch.Tensor,
+    reduceOp: str,
+    tag: str,
+    ranks: List[int],
+    group_size: int,
+):
+    # TODO add dim support?
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+    op = _str_to_reduce_op(reduceOp)
+
+    if dist.get_backend(group) == dist.Backend.GLOO or input.is_cpu:
+        # cpu::gloo backend does not have reduce_scatter we fallback to do all_reduce
+        # + local chunk
+        logger.warning(
+            "ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!"
+        )
+        reduction_input = input.clone()
+        group_rank = dist.get_rank(group)
+        work = dist.all_reduce(reduction_input, op=op, group=group, async_op=True)
+        out_tensor = reduction_input.chunk(group_size, dim=0)[group_rank]
+        _register_tensor_work(out_tensor, work)
+    else:
+        out_size = list(input.size())
+        out_size[0] //= group_size
+        out_tensor = input.new_empty(out_size)
+        work = dist.reduce_scatter_tensor(
+            out_tensor, input, op=op, group=group, async_op=True
+        )
+        _register_tensor_work(out_tensor, work)
+
+    return out_tensor
+
+
+def _reduce_scatter_tensor_coalesced(
+    inputs: List[torch.Tensor],
+    reduce_op: str,
+    tag: str,
+    ranks: List[int],
+    group_size: int,
+):
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+    op = _str_to_reduce_op(reduce_op)
+
+    def mk_out_tensor(shard):
+        out_size = list(shard.size())
+        out_size[0] //= group_size
+        out_tensor = shard.new_empty(out_size)
+        assert out_tensor.is_contiguous()
+        return out_tensor
+
+    out_tensors = [mk_out_tensor(t) for t in inputs]
+
+    work_list = _reduce_scatter_tensor_coalesced_fallback(
+        output_tensors=out_tensors,
+        input_tensors=inputs,
+        op=op,
+        group=group,
+        async_op=False,
+    )
+
+    _register_tensor_work(out_tensors, work_list)
+    return out_tensors
+
+
+def _all_gather_into_tensor_coalesced_fallback(
+    output_tensors, input_tensors, group, async_op=False
+):
+    # all_gather_coalesced is useless, it doesn't work under NCCL and does lots of copies under Gloo
+    # all_gather is useless too because it's single tensor
+    # NCCL's PG::all_gather with multiple tensors is broken, it only works for the multi-device setting
+    #  and fails if you mix same-size with different-size tensor lists.
+    # _coalescing_manager crashed NCCL when used with all_gather_into_tensor.
+    if input_tensors[0].is_cpu or not async_op:
+        work_list = []
+        out_tensors_sliced = [
+            list(torch.chunk(out_tensor, dist.get_world_size(group)))
+            for out_tensor in output_tensors
+        ]
+        for shard, out_tensor in zip(input_tensors, out_tensors_sliced):
+            work = c10d.all_gather(out_tensor, shard, group=group, async_op=async_op)
+            work_list.append(work)
+        return work_list
+    else:
+        with c10d._coalescing_manager(group=group, async_ops=True) as cm:
+            for in_t, out_t in zip(input_tensors, output_tensors):
+                dist.all_gather_into_tensor(out_t, in_t, group=group, async_op=True)
+        return cm
+
+
+def _reduce_scatter_tensor_coalesced_fallback(
+    output_tensors, input_tensors, op, group, async_op=False
+):
+    # All the same reasons as the all_gather fallback
+    work_list = []
+    for shard, out_tensor in zip(input_tensors, output_tensors):
+        work = c10d.reduce_scatter_tensor(
+            out_tensor, shard, op=op, group=group, async_op=async_op
+        )
+        work_list.append(work)
+    return work_list
+
+
+def _all_to_all_single(
+    input: torch.Tensor,
+    output_split_sizes: Optional[List[int]],
+    input_split_sizes: Optional[List[int]],
+    tag: str,
+    ranks: List[int],
+    group_size: int,
+):
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+
+    if output_split_sizes is not None:
+        torch._check(
+            input.dim() >= 1,
+            lambda: f"Expected input to have at least 1 dim but got {input.dim()} dim",
+        )
+        out_size = list(input.size())
+        out_size[0] = sum(output_split_sizes)
+        out_tensor = input.new_empty(out_size)
+    else:
+        out_tensor = input.new_empty(input.size())
+
+    work = c10d.all_to_all_single(
+        out_tensor,
+        input,
+        output_split_sizes=output_split_sizes,
+        input_split_sizes=input_split_sizes,
+        group=group,
+        async_op=True,
+    )
+    _register_tensor_work(out_tensor, work)
+
+    return out_tensor
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76091e1460e37564ff79fdd09869fcb09b498741
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/__init__.py
@@ -0,0 +1,6 @@
+from .api import (
+    _shard_tensor,
+    load_with_process_group,
+    shard_module,
+    shard_parameter,
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a11d4156d02c2611a48ffbcdca9b9f85170f14e2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2828a73b8dd5afc285b5842f57e69719198636c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7f8aaa59fc41145ef5f812e79dfcf259ddd9bdc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92f6eba7bd40dc6d9f842cbf63074d9f18d38b7e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1ddda13b7bf9106a9b8c4ef069173a9bbb05fbd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00aa3fbddb11c963f289169c88c4b46326a7ce0c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c07e800c7b956ca727b1d88fe267d1401a6003a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/_utils.py b/MLPY/Lib/site-packages/torch/distributed/_shard/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aedb7c8e6f2d4fa0d33dd79a3efddfe35c2ebfd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/_utils.py
@@ -0,0 +1,28 @@
+import torch
+from torch.distributed._shard.metadata import ShardMetadata
+from typing import Sequence
+
+DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor."
+
+def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: Sequence[int]) -> torch.Tensor:
+    """
+    Narrow the tensor according to ``offsets`` and ``sizes``.
+    """
+    narrowed_tensor = tensor
+    for idx, (offset, size) in enumerate(zip(offsets, sizes)):
+        if size < tensor.size(idx):
+            # Reshape to get shard for this rank and we don't want autograd
+            # recording here for the narrow op and 'local_shard' should be a
+            # leaf variable in the autograd graph.
+            narrowed_tensor = narrowed_tensor.narrow(
+                idx,
+                offset,
+                size
+            )
+    return narrowed_tensor
+
+def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
+    """
+    Narrow the tensor according to the metadata
+    """
+    return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/api.py b/MLPY/Lib/site-packages/torch/distributed/_shard/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..6867cd0bd5b8a3cb1734314c4051d3b82b7e68c7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/api.py
@@ -0,0 +1,290 @@
+from contextlib import contextmanager
+from typing import Optional
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import distributed_c10d
+from torch.distributed._shard.sharded_tensor import (
+    ShardedTensor,
+)
+from .sharding_spec import (
+    ShardingSpec,
+    ChunkShardingSpec
+)
+from .sharding_plan import (
+    ShardingPlan
+)
+from .sharder import Sharder
+
+def _shard_tensor(
+    tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
+) -> ShardedTensor:
+    """
+    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
+    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
+    used as the ground truth of the data which would be scattered as shards
+    across the rest of the ranks.
+
+    Args:
+        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+
+    Keyword args:
+        src_rank (int, optional): The source rank which is used as the ground truth of
+            the data for the parameter that would be sharded and scattered
+            across the rest of the ranks.
+            Default: 0.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+
+    Returns:
+        A :class:`ShardedTensor` sharded from the given tensor.
+
+    .. warning::
+        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
+        currently supported as the ``sharding_spec``.
+    """
+    if not tensor.is_contiguous():
+        raise ValueError('input tensor is not a contiguous Tensor')
+
+    pg = process_group if process_group is not None else distributed_c10d._get_default_group()
+    world_size = dist.get_world_size(pg)
+    current_rank = dist.get_rank(pg)
+
+    # Validate src_rank and sharding_spec are same across all ranks.
+    gathered_list = [None] * world_size
+    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
+
+    for idx, entry in enumerate(gathered_list):
+        if src_rank != entry[0]:  # type: ignore[index]
+            raise ValueError(
+                f'src_rank={src_rank} on rank: {current_rank} does not '  # type: ignore[index]
+                f'match with src_rank={entry[0]} on rank: {idx}')
+        if sharding_spec != entry[1]:  # type: ignore[index]
+            raise ValueError(
+                f'sharding_spec={sharding_spec} on rank: {current_rank} does not '  # type: ignore[index]
+                f'match with sharding_spec={entry[1]} on rank: {idx}')
+
+    st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
+
+    return st
+
+def shard_parameter(
+        module: torch.nn.Module,
+        param_name: str,
+        sharding_spec: ShardingSpec,
+        src_rank=0,
+        process_group=None):
+    """
+    Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
+    module, it shards that parameter according to the provided
+    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
+    used as the ground truth of the data which would be scattered as shards
+    across the rest of the ranks.
+
+    This method replaces ``module.param_name`` with a
+    :class:`torch.distributed._sharded_tensor.ShardedTensor`
+
+    Args:
+        module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
+        param_name (str): Name of the parameter of ``module`` that needs to be sharded.
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+
+    Keyword args:
+        src_rank (int, optional): The source rank which is used as the ground truth of
+            the data for the parameter that would be sharded and scattered
+            across the rest of the ranks.
+            Default: 0.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+
+    .. warning::
+        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
+        currently supported as the ``sharding_spec``.
+    """
+    # Perform some validation first.
+    if not hasattr(module, param_name):
+        raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`')
+
+    tensor = getattr(module, param_name)
+    if not isinstance(tensor, torch.Tensor):
+        raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
+
+    if not tensor.is_contiguous():
+        raise ValueError(f'param: {param_name} is not a contiguous Tensor')
+
+    st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
+
+    # Replace param with ShardedTensor.
+    module.register_parameter(param_name, nn.Parameter(st))
+
+# Tracks the current process group in the load context manager.
+_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None
+
+@contextmanager
+def load_with_process_group(process_group):
+    """
+    Context manager to set the process group with which to load a ShardedTensor.
+    """
+    global _CURRENT_PROCESS_GROUP
+    if _CURRENT_PROCESS_GROUP is not None:
+        raise RuntimeError(
+            'ProcessGroup already set by previous "load_with_process_group" '
+            'context manager')
+    _CURRENT_PROCESS_GROUP = process_group
+    try:
+        yield process_group
+    finally:
+        _CURRENT_PROCESS_GROUP = None
+
+def _get_current_process_group():
+    """
+    Retrieves the current process group set by ``load_with_process_group``.
+    If not set, it just returns the default group.
+    """
+    global _CURRENT_PROCESS_GROUP
+    if _CURRENT_PROCESS_GROUP is None:
+        return distributed_c10d._get_default_group()
+    else:
+        return _CURRENT_PROCESS_GROUP
+
+def _reshard_output(
+        module: torch.nn.Module,
+        resharding_spec: ShardingSpec) -> torch.nn.Module:
+    """
+    Hook a module with output resharding in the forward pass according
+    to the given ``resharding_spec``.
+
+    Args:
+        module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
+        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
+            The specification describing how the output of the module will be resharded.
+
+    Returns:
+        A :class:`torch.nn.Module` object with reshard API hooked.
+    """
+    def hook_func(_module, _input, output):
+        if isinstance(output, ShardedTensor):
+            return output.reshard(resharding_spec)
+        return output
+    module.register_forward_hook(hook_func)
+    return module
+
+def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
+    """
+    Hook a module with local shards collection in the forward pass.
+
+    This API is typically used to convert a sharded representation back to data parallel
+    representation. In particular, it returns the local tensor for this Shard. If the
+    size along the sharding dimension for the local tensor is 1, this dimension is removed
+    from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
+    a local Tensor of size [16] across each rank and not [1, 16] across each rank.
+
+    Args:
+        module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
+            local tensor value needs to be returned.
+
+    Returns:
+        A :class:`torch.nn.Module` object with collection API hooked.
+    """
+
+    def hook_func(_module, _input, output):
+        if isinstance(output, ShardedTensor):
+            local_tensor = output.local_tensor()
+            # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
+            sharding_spec = output._sharding_spec
+            if isinstance(sharding_spec, ChunkShardingSpec) \
+               and local_tensor.size(sharding_spec.dim) == 1:  # type: ignore[attr-defined, arg-type]
+                local_tensor = local_tensor.squeeze(
+                    output._sharding_spec.dim  # type: ignore[attr-defined]
+                )
+            return local_tensor
+    module.register_forward_hook(hook_func)
+    return module
+
+def shard_module(
+    module: nn.Module,
+    plan: ShardingPlan,
+    src_rank=0,
+    process_group=None
+):
+    """
+    Shards a given module according to the provided sharding `plan`. This method
+    first shards all the parameters according to the given sharding `plan`. Then if
+    `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
+    will tag the output of modules according `output_plan`, convert the module's
+    output back to data parallel according to `return_local_tensor`.
+
+    Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        module (:class:`torch.nn.Module`): The module to apply sharding to
+        plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
+            The ShardingPlan which specified param name to ShardingSpec to apply to
+            each parameter.
+
+    Keyword args:
+         src_rank (int, optional): The source rank which is used as the ground truth of
+            the data for the module that would be sharded and scattered across the rest
+            of the ranks.
+            Default: 0.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+    """
+    # record Sharder paths for sanity check on the plan to ensure items in the plan
+    # does not conflict with the submodule tree that the Sharder is working with
+    sharder_paths = []
+    for name, spec in plan.plan.items():
+        if isinstance(spec, Sharder):
+            sharder_paths.append(name)
+
+    # shard the parameter according to the ShardingPlan
+    for name, spec in plan.plan.items():
+        if isinstance(spec, ShardingSpec):
+            # if found a sharding spec, try to shard the parameter
+            module_path, _, param_name = name.rpartition(".")
+
+            for sharder_path in sharder_paths:
+                if module_path.startswith(sharder_path):
+                    raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
+                                       f" but there's already a Sharder entry for module {sharder_path},"
+                                       f" parameter sharding should not conflict with the submodule tree"
+                                       f" that a Sharder is working with!")
+
+            mod = module.get_submodule(module_path)
+            shard_parameter(
+                mod,
+                param_name,
+                spec,
+                src_rank=src_rank,
+                process_group=process_group
+            )
+        elif isinstance(spec, Sharder):
+            parent_mod_path, _, mod_name = name.rpartition(".")
+            if name == "":
+                raise KeyError("Module path must not be empty for custom sharder!")
+            mod = module.get_submodule(name)
+            parent_mod = module.get_submodule(parent_mod_path)
+            sharded_mod = spec.shard(mod)
+            # swap this submodule with the sharded module
+            parent_mod.mod_name = sharded_mod
+        else:
+            raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'")
+
+    # reshard output if there's an entry in `reshard_output` for this module
+    if plan.output_plan is not None:
+        for module_path, output_spec in plan.output_plan.items():
+            if isinstance(output_spec, ShardingSpec):
+                mod = module.get_submodule(module_path)
+                _reshard_output(mod, output_spec)
+            else:
+                raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'")
+    # convert the output back to data parallel for the modules appears in
+    # `return_local_tensor` of the plan, we will call `_collect_local_shard`
+    # to collect the local tensor for output of modules
+    if plan.return_local_tensor is not None:
+        for module_path in plan.return_local_tensor:
+            mod = module.get_submodule(module_path)
+            _collect_local_shard(mod)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec64ad79bd2a9f35b7fcb82d285f2eef57ff1831
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py
@@ -0,0 +1,12 @@
+# Keep old package for BC purposes, this file should be removed once
+# everything moves to the `torch.distributed.checkpoint` package.
+import sys
+import torch
+import warnings
+
+from torch.distributed.checkpoint import *  # noqa: F403
+warnings.warn(
+    "torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead",
+    DeprecationWarning
+)
+sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd852c86a58024a1ac5eda84b2adf01d8fa55907
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/common_op_utils.py b/MLPY/Lib/site-packages/torch/distributed/_shard/common_op_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..001c1f312224af64ed1d1ca2230990cec2eee08c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/common_op_utils.py
@@ -0,0 +1,61 @@
+import torch
+from torch.utils import _pytree as pytree
+from typing import Optional
+
+def _basic_validation(op, args=(), kwargs=None):
+    """
+    Common validation across all ops go in here.
+    """
+    from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+    if len(args) == 0 and (kwargs is None or len(kwargs) == 0):
+        raise ValueError(f" No input for '{op.__name__}'!")
+
+    # Validate types
+    has_distributed_tensor = False
+
+    def is_distributed_tensor(e):
+        nonlocal has_distributed_tensor
+        if isinstance(e, ShardedTensor):
+            has_distributed_tensor = True
+
+    pytree.tree_map_(is_distributed_tensor, args)
+    pytree.tree_map_(is_distributed_tensor, kwargs)
+
+    if not has_distributed_tensor:
+        raise TypeError(
+            f"torch function '{op.__name__}', with args: {args} and "
+            f"kwargs: {kwargs} are called without any distributed tensor!"
+        )
+
+    # Validate all distributed tensors use the same PG.
+    cur_pg: Optional[torch.distributed.ProcessGroup] = None
+
+    def validate_pg(e):
+        nonlocal cur_pg
+        if isinstance(e, ShardedTensor):
+            if cur_pg is not None and e._process_group is not cur_pg:
+                raise RuntimeError(
+                    'All distributed tensors should use the '
+                    'same ProcessGroup if used together in an op.'
+                )
+            cur_pg = e._process_group
+
+    pytree.tree_map_(validate_pg, args)
+    pytree.tree_map_(validate_pg, kwargs)
+
+def _register_default_op(op, decorator):
+    @decorator(op)
+    def tensor_default_op(types, args=(), kwargs=None, pg=None):
+        """
+        Handles ``__torch_function__`` dispatch for the default tensor ops that
+        behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
+        ``torch.Tensor.dtype``. We simply lower to the real op call with
+        DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
+        to avoid recursions.
+        """
+        if kwargs is None:
+            kwargs = {}
+
+        with torch._C.DisableTorchFunctionSubclass():
+            return op(*args, **kwargs)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/metadata.py b/MLPY/Lib/site-packages/torch/distributed/_shard/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..057f5a06fd7211ab0c749c6deb7f9dce45790616
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/metadata.py
@@ -0,0 +1,61 @@
+from dataclasses import dataclass
+from typing import List, Union, Optional
+from functools import reduce
+
+from torch.distributed.remote_device import _remote_device
+
+@dataclass
+class ShardMetadata:
+    """
+    Represents a shard of the overall Tensor including its
+    offsets, lengths and device placement.
+
+    Args:
+        shard_offsets(List[int]): Offsets in the original tensor indicating
+            the start offsets for this shard. Should have the same rank as
+            the original tensor.
+        shard_sizes(List[int]): Integers indicating the size of each
+            dimension for this shard. Should have the same rank as the
+            original tensor.
+        placement(:class:`torch.distributed._remote_device`):
+            Specifies the placement of this shard.
+    """
+
+    __slots__ = ['shard_offsets', 'shard_sizes', 'placement']
+
+    shard_offsets: List[int]
+    shard_sizes: List[int]
+    placement: Optional[_remote_device]
+
+    def __init__(
+        self,
+        shard_offsets: List[int],
+        shard_sizes: List[int],
+        placement: Optional[Union[str, _remote_device]] = None
+    ):
+        self.shard_offsets = shard_offsets
+        self.shard_sizes = shard_sizes
+        if isinstance(placement, str):
+            self.placement = _remote_device(placement)
+        else:
+            self.placement = placement
+        if len(self.shard_offsets) != len(self.shard_sizes):
+            raise ValueError(
+                f'shard_offsets and shard_sizes should have '
+                f'the same number of elements, found {len(self.shard_offsets)} '
+                f'and {self.shard_sizes} respectively')
+
+        for i in range(len(self.shard_offsets)):
+            if self.shard_offsets[i] < 0:
+                raise ValueError('shard_offsets should be >=0')
+            if self.shard_sizes[i] < 0:
+                raise ValueError('shard_sizes should be >= 0')
+
+    def __hash__(self):
+        def _hash_reduce(a, b):
+            return (a << 8) + hash(b)
+
+        res = reduce(_hash_reduce, self.shard_offsets, 37)
+        res = reduce(_hash_reduce, self.shard_sizes, res)
+        res = _hash_reduce(res, self.placement)
+        return res
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/op_registry_utils.py b/MLPY/Lib/site-packages/torch/distributed/_shard/op_registry_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af4a17ecb722383d089a8d445ed14850a3292f3e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/op_registry_utils.py
@@ -0,0 +1,35 @@
+import functools
+from inspect import signature
+from .common_op_utils import _basic_validation
+
+"""
+Common utilities to register ops on ShardedTensor
+and PartialTensor.
+"""
+
+def _register_op(op, func, op_table):
+    """
+    Performs basic validation and registers the provided op in the given
+    op_table.
+    """
+    if len(signature(func).parameters) != 4:
+        raise TypeError(
+            f'Custom sharded op function expects signature: '
+            f'(types, args, kwargs, process_group), but received '
+            f'signature: {signature(func)}')
+
+    op_table[op] = func
+
+def _decorator_func(wrapped_func, op, op_table):
+    """
+    Decorator function to register the given ``op`` in the provided
+    ``op_table``
+    """
+
+    @functools.wraps(wrapped_func)
+    def wrapper(types, args, kwargs, process_group):
+        _basic_validation(op, args, kwargs)
+        return wrapped_func(types, args, kwargs, process_group)
+
+    _register_op(op, wrapper, op_table)
+    return wrapper
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94d65d720c8b229037f983bf1ac8d175f19f6155
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__init__.py
@@ -0,0 +1,54 @@
+from typing import Iterator, Tuple, Union
+from .api import ShardedOptimizer
+
+import torch.nn as nn
+
+from torch.distributed._shard.sharded_tensor import (
+    ShardedTensor
+)
+
+def named_params_with_sharded_tensor(
+    module: nn.Module,
+    prefix: str = '',
+    recurse: bool = True,
+) -> Iterator[Tuple[str, Union[nn.Parameter, ShardedTensor]]]:
+
+    r"""Returns an iterator over module parameters (together with the
+    ShardedTensor parameters), yielding both the name of the parameter
+    as well as the parameter itself. This is typically passed to a
+    :class:torch.distributed._shard.sharded_optim.ShardedOptimizer
+
+    Args:
+        prefix (str): prefix to prepend to all parameter names.
+        recurse (bool): if True, then yields parameters of this module
+            and all submodules. Otherwise, yields only parameters that
+            are direct members of this module.
+
+    Yields:
+        (str, Union[Tensor, ShardedTensor]): Tuple containing
+            the name and parameter (or ShardedTensor parameter)
+
+    Example::
+
+        >>> # xdoctest: +SKIP
+        >>> model = torch.nn.Linear(*linear_size)
+        >>> shard_parameter(model, "weight", spec)
+        >>> for name, param in named_params_with_sharded_tensor(model):
+        >>>    if name in ['weight']:
+        >>>        print(param.size())
+
+    """
+    modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
+
+    memo = set()
+    for mod_prefix, mod in modules:
+        # find all sharded tensor params
+        for name, val in vars(mod).items():
+            if isinstance(val, ShardedTensor) and val not in memo:
+                memo.add(val)
+                name = mod_prefix + ('.' if mod_prefix else '') + name
+                yield name, val
+
+    # find all nn.Parameters
+    for name, val in module.named_parameters():
+        yield name, val
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d3dcc94d3d5230bf33e8a667de0b08fea10f2b5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b1008d0e12d1df25b38df6a52d582de29b512d4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/api.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14be300e67888aceb78f3124d004e37deea489b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_optim/api.py
@@ -0,0 +1,97 @@
+from typing import List, Union, Mapping, Dict, Any
+
+import torch.optim as optim
+from torch import Tensor
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+
+class ShardedOptimizer(optim.Optimizer):
+    def __init__(
+        self,
+        named_params: Mapping[str, Union[Tensor, ShardedTensor]],
+        optimizer_class,
+        *optimizer_args,
+        **optimizer_kwargs
+    ):
+        """
+        ShardedOptimizer collects all tensors and local shard tensors of
+        ShardedTensor, then use these tensors as ``params`` for optimizers
+
+        Args:
+            named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
+                of parameters, where key is the parameter key, value is either
+                Tensor or ShardedTensor parameter.
+            optimizer_class (torch.optim.Optimizer): the Optimizer to use
+                locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
+            *optimizer_args: the arguments to initialize the optimizer.
+            **optimizer_kwargs: the key-word arguments to initialize the optimizer.
+
+        """
+        tensors: List[Tensor] = []
+        for value in named_params.values():
+            if isinstance(value, ShardedTensor):
+                for local_shard in value.local_shards():
+                    tensors.append(local_shard.tensor)
+            else:
+                tensors.append(value)
+
+        self.named_params = named_params
+        self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
+        self.param_groups = self._optim.param_groups
+        self.state = self._optim.state
+
+    def zero_grad(self, set_to_none: bool = True):  # type: ignore[override]
+        r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
+
+        Args:
+            set_to_none (bool): instead of setting to zero, set the grads to None.
+                This will in general have lower memory footprint, and can modestly improve performance.
+                However, it changes certain behaviors. For example:
+                1. When the user tries to access a gradient and perform manual ops on it,
+                a None attribute or a Tensor full of 0s will behave differently.
+                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
+                are guaranteed to be None for params that did not receive a gradient.
+                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
+                (in one case it does the step with a gradient of 0 and in the other it skips
+                the step altogether).
+        """
+        self._optim.zero_grad(set_to_none)
+
+    def step(self, closure=None):
+        r"""Performs a single optimization step (parameter update).
+
+        Args:
+            closure (Callable): A closure that reevaluates the model and
+                returns the loss. Optional for most optimizers.
+
+        .. note::
+            Unless otherwise specified, this function should not modify the
+            ``.grad`` field of the parameters.
+        """
+        self._optim.step(closure)
+
+    def state_dict(self) -> Dict[str, Any]:
+        """
+        Returned state and param_groups will contain parameter keys
+        instead of parameter indices like torch.optim.Optimizer.
+        This allows for advanced functionality like optimizer re-sharding to be implemented.
+        """
+        # TODO: implement state_dict
+        raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!")
+
+
+    def load_state_dict(self, state_dict: Mapping[str, Any]):
+        r"""Loads the ShardedOptimizer state.
+
+        Args:
+            state_dict (dict): ShardedOptimizer state. Should be an object returned
+                from a call to :meth:`state_dict`.
+        """
+        # TODO: implement load_state_dict
+        raise NotImplementedError("ShardedOptimizer load_state_dict not implemented yet!")
+
+    def add_param_group(self, param_group: Any):
+        r"""Add a new param group
+        """
+        # TODO: implement add_param_group
+        raise NotImplementedError("ShardedOptimizer add_param_group not implemented yet!")
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..183854860b6899c398c0aa082197c24c29764b4d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py
@@ -0,0 +1,469 @@
+import functools
+from typing import List, TYPE_CHECKING
+
+import torch
+
+if TYPE_CHECKING:
+    from torch.distributed._shard.sharding_spec import ShardingSpec
+else:
+    ShardingSpec = "ShardingSpec"
+
+from .api import (
+    _CUSTOM_SHARDED_OPS,
+    _SHARDED_OPS,
+    Shard,
+    ShardedTensorBase,
+    ShardedTensor,
+    ShardedTensorMetadata,
+    TensorProperties,
+)
+from .metadata import ShardMetadata  # noqa: F401
+from torch.distributed._shard.op_registry_utils import _decorator_func
+
+
+def empty(sharding_spec: ShardingSpec,
+          *size,
+          dtype=None,
+          layout=torch.strided,
+          requires_grad=False,
+          pin_memory=False,
+          memory_format=torch.contiguous_format,
+          process_group=None,
+          init_rrefs=False) -> ShardedTensor:
+    """
+    Returns a :class:`ShardedTensor` filled with uninitialized data.
+        Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        memory_format (:class:`torch.memory_format`, optional): the desired memory format of
+            returned Tensor. Default: ``torch.contiguous_format``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    return ShardedTensor(
+        sharding_spec,
+        *size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        pin_memory=pin_memory,
+        memory_format=memory_format,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+
+def ones(sharding_spec: ShardingSpec,
+         *size,
+         dtype=None,
+         layout=torch.strided,
+         requires_grad=False,
+         pin_memory=False,
+         memory_format=torch.contiguous_format,
+         process_group=None,
+         init_rrefs=False) -> ShardedTensor:
+    """
+    Returns a :class:`ShardedTensor` with the scalar value 1.
+        Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    return full(
+        sharding_spec,
+        size,
+        fill_value=1,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        pin_memory=pin_memory,
+        memory_format=memory_format,
+        process_group=process_group,
+        init_rrefs=init_rrefs
+    )
+
+def zeros(sharding_spec: ShardingSpec,
+          *size,
+          dtype=None,
+          layout=torch.strided,
+          requires_grad=False,
+          pin_memory=False,
+          memory_format=torch.contiguous_format,
+          process_group=None,
+          init_rrefs=False) -> ShardedTensor:
+    """
+    Returns a :class:`ShardedTensor` filled with the scalar value 0.
+        Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    return full(
+        sharding_spec,
+        size,
+        fill_value=0,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        pin_memory=pin_memory,
+        memory_format=memory_format,
+        process_group=process_group,
+        init_rrefs=init_rrefs
+    )
+
+def full(sharding_spec: ShardingSpec,
+         size,
+         fill_value,
+         *,
+         dtype=None,
+         layout=torch.strided,
+         requires_grad=False,
+         pin_memory=False,
+         memory_format=torch.contiguous_format,
+         process_group=None,
+         init_rrefs=False) -> ShardedTensor:
+    """
+    Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype
+        is inferred from fill_value. If dtype is specified, it will override the
+        inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
+    Args:
+        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
+            output tensor.
+        fill_value (Scalar) – the value to fill the output tensor with.
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    sharded_tensor = ShardedTensor(
+        sharding_spec,
+        *size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        pin_memory=pin_memory,
+        memory_format=memory_format,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+    torch.nn.init.constant_(sharded_tensor, fill_value)  # type: ignore[arg-type]
+    return sharded_tensor
+
+def rand(sharding_spec: ShardingSpec,
+         *size,
+         dtype=None,
+         layout=torch.strided,
+         requires_grad=False,
+         pin_memory=False,
+         memory_format=torch.contiguous_format,
+         process_group=None,
+         init_rrefs=False) -> ShardedTensor:
+    """
+    Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution
+        on the interval :math:`[0, 1)`. The shape of the tensor is defined by the
+        variable argument `size`. Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
+            output tensor.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    sharded_tensor = ShardedTensor(
+        sharding_spec,
+        *size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        pin_memory=pin_memory,
+        memory_format=memory_format,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+    torch.nn.init.uniform_(sharded_tensor, 0, 1)  # type: ignore[arg-type]
+    return sharded_tensor
+
+def randn(sharding_spec: ShardingSpec,
+          *size,
+          dtype=None,
+          layout=torch.strided,
+          requires_grad=False,
+          pin_memory=False,
+          memory_format=torch.contiguous_format,
+          process_group=None,
+          init_rrefs=False) -> ShardedTensor:
+    """
+    Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution
+        with mean `0` and variance `1` (also called standard normal distribution). The shape
+        of the tensor is defined by the variable argument `size`. Needs to be called on all ranks
+        in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
+            output tensor.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    sharded_tensor = ShardedTensor(
+        sharding_spec,
+        *size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        pin_memory=pin_memory,
+        memory_format=memory_format,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+    torch.nn.init.normal_(sharded_tensor, 0, 1)  # type: ignore[arg-type]
+    return sharded_tensor
+
+def init_from_local_shards(
+        local_shards: List[Shard],
+        *global_size,
+        process_group=None,
+        init_rrefs=False) -> ShardedTensor:
+    """
+    Creates an :class:`ShardedTensor` from local shards and the global metadata.
+    Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list
+            of shards that represent the local shards on this rank.
+        global_size (int...):  a list, tuple, or `torch.Size` of integers defining the
+            shape of the overall sharded tensor.
+
+    Keyword args:
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object handle on this rank
+
+
+    Examples:
+        Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
+        each shard have a (5, 5) local tensor, we can do it like below:
+
+        on rank 0:
+        >>> # xdoctest: +SKIP("not distributed")
+        >>> local_shard_metadata = ShardMetadata(
+        >>>     shard_offsets=[0, 0],
+        >>>     shard_lengths=[5, 5],
+        >>>     placement="rank:0/cuda:0"
+        >>> )
+        >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
+        >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
+
+        on rank 1:
+        >>> # xdoctest: +SKIP("not distributed")
+        >>> local_shard_metadata = ShardMetadata(
+        >>>     shard_offsets=[5, 0],
+        >>>     shard_lengths=[5, 5],
+        >>>     placement="rank:1/cuda:1"
+        >>> )
+        >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
+        >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
+    """
+    return ShardedTensor._init_from_local_shards(
+        local_shards,
+        *global_size,
+        process_group=process_group,
+        init_rrefs=init_rrefs
+    )
+
+def state_dict_hook(module, destination, prefix, local_metadata):
+    """
+    Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
+    registered to the Module using
+    :meth:`torch.nn.Module._register_state_dict_hook`.
+    """
+    for submodule_name, submodule in module.named_modules():
+        for attr_name, attr in submodule.__dict__.items():
+            if isinstance(attr, ShardedTensor):
+                mod_prefix = prefix + submodule_name
+                key = mod_prefix + ('.' if mod_prefix else '') + attr_name
+                destination[key] = attr
+
+def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+    """
+    Pre-load state dict hook to add ShardedTensor to the module.
+    """
+    for submodule_name, submodule in module.named_modules():
+        for attr_name in submodule.__dict__.keys():
+            mod_prefix = prefix + submodule_name
+            key = mod_prefix + ('.' if mod_prefix else '') + attr_name
+            if key in state_dict:
+                if isinstance(state_dict[key], ShardedTensor):
+                    setattr(submodule, attr_name, state_dict[key])
+
+def custom_sharded_op_impl(func):
+    """
+    Provides a way for users to write their own custom sharded operator. This
+    can be used to override existing ShardedTensor operators or write a new
+    one not supported by ShardedTensor. If the operator in question is covered
+    by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
+    parameters, the function provided will be invoked for that operator.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> @custom_sharded_op_impl(torch.nn.functional.linear)
+        >>> def my_custom_sharded_linear(types, args, kwargs, process_group):
+        >>>     ...
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> input = torch.rand(10, 32)
+        >>> weight = sharded_tensor.rand(32, 16)
+        >>> bias = torch.rand(16)
+        >>> # This will call 'my_custom_sharded_linear'
+        >>> torch.nn.functional.linear(input, weight, bias)
+
+    The types, args and kwargs parameters are the same parameters that are
+    passed to ``__torch_function__`` dispatch API
+    (https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
+    There is an additional ``process_group`` parameter which is the
+    process_group used for the ShardedTensor and can be used by
+    implementations for communications within a sharded implementation.
+
+    Args:
+        func(Callable): Torch function for which we want to provide a sharded
+            implementation (ex: torch.nn.functional.linear)
+    """
+    return functools.partial(
+        _decorator_func,
+        op=func,
+        op_table=_CUSTOM_SHARDED_OPS
+    )
+
+def _sharded_op_impl(func):
+    """
+    Decorator to register a default sharded op.
+    """
+    return functools.partial(
+        _decorator_func,
+        op=func,
+        op_table=_SHARDED_OPS
+    )
+
+# Import all builtin sharded ops
+from ._ops import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c605dc866f3baf727b34be264e1e083cba7f309
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b37316e2b1f0720ef634aa67634376b38a49ee28
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5088979651b8348215a5e61cd8e397719b4bc870
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a926eaf91a0c857eeb46c87c16979e8845e1bc1b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e31044bd80026f8934bbcf2e22384b47b178184c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc13fdf4927482a6b41284775ad92e86c783045f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b5fc69baf094b5d5bf772c17b7f00a4323a6385
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70c89a491bd6137aa2548c76bd39d5f80dc2c0cf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a994f8e214a17e29349c94e7c823977e5d0962b6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
@@ -0,0 +1,9 @@
+import torch.distributed._shard.sharded_tensor._ops.misc_ops
+import torch.distributed._shard.sharded_tensor._ops.tensor_ops
+
+from .binary_cmp import equal, allclose
+from .init import kaiming_uniform_, normal_, uniform_, constant_
+
+# Import all ChunkShardingSpec ops
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57293da37571eeaf444b1096816891f2b1aacfb1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3eff78092a9e142b3a4d6273f86e14693f33809
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ca60197de7ceae3f09a1814d94acf5aeab0b971
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e395827cb657597984fe1a47413a390ff4bbc562
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5138841b244eda2ba1b33bbf38a538ce305bd90d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2cbbb5a40da7578c5fc24f0b5e70daa07b52738
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d094dda16c6aee25c4c498ba217fb3dc76b0342
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py
@@ -0,0 +1,107 @@
+import functools
+from torch.distributed._shard.sharded_tensor import (
+    _sharded_op_impl,
+    Shard,
+    ShardedTensor,
+)
+from torch.distributed._shard.common_op_utils import _basic_validation
+
+def _sharded_op_common(op, early_stop_func, extra_check):
+    """
+    Inject sharded tensor op registration with common logics executed before
+    different behaviors are done on either local shards or a local tensor.
+
+    Example::
+        >>> # xdoctest: +SKIP("Undefined variables")
+        >>> op = torch.transpose
+        >>> @_sharded_op_impl(op)
+        >>> @_sharded_op_common(op, early_stop_func, extra_check)
+        >>> def sharded_tensor_op(types, args, kwargs, process_group):
+        >>>   ...
+        >>>
+        >>> st = sharded_tensor.rand(32, 16)
+        >>> st.transpose(1, 2)
+        >>> # This will call '_sharded_op_common'
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+
+    Return:
+        func (Callable): Torch function for which we want to provide a sharded
+            implementation (ex: torch.transpose)
+    """
+    def decorator_sharded_func(wrapped_func):
+        @functools.wraps(wrapped_func)
+        def wrapper(types, args=(), kwargs=None, pg=None):
+            _basic_validation(op, args, kwargs)
+
+            st = args[0]
+            if kwargs is None:
+                kwargs = {}
+            if extra_check:
+                extra_check(*args, **kwargs)
+            if early_stop_func:
+                early_stop = early_stop_func(*args, **kwargs)
+                if early_stop:
+                    return st
+            return wrapped_func(types, args, kwargs, pg)
+
+        return wrapper
+
+    return decorator_sharded_func
+
+def _register_sharded_op_on_local_shards(
+    op, early_stop_func=None, extra_check=None, customized_func=None
+):
+    """
+    Handles ``__torch_function__`` dispatch for ops which are performed on
+    each shard of the sharded tensor such as elementwise op like
+    ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
+
+    For more complicated ops, a customized func can be used to generate
+    the new shards and sharded tensor size.
+
+    This function expects that the original ShardingSpec for the ShardedTensor
+    is preserved irrespective of whether or not a customized function is used.
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+        customized_func (Callable, optional): the func for customized logic
+            to generate new shards and sharded tensor size.
+            Default: if ``None``, we simply lower to the real op call with
+                all local shards of the st.
+
+    Return:
+        func (Callable): registered implementation for sharded op for
+        ``__torch_function__`` dispatch.
+    """
+    @_sharded_op_impl(op)
+    @_sharded_op_common(op, early_stop_func, extra_check)
+    def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
+        st = args[0]
+        st_metadata = st.metadata()
+        local_shards = st.local_shards()
+        local_shards_new = []
+        if customized_func:
+            local_shards_new, st_metadata = customized_func(args, kwargs, pg)
+        else:
+            for local_shard in local_shards:
+                args = (local_shard.tensor, *args[1:])
+                local_shards_new.append(
+                    Shard(op(*args, **kwargs), local_shard.metadata)
+                )
+        return ShardedTensor._init_from_local_shards_and_global_metadata(
+            local_shards_new,
+            st_metadata,
+            process_group=pg,
+            init_rrefs=st._init_rrefs,
+            sharding_spec=st.sharding_spec()
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9809e70d718399e99fbd20a47779456f4238692f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py
@@ -0,0 +1,68 @@
+import torch
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as distributed_c10d
+from torch.distributed._shard.sharded_tensor import (
+    ShardedTensor,
+    _sharded_op_impl
+)
+
+def _communicate_result(result, pg):
+    # Gather results from all ranks.
+    if result:
+        result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device()))
+    else:
+        result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device()))
+
+    dist.all_reduce(result_tensor, group=pg)
+
+    expected_result = torch.ones(1, device=torch.device(torch.cuda.current_device())) * dist.get_world_size(pg)
+
+    return torch.equal(result_tensor, expected_result)
+
+def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
+    if len(args) != 2:
+        raise ValueError(f'Expected two arguments for torch.{cmp_fun.__name__}')
+
+    result = True
+    st1 = args[0]
+    st2 = args[1]
+    if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
+        raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor')
+
+    # Verify same PG
+    if st1._process_group != st2._process_group:
+        return False
+
+    if distributed_c10d._rank_not_in_group(st1._process_group) or distributed_c10d._rank_not_in_group(st2._process_group):
+        return distributed_c10d._rank_not_in_group(st1._process_group) == distributed_c10d._rank_not_in_group(st2._process_group)
+
+    # Verify metadata
+    if st1.metadata() != st2.metadata():
+        return _communicate_result(False, st1._process_group)
+
+    # Verify number of local shards
+    st1_local_shards = st1.local_shards()
+    st2_local_shards = st2.local_shards()
+    if len(st1_local_shards) != len(st2_local_shards):
+        return _communicate_result(False, st1._process_group)
+
+    # kwargs must be dict-like
+    if kwargs is None:
+        kwargs = {}
+    # Verify each local shard
+    for idx in range(len(st1_local_shards)):
+        if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata:
+            return _communicate_result(False, st1._process_group)
+        if not cmp_fun(st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs):
+            return _communicate_result(False, st1._process_group)
+
+
+    return _communicate_result(True, st1._process_group)
+
+@_sharded_op_impl(torch.equal)
+def equal(types, args, kwargs, process_group):
+    return binary_cmp(torch.equal, types, args, kwargs, process_group)
+
+@_sharded_op_impl(torch.allclose)
+def allclose(types, args, kwargs, process_group):
+    return binary_cmp(torch.allclose, types, args, kwargs, process_group)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcd9bfaa083b4f8dccc0372c4577446bc88ac972
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py
@@ -0,0 +1,143 @@
+import torch
+import torch.distributed._shard.sharded_tensor as sharded_tensor
+from torch.distributed._shard.sharded_tensor import (
+    _sharded_op_impl,
+)
+
+def validate_param(param, param_name):
+    if param is None:
+        raise ValueError(f"param: {param_name} shouldn't be None!")
+
+@_sharded_op_impl(torch.nn.init.uniform_)
+def uniform_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the Tensor in tensor.local_shards with values drawn from the uniform
+    distribution :math:`\mathcal{U}(a, b)`.
+    Args:
+        tensor: tensor sharded across devices
+        a: the lower bound of the uniform distribution
+        b: the upper bound of the uniform distribution
+    """
+    validate_param(kwargs, "kwargs")
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    a = kwargs['a']
+    validate_param(a, "a")
+    b = kwargs['b']
+    validate_param(b, "b")
+
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.uniform_(shard.tensor, a=a, b=b)
+    return sharded_tensor
+
+@_sharded_op_impl(torch.nn.init.normal_)
+def normal_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the Tensors in tensor.local_shards with values drawn from the normal
+    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+    Args:
+        tensor: tensor sharded across devices
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+    """
+    validate_param(kwargs, "kwargs")
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    mean = kwargs['mean']
+    validate_param(mean, "mean")
+    std = kwargs['std']
+    validate_param(std, "std")
+
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.normal_(shard.tensor, mean=mean, std=std)
+    return sharded_tensor
+
+@_sharded_op_impl(torch.nn.init.kaiming_uniform_)
+def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the Tensors in tensor.local_shards with values according to the method
+    described in `Delving deep into rectifiers: Surpassing human-level
+    performance on ImageNet classification` - He, K. et al. (2015), using a
+    uniform distribution. The resulting tensor will have values sampled from
+    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
+    .. math::
+        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
+    Also known as He initialization.
+    Args:
+        tensor: tensor sharded across devices
+        a: the negative slope of the rectifier used after this layer (only
+            used with ``'leaky_relu'``)
+        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
+            preserves the magnitude of the variance of the weights in the
+            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
+            backwards pass.
+        nonlinearity: the non-linear function (`nn.functional` name),
+            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
+    """
+    validate_param(kwargs, "kwargs")
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    a = kwargs['a']
+    validate_param(a, "a")
+    mode = kwargs['mode']
+    validate_param(mode, "mode")
+    nonlinearity = kwargs['nonlinearity']
+    validate_param(nonlinearity, "nonlinearity")
+
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity)
+    return sharded_tensor
+
+@_sharded_op_impl(torch.nn.init.constant_)
+def constant_(types, args=(), kwargs=None, pg=None):
+    r"""
+    Fills the input ShardedTensor with the value \text{val}val.
+    Args:
+        tensor: tensor sharded across devices
+        val: the value to fill the tensor with
+    """
+    validate_param(kwargs, "kwargs")
+    sharded_tensor = kwargs["tensor"]
+    validate_param(sharded_tensor, "tensor")
+    val = kwargs['val']
+    validate_param(val, "val")
+    for shard in sharded_tensor.local_shards():
+        torch.nn.init.constant_(shard.tensor, val=val)
+    return sharded_tensor
+
+tensor_like_creation_op_map = {
+    torch.full_like: sharded_tensor.full,
+    torch.empty_like: sharded_tensor.empty,
+    torch.zeros_like: sharded_tensor.zeros,
+    torch.ones_like: sharded_tensor.ones,
+    torch.rand_like: sharded_tensor.rand,
+    torch.randn_like: sharded_tensor.randn,
+}
+
+# tensor ops that behave the same as the default tensor
+def register_tensor_creation_op(op):
+    @_sharded_op_impl(op)
+    def tensor_creation_op(types, args=(), kwargs=None, pg=None):
+        """
+        Handles ``__torch_function__`` dispatch for tensor creation ops that
+        takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
+        ``torch.full_like``.
+        """
+        creation_op = tensor_like_creation_op_map.get(op, None)
+        if creation_op is None:
+            raise RuntimeError(f"Tensor creation {op} not supported!")
+        if kwargs is None:
+            kwargs = {}
+
+        st = args[0]
+
+        new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs)  # type: ignore[operator]
+        return new_st
+
+
+register_tensor_creation_op(torch.full_like)
+register_tensor_creation_op(torch.empty_like)
+register_tensor_creation_op(torch.zeros_like)
+register_tensor_creation_op(torch.ones_like)
+register_tensor_creation_op(torch.rand_like)
+register_tensor_creation_op(torch.randn_like)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..409e6495e803eba46b045f1cdf9a395fac858085
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py
@@ -0,0 +1,12 @@
+import torch
+from torch.distributed._shard.sharded_tensor import (
+    _sharded_op_impl,
+)
+
+# This is used by `_apply()` within module.py to set new
+# parameters after apply a certain method, we should follow
+# the future behavior of overwriting the existing tensor
+# instead of doing in-place change using `.data = `.
+@_sharded_op_impl(torch._has_compatible_shallow_copy_type)
+def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None):
+    return False
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a6b7dfdba1f8b85caf7007ac0ce5ffc3f6e8922
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
@@ -0,0 +1,215 @@
+import copy
+import torch
+from torch.distributed._shard.sharded_tensor import (
+    _sharded_op_impl,
+    Shard,
+    ShardedTensor,
+)
+from ._common import (
+    _register_sharded_op_on_local_shards,
+)
+from torch.distributed._shard.common_op_utils import _register_default_op
+
+
+# Tensor properties access
+_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.size, _sharded_op_impl)
+_register_default_op(torch.Tensor.dim, _sharded_op_impl)
+_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl)
+_register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
+_register_default_op(torch.Tensor.is_floating_point, _sharded_op_impl)
+
+# __reduce_ex__ to dispatch to get_state/set_state
+_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
+
+# autograd related properties
+_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+# TODO: set grad with a ShardedTensor that consists of all local grads
+_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl)  # type: ignore[union-attr]
+_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl)  # type: ignore[union-attr]
+_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl)  # type: ignore[attr-defined]
+
+# device property is ambiguous as from a global prospective,
+# ShardedTensor.device consists of multiple devices (might even across hosts)
+# We choose to return the current device of the local tensor to represent
+# the device property on each rank
+@_sharded_op_impl(torch.Tensor.device.__get__)
+def tensor_device(types, args=(), kwargs=None, pg=None):
+    self_st = args[0]
+    # Validate types
+    if not isinstance(self_st, ShardedTensor):
+        raise TypeError("input needs to be a ShardedTensor")
+    dev: torch.device
+    if self_st._local_shards:
+        dev = self_st._local_shards[0].tensor.device
+    elif pg and pg._get_backend_name() == "gloo":
+        dev = torch.device("cpu")
+    else:
+        dev = torch.device(torch.cuda.current_device())
+    return dev
+
+@_sharded_op_impl(torch.Tensor.is_meta.__get__)  # type: ignore[attr-defined]
+def st_is_meta(types, args=(), kwargs=None, pg=None):
+    return args[0].local_tensor().is_meta
+
+
+def sharded_type_as_check(*args, **kwargs):
+    """
+    Perform extra checks for the sharded_type_as op such as the input needs to
+    be either a Tensor or ShardedTensor.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return: None
+    """
+    if len(args) < 2:
+        raise ValueError("Needs to give a tensor to cast type as!")
+    if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
+        raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
+
+
+def same_dtype(*args, **kwargs):
+    """
+    When the dtype is the same, return the original ShardedTensor.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return (bool): Whether to return early or not.
+    """
+    return args[0].dtype == args[1].dtype
+
+
+def sharded_type_as(args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return:
+        new_local_shards (List[Shard]): Local shards for the new sharded tensor.
+        st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
+    """
+    st = args[0]
+    tensor = args[1]
+    if isinstance(tensor, ShardedTensor):
+        tensor = tensor.local_tensor()
+    new_local_shards = []
+    for shard in st.local_shards():
+        new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
+    st_meta = copy.deepcopy(st._metadata)
+    st_meta.tensor_properties.dtype = tensor.dtype
+    return new_local_shards, st_meta
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.type_as,
+    early_stop_func=same_dtype,
+    extra_check=sharded_type_as_check,
+    customized_func=sharded_type_as,
+)
+
+
+def sharded_deepcopy(args, kwargs, pg):
+    # NOTE: we directly implement deepcopy magic method
+    # instead of using the default tensor.__deepcopy__
+    # and implement clone(). This is because the default
+    # tensor deepcopy copies every attribute, but the
+    # process_group in ShardedTensor cannot be deep copied.
+    self_st = args[0]
+    new_local_shards = copy.deepcopy(self_st.local_shards())
+    new_metadata = copy.deepcopy(self_st.metadata())
+    return new_local_shards, new_metadata
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.__deepcopy__,
+    customized_func=sharded_deepcopy,
+)
+
+
+@_sharded_op_impl(torch.Tensor.copy_)
+def sharded_inplace_copy(types, args, kwargs, pg):
+    # NOTE: inplace op don't need to rewrap
+    kwargs = {} if kwargs is None else kwargs
+    self_st = args[0]
+    new_st = args[1]
+    nonblocking = kwargs.get("non_blocking", False)
+    for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
+        if local_shard.metadata != new_shard.metadata:
+            raise RuntimeError(
+                "inplace copy can only happen between two ShardedTensor with same metadata!"
+            )
+    for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
+        local_shard.tensor.copy_(new_shard.tensor, nonblocking)
+
+    return self_st
+
+
+def sharded_clone(args, kwargs, pg):
+    self_st = args[0]
+    desire_memory_format = kwargs.get("memory_format", None)
+    if desire_memory_format and desire_memory_format != torch.preserve_format:
+        raise RuntimeError("Only support torch.preserve_format for ShardedTensor!")
+    cloned_local_shards = [
+        Shard(
+            local_shard.tensor.clone(memory_format=desire_memory_format),
+            metadata=copy.deepcopy(local_shard.metadata),
+        )
+        for local_shard in self_st.local_shards()
+    ]
+    new_metadata = copy.deepcopy(self_st.metadata())
+    return cloned_local_shards, new_metadata
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.clone,
+    customized_func=sharded_clone,
+)
+
+
+def sharded_detach(args, kwargs, pg):
+    self_st = args[0]
+    detached_local_shards = [
+        Shard(
+            local_shard.tensor.detach(),
+            metadata=copy.deepcopy(local_shard.metadata),
+        )
+        for local_shard in self_st.local_shards()
+    ]
+    new_metadata = copy.deepcopy(self_st.metadata())
+    new_metadata.tensor_properties.requires_grad = False
+    return detached_local_shards, new_metadata
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.detach,
+    customized_func=sharded_detach,
+)
+
+
+@_sharded_op_impl(torch.Tensor.requires_grad_)
+def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
+    self_st = args[0]
+    # Validate types
+    if not isinstance(self_st, ShardedTensor):
+        raise TypeError("input needs to be a ShardedTensor")
+
+    if kwargs is None:
+        kwargs = {}
+
+    requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True)
+    if requires_grad == self_st.requires_grad:
+        return self_st
+
+    for local_shard in self_st.local_shards():
+        local_shard.tensor.requires_grad_(requires_grad)
+
+        # update the wrapper class property
+    with torch._C.DisableTorchFunctionSubclass():
+        self_st.requires_grad_(requires_grad)
+    # update the metadata in the meanwhile
+    self_st._metadata.tensor_properties.requires_grad = requires_grad
+    return self_st
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/api.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a072dade94e429c2b96279d8e2c7a73015cf5e5c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/api.py
@@ -0,0 +1,1253 @@
+from __future__ import annotations  # type: ignore[attr-defined]
+from dataclasses import dataclass
+from typing import (
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    cast,
+)
+import copy
+import warnings
+from functools import reduce
+import weakref
+
+import threading
+import torch
+import torch.distributed as dist
+from torch.distributed import rpc
+from torch.distributed import distributed_c10d
+from torch.distributed._shard.metadata import ShardMetadata
+import torch.distributed._shard.sharding_spec as shard_spec
+from torch.distributed._shard.sharding_spec.api import (
+    _dispatch_custom_op,
+    _has_custom_op,
+)
+from torch.distributed._shard.sharding_spec._internals import (
+    check_tensor,
+    validate_non_overlapping_shards_metadata,
+)
+from torch.distributed._shard._utils import (
+    DEPRECATE_MSG,
+)
+
+from .metadata import TensorProperties, ShardedTensorMetadata
+from .shard import Shard
+from .reshard import reshuffle_local_shard, reshard_local_shard
+from .utils import (
+    _flatten_tensor_size,
+    _parse_and_validate_remote_device,
+    _validate_output_tensor_for_gather,
+    build_metadata_from_local_shards,
+    build_global_metadata
+)
+from torch.distributed.remote_device import _remote_device
+from torch.utils import _pytree as pytree
+import operator
+
+# Tracking for sharded tensor objects.
+_sharded_tensor_lock = threading.Lock()
+_sharded_tensor_current_id = 0
+_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {}
+
+# Default sharded ops
+_SHARDED_OPS: Dict[Callable, Callable] = {}
+
+# Customized user ops
+_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
+
+def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int):
+    with _sharded_tensor_lock:
+        if sharded_tensor_id not in _sharded_tensor_map:
+            raise RuntimeError(
+                f'Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}')
+
+        sharded_tensor = _sharded_tensor_map[sharded_tensor_id]()
+        if sharded_tensor is None:
+            raise RuntimeError('ShardedTensor weakref has been deallocated')
+        else:
+            sharded_tensor._register_remote_shards(rrefs, rpc_rank)
+
+class ShardedTensorBase(torch.Tensor):
+    _sharding_spec: shard_spec.ShardingSpec
+    _metadata: ShardedTensorMetadata
+    _local_shards: List[Shard]
+
+    def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
+        # Use __new__ to construct a wrapper tensor, for recording tensor
+        # properties and logging purposes.
+        torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
+
+        # check sharding spec and build sharded tensor metadata
+        if not isinstance(sharding_spec, shard_spec.ShardingSpec):
+            raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}")
+
+        sizes = _flatten_tensor_size(size)
+        dtype = kwargs["dtype"]
+        layout = kwargs["layout"]
+        pin_memory = kwargs["pin_memory"]
+        requires_grad = kwargs["requires_grad"]
+
+        if dtype is None:
+            dtype = torch.get_default_dtype()
+
+        tensor_properties = TensorProperties(
+            dtype, layout, requires_grad, pin_memory=pin_memory
+        )
+        sharded_tensor_metadata = sharding_spec.build_metadata(
+            sizes, tensor_properties=tensor_properties
+        )
+
+        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
+            cls,
+            sizes,
+            dtype=dtype,
+            layout=layout,
+            pin_memory=pin_memory,
+            requires_grad=requires_grad,
+        )
+        # set sharding spec
+        r._sharding_spec = sharding_spec
+        # set metadata
+        r._metadata = sharded_tensor_metadata
+        # set local shards
+        r._local_shards = []
+        return r
+
+    def metadata(self) -> ShardedTensorMetadata:
+        """
+        Returns a :class:`ShardedTensorMetadata` object corresponding to the
+        metadata for the entire tensor.
+        """
+        return self._metadata
+
+    def local_shards(self) -> List[Shard]:
+        """
+        Returns a list of :class:`Shard' corresponding to the
+        local shards for this rank. Returns an empty list if the current rank
+        does not host any shards for this Tensor.
+        """
+        return self._local_shards
+
+    @classmethod
+    def _init_from_local_shards_and_global_metadata(
+        cls,
+        local_shards: List[Shard],
+        sharded_tensor_metadata: ShardedTensorMetadata,
+        sharding_spec=None,
+    ) -> ShardedTensorBase:
+        """
+        Initialize a ShardedTensorBase with local shards and a global
+        ShardedTensorMetadata built on each rank.
+        Warning: This API is experimental and subject to change. It does
+                 not do cross rank validations, and fully rely on the user
+                 for the correctness of sharded_tensor_metadata on each rank
+        """
+        shards_metadata = sharded_tensor_metadata.shards_metadata
+        tensor_properties = sharded_tensor_metadata.tensor_properties
+
+        if len(shards_metadata) == 0:
+            raise ValueError("shards_metadata must not be empty!")
+
+        if tensor_properties.layout != torch.strided:
+            raise ValueError("Only torch.strided layout is currently supported")
+
+        if sharding_spec is None:
+            spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
+        else:
+            spec = sharding_spec
+
+        sharded_tensor_base = ShardedTensorBase.__new__(
+            ShardedTensor,
+            spec,
+            sharded_tensor_metadata.size,
+            dtype=tensor_properties.dtype,
+            layout=tensor_properties.layout,
+            pin_memory=tensor_properties.pin_memory,
+            requires_grad=tensor_properties.requires_grad,
+        )
+
+        # check if shards_metadata have overlap shards
+        validate_non_overlapping_shards_metadata(shards_metadata)
+
+        # check if the shards_metadata is compatible with overall size of the sharded tensor.
+        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))
+
+        # done validation, add local_shards
+        sharded_tensor_base._local_shards = local_shards
+        return sharded_tensor_base
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        raise RuntimeError(
+            f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} "
+            "but the there is no custom __torch_dispatch__ implementation for it."
+        )
+
+class ShardedTensor(ShardedTensorBase):
+    """
+    ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded
+    across multiple devices and multiple processes.
+
+    ShardedTensor is initialized in an SPMD like fashion where each rank
+    initializes the ShardedTensor. The ShardedTensor object on each rank
+    then only stores the local shard for the Tensor and provides global
+    metadata for all the shards.
+
+    ShardedTensor doesn't provide any Tensor like operations but is a wrapper
+    providing the Tensor representing the local shard and the global metadata.
+    Using these, users can build their custom distributed._sharded computations
+    on top of this primitive. The local shards are all initialized using the
+    create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
+    torch.empty
+
+    Args:
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+                Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        memory_format (:class:`torch.memory_format`, optional): the desired memory format of
+            returned Tensor. Default: ``torch.contiguous_format``.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    .. note:: ShardedTensor uses collectives to do various operations, i.e. it
+        uses all_gather to do cross rank validations. For NCCL-based process
+        groups, internal tensor representations of objects must be moved to the
+        GPU device before communication takes place. In this case, the device
+        used is given by ``torch.cuda.current_device()`` and it is the user's
+        responsibility to ensure that this is set so that each rank has an
+        individual GPU, via ``torch.cuda.set_device()``
+
+    """
+    def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
+        self = super().__new__(cls, sharding_spec, *size, **kwargs)
+        return self
+
+    def __init__(
+        self,
+        sharding_spec: shard_spec.ShardingSpec,
+        *size,
+        dtype=None,
+        layout=torch.strided,
+        requires_grad=False,
+        pin_memory=False,
+        memory_format=torch.contiguous_format,
+        process_group=None,
+        init_rrefs=False,
+    ):
+        # prepare initialization, initialize fields like
+        # _process_group, _local_shards, etc.
+        self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
+
+        if layout != torch.strided:
+            raise ValueError('Only torch.strided layout is currently supported')
+
+        if memory_format != torch.contiguous_format:
+            raise ValueError('Only torch.contiguous_format memory_format is currently supported')
+
+        self._metadata.tensor_properties.memory_format = memory_format
+
+        current_rank = dist.get_rank(self._process_group)
+
+        for shard_metadata in self._metadata.shards_metadata:
+            rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement)
+            if rank == current_rank:
+                local_tensor = _create_tensor_from_params(
+                    shard_metadata.shard_sizes,
+                    local_device=device,
+                    tensor_properties=self._metadata.tensor_properties
+                )
+                self._local_shards.append(Shard(local_tensor, shard_metadata))
+
+        # do post initialization (i.e. register sharded_tensor_id, initialize_rpc)
+        self._post_init()
+
+    def _prepare_init(self, process_group=None, init_rrefs=False):
+        self._init_rrefs = init_rrefs
+        self._sharded_tensor_id = None
+
+        self._process_group = (
+            process_group
+            if process_group is not None
+            else distributed_c10d._get_default_group()
+        )
+
+        self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
+
+    def _post_init(self):
+        # Initialize RPC if available.
+        if self._init_rrefs:
+            with _sharded_tensor_lock:
+                global _sharded_tensor_current_id, _sharded_tensor_map
+                self._sharded_tensor_id = _sharded_tensor_current_id
+                _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self)
+                _sharded_tensor_current_id += 1
+
+            if not rpc._is_current_rpc_agent_set():
+                raise RuntimeError(
+                    'RPC Framework needs to be initialized using'
+                    ' torch.distributed.rpc.init_rpc if init_rrefs is set to True')
+            self._init_rpc()
+
+    def __del__(self):
+        # Clean up the global map.
+        with _sharded_tensor_lock:
+            global _sharded_tensor_current_id, _sharded_tensor_map
+            if (
+                hasattr(self, "_sharded_tensor_id")
+                and self._sharded_tensor_id in _sharded_tensor_map
+            ):
+                _sharded_tensor_map.pop(self._sharded_tensor_id)  # type: ignore[call-overload]
+
+    def _init_rpc(self):
+        # Validate PG and RPC ranks match.
+        pg_rank = dist.get_rank()
+        rpc_rank = rpc.get_worker_info().id
+        if pg_rank != rpc_rank:
+            raise ValueError(
+                f'Default ProcessGroup and RPC ranks must be '
+                f'the same for ShardedTensor, found process group rank: '
+                f'{pg_rank} and RPC rank: {rpc_rank}'
+            )
+
+        self._remote_shards = {}
+
+        # Gather all the sharded tensor ids.
+        worker_infos = rpc._get_current_rpc_agent().get_worker_infos()
+        rank_to_name = {}
+        name_to_rank = {}
+
+        for worker_info in worker_infos:
+            rank_to_name[worker_info.id] = worker_info.name
+            name_to_rank[worker_info.name] = worker_info.id
+
+        all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id)
+
+        # Share the local shards to the entire world.
+        futs = []
+        rpc_rank = rpc.get_worker_info().id
+        for rank in range(dist.get_world_size()):
+            # Skip self.
+            if rank == dist.get_rank():
+                continue
+
+            if len(self.local_shards()) != 0:
+                rrefs: List[rpc.RRef[Shard]] = [rpc.RRef(shard) for shard in self.local_shards()]
+                fut = rpc.rpc_async(
+                    rank,
+                    _register_remote_shards,
+                    args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank))
+                futs.append(fut)
+
+        torch.futures.wait_all(futs)
+
+        # Barrier for all RPCs to finish on all ranks.
+        rpc.api._all_gather(None)
+
+    def _get_preferred_device(self) -> torch.device:
+        """
+        Return the preferred device to be used when creating tensors for collectives.
+        This method takes into account the associated process group
+        """
+        if dist.get_backend(self._process_group) == dist.Backend.NCCL:
+            return torch.device(torch.cuda.current_device())
+        return torch.device("cpu")
+
+    def gather(  # type: ignore[override]
+        self,
+        dst: int = 0,
+        out: Optional[torch.Tensor] = None,
+        enforce_dtype: bool = False,
+        dtype: Optional[torch.dtype] = None,
+    ) -> None:
+        """
+        Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the
+        sharded tensor.
+
+        The API needs to be called on all ranks in SPMD fashion. All ranks should have
+        the same ``dst``. ``out`` should be a tensor of the same size as the overall
+        size of the sharded tensor on ``dst`` and ``None`` on all other ranks.
+
+        Args:
+            dst(int): The rank where full tensor is constructed.
+                Default: 0
+            out (:class `torch.Tensor`, optional): The output full tensor.
+                Must to be provided ONLY on ``dst`` rank.
+                Default: ``None``
+            enforce_dtype (bool): Deprecated, please use dtype instead.  Force the
+                gathered tensors to be the same type as input and output.
+            dtype (torch.dtype): Force the gathered tensors to be this dtype.
+                Default: ``None``
+        """
+        def shard_size(shard_md):
+            return reduce(operator.mul, shard_md.shard_sizes)  # type: ignore[attr-defined]
+
+        if enforce_dtype:
+            warnings.warn("enforce_dtype is deprecated.  Please use dtype instead.")
+
+        rank = dist.get_rank(self._process_group)
+        full_size = self.metadata().size
+        _validate_output_tensor_for_gather(rank, dst, full_size, out)
+
+        local_shards = self.local_shards()
+        world_size = dist.get_world_size(self._process_group)
+        rank_sizes = [0 for _ in range(world_size)]
+        max_rank_size = 0
+        shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
+        # collect sizes
+        for shard_md in self.metadata().shards_metadata:
+            shard_rank = cast(_remote_device, shard_md.placement).rank()
+            assert shard_rank is not None
+
+            shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
+            rank_sizes[shard_rank] += shard_size(shard_md)
+            max_rank_size = max(max_rank_size, rank_sizes[shard_rank])
+
+        gather_list: Optional[List[torch.Tensor]]
+        if rank == dst:
+            assert out is not None
+            if enforce_dtype:
+                # enforce_dtype is deprecated.  Do it for backward compatibility.
+                dtype = out.dtype
+            # TODO make it as a view of out tensor
+            gather_list = [torch.empty((max_rank_size,), device=out.device, dtype=dtype) for _ in range(world_size)]
+        else:
+            gather_list = None
+
+        with torch.no_grad():
+            if enforce_dtype and len(local_shards) > 0:
+                # enforce_dtype is deprecated.  Do it for backward compatibility.
+                dtype = local_shards[0].tensor.dtype
+            data = torch.empty(max_rank_size, device=self._get_preferred_device(), dtype=dtype)
+
+            for shard in local_shards:
+                src = shard.tensor.flatten()
+                if src.nelement() == 0 :
+                    warnings.warn("Gathering a tensor with zero elements on rank " + str(rank))
+                    return
+                shard_offset = shard_placement[shard.metadata][1]
+                data[shard_offset: shard_offset + src.numel()].copy_(src)
+
+        dist.gather(
+            tensor=data,
+            gather_list=gather_list,
+            dst=dst,
+            group=self._process_group,
+        )
+        if rank != dst:
+            return
+        # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst
+        out = cast(torch.Tensor, out)
+        assert gather_list is not None
+
+        full_size = self.metadata().size
+        dims = len(full_size)
+        for shard_md in self.metadata().shards_metadata:
+            rank, rank_offset = shard_placement[shard_md]
+            tensor = gather_list[rank]
+            tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
+            tensor = tensor.view(shard_md.shard_sizes)
+
+            out_narrow_view = out
+            for dim in range(dims):
+                out_narrow_view = out_narrow_view.narrow(
+                    dim,
+                    shard_md.shard_offsets[dim],
+                    shard_md.shard_sizes[dim],
+                )
+
+            out_narrow_view.copy_(tensor)
+
+    def cpu(
+        self,
+        memory_format=torch.preserve_format,
+        process_group=None
+    ) -> ShardedTensor:
+        """
+        Returns a copy of this object in CPU memory.
+
+        If this ShardedTensor is already on CPU memory, then no copy is
+        performed and original object is returned.
+
+        .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might
+            need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo),
+            it is the user's responsiblity to explicitly pass in a new process_group that
+            is compatible with CPU.
+        """
+        # TODO: make this a __torch_function__ op once ShardedTensor becomes a
+        # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402
+        if memory_format != torch.preserve_format and \
+                memory_format != torch.contiguous_format:
+            raise RuntimeError("Only `torch.contiguous_format` or "
+                               "`torch.preserve_format` is supported!")
+        all_on_cpu = True
+        for meta in self.metadata().shards_metadata:
+            all_on_cpu &= (meta.placement.device().type == "cpu")  # type: ignore[union-attr]
+
+        # if every shard is already on CPU, return the original object
+        if all_on_cpu:
+            return self
+
+        # if not, returns a copy of this object on CPU
+        list_shards: List[Shard] = []
+        # move all local shards to cpu, and change metadata
+        for shard in self._local_shards:
+            cpu_tensor = shard.tensor.cpu(memory_format=memory_format)  # type: ignore[call-arg]
+            metadata = copy.deepcopy(shard.metadata)
+            metadata.placement._device = torch.device("cpu")  # type: ignore[union-attr]
+            list_shards.append(
+                Shard(cpu_tensor, metadata)
+            )
+
+        st_meta = copy.deepcopy(self.metadata())
+        for meta in st_meta.shards_metadata:
+            if meta.placement.device().type != "cpu":  # type: ignore[union-attr]
+                meta.placement._device = torch.device("cpu")  # type: ignore[union-attr]
+
+        pg = self._process_group if process_group is None else process_group
+        st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata(
+            list_shards,
+            sharded_tensor_metadata=st_meta,
+            process_group=pg,
+            init_rrefs=self._init_rrefs
+        )
+        return st_cpu
+
+    def cuda(
+        self,
+        device=None,
+        non_blocking=False,
+        memory_format=torch.preserve_format,
+        process_group=None
+    ) -> ShardedTensor:
+        """
+        Returns a copy of this object in CUDA memory, if the original ShardedTensor
+        is on CPU, we will move the local shard to the current GPU device of each
+        process in a SPMD fashion.
+        If this ShardedTensor is already on CUDA memory and local shards on each rank are
+        already on current device, we still returns a new ShardedTensor object with new
+        metadata, but no underlying data movements are performed.
+        .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might
+            need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL),
+            it is the user's responsiblity to explicitly pass in a new process_group that
+            is compatible with GPU.
+        """
+        if memory_format != torch.preserve_format and \
+                memory_format != torch.contiguous_format:
+            raise RuntimeError("Only `torch.contiguous_format` or "
+                               "`torch.preserve_format` is supported!")
+
+        if device is not None:
+            device = torch.device(device) if isinstance(device, str) else device
+            assert isinstance(device, torch.device) and device.index == torch.cuda.current_device(), \
+                '''Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!'''
+
+        current_device = torch.device(torch.cuda.current_device())
+        # returns a copy of ShardedTensor on CUDA current device
+        list_shards: List[Shard] = []
+        # move all local shards to current device, and change metadata
+        # if local shards already on the current device, there's no
+        # real data movement, only the metadata are copied.
+        for shard in self._local_shards:
+            cuda_tensor = shard.tensor.cuda(
+                device=current_device,
+                non_blocking=non_blocking,
+                memory_format=memory_format
+            )  # type: ignore[call-arg]
+            metadata = copy.deepcopy(shard.metadata)
+            metadata.placement._device = current_device  # type: ignore[union-attr]
+
+            list_shards.append(
+                Shard(cuda_tensor, metadata)
+            )
+
+        st_meta = copy.deepcopy(self.metadata())
+        for meta in st_meta.shards_metadata:
+            if meta.placement.device().type != "cuda":  # type: ignore[union-attr]
+                meta.placement._device = current_device  # type: ignore[union-attr]
+
+        pg = self._process_group if process_group is None else process_group
+        # we need to use `init_from_local_shards` to communicate between ranks
+        # and update the sharding spec/shards metadata.
+        st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata(
+            list_shards,
+            sharded_tensor_metadata=st_meta,
+            process_group=pg,
+            init_rrefs=self._init_rrefs
+        )
+        return st_cuda
+
+    def to(self, *args, **kwargs) -> ShardedTensor:
+        current_device: torch.device
+        if self._local_shards:
+            current_device = self._local_shards[0].tensor.device
+        elif self._process_group._get_backend_name() == "gloo":
+            current_device = torch.device("cpu")
+        else:
+            current_device = torch.device(torch.cuda.current_device())
+        current_dtype = self.dtype
+        device_to = current_device
+        dtype_to = current_dtype
+        if len(args) == 1:
+            if isinstance(args[0], torch.dtype):
+                dtype_to = args[0]
+            elif isinstance(args[0], torch.device):
+                device_to = args[0]
+            elif isinstance(args[0], (str, int)):
+                device_to = torch.device(args[0])
+            elif isinstance(args[0], torch.Tensor):
+                dtype_to = args[0].dtype
+                device_to = args[0].device
+            else:
+                raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}")
+        elif len(args) == 2:
+            device_to, dtype_to = args
+        else:
+            dtype_to = kwargs.get("dtype", current_dtype)
+            device_to = kwargs.get("device", current_device)
+
+        device_to = torch.device(device_to) if isinstance(device_to, (str, int)) else device_to
+
+        if device_to.type == "cuda":
+            # if device_to set to cuda, set to current device even
+            # if user specify the device index.
+            current_idx = torch.cuda.current_device()
+            if device_to.index != current_idx:
+                warnings.warn("ShardedTensor.to only move tensor to its current device"
+                              "If you want to put to different device, use `reshard` instead.")
+            device_to = torch.device(current_idx)
+
+        copy_tensor = kwargs.get("copy", False)
+        non_blocking = kwargs.get("non_blocking", False)
+        memory_format = kwargs.get("memory_format", torch.preserve_format)
+        process_group = kwargs.get("process_group", None)
+
+        if not copy_tensor and dtype_to == current_dtype and device_to == current_device:
+            # already have correct dtype and device, return itself
+            return self
+
+        # returns a copy of ShardedTensor on CUDA current device
+        list_shards: List[Shard] = []
+
+        for shard in self._local_shards:
+            new_tensor = shard.tensor.to(  # type: ignore[call-overload]
+                device=device_to,
+                dtype=dtype_to,
+                non_blocking=non_blocking,
+                copy=copy_tensor,
+                memory_format=memory_format
+            )
+            metadata = copy.deepcopy(shard.metadata)
+            if metadata.placement is not None:
+                metadata.placement._device = device_to
+            list_shards.append(Shard(new_tensor, metadata))
+
+        # update metadata
+        st_meta = copy.deepcopy(self.metadata())
+        st_meta.tensor_properties.dtype = dtype_to
+        for meta in st_meta.shards_metadata:
+            meta.placement._device = device_to  # type: ignore[union-attr]
+
+        pg = self._process_group if process_group is None else process_group
+        # we need to use `init_from_local_shards` to communicate between ranks
+        # and update the sharding spec/shards metadata.
+        st_to = ShardedTensor._init_from_local_shards_and_global_metadata(
+            list_shards,
+            sharded_tensor_metadata=st_meta,
+            process_group=pg,
+            init_rrefs=self._init_rrefs
+        )
+        return st_to
+
+
+    @classmethod
+    def _init_from_local_shards(
+        cls,
+        local_shards: List[Shard],
+        *global_size,
+        process_group=None,
+        init_rrefs=False,
+    ):
+        # STEP 1: Validate the Shardmetadatas locally
+        process_group = (
+            process_group
+            if process_group is not None
+            else distributed_c10d._get_default_group()
+        )
+        current_rank = dist.get_rank(process_group)
+        world_size = dist.get_world_size(process_group)
+
+        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
+        global_tensor_size = _flatten_tensor_size(global_size)
+
+        if len(local_shards) > 0:
+            local_sharded_tensor_metadata = \
+                build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group)
+
+        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
+        # metadata by gathering local ShardedTensorMetadata
+        gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
+        if world_size > 1:
+            gathered_metadatas = [None for _ in range(world_size)]
+
+            dist.all_gather_object(
+                gathered_metadatas,
+                local_sharded_tensor_metadata,
+                group=process_group
+            )
+        else:
+            gathered_metadatas = [local_sharded_tensor_metadata]
+
+        global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas)
+        tensor_properties = global_sharded_tensor_metadata.tensor_properties
+
+        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
+        # prepare initialization
+        spec = shard_spec._infer_sharding_spec_from_shards_metadata(
+            global_sharded_tensor_metadata.shards_metadata
+        )
+        sharded_tensor = cls.__new__(cls,
+                                     spec,
+                                     global_sharded_tensor_metadata.size,
+                                     dtype=tensor_properties.dtype,
+                                     layout=tensor_properties.layout,
+                                     pin_memory=tensor_properties.pin_memory,
+                                     requires_grad=tensor_properties.requires_grad)
+        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
+
+        # attach local_shards to the ShardedTensor created
+        sharded_tensor._local_shards = local_shards
+
+        # run post initialization, i.e. map registration, rpc initialization
+        sharded_tensor._post_init()
+        return sharded_tensor
+
+    @classmethod
+    def _init_from_local_tensor(
+        cls,
+        local_tensor: torch.Tensor,
+        sharding_spec: shard_spec.ShardingSpec,
+        *global_size: Sequence[int],
+        process_group: Optional[dist.ProcessGroup] = None,
+        init_rrefs=False,
+    ) -> ShardedTensor:
+        """
+        Initialize a ShardedTensor given only one local tensor, global sharded tensor
+        size and sharding spec on each rank.
+
+        Args:
+            local_tensor (Tensor): Single tensor of local shard stored in each rank.
+            sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
+                The specification describing how to shard the Tensor.
+            global_size (Sequence[int]): Size of the sharded tensor.
+            process_group (ProcessGroup, optional): The process group to aggregate on.
+                Default: None
+            init_rrefs (bool, optional): Whether or not to initialize
+                :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+                Need to initialize the RPC Framework if specified as ``True``.
+                Default: ``False``.
+
+        Returns:
+            A :class:`ShardedTensor` sharded based on the given sharding_spec with local
+                tensor stored in the current rank.
+
+        Examples:
+            >>> # xdoctest: +SKIP
+            >>> # All tensors below are of torch.int64 type.
+            >>> # We have 2 process groups, 2 ranks.
+            >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
+            >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]))
+            >>> local_tensor
+            tensor([[1, 2, 3, 4]]) # Rank 0
+            tensor([[3, 4, 5, 6]]) # Rank 1
+            >>> sharding_dim = 0
+            >>> sharding_spec = ChunkShardingSpec(
+                    dim=sharding_dim,
+                    placements=[
+                        "rank:0/cuda:0",
+                        "rank:1/cuda:1",
+                    ],
+                )
+            >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
+            >>> st
+            ShardedTensor(
+                ShardedTensorMetadata(
+                    shards_metadata=[
+                        ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0),
+                        ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1),
+                    ],
+                    size=torch.Size([2, 4])
+            )
+            >>> st.local_tensor()
+            tensor([1, 2, 3, 4]) # Rank 0
+            tensor([3, 4, 5, 6]) # Rank 1
+
+        Warning: This API is experimental and subject to change. It lacks of a fully across
+                 rank validations, and we only validate the local shard on the current rank.
+                 We fully rely on the user to ensure local tensor is sharded based on the
+                 sharding spec.
+        """
+        warnings.warn(DEPRECATE_MSG)
+
+        if not local_tensor.is_contiguous():
+            raise ValueError('local_tensor is not a contiguous Tensor.')
+
+        global_tensor_size = _flatten_tensor_size(global_size)
+        tensor_properties = TensorProperties(
+            dtype=local_tensor.dtype,
+            layout=local_tensor.layout,
+            requires_grad=local_tensor.requires_grad,
+            memory_format=torch.contiguous_format,
+            pin_memory=local_tensor.is_pinned())
+        sharded_tensor_metadata = sharding_spec.build_metadata(
+            global_tensor_size,
+            tensor_properties
+        )
+
+        process_group = (
+            process_group
+            if process_group is not None
+            else distributed_c10d._get_default_group()
+        )
+        current_rank = dist.get_rank(process_group)
+
+        local_shards: List[Shard] = []
+        for shard_metadata in sharded_tensor_metadata.shards_metadata:
+            rank, device = _parse_and_validate_remote_device(process_group, shard_metadata.placement)
+            if rank == current_rank:
+                local_shards.append(Shard(local_tensor, shard_metadata))
+
+        # TODO: figure out what the API should behave when some rank have no shard
+        # see https://github.com/pytorch/pytorch/issues/7313
+        return ShardedTensor._init_from_local_shards_and_global_metadata(
+            local_shards,
+            sharded_tensor_metadata,
+            process_group=process_group,
+            init_rrefs=init_rrefs,
+            sharding_spec=sharding_spec,
+        )
+
+    @classmethod
+    def _init_from_local_shards_and_global_metadata(  # type: ignore[override]
+        cls,
+        local_shards: List[Shard],
+        sharded_tensor_metadata: ShardedTensorMetadata,
+        process_group=None,
+        init_rrefs=False,
+        sharding_spec=None,
+    ) -> ShardedTensor:
+        """
+        Initialize a ShardedTensor with local shards and a global
+        ShardedTensorMetadata built on each rank.
+
+        Warning: This API is experimental and subject to change. It does
+                 not do cross rank validations, and fully rely on the user
+                 for the correctness of sharded_tensor_metadata on each rank
+        """
+        process_group = (
+            process_group
+            if process_group is not None
+            else distributed_c10d._get_default_group()
+        )
+        current_rank = dist.get_rank(process_group)
+
+        shards_metadata = sharded_tensor_metadata.shards_metadata
+
+        local_shard_metadatas = []
+
+        # collect local shard metadatas from the global sharded_tensor_metadata
+        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
+            rank, local_device = _parse_and_validate_remote_device(process_group, shard_metadata.placement)
+
+            if current_rank == rank:
+                local_shard_metadatas.append(shard_metadata)
+
+        if len(local_shards) != len(local_shard_metadatas):
+            raise RuntimeError(
+                f'Number of local shards ({len(local_shards)}) does not match number of local '
+                f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
+                f'on rank ({current_rank}) '
+            )
+
+        shards_metadata = sharded_tensor_metadata.shards_metadata
+        tensor_properties = sharded_tensor_metadata.tensor_properties
+
+        if len(shards_metadata) == 0:
+            raise ValueError("shards_metadata must not be empty!")
+
+        if tensor_properties.layout != torch.strided:
+            raise ValueError("Only torch.strided layout is currently supported")
+
+        if sharding_spec is None:
+            spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
+        else:
+            spec = sharding_spec
+
+        sharded_tensor = ShardedTensor.__new__(
+            ShardedTensor,
+            spec,
+            sharded_tensor_metadata.size,
+            dtype=tensor_properties.dtype,
+            layout=tensor_properties.layout,
+            pin_memory=tensor_properties.pin_memory,
+            requires_grad=tensor_properties.requires_grad,
+        )
+
+        def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False):
+            tensor_property_or_metadata = (
+                "tensor property" if is_property else "local ShardMetadata"
+            )
+            if expected != actual:
+                raise ValueError(
+                    f"Local shards' tensor {prop_name} property is incompatible with "
+                    f"{tensor_property_or_metadata} on rank {rank}: "
+                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
+                    f"local shard tensor {prop_name}={actual}."
+                )
+
+        for shard in local_shards:
+            shard_meta = shard.metadata
+            local_shard_tensor = shard.tensor
+            placement = shard_meta.placement
+            assert placement is not None, "Must specify placement for `Shard`!"
+            rank = placement.rank()
+            local_device = placement.device()
+
+            _raise_if_mismatch(
+                tensor_properties.layout,
+                local_shard_tensor.layout,
+                "layout",
+                rank,
+                True,
+            )
+            if not local_shard_tensor.is_contiguous():
+                raise ValueError(
+                    "Only torch.contiguous_format memory_format is currently supported"
+                )
+
+            _raise_if_mismatch(
+                shard_meta.shard_sizes,
+                list(local_shard_tensor.size()),
+                "size",
+                rank,
+            )
+            _raise_if_mismatch(
+                tensor_properties.pin_memory,
+                local_shard_tensor.is_pinned(),
+                "pin_memory",
+                rank,
+                True,
+            )
+            _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank)
+            _raise_if_mismatch(
+                tensor_properties.dtype,
+                local_shard_tensor.dtype,
+                "dtype",
+                rank,
+                True,
+            )
+            _raise_if_mismatch(
+                tensor_properties.requires_grad,
+                local_shard_tensor.requires_grad,
+                "requires_grad",
+                rank,
+                True,
+            )
+
+        # check if shards_metadata have overlap shards
+        validate_non_overlapping_shards_metadata(shards_metadata)
+
+        # check if the shards_metadata is compatible with overall size of the sharded tensor.
+        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))
+
+        # done validation, add local_shards
+        sharded_tensor._local_shards = local_shards
+        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
+
+        # run post initialization, i.e. map registration, rpc initialization
+        sharded_tensor._post_init()
+        return sharded_tensor
+
+    def sharding_spec(self) -> shard_spec.ShardingSpec:
+        """
+        Returns the ShardingSpec for the tensor.
+        """
+        return self._sharding_spec
+
+    def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor:
+        """
+        Reshard a sharded tensor given the ``resharding_spec``. For now, we only support
+        single local shard.
+
+        If ``resharding_spec`` is same as the original one, this becomes a no-op.
+        If only ``resharding_spec`` shares the same sharding dim with the original one,
+        we swap local shards directly.
+        For more generic cases, we merge different shards across different ranks and split
+        the local shards based on the ``resharding_spec`` via `all_to_all` collective API.
+
+        Args:
+            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
+                specification describing how the tensor is sharded.
+
+        Returns:
+            A :class:`ShardedTensor` object whose local shards are resharded.
+
+        Examples:
+            >>> # xdoctest: +SKIP
+            >>> # We have 2 process groups, 2 ranks.
+            >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank
+            >>> tensor = torch.stack([tensor, tensor])
+            >>> tensor
+            tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0
+            tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1
+            tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2
+            tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3
+            >>> sharding_dim = 0
+            >>> spec = ChunkShardingSpec(
+                    dim=sharding_dim,
+                    placements=[
+                        "rank:0/cuda:0",
+                        "rank:1/cuda:1",
+                        "rank:2/cuda:2",
+                        "rank:3/cuda:3",
+                    ],
+                )
+            >>> current_offsets = [0] * 2
+            >>> current_offsets[0] = rank * 2
+            >>> shard_metadata = ShardMetadata(
+                    shard_offsets=copy.deepcopy(current_offsets),
+                    shard_sizes=tensor.size(),
+                    placement=spec.placements[rank],
+                )
+            >>> local_shards = [
+                    Shard(
+                        tensor=tensor,
+                        metadata=shard_metadata,
+                    )
+                ]
+            >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size())
+            >>> sharding_dim = 1
+            >>> resharding_spec = ChunkShardingSpec(
+                    dim=sharding_dim,
+                    placements=[
+                        "rank:0/cuda:0",
+                        "rank:1/cuda:1",
+                        "rank:2/cuda:2",
+                        "rank:3/cuda:3",
+                    ],
+                )
+            >>> st.reshard(resharding_spec)
+            >>> tensor = st.local_shards()[0].tensor
+            >>> tensor
+            tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0
+            tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1
+            tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2
+            tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3
+        """
+        warnings.warn(DEPRECATE_MSG)
+
+        if (
+            not isinstance(resharding_spec, shard_spec.ChunkShardingSpec) or
+            not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec)
+        ):
+            raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
+        if (len(self.local_shards()) != 1):
+            raise NotImplementedError("Only single local shard supported for reshard.")
+
+        if self._sharding_spec.dim == resharding_spec.dim:  # type: ignore[attr-defined]
+            if self._sharding_spec.placements == resharding_spec.placements:  # type: ignore[attr-defined]
+                return self
+            else:
+                local_shards, shards_metadata = reshuffle_local_shard(
+                    self.local_tensor(),
+                    self.size(),  # type: ignore[arg-type]
+                    self._sharding_spec,
+                    resharding_spec,
+                    self._process_group,
+                )
+        else:
+            local_shards, shards_metadata = reshard_local_shard(
+                self.local_tensor(),
+                self.size(),  # type: ignore[arg-type]
+                self._sharding_spec,
+                resharding_spec,
+                self._process_group,
+            )
+        self._local_shards = local_shards
+        self._metadata.shards_metadata = shards_metadata
+        self._sharding_spec = resharding_spec
+        return self
+
+    def local_tensor(self) -> torch.Tensor:
+        """
+        Return local tensor for a sharded_tensor. For now we only support single local shard.
+
+        Returns:
+            A :class:`torch.Tensor` of the local shard.
+        """
+        if len(self.local_shards()) != 1:
+            raise NotImplementedError("Only single local shard is supported.")
+        return self.local_shards()[0].tensor
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        def dispatch(st: ShardedTensor, func: Callable):
+            # Dispatch to custom user provided op first if it exists.
+            if func in _CUSTOM_SHARDED_OPS:
+                return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group)
+
+            # Dispatch to custom sharding spec op if it has one.
+            if _has_custom_op(st._sharding_spec, func):
+                return _dispatch_custom_op(
+                    st._sharding_spec,
+                    func,
+                    types,
+                    args,
+                    kwargs,
+                    st._process_group
+                )
+
+            if func in _SHARDED_OPS:
+                return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
+
+            raise RuntimeError(
+                f"torch function '{func.__name__}', with args: {args} and "
+                f"kwargs: {kwargs} not supported for ShardedTensor!")
+
+        warnings.warn(DEPRECATE_MSG)
+        # Find ShardedTensor instance to get process_group and sharding_spec.
+        st_instance = None
+
+        def find_sharded_tensor(e):
+            nonlocal st_instance
+            if st_instance is None and isinstance(e, ShardedTensor):
+                st_instance = e
+
+        pytree.tree_map_(find_sharded_tensor, args)
+        pytree.tree_map_(find_sharded_tensor, kwargs)
+
+        if st_instance is not None:
+            return dispatch(st_instance, func)
+
+        raise RuntimeError(
+            f"torch function '{func.__name__}', with args: {args} and "
+            f"kwargs: {kwargs} not supported for ShardedTensor!")
+
+    def is_pinned(self) -> bool:  # type: ignore[override]
+        """
+        Returns True if the sharded tensor (each local shard) resides in pinned memory.
+        """
+        return self._metadata.tensor_properties.pin_memory
+
+    def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int):
+        self._remote_shards[rpc_rank] = remote_shards
+
+    def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]:
+        """
+        Returns a Dict[int, RRef] with keys being the RPC rank and values
+        being RRefs to shards on that rank. Need to initialize the
+        RPC framework for this functionality.
+
+        Raises an exception if ShardedTensor was created with ``init_rrefs=False``
+        """
+        if not self._init_rrefs:
+            raise RuntimeError(
+                'ShardedTensor created with init_rrefs=False, no RRefs to remote shards available'
+            )
+        return self._remote_shards
+
+    def __hash__(self):
+        return id(self)
+
+    def __repr__(self):
+        return f'ShardedTensor({self._metadata})'
+
+    @dataclass
+    class ProcessGroupState:
+        """
+        State for ser-de of process group
+        """
+        local_rank: int
+        global_rank: int
+        local_world_size: int
+        global_world_size: int
+
+    def __getstate__(self):
+        pg_state = ShardedTensor.ProcessGroupState(
+            distributed_c10d.get_rank(self._process_group),
+            distributed_c10d.get_rank(),
+            distributed_c10d.get_world_size(self._process_group),
+            distributed_c10d.get_world_size(),
+        )
+
+        return self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs
+
+    def __setstate__(self, state):
+        self._sharded_tensor_id = None
+        if not distributed_c10d.is_initialized():
+            raise RuntimeError(
+                'Need to initialize default process group using '
+                '"init_process_group" before loading ShardedTensor')
+
+        self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state
+
+        # Setup process group
+        from torch.distributed._shard.api import _get_current_process_group
+        self._process_group = _get_current_process_group()
+
+        # Validate process group.
+        local_rank = distributed_c10d.get_rank(self._process_group)
+        if pg_state.local_rank != local_rank:
+            raise RuntimeError(
+                f'Local rank at save time was {pg_state.local_rank}, but at '
+                f'load time was {local_rank}')
+
+        global_rank = distributed_c10d.get_rank()
+        if pg_state.global_rank != global_rank:
+            raise RuntimeError(
+                f'Global rank at save time was {pg_state.global_rank}, but at '
+                f'load time was {global_rank}')
+
+        local_world_size = distributed_c10d.get_world_size(self._process_group)
+        if pg_state.local_world_size != local_world_size:
+            raise RuntimeError(
+                f'Local world size at save time was {pg_state.local_world_size}, '
+                f'but at load time was {local_world_size}')
+
+        global_world_size = distributed_c10d.get_world_size()
+        if pg_state.global_world_size != global_world_size:
+            raise RuntimeError(
+                f'Global world size at save time was {pg_state.global_world_size}, '
+                f'but at load time was {global_world_size}')
+
+        self._post_init()
+
+
+def _create_tensor_from_params(*size, local_device, tensor_properties: TensorProperties):
+    """ Helper to construct tensor from size, device and common params. """
+    dtype = tensor_properties.dtype
+    layout = tensor_properties.layout
+    requires_grad = tensor_properties.requires_grad
+    memory_format = tensor_properties.memory_format
+    pin_memory = tensor_properties.pin_memory
+
+    return torch.empty(
+        *size, dtype=dtype, layout=layout,
+        device=local_device, requires_grad=requires_grad,
+        memory_format=memory_format, pin_memory=pin_memory
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logger.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..99e4a17b6a205060684617d12da849fccc1eee1a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logger.py
@@ -0,0 +1,37 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import List, Tuple
+
+from torch.distributed._shard.sharded_tensor.logging_handlers import (
+    _log_handlers,
+)
+
+__all__: List[str] = []
+
+
+def _get_or_create_logger() -> logging.Logger:
+    logging_handler, log_handler_name = _get_logging_handler()
+    logger = logging.getLogger(f"sharding-spec-{log_handler_name}")
+    logger.setLevel(logging.DEBUG)
+    formatter = logging.Formatter(
+        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
+    )
+    logging_handler.setFormatter(formatter)
+    logger.propagate = False
+    logger.addHandler(logging_handler)
+    return logger
+
+
+def _get_logging_handler(
+    destination: str = "default",
+) -> Tuple[logging.Handler, str]:
+    log_handler = _log_handlers[destination]
+    log_handler_name = type(log_handler).__name__
+    return (log_handler, log_handler_name)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a775863e0b06b2f7597cd9cae85d19110271a1f6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Dict, List
+
+__all__: List[str] = []
+
+_log_handlers: Dict[str, logging.Handler] = {
+    "default": logging.NullHandler(),
+}
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..e917f1156eb3fb8c26f9bfbbd1a42850224bc639
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import List
+
+import torch
+from torch.distributed._shard.metadata import ShardMetadata
+
+class MEM_FORMAT_ENCODING(Enum):
+    TORCH_CONTIGUOUS_FORMAT = 0
+    TORCH_CHANNELS_LAST = 1
+    TORCH_PRESERVE_FORMAT = 2
+
+@dataclass
+class TensorProperties:
+    """ Properties used to create :class:`Tensor` """
+
+    # Regular tensor fields
+    dtype: torch.dtype = field(default=torch.get_default_dtype())
+    layout: torch.layout = field(default=torch.strided)
+    requires_grad: bool = False
+    memory_format: torch.memory_format = field(default=torch.contiguous_format)
+    pin_memory: bool = False
+
+    def __getstate__(self):
+        # Since torch.memory_format cannot be pickled!
+        memory_format = self.memory_format
+        if memory_format == torch.contiguous_format:
+            mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
+        elif memory_format == torch.channels_last:
+            mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
+        elif memory_format == torch.preserve_format:
+            mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
+        else:
+            raise RuntimeError(f'Invalid torch.memory_format: {memory_format}')
+
+        return (
+            self.dtype,
+            self.layout,
+            self.requires_grad,
+            mem_format_encoding,
+            self.pin_memory,
+        )
+
+    def __setstate__(
+        self,
+        state,
+    ):
+        (self.dtype, self.layout, self.requires_grad, mem_format_encoding, self.pin_memory) = state
+
+        if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
+            memory_format = torch.contiguous_format
+        elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
+            memory_format = torch.channels_last
+        elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
+            memory_format = torch.preserve_format
+        else:
+            raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}')
+
+        self.memory_format = memory_format
+
+    @staticmethod
+    def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
+        return TensorProperties(
+            dtype=tensor.dtype,
+            layout=tensor.layout,
+            requires_grad=tensor.requires_grad,
+            memory_format=torch.contiguous_format,
+            pin_memory=tensor.is_pinned()
+        )
+@dataclass
+class ShardedTensorMetadata:
+    """
+    Represents metadata for :class:`ShardedTensor`
+    """
+
+    # Metadata about each shard of the Tensor
+    shards_metadata: List[ShardMetadata] = field(default_factory=list)
+
+    # Size of each dim of the overall Tensor.
+    size: torch.Size = field(default=torch.Size([]))
+
+    tensor_properties: TensorProperties = field(default_factory=TensorProperties)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py
new file mode 100644
index 0000000000000000000000000000000000000000..91a1bd254327e83b6812ba6d5b2484913394c027
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py
@@ -0,0 +1,248 @@
+import copy
+from typing import List, Tuple
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import (
+    ProcessGroup,
+)
+import torch.distributed._shard.sharding_spec as shard_spec
+from torch.distributed._shard.sharding_spec._internals import (
+    get_split_size,
+    get_chunked_dim_size,
+)
+from torch.distributed.nn.functional import (
+    all_to_all,
+    all_to_all_single,
+)
+from torch.distributed._shard.metadata import ShardMetadata
+
+from .shard import Shard
+
+
+def get_idx_from_placements(placements, current_rank) -> int:
+    """
+    Return the position of the current rank in the given placements.
+
+    Args:
+        placements(List[Union[_remote_device, str]]):
+            Specifies the placement of each shard of the Tensor. The size of
+            the list represents the number of shards to be created. This could
+            be a list of
+            :class:`torch.distributed._remote_device`'s. This list
+            could also contain a string which represents remote
+            device as accepted by
+            :class:`torch.distributed._remote_device`
+        current_rank (int): number of current device.
+
+    Returns:
+        A int which contains the position of current device in the placement list.
+    """
+    for idx, placement in enumerate(placements):  # type: ignore[attr-defined]
+        if current_rank == placement.rank():  # type: ignore[union-attr]
+            return idx
+    raise RuntimeError('current_rank not in the placement.')
+
+
+def build_reshard_metadata(
+    st_size: torch.Size,
+    sharding_spec: shard_spec.ShardingSpec,
+    world_size: int,
+) -> Tuple[List[ShardMetadata], List[int]]:
+    """
+    Based the given sharding spec, we calculate the offset and local shard size.
+    We then build a ShardMetadata on top of the calculation result.
+
+    Args:
+        st_size (torch.Size): The size of the sharded tensor.
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
+            specification describing how the tensor is sharded.
+        world_size (int): number of ranks.
+
+    Returns:
+        A Tuple of the followings:
+            A List[`ShardMetadata`] which contains the metadata for the shard, including
+                offsets, lengths and device placement.
+            A List[int] which contains the ranks in the order of placement.
+    """
+    shard_dim = int(sharding_spec.dim)  # type: ignore[attr-defined]
+    shards_metadata = [None] * world_size
+    ranks = []
+    offsets = [0] * len(st_size)
+    split_size = get_split_size(st_size[shard_dim], world_size)
+    for idx, placement in enumerate(sharding_spec.placements):  # type: ignore[attr-defined]
+        ranks.append(placement.rank())
+        sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx)
+        local_tensor_size = list(st_size)
+        local_tensor_size[shard_dim] = sharded_dim_size
+        shards_metadata[placement.rank()] = ShardMetadata(  # type: ignore[call-overload]
+            shard_offsets=copy.deepcopy(offsets),
+            shard_sizes=local_tensor_size,
+            placement=placement,
+        )
+        offsets[shard_dim] += sharded_dim_size
+    return shards_metadata, ranks  # type: ignore[return-value]
+
+
+def reshuffle_local_shard(
+    local_shard: torch.Tensor,
+    st_size: torch.Size,
+    sharding_spec: shard_spec.ShardingSpec,
+    resharding_spec: shard_spec.ShardingSpec,
+    pg: ProcessGroup,
+) -> Tuple[List[Shard], List[ShardMetadata]]:
+    """
+    Reshuffle the local shard directly when the reshard dim is same as the original
+    sharding dim. Logically we do this in two step:
+    1. To collect all shards based on original sharding spec.
+    2. Reshard the tensor based on the given resharding spec.
+
+    In reality, we consolidate the two steps into one by sending the local tensor to
+    the new shard directly based on the resharding spec.
+
+    Args:
+        local_shard (Tensor): Local tensor stored in the current rank.
+        st_size (torch.Size): The size of the sharded tensor.
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
+            specification describing how the tensor is sharded originally.
+        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
+            specification describing how the tensor will be resharded.
+        pg (ProcessGroup): The process group to aggregate on.
+
+    Returns:
+        A Tuple of the followings:
+            A List[`Shard`] which contains the local tensor and its metadata.
+            A List[`ShardMetadata`] which contains the metadata for the shard, including
+                offsets, lengths and device placement.
+    """
+    current_rank = dist.get_rank(pg)
+    world_size = dist.get_world_size(pg)
+    # Build shards_metadata first.
+    shards_metadata, ranks = build_reshard_metadata(
+        st_size, resharding_spec, world_size
+    )
+    # Get input split size for all2all.
+    reshard_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
+    split_size = get_split_size(st_size[reshard_dim], world_size)
+    input_split_sizes = [0] * world_size
+    idx = get_idx_from_placements(sharding_spec.placements, current_rank)  # type: ignore[attr-defined]
+    new_rank = resharding_spec.placements[idx].rank()  # type: ignore[union-attr, attr-defined]
+    input_split_sizes[new_rank] = local_shard.size(reshard_dim)
+    # Get output split size for all2all.
+    output_split_sizes = [0] * world_size
+    new_idx = ranks.index(current_rank)
+    sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx)
+    output_split_sizes[new_rank] = sharded_dim_size
+    # Get gathered_input for all2all.
+    local_shard = local_shard.transpose(0, reshard_dim).contiguous()
+    gathered_input_size = list(local_shard.size())
+    gathered_input_size[0] = sharded_dim_size
+    gathered_input = torch.empty(gathered_input_size, device=local_shard.device, dtype=local_shard.dtype)
+    # all2all.
+    local_shard = all_to_all_single(
+        gathered_input,
+        local_shard,
+        input_split_sizes=input_split_sizes,
+        output_split_sizes=output_split_sizes,
+        group=pg,
+    )
+    local_tensor = local_shard.transpose(0, reshard_dim).contiguous()
+    local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
+    return local_shards, shards_metadata
+
+
+def reshard_local_shard(
+    local_tensor: torch.Tensor,
+    st_size: torch.Size,
+    sharding_spec: shard_spec.ShardingSpec,
+    resharding_spec: shard_spec.ShardingSpec,
+    pg: ProcessGroup,
+) -> Tuple[List[Shard], List[ShardMetadata]]:
+    """
+    Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is
+    different from the original sharding dim, we need to do two steps logically:
+    1. To collect all shards based on original sharding spec.
+    2. Reshard the tensor based on the given resharding spec.
+
+    In reality, we consolidate the two steps into one by sending each rank the new
+    shard based on the resharding spec.
+
+    Args:
+        local_tensor (Tensor): Local tensor stored in the current rank.
+        st_size (torch.Size): The size of the sharded tensor.
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
+            specification describing how the tensor is sharded originally.
+        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
+            specification describing how the tensor will be resharded.
+        pg (ProcessGroup): The process group to aggregate on.
+
+    Returns:
+        A Tuple of the followings:
+            A List[`Shard`] which contains the local tensor and its metadata.
+            A List[`ShardMetadata`] which contains the metadata for the shard, including
+                offsets, lengths and device placement.
+    """
+    current_rank = dist.get_rank(pg)
+    world_size = dist.get_world_size(pg)
+    current_sharding_dim = int(sharding_spec.dim)  # type: ignore[attr-defined]
+    reshard_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
+
+    # Build shards_metadata first.
+    shards_metadata, ranks = build_reshard_metadata(
+        st_size, resharding_spec, world_size
+    )
+
+    # Compute expected size
+    input_split_sizes = []
+    for metadata in shards_metadata:
+        input_split_sizes.append(metadata.shard_sizes[reshard_dim])
+    rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
+
+    if rearrange_input:
+        # Need to re-arrange reshard_dim of local_tensor before all2all.
+        indices: List[int] = []
+        for metadata in shards_metadata:
+            offset_start_idx = metadata.shard_offsets[reshard_dim]
+            split_size = metadata.shard_sizes[reshard_dim]
+            indices += range(offset_start_idx, offset_start_idx + split_size)
+        local_tensor = local_tensor.index_select(
+            reshard_dim, torch.tensor(indices, device=local_tensor.device)
+        )
+
+    # Because reshard_dim != original shard_dim. We need to compute the
+    # size of tensor from each rank.
+    output_tensor_list = [torch.tensor(1)] * world_size
+    split_size = get_split_size(st_size[current_sharding_dim], world_size)
+    rearrange_output_list = False
+    indices = []
+    for idx, placement in enumerate(sharding_spec.placements):  # type: ignore[attr-defined]
+        sharded_dim_size = get_chunked_dim_size(
+            st_size[current_sharding_dim], split_size, idx
+        )
+        output_tensor_size = list(st_size)
+        output_tensor_size[current_sharding_dim] = sharded_dim_size
+        output_tensor_size[reshard_dim] = input_split_sizes[current_rank]
+        output_tensor_list[
+            placement.rank()
+        ] = torch.empty(  # type: ignore[union-attr, index]
+            output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype
+        )
+        indices.append(placement.rank())  # type: ignore[union-attr, index, arg-type]
+        if idx != placement.rank():  # type: ignore[union-attr]
+            rearrange_output_list = True
+
+    # Perform autograd enabled all2all.
+    input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim)
+    input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple]
+    output_tensor_list = all_to_all(
+        output_tensor_list,
+        input_tensor_list,
+        group=pg,
+    )
+
+    if rearrange_output_list:
+        # Need to re-arrange original shard_dim of output_tensor_list.
+        output_tensor_list = [output_tensor_list[idx] for idx in indices]  # type: ignore[call-overload]
+    local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim)
+    local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
+    return local_shards, shards_metadata
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/shard.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/shard.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d0768bc9a05f1e9a159f8602e33b26c18c253ba
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/shard.py
@@ -0,0 +1,58 @@
+from dataclasses import dataclass
+from typing import List
+
+import torch
+from torch.distributed._shard.metadata import ShardMetadata
+from torch.distributed.remote_device import _remote_device
+
+
+@dataclass
+class Shard:
+    """
+    Container which holds the data for a shard as a Tensor and also
+    the associated metadata for that shard.
+
+    Args:
+        tensor(torch.Tensor): Local tensor for the shard.
+        metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`):
+            The metadata for the shard, including offsets, lengths and device placement.
+    """
+    __slots__ = ['tensor', 'metadata']
+    tensor: torch.Tensor
+    metadata: ShardMetadata
+
+    def __post_init__(self):
+        # verification between local tensor and metadata
+        if list(self.tensor.size()) != self.metadata.shard_sizes:
+            raise ValueError(
+                "Shard tensor size does not match with metadata.shard_lengths! "
+                f"Found shard tensor size: {list(self.tensor.size())}, "
+                f"metadata.shard_lengths: {self.metadata.shard_sizes}, "
+            )
+        placement_device = self.metadata.placement
+        if placement_device is not None and placement_device.device() != self.tensor.device:
+            raise ValueError(
+                f"Local shard tensor device does not match with local Shard's placement! "
+                f"Found local shard tensor device: {self.tensor.device}, "
+                f"local shard metadata placement device: {placement_device.device()}"
+            )
+
+    @classmethod
+    def from_tensor_and_offsets(cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int):
+        """
+        Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.
+
+        Args:
+            tensor(torch.Tensor): Local tensor for the shard.
+            shard_offsets(List[int]): List of integers specify the offset
+                of the shard on each dimension.
+            rank(int): Specify the rank for the shard.
+        """
+        shard_sizes = list(tensor.size())
+        placement = _remote_device(f"rank:{rank}/{str(tensor.device)}")
+        shard_meta = ShardMetadata(
+            shard_offsets=shard_offsets,
+            shard_sizes=shard_sizes,
+            placement=placement
+        )
+        return Shard(tensor, shard_meta)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/utils.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..343af02c96600212fa98ec0f5b773e719b909228
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharded_tensor/utils.py
@@ -0,0 +1,211 @@
+import collections.abc
+import copy
+from typing import Optional, List, Sequence
+
+import torch
+from torch.distributed import distributed_c10d
+from torch.distributed import rpc
+from torch.distributed._shard.sharding_spec._internals import (
+    check_tensor,
+    validate_non_overlapping_shards_metadata,
+)
+
+from torch.distributed._shard.metadata import ShardMetadata
+from .metadata import TensorProperties, ShardedTensorMetadata
+from .shard import Shard
+
+def _parse_and_validate_remote_device(pg, remote_device):
+    if remote_device is None:
+        raise ValueError("remote device is None")
+
+    worker_name = remote_device.worker_name()
+    rank = remote_device.rank()
+    device = remote_device.device()
+
+    # Validate rank, skip validation if rank is not part of process group.
+    if not distributed_c10d._rank_not_in_group(pg):
+        if rank is not None and (rank < 0 or rank >= distributed_c10d.get_world_size(pg)):
+            raise ValueError(f'Invalid rank: {rank}')
+
+    if worker_name is not None:
+        if not rpc._is_current_rpc_agent_set():
+            raise RuntimeError(f'RPC framework needs to be initialized for using worker names: {worker_name}')
+
+        workers = rpc._get_current_rpc_agent().get_worker_infos()
+        for worker in workers:
+            if worker.name == worker_name:
+                return worker.id, device
+
+        raise ValueError(f'Invalid worker name: {worker_name}')
+
+    return rank, device
+
+def _validate_output_tensor_for_gather(
+    my_rank: int,
+    dst_rank: int,
+    size: torch.Size,
+    dst_tensor: Optional[torch.Tensor],
+) -> None:
+    if dst_rank == my_rank:
+        if dst_tensor is None:
+            raise ValueError(
+                f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}"
+            )
+        if tuple(size) != (dst_tensor.size()):
+            raise ValueError(
+                f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())},"
+                f"but should be {tuple(size)}"
+            )
+    elif dst_tensor:
+        raise ValueError(
+            "Argument ``dst_tensor`` must NOT be specified "
+            "on non-destination ranks."
+        )
+
+def _flatten_tensor_size(size) -> torch.Size:
+    """
+    Checks if tensor size is valid, then flatten/return a torch.Size object.
+    """
+    if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
+        dims = list(*size)
+    else:
+        dims = list(size)
+
+    for dim in dims:
+        if not isinstance(dim, int):
+            raise TypeError(f'size has to be a sequence of ints, found: {dims}')
+
+    return torch.Size(dims)
+
+def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
+    if is_local:
+        assert isinstance(ranks, int)
+        if expected != actual:
+            raise ValueError(f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! "
+                             f"Found one local shard tensor {prop_name}={expected}, "
+                             f"the other local shard tensor {prop_name}={actual}.")
+    else:
+        # compare failure check across ranks, ranks list should have two rank
+        assert len(ranks) == 2
+        if expected != actual:
+            raise ValueError(f"ShardedTensor {prop_name} property does not match from different ranks! "
+                             f"Found {prop_name}={expected} on rank:{ranks[0]}, "
+                             f"and {prop_name}={actual} on rank:{ranks[1]}.")
+
+
+def build_metadata_from_local_shards(
+    local_shards: List[Shard],
+    global_size: torch.Size,
+    current_rank: int,
+    pg: distributed_c10d.ProcessGroup
+) -> ShardedTensorMetadata:
+
+    assert len(local_shards) > 0, "must have local shards!"
+    local_shard_metadatas: List[ShardMetadata] = []
+
+    first_shard_dtype = local_shards[0].tensor.dtype
+    first_shard_layout = local_shards[0].tensor.layout
+    first_shard_requires_grad = local_shards[0].tensor.requires_grad
+    first_shard_is_pinned = local_shards[0].tensor.is_pinned()
+
+    # 1). Validate local tensors and associated metadatas
+    for local_shard in local_shards:
+        local_shard_tensor = local_shard.tensor
+        local_shard_meta = local_shard.metadata
+        local_shard_metadatas.append(local_shard_meta)
+        rank, local_device = _parse_and_validate_remote_device(pg, local_shard_meta.placement)
+
+        if local_shard_tensor.layout != torch.strided or local_shard_tensor.layout != first_shard_layout:
+            raise ValueError(
+                f'Only torch.strided layout is currently supported, but found '
+                f'{local_shard_tensor.layout} on rank:{current_rank}!'
+            )
+
+        if not local_shard_tensor.is_contiguous():
+            raise ValueError('Only torch.contiguous_format memory_format is currently supported!')
+
+        if rank != current_rank:
+            raise ValueError(
+                f"Local shard metadata's rank does not match with the rank in its process group! "
+                f'Found current rank in the process group: {current_rank}, '
+                f"local ShardMetadata placement's rank: {rank}"
+            )
+        if local_shard_tensor.device != local_device:
+            raise ValueError(
+                f"Local shard tensor device does not match with local Shard's placement! "
+                f"Found local shard tensor device: {local_shard_tensor.device}, "
+                f"local shard metadata placement device: {local_device}"
+            )
+
+        _raise_if_mismatch(local_shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank)
+        _raise_if_mismatch(local_shard_tensor.is_pinned(), first_shard_is_pinned, "pin_memory", current_rank)
+        _raise_if_mismatch(local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank)
+        _raise_if_mismatch(local_shard_tensor.requires_grad, first_shard_requires_grad, "requires_grad", current_rank)
+
+    # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then
+    #    do all_gather to collect local_sharded_tensor_metadata from all ranks
+    local_tensor_properties = TensorProperties(
+        dtype=first_shard_dtype,
+        layout=first_shard_layout,
+        requires_grad=first_shard_requires_grad,
+        memory_format=torch.contiguous_format,
+        pin_memory=first_shard_is_pinned
+    )
+
+    local_sharded_tensor_metadata = ShardedTensorMetadata(
+        shards_metadata=local_shard_metadatas,
+        size=global_size,
+        tensor_properties=local_tensor_properties)
+
+    return local_sharded_tensor_metadata
+
+
+def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]):
+    global_sharded_tensor_metadata = None
+    global_metadata_rank = 0
+
+    for rank, rank_metadata in enumerate(gathered_metadatas):
+        if rank_metadata is None:
+            continue
+
+        if global_sharded_tensor_metadata is None:
+            global_sharded_tensor_metadata = copy.deepcopy(rank_metadata)
+            global_metadata_rank = rank
+        else:
+            _raise_if_mismatch(global_sharded_tensor_metadata.size,
+                               rank_metadata.size,
+                               "global_size",
+                               [global_metadata_rank, rank],
+                               is_local=False)
+
+            # don't need to check layout and memory format as we already checked in local shards validation stage
+            _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.dtype,
+                               rank_metadata.tensor_properties.dtype,
+                               "dtype",
+                               [global_metadata_rank, rank],
+                               is_local=False)
+
+            _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.requires_grad,
+                               rank_metadata.tensor_properties.requires_grad,
+                               "requires_grad",
+                               [global_metadata_rank, rank],
+                               is_local=False)
+
+            _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.pin_memory,
+                               rank_metadata.tensor_properties.pin_memory,
+                               "pin_memory",
+                               [global_metadata_rank, rank],
+                               is_local=False)
+            # pass all validations, extend shards metadata
+            global_sharded_tensor_metadata.shards_metadata.extend(rank_metadata.shards_metadata)
+
+    if global_sharded_tensor_metadata is not None:
+        # check if shards_metadata have overlap shards
+        validate_non_overlapping_shards_metadata(global_sharded_tensor_metadata.shards_metadata)
+
+        # check if the shards_metadata is compatible with global size of the sharded tensor.
+        check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size)
+    else:
+        raise ValueError("ShardedTensor have no local shards on all ranks!")
+
+    return global_sharded_tensor_metadata
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharder.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e732208b557377bdcae044f6b8176e6a275fd092
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharder.py
@@ -0,0 +1,27 @@
+import abc
+import torch.nn as nn
+
+class Sharder(abc.ABC):
+    """
+    This is an interface which allows user to create more advanced
+    sharding strategies that are not easily be composed by the
+    `ShardingSpec`.
+
+    :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could
+    take an object of the `Sharder` and call `shard` to shard the module,
+    then replace the original module with sharded module returned.
+    """
+    @abc.abstractmethod
+    def shard(self, module: nn.Module) -> nn.Module:
+        """
+        Shard a module base on the implementation of this method, and
+        return the sharded version of the module.
+
+        Args:
+            module (:class:`torch.nn.Module`):
+                The module to apply sharding to.
+        Returns:
+            A :class:`torch.nn.Module` object that represents a module
+            that's already been sharded.
+        """
+        pass
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..263e9d538f9a7f4442b12c086ef4620a21d49edd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__init__.py
@@ -0,0 +1,4 @@
+from .api import (
+    ShardingPlan,
+    ShardingPlanner
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d4b60a2af0f8be1461dc5d68a492863dc28f785
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93ce3151d3c243d80e6bd1f9898115922183c856
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/api.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b04a51fb3f1d15a05dd96fd6dbb561b76a5a01
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_plan/api.py
@@ -0,0 +1,86 @@
+import abc
+import torch.nn as nn
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Union
+
+from torch.distributed._shard.sharder import Sharder
+from torch.distributed._shard.sharding_spec import ShardingSpec
+
+@dataclass
+class ShardingPlan:
+    """
+    Representation of a sharding plan, describes how to shard a module
+    across hosts. `plan` is used to shard module parameters according to the spec provided,
+    `output_plan` and `return_local_tensor` are optional, they are used to specify the output
+    layout of a module with a spec, and when to convert back to data parallel fashion.
+
+    Args:
+        plan (Dict[str, Union[:class:`torch.distributed._shard.sharding_spec.ShardingSpec`,
+              :class:`torch.distributed._shard.sharder.Sharder`]):
+            a dict describes how to shard a module, there're currently two ways to shard a module:
+                1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of
+                   a parameter to a `ShardingSpec`.
+                2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module
+                   to a `Sharder` object.
+        output_plan (Dict[str, :class:`torch.distributed._shard.sharding_spec.ShardingSpec`), optional):
+            a dict specifies the layout of a module's output which produces a ShardedTensor,
+            keyed by the name of module to ShardingSpec("" in key means the root module).
+            Default: `None`
+        return_local_tensor (List[str], optional): a list of string, each element enables
+            a module's sharded output to be returned as a Tensor from its local shards to
+            ensure further processing in a data parallel fashion. ("" in list means the
+            root module).
+            Default: None
+    Example:
+      Suppose we want to shard a module with two linear layers and then run it with DDP, we also
+      want to convert the output of the second linear layer back to DDP, we can do it as follows:
+
+        >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
+        >>> class MyModule(nn.Module):
+        >>>     def __init__(self):
+        >>>        super().__init__()
+        >>>        self.fc1 = nn.Linear()
+        >>>        self.gelu = nn.GELU()
+        >>>        self.fc2 = nn.Linear()
+        >>>        self.relu = nn.Linear()
+        >>>
+        >>>     def forward(self, input):
+        >>>         return self.relu(self.fc2(self.gelu(self.fc1(input))))
+
+
+        >>> # xdoctest: +SKIP("Undefined spec1, spec2)
+        >>> sharding_plan = ShardingPlan(
+        >>>    plan={
+        >>>        "fc1.weight": spec1,
+        >>>        "fc2.weight": spec2
+        >>>    },
+        >>>    output_plan={
+        >>>        "fc2": output_spec
+        >>>    },
+        >>>    return_local_tensor=["fc2"]
+        >>> )
+    """
+    plan: Dict[str, Union[ShardingSpec, Sharder]]
+    output_plan: Optional[Dict[str, ShardingSpec]] = None
+    return_local_tensor: Optional[List[str]] = None
+
+
+class ShardingPlanner(abc.ABC):
+    """
+    Default ShardingPlanner interface, can be extended and
+    implement advanced sharding strategies.
+    """
+    @abc.abstractmethod
+    def build_plan(self, module: nn.Module) -> ShardingPlan:
+        """
+        Given a nn.Module, define how to shard the module across
+        ranks, return a ShardingPlan
+        Args:
+            module (:class:`torch.nn.Module`):
+                The module to apply sharding to.
+        Returns:
+            A :class:`torch.distributed._shard.sharding_plan.ShardingPlan` object that
+            represents how to shard the module.
+        """
+        pass
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f7a90f5156144a940d79475f65a94bc0ee49f1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__init__.py
@@ -0,0 +1,12 @@
+from .api import (
+    DevicePlacementSpec,
+    EnumerableShardingSpec,
+    PlacementSpec,
+    ShardingSpec,
+    _infer_sharding_spec_from_shards_metadata,
+)
+from .chunk_sharding_spec import (
+    ChunkShardingSpec as ChunkShardingSpec,
+)
+
+from torch.distributed._shard.metadata import ShardMetadata
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b06e7528523dd3fd083f00ef18b368857eb057b1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1285ef08f4eb974f152ade69c3f64301a767f725
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f2a2f3d0d6d1f070bf6384796010634747bf229
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d9ef84e97e415ce9410c9b68baa51eec1b1741b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/_internals.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/_internals.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ae66fe5e0332a00e01c38cef297bc630db846b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/_internals.py
@@ -0,0 +1,209 @@
+from typing import List, Optional, Tuple
+
+from torch.distributed._shard.metadata import ShardMetadata
+
+
+def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata):
+    """
+    Checks if two shards overlap.
+    """
+
+    # For each dim of each shard, check if one shard resides on the other
+    # end of second shard with respect to that dim. As an example for a 2D
+    # shard, we would check if one shard is above or on the left of the
+    # other shard.
+    ndims = len(shard1.shard_offsets)
+    for i in range(ndims):
+        if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]:
+            return False
+        if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]:
+            return False
+
+    return True
+
+
+def _find_nd_overlapping_shards(
+    shards: List[ShardMetadata], sharded_dims: List[int]
+) -> Optional[Tuple[int, int]]:
+    # Each rank has len(sharded_dims) tuples. Each tuple represent the
+    # [begin, end] (inclusive) pair of that dimension.
+    shard_intervals = [
+        [
+            (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1)
+            for dim in sharded_dims
+        ]
+        for s in shards
+    ]
+
+    for i in range(len(shards)):
+        shard_i = shard_intervals[i]
+        for j in range(i + 1, len(shards)):
+            shard_j = shard_intervals[j]
+            # For each dim of each shard, check if one shard resides on the other
+            # end of second shard with respect to that dim. As an example for a 2D
+            # shard, we would check if one shard is above or on the left of the
+            # other shard.
+            overlap = True
+            for interval_i, interval_j in zip(shard_i, shard_j):
+                if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]:
+                    overlap = False
+                    break
+            if overlap:
+                return (i, j)
+    return None
+
+
+def _find_1d_overlapping_shards(
+    shards: List[ShardMetadata], dim: int
+) -> Optional[Tuple[int, int]]:
+    # (begin, end, index_in_shards). Begin and end are inclusive.
+    intervals = [
+        (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i)
+        for i, s in enumerate(shards)
+    ]
+    intervals.sort()
+    for i in range(len(shards) - 1):
+        if intervals[i][1] >= intervals[i + 1][0]:
+            return (intervals[i][2], intervals[i + 1][2])
+    return None
+
+
+def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
+    """
+    Ensures none of the shards overlap with each other.
+
+    Args:
+        shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
+            each shard.
+    Raises:
+        ``ValueError`` if there's overlap in any two shards.
+    """
+    if not shards or len(shards) == 1:
+        return
+
+    sharded_dims: List[int] = []
+    for dim in range(len(shards[0].shard_offsets)):
+        for i in range(1, len(shards)):
+            if (
+                shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] or
+                shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim]
+            ):
+                sharded_dims.append(dim)
+                break
+
+    pair: Optional[Tuple[int, int]] = None
+    if len(sharded_dims) == 0:
+        # All shards are the same, all dims are not partitioned. Choose any 2.
+        pair = (0, 1)
+    elif len(sharded_dims) == 1:
+        # Shards are partitioned over only one dimension. Overlap can be found
+        # using a O(nlogn) overlapping interval algorithm.
+        pair = _find_1d_overlapping_shards(shards, sharded_dims[0])
+    else:
+        # Shards are partitioned over more than one dimension. Fall back to
+        # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist
+        # for 2D overlap, the implementation is not trivial and may not justify
+        # the time saving in most cases.
+        pair = _find_nd_overlapping_shards(shards, sharded_dims)
+
+    if pair:
+        raise ValueError(f'Shards {shards[pair[0]]} and {shards[pair[1]]} overlap')
+
+
+def check_tensor(shards_metadata, tensor_dims) -> None:
+    """
+    Checks if the shards_metadata is compatible with the provided tensor dims.
+
+    Args:
+        shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata`
+            objects representing each shard of the tensor.
+        tensor_dims(Sequence of int): Dimensions of tensor to verify
+    Raises:
+        ``ValueError`` if not compatible.
+    """
+
+    # If the tensor's volume matches the total volume of all shards and
+    # all shard boundaries are within tensor dims, we have a compatible
+    # sharding spec for this tensor. Note that we have already verified
+    # we don't have overlapping shards.
+    tensor_rank = len(tensor_dims)
+    shards_rank = len(shards_metadata[0].shard_offsets)
+    if tensor_rank != shards_rank:
+        raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}')
+
+    total_shard_volume = 0
+    for shard in shards_metadata:
+        shard_volume = 1
+        for i, shard_length in enumerate(shard.shard_sizes):
+            shard_volume *= shard_length
+            if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]:
+                raise ValueError(
+                    f'Shard offset {shard.shard_offsets[i]} and length '
+                    f'{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}')
+        total_shard_volume += shard_volume
+
+    tensor_volume = 1
+    for size in tensor_dims:
+        tensor_volume *= size
+
+    if total_shard_volume != tensor_volume:
+        # TODO: Can we improve this error message to point out the gaps?
+        raise ValueError(
+            f'Total volume of shards: {total_shard_volume} '
+            f'does not match tensor volume: {tensor_volume}, in other words '
+            f'all the individual shards do not cover the entire tensor')
+
+def get_split_size(dim_size, chunks):
+    """
+    Computes the split size inline with ``torch.chunk``
+
+    Args:
+        dim_size(int): Size of the dimension being chunked.
+        chunks(int): Number of chunks to create for ``dim_size``.
+
+    Returns:
+        An int indicating the split size to use.
+    """
+    return (dim_size + chunks - 1) // chunks
+
+def get_chunked_dim_size(dim_size, split_size, idx):
+    """
+    Computes the dim size of the chunk for provided ``idx`` given ``dim_size``
+    and ``split_size``.
+
+    Args:
+        dim_size(int): Size of the dimension being chunked.
+        split_size(int): The chunk size for each chunk of ``dim_size``.
+        idx(int): The index of chunk whose dim size is being requested.
+
+    Returns:
+        An int indicating the dim size of the chunk.
+    """
+    return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0)
+
+def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
+    """
+    Generate the start pos and offset length for the current rank for
+    chunk sharding.
+
+    Args:
+        sharding_dim_size(int): The dimension length which we shard on.
+        world_size(int): number of ranks.
+        spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`):
+            sharding spec.
+        rank(int): # of cuda process.
+
+    Returns:
+        start_pos(int): start position of sharded tensor on the given rank.
+        chunk_size(int): chunk size of sharded tensor on the given rank.
+    """
+    split_size = get_split_size(sharding_dim_size, world_size)
+    current_offsets = 0
+    start_pos = current_offsets
+    for idx, placement in enumerate(spec.placements):
+        chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+        if rank == placement.rank():
+            start_pos = current_offsets
+            break
+        current_offsets += chunk_size
+    return start_pos, chunk_size  # type: ignore[possibly-undefined]
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/api.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..e97286f092bda3e5b50901cbda3f4fd2c363e249
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/api.py
@@ -0,0 +1,242 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+import functools
+from typing import Callable, Dict, List, TYPE_CHECKING
+
+import torch
+
+from ._internals import (
+    check_tensor,
+    get_chunked_dim_size,
+    get_split_size,
+    validate_non_overlapping_shards_metadata
+)
+from torch.distributed._shard.metadata import ShardMetadata
+
+import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
+from torch.distributed._shard.op_registry_utils import _decorator_func
+
+if TYPE_CHECKING:
+    # Only include ShardedTensor when do type checking, exclude it
+    # from run-time to resolve circular dependency.
+    from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+class PlacementSpec(ABC):  # noqa: B024
+    """
+    Base class representing the placement of an entity. Subclasses of this
+    class can be used to specify customized placements which might not be
+    covered by existing APIs.
+    """
+    pass
+
+
+@dataclass
+class DevicePlacementSpec(PlacementSpec):
+    """
+    Associates placement of an entity with a single device.
+
+    Args:
+        device(:class:`torch.distributed._remote_device`): The device to place the entity on.
+    """
+
+    device: torch.distributed._remote_device
+
+    def __post_init__(self):
+        if not isinstance(self.device, torch.distributed._remote_device):
+            self.device = torch.distributed._remote_device(self.device)
+
+class ShardingSpec(ABC):
+    """
+    Base class representing sharding specifications.
+    """
+    @abstractmethod
+    def build_metadata(self,
+                       tensor_sizes: torch.Size,
+                       tensor_properties: sharded_tensor_meta.TensorProperties,
+                       ) -> sharded_tensor_meta.ShardedTensorMetadata:
+        """
+        Given a global tensor size, define how to shard a tensor like this shape
+        across ranks, return ShardedTensorMetadata
+        Args:
+            tensor_sizes (:class:`torch.Size`):
+                The tensor shape to shard on, a `torch.Size` object that represents the
+                tensor shape to be sharded according to the ShardingSpec.
+            tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties):
+                Tensor properties used to create a ShardedTensor.
+        Returns:
+            A :class:`ShardedTensorMetadata` object that encodes the information about
+            the layout of the ShardedTensor and its properties.
+        """
+
+    @abstractmethod
+    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
+        """
+        Given a global tensor on src_rank, shard this tensor
+        across ranks within the process group, return a ShardedTensor.
+        Args:
+            tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
+        Keyword args:
+            src_rank (int, optional): The source rank which is used as the ground truth of
+                the data for the parameter that would be sharded and scattered
+                across the rest of the ranks.
+                Default: 0.
+            process_group (ProcessGroup, optional): The process group to work on. If None,
+                the default process group will be used.
+        Returns:
+            A :class:`ShardedTensor` sharded from the given tensor.
+        """
+
+# Ops customized for a particular ShardingSpec.
+_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
+
+def _has_custom_op(sharding_spec, op):
+    """
+    Returns whether or not the ShardingSpec has a custom op implementation.
+    """
+    class_name = type(sharding_spec).__qualname__
+    return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name]
+
+def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group):
+    """
+    Calls the custom op for this ShardingSpec if it exists.
+    """
+    class_name = type(sharding_spec).__qualname__
+    if not _has_custom_op(sharding_spec, op):
+        raise RuntimeError(f'Custom op: {op} not registered for {class_name}')
+    func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op]
+    return func(types, args, kwargs, process_group)
+
+def custom_sharding_spec_op(sharding_spec_class, func):
+    """
+    Decorator to allow custom registration of ops.
+    Args:
+        sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
+        func(Callable): The op to override (ex: torch.bmm)
+    """
+    class_name = sharding_spec_class.__qualname__
+    if class_name not in _CUSTOM_SHARDING_SPEC_OPS:
+        _CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
+    return functools.partial(
+        _decorator_func,
+        op=func,
+        op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name]
+    )
+
+
+@dataclass
+class EnumerableShardingSpec(ShardingSpec):
+    """
+    This is a type of PlacementSpec that allows users to specify a generic
+    sharding scheme by enumerating exactly how each shard is laid out.
+
+    Args:
+        shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
+            each shard. Note that none of the shards should overlap.
+    """
+
+    shards: List[ShardMetadata]
+
+    def __post_init__(self):
+        if len(self.shards) == 0:
+            raise ValueError(f'Empty shard list provided: {self.shards}')
+
+        # Validate each shard has same rank.
+        rank = -1
+        for shard in self.shards:
+            if rank != -1 and rank != len(shard.shard_offsets):
+                raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}')
+            rank = len(shard.shard_offsets)
+
+        validate_non_overlapping_shards_metadata(self.shards)
+
+    def build_metadata(self,
+                       tensor_sizes: torch.Size,
+                       tensor_properties: sharded_tensor_meta.TensorProperties,
+                       ) -> sharded_tensor_meta.ShardedTensorMetadata:
+        # check if shards form a valid tensor
+        check_tensor(self.shards, tensor_sizes)
+        return sharded_tensor_meta.ShardedTensorMetadata(
+            self.shards,
+            tensor_sizes,
+            tensor_properties
+        )
+
+    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
+        # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec
+        raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!")
+
+
+def _infer_sharding_spec_from_shards_metadata(shards_metadata):
+    """
+    Infer the sharding spec from the metadata of each shard of a ShardedTensor.
+    If the tensor is sharded only on one dimension, we can then verify whether it's
+    a ChunkShardingSpec or not. The way to verify it is to first get the total length
+    and perform a chunk sharding with the given placements to see if we can have the
+    same chunk size as the given shards_metadata. If not, we assume it's enum sharded.
+
+    Args:
+        shards_metadata (List[ShardMetadata]): List of Metadata of local shards.
+
+    Returns:
+        A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding
+            spec for one sharded tensor.
+    """
+    placements = []
+    chunk_sharding_dim = None
+    chunk_offset_list = []
+    shard_size_list = []
+    shard_offset_list = []
+    # collect local shard metadatas from the global sharded_tensor_metadata
+    for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
+        placements.append(shard_metadata.placement)
+        local_offsets = shard_metadata.shard_offsets
+        chunk_offset_list.append(sum(local_offsets))
+        shard_size_list.append(shard_metadata.shard_sizes)
+        shard_offset_list.append(shard_metadata.shard_offsets)
+        shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0]
+        # If the offset is [0, 0, ..., 0] (all zeros),
+        # we cannot decide whether how the tensor is sharded.
+        if len(shard_dims) == 0:
+            continue
+        # If the offset is [0, N, .,0, M, 0, .., 0],
+        # we are sure it's sharded by more than one dimension.
+        if len(shard_dims) != 1:
+            chunk_sharding_dim = None
+            break
+        # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just
+        # one dimension, we need to make sure all ranks share the same dimension.
+        if not chunk_sharding_dim:
+            chunk_sharding_dim = shard_dims[0]
+        elif chunk_sharding_dim != shard_dims[0]:
+            chunk_sharding_dim = None
+            break
+
+    if chunk_sharding_dim is not None:
+        # Ensure we infer the correct placement order from offsets
+        placements = [
+            x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
+        ]
+
+        from .chunk_sharding_spec import ChunkShardingSpec
+        chunk_spec = ChunkShardingSpec(
+            dim=chunk_sharding_dim,
+            placements=placements,
+        )
+
+        shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list])
+        shard_total_length = sum(shard_sizes)
+        shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list])
+
+        chunks = len(placements)
+        split_size = get_split_size(shard_total_length, chunks)
+        chunk_shard_sizes = sorted(
+            [
+                get_chunked_dim_size(shard_total_length, split_size, idx)
+                for idx in range(chunks)
+            ]
+        )
+        # Should match ChunkShardingSpec offsets calculation
+        chunk_shard_offsets = [split_size * idx for idx in range(chunks)]
+        if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets:
+            return chunk_spec
+    return EnumerableShardingSpec(shards_metadata)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..723908b94f494a595b0fb4209ed6db4a5073c85c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
@@ -0,0 +1,202 @@
+from dataclasses import dataclass
+import torch
+import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
+from torch.distributed._shard.metadata import ShardMetadata
+from torch.distributed._shard.sharded_tensor.shard import Shard
+from torch.distributed._shard.sharded_tensor.utils import (
+    _parse_and_validate_remote_device
+)
+from torch.distributed._shard._utils import narrow_tensor
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as distributed_c10d
+from typing import List, Union, TYPE_CHECKING
+from ._internals import (
+    get_chunked_dim_size,
+    get_split_size,
+)
+
+from .api import ShardingSpec
+
+if TYPE_CHECKING:
+    # Only include ShardedTensor when do type checking, exclude it
+    # from run-time to resolve circular dependency.
+    from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+@dataclass
+class ChunkShardingSpec(ShardingSpec):
+    """
+    This is a type of PlacementSpec that defines the placement as being sharded
+    across multiple devices. In particular, it represents sharding a Tensor
+    along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
+
+    The semantics of how a tensor is partitioned is inline with
+    :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the
+    specified ``dim`` and ``chunks`` in torch.chunk is the number of elements
+    in the placement specified.
+
+    Args:
+        dim (int or str):
+            The dimension to shard on, could be an integer representing the
+            dimension or a string in case of named tensors where dimensions are
+            named. Note that named tensor support is not added yet.
+        placement(List[Union[_remote_device, str]]):
+            Specifies the placement of each shard of the Tensor. The size of
+            the list represents the number of shards to be created. This could
+            be a list of
+            :class:`torch.distributed._remote_device`'s. This list
+            could also contain a string which represents remote
+            device as accepted by
+            :class:`torch.distributed._remote_device`
+    """
+
+    ShardingDim = Union[int, str]
+
+    dim: ShardingDim
+    placements: List[Union[torch.distributed._remote_device, str]]
+
+    def __post_init__(self):
+        self._verify_dim(self.dim)
+        for i, remote_device in enumerate(self.placements):
+            if not isinstance(remote_device, torch.distributed._remote_device):
+                self.placements[i] = torch.distributed._remote_device(remote_device)
+
+    @staticmethod
+    def _verify_dim(dim):
+        # Validate the sharding spec.
+        # TODO: support named dimension
+        if isinstance(dim, str):
+            raise NotImplementedError(
+                "ChunkShardingSpec does not support named dimension yet!"
+            )
+
+        if not isinstance(dim, int):
+            raise ValueError(
+                f"Sharding dim needs to be an integer, found: {dim}"
+            )
+
+    def build_metadata(self,
+                       tensor_sizes: torch.Size,
+                       tensor_properties: sharded_tensor_meta.TensorProperties,
+                       ) -> sharded_tensor_meta.ShardedTensorMetadata:
+        tensor_num_dim = len(tensor_sizes)
+
+        self._verify_dim(self.dim)
+        if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim:  # type: ignore[operator]
+            raise ValueError(f"Invalid sharding dim: {self.dim}")
+
+        shards_metadata = []
+        sharding_dim_size = tensor_sizes[self.dim]  # type: ignore[index]
+        chunks = len(self.placements)
+        split_size = get_split_size(sharding_dim_size, chunks)
+        for idx, placement in enumerate(self.placements):
+            # generate ShardMetadata for each placement device
+            chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+            shard_size = list(tensor_sizes)
+            current_offsets = [0] * tensor_num_dim
+            current_offsets[self.dim] = split_size * idx  # type: ignore[index]
+            shard_size[self.dim] = chunked_dim_size  # type: ignore[index]
+
+            shard_metadata = ShardMetadata(
+                shard_offsets=current_offsets,
+                shard_sizes=shard_size,
+                placement=placement,
+            )
+            shards_metadata.append(shard_metadata)
+
+        return sharded_tensor_meta.ShardedTensorMetadata(
+            shards_metadata,
+            tensor_sizes,
+            tensor_properties
+        )
+
+
+    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
+        """
+        Args:
+            src_rank: group rank relative to ``process_group``
+
+            N.B. If ``process_group`` is None, ``src_rank`` is a global rank.
+        """
+        # relative imports to avoid circular dependency
+        from torch.distributed._shard.sharded_tensor import (
+            ShardedTensor
+        )
+        tensor_properties = sharded_tensor_meta.TensorProperties(
+            dtype=tensor.dtype,
+            layout=tensor.layout,
+            requires_grad=tensor.requires_grad,
+            memory_format=torch.contiguous_format,
+            pin_memory=tensor.is_pinned()
+        )
+        current_rank = dist.get_rank(process_group)
+        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
+        local_shards = []
+        local_tensor = None
+        local_metadata = None
+        tensors_to_scatter = [None] * dist.get_world_size(process_group)
+
+        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
+        chunks = len(self.placements)
+        split_size = get_split_size(sharding_dim_size, chunks)
+        scatter_shape = list(tensor.size())
+        scatter_shape[self.dim] = split_size  # type: ignore[index]
+
+        for shard_meta in tensor_meta.shards_metadata:
+            rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
+            if current_rank == src_rank:
+                # Reshape to get shard for this rank and we don't want autograd
+                # recording here for the narrow op and 'local_shard' should be a
+                # leaf variable in the autograd graph.
+                narrowed_tensor = narrow_tensor(tensor, shard_meta)
+                if shard_meta.shard_sizes[self.dim] < split_size:  # type: ignore[index]
+                    # for the last shard that might be smaller to other shards
+                    # resize the narrowed tensor to the same size and use it for
+                    # the scatter collective as dist.scatter requires same size
+                    # inputs on every rank
+                    tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
+                else:
+                    tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
+
+                tensors_to_scatter[rank] = tensor_to_scatter
+
+            if current_rank == rank:
+                local_tensor = torch.empty(
+                    scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
+                local_metadata = shard_meta
+
+        # each rank should have local_tensor and local_metadata initialized if we build
+        # the metadata list in a correct way.
+        assert local_tensor is not None
+        assert local_metadata is not None
+
+        # Scatter the shards to all ranks in the pg
+        # scatter takes the global rank as ``src``
+        src_for_scatter = src_rank
+        if process_group is not None and process_group is not distributed_c10d._get_default_group():
+            src_for_scatter = distributed_c10d.get_global_rank(process_group, src_for_scatter)
+
+        dist.scatter(
+            local_tensor,
+            scatter_list=tensors_to_scatter if current_rank == src_rank else None,
+            src=src_for_scatter,
+            group=process_group
+        )
+
+        if list(local_tensor.size()) != local_metadata.shard_sizes:
+            # detach again after receiving to ensure local shards remain a leaf node
+            local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()
+
+        # Sync requires_grad to local_shard.
+        local_tensor.requires_grad = tensor.requires_grad
+
+        local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
+
+        st = ShardedTensor._init_from_local_shards_and_global_metadata(
+            local_shards,
+            tensor_meta,
+            process_group=process_group)
+
+        # Manually set sharding_spec
+        st._sharding_spec = self
+
+        return st
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c2ff6960456ccaf873b1b21d46a740c37a69788
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b725af33c32ef205fee58422a021e9c37f7e3ecc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a24d3e2d9719147b563ba12a798317a1fe4bb14
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4858d14bed6274985f07b8bc8b667dc929cfec53
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..083bd959a9dc9ddf3b7c7dc12e384194f98dced0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
@@ -0,0 +1,349 @@
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharded_tensor._ops._common import _sharded_op_common
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec._internals import (
+    get_chunk_sharding_params,
+    get_chunked_dim_size,
+    get_split_size,
+)
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed.nn.functional import (
+    _all_gather_base,
+    all_reduce,
+    all_to_all_single,
+)
+
+
+def _chunk_sharding_spec_check(spec, op):
+    """
+    For the given op implementation check if the sharding spec is ChunkShardingSpec.
+    """
+    if not isinstance(spec, ChunkShardingSpec):
+        raise NotImplementedError(
+            f"Only ChunkShardingSpec supported for '{op.__name__}'."
+        )
+
+
+def _register_sharded_op_on_local_tensor(
+    op, early_stop_func=None, extra_check=None, customized_func=None
+):
+    """
+    Handles ``__torch_function__`` dispatch for ops which are performed on
+    the single local tensor of the sharded tensor such as op like
+    ``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
+
+    For more complicated ops, a customized func can be used to generate
+    the new local tensor, sharding spec and sharded tensor size.
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+        customized_func (Callable, optional): the func for customized logic
+            to generate the new local tensor, sharding spec and sharded tensor size.
+            Default: if ``None``, we simply lower to the real op call with
+                the single local tensor of the st.
+
+    Return:
+        func (Callable): registered implementation for sharded op for
+        ``__torch_function__`` dispatch.
+    """
+
+    @custom_sharding_spec_op(ChunkShardingSpec, op)
+    @_sharded_op_common(op, early_stop_func, extra_check)
+    def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None):
+        st = args[0]
+        sharding_spec = st.sharding_spec()
+        if len(st.local_shards()) != 1:
+            raise TypeError(
+                f"torch function '{op.__name__}', with args: {args} and "
+                f"kwargs: {kwargs} only supported for single local tensor!"
+            )
+        st_size = st.size()
+        if customized_func:
+            local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
+        else:
+            args = (st.local_tensor(), *args[1:])
+            local_tensor = op(*args, **kwargs)
+        return ShardedTensor._init_from_local_tensor(
+            local_tensor.contiguous(),
+            sharding_spec,
+            st_size,  # type: ignore[arg-type]
+            process_group=pg,
+            init_rrefs=st._init_rrefs,
+        )
+
+
+def _handle_col_wise_sharding_base(
+    op_func,
+    col_dim,
+    input,
+    world_size,
+    weight,
+    local_shard,
+    pg,
+    gathered_inputs,
+    mode=None,
+    gathered_per_sample_weights=None,
+    gathered_offsets=None,
+    padding_idx=None,
+):
+    """
+    For col-wise sharding of weight, lots of logic are common.
+    So we extract the common logic and put in this function:
+    Step 1. To get input from each rank and
+    Step 2. To perform the op on the concatenated tensor.
+    Step 3. To distribute results to each rank with col rearrangement.
+    Step 4. To concatenate all results from all ranks.
+
+    Args:
+        op_func: operator which is applied to the input tensor.
+        col_dim: dim of result tensor after the operation.
+        input: tensor to be applied op on.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: col-wise sharded weight tensor.
+        pg: process group.
+        gathered_inputs: list of inputs from all ranks. If specified, we
+            don't need to communicate with each rank any more.
+        mode: aggregation mode of EmbeddingBag.
+        gathered_per_sample_weights: per_sample_weights across all ranks.
+        gathered_offsets: offsets across all ranks.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+
+    Return: final result of input being applied with the op.
+    """
+    # run the operator's function for all the inputs.
+    results = []
+    for i, inp in enumerate(gathered_inputs):
+        if op_func == torch.nn.functional.embedding_bag:
+            result = op_func(
+                inp,
+                local_shard,
+                offsets=gathered_offsets[i] if gathered_offsets is not None else None,
+                mode=mode,
+                per_sample_weights=gathered_per_sample_weights[i]
+                if gathered_per_sample_weights is not None
+                else None,
+                padding_idx=padding_idx,
+            )
+        elif op_func == torch.nn.functional.embedding:
+            result = op_func(
+                inp,
+                local_shard,
+                padding_idx=padding_idx,
+            )
+        else:
+            result = op_func(inp, local_shard)
+        results.append(torch.transpose(result, 0, col_dim))
+
+    # Distribute results to each rank with col rearrangement.
+    output = _result_distribute_with_col_rearrange(
+        results, input, world_size, weight, pg
+    )
+
+    # transpose the output and return result.
+    return torch.transpose(output, 0, col_dim)
+
+
+def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg):
+    """
+    For col-wise sharding of weight, we need to distribute
+    results to each rank. We do them in this function.
+    Note that, if the index in the Sharding Spec is not equal to
+    the rank number, we need to do the rearrangement based on the
+    order given by the Sharding Spec (placement).
+
+    Args:
+        results: results from ops applied to inputs from all ranks.
+            We need to distribute them back to their original ranks.
+        input: tensor to be applied op to.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        pg: process group.
+
+    Return: column rearranged result.
+    """
+    # Process results and outputs for all2all.
+    sharding_dim = weight._sharding_spec.dim
+    sharding_dim_size = weight.size(sharding_dim)
+    dims = list(results[0].size())
+    dims[0] = sharding_dim_size
+    combined_results = torch.cat(results)
+    output = torch.empty(
+        *dims, device=combined_results.device, dtype=combined_results.dtype
+    )
+
+    # Compute output splits
+    split_size = get_split_size(sharding_dim_size, world_size)
+    output_split_sizes = [0] * world_size
+    for idx, placement in enumerate(weight._sharding_spec.placements):
+        output_split_sizes[placement.rank()] = get_chunked_dim_size(
+            sharding_dim_size, split_size, idx
+        )
+
+    # distribute the outputs using all2all.
+    output = all_to_all_single(
+        output, combined_results, output_split_sizes=output_split_sizes, group=pg
+    )
+
+    # Check if we need to rearrange columns appropriately for output.
+    rearrange_columns = any(
+        idx != placement.rank()
+        for idx, placement in enumerate(weight._sharding_spec.placements)
+    )
+    if not rearrange_columns:
+        return output
+
+    indices = []
+    for placement in weight._sharding_spec.placements:
+        dim_size = output_split_sizes[placement.rank()]
+        start = sum(
+            [
+                split_size if i < placement.rank() else 0
+                for i, split_size in enumerate(output_split_sizes)
+            ]
+        )
+        indices += list(range(start, start + dim_size))
+
+    return output.index_select(0, torch.tensor(indices, device=output.device))
+
+
+def _handle_max_norm_col_wise(
+    max_norm,
+    norm_type,
+    local_shard,
+    input,
+    world_size,
+    gathered_inputs,
+    pg,
+):
+    """
+    For col-wise sharding of weight, we need to aggregate the
+    norm across all ranks before we can perform the proper re-norm.
+    Note that, the max_norm logic is only applied to the embedding
+    indices that are looked up and not the whole shard.
+
+    Args:
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        local_shard: col-wise shared local weight used for lookup.
+        input: tensor to be applied op to.
+        world_size: number of ranks.
+        gathered_inputs: list of inputs from all ranks.
+        pg: process group.
+
+    Return:
+        local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger
+            than it.
+
+    """
+    norm_type = norm_type if norm_type is not None else 2.0
+    unique_inp = torch.unique(torch.cat(gathered_inputs))
+    local_shard_sum = torch.sum(
+        torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype
+    )
+    # For col-wise sharding, we need to first aggregate the powered sum
+    # from each rank first and then calculate the norm.
+    local_shard_sum = all_reduce(local_shard_sum, group=pg)
+    local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type)
+    max_norm_tensor = torch.full(
+        (local_shard.size(0),),
+        float("inf"),
+        dtype=local_shard.dtype,
+        device=input.device,
+    )
+    max_norm_tensor[unique_inp] = max_norm
+    local_shard_t = local_shard.t().contiguous()
+    normalized_tensor = torch.where(
+        local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm
+    )
+    # Make sure divisor is not zero.
+    local_shard_norm[local_shard_norm == 0.0] = 1.0
+    local_shard_norm_renormed = (
+        torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm)
+        .t()
+        .contiguous()
+    )
+    return local_shard_norm_renormed
+
+
+def _all_gather_base_input(input, pg):
+    """
+    Use _all_gather_base to get a concatenated input from each rank.
+
+    Args:
+        input: tensor to be applied op on.
+        pg: process group.
+
+    Returns:
+        gathered_inputs: input gathered from each rank and concat by dim 0.
+    """
+    # allgather the inputs first.
+    gather_inp_size = list(input.size())
+    gather_inp_size[0] = input.size(0) * dist.get_world_size(pg)
+    gather_inp = torch.empty(gather_inp_size, device=input.device, dtype=input.dtype)
+    return _all_gather_base(gather_inp, input, group=pg)
+
+
+def _handle_row_wise_mask(gather_inp, padding_idx, weight, world_size, rank):
+    """
+    Mask the input for embedding look-up for IDs which are not stored
+    on the current rank. This function also adjust the ``padding_idx``
+    so that it is only used on the rank where the corresponding row is
+    stored.
+
+    Note that, with ``max_norm`` flag on, only weights of rows being
+    looked up will be re-normed. So we need an extra row for masked ID
+    so that it does not affect the final result and ``max_norm``.
+
+    Args:
+        gather_inp: tensor to be applied op on gathered from all ranks.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+        weight: weight tensor of Embedding look-up table.
+        world_size: number of ranks.
+        rank: # of cuda process.
+
+    Returns:
+        lookup_input: Tensor of masked input.
+        padding_idx: adjusted padding_idx.
+        padding_row: The extra row we used during lookup so that
+            looking up does not affect ``max_norm``.
+    """
+    (start_pos, chunk_size) = get_chunk_sharding_params(
+        weight.size(0), world_size, weight._sharding_spec, rank
+    )
+    mask = (gather_inp < start_pos) | (gather_inp >= start_pos + chunk_size)
+    lookup_input = gather_inp.clone() - start_pos
+    lookup_input[mask] = chunk_size
+    if (
+        padding_idx is not None
+        and padding_idx >= start_pos
+        and padding_idx < (start_pos + chunk_size)
+    ):
+        padding_idx = padding_idx - start_pos
+    else:
+        padding_idx = None
+
+    # When max_norm is set, it will only re-norm the row being looked up.
+    padding_row = torch.zeros(
+        1, weight.size(1), device=gather_inp.device, dtype=weight.dtype
+    )
+    return lookup_input, padding_idx, padding_row
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b2bca833975d425b4189398642ab2bf75a33b93
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
@@ -0,0 +1,293 @@
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed.nn.functional import all_gather, reduce_scatter
+
+from ._common import (
+    _all_gather_base_input,
+    _handle_col_wise_sharding_base,
+    _handle_max_norm_col_wise,
+    _handle_row_wise_mask,
+)
+
+
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding)
+def sharded_embedding(types, args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
+    This method computes a sharded embedding lookup and has the following limitations:
+
+    1. Supports only sharding of ``weight``.
+    2. Supports only ``ChunkShardingSpec``.
+    3. Supports only a single local shard per rank.
+    4. Supports all specs except for scale_grad_by_freq, sparse, etc.
+
+    Based on the dimension that the weight is sharded on, there are two
+    algorithms:
+
+    ROWWISE SHARDING
+    ================
+    For row-wise sharding the weight is sharded on dimension 0.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (10 x 17) and W is sharded across
+    4 GPUs creating 3 shard of (3 x 17) and 1 shard of (1 x 17).
+    The algorithm is as follows:
+
+    1. First the input is all gathered to all ranks, since this is SPMD and
+       input is actually sharded across all ranks. The inputs then become a
+       4 (4 x 6) tensor on each rank. For example if the given input is
+       tensor([[6, 5, 2, 9, 6, 3],
+               [3, 1, 2, 4, 7, 6],
+               [4, 0, 4, 9, 8, 9],
+               [8, 6, 6, 4, 6, 1]])
+       on rank 0.
+       Then on every rank, we will have this tensor.
+       If input itself is already replicated, no all-gather will be done.
+    2. Next, we mask the ID which are not stored on that rank.
+       For example on rank 0, we store ID [0, 1, 2]. We only keep the ID
+       inside the set of numbers. The rest of them will be masked to an extra row.
+       The masked matrix will be used for embedding look up and is like:
+       tensor([[4, 4, 2, 4, 4, 4],
+               [4, 1, 2, 4, 4, 4],
+               [4, 0, 4, 4, 4, 4],
+               [4, 4, 4, 4, 4, 1]])
+       The reason of having an extra row (aka, number 4 in the example) is
+       because when max_norm is specified only weight which has looked will
+       be re-normed so mask IDs whose embeddings are not stored in current
+       rank will to an extra row will ensure max_norm still works as expected.
+    3. If max_norm is specified, the extra row guarantees that the mask ID will
+       not affect the behavior of weigh re-norm.
+
+    COLWISE SHARDING
+    ================
+    For col-wise sharding the weight is sharded on dimension 1.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across
+    4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2).
+    The algorithm is as follows:
+
+    1. First the input is broadcasted to all ranks, since this is SPMD we
+       actually do an all_gather for all the inputs resulting in 4 (4 x 6)
+       inputs on each rank.
+    2. Next we perform local embedding lookup operation by apply each
+       input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last).
+       This results in 4 (5 x 6 x 4) ((2 x 6 x 4) for the last) matrices
+       on each rank. We transpose dim 0 and dim 2.
+    3. Next, we concat these 4 matrices and perform an all2all to share the
+       appropriate (5 x 6 x 4) or (2 x 6 x 4) matrices to each rank.
+    4. Now, each rank receives a (17 x 6 x 4) matrix which is basically the
+       size of the result we need.
+    5. If placements are not in order any appropriate rearrangement of columns
+       are done for the (17 x 6 x 4) matrix and finally we transpose the
+       dim 0 and dim 2 again.
+    6. If max_norm is specified, we manually sum up the norm and renorm. Because
+       the renorm must be in place, we need to override the local_shard to mimic
+       this behavior.
+    """
+    # Validate input params
+    _validate_embedding_param(args, kwargs)
+
+    input = args[0]
+    weight = args[1]
+    max_norm = kwargs.get("max_norm")
+    norm_type = kwargs.get("norm_type")
+    padding_idx = kwargs.get("padding_idx")
+
+    local_shard = weight.local_tensor().contiguous()
+    sharding_dim = weight._sharding_spec.dim
+    world_size = dist.get_world_size(pg)
+    rank = dist.get_rank(pg)
+
+    if sharding_dim == 1:
+        output, local_shard = _handle_col_wise_sharding(
+            input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg
+        )
+        weight.local_shards()[0].tensor = local_shard
+        return output
+    elif sharding_dim == 0:
+        return _handle_row_wise_sharding(
+            input,
+            world_size,
+            weight,
+            local_shard,
+            max_norm,
+            norm_type,
+            padding_idx,
+            rank,
+            pg,
+        )
+    else:
+        raise RuntimeError(
+            f"nn.Embedding weight sharded on dim {sharding_dim} not supported!"
+        )
+
+
+def _validate_embedding_param(args, kwargs):
+    """
+    Validate input params of sharded embedding op.
+
+    Args:
+        input: list of ID used for lookup.
+        weight: sharded weight tensor.
+        kwargs: same as normal Embedding.
+
+    Return: None.
+    """
+
+    input = args[0]
+    weight = args[1]
+    max_norm = kwargs.get("max_norm")
+    scale_grad_by_freq = kwargs.get("scale_grad_by_freq")
+    sparse = kwargs.get("sparse")
+
+    # Validate types
+    if not isinstance(input, torch.Tensor):
+        raise TypeError("input need to be torch.Tensor")
+    if not isinstance(weight, ShardedTensor):
+        raise TypeError("weight needs to be ShardedTensor")
+    weight_size = weight.size()
+    if len(weight_size) != 2:
+        raise ValueError("Weight needs to have exactly 2 dims")
+    if int(torch.min(input).item()) < 0:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.min(input).item()),
+            weight_size[1],
+        )
+    if int(torch.max(input).item()) >= weight_size[0]:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.max(input).item()),
+            weight_size[1],
+        )
+    if scale_grad_by_freq:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!'
+        )
+    if sparse:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "sparse" not supported!'
+        )
+    if max_norm and max_norm <= 0.0:
+        raise ValueError('"max_norm" must be larger than zero!')
+
+    if not isinstance(weight._sharding_spec, ChunkShardingSpec):
+        raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
+    if len(weight.local_shards()) != 1:
+        raise ValueError("Only one local shard supported!")
+
+
+def _handle_col_wise_sharding(
+    input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg
+):
+    """
+    Entry-point function to handle the logic of col-wise sharding of weight
+    for embedding. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: col-wise shared local weight used for lookup.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+        pg: process group.
+
+    Returns: final result of lookup.
+    """
+    # allgather the inputs first for non Replicated Tensor.
+    gathered_inputs = all_gather(input, group=pg)
+
+    if max_norm is not None:
+        # max_norm changes the weight in-place
+        local_shard = _handle_max_norm_col_wise(
+            max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg
+        )
+
+    output = _handle_col_wise_sharding_base(
+        torch.nn.functional.embedding,
+        len(input.size()),
+        input,
+        world_size,
+        weight,
+        local_shard,
+        pg,
+        gathered_inputs,
+        padding_idx=padding_idx,
+    )
+    return (output, local_shard)
+
+
+def _handle_row_wise_sharding(
+    input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg
+):
+    """
+    Entry-point function to handle the logic of row-wise sharding of weight
+    for embedding. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: row-wise shared local weight used for lookup.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+        rank: # of cuda process.
+        pg: process group.
+
+    Returns: final result of lookup.
+    """
+    # allgather the inputs first for non Replicated Tensor.
+    gather_inp = _all_gather_base_input(input, pg)
+
+    # Mask the input according to sharding spec.
+    lookup_input, padding_idx, padding_row = _handle_row_wise_mask(
+        gather_inp, padding_idx, weight, world_size, rank
+    )
+
+    # When input is a large tensor, the value of weight is changed.
+    # This is a walk-around for now. GH issue: #81717
+    if max_norm is not None:
+        torch.nn.functional.embedding(
+            torch.unique(lookup_input)[:-1],
+            local_shard,
+            padding_idx=padding_idx,
+            max_norm=max_norm,
+            norm_type=norm_type,
+        )
+        max_norm = None
+
+    local_input_embeddings = torch.nn.functional.embedding(
+        lookup_input,
+        torch.cat([local_shard, padding_row]),
+        padding_idx=padding_idx,
+        max_norm=max_norm,
+        norm_type=norm_type,
+    )
+
+    # TODO: Make the result a PartialTensor.
+    local_shards = local_input_embeddings.chunk(pg.size())
+    return reduce_scatter(
+        torch.empty_like(local_shards[0]),
+        list(local_shards),
+        group=pg,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
new file mode 100644
index 0000000000000000000000000000000000000000..b95f5334750e74348e2f0bfd52359d64e2ef899d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
@@ -0,0 +1,476 @@
+
+from typing import cast, List
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import ReduceOp
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed.nn.functional import all_gather, reduce_scatter
+
+from ._common import (
+    _all_gather_base_input,
+    _handle_col_wise_sharding_base,
+    _handle_max_norm_col_wise,
+    _handle_row_wise_mask,
+)
+
+
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding_bag)
+def sharded_embedding_bag(types, args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
+    This method computes a sharded embedding bag aggregation and has the following limitations:
+
+    1. Supports only sharding of ``weight``.
+    2. Supports only ``ChunkShardingSpec``.
+    3. Supports only a single local shard per rank.
+    4. Supports all specs except for scale_grad_by_freq, sparse, etc.
+
+    Based on the dimension that the weight is sharded on, there are two
+    algorithms:
+
+    ROWWISE SHARDING
+    ================
+    For row-wise sharding the weight is sharded on dimension 0.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across
+    4 GPUs creating 4 shard of (4 x 17).
+    The algorithm is as follows:
+
+    1. First the input is all gathered to all ranks, since this is SPMD and
+       input is actually sharded across all ranks. The inputs then become a
+       4 (4 x 6) tensor on each rank. For example if the given input is
+       tensor([[6, 5, 2, 9, 6, 3],
+               [3, 1, 2, 4, 7, 6],
+               [4, 0, 4, 9, 8, 9],
+               [8, 6, 6, 4, 6, 1]])
+       on rank 0.
+       Then on every rank, we will have this tensor.
+       If input itself is already replicated, no all-gather will be done.
+    2. Next, we mask the ID which are not stored on that rank.
+       For example on rank 0, we store ID [0, 1, 2]. We only keep the ID
+       inside the set of numbers. The rest of them will be masked to an extra row.
+       The masked matrix will be used for embedding look up and is like:
+       tensor([[4, 4, 2, 4, 4, 4],
+               [4, 1, 2, 4, 4, 4],
+               [4, 0, 4, 4, 4, 4],
+               [4, 4, 4, 4, 4, 1]])
+    3. If ``max_norm`` is specified, the extra row guarantees that the mask ID will
+       not affect the behavior of weigh re-norm.
+    4. The example above only happens in one rank and each rank does a very similar thing.
+       For "Mean" mode we need to divide by either column size (2D) or the interval length
+       defined by the offset (excluding the row specified in ``padding_idx``).
+       We also need to mask the unexisting row to neg Inf so that negative value does not
+       gets wiped out in the "Max" mode.
+
+    COLWISE SHARDING
+    ================
+    For col-wise sharding the weight is sharded on dimension 1.
+
+    The overall algorithm can be best explained with an example. Let's assume
+    the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across
+    4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2).
+    The algorithm is as follows:
+
+    1. First the input is broadcasted to all ranks, since this is SPMD we
+       actually do an all_gather for all the inputs resulting in 4 (4 x 6)
+       inputs on each rank.
+    2. Next we perform local embedding bag operation under the given mode by
+       apply each input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last).
+       This results in 4 (5 x 4) ((2 x 4) for the last) matrices on each rank.
+       We transpose the aggregation result.
+    3. Next, we concatenate these 4 matrices and perform an all2all to share the
+       appropriate (5 x 4) or (2 x 4) matrices to each rank.
+    4. Now, each rank receives a (17 x 4) matrix which is basically the
+       size of the result we need.
+    5. If placements are not in order any appropriate rearrangement of columns
+       are done for the (17 x 4) matrix and finally we transpose the output again.
+    6. If max_norm is specified, we manually sum up the norm and renorm. Because
+       the renorm must be in place, we need to override the local_shard to mimic
+       this behavior.
+    """
+    # Validate input params
+    _validate_embedding_bag_param(args, kwargs)
+
+    input = args[0]
+    weight = args[1]
+    offsets = kwargs.get("offsets")
+    per_sample_weights = kwargs.get("per_sample_weights")
+    mode = kwargs.get("mode")
+    max_norm = kwargs.get("max_norm")
+    norm_type = kwargs.get("norm_type")
+    include_last_offset = kwargs.get("include_last_offset")
+    padding_idx = kwargs.get("padding_idx")
+
+    local_shard = weight.local_tensor().contiguous()
+    sharding_dim = weight._sharding_spec.dim
+    world_size = dist.get_world_size(pg)
+    rank = dist.get_rank(pg)
+    if include_last_offset:
+        offsets = offsets[:-1]
+
+    if sharding_dim == 1:
+        output, local_shard = _handle_col_wise_sharding(
+            input,
+            world_size,
+            weight,
+            local_shard,
+            offsets,
+            per_sample_weights,
+            mode,
+            max_norm,
+            norm_type,
+            padding_idx,
+            pg,
+        )
+        weight.local_shards()[0].tensor = local_shard
+        return output
+    elif sharding_dim == 0:
+        return _handle_row_wise_sharding(
+            input,
+            world_size,
+            weight,
+            local_shard,
+            offsets,
+            per_sample_weights,
+            mode,
+            max_norm,
+            norm_type,
+            padding_idx,
+            rank,
+            pg,
+        )
+    else:
+        raise RuntimeError(
+            f"nn.EmbeddingBag weight sharded on dim {sharding_dim} not supported!"
+        )
+
+
+def _validate_embedding_bag_param(args, kwargs):
+    """
+    Validate input params of sharded embeddingBag op.
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        weight: sharded weight tensor.
+        kwargs: same as normal EmbeddingBag.
+
+    Return: None.
+    """
+
+    input = args[0]
+    weight = args[1]
+    offsets = kwargs.get("offsets")
+    per_sample_weights = kwargs.get("per_sample_weights")
+    mode = kwargs.get("mode")
+    max_norm = kwargs.get("max_norm")
+    scale_grad_by_freq = kwargs.get("scale_grad_by_freq")
+    sparse = kwargs.get("sparse")
+    include_last_offset = kwargs.get("include_last_offset")
+
+    # Validate types
+    if not isinstance(input, torch.Tensor):
+        raise TypeError("input need to be torch.Tensor")
+    if offsets is not None and not isinstance(offsets, torch.Tensor):
+        raise TypeError("offsets need to be torch.Tensor")
+    if per_sample_weights is not None and not isinstance(
+        per_sample_weights, torch.Tensor
+    ):
+        raise TypeError("per_sample_weights need to be torch.Tensor")
+    if not isinstance(weight, ShardedTensor):
+        raise TypeError("weight needs to be ShardedTensor")
+    if len(input.size()) > 2:
+        raise ValueError("Input more than 2 dims not supported")
+    weight_size = weight.size()
+    if len(weight_size) != 2:
+        raise ValueError("Weight needs to have exactly 2 dims")
+    if int(torch.min(input).item()) < 0:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.min(input).item()),
+            weight_size[1],
+        )
+    if int(torch.max(input).item()) >= weight_size[0]:
+        raise ValueError(
+            "Index out of range in Input %d %d",
+            int(torch.max(input).item()),
+            weight_size[1],
+        )
+    if offsets is not None and len(input.size()) != 1:
+        raise ValueError("Input dimension needs to be exactly 1 dim")
+    if len(input.size()) == 1 and offsets is None:
+        raise ValueError("offsets is required for 1D input")
+    if per_sample_weights is not None and per_sample_weights.size() != input.size():
+        raise ValueError(
+            f"per_sample_weights size {per_sample_weights.size()} not equal to input size {input.size()}"
+        )
+    if mode is None:
+        mode = "mean"
+    if mode not in ["sum", "mean", "max"]:
+        raise ValueError(f"mode '{mode}' is not supported")
+    if scale_grad_by_freq:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!'
+        )
+    if sparse:
+        raise RuntimeError(
+            'nn.Embedding weight sharded with flag on "sparse" not supported!'
+        )
+    if include_last_offset and offsets is None:
+        raise ValueError('offsets is required for flag "include_last_offset"!')
+    if include_last_offset and cast(List[int], offsets)[-1] != input.size(0):
+        raise ValueError(
+            'offsets need to have the input size in the end when the flag "include_last_offset" is on!'
+        )
+
+    if max_norm and max_norm <= 0.0:
+        raise ValueError('"max_norm" must be larger than zero!')
+
+    if not isinstance(weight._sharding_spec, ChunkShardingSpec):
+        raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!")
+    if len(weight.local_shards()) != 1:
+        raise ValueError("Only one local shard supported!")
+
+
+def _handle_col_wise_sharding(
+    input,
+    world_size,
+    weight,
+    local_shard,
+    offsets,
+    per_sample_weights,
+    mode,
+    max_norm,
+    norm_type,
+    padding_idx,
+    pg,
+):
+    """
+    Entry-point function to handle the logic of col-wise sharding of weight
+    for embeddingBag. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding_bag.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: col-wise shared local weight used for lookup.
+        offsets: list of start positions of each bag for 1D input.
+        per_sample_weights: weights for weighted sum mode.
+        mode: aggregation method of each bag.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+        pg: process group.
+
+    Return:
+        output: final result of lookup and aggregation.
+        local_shard: col-wise shared local weight used for lookup.
+            If max_norm, this will be the renormed weight.
+    """
+    # allgather the special input of embedding bag first.
+    (
+        gathered_inputs,
+        gathered_per_sample_weights,
+        gathered_offsets,
+    ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg)
+
+    if max_norm is not None:
+        # max_norm changes the weight in-place
+        local_shard = _handle_max_norm_col_wise(
+            max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg
+        )
+
+    output = _handle_col_wise_sharding_base(
+        torch.nn.functional.embedding_bag,
+        1,
+        input,
+        world_size,
+        weight,
+        local_shard,
+        pg,
+        gathered_inputs,
+        mode=mode,
+        gathered_per_sample_weights=gathered_per_sample_weights,
+        gathered_offsets=gathered_offsets,
+        padding_idx=padding_idx,
+    )
+    return (output, local_shard)
+
+
+def _handle_row_wise_sharding(
+    input,
+    world_size,
+    weight,
+    local_shard,
+    offsets,
+    per_sample_weights,
+    mode,
+    max_norm,
+    norm_type,
+    padding_idx,
+    rank,
+    pg,
+):
+    """
+    Entry-point function to handle the logic of row-wise sharding of weight
+    for embeddingBag. (Detailed explanations of the logic can be found in
+    the comment for sharded_embedding_bag.)
+
+    Args:
+        input: list of ID used for lookup and aggregation.
+        world_size: number of ranks.
+        weight: sharded weight tensor.
+        local_shard: row-wise shared local weight used for lookup.
+        offsets: list of start positions of each bag for 1D input.
+        per_sample_weights: weights for weighted sum mode.
+        mode: aggregation method of each bag.
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+        rank: # of cuda process.
+        pg: process group.
+
+    Returns:
+        gathered_output: final result of lookup and aggregation.
+    """
+    if input.dim() > 1 and per_sample_weights is None:
+        # allgather the inputs first for non Replicated Tensor.
+        gather_inp = _all_gather_base_input(input, pg)
+    else:
+        (
+            gathered_inputs,
+            gathered_per_sample_weights,
+            gathered_offsets,
+        ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg)
+        cat_dim = 0 if input.dim() != 1 else -1
+        gather_inp = torch.cat(gathered_inputs, dim=cat_dim)
+        if per_sample_weights is not None:
+            per_sample_weights = torch.cat(gathered_per_sample_weights, dim=cat_dim)
+        offset_add = 0 if input.dim() > 1 else input.size(0)
+        if offsets is not None:
+            offsets_list = torch.cat(
+                [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())],
+                dim=cat_dim,
+            )
+
+    # Mask the input according to sharding spec.
+    lookup_input, padding_local, padding_row = _handle_row_wise_mask(
+        gather_inp, padding_idx, weight, world_size, rank
+    )
+    if mode == "max":
+        padding_row[:] = -float("Inf")
+
+    # When input is a large tensor, the value of weight is changed.
+    # This is a walk-around for now. GH issue: #81717.
+    if max_norm is not None:
+        torch.nn.functional.embedding_bag(
+            torch.unique(lookup_input)[:-1],
+            local_shard,
+            offsets=torch.tensor([0], device=local_shard.device, dtype=torch.long),
+            mode=mode,
+            per_sample_weights=None,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            padding_idx=padding_local,
+        )
+        max_norm = None
+    result = torch.nn.functional.embedding_bag(
+        lookup_input,
+        torch.cat([local_shard, padding_row]),
+        offsets=offsets_list if offsets is not None else offsets,  # type: ignore[possibly-undefined]
+        mode=mode if mode != "mean" else "sum",
+        per_sample_weights=per_sample_weights,
+        max_norm=max_norm,
+        norm_type=norm_type,
+        padding_idx=padding_local,
+    )
+
+    op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX
+    # TODO: Make the result a PartialTensor and move the logic below there.
+    local_shards = result.chunk(pg.size())
+    result = reduce_scatter(
+        torch.empty_like(local_shards[0]),
+        list(local_shards),
+        op=op,
+        group=pg,
+    )
+
+    # For Mean, we cannot do the division until very end because the sum of means
+    # not equal to the mean of sum. (Divisor is different)
+    if mode == "mean":
+        if input.dim() > 1:
+            padding_idx = padding_idx if padding_idx is not None else -1
+            split_sizes = torch.sum(
+                torch.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype
+            )
+        else:
+            split_sizes = torch.cat(
+                (
+                    offsets[1 : offsets.size(0)] - offsets[0:-1],
+                    (input.size(0) - offsets[-1]).unsqueeze(0),
+                ),
+                dim=-1,
+            )
+        return torch.div(result, split_sizes.unsqueeze(1))
+
+    # Return the appropriate local result.
+    return result
+
+
+def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg):
+    """
+    In case we need to gather input and all other parameters of embeddingBag
+    ops, we need to stack all input together to perform ``all_gather``
+    collective communication just once.
+
+    Note that since offsets does not share the same size as input and
+    is always smaller than input, we resize it during the communication.
+
+    Args:
+        input: tensor to be applied op on.
+        per_sample_weights: weights for weighted sum mode.
+        offsets: when input is 1D. offsets determines the starting
+            index position of each bag (sequence) in input.
+        pg: process group.
+
+    Returns:
+        gathered_inputs: list of input tensor gathered from each rank.
+        gathered_per_sample_weights: list of per_sample_weights from each rank.
+        gathered_offsets: list of offsets from each rank.
+    """
+    input_to_gather = [input]
+    if per_sample_weights is not None:
+        input_to_gather.append(per_sample_weights)
+    if offsets is not None:
+        input_to_gather.append(offsets.clone().resize_(input.size()))
+    gathered_inputs = all_gather(torch.stack(input_to_gather), group=pg)
+
+    gathered_per_sample_weights = None
+    if per_sample_weights is not None:
+        gathered_per_sample_weights = [t[1] for t in gathered_inputs]
+    gathered_offsets = None
+    if offsets is not None:
+        idx = 2 if per_sample_weights is not None else 1
+        gathered_offsets = [
+            t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs
+        ]
+    gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs]
+    return gathered_inputs, gathered_per_sample_weights, gathered_offsets
diff --git a/MLPY/Lib/site-packages/torch/distributed/_sharded_tensor/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_sharded_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..30f121173c4c9c80ff09903da42b010f9ade2855
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_sharded_tensor/__init__.py
@@ -0,0 +1,12 @@
+# Keep old package for BC purposes, this file should be removed once
+# everything moves to the `torch.distributed._shard` package.
+import sys
+import torch
+import warnings
+
+from torch.distributed._shard.sharded_tensor import *  # noqa: F403
+warnings.warn(
+    "torch.distributed._sharded_tensor will be deprecated, use torch.distributed._shard.sharded_tensor instead",
+    DeprecationWarning
+)
+sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor
diff --git a/MLPY/Lib/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dde1b44ada39079d3e8161c1fc4a415c20b414b3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_sharding_spec/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_sharding_spec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1890c6b6f0930426a543a5d19cf8e02710b384
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_sharding_spec/__init__.py
@@ -0,0 +1,14 @@
+# Keep old package for BC purposes, this file should be removed once
+# everything moves to the `torch.distributed._shard` package.
+import sys
+import torch
+import warnings
+
+from torch.distributed._shard.sharding_spec import *  # noqa: F403
+warnings.warn(
+    "torch.distributed._sharding_spec will be deprecated, use torch.distributed._shard.sharding_spec instead",
+    DeprecationWarning
+)
+
+import torch.distributed._shard.sharding_spec as _sharding_spec
+sys.modules['torch.distributed._sharding_spec'] = _sharding_spec
diff --git a/MLPY/Lib/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8febcfddb3aab538e86322657d3c8a84cfa88a97
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..475f93d71f7611d7e3d380ff85b8ce086762ecba
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d65d75533dea9100b698005c491561aac3adaf4f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/batch_dim_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/batch_dim_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..163218563efb422d983239a80d165b945e8e637e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/batch_dim_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/comm_tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/comm_tensor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6f038c243e6a3aee558e82dc72017d7c307ef24
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/comm_tensor.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..267dafadde99bbfc1c499fdac25703b1c8f6c35d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/config.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/data_parallel.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/data_parallel.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9604372e55ee7e061916557636aa03409ded11ed
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/data_parallel.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/distribute.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/distribute.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d148a32c448c7046c1fec54a5daf1e94ddb066cd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/distribute.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/experimental_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/experimental_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a96b738f51bbdf497bd47356304bf6f8728bdc11
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/experimental_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/gm_transformation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/gm_transformation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..831bf4aa72395d114df3d10e6ef027b033c8bc99
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/gm_transformation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/graph_optimization.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/graph_optimization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a6c4ca0c35b5111ba9850745254bb136c5f1ebf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/graph_optimization.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/graph_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/graph_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ac2c4a1a0884b991a268f277e116ecbb2de534e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/graph_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/iter_graph_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/iter_graph_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36697f4b787571d879ac84fe548c792c01780698
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/iter_graph_module.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/log_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/log_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caf9751d168b571530219ad6e59bb9b2cf012a52
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/log_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/parallel_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/parallel_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..484a8d38eae67e36a20429070e10c366d100624e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/parallel_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/partial_lower.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/partial_lower.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b56f7e12e69ea40fba9a83f52068fe7aca8b2dad
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_spmd/__pycache__/partial_lower.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/api.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..77953fee4310e2a26c0afc86bd06da93739dd49e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/api.py
@@ -0,0 +1,575 @@
+from abc import ABC, abstractmethod
+from contextlib import contextmanager, nullcontext
+from copy import copy
+from dataclasses import dataclass
+from functools import partial, wraps
+from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
+
+from functorch import make_fx
+
+import torch
+import torch.distributed as dist
+
+# We need to import _functional_collectives to trigger op registration
+import torch.distributed._functional_collectives
+import torch.nn as nn
+import torch.utils._pytree as pytree
+
+from torch import fx
+from torch._decomp.decompositions import native_layer_norm_backward
+
+from torch._subclasses.fake_tensor import FakeTensorMode
+from torch.distributed._spmd.data_parallel import gradients_tagging
+from torch.distributed._spmd.parallel_mode import (
+    DataParallel,
+    DTensorExpandMode,
+    ParallelMode,
+)
+from torch.distributed._tensor import Placement
+from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen
+from torch.nn.utils import stateless
+from torch.nn.utils._named_member_accessor import NamedMemberAccessor
+
+
+class Override(ABC):
+    r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
+
+    This is useful when any part of the model is not traceable or if you prefer
+    to not trace it due to any reason. More specifically, users can implement
+    :meth:`torch.distributed._spmd.Override.replacement` to replace an original
+    submodule with the return new submodule. The new submodule contains
+    operations that users preferred to be traced, which simply be a dummy
+    placeholder operator. After tracing, users can implement
+    :meth:`torch.distributed._spmd.Override.transform` to transform the traced
+    graph, where the dummy placeholder operator serves as an anchor to insert
+    new sub-graphs.
+    """
+
+    @abstractmethod
+    def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
+        r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
+        argument in the model.
+
+        This helps if ``orig_submodule`` is not traceable or should not be traced.
+
+        Args:
+            fqn (str): fully quantified name of the submodule.
+            orig_submodule (class:`nn.Module`): original submodule instance to replace.
+
+        Returns:
+            A new :class:`nn.Module` instance to replace the original one.
+
+        """
+        pass
+
+    @abstractmethod
+    def transform(
+        self,
+        gm: fx.GraphModule,
+        flat_state: List[torch.Tensor],
+    ) -> fx.GraphModule:
+        r"""
+        Given a DTensor-expanded graph and sharding schema for every node,
+        conduct additional transformation for the sub-graph from the :class:`nn.Module`
+        returned by :meth:`torch.distributed._spmd.Override.replacement` if
+        necessary.
+
+        Args:
+            gm (:class:`fx.Graph`): a DTensor-expanded graph.
+            flat_state (List[str, :class:`Tensor`]): a reference to the list of
+                flattened state. The elements in ``flat_state`` map to the first
+                ``len(flat_state)`` placeholders in the graph. The transformation
+                can add state to or remove state from ``flat_state`` as long as
+                it keeps ``flat_state`` and the placeholders consistent.
+
+        Returns:
+            The :class:`fx.Graph` after transformation.
+
+        """
+        pass
+
+
+class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
+    # pyre-ignore[3]
+    def process_inputs(self, *args: Any) -> Any:
+        return args
+
+    # pyre-ignore[2, 3]
+    def gen_fn_def(self, free_vars, maybe_return_annotation):
+        return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation)
+
+
+def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """Move the responsibility of flattening the input arguments from the graph module to the caller.
+
+    Example:
+
+        output = gm(my_struct)
+
+        gm = gm(to_caller_flattened_graph_module)
+
+        output = gm(*pytree.flatten(my_struct)[0])
+
+    """
+    # pyre-ignore[16]
+    gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
+        pytree_info=_PyTreeInfo(
+            # pyre-ignore[6]
+            orig_args=None,  # type: ignore[arg-type]
+            # pyre-ignore[6]
+            in_spec=None,  # type: ignore[arg-type]
+            # pyre-ignore[16]
+            out_spec=gm._graph._codegen.pytree_info.out_spec,
+        )
+    )
+    gm.recompile()
+    return gm
+
+
+# Use a dtensor expand mode for now to preserve the old behavior
+# and avoid breaking existing code
+dtensor_expand_mode = DTensorExpandMode()
+
+
+def _override_placements(t: torch.Tensor, placements: List[Placement]):
+    global dtensor_expand_mode
+    dtensor_expand_mode._placements_override[id(t)] = placements
+
+
+@contextmanager
+def _rematerialize_optimizer(
+    opt: torch.optim.Optimizer,
+    named_states: Dict[str, Any],
+    params: Dict[str, nn.Parameter],
+):
+    assert opt is not None
+
+    # update opt.state with proxy tensors
+    orig_states = copy(opt.state)
+    for n in named_states:
+        # opt.state's key type is string, but optimizer uses Parameter as keys
+        opt.state[params[n]] = named_states[n]  # type: ignore[index]
+
+    # FIXME: support multiple parameter groups
+    param_group = opt.param_groups[0]
+    orig_params = param_group["params"]
+    param_group["params"] = params.values()
+
+    try:
+        yield
+    finally:
+        param_group["params"] = orig_params
+        opt.state = orig_states
+
+
+aten = torch.ops.aten  # pyre-ignore
+
+
+@contextmanager
+def _enable_compile():
+    # The return value of torch._utils.is_compiling changes optimizer behavior.
+    # We need that function to return True to include optimizer in the graph.
+    # See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
+    def f_true():
+        return True
+
+    orig_is_compiling_code = torch._utils.is_compiling.__code__
+    torch._utils.is_compiling.__code__ = f_true.__code__
+    try:
+        yield
+    finally:
+        torch._utils.is_compiling.__code__ = orig_is_compiling_code
+
+
+def _foreach_add_decomp(self, other, alpha=1):
+    self_updated = aten._foreach_add.List(self, other, alpha=alpha)
+    for s, s_u in zip(self, self_updated):
+        s.copy_(s_u)
+
+
+def _foreach_unaop_decomp(op, self):
+    self_updated = op(self)
+    for s, s_u in zip(self, self_updated):
+        s.copy_(s_u)
+
+
+def _foreach_binop_list_decomp(op, self, other):
+    self_updated = op(self, other)
+    for s, s_u in zip(self, self_updated):
+        s.copy_(s_u)
+
+
+def _foreach_binop_scalar_decomp(op, self, scalar=1):
+    self_updated = op(self, scalar)
+    for s, s_u in zip(self, self_updated):
+        s.copy_(s_u)
+
+
+def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):
+    self_updated = op(self, tensor1, tensor2, scalar)
+    for s, s_u in zip(self, self_updated):
+        s.copy_(s_u)
+
+
+def _fused_adam_decomp(
+    self,
+    grads,
+    exp_avgs,
+    exp_avg_sqs,
+    max_exp_avg_sqs,
+    state_steps,
+    *,
+    lr=1,
+    beta1=1,
+    beta2=1,
+    weight_decay=1,
+    eps=1,
+    amsgrad=True,
+    maximize=True,
+    grad_scale=None,
+    found_inf=None,
+):
+    orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)
+    updated_tuple = aten._fused_adam.default(
+        self,
+        grads,
+        exp_avgs,
+        exp_avg_sqs,
+        max_exp_avg_sqs,
+        state_steps,
+        lr=lr,
+        beta1=beta1,
+        beta2=beta2,
+        weight_decay=weight_decay,
+        eps=eps,
+        amsgrad=amsgrad,
+        maximize=maximize,
+        grad_scale=grad_scale,
+        found_inf=found_inf,
+    )
+
+    for idx, (orig, updated) in enumerate(zip(orig_tuple, updated_tuple)):
+        if idx == 1:
+            # skip gradient copying as we don't need to copy gradients back
+            continue
+        for o, u in zip(orig, updated):
+            o.copy_(u)
+
+
+SPMD_DECOMP_TABLE = {
+    aten._foreach_add_.List: _foreach_add_decomp,
+    aten._foreach_add_.Scalar: partial(
+        _foreach_binop_scalar_decomp, aten._foreach_add.Scalar
+    ),
+    aten._foreach_addcdiv_.Scalar: partial(
+        _foreach_addcop_scalar_decomp, aten._foreach_addcdiv.Scalar
+    ),
+    aten._foreach_addcmul_.Scalar: partial(
+        _foreach_addcop_scalar_decomp, aten._foreach_addcmul.Scalar
+    ),
+    aten._foreach_div_.List: partial(
+        _foreach_binop_list_decomp, aten._foreach_div.List
+    ),
+    aten._foreach_mul_.Scalar: partial(
+        _foreach_binop_scalar_decomp, aten._foreach_mul.Scalar
+    ),
+    aten._foreach_div_.Scalar: partial(
+        _foreach_binop_scalar_decomp, aten._foreach_div.Scalar
+    ),
+    aten._foreach_neg_.default: partial(
+        _foreach_unaop_decomp, aten._foreach_neg.default
+    ),
+    aten._foreach_reciprocal_.default: partial(
+        _foreach_unaop_decomp, aten._foreach_reciprocal.default
+    ),
+    aten._foreach_sqrt_.default: partial(
+        _foreach_unaop_decomp, aten._foreach_sqrt.default
+    ),
+    aten._foreach_sub_.Scalar: partial(
+        _foreach_binop_scalar_decomp, aten._foreach_sub.Scalar
+    ),
+    aten._fused_adam_.default: _fused_adam_decomp,
+    aten.native_layer_norm_backward.default: native_layer_norm_backward,
+}
+
+
+DEDUP_TARGETS: Set[torch._ops.OpOverload] = {
+    torch.ops.c10d_functional.all_reduce.default,
+    torch.ops.c10d_functional.wait_tensor.default,
+}
+
+
+def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule:
+    args_to_node: Dict[Tuple[Any, ...], fx.Node] = {}
+
+    for node in gm.graph.nodes:
+        # replace all args with the results from the first unique comm op
+        args = pytree.arg_tree_leaves(*node.args)
+
+        if node.target in DEDUP_TARGETS:
+            args_key = (node.target, *args)
+            unique_node = args_to_node.get(args_key, None)
+            if unique_node is None:
+                # first time seeing this combination, remember it
+                args_to_node[args_key] = node
+            else:
+                # the current node is a duplicate, replace it
+                node.replace_all_uses_with(unique_node)
+                gm.graph.erase_node(node)
+
+    gm.recompile()
+
+    return gm
+
+
+@dataclass
+class _CompiledResult:
+    gm: fx.GraphModule
+    mod: nn.Module
+    opt: Optional[torch.optim.Optimizer]
+    flat_state: List[torch.Tensor]
+
+
+def _compile(
+    func: Callable,
+    module_override: Optional[List[Override]],
+    parallel_mode: ParallelMode,
+    *args: Any,
+    **kwargs: Any,
+) -> _CompiledResult:
+    # 1. Extract nn.Module and Optimizer from args and kwargs
+    # FIXME(@mrshenli): support multiple nn.Module instances
+    # FIXME(@mrshenli): support multiple Optiimzer instances
+    # FIXME(@mrshenli): need to broadcast model to sync parameters
+    mod, opt = None, None
+    for arg in pytree.arg_tree_leaves(*args, **kwargs):
+        if isinstance(arg, nn.Module):
+            assert mod is None, "Only support single nn.Module for now"
+            mod = arg
+        if isinstance(arg, torch.optim.Optimizer):
+            assert opt is None, "Only support single Optimizer for now"
+            opt = arg
+
+    assert mod is not None, "Couldn't find nn.Module instances from the arguments."
+
+    # 2. Override target submodules (e.g., MoE) with dummy replacements
+    if module_override:
+        accessor = NamedMemberAccessor(mod)
+
+        def swap(fqn_prefix: str, module: torch.nn.Module) -> None:
+            for override in module_override:  # type: ignore[union-attr]
+                for name, child in module.named_children():
+                    if len(name) == 0:
+                        continue
+                    fqn = fqn_prefix + "." + name if fqn_prefix != "" else name
+                    new_child = override.replacement(fqn, child)
+                    if id(new_child) == id(child):
+                        swap(fqn, new_child)
+                    else:
+                        accessor.swap_submodule(fqn, new_child)
+
+        swap("", mod)
+
+    # 3. Trace statelss version of the train_step
+    params = dict(mod.named_parameters(remove_duplicate=False))
+    buffers = dict(mod.named_buffers(remove_duplicate=False))
+
+    named_states = {}
+    if opt is not None:
+        # Pass named_states instead of opt.state to stateless_func, because
+        # the later uses nn.Parameter as key. During tracing, we need to
+        # make sure optimizers can find the states using proxy tensors.
+        for n, p in params.items():
+            if p in opt.state:
+                # opt.state's key type is string, but optimizer uses
+                # Parameter as keys
+                named_states[n] = opt.state[p]  # type: ignore[index]
+
+    is_data_parallel_mode = isinstance(parallel_mode, DataParallel)
+
+    # Lift states and parameters as function arguments so that make_fx
+    # can trace operations applied to them.
+    def stateless_func(func, params, buffers, named_states, args, kwargs):
+        with stateless._reparametrize_module(
+            mod, {**params, **buffers}
+        ), _rematerialize_optimizer(
+            opt, named_states, params
+        ) if opt else nullcontext():
+            # For DataParallel mode, install hooks first to tag the gradients
+            with gradients_tagging(params) if is_data_parallel_mode else nullcontext():
+                ret = func(*args, **kwargs)
+
+            # make sure updated parameters are returned
+            return ret, list(mod.parameters()), list(named_states.values())  # type: ignore[union-attr]
+
+    # FIXME: Using symbolic tracing to work around in DTensor expand mode.
+    # Otherwise it hits shape mismatch error, as we use local inputs to
+    # trace local graph and use DTensor to expand operators, where
+    # DTensor's shape is the global shape.
+    tracing_mode = "fake" if is_data_parallel_mode else "symbolic"
+
+    if is_data_parallel_mode:
+        fake_mode = FakeTensorMode()
+        data_parallel_mode = cast(DataParallel, parallel_mode)
+
+        def _get_full_batch_arg(arg: torch.Tensor) -> torch.Tensor:
+            # since compilation happens in the first iteration and we
+            # receives mini-batch input, convert them to full batch
+            # fake tensor input first for data parallel sharding
+            # propagations
+            fake_arg = fake_mode.from_tensor(arg)
+            arg_dims = [1] * arg.ndim
+            # expand the tensor to full batch size on its batch dim
+            arg_dims[data_parallel_mode.input_batch_dim] *= dist.get_world_size()
+            return fake_arg.repeat(arg_dims)
+
+        args = pytree.tree_map_only(
+            torch.Tensor,
+            _get_full_batch_arg,
+            args,
+        )
+        kwargs = pytree.tree_map_only(
+            torch.Tensor,
+            _get_full_batch_arg,
+            kwargs,
+        )
+
+    with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False):
+        # FIXME(@mrshenli): functionalization does not work for our use
+        # case yet. Use explicit decompositions for foreach ops.
+        # Remove this when the following issue is addressed.
+        # Issue: https://github.com/pytorch/pytorch/issues/97852
+        gm = make_fx(
+            partial(stateless_func, func),
+            tracing_mode=tracing_mode,
+            decomposition_table=SPMD_DECOMP_TABLE,
+            _allow_non_fake_inputs=False,
+        )(params, buffers, named_states, args, kwargs)
+
+    params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
+        **params,
+        **buffers,
+    }
+
+    # 4. parallel mode to expand a single device graph to a distributed graph
+    gm = parallel_mode.partition(
+        gm,
+        mod,
+        opt,
+        params_and_buffers,
+        named_states,
+        args,
+        kwargs,
+    )
+
+    # 5. Move the responsibility of flattening the input arguments from the
+    # graph module to the caller. This serves two purposes:
+    #   - Transformations that add/remove state need to manipulate a state
+    #   container that maintains the state tensors in the same order as they
+    #   appear in graph placeholders.
+    #   - Reduced runtime cost. The state container is only flattened once upfront.
+    flat_state = pytree.tree_leaves([params_and_buffers, named_states])
+    gm = _to_caller_flattened_graph_module(gm)
+
+    # 6. dedup comm operators.
+    # The duplication could come from DTensor args and kwargs redistribution.
+    # Suppose one operator produces a Partial gradient tensor and model
+    # parameters are replicated. In this case, every optimizer operation using
+    # that Partial gradient tensor would trigger an allreduce. This is becuase
+    # DTensor only has local information on individual tensor/operator, which is
+    # not sufficient to detect duplications in the graph. This situation can
+    # also happen when inserting FSDP allgather if a parameter is used multiple
+    # times in the forward method.
+    # TODO(@mrshenli): @yifuwang has a suggestion of conducting expansion and
+    # dedup at tracer-level to avoid multiple graph passes.
+    gm = _dedup_collectives(gm)
+
+    # 7. Replace previously inserted dummy ones with real graphs.
+    if module_override:
+        for override in module_override:
+            gm = override.transform(gm, flat_state)
+
+    return _CompiledResult(gm, mod, opt, flat_state)
+
+
+# Note that the Python convention of __dict__ requires the key to be str.
+# TODO: ensure the key is unique.
+COMPILED_OBJECT_KEY = "_compiled_obj"
+
+
+def compile(
+    module_override: Optional[List[Override]] = None,
+    gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
+    parallel_mode: Optional[ParallelMode] = None,
+):
+    r"""Compile and optimize a callable, which can be a train step within a training loop.
+
+    This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
+    instances from the input arguments and trace operations applied to their
+    parameters and states.
+
+    Args:
+        module_override (Optional[List[Override]]): a list of Override instances
+            that will be applied to the module in order. The :class:`Override`
+            objects provide :class:`nn.Module` replacements during tracing and a
+            graph transformation function after tracing. (Default: ``None``)
+        gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
+            a callback that will be called after the original callable is
+            compiled and distributed (usually after the first iteration) to
+            transform the compiled GraphModule into a new optimized one.
+        parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object
+            that specifies how to parallelize the callable. Each ParallelMode
+            would have its own strategy to partition the model and the captured
+            graph (Default: ``None``)
+
+    """
+
+    def inner(func: Callable):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            last_train_step = kwargs.pop("last_train_step", False) if kwargs else False
+            first_iter = False
+            # Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as
+            # ``wrapper`` is the one that users will get.
+            compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
+            if compiled_obj is None:
+                first_iter = True
+                global dtensor_expand_mode
+                mode: ParallelMode = (
+                    dtensor_expand_mode if parallel_mode is None else parallel_mode
+                )
+
+                compiled_obj = _compile(func, module_override, mode, *args, **kwargs)
+                wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
+
+            flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves(
+                *args, **kwargs
+            )
+
+            with torch.no_grad():
+                # N.B.: we don't need autograd as backward has already been
+                # captured in the graph.
+                if first_iter and gm_transformation:
+                    # TODO: SPMD should provid a default and configurable
+                    # transformation.
+                    compiled_obj.gm = gm_transformation(compiled_obj.gm)
+                if not last_train_step:
+                    output = compiled_obj.gm(*flat_inps)[0]
+                else:
+                    # This is the last train step. Call IterGraphModule.forward()
+                    # with the `last_iter` argument and catch the exception in
+                    # case the compiled_obj is not wrapped with IterGraphModule.
+                    try:
+                        output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[
+                            0
+                        ]
+                    except TypeError as e:
+                        if "last_iter" not in str(e):
+                            raise e
+                        output = compiled_obj.gm(*flat_inps)[0]
+
+                return output
+
+        return wrapper
+
+    return inner
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/batch_dim_utils.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/batch_dim_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8277ae25cdf1d8d7df61e04ad9acf40f62b86b40
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/batch_dim_utils.py
@@ -0,0 +1,179 @@
+from typing import Callable, Dict, List, Set
+
+import torch
+
+import torch.fx as fx
+
+import torch.utils._pytree as pytree
+
+from torch import Tensor
+
+from torch.distributed._tensor import DeviceMesh, Replicate, Shard
+from torch.distributed._tensor.ops.view_ops import (
+    DimSpec,
+    InputDim,
+    ops as view_op_rules,
+)
+from torch.distributed._tensor.placement_types import _Partial, DTensorSpec
+
+aten = torch.ops.aten
+
+
+class BatchDimAnalyzer:
+    """This class is used to analyze the batch dimension of each tensor/node in the graph.
+
+    We need to know the batch dimension of each tensor/node so that we know
+    exactly the sharding layout of intermediate tensors.
+
+    We possibly should evaluate using symbolic shapes to track the batch dimension.
+    We can experiment it later with dynamo integration (as dynamo have mark_dynamic
+    API which allows marking batch dimension only) or try to use FakeTensorMode to
+    mark the batch dimension. For now, let's just use the batch dimension of the first
+    input tensor as the hint to track the batch dimension of all tensors/nodes in
+    the graph.
+    """
+
+    def __init__(self, batch_dim: int = 0) -> None:
+        self.batch_dim = batch_dim
+
+        self.batch_dim_map: Dict[fx.Node, int] = {}
+        # batch dim size is used to track the batch dim size of the input tensor
+        self.batch_dim_size = -1
+
+        self.dim_rule_map: Dict[torch._ops.OpOverload, Callable[..., torch.Tensor]] = {
+            aten.squeeze.default: torch.squeeze,
+            aten.squeeze.dim: torch.squeeze,
+            aten.view.default: Tensor.view,
+            aten.reshape.default: torch.reshape,
+            aten._unsafe_view.default: Tensor.view,
+            aten.unsqueeze.default: torch.unsqueeze,
+            aten.expand.default: Tensor.expand,
+            aten.permute.default: torch.permute,
+            aten.repeat.default: Tensor.repeat,
+            aten.transpose.int: torch.transpose,
+        }
+
+    def init_batch_dim_size(self, batch_dim_size: int) -> None:
+        """Initialize batch dim size base on the first input batch size."""
+        if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
+            raise RuntimeError(
+                f"batch dim size is already initialized! "
+                f"Found new batch size: {batch_dim_size} not "
+                f"matching existing batch dim size: {self.batch_dim_size}!"
+            )
+        self.batch_dim_size = batch_dim_size
+
+    def set_batch_dim(self, node: fx.Node, batch_dim: int) -> None:
+        self.batch_dim_map[node] = batch_dim
+
+    def get_batch_dim(self, node: fx.Node) -> int:
+        if node not in self.batch_dim_map:
+            raise RuntimeError(f"batch dim analysis failed on node: {node}!")
+        return self.batch_dim_map[node]
+
+    def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
+        """Compute the batch dimension for the `node`."""
+        assert self.batch_dim_size != -1, "batch dim size is not initialized!"
+
+        if node in self.batch_dim_map:
+            # if batch dim already computed, simply return it
+            return self.batch_dim_map[node]
+
+        if node.target in self.dim_rule_map:
+            view_op_rule = view_op_rules[self.dim_rule_map[node.target]]  # type: ignore[index]
+            args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args)
+            kwargs_val = pytree.tree_map_only(
+                fx.Node, lambda n: n.meta["val"], node.kwargs
+            )
+            output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val)
+
+            def collect_input_dim(cmd: DimSpec, input_dims: Set[int]):
+                if isinstance(cmd, InputDim):
+                    input_dims.add(cmd.input_dim)
+                for inp in cmd.inputs():
+                    collect_input_dim(inp, input_dims)
+
+            output_dim_to_input_dims: List[Set[int]] = []
+            for inp in output_dim_rules:
+                input_dims: Set[int] = set()
+                collect_input_dim(inp, input_dims=input_dims)
+                output_dim_to_input_dims.append(input_dims)
+
+            operand = node.all_input_nodes[0]
+            operand_batch_dim = self.get_batch_dim(operand)
+            for output_dim, input_dims in enumerate(output_dim_to_input_dims):
+                if operand_batch_dim in input_dims:
+                    self.set_batch_dim(node, output_dim)
+                    # update batch dim size before return
+                    # this is because batch dim size might change during the middle
+                    self.batch_dim_size = node.meta["val"].shape[output_dim]
+                    return output_dim
+
+        # if there's no hints from the output_dim_rules, we infer from output
+        # shape to see if there's batch dim, and shard correspondingly
+        node_val = node.meta["val"]
+        if isinstance(node_val, (list, tuple)):
+            shapes = [val.shape for val in node_val]
+        else:
+            shapes = [node_val.shape]
+
+        # for reduction op that reduces over the sharded batch dim
+        # we don't generate partial, but rather, we generate shard
+        # This is because the intention of data parallel is to never
+        # do full reduction across batch dimension, it would still
+        # keep the reduction activation as sharded.
+        full_reduction = False
+        # loop through the dim size to find the output batch dim
+        for shape in shapes:
+            if len(shape) == 0:
+                full_reduction = True
+
+            for i, dim_size in enumerate(shape):
+                if dim_size == self.batch_dim_size:
+                    self.set_batch_dim(node, i)
+                    return i
+
+        operands = node.all_input_nodes
+        if not operands:
+            # if there's no operands, it must be factory ops and it's a tensor
+            # generated for computation and should be marked as replicated
+            self.set_batch_dim(node, -1)
+            # -1 means replicated
+            return -1
+        else:
+            # if there's operand we see the operand have batch dim, if operand
+            # have batch dim but output does not, it's either a full reduction,
+            # where we should stay sharded, or it's a reduction on batch dim only
+            # where we should produce partial
+            operand_batch_dim = -1
+            for operand in operands:
+                if operand in self.batch_dim_map:
+                    operand_batch_dim = self.get_batch_dim(operand)
+            # self.get_batch_dim(operands[0])
+            if operand_batch_dim < 0:
+                # if operand does not have batch dim, we also don't have batch dim
+                self.set_batch_dim(node, operand_batch_dim)
+                return operand_batch_dim
+            elif full_reduction:
+                self.set_batch_dim(node, operand_batch_dim)
+                return operand_batch_dim
+            else:
+                # if operand have batch dim but output does not, it should
+                # produce partial, we use -2 to indicate partial
+                self.set_batch_dim(node, -2)
+                return -2
+
+    def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
+        """Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
+        node_batch_dim = self.compute_batch_dim(node)
+        if node_batch_dim == -1:
+            # indicate this activation is replicated
+            act_spec = DTensorSpec(mesh=mesh, placements=(Replicate(),))
+        elif node_batch_dim == -2:
+            # indicate this activation is partial
+            act_spec = DTensorSpec(mesh=mesh, placements=(_Partial(),))
+        else:
+            # indicate this activation is Shard
+            act_spec = DTensorSpec(mesh=mesh, placements=(Shard(node_batch_dim),))
+
+        return act_spec
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/comm_tensor.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/comm_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5369c2d4f628ecdd9b2a7344cb8d1ab155ef11dd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/comm_tensor.py
@@ -0,0 +1,247 @@
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, List, Optional, Tuple
+
+import torch
+from torch._C import _disabled_torch_function_impl
+from torch.fx.experimental.proxy_tensor import (
+    _ProxyTensor,
+    fetch_object_proxy,
+    get_innermost_proxy_mode,
+    get_proxy_slot,
+    set_proxy_slot,
+    track_tensor_tree,
+)
+from torch.utils import _pytree as pytree
+from torch.utils._mode_utils import no_dispatch
+from torch.utils._pytree import tree_flatten, tree_map, tree_map_only
+
+
+@dataclass
+class _CommResult:
+    # a custom type wrapping both inplace output tensor and work handle
+    _tensor: torch.Tensor
+    _work: torch.distributed._Work
+
+
+def _wait_comm(comm_result: _CommResult):
+    # This function is only used by tracing mode as a call_function node right
+    # before consuming a collective result tensor.
+    comm_result._work.wait()
+    return comm_result._tensor
+
+
+def _wrap_comm_result(result: Tuple[Any, Any]) -> Tuple[Any, Any]:
+    def wrap(work, e):
+        assert isinstance(e, torch.Tensor), (
+            "Excepting collection of tensors as the first element in the "
+            "return value of communication operations."
+        )
+
+        return _CommResult(e, work)
+
+    # E.g.,
+    # allreduce_ returns ([tensor], work)
+    # allgather_ returns ([[tensor1, tensor2]], work)
+    work = result[1]
+    return (tree_map(partial(wrap, work), result[0]), work)
+
+
+def _get_tracer() -> Optional[torch.fx.Tracer]:
+    mode = get_innermost_proxy_mode()
+    if mode is None:
+        return None
+    return mode.tracer
+
+
+class CommTensor(torch.Tensor):
+    r"""
+    A Tensor subclass to wrap input tensors for collective communications.
+
+    This Tensor subclass works for both eager and tracing mode.
+    In eager mode, it will record whether the inplace collective communication
+    has been launched using this Tensor and remember the corresponding work
+    handle. If yes, it will explicitly call wait() in the ``__torch_dispatch__``
+    function before subsequent operations consuming the value of the Tensor.
+
+    In tracing mode, ``CommTensor`` inserts two node into the graph using the
+    ``__torch_dispatch__`` function.
+    1. The first node is inserted right after the
+    communication, wrapping both the inplace output tensor and the returned
+    work handle into a custom ``_CommResult`` type. We have to do this because
+    ``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``,
+    and ``torch.nn.Parameter`` objects and will treat the work handle
+    as a constant and embed that into the graph. As a result, during execution,
+    it will use the work handle created during tracing and will lead to wrong
+    result. The solution in this test is to manually create a proxy on the
+    return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that
+    to ``[(_CommResult(tensor, work)), work]``. In this way, subsequent nodes can
+    directly consume ``_CommResult``.
+    2. The second node is inserted right before any subsequent node reads from
+    ``_CommResult``. It will call ``wait()`` on the stashed work handle to ensure
+    that computation waits for communication.
+    """
+
+    _supported_comms: List[str] = [
+        "_allgather_base_",
+        "_reduce_scatter_base_",
+        "allreduce_",
+        "allgather_",
+        "alltoall_",
+        "broadcast_",
+        "reduce_scatter_",
+        "scatter_",
+    ]
+
+    _tensor: torch.Tensor
+    _work: Optional[torch.distributed._Work]
+
+    @staticmethod
+    def __new__(cls, tensor: torch.Tensor):
+        t = tensor._tensor if isinstance(tensor, CommTensor) else tensor
+        if get_innermost_proxy_mode() is None:
+            # noop for eager mode
+            return tensor
+
+        # Use non-CommTensor to avoid nested CommTensor Wrapping
+        r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
+        # The tensor object wrapped by this CommTensor
+        # NB: THIS CAN BE A CommTensor; see test_nested_comm_tensor_wrapping
+        r._tensor = tensor  # type: ignore[attr-defined]
+        # Record the LAST `work` object returned by collective communication
+        # operations. If this is None, it means no collectives have called
+        # since last time a tensor is wrapped by CommTensor
+        r._work = None  # type: ignore[attr-defined]
+        return r
+
+    def __repr__(self):
+        return f"CommTensor({self._tensor}, work={self._work})"
+
+    # disable __torch_function__ so that CommTensor can recursively dispatch
+    # with ProxyTorchDispatchMode in make_fx
+    __torch_function__ = _disabled_torch_function_impl
+
+    @classmethod
+    def _is_supported(cls, op_name):
+        return any(comm in op_name for comm in cls._supported_comms)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        # shared states when unwrapping args
+        tracer: Optional[torch.fx.Tracer] = None
+        work: Optional[torch.distributed._Work] = None
+
+        # wrapped ._tensor if this is a CommTensor, and insert/call wait()
+        # if communication has been launched on this tensor.
+        def unwrap(e: Any):
+            if isinstance(e, CommTensor):
+                nonlocal tracer, work
+
+                work = e._work
+                # TODO(ezyang): I don't really understand what's going on
+                # here, but it seems that tracer doesn't reflect whether or
+                # not there is ambient tracing going on, but rather, whether
+                # or not we will trace THIS particular invocation.  If we
+                # have a nested CommTensor, the outer layer doesn't actually
+                # trace and we only trace the inner layer
+                if not isinstance(e._tensor, CommTensor):
+                    tracer = _get_tracer()
+
+                if work is not None:
+                    if tracer is not None:
+                        # insert a node to the traced graph.
+                        proxy_res = tracer.create_proxy(  # type: ignore[union-attr]
+                            "call_function",
+                            _wait_comm,
+                            (get_proxy_slot(e._tensor, tracer).proxy,),
+                            {},
+                            name="wait_comm",
+                        )
+                        # HACK: update the proxy for the inplace output
+                        set_proxy_slot(e._tensor, tracer, proxy_res)
+                    # For eager mode, simply wait.
+                    # During tracing, still need to wait here, to make sure the
+                    # execution during tracing is correct.
+                    work.wait()
+
+                # communication has been waited, stop propagating CommTensor
+                return e._tensor
+            else:
+                return e
+
+        def wrap(e: Any):
+            return CommTensor(e) if isinstance(e, torch.Tensor) else e
+
+        def set_work(work: torch.distributed._Work, e: Any):
+            if isinstance(e, CommTensor):
+                e._work = work  # type: ignore[attr-defined]
+            elif isinstance(e, torch.Tensor):
+                raise RuntimeError(
+                    "Type of output tensors from collective communication during "
+                    "tracing should always be CommTensor instead of torch.Tensor"
+                )
+            return e
+
+        unwrapped_args = tree_map(unwrap, args)
+        unwrapped_kwargs = tree_map(unwrap, kwargs)
+
+        if cls._is_supported(func.__name__):
+            if tracer is not None:
+                # in tracing mode, get proxies for args
+                proxy_args, proxy_kwargs = tree_map_only(
+                    _ProxyTensor,
+                    lambda e: e.proxy,
+                    tree_map_only(
+                        torch.Tensor,
+                        fetch_object_proxy(tracer),
+                        (unwrapped_args, unwrapped_kwargs),
+                    ),
+                )
+
+                # get proxy for output tuple
+                proxy_res = func(*proxy_args, **proxy_kwargs)
+                assert isinstance(proxy_res, torch.fx.Proxy)
+                # insert a node that wraps the output tuple into
+                # _CommResult(tensor, work)
+                comm_result_proxy = tracer.create_proxy(  # type: ignore[union-attr]
+                    "call_function",
+                    _wrap_comm_result,
+                    (proxy_res,),
+                    {},
+                    name="comm_result",
+                )
+
+                with no_dispatch():
+                    # disable dispatch to avoid trigger ProxyTorchDispatchMode logic
+                    out = func(*unwrapped_args, **unwrapped_kwargs)
+
+                # wrap output with the proxy of _CommResult, so that subsequent
+                # ops and link to it.
+                track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer)
+
+                # N.B.: we still need to remember the work handle here, and wait
+                # for it later to make sure the execution during tracing is
+                # correct. Also, remember comm is already launched
+                # args[0] is always the collection of output tensors
+                pytree.tree_map_(partial(set_work, out[1]), args[0])
+
+                # HACK: update the proxy on the input argument as this is an
+                # inplace collective communication.
+                flat_args, args_spec = tree_flatten(unwrapped_args[0])
+                flat_out, out_spec = tree_flatten(out[0])
+                for a, o in zip(flat_args, flat_out):
+                    set_proxy_slot(a, tracer, get_proxy_slot(o, tracer))
+
+                return out
+            else:
+                # in eager mode, simply remember work handle as an attribute
+                out = func(*unwrapped_args, **unwrapped_kwargs)
+                pytree.tree_map_(partial(set_work, out[1]), args[0])
+                return out
+        else:
+            if work is not None:
+                return func(*unwrapped_args, **unwrapped_kwargs)
+            else:
+                # we need to propagate CommTensor wrapping until the first
+                # subsequent operation has waited for it.
+                return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/config.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..31dbcb1f1e7594439b7363cf43230389f48d76c0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/config.py
@@ -0,0 +1,27 @@
+import logging
+import sys
+from types import ModuleType
+from typing import Set
+
+# log level (levels print what it says + all levels listed below it)
+# DEBUG print full traces <-- lowest level + print tracing of every instruction
+# INFO print compiler functions + distributed graphs
+# WARN print warnings
+# ERROR print exceptions
+log_level: int = logging.DEBUG
+# Verbose will print full stack traces on warnings and errors
+verbose = False
+
+# the name of a file to write the logs to
+log_file_name: None = None
+
+
+class _AccessLimitingConfig(ModuleType):
+    def __setattr__(self, name, value) -> None:
+        if name not in _allowed_config_names:
+            raise AttributeError(f"{__name__}.{name} does not exist")
+        return object.__setattr__(self, name, value)
+
+
+_allowed_config_names: Set[str] = {*globals().keys()}
+sys.modules[__name__].__class__ = _AccessLimitingConfig
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/data_parallel.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f6b92c0494db3541365ab434a0e7d2b5e7b3831
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/data_parallel.py
@@ -0,0 +1,824 @@
+import operator
+from contextlib import contextmanager
+from enum import Enum
+
+from typing import Any, cast, Dict, List, Optional, Tuple
+
+import torch
+
+import torch.distributed.distributed_c10d as c10d
+import torch.fx as fx
+import torch.library
+import torch.nn as nn
+
+import torch.utils._pytree as pytree
+
+from torch.distributed._spmd.batch_dim_utils import BatchDimAnalyzer
+from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
+
+from torch.distributed._tensor._utils import compute_local_shape
+from torch.distributed._tensor.op_schema import (
+    OpStrategy,
+    PlacementStrategy,
+    StrategyType,
+    TupleStrategy,
+)
+from torch.distributed._tensor.placement_types import _Partial, DTensorSpec, Placement
+from torch.distributed._tensor.redistribute import redistribute_local_tensor
+from torch.fx import GraphModule
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from torch.nn.utils._named_member_accessor import NamedMemberAccessor
+
+aten = torch.ops.aten
+
+# Dummy op used by data parallel to tag gradients.
+_spmd_lib_def = torch.library.Library("_spmd", "DEF")
+_spmd_lib_def.define("tag_grad(Tensor self) -> Tensor")
+
+_spmd_lib_impl = torch.library.Library("_spmd", "IMPL")
+_spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
+
+
+class DataParallelStyle(Enum):
+    """This enum represents the style of the data-parallel operation.
+
+    We have three types of Data Parallel style:
+    1. DEFAULT: the default data parallel style, which is to represent a mixed
+                replicate and fully shard behavior. For each parameter that is able
+                to be sharded evenly, we shard it, otherwise we would replicate the
+                parameter. This style avoids potential padding if the parameters
+                cannot be sharded evenly, but it would generate a mixed of all_reduce
+                and reduce_scatter.
+    2. REPLICATE: the data parallel style that replicates all model parameters.
+                  This is similar to the behavior of DistributedDataParallel.
+    3. FULLY_SHARD: the data parallel style that shards all model parameters. This
+                    is similar to the behavior of FullyShardedDataParallel, the
+                    difference is that FullyShardedDataParallel (ZERO-3), which
+                    shards the model using FlatParameter based sharding,
+                    while this style shards each parameter into DTensor.
+    """
+
+    DEFAULT = 0
+    REPLICATE = 1
+    FULLY_SHARD = 2
+
+
+class NodeType(Enum):
+    """NodeType is an enum that records the type of the tensors in the graph.
+
+    This is used to determine the data parallel strategy.
+    """
+
+    PARAM = 0
+    ACT = 1
+    GRAD = 2
+    STATE = 3
+    NON_TENSOR = 4  # NON_TENSOR is to tag non tensor node (i.e. graph output)
+
+
+class DataParallelStrategy(OpStrategy):
+    """DataParallelStrategy is a special case of OpStrategy that only records the "data parallel style" placement
+    strategy for each fx Node.
+
+    It takes a list of PlacementStrategy, where each PlacementStrategy describes
+    one way to distribute the tensor and computation. In the DataParallel case,
+    there're two possible ways to distribute the parameters:
+        1. replicate the parameter over a set of devices (DDP like behavior)
+        2. shard the parameter on its tensor dimension 0 over a set of devices
+           (FSDP like behavior).
+
+    In addition to the strategy list, we also need to:
+    1. `node_type`: record the type of each node in the graph, so that we can
+        determine how to propagate in a data parallel fashion.
+    2. `reduce_over_batch` is specifically tied to data parallel as the loss
+        calculation usually results in scalar tensor where it comes from a
+        reduction over the batch dimension. We need to know this information
+        so that we could keep the output as sharded.
+    """
+
+    def __init__(
+        self,
+        node_type: NodeType,
+        strategy_list: List[PlacementStrategy],
+        reduction_over_batch: bool = False,
+    ):
+        super().__init__(strategy_list)
+        self.node_type = node_type
+        self.reduction_over_batch = reduction_over_batch
+
+    def __str__(self) -> str:
+        return f"type: {self.node_type}, {super().__str__()}"
+
+
+@contextmanager
+def gradients_tagging(params: Dict[str, torch.Tensor]):
+    """Tag the gradient of the parameters with a special tag, so that we can identify them during SPMD expansion.
+
+    It's safe to trace those hooks and we would remove those nodes later.
+    """
+    tagging_hooks = []
+    try:
+        for p in params.values():
+            h = p.register_hook(torch.ops._spmd.tag_grad)
+            tagging_hooks.append(h)
+        yield
+    finally:
+        # remove those hooks after tracing
+        for h in tagging_hooks:
+            h.remove()
+
+
+def _gen_shard_strategy(
+    mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
+) -> PlacementStrategy:
+    """Util function to generate a shard strategy on shard_dim."""
+    return PlacementStrategy(
+        output_specs=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
+        input_specs=input_specs,
+    )
+
+
+def _gen_replicate_strategy(
+    mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
+) -> PlacementStrategy:
+    """Util function to generate a replicate strategy."""
+    return PlacementStrategy(
+        output_specs=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
+        input_specs=input_specs,
+    )
+
+
+def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
+    """Util function to generate a partial strategy."""
+    # NOTE: we use AVG by default, avg reduction is needed depending on
+    # the loss function, for most loss function it should do
+    # gradient averaging. There might be certain cases it should
+    # not do gradient averaging (i.e. sum) but it's pretty rare.
+    # TODO: Only NCCL supports AVG so using backend like Gloo would
+    # crash, we should figure out a way to support avg reduction
+    # for non-NCCL backend
+    reduce_op = c10d.ReduceOp.AVG  # type: ignore[attr-defined]
+    return PlacementStrategy(
+        output_specs=DTensorSpec(mesh=mesh, placements=(_Partial(reduce_op),)),
+    )
+
+
+def build_data_parallel_strategies(
+    train_step_graph: GraphModule,
+    num_params: int,
+    num_states: int,
+    mesh: DeviceMesh,
+    batch_dim: int = 0,
+) -> Dict[fx.Node, StrategyType]:
+    """Loop through the train step graph and build the data parallel strategy for each fx Node."""
+    activation_idx = num_params + num_states
+    non_compute_ops = [
+        aten.clone.default,
+        aten.detach.default,
+        aten.ones_like.default,
+        aten.reshape.default,
+        aten.t.default,
+        aten.view.default,
+        torch.ops._spmd.tag_grad.default,
+        operator.getitem,
+    ]
+
+    tuple_strategy_ops = [aten._fused_adam.default]
+
+    dp_strategy_map: Dict[fx.Node, StrategyType] = {}
+    batch_dim_analyzer = BatchDimAnalyzer(batch_dim)
+    placeholder_idx = 0
+    num_param_grad = 0
+
+    # first we backward propagate to mark the param gradients sharding
+    # with tag_grad node helps and then delete the tag_grad nodes
+    for node in reversed(list(train_step_graph.graph.nodes)):
+        # find a param_grad node via the tagging
+        if node.target == torch.ops._spmd.tag_grad.default:
+            cur_node = node
+            while cur_node.target in non_compute_ops:
+                cur_node = cur_node.args[0]
+                partial_strategy = _gen_partial_strategy(mesh)
+                dp_strategy_map[cur_node] = DataParallelStrategy(
+                    NodeType.GRAD, [partial_strategy]
+                )
+            num_param_grad += 1
+            # remove the tag_grad node from graph
+            node.replace_all_uses_with(node.args[0])
+            train_step_graph.graph.erase_node(node)
+
+            if num_param_grad == num_params:
+                # early break if we have already processed all param_grads
+                break
+
+    # next we forward propagate to mark all the sharding
+    for node in train_step_graph.graph.nodes:
+        if node.op == "placeholder":
+            if "val" not in node.meta:
+                # NOTE: There're certain cases where the placeholder nodes do
+                # not have real tensor values:
+                # 1. optimizer states can be None sometimes, i.e. SGD with
+                #    no momentum, optimizer states populate `momentum` state
+                #    as None, the full graph we get from `compile` would have
+                #    None as the placeholder value
+                # 2. function args might not only contain params or activations,
+                #    but also contain other non-tensor inputs, i.e. the model
+                #    and optimizer instances baked in as a placeholder, there might
+                #    also be some scalar argument which is not a tensor
+                #
+                # For the above cases, we create a NON_TENSOR stratgy so that we
+                # know it's not a tensor and we don't need to shard it
+                dp_strategy_map[node] = DataParallelStrategy(NodeType.NON_TENSOR, [])
+
+            elif placeholder_idx < num_params:
+                # during compilation there's an assumption that the first num_params
+                # placeholders should be parameters
+                shard_strategy = _gen_shard_strategy(mesh, 0)
+                replica_strategy = _gen_replicate_strategy(mesh)
+                dp_strategy_map[node] = DataParallelStrategy(
+                    NodeType.PARAM, [replica_strategy, shard_strategy]
+                )
+
+            elif placeholder_idx < activation_idx:
+                # optimizer states follow the same strategy as
+                # the corresponding parameters
+                replica_strategy = _gen_replicate_strategy(mesh)
+                shard_strategy = _gen_shard_strategy(mesh, 0)
+
+                dp_strategy_map[node] = DataParallelStrategy(
+                    NodeType.STATE, [replica_strategy, shard_strategy]
+                )
+            else:
+                activation_batch_dim_size = node.meta["val"].shape[batch_dim]
+                # find the first activation node and use its batch dim size
+                if batch_dim_analyzer.batch_dim_size == -1:
+                    batch_dim_analyzer.init_batch_dim_size(activation_batch_dim_size)
+
+                batch_dim_analyzer.set_batch_dim(node, batch_dim)
+                shard_strategy = _gen_shard_strategy(mesh, batch_dim)
+                dp_strategy_map[node] = DataParallelStrategy(
+                    NodeType.ACT, [shard_strategy]
+                )
+            placeholder_idx += 1
+        elif node.op == "call_function":
+            # Annotate node types for the computation graph
+            # Data Parallel node propagation logic:
+            # param (non-compute) -> out: param
+            # grad (non-compute before/after) -> out: grad
+            # state -> output: state
+            #
+            # param + activation (param must be replicate, act be sharded) -> out: activation
+            # param/state + grad (param/state/grad be the same spec) -> out: param/state
+            # param + state -> out: param
+
+            if node.target in non_compute_ops:
+                # At this point, we should have removed all the `tag_grad` nodes in the graph
+                assert node.target != torch.ops._spmd.tag_grad.default
+
+                input_nodes = node.all_input_nodes
+                assert (
+                    len(input_nodes) == 1
+                ), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
+                arg_strategy = dp_strategy_map[input_nodes[0]]
+
+                if node.target == operator.getitem:
+                    # for getitem call, just forward the strategy from the input
+                    getitem_idx = node.args[1]
+                    if isinstance(arg_strategy, TupleStrategy):
+                        # for tuple strategy, we need to get the child strategy from the tuple
+                        dp_strategy_map[node] = arg_strategy.childs[getitem_idx]
+                    else:
+                        # if it's not a tuple strategy, we just forward the arg strategy
+                        dp_strategy_map[node] = arg_strategy
+                else:
+                    assert isinstance(arg_strategy, DataParallelStrategy)
+                    arg_node_type = arg_strategy.node_type
+                    if arg_node_type == NodeType.PARAM:
+                        replica_strategy = _gen_replicate_strategy(mesh)
+                        dp_strategy_map[node] = DataParallelStrategy(
+                            NodeType.PARAM, [replica_strategy]
+                        )
+                    elif arg_node_type == NodeType.GRAD:
+                        partial_sig = _gen_partial_strategy(mesh)
+                        dp_strategy_map[node] = DataParallelStrategy(
+                            NodeType.GRAD, [partial_sig]
+                        )
+                    elif arg_node_type == NodeType.ACT:
+                        arg_node_spec = batch_dim_analyzer.compute_act_spec(
+                            input_nodes[0], mesh
+                        )
+
+                        output_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
+
+                        shard_strategy = PlacementStrategy(
+                            output_specs=output_spec, input_specs=[arg_node_spec]
+                        )
+                        dp_strategy_map[node] = DataParallelStrategy(
+                            NodeType.ACT, [shard_strategy]
+                        )
+                    else:
+                        raise RuntimeError(
+                            f"non compute op not supporting {arg_node_type}! "
+                        )
+
+                # finished processing this non-compute node
+                continue
+
+            # for computatation nodes, we need to check all the inputs
+            input_args = node.all_input_nodes
+            input_specs = []
+            if node in dp_strategy_map:
+                # found a param_grad node that already have output pre-filled spec
+                # fill in the expected input specs for the pre-filled strategy
+                node_strategy = dp_strategy_map[node]
+                assert isinstance(node_strategy, DataParallelStrategy)
+                node_type = node_strategy.node_type
+                assert node_type == NodeType.GRAD
+                produce_param_grad_strat = node_strategy.strategies
+                has_activation = False
+                for arg in input_args:
+                    arg_strategy = dp_strategy_map[arg]
+                    assert isinstance(arg_strategy, DataParallelStrategy)
+                    arg_node_type = arg_strategy.node_type
+                    if arg_node_type == NodeType.ACT:
+                        # activation sharded
+                        has_activation = True
+                        act_spec = batch_dim_analyzer.compute_act_spec(arg, mesh)
+
+                        input_specs.append(act_spec)
+
+                if has_activation:
+                    assert len(produce_param_grad_strat) == 1
+                    produce_param_grad_strat[0].input_specs = input_specs
+            elif node.target in tuple_strategy_ops:
+                # ops that need to build tuple strategy instead of normal strategy
+                # This should happen rarely and only needed when we need to generate
+                # different node strategy for multiple outputs (i.e. fused_adam op)
+                # TODO: Currently this specializes to fused optimizer ops, but we need
+                # to see how to generalize this strategy building logic
+                output_strategy_len = len(node.args) - 1
+                tuple_strategies = []
+                for i in range(output_strategy_len):
+                    if not isinstance(node.args[i], list):
+                        raise RuntimeError(
+                            f"Expecting list as arg to build Tuple Strategy, but found type {type(node.args[i])}!"
+                        )
+                    # for list/tuple arg, use the first one to find out the node type
+                    if len(node.args[i]) > 0:
+                        arg_strategy = dp_strategy_map[node.args[i][0]]
+                        assert isinstance(arg_strategy, DataParallelStrategy)
+                        assert arg_strategy.node_type in [
+                            NodeType.PARAM,
+                            NodeType.GRAD,
+                            NodeType.STATE,
+                        ], "Expecting param/grad/state as arg to build Tuple Strategy!"
+                        replica_strategy = _gen_replicate_strategy(mesh)
+                        shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
+                        out_node_strategy: StrategyType = DataParallelStrategy(
+                            arg_strategy.node_type, [replica_strategy, shard_strategy]
+                        )
+
+                        tuple_strategies.append(out_node_strategy)
+
+                output_tuple_strategy = TupleStrategy(tuple(tuple_strategies))
+                dp_strategy_map[node] = output_tuple_strategy
+            else:
+                # NOTE: This is the common region for all regular computation ops
+
+                input_node_types = [
+                    cast(DataParallelStrategy, dp_strategy_map[arg]).node_type
+                    for arg in input_args
+                    if isinstance(dp_strategy_map[arg], DataParallelStrategy)
+                ]
+                if NodeType.GRAD in input_node_types:
+                    # param/state + grad, build up acceptable strategy
+                    # the strategy should be the same for all the inputs/outputs
+                    # TODO: optimizer parts should follow the dtensor prop logic
+                    # to support more general cases that allows optimizer states
+                    # to have different shardings compare to the params
+                    replica_strategy = _gen_replicate_strategy(mesh)
+                    shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
+                    output_node_type = NodeType.PARAM
+
+                    non_grad_types = [t for t in input_node_types if t != NodeType.GRAD]
+
+                    output_node_type = non_grad_types[0]
+                    for non_grad_type in non_grad_types:
+                        assert (
+                            non_grad_type == output_node_type
+                        ), f"Found more than one non grad types! Expect {output_node_type} but found {non_grad_type}!"
+                    assert output_node_type in [
+                        NodeType.PARAM,
+                        NodeType.STATE,
+                    ], f"Expecting output node type to be either state or param, but found {output_node_type}!"
+
+                    dp_strategy_map[node] = DataParallelStrategy(
+                        output_node_type, [replica_strategy, shard_strategy]
+                    )
+                elif NodeType.STATE in input_node_types:
+                    # either param + state or state + state
+                    replica_strategy = _gen_replicate_strategy(mesh)
+                    shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
+                    output_node_type = (
+                        NodeType.PARAM
+                        if NodeType.PARAM in input_node_types
+                        else NodeType.STATE
+                    )
+
+                    dp_strategy_map[node] = DataParallelStrategy(
+                        output_node_type, [replica_strategy, shard_strategy]
+                    )
+                elif NodeType.PARAM in input_node_types:
+                    if NodeType.ACT in input_node_types:
+                        # param + activation, build up acceptable strategy
+                        # param must be replicated, activation must be sharded
+                        for arg in input_args:
+                            arg_strategy = dp_strategy_map[arg]
+                            assert isinstance(arg_strategy, DataParallelStrategy)
+                            node_type = arg_strategy.node_type
+                            if node_type == NodeType.ACT:
+                                # compute activation spec
+                                act_spec = batch_dim_analyzer.compute_act_spec(
+                                    arg, mesh
+                                )
+
+                                input_specs.append(act_spec)
+                            elif node_type == NodeType.PARAM:
+                                # param must be replicated
+                                input_specs.append(
+                                    DTensorSpec(mesh=mesh, placements=(Replicate(),))
+                                )
+                            else:
+                                raise RuntimeError(
+                                    f"Expecting node with parameter and activation, but found {input_node_types}! "
+                                )
+                        # produce activation type sharding for output
+                        output_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
+
+                        act_strategy = PlacementStrategy(
+                            output_specs=output_spec, input_specs=input_specs
+                        )
+
+                        dp_strategy_map[node] = DataParallelStrategy(
+                            NodeType.ACT, [act_strategy]
+                        )
+                    else:
+                        # If inputs only have parameters, the
+                        # strategy of this node should follow input
+                        dp_strategy_map[node] = dp_strategy_map[input_args[0]]
+                else:
+                    # If input nodes does not have PARAM/GRAD/STATE, then
+                    # it should be a pure activation computation, it should
+                    # produce activation output.
+                    # Activations are usually sharded unless model creates
+                    # new tensors during computation, which depend on whether
+                    # the new tensor associate with a batch dim or not, it could
+                    # be shard/replicate/partial, batch dim analyzer should tell
+                    # us the correct sharding.
+                    for arg in input_args:
+                        arg_strategy = dp_strategy_map[arg]
+                        assert isinstance(arg_strategy, DataParallelStrategy)
+                        input_spec = batch_dim_analyzer.compute_act_spec(arg, mesh)
+
+                        input_specs.append(input_spec)
+
+                    act_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
+                    op_strategy = PlacementStrategy(
+                        output_specs=act_spec, input_specs=input_specs
+                    )
+                    dp_strategy_map[node] = DataParallelStrategy(
+                        NodeType.ACT, [op_strategy]
+                    )
+
+        elif node.op == "output":
+            dp_strategy_map[node] = DataParallelStrategy(NodeType.NON_TENSOR, [])
+        else:
+            raise RuntimeError(f"op code {node.op} not supported")
+
+    return dp_strategy_map  # type: ignore[return-value]
+
+
+def mark_data_parallel_shardings(
+    train_step_graph: GraphModule,
+    num_parameters: int,
+    num_states: int,
+    dp_strategy_map: Dict[fx.Node, StrategyType],
+    parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
+) -> None:
+    """Mark the sharding for the nodes in the train_step_graph."""
+    activation_idx = num_parameters + num_states
+    placeholder_idx = 0
+    for node in train_step_graph.graph.nodes:
+        node_strategy = dp_strategy_map[node]
+        if node.op == "placeholder":
+            assert isinstance(node_strategy, DataParallelStrategy)
+            node_type = node_strategy.node_type
+            node_strategies = node_strategy.strategies
+            if node_type == NodeType.NON_TENSOR:
+                # set node sharding to None
+                node_sharding = None
+            elif placeholder_idx < activation_idx:
+                assert len(node_strategies) > 0, "node_strategies should not be empty"
+                if parallel_mode == DataParallelStyle.REPLICATE:
+                    # set to replicate for replicate style
+                    node_sharding = node_strategies[0]
+                elif parallel_mode == DataParallelStyle.FULLY_SHARD:
+                    # set to shard for fully shard style
+                    if len(node_strategies) == 1:
+                        # only one strategy, use that instead
+                        # i.e. optimizer state steps can only be replicate
+                        node_sharding = node_strategies[0]
+                    else:
+                        # use the full sharding strategy
+                        node_sharding = node_strategies[1]
+                elif parallel_mode == DataParallelStyle.DEFAULT:
+                    # TODO: add support for default mode
+                    # default mode would generate either replicate or shard
+                    raise NotImplementedError("default mode not implemented")
+            else:
+                assert len(node_strategies) > 0, "node_strategies should not be empty"
+                # mark activation as sharded on batch dim
+                node_sharding = node_strategies[0]
+
+            node.meta["sharding"] = node_sharding  # type: ignore[possibly-undefined]
+
+            placeholder_idx += 1
+        elif node.op == "call_function":
+            if isinstance(node_strategy, TupleStrategy):
+                # For tuple strategy in the data parallel mode, it should have the same strategy
+                # for all tuple elements, assert that then use the first element's strategy as sharding
+                first_strategy = cast(DataParallelStrategy, node_strategy.childs[0])
+                for child_strategy in node_strategy.childs:
+                    assert isinstance(child_strategy, DataParallelStrategy)
+                    assert child_strategy.strategies == first_strategy.strategies
+
+                node_strategies = first_strategy.strategies
+            else:
+                assert isinstance(node_strategy, DataParallelStrategy)
+                node_strategies = node_strategy.strategies
+
+            assert (
+                len(node_strategies) <= 2
+            ), "data parallel should have at most 2 strategies"
+            if len(node_strategies) == 1:
+                node.meta["sharding"] = node_strategies[0]
+            elif len(node_strategies) == 2:
+                if parallel_mode == DataParallelStyle.REPLICATE:
+                    # set to replicate for replicate style
+                    node.meta["sharding"] = node_strategies[0]
+                elif parallel_mode == DataParallelStyle.FULLY_SHARD:
+                    # set to shard for fully shard style
+                    node.meta["sharding"] = node_strategies[1]
+                else:
+                    raise RuntimeError("default mode not supported yet!")
+            else:
+                raise RuntimeError(
+                    f"node {node} strategy length {len(node_strategies)} is not expected!"
+                )
+        elif node.op == "output":
+            assert (
+                isinstance(node_strategy, DataParallelStrategy)
+                and node_strategy.node_type == NodeType.NON_TENSOR
+            ), "output node should not be tensor"
+            node.meta["sharding"] = None
+        else:
+            raise RuntimeError(f"op code {node.op} not supported")
+
+
+def _partition_val(val: Any, spec: DTensorSpec) -> Any:
+    """Util function to convert a full tensor val to its local component."""
+    if isinstance(val, torch.Tensor):
+        local_shard = val
+        if val.ndim == 0:
+            # If it's already a scalar tensor, it is already local, we don't
+            # need to do anything
+            return local_shard
+
+        for idx, placement in enumerate(spec.placements):
+            if placement.is_shard():
+                placement = cast(Shard, placement)
+                num_chunks = spec.mesh.size(mesh_dim=idx)
+                my_coord = spec.mesh.get_coordinate()
+                assert my_coord is not None, "current rank not in mesh!"
+                my_coord_on_mesh_dim = my_coord[idx]
+                local_shard = placement._split_tensor(
+                    local_shard, num_chunks, with_padding=False, contiguous=False
+                )[0][my_coord_on_mesh_dim]
+        return local_shard
+    elif isinstance(val, (tuple, list)):
+        return val.__class__(_partition_val(v, spec) for v in val)
+    else:
+        raise RuntimeError(f"val type {type(val)} not supported")
+
+
+def partitioner(graph: GraphModule) -> GraphModule:
+    """Graph partitioner that partitions the single device graph to distributed graph."""
+    shape_adjustment_ops = {
+        aten._unsafe_view.default: 1,
+        aten.expand.default: 1,
+        aten.new_zeros.default: 1,
+        aten.ones.default: 0,
+        aten.reshape.default: 1,
+        aten.view.default: 1,
+        aten.zeros.default: 0,
+    }
+    # partition the graph to distributed
+    for node in graph.graph.nodes:
+        node_sharding = node.meta["sharding"]
+        # None sharding means this node don't need sharding
+        if node_sharding is None:
+            continue
+
+        if node.op == "placeholder":
+            out_spec = node_sharding.output_spec
+            if not hasattr(out_spec, "from_local"):
+                local_val = _partition_val(node.meta["val"], out_spec)
+                # update node value
+                node.meta["val"] = local_val
+        elif node.op == "call_function":
+            out_spec = node_sharding.output_spec
+
+            # check if there's misaligned sharding, insert reshard if there is
+            expected_input_specs = node_sharding.input_specs
+            for idx, input_arg in enumerate(node.all_input_nodes):
+                input_arg_sharding = input_arg.meta["sharding"]
+
+                input_arg_spec = input_arg_sharding.output_spec
+                desired_spec = (
+                    out_spec
+                    if expected_input_specs is None
+                    else expected_input_specs[idx]
+                )
+                if input_arg_spec != desired_spec:
+                    input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
+                    desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
+                    input_arg_tensor = input_arg.meta["val"]
+
+                    # insert reshard operation
+                    def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
+                        return redistribute_local_tensor(
+                            local_tensor,
+                            input_arg_spec,
+                            desired_spec,
+                        )
+
+                    reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
+                    reshard_gm_nodes = list(reshard_gm.graph.nodes)
+                    input_node = reshard_gm_nodes[0]
+                    with graph.graph.inserting_before(node):
+                        output_node = graph.graph.graph_copy(
+                            reshard_gm.graph,
+                            val_map={
+                                input_node: input_arg,
+                            },
+                        )
+                    node.replace_input_with(input_arg, output_node)
+
+            output_val = node.meta["val"]
+
+            if node.target == torch.ops.aten.repeat.default:
+                # for repeat op, we need to infer the repeat sizes
+                assert isinstance(output_val, torch.Tensor)
+                local_shape = compute_local_shape(
+                    output_val.shape, out_spec.mesh, out_spec.placements
+                )
+                input_shape = node.args[0].meta["val"].shape
+
+                def infer_repeat_sizes(repeated_shape, input_shape):
+                    repeated_size = [1] * len(repeated_shape)
+                    padded_length = len(repeated_shape) - len(input_shape)
+                    for i in range(len(repeated_shape)):
+                        if i < padded_length:
+                            repeated_size[i] = repeated_shape[i]
+                        else:
+                            repeated_size[i] = (
+                                repeated_shape[i] // input_shape[i - padded_length]
+                            )
+
+                    return repeated_size
+
+                node.update_arg(1, infer_repeat_sizes(local_shape, input_shape))
+
+            elif node.target in shape_adjustment_ops:
+                # for view related op that needs shape, adjust shape to local shape if needed
+                assert isinstance(output_val, torch.Tensor)
+                local_shape = compute_local_shape(
+                    output_val.shape, out_spec.mesh, out_spec.placements
+                )
+                shape_arg_num = shape_adjustment_ops[node.target]
+                node.update_arg(shape_arg_num, local_shape)
+
+            # convert output val to its local component
+            node.meta["val"] = _partition_val(output_val, out_spec)
+
+        elif node.op == "output":
+            break
+        else:
+            raise RuntimeError(f"op code {node} not supported")
+
+    # clean up the graph by removing sharding and partitioning related metadata
+    for node in graph.graph.nodes:
+        if "sharding" in node.meta:
+            del node.meta["sharding"]
+        if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
+            local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
+            node.meta["tensor_meta"] = local_tensor_meta
+
+    graph.graph.lint()
+    graph.recompile()
+    return graph
+
+
+def partition_data_parallel(
+    graph: GraphModule,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer],
+    params_buffers: Dict[str, torch.Tensor],
+    named_states: Dict[str, Any],
+    args: Tuple[Any, ...],
+    kwargs: Dict[str, Any],
+    mesh: DeviceMesh,
+    parallel_style: DataParallelStyle,
+    input_batch_dim: int,
+) -> GraphModule:
+    """Partition the graph to into a data parallel graph.
+
+    This function also shards/replicates the model parameters and optimizer states to DTensors.
+    """
+    num_params_buffers = len(params_buffers)
+    flattened_states = pytree.tree_leaves(named_states)
+    num_states = len(flattened_states)
+
+    changed = graph.graph.eliminate_dead_code()
+    if changed:
+        graph.recompile()
+
+    # 1. First build up data parallel strategies for the whole graph
+    strategy_map = build_data_parallel_strategies(
+        graph, num_params_buffers, num_states, mesh=mesh, batch_dim=input_batch_dim
+    )
+
+    # 2. Next we mark the data parallel strategy for each node base on
+    #    the parallel_style
+    mark_data_parallel_shardings(
+        graph,
+        num_parameters=num_params_buffers,
+        num_states=num_states,
+        dp_strategy_map=strategy_map,
+        parallel_mode=parallel_style,
+    )
+
+    # 3. Partition the single machine graph to the distribute graph
+    partitioned_graph = partitioner(graph)
+
+    # preserve node types for the expanded graph
+    for node in partitioned_graph.graph.nodes:
+        if node in strategy_map:
+            node_strategy = strategy_map[node]
+            if isinstance(node_strategy, DataParallelStrategy):
+                node.meta["node_type"] = node_strategy.node_type
+            elif isinstance(node_strategy, TupleStrategy):
+                node.meta["node_type"] = NodeType.NON_TENSOR
+            else:
+                raise RuntimeError(f"Unknown node strategy {node_strategy}")
+        else:
+            # if the nodes are expanded nodes (collectives), we mark them
+            # the same type as the input node.
+            input_node = node.all_input_nodes[0]
+            node.meta["node_type"] = input_node.meta["node_type"]
+
+    # 4. Last, inplace partition the weights and optim states to
+    #    DTensors base on the parallel style
+    accessor = NamedMemberAccessor(model)
+    for param_key, param in params_buffers.items():
+        placement: Placement = Replicate()
+        if parallel_style == DataParallelStyle.FULLY_SHARD:
+            placement = Shard(0)
+        elif parallel_style != DataParallelStyle.REPLICATE:
+            raise RuntimeError(f"parallel style {parallel_style} not supported yet")
+
+        dtensor_param = distribute_tensor(param, mesh, [placement])
+        # update re-parameterized module param dict and optim states dict to DTensor
+        params_buffers[param_key] = dtensor_param.to_local()
+        # update module parameters to DTensor
+        accessor.set_tensor(param_key, dtensor_param)
+
+        # update the optimizer state key and values to DTensor
+        if optimizer is not None and param in optimizer.state:
+            param_states = named_states[param_key]
+            param_dtensor_states = {}
+            for state_key, state_val in param_states.items():
+                if isinstance(state_val, torch.Tensor) and state_val.ndim > 0:
+                    # shard/replicate non-scalar tensors, for scalar tensor, we
+                    # don't do anything
+                    dtensor_state = distribute_tensor(state_val, mesh, [placement])
+                    param_dtensor_states[state_key] = dtensor_state
+                    param_states[state_key] = dtensor_state.to_local()
+                else:
+                    param_dtensor_states[state_key] = state_val
+
+            optimizer.state.pop(param)  # type: ignore[call-overload]
+            optimizer.state[dtensor_param] = param_dtensor_states  # type: ignore[index]
+
+    return partitioned_graph
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/distribute.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/distribute.py
new file mode 100644
index 0000000000000000000000000000000000000000..daf2ed4b9daec85b0634fefb11de32dd9fa59f2b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/distribute.py
@@ -0,0 +1,783 @@
+import logging
+import operator
+from dataclasses import dataclass
+from enum import auto, Enum
+from functools import partial
+from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.distributed._spmd.experimental_ops
+import torch.fx as fx
+
+from torch.distributed._spmd.comm_tensor import _get_tracer
+from torch.distributed._spmd.graph_utils import OP
+from torch.distributed._spmd.log_utils import get_logger
+
+from torch.distributed._tensor import DeviceMesh, DTensor
+from torch.distributed._tensor.op_schema import OpSchema
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+    TensorMeta,
+)
+from torch.distributed._tensor.redistribute import redistribute_local_tensor
+from torch.fx.experimental.proxy_tensor import make_fx, proxy_slot
+from torch.utils import _pytree as pytree
+from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten
+
+
+logger: Optional[logging.Logger] = None
+
+aten = torch.ops.aten
+
+
+class TrainingPhase(Enum):
+    FORWARD = auto()
+    BACKWARD = auto()
+
+
+@dataclass
+class Schema:
+    mesh: DeviceMesh
+    placements: List[Placement]
+
+
+@dataclass
+class DSymInt:
+    """DSymInt represents a value retrieved by a SymInt op from a DTensor.
+
+    DSymInt helps View and Factory ops to determine the placement and shape of the
+    output tensor, as those operators either do not have an input DTensor or
+    the input DTensor is insufficient to determine the output tensor's placement.
+    """
+
+    global_value: int  # value that the SymInt evaluates to
+    local_value: int  # vaue that this SymInt evaluates to on the local shard
+    mesh: DeviceMesh  # device mesh of the DTensor where this SymInt is retrieved from
+
+    def is_shard(self) -> bool:
+        return self.local_value != self.global_value
+
+    @classmethod
+    def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt":
+        dim: int = 0
+        if node.target == aten.sym_size:
+            dim = cast(int, node.args[1])
+            return cls(
+                global_value=dtensor.size(dim),
+                local_value=dtensor.to_local().size(dim),
+                mesh=dtensor.device_mesh,
+            )
+        elif node.target == aten.sym_numel:
+            return cls(
+                global_value=dtensor.numel(),
+                local_value=dtensor.to_local().numel(),
+                mesh=dtensor.device_mesh,
+            )
+        elif node.target == aten.sym_stride:
+            dim = cast(int, node.args[1])
+            return cls(
+                global_value=dtensor.stride(dim),
+                local_value=dtensor.to_local().stride(dim),
+                mesh=dtensor.device_mesh,
+            )
+        else:
+            raise NotImplementedError(f"DSymInt does not support {node.target}")
+
+
+def _is_partial_dtensor(obj: Any) -> bool:
+    """Check if object is 1) DTensor and  2) with any placement of _Partial."""
+    if not isinstance(obj, DTensor):
+        return False
+
+    is_partial = False
+    for placement in obj.placements:
+        if isinstance(placement, _Partial):
+            is_partial = True
+            break
+
+    return is_partial
+
+
+def _dispatch_with_local_tensors(
+    op: torch._ops.OpOverload,
+    local_args: Tuple[Any, ...],
+    kwargs: Optional[Dict[str, Any]] = None,
+    specs: Optional[
+        Dict[
+            torch.Tensor,
+            Tuple[torch.Size, DeviceMesh, Sequence[Placement], Sequence[Placement]],
+        ]
+    ] = None,
+) -> Any:
+    if kwargs is None:
+        kwargs = {}
+    if specs is None:
+        specs = {}
+
+    def redistribute(arg: Any) -> Any:
+        tensor_shape, mesh, current_placement, target_placement = specs[arg]
+        tensor_meta = TensorMeta(
+            tensor_shape,
+            stride=arg.stride(),
+            dtype=arg.dtype,
+        )
+        current_spec = DTensorSpec(
+            mesh, tuple(current_placement), tensor_meta=tensor_meta
+        )
+        target_spec = DTensorSpec(
+            mesh, tuple(target_placement), tensor_meta=tensor_meta
+        )
+
+        return (
+            redistribute_local_tensor(arg, current_spec, target_spec)  # type: ignore[index]
+            if isinstance(arg, torch.Tensor) and arg in specs  # type: ignore[operator]
+            else arg
+        )
+
+    # TODO: this is broken because it won't redistributed potential tensors on the kwargs
+    return op(*tree_map(redistribute, local_args), **kwargs)
+
+
+# Figure out how to specify a type spec for the return specs value
+# without the entire structure.
+# pyre-fixme
+def _update_specs_for_redistribute(args, target_schema, redistribute):
+    # Code adapted from pack_args_kwargs_with_local_tensor
+    flatten_args, args_tree_spec = tree_flatten(args)
+    flatten_args_schema = pytree.tree_leaves(target_schema.args_schema)
+
+    specs: Dict[
+        torch.Tensor,
+        Tuple[
+            torch.Size,
+            DeviceMesh,
+            Sequence[Placement],
+            Sequence[Placement],
+        ],
+    ] = {}
+    for i, arg in enumerate(flatten_args):
+        if isinstance(arg, DTensor):
+            if redistribute:
+                specs[arg._local_tensor] = (
+                    arg.size(),
+                    flatten_args_schema[i].mesh,
+                    arg.placements,
+                    flatten_args_schema[i].placements,
+                )
+            flatten_args_schema[i] = arg._local_tensor
+
+    unflattened_args = tree_unflatten(flatten_args_schema, args_tree_spec)
+    return specs, unflattened_args
+
+
+# When no tensor redistribution is required, we only need to update non-tensor args
+# of the node according to op_schema and avoid building a GraphModule just for the
+# node.
+def _update_node_from_op_schema(node: torch.fx.Node, op_schema: OpSchema) -> None:
+    flat_args, args_tree_spec = tree_flatten(node.args)
+    flat_args_schema = pytree.tree_leaves(op_schema.args_schema)
+
+    def is_sym_int_or_int(arg: Union[int, torch.fx.Node]) -> bool:
+        if isinstance(arg, torch.fx.Node):
+            return arg.target in [
+                aten.sym_size,
+                aten.sym_numel,
+                aten.sym_stride,
+            ]
+        return isinstance(arg, int)
+
+    assert len(flat_args) == len(flat_args_schema)
+    for i, (arg, arg_schema) in enumerate(zip(flat_args, flat_args_schema)):
+        if is_sym_int_or_int(arg) and isinstance(arg_schema, int):
+            flat_args[i] = arg_schema
+
+    args = tree_unflatten(flat_args, args_tree_spec)
+    for idx, arg in enumerate(args):
+        node.update_arg(idx, arg)
+    return None
+
+
+def _remap_arg(node_to_obj: Dict[fx.Node, Any], arg: Any) -> Any:
+    if isinstance(arg, torch.fx.Node):
+        obj = node_to_obj[arg]
+        if _get_tracer():
+            # This is a shared arg, already has a tracer from previous
+            # tracing. Delete the tracer.
+            del cast(Dict[Any, Any], obj.__dict__)[proxy_slot]
+        return obj
+    else:
+        return arg
+
+
+def unpack_sizes_and_dims(
+    sizes: List[Union[DSymInt, int]], mesh: DeviceMesh
+) -> Tuple[List[int], List[Placement]]:
+    local_sizes: List[int] = [
+        s.local_value if isinstance(s, DSymInt) else s for s in sizes
+    ]
+    placements: List[Placement] = [
+        Shard(i)
+        for i, a in enumerate(sizes)
+        if (isinstance(a, DSymInt) and a.is_shard())
+    ] or [Replicate()]
+
+    assert len(placements) == mesh.ndim, (
+        f"The number of sharded dimensions ({len(placements)}) must "
+        f"match number of dimensions in device mesh ({mesh.ndim})."
+    )
+
+    return local_sizes, placements
+
+
+def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor:
+    assert len(args) == 2, f"Expect two args but got op {node.target} with args {args}"
+    assert isinstance(
+        args[0], DTensor
+    ), f"Expect 1st argument to be DTensor but got {args[0]}"
+    assert isinstance(args[1], list), f"Expect 2nd argument as list but got {args[1]}"
+
+    # extract sharded dimensions in the size list, the output DTensor should
+    # follow these placements.
+    local_sizes, placements = unpack_sizes_and_dims(args[1], args[0].device_mesh)
+
+    # set node args to real int sizes.
+    node.args = (node.args[0], local_sizes)
+    op = cast(torch._ops.OpOverload, node.target)
+    return DTensor.from_local(
+        local_tensor=op(args[0]._local_tensor, local_sizes),
+        device_mesh=args[0].device_mesh,
+        placements=placements,
+        run_check=False,
+    )
+
+
+def slice_backwad_sym_int_consumer_rule(
+    node: fx.Node, args: Tuple[Any, ...]
+) -> DTensor:
+    grad_output, input_sizes, dim, start, end, step = args
+
+    local_sizes: List[int] = [
+        s.local_value if isinstance(s, DSymInt) else s for s in input_sizes
+    ]
+
+    input_tensor = torch.zeros(
+        local_sizes, device=grad_output.device, dtype=grad_output.dtype
+    )
+    return DTensor.from_local(
+        local_tensor=torch.slice_scatter(
+            input_tensor, grad_output.to_local(), dim, start, end, step
+        ),
+        device_mesh=grad_output.device_mesh,
+        placements=grad_output.placements,
+        run_check=False,
+    )
+
+
+def factory_with_sizes_rule(
+    node: fx.Node,
+    args: Tuple[Any, ...],
+    kwargs: Dict[str, Any],
+    default_mesh: DeviceMesh,
+) -> DTensor:
+    flat_args = pytree.arg_tree_leaves(*args)
+    assert not any(isinstance(a, DTensor) for a in flat_args), (
+        f"Not expect DTensor argument for factory op, but got {node.target} "
+        f"with arguments {args}."
+    )
+    assert isinstance(args[0], list), f"Expect 2nd argument as list but got {args[1]}"
+
+    local_sizes, placements = unpack_sizes_and_dims(args[0], default_mesh)
+    node.args = (local_sizes, *args[1:])
+    op = cast(torch._ops.OpOverload, node.target)
+    return DTensor.from_local(
+        local_tensor=op(*node.args, **kwargs),
+        device_mesh=default_mesh,
+        placements=placements,
+        run_check=False,
+    )
+
+
+def factory_arange_rule(
+    node: fx.Node,
+    args: Tuple[Any, ...],
+    kwargs: Dict[str, Any],
+    default_mesh: DeviceMesh,
+) -> DTensor:
+    node.args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
+    op = cast(torch._ops.OpOverload, node.target)
+    return DTensor.from_local(
+        local_tensor=op(*node.args, **kwargs),
+        device_mesh=default_mesh,
+        placements=[Replicate()],
+        run_check=False,
+    )
+
+
+def default_factory_op_rule(
+    node: fx.Node,
+    args: Tuple[Any, ...],
+    kwargs: Dict[str, Any],
+    default_mesh: DeviceMesh,
+) -> DTensor:
+    node.args, node.kwargs = args, kwargs
+    op = cast(torch._ops.OpOverload, node.target)
+    return DTensor.from_local(
+        local_tensor=op(*node.args, **node.kwargs),
+        device_mesh=default_mesh,
+        placements=[Replicate()],
+        run_check=False,
+    )
+
+
+# Dispatch override for view and factory ops that consume SymInt arguments,
+# where the output spec should follow dimension placement where the SymInt comes
+# from.
+VIEW_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
+    aten._unsafe_view.default: binop_sym_int_consumer_rule,
+    aten.expand.default: binop_sym_int_consumer_rule,
+    aten.slice_backward.default: slice_backwad_sym_int_consumer_rule,
+    aten.view.default: binop_sym_int_consumer_rule,
+}
+
+FACTORY_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
+    aten.full.default: factory_with_sizes_rule,
+    aten.arange.default: factory_arange_rule,
+    aten.arange.start: factory_arange_rule,
+}
+
+
+# Dispatch override for factory ops, as DTensor cannot propogate sharding spec
+# without DTensor inputs.
+FACTORY_OPS: Dict[torch._ops.OpOverload, Callable] = {
+    aten.scalar_tensor.default: default_factory_op_rule,
+    aten.arange.start: default_factory_op_rule,
+    aten.zeros.default: default_factory_op_rule,
+}
+
+
+def _get_dtensor_dispatch_graph(
+    node: fx.Node,
+    node_to_obj: Dict[fx.Node, Any],
+    *,
+    force_make_fx: bool = False,
+    default_mesh: Optional[DeviceMesh] = None,
+) -> Optional[fx.GraphModule]:
+    with torch.no_grad():
+        # Args should be a list of objects post remapping.
+        args = tree_map(partial(_remap_arg, node_to_obj), node.args)
+        kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
+
+        op_overload = cast(torch._ops.OpOverload, node.target)
+
+        if any(
+            a.is_shard()
+            for a in pytree.arg_tree_leaves(*args)
+            if isinstance(a, DSymInt)
+        ):
+            if op_overload in VIEW_SYM_INT_CONSUMERS:
+                assert len(kwargs) == 0, f"Expect empty kwargs, but got {kwargs}"
+                node_to_obj[node] = VIEW_SYM_INT_CONSUMERS[op_overload](node, args)
+                return None
+            elif op_overload in FACTORY_SYM_INT_CONSUMERS:
+                assert default_mesh is not None, "Requires default mesh for factory ops"
+                node_to_obj[node] = FACTORY_SYM_INT_CONSUMERS[op_overload](
+                    node, args, kwargs, default_mesh
+                )
+                return None
+            else:
+                assert isinstance(logger, logging.Logger)
+                logger.warning(
+                    "Assuming using local_value from SymInt for %s"
+                    "is mathematically correct. Full args are %s.",
+                    op_overload,
+                    args,
+                )
+
+        if node.target == aten.view.default:
+            # HACK: this is a hack to get around with the fact that some
+            # view operations on a "global" tensor is invalid usage
+            # but somehow the view operation on the batch input might hit it
+            # so we convert the view op to reshape before calling DTensor
+            op_overload = aten.reshape.default
+
+        # DSymInt args are not sharded on any dimension, local value and global
+        # value should be the same
+        args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args)
+        kwargs = tree_map(
+            lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs
+        )
+
+        if op_overload in FACTORY_OPS:
+            # Don't pass factory ops to DTensor dispatch, as DTensor cannot
+            # propagate sharding spec without DTensor inputs.
+            node_to_obj[node] = FACTORY_OPS[op_overload](
+                node, args, kwargs, default_mesh
+            )
+            return None
+
+        dispatch = partial(
+            _dispatch_with_local_tensors,
+            op_overload,
+            kwargs=kwargs,
+            specs=args,
+        )
+
+        gm = make_fx(dispatch, _allow_non_fake_inputs=False)(args)
+        # FIXME(@wanchaol, @mrshenli): the above seems to accidentally captured
+        # DeviceMesh tensor ops when handling inplace operators? The ``_to_copy`` is
+        # not connected to graph output. So, using DCE to get rid of it, but this
+        # doesn't look correct.
+        #
+        # The following operators appear in the captured graph, where the dtype is
+        # torch.int64.
+        #
+        # get_attr       _tensor_constant0  _tensor_constant0         ()
+        # call_function  transpose          aten.transpose.int        (_tensor_constant0, -1, 0)
+        # call_function  view               aten.view.default         (transpose, [-1, 2])
+        # call_function  view_1             aten.view.default         (view, [2])
+        # call_function  _to_copy           aten._to_copy.default     (view_1,)
+        gm.graph.eliminate_dead_code()
+
+        return gm
+
+
+def _build_dummy_add_graph(
+    dt: DTensor, node_to_obj: Dict[fx.Node, Any]
+) -> Tuple[fx.GraphModule, Any]:
+    """Create a graph for a dummy add function from a partial DTensor.
+
+    This dummy add is used for triggering all_reduce on a Partial DTensor
+    during the DTensor expansion of the traced graph.
+    Also returns the actual DTensor after resharding.
+    """
+
+    def dummy_add(grad: torch.Tensor, zero: torch.Tensor) -> torch.Tensor:
+        return grad + zero
+
+    grad: torch.Tensor = dt._local_tensor
+    zero: torch.Tensor = torch.zeros_like(dt._local_tensor)
+
+    traced_add = make_fx(dummy_add)(grad, zero)
+
+    placeholders = [n for n in traced_add.graph.nodes if n.op == OP.PLACEHOLDER]
+    call_functions = [n for n in traced_add.graph.nodes if n.op == OP.CALL_FUNCTION]
+    assert len(placeholders) == 2
+    assert len(call_functions) == 1
+    node_to_obj[placeholders[0]] = dt
+    node_to_obj[placeholders[1]] = DTensor.from_local(
+        zero, dt.device_mesh, [Replicate()], run_check=False
+    )
+
+    traced_dispatch = _get_dtensor_dispatch_graph(
+        call_functions[0], node_to_obj, force_make_fx=True
+    )
+    assert traced_dispatch is not None
+
+    # TODO(anj): This depends on the call function node -> actual DTensor output
+    # mapping that we want to avoid for SPMD expansion
+    return traced_dispatch, node_to_obj[call_functions[0]]
+
+
+def _convert_output(
+    gm: fx.GraphModule,
+    node: fx.Node,
+    node_to_obj: Dict[fx.Node, Any],
+) -> fx.Node:
+    new_args = []
+    has_partial = False
+    for argument in node.args[0]:  # type: ignore[union-attr]
+        if not isinstance(argument, fx.Node):
+            new_args.append(argument)
+            continue
+
+        obj = node_to_obj[argument]
+
+        if not _is_partial_dtensor(obj):
+            new_args.append(argument)
+            continue
+
+        has_partial = True
+
+        # we know it's a dtensor from is partial DT check...
+        dt = cast(DTensor, obj)
+
+        traced_dispatch, result_obj = _build_dummy_add_graph(dt, node_to_obj)
+
+        wait = [
+            n
+            for n in traced_dispatch.graph.nodes
+            if n.name == "wait_comm" or n.name == "wait_tensor"
+        ]
+        add = [n for n in traced_dispatch.graph.nodes if n.name == "add"]
+        assert len(wait) == 1 and len(add) == 1
+
+        # remove add node and replace it with wait node
+        add[0].replace_all_uses_with(wait[0])
+        traced_dispatch.graph.eliminate_dead_code()
+        # also update the actual DTensor corresponding to the node
+        # TODO(anj): We require mapping of the final DTensor output to the wait
+        # comm node.
+        node_to_obj[wait[0]] = result_obj
+
+        value_remap: Dict[fx.Node, fx.Node] = {}
+        for dtn in traced_dispatch.graph.nodes:
+            if dtn.op == OP.PLACEHOLDER:
+                # do nothing, ignore placeholders, as it has
+                # already been prepared in value_remap
+                value_remap[dtn] = argument
+            elif dtn.op == OP.OUTPUT:
+                assert (
+                    len(dtn.args) == 1 and len(dtn.args[0]) == 1
+                ), f"Expecting single output, but got {dtn.args} {len(dtn.args)}"
+                new_args.append(value_remap[dtn.args[0][0]])
+                # the concrete DTensor value of output was added when creating the
+                # inner graph (in _build_dummy_add_graph). Just add it to the final
+                # output node so that we can report the final output specs correctly.
+                # TODO(anj): We are depending on the concrete DTensor output of the dummy add.
+                node_to_obj[value_remap[dtn.args[0][0]]] = node_to_obj[dtn.args[0][0]]
+
+            else:
+                if dtn.op == OP.GET_ATTR:
+                    setattr(
+                        gm,
+                        dtn.target,
+                        getattr(traced_dispatch, dtn.target),
+                    )
+                with gm.graph.inserting_before(node):
+                    value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
+    if has_partial:
+        gm.graph.erase_node(node)
+        return gm.graph.output(new_args)
+    else:
+        return node
+
+
+def _rebuild_graph(
+    gm: fx.GraphModule,
+    node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule],
+) -> None:
+    # replace nodes in local traced graph with DTensor's dispatch graph
+    for node in gm.graph.nodes:
+        if node not in node_replacements:
+            continue
+
+        traced_dispatch = node_replacements[node]
+        # Map DT's dispatch graph input placeholder nodes to the ones in
+        # local traced graph. It uses index-based accessing, which is
+        # brittle, just for testing purpose.
+        flatten_args = pytree.arg_tree_leaves(*node.args)
+        i, value_remap = 0, {}
+        for dtn in traced_dispatch.graph.nodes:
+            if dtn.op == OP.PLACEHOLDER:
+                value_remap[dtn] = flatten_args[i]
+                i += 1
+
+        # insert DT's dispatch graph to traced local graph.
+        with gm.graph.inserting_before(node):
+            for dtn in traced_dispatch.graph.nodes:
+                if dtn.op == OP.PLACEHOLDER:
+                    # do nothing, ignore placeholders, as it has already
+                    # been prepared in value_remap
+                    pass
+                elif dtn.op == OP.OUTPUT:
+                    assert (
+                        len(dtn.args) == 1
+                    ), f"Expecting single output, but got {dtn.args} {len(dtn.args[0])}"
+                    outputs = dtn.args[0]
+                    # we currently support two very specific types of output
+                    # 1. single output
+                    # 2. multiple outputs resulting from getitem of all elements of tuple
+                    if len(outputs) == 1:
+                        # for single output, we replace the node with the single node
+                        output = outputs[0]
+                    else:
+                        # for multiple outputs, we check that these outputs correspond
+                        # to all elements of a tuple. In that case, we replace
+                        # uses of the output directly with the original tuple
+                        source = None
+                        for i, out in enumerate(outputs):
+                            # we allow None outputs for certain items in the tuple
+                            if out is None:
+                                continue
+                            assert out.op == "call_function"
+                            assert out.target.__module__ == "_operator"
+                            assert out.target.__name__ == "getitem"
+                            assert source is None or source == out.args[0]
+                            source = out.args[0]
+                            assert out.args[1] == i
+                        assert source is not None
+                        output = source
+
+                    new_node = value_remap[output]
+                    node.replace_all_uses_with(new_node)
+                else:
+                    value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n])
+                    if all(
+                        isinstance(n.target, torch._ops.OpOverload)
+                        and n.target._schema.name.startswith(
+                            ("aten::_foreach", "aten::_fused_adam")
+                        )
+                        for n in [dtn, node]
+                    ):
+                        # FIXME(@mrshenli): This is a temporary solution enable
+                        # foreach ops. The problem is that foreach ops returns
+                        # List[Tensor], but make_fx will flatten that before
+                        # passing those tensors to output node, which will
+                        # introduce additional getitem nodes. These redundant
+                        # getitem nodes breaks graph correctness as we cannot do
+                        # getitem(getitem(foreach_out, 0), 0). This temporary
+                        # solution skips getitem nodes in DTensor expanded
+                        # subgraphs.
+                        node.replace_all_uses_with(value_remap[dtn])
+                        break
+            # explicitly erase node instead of relying on DCE, as DCE does not
+            # remove inplace copy_ correctly.
+            gm.graph.erase_node(node)
+
+    gm.graph.eliminate_dead_code()
+    gm.recompile()
+
+
+def _get_last_consumer_to_nodes(
+    graph: fx.Graph,
+) -> Dict[fx.Node, List[fx.Node]]:
+    # Run through reverse nodes and record the first instance of a use
+    # of a given node. This represents the *last* use of the node in the
+    # execution order of the program, which we will use to free unused
+    # values
+    node_to_last_consumer: Dict[fx.Node, fx.Node] = {}
+    last_consumer_to_nodes: Dict[fx.Node, List[fx.Node]] = {}
+
+    def _register_final_consumer(arg_node: fx.Node, consumer: fx.Node) -> None:
+        if arg_node not in node_to_last_consumer:
+            node_to_last_consumer[arg_node] = consumer
+            last_consumer_to_nodes.setdefault(consumer, []).append(arg_node)
+
+    for node in reversed(graph.nodes):
+        fx.node.map_arg(
+            node.args, lambda arg_node: _register_final_consumer(arg_node, node)
+        )
+        fx.node.map_arg(
+            node.kwargs,
+            lambda kwarg_node: _register_final_consumer(kwarg_node, node),
+        )
+
+    return last_consumer_to_nodes
+
+
+def _convert_to_distributed(
+    gm: fx.GraphModule,
+    inps: List[torch.Tensor],
+    schemas: List[Schema],
+    default_mesh: Optional[DeviceMesh] = None,
+    _allow_partial: bool = False,
+) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
+    """Transform a graph module to a distributed graph module.
+
+    Returns:
+        - transformed graph module
+        - map from output name to DTensorSpec
+
+    """
+    global logger
+    logger = get_logger("spmd_exp")
+    operators = {getattr(operator, name) for name in operator.__all__}
+    node_to_obj: Dict[fx.Node, Any] = {}
+    # map local op node in traced_f to its corresponding subgraph of
+    # DTensor ops.
+    node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule] = {}
+
+    last_consumer_to_nodes = _get_last_consumer_to_nodes(gm.graph)
+
+    output_schemas: Dict[str, Schema] = {}
+    for i, node in enumerate(gm.graph.nodes):
+        assert logger is not None
+        logger.info("node%s: op=%s target=%s", i, node.op, node.target)
+        if node.op == OP.PLACEHOLDER:
+            assert i < len(
+                inps
+            ), f"got more placeholder nodes ({i + 1}) than inputs ({len(inps)})"
+
+            # our example inputs are local shards. Create DTensors from them.
+            node_to_obj[node] = DTensor.from_local(
+                inps[i].clone(),  # use clone to avoid modifications from inplace ops
+                schemas[i].mesh,
+                schemas[i].placements,
+                # prevent running this collective in backwards pass
+                run_check=False,
+            )
+        elif isinstance(node.target, torch._ops.OpOverloadPacket):
+            dtensor = cast(DTensor, node_to_obj[node.args[0]])
+            node_to_obj[node] = DSymInt.from_node(node, dtensor)
+        elif isinstance(node.target, torch._ops.OpOverload):
+            replacement = _get_dtensor_dispatch_graph(
+                node, node_to_obj, default_mesh=default_mesh
+            )
+            if replacement is not None:
+                node_replacements[node] = replacement
+        elif node.op == OP.OUTPUT:
+            if not _allow_partial:
+                # Returns an expanded dummy add node that ensures
+                # that the partial output tensor has been converted
+                # to a replicated tensor.
+                node = _convert_output(gm, node, node_to_obj)
+
+            # Save output sharding for the inputs to backward pass.
+            # TODO(anj): Pipe the output schema for the BW pass
+            # instead of requiring the full output DTensor to be
+            # materialized.
+            for inp_arg in node.args[0]:
+                if isinstance(inp_arg, fx.Node):
+                    obj = node_to_obj[inp_arg]
+                    if isinstance(obj, DTensor):
+                        output_schemas[inp_arg.name] = Schema(
+                            obj.device_mesh, obj.placements  # type: ignore[arg-type]
+                        )
+        elif node.op == OP.CALL_FUNCTION:
+            args = tree_map(partial(_remap_arg, node_to_obj), node.args)
+            kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
+
+            dsymints = list(
+                filter(lambda a: isinstance(a, DSymInt), args + tuple(kwargs.values()))
+            )
+
+            if node.target in operators and len(dsymints) > 0:
+                assert all(
+                    dsymints[0].mesh == d.mesh for d in dsymints
+                ), "all DSymInts must have the same mesh. "
+
+                local_args = tree_map_only(DSymInt, lambda a: a.local_value, args)
+                local_kwargs = tree_map_only(DSymInt, lambda a: a.local_value, kwargs)
+
+                global_args = tree_map_only(DSymInt, lambda a: a.global_value, args)
+                global_kwargs = tree_map_only(DSymInt, lambda a: a.global_value, kwargs)
+
+                node.args = local_args
+                node.kwargs = local_kwargs
+
+                node_to_obj[node] = DSymInt(
+                    local_value=node.target(*local_args, **local_kwargs),
+                    global_value=node.target(*global_args, **global_kwargs),
+                    mesh=dsymints[0].mesh,
+                )
+            else:
+                assert len(dsymints) == 0, (
+                    "SPMD expansion does not support SymInt in non-operator "
+                    f"nodes, got {node.target}."
+                )
+                node_to_obj[node] = node.target(*args, **kwargs)
+        else:
+            raise ValueError(f"Unrecognized node.op type {node.op}")
+
+        if node in last_consumer_to_nodes:
+            # Save memory by deleting objs that wont be used anymore.
+            for arg_node in last_consumer_to_nodes[node]:
+                del node_to_obj[arg_node]
+
+    _rebuild_graph(gm, node_replacements)
+
+    return gm, output_schemas
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/experimental_ops.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/experimental_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f42c52e2bf88a0d56feec928a4b157806bb480c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/experimental_ops.py
@@ -0,0 +1,455 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import cast, List, Optional, Sequence, Tuple
+
+import torch
+from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
+from torch.distributed._tensor.ops.common_rules import pointwise_rule
+from torch.distributed._tensor.ops.utils import register_prop_rule
+
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+    TensorMeta,
+)
+
+aten = torch.ops.aten  # pyre-ignore
+
+
+@register_prop_rule(  # pyre-ignore
+    [
+        aten._foreach_neg.default,
+        aten._foreach_reciprocal.default,
+        aten._foreach_sqrt.default,
+    ]
+)
+def _prop__foreach_unaop(op_schema: OpSchema) -> OutputSharding:
+    self = op_schema.args_schema[0]
+    assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
+    # FIXME(@mrshenli): for sqrt, this is only mathematically correct for
+    # Replicate and Shard tensor.
+    return OutputSharding(output_spec=self)
+
+
+@register_prop_rule(  # pyre-ignore
+    [
+        aten._foreach_add.List,
+        aten._foreach_div.List,
+        aten._foreach_mul.List,
+    ]
+)
+def _prop__foreach_binop_list(op_schema: OpSchema) -> OutputSharding:
+    self, other = op_schema.args_schema[:2]
+    scalar = None if len(op_schema.args_schema) < 3 else op_schema.args_schema[2]
+    assert isinstance(self, list) and all(
+        isinstance(s, DTensorSpec) for s in self
+    ), f"Expect a List[DTensorSpec] but got {self}"
+    assert isinstance(other, list) and all(
+        isinstance(o, DTensorSpec) for o in other
+    ), f"Expect a List[DTensorSpec] but got {other}"
+    assert len(self) == len(other), (
+        "Two tensor lists must match in length, "
+        f"but got {len(self)} and {len(other)}"
+    )
+
+    if any(s != o for s, o in zip(self, other)):
+        # If DTensorSpec for the two operand do not match, suggest using
+        # self's DTensorSpec. This will trigger allreduce if other is partial
+        # and self is replicated.
+        return OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(self, self, scalar) if scalar else (self, self),
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+            ],
+        )
+    else:
+        return OutputSharding(output_spec=self)
+
+
+@register_prop_rule(  # pyre-ignore
+    [
+        aten._foreach_add.Scalar,
+        aten._foreach_div.Scalar,
+        aten._foreach_mul.Scalar,
+        aten._foreach_sub.Scalar,
+    ]
+)
+def _prop__foreach_binop_scalar(op_schema: OpSchema) -> OutputSharding:
+    self, scalar = op_schema.args_schema
+    assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
+    assert not isinstance(scalar, list)
+    return OutputSharding(output_spec=self)
+
+
+@register_prop_rule(  # pyre-ignore
+    [
+        aten._foreach_addcdiv.Scalar,
+        aten._foreach_addcmul.Scalar,
+    ]
+)
+def _prop__foreach_addcop_scalar(op_schema: OpSchema):
+    self, tensor1, tensor2 = op_schema.args_schema[:3]
+    scalar = None if len(op_schema.args_schema) < 4 else op_schema.args_schema[3]
+    assert isinstance(self, list) and all(isinstance(s, DTensorSpec) for s in self)
+    assert isinstance(tensor1, list) and all(isinstance(s, DTensorSpec) for s in self)
+    assert isinstance(tensor2, list) and all(isinstance(s, DTensorSpec) for s in self)
+    if any(s != t1 or s != t2 for s, t1, t2 in zip(self, tensor1, tensor2)):
+        # If DTensorSpec for the two operand do not match, suggest using
+        # self's DTensorSpec. This will trigger allreduce if other is partial
+        # and self is replicated.
+        return OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(self, self, self, scalar)
+                    if scalar
+                    else (self, self, self),
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+            ],
+        )
+    else:
+        return OutputSharding(output_spec=self)
+
+
+@register_prop_rule([aten._foreach_pow.ScalarAndTensor])  # pyre-ignore
+def _prop__foreach_pow_scalar_and_tensor(op_schema: OpSchema):
+    scala, exponent = op_schema.args_schema
+    assert isinstance(exponent, list) and all(
+        isinstance(s, DTensorSpec) for s in exponent
+    )
+    return OutputSharding(output_spec=exponent)
+
+
+@register_prop_rule([aten._fused_adam.default])  # pyre-ignore
+def _prop__fused_adam(op_schema: OpSchema):
+    NT = 5
+    tesnor_list_args: Tuple[List[DTensorSpec]] = op_schema.args_schema[:NT]  # type: ignore[assignment]
+
+    assert all(isinstance(schema, list) for schema in tesnor_list_args)
+    assert all(
+        isinstance(s, DTensorSpec) for schema in tesnor_list_args for s in schema
+    )
+
+    tensor_schemas: Tuple[List[DTensorSpec]] = [  # type: ignore[assignment]
+        schema for schema in tesnor_list_args if len(schema)
+    ]
+
+    assert all(len(s) == len(tensor_schemas[0]) for s in tensor_schemas), (
+        "expect the same number of gradients and states, but got "
+        f"{[len(s) for s in tensor_schemas]}."
+    )
+
+    if any(any(t != ts[0] for t in ts) for ts in zip(*tensor_schemas)):
+        new_schemas: Tuple[List[DTensorSpec]] = tuple(  # type: ignore[assignment]
+            op_schema.args_schema[0] if len(s) else s for s in tesnor_list_args
+        )
+        return OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=new_schemas + op_schema.args_schema[NT:],
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+            ],
+        )
+    else:
+        return OutputSharding(output_spec=(op_schema.args_schema[0],) * NT)  # type: ignore[arg-type]
+
+
+@register_prop_rule(aten.nll_loss_forward.default)  # pyre-ignore
+def _prop_nll_loss_forward(op_schema: OpSchema) -> OutputSharding:
+    self, target = op_schema.args_schema[:2]
+    assert isinstance(self, DTensorSpec)
+    assert isinstance(target, DTensorSpec)
+    if self.placements != target.placements:
+        # Self and target must match in placements, which should be shard along
+        # batch dimension in data parallell use cases. Force redistribute.
+
+        # need to create a new self instead return (target, target) as target
+        # and self might not match in shape.
+        new_self = DTensorSpec(
+            mesh=self.mesh,
+            placements=target.placements,
+            tensor_meta=self.tensor_meta,
+        )
+        return OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(new_self, target) + op_schema.args_schema[2:],
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+            ],
+        )
+    else:
+        return OutputSharding(
+            output_spec=(
+                # by default, nll_loss_forward conducts a reduction and returns
+                # a scalar tensor, and hence the _Partial placements.
+                DTensorSpec(mesh=self.mesh, placements=(_Partial(),)),
+                # the 2nd output total_weight is always a scalar tensor
+                DTensorSpec(mesh=self.mesh, placements=(Replicate(),)),
+            )
+        )
+
+
+@register_prop_rule(aten.nll_loss_backward.default)  # pyre-ignore
+def _prop_nll_loss_backward(op_schema: OpSchema) -> OutputSharding:
+    grad_output, self = op_schema.args_schema[:2]
+    assert isinstance(grad_output, DTensorSpec)
+    assert isinstance(self, DTensorSpec)
+    return OutputSharding(output_spec=self)
+
+
+@register_prop_rule(aten.stack.default)
+def _prop_stack(op_schema: OpSchema) -> OutputSharding:
+    tensors = op_schema.args_schema[0]
+    dim = 0 if len(op_schema.args_schema) == 1 else cast(int, op_schema.args_schema[1])
+    assert (
+        isinstance(tensors, list) and len(tensors) > 0
+    ), "expect at least one tensor to stack"
+    assert all(
+        isinstance(t, DTensorSpec) for t in tensors
+    ), f"expect a list of DTensorSpecs, but got {tensors}"
+    assert all(
+        t.shape == tensors[0].shape for t in tensors
+    ), f"expect all tensors to have the same shape, but got {tensors}."
+    # TODO: provide schema_suggestions when placements do not match
+    assert all(
+        t.placements == tensors[0].placements for t in tensors
+    ), f"expect all tensors to have the same placements, but got {tensors}."
+    assert all(
+        not p.is_shard(dim) for p in tensors[0].placements
+    ), "DTensor does not support stack on sharded dimension."
+
+    return OutputSharding(
+        output_spec=DTensorSpec(mesh=tensors[0].mesh, placements=tensors[0].placements)
+    )
+
+
+@register_prop_rule(aten.select.int)
+def _prop_select(op_schema: OpSchema) -> OutputSharding:
+    tensor, dim = op_schema.args_schema[:2]
+    assert isinstance(tensor, DTensorSpec)
+    assert isinstance(dim, int)
+    placements: Sequence[Placement] = tensor.placements
+    assert all(
+        not p.is_shard(dim) for p in placements
+    ), "DTensor does not support select on sharded dimension."
+
+    # select will remove one dimension, decrement dim of Shard placements by 1
+    # if they are larger than dim.
+    new_placements: List[Placement] = []
+    for p in placements:
+        # Using isinstance instead of is_shard so that mypy won't complain
+        # about accessing dim attribute.
+        if isinstance(p, Shard) and p.dim > dim:
+            new_placements.append(Shard(p.dim - 1))
+        else:
+            new_placements.append(p)
+
+    return OutputSharding(
+        output_spec=DTensorSpec(mesh=tensor.mesh, placements=tuple(new_placements))
+    )
+
+
+@register_prop_rule(aten.native_layer_norm.default)  # pyre-ignore
+def _prop_native_layer_norm(op_schema: OpSchema) -> OutputSharding:
+    input, normalized_shape, weight, bias, eps = op_schema.args_schema
+    assert isinstance(input, DTensorSpec)
+    assert isinstance(normalized_shape, (tuple, list))
+    if weight is not None:
+        assert isinstance(weight, DTensorSpec)
+        assert all(isinstance(p, Replicate) for p in weight.placements)
+    if bias is not None:
+        assert isinstance(bias, DTensorSpec)
+        assert all(isinstance(p, Replicate) for p in bias.placements)
+    # only the left-most (non-normalized) dimensions of the input can be sharded
+    batch_ndim = len(input.shape) - len(normalized_shape)
+    assert all(
+        isinstance(p, Replicate) or (isinstance(p, Shard) and p.dim < batch_ndim,)
+        for p in input.placements
+    )
+    stats_spec = DTensorSpec(
+        mesh=input.mesh,
+        placements=input.placements,
+    )
+    return OutputSharding(output_spec=(input, stats_spec, stats_spec))
+
+
+@register_prop_rule(aten.native_layer_norm_backward.default)  # pyre-ignore
+def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
+    (
+        grad,
+        input,
+        normalized_shape,
+        result1,
+        result2,
+        weight,
+        bias,
+        grad_input_mask,
+    ) = op_schema.args_schema
+    assert isinstance(grad, DTensorSpec)
+    assert isinstance(grad_input_mask, (list, tuple))
+    if weight is not None:
+        assert isinstance(weight, DTensorSpec)
+        assert all(isinstance(s, Replicate) for s in weight.placements)
+    if bias is not None:
+        assert isinstance(bias, DTensorSpec)
+        assert all(isinstance(s, Replicate) for s in bias.placements)
+    # ensure sharding on dim 0, which will trigger the "Partial" output on
+    # weight and bias grads
+    assert any(
+        isinstance(s, Shard) and s.dim == 0 for s in grad.placements
+    ), f"Got {grad.placements}"
+    weight_grad = (
+        DTensorSpec(
+            mesh=weight.mesh,
+            placements=tuple([_Partial()] * weight.mesh.ndim),
+        )
+        if weight
+        else None
+    )
+    bias_grad = (
+        DTensorSpec(
+            mesh=bias.mesh,
+            placements=tuple([_Partial()] * bias.mesh.ndim),
+        )
+        if bias
+        else None
+    )
+    return OutputSharding(
+        # NOTE: type errors below are legit. This is because DTensor currently
+        # doesn't support Optional return values. Need to be fixed in DTensor repo.
+        output_spec=(
+            grad if grad_input_mask[0] else None,
+            weight_grad if grad_input_mask[1] else None,
+            bias_grad if grad_input_mask[2] else None,
+        ),
+    )
+
+
+def _refine_sharding(
+    op_schema: OpSchema, active_dim: Optional[int]
+) -> Sequence[Placement]:
+    """Considers 2 first inputs of op_schema as having same shape, and returns suggested placement for a pointwise operation."""
+    # consider the operating dimension as a singleton to prevent sharding on it
+    # however, if active_dim is None, this means the input and output shapes are equal and
+    # we'll apply exactly the pointwise rule.
+
+    args_schema = []
+    for s in op_schema.args_schema[:2]:
+        assert isinstance(s, DTensorSpec) and s.tensor_meta is not None
+        args_schema.append(
+            DTensorSpec(
+                mesh=s.mesh,  # type: ignore[attr-defined]
+                placements=s.placements,  # type: ignore[attr-defined]
+                tensor_meta=TensorMeta(
+                    shape=torch.Size(
+                        s.shape[0:active_dim] + (1,) + s.shape[active_dim + 1 :]
+                    )
+                    if active_dim is not None
+                    else s.shape,
+                    stride=s.tensor_meta.stride,
+                    dtype=s.tensor_meta.dtype,
+                ),
+            )
+        )
+
+    op_schema = OpSchema(
+        op=op_schema.op,
+        args_schema=args_schema,  # type: ignore[arg-type]
+        kwargs_schema={},
+    )
+    output_sharding = pointwise_rule(op_schema, linearity=False)
+    if output_sharding.output_spec:
+        assert isinstance(output_sharding.output_spec, DTensorSpec)
+        return output_sharding.output_spec.placements
+    else:
+        assert output_sharding.schema_suggestions is not None
+        out_schema = output_sharding.schema_suggestions[0].args_schema[0]
+        assert isinstance(out_schema, DTensorSpec)
+        return tuple(out_schema.placements)
+
+
+@register_prop_rule(aten.slice_scatter.default)  # pyre-ignore
+def prop_slice_scatter(op_schema: OpSchema) -> OutputSharding:
+    # 1. number of dimensions in input and src need to match.
+    # 2. number of elements on all non-dim need to match between input and src.
+    # 3. numer of elements in src in dim need to match the slice size.
+    # Given the above:
+    # - We suggest for src to follow the sharding of input, except on the scatter dimension,
+    #   where our best bet for now is to make them replicated as a fall-back.
+    #   TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.
+
+    defaults = (None, None, 0, None, None, 1)
+    input, src, dim, start, end, step = (
+        op_schema.args_schema + defaults[len(op_schema.args_schema) :]
+    )
+    assert isinstance(input, DTensorSpec)
+    assert isinstance(src, DTensorSpec)
+    assert isinstance(dim, int)
+
+    if dim < 0:
+        dim += input.ndim
+
+    # if the input shape and the output shape are the same on the operating dimension,
+    # this is effectively a no-op, so we just propagate sharding as we would do for
+    # pointwise, no exceptions.
+    if input.shape[dim] == src.shape[dim]:
+        assert start == 0
+        assert end >= src.shape[dim]  # type: ignore[operator]
+        dim = None
+
+    # apply sharding refinement as implemented in pointwise_rule
+    input_suggestion = list(_refine_sharding(op_schema, dim))
+    # apply the exception -- disallow sharding on the operating dimension.
+    for i, p in enumerate(input_suggestion):
+        if isinstance(p, Shard) and p.dim == dim:
+            input_suggestion[i] = Replicate()
+    input_suggestion = tuple(input_suggestion)  # type: ignore[assignment]
+
+    if input_suggestion == tuple(input.placements) and src.placements == tuple(
+        input.placements
+    ):
+        # if our sharding is correct, the output sharding will be the same as the input.
+        return OutputSharding(
+            output_spec=DTensorSpec(
+                mesh=input.mesh,
+                placements=input.placements,
+            )
+        )
+    else:
+        # otherwise, return the suggestion.
+        return OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(
+                        DTensorSpec(
+                            mesh=input.mesh,
+                            placements=input_suggestion,
+                            tensor_meta=input.tensor_meta,
+                        ),
+                        DTensorSpec(
+                            mesh=src.mesh,
+                            placements=input_suggestion,
+                            tensor_meta=src.tensor_meta,
+                        ),
+                    )
+                    + op_schema.args_schema[2:],
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+            ],
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/gm_transformation.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/gm_transformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..77247616251521f20d0dba36c2137fa55866781d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/gm_transformation.py
@@ -0,0 +1,51 @@
+from typing import Callable
+
+from torch import fx
+from torch.distributed._spmd.graph_optimization import (
+    comm_fusion_with_concat,
+    enable_graph_optimization_dump,
+    remove_copy_from_optimizer,
+    schedule_comm_wait,
+)
+from torch.distributed._spmd.graph_utils import dump_graphs_to_files
+from torch.distributed._spmd.iter_graph_module import IterGraphModule
+
+
+class GraphModuleTransformation:
+    def __init__(
+        self,
+        *,
+        enable_graph_optimization: bool = False,
+        enable_inductor: bool = False,
+        dump_graphs: bool = False,
+    ) -> None:
+        self.enable_graph_optimization = enable_graph_optimization
+        self.enable_inductor = enable_inductor
+        self.dump_graphs = dump_graphs
+
+    def __call__(self, gm: fx.GraphModule) -> Callable:
+        if self.dump_graphs:
+            graph_folder = dump_graphs_to_files(
+                {"before_transformation_gm": gm.print_readable(False)}
+            )
+            enable_graph_optimization_dump(graph_folder)
+
+        iter_gm = IterGraphModule(gm, enable_inductor=self.enable_inductor)
+        if self.enable_graph_optimization:
+            comm_fusion_with_concat(iter_gm, 100)
+            schedule_comm_wait(iter_gm)
+            remove_copy_from_optimizer(iter_gm)
+        # Must be called after we are not going to move the graphs
+        iter_gm.finalize_setup()
+
+        if self.dump_graphs:
+            dump_graphs_to_files(
+                {
+                    "iter_graph_setup_gm": iter_gm.setup_gm.print_readable(False),
+                    "iter_graph_main_gm": iter_gm.main_gm.print_readable(False),
+                    "iter_graph_cleanup_gm": iter_gm.cleanup_gm.print_readable(False),
+                },
+                graph_folder,  # type: ignore[possibly-undefined]
+            )
+
+        return iter_gm
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/graph_optimization.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/graph_optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd96e7cf246f6fbcaa6715894e5e817ab4ce8d46
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/graph_optimization.py
@@ -0,0 +1,986 @@
+# Owner(s): ["oncall: distributed"]
+import collections
+import itertools
+import logging
+import operator
+import tempfile
+import time
+from dataclasses import dataclass, field
+from functools import wraps
+from typing import (
+    Any,
+    Callable,
+    cast,
+    DefaultDict,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
+
+import torch
+import torch.fx as fx
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch.distributed._spmd.graph_utils import (
+    CommType,
+    dump_graphs_to_files,
+    find_node,
+    get_output,
+    OP,
+)
+from torch.distributed._spmd.iter_graph_module import IterGraphModule
+from torch.fx.passes.shape_prop import TensorMetadata
+from torch.utils import _pytree as pytree
+from torch.utils._pytree import tree_flatten, tree_unflatten
+
+logger: logging.Logger = logging.getLogger("graph_optimization")
+aten = torch.ops.aten
+fake_tensor_mode = FakeTensorMode()
+
+_optimized_func: Set[str] = set()
+# The key is the target pass and the value is the prerequisites of the pass.
+_prerequisite_sets: DefaultDict[str, Set[str]] = collections.defaultdict(set)
+# The key is the target pass and the value is the passes that must applied before
+# the key.
+_apply_before_sets: DefaultDict[str, Set[str]] = collections.defaultdict(set)
+_dump_graph_folder: str = ""
+
+
+def enable_graph_optimization_dump(folder: str = ""):
+    global _dump_graph_folder
+    if not folder:
+        folder = tempfile.mkdtemp()
+    _dump_graph_folder = folder
+
+
+# TODO(@fegin): Support multiple runs of graph optimization
+# TODO(@fegin): With this design, circular imports will happen when a pass
+# developer accidentally create a pass dependency cycle. As a result, we need to
+# break this file into a finer granularity to avoid incorrect circular import.
+def graph_optimization_pass(
+    prerequisites: Iterable[Callable],
+    apply_after: Iterable[Callable],
+) -> Callable:
+    """Define the contract of a graph optimization pass.
+
+    All the passes should be wrapped with this decorator.
+    `prerequisites` is used to annotate the prerequisite passes of the this pass.
+    `apply_after` means that this wrapped pass must be applied after the passes
+    in `apply_after`. The difference between `prerequisites` and `apply_after`
+    is that all the passes in `prerequisites` must be applied to the graph and
+    must be applifed before the wrapped pass while the passes `apply_after` are
+    optional. But if a pass in `apply_after` is applied to the graph, it has to
+    be done before the wrapped pass.
+    Optimizer pass developers are required to add these fields accordingly and
+    users need to follow the restrictions to avoid the assert.
+
+    Current design has one limitation: users can only apply the optimizations
+    once.  In some cases, we may need to run multiple the same optimization
+    multiple time, e.g., optimization passes -> profiling the result -> apply
+    optimization passes with the profiling result again. This limitation will be
+    addressed limitation in the future.
+
+    Args:
+        prerequisites (Iterable[Callable]): the list of string to the names of
+            passes which are the prerequisites of this pass.
+        apply_after (Iterable[Callable]): the list of string to the names of
+            passes that can not be applied after the wrapped pass.
+    """
+
+    def inner(func: Callable) -> Callable:
+        def make_key(func: Callable) -> str:
+            return f"{func.__module__}.{func.__name__}"
+
+        func_key = make_key(func)
+        _prerequisite_sets[func_key] = {make_key(f) for f in prerequisites}
+        for apply_after_pass in apply_after:
+            _apply_before_sets[make_key(apply_after_pass)].add(func_key)
+
+        @wraps(func)
+        def pass_wrapper(
+            gm: Union[fx.GraphModule, IterGraphModule], *args: Any, **kwargs: Any
+        ) -> None:
+            begin = time.time()
+            assert isinstance(gm, (fx.GraphModule, IterGraphModule)), (
+                "The first argument of the pass must be either "
+                "fx.GraphModule or IterGraphModule."
+            )
+            assert func_key not in _optimized_func, f"Cannot apply {func_key} twice."
+            invalid_passes = _apply_before_sets[func_key].intersection(_optimized_func)
+            assert (
+                not invalid_passes
+            ), f"{invalid_passes} must be applied after {func_key}."
+            assert _prerequisite_sets[func_key].issubset(_optimized_func), (
+                f"{_prerequisite_sets[func_key] - _optimized_func} are the "
+                f"prerequisites of {func_key} but are not applified. "
+                f"Applied passes are {_optimized_func}."
+            )
+
+            func(gm, *args, **kwargs)
+            gm.graph.lint()
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+            _optimized_func.add(func_key)
+
+            prefix = f"after_{func.__name__}"
+            if _dump_graph_folder:
+                if isinstance(gm, IterGraphModule):
+                    dump_graphs_to_files(
+                        {
+                            f"{prefix}_setup_gm": gm.setup_gm,
+                            f"{prefix}_main_gm": gm.main_gm,
+                            f"{prefix}_cleanup_gm": gm.cleanup_gm,
+                        },
+                        _dump_graph_folder,
+                    )
+                else:
+                    dump_graphs_to_files({prefix: gm}, _dump_graph_folder)
+
+            logger.info("Spent %f seconds applying %s", time.time() - begin, func_key)
+
+        return pass_wrapper
+
+    return inner
+
+
+@dataclass(unsafe_hash=True)
+class CommBlock:
+    shape: Optional[torch.Size]
+    node_list: List[fx.Node]
+    inputs: List[fx.Node]
+    wait_nodes: List[fx.Node]
+    comm_node: fx.Node
+    outputs: Set[fx.Node]
+
+
+def get_comm_block(comm_node: fx.Node) -> CommBlock:
+    """Find out all the nodes belong to this communcation given a collective node (e.g., allreduce).
+
+    Args:
+        comm_node(fx.Node): The target communication/collective node.
+
+    Returns:
+        The CommBlock that encapsulates the related nodes (e.g., wait_node) of
+        the given comm_node.
+    """
+    # We choose 5 to prevent some accidents that cause infinite loop. But
+    # with functional collective, the distance is 1.
+    MAX_WAIT_DISTANCE = 5
+    node_list = []
+    wait_nodes = []
+    inputs = pytree.arg_tree_leaves(*comm_node.args, **comm_node.kwargs)
+    input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)]
+    distance = 0
+    wait_prefixes = ("wait_comm", "wait_tensor")
+    non_end_users_nodes = ("split", "reshape", "getitem", "detach", "alias")
+
+    nodes = collections.deque([comm_node, None])
+    while nodes and distance < 5:
+        node = nodes.popleft()
+        if node is None:
+            distance += 1
+            if nodes:
+                nodes.append(None)
+            continue
+        node_list.append(node)
+        if node.name.startswith(wait_prefixes):
+            wait_nodes.append(node)
+        else:
+            for child in node.users:
+                if isinstance(child, fx.Node):
+                    nodes.append(child)
+
+    if not wait_nodes:
+        raise RuntimeError(
+            "The wait nodes are too far away from the comm node {comm_node}."
+        )
+
+    # Identify all the outputs of this collective block.
+    outputs: Set[fx.Node] = set()
+    nodes = collections.deque(wait_nodes)
+    while nodes:
+        node = nodes.popleft()
+        assert node is not None
+        for user in node.users:
+            if isinstance(user, fx.Node) and user.name.startswith(non_end_users_nodes):
+                nodes.append(user)
+                node_list.append(user)
+            else:
+                outputs.add(node)
+                break
+
+    # TODO: populate all the tensor metadata and remove the default.
+    tensor_meta = input_nodes[0].meta.get("tensor_meta", None)
+    return CommBlock(
+        # TODO: support symbolic shapes
+        shape=torch.Size(int(s) for s in tensor_meta.shape) if tensor_meta else None,
+        node_list=node_list,
+        wait_nodes=wait_nodes,
+        comm_node=comm_node,
+        inputs=input_nodes,
+        outputs=outputs,
+    )
+
+
+def get_all_comm_blocks(
+    gm: IterGraphModule, comm_ops: Union[Tuple[str, ...], str]
+) -> List[CommBlock]:
+    return [
+        get_comm_block(node)
+        for node in gm.graph.nodes
+        if node.name.startswith(comm_ops)
+    ]
+
+
+def _create_meta_val(
+    fake_tensor_mode: FakeTensorMode,
+    val: FakeTensor,
+) -> FakeTensor:
+    # TODO: fix the memory_format
+    return FakeTensor(
+        fake_tensor_mode,
+        torch.empty(
+            val.shape,
+            dtype=val.dtype,
+            device="meta",
+            requires_grad=val.requires_grad,
+        ),
+        val.device,
+    )
+
+
+def _create_meta_tensor_meta(
+    fake_tensor_mode: FakeTensorMode,
+    val: FakeTensor,
+) -> TensorMetadata:
+    return TensorMetadata(
+        shape=val.shape,
+        dtype=val.dtype,
+        requires_grad=val.requires_grad,
+        stride=val.stride,  # type: ignore[arg-type]
+        # TODO: fix these value
+        memory_format=None,
+        is_quantized=False,
+        qparams={},
+    )
+
+
+def _call_function(
+    gm: IterGraphModule,
+    fake_tensor_mode: FakeTensorMode,
+    meta_val: Optional[FakeTensor],
+    function: Any,
+    *args: Any,
+    **kwargs: Any,
+) -> fx.Node:
+    node = gm.graph.call_function(function, args, kwargs)
+
+    if meta_val is None:
+        flat_args, spec = tree_flatten((args, kwargs))
+        new_flat_args = []
+        memory_format = None
+        for arg in flat_args:
+            if not isinstance(arg, fx.Node):
+                new_flat_args.append(arg)
+                continue
+            val = arg.meta["val"]
+            new_flat_args.append(_create_meta_val(fake_tensor_mode, val))
+
+        fake_args, fake_kwargs = tree_unflatten(new_flat_args, spec)
+        new_meta_val = function(*fake_args, **fake_kwargs)
+    else:
+        new_meta_val = meta_val
+    node.meta["val"] = new_meta_val
+    node.meta["tensor_meta"] = _create_meta_tensor_meta(fake_tensor_mode, new_meta_val)
+    return node
+
+
+def _scatter_wait_result(
+    gm: IterGraphModule,
+    fused_comm_block: CommBlock,
+    comm_blocks: List[CommBlock],
+    node_indices: Dict[fx.Node, int],
+) -> None:
+    """Scatter the result of the fused communication node to the original users -- splitting the output and reshape each subitem."""
+    last_wait_node_idx = 0
+    for node in gm.graph.nodes:
+        if node == fused_comm_block.comm_node:
+            break
+        last_wait_node_idx = max(
+            node_indices.get(node, last_wait_node_idx), last_wait_node_idx
+        )
+
+    fused_comm_node = fused_comm_block.comm_node
+    fused_wait_node = fused_comm_block.wait_nodes[0]
+
+    with gm.graph.inserting_after(fused_wait_node):
+        split_node = gm.graph.call_function(
+            aten.split,
+            (
+                fused_wait_node,
+                # TODO(@fegin): support symbolic shapes
+                [int(cast(torch.Size, cb.shape).numel()) for cb in comm_blocks],
+            ),
+        )
+
+    # Scatter the split result.
+    need_sort_nodes = []
+    last_split_reshape_node = split_node
+    with gm.graph.inserting_after(split_node):
+        for idx, comm_block in enumerate(comm_blocks):
+            # Some users of the original allreduce and wait are scheduled
+            # before the fused allreduce. We must move these users to a
+            # correct topological sort order -- right after the last fused
+            # allreduce result, the `last_split_reshape_node` variable.
+            orig_wait = comm_block.wait_nodes[0]
+            nodes = collections.deque(list(orig_wait.users))
+            while nodes:
+                user_node = nodes.popleft()
+                if not isinstance(user_node, fx.Node):
+                    continue
+                if node_indices[user_node] < last_wait_node_idx:
+                    need_sort_nodes.append(user_node)
+                    nodes.extend(list(user_node.users))
+
+            split_idx_node = gm.graph.call_function(operator.getitem, (split_node, idx))
+            with gm.graph.inserting_after(split_idx_node):
+                wait_output_node = gm.graph.call_function(
+                    aten.reshape, (split_idx_node, comm_block.shape)
+                )
+            gm.graph.node_replace_all_uses_with(orig_wait, wait_output_node)
+
+        if last_split_reshape_node == split_node:
+            last_split_reshape_node = wait_output_node  # type: ignore[possibly-undefined]
+
+    need_sort_nodes = sorted(need_sort_nodes, key=lambda node: node_indices[node])
+    gm.graph.move_after(need_sort_nodes, last_split_reshape_node)
+
+    gm.graph.eliminate_dead_code()
+
+
+def _fuse_with_cat(
+    gm: IterGraphModule,
+    comm_blocks: List[CommBlock],
+    node_indices: Dict[fx.Node, int],
+) -> CommBlock:
+    """Fuse the CommBlocks using concat given a list of CommBlock (only allreduce)."""
+    # Find the last input node.
+    last_input_node = comm_blocks[0].inputs[0]
+    last_input_index = -1
+    all_input_nodes = []
+    for comm_block in comm_blocks:
+        input_node = comm_block.inputs[0]
+        # If the input node is a clone, this is CommTensor based implementation.
+        if input_node.name.startswith("clone"):
+            input_node = cast(fx.Node, input_node.args[0])
+        all_input_nodes.append(input_node)
+        index = node_indices[input_node]
+        if index >= last_input_index:
+            assert index != last_input_index
+            last_input_node = input_node
+            last_input_index = index
+
+    # Flatten all the inputs right after the last input is ready.
+    with gm.graph.inserting_after(last_input_node):
+        cat_inputs = []
+        for input_node in all_input_nodes:
+            cat_inputs.append(
+                _call_function(
+                    gm, fake_tensor_mode, None, aten.flatten.using_ints, input_node
+                )
+            )
+
+    with gm.graph.inserting_after(cat_inputs[0]):
+        cat_node = _call_function(gm, fake_tensor_mode, None, aten.cat, cat_inputs)
+
+    # Create a new Comm node.
+    last_comm = comm_blocks[-1]
+    last_comm_node = last_comm.comm_node
+    last_wait_node = last_comm.wait_nodes[0]
+    with gm.graph.inserting_after(cat_node):
+        flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
+        flatten_args[0] = cat_node
+        args, kwargs = tree_unflatten(flatten_args, spec)
+        fused_comm_node = _call_function(
+            gm,
+            fake_tensor_mode,
+            cat_node.meta["val"],
+            last_comm_node.target,
+            *args,
+            **kwargs,
+        )
+
+    # Create a new Wait node.
+    with gm.graph.inserting_after(fused_comm_node):
+        flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
+        flatten_args[0] = fused_comm_node
+        args, kwargs = tree_unflatten(flatten_args, spec)
+        fused_wait_node = _call_function(
+            gm,
+            fake_tensor_mode,
+            cat_node.meta["val"],
+            last_wait_node.target,
+            *args,
+            **kwargs,
+        )
+
+    # Move the fused_comm_node and its args to right after the source node
+    nodes_to_move = cat_inputs + [cat_node, fused_comm_node, fused_wait_node]
+    gm.graph.move_after(nodes_to_move, last_input_node)
+
+    tensor_meta = cat_node.meta.get("tensor_meta")
+    fused_comm_block = CommBlock(
+        shape=tensor_meta.shape,  # type: ignore[union-attr]
+        node_list=[fused_comm_node, fused_wait_node],
+        wait_nodes=[fused_wait_node],
+        comm_node=fused_comm_node,
+        inputs=[cat_node],
+        outputs={fused_wait_node},
+    )
+
+    _scatter_wait_result(gm, fused_comm_block, comm_blocks, node_indices)
+
+    return fused_comm_block
+
+
+def _expedite_comm_ops(gm: IterGraphModule, comm_blocks: List[CommBlock]) -> None:
+    node_indices = {node: i for i, node in enumerate(gm.graph.nodes)}
+    for comm_block in comm_blocks:
+        last_input = comm_block.comm_node
+        last_input_idx = -1
+        for input in comm_block.inputs:
+            input_idx = node_indices[input]
+            if input_idx > last_input_idx:
+                last_input = input
+                last_input_idx = input_idx
+        gm.graph.node_append(last_input, comm_block.comm_node)
+
+
+@graph_optimization_pass(
+    prerequisites=[],
+    apply_after=[],
+)
+def comm_fusion_with_concat(
+    gm: IterGraphModule,
+    bucket_size_mb: int,
+) -> None:
+    """Run fuse communication with concat.
+
+    This implementation uses concat to concat the bucketed gradients.
+    """
+    comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
+    # First ensure the allreduce are scheduled immediately right after the gradients.
+    _expedite_comm_ops(gm, comm_blocks)
+    # Get the comm_blocks based on the new order.
+    comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
+    node_indices = {node: i for i, node in enumerate(gm.graph.nodes)}
+
+    bucket_size = 1 * 1024**2
+    bucket_cap_size = bucket_size_mb * 1024**2
+    begin = end = curr_size = 0
+    while end < len(comm_blocks):
+        # TODO: determine the dtype
+        curr_size += cast(torch.Size, comm_blocks[end].shape).numel() * 4
+        end += 1
+        if curr_size < bucket_size:
+            continue
+        _fuse_with_cat(gm, comm_blocks[begin:end], node_indices)
+        bucket_size = bucket_cap_size
+        begin = end
+        curr_size = 0
+    else:
+        if begin < len(comm_blocks):
+            _fuse_with_cat(gm, comm_blocks[begin:end], node_indices)
+
+
+@graph_optimization_pass(
+    prerequisites=[comm_fusion_with_concat],
+    apply_after=[],
+)
+def schedule_comm_wait(gm: IterGraphModule) -> None:
+    """Delay the execution of wait tensors of allreduce until its first user."""
+    comm_blocks = get_all_comm_blocks(gm, (CommType.ALLREDUCE, "all_reduce"))
+
+    # Find all the end users.
+    allreduce_users: Set[fx.Node] = set()
+    for allreduce in comm_blocks:
+        for output in allreduce.outputs:
+            allreduce_users.update(output.users)
+
+    node_indices = {node: i for i, node in enumerate(gm.graph.nodes)}
+    for allreduce in comm_blocks:
+        # Find the earliest users.
+        assert (
+            len(allreduce.outputs) >= 1
+        ), f"Found a allreduce that has zero outputs/users -- {allreduce}."
+        # Initialize the target_node to be the first user of the first output.
+        target_node = next(iter(next(iter(allreduce.outputs)).users))
+        target_node_index = 2**31
+        for user in (user for output in allreduce.outputs for user in output.users):
+            index = node_indices[user]
+            if index < target_node_index:
+                target_node = user
+                target_node_index = index
+
+        # Move wait nodes and all the subsequent output nodes before the
+        # earliest user.
+        wait_idx = -1
+        for wait_idx, node in enumerate(allreduce.node_list):
+            if node == allreduce.wait_nodes[0]:
+                break
+        assert wait_idx >= 0
+        gm.graph.move_before(allreduce.node_list[wait_idx:], target_node)
+
+
+@graph_optimization_pass(
+    prerequisites=[],
+    apply_after=[],
+)
+def remove_copy_from_optimizer(gm: IterGraphModule) -> None:
+    """Erase the orphant copy_ that generated when tracing optimizer.
+
+    Two reasons why we could not simply use the DCE of fx.Graph.
+    1. fx.Graph treats copy_ as a side-effect node and does not erase it.
+    2. Users may want to preserve some orphan `copy_` that is not from the
+       optimizer.
+    If the second reason does not hold, this pass can be rewritten as using
+    DCE from fx.Graph (with the overwrite to the side-effect node list).
+    """
+    MAX_COPY_DISTANCE = 5
+    remove_candidates: Set[fx.Node] = set()
+    for node in reversed(gm.graph.nodes):
+        if node.users:
+            continue
+        if node.op != OP.CALL_FUNCTION or node.target != aten.copy_.default:
+            continue
+
+        copy_ancestors: Set[fx.Node] = set()
+        nodes = collections.deque([node, None])
+        distance = 0
+        should_remove = False
+        while nodes and distance < MAX_COPY_DISTANCE:
+            visiting = nodes.popleft()
+            if visiting is None:
+                distance += 1
+                if nodes:
+                    nodes.append(None)
+                continue
+            copy_ancestors.add(visiting)
+            if visiting.op == OP.CALL_FUNCTION and str(visiting.target).startswith(
+                ("aten._foreach_", "aten._fused_")
+            ):
+                should_remove = True
+            parents = pytree.arg_tree_leaves(*visiting.args, **visiting.kwargs)
+            for parent in parents:
+                if isinstance(parent, fx.Node):
+                    nodes.append(parent)
+        if should_remove:
+            # We add all ancestors to the list and it is okay as not all of
+            # them will be erased -- only those nodes with zero users will be
+            # erased.
+            remove_candidates.update(copy_ancestors)
+
+    for node in reversed(gm.graph.nodes):
+        if node.users:
+            continue
+        if node not in remove_candidates:
+            continue
+        gm.graph.erase_node(node)
+
+
+# The args list of fused_adam function. We don't care about kwargs.
+AdamArgs = collections.namedtuple(
+    "AdamArgs",
+    ["params", "grads", "exp_avgs", "exp_avg_sqs", "max_exp_avg_sqs", "state_steps"],
+)
+
+
+# TODO(fegin): Have a template class for all Block class.
+@dataclass(unsafe_hash=True)
+class FusedAdamBlock:
+    optim_node: fx.Node
+    generate_output: bool
+    # The output list of the copy nodes. The order follows the argument order.
+    param_outputs: List[fx.Node] = field(default_factory=list)
+    grad_outputs: List[fx.Node] = field(default_factory=list)
+    exp_avgs_outputs: List[fx.Node] = field(default_factory=list)
+    exp_avg_sqs_outputs: List[fx.Node] = field(default_factory=list)
+    # TODO(fegin): populate/generate the max_exp_avg_sqs if exists
+    max_exp_avg_sqs: List[fx.Node] = field(default_factory=list)
+
+    def generate_outputs(self):
+        # Iterate all the args and generate the corresponding output lists.
+        # Assuming the corrsesponding output nodes are not created yet.
+        def _generate_outputs(arg_idx, output_list):
+            graph = self.optim_node.graph
+            with graph.inserting_after(self.optim_node):
+                optim_getitem = graph.call_function(
+                    operator.getitem, (self.optim_node, arg_idx)
+                )
+            for i, arg in enumerate(self.optim_node.args[arg_idx]):
+                with graph.inserting_after(optim_getitem):
+                    updated_arg = graph.call_function(
+                        operator.getitem, (optim_getitem, i)
+                    )
+                with graph.inserting_after(updated_arg):
+                    output_copy = graph.call_function(aten.copy_, (arg, updated_arg))
+                output_list.append(output_copy)
+
+        _generate_outputs(0, self.param_outputs)
+        # Do not generate gradient out list as it is not used.
+        _generate_outputs(2, self.exp_avgs_outputs)
+        _generate_outputs(3, self.exp_avg_sqs_outputs)
+
+    def populate_outputs(self):
+        # Populate the existing output lists from the graph.
+        def _populate_outputs(args_idx, output_list):
+            optim_getitem = self.optim_node
+            for user in self.optim_node.users:
+                assert (
+                    user.target == operator.getitem
+                ), f"The user of {self.optim_node} is not getitem."
+                if user.args[1] == args_idx:
+                    optim_getitem = user
+                    break
+            assert (
+                optim_getitem != self.optim_node
+            ), f"Cannot find the getitem node for {self.optim_node}"
+            output_list.extend(
+                [self.optim_node] * len(cast(List[fx.Node], self.optim_node.args[0]))
+            )
+            for updated_arg in optim_getitem.users:
+                assert (
+                    updated_arg.target == operator.getitem
+                ), f"Unexpected node target {updated_arg.target}."
+                idx = updated_arg.args[1]
+                output_copy = next(iter(updated_arg.users))
+                assert str(output_copy.target).startswith(
+                    "aten.copy_"
+                ), f"Unexpected node target {output_copy.target}."
+                output_list[idx] = output_copy
+            for i, output in enumerate(output_list):
+                assert output != self.optim_node, f"{i}th output is not replaced."
+
+            assert output_list, f"The output for {self.optim_node} is empty."
+
+        _populate_outputs(0, self.param_outputs)
+        _populate_outputs(2, self.exp_avgs_outputs)
+        _populate_outputs(3, self.exp_avg_sqs_outputs)
+
+    def __post_init__(self):
+        if self.param_outputs:
+            return
+        if self.generate_output:
+            self.generate_outputs()
+        else:
+            self.populate_outputs()
+
+
+@dataclass(unsafe_hash=True)
+class ForeachAddBlock:
+    add_node: fx.Node
+    generate_output: bool
+    # The output list of the copy nodes. The order follows the argument order.
+    outputs: List[fx.Node] = field(default_factory=list)
+
+    def generate_outputs(self):
+        # Iterate all the args and generate the corresponding output lists
+        # Assuming the corrsesponding output nodes are not created yet.
+        graph = self.add_node.graph
+        for i, arg in enumerate(cast(Tuple[Any, ...], self.add_node.args[0])):
+            with graph.inserting_after(self.add_node):
+                updated_arg = graph.call_function(operator.getitem, (self.add_node, i))
+            with graph.inserting_after(updated_arg):
+                output_copy = graph.call_function(aten.copy_, (arg, updated_arg))
+            self.outputs.append(output_copy)
+        assert self.outputs, f"The output for {self.add_node} is empty."
+
+    def populate_outputs(self):
+        # Populate the existing output lists from the graph.
+        self.outputs = [
+            self.add_node for _ in cast(Tuple[Any, ...], self.add_node.args[0])
+        ]
+        for updated_arg in self.add_node.users:
+            assert (
+                updated_arg.target == operator.getitem
+            ), f"Unexpected node target {updated_arg.target}"
+            idx = cast(int, updated_arg.args[1])
+            output_copy = next(iter(updated_arg.users))
+            assert str(output_copy.target).startswith(
+                "aten.copy_"
+            ), f"The execpted output node is different, {str(output_copy.target)}"
+            self.outputs[idx] = output_copy
+        for i, output in enumerate(self.outputs):
+            assert output != self.add_node, f"{i}th output is not replaced."
+
+    def __post_init__(self):
+        if self.outputs:
+            return
+
+        if self.generate_output:
+            self.generate_outputs()
+        else:
+            self.populate_outputs()
+
+
+@dataclass(unsafe_hash=True)
+class FusedOptimizerBlock:
+    step: ForeachAddBlock
+    optim: FusedAdamBlock
+
+
+def get_fused_optimizer_block(optim_node: fx.Node) -> FusedOptimizerBlock:
+    """Given a fused optimizer node and return the FusedOptimizerBlock."""
+    MAX_STEP_DISTANCE = 5
+    # Find the step (foreach_add)
+    nodes = collections.deque([optim_node, None])
+    step_node = optim_node
+    distance = 0
+    while nodes and distance < MAX_STEP_DISTANCE:
+        node = nodes.popleft()
+        if node is None:
+            distance += 1
+            if nodes:
+                nodes.append(None)
+            continue
+        elif node.op == OP.CALL_FUNCTION and str(node.target).startswith(
+            "aten._foreach_add"
+        ):
+            step_node = node
+            break
+        else:
+            nodes.extend(
+                a
+                for a in pytree.arg_tree_leaves(*node.args, **node.kwargs)
+                if isinstance(a, fx.Node)
+            )
+    if step_node == optim_node:
+        raise RuntimeError(
+            "Cannot find step node (foreach_add) for the optimizer node "
+            f"{optim_node} with {MAX_STEP_DISTANCE} BFS distance. "
+            "The API design does not match the tracing graph."
+        )
+
+    step = ForeachAddBlock(step_node, generate_output=False)
+    optim = FusedAdamBlock(optim_node, generate_output=False)
+    return FusedOptimizerBlock(step, optim)
+
+
+def get_all_fused_optimizer_blocks(
+    gm: IterGraphModule, optim_ops: Union[Tuple[str, ...], str]
+) -> List[FusedOptimizerBlock]:
+    """Find all the FusedOptimizerBlock that the optimizer operators are in `optim_ops`."""
+    return [
+        get_fused_optimizer_block(node)
+        for node in gm.graph.nodes
+        if node.name.startswith(optim_ops)
+    ]
+
+
+def _split_fused_adam(
+    gm: IterGraphModule,
+    orig_optim_block: FusedOptimizerBlock,
+    split_gradients: Set[fx.Node],
+) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
+    """Split the `orig_optim_block` into two FusedOptimizerBlock.
+
+    The first one will be the optimizer that optimize `split_gradients`. The second one is
+    used to optimize the remaining gradients.
+    An assert will be raised if one of the optimizer optimize zero gradients.
+    """
+    orig_optim_args = AdamArgs(*orig_optim_block.optim.optim_node.args)
+    optim_args = (AdamArgs([], [], [], [], [], []), AdamArgs([], [], [], [], [], []))
+    # The only hint we can use to split the optimizer is the order/indices.
+    orig_optim_indices: Tuple[List[int], List[int]] = ([], [])
+    orig_step_indices: Tuple[List[int], List[int]] = ([], [])
+
+    for idx, gradient in enumerate(orig_optim_args.grads):
+        group_idx = 0 if gradient in split_gradients else 1
+        orig_optim_indices[group_idx].append(idx)
+        # Get the argument for idx-th gradient from orig_optim_args
+        for orig_arg, optim_arg in zip(orig_optim_args, optim_args[group_idx]):
+            # Only add the argument to the list if the original argument list
+            # is not empty. If the original argument list is empty, the new
+            # one must be an empty list as well.
+            if orig_arg:
+                optim_arg.append(orig_arg[idx])
+
+        # If argument order of step is the same as optimizer, nothing has to be
+        # done. However, it is risky to rely on this assumption so we populate
+        # the orig_step_indices.
+        orig_step_output = optim_args[group_idx].state_steps[-1]
+        assert str(orig_step_output.target).startswith(
+            "aten.copy_"
+        ), f"The copy output is {orig_step_output.target}, expect aten.copy_"
+        orig_step_getitem = orig_step_output.args[1]
+        assert "getitem" in str(
+            orig_step_getitem.target
+        ), f"The copy getitem is {orig_step_getitem.target}, expect operator.getitem"
+        orig_step_idx = orig_step_getitem.args[1]
+        orig_step_indices[group_idx].append(orig_step_idx)
+
+    if not all(l for l in (orig_step_indices + orig_optim_indices)):
+        raise ValueError("At least one split optimizer does not have input.")
+
+    output = get_output(gm.graph)
+    results: List[FusedOptimizerBlock] = []
+    flatten_output_args, spec = tree_flatten((output.args, output.kwargs))
+    flatten_output_args_indices: DefaultDict[
+        fx.Node, Set[int]
+    ] = collections.defaultdict(set)
+    for idx, output_arg in enumerate(flatten_output_args):
+        if isinstance(output_arg, fx.Node):
+            flatten_output_args_indices[output_arg].add(idx)
+
+    def replace_flatten_output_args(orig_node: fx.Node, new_node: fx.Node):
+        for idx in flatten_output_args_indices[orig_node]:
+            flatten_output_args[idx] = new_node
+
+    # Create the new step and optim nodes and blocks.
+    for group_idx in range(2):
+        step_args: List[fx.Node] = []
+        orig_step_outputs: List[fx.Node] = []
+        # We have to create the new step node and block first because it is used
+        # for the new optim node as the input.
+        with gm.graph.inserting_after(orig_optim_block.optim.optim_node):
+            for idx in orig_step_indices[group_idx]:
+                step_args.append(
+                    cast(Tuple[fx.Node, ...], orig_optim_block.step.add_node.args[0])[
+                        idx
+                    ]
+                )
+                orig_step_outputs.append(orig_optim_block.step.outputs[idx])
+            step = gm.graph.call_function(
+                aten._foreach_add.Scalar,
+                (step_args, 1),
+            )
+        step_block = ForeachAddBlock(step, generate_output=True)
+        for i, step_output in enumerate(step_block.outputs):
+            # Replace the original step output in the graph output node with
+            # the new one.
+            orig_step_output = orig_step_outputs[i]
+            replace_flatten_output_args(orig_step_output, step_output)
+            # Also need to replace the step output used for the new optimizer.
+            assert optim_args[group_idx].state_steps[i] == orig_step_output, (
+                f"The expected step output node mismatched, {orig_step_output} "
+                f"{optim_args[group_idx].state_steps[i]}"
+            )
+            optim_args[group_idx].state_steps[i] = step_output
+
+        # Insert the optimizer node after the first step output because its
+        # topo sort order is the last.
+        with gm.graph.inserting_after(step_block.outputs[0]):
+            optim = gm.graph.call_function(
+                aten._fused_adam.default,
+                optim_args[group_idx],
+                orig_optim_block.optim.optim_node.kwargs,
+            )
+        optim_block = FusedAdamBlock(optim, generate_output=True)
+        for curr_idx, orig_idx in enumerate(orig_optim_indices[group_idx]):
+            list_names = ("param_outputs", "exp_avgs_outputs", "exp_avg_sqs_outputs")
+            for name in list_names:
+                orig_list = getattr(orig_optim_block.optim, name)
+                curr_list = getattr(optim_block, name)
+                replace_flatten_output_args(orig_list[orig_idx], curr_list[curr_idx])
+
+        results.append(FusedOptimizerBlock(step_block, optim_block))
+
+    # Optimizer is used as the output of the train_step. Therefore, we have to
+    # update the output node of the graph.
+    output_args, output_kwargs = tree_unflatten(flatten_output_args, spec)
+    gm.graph.node_set_args(output, output_args)
+    gm.graph.node_set_kwargs(output, output_kwargs)
+    # Remove the original copy_ nodes as they won't be DCE.
+    for copy_output in itertools.chain(
+        orig_optim_block.optim.param_outputs,
+        orig_optim_block.optim.exp_avgs_outputs,
+        orig_optim_block.optim.exp_avg_sqs_outputs,
+    ):
+        gm.graph.erase_node(copy_output)
+    # Call DCE once to get rid of the old optimizer. By doing so, we will be
+    # able to erase the copy_ nodes of step later.
+    gm.graph.eliminate_dead_code()
+    for copy_output in orig_optim_block.step.outputs:
+        gm.graph.erase_node(copy_output)
+    # This is not required but calling this for consistency.
+    gm.graph.eliminate_dead_code()
+
+    return results[0], results[1]
+
+
+def split_fused_optimizer(
+    gm: IterGraphModule,
+    optim_block: FusedOptimizerBlock,
+    split_gradients: Set[fx.Node],
+) -> Tuple[FusedOptimizerBlock, FusedOptimizerBlock]:
+    if not split_gradients:
+        raise ValueError("The given split_gradients is empty.")
+    if str(optim_block.optim.optim_node.target).startswith("aten._fused_adam"):
+        return _split_fused_adam(gm, optim_block, split_gradients)
+    else:
+        raise NotImplementedError("Only fused_adam is supported now")
+
+
+# TODO(fegin): The API only support fused adam now. Should extend it to support
+# foreach as well.
+@graph_optimization_pass(
+    prerequisites=[remove_copy_from_optimizer],
+    apply_after=[schedule_comm_wait],
+)
+def iter_move_grads_and_optimizers(
+    gm: IterGraphModule,
+    target_comm_node: str,
+    target_dest_node: str,
+) -> None:
+    """Extract a comm block and split out a new optimizer and step for it.
+
+    This subgraph is then moved to the forward graph.
+    """
+    for comm_block in get_all_comm_blocks(gm, "all_reduce"):
+        if comm_block.comm_node.name == target_comm_node:
+            break
+    else:
+        raise ValueError(f"Cannot find {target_comm_node}")
+
+    optim_blocks = get_all_fused_optimizer_blocks(gm, "_fused_adam")
+    for optim_block in optim_blocks:
+        optim_args = AdamArgs(*optim_block.optim.optim_node.args)
+        one_output = next(iter(comm_block.outputs))
+        if one_output in optim_args.grads:
+            break
+    else:
+        raise ValueError(f"{target_comm_node} is not used by any fused optimizer.")
+
+    move_optim, _ = split_fused_optimizer(gm, optim_block, comm_block.outputs)
+
+    move_nodes = find_all_descendants(
+        gm, [comm_block.comm_node, move_optim.step.add_node]
+    )
+
+    stop_node = find_node(gm.graph, lambda n: n.name == target_dest_node)[0]
+
+    gm.graph.move_to_next_iter_before(move_nodes, stop_node)
+
+
+def find_all_descendants(
+    gm: IterGraphModule,
+    parent_nodes: List[fx.Node],
+) -> List[fx.Node]:
+    """Identify the list of nodes to move during FX graph transformation."""
+    assert len(parent_nodes) > 0, "No parent nodes are given."
+
+    output = get_output(gm.graph)
+    dq_parent_nodes = collections.deque(parent_nodes)
+    move_node_set = set()
+    while dq_parent_nodes:
+        node = dq_parent_nodes.popleft()
+        move_node_set.add(node)
+        dq_parent_nodes += [
+            u for u in node.users if isinstance(u, fx.Node) and u != output
+        ]
+    move_nodes = [node for node in gm.graph.nodes if node in move_node_set]
+
+    return move_nodes
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/graph_utils.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/graph_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ee71ce69c2cf09b287dcfbd5f322df11e47ed3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/graph_utils.py
@@ -0,0 +1,145 @@
+import logging
+import os
+import tempfile
+from enum import Enum
+from typing import Callable, cast, Dict, Iterable, List, Set
+
+import torch.fx as fx
+from torch.fx.passes.shape_prop import TensorMetadata
+from torch.utils import _pytree as pytree
+from torch.utils._pytree import tree_flatten, tree_unflatten
+
+
+logger: logging.Logger = logging.getLogger("graph_utils")
+
+
+class OP(str, Enum):
+    CALL_FUNCTION = "call_function"
+    CALL_MODULE = "call_module"
+    CALL_METHOD = "call_method"
+    GET_ATTR = "get_attr"
+    OUTPUT = "output"
+    PLACEHOLDER = "placeholder"
+
+
+class CommType(str, Enum):
+    ALLREDUCE = "allreduce_"
+    ALLGATHER = "allgather_"
+    BROADCAST = "broadcast_"
+    REDUCESCATTER = "reduce_scatter_"
+    SCATTER = "scatter_"
+
+
+def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorMetadata:
+    metadata = node.meta.get("tensor_meta", None)
+    if is_required and metadata is None:
+        raise RuntimeError(
+            f"Callsite expects that ``tensor_meta`` exists in ``{node.name}``, "
+            f"but got None instead. Node: {node.op} {node.name} {node.target}"
+        )
+    return metadata
+
+
+def get_output(graph: fx.Graph) -> fx.Node:
+    """Take a graphmodule and return the graph output node.
+
+    We traverse in reverse to expedite it, with the idea that last node should be output
+    """
+    for node in reversed(graph.nodes):
+        if node.op == OP.OUTPUT:
+            return node
+    raise RuntimeError(f"Cannot find the output node in {graph}")
+
+
+def find_node(
+    graph: fx.Graph, predicate: Callable, reverse_order: bool = False
+) -> List[fx.Node]:
+    """Take a predicate and return all the nodes in the `graph` where the predicate holds."""
+    nodes = cast(Iterable[fx.Node], graph.nodes)
+    if reverse_order:
+        nodes = cast(Iterable[fx.Node], iter(reversed(nodes)))  # type: ignore[call-overload]
+    return [node for node in nodes if predicate(node)]
+
+
+def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
+    """Ensure nodes in ``subgraph`` satisfy one of the following rules.
+
+    1. The user of the node is in ``subgraph``.
+    2. The user of the node is output.
+    3. There are no users -- the node is a side-effect node.
+    """
+    all_nodes: Set[fx.Node] = set(subgraph)
+    output = get_output(graph)
+    for node in subgraph:
+        for user in node.users:
+            if not isinstance(user, fx.Node):
+                continue
+            if user not in all_nodes and user != output:
+                return False
+    return True
+
+
+def clone_subgraph(
+    graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
+) -> List[fx.Node]:
+    """Clone the given subgraph and insert it before ``target``.
+
+    This API currently does not support inserting after ``target``.
+    """
+    all_nodes = set(subgraph)
+    mapping: Dict[fx.Node, fx.Node] = dict()
+    cloned_subgraph = []
+    with graph.inserting_before(target):
+        for node in subgraph:
+            cloned_node = graph.call_function(
+                node.target, node.args, node.kwargs, node.type
+            )
+            # TODO: there are many flatten/unflatten in IterGraph that
+            # can be simplified with tree_map. Will simplify this in
+            # a follow-up PR.
+            original_input = pytree.arg_tree_leaves(*node.args, **node.kwargs)
+            cloned_input, spec = tree_flatten((cloned_node.args, cloned_node.kwargs))
+            mapped_cloned_input = []
+            for original_input_node, cloned_input_node in zip(
+                original_input, cloned_input
+            ):
+                if (
+                    isinstance(original_input_node, fx.Node)
+                    and original_input_node in all_nodes
+                ):
+                    assert original_input_node in mapping
+                    mapped_cloned_input.append(mapping[original_input_node])
+                else:
+                    mapped_cloned_input.append(cloned_input_node)
+            cloned_node.args, cloned_node.kwargs = tree_unflatten(
+                mapped_cloned_input, spec
+            )
+            mapping[node] = cloned_node
+            cloned_subgraph.append(cloned_node)
+
+    return cloned_subgraph
+
+
+def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
+    """Run the required steps to ensure production-ready graph.
+
+    Note - per the fx docs, elimination of dead code is not very precise.
+    Hence, the flag to make this step optional.
+    """
+    gm.graph.lint()
+    if remove_dead_code:
+        gm.graph.eliminate_dead_code()
+    gm.recompile()
+
+
+def dump_graphs_to_files(graphs: Dict[str, fx.GraphModule], folder: str = "") -> str:
+    if not folder:
+        folder = tempfile.mkdtemp()
+
+    for prefix, gm in graphs.items():
+        with open(os.path.join(folder, f"{prefix}.graph"), "w") as fp:
+            fp.write(str(gm))
+
+    logger.warning("Dump graphs to %s", folder)
+
+    return folder
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/iter_graph_module.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/iter_graph_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..11576db9161daad1e7c669cead76a33b7f5bc397
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/iter_graph_module.py
@@ -0,0 +1,762 @@
+import copy
+import inspect
+import logging
+from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type
+
+import torch.nn as nn
+from torch import fx
+from torch.distributed._spmd.graph_utils import (
+    clone_subgraph,
+    get_output,
+    is_leaf_subgraph,
+)
+from torch.distributed._spmd.partial_lower import partial_lower
+from torch.fx.graph import _PyTreeCodeGen, PythonCode
+from torch.fx.node import Argument
+from torch.profiler import record_function
+from torch.utils import _pytree as pytree
+from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten
+
+
+logger: logging.Logger = logging.getLogger("IterGraphModule")
+
+
+class IterGraph(fx.Graph):
+    """``IterGraph`` is used to perform cross-iteration optimization.
+
+    ``IterGraph`` keeps track of the 3 graphs, self (the original graph), setup graph, and
+    cleanup graph. The 3 graphs should be identical copies of a ``fx.Graph``.
+
+    IterGraph subclass fx.Graph to override the necessary APIs that will be used
+    when constructing a optimization, e.g., communication fusion. IterGraph also
+    provides APIs that originally belong to fx.Node and all these APIs will have
+    ``node_`` prefix. For example, ``IterGraph.node_prepend`` is the equivalence
+    of ``fx.Node.prepend``. Note that all the optimizations must be constructed
+    using these APIs.
+    """
+
+    def __init__(
+        self,
+        orig_graph: fx.Graph,
+        setup_graph: fx.Graph,
+        cleanup_graph: fx.Graph,
+        owning_module: Optional[fx.GraphModule] = None,
+        tracer_cls: Optional[Type["fx.Tracer"]] = None,
+        tracer_extras: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(owning_module, tracer_cls, tracer_extras)
+
+        output_vals = self.graph_copy(orig_graph, {}, return_output_node=True)
+        # TODO: if we do ``deepcopy(_codegen)`` and the input argument contains
+        # a dictionary with the form of Dict[torch.Tensor, Any], the
+        # torch.fx._pytree.treen_flatten_spec will not be able to flatten the
+        # dict -- the torch.Tensor will be duplicated because the _input_spec
+        # will save the ``keys`` of a dictionary (the values are not saved).
+        self._codegen = copy.deepcopy(orig_graph._codegen)
+        assert isinstance(output_vals, tuple)
+        output_val, old_output_val = output_vals
+        super().output(output_val, type_expr=getattr(old_output_val, "type", None))
+
+        self.setup_graph = setup_graph
+        self.cleanup_graph = cleanup_graph
+        self._all_graphs: Tuple[fx.Graph, ...] = (
+            self.setup_graph,
+            self.cleanup_graph,
+            cast(fx.Graph, super()),
+        )
+
+        self._setup_mapping: Dict[fx.Node, fx.Node] = {}
+        self._cleanup_mapping: Dict[fx.Node, fx.Node] = {}
+        self._freeze_cross_iter_movement = False
+        self._cross_iter_block_count = 0
+
+        for node, setup_node, cleanup_node in zip(
+            self.nodes, self.setup_graph.nodes, self.cleanup_graph.nodes
+        ):
+            self._setup_mapping[node] = setup_node
+            self._cleanup_mapping[node] = cleanup_node
+
+        self.num_extra_output = 0
+
+    def _lookup_node(self, node: fx.Node, graph: fx.Graph) -> Optional[fx.Node]:
+        if graph == self.setup_graph:
+            return self._setup_mapping.get(node, None)
+        elif graph == self.cleanup_graph:
+            return self._cleanup_mapping.get(node, None)
+        return node
+
+    def _fx_graph_call(
+        self, graph: fx.Graph, func: str, *args: Any, **kwargs: Any
+    ) -> Any:
+        fx_graph: fx.Graph = graph if graph != self else cast(fx.Graph, super())
+        return getattr(fx_graph, func)(*args, **kwargs)
+
+    def _insert_context(self, func: str, node: fx.Node):
+        class _InsertPoint:
+            def __init__(self, insert_points: List[Any]):
+                self.insert_points = insert_points
+
+            def __enter__(self):
+                pass
+
+            def __exit__(self, type, value, tb):
+                for insert_point in self.insert_points:
+                    insert_point.__exit__(type, value, tb)
+
+        insert_points = []
+        for graph in self._all_graphs:
+            if node:
+                actual_node = self._lookup_node(node, graph)
+                assert actual_node is not None, "Cannot handle None case now."
+            else:
+                actual_node = node
+            insert_points.append(getattr(graph, func)(actual_node))
+
+        return _InsertPoint(insert_points)
+
+    def inserting_after(self, node):
+        if self._freeze_cross_iter_movement:
+            return super().inserting_after(node)
+        return self._insert_context("inserting_after", node)
+
+    def inserting_before(self, node):
+        if self._freeze_cross_iter_movement:
+            return super().inserting_before(node)
+        return self._insert_context("inserting_before", node)
+
+    def _forward_subgraph_inputs(
+        self, subgraph: List[fx.Node], graph: fx.Graph, erase_node: bool
+    ) -> int:
+        """Turn the inputs of a subgraph into the extra output of the entire graph.
+
+        If ``erase_node`` is True, the subgraph will be erased from the graph -- essentially forward the inputs
+        of the subgraph to the output of the graph.
+        """
+        output = get_output(graph)
+        inputs = []
+        all_nodes: Set[fx.Node] = set(subgraph)
+
+        for node in subgraph:
+            node_inputs = pytree.arg_tree_leaves(*node.args, **node.kwargs)
+            for _input in node_inputs:
+                if not isinstance(_input, fx.Node):
+                    continue
+                if _input in all_nodes:
+                    continue
+                inputs.append(_input)
+
+        if erase_node:
+            # We have to remove the node in the reversed order to ensure the
+            # node has zero users.
+            erased = set()
+            for node in reversed(subgraph):
+                if len(node.users) == 1:
+                    key = next(iter(node.users.keys()))
+                    if key == output:
+                        flatten_args, spec = tree_flatten((output.args, output.kwargs))
+                        if node not in flatten_args:
+                            # This optimizer node from the legacy _SPMD tracing.
+                            node.users.clear()
+                        elif str(node.target).startswith("aten.copy_"):
+                            # This is the case where the optimizer is
+                            # functionalized with copy_.
+                            for i in range(len(flatten_args)):
+                                if flatten_args[i] == node:
+                                    flatten_args[i] = node.args[0]
+                        else:
+                            # We have not figured out semantics of forwarding
+                            # all diff ops.
+                            raise RuntimeError(
+                                f"IterGraph does not how to forward the output of {node}"
+                            )
+                        output.args, output.kwargs = tree_unflatten(flatten_args, spec)
+
+                # This is the step case where there is a virtual data dependency
+                # (in-place update) between step and optimizer. And
+                # functionalize_optim add this dependency
+                for user in list(node.users.keys()):
+                    if user in erased:
+                        node.users.pop(user)
+                if node.users:
+                    raise RuntimeError(
+                        "IterGraph has not supported moving the nodes that "
+                        "produce users output result. "
+                        f"Error node: {node}."
+                    )
+                self._fx_graph_call(graph, "erase_node", node)
+                erased.add(node)
+
+        # Add all the extra output nodes into a list and append the list to
+        # the original output.args[0].
+        if self.num_extra_output:
+            # If the extra-output list already exist, just use it.
+            cast(List[fx.Node], output.args[0][-1]).extend(inputs)  # type: ignore[index]
+            new_output = output.args[0]
+        else:
+            # When adding the extra-output list, out_spec of _PyTreeCodeGen
+            # must be updated accordingly.
+            if isinstance(graph._codegen, _PyTreeCodeGen):
+                codegen = graph._codegen
+                new_output = list(output.args[0])  # type: ignore[arg-type]
+                new_output.append(inputs)
+                assert codegen.pytree_info.out_spec is not None
+                original_tree_out = tree_unflatten(
+                    cast(List[Any], output.args[0]), codegen.pytree_info.out_spec
+                )
+                # Use None as a placeholder. If we use the extra-output list
+                # the list will be flatten as well and put into out_spec.
+                _, out_spec = tree_flatten((original_tree_out, None))
+                codegen.pytree_info = codegen.pytree_info._replace(out_spec=out_spec)
+            else:
+                new_output = (output.args[0], inputs)
+        self._fx_graph_call(graph, "erase_node", output)
+        self._fx_graph_call(graph, "output", new_output)
+
+        logger.info("Extended outputs from the subgraph inputs: %s", str(inputs))
+        return len(inputs)
+
+    def _forward_inputs_to_subgraph(
+        self, subgraph: List[fx.Node], graph: fx.Graph, extra_input: int
+    ) -> None:
+        """Create extra input nodes and forward the input nodes to the ``subgraph``.
+
+        The external input nodes of ``subgraph`` (nodes that are not in ``subgraph``) will replaced by the newly
+        created input nodes.
+        """
+        placeholders = [node for node in graph.nodes if str(node.op) == "placeholder"]
+        assert placeholders, "No placeholders are found"
+        # Append the extra input nodes to the current input nodes.
+        with self._fx_graph_call(graph, "inserting_after", placeholders[-1]):
+            new_input_nodes = list(
+                reversed(
+                    [
+                        self._fx_graph_call(
+                            graph,
+                            "placeholder",
+                            f"cross_iter_input_{self._cross_iter_block_count}_{i}",
+                        )
+                        for i in reversed(range(extra_input))
+                    ]
+                )
+            )
+
+        # Update the inputs of subgraph to use the newly created input nodes.
+        all_nodes = set(subgraph)
+        new_input_index = 0
+        for node in subgraph:
+            node_inputs, spec = tree_flatten((node.args, node.kwargs))
+            new_node_inputs = []
+            for input_node in node_inputs:
+                if not isinstance(input_node, fx.Node) or input_node in all_nodes:
+                    new_node_inputs.append(input_node)
+                else:
+                    new_node_inputs.append(new_input_nodes[new_input_index])
+                    new_input_index += 1
+            node.args, node.kwargs = tree_unflatten(new_node_inputs, spec)
+        assert new_input_index == len(
+            new_input_nodes
+        ), f"More inputs than needed {len(new_input_nodes)} > {new_input_index}"
+
+        # Update the in_spec of _PyTreeCodeGen if in_spec is not None (the new
+        # SPMD makes in_spec as None).
+        if (
+            isinstance(graph._codegen, _PyTreeCodeGen)
+            and graph._codegen.pytree_info.in_spec is not None
+        ):
+            codegen = graph._codegen
+            original_tree_in = tree_unflatten(placeholders, codegen.pytree_info.in_spec)
+            _, in_spec = tree_flatten(tuple(list(original_tree_in) + new_input_nodes))
+            codegen.pytree_info = codegen.pytree_info._replace(in_spec=in_spec)
+            for new_input in new_input_nodes:
+                codegen.pytree_info.orig_args.append(new_input.name)
+            codegen.pytree_info = codegen.pytree_info._replace(in_spec=in_spec)
+
+    def move_to_next_iter_before(
+        self, subgraph: List[fx.Node], target_node: fx.Node
+    ) -> None:
+        """Move the ``subgraph`` to the next iteration before ``target_node``.
+
+        The ``subgraph`` is a list of fx.Node and must satisfy the following
+        restrictions:
+            1. The order of the nodes in ``subgraph`` must obey the topological
+               sort order.
+            2. The users of the node in ``subgraph`` must be one of the following:
+                a.) the user is also a node in ``subgraph``.
+                b.) the user is the output of the full graph.
+                c.) the node has users (side effect node).
+        """
+        if self._freeze_cross_iter_movement:
+            raise RuntimeError(
+                "The cross-iteration movement has been frozen for the given "
+                "IterGraph."
+            )
+
+        if not is_leaf_subgraph(self, subgraph):
+            raise ValueError(
+                "The target nodes for ``move_to_next_iter_before`` must "
+                "satisfy one of the following conditions: 1) the user of the "
+                "node is in the target nodes, 2) the user is the output of the "
+                "graph, 3) there are no users -- the node is a side-effect node. "
+            )
+
+        self._cross_iter_block_count += 1
+        # The main graph must be the last one to be modified. Otherwise, the
+        # mapping may change and hence introduce incorrect mapping for setup
+        # and cleanup graphs.
+
+        # For the setup graph, no additional input is needed but additional
+        # outputs will be created. The additional output represents the input of
+        # the action to be moved to the next iteration -- main graph.
+        setup_subgraph: List[fx.Node] = []
+        for node in subgraph:
+            mapped_node = self._lookup_node(node, self.setup_graph)
+            assert mapped_node is not None
+            setup_subgraph.append(mapped_node)
+        setup_extra_input = self._forward_subgraph_inputs(
+            subgraph=setup_subgraph,
+            graph=self.setup_graph,
+            erase_node=True,
+        )
+
+        # For the cleanup graph, additional input is required to get the output
+        # from the last iteration -- main graph. Additional nodes are also
+        # needed to perform the action moved from the last iteration.
+        target_cleanup_node = self._lookup_node(target_node, self.cleanup_graph)
+        assert target_cleanup_node is not None, "The target_cleanup_node is None."
+        cleanup_subgraph: List[fx.Node] = []
+        for node in subgraph:
+            mapped_node = self._lookup_node(node, self.cleanup_graph)
+            assert mapped_node is not None
+            cleanup_subgraph.append(mapped_node)
+        cloned_subgraph = clone_subgraph(
+            self.cleanup_graph,
+            cleanup_subgraph,
+            target=target_cleanup_node,
+        )
+        self._forward_inputs_to_subgraph(
+            cloned_subgraph, self.cleanup_graph, setup_extra_input
+        )
+
+        # For the main graph, additional input will be created to represent
+        # the output from the last iteration -- main graph or setup graph.
+        # Additional output will also be generated to represent the input for
+        # the next iteration -- the main graph or the cleanup graph.
+        main_extra_input = self._forward_subgraph_inputs(
+            subgraph=subgraph, graph=self, erase_node=False
+        )
+        assert main_extra_input == setup_extra_input
+        for node in subgraph:
+            target_node.prepend(node)
+        self._forward_inputs_to_subgraph(subgraph, self, main_extra_input)
+
+        # TODO: This is a temporary solution. We are going to remove DCE usage
+        # or have something to replace fx DCE.
+        for node in self.cleanup_graph.nodes:
+            if len(node.users) == 0:
+                node.users["__hold__"] = None  # type: ignore[index]
+        for node in self.nodes:
+            if len(node.users) == 0:
+                node.users["__hold__"] = None  # type: ignore[index]
+        self.num_extra_output += main_extra_input
+
+    def move_before(self, nodes: List[fx.Node], target_node: fx.Node) -> None:
+        for graph in self._all_graphs:
+            actual_nodes = [self._lookup_node(node, graph) for node in nodes]
+            actual_target_node = self._lookup_node(target_node, graph)
+            assert actual_target_node is not None
+            for actual_node in actual_nodes:
+                actual_target_node.prepend(actual_node)
+
+    def move_after(self, nodes: List[fx.Node], target_node: fx.Node) -> None:
+        for graph in self._all_graphs:
+            actual_nodes = [self._lookup_node(node, graph) for node in nodes]
+            actual_target_node = self._lookup_node(target_node, graph)
+            for actual_node in actual_nodes:
+                assert actual_target_node is not None
+                actual_target_node.append(actual_node)
+                actual_target_node = actual_node
+
+    def call_function(
+        self,
+        the_function: Callable[..., Any],
+        args: Optional[Tuple[Argument, ...]] = None,
+        kwargs: Optional[Dict[str, Argument]] = None,
+        type_expr: Optional[Any] = None,
+    ) -> fx.Node:
+        if self._freeze_cross_iter_movement:
+            return super().call_function(the_function, args, kwargs, type_expr)
+
+        setup_args = tree_map(
+            lambda arg: self._lookup_node(arg, self.setup_graph)
+            if isinstance(arg, fx.Node)
+            else arg,
+            args,
+        )
+        setup_kwargs = tree_map(
+            lambda arg: self._lookup_node(arg, self.setup_graph)
+            if isinstance(arg, fx.Node)
+            else arg,
+            kwargs,
+        )
+        cleanup_args = tree_map(
+            lambda arg: self._lookup_node(arg, self.cleanup_graph)
+            if isinstance(arg, fx.Node)
+            else arg,
+            args,
+        )
+        cleanup_kwargs = tree_map(
+            lambda arg: self._lookup_node(arg, self.cleanup_graph)
+            if isinstance(arg, fx.Node)
+            else arg,
+            kwargs,
+        )
+
+        setup_node = self.setup_graph.call_function(
+            the_function, setup_args, setup_kwargs, type_expr
+        )
+        main_node = super().call_function(the_function, args, kwargs, type_expr)
+        cleanup_node = self.cleanup_graph.call_function(
+            the_function, cleanup_args, cleanup_kwargs, type_expr
+        )
+        self._setup_mapping[main_node] = setup_node
+        self._cleanup_mapping[main_node] = cleanup_node
+        return main_node
+
+    def erase_node(self, to_erase: fx.Node) -> None:
+        if self._freeze_cross_iter_movement:
+            return super().erase_node(to_erase)
+
+        setup_node = self._lookup_node(to_erase, self.setup_graph)
+        assert setup_node is not None, "setup_node is None"
+        self.setup_graph.erase_node(setup_node)
+        super().erase_node(to_erase)
+        cleanup_node = self._lookup_node(to_erase, self.cleanup_graph)
+        self.cleanup_graph.erase_node(cleanup_node)
+
+    def placeholder(
+        self,
+        name: str,
+        type_expr: Optional[Any] = None,
+        default_value: Any = inspect.Signature.empty,
+    ) -> fx.Node:
+        if self._freeze_cross_iter_movement:
+            return super().placeholder(name, type_expr, default_value)
+
+        main_placeholder = super().placeholder(name, type_expr, default_value)
+        setup_placeholder = self.setup_graph.placeholder(name, type_expr, default_value)
+        cleanup_placeholder = self.cleanup_graph.placeholder(
+            name, type_expr, default_value
+        )
+        self._setup_mapping[main_placeholder] = setup_placeholder
+        self._cleanup_mapping[main_placeholder] = cleanup_placeholder
+        return main_placeholder
+
+    def output(self, result: Argument, type_expr: Optional[Any] = None) -> fx.Node:
+        if self._freeze_cross_iter_movement:
+            return super().output(result, type_expr)
+
+        main_output = super().output(result, type_expr)
+        setup_result = tree_map(
+            lambda _result: self._lookup_node(_result, self.setup_graph)
+            if isinstance(_result, fx.Node)
+            else _result,
+            result,
+        )
+        cleanup_result = tree_map(
+            lambda _result: self._lookup_node(_result, self.cleanup_graph)
+            if isinstance(_result, fx.Node)
+            else _result,
+            result,
+        )
+        self.setup_graph.output(setup_result, type_expr)
+        self.cleanup_graph.output(cleanup_result, type_expr)
+
+        return main_output
+
+    def lint(self) -> None:
+        self.setup_graph.lint()
+        super().lint()
+        self.cleanup_graph.lint()
+
+    def node_prepend(self, target_node: fx.Node, node: fx.Node) -> None:
+        """Prepend node to target_node."""
+        if self._freeze_cross_iter_movement:
+            target_node.prepend(node)
+            return
+
+        for graph in self._all_graphs:
+            actual_node = self._lookup_node(node, graph)
+            assert actual_node is not None, "The node is None"
+            actual_target_node = self._lookup_node(target_node, graph)
+            assert actual_target_node is not None, "The target node is None"
+            actual_target_node.prepend(actual_node)
+
+    def node_append(self, target_node: fx.Node, node: fx.Node) -> None:
+        """Append node to target_node."""
+        if self._freeze_cross_iter_movement:
+            target_node.append(node)
+            return
+
+        for graph in self._all_graphs:
+            actual_node = self._lookup_node(node, graph)
+            assert actual_node is not None, f"The actual node is None, {node}."
+            actual_target_node = self._lookup_node(target_node, graph)
+            assert (
+                actual_target_node is not None
+            ), f"The actual target node is None, {target_node}."
+            actual_target_node.append(actual_node)
+
+    def node_set_args(self, node: fx.Node, args: Tuple[Argument, ...]) -> None:
+        if self._freeze_cross_iter_movement:
+            node.args = args
+            return
+
+        setup_args = tree_map_only(
+            fx.Node, lambda _arg: self._lookup_node(_arg, self.setup_graph), args
+        )
+        setup_node = self._lookup_node(node, self.setup_graph)
+        assert setup_node is not None
+        setup_node.args = setup_args
+        cleanup_args = tree_map_only(
+            fx.Node, lambda _arg: self._lookup_node(_arg, self.cleanup_graph), args
+        )
+        cleanup_node = self._lookup_node(node, self.cleanup_graph)
+        assert cleanup_node is not None
+        cleanup_node.args = cleanup_args
+        node.args = args
+
+    def node_set_kwargs(self, node: fx.Node, kwargs: Dict[str, Argument]) -> None:
+        if self._freeze_cross_iter_movement:
+            node.kwargs = kwargs
+            return
+
+        setup_kwargs = tree_map_only(
+            fx.Node, lambda _arg: self._lookup_node(_arg, self.setup_graph), kwargs
+        )
+        setup_node = self._lookup_node(node, self.setup_graph)
+        assert setup_node is not None
+        setup_node.kwargs = setup_kwargs
+        cleanup_kwargs = tree_map_only(
+            fx.Node, lambda _arg: self._lookup_node(_arg, self.cleanup_graph), kwargs
+        )
+        cleanup_node = self._lookup_node(node, self.cleanup_graph)
+        assert cleanup_node is not None
+        cleanup_node.kwargs = cleanup_kwargs
+        node.kwargs = kwargs
+
+    def node_replace_all_uses_with(
+        self,
+        node: fx.Node,
+        replace_with: fx.Node,
+        delete_user_cb: Callable[[fx.Node], bool] = lambda user: True,
+        *,
+        propagate_meta=False,
+    ) -> List[fx.Node]:
+        for graph in self._all_graphs:
+            actual_node = self._lookup_node(node, graph)
+            actual_replace_with = self._lookup_node(replace_with, graph)
+            assert actual_node is not None
+            ret = actual_node.replace_all_uses_with(
+                actual_replace_with,
+                delete_user_cb,
+                propagate_meta=propagate_meta,
+            )
+        return ret  # type: ignore[possibly-undefined]
+
+    def node_add_user(self, node: fx.Node, user: Any) -> None:
+        for graph in self._all_graphs:
+            actual_node = self._lookup_node(node, graph)
+            if isinstance(user, fx.Node):
+                actual_user_node = self._lookup_node(user, graph)
+            else:
+                actual_user_node = user
+            assert actual_node is not None
+            actual_node.users[actual_user_node] = None  # type: ignore[index]
+
+    def node_remove_user(self, node: fx.Node, user: Any) -> None:
+        for graph in self._all_graphs:
+            actual_node = self._lookup_node(node, graph)
+            if isinstance(user, fx.Node):
+                actual_user_node = self._lookup_node(user, graph)
+            else:
+                actual_user_node = user
+            assert actual_node is not None
+            del actual_node.users[actual_user_node]  # type: ignore[arg-type]
+
+    def keep_unused_nodes(self) -> None:
+        for node in self.nodes:
+            if len(node.users) == 0 and str(node.op) != "output":
+                self.node_add_user(node, "__hold__")
+
+    def functionalize_optim(self) -> None:
+        # IterGraph can only support full graph (fwd+bwd+optim). As optimizer
+        # is not a functional call (it is inplace op), this method adds the of
+        # the optimizer call. This method has strong assumption of the optimizer
+        # and may not always be working. This method is intended be a temporary
+        # solution only.
+
+        # TODO: remove this API after DCE is removed
+        for node in reversed(self.nodes):
+            if node.name.startswith("output"):
+                output_node = node
+            elif node.name.startswith(
+                "_fused_adam_",
+            ):
+                optim_node = node
+            elif node.name.startswith(
+                "_foreach_add_",
+            ):
+                step_node = node
+                self.node_add_user(optim_node, output_node)  # type: ignore[possibly-undefined]
+                self.node_add_user(step_node, optim_node)  # type: ignore[possibly-undefined]
+
+    def defunctionalize_optim(self) -> None:
+        # TODO: remove this API after DCE is not used with IterGraph
+        for graph in self._all_graphs:
+            for node in reversed(graph.nodes):
+                if node.name.startswith("output"):
+                    output_node = node
+                elif node.name.startswith(
+                    "_fused_adam_",
+                ):
+                    optim_node = node
+                elif node.name.startswith(
+                    "_foreach_add_",
+                ):
+                    step_node = node
+                    optim_node.users.pop(output_node, None)  # type: ignore[possibly-undefined]
+                    step_node.users.pop(optim_node, None)  # type: ignore[possibly-undefined]
+
+    def freeze_cross_iter_movement(self) -> None:
+        self._freeze_cross_iter_movement = True
+
+
+class IterGraphModule(nn.Module):
+    """``IterGraphModule`` provides the ability to do cross-iteration optimization.
+
+    Given a ``fx.GraphModule``, main_gm, ``IterGraphModule`` internally
+    duplicate it to 3 copies and redirect the ``forward`` request to a different
+    ``fx.GraphModule`` based on the iteration count. This allows users to do
+    graph optimizations that across iterations (e.g., moving collective wait in
+    the backward to the forward of the next iteration).
+
+    Note that users must call the APIs provided by ``IterGraphModule`` or
+    ``IterGraph`` to rewrite the graph so that ``IterGraphModule`` can keep the
+    data dependency for all 3 graphs.
+    """
+
+    def __init__(
+        self,
+        main_gm: fx.GraphModule,
+        max_iters: int = -1,
+        enable_inductor: bool = False,
+    ) -> None:
+        super().__init__()
+
+        def _copy_gm(src: fx.GraphModule, graph: fx.Graph) -> fx.GraphModule:
+            gm = fx.GraphModule(src, graph)
+            gm.meta = getattr(graph, "meta", {})
+            return gm
+
+        self.setup_gm = _copy_gm(main_gm, copy.deepcopy(main_gm.graph))
+        self.cleanup_gm = _copy_gm(main_gm, copy.deepcopy(main_gm.graph))
+        self.main_gm = _copy_gm(
+            main_gm,
+            IterGraph(main_gm.graph, self.setup_gm.graph, self.cleanup_gm.graph),
+        )
+
+        self._iter = 0
+        self._max_iters = max_iters
+        self._previous_output: Tuple[Any, ...] = tuple()
+        self._num_extra_output = 0
+        self._is_frozen = False
+        self._enable_inductor = enable_inductor
+
+    def finalize_setup(self) -> None:
+        """Set up the internal states and also get the signal from users that what is the maximum iteration count.
+
+        This method must be called before the forward() is called.
+        """
+        if not self._is_frozen:
+            self.graph.freeze_cross_iter_movement()
+            self._num_extra_output = self.graph.num_extra_output
+            if self._enable_inductor:
+                self.main_gm = partial_lower(self.main_gm)
+            self._is_frozen = True
+
+        self._iter = 0
+
+    def _run(self, gm: fx.GraphModule, last_iter: bool, *args, **kwargs) -> Any:
+        if self._num_extra_output > 0:
+            new_args = args + (self._previous_output)
+            output = gm(*new_args, **kwargs)
+            if not last_iter:
+                assert len(output) == 2
+                self._previous_output = tuple(output[-1])
+                assert (
+                    len(self._previous_output) > 0
+                ), "There should be at least one extra output."
+                output = output[0]
+        else:
+            # No cross-iteration optimization is done. Simply call the
+            # GraphModule.
+            output = gm(*args, **kwargs)
+        return output
+
+    def forward(self, *args: Any, last_iter: bool = False, **kwargs: Any) -> Any:
+        self._iter += 1
+        last_iter = last_iter or self._iter == self._max_iters
+        if last_iter:
+            logger.info("Using the cleanup graph")
+            gm = self.cleanup_gm
+            profiler_string = "## IterGraphModule: Cleanup Graph ##"
+            self._iter = 0
+        elif self._iter == 1:
+            logger.info("Using the setup graph")
+            gm = self.setup_gm
+            profiler_string = "## IterGraphModule: Setup Graph ##"
+        else:
+            gm = self.main_gm
+            if self._iter == 2:
+                logger.info("Using the main graph")
+                profiler_string = "## IterGraphModule -- Maybe Compiling ##"
+            else:
+                profiler_string = "## IterGraphModule ##"
+
+        with record_function(profiler_string):
+            return self._run(gm, last_iter, *args, **kwargs)
+
+    @property
+    def graph(self) -> IterGraph:
+        return cast(IterGraph, self.main_gm.graph)
+
+    def recompile(self) -> PythonCode:
+        self.setup_gm.recompile()
+        self.cleanup_gm.recompile()
+        return self.main_gm.recompile()
+
+    def freeze_cross_iter_movement(self) -> None:
+        # TODO: remove this API once it is not used.
+        self.graph.freeze_cross_iter_movement()
+        self._num_extra_output = self.graph.num_extra_output
+
+    def print_readable(self, print_output: bool = True) -> str:
+        return self.main_gm.print_readable(print_output)
+
+    def print_all_graphs(self) -> None:
+        logger.info("Printing the three fx.Graph:")
+        logger.info("1. Setup fx.Graph:")
+        logger.info("%s", self.setup_gm.graph)
+        logger.info("2. Main fx.Graph:")
+        logger.info("%s", self.main_gm.graph)
+        logger.info("3. Cleanup fx.Graph:")
+        logger.info("%s", self.cleanup_gm.graph)
+
+    def print_all_graph_modules(self) -> None:
+        logger.info("Printing the three fx gm:")
+        logger.info("1. Setup fx.GraphModule:")
+        logger.info("%s", self.setup_gm.print_readable(False))
+        logger.info("2. Main fx.GraphModule:")
+        logger.info("%s", self.main_gm.print_readable(False))
+        logger.info("3. Cleanup fx.GraphModule:")
+        logger.info("%s", self.cleanup_gm.print_readable(False))
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/log_utils.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/log_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c792e8649c96851ffd7c9ba1df12d4dce67a9bbe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/log_utils.py
@@ -0,0 +1,78 @@
+import logging
+import logging.config
+import os
+from typing import Optional
+
+import torch.distributed as dist
+
+
+LOGGING_CONFIG = {
+    "version": 1,
+    "formatters": {
+        "spmd_format": {"format": "%(name)s: [%(levelname)s] %(message)s"},
+        "graph_opt_format": {"format": "%(name)s: [%(levelname)s] %(message)s"},
+    },
+    "handlers": {
+        "spmd_console": {
+            "class": "logging.StreamHandler",
+            "level": "DEBUG",
+            "formatter": "spmd_format",
+            "stream": "ext://sys.stdout",
+        },
+        "graph_opt_console": {
+            "class": "logging.StreamHandler",
+            "level": "DEBUG",
+            "formatter": "graph_opt_format",
+            "stream": "ext://sys.stdout",
+        },
+        "null_console": {
+            "class": "logging.NullHandler",
+        },
+    },
+    "loggers": {
+        "spmd_exp": {
+            "level": "DEBUG",
+            "handlers": ["spmd_console"],
+            "propagate": False,
+        },
+        "graph_opt": {
+            "level": "DEBUG",
+            "handlers": ["graph_opt_console"],
+            "propagate": False,
+        },
+        "null_logger": {
+            "handlers": ["null_console"],
+            "propagate": False,
+        },
+        # TODO(anj): Add loggers for MPMD
+    },
+    "disable_existing_loggers": False,
+}
+
+
+def get_logger(log_type: str) -> Optional[logging.Logger]:
+    from torch.distributed._spmd import config
+
+    if "PYTEST_CURRENT_TEST" not in os.environ:
+        logging.config.dictConfig(LOGGING_CONFIG)
+        avail_loggers = list(LOGGING_CONFIG["loggers"].keys())  # type: ignore[attr-defined]
+        assert (
+            log_type in avail_loggers
+        ), f"Unable to find {log_type} in the available list of loggers {avail_loggers}"
+
+        if not dist.is_initialized():
+            return logging.getLogger(log_type)
+
+        if dist.get_rank() == 0:
+            logger = logging.getLogger(log_type)
+            logger.setLevel(config.log_level)
+            if config.log_file_name is not None:
+                log_file = logging.FileHandler(config.log_file_name)
+                log_file.setLevel(config.log_level)
+                logger.addHandler(log_file)
+        else:
+            logger = logging.getLogger("null_logger")
+
+        return logger
+
+    return logging.getLogger("null_logger")
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/parallel_mode.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/parallel_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..da719c8086eb48750b900fe8d969099b87d7ef1e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/parallel_mode.py
@@ -0,0 +1,216 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.utils._pytree as pytree
+from torch._subclasses import FakeTensorMode
+from torch.distributed._spmd.data_parallel import (
+    DataParallelStyle,
+    partition_data_parallel,
+)
+from torch.distributed._spmd.distribute import _convert_to_distributed, Schema
+from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard
+
+from torch.fx import GraphModule
+
+
+class ParallelMode(ABC):
+    """
+    Basic Parallel Mode interface. Each parallelism pattern should implement
+    this interface to describe how to partition and compile the graph in the
+    spmd compiler.
+    """
+
+    @abstractmethod
+    def partition(
+        self,
+        gm: GraphModule,
+        model: torch.nn.Module,
+        optimizer: Optional[torch.optim.Optimizer],
+        params_and_buffers: Dict[str, Any],
+        named_states: Dict[str, Any],
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+    ) -> GraphModule:
+        """
+        Partition a single device graph to a distributed graph.
+
+        TODO(@wanchaol): some of these arguments are not necessary for
+        partitioning, remove the unnecessary ones later.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def transform_and_compile(self, gm: GraphModule) -> GraphModule:
+        """
+        Transform and compile a distributed graph with a set of graph
+        transformation and optimization passes for each parallel mode.
+
+        The returned result should be a compiled executable graph in
+        the distributed environment.
+        """
+        # TODO: add more necessary arguments to this interface.
+        raise NotImplementedError()
+
+
+class DataParallel(ParallelMode):
+    """Data Parallelism mode."""
+
+    def __init__(
+        self,
+        parallel_style: str = "replicate",
+        *,
+        input_batch_dim: int = 0,
+        custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None,
+    ):
+        """
+        DataParallel Mode that partition the model and graph to data parallel style
+        parallelism (i.e. DDP/FSDP/ZERO-3). It currently supports three different
+        parallel styles: "replicate", "fully_shard", and "default". See
+        :class:`DataParallelStyle` for more details.
+
+        Args:
+            parallel_style (str): parallel style to use. Currently supports
+                "replicate", "fully_shard", and "default".
+
+        Keyword args:
+            input_batch_dim (int): the batch dimension of the input tensor.
+                 default: 0
+            custom_passes (Callable[[GraphModule], GraphModule], optional):
+                A custom callable that overrides the default graph transformation
+                and optimization passes.
+        """
+        if parallel_style == "replicate":
+            self.parallel_style = DataParallelStyle.REPLICATE
+        elif parallel_style == "fully_shard":
+            self.parallel_style = DataParallelStyle.FULLY_SHARD
+        elif parallel_style == "default":
+            self.parallel_style = DataParallelStyle.DEFAULT
+        else:
+            raise RuntimeError(f"Unknown parallel style: {parallel_style}")
+
+        # TODO: what if user passes in a incorrect `input_batch_dim`, how should we
+        # detect that and do proper error handling?
+        self.input_batch_dim = input_batch_dim
+
+        if custom_passes is not None:
+            self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
+        else:
+            # TODO: add a few default passes here.
+            self._gm_passes = lambda gm: gm
+
+    def partition(
+        self,
+        gm: GraphModule,
+        model: torch.nn.Module,
+        optimizer: Optional[torch.optim.Optimizer],
+        params_and_buffers: Dict[str, Any],
+        named_states: Dict[str, Any],
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+    ) -> GraphModule:
+        # TODO: figure out a way to avoid explicit "cuda" mesh.
+        mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()))
+
+        gm = partition_data_parallel(
+            gm,
+            model,
+            optimizer,
+            params_and_buffers,
+            named_states,
+            args,
+            kwargs,
+            mesh,
+            self.parallel_style,
+            self.input_batch_dim,
+        )
+        return gm
+
+    def transform_and_compile(self, gm: GraphModule) -> GraphModule:
+        """optimize a distributed graph with a set of optimization passes"""
+        # TODO: add more necessary arguments to this interface.
+        return self._gm_passes(gm)
+
+
+class DTensorExpandMode(ParallelMode):
+    """
+    The DTensor Expand mode. It's replicating the parameters and
+    shard the inputs to represent DDP like behavior, it's currently
+    a transitent mode before we move to the new data parallel expansion.
+    """
+
+    def __init__(
+        self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None
+    ):
+        self._placements_override: Dict[int, List[Placement]] = {}
+        if custom_passes is not None:
+            self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
+        else:
+            # TODO: add a few default passes here.
+            self._gm_passes = lambda gm: gm
+
+    def partition(
+        self,
+        gm: GraphModule,
+        model: torch.nn.Module,
+        optimizer: Optional[torch.optim.Optimizer],
+        params_and_buffers: Dict[str, Any],
+        named_states: Dict[str, Any],
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+    ) -> GraphModule:
+        flat_args = pytree.arg_tree_leaves(*args, **kwargs)
+
+        mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
+        shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
+        # FIXME: allow other sharding schemas
+        replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
+
+        inps, schemas = [], []
+
+        for p in pytree.tree_leaves(params_and_buffers):
+            assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
+            inps.append(p)
+            schemas.append(replicate_schema)
+
+        for o in pytree.tree_leaves(named_states):
+            if isinstance(o, torch.Tensor):
+                inps.append(o)
+                schemas.append(replicate_schema)
+            else:
+                inps.append(torch.empty(0))
+                schemas.append(replicate_schema)
+
+        for a in flat_args:
+            if isinstance(a, torch.Tensor):
+                inps.append(a)
+                if id(a) in self._placements_override:
+                    schemas.append(
+                        Schema(mesh=mesh, placements=self._placements_override[id(a)])
+                    )
+                else:
+                    schemas.append(shard_schema)
+            else:
+                # Create dummy tensor and schema for non-tensor inputs for
+                # the purpose of dtensor expansion. Non-tensor inputs are
+                # guaranteed unused in dispatcher graphs produced by make_fx.
+                # However, we still need to respect them so that tensor inputs
+                # match wtih their placeholders.
+                inps.append(torch.empty(0))
+                schemas.append(shard_schema)
+
+        with FakeTensorMode(allow_non_fake_inputs=True):
+            fake_inps = [torch.empty_like(inp) for inp in inps]
+
+        return _convert_to_distributed(
+            gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
+        )[0]
+
+    def transform_and_compile(self, gm: GraphModule) -> GraphModule:
+        """
+        Transform and compile a distributed graph with a set of graph transformation
+        and optimization passes for the dtensor fallback parallel mode.
+        """
+        # TODO: move the trasnformation passed to this function
+        return self._gm_passes(gm)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_spmd/partial_lower.py b/MLPY/Lib/site-packages/torch/distributed/_spmd/partial_lower.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bdd6fd85fb425b43728e88586090feffec05a84
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_spmd/partial_lower.py
@@ -0,0 +1,268 @@
+# This file is copied from Meta internal repo and is not synced with the
+# internal version. Once the internal version is fully mature, we should
+# upstream again and retire the internal version. @yifuwang
+
+import logging
+import operator
+from typing import Callable, List, Optional, Set, Tuple
+
+from functorch import make_fx
+
+import torch
+
+from torch._inductor.compile_fx import compile_fx_inner
+from torch._inductor.decomposition import select_decomp_table
+
+MIN_ATEN_OPS_TO_LOWER = 10
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+def _create_subgraph_module(
+    inputs: List[torch.fx.Node], body: List[torch.fx.Node], outputs: List[torch.fx.Node]
+) -> torch.fx.GraphModule:
+    subgraph: torch.fx.Graph = torch.fx.Graph()
+    node_to_subgraph_node = {}
+    for idx, inp in enumerate(inputs):
+        subgraph_inp = subgraph.placeholder(name=f"arg_{idx}")
+        subgraph_inp.meta = inp.meta
+        node_to_subgraph_node[inp] = subgraph_inp
+
+    for node in body:
+        subgraph_node = subgraph.node_copy(
+            node, arg_transform=lambda x: node_to_subgraph_node[x]
+        )
+        node_to_subgraph_node[node] = subgraph_node
+
+    subgraph.output(result=tuple(node_to_subgraph_node[x] for x in outputs))
+    subgraph.eliminate_dead_code()
+    subgraph.lint()
+    return torch.fx.GraphModule(root={}, graph=subgraph)
+
+
+def _is_container_node(node: torch.fx.Node) -> bool:
+    if any(user.target == operator.getitem for user in node.users):
+        assert all(user.target == operator.getitem for user in node.users), (
+            "Malformed graph: a container node is used as input for non-getitem nodes."
+            "\nNode: {fmt_node}\nUsers: {fmt_users}".format(
+                fmt_node=node.format_node(),
+                fmt_users="\n".join(u.format_node() for u in node.users),
+            )
+        )
+        return True
+    return False
+
+
+def _lower_subgraph_nodes(
+    gm: torch.fx.GraphModule,
+    subgraph_name: str,
+    subgraph_nodes: List[torch.fx.Node],
+    dumper: Callable[[str], str],
+) -> None:
+    prologue: List[torch.fx.Node] = []
+    inputs: List[torch.fx.Node] = []
+    body: List[torch.fx.Node] = []
+    visible: Set[torch.fx.Node] = set()
+
+    # Inductor requires all graph input to be tensors. When adding a container
+    # node as subgraph input, add its descendant getitem nodes to the subgraph
+    # prologue and add its leaf getitem nodes to the subgraph input.
+    def add_input(arg: torch.fx.Node) -> None:
+        stack = [arg]
+        while len(stack) != 0:
+            node = stack.pop()
+            if _is_container_node(node):
+                # We should only prepone nodes within subgraph_nodes
+                prologue.extend(user for user in node.users if user in subgraph_nodes)
+                stack.extend(node.users)
+            else:
+                if node not in visible:
+                    inputs.append(node)
+                    visible.add(node)
+
+    for node in subgraph_nodes:
+        if node.op == "get_attr":
+            # Prepone get_attr to avoid having to copy
+            # the attribute to the subgraph module.
+            inputs.append(node)
+            visible.add(node)
+            continue
+
+        for arg in node.all_input_nodes:
+            if arg not in visible:
+                add_input(arg)
+
+        if node not in prologue:
+            body.append(node)
+            visible.add(node)
+
+    outputs: List[torch.fx.Node] = []
+
+    # Inductor requires all graph output to be tensors. When adding a container
+    # node as subgraph output, add its descendant getitem nodes to the subgraph
+    # body and add its leaf getitem nodes to the subgraph output.
+    def add_output(output: torch.fx.Node) -> None:
+        stack = [output]
+        while len(stack) != 0:
+            node = stack.pop()
+            if _is_container_node(node):
+                body.extend(node.users)
+                stack.extend(node.users)
+            elif not all(user in visible for user in node.users):
+                if node not in outputs:
+                    outputs.append(node)
+
+    for node in body:
+        if not all(user in visible for user in node.users):
+            add_output(node)
+
+    assert len(inputs) == len(set(inputs))
+    assert len(outputs) == len(set(outputs))
+
+    subgraph_module = _create_subgraph_module(inputs, body, outputs)
+    readable_tag = dumper(str(subgraph_module.graph))
+    setattr(gm, subgraph_name, _InductorModule(subgraph_module))
+
+    insertion_point = subgraph_nodes[-1].next
+    for node in prologue:
+        insertion_point.prepend(node)
+
+    with gm.graph.inserting_before(insertion_point):
+        # Insert subgraph call
+        subgraph_call = gm.graph.create_node(
+            op="call_module",
+            target=subgraph_name,
+            args=tuple(inputs),
+            kwargs={"tag": readable_tag},
+        )
+        # Replace parent graph nodes with their corresponding subgraph outputs
+        for idx, output in enumerate(outputs):
+            new_output = gm.graph.create_node(
+                op="call_function",
+                target=operator.getitem,
+                args=(subgraph_call, idx),
+            )
+            new_output.meta = output.meta
+            output.replace_all_uses_with(new_output)
+
+    # Erase lowered nodes from the parent graph
+    for node in reversed(body + outputs):
+        if len(node.users) == 0:
+            gm.graph.erase_node(node)
+
+
+class _InductorModule(torch.nn.Module):
+    def __init__(self, gm: torch.fx.GraphModule) -> None:
+        super().__init__()
+        self.gm = gm
+        self.compiled: Optional[
+            Callable[[List[torch.Tensor]], List[torch.Tensor]]
+        ] = None
+
+    def forward(self, *args: torch.Tensor, tag: str) -> List[torch.Tensor]:
+        if self.compiled is None:
+            inductor_decompositions = select_decomp_table()
+            # TODO: figure out why turning on cudagraphs cause exceptions.
+            decomp_gm = make_fx(self.gm, decomposition_table=inductor_decompositions)(
+                *args
+            )
+            logger.info("Lowering subgraph (%s) to Inductor...", tag)
+            self.compiled = compile_fx_inner(
+                decomp_gm,
+                list(args),
+                cudagraphs=False,
+            )
+            logger.info("Completed lowering subgraph (%s) to Inductor", tag)
+        with torch.profiler.record_function(tag):
+            assert self.compiled is not None
+            return self.compiled(list(args))
+
+
+def _is_inductor_compatible(node: torch.fx.Node) -> Tuple[bool, str]:
+    # `has_tag` is not supported yet
+    # if has_tag(node, "non_lowerable"):
+
+    if node.target in (
+        torch.ops.aten._fused_adam_.default,
+        torch.ops.aten._fused_adam.default,
+        torch.ops.aten._foreach_add_.Scalar,
+        torch.ops.aten._foreach_add.Scalar,
+    ):
+        return False, "fused adam is not supported yet"
+
+    # TODO(yifu): apparently having a meta kernel is not a necessary
+    # condition for Inductor compatiblity. We should refine the check.
+    # Sneaking this one in for now to support comm_fusion_with_cat.
+    if node.target == torch.ops.aten.flatten.using_ints:
+        return True, ""
+
+    if isinstance(node.target, torch._ops.OpOverload):
+        if not node.target.has_kernel_for_dispatch_key(torch._C.DispatchKey.Meta):
+            return False, f"{node.target} doesn't have a meta kernel registered"
+    return True, ""
+
+
+def _subgraph_predicate(nodes: List[torch.fx.Node]) -> bool:
+    num_aten_ops = len([n for n in nodes if str(n.target).startswith("aten.")])
+    return num_aten_ops >= MIN_ATEN_OPS_TO_LOWER
+
+
+def partial_lower(
+    gm: torch.fx.GraphModule,
+    node_predicate: Callable[[torch.fx.Node], bool] = lambda x: True,
+    subgraph_predicate: Callable[[List[torch.fx.Node]], bool] = lambda x: True,
+    dumper: Callable[[str], str] = lambda x: "subgraph",
+) -> torch.fx.GraphModule:
+    """
+    Lower Inductor compatible portions of the graph module to Inductor.
+
+    Args:
+        node_predicate: user predicate for determining whether to consider a node for
+            lowering.
+        subgraph_predicate: user predicate for determining whether to consider a list of
+            candidate nodes for lowering.
+        dumper: a callback for dumping subgraphs for human digestion. For exmaple, it
+            can be a function that writes to disk/blob storage and returns the
+            path/handle. The returned path/handle for each subgraph will be made
+            available in the subgraph call node in the parent graph, as well as the
+            label of the profiler block for the subgraph.
+    """
+    nodes_per_subgraph: List[List[torch.fx.Node]] = [[]]
+    ptr = next(iter(gm.graph.nodes))
+
+    def _node_predicate(node: torch.fx.Node) -> Tuple[bool, str]:
+        should_lower, reason = _is_inductor_compatible(node)
+        if not should_lower:
+            return should_lower, reason
+        if not node_predicate(node):
+            return False, "user predicate"
+        return True, ""
+
+    while ptr.op != "output":
+        if ptr.op == "placeholder":
+            ptr = ptr.next
+            continue
+        should_lower, reason = _node_predicate(ptr)
+        if should_lower:
+            nodes_per_subgraph[-1].append(ptr)
+        else:
+            if len(nodes_per_subgraph[-1]) > 0:
+                logger.warning(
+                    "partial_lower: graph break at %s. Reason: %s", str(ptr), reason
+                )
+            nodes_per_subgraph.append([])
+        ptr = ptr.next
+
+    nodes_per_subgraph = [
+        nodes
+        for nodes in nodes_per_subgraph
+        if subgraph_predicate(nodes) and _subgraph_predicate(nodes)
+    ]
+
+    for idx, subgraph_nodes in enumerate(nodes_per_subgraph):
+        subgraph_name = f"subgraph_{idx}"
+        _lower_subgraph_nodes(gm, subgraph_name, subgraph_nodes, dumper)
+
+    gm.graph.lint()
+    gm.recompile()
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/distributed/_state_dict_utils.py b/MLPY/Lib/site-packages/torch/distributed/_state_dict_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce134241ae2a1b067553402aff8783b40cc6934
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_state_dict_utils.py
@@ -0,0 +1,385 @@
+import io
+import math
+from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.distributed._functional_collectives import AsyncCollectiveTensor
+
+if dist.is_available() or TYPE_CHECKING:
+    from torch.distributed import distributed_c10d
+    from torch.distributed._shard.sharded_tensor import ShardedTensor
+    from torch.distributed._tensor import DTensor, Replicate
+
+
+def _identity_func(
+    obj: torch.Tensor,
+    pg: Optional[dist.ProcessGroup],
+    device: Optional[torch.device],
+    companion_obj: Any,
+) -> torch.Tensor:
+    return obj
+
+
+def _all_gather_sharded_tensor(
+    sharded_tensor: "ShardedTensor",
+    pg: Optional[dist.ProcessGroup] = None,
+    device: Optional[torch.device] = None,
+) -> torch.Tensor:
+    if pg is None:
+        pg = distributed_c10d._get_default_group()
+    world_size = dist.get_world_size(pg)
+    shards = sharded_tensor.local_shards()
+    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
+    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
+    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
+    pg_device = (
+        distributed_c10d._get_pg_default_device(pg) if device is None else device
+    )
+    if shards:
+        local_tensor = shards[0].tensor.flatten()
+        if local_tensor.device.type != pg_device.type:
+            local_tensor = local_tensor.to(pg_device)
+        num_padding = chunk_size - local_tensor.numel()
+        if num_padding > 0:
+            local_tensor = F.pad(local_tensor, [0, num_padding])
+    else:
+        local_tensor = torch.zeros(
+            chunk_size, dtype=sharded_tensor.dtype, device=pg_device
+        )
+
+    tensor = torch.empty(
+        chunk_size * world_size,
+        dtype=local_tensor.dtype,
+        device=pg_device,
+    )
+    dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
+
+    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
+    return tensor
+
+
+class CompanionMismatch(Exception):
+    ...
+
+
+def _iterate_state_dict(
+    iter_object: Any,
+    sharded_tensor_func: Callable,
+    dtensor_func: Callable,
+    tensor_func: Callable,
+    *,
+    pg: Optional[dist.ProcessGroup] = None,
+    device: Optional[torch.device] = None,
+    cpu_offload: bool = False,
+    companion_obj: Any = None,
+    ranks_only: Tuple[int, ...] = tuple(),
+    type_check: bool = True,
+) -> Dict[str, Any]:
+    # TODO: should we use pytree?
+    cpu_device = torch.device("cpu")
+    if isinstance(iter_object, ShardedTensor):
+        ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
+    elif isinstance(iter_object, DTensor):
+        ret = dtensor_func(iter_object, pg, device, companion_obj)
+    elif isinstance(iter_object, torch.Tensor):
+        ret = tensor_func(iter_object, pg, device, companion_obj)
+    elif (
+        isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
+        or iter_object is None
+    ):
+        ret = iter_object
+    elif isinstance(iter_object, dict):
+        if companion_obj is not None and (
+            not isinstance(companion_obj, dict)
+            or set(companion_obj.keys()) != set(iter_object.keys())
+        ):
+            raise CompanionMismatch()
+
+        ret = {
+            key: _iterate_state_dict(
+                value,
+                sharded_tensor_func,
+                dtensor_func,
+                tensor_func,
+                pg=pg,
+                device=device,
+                cpu_offload=cpu_offload,
+                companion_obj=companion_obj[key] if companion_obj is not None else None,
+                ranks_only=ranks_only,
+                type_check=type_check,
+            )
+            for key, value in iter_object.items()
+        }
+    elif isinstance(iter_object, (list, tuple)):
+        if companion_obj is not None and (
+            not isinstance(companion_obj, (list, tuple))
+            or len(companion_obj) != len(iter_object)
+        ):
+            raise CompanionMismatch()
+
+        ret = [
+            _iterate_state_dict(
+                v,
+                sharded_tensor_func,
+                dtensor_func,
+                tensor_func,
+                pg=pg,
+                device=device,
+                cpu_offload=cpu_offload,
+                companion_obj=companion_obj[idx] if companion_obj is not None else None,
+                ranks_only=ranks_only,
+                type_check=type_check,
+            )
+            for idx, v in enumerate(iter_object)
+        ]
+        if isinstance(iter_object, tuple):
+            ret = tuple(ret)
+    elif not type_check:
+        ret = iter_object
+    else:
+        raise ValueError(f"Unexpected value type {type(iter_object)}")
+
+    if not ranks_only or dist.get_rank(pg) in ranks_only:
+        if isinstance(ret, torch.Tensor) and cpu_offload:
+            if companion_obj is None:
+                ret = ret.to(cpu_device)
+            else:
+                # TODO: support DTensor
+                companion_obj.copy_(ret, non_blocking=True)
+                ret = companion_obj
+    else:
+        ret = {} if isinstance(ret, dict) else None
+
+    return ret
+
+
+def _gather_state_dict(
+    state_dict: Dict[str, Any],
+    *,
+    pg: Optional[dist.ProcessGroup] = None,
+    device: Optional[torch.device] = None,
+    cpu_offload: bool = False,
+    ranks_only: Tuple[int, ...] = tuple(),
+    type_check: bool = True,
+) -> Dict[str, Any]:
+    """
+    Given a state_dict, this API gathers all the ShardedTensors or DTensors in
+    the state_dict.
+
+
+    Args:
+        state_dict (Dict[str, Any]): the target sharded state_dict.
+        pg (Optional[dist.ProcessGroup]): the process group that is used to
+            gather ShardedTensor. Note that gathering a DTensor will use
+            the DeviceMesh. So this argument will be ignored when gathering a
+            DTensor.
+        device: (Optional[torch.device]): the device that is used to
+            perform allgather for ShardedTensor. Note that gathering a DTensor
+            will use the DeviceMesh. So this argument will be ignored when
+            gathering a DTensor.
+        cpu_offload (bool): whether to offload the tensors to CPU memory. The
+            default value is False.
+        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
+            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
+            have the same state_dicts. Other ranks will get empty state_dicts.
+        type_check: (bool): check if the instance data type is a supported type
+            that can be saved by DCP.  The current supported data types are
+            torch.Tensor, DTensor, int, float, str, list, dict, None.
+
+    Returns:
+        The gathered state dictionary.
+    """
+
+    def sharded_tensor_func(value, pg, device, companion_obj):
+        # ShardedTensor does not seem to record the original device type.
+        # So if the tensor is moved to CPU, we won't know the original type.
+        # As a result, we have to rely on the user to tell us the correct one.
+        cpu_device = torch.device("cpu")
+        output_tensor = _all_gather_sharded_tensor(value, pg, device)
+        local_shard_device = (
+            value.local_shards()[0].tensor.device
+            if value.local_shards()
+            else cpu_device
+        )
+        if output_tensor.device != local_shard_device:
+            value = output_tensor.to(local_shard_device)
+        else:
+            value = output_tensor
+        return value
+
+    def dtensor_func(value, pg, device, companion_obj):
+        if value.device != value.device_mesh.device_type:
+            value = value.to(value.device_mesh.device_type)
+        # FSDP all_gather: [Shard(0)] -> [Replicate()]
+        # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
+        # 2D FSDP + TP all_gather:
+        # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
+        # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
+        placements = [Replicate() for _ in value.placements]
+        value = value.redistribute(
+            device_mesh=value.device_mesh,
+            placements=placements,
+        )
+        # Call `wait()` to force the tensor to be synchronous with respect
+        # to the main stream.
+        # See the discussion in https://github.com/pytorch/pytorch/pull/117799.
+        value = value.to_local()
+        if isinstance(value, AsyncCollectiveTensor):
+            value = value.wait()
+        return value
+
+    return _iterate_state_dict(
+        state_dict,
+        sharded_tensor_func,
+        dtensor_func,
+        _identity_func,
+        pg=pg,
+        device=device,
+        cpu_offload=cpu_offload,
+        ranks_only=ranks_only,
+        type_check=type_check,
+    )
+
+
+def _offload_state_dict_to_cpu(
+    state_dict: Dict[str, Any],
+    *,
+    ranks_only: Tuple[int, ...] = tuple(),
+    cpu_offload_state_dict: Optional[Dict[str, Any]] = None,
+    cpu_offload_sync: bool = True,
+    type_check: bool = True,
+) -> Dict[str, Any]:
+    """
+    Given a state_dict, this API offload all the tensors to CPU memory.
+
+    Args:
+        state_dict (Dict[str, Any]): the target state_dict.
+        pg (Optional[dist.ProcessGroup]): the process group that is used to
+            gather ShardedTensor. Note that gathering a DTensor will use
+            the DeviceMesh. So this argument will be ignored when gathering a
+            DTensor.
+        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
+            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
+            have the same state_dicts. Other ranks will get empty state_dicts.
+        cpu_offload_state_dict (Optional[Dict[str, Any]]): the CPU state_dict
+            that will be returned. If this is not None, this API will use
+            `copy_` to copy the GPU tensor to the tensor in this CPU state_dict.
+            This CPU state_dict must have exactly the same structure as the
+            `state_dict` the only difference is that all the tensors in this
+            CPU state_dict are on CPU memory.
+        cpu_offload_sync: (bool): flag to decide whether to call `synchronize()`
+            before this API returns.
+        type_check: (bool): check if the instance data type is a supported type
+            that can be saved by DCP.  The current supported data types are
+            torch.Tensor, DTensor, int, float, str, list, dict, None.
+
+    Returns:
+        The gathered state dictionary.
+    """
+
+    ret = _iterate_state_dict(
+        state_dict,
+        _identity_func,
+        _identity_func,
+        _identity_func,
+        pg=None,
+        device=None,
+        cpu_offload=True,
+        ranks_only=ranks_only,
+        companion_obj=cpu_offload_state_dict,
+        type_check=type_check,
+    )
+    if cpu_offload_state_dict is not None and cpu_offload_sync:
+        torch.cuda.synchronize()
+    return ret
+
+
+def _create_cpu_state_dict(
+    state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
+) -> Dict[str, Any]:
+    """
+    Given a state_dict, create another state_dict with the same structure and elements.
+    However, all tensors in the returned state_dict are new tensors on CPU. These
+    tensors can be placed on pin_memory or share_memory based on the provided arguments.
+    """
+
+    if pin_memory and share_memory:
+        raise ValueError(
+            "Cannot allocate both memory on both pin_memory and share_memory"
+        )
+
+    def tensor_func(
+        obj: torch.Tensor,
+        pg: Optional[dist.ProcessGroup],
+        device: Optional[torch.device],
+        companion_obj: Any,
+    ) -> torch.Tensor:
+        if len(obj.size()) == 0:
+            return torch.tensor(0, dtype=obj.dtype)
+
+        if share_memory:
+            return torch.empty(
+                *tuple(companion_obj.size()), dtype=companion_obj.dtype
+            ).share_memory_()
+        else:
+            return torch.empty(
+                *tuple(companion_obj.size()), dtype=companion_obj.dtype
+            ).pin_memory()
+
+    ret = _iterate_state_dict(
+        state_dict,
+        _identity_func,
+        _identity_func,
+        tensor_func,
+        pg=None,
+        device=None,
+        cpu_offload=False,
+        ranks_only=tuple(),
+        companion_obj=state_dict,
+        type_check=False,
+    )
+    return ret
+
+
+def _check_state_dict_similarity(
+    state_dict: Dict[str, Any],
+    compared_state_dict: Dict[str, Any],
+) -> bool:
+    """
+    Given two state_dicts, check if the structures are the same. And
+    if a [key, tensor] pair exist in one state_dict there must be
+    the a corresponding pait, [key, other_tensor], in the other state_dict,
+    where tensor and other_tensor have the same size and dtype.
+
+    Return the check result.
+    """
+
+    def tensor_func(
+        obj: torch.Tensor,
+        pg: Optional[dist.ProcessGroup],
+        device: Optional[torch.device],
+        companion_obj: Any,
+    ) -> torch.Tensor:
+        if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
+            raise CompanionMismatch()
+        return obj
+
+    try:
+        _iterate_state_dict(
+            state_dict,
+            _identity_func,
+            _identity_func,
+            tensor_func,
+            pg=None,
+            device=None,
+            cpu_offload=False,
+            ranks_only=tuple(),
+            companion_obj=compared_state_dict,
+            type_check=False,
+        )
+    except CompanionMismatch:
+        return False
+
+    return True
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c67c146cc9c40dd9b0d697a3dd9724bb693ac39
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/__init__.py
@@ -0,0 +1,342 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import Optional, Sequence
+
+# Import all builtin dist tensor ops
+import torch
+import torch.distributed._tensor.ops
+import torch.distributed._tensor.random as random
+from torch.distributed._tensor._utils import compute_local_shape
+from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor
+from torch.distributed._tensor.ops.utils import normalize_to_torch_size
+from torch.distributed._tensor.placement_types import Placement, Replicate, Shard
+from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
+
+# All public APIs from dtensor package
+__all__ = [
+    "DTensor",
+    "DeviceMesh",
+    "distribute_tensor",
+    "distribute_module",
+    "init_device_mesh,",
+    "Shard",
+    "Replicate",
+]
+
+
+def _dtensor_init_helper(
+    init_op,
+    size: torch.Size,
+    device_mesh=None,
+    placements=None,
+    **kwargs,
+) -> DTensor:
+    # if device_mesh is None, use the one from mesh resources
+    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
+    kwargs["device"] = device_mesh.device_type
+
+    # set default placements to replicated if not specified
+    placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
+
+    # check device_mesh againts placements
+    assert device_mesh.ndim == len(
+        placements
+    ), "mesh dimension does not match the length of placements"
+
+    assert kwargs["layout"] == torch.strided, "layout value not supported!"
+    torch_stride = torch._prims_common.make_contiguous_strides_for(size)
+
+    # get local tensor shape
+    local_shape = compute_local_shape(size, device_mesh, placements)
+    # initialize the local tensor
+    if init_op == torch.full:
+        fill_value = kwargs.pop("fill_value", 0)
+        local_tensor = init_op(local_shape, fill_value, **kwargs)
+    elif init_op == torch.rand or init_op == torch.randn:
+        # this tensor meta is not used except `shape`
+        dtype = kwargs.get("dtype", torch.get_default_dtype())
+
+        from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
+
+        tensor_meta = TensorMeta(size, (0,), dtype)
+        spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta)
+
+        if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
+            random._rng_tracker = random.OffsetBasedRNGTracker()
+
+        assert random._rng_tracker is not None
+        with random._rng_tracker._distribute_region(spec):
+            local_tensor = init_op(local_shape, **kwargs)
+    else:
+        local_tensor = init_op(local_shape, **kwargs)
+
+    return DTensor(
+        local_tensor=local_tensor,
+        device_mesh=device_mesh,
+        placements=tuple(placements),
+        shape=size,
+        dtype=local_tensor.dtype,
+        stride=torch_stride,
+        requires_grad=kwargs["requires_grad"],
+    )
+
+
+def ones(
+    *size,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    requires_grad: bool = False,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
+    by the variable argument ``size``.
+
+    Args:
+        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
+            Can be a variable number of arguments or a collection like a list or tuple.
+            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned :class:`DTensor`. Default: ``False``.
+        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
+        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
+
+    Returns:
+        A :class:`DTensor` object on each rank
+    """
+    torch_size = normalize_to_torch_size(size)
+
+    return _dtensor_init_helper(
+        torch.ones,
+        torch_size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        device_mesh=device_mesh,
+        placements=placements,
+    )
+
+
+def empty(
+    *size,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    requires_grad: bool = False,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
+    is defined by the variable argument ``size``.
+
+    Args:
+        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
+            Can be a variable number of arguments or a collection like a list or tuple.
+            E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
+        layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned :class:`DTensor`. Default: ``False``.
+        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
+        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
+
+    Returns:
+        A :class:`DTensor` object on each rank
+    """
+    torch_size = normalize_to_torch_size(size)
+
+    return _dtensor_init_helper(
+        torch.empty,
+        torch_size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        device_mesh=device_mesh,
+        placements=placements,
+    )
+
+
+def full(
+    size,
+    fill_value,
+    *,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    requires_grad: bool = False,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match
+        ``device_mesh.device_type``.
+
+    Args:
+        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
+            Can be a variable number of arguments or a collection like a list or tuple.
+            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
+        fill_value(Scalar): the value to fill the output tensor with.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned :class:`DTensor`. Default: ``False``.
+        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
+        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
+
+    Returns:
+        A :class:`DTensor` object on each rank
+    """
+    torch_size = normalize_to_torch_size(size)
+
+    return _dtensor_init_helper(
+        torch.full,
+        torch_size,
+        fill_value=fill_value,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        device_mesh=device_mesh,
+        placements=placements,
+    )
+
+
+def rand(
+    *size,
+    requires_grad: bool = False,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Returns a :class:`DTensor` filled with random numbers from a uniform distribution
+        on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
+        argument ``size``.
+
+    Args:
+        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
+            Can be a variable number of arguments or a collection like a list or tuple.
+            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned :class:`DTensor`. Default: ``False``.
+        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
+        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
+
+    Returns:
+        A :class:`DTensor` object on each rank
+    """
+    torch_size = normalize_to_torch_size(size)
+
+    return _dtensor_init_helper(
+        torch.rand,
+        torch_size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        device_mesh=device_mesh,
+        placements=placements,
+    )
+
+
+def randn(
+    *size,
+    requires_grad: bool = False,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Returns a :class:`DTensor` filled with random numbers from a normal distribution
+        with mean 0 and variance 1. The shape of the tensor is defined by the variable
+        argument ``size``.
+
+    Args:
+        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
+            Can be a variable number of arguments or a collection like a list or tuple.
+            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned :class:`DTensor`. Default: ``False``.
+        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
+        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
+
+    Returns:
+        A :class:`DTensor` object on each rank
+    """
+    torch_size = normalize_to_torch_size(size)
+
+    return _dtensor_init_helper(
+        torch.randn,
+        torch_size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        device_mesh=device_mesh,
+        placements=placements,
+    )
+
+
+def zeros(
+    *size,
+    requires_grad: bool = False,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Returns a :class:`DTensor` filled with the scalar value 0.
+
+    Args:
+        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
+            Can be a variable number of arguments or a collection like a list or tuple.
+            E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
+    Keyword args:
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned :class:`DTensor`. Default: ``False``.
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
+            Default: ``torch.strided``.
+        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
+        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
+
+    Returns:
+        A :class:`DTensor` object on each rank
+    """
+    torch_size = normalize_to_torch_size(size)
+
+    return _dtensor_init_helper(
+        torch.zeros,
+        torch_size,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+        device_mesh=device_mesh,
+        placements=placements,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b7f1a97ddde89cdf5e16281d7642924ee8daeca
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/_collective_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/_collective_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..053e8ffe80ea0f086a7cf87bee54693836a0bbf1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/_collective_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd29b73947615309db4837987dcf569a46b18ae9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dd4c3bd3fc81197e9e16ea5ef1fbb16b540f18b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/device_mesh.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/device_mesh.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84decc8de8cc2fbfeb326d783af0ef2f6563a149
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/device_mesh.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/dispatch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/dispatch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16b21f5945b282c0ebe0fbcdebc7f68c453b1bb8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/dispatch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/op_schema.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/op_schema.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b9cb1d58cf93db23b444f25add94c746ed0fda4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/op_schema.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05400981c736888571e4e4b0e250b60be723c333
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/random.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/random.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..729142ab3d3140c73f7e7e1ce5c9be15eca71176
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/random.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/redistribute.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/redistribute.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..864abcc7f4510e0ea4e22ba526f0aab218fdfbfb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/redistribute.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/sharding_prop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/sharding_prop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b038dd055692de693e3d7f743a9256ca454fcb54
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/sharding_prop.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/tp_conv.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/tp_conv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4aa1da6df197117d1b7e58beda383c8ad0f812b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/__pycache__/tp_conv.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/_collective_utils.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/_collective_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb5ba70d71991f1344daad89b22d8df57b54802
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/_collective_utils.py
@@ -0,0 +1,313 @@
+import logging
+import math
+from dataclasses import dataclass
+from functools import lru_cache
+
+from typing import List, Optional
+
+import torch
+import torch.distributed._tensor.placement_types as placement_types
+from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
+from torch.distributed.distributed_c10d import (
+    all_to_all,
+    broadcast,
+    get_global_rank,
+    get_rank,
+    get_world_size,
+    GroupMember,
+    ProcessGroup,
+    scatter,
+    Work,
+)
+
+logger = logging.getLogger(__name__)
+
+
+# TODO: we need to migrate these APIs to be functional collectives
+
+
+def mesh_scatter(
+    output: torch.Tensor,
+    scatter_list: List[torch.Tensor],
+    mesh: DeviceMesh,
+    mesh_dim: int = 0,
+    async_op: bool = False,
+) -> Optional[Work]:
+    """
+    scatter a list of tensors to a device mesh dimension. We by default
+    use the first rank of the mesh dimension as the source of truth, i.e
+    for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
+    scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
+    2 to rank 2/3.
+
+    Args:
+        output (torch.Tensor): the tensor to receive the scattered list.
+        scatter_list (List[torch.Tensor]): the tensor list to be scattered.
+        mesh_dim (int, optional): indicate which mesh dimension we want
+            to scatter on, we by default choose the first rank on the
+            mesh dimension as source of truth.
+
+    Returns:
+        A :class:`Work` object
+    """
+    # TODO: Ideally we should use the meta tensor way
+    # (to register a meta kernel for the collective op)
+    # so that it would avoid the communication. Need to
+    # remove the check below once that is done.
+    if output.is_meta:
+        return None
+    dim_group = mesh.get_group(mesh_dim)
+    assert isinstance(dim_group, ProcessGroup)
+    # src need to be global rank
+    src_for_dim = 0
+
+    if dim_group is not GroupMember.WORLD:
+        src_for_dim = get_global_rank(dim_group, 0)
+
+    if src_for_dim == get_rank():
+        fut = scatter(
+            output,
+            scatter_list=scatter_list,
+            src=src_for_dim,
+            group=dim_group,
+            async_op=async_op,
+        )
+    else:
+        fut = scatter(
+            output,
+            scatter_list=None,
+            src=src_for_dim,
+            group=dim_group,
+            async_op=async_op,
+        )
+
+    return fut
+
+
+def mesh_broadcast(
+    tensor: torch.Tensor,
+    mesh: DeviceMesh,
+    mesh_dim: int = 0,
+    async_op: bool = False,
+) -> Optional[Work]:
+    """
+    broadcast the tensor to a device mesh dimension. We by default
+    use the first rank of the mesh dimension as the source of truth, i.e
+    for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
+    broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
+    to rank 2/3.
+
+    Args:
+        tensor (torch.Tensor): tensor to broadcast.
+        mesh_dim (int, optional): indicate which mesh dimension we want
+            to scatter on, we by default choose the first rank on the
+            mesh dimension as source of truth.
+
+    Returns:
+        A :class:`Work` object
+    """
+    # TODO: Ideally we should use the meta tensor way
+    # (to register a meta kernel for the collective op)
+    # so that it would avoid the communication. Need to
+    # remove the check below once that is done.
+    if tensor.is_meta:
+        return None
+    dim_group = mesh.get_group(mesh_dim)
+    assert isinstance(dim_group, ProcessGroup)
+    # src need to be global rank
+    src_for_dim = 0
+    if dim_group is not GroupMember.WORLD:
+        src_for_dim = get_global_rank(dim_group, 0)
+
+    return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
+
+
+# TODO: test uneven split on GLOO and NCCL
+def mesh_all_to_all(
+    output_tensor_list: List[torch.Tensor],
+    input_tensor_list: List[torch.Tensor],
+    mesh: DeviceMesh,
+    mesh_dim: int = 0,
+    async_op: bool = False,
+) -> Optional[Work]:
+    dim_group = mesh.get_group(mesh_dim)
+    assert isinstance(dim_group, ProcessGroup)
+
+    work = None
+    # no direct dist.all_to_all support on 'gloo' so we manually do scatters
+    if mesh.device_type == "cpu":
+        logger.warning(
+            "ProcessGroupGloo does not support all_to_all, falling back with scatters!"
+        )
+        # TODO: pull the handle of uneven case in #492
+        dim_group_size = get_world_size(dim_group)
+        for i in range(dim_group_size):
+            # src need to be global rank
+            src_for_dim = i
+            if dim_group is not GroupMember.WORLD:
+                src_for_dim = get_global_rank(dim_group, i)
+
+            work = scatter(
+                output_tensor_list[i],
+                input_tensor_list if mesh.get_rank() == src_for_dim else [],
+                group=dim_group,
+                src=src_for_dim,
+                async_op=async_op,
+            )
+    else:
+        work = all_to_all(
+            output_tensor_list,
+            input_tensor_list,
+            dim_group,
+            async_op=async_op,
+        )
+    return work
+
+
+def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
+    assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
+    return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
+
+
+@dataclass
+class MeshTopoInfo:
+    """
+    Mesh information for collective cost estimation
+    """
+
+    mesh: DeviceMesh
+    mesh_dim_devices: List[int]
+    mesh_dim_bandwidth: List[float]
+    mesh_dim_latency: List[float]
+
+    @staticmethod
+    @lru_cache(None)
+    def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
+        # Generate mesh topology info for intra-host/inter-host communication pattern
+        # Note that we made bunch of assumptions for simplicity:
+        # 1. we assume the mesh is homogeneous, and it's gpu/nccl model
+        # 2. we assume gpu arch is Ampere or Hopper
+        # 3. we assume collectives are all ring base algo for now
+        num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
+        # the base bw number (intra-node), GB/s
+        base_bw = 87.7
+        mesh_dim_bandwidth = [base_bw] * mesh.ndim
+        # the latency in terms of us (intra-node, nv-link)
+        mesh_dim_latency = [0.6] * mesh.ndim
+        mesh_dim_devices = [1] * mesh.ndim
+
+        total_num_devices = 1
+        for mesh_dim in reversed(range(mesh.ndim)):
+            num_devices = mesh.size(mesh_dim)
+            mesh_dim_devices[mesh_dim] = num_devices
+            total_num_devices *= num_devices
+            if total_num_devices > num_devices_per_host:
+                # magic number for inter-host communication bandwidth/latency factor
+                # This number assumes latest GPU arch, i.e. Ampere or Hopper
+                # TODO: see if we need to tweak this or offer a way for user
+                # to specify the bandwidths/latency
+                mesh_dim_bandwidth[mesh_dim] *= 0.22
+                # set to ethernet latency for inter-host
+                mesh_dim_latency[mesh_dim] = 2.7
+
+        return MeshTopoInfo(
+            mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
+        )
+
+
+def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
+    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
+    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
+    num_hops = num_devices_on_mesh_dim - 1
+    # base latency + comm latency
+    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]  # us
+    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth  # s
+    return latency + bw * 1e6  # rescale to us
+
+
+def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
+    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
+    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
+    # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
+    num_hops = 2 * num_devices_on_mesh_dim - 1
+
+    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
+    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
+    return latency + bw * 1e6
+
+
+def reduce_scatter_cost(
+    bytes_gb: float,
+    mesh_topo: MeshTopoInfo,
+    mesh_dim: int,
+) -> float:
+    num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
+    mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
+    num_hops = num_devices_on_mesh_dim - 1
+    # base latency + comm latency
+    latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
+    bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
+    return latency + bw * 1e6
+
+
+def redistribute_cost(
+    current_spec: "placement_types.DTensorSpec",
+    target_spec: "placement_types.DTensorSpec",
+) -> float:
+    """
+    This function returns the cost of redistribute from current to target DTensorSpec.
+
+    NOTE:
+    1. Only consider communication cost here, since computation costs for redistribute
+       are quite trival (i.e. we only need to narrow or simple division)
+    2. Only consider redistribute cost on same mesh, cross mesh communication cost is
+       not quite needed for operator strategy estimation/selection.
+    """
+    if current_spec.mesh != target_spec.mesh:
+        # make infinite cost if meshes are not same
+        # TODO: see if we want to support this once there's cross mesh communication
+        return float("inf")
+
+    if current_spec.is_replicated():
+        # short-cut:
+        # comm cost is 0 if current spec is already full replication
+        return 0.0
+
+    mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
+    cost = 0.0
+    comm_bytes_gb = (
+        spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
+    )
+    # Transformation that considered for redistribute cost:
+    # 1. allgather 2. alltoall
+    # 3. allreduce 4. reduce_scatter
+    for i, (current, target) in enumerate(
+        zip(current_spec.placements, target_spec.placements)
+    ):
+        if current == target:
+            continue
+
+        num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
+        if current.is_shard() and target.is_replicate():
+            # allgather gives larger comm bytes
+            comm_bytes_gb *= num_devices_on_mesh_dim
+            # add up allgather comm cost
+            cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
+        elif current.is_shard() and target.is_shard():
+            # should be alltoall comm, since we haven't implement it yet, add penalty
+            # to favor allgather instead
+            cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0
+        elif current.is_partial() and target.is_replicate():
+            # add up allreduce comm cost
+            cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
+        elif current.is_partial() and target.is_shard():
+            # add up reduce_scatter comm cost
+            cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
+            # after reduce_scatter the comm bytes for further collectives halved.
+            comm_bytes_gb /= num_devices_on_mesh_dim
+        elif current.is_shard() and target.is_partial():
+            # ban shard -> partial as it does not make sense to perform
+            # this redistribute
+            return float("inf")
+
+    return cost
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/_utils.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ee7b188f1fc34fdbcfd8c4c0fc015c99886362
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/_utils.py
@@ -0,0 +1,204 @@
+from typing import cast, List, Sequence, Tuple
+
+import torch
+import torch.distributed._tensor.api as dtensor
+from torch._prims_common import ShapeType
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.device_mesh import DeviceMesh
+
+
+# TODO: audit existing code base to see if we can safely remove this API.
+def compute_local_shape(
+    global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
+) -> Tuple[int, ...]:
+    """
+    Compute the shape of a local shard of the given DTensor on its current
+    coordinate of the mesh.
+    """
+    my_coordinate = mesh.get_coordinate()
+
+    if my_coordinate is None:
+        # if rank not in the mesh, return empty shape
+        return (0,)
+    else:
+        local_shape = list(global_shape)  # start with global shape
+        ndim = len(global_shape)
+        for idx, placement in enumerate(placements):
+            mesh_dim_size = mesh.size(idx)
+            if isinstance(placement, Shard):
+                shard_dim = placement.dim
+                assert (
+                    shard_dim < ndim
+                ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}"
+                local_shard_size, _ = placement._local_shard_size_on_dim(
+                    local_shape[shard_dim], mesh_dim_size, my_coordinate[idx]
+                )
+                assert isinstance(local_shard_size, int)
+                local_shape[shard_dim] = local_shard_size
+
+        return tuple(local_shape)
+
+
+def compute_local_shape_and_global_offset(
+    global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
+) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
+    """
+    Compute the local tensor shape and the global offsets into the original tensor
+    of a DTensor on its current global rank. This is useful for checkpointing purpose.
+
+    Example (2 host with 4GPUs each):
+    # Below is a DeviceMesh with mesh_shape of (2, 4)
+    mesh = DeviceMesh(device_type="cuda",
+                        mesh=[
+                        [0, 1, 2, 3],
+                        [4, 5, 6, 7]
+                        ],
+    )
+
+    Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh
+    with a placements of [Shard(0), Shard(0)].
+    The local shape and global offset will be as follows:
+    rank0 -- local_shape:[1, 4], global_offset:[0, 0]
+    rank1 -- local_shape:[1, 4], global_offset:[1, 0]
+    rank2 -- local_shape:[1, 4], global_offset:[2, 0]
+    rank5 -- local_shape:[1, 4], global_offset:[5, 0]
+    rank3 -- local_shape:[1, 4], global_offset:[3, 0]
+    rank4 -- local_shape:[1, 4], global_offset:[4, 0]
+    rank6 -- local_shape:[1, 4], global_offset:[6, 0]
+    rank7 -- local_shape:[1, 4], global_offset:[7, 0]
+
+    Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with
+    a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks.
+    The local shape and global offset will be as follows:
+    rank0 -- local_shape:[1,], global_offset:[0,]
+    rank1 -- local_shape:[1,], global_offset:[1,]
+    rank2 -- local_shape:[0,], global_offset:[2,]
+    rank5 -- local_shape:[0,], global_offset:[2,]
+    rank3 -- local_shape:[0,], global_offset:[2,]
+    rank4 -- local_shape:[0,], global_offset:[2,]
+    rank6 -- local_shape:[0,], global_offset:[2,]
+    rank7 -- local_shape:[0,], global_offset:[2,]
+    """
+    my_coordinate = mesh.get_coordinate()
+
+    if my_coordinate is None:
+        # if rank not in the mesh, return empty offset
+        return ((), ())
+    else:
+        local_shape = list(global_shape)
+        global_offset = [0] * len(global_shape)
+
+        for idx, placement in enumerate(placements):
+            mesh_dim_size = mesh.size(idx)
+            if isinstance(placement, Shard):
+                shard_dim = placement.dim
+                local_offset = [0] * len(global_shape)
+                assert shard_dim < len(
+                    local_shape
+                ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
+                shard_size, shard_offset = placement._local_shard_size_on_dim(
+                    local_shape[shard_dim],
+                    mesh_dim_size,
+                    my_coordinate[idx],
+                    return_offset=True,
+                )
+
+                local_shape[shard_dim] = shard_size
+                local_offset[shard_dim] = shard_offset
+
+                # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
+                # it means that this dimension has been already sharded in previous placement.
+                # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
+                # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
+                if global_offset[shard_dim] <= local_offset[shard_dim]:
+                    global_offset[shard_dim] = local_offset[shard_dim]
+                else:
+                    global_offset[shard_dim] += local_offset[shard_dim]
+
+        return tuple(local_shape), tuple(global_offset)
+
+
+def compute_global_tensor_info(
+    tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
+) -> Tuple[List[int], List[int]]:
+    """
+    Compute the global size and stride of a DTensor from the given local tensor.
+    The local size is multiplited by `world_size` per Sharding dim.
+    The local stride is multiplited by `world_size` per Sharding dim, as long as the
+    dimension is outside sharding dim.
+
+    For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).
+    If the DTensor placements are [Shard(2)] and world_size is 2;
+    then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).
+
+    Args:
+        tensor (:class:`torch.Tensor`):
+            Local tensor which DTensor will be constructed from.
+        mesh (:class:`DeviceMesh`):
+            Object which describes the mesh topology
+            of devices for the DTensor.
+        placements (Sequence[:class:`Placement`]]):
+            The attribute of the DTensor that describes its layout
+            on the mesh topology.
+
+    Return:
+        tensor_shape: A List of int which specifies the size of DTensor which build
+            on top of the local tensor.
+        tensor_stride: A List of int which specifies the stride of DTensor.
+    """
+    tensor_shape = list(tensor.size())
+    tensor_stride = list(tensor.stride())
+    for idx, placement in enumerate(placements):
+        mesh_dim_size = mesh.size(idx)
+        if placement.is_shard():
+            shard_placement = cast(Shard, placement)
+            if shard_placement.dim < 0:
+                raise AssertionError(
+                    "Shard placements should have negative dims normalized in "
+                    f"the user-facing APIs: {shard_placement}"
+                )
+            shard_dim = shard_placement.dim
+
+            assert (
+                shard_dim < tensor.ndim
+            ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
+
+            local_dim_size = tensor_shape[shard_dim]
+            tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
+
+            # recover tensor stride by modifying the stride that larger than
+            # the current stride on the shard_dim
+            for i in range(len(tensor_stride)):
+                if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
+                    # rescale the stride by the shard size
+                    tensor_stride[i] = tensor_stride[i] * mesh_dim_size
+        elif not isinstance(placement, (Replicate, _Partial)):
+            raise RuntimeError(f"placement type {type(placement)} not supported!")
+    return tensor_shape, tensor_stride
+
+
+def try_find_mesh_from_args(
+    op_call: torch._ops.OpOverload, args: Sequence[object]
+) -> DeviceMesh:
+    """
+    Find the device mesh object from args.
+    It returns None if no mesh is found.
+    NOTE: we can optimize this search if needed
+    """
+    for arg in args:
+        if isinstance(arg, (dtensor.DTensor, DTensorSpec)):
+            return arg.device_mesh
+        elif (
+            isinstance(arg, (list, tuple))
+            and len(arg) > 0
+            and isinstance(arg[0], (dtensor.DTensor, DTensorSpec))
+        ):
+            return arg[0].device_mesh
+
+    raise ValueError(f"Cannot find device mesh from args for op : {op_call}.")
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/api.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..feabc57d2fd75d7e90b9c32215f4f40bb4a07928
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/api.py
@@ -0,0 +1,760 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import inspect
+import warnings
+from typing import Any, Callable, cast, Optional, Sequence, Tuple
+
+import torch
+
+import torch.distributed._tensor.dispatch as op_dispatch
+import torch.distributed._tensor.random as random
+import torch.nn as nn
+from torch.distributed._tensor._collective_utils import mesh_broadcast
+from torch.distributed._tensor._utils import compute_global_tensor_info
+from torch.distributed._tensor.placement_types import (
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+    TensorMeta,
+)
+from torch.distributed._tensor.random import (
+    is_rng_supported_mesh,
+    OffsetBasedRNGTracker,
+)
+from torch.distributed._tensor.redistribute import (
+    Redistribute,
+    redistribute_local_tensor,
+)
+from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
+
+
+__all__ = ["DTensor", "distribute_tensor", "distribute_module"]
+
+aten = torch.ops.aten
+
+
+# NOTE [Autograd interaction between torch.Tensor]
+#
+# The autograd functions defined below are being used by the public
+# facing APIs (i.e. from_local, to_local) to ensure our DTensor
+# works together with torch.Tensor within autograd engine. This
+# allows DistributedTensor to exist on part of the module hierarchy
+# and still able to calculate gradients across the torch.Tensor and
+# DistributedTensor boundary.
+# As an example, we have the a module that consists of submodules
+# A, B, and C, the execution flow would be like:
+#  input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
+#
+# Suppose I only want to make Module B be a sharded module with
+# DistributedTensor params, we would need to make the following
+# flow to work:
+#
+#  input(torch.Tensor) -> Module A
+#       -> DTensor input -> Sharded Module B -> DTensor output
+#           -> output (torch.Tensor) -> Module C -> output (torch.Tensor)
+#
+# We need the conversion from Module A to DTensor input, which is
+# `from_local`, and conversion from DTensor output to output, which
+# is `to_local`, thus these two functions must be Autograd functions.
+#
+class _ToTorchTensor(torch.autograd.Function):
+    @staticmethod
+    def forward(  # type: ignore[override]
+        ctx,
+        input: "DTensor",
+        grad_placements: Optional[Sequence[Placement]],
+    ):
+        ctx.dtensor_spec = input._spec
+        ctx.grad_placements = grad_placements
+        local_tensor = input._local_tensor
+
+        # We need to return a fresh Tensor object there as autograd metadata
+        # will be inplaced into it. So we don't want to pollute the Tensor
+        # object stored in the _local_tensor of this DTensor.
+        return local_tensor.view_as(local_tensor)
+
+    @staticmethod
+    def backward(ctx, grad_output: torch.Tensor):  # type: ignore[override]
+        dtensor_spec = ctx.dtensor_spec
+        mesh = dtensor_spec.mesh
+        grad_placements = ctx.grad_placements
+        dtensor_meta = dtensor_spec.tensor_meta
+
+        _, tensor_stride = compute_global_tensor_info(
+            grad_output, mesh, dtensor_spec.placements
+        )
+        tensor_stride = tuple(tensor_stride)
+        grad_placements = grad_placements or dtensor_spec.placements
+
+        return (
+            DTensor(
+                grad_output,
+                mesh,
+                grad_placements,
+                shape=dtensor_meta.shape,
+                dtype=dtensor_meta.dtype,
+                requires_grad=grad_output.requires_grad,
+                stride=tensor_stride,
+            ),
+            None,
+        )
+
+
+class _FromTorchTensor(torch.autograd.Function):
+    @staticmethod
+    def forward(  # type: ignore[override]
+        ctx,  # pyre-ignore[2]: Parameter must be annotated.
+        input: torch.Tensor,
+        device_mesh: DeviceMesh,
+        placements: Tuple[Placement, ...],
+        run_check: bool,
+        shape: Optional[torch.Size] = None,
+        stride: Optional[Tuple[int, ...]] = None,
+    ) -> "DTensor":
+        ctx.previous_placement = placements
+        ctx.previous_device_mesh = device_mesh
+
+        if shape and stride:
+            tensor_shape, tensor_stride = shape, stride
+        elif not shape and not stride:
+            # if it's not by default run_check, we assume user is certain that each
+            # rank has the same tensor shape, and we just use that to calculate the
+            # global shape
+            global_shape, global_stride = compute_global_tensor_info(
+                input, device_mesh, placements
+            )
+            tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
+        else:
+            raise RuntimeError(
+                f"Found shape:{shape}, stride:{stride}.",
+                "Please pass both shape and stride at the same time.",
+            )
+
+        if device_mesh.get_coordinate() is None:
+            # if the global rank is not participating in the device mesh, we
+            # simply set the local tensor to an empty tensor
+            input = input.new_empty(0, requires_grad=input.requires_grad)
+        elif run_check:
+            # TODO: by default check tensor metas across rank
+            # TODO: See if we need to make this run_check logic
+            # have a corresponding backward.
+            for idx, placement in enumerate(placements):
+                if placement.is_replicate():
+                    # broadcast rank 0 tensor to all ranks
+                    # only broadcast if run_check is True
+                    input = input.contiguous()
+                    mesh_broadcast(input, device_mesh, mesh_dim=idx)
+
+        # We want a fresh Tensor object that shares memory with the input tensor
+        dist_tensor = DTensor(
+            input.view_as(input),
+            device_mesh,
+            placements,
+            shape=tensor_shape,
+            dtype=input.dtype,
+            # requires_grad of the dist tensor depends on if input
+            # requires_grad or not
+            requires_grad=input.requires_grad,
+            stride=tensor_stride,
+        )
+        return dist_tensor
+
+    @staticmethod
+    def backward(ctx, grad_output: "DTensor"):  # type: ignore[override]
+        previous_placement = ctx.previous_placement
+        previous_device_mesh = ctx.previous_device_mesh
+
+        # reshard to the placement when creating DistributedTensor
+        # so that the gradient layout matches, and we could return
+        # local gradients directly
+        if grad_output.placements != previous_placement:
+            current_spec = grad_output._spec
+            target_spec = DTensorSpec(
+                previous_device_mesh,
+                previous_placement,
+                tensor_meta=grad_output._spec.tensor_meta,
+            )
+            local_tensor = grad_output._local_tensor
+            output = redistribute_local_tensor(
+                local_tensor, current_spec, target_spec, is_backward=True
+            )
+            # TODO: return the redistributed local tensor directly without
+            # differentiable backward. see if this make sense for all cases.
+            return output, None, None, None, None, None
+
+        # TODO: backward is also differentiable now, add a test
+        # to test higher level gradients.
+        return grad_output.to_local(), None, None, None, None, None
+
+
+class DTensor(torch.Tensor):  # pyre-ignore[13]: pyre is bad at __new__
+    _local_tensor: torch.Tensor
+    _spec: DTensorSpec
+    __slots__ = ["_local_tensor", "_spec"]
+
+    # class attribute that handles operator placements propagation
+    # rules, keyed by aten op name, value is propagation func
+    _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
+
+    @staticmethod
+    def __new__(
+        cls,
+        local_tensor: torch.Tensor,
+        device_mesh: DeviceMesh,
+        placements: Tuple[Placement, ...],
+        *,
+        shape: torch.Size,
+        dtype: torch.dtype,
+        requires_grad: bool,
+        stride: Tuple[int, ...],
+    ) -> "DTensor":
+        """
+        Construct a DTensor from a local tensor, device mesh, and placement and
+        other tensor properties (i.e. shape, requires_grad, strides, etc).
+        Note: This is not a public API and it's only supposed to be used by the
+            operator implementations and internals. If you want to construct a
+            DTensor from a local tensor, consider using `DTensor.from_local`, if
+            you want to construct a DTensor from a "global" tensor (where you
+            already have tensor initialized and want to shard this tensor),
+            consider using `distribute_tensor`.
+        """
+        if local_tensor.requires_grad and not requires_grad:
+            warnings.warn(
+                "To construct DTensor from torch.Tensor, it's recommended to "
+                "use local_tensor.detach() and make requires_grad consistent."
+            )
+
+        # new method instruct wrapper tensor from local_tensor and add
+        # placement spec, it does not do actual distribution
+        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
+            cls,
+            shape,
+            strides=stride,
+            dtype=dtype,
+            device=local_tensor.device,
+            layout=local_tensor.layout,
+            requires_grad=requires_grad,
+        )
+
+        tensor_meta = TensorMeta(shape, stride, dtype)
+        # deepcopy and set spec
+        r._spec = DTensorSpec(device_mesh, placements, tensor_meta=tensor_meta)
+        r._local_tensor = local_tensor
+        return r
+
+    # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
+    # pyre-fixme[3]: Return type must be annotated.
+    def __repr__(self):
+        # TODO: consider all_gather the local tensors for better debugging
+        return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
+
+    def __tensor_flatten__(self):
+        """
+        protocol to inform how to flatten a DTensor to local tensor
+        for PT2 tracing
+        """
+        return ["_local_tensor"], (self._spec, self.requires_grad)
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
+        assert (
+            flatten_spec is not None
+        ), "Expecting spec to be not None from `__tensor_flatten__` return value!"
+        local_tensor = inner_tensors["_local_tensor"]
+        spec, requires_grad = flatten_spec
+        return DTensor(
+            local_tensor,
+            spec.mesh,
+            spec.placements,
+            shape=outer_size,
+            dtype=spec.tensor_meta.dtype,
+            requires_grad=requires_grad,
+            stride=outer_stride,
+        )
+
+    @classmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        return DTensor._op_dispatcher.dispatch(
+            func,
+            args,
+            kwargs or {},
+        )
+
+    @staticmethod
+    def from_local(
+        local_tensor: torch.Tensor,
+        device_mesh: Optional[DeviceMesh] = None,
+        placements: Optional[Sequence[Placement]] = None,
+        *,
+        run_check: bool = True,
+        shape: Optional[torch.Size] = None,
+        stride: Optional[Tuple[int, ...]] = None,
+    ) -> "DTensor":
+        """
+        Create a :class:`DTensor` from a local torch.Tensor on each rank
+        according to the `device_mesh` and `placements` specified.
+
+        Args:
+            local_tensor (torch.Tensor): local torch.Tensor on each rank.
+            device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
+                tensor, if not specified, must be called under a DeviceMesh
+                context manager, default: None
+            placements (List[:class:`Placement`], optional): the placements that
+                describes how to place the local torch.Tensor on DeviceMesh, must
+                have the same number of elements as `device_mesh.ndim`. If not
+                specified, we will by default replicate the tensor across the
+                `device_mesh` from the first rank of each dimension of the `device_mesh`.
+
+        Keyword args:
+            run_check (bool, optional): indicate whether to run check across ranks
+                to check meta information and data. if have :class:`Replicate` in
+                `placements`, the data on first rank of the device mesh dimension
+                will be broadcasted to other ranks.
+            shape (torch.Size, optional): A List of int which specifies the size of
+                DTensor which build on top of `local_tensor`. Note this needs to be
+                provided if the shape of `local_tensor` are different across the ranks.
+                If not provided, `shape` will be computed assuming the given distributed
+                tensor is evenly sharded across ranks.
+            stride (tuple, optional): A List of int which specifies the stride of DTensor.
+                If not provided, `stride` will be computed assuming the given distributed
+                tensor is evenly sharded across ranks.
+
+        Returns:
+            A :class:`DTensor` object
+
+        .. note:: `from_local` is differentiable, the `requires_grad` of the created
+            `DTensor` object will depend on if `local_tensor` requires_grad or not.
+        """
+        # if same shape/dtype, no need to run_check, if not, must allgather
+        # the metadatas to check the size/dtype across ranks
+        # There should be no data communication unless there's replication
+        # strategy, where we broadcast the replication from the first rank
+        # in the mesh dimension
+        device_mesh = device_mesh or _mesh_resources.get_current_mesh()
+        device_type = device_mesh.device_type
+
+        # convert the local tensor to desired device base on device mesh's device_type
+        if device_type != local_tensor.device.type and not local_tensor.is_meta:
+            local_tensor = local_tensor.to(device_type)
+
+        # set default placements to replicated if not specified
+        if placements is None:
+            placements = [Replicate() for _ in range(device_mesh.ndim)]
+        else:
+            placements = list(placements)
+            for idx, placement in enumerate(placements):
+                # normalize shard dim to be positive
+                if placement.is_shard():
+                    placement = cast(Shard, placement)
+                    if placement.dim < 0:
+                        placements[idx] = Shard(placement.dim + local_tensor.ndim)
+
+        # `from_local` is differentiable, and the gradient of the dist tensor this function
+        # created should flow back the gradients to the local_tensor, so we call an autograd
+        # function to construct the dist tensor instead.
+        return _FromTorchTensor.apply(  # pyre-ignore[16]: autograd func
+            local_tensor,
+            device_mesh,
+            tuple(placements),
+            run_check,
+            shape,
+            stride,
+        )
+
+    def to_local(
+        self, *, grad_placements: Optional[Sequence[Placement]] = None
+    ) -> torch.Tensor:
+        """
+        Get the local tensor of this DTensor on its current rank. For sharding it returns
+        a local shard of the logical tensor view, for replication it returns the replica on
+        its current rank.
+
+        Keyword args:
+            grad_placements (List[:class:`Placement`], optional): the placements describes
+                the future layout of any gradient layout of the Tensor returned from this
+                function.
+                `to_local` converts DTensor to local tensor and the returned local tensor
+                might not be used as the original DTensor layout later in the code. This
+                argument is the hint that user can give to autograd in case the gradient
+                layout of the returned tensor does not match the original DTensor layout.
+                If not specified, we will assume the gradient layout remains the same
+                as the original DTensor and use that for gradient computation.
+
+        Returns:
+            A :class:`torch.Tensor` or `AsyncCollectiveTensor` object. it represents the
+            local tensor on its current rank.
+
+        .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned
+            will depend on if the `DTensor` requires_grad or not.
+        """
+        if grad_placements is not None and not isinstance(grad_placements, tuple):
+            grad_placements = tuple(grad_placements)
+        return _ToTorchTensor.apply(
+            self, grad_placements
+        )  # pyre-ignore[16]: autograd func
+
+    def redistribute(
+        self,
+        device_mesh: Optional[DeviceMesh] = None,
+        placements: Optional[Sequence[Placement]] = None,
+        *,
+        async_op: bool = False,
+    ) -> "DTensor":
+        """
+        `redistribute` performs necessary collective operations that redistribute the current
+        DTensor from its current placements to a new placements, or from is current DeviceMesh
+        to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
+        specifying a Replicate placement for each dimension of the DeviceMesh.
+
+        Args:
+            device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
+                DTensor, if not specified, must be called under a DeviceMesh
+                context manager, default: None
+            placements (List[:class:`Placement`], optional): the new placements that
+                describes how to place the DTensor into the DeviceMesh, must
+                have the same number of elements as `device_mesh.ndim`.
+
+        Keyword args:
+            async_op (bool, optional): whether to perform the DTensor redistribute operation
+                asynchronously or not. Default: False
+
+        Returns:
+            A :class:`DTensor` object
+
+        .. note:: `redistribute` is differentiable.
+        """
+        # NOTE: This redistribute API currently only supports out
+        # of place redistribution, i.e. it always create a new
+        # DTensor object and leave the original one unchanged.
+
+        # if device_mesh is not specified, use the current device_mesh
+        device_mesh = device_mesh or self.device_mesh
+        # raise error if new placements not specified
+        if placements is None:
+            raise RuntimeError("placements is needed for redistribute!")
+
+        placements = list(placements)
+        for i, placement in enumerate(placements):
+            if placement.is_partial():
+                raise RuntimeError(
+                    "Can not redistribute to _Partial, _Partial is for internal use only!"
+                )
+            elif isinstance(placement, Shard) and placement.dim < 0:
+                # normalize shard dim to be positive
+                placements[i] = Shard(placement.dim + self.ndim)
+        placements = tuple(placements)
+
+        # Early return the original DTensor if the placements are the same.
+        if self._spec.placements == placements:
+            return self
+
+        # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
+        return Redistribute.apply(self, device_mesh, placements, async_op)
+
+    def full_tensor(
+        self, *, grad_placements: Optional[Sequence[Placement]] = None
+    ) -> torch.Tensor:
+        """
+        Return the full tensor of this DTensor. It will perform necessary collectives
+        to gather the local tensors from other ranks in its DeviceMesh and concatenate
+        them together. It's a syntatic sugar of the following code:
+
+        `dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`
+
+        Keyword args:
+            grad_placements (List[:class:`Placement`], optional): the placements describes
+                the future layout of any gradient layout of the full Tensor returned from this
+                function.
+                `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor
+                might not be used as the original replicated DTensor layout later in the code. This
+                argument is the hint that user can give to autograd in case the gradient
+                layout of the returned tensor does not match the original replicated DTensor layout.
+                If not specified, we will assume the gradient layout of the full tensor be replicated.
+
+        Returns:
+            A :class:`torch.Tensor` object that represents the full tensor of this DTensor.
+
+        .. note:: `full_tensor` is differentiable.
+        """
+
+        redist_res = self.redistribute(
+            placements=[Replicate()] * self.device_mesh.ndim, async_op=False
+        )
+        return _ToTorchTensor.apply(redist_res, grad_placements)
+
+    @property
+    def device_mesh(self) -> DeviceMesh:
+        """
+        The :class:`DeviceMesh` attribute that associates with this DTensor object.
+
+        .. note:: device_mesh is a read-only property, it can not be set.
+        """
+        return self._spec.mesh
+
+    @property
+    def placements(self) -> Sequence[Placement]:
+        """
+        The placements attribute of this DTensor that describes the layout of this
+        DTensor on the its DeviceMesh.
+
+        .. note:: placements is a read-only property, it can not be set.
+        """
+        return self._spec.placements
+
+
+def distribute_tensor(
+    tensor: torch.Tensor,
+    device_mesh: Optional[DeviceMesh] = None,
+    placements: Optional[Sequence[Placement]] = None,
+) -> DTensor:
+    """
+    Distribute a torch.Tensor to the `device_mesh` according to the `placements`
+    specified. The rank of `device_mesh` and `placements` must be the same.
+
+    Args:
+        tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
+            want to shard a tensor on a dimension that is not evenly divisible by
+            the number of devices in that mesh dimension, we use `torch.chunk`
+            semantic to shard the tensor and scatter the shards.
+        device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
+            tensor, if not specified, must be called under a DeviceMesh context
+            manager, default: None
+        placements (List[:class:`Placement`], optional): the placements that
+            describes how to place the tensor on DeviceMesh, must have the same
+            number of elements as `device_mesh.ndim`. If not specified, we will
+            by default replicate the tensor across the `device_mesh` from the
+            first rank of each dimension of the `device_mesh`.
+
+    Returns:
+        A :class:`DTensor` or `XLAShardedTensor` object.
+
+    Note:
+        When initialize the DeviceMesh with the `xla` device_type, `distribute_tensor`
+        return `XLAShardedTensor` instead. see [link](https://github.com/pytorch/pytorch/issues/92909)
+        for more details. The XLA integration is experimental and subject to change.
+    """
+
+    torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")
+
+    # get default device mesh if there's nothing specified
+    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
+    device_type = device_mesh.device_type
+    if device_type == "xla":
+        try:
+            # call PyTorch/XLA SPMD for `xla` backend type device mesh.
+            # This returns XLAShardedTensor
+            from torch_xla.distributed.spmd import (  # type:ignore[import]
+                xla_distribute_tensor,
+            )
+
+            return xla_distribute_tensor(
+                tensor, device_mesh, placements
+            )  # type:ignore[return-value]
+        except ImportError as e:
+            msg = "To use DTensor API with xla, you must install the torch_xla package!"
+            raise ImportError(msg) from e
+
+    # instantiate a RNG tracker if haven't. By default DTensor uses an
+    # OffsetBasedRNGTracker to perform random operators.
+    # TODO: the value assignment to global variable is not the ideal solution
+    # we can replace it in future.
+    if is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
+        random._rng_tracker = OffsetBasedRNGTracker(device_type)
+
+    if not tensor.is_leaf:
+        raise RuntimeError(
+            "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
+        )
+
+    # convert tensor to the corresponding device type if it's not in that device type
+    if device_type != tensor.device.type and not tensor.is_meta:
+        tensor = tensor.to(device_type)
+
+    # set default placements to replicated if not specified
+    if placements is None:
+        placements = [Replicate() for _ in range(device_mesh.ndim)]
+
+    if len(placements) != device_mesh.ndim:
+        raise ValueError(
+            f"`placements` must have the same length as `device_mesh.ndim`! "
+            f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
+        )
+    if isinstance(tensor, DTensor):
+        # if the tensor is already a DTensor, we just need to check if the
+        # device mesh and placements are the same
+        if tensor.device_mesh != device_mesh:
+            raise ValueError(
+                f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
+                f"to a different device mesh {device_mesh}."
+            )
+        if tensor.placements != tuple(placements):
+            raise ValueError(
+                f"Cannot distribute a DTensor with placements {tensor.placements} "
+                f"to a different placements {placements}. do you want to call "
+                f"`redistribute` instead?"
+            )
+        return tensor
+
+    local_tensor = tensor
+
+    # distribute the tensor according to the placements.
+    placements = list(placements)
+    for idx, placement in enumerate(placements):
+        if placement.is_shard():
+            placement = cast(Shard, placement)
+            if placement.dim < 0:
+                # normalize shard placement dim
+                placement = Shard(placement.dim + tensor.ndim)
+                placements[idx] = placement
+            local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
+        elif placement.is_replicate():
+            placement = cast(Replicate, placement)
+            local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx)
+        else:
+            raise RuntimeError(
+                f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
+            )
+    placements = tuple(placements)
+
+    assert local_tensor is not None, "distributing a tensor should not be None"
+    # detach the local tensor passed to DTensor since after the construction
+    # of DTensor, autograd would work on top of DTensor instead of local tensor
+    return DTensor(
+        local_tensor.detach().requires_grad_(tensor.requires_grad),
+        device_mesh,
+        placements,
+        shape=tensor.size(),
+        dtype=tensor.dtype,
+        requires_grad=tensor.requires_grad,
+        stride=tensor.stride(),
+    )
+
+
+def distribute_module(
+    module: nn.Module,
+    device_mesh: Optional[DeviceMesh] = None,
+    partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
+    input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
+    output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
+) -> nn.Module:
+    """
+    This function converts all module parameters to :class:`DTensor` parameters
+    according to the `partition_fn` specified. It could also control the input or
+    output of the module by specifying the `input_fn` and `output_fn`. (i.e. convert
+    the input to :class:`DTensor`, convert the output back to torch.Tensor)
+    Args:
+        module (:class:`nn.Module`): user module to be partitioned.
+        device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
+        partition_fn (Callable): the function to partition parameters (i.e. shard certain
+            parameters across the `device_mesh`). If `partition_fn` is not specified,
+            by default we replicate all module parameters of `module` across the mesh.
+        input_fn (Callable): specify the input distribution, i.e. could control how the
+            input of the module is sharded. `input_fn` will be installed as a module
+            `forward_pre_hook` (pre forward hook).
+        output_fn (Callable): specify the output distribution, i.e. could control how the
+            output is sharded, or convert it back to torch.Tensor. output_fn will be
+            installed as a module `forward_hook` (post forward hook).
+
+    Returns:
+        A module that contains parameters/buffers that are all `DTensor`s.
+
+    Note:
+        When initialize the DeviceMesh with the `xla` device_type, `distribute_module`
+        return nn.Module with PyTorch/XLA SPMD annotated parameters. See [link](https://github.com/pytorch/pytorch/issues/92909)
+        for more details. The XLA integration is experimental and subject to change.
+    """
+
+    torch._C._log_api_usage_once("torch.dtensor.distribute_module")
+
+    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
+    device_type = device_mesh.device_type
+    if device_type == "xla":
+        try:
+            # This function annotates all module parameters for auto-partitioning with
+            # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
+            # according to the `partition_fn` specified.
+            from torch_xla.distributed.spmd import (  # type:ignore[import]
+                xla_distribute_module,
+            )
+
+            return xla_distribute_module(
+                module, device_mesh, partition_fn, input_fn, output_fn
+            )  # type:ignore[return-value]
+        except ImportError as e:
+            msg = "To use DTensor API with xla, you must install the torch_xla package!"
+            raise ImportError(msg) from e
+
+    def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
+        # This function loop over the immediate module parameters and
+        # buffers, replicate all non DTensor params/buffers to DTensor
+        # parameters/buffers, if they have not been partitioned in the
+        # partition_fn, we can't easily use `module._apply` here
+        # because we don't know what happened inside partition_fn as
+        # user could do anything, i.e. install hooks, and we want to
+        # preserve those.
+        full_replicate = [Replicate()] * mesh.ndim
+        for key, param in m._parameters.items():
+            if param is not None and not isinstance(param, DTensor):
+                m.register_parameter(
+                    key,
+                    nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
+                )
+        for key, buffer in m._buffers.items():
+            if buffer is not None and not isinstance(buffer, DTensor):
+                m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
+
+    if partition_fn is None:
+        # if partition_fn not specified, we by default replicate
+        # all module params/buffers
+        for name, submod in module.named_modules():
+            replicate_module_params_buffers(submod, device_mesh)
+    else:
+        # apply partition_fun to submodules
+        for name, submod in module.named_modules():
+            partition_fn(name, submod, device_mesh)
+            replicate_module_params_buffers(submod, device_mesh)
+
+    # register input_fn as module forward pre hook
+    if input_fn is not None:
+        # check the input_fn signature
+        num_args = len(inspect.signature(input_fn).parameters)
+        if num_args == 2:
+            # input_fn only takes in inputs and device mesh
+            warnings.warn(
+                "Deprecating input_fn that takes two arguments (inputs, device_mesh), "
+                "please use input_fn that takes in (module, inputs, device_mesh) instead!",
+            )
+            module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh))  # type: ignore[call-arg]
+        elif num_args == 3:
+            # input_fn takes in module, inputs, device mesh
+            module.register_forward_pre_hook(
+                lambda mod, inputs: input_fn(mod, inputs, device_mesh)
+            )
+        else:
+            raise ValueError(
+                f"input_fn should take in 3 arguments, but got {num_args} arguments!"
+            )
+    # register output_fn as module forward hook
+    if output_fn is not None:
+        num_args = len(inspect.signature(output_fn).parameters)
+        if num_args == 2:
+            # output_fn only takes in outputs and device mesh
+            warnings.warn(
+                "Deprecating output_fn that takes two arguments (inputs, device_mesh), "
+                "please use output_fn that takes in (module, inputs, device_mesh) instead!",
+            )
+            module.register_forward_hook(
+                lambda mod, inputs, outputs: output_fn(outputs, device_mesh)  # type: ignore[call-arg]
+            )
+        elif num_args == 3:
+            module.register_forward_hook(
+                lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
+            )
+        else:
+            raise ValueError(
+                f"output_fn should take in 3 arguments, but got {num_args} arguments!"
+            )
+
+    return module
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e3cdb8683b076258ab805257e61e8fecd71f67f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__init__.py
@@ -0,0 +1,14 @@
+from torch.distributed._tensor.api import DTensor
+
+from torch.distributed._tensor.debug.comm_mode import CommDebugMode
+
+
+def get_sharding_prop_cache_info():
+    """
+    Get the cache info for the sharding propagation cache, used for debugging purpose only.
+    This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
+    propagator cache.
+    """
+    return (
+        DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info()  # type:ignore[attr-defined]
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b889ca5f95fd25df7ce14965cc051043eef1e609
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/comm_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/comm_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf25caf053dacd986616669fa8947db1ea16e7cb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/comm_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/op_coverage.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/op_coverage.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05207625ac97c44dc52c95cacb58b2a11dc05660
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/op_coverage.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/visualize_sharding.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/visualize_sharding.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5df149e8441907e9f084f28ae13e88d368313d47
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/__pycache__/visualize_sharding.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/comm_mode.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/comm_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8adf0e338418717ac6c2436178ebe2f6463af69
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/comm_mode.py
@@ -0,0 +1,91 @@
+from collections import defaultdict
+from typing import Any, Dict
+
+import torch
+from torch.distributed._tensor.api import DTensor
+from torch.utils._python_dispatch import TorchDispatchMode
+
+
+funcol_native = torch.ops._c10d_functional
+funcol_py = torch.ops.c10d_functional
+
+NATIVE_TO_PY_MAPPING = {
+    funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor,
+    funcol_native.all_gather_into_tensor_coalesced: funcol_py.all_gather_into_tensor_coalesced,
+    funcol_native.all_reduce: funcol_py.all_reduce,
+    funcol_native.all_to_all_single: funcol_py.all_to_all_single,
+    funcol_native.broadcast: funcol_py.broadcast,
+    funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor,
+    funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced,
+}
+
+
+class CommDebugMode(TorchDispatchMode):
+    """
+    ``CommDebugMode`` is a context manager that counts the number of
+    functional collectives within its context. It does this using a
+    ``TorchDispatchMode``.
+
+    NOTE: this mode only works for functional collective atm and the
+    distributed_c10d collectives are not supported yet.
+
+    Example usage
+
+    .. code-block:: python
+
+        mod = ...
+        comm_mode = CommDebugMode()
+        with comm_mode:
+            mod.sum().backward()
+
+    """
+
+    def __init__(self):
+        self.comm_counts: Dict[Any, int] = defaultdict(int)
+        self.comm_registry = set()
+        for native_op, py_op in NATIVE_TO_PY_MAPPING.items():
+            self.comm_registry.add(native_op)
+            self.comm_registry.add(py_op)
+
+    def get_total_counts(self) -> int:
+        return sum(self.comm_counts.values())
+
+    def get_comm_counts(self) -> Dict[Any, int]:
+        """Returns the communication counts as a dictionary.
+
+        Returns:
+            Dict[Any, int]: The communication counts as a dictionary.
+        """
+        return self.comm_counts
+
+    def __enter__(self):
+        self.comm_counts.clear()
+        super().__enter__()
+        return self
+
+    def __exit__(self, *args):
+        super().__exit__(*args)
+
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        # When running this mode with DTensor, ordinarily all modes will
+        # run **before** subclasses get a chance to run.
+        # Returning NotImplemented here gives us a chance to let DTensor
+        # run and desugar into comms ops, before CommDebugMode sees them.
+        if any(t == DTensor for t in types):
+            return NotImplemented
+        kwargs = kwargs if kwargs else {}
+        out = func(*args, **kwargs)
+        func_packet = func._overloadpacket
+        # We have many tests that use CommDebugMode to verify the occurrence of
+        # collectives. These tests do so by querying comm_counts with legacy
+        # funcol ops as key. For the purpose of native funcol migration, we
+        # need these tests to work for both legacy and native funcol. To avoid
+        # the need to modify all tests to accommodate the two implementations,
+        # we make CommDebugMode translate native funcol ops into legacy funcol
+        # ops until the migration finishes.
+        if func_packet in self.comm_registry:
+            if func_packet in NATIVE_TO_PY_MAPPING:
+                func_packet = NATIVE_TO_PY_MAPPING[func_packet]
+            self.comm_counts[func_packet] += 1
+
+        return out
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/op_coverage.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/op_coverage.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66eddcfecd9c99c29447d65a7680fd92bd5e902
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/op_coverage.py
@@ -0,0 +1,105 @@
+from operator import itemgetter
+from typing import List
+
+from functorch.compile import make_boxed_func
+
+import torch
+import torch.fx
+import torch.nn as nn
+from torch._functorch.compilers import aot_module
+from torch._inductor.decomposition import select_decomp_table
+from torch.distributed._tensor import DTensor
+
+
+inductor_decomps = select_decomp_table()
+
+graphs: List[torch.fx.GraphModule] = []
+
+
+def fwd_bwd_compiler(fx_g, _):
+    graphs.append(fx_g)
+    return make_boxed_func(fx_g)
+
+
+def get_inductor_decomp_graphs(model: nn.Module, args, kwargs):
+    """
+    Obtain forward and backward graphs of a model with inductor decompositions using tracing and aot_module.
+
+    Convenient util to get the fwd and bwd graphs of an arbitrary model
+    with inductor decompositions. Note that this would simply do tracing
+    with aot_module and don't ensure correctness. This is useful to track
+    the ops needed in DTensor.
+    """
+    compiled_mod = aot_module(
+        model, fw_compiler=fwd_bwd_compiler, decompositions=inductor_decomps
+    )
+    output = compiled_mod(*args, **kwargs)
+
+    if output.ndim != 0:
+        # if output is not a scalar tensor, by default sum it in order to
+        # run backward
+        output = output.sum()
+
+    output.backward()
+
+    # one fwd, one bwd graph
+    assert len(graphs) == 2
+    return graphs
+
+
+def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=False):
+    """
+    Util to print the operator coverage summary of a certain model with tabulute.
+
+    Must have tabulate module installed.
+    """
+    # python module required for summary
+    import csv
+
+    from tabulate import tabulate
+
+    fwd_graph, bwd_graph = get_inductor_decomp_graphs(model, args, kwargs)
+
+    op_counts = {}
+
+    for node in fwd_graph.graph.nodes:
+        if node.op == "call_function" and isinstance(
+            node.target, torch._ops.OpOverload
+        ):
+            if node.target not in op_counts:
+                op_counts[node.target] = 0
+
+            op_counts[node.target] += 1
+
+    for node in bwd_graph.graph.nodes:
+        if node.op == "call_function" and isinstance(
+            node.target, torch._ops.OpOverload
+        ):
+            if node.target not in op_counts:
+                op_counts[node.target] = 0
+
+            op_counts[node.target] += 1
+
+    op_infos = []
+
+    for op, count in op_counts.items():
+        supported = op in DTensor._op_dispatcher.sharding_propagator.op_to_rules
+        op_infos.append([op, str(op._schema), count, supported])
+
+    # sort the op info base on the total count index
+    count_idx = 2
+    op_infos.sort(key=itemgetter(count_idx), reverse=True)
+
+    headers = ["Operator", "Schema", "Total Count", "Supported"]
+    print(tabulate(op_infos, headers=headers))
+
+    if output_csv:
+        # Open a CSV file for writing
+        with open("op_summary.csv", "w", newline="") as csv_file:
+            # Create a CSV writer object
+            csv_writer = csv.writer(csv_file)
+
+            csv_writer.writerow(headers)
+            # Write each table row to the CSV file
+            for row in op_infos:
+                csv_writer.writerow(row)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/visualize_sharding.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/visualize_sharding.py
new file mode 100644
index 0000000000000000000000000000000000000000..69a775b1a0f256c9de6a27801b52d222f5b8c7da
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/debug/visualize_sharding.py
@@ -0,0 +1,176 @@
+from typing import List, Sequence, Tuple
+
+import numpy as np
+
+from torch._prims_common import ShapeType
+from torch.distributed._tensor import DeviceMesh
+
+from torch.distributed._tensor.placement_types import Placement, Shard
+
+
+def _mesh_to_coordinate(mesh, device_type):
+    """
+    Given a n-dimensional list of device mesh, this function creates a map of
+    device and its coordinate
+    """
+    # Convert the n-dimensional list to a NumPy array
+    np_mesh = np.array(mesh.mesh.tolist())
+
+    # Create a dictionary to map each value to its coordinate
+    device_to_coordinate_map = {}
+    for coord, value in np.ndenumerate(np_mesh):
+        # device is unique in device_mesh
+        device_to_coordinate_map[f"{device_type}:{str(value)}"] = list(coord)
+
+    return device_to_coordinate_map
+
+
+def _convert_offset_to_ranges(all_offsets):
+    """
+    Using tabulate package to create a table is easier when we specify row and col ranges
+    This function converts offsets to ranges.
+    """
+    converted_blocks = []
+
+    for offset in all_offsets:
+        shape, offset, value = offset
+
+        # Calculate row_range and column_range
+        row_range = (offset[0], offset[0] + shape[0] - 1)
+        column_range = (offset[1], offset[1] + shape[1] - 1)
+
+        # Convert value to string to match your desired format
+        converted_block = {
+            "row_range": row_range,
+            "column_range": column_range,
+            "value": str(value),
+        }
+        converted_blocks.append(converted_block)
+
+    return converted_blocks
+
+
+def _create_table(blocks):
+    """
+    Creates a tabulate table given row and column ranges with device name
+    """
+    try:
+        from tabulate import tabulate
+    except ImportError as e:
+        raise ImportError("tabulate package is required to visualize sharding") from e
+
+    # Extract unique row and column ranges
+    row_ranges = sorted({block["row_range"] for block in blocks})
+    col_ranges = sorted({block["column_range"] for block in blocks})
+
+    # Create a matrix initialized with empty strings
+    matrix = [["" for _ in col_ranges] for _ in row_ranges]
+
+    # Fill the matrix with values
+    for block in blocks:
+        row_index = row_ranges.index(block["row_range"])
+        col_index = col_ranges.index(block["column_range"])
+        if matrix[row_index][col_index] == "":
+            matrix[row_index][col_index] = block["value"]
+        else:
+            matrix[row_index][col_index] += ", " + block["value"]
+
+    # Prepare headers
+    row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges]
+    col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges]
+
+    return tabulate(matrix, headers=col_headers, showindex=row_headers)
+
+
+def compute_local_shape_and_global_offset(
+    global_shape: ShapeType,
+    mesh: DeviceMesh,
+    placements: Sequence[Placement],
+    my_coordinate: List[int],
+) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
+    """
+    Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but
+    with custom my_coordinate input. This is the modified implementation for visualize_sharding.
+    """
+
+    if my_coordinate is None:
+        # if rank not in the mesh, return empty offset
+        return ((), ())
+    else:
+        local_shape = list(global_shape)
+        global_offset = [0] * len(global_shape)
+
+        for idx, placement in enumerate(placements):
+            mesh_dim_size = mesh.size(idx)
+            if isinstance(placement, Shard):
+                shard_dim = placement.dim
+                local_offset = [0] * len(global_shape)
+                assert shard_dim < len(
+                    local_shape
+                ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
+                shard_size, shard_offset = placement._local_shard_size_on_dim(
+                    local_shape[shard_dim],
+                    mesh_dim_size,
+                    my_coordinate[idx],
+                    return_offset=True,
+                )
+
+                local_shape[shard_dim] = shard_size
+                local_offset[shard_dim] = shard_offset
+
+                # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
+                # it means that this dimension has been already sharded in previous placement.
+                # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
+                # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
+                if global_offset[shard_dim] <= local_offset[shard_dim]:
+                    global_offset[shard_dim] = local_offset[shard_dim]
+                else:
+                    global_offset[shard_dim] += local_offset[shard_dim]
+
+        return tuple(local_shape), tuple(global_offset)
+
+
+def visualize_sharding(dtensor, header=""):
+    """
+    Visualizes sharding in 1D-2D dtensors
+    Requires tabulate, install with `pip install tabulate`
+
+    note: no sharding info will be printed for empty tensors
+    """
+    if dtensor.numel() == 0:  # we do not print for empty dtensors
+        return
+
+    if len(dtensor.shape) >= 3:
+        raise RuntimeError(
+            "visualize sharding is only implemented for 1D or 2D dtensor"
+        )
+    placements = dtensor.placements
+    device_mesh = dtensor.device_mesh
+    device_type = dtensor.device_mesh.device_type
+
+    if device_mesh.get_coordinate() is None:  # current rank is not in the mesh
+        return
+
+    # Only display the visualization once for each DTensor, on the rank whose
+    # coordinate is 0 on all dimensions. For example, if the mesh is a full mesh,
+    # we will only print on rank 0.
+    local_rank_zero_on_all_dim = all(
+        device_mesh.get_local_rank(mesh_dim=dim) == 0 for dim in range(device_mesh.ndim)
+    )
+    if not local_rank_zero_on_all_dim:
+        return
+
+    device_map = _mesh_to_coordinate(device_mesh, device_type)
+    all_offsets = []
+    for device in device_map:
+        local_shape, global_offset = compute_local_shape_and_global_offset(
+            dtensor.shape, device_mesh, placements, device_map[device]
+        )
+        all_offsets.append([local_shape, global_offset, device])
+
+    # Convert offsets to blocks with row_ranges for tabulate
+    blocks = _convert_offset_to_ranges(all_offsets)
+
+    # Print the table
+    print(header)
+    print(_create_table(blocks))
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/device_mesh.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/device_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c9c01ff186892d6b61097317d24cad9cc2c0cf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/device_mesh.py
@@ -0,0 +1,6 @@
+from torch.distributed.device_mesh import (  # noqa: F401
+    _get_device_handle,
+    _mesh_resources,
+    DeviceMesh,
+    init_device_mesh,
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/dispatch.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/dispatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7686dc5c2e762851b2f9a9dfdf5cfc5c2cd3267
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/dispatch.py
@@ -0,0 +1,393 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import functools
+import operator
+from typing import cast, Dict, List, Optional, Sequence, Tuple
+
+import torch
+
+import torch.distributed as dist
+import torch.distributed._tensor.api as dtensor
+import torch.distributed._tensor.random as random
+from torch.distributed._tensor._utils import try_find_mesh_from_args
+from torch.distributed._tensor.op_schema import (
+    _is_inplace_op,
+    _is_out_variant_op,
+    OpInfo,
+    OpSchema,
+    OutputSpecType,
+)
+from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta
+from torch.distributed._tensor.random import is_rng_supported_mesh
+from torch.distributed._tensor.redistribute import redistribute_local_tensor
+from torch.distributed._tensor.sharding_prop import ShardingPropagator
+from torch.distributed._tensor.tp_conv import (
+    convolution_backward_handler,
+    convolution_handler,
+)
+from torch.distributed.device_mesh import DeviceMesh
+
+try:
+    from torch.utils import _cxx_pytree as pytree
+except ImportError:
+    from torch.utils import _pytree as pytree  # type: ignore[no-redef]
+
+aten = torch.ops.aten
+
+
+def decompose_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    """
+    Decomposes a op to core ATen op, this handler is mostly here
+    for inference mode usage where the ops are not core aten ops.
+    """
+    r = op_call.decompose(*args, **kwargs)
+    if r is not NotImplemented:
+        return r
+    else:
+        raise RuntimeError("Decomposition failed")
+
+
+def is_same_size_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> bool:
+    lhs = cast(torch.Tensor, args[0])
+    rhs = cast(torch.Tensor, args[1])
+    return lhs.shape == rhs.shape
+
+
+class OpDispatcher:
+    """
+    Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
+    propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
+    also handles any op specific logic if necessary.
+    """
+
+    def __init__(self) -> None:
+        self.sharding_propagator = ShardingPropagator()
+        self._random_ops = {
+            aten.native_dropout.default,
+            aten.normal_.default,
+            aten.rand_like.default,
+            aten.randn_like.default,
+            aten.randint_like.default,
+            aten.randint_like.low_dtype,
+            aten.randint_like.low_dtype_out,
+            aten.uniform_.default,
+            aten.bernoulli.default,
+            aten.bernoulli_.float,
+        }
+        self._custom_op_handlers = {
+            aten.linear.default: decompose_handler,
+            aten.is_same_size.default: is_same_size_handler,
+            aten.convolution.default: convolution_handler,
+            aten.convolution_backward.default: convolution_backward_handler,
+        }
+
+        # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
+        # as implicitly replicated or we throw error to user.
+        # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
+        # it as False by default.
+        self._allow_implicit_replication = False
+
+    def dispatch(
+        self,
+        op_call: torch._ops.OpOverload,
+        args: Tuple[object, ...],
+        kwargs: Dict[str, object],
+    ) -> object:
+        """
+        Main dispatching logic
+        """
+        # operators that does not need to go through sharding propagation
+        if op_call in self._custom_op_handlers:
+            return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
+
+        # extract local tensor and sharding infos to a OpInfo
+        op_info = self.unwrap_to_op_info(op_call, args, kwargs)
+
+        self.sharding_propagator.propagate(op_info)
+        output_sharding = op_info.output_sharding
+        assert output_sharding is not None, "output sharding should not be None"
+
+        mesh = op_info.mesh
+        if mesh.get_coordinate() is None:
+            # For a non-participating device, we do:
+            #   1. if the return type is scalar, set the local result to None.
+            #   The local results from all devices will then be all-gathered
+            #   and a reduce op will be performed on the list of results
+            #   with appropriate operators:
+            #       for bool type, we by default use AND to reduce;
+            #       we can extend for more ops if necessary.
+            #   2. if the return type is Tensor or List[Tensor], return empty
+            #   tensor(s) with correct dtype.
+            spec = output_sharding.output_spec
+            ret_list = op_info.schema.op._schema.returns
+
+            if spec is None:
+                # For a scalar return type, the non-participating device has None
+                # as its local result
+                local_results: object = None
+            else:
+
+                def default_tensor(spec: DTensorSpec) -> torch.Tensor:
+                    if spec.tensor_meta is not None:
+                        shape = spec.tensor_meta.shape
+                        dtype = spec.tensor_meta.dtype
+                        if len(shape) == 0:
+                            # scalar tensor
+                            return torch.zeros((), dtype=dtype)
+                        else:
+                            # non-scalar tensor
+                            return torch.tensor([], dtype=dtype)
+                    else:
+                        raise RuntimeError(f"{spec} has no tensor metadata.")
+
+                if isinstance(spec, DTensorSpec):
+                    # return a Tensor value
+                    local_results = default_tensor(spec)
+                elif isinstance(spec, Sequence):
+                    # return a List[Tensor] value
+                    local_results = [
+                        default_tensor(s) if s is not None else None for s in spec
+                    ]
+                    assert isinstance(local_results, List)
+                    if None in local_results:
+                        ret_type = str(ret_list[0].type)
+                        raise NotImplementedError(
+                            f"return type {ret_type} in DTensor op is not supported"
+                        )
+        else:
+            if output_sharding.needs_redistribute:
+                # compute locally with redistribute first if needed
+                assert output_sharding.schema_suggestions is not None
+                self.redistribute_local_args(
+                    op_info, output_sharding.schema_suggestions[0]
+                )
+
+            local_tensor_args = (
+                pytree.tree_unflatten(
+                    cast(List[object], op_info.local_args), op_info.args_tree_spec
+                )
+                if op_info.args_tree_spec
+                else op_info.local_args
+            )
+
+            # run local op computation with potentially modified args/kwargs
+            local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
+            if op_call in self._random_ops and is_rng_supported_mesh(mesh):
+                if not random._rng_tracker:
+                    # Default to `OffsetBasedRNGTracker` if the parallelism API
+                    # did not already construct one
+                    random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
+                # For DTensor random operator, run it within a distribute region
+                with random._rng_tracker._distribute_region(
+                    cast(dtensor.DTensor, args[0])._spec
+                ):
+                    local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
+            else:
+                local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
+
+        # communicate the result to all ranks for some operators that return scalar value
+        if output_sharding.output_spec is None:
+            if op_call == aten.equal.default:
+                obj_list = [None for _ in range(dist.get_world_size())]
+                dist.all_gather_object(obj_list, local_results)  # type: ignore[possibly-undefined]
+                obj_list = list(filter(lambda x: x is not None, obj_list))
+                # perform reduce on the collection with AND op
+                local_results = functools.reduce(operator.and_, obj_list, True)
+
+        if _is_inplace_op(op_call):
+            # inplace op should return self instead of re-wrapping
+            if output_sharding.output_spec is not None:
+                return args[0]
+            else:
+                return None
+        elif _is_out_variant_op(op_call):
+            # out variant could possibly have multiple out args (i.e. lu_unpack.out)
+            output_specs = (
+                (output_sharding.output_spec,)
+                if not isinstance(output_sharding.output_spec, tuple)
+                else output_sharding.output_spec
+            )
+            out_dts = []
+            spec_idx = 0
+            for argument in op_call._schema.arguments:
+                if argument.is_out:
+                    out_dt = cast(dtensor.DTensor, kwargs[argument.name])
+                    out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
+                    out_dts.append(out_dt)
+                    spec_idx += 1
+
+            assert len(out_dts) >= 1, "out variant should have at least one out arg"
+            return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
+        else:
+            return self.wrap(local_results, output_sharding.output_spec)  # type: ignore[possibly-undefined]
+
+    @staticmethod
+    def redistribute_local_args(
+        op_info: OpInfo,
+        suggested_input_schema: OpSchema,
+    ) -> None:
+        # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
+
+        # TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten
+        # Need to fix all the ops before doing this.
+        if op_info.args_tree_spec is not None:
+            flatten_args_schema_to_reshard = tuple(
+                pytree.tree_leaves(suggested_input_schema.args_schema)
+            )
+        else:
+            flatten_args_schema_to_reshard = suggested_input_schema.args_schema
+
+        new_local_args: List[object] = []
+        for i, arg_spec in enumerate(op_info.flat_args_schema):
+            reshard_arg_spec = flatten_args_schema_to_reshard[i]
+            if isinstance(arg_spec, DTensorSpec):
+                local_tensor = cast(torch.Tensor, op_info.local_args[i])
+                if arg_spec != reshard_arg_spec:
+                    resharded_local_tensor = redistribute_local_tensor(
+                        local_tensor, arg_spec, reshard_arg_spec
+                    )
+                    new_local_args.append(resharded_local_tensor)
+                else:
+                    new_local_args.append(local_tensor)
+            else:
+                new_local_args.append(reshard_arg_spec)
+
+        op_info.local_args = tuple(new_local_args)
+
+    def unwrap_to_op_info(
+        self,
+        op_call: torch._ops.OpOverload,
+        args: Tuple[object, ...],
+        kwargs: Dict[str, object],
+    ) -> OpInfo:
+        # get runtime schema to determine whether to use pytree to flatten inputs
+        runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
+            op_call, None
+        )
+
+        if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
+            # flatten args/kwargs when necessary
+            tree_args, args_spec = pytree.tree_flatten(args)
+            args_list: Sequence[object] = tree_args
+        else:
+            args_list, args_spec = args, None
+
+        args_schema: List[object] = []
+        kwargs_schema: Dict[str, object] = {}
+        local_args: List[object] = []
+        local_kwargs: Dict[str, object] = {}
+        mesh: Optional[DeviceMesh] = None
+
+        for arg in args_list:
+            if isinstance(arg, dtensor.DTensor):
+                args_schema.append(arg._spec)
+                local_args.append(arg._local_tensor)
+                if mesh is not None:
+                    if mesh != arg.device_mesh:
+                        raise NotImplementedError(
+                            f"{op_call}: DTensor does not support cross-mesh operation yet!"
+                        )
+                else:
+                    mesh = arg.device_mesh
+            elif isinstance(arg, torch.Tensor):
+                if arg.ndim == 0 or self._allow_implicit_replication:
+                    mesh = mesh or try_find_mesh_from_args(op_call, args_list)
+                    # scalar tensor can be safely treated as replicated
+                    args_schema.append(
+                        DTensorSpec(
+                            mesh,
+                            (Replicate(),) * mesh.ndim,
+                            tensor_meta=TensorMeta(
+                                shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
+                            ),
+                        )
+                    )
+                    local_args.append(arg)
+                else:
+                    raise RuntimeError(
+                        f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
+                        " torch.Tensor to DTensor before calling distributed operators!"
+                    )
+            else:
+                args_schema.append(arg)
+                local_args.append(arg)
+
+        for k, v in kwargs.items():
+            if isinstance(v, dtensor.DTensor):
+                kwargs_schema[k] = v._spec
+                local_kwargs[k] = v._local_tensor
+                if mesh is not None:
+                    if mesh != v.device_mesh:
+                        raise NotImplementedError(
+                            f"{op_call}: DTensor does not support cross-mesh operation yet!"
+                        )
+                else:
+                    mesh = v.device_mesh
+            elif isinstance(v, torch.Tensor):
+                raise RuntimeError(
+                    f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
+                    " torch.Tensor to DTensor before calling distributed operators!"
+                )
+            else:
+                kwargs_schema[k] = v
+                local_kwargs[k] = v
+
+        assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
+        op_info = OpInfo(
+            mesh,
+            OpSchema(
+                op_call,
+                pytree.tree_unflatten(args_schema, args_spec)
+                if args_spec
+                else tuple(args_schema),
+                kwargs_schema,
+                schema_info=runtime_schema_info,
+            ),
+            args_schema,
+            tuple(local_args),
+            local_kwargs,
+            args_spec,
+        )
+        return op_info
+
+    @staticmethod
+    def wrap(res: object, spec: OutputSpecType) -> object:
+        if isinstance(res, torch.Tensor):
+            if spec is not None:
+                assert isinstance(
+                    spec, DTensorSpec
+                ), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
+                assert spec.tensor_meta is not None
+                return dtensor.DTensor(
+                    res,
+                    spec.mesh,
+                    spec.placements,
+                    shape=spec.tensor_meta.shape,
+                    dtype=spec.tensor_meta.dtype,
+                    requires_grad=res.requires_grad,
+                    stride=spec.tensor_meta.stride,
+                )
+            else:
+                # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
+                assert res.ndim == 0, "output tensor should be scalar!"
+                return res
+        elif isinstance(res, (list, tuple)):
+            assert spec is not None and isinstance(
+                spec, (list, tuple)
+            ), f"output spec does not match with output! Expected list/tuple, got {spec}."
+            res_list = []
+            for e, s in zip(res, spec):
+                res_list.append(OpDispatcher.wrap(e, s))
+
+            return tuple(res_list) if isinstance(res, tuple) else res_list
+        else:
+            # if the res contains only non tensor values (i.e. int/float/none), we simply return it
+            # without rewrapping to DTensor.
+            return res
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d488e14db16c07674d9652f410326ec9fe23fe9f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__init__.py
@@ -0,0 +1,12 @@
+from contextlib import contextmanager
+
+from torch.distributed._tensor.api import DTensor
+
+
+@contextmanager
+def implicit_replication():
+    try:
+        DTensor._op_dispatcher._allow_implicit_replication = True
+        yield
+    finally:
+        DTensor._op_dispatcher._allow_implicit_replication = False
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d7102a38dab4967db55f50def645a2c80878b1f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__pycache__/tp_transform.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__pycache__/tp_transform.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c0212a12848abbacfedb1505583e87f8da5ea7f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/__pycache__/tp_transform.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/tp_transform.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/tp_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..df100c4696341823e30dd131c07aa763e2c010e9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/experimental/tp_transform.py
@@ -0,0 +1,547 @@
+import copy
+import operator
+from typing import Any, cast, Dict, List, Optional, Sequence, Tuple
+
+import torch
+from torch._subclasses.fake_tensor import FakeTensor
+from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
+from torch.distributed._tensor.op_schema import (
+    DTensorSpec,
+    OpSchema,
+    OutputSharding,
+    OutputSpecType,
+    PlacementStrategy,
+)
+from torch.distributed._tensor.placement_types import (
+    Placement,
+    Replicate,
+    Shard,
+    TensorMeta,
+)
+from torch.distributed._tensor.redistribute import redistribute_local_tensor
+from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
+from torch.export import ExportedProgram
+from torch.export.exported_program import ExportGraphSignature
+from torch.fx import GraphModule
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.node import Node
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from torch.utils import _pytree as pytree
+
+
+aten = torch.ops.aten
+
+
+def tensor_parallel_transformation(
+    exported_program: ExportedProgram,
+    rank: int,
+    world_size: int,
+    device_type: str,
+    parallel_strategies: Dict[str, ParallelStyle],
+) -> ExportedProgram:
+    """
+    The entry point function to perform graph transformations on an exported program
+    to transform a single-device graph into a tensor parallel graph.
+
+    .. warning::
+        This API is experimental and subject to change.
+    """
+
+    gm = exported_program.graph_module
+    sig = copy.deepcopy(exported_program.graph_signature)
+    state_dict = copy.copy(exported_program.state_dict)
+
+    with gm._set_replace_hook(sig.get_replace_hook()):
+        res = TensorParallelTransformPass(
+            rank,
+            world_size,
+            device_type,
+            state_dict,
+            exported_program.graph_signature,
+            parallel_strategies,
+        )(gm)
+        assert res is not None
+        gm = res.graph_module
+
+    return exported_program._update(gm, sig, state_dict)
+
+
+class TensorParallelTransformPass(PassBase):
+    """
+    This pass is responsible for transforming a single-device graph into a tensor parallel
+    graph. It will mark the placement strategy of each node in the graph,
+    partition the graph into distributed graph, then shard the parameters/buffers accordingly.
+    """
+
+    def __init__(
+        self,
+        rank: int,
+        world_size: int,
+        device_type: str,
+        state_dict: Dict[str, torch.Tensor],
+        graph_signature: ExportGraphSignature,
+        parallel_strategies: Dict[str, ParallelStyle],
+    ) -> None:
+        super().__init__()
+        self.rank = rank
+        self.mesh = DeviceMesh(device_type, torch.arange(world_size))
+        self.state_dict: Dict[str, torch.Tensor] = state_dict
+        self.graph_signature = graph_signature
+        self.parallel_strategies = parallel_strategies
+
+    def call(self, graph_module) -> PassResult:
+        gm = copy.deepcopy(graph_module)
+
+        parameter_placements = _generate_parameter_and_buffer_placements(
+            list(self.state_dict.keys()), self.parallel_strategies
+        )
+        placement_strategies = _mark_sharding(
+            gm, self.graph_signature, self.mesh, parameter_placements
+        )
+        _partitioner(gm)
+        _shard_state_dict(
+            self.state_dict, placement_strategies, self.graph_signature, self.mesh
+        )
+        return PassResult(gm, True)
+
+
+def _generate_parameter_and_buffer_placements(
+    params_and_buffers: List[str],
+    parallel_strategies: Dict[str, ParallelStyle],
+) -> Dict[str, Placement]:
+    """
+    Build parameter placements based on the give parallel style of linear layers.
+    """
+    parameter_placements: Dict[str, Placement] = {}
+    for linear_fqn, parallel_style in parallel_strategies.items():
+        weight_fqn = f"{linear_fqn}.weight"
+        bias_fqn = f"{linear_fqn}.bias"
+        assert weight_fqn in params_and_buffers
+        parameter_placements[weight_fqn] = (
+            Shard(0) if parallel_style == ColwiseParallel else Shard(1)
+        )
+        if bias_fqn in params_and_buffers:
+            parameter_placements[bias_fqn] = (
+                Shard(0) if parallel_style == ColwiseParallel else Replicate()
+            )
+    return parameter_placements
+
+
+def _mark_tensor_parallel_shardings(
+    gm: GraphModule,
+    graph_signature: ExportGraphSignature,
+    mesh: DeviceMesh,
+    parameter_placements: Dict[str, Placement],
+) -> Dict[Node, PlacementStrategy]:
+    """
+    Mark the placement strategies of the parameter and buffer placeholder nodes.
+    """
+    placement_strategies: Dict[Node, PlacementStrategy] = {}
+    num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len(
+        graph_signature.inputs_to_buffers
+    )
+    placeholder_idx: int = 0
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            if placeholder_idx < num_params_and_buffers:
+                fqn: str = _get_input_node_fqn(node.name, graph_signature)
+                placement: Placement = (
+                    parameter_placements[fqn]
+                    if fqn in parameter_placements
+                    else Replicate()
+                )
+                placement_strategies[node] = _create_placement_strategy(
+                    node,
+                    mesh,
+                    placements=(placement,),
+                )
+                placeholder_idx += 1
+            else:
+                placement_strategies[node] = _create_placement_strategy(
+                    node,
+                    mesh,
+                    placements=(Replicate(),),
+                )
+    return placement_strategies
+
+
+def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str:
+    """
+    Return the FQN of an input node.
+    """
+    if input_name in graph_signature.inputs_to_parameters:
+        return graph_signature.inputs_to_parameters[input_name]
+    elif input_name in graph_signature.inputs_to_buffers:
+        return graph_signature.inputs_to_buffers[input_name]
+    else:
+        raise ValueError(
+            f"{input_name} not found in inputs_to_parameters or inputs_to_buffers"
+        )
+
+
+def _mark_sharding(
+    gm: GraphModule,
+    graph_signature: ExportGraphSignature,
+    mesh: DeviceMesh,
+    parameter_placements: Dict[str, Placement],
+) -> Dict[Node, PlacementStrategy]:
+    """
+    Mark the sharding strategy for each node in the graph module.
+    """
+    placement_strategies: Dict[
+        Node, PlacementStrategy
+    ] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
+
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            if node not in placement_strategies:
+                placement_strategies[node] = _create_placement_strategy(
+                    node, mesh, placements=(Replicate(),)
+                )
+            node.meta["sharding"] = placement_strategies[node]
+        elif node.op == "call_function":
+            if node.target == operator.getitem:
+                input_nodes = node.all_input_nodes
+                assert (
+                    len(input_nodes) == 1
+                ), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
+                arg_strategy = placement_strategies[input_nodes[0]]
+                placement_strategies[node] = _create_placement_strategy(
+                    node,
+                    mesh,
+                    placements=arg_strategy.output_spec.placements,
+                    input_specs=_get_input_node_specs(node, placement_strategies),
+                )
+                node.meta["sharding"] = placement_strategies[node]
+            else:
+                op_schema = _get_op_schema(node, placement_strategies)
+
+                # get DTensor specs for inputs and outputs
+                if (
+                    op_schema.op
+                    not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
+                    and op_schema.op
+                    not in DTensor._op_dispatcher.sharding_propagator.op_to_rules
+                ):
+                    # Mark all as replicated
+                    output_sharding = _generate_default_output_sharding(
+                        node,
+                        mesh,
+                        op_schema,
+                    )
+                else:
+                    output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding(
+                        op_schema,
+                    )
+                placement_strategies[node] = PlacementStrategy(
+                    output_specs=_get_output_spec_from_output_sharding(output_sharding),
+                    input_specs=output_sharding.schema_suggestions[0].args_spec
+                    if output_sharding.schema_suggestions is not None
+                    else _get_input_node_specs(node, placement_strategies),
+                )
+                node.meta["sharding"] = placement_strategies[node]
+        elif node.op == "output":
+            node.meta["sharding"] = None
+        else:
+            raise RuntimeError(f"op code {node.op} not supported")
+    return placement_strategies
+
+
+def _get_output_spec_from_output_sharding(
+    output_sharding: OutputSharding,
+) -> DTensorSpec:
+    """
+    Util function to extract output spec from output sharding.
+    """
+    if isinstance(output_sharding.output_spec, DTensorSpec):
+        return output_sharding.output_spec
+    else:
+        # For ops that return multiple outputs, the outputs should have the same output spec
+        assert isinstance(output_sharding.output_spec, Sequence)
+        assert output_sharding.output_spec[0] is not None
+        output_sharding.output_spec[0].tensor_meta = None
+        return output_sharding.output_spec[0]
+
+
+def _create_placement_strategy(
+    node: Node,
+    mesh: DeviceMesh,
+    placements: Tuple[Placement, ...],
+    input_specs: Optional[Sequence[DTensorSpec]] = None,
+) -> PlacementStrategy:
+    """
+    Util function to construct a placement strategy for a given node.
+    """
+    placement = PlacementStrategy(
+        input_specs=input_specs,
+        output_specs=DTensorSpec(
+            mesh=mesh,
+            placements=placements,
+        ),
+    )
+    _populate_tensor_meta(node, placement.output_specs)
+    return placement
+
+
+def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None:
+    """
+    Util function to populate tensor meta of output_spec based on node metadata.
+    """
+    if isinstance(node.meta["val"], Sequence):
+        assert isinstance(output_spec, Sequence)
+        for spec, fake_tensor in zip(output_spec, node.meta["val"]):
+            assert spec is not None
+            spec.tensor_meta = TensorMeta(
+                shape=fake_tensor.shape,
+                stride=fake_tensor.stride(),
+                dtype=fake_tensor.dtype,
+            )
+    else:
+        assert isinstance(output_spec, DTensorSpec)
+        output_spec.tensor_meta = TensorMeta(
+            shape=node.meta["val"].shape,
+            stride=node.meta["val"].stride(),
+            dtype=node.meta["val"].dtype,
+        )
+
+
+def _generate_default_output_sharding(
+    node: Node,
+    mesh: DeviceMesh,
+    op_schema: OpSchema,
+) -> OutputSharding:
+    """
+    Util function to create a default output sharding that suggests Replicate placement for both args and outputs.
+    """
+
+    def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec:
+        return DTensorSpec(
+            mesh=arg_spec.mesh,
+            placements=(Replicate(),),
+            tensor_meta=arg_spec.tensor_meta,
+        )
+
+    new_op_schema = OpSchema(
+        op=op_schema.op,
+        args_schema=pytree.tree_map_only(
+            DTensorSpec, update_arg_spec, op_schema.args_schema
+        ),
+        kwargs_schema=op_schema.kwargs_schema,
+    )
+
+    def create_output_spec(tensor: FakeTensor) -> DTensorSpec:
+        return DTensorSpec(
+            mesh=mesh,
+            placements=(Replicate(),),
+            tensor_meta=TensorMeta(
+                shape=tensor.shape,
+                stride=tensor.stride(),
+                dtype=tensor.dtype,
+            ),
+        )
+
+    return OutputSharding(
+        output_spec=pytree.tree_map_only(
+            FakeTensor, create_output_spec, node.meta["val"]
+        ),
+        schema_suggestions=[new_op_schema],
+        failed_reason=f"{node.op} does not have sharding strategy registered",
+        needs_redistribute=True,
+    )
+
+
+def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Graph partitioner that partitions the single device graph
+    to distributed graph
+    """
+    for node in gm.graph.nodes:
+        node_sharding = node.meta["sharding"]
+        if node.op == "placeholder":
+            out_spec = node_sharding.output_spec
+            local_val = _partition_val(node.meta["val"], out_spec)
+            # update node value
+            node.meta["val"] = local_val
+        elif node.op == "call_function":
+            out_spec = node_sharding.output_spec
+            # check if there's misaligned sharding, insert reshard if there is
+            expected_input_specs = node_sharding.input_specs
+            for idx, input_arg in enumerate(node.all_input_nodes):
+                input_arg_sharding = input_arg.meta["sharding"]
+                input_arg_spec = input_arg_sharding.output_spec
+                desired_spec = (
+                    out_spec
+                    if expected_input_specs is None
+                    else expected_input_specs[idx]
+                )
+                if input_arg_spec != desired_spec:
+                    _insert_reshard_gm(
+                        gm, node, input_arg, input_arg_spec, desired_spec
+                    )
+            # convert output val to its local component
+            output_val = node.meta["val"]
+            node.meta["val"] = _partition_val(output_val, out_spec)
+        elif node.op == "output":
+            for input_arg in node.all_input_nodes:
+                # input args of output should be Replicate, otherwise redistribution is needed.
+                input_args_to_check: Sequence[Node] = (
+                    input_arg if isinstance(input_arg, Sequence) else [input_arg]
+                )
+                for arg in input_args_to_check:
+                    arg_sharding = arg.meta["sharding"]
+                    arg_spec = arg_sharding.output_spec
+                    desired_spec = copy.copy(arg_spec)
+                    desired_spec.placements = (Replicate(),)
+                    if arg_spec != desired_spec:
+                        _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec)
+        else:
+            raise RuntimeError(f"op code {node} not supported")
+
+    _clean_up_graph_metadata(gm)
+    gm.graph.lint()
+    gm.recompile()
+    return gm
+
+
+def _partition_val(val: Any, spec: DTensorSpec) -> Any:
+    """
+    util function to convert a full tensor val to its local component
+    """
+    if isinstance(val, torch.Tensor):
+        local_shard = val
+        if val.ndim == 0:
+            # If it's already a scalar tensor, it is already local, we don't
+            # need to do anything
+            return local_shard
+
+        for idx, placement in enumerate(spec.placements):
+            if placement.is_shard():
+                placement = cast(Shard, placement)
+                num_chunks = spec.mesh.size(mesh_dim=idx)
+                my_coord = spec.mesh.get_coordinate()
+                assert my_coord is not None, "current rank not in mesh!"
+                my_coord_on_mesh_dim = my_coord[idx]
+                local_shard = placement._split_tensor(
+                    local_shard, num_chunks, with_padding=False, contiguous=True
+                )[0][my_coord_on_mesh_dim]
+        return local_shard
+    elif isinstance(val, (list, tuple)):
+        return val.__class__(_partition_val(v, spec) for v in val)
+    else:
+        raise RuntimeError(f"val type {type(val)} not supported")
+
+
+def _insert_reshard_gm(
+    gm: torch.fx.GraphModule,
+    node: Node,
+    input_arg: Node,
+    input_arg_spec: DTensorSpec,
+    desired_spec: DTensorSpec,
+) -> None:
+    """
+    Transform the graph for tensor redistribution.
+    """
+    input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
+    desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
+    input_arg_tensor = input_arg.meta["val"]
+
+    # insert reshard operation
+    def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
+        return redistribute_local_tensor(
+            local_tensor,
+            input_arg_spec,
+            desired_spec,
+        )
+
+    reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
+    reshard_gm_nodes = list(reshard_gm.graph.nodes)
+    input_node = reshard_gm_nodes[0]
+    with gm.graph.inserting_before(node):
+        output_node = gm.graph.graph_copy(
+            reshard_gm.graph,
+            val_map={
+                input_node: input_arg,
+            },
+        )
+    node.replace_input_with(input_arg, output_node)
+
+
+def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
+    """
+    Clean up the graph by removing sharding and partitioning related metadata
+    """
+    for node in gm.graph.nodes:
+        if "sharding" in node.meta:
+            del node.meta["sharding"]
+        if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
+            local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
+            node.meta["tensor_meta"] = local_tensor_meta
+
+
+def _get_input_node_specs(
+    node: Node, placement_strategies: Dict[Node, PlacementStrategy]
+) -> Tuple[DTensorSpec, ...]:
+    """
+    Get the input specs of a node.
+    """
+    input_specs_list: List[DTensorSpec] = []
+    for input_arg in node.all_input_nodes:
+        if input_arg in placement_strategies:
+            output_spec = placement_strategies[input_arg].output_specs
+            assert isinstance(output_spec, DTensorSpec)
+            input_specs_list.append(output_spec)
+        else:
+            raise ValueError(f"{input_arg} does not have output_spec populated.")
+    return tuple(input_specs_list)
+
+
+def _get_op_schema(
+    node: Node, placement_strategies: Dict[Node, PlacementStrategy]
+) -> OpSchema:
+    """
+    Util function to construct the operator schema of a node.
+    """
+    args_schema_list = pytree.tree_map_only(
+        Node, lambda arg: placement_strategies[arg].output_specs, node.args
+    )
+    op_schema = OpSchema(
+        op=cast(torch._ops.OpOverload, node.target),
+        args_schema=tuple(args_schema_list),
+        kwargs_schema=cast(Dict[str, object], node.kwargs),
+    )
+    return op_schema
+
+
+def _shard_state_dict(
+    state_dict: Dict[str, torch.Tensor],
+    placement_strategies: Dict[Node, PlacementStrategy],
+    graph_signature: ExportGraphSignature,
+    mesh: DeviceMesh,
+) -> None:
+    """
+    Inplace partition the weights based on the placement strategy
+    """
+    for node, placement_strategy in placement_strategies.items():
+        if node.op != "placeholder":
+            continue
+        if node.name in graph_signature.inputs_to_parameters:
+            fqn = graph_signature.inputs_to_parameters[node.name]
+        elif node.name in graph_signature.inputs_to_buffers:
+            fqn = graph_signature.inputs_to_buffers[node.name]
+        else:
+            continue
+        assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}"
+
+        original_param = state_dict[fqn]
+        dtensor_param = distribute_tensor(
+            original_param,
+            mesh,
+            placement_strategy.output_spec.placements,
+        )
+        local_param = dtensor_param.to_local()
+        state_dict[fqn] = (
+            torch.nn.Parameter(local_param)
+            if isinstance(original_param, torch.nn.Parameter)
+            else local_param
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/op_schema.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/op_schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ff8c6db88de8b3a0fbbeb3bd1f0c97cec77f5e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/op_schema.py
@@ -0,0 +1,427 @@
+from dataclasses import dataclass
+from functools import cached_property
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+from torch._ops import OpOverload
+from torch.distributed._tensor.placement_types import DTensorSpec
+from torch.distributed.device_mesh import DeviceMesh
+
+try:
+    from torch.utils._cxx_pytree import tree_map_only, TreeSpec
+except ImportError:
+    from torch.utils._pytree import (  # type: ignore[no-redef, assignment]
+        tree_map_only,
+        TreeSpec,
+    )
+
+
+# Common type aliases
+ArgsType = Tuple[object, ...]
+KwargsType = Dict[str, object]
+# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
+# be the same set of possibilities.
+OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
+
+
+def _rebuild_tensor_from_dtensor_meta(arg) -> object:
+    """
+    This is used to propagate tensor metadata, must be under fake mode
+    """
+    assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
+    return torch.empty_strided(
+        arg.tensor_meta.shape,
+        arg.tensor_meta.stride,
+        dtype=arg.tensor_meta.dtype,
+    )
+
+
+def _is_inplace_op(op: OpOverload):
+    # simple analysis of function schema to determine
+    # if this is an inplace variant, it might not
+    # be entirely correct, but it's good enough for now.
+    return op._schema.name[-1] == "_"
+
+
+def _is_out_variant_op(op: OpOverload):
+    # simple analysis of function schema to determine
+    # if this is an out variant, it might not
+    # be entirely correct, but it's good enough for now.
+    return "out" in op._schema.overload_name
+
+
+def _pretty_print_spec(spec: object) -> str:
+    if spec is None:
+        return "None"
+    elif isinstance(spec, DTensorSpec):
+        return "".join([str(p) for p in spec.placements])
+    elif isinstance(spec, Sequence):
+        return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
+    else:
+        raise RuntimeError(f"Unknown spec type to print: spec={spec}")
+
+
+@dataclass
+class PlacementStrategy:
+    """
+    A placement strategy describes acceptable sharding placements of the output
+    and the tensor arguments of an operation.
+
+    note: when the op return value is a single DTensor object, output_specs is
+    DTensorSpec; when the return value is a tuple of Optional[DTensor],
+    output_specs is a tuple of Optional[DTensorSpec].
+    """
+
+    output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]]
+    input_specs: Optional[Sequence[DTensorSpec]] = None
+
+    # redistribute costs for this op placement strategy
+    # we need a nested list to record the cost for each
+    # operand of this operator, and for each operand of
+    # this operator it might have multiple placement strategies
+    redistribute_cost: Optional[List[List[float]]] = None
+
+    @cached_property
+    def output_spec(self) -> DTensorSpec:
+        """
+        This function requires that the strategy have exactly one DTensorSpec as the
+        output spec. If the output_specs is a tuple, we throw an exception.
+        """
+        if isinstance(self.output_specs, DTensorSpec):
+            return self.output_specs
+        else:
+            raise ValueError(
+                f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
+            )
+
+    def input_spec(self, index: int = 0) -> DTensorSpec:
+        assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
+        assert len(self.input_specs) > index, (
+            f"Invalid index {index} for input_specs of length "
+            f"{len(self.input_specs)}: {self.input_specs}"
+        )
+        return self.input_specs[index]
+
+    def __str__(self) -> str:
+        input_specs_str = _pretty_print_spec(self.input_specs)
+        output_spec_str = _pretty_print_spec(self.output_specs)
+        return f"{input_specs_str} -> {output_spec_str}"
+
+
+class StrategyType:
+    """
+    Base class type for op strategy, We have two StrategyType:
+        OpStrategy and TupleStrategy
+    """
+
+    pass
+
+
+class OpStrategy(StrategyType):
+    """
+    OpStrategy that consists of a list of placement strategies associated with the op
+    """
+
+    def __init__(self, strategies: List[PlacementStrategy]) -> None:
+        super().__init__()
+        self.strategies: List[PlacementStrategy] = strategies
+
+    def __str__(self) -> str:
+        strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
+        mesh_shape = self.output_mesh_shape
+        return f"OpStrategy:[{strategy_list_str}] @ mesh: {mesh_shape}"
+
+    def max_num_shards(self) -> int:
+        """
+        Returns the max number of shards across all placement strategies
+        """
+        return max([strategy.output_spec.num_shards for strategy in self.strategies])
+
+    @property
+    def output_mesh_shape(self):
+        output_spec = self.strategies[0].output_specs
+        if isinstance(output_spec, DTensorSpec):
+            return output_spec.mesh.shape
+        else:
+            assert isinstance(
+                output_spec, tuple
+            ), "found no DTensorSpec in the OpStrategy!"
+            assert output_spec[0] is not None
+            return output_spec[0].mesh.shape
+
+    @property
+    def output_ndim(self):
+        return self.strategies[0].output_spec.ndim
+
+    @property
+    def output_shape(self):
+        return self.strategies[0].output_spec.shape
+
+
+class TupleStrategy(StrategyType):
+    """
+    TupleStrategy represents the output strategy of this op is a tuple
+    of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
+    with possibly different placement strategies, we should return a TupleStrategy that
+    contains a tuple of OpStrategy, where each child represents the sharding strategy
+    of "each element" of the tuple/list of tensors the op returns.
+
+    NOTE: if the output of the op is a List[Tensor] and they share the same placement
+    strategy, then we should return a single OpStrategy instead of a TupleStrategy
+    """
+
+    def __init__(self, childs: Sequence[StrategyType]) -> None:
+        super().__init__()
+        self.childs: Sequence[StrategyType] = childs
+
+    def __str__(self) -> str:
+        child_strategies_str = ", ".join(
+            [f"{str(strat)}" for idx, strat in enumerate(self.childs)]
+        )
+        return f"TupleStrategy({child_strategies_str})"
+
+
+@dataclass
+class RuntimeSchemaInfo:
+    """
+    RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
+    execution. This is mainly used for two ways: 1. to generate hash for args to determine
+    whether to re-run sharding prop or not 2. to determine if we need pytree
+    """
+
+    # This static_argnum records static arg "starting index" for ops that have non-tensor
+    # args/kwargs which would affect sharding propagation results. All args starting from
+    # this index would be hashed to our sharding cache.
+    # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
+    static_argnum: int = 100
+    # This static_kwargkey records static kwarg names which would affect sharding prop
+    static_kwargkey: Optional[List[str]] = None
+    # each op can decide if it wants to use pytree flatten/unflatten during operator
+    # eager execution, by default we don't need to do flatten/unflatten, only if the
+    # op indicate it needs to, this is to accelate eager performance.
+    needs_pytree: bool = False
+
+
+@dataclass
+class OpSchema:
+    """
+    OpSchema is a data class that describes an operator input schemas, it
+    includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order
+    preserved). It is mainly used by the dispatching logic below to run things like
+    sharding propagation.
+
+    NOTE: this should be used as a read only data class
+    TODO: make this a frozen dataclass
+
+    Args:
+        op: the operator overload we are intercepting
+        args_schema: contains args except that the DTensor args have been replaced
+            with its DTensorSpec
+        kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
+            with its DTensorSpec
+    """
+
+    op: OpOverload
+    args_schema: ArgsType
+    kwargs_schema: KwargsType
+
+    schema_info: Optional[RuntimeSchemaInfo] = None
+
+    @property
+    def args_spec(self) -> Tuple[DTensorSpec, ...]:
+        """
+        args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
+            with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
+            mainly used by sharding propagation to propagate the output spec
+        """
+        # filter out non-relevant values from args schema to get a clean spec list
+        # this would mainly be used by sharding propagation rules
+        return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec))
+
+    def __repr__(self) -> str:
+        return (
+            f"OpSchema(op={self.op},"
+            f" args_schema={self.args_schema},"
+            f" kwargs_schema={self.kwargs_schema})"
+        )
+
+    def __str__(self) -> str:
+        args_sharding: List[str] = []
+        mesh_shape = None
+        for arg in self.args_schema:
+            if isinstance(arg, DTensorSpec):
+                args_sharding.append(str(arg))
+                mesh_shape = arg.mesh.shape
+            elif isinstance(arg, OpStrategy):
+                assert len(arg.strategies) == 1
+                args_sharding.append(_pretty_print_spec(arg.strategies[0].output_specs))
+                mesh_shape = arg.output_mesh_shape
+            elif isinstance(arg, TupleStrategy):
+                first_op_strtgy = arg.childs[0]
+                assert isinstance(first_op_strtgy, OpStrategy)
+                mesh_shape = first_op_strtgy.output_mesh_shape
+                args_sharding.append(str(arg))
+            else:
+                args_sharding.append(str(arg))
+        return f"Op(op={self.op}, args_sharding={', '.join(args_sharding)} @ mesh: {mesh_shape})"
+
+    def __post_init__(self) -> None:
+        has_symints = False
+        for a in self.args_schema:
+            if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
+                if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
+                    has_symints = True
+                    break
+        self.has_symints = has_symints
+
+    def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
+        arg = self.args_schema[arg_idx]
+        is_tensor = isinstance(arg, DTensorSpec)
+        if is_tensor:
+            return True
+
+        if not isinstance(arg, list):
+            return False
+
+        return all(isinstance(e, DTensorSpec) or e is None for e in arg)
+
+    def return_type_tuple_tensor_like(self) -> bool:
+        # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
+        # in the tuple, but the first element must be a Tensor, so this check is enough
+        return_types = self.op._schema.returns
+        return len(return_types) > 1 and isinstance(
+            return_types[0].type, torch.TensorType
+        )
+
+    def return_type_tensor(self) -> bool:
+        return_types = self.op._schema.returns
+        # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
+        # return types, so this check is enough for tensor like types
+        return isinstance(return_types[0].type, torch.TensorType)
+
+    def __hash__(self) -> int:
+        # Only hash args and kwargs that op indicates to hash
+        if not self.schema_info:
+            static_argnum = len(self.args_schema)
+            static_kwargkey = None
+        else:
+            static_argnum = self.schema_info.static_argnum
+            static_kwargkey = self.schema_info.static_kwargkey
+
+        args_to_hash = tuple(
+            tuple(e) if isinstance(e, list) else e
+            for i, e in enumerate(self.args_schema)
+            if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
+        )
+        if static_kwargkey is not None:
+            kwargs_to_hash = tuple(
+                self.kwargs_schema.get(k, None) for k in static_kwargkey
+            )
+            return hash((self.op, args_to_hash, kwargs_to_hash))
+        else:
+            return hash((self.op, args_to_hash))
+
+    def __eq__(self, other: object) -> bool:
+        # early return checks
+        if not isinstance(other, OpSchema):
+            return False
+
+        if self.op != other.op:
+            return False
+
+        if len(self.args_schema) != len(other.args_schema):
+            return False
+
+        # compare each element and early return if any of them is different
+        if not self.schema_info:
+            static_argnum = len(self.args_schema)
+            static_kwargkey = None
+        else:
+            static_argnum = self.schema_info.static_argnum
+            static_kwargkey = self.schema_info.static_kwargkey
+
+        for i, (self_arg, other_arg) in enumerate(
+            zip(self.args_schema, other.args_schema)
+        ):
+            if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
+                return False
+            elif i >= static_argnum and self_arg != other_arg:
+                return False
+
+        # check kwarg equality when there's a static kwarg key
+        if static_kwargkey:
+            for key in static_kwargkey:
+                if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
+                    key, None
+                ):
+                    return False
+
+        return True
+
+    def gen_fake_args(self) -> ArgsType:
+        """
+        gen_fake_args: generate fake args for the operator, this is mainly used
+            by sharding propagation rules to generate fake args for the operator
+            to run the local tensor operator and get the output spec.
+        """
+        return tree_map_only(
+            DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
+        )
+
+    def gen_fake_kwargs(self) -> KwargsType:
+        """
+        gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
+            by sharding propagation rules to generate fake kwargs for the operator
+            to run the local tensor operator and get the output spec.
+        """
+        return tree_map_only(
+            DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
+        )
+
+    def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
+        suggestion_args_spec = self.args_spec
+        new_arg_schema: List[object] = []
+        idx_of_args_spec = 0
+        for arg in origin_schema.args_schema:
+            if isinstance(arg, DTensorSpec):
+                new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
+                idx_of_args_spec += 1
+            else:
+                new_arg_schema.append(arg)
+        self.args_schema = tuple(new_arg_schema)
+        self.kwargs_schema = origin_schema.kwargs_schema
+
+
+@dataclass
+class OutputSharding:
+    """
+    OutputSharding is a data class that is used by the sharding propagation
+    rules, it could set the output_spec upon successful propagation, and if
+    it failed, output_spec would become None and sharding propagation rules
+    could give a list of suggestions for inputs to reshard.
+
+    NOTE: the schema_suggestion generated by sharding propagation should be
+    exactly the same as the operator OpSchema, except the DTensor DTensorSpecs
+    """
+
+    output_spec: OutputSpecType
+    schema_suggestions: Optional[List[OpSchema]] = None
+    failed_reason: Optional[str] = None
+    needs_redistribute: bool = False
+
+
+@dataclass
+class OpInfo:
+    """
+    All Runtime Op execution info are packed here
+    """
+
+    mesh: DeviceMesh
+    schema: OpSchema
+    flat_args_schema: List[object]
+    local_args: Sequence[object]
+    local_kwargs: Dict[str, object]
+    args_tree_spec: Optional[TreeSpec] = None
+
+    # the output sharding info
+    output_sharding: Optional[OutputSharding] = None
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9112fa0b61e8bf65a5fc7076fb38bf00c8fcc62f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from .embedding_ops import *  # noqa: F403
+from .matrix_ops import *  # noqa: F403
+from .math_ops import *  # noqa: F403
+from .tensor_ops import *  # noqa: F403
+from .pointwise_ops import *  # noqa: F403
+from .random_ops import *  # noqa: F403
+from .view_ops import *  # noqa: F403
+from .conv_ops import *  # noqa: F403
+from .experimental_ops import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7f704e9afe04142b176c5e9c34c3d0219ec93a6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/basic_strategy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/basic_strategy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8139bf72256c9a508f394cc5927329a42852b88
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/basic_strategy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/common_rules.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/common_rules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28a276c5632895e6e5c3c6c72676cd81f09672af
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/common_rules.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/conv_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/conv_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6bd98092828a1470515d2c253a3c21d6963d6b4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/conv_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/embedding_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/embedding_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c73017724fb778dc83fc06cd91f75ab79adc392
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/embedding_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/experimental_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/experimental_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0922ba345ac4389d1eabd0b88691e9f51b381ef0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/experimental_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/math_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/math_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec85984ef9d3b488a6b8e5daa16f7b70a9c969f7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/math_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/matrix_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/matrix_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..365099bc3b6bfcefa6de29d6bad77ea00247ff92
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/matrix_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/pointwise_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/pointwise_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da315038137f27976d39e6270b77a90fc95bbc8a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/pointwise_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/random_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/random_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..993c31f8da2f4c27211d87134034acea5d0a8562
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/random_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/tensor_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/tensor_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..678e418acdf02dd422c2b68d9ebcf6a3a659bce9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/tensor_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13e4fb2ff4643509626dc9b5c7e3469a2d29e237
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/view_ops.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/view_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ace9bf611ec98e31822368cba9d4b616fff9562
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/__pycache__/view_ops.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/basic_strategy.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/basic_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3ea85fd3f1551669d8298090fc6e57bae1fbf8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/basic_strategy.py
@@ -0,0 +1,184 @@
+import itertools
+from dataclasses import dataclass
+
+from typing import List, Tuple
+
+from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+from torch.distributed.device_mesh import DeviceMesh
+
+
+@dataclass
+class EinsumDims:
+    contracting_dims: List[str]
+    batch_dims: List[str]
+    lhs_out_only_dims: List[str]
+    rhs_out_only_dims: List[str]
+
+    @classmethod
+    def parse_equation(cls, equation: str) -> Tuple[List[str], str]:
+        # parse einop equation and extract arg specs
+        """
+        Parse the einsum equation str to input dim chars and output dim char
+        """
+        inputs, outputs = equation.split("->")
+        input_dims, output_dims = inputs.split(","), outputs.split(",")
+
+        # NOTE: only support at most two inputs, and single output
+        # extend to support more inputs if needed in future
+        assert len(input_dims) <= 2, "Only support at most two inputs"
+        assert len(output_dims) == 1, "Only support single output"
+        output_dim = output_dims[0]
+        return input_dims, output_dim
+
+    @classmethod
+    def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
+        """
+        Parse the dims and extract the contracting, batch, and free dimensions
+        for the left and right hand sides.
+        """
+        dim_char_set = set()
+        for input_dim in input_dims:
+            for input_char in list(input_dim):
+                dim_char_set.add(input_char)
+
+        # get a determinisitc order of all dim chars
+        all_dim_chars = sorted(dim_char_set)
+
+        # parse input and output dimensions
+        lhs_out_only_dims, rhs_out_only_dims = [], []
+        batch_dims, contracting_dims = [], []
+
+        for dim_char in all_dim_chars:
+            if dim_char not in output_dim:
+                contracting_dims.append(dim_char)
+            else:
+                is_batch_dim = True
+                for input_dim in input_dims:
+                    is_batch_dim = is_batch_dim and dim_char in input_dim
+
+                if is_batch_dim:
+                    batch_dims.append(dim_char)
+                else:
+                    assert (
+                        len(input_dims) == 2
+                    ), "free dimension only supported for two inputs!"
+                    lhs, rhs = input_dims
+                    if dim_char in lhs:
+                        lhs_out_only_dims.append(dim_char)
+                    elif dim_char in rhs:
+                        rhs_out_only_dims.append(dim_char)
+                    else:
+                        raise RuntimeError("Invalid dimension character")
+
+        return cls(
+            contracting_dims=contracting_dims,
+            batch_dims=batch_dims,
+            lhs_out_only_dims=lhs_out_only_dims,
+            rhs_out_only_dims=rhs_out_only_dims,
+        )
+
+
+def gen_einsum_strategies(
+    equation: str,
+    mesh: DeviceMesh,
+    *,
+    linearity: bool = False,
+) -> OpStrategy:
+    """
+    Generate a strategy list for the ops that follow einsum style notation.
+    """
+    # parse einop equation and extract dims
+    input_dims, output_dim = EinsumDims.parse_equation(equation)
+    edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+    all_mesh_dim_strategies = []
+
+    # generate strategies for each mesh dim
+    for mesh_dim in range(mesh.ndim):
+        mesh_dim_strategies = []
+
+        # placement list stores placements of [output, input1, input2, ...]
+        # first we always have replicate all for inputs and output
+        placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
+        mesh_dim_strategies.append(placement_list)
+
+        if mesh.size(mesh_dim) <= 1:
+            # only replicate strategy for mesh dim with size 1
+            # TODO: see if this is valid for the submesh case
+            continue
+
+        # split batch dim
+        for batch_dim in edims.batch_dims:
+            output_batch_dim = output_dim.index(batch_dim)
+            placement_list = [Shard(output_batch_dim)]
+            for input_dim in input_dims:
+                input_batch_dim = input_dim.index(batch_dim)
+                placement_list.append(Shard(input_batch_dim))
+
+            mesh_dim_strategies.append(placement_list)
+
+        # split contracting dim
+        for contracting_dim in edims.contracting_dims:
+            placement_list = [_Partial()]
+            for input_dim in input_dims:
+                input_contracting_dim = input_dim.index(contracting_dim)
+                placement_list.append(Shard(input_contracting_dim))
+
+            mesh_dim_strategies.append(placement_list)
+
+        # split lhs free dim
+        for lhs_dim in edims.lhs_out_only_dims:
+            lhs_free_dim = output_dim.index(lhs_dim)
+            # this means split the lhs input and output
+            # i.e. S(0), R -> S(0)
+            lhs_placement_list: List[Placement] = [
+                Shard(lhs_free_dim),
+                Shard(lhs_free_dim),
+                Replicate(),
+            ]
+            mesh_dim_strategies.append(lhs_placement_list)
+
+        # split rhs free dim
+        for rhs_dim in edims.rhs_out_only_dims:
+            rhs_free_dim = output_dim.index(rhs_dim)
+            rhs_placement_list: List[Placement] = [
+                Shard(rhs_free_dim),
+                Replicate(),
+                Shard(rhs_free_dim),
+            ]
+            mesh_dim_strategies.append(rhs_placement_list)
+
+        # linearity strategy
+        if linearity:
+            linearity_placement_list: List[Placement] = [_Partial()]
+            for input_dim in input_dims:
+                linearity_placement_list.append(_Partial())
+            mesh_dim_strategies.append(linearity_placement_list)
+
+        all_mesh_dim_strategies.append(mesh_dim_strategies)
+
+    # generate strategies for entire mesh
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    # TODO: filter out invalid strategies, at this point we generate
+    # all possible strategies without considering the whether the tensor
+    # dim could be sharded or not, we would need to filter out invalid
+    # strategies base on the actual tensor shape
+    # (i.e. for Shard, tensor dim size must > mesh size)
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = []
+        for specs in zip(*strategy_comb):
+            spec_list.append(DTensorSpec(mesh, tuple(specs)))
+        strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
+        all_strategies.append(strat)
+
+    return OpStrategy(all_strategies)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/common_rules.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/common_rules.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4add50dee169646009c706316a531070570913a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/common_rules.py
@@ -0,0 +1,289 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import cast, Dict, List, Optional, Tuple
+
+import torch
+from torch.distributed._tensor._utils import compute_local_shape
+from torch.distributed._tensor.op_schema import (
+    _is_inplace_op,
+    _is_out_variant_op,
+    OpSchema,
+    OutputSharding,
+)
+from torch.distributed._tensor.ops.utils import prod
+from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
+
+
+def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
+    return string[:idx] + new_char + string[idx + 1 :]
+
+
+def _gen_reshard_suggestions(
+    op_schema: OpSchema,
+    input_dims: List[str],
+    input_specs: Tuple[DTensorSpec, ...],
+    dim_to_sharding: Dict[str, int],
+    pending_sum: List[int],
+) -> OutputSharding:
+    suggested_arg_specs: List[DTensorSpec] = []
+    for input_dim, input_spec in zip(input_dims, input_specs):
+        dim_map = [dim_to_sharding[dim] for dim in input_dim]
+        suggested_arg_specs.append(
+            DTensorSpec.from_dim_map(
+                mesh=input_spec.mesh,
+                dim_map=dim_map,
+                sums=pending_sum,
+                tensor_meta=input_spec.tensor_meta,
+            )
+        )
+    suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {})
+    suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
+    return OutputSharding(
+        None,
+        schema_suggestions=[suggested_schema],
+        failed_reason="Input placements op sharding propagation failed, need to reshard!",
+    )
+
+
+def einop_rule(
+    equation: str,
+    op_schema: OpSchema,
+    *,
+    linearity: bool = False,
+    enforce_sharding: Optional[Dict[str, int]] = None,
+) -> OutputSharding:
+    """
+    Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.
+
+    This is mostly borrowed from @zdevito's sharding simulator. Examples:
+        mk,kn->mn - einsum
+        ij,ij->ij - addition
+        ij,j->ij - broadcasted addition
+        ij->i - reduction
+    Other ops could use this propagation algorithm when applied, note
+    that einsum propagation only deal with list of specs (DTensor specs)
+    as it only works on list of tensors!
+
+    linearity in einop_rule means that the calling op `f` follows this rule:
+        f(a + b) = f(a) + f(b)
+
+    In this case we can propagate the partial sum, note that linearity in einop
+    only applies to partial sum, not other operations like min/max (which are
+    associative but not linear).
+    """
+    # parse einop equation and extract arg specs
+    inputs, outputs = equation.split("->")
+    input_dims, output_dims = inputs.split(","), outputs.split(",")
+    input_specs = op_schema.args_spec
+    # NOTE: only support single output unless needed in future
+    output_dim = output_dims[0]
+
+    dim_to_sharding: Dict[str, int] = {}
+    dim_to_size: Dict[str, int] = {}
+    # record pending sum, key is mesh dimension, value is pending sum
+    # counter across input specs
+    pending_sums_counter: Dict[int, int] = {}
+    seen_shardings: Dict[int, str] = {}
+    needs_reshard = False
+
+    def merge_sharding(dim: str, a: int, b: int) -> int:
+        # merge the sharding of inputs if it's able to merge, i.e. we can merge
+        # replicate and shard to shard, but this will trigger an reshard operation
+        if a != b:
+            if a == -1 or b == -1:
+                # reshard the replicate to match the sharded one
+                nonlocal needs_reshard
+                needs_reshard = True
+                return a if a != -1 else b
+            else:
+                # TODO: further merge the sharding properly (i.e. reshard one input to replicate)
+                raise RuntimeError(
+                    f"{equation}: dim {dim} sharded two different ways: {a} and {b}"
+                )
+        else:
+            return a
+
+    for input_dim, input_spec in zip(input_dims, input_specs):
+        # deal with partial sums
+        input_sums = input_spec.sums
+        for sum_dim in input_sums:
+            if sum_dim not in pending_sums_counter:
+                seen_shardings[sum_dim] = "+"
+            # update pending sum counter for pending sum mesh
+            # dimension with the occurrence from each input
+            pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1
+
+        for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)):
+            if enforce_sharding and dim in enforce_sharding:
+                if enforce_sharding[dim] != mesh_dim:
+                    needs_reshard = True
+                dim_to_sharding[dim] = enforce_sharding[dim]
+                dim_to_size[dim] = input_spec.shape[idx]
+            elif dim not in dim_to_sharding:
+                dim_to_sharding[dim] = mesh_dim
+                dim_to_size[dim] = input_spec.shape[idx]
+            else:
+                dim_to_sharding[dim] = merge_sharding(
+                    dim, dim_to_sharding[dim], mesh_dim
+                )
+                assert dim_to_size[dim] == input_spec.shape[idx]
+
+            # after merging sharding, we check if there're multiple
+            # sharding on the same mesh dim.
+            merged_sharding_for_dim = dim_to_sharding[dim]
+            if merged_sharding_for_dim != -1:
+                if (
+                    merged_sharding_for_dim in seen_shardings
+                    and dim != seen_shardings[merged_sharding_for_dim]
+                ):
+                    needs_reshard = True
+                    seen_shardings[merged_sharding_for_dim] += dim
+                else:
+                    seen_shardings[merged_sharding_for_dim] = dim
+
+    if pending_sums_counter and not linearity:
+        # return reshard suggestion with no pending sum, because we already properly
+        # merge the sharding, this reshard suggestion is legit to use
+        return _gen_reshard_suggestions(
+            op_schema, input_dims, input_specs, dim_to_sharding, []
+        )
+    else:
+        # It's a op that support linearity, but not all input arguments are partial
+        # we fail the sharding propagation with suggestion to make all inputs be
+        # partial on the corresponding mesh dim (all inputs should be partial for
+        # the mesh dims in order to execute locally and delay the sum reduction)
+        for value in pending_sums_counter.values():
+            if value != len(input_specs):
+                needs_reshard = True
+
+    for mesh_dim, dims in seen_shardings.items():
+        if len(dims) > 1:
+            # we found different input dims are being sharded on the same mesh dim
+            # in order to perform local op computation, we need to reshard inputs
+            # base on some simple heuristics, now we simply pick the one with least comm
+            # volume. (i.e. the input with least size)
+            # TODO: consider a more advanced heuristic to pick the best sharding
+            costs = []
+            for d in dims:
+                cost = 0
+                for input_dim, input_spec in zip(input_dims, input_specs):
+                    if (
+                        d in input_dim
+                        and input_spec.dim_map[input_dim.index(d)] == mesh_dim
+                    ):
+                        assert input_spec.tensor_meta is not None
+                        global_shape = input_spec.tensor_meta.shape
+                        local_shape = compute_local_shape(
+                            global_shape, input_spec.mesh, input_spec.placements
+                        )
+                        cost += prod(local_shape) * input_spec.mesh.size(mesh_dim)
+                costs.append(cost)
+            d_to_keep_sharding = dims[costs.index(max(costs))]
+            for d in dims:
+                # update dim_to_sharding to keep the sharding of the dim with
+                # highest comm and make the rest of the dims to replicate
+                if d != d_to_keep_sharding:
+                    dim_to_sharding[d] = -1
+
+    pending_sums = list(pending_sums_counter.keys())
+    if needs_reshard:
+        return _gen_reshard_suggestions(
+            op_schema, input_dims, input_specs, dim_to_sharding, pending_sums
+        )
+
+    # generate output pending sum if a dim is sharded, and it appears in input
+    # but not output
+    for dim, shard_on_mesh in dim_to_sharding.items():
+        if dim not in output_dims[0] and shard_on_mesh != -1:
+            pending_sums.append(shard_on_mesh)
+
+    # if no need to reshard, we directly generate the output sharding
+    output_dim_map = []
+    output_shape = []
+    for dim in output_dim:
+        if dim == "1":
+            # find output dim that is a singleton dimension, mark sharding and shape
+            output_dim_map.append(-1)
+            output_shape.append(1)
+        else:
+            output_dim_map.append(dim_to_sharding[dim])
+            output_shape.append(dim_to_size[dim])
+
+    # XXX: since we still need to have intermediate shape calculation, we need
+    # to pass in the shape here. We should remove this once sharding decomp works
+    # for ops like addmm
+    assert input_specs[0].tensor_meta is not None
+    tensor_meta = TensorMeta(
+        torch.Size(output_shape),
+        input_specs[0].tensor_meta.stride,
+        input_specs[0].tensor_meta.dtype,
+    )
+    return OutputSharding(
+        DTensorSpec.from_dim_map(
+            input_specs[0].mesh,
+            output_dim_map,
+            pending_sums,
+            tensor_meta=tensor_meta,
+        )
+    )
+
+
+def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding:
+    """
+    Propagate the sharding for pointwise operations.
+
+    Examples:
+        ij,ij->ij - addition/mul
+        ij,j->ij - broadcasted addition
+    """
+    alphabet = "abcdefghijklmnopqrstuvwxyz"
+    # find the max_dim first in case we need to broadcasting
+    input_specs = op_schema.args_spec
+    max_dim = max(input.ndim for input in input_specs)
+    dimchars = []
+    singleton_counter: List[int] = [0] * max_dim
+    for input in input_specs:
+        start_dim = max_dim - input.ndim
+        p = alphabet[start_dim:max_dim]
+        # handle the "broadcasting to a common shape case"
+        # see https://pytorch.org/docs/stable/notes/broadcasting.html
+        # If any of the dimensions is singleton dimension (i.e. 1).
+        # we mark the dim char as a special "1" to distinguish with
+        # the non-singleton dimension, so that sharding propagation
+        # should just ignore the singleton dimension.
+        if len(input_specs) > 1:
+            for i in range(max_dim):
+                if i < start_dim:
+                    # treat the leading miss dim chars as singleton
+                    singleton_counter[i] += 1
+                elif input.shape[i - start_dim] == 1:
+                    # mark singleton dim char as a special "1" in einop rule
+                    singleton_counter[i] += 1
+                    p = _replace_char_in_str(p, "1", (i - start_dim))
+
+        dimchars.append(p)
+    out_dimchars = alphabet[:max_dim]
+    # check if we replace the all inputs dim char with singleton dimension,
+    # if we replace all inputs, we also need to replace the output dimension.
+    for output_dim_idx in range(len(out_dimchars)):
+        out_dimchar = out_dimchars[output_dim_idx]
+        if singleton_counter[output_dim_idx] == len(input_specs):
+            out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx)
+
+    fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}"
+
+    enforce_sharding: Dict[str, int] = {}
+    if _is_inplace_op(op_schema.op):
+        # inplace op should keep the input sharding it writes to
+        for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map):
+            enforce_sharding[out_dimchar] = mesh_dim
+    elif _is_out_variant_op(op_schema.op):
+        out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"])
+        for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map):
+            enforce_sharding[out_dimchar] = mesh_dim
+
+    return einop_rule(
+        fmt,
+        op_schema,
+        linearity=linearity,
+        enforce_sharding=enforce_sharding,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/conv_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/conv_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6f901f4d2edc7ac52aa14d3c966a9ce42c41544
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/conv_ops.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+from typing import List
+
+import torch
+from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
+from torch.distributed._tensor.ops.utils import register_prop_rule
+from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
+
+aten = torch.ops.aten
+
+
+@register_prop_rule(aten.convolution.default)
+def convolution_rules(op_schema: OpSchema) -> OutputSharding:
+    (
+        input_spec,
+        weight_spec,
+        bias_spec,
+        stride,
+        padding,
+        dilation,
+        transposed,
+        output_padding,
+        groups,
+    ) = op_schema.args_schema
+
+    assert isinstance(input_spec, DTensorSpec)
+    assert isinstance(weight_spec, DTensorSpec)
+    assert isinstance(bias_spec, DTensorSpec)
+    assert input_spec.tensor_meta is not None
+    assert weight_spec.tensor_meta is not None
+    in_shape = input_spec.tensor_meta.shape
+    weight_shape = weight_spec.tensor_meta.shape
+    assert isinstance(stride, List)
+    assert isinstance(padding, List)
+    assert isinstance(dilation, List)
+    assert isinstance(weight_shape, torch.Size)
+    N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3]
+    C_out = weight_shape[0]
+    H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[
+        0
+    ] + 1
+    W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[
+        1
+    ] + 1
+    output_shape = [N, C_out, H_out, W_out]
+    output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1)
+    output_dim_map = input_spec.dim_map
+    pending_sums = input_spec.sums
+
+    tensor_meta = TensorMeta(
+        torch.Size(output_shape),
+        output_stride,
+        input_spec.tensor_meta.dtype,
+    )
+    return OutputSharding(
+        DTensorSpec.from_dim_map(
+            input_spec.mesh,
+            output_dim_map,
+            pending_sums,
+            tensor_meta=tensor_meta,
+        )
+    )
+
+
+@register_prop_rule(aten.convolution_backward.default)
+def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
+    input_spec = op_schema.args_schema[0]
+    (
+        grad_output_spec,
+        input_spec,
+        weight_spec,
+        bias_shape_opt,
+        stride,
+        padding,
+        dilation,
+        transposed,
+        output_padding,
+        groups,
+        output_mask,
+    ) = op_schema.args_schema
+
+    assert isinstance(grad_output_spec, DTensorSpec)
+    assert isinstance(input_spec, DTensorSpec)
+    assert isinstance(weight_spec, DTensorSpec)
+    assert isinstance(bias_shape_opt, List)
+    assert input_spec.tensor_meta is not None
+    weight_tensor_meta = weight_spec.tensor_meta
+    bias_tensor_meta = TensorMeta(
+        torch.Size(bias_shape_opt),
+        (1,),
+        input_spec.tensor_meta.dtype,
+    )
+
+    grad_input_spec = input_spec
+    grad_weight_spec = DTensorSpec.from_dim_map(
+        input_spec.mesh,
+        [-1, -1, -1, -1],
+        [0],
+        tensor_meta=weight_tensor_meta,
+    )
+    grad_bias_spec = DTensorSpec.from_dim_map(
+        input_spec.mesh,
+        [-1],
+        [0],
+        tensor_meta=bias_tensor_meta,
+    )
+    return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/embedding_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/embedding_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a8dcd9679884c10851dd69db5bbde52598ed8fd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/embedding_ops.py
@@ -0,0 +1,313 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+import itertools
+from dataclasses import dataclass, field
+from typing import cast, List, Optional
+
+import torch
+import torch.distributed._functional_collectives as funcol
+from torch.distributed._tensor.op_schema import (
+    OpSchema,
+    OpStrategy,
+    PlacementStrategy,
+    StrategyType,
+)
+from torch.distributed._tensor.ops.utils import (
+    generate_redistribute_costs,
+    is_tensor_shardable,
+    register_op_strategy,
+)
+
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+from torch.distributed.device_mesh import DeviceMesh
+
+aten = torch.ops.aten
+
+
+@dataclass
+class MaskBuffer:
+    data: Optional[torch.Tensor] = None
+
+    def materialize_mask(self, mask):
+        if self.data is not None:
+            raise RuntimeError("MaskBuffer has already been materialized")
+        self.data = mask
+
+    def release_mask(self):
+        # TODO: evaluate if we need to release the mask buffer or the buffer
+        # can just have the same lifetime as the _Partial placement
+        if self.data is None:
+            raise RuntimeError("MaskBuffer has not been materialized")
+        self.data = None
+
+    def apply_mask(self, tensor):
+        if self.data is None:
+            raise RuntimeError("MaskBuffer has not been materialized")
+
+        # NOTE: _MaskPartial is being used by the embedding op and the gather op.
+        # For gather, the mask has the same dimension as the output tensor, whereas
+        # the output of the embedding op has an additional dimension compare to the input,
+        # hence the output masking logic below having two different cases.
+        if tensor.ndim == self.data.ndim:
+            tensor[self.data] = 0.0
+        else:
+            tensor[self.data, :] = 0.0
+
+
+@dataclass(frozen=True)
+class _MaskPartial(_Partial):
+    """
+    A partial mask placement devised for rowwise sharded embedding op, where we need
+    to mask and adjust the indices to the local embedding shard, embedding masking
+    is a special type of the Partial placement
+
+    NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
+    lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
+    """
+
+    logical_dim_size: int = -1
+    mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
+
+    def _partition_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        # override parent logic to perform partial mask for embedding
+        num_chunks = mesh.size(mesh_dim)
+        # get local shard size and offset on the embedding_dim
+        local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
+            self.logical_dim_size,
+            num_chunks,
+            mesh.get_local_rank(mesh_dim),
+            return_offset=True,
+        )
+        # Build the input mask and save it for the current partial placement
+        # this is so that the output of embedding op can reuse the same partial
+        # placement saved mask to perform mask + reduction
+        mask = (tensor < local_offset_on_dim) | (
+            tensor >= local_offset_on_dim + local_shard_size
+        )
+        # mask the input tensor
+        masked_tensor = tensor.clone() - local_offset_on_dim
+        masked_tensor[mask] = 0
+        # materialize the mask buffer to be used for reduction
+        self.mask_buffer.materialize_mask(mask)
+        return masked_tensor
+
+    def _reduce_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        # by the time we ned reduction, we should have already saved the mask
+        assert self.mask_buffer.data is not None
+
+        # apply the mask to the tensor that pending reduction
+        self.mask_buffer.apply_mask(tensor)
+
+        # clear the mask buffer
+        self.mask_buffer.release_mask()
+
+        # perform sum reduction
+        return funcol.all_reduce(
+            tensor, reduceOp=self.reduce_op.name, group=(mesh, mesh_dim)
+        )
+
+    def _reduce_shard_value(
+        self,
+        tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        shard_spec: Placement,
+    ) -> torch.Tensor:
+        # by the time we ned reduction, we should have already saved the mask
+        assert self.mask_buffer.data is not None
+
+        # apply the mask to the tensor that pending reduction
+        self.mask_buffer.apply_mask(tensor)
+
+        # clear the mask buffer
+        self.mask_buffer.release_mask()
+
+        # call reduce_shard_tensor of the shard_spec.
+        shard_spec = cast(Shard, shard_spec)
+        return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, _MaskPartial):
+            return False
+
+        # if either data is not None, we invalidate the sharding cache, as this indicates
+        # the current MaskPartial placement is still in use and should not be used for cache hit.
+        if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
+            return False
+
+        return (
+            self.reduce_op == other.reduce_op
+            and self.logical_dim_size == other.logical_dim_size
+        )
+
+    def __hash__(self) -> int:
+        return 1 + hash(
+            (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)
+        )
+
+    def __repr__(self) -> str:
+        """
+        machine readable representation of the MaskPartial placement
+        """
+        return f"_MaskPartial(logical_dim_size={self.logical_dim_size})"
+
+    def __str__(self) -> str:
+        """
+        human readable representation of the MaskPartial placement
+        """
+        return "MaskP"
+
+
+@register_op_strategy(aten.embedding.default)
+def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    """
+    This strategy handles embedding op. We have two possible embedding shardings:
+    rowwise and colwise
+    # TODO: implement rowwise sharding
+    """
+    weight_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
+
+    weight_shape = weight_strategy.output_shape
+    indices_shape = indices_strategy.output_shape
+    output_emd_dim = len(indices_shape)
+
+    all_mesh_dim_strategies = []
+
+    for mesh_dim in range(mesh.ndim):
+        single_mesh_dim_strategies = []
+
+        # placement list stores placements of [output, weight, input_indices]
+        # first we always have replicate all for inputs and output
+        all_replicate: List[Placement] = [Replicate()] * 3
+        single_mesh_dim_strategies.append(all_replicate)
+
+        # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate
+        colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()]
+        single_mesh_dim_strategies.append(colwise_sharding)
+
+        # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
+        embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0])
+
+        # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
+        # from the input indices and use it for output reduction
+        rowwise_sharding = [
+            embedding_partial_placement,
+            Shard(0),
+            embedding_partial_placement,
+        ]
+        single_mesh_dim_strategies.append(rowwise_sharding)
+
+        # batch dim sharding, weight replicated, input can shard on any dim, output follows input
+        for input_dim in range(len(indices_shape)):
+            batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)]
+            single_mesh_dim_strategies.append(batch_sharding)
+
+        all_mesh_dim_strategies.append(single_mesh_dim_strategies)
+
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = []
+        for specs in zip(*strategy_comb):
+            spec_list.append(DTensorSpec(mesh, tuple(specs)))
+
+        if is_tensor_shardable(weight_shape, spec_list[1]) and is_tensor_shardable(
+            indices_shape, spec_list[2]
+        ):
+            # only add to the strategy list when both weight and indices are shardable
+            weight_spec, indices_spec = spec_list[1:]
+            redistribute_cost = [
+                generate_redistribute_costs(weight_strategy, weight_spec),
+                generate_redistribute_costs(indices_strategy, indices_spec),
+            ]
+            strat = PlacementStrategy(
+                output_specs=spec_list[0],
+                input_specs=spec_list[1:],
+                redistribute_cost=redistribute_cost,
+            )
+            all_strategies.append(strat)
+
+    return OpStrategy(all_strategies)
+
+
+@register_op_strategy(aten.embedding_dense_backward.default)
+def embedding_dense_backward_strategy(
+    mesh: DeviceMesh, op_schema: OpSchema
+) -> StrategyType:
+    """
+    This strategy handles embedding op. We have two possible embedding shardings:
+    rowwise and colwise
+    # TODO: implement rowwise sharding backward
+    """
+    grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
+
+    grad_out_shape = grad_out_strategy.output_shape
+    indices_shape = indices_strategy.output_shape
+    grad_out_ndim = len(grad_out_shape)
+
+    all_mesh_dim_strategies = []
+
+    for mesh_dim in range(mesh.ndim):
+        single_mesh_dim_strategies = []
+
+        # placement list stores placements of [output, weight, input_indices]
+        # first we always have replicate all for inputs and output
+        all_replicate: List[Placement] = [Replicate()] * 3
+        single_mesh_dim_strategies.append(all_replicate)
+
+        # colwise sharding backward, grad_out shard on last dim, input replicate,
+        # weight grad shard colwise
+        colwise_sharding = [Shard(1), Shard(grad_out_ndim - 1), Replicate()]
+        single_mesh_dim_strategies.append(colwise_sharding)
+
+        # batch dim sharding, weight replicated, grad_out/input have same sharding
+        # that can shard on any dim, weight grad partial
+        for input_dim in range(len(indices_shape)):
+            batch_sharding = [_Partial(), Shard(input_dim), Shard(input_dim)]
+            single_mesh_dim_strategies.append(batch_sharding)
+
+        # grad_out partial, input replicate, weight grad keep partial
+        partial_sharding = [_Partial(), _Partial(), Replicate()]
+        single_mesh_dim_strategies.append(partial_sharding)
+
+        all_mesh_dim_strategies.append(single_mesh_dim_strategies)
+
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = []
+        for specs in zip(*strategy_comb):
+            spec_list.append(DTensorSpec(mesh, tuple(specs)))
+
+        if is_tensor_shardable(grad_out_shape, spec_list[1]) and is_tensor_shardable(
+            indices_shape, spec_list[2]
+        ):
+            # only add to the strategy list when both grad_out and indices are shardable
+            grad_out_spec, indices_spec = spec_list[1:]
+            redistribute_cost = [
+                generate_redistribute_costs(grad_out_strategy, grad_out_spec),
+                generate_redistribute_costs(indices_strategy, indices_spec),
+            ]
+            strat = PlacementStrategy(
+                output_specs=spec_list[0],
+                input_specs=spec_list[1:],
+                redistribute_cost=redistribute_cost,
+            )
+            all_strategies.append(strat)
+
+    return OpStrategy(all_strategies)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/experimental_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/experimental_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..c011ba28381280556ac03a923ddf5ddd95c14137
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/experimental_ops.py
@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+from typing import List
+
+try:
+    import numpy as np
+except ModuleNotFoundError:
+    np = None  # type: ignore[assignment]
+
+import torch
+from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
+from torch.distributed._tensor.ops.utils import register_prop_rule
+from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
+
+aten = torch.ops.aten
+
+
+@register_prop_rule(aten.slice_backward.default)
+def slice_backward_rules(op_schema: OpSchema) -> OutputSharding:
+    grad_output_spec, input_sizes, dim, start, end, step = op_schema.args_schema
+    assert isinstance(grad_output_spec, DTensorSpec)
+    assert isinstance(input_sizes, List)
+    assert grad_output_spec.tensor_meta is not None
+    grad_input_stride = list(np.cumprod(input_sizes[::-1])[:-1][::-1])
+    grad_input_stride.append(1)
+    dim_map = grad_output_spec.dim_map
+    sums = grad_output_spec.sums
+
+    grad_input_tensor_meta = TensorMeta(
+        torch.Size(input_sizes),
+        tuple(grad_input_stride),
+        grad_output_spec.tensor_meta.dtype,
+    )
+    grad_input_spec = DTensorSpec.from_dim_map(
+        grad_output_spec.mesh,
+        dim_map,
+        sums,
+        tensor_meta=grad_input_tensor_meta,
+    )
+
+    return OutputSharding(grad_input_spec)
+
+
+@register_prop_rule(aten.bernoulli.default)
+@register_prop_rule(aten.bernoulli_.float)
+def bernoulli_rules(op_schema: OpSchema) -> OutputSharding:
+    input_spec = op_schema.args_schema[0]
+    assert isinstance(input_spec, DTensorSpec)
+    return OutputSharding(input_spec)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/math_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/math_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b76a08038875da5847db32ae8bb45cde25c25b92
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/math_ops.py
@@ -0,0 +1,957 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from dataclasses import dataclass
+from enum import Enum
+from typing import cast, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+import torch.distributed.distributed_c10d as c10d
+from torch.distributed._tensor.op_schema import (
+    OpSchema,
+    OpStrategy,
+    PlacementStrategy,
+    RuntimeSchemaInfo,
+    TupleStrategy,
+)
+from torch.distributed._tensor.ops.utils import (
+    as_list,
+    generate_redistribute_costs,
+    is_tensor_evenly_shardable,
+    normalize_dim,
+    normalize_dims,
+    normalize_to_torch_size,
+    register_op_strategy,
+)
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.device_mesh import DeviceMesh
+
+
+aten = torch.ops.aten
+
+
+class Reduction(Enum):
+    NONE = 0
+    MEAN = 1
+    SUM = 2
+
+
+@dataclass(frozen=True)
+class NormReduction:
+    norm_type: Union[int, float, str]
+
+
+ReductionOpType = Union[NormReduction, c10d.ReduceOp.RedOpType]
+
+
+@dataclass(frozen=True)
+class _NormPartial(_Partial):
+    """
+    This placement is used for partial vector norm.
+
+    For p-norms (where p not inf or -inf), the p-norm over n elements computes
+        (sum_i x_i^p)^(1/p)
+    where the sum is from i=1 to n. The reduction op is the p-norm itself.
+    For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm:
+        Rank 0: [t1, t2] | Rank 1: [t3, t4]
+    After computing 2-norm per gradient (partial placement):
+        Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)]
+    Converting from partial to replicate wants to ultimately get:
+        Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)]
+    This can be achieved by computing 2-norm on each rank's result. This holds
+    similarly for inf and -inf norm. For 0-norm, the reduction op is sum.
+    """
+
+    norm_type: Union[int, float, str] = 2
+
+    def __post_init__(self):
+        """Set the appropriate reduce op based on the norm type."""
+        # Use `object.__setattr__` to bypass frozen checks
+        if self.norm_type in (float("inf"), "inf"):
+            object.__setattr__(self, "reduce_op", c10d.ReduceOp.MAX)
+        elif self.norm_type in (float("-inf"), "-inf"):
+            object.__setattr__(self, "reduce_op", c10d.ReduceOp.MIN)
+        elif isinstance(self.norm_type, (int, float)):
+            object.__setattr__(self, "reduce_op", c10d.ReduceOp.SUM)
+        else:
+            raise NotImplementedError(f"Unsupported norm type: {self.norm_type}")
+
+    def _partition_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        if self.reduce_op in (c10d.ReduceOp.MAX, c10d.ReduceOp.MIN):
+            return tensor
+        elif self.reduce_op == c10d.ReduceOp.SUM:
+            return tensor / mesh.size(mesh_dim=mesh_dim)
+        raise NotImplementedError(self.reduce_op)
+
+    def _reduce_shard_value(
+        self,
+        tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        shard_spec: Placement,
+    ) -> torch.Tensor:
+        assert isinstance(shard_spec, Shard), f"{shard_spec}"
+        tensor = self._pre_reduce_transform(tensor)
+        reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
+        return self._post_reduce_transform(reduced_tensor)
+
+    def _reduce_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        tensor = self._pre_reduce_transform(tensor)
+        reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
+        return self._post_reduce_transform(reduced_tensor)
+
+    def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
+        if self.reduce_op == c10d.ReduceOp.SUM:
+            assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
+            if self.norm_type != 0 and self.norm_type != 1:
+                return tensor**self.norm_type
+        return tensor
+
+    def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor:
+        if self.reduce_op == c10d.ReduceOp.SUM:
+            assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}"
+            if self.norm_type != 0 and self.norm_type != 1:
+                return tensor ** (1.0 / self.norm_type)
+        return tensor
+
+
+def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]:
+    if dims_arg is None:
+        return None
+    dims = cast(List[int], as_list(dims_arg))
+    dims = cast(List[int], normalize_dims(dims, ndim))
+    empty_dims = [[0], [-1], []]
+    if ndim == 0 and dims_arg in empty_dims:
+        return None
+    return dims
+
+
+def _infer_reduce_dims_map(
+    reduction_dims: List[int], input_ndim: int, keep_dim=False
+) -> List[int]:
+    reduction_dims_map = []
+    new_dim_count = 0
+    for input_dim in range(input_ndim):
+        if input_dim in reduction_dims and not keep_dim:
+            # if input dim in reduction dims, mark it as -1
+            reduction_dims_map.append(-1)
+        else:
+            # otherwise mark it as the new dim
+            reduction_dims_map.append(new_dim_count)
+            new_dim_count += 1
+
+    return reduction_dims_map
+
+
+def replicate_reduction_dims(
+    placements: Tuple[Placement, ...], reduction_dims: List[int]
+) -> Tuple[Placement, ...]:
+    # replicate the reduction dims if not reduction_linear
+    new_placements: List[Placement] = []
+
+    for p in placements:
+        if p.is_partial():
+            new_placements.append(Replicate())
+        elif isinstance(p, Shard) and p.dim in reduction_dims:
+            new_placements.append(Replicate())
+        else:
+            new_placements.append(p)
+
+    return tuple(new_placements)
+
+
+def map_placements_after_reduction(
+    placements: Tuple[Placement, ...],
+    reduction_dims: List[int],
+    reduction_dims_map: List[int],
+    reduction_op: ReductionOpType,
+) -> Tuple[Placement, ...]:
+    """
+    Map each placement based on the output shape after reduction.
+    """
+    new_placements: List[Placement] = []
+    for placement in placements:
+        if isinstance(placement, (Replicate, _Partial)):
+            new_placements.append(placement)
+        else:
+            assert isinstance(placement, Shard)
+            shard_dim = placement.dim
+            new_shard_dim = reduction_dims_map[shard_dim]
+            if new_shard_dim == -1 or shard_dim in reduction_dims:
+                # if new_shard_dim collapsed or its in the reduction dims
+                # (i.e. for the case where keepdims=True), we generate partial
+                new_placements.append(get_placement_from_reduction_op(reduction_op))
+            else:
+                new_placements.append(Shard(new_shard_dim))
+    return tuple(new_placements)
+
+
+def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement:
+    if isinstance(reduction_op, NormReduction):
+        return _NormPartial(norm_type=reduction_op.norm_type)
+    return _Partial(reduction_op)
+
+
+def common_reduction_strategy(
+    mesh: DeviceMesh,
+    input_strategy: OpStrategy,
+    reduce_dims: List[int],
+    keep_dim: bool = False,
+    reduction_linear: bool = True,
+    reduction_op: ReductionOpType = c10d.ReduceOp.SUM,
+) -> OpStrategy:
+    """
+    reduction_linear means that the reduction `f` follows this rule:
+        f([f(a), f(b)]) = f([a, b])
+
+    reduction linear should be super set of linearity.
+    """
+    # by default follow reduction input strategy
+    reduction_strategy = OpStrategy([])
+
+    for strtg in input_strategy.strategies:
+        if not reduction_linear:
+            # input placements for this strategy should clear out pending sum and sharding
+            # on the reduction dimension
+            input_placements = replicate_reduction_dims(
+                strtg.output_spec.placements, reduce_dims
+            )
+        else:
+            input_placements = strtg.output_spec.placements
+
+        input_spec = DTensorSpec(
+            mesh=mesh,
+            placements=input_placements,
+            tensor_meta=strtg.output_spec.tensor_meta,
+        )
+
+        reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim)
+        out_placements = map_placements_after_reduction(
+            input_spec.placements, reduce_dims, reduce_dims_map, reduction_op
+        )
+        redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
+        reduction_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=DTensorSpec(
+                    mesh=mesh,
+                    placements=out_placements,
+                ),
+                input_specs=(input_spec,),
+                redistribute_cost=redistribute_cost,
+            )
+        )
+
+    return reduction_strategy
+
+
+LINEAR_REDUCTION_OP_MAP = {
+    aten.all.default: c10d.ReduceOp.SUM,
+    aten.all.dim: c10d.ReduceOp.SUM,
+    aten.sum.default: c10d.ReduceOp.SUM,
+    aten.sum.dim_IntList: c10d.ReduceOp.SUM,
+    aten.prod.default: c10d.ReduceOp.PRODUCT,
+    aten.prod.dim_int: c10d.ReduceOp.PRODUCT,
+    aten.prod.int_out: c10d.ReduceOp.PRODUCT,
+    aten.mean.default: c10d.ReduceOp.AVG,
+    aten.mean.dim: c10d.ReduceOp.AVG,
+    aten.mean.out: c10d.ReduceOp.AVG,
+    aten.max.default: c10d.ReduceOp.MAX,
+    aten.max.dim: c10d.ReduceOp.MAX,
+    aten.max.out: c10d.ReduceOp.MAX,
+    aten.min.default: c10d.ReduceOp.MIN,
+    aten.min.dim: c10d.ReduceOp.MIN,
+    aten.min.out: c10d.ReduceOp.MIN,
+}
+
+
+@register_op_strategy(
+    list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1)
+)
+def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    assert isinstance(input_strategy, OpStrategy)
+    dims = None
+    if len(op_schema.args_schema) > 1:
+        dims = _infer_reduction_dims(args_schema[1], input_strategy.output_ndim)
+
+    reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims
+
+    keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
+    reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op]
+    return common_reduction_strategy(
+        mesh,
+        input_strategy,
+        reduce_dims,
+        keep_dim=keep_dim,
+        reduction_linear=True,
+        reduction_op=reduction_op,
+    )
+
+
+@register_op_strategy(
+    [aten.var.correction, aten.var.correction_out],
+    schema_info=RuntimeSchemaInfo(1, ["keepdim"]),
+)
+def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    assert isinstance(input_strategy, OpStrategy)
+    dims = None
+    if len(op_schema.args_schema) > 1:
+        dims = _infer_reduction_dims(args_schema[1], input_strategy.output_ndim)
+
+    reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims
+
+    keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False))
+    return common_reduction_strategy(
+        mesh, input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False
+    )
+
+
+@register_op_strategy(
+    [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1)
+)
+def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    args_schema = op_schema.args_schema
+    input_strategy = args_schema[0]
+    assert isinstance(input_strategy, OpStrategy)
+    norm_type = args_schema[1] if len(args_schema) > 1 else 2
+    assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
+    dim = args_schema[2] if len(args_schema) > 2 else None
+    keepdim = args_schema[3] if len(args_schema) > 3 else False
+    dims = _infer_reduction_dims(dim, input_strategy.output_ndim)
+    reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims
+    return common_reduction_strategy(
+        mesh,
+        input_strategy,
+        reduce_dims,
+        keep_dim=cast(bool, keepdim),
+        reduction_linear=True,
+        reduction_op=NormReduction(norm_type),
+    )
+
+
+@register_op_strategy(
+    [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
+)
+def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy:
+    args_schema = op_schema.args_schema
+    input_tuple_strategy = args_schema[0]
+    assert isinstance(input_tuple_strategy, TupleStrategy)
+    norm_type = args_schema[1]
+    assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
+    output_tuple_strategy_childs: List[OpStrategy] = []
+    for op_strategy in input_tuple_strategy.childs:
+        assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
+        reduce_dims = list(range(op_strategy.output_ndim))
+        output_strategy = common_reduction_strategy(
+            mesh,
+            op_strategy,
+            reduce_dims,
+            reduction_linear=True,
+            reduction_op=NormReduction(norm_type),
+        )
+        output_tuple_strategy_childs.append(output_strategy)
+    return TupleStrategy(output_tuple_strategy_childs)
+
+
+@register_op_strategy(
+    [aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1)
+)
+def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    input_strategy, softmax_dim, _ = op_schema.args_schema
+    input_strategy = cast(OpStrategy, input_strategy)
+    softmax_dim = cast(int, softmax_dim)
+    softmax_dim = normalize_dim(softmax_dim, input_strategy.output_ndim)
+
+    output_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        redistribute_costs = []
+        input_src_spec = input_placement_strategy.output_spec
+
+        # make sure input is replicated along the softmax dim
+        input_target_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_reduction_dims(
+                input_src_spec.placements, [softmax_dim]
+            ),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_target_spec)
+        )
+        output_target_spec = input_target_spec
+        output_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=output_target_spec,
+                input_specs=[input_target_spec],
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return output_strategy
+
+
+@register_op_strategy(
+    [
+        aten._log_softmax_backward_data.default,
+        aten._softmax_backward_data.default,
+    ],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema
+    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
+    out_strategy = cast(OpStrategy, out_strategy)
+    softmax_dim = cast(int, softmax_dim)
+    softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.output_ndim)
+
+    grad_in_strategy = OpStrategy([])
+    for grad_out_placement_strat, out_placement_strat in zip(
+        grad_out_strategy.strategies, out_strategy.strategies
+    ):
+        # follow the sharding of the grad_out or out depending on which has more shards
+        grad_out_src_spec = grad_out_placement_strat.output_spec
+        out_src_spec = out_placement_strat.output_spec
+        src_spec = (
+            grad_out_src_spec
+            if grad_out_src_spec.num_shards >= out_src_spec.num_shards
+            else out_src_spec
+        )
+
+        # make sure inputs are replicated along the softmax dim
+        tgt_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]),
+        )
+        redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec)
+        redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec)
+        grad_in_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=tgt_spec,
+                redistribute_cost=[redist_grad_out_cost, redist_out_cost],
+            )
+        )
+
+    return grad_in_strategy
+
+
+@register_op_strategy(
+    [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default],
+    schema_info=RuntimeSchemaInfo(3),
+)
+def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    assert len(op_schema.args_schema) == 5
+    (
+        input_strategy,
+        target_strategy,
+        weight_strategy,
+        reduction,
+        _,
+    ) = op_schema.args_schema
+    input_strategy = cast(OpStrategy, input_strategy)
+    target_strategy = cast(OpStrategy, target_strategy)
+    reduction = cast(int, reduction)
+
+    input_shape = input_strategy.output_shape
+    channel_dim = 1 if len(input_shape) >= 2 else 0
+
+    output_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        op_args_target_specs = []
+        redistribute_costs = []
+
+        # make sure input is replicated along the channel dim
+        input_src_spec = input_placement_strategy.output_spec
+        input_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_reduction_dims(
+                input_src_spec.placements, [channel_dim]
+            ),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_expected_spec)
+        )
+
+        # target doesn't have channel dim, and it follows input on other dims
+        target_src_spec = target_strategy.strategies[idx].output_spec
+        target_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_skip_dim(input_expected_spec.placements, channel_dim),
+            tensor_meta=target_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(target_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(target_strategy, target_expected_spec)
+        )
+
+        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
+        # make sure it is replicated
+        if weight_strategy is not None:
+            assert isinstance(weight_strategy, OpStrategy)
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+            weight_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(weight_src_spec.placements),
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(weight_expected_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(weight_strategy, weight_expected_spec)
+            )
+
+        if reduction == Reduction.NONE.value:
+            output_expected_spec = target_expected_spec
+            total_weight_expected_spec = DTensorSpec(
+                mesh=mesh, placements=tuple([Replicate()] * mesh.ndim)
+            )
+        else:
+            if reduction == Reduction.MEAN.value:
+                reduction_op = c10d.ReduceOp.AVG
+                if not is_tensor_evenly_shardable(
+                    target_expected_spec.shape, target_expected_spec
+                ):
+                    raise ValueError(
+                        "The intermediate results of nll_loss cannot be evenly sharded, \
+                        resulting in biased mean result."
+                    )
+            else:  # reduction == Reduction.SUM.value:
+                reduction_op = c10d.ReduceOp.SUM
+            reduce_dims = list(range(target_expected_spec.ndim))
+            reduce_dims_map = _infer_reduce_dims_map(
+                reduce_dims, target_expected_spec.ndim, keep_dim=False
+            )
+            out_placements = map_placements_after_reduction(
+                target_expected_spec.placements,
+                reduce_dims,
+                reduce_dims_map,
+                reduction_op,
+            )
+            output_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=out_placements,
+            )
+
+            # whether reduction is sum or mean, the total weight has to be summed up if not replicated
+            total_weight_placements = map_placements_after_reduction(
+                target_expected_spec.placements,
+                reduce_dims,
+                reduce_dims_map,
+                c10d.ReduceOp.SUM,
+            )
+            total_weight_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=total_weight_placements,
+            )
+
+        output_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=(output_expected_spec, total_weight_expected_spec),
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return output_strategy
+
+
+@register_op_strategy(
+    [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default],
+    schema_info=RuntimeSchemaInfo(4),
+)
+def nll_loss_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    assert len(op_schema.args_schema) == 7
+    (
+        grad_out_strategy,
+        input_strategy,
+        target_strategy,
+        weight_strategy,
+        reduction,
+        _,
+        total_weight_strategy,
+    ) = op_schema.args_schema
+    grad_out_strategy = cast(OpStrategy, grad_out_strategy)
+    input_strategy = cast(OpStrategy, input_strategy)
+    target_strategy = cast(OpStrategy, target_strategy)
+    reduction = cast(int, reduction)
+    total_weight_strategy = cast(OpStrategy, total_weight_strategy)
+
+    input_shape = input_strategy.output_shape
+    channel_dim = 1 if len(input_shape) >= 2 else 0
+
+    grad_in_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        op_args_target_specs = []
+        redistribute_costs = []
+
+        # make sure input is replicated along the channel dim
+        input_src_spec = input_placement_strategy.output_spec
+        input_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=replicate_reduction_dims(
+                input_src_spec.placements, [channel_dim]
+            ),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_expected_spec)
+        )
+
+        # target doesn't have channel dim, and it follows input on other dims
+        target_src_spec = target_strategy.strategies[idx].output_spec
+        target_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_skip_dim(input_expected_spec.placements, channel_dim),
+            tensor_meta=target_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(target_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(target_strategy, target_expected_spec)
+        )
+
+        # grad_out follows target if there is no reduction;
+        # otherwise, it should be a replicated scalar.
+        grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec
+        if reduction == Reduction.NONE.value:
+            grad_out_expected_spec = target_expected_spec
+        else:
+            grad_out_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(grad_out_src_spec.placements),
+                tensor_meta=grad_out_src_spec.tensor_meta,
+            )
+        op_args_target_specs.insert(0, grad_out_expected_spec)
+        redistribute_costs.insert(
+            0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec)
+        )
+
+        # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim]
+        # make sure it is replicated
+        if weight_strategy is not None:
+            assert isinstance(weight_strategy, OpStrategy)
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+            weight_expected_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(weight_src_spec.placements),
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(weight_expected_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(weight_strategy, weight_expected_spec)
+            )
+
+        # total_weight should always be replicated
+        total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec
+        total_weight_expected_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(total_weight_src_spec.placements),
+            tensor_meta=total_weight_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(total_weight_expected_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(
+                total_weight_strategy, total_weight_expected_spec
+            )
+        )
+
+        grad_in_expected_spec = input_expected_spec
+        grad_in_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=grad_in_expected_spec,
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return grad_in_strategy
+
+
+@register_op_strategy(
+    [aten.native_layer_norm.default],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    # args must be: input, normalized_shape, weight, bias, eps
+    # for None weight and bias, their corresponding objects will
+    # be None as well. layer_norm_strategy returns one OpStrategy
+    # for the triple return values (out, mean, rstd).
+    assert len(op_schema.args_schema) == 5
+    (
+        input_strategy,
+        normalized_shape,
+        weight_strategy,
+        bias_strategy,
+        _,
+    ) = op_schema.args_schema
+
+    # the current layer norm implementation requires that all
+    # input DTensor's sharding must be in form of OpStrategy
+    assert isinstance(input_strategy, OpStrategy)
+    assert isinstance(normalized_shape, (int, Sequence, torch.Size))
+    normalized_size = normalize_to_torch_size(normalized_shape)
+
+    input_ndim = input_strategy.output_ndim
+    axis = input_ndim - len(normalized_size)
+
+    # we use OpStrategy because the output (out, mean, rstd)
+    # should have the same placements
+    output_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        op_args_target_specs = []
+        redistribute_costs = []
+        input_src_spec = input_placement_strategy.output_spec
+
+        # for the input tensor, we replicate it on the inner dims if necessary
+        # TODO: we can avoid forcing the redistribution once we figure out
+        # how to decompose layer norm
+        input_target_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_target_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_target_spec)
+        )
+
+        if weight_strategy is not None:
+            assert isinstance(weight_strategy, OpStrategy)
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+
+            # for the weight tensor, we replicate it on all dims if necessary
+            # TODO: we can avoid forcing the redistribution once we figure out
+            # how to decompose layer norm
+            weight_target_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(weight_src_spec.placements),
+                tensor_meta=weight_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(weight_target_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(weight_strategy, weight_target_spec)
+            )
+
+        if bias_strategy is not None:
+            assert isinstance(bias_strategy, OpStrategy)
+            bias_src_spec = bias_strategy.strategies[idx].output_spec
+
+            # for the bias tensor, we replicate it on all dims if necessary
+            # TODO: we can avoid forcing the redistribution once we figure out
+            # how to decompose layer norm
+            bias_target_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(bias_src_spec.placements),
+                tensor_meta=bias_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(bias_target_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(bias_strategy, bias_target_spec)
+            )
+
+        # the output spec is the same as input spec
+        output_target_spec = input_target_spec
+        output_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=output_target_spec,
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return output_strategy
+
+
+@register_op_strategy(
+    [aten.native_layer_norm_backward.default],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    # args must be: grad_out, input, normalized_shape, mean, rstd,
+    # weight, bias, output_mask. For None weight and bias, their
+    # corresponding objects will be None as well.
+    assert len(op_schema.args_schema) == 8
+    (
+        grad_out_strategy,
+        input_strategy,
+        normalized_shape,
+        mean_strategy,
+        rstd_strategy,
+        weight_strategy,
+        bias_strategy,
+        output_mask,
+    ) = op_schema.args_schema
+
+    assert isinstance(grad_out_strategy, OpStrategy)
+    assert isinstance(input_strategy, OpStrategy)
+    assert isinstance(mean_strategy, OpStrategy)
+    assert isinstance(rstd_strategy, OpStrategy)
+
+    assert isinstance(normalized_shape, (int, Sequence, torch.Size))
+    normalized_size = normalize_to_torch_size(normalized_shape)
+    input_ndim = input_strategy.output_ndim
+    axis = input_ndim - len(normalized_size)
+    outer_dims = list(range(axis))
+
+    assert isinstance(output_mask, List) and len(output_mask) == 3
+
+    # output triple: (d_input, d_weight, d_bias)
+    out_tuple_strategy = OpStrategy([])
+    for idx, input_placement_strategy in enumerate(input_strategy.strategies):
+        # args for PlacementStrategy
+        output_specs_list: List[Optional[DTensorSpec]] = []
+        op_args_target_specs = []
+        redistribute_costs = []
+
+        input_src_spec = input_placement_strategy.output_spec
+        # arg: grad_out
+        # TODO: change the strategy to the following rule.
+        # d_input is basically a product of element-wise mul of
+        # grad_out, rstd, and normalized input, among which rstd
+        # and normalized input (x_hat) should have the same sharding
+        # placements, and grad_out's sharding is determined by the
+        # pointwise result of x_hat and weight/bias.
+        if output_mask[0]:
+            # TODO: now grad_out spec follows input spec. we may need
+            # to change it to apply a pointwise rule over grad_out,
+            # input, and weight.
+            grad_out_target_spec = DTensorSpec(
+                mesh=mesh,
+                placements=_replicate_dims_start_at(input_src_spec.placements, axis),
+                tensor_meta=input_src_spec.tensor_meta,
+            )
+            op_args_target_specs.append(grad_out_target_spec)
+            redistribute_costs.append(
+                generate_redistribute_costs(grad_out_strategy, grad_out_target_spec)
+            )
+            output_specs_list.append(grad_out_target_spec)
+        else:
+            output_specs_list.append(None)
+
+        # arg: input
+        input_target_spec = DTensorSpec(
+            mesh=mesh,
+            placements=_replicate_dims_start_at(input_src_spec.placements, axis),
+            tensor_meta=input_src_spec.tensor_meta,
+        )
+        op_args_target_specs.append(input_target_spec)
+        redistribute_costs.append(
+            generate_redistribute_costs(input_strategy, input_target_spec)
+        )
+
+        # arg: mean, rstd
+        mean_src_spec = mean_strategy.strategies[idx].output_spec
+        op_args_target_specs.append(mean_src_spec)
+        redistribute_costs.append([0.0 for _ in mean_strategy.strategies])
+        rstd_src_spec = rstd_strategy.strategies[idx].output_spec
+        op_args_target_specs.append(rstd_src_spec)
+        redistribute_costs.append([0.0 for _ in rstd_strategy.strategies])
+
+        # arg: weight
+        # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False)
+        if output_mask[1]:
+            assert isinstance(weight_strategy, OpStrategy)
+            weight_src_spec = weight_strategy.strategies[idx].output_spec
+            # no need to redistribute weight since they should be replicated
+            # in forward pass
+            op_args_target_specs.append(weight_src_spec)
+            redistribute_costs.append([0.0 for _ in weight_strategy.strategies])
+            # TODO: now d_weight spec follows input spec w/ a reduction.
+            # we may need to change to a pointwise rule over grad_out and
+            # input, then apply a reduction.
+            inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis)
+            reduce_dims_map = _infer_reduce_dims_map(
+                outer_dims, input_src_spec.ndim, False
+            )
+            out_placements = map_placements_after_reduction(
+                inp_placements, outer_dims, reduce_dims_map, c10d.ReduceOp.SUM
+            )
+            output_specs_list.append(
+                DTensorSpec(
+                    mesh=mesh,
+                    placements=out_placements,
+                    tensor_meta=weight_src_spec.tensor_meta,
+                )
+            )
+        else:
+            output_specs_list.append(None)
+
+        # arg: bias
+        # d_bias = sum(grad_out, outer_dim, keepdim=False)
+        if output_mask[2]:
+            assert isinstance(bias_strategy, OpStrategy)
+            bias_src_spec = bias_strategy.strategies[idx].output_spec
+            # no need to redistribute weight since they should be replicated
+            # in forward pass
+            op_args_target_specs.append(bias_src_spec)
+            redistribute_costs.append([0.0 for _ in bias_strategy.strategies])
+            # Currently we do not support the case where output_mask[0] is False while
+            # output_mask[1] is True. But it's easy to support that by accessing
+            # grad_out_spec via a local variable rather than the list. We just don't
+            # see the case.
+            grad_out_spec = output_specs_list[0]
+            assert isinstance(grad_out_spec, DTensorSpec)
+            # d_bias spec follows a reduction over grad_out
+            inp_placements = _replicate_dims_start_at(grad_out_spec.placements, axis)
+            reduce_dims_map = _infer_reduce_dims_map(
+                outer_dims, grad_out_spec.ndim, False
+            )
+            out_placements = map_placements_after_reduction(
+                inp_placements, outer_dims, reduce_dims_map, c10d.ReduceOp.SUM
+            )
+            output_specs_list.append(
+                DTensorSpec(
+                    mesh=mesh,
+                    placements=out_placements,
+                    tensor_meta=bias_src_spec.tensor_meta,
+                )
+            )
+        else:
+            output_specs_list.append(None)
+
+        out_tuple_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=tuple(output_specs_list),
+                input_specs=op_args_target_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+
+    return out_tuple_strategy
+
+
+def _replicate_dims_start_at(
+    placements: Sequence[Placement], start_dim: int = 0
+) -> Tuple[Placement, ...]:
+    new_placements: List[Placement] = []
+    for p in placements:
+        if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
+            new_placements.append(Replicate())  # make it replicate
+        else:
+            new_placements.append(p)  # keep the placement
+    return tuple(new_placements)
+
+
+# return new_placements which align with placements but skip the skipped_dim
+def _skip_dim(
+    placements: Tuple[Placement, ...], skipped_dim: int
+) -> Tuple[Placement, ...]:
+    new_placements: List[Placement] = []
+    for p in placements:
+        if isinstance(p, Shard) and p.dim >= skipped_dim:
+            new_placements.append(Shard(p.dim - 1))
+        else:
+            new_placements.append(p)
+    return tuple(new_placements)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/matrix_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/matrix_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9f718ec6afc2d7f76acf0457a0654bb365d9f3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/matrix_ops.py
@@ -0,0 +1,226 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+import itertools
+from typing import List, Optional
+
+import torch
+from torch.distributed._tensor.op_schema import (
+    OpSchema,
+    OpStrategy,
+    OutputSharding,
+    PlacementStrategy,
+)
+from torch.distributed._tensor.ops.basic_strategy import gen_einsum_strategies
+from torch.distributed._tensor.ops.common_rules import einop_rule
+from torch.distributed._tensor.ops.utils import (
+    generate_redistribute_costs,
+    infer_broadcast_dims_map,
+    is_tensor_shardable,
+    map_placements_after_broadcast,
+    register_op_strategy,
+    register_prop_rule,
+)
+from torch.distributed._tensor.placement_types import (
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+from torch.distributed.device_mesh import DeviceMesh
+
+aten = torch.ops.aten
+
+
+@register_prop_rule(aten.t.default)
+def transpose_rule(op_schema: OpSchema) -> OutputSharding:
+    return einop_rule("ij->ji", op_schema, linearity=True)
+
+
+def _mm_like_strategy(
+    mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
+) -> OpStrategy:
+    self_strategy, mat2_strategy = op_schema.args_schema
+    assert isinstance(self_strategy, OpStrategy)
+    assert isinstance(mat2_strategy, OpStrategy)
+    # generate all possible strategies for mm
+    mm_strategy = gen_einsum_strategies(mm_equation, mesh)
+    # filter out invalid strategies and associate costs
+    strategies = mm_strategy.strategies
+    filtered_strategies = []
+    for strtg in strategies:
+        assert strtg.input_specs is not None
+        self_spec = strtg.input_specs[0]
+        mat2_spec = strtg.input_specs[1]
+        if is_tensor_shardable(
+            self_strategy.output_shape, self_spec
+        ) and is_tensor_shardable(mat2_strategy.output_shape, mat2_spec):
+            redistribute_cost = [
+                generate_redistribute_costs(self_strategy, self_spec),
+                generate_redistribute_costs(mat2_strategy, mat2_spec),
+            ]
+            strtg.redistribute_cost = redistribute_cost
+            filtered_strategies.append(strtg)
+
+    mm_strategy.strategies = filtered_strategies
+
+    return mm_strategy
+
+
+def _addmm_like_strategy(
+    mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
+) -> OpStrategy:
+    self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema
+    assert isinstance(self_strategy, OpStrategy)
+    assert isinstance(mat1_strategy, OpStrategy)
+    assert isinstance(mat2_strategy, OpStrategy)
+    self_shape = self_strategy.output_shape
+    mm_out_shape = torch.Size(
+        [
+            mat2_strategy.output_shape[-1]
+            if i == len(mat1_strategy.output_shape) - 1
+            else dim_size
+            for i, dim_size in enumerate(mat1_strategy.output_shape)
+        ]
+    )
+    # generate all possible strategies for mm
+    mm_strategy = gen_einsum_strategies(mm_equation, mesh)
+    # filter out invalid strategies and associate costs
+    strategies = mm_strategy.strategies
+    filtered_strategies = []
+    for strtg in strategies:
+        # construct new strategy by consider the self arg
+        assert strtg.input_specs is not None
+        mat1_spec = strtg.input_specs[0]
+        mat2_spec = strtg.input_specs[1]
+        out_spec = strtg.output_spec
+
+        # self arg's spec should follow the output of mm, but need
+        # to consider broadcast for the self arg
+        broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape)
+        self_placements = map_placements_after_broadcast(
+            out_spec.placements, mm_out_shape, broadcast_dims_map
+        )
+        self_spec = DTensorSpec(mesh=mesh, placements=self_placements)
+
+        if is_tensor_shardable(
+            mat1_strategy.output_shape, mat1_spec
+        ) and is_tensor_shardable(mat2_strategy.output_shape, mat2_spec):
+            # update input specs with new self spec
+            strtg.input_specs = (self_spec, mat1_spec, mat2_spec)
+
+            # associate costs
+            redistribute_cost = [
+                generate_redistribute_costs(self_strategy, self_spec),
+                generate_redistribute_costs(mat1_strategy, mat1_spec),
+                generate_redistribute_costs(mat2_strategy, mat2_spec),
+            ]
+            strtg.redistribute_cost = redistribute_cost
+            filtered_strategies.append(strtg)
+
+    mm_strategy.strategies = filtered_strategies
+
+    return mm_strategy
+
+
+@register_op_strategy(aten.mm.default)
+def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
+
+
+@register_op_strategy(aten.addmm.default)
+def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    return _addmm_like_strategy("mk,kn->mn", mesh, op_schema)
+
+
+@register_op_strategy(aten.bmm.default)
+def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
+
+
+@register_op_strategy(aten.baddbmm.default)
+def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
+
+
+@register_op_strategy(aten._scaled_dot_product_flash_attention.default)
+def scaled_dot_product_attention_strategy(
+    mesh: DeviceMesh, op_schema: OpSchema
+) -> OpStrategy:
+    # NOTE: currently we only support some simple strategies to support tensor parallelism
+    # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
+    # as it involves: matmul, pointwise, reduction ops together.
+    return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
+    q_input_strategy = op_schema.args_schema[0]
+    assert isinstance(q_input_strategy, OpStrategy)
+    # q/k/v have the same shape
+    qkv_shape = q_input_strategy.output_shape
+
+    all_mesh_dim_strategies = []
+
+    for mesh_dim in range(mesh.ndim):
+        single_mesh_dim_strategies = []
+
+        # placement list stores placements of [outputs, inputs]
+        # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs
+        # first we can always accept full replication for inputs and output
+        all_replicate: List[Placement] = [Replicate()] * 6
+        single_mesh_dim_strategies.append(all_replicate)
+
+        # second we can accept the sharding pattern of tensor parallelism, which
+        # shard on the num of head dim
+        qkv_sharding = Shard(1)  # num head dim
+        output_sharding = Shard(1)  # num head dim
+        logsumexp_sharding = Shard(1)  # num head dim
+        if return_debug_mask:
+            debug_attn_mask_sharding: Placement = Shard(1)  # num head dim
+        else:
+            # empty debug mask, replicated
+            debug_attn_mask_sharding = Replicate()
+
+        num_heads_dim_sharding = [
+            output_sharding,
+            logsumexp_sharding,
+            debug_attn_mask_sharding,
+            qkv_sharding,
+            qkv_sharding,
+            qkv_sharding,
+        ]
+        single_mesh_dim_strategies.append(num_heads_dim_sharding)
+
+        all_mesh_dim_strategies.append(single_mesh_dim_strategies)
+
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = []
+        for specs in zip(*strategy_comb):
+            spec_list.append(DTensorSpec(mesh, tuple(specs)))
+
+        assert len(spec_list) == 6
+        input_expected_specs = spec_list[3:]
+        output_specs: List[Optional[DTensorSpec]] = list(spec_list[:3])
+        # fix up output_specs and fill in None for the int and empty tensor return values
+        for i in range(2, 8):
+            output_specs.insert(i, None)
+        if all(is_tensor_shardable(qkv_shape, spec) for spec in input_expected_specs):
+            # only add to the strategy list when all inputs are shardable
+            redistribute_cost = []
+            for input_idx, spec in enumerate(input_expected_specs):
+                qkv_strategy = op_schema.args_schema[input_idx]
+                assert isinstance(qkv_strategy, OpStrategy)
+                qkv_tensor_meta = qkv_strategy.strategies[0].output_spec.tensor_meta
+                spec.tensor_meta = qkv_tensor_meta
+                redistribute_cost.append(
+                    generate_redistribute_costs(qkv_strategy, spec)
+                )
+
+            strat = PlacementStrategy(
+                output_specs=tuple(output_specs),
+                input_specs=tuple(input_expected_specs),
+                redistribute_cost=redistribute_cost,
+            )
+            all_strategies.append(strat)
+
+    return OpStrategy(all_strategies)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/pointwise_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/pointwise_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..187dd1b04a613ee69aedac27b58b05a8dd588366
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/pointwise_ops.py
@@ -0,0 +1,629 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import List, Sequence, Tuple
+
+import torch
+
+from torch.distributed._tensor.op_schema import (
+    _is_inplace_op,
+    _is_out_variant_op,
+    OpSchema,
+    OpStrategy,
+    PlacementStrategy,
+    RuntimeSchemaInfo,
+    StrategyType,
+    TupleStrategy,
+)
+
+from torch.distributed._tensor.ops.utils import (
+    generate_redistribute_costs,
+    infer_broadcast_dims_map,
+    map_placements_after_broadcast,
+    normalize_dim,
+    register_op_strategy,
+)
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.device_mesh import DeviceMesh
+
+
+aten = torch.ops.aten
+# leave the remaining pointwise_ops list here for convenience,
+# Below ops are some pointwise ops that are yet to be supported,
+# they might not be a complete list.
+# pointwise_ops = [
+#     "fake_quantize_per_channel_affine",
+#     "fake_quantize_per_tensor_affine",
+#     "floor_divide",  # floor_divide is deprecated
+#     "frexp",  # multiple output pointwise op, need to add support
+#     "gradient",  #  need investigation on this op
+#     "imag",  # complex data type only
+#     "quantized_batch_norm",
+#     "quantized_max_pool1d",
+#     "quantized_max_pool2d",
+#     "real",  # complex data type only
+# ]
+
+
+linear_pointwise_ops = [
+    aten.div.Scalar,  # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
+    aten.div_.Scalar,  # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
+    aten.to.dtype,
+    aten.add.Tensor,
+    aten.add_.Tensor,
+]
+
+
+pointwise_ops = [
+    # please keep the entries below alphabetically sorted
+    aten.abs.default,
+    aten.abs.out,
+    aten.abs_.default,
+    aten.acos.default,
+    aten.acos.out,
+    aten.acos_.default,
+    aten.acosh.default,
+    aten.acosh.out,
+    aten.acosh_.default,
+    aten.add.Scalar,
+    aten.add.out,
+    aten.add_.Scalar,
+    aten.addcdiv.default,
+    aten.addcdiv.out,
+    aten.addcdiv_.default,
+    aten.addcmul.default,
+    aten.addcmul.out,
+    aten.addcmul_.default,
+    aten.angle.default,
+    aten.angle.out,
+    aten.asin.default,
+    aten.asin.out,
+    aten.asin_.default,
+    aten.asinh.default,
+    aten.asinh.out,
+    aten.asinh_.default,
+    aten.atan.default,
+    aten.atan.out,
+    aten.atan2.default,
+    aten.atan2.out,
+    aten.atan2_.default,
+    aten.atan_.default,
+    aten.atanh.default,
+    aten.atanh.out,
+    aten.atanh_.default,
+    aten.bitwise_and.Scalar,
+    aten.bitwise_and.Scalar_Tensor,
+    aten.bitwise_and.Scalar_out,
+    aten.bitwise_and.Tensor,
+    aten.bitwise_and.Tensor_out,
+    aten.bitwise_and_.Scalar,
+    aten.bitwise_and_.Tensor,
+    aten.bitwise_left_shift.Scalar_Tensor,
+    aten.bitwise_left_shift.Tensor,
+    aten.bitwise_left_shift.Tensor_Scalar,
+    aten.bitwise_left_shift.Tensor_Scalar_out,
+    aten.bitwise_left_shift.Tensor_out,
+    aten.bitwise_left_shift_.Tensor,
+    aten.bitwise_left_shift_.Tensor_Scalar,
+    aten.bitwise_not.default,
+    aten.bitwise_not.out,
+    aten.bitwise_not_.default,
+    aten.bitwise_or.Scalar,
+    aten.bitwise_or.Scalar_Tensor,
+    aten.bitwise_or.Scalar_out,
+    aten.bitwise_or.Tensor,
+    aten.bitwise_or.Tensor_out,
+    aten.bitwise_or_.Scalar,
+    aten.bitwise_or_.Tensor,
+    aten.bitwise_right_shift.Scalar_Tensor,
+    aten.bitwise_right_shift.Tensor,
+    aten.bitwise_right_shift.Tensor_Scalar,
+    aten.bitwise_right_shift.Tensor_Scalar_out,
+    aten.bitwise_right_shift.Tensor_out,
+    aten.bitwise_right_shift_.Tensor,
+    aten.bitwise_right_shift_.Tensor_Scalar,
+    aten.bitwise_xor.Scalar,
+    aten.bitwise_xor.Scalar_Tensor,
+    aten.bitwise_xor.Scalar_out,
+    aten.bitwise_xor.Tensor,
+    aten.bitwise_xor.Tensor_out,
+    aten.bitwise_xor_.Scalar,
+    aten.bitwise_xor_.Tensor,
+    aten.ceil.default,
+    aten.ceil.out,
+    aten.ceil_.default,
+    aten.clamp.default,
+    aten.clamp.out,
+    aten.clamp_.default,
+    aten.clip.default,
+    aten.clip.out,
+    aten.clip_.default,
+    aten.conj_physical.default,
+    aten.conj_physical.out,
+    aten.conj_physical_.default,
+    aten.copysign.Scalar,
+    aten.copysign.Scalar_out,
+    aten.copysign.Tensor,
+    aten.copysign.out,
+    aten.copysign_.Scalar,
+    aten.copysign_.Tensor,
+    aten.cos.default,
+    aten.cos.out,
+    aten.cos_.default,
+    aten.cosh.default,
+    aten.cosh.out,
+    aten.cosh_.default,
+    aten.deg2rad.default,
+    aten.deg2rad.out,
+    aten.deg2rad_.default,
+    aten.digamma.default,
+    aten.digamma.out,
+    aten.digamma_.default,
+    aten.div.Tensor,
+    aten.div.Tensor_mode,
+    aten.div.out,
+    aten.div.out_mode,
+    aten.div_.Tensor,
+    aten.div_.Tensor_mode,
+    aten.eq.Tensor,
+    aten.eq.Tensor_out,
+    aten.eq.Scalar,
+    aten.eq.Scalar_out,
+    aten.erf.default,
+    aten.erf.out,
+    aten.erf_.default,
+    aten.erfc.default,
+    aten.erfc.out,
+    aten.erfc_.default,
+    aten.erfinv.default,
+    aten.erfinv.out,
+    aten.erfinv_.default,
+    aten.exp.default,
+    aten.exp.out,
+    aten.exp2.default,
+    aten.exp2.out,
+    aten.exp2_.default,
+    aten.exp_.default,
+    aten.expm1.default,
+    aten.expm1.out,
+    aten.expm1_.default,
+    aten.float_power.Scalar,
+    aten.float_power.Scalar_out,
+    aten.float_power.Tensor_Scalar,
+    aten.float_power.Tensor_Scalar_out,
+    aten.float_power.Tensor_Tensor,
+    aten.float_power.Tensor_Tensor_out,
+    aten.float_power_.Scalar,
+    aten.float_power_.Tensor,
+    aten.floor.default,
+    aten.floor.out,
+    aten.floor_.default,
+    aten.fmod.Scalar,
+    aten.fmod.Scalar_out,
+    aten.fmod.Tensor,
+    aten.fmod.Tensor_out,
+    aten.fmod_.Scalar,
+    aten.fmod_.Tensor,
+    aten.frac.default,
+    aten.frac.out,
+    aten.frac_.default,
+    aten.ge.Scalar,
+    aten.ge.Tensor,
+    aten.gelu.default,
+    aten.gt.Tensor,
+    aten.gt.Tensor_out,
+    aten.gt.Scalar,
+    aten.gt.Scalar_out,
+    aten.gt.Scalar,
+    aten.gt.Tensor,
+    aten.hypot.default,
+    aten.hypot.out,
+    aten.hypot_.default,
+    aten.i0.default,
+    aten.i0.out,
+    aten.i0_.default,
+    aten.igamma.default,
+    aten.igamma.out,
+    aten.igamma_.default,
+    aten.igammac.default,
+    aten.igammac.out,
+    aten.igammac_.default,
+    aten.isnan.default,
+    aten.ldexp.default,
+    aten.ldexp.out,
+    aten.ldexp_.default,
+    aten.lt.Tensor,
+    aten.lt.Tensor_out,
+    aten.lt.Scalar,
+    aten.lt.Scalar_out,
+    aten.le.Scalar,
+    aten.le.Tensor,
+    aten.lerp.Scalar,
+    aten.lerp.Scalar_out,
+    aten.lerp.Tensor,
+    aten.lerp.Tensor_out,
+    aten.lerp_.Scalar,
+    aten.lerp_.Tensor,
+    aten.lgamma.default,
+    aten.lgamma.out,
+    aten.lgamma_.default,
+    aten.log.default,
+    aten.log.out,
+    aten.log10.default,
+    aten.log10.out,
+    aten.log10_.default,
+    aten.log1p.default,
+    aten.log1p.out,
+    aten.log1p_.default,
+    aten.log2.default,
+    aten.log2.out,
+    aten.log2_.default,
+    aten.log_.default,
+    aten.logaddexp.default,
+    aten.logaddexp.out,
+    aten.logaddexp2.default,
+    aten.logaddexp2.out,
+    aten.logical_and.default,
+    aten.logical_and.out,
+    aten.logical_and_.default,
+    aten.logical_not.default,
+    aten.logical_not.out,
+    aten.logical_not_.default,
+    aten.logical_or.default,
+    aten.logical_or.out,
+    aten.logical_or_.default,
+    aten.logical_xor.default,
+    aten.logical_xor.out,
+    aten.logical_xor_.default,
+    aten.logit.default,
+    aten.logit.out,
+    aten.logit_.default,
+    aten.masked_fill.Scalar,
+    aten.maximum.out,
+    aten.mul.Scalar,
+    aten.mul.Tensor,
+    aten.mul.out,
+    aten.mul_.Scalar,
+    aten.mul_.Tensor,
+    aten.mvlgamma.default,
+    aten.mvlgamma.out,
+    aten.mvlgamma_.default,
+    aten.native_dropout_backward.default,
+    aten.native_dropout_backward.out,
+    aten.nan_to_num.default,
+    aten.nan_to_num.out,
+    aten.nan_to_num_.default,
+    aten.ne.Scalar,
+    aten.neg.default,
+    aten.neg.out,
+    aten.neg_.default,
+    aten.nextafter.default,
+    aten.nextafter.out,
+    aten.nextafter_.default,
+    aten.polygamma.default,
+    aten.polygamma.out,
+    aten.polygamma_.default,
+    aten.positive.default,
+    aten.pow.Scalar,
+    aten.pow.Scalar_out,
+    aten.pow.Tensor_Scalar,
+    aten.pow.Tensor_Scalar_out,
+    aten.pow.Tensor_Tensor,
+    aten.pow.Tensor_Tensor_out,
+    aten.pow_.Scalar,
+    aten.pow_.Tensor,
+    aten.reciprocal.default,
+    aten.reciprocal.out,
+    aten.reciprocal_.default,
+    aten.rad2deg.default,
+    aten.rad2deg.out,
+    aten.rad2deg_.default,
+    aten.relu.default,
+    aten.relu_.default,
+    aten.remainder.Scalar,
+    aten.remainder.Scalar_Tensor,
+    aten.remainder.Scalar_out,
+    aten.remainder.Tensor,
+    aten.remainder.Tensor_out,
+    aten.remainder_.Scalar,
+    aten.remainder_.Tensor,
+    aten.round.decimals,
+    aten.round.decimals_out,
+    aten.round.default,
+    aten.round.out,
+    aten.round_.decimals,
+    aten.round_.default,
+    aten.rsqrt.default,
+    aten.rsqrt.out,
+    aten.rsqrt_.default,
+    aten.rsub.Scalar,
+    aten.sgn.default,
+    aten.sgn.out,
+    aten.sgn_.default,
+    aten.sigmoid.default,
+    aten.sigmoid.out,
+    aten.sigmoid_.default,
+    aten.sign.default,
+    aten.sign.out,
+    aten.sign_.default,
+    aten.signbit.default,
+    aten.signbit.out,
+    aten.silu.default,
+    aten.silu.out,
+    aten.sin.default,
+    aten.sin.out,
+    aten.sin_.default,
+    aten.sinc.default,
+    aten.sinc.out,
+    aten.sinc_.default,
+    aten.sinh.default,
+    aten.sinh.out,
+    aten.sinh_.default,
+    aten.sqrt.default,
+    aten.sqrt.out,
+    aten.sqrt_.default,
+    aten.square.default,
+    aten.square.out,
+    aten.square_.default,
+    aten.sub.Scalar,
+    aten.sub.Tensor,
+    aten.sub.out,
+    aten.sub_.Scalar,
+    aten.sub_.Tensor,
+    aten.tan.default,
+    aten.tan.out,
+    aten.tan_.default,
+    aten.tanh.default,
+    aten.tanh.out,
+    aten.tanh_.default,
+    aten.true_divide.Tensor,
+    aten.trunc.default,
+    aten.trunc.out,
+    aten.trunc_.default,
+    aten.where.self,
+    aten.where.self_out,
+    aten.xlogy.OutScalar_Self,
+    aten.xlogy.OutScalar_Other,
+    aten.xlogy.OutTensor,
+    aten.xlogy.Scalar_Other,
+    aten.xlogy.Scalar_Self,
+    aten.xlogy.Tensor,
+    aten.xlogy_.Scalar_Other,
+    aten.xlogy_.Tensor,
+    # backward point-wise ops
+    # please keep the entries below alphabetically sorted
+    aten.gelu_backward.default,
+    aten.sigmoid_backward.default,
+    aten.silu_backward.default,
+    aten.tanh_backward.default,
+    aten.threshold_backward.default,
+]
+
+
+def pointwise_strategy(
+    mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
+) -> OpStrategy:
+    max_shards_strategy_index = -1
+    max_shards = -1
+
+    if _is_inplace_op(op_schema.op):
+        # inplace op should follow the first arg strategy
+        followed_strategy = op_schema.args_schema[0]
+    elif _is_out_variant_op(op_schema.op):
+        # out variant op should follow the out kwarg strategy
+        followed_strategy = op_schema.kwargs_schema["out"]
+    else:
+        # normal pointwise op, we choose to follow the arg with
+        # the max shards in case operands needs reshard
+        for idx, arg_strategy in enumerate(op_schema.args_schema):
+            if not isinstance(arg_strategy, OpStrategy):
+                continue
+
+            arg_max_shards = arg_strategy.max_num_shards()
+            if arg_max_shards > max_shards:
+                max_shards_strategy_index = idx
+                max_shards = arg_max_shards
+
+        followed_strategy = op_schema.args_schema[max_shards_strategy_index]
+
+    assert isinstance(
+        followed_strategy, OpStrategy
+    ), f"no strategy to follow for {op_schema}!"
+    return common_pointwise_strategy(
+        mesh, op_schema.args_schema, followed_strategy, linearity
+    )
+
+
+def common_pointwise_strategy(
+    mesh: DeviceMesh,
+    args_schema: Sequence[object],
+    followed_strategy: OpStrategy,
+    linearity: bool,
+) -> OpStrategy:
+    # handle broadcasting
+    common_shape = torch.broadcast_shapes(
+        *[arg.output_shape for arg in args_schema if isinstance(arg, OpStrategy)]
+    )
+    pointwise_strategy = OpStrategy([])
+
+    for placement_strategy in followed_strategy.strategies:
+        spec_to_follow = placement_strategy.output_spec
+        out_placements: List[Placement] = []
+        for placement in spec_to_follow.placements:
+            if isinstance(placement, Shard):
+                shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape))
+                common_ndim = len(common_shape)
+                new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
+                out_placements.append(Shard(new_shard_dim))
+            elif isinstance(placement, _Partial) and not linearity:
+                # clear the partial placemnet if op does not support linearity
+                # by default we just replicate the partial, need to see if this
+                # is optimal for all cases
+                out_placements.append(Replicate())
+            else:
+                out_placements.append(placement)
+
+        input_specs: List[DTensorSpec] = []
+        redistribute_costs: List[List[float]] = []
+        for idx, input_arg in enumerate(args_schema):
+            if isinstance(input_arg, OpStrategy):
+                # every arg follow the out_placements, but need to handle broadcasting
+                input_arg_spec = input_arg.strategies[0].output_spec
+                input_arg_dims_map = infer_broadcast_dims_map(
+                    common_shape, input_arg_spec.shape
+                )
+                input_target_placements = map_placements_after_broadcast(
+                    tuple(out_placements),
+                    common_shape,
+                    input_arg_dims_map,
+                )
+                input_arg_target_spec = DTensorSpec(
+                    mesh=mesh,
+                    placements=input_target_placements,
+                    tensor_meta=input_arg_spec.tensor_meta,
+                )
+                input_specs.append(input_arg_target_spec)
+                redistribute_costs.append(
+                    generate_redistribute_costs(input_arg, input_arg_target_spec)
+                )
+
+        pointwise_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=DTensorSpec(
+                    mesh=mesh,
+                    placements=tuple(out_placements),
+                ),
+                input_specs=input_specs,
+                redistribute_cost=redistribute_costs,
+            )
+        )
+    return pointwise_strategy
+
+
+def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    """
+    Linear pointwise operators can propagate pending reductions.
+    For example, c = add(a, b); if a is pending sum, then c will be
+    pending sum as well without any communication overhead.
+    """
+    return pointwise_strategy(mesh, op_schema, linearity=True)
+
+
+for op in linear_pointwise_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
+        linear_pointwise_strategy
+    )
+
+for op in pointwise_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
+        pointwise_strategy
+    )
+
+
+# TODO: add all for_each ops
+for_each_ops = [
+    aten._foreach_abs_.default,
+    aten._foreach_addcdiv_.Scalar,
+    aten._foreach_addcdiv_.ScalarList,
+    aten._foreach_addcdiv_.Tensor,
+    aten._foreach_addcmul.Scalar,
+    aten._foreach_addcmul_.Scalar,
+    aten._foreach_addcmul_.ScalarList,
+    aten._foreach_addcmul_.Tensor,
+    aten._foreach_div_.List,
+    aten._foreach_div_.ScalarList,
+    aten._foreach_lerp_.Scalar,
+    aten._foreach_maximum_.List,
+    aten._foreach_mul.Scalar,
+    aten._foreach_mul.List,
+    aten._foreach_mul_.Scalar,
+    aten._foreach_mul_.ScalarList,
+    aten._foreach_mul_.Tensor,
+    aten._foreach_mul_.List,
+    aten._foreach_neg.default,
+    aten._foreach_neg_.default,
+    aten._foreach_reciprocal_.default,
+    aten._foreach_sub_.Scalar,
+    aten._foreach_sqrt.default,
+    aten._foreach_sqrt_.default,
+    aten._foreach_zero_.default,
+]
+
+for_each_linearity_ops = [
+    aten._foreach_add.Scalar,
+    aten._foreach_add_.Scalar,
+    aten._foreach_add_.ScalarList,
+    aten._foreach_add.List,
+    aten._foreach_add_.List,
+]
+
+
+def foreach_list_pointwise_strategy(
+    mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
+) -> StrategyType:
+    """
+    Apply the pointwise strategy to the zipped arguments. For example, if we
+    run a foreach add of two lists l1 and l2, then we apply the pointwise
+    strategy on each pair (l1[i], l2[i]). If the first argument is a list but
+    the second (or later) one is a tensor, then we broadcast the tensor by
+    replicating it into a list with the length of the first argument.
+    """
+
+    def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]:
+        first_arg = args_schema[0]
+        assert isinstance(first_arg, TupleStrategy)
+        strategy_len = len(first_arg.childs)
+        tuple_strategies: List[TupleStrategy] = []
+        for arg_idx, arg in enumerate(args_schema):
+            if isinstance(arg, TupleStrategy):
+                # every tuple strategy should have the same length
+                assert len(arg.childs) == strategy_len
+                tuple_strategies.append(arg)
+            elif isinstance(arg, OpStrategy):
+                if arg_idx > 0:  # implicitly broadcast
+                    tuple_strategies.append(
+                        TupleStrategy([arg for _ in range(strategy_len)])
+                    )
+                else:
+                    raise RuntimeError(
+                        f"foreach list op only supports tuple strategy! {op_schema}"
+                    )
+        return tuple_strategies
+
+    args_strategies = args_tuple_strategies(op_schema.args_schema)
+    follow_strategy: TupleStrategy = args_strategies[0]
+    foreach_strategy_list: List[OpStrategy] = []
+    for child_idx, child_strtgy in enumerate(follow_strategy.childs):
+        assert isinstance(child_strtgy, OpStrategy)
+        args_schema: List[StrategyType] = [
+            arg_strategy.childs[child_idx] for arg_strategy in args_strategies
+        ]
+        pointwise_strategy: OpStrategy = common_pointwise_strategy(
+            mesh, args_schema, child_strtgy, linearity
+        )
+        foreach_strategy_list.append(pointwise_strategy)
+    return TupleStrategy(foreach_strategy_list)
+
+
+def foreach_list_linear_pointwise_strategy(
+    mesh: DeviceMesh, op_schema: OpSchema
+) -> StrategyType:
+    """
+    for each list op stratgy that supports linearity
+    """
+    return foreach_list_pointwise_strategy(mesh, op_schema, linearity=True)
+
+
+for op in for_each_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
+        foreach_list_pointwise_strategy
+    )
+
+for op in for_each_linearity_ops:
+    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
+        foreach_list_linear_pointwise_strategy
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/random_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/random_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..3db7e2f4c0295e9d9c25c871065f44856c58c7bc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/random_ops.py
@@ -0,0 +1,30 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import torch
+from torch.distributed._tensor.op_schema import (
+    OpSchema,
+    OpStrategy,
+    PlacementStrategy,
+    StrategyType,
+)
+from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy
+from torch.distributed.device_mesh import DeviceMesh
+
+aten = torch.ops.aten
+
+
+@register_op_strategy(
+    [aten.normal_.default, aten.uniform_.default, aten.native_dropout.default]
+)
+def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    self_strategy = op_schema.args_schema[0]
+    assert isinstance(self_strategy, OpStrategy)
+
+    random_strategy = OpStrategy([])
+    for arg_strategy in self_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if is_tensor_partial(arg_spec):
+            # TODO: figure out how inplace random op should behave when it's partial
+            raise RuntimeError(f"{op_schema.op} with _Partial is not supported yet!")
+        random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec))
+
+    return random_strategy
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/tensor_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/tensor_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..2af0814e1778405e12b65bb91cd9972dc6bc8d28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/tensor_ops.py
@@ -0,0 +1,826 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import itertools
+from typing import cast, List, Optional, Sequence, Tuple
+
+import torch
+
+from torch.distributed._tensor._utils import compute_local_shape
+from torch.distributed._tensor.op_schema import (
+    OpSchema,
+    OpStrategy,
+    OutputSharding,
+    PlacementStrategy,
+    RuntimeSchemaInfo,
+    StrategyType,
+    TupleStrategy,
+)
+from torch.distributed._tensor.ops.common_rules import pointwise_rule
+from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
+from torch.distributed._tensor.ops.utils import (
+    generate_redistribute_costs,
+    is_tensor_dim_sharded,
+    is_tensor_partial,
+    is_tensor_shardable,
+    normalize_dim,
+    prod,
+    register_op_strategy,
+    register_prop_rule,
+)
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.device_mesh import DeviceMesh
+
+
+aten = torch.ops.aten
+
+
+def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    # Default strategy by default just propagate the first input strategy
+    select_strategy = op_schema.args_schema[0]
+    assert isinstance(select_strategy, OpStrategy)
+    default_strategy = []
+    for strategy in select_strategy.strategies:
+        # we create new DTensorSpecs even for default strategy to assure that
+        # the tensor metas are distinct between the arguments and outputs
+        default_strategy.append(
+            PlacementStrategy(
+                output_specs=DTensorSpec(
+                    mesh=strategy.output_spec.mesh,
+                    placements=strategy.output_spec.placements,
+                )
+            )
+        )
+    return OpStrategy(default_strategy)
+
+
+register_op_strategy(
+    [
+        aten.clone.default,
+        aten.contiguous.default,
+        aten.copy_.default,
+        aten.detach.default,
+        aten.fill_.Scalar,
+        aten.zero_.default,
+    ]
+)(default_strategy)
+
+register_op_strategy(
+    aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"])
+)(default_strategy)
+
+
+@register_op_strategy(
+    [
+        aten.equal.default,
+        aten.is_same_size.default,
+    ]
+)
+def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    # equal_strategy deals with ops that comparing two tensor, we need to make sure
+    # sharding layout the same with two operands, we choose to follow the arg with max
+    # num of shards, still keep is_same_size here for completeness as they share the
+    # same strategy in theory.
+    self_strategy, other_strategy = op_schema.args_schema
+    assert isinstance(self_strategy, OpStrategy)
+    assert isinstance(other_strategy, OpStrategy)
+
+    select_strategy = (
+        self_strategy
+        if self_strategy.max_num_shards() >= other_strategy.max_num_shards()
+        else other_strategy
+    )
+    equal_strategy = OpStrategy([])
+
+    for arg_strategy in select_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if is_tensor_partial(arg_spec):
+            # if the arg_spec have partial, reshard to replicate
+            # otherwise local shard tensor comparison would be invalid
+            output_spec = DTensorSpec(
+                mesh=arg_spec.mesh,
+                placements=tuple(
+                    Replicate() if isinstance(p, _Partial) else p
+                    for p in arg_spec.placements
+                ),
+            )
+            equal_strategy.strategies.append(
+                PlacementStrategy(output_specs=output_spec)
+            )
+        else:
+            equal_strategy.strategies.append(PlacementStrategy(arg_spec))
+    return equal_strategy
+
+
+@register_op_strategy(
+    [
+        aten.empty_like.default,
+        aten.ones_like.default,
+        aten.rand_like.default,
+        aten.randn_like.default,
+        aten.zeros_like.default,
+    ],
+    schema_info=RuntimeSchemaInfo(1, ["dtype"]),
+)
+@register_op_strategy(
+    [aten.full_like.default],
+    schema_info=RuntimeSchemaInfo(2, ["dtype"]),
+)
+@register_op_strategy(
+    [
+        aten.randint_like.default,
+        aten.randint_like.low_dtype,
+        aten.randint_like.low_dtype_out,
+    ],
+    schema_info=RuntimeSchemaInfo(3, ["dtype"]),
+)
+def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    # create_like_strategy deals with ops that creating tensors with same
+    # shape as input, but with specific content that does not depend on
+    # the input, we can propagate sharding, but we have to make sure we
+    # move from partial to replicated.
+    select_strategy = op_schema.args_schema[0]
+    create_like_strategy = OpStrategy([])
+    assert isinstance(select_strategy, OpStrategy)
+    for arg_strategy in select_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if is_tensor_partial(arg_spec):
+            # if the arg_spec have partial, accept partial
+            # in the input_specs but output replicate for
+            # those corresponding mesh dims
+            output_spec = DTensorSpec(
+                mesh=arg_spec.mesh,
+                placements=tuple(
+                    Replicate() if isinstance(p, _Partial) else p
+                    for p in arg_spec.placements
+                ),
+            )
+            create_like_strategy.strategies.append(
+                PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,))
+            )
+
+        else:
+            create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
+
+    return create_like_strategy
+
+
+@register_op_strategy(
+    [
+        aten.new_empty.default,
+        aten.new_full.default,
+        aten.new_ones.default,
+        aten.new_zeros.default,
+        aten.new_empty_strided.default,  # TODO: re-think new_empty_strided
+    ],
+    schema_info=RuntimeSchemaInfo(1, ["dtype"]),
+)
+def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    # TODO: maybe we should generate all possible shardings intead of just stay
+    # replicated for new factory methods
+    input_strategy = op_schema.args_schema[0]
+    new_factory_strategy = OpStrategy([])
+    assert isinstance(input_strategy, OpStrategy)
+    for arg_strategy in input_strategy.strategies:
+        input_spec = arg_strategy.output_spec
+        replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
+        new_factory_strategy.strategies.append(
+            PlacementStrategy(output_specs=replica_spec, input_specs=(input_spec,))
+        )
+
+    return new_factory_strategy
+
+
+@register_op_strategy(aten.bucketize.Tensor)
+def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    """Just propagate input sharding, but expect replicated for boundaries input."""
+    input_strategy = op_schema.args_schema[0]
+    bucketize_strategy = OpStrategy([])
+    assert isinstance(input_strategy, OpStrategy)
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements)
+        replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
+        bucketize_strategy.strategies.append(
+            PlacementStrategy(
+                output_specs=arg_spec, input_specs=(arg_spec, replica_spec)
+            )
+        )
+
+    return bucketize_strategy
+
+
+@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1))
+def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    """Forward all shardings except the slice dimension."""
+    defaults = (None, 0, None, None, 1)
+    input_strategy, dim, start, end, step = (
+        op_schema.args_schema + defaults[len(op_schema.args_schema) :]
+    )
+    assert isinstance(input_strategy, OpStrategy)
+    input_shape = input_strategy.output_shape
+    input_ndim = input_strategy.output_ndim
+    assert isinstance(dim, int)
+    if start is None:
+        start = 0
+    if end is None or end > input_shape[dim]:
+        end = input_shape[dim]
+    assert isinstance(start, int)
+    assert isinstance(end, int)
+    assert isinstance(step, int)
+
+    # normalize args
+    slice_dim = normalize_dim(dim, input_ndim)
+    start = normalize_dim(start, input_shape[dim])
+    end = normalize_dim(end, input_shape[dim])
+
+    redundant_slice = start == 0 and end == input_shape[dim] and step == 1
+
+    slice_strategy = OpStrategy([])
+
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice:
+            # only add the strategy if the slice dim is not sharded
+            out_spec = DTensorSpec(mesh, arg_spec.placements)
+            slice_strategy.strategies.append(PlacementStrategy(output_specs=out_spec))
+    if not slice_strategy.strategies:
+        # if all strategies are filtered out, unsharding all specs on slice dim
+        # of the input strategy, and use that as the op strategy
+        for arg_strategy in input_strategy.strategies:
+            arg_spec = arg_strategy.output_spec
+            unshard_spec = DTensorSpec(
+                mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim)
+            )
+            slice_strategy.strategies.append(
+                PlacementStrategy(output_specs=unshard_spec)
+            )
+    return slice_strategy
+
+
+def unshard_tensor_dim(
+    placements: Sequence[Placement], dim: int
+) -> Tuple[Placement, ...]:
+    """Disallow the given tensor dimension to be sharded."""
+    return tuple(
+        p if (not isinstance(p, Shard) or p.dim != dim) else Replicate()
+        for p in placements
+    )
+
+
+def replicate_tensor_dim(
+    placements: Sequence[Placement], dim: int
+) -> Tuple[Placement, ...]:
+    """Force the given tensor dimension to be replicated."""
+    # Not using p.is_shard() to avoid mypy complain about Placement not having
+    # attribute dim.
+    return tuple(
+        Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p
+        for p in placements
+    )
+
+
+@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
+def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    # 1. number of dimensions in input and src need to match.
+    # 2. number of elements on all non-dim need to match between input and src.
+    # 3. numer of elements in src in dim need to match the slice size.
+    # Given the above:
+    # - We suggest for src to follow the sharding of input, except on the scatter dimension,
+    #   where our best bet for now is to make them replicated as a fall-back.
+    #   TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.
+
+    input_strategy = op_schema.args_schema[0]
+    assert isinstance(input_strategy, OpStrategy)
+    input_ndim = input_strategy.output_ndim
+    slice_dim = (
+        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
+    )
+    slice_dim = normalize_dim(slice_dim, input_ndim)
+
+    slice_scatter_strategy = OpStrategy([])
+    # by default follow the input strategy for both input and src
+    for arg_strategy in input_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        if not (
+            is_tensor_dim_sharded(arg_spec, dim=slice_dim)
+            or is_tensor_partial(arg_spec)
+        ):
+            # only add the strategy if the slice_scatter dim is not sharded or partial
+            slice_scatter_strategy.strategies.append(
+                PlacementStrategy(output_specs=arg_spec)
+            )
+
+    if not slice_scatter_strategy.strategies:
+        # if all strategies are filtered out, replicating all specs on slice_scatter dim
+        # of the input strategy, and use that as the op strategy
+        for arg_strategy in input_strategy.strategies:
+            arg_spec = arg_strategy.output_spec
+            replicate_spec = DTensorSpec(
+                mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim)
+            )
+            slice_scatter_strategy.strategies.append(
+                PlacementStrategy(output_specs=replicate_spec)
+            )
+    return slice_scatter_strategy
+
+
+@register_op_strategy(aten._local_scalar_dense.default)
+def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    """Only allow replication on the input/output."""
+    replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
+    return OpStrategy([PlacementStrategy(replicate_spec)])
+
+
+@register_op_strategy(aten.gather.default)
+def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    dim = cast(int, op_schema.args_schema[1])
+    index_strategy = cast(OpStrategy, op_schema.args_schema[2])
+
+    input_shape = input_strategy.output_shape
+    index_shape = index_strategy.output_shape
+
+    all_mesh_dim_strategies = []
+
+    for mesh_dim in range(mesh.ndim):
+        single_mesh_dim_strategies = []
+
+        # placement list stores placements of [output, input, index]
+        # first we always have replicate all for inputs and output
+        all_replicate: List[Placement] = [Replicate()] * 3
+        single_mesh_dim_strategies.append(all_replicate)
+
+        # input sharding, input sharded, index accepts mask partial, output follows index
+        # this only works when the input is sharded on the gather dimension, and
+        # index has size 1 on the gather dimension
+        if index_shape[dim] == 1:
+            index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim])
+            input_sharding = [
+                index_partial_placement,
+                Shard(dim),
+                index_partial_placement,
+            ]
+            single_mesh_dim_strategies.append(input_sharding)
+
+        # index sharding, input replicated, index sharded, output follows index
+        # this only works when the sharding dimension is the gather dimension
+        index_sharding = [Shard(dim), Replicate(), Shard(dim)]
+        single_mesh_dim_strategies.append(index_sharding)
+
+        all_mesh_dim_strategies.append(single_mesh_dim_strategies)
+
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = []
+        for specs in zip(*strategy_comb):
+            spec_list.append(DTensorSpec(mesh, tuple(specs)))
+
+        if is_tensor_shardable(input_shape, spec_list[1]) and is_tensor_shardable(
+            index_shape, spec_list[2]
+        ):
+            input_spec, index_spec = spec_list[1:]
+            redistribute_cost = [
+                generate_redistribute_costs(input_strategy, input_spec),
+                generate_redistribute_costs(index_strategy, index_spec),
+            ]
+            strat = PlacementStrategy(
+                output_specs=spec_list[0],
+                input_specs=spec_list[1:],
+                redistribute_cost=redistribute_cost,
+            )
+            all_strategies.append(strat)
+
+    return OpStrategy(all_strategies)
+
+
+@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True))
+def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
+    args_schema = op_schema.args_schema
+    input_tuple_strategy = args_schema[0]
+    assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
+    dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
+
+    # Follow the 1st child strategy's placement strategies
+    child_strategy = input_tuple_strategy.childs[0]
+    assert isinstance(child_strategy, OpStrategy), f"{child_strategy}"
+    strategies: List[PlacementStrategy] = []
+
+    # For each arg strategy of the child to follow, we check if every other
+    # child has an equal strategy. If so, then that is a valid strategy. If
+    # there are no such valid strategies, then we replicate.
+    for arg_strategy in child_strategy.strategies:
+        arg_spec = arg_strategy.output_spec
+        # For each arg strategy (whether the one to follow or other), we
+        # replicate the stack dim since we cannot stack on a sharded dim
+        if is_tensor_dim_sharded(arg_spec, dim):
+            arg_spec = DTensorSpec(
+                mesh, unshard_tensor_dim(arg_spec.placements, dim=dim)
+            )
+        all_compatible = True
+        for other_child_strategy in input_tuple_strategy.childs[1:]:
+            has_compatible_strategy = False
+            assert isinstance(
+                other_child_strategy, OpStrategy
+            ), f"{other_child_strategy}"
+            for other_arg_strategy in other_child_strategy.strategies:
+                other_arg_spec = other_arg_strategy.output_spec
+                if is_tensor_dim_sharded(other_arg_spec, dim):
+                    other_arg_spec = DTensorSpec(
+                        mesh, unshard_tensor_dim(other_arg_spec.placements, dim=dim)
+                    )
+                if other_arg_spec.placements == arg_spec.placements:
+                    has_compatible_strategy = True
+                    break
+            if not has_compatible_strategy:
+                all_compatible = False
+                break
+        if all_compatible:
+            input_specs = tuple(
+                arg_spec for _ in range(len(input_tuple_strategy.childs))
+            )
+            strategies.append(
+                PlacementStrategy(
+                    output_specs=DTensorSpec(mesh, arg_spec.placements),
+                    input_specs=input_specs,
+                )
+            )
+    if not strategies:
+        # Arbitrarily use each child strategy's 0th strategy's output spec
+        input_specs = tuple(
+            cast(OpStrategy, child_strategy).strategies[0].output_spec
+            for child_strategy in input_tuple_strategy.childs
+        )
+        replicate_spec = DTensorSpec(mesh, tuple(Replicate() for _ in range(mesh.ndim)))
+        strategies.append(PlacementStrategy(output_specs=replicate_spec))
+    return OpStrategy(strategies)
+
+
+@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1))
+def prop_index_select(op_schema: OpSchema) -> OutputSharding:
+    values_spec, dim, indices_spec = op_schema.args_schema
+
+    assert isinstance(values_spec, DTensorSpec)
+    assert isinstance(dim, int)
+    assert isinstance(indices_spec, DTensorSpec)
+
+    all_indices_spec: List[Optional[DTensorSpec]] = [
+        indices_spec if dim == i else None for i in range(values_spec.ndim)
+    ]
+
+    result = prop_index(
+        OpSchema(
+            op=op_schema.op,
+            args_schema=(values_spec, all_indices_spec),
+            kwargs_schema=op_schema.kwargs_schema,
+        )
+    )
+    if result.schema_suggestions:
+        result.schema_suggestions = [
+            OpSchema(
+                op=op_schema.op,
+                args_schema=(s.args_schema[0], dim, s.args_schema[1][dim]),
+                kwargs_schema=op_schema.kwargs_schema,
+            )
+            for s in result.schema_suggestions
+        ]
+    return result
+
+
+@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True))
+def prop_index(op_schema: OpSchema) -> OutputSharding:
+    """
+    Expect replicated on the first input; _mostly_ pointwise on the second input.
+
+    TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first.
+    """
+    # Current sharding constraints:
+    # For values:
+    #   1. We currently require that the dimension of values_spec be replicated or partial
+    #      if they are being indexed on.
+    #   2. Other dimensions of values_spec can remain sharded if they are so.
+    # For indices:
+    #   Indices can be either sharded or replicated. All index tensors need to be sharded
+    #   in a compatible way, following the pointwise rule (including resolving _Partial
+    #   into either sharded or replicated)
+
+    values_spec, multi_indices_spec = op_schema.args_schema
+    assert isinstance(values_spec, DTensorSpec)
+    assert isinstance(multi_indices_spec, list)
+    multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec)
+    valid_indices_spec: List[Tuple[int, DTensorSpec]] = [
+        (i, a) for i, a in enumerate(multi_indices_spec) if a is not None
+    ]
+
+    # 1. All indices have to be sharded equally. Moreover, indices can be broadcast.
+    #    Here, we piggyback on the pointwise sharding rule for indices.
+    indices_out = pointwise_rule(
+        OpSchema(
+            op=op_schema.op,
+            args_schema=tuple(v[1] for v in valid_indices_spec),
+            kwargs_schema={},
+        )
+    )
+    need_reshard_on_indices = indices_out.output_spec is None
+
+    if not need_reshard_on_indices:
+        # this means that our inputs are already sharded properly and we will use that as our indices_spec
+        assert isinstance(indices_out.output_spec, DTensorSpec)
+        indices_spec: DTensorSpec = indices_out.output_spec
+    else:
+        assert indices_out.schema_suggestions is not None
+        valid_indices_suggestion = indices_out.schema_suggestions[0]
+        for i, v in enumerate(valid_indices_suggestion.args_spec):
+            multi_indices_spec[valid_indices_spec[i][0]] = v
+        # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then
+        # use that to compute our ideal values_spec
+        indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec
+        assert isinstance(indices_output_spec, DTensorSpec)
+        indices_spec = indices_output_spec
+
+    lookup_dims = {v[0] for v in valid_indices_spec}
+
+    need_reshard_on_values = tuple(
+        (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard)))
+        for vp, ip in zip(values_spec.placements, indices_spec.placements)
+    )
+
+    if not need_reshard_on_indices and not any(need_reshard_on_values):
+        value_placements = values_spec.placements
+
+        all_dims_consecutive = all(
+            b[0] - a[0] == 1
+            for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1])
+        )
+        if all_dims_consecutive:
+            # if all index vectors are consecutives, insert at the dimension of the first index
+            insert_dim: int = valid_indices_spec[0][0]
+        else:
+            # else, insert on the first dimension
+            insert_dim = 0
+
+        def place(vp: Placement, ip: Placement) -> Placement:
+            if isinstance(vp, Shard):
+                return Shard(
+                    vp.dim
+                    if vp.dim < insert_dim
+                    # accounts for the offset in output dimensions
+                    else vp.dim
+                    + indices_spec.ndim
+                    - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec)
+                )
+            if isinstance(ip, Shard):
+                return Shard(ip.dim + insert_dim)
+            # _Partial or Replicated
+            return vp
+
+        value_placements = tuple(
+            place(vp, ip)
+            for vp, ip in zip(values_spec.placements, indices_spec.placements)
+        )
+        result = OutputSharding(
+            output_spec=DTensorSpec(
+                mesh=values_spec.mesh,
+                placements=value_placements,
+            )
+        )
+        return result
+    else:
+        result = OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(
+                        DTensorSpec(
+                            mesh=values_spec.mesh,
+                            placements=tuple(
+                                [
+                                    Replicate() if need_reshard_on_values[i] else v
+                                    for i, v in enumerate(values_spec.placements)
+                                ]
+                            ),
+                            tensor_meta=values_spec.tensor_meta,
+                        ),
+                        multi_indices_spec,
+                    ),
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+            ],
+        )
+        return result
+
+
+@register_prop_rule(
+    aten.cat.default, schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
+)
+def cat_rule(op_schema: OpSchema) -> OutputSharding:
+    # torch.cat requires all tensors must either have the same shape (except
+    # in the concatenating dimension) or be "empty". "Empty" here strictly means
+    # tensor.shape is torch.Size([0]). When tensor.ndim > 1, it will be treated
+    # as a non-empty tensor and the shape must match on non-cat dimensions.
+    def is_empty(spec: DTensorSpec) -> bool:
+        return list(spec.shape) == [0]
+
+    # the first arg is a list of input tensor specs
+    tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0])
+    assert len(tensor_list_specs) > 0, "torch.cat expects a non-empty list of tensors"
+    non_empty_specs = [spec for spec in tensor_list_specs if not is_empty(spec)]
+
+    if len(non_empty_specs) == 0:
+        # all tensors are empty, we can return any output sharding
+        return OutputSharding(
+            output_spec=DTensorSpec(
+                mesh=tensor_list_specs[0].mesh,
+                placements=tensor_list_specs[0].placements,
+            )
+        )
+
+    assert all(
+        spec.ndim == non_empty_specs[0].ndim for spec in non_empty_specs
+    ), f"Expect all tensors to have same shape or empty, but got {tensor_list_specs}"
+    assert all(
+        spec.mesh == tensor_list_specs[0].mesh for spec in tensor_list_specs
+    ), f"Expect all tensors to have same mesh, but got {tensor_list_specs}"
+
+    # ndim will also be the result's ndim
+    ndim = 1
+    for spec in tensor_list_specs:
+        ndim = max(ndim, spec.ndim)
+
+    dim = 0  # default dim = 0
+    if len(op_schema.args_schema) > 1:
+        dim = cast(int, op_schema.args_schema[1])
+    dim = normalize_dim(dim, ndim)
+
+    # Make sure all tensors are replicated on cat dimension
+    need_reshard = False
+    tensor_list_specs_after: List[DTensorSpec] = []
+    for spec in tensor_list_specs:
+        if not is_empty(spec) and (
+            is_tensor_dim_sharded(spec, dim=dim) or is_tensor_partial(spec)
+        ):
+            need_reshard = True
+            tensor_list_specs_after.append(
+                DTensorSpec(
+                    mesh=spec.mesh,
+                    placements=replicate_tensor_dim(spec.placements, dim=dim),
+                    tensor_meta=spec.tensor_meta,
+                )
+            )
+        else:
+            tensor_list_specs_after.append(spec)
+
+    tensor_list_specs = tensor_list_specs_after
+
+    # align non-cat dimensions placements based on reshard cost
+    non_empty_specs = [spec for spec in tensor_list_specs if not is_empty(spec)]
+    mesh = non_empty_specs[0].mesh
+    ndim = non_empty_specs[0].ndim
+    new_placements: List[Placement] = []
+    for mesh_dim in range(mesh.ndim):
+        # compute the minimum cost of resharding on this mesh_dim
+        if any(
+            spec.placements[mesh_dim] != non_empty_specs[0].placements[mesh_dim]
+            for spec in non_empty_specs
+        ):
+            # only reshard if there is a mismatch
+            need_reshard = True
+            reshard_cost = []
+            for shard_dim in range(ndim):
+                # compute the cost of resharding on this shard_dim
+                cost: float = 0.0
+                for spec in non_empty_specs:
+                    global_shape = spec.shape
+                    if global_shape[shard_dim] < mesh.size(mesh_dim):
+                        # found one tensor where the shard_dim is smaller than
+                        # mesh_dim. In this case, we cannot shard on this shard_dim,
+                        # and hence set cost to infinity.
+                        cost = +float("inf")
+                    elif (
+                        is_tensor_dim_sharded(spec, dim=shard_dim)
+                        or prod(global_shape) == 0
+                    ):
+                        continue
+                    else:
+                        local_shape = compute_local_shape(
+                            global_shape, spec.mesh, spec.placements
+                        )
+                        cost += prod(local_shape) * spec.mesh.size(mesh_dim)
+                reshard_cost.append(cost)
+            best_dim = reshard_cost.index(min(reshard_cost))
+            new_placements.append(Shard(best_dim))
+        else:
+            # no mismatch, keep the original placement
+            new_placements.append(non_empty_specs[0].placements[mesh_dim])
+
+    if need_reshard:
+        tensor_list_specs_after = []
+        for spec in tensor_list_specs:
+            if is_empty(spec):
+                tensor_list_specs_after.append(spec)
+            else:
+                tensor_list_specs_after.append(
+                    DTensorSpec(
+                        mesh=spec.mesh,
+                        placements=tuple(new_placements),
+                        tensor_meta=spec.tensor_meta,
+                    )
+                )
+
+        return OutputSharding(
+            output_spec=None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(
+                        tuple(tensor_list_specs_after),
+                        *op_schema.args_schema[1:],
+                    ),
+                    kwargs_schema=op_schema.kwargs_schema,
+                ),
+            ],
+        )
+    else:
+        # at this point, the cat dim is not sharded,
+        return OutputSharding(
+            output_spec=DTensorSpec(
+                mesh=non_empty_specs[0].mesh,
+                placements=non_empty_specs[0].placements,
+            ),
+        )
+
+
+@register_prop_rule(
+    [
+        aten.split.Tensor,
+        aten.split_with_sizes.default,
+        aten.split_with_sizes_copy.default,
+    ],
+    schema_info=RuntimeSchemaInfo(1),
+)
+def split_rule(op_schema: OpSchema) -> OutputSharding:
+    output_spec_list: List[DTensorSpec] = []
+    input_spec = cast(DTensorSpec, op_schema.args_schema[0])
+    ndim = input_spec.ndim
+    split_size_or_sections = op_schema.args_schema[1]
+    dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
+    dim = normalize_dim(dim, ndim)
+
+    # TODO: tensor to split cannot have _Partial
+    # in its placements for now. Will need to
+    # support in future.
+    if input_spec.sums:
+        raise NotImplementedError(
+            f"splitting distributed tensor with "
+            f"_Partial placement is not implemented!\n"
+            f"DTensorSpec={input_spec}"
+        )
+
+    # TODO: just like slice op, split replicates before
+    # splitting on a sharded dimension
+    need_reshard = False
+    if is_tensor_dim_sharded(input_spec, dim=dim):
+        need_reshard = True
+        input_spec = DTensorSpec(
+            mesh=input_spec.mesh,
+            placements=unshard_tensor_dim(input_spec.placements, dim=dim),
+            tensor_meta=input_spec.tensor_meta,
+        )
+
+    if need_reshard:
+        return OutputSharding(
+            None,
+            schema_suggestions=[
+                OpSchema(
+                    op=op_schema.op,
+                    args_schema=(input_spec,) + op_schema.args_schema[1:],
+                    kwargs_schema=op_schema.kwargs_schema,
+                ),
+            ],
+        )
+
+    def size_split(N, i):
+        # Last chunk will be smaller if the tensor size N
+        # along the given dimension dim is not divisible by i.
+        assert i > 0
+        return [i] * (N // i) + ([N % i] if N % i != 0 else [])
+
+    output_size_list = (
+        size_split(input_spec.shape[dim], split_size_or_sections)
+        if isinstance(split_size_or_sections, int)
+        else split_size_or_sections
+    )
+    output_spec_list = [
+        DTensorSpec(
+            mesh=input_spec.mesh,
+            placements=input_spec.placements,
+        )
+        for _ in range(len(output_size_list))
+    ]
+    return OutputSharding(output_spec_list)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/utils.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..559ba483efdfa278d116fac89a02f44d0ceef01a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/utils.py
@@ -0,0 +1,226 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import functools
+import operator
+from typing import cast, Iterable, List, Sequence, Tuple, Union
+
+import torch
+from torch.distributed._tensor._collective_utils import redistribute_cost
+from torch.distributed._tensor.api import DTensor
+from torch.distributed._tensor.op_schema import OpStrategy, RuntimeSchemaInfo
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+# convenient wrapper to register sharding propagation rules
+# pyre-fixme[3]: Return type must be annotated.
+# pyre-fixme[2]: Parameter must be annotated.
+def register_prop_rule(op, schema_info=None):
+    # pyre-fixme[53]: Captured variable `func` is not annotated.
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def wrapper(impl):
+        overloads = op if isinstance(op, list) else [op]
+        for overload in overloads:
+            DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule(
+                overload, impl, schema_info
+            )
+        return impl
+
+    return wrapper
+
+
+def register_op_strategy(op, schema_info=None):
+    # pyre-fixme[53]: Captured variable `func` is not annotated.
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+
+    # For every ATen op that accepts any args in this list,
+    # the arg itself can impact the strides (and potentially the sharding strategy)
+    # of the output tensor.
+    # thus, we will detect ATen schemas with any of these args and ensure
+    # that they get specialized here.
+    arg_names_that_require_specializing_cache_strategy = [
+        "memory_format",
+    ]
+
+    def wrapper(impl):
+        if isinstance(op, list):
+            overloads = op
+        else:
+            overloads = [op]
+
+        for overload in overloads:
+            curr_schema_info = None
+            if schema_info is None:
+                specialized_args = [
+                    a.name
+                    for a in overload._schema.arguments
+                    if a.name in arg_names_that_require_specializing_cache_strategy
+                ]
+                if any(specialized_args):
+                    curr_schema_info = RuntimeSchemaInfo(
+                        static_kwargkey=specialized_args
+                    )
+            else:
+                curr_schema_info = schema_info
+            DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
+                overload, impl, curr_schema_info
+            )
+        return impl
+
+    return wrapper
+
+
+def as_list(
+    x: Union[List[object], object]
+    # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
+) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:  # type: ignore[valid-type]
+    # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
+    # which is an object but treated as a list by the tracer. Therefore, keep
+    # `immutable_list` intact here as well.
+    if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
+        return x
+    else:
+        return [x]
+
+
+def normalize_dim(dim: int, ndim: int) -> int:
+    return dim if dim >= 0 else dim + ndim
+
+
+def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]:
+    """Normalize a dim or a sequence of dims, so that they are all positive."""
+    if isinstance(dims, int):
+        dims = (normalize_dim(dims, ndim),)
+    elif isinstance(dims, list):
+        dims = [normalize_dim(dim, ndim) for dim in dims]
+    elif isinstance(dims, tuple):
+        dims = tuple([normalize_dim(dim, ndim) for dim in dims])
+    return dims
+
+
+def normalize_to_torch_size(size) -> torch.Size:
+    """
+    Unify variable types of size argument to torch.Size
+    Acceptable types include:
+        int, Sequence[int], Tuple[int], Tuple[Sequence[int]],
+        or torch.Size
+    """
+    if isinstance(size, torch.Size):
+        return size
+
+    if isinstance(size, int):
+        torch_size = [size]
+    elif len(size) == 1 and isinstance(size[0], Sequence):
+        torch_size = list(size[0])
+    else:
+        torch_size = list(size)
+    return torch.Size(torch_size)
+
+
+def prod(xs: Iterable[int]) -> int:
+    return functools.reduce(operator.mul, xs, 1)
+
+
+def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
+    """Check if the shape is shardable according to the spec."""
+    # number of shards in each tensor dimension
+    shards_map = [1] * len(shape)
+    for i, placement in enumerate(spec.placements):
+        if placement.is_shard():
+            shard_dim = cast(Shard, placement).dim
+            shards_map[shard_dim] *= spec.mesh.size(i)
+
+    for i, dim_size in enumerate(shape):
+        # TODO: maybe we should determine is_shardable based on
+        #       whether it's evenly sharded or not
+        if shards_map[i] > 1 and dim_size < shards_map[i]:
+            return False
+
+    return True
+
+
+def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
+    """Check if the shape is evenly shardable according to the spec."""
+    # number of shards in each tensor dimension
+    shards_map = [1] * len(shape)
+    for i, placement in enumerate(spec.placements):
+        if placement.is_shard():
+            shard_dim = cast(Shard, placement).dim
+            shards_map[shard_dim] *= spec.mesh.size(i)
+
+    for i, dim_size in enumerate(shape):
+        if shards_map[i] > 1 and (dim_size % shards_map[i] != 0):
+            return False
+
+    return True
+
+
+def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool:
+    """Return True if tensor dim is sharded."""
+    return any(p.is_shard(dim) for p in spec.placements)
+
+
+def is_tensor_partial(spec: DTensorSpec) -> bool:
+    """Return True if tensor is partial on the mesh."""
+    return any(p.is_partial() for p in spec.placements)
+
+
+def infer_broadcast_dims_map(
+    common_shape: torch.Size, input_shape: torch.Size
+) -> List[int]:
+    # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
+    # this is aligned with the broadcast semantics
+    common_ndim = len(common_shape)
+    input_ndim = len(input_shape)
+    broadcast_dims_map = [-1] * common_ndim
+    for idx in range(-1, -1 - input_ndim, -1):
+        if input_shape[idx] == common_shape[idx]:
+            broadcast_dims_map[common_ndim + idx] = input_ndim + idx
+    return broadcast_dims_map
+
+
+def map_placements_after_broadcast(
+    placements: Tuple[Placement, ...],
+    shape: torch.Size,
+    broadcast_dims_map: List[int],
+) -> Tuple[Placement, ...]:
+    """Map each placement based on the output shape after broadcast."""
+    new_placements: List[Placement] = []
+    for placement in placements:
+        if isinstance(placement, (Replicate, _Partial)):
+            new_placements.append(placement)
+        else:
+            assert isinstance(placement, Shard)
+            shard_dim = normalize_dim(placement.dim, len(shape))
+            new_shard_dim = broadcast_dims_map[shard_dim]
+            if new_shard_dim != -1:
+                # there's a map from the common shape shard dim to
+                # the input shape shard dim before broadcasting,
+                # use that instead
+                new_placements.append(Shard(new_shard_dim))
+            else:
+                # there's no map between common shape shard dim and
+                # the input shape shard dim before broadcasting,
+                # in this case it means implicit broadcasting happen
+                # in this dim, so we can just mark it as replicate
+                # and implict broadcast will broadcast automatically
+                # to the sharded shape
+                new_placements.append(Replicate())
+
+    return tuple(new_placements)
+
+
+def generate_redistribute_costs(
+    src_strategy: OpStrategy, dst_spec: DTensorSpec
+) -> List[float]:
+    redistribute_costs: List[float] = []
+    for strat in src_strategy.strategies:
+        redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
+
+    return redistribute_costs
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/view_ops.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/view_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbb567f220a6c337c8e6804837d391c8c4b028a1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/ops/view_ops.py
@@ -0,0 +1,717 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from dataclasses import dataclass
+from typing import Callable, cast, Dict, Iterable, Optional, Sequence, Set, Tuple, Union
+
+import torch
+
+from torch import Tensor
+from torch._subclasses.fake_tensor import unset_fake_temporarily
+from torch.distributed._tensor._utils import compute_local_shape
+from torch.distributed._tensor.api import Shard
+from torch.distributed._tensor.op_schema import (
+    OpSchema,
+    OutputSharding,
+    RuntimeSchemaInfo,
+)
+from torch.distributed._tensor.ops.utils import (
+    normalize_dim,
+    normalize_dims,
+    prod,
+    register_prop_rule,
+)
+
+from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate
+from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
+
+aten = torch.ops.aten
+
+Shape = Tuple[int, ...]
+
+
+@dataclass
+class DimSpec:
+    """Specifies how an output dimension maps to an input dimension."""
+
+    def inputs(self) -> Iterable["DimSpec"]:
+        return ()
+
+
+# Rules that map each dimension of the output to dimensions of the input tensor
+DimMap = Tuple[DimSpec, ...]
+
+
+@dataclass
+class Singleton(DimSpec):
+    """Output dimension is a singleton."""
+
+    pass
+
+
+@dataclass
+class InputDim(DimSpec):
+    """Output dimension maps directly to an input dimension."""
+
+    input_dim: int
+
+
+@dataclass
+class Broadcast(DimSpec):
+    """Output is the broadcast of a singleton input dimension."""
+
+    dim: DimSpec
+    dim_size: int
+
+    @classmethod
+    def new(cls, dim: DimSpec, dim_size: int) -> DimSpec:
+        return Broadcast(dim, dim_size)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return (self.dim,)
+
+
+@dataclass
+class NewDim(DimSpec):
+    """This is a new dimension created by the op."""
+
+    size: int
+
+    @classmethod
+    def new(cls, size: int) -> DimSpec:
+        return Singleton() if size == 1 else NewDim(size)
+
+
+@dataclass
+class Repeat(DimSpec):
+    """Output dimension is the input dimension repeated n-times."""
+
+    input_dim: DimSpec
+    times: int
+
+    @classmethod
+    def new(cls, dim: DimSpec, times: int) -> DimSpec:
+        if times == 1:
+            return dim
+        elif isinstance(dim, Singleton):
+            # repeating a singleton is the same as broadcasting it
+            return Broadcast(dim, times)
+        else:
+            return Repeat(dim, times)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return (self.input_dim,)
+
+
+@dataclass
+class Flatten(DimSpec):
+    """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output."""
+
+    input_dims: Sequence[DimSpec]
+
+    @classmethod
+    def new(cls, dims: Sequence[DimSpec]) -> DimSpec:
+        if len(dims) == 0:
+            # flattening a scalar leads to a singleton
+            return Singleton()
+        elif len(dims) == 1:
+            # flattening a single dimension is no-op
+            return dims[0]
+        else:
+            return Flatten(dims)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return self.input_dims
+
+
+@dataclass
+class Split(DimSpec):
+    """
+    This dimension is a member of a decomposition of the input dim.
+
+    Note that input_dim itself could be a Flattened set of input dims.
+    """
+
+    input_dim: DimSpec
+    group_shape: Shape
+    split_id: int
+
+    @classmethod
+    def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec:
+        assert len(group_shape) > 0
+        if len(group_shape) == 1:
+            # not really a group, just return the input dim back
+            assert idx == 0
+            return dim
+        elif group_shape[idx] == 1:
+            return Singleton()
+        else:
+            # remove singletons from group
+            # group_mapping = [(new_index, (shape, old_index)) ...]
+            group_mapping = list(
+                enumerate((s, i) for i, s in enumerate(group_shape) if s != 1)
+            )
+            new_group_shape = tuple(m[1][0] for m in group_mapping)
+            new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0]
+            return Split(dim, new_group_shape, new_idx)
+
+    def inputs(self) -> Iterable[DimSpec]:
+        return (self.input_dim,)
+
+
+def dim_pad_left(ndim: int, min_dims: int) -> DimMap:
+    return (Singleton(),) * max(0, min_dims - ndim) + tuple(
+        InputDim(i) for i in range(ndim)
+    )
+
+
+def dim_atleast_3d(ndim: int) -> DimMap:
+    if ndim == 0:
+        return (Singleton(), Singleton(), Singleton())
+    elif ndim == 1:
+        return (Singleton(), InputDim(0), Singleton())
+    elif ndim == 2:
+        return (InputDim(0), InputDim(1), Singleton())
+    else:
+        return tuple(InputDim(i) for i in range(ndim))
+
+
+def expand(input_shape: Shape, shape: Shape) -> DimMap:
+    """Implement broadcast on multiple dimensions."""
+    assert len(shape) >= len(input_shape)
+
+    # 1. create padded input dimensions
+    padded_input = dim_pad_left(len(input_shape), len(shape))
+    # 2. check that input shapes are compatible
+    mapping = []
+    for p, desired_s in zip(padded_input, shape):
+        if isinstance(p, Singleton):
+            actual_s = 1
+            assert desired_s >= 0
+        else:
+            assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}"
+            actual_s = input_shape[p.input_dim]
+            assert actual_s == 1 or desired_s == -1 or desired_s == actual_s
+        mapping.append(
+            p
+            if desired_s in (1, -1) or desired_s == actual_s
+            else Broadcast.new(p, desired_s)
+        )
+    return tuple(mapping)
+
+
+def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape:
+    if isinstance(sizes[0], int):
+        return cast(Shape, sizes)
+    elif len(sizes) == 1:
+        return cast(Shape, sizes[0])  # type: ignore[redundant-cast]
+    else:
+        raise RuntimeError("Size must be int... or tuple")
+
+
+def dim_flatten(ndim: int) -> DimMap:
+    if ndim == 0:
+        return (Singleton(),)
+    elif ndim == 1:
+        return (InputDim(0),)
+    else:
+        return (Flatten.new(tuple(InputDim(i) for i in range(ndim))),)
+
+
+def dim_movedim(
+    ndim: int,
+    input: Union[int, Sequence[int]],
+    destination: Union[int, Sequence[int]],
+) -> DimMap:
+    input = normalize_dims(input, ndim)
+    destination = normalize_dims(destination, ndim)
+
+    assert len(input) == len(destination)
+    input_set = set(input)
+    assert len(input_set) == len(input), "Found repeated input dims"
+    assert len(set(destination)) == len(destination), "Found repeated output dims"
+    assert max(input) < ndim
+    assert max(destination) < ndim
+
+    dest = [-1] * ndim
+    for i, d in zip(input, destination):
+        dest[d] = i
+
+    unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set)
+    for i in range(ndim):
+        if dest[i] == -1:
+            dest[i] = next(unused_inputs_iter)
+
+    return tuple(InputDim(i) for i in dest)
+
+
+def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
+    sizes = normalize_sizes(sizes)
+    assert (
+        len(sizes) >= ndim
+    ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
+    pad = len(sizes) - ndim
+    return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
+        Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
+    )
+
+
+def infer_size(total_size: int, sizes: Shape) -> Shape:
+    """
+    One dimension input to view may be "-1".
+
+    Infer the size of this dimension given the total_size.
+    """
+    infers = [i for i, s in enumerate(sizes) if s == -1]
+    size = prod(sizes)
+    assert len(infers) <= 1, "can only infer one size"
+    if infers:
+        size = -size
+        missing_size = total_size // size
+        assert (
+            total_size % size == 0
+        ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
+        return tuple(s if s != -1 else missing_size for s in sizes)
+    assert size == total_size, f"sizes do not match {total_size} vs {size}"
+    return sizes
+
+
+def view_groups(from_size: Shape, to_size: Shape) -> DimMap:
+    """
+    Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension.
+
+    A view or reshape operation can be decomposed into a set of 3 types of smaller operations:
+    1) Forward a dimension from input to output
+    2) Flatten a set of dimensions into a single dimension
+    3) Split one dimension into multiple dimensions
+
+    view_groups identifies these operations and returns, for each output dimension, what
+    is operation was performed in the input dimension. For example:
+
+        view_groups([2, 3, 4], [2, 12]) -> (
+            InputDim(0),
+            Flatten((InputDim(1), InputDim(2)))
+        )
+
+    - ouptut dimension 0 maps to input dimension 0
+    - output dimension 1 maps to a flattened input dimensions 1 and 2
+
+
+        view_groups([2, 3], [3, 2]) -> (
+            Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
+            Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
+        )
+
+    - in the above, input is flattened into a single dimension and then split
+      into two separate dimensions with different sizes from the input.
+    """
+    from_nelem = prod(from_size)
+    to_size = infer_size(from_nelem, normalize_sizes(to_size))
+
+    assert from_nelem == prod(to_size), "Total view shape does not add up"
+
+    from_idx = 0
+    to_idx = 0
+    from_len = len(from_size)
+    to_len = len(to_size)
+
+    result_pp = []
+
+    while from_idx < from_len or to_idx < to_len:
+        from_group_dim, to_group_shape = [], []
+
+        if from_idx >= from_len:
+            f = 1
+        else:
+            f = from_size[from_idx]
+            from_group_dim.append(from_idx)
+            from_idx += 1
+
+        if to_idx >= to_len:
+            t = 1
+        else:
+            t = to_size[to_idx]
+            to_group_shape.append(t)
+            to_idx += 1
+
+        # if any of the groups is singleton, great, we need to backtrack though
+        if f == 1 and t != 1:
+            # produces ([1], [])
+            to_idx -= 1
+            to_group_shape = []
+        elif f != 1 and t == 1:
+            # produces ([], [1])
+            from_idx -= 1
+            from_group_dim = []
+        else:
+            # produces ([1], [1]),  ([2], [2]), ([2,3], [6])
+            while f != t:
+                if f < t:
+                    nf = from_size[from_idx]
+                    from_group_dim.append(from_idx)
+                    from_idx += 1
+                    f *= nf
+                else:
+                    nt = to_size[to_idx]
+                    to_group_shape.append(nt)
+                    to_idx += 1
+                    t *= nt
+
+        if len(to_group_shape) > 0:
+            flattened = Flatten.new(
+                tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1)
+            )
+            result_pp += [
+                Split.new(flattened, tuple(to_group_shape), i)
+                for i in range(len(to_group_shape))
+            ]
+
+    return tuple(result_pp)
+
+
+def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap:
+    if len(dims) < ndim:
+        dims = (1,) * (ndim - len(dims)) + dims
+    return dim_repeat(ndim, dims)
+
+
+def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
+    dim1 = normalize_dim(dim1, ndim)
+    dim2 = normalize_dim(dim2, ndim)
+    assert dim1 < ndim
+    assert dim2 < ndim
+    dimmap = [InputDim(i) for i in range(ndim)]
+    swapdim = dimmap[dim1]
+    dimmap[dim1] = dimmap[dim2]
+    dimmap[dim2] = swapdim
+    return tuple(dimmap)
+
+
+def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap:
+    # FIXME: this is wrong when dim=None and one of the dimensions
+    # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could
+    # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to
+    # removal of a dimension that is not actually a singleton.
+    return tuple(
+        InputDim(i)
+        for i, s in enumerate(shape)
+        if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape)))
+    )
+
+
+def dim_unsqueeze(ndim: int, dim: int) -> DimMap:
+    dims = tuple(InputDim(i) for i in range(ndim))
+    if dim < 0:
+        dim += ndim + 1
+    return dims[:dim] + (Singleton(),) + dims[dim:]
+
+
+def dim_reduction(
+    ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool
+) -> DimMap:
+    """
+    General fallback for reduction ops where _Partial() does not apply.
+
+    This will cause incoming tensor to be replicated on the reducing dimensions.
+    """
+    if dim_or_dims is None:
+        dim_or_dims = tuple(range(ndim))
+    if isinstance(dim_or_dims, int):
+        dim_or_dims = (dim_or_dims,)
+    dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims)
+    return tuple(
+        InputDim(i) if i not in dim_or_dims else Singleton()
+        for i in range(ndim)
+        if i not in dim_or_dims or keepdim
+    )
+
+
+@dataclass
+class Op:
+    dim_map: Callable[..., DimMap]
+    shape_argnum: Optional[int] = None
+
+
+ops: Dict[Callable[..., torch.Tensor], Op] = {
+    torch.atleast_1d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 1)),
+    torch.atleast_2d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 2)),
+    torch.atleast_3d: Op(dim_map=lambda x: dim_atleast_3d(x.ndim)),
+    torch.broadcast_to: Op(
+        dim_map=lambda input, shape: expand(input.shape, shape), shape_argnum=1
+    ),
+    Tensor.expand: Op(
+        dim_map=lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
+        shape_argnum=1,
+    ),
+    torch.flatten: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)),
+    torch.movedim: Op(
+        dim_map=lambda input, source, destination: dim_movedim(
+            input.ndim, source, destination
+        )
+    ),
+    torch.permute: Op(
+        dim_map=lambda input, dims: tuple(
+            InputDim(i) for i in normalize_dims(dims, input.ndim)
+        )
+    ),
+    torch.ravel: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)),
+    Tensor.repeat: Op(dim_map=lambda self, *sizes: dim_repeat(self.ndim, sizes)),
+    torch.reshape: Op(
+        dim_map=lambda input, shape: view_groups(input.shape, shape),
+        shape_argnum=1,
+    ),
+    torch.squeeze: Op(dim_map=lambda input, dim=None: dim_squeeze(input.shape, dim)),
+    torch.tile: Op(dim_map=lambda input, dims: dim_tile(input.ndim, dims)),
+    torch.transpose: Op(
+        dim_map=lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1)
+    ),
+    torch.unsqueeze: Op(dim_map=lambda input, dim: dim_unsqueeze(input.ndim, dim)),
+    Tensor.view: Op(
+        dim_map=lambda input, *shape: view_groups(input.shape, shape),
+        shape_argnum=1,
+    ),
+}
+
+
+def propagate_shape_and_sharding(
+    in_shard: Sequence[Placement],
+    local_in_shape: Shape,
+    rule: DimMap,
+    mesh_sizes: Shape,
+) -> Tuple[Shape, Optional[Sequence[Placement]], torch.Tensor]:
+    """
+    Determine output sharding and tensor shape based on given global tensor shape and input sharding.
+
+    Takes as input the global shape of the tensor, and the input sharding,
+    and produce corresponding output sharding and shape of the output tensor.
+
+    Sharding propagation follows mapped dimensions:
+    - An output dimension that maps directly to an input dimension is sharded equally
+    - An output dimension that is a flattened set of input dimensions can only be
+      sharded if only the leftmost flattened dimension is sharded.
+    - An output dimension that is a split of the input dimension can only be sharded
+      if the leftmost split size is divisible by the mesh dimension
+    """
+    assert len(in_shard) == len(mesh_sizes)
+    sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)}
+    # for each input dim, for each mesh dim, provides a list of possible shardable dimensions
+    shardable_dims: torch.Tensor = torch.ones(
+        (len(local_in_shape), len(mesh_sizes)), dtype=torch.bool
+    )
+
+    # in case an input dimension disappears (e.g. collapsing, reduction)
+    # we cannot shard in that dimension (we need a replication fall-back rule)
+
+    seen_input_dims: Set[int] = set()
+
+    def collect_used_inputs(cmd: DimSpec) -> None:
+        if isinstance(cmd, InputDim):
+            seen_input_dims.add(cmd.input_dim)
+        for inp in cmd.inputs():
+            collect_used_inputs(inp)
+
+    for cmd in rule:
+        collect_used_inputs(cmd)
+    for dim in range(len(local_in_shape)):
+        shardable_dims[dim, :] = dim in seen_input_dims
+
+    def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]:
+        if isinstance(cmd, InputDim):
+            seen_input_dims.add(cmd.input_dim)
+            return (
+                local_in_shape[cmd.input_dim],
+                cmd if cmd.input_dim in sharded_in_dims else None,
+            )
+        elif isinstance(cmd, Flatten):
+            for dim in cmd.input_dims[1:]:
+                if isinstance(dim, InputDim):
+                    shardable_dims[dim.input_dim, :] = False
+            dim0 = cmd.input_dims[0]
+            return (
+                prod(get_dim_size(a)[0] for a in cmd.input_dims),
+                dim0
+                if isinstance(dim0, InputDim) and dim0.input_dim in sharded_in_dims
+                else None,
+            )
+        elif isinstance(cmd, Split):
+            _, in_dim = get_dim_size(cmd.input_dim)
+            out_size = cmd.group_shape[cmd.split_id]
+            if cmd.split_id == 0 and in_dim is not None:
+                # we need to check that the input dimension is divisible
+                # by the size of the submesh we're sharding it on
+                # NOTE: it would be possible to shard the same input dimension
+                # on more than one mesh dimension. In that case, the dimension
+                # needs to be divisible by the product of mesh sizes.
+                # In order to keep the problem more tractable, we will not consider
+                # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ])
+                # but we will allow it if that's the input and it's compatible
+
+                # 1. is this dimension shardable on each individual mesh dim?
+                for mesh_dim, mesh_dim_size in enumerate(mesh_sizes):
+                    shardable_dims[in_dim.input_dim, mesh_dim] = (
+                        out_size % mesh_dim_size == 0
+                    )
+
+                # 2. here we special case things like [Shard(0), Shard(0)]
+                submesh_size = 1
+                for size, shard in zip(mesh_sizes, in_shard):
+                    if isinstance(shard, Shard) and shard.dim == in_dim:
+                        submesh_size *= size
+                assert (
+                    out_size % submesh_size == 0
+                ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
+
+            # we will only shard our first component of the split
+            return out_size, in_dim if cmd.split_id == 0 else None
+        elif isinstance(cmd, Singleton):
+            return 1, None
+        elif isinstance(cmd, Broadcast):
+            return cmd.dim_size, None
+        elif isinstance(cmd, NewDim):
+            return cmd.size, None
+        elif isinstance(cmd, Repeat):
+            size, in_dim = get_dim_size(cmd.input_dim)
+            if in_dim is not None:
+                shardable_dims[in_dim.input_dim, :] = False
+            return size * cmd.times, None
+        else:
+            raise RuntimeError(f"cmd not found: {cmd}, in rule: {rule}")
+
+    dim_map = {}
+    out_shape = []
+    for dim, cmd in enumerate(rule):
+        out_size, in_dim = get_dim_size(cmd)
+        out_shape.append(out_size)
+        if in_dim is not None:
+            dim_map[in_dim.input_dim] = dim
+
+    needs_reshard = any(
+        isinstance(placement, Shard) and not shardable_dims[placement.dim][mesh_dim]
+        for mesh_dim, placement in enumerate(in_shard)
+    )
+
+    output_placements = (
+        None
+        if needs_reshard
+        else [Shard(dim_map[s.dim]) if isinstance(s, Shard) else s for s in in_shard]
+    )
+
+    return (tuple(out_shape), output_placements, shardable_dims)
+
+
+def register_prop_rule_map(
+    aten_op_overload: torch._ops.OpOverload,
+    local_op_name: Callable[..., torch.Tensor],
+    schema_info: Optional[RuntimeSchemaInfo] = None,
+) -> None:
+    spec: Op = ops[local_op_name]
+
+    @register_prop_rule(aten_op_overload, schema_info=schema_info)
+    def reshape_prop(op_schema: OpSchema) -> OutputSharding:
+        rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
+        input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0])
+        mesh = input_dtensor_spec.mesh
+
+        assert isinstance(
+            input_dtensor_spec, DTensorSpec
+        ), "Expected first input to be a DTensorSpec"
+        global_in_shape = input_dtensor_spec.shape
+        assert global_in_shape is not None, "Shape required."
+
+        with disable_proxy_modes_tracing(), unset_fake_temporarily():
+            (
+                global_out_shape,
+                shard_out,
+                shardable_dims,
+            ) = propagate_shape_and_sharding(
+                input_dtensor_spec.placements,
+                tuple(global_in_shape),
+                rules,
+                mesh.shape,
+            )
+
+        if shard_out is not None:
+            # no reshard needed
+            output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out))
+
+            # We only need the local shape to lower the call into the local op
+            args = op_schema.args_schema
+            shape_argnum = spec.shape_argnum
+            if shape_argnum is not None:
+                # compute the local shape from the global shape, then return
+                # a resharding even if we don't really reshard, the only reason
+                # for this type of resharding is to lower the global shape to
+                # local shape
+                local_out_shape = compute_local_shape(
+                    list(global_out_shape), mesh, shard_out
+                )
+
+                suggested_schema = OpSchema(
+                    op=op_schema.op,
+                    args_schema=args[:shape_argnum]
+                    + (tuple(local_out_shape),)
+                    + args[shape_argnum + 1 :],
+                    kwargs_schema=op_schema.kwargs_schema,
+                )
+                return OutputSharding(
+                    output_spec=output_dtensor_spec,
+                    schema_suggestions=[suggested_schema],
+                    needs_redistribute=True,
+                )
+
+            return OutputSharding(output_spec=output_dtensor_spec)
+
+        else:
+            # TODO: optimize this. we shouldn't simply blindly replicate
+            #       unshardable dims ...
+            # FIXME: this can be wrong for situations where we have
+            #        [Shard(0), Shard(0)]
+            suggested_placements = [
+                p
+                if not isinstance(p, Shard) or shardable_dims[p.dim][mesh_dim]
+                else Replicate()
+                for mesh_dim, p in enumerate(input_dtensor_spec.placements)
+            ]
+            return OutputSharding(
+                output_spec=None,
+                schema_suggestions=[
+                    OpSchema(
+                        op=op_schema.op,
+                        args_schema=(
+                            DTensorSpec(
+                                placements=tuple(suggested_placements),
+                                mesh=input_dtensor_spec.mesh,
+                                tensor_meta=input_dtensor_spec.tensor_meta,
+                            ),
+                        )
+                        + op_schema.args_schema[1:],
+                        kwargs_schema=op_schema.kwargs_schema,
+                    )
+                ],
+            )
+
+
+register_prop_rule_map(aten.squeeze.default, torch.squeeze)
+register_prop_rule_map(
+    aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1))
+register_prop_rule_map(
+    aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(
+    aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(
+    aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(
+    aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(
+    aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(
+    aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)
+)
+register_prop_rule_map(
+    aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/placement_types.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/placement_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b5424a028eb120c25d20fe524a49f4ce4df9fa4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/placement_types.py
@@ -0,0 +1,620 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from dataclasses import dataclass
+from typing import Any, cast, List, NamedTuple, Optional, Tuple
+
+import torch
+import torch.distributed._functional_collectives as funcol
+import torch.distributed.distributed_c10d as c10d
+
+from torch.distributed._tensor._collective_utils import mesh_broadcast, mesh_scatter
+from torch.distributed.device_mesh import DeviceMesh
+
+
+class Placement:
+    # base class Placement type
+
+    # convenient utils to check for placement types
+    def is_shard(self, dim: Optional[int] = None) -> bool:
+        is_shard_instance = isinstance(self, Shard)
+        if dim is not None and is_shard_instance:
+            return cast(Shard, self).dim == dim
+        else:
+            return is_shard_instance
+
+    def is_replicate(self) -> bool:
+        return isinstance(self, Replicate)
+
+    def is_partial(self) -> bool:
+        return isinstance(self, _Partial)
+
+
+@dataclass(frozen=True)
+class Shard(Placement):
+    # shard placement, shard on a dim
+    dim: int
+
+    def _split_tensor(
+        self,
+        tensor: torch.Tensor,
+        num_chunks: int,
+        *,
+        with_padding: bool = True,
+        contiguous: bool = True,
+    ) -> Tuple[List[torch.Tensor], List[int]]:
+        """
+        This function uses torch.chunk to split a tensor into num_chunks shards along
+        the Shard placement dimension, and return a list of shards with their pad sizes.
+
+        Keyword args:
+            with_padding (bool, optional): when True, we pad the tensor on the last
+            few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
+            This is because collectives usually require equal size tensor inputs
+        """
+        assert (
+            self.dim <= tensor.ndim
+        ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
+
+        # chunk tensor over dimension `dim` into n slices with padding if necessary
+        tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
+        # compute the chunk size inline with ``torch.chunk``
+        full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
+
+        # Compute chunk size for each chunk for ``self.dim``
+        chunk_sizes = [
+            tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0
+            for idx in range(num_chunks)
+        ]
+        # Compute pad size on each chunk
+        pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes]
+
+        # Reuse tensor to fill empty chunk with empty tensor
+        num_empty_tensors = num_chunks - len(tensor_list)
+        tensor_size = list(tensor_list[0].size())
+        tensor_size = [
+            size if idx != self.dim else 0 for idx, size in enumerate(tensor_size)
+        ]
+        tensor = tensor.new_zeros(tensor_size)
+        for _ in range(num_empty_tensors):
+            tensor_list.append(tensor)
+
+        if with_padding or contiguous:
+            shard_list = []
+            for shard, pad_size in zip(tensor_list, pad_sizes):
+                # Fill the empty tensor with zeroes with padding.
+                if with_padding and pad_size > 0:
+                    shard = self._pad_tensor(shard, pad_size)
+                shard = shard.contiguous() if contiguous else shard
+                shard_list.append(shard)
+            return shard_list, pad_sizes
+        else:
+            return tensor_list, pad_sizes
+
+    def _pad_tensor(
+        self,
+        tensor: torch.Tensor,
+        pad_size: int,
+    ) -> torch.Tensor:
+        if pad_size == 0:
+            return tensor
+        pad = [0, 0] * (tensor.ndim - self.dim)
+        pad[-1] = pad_size
+        return torch.nn.functional.pad(tensor, pad)
+
+    def _unpad_tensor(
+        self,
+        tensor: torch.Tensor,
+        pad_size: int,
+    ) -> torch.Tensor:
+        if pad_size == 0:
+            return tensor
+        return tensor.narrow(
+            self.dim,
+            start=0,
+            length=tensor.size(self.dim) - pad_size,
+        )
+
+    @staticmethod
+    def _local_shard_size_on_dim(
+        size_on_dim: int,
+        num_chunks: int,
+        rank: int,
+        return_offset: bool = False,
+    ) -> Tuple[int, int]:
+        """
+        returns the local shard size and offset on a given tensor dim
+        """
+        # Compute the chunk size inline with ``torch.chunk``
+        if size_on_dim % num_chunks == 0:
+            full_chunk_size = size_on_dim // num_chunks
+            return full_chunk_size, full_chunk_size * rank if return_offset else -1
+
+        # uneven sharding case
+        full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks
+        shard_starting_idx = full_chunk_size * rank
+
+        if size_on_dim < shard_starting_idx:
+            return 0, size_on_dim if return_offset else -1
+        else:
+            local_shard_size = (
+                min(size_on_dim, shard_starting_idx + full_chunk_size)
+                - shard_starting_idx
+            )
+            return local_shard_size, shard_starting_idx if return_offset else -1
+
+    def _shard_tensor(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        """
+        shard and scatter a tensor on a mesh dimension (use coordinate
+        0 on the mesh dimension as source of truth)
+        """
+        my_coordinate = mesh.get_coordinate()
+        num_chunks = mesh.size(mesh_dim=mesh_dim)
+
+        if my_coordinate is None:
+            # if rank is not part of mesh, we simply return an empty tensor
+            return tensor.new_empty(0, requires_grad=tensor.requires_grad)
+
+        scatter_list, pad_sizes = self._split_tensor(
+            tensor, num_chunks, with_padding=True, contiguous=True
+        )
+
+        output = torch.empty_like(scatter_list[my_coordinate[mesh_dim]])
+        mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim)
+
+        # Only unpad if the local_tensor was padded on the dimension.
+        pad_size = pad_sizes[my_coordinate[mesh_dim]]
+        if pad_size > 0:
+            output = self._unpad_tensor(output, pad_size)
+        return output
+
+    def _reduce_shard_tensor(
+        self,
+        tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        reduce_op: c10d.ReduceOp.RedOpType,
+        mesh_dim: int,
+    ) -> torch.Tensor:
+        """
+        reduce and scatter a tensor on a mesh dimension
+        """
+        my_coordinate = mesh.get_coordinate()
+        num_chunks = mesh.size(mesh_dim=mesh_dim)
+
+        if my_coordinate is None:
+            # if rank is not part of mesh, we simply return local_tensor,
+            # which should be an empty tensor
+            return tensor
+
+        is_padded = tensor.size(self.dim) % num_chunks != 0
+        if is_padded:
+            scattered_list, pad_sizes = self._split_tensor(
+                tensor, num_chunks, with_padding=True, contiguous=True
+            )
+            tensor = torch.cat(scattered_list, dim=self.dim)
+        elif not tensor.is_contiguous():
+            tensor = tensor.contiguous()
+
+        output = funcol.reduce_scatter_tensor(
+            tensor, reduce_op.name, scatter_dim=self.dim, group=(mesh, mesh_dim)
+        )
+
+        if is_padded:
+            output = self._unpad_tensor(output, pad_sizes[my_coordinate[mesh_dim]])  # type: ignore[possibly-undefined]
+        return output
+
+    def _to_replicate_tensor(
+        self,
+        local_tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        current_logical_shape: List[int],
+    ) -> torch.Tensor:
+        """
+        This function all_gather all shards and return a tensor that
+        is replicated on the previously sharded mesh dimension
+        """
+        num_chunks = mesh.size(mesh_dim=mesh_dim)
+        # check if it's uneven, so we need to pad input tensor before all_gather
+        local_shape = list(local_tensor.size())
+
+        logical_dim_size = current_logical_shape[self.dim]
+        is_padded = logical_dim_size % num_chunks != 0
+
+        if is_padded:
+            full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
+            pad_size = full_chunk_size - local_shape[self.dim]
+            local_tensor = self._pad_tensor(local_tensor, pad_size)
+
+        if not local_tensor.is_contiguous():
+            local_tensor = local_tensor.contiguous()
+
+        result = funcol.all_gather_tensor(
+            local_tensor,
+            gather_dim=self.dim,
+            group=(mesh, mesh_dim),
+        )
+        if is_padded:
+            unpad_size = full_chunk_size * num_chunks - logical_dim_size  # type: ignore[possibly-undefined]
+            result = self._unpad_tensor(result, unpad_size)
+        return result
+
+    def _replicate_to_shard(
+        self,
+        local_tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        shard_index: int,
+    ) -> torch.Tensor:
+        """
+        transform from replicated tensor to a sharded tensor on
+        the current rank, which would perform a local chunk
+        """
+        num_chunks = mesh.size(mesh_dim=mesh_dim)
+        shards, _ = self._split_tensor(
+            local_tensor,
+            num_chunks,
+            with_padding=False,
+            contiguous=False,
+        )
+        return shards[shard_index].clone()
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Shard):
+            return False
+        return self.dim == other.dim
+
+    def __hash__(self) -> int:
+        return hash(self.dim)
+
+    def __repr__(self) -> str:
+        """
+        machine readable representation of the Shard placement
+        """
+        return f"Shard(dim={self.dim})"
+
+    def __str__(self) -> str:
+        """human readable representation of the Shard placement"""
+        return f"S({self.dim})"
+
+
+@dataclass(frozen=True)
+class Replicate(Placement):
+    # replicate placement
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Replicate):
+            return False
+        return True
+
+    def __hash__(self) -> int:
+        # every replicate placement is the same
+        return -1
+
+    def __repr__(self) -> str:
+        """
+        machine readable representation of the Replicate placement
+        """
+        return "Replicate()"
+
+    def __str__(self) -> str:
+        """
+        human readable representation of the Replicate placement
+        """
+        return "R"
+
+    def _replicate_tensor(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        """
+        Replicate (broadcast) a torch.Tensor on a mesh dimension (use
+        the first coordinate on the mesh dimension as source of truth)
+        """
+        my_coordinate = mesh.get_coordinate()
+        if my_coordinate is None:
+            # if rank is not part of mesh, we simply return an empty tensor
+            return tensor.new_empty(0, requires_grad=tensor.requires_grad)
+
+        tensor = tensor.contiguous()
+        mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim)
+        return tensor
+
+
+@dataclass(frozen=True)
+class _Partial(Placement):
+    # This is a default _Partial placement with element-wise reduce op
+    # _Partial define three contracts:
+    # 1. _reduce_value: reduce the value of the tensor on the mesh dimension
+    # 2. _reduce_shard_value: reduce_scatter the value of the tensor on the mesh dimension
+    # 3. _partition_value: partition the value of a replicated tensor on the mesh dimension
+    # We can implement custom reductions as needed by subclassing this
+    # class and override those contracts.
+    reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM
+
+    def _reduce_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        return funcol.all_reduce(
+            tensor, reduceOp=self.reduce_op.name, group=(mesh, mesh_dim)
+        )
+
+    def _reduce_shard_value(
+        self,
+        tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        shard_spec: Placement,
+    ) -> torch.Tensor:
+        # by default call reduce_shard_tensor of the shard_spec.
+        shard_spec = cast(Shard, shard_spec)
+        return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
+
+    def _partition_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        # _partition_value is the conjugate operation of _reduce_value
+        # - i.e. _partition_value on a sum reduce op is just a divison operation
+        # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation
+        # TODO: if the reduce_op is min/max, etc. the _partition_value should be a
+        # different operation
+        assert (
+            self.reduce_op == c10d.ReduceOp.SUM
+        ), "only support replicate to PartialSUM for now!"
+        num_chunks = mesh.size(mesh_dim=mesh_dim)
+        return tensor / num_chunks
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, _Partial):
+            return False
+        return self.reduce_op == other.reduce_op
+
+    def __hash__(self) -> int:
+        return 1 + hash(self.reduce_op)
+
+    def __repr__(self) -> str:
+        """
+        machine readable representation of the Partial placement
+        """
+        return f"_Partial(reduce_op={self.reduce_op})"
+
+    def __str__(self) -> str:
+        """
+        human readable representation of the Partial placement
+        """
+        return "P"
+
+
+class TensorMeta(NamedTuple):
+    # simple named tuple to represent tensor metadata
+    # intentionally to stay simple only for sharding
+    # propagation purposes.
+    shape: torch.Size
+    stride: Tuple[int, ...]
+    dtype: torch.dtype
+
+
+# used internally to propagate the placements
+@dataclass
+class DTensorSpec:
+    mesh: DeviceMesh
+    placements: Tuple[Placement, ...]
+
+    # tensor meta will only be set during sharding propagation
+    tensor_meta: Optional[TensorMeta] = None
+
+    def __post_init__(self):
+        if not isinstance(self.placements, tuple):
+            self.placements = tuple(self.placements)
+        self._hash: Optional[int] = None
+
+    def __setattr__(self, attr: str, value: Any):
+        super().__setattr__(attr, value)
+        # Make sure to recompute the hash in case any of the hashed attributes
+        # change (though we do not expect `mesh` or `placements` to change)
+        if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
+            self._hash = None
+
+    def _hash_impl(self) -> int:
+        # hashing and equality check for DTensorSpec are used to cache the sharding
+        # propagation results. We only need to consider the mesh, placements, shape
+        # dtype and stride.
+        # Caveat: we need to keep this in mind and sync hash and eq if we add more
+        # fields to them.
+        if self.tensor_meta is not None:
+            return hash(
+                (
+                    self.mesh,
+                    self.placements,
+                    self.tensor_meta.shape,
+                    self.tensor_meta.stride,
+                    self.tensor_meta.dtype,
+                )
+            )
+        return hash((self.mesh, self.placements))
+
+    def __hash__(self) -> int:
+        # We lazily cache the spec to avoid recomputing the hash upon each
+        # use, where we make sure to update the hash when the `tensor_meta`
+        # changes by overriding `__setattr__`. This must be lazy so that Dynamo
+        # does not try to hash non-singleton `SymInt`s for the stride.
+        if self._hash is None:
+            self._hash = self._hash_impl()
+        return self._hash
+
+    def __eq__(self, __o: object) -> bool:
+        if not (
+            isinstance(__o, DTensorSpec)
+            and self.mesh == __o.mesh
+            and self.placements == __o.placements
+        ):
+            return False
+        if self.tensor_meta is None or __o.tensor_meta is None:
+            return self.tensor_meta == __o.tensor_meta
+
+        return (
+            self.tensor_meta.shape == __o.tensor_meta.shape  # type: ignore[union-attr]
+            and self.tensor_meta.stride == __o.tensor_meta.stride  # type: ignore[union-attr]
+            and self.tensor_meta.dtype == __o.tensor_meta.dtype  # type: ignore[union-attr]
+        )
+
+    def __str__(self) -> str:
+        """
+        human readable representation of the DTensorSpec
+        """
+        if len(self.placements) == 1:
+            placement_str = str(self.placements[0])
+        else:
+            placement_str = str(self.placements)
+
+        if self.tensor_meta is not None:
+            tensor_shape = str(tuple(self.tensor_meta.shape))
+        else:
+            tensor_shape = "unknown shape"
+
+        return f"Spec({placement_str} on {tensor_shape})"
+
+    @property
+    def shape(self) -> torch.Size:
+        if self.tensor_meta is None:
+            raise ValueError("tensor_meta is not set")
+        return self.tensor_meta.shape
+
+    @property
+    def stride(self) -> Tuple[int, ...]:
+        if self.tensor_meta is None:
+            raise ValueError("tensor_meta is not set")
+        return self.tensor_meta.stride
+
+    @property
+    def ndim(self) -> int:
+        if self.tensor_meta is None:
+            raise ValueError("tensor_meta is not set")
+        return len(self.tensor_meta.shape)
+
+    @property
+    def num_shards(self) -> int:
+        num_shards = 1
+        for i, placement in enumerate(self.placements):
+            if placement.is_shard():
+                num_shards *= self.mesh.size(i)
+        return num_shards
+
+    @property
+    def device_mesh(self) -> DeviceMesh:
+        # simple aliasing for the mesh field, make some
+        # checks that mixes DTensor/DTensorSpec easier
+        return self.mesh
+
+    @property
+    def dim_map(self) -> List[int]:
+        """
+        dim_map is a property we derive from `placements` of
+        the distributed tensor. It simply return a list of ints
+        where dim_map[i] denotes the sharding mapping to the mesh
+        dimension, and len(dim_map) == dist_tensor.ndim
+        dim_map[i] = -1: means tensor dim i replicate on mesh
+        dim_map[i] = j: means tensor dim i shard on mesh dim j
+
+        For example, we have a dist tensor that have the shape of
+        [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements:
+        [Shard(1)], the dim_map of this placement would be:
+        [-1, 0, -1]. This representation is pretty helpful during
+        sharding propagation where we could know exactly each
+        tensor dimension is sharded or not.
+
+        Note that if placements contains `_Partial`, we have to
+        explicitly deal with it, so that when we create a DTensorSpec
+        with dim_map, we could properly record the pending sums.
+        """
+        # dims mapping of dist tensor sharding
+        # return size of tensor ndim, -1 represent replicate
+        # and int >=0 represent shard on that device mesh dim
+        r = [-1] * self.ndim
+        for i, placement in enumerate(self.placements):
+            if placement.is_shard():
+                shard_dim = cast(Shard, placement).dim
+                if r[shard_dim] > -1:
+                    raise ValueError(
+                        f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]},"
+                        " DTensor operator implementation does not support things like hybrid"
+                        " sharding strategies yet (i.e. [Shard(0), Shard(0)])"
+                    )
+                r[shard_dim] = i
+        return r
+
+    @property
+    def sums(self) -> List[int]:
+        """
+        sums is a property we derive from `placements` of the
+        distributed tensor. It simply return a list of ints where
+        sums[i] denotes the pending sum (partial) on mesh dim i
+        """
+        return [
+            idx
+            for idx, placement in enumerate(self.placements)
+            if placement.is_partial()
+        ]
+
+    @classmethod
+    def from_dim_map(
+        cls,
+        mesh: DeviceMesh,
+        dim_map: List[int],
+        sums: List[int],
+        tensor_meta: Optional[TensorMeta] = None,
+    ) -> "DTensorSpec":
+        """
+        Construct a DTensorSpec from dim_map list and pending sum.
+
+        Args:
+            mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec
+            dim_map (List[int]): a list of integer that represents sharding on each
+                tensor dimension, see `dim_map` property doc for details
+            sums (List[int]): a list of integer that represents the dist tensor have
+                pending sum on which device mesh dimension.
+            tensor meta (TensorMeta): DTensor metadata
+
+        Return:
+            a class:`DTensorSpec` object
+        """
+        # by default replicate on device mesh dims
+        placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
+
+        # find all mesh dims that need pending reductions
+        for s in sums:
+            placements[s] = _Partial()
+
+        for i, m in enumerate(dim_map):
+            if m >= 0:
+                placement = placements[m]
+                if placement.is_shard():
+                    placement = cast(Shard, placement)
+                    raise RuntimeError(
+                        f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}"
+                    )
+                elif placement.is_partial():
+                    raise RuntimeError(
+                        f"DeviceMesh dimension {m} cannot be both shard and partial!"
+                    )
+                placements[m] = Shard(i)
+
+        return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
+
+    def is_replicated(self):
+        """
+        return True if the current DTensorSpec replicates on all mesh dims (devices)
+        """
+        return all(placement.is_replicate() for placement in self.placements)
+
+    def shallow_copy_with_tensor_meta(
+        self, tensor_meta: Optional[TensorMeta]
+    ) -> "DTensorSpec":
+        """
+        Shallow copy the DTensorSpec with a new tensor_meta.
+        """
+        assert tensor_meta is not None, "shallow copy with no tensor_meta!"
+        return DTensorSpec(
+            self.mesh,
+            self.placements,
+            tensor_meta=tensor_meta,
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/random.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..22d0cf9ca708fe5ab026f8130ed8a396b1f2ed57
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/random.py
@@ -0,0 +1,372 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import contextlib
+import warnings
+from typing import Dict, List, Optional
+
+import torch
+import torch.distributed as dist
+
+from torch import Tensor
+from torch.distributed._tensor.placement_types import DTensorSpec, Shard
+from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
+
+
+_rng_tracker: Optional["RNGStateTracker"] = None
+
+
+def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
+    """Checks if the current device of `device_mesh` supports DTensor's random APIs.
+    Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest
+    users call this API to test the availability before using our random APIs.
+
+    Args:
+        device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the
+            random ops APIs are supported.
+
+    Returns:
+        A bool value. True if `device_mesh` supports DTensor Random APIs; False otherwise.
+
+    .. warning::
+        Currently we only support correct RNG on cuda/cuda-like devices.
+    """
+    device_handle = _get_device_handle(device_mesh.device_type)
+    if device_handle and hasattr(device_handle, "set_rng_state"):
+        return True
+    else:
+        warnings.warn(
+            f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh"
+        )
+        return False
+
+
+def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
+    """Sets the seed for generating random numbers for the calling rank.
+
+    Args:
+        seed (int): The desired seed.
+        device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
+
+    Returns:
+        None
+
+    .. warning::
+        When calling this function, :func:`manual_seed` must be called from all ranks of the
+        default `ProcessGroup` even if some ranks may not be a part of the `device_mesh`,
+        with the same `seed` value.
+        If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
+        `manual_seed` will not set its GPU device's generator seed.
+        Current implementation only supports a GPU device mesh.
+    """
+    device_handle = _get_device_handle(device_mesh.device_type)
+    if not device_handle:
+        raise NotImplementedError(
+            f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
+        )
+
+    # allgather the seed over the default PG
+    object_list = [seed] * dist.get_world_size()
+    dist.all_gather_object(object_list, seed)
+    for rank, object in enumerate(object_list):
+        if seed != int(object):
+            raise RuntimeError(
+                f"calling manual_seed function over {device_mesh} but received different seed values on ranks:",
+                f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!",
+            )
+    # instantiate a RNG tracker if haven't. By default DTensor uses an
+    # OffsetBasedRNGTracker to perform random operators.
+    global _rng_tracker
+    if not _rng_tracker:
+        _rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
+
+    # the current rank is in mesh
+    if device_mesh.get_coordinate() is not None:
+        if isinstance(_rng_tracker, TensorParallelRNGTracker):
+            _rng_tracker._manual_seed(device_mesh, seed)
+        elif isinstance(_rng_tracker, OffsetBasedRNGTracker):
+            _rng_tracker._manual_seed(seed)
+        else:
+            raise RuntimeError(
+                f"Unknown type of cuda RNG state tracker: _rng_tracker = {_rng_tracker}"
+            )
+
+
+class RNGStateTracker:
+    """
+    RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object)
+    in a dict, mapping from a corresponding tag to each state tensor. It also provides
+    a set of convenient utility methods to help access/modify the state tensors. The most
+    important interface is _distribute_region which will be used when DTensor executes
+    a random op (an operator that calls RNG).
+    """
+
+    def __init__(self, device_type: str = "cuda"):
+        self._device_type = device_type
+        self._device_handle = _get_device_handle(device_type)
+        if not (self._device_handle and self._device_handle.is_available()):
+            raise RuntimeError(
+                f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
+            )
+
+        self._states: Dict[str, Tensor] = {}
+        self._devices = [self._device_handle.current_device()]
+        self._use_distribute_region = True
+
+    @property
+    def rng_states(self) -> Dict[str, Tensor]:
+        return self._states
+
+    @property
+    def distribute_region_enabled(self) -> bool:
+        return self._use_distribute_region
+
+    @distribute_region_enabled.setter
+    def distribute_region_enabled(self, value) -> None:
+        self._use_distribute_region = value
+
+    def rng_state_is_sync(self, name) -> bool:
+        return name in self.rng_states
+
+    def get_seed(self, name: str) -> int:
+        if name not in self.rng_states:
+            raise RuntimeError(
+                f"{self.__class__.__name__} does not have random state for {name}"
+            )
+
+        seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64)
+        return int(seed_tensor.item())
+
+    def set_seed(self, name: str, seed: int) -> None:
+        seed_tensor = torch.tensor([seed]).view(torch.uint8)
+        offset_tensor = torch.tensor([0]).view(torch.uint8)
+        self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
+
+    def _distribute_region(self, spec: DTensorSpec):
+        pass
+
+
+class OffsetBasedRNGTracker(RNGStateTracker):
+    """
+    This subclass of `RNGStateTracker` defines the default policy of how RNG states
+    should be shared and synchronized among all ranks to respect the semantics of DTensor
+    random operators.
+    """
+
+    def __init__(self, device_type: str = "cuda"):
+        super().__init__(device_type)
+        # synchronize RNG state using rank 0's current one
+        rng_state = self._device_handle.get_rng_state().to(device_type)
+        dist.broadcast(rng_state, 0)
+        self.rng_states["parallel-rng"] = rng_state.to("cpu")
+
+    def _manual_seed(self, parallel_seed: int) -> None:
+        self.set_seed("parallel-rng", parallel_seed)
+
+    @contextlib.contextmanager
+    def _distribute_region(self, spec: DTensorSpec):
+        # check if the parallel rng state has been synchronized or not
+        if not self.rng_state_is_sync("parallel-rng"):
+            raise RuntimeError(
+                "OffsetBasedRNGTracker requires the random state to be synchronized "
+                "before entering into a distribute region!"
+            )
+
+        if self.distribute_region_enabled:
+            old_offset = self.get_offset("parallel-rng")
+            self._set_pre_op_offset(spec)
+            with torch.random.fork_rng(self._devices, device_type=self._device_type):
+                self._device_handle.set_rng_state(self.rng_states["parallel-rng"])
+                try:
+                    yield  # execute the region code
+                finally:
+                    # update offset to synchronize among ranks
+                    self._set_post_op_offset(spec, old_offset)
+        else:
+            yield
+
+    def get_offset(self, name: str) -> int:
+        if name not in self.rng_states:
+            raise RuntimeError(
+                f"{self.__class__.__name__} does not have random state for {name}"
+            )
+
+        offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64)
+        return int(offset_tensor.item())
+
+    def set_offset(self, name: str, offset: int) -> None:
+        if name not in self.rng_states:
+            raise RuntimeError(
+                f"{self.__class__.__name__} does not have random state for {name}"
+            )
+
+        seed_tensor = (self.rng_states[name])[0:8]
+        offset_tensor = torch.tensor([offset]).view(torch.uint8)
+        self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
+
+    def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
+        """Set the starting RNG offset for current device's local shard before actual
+        op execution. The pre_op_offset value should start from the current RNG offset
+        and increment by the size of local shard until it reaches the size of the whole
+        DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset
+        will be the same.
+
+        Args:
+            spec (:class:`DTensorSpec`): the spec of the DTensor object on which
+                we prepare the offset for running random ops.
+
+        Returns:
+            None
+
+        .. warning::
+            Note that, current implementation does not consider DTensor's continguity.
+
+        Example:
+            take a DTensor of shape [8, 16] as an example. Assume that the DTensor
+            is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
+            and the mesh is:
+                [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
+            ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank
+            in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
+
+            Another concept to introduce besides rank coordinate is shard coordinate.
+            Each rank holds a local shard of the DTensor. In the example, the DTensor
+            is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
+            rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
+            That being said, the local shard on rank 0 and rank 2 correspond to the same
+            shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
+            (in the example, it will be a tuple (i, j) where shard (i, j) has the slice
+            DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
+
+            Once we have rank coordinate and shard coordinate, we can calculate on each rank
+            what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
+            of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
+            (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
+            Following this calculation,
+            rank 0 and rank 2 holds the shard of coord (0, 0);
+            rank 1 and rank 3 holds the shard of coord (0, 1);
+            rank 4 and rank 6 holds the shard of coord (1, 0);
+            rank 5 and rank 7 holds the shard of coord (1, 1);
+
+            The last value to calculate before obtaining the starting offset is the shard linear index.
+            The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
+        """
+        dtensor_shape = spec.shape
+        mesh = spec.mesh
+        dim_map = spec.dim_map
+
+        # Compute shard coordinate:
+        # The coordinate on each tensor dim is a tuple (idx, range)
+        # If a DTensor is partitioned on its dim i into n shards, and the current rank
+        # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
+        coordinate = mesh.get_coordinate()
+        assert coordinate is not None
+        shard_coord = [
+            coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dim_map
+        ]
+        shard_size = [
+            mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dim_map
+        ]
+
+        # compute shard linear index
+        shard_linear_idx = self._calc_shard_linear_idx(shard_coord, shard_size)
+
+        # compute starting offset using the first shard's size
+        local_size_on_rank_0 = list(dtensor_shape)
+        for idx, placement in enumerate(spec.placements):
+            if isinstance(placement, Shard):
+                mesh_dim_size = mesh.size(idx)
+                shard_dim = placement.dim
+                local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim(
+                    dtensor_shape[shard_dim],
+                    mesh_dim_size,
+                    0,
+                    return_offset=False,
+                )[0]
+
+        from torch.distributed._tensor.ops.utils import prod
+
+        local_size = prod(local_size_on_rank_0)
+
+        # get current RNG offset
+        current_offset = self.get_offset("parallel-rng")
+
+        # pytorch: offset must be multiple of 4
+        # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
+        offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4
+        self.set_offset("parallel-rng", current_offset + offset_incr)
+
+    def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
+        """Sets the RNG to a synchronized state after running the local random op. Every
+        rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is
+        the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor
+        random ops.
+
+        Args:
+            spec (:class:`DTensorSpec`): the spec of the DTensor object on which
+                we post-process the offset for running random ops.
+
+        Returns:
+            None
+        """
+        dtensor_shape = spec.shape
+
+        from torch.distributed._tensor.ops.utils import prod
+
+        numel = prod(dtensor_shape)
+        # pytorch: offset must be multiple of 4
+        # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
+        numel = (numel + 3) // 4 * 4
+        self.set_offset("parallel-rng", old_offset + numel)
+
+    def _calc_shard_linear_idx(
+        self, shard_coord: List[int], shard_size: List[int]
+    ) -> int:
+        # compute shard linear index
+        shard_linear_idx = 0
+        shard_coord_stride = 1
+        for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
+            shard_linear_idx += idx * shard_coord_stride
+            shard_coord_stride *= size
+
+        return shard_linear_idx
+
+
+class TensorParallelRNGTracker(RNGStateTracker):
+    def __init__(self, device_type: str = "cuda"):
+        super().__init__(device_type)
+        # copy the default RNG state
+        self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()
+
+    def _manual_seed(
+        self,
+        tp_mesh: DeviceMesh,
+        base_seed: int = 1234,
+    ):
+        tensor_parallel_rank = tp_mesh.get_local_rank()
+        # this magic number 2718 comes from Megatron's code
+        # (https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/random.py#L162-L163)
+        MegatronMagicNum = 2718
+        tensor_parallel_seed = base_seed + MegatronMagicNum + tensor_parallel_rank
+        self.set_seed("tensor-parallel-rng", tensor_parallel_seed)
+
+    @contextlib.contextmanager
+    def _distribute_region(self, spec: DTensorSpec):
+        # check if the tensor parallel rng state has been synchronized or not
+        if not self.rng_state_is_sync("tensor-parallel-rng"):
+            raise RuntimeError(
+                "TensorParallelRNGTracker requires the random state to be synchronized "
+                "before entering into a distribute region!"
+            )
+
+        if self.distribute_region_enabled:
+            with torch.random.fork_rng(self._devices, device_type=self._device_type):
+                self._device_handle.set_rng_state(
+                    self.rng_states["tensor-parallel-rng"]
+                )
+                try:
+                    yield
+                finally:
+                    self.rng_states[
+                        "tensor-parallel-rng"
+                    ] = self._device_handle.get_rng_state()
+        else:
+            yield
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/redistribute.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/redistribute.py
new file mode 100644
index 0000000000000000000000000000000000000000..021c0adeac5c100de121f5ded51316e898c1d2aa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/redistribute.py
@@ -0,0 +1,337 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from functools import lru_cache
+from typing import cast, Dict, List, NamedTuple, Tuple
+
+import torch
+import torch.distributed._functional_collectives as funcol
+import torch.distributed._tensor.api as dtensor
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Placement,
+    Replicate,
+    Shard,
+)
+
+
+class _TransformInfo(NamedTuple):
+    mesh_dim: int
+    src_dst_placements: Tuple[Placement, Placement]
+    # logical_shape on this mesh dimension
+    logical_shape: List[int]
+
+
+def _replicate_then_shard(val: _TransformInfo) -> int:
+    """
+    This is a helper function to allow reordering _TransformInfo list. The high level
+    idea is that we want to reorder the sharding redistributions so that the DTensor
+    redistribution is consistent with its full tensor. This is built on top of two simple
+    assumptions:
+    1. Replication happens from inner to outer dimension. i.e. Shard -> Replicate
+    2. Sharding happens from outer to inner dimension, i.e. Replicate -> Shard
+
+    So we always put the replication first and put sharding later.
+    """
+    mesh_dim = val.mesh_dim
+    src, dst = val.src_dst_placements
+    if (dst.is_replicate() or dst.is_partial()) and src.is_shard():
+        return -mesh_dim
+    elif (src.is_replicate() or src.is_partial()) and dst.is_shard():
+        return mesh_dim
+    else:
+        return 0
+
+
+@lru_cache(maxsize=None)
+def _gen_transform_infos(
+    src_spec: DTensorSpec,
+    dst_spec: DTensorSpec,
+) -> List[_TransformInfo]:
+    """
+    Generate the transform infos from the source placements to the target placements, to
+    transform from source to target placement it might have multipl steps, i.e. it might
+    decompose Si -> Sj into Si -> R -> Sj.
+    This would detects if there're mis-aligned shardings between src/dst placements.
+    i.e. (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), in this case Shard(0) -> Shard(0)
+    for mesh dimension 1 actually needs reshard, because in the first case it's a sub-sharding
+    of an already tensor dimension 0, and in the second case, it's the first sharding on tensor
+    dimension 0.
+
+    Note that we also currently handles sharding on different tensor dimensions, e.g.
+    Shard(0) -> Shard(1) in this pass
+    """
+    src_dim_counts: Dict[int, int] = {}
+    dst_dim_counts: Dict[int, int] = {}
+    transform_infos: List[_TransformInfo] = []
+
+    src_placements = src_spec.placements
+    dst_placements = dst_spec.placements
+    device_mesh = src_spec.device_mesh
+    my_coordinate = device_mesh.get_coordinate()
+    assert my_coordinate is not None
+
+    # logical shape records the logic tensor shape on the mesh dimension
+    # this is useful to ensure uneven sharding gets correct output shape
+    initial_logical_shape = list(src_spec.shape)
+    mesh_dims_to_logical_shape = [initial_logical_shape]
+    mesh_ndim = len(src_placements)
+
+    for i, (src, dst) in enumerate(zip(src_placements, dst_placements)):
+        # detect mis-aligned sharding and build logical shapes
+        current_logical_shape = mesh_dims_to_logical_shape[i]
+        if isinstance(src, Shard):
+            src_dim_counts[src.dim] = src_dim_counts.get(src.dim, 0) + 1
+
+            if i < mesh_ndim - 1:
+                # calculate and save the logical shape for this sharding
+                mesh_dim_size = device_mesh.size(mesh_dim=i)
+                local_shard_size, _ = src._local_shard_size_on_dim(
+                    current_logical_shape[src.dim],
+                    mesh_dim_size,
+                    my_coordinate[i],
+                )
+                new_logical_shape = list(current_logical_shape)
+                new_logical_shape[src.dim] = local_shard_size
+                mesh_dims_to_logical_shape.append(new_logical_shape)
+        else:
+            mesh_dims_to_logical_shape.append(current_logical_shape)
+
+        if isinstance(dst, Shard):
+            dst_dim_counts[dst.dim] = dst_dim_counts.get(dst.dim, 0) + 1
+
+        if (
+            isinstance(src, Shard)
+            and isinstance(dst, Shard)
+            and (
+                src.dim != dst.dim or src_dim_counts[src.dim] != dst_dim_counts[dst.dim]
+            )
+        ):
+            # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j)
+            transform_infos.append(
+                _TransformInfo(
+                    mesh_dim=i,
+                    src_dst_placements=(src, Replicate()),
+                    logical_shape=mesh_dims_to_logical_shape[i],
+                )
+            )
+            transform_infos.append(
+                _TransformInfo(
+                    mesh_dim=i,
+                    src_dst_placements=(Replicate(), dst),
+                    logical_shape=mesh_dims_to_logical_shape[i],
+                )
+            )
+        else:
+            transform_infos.append(
+                _TransformInfo(
+                    mesh_dim=i,
+                    src_dst_placements=(src, dst),
+                    logical_shape=mesh_dims_to_logical_shape[i],
+                )
+            )
+
+    # sort the pairs by first perform replication then sharding
+    transform_infos.sort(key=_replicate_then_shard)
+    return transform_infos
+
+
+def redistribute_local_tensor(
+    local_tensor: torch.Tensor,
+    current_spec: DTensorSpec,
+    target_spec: DTensorSpec,
+    *,
+    async_op: bool = False,
+    is_backward: bool = False,
+) -> torch.Tensor:
+    """
+    This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
+    the target DTensorSpec, which involves the necessary collective calls to transform
+    the local shard of the DTensor from its current spec to the target spec.
+    """
+
+    if current_spec.mesh != target_spec.mesh:
+        # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same
+        raise NotImplementedError("Cross device mesh comm not supported yet!")
+
+    new_local_tensor = None
+    device_mesh = current_spec.mesh
+
+    my_coordinate = device_mesh.get_coordinate()
+
+    if my_coordinate is None:
+        # if rank is not part of mesh, we skip redistribute and simply return local_tensor,
+        # which should be an empty tensor
+        return local_tensor
+
+    transform_infos = _gen_transform_infos(current_spec, target_spec)
+
+    for transform_info in transform_infos:
+        i = transform_info.mesh_dim
+        current, target = transform_info.src_dst_placements
+        num_chunks = device_mesh.size(mesh_dim=i)
+
+        if current == target:
+            # short cut, just use the original local tensor
+            new_local_tensor = local_tensor
+            continue
+
+        if target.is_replicate():
+            # Case 1: target is Replicate
+            if current.is_partial():
+                partial_spec = cast(_Partial, current)
+                new_local_tensor = partial_spec._reduce_value(
+                    local_tensor, device_mesh, i
+                )
+            elif current.is_shard():
+                current_placement = cast(Shard, current)
+                new_local_tensor = current_placement._to_replicate_tensor(
+                    local_tensor, device_mesh, i, transform_info.logical_shape
+                )
+            else:
+                raise RuntimeError(
+                    f"redistribute from {current} to {target} not supported yet"
+                )
+        elif target.is_shard():
+            # Case 2: target is Shard
+            target_placement = cast(Shard, target)
+            target_dim = target_placement.dim
+            if current.is_partial():
+                partial_spec = cast(_Partial, current)
+                new_local_tensor = partial_spec._reduce_shard_value(
+                    local_tensor, device_mesh, i, target_placement
+                )
+            elif current.is_replicate():
+                # split the tensor and return the corresponding cloned local shard
+                new_local_tensor = target_placement._replicate_to_shard(
+                    local_tensor, device_mesh, i, my_coordinate[i]
+                )
+            else:
+                # NOTE: we don't support this case efficiently yet, the fallback path we are going here is
+                # to decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1)
+                # TODO: enable this with all_to_all
+                assert (
+                    current.is_shard()
+                ), f"Current placement should be shard but found {current}"
+                shard_spec = cast(Shard, current)
+                if shard_spec.dim != target_placement.dim:
+                    new_local_tensor = shard_spec._to_replicate_tensor(
+                        local_tensor, device_mesh, i, transform_info.logical_shape
+                    )
+                    shards, _ = target_placement._split_tensor(
+                        new_local_tensor,
+                        num_chunks,
+                        with_padding=False,
+                        contiguous=False,
+                    )
+                    new_local_tensor = shards[my_coordinate[i]]
+        elif target.is_partial():
+            if current.is_replicate():
+                partial_spec = cast(_Partial, target)
+                # skip the replicate to partial transformation when we are in backward pass
+                # In this case we keep the grad as replicate, this is because we don't
+                # want to convert the replicated gradients back to partial, although
+                # that's logically conform with the same layout, converting the gradients
+                # back to partial is actually useless as you would have to do reduce later
+                # which would be more expensive than keeping it replicate! For this reason,
+                # we keep the replicate grad here.
+                new_local_tensor = (
+                    partial_spec._partition_value(local_tensor, device_mesh, i)
+                    if not is_backward
+                    else local_tensor
+                )
+            elif current.is_shard():
+                if not is_backward:
+                    raise RuntimeError(
+                        f"redistribute from {current} to {target} not supported yet"
+                    )
+                # for backward shard -> partial, we just need to convert the shard to replicate
+                current_placement = cast(Shard, current)
+                new_local_tensor = current_placement._to_replicate_tensor(
+                    local_tensor, device_mesh, i, transform_info.logical_shape
+                )
+            else:
+                # partial -> partial no op, should never hit
+                new_local_tensor = local_tensor
+
+        assert new_local_tensor is not None
+        local_tensor = new_local_tensor
+
+    assert new_local_tensor is not None, "redistribute failed!"
+
+    if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
+        new_local_tensor = new_local_tensor.wait()
+
+    return new_local_tensor
+
+
+class Redistribute(torch.autograd.Function):
+    @staticmethod
+    def forward(  # type: ignore[override]
+        # pyre-fixme[2]: Parameter must be annotated.
+        ctx,
+        input: "dtensor.DTensor",
+        device_mesh: DeviceMesh,
+        placements: Tuple[Placement, ...],
+        async_op: bool = False,
+    ):
+        current_spec = input._spec
+        ctx.current_spec = current_spec
+        ctx.async_op = async_op
+        target_spec = DTensorSpec(
+            device_mesh, placements, tensor_meta=input._spec.tensor_meta
+        )
+
+        local_tensor = input._local_tensor
+        output = redistribute_local_tensor(
+            local_tensor, current_spec, target_spec, async_op=async_op
+        )
+
+        return dtensor.DTensor(
+            output,
+            device_mesh,
+            target_spec.placements,
+            shape=input.shape,
+            dtype=input.dtype,
+            requires_grad=input.requires_grad,
+            stride=input.stride(),
+        )
+
+    @staticmethod
+    def backward(ctx, grad_output: "dtensor.DTensor"):  # type: ignore[override]
+        previous_spec = ctx.current_spec
+        current_spec = grad_output._spec
+        async_op = ctx.async_op
+
+        local_tensor = grad_output._local_tensor
+        output = redistribute_local_tensor(
+            local_tensor,
+            current_spec,
+            previous_spec,
+            async_op=async_op,
+            is_backward=True,
+        )
+        # normalize the target placement to replicate if it is partial
+        normalized_placements: List[Placement] = []
+        for previous_placement in previous_spec.placements:
+            if previous_placement.is_partial():
+                # keep target placement to replicate instead of partial in this case
+                normalized_placements.append(Replicate())
+            else:
+                normalized_placements.append(previous_placement)
+        output_dtensor = dtensor.DTensor(
+            output,
+            previous_spec.mesh,
+            tuple(normalized_placements),
+            shape=grad_output.shape,
+            dtype=grad_output.dtype,
+            requires_grad=grad_output.requires_grad,
+            stride=grad_output.stride(),
+        )
+
+        return (
+            output_dtensor,
+            None,
+            None,
+            None,
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/sharding_prop.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/sharding_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..a106c1aeb73f4005688d9e157a43c55303a37b36
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/sharding_prop.py
@@ -0,0 +1,410 @@
+from functools import lru_cache
+from itertools import chain
+from typing import Callable, cast, Dict, List, Optional, Sequence, Union
+
+import torch
+from torch._ops import OpOverload
+from torch._subclasses import FakeTensorMode
+from torch.distributed._tensor._utils import try_find_mesh_from_args
+from torch.distributed._tensor.op_schema import (
+    DTensorSpec,
+    OpInfo,
+    OpSchema,
+    OpStrategy,
+    OutputSharding,
+    OutputSpecType,
+    PlacementStrategy,
+    RuntimeSchemaInfo,
+    StrategyType,
+    TupleStrategy,
+)
+from torch.distributed._tensor.placement_types import TensorMeta
+from torch.distributed.device_mesh import DeviceMesh
+
+aten = torch.ops.aten
+
+
+def _length(obj) -> int:
+    if obj is None:
+        return 0
+    if not isinstance(obj, Sequence):
+        return 1
+    return len(obj)
+
+
+class ShardingPropagator:
+    def __init__(self) -> None:
+        self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
+        self.op_strategy_funcs: Dict[
+            OpOverload,
+            Callable[[DeviceMesh, OpSchema], StrategyType],
+        ] = {}
+        # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
+        self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
+        self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached)  # type: ignore[method-assign]
+
+    def register_sharding_prop_rule(
+        self,
+        op_overload: OpOverload,
+        rule_func: Callable[[OpSchema], OutputSharding],
+        schema_info: Optional[RuntimeSchemaInfo] = None,
+    ):
+        """
+        Register a sharding propagation rule for an operator.
+        """
+        self.op_to_rules[op_overload] = rule_func
+        if schema_info is not None:
+            self.op_to_schema_info[op_overload] = schema_info
+
+    def register_op_strategy(
+        self,
+        op_overload: OpOverload,
+        strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType],
+        schema_info: Optional[RuntimeSchemaInfo] = None,
+    ):
+        """
+        Register a sharding strategy generator for an operator.
+        """
+        self.op_strategy_funcs[op_overload] = strategy_func
+        if schema_info is not None:
+            self.op_to_schema_info[op_overload] = schema_info
+
+    @lru_cache
+    def _propagate_tensor_meta(
+        self, op_schema: OpSchema
+    ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
+        """
+        Propagate the tensor metadata, it could either return a TensorMeta
+        or a list/tuple of TensorMetas
+        """
+        if op_schema.op == aten.equal.default:
+            # data dependent ops can't be used for fake propagation
+            return None
+
+        # NOTE: We must call the tracing in fake tensor mode so that it
+        # avoids materializing memory
+        with FakeTensorMode():
+            fake_args = op_schema.gen_fake_args()
+            fake_kwargs = op_schema.gen_fake_kwargs()
+            fake_out = op_schema.op(*fake_args, **fake_kwargs)
+
+        if isinstance(fake_out, torch.Tensor):
+            return TensorMeta(
+                shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype
+            )
+
+        elif isinstance(fake_out, (tuple, list)):
+            tensor_meta_list: List[Optional[TensorMeta]] = []
+            for fake_out_item in fake_out:
+                if isinstance(fake_out_item, torch.Tensor):
+                    tensor_meta_list.append(
+                        TensorMeta(
+                            shape=fake_out_item.shape,
+                            stride=fake_out_item.stride(),
+                            dtype=fake_out_item.dtype,
+                        )
+                    )
+                else:
+                    tensor_meta_list.append(None)
+            return (
+                tuple(tensor_meta_list)
+                if isinstance(fake_out, tuple)
+                else tensor_meta_list
+            )
+        else:
+            # if fake is not a tensor or tuple of tensor, return as none
+            return None
+
+    def _wrap_output_spec_tensor_meta(
+        self,
+        op: OpOverload,
+        output_specs: OutputSpecType,
+        output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]],
+    ) -> None:
+        """
+        Wrap the output_specs with the tensor metadata from the output.
+        """
+
+        if isinstance(output_specs, DTensorSpec):
+            if not isinstance(output_tensor_meta, TensorMeta):
+                # Either error due to ShardingPropagator or due to incorrect OutputSpec
+                if not isinstance(output_tensor_meta, (tuple, list)):
+                    raise ValueError(
+                        "ShardingPropagator error: output does not have an associated TensorMeta"
+                    )
+                raise ValueError(
+                    f"For the op {op.name()}, `output_specs` has 1 output which does not equal the "
+                    f"number of op outputs: {len(output_tensor_meta)}."
+                )
+            output_specs.tensor_meta = output_tensor_meta
+        elif isinstance(output_specs, (tuple, list)):
+            if not isinstance(output_tensor_meta, (tuple, list)) or len(
+                output_specs
+            ) != len(output_tensor_meta):
+                raise ValueError(
+                    f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the "
+                    f"number of op outputs {_length(output_tensor_meta)}."
+                )
+            for i, spec in enumerate(output_specs):
+                if isinstance(spec, DTensorSpec):
+                    output_tensor_meta_i = output_tensor_meta[i]
+                    if not isinstance(output_tensor_meta_i, TensorMeta):
+                        raise ValueError(
+                            f"ShardingPropagator error: output {i} does not have an associated TensorMeta"
+                        )
+                    spec.tensor_meta = output_tensor_meta_i
+
+    def propagate(self, op_info: OpInfo) -> None:
+        # We cannot use an lru cache if we know that inputs will have dynamic shapes,
+        # because SymInts are not hashable.
+        # This is generally ok because this only happens during tracing in torch.compile,
+        # and tracing does not need to be as fast as eagermode DTensor usages.
+        if op_info.schema.has_symints:
+            output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
+        else:
+            output_sharding = self.propagate_op_sharding(op_info.schema)
+        op_info.output_sharding = output_sharding
+
+    def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
+        """
+        Propagate the sharding for an operator given the op_schema.
+        """
+        # special case op, we don't need to propagate for local
+        # scalar. TODO: figure out a better way to handle this
+        if op_schema.op is aten._local_scalar_dense.default:
+            return OutputSharding(None, [op_schema])
+
+        out_tensor_meta = self._propagate_tensor_meta(op_schema)
+
+        def spec_to_strategy(spec: object) -> object:
+            if isinstance(spec, DTensorSpec):
+                return OpStrategy([PlacementStrategy(spec)])
+            elif (
+                isinstance(spec, (list, tuple))
+                and len(spec) > 0
+                and isinstance(spec[0], DTensorSpec)
+            ):
+                # tensor list create tuple strategy
+                tuple_strategy = [spec_to_strategy(s) for s in spec]
+                tuple_strategy = cast(Sequence[StrategyType], tuple_strategy)
+                return TupleStrategy(
+                    tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy
+                )
+            else:
+                return spec
+
+        if op_schema.op in self.op_strategy_funcs:
+            # generate op strategy for the op.
+            mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema)
+            # swap the args spec with args strategies
+            args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema]
+
+            kwargs_op_strategy = {
+                k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items()
+            }
+
+            # construct a new OpSchema on args for strategy based propagation
+            strategy_schema: OpSchema = OpSchema(
+                op=op_schema.op,
+                args_schema=tuple(args_op_strategy),
+                kwargs_schema=kwargs_op_strategy,
+            )
+
+            op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema)
+
+            if isinstance(op_strategy, OpStrategy):
+                # single Op strategy
+                output_strategy = self._select_strategy(op_strategy)
+
+                # check if we need to redistribute the input
+                needs_redistribute = False
+                expected_input_specs = []
+
+                # in case where the op does not specify input_specs and output_specs
+                # is a DTensorSpec, we use output_specs as the spec for each DTensor
+                # input arg.
+                if output_strategy.input_specs is None:
+                    assert isinstance(output_strategy.output_specs, DTensorSpec)
+
+                for idx, input_spec in enumerate(op_schema.args_spec):
+                    desired_spec = (
+                        output_strategy.output_spec
+                        if output_strategy.input_specs is None
+                        else output_strategy.input_specs[idx]
+                    )
+                    expected_input_specs.append(desired_spec)
+                    if input_spec.placements != desired_spec.placements:
+                        needs_redistribute = True
+
+                suggestion_schema = None
+                if needs_redistribute:
+                    reshard_schema = OpSchema(
+                        op_schema.op, tuple(expected_input_specs), {}
+                    )
+                    reshard_schema._inplace_rewrap_schema_suggestion(op_schema)
+                    suggestion_schema = [reshard_schema]
+
+                # construct output spec for the op
+                if op_schema.return_type_tuple_tensor_like():
+                    # for ops that return multiple tensors and the output_specs is not
+                    # a tuple, we use a tuple of that single output spec as the new
+                    # output_specs
+                    output_specs: OutputSpecType = output_strategy.output_specs
+                    if isinstance(output_specs, DTensorSpec):
+                        output_specs = tuple(
+                            [
+                                # create a new DTensorSpec with the same placement as the
+                                # output_specs in output_strategy
+                                DTensorSpec(
+                                    mesh=output_specs.mesh,
+                                    placements=output_specs.placements,
+                                    tensor_meta=output_specs.tensor_meta,
+                                )
+                                for _ in range(len(op_schema.op._schema.returns))
+                            ]
+                        )
+                elif op_schema.return_type_tensor():
+                    output_specs = output_strategy.output_specs
+                else:
+                    output_specs = None
+
+                output_sharding = OutputSharding(
+                    output_specs,
+                    suggestion_schema,
+                    needs_redistribute=needs_redistribute,
+                )
+            elif isinstance(op_strategy, TupleStrategy):
+                # tuple strategy output sharding processing
+                # runtime selected placement strategy for each TupleStrategy input arg
+                selected_strategies: List[PlacementStrategy] = []
+                out_spec_list: List[DTensorSpec] = []
+                for strategy in op_strategy.childs:
+                    assert isinstance(strategy, OpStrategy)
+                    selected_strategy = self._select_strategy(strategy)
+                    selected_strategies.append(selected_strategy)
+                    out_spec_list.append(selected_strategy.output_spec)
+
+                needs_redistribute = False
+                suggestion_args: List[object] = []
+                for arg_idx, arg in enumerate(op_schema.args_schema):
+                    if isinstance(arg, (list, tuple)) and isinstance(
+                        arg[0], DTensorSpec
+                    ):
+                        expected_input_spec_list: List[DTensorSpec] = []
+                        for idx, arg_spec in enumerate(arg):
+                            expected_input_spec = selected_strategies[idx].input_spec(
+                                arg_idx
+                            )
+                            expected_input_spec = (
+                                expected_input_spec.shallow_copy_with_tensor_meta(
+                                    arg_spec.tensor_meta
+                                )
+                            )
+                            if arg_spec.placements != expected_input_spec.placements:
+                                needs_redistribute = True
+                            expected_input_spec_list.append(expected_input_spec)
+                        suggestion_args.append(
+                            tuple(expected_input_spec_list)
+                            if isinstance(arg, tuple)
+                            else expected_input_spec_list
+                        )
+                    elif isinstance(arg, DTensorSpec):
+                        expected_input_spec = selected_strategies[0].input_spec(arg_idx)
+                        expected_input_spec = (
+                            expected_input_spec.shallow_copy_with_tensor_meta(
+                                arg.tensor_meta
+                            )
+                        )
+                        if arg.placements != expected_input_spec.placements:
+                            needs_redistribute = True
+                        suggestion_args.append(expected_input_spec)
+                    else:
+                        suggestion_args.append(arg)
+
+                suggestion_schema = None
+                if needs_redistribute:
+                    reshard_schema = OpSchema(
+                        op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema
+                    )
+                    suggestion_schema = [reshard_schema]
+
+                output_sharding = OutputSharding(
+                    tuple(out_spec_list) if out_tensor_meta is not None else None,
+                    suggestion_schema,
+                    needs_redistribute=needs_redistribute,
+                )
+            else:
+                raise ValueError("Unsupported op strategy type")
+
+            # associate the output sharding with the output tensor metadata
+            self._wrap_output_spec_tensor_meta(
+                op_schema.op, output_sharding.output_spec, out_tensor_meta
+            )
+            return output_sharding
+        elif op_schema.op in self.op_to_rules:
+            # propagate the sharding with rule
+            sharding_prop_func = self.op_to_rules[op_schema.op]
+
+            # step 1. there's sharding propagation rule, run
+            # sharding propagation to get the output sharding
+            try:
+                output_sharding = sharding_prop_func(op_schema)
+            except NotImplementedError as e:
+                raise e
+            except Exception as e:
+                raise RuntimeError(
+                    f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}"
+                ) from e
+
+            # step 2. if can't get output_spec from sharding
+            # propagation (i.e. no rules apply for input
+            # placements), we return the output sharding
+            # with schema suggestions, which can be used to
+            # decide how to do redistribute on inputs
+            if output_sharding.output_spec is None:
+                if output_sharding.schema_suggestions is None:
+                    if output_sharding.failed_reason is not None:
+                        raise RuntimeError(
+                            f"Sharding propagation failed on op {op_schema}!"
+                            f"Failed reason: {output_sharding.failed_reason}"
+                        )
+                else:
+                    # we do auto redistribute on inputs if necessary
+                    # to get an eligible input, which we will pick a
+                    # schema suggestion base on the redistribute cost.
+                    # For now we simply pick the first suggestion.
+                    suggested_input_schema = output_sharding.schema_suggestions[0]
+                    # run sharding propagation again with suggested schema
+                    propagation_res = sharding_prop_func(suggested_input_schema)
+                    # we set the output sharding with the new propagation result
+                    # so that dispatching know both output_spec and schema_suggestions
+                    # exist, which indicates a reshard is needed
+                    output_sharding.output_spec = propagation_res.output_spec
+                    output_sharding.needs_redistribute = True
+
+            # associate the output sharding with the output tensor metadata
+            self._wrap_output_spec_tensor_meta(
+                op_schema.op, output_sharding.output_spec, out_tensor_meta
+            )
+
+            return output_sharding
+        else:
+            raise NotImplementedError(
+                f"Operator {op_schema.op} does not have a sharding strategy registered."
+            )
+
+    def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy:
+        if len(strategy.strategies) == 1:
+            # short cut with only one possible strategy
+            return strategy.strategies[0]
+
+        strategy_costs: List[float] = []
+        for strtg in strategy.strategies:
+            assert (
+                strtg.redistribute_cost is not None
+            ), "must set redistribute cost each strategy!"
+            redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
+            strategy_costs.append(redistribute_cost)
+
+        # for eager execution, we just select the one with the minimal redistribute cost
+        return strategy.strategies[strategy_costs.index(min(strategy_costs))]
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tensor/tp_conv.py b/MLPY/Lib/site-packages/torch/distributed/_tensor/tp_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..230b1f2c0974e5f04710c996b7c6a3d6f14a85d7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tensor/tp_conv.py
@@ -0,0 +1,277 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# implement matrix related ops for distributed tensor
+from typing import cast, Dict, List, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed._tensor.api as dtensor
+
+aten = torch.ops.aten
+
+
+def _requires_data_exchange(padding):
+    # TODO: whether there requires data exchange is currently determined by padding
+    return padding[1] != 0
+
+
+def _is_supported(input_size, kernel_size, stride, padding, dilation):
+    if dilation[1] != 1:
+        raise RuntimeError("Dilation must be 1 for tensor parallel convolution.")
+    if padding[1] != 0:
+        if stride[1] != 1:
+            raise RuntimeError(
+                "Stride must be 1 when there is padding for tensor parallel convolution."
+            )
+        if kernel_size[3] // 2 > input_size[3]:
+            raise RuntimeError(
+                "kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution."
+            )
+    else:
+        if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]):
+            raise RuntimeError(
+                "It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] "
+                "when there is padding for tensor parallel convolution."
+            )
+    return True
+
+
+def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size):
+    # dist comms and reconstruct local input tensor
+    send_to_right = in_tensor[:, :, :, -d1:].contiguous()
+    send_to_left = in_tensor[:, :, :, :d2].contiguous()
+    recv_from_right = torch.zeros_like(send_to_left)
+    recv_from_left = torch.zeros_like(send_to_right)
+
+    send_op_right = dist.P2POp(dist.isend, send_to_right, right)
+    send_op_left = dist.P2POp(dist.isend, send_to_left, left)
+    recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
+    recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
+
+    reqs = dist.batch_isend_irecv(
+        [send_op_right, send_op_left, recv_op_left, recv_op_right]
+    )
+    for req in reqs:
+        req.wait()
+
+    if rank == 0:
+        in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1)
+    elif rank == size - 1:
+        in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1)
+    else:
+        in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1)
+
+    return in_tensor
+
+
+def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size):
+    # dist comms and aggregate gradients for edge pixels
+    send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous()
+    send_to_left = grad_in_tensor[:, :, :, :d1].contiguous()
+    recv_from_right = torch.zeros_like(send_to_left)
+    recv_from_left = torch.zeros_like(send_to_right)
+
+    send_op_right = dist.P2POp(dist.isend, send_to_right, right)
+    send_op_left = dist.P2POp(dist.isend, send_to_left, left)
+    recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
+    recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
+
+    reqs = dist.batch_isend_irecv(
+        [send_op_right, send_op_left, recv_op_left, recv_op_right]
+    )
+    for req in reqs:
+        req.wait()
+
+    if rank == 0:
+        grad_in_tensor = grad_in_tensor[:, :, :, :-d2]
+        grad_in_tensor[:, :, :, -d1:] = torch.add(
+            grad_in_tensor[:, :, :, -d1:], recv_from_right
+        )
+    elif rank == size - 1:
+        grad_in_tensor = grad_in_tensor[:, :, :, d1:]
+        grad_in_tensor[:, :, :, :d2] = torch.add(
+            grad_in_tensor[:, :, :, :d2], recv_from_left
+        )
+    else:
+        grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2]
+        grad_in_tensor[:, :, :, -d1:] = torch.add(
+            grad_in_tensor[:, :, :, -d1:], recv_from_right
+        )
+        grad_in_tensor[:, :, :, :d2] = torch.add(
+            grad_in_tensor[:, :, :, :d2], recv_from_left
+        )
+
+
+def tp_convolution(
+    op_call: torch._ops.OpOverload,
+    local_tensor_args: Tuple[object, ...],
+    local_tensor_kwargs: Dict[str, object],
+) -> object:
+    assert op_call == aten.convolution.default
+    assert len(local_tensor_args) == 9
+
+    rank = dist.get_rank()
+    size = dist.get_world_size()
+    in_tensor = cast(torch.Tensor, local_tensor_args[0])
+    weight = cast(torch.Tensor, local_tensor_args[1])
+    stride, padding, dilation = local_tensor_args[3:6]
+
+    assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
+    assert isinstance(padding, List)
+
+    if not _requires_data_exchange(padding):
+        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
+        return local_results
+    else:
+        # step 0 compute the overlap pixels of the input tensor
+        d = weight.shape[3] - 1
+        d1 = d // 2
+        d2 = d - d1
+        assert d1 + d2 == d
+        right = (rank + 1) % size
+        left = (rank - 1 + size) % size
+
+        # step1 reconstruct local input tensor
+        in_tensor = _ring_send_recv_construct(
+            in_tensor, d1, d2, left, right, rank, size
+        )
+
+        # step2 feed local input tensor to op_call
+        local_tensor_args_list = list(local_tensor_args)
+        local_tensor_args_list[0] = in_tensor
+        local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
+        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
+
+        # step3 remove extra outputs from the results
+        padding_w = padding[1]
+        w = local_results.size(3)
+        if rank == 0:
+            local_results = local_results[:, :, :, : w - padding_w]
+        elif rank == size - 1:
+            local_results = local_results[:, :, :, padding_w:]
+        else:
+            local_results = local_results[:, :, :, padding_w : w - padding_w]
+
+        return local_results
+
+
+def tp_convolution_backward(
+    op_call: torch._ops.OpOverload,
+    local_tensor_args: Tuple[object, ...],
+    local_tensor_kwargs: Dict[str, object],
+) -> object:
+    assert op_call == aten.convolution_backward.default
+    assert len(local_tensor_args) == 11
+
+    rank = dist.get_rank()
+    size = dist.get_world_size()
+    grad_out_tensor = cast(torch.Tensor, local_tensor_args[0])
+    in_tensor = cast(torch.Tensor, local_tensor_args[1])
+    weight = cast(torch.Tensor, local_tensor_args[2])
+    stride, padding, dilation = local_tensor_args[4:7]
+
+    assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
+    assert isinstance(padding, List)
+
+    if not _requires_data_exchange(padding):
+        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
+        return local_results
+    else:
+        # step 0 compute the overlap pixels of the input tensor
+        d = weight.shape[3] - 1
+        d1 = d // 2
+        d2 = d - d1
+        assert d1 + d2 == d
+        right = (rank + 1) % size
+        left = (rank - 1 + size) % size
+
+        # step1 reconstruct local input tensor
+        in_tensor = _ring_send_recv_construct(
+            in_tensor, d1, d2, left, right, rank, size
+        )
+
+        # step2 reconstruct local gradient output tensor
+        N, C_out, H_out, _ = grad_out_tensor.shape
+        padding_w = padding[1]
+        if rank == 0:
+            grad_out_tensor = torch.nn.functional.pad(
+                grad_out_tensor, (0, padding_w), "constant", 0
+            )
+        elif rank == size - 1:
+            grad_out_tensor = torch.nn.functional.pad(
+                grad_out_tensor, (padding_w, 0), "constant", 0
+            )
+        else:
+            grad_out_tensor = torch.nn.functional.pad(
+                grad_out_tensor, (padding_w, padding_w), "constant", 0
+            )
+
+        # step3 feed local input tensor to op_call
+        local_tensor_args_list = list(local_tensor_args)
+        local_tensor_args_list[0] = grad_out_tensor
+        local_tensor_args_list[1] = in_tensor
+        local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
+        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
+
+        # step4 aggregate gradients for edge pixels
+        grad_in_tensor = local_results[0]
+        grad_in_tensor = _ring_send_recv_aggregate(
+            grad_in_tensor, d1, d2, left, right, rank, size
+        )
+
+        local_results = list(local_results)
+        local_results[0] = grad_in_tensor
+        local_results = cast(Tuple[object, ...], local_results)
+
+        return local_results
+
+
+def convolution_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    # extract local tensor and sharding infos to a OpInfo
+    op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
+
+    # sharding propagation
+    dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
+    output_sharding = op_info.output_sharding
+    assert output_sharding is not None, "output sharding should not be None"
+
+    # local propagation
+    local_results = tp_convolution(
+        op_call, tuple(op_info.local_args), op_info.local_kwargs
+    )
+
+    return dtensor.DTensor._op_dispatcher.wrap(
+        local_results, output_sharding.output_spec
+    )
+
+
+def convolution_backward_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    # Redistribute grad_output tensor to the same placement as input tensor
+    args = list(args)
+    assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
+    args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
+    args = tuple(args)
+
+    # extract local tensor and sharding infos to a OpInfo
+    op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
+
+    # sharding propagation
+    dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
+    output_sharding = op_info.output_sharding
+    assert output_sharding is not None, "output sharding should not be None"
+
+    # local propagation
+    local_results = tp_convolution_backward(
+        op_call, tuple(op_info.local_args), op_info.local_kwargs
+    )
+
+    return dtensor.DTensor._op_dispatcher.wrap(
+        local_results, output_sharding.output_spec
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tools/__init__.py b/MLPY/Lib/site-packages/torch/distributed/_tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d1cf0c563abd9886a50a6ea78d257d4a4d26b03
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tools/__init__.py
@@ -0,0 +1 @@
+from .memory_tracker import MemoryTracker
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0569f337dbfbe171e2de064b919ace6aa8a2354
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..491dea1a476294087acadff886be05872fc72f1c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/_tools/memory_tracker.py b/MLPY/Lib/site-packages/torch/distributed/_tools/memory_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d736703e581fb182955acc86b75b838e0dbb8a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/_tools/memory_tracker.py
@@ -0,0 +1,299 @@
+from collections import defaultdict
+
+from itertools import chain
+
+import pickle
+
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    no_type_check,
+    Sequence,
+)
+
+import torch
+import torch.nn as nn
+from torch.utils.hooks import RemovableHandle
+from torch.utils._python_dispatch import TorchDispatchMode
+
+
+BYTES_PER_MB = 1024 * 1024.0
+
+
+class MemoryProfileDispatchMode(TorchDispatchMode):
+    """Run in ``TorchDispatchMode`` to get memory stats at operator level."""
+
+    def __init__(self, memory_tracker) -> None:
+        self.memory_tracker = memory_tracker
+
+    def __torch_dispatch__(self, func, types, args=..., kwargs=None):
+        rs = func(*args, **kwargs)
+        if func == torch.ops.aten.detach.default:
+            return rs
+        func_name: str = (
+            self.memory_tracker._cur_module_name
+            + "."
+            + func.__name__
+            + "_"
+            + str(self.memory_tracker._operator_names[func.__name__])
+        )
+        self.memory_tracker._operator_names[func.__name__] = (
+            self.memory_tracker._operator_names[func.__name__] + 1
+        )
+        self.memory_tracker._record_memory_stats(func_name)
+
+        return rs
+
+
+class MemoryTracker:
+    """
+    Collect and plot the memory stats at operator level.
+
+    Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``.
+    It also prints a summary for the top 20 operators that generate the most memories.
+
+    Example usage:
+
+        >>> # xdoctest: +SKIP(failing)
+        >>> net.cuda()
+        >>> input = input.cuda()
+
+        >>> mem_tracker = MemoryTracker()
+        >>> mem_tracker.start_monitor(net)
+
+        >>> net.zero_grad(True)
+        >>> loss = net(input)
+        >>> if isinstance(loss, dict):
+        >>>    loss = loss['out']
+        >>> loss.sum().backward()
+        >>> net.zero_grad(set_to_none=True)
+
+        >>> mem_tracker.stop()
+        >>> mem_tracker.summary()
+        >>> mem_tracker.show_traces()
+    """
+
+    def __init__(self) -> None:
+        torch._C._log_api_usage_once("torch.distributed.memory_tracker")
+        self._hooks: List[RemovableHandle] = []
+        self._operator_names: Dict[str, int] = defaultdict(int)
+        self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict()
+        self.memories_active: Dict[int, Dict[str, float]] = defaultdict()
+        self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict()
+        self._markers: Dict[str, int] = defaultdict(int)
+        self._cur_module_name: str = ""
+        self._op_index: int = 0
+        self._num_cuda_retries: int = 0
+
+    @no_type_check
+    def start_monitor(self, root_module: nn.Module) -> None:
+        """
+        Register module hooks and entering ``MemoryProfileDispatchMode``.
+
+        This enables operator level memory stats can be tracked during module runtime.
+        """
+        self._clear_state()
+        root_module.__setattr__("_memory_tracker_is_root", True)
+        for name, m in root_module.named_modules():
+            if m is not root_module:
+                m.__setattr__("_memory_tracker_is_root", False)
+            # fused_proxy_group does not support hooks
+            if ".fused_proxy_grouped_embedding_bag" in name:
+                continue
+            # hook ordering with other hooks added by users is not managed, so
+            # the memory stats tracked here may not completely accurate.
+            h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name))
+            h2 = m.register_forward_hook(self._create_post_forward_hook(name))
+            # it does not work well with jagged tensor somehow, the root cause is not
+            # clear and remove it for now as it does not really capture important info.
+            # h3 = m.register_backward_hook(self._create_backward_hook(name))
+            self._hooks.extend([h1, h2])
+        torch.cuda.empty_cache()
+        assert getattr(self, "profile_mode", None) is None
+        self.profile_mode = MemoryProfileDispatchMode(self)
+        self.profile_mode.__enter__()
+
+    @no_type_check
+    def stop(self) -> None:
+        """
+        Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level.
+
+        Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``.
+        """
+        self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0)
+
+        for h in self._hooks:
+            h.remove()
+        self._hooks.clear()
+        assert getattr(self, "profile_mode", None) is not None
+        self.profile_mode.__exit__(None, None, None)
+        self.profile_mode = None
+
+    @no_type_check
+    def summary(self, top: int = 20) -> None:
+        """
+        Print out the top operators that generate the most memories.
+
+        The number of the top operators can be configured.
+        """
+        op_diff: Dict[str, float] = defaultdict(float)
+        op_name, previous_allocated_memory = self.memories_allocated[0]
+        for i in range(1, self._op_index):
+            op_name, current_allocated_memory = self.memories_allocated[i]
+            op_diff[op_name] = current_allocated_memory - previous_allocated_memory
+            previous_allocated_memory = current_allocated_memory
+
+        print("------------------------------------------------")
+        print(f"The number of cuda retries are: {self._num_cuda_retries}")
+        print(f"Top {top} ops that generates memory are:")
+        for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[
+            :top
+        ]:
+            print(f"{k}: {v}MB")
+        print("------------------------------------------------")
+
+    @no_type_check
+    def show_traces(self, path: str = "") -> None:
+        import matplotlib.pyplot as plt
+
+        def _plot_figure(x, y_values, labels):
+            min_val = min(list(chain(*y_values))) * 0.999
+            max_val = max(list(chain(*y_values))) * 1.001
+            plt.figure()
+            for y, label in zip(y_values, labels):
+                plt.plot(x, y, label=label)
+            plt.xlabel("# Operator Calls")
+            plt.ylabel("Memory (MB)")
+            plt.legend()
+            for marker_name, marker in self._markers.items():
+                if marker_name == "fw_bw_boundary":
+                    plt.plot(
+                        [marker, marker],
+                        [min_val, max_val],
+                        "r",
+                        lw=2,
+                        label=marker_name,
+                    )
+                else:
+                    plt.plot(
+                        [marker, marker],
+                        [min_val, max_val],
+                        "k-",
+                        lw=2,
+                        label=marker_name,
+                    )
+
+        if path != "":
+            self.load(path)
+
+        y_1 = [gb for (name, gb) in self.memories_allocated.values()]
+        y_2 = [gb for (name, gb) in self.memories_active.values()]
+        y_3 = [gb for (name, gb) in self.memories_reserved.values()]
+        x = list(range(len(y_1)))
+        # Split figures when there is big difference between
+        # "reserved_memory" and "allocated_memory" or "active_memory".
+        _plot_figure(
+            x,
+            [list(y_1), list(y_2), list(y_3)],
+            ["allocated_memory", "active_memory", "reserved_memory"],
+        )
+        _plot_figure(x, [list(y_1)], ["allocated_memory"])
+        _plot_figure(x, [list(y_2)], ["active_memory"])
+        _plot_figure(x, [list(y_3)], ["reserved_memory"])
+
+    def save_stats(self, path: str) -> None:
+        """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook."""
+        stats = {
+            "memories_allocated": self.memories_allocated,
+            "memories_active": self.memories_active,
+            "memories_reserved": self.memories_reserved,
+            "markers": self._markers,
+            "num_alloc_retries": self._num_cuda_retries,
+        }
+
+        with open(path, "wb") as f:
+            pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL)
+
+    def load(self, path: str) -> None:
+        """Load the pickled memory stats to plot the traces or print the summary."""
+        with open(path, "rb") as f:
+            stats = pickle.load(f)
+
+        self.memories_allocated = stats["memories_allocated"]
+        self.memories_active = stats["memories_active"]
+        self.memories_reserved = stats["memories_reserved"]
+        self._markers = stats["markers"]
+        self._num_cuda_retries = stats["num_alloc_retries"]
+
+    def _create_pre_forward_hook(self, name: str) -> Callable:
+        """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start."""
+        def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
+            self._cur_module_name = f"{name}.forward"
+            if (
+                hasattr(module, "_memory_tracker_is_root")
+                and module._memory_tracker_is_root
+            ):
+                self._add_marker("fw_start")
+
+        return _pre_forward_hook
+
+    def _create_post_forward_hook(self, name: str) -> Callable:
+        """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass."""
+
+        def _post_forward_hook(
+            module: nn.Module,
+            inputs: Sequence[torch.Tensor],
+            outputs: Sequence[torch.Tensor],
+        ) -> None:
+            if (
+                hasattr(module, "_memory_tracker_is_root")
+                and module._memory_tracker_is_root
+            ):
+                self._add_marker("fw_bw_boundary")
+
+        return _post_forward_hook
+
+    def _create_backward_hook(self, name: str) -> Callable:
+        """Insert the current module name with backward prefix for the operator name."""
+
+        def _backward_hook(
+            module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor
+        ) -> None:
+            self._cur_module_name = f"{name}.backward"
+
+        return _backward_hook
+
+    @no_type_check
+    def _record_memory_stats(self, fn_name: str) -> None:
+        """
+        Record current memory allocated, current memory active and current memory reserved.
+
+        The memory stats dict is indexed with ``self._op_index``.
+        """
+        memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB
+        memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB
+        memory_active: float = (
+            torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB
+        )
+        self.memories_allocated[self._op_index] = (fn_name, memory_allocated)
+        self.memories_reserved[self._op_index] = (fn_name, memory_reserved)
+        self.memories_active[self._op_index] = (fn_name, memory_active)
+        self._op_index += 1
+
+    def _add_marker(self, marker_name: str) -> None:
+        """Set the marker's x-axis value."""
+        marker_val = len(self.memories_allocated.values())
+        self._markers[marker_name] = marker_val
+
+    def _clear_state(self) -> None:
+        """Clear states when start_monitor() is called."""
+        self._operator_names.clear()
+        self.memories_allocated.clear()
+        self.memories_active.clear()
+        self.memories_reserved.clear()
+        self._markers.clear()
+        self._cur_module_name = ""
+        self._op_index = 0
+        self._num_cuda_retries = 0
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2322e7967783425573d7a1e887cb2249d76a095
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/__init__.py
@@ -0,0 +1,3 @@
+from .join import Join
+from .join import Joinable
+from .join import JoinHook
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72ac83944d2f698e1692540d231f254d4754eb2e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb8199f162ae4bcaa57dac3b950129894cec244e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/__pycache__/join.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2447d96f6aeedc2b58c3026efb27368865fb6157
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bff6259c23cdb42546a94ff812a9ec9a98a19518
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ccaf64b3162bccb88e095f01b1f099ca0b2a0c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py
@@ -0,0 +1,314 @@
+import warnings
+from enum import auto, Enum
+from functools import partial
+from typing import Any, Callable, Dict, Iterator, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.autograd.graph import save_on_cpu
+from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
+from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
+
+_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module"
+_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "."
+
+
+class CheckpointImpl(Enum):
+    REENTRANT = auto()
+    NO_REENTRANT = auto()
+
+
+class ActivationWrapper(torch.nn.Module):
+    """
+    Base class for Activation Checkpoint and Activation Offload.
+
+    Not meant to be instantiated directly.
+    """
+
+    def __init__(self, mod):
+        super().__init__()
+        self._checkpoint_wrapped_module = mod
+        # state_dict post hook to remove prefix to allow loading into a
+        # non-checkpoint wrapped module.
+        self._register_state_dict_hook(self._post_state_dict_hook)
+        # load_state_dict pre-hook to allow loading back into
+        # checkpoint-wrapped module.
+        self._register_load_state_dict_pre_hook(
+            self._pre_load_state_dict_hook, with_module=True
+        )
+
+    def forward(self, *args, **kwargs):
+        raise ValueError("Subclasses should implement forward().")
+
+    def __getattr__(self, name: str) -> Any:
+        """Forward missing attributes to wrapped module."""
+        try:
+            return super().__getattr__(name)  # defer to nn.Module's logic
+        except AttributeError:
+            return getattr(self._checkpoint_wrapped_module, name)
+
+    def __getitem__(self, key: int) -> Any:
+        """Forward indexing calls in case the module is a nn.Sequential."""
+        return self._checkpoint_wrapped_module.__getitem__(key)  # type: ignore[operator]
+
+    def named_parameters(
+        self,
+        *args,
+        **kwargs,
+    ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
+        """
+        Override :meth:`named_parameters()` to intercept parameter names.
+
+        remove all occurrences of ``_CHECKPOINT_PREFIX``.
+        """
+        for param_name, param in super().named_parameters(*args, **kwargs):
+            yield param_name.replace(_CHECKPOINT_PREFIX, ""), param
+
+    @staticmethod
+    def _post_state_dict_hook(
+        module: nn.Module,
+        state_dict: Dict[str, Any],
+        prefix: str,
+        *args: Any,
+    ) -> Dict[str, Any]:
+        """
+        _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
+
+        For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix,
+        so that this module can be loaded into non-checkpointed modules.
+        It would still be able to be loaded into checkpoint-wrapped modules as this class,
+        adds the prefix back before loading the state_dict.
+        """
+        _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix)
+        return state_dict
+
+    @staticmethod
+    def _pre_load_state_dict_hook(
+        module: nn.Module,
+        state_dict: Dict[str, Any],
+        prefix: str,
+        *args: Any,
+    ) -> None:
+        """
+        ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called.
+
+        For ``checkpoint_wrapper``, it will add back the module
+        prefix so that non-checkpointed modules can be loaded into
+        checkpoint_wrapper modules properly.
+        """
+        _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}")
+
+
+class OffloadWrapper(ActivationWrapper):
+    def __init__(self, mod):
+        super().__init__(mod)
+
+    def forward(self, *args, **kwargs):
+        with save_on_cpu(pin_memory=True):
+            return self._checkpoint_wrapped_module(*args, **kwargs)
+
+
+class CheckpointWrapper(ActivationWrapper):
+    """
+    An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing.
+
+    Note that this module is not meant to be used directly but instead,
+    it is to be used through the ``checkpoint_wrapper`` function.
+    """
+
+    def __init__(
+        self,
+        mod: torch.nn.Module,
+        checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
+        checkpoint_fn=None,
+        **checkpoint_fn_kwargs,
+    ):
+        super().__init__(mod)
+        self.checkpoint_impl = checkpoint_impl
+        if checkpoint_fn is None:
+            # use torch.utils.checkpoint
+            self.checkpoint_fn = partial(
+                torch_utils_checkpoint,
+                use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
+                **checkpoint_fn_kwargs,
+            )
+        else:
+            # Construct user-specified checkpoint function.
+            self.checkpoint_fn = partial(
+                checkpoint_fn,
+                **checkpoint_fn_kwargs,
+            )
+
+    def forward(self, *args, **kwargs):
+        # Support keyword arguments for reentrant checkpoint. Note that this
+        # only works if user has specified self.checkpoint_impl and is not
+        # using their own custom checkpoint_fn.
+        if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}:
+            # Pack the args and kwargs
+            flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs)
+
+            # Function that only takes (packed) args, but can unpack them
+            # into the original args and kwargs for the checkpointed
+            # function, and runs that function.
+            def my_function(*inputs):
+                # unpack back into args and kwargs
+                unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys)
+                # run original module
+                return self._checkpoint_wrapped_module(
+                    *unpacked_args, **unpacked_kwargs
+                )
+
+            # Pass the function that only takes packed args into reentrant
+            # checkpoint API.
+            return self.checkpoint_fn(  # type: ignore[misc]
+                my_function,
+                *flat_args,
+            )
+        else:
+            return self.checkpoint_fn(  # type: ignore[misc]
+                self._checkpoint_wrapped_module, *args, **kwargs
+            )
+
+
+def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module:
+    """
+    Wrap a module for activation offloading to CPU.
+
+    Offloads intermediate activations to the CPU for modules wrapped with this function.
+    Wrappers with activation offload can be composed with ones that do recomputation-based
+    checkpoint to trade off increased compute versus increased CPU
+    memory usage and additional H2D transfers.
+
+    Usage::
+        offloaded_module = offload_wrapper(module)
+        outputs = checkpointed_module(inputs)
+    Args:
+        module (nn.Module):
+            The module to be wrapped
+    Returns:
+        (nn.Module):
+            Wrapped module
+    """
+    return OffloadWrapper(module)
+
+
+def checkpoint_wrapper(
+    module: torch.nn.Module,
+    checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
+    checkpoint_fn=None,
+    **checkpoint_fn_kwargs,
+) -> torch.nn.Module:
+    """
+    Wrap a module for activation checkpointing.
+
+    If the module is wrapped with this function, all subsequent calls to the module will,
+    automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function.
+
+    Usage::
+        checkpointed_module = checkpoint_wrapper(module)
+        outputs = checkpointed_module(inputs)
+    Args:
+        module (nn.Module):
+            The module to be wrapped
+        checkpoint_impl (Optional[CheckpointImpl]):
+            The checkpointing implementation to use. Note that this will only
+            be passed into the ``torch.utils.checkpoint.checkpoint``
+            implementation, and is ignored if a custom ``checkpoint_fn`` is
+            specified. Note that for implementations using reentrant checkpoint
+            from ``torch.utils.checkpoint``, keyword arguments will only be
+            supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`.
+        checkpoint_fn (Optional[Callable]):
+            Functional checkpoint implementation to use. If this is specified,
+            it will be used over the default ``torch.utils.checkpoint.checkpoint``
+            implementation and the `checkpoint_impl` argument will be ignored.
+        **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`.
+
+    Returns:
+        (nn.Module):
+            Wrapped module
+    """
+
+    if checkpoint_impl == CheckpointImpl.REENTRANT:
+        warnings.warn(
+            f"Please specify {CheckpointImpl.NO_REENTRANT} as "
+            f"{CheckpointImpl.REENTRANT} will soon be removed as "
+            "the default and eventually deprecated.",
+            stacklevel=1,
+        )
+    return CheckpointWrapper(
+        module,
+        checkpoint_impl,
+        checkpoint_fn,
+        **checkpoint_fn_kwargs,
+    )
+
+
+def apply_activation_checkpointing(
+    model,
+    checkpoint_wrapper_fn=checkpoint_wrapper,
+    check_fn=lambda _: True,
+    auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None,
+):
+    """
+    Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration.
+
+    For each module within `model`, the `check_fn` is used to decide
+    whether `module` should be wrapped with :func:`checkpoint_wrapper` or not.
+
+    Note::
+        This function modifies `model` in place and replaces appropriate layers with
+        their checkpoint-wrapped modules.
+    Note::
+        This function will not wrap the overall root module. If this is needed, please directly use
+        :func:`checkpoint_wrapper` or :func:`offload_wrapper`.
+    Usage::
+        model = nn.Sequential(
+            nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)
+        )
+        check_fn = lambda l: isinstance(l, nn.Linear)
+        # checkpoint activations
+        apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
+        # Or offload activations to CPU
+        apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn)
+    Args:
+        model (nn.Module):
+            The model whose submodules should be wrapped with activation checkpointing.
+        checkpoint_wrapper_fn (Optional[Callable[nn.Module]])
+            A ``Callable`` which will wrap modules
+        check_fn (Optional[Callable[nn.Module, nn.Module]])
+            A lambda function which will be passed each child submodule of ``model`` and returns
+            ``True`` or ``False`` depending on whether the submodule should be wrapped.
+        auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's
+            submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``.
+    Returns: None (`model` is modified inplace)
+    """
+    # TODO: Importing inside function to avoid circular import issue between FSDP and
+    # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code.
+    from torch.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy, _Policy
+    from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply
+
+    policy = (
+        auto_wrap_policy
+        if auto_wrap_policy is not None
+        else partial(lambda_auto_wrap_policy, lambda_fn=check_fn)
+    )
+    if not callable(policy):
+        if not isinstance(policy, _Policy):
+            raise ValueError(
+                f"Expected {policy} to be callable or be a pre-defined wrap policy"
+            )
+        target_module_to_kwargs = policy._run_policy(
+            model, ignored_modules=set(), root_kwargs={}
+        )
+        wrap_fn = _construct_wrap_fn(model, target_module_to_kwargs, checkpoint_wrapper_fn)
+        _post_order_apply(model, wrap_fn)
+        return
+
+    _recursive_wrap(
+        module=model,
+        auto_wrap_policy=policy,  # type: ignore[arg-type]
+        wrapper_cls=checkpoint_wrapper_fn,
+        ignored_modules=set(),
+        ignored_params=set(),
+        only_wrap_children=True,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..990bfd9dcf09c9fc8029739f6c0191a614f87404
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__init__.py
@@ -0,0 +1,7 @@
+
+from . import default_hooks as default
+
+LOW_PRECISION_HOOKS = [
+    default.fp16_compress_hook,
+    default.bf16_compress_hook,
+]
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebba0369be6abbb629d2a66d62e2805782de79d2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b687f74b8acee7c1bb09313ee86bf9ad494bedb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ebd5b4bd3e65b3e09300f171fe18a55d634ff09
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/_comm_hooks/default_hooks.py
@@ -0,0 +1,165 @@
+import functools
+import torch
+import torch.distributed as dist
+from typing import Optional
+
+
+class DefaultState:
+    r"""
+    Stores state needed to perform the default communication algorithm within a communication hook.
+
+    Args:
+        process_group (ProcessGroup): The process group to be used.
+    """
+
+    __slots__ = [
+        "process_group",
+        "world_size",
+        "gradient_predivide_factor",
+        "gradient_postdivide_factor"
+    ]
+
+    def __init__(
+        self,
+        process_group: dist.ProcessGroup
+    ):
+        if process_group is None:
+            raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.")
+        self.process_group = process_group
+        self.world_size = dist.get_world_size(process_group)
+        # Setting two factors `self.gradient_predivide_factor`
+        # and `self.gradient_postdivide_factor` to avoid underflow and overflow
+        self.gradient_predivide_factor = self._get_gradient_predivide_factor(
+            self.world_size
+        )
+        self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor
+
+    @staticmethod
+    def _get_gradient_predivide_factor(world_size: int) -> float:
+        factor: int = 1
+        while world_size % factor == 0 and world_size / factor > factor:
+            factor *= 2
+        return float(factor)
+
+class LowPrecisionState(DefaultState):
+    r"""
+    Stores state needed to perform gradient communication in a lower precision within a communication hook.
+
+    Communication hook will cast gradients back to the original
+    parameter precision specified by ``parameter_type`` (default: torch.float32).
+    Builds on top of the :class:`DefaultState`.
+
+    Args:
+        parameter_type (torch.dtype): The precision of model's parameters.
+        Required for a hook to cast gradients back to a parameter's precision.
+    """
+
+    __slots__ = [
+        "parameter_type",
+    ]
+
+    def __init__(
+        self,
+        process_group,
+        parameter_type=torch.float32,
+    ):
+        super().__init__(process_group)
+        self.parameter_type = parameter_type
+
+
+def _decompress(state: LowPrecisionState, grad: torch.Tensor):
+    """
+    Casts gradients back to full parameter precision so that further computation happens in full precision.
+    """
+    orig_grad_data = grad.data
+    grad.data = grad.data.to(state.parameter_type)
+    # Don't let this memory get reused until after the transfer.
+    orig_grad_data.record_stream(torch.cuda.current_stream())  # type: ignore[arg-type]
+
+def allreduce_hook(state: DefaultState, grad: torch.Tensor):
+    r"""
+    Implement the  FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients.
+
+    Args:
+        state (DefaultState): State information, configures pre- and post-division factors.
+        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks.
+    """
+    # Average grad by pre-division factor. Together pre- and post-division factors
+    # lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
+    # This is a two-step process to avoid potential underflow and overflow.
+    if state.gradient_predivide_factor > 1:
+        grad.div_(state.gradient_predivide_factor)
+    dist.all_reduce(grad, group=state.process_group)
+    # Average grad by post-division factor.
+    if state.gradient_postdivide_factor > 1:
+        grad.div_(state.gradient_postdivide_factor)
+
+def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor):
+    r"""
+    Implement the  FSDP communication hook for ``reduce_scatter`` algorithm.
+
+    For sharded FSDP strategies and a necessary pre- and post-division of gradients.
+
+    Args:
+        state (DefaultState): State information, configures pre- and post-division factors.
+        grad (torch.Tensor): An unsharded gradient for the local batch that needs to be
+        communicated across ranks.
+        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
+    """
+    # Average grad by pre-division factor.
+    if state.gradient_predivide_factor > 1:
+        grad.div_(state.gradient_predivide_factor)
+    dist.reduce_scatter_tensor(
+        output, grad, group=state.process_group
+    )
+    # Average grad's shard by post-division factor.
+    if state.gradient_postdivide_factor > 1:
+        output.div_(state.gradient_postdivide_factor)
+
+def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor):
+    if grad.dtype != prec:
+        grad.data = grad.data.to(prec)
+    if output is not None:
+        if output.dtype != prec:
+            output.data = output.data.to(prec)
+        reduce_scatter_hook(state, grad, output)
+        _decompress(state, output)
+    else:
+        allreduce_hook(state, grad)
+        _decompress(state, grad)
+
+def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
+    r"""
+    Implement FSDP communication hook for a simple gradient compression approach.
+    Casts ``grad`` to half-precision floating-point format (``torch.float16``).
+
+    It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
+    ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
+    gradients are averaged by a ``state.gradient_postdivide_factor``.
+    Once post-division is done, compressed gradients are casted back to parameters' precision.
+
+    Args:
+        state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
+        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
+        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
+    """
+    fp16_hook = functools.partial(_low_precision_hook, torch.float16)
+    return fp16_hook(state, grad, output)
+
+def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
+    r"""
+    Implement FSDP communication hook for a simple gradient compression approach .
+    Casts ``grad`` to half-precision floating-point format.
+
+    It also averages gradients by ``world_size`` in two steps: first it pre-divides gradients by a
+    ``state.gradient_predivide_factor``, and after a communication step (``all_reduce`` or ``reduce_scatter``)
+    gradients are averaged by a ``state.gradient_postdivide_factor``.
+    Once post-division is done, compressed gradients are casted back to parameters' precision.
+
+    Args:
+        state (LowPrecisionState): State information, configures pre- and post-division factors, parameters' precision.
+        grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
+        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
+    """
+    bf16_hook = functools.partial(_low_precision_hook, torch.bfloat16)
+    return bf16_hook(state, grad, output)
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9460c12ce8abc076e6d22570e65a773039aa145f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__init__.py
@@ -0,0 +1 @@
+from .optimizer_overlap import _as_overlapped_optim
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4400d9acc55d75450755425a11a329bb1005ed1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1250da9aa6e7bc12384bbab9b8e1ffcb6fcc52a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d824897049d10df995e7b471bb8c077adfa434
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py
@@ -0,0 +1,93 @@
+from abc import ABC, abstractmethod
+import inspect
+from typing import Dict, Type
+
+from torch.distributed.fsdp import FullyShardedDataParallel
+from torch.nn.parallel import DistributedDataParallel
+from torch.optim import Optimizer
+from torch.distributed.optim import as_functional_optim
+
+from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
+
+from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
+    _OptimizerHookState,
+    _hook_then_optimizer
+)
+
+# Contains the mappings between the regular and overlapped optimizer types.
+_registered_overlapped_optims: Dict[Type, Type] = {}
+
+
+def register_overlapped(optim_cls):
+    def decorator(target_overlapped_optim_cls):
+        if target_overlapped_optim_cls in _registered_overlapped_optims:
+            raise ValueError(
+                f"{target_overlapped_optim_cls} already registered with optim_cls "
+                f"{_registered_overlapped_optims[optim_cls]} {optim_cls}, trying to"
+                f"re-register it for {optim_cls} is not supported."
+            )
+        _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls
+        return target_overlapped_optim_cls
+    return decorator
+
+
+class OverlappedOptimizer(ABC):
+    def __init__(self, optim_cls: Type) -> None:
+        """
+        Initialize the OverlappedOptimizer.
+
+        Overlappedoptimizer is a base class that child classes can implement to
+        specify how different optimizers will register themselves with DDP.
+        """
+        self.optim_cls = optim_cls
+
+    @abstractmethod
+    def register_ddp(self, ddp: DistributedDataParallel) -> None:
+        """Registers the overlapped optimizer with DDP."""
+        raise NotImplementedError(
+            f"{self.__class__.__name__} does not support overlapped DDP."
+        )
+
+    @abstractmethod
+    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
+        """Registers the overlapped optimizer with FSDP."""
+        raise NotImplementedError(
+            f"{self.__class__.__name__} does not support overlapped FSDP."
+        )
+
+
+@register_overlapped(Optimizer)
+class _OverlappedStandardOptimizer(OverlappedOptimizer):
+    """Overlaps a regular ``Optimizer``."""
+
+    def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
+        super().__init__(optim_cls)
+        f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
+        self._opt_hook_state = _OptimizerHookState(f_optim, params)
+
+    def register_ddp(self, ddp_inst: DistributedDataParallel):
+        # NOTE: using a custom communication hook and fused optimizer is not
+        # yet supported.
+        ddp_inst.register_comm_hook(  # type: ignore[operator]
+            None,  # wrapped hook state
+            _hook_then_optimizer(allreduce_hook, self._opt_hook_state)
+        )
+
+    # TODO: register_fsdp once FSDP supports communication hook.
+    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
+        """Register the overlapped optimizer with FSDP."""
+        raise NotImplementedError(
+            f"{self.__class__.__name__} does not support overlapped FSDP."
+        )
+
+def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
+    """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
+    for clz in inspect.getmro(optim_cls):
+        try:
+            return _registered_overlapped_optims[clz](optim_cls, params, *args, **kwargs)
+        except KeyError:
+            pass
+
+    # Fallback to standard overlapped optimizer, which will raise errors if user
+    # is attempting to use an unsupported optimizer.
+    return _OverlappedStandardOptimizer(optim_cls, params, *args, **kwargs)
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..407d34d949887c4c9a26c30519776db912243bc7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88009dd7c40b2e87e6c5201b41f7d047c3264b3f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/quantization.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbb6862c46b9596069eec65a91b354046dc6d0ef
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/_quantization/quantization.py
@@ -0,0 +1,144 @@
+import functools
+import torch
+import torch.distributed as dist
+
+
+from enum import Enum
+
+
+TORCH_HALF_MIN = torch.finfo(torch.float16).min
+TORCH_HALF_MAX = torch.finfo(torch.float16).max
+
+class DQuantType(Enum):
+    """
+    Different quantization methods for auto_quantize API are identified here.
+
+    auto_quantize API currently supports fp16 and bfp16 methods.
+    """
+    FP16 = "fp16",
+    BFP16 = "bfp16"
+
+    def __str__(self) -> str:
+        return self.value
+
+
+def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
+    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
+
+def _quantize_tensor(tensor, qtype):
+    if not isinstance(tensor, torch.Tensor):
+        raise RuntimeError(
+            f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+        )
+    if qtype == DQuantType.FP16:
+        return _fp32_to_fp16_with_clamp(tensor)
+    elif qtype == DQuantType.BFP16:
+        return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
+    else:
+        raise RuntimeError(
+            f'Quantization type {qtype} is not supported'
+        )
+
+def _quantize_tensor_list(tensor_list, qtype):
+    if not isinstance(tensor_list, list) or not all(
+        isinstance(p, torch.Tensor) for p in tensor_list
+    ):
+        raise RuntimeError(
+            f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+        )
+    quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
+    return quantized_tensor_list
+
+def _dequantize_tensor(tensor, qtype, quant_loss=None):
+    if not isinstance(tensor, torch.Tensor):
+        raise RuntimeError(
+            f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+        )
+    if qtype == DQuantType.FP16:
+        if tensor.dtype != torch.float16:
+            raise RuntimeError(
+                f"tensor dtype is {tensor.dtype} while expected to be FP16."
+            )
+        elif tensor.dtype == torch.float16 and quant_loss is None:
+            return tensor.float()
+        else:
+            return tensor.float() / quant_loss
+    elif qtype == DQuantType.BFP16:
+        if tensor.dtype != torch.float16:
+            raise RuntimeError(
+                f"tensor dtype is {tensor.dtype} while expected to be FP16."
+            )
+        else:
+            return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
+    else:
+        raise RuntimeError(
+            f'Quantization type {qtype} is not supported'
+        )
+
+
+def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
+    if not isinstance(tensor_list, list) or not all(
+        isinstance(p, torch.Tensor) for p in tensor_list
+    ):
+        raise RuntimeError(
+            f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+        )
+    dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
+    return dequantized_tensor_list
+
+
+def auto_quantize(func, qtype, quant_loss=None):
+    """
+    Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output.
+
+    Currently it only supports:
+        . FP16 and BFP16 quantization method supported for gloo and nccl backends
+        . all_gather, all_to_all collective ops
+    Note: BFP16 only supports 2D tensors.
+    Args:
+        func (Callable): A function representing collective operations.
+        qtype (QuantType): Quantization method
+        quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
+    Returns:
+        (Callable): the same collective as func but enables automatic quantization/dequantization.
+    """
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        group = kwargs.get('group', None)
+        async_op = kwargs.get('async_op', False)
+        if async_op is True:
+            raise RuntimeError(
+                'The async_op=True mode is not supported yet.'
+            )
+        if func == dist.all_gather:
+            tensors = args[0]
+            input_tensors = _quantize_tensor(args[1], qtype)
+            out_tensors = _quantize_tensor_list(tensors, qtype)
+            dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
+            for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
+                tensors[i] = t
+
+        elif func == dist.all_to_all:
+            tensors = args[0]
+            input_tensors = _quantize_tensor_list(args[1], qtype)
+            out_tensors = _quantize_tensor_list(tensors, qtype)
+            dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
+            for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
+                tensors[i] = t
+
+        elif func == dist.all_to_all_single:
+            tensors = args[0]
+            out_splits = kwargs.get('out_splits', None)
+            in_splits = kwargs.get('in_splits', None)
+            # Quantizing the input/output tensor
+            input_tensors = _quantize_tensor(args[1], qtype)
+            out_tensors = _quantize_tensor(tensors, qtype)
+            dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group)
+            for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)):
+                tensors[i] = t
+        else:
+            raise RuntimeError(
+                f"The collective op {func} is not supported yet"
+            )
+
+    return wrapper
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc837ee0abcfeb481cc60d2e74aab92230c859d6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__init__.py
@@ -0,0 +1,108 @@
+from enum import Enum
+from functools import partial
+
+import torch.distributed as dist
+
+from . import (
+    debugging_hooks as debugging,
+    default_hooks as default,
+    powerSGD_hook as powerSGD,
+    quantization_hooks as quantization,
+    optimizer_overlap_hooks as optimizer_overlap,
+)
+
+__all__ = ['DDPCommHookType', 'register_ddp_comm_hook']
+
+def _ddp_comm_hook_wrapper(comm_hook, model, state):
+    model.register_comm_hook(state, comm_hook)
+
+
+def _powerSGD_comm_hook_wrapper(
+    comm_hook,
+    model,
+    state,
+    matrix_approximation_rank,
+    start_powerSGD_iter=1_000,
+):
+    """
+    Wrap PowerSGD communication hook.
+
+    To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
+    which will be wrapped up with other state info.
+    """
+    powerSGD_state = powerSGD.PowerSGDState(
+        process_group=state,
+        matrix_approximation_rank=matrix_approximation_rank,
+        start_powerSGD_iter=start_powerSGD_iter,
+    )
+    model.register_comm_hook(powerSGD_state, comm_hook)
+
+
+class DDPCommHookType(Enum):
+    """
+    Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types.
+
+    DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
+    as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example,
+    you can register allreduce hook by
+    ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``.
+    """
+
+    ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
+    FP16_COMPRESS = partial(
+        _ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook
+    )
+    BF16_COMPRESS = partial(
+        _ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook
+    )
+    QUANTIZE_PER_TENSOR = partial(
+        _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook
+    )
+    QUANTIZE_PER_CHANNEL = partial(
+        _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook
+    )
+    POWER_SGD = partial(
+        _powerSGD_comm_hook_wrapper,
+        comm_hook=powerSGD.powerSGD_hook,
+        matrix_approximation_rank=1,
+    )
+    # Rank-2 PowerSGD can give a higher accuracy than the default rank-1 version,
+    # but it runs slower and consumes more memory.
+    POWER_SGD_RANK2 = partial(
+        _powerSGD_comm_hook_wrapper,
+        comm_hook=powerSGD.powerSGD_hook,
+        matrix_approximation_rank=2,
+    )
+    # Batching can lead to a faster training at the cost of accuracy.
+    BATCHED_POWER_SGD = partial(
+        _powerSGD_comm_hook_wrapper,
+        comm_hook=powerSGD.batched_powerSGD_hook,
+        matrix_approximation_rank=1,
+    )
+    BATCHED_POWER_SGD_RANK2 = partial(
+        _powerSGD_comm_hook_wrapper,
+        comm_hook=powerSGD.batched_powerSGD_hook,
+        matrix_approximation_rank=2,
+    )
+    NOOP = partial(
+        _ddp_comm_hook_wrapper, comm_hook=debugging.noop_hook,
+    )
+
+
+def register_ddp_comm_hook(
+    comm_hook_type: DDPCommHookType, model, state=None
+):
+    """
+    Register ``ddp_comm_hooks`` to DDP model.
+
+    Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
+    to the DDP model. User can specify the type of hook as an enum
+    ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will
+    be passed to the model.
+    Uses Python comm hook implementations.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state)
+    """
+    comm_hook_type.value(model=model, state=state)
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..079216591f84f3a729b3ff218b8fd6cf2b041489
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48be6b606c19ebb1f06b00b1db19617346810a62
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c6ab9522d1a6af5c686170637675f42b0411d59
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bead8fc933c58257733aec8bd629399370f12142
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..047f95f726191a11b33663cd604f709935110b7a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5accb84c6e515a2c5c419316c9edd41c2d9d54ac
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a318bc3eeda9f98fcd26150bf9690afba1c5b0c0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e05d4889ca8a686ac412c18c07304ebd0de4a0a7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95a65ac700c6d9ebbb526903bbf1757a5907420a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..95bfdec9500e95369439a26f878b20a8f6e6417b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py
@@ -0,0 +1,448 @@
+import weakref
+from typing import Any, Callable, List, Optional
+
+import torch
+import torch.distributed as dist
+from torch.distributed.optim import ZeroRedundancyOptimizer
+from torch.distributed.optim.zero_redundancy_optimizer import (
+    _OverlapStatus,
+)
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"]
+
+# Functional optimizers require passing a list of gradients to their `step()`
+# method, and ZeRO requires a functional optimizer to overlap with DDP
+# Passing a `None` instead of an actual gradient indicates to the optimizer
+# to not update the corresponding parameter
+_NO_PARAM_UPDATE: None = None
+
+
+def _perform_local_step(
+    bucket: dist.GradBucket,
+    zero: ZeroRedundancyOptimizer,
+    rank: int,
+):
+    r"""
+    Perform a local optimizer step using the gradients provided by ``bucket``.
+
+    Arguments:
+        bucket (dist.GradBucket): the bucket providing the gradients.
+        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
+            instance to perform the :meth:`_local_step`.
+        rank (int): the calling process's rank.
+
+    .. warning::
+        This function assumes that appropriate synchronization has taken place
+        so that the bucket's gradients can be used.
+    """
+    overlap_info = zero._overlap_info
+    bucket_index = bucket.index()
+    assert len(zero.optim.param_groups) == 1, \
+        "Overlapping DDP with ZeRO only supports a single parameter group"
+
+    # Construct the `gradients` input for the local optimizer step, which
+    # expects `None` in a list position to indicate that the corresponding
+    # parameter should not be updated
+    num_local_optim_params = len(zero.optim.param_groups[0]["params"])
+    gradients: List[Optional[torch.Tensor]] = \
+        [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)]
+    assert bucket_index in overlap_info.offsets, \
+        f"Bucket index {bucket_index} was not assigned to rank {rank}"
+    gradients_offset = overlap_info.offsets[bucket_index]
+    bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
+    bucket_offset = bucket_assignment.offset
+    length = len(bucket_assignment.parameters)
+    bucket_gradients = bucket.gradients()[bucket_offset:bucket_offset + length]
+    for i, grad in enumerate(bucket_gradients):
+        gradients[gradients_offset + i] = grad
+
+    zero._local_step(gradients)
+
+
+def _broadcast_bucket(
+    bucket_index: int,
+    zero: ZeroRedundancyOptimizer,
+):
+    r"""
+    Broadcasts a bucket's parameters.
+
+    Arguments:
+        bucket_index (int): the index of the bucket corresponding to the
+            parameters to broadcast.
+        zero (ZeroRedundancyOptimizer): the calling process's
+            :class:`ZeroRedundancyOptimizer` instance.
+    """
+    overlap_info = zero._overlap_info
+    assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \
+        "`assigned_ranks_per_bucket` is not fully constructed"
+    # Sort to ensure the same ordering across ranks
+    assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
+    assert len(assigned_ranks) > 0, f"Bucket {bucket_index} should be " \
+        "assigned to at least one rank"
+    for assigned_rank in assigned_ranks:
+        bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
+        if bucket_index in bucket_assignments:
+            overlap_info.broadcast_handles.append(
+                dist.broadcast(
+                    bucket_assignments[bucket_index].tensor,
+                    src=dist.get_global_rank(zero.process_group, assigned_rank),
+                    group=zero.process_group,
+                    async_op=True,
+                )
+            )
+
+
+def _save_ddp_bucket_info(
+    bucket: dist.GradBucket,
+    zero: ZeroRedundancyOptimizer,
+):
+    r"""
+    Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
+
+    In particular, this function is meant to be called upon seeing each
+    gradient bucket to use when overlapping, meaning it does not save or compute any global
+    information.
+
+    Arguments:
+        bucket (dist.GradBucket): the current gradient bucket.
+        zero (ZeroRedundancyOptimizer): the calling process's
+            :class:`ZeroRedundancyOptimizer` instance.
+    """
+    overlap_info = zero._overlap_info
+    bucket_params = bucket.parameters()
+    assert len(bucket_params) > 0, "Empty bucket"
+
+    # Save the parameters in the bucket
+    overlap_info.params_per_bucket.append(bucket_params)
+    if overlap_info.shard_buckets:
+        # Additionally save the bucket size for the assignment heuristic to use
+        bucket_size = 0
+        for param in bucket_params:
+            bucket_size += param.numel()
+        assert overlap_info.total_size is not None
+        overlap_info.total_size += bucket_size
+
+
+def _hook_with_zero_step_setup(
+    ddp_ref: weakref.ReferenceType,
+    zero: ZeroRedundancyOptimizer,
+    bucket: dist.GradBucket,
+):
+    r"""
+    Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
+
+    This means the logic to run in the
+    hook before the backward pass and optimizer step can actually be
+    overlapped. This is factored out since it is common to both
+    :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
+
+    Arguments:
+        ddp_ref (weakref.ReferenceType): weak reference to the process's
+            :class:`DistributedDataParallel` instance.
+        zero (ZeroRedundancyOptimizer): the calling process's
+            :class:`ZeroRedundancyOptimizer` instance.
+        bucket (dist.GradBucket): the current gradient bucket.
+    """
+    # Proceed as normal until the DDP buckets have been rebuilt
+    if not ddp_ref()._has_rebuilt_buckets:  # type: ignore[union-attr]
+        assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED
+        return
+
+    bucket_index = bucket.index()
+    overlap_info = zero._overlap_info
+    if overlap_info.status == _OverlapStatus.UNINITIALIZED:
+        overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
+
+    if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
+        if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0:
+            # This corresponds to the first bucket of the backward pass
+            # immediately after all information has been saved, so we
+            # can perform the delayed ZeRO initialization
+            zero._init_zero_for_overlap()
+        else:
+            # Once DDP buckets have been rebuilt but ZeRO has not been
+            # properly initialized yet, save the information needed
+            _save_ddp_bucket_info(bucket, zero)
+
+
+def hook_with_zero_step(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
+    ddp: DistributedDataParallel,
+    zero: ZeroRedundancyOptimizer,
+    shard_buckets: bool = False,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""
+    Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass.
+
+    This approach overlaps the optimizer computation and communication with the
+    backward communication. In particular, the backward computation proceeds
+    contiguously, and the optimizer computation follows, overlapping with
+    outstanding backward communication (i.e. all-reduces) and possibly other
+    optimizer communication (i.e. broadcasts).
+    The optimizer step computation begins after the last gradient bucket computation has finished.
+
+    This approach may be preferred over :meth:`hook_with_zero_step_interleaved`
+    if communication is relatively slow compared to computation.
+
+    Arguments:
+        hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook
+            to modify.
+        ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
+            instance to use.
+        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
+            instance to use.
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity; if
+            ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
+
+    Returns:
+        The modified hook.
+
+    Raises:
+        ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
+        RuntimeError: if using any backend other than NCCL/HCCL since currently
+            Gloo may hang.
+
+    .. warning::
+        Given the way that overlapping :class:`DistributedDataParallel` with
+        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
+        two or three training iterations do not perform parameter updates in
+        the optimizer step, depending on if ``static_graph=False`` or
+        ``static_graph=True``, respectively. This is because it needs
+        information about the gradient bucketing strategy used by
+        :class:`DistributedDataParallel`, which is not finalized until the
+        second forward pass if ``static_graph=False`` or until the third
+        forward pass if ``static_graph=True``.
+    """
+    if not zero._overlap_with_ddp:
+        raise ValueError(
+            "ZeroRedundancyOptimizer must be constructed with "
+            "`overlap_with_ddp=True` to use this hook properly"
+        )
+    ddp_ref = weakref.ref(ddp)
+
+    # NOTE: Gloo may hang with this overlapping approach, so we require
+    # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
+    pg = dist.get_backend(ddp_ref().process_group)  # type: ignore[union-attr]
+    if ((pg != dist.Backend.NCCL) and (pg != 'hccl')):
+        raise RuntimeError(
+            "Overlapping DDP with ZeRO using this approach currently requires "
+            "NCCL/HCCL backend to avoid hangs"
+        )
+
+    if shard_buckets:
+        zero._overlap_info.shard_buckets = True
+        zero._overlap_info.total_size = 0
+
+    def hook_with_zero_fn(
+        state: Any,
+        bucket: dist.GradBucket,
+    ) -> torch.futures.Future[torch.Tensor]:
+        r"""
+        Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket.
+
+        Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket.
+        The function gives a gradient bucket tensor and
+        performs additional computation on the iteration that
+        the :class:`DistributedDataParallel` buckets are rebuilt to collect
+        information used to implement the modified hook.
+
+        Arguments:
+            state (Any): any state for the hook.
+            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
+                gradient bucket.
+        """
+        fut = hook(state, bucket)
+        _hook_with_zero_step_setup(ddp_ref, zero, bucket)
+        if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
+            return fut
+
+        overlap_info = zero._overlap_info
+        bucket_index = bucket.index()
+        rank = zero.global_rank
+
+        assert overlap_info.status == _OverlapStatus.INITIALIZED
+        assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \
+            "`assigned_ranks_per_bucket` is not fully constructed"
+        assigned_to_bucket = rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
+
+        # Save the bucket reference and all-reduce future for the final bucket
+        if assigned_to_bucket:
+            overlap_info.bucket_index_to_bucket[bucket_index] = bucket
+            overlap_info.bucket_index_to_future[bucket_index] = fut
+
+        # Check that buckets are indexed incrementally starting from 0 in the
+        # order of their autograd hooks firing
+        if len(overlap_info.bucket_indices_seen) > 0:
+            assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, \
+                "Bucket indices are not in incremental order"
+        else:
+            assert bucket_index == 0, "Bucket indices do not start from 0"
+        overlap_info.bucket_indices_seen.append(bucket_index)
+
+        # Directly return the future without any optimizer computation if this
+        # is not the last bucket
+        num_buckets = len(overlap_info.params_per_bucket)
+        is_last_bucket = bucket_index == num_buckets - 1
+        if not is_last_bucket:
+            return fut
+
+        # Perform partial optimizer step on all buckets after the final
+        # bucket has been computed
+        # NOTE: This should not be chained as a callback to the last bucket's
+        # all-reduce future since that would add synchronization that delays
+        # all optimizer computation to wait for that last all-reduce
+        for bucket_index in range(num_buckets):
+            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
+            if rank in assigned_ranks:
+                # Wait on the bucket's all-reduce future to ensure correct
+                # gradients
+                assert bucket_index in overlap_info.bucket_index_to_future, \
+                    f"All-reduce future for bucket {bucket_index} not saved " \
+                    f"on rank {rank}"
+                allreduce_future = overlap_info.bucket_index_to_future[bucket_index]
+                allreduce_future.wait()
+
+                # Perform the partial optimizer step
+                curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index]
+                _perform_local_step(curr_bucket, zero, rank)
+
+            _broadcast_bucket(bucket_index, zero)
+
+        # Ensure that all parameter updates are finished before the
+        # next forward pass
+        overlap_info.wait_for_broadcasts()
+        overlap_info.clear_per_iter_info()
+
+        return fut
+
+    return hook_with_zero_fn
+
+
+def hook_with_zero_step_interleaved(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
+    ddp: DistributedDataParallel,
+    zero: ZeroRedundancyOptimizer,
+    shard_buckets: bool = False,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""
+    Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass
+
+    This approach overlaps the optimizer computation and communication with the
+    backward computation and communication. In particular, once a bucket's
+    gradients have been computed, the optimizer computation using those
+    gradients is launched (though the actual computation must wait for the
+    bucket's all-reduce to complete). This yields an interleaving of all-
+    reduces and broadcasts in the communication stream.
+
+    This approach may be preferred over :meth:`hook_with_zero_step` if
+    communication is relatively fast compared to computation.
+
+    Arguments:
+        hook (Any * dist.GradBucket -> torch.futures.Future): the hook to
+            modify.
+        ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
+            instance to use.
+        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
+            instance to use.
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity; if
+            ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
+
+    Returns:
+        The modified hook.
+
+    Raises:
+        ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
+        RuntimeError: if using any backend other than NCCL since currently
+            Gloo may hang.
+
+    .. warning::
+        Given the way that overlapping :class:`DistributedDataParallel` with
+        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
+        two or three training iterations do not perform parameter updates in
+        the optimizer step, depending on if ``static_graph=False`` or
+        ``static_graph=True``, respectively. This is because it needs
+        information about the gradient bucketing strategy used by
+        :class:`DistributedDataParallel`, which is not finalized until the
+        second forward pass if ``static_graph=False`` or until the third
+        forward pass if ``static_graph=True``.
+    """
+    if not zero._overlap_with_ddp:
+        raise ValueError(
+            "ZeroRedundancyOptimizer must be constructed with "
+            "`overlap_with_ddp=True` to use this hook properly"
+        )
+    ddp_ref = weakref.ref(ddp)
+
+    # NOTE: Gloo may hang with this overlapping approach, so we require
+    # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
+    pg = dist.get_backend(ddp_ref().process_group)  # type: ignore[union-attr]
+    if ((pg != dist.Backend.NCCL) and (pg != 'hccl')):
+        raise RuntimeError(
+            "Overlapping DDP with ZeRO using this approach currently requires "
+            "NCCL/HCCL backend to avoid hangs"
+        )
+
+    if shard_buckets:
+        zero._overlap_info.shard_buckets = True
+        zero._overlap_info.total_size = 0
+
+    def hook_with_zero_interleaved_fn(
+        state,
+        bucket: dist.GradBucket,
+    ) -> torch.futures.Future[torch.Tensor]:
+        r"""
+        Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`.
+
+        This function uses the gradients in gradient in given bucket to perform a partial
+        :class:`ZeroRedundancyOptimizer` :meth:`step`
+
+        Arguments:
+            state: any state for the hook.
+            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
+                gradient bucket.
+        """
+        fut = hook(state, bucket)
+        _hook_with_zero_step_setup(ddp_ref, zero, bucket)
+        if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
+            return fut
+
+        def zero_step(fut: torch.futures.Future) -> torch.Tensor:
+            r"""
+            Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`.
+
+            Returns:
+                A :class:`torch.Tensor` representing the contents of the
+                gradient bucket.
+            """
+            overlap_info = zero._overlap_info
+            bucket_index = bucket.index()
+            rank = zero.global_rank
+
+            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
+            overlap_info.bucket_indices_seen.append(bucket_index)
+            if rank in assigned_ranks:
+                _perform_local_step(bucket, zero, rank)
+
+            _broadcast_bucket(bucket_index, zero)
+
+            num_buckets = len(overlap_info.params_per_bucket)
+            if len(overlap_info.bucket_indices_seen) == num_buckets:
+                # Ensure that all parameter updates are finished before the
+                # next forward pass
+                overlap_info.wait_for_broadcasts()
+                overlap_info.clear_per_iter_info()
+
+            return bucket.buffer()
+
+        return fut.then(zero_step)
+
+    return hook_with_zero_interleaved_fn
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..79673f5e297c0c1053361c09817c447cbd54b439
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py
@@ -0,0 +1,28 @@
+from typing import Any
+
+import torch
+from torch.distributed import GradBucket
+
+__all__ = ["noop_hook"]
+
+
+def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
+    """
+    Return a future that wraps the input, so it is a no-op that does not incur any communication overheads.
+
+    This hook should **only** be used for headroom analysis of allreduce optimization,
+    instead of the normal gradient synchronization.
+    For example, if only less than 10% speedup of training time can be observed after this hook is registered,
+    it usually implies that allreduce is not a performance bottleneck for this case.
+    Such instrumentation can be particularly useful
+    if GPU traces cannot be easily retrieved or the trace analysis is complicated
+    some factors such as the overlap between allreduce and computation or the desynchronization across ranks.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(None, noop_hook)
+    """
+    fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
+    fut.set_result(bucket.buffer())
+
+    return fut
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4f5dfb185cd22e6e8537f45da3f2c0c42a93fb1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -0,0 +1,223 @@
+from typing import Any, Callable, cast, Tuple
+
+import torch
+import torch.distributed as dist
+
+__all__ = [
+    "allreduce_hook",
+    "fp16_compress_hook",
+    "bf16_compress_hook",
+    "fp16_compress_wrapper",
+    "bf16_compress_wrapper",
+]
+
+
+def _allreduce_fut(
+    process_group: dist.ProcessGroup, tensor: torch.Tensor
+) -> torch.futures.Future[torch.Tensor]:
+    """Average the input gradient tensor by allreduce and returns a future."""
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+
+    # Apply the division first to avoid overflow, especially for FP16.
+    tensor.div_(group_to_use.size())
+
+    return (
+        dist.all_reduce(tensor, group=group_to_use, async_op=True)
+        .get_future()
+        .then(lambda fut: fut.value()[0])
+    )
+
+
+def allreduce_hook(
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Call ``allreduce`` using ``GradBucket`` tensors.
+
+    Once gradient tensors are aggregated across all workers, its ``then``
+    callback takes the mean and returns the result.
+
+    If user registers this DDP communication hook,
+    DDP results is expected to be same as the case where no hook was registered.
+    Hence, this won't change behavior of DDP and user can use this as a reference
+    or modify this hook to log useful information or any other purposes while
+    unaffecting DDP behavior.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, allreduce_hook)
+    """
+    return _allreduce_fut(process_group, bucket.buffer())
+
+
+def fp16_compress_hook(
+    process_group: dist.ProcessGroup,
+    bucket: dist.GradBucket,
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
+
+    This DDP communication hook implements a simple gradient compression
+    approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
+    and then divides it by the process group size.
+    It allreduces those ``float16`` gradient tensors. Once compressed gradient
+    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
+    """
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    world_size = group_to_use.size()
+
+    buffer = (
+        cast(Tuple[torch.Tensor, ...], bucket)[0]
+        if isinstance(bucket, tuple)
+        else bucket.buffer()
+    )
+    compressed_tensor = buffer.to(torch.float16).div_(world_size)
+
+    def decompress(fut):
+        decompressed_tensor = buffer
+        # Decompress in place to reduce the peak memory.
+        # See: https://github.com/pytorch/pytorch/issues/45968
+        value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
+        decompressed_tensor.copy_(value)
+        return decompressed_tensor
+
+    if torch._utils.is_compiling():
+        grad = dist._functional_collectives.all_reduce(
+            compressed_tensor, "sum", group_to_use
+        )
+        return decompress(grad)
+    else:
+        fut = dist.all_reduce(
+            compressed_tensor, group=group_to_use, async_op=True
+        ).get_future()
+        return fut.then(decompress)
+
+
+# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress.
+def bf16_compress_hook(
+    process_group: dist.ProcessGroup,
+    bucket: dist.GradBucket,
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
+
+    This DDP communication hook implements a simple gradient compression
+    approach that casts ``GradBucket`` tensor to half-precision
+    `Brain floating point format `_ (``torch.bfloat16``)
+    and then divides it by the process group size.
+    It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient
+    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
+    """
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    world_size = group_to_use.size()
+
+    buffer = (
+        cast(Tuple[torch.Tensor, ...], bucket)[0]
+        if isinstance(bucket, tuple)
+        else bucket.buffer()
+    )
+    compressed_tensor = buffer.to(torch.bfloat16).div_(world_size)
+
+    def decompress(fut):
+        decompressed_tensor = buffer
+        # Decompress in place to reduce the peak memory.
+        # See: https://github.com/pytorch/pytorch/issues/45968
+        value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
+        decompressed_tensor.copy_(value)
+        return decompressed_tensor
+
+    if torch._utils.is_compiling():
+        grad = dist._functional_collectives.all_reduce(
+            compressed_tensor, "sum", group_to_use
+        )
+        return decompress(grad)
+    else:
+        fut = dist.all_reduce(
+            compressed_tensor, group=group_to_use, async_op=True
+        ).get_future()
+        return fut.then(decompress)
+
+
+def fp16_compress_wrapper(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    """
+    Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
+
+    This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
+    floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to
+    the input data type, such as ``float32``.
+    Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
+        >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
+    """
+
+    def fp16_compress_wrapper_hook(
+        hook_state, bucket: dist.GradBucket
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Cast bucket tensor to FP16.
+        bucket.set_buffer(bucket.buffer().to(torch.float16))
+
+        fut = hook(hook_state, bucket)
+
+        def decompress(fut):
+            decompressed_tensor = bucket.buffer()
+            # Decompress in place to reduce the peak memory.
+            # See: https://github.com/pytorch/pytorch/issues/45968
+            decompressed_tensor.copy_(fut.value())
+            return decompressed_tensor
+
+        # Decompress after hook has run.
+        return fut.then(decompress)
+
+    return fp16_compress_wrapper_hook
+
+
+def bf16_compress_wrapper(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    """
+    Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
+
+    This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
+    `Brain floating point format  `_  (``torch.bfloat16``),
+    and casts the resulting tensor of the given hook back to the input data type, such as ``float32``.
+
+    Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
+        >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
+    """
+
+    def bf16_compress_wrapper_hook(
+        hook_state, bucket: dist.GradBucket
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Cast bucket tensor to BF16.
+        bucket.set_buffer(bucket.buffer().to(torch.bfloat16))
+
+        fut = hook(hook_state, bucket)
+
+        def decompress(fut):
+            decompressed_tensor = bucket.buffer()
+            # Decompress in place to reduce the peak memory.
+            # See: https://github.com/pytorch/pytorch/issues/45968
+            decompressed_tensor.copy_(fut.value())
+            return decompressed_tensor
+
+        # Decompress after hook has run.
+        return fut.then(decompress)
+
+    return bf16_compress_wrapper_hook
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..253a902e004f90bc3f1d7232e4efb250f8b207f4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py
@@ -0,0 +1,85 @@
+import torch
+import torch.distributed as dist
+from torch.autograd import Variable
+
+from dataclasses import dataclass
+from typing import Any, no_type_check
+from torch.distributed.utils import _free_storage
+
+@dataclass
+class _AllreduceUpcastHookState:
+    """
+    State to manage DDP mixed precision in backward / gradient communication.
+
+    This contains a weakref to the DDP module for access to reducer and process
+    group, and a stream to run parameter and gradient upcasts.
+    """
+
+    ddp_weakref: Any
+    upcast_stream: torch.cuda.Stream
+    wait_for_stream_enqueued: bool = False
+
+@no_type_check
+def _reducer_allreduce_and_upcast_hook(
+    hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer.
+
+    Performs allreduce in the reduced precision given by DDP's mixed precision
+    reduce_dtype, and upcasts parameters and gradients to fp32 in preparation
+    to run the optimizer.
+    """
+    ddp_weakref = hook_state.ddp_weakref
+    reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group
+    gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view
+    # Cast bucket if different than param_dtype.
+    if (
+        ddp_weakref().mixed_precision.param_dtype != ddp_weakref().mixed_precision.reduce_dtype
+    ):
+        # Cast bucket tensor to reduce_dtype
+        bucket.set_buffer(bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype))
+    fut = reducer._run_allreduce_hook(bucket)
+    ret_fut = torch.futures.Future()
+    stream = hook_state.upcast_stream
+    with torch.cuda.stream(stream):
+        fut.wait()
+        bucket.buffer().div_(process_group.size())
+        ret_fut.set_result(bucket.buffer())
+
+        # Upcast parameters and gradients so optimizer step can run in fp32.
+        params, grads = bucket.parameters(), bucket.gradients()
+        for p, g in zip(params, grads):
+            p.data = p._fp_param
+            # free storage for mp param as it will be allocated again in next
+            # forward pass.
+            _free_storage(p._mp_param)
+            p.grad.data = p.grad.to(p.data.dtype)
+
+    # enqueue a callback to wait for this stream at end of backward
+    def wait_for_stream_cb():
+        torch.cuda.current_stream().wait_stream(stream)
+        # Remove post-backward hooks since they are re-installed in next
+        # iteration, similar to FSDP.
+        # Parameters that don't require grad still needed to be casted since
+        # they may participate in computation. However, they would not be recast
+        # by hook above as they don't have a grad hook installed, so cast them
+        # back here.
+        for n, p in ddp_weakref().module.named_parameters():
+            if hasattr(p, '_ddp_mp_hook_state'):
+                p._ddp_mp_hook_state[1].remove()
+                delattr(p, '_ddp_mp_hook_state')
+            if not p.requires_grad and not hasattr(p, '_ddp_ignored'):
+                p.data = p._fp_param
+
+        # reset for next backward pass
+        hook_state.wait_for_stream_enqueued = False
+
+    if not hook_state.wait_for_stream_enqueued:
+        Variable._execution_engine.queue_callback(
+            wait_for_stream_cb
+        )
+        # mark that the callback is enqueued
+        hook_state.wait_for_stream_enqueued = True
+
+    return ret_fut
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e62f70122b83e6b7c3481b08d7888282b799955
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py
@@ -0,0 +1,154 @@
+from typing import Any, Callable, List, no_type_check
+
+import torch
+import torch.distributed as dist
+from torch.autograd import Variable
+from functools import partial
+from dataclasses import dataclass
+
+__all__: List[str] = []
+
+_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
+
+class _OptimizerHookState:
+    """
+    Holds state for running optimizer in-line after DDP communication hook.
+
+    Currently contains only optimizer class which must have a method `step_param`.
+    """
+
+    __slots__ = ["functional_optimizer", "params_to_optimize"]
+
+    def __init__(self, functional_optim, params=None):
+        self.functional_optimizer = functional_optim
+        self._check_valid_functional_optim()
+        self._set_params_to_optimize(params)
+
+    def _set_params_to_optimize(self, params):
+        if params is not None:
+            self.params_to_optimize = set(params)
+
+    def _check_valid_functional_optim(self):
+        if not hasattr(self.functional_optimizer, _FUNCTIONAL_OPTIM_STEP_METHOD_NAME):
+            raise ValueError(
+                f"Class {type(self.functional_optimizer)} must implement method "
+                f"{_FUNCTIONAL_OPTIM_STEP_METHOD_NAME}."
+            )
+
+
+@dataclass
+class _OptimInBackwardHookState:
+    optim_stream: torch.cuda.Stream
+    wait_for_optim_stream_enqueued: bool
+
+@no_type_check
+def _apply_optim_in_backward_hook(
+    gradient_is_bucket_view: bool
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""
+    Register hook to apply the optimizer in backward.
+
+    If torch.distributed.optim._apply_optimizer_in_backward is used to overlap
+    optimizer with backward pass, DDP will run the below hook to run optimizer
+    step for parameters after gradient communication has taken place.
+    """
+    optim_in_bwd_state = _OptimInBackwardHookState(
+        optim_stream=torch.cuda.Stream(),
+        wait_for_optim_stream_enqueued=False,
+    )
+
+    def apply_optim_in_backward_hook(
+        hook_state: Any, bucket: dist.GradBucket, optim_stream_state,
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Run original hook
+        ddp_weakref = hook_state
+        ddp_inst = ddp_weakref()
+        reducer, process_group = ddp_inst.reducer, ddp_inst.process_group
+        fut = reducer._run_allreduce_hook(bucket)
+        optimizer_stream = optim_stream_state.optim_stream
+        with torch.cuda.stream(optimizer_stream):
+            fut.wait()
+            # Apply gradient division since C++ side only allreduces and does
+            # not average. TODO: (rohan-varma) the div factor may be different
+            # when running with join hook
+            bucket.buffer().div_(process_group.size())
+            model_params = bucket.parameters()
+            grads = bucket.gradients()
+            # TODO (rohan-varma): upcast as needed for DDP mixed precision,
+            # once optimizer in backward + DDP mixed precision is supported.
+            for p, g in zip(model_params, grads):
+                if hasattr(p, '_in_backward_optimizers'):
+                    # Note: need to set grad to the bucket's grad, because
+                    # running allreduce results in the bucket's grad being
+                    # reduced, but not grad field.
+                    if not gradient_is_bucket_view:
+                        p.grad = g
+                    for optim in p._in_backward_optimizers:
+                        optim.step()
+
+        # Need to return a Future[Tensor] to obey comm hook API contract.
+        ret_fut = torch.futures.Future()
+        ret_fut.set_result(bucket.buffer())
+
+        # enqueue a callback to wait for this optimizer stream at the end of
+        # backward and set all DDP managed grads to None.
+        def wait_for_optim_stream_callback():
+            torch.cuda.current_stream().wait_stream(
+                optim_stream_state.optim_stream
+            )
+            # Set DDP managed grads to None
+            for param in ddp_inst._get_data_parallel_params(ddp_inst.module):
+                if hasattr(param, '_in_backward_optimizers'):
+                    param.grad = None
+
+            # reset for the next backwards pass
+            optim_stream_state.wait_for_optim_stream_enqueued = False
+
+        if not optim_stream_state.wait_for_optim_stream_enqueued:
+            Variable._execution_engine.queue_callback(
+                wait_for_optim_stream_callback
+            )
+            # mark that the callback is enqueued
+            optim_stream_state.wait_for_optim_stream_enqueued = True
+
+        return ret_fut
+
+    comm_hook = partial(
+        apply_optim_in_backward_hook, optim_stream_state=optim_in_bwd_state
+    )
+    # These are needed for DDP's logging of comm hooks
+    comm_hook.__name__ = apply_optim_in_backward_hook.__name__
+    comm_hook.__qualname__ = apply_optim_in_backward_hook.__qualname__
+
+    return comm_hook
+
+def _hook_then_optimizer(
+    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
+    optimizer_state: _OptimizerHookState,
+) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
+    r"""Run optimizer in a functional fashion after DDP communication hook."""
+    has_set_params = (
+        hasattr(optimizer_state, 'params_to_optimize')
+        and optimizer_state.params_to_optimize is not None
+    )
+
+    def hook_then_optimizer_wrapper(
+        hook_state, bucket: dist.GradBucket
+    ) -> torch.futures.Future[torch.Tensor]:
+        # Run original hook
+        fut = hook(hook_state, bucket)
+
+        def optimizer_step(fut):
+            gradient_tensors = bucket.gradients()
+            model_params = bucket.parameters()
+            for grad_tensor, model_param in zip(gradient_tensors, model_params):
+                if not has_set_params or model_param in optimizer_state.params_to_optimize:
+                    optimizer_state.functional_optimizer.step_param(
+                        model_param,
+                        grad_tensor,
+                    )
+            return bucket.buffer()
+
+        return fut.then(optimizer_step)
+
+    return hook_then_optimizer_wrapper
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ca0d2868095307a10669168379c109d5fe2990
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
@@ -0,0 +1,123 @@
+import logging
+
+import torch
+import torch.distributed as dist
+
+from . import default_hooks as default
+
+logger = logging.getLogger(__name__)
+
+
+class PostLocalSGDState:
+    r"""
+    Store state for all-reducing gradients globally until given step, then locally after.
+
+    Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
+    and all-reducing gradients locally using ``subgroup`` afterwards.
+
+    If ``process_group`` is ``None``, the global process group will be used.
+    If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
+
+    Additionally, ``post_local_gradient_allreduce`` may be worth tuning,
+    because both true and false may give a faster convergence.
+    """
+
+    __slots__ = [
+        "process_group",
+        "subgroup",
+        "start_localSGD_iter",
+        "post_local_gradient_allreduce",
+        "iter",
+    ]
+
+    def __init__(
+        self,
+        process_group,
+        subgroup,
+        start_localSGD_iter,
+        post_local_gradient_allreduce=True,
+    ):
+        """Initialize state object with given parameters and log when localSGD start."""
+        logger.info(
+            "Local SGD will be started after %s iterations", start_localSGD_iter
+        )
+
+        # The group used for all-reducing gradients globally.
+        self.process_group = process_group
+        # The group used for all-reducing gradients locally.
+        self.subgroup = subgroup
+        self.start_localSGD_iter = start_localSGD_iter
+        # Allreduce gradients locally since iteration `start_localSGD_iter`.
+        # This may help with the convergence efficiency at the cost of relatively cheap intra-subgroup communication.
+        self.post_local_gradient_allreduce = post_local_gradient_allreduce
+        # Iteration/step in the training loop.
+        self.iter = 0
+
+    def maybe_increase_iter(self, bucket):
+        """Track iterations and trigger log message at start of local SGD."""
+        # Since bucket 0 is the last bucket to allreduce in an iteration.
+        # Only increase `iter` when bucket 0 is processed.
+        if bucket.is_last():
+            self.iter += 1
+
+        if self.iter == self.start_localSGD_iter:
+            logger.info(
+                "Start to apply local SGD after %s iterations.", self.iter
+            )
+
+def post_localSGD_hook(
+    state: PostLocalSGDState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Run post-localSGD algorithm.
+
+    This DDP communication hook is used for running post-localSGD algorithm,
+    by combining with a model averaging component (e.g.,
+    :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`)
+    that runs after the optimizer step.
+
+    Args:
+        state (PostLocalSGDState): State information to run post-localSGD.
+            Users mainly need to tune ``start_localSGD_iter`` to determine when to start local SGD.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+            Note that since DDP comm hook only supports single process single device mode,
+            only exactly one tensor is stored in this bucket.
+
+    Returns:
+        Future handler of the communication, which updates the gradients in place.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PostLocalSGDState(process_group=process_group, subgroup=subgroup,
+                                  start_localSGD_iter=10)
+        >>> ddp_model.register_comm_hook(state, post_localSGD_hook)
+        >>> # Also need to establish a model averaging module and run model averaging after ``optimizer.step()``.
+        >>> # Please refer to the examples in ``torch.distributed.algorithms.model_averaging.averagers`` module.
+    """
+    global_group_to_use = (
+        state.process_group if state.process_group is not None else dist.group.WORLD
+    )
+
+    # The input tensor is a flattened 1D tensor.
+    input_tensor = bucket.buffer()
+
+    # Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
+    if state.iter < state.start_localSGD_iter:
+        state.maybe_increase_iter(bucket)
+        return default._allreduce_fut(global_group_to_use, input_tensor)
+
+    # If `post_local_gradient_allreduce` is not set,
+    # then no gradient synchronization after the first `start_localSGD_iter` iterations.
+    if not state.post_local_gradient_allreduce:
+        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
+        fut.set_result(input_tensor)
+        return fut
+
+    # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
+    # Note that by default, a separate subgroup for each node is created which
+    # causes an intra-node allreduce to be done at each training step.
+    # From this moment, model averaging should run after the optimizer step,
+    # to globally allreduce all the parameters.
+    if state.subgroup is None:
+        state.subgroup, _ = dist.new_subgroups()
+    return default._allreduce_fut(state.subgroup, input_tensor)
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..364cccec56fe28a0bb4451fd0d4f3a1565f49d10
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -0,0 +1,850 @@
+from collections import defaultdict
+import logging
+import math
+from typing import Dict
+
+import torch
+import torch.distributed as dist
+
+from . import default_hooks as default
+from torch.distributed import distributed_c10d
+
+__all__ = [
+    "PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"
+]
+
+logger = logging.getLogger(__name__)
+
+
+def _orthogonalize(matrices, epsilon=0):
+    """
+    Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
+
+    QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
+    """
+    assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]
+
+    num_matrices = matrices.shape[0]
+    rank = matrices.shape[2]
+    dtype = matrices.dtype
+    if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
+        _orthogonalize_gram_schmidt(matrices, epsilon=epsilon)
+    else:
+        torch.linalg.qr(
+            matrices,
+            out=(
+                matrices,
+                torch.empty(num_matrices, rank, rank, device=matrices.device, dtype=dtype)
+            )
+        )
+
+def _orthogonalize_gram_schmidt(matrices, epsilon=0):
+    """
+    Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
+
+    If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`,
+    """
+    num_cols = matrices.shape[2]
+    for i in range(num_cols):
+        # Normalize the i'th column.
+        col = matrices[:, :, i : i + 1]
+        # If no epsilon is added here, division by zero may be caused by vanishing gradients.
+        # This epsilon is not needed if the input batch of matrices covers the gradients of at least one entire layer
+        # in the neural network.
+        if epsilon == 0:
+            # Note that col ** 2 can underflow/overflow if we use FP16.
+            # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead.
+            try:
+                col /= torch.norm(col, dim=1, keepdim=True)
+            except ZeroDivisionError:
+                logger.error(
+                    "The matrices to be orthogonalized has at least a column of all 0s. Please set a small value such as 1e-8 "
+                    "as `orthogonalization_epsilon` in PowerSGD state."
+                )
+                # Recover the values from NaNs to 0s.
+                col.fill_(0.0)
+        else:
+            col /= torch.norm(col, dim=1, keepdim=True) + epsilon
+        # Project it on the rest and remove it.
+        if i + 1 < num_cols:
+            rest = matrices[:, :, i + 1 :]
+            rest -= torch.sum(col * rest, dim=1, keepdim=True) * col
+
+
+def _should_compress(
+    num_rows, num_cols, matrix_approximation_rank, min_compression_rate
+):
+    """
+    Recommend if tensor given is worth compressing.
+
+    Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing,
+    including statistics describing the expected savings from compression.  We consider a tensor worth
+    compressing when ``min_compression_rate`` < uncompressed size / compressed size, where
+    uncompressed size = ``num_rows`` * ``num_cols``,
+    and compressed size = (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
+
+    The result of this function is a tuple of the form (compression_recommendation, uncompressed_el_count, compressed_el_count), where:
+
+    compression_recommendation is true if the tensor is worth compressing, and false otherwise (see above);
+
+    uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and,
+
+    compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``.
+    """  # noqa: B950
+    uncompressed_size = num_rows * num_cols
+    compressed_size = (num_rows + num_cols) * matrix_approximation_rank
+    return (
+        compressed_size * min_compression_rate < uncompressed_size,
+        uncompressed_size,
+        compressed_size,
+    )
+
+
+def _report_compression_stats(bucket, state):
+    """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
+    if (
+        bucket.is_last()
+        and state.iter >= state.next_stats_report
+    ):
+        stats = state.compression_stats()
+        logger.info(
+            "Compression stats: iter %s, total before compression %s, total after compression %s, "
+            "rate %s", state.iter, stats[1], stats[2], stats[0]
+        )
+        state.next_stats_report = state.iter + state.compression_stats_logging_frequency
+
+
+class PowerSGDState:
+    r"""
+    Store both the algorithm's hyperparameters and internal state for all gradients during training.
+
+    Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
+    For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
+
+    1. ``matrix_approximation_rank`` controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.
+
+        1.1. If ``matrix_approximation_rank`` is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.
+
+        1.2. The increase of ``matrix_approximation_rank`` can substantially increase the computation costs of the compression, and the accuracy may not be further improved beyond a certain ``matrix_approximation_rank`` threshold.
+
+    To tune ``matrix_approximation_rank``, we suggest to start from 1 and increase by factors of 2 (like an exponential grid search, 1, 2, 4, ...), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.
+
+    2. ``start_powerSGD_iter`` defers PowerSGD compression until step ``start_powerSGD_iter``, and vanilla allreduce runs prior to step ``start_powerSGD_iter``. This hybrid scheme of **vanilla allreduce + PowerSGD** can effectively improve the accuracy, even a relatively small ``matrix_approximation_rank`` is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.
+
+    To tune ``start_powerSGD_iter``, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached. If there is a warm-up stage in the training, ``start_powerSGD_iter`` typically should be no less than the number of warm-up steps.
+
+    3. ``min_compression_rate`` is the minimum compression rate required when a layer is compressed. Due to the computation overheads incurred by the compression, a tensor is worth compressing only if there can be sufficient saving in bandwidth, where ``(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols``. If the specified compression rate threshold cannot be satisfied, the tensor will be directly allreduced without compression.
+
+    Compression statistics are logged every ``compression_stats_logging_frequency`` iterations once PowerSGD compression starts.
+
+    4. ``orthogonalization_epsilon`` can be a very small value (e.g., 1e-8) added to every normalized matrix column in orthogonalization step, to prevent div-by-zero error if any column has all 0s. If this can already be prevented (e.g., by batch normalization), an epsilon of 0 is recommended for accuracy.
+
+    5. ``batch_tensors_with_same_shape`` controls whether to compress and decompress tensors with same shape in a batched operation to achieve higher parallelism. Note that you should also increase the bucket size (i.e., ``bucket_cap_mb`` arg in DDP constructor) to make more same-shaped tensors appear in the same bucket, however this may reduce the overlap between computation and communication, and increase the memory footprint due to stacking the tensors of the same shape. Set to ``True`` if the compression / decompression computation is a bottleneck.
+
+    .. warning ::
+        If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2.
+        This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
+        and this can conflict with any tensor memorized before the rebuild process.
+    """  # noqa: B950
+
+    __slots__ = [
+        "process_group",
+        # The fields below are the hyperparameters that often need to be tuned by the user.
+        "matrix_approximation_rank",
+        "start_powerSGD_iter",
+        # The fields below are the hyperparameters that seldom need be tuned by the user.
+        "min_compression_rate",
+        "orthogonalization_epsilon",
+        # The fields below are the binary hyperparameters recommended to be turned on for performance and accuracy.
+        "use_error_feedback",
+        "warm_start",
+        "batch_tensors_with_same_shape",
+        # The fields below are internal state.
+        "rng",
+        "error_dict",
+        "p_memory_dict",
+        "q_memory_dict",
+        "iter",
+        # The fields below are for recording compression stats.
+        "total_numel_before_compression",
+        "total_numel_after_compression",
+        "compression_stats_logging_frequency",
+        "next_stats_report",
+    ]
+
+    def __init__(
+        self,
+        process_group,
+        matrix_approximation_rank=1,
+        start_powerSGD_iter=1_000,
+        min_compression_rate=2,
+        use_error_feedback=True,
+        warm_start=True,
+        orthogonalization_epsilon=0,
+        random_seed=0,
+        compression_stats_logging_frequency=10_000,
+        batch_tensors_with_same_shape: bool = False,
+    ):
+        logger.info(
+            "PowerSGD config: matrix_approximation_rank = %s; start_powerSGD_iter = %s; "
+            "min_compression_rate = %s; orthogonalization_epsilon = %s; use_error_feedback = %s; warm_start = %s; "
+            "random_seed = %s; compression_stats_logging_frequency = %s; batch_tensors_with_same_shape = %s",
+            matrix_approximation_rank,
+            start_powerSGD_iter,
+            min_compression_rate,
+            orthogonalization_epsilon,
+            use_error_feedback,
+            warm_start,
+            random_seed,
+            compression_stats_logging_frequency,
+            batch_tensors_with_same_shape,
+        )
+
+        self.process_group = process_group
+        self.matrix_approximation_rank = matrix_approximation_rank
+        # Deferring PowerSGD compression util step 'start_powerSGD_iter' can have two advantages:
+        # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss,
+        # even if the matrix approximation rank is increased to a large value.
+        # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce
+        # (or a more conservative compression such as FP16 compression) with PowerSGD.
+        # 2) There is an internal optimization of rebuilding buckets process in DDP,
+        # in order to save the memory space.
+        # This step takes place after the first iteration.
+        # However, this means that the shape of input bucketized tensors is subject to change,
+        # which will complicate the implementations of error feedback and warm-up.
+        # Running vanilla allreduce in the first few iterations can avoid this complexity.
+        if (use_error_feedback or warm_start) and start_powerSGD_iter <= 1:
+            raise ValueError(
+                "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
+                "because PowerSGD can only be applied after the first two iterations in DDP."
+            )
+        self.start_powerSGD_iter = start_powerSGD_iter
+        self.min_compression_rate = min_compression_rate
+        # Error feedback is usually crucial for both for convergence and generalization,
+        # because PowerSGD is a biased compressor,
+        # i.e., compressing and decompressing a random gradient does not yield the original in expectation.
+        # This mechanism requires a temporary copy of the input gradients,
+        # so it increases the peak memory consumption by the size of the gradient tensor.
+        # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank),
+        # sometimes it is possible to converge to the optima without error feedback.
+        # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf
+        self.use_error_feedback = use_error_feedback
+        # Warm-start reuses P(s) and Q(s) from the previous iteration.
+        # This can improve the approximation quality and hence improve the accuracy.
+        # Additionally, by avoiding the initialization of these low-rank tensors at every step,
+        # this can also accelerate training.
+        # However, this is at the cost of extra memory.
+        self.warm_start = warm_start
+        # Can use a very small value to prevent div-by-zero error caused by orthogonalization of vanishing gradients.
+        self.orthogonalization_epsilon = orthogonalization_epsilon
+        # The purpose of this RNG is to generate different random seeds for initializing Q across iterations,
+        # but in the same order for all the DDP replicas.
+        # Different random seeds across iterations indicate different 'projections' of the gradients at different SGD steps.
+        # If the same random projection is used,
+        # there will be differences between the gradients that are never synchronized.
+        import numpy as np
+        self.rng = np.random.RandomState(random_seed)
+        # Since there is only a single state instance for all the input buckets,
+        # need to maintain a dictionary that maps each bucket index to the local error.
+        self.error_dict: Dict[int, torch.Tensor] = {}
+        self.p_memory_dict: Dict[int, torch.Tensor] = {}
+        self.q_memory_dict: Dict[int, torch.Tensor] = {}
+        # Iteration/step in the training loop.
+        self.iter = 0
+        # Compression stats accumulators
+        self.total_numel_before_compression = 0
+        self.total_numel_after_compression = 0
+        # We'll report compression stats every 'compression_stats_logging_frequency' iterations
+        # Note that we always report compression stats at least once.
+        self.compression_stats_logging_frequency = max(
+            1, compression_stats_logging_frequency
+        )
+        self.next_stats_report = 0
+        # Batching tensors with same shape can increase parallelism in compression / decompression computation.
+        # This requires a larger bucket size to make more same-shaped tensor to appear in one bucket, however
+        # this may reduce the overlap between computation and communication, and increase the memory footprint
+        # due to stacking tensors.
+        # Turn on if compression / decompression computation is a bottleneck.
+        self.batch_tensors_with_same_shape = batch_tensors_with_same_shape
+
+    def __getstate__(self):
+        r"""
+        Return a ``Dict[str, Any]`` which will be pickled and saved.
+
+        ``process_group`` is not serializable and excluded from
+        a returned state.
+        """
+        logger.warning(
+            "NOTE: Process group is not serializable and excluded from a saved state."
+        )
+        return {
+            slot: getattr(self, slot)
+            for slot in self.__slots__ if slot != "process_group"
+        }
+
+    def __setstate__(self, state):
+        r"""
+        Take a provided ``state`` and set to this ``PowerSGDState`` instance.
+
+        ``process_group`` is set to default.
+        """
+        self.process_group = distributed_c10d._get_default_group()
+        logger.warning(
+            "NOTE: Process group will be set to a default group (i.e. the world size).\
+                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded."
+        )
+        for slot, value in state.items():
+            setattr(self, slot, value)
+
+    def maybe_increase_iter(self, bucket):
+        """Track iterations and trigger log message at start of local SGD."""
+        # Since bucket 0 is the last bucket to allreduce in an iteration.
+        # Only increase `iter` when bucket 0 is processed.
+        if bucket.is_last():
+            self.iter += 1
+
+        if self.iter == self.start_powerSGD_iter:
+            logger.info(
+                "Start to apply PowerSGD after %s iterations.", self.iter
+            )
+
+    def compression_stats(self):
+        r"""
+        Return latest compression statistics as tuple.
+
+        Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where:
+
+        compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression);
+
+        numel_before_compression is the total number of elements before compression was applied; and,
+
+        numel_after_compression is the total number of elements after compression was applied.
+        """  # noqa: B950
+        compress_rate = (
+            self.total_numel_before_compression / self.total_numel_after_compression
+            if self.total_numel_after_compression > 0
+            else 0
+        )
+        return (
+            compress_rate,
+            self.total_numel_before_compression,
+            self.total_numel_after_compression,
+        )
+
+
+def powerSGD_hook(
+    state: PowerSGDState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    r"""
+    Implement PowerSGD algorithm.
+
+    This DDP communication hook implements PowerSGD gradient compression
+    algorithm described in the `paper `_.
+    Once gradient tensors are aggregated across all workers, this hook applies
+    compression as follows:
+
+    1. Views the input flattened 1D gradient tensor as a list of per-parameter tensors, and divides all the tensors into two groups:
+
+        1.1 The tensors that should be compressed before allreduce, because the compression can give enough saving in bandwidth.
+
+        1.2 Rest of the tensors will be directly allreduced without compression, including all the vector tensors (for biases).
+
+    2. Handles uncompressed tensors:
+
+        2.1. Allocate contiguous memory for those uncompressed tensors, and allreduces all the uncompressed tensors as a batch, without compression;
+
+        2.2. Copies the individual uncompressed tensors from the contiguous memory back to the input tensor.
+
+    3. Handles the tensors that should be compressed by PowerSGD compression:
+
+        3.1. For each tensor M, creates two low-rank tensors P and Q for decomposing M,
+        such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
+
+        3.2. Computes each P in Ps, which is equal to MQ;
+
+        3.3. Allreduces Ps as a batch;
+
+        3.4. Orthogonalizes each P in Ps;
+
+        3.5. Computes each Q in Qs, which is approximately equal to M^TP;
+
+        3.6. Allreduces Qs as a batch;
+
+        3.7. Computes each M among all the compressed tensors, which is approximately equal to PQ^T.
+
+    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
+    This not only gives the user more control over the tradeoff between speedup and accuracy,
+    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
+
+    Args:
+        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
+            To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
+            and ``min_compression_rate``.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+            Note that since DDP comm hook only supports single process single device mode,
+            only exactly one tensor is stored in this bucket.
+
+    Returns:
+        Future handler of the communication, which updates the gradients in place.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1,
+                                  start_powerSGD_iter=10, min_compression_rate=0.5)
+        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
+    """  # noqa: B950
+    process_group = state.process_group
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    world_size = group_to_use.size()
+
+    # The input tensor is a flattened 1D tensor.
+    input_tensor = bucket.buffer()
+
+    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
+    if state.iter < state.start_powerSGD_iter:
+        state.maybe_increase_iter(bucket)
+        return default._allreduce_fut(group_to_use, input_tensor)
+
+    # Apply PowerSGD after `start_powerSGD_iter` iterations.
+    device = input_tensor.device
+    dtype = input_tensor.dtype
+
+    # Incorporate the error from the previous state into the gradients.
+    bucket_index = bucket.index()
+    input_tensor_cp = None
+    total_length = input_tensor.shape[0]
+    if state.use_error_feedback:
+        if bucket_index in state.error_dict:
+            input_tensor.add_(state.error_dict[bucket_index])
+        else:
+            logger.info(
+                "A zero tensor of length %s that represents local error is created.",
+                total_length
+            )
+            state.error_dict[bucket_index] = torch.zeros(
+                total_length, device=device, dtype=dtype
+            )
+
+        # Keep a copy of the input tensor,
+        # so that we can compute the local error caused by compression later,
+        # by comparing this copy and the input tensor updated after decompression.
+        input_tensor_cp = torch.clone(input_tensor).detach()
+
+    # Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
+    tensors = bucket.gradients()
+
+    # Step I: Divide all the tensors into two groups,
+    # one will be compressed before allreduce and the other will be directly allreduced without compression.
+    tensors_to_compress, uncompressed_tensors = [], []
+    total_Ps_size = 0
+    total_Qs_size = 0
+    for tensor in tensors:
+        matrix = tensor.view(tensor.shape[0], -1)
+        n, m = matrix.shape
+        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
+        compress_test = _should_compress(
+            n, m, matrix_approximation_rank, state.min_compression_rate
+        )
+        state.total_numel_before_compression += compress_test[1]
+        if compress_test[0]:
+            tensors_to_compress.append(matrix)
+            total_Ps_size += n * matrix_approximation_rank
+            total_Qs_size += m * matrix_approximation_rank
+            state.total_numel_after_compression += compress_test[2]
+        else:
+            uncompressed_tensors.append(tensor)
+            state.total_numel_after_compression += compress_test[1]
+
+    _report_compression_stats(bucket, state)
+
+    # Step II: Handle uncompressed tensors.
+    # Allocate contiguous memory for these tensors to allreduce efficiently.
+    uncompressed_tensors_memory = (
+        torch.cat([tensor.view(-1) for tensor in uncompressed_tensors])
+        if uncompressed_tensors
+        else torch.tensor([], device=device, dtype=dtype)
+    )
+
+    # Step III: Handle the tensors that should be compressed.
+    # Allocate contiguous memory for Ps and Qs to allreduce efficiently.
+    # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
+    # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
+    need_randomize_qs = False
+    if not state.warm_start or bucket_index not in state.p_memory_dict:
+        need_randomize_qs = True
+        # If warm-start is disabled, low-rank tensors will be initialized at every step.
+        # Only log this if warm-start to avoid spamming.
+        if state.warm_start:
+            logger.info(
+                "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.",
+                total_Ps_size, total_Qs_size
+            )
+        state.p_memory_dict[bucket_index] = torch.empty(
+            total_Ps_size, device=device, dtype=dtype
+        )
+        state.q_memory_dict[bucket_index] = torch.empty(
+            total_Qs_size, device=device, dtype=dtype
+        )
+
+    # Batch tensors to compress by shape.
+    shape_to_tensors = defaultdict(list)
+    for tensor in tensors_to_compress:
+        shape_to_tensors[tensor.shape].append(tensor)
+
+    # This function decides whether to batch tensors with same shape or not according to the argument,
+    # so the following process could share the same code.
+    def maybe_batched_tensors_to_compress():
+        for tensors in shape_to_tensors.values():
+            if state.batch_tensors_with_same_shape:
+                batch_size = len(tensors)
+                if batch_size == 1:
+                    # Use the original tensor to avoid copy.
+                    yield tensors[0].unsqueeze(0)
+                else:
+                    yield torch.stack(tensors)
+            else:
+                for tensor in tensors:
+                    yield tensor.unsqueeze(0)
+
+    # Create Ps and Qs that point to the allocated memory.
+    tensors_to_compress = []
+    ps = []
+    qs = []
+    p_idx = 0
+    q_idx = 0
+    for tensor in maybe_batched_tensors_to_compress():
+        batch_size, n, m = tensor.shape
+        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
+        tensors_to_compress.append(tensor)
+        ps.append(
+            state.p_memory_dict[bucket_index][
+                p_idx : p_idx + batch_size * n * matrix_approximation_rank
+            ].view(batch_size, n, matrix_approximation_rank)
+        )
+        qs.append(
+            state.q_memory_dict[bucket_index][
+                q_idx : q_idx + batch_size * m * matrix_approximation_rank
+            ].view(batch_size, m, matrix_approximation_rank)
+        )
+        p_idx += batch_size * n * matrix_approximation_rank
+        q_idx += batch_size * m * matrix_approximation_rank
+
+    # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
+    # The exception is the first iteration when PowerSGD is applied.
+    if not need_randomize_qs:
+        for q in qs:
+            _orthogonalize(q, state.orthogonalization_epsilon)
+    else:
+        with torch.random.fork_rng(devices=[]):
+            # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
+            # The seed makes sure that the initial random values are the same across all the DDP replicas.
+            # This seed should differ at every step.
+            # Since it is very slow to fork RNG state across all the CUDA devices,
+            # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q).
+            torch.manual_seed(state.rng.randint(1_000_000_000))
+            for q in qs:
+                q.copy_(
+                    torch.randn(
+                        *q.shape,
+                        device="cpu",
+                        dtype=dtype,
+                    )
+                )
+                _orthogonalize(q, state.orthogonalization_epsilon)
+
+    # Compute Ps.
+    for tensor, q, p in zip(tensors_to_compress, qs, ps):
+        torch.bmm(tensor, q, out=p)
+
+    # This allreduce is only applied to uncompressed tensors,
+    # so it should have been kicked off before the above computation on the compressed tensors to hide more communication costs.
+    # However, this somehow requires a separate future chain at this time.
+    allreduce_contiguous_uncompressed_tensors_fut = dist.all_reduce(
+        uncompressed_tensors_memory, group=group_to_use, async_op=True
+    ).get_future()
+
+    def unpack_uncompressed_tensors_and_allreduce_ps(fut):
+        uncompressed_tensors_memory = fut.value()[0].div_(world_size)
+        idx = 0
+        for tensor in uncompressed_tensors:
+            tensor.copy_(
+                uncompressed_tensors_memory[idx : idx + tensor.numel()].view_as(tensor)
+            )
+            idx += tensor.numel()
+
+        # Since these Ps will be orthogonalized later, no need to divide them by world size.
+        return (
+            dist.all_reduce(
+                state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
+            )
+            .get_future()
+            .wait()[0]
+        )
+
+    def compute_qs(fut):
+        state.p_memory_dict[bucket_index] = fut.value()
+        for p in ps:
+            _orthogonalize(p, state.orthogonalization_epsilon)
+
+        # Compute Qs.
+        for tensor, p, q in zip(tensors_to_compress, ps, qs):
+            torch.bmm(tensor.transpose(1, 2), p, out=q)
+
+        # TODO: The above procedure does two matmul+allreduce steps per iteration --
+        # one left multiplication and one right multiplication.
+        # For warm-start, can take one such step at a time, and alternate between them.
+
+        # Allreduce Qs.
+        return (
+            dist.all_reduce(
+                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
+            )
+            .get_future()
+            .wait()[0]
+        )
+
+    def decompress(fut):
+        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
+
+        for p, q, tensor in zip(ps, qs, tensors_to_compress):
+            torch.bmm(p, q.transpose(1, 2), out=tensor)
+
+        # Copy batched tensors back to original buffer.
+        if state.batch_tensors_with_same_shape:
+            for tensor in tensors_to_compress:
+                if tensor.shape[0] == 1:
+                    # Skip tensor with batch_size == 1 since itself is the original tensor.
+                    continue
+                original_tensors = shape_to_tensors[tensor.shape[1:]]
+                for i, original_tensor in enumerate(original_tensors):
+                    original_tensor.copy_(tensor[i])
+
+        if torch.cuda.is_available():
+            torch.cuda.synchronize(device)
+
+        if state.use_error_feedback:
+            # Memorize the local errors.
+            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
+        if not state.warm_start:
+            state.p_memory_dict.clear()
+            state.q_memory_dict.clear()
+
+        state.maybe_increase_iter(bucket)
+
+        return input_tensor
+
+    return (
+        allreduce_contiguous_uncompressed_tensors_fut.then(
+            unpack_uncompressed_tensors_and_allreduce_ps
+        )
+        .then(compute_qs)
+        .then(decompress)
+    )
+
+
+def batched_powerSGD_hook(
+    state: PowerSGDState, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    r"""
+    Implement simplified PowerSGD algorithm.
+
+    This DDP communication hook implements a simplified PowerSGD gradient compression
+    algorithm described in the `paper `_.
+    This variant does not compress the gradients layer by layer,
+    but instead compresses the flattened input tensor that batches all the gradients.
+    Therefore, it is **faster** than :meth:`powerSGD_hook`,
+    but usually results in a **much lower accuracy**, unless ``matrix_approximation_rank`` is 1.
+
+    .. warning ::
+        Increasing ``matrix_approximation_rank`` here may not necessarily increase the accuracy,
+        because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
+        Therefore, the user should always consider :meth:`powerSGD_hook` first,
+        and only consider this variant when a satisfactory accuracy can be achieved when ``matrix_approximation_rank`` is 1.
+
+    Once gradient tensors are aggregated across all workers, this hook applies
+    compression as follows:
+
+    1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
+
+    2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
+
+    3. Computes P, which is equal to MQ;
+
+    4. Allreduces P;
+
+    5. Orthogonalizes P;
+
+    6. Computes Q, which is approximately equal to M^TP;
+
+    7. Allreduces Q;
+
+    8. Computes M, which is approximately equal to PQ^T.
+
+    9. Truncates the input tensor to the original length.
+
+    Note that this communication hook enforces vanilla allreduce for the first ``state.start_powerSGD_iter`` iterations.
+    This not only gives the user more control over the tradeoff between speedup and accuracy,
+    but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.
+
+    Args:
+        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
+            To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+            Note that since DDP comm hook only supports single process single device mode,
+            only exactly one tensor is stored in this bucket.
+
+    Returns:
+        Future handler of the communication, which updates the gradients in place.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
+        >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
+    """  # noqa: B950
+    process_group = state.process_group
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    world_size = group_to_use.size()
+
+    # The input tensor is a flattened 1D tensor.
+    input_tensor = bucket.buffer()
+
+    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
+    if state.iter < state.start_powerSGD_iter:
+        state.maybe_increase_iter(bucket)
+        return default._allreduce_fut(group_to_use, input_tensor)
+
+    # Apply PowerSGD after `start_powerSGD_iter` iterations.
+    device = input_tensor.device
+    total_length = input_tensor.shape[0]
+    state.total_numel_before_compression += total_length
+
+    # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
+    square_side_length = math.ceil(math.sqrt(total_length))
+    state.total_numel_after_compression += (
+        square_side_length * state.matrix_approximation_rank * 2
+    )
+    padded_total_length = square_side_length ** 2
+    input_tensor.resize_(padded_total_length)
+    input_tensor[total_length:padded_total_length].fill_(0)
+
+    _report_compression_stats(bucket, state)
+
+    # Incorporate the error from the previous state into the gradients.
+    bucket_index = bucket.index()
+    input_tensor_cp = None
+    if state.use_error_feedback:
+        if bucket_index in state.error_dict:
+            input_tensor.add_(state.error_dict[bucket_index])
+        else:
+            logger.info(
+                "A zero tensor of length %s that represents local error is created.",
+                padded_total_length
+            )
+            state.error_dict[bucket_index] = torch.zeros(
+                padded_total_length, device=device, dtype=input_tensor.dtype
+            )
+
+        # Keep a copy of the input tensor,
+        # so that we can compute the local error caused by compression later,
+        # by comparing this copy and the input tensor updated after decompression.
+        input_tensor_cp = torch.clone(input_tensor).detach()
+    matrix = input_tensor.view(square_side_length, square_side_length)
+
+    # Reuse P and Q from the previous iteration if possible.
+    # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
+    if not state.warm_start or bucket_index not in state.p_memory_dict:
+        # If warm-start is disabled, low-rank tensors will be initialized at every step.
+        # Only log this if warm-start to avoid spamming.
+        if state.warm_start:
+            logger.info(
+                "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.",
+                square_side_length, state.matrix_approximation_rank
+            )
+
+        def create_low_rank_tensor(fill_random_values, rng):
+            """Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank."""
+            if fill_random_values:
+                with torch.random.fork_rng(devices=[]):
+                    # Fork this RNG to avoid changing the seed globally and affecting the random sampling
+                    # anywhere else in the training.
+                    # The seed makes sure that the initial random values are the same across all the DDP replicas.
+                    # This seed should differ at every step.
+                    # Since it is very slow to fork RNG state across all the CUDA devices,
+                    # only fork on CPU and then move the generated tensor to the CUDA device.
+                    torch.manual_seed(rng.randint(1_000_000_000))
+                    return torch.randn(
+                        square_side_length,
+                        state.matrix_approximation_rank,
+                        device="cpu",
+                        dtype=input_tensor.dtype,
+                    ).to(device)
+            else:
+                return torch.empty(
+                    square_side_length,
+                    state.matrix_approximation_rank,
+                    device=device,
+                    dtype=input_tensor.dtype,
+                )
+
+        state.p_memory_dict[bucket_index] = create_low_rank_tensor(
+            fill_random_values=False, rng=state.rng
+        )
+        state.q_memory_dict[bucket_index] = create_low_rank_tensor(
+            fill_random_values=True, rng=state.rng
+        )
+    _orthogonalize(state.q_memory_dict[bucket_index])
+
+    torch.matmul(
+        matrix, state.q_memory_dict[bucket_index], out=state.p_memory_dict[bucket_index]
+    )
+    allreduce_p_fut = dist.all_reduce(
+        state.p_memory_dict[bucket_index], group=group_to_use, async_op=True
+    ).get_future()
+
+    def compute_q(fut):
+        state.p_memory_dict[bucket_index] = fut.value()[0]
+        _orthogonalize(state.p_memory_dict[bucket_index])
+
+        torch.matmul(
+            matrix.t(),
+            state.p_memory_dict[bucket_index],
+            out=state.q_memory_dict[bucket_index],
+        )
+
+        # TODO: The above procedure does two matmul+allreduce steps per iteration --
+        # one left multiplication and one right multiplication.
+        # For warm-start, can take one such step at a time, and alternate between them.
+
+        return (
+            dist.all_reduce(
+                state.q_memory_dict[bucket_index], group=group_to_use, async_op=True
+            )
+            .get_future()
+            .wait()[0]
+        )
+
+    def decompress(fut):
+        state.q_memory_dict[bucket_index] = fut.value().div_(world_size)
+        torch.matmul(
+            state.p_memory_dict[bucket_index],
+            state.q_memory_dict[bucket_index].t(),
+            out=matrix,
+        )
+
+        if state.use_error_feedback:
+            # Memorize the local errors.
+            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
+        # Removing this seemingly unnecessary sync somehow may cause failures.
+        # See: https://github.com/pytorch/pytorch/pull/54838
+        if torch.cuda.is_available():
+            torch.cuda.synchronize(device)
+        if not state.warm_start:
+            state.p_memory_dict.clear()
+            state.q_memory_dict.clear()
+        ret = input_tensor.resize_(total_length)
+
+        state.maybe_increase_iter(bucket)
+
+        return ret
+
+    return allreduce_p_fut.then(compute_q).then(decompress)
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5091f9cbe918217aa135a46729c191b5320cd82
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
@@ -0,0 +1,217 @@
+import torch
+import torch.distributed as dist
+from torch import nn
+
+
+def _quantize_per_tensor_cuda(x, scale, zero_point):
+    y = torch.round(x / scale) + zero_point
+    y = torch.clamp(y, 0, 255).to(torch.uint8)
+    return y
+
+
+def _dequantize_per_tensor_cuda(y, scale, zero_point):
+    x = scale * (y.to(torch.float32) - zero_point)
+    return x
+
+
+def _quantize_per_channel_cuda(x, scale, zero_point):
+    y = torch.zeros(x.size(), device=x.device)
+    for i in range(x.size()[0]):
+        y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i]
+    y = torch.clamp(y, 0, 255).to(torch.uint8)
+    return y
+
+
+def _dequantize_per_channel_cuda(y, scale, zero_point):
+    y = y.to(torch.float32).cuda(y.device)
+    x = torch.zeros_like(y, device=y.device)
+    for i in range(x.size()[0]):
+        x[i, :] = scale[i] * (y[i, :] - zero_point[i])
+    return x
+
+
+def _get_allgather_out_list(all_gather_in_list, world_size):
+    out_list = [
+        torch.zeros_like(
+            all_gather_in_list,
+            device=all_gather_in_list.device,
+            dtype=all_gather_in_list.dtype,
+        )
+        for _ in range(world_size)
+    ]
+    return out_list
+
+
+def quantization_pertensor_hook(
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Apply ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol.
+
+    Workers first allgather the scale and zero point of their own
+    ``GradBucket`` prior to the quantization. After all workers have that information,
+    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
+    own gradient tensor, and uses ``allgather`` to communicate these across all workers.
+    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
+    aggregates each quantized gradient tensor locally and returns the mean.
+
+    .. warning ::
+        This is experimental, and uses ``allgather`` protocol which is considerably slower than
+        ``allreduce`` protocol. It works only with flattened grads.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
+    """
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    rank = process_group.rank() if process_group is not None else dist.get_rank()
+    world_size = group_to_use.size()
+
+    tensor = bucket.buffer()
+
+    myObserver = torch.ao.quantization.MinMaxObserver().cuda(tensor.device)
+    myObserver(tensor)
+
+    s, z = myObserver.calculate_qparams()
+    s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device)
+
+    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
+
+    # First, allgather scale and zeros.
+    fut = dist.all_gather(
+        all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
+    ).get_future()
+
+    def quantize_and_allgather(fut):
+        # Store scale and zeros across all workers.
+        all_ranks_s_and_z = fut.wait()[0]
+        # All workers quantize their own ``GradBucket`` tensors.
+        quantized_tensor = _quantize_per_tensor_cuda(
+            tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]
+        )
+        # Allgather quantized tensors.
+        fut = dist.all_gather(
+            _get_allgather_out_list(quantized_tensor, world_size),
+            quantized_tensor,
+            group=group_to_use,
+            async_op=True,
+        ).get_future()
+
+        return fut.wait()
+
+    def dequantize_and_aggregate(fut):
+        all_ranks_quantized_tensor = fut.wait()[0]
+
+        aggregated_dequantized_tensor = torch.zeros_like(
+            all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
+        )
+        # Using previously allgathered scales and zeros, dequantize gradient tensors
+        # locally and then aggregate them.
+        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
+            aggregated_dequantized_tensor += _dequantize_per_tensor_cuda(
+                quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
+            )
+
+        return aggregated_dequantized_tensor / world_size
+
+    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
+
+
+def quantization_perchannel_hook(
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
+) -> torch.futures.Future[torch.Tensor]:
+    """
+    Apply``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol.
+
+    Compared to per-tensor, the main motivation of per-channel is
+    for considerably large tensors such as a tensor that contains 6 million
+    elements quantizing per a bucket size of 512 (or 128) elements may significantly
+    increase the resolution.
+
+    It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
+    elements. Then, workers allgather the scales and zero points of their own
+    ``GradBucket`` prior to the quantization. After all workers have that information,
+    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
+    own gradient tensor, and uses ``allgather`` to communicate these across all workers.
+    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
+    aggregates each quantized gradient tensor locally and returns the mean.
+
+    .. warning ::
+        This is experimental, and uses ``allgather`` protocol which is considerably slower than
+        ``allreduce`` protocol. It works only with flattened grads.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
+    """
+    group_to_use = process_group if process_group is not None else dist.group.WORLD
+    rank = process_group.rank() if process_group is not None else dist.get_rank()
+    world_size = group_to_use.size()
+
+    tensor = bucket.buffer()
+
+    tensor_in_channels = (
+        nn.functional.pad(
+            input=tensor,
+            pad=(0, bucket_size - len(tensor) % bucket_size),
+            mode="constant",
+            value=0,
+        )
+        .view(-1, bucket_size)
+        .cuda(tensor.device)
+    )
+
+    myPerChannelObserver = torch.ao.quantization.PerChannelMinMaxObserver().cuda(
+        tensor.device
+    )
+    myPerChannelObserver(tensor_in_channels)
+
+    s_ch, z_ch = myPerChannelObserver.calculate_qparams()
+    s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device)
+
+    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
+    # First, allgather scale and zeros.
+    fut = dist.all_gather(
+        all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True
+    ).get_future()
+
+    def quantize_and_allgather(fut):
+        # Store scale and zeros across all workers.
+        all_ranks_s_and_z = fut.wait()[0]
+        # All workers quantize their corresponding ``GradBucket`` tensors.
+        quantized_tensor = _quantize_per_channel_cuda(
+            tensor_in_channels,
+            all_ranks_s_and_z[rank, 0, :],
+            all_ranks_s_and_z[rank, 1, :],
+        )
+        # Allgather quantized tensors.
+        fut = dist.all_gather(
+            _get_allgather_out_list(quantized_tensor, world_size),
+            quantized_tensor,
+            group=group_to_use,
+            async_op=True,
+        ).get_future()
+
+        return fut.wait()
+
+    def dequantize_and_aggregate(fut):
+        all_ranks_quantized_tensor = fut.wait()[0]
+
+        aggregated_dequantized_tensor = torch.zeros_like(
+            all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32
+        )
+        # Using previously allgathered scales and zeros, dequantize gradient tensors
+        # locally and then aggregate them.
+        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
+            aggregated_dequantized_tensor += _dequantize_per_channel_cuda(
+                quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]
+            )
+
+        return (
+            torch.flatten(aggregated_dequantized_tensor).cuda(tensor.device)[
+                : tensor.size()[0]
+            ]
+            / world_size
+        )
+
+    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/join.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/join.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab93d0479bcad17779e74dcf79f0db1ef4a85d8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/join.py
@@ -0,0 +1,346 @@
+import warnings
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Any, List, NamedTuple, Optional, Type
+
+import torch
+import torch.distributed as dist
+
+__all__ = ['JoinHook', 'Joinable', 'Join']
+
+class JoinHook:
+    r"""
+    This defines a join hook, which provides two entry points in the join context manager.
+
+    Entry points : a main hook, which is called repeatedly while there exists a non-joined
+    process, and a post-hook, which is called once all processes have joined.
+
+    To implement a join hook for the generic join context manager, define a
+    class that inherits from :class:`JoinHook` and override ``main_hook()`` and
+    ``post_hook()`` as appropriate.
+    """
+
+    def main_hook(self) -> None:
+        r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
+
+        Training iteration i.e., in one forward pass, backward pass, and optimizer step.
+        """
+        ...
+
+    def post_hook(self, is_last_joiner: bool) -> None:
+        r"""
+        Call hook after all processes have joined.
+
+        It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
+
+        Arguments:
+            is_last_joiner (bool): ``True`` if the rank is one of the last to
+                join; ``False`` otherwise.
+        """
+        ...
+
+
+class Joinable(ABC):
+    r"""
+    This defines an abstract base class for joinable classes.
+
+    A joinable class
+    (inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
+    which returns a :class:`JoinHook` instance, in addition to
+    :meth:`join_device` and :meth:`join_process_group` that return device and
+    process group information, respectively.
+    """
+
+    @abstractmethod
+    def __init__(self):
+        super().__init__()
+        self._join_config = _JoinConfig.construct_disabled_join_config()
+
+    @abstractmethod
+    def join_hook(self, **kwargs) -> JoinHook:
+        r"""
+        Return a :class:`JoinHook` instance for the given :class:`Joinable`.
+
+        Arguments:
+            kwargs (dict): a :class:`dict` containing any keyword arguments
+                to modify the behavior of the join hook at run time; all
+                :class:`Joinable` instances sharing the same join context
+                manager are forwarded the same value for ``kwargs``.
+        """
+        ...
+
+    @property
+    @abstractmethod
+    def join_device(self) -> torch.device:
+        r"""Return the device from which to perform collective communications needed by the join context manager."""
+        ...
+
+    @property
+    @abstractmethod
+    def join_process_group(self) -> Any:
+        r"""Returns the process group for the collective communications needed by the join context manager itself."""
+        ...
+
+
+class _JoinConfig(NamedTuple):
+    r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
+
+    enable: bool
+    throw_on_early_termination: bool
+    is_first_joinable: bool
+
+    @staticmethod
+    def construct_disabled_join_config():
+        r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
+
+        e.g. if the caller is not in a join context manager.
+        """
+        return _JoinConfig(
+            enable=False,
+            throw_on_early_termination=False,
+            is_first_joinable=False
+        )
+
+
+
+class Join:
+    r"""
+    This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
+
+    These hooks should shadow the
+    collective communications of non-joined processes to prevent hanging and
+    erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
+    for details about the hook definition.
+
+    .. warning::
+        The context manager requires each participating :class:`Joinable` to
+        call the method :meth:`notify_join_context()` before its own per-
+        iteration collective communications to ensure correctness.
+
+    .. warning::
+        The context manager requires that all ``process_group`` attributes in
+        the :class:`JoinHook` objects are the same. If there are multiple
+        :class:`JoinHook` objects, then the ``device`` of the first is used.
+        The process group and device information is used for checking for non-
+        joined processes and for notifying processes to throw an exception if
+        ``throw_on_early_termination`` is enabled, both of which using an all-
+        reduce.
+
+    Arguments:
+        joinables (List[Joinable]): a list of the participating
+            :class:`Joinable` s; their hooks are iterated over in the given
+            order.
+
+        enable (bool): a flag enabling uneven input detection; setting to
+            ``False`` disables the context manager's functionality and should
+            only be set when the user knows the inputs will not be uneven
+            (default: ``True``).
+
+        throw_on_early_termination (bool): a flag controlling whether to throw an
+            exception upon detecting uneven inputs (default: ``False``).
+
+    Example::
+
+        >>> import os
+        >>> import torch
+        >>> import torch.distributed as dist
+        >>> import torch.multiprocessing as mp
+        >>> # xdoctest: +SKIP
+        >>> import torch.nn.parallel.DistributedDataParallel as DDP
+        >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
+        >>> from torch.distributed.algorithms.join import Join
+        >>>
+        >>> # On each spawned worker
+        >>> def worker(rank):
+        >>>     dist.init_process_group("nccl", rank=rank, world_size=2)
+        >>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
+        >>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
+        >>>     # Rank 1 gets one more input than rank 0
+        >>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
+        >>>     with Join([model, optim]):
+        >>>         for input in inputs:
+        >>>             loss = model(input).sum()
+        >>>             loss.backward()
+        >>>             optim.step()
+        >>>     # All ranks reach here without hanging/erroring
+    """
+
+    def __init__(
+        self,
+        joinables: List[Joinable],
+        enable: bool = True,
+        throw_on_early_termination: bool = False,
+        **kwargs,
+    ):
+        if len(joinables) == 0:
+            raise ValueError("The join context manager requires at least one joinable")
+        self._joinables = joinables
+        self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables]
+        self._enable = enable
+        self._throw_on_early_termination = throw_on_early_termination
+        self._set_joinable_configs()
+        self._extract_dist_info()
+
+    def _set_joinable_configs(self) -> None:
+        r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
+        assert len(self._joinables) > 0
+        is_first_joinable = True
+        for joinable in self._joinables:
+            joinable._join_config = _JoinConfig(
+                enable=self._enable,
+                throw_on_early_termination=self._throw_on_early_termination,
+                is_first_joinable=is_first_joinable
+            )
+            is_first_joinable = False
+
+    def _extract_dist_info(self) -> None:
+        r"""
+        Extract the process group and device information from the joinables.
+
+        If there are multiple joinables, then the context manager uses the
+        first specified device.
+
+        Preconditions:
+            ``self._joinables`` is not ``None`` and is non-empty.
+
+        Raises:
+            ValueError
+                If there are multiple conflicting ``process_group`` attributes
+                among the ``Joinable`` objects.
+        """
+        process_group = None
+        device = None
+        for joinable in self._joinables:
+            if process_group is None:
+                process_group = joinable.join_process_group
+            elif process_group != joinable.join_process_group:
+                raise ValueError("Using join context manager with multiple process groups")
+            if device is None:
+                device = joinable.join_device
+        self._process_group = process_group
+        self._rank = dist.get_rank(self._process_group)
+        self._device = device
+
+    def __enter__(self):
+        ...
+
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType]
+    ):
+        r"""
+        Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
+
+        Raises:
+            RuntimeError
+                If ``throw_on_early_termination=True``.
+        """
+        if not self._enable or type:
+            return  # propagate the exception directly if one was raised
+
+        all_procs_joined = False
+        is_last_joiner = True
+
+        i = 0
+        WARN_THRESHOLD = 1000
+        warnings.simplefilter("once")
+
+        while not all_procs_joined:
+            if i > WARN_THRESHOLD:
+                warnings.warn(
+                    "Detected uneven input skew of greater than "
+                    f"{WARN_THRESHOLD}. This means that rank "
+                    f"{self._rank} has at least {WARN_THRESHOLD} "
+                    f"fewer inputs than other currently-active ranks. "
+                    "This level of skew could lead to performance "
+                    "degradation during training."
+                )
+            # Shadow the all-reduce in non-joined processes
+            num_nonjoined_procs = self._get_num_nonjoined_procs()
+            if num_nonjoined_procs == 0:
+                all_procs_joined = True
+            else:
+                if self._throw_on_early_termination:
+                    self._notify_procs_to_terminate()
+
+                # Run main hooks
+                for join_hook in self._join_hooks:
+                    join_hook.main_hook()
+
+                is_last_joiner = False
+                i += 1
+
+        # Run post-hooks
+        for join_hook in self._join_hooks:
+            join_hook.post_hook(is_last_joiner)
+
+    def _get_num_nonjoined_procs(self):
+        r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
+        num_nonjoined_procs = torch.zeros(1, device=self._device)
+        dist.all_reduce(num_nonjoined_procs, group=self._process_group)
+        return num_nonjoined_procs.item()
+
+    def _notify_procs_to_terminate(self):
+        r"""Schedule an all-reduce to notify non-joined processes to terminate.
+
+        Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
+        """
+        ones = torch.ones(1, device=self._device)
+        dist.all_reduce(ones, group=self._process_group)
+        raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
+
+    @staticmethod
+    def notify_join_context(joinable: Joinable):
+        r"""
+        Notifies the join context manager that the calling process has not yet joined.
+
+        Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
+        (i.e. if one process has already joined) and throws an exception if so.
+
+        This method should be called from a :class:`Joinable` object before
+        its per-iteration collective communications. For example, this should
+        be called at the beginning of the forward pass in
+        :class:`DistributedDataParallel`.
+
+        Only the first :class:`Joinable` object passed into the context
+        manager performs the collective communications in this method, and
+        for the others, this method is vacuous.
+
+        Arguments:
+            joinable (Joinable): the :class:`Joinable` object calling this
+                method.
+
+        Returns:
+            An async work handle for the all-reduce meant to notify the context
+            manager that the process has not yet joined if ``joinable`` is the
+            first one passed into the context manager; ``None`` otherwise.
+        """
+        assert hasattr(joinable, "_join_config"), \
+            f"Check that the {type(joinable)} constructor calls the " \
+            "``Joinable`` constructor"
+
+        join_config = joinable._join_config
+        # First joinable is responsible for the collective communications
+        if not join_config.is_first_joinable or not join_config.enable:
+            return None
+
+        device = joinable.join_device
+        process_group = joinable.join_process_group
+
+        # Schedule an all-reduce to indicate that the caller has not yet joined
+        ones = torch.ones(1, device=device)
+        work = dist.all_reduce(ones, group=process_group, async_op=True)
+
+        if join_config.throw_on_early_termination:
+            # Check if uneven inputs have been detected
+            zeros = torch.zeros(1, device=device)
+            dist.all_reduce(zeros, group=process_group)
+            should_throw = zeros.item()
+            if should_throw:
+                raise RuntimeError(
+                    "Detected at least one rank that exhausted inputs. "
+                    "Throwing across all ranks."
+                )
+        return work
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__init__.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..413b55e18a23afb658c8c77434db638495cd5b8e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2eb37830a301cbac35cd3c3a5c84f26aa529068f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec641f28a37e86cfe44bae54118c16afbcf6fbf2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa1a0e015b2f3054bb285f36b99f62b79129d683
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/averagers.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/averagers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1085d37563fd2b73efdcaea75e44136195c1e6f7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/averagers.py
@@ -0,0 +1,120 @@
+import warnings
+from abc import ABC, abstractmethod
+from typing import Union, Iterable, Dict
+import torch
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.utils as utils
+
+__all__ = ['ModelAverager', 'PeriodicModelAverager']
+
+class ModelAverager(ABC):
+    r"""Base class for all model averagers.
+
+    Args:
+        process_group: The process group to be used for all-reduce.
+                       If ``None``, the default process group, which
+                       is created by :func:`torch.distributed.init_process_group`,
+                       will be used. (default: ``None``)
+    """
+
+    def __init__(self, process_group=None):
+        self.process_group = (
+            process_group if process_group is not None else dist.group.WORLD
+        )
+        self.step = 0
+
+    @abstractmethod
+    def average_parameters(self, params):
+        raise NotImplementedError
+
+
+class PeriodicModelAverager(ModelAverager):
+    r"""
+    Averages parameters periodically after the warm-up stage.
+
+    This can be used for running `post-local SGD `_,
+    by running :class:`~torch.nn.DistributedDataParallel` (DDP)
+    using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
+
+    Args:
+        period (int): The number of steps per model averaging.
+                      Usually the period should be greater than ``1`` to reduce the communication cost.
+                      Otherwise, only DDP needs to be used.
+        warmup_steps (int): The number of warm-up steps. During this stage,
+                            model averaging is skipped.
+        process_group: The process group to be used for all-reduce.
+                       If ``None``, the default process group, which
+                       is created by :func:`torch.distributed.init_process_group`,
+                       will be used. (default: ``None``)
+
+    Example::
+
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> import torch
+        >>> import torch.distributed as dist
+        >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
+        >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
+        >>> import torch.nn as nn
+        >>>
+        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
+        >>> torch.cuda.set_device(rank)
+        >>> module = nn.Linear(1, 1, bias=False).cuda()
+        >>> model = nn.parallel.DistributedDataParallel(
+        >>>    module, device_ids=[rank], output_device=rank
+        >>> )
+        >>> # Register a post-localSGD communication hook.
+        >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
+        >>> model.register_comm_hook(state, post_localSGD_hook)
+        >>>
+        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
+        >>> # After 100 steps, run model averaging every 4 steps.
+        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+        >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
+        >>> for step in range(0, 200):
+        >>>    optimizer.zero_grad()
+        >>>    loss = loss_fn(output, labels)
+        >>>    loss.backward()
+        >>>    optimizer.step()
+        >>>    # Will average model parameters globally every 4 steps. Thus,
+        >>>    # inter-node communication only occurs every 4 iterations after
+        >>>    # the initial ``warmup_steps`` period.
+        >>>    averager.average_parameters(model.parameters())
+    """
+
+    def __init__(
+        self,
+        period,
+        warmup_steps=0,
+        process_group=None
+    ):
+        super().__init__(process_group)
+        if warmup_steps < 0:
+            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
+        self.warmup_steps = warmup_steps
+        if period < 1:
+            raise ValueError("Arg ``period`` must be a positive value.")
+        elif period == 1:
+            warnings.warn(
+                "When period is 1, no need to use model averaging because the communication cost "
+                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
+                "by DistributedDataParallel in the backward pass. Therefore, only "
+                "DistributedDataParallel should be used for this case."
+            )
+        self.period = period
+
+    def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
+        """
+        Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``.
+
+        Can be divided by ``period``, where ``step`` is increased by 1
+        at each iteration in the training loop.
+        Args:
+            params: The parameters of a model or parameter groups of an optimizer.
+
+        """
+        if (
+            self.step >= self.warmup_steps
+            and (self.step - self.warmup_steps) % self.period == 0
+        ):
+            utils.average_parameters_or_parameter_groups(params, self.process_group)
+        self.step += 1
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6246f4c669dc75b011338ba4591b418fdb17fa8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py
@@ -0,0 +1,167 @@
+# Copyright 2022 Cruise LLC
+import logging
+import warnings
+from collections import OrderedDict
+from typing import Union, Iterable, Dict
+
+import torch
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.averagers as averagers
+import torch.distributed.algorithms.model_averaging.utils as utils
+
+logger = logging.getLogger(__name__)
+
+
+class HierarchicalModelAverager(averagers.ModelAverager):
+    r"""
+    Runs hierarchical model averaging (`hierarchical SGD `_).
+
+    Process groups of different sizes are organized in a hierarchy, and they average parameters
+    by using different periods concurrently after the warm-up stage.
+    This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
+    that supports `post-local SGD `_, which essentially only supports
+    a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
+    level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
+    Similarly, the process groups within this class do not have such an intra-machine process
+    subgroup, which should be embedded by the post-local SGD communication hook instead.
+
+    Args:
+        period_group_size_dict: An ordered dict mapping keys of model averaging period to
+                                process group size, used for initializing process groups of
+                                different sizes in a hierarchy to average parameters concurrently.
+                                Particularly, at each iteration, there will be at most a single
+                                process group that runs averaging -- the period of such group should
+                                have the largest period which the current step can be divided by.
+                                For example, if the dict has three keys: 2, 4, and 8,
+                                then this means totally three process groups will be created to
+                                average parameters every 2, 4, and 8 iterations, respectively.
+                                At the 4th iteration, only the second process group will run
+                                averaging, because the first process group should be a
+                                subset of the second process group, and no need to execute the first
+                                process group redundantly.
+                                On the other hand, the third process group can only be triggered
+                                every 8 iterations, so it will not be triggered at the 4th iteration.
+        warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
+        process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
+                                                If ``None``, the default process group, which is created
+                                                by :func:`torch.distributed.init_process_group`, will be used.
+                                                (default: ``None``)
+
+    Example::
+        >>> # xdoctest: +SKIP('undefined rank')
+        >>> from collections import OrderedDict
+        >>> import torch
+        >>> import torch.distributed as dist
+        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
+        >>>     PostLocalSGDState,
+        >>>     post_localSGD_hook,
+        >>> )
+        >>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
+        >>> import torch.nn as nn
+        >>>
+        >>> dist.init_process_group("nccl", rank=rank, world_size=16)
+        >>> torch.cuda.set_device(rank)
+        >>> module = nn.Linear(1, 1, bias=False).to(rank)
+        >>> model = nn.parallel.DistributedDataParallel(
+        >>>    module, device_ids=[rank], output_device=rank
+        >>> )
+        >>> # Register a post-localSGD communication hook.
+        >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
+        >>> subgroup, _ = dist.new_subgroups()
+        >>> state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100)
+        >>> model.register_comm_hook(state, post_localSGD_hook)
+        >>>
+        >>> # Average parameters among each group of 8 processes every 4 iterations, and among all
+        >>> # the 16 processes every 16 iterations.
+        >>> averager = hierarchicalSGD.HierarchicalModelAverager(
+        >>>     period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
+        >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+        >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
+        >>> # After 100 steps, run model averaging at two levels.
+        >>> for step in range(0, 200):
+        >>>    optimizer.zero_grad()
+        >>>    loss = loss_fn(output, labels)
+        >>>    loss.backward()
+        >>>    optimizer.step()
+        >>>    # Average parameters after ``optimizer.step()``.
+        >>>    # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
+        >>>    averager.average_parameters(model.parameters())
+
+    .. warning ::
+        The last group size in the dict must be the size of the provided ``process_group``,
+        which indicates model averaging at the highest level of the hierarchy.
+        If ``process_group`` is not provided, then the last group size should be equal to the world size.
+
+    .. warning ::
+        `HierarchicalModelAverager` is experimental and subject to change.
+    """
+
+    def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
+        super().__init__(process_group)
+        if not period_group_size_dict:
+            raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
+        self._periods = list(period_group_size_dict.keys())
+        if self._periods[0] <= 0:
+            raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.")
+        elif self._periods[-1] == 1:
+            warnings.warn(
+                "When the maximum period in arg ``period_group_size_dict`` is 1, "
+                "no need to use model averaging because the communication cost "
+                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
+                "by DistributedDataParallel in the backward pass. Therefore, only "
+                "DistributedDataParallel should be used for this case."
+            )
+        overall_group_size = dist.get_world_size(group=self.process_group)
+        if list(period_group_size_dict.values())[-1] != overall_group_size:
+            raise ValueError(
+                f"The last value in arg ``period_process_group_dict`` {list(period_group_size_dict.values())[-1]} "
+                f"must be equal to the size of arg ``process_group`` {overall_group_size}."
+            )
+
+        self.period_process_group_dict = OrderedDict()
+        logger.info("Model averaging hierarchy:")
+        for period, group_size in period_group_size_dict.items():
+            logger.info(
+                "\tEach group that has %s processes average parameters every %s iterations, "
+                "if no higher-level averaging.", group_size, period)
+            if group_size != overall_group_size:
+                self.period_process_group_dict[period], _ = dist.new_subgroups(
+                    group_size=group_size, group=self.process_group)
+            else:
+                self.period_process_group_dict[period] = self.process_group
+
+        if warmup_steps < 0:
+            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
+        self.warmup_steps = warmup_steps
+
+    def _find_process_group(self):
+        """
+        Return a process group as the value of an ``period_process_group_dict`` entry.
+
+        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
+        then the returned process group is the one corresponding to the largest period,
+        since this process group will be used for averaging parameters at this ``step``.
+        Returns ``None`` if not found.
+        """
+        for period in reversed(self._periods):
+            if self.step % period == 0:
+                return self.period_process_group_dict[period]
+        return None
+
+    def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
+        """
+        Averages parameters or parameter groups of an optimizer.
+
+        Averaging only occurs if ``step`` is no less than ``warmup_steps``
+        and it can be divided by a period in the keys of ``period_process_group_dict``,
+        where ``step`` is increased by 1 at each iteration in the training loop.
+        If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
+        only the largest period is used, and the corresponding process group is used for averaging parameters.
+        Args:
+            params: The parameters of a model or parameter groups of an optimizer.
+        """
+        if self.step >= self.warmup_steps:
+            group = self._find_process_group()
+            if group is not None:
+                utils.average_parameters_or_parameter_groups(params, group)
+        self.step += 1
diff --git a/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/utils.py b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..34444b3478e253a50ac71151c44d334f8aefb890
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/algorithms/model_averaging/utils.py
@@ -0,0 +1,72 @@
+# flake8: noqa C101
+import itertools
+from typing import Union, Iterable, Dict, Iterator
+
+import torch
+import torch.distributed as dist
+# The two imports below are not always available depending on the
+# USE_DISTRIBUTED compile flag. Make sure they raise import error
+# if we're trying to use them.
+from torch.distributed import ProcessGroup, group
+
+__all__ = ["average_parameters", "get_params_to_average", "average_parameters_or_parameter_groups"]
+
+def average_parameters(
+    params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
+):
+    """
+    Averages all the given parameters.
+
+    For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
+    Thus, it requires extra memory of the same size as the given parameters.
+    """
+    group_to_use = process_group if process_group is not None else group.WORLD
+    # Do not update any parameter if not in the process group.
+    if dist._rank_not_in_group(group_to_use):
+        return
+
+    params_it1, params_it2 = itertools.tee(params)
+    # If the input parameters have different data types,
+    # packing these parameters will trigger an implicit type up-casting.
+    # The original parameter data types will be restored during the subsequent unpacking.
+    flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
+    flat_params /= dist.get_world_size(group_to_use)
+    # Make sure the allreduce will not conflict with any other ongoing process group.
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    dist.all_reduce(flat_params, group=group_to_use)
+
+    offset = 0
+    for p in params_it2:
+        p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
+        offset += p.numel()
+
+
+def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
+    """
+    Return a list of parameters that need to average.
+
+    This filters out the parameters that do not contain any gradients.
+    Args:
+        params: The parameters of a model or parameter groups of an optimizer.
+    """
+    filtered_params = []
+    for param in params:
+        if isinstance(param, torch.nn.Parameter):
+            # model.parameters() input
+            param_data = param
+            if param_data.grad is not None:
+                filtered_params.append(param_data)
+        elif isinstance(param, dict):
+            # optimizer.param_groups input
+            for param_data in param["params"]:
+                if param_data.grad is not None:
+                    filtered_params.append(param_data)
+        else:
+            raise NotImplementedError(f"Parameter input of type {type(param)} is not supported")
+    return filtered_params
+
+
+def average_parameters_or_parameter_groups(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]], process_group: ProcessGroup):
+    """Averages parameters of a model or parameter groups of an optimizer."""
+    average_parameters(iter(get_params_to_average(params)), process_group)
diff --git a/MLPY/Lib/site-packages/torch/distributed/argparse_util.py b/MLPY/Lib/site-packages/torch/distributed/argparse_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..668ebc8c1f89d2952652a68d433d7ef9e019fd9a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/argparse_util.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+from argparse import Action
+
+
+class env(Action):
+    """
+    Get argument values from ``PET_{dest}`` before defaulting to the given ``default`` value.
+
+    For flags (e.g. ``--standalone``)
+    use ``check_env`` instead.
+
+    .. note:: when multiple option strings are specified, ``dest`` is
+              the longest option string (e.g. for ``"-f", "--foo"``
+              the env var to set is ``PET_FOO`` not ``PET_F``)
+
+    Example:
+    ::
+
+     parser.add_argument("-f", "--foo", action=env, default="bar")
+
+     ./program                                      -> args.foo="bar"
+     ./program -f baz                               -> args.foo="baz"
+     ./program --foo baz                            -> args.foo="baz"
+     PET_FOO="env_bar" ./program -f baz    -> args.foo="baz"
+     PET_FOO="env_bar" ./program --foo baz -> args.foo="baz"
+     PET_FOO="env_bar" ./program           -> args.foo="env_bar"
+
+     parser.add_argument("-f", "--foo", action=env, required=True)
+
+     ./program                                      -> fails
+     ./program -f baz                               -> args.foo="baz"
+     PET_FOO="env_bar" ./program           -> args.foo="env_bar"
+     PET_FOO="env_bar" ./program -f baz    -> args.foo="baz"
+    """
+
+    def __init__(self, dest, default=None, required=False, **kwargs) -> None:
+        env_name = f"PET_{dest.upper()}"
+        default = os.environ.get(env_name, default)
+
+        # ``required`` means that it NEEDS to be present  in the command-line args
+        # rather than "this option requires a value (either set explicitly or default"
+        # so if we found default then we don't "require" it to be in the command-line
+        # so set it to False
+        if default:
+            required = False
+
+        super().__init__(dest=dest, default=default, required=required, **kwargs)
+
+    def __call__(self, parser, namespace, values, option_string=None):
+        setattr(namespace, self.dest, values)
+
+
+class check_env(Action):
+    """
+    Check whether the env var ``PET_{dest}`` exists before defaulting to the given ``default`` value.
+
+    Equivalent to
+    ``store_true`` argparse built-in action except that the argument can
+    be omitted from the commandline if the env var is present and has a
+    non-zero value.
+
+    .. note:: it is redundant to pass ``default=True`` for arguments
+              that use this action because a flag should be ``True``
+              when present and ``False`` otherwise.
+
+    Example:
+    ::
+
+     parser.add_argument("--verbose", action=check_env)
+
+     ./program                                  -> args.verbose=False
+     ./program --verbose                        -> args.verbose=True
+     PET_VERBOSE=1 ./program           -> args.verbose=True
+     PET_VERBOSE=0 ./program           -> args.verbose=False
+     PET_VERBOSE=0 ./program --verbose -> args.verbose=True
+
+    Anti-pattern (don't do this):
+
+    ::
+
+     parser.add_argument("--verbose", action=check_env, default=True)
+
+     ./program                                  -> args.verbose=True
+     ./program --verbose                        -> args.verbose=True
+     PET_VERBOSE=1 ./program           -> args.verbose=True
+     PET_VERBOSE=0 ./program           -> args.verbose=False
+
+    """
+
+    def __init__(self, dest, default=False, **kwargs) -> None:
+        env_name = f"PET_{dest.upper()}"
+        default = bool(int(os.environ.get(env_name, "1" if default else "0")))
+        super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs)
+
+    def __call__(self, parser, namespace, values, option_string=None):
+        setattr(namespace, self.dest, self.const)
diff --git a/MLPY/Lib/site-packages/torch/distributed/autograd/__init__.py b/MLPY/Lib/site-packages/torch/distributed/autograd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85f017e96de63905fe36edba3045269a853805b4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/autograd/__init__.py
@@ -0,0 +1,52 @@
+
+import sys
+import torch
+
+
+def is_available():
+    return hasattr(torch._C, "_dist_autograd_init")
+
+
+if is_available() and not torch._C._dist_autograd_init():
+    raise RuntimeError("Failed to initialize torch.distributed.autograd")
+
+if is_available():
+    from torch._C._distributed_autograd import (
+        get_gradients,
+        backward,
+        _init,
+        _new_context,
+        _release_context,
+        _get_max_id,
+        _is_valid_context,
+        _retrieve_context,
+        _current_context,
+        _get_debug_info,
+        DistAutogradContext,
+    )
+
+
+class context:
+    '''
+    Context object to wrap forward and backward passes when using
+    distributed autograd. The ``context_id`` generated in the ``with``
+    statement  is required to uniquely identify a distributed backward pass
+    on all workers. Each worker stores metadata associated with this
+    ``context_id``, which is required to correctly execute a distributed
+    autograd pass.
+
+    Example::
+        >>> # xdoctest: +SKIP
+        >>> import torch.distributed.autograd as dist_autograd
+        >>> with dist_autograd.context() as context_id:
+        >>>     t1 = torch.rand((3, 3), requires_grad=True)
+        >>>     t2 = torch.rand((3, 3), requires_grad=True)
+        >>>     loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
+        >>>     dist_autograd.backward(context_id, [loss])
+    '''
+    def __enter__(self):
+        self.autograd_context = _new_context()
+        return self.autograd_context._context_id()
+
+    def __exit__(self, type, value, traceback):
+        _release_context(self.autograd_context._context_id())
diff --git a/MLPY/Lib/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64dfdaaa4fa7ae0dfea7199a2408efb524a8fa6c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/c10d_logger.py b/MLPY/Lib/site-packages/torch/distributed/c10d_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b4e312a7de9cb4eeb696718f6d71392a4d6b8c1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/c10d_logger.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import functools
+import logging
+import time
+from typing import Any, Callable, Dict, List, Tuple, TypeVar
+from typing_extensions import ParamSpec
+
+import torch
+import torch.distributed as dist
+
+from torch.distributed.logging_handlers import _log_handlers
+
+__all__: List[str] = []
+
+
+def _get_or_create_logger() -> logging.Logger:
+    logging_handler, log_handler_name = _get_logging_handler()
+    logger = logging.getLogger(f"c10d-{log_handler_name}")
+    logger.setLevel(logging.DEBUG)
+    formatter = logging.Formatter(
+        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
+    )
+    logging_handler.setFormatter(formatter)
+    logger.propagate = False
+    logger.addHandler(logging_handler)
+    return logger
+
+
+def _get_logging_handler(destination: str = "default") -> Tuple[logging.Handler, str]:
+    log_handler = _log_handlers[destination]
+    log_handler_name = type(log_handler).__name__
+    return (log_handler, log_handler_name)
+
+
+global _c10d_logger
+_c10d_logger = _get_or_create_logger()
+
+
+def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
+    if dist.is_initialized():
+        msg_dict = {
+            "func_name": f"{func_name}",
+            "args": f"{args}, {kwargs}",
+            "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}",  # type: ignore[arg-type]
+            "backend": f"{dist.get_backend(kwargs.get('group'))}",
+            "world_size": f"{dist.get_world_size()}",
+            "group_size": f"{dist.get_world_size(kwargs.get('group'))}",
+            "global_rank": f"{dist.get_rank()}",
+            "local_rank": f"{dist.get_rank(kwargs.get('group'))}",
+        }
+        if msg_dict["backend"] == "nccl":
+            nccl_version = torch.cuda.nccl.version()
+            msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version)
+    else:
+        msg_dict = {
+            "func_name": f"{func_name}",
+            "args": f"{args}, {kwargs}",
+        }
+    return msg_dict
+
+_T = TypeVar('_T')
+_P = ParamSpec('_P')
+
+def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
+    @functools.wraps(func)
+    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
+        try:
+            return func(*args, **kwargs)
+        except Exception as error:
+            msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
+            msg_dict["error"] = f"{error}"
+            _c10d_logger.debug(msg_dict)
+            raise
+
+    return wrapper
+
+
+def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
+    @functools.wraps(func)
+    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
+        t1 = time.time_ns()
+        func_return = func(*args, **kwargs)
+        time_spent = time.time_ns() - t1
+
+        msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
+        msg_dict["time_spent"] = f"{time_spent}ns"
+        _c10d_logger.debug(msg_dict)
+
+        return func_return
+
+    return wrapper
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__init__.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..007a207618e469e90e30bde67d8e530855decf27
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__init__.py
@@ -0,0 +1,15 @@
+from .api import CheckpointException
+from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
+from .filesystem import FileSystemReader, FileSystemWriter
+from .fsspec import FsspecReader, FsspecWriter
+from .metadata import (
+    BytesStorageMetadata,
+    ChunkStorageMetadata,
+    Metadata,
+    TensorStorageMetadata,
+)
+from .optimizer import load_sharded_optimizer_state_dict
+from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
+from .state_dict_loader import load, load_state_dict
+from .state_dict_saver import async_save, save, save_state_dict
+from .storage import StorageReader, StorageWriter
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..661dd428dfdcea39f30d89f9eccd05e86f21ae31
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a962f272b00714088e3b9a5b62944aaec92fe2a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c37ce3e477abe7dee3d85398e3d59331473dd947
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7ddb9a6928414779a4067081ebd58ddb96ba948
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72f5999482663ffbd30fe00232ccc0610e0d936c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b1e001c54e9548c2bfd4a1e0ef049c8f54581d9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b523088c2348af8378d8be453745d20a44c35df7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92f01881408b1408c3cdd559ec1e13c9aaa2b772
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0397a4d7cff9af5cc46d60e62421b049d6a5dd80
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2df7b45d6d8b2ab383fffcc662348c3ea58e0870
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0eb51f9af8213dfbbb163c3fee2eb13333a735f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2e7c7a28459aab2a6eabf3c39e8a7ab2d97ecd9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/fsspec.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/fsspec.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..027b9716b408f951788bea65c3ffa855e0a95fbf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/fsspec.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d36f2e08568645a02e48c37941c8ad3df863ffd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da8606b2455412855713415d44690cdaf240c51f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31c58dd3e1a4224129a4144f40b29829c6adaa82
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c04ab860496f72c7f20a981d483ceb82f91ddbbc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aca818516da5ee71258a77c4b251b34b79509f02
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ccac23b666a7d4e89bbb06fa6f1869383414ffd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ddbf726b933d44bc9993f0872c1041cef94877f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c7555b08a1238f073d003aefa47239add11ba0b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa89a3972e2bc24112c078dca22e8c968baffdb1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8b448ad2bdd7f3973ef4d411384ecd31f762fe0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e5ee7b56e2c1d5e6134013fb5819b5d849548d2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4a7ddce6de0aad680c775db1a6d88ad3849aa13
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_dedup_save_plans.py
@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import dataclasses
+from collections import defaultdict
+from typing import Dict, List, Set
+
+from torch.distributed.checkpoint.metadata import MetadataIndex
+from torch.distributed.checkpoint.planner import SavePlan, WriteItem
+
+__all__ = ["dedup_save_plans"]
+
+
+def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]:
+    """
+    Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
+    a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
+    """
+
+    write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
+    write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
+    for plan_idx, plan in enumerate(all_plans):
+        for write_item in plan.items:
+            # map each write item to its plan
+            write_item_to_plan_indices[write_item.index].add(plan_idx)
+            write_item_idx_to_write_item[write_item.index] = write_item
+
+    # put item in the plan with the smallest size and remove it from the other plan_indices
+    to_remove: List[Set] = [set() for _ in range(len(all_plans))]
+    plan_to_size = [0] * len(all_plans)
+    for write_item_idx, plan_indices in write_item_to_plan_indices.items():
+        select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_to_size[plan_idx])
+
+        write_item = write_item_idx_to_write_item[write_item_idx]
+        # essentially ignores the storage size of anything that is not a tensor, since
+        # we don't know how much storage they represent
+        plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
+
+        plan_indices.remove(select_plan_idx)
+        for plan_idx in plan_indices:
+            to_remove[plan_idx].add(write_item_idx)
+
+    for plan_idx, remove_set in enumerate(to_remove):
+        new_items = [
+            write_item
+            for write_item in all_plans[plan_idx].items
+            if write_item.index not in remove_set
+        ]
+        all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
+
+    return all_plans
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_dedup_tensors.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_dedup_tensors.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eea84f6333bc93cc84b863d85f139bdb2b34fb2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_dedup_tensors.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import dataclasses
+import logging
+from typing import Dict, List
+
+from torch.distributed.checkpoint.metadata import MetadataIndex
+from torch.distributed.checkpoint.planner import SavePlan
+
+__all__ = ["dedup_tensors"]
+
+
+def init_logger() -> logging.Logger:
+    logger = logging.getLogger(__name__)
+    level = logging.INFO
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter(
+        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
+    )
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+
+logger = init_logger()
+
+
+# TODO add docstring for dedup_tensors
+def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
+    all_plans = list(all_plans)
+    key_to_plan: Dict[MetadataIndex, List[int]] = {}
+    for plan_idx, plan in enumerate(all_plans):
+        for write_item in plan.items:
+            key_to_plan.setdefault(write_item.index, []).append(plan_idx)
+
+    replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}
+
+    # Remove duplicates by always keeping the first entry.
+    # Compute the per-rank remove set.
+    plan_to_keys: Dict[int, List[MetadataIndex]] = {}
+    for key, plans in replicated_items.items():
+        for plan_idx in plans[1:]:
+            plan_to_keys.setdefault(plan_idx, []).append(key)
+    if len(plan_to_keys) > 0:
+        logger.info("Duplicate keys to remove: %s", plan_to_keys)
+
+    for plan_idx, keys in plan_to_keys.items():
+        key_set = set(keys)
+        # rewrite items and remove elements
+        new_items = [
+            write_item
+            for write_item in all_plans[plan_idx].items
+            if write_item.index not in key_set
+        ]
+        all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
+
+    return all_plans
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py
new file mode 100644
index 0000000000000000000000000000000000000000..8087bd3e0bad21fcd4b80206bf1e8b5e4d029445
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_fsspec_filesystem.py
@@ -0,0 +1,15 @@
+# Mypy will not try inferring the types of any 3rd party libraries installed.
+# mypy: ignore-errors
+
+import logging
+
+from torch.distributed.checkpoint.fsspec import (  # noqa: F401  # noqa: F401
+    FsspecReader,
+    FsspecWriter,
+)
+
+log = logging.getLogger(__name__)
+log.warning(
+    "FSSpec Filesystem has been made public, please update your "
+    "import to torch.distributed.checkpoint"
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_nested_dict.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_nested_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb2bb77cf4db45a26e51613f61fd2dff2b24f10
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_nested_dict.py
@@ -0,0 +1,53 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import Dict, Tuple
+
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
+
+from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
+
+"""
+TODO:
+Need to add ability to handle tuple, OrderedDict, NamedTuple.
+Update mappings from dict to a class.
+Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple.
+"""
+
+
+FLATTEN_MAPPING = Dict[str, OBJ_PATH]
+
+
+# TODO: Update Docstring for nested_dict.py
+def flatten_state_dict(
+    state_dict: STATE_DICT_TYPE,
+) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
+    """
+    Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
+
+    Use ``unflatten_state_dict`` to revert this process.
+    Returns:
+        A tuple with the flatten state_dict and a mapping from original to new state_dict.
+    N.B. The new keys are derived from the object paths, joined by dot.
+        For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
+    """
+    flattened: STATE_DICT_TYPE = {}
+    mappings: FLATTEN_MAPPING = {}
+
+    def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
+        new_fqn = ".".join(map(str, path))
+        if new_fqn in flattened:
+            raise ValueError(f"duplicated flatten key {new_fqn}")
+        flattened[new_fqn] = value
+        mappings[new_fqn] = path
+
+    traverse_state_dict(state_dict, flat_copy)
+    return flattened, mappings
+
+
+def unflatten_state_dict(
+    state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
+) -> STATE_DICT_TYPE:
+    """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``."""
+    nested: STATE_DICT_TYPE = {}
+    for key, value in state_dict.items():
+        set_element(nested, mapping[key], value)
+    return nested
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..27807cfc768979c6c23b53bd6facd4c254c7ecbd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_sharded_tensor_utils.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import copy
+
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
+from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
+from torch.distributed.remote_device import _remote_device
+
+from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
+from .utils import _element_wise_add, _normalize_device_info
+
+
+# TODO: We need to refactor this code.
+def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
+    r"""
+    Transform ``state_dict`` by flattening all nested ShardedTensor instances found.
+
+    The resulting ShardedTensor instances are only correct regarding the local shard and
+    MUST not be used for any other purpose but checkpointing, as no operator will work with them.
+
+    This function should be used in conjunction with a state_dict produced by FSDP's
+    StateDictType.SHARDED_STATE_DICT methods.
+    """
+    new_state_dict: STATE_DICT_TYPE = {}
+
+    def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
+        if not isinstance(value, ShardedTensor):
+            set_element(new_state_dict, path, value)
+            return
+        shards = value.local_shards()
+
+        if len(shards) == 0:
+            return
+        if len(shards) != 1:
+            set_element(new_state_dict, path, value)
+            return
+
+        outer_shard = shards[0]
+
+        inner_st = outer_shard.tensor
+        if not isinstance(inner_st, ShardedTensor):
+            set_element(new_state_dict, path, value)
+            return
+
+        if len(inner_st.local_shards()) != 1:
+            raise ValueError("Cannot handle inner tensor with more than 1 shard")
+        inner_shard = inner_st.local_shards()[0]
+
+        local_shards = [
+            Shard(
+                tensor=inner_shard.tensor,
+                metadata=ShardMetadata(
+                    shard_offsets=_element_wise_add(
+                        outer_shard.metadata.shard_offsets,
+                        inner_shard.metadata.shard_offsets,
+                    ),
+                    shard_sizes=inner_shard.metadata.shard_sizes,
+                    placement=f"rank:{dist.get_rank()}/{inner_shard.tensor.device}",
+                ),
+            )
+        ]
+
+        st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
+        other_rank = 0 if dist.get_rank() > 0 else 1
+        device_info = _normalize_device_info(inner_shard.tensor.device.type, 0)
+
+        # Remove the outer ST shard the inner ST covers
+        for i, shard_md in enumerate(st_meta.shards_metadata):
+            if shard_md.shard_offsets == outer_shard.metadata.shard_offsets:
+                st_meta.shards_metadata.pop(i)
+                break
+
+        # Attribute other rank for the other shards
+        for shard_md in st_meta.shards_metadata:
+            shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}")
+
+        # Add other inner shards from the inner tensor
+        for inner_md in inner_st.metadata().shards_metadata:
+            if inner_md.shard_offsets != inner_shard.metadata.shard_offsets:
+                st_meta.shards_metadata.append(
+                    ShardMetadata(
+                        shard_offsets=_element_wise_add(
+                            outer_shard.metadata.shard_offsets,
+                            inner_md.shard_offsets,
+                        ),
+                        shard_sizes=inner_md.shard_sizes,
+                        placement=f"rank:{other_rank}/{device_info}",
+                    )
+                )
+
+        # Finally add this shard
+        st_meta.shards_metadata.append(local_shards[0].metadata)
+
+        st = ShardedTensor._init_from_local_shards_and_global_metadata(
+            local_shards=local_shards,
+            sharded_tensor_metadata=st_meta,
+        )
+        set_element(new_state_dict, path, st)
+
+    traverse_state_dict(state_dict, rewrite_dict)
+    return new_state_dict
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_storage_utils.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_storage_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d66e1a605270766893c21235efc63930cc7b79
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_storage_utils.py
@@ -0,0 +1,50 @@
+import os
+from typing import List, Type, Union
+
+from .filesystem import FileSystemReader, FileSystemWriter
+
+from .storage import StorageReader, StorageWriter
+
+
+def _storage_setup(
+    storage: Union[StorageReader, StorageWriter, None],
+    checkpoint_id: Union[str, os.PathLike, None],
+    reader: bool = False,
+) -> Union[None, StorageReader, StorageWriter]:
+    if storage:
+        if checkpoint_id is not None:
+            storage.reset(checkpoint_id)
+        return storage
+
+    if not checkpoint_id:
+        raise RuntimeError(
+            "`checkpoint_id` must be specificed if "
+            "storage_reader/storage_writer is None."
+        )
+
+    targets: List[Type[Union[StorageReader, StorageWriter]]] = []
+    if reader:
+        targets = [
+            FileSystemReader,
+        ]
+    else:
+        targets = [
+            FileSystemWriter,
+        ]
+    try:
+        from .fsspec import FsspecReader, FsspecWriter
+
+        targets.append(FsspecReader if reader else FsspecWriter)
+    except Exception:
+        pass
+
+    for target in targets:
+        if target.validate_checkpoint_id(checkpoint_id):
+            storage = target(checkpoint_id)  # type: ignore[call-arg]
+            storage.reset(checkpoint_id)
+            return storage
+
+    raise RuntimeError(
+        "Cannot detect which StorageReader or StorageWriter to use. "
+        "Please specify the storage_reader/storage_writer."
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/_traverse.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_traverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b40729b282424b127cf009cde3532abaf3792ff
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/_traverse.py
@@ -0,0 +1,167 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import (
+    Callable,
+    cast,
+    Collection,
+    List,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+)
+
+import torch
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
+
+PATH_ITEM = Union[str, int]
+OBJ_PATH = Tuple[PATH_ITEM, ...]
+T = TypeVar("T")
+
+STATE_DICT_ITEM = object
+CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
+
+__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
+
+
+def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
+    return isinstance(value, torch.Tensor)
+
+
+# TODO: update docstring for traverse.py
+def traverse_state_dict(
+    state_dict: STATE_DICT_TYPE,
+    visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
+    keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
+) -> None:
+    """
+    Invoke ``visitor`` for each value recursively in ``state_dict``.
+
+    Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
+    to false for all elements.
+    By default, all collections with at least one ``torch.Tensor`` element are traversed.
+    Visitor takes a path argument that is a tuple of the keys used to reach it.
+    """
+
+    # a value is terminal if it has no other containers values inside it
+    def _is_terminal(value: STATE_DICT_ITEM) -> bool:
+        values: Collection[STATE_DICT_ITEM]
+        if isinstance(value, Mapping):
+            values = value.values()
+        elif isinstance(value, list):
+            values = value
+        else:
+            return True
+
+        for entry in values:
+            if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
+                return False
+            if keep_traversing is not None and keep_traversing(entry):
+                return False
+        return True
+
+    def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
+        if _is_terminal(value):
+            visitor(path, value)
+        elif isinstance(value, Mapping):
+            for k, v in value.items():
+                _traverse_obj(path + (str(k),), v)
+        elif isinstance(value, list):
+            for i, v in enumerate(value):
+                _traverse_obj(path + (i,), v)
+
+    for key, value in state_dict.items():
+        _traverse_obj((str(key),), value)
+
+
+def set_element(
+    root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
+) -> None:
+    """Set ``value`` in ``root_dict`` along the ``path`` object path."""
+    cur_container = cast(CONTAINER_TYPE, root_dict)
+
+    def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
+        while len(lst) <= idx:
+            lst.append(None)
+
+    for i in range(1, len(path)):
+        prev_key = path[i - 1]
+        key = path[i]
+        def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
+
+        if isinstance(cur_container, Mapping):
+            cur_container = cast(
+                CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
+            )
+        else:
+            extend_list(cur_container, prev_key)
+            if cur_container[prev_key] is None:
+                cur_container[prev_key] = def_val
+            cur_container = cur_container[prev_key]
+
+    key = path[-1]
+    if type(key) == int:
+        extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
+
+    cur_container[key] = value
+
+
+def get_element(
+    root_dict: STATE_DICT_TYPE,
+    path: OBJ_PATH,
+    default_value: Optional[T] = None,
+) -> Optional[T]:
+    """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
+    cur_value = cast(CONTAINER_TYPE, root_dict)
+    for part in path:
+        if type(part) is int:
+            if not isinstance(cur_value, list) or len(cur_value) < part:
+                return default_value
+        elif not isinstance(cur_value, Mapping) or part not in cur_value:
+            return default_value
+
+        cur_value = cast(CONTAINER_TYPE, cur_value[part])
+    return cast(Optional[T], cur_value)
+
+
+def _print_nested(
+    value: STATE_DICT_ITEM,
+    prefix: str = "",
+    print_fun: Callable[[str], None] = print,
+) -> None:
+    if type(value) is ShardedTensor:
+        print_fun(f"{prefix} ShardedTensor size: {value.size()}")
+        for shard in value.local_shards():
+            _print_nested(
+                shard.tensor,
+                f"{shard.metadata.shard_offsets} ",
+                print_fun=print_fun,
+            )
+    elif type(value) is (DTensor):
+        print_fun(f"{prefix} DistributedTensor size: {value.size()}")
+        # TODO: add local offset for _local_tensor in print_nested.
+        _print_nested(
+            value._local_tensor,
+            print_fun=print_fun,
+        )
+    elif isinstance(value, torch.Tensor):
+        print_fun(f"{prefix} Tensor size: {value.size()}")
+    else:
+        print_fun(f"{prefix} Type: {type(value)}")
+
+
+def print_tensor(
+    path: OBJ_PATH,
+    value: STATE_DICT_ITEM,
+    print_fun: Callable[[str], None] = print,
+) -> None:
+    """
+    Use this callback with traverse_state_dict to print its content.
+
+    By default the content is printed using the builtin ``print`` but this can
+    be change by passing a different ``print_fun` callable.
+    """
+    _print_nested(value, prefix=str(path), print_fun=print_fun)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/api.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a08215d4751ae7b3594f4e4ac67d72c3d805a3b4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/api.py
@@ -0,0 +1,41 @@
+import traceback as tb
+from typing import Any, Dict, Tuple
+
+WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
+
+__all__ = ["CheckpointException"]
+
+
+def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
+    return (exc, tb.extract_tb(exc.__traceback__))
+
+
+def _is_wrapped_exception(obj: Any) -> bool:
+    if not isinstance(obj, tuple):
+        return False
+    if len(obj) != 2:
+        return False
+    return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
+
+
+class CheckpointException(BaseException):
+    """Exception raised if failure was detected as part of a checkpoint load or save."""
+
+    def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
+        super().__init__(msg, failures)
+        self._failures = failures
+
+    @property
+    def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
+        """Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
+        return self._failures
+
+    def __str__(self):
+        str = f"CheckpointException ranks:{self._failures.keys()}\n"
+        for rank, exc_pair in self._failures.items():
+            exc, trace = exc_pair
+            str += f"Traceback (most recent call last): (RANK {rank})\n"
+            if trace is not None:
+                str += "".join(tb.format_list(trace))
+            str += "".join(tb.format_exception_only(type(exc), value=exc))
+        return str
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/default_planner.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/default_planner.py
new file mode 100644
index 0000000000000000000000000000000000000000..df49dfca05ec1c6e1c2a5ec20669df7b64194ebc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/default_planner.py
@@ -0,0 +1,420 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import dataclasses
+import io
+import logging
+import operator
+from collections import ChainMap
+from functools import reduce
+from typing import Any, cast, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.distributed._shard._utils import narrow_tensor_by_index
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
+from torch.distributed.checkpoint._nested_dict import (
+    FLATTEN_MAPPING,
+    flatten_state_dict,
+)
+from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
+from torch.distributed.checkpoint._traverse import set_element
+from torch.distributed.checkpoint.metadata import (
+    BytesStorageMetadata,
+    ChunkStorageMetadata,
+    Metadata,
+    MetadataIndex,
+    STATE_DICT_TYPE,
+    STORAGE_TYPES,
+    TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import (
+    LoadPlan,
+    LoadPlanner,
+    ReadItem,
+    SavePlan,
+    SavePlanner,
+    WriteItem,
+    WriteItemType,
+)
+from torch.distributed.checkpoint.planner_helpers import (
+    _create_default_metadata_only_plan,
+    _create_read_items,
+    _create_write_items,
+    _init_state_dict,
+)
+from torch.distributed.checkpoint.utils import find_state_dict_object
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+__all__ = [
+    "DefaultSavePlanner",
+    "DefaultLoadPlanner",
+    "create_default_local_load_plan",
+    "create_default_global_load_plan",
+    "create_default_local_save_plan",
+    "create_default_global_save_plan",
+]
+
+
+# TODO: Update docstrings for default_planner.py
+class DefaultSavePlanner(SavePlanner):
+    mappings: FLATTEN_MAPPING
+
+    def __init__(
+        self,
+        flatten_state_dict: bool = True,
+        flatten_sharded_tensors: bool = True,
+        dedup_replicated_tensors: Optional[bool] = None,
+    ) -> None:
+        self.flatten_state_dict = flatten_state_dict
+        self.flatten_sharded_tensors = flatten_sharded_tensors
+        self.mappings = {}
+
+        if dedup_replicated_tensors is not None:
+            logger.warning(
+                "DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
+                "deprecated, and no longer has any effect. Please remove this argument "
+                "from your call."
+            )
+
+    def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
+        if self.flatten_state_dict:
+            state_dict, self.mappings = flatten_state_dict(state_dict)
+        if self.flatten_sharded_tensors:
+            state_dict = _flatten_sharded_tensors(state_dict)
+        self.state_dict = state_dict
+        self.is_coordinator = is_coordinator
+
+    def create_local_plan(self) -> SavePlan:
+        plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
+        if self.flatten_state_dict:
+            plan = dataclasses.replace(plan, planner_data=self.mappings)
+        self.plan = plan
+
+        return self.plan
+
+    def create_global_plan(
+        self, all_plans: List[SavePlan]
+    ) -> Tuple[List[SavePlan], Metadata]:
+        all_plans = dedup_save_plans(all_plans)
+
+        global_plan, metadata = create_default_global_save_plan(all_plans)
+
+        if self.flatten_state_dict:
+            # | does not work for Python 3.8 or older version.
+            # merged_mappings = reduce(
+            #     lambda x, y: x | y, (p.planner_data for p in global_plan)
+            # )
+            planner_data_dict = [p.planner_data for p in global_plan]
+            merged_mappings = dict(ChainMap(*planner_data_dict))
+            metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
+
+        if not _validate_global_plan(global_plan, metadata):
+            raise ValueError("Failed to validate global plan")
+
+        self.global_plan = global_plan
+        self.metadata = metadata
+
+        return self.global_plan, self.metadata
+
+    def finish_plan(self, new_plan: SavePlan) -> SavePlan:
+        self.plan = new_plan
+        return new_plan
+
+    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
+        object = self.lookup_object(write_item.index)
+        return self.transform_object(write_item, object)
+
+    def lookup_object(self, index: MetadataIndex) -> Any:
+        """Extension from the planner interface to make it easy to extend the default planner."""
+        return find_state_dict_object(self.state_dict, index)
+
+    def transform_object(self, write_item: WriteItem, object: Any):
+        """Extension from the planner interface to make it easy to extend the default planner."""
+        if write_item.type == WriteItemType.BYTE_IO:
+            bytes = io.BytesIO()
+            torch.save(object, bytes)
+            object = bytes
+        return object
+
+
+class DefaultLoadPlanner(LoadPlanner):
+    """
+    DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
+
+    In particular it adds the following:
+
+    flatten_state_dict: Handle state_dict with nested dicts
+    flatten_sharded_tensors: For FSDP in 2D parallel mode
+    """
+
+    original_state_dict: STATE_DICT_TYPE
+    mappings: FLATTEN_MAPPING
+
+    def __init__(
+        self,
+        flatten_state_dict: bool = True,
+        flatten_sharded_tensors: bool = True,
+    ) -> None:
+        self.flatten_state_dict = flatten_state_dict
+        self.flatten_sharded_tensors = flatten_sharded_tensors
+        self.original_state_dict = {}
+        self.mappings = {}
+
+    def set_up_planner(
+        self,
+        state_dict: STATE_DICT_TYPE,
+        metadata: Metadata,
+        is_coordinator: bool,
+    ) -> None:
+        _init_state_dict(state_dict)
+        self.original_state_dict = state_dict
+
+        if self.flatten_sharded_tensors:
+            state_dict = _flatten_sharded_tensors(state_dict)
+
+        if self.flatten_state_dict:
+            state_dict, self.mappings = flatten_state_dict(state_dict)
+
+        self.state_dict = state_dict
+        self.metadata = metadata
+        self.is_coordinator = is_coordinator
+
+    def create_local_plan(self) -> LoadPlan:
+        return create_default_local_load_plan(self.state_dict, self.metadata)
+
+    def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
+        return create_default_global_load_plan(global_plan)
+
+    def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
+        return new_plan
+
+    def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
+        if self.flatten_state_dict:
+            set_element(
+                self.original_state_dict,
+                self.mappings[read_item.dest_index.fqn],
+                torch.load(value),
+            )
+        else:
+            self.state_dict[read_item.dest_index.fqn] = torch.load(value)
+
+    def resolve_tensor(self, read_item: ReadItem):
+        tensor = self.lookup_tensor(read_item.dest_index)
+        return self.transform_tensor(read_item, tensor)
+
+    def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
+        pass
+
+    def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
+        """Extension from the planner interface to make it easy to extend the default planner."""
+        return find_state_dict_object(self.state_dict, index)
+
+    def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
+        """Extension from the planner interface to make it easy to extend the default planner."""
+        return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
+
+
+def create_default_local_load_plan(
+    state_dict: Dict[str, Any],
+    metadata: Metadata,
+) -> LoadPlan:
+    requests = []
+    """
+    Create the ``LoadPlan`` used by DefaultLoadPlanner.
+
+    It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
+
+    The default behavior is to match key exactly between state_dict and metadata.
+    It handles resharding by issuing multiple read requests against storage in order to match
+    load requirements.
+    """
+
+    for fqn, obj in state_dict.items():
+        md = metadata.state_dict_metadata[fqn]
+        # Since DTensor supports submesh, adding extra check to ensure _create_read_items()
+        # gets called only when the current rank is part of the mesh for the corresponding DTensor.
+        if isinstance(obj, DTensor):
+            if obj.device_mesh.get_coordinate() is not None:
+                requests += _create_read_items(fqn, md, obj)
+        else:
+            requests += _create_read_items(fqn, md, obj)
+
+    return LoadPlan(requests)
+
+
+def create_default_global_load_plan(
+    all_plans: List[LoadPlan],
+) -> List[LoadPlan]:
+    """
+    Create global load plan used by DefaultLoadPlanner.
+
+    The default load behavior involved no global coordination and this function
+    currently doesn't change the local plans.
+    """
+    return all_plans
+
+
+def create_default_local_save_plan(
+    state_dict: Dict[str, Any], is_coordinator: bool
+) -> SavePlan:
+    """
+    Create the ``SavePlan`` used by DefaultSavePlanner.
+
+    On non-coordinator ranks, this function ignores tensors and non-tensor objects,
+    only producing writes for ShardedTensor objects.
+
+    On the coordinator rank, produce writes for all values.
+    """
+    requests = []
+    for fqn, obj in state_dict.items():
+        # Since DTensor supports submesh, adding extra check to ensure _create_write_items()
+        # gets called only when the current rank is part of the mesh for the corresponding DTensor.
+        if isinstance(obj, DTensor):
+            if obj.device_mesh.get_coordinate() is not None:
+                requests += _create_write_items(fqn, obj)
+        elif isinstance(obj, (torch.Tensor)) or is_coordinator:
+            requests += _create_write_items(fqn, obj)
+
+    return SavePlan(requests)
+
+
+def create_default_global_save_plan(
+    all_plans: List[SavePlan],
+    rewrite_index_hints: bool = True,
+) -> Tuple[List[SavePlan], Metadata]:
+    """
+    Create the global plan and metadata used by DefaultSavePlanner.
+
+    Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
+
+    The only global planning change is to update index hints in all ``MetadataIndex`` objects if
+    ``rewrite_index_hints`` is True.
+    """
+    md: Dict[str, STORAGE_TYPES] = {}
+    new_plans = []
+    for plan in all_plans:
+        new_items = []
+        for item in plan.items:
+            if not item.type == WriteItemType.SHARD:
+                assert item.index.fqn not in md
+
+            if item.type == WriteItemType.BYTE_IO:
+                md[item.index.fqn] = BytesStorageMetadata()
+                new_items.append(item)
+            else:
+                assert item.tensor_data is not None
+                tensor_md = cast(
+                    TensorStorageMetadata,
+                    md.setdefault(
+                        item.index.fqn,
+                        TensorStorageMetadata(
+                            properties=item.tensor_data.properties,
+                            size=item.tensor_data.size,
+                            chunks=[],
+                        ),
+                    ),
+                )
+                new_item = item
+                if rewrite_index_hints:
+                    new_index = dataclasses.replace(
+                        item.index, index=len(tensor_md.chunks)
+                    )
+                    new_item = dataclasses.replace(item, index=new_index)
+                new_items.append(new_item)
+
+                assert (
+                    item.tensor_data.chunk is not None
+                ), f"""
+                    Cannot create MD for tensor without bounds.
+                    FQN: {item.index.fqn}
+                """
+                tensor_md.chunks.append(item.tensor_data.chunk)
+        new_plans.append(dataclasses.replace(plan, items=new_items))
+    return (new_plans, Metadata(md))
+
+
+def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
+    """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``."""
+    plan = _create_default_metadata_only_plan(state_dict)
+    _, md = create_default_global_save_plan([plan])
+    return md
+
+
+def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
+    """Check if two boxes overlap. Tuples are (offset, lengths)."""
+    # For each dim of each shard, check if one shard resides on the other
+    # end of second shard with respect to that dim. As an example for a 2D
+    # shard, we would check if one shard is above or on the left of the
+    # other shard.
+    ndims = len(box0.offsets)
+    for i in range(ndims):
+        if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
+            return False
+        if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
+            return False
+
+    return True
+
+
+def _check_box_bounds(
+    outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
+) -> bool:
+    for i in range(len(outer_box_size)):
+        if inner_box.offsets[i] < 0:
+            return False
+        if inner_box.sizes[i] < 0:
+            return False
+        if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
+            return False
+
+    return True
+
+
+def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
+    all_good = True
+    for key, value in metadata.state_dict_metadata.items():
+        if isinstance(value, BytesStorageMetadata):
+            continue
+        if len(value.size) == 0:
+            continue
+        chunks_volume = 0
+        for chunk_idx, chunk0 in enumerate(value.chunks):
+            # Compute the volume
+            if not _check_box_bounds(value.size, chunk0):
+                logger.warning(
+                    """
+                        key:%s has out of bounds chunk:
+                        tensor-size:%s chunk: %s
+                    """,
+                    key,
+                    value.size,
+                    chunk0,
+                )
+                all_good = False
+            chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
+
+            # Check for overlap
+            for chunk1 in value.chunks[chunk_idx + 1 :]:
+                if _check_box_overlap(chunk0, chunk1):
+                    logger.warning(
+                        "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1
+                    )
+                    all_good = False
+
+        # Check whether combined chunk cover the whole tensor
+        tensor_volume = reduce(operator.mul, value.size, 1)
+        if chunks_volume != tensor_volume:
+            logger.warning(
+                """
+                    key:%s invalid fill tensor-volume:
+                    %s chunks-volume: %s
+                """,
+                key,
+                tensor_volume,
+                chunks_volume,
+            )
+            all_good = False
+
+    return all_good
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/filesystem.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/filesystem.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f7b081ee6408bebed6974c670429683a46784b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/filesystem.py
@@ -0,0 +1,618 @@
+import collections
+import dataclasses
+import io
+import os
+import pickle
+import queue
+import threading
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from typing import (
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    IO,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
+
+import torch
+from torch import Tensor
+from torch._utils import _get_available_device_type, _get_device_module
+from torch.distributed._shard._utils import narrow_tensor_by_index
+from torch.futures import Future
+
+from .metadata import Metadata, MetadataIndex
+from .planner import (
+    LoadItemType,
+    LoadPlan,
+    LoadPlanner,
+    ReadItem,
+    SavePlan,
+    SavePlanner,
+    WriteItem,
+    WriteItemType,
+)
+from .storage import StorageReader, StorageWriter, WriteResult
+from .utils import _create_file_view
+
+__all__ = ["FileSystemWriter", "FileSystemReader"]
+
+
+@dataclass
+class _StorageInfo:
+    """This is the per entry storage info."""
+
+    relative_path: str
+    offset: int
+    length: int
+
+
+@dataclass
+class _StoragePrefix:
+    prefix: str
+
+
+DEFAULT_SUFFIX = ".distcp"
+
+
+class _TensorLoader(ABC):
+    @abstractmethod
+    def add(self, size: int, obj: object) -> None:
+        pass
+
+    @abstractmethod
+    def start_loading(self) -> None:
+        pass
+
+    @abstractmethod
+    def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
+        pass
+
+
+class _SerialCpuLoader(_TensorLoader):
+    def __init__(self, resolve_fun: Callable) -> None:
+        self.resolve_fun = resolve_fun
+        self.items: List[Tuple[int, object]] = []
+
+    def add(self, size: int, obj: object) -> None:
+        self.items.append((size, obj))
+
+    def start_loading(self) -> None:
+        pass
+
+    def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
+        for _, obj in self.items:
+            tensor = self.resolve_fun(obj).detach()
+            tensor = tensor.cpu()
+            if tensor.storage().size() != tensor.numel():
+                tensor = tensor.clone()
+            yield (
+                tensor,
+                obj,
+            )
+
+
+class _OverlappingCpuLoader(_TensorLoader):
+    def __init__(
+        self,
+        resolve_fun: Callable,
+        stream: Optional[torch.Stream] = None,
+        inflight_threshhold: int = 1_000_000,
+    ) -> None:
+        self.resolve_fun = resolve_fun
+        self.items: List[Tuple[int, object]] = []
+        self.inflight_threshhold = inflight_threshhold
+        self.in_flight_data = 0
+        self.current_items: collections.deque = collections.deque()
+        self.idx = 0
+        self.started = False
+        self.device_type = (
+            stream.device_type if stream else _get_available_device_type()
+        )
+        self.device_module = _get_device_module(self.device_type)
+        self.stream = cast(
+            torch.cuda.Stream, stream or self.device_module.current_stream()
+        )
+        if self.stream != self.device_module.current_stream():
+            self.stream.wait_stream(self.device_module.current_stream())
+
+    @property
+    def _done(self) -> bool:
+        return self.idx >= len(self.items)
+
+    def _drain(self) -> List[Tuple[torch.Tensor, object]]:
+        drained = []
+        if self.in_flight_data >= self.inflight_threshhold:
+            self.stream.synchronize()
+        while self.in_flight_data >= self.inflight_threshhold:
+            val = self.current_items.popleft()
+            self.in_flight_data -= val[0].numel() * val[0].element_size()
+            drained.append(val)
+        return drained
+
+    def _refill(self) -> None:
+        with self.device_module.stream(self.stream):
+            while not self._done and self.in_flight_data < self.inflight_threshhold:
+                _, obj = self.items[self.idx]
+                self.idx += 1
+                tensor = self.resolve_fun(obj).detach()
+                if tensor.device.type == self.device_type:
+                    tensor = tensor.to(device="cpu", non_blocking=True)
+                elif tensor.device == torch.device("cpu"):
+                    if (
+                        tensor.untyped_storage().size()
+                        != tensor.numel() * tensor.itemsize
+                    ):
+                        # this forces the tensor to be both contiguous and with minimal storage
+                        tensor = tensor.clone()
+
+                self.current_items.append(
+                    (
+                        tensor,
+                        obj,
+                    )
+                )
+                self.in_flight_data += tensor.numel() * tensor.element_size()
+
+    def _finish(self) -> Iterable[Tuple[torch.Tensor, object]]:
+        assert self._done
+        if len(self.current_items) > 0:
+            self.stream.synchronize()
+        return self.current_items
+
+    def add(self, size: int, obj: object) -> None:
+        if self.started:
+            raise RuntimeError("cannot add items after loading started")
+        self.items.append((size, obj))
+
+    def start_loading(self) -> None:
+        if self.started:
+            return
+        self.started = True
+        self.items.sort(key=lambda x: x[0])
+        self._refill()
+
+    def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
+        self.start_loading()
+        while not self._done:
+            drained = self._drain()
+            self._refill()
+            yield from drained
+
+        yield from self._finish()
+
+
+def _item_size(item: WriteItem) -> int:
+    size = 1
+    assert item.tensor_data is not None
+    # can't use math.prod as PT needs to support older python
+    for s in item.tensor_data.size:
+        size *= s
+
+    dtype = item.tensor_data.properties.dtype
+    return size * torch._utils._element_size(dtype)
+
+
+def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
+    if bins == 1:
+        return [items]
+
+    bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
+    tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
+
+    buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
+    bucket_sizes = [0 for _ in range(bins)]
+
+    tensor_w.sort(key=_item_size, reverse=True)
+
+    for i, wi in enumerate(bytes_w):
+        buckets[i % bins].append(wi)
+
+    for wi in tensor_w:
+        # TODO replace with headq
+        idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
+        buckets[idx].append(wi)
+        bucket_sizes[idx] += _item_size(wi)
+
+    return buckets
+
+
+def _write_item(
+    stream: io.IOBase,
+    data: Union[io.BytesIO, torch.Tensor],
+    write_item: WriteItem,
+    storage_key: str,
+) -> WriteResult:
+    offset = stream.tell()
+
+    if write_item.type == WriteItemType.BYTE_IO:
+        assert isinstance(data, io.BytesIO)
+        stream.write(data.getbuffer())
+    else:
+        assert isinstance(data, torch.Tensor)
+        assert data.device == torch.device("cpu")
+        torch.save(data, cast(IO[bytes], stream))
+    length = stream.tell() - offset
+
+    return WriteResult(
+        index=write_item.index,
+        size_in_bytes=length,
+        storage_data=_StorageInfo(storage_key, offset, length),
+    )
+
+
+def _write_files_from_queue(
+    create_stream: Callable,
+    file_queue: queue.Queue,
+    result_queue: queue.Queue,
+    planner: SavePlanner,
+    inflight_threshhold: int,
+    use_fsync: bool,
+    thread_count: int,
+) -> None:
+    try:
+        while True:
+            file_name, storage_key, write_items = file_queue.get_nowait()
+            loader: _TensorLoader
+
+            custom_backend_name = torch._C._get_privateuse1_backend_name()
+            custom_device_mod = getattr(torch, custom_backend_name, None)
+
+            # TODO: Using the OverlappingCpuLoader with multiple threads creates significant
+            # performance degredation, observed as being related to cuda stream syncs. We
+            # should try to fix this and use _OverlappingCpuLoader for all threaded cases
+            if (
+                thread_count == 1
+                and (
+                    torch.cuda.is_available()
+                    or (custom_device_mod and custom_device_mod.is_available())
+                )
+                and inflight_threshhold > 0
+            ):
+                loader = _OverlappingCpuLoader(
+                    planner.resolve_data,
+                    inflight_threshhold=inflight_threshhold,
+                )
+            else:
+                loader = _SerialCpuLoader(
+                    planner.resolve_data,
+                )
+
+            tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
+            for write_item in tensor_w:
+                loader.add(_item_size(write_item), write_item)
+            loader.start_loading()
+
+            bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
+            write_results = []
+
+            with create_stream(file_name, "wb") as stream:
+                for write_item in bytes_w:
+                    data = planner.resolve_data(write_item)
+                    write_results.append(
+                        _write_item(stream, data, write_item, storage_key)
+                    )
+
+                for tensor, write_item in loader.values():
+                    assert tensor.is_cpu
+                    write_results.append(
+                        _write_item(stream, tensor, write_item, storage_key)
+                    )
+
+                if use_fsync:
+                    try:
+                        os.fsync(stream.fileno())
+                    except AttributeError:
+                        os.sync()
+            result_queue.put(write_results)
+    except queue.Empty:
+        pass
+
+
+class FileSystemBase(ABC):
+    @contextmanager
+    @abstractmethod
+    def create_stream(
+        self, path: Union[str, os.PathLike], mode: str
+    ) -> Generator[io.IOBase, None, None]:
+        ...
+
+    @abstractmethod
+    def concat_path(
+        self, path: Union[str, os.PathLike], suffix: str
+    ) -> Union[str, os.PathLike]:
+        ...
+
+    @abstractmethod
+    def rename(
+        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
+    ) -> None:
+        ...
+
+    @abstractmethod
+    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
+        ...
+
+    @abstractmethod
+    def mkdir(self, path: Union[str, os.PathLike]) -> None:
+        ...
+
+    @classmethod
+    @abstractmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        ...
+
+
+class FileSystem(FileSystemBase):
+    @contextmanager
+    def create_stream(
+        self, path: Union[str, os.PathLike], mode: str
+    ) -> Generator[io.IOBase, None, None]:
+        with cast(Path, path).open(mode) as stream:
+            yield cast(io.IOBase, stream)
+
+    def concat_path(
+        self, path: Union[str, os.PathLike], suffix: str
+    ) -> Union[str, os.PathLike]:
+        return cast(Path, path) / suffix
+
+    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
+        if not isinstance(path, Path):
+            path = Path(path)
+        return path
+
+    def rename(
+        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
+    ) -> None:
+        cast(Path, path).rename(cast(Path, new_path))
+
+    def mkdir(self, path: Union[str, os.PathLike]) -> None:
+        cast(Path, path).mkdir(parents=True, exist_ok=True)
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        if isinstance(checkpoint_id, Path):
+            return True
+
+        if "://" in str(checkpoint_id):
+            return False
+
+        for p in Path(checkpoint_id).parents:
+            if p.exists() and os.access(str(p), os.W_OK):
+                return True
+
+        return False
+
+
+class FileSystemWriter(StorageWriter):
+    """
+    Basic implementation of StorageWriter using file IO.
+
+    This implementation makes the following assumptions and simplifications:
+
+    * The checkpoint path is an empty or non-existing directory.
+    * File creation is atomic
+
+    The checkpoint consist of one file per write request plus
+    a `.metadata` file with the serialized metadata.
+
+    """
+
+    def __init__(
+        self,
+        path: Union[str, os.PathLike],
+        single_file_per_rank: bool = True,
+        sync_files: bool = True,
+        thread_count: int = 1,
+        per_thread_copy_ahead: int = 10_000_000,
+    ) -> None:
+        """
+        Initialize the writer pointing to `path`.
+
+        Args:
+            path: directory where the checkpoint will be written to.
+            single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
+            sync_files : force files to be synced to permanent storage. Default to True.
+            thread_count: Number of IO threads to use to write. Default to 1.
+            per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
+
+        N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
+        """
+        super().__init__()
+        self.fs = FileSystem()
+        self.path = self.fs.init_path(path)
+        self.single_file_per_rank = single_file_per_rank
+        self.sync_files = sync_files
+        self.thread_count = thread_count
+        self.per_thread_copy_ahead = per_thread_copy_ahead
+
+    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
+        if checkpoint_id:
+            self.path = self.fs.init_path(checkpoint_id)
+
+    def set_up_storage_writer(self, is_coordinator: bool) -> None:
+        pass
+
+    def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
+        self.fs.mkdir(self.path)
+        return plan
+
+    def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
+        new_plans = [
+            dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
+            for i, plan in enumerate(global_plan)
+        ]
+        return new_plans
+
+    def write_data(
+        self,
+        plan: SavePlan,
+        planner: SavePlanner,
+    ) -> Future[List[WriteResult]]:
+        storage_plan: _StoragePrefix = plan.storage_data
+        file_count = 0
+
+        def gen_file():
+            nonlocal file_count
+            file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
+            file_count += 1
+            return file_name
+
+        file_queue: queue.Queue = queue.Queue()
+        if self.single_file_per_rank:
+            for bucket in _split_by_size_and_type(self.thread_count, plan.items):
+                file_name = gen_file()
+                path = self.fs.concat_path(self.path, file_name)
+                file_queue.put((path, file_name, bucket))
+        else:
+            for item in plan.items:
+                file_name = gen_file()
+                path = self.fs.concat_path(self.path, file_name)
+                file_queue.put((path, file_name, [item]))
+
+        result_queue: queue.Queue = queue.Queue()
+
+        threads = []
+        for _ in range(1, self.thread_count):
+            t = threading.Thread(
+                target=_write_files_from_queue,
+                args=(
+                    self.fs.create_stream,
+                    file_queue,
+                    result_queue,
+                    planner,
+                    self.per_thread_copy_ahead,
+                    self.sync_files,
+                    self.thread_count,
+                ),
+            )
+            t.start()
+            threads.append(t)
+
+        _write_files_from_queue(
+            create_stream=self.fs.create_stream,
+            file_queue=file_queue,
+            result_queue=result_queue,
+            planner=planner,
+            inflight_threshhold=self.per_thread_copy_ahead,
+            use_fsync=self.sync_files,
+            thread_count=self.thread_count,
+        )
+
+        for t in threads:
+            t.join()
+
+        res = []
+        try:
+            while True:
+                res += result_queue.get_nowait()
+        except queue.Empty:
+            pass
+
+            fut: Future[List[WriteResult]] = Future()
+            fut.set_result(res)
+            return fut
+
+    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
+        storage_md = dict()
+        for wr_list in results:
+            storage_md.update({wr.index: wr.storage_data for wr in wr_list})
+        metadata.storage_data = storage_md
+        tmp_path = cast(Path, self.fs.concat_path(self.path, ".metadata.tmp"))
+        meta_path = cast(Path, self.fs.concat_path(self.path, ".metadata"))
+        with self.fs.create_stream(tmp_path, "wb") as metadata_file:
+            pickle.dump(metadata, metadata_file)
+            if self.sync_files:
+                try:
+                    os.fsync(metadata_file.fileno())
+                except AttributeError:
+                    os.sync()
+
+        self.fs.rename(tmp_path, meta_path)
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        return FileSystem.validate_checkpoint_id(checkpoint_id)
+
+
+class FileSystemReader(StorageReader):
+    def __init__(self, path: Union[str, os.PathLike]) -> None:
+        super().__init__()
+        self.fs = FileSystem()
+        self.path = self.fs.init_path(path)
+        self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
+
+    def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase:
+        return _create_file_view(file, sinfo.offset, sinfo.length)
+
+    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
+        self.storage_data = dict()
+        if checkpoint_id:
+            self.path = self.fs.init_path(checkpoint_id)
+
+    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
+        # group requests by file
+        per_file: Dict[str, List[ReadItem]] = dict()
+        for read_item in plan.items:
+            item_md = self.storage_data[read_item.storage_index]
+            path = item_md.relative_path
+            per_file.setdefault(path, []).append(read_item)
+
+        for relative_path, reqs in per_file.items():
+            new_path = self.fs.concat_path(self.path, relative_path)
+            with self.fs.create_stream(new_path, "rb") as stream:
+                # TODO sort by offset and cache the reading
+                for req in reqs:
+                    item_md = self.storage_data[req.storage_index]
+                    file_slice = self._slice_file(stream, item_md)
+                    if req.type == LoadItemType.BYTE_IO:
+                        read_bytes = io.BytesIO(file_slice.read(item_md.length))
+                        read_bytes.seek(0)
+                        planner.load_bytes(req, read_bytes)
+                    else:
+                        tensor = cast(
+                            Tensor,
+                            torch.load(cast(IO[bytes], file_slice), map_location="cpu"),
+                        )
+                        tensor = narrow_tensor_by_index(
+                            tensor, req.storage_offsets, req.lengths
+                        )
+                        target_tensor = planner.resolve_tensor(req).detach()
+
+                        assert (
+                            target_tensor.size() == tensor.size()
+                        ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
+                        target_tensor.copy_(tensor)
+                        planner.commit_tensor(req, target_tensor)
+
+        fut: Future = Future()
+        fut.set_result(None)
+        return fut
+
+    # Implementing the abstract function in StorageReader
+    def read_metadata(self) -> Metadata:
+        path = self.fs.concat_path(self.path, ".metadata")
+        with self.fs.create_stream(path, "rb") as metadata_file:
+            return pickle.load(metadata_file)
+
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
+        self.storage_data = metadata.storage_data
+        assert self.storage_data is not None
+
+    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
+        return plan
+
+    def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
+        return global_plan
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        return FileSystem.validate_checkpoint_id(checkpoint_id)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/format_utils.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/format_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..754ae7b5cb73e841e07f036d47e4f5ecf04e9257
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/format_utils.py
@@ -0,0 +1,311 @@
+import argparse
+import os
+from enum import Enum
+from typing import cast, Dict, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard._utils import narrow_tensor_by_index
+from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
+from torch.distributed.checkpoint._nested_dict import flatten_state_dict
+from torch.distributed.checkpoint._traverse import set_element
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
+from torch.distributed.checkpoint.metadata import (
+    Metadata,
+    STATE_DICT_TYPE,
+    STORAGE_TYPES,
+    TensorProperties,
+    TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
+from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
+from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
+from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
+from torch.distributed.checkpoint.storage import StorageReader
+from torch.futures import Future
+
+
+__all__ = [
+    "dcp_to_torch_save",
+    "torch_save_to_dcp",
+    "BroadcastingTorchSaveReader",
+    "DynamicMetaLoadPlanner",
+]
+
+
+class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
+    """
+    Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
+    Useful for loading in state_dict without first initializing a model, such as
+    when converting a DCP checkpoint into a Torch save file.
+
+    . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
+
+    .. warning::
+        Because the entire state dict is initialized, It's recommended to only utilize
+        this LoadPlanner on a single rank or process to avoid OOM.
+
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def set_up_planner(
+        self,
+        state_dict: STATE_DICT_TYPE,
+        metadata: Metadata,
+        is_coordinator: bool,
+    ) -> None:
+        assert not state_dict
+
+        # rebuild the state dict from the metadata
+        for k, v in metadata.state_dict_metadata.items():
+            if isinstance(v, TensorStorageMetadata):
+                v = torch.empty(v.size, dtype=v.properties.dtype)  # type: ignore[assignment]
+            if k in metadata.planner_data:
+                set_element(state_dict, metadata.planner_data[k], v)
+            else:
+                state_dict[k] = v
+
+        super().set_up_planner(state_dict, metadata, is_coordinator)
+
+
+class BroadcastingTorchSaveReader(StorageReader):
+    """
+    StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
+    on the coordinator rank, and then broadcast and shard each tensor to all ranks.
+
+    . N.B. Intended to be used with DynamicMetaLoadPlanner
+
+    .. warning::
+        Current implementation only supports loading Tensors.
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> sd = {"mode": model}
+    >>> dcp.load(
+    >>>    sd,
+    >>>    storage_reader=BroadcastingTorchSaveReader(),
+    >>>    planner=DynamicMetaLoadPlanner(),
+    >>>    checkpoint_id="path_to_model.pt"
+    >>> )
+    """
+
+    def __init__(
+        self,
+        checkpoint_id: Optional[Union[str, os.PathLike]] = None,
+        coordinator_rank: int = 0,
+    ) -> None:
+        self.checkpoint_id = checkpoint_id
+        self.coordinator_rank = coordinator_rank
+
+    def read_metadata(self) -> Metadata:
+        """Extends the default StorageReader to support building the metadata file"""
+        # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
+        # the disk
+        return Metadata(state_dict_metadata={})
+
+    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
+        """
+        Reads torch save data on the coordinator rank, and broadcast afterwards
+        this incurrs a communication cost, but avoids having to load
+        the entire checkpoint on each rank, hopefully preventing OOM issues
+        """
+        planner = cast(DefaultLoadPlanner, planner)
+
+        # data is read in on the coordinator rank, and broadcast afterwards
+        # this incurrs a communication cost, but it avoids having to load
+        # the entire checkpoint on each rank, hopefully preventing OOM issues
+        # TODO: read on each host, instead of only the coordinator
+        if self.is_coordinator:
+            assert self.checkpoint_id is not None
+            torch_state_dict = torch.load(self.checkpoint_id, map_location="cpu")
+            if planner.flatten_state_dict:
+                torch_state_dict, _ = flatten_state_dict(torch_state_dict)
+        else:
+            torch_state_dict = None
+
+        for req in plan.items:
+            if req.type == LoadItemType.BYTE_IO:
+                raise RuntimeError(
+                    f"Non-tensor value identified at {req.storage_index.fqn}. "
+                    f"At this time {type(self).__name__} only supports loading Tensors."
+                )
+
+            #  Broadcast the tensor from the coordinator rank
+            if self.is_coordinator:
+                tensor = torch_state_dict[req.storage_index.fqn].cuda()
+            else:
+                tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
+
+            dist.broadcast(tensor, src=self.coordinator_rank, async_op=False)
+
+            tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
+            target_tensor = planner.resolve_tensor(req).detach()
+            assert target_tensor.size() == tensor.size(), (
+                f"req {req.storage_index} mismatch sizes, "
+                f"{target_tensor.size()} vs {tensor.size()}"
+            )
+            target_tensor.copy_(tensor)
+            planner.commit_tensor(req, target_tensor)
+
+        fut: Future = Future()
+        fut.set_result(None)
+        return fut
+
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
+        """Implementation of the StorageReader method"""
+        self.is_coordinator = is_coordinator
+        if self.is_coordinator:
+            assert dist.get_rank() == self.coordinator_rank
+
+        assert self.checkpoint_id is not None
+
+    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
+        """Implementation of the StorageReader method"""
+        return plan
+
+    def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
+        """Implementation of the StorageReader method"""
+        return global_plan
+
+    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
+        """Implementation of the StorageReader method"""
+        self.checkpoint_id = checkpoint_id
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        """Implementation of the StorageReader method"""
+        return os.path.isfile(checkpoint_id)
+
+
+class DynamicMetaLoadPlanner(DefaultLoadPlanner):
+    """
+    Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,
+    avoiding the need to read metadata from disk. This is useful when reading formats which don't have a
+    metadata file, like Torch Save files.
+
+    . N.B. Intended to be used with BroadcastingTorchSaveReader
+
+    .. warning::
+        Current implementation only supports loading Tensors.
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> sd = {"mode": model}
+    >>> dcp.load(
+    >>>    sd,
+    >>>    storage_reader=BroadcastingTorchSaveReader(),
+    >>>    planner=DynamicMetaLoadPlanner(),
+    >>>    checkpoint_id="path_to_model.pt"
+    >>> )
+    """
+
+    def set_up_planner(
+        self,
+        state_dict: STATE_DICT_TYPE,
+        metadata: Metadata,
+        is_coordinator: bool,
+    ) -> None:
+        """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
+        super().set_up_planner(state_dict, metadata, is_coordinator)
+
+        state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
+        for key, tensor in self.state_dict.items():
+            if not torch.is_tensor(tensor):
+                raise RuntimeError(
+                    f"Non-tensor value identified at {key}. "
+                    f"At this time {type(self).__name__} only supports loading Tensors."
+                )
+
+            state_dict_metadata[key] = TensorStorageMetadata(
+                TensorProperties(dtype=tensor.dtype),
+                tensor.size(),
+                _create_chunk_list(tensor),
+            )
+        self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
+
+
+def dcp_to_torch_save(
+    dcp_checkpoint_dir: Union[str, os.PathLike],
+    torch_save_path: Union[str, os.PathLike],
+):
+    """
+    Given a directory containing a DCP checkpoint, this function will convert it into a
+    Torch save file.
+
+    Args:
+        dcp_checkpoint_dir: Directory containing the DCP checkpoint.
+        torch_save_path: Filename to store the converted Torch save file.
+
+    .. warning::
+        To avoid OOM, it's recommended to only run this function on a single rank.
+    """
+    sd: STATE_DICT_TYPE = {}
+    _load_state_dict(
+        sd,
+        storage_reader=FileSystemReader(dcp_checkpoint_dir),
+        planner=_EmptyStateDictLoadPlanner(),
+        no_dist=True,
+    )
+    torch.save(sd, torch_save_path)
+
+
+def torch_save_to_dcp(
+    torch_save_path: Union[str, os.PathLike],
+    dcp_checkpoint_dir: Union[str, os.PathLike],
+):
+    """
+    Given the location of a torch save file, converts it into a DCP checkpoint.
+
+    Args:
+        torch_save_path: Filename to store the converted Torch save file.
+        dcp_checkpoint_dir: Directory containing the DCP checkpoint.
+
+    .. warning::
+        To avoid OOM, it's recommended to only run this function on a single rank.
+    """
+
+    state_dict = torch.load(torch_save_path)
+    # we don't need stateful behavior here because the expectation is anything loaded by
+    # torch.load would not contain stateful objects.
+    _save_state_dict(
+        state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True
+    )
+
+
+if __name__ == "__main__":
+
+    class FormatMode(Enum):
+        TORCH_TO_DCP = "torch_to_dcp"
+        DCP_TO_TORCH = "dcp_to_torch"
+
+    # Parse command-line arguments
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "mode",
+        type=str,
+        help="Conversion mode",
+        choices=[m.value for m in FormatMode],
+        default=FormatMode.TORCH_TO_DCP,
+    )
+    parser.add_argument("src", type=str, help="Path to the source model")
+    parser.add_argument("dst", type=str, help="Path to the destination model")
+    args = parser.parse_args()
+
+    print(
+        f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'"
+    )
+    checkpoint_missing_warning = (
+        f"No checkpoint found at {args.src}. Skipping conversion."
+    )
+    if args.mode == FormatMode.TORCH_TO_DCP.value:
+        if os.path.isfile(args.src):
+            torch_save_to_dcp(args.src, args.dst)
+        else:
+            print(checkpoint_missing_warning)
+    elif args.mode == FormatMode.DCP_TO_TORCH.value:
+        if os.path.isdir(args.src):
+            dcp_to_torch_save(args.src, args.dst)
+        else:
+            print(checkpoint_missing_warning)
+    else:
+        raise ValueError(f"Unknown conversion mode: {args.mode}")
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/fsspec.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/fsspec.py
new file mode 100644
index 0000000000000000000000000000000000000000..f471fba89ada2a5aadf1530a3b6aed40bd68ac44
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/fsspec.py
@@ -0,0 +1,122 @@
+# Mypy will not try inferring the types of any 3rd party libraries installed.
+# mypy: ignore-errors
+
+import io
+import os
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Generator, Optional, Union
+
+import fsspec
+from fsspec import AbstractFileSystem
+from fsspec.core import url_to_fs
+
+from torch.distributed.checkpoint.filesystem import (
+    FileSystemBase,
+    FileSystemReader,
+    FileSystemWriter,
+)
+
+__all__ = [
+    "FsspecWriter",
+    "FsspecReader",
+]
+
+
+class FileSystem(FileSystemBase):
+    def __init__(self) -> None:
+        self.fs: Optional[AbstractFileSystem] = None
+
+    @contextmanager
+    def create_stream(
+        self, path: Union[str, os.PathLike], mode: str
+    ) -> Generator[io.IOBase, None, None]:
+        assert self.fs is not None
+        with self.fs.transaction:
+            with fsspec.open(str(path), mode) as stream:
+                yield stream
+
+    def concat_path(
+        self, path: Union[str, os.PathLike], suffix: str
+    ) -> Union[str, os.PathLike]:
+        return os.path.join(path, suffix)
+
+    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
+        self.fs, _ = url_to_fs(path)
+        return path
+
+    def rename(
+        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
+    ) -> None:
+        self.fs.rename(path, new_path)
+
+    def mkdir(self, path: [str, os.PathLike]) -> None:
+        self.fs.makedirs(path, exist_ok=True)
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        if isinstance(checkpoint_id, Path):
+            return False
+
+        try:
+            url_to_fs(checkpoint_id)
+        except ValueError as e:
+            return False
+
+        return True
+
+
+class FsspecWriter(FileSystemWriter):
+    """
+    Basic implementation of StorageWriter using FFspec.
+
+    This implementation makes the following assumptions and simplifications:
+
+    * The checkpoint path is an empty or non-existing directory.
+    * File creation is atomic
+
+    The checkpoint consist of one file per write request plus
+    a `.metadata` file with the serialized metadata.
+
+    """
+
+    def __init__(
+        self,
+        path: Union[str, os.PathLike],
+        single_file_per_rank: bool = True,
+        sync_files: bool = True,
+        thread_count: int = 1,
+        per_thread_copy_ahead: int = 10_000_000,
+    ) -> None:
+        """
+        Initialize the writer pointing to `path`.
+
+        Args:
+            path: directory where the checkpoint will be written to.
+            single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
+            sync_files : force files to be synced to permanent storage. Default to True.
+            thread_count: Number of IO threads to use to write. Default to 1.
+            per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
+
+        N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
+        """
+        super().__init__(
+            path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
+        )
+        self.fs = FileSystem()
+        self.path = self.fs.init_path(path)
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        return FileSystem.validate_checkpoint_id(checkpoint_id)
+
+
+class FsspecReader(FileSystemReader):
+    def __init__(self, path: Union[str, os.PathLike]) -> None:
+        super().__init__(path)
+        self.fs = FileSystem()
+        self.path = self.fs.init_path(path)
+
+    @classmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        return FileSystem.check(checkpoint_id)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/metadata.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..2da2237807a90fcdab6dec485b3a9382b8707236
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/metadata.py
@@ -0,0 +1,170 @@
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Sequence, Union
+
+import torch
+from torch.distributed.checkpoint.stateful import StatefulT
+
+__all__ = [
+    "ChunkStorageMetadata",
+    "TensorStorageMetadata",
+    "BytesStorageMetadata",
+    "Metadata",
+    "MetadataIndex",
+    "TensorProperties",
+]
+
+
+@dataclass
+class ChunkStorageMetadata:
+    """
+    Each chunk is expected to have the same properties of the TensorStorageMetadata
+    that includes it.
+    """
+
+    offsets: torch.Size
+    sizes: torch.Size
+
+
+class _MEM_FORMAT_ENCODING(Enum):
+    """Describe the memory format of a tensor."""
+
+    TORCH_CONTIGUOUS_FORMAT = 0
+    TORCH_CHANNELS_LAST = 1
+    TORCH_PRESERVE_FORMAT = 2
+
+
+@dataclass
+class TensorProperties:
+    """Properties used to create :class:`Tensor`"""
+
+    # Regular tensor fields
+    dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
+    # This field is deprecated.
+    layout: torch.layout = field(default=torch.strided)
+    # This field is deprecated.
+    requires_grad: bool = False
+    # This field is deprecated.
+    memory_format: torch.memory_format = field(default=torch.contiguous_format)
+    # This field is deprecated.
+    pin_memory: bool = False
+
+    def __getstate__(self):
+        # Since torch.memory_format cannot be pickled!
+        memory_format = self.memory_format
+        if memory_format == torch.contiguous_format:
+            mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
+        elif memory_format == torch.channels_last:
+            mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
+        elif memory_format == torch.preserve_format:
+            mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
+        else:
+            raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
+
+        return (
+            self.dtype,
+            self.layout,
+            self.requires_grad,
+            mem_format_encoding,
+            self.pin_memory,
+        )
+
+    def __setstate__(
+        self,
+        state,
+    ):
+        (
+            self.dtype,
+            self.layout,
+            self.requires_grad,
+            mem_format_encoding,
+            self.pin_memory,
+        ) = state
+
+        if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
+            memory_format = torch.contiguous_format
+        elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
+            memory_format = torch.channels_last
+        elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
+            memory_format = torch.preserve_format
+        else:
+            raise RuntimeError(
+                f"Invalid torch.memory_format encoding: {mem_format_encoding}"
+            )
+
+        self.memory_format = memory_format
+
+    @staticmethod
+    def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
+        return TensorProperties(
+            dtype=tensor.dtype,
+            layout=tensor.layout,
+            requires_grad=tensor.requires_grad,
+            memory_format=torch.contiguous_format,
+            pin_memory=tensor.is_pinned(),
+        )
+
+
+@dataclass
+class TensorStorageMetadata:
+    properties: TensorProperties
+    size: torch.Size
+    chunks: List[ChunkStorageMetadata]
+
+
+@dataclass
+class BytesStorageMetadata:
+    pass
+
+
+STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
+STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
+
+
+@dataclass
+class Metadata:
+    """This class represents the metadata of the checkpoint."""
+
+    # Keys are the same from the `state_dict` used.
+    state_dict_metadata: Dict[str, STORAGE_TYPES]
+    # It is the responsibility of the planner and storage plugins to ensure
+    # backward compatibility of the planner_data and storage_data. DCP will
+    # also ensure the backward compatibility of the metadata in this file and
+    # the metadata of the built-in planner and storage plugins.
+    planner_data: Any = None
+    storage_data: Any = None
+
+
+@dataclass(frozen=True)
+class MetadataIndex:
+    """This class represents a lookup key for items in a state dict or Metadata."""
+
+    fqn: str
+    """Fully Qualified Name of the object"""
+
+    offset: Optional[torch.Size] = None
+    """If the object is a tensor, offset into the tensor we're looking for"""
+
+    index: Optional[int] = field(hash=False, compare=False, default=None)
+    """
+    Index hint when searching for tensor chunk to speedup lookups (optional)
+
+    A common representation of a sharded tensor is as a list of chunks so to
+    find the index in such a list you need to linear search it.
+
+    When constructing an instance of MetadataIndex that points to that list,
+    one can provide the index as a hint and it will be probed first before
+    the linear search and thus making it significantly faster.
+    """
+
+    def __init__(
+        self,
+        fqn: str,
+        offset: Optional[Sequence[int]] = None,
+        index: Optional[int] = None,
+    ):
+        # We must use object.__setattr__ due to frozen=True
+        object.__setattr__(self, "fqn", fqn)
+        object.__setattr__(self, "index", index)
+        if offset is not None:
+            object.__setattr__(self, "offset", torch.Size(offset))
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/optimizer.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..92b969a8266dec88cc79d70e12d874cc80cef871
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/optimizer.py
@@ -0,0 +1,348 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import dataclasses
+from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.distributed as dist
+from torch._utils import _get_device_module
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+from torch.distributed._shard.sharded_tensor.metadata import (
+    TensorProperties as ShardTensorProperties,
+)
+from torch.distributed._shard.sharded_tensor.shard import Shard
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
+from torch.distributed.checkpoint.metadata import (
+    BytesStorageMetadata,
+    ChunkStorageMetadata,
+    Metadata,
+    MetadataIndex,
+    STATE_DICT_TYPE,
+    TensorProperties,
+    TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
+from torch.distributed.checkpoint.planner_helpers import (
+    _create_read_items,
+    create_read_items_for_chunk_list,
+)
+from torch.distributed.checkpoint.state_dict_loader import load_state_dict
+from torch.distributed.checkpoint.storage import StorageReader
+from torch.distributed.checkpoint.utils import (
+    _element_wise_add,
+    _element_wise_sub,
+    _normalize_device_info,
+)
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
+from torch.distributed.remote_device import _remote_device
+
+STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
+
+
+# TODO: Update docstrings for optimizer.py
+__all__ = [
+    "load_sharded_optimizer_state_dict",
+]
+
+
+def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
+    if device_type == "cpu":
+        return "cpu"
+    device_module = _get_device_module(device_type)
+    if device_module.is_available():
+        return _normalize_device_info(
+            device_type, global_rank % device_module.device_count()
+        )
+    return "cpu"
+
+
+def _create_colwise_spec(
+    pg: Optional[dist.ProcessGroup] = None,
+) -> ChunkShardingSpec:
+    pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
+    if pg is None:
+        placements = [
+            f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
+            for idx in range(dist.get_world_size())
+        ]
+    else:
+        placements = [
+            f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
+            for idx in range(pg.size())
+        ]
+    return ChunkShardingSpec(
+        dim=0,
+        placements=cast(List[Union[_remote_device, str]], placements),
+    )
+
+
+def _is_nested_tensor(val: torch.Tensor) -> bool:
+    if type(val) is ShardedTensor:
+        if len(val.local_shards()) == 0:
+            return False
+        if type(val.local_shards()[0].tensor) is ShardedTensor:
+            return True
+        if type(val.local_shards()[0].tensor) is DTensor:
+            raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
+    elif type(val) is DTensor and (
+        type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
+    ):
+        raise ValueError("Cannot handle nested DTensor")
+    return False
+
+
+def _alloc_tensor(
+    props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
+) -> torch.Tensor:
+    return torch.empty(
+        size=size,
+        dtype=props.dtype,
+        layout=props.layout,
+        requires_grad=props.requires_grad,
+        pin_memory=props.pin_memory,
+        device=cast(torch.device, _get_device_module(device_type).current_device()),
+    )
+
+
+def _get_state_dict_2d_layout(
+    state_dict: STATE_DICT_TYPE,
+) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
+    """
+    Load the right TP slice of the optimizer state.
+
+    This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
+    We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
+    This is pretty fragile and it might be easier for FSDP to compute this info for us.
+    Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
+    (offset, size) for the current rank TP slice.
+    N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
+    """
+    specs: STATE_DICT_2D_LAYOUT = {}
+    dp_pg: Optional[dist.ProcessGroup] = None
+    for key, value in state_dict.items():
+        specs[key] = (None, value.size())
+        if _is_nested_tensor(value):
+            assert (
+                len(value.local_shards()) == 1
+            ), "Cannot handle ST with multiple shards"
+            assert isinstance(
+                value, ShardedTensor
+            ), "Can only handle nested ShardedTensor"
+            shard = value.local_shards()[0]
+            specs[key] = (
+                shard.metadata.shard_offsets,
+                shard.metadata.shard_sizes,
+            )
+            dp_pg = shard.tensor._process_group  # type: ignore[attr-defined]
+
+    return (
+        specs,
+        dp_pg,
+    )
+
+
+class _ReaderWithOffset(DefaultLoadPlanner):
+    translation: Dict[MetadataIndex, MetadataIndex]
+    state_dict: STATE_DICT_TYPE
+    metadata: Metadata
+
+    def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
+        super().__init__()
+        self.fqn_to_offset = fqn_to_offset
+        self.metadata = Metadata({})
+        self.state_dict = {}
+        self.translation = {}
+
+    def create_local_plan(self) -> LoadPlan:
+        requests = []
+        self.translation = {}
+        for fqn, obj in self.state_dict.items():
+            md = self.metadata.state_dict_metadata[fqn]
+            if not isinstance(obj, ShardedTensor):
+                requests += _create_read_items(fqn, md, obj)
+                continue
+
+            if fqn not in self.fqn_to_offset:
+                requests += _create_read_items(fqn, md, obj)
+                continue
+
+            offset = self.fqn_to_offset[fqn]
+
+            assert len(obj.local_shards()) == 1
+            original_shard = obj.local_shards()[0]
+            local_chunks = [
+                ChunkStorageMetadata(
+                    offsets=torch.Size(
+                        _element_wise_add(original_shard.metadata.shard_offsets, offset)
+                    ),
+                    sizes=torch.Size(original_shard.metadata.shard_sizes),
+                )
+            ]
+
+            reqs = create_read_items_for_chunk_list(
+                fqn, cast(TensorStorageMetadata, md), local_chunks
+            )
+            # TODO: The ReadItems will have a displaced MetadataIndex, fix it.
+            # TODO: we should change _create_sharded_read_items to have more ergonomic API
+            for ri in reqs:
+                assert ri.dest_index.offset is not None
+                original_offset = _element_wise_sub(ri.dest_index.offset, offset)
+                original_index = dataclasses.replace(
+                    ri.dest_index, offset=torch.Size(original_offset)
+                )
+                self.translation[ri.dest_index] = original_index
+
+            requests += reqs
+        return LoadPlan(requests)
+
+    def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
+        return super().lookup_tensor(self.translation.get(index, index))
+
+
+def load_sharded_optimizer_state_dict(
+    model_state_dict: STATE_DICT_TYPE,
+    optimizer_key: str,
+    storage_reader: StorageReader,
+    planner: Optional[LoadPlanner] = None,
+) -> STATE_DICT_TYPE:
+    """
+    Load a state_dict in conjunction with FSDP sharded optimizer state.
+
+    This is the current recommended way to checkpoint FSDP.
+    >>> # xdoctest: +SKIP
+    >>> import torch.distributed.checkpoint as dist_cp
+    >>> # Save
+    >>> model: torch.nn.Model
+    >>> optim_params = model.parameters()
+    >>> optim = torch.optim.SGD(optim_params, lr=0.01)
+    >>> # Save
+    >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+    >>>     state_dict = {
+    >>>         "optimizer": FSDP.optim_state_dict(model, optim),
+    >>>         "model": model.state_dict()
+    >>>     }
+    >>>     dist_cp.save_state_dict(
+    >>>         state_dict=optim_state,
+    >>>         storage_writer=dist_cp.FileSystemWriter("checkpoint"),
+    >>>         planner=dist_cp.DefaultSavePlanner(),
+    >>>     )
+    >>>
+    >>> # Load
+    >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
+    >>>     model_state_dict = model_tp.state_dict()
+    >>>     checkpoint = {
+    >>>         "model": model_state_dict
+    >>>     }
+    >>>     dist_cp.load_state_dict(
+    >>>         state_dict=checkpoint,
+    >>>         storage_reader=dist_cp.FileSystemReader(checkpoint_file),
+    >>>         planner=dist_cp.DefaultLoadPlanner(),
+    >>>     )
+    >>>     model.load_state_dict(checkpoint["model_state"])
+    >>>
+    >>>     optim_state = dist_cp.load_sharded_optimizer_state_dict(
+    >>>         model_state_dict,
+    >>>         optimizer_key="optimizer",
+    >>>         storage_reader=dist_cp.FileSystemReader("checkpoint"),
+    >>>     )
+    >>>
+    >>>     flattened_osd = FSDP.optim_state_dict_to_load(
+    >>>        model, optim, optim_state["optimizer"]
+    >>>     )
+    >>>
+    >>>     optim.load_state_dict(flattened_osd)
+    """
+    metadata = storage_reader.read_metadata()
+
+    layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
+    dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
+    device_module = _get_device_module(dp_pg_device_type)
+
+    if dp_pg is None:
+        placements = []
+        for i in range(dist.get_world_size()):
+            device_info = _normalize_device_info(
+                dp_pg_device_type, i % device_module.device_count()
+            )
+            placements.append(f"rank:{i}/{device_info}")
+        sharding_spec = ChunkShardingSpec(dim=0, placements=placements)  # type: ignore[arg-type]
+    else:
+        sharding_spec = _create_colwise_spec(dp_pg)
+
+    # Create a state_dict for optimizer state
+    state_dict: STATE_DICT_TYPE = {}
+
+    fqn_to_offset: Dict[str, Sequence[int]] = {}
+    for key, value in metadata.state_dict_metadata.items():
+        key_path = metadata.planner_data[key]
+        if key_path[0] != optimizer_key:
+            continue
+
+        if isinstance(value, BytesStorageMetadata):
+            state_dict[key] = ""
+            continue
+
+        # value: TensorStorageMetadata
+        if value.size.numel() == 1:
+            state_dict[key] = _alloc_tensor(
+                value.properties, value.size, dp_pg_device_type
+            )
+        elif dp_pg is None:
+            state_dict[key] = _create_chunk_sharded_tensor(
+                _alloc_tensor(value.properties, value.size, dp_pg_device_type),
+                rank=dist.get_rank(),
+                world_size=dist.get_world_size(),
+                num_devices_per_node=device_module.device_count(),
+                pg=_get_default_group(),
+            )
+        else:
+            spec_key = key_path[2]
+            alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
+
+            properties = ShardTensorProperties(
+                dtype=value.properties.dtype,
+                layout=value.properties.layout,
+                requires_grad=value.properties.requires_grad,
+                memory_format=value.properties.memory_format,
+                pin_memory=value.properties.pin_memory,
+            )
+
+            st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties)
+            local_shards = []
+            current_rank = dist.get_rank(dp_pg)
+            for shard_md in st_md.shards_metadata:
+                if cast(_remote_device, shard_md.placement).rank() != current_rank:
+                    continue
+                local_shards.append(
+                    Shard(
+                        tensor=_alloc_tensor(
+                            value.properties, shard_md.shard_sizes, dp_pg_device_type
+                        ),
+                        metadata=shard_md,
+                    )
+                )
+
+            st = ShardedTensor._init_from_local_shards_and_global_metadata(
+                local_shards, st_md, process_group=dp_pg
+            )
+
+            if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
+                fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
+
+            state_dict[key] = st
+
+    # Whether we unflatten before or after doesn't matter
+    load_state_dict(
+        state_dict=state_dict,
+        storage_reader=storage_reader,
+        # FIXME the type of planner is wrong in load_state_dict
+        planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner,
+    )
+
+    state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
+
+    return state_dict
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/planner.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/planner.py
new file mode 100644
index 0000000000000000000000000000000000000000..8992e3915b96cae3e9d3f65fafa2561c9ee521e6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/planner.py
@@ -0,0 +1,403 @@
+import abc
+import io
+from dataclasses import dataclass
+from enum import auto, Enum
+from functools import reduce
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+
+from .metadata import (
+    ChunkStorageMetadata,
+    Metadata,
+    MetadataIndex,
+    STATE_DICT_TYPE,
+    TensorProperties,
+)
+
+
+__all__ = [
+    "WriteItemType",
+    "LoadItemType",
+    "TensorWriteData",
+    "WriteItem",
+    "ReadItem",
+    "SavePlan",
+    "LoadPlan",
+    "SavePlanner",
+    "LoadPlanner",
+]
+
+
+class WriteItemType(Enum):
+    TENSOR = auto()
+    SHARD = auto()
+    BYTE_IO = auto()
+
+
+class LoadItemType(Enum):
+    TENSOR = auto()
+    BYTE_IO = auto()
+
+
+@dataclass(frozen=True)
+class TensorWriteData:
+    chunk: ChunkStorageMetadata
+    properties: TensorProperties
+    size: torch.Size
+
+
+@dataclass(frozen=True)
+class WriteItem:
+    """Dataclass which holds information about what needs to be written to storage."""
+
+    index: MetadataIndex
+    type: WriteItemType
+
+    # Value present if it's a tensor write
+    tensor_data: Optional[TensorWriteData] = None
+
+    def tensor_storage_size(self) -> Optional[int]:
+        """
+        Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
+
+        Returns:
+            Optional[int] storage size, in bytes of underlying tensor if any.
+        """
+        if self.tensor_data is None:
+            return None
+
+        numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1)
+        dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
+        return numels * dtype_size
+
+
+@dataclass(frozen=True)
+class ReadItem:
+    # Read Item
+    type: LoadItemType
+
+    # Index into the state_dict
+    dest_index: MetadataIndex
+    # Offsets into destination tensor
+    dest_offsets: torch.Size
+
+    # Index into the checkpoint
+    storage_index: MetadataIndex
+    # Offset into the checkpoint data
+    storage_offsets: torch.Size
+
+    # Size of the hypercube to copy
+    lengths: torch.Size
+
+
+@dataclass(frozen=True)
+class SavePlan:
+    items: List[WriteItem]
+    storage_data: Any = None
+    planner_data: Any = None
+
+
+@dataclass
+class LoadPlan:
+    items: List[ReadItem]
+    storage_data: Any = None
+    planner_data: Any = None
+
+
+class SavePlanner(abc.ABC):
+    """
+    Abstract class defining the protocol used by save_state_dict to plan the save process.
+
+    SavePlanners are stateful objects that can be used to customize the whole save process.
+
+    SavePlanner acts as an access proxy to the state_dict, so any transformation done to it
+    will be visible to the whole process.
+
+    A planner subclass can expect the following sequence of calls during save_state_dict:
+
+    1) set_up_planner - called on all ranks.
+        Signals the start of a checkpoint save.
+
+    2) create_local_plan - called on all ranks.
+        Process the state_dict and produces a `SavePlan` that will be sent for global planning.
+
+    3) create_global_plan - called on the coordinator rank only.
+        Takes the SavePlan from all ranks and make any global decision.
+
+    4) finish_plan - called on all ranks.
+        This gives each rank a chance to adjust to global planning decisions.
+
+    5) resolve_data - called multiple times on each rank
+        Lookups a value on the `state_dict` for the storage layer to write.
+
+    Users are recommended to extend DefaultSavePlanner instead of this interface directly as
+    most changes can be expressed by changes in a single method.
+
+    There are 3 usual patterns of extension:
+
+    Rewriting state_dict. This is the simplest way to extend the save process as it
+    doesn't requite understanding the intrincacies of how SavePlan works:
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> class RenamePlanner(DefaultSavePlanner):
+    >>>     def set_up_planner(self, state_dict, is_coordinator):
+    >>>         # prefix all keys with `foo_``
+    >>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
+
+    Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> class FP16Planner(DefaultSavePlanner):
+    >>>     def create_local_plan(self):
+    >>>         plan = super().create_local_plan()
+    >>>         for p in plan:
+    >>>             if p.tensor_data is not None:
+    >>>                 p.tensor_data.properties.dtype = torch.float16
+    >>>         return plan
+    >>>
+    >>>     def resolve_data(self, write_item):
+    >>>         item = super().resolve_data(write_item)
+    >>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
+
+    Using the global planning step to make central decisions that can't be made individually by each rank
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> from itertools import islice
+    >>> from dataclasses import replace
+    >>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
+    >>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
+    >>>     # This sample doesn't handle ShardedTensors
+    >>>     def create_global_plan(self, all_plans):
+    >>>         def chunk(it, size):
+    >>>             it = iter(it)
+    >>>         return list(iter(lambda: tuple(islice(it, size)), ()))
+    >>>         all_plans = [
+    >>>             replace(plan, items=items) for plan, items in
+    >>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
+    >>>         ]
+    >>>         return super().create_global_plan(all_plans)
+
+    Finally, some planners need to save additional metadata in the checkpoint, this is
+    accomplished by having each rank contribute their data items in the local plan and
+    the global planner aggregate them:
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> class SaveExtraDataPlanner(DefaultSavePlanner):
+    >>>     def create_local_plan(self) -> SavePlan:
+    >>>         plan = super().create_local_plan()
+    >>>         return replace(plan, planner_data="per-rank-data")
+    >>>
+    >>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
+    >>>         global_plan, metadata = super().create_global_plan(all_plans)
+    >>>         merged_data = [p.planner_data for p in global_plan]
+    >>>         metadata = replace(metadata, planner_data=merged_data)
+    >>>         return global_plan, metadata
+    """
+
+    @abc.abstractmethod
+    def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
+        """
+        Initialize this planner to save ``state_dict``.
+
+        Implementations should save those values as they won't be provided lated in the save process.
+
+        This is called on all ranks.
+        """
+        pass
+
+    @abc.abstractmethod
+    def create_local_plan(self) -> SavePlan:
+        """
+        Compute the save plan for the current rank.
+
+        This will be aggregated and passed to create_global_plan.
+        Planner specific data can be passed through SavePlan::planner_data.
+
+        This is called on all ranks.
+        """
+        pass
+
+    @abc.abstractmethod
+    def create_global_plan(
+        self, all_plans: List[SavePlan]
+    ) -> Tuple[List[SavePlan], Metadata]:
+        """
+        Compute the global checkpoint plan and return the local plan of each rank.
+
+        This is called on the coordinator rank only.
+        """
+        pass
+
+    @abc.abstractmethod
+    def finish_plan(self, new_plan: SavePlan) -> SavePlan:
+        """
+        Merge the plan created by `create_local_plan` and the result of `create_global_plan`.
+
+        This is called on all ranks.
+        """
+        pass
+
+    @abc.abstractmethod
+    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
+        """
+        Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
+
+        Lookup the object associated with ``write_item`` in ``state_dict`` and apply any
+        transformation (such as serialization) prior to the storage layer consuming it.
+
+        Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
+
+        This method should be idempotent and thread-save. StorageWriter implementations
+        are free to call it as frequently as they need.
+
+        Any transformation that allocates memory should be lazily done when his method
+        is called in order to reduce peak memory required by checkpointing.
+
+        When returning tensors, they can be on any device or format, they can be views too.
+        It's the storage layer responsibility to figure out how to save them.
+        """
+        pass
+
+
+class LoadPlanner:
+    """
+    Abstract class defining the protocol used by load_state_dict to plan the load process.
+
+    LoadPlanner are stateful objects that can be used to customize the whole load process.
+
+    LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it
+    will be visible to the whole process.
+
+    A planner subclass can expect the following sequence of calls during load_state_dict:
+
+    1) set_up_planner - called on all ranks.
+        Signals the start of loading a checkpoint.
+
+    2) create_local_plan - called on all ranks.
+        Process the state_dict and produces a `LoadPlan` that will be sent for global planning.
+
+    3) create_global_plan - called on the coordinator rank only.
+        Takes the LoadPlan from all ranks and make any global decision.
+
+    4) load_bytes - called multiple times on each rank
+        This is called once per non-tensor value in state_dict.
+
+    5) resolve_tensor and commit_tensor - called multiple times on each rank
+        They are called in pair for each Tensor value in state_dict.
+
+    Users are recommended to extend DefaultLoadPlanner instead of this interface directly as
+    most changes can be expressed by changes in a single method.
+
+    There are two usual patterns of extension:
+
+    Rewriting state_dict. This is the simplest way to extend the load process as it
+    doesn't requite understanding the intrincacies of how LoadPlan works. We need
+    to keep a reference to the original state_dict as load happens in place so
+    we need to be able to perform it in place
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> class RenamePlanner(DefaultLoadPlanner):
+    >>>     def set_up_planner(self, state_dict, metadata, is_coordinator):
+    >>>         self.original_state_dict = state_dict
+    >>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
+    >>>
+    >>>         if self.flatten_sharded_tensors:
+    >>>             state_dict = _flatten_sharded_tensors(state_dict)
+    >>>
+    >>>         if self.flatten_state_dict:
+    >>>             state_dict, self.mappings = flatten_state_dict(state_dict)
+    >>>
+    >>>         self.state_dict = state_dict
+    >>>         self.metadata = metadata
+    >>>         self.is_coordinator = is_coordinator
+    >>>
+    >>>     def load_bytes(self, read_item, value):
+    >>>         # Remove the "foo_" prefix
+    >>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
+
+
+    Modifying resolve_tensor and commit_tensor to handle load time transformation.
+
+    >>> # xdoctest: +SKIP("undefined vars")
+    >>> class MetaModelMaterialize(DefaultSavePlanner):
+    >>>     def resolve_tensor(self, read_item):
+    >>>         tensor = super().resolve_tensor(read_item)
+    >>>         return torch.empty_like(tensor, device="cpu")
+    >>>
+    >>>     def commit_tensor(self, read_item, tensor):
+    >>>         self.state_dict[read_item.dest_index.fqn] = tensor
+    """
+
+    @abc.abstractmethod
+    def set_up_planner(
+        self,
+        state_dict: STATE_DICT_TYPE,
+        metadata: Metadata,
+        is_coordinator: bool,
+    ) -> None:
+        """
+        Initialize this instance to load data into ``state_dict``.
+
+        . N.B. This is called on every rank.
+        """
+        pass
+
+    @abc.abstractmethod
+    def create_local_plan(self) -> LoadPlan:
+        """
+        Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
+
+        . N.B. This is called on every rank.
+        """
+        pass
+
+    @abc.abstractmethod
+    def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
+        """
+        Compute the global load plan and return plans for each rank.
+
+        . N.B. This is called on the coordinator rank only
+        """
+        pass
+
+    @abc.abstractmethod
+    def finish_plan(self, central_plan: LoadPlan) -> LoadPlan:
+        """Accept the plan from coordinator and return final LoadPlan."""
+        pass
+
+    @abc.abstractmethod
+    def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
+        """
+        Load the item described by ``read_item``and ``value``.
+
+        This method is expected to modify in-place the underlying state_dict.
+
+        The contents of ``value`` are defined by the SavePlanner used to produce
+        the checkpoint being loaded.
+        """
+        pass
+
+    @abc.abstractmethod
+    def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor:
+        """
+        Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`.
+
+        The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.
+        If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data
+        back to the one in state_dict.
+        """
+        pass
+
+    @abc.abstractmethod
+    def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
+        """
+        Call once the StorageReader finished loading data into ``tensor``.
+
+        The provided tensor is the same one returned by the call to ``resolve_tensor``.
+        This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to
+        copying it back to the one in the state_dict.
+
+        The contents of tensor will follow its device synchronization model.
+        """
+        pass
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..99ee60d95830fb0a76e50199dff7917213653107
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py
@@ -0,0 +1,325 @@
+from typing import Any, cast, List
+
+import torch
+import torch.distributed as dist
+from torch._utils import _get_device_module
+
+from torch.distributed._shard.metadata import ShardMetadata
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._tensor import DTensor
+from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
+
+from torch.utils._pytree import tree_map_only
+
+from .metadata import (
+    BytesStorageMetadata,
+    ChunkStorageMetadata,
+    MetadataIndex,
+    STATE_DICT_TYPE,
+    STORAGE_TYPES,
+    TensorProperties,
+    TensorStorageMetadata,
+)
+from .planner import (
+    LoadItemType,
+    ReadItem,
+    SavePlan,
+    TensorWriteData,
+    WriteItem,
+    WriteItemType,
+)
+from .resharding import (
+    _check_shard_metadata_pair_overlap,
+    _shards_get_overlap_region_wrt_saved_tensor,
+)
+
+__all__: List[str] = ["create_read_items_for_chunk_list"]
+
+
+def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
+    return ChunkStorageMetadata(
+        offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
+    )
+
+
+def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
+    return ChunkStorageMetadata(
+        offsets=torch.Size(shard_md.shard_offsets),
+        sizes=torch.Size(shard_md.shard_sizes),
+    )
+
+
+def _sharded_tensor_metadata(
+    sharded_tensor: ShardedTensor, shard_md: ShardMetadata
+) -> TensorWriteData:
+    shard_properties = sharded_tensor.metadata().tensor_properties
+
+    properties = TensorProperties(
+        dtype=shard_properties.dtype,
+        layout=shard_properties.layout,
+        requires_grad=shard_properties.requires_grad,
+        memory_format=shard_properties.memory_format,
+        pin_memory=shard_properties.pin_memory,
+    )
+
+    return TensorWriteData(
+        chunk=_chunk_for_shard(shard_md),
+        properties=properties,
+        size=sharded_tensor.metadata().size,
+    )
+
+
+def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
+    sizes, offsets = compute_local_shape_and_global_offset(
+        tensor.shape, tensor.device_mesh, tensor.placements
+    )
+    sizes, offsets = torch.Size(sizes), torch.Size(offsets)
+
+    return WriteItem(
+        index=MetadataIndex(fqn, offsets),
+        type=WriteItemType.SHARD,
+        tensor_data=TensorWriteData(
+            chunk=ChunkStorageMetadata(
+                offsets=offsets,
+                sizes=sizes,
+            ),
+            properties=TensorProperties.create_from_tensor(tensor.to_local()),
+            size=tensor.size(),
+        ),
+    )
+
+
+def _create_write_item_for_shard(
+    fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
+) -> WriteItem:
+    offsets = torch.Size(shard_md.shard_offsets)
+    return WriteItem(
+        index=MetadataIndex(fqn, offsets),
+        type=WriteItemType.SHARD,
+        tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
+    )
+
+
+def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
+    offsets = torch.Size([0] * len(tensor.size()))
+    return WriteItem(
+        index=MetadataIndex(fqn, offsets),
+        type=WriteItemType.TENSOR,
+        tensor_data=TensorWriteData(
+            chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
+            properties=TensorProperties.create_from_tensor(tensor),
+            size=tensor.size(),
+        ),
+    )
+
+
+def _create_write_item_for_bytesio(fqn: str, bytes: Any):
+    return WriteItem(
+        index=MetadataIndex(fqn),
+        type=WriteItemType.BYTE_IO,
+    )
+
+
+def _create_read_item_for_byteio(
+    dest_index, dest_offset, storage_index, storage_offset, length
+):
+    return ReadItem(
+        type=LoadItemType.BYTE_IO,
+        dest_index=dest_index,
+        dest_offsets=torch.Size((dest_offset,)),
+        storage_index=storage_index,
+        storage_offsets=torch.Size((storage_offset,)),
+        lengths=torch.Size((length,)),
+    )
+
+
+def _create_read_item_for_tensor(
+    dest_index, dest_offsets, storage_index, storage_offsets, lengths
+):
+    return ReadItem(
+        type=LoadItemType.TENSOR,
+        dest_index=dest_index,
+        dest_offsets=torch.Size(dest_offsets),
+        storage_index=storage_index,
+        storage_offsets=torch.Size(storage_offsets),
+        lengths=torch.Size(lengths),
+    )
+
+
+def create_read_items_for_chunk_list(
+    fqn: str,
+    checkpoint_md: TensorStorageMetadata,
+    local_chunks: List[ChunkStorageMetadata],
+) -> List[ReadItem]:
+    """
+    Create a list of ``ReadItem`` based on the checkpoint and local chunks.
+
+    This applies the resharding algorithm and computes the reads needed
+    to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
+
+    Args:
+        fqn (str) : The state_dict FQN to pass to ``ReadItem``.
+        checkpoint_md (TensorStorageMetadata): metadata for a given tensor
+            from a checkpoint.
+        local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
+            loaded.
+
+    Returns:
+        A list of ``ReadItem`` that will satisfy all input chunks.
+    """
+    read_items = []
+    # this is a naive quadratic algo that can be optimized later
+    for idx, shard in enumerate(local_chunks):
+        for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
+            if not _check_shard_metadata_pair_overlap(shard, storage_md):
+                continue
+
+            storage_offsets = []
+            dest_offsets = []
+            lengths = []
+            for (
+                dim,
+                offset_for_saved_tensor,
+                offset_for_current_tensor,
+                length,
+            ) in _shards_get_overlap_region_wrt_saved_tensor(
+                saved_shard=storage_md, current_shard=shard
+            ):
+                storage_offsets.append(offset_for_saved_tensor)
+                dest_offsets.append(offset_for_current_tensor)
+                lengths.append(length)
+
+            read_items.append(
+                _create_read_item_for_tensor(
+                    dest_index=MetadataIndex(fqn, shard.offsets, idx),
+                    dest_offsets=dest_offsets,
+                    storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
+                    storage_offsets=storage_offsets,
+                    lengths=lengths,
+                )
+            )
+    return read_items
+
+
+def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
+    requests = []
+    for fqn, obj in state_dict.items():
+        if isinstance(obj, DTensor):
+            requests.append(_create_write_items_for_dtensor(fqn, obj))
+        elif isinstance(obj, ShardedTensor):
+            for shard_md in obj.metadata().shards_metadata:
+                requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
+        elif isinstance(obj, torch.Tensor):
+            requests.append(_create_write_item_for_tensor(fqn, obj))
+        else:
+            requests.append(_create_write_item_for_bytesio(fqn, obj))
+    return SavePlan(requests)
+
+
+def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
+    if isinstance(object, DTensor):
+        return [_create_write_items_for_dtensor(fqn, object)]
+    elif isinstance(object, ShardedTensor):
+        return [
+            _create_write_item_for_shard(fqn, object, shard.metadata)
+            for shard in object.local_shards()
+        ]
+    elif isinstance(object, torch.Tensor):
+        return [_create_write_item_for_tensor(fqn, object)]
+    else:
+        return [_create_write_item_for_bytesio(fqn, object)]
+
+
+def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
+    sizes, offsets = compute_local_shape_and_global_offset(
+        tensor.shape, tensor.device_mesh, tensor.placements
+    )
+    sizes, offsets = torch.Size(sizes), torch.Size(offsets)
+    return ChunkStorageMetadata(
+        offsets=offsets,
+        sizes=sizes,
+    )
+
+
+def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
+    if isinstance(tensor, DTensor):
+        local_chunks = [_create_chunk_from_dtensor(tensor)]
+    elif isinstance(tensor, ShardedTensor):
+        local_chunks = [
+            _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
+        ]
+    elif isinstance(tensor, torch.Tensor):
+        local_chunks = [_create_chunk_from_tensor(tensor)]
+    else:
+        raise ValueError(
+            "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
+            f",but got {type(tensor)}"
+        )
+
+    return local_chunks
+
+
+def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
+    if not isinstance(md, BytesStorageMetadata):
+        try:
+            local_chunks = _create_chunk_list(obj)
+        except ValueError as ex:
+            raise ValueError(
+                f"Invalid checkpoint metadata for {fqn}, "
+                + f"expected BytesStorageMetadata but found {type(md)}",
+            ) from ex
+
+        return create_read_items_for_chunk_list(fqn, md, local_chunks)
+    else:
+        return [
+            _create_read_item_for_byteio(
+                dest_index=MetadataIndex(fqn),
+                dest_offset=0,
+                storage_index=MetadataIndex(fqn),
+                storage_offset=0,
+                length=0,
+            )
+        ]
+
+
+def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
+    state_dict_assigned_storage = tree_map_only(
+        torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
+    )
+    # The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
+    # So we need to temporariy update the each element in the state dict with meta tensor.
+    for k in state_dict.keys():
+        state_dict[k] = state_dict_assigned_storage[k]
+
+
+def _init_meta_tensor(value: Any) -> Any:
+    """
+    Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device.
+    """
+
+    device = getattr(value, "device", None)
+    # DCP does the initialization if it's meta tensor/DTensor.
+    if device == torch.device("meta"):
+        device_type = dist.distributed_c10d._get_pg_default_device().type
+        device = cast(torch.device, _get_device_module(device_type).current_device())
+        if isinstance(value, DTensor):
+            new_local_tensor = torch.empty_like(value.to_local(), device=device)
+            # We need to pass shape and stride explicitly, since DTensor might be
+            # sharded unevenly.
+            dtensor = DTensor.from_local(
+                new_local_tensor,
+                device_mesh=value.device_mesh,
+                placements=value.placements,
+                shape=value.size(),
+                stride=value.stride(),
+            )
+            return dtensor
+        elif isinstance(value, torch.Tensor):
+            tensor = torch.empty_like(value, device=device)
+            return tensor
+        else:
+            raise RuntimeError(
+                f"Found unsupported type {type(value)} for meta device loading."
+            )
+    else:
+        return value
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/resharding.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/resharding.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ce4138ce11bc1a45ebd275068a5b043297b870
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/resharding.py
@@ -0,0 +1,70 @@
+from typing import List, Tuple
+
+from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
+
+__all__: List[str] = []
+
+
+def _check_shard_metadata_pair_overlap(
+    shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
+):
+    """Check if two shards overlap."""
+    # For each dim of each shard, check if one shard resides on the other
+    # end of second shard with respect to that dim. As an example for a 2D
+    # shard, we would check if one shard is above or on the left of the
+    # other shard.
+    ndims = len(shard1.offsets)
+    for i in range(ndims):
+        if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]:
+            return False
+        if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]:
+            return False
+
+    return True
+
+
+def _shards_get_overlap_region_wrt_saved_tensor(
+    saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
+) -> List[Tuple[int, int, int, int]]:
+    """
+    Return the overlapping region between saved_shard and current_shard.
+
+    There returned list has the same number of elements as the tensor's dimension.
+    For each element, we produce a tuple with the following contents:
+        (dimension, `saved_shard` offset, `current_shard` offset, length)
+
+    Offsets are relative to each shard.
+    """
+    narrows = []
+    for dim, (
+        saved_shard_offset,
+        current_shard_offset,
+        saved_shard_size,
+        current_shard_size,
+    ) in enumerate(
+        zip(
+            saved_shard.offsets,
+            current_shard.offsets,
+            saved_shard.sizes,
+            current_shard.sizes,
+        )
+    ):
+        min_range_end = min(
+            saved_shard_offset + saved_shard_size,
+            current_shard_offset + current_shard_size,
+        )
+
+        length = min_range_end - max(current_shard_offset, saved_shard_offset)
+
+        if saved_shard_offset > current_shard_offset:
+            offset_for_saved_tensor = 0
+            offset_for_current_tensor = saved_shard_offset - current_shard_offset
+        else:
+            offset_for_saved_tensor = current_shard_offset - saved_shard_offset
+            offset_for_current_tensor = 0
+
+        narrows.append(
+            (dim, offset_for_saved_tensor, offset_for_current_tensor, length)
+        )
+
+    return narrows
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a115f5e349bcd5101eebdfe7b2ad54da1262dd2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict.py
@@ -0,0 +1,1115 @@
+import contextlib
+import functools
+import gc
+from dataclasses import asdict, dataclass, field
+from itertools import chain
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    no_type_check,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._state_dict_utils import (
+    _gather_state_dict,
+    _offload_state_dict_to_cpu,
+)
+from torch.distributed._tensor import DTensor
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+    _CHECKPOINT_PREFIX,
+)
+from torch.distributed.fsdp import (
+    FullOptimStateDictConfig,
+    FullStateDictConfig,
+    FullyShardedDataParallel as FSDP,
+    OptimStateDictConfig,
+    ShardedOptimStateDictConfig,
+    ShardedStateDictConfig,
+    StateDictConfig,
+    StateDictType,
+)
+from torch.distributed.fsdp._common_utils import (
+    _get_module_fsdp_state_if_fully_sharded_module,
+    FSDP_WRAPPED_MODULE,
+)
+from torch.nn.modules.module import _IncompatibleKeys
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+FLAT_PARAM = "_flat_param"
+PG = "param_groups"
+PG_PREFIX = f"{PG}."
+STATE = "state"
+STATE_PREFIX = f"{STATE}."
+PARAMS = "params"
+FQNS_T = Set[str]
+
+_patched_state_dict: Set[Callable] = set()
+
+
+PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
+ValueType = Union[
+    PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"]
+]
+DictValueType = Dict[str, ValueType]
+ListDictValueType = List[DictValueType]
+OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]]
+
+
+@contextlib.contextmanager
+def gc_context():
+    is_enabled = gc.isenabled()
+    gc.disable()
+    try:
+        yield
+    finally:
+        # TODO: add logging for the gc details/time
+        gc.collect()
+        if is_enabled:
+            gc.enable()
+
+
+@dataclass
+class StateDictOptions:
+    """
+    This dataclass specifies how get_state_dict/set_state_dict will work.
+
+    - ``full_state_dict``: if this is set to True, all the tensors in the
+      returned state_dict will be gathered. No ShardedTensor and DTensor
+      will be in the returned state_dict.
+
+    - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
+      ``full_state_dict`` is also true, then only the rank0 will get the
+      state_dict and all other ranks will get empty state_dict.
+
+    - ``ignore_frozen_params``: if the value is True, the returned state_dict
+      won't contain any frozen parameters -- the ``requires_grad`` is False.
+      The default value is False.
+
+    - ``keep_submodule_prefixes``: when ``submodules`` is not None, this option
+      indicates whether to keep the submodule prefixes from the state_dict keys.
+      or example, if the submodule is ``module.pretrain`` and the full FQN of
+      the parameter is ``pretrain.layer1.weight`` of the param. When this option
+      is True, the parameter's key in the returned state_dict will be
+      ``pretrain.layer1.weight``. If the options is False, the key will be
+      ``layer1.weight``.
+      Note that if ``keep_submodule_prefixes`` is False, there may be conflicted
+      FQNs, hence there should be only one submodule in ``submodules``.
+
+    - ``strict``: the ``strict`` option when ``set_state_dict`` calls
+      model.load_state_dict().
+      The default value is False.
+    """
+
+    full_state_dict: bool = False
+    cpu_offload: bool = False
+    ignore_frozen_params: bool = False
+    keep_submodule_prefixes: bool = True
+    strict: bool = True
+
+
+@dataclass
+class _StateDictInfo(StateDictOptions):
+    fqn_param_mapping: Dict[
+        Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
+    ] = field(default_factory=dict)
+    all_fqns: Set[str] = field(default_factory=set)
+    submodule_prefixes: Set[str] = field(default_factory=set)
+    handle_model: bool = True
+    handle_optim: bool = True
+    fsdp_context: Callable = contextlib.nullcontext
+    fsdp_modules: List[nn.Module] = field(default_factory=list)
+
+
+def _get_fqns(
+    model: nn.Module,
+    name: str,
+    skip_ddp_prefix: bool = True,
+    skip_compiler_prefix: bool = True,
+) -> FQNS_T:
+    """
+    This API is used to convert the name of a parameter to the FQNs. For FSDP
+    without `use_orig_params`, the name of FlatParameter can be mapped to
+    multiple original parameters. As a result, the return type of this function
+    is `Set[str]`.
+
+    Args:
+        module (nn.Module): the root model.
+        name (str): the name
+        skip_ddp_prefix (bool): whether to skip DDP's `module` prefix
+
+    Returns:
+        The canonical FQNs based on the model traversal.
+    """
+
+    # Remove the checkpoint prefix, if it exists.
+    name = name.replace(_CHECKPOINT_PREFIX, "")
+    if "." not in name:
+        return {name}
+
+    obj_names = name.split(".")
+    fqn_obj_names = []
+    curr_obj = model
+    for i, curr_obj_name in enumerate(obj_names):
+        if isinstance(curr_obj, DDP):
+            assert curr_obj_name == "module"
+            curr_obj = curr_obj.module
+            if not skip_ddp_prefix:
+                fqn_obj_names.append(curr_obj_name)
+        elif isinstance(curr_obj, FSDP):
+            if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM:
+                prefix = ".".join(fqn_obj_names)
+                flat_param = getattr(curr_obj, FLAT_PARAM)
+                if prefix:
+                    prefix = f"{prefix}."
+                return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
+            curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
+            if curr_obj_name != FSDP_WRAPPED_MODULE:
+                fqn_obj_names.append(curr_obj_name)
+                curr_obj = getattr(curr_obj, curr_obj_name)
+        elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
+            assert curr_obj_name == "_orig_mod"
+            curr_obj = curr_obj._orig_mod
+            if not skip_compiler_prefix:
+                fqn_obj_names.append(curr_obj_name)
+        else:
+            fqn_obj_names.append(curr_obj_name)
+            if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
+                if i != len(obj_names) - 1:
+                    raise RuntimeError("Expect `_extra_state` to be the last obj name")
+            else:
+                curr_obj = getattr(curr_obj, curr_obj_name)
+
+    return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")}
+
+
+class _EXTRA_STATE:
+    pass
+
+
+def _iterate_valid_model_state(model):
+    visited_modules: Set[nn.Module] = set()
+
+    def recurse(module: nn.Module, curr_fqn: str) -> Generator:
+        visited_modules.add(module)
+
+        curr_fqn = f"{curr_fqn}." if curr_fqn else ""
+        for name, submodule in module.named_children():
+            if submodule in visited_modules:
+                continue
+            new_fqn = f"{curr_fqn}{name}"
+            yield from recurse(submodule, new_fqn)
+
+        for name, obj in chain(
+            module.named_buffers(recurse=False), module.named_parameters(recurse=False)
+        ):
+            if name in module._non_persistent_buffers_set:
+                continue
+            new_fqn = f"{curr_fqn}{name}"
+            yield new_fqn, obj
+
+        if (
+            getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state)
+            != nn.Module.get_extra_state
+        ):
+            new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}"
+            yield new_fqn, _EXTRA_STATE()
+
+    yield from recurse(model, "")
+
+
+def _verify_options(
+    model: nn.Module,
+    optims: Tuple[torch.optim.Optimizer, ...],
+    optim_only: bool,
+    *,
+    submodules: Optional[Set[nn.Module]] = None,
+    options: Optional[StateDictOptions] = None,
+) -> _StateDictInfo:
+    """
+    Verify the model and options passed by the user and generates _StateDictInfo.
+    """
+    if optim_only and not optims:
+        raise RuntimeError(
+            "Optimizers are not passed in but optim_only is set to True."
+        )
+
+    options = options or StateDictOptions()
+
+    fqn_param_mapping: Dict[
+        Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
+    ] = {}
+    all_fqns = set()
+    for name, param in _iterate_valid_model_state(model):
+        fqns = _get_fqns(model, name)
+        if not isinstance(param, _EXTRA_STATE):
+            fqn_param_mapping[param] = fqns
+        for fqn in fqns:
+            if not isinstance(param, _EXTRA_STATE):
+                fqn_param_mapping[fqn] = param
+            all_fqns.add(fqn)
+
+    submodule_prefixes = set()
+    if submodules:
+        submodules = set(submodules)
+        for name, module in model.named_modules():
+            if module not in submodules:
+                continue
+            fqns = _get_fqns(model, name)
+            assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
+            for fqn in fqns:
+                submodule_prefixes.add(f"{fqn}.")
+
+    fsdp_modules = FSDP.fsdp_modules(model)
+    state_dict_config: StateDictConfig
+    optim_state_dict_config: OptimStateDictConfig
+    fsdp_context: Callable
+    if fsdp_modules:
+        # FSDP API only work if at least one FSDP instance exists.
+        if options.full_state_dict:
+            state_dict_config = FullStateDictConfig(
+                offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
+            )
+            optim_state_dict_config = FullOptimStateDictConfig(
+                offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
+            )
+            state_dict_type = StateDictType.FULL_STATE_DICT
+        else:
+            state_dict_config = ShardedStateDictConfig(
+                offload_to_cpu=options.cpu_offload,
+            )
+            optim_state_dict_config = ShardedOptimStateDictConfig(
+                offload_to_cpu=options.cpu_offload,
+            )
+            state_dict_type = StateDictType.SHARDED_STATE_DICT
+
+        fsdp_context = functools.partial(
+            FSDP.state_dict_type,
+            module=model,
+            state_dict_type=state_dict_type,
+            state_dict_config=state_dict_config,
+            optim_state_dict_config=optim_state_dict_config,
+        )
+    else:
+        fsdp_context = contextlib.nullcontext
+
+    return _StateDictInfo(
+        **asdict(options),
+        fqn_param_mapping=fqn_param_mapping,
+        all_fqns=all_fqns,
+        submodule_prefixes=submodule_prefixes,
+        fsdp_context=fsdp_context,
+        fsdp_modules=cast(List[nn.Module], fsdp_modules),
+        handle_model=not optim_only,
+        handle_optim=(len(optims) > 0),
+    )
+
+
+def _verify_state_dict(
+    model_state_dict: Dict[str, ValueType],
+    optim_state_dict: OptimizerStateType,
+    info: _StateDictInfo,
+) -> None:
+    for module in info.fsdp_modules:
+        fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
+        assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
+
+    # Verify if the model_state_dict and optim_state_dict are valid. This API
+    # should give the users an explicit error message to debug or report.
+    if (
+        info.handle_model
+        and not model_state_dict
+        and not info.submodule_prefixes
+        and not info.ignore_frozen_params
+        and not (info.cpu_offload and info.full_state_dict)
+        and info.strict
+    ):
+        raise RuntimeError(
+            "The option indicates that model state_dict is required to save "
+            "or load, but model state_dict is empty."
+            f"rank = {dist.get_rank()=}."
+        )
+
+    if info.handle_optim:
+        if not (optim_state_dict and optim_state_dict[STATE]) and not (
+            info.cpu_offload and info.full_state_dict
+        ):
+            raise RuntimeError(
+                "The option indicates that model state_dict is required to save, "
+                f"or load but optim state_dict is empty. {optim_state_dict}"
+            )
+
+    for key in model_state_dict.keys():
+        if FLAT_PARAM in key:
+            raise RuntimeError(
+                f"{key} contains {FLAT_PARAM}. This can happen if the model "
+                "is not the root module."
+            )
+
+
+def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable:
+    call = getattr(obj, api)
+    if call in _patched_state_dict:
+        call = functools.partial(getattr(obj.__class__, api), self=obj)
+    return call
+
+
+def _get_model_state_dict(
+    model: nn.Module, info: _StateDictInfo
+) -> Dict[str, ValueType]:
+    if not info.handle_model:
+        return {}
+
+    with info.fsdp_context():
+        state_dict = _state_dict_fn(model, "state_dict")()
+
+    for key in list(state_dict.keys()):
+        fqns = _get_fqns(model, key)
+        assert len(fqns) == 1
+        fqn = next(iter(fqns))
+        if fqn != key:
+            # As we only support FSDP, DDP, and TP, the only cases are
+            # wrapper-based DDP and compiler. Verify if the assumption
+            # is correct.
+            def verify(key, fqn) -> bool:
+                if len(fqn) >= len(key):
+                    return False
+                fqn_split = fqn.split(".")
+                key_split = key.split(".")
+                fqn_idx = 0
+                for key_idx, key_name in enumerate(key_split):
+                    if key_name == fqn_split[fqn_idx]:
+                        fqn_idx += 1
+                        if fqn_idx == len(fqn_split):
+                            return key_idx == len(key_split) - 1
+                    elif key_name in ("module", "_orig_mod"):
+                        continue
+                    else:
+                        return False
+                return True
+
+            if not verify(key, fqn):
+                raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}")
+            state_dict[fqn] = state_dict.pop(key)
+
+    if info.submodule_prefixes:
+        new_state_dict: Dict[str, ValueType] = {}
+        # TODO: make this faster.
+        for fqn in state_dict.keys():
+            for prefix in info.submodule_prefixes:
+                if not fqn.startswith(prefix):
+                    continue
+                if info.keep_submodule_prefixes:
+                    new_state_dict[fqn] = state_dict[fqn]
+                else:
+                    new_fqn = fqn[len(prefix) :]
+                    new_state_dict[new_fqn] = state_dict[fqn]
+        state_dict = new_state_dict
+
+    if info.ignore_frozen_params:
+        for key, param in model.named_parameters():
+            if param.requires_grad:
+                continue
+            fqns = _get_fqns(model, key)
+            for fqn in fqns:
+                state_dict.pop(fqn)
+
+    for key, p in list(state_dict.items()):
+        if torch.is_tensor(p) and p.is_meta:
+            state_dict.pop(key)
+
+    if info.full_state_dict:
+        ranks_only = tuple() if not info.cpu_offload else (0,)
+        return _gather_state_dict(
+            state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
+        )
+    elif info.cpu_offload:
+        return _offload_state_dict_to_cpu(state_dict)
+    else:
+        return state_dict
+
+
+def _load_model_state_dict(
+    model: nn.Module,
+    state_dict: Dict[str, ValueType],
+    info: _StateDictInfo,
+) -> _IncompatibleKeys:
+    if not info.handle_model or not state_dict:
+        return _IncompatibleKeys({}, {})
+
+    for key, _ in _iterate_valid_model_state(model):
+        fqns = _get_fqns(model, key)
+        fqns_with_prefix = _get_fqns(
+            model, key, skip_ddp_prefix=False, skip_compiler_prefix=False
+        )
+        for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix):
+            if fqn != fqn_with_prefix:
+                state_dict[fqn_with_prefix] = state_dict.pop(fqn)
+
+    with info.fsdp_context():
+        return cast(
+            _IncompatibleKeys,
+            _state_dict_fn(model, "load_state_dict")(
+                state_dict=state_dict, strict=info.strict
+            ),
+        )
+
+
+def _init_optim_state(optim: torch.optim.Optimizer) -> None:
+    """
+    Initialize optim states by calling the step() with zero grads.
+    """
+    if optim.state:
+        # The optimizer state is initialized.
+        return
+
+    for param_group in optim.param_groups:
+        for param in param_group[PARAMS]:
+            if param.grad is not None:
+                raise RuntimeError(
+                    "state_dict can only be used if the optimizer "
+                    "states are initialized (usually after one step() with "
+                    "gradients) or gradients are None. For the later case, "
+                    "state_dict will fake the gradients as zero "
+                    "to initialize the optimizer states. However, the "
+                    "gradients are not None."
+                )
+            if param.requires_grad:
+                param.grad = torch.zeros_like(param)
+    optim.step(closure=None)
+    optim.zero_grad(set_to_none=True)
+
+
+def _get_optim_state_dict(
+    model: nn.Module,
+    optimizers: Tuple[torch.optim.Optimizer, ...],
+    info: _StateDictInfo,
+) -> OptimizerStateType:
+    if not info.handle_optim:
+        return {}
+
+    optim_state_dict: OptimizerStateType = {STATE: {}, PG: []}
+    for optim in optimizers:
+        _init_optim_state(optim)
+        osd = _state_dict_fn(optim, "state_dict")()
+        if info.fsdp_modules:
+            with info.fsdp_context():
+                osd = FSDP.optim_state_dict(model, optim, osd)
+
+            # We need to specially handle FlatParameter FSDP as
+            # FlatParameter FSDP converts the FQNs.
+            # There are no easy ways to do this conversion systematically.
+            # We can only use a string replacment without correctness check.
+            if not osd:
+                continue
+            for k in list(osd[STATE].keys()):
+                if "_orig_mod" in k:
+                    osd[STATE][k.replace("_orig_mod.", "")] = osd[STATE].pop(k)
+            for g in osd[PG]:
+                params = [k.replace("_orig_mod.", "") for k in g[PARAMS]]
+                g[PARAMS] = params
+        else:
+            params = list(chain.from_iterable(g[PARAMS] for g in optim.param_groups))
+            param_pid_mapping = dict(zip(params, range(len(params))))
+            fqn_pid_mapping = {}
+            for key, param in model.named_parameters():
+                fqns = _get_fqns(model, key)
+                assert len(fqns) == 1
+                fqn = next(iter(fqns))
+                if param not in param_pid_mapping:
+                    continue
+                pid = param_pid_mapping[param]
+                fqn_pid_mapping[fqn] = pid
+                fqn_pid_mapping[pid] = fqn
+
+            for key in list(osd[STATE].keys()):
+                fqn = fqn_pid_mapping[key]
+                osd[STATE][fqn] = osd[STATE].pop(key)
+
+            for group in osd[PG]:
+                group[PARAMS] = [fqn_pid_mapping[pid] for pid in group[PARAMS]]
+
+        if not osd:
+            continue
+
+        cast(DictValueType, optim_state_dict[STATE]).update(osd[STATE])
+        cast(ListDictValueType, optim_state_dict[PG]).extend(osd[PG])
+
+    if info.full_state_dict:
+        ranks_only = tuple() if not info.cpu_offload else (0,)
+        return _gather_state_dict(
+            optim_state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
+        )
+    elif info.cpu_offload:
+        return _offload_state_dict_to_cpu(optim_state_dict)
+    else:
+        return optim_state_dict
+
+
+def _split_optim_state_dict(
+    model: nn.Module,
+    optim: torch.optim.Optimizer,
+    optim_state_dict: OptimizerStateType,
+    info: _StateDictInfo,
+) -> OptimizerStateType:
+    """
+    Extract the corresponding optim state_dict from ``optim_state_dict`` for
+    ``optim`` and return the result optim state_dict.
+
+    Args:
+        model (nn.Module): the root model.
+        optim (torch.optim.Optimizer): the optimizer.
+        optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that
+            contains the optim state_dict of ``optim``.
+        info (_StateDictInfo): state dict information.
+
+    Returns:
+        The optim state_dict of ``optim``.
+    """
+
+    state: DictValueType = {}
+    pg_state: ListDictValueType = []
+    return_osd: OptimizerStateType = {STATE: state, PG: pg_state}
+    pg_mapping: Dict[int, int] = {}
+
+    for param_group in optim.param_groups:
+        pg_state.append({PARAMS: []})
+        for param in param_group[PARAMS]:
+            for fqn in info.fqn_param_mapping[param]:
+                params = pg_state[-1][PARAMS]
+                assert isinstance(params, list)
+                params.append(fqn)
+                if param.requires_grad:
+                    state[fqn] = cast(DictValueType, optim_state_dict[STATE])[fqn]
+                for loaded_param_group in cast(ListDictValueType, optim_state_dict[PG]):
+                    params = loaded_param_group[PARAMS]
+                    assert isinstance(params, list)
+                    if fqn in params:
+                        pg_mapping[id(loaded_param_group)] = len(return_osd[PG]) - 1
+
+    for param_group in cast(ListDictValueType, optim_state_dict[PG]):
+        idx = pg_mapping.get(id(param_group), -1)
+        if idx == -1:
+            continue
+        for key, value in param_group.items():
+            if key == PARAMS:
+                continue
+            # TODO: check if value is the same if exists.
+            pg_state[idx][key] = value
+
+    return return_osd
+
+
+def _load_optim_state_dict(
+    model: nn.Module,
+    optimizers: Tuple[torch.optim.Optimizer, ...],
+    state_dict: OptimizerStateType,
+    info: _StateDictInfo,
+) -> None:
+    if not info.handle_optim:
+        return
+
+    for optim in optimizers:
+        optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info)
+        if info.fsdp_modules:
+            # We need to specially handle FlatParameter FSDP as
+            # FlatParameter FSDP converts the FQNs.
+            for original_fqn, _ in model.named_parameters():
+                fqns = _get_fqns(model, original_fqn)
+                fqns_with_compiler = _get_fqns(
+                    model, original_fqn, skip_compiler_prefix=False
+                )
+                if fqns == fqns_with_compiler:
+                    continue
+
+                assert len(fqns) == 1
+                fqn = fqns.pop()
+                fqn_with_compiler = fqns_with_compiler.pop()
+                for g in optim_state_dict[PG]:
+                    val = cast(Dict[str, Any], g)
+                    params = [
+                        key.replace(fqn, fqn_with_compiler) for key in val[PARAMS]
+                    ]
+                    val[PARAMS] = params
+                osd_state = cast(DictValueType, optim_state_dict[STATE])
+                for k in list(osd_state.keys()):
+                    if fqn in k:
+                        osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k)
+
+            with info.fsdp_context():
+                optim_state_dict = FSDP.optim_state_dict_to_load(
+                    model, optim, optim_state_dict
+                )
+
+        # Note that we do not have to convert the FQN back to param id here if
+        # order in optim.param_groups[idx][PARAMS] is the same as the one in
+        # optim_state_dict[PG][idx][PARAMS].
+        _init_optim_state(optim)
+        _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict)
+
+
+def get_model_state_dict(
+    model: nn.Module,
+    *,
+    submodules: Optional[Set[nn.Module]] = None,
+    options: Optional[StateDictOptions] = None,
+) -> Dict[str, ValueType]:
+    """
+    Return the model state_dict of ``model``.
+
+    See ``get_state_dict`` for the detail usage.
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        submodules: Optional[Set[nn.Module]]: only return the model parameters
+            that belong to the submodules.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be returned. See
+            `StateDictOptions` for the details.
+
+    Returns:
+        The state_dict for ``model``.
+
+    :rtype: typing.Dict[str, ValueType]
+    """
+    with gc_context():
+        info = _verify_options(
+            model,
+            tuple(),
+            optim_only=False,
+            submodules=submodules,
+            options=options,
+        )
+        model_state_dict = _get_model_state_dict(model, info)
+        _verify_state_dict(model_state_dict, {}, info)
+        return model_state_dict
+
+
+def get_optimizer_state_dict(
+    model: nn.Module,
+    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
+    *,
+    submodules: Optional[Set[nn.Module]] = None,
+    options: Optional[StateDictOptions] = None,
+) -> OptimizerStateType:
+    """
+    Return the combined state_dict for optimizers.
+
+    See ``get_state_dict`` for the detail usage.
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
+            The optimizers that are used to optimize ``model``.
+        submodules: Optional[Set[nn.Module]]: only return the model parameters
+            that belong to the submodules.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be returned. See
+            `StateDictOptions` for the details.
+
+    Returns:
+        The state_dict for ``optimizers``.
+
+    :rtype: OptimizerStateType
+    """
+    with gc_context():
+        optimizers = (
+            (optimizers,)
+            if isinstance(optimizers, torch.optim.Optimizer)
+            else tuple(optimizers)
+        )
+        info = _verify_options(
+            model,
+            optimizers,
+            optim_only=True,
+            submodules=submodules,
+            options=options,
+        )
+        optim_state_dict = _get_optim_state_dict(model, optimizers, info)
+        _verify_state_dict({}, optim_state_dict, info)
+        return optim_state_dict
+
+
+def get_state_dict(
+    model: nn.Module,
+    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
+    *,
+    submodules: Optional[Set[nn.Module]] = None,
+    options: Optional[StateDictOptions] = None,
+) -> Tuple[Dict[str, ValueType], OptimizerStateType]:
+    """
+    Return the model state_dict and optimizers state_dict.
+
+    ``get_state_dict`` can process any module that is parallelized by PyTorch
+    FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any
+    combination of these parallelisms. The main functions of ``get_state_dict``
+    are: 1.) returning a model and optimizer state_dict that can be resharded
+    with a different number of trainers and/or different parallelisms.
+    2.) hiding the parallelism-specific state_dict APIs. Users don't have to call
+    these APIs.
+    3.) sanity checking the result state_dict.
+
+    The keys of the result state dictionary are the canonical FQNs (Fully
+    Qualified Names).  A canonical FQN refers to the FQN based on a parameter's
+    position in an nn.Module hierarchy. More specifically, a canonical FQN to a
+    parameter is the FQN returned by ``module.named_parameters()`` or
+    ``module.named_buffers()`` when the module is not distributed by any
+    parallelisms. Since the optimizer internally uses parameter IDs to represent
+    a parameter, there will be a conversion from the parameter IDs to the
+    canonical FQNs when calling this API.
+
+    ``get_state_dict`` can also process a module that is not parallelized. In
+    such a case, ``get_state_dict`` only performs one function -- converting the
+    optimizer parameter IDs to the canonical FQNs.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> import torch
+        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+        >>> from torch.nn.parallel import DistributedDataParallel as DDP
+        >>> from torch.distributed.checkpoint.state_dict import get_state_dict
+
+        >>> fsdp_model = FSDP(copy.deepcopy(model))
+        >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+        >>> ddp_model = DDP(copy.deepcopy(model))
+        >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+
+        >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
+        >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
+
+        >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
+        >>> # the asserts will fail.
+        >>> assert ddp_state_dict == fsdp_state_dict
+        >>> assert ddp_optim_state == fsdp_optim_state_dict
+
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
+            The optimizers that are used to optimize ``model``.
+        submodules: Optional[Set[nn.Module]]: only return the model parameters
+            that belong to the submodules.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be returned. See
+            `StateDictOptions` for the details.
+
+    Returns:
+        ``Tuple`` that contain model state_dict and optimizer state_dict.
+
+    :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]
+    """
+
+    with gc_context():
+        optimizers = (
+            (optimizers,)
+            if isinstance(optimizers, torch.optim.Optimizer)
+            else tuple(optimizers)
+        )
+        info = _verify_options(
+            model,
+            optimizers,
+            optim_only=False,
+            submodules=submodules,
+            options=options,
+        )
+        model_state_dict = _get_model_state_dict(model, info)
+        optim_state_dict = _get_optim_state_dict(model, optimizers, info)
+        _verify_state_dict(model_state_dict, optim_state_dict, info)
+        return model_state_dict, optim_state_dict
+
+
+def _unflatten_model_state_dict(
+    model: nn.Module,
+    state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]],
+) -> Dict[str, ValueType]:
+    if not state_dict:
+        return {}
+
+    if isinstance(next(iter(state_dict.keys())), nn.Module):
+        cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict)
+        new_state_dict: Dict[str, ValueType] = {}
+        for submodule, sub_state_dict in cast_state_dict.items():
+            for name, m in model.named_modules():
+                if m != submodule:
+                    continue
+
+                fqns = _get_fqns(model, name)
+                assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
+                prefix = f"{next(iter(fqns))}."
+                new_state_dict.update(
+                    {prefix + subfqn: value for subfqn, value in sub_state_dict.items()}
+                )
+        return new_state_dict
+    else:
+        return cast(Dict[str, ValueType], state_dict)
+
+
+def set_model_state_dict(
+    model: nn.Module,
+    model_state_dict: Dict[str, ValueType],
+    *,
+    options: Optional[StateDictOptions] = None,
+) -> _IncompatibleKeys:
+    """Load the model state_dict.
+
+    The counterpart of ``get_model_state_dict`` to set the state_dict to the
+    model. See ``set_state_dict`` for the detail usage.
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        model_state_dict: (Dict[str, ValueType]):
+           the model state_dict to load. If the key of the ``model_state_dict``
+           is nn.Module, the key is a submodule of ``model`` and the value should
+           be the state_dict of the submodule. When loading the state_dict,
+           the prefix of the submodule will be append to the state_dict.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be loaded. See
+            `StateDictOptions` for the details.
+
+    Returns:
+        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
+            * **missing_keys** is a list of str containing the missing keys
+            * **unexpected_keys** is a list of str containing the unexpected keys
+
+    :type model_state_dict: typing.Dict[str, ValueType]
+    """
+    model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
+        model, model_state_dict
+    )
+    with gc_context():
+        info = _verify_options(model, tuple(), optim_only=False, options=options)
+
+        _verify_state_dict(model_state_dict, {}, info)
+        return _load_model_state_dict(model, model_state_dict, info)
+
+
+def set_optimizer_state_dict(
+    model: nn.Module,
+    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
+    *,
+    optim_state_dict: OptimizerStateType,
+    options: Optional[StateDictOptions] = None,
+) -> None:
+    """Load the optimizers state_dict.
+
+    The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the
+    optimizers. See ``set_state_dict`` for the detail usage.
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        optimizers (Union[Optimizer, Iterable[Optimizer]]):
+            The optimizers that are used to optimize ``model``.
+        optim_state_dict: OptimizerStateType:
+            the optimizer state_dict to load.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be loaded. See
+            `StateDictOptions` for the details.
+
+    Returns:
+        None
+
+    :type optim_state_dict: typing.OptimizerStateType
+    """
+    with gc_context():
+        optimizers = (
+            (optimizers,)
+            if isinstance(optimizers, torch.optim.Optimizer)
+            else tuple(optimizers)
+        )
+        info = _verify_options(model, optimizers, optim_only=True, options=options)
+
+        _verify_state_dict({}, optim_state_dict, info)
+        _load_optim_state_dict(model, optimizers, optim_state_dict, info)
+
+
+def set_state_dict(
+    model: nn.Module,
+    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
+    *,
+    model_state_dict: Dict[str, ValueType],
+    optim_state_dict: OptimizerStateType,
+    options: Optional[StateDictOptions] = None,
+) -> _IncompatibleKeys:
+    """Load the model state_dict and optimizers state_dict.
+
+    The counterpart of ``get_state_dict`` to set the state_dict to the model and
+    optimizers.  The given ``model_state_dict`` and ``optim_state_dict`` do not
+    have to be returned by ``get_state_dict`` but must meet the following
+    requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,
+    2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,
+    3) optimizer state_dict cannot contain the parameter IDs; the keys should be
+    the canonical FQNs.
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        optimizers (Union[Optimizer, Iterable[Optimizer]]):
+            The optimizers that are used to optimize ``model``.
+        model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):
+           the model state_dict to load. If the key of the ``model_state_dict``
+           is nn.Module, the key is a submodule of ``model`` and the value should
+           be the state_dict of the submodule. When loading the state_dict,
+           the prefix of the submodule will be append to the state_dict.
+        optim_state_dict: OptimizerStateType:
+            the optimizer state_dict to load.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be loaded. See
+            `StateDictOptions` for the details.
+
+    Returns:
+        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
+            * **missing_keys** is a list of str containing the missing keys of the model state_dict.
+            * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.
+
+    :type model_state_dict: typing.Dict[str, ValueType]
+    :type optim_state_dict: typing.OptimizerStateType
+    """
+
+    model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
+        model, model_state_dict
+    )
+    with gc_context():
+        optimizers = (
+            (optimizers,)
+            if isinstance(optimizers, torch.optim.Optimizer)
+            else tuple(optimizers)
+        )
+        info = _verify_options(
+            model, optimizers, optim_only=not model_state_dict, options=options
+        )
+
+        _verify_state_dict(model_state_dict, optim_state_dict, info)
+        _load_optim_state_dict(model, optimizers, optim_state_dict, info)
+        return _load_model_state_dict(model, model_state_dict, info)
+
+
+# TODO: correct the state_dict function signature.
+# TODO: this API is not yet fully tested. Make it private
+@no_type_check
+def _patch_model_state_dict(
+    model: nn.Module,
+    *,
+    options: Optional[StateDictOptions] = None,
+) -> None:
+    """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.
+
+    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to
+    be a partial function to call ``get_state_dict`` and ``set_state_dict``.
+
+    Example:
+        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+        from torch.distributed.checkpoint.state_dict import patch_model_state_dict
+
+        model = fsdp(model)
+        patch_model_state_dict(model)
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be loaded. See
+            `StateDictOptions` for the details.
+    Returns:
+        None
+    """
+
+    _state_dict_call = functools.partial(
+        get_model_state_dict,
+        model=model,
+        options=options,
+    )
+
+    def state_dict_call():
+        return _state_dict_call()
+
+    model.state_dict = state_dict_call
+
+    _load_state_dict_call = functools.partial(
+        set_model_state_dict,
+        model=model,
+        options=options,
+    )
+
+    def load_state_dict_call(state_dict: Dict[str, Any]):
+        _load_state_dict_call(model_state_dict=state_dict)
+
+    model.load_state_dict = load_state_dict_call
+
+    _patched_state_dict.add(state_dict_call)
+    _patched_state_dict.add(load_state_dict_call)
+
+
+# TODO: correct the load_state_dict function signature.
+# TODO: this API is not yet fully tested. Make it private
+@no_type_check
+def _patch_optimizer_state_dict(
+    model: nn.Module,
+    *,
+    optimizers: Tuple[torch.optim.Optimizer, ...],
+    options: Optional[StateDictOptions] = None,
+) -> None:
+    """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.
+
+    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to
+    be a partial function to call ``get_state_dict`` and ``set_state_dict``.
+
+    Note that if there are multiple optimizers, all of the optimizers will be patched.
+    So users only need to call one of the state_dict() to get the full result.
+
+    Example:
+        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+        from torch.distributed.checkpoint.state_dict import patch_model_state_dict
+
+        model = fsdp(model)
+        patch_model_state_dict(model)
+
+    Args:
+        model (nn.Module): the nn.Module to the model.
+        options (StateDictOptions): the options to control how
+            model state_dict and optimizer state_dict should be loaded. See
+            `StateDictOptions` for the details.
+    Returns:
+        None
+    """
+
+    _state_dict_call = functools.partial(
+        get_optimizer_state_dict,
+        model=model,
+        optimizers=optimizers,
+        options=options,
+    )
+
+    def state_dict_call():
+        return _state_dict_call()
+
+    _load_state_dict_call = functools.partial(
+        set_optimizer_state_dict,
+        model=model,
+        optimizers=optimizers,
+        options=options,
+    )
+
+    def load_state_dict_call(state_dict: Dict[str, Any]):
+        _load_state_dict_call(optim_state_dict=state_dict)
+
+    _patched_state_dict.add(state_dict_call)
+    _patched_state_dict.add(load_state_dict_call)
+    optimizers = (
+        (optimizers,)
+        if isinstance(optimizers, torch.optim.Optimizer)
+        else tuple(optimizers)
+    )
+    for optim in optimizers:
+        optim.state_dict = state_dict_call
+        optim.load_state_dict = load_state_dict_call
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict_loader.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28738bc7772012d1ad2f24b484ccfe1c6b69381
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict_loader.py
@@ -0,0 +1,218 @@
+import os
+import warnings
+from typing import Any, cast, Dict, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed.checkpoint.stateful import Stateful
+
+from ._storage_utils import _storage_setup
+from .default_planner import DefaultLoadPlanner
+from .planner import LoadPlanner
+from .storage import StorageReader
+from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile
+
+__all__ = ["load_state_dict", "load"]
+
+
+def load_state_dict(
+    state_dict: Dict[str, Any],
+    storage_reader: StorageReader,
+    process_group: Optional[dist.ProcessGroup] = None,
+    coordinator_rank: int = 0,
+    no_dist: bool = False,
+    planner: Optional[LoadPlanner] = None,
+) -> None:
+    """This method is deprecated. Please switch to 'load'."""
+    warnings.warn(
+        "'load_state_dict' is deprecated and will be removed in future versions. "
+        "Please use 'load' instead."
+    )
+    storage_reader.reset()
+    with _profile():
+        # TODO: test returning `load` here instead.
+        return _load_state_dict(
+            state_dict,
+            storage_reader,
+            process_group,
+            coordinator_rank,
+            no_dist,
+            planner,
+        )
+
+
+@_api_bc_check
+def load(
+    state_dict: Dict[str, Any],
+    *,
+    checkpoint_id: Union[str, os.PathLike, None] = None,
+    storage_reader: Optional[StorageReader] = None,
+    planner: Optional[LoadPlanner] = None,
+    process_group: Optional[dist.ProcessGroup] = None,
+) -> None:
+    """
+    Load a distributed ``state_dict`` in SPMD style.
+
+    Each rank will try to read the least amount of data necessary
+    to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
+    or :class:`DTensor` instances, each rank only reads data for their local shards.
+
+    For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
+    load will first call ``state_dict`` before attempting deserialization, followed by
+    ``load_state_dict`` once the deserialization is complete.
+
+    .. warning::
+        All tensors in ``state_dict`` must be allocated on their
+        destination device *prior to* calling this function.
+
+        All non-tensor data is loaded using `torch.load()` and modified in place
+        on state_dict.
+
+    .. warning::
+        Users must call `load_state_dict` on the root module to ensure load
+        pos-processing and non-tensor data properly propagates.
+
+    .. note:
+        If no process group is initialized, this function can assumesbe the intent
+        is to load a checkpoint into the local process. This can be useful in the
+        case of local inference, and when using regular Tensors (as opposed to DTensor
+         or ShardedTensor)
+
+    .. note:
+        Rank 0 is assumed to be the coordinator rank.
+
+    Args:
+        state_dict (Dict[str, Any]): The state_dict to save.
+        checkpoint_id (Union[str, os.PathLike, None]):
+            The ID of this checkpoint instance. The meaning of the checkpoint_id
+            depends on the storage. It can be a path to a folder or to a file.
+            It can also be a key if the storage is a key-value store.
+            (Default: ``None``)
+        storage_reader (Optional[StorageReader]):
+            Instance of StorageWriter used to perform reads. If this is not
+            specified, DCP will automatically infer the reader based on the
+            checkpoint_id. If checkpoint_id is also None, an exception will
+            be raised. (Default: ``None``)
+        planner (Optional[LoadPlanner]):
+            Instance of LoadPlanner. If this is not specificed, the default
+            planner will be used. (Default: ``None``)
+        process_group (Optional[ProcessGroup]):
+            ProcessGroup to be used for cross-rank synchronization.
+            (Default: ``None``)
+
+    Returns:
+        None.
+
+    Examples
+        >>> # xdoctest: +SKIP
+        >>> my_model = MyModule()
+        >>> optimizer = Adagrad(my_model.parameters())
+        >>> model_state_dict = my_model.state_dict()
+        >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
+
+        >>> torch.distributed.checkpoint.load_state_dict(
+        >>>     state_dict=model_state_dict,
+        >>>     storage_reader=fs_storage_reader,
+        >>> )
+
+        >>> # module.load_state_dict() function might have customized steps
+        >>> # to flush the state_dict, must call it to
+        >>> # ensure correct behavior.
+        >>> my_model.load_state_dict(model_state_dict)
+
+    .. note::
+        load_state_dict uses collectives to coordinate reads across ranks.
+        For NCCL-based process groups, internal tensor representations of
+        objects must be moved to the GPU device before communication takes place.
+        In this case, the device used is given by ``torch.cuda.current_device()``
+        and it is the user's responsibility to ensure that this is set so that each
+        rank has an individual GPU, via ``torch.cuda.set_device()``.
+    """
+
+    no_dist = not (dist.is_available() and dist.is_initialized())
+    if no_dist:
+        warnings.warn(
+            "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
+        )
+
+    with _profile():
+        storage_reader = cast(
+            StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
+        )
+
+        if no_dist:
+            keys = list(state_dict.keys())
+        else:
+            keys = _all_gather_keys(state_dict, process_group)
+            if keys != sorted(state_dict.keys()):
+                warnings.warn(
+                    "Detected mismatched keys in state dict after all gather!"
+                    " This behavior is unsupported and may cause errors may cause errors."
+                )
+
+        statetful_sd = {}
+        for key in keys:
+            if key not in state_dict:
+                continue
+            elem = state_dict[key]
+            statetful_sd[key] = (
+                elem.state_dict() if isinstance(elem, Stateful) else elem
+            )
+
+        _load_state_dict(
+            state_dict=statetful_sd,
+            storage_reader=storage_reader,
+            process_group=process_group,
+            no_dist=no_dist,
+            planner=planner,
+        )
+        for key in keys:
+            if key not in state_dict:
+                continue
+            elem = state_dict[key]
+            if isinstance(elem, Stateful):
+                elem.load_state_dict(statetful_sd[key])
+            state_dict[key] = elem
+
+
+def _load_state_dict(
+    state_dict: Dict[str, Any],
+    storage_reader: StorageReader,
+    process_group: Optional[dist.ProcessGroup] = None,
+    coordinator_rank: int = 0,
+    no_dist: bool = False,
+    planner: Optional[LoadPlanner] = None,
+) -> None:
+    torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
+
+    distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
+    if planner is None:
+        planner = DefaultLoadPlanner()
+
+    def local_step():
+        assert planner is not None
+        metadata = storage_reader.read_metadata()
+        planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
+        storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
+
+        local_plan = planner.create_local_plan()
+        local_plan = storage_reader.prepare_local_plan(local_plan)
+        return local_plan
+
+    def global_step(all_local_plans):
+        assert planner is not None
+        all_local_plans = planner.create_global_plan(all_local_plans)
+        all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
+        return all_local_plans
+
+    central_plan = distW.reduce_scatter("plan", local_step, global_step)
+
+    def read_data():
+        assert planner is not None
+        final_local_plan = planner.finish_plan(central_plan)
+        all_reads = storage_reader.read_data(final_local_plan, planner)
+
+        all_reads.wait()
+        return None
+
+    _ = distW.all_gather("read", read_data)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict_saver.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict_saver.py
new file mode 100644
index 0000000000000000000000000000000000000000..280ffdbe518dcf0b0e8f72c64805d170ec2d393b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/state_dict_saver.py
@@ -0,0 +1,288 @@
+import os
+import warnings
+from concurrent.futures import Future, ThreadPoolExecutor
+from typing import cast, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed._state_dict_utils import _offload_state_dict_to_cpu
+from torch.distributed.checkpoint.stateful import Stateful
+from torch.distributed.distributed_c10d import _get_default_group
+
+from ._storage_utils import _storage_setup
+from .default_planner import DefaultSavePlanner
+from .metadata import Metadata, STATE_DICT_TYPE
+from .planner import SavePlanner
+from .storage import StorageWriter
+from .utils import _api_bc_check, _DistWrapper, _profile
+
+
+__all__ = ["save_state_dict", "save", "async_save"]
+
+
+def save_state_dict(
+    state_dict: STATE_DICT_TYPE,
+    storage_writer: StorageWriter,
+    process_group: Optional[dist.ProcessGroup] = None,
+    coordinator_rank: int = 0,
+    no_dist: bool = False,
+    planner: Optional[SavePlanner] = None,
+) -> Metadata:
+    """This method is deprecated. Please switch to 'save'."""
+    warnings.warn(
+        "'save_state_dict' is deprecated and will be removed in future versions."
+        "Please use 'save' instead."
+    )
+
+    storage_writer.reset()
+
+    # TODO: test returning `save` here instead.
+    with _profile():
+        return _save_state_dict(
+            state_dict,
+            storage_writer,
+            process_group,
+            coordinator_rank,
+            no_dist,
+            planner,
+        )
+
+
+@_api_bc_check
+def save(
+    state_dict: STATE_DICT_TYPE,
+    *,
+    checkpoint_id: Union[str, os.PathLike, None] = None,
+    storage_writer: Optional[StorageWriter] = None,
+    planner: Optional[SavePlanner] = None,
+    process_group: Optional[dist.ProcessGroup] = None,
+) -> Metadata:
+    """
+    Save a distributed model in SPMD style.
+
+    This function is different from ``torch.save()`` as it handles
+    ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards.
+
+    For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
+    save will call ``state_dict`` before serialization.
+
+    .. warning::
+        There is no guarantees of Backwards Compatibility across PyTorch versions
+        for saved state_dicts.
+
+    .. warning::
+        If using the `process_group` argument, make sure that only its ranks
+        call `save_state_dict` and that all data in state_dict belong to it.
+
+    .. note::
+        When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of
+        the shard_group should be calling `save_state_dict` and the corresponding process
+        group needs to be passed in.
+
+    .. note::
+        If no process group is available, this function assumes the intention is to save the
+         state_dict in the local process.
+
+    .. note:
+        Rank 0 is assumed to be the coordinator rank.
+
+
+    Args:
+        state_dict (Dict[str, Any]): The state_dict to save.
+        checkpoint_id (Union[str, os.PathLike, None]):
+            The ID of this checkpoint instance. The meaning of the checkpoint_id
+            depends on the storage. It can be a path to a folder or to a file.
+            It can also be a key if the storage is a key-value store.
+            (Default: ``None``)
+        storage_writer (Optional[StorageWriter]):
+            Instance of StorageWriter used to perform writes. If this is not
+            specified, DCP will automatically infer the writer based on the
+            checkpoint_id. If checkpoint_id is also None, an exception will
+            be raised. (Default: ``None``)
+        planner (Optional[SavePlanner]):
+            Instance of SavePlanner. If this is not specificed, the default
+            planner will be used. (Default: ``None``)
+        process_group (Optional[ProcessGroup]):
+            ProcessGroup to be used for cross-rank synchronization.
+            (Default: ``None``)
+
+    Returns:
+        Metadata: Metadata object for the saved checkpoint.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> my_model = MyModule()
+
+        >>> state_dict = {"model": my_model}
+
+        >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
+        >>> torch.distributed.checkpoint.save(
+        >>>     state_dict=state_dict,
+        >>>     storage_writer=fs_storage_writer,
+        >>> )
+
+    .. note::
+        save_state_dict uses collectives to coordinate writes across ranks.
+        For NCCL-based process groups, internal tensor representations of
+        objects must be moved to the GPU device before communication takes place.
+        In this case, the device used is given by ``torch.cuda.current_device()``
+        and it is the user's responsibility to ensure that this is set so that
+        each rank has an individual GPU, via ``torch.cuda.set_device()``.
+    """
+    torch._C._log_api_usage_once("torch.distributed.checkpoint.save")
+
+    no_dist = not (dist.is_available() and dist.is_initialized())
+    if no_dist:
+        warnings.warn(
+            "torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process."
+        )
+
+    with _profile():
+        storage_writer = cast(
+            StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
+        )
+
+        return _save_state_dict(
+            state_dict=_stateful_to_state_dict(state_dict),
+            storage_writer=storage_writer,
+            process_group=process_group,
+            no_dist=no_dist,
+            planner=planner,
+        )
+
+
+def async_save(
+    state_dict: STATE_DICT_TYPE,
+    *,
+    checkpoint_id: Union[str, os.PathLike, None] = None,
+    storage_writer: Optional[StorageWriter] = None,
+    planner: Optional[SavePlanner] = None,
+    process_group: Optional[dist.ProcessGroup] = None,
+) -> Future:
+    """Asynchronous version of ``save_state_dict``. This code first de-stages the state_dict on CPU, and then calls
+    `save` in a separate thread.
+
+    .. warning::
+        This feature is experimental and subject to change.
+
+    Args:
+        state_dict (Dict[str, Any]): The state_dict to save.
+        checkpoint_id (Union[str, os.PathLike, None]):
+            The ID of this checkpoint instance. The meaning of the checkpoint_id
+            depends on the storage. It can be a path to a folder or to a file.
+            It can also be a key if the storage is a key-value store.
+            (Default: ``None``)
+        storage_writer (Optional[StorageWriter]):
+            Instance of StorageWriter used to perform writes. If this is not
+            specified, DCP will automatically infer the writer based on the
+            checkpoint_id. If checkpoint_id is also None, an exception will
+            be raised. (Default: ``None``)
+        planner (Optional[SavePlanner]):
+            Instance of SavePlanner. If this is not specificed, the default
+            planner will be used. (Default: ``None``)
+        process_group (Optional[ProcessGroup]):
+            ProcessGroup to be used for cross-rank synchronization.
+            (Default: ``None``)
+
+    Returns:
+        Future: A future holding the resultant Metadata object from `save`.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> my_model = MyModule()
+
+        >>> state_dict = {"model": my_model}
+
+        >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
+        >>> checkpoint_future = torch.distributed.checkpoint.async_save(
+        >>>     state_dict=state_dict,
+        >>>     storage_writer=fs_storage_writer,
+        >>> )
+        >>>
+        >>> # ... do some work ...
+        >>>
+        >>> checkpoint_future.result()
+
+    """
+    torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save")
+
+    pg = process_group or _get_default_group()
+    assert (
+        torch.device("cpu") in pg._device_types  # type: ignore[attr-defined]
+    ), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:ncc'"
+
+    cpu_state_dict = _offload_state_dict_to_cpu(_stateful_to_state_dict(state_dict))
+
+    executor = ThreadPoolExecutor(max_workers=1)
+    f = executor.submit(
+        save,
+        cpu_state_dict,
+        checkpoint_id=checkpoint_id,
+        storage_writer=storage_writer,
+        planner=planner,
+        process_group=process_group,
+    )
+    f.add_done_callback(lambda f: executor.shutdown(wait=False))
+
+    return f
+
+
+def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
+    """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
+    stateful_state_dict = {}
+    for key, elem in state_dict.items():
+        stateful_state_dict[key] = (
+            elem.state_dict() if isinstance(elem, Stateful) else elem
+        )
+    return stateful_state_dict
+
+
+def _save_state_dict(
+    state_dict: STATE_DICT_TYPE,
+    storage_writer: StorageWriter,
+    process_group: Optional[dist.ProcessGroup] = None,
+    coordinator_rank: int = 0,
+    no_dist: bool = False,
+    planner: Optional[SavePlanner] = None,
+) -> Metadata:
+    torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
+
+    distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
+    if planner is None:
+        planner = DefaultSavePlanner()
+    assert planner is not None
+
+    global_metatadata = None
+
+    def local_step():
+        assert planner is not None
+        planner.set_up_planner(state_dict, distW.is_coordinator)
+        storage_writer.set_up_storage_writer(distW.is_coordinator)
+        local_plan = planner.create_local_plan()
+        local_plan = storage_writer.prepare_local_plan(local_plan)
+        return local_plan
+
+    def global_step(all_local_plans):
+        nonlocal global_metatadata
+
+        assert planner is not None
+        all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
+        all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
+        return all_local_plans
+
+    central_plan = distW.reduce_scatter("plan", local_step, global_step)
+
+    def write_data():
+        assert planner is not None
+        final_local_plan = planner.finish_plan(central_plan)
+        all_writes = storage_writer.write_data(final_local_plan, planner)
+
+        all_writes.wait()
+        return all_writes.value()
+
+    def finish_checkpoint(all_results):
+        assert global_metatadata is not None
+        storage_writer.finish(metadata=global_metatadata, results=all_results)
+        return global_metatadata
+
+    return distW.all_reduce("write", write_data, finish_checkpoint)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/stateful.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/stateful.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09cb8fffd1379c4b8c3bf22dd41c68b93200c49
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/stateful.py
@@ -0,0 +1,43 @@
+from typing import Any, Dict, runtime_checkable, TypeVar
+
+from typing_extensions import Protocol
+
+
+__all__ = ["Stateful", "StatefulT"]
+
+
+@runtime_checkable
+class Stateful(Protocol):
+    """
+    Stateful protocol for objects that can be checkpointed and restored.
+    """
+
+    def state_dict(self) -> Dict[str, Any]:
+        """
+        Objects should return their state_dict representation as a dictionary.
+        The output of this function will be checkpointed, and later restored in
+        `load_state_dict()`.
+
+        .. warning::
+            Because of the inplace nature of restoring a checkpoint, this function
+            is also called during `torch.distributed.checkpoint.load`.
+
+
+        Returns:
+            Dict: The objects state dict
+        """
+
+        ...
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        """
+        Restore the object's state from the provided state_dict.
+
+        Args:
+            state_dict: The state dict to restore from
+        """
+
+        ...
+
+
+StatefulT = TypeVar("StatefulT", bound=Stateful)
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/storage.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..98b828c0b9cda33e233947473ab07753cad64346
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/storage.py
@@ -0,0 +1,279 @@
+import abc
+import os
+from dataclasses import dataclass
+from typing import Any, List, Union
+
+from torch.futures import Future
+
+from .metadata import Metadata, MetadataIndex
+from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner
+
+__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
+
+
+@dataclass(frozen=True)
+class WriteResult:
+    index: MetadataIndex
+
+    size_in_bytes: int
+    storage_data: Any
+
+
+class StorageWriter(abc.ABC):
+    """
+    Interface used by ``save_state_dict`` to write to storage.
+
+    One StorageWriter instance acts as both the coordinator and the follower
+    in a distributed checkpoint. As part of initialization, each instance
+    is told its role.
+
+    A subclass should expect the following sequence of calls.
+
+    0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
+    1) (all ranks) set_up_storage_writer()
+    2) (all ranks) prepare_local_plan()
+    3) (coordinator) prepare_global_plan()
+    4) (all ranks) write_data()
+    5) (coordinator) finish()
+    """
+
+    @abc.abstractmethod
+    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
+        """
+        Calls to indicates a brand new checkpoint write is going to happen.
+        A checkpoint_id may be present if users set the checkpoint_id for
+        this checkpoint write. The meaning of the checkpiont_id is
+        storage-dependent. It can be a path to a folder/file or a key for
+        a key-value storage.
+
+        Args:
+            checkpoint_id (Union[str, os.PathLike, None]):
+                The ID of this checkpoint instance. The meaning of the checkpoint_id
+                depends on the storage. It can be a path to a folder or to a file.
+                It can also be a key if the storage is a key-value store.
+                (Default: ``None``)
+        """
+        ...
+
+    @abc.abstractmethod
+    def set_up_storage_writer(self, is_coordinator: bool) -> None:
+        """
+        Initialize this instance.
+
+        Args:
+            is_coordinator (bool): Whether this instance is responsible for coordinating
+              the checkpoint.
+        """
+        pass
+
+    @abc.abstractmethod
+    def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
+        """
+        Perform storage-specific local planning.
+
+        While this method can produce a completely different plan, the recommended
+        way is to store storage specific data in SavePlan::storage_data.
+
+        Args:
+            plan (SavePlan): The local plan from the ``SavePlanner`` in use.
+
+        Returns:
+            A transformed ``SavePlan`` after storage local planning
+        """
+        pass
+
+    @abc.abstractmethod
+    def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
+        """
+        Perform centralized planning of storage.
+
+        This method is only called on the coordinator instance.
+
+        While this method can produce a completely different plan, the preferred
+        way is to store storage specific data in SavePlan::storage_data.
+
+        Args:
+            plans: A list of ``SavePlan`` instances, one for each rank.
+
+        Returns:
+            A list of transformed ``SavePlan`` after storage global planning
+        """
+        pass
+
+    @abc.abstractmethod
+    def write_data(
+        self, plan: SavePlan, planner: SavePlanner
+    ) -> Future[List[WriteResult]]:
+        """
+        Write all items from ``plan`` using ``planner`` to resolve the data.
+
+        A subclass should call ``SavePlanner::resolve_data`` on each item
+        from the plan to get access to the underlying object to write.
+
+        Subclasses should lazily call `resolve_data` as it can allocate memory.
+        In case of tensors, make following assumptions:
+
+        - They might be on any device, including not matching the one on ``WriteItem::tensor_data``
+        - They might be views or not contiguous. Only the projection needs to be saved.
+
+        Args:
+            plan (SavePlan): The save plan to execute.
+            planner (SavePlanner): Planner object to be used to resolve items to data.
+
+        Returns:
+            A future that completes to a list of WriteResult
+        """
+        pass
+
+    @abc.abstractmethod
+    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
+        """
+        Write the metadata and marks the current checkpoint as successful.
+
+        The actual format/schema used for serializing `metadata` is an
+        implementation detail. The only requirement is that it's recoverable
+        in to the same object graph.
+
+        Args:
+            metadata (Metadata): metadata for the new checkpoint
+            results: A list of WriteResults from all ranks.
+
+        Returns:
+            None
+        """
+        pass
+
+    @classmethod
+    @abc.abstractmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        """
+        Check if the given checkpoint_id is supported by the stroage. This allow
+        us to enable automatic storage selection.
+        """
+        ...
+
+
+class StorageReader(abc.ABC):
+    """
+    Interface used by ``load_state_dict`` to read from storage.
+
+    One StorageReader instance acts as both the coordinator and the follower
+    in a distributed checkpoint. As part of initialization, each instance
+    is told its role.
+
+    A subclass should expected the following sequence of calls by ``load_state_dict``:
+
+    0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
+    1) (all ranks) read_metadata()
+    2) (all ranks) set_up_storage_reader()
+    3) (all ranks) prepare_local_plan()
+    4) (coordinator) prepare_global_plan()
+    5) (all ranks) read_data()
+    """
+
+    @abc.abstractmethod
+    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
+        """
+        Calls to indicates a brand new checkpoint read is going to happen.
+        A checkpoint_id may be present if users set the checkpoint_id for
+        this checkpoint read. The meaning of the checkpiont_id is
+        storage-dependent. It can be a path to a folder/file or a key for
+        a key-value storage.
+
+        Args:
+            checkpoint_id (Union[str, os.PathLike, None]):
+                The ID of this checkpoint instance. The meaning of the checkpoint_id
+                depends on the storage. It can be a path to a folder or to a file.
+                It can also be a key if the storage is more like a key-value store.
+                (Default: ``None``)
+        """
+        ...
+
+    @abc.abstractmethod
+    def read_metadata(self) -> Metadata:
+        """
+        Read the checkpoint metadata.
+
+        Returns:
+            The metadata object associated with the checkpoint being loaded.
+
+        """
+        pass
+
+    @abc.abstractmethod
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
+        """
+        Initialize this instance.
+
+        Args:
+            metadata (Metadata): The metadata schema to use.
+            is_coordinator (bool): Whether this instance is responsible for coordinating
+              the checkpoint.
+        """
+        pass
+
+    @abc.abstractmethod
+    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
+        """
+        Perform storage-specific local planning.
+
+        While this method can produce a completely different plan, the recommended
+        way is to store storage specific data in LoadPlan::storage_data.
+
+        Args:
+            plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
+
+        Returns:
+            A transformed ``LoadPlan`` after storage local planning
+        """
+        pass
+
+    @abc.abstractmethod
+    def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
+        """
+        Perform centralized planning of storage loading.
+
+        This method is only called on the coordinator instance.
+
+        While this method can produce a completely different plan, the preferred
+        way is to store storage specific data in LoadPlan::storage_data.
+
+        Args:
+            plans: A list of ``LoadPlan`` instances, one for each rank.
+
+        Returns:
+            A list of transformed ``LoadPlan`` after storage global planning
+        """
+        pass
+
+    @abc.abstractmethod
+    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
+        """
+        Read all items from ``plan`` using ``planner`` to resolve the data.
+
+        A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
+        object into the right place.
+
+        A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
+        tensors that in should load data into.
+
+        It's the StorageLayer responsibility to properly schedule any cross device copies
+        required.
+
+        Args:
+            plan (LoadPlan): The local plan to execute on
+            planner (LoadPlanner): The planner object to use to resolve items.
+
+        Returns:
+            A future that completes once all reads are finished.
+        """
+        pass
+
+    @classmethod
+    @abc.abstractmethod
+    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
+        """
+        Check if the given checkpoint_id is supported by the stroage. This allow
+        us to enable automatic storage selection.
+        """
+        ...
diff --git a/MLPY/Lib/site-packages/torch/distributed/checkpoint/utils.py b/MLPY/Lib/site-packages/torch/distributed/checkpoint/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed0817736a957549f059d40b585c450d37ea6373
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/checkpoint/utils.py
@@ -0,0 +1,429 @@
+import cProfile
+import inspect
+import io
+import itertools
+import os
+import warnings
+from contextlib import contextmanager
+from functools import wraps
+from pstats import Stats
+from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharded_tensor.shard import Shard
+from torch.distributed._tensor import DTensor
+
+from .api import (
+    _is_wrapped_exception,
+    _wrap_exception,
+    CheckpointException,
+    WRAPPED_EXCEPTION,
+)
+from .metadata import MetadataIndex, STATE_DICT_TYPE
+
+__all__ = ["find_tensor_shard", "find_state_dict_object"]
+
+T = TypeVar("T")
+R = TypeVar("R")
+
+
+def _get_failure_dict(
+    results: List[Union[T, WRAPPED_EXCEPTION]]
+) -> Dict[int, WRAPPED_EXCEPTION]:
+    return cast(
+        Dict[int, WRAPPED_EXCEPTION],
+        {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
+    )
+
+
+def _all_gather_keys(
+    local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
+) -> List[Any]:
+    """Gathers all keys, and returns them sorted."""
+    keys = list(local_dict.keys())
+    gathered_keys: List[List[Any]] = [None] * dist.get_world_size()  # type: ignore[list-item]
+
+    dist.all_gather_object(gathered_keys, keys, group=group)
+    return sorted(set(itertools.chain.from_iterable(gathered_keys)))
+
+
+class _DistWrapper:
+    """
+    This is a wrapper around PG that provides a series of features around object collectives.
+
+    It works without distributed initialized, where most collectives turns into nops.
+
+    All variants that take functions are exception robust, meaning that if one or more
+    ranks raise errors, all ranks will observe those.
+    """
+
+    def __init__(
+        self,
+        group: Optional[dist.ProcessGroup],
+        use_dist: bool,
+        coordinator_rank: int,
+    ):
+        self.group = group
+        self.use_dist = use_dist
+        self.coordinator_rank = coordinator_rank
+        if self.use_dist:
+            self.rank = dist.get_rank(group)
+            self.is_coordinator = self.rank == coordinator_rank
+        else:
+            self.rank = 0
+            self.is_coordinator = True
+
+    def get_rank(self) -> int:
+        return self.rank
+
+    def get_world_size(self) -> int:
+        if self.use_dist:
+            return dist.get_world_size(self.group)
+        return 1
+
+    def broadcast_object(self, object: Optional[T]) -> T:
+        """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled."""
+        object_list = [object]
+        if self.use_dist:
+            dist.broadcast_object_list(
+                object_list=object_list,
+                group=self.group,
+                src=self.coordinator_rank,
+            )
+        return cast(T, object_list[0])
+
+    def gather_object(self, object: T) -> Optional[List[T]]:
+        """Implement functionality similar to c10d::gather_object but without distributed enabled."""
+        if self.use_dist:
+            gather_objs = (
+                cast(List[T], [None] * dist.get_world_size(self.group))
+                if self.is_coordinator
+                else None
+            )
+
+            dist.gather_object(
+                obj=object,
+                object_gather_list=gather_objs if self.is_coordinator else None,
+                dst=self.coordinator_rank,
+                group=self.group,
+            )
+            result = gather_objs
+        else:
+            result = [object]
+        return result
+
+    def all_gather_object(self, object: T) -> List[T]:
+        """Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
+        if self.use_dist:
+            gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
+
+            dist.all_gather_object(
+                object_list=gather_objs, obj=object, group=self.group
+            )
+        else:
+            gather_objs = [object]
+        return gather_objs
+
+    def scatter_object(self, object_list: Optional[List[T]]) -> T:
+        """Implement functionality similar to c10d::scatter_object but without distributed enabled."""
+        if self.use_dist:
+            gather_result = cast(List[T], [None])
+            dist.scatter_object_list(
+                scatter_object_output_list=gather_result,
+                scatter_object_input_list=object_list if self.is_coordinator else None,
+                src=self.coordinator_rank,
+                group=self.group,
+            )
+
+            local_reply = gather_result[0]
+        else:
+            assert object_list is not None
+            local_reply = object_list[0]
+        return local_reply
+
+    def reduce_scatter(
+        self,
+        step: str,
+        map_fun: Callable[[], T],
+        reduce_fun: Callable[[List[T]], List[R]],
+    ) -> R:
+        """
+        Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
+
+        This method operates in the following way:
+            Run ``map_fun`` on all ranks
+            Gather results on rank 0
+            Call ``reduce_fun`` on all those values
+            Scatter to each rank part of the result.
+        """
+        local_data: Union[WRAPPED_EXCEPTION, T]
+        try:
+            local_data = map_fun()
+        except BaseException as e:
+            local_data = _wrap_exception(e)
+
+        all_data = self.gather_object(local_data)
+        all_results: Optional[List[Union[R, CheckpointException]]] = None
+        if self.is_coordinator:
+            assert all_data is not None
+            node_failures = _get_failure_dict(all_data)
+
+            if len(node_failures) == 0:
+                try:
+                    # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
+                    all_results = cast(
+                        List[Union[R, CheckpointException]],
+                        reduce_fun(cast(List[T], all_data)),
+                    )
+                except BaseException as e:
+                    node_failures[self.rank] = _wrap_exception(e)
+
+            if len(node_failures) > 0:
+                all_results = [
+                    CheckpointException(step, node_failures)
+                ] * self.get_world_size()
+
+        result = self.scatter_object(all_results)
+        if isinstance(result, CheckpointException):
+            raise result
+        return result
+
+    def all_reduce(
+        self,
+        step: str,
+        map_fun: Callable[[], T],
+        reduce_fun: Callable[[List[T]], R],
+    ) -> R:
+        """
+        Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
+
+        This method operates in the following way:
+            Run ``map_fun`` on all ranks
+            Gather results on rank 0
+            Call ``reduce_fun`` on all those values
+            Broadcast the reduced value to all ranks.
+        """
+        local_data: Union[T, WRAPPED_EXCEPTION]
+        try:
+            local_data = map_fun()
+        except BaseException as e:
+            local_data = _wrap_exception(e)
+
+        all_data = self.gather_object(local_data)
+        result: Optional[Union[R, CheckpointException]] = None
+        if self.is_coordinator:
+            assert all_data is not None
+            node_failures = _get_failure_dict(all_data)
+            if len(node_failures) == 0:
+                try:
+                    result = reduce_fun(cast(List[T], all_data))
+                except BaseException as e:
+                    node_failures[self.rank] = _wrap_exception(e)
+
+            if len(node_failures) > 0:
+                result = CheckpointException(step, node_failures)
+
+        final_result = self.broadcast_object(result)
+        if isinstance(final_result, CheckpointException):
+            raise final_result
+        return cast(R, final_result)
+
+    def all_gather(
+        self,
+        step: str,
+        map_fun: Callable[[], T],
+    ) -> List[T]:
+        """
+        Compute a value on each rank, then all_gather them.
+
+        This method operates in the following way:
+            Run ``map_cp`` on all ranks
+            all_gather the values to all ranks
+        """
+        result: Union[T, WRAPPED_EXCEPTION]
+        try:
+            result = map_fun()
+        except BaseException as e:
+            result = _wrap_exception(e)
+
+        all_results = self.all_gather_object(result)
+
+        node_failures = _get_failure_dict(all_results)
+        if len(node_failures) > 0:
+            raise CheckpointException(step, node_failures)
+        return cast(List[T], all_results)
+
+    def broadcast(
+        self,
+        step: str,
+        map_fun: Callable[[], T],
+    ) -> T:
+        """
+        Compute a value on rank 0 and broadcast it.
+
+        This method operates in the following way:
+            Run ``map_cp`` on rank 0
+            broadcast the value
+        """
+        result: Optional[Union[T, CheckpointException]] = None
+        if self.is_coordinator:
+            try:
+                result = map_fun()
+            except BaseException as e:
+                result = CheckpointException(step, {self.rank: _wrap_exception(e)})
+        final_result = self.broadcast_object(result)
+        if isinstance(final_result, CheckpointException):
+            raise final_result
+        return cast(T, final_result)
+
+
+def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
+    if index.offset is None:
+        raise ValueError(
+            f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
+        )
+
+    shards = tensor.local_shards()
+    # index fast path
+    if index.index is not None:
+        if (
+            len(shards) > index.index
+            and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
+        ):
+            return shards[index.index]
+
+    for shard in shards:
+        if torch.Size(shard.metadata.shard_offsets) == index.offset:
+            return shard
+    raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
+
+
+def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
+    if isinstance(tensor, DTensor):
+        return tensor.to_local()
+    if isinstance(tensor, ShardedTensor):
+        return _find_shard(tensor, index).tensor
+    if index.offset is not None:
+        # special case looking up a tensor by origin
+        if index.offset == torch.Size([0] * len(tensor.size())):
+            return tensor
+        raise ValueError(
+            f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
+        )
+    return tensor
+
+
+def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
+    if index.fqn not in state_dict:
+        raise ValueError(f"Could not find FQN: '{index.fqn}'")
+    obj = state_dict[index.fqn]
+
+    if isinstance(obj, torch.Tensor):
+        return find_tensor_shard(obj, index)
+    elif index.offset is not None:
+        raise ValueError(
+            f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
+        )
+    return obj
+
+
+def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
+    return [i_a + i_b for i_a, i_b in zip(a, b)]
+
+
+def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
+    return [i_a - i_b for i_a, i_b in zip(a, b)]
+
+
+class _ReaderView(io.IOBase):
+    def __init__(self, base_stream: io.IOBase, offset: int, len: int):
+        super().__init__()
+        self.offset = offset
+        self.len = len
+        self.base_stream = base_stream
+        self.seek(0)
+
+    def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
+        if __whence == os.SEEK_SET:
+            __offset = self.offset + __offset
+        elif __whence == os.SEEK_END:
+            __whence = os.SEEK_SET
+            __offset = (self.offset + self.len) - __offset
+        return self.base_stream.seek(__offset, __whence)
+
+    def tell(self) -> int:
+        return self.base_stream.tell() - self.offset
+
+    def readable(self) -> bool:
+        return self.base_stream.readable()
+
+    def seekable(self) -> bool:
+        return self.base_stream.seekable()
+
+    def readinto(self, b):
+        return self.base_stream.readinto(b)  # type: ignore[attr-defined]
+
+    def read(self, size=-1):
+        return self.base_stream.read(size)
+
+
+def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
+    # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
+    return _ReaderView(file, offset, length)
+
+
+def _normalize_device_info(device_type: str, device_id: int) -> str:
+    """Device info normalization."""
+    if device_type == "cpu":
+        return "cpu"
+    return f"{device_type}:{device_id}"
+
+
+# TODO: integrate with distributed logging flag
+ENABLE_PROFILE = False
+
+
+@contextmanager
+def _profile():
+    # Only log the profiling when it is enable and is on rank0  or dist is not
+    # avaiable.
+    if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0):
+        profiler = cProfile.Profile()
+        profiler.enable()
+        try:
+            yield
+        finally:
+            profiler.disable()
+            stats = Stats(profiler)
+            stats.sort_stats("time").print_stats(10)
+    else:
+        yield
+
+
+def _api_bc_check(func):
+    @wraps(func)
+    def inner_func(*args, **kwargs) -> Any:
+        if len(args) == 2:
+            warnings.warn(
+                f"The argument order of {func.__name__} has been changed. "
+                "Please check the document to avoid future breakages."
+            )
+            sig = inspect.signature(func)
+            kwonlyargs = [
+                p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
+            ]
+            if "storage_writer" in kwonlyargs:
+                assert "storage_writer" not in kwargs, (args, kwargs)
+                kwargs["storage_writer"] = args[1]
+            elif "storage_reader" in kwonlyargs:
+                assert "storage_reader" not in kwargs, (args, kwargs)
+                kwargs["storage_reader"] = args[1]
+            else:
+                raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
+            return func(args[0], **kwargs)
+        else:
+            return func(*args, **kwargs)
+
+    return inner_func
diff --git a/MLPY/Lib/site-packages/torch/distributed/collective_utils.py b/MLPY/Lib/site-packages/torch/distributed/collective_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dbe0310e63433d15a957388282eac55431ed270
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/collective_utils.py
@@ -0,0 +1,211 @@
+#!/usr/bin/env python3
+
+
+"""
+A set of primitive functions for performing collective ops.
+
+Each should also handle single rank scenario.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar, Union
+
+import torch.distributed as dist
+
+T = TypeVar("T")
+
+@dataclass
+class SyncPayload(Generic[T]):
+    stage_name: Optional[str]
+    success: bool
+    payload: T
+    exception: Optional[Exception] = None
+
+def broadcast(
+    data_or_fn: Union[T, Callable[[], T]],
+    *,
+    success: bool = True,
+    stage_name: Optional[str] = None,
+    rank: int = 0,
+    pg: Optional[dist.ProcessGroup] = None,
+) -> T:
+    """
+    Broadcasts the data payload from rank 0 to all other ranks.
+    Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks.
+
+    Can be used to broadcast a failure signal to stop all ranks.
+
+    If the function raises an exception, all ranks will raise.
+
+    Args:
+        data_or_fn: the data to broadcast or function to execute and broadcast result.
+        success: False to stop all ranks.
+        stage_name: the name of the logical stage for synchronization and debugging
+        rank: rank to broadcast data or execute function and broadcast resutls.
+        pg: the process group for sync
+    Throws:
+        RuntimeError from original exception trace
+    Returns:
+        the value after synchronization
+
+    Example usage:
+    >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg)
+    """
+
+    if not success and data_or_fn is not None:
+        raise AssertionError("Data or Function is expected to be None if not successful")
+
+    payload: Optional[T] = None
+    exception : Optional[Exception] = None
+    # if no pg is passed then execute if rank is 0
+    if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
+        # determine if it is an executable function or data payload only
+        if callable(data_or_fn):
+            try:
+                payload = data_or_fn()
+            except Exception as e:
+                success = False
+                exception = e
+        else:
+            payload = data_or_fn
+
+    # broadcast the exception type if any to all ranks for failure categorization
+    sync_obj = SyncPayload(
+        stage_name=stage_name,
+        success=success,
+        payload=payload,
+        exception=exception,
+    )
+
+    if pg is not None:
+        broadcast_list = [sync_obj]
+        dist.broadcast_object_list(broadcast_list, src=rank, group=pg)
+        assert len(broadcast_list) == 1
+        sync_obj = broadcast_list[0]
+
+    # failure in any rank will trigger a throw in every rank.
+    if not sync_obj.success:
+        error_msg = f"Rank {rank} failed"
+        if stage_name is not None:
+            error_msg += f": stage {sync_obj.stage_name}"
+        if sync_obj.exception is not None:
+            error_msg += f": exception {sync_obj.exception}"
+        raise RuntimeError(error_msg) from sync_obj.exception
+
+    return cast(T, sync_obj.payload)
+
+
+def all_gather(
+    data_or_fn: Union[T, Callable[[], T]],
+    stage_name: Optional[str] = None,
+    pg: Optional[dist.ProcessGroup] = None,
+) -> List[T]:
+    """
+    A simple all_gather primitive with basic synchronization guard logic,
+    by checking payload from all ranks has the same stage name.
+
+    Args:
+        data_or_fn: the data to be all gathered across ranks or function to be executed
+        stage_name: the sync stage name for out-of-sync protection
+        pg: the process group for sync
+    Throws:
+        RuntimeError from original exception trace
+    Returns:
+        a list of synced data from all ranks
+
+    Example usage:
+    >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
+    """
+    payload: Optional[T] = None
+    exception : Optional[Exception] = None
+    success = True
+    # determine if it is an executable function or data payload only
+    if callable(data_or_fn):
+        try:
+            payload = data_or_fn()
+        except Exception as e:
+            success = False
+            exception = e
+    else:
+        payload = data_or_fn
+
+    sync_obj = SyncPayload(
+        stage_name=stage_name,
+        success=success,
+        payload=payload,
+        exception=exception,
+    )
+
+    if pg is not None:
+        # List of success/failure across all ranks.
+        total_list = [None] * dist.get_world_size(pg)
+        all_gather_object_enforce_type(pg, total_list, sync_obj)
+        # Each rank will throw RuntimeError in case of failure on any rank.
+        stage_name = cast(SyncPayload[T], total_list[0]).stage_name
+        exception_list: List[Tuple[int, Exception]] = []
+        ret_list: List[T] = []
+        error_msg: str = ""
+
+        for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)):
+            if sp.stage_name != stage_name:
+                error_msg += (
+                    f"Unexpected stage name received from rank {i}: {sp.stage_name} "
+                )
+                continue
+            if not sp.success and sp.exception is not None:
+                exception_list.append((i, sp.exception))
+                continue
+            ret_list.append(sp.payload)
+
+        if len(exception_list) > 0:
+            raise RuntimeError(  # type: ignore[misc]
+                error_msg, exception_list) from exception_list[0]
+        return ret_list
+    else:
+        if not sync_obj.success:
+            raise RuntimeError(
+                f"all_gather failed with exception {sync_obj.exception}",
+            ) from sync_obj.exception
+        return [sync_obj.payload]  # type: ignore[list-item]
+
+
+# Note: use Any for typing for now so users can pass in
+# either a list of None or target type placeholders
+# otherwise pyre would complain
+def all_gather_object_enforce_type(
+    pg: dist.ProcessGroup,
+    # pyre-fixme[2]: Parameter must have a type that does not contain `Any`
+    object_list: List[Any],
+    # pyre-fixme[2]: Parameter must have a type other than `Any`
+    obj: Any,
+    # pyre-fixme[2]: Parameter must have a type that does not contain `Any`
+    type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) == type(y),
+) -> None:
+    """
+    Similar to plain all_gather_object but with additional type checking
+    AFTER gather is done to ensure basic consistency.
+    If check does not pass, all ranks will fail with exception.
+
+    This is generally to prevent conditional logic leading to
+    unexpected messages being received. This is considered fatal code error,
+    but due to logic stacks this might happen implicitly in practice.
+
+    The default check does not check sub type (considered different)
+    or covariance (considered same) but users can pass in custom checker
+    if more complicated check is needed.
+    """
+    dist.all_gather_object(object_list, obj, group=pg)
+
+    # conservative check
+    list_len = len(object_list)
+    if list_len == 0:
+        return
+    first_obj = object_list[0]
+    for i in range(1, list_len):
+        if not type_checker(first_obj, object_list[i]):
+            raise TypeError(
+                f"Object type at index {i} is {type(object_list[i])}, "
+                f"while first object type is {type(first_obj)}"
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/constants.py b/MLPY/Lib/site-packages/torch/distributed/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..102d7bf100080087a867ada7421dc96aef8b972d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/constants.py
@@ -0,0 +1,23 @@
+from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
+from datetime import timedelta
+from typing import Optional
+
+__all__ = ['default_pg_timeout', 'default_pg_nccl_timeout']
+
+# Default process group wide timeout, if applicable.
+# This only applies to the non-nccl backends
+# To make an attempt at backwards compatibility with THD, we use an
+# extraordinarily high default timeout, given that THD did not have timeouts.
+default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
+# Separate timeout for PGNCCL mainly becuase it's always been that way in the C++ layer, but until recently
+# there was one default that applied across all backends in the python layer.
+# Later, we could consider merging them back together at the c++ layer if we can align on a same value.
+# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1).
+
+try:
+    from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
+    default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
+except ImportError:
+    # if C++ NCCL support is not compiled, we don't have access to the default nccl value.
+    # if anyone is actually trying to use nccl in this state, it should error.
+    default_pg_nccl_timeout = None
diff --git a/MLPY/Lib/site-packages/torch/distributed/device_mesh.py b/MLPY/Lib/site-packages/torch/distributed/device_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..192516aa033ee10533b498d1e649982f3a2beac6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/device_mesh.py
@@ -0,0 +1,567 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import logging
+import math
+from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
+
+import torch
+
+from torch.distributed import is_available
+
+from ..utils._typing_utils import not_none
+
+__all__ = ["init_device_mesh", "DeviceMesh"]
+
+
+if not is_available():
+    import sys
+
+    # We need to create the stubs when distributed is not available.
+    # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
+    # since it would try to import ``torch.distributed.device_mesh`` or
+    # ``torch.distributed.init_device_mesh`` but cannot find them.
+
+    class _DeviceMeshStub:
+        pass
+
+    def _init_device_mesh_stub():
+        pass
+
+    sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub  # type: ignore[attr-defined]
+    sys.modules[
+        "torch.distributed.device_mesh"
+    ].init_device_mesh = _init_device_mesh_stub  # type: ignore[attr-defined]
+
+
+else:
+    from torch.distributed.distributed_c10d import (
+        _find_pg_by_ranks_and_tag,
+        _get_default_group,
+        _get_group_tag,
+        get_rank,
+        get_world_size,
+        init_process_group,
+        is_initialized,
+        new_group,
+        ProcessGroup,
+    )
+
+    logger = logging.getLogger(__name__)
+
+    # only import numpy typing when type checking
+    if TYPE_CHECKING:
+        try:
+            from numpy.typing import ArrayLike
+        except ImportError:
+            logger.warning(
+                "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
+            )
+
+    class _MeshEnv:
+        def __init__(self) -> None:
+            self.mesh_stack: List[DeviceMesh] = []
+            self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
+            self.parent_to_child_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
+
+        def get_current_mesh(self) -> "DeviceMesh":
+            if len(self.mesh_stack) == 0:
+                raise RuntimeError("No device mesh is currently active!")
+            return self.mesh_stack[-1]
+
+        def create_child_mesh(
+            self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
+        ) -> "DeviceMesh":
+            # Directly return the child mesh if it is already created.
+            child_mesh_mappings = self.parent_to_child_mapping.get(device_mesh)
+            if child_mesh_mappings:
+                sub_mesh = child_mesh_mappings.get(mesh_dim_name)
+                if sub_mesh:
+                    return sub_mesh
+
+            # swap the current dim to the last dim then reshape to flatten out other
+            # dims, so we can just extract the list of ranks which contains cur_rank.
+            cur_rank = device_mesh.get_rank()
+            pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
+                -1, device_mesh.mesh.size(mesh_dim)
+            )
+
+            for mesh_1d in pg_ranks_by_dim:
+                sub_mesh = DeviceMesh(
+                    device_mesh.device_type,
+                    mesh_1d,
+                    mesh_dim_names=(mesh_dim_name,),
+                    _init_backend=False,
+                )
+                if cur_rank in mesh_1d:
+                    res_sub_mesh = sub_mesh
+
+            res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]  # type: ignore[possibly-undefined]
+            # Assign the current DeviceMesh as the parent of the child DeviceMesh.
+            self.child_to_parent_mapping[res_sub_mesh] = device_mesh
+            self.parent_to_child_mapping.setdefault(device_mesh, {})[
+                mesh_dim_name
+            ] = res_sub_mesh
+            return res_sub_mesh
+
+        def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
+            return self.child_to_parent_mapping.get(device_mesh, None)
+
+        def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
+            """
+            Return the index of the mesh dim in the parent mesh.
+            The device_mesh passed in needs to be sliced out from a parent mesh.
+            """
+            parent_mesh = self.get_parent_mesh(device_mesh)
+            child_mesh_dim_names = device_mesh.mesh_dim_names
+            if parent_mesh and child_mesh_dim_names:
+                assert (
+                    len(child_mesh_dim_names) == 1
+                ), "The child mesh can only be a 1D mesh."
+                child_mesh_dim_name = child_mesh_dim_names[0]
+                return self.get_mesh_dim_by_name(parent_mesh, child_mesh_dim_name)
+            return None
+
+        @staticmethod
+        def num_devices_per_host(device_type: str) -> int:
+            return _get_device_handle(device_type).device_count()
+
+        @staticmethod
+        def num_hosts(device_type: str) -> int:
+            # ProcessGroup can't tell us this info so we have to infer it, assume
+            # homogeneous hardware for now
+            return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
+
+        def get_mesh_dim_by_name(
+            self, device_mesh: "DeviceMesh", mesh_dim_name: str
+        ) -> int:
+            if (
+                device_mesh.mesh_dim_names is None
+                or len(device_mesh.mesh_dim_names) == 0
+            ):
+                raise KeyError(
+                    "No `mesh_dim_names` found.",
+                )
+            if mesh_dim_name not in device_mesh.mesh_dim_names:
+                raise KeyError(
+                    f"Mesh dimension '{mesh_dim_name}' does not exist.",
+                    f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}",
+                )
+            return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name))
+
+    _mesh_resources: _MeshEnv = _MeshEnv()
+
+    def _get_device_handle(device_type: str = "cuda"):
+        """
+        Get the module corresponding to the device_type which is cuda or cuda-like device.
+        For example, when the device_type is cuda, the module `torch.cuda` is returned.
+        Return None when there is no corresponding module for device_type, otherwise
+        return the corresponding module.
+        """
+        return getattr(torch, device_type, None)
+
+    class DeviceMesh:
+        """
+        DeviceMesh represents a mesh of devices, where layout of devices could be
+        represented as a n-d dimension array, and each value of the n-d dimensional
+        array is the global id of the default process group ranks.
+
+        DeviceMesh could be used to describe the layout of devices across the cluster,
+        and serves as a proxy for communication among the device lists within the cluster.
+
+        DeviceMesh can be used as a context manager.
+
+        .. note::
+            DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
+            is running on all processes/ranks in the cluster. Therefore, users need to make sure the
+            `mesh` array (which describes the layout of devices) should be identical across all ranks.
+            Inconsistent `mesh` will lead to silent hang.
+
+        Args:
+            device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
+            mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
+                of devices, where the IDs are global IDs of the default process group.
+
+        Returns:
+            DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
+
+        The following program runs on each process/rank in an SPMD manner. In this example, we have 2
+        hosts with 4 GPUs each.
+        A reduction over the first dimension of mesh will reduce across
+        columns (0, 4), .. and (3, 7), a reduction over the second dimension
+        of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
+
+        Example::
+            >>> # xdoctest: +SKIP("no rank")
+            >>> from torch.distributed.device_mesh import DeviceMesh
+            >>>
+            >>> # Initialize device mesh as (2, 4) to represent the topology
+            >>> # of cross-host(dim 0), and within-host (dim 1).
+            >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
+        """
+
+        device_type: str
+        mesh: torch.Tensor
+        mesh_dim_names: Optional[Tuple[str, ...]]
+
+        def __init__(
+            self,
+            device_type: str,
+            mesh: Union[torch.Tensor, "ArrayLike"],
+            *,
+            mesh_dim_names: Optional[Tuple[str, ...]] = None,
+            _init_backend: bool = True,
+        ) -> None:
+            self.device_type = device_type
+            if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
+                raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
+            self.mesh = (
+                mesh.detach().cpu()
+                if isinstance(mesh, torch.Tensor)
+                else torch.tensor(mesh, dtype=torch.int)
+            )
+            self.mesh_dim_names = mesh_dim_names
+
+            # private field to pre-generate DeviceMesh's hash
+            self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
+            self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))
+
+            # Skip process group initialization if xla device or init backend is False
+            # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
+            if device_type != "xla":
+                # always try to create default (world) pg, even if it is not initialized
+                # already. The world pg is used for device mesh identity (rank) on each
+                # process (we need to know if the current global rank is in the mesh or not).
+                if _init_backend:
+                    self._get_or_create_default_group()
+                    self._init_process_groups()
+
+                # calculate the coordinates of the current global rank on the mesh
+                rank_coords = (self.mesh == get_rank()).nonzero()
+                assert rank_coords.size(0) in (0, 1)
+                self._coordinate_on_dim: Optional[List[int]] = (
+                    rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
+                )
+
+        def _get_or_create_default_group(self):
+            default_initialized = is_initialized()
+            if not default_initialized:
+                init_process_group()
+
+            world_size = get_world_size()
+            if self.mesh.numel() > world_size:
+                raise RuntimeError(
+                    f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
+                )
+
+            device_handle = _get_device_handle(self.device_type)
+            # TODO: if user want to pass pg_options, offer a way to do it
+            if not default_initialized and device_handle:
+                # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
+                # NOTE: This device selection would only work for homogeneous hardware.
+                num_devices_per_host = device_handle.device_count()
+                if (
+                    world_size > num_devices_per_host
+                    and world_size % num_devices_per_host != 0
+                ):
+                    raise RuntimeError(
+                        f"DeviceMesh only support homogeneous hardware, but found "
+                        f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
+                    )
+                device_handle.set_device(get_rank() % num_devices_per_host)
+
+            return _get_default_group()
+
+        def _init_process_groups(self):
+            # tag/ranks/group_name associated with each mesh dimension, each
+            # mesh dimension should have one sub-group per rank
+            #
+            # TODO(yifu): remove tag and ranks once we fully migrate to native
+            # functional collectives. See details in:
+            # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
+            dim_group_infos: List[Tuple[str, List[int], str]] = []
+
+            if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
+                # if the mesh is the same as world_pg, we just append the default
+                # pg to the first dim groups, as new_group cannot have the exact
+                # same ranks as world
+                dim_group_infos.append(
+                    (
+                        _get_group_tag(_get_default_group()),
+                        list(range(get_world_size())),
+                        _get_default_group().group_name,
+                    )
+                )
+            else:
+                # create sub pgs base on the mesh argument specified
+                for dim in range(self.mesh.ndim):
+                    # swap the current dim to the last dim
+                    # then reshape to flatten out other dims
+                    pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
+                        -1, self.mesh.size(dim)
+                    )
+                    # multi-dim mesh, create subgroups by looping over the pg_ranks
+                    # for each dim and append the groups
+                    for dim_mesh in pg_ranks_by_dim:
+                        subgroup_ranks = dim_mesh.tolist()
+
+                        # We temporarily revert the re-use subgroup, since it breaks two internal tests.
+                        # Temporarily reverting to resolve test timeout while root-causing.
+                        # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
+                        dim_group = new_group(ranks=subgroup_ranks)
+
+                        # only add to dim_groups if the current rank in the subgroup
+                        if self.get_rank() in subgroup_ranks:
+                            if len(dim_group_infos) > dim:
+                                raise RuntimeError(
+                                    f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
+                                    f"in {subgroup_ranks}!"
+                                )
+                            dim_group_infos.append(
+                                (
+                                    _get_group_tag(not_none(dim_group)),
+                                    subgroup_ranks,
+                                    dim_group.group_name,
+                                )
+                            )
+            self._dim_group_infos = dim_group_infos
+
+        def __enter__(self) -> "DeviceMesh":
+            # set this mesh as the current mesh in mesh env
+            _mesh_resources.mesh_stack.append(self)
+            return self
+
+        # pyre-fixme[2]: Parameter must be annotated.
+        def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
+            # pop this mesh from mesh env
+            _mesh_resources.mesh_stack.pop()
+
+        def __repr__(self) -> str:
+            device_mesh_repr = (
+                f"DeviceMesh({self.mesh.tolist()})"
+                if not self.mesh_dim_names
+                else f"DeviceMesh({self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})"
+            )
+            return device_mesh_repr
+
+        def __hash__(self):
+            return self._hash
+
+        def __eq__(self, other: object) -> bool:
+            if not isinstance(other, DeviceMesh):
+                return False
+            if id(self.mesh) == id(other.mesh):
+                return True
+            return (
+                self.mesh.shape == other.mesh.shape
+                and self._flatten_mesh_list == other._flatten_mesh_list
+            )
+
+        def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
+            """
+            Slice the current DeviceMesh based on the mesh_dim_name given to create a child
+            DeviceMesh.
+
+            Args:
+                mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
+                to create a child DeviceMesh for.
+            Returns:
+                A :class:`DeviceMesh` object
+
+            The following program runs on each process/rank in an SPMD manner. In this example, we have 2
+            hosts with 4 GPUs each.
+            Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
+            Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
+            Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
+            Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
+            Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
+            Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
+
+            Example::
+                >>> # xdoctest: +SKIP("no rank")
+                >>> from torch.distributed.device_mesh import DeviceMesh
+                >>>
+                >>> # Initialize device mesh as (2, 4) to represent the topology
+                >>> # of cross-host(dim 0), and within-host (dim 1).
+                >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
+            """
+            if self.mesh.ndim == 1:
+                if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]:
+                    return self
+                else:
+                    raise RuntimeError(
+                        f"Invalid mesh_dim_name {mesh_dim_name} specified."
+                    )
+
+            mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name)
+            submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
+            return submesh
+
+        def get_group(
+            self, mesh_dim: Optional[Union[int, str]] = None
+        ) -> Union[ProcessGroup, List[ProcessGroup]]:
+            """
+            Returns a list of ProcessGroups corresponding to the mesh dimensions, or
+            returns a single ProcessGroup if mesh_dim is specified or the given mesh has
+            only one mesh dimension.
+
+            Args:
+                mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
+                of the mesh dimension. Default is None.
+
+            Returns:
+                A list of :class:`ProcessGroup` object when `mesh_dim` is not specified for
+                a DeviceMesh with more than 1 dimension; otherwise, returns a single
+                :class:`ProcessGroup` object.
+            """
+            if not hasattr(self, "_dim_group_infos"):
+                raise RuntimeError("DeviceMesh process groups not initialized!")
+
+            if self.mesh.ndim == 1:
+                return not_none(
+                    _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
+                )
+
+            if mesh_dim is not None:
+                if isinstance(mesh_dim, str):
+                    mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
+                return not_none(
+                    _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2])
+                )
+            else:
+                dim_groups = []
+                for ith_dim in range(self.mesh.ndim):
+                    dim_groups.append(
+                        not_none(
+                            _find_pg_by_ranks_and_tag(
+                                *self._dim_group_infos[ith_dim][:2]
+                            )
+                        )
+                    )
+                return dim_groups
+
+        def size(self, mesh_dim: Optional[int] = None) -> int:
+            return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
+
+        @property
+        def ndim(self) -> int:
+            return self.mesh.ndim
+
+        @property
+        def shape(self) -> Tuple[int, ...]:
+            return tuple(self.mesh.shape)
+
+        def get_rank(self) -> int:
+            """
+            Returns the current global rank.
+            """
+            return get_rank()
+
+        def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
+            """
+            Returns the local rank of the given mesh_dim of the DeviceMesh.
+
+            Args:
+                mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
+                of the mesh dimension. Default is None.
+
+            Returns:
+                An integer denotes the local rank.
+
+            The following program runs on each process/rank in an SPMD manner. In this example, we have 2
+            hosts with 4 GPUs each.
+            Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
+            Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
+            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
+            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
+            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
+            Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
+
+            Example::
+                >>> # xdoctest: +SKIP("no rank")
+                >>> from torch.distributed.device_mesh import DeviceMesh
+                >>>
+                >>> # Initialize device mesh as (2, 4) to represent the topology
+                >>> # of cross-host(dim 0), and within-host (dim 1).
+                >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
+            """
+            if self.ndim > 1 and mesh_dim is None:
+                raise RuntimeError(
+                    f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
+                    "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
+                )
+            elif mesh_dim is None:
+                mesh_dim = 0
+
+            mesh_dim_group = not_none(self.get_group(mesh_dim))
+            assert isinstance(
+                mesh_dim_group, ProcessGroup
+            ), "We expect ProcessGroup before calling `get_rank`!"
+            return not_none(get_rank(mesh_dim_group))
+
+        def get_coordinate(self) -> Optional[List[int]]:
+            """
+            Return the relative indices of this rank relative to all
+            dimensions of the mesh. If this rank is not part of the mesh, return None.
+            """
+            return self._coordinate_on_dim if self._coordinate_on_dim else None
+
+    def init_device_mesh(
+        device_type: str,
+        mesh_shape: Tuple[int, ...],
+        *,
+        mesh_dim_names: Optional[Tuple[str, ...]] = None,
+    ) -> DeviceMesh:
+        """
+        Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
+
+        This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
+        If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
+
+        .. note::
+            `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
+            runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
+            describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
+
+        .. note::
+            If no process group is found, init_device_mesh will initialize distributed process group/groups
+            required for distributed communications behind the scene.
+
+        Args:
+            device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
+            mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
+                describing the layout of devices.
+            mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
+                of the multi-dimensional array describing the layout of devices. Its length must match the length
+                of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
+
+        Returns:
+            DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
+
+        Example::
+            >>> # xdoctest: +SKIP("no rank")
+            >>> from torch.distributed.device_mesh import init_device_mesh
+            >>>
+            >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
+            >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
+
+        """
+        if mesh_dim_names is not None:
+            if len(set(mesh_dim_names)) != len(mesh_dim_names):
+                raise RuntimeError(
+                    "Each mesh_dim_name must be unique.",
+                    f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
+                )
+
+            if len(mesh_shape) != len(mesh_dim_names):
+                raise RuntimeError(
+                    "mesh_shape and mesh_dim_names should have same length!",
+                    f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
+                )
+
+        mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
+        device_mesh = DeviceMesh(
+            device_type=device_type,
+            mesh=mesh,
+            mesh_dim_names=mesh_dim_names,
+        )
+
+        return device_mesh
diff --git a/MLPY/Lib/site-packages/torch/distributed/distributed_c10d.py b/MLPY/Lib/site-packages/torch/distributed/distributed_c10d.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2485e8c4ad09b6cc2845a8c4d77f6c6d85dc6c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/distributed_c10d.py
@@ -0,0 +1,4264 @@
+"""Distributed Collective Communication (c10d)."""
+
+import itertools
+import collections.abc
+import contextlib
+import hashlib
+import io
+import logging
+import os
+import pickle
+import sys
+import time
+import warnings
+from collections import namedtuple
+from datetime import timedelta
+from typing import Any, Callable, Dict, Optional, Tuple, Union, List
+
+import torch
+from torch._C._distributed_c10d import (
+    AllgatherOptions,
+    AllreduceCoalescedOptions,
+    AllreduceOptions,
+    AllToAllOptions,
+    _DistributedBackendOptions,
+    BarrierOptions,
+    BroadcastOptions,
+    GatherOptions,
+    PrefixStore,
+    ProcessGroup,
+    ReduceOp,
+    ReduceOptions,
+    ReduceScatterOptions,
+    ScatterOptions,
+    Store,
+    DebugLevel,
+    get_debug_level,
+    Work,
+    _register_process_group,
+    _resolve_process_group,
+    _unregister_all_process_groups,
+    _unregister_process_group,
+)
+from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
+from .constants import default_pg_timeout, default_pg_nccl_timeout
+from .c10d_logger import _exception_logger, _time_logger
+from .rendezvous import register_rendezvous_handler, rendezvous  # noqa: F401
+from ..utils._typing_utils import not_none
+DistStoreError = torch._C._DistStoreError
+
+__all__ = [
+    'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced',
+    'all_gather_object', 'all_reduce',
+    'all_reduce_coalesced', 'all_to_all',
+    'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast',
+    'broadcast_object_list', 'destroy_process_group',
+    'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank',
+    'get_world_size', 'get_pg_count', 'group', 'init_process_group', 'irecv',
+    'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available',
+    'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available',
+    'isend', 'monitored_barrier', 'new_group', 'new_subgroups',
+    'new_subgroups_by_enumeration', 'recv', 'reduce',
+    'reduce_scatter', 'scatter',
+    'scatter_object_list', 'send', 'supports_complex',
+    'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions',
+    'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore',
+    'ProcessGroup', 'ReduceOp', 'ReduceOptions', 'ReduceScatterOptions',
+    'ScatterOptions', 'Store', 'DebugLevel', 'get_debug_level', 'Work',
+    'default_pg_timeout', 'get_group_rank', 'get_global_rank', 'get_process_group_ranks',
+    'reduce_op', 'all_gather_into_tensor', 'reduce_scatter_tensor',
+]
+
+_MPI_AVAILABLE = True
+_NCCL_AVAILABLE = True
+_GLOO_AVAILABLE = True
+_UCC_AVAILABLE = True
+
+_pickler = pickle.Pickler
+_unpickler = pickle.Unpickler
+
+# Change __module__ of all imported types from torch._C._distributed_c10d that are public
+def _export_c_types() -> None:
+    _public_types_to_change_module = [
+        AllreduceCoalescedOptions,
+        AllreduceOptions,
+        AllToAllOptions,
+        BarrierOptions,
+        BroadcastOptions,
+        GatherOptions,
+        PrefixStore,
+        ProcessGroup,
+        ReduceOp,
+        ReduceOptions,
+        ReduceScatterOptions,
+        ScatterOptions,
+        Store,
+        DebugLevel,
+        get_debug_level,
+        Work
+    ]
+    for type in _public_types_to_change_module:
+        type.__module__ = "torch.distributed.distributed_c10d"
+_export_c_types()
+
+try:
+    from torch._C._distributed_c10d import ProcessGroupMPI
+    ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
+    __all__ += ["ProcessGroupMPI"]
+except ImportError:
+    _MPI_AVAILABLE = False
+
+try:
+    from torch._C._distributed_c10d import ProcessGroupNCCL
+    ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
+    __all__ += ["ProcessGroupNCCL"]
+except ImportError:
+    _NCCL_AVAILABLE = False
+
+try:
+    from torch._C._distributed_c10d import ProcessGroupGloo
+    from torch._C._distributed_c10d import _ProcessGroupWrapper
+    ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
+    __all__ += ["ProcessGroupGloo"]
+except ImportError:
+    _GLOO_AVAILABLE = False
+
+try:
+    from torch._C._distributed_c10d import ProcessGroupUCC
+    ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
+    __all__ += ["ProcessGroupUCC"]
+except ImportError:
+    _UCC_AVAILABLE = False
+
+logger = logging.getLogger(__name__)
+
+PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
+
+
+# Some reduce ops are not supported by complex numbers and will result in an error.
+# We currently provide complex support to the distributed API by viewing
+# complex tensors as real (torch.view_as_real), meaning that calling
+# these unsupported ops will return garbage values rather than error out.
+# (e.g. max(2+3i, 3+2i) = 3+3i)
+# We'd like calls to unsupported ops to error out accordingly,
+# rather than returning garbage values.
+def supports_complex(reduceOp: ReduceOp) -> bool:
+    """Return true if reduce ops is supported. False otherwise."""
+    denyList = [
+        ReduceOp.MAX,
+        ReduceOp.MIN,
+        ReduceOp.PRODUCT,
+        ReduceOp.BAND,
+        ReduceOp.BOR,
+        ReduceOp.BXOR,
+    ]
+    return reduceOp not in denyList
+
+
+class Backend(str):
+    """
+    An enum-like class for backends.
+
+    Available backends: GLOO, NCCL, UCC, MPI, and other registered backends.
+
+    The values of this class are lowercase strings, e.g., ``"gloo"``. They can
+    be accessed as attributes, e.g., ``Backend.NCCL``.
+
+    This class can be directly called to parse the string, e.g.,
+    ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
+    return the parsed lowercase string if so. It also accepts uppercase strings,
+    e.g., ``Backend("GLOO")`` returns ``"gloo"``.
+
+    .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
+              initial value of some fields. Users should neither use it directly
+              nor assume its existence.
+    """
+
+    UNDEFINED = "undefined"
+    GLOO = "gloo"
+    NCCL = "nccl"
+    UCC = "ucc"
+    MPI = "mpi"
+
+    _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])
+
+    _plugins: Dict[str, _BackendPlugin] = {}
+
+    backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
+
+    default_device_backend_map: Dict[str, str] = {
+        'cpu' : GLOO,
+        'cuda' : NCCL,
+    }
+
+    backend_capability: Dict[str, List[str]] = {
+        GLOO : ["cpu", "cuda"],
+        NCCL : ["cuda"],
+        UCC : ["cpu", "cuda"],
+        MPI : ["cpu", "cuda"],
+    }
+
+    backend_type_map: Dict[str, ProcessGroup.BackendType] = {
+        UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
+        GLOO : ProcessGroup.BackendType.GLOO,
+        NCCL: ProcessGroup.BackendType.NCCL,
+        UCC: ProcessGroup.BackendType.UCC,
+    }
+
+    def __new__(cls, name: str):
+        """Create and return a new instance of the class."""
+        if not isinstance(name, str):
+            raise ValueError("Backend constructor parameter must be string-ish")
+        value = getattr(Backend, name.upper(), Backend.UNDEFINED)
+
+        if value == Backend.UNDEFINED:
+            value = name.lower()
+        return value
+
+    @classmethod
+    def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None) -> None:
+        """
+        Register a new backend with the given name and instantiating function.
+
+        This class method is used by 3rd party ``ProcessGroup`` extension to
+        register new backends.
+
+        Args:
+            name (str): Backend name of the ``ProcessGroup`` extension. It
+                        should match the one in ``init_process_group()``.
+            func (function): Function handler that instantiates the backend.
+                             The function should be implemented in the backend
+                             extension and takes four arguments, including
+                             ``store``, ``rank``, ``world_size``, and ``timeout``.
+            extended_api (bool, optional): Whether the backend supports extended argument structure.
+                                           Default: ``False``. If set to ``True``, the backend
+                                           will get an instance of ``c10d::DistributedBackendOptions``, and
+                                           a process group options object as defined by the backend implementation.
+            device (str or list of str, optional): device type this backend
+                            supports, e.g. "cpu", "cuda", etc. If `None`,
+                            assuming both "cpu" and "cuda"
+
+        .. note:: This support of 3rd party backend is experimental and subject to change.
+
+        """
+        # Allow UCC plugin if Pytorch is not built with native support.
+        # TODO: remove this exception once UCC plugin is fully deprecated.
+        if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())):
+            assert not hasattr(Backend, name.upper()), (
+                f"{name.upper()} c10d backend already exist"
+            )
+        assert name.upper() not in Backend._plugins, (
+            f"{name.upper()} c10d backend creator function already exist"
+        )
+
+        setattr(Backend, name.upper(), name.lower())
+        Backend.backend_list.append(name.lower())
+        if devices is not None:
+            for device in devices:
+                if device != 'cpu' and device != 'cuda':
+                    Backend.default_device_backend_map[device] = name.lower()
+        Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM
+
+        # Update device capability matrix in Backend class
+        if devices is None:
+            # This is more of a backward support for groups like `threaded`:
+            # assume default devices "cpu" and "cuda", but warn
+            warnings.warn(
+                f"Device capability of {name} unspecified, assuming `cpu` and "
+                "`cuda`. Please specify it via the `devices` argument of "
+                "`register_backend`."
+            )
+            Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
+        elif isinstance(devices, str):
+            # Single device string specified. Simply convert to list.
+            Backend.backend_capability[name.lower()] = [devices]
+        else:
+            Backend.backend_capability[name.lower()] = devices
+
+        Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)
+
+class BackendConfig:
+    """Backend configuration class."""
+
+    def __init__(self, backend: Backend):
+        """Init."""
+        self.device_backend_map: Dict[str, Backend] = {}
+        backend = str(backend)
+
+        if backend == Backend.UNDEFINED:
+            # default config when backend is not specified
+            # supported since PyTorch 2.0
+            for device, default_backend in Backend.default_device_backend_map.items():
+                if is_backend_available(default_backend):
+                    if default_backend == Backend.NCCL and not torch.cuda.is_available():
+                        continue
+                    self.device_backend_map[device] = Backend(default_backend)
+        elif backend.lower() in Backend.backend_list:
+            # Cases for when backend is a single string (without device types)
+            # e.g. "nccl", "gloo", "ucc", "mpi"
+            supported_devices = Backend.backend_capability[backend.lower()]
+            backend_val = Backend(backend)
+            self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
+        elif ":" in backend.lower():
+            # Backend specified in "device:backend" format
+            # make sure the backend string is in the correct format
+            # "{device_type1}:{backend1},{device_type2}:{backend2}"
+            # e.g. "cpu:gloo,cuda:nccl"
+            backend_str_error_message = f"""The custom backend string argument is invalid: {backend}.
+                Custom backend string is an experimental feature where the backend string must be in the format:
+                ":,:...". e.g. 'cpu:gloo,cuda:nccl'"""
+
+            # parse the backend string and populate the device_backend_map
+            for device_backend_pair_str in backend.lower().split(","):
+                device_backend_pair = device_backend_pair_str.split(":")
+                if len(device_backend_pair) != 2:
+                    raise ValueError(f"Invalid device:backend pairing: \
+                                     {device_backend_pair_str}. {backend_str_error_message}")
+                device, backend = device_backend_pair
+                if device in self.device_backend_map:
+                    raise ValueError(f"Duplicate device type {device} \
+                                     in backend string: {backend}. {backend_str_error_message}")
+                self.device_backend_map[device] = Backend(backend)
+        else:
+            # User specified a single backend name whose device capability is
+            # unknown, assuming it can support the default devices of PyTorch
+            # (cpu and cuda)
+            warnings.warn(
+                f"Device capability of {backend} unknown, assuming `cpu` and "
+                "`cuda`. You can specify it in `device:backend` format in "
+                "`init_process_group` call."
+            )
+            backend_val = Backend(backend)
+            self.device_backend_map = {
+                "cpu" : backend_val,
+                "cuda" : backend_val,
+                "xpu" : backend_val,
+            }
+
+        logger.info(
+            f"Using backend config: {self.device_backend_map}"  # noqa: G004
+        )
+
+    def __repr__(self):
+        """Return all the device:backend pairs separated by commas."""
+        return ",".join(f"{device}:{backend}" for device, backend in self.device_backend_map.items())
+
+    def get_device_backend_map(self) -> Dict[str, Backend]:
+        """Return backend map of the device."""
+        return self.device_backend_map
+
+class _reduce_op:
+    r"""
+    Deprecated enum-like class.
+
+    For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``.
+
+    :class:`~torch.distributed.ReduceOp` is recommended to use instead.
+    """
+
+    def __init__(self):
+        # __members__ is a dict storing key-value pairs for enum classes
+        for k, v in ReduceOp.RedOpType.__members__.items():
+            setattr(self, k, v)
+        self.__members__ = ReduceOp.RedOpType.__members__
+
+    def __getattribute__(self, key):
+        warnings.warn(
+            "torch.distributed.reduce_op is deprecated, please use "
+            "torch.distributed.ReduceOp instead"
+        )
+        return object.__getattribute__(self, key)
+
+
+reduce_op = _reduce_op()
+
+
+class P2POp:
+    """
+    A class to build point-to-point operations for ``batch_isend_irecv``.
+
+    This class builds the type of P2P operation, communication buffer, peer rank,
+    Process Group, and tag. Instances of this class will be passed to
+    ``batch_isend_irecv`` for point-to-point communications.
+
+    Args:
+        op (Callable): A function to send data to or receive data from a peer process.
+            The type of ``op`` is either ``torch.distributed.isend`` or
+            ``torch.distributed.irecv``.
+        tensor (Tensor): Tensor to send or receive.
+        peer (int): Destination or source rank.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        tag (int, optional): Tag to match send with recv.
+    """
+
+    def __init__(self, op: Callable, tensor: torch.Tensor, peer: int,
+                 group: Optional[ProcessGroup] = None, tag: int = 0):
+        """Init."""
+        self.op = op
+        self.tensor = tensor
+        self.peer = peer
+        self.group = group
+        self.tag = tag
+
+    def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int,
+                group: Optional[ProcessGroup] = None, tag: int = 0):
+        """Create and return a new instance of the class."""
+        _check_op(op)
+        _check_single_tensor(tensor, "tensor")
+        return object.__new__(cls)
+
+
+class _CollOp:
+    """
+    A class to capture collective operations.
+
+    Args:
+        op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``.
+        tensor (Tensor): Tensor to operate on.
+        dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same.
+        redop (ReduceOp, optional): reduce operation.
+        root (int, optional): root of broadcast or reduce.
+    """
+
+    def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torch.Tensor] = None,
+                 redop: Optional[ReduceOp] = None, root: Optional[int] = None):
+        self.op = op
+        self.tensor = tensor
+        self.dst_tensor = dst_tensor
+        self.redop = redop
+        self.root = root
+
+
+# DO NOT USE THESE FIELDS DIRECTLY.
+# Use them through the _world object to make sure the _world override mechanism
+_pg_map: Dict[ProcessGroup, Tuple[str, Store]] = {}
+_pg_names: Dict[ProcessGroup, str] = {}
+_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
+# For a pg, it is a map from ProcessGroup to BackendConfig
+_pg_backend_config: Dict[ProcessGroup, str] = {}
+_group_count = 0
+_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
+_pg_to_tag: Dict[ProcessGroup, str] = {}
+_backend: Optional[str] = None
+
+class _World:
+    """
+    Container class for c10d process group state.
+
+    This is used during registration and lookup of PG state.
+
+    .. warning:: This is an experimental API intended to expose the inner workings
+       of c10d and is subject to change..
+    """
+
+    def __init__(self):
+        self._default_pg = None
+        self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
+        self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
+
+    @property
+    def default_pg(self) -> Optional[ProcessGroup]:
+        """
+        Process group that includes all ranks of the cluster.
+
+        This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed
+        but None is provided.
+        """
+        return self._default_pg
+
+    @default_pg.setter
+    def default_pg(self, value) -> None:
+        self._default_pg = value
+
+    @property
+    def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Store]]:
+        """
+        Provide Mapping from ProcessGroup to backend name and store.
+
+        For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
+        For MPI pg, it is a map from ProcessGroup to (Backend, None)
+
+        TODO don't expose the map, expose fine grained ops
+        """
+        global _pg_map
+        return _pg_map
+
+    @property
+    def pg_names(self) -> Dict[ProcessGroup, str]:
+        """
+        Process group's names, map from ProcessGroup to str.
+
+        TODO don't expose the map, expose fine grained ops
+        """
+        global _pg_names
+        return _pg_names
+
+    @property
+    def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
+        """
+        Process group's global rank to local rank mapping.
+
+        TODO don't expose the map, expose fine grained ops
+        """
+        global _pg_group_ranks
+        return _pg_group_ranks
+
+    @property
+    def pg_backend_config(self) -> Dict[ProcessGroup, str]:
+        """
+        Process group's backend config.
+
+        TODO don't expose the map, expose fine grained ops
+        """
+        global _pg_backend_config
+        return _pg_backend_config
+
+    @property
+    def group_count(self) -> int:
+        """
+        Process group count for default naming.
+
+        TODO don't expose group_count, use something else instead
+        """
+        global _group_count
+        return _group_count
+
+    @group_count.setter
+    def group_count(self, value: int) -> None:
+        """Use to compute the name of ProcessGroups when using global synchronization."""
+        global _group_count
+        _group_count = value
+
+    @property
+    def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
+        global _tags_to_pg
+        return _tags_to_pg
+
+    @property
+    def pg_to_tag(self) -> Dict[ProcessGroup, str]:
+        global _pg_to_tag
+        return _pg_to_tag
+
+    @property
+    def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]:
+        return self._pg_coalesce_state
+
+    @property
+    def pg_default_device(self) -> Dict[ProcessGroup, torch.device]:
+        return self._pg_default_device
+
+    @property
+    def pg_config_info(self) -> List[Dict[str, Any]]:
+        """
+        Return a list of dict with process groups and backends.
+
+        Along with their unique IDs and configurations (types and ranks).
+        """
+        config_info: List[Dict[str, Any]] = []
+        default_pg_size = _get_group_size(None)
+        for pg in self.pg_map.keys():
+            ranks = self.pg_group_ranks[pg]
+            config_info.append(
+                {
+                    "pg_name": self.pg_names[pg],
+                    "uid": _get_process_group_uid(pg),
+                    "backend_config": self.pg_backend_config[pg],
+                    "ranks": list(ranks.keys())
+                    if len(ranks) != default_pg_size
+                    else [],  # 'ranks' is an empty list when all ranks are involved in a pg
+                    "group_size": len(ranks),
+                    "group_count": self.group_count,
+                }
+            )
+        return config_info
+
+
+_world = _World()
+"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
+
+class _WorldMeta(type):
+    """
+    Meta class of ``group`` and ``GroupMember``.
+
+    Allows them to have the class property ``WORLD``.
+    """
+
+    # Points to the default PG once initialized.
+    @property
+    def WORLD(cls) -> Optional[ProcessGroup]:
+        return _world.default_pg
+
+    @WORLD.setter
+    def WORLD(cls, pg: Optional[ProcessGroup]):
+        _world.default_pg = pg
+
+class group(metaclass=_WorldMeta):
+    """Group class. Placeholder."""
+
+    pass
+
+class GroupMember(metaclass=_WorldMeta):
+    """Group member class."""
+
+    NON_GROUP_MEMBER = -100
+
+
+def _get_default_timeout(backend: Backend) -> timedelta:
+    # see note on nccl vs other backend timeout (constants.py)
+    if backend == Backend.NCCL:
+        if not isinstance(default_pg_nccl_timeout, timedelta):
+            # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was
+            # changed to be a warning.  We should fix the moco model.
+            warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled")
+            return default_pg_timeout
+        return default_pg_nccl_timeout
+    else:
+        return default_pg_timeout
+
+def _check_valid_timeout(timeout: Any) -> None:
+    if not isinstance(timeout, timedelta):
+        raise TypeError(
+            f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
+        )
+
+# Default process group state
+_default_pg_init_method: Optional[str] = None
+
+STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
+
+def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device:
+    """
+    Return the device to use with ``group`` for control flow usage (object collectives, barrier).
+
+    There are selection rules:
+        1. If user specifies exactly one backend in ``init_process_group`` call:
+            use that backend
+        2. Else if user specifies multiple "device:backend" pairs in init_process_group:
+            If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory);
+            Otherwise, use the first backend (sort of a random pick).
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+
+    Returns:
+        torch.device: The device to use with ``group``.
+
+    """
+    group = group or _get_default_group()
+    if group in _world.pg_default_device:
+        # Previously searched and cached; just return
+        return _world.pg_default_device[group]
+
+    if not isinstance(group, ProcessGroup):
+        # Provide backward compatibility to cases where `group` passed in is
+        # actually a Backend (like `ProcessGroupGloo`) rather than a
+        # `ProcessGroup` in PT 2.0 sense
+        warnings.warn(
+            f"You are using a Backend {type(group)} as a ProcessGroup. "
+            "This usage is deprecated since PyTorch 2.0. Please use a public API "
+            "of PyTorch Distributed instead."
+        )
+        # Most users create Gloo with private API for object collectives
+        _world.pg_default_device[group] = torch.device("cpu")
+        return _world.pg_default_device[group]
+
+    """
+    ``group._device_types`` is a property pybind that returns the devices
+    ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the
+    ``group`` supports multiple devices.
+    """
+    devices = group._device_types
+
+    if len(devices) == 1:
+        # User fixed exactly one backend in `init_process_group`
+        _world.pg_default_device[group] = devices[0]
+    elif len(devices) == 0:
+        # No backend has been registered with this PG (maybe because no
+        # collective has been run?) We pick cpu as the default and hopefully
+        # this would lazily init Gloo or other available cpu backend.
+        _world.pg_default_device[group] = torch.device("cpu")
+    elif torch.device("cpu") in devices:
+        # There are multiple backends in this PG and cpu is among them.
+        # cpu is preferred as the object is in cpu memory. No need for device
+        # copy.
+        _world.pg_default_device[group] = torch.device("cpu")
+    else:
+        # No cpu in the backend list. Randomly pick the first backend
+        _world.pg_default_device[group] = devices[0]
+
+    logger.info(
+        f"Using device {_world.pg_default_device[group]} for object "  # noqa: G004
+        "collectives."
+    )
+    return _world.pg_default_device[group]
+
+
+@_time_logger
+def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, logging_interval=timedelta(seconds=10)) -> None:
+    """
+    Store based barrier for synchronizing processes.
+
+    Barrier based on store which is used for synchronizing processes after
+    ``init_process_group`` or ``new_group``. Intended to be used only with
+    those two methods and is not a generic alternative to ``barrier()``.
+    """
+    store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}"
+    store.add(store_key, 1)
+    logger.info("Added key: %s to store for rank: %s", store_key, rank)
+
+    # Now wait for all workers to check in with the store.
+    world_size = rendezvous_count
+    worker_count = store.add(store_key, 0)
+
+    last_worker_key = f"{store_key}:last_worker"
+    if worker_count == world_size:
+        store.set(last_worker_key, "1")
+
+    # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout
+    # this value was empirically found while scale testing.
+    logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000))
+
+    start = time.time()
+    while True:
+        try:
+            # This will throw an exception after the logging_interval in which we print out
+            # the status of the group or time out officially, throwing runtime error
+            store.wait([last_worker_key], logging_interval)
+            break
+        except RuntimeError as e:
+            worker_count = store.add(store_key, 0)
+            # Print status periodically to keep track.
+            logger.info(
+                "Waiting in store based barrier to initialize process group for "
+                "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s)",
+                rank, store_key, world_size, worker_count, timeout
+            )
+
+            if timedelta(seconds=(time.time() - start)) > timeout:
+                raise DistStoreError(  # noqa: TRY200
+                    "Timed out initializing process group in store based barrier on "
+                    "rank {}, for key: {} (world_size={}, num_workers_joined={}, timeout={})".format(
+                        rank, store_key, world_size, worker_count, timeout
+                    )
+                )
+
+    logger.info(
+        "Rank %s: Completed store-based barrier for key:%s with %s nodes.", rank, store_key, world_size
+    )
+
+
+def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool:
+    """Check if the current process's rank is not in a given group."""
+    if group is None:
+        return False
+    return group == GroupMember.NON_GROUP_MEMBER
+
+
+def _warn_not_in_group(op_name) -> None:
+    global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank()
+    warnings.warn(
+        f"Running {op_name} on global rank {global_rank} which does not "
+        "belong to the given group."
+    )
+
+
+def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
+    """
+    Translate a global rank into a group rank.
+
+    ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError.
+
+    Args:
+        group (ProcessGroup): ProcessGroup to find the relative rank.
+        global_rank (int): Global rank to query.
+
+    Returns:
+        Group rank of ``global_rank`` relative to ``group``
+
+    N.B. calling this function on the default process group returns identity
+    """
+    if group is GroupMember.WORLD:
+        return global_rank
+    if group not in _world.pg_group_ranks:
+        raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
+    group_ranks = _world.pg_group_ranks[group]
+    if global_rank not in group_ranks:
+        raise ValueError(f"Global rank {global_rank} is not part of group {group}")
+
+    return group_ranks[global_rank]
+
+def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
+    """
+    Translate a group rank into a global rank.
+
+    ``group_rank`` must be part of `group` otherwise this raises RuntimeError.
+
+    Args:
+        group (ProcessGroup): ProcessGroup to find the global rank from.
+        group_rank (int): Group rank to query.
+
+    Returns:
+        Global rank of ``group_rank`` relative to ``group``
+
+    N.B. calling this function on the default process group returns identity
+    """
+    if group is GroupMember.WORLD:
+        return group_rank
+    if group not in _world.pg_group_ranks:
+        raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
+    for rank, grp_rank in _world.pg_group_ranks[group].items():
+        if grp_rank == group_rank:
+            return rank
+    raise ValueError(f"Group rank {group_rank} is not part of group {group}")
+
+# TODO: remove this once the ecosystem moves away from it.
+def _get_global_rank(group, rank) -> int:
+    """Use get_global_rank as this method is deprecated."""
+    warnings.warn(
+        "torch.distributed.distributed_c10d._get_global_rank is deprecated "
+        "please use torch.distributed.distributed_c10d.get_global_rank instead"
+    )
+    return get_global_rank(group, rank)
+
+
+def get_process_group_ranks(group: ProcessGroup) -> List[int]:
+    """
+    Get all ranks associated with ``group``.
+
+    Args:
+        group (ProcessGroup): ProcessGroup to get all ranks from.
+
+    Returns:
+        List of global ranks ordered by group rank.
+    """
+    return list(_world.pg_group_ranks[group].keys())
+
+def _get_group_size(group) -> int:
+    """Get a given group's world size."""
+    if group is GroupMember.WORLD or group is None:
+        default_pg = _get_default_group()
+        return default_pg.size()
+    return group.size()
+
+
+def _get_group_size_by_name(group_name: str) -> int:
+    group = _resolve_process_group(group_name)
+    return group.size()
+
+
+def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str:
+    # TODO(yifu): remove this function once ranks + tag is not a supported
+    # identifier for process group for functional collectives.
+    group = _find_pg_by_ranks_and_tag(tag, ranks)
+    if group is None:
+        raise ValueError("")
+    return group.group_name
+
+
+def _check_single_tensor(param, param_name) -> None:
+    """Check that the parameter ``param_name`` is a single tensor."""
+    if not isinstance(param, torch.Tensor):
+        raise TypeError(
+            f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor
+             but got {type(param)} instead."""
+        )
+
+
+def _check_tensor_list(param, param_name) -> None:
+    """Check that the parameter ``param_name`` is a list of tensors."""
+    if not isinstance(param, list):
+        raise TypeError(
+            f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor]
+             but got {type(param)} instead."""
+        )
+    elif not all(isinstance(p, torch.Tensor) for p in param):
+        raise TypeError(
+            f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor]
+             but got {type(param)} with elements of type {[type(p) for p in param]}."""
+        )
+
+
+def _as_iterable(obj) -> collections.abc.Iterable:
+    return obj if isinstance(obj, list) else (obj,)
+
+def _ensure_all_tensors_same_dtype(*tensors) -> None:
+    last_dtype = None
+    for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
+        tensor_dtype = tensor.dtype
+        # Mixing complex and its element type is allowed
+        if tensor_dtype.is_complex:
+            tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128
+
+        if last_dtype is None:
+            last_dtype = tensor_dtype
+        else:
+            if last_dtype != tensor_dtype:
+                raise ValueError(
+                    "Invalid usage of tensors with different dtypes"
+                    f"Found {last_dtype} and  {tensor.dtype}"
+                )
+
+
+def _check_op(op) -> None:
+    """Check that the ``op`` is either isend or irecv."""
+    if op not in [isend, irecv]:
+        raise ValueError(
+            "Invalid ``op``. Expected ``op`` "
+            "to be of type ``torch.distributed.isend`` or "
+            "``torch.distributed.irecv``."
+        )
+
+
+def _check_p2p_op_list(p2p_op_list) -> None:
+    """
+    Check that the ``p2p_op_list`` is a list of P2POp instances.
+
+    Also, check that all ops use the same group.
+    """
+    if not isinstance(p2p_op_list, list) or not all(
+        isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
+    ):
+        raise ValueError(
+            "Invalid ``p2p_op_list``. Each op is expected to "
+            "to be of type ``torch.distributed.P2POp``."
+        )
+
+    group = p2p_op_list[0].group
+    if not all(group == p2p_op.group for p2p_op in p2p_op_list):
+        raise ValueError("All ops need to use the same group.")
+
+
+def is_mpi_available() -> bool:
+    """Check if the MPI backend is available."""
+    return _MPI_AVAILABLE
+
+
+def is_nccl_available() -> bool:
+    """Check if the NCCL backend is available."""
+    return _NCCL_AVAILABLE
+
+
+def is_gloo_available() -> bool:
+    """Check if the Gloo backend is available."""
+    return _GLOO_AVAILABLE
+
+
+def is_ucc_available() -> bool:
+    """Check if the UCC backend is available."""
+    return _UCC_AVAILABLE
+
+
+def is_backend_available(backend: str) -> bool:
+    """
+    Check backend availability.
+
+    Checks if the given backend is available and supports the built-in backends or
+    third-party backends through function ``Backend.register_backend``.
+
+    Args:
+        backend (str): Backend name.
+    Returns:
+        bool: Returns true if the backend is available otherwise false.
+    """
+    # If the backend has an ``is_backend_available`` function, return the result of that function directly
+    available_func = getattr(torch.distributed, f"is_{backend.lower()}_available", None)
+    if available_func:
+        return available_func()
+
+    return backend.lower() in Backend.backend_list
+
+
+def is_initialized() -> bool:
+    """Check if the default process group has been initialized."""
+    return GroupMember.WORLD is not None
+
+
+def is_torchelastic_launched() -> bool:
+    """
+    Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic).
+
+    The existence of ``TORCHELASTIC_RUN_ID`` environment
+    variable is used as a proxy to determine whether the current process
+    was launched with torchelastic. This is a reasonable proxy since
+    ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a
+    non-null value indicating the job id for peer discovery purposes..
+    """
+    return os.getenv("TORCHELASTIC_RUN_ID") is not None
+
+
+def _is_barrier_after_init() -> int:
+    # Environment variable to control whether process group should perform a
+    # barrier after its init. Default value is 0, i.e. no barrier. If you
+    # experience issue with this setting, you may set
+    # `TORCH_DIST_INIT_BARRIER=1` to add the barrier.
+    return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0"))
+
+
+def _abort_in_destroy_pg() -> bool:
+    # Environment variable to control whether to abort the communicators when users call destroy_process_group()
+    env = os.getenv("TORCH_NCCL_ABORT_IN_DESTROY_PG", "0")
+    return env == "1" or env.lower() == "true"
+
+
+def _get_default_group() -> ProcessGroup:
+    """Get the default process group created by init_process_group."""
+    if not is_initialized():
+        raise ValueError(
+            "Default process group has not been initialized, "
+            "please make sure to call init_process_group."
+        )
+    return not_none(GroupMember.WORLD)
+
+
+def _get_default_store() -> Store:
+    """Get the default store created by init_process_group."""
+    if not is_initialized():
+        raise ValueError(
+            "Default process group has not been initialized, "
+            "please make sure to call init_process_group."
+        )
+    default_pg = _get_default_group()
+    _, default_store = _world.pg_map[default_pg]
+    return default_store
+
+
+def _update_default_pg(pg) -> None:
+    _world.default_pg = pg
+    rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
+    torch._C._distributed_c10d._set_global_rank(rank)
+
+def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
+    """
+    Return the backend configuration of the given process group.
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. The
+            default is the general main process group. If another specific group
+            is specified, the calling process must be part of :attr:`group`.
+
+    Returns:
+        The backend configuration of the given process group as a lower case string.
+
+    """
+    if group is None:
+        pg = _get_default_group()
+    else:
+        pg = group
+    if _rank_not_in_group(pg):
+        raise ValueError("Invalid process group specified")
+    backend_config = _world.pg_backend_config.get(pg)
+    return str(not_none(backend_config))
+
+def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
+    """
+    Return the backend of the given process group.
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. The
+            default is the general main process group. If another specific group
+            is specified, the calling process must be part of :attr:`group`.
+
+    Returns:
+        The backend of the given process group as a lower case string.
+
+    """
+    if group is None:
+        pg = _get_default_group()
+    else:
+        pg = group
+    if _rank_not_in_group(pg):
+        raise ValueError("Invalid process group specified")
+    pg_store = _world.pg_map[pg] if pg in _world.pg_map else None
+    return Backend(not_none(pg_store)[0])
+
+def _get_process_group_uid(pg: ProcessGroup) -> int:
+    backend = None
+    try:
+        backend = pg._get_backend(torch.device("cuda"))
+    except RuntimeError:
+        pass
+    if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
+        return backend.uid
+    return -1
+
+def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]:
+    """
+    Return the pg configuration of the given process group.
+
+    """
+    if group is None:
+        pg = _get_default_group()
+    else:
+        pg = group
+    return {
+        "pg_name": _get_process_group_name(pg),
+        "uid": _get_process_group_uid(pg),
+        "backend_config": get_backend_config(pg),
+        "pg_size": _get_group_size(pg),
+        "ranks": get_process_group_ranks(pg),
+    }
+
+def _get_all_pg_configs() -> List[Dict[str, Any]]:
+    """
+    Return the pg configuration of all the process groups.
+
+    """
+    config_info: List[Dict[str, Any]] = []
+    for pg in _world.pg_map.keys():
+        config_info.append(_get_pg_config(pg))
+    return config_info
+
+def get_pg_count() -> int:
+    """
+    Return the number of process groups.
+
+    """
+    return _world.group_count
+
+def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None:
+    """
+    Set the timeout for the given process group when users want to use a different timeout instead of
+    default values.
+
+    Args:
+        timeout (timedelta): Timeout for operations executed against the process group which
+            users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends.
+            This is the duration after which collectives will be aborted asynchronously and the process will crash.
+            This is done since CUDA execution is async and it is no longer safe to continue executing user code since
+            failed async NCCL operations might result in subsequent CUDA operations running on corrupted data.
+            When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
+
+        group (ProcessGroup, optional): The process group to work on. The
+            default is the general main process group. If another specific group
+            is specified, the calling process must be part of :attr:`group`.
+
+    Returns:
+        None
+    """
+    if group is None:
+        group = _get_default_group()
+    if _rank_not_in_group(group):
+        raise ValueError("Invalid process group specified")
+    assert isinstance(group, ProcessGroup)
+    devices = group._device_types
+    backends = set()
+    if torch.device("cpu") in devices and is_gloo_available():
+        backend = group._get_backend(torch.device("cpu"))
+        if isinstance(backend, ProcessGroupGloo):
+            backends.add(backend)
+    if torch.device("cuda") in devices:
+        backend = group._get_backend(torch.device("cuda"))
+        if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
+            backends.add(backend)  # type: ignore[arg-type]
+        elif is_gloo_available() and isinstance(backend, ProcessGroupGloo):
+            backends.add(backend)  # type: ignore[arg-type]
+    if len(backends) == 0:
+        warnings.warn("Set timeout is now only supported for either nccl or gloo.")
+    for backend in backends:
+        backend._set_default_timeout(timeout)
+
+
+@_exception_logger
+@_time_logger
+def init_process_group(
+    backend: Optional[str] = None,
+    init_method: Optional[str] = None,
+    timeout: Optional[timedelta] = None,
+    world_size: int = -1,
+    rank: int = -1,
+    store: Optional[Store] = None,
+    group_name: str = "",
+    pg_options: Optional[Any] = None,
+    device_id: Optional[torch.device] = None,
+) -> None:
+    """
+    Initialize the default distributed process group.
+
+    This will also initialize the distributed package.
+
+    There are 2 main ways to initialize a process group:
+        1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
+        2. Specify ``init_method`` (a URL string) which indicates where/how
+           to discover peers. Optionally specify ``rank`` and ``world_size``,
+           or encode all required parameters in the URL and omit them.
+
+    If neither is specified, ``init_method`` is assumed to be "env://".
+
+
+    Args:
+        backend (str or Backend, optional): The backend to use. Depending on
+            build-time configurations, valid values include ``mpi``, ``gloo``,
+            ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo``
+            and ``nccl`` backend will be created, see notes below for how multiple
+            backends are managed. This field can be given as a lowercase string
+            (e.g., ``"gloo"``), which can also be accessed via
+            :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
+            multiple processes per machine with ``nccl`` backend, each process
+            must have exclusive access to every GPU it uses, as sharing GPUs
+            between processes can result in deadlocks. ``ucc`` backend is
+            experimental.
+        init_method (str, optional): URL specifying how to initialize the
+                                     process group. Default is "env://" if no
+                                     ``init_method`` or ``store`` is specified.
+                                     Mutually exclusive with ``store``.
+        world_size (int, optional): Number of processes participating in
+                                    the job. Required if ``store`` is specified.
+        rank (int, optional): Rank of the current process (it should be a
+                              number between 0 and ``world_size``-1).
+                              Required if ``store`` is specified.
+        store(Store, optional): Key/value store accessible to all workers, used
+                                to exchange connection/address information.
+                                Mutually exclusive with ``init_method``.
+        timeout (timedelta, optional): Timeout for operations executed against
+            the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends.
+            This is the duration after which collectives will be aborted asynchronously and the process will crash.
+            This is done since CUDA execution is async and it is no longer safe to continue executing user code since
+            failed async NCCL operations might result in subsequent CUDA operations running on corrupted data.
+            When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
+
+        group_name (str, optional, deprecated): Group name. This argument is ignored
+        pg_options (ProcessGroupOptions, optional): process group options
+            specifying what additional options need to be passed in during
+            the construction of specific process groups. As of now, the only
+            options we support is ``ProcessGroupNCCL.Options`` for the ``nccl``
+            backend, ``is_high_priority_stream`` can be specified so that
+            the nccl backend can pick up high priority cuda streams when
+            there're compute kernels waiting.
+        device_id (torch.device, optional): a single, specific device
+            to "bind" this process to, allowing for backend-specific
+            optimizations.  Currently this has two effects, only under
+            NCCL: the communicator is immediately formed (calling
+            ``ncclCommInit*`` immediately rather than the normal lazy
+            call) and sub-groups will use ``ncclCommSplit`` when
+            possible to avoid unnecessary overhead of group creation. If you
+            want to know NCCL initialization error early, you can also use this
+            field.
+
+    .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
+        on a system that supports MPI.
+
+    .. note:: Support for multiple backends is experimental. Currently when no backend is
+        specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend
+        will be used for collectives with CPU tensors and the ``nccl`` backend will be used
+        for collectives with CUDA tensors. A custom backend can be specified by passing in
+        a string with format ":,:", e.g.
+        "cpu:gloo,cuda:custom_backend".
+
+    """
+
+    global _world
+
+    global _backend
+    global _default_pg_init_method
+
+    if GroupMember.WORLD is not None:
+        raise ValueError("trying to initialize the default process group twice!")
+
+    set_pytorch_distributed_envs_from_justknobs()
+
+    assert (store is None) or (
+        init_method is None
+    ), "Cannot specify both init_method and store."
+
+    if store is not None:
+        assert world_size > 0, "world_size must be positive if using store"
+        assert rank >= 0, "rank must be non-negative if using store"
+    elif init_method is None:
+        init_method = "env://"
+
+    if backend:
+        backend = Backend(backend)
+    else:
+        backend = Backend("undefined")
+
+    if timeout is None:
+        timeout = _get_default_timeout(backend)
+
+    _check_valid_timeout(timeout)
+
+    """
+    Group name is not visible to users unless they access
+    internals of c10d. This means we can ignore the value
+    they provide as it not exposed in a public way.
+    """
+    group_name = _process_group_name([], use_hashed_name=False)
+    if backend == Backend.MPI:
+        if world_size != -1 or rank != -1:
+            warnings.warn(
+                f"For MPI backend, world_size ({world_size}) and rank ({rank}) "
+                "are ignored since they are assigned by the "
+                "MPI runtime."
+            )
+
+        default_pg, _ = _new_process_group_helper(
+            -1, -1, [], backend, None, group_name, timeout=timeout
+        )
+        _update_default_pg(default_pg)
+    else:
+        # backward compatible API
+        if store is None:
+            rendezvous_iterator = rendezvous(
+                not_none(init_method), rank, world_size, timeout=timeout
+            )
+            store, rank, world_size = next(rendezvous_iterator)
+            store.set_timeout(timeout)
+
+            # Use a PrefixStore to avoid accidental overrides of keys used by
+            # different systems (e.g. RPC) in case the store is multi-tenant.
+            store = PrefixStore("default_pg", store)
+
+        default_pg, _ = _new_process_group_helper(
+            world_size,
+            rank,
+            [],
+            backend,
+            store,
+            group_name,
+            pg_options=pg_options,
+            timeout=timeout,
+            device_id=device_id,
+        )
+        _update_default_pg(default_pg)
+
+    _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())}  # type: ignore[attr-defined, index]
+    _backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
+    _default_pg_init_method = init_method
+
+    old_hook = sys.excepthook
+
+    def _distributed_excepthook(*args):
+        old_stderr = sys.stderr
+        sys.stderr = buf = io.StringIO()
+        try:
+            old_hook(*args)
+        finally:
+            sys.stderr = old_stderr
+        msg = buf.getvalue()
+        prefix = f"[rank{get_rank()}]"
+        msg = "\n".join(f"{prefix}: {s}" if s != "" else "" for s in msg.split("\n"))
+        sys.stderr.write(msg)
+        sys.stderr.flush()
+
+    sys.excepthook = _distributed_excepthook
+
+    if _is_barrier_after_init() == 1:
+        # barrier at the end to ensure that once we return from this method, all
+        # process groups including global variables (if any) are updated
+        # correctly on all ranks.
+        # Update 04/2023: for large-scale runs, this barrier (esp. store-based
+        # barrier) may be costly and/or unscalable. Also, in a lot of cases,
+        # these barriers may be unnecessary, as proven by a green CI after
+        # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
+        # added which enables this barrier only when set to 1.
+        logger.info(
+            "Performing barrier after ProcessGroup initialization since "
+            "TORCH_DIST_INIT_BARRIER = 1"
+        )
+        if backend == Backend.MPI:
+            # MPI backend doesn't use store.
+            barrier()
+        else:
+            # Use store based barrier here since barrier() used a bunch of
+            # default devices and messes up NCCL internal state.
+            _store_based_barrier(rank, store, group_name, world_size, timeout)
+
+def _get_split_source(pg):
+    split_from = None
+    if pg.bound_device_id:
+        split_from = pg._get_backend(pg.bound_device_id)
+    elif pg is _world.default_pg:
+        try:
+            split_from = pg._get_backend(torch.device("cuda"))
+        except RuntimeError:
+            # no cuda device associated with this backend
+            pass
+
+    if not split_from or not split_from.supports_splitting:
+        return None
+
+    # If necessary, find a backend to split from by peeling process
+    # group wrappers from our potentially wrapped process group.
+    while isinstance(split_from, _ProcessGroupWrapper):
+        split_from = split_from.wrapped_pg
+
+    return split_from
+
+def _shutdown_backend(pg):
+    """
+    Try to shut down the backend of a process group.
+    Currently, only ProcessGroupNCCL backend is supported.
+    No op for other backends.
+    """
+    backend = None
+    try:
+        backend = pg._get_backend(torch.device("cuda"))
+    except RuntimeError:
+        pass
+    if isinstance(backend, ProcessGroupNCCL):
+        # explictly call shutdown to ensure that NCCL resources are released
+        backend._shutdown()
+
+def _new_process_group_helper(
+    group_size,
+    group_rank,
+    global_ranks_in_group,
+    backend,
+    store,
+    group_name,
+    pg_options=None,
+    timeout=None,
+    pg_tag=None,
+    device_id=None,
+):
+    """
+    Create a new distributed process group.
+
+    This function must be called by ALL processes in the global group, even if
+    the calling process is not part of the newly created group. In that case,
+    this function returns GroupMember.NON_GROUP_MEMBER.
+
+    This function is called with ``global_ranks_in_group == []`` for the default group.
+    """
+    global _world
+
+    if group_name in _world.pg_names.values():
+        raise ValueError(
+            "The specified group name has already been "
+            "created, please use a different group name"
+        )
+
+    if device_id is not None and (device_id.index is None or device_id.type != 'cuda'):
+        raise ValueError("init_process_group device_id parameter must be a cuda device with an "
+                         "id, e.g. cuda:0, not just cuda or cpu")
+
+    # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
+    _check_valid_timeout(timeout)
+
+    if pg_tag not in [None, ""]:
+        # creating with the same tag and rank set results in the same underlying PG
+        existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
+        if existing_group:
+            _, prefix_store = _world.pg_map[existing_group]
+            return existing_group, prefix_store
+
+    # The list of group ranks is empty if we're creating the default group.
+    is_default_group = len(global_ranks_in_group) == 0
+
+    # nccl and potentially other backends allow creation of
+    # communicators based on pre-existing ones, which can save
+    # initialization time.  Due to lazy initialization of
+    # communicators in some backends, we have to be careful and only
+    # split when we *know* the backends already are connected _on all
+    # ranks_.  We can only know this if the group we are making is the
+    # entire world or if we have bound a device id to the world (which
+    # causes early connection initialization).
+    if (is_initialized() and
+            (len(global_ranks_in_group) == _get_default_group().size() or _get_default_group().bound_device_id)):
+        split_from = _get_split_source(_get_default_group())
+    else:
+        split_from = None
+
+    # If this is a subgroup (which means group_ranks is specified),
+    # we check if the current process is a member of the new group.
+    if not is_default_group:
+        global_rank = _get_default_group().rank()
+        if global_rank not in global_ranks_in_group:
+            # If we are using `ncclCommSplit` (or similar split from
+            # other APIs) to create the communicator, we will need to
+            # call `ncclCommSplit` on *all* ranks in this new group's
+            # parent group, even those not in the new group.  This is
+            # a requirement of the NCCL API as otherwise we would get
+            # out of sync.
+            if split_from:
+                split_from.perform_nocolor_split(_get_default_group().bound_device_id)
+            return GroupMember.NON_GROUP_MEMBER, None
+
+    prefix_store = PrefixStore(f"{group_name}/", store)
+    base_pg_options = ProcessGroup.Options(backend=str(backend))
+    base_pg_options._timeout = timeout
+    pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options)
+    if device_id:
+        pg.bound_device_id = device_id
+    backend_config = BackendConfig(backend)
+    backend_class: torch._C._distributed_c10d.Backend
+    for device, backend_str in backend_config.get_device_backend_map().items():
+        # Use the group name as prefix in the default store, such that
+        # a single store can be reused by multiple groups.
+        backend_prefix_store = PrefixStore(f"{device}/", prefix_store)
+
+        if backend_str == Backend.MPI:
+            if not is_mpi_available():
+                raise RuntimeError(
+                    "Distributed package doesn't have MPI built in."
+                    " MPI is only included if you build PyTorch from"
+                    " source on a host that has MPI installed."
+                )
+            backend_class = ProcessGroupMPI.create(global_ranks_in_group)
+            backend_type = ProcessGroup.BackendType.MPI
+            if not backend_class:
+                return GroupMember.NON_GROUP_MEMBER, None
+            # create new process group with accurate rank and size
+            if pg.rank() == -1 and pg.size() == -1:
+                pg = ProcessGroup(backend_prefix_store, backend_class.rank(), backend_class.size(), base_pg_options)
+        elif backend_str == Backend.GLOO:
+            # TODO: remove this check after lazy initialization is supported
+            # if pg_options is not None:
+            #     raise RuntimeError("GLOO options not supported")
+            backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout)
+            backend_type = ProcessGroup.BackendType.GLOO
+        elif backend_str == Backend.NCCL:
+            if not is_nccl_available():
+                raise RuntimeError("Distributed package doesn't have NCCL built in")
+            if pg_options is not None:
+                assert isinstance(
+                    pg_options, ProcessGroupNCCL.Options
+                ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
+                if pg_options._timeout != timeout:
+                    warnings.warn(
+                        "pg_options._timeout was specified, "
+                        "but timeout kwarg has a default value that will always override it. "
+                    )
+            else:
+                # default pg_options for NCCL
+                pg_options = ProcessGroupNCCL.Options()
+                pg_options.is_high_priority_stream = False
+            pg_options._timeout = timeout
+
+            if split_from:
+                pg_options.split_from = split_from
+                pg_options.split_color = _process_group_color(global_ranks_in_group)
+            pg_options.global_ranks_in_group = global_ranks_in_group
+            backend_class = ProcessGroupNCCL(
+                backend_prefix_store, group_rank, group_size, pg_options)
+            backend_type = ProcessGroup.BackendType.NCCL
+        elif backend_str == Backend.UCC and is_ucc_available():
+            # TODO: once UCC plugin is fully deprecated, remove
+            # is_ucc_available() from above elif-condition and raise
+            # RuntimeError if is_ucc_available() returns false.
+
+            backend_class = ProcessGroupUCC(backend_prefix_store, group_rank, group_size, timeout=timeout)
+            backend_type = ProcessGroup.BackendType.UCC
+        else:
+            assert backend_str.upper() in Backend._plugins, (
+                f"Unknown c10d backend type {backend_str.upper()}"
+            )
+
+            backend_plugin = Backend._plugins[backend_str.upper()]
+            creator_fn = backend_plugin.creator_fn
+            extended_api = backend_plugin.extended_api
+            backend_type = ProcessGroup.BackendType.CUSTOM
+
+            if not extended_api:
+                backend_class = creator_fn(backend_prefix_store, group_rank, group_size, timeout)
+            else:
+                dist_backend_opts = _DistributedBackendOptions()
+                dist_backend_opts.store = backend_prefix_store
+                dist_backend_opts.group_rank = group_rank
+                dist_backend_opts.group_size = group_size
+                dist_backend_opts.timeout = timeout
+                dist_backend_opts.group_id = group_name
+                dist_backend_opts.global_ranks_in_group = global_ranks_in_group
+
+                backend_class = creator_fn(dist_backend_opts, pg_options)
+
+        # Set sequence numbers for gloo and nccl backends.
+        if backend_str == Backend.GLOO:
+            assert isinstance(backend_class, ProcessGroupGloo)
+            backend_class._set_sequence_number_for_group()
+        elif backend_str == Backend.NCCL:
+            assert isinstance(backend_class, ProcessGroupNCCL)
+            backend_class._set_sequence_number_for_group()
+
+        # If the type is a subclass of ProcessGroup then return this process group immediately
+        # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the
+        # ProcessGroup instance
+        if issubclass(type(backend_class), ProcessGroup):
+            pg = backend_class  # type: ignore[assignment]
+            break
+
+        # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
+        if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC]:
+            # In debug mode and if GLOO is available, wrap in a wrapper PG that
+            # enables enhanced collective checking for debuggability.
+            if get_debug_level() == DebugLevel.DETAIL:
+                if not _GLOO_AVAILABLE:
+                    logger.info(
+                        """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
+                                GLOO is not available. Build with Gloo to
+                                create a wrapper process group in debug mode
+                                to aid collective desynchronization debugging."""
+                    )
+                else:
+                    backend_class = _create_process_group_wrapper(
+                        wrapped_pg=backend_class,
+                        store_prefix=group_name,
+                        store=backend_prefix_store,
+                        rank=group_rank,
+                        world_size=group_size,
+                        timeout=timeout,
+                    )
+
+        # register only a single backend when all get_device_backend_map values are the same
+        if len(set(backend_config.get_device_backend_map().values())) == 1:
+            for device in backend_config.get_device_backend_map().keys():
+                pg._register_backend(torch.device(device), backend_type, backend_class)
+
+            # break out of outer loop to not create any more backends
+            break
+
+        pg._register_backend(torch.device(device), backend_type, backend_class)
+
+    if device_id and pg._get_backend(device_id).supports_splitting:
+        eager_backend = pg._get_backend(device_id)
+        eager_backend.eager_connect_single_device(device_id)
+
+    # update global state
+    assert group_name is not None
+    _world.pg_map[pg] = (backend, prefix_store)
+    _world.pg_names[pg] = group_name
+    pg._set_group_name(group_name)
+    _register_process_group(group_name, pg)
+
+    _world.pg_backend_config[pg] = str(backend_config)
+    # "" is the default tag for user PGs
+    if pg_tag in [None, ""]:
+        pg_tag = f"ptd:{group_name}"
+        _world.tags_to_pg.setdefault("", []).append(pg)
+    else:
+        pg_tag = f"user:{pg_tag}"
+
+    _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
+    _world.pg_to_tag[pg] = pg_tag
+    return pg, prefix_store
+
+def destroy_process_group(group: Optional[ProcessGroup] = None):
+    """
+    Destroy a given process group, and deinitialize the distributed package.
+
+    Args:
+        group (ProcessGroup, optional): The process group to be destroyed, if
+                                        group.WORLD is given, all process
+                                        groups including the default one will
+                                        be destroyed.
+    """
+    global _world
+
+    if group == GroupMember.NON_GROUP_MEMBER:
+        return
+
+    if group is None:
+        pg = GroupMember.WORLD
+    else:
+        pg = group
+
+    assert pg is not None
+    if _world.pg_map.get(pg, None) is None:
+        raise ValueError("Invalid process group specified")
+
+    # When users register Python onCompletion hooks, those hooks will run on a
+    # different thread than the main thread. Today, the ProcessGroup dtor does
+    # wait for that thread. However, the dtor might finish after the Python
+    # Interpreter exits. After that grabbing the GIL for the Python hook will crash.
+    # We can either revive the interpreter when running hooks or keep the main one
+    # alive until all works and hooks are done. The current implementation does the
+    # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
+    # for the pending hooks to finish.
+    if pg.name().lower() == "nccl" and pg._has_hooks():
+        pg._wait_for_pending_works()
+
+    if group is None or group == GroupMember.WORLD:
+        if _abort_in_destroy_pg():
+            # shutdown all backends in the order of pg names. shutting down in order because
+            # ncclCommAbort() was a 'collective' call in some versions of NCCL.
+            for pg_to_shutdown in sorted(_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True):
+                _shutdown_backend(pg_to_shutdown)
+
+        _update_default_pg(None)
+        _world.pg_map.clear()
+        _world.pg_names.clear()
+        _world.pg_group_ranks.clear()
+        _world.pg_backend_config.clear()
+        _world.pg_to_tag.clear()
+        _world.tags_to_pg.clear()
+        _world.pg_coalesce_state.clear()
+        _world.pg_default_device.clear()
+        _unregister_all_process_groups()
+
+        # when process group doesn't have an explicit name (only WORLD (default)
+        # process group can have an explicit name), we use global _world.group_count
+        # to generate the name. We need to reset the counter on destruction to
+        # allow consistent value to be generated when we re-create process
+        # groups after some trainers recover from failure
+        #
+        # We only reset this when WORLD is being destroyed because if this
+        # process group is in good state, we aren't dealing with failures.
+        _world.group_count = 0
+    else:
+        if _abort_in_destroy_pg():
+            _shutdown_backend(pg)
+        del _world.pg_map[pg]
+        del _world.pg_names[pg]
+        del _world.pg_group_ranks[pg]
+        del _world.pg_backend_config[pg]
+        if pg in _world.pg_default_device:
+            del _world.pg_default_device[pg]
+        if pg in _world.pg_coalesce_state.keys():
+            warnings.warn(
+                "Some coalesced collectives haven't been launched when "
+                "ProcessGroup is destroyed. They will be cleaned."
+            )
+            del _world.pg_coalesce_state[pg]
+
+        tag = _world.pg_to_tag.get(pg)
+        del _world.pg_to_tag[pg]
+        if tag is not None:
+            try:
+                _world.tags_to_pg[tag].remove(pg)
+                if tag.startswith("ptd:"):
+                    _world.tags_to_pg[""].remove(pg)
+            except Exception:
+                pass
+        _unregister_process_group(pg.group_name)
+
+
+def get_rank(group: Optional[ProcessGroup] = None) -> int:
+    """
+    Return the rank of the current process in the provided ``group``, default otherwise.
+
+    Rank is a unique identifier assigned to each process within a distributed
+    process group. They are always consecutive integers ranging from 0 to
+    ``world_size``.
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+
+    Returns:
+        The rank of the process group
+        -1, if not part of the group
+
+    """
+    if _rank_not_in_group(group):
+        return -1
+
+    default_pg = _get_default_group()
+    if group is None or group is GroupMember.WORLD:
+        return default_pg.rank()
+
+    return get_group_rank(group, default_pg.rank())
+
+
+def get_world_size(group: Optional[ProcessGroup] = None) -> int:
+    """
+    Return the number of processes in the current process group.
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+
+    Returns:
+        The world size of the process group
+        -1, if not part of the group
+
+    """
+    if _rank_not_in_group(group):
+        return -1
+
+    return _get_group_size(group)
+
+
+def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]:
+    """
+    Send a tensor asynchronously.
+
+    .. warning::
+        Modifying ``tensor`` before the request completes causes undefined
+        behavior.
+
+    .. warning::
+        ``tag`` is not supported with the NCCL backend.
+
+    Args:
+        tensor (Tensor): Tensor to send.
+        dst (int): Destination rank on global process group (regardless of ``group`` argument)
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        tag (int, optional): Tag to match send with remote recv
+
+    Returns:
+        A distributed request object.
+        None, if not part of the group
+
+    """
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("isend")
+        return None
+
+    if tensor.is_complex():
+        tensor = torch.view_as_real(tensor)
+
+    if group is None or group is GroupMember.WORLD:
+        pg = _get_default_group()
+    else:
+        pg = group
+        dst = get_group_rank(pg, dst)
+
+    return pg.send([tensor], dst, tag)
+
+def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]:
+    """
+    Receives a tensor asynchronously.
+
+    .. warning::
+        ``tag`` is not supported with the NCCL backend.
+
+    Args:
+        tensor (Tensor): Tensor to fill with received data.
+        src (int, optional): Source rank on global process group (regardless of ``group`` argument).
+            Will receive from any process if unspecified.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        tag (int, optional): Tag to match recv with remote send
+
+    Returns:
+        A distributed request object.
+        None, if not part of the group
+
+    """
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("irecv")
+        return None
+
+    if tensor.is_complex():
+        tensor = torch.view_as_real(tensor)
+
+    if group is None or group is GroupMember.WORLD:
+        pg = _get_default_group()
+    else:
+        pg = group
+
+    if src is None:
+        return pg.recv_anysource([tensor], tag)
+    else:
+        if pg is GroupMember.WORLD:
+            return pg.recv([tensor], src, tag)
+        else:
+            group_src_rank = get_group_rank(pg, src)
+            return pg.recv([tensor], group_src_rank, tag)
+
+@_exception_logger
+def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> None:
+    """
+    Send a tensor synchronously.
+
+    Args:
+        tensor (Tensor): Tensor to send.
+        dst (int): Destination rank on global process group (regardless of ``group`` argument).
+            Destination rank should not be the same as the rank of the current process.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        tag (int, optional): Tag to match send with remote recv
+
+    """
+    if get_rank() == dst:
+        raise ValueError(
+            "Invalid destination rank: destination rank should not be the same as "
+            "the rank of the current process."
+        )
+
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("send")
+        return None
+
+    if tensor.is_complex():
+        tensor = torch.view_as_real(tensor)
+
+    if group is None or group is GroupMember.WORLD:
+        default_pg = _get_default_group()
+        default_pg.send([tensor], dst, tag).wait()
+    else:
+        group_dst_rank = get_group_rank(group, dst)
+        group.send([tensor], group_dst_rank, tag).wait()
+
+@_exception_logger
+def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> int:
+    """
+    Receives a tensor synchronously.
+
+    Args:
+        tensor (Tensor): Tensor to fill with received data.
+        src (int, optional): Source rank on global process group (regardless of ``group`` argument).
+            Will receive from any process if unspecified.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        tag (int, optional): Tag to match recv with remote send
+
+    Returns:
+        Sender rank
+        -1, if not part of the group
+
+    """
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("recv")
+        return -1
+
+    if tensor.is_complex():
+        tensor = torch.view_as_real(tensor)
+
+    if group is None:
+        pg = _get_default_group()
+    else:
+        pg = group
+
+    if src is None:
+        work = pg.recv_anysource([tensor], tag)
+        work.wait()
+        src_rank = work._source_rank()
+        if group is None or group is GroupMember.WORLD:
+            return src_rank
+        else:
+            return get_global_rank(pg, src_rank)
+    else:
+        if group is None or group is GroupMember.WORLD:
+            pg.recv([tensor], src, tag).wait()
+        else:
+            group_src_rank = get_group_rank(pg, src)
+            pg.recv([tensor], group_src_rank, tag).wait()
+        return src
+
+
+class _IllegalWork(Work):
+    def __getattribute__(self, name):
+        if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]:
+            raise ValueError(f"Illegal to call {name} on IllegalWork object")
+
+
+class _CoalescingManager:
+    def __init__(self):
+        self.works: List[Work] = []
+
+    def append(self, work: Work):
+        if work:
+            self.works.append(work)
+
+    def wait(self):
+        for work in self.works:
+            work.wait()
+
+
+@contextlib.contextmanager
+def _coalescing_manager(
+    group: Optional[ProcessGroup] = None,
+    device: Optional[torch.device] = None,
+    async_ops: Optional[bool] = False,
+):
+    """
+    Context manager used to coalesce collectives or P2P operations when possible.
+
+    Args:
+        group (`ProcessGroup`, optional): The process group to work on. If None,
+            the default process group will be used.
+        device (`torch.device`, optional): Default is None, set to a device if
+            there isn't a `**_coalesced` implementation by the backend.
+        async_ops (`bool`, optional): whether the coalesced ops are async ops.
+
+    Examples:
+        >>> # xdoctest: +SKIP("no rank")
+        >>> # Synchronous ops
+        >>> with _coalescing_manager():
+        >>>     for i in range(num_colls):
+        >>>         dist.all_reduce(tensors[i])
+        >>> # Asynchronous ops
+        >>> with _coalescing_manager(async_ops=True) as cm:
+        >>>     for i in range(num_colls):
+        >>>         dist.all_reduce(tensors[i])
+        >>> cm.wait()
+
+    .. warning::
+       :func:`_coalescing_manager` currently do not support coalescing
+       all-reduces with different reduce operators, e.g.  `ReduceOp.SUM` mixed
+       with `ReduceOp.PRODUCT`.
+    """
+    group = group or _get_default_group()
+    op_list = _world.pg_coalesce_state.setdefault(group, [])
+    if op_list:
+        raise ValueError("ProcessGroup has non-empty op list at the start of coalescing")
+    if device:
+        group._start_coalescing(device)
+    cm = _CoalescingManager()
+    yield cm
+    op_list = _world.pg_coalesce_state.pop(group)
+    if op_list:
+        # Collectives supporting "Fast Path" coalescing are captured.
+        # See implementation in corresponding collective APIs.
+        # Currently supported:
+        # - coalesced `all_reduce`
+        # - coalesced `all_gather_into_tensor`
+        # - coalesced `reduce_scatter_tensor`
+        op0 = op_list[0].op
+        if op0 == all_reduce:
+            tensors = []
+            for op in op_list:
+                tensors.append(op.tensor)
+            all_reduce_opts = AllreduceCoalescedOptions()
+            all_reduce_opts.reduceOp = not_none(op_list[0].redop)
+            work = group.allreduce_coalesced(tensors, all_reduce_opts)
+        elif op0 == all_gather_into_tensor:
+            inputs = []
+            outputs = []
+            for op in op_list:
+                inputs.append(op.tensor)
+                outputs.append(not_none(op.dst_tensor))
+            work = group.allgather_into_tensor_coalesced(outputs, inputs)
+        elif op0 == reduce_scatter_tensor:
+            inputs = []
+            outputs = []
+            for op in op_list:
+                inputs.append(op.tensor)
+                outputs.append(not_none(op.dst_tensor))
+            reduce_opts = ReduceScatterOptions()
+            reduce_opts.reduceOp = not_none(op_list[0].redop)
+            work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
+        else:
+            raise AssertionError(
+                f"Coalescing manager does not support fast-path coalescing of {op0}, "
+                f"yet {op0} is still recorded in op list. This is an internal error of c10d."
+            )
+
+    if device:
+        # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
+        work = group._end_coalescing(device)
+
+    if async_ops:
+        cm.append(work)  # type: ignore[possibly-undefined]
+    else:
+        work.wait()  # type: ignore[possibly-undefined]
+
+
+def batch_isend_irecv(p2p_op_list):
+    """
+    Send or Receive a batch of tensors asynchronously and return a list of requests.
+
+    Process each of the operations in ``p2p_op_list`` and return the corresponding
+    requests. NCCL, Gloo, and UCC backend are currently supported.
+
+    Args:
+        p2p_op_list: A list of point-to-point operations(type of each operator is
+            ``torch.distributed.P2POp``). The order of the isend/irecv in the list
+            matters and it needs to match with corresponding isend/irecv on the
+            remote end.
+
+    Returns:
+        A list of distributed request objects returned by calling the corresponding
+        op in the op_list.
+
+    Examples:
+        >>> # xdoctest: +SKIP("no rank")
+        >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
+        >>> recv_tensor = torch.randn(2, dtype=torch.float32)
+        >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
+        >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
+        >>> reqs = batch_isend_irecv([send_op, recv_op])
+        >>> for req in reqs:
+        >>>     req.wait()
+        >>> recv_tensor
+        tensor([2, 3])     # Rank 0
+        tensor([0, 1])     # Rank 1
+
+    .. note:: Note that when this API is used with the NCCL PG backend, users must set
+        the current GPU device with `torch.cuda.set_device`, otherwise it will
+        lead to unexpected hang issues.
+
+        In addition, if this API is the first collective call in the ``group``
+        passed to ``dist.P2POp``, all ranks of the ``group`` must participate in
+        this API call; otherwise, the behavior is undefined. If this API call is
+        not the first collective call in the ``group``, batched P2P operations
+        involving only a subset of ranks of the ``group`` are allowed.
+    """
+    _check_p2p_op_list(p2p_op_list)
+    group = p2p_op_list[0].group
+    device = p2p_op_list[0].tensor.device
+    if device.type == "cuda":
+        # NCCL style coalescing
+        with _coalescing_manager(group, device, async_ops=True) as cm:
+            for p2p_op in p2p_op_list:
+                p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
+        return cm.works
+    else:
+        # Backward support for Gloo
+        reqs = []
+        for p2p_op in p2p_op_list:
+            work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
+            if work:
+                reqs.append(work)
+        return reqs
+
+
+@_exception_logger
+def broadcast(tensor, src, group=None, async_op=False):
+    """
+    Broadcasts the tensor to the whole group.
+
+    ``tensor`` must have the same number of elements in all processes
+    participating in the collective.
+
+    Args:
+        tensor (Tensor): Data to be sent if ``src`` is the rank of current
+            process, and tensor to be used to save received data otherwise.
+        src (int): Source rank on global process group (regardless of ``group`` argument).
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    """
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("broadcast")
+        return
+
+    opts = BroadcastOptions()
+    opts.rootRank = src
+    opts.rootTensor = 0
+    opts.asyncOp = async_op
+
+    if group is None or group is GroupMember.WORLD:
+        default_pg = _get_default_group()
+        work = default_pg.broadcast([tensor], opts)
+    else:
+        group_src_rank = get_group_rank(group, src)
+        opts.rootRank = group_src_rank
+        work = group.broadcast([tensor], opts)
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+@_exception_logger
+def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
+    """
+    Reduces the tensor data across all machines in a way that all get the final result.
+
+    After the call ``tensor`` is going to be bitwise identical in all processes.
+
+    Complex tensors are supported.
+
+    Args:
+        tensor (Tensor): Input and output of the collective. The function
+            operates in-place.
+        op (optional): One of the values from
+            ``torch.distributed.ReduceOp``
+            enum.  Specifies an operation used for element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    Examples:
+        >>> # xdoctest: +SKIP("no rank")
+        >>> # All tensors below are of torch.int64 type.
+        >>> # We have 2 process groups, 2 ranks.
+        >>> device = torch.device(f'cuda:{rank}')
+        >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
+        >>> tensor
+        tensor([1, 2], device='cuda:0') # Rank 0
+        tensor([3, 4], device='cuda:1') # Rank 1
+        >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
+        >>> tensor
+        tensor([4, 6], device='cuda:0') # Rank 0
+        tensor([4, 6], device='cuda:1') # Rank 1
+
+        >>> # All tensors below are of torch.cfloat type.
+        >>> # We have 2 process groups, 2 ranks.
+        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
+        >>> tensor
+        tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
+        tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
+        >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
+        >>> tensor
+        tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
+        tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1
+
+    """
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_reduce")
+        return
+
+    if tensor.is_complex():
+        if not supports_complex(op):
+            raise ValueError(f"all_reduce does not support {op} on complex tensors")
+        tensor = torch.view_as_real(tensor)
+
+    opts = AllreduceOptions()
+    opts.reduceOp = op
+    if group is None:
+        group = _get_default_group()
+
+    if group in _world.pg_coalesce_state.keys():
+        # We are in coalescing context, do not issue single operation, just append a collective representation
+        coll = _CollOp(all_reduce, tensor, None, op, None)
+        _world.pg_coalesce_state[group].append(coll)
+        if async_op:
+            return _IllegalWork()
+        else:
+            return None
+
+    work = group.allreduce([tensor], opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+@_exception_logger
+def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
+    """
+    WARNING: at this time individual shape checking is not implemented across nodes.
+
+    For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
+    rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce
+    operation will proceed without complaint and return erroneous outputs. This lack
+    of shape checking results in significant performance improvements but users of this
+    function should take extra care to ensure that each node passes in tensors whose
+    shapes match across nodes.
+
+    Reduces each tensor in tensors (residing on the same device) across all machines
+    in such a way that all get the final result.
+
+    After the call each tensor in tensors is going to bitwise identical
+    in all processes.
+
+    Complex tensors are supported.
+
+    Args:
+        tensors (Union[List[Tensor], Tensor]): Input and output of the collective.
+            The function operates in-place.
+        op (Optional[ReduceOp]): One of the values from
+            ``torch.distributed.ReduceOp`` enum. Specifies an operation used for
+            element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (Optional[bool]): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group.
+
+    """
+    warnings.warn(
+        "torch.distributed.all_reduce_coalesced will be deprecated. If you must "
+        "use it, please revisit our documentation later at "
+        "https://pytorch.org/docs/master/distributed.html#collective-functions"
+    )
+    if isinstance(tensors, torch.Tensor):
+        tensors = [tensors]
+    _check_tensor_list(tensors, "tensor")
+    _ensure_all_tensors_same_dtype(tensors)
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_reduce_coalesced")
+        return
+
+    if any(t.is_complex() for t in tensors) and not supports_complex(op):
+        raise ValueError(f"all_reduce does not support {op} on complex tensors")
+
+    tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]
+
+    opts = AllreduceCoalescedOptions()
+    opts.reduceOp = op
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.allreduce_coalesced(tensors, opts)
+    else:
+        work = group.allreduce_coalesced(tensors, opts)
+
+    if async_op:
+        return work.get_future()
+    else:
+        work.wait()
+
+@_exception_logger
+def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
+    """
+    Reduces the tensor data across all machines.
+
+    Only the process with rank ``dst`` is going to receive the final result.
+
+    Args:
+        tensor (Tensor): Input and output of the collective. The function
+            operates in-place.
+        dst (int): Destination rank on global process group (regardless of ``group`` argument)
+        op (optional): One of the values from
+            ``torch.distributed.ReduceOp``
+            enum.  Specifies an operation used for element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    """
+    _check_single_tensor(tensor, "tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("reduce")
+        return
+
+    opts = ReduceOptions()
+    opts.reduceOp = op
+    opts.rootRank = dst
+
+    if group is None or group is GroupMember.WORLD:
+        default_pg = _get_default_group()
+        work = default_pg.reduce([tensor], opts)
+    else:
+        group_dst_rank = get_group_rank(group, dst)
+        opts.rootRank = group_dst_rank
+        work = group.reduce([tensor], opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+def _object_to_tensor(obj, device, group):
+    f = io.BytesIO()
+    _pickler(f).dump(obj)
+    byte_storage = torch.ByteStorage._from_buffer(f.getvalue())  # type: ignore[attr-defined]
+    # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
+    # Otherwise, it will casue 100X slowdown.
+    # See: https://github.com/pytorch/pytorch/issues/65696
+    byte_tensor = torch.ByteTensor(byte_storage).to(device)
+    if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
+        backend = get_backend(group)
+        if backend == Backend.NCCL:
+            hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
+            logger.warning(f"_object_to_tensor size: {byte_tensor.numel()} hash value: {hash}")  # noqa: G004
+    local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
+    return byte_tensor, local_size
+
+
+def _tensor_to_object(tensor, tensor_size, group):
+    if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
+        backend = get_backend(group)
+        if backend == Backend.NCCL:
+            hash = torch._C._distributed_c10d._hash_tensors([tensor])
+            logger.warning(f"_tensor_to_object size: {tensor.numel()} hash value: {hash}")  # noqa: G004
+    tensor = tensor.cpu()
+    buf = tensor.numpy().tobytes()[:tensor_size]
+    return _unpickler(io.BytesIO(buf)).load()
+
+
+@_exception_logger
+def all_gather_object(object_list, obj, group=None):
+    """
+    Gathers picklable objects from the whole group into a list.
+
+    Similar to :func:`all_gather`, but Python objects can be passed in.
+    Note that the object must be picklable in order to be gathered.
+
+    Args:
+        object_list (list[Any]): Output list. It should be correctly sized as the
+            size of the group for this collective and will contain the output.
+        obj (Any): Pickable Python object to be broadcast from current process.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Default is ``None``.
+
+    Returns:
+        None. If the calling rank is part of this group, the output of the
+        collective will be populated into the input ``object_list``. If the
+        calling rank is not part of the group, the passed in ``object_list`` will
+        be unmodified.
+
+    .. note:: Note that this API differs slightly from the :func:`all_gather`
+        collective since it does not provide an ``async_op`` handle and thus
+        will be a blocking call.
+
+    .. note:: For NCCL-based processed groups, internal tensor representations
+        of objects must be moved to the GPU device before communication takes
+        place. In this case, the device used is given by
+        ``torch.cuda.current_device()`` and it is the user's responsiblity to
+        ensure that this is set so that each rank has an individual GPU, via
+        ``torch.cuda.set_device()``.
+
+    .. warning::
+        :func:`all_gather_object` uses ``pickle`` module implicitly, which is
+        known to be insecure. It is possible to construct malicious pickle data
+        which will execute arbitrary code during unpickling. Only call this
+        function with data you trust.
+
+    .. warning::
+        Calling :func:`all_gather_object` with GPU tensors is not well supported
+        and inefficient as it incurs GPU -> CPU transfer since tensors would be
+        pickled. Please consider using :func:`all_gather` instead.
+
+    Example::
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # Note: Process group initialization omitted on each rank.
+        >>> import torch.distributed as dist
+        >>> # Assumes world_size of 3.
+        >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+        >>> output = [None for _ in gather_objects]
+        >>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
+        >>> output
+        ['foo', 12, {1: 2}]
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_gather_object")
+        return
+
+    current_device = _get_pg_default_device(group)
+    input_tensor, local_size = _object_to_tensor(obj, current_device, group)
+
+    # Gather all local sizes. This is so that we can find the max size, and index
+    # until the correct size when deserializing the tensors.
+    group_size = get_world_size(group=group)
+    object_sizes_tensor = torch.zeros(
+        group_size, dtype=torch.long, device=current_device
+    )
+    object_size_list = [
+        object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
+    ]
+    # Allgather tensor sizes
+    all_gather(object_size_list, local_size, group=group)
+    max_object_size = int(max(object_size_list).item())  # type: ignore[type-var]
+    # Resize tensor to max size across all ranks.
+    input_tensor.resize_(max_object_size)
+    coalesced_output_tensor = torch.empty(
+        max_object_size * group_size, dtype=torch.uint8, device=current_device
+    )
+    # Output tensors are nonoverlapping views of coalesced_output_tensor
+    output_tensors = [
+        coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
+        for i in range(group_size)
+    ]
+    all_gather(output_tensors, input_tensor, group=group)
+    # Deserialize outputs back to object.
+    for i, tensor in enumerate(output_tensors):
+        tensor = tensor.type(torch.uint8)
+        tensor_size = object_size_list[i]
+        object_list[i] = _tensor_to_object(tensor, tensor_size, group)
+
+
+@_exception_logger
+def gather_object(obj, object_gather_list=None, dst=0, group=None):
+    """
+    Gathers picklable objects from the whole group in a single process.
+
+    Similar to :func:`gather`, but Python objects can be passed in. Note that the
+    object must be picklable in order to be gathered.
+
+    Args:
+        obj (Any): Input object. Must be picklable.
+        object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
+            should be correctly sized as the size of the group for this
+            collective and will contain the output. Must be ``None`` on non-dst
+            ranks. (default is ``None``)
+        dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0)
+        group: (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Default is ``None``.
+
+    Returns:
+        None. On the ``dst`` rank, ``object_gather_list`` will contain the
+        output of the collective.
+
+    .. note:: Note that this API differs slightly from the gather collective
+        since it does not provide an async_op handle and thus will be a blocking
+        call.
+
+    .. note:: For NCCL-based processed groups, internal tensor representations
+        of objects must be moved to the GPU device before communication takes
+        place. In this case, the device used is given by
+        ``torch.cuda.current_device()`` and it is the user's responsiblity to
+        ensure that this is set so that each rank has an individual GPU, via
+        ``torch.cuda.set_device()``.
+
+    .. warning::
+        :func:`gather_object` uses ``pickle`` module implicitly, which is
+        known to be insecure. It is possible to construct malicious pickle data
+        which will execute arbitrary code during unpickling. Only call this
+        function with data you trust.
+
+    .. warning::
+        Calling :func:`gather_object` with GPU tensors is not well supported
+        and inefficient as it incurs GPU -> CPU transfer since tensors would be
+        pickled. Please consider using :func:`gather` instead.
+
+    Example::
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # Note: Process group initialization omitted on each rank.
+        >>> import torch.distributed as dist
+        >>> # Assumes world_size of 3.
+        >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
+        >>> output = [None for _ in gather_objects]
+        >>> dist.gather_object(
+        ...     gather_objects[dist.get_rank()],
+        ...     output if dist.get_rank() == 0 else None,
+        ...     dst=0
+        ... )
+        >>> # On rank 0
+        >>> output
+        ['foo', 12, {1: 2}]
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("gather_object")
+        return
+
+    # Ensure object_gather_list is specified appropriately.
+    my_rank = get_rank()
+    _validate_output_list_for_rank(my_rank, dst, object_gather_list)
+    current_device = _get_pg_default_device(group)
+    input_tensor, local_size = _object_to_tensor(obj, current_device, group)
+
+    # Gather all local sizes. This is so that we can find the max size, and index
+    # until the correct size when deserializing the tensors.
+    group_size = get_world_size(group=group)
+    object_sizes_tensor = torch.zeros(
+        group_size, dtype=torch.long, device=current_device
+    )
+    object_size_list = [
+        object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
+    ]
+    # Allgather tensor sizes. An all-gather is needed here despite this being a
+    # gather, since each rank needs to broadcast a tensor of the same (maximal)
+    # size.
+    all_gather(object_size_list, local_size, group=group)
+    max_object_size = int(max(object_size_list).item())  # type: ignore[type-var]
+    # Resize tensor to max size across all ranks.
+    input_tensor.resize_(max_object_size)
+    # Avoid populating output tensors if the result won't be gathered on this rank.
+    if my_rank == dst:
+        coalesced_output_tensor = torch.empty(
+            max_object_size * group_size, dtype=torch.uint8, device=current_device
+        )
+        # Output tensors are nonoverlapping views of coalesced_output_tensor
+        output_tensors = [
+            coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
+            for i in range(group_size)
+        ]
+    # All ranks call gather with equal-sized tensors.
+    gather(
+        input_tensor,
+        gather_list=output_tensors if my_rank == dst else None,  # type: ignore[possibly-undefined]
+        dst=dst,
+        group=group,
+    )
+    if my_rank != dst:
+        return
+    for i, tensor in enumerate(output_tensors):
+        tensor = tensor.type(torch.uint8)
+        tensor_size = object_size_list[i]
+        object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group)
+
+
+@_exception_logger
+def broadcast_object_list(object_list, src=0, group=None, device=None):
+    """
+    Broadcasts picklable objects in ``object_list`` to the whole group.
+
+    Similar to :func:`broadcast`, but Python objects can be passed in.
+    Note that all objects in ``object_list`` must be picklable in order to be
+    broadcasted.
+
+    Args:
+        object_list (List[Any]): List of input objects to broadcast.
+            Each object must be picklable. Only objects on the ``src`` rank will
+            be broadcast, but each rank must provide lists of equal sizes.
+        src (int): Source rank from which to broadcast ``object_list``.
+            Source rank is based on global process group (regardless of ``group`` argument)
+        group: (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Default is ``None``.
+        device (``torch.device``, optional): If not None, the objects are
+            serialized and converted to tensors which are moved to the
+            ``device`` before broadcasting. Default is ``None``.
+
+    Returns:
+        ``None``. If rank is part of the group, ``object_list`` will contain the
+        broadcasted objects from ``src`` rank.
+
+    .. note:: For NCCL-based process groups, internal tensor representations
+        of objects must be moved to the GPU device before communication takes
+        place. In this case, the device used is given by
+        ``torch.cuda.current_device()`` and it is the user's responsibility to
+        ensure that this is set so that each rank has an individual GPU, via
+        ``torch.cuda.set_device()``.
+
+    .. note:: Note that this API differs slightly from the :func:`all_gather`
+        collective since it does not provide an ``async_op`` handle and thus
+        will be a blocking call.
+
+    .. warning::
+        :func:`broadcast_object_list` uses ``pickle`` module implicitly, which
+        is known to be insecure. It is possible to construct malicious pickle
+        data which will execute arbitrary code during unpickling. Only call this
+        function with data you trust.
+
+    .. warning::
+        Calling :func:`broadcast_object_list` with GPU tensors is not well supported
+        and inefficient as it incurs GPU -> CPU transfer since tensors would be
+        pickled. Please consider using :func:`broadcast` instead.
+
+    Example::
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # Note: Process group initialization omitted on each rank.
+        >>> import torch.distributed as dist
+        >>> if dist.get_rank() == 0:
+        >>>     # Assumes world_size of 3.
+        >>>     objects = ["foo", 12, {1: 2}] # any picklable object
+        >>> else:
+        >>>     objects = [None, None, None]
+        >>> # Assumes backend is not NCCL
+        >>> device = torch.device("cpu")
+        >>> dist.broadcast_object_list(objects, src=0, device=device)
+        >>> objects
+        ['foo', 12, {1: 2}]
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("broadcast_object_list")
+        return
+
+    # Current device selection.
+    # To preserve backwards compatibility, ``device`` is default to ``None``
+    # in which case we run current logic of device selection, i.e.
+    # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
+    # case it is not ``None`` we move the size and object tensors to be
+    # broadcasted to this device.
+    current_device = device or _get_pg_default_device(group)
+    my_rank = get_rank()
+    # Serialize object_list elements to tensors on src rank.
+    if my_rank == src:
+        tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list])
+        object_sizes_tensor = torch.cat(size_list)
+    else:
+        object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device)
+
+    # Broadcast object sizes
+    broadcast(object_sizes_tensor, src=src, group=group)
+
+    # Concatenate and broadcast serialized object tensors
+    # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
+    # has only one element, we can skip the copy.
+    if my_rank == src:
+        if len(tensor_list) == 1:  # type: ignore[possibly-undefined]
+            object_tensor = tensor_list[0]
+        else:
+            object_tensor = torch.cat(tensor_list)
+    else:
+        object_tensor = torch.empty(  # type: ignore[call-overload]
+            torch.sum(object_sizes_tensor).item(),  # type: ignore[arg-type]
+            dtype=torch.uint8,
+            device=current_device
+        )
+
+    broadcast(object_tensor, src=src, group=group)
+    # Deserialize objects using their stored sizes.
+    offset = 0
+    if my_rank != src:
+        for i, obj_size in enumerate(object_sizes_tensor):
+            obj_view = object_tensor[offset : offset + obj_size]
+            obj_view = obj_view.type(torch.uint8)
+            offset += obj_size
+            object_list[i] = _tensor_to_object(obj_view, obj_size, group)
+
+
+@_exception_logger
+def scatter_object_list(
+    scatter_object_output_list, scatter_object_input_list, src=0, group=None
+):
+    """
+    Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
+
+    Similar to :func:`scatter`, but Python objects can be passed in. On
+    each rank, the scattered object will be stored as the first element of
+    ``scatter_object_output_list``. Note that all objects in
+    ``scatter_object_input_list`` must be picklable in order to be scattered.
+
+    Args:
+        scatter_object_output_list (List[Any]): Non-empty list whose first
+            element will store the object scattered to this rank.
+        scatter_object_input_list (List[Any]): List of input objects to scatter.
+            Each object must be picklable. Only objects on the ``src`` rank will
+            be scattered, and the argument can be ``None`` for non-src ranks.
+        src (int): Source rank from which to scatter ``scatter_object_input_list``.
+            Source rank is based on global process group (regardless of ``group`` argument).
+        group: (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Default is ``None``.
+
+    Returns:
+        ``None``. If rank is part of the group, ``scatter_object_output_list``
+        will have its first element set to the scattered object for this rank.
+
+    .. note:: Note that this API differs slightly from the scatter collective
+        since it does not provide an ``async_op`` handle and thus will be a
+        blocking call.
+
+    .. warning::
+        :func:`scatter_object_list` uses ``pickle`` module implicitly, which
+        is known to be insecure. It is possible to construct malicious pickle
+        data which will execute arbitrary code during unpickling. Only call this
+        function with data you trust.
+
+    .. warning::
+        Calling :func:`scatter_object_list` with GPU tensors is not well supported
+        and inefficient as it incurs GPU -> CPU transfer since tensors would be
+        pickled. Please consider using :func:`scatter` instead.
+
+    Example::
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # Note: Process group initialization omitted on each rank.
+        >>> import torch.distributed as dist
+        >>> if dist.get_rank() == 0:
+        >>>     # Assumes world_size of 3.
+        >>>     objects = ["foo", 12, {1: 2}] # any picklable object
+        >>> else:
+        >>>     # Can be any list on non-src ranks, elements are not used.
+        >>>     objects = [None, None, None]
+        >>> output_list = [None]
+        >>> dist.scatter_object_list(output_list, objects, src=0)
+        >>> # Rank i gets objects[i]. For example, on rank 2:
+        >>> output_list
+        [{1: 2}]
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("scatter_object_list")
+        return
+
+    if (
+        not isinstance(scatter_object_output_list, list)
+        or len(scatter_object_output_list) < 1
+    ):
+        raise ValueError(
+            "Expected argument scatter_object_output_list to be a list of size at least 1."
+        )
+
+    my_rank = get_rank()
+    pg_device = _get_pg_default_device(group)
+    if my_rank == src:
+        tensor_list, tensor_sizes = zip(
+            *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list]
+        )
+        tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
+
+    # Src rank broadcasts the maximum tensor size. This is because all ranks are
+    # expected to call into scatter() with equal-sized tensors.
+    if my_rank == src:
+        max_tensor_size = max(tensor_sizes)  # type: ignore[possibly-undefined]
+        for tensor in tensor_list:  # type: ignore[possibly-undefined]
+            tensor.resize_(max_tensor_size)
+    else:
+        max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
+    broadcast(max_tensor_size, src=src, group=group)
+
+    # Scatter actual serialized objects
+    output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device)
+    scatter(
+        output_tensor,
+        scatter_list=None if my_rank != src else tensor_list,  # type: ignore[possibly-undefined]
+        src=src,
+        group=group,
+    )
+
+    # Scatter per-object sizes to trim tensors when deserializing back to object
+    obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
+    scatter(
+        obj_tensor_size,
+        scatter_list=None if my_rank != src else tensor_sizes,  # type: ignore[possibly-undefined]
+        src=src,
+        group=group,
+    )
+
+    # Deserialize back to object
+    scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group)
+
+
+@_exception_logger
+def all_gather(tensor_list, tensor, group=None, async_op=False):
+    """
+    Gathers tensors from the whole group in a list.
+
+    Complex tensors are supported.
+
+    Args:
+        tensor_list (list[Tensor]): Output list. It should contain
+            correctly-sized tensors to be used for output of the collective.
+        tensor (Tensor): Tensor to be broadcast from current process.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    Examples:
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # All tensors below are of torch.int64 dtype.
+        >>> # We have 2 process groups, 2 ranks.
+        >>> device = torch.device(f'cuda:{rank}')
+        >>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)]
+        >>> tensor_list
+        [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
+        [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:1')] # Rank 1
+        >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
+        >>> tensor
+        tensor([1, 2], device='cuda:0') # Rank 0
+        tensor([3, 4], device='cuda:1') # Rank 1
+        >>> dist.all_gather(tensor_list, tensor)
+        >>> tensor_list
+        [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0
+        [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1
+
+        >>> # All tensors below are of torch.cfloat dtype.
+        >>> # We have 2 process groups, 2 ranks.
+        >>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)]
+        >>> tensor_list
+        [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
+        [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
+        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j)
+        >>> tensor
+        tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
+        tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
+        >>> dist.all_gather(tensor_list, tensor)
+        >>> tensor_list
+        [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0
+        [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1
+
+    """
+    _check_tensor_list(tensor_list, "tensor_list")
+    _check_single_tensor(tensor, "tensor")
+    _ensure_all_tensors_same_dtype(tensor_list, tensor)
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_gather")
+        return
+
+    tensor_list = [
+        t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
+    ]
+    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
+
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.allgather([tensor_list], [tensor])
+    else:
+        work = group.allgather([tensor_list], [tensor])
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+@_exception_logger
+def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
+    """
+    Gather tensors from all ranks and put them in a single output tensor.
+
+    Args:
+        output_tensor (Tensor): Output tensor to accommodate tensor elements
+            from all ranks. It must be correctly sized to have one of the
+            following forms:
+            (i) a concatenation of all the input tensors along the primary
+            dimension; for definition of "concatenation", see ``torch.cat()``;
+            (ii) a stack of all the input tensors along the primary dimension;
+            for definition of "stack", see ``torch.stack()``.
+            Examples below may better explain the supported output forms.
+        input_tensor (Tensor): Tensor to be gathered from current rank.
+            Different from the ``all_gather`` API, the input tensors in this
+            API must have the same size across all ranks.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    Examples:
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
+        >>> # We have two ranks.
+        >>> device = torch.device(f'cuda:{rank}')
+        >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
+        >>> tensor_in
+        tensor([1, 2], device='cuda:0') # Rank 0
+        tensor([3, 4], device='cuda:1') # Rank 1
+        >>> # Output in concatenation form
+        >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device)
+        >>> dist.all_gather_into_tensor(tensor_out, tensor_in)
+        >>> tensor_out
+        tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
+        tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
+        >>> # Output in stack form
+        >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
+        >>> dist.all_gather_into_tensor(tensor_out2, tensor_in)
+        >>> tensor_out2
+        tensor([[1, 2],
+                [3, 4]], device='cuda:0') # Rank 0
+        tensor([[1, 2],
+                [3, 4]], device='cuda:1') # Rank 1
+
+    .. warning::
+        The Gloo backend does not support this API.
+
+    """
+    _check_single_tensor(input_tensor, "input_tensor")
+    _check_single_tensor(output_tensor, "output_tensor")
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_gather_into_tensor")
+        return
+
+    output_tensor = (
+        output_tensor
+        if not output_tensor.is_complex()
+        else torch.view_as_real(output_tensor)
+    )
+    input_tensor = (
+        input_tensor
+        if not input_tensor.is_complex()
+        else torch.view_as_real(input_tensor)
+    )
+
+    opts = AllgatherOptions()
+    opts.asyncOp = async_op
+
+    group = group or _get_default_group()
+
+    if group in _world.pg_coalesce_state.keys():
+        # We are in coalescing context, do not issue single operation, just append a collective representation
+        coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor)
+        _world.pg_coalesce_state[group].append(coll)
+        if async_op:
+            return _IllegalWork()
+        else:
+            return None
+
+    work = group._allgather_base(output_tensor, input_tensor, opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+@_exception_logger
+def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
+    """
+    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
+
+    Args:
+        output_tensor (Tensor): Output tensor. It should contain
+            correctly-sized tensors to be used for output of the collective.
+        input_tensor (Tensor): Tensor to be broadcast from current process.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    .. warning::
+        `_all_gather_base` is a private function. Users should use
+        `all_gather_into_tensor` instead.
+
+    """
+    warnings.warn(
+        "torch.distributed._all_gather_base is a private function and will be "
+        "deprecated. Please use torch.distributed.all_gather_into_tensor "
+        "instead."
+    )
+    return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)
+
+
+@_exception_logger
+def all_gather_coalesced(
+    output_tensor_lists, input_tensor_list, group=None, async_op=False
+):
+    """
+    Gathers input tensors from the whole group in a list in a coalesced manner.
+
+    Complex tensors are supported.
+
+    Args:
+        output_tensor_lists (list[list[Tensor]]): Output list. It should contain
+            correctly-sized tensors to be used for output of the collective.
+        input_tensor_list (list[Tensor]): Tensors to be broadcast from
+            current process. At least one tensor has to be non empty.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    Example:
+        we have 2 process groups, 2 ranks.
+        rank 0 passes:
+            input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]]
+            output_tensor_lists =
+               [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
+                [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
+        rank 1 passes:
+            input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]]
+            output_tensor_lists =
+               [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
+                [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
+        both rank 0 and 1 get:
+            output_tensor_lists =
+               [[[1, 1], [1, 1]], [2], [3, 3]],
+                [[3, 3], [3, 3]], [5], [1, 1]]].
+
+    WARNING: at this time individual shape checking is not implemented across nodes.
+    For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
+    rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the
+    all_gather_coalesced operation will proceed without complaint and return
+    erroneous outputs. This lack of shape checking results in significant
+    performance improvements but users of this function should take extra care
+    to ensure that each node passes in tensors whose shapes match across nodes.
+    """
+    warnings.warn(
+        "torch.distributed.all_gather_coalesced will be deprecated. If you must "
+        "use it, please revisit our documentation later at "
+        "https://pytorch.org/docs/master/distributed.html#collective-functions"
+    )
+    # We only check basic compatibility with C++ params here, C++ code will
+    # do shape and type checking.
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_gather_coalesced")
+        return
+    _check_tensor_list(input_tensor_list, "input_tensor_list")
+    _ensure_all_tensors_same_dtype(input_tensor_list)
+    if not isinstance(output_tensor_lists, list):
+        raise TypeError(
+            "Invalid function argument: output_tensor_lists should be a list"
+        )
+    for output_tensor_list in output_tensor_lists:
+        _check_tensor_list(output_tensor_list, "output_tensor_lists")
+        _ensure_all_tensors_same_dtype(output_tensor_list)
+
+    output_tensor_lists = [
+        [t if not t.is_complex() else torch.view_as_real(t) for t in l]
+        for l in output_tensor_lists
+    ]
+    input_tensor_list = [
+        t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
+    ]
+
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.allgather_coalesced(output_tensor_lists, input_tensor_list)
+    else:
+        work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
+
+    if async_op:
+        return work.get_future()
+    else:
+        work.wait()
+
+
+def _validate_output_list_for_rank(my_rank, dst, gather_list):
+    if dst == my_rank:
+        if not gather_list:
+            raise ValueError(
+                "Argument ``gather_list`` must be specified on destination rank."
+            )
+    elif gather_list:
+        raise ValueError(
+            "Argument ``gather_list`` must NOT be specified "
+            "on non-destination ranks."
+        )
+
+
+@_exception_logger
+def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
+    """
+    Gathers a list of tensors in a single process.
+
+    Args:
+        tensor (Tensor): Input tensor.
+        gather_list (list[Tensor], optional): List of appropriately-sized
+            tensors to use for gathered data (default is None, must be specified
+            on the destination rank)
+        dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0)
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    """
+    _check_single_tensor(tensor, "tensor")
+
+    # Parameter ``gather_list`` may be left unspecified on non-dst ranks.
+    if gather_list:
+        _check_tensor_list(gather_list, "gather_list")
+    else:
+        gather_list = []
+    _ensure_all_tensors_same_dtype(tensor, gather_list)
+
+    if _rank_not_in_group(group):
+        _warn_not_in_group("gather")
+        return
+
+    my_rank = get_rank()
+    _validate_output_list_for_rank(my_rank, dst, gather_list)
+    output_tensors = [gather_list] if dst == my_rank else []
+    input_tensors = [tensor]
+
+    opts = GatherOptions()
+    opts.rootRank = dst
+
+    if group is None or group is GroupMember.WORLD:
+        default_pg = _get_default_group()
+        work = default_pg.gather(output_tensors, input_tensors, opts)
+    else:
+        group_dst_rank = get_group_rank(group, dst)
+        opts.rootRank = group_dst_rank
+        work = group.gather(output_tensors, input_tensors, opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+@_exception_logger
+def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
+    """
+    Scatters a list of tensors to all processes in a group.
+
+    Each process will receive exactly one tensor and store its data in the
+    ``tensor`` argument.
+
+    Complex tensors are supported.
+
+    Args:
+        tensor (Tensor): Output tensor.
+        scatter_list (list[Tensor]): List of tensors to scatter (default is
+            None, must be specified on the source rank)
+        src (int): Source rank on global process group (regardless of ``group`` argument).
+            Default is 0
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    .. note:: Note that all Tensors in scatter_list must have the same size.
+
+    Example::
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # Note: Process group initialization omitted on each rank.
+        >>> import torch.distributed as dist
+        >>> tensor_size = 2
+        >>> t_ones = torch.ones(tensor_size)
+        >>> t_fives = torch.ones(tensor_size) * 5
+        >>> output_tensor = torch.zeros(tensor_size)
+        >>> if dist.get_rank() == 0:
+        >>>     # Assumes world_size of 2.
+        >>>     # Only tensors, all of which must be the same size.
+        >>>     scatter_list = [t_ones, t_fives]
+        >>> else:
+        >>>     scatter_list = None
+        >>> dist.scatter(output_tensor, scatter_list, src=0)
+        >>> # Rank i gets scatter_list[i]. For example, on rank 1:
+        >>> output_tensor
+        tensor([5., 5.])
+
+    """
+    _check_single_tensor(tensor, "tensor")
+
+    # Parameter ``scatter_list`` may be left unspecified on non-src ranks.
+    if scatter_list:
+        _check_tensor_list(scatter_list, "scatter_list")
+    else:
+        scatter_list = []
+    _ensure_all_tensors_same_dtype(tensor, scatter_list)
+
+    if _rank_not_in_group(group):
+        _warn_not_in_group("scatter")
+        return
+    scatter_list = [
+        t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list
+    ]
+    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
+
+    my_rank = get_rank()
+    if src == my_rank:
+        if not scatter_list:
+            raise ValueError(
+                "Argument ``scatter_list`` must be specified on source rank."
+            )
+        input_tensors = [scatter_list]
+        output_tensors = [tensor]
+    else:
+        if scatter_list:
+            raise ValueError(
+                "Argument ``scatter_list`` must NOT be specified "
+                "on non-source ranks."
+            )
+        input_tensors = []
+        output_tensors = [tensor]
+
+    opts = ScatterOptions()
+    opts.rootRank = src
+    opts.asyncOp = async_op
+
+    if group is None or group is GroupMember.WORLD:
+        default_pg = _get_default_group()
+        work = default_pg.scatter(output_tensors, input_tensors, opts)
+    else:
+        group_src_rank = get_group_rank(group, src)
+        opts.rootRank = group_src_rank
+        work = group.scatter(output_tensors, input_tensors, opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+@_exception_logger
+def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
+    """
+    Reduces, then scatters a list of tensors to all processes in a group.
+
+    Args:
+        output (Tensor): Output tensor.
+        input_list (list[Tensor]): List of tensors to reduce and scatter.
+        op (optional): One of the values from
+            ``torch.distributed.ReduceOp``
+            enum.  Specifies an operation used for element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group.
+
+    """
+    _check_single_tensor(output, "output")
+    _check_tensor_list(input_list, "input_list")
+    _ensure_all_tensors_same_dtype(output, input_list)
+    if _rank_not_in_group(group):
+        _warn_not_in_group("reduce_scatter")
+        return
+
+    opts = ReduceScatterOptions()
+    opts.reduceOp = op
+
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.reduce_scatter([output], [input_list], opts)
+    else:
+        work = group.reduce_scatter([output], [input_list], opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+@_exception_logger
+def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
+    """
+    Reduces, then scatters a tensor to all ranks in a group.
+
+    Args:
+        output (Tensor): Output tensor. It should have the same size across all
+            ranks.
+        input (Tensor): Input tensor to be reduced and scattered. Its size
+            should be output tensor size times the world size. The input tensor
+            can have one of the following shapes:
+            (i) a concatenation of the output tensors along the primary
+            dimension, or
+            (ii) a stack of the output tensors along the primary dimension.
+            For definition of "concatenation", see ``torch.cat()``.
+            For definition of "stack", see ``torch.stack()``.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group.
+
+    Examples:
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
+        >>> # We have two ranks.
+        >>> device = torch.device(f'cuda:{rank}')
+        >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
+        >>> # Input in concatenation form
+        >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
+        >>> tensor_in
+        tensor([0, 1, 2, 3], device='cuda:0') # Rank 0
+        tensor([0, 1, 2, 3], device='cuda:1') # Rank 1
+        >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
+        >>> tensor_out
+        tensor([0, 2], device='cuda:0') # Rank 0
+        tensor([4, 6], device='cuda:1') # Rank 1
+        >>> # Input in stack form
+        >>> tensor_in = torch.reshape(tensor_in, (world_size, 2))
+        >>> tensor_in
+        tensor([[0, 1],
+                [2, 3]], device='cuda:0') # Rank 0
+        tensor([[0, 1],
+                [2, 3]], device='cuda:1') # Rank 1
+        >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
+        >>> tensor_out
+        tensor([0, 2], device='cuda:0') # Rank 0
+        tensor([4, 6], device='cuda:1') # Rank 1
+
+    .. warning::
+        The Gloo backend does not support this API.
+
+    """
+    _check_single_tensor(output, "output")
+    _check_single_tensor(input, "input")
+
+    if _rank_not_in_group(group):
+        _warn_not_in_group("reduce_scatter_tensor")
+        return
+
+    opts = ReduceScatterOptions()
+    opts.reduceOp = op
+    opts.asyncOp = async_op
+
+    group = group or _get_default_group()
+
+    # Check if we are in coalescing context
+    # If we are, do not issue single operation, just append a collective representation
+    if group in _world.pg_coalesce_state.keys():
+        coll = _CollOp(reduce_scatter_tensor, input, output, op, None)
+        _world.pg_coalesce_state[group].append(coll)
+        if async_op:
+            return _IllegalWork()
+        else:
+            return None
+
+    work = group._reduce_scatter_base(output, input, opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False):
+    """
+    Reduces, then scatters a flattened tensor to all processes in a group.
+
+    Args:
+        output (Tensor): Output tensor.
+        input (Tensor): Input tensor that is of size output tensor size times world size
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group.
+
+    .. warning::
+        `_reduce_scatter_base` is a private function. Users should use
+        `reduce_scatter_tensor` instead.
+
+    """
+    warnings.warn(
+        "torch.distributed._reduce_scatter_base is a private function and will "
+        "be deprecated. Please use torch.distributed.reduce_scatter_tensor "
+        "instead."
+    )
+    return reduce_scatter_tensor(output, input, op, group, async_op)
+
+
+@_exception_logger
+def all_to_all_single(
+    output,
+    input,
+    output_split_sizes=None,
+    input_split_sizes=None,
+    group=None,
+    async_op=False,
+):
+    """
+    Split input tensor and then scatter the split list to all processes in a group.
+
+    Later the received tensors are concatenated from all the processes in the group
+    and returned as a single output tensor.
+
+    Complex tensors are supported.
+
+    Args:
+        output (Tensor): Gathered concatenated output tensor.
+        input (Tensor): Input tensor to scatter.
+        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
+            if specified None or empty, dim 0 of ``output`` tensor must divide
+            equally by ``world_size``.
+        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
+            if specified None or empty, dim 0 of ``input`` tensor must divide
+            equally by ``world_size``.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group.
+
+    .. warning::
+        `all_to_all_single` is experimental and subject to change.
+
+    Examples:
+        >>> # xdoctest: +SKIP("Undefined rank")
+        >>> input = torch.arange(4) + rank * 4
+        >>> input
+        tensor([0, 1, 2, 3])     # Rank 0
+        tensor([4, 5, 6, 7])     # Rank 1
+        tensor([8, 9, 10, 11])   # Rank 2
+        tensor([12, 13, 14, 15]) # Rank 3
+        >>> output = torch.empty([4], dtype=torch.int64)
+        >>> dist.all_to_all_single(output, input)
+        >>> output
+        tensor([0, 4, 8, 12])    # Rank 0
+        tensor([1, 5, 9, 13])    # Rank 1
+        tensor([2, 6, 10, 14])   # Rank 2
+        tensor([3, 7, 11, 15])   # Rank 3
+
+        >>> # Essentially, it is similar to following operation:
+        >>> scatter_list = list(input.chunk(world_size))
+        >>> gather_list  = list(output.chunk(world_size))
+        >>> for i in range(world_size):
+        >>>     dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
+
+        >>> # Another example with uneven split
+        >>> input
+        tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
+        tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
+        tensor([20, 21, 22, 23, 24])                                     # Rank 2
+        tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
+        >>> input_splits
+        [2, 2, 1, 1]                                                     # Rank 0
+        [3, 2, 2, 2]                                                     # Rank 1
+        [2, 1, 1, 1]                                                     # Rank 2
+        [2, 2, 2, 1]                                                     # Rank 3
+        >>> output_splits
+        [2, 3, 2, 2]                                                     # Rank 0
+        [2, 2, 1, 2]                                                     # Rank 1
+        [1, 2, 1, 2]                                                     # Rank 2
+        [1, 2, 1, 1]                                                     # Rank 3
+        >>> output = ...
+        >>> dist.all_to_all_single(output, input, output_splits, input_splits)
+        >>> output
+        tensor([ 0,  1, 10, 11, 12, 20, 21, 30, 31])                     # Rank 0
+        tensor([ 2,  3, 13, 14, 22, 32, 33])                             # Rank 1
+        tensor([ 4, 15, 16, 23, 34, 35])                                 # Rank 2
+        tensor([ 5, 17, 18, 24, 36])                                     # Rank 3
+
+
+        >>> # Another example with tensors of torch.cfloat type.
+        >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
+        >>> input
+        tensor([1+1j, 2+2j, 3+3j, 4+4j])                                # Rank 0
+        tensor([5+5j, 6+6j, 7+7j, 8+8j])                                # Rank 1
+        tensor([9+9j, 10+10j, 11+11j, 12+12j])                          # Rank 2
+        tensor([13+13j, 14+14j, 15+15j, 16+16j])                        # Rank 3
+        >>> output = torch.empty([4], dtype=torch.int64)
+        >>> dist.all_to_all_single(output, input)
+        >>> output
+        tensor([1+1j, 5+5j, 9+9j, 13+13j])                              # Rank 0
+        tensor([2+2j, 6+6j, 10+10j, 14+14j])                            # Rank 1
+        tensor([3+3j, 7+7j, 11+11j, 15+15j])                            # Rank 2
+        tensor([4+4j, 8+8j, 12+12j, 16+16j])                            # Rank 3
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_to_all_single")
+        return
+
+    opts = AllToAllOptions()
+    _check_single_tensor(output, "output")
+    _check_single_tensor(input, "input")
+    _ensure_all_tensors_same_dtype(output, input)
+
+    if input.is_complex():
+        input = torch.view_as_real(input)
+    if output.is_complex():
+        output = torch.view_as_real(output)
+
+    output_split_sizes = [] if output_split_sizes is None else output_split_sizes
+    input_split_sizes = [] if input_split_sizes is None else input_split_sizes
+
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.alltoall_base(
+            output, input, output_split_sizes, input_split_sizes, opts
+        )
+    else:
+        work = group.alltoall_base(
+            output, input, output_split_sizes, input_split_sizes, opts
+        )
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+@_exception_logger
+def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
+    """
+    Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
+
+    Complex tensors are supported.
+
+    Args:
+        output_tensor_list (list[Tensor]): List of tensors to be gathered one
+            per rank.
+        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group.
+
+    .. warning::
+        `all_to_all` is experimental and subject to change.
+
+    Examples:
+        >>> # xdoctest: +SKIP("Undefined rank")
+        >>> input = torch.arange(4) + rank * 4
+        >>> input = list(input.chunk(4))
+        >>> input
+        [tensor([0]), tensor([1]), tensor([2]), tensor([3])]     # Rank 0
+        [tensor([4]), tensor([5]), tensor([6]), tensor([7])]     # Rank 1
+        [tensor([8]), tensor([9]), tensor([10]), tensor([11])]   # Rank 2
+        [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
+        >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
+        >>> dist.all_to_all(output, input)
+        >>> output
+        [tensor([0]), tensor([4]), tensor([8]), tensor([12])]    # Rank 0
+        [tensor([1]), tensor([5]), tensor([9]), tensor([13])]    # Rank 1
+        [tensor([2]), tensor([6]), tensor([10]), tensor([14])]   # Rank 2
+        [tensor([3]), tensor([7]), tensor([11]), tensor([15])]   # Rank 3
+
+        >>> # Essentially, it is similar to following operation:
+        >>> scatter_list = input
+        >>> gather_list  = output
+        >>> for i in range(world_size):
+        >>>     dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
+
+        >>> input
+        tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
+        tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
+        tensor([20, 21, 22, 23, 24])                                     # Rank 2
+        tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
+        >>> input_splits
+        [2, 2, 1, 1]                                                     # Rank 0
+        [3, 2, 2, 2]                                                     # Rank 1
+        [2, 1, 1, 1]                                                     # Rank 2
+        [2, 2, 2, 1]                                                     # Rank 3
+        >>> output_splits
+        [2, 3, 2, 2]                                                     # Rank 0
+        [2, 2, 1, 2]                                                     # Rank 1
+        [1, 2, 1, 2]                                                     # Rank 2
+        [1, 2, 1, 1]                                                     # Rank 3
+        >>> input = list(input.split(input_splits))
+        >>> input
+        [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])]                   # Rank 0
+        [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
+        [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])]                 # Rank 2
+        [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])]         # Rank 3
+        >>> output = ...
+        >>> dist.all_to_all(output, input)
+        >>> output
+        [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])]   # Rank 0
+        [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])]           # Rank 1
+        [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])]              # Rank 2
+        [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])]                  # Rank 3
+
+        >>> # Another example with tensors of torch.cfloat type.
+        >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j)
+        >>> input = list(input.chunk(4))
+        >>> input
+        [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])]            # Rank 0
+        [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])]            # Rank 1
+        [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])]      # Rank 2
+        [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])]    # Rank 3
+        >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
+        >>> dist.all_to_all(output, input)
+        >>> output
+        [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])]          # Rank 0
+        [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])]        # Rank 1
+        [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])]        # Rank 2
+        [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])]        # Rank 3
+
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("all_to_all")
+        return
+
+    opts = AllToAllOptions()
+    _check_tensor_list(output_tensor_list, "output_tensor_list")
+    _check_tensor_list(input_tensor_list, "input_tensor_list")
+    _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
+
+    input_tensor_list = [
+        t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
+    ]
+    output_tensor_list = [
+        t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list
+    ]
+
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.alltoall(output_tensor_list, input_tensor_list, opts)
+    else:
+        work = group.alltoall(output_tensor_list, input_tensor_list, opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+@_exception_logger
+def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
+    """
+    Synchronize all processes.
+
+    This collective blocks processes until the whole group enters this function,
+    if async_op is False, or if async work handle is called on wait().
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        async_op (bool, optional): Whether this op should be an async op
+        device_ids ([int], optional): List of device/GPU ids.
+
+    Returns:
+        Async work handle, if async_op is set to True.
+        None, if not async_op or if not part of the group
+
+    .. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of
+              device synchronization to block the CPU. Thus, please do not assume that
+              `barrier()` would perform a device synchronization.
+    """
+    if _rank_not_in_group(group):
+        _warn_not_in_group("barrier")
+        return
+
+    opts = BarrierOptions()
+    opts.device = _get_pg_default_device(group)
+    if device_ids is not None:
+        if isinstance(device_ids, list):
+            opts.device_ids = device_ids
+        else:
+            raise TypeError(
+                "Invalid function argument: device_ids type should be List[int]"
+            )
+
+    if group is None:
+        default_pg = _get_default_group()
+        work = default_pg.barrier(opts=opts)
+    else:
+        work = group.barrier(opts=opts)
+
+    if async_op:
+        return work
+    else:
+        work.wait()
+
+
+def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
+    """
+    Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
+
+    It is able to report ranks that did not pass this barrier within the provided timeout.
+    Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.
+    Rank 0 will block until all send /recv from other ranks are processed, and will report
+    failures for ranks that failed to respond in time. Note that if one rank does not reach the
+    monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.
+
+    This collective will block all processes/ranks in the group, until the
+    whole group exits the function successfully, making it useful for debugging
+    and synchronizing. However, it can have a performance impact and should only
+    be used for debugging or scenarios that require full synchronization points
+    on the host-side. For debugging purposes, this barrier can be inserted
+    before the application's collective calls to check if any ranks are
+    desynchronized.
+
+    .. note:: Note that this collective is only supported with the GLOO backend.
+
+    Args:
+        group (ProcessGroup, optional): The process group to work on. If
+            ``None``, the default process group will be used.
+        timeout (datetime.timedelta, optional): Timeout for monitored_barrier.
+            If ``None``, the default process group timeout will be used.
+        wait_all_ranks (bool, optional): Whether to collect all failed ranks or
+            not. By default, this is ``False`` and ``monitored_barrier`` on rank 0
+            will throw on the first failed rank it encounters in order to fail
+            fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will
+            collect all failed ranks and throw an error containing information
+            about all failed ranks.
+
+    Returns:
+        ``None``.
+
+    Example::
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> # Note: Process group initialization omitted on each rank.
+        >>> import torch.distributed as dist
+        >>> if dist.get_rank() != 1:
+        >>>     dist.monitored_barrier() # Raises exception indicating that
+        >>> # rank 1 did not call into monitored_barrier.
+        >>> # Example with wait_all_ranks=True
+        >>> if dist.get_rank() == 0:
+        >>>     dist.monitored_barrier(wait_all_ranks=True) # Raises exception
+        >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
+        >>> # monitored_barrier.
+    """
+    # Need to call rank not in group before using the group, otherwise
+    # "Invalid process group" error is raised.
+    if _rank_not_in_group(group):
+        _warn_not_in_group("monitored_barrier")
+        return
+
+    if get_backend(group) != Backend.GLOO:
+        raise ValueError("monitored_barrier is only implemented for GLOO backend.")
+
+    if timeout is None:
+        timeout = _get_default_timeout(get_backend(group))
+    elif isinstance(timeout, float):
+        # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format?
+        warnings.warn(
+            "Please specify timeout arg as a timedelta. "
+            f"Converting current value of {timeout} assuming it represents seconds",
+        )
+        timeout = timedelta(seconds=timeout)
+
+    _check_valid_timeout(timeout)
+
+    group_to_use = _get_default_group() if group is None else group
+    return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
+
+
+def _create_process_group_wrapper(
+    wrapped_pg: torch._C._distributed_c10d.Backend,
+    store_prefix: str,
+    store: Store,
+    rank: int,
+    world_size: int,
+    timeout: timedelta = default_pg_timeout,
+):
+    # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
+
+    # Create a separate prefix store for the helper process group.
+    prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
+    store = PrefixStore(prefix, store)
+    helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
+    # Wrap the underlying pg with ProcessGroupWrapper.
+    wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
+    return wrapped_pg
+
+# helper function for deterministically hashing a list of ranks
+def _hash_ranks(ranks: List[int]):
+    return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
+
+# Takes a list of ranks and computes an integer color
+def _process_group_color(ranks: List[int]) -> int:
+    # Convert our hash to an int, but avoid negative numbers by shifting a bit.
+    return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
+
+def _process_group_name(ranks, use_hashed_name):
+    global _world
+    if use_hashed_name:
+        pg_name = _hash_ranks(ranks)
+        while pg_name in _world.pg_names.values():
+            pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest()
+    else:
+        pg_name = str(_world.group_count)
+        _world.group_count += 1
+    return pg_name
+
+def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
+    # Default to the same backend as the global process group
+    #  if backend is not specified.
+    if not backend:
+        backend = get_backend(_get_default_group())
+    return Backend(backend)
+
+
+@_time_logger
+def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False):
+    """
+    Create a new distributed group.
+
+    This function requires that all processes in the main group (i.e. all
+    processes that are part of the distributed job) enter this function, even
+    if they are not going to be members of the group. Additionally, groups
+    should be created in the same order in all processes.
+
+    .. warning::
+        Using multiple process groups with the ``NCCL`` backend concurrently
+        is not safe and the user should perform explicit synchronization in
+        their application to ensure only one process group is used at a time.
+        This means collectives from one process group should have completed
+        execution on the device (not just enqueued since CUDA execution is
+        async) before collectives from another process group are enqueued.
+        See `Using multiple NCCL communicators concurrently `_ for more details.
+
+    Args:
+        ranks (list[int]): List of ranks of group members. If ``None``, will be
+            set to all ranks. Default is ``None``.
+        timeout (timedelta, optional): see `init_process_group` for details and default value.
+        backend (str or Backend, optional): The backend to use. Depending on
+            build-time configurations, valid values are ``gloo`` and ``nccl``.
+            By default uses the same backend as the global group. This field
+            should be given as a lowercase string (e.g., ``"gloo"``), which can
+            also be accessed via :class:`Backend` attributes (e.g.,
+            ``Backend.GLOO``). If ``None`` is passed in, the backend
+            corresponding to the default process group will be used. Default is
+            ``None``.
+        pg_options (ProcessGroupOptions, optional): process group options
+            specifying what additional options need to be passed in during
+            the construction of specific process groups. i.e. for the ``nccl``
+            backend, ``is_high_priority_stream`` can be specified so that
+            process group can pick up high priority cuda streams.
+        use_local_synchronization (bool, optional): perform a group-local
+            barrier at the end of the process group creation. This is different
+            in that non-member ranks don't need to call into API and don't
+            join the barrier.
+
+    Returns:
+        A handle of distributed group that can be given to collective calls or None if the rank is not part of ``ranks``.
+
+    N.B. use_local_synchronization doesn't work with MPI.
+
+    N.B. While use_local_synchronization=True can be significantly faster with larger
+    clusters and small process groups, care must be taken since it changes cluster behavior
+    as non-member ranks don't join the group barrier().
+
+    N.B. use_local_synchronization=True can lead to deadlocks when each rank creates
+    multiple overlaping process groups. To avoid that, make sure all ranks follow the
+    same global creation order.
+    """
+    return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization)
+
+def _new_group_with_tag(
+    ranks=None,
+    timeout=None,
+    backend=None,
+    pg_options=None,
+    pg_tag=None,
+    use_local_synchronization=False
+):
+    """
+    Variant of ``new_group`` that exposes tag creation.
+
+    :: N.B. The mechanism is experimental and tied to the functional collectives effort, see
+    ``torch.distributed._functional_collectives`` for reference on how to use it.
+    """
+    global _world
+
+    default_pg = _get_default_group()
+    default_backend, default_store = _world.pg_map[default_pg]
+    global_rank = default_pg.rank()
+    global_world_size = default_pg.size()
+
+
+    # Default to the same backend as the global process group
+    # if the backend is not specified.
+    if not backend:
+        backend = default_backend
+    backend = Backend(backend)
+
+    # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants,
+    # which may just pass their timeout value (or None)
+    if timeout is None:
+        timeout = _get_default_timeout(backend)
+    _check_valid_timeout(timeout)
+
+    if use_local_synchronization:
+        # MPI backend doesn't have have a way for us to perform a partial sync
+        if backend == Backend.MPI:
+            raise ValueError("MPI backend doesn't support use_local_synchronization=True")
+        if ranks is not None and get_rank() not in ranks:
+            return None
+
+    # checks the input ranks
+    if ranks is not None:
+        ranks = sorted(ranks)
+        group_world_size = len(ranks)
+        if group_world_size > global_world_size:
+            raise ValueError(
+                "the new group's world size should be less or "
+                "equal to the world size set by "
+                "init_process_group"
+            )
+        # check ranks' sanity
+        for rank in ranks:
+            if rank < 0 or rank >= global_world_size:
+                raise ValueError(
+                    "The new group's rank should be within "
+                    "the world_size set by init_process_group"
+                )
+        if global_rank in ranks:
+            group_rank = ranks.index(global_rank)
+        else:
+            group_rank = None
+    else:
+        ranks = list(range(global_world_size))
+        group_world_size = global_world_size
+        group_rank = global_rank
+
+    group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization)
+
+    pg, pg_store = _new_process_group_helper(
+        group_world_size,
+        group_rank,
+        ranks,
+        backend,
+        default_store,
+        group_name,
+        pg_options=pg_options,
+        timeout=timeout,
+        pg_tag=pg_tag
+    )
+
+    # Create the global rank to group rank mapping
+    _world.pg_group_ranks[pg] = {
+        global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
+    }
+
+    if _is_barrier_after_init() == 1:
+        # barrier at the end to ensure that once we return from this method, all
+        # process groups including global variables (if any) are updated
+        # correctly on all ranks.
+        # Update 04/2023: for large-scale runs, this barrier (esp. store-based
+        # barrier) may be costly and/or unscalable. Also, in a lot of cases,
+        # these barriers may be unnecessary, as proven by a green CI after
+        # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
+        # added which enables this barrier only when set to 1.
+        logger.info(
+            "Performing barrier after ProcessGroup initialization since "
+            "TORCH_DIST_INIT_BARRIER = 1"
+        )
+        if backend == Backend.MPI:
+            # MPI doesn't have store.
+            barrier()
+        else:
+            barrier_store = pg_store if use_local_synchronization else default_store
+            world_size = len(ranks) if use_local_synchronization else get_world_size()
+            # Use store based barrier here since barrier() used a bunch of
+            # default devices and messes up NCCL internal state.
+            _store_based_barrier(global_rank, barrier_store, group_name, world_size, timeout)
+
+    return pg
+
+
+def new_subgroups(
+    group_size=None,
+    group=None,
+    timeout=None,
+    backend=None,
+    pg_options=None,
+):
+    """
+    Create subgroups of equal size.
+
+    By default, it creates intra-machine subgroups,
+    where each of which contains all the ranks of a machine, based on the assumption
+    that each machine has the same number of devices.
+
+    This is a convenience API that calls ``new_group`` to generate multiple subgroups.
+    It requires that all processes in the main group (i.e. all
+    processes that are part of the distributed job) enter this function, even
+    if they are not going to be members of the group.
+
+    .. warning::
+        If ``group_size`` is passed in, the world size must be divisible by ``group_size``.
+        If no ``group_size`` is passed in, it believe that you are creating a group based
+        on CUDA and determining the group size by number of CUDA devices, and if not all
+        the machines have the same number of devices, the subgroup division will be
+        different across nodes and can cause unexpected behaviors. Therefore, if you are
+        creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please
+        pass in ``group_size`` correctly.
+
+    .. warning::
+        Using multiple process groups with the ``NCCL`` backend concurrently
+        is not safe and the user should perform explicit synchronization in
+        their application to ensure only one process group is used at a time.
+        This means collectives from one process group should have completed
+        execution on the device (not just enqueued since CUDA execution is
+        async) before collectives from another process group are enqueued.
+        See `Using multiple NCCL communicators concurrently `_ for more details.
+
+    Args:
+        group_size (int, optional): The size of each subgroup. If ``None``,
+            the default subgroup size is equal to the number of devices on each machine,
+            based on the assumption that each machine has exactly the same
+            number of devices. Default is ``None``.
+        timeout (timedelta, optional): see `init_process_group` for details and default value.
+        backend (str or Backend, optional): The backend to use. Depending on
+            build-time configurations, valid values are ``gloo`` and ``nccl``.
+            By default uses the same backend as the global group. This field
+            should be given as a lowercase string (e.g., ``"gloo"``), which can
+            also be accessed via :class:`Backend` attributes (e.g.,
+            ``Backend.GLOO``). If ``None`` is passed in, the backend
+            corresponding to the default process group will be used. Default is
+            ``None``.
+        pg_options (ProcessGroupOptions, optional): process group options
+            specifying what additional options need to be passed in during
+            the construction of specific process groups. i.e. for the ``nccl``
+            backend, ``is_high_priority_stream`` can be specified so that
+            process group can pick up high priority cuda streams.
+
+    Returns:
+        The subgroup containing the current rank, and all the subgroups used for cleanup.
+
+    Examples:
+        >>> # Create intra-machine subgroups.
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> cur_subgroup, subgroups = dist.new_subgroups()
+        >>> # Allreduce within the machine.
+        >>> rank = dist.get_rank()
+        >>> tensor = torch.ones(1, device=rank) * rank
+        >>> dist.all_reduce(tensor, group=cur_subgroup)
+        >>> tensor
+        tensor([8])     # Assume 8 is the number of CUDA devices per machine.
+        >>> # Cleanup.
+        >>> for subgroup in subgroups:
+        >>>     dist.destroy_process_group(subgroup)
+    """
+    if group_size is None:
+        if not torch.cuda.is_available():
+            raise ValueError("Default group size only takes effect when CUDA is available."
+                             "If your subgroup using a backend that does not depend on CUDA,"
+                             "please pass in 'group_size' correctly.")
+        group_size = torch.cuda.device_count()
+    if group_size <= 0:
+        raise ValueError(f"The arg 'group_size' ({group_size}) must be positive")
+
+    world_size = get_world_size()
+    if world_size < group_size:
+        raise ValueError(f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})")
+    if world_size % group_size != 0:
+        raise ValueError("The world size must be divisible by 'group_size'")
+
+    subgroups = []
+    cur_subgroup = None
+
+    for subgroup_id in range(world_size // group_size):
+        start_rank = subgroup_id * group_size
+        end_rank = start_rank + group_size
+        ranks_in_subgroup = list(range(start_rank, end_rank))
+        subgroup = new_group(
+            ranks=ranks_in_subgroup,
+            timeout=timeout,
+            backend=backend,
+            pg_options=pg_options,
+        )
+        subgroups.append(subgroup)
+
+        rank = get_rank()
+        if rank in ranks_in_subgroup:
+            cur_subgroup = subgroup
+            logger.info(
+                "Rank %s is assigned to subgroup %s",
+                rank, ranks_in_subgroup
+            )
+
+    return cur_subgroup, subgroups
+
+
+def new_subgroups_by_enumeration(
+    ranks_per_subgroup_list,
+    timeout=None,
+    backend=None,
+    pg_options=None,
+):
+    """
+    Create subgroups by dividing the global world.
+
+    The division is specified by a nested list of ranks. The subgroups cannot have
+    overlap, and some ranks may not have to be in any subgroup.
+
+    This is a convenience API that calls ``new_group`` to generate multiple subgroups.
+    It requires that all processes in the main group (i.e. all
+    processes that are part of the distributed job) enter this function, even
+    if they are not going to be members of the group.
+
+    .. warning::
+        Using multiple process groups with the ``NCCL`` backend concurrently
+        is not safe and the user should perform explicit synchronization in
+        their application to ensure only one process group is used at a time.
+        This means collectives from one process group should have completed
+        execution on the device (not just enqueued since CUDA execution is
+        async) before collectives from another process group are enqueued.
+        See `Using multiple NCCL communicators concurrently `_ for more details.
+
+    Args:
+        ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of
+            group members.
+        timeout (timedelta, optional): see `init_process_group` for details and default value.
+        backend (str or Backend, optional): The backend to use. Depending on
+             build-time configurations, valid values are ``gloo`` and ``nccl``.
+             By default uses the same backend as the global group. This field
+             should be given as a lowercase string (e.g., ``"gloo"``), which can
+             also be accessed via :class:`Backend` attributes (e.g.,
+             ``Backend.GLOO``). If ``None`` is passed in, the backend
+             corresponding to the default process group will be used. Default is
+             ``None``.
+        pg_options (ProcessGroupOptions, optional): process group options
+            specifying what additional options need to be passed in during
+            the construction of specific process groups. i.e. for the ``nccl``
+            backend, ``is_high_priority_stream`` can be specified so that
+            process group can pick up high priority cuda streams.
+
+    Returns:
+        The subgroup containing the current rank, and all the subgroups used for cleanup.
+
+    Examples:
+        >>> # Create two subgroups, where each has 2 processes.
+        >>> # xdoctest: +SKIP("need process group init")
+        >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]])
+        >>> rank = dist.get_rank()
+        >>> tensor = torch.ones(1, device=rank) * rank
+        >>> dist.all_reduce(tensor, group=cur_subgroup)
+        >>> tensor
+        tensor([2])     # Subgroup 0: ranks 0 and 2
+        tensor([4])     # Subgroup 1: ranks 1 and 3
+    """
+    if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0:
+        raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty")
+
+    subgroups = []
+    cur_subgroup = None
+    # Create a mapping from rank to subgroup to check if there is any subgroup overlap.
+    rank_to_ranks_dict = {}  # type: ignore[var-annotated]
+    for ranks in ranks_per_subgroup_list:
+        subgroup = new_group(
+            ranks=ranks,
+            timeout=timeout,
+            backend=backend,
+            pg_options=pg_options,
+        )
+        subgroups.append(subgroup)
+        my_rank = get_rank()
+        for rank in ranks:
+            if rank in rank_to_ranks_dict:
+                raise ValueError(
+                    f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}"
+                )
+            rank_to_ranks_dict[rank] = ranks
+            if my_rank == rank:
+                cur_subgroup = subgroup
+                logger.info("Rank %s is assigned to subgroup %s", rank, ranks)
+
+    return cur_subgroup, subgroups
+
+
+def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]:
+    if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
+        tag = f"user:{tag}"
+
+    for group in _world.tags_to_pg.get(tag, []):
+        if group.size() != len(ranks):
+            continue
+
+        group_ranks = get_process_group_ranks(group)
+        good = all(r in group_ranks for r in ranks)
+        if good:
+            return group
+    return None
+
+def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup:
+    assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
+
+    my_rank = get_rank()
+    my_ranks = None
+
+    if stride == len(ranks):
+        my_ranks = ranks.copy()
+        assert my_rank in my_ranks, "rankset doesn't include the current node"
+    else:
+        for i in range(0, len(ranks), stride):
+            rank_set = ranks[i : i + stride]
+            if my_rank in rank_set:
+                my_ranks = rank_set
+        assert my_ranks is not None, "rankset doesn't include the current node"
+
+    my_ranks.sort()
+
+    pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
+    if pg is not None:
+        return pg
+    if tag == "":
+        raise ValueError("Cannot automatically create PG with empty tag")
+    # TODO copy settings and timeout from default PG
+    return _new_group_with_tag(my_ranks, pg_tag=tag)
+
+def _get_group_tag(pg: ProcessGroup) -> str:
+    """Return the tag associated with ``pg``."""
+    tag = _world.pg_to_tag[pg]
+    if tag.startswith("user:"):
+        tag = tag[5:]
+    return tag
+
+def _get_process_group_name(pg: ProcessGroup) -> str:
+    return _world.pg_names.get(pg, "None")
+
+def _get_process_group_store(pg: ProcessGroup) -> Store:
+    return _world.pg_map[pg][1]
+
+# This ops are not friendly to TorchDynamo. So, we decide to disallow these ops
+# in FX graph, allowing them to run them on eager, with torch.compile.
+dynamo_unsupported_distributed_c10d_ops = [
+    recv,
+    all_gather_object,
+    all_gather_coalesced,
+    all_to_all_single,
+    all_reduce,
+    gather_object,
+    all_to_all,
+    all_reduce_coalesced,
+    gather,
+    broadcast_object_list,
+    barrier,
+    scatter,
+    scatter_object_list,
+    reduce,
+    all_gather,
+    reduce_scatter,
+    all_gather_into_tensor,
+    broadcast,
+    reduce_scatter_tensor,
+    send,
+]
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..427e1745c4a2631cd006e0c856c248d7e2968c11
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/__init__.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env/python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+
+Torchelastic agent and user worker failover contract:
+
+**TL;DR;**:
+
+* TE(torchelastic) expects user workers to finish with the 5 minutes drift
+* It is better to design DDP app to fail for all workers, rather than a single one.
+* TE does not synchronize number of restarts between agents
+* TE re-rendezvous does not trigger restart decrease
+* When a single agent finishes its job(successfully or not), it will close rendezvous.
+  If other agents still have workers in progress, they will be terminated.
+* Based on above, scale down does not work if at least single agent finishes the job.
+* When Scale up is detected by agents, it will not decrease ``max_restarts``
+
+
+In general TE(torchelastic) can launch arbitrary user code, but there is some
+clarifications need to be done around what failover mechanism torchelastic
+provides and what failover mechanism it expects from user workers.
+
+Torchelastic currently supports DDP style applications.  That means that
+TE expects *ALL* workers finish approximately at the same time. In practice,
+it is nearly to impossible to guarantee that all workers in arbitrary
+DDP application finish at the time, so TE provides a finalization barrier
+that waits for TIMEOUT(5 minutes) for worker finalization.
+
+**Worker Failure**
+
+When worker fails, TE will check the number of restarts
+available, if there is more than 0 restarts, TE will start a new rendezvous
+round and restart the worker process. New rendezvous round will other
+TE agents to terminate their workers.
+
+.. note:: The TE agent does not synchronize restarts between themselves.
+          When a single agent performs restart, it will trigger a local ``max_restarts``
+          decrease, other agent will not decrease their ``max_restarts``.
+          the user to run the distributed application locally on a dev host.
+
+A single worker failure can cause the whole cluster to fail:
+If a single worker is constantly failing, it will cause the TE agent
+``max_restarts``  to go to zero. This will cause an agent to finish its
+work and close rendezvous. If there are any other workers on different
+agents, they will be terminated.
+
+
+**Re-Rendezvous**
+
+Re-rendezvous occurs when TE agents detect a new node
+trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents
+will terminate its workers and start a new rendezvous round.
+
+Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous
+has already max_nodes, the new node won't be added to the wait list right
+away since there is no need to tear down a rendezvous that is already fully
+utilized. The new node will wait until its timeout (600 secs by default)
+and periodically check the number of participants. If the number becomes
+less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs.
+
+*Scale up event*. When scale up event happens, torchelastic rendezvous
+will detect that there are new nodes trying to join. Torchelastic agent
+will stop all workers and perform re-rendezvous. Note: when scale up event
+happens, *``max_restarts``* will *not* decrease.
+
+*Scale down event*. When scale down event happens, rendezvous will not
+notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` ,
+it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` ,
+TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*.
+
+"""
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf1d28816b3dba643896be4bb77447933eef2572
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52b18e9d0fbf56ad260176bcf1f4f5baed037d79
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d5992f195597fd71aa809fe1b2dcedb4d6a8ef
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__init__.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+The elastic agent is the control plane of torchelastic.
+
+It is a process that launches and manages underlying worker processes.
+The agent is responsible for:
+
+1. Working with distributed torch: the workers are started with all the
+   necessary information to successfully and trivially call
+   ``torch.distributed.init_process_group()``.
+
+2. Fault tolerance: monitors workers and upon detecting worker failures
+   or unhealthiness, tears down all workers and restarts everyone.
+
+3. Elasticity: Reacts to membership changes and restarts workers with the new
+   members.
+
+The simplest agents are deployed per node and works with local processes.
+A more advanced agent can launch and manage workers remotely. Agents can
+be completely decentralized, making decisions based on the workers it manages.
+Or can be coordinated, communicating to other agents (that manage workers
+in the same job) to make a collective decision.
+"""
+
+from .api import (  # noqa: F401
+    ElasticAgent,
+    RunResult,
+    SimpleElasticAgent,
+    Worker,
+    WorkerGroup,
+    WorkerSpec,
+    WorkerState,
+)
+from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4acc402160714660e11bd9ca87abe070907f1ea4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d4fa7de067a2dd8a14354b4e6e24f51e7016da3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0ba513f9458a8d57a9b51f8fac58b3e866d488c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d684d1aac0db50e2dcd634314952df2b3e5e8a6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/api.py
@@ -0,0 +1,954 @@
+# mypy: ignore-errors
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import abc
+import functools
+import json
+import os
+import signal
+import socket
+import time
+import traceback
+import warnings
+from contextlib import closing
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch.distributed.elastic.rendezvous as rdzv
+import torch.distributed.elastic.utils.store as store_util
+from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
+from torch.distributed import Store
+from torch.distributed.elastic.events import Event, EventSource, record
+from torch.distributed.elastic.metrics import prof, put_metric
+from torch.distributed.elastic.multiprocessing import (
+    ProcessFailure,
+    SignalException,
+)
+from torch.distributed.elastic.utils.logging import get_logger
+
+__all__ = [
+    "WorkerSpec",
+    "Worker",
+    "WorkerState",
+    "WorkerGroup",
+    "RunResult",
+    "ElasticAgent",
+    "SimpleElasticAgent",
+]
+_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state"
+
+DEFAULT_ROLE = "default"
+log = get_logger(__name__)
+
+
+@dataclass
+class WorkerSpec:
+    """Blueprint information about a particular type of worker.
+
+    For a given role, there must only exist a single worker spec.
+    Worker spec is expected to be homogeneous across all nodes (machine),
+    that is each node runs the same number of workers for a particular spec.
+
+    Args:
+        role: user-defined role for the workers with this spec
+        local_world_size: number local workers to run
+        fn: (deprecated use entrypoint instead)
+        entrypoint: worker function or command
+        args: arguments to pass to ``entrypoint``
+        rdzv_handler: handles rdzv for this set of workers
+        max_restarts: number of max retries for the workers
+        monitor_interval: monitor status of workers every ``n`` seconds
+        master_port: fixed port to run the c10d store on rank 0
+                     if not specified then will chose a random free port
+        master_addr: fixed master_addr to run the c10d store on rank 0
+                     if not specified then will chose hostname on agent rank 0
+        redirects: redirect std streams to a file,
+                   selectively redirect for a particular
+                   local rank by passing a map
+        tee: tees the specified std stream(s) to console + file,
+             selectively tee for a particular local rank by passing a map,
+             takes precedence over ``redirects`` settings.
+
+    """
+
+    role: str
+    local_world_size: int
+    rdzv_handler: rdzv.RendezvousHandler
+    fn: Optional[Callable] = None
+    # TODO @kiuk - make entrypoint a required field
+    entrypoint: Union[Callable, str, None] = None
+    args: Tuple = ()
+    max_restarts: int = 3
+    monitor_interval: float = 30.0
+    master_port: Optional[int] = None
+    master_addr: Optional[str] = None
+    local_addr: Optional[str] = None
+
+    def __post_init__(self):
+        assert self.local_world_size > 0
+        assert self.monitor_interval > 0
+
+        if self.fn:
+            warnings.warn(
+                "WorkerSpec.fn will be deprecated,"
+                " please use WorkerSpec.entrypoint instead",
+                category=DeprecationWarning,
+            )
+            self.entrypoint = self.fn
+        assert self.entrypoint
+
+    def get_entrypoint_name(self):
+        """Get the entry point name.
+
+        If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__``
+        else if the entrypoint is a binary (e.g. ``str``), returns the binary name.
+        """
+        if isinstance(self.entrypoint, str):
+            return os.path.basename(self.entrypoint)
+        else:
+            assert self.entrypoint is not None
+            return self.entrypoint.__qualname__
+
+
+class Worker:
+    """A worker instance.
+
+    Contrast this with ``WorkerSpec`` that represents the specifications of a
+    worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to
+    a ``WorkerSpec`` as an object is to a class.
+
+    The ``id`` of the worker is interpreted
+    by the specific implementation of ``ElasticAgent``. For a local
+    agent, it could be the ``pid (int)`` of the worker, for a remote
+    agent it could be encoded as ``host:port (string)``.
+
+    Args:
+        id (Any): uniquely identifies a worker (interpreted by the agent)
+        local_rank (int): local rank of the worker
+        global_rank (int): global rank of the worker
+        role_rank (int): rank of the worker across all workers that have the same role
+        world_size (int): number of workers (globally)
+        role_world_size (int): number of workers that have the same role
+    """
+
+    __slots__ = [
+        "id",
+        "local_rank",
+        "global_rank",
+        "role_rank",
+        "world_size",
+        "role_world_size",
+    ]
+
+    def __init__(
+        self,
+        local_rank: int,
+        global_rank: int = -1,
+        role_rank: int = -1,
+        world_size: int = -1,
+        role_world_size: int = -1,
+    ):
+        # unique identifier for this worker
+        self.id: Any = None
+
+        # rank of the worker among workers with the same role being monitored
+        # by the same ``agent`` instance.
+        self.local_rank: int = local_rank
+
+        #  rank of the worker among all the workers across all roles
+        #  across all ``agent`` instances.
+        #  Global rank is not stable between re-rendezvous.
+        self.global_rank: int = global_rank
+
+        #  rank of the worker among all the workers with the same role
+        #  across all ``agent`` instances.
+        #  Role rank is not stable between re-rendezvous.
+        self.role_rank: int = role_rank
+
+        # total number of workers (globally). Due to elasticity
+        # the world size may change between re-rendezvous.
+        self.world_size: int = world_size
+
+        # total number of workers that share the same role. Due to elasticity
+        # the role world size may change between re-rendezvous.
+        self.role_world_size: int = role_world_size
+
+    def __str__(self):
+        return (
+            f"local_rank={self.local_rank},global_rank={self.global_rank}"
+            f",role_rank={self.role_rank},world_size={self.world_size}"
+            f",role_world_size={self.role_world_size}"
+        )
+
+    def __repr__(self):
+        return str(self)
+
+
+class WorkerState(str, Enum):
+    """A state of the ``WorkerGroup``.
+
+    Workers in a worker group change state as a unit. If a single worker
+    in a worker group fails the entire set is considered failed::
+
+      UNKNOWN - agent lost track of worker group state, unrecoverable
+      INIT - worker group object created not yet started
+      HEALTHY - workers running and healthy
+      UNHEALTHY - workers running and unhealthy
+      STOPPED - workers stopped (interrupted) by the agent
+      SUCCEEDED - workers finished running (exit 0)
+      FAILED - workers failed to successfully finish (exit !0)
+
+
+    A worker group starts from an initial ``INIT`` state,
+    then progresses to ``HEALTHY`` or ``UNHEALTHY`` states,
+    and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state.
+
+    Worker groups can be interrupted and temporarily put into ``STOPPED`` state
+    by the agent. Workers in ``STOPPED`` state are scheduled to be restarted
+    in the near future by the agent. Some examples of workers being put into
+    ``STOPPED`` state are:
+
+    1. Worker group failure|unhealthy observed
+    2. Membership change detected
+
+    When actions (start, stop, rdzv, retry, etc) on worker group fails
+    and results in the action being partially applied to the worker group
+    the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled
+    exceptions during state change events on the agent. The agent is not
+    expected to recover worker groups in ``UNKNOWN`` state and is better off
+    self terminating and allowing the job manager to retry the node.
+    """
+
+    UNKNOWN = "UNKNOWN"
+    INIT = "INIT"
+    HEALTHY = "HEALTHY"
+    UNHEALTHY = "UNHEALTHY"
+    STOPPED = "STOPPED"
+    SUCCEEDED = "SUCCEEDED"
+    FAILED = "FAILED"
+
+    @staticmethod
+    def is_running(state: "WorkerState") -> bool:
+        """Return the state of the Worker.
+
+        Returns:
+             True if the worker state represents workers still running
+             (e.g. that the process exists but not necessarily healthy).
+        """
+        return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
+
+
+class WorkerGroup:
+    """A set of ``Worker`` instances.
+
+    The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker
+    group contains cross instance workers or not depends on the implementation of the agent.
+    """
+
+    __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]
+
+    def __init__(self, spec: WorkerSpec):
+        self.spec = spec
+        self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
+
+        # assigned after rdzv
+        self.store = None
+        self.group_rank = None
+        self.group_world_size = None
+
+        self.state = WorkerState.INIT
+
+
+class _RoleInstanceInfo:
+    """The class is used by the agent to exchange the information with other agents.
+
+    The information is used to determine the rank of the workers that agent
+    manages in heterogeneous environments, where different agents can have
+    different number of workers.
+    """
+
+    __slots__ = ["role", "rank", "local_world_size"]
+
+    def __init__(self, role: str, rank: int, local_world_size: int):
+        r"""Initialize the agent class instance.
+
+        Args:
+            role (str): user-defined role for the workers with this spec
+            rank (int): the rank of the agent
+            local_world_size (int): number of local workers to run
+        """
+        self.role = role
+        self.rank = rank
+        self.local_world_size = local_world_size
+
+    def serialize(self) -> bytes:
+        dict_data = {
+            "role": self.role,
+            "rank": self.rank,
+            "local_world_size": self.local_world_size,
+        }
+        return json.dumps(dict_data).encode(encoding="UTF-8")
+
+    @staticmethod
+    def deserialize(data: bytes):
+        dict_data = json.loads(data.decode(encoding="UTF-8"))
+        return _RoleInstanceInfo(
+            dict_data["role"], dict_data["rank"], dict_data["local_world_size"]
+        )
+
+    @staticmethod
+    def compare(obj1, obj2) -> int:
+        if obj1.role == obj2.role:
+            return obj1.rank - obj2.rank
+        elif obj1.role > obj2.role:
+            return 1
+        else:
+            return -1
+
+    @staticmethod
+    def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]:
+        start_idx, end_idx = -1, -1
+        for idx, role_info in enumerate(roles_infos):
+            if role_info.role == role:
+                if start_idx == -1:
+                    start_idx = idx
+                end_idx = idx
+        return (start_idx, end_idx)
+
+
+@dataclass
+class RunResult:
+    """Return results of the worker executions.
+
+    Run results follow an "all-or-nothing" policy where the run is successful if and
+    only if ALL local workers managed by this agent complete successfully.
+
+    If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
+    field contains the outputs (return values) of the workers managed by THIS agent mapped
+    by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
+    global rank 0.
+
+    .. note:: ``return_values`` are only meaningful for when the worker entrypoint
+              is a function. Workers specified as a binary entrypoint do not canonically
+              have a return value and the ``return_values`` field is meaningless and
+              may be empty.
+
+    If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
+    failure information, again, mapped by the GLOBAL rank of the worker that failed.
+
+    The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
+    a worker's final state can only be one of: succeeded, failed. Workers intentionally
+    terminated by the agent according to the agent's restart policy, are not represented
+    in either ``return_values`` nor ``failures``.
+    """
+
+    state: WorkerState
+    return_values: Dict[int, Any] = field(default_factory=dict)
+    failures: Dict[int, ProcessFailure] = field(default_factory=dict)
+
+    def is_failed(self) -> bool:
+        return self.state == WorkerState.FAILED
+
+
+def _get_socket_with_port() -> socket.socket:
+    """Return a free port on localhost.
+
+    The free port is "reserved" by binding a temporary socket on it.
+    Close the socket before passing the port to the entity that
+    requires it. Usage example::
+
+    sock = _get_socket_with_port()
+    with closing(sock):
+        port = sock.getsockname()[1]
+        sock.close()
+        # there is still a race-condition that some other process
+        # may grab this port before func() runs
+        func(port)
+    """
+    addrs = socket.getaddrinfo(
+        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
+    )
+    for addr in addrs:
+        family, type, proto, _, _ = addr
+        s = socket.socket(family, type, proto)
+        try:
+            s.bind(("localhost", 0))
+            s.listen(0)
+            return s
+        except OSError as e:
+            s.close()
+            log.info("Socket creation attempt failed.", exc_info=e)
+    raise RuntimeError("Failed to create a socket")
+
+
+def _get_fq_hostname() -> str:
+    return socket.getfqdn(socket.gethostname())
+
+
+class ElasticAgent(abc.ABC):
+    """An agent process responsible for managing one or more worker processes.
+
+    The worker processes are assumed to be regular distributed PyTorch scripts.
+    When the worker process is created by the agent, the agent provides the
+    necessary information for the worker processes to properly initialize
+    a torch process group.
+
+    The exact deployment topology and ratio of agent-to-worker is dependent
+    on the specific implementation of the agent and the user's job placement
+    preferences. For instance, to run a distributed training job on GPU with
+    8 trainers (one per GPU) one can:
+
+    1. Use 8 x single GPU instances, place an agent per instance, managing
+       1 worker per agent.
+    2. Use 4 x double GPU instances, place an agent per instance, managing
+       2 workers per agent.
+    3. Use 2 x quad GPU instances, place an agent per instance, managing
+       4 workers per agent.
+    4. Use 1 x 8 GPU instance, place an agent per instance, managing
+       8 workers per agent.
+
+    Usage
+    ::
+
+     group_result = agent.run()
+      if group_result.is_failed():
+        # workers failed
+        failure = group_result.failures[0]
+        log.exception("worker 0 failed with exit code : %s", failure.exit_code)
+      else:
+        return group_result.return_values[0] # return rank 0's results
+
+    """
+
+    @abc.abstractmethod
+    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
+        """Run the agent.
+
+        Supports retrying the worker group on failures up to ``max_restarts``.
+
+        Returns:
+            The result of the execution, containing the return values or
+            failure details for each worker mapped by the worker's global rank.
+
+        Raises:
+            Exception - any other failures NOT related to worker process
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
+        """Return the ``WorkerGroup`` for the given ``role``.
+
+        Note that the worker group is a mutable object and hence in a
+        multi-threaded/process environment it may change state.
+        Implementors are encouraged (but not required) to return
+        a defensive read-only copy.
+        """
+        raise NotImplementedError()
+
+
+class SimpleElasticAgent(ElasticAgent):
+    """An ``ElasticAgent`` that manages one particular type of worker role.
+
+    An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec``
+    such as one particular type of worker role.
+    """
+
+    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
+        self._worker_group = WorkerGroup(spec)
+        self._remaining_restarts = self._worker_group.spec.max_restarts
+        self._store = None
+        self._exit_barrier_timeout = exit_barrier_timeout
+        self._total_execution_time = 0
+
+    def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
+        return self._worker_group
+
+    @abc.abstractmethod
+    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
+        r"""Start ``worker_group.spec.local_world_size`` number of workers.
+
+        This is according to worker spec for the worker group .
+        Returns a map of ``local_rank`` to worker ``id``.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _stop_workers(self, worker_group: WorkerGroup) -> None:
+        r"""Stop all workers in the given worker group.
+
+        Implementors must deal with workers in all states defined by
+        ``WorkerState``. That is, it must gracefully handle stopping
+        non-existent workers, unhealthy (stuck) workers, etc.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
+        r"""Check on the workers for the ``worker_group``.
+
+        This function also returns the new state of the worker group.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
+        """Clean up any resources that were allocated during the agent's work.
+
+        Args:
+            death_sig: Signal to send to the child process, SIGTERM is default
+        """
+        raise NotImplementedError()
+
+    @staticmethod
+    def _set_master_addr_port(
+        store: Store,
+        master_addr: Optional[str],
+        master_port: Optional[int],
+        local_addr: Optional[str],
+    ):
+        if master_port is None:
+            sock = _get_socket_with_port()
+            with closing(sock):
+                master_port = sock.getsockname()[1]
+
+        if master_addr is None:
+            # If user specified the address for the local node, use it as the master addr if not exist
+            if local_addr:
+                master_addr = local_addr
+            else:
+                master_addr = _get_fq_hostname()
+
+        store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
+        store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
+
+    @staticmethod
+    def _get_master_addr_port(store: Store) -> Tuple[str, int]:
+        master_addr = store.get("MASTER_ADDR").decode(encoding="UTF-8")
+        master_port = int(store.get("MASTER_PORT").decode(encoding="UTF-8"))
+        return (master_addr, master_port)
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _rendezvous(self, worker_group: WorkerGroup) -> None:
+        r"""Run rendezvous for the workers specified by the worker spec.
+
+        Assigns workers a new global rank and world size.
+        Updates the rendezvous store for the worker group.
+        """
+        spec = worker_group.spec
+
+        store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
+        self._store = store
+
+        workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
+        worker_group.workers = workers
+        worker_group.store = store
+        worker_group.group_rank = group_rank
+        worker_group.group_world_size = group_world_size
+
+        if group_rank == 0:
+            self._set_master_addr_port(
+                store,
+                spec.master_addr,
+                spec.master_port,
+                spec.local_addr,
+            )
+
+        master_addr, master_port = self._get_master_addr_port(store)
+        restart_count = spec.max_restarts - self._remaining_restarts
+
+        log.info(
+            "[%(role)s] Rendezvous complete for workers. Result:\n"
+            "  restart_count=%(restart_count)s\n"
+            "  master_addr=%(master_addr)s\n"
+            "  master_port=%(master_port)s\n"
+            "  group_rank=%(group_rank)s\n"
+            "  group_world_size=%(group_world_size)s\n"
+            "  local_ranks=%(local_ranks)s\n"
+            "  role_ranks=%(role_ranks)s\n"
+            "  global_ranks=%(global_ranks)s\n"
+            "  role_world_sizes=%(role_world_sizes)s\n"
+            "  global_world_sizes=%(global_world_sizes)s\n",
+            {
+                "role": spec.role,
+                "restart_count": restart_count,
+                "master_addr": master_addr,
+                "master_port": master_port,
+                "group_rank": group_rank,
+                "group_world_size": group_world_size,
+                "local_ranks": [worker.local_rank for worker in workers],
+                "role_ranks": [worker.role_rank for worker in workers],
+                "global_ranks": [worker.global_rank for worker in workers],
+                "role_world_sizes": [worker.role_world_size for worker in workers],
+                "global_world_sizes": [worker.world_size for worker in workers]
+            }
+        )
+
+    def _get_ranks(
+        self,
+        role_infos: List[_RoleInstanceInfo],
+        role_idx: int,
+        start_idx: int = 0,
+        end_idx: int = -1,
+    ) -> Tuple[int, List[int]]:
+        if end_idx == -1:
+            end_idx = len(role_infos)
+        prefix_sum = 0
+        total_sum = 0
+        for idx in range(start_idx, end_idx):
+            if role_idx > idx:
+                prefix_sum += role_infos[idx].local_world_size
+            total_sum += role_infos[idx].local_world_size
+        return (
+            total_sum,
+            list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)),
+        )
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _assign_worker_ranks(
+        self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
+    ) -> List[Worker]:
+        """Determine proper ranks for worker processes.
+
+        The rank assignment is done according to the following algorithm:
+
+        1. Each agent writes its configuration(group_rank, group_world_size
+           , num_workers) to the common store.
+        2. Each agent retrieves configuration for all agents
+           and performs two level sort using role and rank.
+        3. Determine the global rank: the global rank of the workers for the current
+           agent is the offset of the infos array up to group_rank of the agent.
+           The offset is computed as a sum of local_world_size of all agents that
+           have rank less than the group_rank. The workers would have the ranks:
+           [offset, offset+local_world_size)
+        4. Determine the role rank: The role rank is determined using the algorithms
+           in the point 3 with the exception that the offset is done from the first
+           agent that has the same role as current one and has the minimum group rank.
+        """
+        role_infos = self._share_and_gather(store, group_rank, group_world_size, spec)
+        my_role_info = role_infos[group_rank]
+        worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank)
+        role_infos = sorted(
+            role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare)
+        )
+        role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries(
+            role_infos, my_role_info.role
+        )
+        role_pos = next(
+            idx
+            for idx, role_info in enumerate(role_infos)
+            if _RoleInstanceInfo.compare(role_info, my_role_info) == 0
+        )
+        role_world_size, role_ranks = self._get_ranks(
+            role_infos, role_pos, role_start_idx, role_end_idx + 1
+        )
+        workers = []
+        for ind in range(spec.local_world_size):
+            worker = Worker(
+                local_rank=ind,
+                global_rank=worker_global_ranks[ind],
+                role_rank=role_ranks[ind],
+                world_size=worker_world_size,
+                role_world_size=role_world_size,
+            )
+            workers.append(worker)
+        return workers
+
+    def _share_and_gather(
+        self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
+    ) -> List:
+        agent_role_info = _RoleInstanceInfo(
+            spec.role, group_rank, spec.local_world_size
+        )
+        key_prefix = "torchelastic/role_info"
+        agent_config_enc = agent_role_info.serialize()
+        role_infos_bytes = store_util.synchronize(
+            store, agent_config_enc, group_rank, group_world_size, key_prefix
+        )
+        role_infos = [
+            _RoleInstanceInfo.deserialize(role_info_bytes)
+            for role_info_bytes in role_infos_bytes
+        ]
+        return role_infos
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _initialize_workers(self, worker_group: WorkerGroup) -> None:
+        r"""Start a fresh set of workers for the worker_group.
+
+        Essentially, a rendezvous followed by a ``start_workers``.
+        The caller should first call ``_stop_workers()`` to stop running workers
+        prior to calling this method.
+
+        Optimistically sets the state of the worker group that
+        just started as ``HEALTHY`` and delegates the actual monitoring
+        of state to ``_monitor_workers()`` method
+        """
+        role = worker_group.spec.role
+        log.info("[%s] Rendezvous'ing worker group", role)
+
+        # TODO after stopping workers, wait at least monitor_interval*2 for
+        # workers on different nodes to fail on a collective op before waiting
+        # on the rdzv barrier, this way we ensure that nodes enter rdzv
+        # at around the same time and reduce false positive rdzv timeout errors
+        self._rendezvous(worker_group)
+
+        log.info("[%s] Starting worker group", role)
+        worker_ids = self._start_workers(worker_group)
+        for local_rank, w_id in worker_ids.items():
+            worker = worker_group.workers[local_rank]
+            worker.id = w_id
+
+        worker_group.state = WorkerState.HEALTHY
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _restart_workers(self, worker_group: WorkerGroup) -> None:
+        """Restart (stops, rendezvous, starts) all local workers in the group."""
+        role = worker_group.spec.role
+        log.info("[%s] Stopping worker group", role)
+        self._stop_workers(worker_group)
+        worker_group.state = WorkerState.STOPPED
+        self._initialize_workers(worker_group)
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
+        start_time = time.monotonic()
+        shutdown_called: bool = False
+        try:
+            result = self._invoke_run(role)
+            self._total_execution_time = int(time.monotonic() - start_time)
+            self._record_metrics(result)
+            self._record_worker_events(result)
+            return result
+        except RendezvousGracefulExitError as e:
+            log.info("Rendezvous gracefully exited: %s", e)
+        except SignalException as e:
+            log.warning("Received %s death signal, shutting down workers", e.sigval)
+            self._shutdown(e.sigval)
+            shutdown_called = True
+            raise
+        finally:
+            if not shutdown_called:
+                self._shutdown()
+            # record the execution time in case there were any exceptions during run.
+            self._total_execution_time = int(time.monotonic() - start_time)
+
+    def get_event_failed(self) -> Event:
+        return self._construct_event(
+            state="FAILED",
+            source=EventSource.AGENT,
+            raw_error=traceback.format_exc(),
+        )
+
+    def get_event_succeeded(self) -> Event:
+        return self._construct_event(
+            state="SUCCEEDED",
+            source=EventSource.AGENT,
+        )
+
+    def _record_worker_events(self, result: RunResult) -> None:
+        for worker in self._worker_group.workers:
+            failure = result.failures.get(worker.global_rank)
+            state: str = self._get_worker_state(worker, result)
+            raw_error = json.dumps(failure.error_file_data) if failure else None
+            record(self._construct_event(state, EventSource.WORKER, worker, raw_error))
+
+    def _get_worker_state(self, worker: Worker, result: RunResult) -> str:
+        failure = result.failures.get(worker.global_rank)
+        if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure:
+            # The worker got terminated by the torchelastic agent via SIGTERM signal
+            return "TERMINATED"
+        elif failure or worker.global_rank in result.return_values:
+            return result.state.value
+        else:
+            raise ValueError(f"Unknown worker: {worker.global_rank}")
+
+    def _construct_event(
+        self,
+        state: str,
+        source: EventSource,
+        worker: Optional[Worker] = None,
+        raw_error: Optional[str] = None,
+    ) -> Event:
+        wg = self._worker_group
+        spec = wg.spec
+        md = {
+            "group_world_size": wg.group_world_size,
+            "entry_point": spec.get_entrypoint_name(),
+        }
+        if worker:
+            md["local_rank"] = (worker.local_rank,)
+            md["role_rank"] = (worker.role_rank,)
+            md["role_world_size"] = (worker.role_world_size,)
+            global_rank = worker.global_rank
+            worker_id = str(worker.id)
+        else:
+            global_rank = None
+            worker_id = None
+        md_str = json.dumps(md)
+        metadata = {
+            "run_id": spec.rdzv_handler.get_run_id(),
+            "global_rank": global_rank,
+            "group_rank": wg.group_rank,
+            "worker_id": worker_id,
+            "role": spec.role,
+            "hostname": _get_fq_hostname(),
+            "state": state,
+            "total_run_time": self._total_execution_time,
+            "rdzv_backend": spec.rdzv_handler.get_backend(),
+            "raw_error": raw_error,
+            "metadata": md_str,
+            "agent_restarts": spec.max_restarts - self._remaining_restarts,
+        }
+        return Event(
+            f"torchelastic.worker.status.{state}", source=source, metadata=metadata
+        )
+
+    def _record_metrics(self, group_results: RunResult):
+        is_failed = group_results.is_failed()
+        self._record_flakiness_metric(is_failed)
+        spec = self._worker_group.spec
+        restarts_happened = self._remaining_restarts != spec.max_restarts
+        put_metric(f"workers.{spec.role}.run_total", 1)
+        self._record_metric_with_condition(
+            "run_success_with_retries", not is_failed and restarts_happened
+        )
+        self._record_metric_with_condition(
+            "run_success_no_retries", not is_failed and not restarts_happened
+        )
+        self._record_metric_with_condition(
+            "run_failed_with_retries", is_failed and restarts_happened
+        )
+        self._record_metric_with_condition(
+            "run_failed_no_retries", is_failed and not restarts_happened
+        )
+
+    def _record_metric_with_condition(self, metric_name, condition):
+        spec = self._worker_group.spec
+        if condition:
+            put_metric(f"workers.{spec.role}.{metric_name}", 1)
+        else:
+            put_metric(f"workers.{spec.role}.{metric_name}", 0)
+
+    def _record_flakiness_metric(self, is_failed: bool = False):
+        if is_failed:
+            flakiness = 100.0
+        else:
+            spec = self._worker_group.spec
+            flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (
+                spec.max_restarts + 1
+            )
+        spec = self._worker_group.spec
+
+        put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
+
+    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
+        # NOTE: currently only works for a single role
+
+        spec = self._worker_group.spec
+        role = spec.role
+
+        log.info(
+            "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()
+        )
+
+        self._initialize_workers(self._worker_group)
+        monitor_interval = spec.monitor_interval
+        rdzv_handler = spec.rdzv_handler
+
+        while True:
+            assert self._worker_group.state != WorkerState.INIT
+            time.sleep(monitor_interval)
+            run_result = self._monitor_workers(self._worker_group)
+            state = run_result.state
+            self._worker_group.state = state
+
+            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
+            put_metric(f"workers.{role}.{state.name.lower()}", 1)
+
+            if state == WorkerState.SUCCEEDED:
+                log.info(
+                    "[%s] worker group successfully finished."
+                    " Waiting %s seconds for other agents to finish.",
+                    role, self._exit_barrier_timeout
+                )
+                self._exit_barrier()
+                return run_result
+            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
+                if self._remaining_restarts > 0:
+                    log.info(
+                        "[%s] Worker group %s. "
+                        "%s/%s attempts left;"
+                        " will restart worker group",
+                        role, state.name, self._remaining_restarts, spec.max_restarts
+                    )
+                    self._remaining_restarts -= 1
+                    self._restart_workers(self._worker_group)
+                else:
+                    self._stop_workers(self._worker_group)
+                    self._worker_group.state = WorkerState.FAILED
+                    return run_result
+            elif state == WorkerState.HEALTHY:
+                # membership changes do not count as retries
+                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
+                group_rank = self._worker_group.group_rank
+                if num_nodes_waiting > 0:
+                    log.info(
+                        "[%s] Detected %s "
+                        "new nodes from group_rank=%s; "
+                        "will restart worker group",
+                        role, num_nodes_waiting, group_rank
+                    )
+                    self._restart_workers(self._worker_group)
+            else:
+                raise Exception(f"[{role}] Worker group in {state.name} state")
+
+    def _exit_barrier(self):
+        """
+        Define a barrier that keeps the agent process alive until all workers finish.
+
+        Wait for ``exit_barrier_timeout`` seconds for all agents to finish
+        executing their local workers (either successfully or not). This
+        acts as a safety guard against user scripts that terminate at different
+        times.
+        """
+        log.info(
+            "Local worker group finished (%s). "
+            "Waiting %s seconds for other agents to finish",
+            self._worker_group.state, self._exit_barrier_timeout
+        )
+        start = time.time()
+        try:
+            store_util.barrier(
+                self._store,
+                self._worker_group.group_rank,
+                self._worker_group.group_world_size,
+                key_prefix=_TERMINAL_STATE_SYNC_ID,
+                barrier_timeout=self._exit_barrier_timeout,
+            )
+            log.info(
+                "Done waiting for other agents. Elapsed: %s seconds", time.time() - start
+            )
+        except SignalException as e:
+            log.warning("Got termination signal: %s", e.sigval)
+            raise
+        except Exception:
+            log.exception(
+                "Error waiting on exit barrier. Elapsed: %s seconds",
+                time.time() - start
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5f98867024c1df5965a3e54b3097fb63551dc6c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py
@@ -0,0 +1,339 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import json
+import os
+import signal
+import socket
+from string import Template
+import uuid
+from typing import Any, Dict, Optional, Tuple
+
+import torch.distributed.elastic.timer as timer
+from torch.distributed.elastic import events
+
+from torch.distributed.elastic.agent.server.api import (
+    RunResult,
+    SimpleElasticAgent,
+    WorkerGroup,
+    WorkerSpec,
+    WorkerState,
+)
+from torch.distributed.elastic.events.api import EventMetadataValue
+from torch.distributed.elastic.metrics.api import prof
+from torch.distributed.elastic.multiprocessing import PContext, start_processes, LogsSpecs
+from torch.distributed.elastic.utils import macros
+from torch.distributed.elastic.utils.logging import get_logger
+
+log = get_logger(__name__)
+
+__all__ = [
+    "LocalElasticAgent",
+    "TORCHELASTIC_ENABLE_FILE_TIMER",
+    "TORCHELASTIC_TIMER_FILE",
+]
+
+TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
+TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"
+
+class LocalElasticAgent(SimpleElasticAgent):
+    """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
+
+    This agent is deployed per host and is configured to spawn ``n`` workers.
+    When using GPUs, ``n`` maps to the number of GPUs available on the host.
+
+    The local agent does not communicate to other local agents deployed on
+    other hosts, even if the workers may communicate inter-host. The worker id
+    is interpreted to be a local process. The agent starts and stops all worker
+    processes as a single unit.
+
+
+    The worker function and argument passed to the worker function must be
+    python multiprocessing compatible. To pass multiprocessing data structures
+    to the workers you may create the data structure in the same multiprocessing
+    context as the specified ``start_method`` and pass it as a function argument.
+
+    The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait
+    for other agents to finish. This acts as a safety net to handle cases where
+    workers finish at different times, to prevent agents from viewing workers
+    that finished early as a scale-down event. It is strongly advised that the
+    user code deal with ensuring that workers are terminated in a synchronous
+    manner rather than relying on the exit_barrier_timeout.
+
+    A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an
+    environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has
+    been defined in the ```LocalElasticAgent``` process.
+    Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE```
+    can be set with a unique file name for the named pipe. If the environment
+    variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent```
+    will internally create a unique file name and set it to the environment
+    variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will
+    be propagated to the worker processes to allow them to connect to the same
+    named pipe that ```LocalElasticAgent``` uses.
+
+    Logs are written to the specified log directory. Each log line will be by default
+    prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``).
+    Log prefixes can be customized by passing a `template string
+    `_ as the
+    ``log_line_prefix_template`` argument.
+    The following macros (identifiers) are substituted at runtime:
+    ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with
+    global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``.
+
+
+    Example launching function
+
+    ::
+
+        def trainer(args) -> str:
+            return "do train"
+
+        def main():
+            start_method="spawn"
+            shared_queue= multiprocessing.get_context(start_method).Queue()
+            spec = WorkerSpec(
+                        role="trainer",
+                        local_world_size=nproc_per_process,
+                        entrypoint=trainer,
+                        args=("foobar",),
+                        ...)
+            agent = LocalElasticAgent(spec, start_method)
+            results = agent.run()
+
+            if results.is_failed():
+                print("trainer failed")
+            else:
+                print(f"rank 0 return value: {results.return_values[0]}")
+                # prints -> rank 0 return value: do train
+
+    Example launching binary
+
+    ::
+
+        def main():
+            spec = WorkerSpec(
+                        role="trainer",
+                        local_world_size=nproc_per_process,
+                        entrypoint="/usr/local/bin/trainer",
+                        args=("--trainer-args", "foobar"),
+                        ...)
+            agent = LocalElasticAgent(spec)
+            results = agent.run()
+
+            if not results.is_failed():
+                print("binary launches do not have return values")
+
+    """
+
+    def __init__(
+        self,
+        spec: WorkerSpec,
+        logs_specs: LogsSpecs,
+        start_method="spawn",
+        exit_barrier_timeout: float = 300,
+        log_line_prefix_template: Optional[str] = None,
+    ):
+        super().__init__(spec, exit_barrier_timeout)
+        self._start_method = start_method
+        self._pcontext: Optional[PContext] = None
+        self._rdzv_handler = spec.rdzv_handler
+        self._log_line_prefix_template = log_line_prefix_template
+        self._worker_watchdog: Optional[timer.FileTimerServer] = None
+        self._logs_specs = logs_specs
+
+
+    def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
+        enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
+        watchdog_enabled = os.getenv(enable_watchdog_env_name)
+        watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
+        watchdog_file_path = os.getenv(watchdog_file_env_name)
+        if watchdog_enabled is not None and str(watchdog_enabled) == "1":
+            if watchdog_file_path is None:
+                watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
+            log.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
+            self._worker_watchdog = timer.FileTimerServer(
+                file_path=watchdog_file_path,
+                max_interval=0.1,
+                daemon=True,
+                log_event=self._log_watchdog_event)
+            self._worker_watchdog.start()
+            log.info("FileTimerServer started")
+        else:
+            log.info("Environment variable '%s' not found. Do not start FileTimerServer.", enable_watchdog_env_name)
+        # Propagate the watchdog file env to worker processes
+        if watchdog_file_path is not None:
+            for worker_env in envs.values():
+                worker_env[watchdog_file_env_name] = watchdog_file_path
+
+
+    def _get_fq_hostname(self) -> str:
+        return socket.getfqdn(socket.gethostname())
+
+    def _log_watchdog_event(
+        self,
+        name: str,
+        request: Optional[timer.FileTimerRequest],
+    ) -> None:
+        wg = self._worker_group
+        spec = wg.spec
+        md = {
+            "watchdog_event": name
+        }
+        if request is not None:
+            md["worker_pid"] = str(request.worker_pid)
+            md["scope_id"] = request.scope_id
+            md["expiration_time"] = str(request.expiration_time)
+            md["signal"] = str(request.signal)
+        md_str = json.dumps(md)
+        state = "RUNNING"
+        metadata: Dict[str, EventMetadataValue] = {
+            "run_id": spec.rdzv_handler.get_run_id(),
+            "global_rank": None,
+            "group_rank": wg.group_rank,
+            "worker_id": None,
+            "role": spec.role,
+            "hostname": self._get_fq_hostname(),
+            "state": state,
+            "total_run_time": self._total_execution_time,
+            "rdzv_backend": spec.rdzv_handler.get_backend(),
+            "raw_error": None,
+            "metadata": md_str,
+            "agent_restarts": spec.max_restarts - self._remaining_restarts,
+        }
+        # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later.
+        #       The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry.
+        event = events.Event(
+            name=name, source=events.EventSource.AGENT, metadata=metadata
+        )
+        events.record(event)
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _stop_workers(self, worker_group: WorkerGroup) -> None:
+        self._shutdown()
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
+        spec = worker_group.spec
+        store = worker_group.store
+        assert store is not None
+        master_addr, master_port = super()._get_master_addr_port(store)
+        restart_count = spec.max_restarts - self._remaining_restarts
+
+        use_agent_store = spec.rdzv_handler.get_backend() == "static"
+
+        args: Dict[int, Tuple] = {}
+        envs: Dict[int, Dict[str, str]] = {}
+        log_line_prefixes: Optional[Dict[int, str]] = {} if self._log_line_prefix_template else None
+        for worker in worker_group.workers:
+            local_rank = worker.local_rank
+            worker_env = {
+                "LOCAL_RANK": str(local_rank),
+                "RANK": str(worker.global_rank),
+                "GROUP_RANK": str(worker_group.group_rank),
+                "ROLE_RANK": str(worker.role_rank),
+                "ROLE_NAME": spec.role,
+                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
+                "WORLD_SIZE": str(worker.world_size),
+                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
+                "ROLE_WORLD_SIZE": str(worker.role_world_size),
+                "MASTER_ADDR": master_addr,
+                "MASTER_PORT": str(master_port),
+                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
+                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
+                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
+                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv(
+                    "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
+                ),
+            }
+            if "OMP_NUM_THREADS" in os.environ:
+                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
+
+
+            if self._log_line_prefix_template:
+                log_line_prefix = Template(self._log_line_prefix_template).safe_substitute(
+                    role_name=spec.role,
+                    rank=worker.global_rank,
+                    local_rank=local_rank,)
+                log_line_prefixes[local_rank] = log_line_prefix
+
+            envs[local_rank] = worker_env
+            worker_args = list(spec.args)
+            worker_args = macros.substitute(worker_args, str(local_rank))
+            args[local_rank] = tuple(worker_args)
+
+        self._setup_local_watchdog(envs=envs)
+
+        assert spec.entrypoint is not None
+        assert self._logs_specs is not None
+        self._pcontext = start_processes(
+            name=spec.role,
+            entrypoint=spec.entrypoint,
+            args=args,
+            envs=envs,
+            logs_specs=self._logs_specs,
+            log_line_prefixes=log_line_prefixes,
+            start_method=self._start_method,
+        )
+
+        return self._pcontext.pids()
+
+    def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
+        if self._worker_watchdog is not None:
+            self._worker_watchdog.stop()
+            self._worker_watchdog = None
+        if self._pcontext:
+            self._pcontext.close(death_sig)
+        if self._rdzv_handler:
+            self._rdzv_handler.shutdown()
+
+    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
+    #  `torch.distributed.elastic.metrics.prof`.
+    @prof
+    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
+        role = worker_group.spec.role
+        worker_pids = {w.id for w in worker_group.workers}
+        assert self._pcontext is not None
+        pc_pids = set(self._pcontext.pids().values())
+        if worker_pids != pc_pids:
+            log.error(
+                "[%s] worker pids do not match process_context pids."
+                " Expected: %s, actual: %s",
+                role, worker_pids, pc_pids
+            )
+            return RunResult(state=WorkerState.UNKNOWN)
+
+        result = self._pcontext.wait(0)
+        if result:
+            if result.is_failed():
+                # map local rank failure to global rank
+                worker_failures = {}
+                for local_rank, failure in result.failures.items():
+                    worker = worker_group.workers[local_rank]
+                    worker_failures[worker.global_rank] = failure
+                return RunResult(
+                    state=WorkerState.FAILED,
+                    failures=worker_failures,
+                )
+            else:
+                # copy ret_val_queue into a map with a global ranks
+                workers_ret_vals = {}
+                for local_rank, ret_val in result.return_values.items():
+                    worker = worker_group.workers[local_rank]
+                    workers_ret_vals[worker.global_rank] = ret_val
+                return RunResult(
+                    state=WorkerState.SUCCEEDED,
+                    return_values=workers_ret_vals,
+                )
+        else:
+            return RunResult(state=WorkerState.HEALTHY)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/events/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b5ff025f8ae3a2a10d3e28ec1b6ef9d5ae4573f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__init__.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env/python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Module contains events processing mechanisms that are integrated with the standard python logging.
+
+Example of usage:
+
+::
+
+  from torch.distributed.elastic import events
+  event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
+  events.get_logging_handler(destination="console").info(event)
+
+"""
+
+import inspect
+import logging
+import os
+import socket
+import traceback
+from enum import Enum
+from typing import Dict, Optional
+
+from torch.distributed.elastic.events.handlers import get_logging_handler
+
+from .api import (  # noqa: F401
+    Event,
+    EventMetadataValue,
+    EventSource,
+    NodeState,
+    RdzvEvent,
+)
+
+_events_loggers: Dict[str, logging.Logger] = {}
+
+def _get_or_create_logger(destination: str = "null") -> logging.Logger:
+    """
+    Construct python logger based on the destination type or extends if provided.
+
+    Available destination could be found in ``handlers.py`` file.
+    The constructed logger does not propagate messages to the upper level loggers,
+    e.g. root logger. This makes sure that a single event can be processed once.
+
+    Args:
+        destination: The string representation of the event handler.
+            Available handlers found in ``handlers`` module
+    """
+    global _events_loggers
+
+    if destination not in _events_loggers:
+        _events_logger = logging.getLogger(f"torchelastic-events-{destination}")
+        _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
+        # Do not propagate message to the root logger
+        _events_logger.propagate = False
+
+        logging_handler = get_logging_handler(destination)
+        _events_logger.addHandler(logging_handler)
+
+        # Add the logger to the global dictionary
+        _events_loggers[destination] = _events_logger
+
+    return _events_loggers[destination]
+
+
+def record(event: Event, destination: str = "null") -> None:
+    _get_or_create_logger(destination).info(event.serialize())
+
+def record_rdzv_event(event: RdzvEvent) -> None:
+    _get_or_create_logger("dynamic_rendezvous").info(event.serialize())
+
+
+def construct_and_record_rdzv_event(
+    run_id: str,
+    message: str,
+    node_state: NodeState,
+    name: str = "",
+    hostname: str = "",
+    pid: Optional[int] = None,
+    master_endpoint: str = "",
+    local_id: Optional[int] = None,
+    rank: Optional[int] = None,
+) -> None:
+    # We don't want to perform an extra computation if not needed.
+    if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
+        return
+
+    # Set up parameters.
+    if not hostname:
+        hostname = socket.getfqdn()
+    if not pid:
+        pid = os.getpid()
+
+    # Determines which file called this function.
+    callstack = inspect.stack()
+    filename = "no_file"
+    if len(callstack) > 1:
+        stack_depth_1 = callstack[1]
+        filename = os.path.basename(stack_depth_1.filename)
+        if not name:
+            name = stack_depth_1.function
+
+    # Delete the callstack variable. If kept, this can mess with python's
+    # garbage collector as we are holding on to stack frame information in
+    # the inspect module.
+    del callstack
+
+    # Set up error trace if this is an exception
+    if node_state == NodeState.FAILED:
+        error_trace = traceback.format_exc()
+    else:
+        error_trace = ""
+
+    # Initialize event object
+    event = RdzvEvent(
+        name=f"{filename}:{name}",
+        run_id=run_id,
+        message=message,
+        hostname=hostname,
+        pid=pid,
+        node_state=node_state,
+        master_endpoint=master_endpoint,
+        rank=rank,
+        local_id=local_id,
+        error_trace=error_trace,
+    )
+
+    # Finally, record the event.
+    record_rdzv_event(event)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a66c7eee23ed21bb25c273f0c5215005e5f91ac0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d59f704caf9540c1ac4abecb4fb03e88d9623ccd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf4b06bcf5863d79d808a9d82f6b44ca2494bc95
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/events/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/events/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..37d1d8947e0ecfcae86c512cad94ede4214970ba
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/events/api.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+from dataclasses import asdict, dataclass, field
+from enum import Enum
+from typing import Dict, Union, Optional
+
+__all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent']
+
+EventMetadataValue = Union[str, int, float, bool, None]
+
+
+class EventSource(str, Enum):
+    """Known identifiers of the event producers."""
+
+    AGENT = "AGENT"
+    WORKER = "WORKER"
+
+
+@dataclass
+class Event:
+    """
+    The class represents the generic event that occurs during the torchelastic job execution.
+
+    The event can be any kind of meaningful action.
+
+    Args:
+        name: event name.
+        source: the event producer, e.g. agent or worker
+        timestamp: timestamp in milliseconds when event occurred.
+        metadata: additional data that is associated with the event.
+    """
+
+    name: str
+    source: EventSource
+    timestamp: int = 0
+    metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
+
+    def __str__(self):
+        return self.serialize()
+
+    @staticmethod
+    def deserialize(data: Union[str, "Event"]) -> "Event":
+        if isinstance(data, Event):
+            return data
+        if isinstance(data, str):
+            data_dict = json.loads(data)
+        data_dict["source"] = EventSource[data_dict["source"]]  # type: ignore[possibly-undefined]
+        return Event(**data_dict)
+
+    def serialize(self) -> str:
+        return json.dumps(asdict(self))
+
+
+class NodeState(str, Enum):
+    """The states that a node can be in rendezvous."""
+
+    INIT = "INIT"
+    RUNNING = "RUNNING"
+    SUCCEEDED = "SUCCEEDED"
+    FAILED = "FAILED"
+
+
+@dataclass
+class RdzvEvent:
+    """
+    Dataclass to represent any rendezvous event.
+
+    Args:
+        name: Event name. (E.g. Current action being performed)
+        run_id: The run id of the rendezvous
+        message: The message describing the event
+        hostname: Hostname of the node
+        pid: The process id of the node
+        node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
+        master_endpoint: The master endpoint for the rendezvous store, if known
+        rank: The rank of the node, if known
+        local_id: The local_id of the node, if defined in dynamic_rendezvous.py
+        error_trace: Error stack trace, if this is an error event.
+    """
+
+    name: str
+    run_id: str
+    message: str
+    hostname: str
+    pid: int
+    node_state: NodeState
+    master_endpoint: str = ""
+    rank: Optional[int] = None
+    local_id: Optional[int] = None
+    error_trace: str = ""
+
+    def __str__(self):
+        return self.serialize()
+
+    @staticmethod
+    def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
+        if isinstance(data, RdzvEvent):
+            return data
+        if isinstance(data, str):
+            data_dict = json.loads(data)
+        data_dict["node_state"] = NodeState[data_dict["node_state"]]  # type: ignore[possibly-undefined]
+        return RdzvEvent(**data_dict)
+
+    def serialize(self) -> str:
+        return json.dumps(asdict(self))
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/events/handlers.py b/MLPY/Lib/site-packages/torch/distributed/elastic/events/handlers.py
new file mode 100644
index 0000000000000000000000000000000000000000..51dd142801ba1f3d597d41da7df9121aef006fe7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/events/handlers.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Dict
+
+
+_log_handlers: Dict[str, logging.Handler] = {
+    "console": logging.StreamHandler(),
+    "dynamic_rendezvous": logging.NullHandler(),
+    "null": logging.NullHandler(),
+}
+
+
+def get_logging_handler(destination: str = "null") -> logging.Handler:
+    global _log_handlers
+    return _log_handlers[destination]
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc47774f8bc60e40663b5cb6e5f703afc1ad2b8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__init__.py
@@ -0,0 +1,163 @@
+#!/usr/bin/env/python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Metrics API.
+
+**Overview**:
+
+The metrics API in torchelastic is used to publish telemetry metrics.
+It is designed to be used by torchelastic's internal modules to
+publish metrics for the end user with the goal of increasing visibility
+and helping with debugging. However you may use the same API in your
+jobs to publish metrics to the same metrics ``sink``.
+
+A ``metric`` can be thought of as timeseries data
+and is uniquely identified by the string-valued tuple
+``(metric_group, metric_name)``.
+
+torchelastic makes no assumptions about what a ``metric_group`` is
+and what relationship it has with ``metric_name``. It is totally up
+to the user to use these two fields to uniquely identify a metric.
+
+.. note:: The metric group ``torchelastic`` is reserved by torchelastic for
+          platform level metrics that it produces.
+          For instance torchelastic may output the latency (in milliseconds)
+          of a re-rendezvous operation from the agent as
+          ``(torchelastic, agent.rendezvous.duration.ms)``
+
+A sensible way to use metric groups is to map them to a stage or module
+in your job. You may also encode certain high level properties
+the job such as the region or stage (dev vs prod).
+
+**Publish Metrics**:
+
+Using torchelastic's metrics API is similar to using python's logging
+framework. You first have to configure a metrics handler before
+trying to add metric data.
+
+The example below measures the latency for the ``calculate()`` function.
+
+::
+
+  import time
+  import torch.distributed.elastic.metrics as metrics
+
+  # makes all metrics other than the one from "my_module" to go /dev/null
+  metrics.configure(metrics.NullMetricsHandler())
+  metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
+
+  def my_method():
+    start = time.time()
+    calculate()
+    end = time.time()
+    metrics.put_metric("calculate_latency", int(end-start), "my_module")
+
+You may also use the torch.distributed.elastic.metrics.prof` decorator
+to conveniently and succinctly profile functions
+
+::
+
+  # -- in module examples.foobar --
+
+  import torch.distributed.elastic.metrics as metrics
+
+  metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
+  metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
+
+  @metrics.prof
+  def foo():
+    pass
+
+  class Bar():
+
+    @metrics.prof
+    def baz():
+        pass
+
+``@metrics.prof`` will publish the following metrics
+::
+
+  .success - 1 if the function finished successfully
+  .failure - 1 if the function threw an exception
+  .duration.ms - function duration in milliseconds
+
+**Configuring Metrics Handler**:
+
+`torch.distributed.elastic.metrics.MetricHandler` is responsible for emitting
+the added metric values to a particular destination. Metric groups can be
+configured with different metric handlers.
+
+By default torchelastic emits all metrics to ``/dev/null``.
+By adding the following configuration metrics,
+``torchelastic`` and ``my_app`` metric groups will be printed out to
+console.
+
+::
+
+  import torch.distributed.elastic.metrics as metrics
+
+  metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic")
+  metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app")
+
+**Writing a Custom Metric Handler**:
+
+If you want your metrics to be emitted to a custom location, implement
+the `torch.distributed.elastic.metrics.MetricHandler` interface
+and configure your job to use your custom metric handler.
+
+Below is a toy example that prints the metrics to ``stdout``
+
+::
+
+  import torch.distributed.elastic.metrics as metrics
+
+  class StdoutMetricHandler(metrics.MetricHandler):
+     def emit(self, metric_data):
+         ts = metric_data.timestamp
+         group = metric_data.group_name
+         name = metric_data.name
+         value = metric_data.value
+         print(f"[{ts}][{group}]: {name}={value}")
+
+  metrics.configure(StdoutMetricHandler(), group="my_app")
+
+Now all metrics in the group ``my_app`` will be printed to stdout as:
+
+::
+
+  [1574213883.4182858][my_app]: my_metric=
+  [1574213940.5237644][my_app]: my_metric=
+
+"""
+
+from typing import Optional
+
+from .api import (  # noqa: F401
+    ConsoleMetricHandler,
+    MetricData,
+    MetricHandler,
+    MetricsConfig,
+    NullMetricHandler,
+    configure,
+    get_elapsed_time_ms,
+    getStream,
+    prof,
+    profile,
+    publish_metric,
+    put_metric,
+)
+
+
+def initialize_metrics(cfg: Optional[MetricsConfig] = None):
+    pass
+
+
+try:
+    from torch.distributed.elastic.metrics.static_init import *  # type: ignore[import] # noqa: F401 F403
+except ModuleNotFoundError:
+    pass
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44096ca0ea2e27b0c28ae614341d8ac24304db0d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c23139ef26a2d91cfd79842f96fcdcad9ec00422
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..abe2ef22bd53edea9da695dde06259510d569ff8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/metrics/api.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import abc
+import time
+import warnings
+from collections import namedtuple
+from functools import wraps
+from typing import Dict, Optional
+
+__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream',
+           'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms',
+           'MetricData']
+
+MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
+
+
+class MetricsConfig:
+    __slots__ = ["params"]
+
+    def __init__(self, params: Optional[Dict[str, str]] = None):
+        self.params = params
+        if self.params is None:
+            self.params = {}
+
+
+class MetricHandler(abc.ABC):
+    @abc.abstractmethod
+    def emit(self, metric_data: MetricData):
+        pass
+
+
+class ConsoleMetricHandler(MetricHandler):
+    def emit(self, metric_data: MetricData):
+        print(
+            f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
+        )
+
+
+class NullMetricHandler(MetricHandler):
+    def emit(self, metric_data: MetricData):
+        pass
+
+
+class MetricStream:
+    def __init__(self, group_name: str, handler: MetricHandler):
+        self.group_name = group_name
+        self.handler = handler
+
+    def add_value(self, metric_name: str, metric_value: int):
+        self.handler.emit(
+            MetricData(time.time(), self.group_name, metric_name, metric_value)
+        )
+
+
+_metrics_map: Dict[str, MetricHandler] = {}
+_default_metrics_handler: MetricHandler = NullMetricHandler()
+
+
+# pyre-fixme[9]: group has type `str`; used as `None`.
+def configure(handler: MetricHandler, group: Optional[str] = None):
+    if group is None:
+        global _default_metrics_handler
+        # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
+        #  as `MetricHandler`.
+        _default_metrics_handler = handler
+    else:
+        _metrics_map[group] = handler
+
+
+def getStream(group: str):
+    if group in _metrics_map:
+        handler = _metrics_map[group]
+    else:
+        handler = _default_metrics_handler
+    return MetricStream(group, handler)
+
+
+def _get_metric_name(fn):
+    qualname = fn.__qualname__
+    split = qualname.split(".")
+    if len(split) == 1:
+        module = fn.__module__
+        if module:
+            return module.split(".")[-1] + "." + split[0]
+        else:
+            return split[0]
+    else:
+        return qualname
+
+
+def prof(fn=None, group: str = "torchelastic"):
+    r"""
+    @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
+
+    The metric name defaults to the qualified name (``class_name.def_name``) of the function.
+    If the function does not belong to a class, it uses the leaf module name instead.
+
+    Usage
+
+    ::
+
+     @metrics.prof
+     def x():
+         pass
+
+     @metrics.prof(group="agent")
+     def y():
+         pass
+    """
+
+    def wrap(f):
+        @wraps(f)
+        def wrapper(*args, **kwargs):
+            key = _get_metric_name(f)
+            try:
+                start = time.time()
+                result = f(*args, **kwargs)
+                put_metric(f"{key}.success", 1, group)
+            except Exception:
+                put_metric(f"{key}.failure", 1, group)
+                raise
+            finally:
+                put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group)  # type: ignore[possibly-undefined]
+            return result
+
+        return wrapper
+
+    if fn:
+        return wrap(fn)
+    else:
+        return wrap
+
+
+def profile(group=None):
+    """
+    @profile decorator adds latency and success/failure metrics to any given function.
+
+    Usage
+
+    ::
+
+     @metrics.profile("my_metric_group")
+     def some_function():
+    """
+    warnings.warn("Deprecated, use @prof instead", DeprecationWarning)
+
+    def wrap(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            try:
+                start_time = time.time()
+                result = func(*args, **kwargs)
+                publish_metric(group, f"{func.__name__}.success", 1)
+            except Exception:
+                publish_metric(group, f"{func.__name__}.failure", 1)
+                raise
+            finally:
+                publish_metric(
+                    group,
+                    f"{func.__name__}.duration.ms",
+                    get_elapsed_time_ms(start_time),  # type: ignore[possibly-undefined]
+                )
+            return result
+
+        return wrapper
+
+    return wrap
+
+
+def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
+    """
+    Publish a metric data point.
+
+    Usage
+
+    ::
+
+     put_metric("metric_name", 1)
+     put_metric("metric_name", 1, "metric_group_name")
+    """
+    getStream(metric_group).add_value(metric_name, metric_value)
+
+
+def publish_metric(metric_group: str, metric_name: str, metric_value: int):
+    warnings.warn(
+        "Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead"
+    )
+    metric_stream = getStream(metric_group)
+    metric_stream.add_value(metric_name, metric_value)
+
+
+def get_elapsed_time_ms(start_time_in_seconds: float):
+    """Return the elapsed time in millis from the given start time."""
+    end_time = time.time()
+    return int((end_time - start_time_in_seconds) * 1000)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c7beb6b0ad259bbe3240495a4285fa1a7c6dd19
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__init__.py
@@ -0,0 +1,235 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary.
+
+For functions, it uses ``torch.multiprocessing`` (and therefore python
+``multiprocessing``) to spawn/fork worker processes. For binaries it uses python
+``subprocessing.Popen`` to create worker processes.
+
+
+Usage 1: Launching two trainers as a function
+
+::
+
+ from torch.distributed.elastic.multiprocessing import Std, start_processes
+
+ def trainer(a, b, c):
+     pass # train
+
+
+ # runs two trainers
+ # LOCAL_RANK=0 trainer(1,2,3)
+ # LOCAL_RANK=1 trainer(4,5,6)
+ ctx = start_processes(
+         name="trainer",
+         entrypoint=trainer,
+         args={0: (1,2,3), 1: (4,5,6)},
+         envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
+         log_dir="/tmp/foobar",
+         redirects=Std.ALL, # write all worker stdout/stderr to a log file
+         tee={0: Std.ERR}, # tee only local rank 0's stderr to console
+       )
+
+ # waits for all copies of trainer to finish
+ ctx.wait()
+
+Usage 2: Launching 2 echo workers as a binary
+
+::
+
+ # same as invoking
+ # echo hello
+ # echo world > stdout.log
+ ctx = start_processes(
+         name="echo"
+         entrypoint="echo",
+         log_dir="/tmp/foobar",
+         args={0: "hello", 1: "world"},
+         redirects={1: Std.OUT},
+        )
+
+Just like ``torch.multiprocessing``, the return value of the function
+:func:`start_processes` is a process context (:class:`api.PContext`). If a function
+was launched, a :class:`api.MultiprocessContext` is returned and if a binary
+was launched a :class:`api.SubprocessContext` is returned. Both are specific
+implementations of the parent :class:`api.PContext` class.
+"""
+
+import os
+from typing import Callable, Dict, Optional, Tuple, Union, Set
+
+from torch.distributed.elastic.multiprocessing.api import (  # noqa: F401
+    _validate_full_rank,
+    DefaultLogsSpecs,
+    LogsDest,
+    LogsSpecs,
+    MultiprocessContext,
+    PContext,
+    ProcessFailure,
+    RunProcsResult,
+    SignalException,
+    Std,
+    SubprocessContext,
+    to_map,
+)
+from torch.distributed.elastic.utils.logging import get_logger
+
+__all__ = [
+    "start_processes",
+    "MultiprocessContext",
+    "PContext",
+    "ProcessFailure",
+    "RunProcsResult",
+    "SignalException",
+    "Std",
+    "LogsDest",
+    "LogsSpecs",
+    "DefaultLogsSpecs",
+    "SubprocessContext",
+    "to_map",
+]
+
+log = get_logger(__name__)
+
+
+def start_processes(
+    name: str,
+    entrypoint: Union[Callable, str],
+    args: Dict[int, Tuple],
+    envs: Dict[int, Dict[str, str]],
+    logs_specs: LogsSpecs,
+    log_line_prefixes: Optional[Dict[int, str]] = None,
+    start_method: str = "spawn",
+) -> PContext:
+    """
+    Start ``n`` copies of ``entrypoint`` processes with the provided options.
+
+    ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
+    The number of copies is determined by the number of entries for ``args`` and
+    ``envs`` arguments, which need to have the same key set.
+
+    ``args`` and ``env`` parameters are the arguments and environment variables
+    to pass down to the entrypoint mapped by the replica index (local rank).
+    All local ranks must be accounted for.
+    That is, the keyset should be ``{0,1,...,(nprocs-1)}``.
+
+    .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings.
+              If any other type is given, then it is casted to a string representation
+              (e.g. ``str(arg1)``). Furthermore, a binary failure will only write
+              an ``error.json`` error file if the main function is annotated with
+              ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches,
+              this is done by default and there is no need to manually annotate
+              with the ``@record`` annotation.
+
+    ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
+    to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
+    To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
+    the local rank to specify the redirect behavior for.
+    Any missing local ranks will default to ``Std.NONE``.
+
+    ``tee`` acts like the unix "tee" command in that it redirects + prints to console.
+    To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
+
+    For each process, the ``log_dir`` will contain:
+
+    #. ``{local_rank}/error.json``: if the process failed, a file with the error info
+    #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT``
+    #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR``
+
+    .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
+
+    Example:
+    ::
+
+     log_dir = "/tmp/test"
+
+     # ok; two copies of foo: foo("bar0"), foo("bar1")
+     start_processes(
+        name="trainer",
+        entrypoint=foo,
+        args:{0:("bar0",), 1:("bar1",),
+        envs:{0:{}, 1:{}},
+        log_dir=log_dir
+     )
+
+     # invalid; envs missing for local rank 1
+     start_processes(
+        name="trainer",
+        entrypoint=foo,
+        args:{0:("bar0",), 1:("bar1",),
+        envs:{0:{}},
+        log_dir=log_dir
+     )
+
+     # ok; two copies of /usr/bin/touch: touch file1, touch file2
+     start_processes(
+        name="trainer",
+        entrypoint="/usr/bin/touch",
+        args:{0:("file1",), 1:("file2",),
+        envs:{0:{}, 1:{}},
+        log_dir=log_dir
+      )
+
+     # caution; arguments casted to string, runs:
+     # echo "1" "2" "3" and echo "[1, 2, 3]"
+     start_processes(
+        name="trainer",
+        entrypoint="/usr/bin/echo",
+        args:{0:(1,2,3), 1:([1,2,3],),
+        envs:{0:{}, 1:{}},
+        log_dir=log_dir
+      )
+
+    Args:
+        name: a human readable short name that describes what the processes are
+              (used as header when tee'ing stdout/stderr outputs)
+        entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
+        args: arguments to each replica
+        envs: env vars to each replica
+        log_dir: directory used to write log files
+        start_method: multiprocessing start method (spawn, fork, forkserver)
+                      ignored for binaries
+        redirects: which std streams to redirect to a log file
+        tee: which std streams to redirect + print to console
+        local_ranks_filter: which ranks' logs to print to console
+
+    """
+
+    nprocs = len(args)
+    _validate_full_rank(args, nprocs, "args")
+    _validate_full_rank(envs, nprocs, "envs")
+
+    context: PContext
+    if isinstance(entrypoint, str):
+        context = SubprocessContext(
+            name=name,
+            entrypoint=entrypoint,
+            args=args,
+            envs=envs,
+            logs_specs=logs_specs,
+            log_line_prefixes=log_line_prefixes,
+        )
+    else:
+        context = MultiprocessContext(
+            name=name,
+            entrypoint=entrypoint,
+            args=args,
+            envs=envs,
+            log_line_prefixes=log_line_prefixes,
+            start_method=start_method,
+            logs_specs=logs_specs,
+        )
+
+    try:
+        context.start()
+        return context
+    except Exception:
+        context.close()
+        raise
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..369c1c91a1162f966f61cde508d0cb778dd07443
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e5912b0f70f85324bb42d495d7c107497a547fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5d8beee573c6e40fe968bbe11789d1635bac105
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c6df5103b1261e6193e01319799f2087c0ef6f8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f51a224ef3db62039a7dc014e656c15ab2bff23
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/api.py
@@ -0,0 +1,873 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import abc
+import logging
+import os
+import re
+import shutil
+import signal
+import subprocess
+import sys
+import tempfile
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass, field
+from enum import IntFlag
+from multiprocessing import synchronize
+from types import FrameType
+from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
+from abc import ABC, abstractmethod
+
+import torch.multiprocessing as mp
+from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
+from torch.distributed.elastic.multiprocessing.redirects import (
+    redirect_stderr,
+    redirect_stdout,
+)
+
+from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler, get_subprocess_handler
+from torch.distributed.elastic.multiprocessing.tail_log import TailLog
+
+IS_WINDOWS = sys.platform == "win32"
+IS_MACOS = sys.platform == "darwin"
+
+
+log = logging.getLogger(__name__)
+
+__all__ = [
+    "DefaultLogsSpecs",
+    "SignalException",
+    "Std",
+    "to_map",
+    "RunProcsResult",
+    "PContext",
+    "get_std_cm",
+    "MultiprocessContext",
+    "SubprocessContext",
+]
+
+class SignalException(Exception):
+    """
+    Exception is raised inside the torchelastic agent process by the termination handler
+    if the death signal got received by the process.
+    """
+
+    def __init__(self, msg: str, sigval: signal.Signals) -> None:
+        super().__init__(msg)
+        self.sigval = sigval
+
+
+def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
+    """Termination handler that raises exceptions on the main process.
+
+    When the process receives death signal(SIGTERM, SIGINT), this termination handler will
+    be invoked. It raises the ``SignalException`` exception that should be processed by the
+    user code. Python does not terminate process after the termination handler is finished,
+    so the exception should not be silently ignored, otherwise the process will never
+    be terminated.
+    """
+    sigval = signal.Signals(signum)
+    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
+
+
+def _get_kill_signal() -> signal.Signals:
+    """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows."""
+    if IS_WINDOWS:
+        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
+    else:
+        return signal.SIGKILL
+
+
+def _get_default_signal() -> signal.Signals:
+    """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
+    if IS_WINDOWS:
+        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
+    else:
+        return signal.SIGTERM
+
+
+def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
+    actual_keys = set(d.keys())
+    expected_keys = set(range(nprocs))
+
+    if actual_keys != expected_keys:
+        raise RuntimeError(
+            f"{what}, local rank mapping mismatch,"
+            f" expected: {expected_keys}, actual: {actual_keys}"
+        )
+
+
+_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
+_VALUE_REGEX = r"^[0123]$"
+
+
+class Std(IntFlag):
+    NONE = 0
+    OUT = 1
+    ERR = 2
+    ALL = OUT | ERR
+
+    @classmethod
+    def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
+        """
+        Example:
+        ::
+
+         from_str("0") -> Std.NONE
+         from_str("1") -> Std.OUT
+         from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
+
+        Any other input raises an exception
+        """
+
+        def to_std(v: str) -> Std:  # type: ignore[return]
+            s = Std(int(v))
+            if s in Std:
+                return s
+            # return None -> should NEVER reach here since we regex check input
+
+        if re.match(_VALUE_REGEX, vm):  # vm is a number (e.g. 0)
+            return to_std(vm)
+        elif re.match(_MAPPING_REGEX, vm):  # vm is a mapping (e.g. 0:1,1:2)
+            d: Dict[int, Std] = {}
+            for m in vm.split(","):
+                i, v = m.split(":")
+                d[int(i)] = to_std(v)
+            return d
+        else:
+            raise ValueError(
+                f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>"
+            )
+
+
+def to_map(
+    val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
+) -> Dict[int, Std]:
+    """
+    Certain APIs take redirect settings either as a single value (e.g. apply to all
+    local ranks) or as an explicit user-provided mapping. This method is a convenience
+    method that converts a value or mapping into a mapping.
+
+    Example:
+    ::
+
+     to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
+     to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
+     to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
+    """
+    if isinstance(val_or_map, Std):
+        return dict.fromkeys(range(local_world_size), val_or_map)
+    else:
+        map = {}
+        for i in range(local_world_size):
+            map[i] = val_or_map.get(i, Std.NONE)
+        return map
+
+
+@dataclass
+class LogsDest:
+    """
+    For each log type, holds mapping of local rank ids to file paths.
+    """
+    stdouts: Dict[int, str] = field(default_factory=dict)
+    stderrs: Dict[int, str] = field(default_factory=dict)
+    tee_stdouts: Dict[int, str] = field(default_factory=dict)
+    tee_stderrs: Dict[int, str] = field(default_factory=dict)
+    error_files: Dict[int, str] = field(default_factory=dict)
+
+
+class LogsSpecs(ABC):
+    """
+    Defines logs processing and redirection for each worker process.
+
+    Args:
+        log_dir:
+            Base directory where logs will be written.
+        redirects:
+            Streams to redirect to files. Pass a single ``Std``
+            enum to redirect for all workers, or a mapping keyed
+            by local_rank to selectively redirect.
+        tee:
+            Streams to duplicate to stdout/stderr.
+            Pass a single ``Std`` enum to duplicate streams for all workers,
+            or a mapping keyed by local_rank to selectively duplicate.
+    """
+
+    def __init__(
+        self,
+        log_dir: Optional[str] = None,
+        redirects: Union[Std, Dict[int, Std]] = Std.NONE,
+        tee: Union[Std, Dict[int, Std]] = Std.NONE,
+        local_ranks_filter: Optional[Set[int]] = None,
+    ) -> None:
+        self._root_log_dir = log_dir
+        self._redirects = redirects
+        self._tee = tee
+        self._local_ranks_filter = local_ranks_filter
+
+    @abstractmethod
+    def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest:
+        """
+        Given the environment variables, builds destination of log files for each of the local ranks.
+
+        Envs parameter contains env variables dict for each of the local ranks, where entries are defined in:
+        :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`.
+        """
+        pass
+
+    @property
+    @abstractmethod
+    def root_log_dir(self) -> str:
+        pass
+
+class DefaultLogsSpecs(LogsSpecs):
+    """
+    Default LogsSpecs implementation:
+
+    - `log_dir` will be created if it doesn't exist
+    - Generates nested folders for each attempt and rank.
+    """
+    def __init__(
+        self,
+        log_dir: Optional[str] = None,
+        redirects: Union[Std, Dict[int, Std]] = Std.NONE,
+        tee: Union[Std, Dict[int, Std]] = Std.NONE,
+        local_ranks_filter: Optional[Set[int]] = None,
+    ) -> None:
+        if log_dir != os.devnull:
+            if not log_dir:
+                log_dir = tempfile.mkdtemp(prefix="torchelastic_")
+            elif not os.path.exists(log_dir):
+                os.makedirs(log_dir)
+            else:
+                if os.path.isfile(log_dir):
+                    raise NotADirectoryError(f"log_dir: {log_dir} is a file")
+        super().__init__(log_dir, redirects, tee, local_ranks_filter)
+        # initialized only once
+        self._run_log_dir = None
+
+    @property
+    def root_log_dir(self) -> str:
+        return str(self._root_log_dir)
+
+    def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
+        base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
+        os.makedirs(base_log_dir, exist_ok=True)
+        dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
+        log.info("log directory set to: %s", dir)
+        return dir
+
+    def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest:
+        """
+        Uses following scheme to build log destination paths:
+
+        - `//attempt_//stdout.log`
+        - `//attempt_//stderr.log`
+        - `//attempt_//error.json`
+        """
+        nprocs = len(envs)
+        global_env = {}  # use only to query properies that are not dependent on a rank
+        if nprocs > 0:
+            global_env = envs[0]
+        else:
+            log.warning("Empty envs map provided when defining logging destinations.")
+        # Keys are always defined, but values can be missing in unit tests
+        run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id")
+        restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0")
+
+        attempt_log_dir: str = ""
+        if self._root_log_dir != os.devnull:
+            if not self._run_log_dir:
+                self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
+
+            attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}")  # type: ignore[call-overload]
+            shutil.rmtree(attempt_log_dir, ignore_errors=True)
+            os.makedirs(attempt_log_dir)
+
+        if self._root_log_dir == os.devnull:
+            attempt_log_dir = os.devnull
+
+        # create subdirs for each local rank in the logs_dir
+        # logs_dir
+        #       |- 0
+        #          |- error.json
+        #          |- stdout.log
+        #          |- stderr.log
+        #       |- ...
+        #       |- (nprocs-1)
+        redirs = to_map(self._redirects, nprocs)
+        ts = to_map(self._tee, nprocs)
+
+        # to tee stdout/stderr we first redirect into a file
+        # then tail -f stdout.log/stderr.log so add tee settings to redirects
+        for local_rank, tee_std in ts.items():
+            redirect_std = redirs[local_rank]
+            redirs[local_rank] = redirect_std | tee_std
+
+        SYS_STREAM = ""  # special case to indicate to output to console
+        stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
+        stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
+        tee_stdouts: Dict[int, str] = {}
+        tee_stderrs: Dict[int, str] = {}
+        error_files = {}
+
+        for local_rank in range(nprocs):
+
+            if attempt_log_dir == os.devnull:
+                tee_stdouts[local_rank] = os.devnull
+                tee_stderrs[local_rank] = os.devnull
+                error_files[local_rank] = os.devnull
+                envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = ""
+            else:
+                clogdir = os.path.join(attempt_log_dir, str(local_rank))
+                os.mkdir(clogdir)
+
+                rd = redirs[local_rank]
+                if (rd & Std.OUT) == Std.OUT:
+                    stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
+                if (rd & Std.ERR) == Std.ERR:
+                    stderrs[local_rank] = os.path.join(clogdir, "stderr.log")
+
+                t = ts[local_rank]
+                if t & Std.OUT == Std.OUT:
+                    tee_stdouts[local_rank] = stdouts[local_rank]
+                if t & Std.ERR == Std.ERR:
+                    tee_stderrs[local_rank] = stderrs[local_rank]
+
+                if self._local_ranks_filter and local_rank not in self._local_ranks_filter:
+                    # If stream is tee'd, only write to file, but don't tail
+                    if local_rank in tee_stdouts:
+                        tee_stdouts.pop(local_rank, None)
+                    if local_rank in tee_stderrs:
+                        tee_stderrs.pop(local_rank, None)
+
+                    # If stream is not redirected, don't print
+                    if stdouts[local_rank] == SYS_STREAM:
+                        stdouts[local_rank] = os.devnull
+                    if stderrs[local_rank] == SYS_STREAM:
+                        stderrs[local_rank] = os.devnull
+
+                error_file = os.path.join(clogdir, "error.json")
+                error_files[local_rank] = error_file
+                log.info("Setting worker%s reply file to: %s", local_rank, error_file)
+                envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
+
+        return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
+
+    def __repr__(self) -> str:
+        return (
+            f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, "
+            f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})"
+        )
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, DefaultLogsSpecs):
+            return False
+
+        return (
+            self._root_log_dir == other._root_log_dir
+            and self._redirects == other._redirects
+            and self._tee == other._tee
+            and self._local_ranks_filter == other._local_ranks_filter
+        )
+
+
+@dataclass
+class RunProcsResult:
+    """
+    Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``.
+
+    Note the following:
+
+    1. All fields are mapped by local rank
+    2. ``return_values`` - only populated for functions (not the binaries).
+    3. ``stdouts`` - path to stdout.log (empty string if no redirect)
+    4. ``stderrs`` - path to stderr.log (empty string if no redirect)
+
+    """
+
+    return_values: Dict[int, Any] = field(default_factory=dict)
+    failures: Dict[int, ProcessFailure] = field(default_factory=dict)
+    stdouts: Dict[int, str] = field(default_factory=dict)
+    stderrs: Dict[int, str] = field(default_factory=dict)
+
+    def is_failed(self) -> bool:
+        return len(self.failures) > 0
+
+
+class PContext(abc.ABC):
+    """
+    The base class that standardizes operations over a set of processes that are launched via different mechanisms.
+
+    The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.
+
+    .. warning:: stdouts and stderrs should ALWAYS be a superset of
+                 tee_stdouts and tee_stderrs (respectively) this is b/c
+                 tee is implemented as a redirect + tail -f 
+    """
+
+    def __init__(
+        self,
+        name: str,
+        entrypoint: Union[Callable, str],
+        args: Dict[int, Tuple],
+        envs: Dict[int, Dict[str, str]],
+        logs_specs: LogsSpecs,
+        log_line_prefixes: Optional[Dict[int, str]] = None,
+
+    ):
+        self.name = name
+        # validate that all mappings have the same number of keys and
+        # all local ranks are accounted for
+        nprocs = len(args)
+
+        # TODO log_line_prefixes can be exanded too
+        logs_dest = logs_specs.reify(envs)
+
+        _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts")
+        _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs")
+
+        self.entrypoint = entrypoint
+        self.args = args
+        self.envs = envs
+        self.stdouts = logs_dest.stdouts
+        self.stderrs = logs_dest.stderrs
+        self.error_files = logs_dest.error_files
+        self.nprocs = nprocs
+
+        self._stdout_tail = TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes)
+        self._stderr_tail = TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes)
+
+    def start(self) -> None:
+        """Start processes using parameters defined in the constructor."""
+        signal.signal(signal.SIGTERM, _terminate_process_handler)
+        signal.signal(signal.SIGINT, _terminate_process_handler)
+        if not IS_WINDOWS:
+            signal.signal(signal.SIGHUP, _terminate_process_handler)
+            signal.signal(signal.SIGQUIT, _terminate_process_handler)
+        self._start()
+        self._stdout_tail.start()
+        self._stderr_tail.start()
+
+    @abc.abstractmethod
+    def _start(self) -> None:
+        """Start processes using strategy defined in a particular context."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _poll(self) -> Optional[RunProcsResult]:
+        """
+        Poll the run status of the processes running under this context.
+        This method follows an "all-or-nothing" policy and returns
+        a ``RunProcessResults`` object if either all processes complete
+        successfully or any process fails. Returns ``None`` if
+        all processes are still running.
+        """
+        raise NotImplementedError()
+
+    def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
+        """
+        Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
+        for the processes to be done. Returns ``None`` if the processes are still running
+        on timeout expiry. Negative timeout values are interpreted as "wait-forever".
+        A timeout value of zero simply queries the status of the processes (e.g. equivalent
+        to a poll).
+
+        ..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise
+                ``SignalException`` when the signals received. It is up to the consumer of the code
+                to properly handle the exception. It is important not to swallow the exception otherwise
+                the process would not terminate. Example of the typical workflow can be:
+
+        .. code-block:: python
+            pc = start_processes(...)
+            try:
+                pc.wait(1)
+                .. do some other work
+            except SignalException as e:
+                pc.shutdown(e.sigval, timeout=30)
+
+        If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
+        received signal. If child processes will not terminate in the timeout time, the process will send
+        the SIGKILL.
+        """
+        if timeout == 0:
+            return self._poll()
+
+        if timeout < 0:
+            timeout = sys.maxsize
+
+        expiry = time.time() + timeout
+        while time.time() < expiry:
+            pr = self._poll()
+            if pr:
+                return pr
+            time.sleep(period)
+
+        return None
+
+    @abc.abstractmethod
+    def pids(self) -> Dict[int, int]:
+        """Return pids of processes mapped by their respective local_ranks."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
+        r"""
+        Terminates all processes managed by this context and cleans up any
+        meta resources (e.g. redirect, error_file files).
+        """
+        raise NotImplementedError()
+
+    def close(
+        self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
+    ) -> None:
+        r"""
+        Terminates all processes managed by this context and cleans up any
+        meta resources (e.g. redirect, error_file files).
+
+        Args:
+            death_sig: Death signal to terminate processes.
+            timeout: Time to wait for processes to finish, if process is
+                still alive after this time, it will be terminated via SIGKILL.
+        """
+        if not death_sig:
+            death_sig = _get_default_signal()
+        self._close(death_sig=death_sig, timeout=timeout)
+        if self._stdout_tail:
+            self._stdout_tail.stop()
+        if self._stderr_tail:
+            self._stderr_tail.stop()
+
+
+def get_std_cm(std_rd: str, redirect_fn):
+    if IS_WINDOWS or IS_MACOS or not std_rd:
+        return nullcontext()
+    else:
+        return redirect_fn(std_rd)
+
+
+def _wrap(
+    local_rank: int,
+    fn: Callable,
+    args: Dict[int, Tuple],
+    envs: Dict[int, Dict[str, str]],
+    stdout_redirects: Dict[int, str],  # redirect file for stdout (to console if None)
+    stderr_redirects: Dict[int, str],  # redirect file for stderr (to console if None)
+    ret_vals: Dict[int, mp.SimpleQueue],
+    queue_finished_reading_event: synchronize.Event,
+) -> None:
+    # get the per-rank params up front so we fail fast if no mapping is found
+    args_ = args[local_rank]
+    env_ = envs[local_rank]
+    ret_val_ = ret_vals[local_rank]
+
+    stdout_rd = stdout_redirects[local_rank]
+    stderr_rd = stderr_redirects[local_rank]
+
+    stdout_cm = get_std_cm(stdout_rd, redirect_stdout)
+    stderr_cm = get_std_cm(stderr_rd, redirect_stderr)
+
+    for k, v in env_.items():
+        os.environ[k] = v
+
+    with stdout_cm, stderr_cm:
+        ret = record(fn)(*args_)
+    ret_val_.put(ret)
+    queue_finished_reading_event.wait()
+
+
+class MultiprocessContext(PContext):
+    """``PContext`` holding worker processes invoked as a function."""
+
+    def __init__(
+        self,
+        name: str,
+        entrypoint: Callable,
+        args: Dict[int, Tuple],
+        envs: Dict[int, Dict[str, str]],
+        start_method: str,
+        logs_specs: LogsSpecs,
+        log_line_prefixes: Optional[Dict[int, str]] = None,
+    ):
+        super().__init__(
+            name,
+            entrypoint,
+            args,
+            envs,
+            logs_specs,
+            log_line_prefixes,
+        )
+
+        self.start_method = start_method
+        # each ret_val queue will always contain a single element.
+        self._ret_vals = {
+            local_rank: mp.get_context(self.start_method).SimpleQueue()
+            for local_rank in range(self.nprocs)
+        }
+
+        # see comments in ``join()`` for what this is
+        self._return_values: Dict[int, Any] = {}
+        self._pc: Optional[mp.ProcessContext] = None
+        # Note: set method should ONLY be invoked for the use case when all processes finished
+        # successfully. If any process died on event.wait() calling set() method will deadlock.
+        self._worker_finished_event = mp.get_context(self.start_method).Event()
+
+    def _start(self):
+        if self._pc:
+            raise ValueError(
+                "The process context already initialized."
+                " Most likely the start method got called twice."
+            )
+        self._pc = mp.start_processes(
+            fn=_wrap,
+            args=(
+                self.entrypoint,
+                self.args,
+                self.envs,
+                self.stdouts,
+                self.stderrs,
+                self._ret_vals,
+                self._worker_finished_event,
+            ),
+            nprocs=self.nprocs,
+            join=False,
+            daemon=False,
+            start_method=self.start_method,
+        )
+
+    def _is_done(self) -> bool:
+        return len(self._return_values) == self.nprocs
+
+    def _poll(self) -> Optional[RunProcsResult]:
+        assert self._pc is not None  # assertion for mypy type checker
+
+        try:
+            # torch.mp.ProcessContext Throws an Exception if some/all of
+            # worker processes failed
+            # timeout < 0 checks worker status and return immediately
+            # Join will never return success since we use synchronize.Event to wait
+            # for all processes to finish.
+            self._pc.join(-1)
+
+            # IMPORTANT: we use multiprocessing.Queue to carry worker return values
+            # back to the parent, the worker process will wait before terminating
+            # until all the buffered items are fed by the feeder thread to the underlying
+            # pipe. Hence to prevent deadlocks on large return values,
+            # we opportunistically try queue.get on each join call
+            # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
+            for local_rank in range(0, self.nprocs):
+                return_queue = self._ret_vals[local_rank]
+                if not return_queue.empty():
+                    # save the return values temporarily into a member var
+                    self._return_values[local_rank] = return_queue.get()
+
+            if self._is_done():
+                # we should ALWAYS have ALL the return values when all the processes are done
+                self._worker_finished_event.set()
+                # Wait untill all processes are finished. At this point workers finished executing
+                # user function
+                self._pc.join()
+                _validate_full_rank(
+                    self._return_values, self.nprocs, "return_value queue"
+                )
+                self.close()
+                return RunProcsResult(
+                    return_values=self._return_values,
+                    stdouts=self.stdouts,
+                    stderrs=self.stderrs,
+                )
+            else:
+                return None
+        except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
+            failed_local_rank = e.error_index
+
+            # entrypoint for MultiprocessContext will always be a Callable
+            fn_name = self.entrypoint.__qualname__  # type: ignore[union-attr]
+            failed_proc = self._pc.processes[failed_local_rank]
+            error_filepath = self.error_files[failed_local_rank]
+
+            log.exception(
+                "failed (exitcode: %s)"
+                " local_rank: %s (pid: %s)"
+                " of fn: %s (start_method: %s)",
+                failed_proc.exitcode,
+                failed_local_rank, e.pid,
+                fn_name, self.start_method,
+            )
+
+            self.close()
+            return RunProcsResult(
+                failures={
+                    failed_local_rank: ProcessFailure(
+                        local_rank=failed_local_rank,
+                        pid=e.pid,
+                        exitcode=failed_proc.exitcode,
+                        error_file=error_filepath,
+                    )
+                },
+                stdouts=self.stdouts,
+                stderrs=self.stderrs,
+            )
+
+    def pids(self) -> Dict[int, int]:
+        assert self._pc is not None  # assertion for mypy type checking
+        return dict(enumerate(self._pc.pids()))
+
+    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
+        if not self._pc:
+            return
+        for proc in self._pc.processes:
+            if proc.is_alive():
+                log.warning("Closing process %s via signal %s", proc.pid, death_sig.name)
+                try:
+                    os.kill(proc.pid, death_sig)
+                except ProcessLookupError:
+                    # If the process exited because of some reason,
+                    # `ProcessLookupError` will be raised, it is safe to ignore it.
+                    pass
+        end = time.monotonic() + timeout
+        for proc in self._pc.processes:
+            time_to_wait = end - time.monotonic()
+            if time_to_wait <= 0:
+                break
+            proc.join(time_to_wait)
+        for proc in self._pc.processes:
+            if proc.is_alive():
+                log.warning(
+                    "Unable to shutdown process %s via %s, forcefully exiting via %s",
+                    proc.pid, death_sig, _get_kill_signal()
+                )
+                try:
+                    os.kill(proc.pid, _get_kill_signal())
+                except ProcessLookupError:
+                    # If the process exited because of some reason,
+                    # `ProcessLookupError` will be raised, it is safe to ignore it.
+                    pass
+            proc.join()
+
+class SubprocessContext(PContext):
+    """``PContext`` holding worker processes invoked as a binary."""
+
+    def __init__(
+        self,
+        name: str,
+        entrypoint: str,
+        args: Dict[int, Tuple],
+        envs: Dict[int, Dict[str, str]],
+        logs_specs: LogsSpecs,
+        log_line_prefixes: Optional[Dict[int, str]] = None,
+
+    ):
+        super().__init__(
+            name,
+            entrypoint,
+            args,
+            envs,
+            logs_specs,
+            log_line_prefixes,
+        )
+
+        # state vector; _vdone[local_rank] -> is local_rank finished or not
+        self._running_local_ranks: Set[int] = set(range(self.nprocs))
+        self._failures: Dict[int, ProcessFailure] = {}
+        self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
+
+    def _start(self):
+        if self.subprocess_handlers:
+            raise ValueError(
+                "The subprocess handlers already initialized. Most likely the start method got called twice."
+            )
+        self.subprocess_handlers = {
+            local_rank: get_subprocess_handler(
+                entrypoint=self.entrypoint,  # type: ignore[arg-type] # entrypoint is always a str
+                args=self.args[local_rank],
+                env=self.envs[local_rank],
+                stdout=self.stdouts[local_rank],
+                stderr=self.stderrs[local_rank],
+                local_rank_id=local_rank,
+            )
+            for local_rank in range(self.nprocs)
+        }
+
+    def _poll(self) -> Optional[RunProcsResult]:
+        done_local_ranks = set()
+        for local_rank in self._running_local_ranks:
+            handler = self.subprocess_handlers[local_rank]
+            exitcode = handler.proc.poll()
+            if exitcode is not None:
+                done_local_ranks.add(local_rank)
+                if exitcode != 0:  # failed or signaled
+                    self._failures[local_rank] = ProcessFailure(
+                        local_rank=local_rank,
+                        pid=handler.proc.pid,
+                        exitcode=exitcode,
+                        error_file=self.error_files[local_rank],
+                    )
+                # else: --> succeeded; nothing to do
+
+        self._running_local_ranks.difference_update(done_local_ranks)
+
+        # if ALL procs are finished or ANY have failed
+        if not self._running_local_ranks or self._failures:
+            self.close()  # terminate all running procs
+            result = RunProcsResult(
+                failures=self._failures,
+                stdouts=self.stdouts,
+                stderrs=self.stderrs,
+            )
+            if result.is_failed():
+                first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
+                log.error(
+                    "failed (exitcode: %s)"
+                    " local_rank: %s (pid: %s)"
+                    " of binary: %s",
+                    first_failure.exitcode, first_failure.local_rank, first_failure.pid, self.entrypoint
+                )
+            else:
+                # Populate return with dummy values. This provides consistency with MultiprocessingHandler
+                result.return_values = dict.fromkeys(range(self.nprocs))
+
+            return result
+        else:  # there are no failures and procs still running
+            return None
+
+    def pids(self) -> Dict[int, int]:
+        return {
+            local_rank: sh.proc.pid
+            for local_rank, sh in self.subprocess_handlers.items()
+        }
+
+    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
+        if not self.subprocess_handlers:
+            return
+        for handler in self.subprocess_handlers.values():
+            if handler.proc.poll() is None:
+                log.warning(
+                    "Sending process %s closing signal %s", handler.proc.pid, death_sig.name
+                )
+                handler.close(death_sig=death_sig)
+        end = time.monotonic() + timeout
+        for handler in self.subprocess_handlers.values():
+            time_to_wait = end - time.monotonic()
+            if time_to_wait <= 0:
+                break
+            try:
+                handler.proc.wait(time_to_wait)
+            except subprocess.TimeoutExpired:
+                # Ignore the timeout expired exception, since
+                # the child process will be forcefully terminated via SIGKILL
+                pass
+        for handler in self.subprocess_handlers.values():
+            if handler.proc.poll() is None:
+                log.warning(
+                    "Unable to shutdown process %s via %s, forcefully exiting via %s",
+                    handler.proc.pid, death_sig, _get_kill_signal()
+                )
+                handler.close(death_sig=_get_kill_signal())
+                handler.proc.wait()
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c06f1dcb2b61233b0ccf426f72806a3c124db61
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py
@@ -0,0 +1,375 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Each host in a distributed PyTorch job runs with a single TorchElastic agent,
+and multiple workers (as children processes of the TorchElastic agent).
+Since the workers are user-provided (your PyTorch script/job), TorchElastic
+has a way to propagate errors on the trainers through the agent and up to the
+scheduler, which ultimately informs the end-user about the state of the job
+and applies any retry policies.
+
+TorchElastic categorizes errors into 3 categories:
+
++----------------+----------------+--------------------------------------------------------------+
+| Category       | Sub-Category   |  Description                                                 |
++================+================+==============================================================+
+| User Error     | Input Error    | invalid inputs to TorchElastic APIs (e.g. min > max nodes)   |
+|                +----------------+--------------------------------------------------------------+
+|                | Worker Failure | any failures on the worker child process                     |
++----------------+----------------+--------------------------------------------------------------+
+| Platform Error |      n/a       | failures caused by the agent                                 |
++----------------+----------------+--------------------------------------------------------------+
+| Infra Error    |      n/a       | failures outside the domain of the agent and workers         |
+|                |                | (e.g. host failures)                                         |
++----------------+----------------+--------------------------------------------------------------+
+
+All errors other than "Worker Failure" are either raised canonically from the
+agent process or implicitly or explicitly crash the agent process. So the
+standard language (python) provided exception handling strategies apply.
+
+Worker Failures are special because the exception/failure originates on a different
+process from the agent so the error needs to be propagated inter-process
+(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
+
+TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes`
+to launch the workers which has a simple file based inter-process error propagation
+built-in.
+
+Any function or binary entrypoint decorated with :func:`record`
+will write uncaught exceptions (with the trace information) to a file specified by the
+environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
+sets this env var on each child it launches, then aggregates the error files for all
+children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
+"""
+
+import json
+import os
+import signal
+import socket
+import time
+import warnings
+from dataclasses import dataclass, field
+from datetime import datetime
+from functools import wraps
+from string import Template
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
+
+from torch.distributed.elastic.utils.logging import get_logger
+
+from .error_handler import ErrorHandler  # noqa: F401
+from .handlers import get_error_handler  # noqa: F401
+
+__all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"]
+
+log = get_logger(__name__)
+
+
+JSON = Dict
+
+_EMPTY_ERROR_DATA = {"message": ""}
+_NOT_AVAILABLE = ""
+
+T = TypeVar("T")
+
+
+@dataclass
+class ProcessFailure:
+    """
+    Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
+
+    Tries to read the failure timestamp from the provided ``error_file``,
+    if the ``error_file`` does not exist, the timestamp is the current
+    timestamp (seconds since epoch).
+
+    The ``message`` field is a concise explanation of the failure. If
+    the error file exists then the message is obtained from the error file.
+    Otherwise one is generated based on the failure signature.
+
+    .. note:: It is assumed that the ``error_file`` is written by
+              ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
+              Otherwise the behavior is undefined.
+
+    """
+
+    local_rank: int
+    pid: int
+    exitcode: int
+    error_file: str
+    error_file_data: JSON = field(init=False)
+    message: str = field(init=False)
+    timestamp: int = field(init=False)
+
+    def __post_init__(self):
+        self.error_file_data = _EMPTY_ERROR_DATA
+        if os.path.isfile(self.error_file):
+            try:
+                with open(self.error_file) as fp:
+                    self.error_file_data = json.load(fp)
+                    log.debug(
+                        "User process failed with error data: %s", json.dumps(self.error_file_data, indent=2)
+                    )
+                    self.message, self.timestamp = self._get_error_data(
+                        self.error_file_data
+                    )
+            except Exception:
+                log.exception("Failed to parse reply file: %s", self.error_file)
+                raise
+        else:
+            self._set_no_reply_file()
+
+        # make up an informative message if not already present
+        if not self.message:
+            # signals typically do not generate an error file message
+            if self.exitcode < 0:
+                self.message = (
+                    f"Signal {-self.exitcode} ({self.signal_name()})"
+                    f" received by PID {self.pid}"
+                )
+            else:
+                self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
+
+    def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
+        message = error_file_data["message"]
+        if isinstance(message, str):
+            timestamp = int(error_file_data.get("timestamp", 0))
+        else:
+            timestamp = int(message["extraInfo"]["timestamp"])
+        return (message, timestamp)
+
+    def _set_no_reply_file(self):
+        self.error_file = _NOT_AVAILABLE
+        self.error_file_data = _EMPTY_ERROR_DATA
+        self.message = ""
+        self.timestamp = int(time.time())
+
+    def signal_name(self) -> str:
+        if self.exitcode < 0:
+            # We don't want to kill the parent process trying to find the signal name.
+            # if the signal doesn't map to a known name, use not available.
+            try:
+                return signal.Signals(-self.exitcode).name
+            except Exception:
+                return _NOT_AVAILABLE
+        else:
+            return _NOT_AVAILABLE
+
+    def timestamp_isoformat(self):
+        """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
+        return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
+
+
+GlobalRank = int
+
+_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
+  time      : ${time}
+  host      : ${hostname}
+  rank      : ${rank} (local_rank: ${local_rank})
+  exitcode  : ${exitcode} (pid: ${pid})
+  error_file: ${error_file}
+  traceback : ${message}"""
+
+# extra new lines before and after are intentional
+_MSG_FORMAT_TEMPLATE = """
+${boarder}
+${title}
+${section}
+Failures:
+${other_failures}
+${section}
+Root Cause (first observed failure):
+${root_failure}
+${boarder}"""
+
+
+class ChildFailedError(Exception):
+    """
+    Special exception type that can be raised from a function annotated with the
+    ``@record`` decorator to have the child process' (root exception) propagate
+    up the stack as-is (e.g. without being wrapped in the parent's traceback).
+
+    Useful in cases where the parent is a simple nanny process
+    and the child (worker) processes are actually doing meaningful compute.
+    In this case, errors typically occur on the child process as the parent
+    is not doing anything non-trivial, and child errors should be propagated
+    to the scheduler for accurate root cause diagnostics.
+
+    .. note:: The propagation relies on error files rather than exception handling to
+              support both function and binary launches.
+
+    Example:
+    ::
+
+     # process tree on a host (container)
+     0: scheduler-init-process:
+                |- 1: torchelastic_agent:
+                         |- 2: trainer_0 (ok)
+                         |- 3: trainer_1 (fail) -> error.json
+                         |- ...
+                         |- n+2: trainer_n (ok)
+                |- n+3: other processes
+                |- ...
+
+    In the example above, trainer 1's failure (written into error.json) is
+    the root cause and should be reported to the scheduler's init process.
+    The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
+    upon detecting trainer 1's failure which would propagate the contents
+    of trainer 1's error file to the scheduler's init process.
+    """
+
+    def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
+        self.name = name
+        self.failures = failures
+        assert (
+            self.failures
+        )  # does not make sense to create a ChildFaileError with no failures
+        super().__init__(self.format_msg())
+
+    def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
+        rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
+        return rank, self.failures[rank]
+
+    def format_msg(self, boarder_delim="=", section_delim="-"):
+        title = f"{self.name} FAILED"
+        root_rank, root_failure = self.get_first_failure()
+
+        root_failure_fmt: str = ""
+        other_failures_fmt: List[str] = []
+        width = len(title)
+        for idx, (rank, failure) in enumerate(self.failures.items()):
+            fmt, w = self._format_failure(idx, rank, failure)
+            width = max(width, w)
+            if rank == root_rank:
+                root_failure_fmt = fmt
+            else:
+                other_failures_fmt.append(fmt)
+
+        # upper boundary on width
+        width = min(width, 60)
+
+        return Template(_MSG_FORMAT_TEMPLATE).substitute(
+            boarder=boarder_delim * width,
+            title=title,
+            section=section_delim * width,
+            root_failure=root_failure_fmt,
+            other_failures="\n".join(other_failures_fmt or ["  "]),
+        )
+
+    def _format_failure(
+        self, idx: int, rank: int, failure: ProcessFailure
+    ) -> Tuple[str, int]:
+
+        # failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
+        # or a dict (json) of the form
+        # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
+        # so the display logic is:
+        # 1. if failure.message is not a dict (it is a str) just show it as is
+        # 2. else try to get the traceback (py_callstack)
+        # 3.      if the traceback is not there, use the message
+        # 4.      if the message  is not there show 
+        msg = failure.message
+        if isinstance(failure.message, dict):
+            msg = (
+                failure.message.get("extraInfo", {})
+                .get("py_callstack", failure.message.get("message", ""))
+                .replace("\n", "\n  ")  # to properly indent the traceback
+            )
+
+        fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
+            idx=idx,
+            time=failure.timestamp_isoformat(),
+            hostname=socket.getfqdn(),
+            rank=rank,
+            local_rank=failure.local_rank,
+            exitcode=failure.exitcode,
+            pid=failure.pid,
+            error_file=failure.error_file,
+            message=msg,
+        )
+        width = 0
+        for line in fmt.split("\n"):
+            width = max(width, len(line))
+        return fmt, width
+
+
+def record(
+    fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
+) -> Callable[..., T]:
+    """
+    Syntactic sugar to record errors/exceptions that happened in the decorated
+    function using the provided ``error_handler``.
+
+    Using this decorator is equivalent to:
+
+    ::
+
+     error_handler = get_error_handler()
+     error_handler.initialize()
+     try:
+        foobar()
+     except ChildFailedError as e:
+        _, failure = e.get_first_failure()
+        error_handler.dump_error_file(failure.error_file, failure.exitcode)
+        raise
+     except Exception as e:
+        error_handler.record(e)
+        raise
+
+    .. important:: use this decorator once per process at the top level method,
+                   typically this is the main method.
+
+    Example
+
+    ::
+
+     @record
+     def main():
+         pass
+
+     if __name__=="__main__":
+        main()
+
+    """
+    if not error_handler:
+        error_handler = get_error_handler()
+
+    def wrap(f):
+        @wraps(f)
+        def wrapper(*args, **kwargs):
+            assert error_handler is not None  # assertion for mypy type checker
+            error_handler.initialize()
+            try:
+                return f(*args, **kwargs)
+            except SystemExit as se:
+                # For run_path based entrypoints, SystemExit with code = 0 will never exit.
+                # Handling it here by returning a value:
+                if se.code == 0:
+                    return None
+                else:
+                    raise
+            except ChildFailedError as e:
+                rank, failure = e.get_first_failure()
+                if failure.error_file != _NOT_AVAILABLE:
+                    error_handler.dump_error_file(failure.error_file, failure.exitcode)
+                else:
+                    log.info(
+                        (
+                            "local_rank %s FAILED with no error file."
+                            " Decorate your entrypoint fn with @record for traceback info."
+                            " See: https://pytorch.org/docs/stable/elastic/errors.html",
+                            rank
+                        )
+                    )
+                raise
+            except Exception as e:
+                error_handler.record_exception(e)
+                raise
+
+        return wrapper
+
+    return wrap(fn)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b08155fe28c931f8a6bddf27644caf35646a9c4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f61d3650956ab271bc578bdc9daea4700063982
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b176a6050b4e27cbda8f56ad21a8f479c0bd40a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9986e7ac7822c579ae5d64d03a9754359303b8a7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/error_handler.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import faulthandler
+import json
+import logging
+import os
+import time
+import traceback
+import warnings
+from typing import Any, Dict, Optional
+
+__all__ = ['ErrorHandler']
+
+log = logging.getLogger(__name__)
+
+
+class ErrorHandler:
+    """
+    Write the provided exception object along with some other metadata about
+    the error in a structured way in JSON format to an error file specified by the
+    environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment
+    variable is not set, then simply logs the contents of what would have been
+    written to the error file.
+
+    This handler may be subclassed to customize the handling of the error.
+    Subclasses should override ``initialize()`` and ``record_exception()``.
+    """
+
+    def _get_error_file_path(self) -> Optional[str]:
+        """
+        Return the error file path.
+
+        May return ``None`` to have the structured error be logged only.
+        """
+        return os.environ.get("TORCHELASTIC_ERROR_FILE", None)
+
+    def initialize(self) -> None:
+        """
+        Call prior to running code that we wish to capture errors/exceptions.
+
+        Typically registers signal/fault handlers. Users can override this
+        function to add custom initialization/registrations that aid in
+        propagation/information of errors/signals/exceptions/faults.
+        """
+        try:
+            faulthandler.enable(all_threads=True)
+        except Exception as e:
+            warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}")
+
+    def _write_error_file(self, file_path: str, error_msg: str) -> None:
+        """Write error message to the file."""
+        try:
+            with open(file_path, "w") as fp:
+                fp.write(error_msg)
+        except Exception as e:
+            warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}")
+
+    def record_exception(self, e: BaseException) -> None:
+        """
+        Write a structured information about the exception into an error file in JSON format.
+
+        If the error file cannot be determined, then logs the content
+        that would have been written to the error file.
+        """
+        file = self._get_error_file_path()
+        if file:
+            data = {
+                "message": {
+                    "message": f"{type(e).__name__}: {e}",
+                    "extraInfo": {
+                        "py_callstack": traceback.format_exc(),
+                        "timestamp": str(int(time.time())),
+                    },
+                }
+            }
+            with open(file, "w") as fp:
+                json.dump(data, fp)
+
+    def override_error_code_in_rootcause_data(
+        self,
+        rootcause_error_file: str,
+        rootcause_error: Dict[str, Any],
+        error_code: int = 0,
+    ):
+        """Modify the rootcause_error read from the file, to correctly set the exit code."""
+        if "message" not in rootcause_error:
+            log.warning(
+                "child error file (%s) does not have field `message`. \n"
+                "cannot override error code: %s",
+                rootcause_error_file, error_code
+            )
+        elif isinstance(rootcause_error["message"], str):
+            log.warning(
+                "child error file (%s) has a new message format. \n"
+                "skipping error code override",
+                rootcause_error_file
+            )
+        else:
+            rootcause_error["message"]["errorCode"] = error_code
+
+    def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
+        """Dump parent error file from child process's root cause error and error code."""
+        with open(rootcause_error_file) as fp:
+            rootcause_error = json.load(fp)
+            # Override error code since the child process cannot capture the error code if it
+            # is terminated by signals like SIGSEGV.
+            if error_code:
+                self.override_error_code_in_rootcause_data(rootcause_error_file, rootcause_error, error_code)
+            log.debug(
+                "child error file (%s) contents:\n"
+                "%s",
+                rootcause_error_file, json.dumps(rootcause_error, indent=2)
+            )
+
+        my_error_file = self._get_error_file_path()
+        if my_error_file:
+            # Guard against existing error files
+            # This can happen when the child is created using multiprocessing
+            # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the
+            # parent and child to specify the error files (respectively)
+            # because the env vars on the child is set in the wrapper function
+            # and by default the child inherits the parent's env vars, if the child
+            # process receives a signal before the wrapper function kicks in
+            # and the signal handler writes to the error file, then the child
+            # will write to the parent's error file. In this case just log the
+            # original error file contents and overwrite the error file.
+            self._rm(my_error_file)
+            self._write_error_file(my_error_file, json.dumps(rootcause_error))
+            log.info("dumped error file to parent's %s", my_error_file)
+        else:
+            log.error(
+                "no error file defined for parent, to copy child error file (%s)", rootcause_error_file
+            )
+
+    def _rm(self, my_error_file):
+        if os.path.isfile(my_error_file):
+            # Log the contents of the original file.
+            with open(my_error_file) as fp:
+                try:
+                    original = json.dumps(json.load(fp), indent=2)
+                    log.warning(
+                        "%s already exists"
+                        " and will be overwritten."
+                        " Original contents:\n%s",
+                        my_error_file, original
+                    )
+                except json.decoder.JSONDecodeError as err:
+                    log.warning(
+                        "%s already exists"
+                        " and will be overwritten."
+                        " Unable to load original contents:\n",
+                        my_error_file
+                    )
+            os.remove(my_error_file)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d4eac0899df723686a433ef75ba0760623594af
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+# Multiprocessing error-reporting module
+
+
+from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
+
+__all__ = ['get_error_handler']
+
+def get_error_handler():
+    return ErrorHandler()
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/redirects.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/redirects.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c3bcb5cdf1f43e10b9b43cab2684298c8dc7a64
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/redirects.py
@@ -0,0 +1,102 @@
+# !/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Taken and modified from original source:
+# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
+import ctypes
+import logging
+import os
+import sys
+from contextlib import contextmanager
+from functools import partial
+
+IS_WINDOWS = sys.platform == "win32"
+IS_MACOS = sys.platform == "darwin"
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_libc():
+    if IS_WINDOWS or IS_MACOS:
+        logger.warning(
+            "NOTE: Redirects are currently not supported in Windows or MacOs."
+        )
+        return None
+    else:
+        return ctypes.CDLL("libc.so.6")
+
+
+libc = get_libc()
+
+
+def _c_std(stream: str):
+    return ctypes.c_void_p.in_dll(libc, stream)
+
+
+def _python_std(stream: str):
+    return {"stdout": sys.stdout, "stderr": sys.stderr}[stream]
+
+
+_VALID_STD = {"stdout", "stderr"}
+
+
+@contextmanager
+def redirect(std: str, to_file: str):
+    """
+    Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``.
+
+    This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``).
+    See usage for details.
+
+    Directory of ``dst_filename`` is assumed to exist and the destination file
+    is overwritten if it already exists.
+
+    .. note:: Due to buffering cross source writes are not guaranteed to
+              appear in wall-clock order. For instance in the example below
+              it is possible for the C-outputs to appear before the python
+              outputs in the log file.
+
+    Usage:
+
+    ::
+
+     # syntactic-sugar for redirect("stdout", "tmp/stdout.log")
+     with redirect_stdout("/tmp/stdout.log"):
+        print("python stdouts are redirected")
+        libc = ctypes.CDLL("libc.so.6")
+        libc.printf(b"c stdouts are also redirected"
+        os.system("echo system stdouts are also redirected")
+
+     print("stdout restored")
+
+    """
+    if std not in _VALID_STD:
+        raise ValueError(
+            f"unknown standard stream <{std}>, must be one of {_VALID_STD}"
+        )
+
+    c_std = _c_std(std)
+    python_std = _python_std(std)
+    std_fd = python_std.fileno()
+
+    def _redirect(dst):
+        libc.fflush(c_std)
+        python_std.flush()
+        os.dup2(dst.fileno(), std_fd)
+
+    with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
+        _redirect(dst)
+        try:
+            yield
+        finally:
+            _redirect(orig_std)
+
+
+redirect_stdout = partial(redirect, "stdout")
+redirect_stderr = partial(redirect, "stderr")
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc1fe591bca2e753727c8365d2ea693d0f61d966
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from torch.distributed.elastic.multiprocessing.subprocess_handler.handlers import (
+    get_subprocess_handler,
+)
+from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
+    SubprocessHandler,
+)
+
+__all__ = ["SubprocessHandler", "get_subprocess_handler"]
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..adbd2dbfcb52f031e36b0522e797c3149b8740f7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..142591654ac2526c627c1e9af888334585ae2fca
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d13a9e15b1482945b3172934085effdaa991e13
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e372f45b88ebf680d9c9dc6b65739cf2a2ff88
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py
@@ -0,0 +1,32 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import Dict, Tuple
+
+from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
+    SubprocessHandler,
+)
+
+__all__ = ["get_subprocess_handler"]
+
+
+def get_subprocess_handler(
+    entrypoint: str,
+    args: Tuple,
+    env: Dict[str, str],
+    stdout: str,
+    stderr: str,
+    local_rank_id: int,
+):
+    return SubprocessHandler(
+        entrypoint=entrypoint,
+        args=args,
+        env=env,
+        stdout=stdout,
+        stderr=stderr,
+        local_rank_id=local_rank_id,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbed2fc5d8594d8f742a60338aeb5d9647ca952e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import signal
+import subprocess
+import sys
+
+from typing import Any, Dict, Optional, Tuple
+
+__all__ = ["SubprocessHandler"]
+
+IS_WINDOWS = sys.platform == "win32"
+
+
+def _get_default_signal() -> signal.Signals:
+    """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
+    if IS_WINDOWS:
+        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
+    else:
+        return signal.SIGTERM
+
+
+class SubprocessHandler:
+    """
+    Convenience wrapper around python's ``subprocess.Popen``. Keeps track of
+    meta-objects associated to the process (e.g. stdout and stderr redirect fds).
+    """
+
+    def __init__(
+        self,
+        entrypoint: str,
+        args: Tuple,
+        env: Dict[str, str],
+        stdout: str,
+        stderr: str,
+        local_rank_id: int,
+    ):
+        self._stdout = open(stdout, "w") if stdout else None
+        self._stderr = open(stderr, "w") if stderr else None
+        # inherit parent environment vars
+        env_vars = os.environ.copy()
+        env_vars.update(env)
+
+        args_str = (entrypoint, *[str(e) for e in args])
+        self.local_rank_id = local_rank_id
+        self.proc: subprocess.Popen = self._popen(args_str, env_vars)
+
+    def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
+        kwargs: Dict[str, Any] = {}
+        if not IS_WINDOWS:
+            kwargs["start_new_session"] = True
+        return subprocess.Popen(
+            # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
+            #  _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
+            #  `Tuple[str, *Tuple[Any, ...]]`.
+            args=args,
+            env=env,
+            stdout=self._stdout,
+            stderr=self._stderr,
+            **kwargs,
+        )
+
+    def close(self, death_sig: Optional[signal.Signals] = None) -> None:
+        if not death_sig:
+            death_sig = _get_default_signal()
+        if IS_WINDOWS:
+            self.proc.send_signal(death_sig)
+        else:
+            os.killpg(self.proc.pid, death_sig)
+        if self._stdout:
+            self._stdout.close()
+        if self._stderr:
+            self._stderr.close()
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63ba1feceff2d198ced71f2e5ec568455f1184d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/multiprocessing/tail_log.py
@@ -0,0 +1,153 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import time
+from concurrent.futures._base import Future
+from concurrent.futures.thread import ThreadPoolExecutor
+from threading import Event
+from typing import Dict, List, Optional, TextIO
+
+__all__ = ["tail_logfile", "TailLog"]
+
+log = logging.getLogger(__name__)
+
+
+def tail_logfile(
+    header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
+):
+
+    while not os.path.exists(file):
+        if finished.is_set():
+            return
+        time.sleep(interval_sec)
+
+    with open(file, errors="replace") as fp:
+        while True:
+            line = fp.readline()
+
+            if line:
+                dst.write(f"{header}{line}")
+            else:  # reached EOF
+                if finished.is_set():
+                    # log line producer is finished
+                    break
+                else:
+                    # log line producer is still going
+                    # wait for a bit before looping again
+                    time.sleep(interval_sec)
+
+
+class TailLog:
+    """
+    Tail the given log files.
+
+    The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until
+    the log files are created by the producer and will tail the contents of the
+    log files until the ``stop()`` method is called.
+
+    .. warning:: ``TailLog`` will wait indefinitely for the log file to be created!
+
+    Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``,
+    where the ``name`` is user-provided and ``idx`` is the index of the log file
+    in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the
+    header for each log file.
+
+    Usage:
+
+    ::
+
+     log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"}
+     tailer = TailLog("trainer", log_files, sys.stdout).start()
+     # actually run the trainers to produce 0_stdout.log and 1_stdout.log
+     run_trainers()
+     tailer.stop()
+
+     # once run_trainers() start writing the ##_stdout.log files
+     # the tailer will print to sys.stdout:
+     # >>> [trainer0]:log_line1
+     # >>> [trainer1]:log_line1
+     # >>> [trainer0]:log_line2
+     # >>> [trainer0]:log_line3
+     # >>> [trainer1]:log_line2
+
+    .. note:: Due to buffering log lines between files may not necessarily
+              be printed out in order. You should configure your application's
+              logger to suffix each log line with a proper timestamp.
+
+    """
+
+    def __init__(
+        self,
+        name: str,
+        log_files: Dict[int, str],
+        dst: TextIO,
+        log_line_prefixes: Optional[Dict[int, str]] = None,
+        interval_sec: float = 0.1,
+    ):
+        n = len(log_files)
+        self._threadpool = None
+        if n > 0:
+            self._threadpool = ThreadPoolExecutor(
+                max_workers=n,
+                thread_name_prefix=f"{self.__class__.__qualname__}_{name}",
+            )
+
+        self._name = name
+        self._dst = dst
+        self._log_files = log_files
+        self._log_line_prefixes = log_line_prefixes
+        self._finished_events: Dict[int, Event] = {
+            local_rank: Event() for local_rank in log_files.keys()
+        }
+        self._futs: List[Future] = []
+        self._interval_sec = interval_sec
+        self._stopped = False
+
+    def start(self) -> "TailLog":
+        if not self._threadpool:
+            return self
+
+        for local_rank, file in self._log_files.items():
+            header = f"[{self._name}{local_rank}]:"
+            if self._log_line_prefixes and local_rank in self._log_line_prefixes:
+                header = self._log_line_prefixes[local_rank]
+            self._futs.append(
+                self._threadpool.submit(
+                    tail_logfile,
+                    header=header,
+                    file=file,
+                    dst=self._dst,
+                    finished=self._finished_events[local_rank],
+                    interval_sec=self._interval_sec,
+                )
+            )
+        return self
+
+    def stop(self) -> None:
+        for finished in self._finished_events.values():
+            finished.set()
+
+        for local_rank, f in enumerate(self._futs):
+            try:
+                f.result()
+            except Exception as e:
+                log.error(
+                    "error in log tailor for %s%s. %s: %s",
+                    self._name, local_rank,
+                    e.__class__.__qualname__, e,
+                )
+
+        if self._threadpool:
+            self._threadpool.shutdown(wait=True)
+
+        self._stopped = True
+
+    def stopped(self) -> bool:
+        return self._stopped
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcff834a35c0c47a8e0eddc838aa92d93fa6d7cf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__init__.py
@@ -0,0 +1,150 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+In the context of Torch Distributed Elastic we use the term *rendezvous* to
+refer to a particular functionality that combines a **distributed
+synchronization** primitive with **peer discovery**.
+
+It is used by Torch Distributed Elastic to gather participants of a training
+job (i.e. nodes) such that they all agree on the same list of participants and
+everyone's roles, as well as make a consistent collective decision on when
+training can begin/resume.
+
+Torch Distributed Elastic rendezvous provides the following critical
+functionalities:
+
+**Barrier**:
+
+Nodes performing rendezvous will all block until the rendezvous is considered
+complete - this happens when at least ``min`` total number of nodes have joined
+the rendezvous barrier (for the same job). This also implies the barrier is not
+necessarily of fixed size.
+
+There's an additional small waiting time after reaching ``min`` number of
+nodes - this is used to ensure the rendezvous is not completed "too quickly"
+(which could potentially exclude additional nodes attempting to join at
+approximately the same time).
+
+If ``max`` number of nodes is gathered at the barrier, the rendezvous is
+completed immediately.
+
+There's also an overall timeout which causes the rendezvous to fail if ``min``
+number of nodes is never reached - this is meant to be a simple fail-safe to
+help release partially allocated job resources, in case there's a problem with
+the resource manager, and is meant to be interpreted as non-retryable.
+
+**Exclusivity**:
+
+A simple distributed barrier would not be sufficient, as we also need to ensure
+that only one group of nodes exists at any given time (for a given job). In
+other words, new nodes (i.e. joining late) should not be able to form a parallel
+independent group of workers for the same job.
+
+Torch Distributed Elastic rendezvous ensures that if a group of nodes has
+already completed a rendezvous (and hence might already be training), then
+additional "late" nodes attempting to rendezvous will only announce themselves
+as waiting, and will have to wait until the (previously completed) existing
+rendezvous is destroyed first.
+
+**Consistency**:
+
+When a rendezvous is completed, all its members will agree on the job membership
+and everyone's role in it. This role is represented using an integer, called
+rank, that is between between 0 and world size.
+
+Note that ranks are *not stable*, in the sense that the same node can be
+assigned a different rank in the next (re-)rendezvous.
+
+**Fault-tolerance**:
+
+Torch Distributed Elastic rendezvous is designed to tolerate node failures
+during the rendezvous process. Should a process crash (or lose network
+connectivity, etc), between joining the rendezvous and it being completed, then
+a re-rendezvous with remaining healthy nodes will happen automatically.
+
+A node can also fail *after* it has completed (or *has been observered* by other
+nodes to have completed) the rendezvous - this scenario will be handled by the
+Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a
+re-rendezvous).
+
+**Shared key-value store**:
+
+When the rendezvous is completed, a shared key-value store is created and
+returned. This store implements a ``torch.distributed.Store`` API (see
+`distributed communication docs
+`__).
+
+This store is only shared by the members of the completed rendezvous. It
+is intended to be used by Torch Distributed Elastic to exchange information
+necessary to initialize job control and data-planes.
+
+**Waiting workers and rendezvous closing**:
+
+Torch Distributed Elastic rendezvous handler object provides additional
+functionalities, which are technically not part of the rendezvous process:
+
+1. Querying how many workers arrived late at the barrier, who can participate in
+   *next* rendezvous.
+
+2. Setting the rendezvous *closed* to signal all nodes not to participate in
+   next rendezvous.
+
+**DynamicRendezvousHandler**:
+
+Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler`
+class that implements the rendezvous mechanism described above. It is a backend-
+agnostic type that expects a particular :py:class:`.RendezvousBackend` instance
+to be specified during construction.
+
+Torch distributed users can either implement their own backend type or use one
+of the following implementations that come with PyTorch:
+
+- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default
+  ``TCPStore``) as the rendezvous backend. The main advantage of using a C10d
+  store is that it requires no 3rd-party dependency (such as etcd) to establish
+  a rendezvous.
+- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy
+  :py:class:`.EtcdRendezvousHandler` class. Passing an
+  :py:class:`.EtcdRendezvousBackend` instance to
+  :py:class:`.DynamicRendezvousHandler` is functionally equivalent to
+  instantiating an :py:class:`.EtcdRendezvousHandler`.
+
+  ::
+
+     store = TCPStore("localhost")
+
+     backend = C10dRendezvousBackend(store, "my_run_id")
+
+     rdzv_handler = DynamicRendezvousHandler.from_backend(
+         run_id="my_run_id",
+         store=store,
+         backend=backend,
+         min_nodes=2,
+         max_nodes=4
+     )
+"""
+
+from .api import *  # noqa: F403
+from .registry import _register_default_handlers
+
+
+_register_default_handlers()
+
+
+__all__ = [
+    "RendezvousClosedError",
+    "RendezvousConnectionError",
+    "RendezvousError",
+    "RendezvousGracefulExitError",
+    "RendezvousHandler",
+    "RendezvousHandlerCreator",
+    "RendezvousHandlerRegistry",
+    "RendezvousParameters",
+    "RendezvousStateError",
+    "RendezvousTimeoutError",
+    "rendezvous_handler_registry",
+]
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c38ff9c80f6815749a82f757806c97022148796
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bbbee62d492c49e3d0db4282cc9f1fa6d844912
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df2f1e8bcd5dac29b9c0b83ccf6770f7f943a1fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe2a5590c7826b730746982691ffde2033874584
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..797c21813f37b94d6201953f909ab97b56427cb1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..379e4426ab1842e4ea0e7dcb84982e24b3a12edd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d669a13ac5d7e515b3225dd7b0abc0fb126700c9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b098351761f5558c848bb16e05bdd0dccf9b7f5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73abe27b28848b47847b9e8bc94682a75778b520
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9029d442d2e86d2526d7ce69ba1865c2bea17688
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a411d6423b7fb6c252226474312bde21560ebd4e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbfbae82a851491184916cb1a1fc8398bfd97f0b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/api.py
@@ -0,0 +1,277 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from torch.distributed import Store
+
+
+class RendezvousError(Exception):
+    """Represents the base type for rendezvous errors."""
+
+
+class RendezvousClosedError(RendezvousError):
+    """Raised when a rendezvous is closed."""
+
+
+class RendezvousTimeoutError(RendezvousError):
+    """Raised when a rendezvous did not complete on time."""
+
+
+class RendezvousConnectionError(RendezvousError):
+    """Raised when the connection to a rendezvous backend has failed."""
+
+
+class RendezvousStateError(RendezvousError):
+    """Raised when the state of a rendezvous is corrupt."""
+
+class RendezvousGracefulExitError(RendezvousError):
+    """Raised when node wasn't not included in rendezvous and gracefully exits.
+
+    Exception is a mechanism to exit the stack, however does not mean a failure.
+    """
+
+class RendezvousHandler(ABC):
+    """Main rendezvous interface.
+
+    Note:
+        Distributed Torch users normally **do not** need to implement their own
+        ``RendezvousHandler``. An implementation based on C10d Store is already
+        provided, and is recommended for most users.
+    """
+
+    @abstractmethod
+    def get_backend(self) -> str:
+        """Return the name of the rendezvous backend."""
+
+    @abstractmethod
+    def next_rendezvous(
+        self,
+    ) -> Tuple[Store, int, int]:
+        """Main entry-point into the rendezvous barrier.
+
+        Blocks until the rendezvous is complete and the current process is
+        included in the formed worker group, or a timeout occurs, or the
+        rendezvous was marked closed.
+
+        Returns:
+            A tuple of :py:class:`torch.distributed.Store`, ``rank``, and
+            ``world size``.
+
+        Raises:
+            RendezvousClosedError:
+                The rendezvous is closed.
+            RendezvousConnectionError:
+                The connection to the rendezvous backend has failed.
+            RendezvousStateError:
+                The rendezvous state is corrupt.
+            RendezvousTimeoutError:
+                The rendezvous did not complete on time.
+        """
+
+    @abstractmethod
+    def is_closed(self) -> bool:
+        """Check whether the rendezvous has been closed.
+
+        A closed rendezvous means all future attempts to re-rendezvous within
+        same job will fail.
+
+        ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
+        propagation and should not be used for synchronization. The intention is
+        that if at least one node decides the job is finished, it will close the
+        rendezvous, and other nodes will soon observe this and stop running as
+        well.
+        """
+
+    @abstractmethod
+    def set_closed(self):
+        """Mark the rendezvous as closed."""
+
+    @abstractmethod
+    def num_nodes_waiting(self) -> int:
+        """Return the number of nodes who arrived late at the rendezvous
+        barrier, hence were not included in the current worker group.
+
+        Callers should periodically call this method to check whether new
+        nodes are waiting to join the job and if so admit them by calling
+        :py:meth:`next_rendezvous()` (re-rendezvous).
+        """
+
+    @abstractmethod
+    def get_run_id(self) -> str:
+        """Return the run id of the rendezvous.
+
+        The run id is a user-defined id that uniquely identifies an instance of
+        a distributed application. It typically maps to a job id and is used to
+        allow nodes to join the correct distributed application.
+        """
+
+    @abstractmethod
+    def shutdown(self) -> bool:
+        """Close all resources that were open for the rendezvous.
+
+        Example::
+
+            rdzv_handler = ...
+            try:
+                store, rank, world_size = rdzv_handler.next_rendezvous()
+            finally:
+                rdzv_handler.shutdown()
+        """
+
+
+class RendezvousParameters:
+    """Hold the parameters to construct a :py:class:`RendezvousHandler`.
+
+    Args:
+        backend:
+            The name of the backend to use to handle the rendezvous.
+        endpoint:
+            The endpoint of the rendezvous, usually in form [:].
+        run_id:
+            The id of the rendezvous.
+        min_nodes:
+            The minimum number of nodes to admit to the rendezvous.
+        max_nodes:
+            The maximum number of nodes to admit to the rendezvous.
+        local_addr:
+            The address of the local node.
+        **kwargs:
+            Additional parameters for the specified backend.
+    """
+
+    def __init__(
+        self,
+        backend: str,
+        endpoint: str,
+        run_id: str,
+        min_nodes: int,
+        max_nodes: int,
+        local_addr: Optional[str] = None,
+        **kwargs,
+    ):
+        if not backend:
+            raise ValueError("The rendezvous backend name must be a non-empty string.")
+
+        if min_nodes < 1:
+            raise ValueError(
+                f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
+            )
+        if max_nodes < min_nodes:
+            raise ValueError(
+                f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
+                f"equal to the minimum number of rendezvous nodes ({min_nodes})."
+            )
+
+        self.backend = backend
+        self.endpoint = endpoint
+        self.run_id = run_id
+        self.min_nodes = min_nodes
+        self.max_nodes = max_nodes
+        self.config = kwargs
+        self.local_addr = local_addr
+
+    def get(self, key: str, default: Any = None) -> Any:
+        """Return the value for ``key`` if ``key`` exists, else ``default``."""
+        return self.config.get(key, default)
+
+    def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
+        """Return the value for ``key`` as a ``bool``."""
+        value = self.get(key, default)
+        if value is None or isinstance(value, bool):
+            return value
+        if isinstance(value, int):
+            if value == 1:
+                return True
+            if value == 0:
+                return False
+        elif isinstance(value, str):
+            if value.lower() in ["1", "true", "t", "yes", "y"]:
+                return True
+            if value.lower() in ["0", "false", "f", "no", "n"]:
+                return False
+        raise ValueError(
+            f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
+        )
+
+    def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
+        """Return the value for ``key`` as an ``int``."""
+        value = self.get(key, default)
+        if value is None:
+            return value
+        try:
+            return int(value)
+        except ValueError as e:
+            raise ValueError(
+                f"The rendezvous configuration option '{key}' does not represent a valid integer "
+                "value."
+            ) from e
+
+
+RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
+
+
+class RendezvousHandlerRegistry:
+    """Represent a registry of :py:class:`RendezvousHandler` backends."""
+
+    _registry: Dict[str, RendezvousHandlerCreator]
+
+    def __init__(self) -> None:
+        self._registry = {}
+
+    def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
+        """Register a new rendezvous backend.
+
+        Args:
+            backend:
+                The name of the backend.
+            creator:
+                The callback to invoke to construct the
+                :py:class:`RendezvousHandler`.
+        """
+        if not backend:
+            raise ValueError("The rendezvous backend name must be a non-empty string.")
+
+        current_creator: Optional[RendezvousHandlerCreator]
+        try:
+            current_creator = self._registry[backend]
+        except KeyError:
+            current_creator = None
+
+        if current_creator is not None and current_creator != creator:
+            raise ValueError(
+                f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
+                f"is already registered with '{current_creator}'."
+            )
+
+        self._registry[backend] = creator
+
+    def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
+        """Create a new :py:class:`RendezvousHandler`."""
+        try:
+            creator = self._registry[params.backend]
+        except KeyError as e:
+            raise ValueError(
+                f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
+                f"to call `{self.register.__name__}`?"
+            ) from e
+
+        handler = creator(params)
+
+        # Do some sanity check.
+        if handler.get_backend() != params.backend:
+            raise RuntimeError(
+                f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
+                f"backend '{params.backend}'."
+            )
+
+        return handler
+
+
+# The default global registry instance used by launcher scripts to instantiate
+# rendezvous handlers.
+rendezvous_handler_registry = RendezvousHandlerRegistry()
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..b547a1ead07275f564c2db71c9b99bf13791f2eb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py
@@ -0,0 +1,269 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import binascii
+import logging
+import os
+import tempfile
+from base64 import b64decode, b64encode
+from datetime import timedelta
+from typing import Any, Optional, Tuple, cast
+
+from torch.distributed import FileStore, Store, TCPStore
+from torch.distributed.elastic.events import (
+    NodeState,
+    construct_and_record_rdzv_event,
+)
+
+from .api import (
+    RendezvousConnectionError,
+    RendezvousError,
+    RendezvousParameters,
+    RendezvousStateError,
+)
+from .dynamic_rendezvous import RendezvousBackend, Token
+from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
+
+log = logging.getLogger(__name__)
+
+
+class C10dRendezvousBackend(RendezvousBackend):
+    """Represents a C10d-backed rendezvous backend.
+
+    Args:
+        store:
+            The :py:class:`torch.distributed.Store` instance to use to
+            communicate with the C10d store.
+        run_id:
+            The run id of the rendezvous.
+    """
+
+    # See the explanation in the __init__ method.
+    _NULL_SENTINEL = "Y2FuaW1hZGFt"
+
+    _store: Store
+    _key: str
+
+    def __init__(self, store: Store, run_id: str) -> None:
+        if not run_id:
+            raise ValueError("The run id must be a non-empty string.")
+
+        self._store = store
+
+        self._key = "torch.rendezvous." + run_id
+
+        # The read operation of a store blocks the caller until the specified
+        # key becomes available. This behavior makes it tricky to use a store
+        # as a regular key-value dictionary.
+        #
+        # As a workaround we initially set a sentinel value as the rendezvous
+        # state. Whenever this value gets returned we treat it as a None.
+        self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
+
+    @property
+    def name(self) -> str:
+        """See base class."""
+        return "c10d"
+
+    def get_state(self) -> Optional[Tuple[bytes, Token]]:
+        """See base class."""
+        base64_state: bytes = self._call_store("get", self._key)
+
+        return self._decode_state(base64_state)
+
+    def set_state(
+        self, state: bytes, token: Optional[Token] = None
+    ) -> Optional[Tuple[bytes, Token, bool]]:
+        """See base class."""
+        base64_state_str: str = b64encode(state).decode()
+
+        if token:
+            # Shortcut if we know for sure that the token is not valid.
+            if not isinstance(token, bytes):
+                result = self.get_state()
+                if result is not None:
+                    tmp = *result, False
+                    # Python 3.6 does not support tuple unpacking in return
+                    # statements.
+                    return tmp
+                return None
+
+            token = token.decode()
+        else:
+            token = self._NULL_SENTINEL
+
+        base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
+
+        state_token_pair = self._decode_state(base64_state)
+        if state_token_pair is None:
+            return None
+
+        new_state, new_token = state_token_pair
+
+        # C10d Store's compare_set method does not offer an easy way to find out
+        # whether our write attempt was successful. As a brute-force solution we
+        # perform a bitwise comparison of our local state and the remote state.
+        return new_state, new_token, new_state == state
+
+    def _call_store(self, store_op: str, *args, **kwargs) -> Any:
+        try:
+            return getattr(self._store, store_op)(*args, **kwargs)
+        except (ValueError, RuntimeError, TimeoutError) as exc:
+            raise RendezvousConnectionError(
+                "The connection to the C10d store has failed. See inner exception for details."
+            ) from exc
+
+    def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
+        if base64_state == self._NULL_SENTINEL.encode():
+            return None
+
+        try:
+            state = b64decode(base64_state)
+        except binascii.Error as exc:
+            raise RendezvousStateError(
+                "The state object is corrupt. See inner exception for details."
+            ) from exc
+
+        return state, base64_state
+
+
+def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
+    host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400)
+
+    cfg_is_host = params.get_as_bool("is_host")
+    # If the user has explicitly specified whether our process should host the
+    # the store, respect it.
+    if cfg_is_host is not None:
+        is_host = cfg_is_host
+    # Otherwise try to determine whether we are the host based on our hostname
+    # and IP address.
+    else:
+        is_host = _matches_machine_hostname(host)
+
+    use_libuv = params.get_as_bool("use_libuv", False)
+
+    # The timeout
+    read_timeout = cast(int, params.get_as_int("read_timeout", 60))
+    if read_timeout <= 0:
+        raise ValueError("The read timeout must be a positive integer.")
+
+    # In specific cases we attempt to instantiate the store twice. For details
+    # see the explanation in the except clause below.
+    for is_server in [is_host, False]:
+        try:
+            store = TCPStore(
+                host,
+                port,
+                is_master=is_server,
+                timeout=timedelta(seconds=read_timeout),
+                use_libuv=use_libuv,
+            )
+
+            if is_server:
+                msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
+                construct_and_record_rdzv_event(
+                    run_id=params.run_id, message=msg, node_state=NodeState.INIT
+                )
+                log.info(msg)
+
+            break
+        except (ValueError, RuntimeError, TimeoutError) as exc:
+            # If we heuristically inferred the value of is_host as True and our
+            # first attempt to instantiate the TCP store has failed, try it one
+            # more time with is_host set to False. As an edge case there can be
+            # more than one process that is part of the same rendezvous on this
+            # machine and only one of them will eventually host the store.
+
+            if not is_server or cfg_is_host is not None:
+                raise RendezvousConnectionError(
+                    "The connection to the C10d store has failed. See inner exception for details."
+                ) from exc
+
+    return store  # type: ignore[possibly-undefined]
+
+
+def _create_file_store(params: RendezvousParameters) -> FileStore:
+    # If a user specifies an endpoint, we treat it as a path to a file.
+    if params.endpoint:
+        path = params.endpoint
+    else:
+        try:
+            # The temporary file is readable and writable only by the user of
+            # this process.
+            _, path = tempfile.mkstemp()
+        except OSError as exc:
+            raise RendezvousError(
+                "The file creation for C10d store has failed. See inner exception for details."
+            ) from exc
+
+    try:
+        store = FileStore(path)
+    except (ValueError, RuntimeError) as exc:
+        raise RendezvousConnectionError(
+            "The connection to the C10d store has failed. See inner exception for details."
+        ) from exc
+
+    return store
+
+
+def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
+    """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters.
+
+    +--------------+-----------------------------------------------------------+
+    | Parameter    | Description                                               |
+    +==============+===========================================================+
+    | store_type   | The type of the C10d store. The currently supported types |
+    |              | are "tcp" and "file" which correspond to                  |
+    |              | :py:class:`torch.distributed.TCPStore` and                |
+    |              | :py:class:`torch.distributed.FileStore`, respectively.    |
+    |              | Defaults to "tcp".                                        |
+    +--------------+-----------------------------------------------------------+
+    | read_timeout | The read timeout, in seconds, for store operations.       |
+    |              | Defaults to 60 seconds.                                   |
+    |              |                                                           |
+    |              | Note this only applies to                                 |
+    |              | :py:class:`torch.distributed.TCPStore`. It is not relevant|
+    |              | to :py:class:`torch.distributed.FileStore` which does not |
+    |              | take in timeout as a parameter.                           |
+    +--------------+-----------------------------------------------------------+
+    | is_host      | A boolean value indicating whether this backend instance  |
+    |              | will host the C10d store. If not specified it will be     |
+    |              | inferred heuristically by matching the hostname or the IP |
+    |              | address of this machine against the specified rendezvous  |
+    |              | endpoint. Defaults to ``None``.                           |
+    |              |                                                           |
+    |              | Note that this configuration option only applies to       |
+    |              | :py:class:`torch.distributed.TCPStore`. In normal         |
+    |              | circumstances you can safely skip it; the only time when  |
+    |              | it is needed is if its value cannot be correctly          |
+    |              | determined (e.g. the rendezvous endpoint has a CNAME as   |
+    |              | the hostname or does not match the FQDN of the machine).  |
+    +--------------+-----------------------------------------------------------+
+    """
+    # As of today we only support TCPStore and FileStore. Other store types do
+    # not have the required functionality (e.g. compare_set) yet.
+    store_type = params.get("store_type", "tcp").strip().lower()
+    store: Store
+
+    try:
+        if store_type == "file":
+            store = _create_file_store(params)
+        elif store_type == "tcp":
+            store = _create_tcp_store(params)
+        else:
+            raise ValueError("Invalid store type given. Currently only supports file and tcp.")
+
+        backend = C10dRendezvousBackend(store, params.run_id)
+
+    except Exception as e:
+        construct_and_record_rdzv_event(
+            message=f"{type(e).__name__}: {str(e)}",
+            run_id=params.run_id,
+            node_state=NodeState.FAILED,
+        )
+        raise
+
+    return backend, store
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py
new file mode 100644
index 0000000000000000000000000000000000000000..88d649141d190c231364754c34494bf9e2bee47e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py
@@ -0,0 +1,1343 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import inspect
+import logging
+import os
+import pickle
+import socket
+import threading
+import time
+import weakref
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from enum import Enum
+from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple
+
+from torch.distributed import PrefixStore, Store
+from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
+
+from .api import (
+    RendezvousClosedError,
+    RendezvousError,
+    RendezvousGracefulExitError,
+    RendezvousHandler,
+    RendezvousParameters,
+    RendezvousStateError,
+    RendezvousTimeoutError,
+)
+from .utils import _delay, _PeriodicTimer
+
+__all__ = ['RendezvousBackend', 'RendezvousTimeout', 'RendezvousSettings', 'DynamicRendezvousHandler', 'create_handler']
+
+log = logging.getLogger(__name__)
+
+
+def get_method_name(depth=2):
+    if len(inspect.stack()) > depth:
+        return inspect.stack()[depth].function
+    return "no_method_name"
+
+
+Token = Any
+"""Represent an opaque fencing token used by the rendezvous backend."""
+
+class RendezvousBackend(ABC):
+    """Represent a backend that holds the rendezvous state."""
+
+    @property
+    @abstractmethod
+    def name(self) -> str:
+        """Get the name of the backend."""
+
+    @abstractmethod
+    def get_state(self) -> Optional[Tuple[bytes, Token]]:
+        """Get the rendezvous state.
+
+        Returns:
+            A tuple of the encoded rendezvous state and its fencing token or
+            ``None`` if no state is found in the backend.
+
+        Raises:
+            RendezvousConnectionError:
+                The connection to the backend has failed.
+            RendezvousStateError:
+                The rendezvous state is corrupt.
+        """
+
+    @abstractmethod
+    def set_state(
+        self, state: bytes, token: Optional[Token] = None
+    ) -> Optional[Tuple[bytes, Token, bool]]:
+        """Set the rendezvous state.
+
+        The new rendezvous state is set conditionally:
+
+          - If the specified ``token`` matches the fencing token stored in the
+            backend, the state will be updated. The new state will be returned
+            to the caller along with its fencing token.
+          - If the specified ``token`` does not match the fencing token stored
+            in the backend, the state won't be updated; instead the existing
+            state along with its fencing token will be returned to the caller.
+          - If the specified ``token`` is ``None``, the new state will be set
+            only if there is no existing state in the backend. Either the new
+            state or the existing state along with its fencing token will be
+            returned to the caller.
+
+        Args:
+            state:
+                The encoded rendezvous state.
+            token:
+                An optional fencing token that was retrieved by a previous call
+                to :py:meth:`get_state` or ``set_state()``.
+
+        Returns:
+            A tuple of the serialized rendezvous state, its fencing token, and
+            a boolean value indicating whether our set attempt succeeded.
+
+        Raises:
+            RendezvousConnectionError:
+                The connection to the backend has failed.
+            RendezvousStateError:
+                The rendezvous state is corrupt.
+        """
+
+
+class RendezvousTimeout:
+    """Hold the timeout configuration of a rendezvous.
+
+    Args:
+        join:
+            The time within which the rendezvous is expected to complete.
+        last_call:
+            An additional wait amount before completing the rendezvous once the
+            rendezvous has the minimum number of required participants.
+        close:
+            The time within which the rendezvous is expected to close after a
+            call to :py:meth:`RendezvousHandler.set_closed` or
+            :py:meth:`RendezvousHandler.shutdown`.
+        keep_alive:
+            The time within which a keep-alive heartbeat is expected to
+            complete.
+    """
+
+    _ZERO = timedelta(0)
+
+    _DEFAULT_TIMEOUTS = {
+        "join": timedelta(seconds=600),
+        "last_call": timedelta(seconds=30),
+        "close": timedelta(seconds=30),
+        "heartbeat": timedelta(seconds=5),
+    }
+
+    _join: timedelta
+    _last_call: timedelta
+    _close: timedelta
+    _heartbeat: timedelta
+
+    def __init__(
+        self,
+        join: Optional[timedelta] = None,
+        last_call: Optional[timedelta] = None,
+        close: Optional[timedelta] = None,
+        heartbeat: Optional[timedelta] = None,
+    ) -> None:
+        self._set_timeouts(join=join, last_call=last_call, close=close, heartbeat=heartbeat)
+
+    @property
+    def join(self) -> timedelta:
+        """Get the join timeout."""
+        return self._join
+
+    @property
+    def last_call(self) -> timedelta:
+        """Get the last call timeout."""
+        return self._last_call
+
+    @property
+    def close(self) -> timedelta:
+        """Get the close timeout."""
+        return self._close
+
+    @property
+    def heartbeat(self) -> timedelta:
+        """Get the keep-alive heartbeat timeout."""
+        return self._heartbeat
+
+    def _set_timeouts(self, **timeouts: Optional[timedelta]):
+        for name, timeout in timeouts.items():
+            if timeout is None:
+                timeout = self._DEFAULT_TIMEOUTS[name]
+            if timeout <= self._ZERO:
+                raise ValueError(f"The {name} timeout ({timeout}) must be positive.")
+            setattr(self, "_" + name, timeout)
+
+
+@dataclass(repr=False, eq=False, frozen=True)
+class RendezvousSettings:
+    """Hold the settings of the rendezvous.
+
+    Attributes:
+        run_id:
+            The run id of the rendezvous.
+        min_nodes:
+            The minimum number of nodes to admit to the rendezvous.
+        max_nodes:
+            The maximum number of nodes to admit to the rendezvous.
+        timeout:
+            The timeout configuration of the rendezvous.
+        keep_alive_interval:
+            The amount of time a node waits before sending a heartbeat to keep
+            it alive in the rendezvous.
+        keep_alive_max_attempt:
+            The maximum number of failed heartbeat attempts after which a node
+            is considered dead.
+    """
+
+    run_id: str
+    min_nodes: int
+    max_nodes: int
+    timeout: RendezvousTimeout
+    keep_alive_interval: timedelta
+    keep_alive_max_attempt: int
+
+
+@dataclass(eq=True, order=True, frozen=True)
+class _NodeDesc:
+    """Describe a node in the rendezvous.
+
+    Attributes:
+        addr:
+            The FQDN of the node or user specified local node address.
+        pid:
+            The id of the process in which the rendezvous handler runs.
+        local_id:
+            A process-wide unique id.
+    """
+
+    addr: str
+    pid: int
+    local_id: int
+
+    def __repr__(self) -> str:
+        return f"{self.addr}_{self.pid}_{self.local_id}"
+
+
+class _NodeDescGenerator:
+    """Generate node descriptors.
+
+    A node descriptor is a combination of an FQDN, a process id, and an auto-
+    incremented integer that uniquely identifies a node in the rendezvous.
+    """
+
+    _lock: threading.Lock
+    _local_id: int
+
+    def __init__(self) -> None:
+        self._lock = threading.Lock()
+
+        # An integer that is incremented with each call to generate().
+        self._local_id = 0
+
+    def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
+        # This method can be called by multiple threads concurrently; therefore,
+        # we must increment the integer atomically.
+        with self._lock:
+            local_id = self._local_id
+
+            self._local_id += 1
+
+        return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id)
+
+
+class _RendezvousState:
+    """Hold the state of a rendezvous.
+
+    Attributes:
+        round:
+            The current round of the rendezvous.
+        complete:
+            A boolean value indicating whether the current round of the
+            rendezvous is complete.
+        deadline:
+            The time at which the current round of the rendezvous will be
+            considered complete if it is still waiting for nodes to join.
+        closed:
+            A boolean value indicating whether the rendezvous is closed.
+        participants:
+            A dictionary of the participants and their corresponding ranks.
+        wait_list:
+            A set of nodes that are waiting to participate in the next round of
+            the rendezvous.
+        redundancy_list:
+            A set of nodes that are redundant in the current round and can join
+            the next rendezvous without triggering re-rendezvous.
+        last_heartbeats:
+            A dictionary containing each node's last heartbeat time.
+    """
+
+    round: int
+    complete: bool
+    deadline: Optional[datetime]
+    closed: bool
+    participants: Dict[_NodeDesc, int]
+    wait_list: Set[_NodeDesc]
+    redundancy_list: Set[_NodeDesc]
+    last_heartbeats: Dict[_NodeDesc, datetime]
+
+    def __init__(self) -> None:
+        self.round = 0
+        self.complete = False
+        self.deadline = None
+        self.closed = False
+        self.participants = {}
+        self.wait_list = set()
+        self.redundancy_list = set()
+        self.last_heartbeats = {}
+
+
+def _remove_participant_epilogue(state: _RendezvousState, settings: RendezvousSettings) -> None:
+    if state.complete:
+        # If we do not have any participants left, move to the next round.
+        if not state.participants:
+            msg = "No participants left in the rendezvous, marking rendezvous as incomplete"
+            log.debug(msg)
+            state.complete = False
+
+            state.round += 1
+    else:
+        if len(state.participants) < settings.min_nodes:
+            msg = (
+                f"Number of participants {len(state.participants)}) less than"
+                f"min_nodes {settings.min_nodes}, clearning deadline in state"
+            )
+            log.debug(msg)
+            state.deadline = None
+
+
+class _RendezvousStateHolder(ABC):
+    """Hold the shared rendezvous state synced with other nodes."""
+
+    @property
+    @abstractmethod
+    def state(self) -> _RendezvousState:
+        """Get the local state."""
+
+    @abstractmethod
+    def sync(self) -> Optional[bool]:
+        """Read or writes the latest state.
+
+        Returns:
+            A boolean value indicating whether the local state, in case marked
+            as dirty, was successfully synced with other nodes.
+        """
+
+    @abstractmethod
+    def mark_dirty(self) -> None:
+        """Mark the local state as dirty."""
+
+
+class _BackendRendezvousStateHolder(_RendezvousStateHolder):
+    """Hold the rendezvous state synced with other nodes via a backend.
+
+    Args:
+        backend:
+            The rendezvous backend to use.
+        settings:
+            The rendezvous settings.
+        cache_duration:
+            The amount of time, in seconds, to cache the last rendezvous state
+            before requesting it from the backend again.
+    """
+
+    _backend: RendezvousBackend
+    _state: _RendezvousState
+    _settings: RendezvousSettings
+    _cache_duration: int
+    _token: Token
+    _dirty: bool
+    _last_sync_time: float
+    _dead_nodes: List[_NodeDesc]
+
+    def __init__(
+        self,
+        backend: RendezvousBackend,
+        settings: RendezvousSettings,
+        cache_duration: int = 1,
+    ) -> None:
+        self._backend = backend
+        self._state = _RendezvousState()
+        self._settings = settings
+        self._cache_duration = cache_duration
+        self._token = None
+        self._dirty = False
+        self._last_sync_time = -1
+        self._dead_nodes = []
+
+    def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
+        construct_and_record_rdzv_event(
+            name=f"{self.__class__.__name__}.{get_method_name()}",
+            run_id=self._settings.run_id,
+            message=message,
+            node_state=node_state,
+        )
+
+    @property
+    def state(self) -> _RendezvousState:
+        """See base class."""
+        return self._state
+
+    def sync(self) -> Optional[bool]:
+        """See base class."""
+        state_bits: Optional[bytes] = None
+
+        token = None
+
+        has_set: Optional[bool]
+
+        if self._dirty:
+            has_set = False
+
+            state_bits = pickle.dumps(self._state)
+
+            set_response = self._backend.set_state(state_bits, self._token)
+            if set_response is not None:
+                state_bits, token, has_set = set_response
+        else:
+            has_set = None
+
+            if self._cache_duration > 0:
+                # Avoid overloading the backend if we are asked to retrieve the
+                # state repeatedly. Try to serve the cached state.
+                if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
+                    return None
+
+            get_response = self._backend.get_state()
+            if get_response is not None:
+                state_bits, token = get_response
+
+        if state_bits is not None:
+            try:
+                self._state = pickle.loads(state_bits)
+            except pickle.PickleError as exc:
+                raise RendezvousStateError(
+                    "The rendezvous state is corrupt. See inner exception for details."
+                ) from exc
+        else:
+            self._state = _RendezvousState()
+
+        if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
+            node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
+
+            msg = (
+                f"As part of the sync operation the node(s) {node_list} have been removed from the "
+                f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
+            )
+            self._record(message=msg)
+            log.debug(msg)
+
+        self._token = token
+
+        self._dirty = False
+
+        self._last_sync_time = time.monotonic()
+
+        self._sanitize()
+
+        return has_set
+
+    def _sanitize(self) -> None:
+        state = self._state
+
+        expire_time = datetime.utcnow() - (
+            self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
+        )
+
+        # Filter out the dead nodes.
+        self._dead_nodes = [
+            node
+            for node, last_heartbeat in state.last_heartbeats.items()
+            if last_heartbeat < expire_time
+        ]
+
+        participant_removed = False
+
+        for dead_node in self._dead_nodes:
+            msg = f"Detected dead node '{dead_node}', removing it from the rendezvous"
+            log.debug(msg)
+            del state.last_heartbeats[dead_node]
+
+            try:
+                del state.participants[dead_node]
+
+                participant_removed = True
+            except KeyError:
+                pass
+
+            try:
+                state.wait_list.remove(dead_node)
+            except KeyError:
+                pass
+
+            try:
+                state.redundancy_list.remove(dead_node)
+            except KeyError:
+                pass
+
+        if participant_removed:
+            # Common epilogue shared with the _remove_from_participants()
+            # function of _DistributedRendezvousOpExecutor.
+            _remove_participant_epilogue(state, self._settings)
+
+    def mark_dirty(self) -> None:
+        """See base class.
+
+        If the local rendezvous state is dirty, the next sync call will try to
+        write the changes back to the backend. However this attempt might fail
+        if another node, which had the same state, also made changes and wrote
+        them before us.
+        """
+        self._dirty = True
+
+
+class _Action(Enum):
+    """Specifies the possible actions based on the state of the rendezvous."""
+
+    KEEP_ALIVE = 1
+    ADD_TO_PARTICIPANTS = 2
+    ADD_TO_WAIT_LIST = 3
+    ADD_TO_REDUNDANCY_LIST = 4
+    REMOVE_FROM_PARTICIPANTS = 5
+    REMOVE_FROM_WAIT_LIST = 6
+    REMOVE_FROM_REDUNDANCY_LIST = 7
+    MARK_RENDEZVOUS_COMPLETE = 8
+    MARK_RENDEZVOUS_CLOSED = 9
+    SYNC = 10
+    ERROR_CLOSED = 11
+    ERROR_TIMEOUT = 12
+    FINISH = 13
+
+
+class _RendezvousContext:
+    """Holds the context of the rendezvous.
+
+    Attributes:
+        node:
+            The node descriptor associated with the current rendezvous handler
+            instance.
+        state:
+            The current state of the rendezvous.
+        settings:
+            The rendezvous settings.
+    """
+
+    node: _NodeDesc
+    state: _RendezvousState
+    settings: RendezvousSettings
+
+    def __init__(
+        self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
+    ) -> None:
+        self.node = node
+        self.state = state
+        self.settings = settings
+
+
+class _RendezvousOpExecutor(ABC):
+    """Execute rendezvous operations."""
+
+    @abstractmethod
+    def run(
+        self,
+        state_handler: Callable[[_RendezvousContext, float], _Action],
+        deadline: float,
+        update_deadline: Optional[Callable[[timedelta], float]] = None,
+    ) -> None:
+        """Execute a rendezvous operation.
+
+        An operation is run inside a state machine and is expected to transition
+        the rendezvous from one state to another.
+
+        Args:
+            state_handler:
+                A callable that is expected to return the next state transition
+                action based on the current state of the rendezvous.
+            deadline:
+                The time, in seconds, at which the operation will be considered
+                timed-out.
+            update_deadline:
+                Function to generate a new operation deadline if the current
+                node may participate in the next rendezvous.
+        """
+
+
+class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
+    """Execute rendezvous operations using a shared state.
+
+    Args:
+        node:
+            The node descriptor associated with the current rendezvous handler
+            instance.
+        state_holder:
+            The ``RendezvousStateHolder`` to use to sync the rendezvous state
+            with other nodes.
+        settings:
+            The rendezvous settings.
+    """
+
+    _node: _NodeDesc
+    _state: _RendezvousState
+    _state_holder: _RendezvousStateHolder
+    _settings: RendezvousSettings
+
+    def __init__(
+        self,
+        node: _NodeDesc,
+        state_holder: _RendezvousStateHolder,
+        settings: RendezvousSettings,
+    ) -> None:
+        self._node = node
+        self._state_holder = state_holder
+        self._settings = settings
+
+    def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
+        construct_and_record_rdzv_event(
+            name=f"{self.__class__.__name__}.{get_method_name()}",
+            run_id=self._settings.run_id,
+            message=message,
+            node_state=node_state,
+            hostname=self._node.addr,
+            pid=self._node.pid,
+            local_id=self._node.local_id,
+        )
+
+    def run(
+        self,
+        state_handler: Callable[[_RendezvousContext, float], _Action],
+        deadline: float,
+        update_deadline: Optional[Callable[[timedelta], float]] = None,
+    ) -> None:
+        """See base class."""
+        action = None
+        while action != _Action.FINISH:
+            # Reads or writes the latest rendezvous state shared by all nodes in
+            # the rendezvous. Note that our local changes might get overridden
+            # by another node if that node synced its changes before us.
+            has_set = self._state_holder.sync()
+            if has_set is not None:
+                if has_set:
+                    msg = (
+                        f"The node '{self._node}' has successfully synced its local changes with "
+                        f"other nodes in the rendezvous '{self._settings.run_id}'."
+                    )
+                else:
+                    msg = (
+                        f"The node '{self._node}' has a stale state and failed to sync its local "
+                        f"changes with other nodes in the rendezvous '{self._settings.run_id}'."
+                    )
+
+                self._record(message=msg)
+                log.debug(msg)
+
+            self._state = self._state_holder.state
+
+            ctx = _RendezvousContext(self._node, self._state, self._settings)
+
+            # Determine the next action to take based on the current state of
+            # the rendezvous.
+            action = state_handler(ctx, deadline)
+
+            if action == _Action.FINISH:
+                continue
+
+            if action == _Action.ERROR_CLOSED:
+                raise RendezvousClosedError()
+
+            if action == _Action.ERROR_TIMEOUT:
+                raise RendezvousTimeoutError()
+
+            if action == _Action.SYNC:
+                # Delay the execution by one second to avoid overloading the
+                # backend if we are asked to poll for state changes.
+                _delay(seconds=1)
+            else:
+                if action == _Action.KEEP_ALIVE:
+                    self._keep_alive()
+                elif action == _Action.ADD_TO_PARTICIPANTS:
+                    self._add_to_participants()
+                elif action == _Action.ADD_TO_WAIT_LIST:
+                    self._add_to_wait_list()
+                elif action == _Action.ADD_TO_REDUNDANCY_LIST:
+                    self._add_to_redundancy_list()
+                elif action == _Action.REMOVE_FROM_PARTICIPANTS:
+                    self._remove_from_participants()
+                elif action == _Action.REMOVE_FROM_WAIT_LIST:
+                    self._remove_from_wait_list()
+                elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST:
+                    self._remove_from_redundancy_list()
+                    # update deadline since the node may participate in rendezvous process
+                    if update_deadline:
+                        deadline = update_deadline(self._settings.timeout.join)
+                elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
+                    self._mark_rendezvous_complete()
+                elif action == _Action.MARK_RENDEZVOUS_CLOSED:
+                    self._mark_rendezvous_closed()
+
+                # Attempt to sync our changes back to other nodes.
+                self._state_holder.mark_dirty()
+
+    def _keep_alive(self) -> None:
+        msg = (
+            f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
+            f"'{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        self._state.last_heartbeats[self._node] = datetime.utcnow()
+
+    def _add_to_participants(self) -> None:
+        msg = (
+            f"The node '{self._node}' added itself to the participants of round "
+            f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        state = self._state
+
+        try:
+            state.wait_list.remove(self._node)
+        except KeyError:
+            pass
+
+        # The ranks of the participants will be set once the rendezvous is
+        # complete.
+        state.participants[self._node] = 0
+
+        self._keep_alive()
+
+        if len(state.participants) == self._settings.min_nodes:
+            state.deadline = datetime.utcnow() + self._settings.timeout.last_call
+
+        if len(state.participants) == self._settings.max_nodes:
+            self._mark_rendezvous_complete()
+
+    def _add_to_wait_list(self) -> None:
+        msg = (
+            f"The node '{self._node}' added itself to the wait list of round "
+            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        if self._node in self._state.redundancy_list:
+            self._state.redundancy_list.remove(self._node)
+        self._state.wait_list.add(self._node)
+
+        self._keep_alive()
+
+    def _add_to_redundancy_list(self) -> None:
+        msg = (
+            f"The node '{self._node}' added itself to the redundancy list of round "
+            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        self._state.redundancy_list.add(self._node)
+
+        self._keep_alive()
+
+    def _remove_from_participants(self) -> None:
+        msg = (
+            f"The node '{self._node}' removed itself from the participants of round "
+            f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        state = self._state
+
+        del state.participants[self._node]
+
+        del state.last_heartbeats[self._node]
+
+        # Common epilogue shared with the sanitizer() function of
+        # _BackendRendezvousStateHolder.
+        _remove_participant_epilogue(state, self._settings)
+
+    def _remove_from_wait_list(self) -> None:
+        msg = (
+            f"The node '{self._node}' removed itself from the wait list of round "
+            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        self._state.wait_list.remove(self._node)
+
+        del self._state.last_heartbeats[self._node]
+
+    def _remove_from_redundancy_list(self) -> None:
+        msg = (
+            f"The node '{self._node}' removed itself from the redunant list of round "
+            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
+        )
+        self._record(message=msg)
+        log.debug(msg)
+
+        self._state.redundancy_list.remove(self._node)
+
+        del self._state.last_heartbeats[self._node]
+
+    def _mark_rendezvous_complete(self) -> None:
+        msg = (
+            f"The node '{self._node}' marked round {self._state.round} of the rendezvous "
+            f"'{self._settings.run_id}' as complete. Pending sync."
+        )
+        self._record(message=msg, node_state=NodeState.SUCCEEDED)
+        log.debug(msg)
+
+        state = self._state
+
+        state.complete = True
+        state.deadline = None
+
+        # Assign the ranks.
+        for rank, node in enumerate(sorted(state.participants)):
+            state.participants[node] = rank
+
+    def _mark_rendezvous_closed(self) -> None:
+        msg = (
+            f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. "
+            "Pending sync."
+        )
+        self._record(message=msg, node_state=NodeState.SUCCEEDED)
+        log.debug(msg)
+
+        self._state.closed = True
+
+
+def _should_keep_alive(ctx: _RendezvousContext) -> bool:
+    """Determine whether a keep-alive heartbeat should be sent."""
+    try:
+        last_heartbeat = ctx.state.last_heartbeats[ctx.node]
+    except KeyError:
+        return False
+
+    return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
+
+
+class _RendezvousExitOp:
+    """Represent a rendezvous exit operation."""
+
+    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
+        if ctx.node in ctx.state.participants:
+            if time.monotonic() > deadline:
+                return _Action.ERROR_TIMEOUT
+            return _Action.REMOVE_FROM_PARTICIPANTS
+        return _Action.FINISH
+
+
+class _RendezvousJoinOp:
+    """Represent a rendezvous join operation."""
+
+    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
+        state = ctx.state
+
+        # A closed rendezvous means that it no longer accepts new nodes.
+        if state.closed:
+            if ctx.node in state.redundancy_list:
+                msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous."
+                raise RendezvousGracefulExitError(msg)
+            return _Action.ERROR_CLOSED
+
+        if ctx.node in state.redundancy_list:
+            msg = f"The node {ctx.node} is in redunancy list"
+            log.debug(msg)
+            # don't apply the timeout logic here, since we want to allow the node to rejoin
+            if len(state.participants) == ctx.settings.max_nodes:
+                if _should_keep_alive(ctx):
+                    return _Action.KEEP_ALIVE
+                else:
+                    return _Action.SYNC
+            else:
+                # transition to waiting state that will respect timeouts.
+                msg = f"The node {ctx.node} is removed from redunancy list"
+                log.debug(msg)
+                return _Action.REMOVE_FROM_REDUNDANCY_LIST
+
+        is_participant = ctx.node in state.participants
+
+        # If we are part of the rendezvous and it is already complete there is
+        # no further action to take.
+        if state.complete and is_participant:
+            return _Action.FINISH
+
+        now = time.monotonic()
+        if now > deadline:
+            rollback_period = 5  # 5 seconds
+
+            # If we still have time to rollback (a short period on top of the
+            # operation deadline), try to remove ourself from the rendezvous.
+            # It is okay if we can't though as our keep-alive will eventually
+            # expire.
+            if now <= deadline + rollback_period:
+                # If we are part of the rendezvous, it means we couldn't find
+                # enough participants to complete it on time.
+                if is_participant:
+                    return _Action.REMOVE_FROM_PARTICIPANTS
+                # If we are in the wait list, it means we couldn't wait till the
+                # next round of the rendezvous.
+                if ctx.node in state.wait_list:
+                    return _Action.REMOVE_FROM_WAIT_LIST
+            return _Action.ERROR_TIMEOUT
+
+        if state.complete:
+            # If we are here, it means we are not part of the rendezvous. In
+            # case the rendezvous has capacity for additional participants add
+            # ourself to the wait list for the next round.
+            if len(state.participants) < ctx.settings.max_nodes:
+                if ctx.node not in state.wait_list:
+                    return _Action.ADD_TO_WAIT_LIST
+            elif len(state.participants) >= ctx.settings.max_nodes:
+                if ctx.node not in state.redundancy_list and ctx.node not in state.wait_list:
+                    return _Action.ADD_TO_REDUNDANCY_LIST
+        elif is_participant:
+            # If the rendezvous has enough number of participants including us,
+            # check whether we have passed the rendezvous deadline. If yes,
+            # complete it.
+            if len(state.participants) >= ctx.settings.min_nodes and \
+                    len(state.participants) <= ctx.settings.max_nodes:
+                if cast(datetime, state.deadline) < datetime.utcnow():
+                    msg = (
+                        f"The node '{ctx.node}' marking the rendezvous complete, "
+                        f"quorum established within deadline"
+                    )
+                    log.debug(msg)
+                    return _Action.MARK_RENDEZVOUS_COMPLETE
+                else:
+                    msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached"
+                    log.debug(msg)
+            else:
+                msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants"
+                log.debug(msg)
+        else:
+            # The rendezvous is not complete yet and we are not part of it. Try
+            # to join.
+            return _Action.ADD_TO_PARTICIPANTS
+
+        if _should_keep_alive(ctx):
+            return _Action.KEEP_ALIVE
+
+        # At this point either the rendezvous is not complete, but we are part
+        # of it, which means we have to wait for other participants to join; or
+        # the rendezvous is complete, but we are not part of it, which means we
+        # have to wait for the next round.
+        return _Action.SYNC
+
+
+class _RendezvousCloseOp:
+    """Represent a rendezvous close operation."""
+
+    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
+        if ctx.state.closed:
+            return _Action.FINISH
+        if time.monotonic() > deadline:
+            return _Action.ERROR_TIMEOUT
+        return _Action.MARK_RENDEZVOUS_CLOSED
+
+
+class _RendezvousKeepAliveOp:
+    """Represent a rendezvous keep-alive update operation."""
+
+    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
+        if _should_keep_alive(ctx):
+            if time.monotonic() > deadline:
+                return _Action.ERROR_TIMEOUT
+            return _Action.KEEP_ALIVE
+        return _Action.FINISH
+
+
+class DynamicRendezvousHandler(RendezvousHandler):
+    """Represent a handler that sets up a rendezvous among a set of nodes."""
+
+    # Static
+    _node_desc_generator = _NodeDescGenerator()
+
+    _this_node: _NodeDesc
+    _settings: RendezvousSettings
+    _backend_name: str
+    _store: Store
+    _state_holder: _RendezvousStateHolder
+    _op_executor: _RendezvousOpExecutor
+    _heartbeat_lock: threading.Lock
+    _keep_alive_timer: Optional[_PeriodicTimer]
+
+    @classmethod
+    def from_backend(
+        cls,
+        run_id: str,
+        store: Store,
+        backend: RendezvousBackend,
+        min_nodes: int,
+        max_nodes: int,
+        local_addr: Optional[str] = None,
+        timeout: Optional[RendezvousTimeout] = None,
+    ):
+        """Create a new :py:class:`DynamicRendezvousHandler`.
+
+        Args:
+            run_id:
+                The run id of the rendezvous.
+            store:
+                The C10d store to return as part of the rendezvous.
+            backend:
+                The backend to use to hold the rendezvous state.
+            min_nodes:
+                The minimum number of nodes to admit to the rendezvous.
+            max_nodes:
+                The maximum number of nodes to admit to the rendezvous.
+            local_addr:
+                The local node address.
+            timeout:
+                The timeout configuration of the rendezvous.
+        """
+        # We associate each handler instance with a unique node descriptor.
+        node = cls._node_desc_generator.generate(local_addr)
+
+        settings = RendezvousSettings(
+            run_id,
+            min_nodes,
+            max_nodes,
+            timeout or RendezvousTimeout(),
+            keep_alive_interval=timedelta(seconds=5),
+            keep_alive_max_attempt=3,
+        )
+
+        state_holder = _BackendRendezvousStateHolder(backend, settings)
+
+        return cls(node, settings, backend.name, store, state_holder)
+
+    def __init__(
+        self,
+        node: _NodeDesc,
+        settings: RendezvousSettings,
+        backend_name: str,
+        store: Store,
+        state_holder: _RendezvousStateHolder,
+    ) -> None:
+        if not settings.run_id:
+            raise ValueError("The run id must be a non-empty string.")
+
+        if settings.min_nodes < 1:
+            raise ValueError(
+                f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero."
+            )
+
+        if settings.max_nodes < settings.min_nodes:
+            raise ValueError(
+                f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal "
+                f"to the minimum number of nodes ({settings.min_nodes})."
+            )
+
+        self._this_node = node
+
+        self._settings = settings
+
+        self._backend_name = backend_name
+
+        self._store = store
+
+        self._state_holder = state_holder
+
+        self._op_executor = _DistributedRendezvousOpExecutor(
+            self._this_node, self._state_holder, self._settings
+        )
+
+        self._heartbeat_lock = threading.Lock()
+
+        self._keep_alive_timer = None
+
+    def _record(
+        self,
+        message: str,
+        node_state: NodeState = NodeState.RUNNING,
+        rank: Optional[int] = None,
+    ) -> None:
+        construct_and_record_rdzv_event(
+            name=f"{self.__class__.__name__}.{get_method_name()}",
+            run_id=self._settings.run_id,
+            message=message,
+            node_state=node_state,
+            hostname=self._this_node.addr,
+            pid=self._this_node.pid,
+            local_id=self._this_node.local_id,
+            rank=rank,
+        )
+
+    @property
+    def settings(self) -> RendezvousSettings:
+        """Get the settings of the rendezvous."""
+        return self._settings
+
+    def get_backend(self) -> str:
+        """See base class."""
+        return self._backend_name
+
+    def next_rendezvous(self) -> Tuple[Store, int, int]:
+        """See base class."""
+        msg = (
+            f"The node '{self._this_node}' attempts to join the next round of the rendezvous "
+            f"'{self._settings.run_id}'."
+        )
+        self._record(message=msg)
+        log.info(msg)
+
+        try:
+            self._stop_heartbeats()
+
+            # Delay the execution for a small random amount of time if this is our
+            # first run. This will slightly skew the rendezvous attempts across the
+            # nodes and reduce the load on the backend.
+            if self._state_holder.state.round == 0:
+                _delay(seconds=(0, 0.3))
+
+            exit_op = _RendezvousExitOp()
+            join_op = _RendezvousJoinOp()
+
+            deadline = self._get_deadline(self._settings.timeout.join)
+            self._op_executor.run(exit_op, deadline)
+            self._op_executor.run(
+                join_op,
+                deadline,
+                self._get_deadline)
+
+            self._start_heartbeats()
+
+            rank, world_size = self._get_world()
+            store = self._get_store()
+
+        except Exception as e:
+            self._record(
+                message=f"{type(e).__name__}: {str(e)}",
+                node_state=NodeState.FAILED,
+            )
+            raise
+
+        msg = (
+            f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of "
+            f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size "
+            f"{world_size}."
+        )
+        self._record(message=msg, rank=rank)
+        log.info(msg)
+
+        return store, rank, world_size
+
+    def is_closed(self) -> bool:
+        """See base class."""
+        try:
+            with self._heartbeat_lock:
+                self._state_holder.sync()
+
+                return self._state_holder.state.closed
+
+        except Exception as e:
+            self._record(
+                message=f"{type(e).__name__}: {str(e)}",
+                node_state=NodeState.FAILED,
+            )
+            raise
+
+    def set_closed(self) -> None:
+        """See base class."""
+        try:
+            with self._heartbeat_lock:
+                self._close()
+        except Exception as e:
+            self._record(
+                message=f"{type(e).__name__}: {str(e)}",
+                node_state=NodeState.FAILED,
+            )
+            raise
+
+    def num_nodes_waiting(self) -> int:
+        """See base class."""
+        try:
+            with self._heartbeat_lock:
+                self._state_holder.sync()
+
+                return len(self._state_holder.state.wait_list)
+
+        except Exception as e:
+            self._record(
+                message=f"{type(e).__name__}: {str(e)}",
+                node_state=NodeState.FAILED,
+            )
+            raise
+
+    def get_run_id(self) -> str:
+        """See base class."""
+        return self._settings.run_id
+
+    def shutdown(self) -> bool:
+        """See base class."""
+        self._stop_heartbeats()
+
+        try:
+            self._close()
+
+            return True
+        except RendezvousError as ex:
+            msg = (
+                f"The node '{self._this_node}' has failed to shutdown the rendezvous "
+                f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}."
+            )
+            self._record(message=msg, node_state=NodeState.FAILED)
+            log.warning(msg)
+
+            return False
+        except Exception as e:
+            self._record(
+                message=f"{type(e).__name__}: {str(e)}",
+                node_state=NodeState.FAILED,
+            )
+            raise
+
+    def _close(self) -> None:
+        op = _RendezvousCloseOp()
+
+        deadline = self._get_deadline(self._settings.timeout.close)
+
+        self._op_executor.run(op, deadline)
+
+        msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'."
+        self._record(message=msg, node_state=NodeState.SUCCEEDED)
+        log.info(msg)
+
+    @staticmethod
+    def _keep_alive_weak(weak_self) -> None:
+        self = weak_self()
+        if self is not None:
+            self._keep_alive()
+
+    def _keep_alive(self) -> None:
+        self._heartbeat_lock.acquire()
+
+        op = _RendezvousKeepAliveOp()
+
+        deadline = self._get_deadline(self._settings.timeout.heartbeat)
+
+        try:
+            self._op_executor.run(op, deadline)
+
+            msg = (
+                f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
+                f"'{self._settings.run_id}'."
+            )
+            self._record(message=msg)
+            log.debug(msg)
+        except RendezvousError as ex:
+            msg = (
+                f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
+                f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
+            )
+            self._record(message=msg, node_state=NodeState.FAILED)
+            log.warning(msg)
+        finally:
+            self._heartbeat_lock.release()
+
+    def _start_heartbeats(self) -> None:
+        self._keep_alive_timer = _PeriodicTimer(
+            self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
+        )
+
+        self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}")
+
+        self._keep_alive_timer.start()
+
+    def _stop_heartbeats(self) -> None:
+        if self._keep_alive_timer is None:
+            return
+
+        self._keep_alive_timer.cancel()
+
+    def _get_world(self) -> Tuple[int, int]:
+        state = self._state_holder.state
+
+        return state.participants[self._this_node], len(state.participants)
+
+    def _get_store(self) -> Store:
+        key_prefix = f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}"
+
+        return PrefixStore(key_prefix, self._store)
+
+    def _get_deadline(self, timeout: timedelta) -> float:
+        return time.monotonic() + timeout.total_seconds()
+
+
+def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]:
+    timeout = params.get_as_int(key + "_timeout")
+    if timeout is None:
+        return None
+    return timedelta(seconds=timeout)
+
+
+def create_handler(
+    store: Store, backend: RendezvousBackend, params: RendezvousParameters
+) -> DynamicRendezvousHandler:
+    """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters.
+
+    Args:
+        store:
+            The C10d store to return as part of the rendezvous.
+        backend:
+            The backend to use to hold the rendezvous state.
+
+    +-------------------+------------------------------------------------------+
+    | Parameter         | Description                                          |
+    +===================+======================================================+
+    | join_timeout      | The total time, in seconds, within which the         |
+    |                   | rendezvous is expected to complete. Defaults to 600  |
+    |                   | seconds.                                             |
+    +-------------------+------------------------------------------------------+
+    | last_call_timeout | An additional wait amount, in seconds, before        |
+    |                   | completing the rendezvous once the minimum number of |
+    |                   | nodes has been reached. Defaults to 30 seconds.      |
+    +-------------------+------------------------------------------------------+
+    | close_timeout     | The time, in seconds, within which the rendezvous is |
+    |                   | expected to close after a call to                    |
+    |                   | :py:meth:`RendezvousHandler.set_closed` or           |
+    |                   | :py:meth:`RendezvousHandler.shutdown`. Defaults to   |
+    |                   | 30 seconds.                                          |
+    +-------------------+------------------------------------------------------+
+    """
+    try:
+        timeout = RendezvousTimeout(
+            _get_timeout(params, "join"),
+            _get_timeout(params, "last_call"),
+            _get_timeout(params, "close"),
+        )
+
+        return DynamicRendezvousHandler.from_backend(
+            params.run_id,
+            store,
+            backend,
+            params.min_nodes,
+            params.max_nodes,
+            params.local_addr,
+            timeout,
+        )
+    except Exception as e:
+        construct_and_record_rdzv_event(
+            message=f"{type(e).__name__}: {str(e)}",
+            run_id=params.run_id,
+            node_state=NodeState.FAILED,
+        )
+        raise
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py
new file mode 100644
index 0000000000000000000000000000000000000000..952d3040c3c7383ed3ce2dd1f7a28ad51eec912f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous.py
@@ -0,0 +1,1045 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import logging
+import sys
+import threading
+import time
+from typing import Optional
+
+import etcd  # type: ignore[import]
+from torch.distributed.elastic.rendezvous import (
+    RendezvousClosedError,
+    RendezvousError,
+    RendezvousHandler,
+    RendezvousParameters,
+    RendezvousTimeoutError,
+)
+
+from .utils import parse_rendezvous_endpoint
+from .etcd_store import EtcdStore, cas_delay
+
+
+_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s")
+_log_handler = logging.StreamHandler(sys.stderr)
+_log_handler.setFormatter(_log_fmt)
+
+log = logging.getLogger(__name__)
+log.propagate = False
+log.setLevel(logging.INFO)
+log.addHandler(_log_handler)
+
+
+# Retryable failure exception means the we were too late to make
+# a desired state transition (e.g. because of a race condition),
+# and should now restart from the beginning.
+# A small delay is recommended to avoid spamming Etcd.
+class EtcdRendezvousRetryableFailure(Exception):
+    pass
+
+
+# Similar to retryable failure, but the new state we observed suggests we
+# can re-try immediately, i.e. without a need for "safety delay".
+class EtcdRendezvousRetryImmediately(Exception):
+    pass
+
+
+# Default timeout for the rendezvous.
+_DEFAULT_TIMEOUT: int = 600  # 10 minutes
+
+# Additional waiting time after reaching the minimum number of nodes
+# in case the rendezvous is elastic (min != max).
+_DEFAULT_LAST_CALL_TIMEOUT: int = 30  # 30 seconds
+
+# Various constants used internally in EtcdRendezvous
+CONST_ETCD_SETUP_TTL = 5
+CONST_ETCD_FROZEN_TTL = 10
+CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10
+
+# Ephemeral node TTL for worker's keep-alive key:
+CONST_WORKER_KEEPALIVE_TTL = 10
+
+# TTL for the ephemeral run_id-specific directory. All rendezvous state data
+# for a specific run_id (job instance) is contained within directory.
+# Its only role is to clean-up rendezvous data from old runs (for the case when
+# etcd server is persistent), and has no affect on correctness, but should be
+# larger than any timeouts that a worker process is expected to survive:
+CONST_RUNID_SUBROOT_TTL = 7200  # 2 hours
+
+
+class EtcdRendezvousHandler(RendezvousHandler):
+    """
+    Implements a
+    :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` interface
+    backed by
+    :py:class:`torch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`.
+    ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to
+    use and to pass implementation specific configurations to the rendezvous
+    module. The basic etcd rendezvous configuration URL looks like the following
+    ::
+
+     etcd://:/?min_workers=&max_workers=  # noqa: W605
+
+     -- example --
+
+     etcd://localhost:2379/1234?min_workers=1&max_workers=3
+
+    The URL above is interpreted as follows:
+
+    1. Use the rendezvous handler that is registered with the ``etcd``
+       scheme
+    2. The ``etcd`` endpoint to use is ``localhost:2379``
+    3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to
+       share a common etcd server for multiple jobs so long as the
+       ``job_ids`` are guaranteed to be unique). Note that the job id can be
+       any string (e.g. does not need to be a number) as long as it is
+       unique.
+    4. ``min_workers=1`` and ``max_workers=3`` specifies a range for
+       membership size - Torch Distributed Elastic starts running the job as
+       long as the cluster size is greater than or equal to ``min_workers``
+       and admits up to ``max_workers`` into the cluster.
+
+    Below are a full list of the parameters that can be passed to etcd
+    rendezvous:
+
+    +--------------------------------------------+--------------------------+
+    | Parameter                                  | Description              |
+    +============================================+==========================+
+    | min_workers                                | minimum number of        |
+    |                                            | workers for the          |
+    |                                            | rendezvous to be valid   |
+    +--------------------------------------------+--------------------------+
+    | max_workers                                | maximum number of        |
+    |                                            | workers to admit         |
+    +--------------------------------------------+--------------------------+
+    | timeout                                    | total timeout within     |
+    |                                            | which next_rendezvous is |
+    |                                            | expected to succeed      |
+    |                                            | (default 600s)           |
+    +--------------------------------------------+--------------------------+
+    | last_call_timeout                          | additional wait amount   |
+    |                                            | (“last call”) after min  |
+    |                                            | number of workers has    |
+    |                                            | been reached (defaults   |
+    |                                            | to 30s)                  |
+    +--------------------------------------------+--------------------------+
+    | etcd_prefix                                | path prefix (from etcd   |
+    |                                            | root), inside which all  |
+    |                                            | etcd nodes will be       |
+    |                                            | created (defaults to     |
+    |                                            | ``/torchelastic/p2p``)   |
+    +--------------------------------------------+--------------------------+
+    """
+
+    def __init__(self, rdzv_impl):
+        self._rdzv_impl = rdzv_impl
+
+    def __del__(self):
+        # TODO: look into using weakref here instead.
+        del self._rdzv_impl
+
+    def get_backend(self) -> str:
+        return "etcd"
+
+    def next_rendezvous(self):
+        rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier()
+
+        log.info("Creating EtcdStore as the c10d::Store implementation")
+        store = self._rdzv_impl.setup_kv_store(rdzv_version)
+
+        return store, rank, world_size
+
+    def is_closed(self):
+        try:
+            _, state = self._rdzv_impl.get_rdzv_state()
+            return state["status"] == "closed"
+        except etcd.EtcdKeyNotFound:
+            # No rendezvous state, so it cannot be closed.
+            return False
+
+    def set_closed(self):
+        self._rdzv_impl.set_closed()
+
+    def num_nodes_waiting(self):
+        try:
+            _, state = self._rdzv_impl.get_rdzv_state()
+            if state["status"] == "final":
+                return state["num_workers_waiting"]
+        except etcd.EtcdKeyNotFound:
+            pass
+        return 0
+
+    def get_run_id(self) -> str:
+        return self._rdzv_impl._run_id
+
+    def shutdown(self) -> bool:
+        try:
+            self.set_closed()
+            return True
+        except BaseException as e:
+            log.warning("Shutdown failed. Error occurred: %s", str(e))
+            return False
+
+
+# TODO: we should probably handle a few additional errors,
+# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are
+# only relevant for multi-node Etcd ensemble. A simple retry would work,
+# but is verbose to add everywhere. Consider wrapping the client calls
+# into auto-retry for these errors?
+#
+class EtcdRendezvous:
+    """A rendezvous implementation that uses `etcd `__ as the backend store."""
+
+    def __init__(
+        self,
+        client,
+        prefix,
+        run_id,
+        num_min_workers,
+        num_max_workers,
+        timeout,
+        last_call_timeout,
+    ):
+        self.client = client
+        log.info("Etcd machines: %s", self.client.machines)
+
+        self._prefix = prefix
+        self._run_id = run_id
+        self._num_min_workers = num_min_workers
+        self._num_max_workers = num_max_workers
+        self._timeout = timeout
+        self._last_call_timeout = last_call_timeout
+
+        # For cleaning up TTL refresher threads (for ephemeral keys)
+        self._lease_run_id_stop = None
+        self._lease_this_rank_stop = None
+
+        if not self._prefix.endswith("/"):
+            self._prefix += "/"
+
+        # Setup a permanent prefix dir, if didn't exist
+        if self._prefix != "/":
+            self.create_path_if_not_exists(self._prefix)
+
+        # Lease a "sub-root" node specific to this job instance (run_id)
+        self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL)
+        self._lease_run_id_stop = self.setup_lease_renewal(
+            self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL
+        )
+
+        # Subdir for all rendezvous work
+        self.create_path_if_not_exists(self.get_path("/rdzv"))
+
+        # Create a rendezvous version counter, if doesn't exist
+        try:
+            self.client.write(
+                key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False
+            )
+        except etcd.EtcdAlreadyExist:
+            pass
+
+    def __del__(self):
+        # TODO: look into using weakref here instead.
+        if self._lease_run_id_stop is not None:
+            self._lease_run_id_stop.set()
+
+        if self._lease_this_rank_stop is not None:
+            self._lease_this_rank_stop.set()
+
+    def rendezvous_barrier(self):
+        """
+        Main entry point for next rendezvous.
+
+        This method is blocking until rendezvous succeeds or a timeout occurs.
+
+        Returns:
+             ``(rdzv_version, rank, world_size)``
+
+        Raises:
+            RendezvousTimeoutError - timeout waiting for rendezvous
+            RendezvousClosedError - rendezvous is or was closed while waiting
+            RendezvousError - other persistent errors that
+             render the rendezvous non-retryable
+        """
+        self._rendezvous_deadline = time.time() + self._timeout
+        while True:
+            if time.time() > self._rendezvous_deadline:
+                raise RendezvousTimeoutError()
+
+            log.info("Attempting to join next rendezvous")
+            try:
+                # Dis-own our lease in the previous rendezvous, if exists
+                if self._lease_this_rank_stop is not None:
+                    self._lease_this_rank_stop.set()
+
+                return self.init_phase()
+
+            except EtcdRendezvousRetryImmediately:
+                # The type of failure suggests we can retry without delay
+                pass
+
+            except EtcdRendezvousRetryableFailure:
+                # In case of retryable failure, wait a small delay
+                # to avoid spamming etcd
+                time.sleep(1)
+
+            except RendezvousTimeoutError:
+                log.info("Rendezvous timeout occurred in EtcdRendezvousHandler")
+                raise
+
+            except RendezvousClosedError:
+                log.info(
+                    "Rendezvous for run_id=%s was observed to be closed", self._run_id
+                )
+                raise
+
+            except RendezvousError:
+                raise
+
+            except Exception as e:
+                # In case of a general exception, wait a small delay
+                # to avoid spamming etcd
+                # FIXME: there are a few things that fall under this like
+                # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly.
+                log.info("Rendezvous attempt failed, will retry. Reason: %s", e)
+                time.sleep(1)
+
+    def init_phase(self):
+        """
+        Initially, the rendezvous state is expected to be one of:
+
+        1. empty (non-existent) - in this case we try to create a new one.
+        2. joinable - we try to join it.
+        3. final - we announce ourselves as waiting, and go into monitoring mode
+
+        Any other state is considered transitional, and will be retried after
+        a short delay.
+
+        Returns:
+            ``(rdzv_version, rank, world_size)``
+
+        Raises:
+            RendezvousClosedError - current rendezvous was/is closed
+            EtcdRendezvousRetryableFailure - observed some intermediate
+             state, which is best handled by retrying later
+        """
+        try:
+            active_version = self.try_create_rendezvous()
+            state = json.loads(active_version.value)
+            log.info("New rendezvous state created: %s", state)
+        except etcd.EtcdAlreadyExist:
+            active_version, state = self.get_rdzv_state()
+            # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound),
+            # but this is ok for us - just means we'll restart from beginning.
+            log.info("Observed existing rendezvous state: %s", state)
+
+        if state["status"] == "closed":
+            raise RendezvousClosedError()
+
+        if state["status"] == "joinable":
+            return self.join_phase(state["version"])
+
+        if state["status"] == "final":
+            self.handle_existing_rendezvous(state["version"])
+            raise EtcdRendezvousRetryImmediately()
+
+        self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1)
+        raise EtcdRendezvousRetryableFailure()
+
+    def join_phase(self, expected_version):
+        """
+        We observed a rendezvous state in 'joinable' state, and attempt to join this
+        particular version, and then wait for all other peers to join.
+        """
+        # Failure to join will propagate an exception, causing a re-entry.
+        active_version, this_rank = self.join_rendezvous(expected_version)
+        state = json.loads(active_version.value)
+        log.info(
+            "Joined rendezvous version %s as rank %s. Full state: %s",
+            state["version"], this_rank, state
+        )
+
+        # If this worker was first to reach num_min_workers requirement,
+        # and rendezvous is still joinable (therefore it is elastic),
+        # then this worker will be responsible for waiting out the "last call"
+        # timeout and closing (i.e. transitioning to 'frozen') the rendezvous
+        # afterwards.
+        # As a safety against a potential failure of this worker (during the
+        # last call timeout), the rendezvous state is made ephemeral
+        # when min_num_workers is reached.
+
+        if this_rank == self._num_min_workers - 1 and state["status"] == "joinable":
+            log.info("Rank %s is responsible for join last call.", this_rank)
+            last_call_deadline = time.time() + self._last_call_timeout
+            self.handle_join_last_call(expected_version, last_call_deadline)
+            log.info("Rank %s finished join last call.", this_rank)
+
+        # Wait for rendezvous state to be frozen, which means a fixed set of peers
+        log.info("Waiting for remaining peers.")
+        active_version = self.wait_for_peers(expected_version)
+        state = json.loads(active_version.value)
+
+        assert (
+            state["version"] == expected_version
+        ), "Logic error: failed to observe version mismatch"
+
+        return self.confirm_phase(expected_version, this_rank)
+
+    def confirm_phase(self, expected_version, this_rank):
+        """
+        Once the rendezvous state transitions from 'joinable' to 'frozen',
+        we have every participant confirm their membership and setup per-member
+        keep-alive TTL keys, and then wait for all other participants to confirm,
+        which would then successfully conclude this rendezvous.
+        """
+        log.info("All peers arrived. Confirming membership.")
+        self.confirm_membership(expected_version, this_rank)
+
+        log.info("Waiting for confirmations from all peers.")
+        active_version = self.wait_for_final(expected_version)
+        state = json.loads(active_version.value)
+
+        log.info(
+            "Rendezvous version %s is complete. Final state: %s",
+            state["version"], state
+        )
+
+        # Rendezvous version number; our rank in it; world size
+        return state["version"], this_rank, len(state["participants"])
+
+    def handle_existing_rendezvous(self, expected_version):
+        """
+        Handle the case when there's an existing (state 'final) rendezvous already
+        in place, and we have to announce ourselves waiting, and wait until
+        the next rendezvous opportunity.
+        """
+        # If state is 'final' -> increment num_workers_waiting
+        # Then, observe state changes:
+        #   1. if it's no longer final -> bail out and re-try
+        #   2. if keep alives are missing, destroy it and bail out.
+        active_state = self.announce_self_waiting(expected_version)
+        log.info(
+            "Added self to waiting list. Rendezvous full state: %s",
+            active_state.value
+        )
+
+        self.wait_for_rendezvous_to_free(expected_version)
+        log.info("Previously existing rendezvous state changed. Will re-try joining.")
+
+    def try_create_rendezvous(self):
+        """
+        Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists).
+
+        Raises:
+             RendezvousError - on unexpected state
+        """
+        # Initially active_version is ephemeral - this is to handle the
+        # possibility that might fail to complete the setup transaction,
+        # i.e. the transition "setup" -> "joinable".
+        active_version = self.client.write(
+            key=self.get_path("/rdzv/active_version"),
+            value=json.dumps({"status": "setup"}),
+            prevExist=False,
+            ttl=CONST_ETCD_SETUP_TTL,
+        )
+
+        try:
+            version_counter = self.client.get(self.get_path("/rdzv/version_counter"))
+            version_counter.value = str(int(version_counter.value) + 1)
+            self.client.update(version_counter)
+        except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e:
+            raise RendezvousError(
+                "Unexpected state of EtcdRendezvousHandler, worker needs to die."
+            ) from e
+
+        # Any failure below results in declaring a retryable rendezvous failure.
+        # The ephemeral /rdzv/active_version will expire and someone can then
+        # re-try the setup process.
+
+        # Create directory node for participant data
+        self.client.write(
+            key=self.get_path(f"/rdzv/v_{version_counter.value}"),
+            value=None,
+            dir=True,
+            prevExist=False,
+        )
+
+        # Publish rendezvous version and signal it is ready-to-be-joined.
+        # If rendezvous was set closed just before this, a retry will happen,
+        # where the closed condition will be handled.
+        return self.client.test_and_set(
+            key=self.get_path("/rdzv/active_version"),
+            value=json.dumps(
+                {
+                    "status": "joinable",
+                    "version": version_counter.value,
+                    "participants": [],
+                }
+            ),
+            prev_value=active_version.value,
+        )
+
+    def join_rendezvous(self, expected_version):
+        """Helper method for the join phase."""
+        # Use compare-and-swap to add self to rendezvous state:
+        while True:
+            cas_delay()
+            active_version, state = self.get_rdzv_state()
+
+            if state["status"] != "joinable":
+                raise EtcdRendezvousRetryableFailure(
+                    "Rendezvous state became non-joinable before we could join. "
+                    "Must join next one."
+                )
+
+            if state["version"] != expected_version:
+                raise EtcdRendezvousRetryImmediately(
+                    "Rendezvous version changed. Must try join the new one."
+                )
+
+            assert (
+                len(state["participants"]) < self._num_max_workers
+            ), "Logic error: joinable rendezvous should always have space left"
+
+            this_rank = len(state["participants"])
+            state["participants"].append(this_rank)
+
+            # When reaching min workers, or changing state to frozen, we'll set
+            # the active_version node to be ephemeral.
+            set_ttl: Optional[int] = None
+            if len(state["participants"]) == self._num_max_workers:
+                state["status"] = "frozen"
+                state["keep_alives"] = []
+                set_ttl = CONST_ETCD_FROZEN_TTL
+            elif len(state["participants"]) >= self._num_min_workers:
+                set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL
+
+            try:
+                # Compare-and-swap.
+                active_version = self.client.test_and_set(
+                    key=self.get_path("/rdzv/active_version"),
+                    value=json.dumps(state),
+                    prev_value=active_version.value,
+                    ttl=set_ttl,
+                )
+                # We succeeded joining.
+                return active_version, this_rank
+
+            except etcd.EtcdCompareFailed:
+                log.info("Join rendezvous CAS unsuccessful, retrying")
+
+    def wait_for_peers(self, expected_version):
+        """Helper method for the join phase."""
+        active_version, state = self.get_rdzv_state()
+        while True:
+            if state["status"] == "frozen" and state["version"] == expected_version:
+                # Success, all peers arrived.
+                return active_version
+
+            elif state["status"] == "joinable" and state["version"] == expected_version:
+                # Continue waiting for any interesting events.
+                active_version, state = self.try_wait_for_state_change(
+                    etcd_index=active_version.etcd_index + 1
+                )
+
+            else:
+                # No valid transition possible at this point
+                raise EtcdRendezvousRetryableFailure(
+                    "Rendezvous state transition no longer possible. Must re-enter."
+                )
+
+    def confirm_membership(self, expected_version, this_rank):
+        """Helper method for the confirm phase."""
+        # Compare-and-swap loop
+        while True:
+            cas_delay()
+            active_version, state = self.get_rdzv_state()
+
+            if state["status"] != "frozen":
+                raise EtcdRendezvousRetryImmediately(
+                    "Rendezvous no longer frozen, before we confirmed. "
+                    "Must join next one"
+                )
+            if state["version"] != expected_version:
+                raise EtcdRendezvousRetryImmediately(
+                    "Rendezvous version changed. Must try join the new one."
+                )
+
+            this_lease_key = self.get_path(
+                f"/rdzv/v_{expected_version}/rank_{this_rank}"
+            )
+            self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL)
+
+            state["keep_alives"].append(this_lease_key)
+            if len(state["keep_alives"]) == len(state["participants"]):
+                # Everyone confirmed (this rank is last to do so)
+                state["status"] = "final"
+                state["num_workers_waiting"] = 0
+                finalize = True
+            else:
+                finalize = False
+
+            try:
+                # Compare-and-swap. If new state is still frozen, keep it ephemeral.
+                active_version = self.client.test_and_set(
+                    key=self.get_path("/rdzv/active_version"),
+                    value=json.dumps(state),
+                    prev_value=active_version.value,
+                    ttl=None if finalize else CONST_ETCD_FROZEN_TTL,
+                )
+
+                self._lease_this_rank_stop = self.setup_lease_renewal(
+                    this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL
+                )
+                return active_version
+
+            except etcd.EtcdCompareFailed:
+                log.info("Confirm membership CAS unsuccessful, retrying")
+
+    def wait_for_final(self, expected_version):
+        """Helper method for the confirm phase."""
+        active_version, state = self.get_rdzv_state()
+        while True:
+            if state["status"] == "final" and state["version"] == expected_version:
+                # Success. This rendezvous is final, and we accept it.
+                return active_version
+
+            elif state["status"] == "frozen" and state["version"] == expected_version:
+                # Continue waiting for any interesting events.
+                active_version, state = self.try_wait_for_state_change(
+                    etcd_index=active_version.etcd_index + 1
+                )
+
+            else:
+                # No valid transition possible at this point
+                raise EtcdRendezvousRetryableFailure(
+                    "Rendezvous state transition no longer possible. Must re-enter."
+                )
+
+    def announce_self_waiting(self, expected_version):
+        """
+        Announce this worker is waiting (via num_workers_waiting counter) to join next
+        rendezvous, but only if state and version match.
+        """
+        while True:
+            cas_delay()
+            active_version, state = self.get_rdzv_state()
+
+            if state["status"] != "final" or state["version"] != expected_version:
+                raise EtcdRendezvousRetryImmediately()
+
+            # Increment counter to signal an additional waiting worker.
+            state["num_workers_waiting"] += 1
+
+            try:
+                active_version = self.client.test_and_set(
+                    key=self.get_path("/rdzv/active_version"),
+                    value=json.dumps(state),
+                    prev_value=active_version.value,
+                )
+                return active_version
+
+            except etcd.EtcdCompareFailed:
+                log.info("Announce self as waiting CAS unsuccessful, retrying")
+
+    def wait_for_rendezvous_to_free(self, expected_version):
+        """
+        When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join.
+
+        Such opportunity may come from:
+
+        1. rendezvous state changed by someone else, in which case we unblock and retry.
+        2. rendezvous becomes invalid because at least one member failed to renew their
+           leased keep_alive node. We detect this, and destroy the rendezvous.
+        """
+        active_version, state = self.get_rdzv_state()
+        while True:
+            if state["status"] != "final" or state["version"] != expected_version:
+                return
+
+            # Check if current rendezvous state is valid, in the sense that all
+            # its members are alive (renewing their lease).
+            # If not, try destroy this rendezvous, so a new one can be created.
+            alive_members = self.client.get(
+                self.get_path(f"/rdzv/v_{expected_version}")
+            )
+            keep_alive_keys = [ch.key for ch in alive_members.children]
+
+            for key in state["keep_alives"]:
+                if key not in keep_alive_keys:
+                    # This participant didn't renew their lease. We'll declare this
+                    # rendezvous version as dead (but only if it hadn't changed)
+                    log.info("Keep-alive key %s is not renewed.", key)
+                    log.info(
+                        "Rendezvous version %s is incomplete. ",
+                        expected_version
+                    )
+                    log.info("Attempting to destroy it.")
+
+                    # Compare-and-delete operation. Throws if compare failed,
+                    # which means rendezvous was already destroyed/re-created/closed,
+                    # and we can try to re-enter the barrier.
+                    self.client.delete(
+                        key=self.get_path("/rdzv/active_version"),
+                        prevValue=active_version.value,
+                    )
+
+                    log.info(
+                        "Destroyed rendezvous version %s successfully.",
+                        expected_version
+                    )
+
+                    # We can return (and retry) immediately
+                    return
+
+            # Existing rendezvous seems valid, no reason to destroy it.
+            # We just have to wait until something changes and re-check.
+            try:
+                overall_timeout = (
+                    max(self._rendezvous_deadline - time.time(), 0.0) + 1.0
+                )
+                self.client.watch(
+                    key=self.get_path("/rdzv"),
+                    index=active_version.etcd_index + 1,
+                    recursive=True,
+                    timeout=overall_timeout,
+                )
+            except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
+                pass
+
+            if time.time() > self._rendezvous_deadline:
+                raise RendezvousTimeoutError()
+            active_version, state = self.get_rdzv_state()
+
+    def handle_join_last_call(self, expected_version, deadline):
+        """
+        After we reach min number of workers, one particular worker takes on the
+        responsibility of waiting an additional timeout before closing the join window.
+        If the worker responsible for this fails, the rendezvous will be destroyed due
+        to expiring TTL, and the other participants will re-rendezvous.
+
+        Here we expect to see state 
+        Exit gracefully if either:
+
+        1. state becomes 
+        2. timeout happens (reaching deadline), in which case
+           we try the transition to 
+
+        Exit with exception otherwise.
+        """
+        active_version, state = self.get_rdzv_state()
+        while True:
+            if state["status"] == "frozen" and state["version"] == expected_version:
+                # Worker set became frozen before last-call timeout. This is possible
+                # when num_max_workers is reached before the timeout.
+                return
+
+            if state["status"] != "joinable" or state["version"] != expected_version:
+                raise EtcdRendezvousRetryableFailure(
+                    "Rendezvous state transition no longer possible. Must re-enter."
+                )
+
+            # If timeout occurred, attempt a state transition (joinable -> frozen)
+            if time.time() >= deadline:
+                state["status"] = "frozen"
+                state["keep_alives"] = []
+                try:
+                    active_version = self.client.test_and_set(
+                        key=self.get_path("/rdzv/active_version"),
+                        value=json.dumps(state),
+                        prev_value=active_version.value,
+                        ttl=CONST_ETCD_FROZEN_TTL,
+                    )
+                    # We successfully made this rendezvous frozen.
+                    return
+                except etcd.EtcdCompareFailed:
+                    log.info("Join last-call transition CAS unsuccessful. Will retry")
+                    cas_delay()
+                    active_version, state = self.get_rdzv_state()
+                    continue
+
+            # Timeout did not occur, so we must refresh TTL, and wait for
+            # further changes. Note: we only want TTL to be refreshed if
+            # state is still joinable, hence we use CAS for that here,
+            # even though we don't change any of the data.
+            try:
+                active_version = self.client.test_and_set(
+                    key=self.get_path("/rdzv/active_version"),
+                    value=active_version.value,
+                    prev_value=active_version.value,
+                    ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL,
+                )
+
+                # Minimize "oversleeping":
+                timeout = min(
+                    CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2,
+                    deadline - time.time() + 1.0,  # Oversleeping by 1s is ok.
+                )
+                active_version, state = self.try_wait_for_state_change(
+                    etcd_index=active_version.etcd_index + 1, timeout=timeout
+                )
+            except etcd.EtcdCompareFailed:
+                log.info("Join last-call TTL refresh CAS unsuccessful, will retry")
+                cas_delay()
+                active_version, state = self.get_rdzv_state()
+
+    def set_closed(self):
+        """
+        Mark rendezvous 'closed' for current run_id, which is used to signal other
+        participants to not attempt to perform (re-)rendezvous. This is useful
+        when one of the workers decides the job is complete.
+        """
+        while True:
+            active_version, state = self.get_rdzv_state()
+
+            if state["status"] == "closed":
+                # Already closed by someone else.
+                return
+
+            state["status"] = "closed"
+            try:
+                self.client.test_and_set(
+                    key=self.get_path("/rdzv/active_version"),
+                    value=json.dumps(state),
+                    prev_value=active_version.value,
+                )
+                return
+
+            except etcd.EtcdCompareFailed:
+                log.info("Set closed CAS unsuccessful, retrying")
+                cas_delay()
+
+    def get_rdzv_state(self):
+        active_version = self.client.get(key=self.get_path("/rdzv/active_version"))
+        return active_version, json.loads(active_version.value)
+
+    def try_wait_for_state_change(self, etcd_index, timeout=None):
+        # Don't sleep past the overall deadline (at least more than by 1s)
+        overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0
+        timeout = overall_timeout if timeout is None else min(timeout, overall_timeout)
+
+        try:
+            self.client.watch(
+                self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout
+            )
+        except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
+            pass
+
+        if time.time() > self._rendezvous_deadline:
+            raise RendezvousTimeoutError()
+
+        # Unfortunately, we have to do another fetch in order to get last etcd_index.
+        return self.get_rdzv_state()
+
+    def get_path(self, path):
+        if not path.startswith("/"):
+            path = "/" + path
+
+        return f"{self._prefix}run_{self._run_id}{path}"
+
+    def create_path_if_not_exists(self, full_path, ttl=None):
+        try:
+            self.client.write(
+                key=full_path, value=None, dir=True, prevExist=False, ttl=ttl
+            )
+        except etcd.EtcdAlreadyExist:
+            pass
+
+    def setup_lease_renewal(self, full_path, ttl):
+        # NOTE: For ephemeral key TTL renewal (~lease) to work correctly,
+        # make sure you don't call any long-blocking methods that do not
+        # release the Python's GIL! An example of this is calling a pybind11
+        # extension function that is blocking / long-running, but is not
+        # doing a scoped release of the GIL.
+        def lease_worker(client, path, ttl, stop_event):
+            while True:
+                try:
+                    client.refresh(path, ttl=ttl)
+                except etcd.EtcdKeyNotFound:
+                    break
+                except ConnectionRefusedError:
+                    # This error usually occurs during test when the server already got terminated but the
+                    # python garbage collector have not yet invoked the __del__ method.
+                    break
+
+                if stop_event.wait(timeout=ttl / 2):
+                    break
+
+        lease_stop_event = threading.Event()
+        lease_thread = threading.Thread(
+            target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event)
+        )
+
+        lease_thread.daemon = True
+        lease_thread.start()
+
+        return lease_stop_event
+
+    def store_extra_data(self, rdzv_version, key, value):
+        node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
+        try:
+            # If first time we are storing anything:
+            extra_data = self.client.write(
+                key=node, value=json.dumps({key: value}), prevExist=False
+            )
+            return
+        except etcd.EtcdAlreadyExist:
+            pass
+
+        # CAS loop, to make sure we don't lose concurrent stores.
+        while True:
+            # We never delete extra_data. Failure here should be fatal, no special handling.
+            extra_data = self.client.get(node)
+
+            new_extra_data_value = json.loads(extra_data.value)
+            new_extra_data_value[key] = value
+
+            try:
+                extra_data = self.client.test_and_set(
+                    key=node,
+                    value=json.dumps(new_extra_data_value),
+                    prev_value=extra_data.value,
+                )
+                return
+            except etcd.EtcdCompareFailed:
+                log.info("Store extra_data CAS unsuccessful, retrying")
+                time.sleep(0.1)
+
+    def load_extra_data(self, rdzv_version, key, timeout=None):
+        # 'extra_data' node itself, and the directory it is located in:
+        node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
+        node_dir = self.get_path(f"/rdzv/v_{rdzv_version}")
+
+        # TODO: implement timeout
+        # https://github.com/pytorch/elastic/issues/12
+        while True:
+            # Combined wait for the node itself, and the key inside it.
+            root = self.client.get(node_dir)
+
+            # Find the extra_data node, if it exists
+            extra_data = [n for n in root.children if n.key == node]
+            assert len(extra_data) <= 1
+
+            # Node for extra_data exists, check the desired key inside it.
+            if len(extra_data) == 1:
+                extra_data_dict = json.loads(extra_data[0].value)
+                if key in extra_data_dict:
+                    return extra_data_dict[key]
+
+            # The 'extra_data' node doesn't exist, or they key isn't published yet.
+            # Wait for interesting events on the extra_data node and retry.
+            try:
+                self.client.watch(node, index=root.etcd_index + 1)
+            except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
+                pass
+
+    def setup_kv_store(self, rdzv_version):
+        store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv")
+        self.create_path_if_not_exists(store_path)
+        return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path)
+
+
+def _create_etcd_client(params: RendezvousParameters) -> etcd.Client:
+    """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``."""
+    hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379)
+
+    # The communication protocol
+    protocol = params.config.get("protocol")
+    if protocol is None:
+        protocol = "http"
+    else:
+        if protocol != "http" and protocol != "https":
+            raise ValueError("The etcd protocol must be HTTP or HTTPS.")
+
+    # The SSL client certificate
+    ssl_cert = params.config.get("cert")
+    if ssl_cert is not None:
+        cert_key = params.config.get("key")
+        if cert_key is not None:
+            # The etcd client expects the certificate key as the second element
+            # of the `cert` tuple.
+            ssl_cert = (ssl_cert, cert_key)
+
+    # The root certificate
+    ca_cert = params.config.get("cacert")
+
+    return etcd.Client(
+        hostname,
+        port,
+        protocol=protocol,
+        cert=ssl_cert,
+        ca_cert=ca_cert,
+        allow_reconnect=True,
+    )
+
+
+# Handler for torch.distributed "static" registration
+def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
+    """
+    Usage:
+
+    ::
+
+    rdzv_params = RendezvousParameters(
+                        backend="etcd",
+                        endpoint="192.168.0.42:2379",
+                        run_id="123",
+                        min_nodes=4,
+                        max_nodes=8,
+                        timeout=300,
+                        last_call_timeout=30,
+                        etcd_prefix="custom_prefix",
+                        protocol="https",
+                        cacert="/etc/kubernetes/certs/ca.crt",
+                        cert="/etc/kubernetes/certs/client.crt",
+                        key="/etc/kubernetes/certs/client.key")
+    # -- or --
+    rdzv_params = RendezvousParameters(
+                        backend="etcd",
+                        endpoint="192.168.0.42:2379",
+                        run_id="123",
+                        min_nodes=4,
+                        max_nodes=8)
+
+    etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params)
+
+
+    Where:
+        run_id - unique id for this training job instance,
+        min_nodes - min number of workers expected to join the rendezvous,
+        max_nodes - max number of workers allowed to join the rendezvous,
+                        defaults to min_workers is not specified.
+        timeout - total timeout within which next_rendezvous is expected to
+                      succeed; a RendezvousTimeoutError is raised otherwise;
+                      Defaults is 600 (10 minutes).
+        last_call_timeout - additional wait amount ("last call") after
+                            min number of workers has been reached.
+                            Defaults to 30 seconds.
+        etcd_prefix - path prefix (from etcd root), inside which all
+                      etcd nodes will be created.
+                      Default is "/torchelastic/p2p".
+        protocol - http (default) or https to access etcd.
+        cacert - CA cert to access etcd, only makes sense with https.
+        cert - client cert to access etcd, only makes sense with https.
+        key - client key to access etcd, only makes sense with https.
+    """
+    client = _create_etcd_client(params)
+
+    etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p")
+
+    rdzv = EtcdRendezvous(
+        client=client,
+        prefix=etcd_prefix,
+        run_id=params.run_id,
+        num_min_workers=params.min_nodes,
+        num_max_workers=params.max_nodes,
+        timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT),
+        last_call_timeout=params.get_as_int("last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT),
+    )
+    return EtcdRendezvousHandler(rdzv_impl=rdzv)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..40d1501bed8820c12d3c824c812e277e20a20c0b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py
@@ -0,0 +1,213 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import binascii
+from base64 import b64decode, b64encode
+from typing import Optional, Tuple, cast
+
+import urllib3.exceptions  # type: ignore[import]
+from etcd import Client as EtcdClient  # type: ignore[import]
+from etcd import (
+    EtcdAlreadyExist,
+    EtcdCompareFailed,
+    EtcdException,
+    EtcdKeyNotFound,
+    EtcdResult,
+)
+from torch.distributed import Store
+
+from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
+from .dynamic_rendezvous import RendezvousBackend, Token
+from .etcd_store import EtcdStore
+from .utils import parse_rendezvous_endpoint
+
+
+class EtcdRendezvousBackend(RendezvousBackend):
+    """Represents an etcd-based rendezvous backend.
+
+    Args:
+        client:
+            The ``etcd.Client`` instance to use to communicate with etcd.
+        run_id:
+            The run id of the rendezvous.
+        key_prefix:
+            The path under which to store the rendezvous state in etcd.
+        ttl:
+            The TTL of the rendezvous state. If not specified, defaults to two hours.
+    """
+
+    _DEFAULT_TTL = 7200  # 2 hours
+
+    _client: EtcdClient
+    _key: str
+    _ttl: int
+
+    def __init__(
+        self,
+        client: EtcdClient,
+        run_id: str,
+        key_prefix: Optional[str] = None,
+        ttl: Optional[int] = None,
+    ) -> None:
+        if not run_id:
+            raise ValueError("The run id must be a non-empty string.")
+
+        self._client = client
+
+        if key_prefix:
+            self._key = key_prefix + "/" + run_id
+        else:
+            self._key = run_id
+
+        if ttl and ttl > 0:
+            self._ttl = ttl
+        else:
+            self._ttl = self._DEFAULT_TTL
+
+    @property
+    def name(self) -> str:
+        """See base class."""
+        return "etcd-v2"
+
+    def get_state(self) -> Optional[Tuple[bytes, Token]]:
+        """See base class."""
+        try:
+            result = self._client.read(self._key)
+        except EtcdKeyNotFound:
+            return None
+        except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
+            raise RendezvousConnectionError(
+                "The connection to etcd has failed. See inner exception for details."
+            ) from exc
+
+        return self._decode_state(result)
+
+    def set_state(
+        self, state: bytes, token: Optional[Token] = None
+    ) -> Optional[Tuple[bytes, Token, bool]]:
+        """See base class."""
+        base64_state = b64encode(state).decode()
+
+        kwargs = {}
+
+        def get_state():
+            result = self.get_state()
+            if result is not None:
+                tmp = *result, False
+                # Python 3.6 does not support tuple unpacking in return
+                # statements.
+                return tmp
+            return None
+
+        if token:
+            try:
+                token = int(token)
+            except ValueError:
+                return get_state()
+
+        if token:
+            kwargs["prevIndex"] = token
+        else:
+            kwargs["prevExist"] = False
+
+        try:
+            result = self._client.write(self._key, base64_state, self._ttl, **kwargs)
+        except (EtcdAlreadyExist, EtcdCompareFailed):
+            result = None
+        except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
+            raise RendezvousConnectionError(
+                "The connection to etcd has failed. See inner exception for details."
+            ) from exc
+
+        if result is None:
+            return get_state()
+
+        tmp = *self._decode_state(result), True
+        return tmp
+
+    def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
+        base64_state = result.value.encode()
+
+        try:
+            state = b64decode(base64_state)
+        except binascii.Error as exc:
+            raise RendezvousStateError(
+                "The state object is corrupt. See inner exception for details."
+            ) from exc
+
+        return state, result.modifiedIndex
+
+
+def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
+    host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379)
+
+    # The timeout
+    read_timeout = cast(int, params.get_as_int("read_timeout", 60))
+    if read_timeout <= 0:
+        raise ValueError("The read timeout must be a positive integer.")
+
+    # The communication protocol
+    protocol = params.get("protocol", "http").strip().lower()
+    if protocol != "http" and protocol != "https":
+        raise ValueError("The protocol must be HTTP or HTTPS.")
+
+    # The SSL client certificate
+    ssl_cert = params.get("ssl_cert")
+    if ssl_cert:
+        ssl_cert_key = params.get("ssl_cert_key")
+        if ssl_cert_key:
+            # The etcd client expects the certificate key as the second element
+            # of the `cert` tuple.
+            ssl_cert = (ssl_cert, ssl_cert_key)
+
+    # The root certificate
+    ca_cert = params.get("ca_cert")
+
+    try:
+        return EtcdClient(
+            host,
+            port,
+            read_timeout=read_timeout,
+            protocol=protocol,
+            cert=ssl_cert,
+            ca_cert=ca_cert,
+            allow_reconnect=True,
+        )
+    except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
+        raise RendezvousConnectionError(
+            "The connection to etcd has failed. See inner exception for details."
+        ) from exc
+
+
+def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
+    """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters.
+
+    +--------------+-----------------------------------------------------------+
+    | Parameter    | Description                                               |
+    +==============+===========================================================+
+    | read_timeout | The read timeout, in seconds, for etcd operations.        |
+    |              | Defaults to 60 seconds.                                   |
+    +--------------+-----------------------------------------------------------+
+    | protocol     | The protocol to use to communicate with etcd. Valid       |
+    |              | values are "http" and "https". Defaults to "http".        |
+    +--------------+-----------------------------------------------------------+
+    | ssl_cert     | The path to the SSL client certificate to use along with  |
+    |              | HTTPS. Defaults to ``None``.                              |
+    +--------------+-----------------------------------------------------------+
+    | ssl_cert_key | The path to the private key of the SSL client certificate |
+    |              | to use along with HTTPS. Defaults to ``None``.            |
+    +--------------+-----------------------------------------------------------+
+    | ca_cert      | The path to the rool SSL authority certificate. Defaults  |
+    |              | to ``None``.                                              |
+    +--------------+-----------------------------------------------------------+
+    """
+    client = _create_etcd_client(params)
+
+    backend = EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous")
+
+    store = EtcdStore(client, "/torch/elastic/store")
+
+    return backend, store
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2cd5afb93a53eff8a0ba8537fe7ccfa8840585b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_server.py
@@ -0,0 +1,246 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import atexit
+import logging
+import os
+import shlex
+import shutil
+import socket
+import subprocess
+import tempfile
+import time
+from typing import Optional, TextIO, Union
+
+try:
+    import etcd  # type: ignore[import]
+except ModuleNotFoundError:
+    pass
+
+
+log = logging.getLogger(__name__)
+
+
+def find_free_port():
+    """
+    Find a free port and binds a temporary socket to it so that the port can be "reserved" until used.
+
+    .. note:: the returned socket must be closed before using the port,
+              otherwise a ``address already in use`` error will happen.
+              The socket should be held and closed as close to the
+              consumer of the port as possible since otherwise, there
+              is a greater chance of race-condition where a different
+              process may see the port as being free and take it.
+
+    Returns: a socket binded to the reserved free port
+
+    Usage::
+
+    sock = find_free_port()
+    port = sock.getsockname()[1]
+    sock.close()
+    use_port(port)
+    """
+    addrs = socket.getaddrinfo(
+        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
+    )
+
+    for addr in addrs:
+        family, type, proto, _, _ = addr
+        try:
+            s = socket.socket(family, type, proto)
+            s.bind(("localhost", 0))
+            s.listen(0)
+            return s
+        except OSError as e:
+            s.close()  # type: ignore[possibly-undefined]
+            print(f"Socket creation attempt failed: {e}")
+    raise RuntimeError("Failed to create a socket")
+
+
+def stop_etcd(subprocess, data_dir: Optional[str] = None):
+    if subprocess and subprocess.poll() is None:
+        log.info("stopping etcd server")
+        subprocess.terminate()
+        subprocess.wait()
+
+    if data_dir:
+        log.info("deleting etcd data dir: %s", data_dir)
+        shutil.rmtree(data_dir, ignore_errors=True)
+
+
+class EtcdServer:
+    """
+    .. note:: tested on etcd server v3.4.3.
+
+    Starts and stops a local standalone etcd server on a random free
+    port. Useful for single node, multi-worker launches or testing,
+    where a sidecar etcd server is more convenient than having to
+    separately setup an etcd server.
+
+    This class registers a termination handler to shutdown the etcd
+    subprocess on exit. This termination handler is NOT a substitute for
+    calling the ``stop()`` method.
+
+    The following fallback mechanism is used to find the etcd binary:
+
+    1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
+    2. Uses ``/bin/etcd`` if one exists
+    3. Uses ``etcd`` from ``PATH``
+
+    Usage
+    ::
+
+     server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
+     server.start()
+     client = server.get_client()
+     # use client
+     server.stop()
+
+    Args:
+        etcd_binary_path: path of etcd server binary (see above for fallback path)
+    """
+
+    def __init__(self, data_dir: Optional[str] = None):
+        self._port = -1
+        self._host = "localhost"
+
+        root = os.path.dirname(__file__)
+        default_etcd_bin = os.path.join(root, "bin/etcd")
+        self._etcd_binary_path = os.environ.get(
+            "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
+        )
+        if not os.path.isfile(self._etcd_binary_path):
+            self._etcd_binary_path = "etcd"
+
+        self._base_data_dir = (
+            data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
+        )
+        self._etcd_cmd = None
+        self._etcd_proc: Optional[subprocess.Popen] = None
+
+    def _get_etcd_server_process(self) -> subprocess.Popen:
+        if not self._etcd_proc:
+            raise RuntimeError(
+                "No etcd server process started. Call etcd_server.start() first"
+            )
+        else:
+            return self._etcd_proc
+
+    def get_port(self) -> int:
+        """Return the port the server is running on."""
+        return self._port
+
+    def get_host(self) -> str:
+        """Return the host the server is running on."""
+        return self._host
+
+    def get_endpoint(self) -> str:
+        """Return the etcd server endpoint (host:port)."""
+        return f"{self._host}:{self._port}"
+
+    def start(
+        self,
+        timeout: int = 60,
+        num_retries: int = 3,
+        stderr: Union[int, TextIO, None] = None,
+    ) -> None:
+        """
+        Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
+
+        Args:
+            timeout: time (in seconds) to wait for the server to be ready
+                before giving up.
+            num_retries: number of retries to start the server. Each retry
+                will wait for max ``timeout`` before considering it as failed.
+            stderr: the standard error file handle. Valid values are
+                `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
+                descriptor (a positive integer), an existing file object, and
+                `None`.
+
+        Raises:
+            TimeoutError: if the server is not ready within the specified timeout
+        """
+        curr_retries = 0
+        while True:
+            try:
+                data_dir = os.path.join(self._base_data_dir, str(curr_retries))
+                os.makedirs(data_dir, exist_ok=True)
+                return self._start(data_dir, timeout, stderr)
+            except Exception as e:
+                curr_retries += 1
+                stop_etcd(self._etcd_proc)
+                log.warning(
+                    "Failed to start etcd server, got error: %s, retrying", str(e)
+                )
+                if curr_retries >= num_retries:
+                    shutil.rmtree(self._base_data_dir, ignore_errors=True)
+                    raise
+        atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
+
+    def _start(
+        self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
+    ) -> None:
+        sock = find_free_port()
+        sock_peer = find_free_port()
+        self._port = sock.getsockname()[1]
+        peer_port = sock_peer.getsockname()[1]
+
+        etcd_cmd = shlex.split(
+            " ".join(
+                [
+                    self._etcd_binary_path,
+                    "--enable-v2",
+                    "--data-dir",
+                    data_dir,
+                    "--listen-client-urls",
+                    f"http://{self._host}:{self._port}",
+                    "--advertise-client-urls",
+                    f"http://{self._host}:{self._port}",
+                    "--listen-peer-urls",
+                    f"http://{self._host}:{peer_port}",
+                ]
+            )
+        )
+
+        log.info("Starting etcd server: [%s]", etcd_cmd)
+
+        sock.close()
+        sock_peer.close()
+        self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
+        self._wait_for_ready(timeout)
+
+    def get_client(self):
+        """Return an etcd client object that can be used to make requests to this server."""
+        return etcd.Client(
+            host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
+        )
+
+    def _wait_for_ready(self, timeout: int = 60) -> None:
+        client = etcd.Client(
+            host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
+        )
+        max_time = time.time() + timeout
+
+        while time.time() < max_time:
+            if self._get_etcd_server_process().poll() is not None:
+                # etcd server process finished
+                exitcode = self._get_etcd_server_process().returncode
+                raise RuntimeError(
+                    f"Etcd server process exited with the code: {exitcode}"
+                )
+            try:
+                log.info("etcd server ready. version: %s", client.version)
+                return
+            except Exception:
+                time.sleep(1)
+        raise TimeoutError("Timed out waiting for etcd server to be ready!")
+
+    def stop(self) -> None:
+        """Stop the server and cleans up auto generated resources (e.g. data dir)."""
+        log.info("EtcdServer stop method called")
+        stop_etcd(self._etcd_proc, self._base_data_dir)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d34ac7cfa25a82f527c3b7b4cde45baa3542482
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/etcd_store.py
@@ -0,0 +1,204 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import datetime
+import random
+import time
+from base64 import b64decode, b64encode
+from typing import Optional
+
+import etcd  # type: ignore[import]
+
+# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
+from torch.distributed import Store
+
+
+# Delay (sleep) for a small random amount to reduce CAS failures.
+# This does not affect correctness, but will reduce requests to etcd server.
+def cas_delay():
+    time.sleep(random.uniform(0, 0.1))
+
+
+# pyre-fixme[11]: Annotation `Store` is not defined as a type.
+class EtcdStore(Store):
+    """
+    Implement a c10 Store interface by piggybacking on the rendezvous etcd instance.
+
+    This is the store object returned by ``EtcdRendezvous``.
+    """
+
+    def __init__(
+        self,
+        etcd_client,
+        etcd_store_prefix,
+        # Default timeout same as in c10d/Store.hpp
+        timeout: Optional[datetime.timedelta] = None,
+    ):
+        super().__init__()  # required for pybind trampoline.
+
+        self.client = etcd_client
+        self.prefix = etcd_store_prefix
+
+        if timeout is not None:
+            self.set_timeout(timeout)
+
+        if not self.prefix.endswith("/"):
+            self.prefix += "/"
+
+    def set(self, key, value):
+        """
+        Write a key/value pair into ``EtcdStore``.
+
+        Both key and value may be either Python ``str`` or ``bytes``.
+        """
+        self.client.set(key=self.prefix + self._encode(key), value=self._encode(value))
+
+    def get(self, key) -> bytes:
+        """
+        Get a value by key, possibly doing a blocking wait.
+
+        If key is not immediately present, will do a blocking wait
+        for at most ``timeout`` duration or until the key is published.
+
+
+        Returns:
+            value ``(bytes)``
+
+        Raises:
+            LookupError - If key still not published after timeout
+        """
+        b64_key = self.prefix + self._encode(key)
+        kvs = self._try_wait_get([b64_key])
+
+        if kvs is None:
+            raise LookupError(f"Key {key} not found in EtcdStore")
+
+        return self._decode(kvs[b64_key])
+
+    def add(self, key, num: int) -> int:
+        """
+        Atomically increment a value by an integer amount.
+
+        The integer is represented as a string using base 10. If key is not present,
+        a default value of ``0`` will be assumed.
+
+        Returns:
+             the new (incremented) value
+
+
+        """
+        b64_key = self._encode(key)
+        # c10d Store assumes value is an integer represented as a decimal string
+        try:
+            # Assume default value "0", if this key didn't yet:
+            node = self.client.write(
+                key=self.prefix + b64_key,
+                value=self._encode(str(num)),  # i.e. 0 + num
+                prevExist=False,
+            )
+            return int(self._decode(node.value))
+        except etcd.EtcdAlreadyExist:
+            pass
+
+        while True:
+            # Note: c10d Store does not have a method to delete keys, so we
+            # can be sure it's still there.
+            node = self.client.get(key=self.prefix + b64_key)
+            new_value = self._encode(str(int(self._decode(node.value)) + num))
+            try:
+                node = self.client.test_and_set(
+                    key=node.key, value=new_value, prev_value=node.value
+                )
+                return int(self._decode(node.value))
+            except etcd.EtcdCompareFailed:
+                cas_delay()
+
+    def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
+        """
+        Wait until all of the keys are published, or until timeout.
+
+        Raises:
+            LookupError - if timeout occurs
+        """
+        b64_keys = [self.prefix + self._encode(key) for key in keys]
+        kvs = self._try_wait_get(b64_keys, override_timeout)
+        if kvs is None:
+            raise LookupError("Timeout while waiting for keys in EtcdStore")
+        # No return value on success
+
+    def check(self, keys) -> bool:
+        """Check if all of the keys are immediately present (without waiting)."""
+        b64_keys = [self.prefix + self._encode(key) for key in keys]
+        kvs = self._try_wait_get(
+            b64_keys,
+            override_timeout=datetime.timedelta(microseconds=1),  # as if no wait
+        )
+        return kvs is not None
+
+    #
+    # Encode key/value data in base64, so we can store arbitrary binary data
+    # in EtcdStore. Input can be `str` or `bytes`.
+    # In case of `str`, utf-8 encoding is assumed.
+    #
+    def _encode(self, value) -> str:
+        if type(value) == bytes:
+            return b64encode(value).decode()
+        elif type(value) == str:
+            return b64encode(value.encode()).decode()
+        raise ValueError("Value must be of type str or bytes")
+
+    #
+    # Decode a base64 string (of type `str` or `bytes`).
+    # Return type is `bytes`, which is more convenient with the Store interface.
+    #
+    def _decode(self, value) -> bytes:
+        if type(value) == bytes:
+            return b64decode(value)
+        elif type(value) == str:
+            return b64decode(value.encode())
+        raise ValueError("Value must be of type str or bytes")
+
+    #
+    # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys
+    # are published or timeout occurs.
+    # This is a helper method for the public interface methods.
+    #
+    # On success, a dictionary of {etcd key -> etcd value} is returned.
+    # On timeout, None is returned.
+    #
+    def _try_wait_get(self, b64_keys, override_timeout=None):
+        timeout = self.timeout if override_timeout is None else override_timeout  # type: ignore[attr-defined]
+        deadline = time.time() + timeout.total_seconds()
+
+        while True:
+            # Read whole directory (of keys), filter only the ones waited for
+            all_nodes = self.client.get(key=self.prefix)
+            req_nodes = {
+                node.key: node.value for node in all_nodes.children if node.key in b64_keys
+            }
+
+            if len(req_nodes) == len(b64_keys):
+                # All keys are available
+                return req_nodes
+
+            watch_timeout = deadline - time.time()
+            if watch_timeout <= 0:
+                return None
+
+            try:
+                self.client.watch(
+                    key=self.prefix,
+                    recursive=True,
+                    timeout=watch_timeout,
+                    index=all_nodes.etcd_index + 1,
+                )
+            except etcd.EtcdWatchTimedOut:
+                if time.time() >= deadline:
+                    return None
+                else:
+                    continue
+            except etcd.EtcdEventIndexCleared:
+                continue
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/registry.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0cb8e421ef2508e1db799083bcbb36863631bb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/registry.py
@@ -0,0 +1,66 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .api import RendezvousHandler, RendezvousParameters
+from .api import rendezvous_handler_registry as handler_registry
+from .dynamic_rendezvous import create_handler
+
+__all__ = ['get_rendezvous_handler']
+
+def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
+    from . import static_tcp_rendezvous
+
+    return static_tcp_rendezvous.create_rdzv_handler(params)
+
+
+def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
+    from . import etcd_rendezvous
+
+    return etcd_rendezvous.create_rdzv_handler(params)
+
+
+def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
+    from .etcd_rendezvous_backend import create_backend
+
+    backend, store = create_backend(params)
+
+    return create_handler(store, backend, params)
+
+
+def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
+    from .c10d_rendezvous_backend import create_backend
+
+    backend, store = create_backend(params)
+
+    return create_handler(store, backend, params)
+
+
+def _register_default_handlers() -> None:
+    handler_registry.register("etcd", _create_etcd_handler)
+    handler_registry.register("etcd-v2", _create_etcd_v2_handler)
+    handler_registry.register("c10d", _create_c10d_handler)
+    handler_registry.register("static", _create_static_handler)
+
+
+def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
+    """
+    Obtain a reference to a :py:class`RendezvousHandler`.
+
+    Custom rendezvous handlers can be registered by
+
+    ::
+
+      from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
+      from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
+
+      def create_my_rdzv(params: RendezvousParameters):
+        return MyCustomRdzv(params)
+
+      rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
+
+      my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
+    """
+    return handler_registry.create_handler(params)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5f2523512e5b3ad696c4e4f538d69bdcabc34d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import datetime
+import logging
+from typing import Tuple, cast, Optional
+
+# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
+from torch.distributed import Store, TCPStore, PrefixStore
+from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
+from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
+
+log = logging.getLogger(__name__)
+
+_default_timeout_seconds = 600
+
+
+class StaticTCPRendezvous(RendezvousHandler):
+    """
+    Static rendezvous that is a wrapper around the TCPStore.
+
+    Creates TCPStore based on the input parameters with the
+    listener on the agent with group_rank=0
+    """
+
+    def __init__(
+        self,
+        master_addr: str,
+        master_port: int,
+        rank: int,
+        world_size: int,
+        run_id: str,
+        timeout: int,
+    ):
+        self.master_addr = master_addr
+        self.master_port = master_port
+        self.rank = rank
+        self.world_size = world_size
+        self.run_id = run_id
+        self.timeout = datetime.timedelta(seconds=timeout)
+        self._store: Optional[Store] = None
+
+    def get_backend(self) -> str:
+        return "static"
+
+    def next_rendezvous(self) -> Tuple[Store, int, int]:
+        log.info("Creating TCPStore as the c10d::Store implementation")
+        if not self._store:
+            is_master = self.rank == 0
+            self._store = TCPStore(  # type: ignore[call-arg]
+                self.master_addr,
+                self.master_port,
+                self.world_size,
+                is_master,
+                self.timeout,
+                multi_tenant=True,
+            )
+        store = PrefixStore(self.run_id, self._store)
+        return store, self.rank, self.world_size
+
+    def is_closed(self):
+        return False
+
+    def set_closed(self):
+        pass
+
+    def num_nodes_waiting(self):
+        return 0
+
+    def get_run_id(self) -> str:
+        return self.run_id
+
+    def shutdown(self) -> bool:
+        return True
+
+
+def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
+    if "rank" not in params.config:
+        raise ValueError(
+            "rank is absent in RendezvousParameters."
+            "Try add --node-rank to the cmd request"
+        )
+    endpoint = params.endpoint.strip()
+    if not endpoint:
+        raise ValueError(
+            "endpoint is absent in RendezvousParameters"
+            "Try add --master-port and --master-addr to the cmd request"
+        )
+    master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
+    if master_port == -1:
+        raise ValueError(
+            f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
+        )
+    world_size = params.max_nodes
+    rank = cast(int, params.config.get("rank"))
+    run_id = params.run_id
+    if "timeout" in params.config:
+        timeout = int(params.config["timeout"])
+    else:
+        timeout = _default_timeout_seconds
+    return StaticTCPRendezvous(
+        master_addr, master_port, rank, world_size, run_id, timeout
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fde1a3380f3bd693f6add3127dec191a5c58e31
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py
@@ -0,0 +1,279 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import ipaddress
+import random
+import re
+import socket
+import time
+import weakref
+from datetime import timedelta
+from threading import Event, Thread
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+__all__ = ['parse_rendezvous_endpoint']
+
+def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
+    """Extract key-value pairs from a rendezvous configuration string.
+
+    Args:
+        config_str:
+            A string in format =,...,=.
+    """
+    config: Dict[str, str] = {}
+
+    config_str = config_str.strip()
+    if not config_str:
+        return config
+
+    key_values = config_str.split(",")
+    for kv in key_values:
+        key, *values = kv.split("=", 1)
+
+        key = key.strip()
+        if not key:
+            raise ValueError(
+                "The rendezvous configuration string must be in format "
+                "=,...,=."
+            )
+
+        value: Optional[str]
+        if values:
+            value = values[0].strip()
+        else:
+            value = None
+        if not value:
+            raise ValueError(
+                f"The rendezvous configuration option '{key}' must have a value specified."
+            )
+
+        config[key] = value
+    return config
+
+
+def _try_parse_port(port_str: str) -> Optional[int]:
+    """Try to extract the port number from ``port_str``."""
+    if port_str and re.match(r"^[0-9]{1,5}$", port_str):
+        return int(port_str)
+    return None
+
+
+def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]:
+    """Extract the hostname and the port number from a rendezvous endpoint.
+
+    Args:
+        endpoint:
+            A string in format [:].
+        default_port:
+            The port number to use if the endpoint does not include one.
+
+    Returns:
+        A tuple of hostname and port number.
+    """
+    if endpoint is not None:
+        endpoint = endpoint.strip()
+
+    if not endpoint:
+        return ("localhost", default_port)
+
+    # An endpoint that starts and ends with brackets represents an IPv6 address.
+    if endpoint[0] == "[" and endpoint[-1] == "]":
+        host, *rest = endpoint, *[]
+    else:
+        host, *rest = endpoint.rsplit(":", 1)
+
+    # Sanitize the IPv6 address.
+    if len(host) > 1 and host[0] == "[" and host[-1] == "]":
+        host = host[1:-1]
+
+    if len(rest) == 1:
+        port = _try_parse_port(rest[0])
+        if port is None or port >= 2 ** 16:
+            raise ValueError(
+                f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
+                "between 0 and 65536."
+            )
+    else:
+        port = default_port
+
+    if not re.match(r"^[\w\.:-]+$", host):
+        raise ValueError(
+            f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
+            "labels, an IPv4 address, or an IPv6 address."
+        )
+
+    return host, port
+
+
+def _matches_machine_hostname(host: str) -> bool:
+    """Indicate whether ``host`` matches the hostname of this machine.
+
+    This function compares ``host`` to the hostname as well as to the IP
+    addresses of this machine. Note that it may return a false negative if this
+    machine has CNAME records beyond its FQDN or IP addresses assigned to
+    secondary NICs.
+    """
+    if host == "localhost":
+        return True
+
+    try:
+        addr = ipaddress.ip_address(host)
+    except ValueError:
+        addr = None
+
+    if addr and addr.is_loopback:
+        return True
+
+    try:
+        host_addr_list = socket.getaddrinfo(
+            host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
+        )
+    except (ValueError, socket.gaierror) as _:
+        host_addr_list = []
+
+    host_ip_list = [
+        host_addr_info[4][0]
+        for host_addr_info in host_addr_list
+    ]
+
+    this_host = socket.gethostname()
+    if host == this_host:
+        return True
+
+    addr_list = socket.getaddrinfo(
+        this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
+    )
+    for addr_info in addr_list:
+        # If we have an FQDN in the addr_info, compare it to `host`.
+        if addr_info[3] and addr_info[3] == host:
+            return True
+
+        # Otherwise if `host` represents an IP address, compare it to our IP
+        # address.
+        if addr and addr_info[4][0] == str(addr):
+            return True
+
+        # If the IP address matches one of the provided host's IP addresses
+        if addr_info[4][0] in host_ip_list:
+            return True
+
+    return False
+
+
+def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
+    """Suspend the current thread for ``seconds``.
+
+    Args:
+        seconds:
+            Either the delay, in seconds, or a tuple of a lower and an upper
+            bound within which a random delay will be picked.
+    """
+    if isinstance(seconds, tuple):
+        seconds = random.uniform(*seconds)
+    # Ignore delay requests that are less than 10 milliseconds.
+    if seconds >= 0.01:
+        time.sleep(seconds)
+
+
+class _PeriodicTimer:
+    """Represent a timer that periodically runs a specified function.
+
+    Args:
+        interval:
+            The interval, in seconds, between each run.
+        function:
+            The function to run.
+    """
+
+    # The state of the timer is hold in a separate context object to avoid a
+    # reference cycle between the timer and the background thread.
+    class _Context:
+        interval: float
+        function: Callable[..., None]
+        args: Tuple[Any, ...]
+        kwargs: Dict[str, Any]
+        stop_event: Event
+
+    _name: Optional[str]
+    _thread: Optional[Thread]
+    _finalizer: Optional[weakref.finalize]
+
+    # The context that is shared between the timer and the background thread.
+    _ctx: _Context
+
+    def __init__(
+        self,
+        interval: timedelta,
+        function: Callable[..., None],
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
+        self._name = None
+
+        self._ctx = self._Context()
+        self._ctx.interval = interval.total_seconds()
+        self._ctx.function = function  # type: ignore[assignment]
+        self._ctx.args = args or ()
+        self._ctx.kwargs = kwargs or {}
+        self._ctx.stop_event = Event()
+
+        self._thread = None
+        self._finalizer = None
+
+    @property
+    def name(self) -> Optional[str]:
+        """Get the name of the timer."""
+        return self._name
+
+    def set_name(self, name: str) -> None:
+        """Set the name of the timer.
+
+        The specified name will be assigned to the background thread and serves
+        for debugging and troubleshooting purposes.
+        """
+        if self._thread:
+            raise RuntimeError("The timer has already started.")
+
+        self._name = name
+
+    def start(self) -> None:
+        """Start the timer."""
+        if self._thread:
+            raise RuntimeError("The timer has already started.")
+
+        self._thread = Thread(
+            target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True
+        )
+
+        # We avoid using a regular finalizer (a.k.a. __del__) for stopping the
+        # timer as joining a daemon thread during the interpreter shutdown can
+        # cause deadlocks. The weakref.finalize is a superior alternative that
+        # provides a consistent behavior regardless of the GC implementation.
+        self._finalizer = weakref.finalize(
+            self, self._stop_thread, self._thread, self._ctx.stop_event
+        )
+
+        # We do not attempt to stop our background thread during the interpreter
+        # shutdown. At that point we do not even know whether it still exists.
+        self._finalizer.atexit = False
+
+        self._thread.start()
+
+    def cancel(self) -> None:
+        """Stop the timer at the next opportunity."""
+        if self._finalizer:
+            self._finalizer()
+
+    @staticmethod
+    def _run(ctx) -> None:
+        while not ctx.stop_event.wait(ctx.interval):
+            ctx.function(*ctx.args, **ctx.kwargs)
+
+    @staticmethod
+    def _stop_thread(thread, stop_event):
+        stop_event.set()
+
+        thread.join()
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b170b8d4444a1a0825a5268dfe278451170843a3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__init__.py
@@ -0,0 +1,44 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Expiration timers are set up on the same process as the agent and
+used from your script to deal with stuck workers. When you go into
+a code-block that has the potential to get stuck you can acquire
+an expiration timer, which instructs the timer server to kill the
+process if it does not release the timer by the self-imposed expiration
+deadline.
+
+Usage::
+
+    import torchelastic.timer as timer
+    import torchelastic.agent.server as agent
+
+    def main():
+        start_method = "spawn"
+        message_queue = mp.get_context(start_method).Queue()
+        server = timer.LocalTimerServer(message, max_interval=0.01)
+        server.start() # non-blocking
+
+        spec = WorkerSpec(
+                    fn=trainer_func,
+                    args=(message_queue,),
+                    ...)
+        agent = agent.LocalElasticAgent(spec, start_method)
+        agent.run()
+
+    def trainer_func(message_queue):
+        timer.configure(timer.LocalTimerClient(message_queue))
+        with timer.expires(after=60): # 60 second expiry
+            # do some work
+
+In the example above if ``trainer_func`` takes more than 60 seconds to
+complete, then the worker process is killed and the agent retries the worker group.
+"""
+
+from .api import TimerClient, TimerRequest, TimerServer, configure, expires  # noqa: F401
+from .local_timer import LocalTimerClient, LocalTimerServer  # noqa: F401
+from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest  # noqa: F401
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a71171a3bb7edce1e3b93e4d06b99a961886de7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8815e326a441a17653d9a1c6f1aef2975c2c103c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e6b9851207520532eb2e8ef30832a5606a025fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1582f48c44d1a904ae4a92f675dc41f5caf5a70a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..50a430ffe0360f173d969feae4d4774dd82ffd26
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/api.py
@@ -0,0 +1,280 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import abc
+import logging
+import threading
+import time
+from contextlib import contextmanager
+from inspect import getframeinfo, stack
+from typing import Any, Dict, List, Optional, Set
+
+__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires']
+
+log = logging.getLogger(__name__)
+
+class TimerRequest:
+    """
+    Data object representing a countdown timer acquisition and release
+    that is used between the ``TimerClient`` and ``TimerServer``.
+    A negative ``expiration_time`` should be interpreted as a "release"
+    request.
+
+    .. note:: the type of ``worker_id`` is implementation specific.
+              It is whatever the TimerServer and TimerClient implementations
+              have on to uniquely identify a worker.
+    """
+
+    __slots__ = ["worker_id", "scope_id", "expiration_time"]
+
+    def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
+        self.worker_id = worker_id
+        self.scope_id = scope_id
+        self.expiration_time = expiration_time
+
+    def __eq__(self, other):
+        if isinstance(other, TimerRequest):
+            return (
+                self.worker_id == other.worker_id
+                and self.scope_id == other.scope_id
+                and self.expiration_time == other.expiration_time
+            )
+        return False
+
+
+class TimerClient(abc.ABC):
+    """
+    Client library to acquire and release countdown timers by communicating
+    with the TimerServer.
+    """
+
+    @abc.abstractmethod
+    def acquire(self, scope_id: str, expiration_time: float) -> None:
+        """
+        Acquires a timer for the worker that holds this client object
+        given the scope_id and expiration_time. Typically registers
+        the timer with the TimerServer.
+        """
+        pass
+
+    @abc.abstractmethod
+    def release(self, scope_id: str):
+        """
+        Releases the timer for the ``scope_id`` on the worker this
+        client represents. After this method is
+        called, the countdown timer on the scope is no longer in effect.
+        """
+        pass
+
+
+class RequestQueue(abc.ABC):
+    """
+    Consumer queue holding timer acquisition/release requests
+    """
+
+    @abc.abstractmethod
+    def size(self) -> int:
+        """
+        Returns the size of the queue at the time this method is called.
+        Note that by the time ``get`` is called the size of the queue
+        may have increased. The size of the queue should not decrease
+        until the ``get`` method is called. That is, the following assertion
+        should hold:
+
+        size = q.size()
+        res = q.get(size, timeout=0)
+        assert size == len(res)
+
+        -- or --
+
+        size = q.size()
+        res = q.get(size * 2, timeout=1)
+        assert size <= len(res) <= size * 2
+        """
+        pass
+
+    @abc.abstractmethod
+    def get(self, size: int, timeout: float) -> List[TimerRequest]:
+        """
+        Gets up to ``size`` number of timer requests in a blocking fashion
+        (no more than ``timeout`` seconds).
+        """
+        pass
+
+
+class TimerServer(abc.ABC):
+    """
+    Entity that monitors active timers and expires them
+    in a timely fashion. This server is responsible for
+    reaping workers that have expired timers.
+    """
+
+    def __init__(
+        self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
+    ):
+        """
+        :param request_queue: Consumer ``RequestQueue``
+        :param max_interval: max time (in seconds) to wait
+                             for an item in the request_queue
+        :param daemon: whether to run the watchdog thread as a daemon
+        """
+        super().__init__()
+        self._request_queue = request_queue
+        self._max_interval = max_interval
+        self._daemon = daemon
+        self._watchdog_thread: Optional[threading.Thread] = None
+        self._stop_signaled = False
+
+    @abc.abstractmethod
+    def register_timers(self, timer_requests: List[TimerRequest]) -> None:
+        """
+        Processes the incoming timer requests and registers them with the server.
+        The timer request can either be a acquire-timer or release-timer request.
+        Timer requests with a negative expiration_time should be interpreted
+        as a release-timer request.
+        """
+        pass
+
+    @abc.abstractmethod
+    def clear_timers(self, worker_ids: Set[Any]) -> None:
+        """
+        Clears all timers for the given ``worker_ids``.
+        """
+        pass
+
+    @abc.abstractmethod
+    def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
+        """
+        Returns all expired timers for each worker_id. An expired timer
+        is a timer for which the expiration_time is less than or equal to
+        the provided deadline.
+        """
+        pass
+
+    @abc.abstractmethod
+    def _reap_worker(self, worker_id: Any) -> bool:
+        """
+        Reaps the given worker. Returns True if the worker has been
+        successfully reaped, False otherwise. If any uncaught exception
+        is thrown from this method, the worker is considered reaped
+        and all associated timers will be removed.
+        """
+
+    def _reap_worker_no_throw(self, worker_id: Any) -> bool:
+        """
+        Wraps ``_reap_worker(worker_id)``, if an uncaught exception is
+        thrown, then it considers the worker as reaped.
+        """
+        try:
+            return self._reap_worker(worker_id)
+        except Exception:
+            log.exception(
+                "Uncaught exception thrown from _reap_worker(), "
+                "check that the implementation correctly catches exceptions",
+            )
+            return True
+
+    def _watchdog_loop(self):
+        while not self._stop_signaled:
+            try:
+                self._run_watchdog()
+            except Exception:
+                log.exception("Error running watchdog")
+
+    def _run_watchdog(self):
+        batch_size = max(1, self._request_queue.size())
+        timer_requests = self._request_queue.get(batch_size, self._max_interval)
+        self.register_timers(timer_requests)
+        now = time.time()
+        reaped_worker_ids = set()
+        for worker_id, expired_timers in self.get_expired_timers(now).items():
+            log.info(
+                "Reaping worker_id=[%s]."
+                " Expired timers: %s",
+                worker_id, self._get_scopes(expired_timers)
+            )
+            if self._reap_worker_no_throw(worker_id):
+                log.info("Successfully reaped worker=[%s]", worker_id)
+                reaped_worker_ids.add(worker_id)
+            else:
+                log.error(
+                    "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id
+                )
+        self.clear_timers(reaped_worker_ids)
+
+    def _get_scopes(self, timer_requests):
+        return [r.scope_id for r in timer_requests]
+
+    def start(self) -> None:
+        log.info(
+            "Starting %s..."
+            " max_interval=%s,"
+            " daemon=%s",
+            type(self).__name__, self._max_interval, self._daemon
+        )
+        self._watchdog_thread = threading.Thread(
+            target=self._watchdog_loop, daemon=self._daemon
+        )
+        log.info("Starting watchdog thread...")
+        self._watchdog_thread.start()
+
+    def stop(self) -> None:
+        log.info("Stopping %s", type(self).__name__)
+        self._stop_signaled = True
+        if self._watchdog_thread:
+            log.info("Stopping watchdog thread...")
+            self._watchdog_thread.join(self._max_interval)
+            self._watchdog_thread = None
+        else:
+            log.info("No watchdog thread running, doing nothing")
+
+
+_timer_client: Optional[TimerClient] = None
+
+
+def configure(timer_client: TimerClient):
+    """
+    Configures a timer client. Must be called before using ``expires``.
+    """
+    global _timer_client
+    _timer_client = timer_client
+    log.info("Timer client configured to: %s", type(_timer_client).__name__)
+
+
+@contextmanager
+def expires(
+    after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
+):
+    """
+    Acquires a countdown timer that expires in ``after`` seconds from now,
+    unless the code-block that it wraps is finished within the timeframe.
+    When the timer expires, this worker is eligible to be reaped. The
+    exact meaning of "reaped" depends on the client implementation. In
+    most cases, reaping means to terminate the worker process.
+    Note that the worker is NOT guaranteed to be reaped at exactly
+    ``time.now() + after``, but rather the worker is "eligible" for being
+    reaped and the ``TimerServer`` that the client talks to will ultimately
+    make the decision when and how to reap the workers with expired timers.
+
+    Usage::
+
+        torch.distributed.elastic.timer.configure(LocalTimerClient())
+        with expires(after=10):
+            torch.distributed.all_reduce(...)
+    """
+    if client is None:
+        if _timer_client is None:
+            raise RuntimeError("Configure timer client before using countdown timers.")
+        client = _timer_client
+    if scope is None:
+        # grab the caller file + lineno
+        caller = getframeinfo(stack()[1][0])
+        scope = f"{caller.filename}#{caller.lineno}"
+    expiration = time.time() + after
+    client.acquire(scope, expiration)
+    try:
+        yield
+    finally:
+        client.release(scope)
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b299ece24162b3db1a34308eb339ad8a7a9c4c5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py
@@ -0,0 +1,333 @@
+# Copyright (c) Meta Platforms, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import io
+import json
+import logging
+import os
+import select
+import signal
+import sys
+import threading
+import time
+from typing import Callable, Dict, List, Optional, Set, Tuple
+
+from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
+
+__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
+
+log = logging.getLogger(__name__)
+
+class FileTimerRequest(TimerRequest):
+    """
+    Data object representing a countdown timer acquisition and release
+    that is used between the ``FileTimerClient`` and ``FileTimerServer``.
+    A negative ``expiration_time`` should be interpreted as a "release"
+    request.
+    ``signal`` is the signal to reap the worker process from the server
+    process.
+    """
+
+    __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"]
+
+    def __init__(self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0) -> None:
+        self.version = 1
+        self.worker_pid = worker_pid
+        self.scope_id = scope_id
+        self.expiration_time = expiration_time
+        self.signal = signal
+
+    def __eq__(self, other) -> bool:
+        if isinstance(other, FileTimerRequest):
+            return (
+                self.version == other.version
+                and self.worker_pid == other.worker_pid
+                and self.scope_id == other.scope_id
+                and self.expiration_time == other.expiration_time
+                and self.signal == other.signal
+            )
+        return False
+
+    def to_json(self) -> str:
+        return json.dumps(
+            {
+                "version": self.version,
+                "pid": self.worker_pid,
+                "scope_id": self.scope_id,
+                "expiration_time": self.expiration_time,
+                "signal": self.signal
+            },
+        )
+
+
+class FileTimerClient(TimerClient):
+    """
+    Client side of ``FileTimerServer``. This client is meant to be used
+    on the same host that the ``FileTimerServer`` is running on and uses
+    pid to uniquely identify a worker.
+    This client uses a named_pipe to send timer requests to the
+    ``FileTimerServer``. This client is a producer while the
+    ``FileTimerServer`` is a consumer. Multiple clients can work with
+    the same ``FileTimerServer``.
+
+    Args:
+
+        file_path: str, the path of a FIFO special file. ``FileTimerServer``
+                        must have created it by calling os.mkfifo().
+
+        signal: signal, the signal to use to kill the process. Using a
+                        negative or zero signal will not kill the process.
+    """
+    def __init__(self, file_path: str, signal=(signal.SIGKILL if sys.platform != "win32" else
+                                               signal.CTRL_C_EVENT)) -> None:  # type: ignore[attr-defined]
+        super().__init__()
+        self._file_path = file_path
+        self.signal = signal
+
+    def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
+        try:
+            fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK)
+            return os.fdopen(fd, "wt")
+        except Exception:
+            return None
+
+    def _send_request(self, request: FileTimerRequest) -> None:
+        # The server may have crashed or may haven't started yet.
+        # In such case, calling open() in blocking model blocks the client.
+        # To avoid such issue, open it in non-blocking mode, and an OSError will
+        # be raised if the server is not there.
+        file = self._open_non_blocking()
+        if file is None:
+            raise BrokenPipeError("Could not send the FileTimerRequest because FileTimerServer is not available.")
+        with file:
+            json_request = request.to_json()
+            # Write request with no greater than select.PIPE_BUF is guarantee to be atomic.
+            if len(json_request) > select.PIPE_BUF:
+                raise RuntimeError(
+                    f"FileTimerRequest larger than {select.PIPE_BUF} bytes "
+                    f"is not supported: {json_request}"
+                )
+            file.write(json_request + "\n")
+
+    def acquire(self, scope_id: str, expiration_time: float) -> None:
+        self._send_request(
+            request=FileTimerRequest(
+                worker_pid=os.getpid(),
+                scope_id=scope_id,
+                expiration_time=expiration_time,
+                signal=self.signal
+            ),
+        )
+
+    def release(self, scope_id: str) -> None:
+        self._send_request(
+            request=FileTimerRequest(
+                worker_pid=os.getpid(),
+                scope_id=scope_id,
+                expiration_time=-1,
+                signal=0
+            ),
+        )
+
+
+class FileTimerServer:
+    """
+    Server that works with ``FileTimerClient``. Clients are expected to be
+    running on the same host as the process that is running this server.
+    Each host in the job is expected to start its own timer server locally
+    and each server instance manages timers for local workers (running on
+    processes on the same host).
+
+    Args:
+
+        file_path: str, the path of a FIFO special file to be created.
+
+        max_interval: float, max interval in seconds for each watchdog loop.
+
+        daemon: bool, running the watchdog thread in daemon mode or not.
+                      A daemon thread will not block a process to stop.
+        log_event: Callable[[Dict[str, str]], None], an optional callback for
+                logging the events in JSON format.
+    """
+
+    def __init__(
+        self,
+        file_path: str,
+        max_interval: float = 10,
+        daemon: bool = True,
+        log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None
+    ) -> None:
+        self._file_path = file_path
+        self._max_interval = max_interval
+        self._daemon = daemon
+        self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
+        self._stop_signaled = False
+        self._watchdog_thread: Optional[threading.Thread] = None
+        if os.path.exists(self._file_path):
+            os.remove(self._file_path)
+        os.mkfifo(self._file_path)
+        # For test only. Count the number of requests received.
+        self._request_count = 0
+        # For test only. Process all requests and stop the server.
+        self._run_once = False
+        self._log_event = log_event if log_event is not None else lambda name, request: None
+
+
+    def start(self) -> None:
+        log.info(
+            "Starting %s..."
+            " max_interval=%s,"
+            " daemon=%s",
+            type(self).__name__, self._max_interval, self._daemon
+        )
+        self._watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=self._daemon)
+        log.info("Starting watchdog thread...")
+        self._watchdog_thread.start()
+        self._log_event("watchdog started", None)
+
+    def stop(self) -> None:
+        log.info("Stopping %s", type(self).__name__)
+        self._stop_signaled = True
+        if self._watchdog_thread:
+            log.info("Stopping watchdog thread...")
+            self._watchdog_thread.join(self._max_interval)
+            self._watchdog_thread = None
+        else:
+            log.info("No watchdog thread running, doing nothing")
+        if os.path.exists(self._file_path):
+            os.remove(self._file_path)
+        self._log_event("watchdog stopped", None)
+
+    def run_once(self) -> None:
+        self._run_once = True
+        if self._watchdog_thread:
+            log.info("Stopping watchdog thread...")
+            self._watchdog_thread.join()
+            self._watchdog_thread = None
+        else:
+            log.info("No watchdog thread running, doing nothing")
+        if os.path.exists(self._file_path):
+            os.remove(self._file_path)
+
+    def _watchdog_loop(self) -> None:
+        # Open the pipe in blocking mode blocks the server thread.
+        # This is fine for the following reasons:
+        #  1. No client case usually does not happen.
+        #  2. We are running the watchdog loop in a separate daemon
+        #     thread, which will not block the process to stop.
+        with open(self._file_path) as fd:
+            while not self._stop_signaled:
+                try:
+                    run_once = self._run_once
+                    self._run_watchdog(fd)
+                    if run_once:
+                        break
+                except Exception:
+                    log.exception("Error running watchdog")
+
+    def _run_watchdog(self, fd: io.TextIOWrapper) -> None:
+        timer_requests = self._get_requests(fd, self._max_interval)
+        self.register_timers(timer_requests)
+        now = time.time()
+        reaped_worker_pids = set()
+        for worker_pid, expired_timers in self.get_expired_timers(now).items():
+            log.info("Reaping worker_pid=[%s]. Expired timers: %s", worker_pid, self._get_scopes(expired_timers))
+            reaped_worker_pids.add(worker_pid)
+            # In case we have multiple expired timers, we find the first timer
+            # with a valid signal (>0) in the expiration time order.
+            expired_timers.sort(key=lambda timer: timer.expiration_time)
+            signal = 0
+            expired_timer = None
+            for timer in expired_timers:
+                self._log_event("timer expired", timer)
+                if timer.signal > 0:
+                    signal = timer.signal
+                    expired_timer = timer
+                    break
+            if signal <= 0:
+                log.info("No signal specified with worker=[%s]. Do not reap it.", worker_pid)
+                continue
+            if self._reap_worker(worker_pid, signal):
+                log.info("Successfully reaped worker=[%s] with signal=%s", worker_pid, signal)
+                self._log_event("kill worker process", expired_timer)
+            else:
+                log.error("Error reaping worker=[%s]. Will retry on next watchdog.", worker_pid)
+        self.clear_timers(reaped_worker_pids)
+
+    def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
+        return [r.scope_id for r in timer_requests]
+
+    def _get_requests(self, fd: io.TextIOWrapper, max_interval: float) -> List[FileTimerRequest]:
+        start = time.time()
+        requests = []
+        while not self._stop_signaled or self._run_once:
+            # For named pipe, readline() is blocking when at least one writer opens.
+            # It returns only when flush() is called at the writer side.
+            # Note that flush() is automatically called inside close().
+            # After the last writer closes, readline() is not blocking.
+            # It will return an empty string when it's at end-of-file.
+            # Since the client side always opens the pipe, writes a message and closes
+            # the pipe immediately, the readline() call below is not blocking for long.
+            json_request = fd.readline()
+            if len(json_request) == 0:
+                if self._run_once:
+                    break
+                time.sleep(min(max_interval, 1))
+            else:
+                request = json.loads(json_request)
+                pid = request["pid"]
+                scope_id = request["scope_id"]
+                expiration_time = request["expiration_time"]
+                signal = request["signal"]
+                requests.append(
+                    FileTimerRequest(
+                        worker_pid=pid, scope_id=scope_id, expiration_time=expiration_time, signal=signal
+                    )
+                )
+            now = time.time()
+            if now - start > max_interval:
+                break
+        return requests
+
+    def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
+        for request in timer_requests:
+            pid = request.worker_pid
+            scope_id = request.scope_id
+            expiration_time = request.expiration_time
+            self._request_count += 1
+
+            key = (pid, scope_id)
+            # negative expiration is a proxy for a release call
+            if expiration_time < 0:
+                if key in self._timers:
+                    del self._timers[key]
+            else:
+                self._timers[key] = request
+
+    def clear_timers(self, worker_pids: Set[int]) -> None:
+        for (pid, scope_id) in list(self._timers.keys()):
+            if pid in worker_pids:
+                del self._timers[(pid, scope_id)]
+
+    def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
+        # pid -> [timer_requests...]
+        expired_timers: Dict[int, List[FileTimerRequest]] = {}
+        for request in self._timers.values():
+            if request.expiration_time <= deadline:
+                expired_scopes = expired_timers.setdefault(request.worker_pid, [])
+                expired_scopes.append(request)
+        return expired_timers
+
+    def _reap_worker(self, worker_pid: int, signal: int) -> bool:
+        try:
+            os.kill(worker_pid, signal)
+            return True
+        except ProcessLookupError:
+            log.info("Process with pid=%s does not exist. Skipping", worker_pid)
+            return True
+        except Exception:
+            log.exception("Error terminating pid=%s", worker_pid)
+        return False
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/timer/local_timer.py b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/local_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..62af765ec8ace4aa55ecae9b80f3dccfb7fbbf31
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/timer/local_timer.py
@@ -0,0 +1,125 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+import multiprocessing as mp
+import os
+import signal
+import time
+from queue import Empty
+from typing import Any, Dict, List, Set, Tuple
+
+from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
+
+__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']
+
+log = logging.getLogger(__name__)
+
+class LocalTimerClient(TimerClient):
+    """
+    Client side of ``LocalTimerServer``. This client is meant to be used
+    on the same host that the ``LocalTimerServer`` is running on and uses
+    pid to uniquely identify a worker. This is particularly useful in situations
+    where one spawns a subprocess (trainer) per GPU on a host with multiple
+    GPU devices.
+    """
+
+    def __init__(self, mp_queue):
+        super().__init__()
+        self._mp_queue = mp_queue
+
+    def acquire(self, scope_id, expiration_time):
+        pid = os.getpid()
+        acquire_request = TimerRequest(pid, scope_id, expiration_time)
+        self._mp_queue.put(acquire_request)
+
+    def release(self, scope_id):
+        pid = os.getpid()
+        release_request = TimerRequest(pid, scope_id, -1)
+        self._mp_queue.put(release_request)
+
+
+class MultiprocessingRequestQueue(RequestQueue):
+    """
+    A ``RequestQueue`` backed by python ``multiprocessing.Queue``
+    """
+
+    def __init__(self, mp_queue: mp.Queue):
+        super().__init__()
+        self._mp_queue = mp_queue
+
+    def size(self) -> int:
+        return self._mp_queue.qsize()
+
+    def get(self, size, timeout: float) -> List[TimerRequest]:
+        requests = []
+        wait = timeout
+        for _ in range(0, size):
+            start = time.time()
+
+            try:
+                r = self._mp_queue.get(block=True, timeout=wait)
+            except Empty:
+                break
+
+            requests.append(r)
+            wait = wait - (time.time() - start)
+            if wait <= 0:
+                break
+
+        return requests
+
+
+class LocalTimerServer(TimerServer):
+    """
+    Server that works with ``LocalTimerClient``. Clients are expected to be
+    subprocesses to the parent process that is running this server. Each host
+    in the job is expected to start its own timer server locally and each
+    server instance manages timers for local workers (running on processes
+    on the same host).
+    """
+
+    def __init__(
+        self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
+    ):
+        super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
+        self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
+
+    def register_timers(self, timer_requests: List[TimerRequest]) -> None:
+        for request in timer_requests:
+            pid = request.worker_id
+            scope_id = request.scope_id
+            expiration_time = request.expiration_time
+
+            # negative expiration is a proxy for a release call
+            if expiration_time < 0:
+                self._timers.pop((pid, scope_id), None)
+            else:
+                self._timers[(pid, scope_id)] = request
+
+    def clear_timers(self, worker_ids: Set[int]) -> None:
+        for (pid, scope_id) in list(self._timers.keys()):
+            if pid in worker_ids:
+                self._timers.pop((pid, scope_id))
+
+    def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
+        # pid -> [timer_requests...]
+        expired_timers: Dict[Any, List[TimerRequest]] = {}
+        for request in self._timers.values():
+            if request.expiration_time <= deadline:
+                expired_scopes = expired_timers.setdefault(request.worker_id, [])
+                expired_scopes.append(request)
+        return expired_timers
+
+    def _reap_worker(self, worker_id: int) -> bool:
+        try:
+            os.kill(worker_id, signal.SIGKILL)
+            return True
+        except ProcessLookupError:
+            log.info("Process with pid=%s does not exist. Skipping", worker_id)
+            return True
+        except Exception:
+            log.exception("Error terminating pid=%s", worker_id)
+        return False
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fbc76bf70244c273d84c617a96dfc9827f1ae70
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .api import get_env_variable_or_raise, get_socket_with_port, macros  # noqa: F401
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fad988cec23a9d1d565b511985cac2f9c076ca6b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8231ff3fa75b555dc5e78d94dfd336a5a0cbf03a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ea7d950cc51e3c59a62ca5ed2debaa5e423fbed
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7d93462f9197319f2deebb7591b949c966f0a2a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e56b9ae8b3bff23486e35ceca524d9998ada068
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af37705e285c636c5a7e968aa39ee2e7bc5ebcb1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/api.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8f48da746fa16a6b4ef41ee276a6931696ce0aa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/api.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import socket
+from string import Template
+from typing import List, Any
+
+
+def get_env_variable_or_raise(env_name: str) -> str:
+    r"""
+    Tries to retrieve environment variable. Raises ``ValueError``
+    if no environment variable found.
+
+    Args:
+        env_name (str): Name of the env variable
+    """
+    value = os.environ.get(env_name, None)
+    if value is None:
+        msg = f"Environment variable {env_name} expected, but not set"
+        raise ValueError(msg)
+    return value
+
+
+def get_socket_with_port() -> socket.socket:
+    addrs = socket.getaddrinfo(
+        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
+    )
+    for addr in addrs:
+        family, type, proto, _, _ = addr
+        s = socket.socket(family, type, proto)
+        try:
+            s.bind(("localhost", 0))
+            s.listen(0)
+            return s
+        except OSError as e:
+            s.close()
+    raise RuntimeError("Failed to create a socket")
+
+
+class macros:
+    """
+    Defines simple macros for caffe2.distributed.launch cmd args substitution
+    """
+
+    local_rank = "${local_rank}"
+
+    @staticmethod
+    def substitute(args: List[Any], local_rank: str) -> List[str]:
+        args_sub = []
+        for arg in args:
+            if isinstance(arg, str):
+                sub = Template(arg).safe_substitute(local_rank=local_rank)
+                args_sub.append(sub)
+            else:
+                args_sub.append(arg)
+        return args_sub
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__init__.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73fd6cdd4431a77cc1cb7ae49efc92cedebfab2e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__init__.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .cycling_iterator import CyclingIterator  # noqa: F401
+from .elastic_distributed_sampler import ElasticDistributedSampler  # noqa: F401
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1fd34eddc7eb3ca0fe5faa753092ce2bd289814
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b9362074c9dadf5af2c84e632b07f24e7029b5d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b88ee6bf8e5a55341bd04dc4c216d54409e429f0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..526d629cdec61093708aae90f9a0f7a9af257b1d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+class CyclingIterator:
+    """
+    An iterator decorator that cycles through the
+    underlying iterator "n" times. Useful to "unroll"
+    the dataset across multiple training epochs.
+
+    The generator function is called as ``generator_fn(epoch)``
+    to obtain the underlying iterator, where ``epoch`` is a
+    number less than or equal to ``n`` representing the ``k``th cycle
+
+    For example if ``generator_fn`` always returns ``[1,2,3]``
+    then ``CyclingIterator(n=2, generator_fn)`` will iterate through
+    ``[1,2,3,1,2,3]``
+    """
+
+    def __init__(self, n: int, generator_fn, start_epoch=0):
+        self._n = n
+        self._epoch = start_epoch
+        self._generator_fn = generator_fn
+        self._iter = generator_fn(self._epoch)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        try:
+            return next(self._iter)
+        except StopIteration as eod:  # eod == end of data
+            if self._epoch < self._n - 1:
+                self._epoch += 1
+                self._iter = self._generator_fn(self._epoch)
+                return self.__next__()
+            else:
+                raise eod
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d211dfabfbf78f9b4f5b210228c4dded497e472e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+from torch.utils.data.distributed import DistributedSampler
+
+
+class ElasticDistributedSampler(DistributedSampler):
+    """
+    Sampler that restricts data loading to a subset of
+    the dataset for elastic training.
+
+    It is especially useful in conjunction with
+    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+    process can pass a DistributedSampler instance as a DataLoader sampler,
+    and load a subset of the original dataset that is exclusive to it.
+
+    .. note::
+        Dataset is assumed to be of constant size.
+
+    Args:
+        dataset: Dataset used for sampling.
+        num_replicas (optional): Number of processes participating in
+            distributed training.
+        rank (optional): Rank of the current process within num_replicas.
+        start_index (optional):  Which index of the dataset to start sampling from
+    """
+
+    def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
+        super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
+        if start_index >= len(dataset):
+            raise ValueError(
+                f"Start index {start_index} should be less than dataset size {len(dataset)}"
+            )
+
+        self.start_index = start_index
+        self.num_samples = int(
+            math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas)  # type: ignore[arg-type]
+        )
+        self.total_size = self.num_samples * self.num_replicas
+
+    def __iter__(self):
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices = (
+            torch.randperm(len(self.dataset) - self.start_index, generator=g)  # type: ignore[arg-type]
+            .add(self.start_index)
+            .tolist()
+        )
+
+        # add extra samples to make it evenly divisible
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        # subsample
+        indices = indices[self.rank : self.total_size : self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/distributed.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc50038ca2ada2a37186e85cdeb959d7139a5a25
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/distributed.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import datetime
+import socket
+from contextlib import closing
+
+import torch.distributed as dist
+from torch.distributed.elastic.utils.logging import get_logger
+
+
+log = get_logger(__name__)
+
+_ADDRESS_IN_USE = "Address already in use"
+_SOCKET_TIMEOUT = "Socket Timeout"
+
+_MEMBER_CHECKIN = "_tcp_store/num_members"
+_LAST_MEMBER_CHECKIN = "_tcp_store/last_member"
+
+
+def create_c10d_store(
+    is_server: bool,
+    server_addr: str,
+    server_port: int = -1,
+    world_size: int = 1,
+    timeout: float = (60 * 10),  # 10 min
+    wait_for_workers: bool = True,
+    retries=3,
+):
+    if server_port == -1 and world_size > 1:
+        raise ValueError(
+            f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
+        )
+
+    if server_port != -1:
+        log.info("sever_port: %s, specified, ignoring retries", server_port)
+
+    # only retry when server_port is NOT static
+    attempt = retries if server_port == -1 else 1
+    while True:
+        if server_port != -1:
+            port = server_port
+        else:
+            port = get_free_port()
+
+        log.info(
+            "Creating c10d store on %s:%s\n"
+            "  world_size  : %s\n"
+            "  is_server   : %s\n"
+            "  timeout(sec): %s\n",
+            server_addr, port, world_size, is_server, timeout
+        )
+
+        try:
+            store = dist.TCPStore(
+                host_name=server_addr,
+                port=port,
+                world_size=world_size,
+                is_master=is_server,
+                timeout=datetime.timedelta(seconds=timeout),
+                wait_for_workers=wait_for_workers,
+            )
+            # skips full rank check when we don't have to wait for all workers
+            if wait_for_workers:
+                _check_full_rank(store, world_size)
+            log.info("Successfully created c10d store")
+            return store
+        except RuntimeError as e:
+            # this is brittle, but the underlying exception type is not properly pybinded
+            # so we parse the error msg for now, interestingly this is how torch itself
+            # detects timeouts and port conflicts in their own unittests
+            # see - caffe2/torch/testing/_internal/common_utils.py
+            # TODO properly map the exceptions in pybind (c10d/init.cpp)
+            if str(e) == _ADDRESS_IN_USE:  # this will only happen on the server
+                if attempt < retries:
+                    log.warning(
+                        "port: %s already in use, attempt: [%s/%s]", port, attempt, retries
+                    )
+                    attempt += 1
+                else:
+                    raise RuntimeError(
+                        f"on {server_addr}, port: {port} already in use"
+                    ) from e
+            else:
+                raise
+
+
+def _check_full_rank(store, world_size):
+    idx = store.add(_MEMBER_CHECKIN, 1)
+    if idx == world_size:
+        store.set(_LAST_MEMBER_CHECKIN, "")
+
+    try:
+        store.get(_LAST_MEMBER_CHECKIN)
+    except RuntimeError as e:
+        if str(e) == _SOCKET_TIMEOUT:
+            raise TimeoutError(
+                f"timed out waiting for all {world_size} members to join"
+            ) from e
+        else:
+            raise
+
+
+def get_free_port():
+    sock = get_socket_with_port()
+    with closing(sock):
+        return sock.getsockname()[1]
+
+
+def get_socket_with_port() -> socket.socket:
+    """
+    Returns a free port on localhost that is "reserved" by binding a temporary
+    socket on it. Close the socket before passing the port to the entity
+    that requires it. Usage example
+
+    ::
+
+    sock = _get_socket_with_port()
+    with closing(sock):
+        port = sock.getsockname()[1]
+        sock.close()
+        # there is still a race-condition that some other process
+        # may grab this port before func() runs
+        func(port)
+    """
+
+    addrs = socket.getaddrinfo(
+        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
+    )
+    for addr in addrs:
+        family, type, proto, _, _ = addr
+        s = socket.socket(family, type, proto)
+        try:
+            s.bind(("localhost", 0))
+            s.listen(0)
+            return s
+        except OSError as e:
+            s.close()
+            log.info("Socket creation attempt failed.", exc_info=e)
+    raise RuntimeError("Failed to create a socket")
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/log_level.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/log_level.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf2d31347aeeb3ebc63af253a3f4db678cfdc0fc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/log_level.py
@@ -0,0 +1,14 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+def get_log_level() -> str:
+    """
+    Return default log level for pytorch.
+    """
+    return "WARNING"
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/logging.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..34355c06ddd69d626723b923ce43918c1e1a3a6c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/logging.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import inspect
+import logging
+import os
+import warnings
+from typing import Optional
+
+from torch.distributed.elastic.utils.log_level import get_log_level
+
+
+def get_logger(name: Optional[str] = None):
+    """
+    Util function to set up a simple logger that writes
+    into stderr. The loglevel is fetched from the LOGLEVEL
+    env. variable or WARNING as default. The function will use the
+    module name of the caller if no name is provided.
+
+    Args:
+        name: Name of the logger. If no name provided, the name will
+              be derived from the call stack.
+    """
+
+    # Derive the name of the caller, if none provided
+    # Use depth=2 since this function takes up one level in the call stack
+    return _setup_logger(name or _derive_module_name(depth=2))
+
+
+def _setup_logger(name: Optional[str] = None):
+    log = logging.getLogger(name)
+    log.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
+    return log
+
+
+def _derive_module_name(depth: int = 1) -> Optional[str]:
+    """
+    Derives the name of the caller module from the stack frames.
+
+    Args:
+        depth: The position of the frame in the stack.
+    """
+    try:
+        stack = inspect.stack()
+        assert depth < len(stack)
+        # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
+        frame_info = stack[depth]
+
+        module = inspect.getmodule(frame_info[0])
+        if module:
+            module_name = module.__name__
+        else:
+            # inspect.getmodule(frame_info[0]) does NOT work (returns None) in
+            # binaries built with @mode/opt
+            # return the filename (minus the .py extension) as modulename
+            filename = frame_info[1]
+            module_name = os.path.splitext(os.path.basename(filename))[0]
+        return module_name
+    except Exception as e:
+        warnings.warn(
+            f"Error deriving logger module name, using . Exception: {e}",
+            RuntimeWarning,
+        )
+        return None
diff --git a/MLPY/Lib/site-packages/torch/distributed/elastic/utils/store.py b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/store.py
new file mode 100644
index 0000000000000000000000000000000000000000..953becb9ab53ac939f3ab82361aa9fd76b247abe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/elastic/utils/store.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from datetime import timedelta
+from typing import List
+
+
+def get_all(store, rank: int, prefix: str, size: int):
+    r"""
+    Given a store and a prefix, the method goes through the array of keys
+    of the following format: ``{prefix}{idx}``, where idx is in a range
+    from 0 to size, and tries to retrieve the data.
+
+    The Rank0 process waits at the end to make sure all other processes
+    finished the procedure before exiting.
+
+    Usage
+
+    ::
+
+     values = get_all(store, 'torchelastic/data', 3)
+     value1 = values[0] # retrieves the data for key torchelastic/data0
+     value2 = values[1] # retrieves the data for key torchelastic/data1
+     value3 = values[2] # retrieves the data for key torchelastic/data2
+
+    """
+    data_arr = []
+    for idx in range(size):
+        data = store.get(f"{prefix}{idx}")
+        data_arr.append(data)
+    store.set(f"{prefix}{rank}.FIN", b"FIN")
+    if rank == 0:
+        # Rank0 runs the TCPStore daemon, as a result it needs to exit last.
+        # Otherwise, the barrier may timeout if rank0 process finished the work
+        # before other processes finished `get_all` method
+        for node_rank in range(size):
+            store.get(f"{prefix}{node_rank}.FIN")
+
+    return data_arr
+
+
+def synchronize(
+    store,
+    data: bytes,
+    rank: int,
+    world_size: int,
+    key_prefix: str,
+    barrier_timeout: float = 300,
+) -> List[bytes]:
+    """
+    Synchronizes ``world_size`` agents between each other using the underlying c10d store.
+    The ``data`` will be available on each of the agents.
+
+    Note: The data on the path is not deleted, as a result there can be stale data if
+        you use the same key_prefix twice.
+    """
+    store.set_timeout(timedelta(seconds=barrier_timeout))
+    store.set(f"{key_prefix}{rank}", data)
+    agent_data = get_all(store, rank, key_prefix, world_size)
+    return agent_data
+
+
+def barrier(
+    store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
+) -> None:
+    """
+    A global lock between agents.
+
+    Note: Since the data is not removed from the store, the barrier can be used
+        once per unique ``key_prefix``.
+    """
+    data = f"{rank}".encode()
+    synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__init__.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ddf24dde2c2e755b3f378a9c73a012e33df6a65
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/__init__.py
@@ -0,0 +1,38 @@
+from ._flat_param import FlatParameter as FlatParameter
+from .fully_sharded_data_parallel import (
+    BackwardPrefetch,
+    CPUOffload,
+    FullOptimStateDictConfig,
+    FullStateDictConfig,
+    FullyShardedDataParallel,
+    LocalOptimStateDictConfig,
+    LocalStateDictConfig,
+    MixedPrecision,
+    OptimStateDictConfig,
+    OptimStateKeyType,
+    ShardedOptimStateDictConfig,
+    ShardedStateDictConfig,
+    ShardingStrategy,
+    StateDictConfig,
+    StateDictSettings,
+    StateDictType,
+)
+
+__all__ = [
+    "BackwardPrefetch",
+    "CPUOffload",
+    "FullOptimStateDictConfig",
+    "FullStateDictConfig",
+    "FullyShardedDataParallel",
+    "LocalOptimStateDictConfig",
+    "LocalStateDictConfig",
+    "MixedPrecision",
+    "OptimStateDictConfig",
+    "OptimStateKeyType",
+    "ShardedOptimStateDictConfig",
+    "ShardedStateDictConfig",
+    "ShardingStrategy",
+    "StateDictConfig",
+    "StateDictSettings",
+    "StateDictType",
+]
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78e890f540e56dfe6f329d04661b07698789d8cb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..904f7ccf46cb7960a44d2fb7b252a49c323acb41
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12755b111043a5055ada85b9e4caff99f6eced97
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52a2285cb2e3d98975a026d38c38ebfe4a20a633
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..800e5729c49483336f5a6da43a08e83ea085d12d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cf2dc40144a5ab29ae5c83f951f56fd0c27f59a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..faa5ca6eabb391bc8222d5854031d064f2d7ab36
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da7ca420280d3fa915085e7801af1338ce1f954e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b38a41343069c07bfa95796b13d54cfc6ef701a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95c55fc07ca6e2bcf0c2e499590fd1e3e4f5796c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3410fe1f71b5366e677fd0f3c3ab7c9738d1745
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4de6c523b23c0af55dc4c7e81f9a1eee381616e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ed27043c5bfd33f6ea7cc7c5ff62dc9c8724fef
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85287e9841b5d163a3e1b9302c0a480bcf56cdd5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..148d99694baa3f2ae2a596132360c5d1d968477f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8349e613e5550baf3ff94b38433d2b60c3bab15e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd6791a4dbab07cc0f97aa6541b0f8e6ac0119f5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d7d60c54b480b31d749401da661cf0b493e4ca8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e67f5a949ce95c76ec959a934252e99abefb5093
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db7aa31a4890fa21c83b34bb282a98d3bffe73e0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d50d7fee1f3dc17807266b12bca44d138a794340
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_common_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d3a722fd62bd0f7f6d9e394008083947618b90
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_common_utils.py
@@ -0,0 +1,563 @@
+"""
+This file includes private common utilities for FSDP.
+"""
+import logging
+import traceback
+import warnings
+import weakref
+from enum import auto, Enum
+from functools import partial
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    no_type_check,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    TYPE_CHECKING,
+)
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._flat_param as flat_param_file
+import torch.nn as nn
+from torch.distributed._composable_state import _get_module_state, _State
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+    _CHECKPOINT_PREFIX,
+)
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
+from torch.distributed.utils import _apply_to_tensors
+from torch.utils._mode_utils import no_dispatch
+
+from .api import (
+    FullOptimStateDictConfig,
+    FullStateDictConfig,
+    OptimStateDictConfig,
+    ShardingStrategy,
+    StateDictConfig,
+    StateDictType,
+)
+
+if TYPE_CHECKING:
+    from ._flat_param import FlatParamHandle
+
+FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
+FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
+FSDP_FLATTENED = "_fsdp_flattened"
+
+# Save a global mapping from module to its input tensor dtype to be populated
+# during the forward pre-hook and consumed in the forward post-hook when
+# overriding a module's mixed precision
+# NOTE: We currently take the last input tensor's dtype in the case of multiple
+# floating-point input tensors, which may be incorrect. However, since there is
+# not a 1:1 correspondence between input and output tensors, we must use *some*
+# heuristic like this to predict the desired output dtype.
+_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
+
+
+class _FSDPDeviceHandle:
+    """
+    This is a simple abstraction for FSDP computing devices,
+    which enables custom backends that implement CUDA-like
+    semantics to be integrated with FSDP.
+    """
+
+    def __init__(self, device: torch.device, backend: Any = None):
+        if backend is None:
+            try:
+                self.__backend = getattr(torch, device.type)
+                self.__device = device
+            except AttributeError as exc:
+                raise AttributeError(
+                    f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'."
+                ) from exc
+        else:
+            self.__backend = backend
+
+    @classmethod
+    def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle":
+        """
+        Return an device handle corresponding to the device, and through this handle,
+        operations with the same semantics as CUDA can be performed on the device.
+        Just return torch.cuda if the device is cuda to make attribute-access faster.
+        Custom backend must first register a module with the same name with {device.type} on torch.
+        """
+        if device.type == "cuda":
+            return cast(_FSDPDeviceHandle, torch.cuda)
+        return cls(device)
+
+    def __getattr__(self, __name: str) -> Any:
+        try:
+            return getattr(self.__backend, __name)
+        except AttributeError as exc:
+            raise AttributeError(
+                f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'"
+            ) from exc
+
+
+class _UninitializedDeviceHandle(_FSDPDeviceHandle):
+    def __init__(self):
+        pass
+
+    def __getattribute__(self, __name: str) -> Any:
+        raise RuntimeError("Trying to use an uninitialized device handle.")
+
+
+class _FSDPState(_State):
+    def __init__(self) -> None:
+        # TODO: Move all the attributes to this class to enable typing for
+        # FSDP/fully_shard.
+        self._ignored_modules: Set[nn.Module] = set()
+        self._ignored_params: Set[nn.Parameter] = set()
+        # Buffer names are cleaned (without wrapper prefixes)
+        self._ignored_buffer_names: Set[str] = set()
+        self.process_group: Optional[dist.ProcessGroup] = None
+        self.rank: int = -1
+        self.world_size: int = -1
+        self._device_mesh: Optional[DeviceMesh] = None
+        self.sharding_strategy = ShardingStrategy.FULL_SHARD
+        self._use_orig_params: bool = False
+        self.training_state = TrainingState.IDLE
+        self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
+        self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
+        self._state_dict_config: StateDictConfig = FullStateDictConfig()
+        self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
+        self._is_root: Optional[bool] = None
+        self._handle: Optional[flat_param_file.FlatParamHandle] = None
+        self._fully_sharded_module_to_handle: Dict[
+            nn.Module, Optional[flat_param_file.FlatParamHandle]
+        ] = {}
+        self.compute_device: Optional[torch.device] = None
+        self._gradient_predivide_factor: int = 0
+        self._gradient_postdivide_factor: int = 0
+        self._comm_hook: Optional[Callable] = None
+        self._comm_hook_state: Optional[Any] = None
+        # Abstract device handle for fsdp compute device. For now,
+        # the compute device must implement cuda semantics used by fsdp
+        self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
+        # All following attributes should only be used for root states:
+        # Save these static lists to avoid the repeated tree traversals
+        self._all_fsdp_states: List[_FSDPState] = []
+        self._all_handles: List[flat_param_file.FlatParamHandle] = []
+        self._fsdp_extension: Optional[FSDPExtensions] = None
+
+
+def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
+    state = _get_module_state(module)
+    if state is None or not isinstance(state, _FSDPState):
+        return None
+    return state
+
+
+def _get_module_fsdp_state_if_fully_sharded_module(
+    module: nn.Module,
+) -> Optional[_FSDPState]:
+    state = _get_module_fsdp_state(module)
+    if state is None:
+        return None
+    if state == module:  # FullyShardedDataParallel module case.
+        return state
+    if module in state._fully_sharded_module_to_handle:  # fully_shard case.
+        return state
+    return None
+
+
+class TrainingState(Enum):
+    """
+    An enum that indicates the state of a ``FullyShardedDataParallel` instance.
+    """
+
+    IDLE = auto()
+    FORWARD_BACKWARD = auto()
+    SUMMON_FULL_PARAMS = auto()
+
+
+class HandleTrainingState(Enum):
+    """
+    An enum that indicates the state of a ``FlatParamHandle`.
+    """
+
+    IDLE = auto()
+    FORWARD = auto()
+    BACKWARD_PRE = auto()
+    BACKWARD_POST = auto()
+    SUMMON_FULL_PARAMS = auto()
+
+
+def _is_composable(state: _FSDPState):
+    # TODO: This is a temporary hack for differentiate between code paths.
+    return not isinstance(state, nn.Module)
+
+
+@no_type_check
+def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]:
+    """
+    Returns the ``FlatParamHandle`` s corresponding to ``module``. This is
+    the handle that contains some parameter in ``module``.
+    """
+    if _is_composable(state):
+        # A valid FSDP state may have no managed parameters and hence no
+        # handles, meaning no entry in `_fully_sharded_module_to_handles`
+        if state._handle is None:
+            return None
+        assert (
+            module in state._fully_sharded_module_to_handle
+        ), f"Expects a fully sharded module but got {module} on rank {state.rank}"
+        return state._fully_sharded_module_to_handle[module]
+    else:
+        # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
+        return module._handle
+
+
+@no_type_check
+def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
+    """Returns if ``module`` has parameters managed by FSDP."""
+    return _module_handle(state, module) is not None
+
+
+def _get_sharding_strategy(handle):
+    """
+    Returns the sharding strategy of the handle.
+    """
+    return handle._sharding_strategy if handle else None
+
+
+def clean_tensor_name(tensor_name: str) -> str:
+    """
+    Cleans the parameter or buffer name by removing any module wrapper
+    prefixes.
+    """
+    tensor_name = tensor_name.replace(FSDP_PREFIX, "")
+    # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
+    # it couples `CheckpointWrapper` and FSDP and also does not scale for more
+    # module wrappers.
+    tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
+    return tensor_name
+
+
+def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
+    """
+    Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
+    avoid re-flattening it during nested construction.
+    """
+    setattr(tensor, FSDP_FLATTENED, True)
+
+
+def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
+    """Returns if ``tensor`` has been marked as flattened by FSDP."""
+    return getattr(tensor, FSDP_FLATTENED, False)
+
+
+def _named_parameters_with_duplicates(
+    module: nn.Module, **kwargs: Any
+) -> List[Tuple[str, nn.Parameter]]:
+    """
+    This API is required as some modules overwrite `named_parameters()` but do not support
+    `remove_duplicate`.
+    """
+    assert (
+        "remove_duplicate" not in kwargs
+    ), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
+    kwargs["remove_duplicate"] = False
+    try:
+        ret = list(module.named_parameters(**kwargs))
+    except AssertionError as e:
+        kwargs.pop("remove_duplicate")
+        ret = list(module.named_parameters(**kwargs))
+    return ret
+
+
+def _get_param_to_fqns(
+    model: torch.nn.Module,
+    dedup_shared_params: bool = True,
+) -> Dict[nn.Parameter, List[str]]:
+    """
+    Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
+    we use canonical to mean the fully-qualified name assigned to the parameter
+    based on its position in the original nn.Module hierarchy before any wrapper
+    or parallelism has been applied to it. This is in contrast to FQNs that may be
+    generated after parallelisms or wrappers have been applied to the model.
+
+    Each normal parameter maps to a singleton list containing its FQN, while each
+    ``FlatParameter`` maps to a list of its original parameter FQNs, which may
+    have length greater than one.  All FQNs are prefixed starting from ``model``.
+
+    In the case where FSDP was applied with ``use_orig_params=True``, there should be no
+    ``FlatParameter`` s registered to the model's modules and this mapping will only
+    contain mappings from ``nn.Parameter`` s to singleton FQN lists.
+
+    It is only in the case where FSDP was applied with ``use_orig_params=False`` where
+    a ``FlatParameter`` will be registered in place of the original parameters and there
+    will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the
+    original parameters.
+
+    Args:
+        model (torch.nn.Module): Root module (which may or may not be a
+            :class:`FullyShardedDataParallel` instance).
+        dedup_shared_params (bool): For shared parameters, if ``True``, only
+            includes the FQNs corresponding to the first encounter of the
+            shared parameter in the module traversal; if ``False``, then
+            includes the FQNs across all encounters. (Default: ``True``)
+    """
+
+    def module_fn(module, prefix, tree_level, param_to_fqns):
+        for param_name, param in _named_parameters_with_duplicates(
+            module, recurse=False
+        ):
+            local_fqns = (
+                param._fqns
+                if isinstance(param, flat_param_file.FlatParameter)
+                else [param_name]
+            )  # prefixed from `module`
+            global_fqns = [
+                clean_tensor_name(prefix + name) for name in local_fqns
+            ]  # prefixed from the top level `model` (i.e. including `prefix`)
+            is_shared_param = param in param_to_fqns
+            if not is_shared_param:
+                param_to_fqns[param] = global_fqns
+            else:
+                if isinstance(param, flat_param_file.FlatParameter):
+                    # DMP overwrites `named_parameters` and skip (advance to
+                    # the next child module) the wrapped_module (e.g.,
+                    # _dmp_wrapped_module and _fsdp_wrapped_module). When a user
+                    # calls `named_child` to traverse the module recursively and
+                    # calls `named_parameters` with `recurse=False`, parameters
+                    # will be traversed more than once.
+                    # This hack is specified designed for DMP + FSDP. We
+                    # overwrite the flat_parameters traversal result to only obtain
+                    # the last one, which happens to be the correct one.
+                    #
+                    # TODO: Remove this hack once DMP + FSDP is not supported.
+                    warnings.warn(
+                        "FlatParameter is being traversed more than once. "
+                        "This case should only happen when using "
+                        "DistributedModelParallel with FullyShardedDataParallel."
+                    )
+                    param_to_fqns[param] = global_fqns
+                elif not dedup_shared_params:
+                    param_to_fqns[param].extend(global_fqns)
+
+    def return_fn(param_to_fqns):
+        return param_to_fqns
+
+    param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
+    return _apply_to_modules(
+        model,
+        module_fn,
+        return_fn,
+        [key for key, _ in _named_parameters_with_duplicates(model)],
+        param_to_unflat_param_names,
+    )
+
+
+@no_type_check
+def _log_post_backward_hook(
+    state: _FSDPState, handle: "FlatParamHandle", log: logging.Logger
+) -> None:
+    # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
+    # Below logging of module names this post-bwd hook fires for can help debug certain
+    # cases where hooks don't fire, such as under certain activation checkpoint configs.
+    if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
+        param_fqns = _get_handle_fqns_from_root(state, handle)
+        log.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
+
+
+@no_type_check
+def _get_handle_fqns_from_root(
+    state: _FSDPState, handle: "FlatParamHandle"
+) -> Optional[List[str]]:
+    if handle is None:
+        return None
+    param_to_fqn = state._exec_order_data.param_to_fqn
+    handle_params = handle.flat_param._params  # only populated for use_orig_params
+    param_fqns = [
+        fqn for fqn_list in [param_to_fqn[p] for p in handle_params] for fqn in fqn_list
+    ]
+    return param_fqns
+
+
+def _apply_to_modules(
+    root_module: torch.nn.Module,
+    module_fn: Callable,
+    return_fn: Callable,
+    filter_fqns: Optional[List[str]] = None,
+    *args,
+    **kwargs,
+):
+    """
+    Performs a pre-order traversal of the modules in the hierarchy rooted at
+    ``root_module``, applying ``module_fn`` at each module and finally
+    returning a value using ``return_fn``. The traversal constructs the full
+    module prefix name (e.g. "module.submodule." just like in model state dict)
+    and makes that available to ``module_fn``.
+
+    ``filter_fqns`` is used because some module may have its own prefix similar
+    to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
+    to remove the prefix.
+    """
+
+    def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
+        # Call the module function before recursing over children (pre-order)
+        module_fn(module, prefix, tree_level, *args, **kwargs)
+        for submodule_name, submodule in module.named_children():
+            if submodule is None:
+                continue
+            new_prefix = prefix + submodule_name + "."
+            new_tree_level = tree_level + 1
+            if filter_fqns is not None:
+                for fqn in filter_fqns:
+                    if fqn.startswith(new_prefix):
+                        break
+                else:
+                    # DMP's named_parameter() will mess up the traversal with
+                    # ``named_children`` + `named_parameter(recurse=False)``.
+                    # This hack is a must to make the traversal work.
+                    # TODO: Remove this hack once DMP + FSDP is not supported.
+                    if (
+                        submodule_name == "_fsdp_wrapped_module"
+                        or submodule_name == "_dmp_wrapped_module"
+                    ):
+                        if (
+                            not torch.distributed._functional_collectives.is_torchdynamo_compiling()
+                        ):
+                            # TODO(voz): Don't graph break on this
+                            warnings.warn(
+                                "An unexpected prefix is detected. This case "
+                                " should only happen when using DMP with FSDP. "
+                                f"prefix = {prefix}, "
+                                f"submodule_name = {submodule_name}"
+                            )
+                        new_prefix = prefix
+                    elif submodule_name == "module":
+                        warnings.warn(
+                            "An unexpected prefix is detected. This case "
+                            " should only happen when DDP wraps the outer "
+                            " modules while FSDP wraps the inner ones."
+                            f"prefix = {prefix}, "
+                            f"submodule_name = {submodule_name}"
+                        )
+                        new_prefix = prefix
+            f(submodule, new_prefix, new_tree_level, *args, **kwargs)
+
+    f(root_module, "", 0, *args, **kwargs)
+    return return_fn(*args, **kwargs)
+
+
+@no_type_check
+def _assert_in_training_states(
+    state: _FSDPState,
+    training_states: List[TrainingState],
+) -> None:
+    """Asserts that FSDP is in the states ``_training_states``."""
+    # Raise a `ValueError` instead of using `assert` to ensure that these
+    # logical assertions run even if `assert`s are disabled
+    if state.training_state not in training_states:
+        msg = (
+            f"expected to be in states {training_states} but current state is "
+            f"{state.training_state}"
+        )
+        # Print the error on rank 0 in case this is called in the backward pass
+        if state.rank == 0:
+            if isinstance(state, nn.Module):
+                print(f"Asserting FSDP instance is: {state}")
+            print(f"ERROR: {msg}")
+            traceback.print_stack()
+        raise ValueError(msg)
+
+
+def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
+    """
+    Returns:
+        Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
+        parent-less) with respect to the modules in the set itself. In other
+        words, these are the modules in ``modules`` that are not the child of
+        any other module in ``modules``.
+    """
+    root_modules: Set[nn.Module] = set()
+    module_to_submodules = {module: set(module.modules()) for module in modules}
+    for candidate_module in modules:
+        is_root_module = True
+        for module, submodules in module_to_submodules.items():
+            is_child_module = (
+                candidate_module is not module and candidate_module in submodules
+            )
+            if is_child_module:
+                is_root_module = False
+                break
+        if is_root_module:
+            root_modules.add(candidate_module)
+    return root_modules
+
+
+def _override_module_mixed_precision(
+    root: torch.nn.Module,
+    module_classes_to_override: Iterable[Type[nn.Module]],
+    wrap_override_dict: Dict[str, Any] = {"mixed_precision": None},  # noqa: B006
+) -> Set[Type[nn.Module]]:
+    module_classes_to_override = tuple(set(module_classes_to_override))
+    # Return a set of the actually overridden module classes
+    overridden_module_classes: Set[Type[nn.Module]] = set()
+    for mod in root.modules():
+        if isinstance(mod, module_classes_to_override):
+            overridden_module_classes.add(type(mod))
+            mod._wrap_overrides = wrap_override_dict  # type: ignore[assignment]
+            # TODO: We need to run this mixed precision ignored module in fp32,
+            # but ensure subsequent modules, that may possibly be running with
+            # mixed precision, still receive the appropriate precision inputs
+            # without user having to adjust mixed precision config too much.
+            # As a result, we attach pre and post forward hooks to up / down
+            # cast. We should revisit this design.
+
+            def cast_fn(
+                dtype: torch.dtype, module: nn.Module, x: torch.Tensor
+            ) -> torch.Tensor:
+                if not torch.is_floating_point(x) or x.dtype == dtype:
+                    return x
+                _MODULE_TO_INP_DTYPE[module] = x.dtype
+                return x.to(dtype)
+
+            def forward_pre_hook(module, args):
+                return _apply_to_tensors(partial(cast_fn, torch.float32, module), args)
+
+            def forward_post_hook(module, args, output):
+                # NOTE: If the forward did not have any floating-point tensors,
+                # then the dtype will not be set for this module, and we do not
+                # upcast the dtype.
+                if module in _MODULE_TO_INP_DTYPE:
+                    old_dtype = _MODULE_TO_INP_DTYPE[module]
+                    return _apply_to_tensors(
+                        partial(cast_fn, old_dtype, module), output
+                    )
+
+            # We intentionally append both of these hooks so that they run after
+            # all other hooks.
+            mod.register_forward_pre_hook(forward_pre_hook, prepend=False)
+            mod.register_forward_hook(forward_post_hook, prepend=False)
+    return overridden_module_classes
+
+
+def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
+    # FIXME record_stream doesn't work with non-cuda tensors
+    if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
+        return
+
+    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        return
+        # from @ezyang:
+        # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin
+        # Looking over the PR, it looks like this is because we don't actually support Stream arguments
+        # in torch dispatch, so it just chokes.
+        # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False),
+        # a better version of this would just be to check if there are any modes before disabling dispatch.
+        # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here.
+        tensor.record_stream(stream)
+    else:
+        with no_dispatch():
+            tensor.record_stream(stream)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_debug_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..76463883817fb3854da9b5ae9fda643429410341
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_debug_utils.py
@@ -0,0 +1,155 @@
+import logging
+import time
+from collections import defaultdict
+from contextlib import contextmanager
+from enum import Enum
+from typing import Dict, Iterator, List, Set, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._flat_param as flat_param_file
+from torch.distributed.fsdp._common_utils import (
+    _apply_to_modules,
+    _get_module_fsdp_state,
+    clean_tensor_name,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SimpleProfiler:
+    class Type(str, Enum):
+        ALL = "all"
+        ALLGATHER = "all_gather"
+        ALLGATHER_OBJ = "all_gather_object"
+        RESHARDING = "resharding"
+        H2D = "H2D"
+        D2H = "D2H"
+
+    results: Dict[str, float] = defaultdict(float)
+    profiling: Set[str] = set()
+
+    @classmethod
+    def reset(cls) -> None:
+        cls.results.clear()
+        cls.profiling.clear()
+
+    @classmethod
+    @contextmanager
+    def profile(cls, profile_type: str) -> Iterator[None]:
+        assert profile_type not in cls.profiling, (
+            f"{profile_type} is already being profiled. "
+            "SimpleProfiler does not support profiling multiple instances at "
+            "the same time. "
+        )
+
+        cls.profiling.add(profile_type)
+        begin = time.monotonic()
+        try:
+            yield
+        finally:
+            end = time.monotonic()
+            cls.results[profile_type] += end - begin
+            cls.profiling.remove(profile_type)
+
+    @classmethod
+    def dump_and_reset(cls, msg: str) -> None:
+        # This cannot be combined with DETAIL distributed log
+        # as the profiling will be very incorrect.
+        if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
+            logger.warning("%s %s", msg, cls.results)
+        cls.reset()
+
+
+def _get_sharded_module_tree_with_module_name_to_fqns(
+    model: torch.nn.Module,
+) -> Tuple[str, Dict[str, List[str]]]:
+    """
+    It is used for composable fully_shard() code path, it returns
+      1. sharded module tree info: each line reprents a submodule name that contats the
+    submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
+    the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
+    level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
+    is like this:
+        [CompositeModel] FULLY SHARDED
+            l1[Linear]
+            u1[UnitModule] FULLY SHARDED
+                u1.l1[Linear]
+                u1.seq[Sequential]
+                    u1.seq.0[ReLU]
+                    u1.seq.1[Linear]
+                    u1.seq.2[ReLU]
+                u1.l2[Linear]
+            u2[UnitModule] FULLY SHARDED
+                u2.l1[Linear]
+                u2.seq[Sequential]
+                    u2.seq.0[ReLU]
+                    u2.seq.1[Linear]
+                    u2.seq.2[ReLU]
+                u2.l2[Linear]
+            l2[Linear]
+      2. a dict mapping from the concated module FQN and class name to a list of its managed
+    original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
+            {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
+             'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
+             'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
+            }
+    All FQNs are prefixed starting from ``model``.
+
+    Args:
+        model (torch.nn.Module): Root module (which may or may not be passed to
+                                 composable `fully_shard()`).
+    """
+
+    def module_fn(
+        module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
+    ):
+        num_spaces = tree_level * 4
+        trimed_prefix = (
+            prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
+        )
+        prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
+        printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
+
+        state = _get_module_fsdp_state(module)
+        if state is None:
+            sharded_tree_info[0] += printed_prefixed_module_name + "\n"
+            return
+
+        handle = state._fully_sharded_module_to_handle.get(module, None)
+
+        if handle:
+            sharded_tree_info[0] += (
+                printed_prefixed_module_name + " FULLY SHARDED" + "\n"
+            )
+        else:
+            sharded_tree_info[0] += printed_prefixed_module_name + "\n"
+
+        if handle:
+            param = handle.flat_param
+            assert isinstance(param, flat_param_file.FlatParameter)
+            global_fqns = [
+                clean_tensor_name(prefix + name) for name in param._fqns
+            ]  # prefixed from the top level `model` (i.e. including `prefix`)
+
+            if prefixed_module_name in sharded_module_name_to_fqns:
+                sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
+            else:
+                sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
+
+    def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
+        return sharded_tree_info[0], sharded_module_name_to_fqns
+
+    # Use List to mutate its value in place while running the recursive functions
+    sharded_tree_info: List[str] = [
+        "",
+    ]
+    sharded_module_name_to_fqns: Dict[str, List[str]] = {}
+    return _apply_to_modules(
+        model,
+        module_fn,
+        return_fn,
+        [key for key, _ in model.named_parameters()],
+        sharded_tree_info,
+        sharded_module_name_to_fqns,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_dynamo_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_dynamo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..53629ae60334a0c0661b991188b43a5d3a460f0c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_dynamo_utils.py
@@ -0,0 +1,45 @@
+from typing import Set
+
+import torch.nn as nn
+
+
+def _annotate_modules_for_dynamo(
+    module: nn.Module,
+    ignored_modules: Set[nn.Module],
+    use_orig_params: bool,
+):
+    """
+    Annotates the submodules in ``module`` 's tree, except those in
+    ``ignored_modules``, indicating that the submodules are FSDP-managed and
+    saving the ``use_orig_params`` setting passed to the FSDP constructor.
+    """
+    for submodule in module.modules():
+        if submodule not in ignored_modules:
+            """[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
+
+            Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since
+            it skips tracing all the torch.distributed.fsdp code.
+                - Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also
+                gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops.
+                - However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*),
+                and we need a way to indicate to dynamo which modules are wrapped by FSDP.
+
+            (*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough
+            guards.  NNModules otherwise are 'specialized', meaning there is less overhead due to assuming
+            their code is well-behaved.
+
+            One particular issue with specialized NNModules for FSDP is that the
+            views created for orig_params are captured into the compiled graph on the first iteration, and while
+            they are always going to point to the correct flatparameter and give correct results, their order
+            of creation influences the order of backward execution, preventing overlap of comm and computation
+            during backward.  We need to _use_ the new parameter views created on each forward iteration, in
+            order for backward to interleave hooks with compute per layer.  UnspecializedNNModule lets us achieve
+            this by capturing the module code more 'functionally' and passing parameters in as inputs each time.
+            """
+            submodule._is_fsdp_managed_module = True  # type: ignore[assignment]
+
+            # Dynamo only supports FSDP with use_orig_params=True.
+            # This is hacky, but I could not think of another way to add an assertion to dynamo
+            # for this, since Dynamo skips all the FSDP code frames and thus can't inspect the
+            # FSDP module directly
+            submodule._fsdp_use_orig_params = use_orig_params  # type: ignore[assignment]
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_exec_order_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_exec_order_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4770469c06464b77dbcff0e866ce014d143f5088
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_exec_order_utils.py
@@ -0,0 +1,364 @@
+import itertools
+import warnings
+from enum import auto, Enum
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._traversal_utils as traversal_utils
+import torch.nn as nn
+from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns
+from torch.distributed.fsdp._flat_param import FlatParamHandle
+
+
+class _ExecOrderWarnStatus(Enum):
+    """Used internally for execution order validation."""
+
+    NONE = auto()  # no deviation yet
+    WARNING = auto()  # deviated this iteration; currently issuing warnings
+    WARNED = auto()  # deviated in a previous iteration
+
+
+class _ExecOrderData:
+    """
+    This contains the data structures to track the execution order. We track
+    the pre-forward order on the *first* iteration for forward prefetching
+    (which thus assumes static graph) and the post-forward order on *every*
+    iteration for backward prefetching (which thus does not assume static
+    graph but may be provide an incorrect order).
+    """
+
+    def __init__(
+        self,
+        debug_level: dist.DebugLevel,
+        backward_prefetch_limit: int,
+        forward_prefetch_limit: int,
+    ) -> None:
+        # Tracks the (static) pre-forward order for execution order validation
+        # and forward prefetching
+        self.handles_pre_forward_order: List[FlatParamHandle] = []
+        # Tracks the post-forward order for pre-backward prefetching
+        self.handles_post_forward_order: List[Optional[FlatParamHandle]] = []
+        self._iter = 0
+
+        # Gives the max number of backward/forward prefetched all-gathers by a
+        # single module
+        self._backward_prefetch_limit = backward_prefetch_limit
+        self._forward_prefetch_limit = forward_prefetch_limit
+
+        # Data structures for execution order validation
+        self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL
+        self.process_group: Optional[dist.ProcessGroup] = None
+        self.world_size: Optional[int] = None
+        self.all_handles: List[FlatParamHandle] = []
+        # Names are prefixed from the root module
+        self.param_to_fqn: Dict[nn.Parameter, List[str]] = {}
+        # Current index in the pre-forward execution order
+        self.current_order_index = 0
+        self.warn_status = _ExecOrderWarnStatus.NONE
+
+    def init(
+        self,
+        state: _FSDPState,
+        root_module: nn.Module,
+        process_group: dist.ProcessGroup,
+    ) -> None:
+        """
+        Initializes the data structures needed for checking the forward order.
+        This should be called after a root FSDP instance has been set during
+        lazy initialization.
+        """
+        self.process_group = process_group
+        self.rank = process_group.rank()
+        self.world_size = process_group.size()
+        # Fix an order over the handles, which should be the same across ranks
+        for handle in traversal_utils._get_fsdp_handles(root_module):
+            index = len(self.all_handles)
+            self.all_handles.append(handle)
+            handle._handle_index = index
+        self.param_to_fqn = _get_param_to_fqns(root_module)
+        # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles`
+        # to check that all ranks have the same handles in the same order.
+        # https://github.com/pytorch/pytorch/issues/79620
+
+    @property
+    def is_first_iter(self) -> bool:
+        return self._iter == 0
+
+    def get_handle_to_backward_prefetch(
+        self,
+        current_handle: FlatParamHandle,
+    ) -> Optional[FlatParamHandle]:
+        """
+        Returns a :class:`list` of the handles keys of the handles to backward
+        prefetch given the current handles key. If there are no valid handles
+        keys to prefetch, then this returns an empty :class:`list`.
+        """
+        current_index = current_handle._post_forward_index
+        if current_index is None:
+            return None
+        target_index = current_index - 1
+        target_handle: Optional[FlatParamHandle] = None
+        for _ in range(self._backward_prefetch_limit):
+            if target_index < 0:
+                break
+            target_handle = self.handles_post_forward_order[target_index]
+            target_index -= 1
+        return target_handle
+
+    def get_handle_to_forward_prefetch(
+        self,
+        current_handle: FlatParamHandle,
+    ) -> Optional[FlatParamHandle]:
+        """
+        Returns a :class:`list` of the handles keys of the handles to forward
+        prefetch given the current handles key. If there are no valid handles
+        keys to prefetch, then this returns an empty :class:`list`.
+        """
+        current_index = current_handle._pre_forward_order_index
+        if current_index is None:
+            return None
+        target_index = current_index + 1
+        target_handle: Optional[FlatParamHandle] = None
+        for _ in range(self._forward_prefetch_limit):
+            if target_index >= len(self.handles_pre_forward_order):
+                break
+            target_handle = self.handles_pre_forward_order[target_index]
+            target_index += 1
+        return target_handle
+
+    def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None:
+        """
+        Records ``handles`` in the post-forward order, where ``handles`` should
+        be a group of handles used in the same module's forward. If ``handles``
+        is empty, then it is omitted.
+
+        Unlike :meth:`record_pre_forward`, this records the order *every*
+        iteration with the expectation that the recorded order is reset in
+        :meth:`next_iter`.
+        """
+        if not handle:
+            return
+        # Only record the first usage of a handles key
+        if handle._post_forward_index:
+            self.handles_post_forward_order.append(handle)
+            return
+        index = len(self.handles_post_forward_order)
+        handle._post_forward_index = index
+        self.handles_post_forward_order.append(handle)
+
+    def record_pre_forward(
+        self, handle: Optional[FlatParamHandle], is_training: bool
+    ) -> None:
+        """
+        Records ``handles`` in the pre-forward order, where ``handles`` should
+        be a group of handles used in the same module's forward. If ``handles``
+        is empty, then it is omitted.
+
+        On the first iteration, this checks the execution order across ranks.
+        See :meth:`_check_order` for details.
+        """
+        if not handle:
+            return
+        self._check_order(handle, is_training)
+        # Fix the order after the first iteration and only record the first
+        # usage of a handles key
+        if not self.is_first_iter or handle._pre_forward_order_index is not None:
+            return
+        index = len(self.handles_pre_forward_order)
+        handle._pre_forward_order_index = index
+        self.handles_pre_forward_order.append(handle)
+
+    def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
+        """
+        Checks the forward execution order as long as ``is_training`` is
+        ``True`` since checking in eval mode is not supported. This only checks
+        if the distributed debug level is DETAIL.
+
+        - On the first iteration, this uses all-gathers to check that all ranks
+        are all-gathering the same handles and hence ``FlatParameter`` s,
+        raising an error if not.
+        - On subsequent iterations, this checks that each rank is locally
+        consistent with its own forward order from the first iteration, issuing
+        a warning if not. This issues a warning on the first deviating
+        iteration and stops warning thereafter.
+        """
+        # Do not check order in eval mode since the post-backward callback does
+        # not run so it cannot be used to mark the end of an iteration
+        if not is_training or not self._checking_order:
+            return
+        if self.is_first_iter:
+            msg_prefix = "Forward order differs across ranks:"
+            optional_local_indices: Tuple[
+                Optional[int], ...
+            ] = self._get_handle_indices(handle)
+            device = handle.device  # guaranteed to be non-CPU
+            num_valid_indices = sum(
+                (index is not None) for index in optional_local_indices
+            )
+            tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = {
+                "dtype": torch.int32,
+                "device": device,
+            }
+            world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs)  # type: ignore[arg-type, call-overload]
+            local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs)  # type: ignore[arg-type, call-overload]
+            dist.all_gather_into_tensor(
+                world_num_valid_indices,
+                local_num_valid_indices,
+                group=self.process_group,
+            )
+            # Copy entire tensor from D2H once to avoid per element D2H copies
+            world_num_valid_indices = world_num_valid_indices.cpu()
+            # Check that all ranks plan to all-gather the same number of
+            # parameters
+            # TODO (awgu): Since every module has at most one handle in the
+            # current implementation, this should never raise the error.
+            assert self.world_size is not None  # mypy
+            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+                # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
+                # tensor comparison control flow.
+                # https://github.com/pytorch/pytorch/issues/107055
+                for (r1, n1), (r2, n2) in itertools.combinations(
+                    (
+                        (rank, world_num_valid_indices[rank])
+                        for rank in range(self.world_size)
+                    ),
+                    2,
+                ):
+                    if n1 != n2:
+                        raise RuntimeError(
+                            f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "
+                            f"while rank {r2} is all-gathering {n2} parameters"
+                        )
+            world_indices = torch.zeros(  # type: ignore[call-overload]
+                self.world_size * num_valid_indices, **tensor_kwargs
+            )
+            local_indices = torch.tensor(optional_local_indices, **tensor_kwargs)  # type: ignore[arg-type]
+            dist.all_gather_into_tensor(
+                world_indices, local_indices, group=self.process_group
+            )
+            # Copy entire tensor from D2H once to avoid per element D2H copies
+            world_indices = world_indices.cpu()
+            # Check that all ranks plan to all-gather the same index parameters
+            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+                # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2
+                # tensor comparison control flow.
+                # https://github.com/pytorch/pytorch/issues/107055
+                for (r1, i1), (r2, i2) in itertools.combinations(
+                    (
+                        (
+                            rank,
+                            world_indices[
+                                rank
+                                * num_valid_indices : (rank + 1)
+                                * num_valid_indices
+                            ],
+                        )
+                        for rank in range(self.world_size)
+                    ),
+                    2,
+                ):
+                    if i1 != i2:
+                        r1_param_names = self._get_names_from_handle_indices(i1)
+                        r2_param_names = self._get_names_from_handle_indices(i2)
+                        raise RuntimeError(
+                            f"{msg_prefix} rank {r1} is all-gathering parameters "
+                            f"for {r1_param_names} while rank {r2} is all-gathering "
+                            f"parameters for {r2_param_names}"
+                        )
+        else:
+            # Only issue warnings on the first deviating iteration and stop
+            # checking thereafter to avoid flooding the console
+            if self.warn_status == _ExecOrderWarnStatus.WARNED:
+                return
+            msg_prefix = None  # non-`None` means we should warn
+            if self.current_order_index >= len(self.handles_pre_forward_order):
+                # This iteration sees extra all-gather(s) compared to the first
+                msg_prefix = (
+                    "Expected to not all-gather any more parameters in the "
+                    "forward but trying to all-gather parameters for "
+                )
+            else:
+                expected_handle = self.handles_pre_forward_order[
+                    self.current_order_index
+                ]
+                if expected_handle != handle:
+                    expected_param_names = self._get_names_from_handles(expected_handle)
+                    msg_prefix = (
+                        f"Expected to all-gather for {expected_param_names} "
+                        "but trying to all-gather parameters for "
+                    )
+            if msg_prefix is not None:
+                param_names = self._get_names_from_handles(handle)
+                msg_suffix = (
+                    f"{param_names}"
+                    if param_names
+                    else "a newly-added parameter since construction time"
+                )
+                warnings.warn(
+                    "Forward order differs from that of the first iteration "
+                    f"on rank {self.rank}. Collectives are unchecked and may "
+                    f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}"
+                )
+                self.warn_status = _ExecOrderWarnStatus.WARNING
+            self.current_order_index += 1
+
+    def _get_handle_indices(
+        self,
+        handle: FlatParamHandle,
+    ) -> Tuple[Optional[int], ...]:
+        """
+        Returns the handle indices (i.e. indices into ``self.all_handles``)
+        corresponding to the handles in ``handle``. An entry in the
+        returned tuple is ``None`` if the handle is invalid.
+        """
+        indices: List[Optional[int]] = []
+        if handle:
+            indices.append(handle._handle_index)
+        return tuple(indices)
+
+    def _get_names_from_handle_indices(
+        self,
+        handle_indices: Tuple[int, ...],
+    ) -> List[List[str]]:
+        """
+        Returns a list of FQNs for each handle in ``handle_indices``. If a
+        handle index is invalid, then its FQNs are omitted from the returned
+        list.
+        """
+        fqns: List[List[str]] = []
+        for index in handle_indices:
+            if index is None or index < 0 or index >= len(self.all_handles):
+                continue
+            handle = self.all_handles[index]
+            flat_param = handle.flat_param
+            fqns.append(self.param_to_fqn[flat_param])
+        return fqns
+
+    def _get_names_from_handles(
+        self,
+        handle: FlatParamHandle,
+    ) -> List[List[str]]:
+        """
+        Returns a list of FQNs for each handle in ``handles_key``. If a handle
+        is invalid, then its FQNs are omitted from the returned list.
+        """
+        fqns: List[List[str]] = []
+        if handle:
+            flat_param = handle.flat_param
+            if flat_param in self.param_to_fqn:
+                fqns.append(self.param_to_fqn[flat_param])
+        return fqns
+
+    def next_iter(self):
+        """
+        Advances the internal data structures per iteration. This should be
+        called in the post-backward callback since that marks the true end of
+        an iteration.
+        """
+        self._iter += 1
+        self.handles_post_forward_order.clear()
+        if self._checking_order:
+            self.current_order_index = 0
+            if self.warn_status == _ExecOrderWarnStatus.WARNING:
+                self.warn_status = _ExecOrderWarnStatus.WARNED
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_flat_param.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_flat_param.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a4ebbb27127d3fc54b3afab1f1bf28e1a3f367
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_flat_param.py
@@ -0,0 +1,2731 @@
+import contextlib
+import functools
+import logging
+import os
+import warnings
+from enum import auto, Enum
+from itertools import accumulate, chain
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    Iterator,
+    List,
+    NamedTuple,
+    no_type_check,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.distributed.fsdp._common_utils import (
+    _FSDPDeviceHandle,
+    _named_parameters_with_duplicates,
+    _no_dispatch_record_stream,
+    _set_fsdp_flattened,
+    HandleTrainingState,
+)
+from torch.distributed.utils import (
+    _alloc_storage,
+    _data_ptr_allocated,
+    _free_storage,
+    _p_assert,
+)
+from torch.nn.parameter import _ParameterMeta  # type: ignore[attr-defined]
+from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
+
+from ._fsdp_extensions import (
+    _ext_post_unflatten_transform,
+    _ext_pre_flatten_transform,
+    FSDPExtensions,
+)
+
+__all__ = [
+    "FlatParameter",
+    "FlatParamHandle",
+    "FlatParamShardMetadata",
+    "ParamInfo",
+    "SharedParamInfo",
+    "HandleShardingStrategy",
+]
+
+log = logging.getLogger(__name__)
+
+
+"""
+[Note: Fully Sharded Module]
+We define the "fully sharded module" to be the original ``nn.Module`` that owns
+a ``FlatParamHandle``. It is the *single* module logically responsible for the
+*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
+forward or backward pass. The fully sharded module should be passed to the
+``FlatParamHandle`` constructor.
+
+For the wrapper code path:
+- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
+runs the unshard/reshard on behalf of the fully sharded module by overriding
+``nn.Module.forward``.
+- The fully sharded module is exactly the module passed to the
+``FullyShardedDataParallel`` constructor's ``module`` argument.
+
+For the non-wrapper code path:
+- Hooks registered on the fully sharded module run the unshard/reshard.
+- The fully sharded module may either be the direct argument to ``fully_shard``
+or a submodule chosen by the provided wrapping policy.
+"""
+
+# Environment variable toggling whether to use unsafe `setattr()` for view
+# setting in `_use_sharded_views()` and `_use_unsharded_views()`
+# We should use 'safe' by default since it respects method overrides, but for
+# special cases such as for high CPU overhead or for intentionally bypassing
+# checks in the overrides, we may use 'unsafe'.
+_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
+
+# Environment variable toggling whether to check for parameter/gradient
+# writeback in case their storages change after FSDP initialization
+# We should check by default since it prevents silent correctness errors, but
+# since such changes are atypical, we may want to skip the check to save CPU
+# overhead, especially since the check happens in the pre-forward and
+# pre-backward each iteration.
+_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
+
+# Env var toggling whether when model is in .eval() mode, should we run in fp32
+# or the reduced precision.
+_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
+
+# Some value to set padding in tensors to for debuggability
+_FLAT_PARAM_PADDING_VALUE = 42
+
+# Environment variables for disabling the all-gather and reduce-scatter
+# communication ops for ablation studies. Note that without these communication
+# ops the training won't converge, and you probably need to disable correctness
+# checks in your model.
+_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
+_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
+
+
+# TODO: Define this for now to avoid circular imports. See if we can remove.
+class HandleShardingStrategy(Enum):
+    FULL_SHARD = auto()
+    SHARD_GRAD_OP = auto()
+    NO_SHARD = auto()
+    HYBRID_SHARD = auto()
+    _HYBRID_SHARD_ZERO2 = auto()
+
+
+RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
+    HandleShardingStrategy.FULL_SHARD,
+    HandleShardingStrategy.HYBRID_SHARD,
+)
+NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
+    HandleShardingStrategy.SHARD_GRAD_OP,
+    HandleShardingStrategy._HYBRID_SHARD_ZERO2,
+)
+
+
+class ParamInfo(NamedTuple):
+    """Information for an original parameter."""
+
+    param_name: str  # unprefixed
+    module: nn.Module
+    module_name: str
+
+
+class SharedParamInfo(NamedTuple):
+    """
+    Additional information for a shared parameter.
+
+    For each shared parameter, we designate one module and its parameter
+    variable to be the primary owner, determined as the first one encountered
+    in the parameter walk. These are prefixed with "prim". The primary module
+    and parameter do not have their own :class:`SharedParamInfo` instance.
+    """
+
+    param_name: str  # unprefixed
+    module: nn.Module
+    module_name: str
+    prim_param_name: str  # unprefixed
+    prim_module: nn.Module
+    prim_module_name: str
+
+
+class _ShardParamInfo(NamedTuple):
+    """Shard-related information for an original parameter."""
+
+    in_shard: bool
+    # Use to index into the sharded flat parameter, e.g.
+    # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
+    offset_in_shard: Optional[int]
+    numel_in_shard: Optional[int]
+    # Use to get part of the parameter in the local shard from a flattened
+    # version of the unsharded parameter, e.g.
+    # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
+    intra_param_start_idx: Optional[int]
+    intra_param_end_idx: Optional[int]  # inclusive
+
+
+class FlatParamShardMetadata(NamedTuple):
+    """
+    This holds metadata specific to this rank's shard of the flat parameter.
+
+    Attributes:
+        param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
+            shard of the parameters; see :class:`FlatParameter`.
+        param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
+            shard of the parameters; see :class:`FlatParameter`.
+        param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
+            of the parameters; see :class:`FlatParameter`.
+        param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
+            units of numels) giving this rank's part of each flattened
+            original parameter.
+    """
+
+    param_names: Tuple[str, ...]
+    param_shapes: Tuple[torch.Size, ...]
+    param_numels: Tuple[int, ...]
+    param_offsets: Tuple[Tuple[int, int], ...]
+
+
+class _FlatParameterMeta(_ParameterMeta):
+    # Make `isinstance(t, FlatParameter)` return True for custom tensor
+    # instances that have the _is_flat_param flag for BC
+    def __instancecheck__(self, instance):
+        # NB: do NOT test the super implementation
+        return isinstance(instance, torch.Tensor) and getattr(
+            instance, "_is_flat_param", False
+        )
+
+
+class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
+    """
+    This is the flat parameter used by :class:`FullyShardedDataParallel`.
+
+    It is comprised of one or more original parameters, which are flattened and
+    concatenated to construct the flat parameter.
+
+    Under the current design, this parameter logically represents both the
+    unsharded and sharded flat parameter, and its data changes storages
+    dynamically.
+        - In the :class:`FullyShardedDataParallel` constructor, the parameter
+        is initialized as unsharded and then sharded in-place.
+        - At runtime, the parameter is lazily (re)-initialized. The sharded
+        parameter data is saved in ``self._local_shard``, and a new ``Tensor``
+        ``self._full_param_padded`` is created, which is the all-gather
+        destination and owns the unsharded parameter storage thereafter. (See
+        :meth:`FlatParamHandle.init_flat_param_attributes`.)
+        - Throughout runtime, the parameter data changes storages as needed,
+        e.g. to the sharded flat parameter, low precision sharded flat
+        parameter, or the unsharded flat parameter.
+
+    NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
+    padding, we have two versions of the per-parameter numels, one that
+    includes the padding (``_numels_with_padding``) and one that does not
+    (``_numels``). The former may have length longer than the other data
+    structures, while the latter has the same length as the number of actual
+    original parameters like the other per-parameter data structures.
+
+    NOTE: This is not a real class; instead, you will always get a Parameter
+    back out if you try to create one of these.  This is similar to the trick
+    we implemented for Parameter to get it to work with subclasses; this
+    is primarily so that FlatParameter supports combination with FakeTensor.
+
+    Attributes:
+        _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
+            without right-hand-side padding for divisibility by the world size.
+            For ``use_orig_params=True``, this includes alignment padding.
+        _padded_unsharded_size (torch.Size): Unsharded flat parameter's size
+            with right-hand-side padding for divisibility by the world size.
+            For ``use_orig_params=True``, this includes alignment padding. This
+            is only set for sharded strategies since they require padding for
+            the all-gather.
+        _sharded_size (torch.Size): Sharded flat parameter's size with padding.
+            This is also set for ``NO_SHARD``, in which case it is the same as
+            the unsharded sizes. (We omit "padded" because there is no
+            analogous unpadded one.)
+
+        _num_params (int): Number of original parameters flattened into this
+            flat parameter. This is the length of the per-parameter data
+            structures.
+        _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
+            entry; see :class:`ParamInfo` for details.
+        _shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
+        _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
+            prefixed from the ``_fully_sharded_module``. The names are
+            guaranteed to be unique in the subtree rooted at that module.
+        _param_extensions (Tuple[Optional[Any], ...]): Each parameter's
+            extension (i.e. some per-parameter state) used to customize
+            pre-flatten and post-unflatten behavior or ``None``. This is
+            experimental, and users should not depend on its existence in the
+            future.
+        _numels_with_padding (Tuple[int, ...]): Each parameter's numel
+            including entries for the padding. This is used to construct views
+            into the flat parameter via ``torch.split()``. This may have length
+            longer than ``_num_params``.
+        _numels (Tuple[int, ...]): Each parameter's numel excluding entries for
+            padding. This has length equal to ``_num_params``.
+        _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
+            shard parameter info; see :class:`_ShardParamInfo` for details.
+        _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
+            info entries; see :class:`SharedParamInfo` for details.
+        _modules (Set[nn.Module]): Modules that contain some original parameter
+            that is flattened into the flat parameter.
+
+        _shard_numel_padded (int): Numel padded for this rank's sharded flat
+            parameter.
+        _local_shard (Tensor): Sharded flat parameter with padding if using a
+            sharded strategy. If using ``NO_SHARD``, then this is the unpadded
+            unsharded flat parameter, and there is no notion of a sharded flat
+            parameter or padded unsharded flat parameter.
+        _full_param_padded (Tensor): Unsharded flat parameter with padding.
+            This is not defined for ``NO_SHARD``. When using mixed precision
+            for parameters, this has the low precision.
+        _full_prec_full_param_padded (Tensor): Full precision unsharded flat
+            parameter with padding. This is used for unsharding outside of
+            computation when using mixed precision for parameters. This is
+            never defined for ``NO_SHARD``.
+        _post_backward_hook_handle (RemovableHandle):
+            Flat parameter's post-backward hook handle. (Compile only)
+        _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
+            Flat parameter's :class:`AccumulateGrad` object and post-backward
+            hook handle. (Eager only)
+        _mp_shard (Tensor): Low precision sharded flat parameter with padding.
+            This is only defined when parameter mixed precision is enabled. For
+            ``NO_SHARD``, this is used for computation.
+        _cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
+            This is only defined when offloading parameters is enabled.
+        _saved_grad_shard (Tensor): Sharded gradient with padding from previous
+            iterations for gradient accumulation without :meth:`no_sync`.
+
+        _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
+            then each original parameter variable; otherwise, ``None``. This
+            does not include any padding tensors.
+        _shared_params (Optional[List[nn.Parameter]]): The original shared
+            parameter variables if ``use_orig_params=True`` and ``None``
+            otherwise.
+        _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
+            views created in the forward and tracked by autograd when
+            ``use_orig_params=True`` and is ``None`` otherwise. This is to
+            preserve those ``Tensor`` variables for the backward to ensure that
+            the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
+            in which case the post-backward hook does not run. This is relevant
+            for cases like reentrant activation checkpointing.
+        _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
+            a mask over the original parameters' gradients indicating if it is
+            logically ``None`` or not; otherwise, ``None``. This does not
+            include entries for padding. This mask is needed because only some
+            of the parameters may have ``None`` gradient, in which case the
+            flat gradient must be non-``None`` and must use zeros to
+            approximate those original ``None`` gradients. This mask informs
+            FSDP to set the original parameter gradients to ``None`` (instead
+            of zeros) as needed.
+    """
+
+    _unpadded_unsharded_size: torch.Size
+    _padded_unsharded_size: torch.Size
+    _sharded_size: torch.Size
+    _num_params: int
+    _param_infos: Tuple[ParamInfo, ...]
+    _shapes: Tuple[torch.Size, ...]
+    _fqns: Tuple[str, ...]
+    _param_extensions: Tuple[Optional[Any], ...]
+    _numels_with_padding: Tuple[int, ...]
+    _numels: Tuple[int, ...]
+    _shard_param_infos: Tuple[_ShardParamInfo, ...]
+    _shared_param_infos: Tuple[SharedParamInfo, ...]
+    _modules: Set[nn.Module]
+    _shard_numel_padded: int
+    _local_shard: Tensor
+    _full_param_padded: Tensor
+    _full_prec_full_param_padded: Tensor
+    # Eager only
+    _post_backward_hook_state: Tuple[Any, Any]
+    # Compile only
+    _post_backward_hook_handle: Any
+    _mp_shard: Tensor
+    _cpu_grad: Tensor
+    _saved_grad_shard: Tensor
+    _params: Optional[List[nn.Parameter]]
+    _shared_params: Optional[List[nn.Parameter]]
+    _tensors: Optional[List[Optional[Tensor]]]
+    _is_grad_none_mask: Optional[List[bool]]
+
+    _is_padding_mask: List[bool]
+
+    def __new__(cls, data=None, requires_grad=True):
+        assert cls is FlatParameter, "subclasses FlatParameter not supported"
+        r = nn.Parameter.__new__(nn.Parameter, data, requires_grad)  # type: ignore[call-arg]
+        r._is_flat_param = True  # type: ignore[attr-defined]
+        return r
+
+    # NB: This is not a regular method, because FlatParameters are not actually
+    # instances of this class (see __new__ above).  So you must indirectly
+    # call this directly through the classmethod.
+    @classmethod
+    def _init_metadata(
+        cls,
+        self,
+        param_infos: List[ParamInfo],
+        numels: List[int],
+        shapes: List[torch.Size],
+        fqns: List[str],
+        shared_param_infos: List[SharedParamInfo],
+        param_extensions: List[Optional[Any]],
+        params: Optional[List[nn.Parameter]],
+        shared_params: Optional[List[nn.Parameter]],
+        is_padding_mask: List[bool],
+    ) -> None:
+        """
+        Initialize attributes holding metadata about the original parameters comprising the flat parameter.
+
+        We expose this method separate from the constructor to keep the
+        constructor only responsible for the flat parameter's tensor data. This
+        method should only be called once per model, while the constructor may
+        be called multiple times, e.g. when reloading from a checkpoint, in
+        which case only the tensor data needs to be passed to the constructor.
+        Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
+        metadata is correctly assumed to be unchanged.
+
+        Args:
+            See the Attributes in the class docstring.
+        """
+        assert len(param_infos) == len(shapes)
+        assert len(param_infos) == len(fqns)
+        assert len(param_infos) == len(param_extensions)
+        self._num_params = len(param_infos)
+        self._param_infos = param_infos
+        self._shapes = shapes
+        self._fqns = fqns
+        self._param_extensions = param_extensions
+        self._is_padding_mask = is_padding_mask
+
+        numels_without_padding: List[int] = []
+        for numel, is_padding in zip(numels, is_padding_mask):
+            if not is_padding:
+                numels_without_padding.append(numel)
+        self._numels = tuple(numels_without_padding)
+        self._numels_with_padding = tuple(numels)
+        assert len(self._numels) == self._num_params
+
+        self._shared_param_infos = tuple(shared_param_infos)
+        self._modules = {pi.module for pi in self._param_infos}.union(
+            {spi.module for spi in self._shared_param_infos}
+        )
+        assert (params is None) == (shared_params is None)
+        if params is not None:
+            assert shared_params is not None and len(shared_params) == len(
+                shared_param_infos
+            )
+            self._params = []
+            for param, is_padding in zip(params, is_padding_mask):
+                if not is_padding:
+                    self._params.append(param)
+            self._shared_params = shared_params
+            # Mark the original parameters to avoid flattening them into
+            # another `FlatParameter` during recursive construction
+            for param in chain(self._params, self._shared_params):
+                _set_fsdp_flattened(param)
+            self._is_grad_none_mask = [False for _ in range(self._num_params)]
+            self._tensors = [None for _ in range(self._num_params)]
+        else:
+            self._params = None
+            self._shared_params = None
+            self._is_grad_none_mask = None
+            self._tensors = None
+        self._unpadded_unsharded_size = self.size()
+        _set_fsdp_flattened(self)
+        # Tracks whether the `FlatParameter`'s post-backward hook has been
+        # called to modify the behavior of the post-backward callback
+        self._post_backward_called = False
+
+
+class FlatParamHandle:
+    """
+    A handle that manages a flat parameter (:class:`FlatParameter`).
+
+    This includes sharding and view management.
+
+    Args:
+        params (Sequence[nn.Parameter]): The parameters to flatten into the
+            flat parameter.
+        fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
+        device (torch.device): The compute and communication device, which
+            should be a non-CPU device. We refer to it as the compute device.
+        sharding_strategy (ShardingStrategy): Sharding strategy to apply to
+            this handle's ``FlatParameter``.
+        offload_params (bool): Whether to offload the handle's
+            ``FlatParameter`` to CPU.
+        mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
+            setting passed to the FSDP constructor.
+        mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
+            precision setting passed to the FSDP constructor.
+        keep_low_precision_grads (bool): Whether to keep gradients in low
+            precision.
+        use_orig_params (bool): If ``True``, then FSDP preserves the original
+            parameter variables and returns them from ``named_parameters()``
+            (e.g. to support different optimizer hyperparameters within one
+            :class:`FlatParameter`). If ``False``, then FSDP reconstructs the
+            parameters every iteration and returns the :class:`FlatParameter` s
+            from ``named_parameters()``.
+    """
+
+    ##################
+    # INITIALIZATION #
+    ##################
+    def __init__(
+        self,
+        params: Sequence[Union[nn.Parameter, Tensor]],
+        fully_sharded_module: nn.Module,
+        device: torch.device,
+        sharding_strategy: HandleShardingStrategy,
+        offload_params: bool,
+        mp_param_dtype: Optional[torch.dtype],
+        mp_reduce_dtype: Optional[torch.dtype],
+        keep_low_precision_grads: bool,
+        process_group: dist.ProcessGroup,
+        use_orig_params: bool,
+        *,
+        fsdp_extension: Optional[FSDPExtensions] = None,
+    ):
+        super().__init__()
+        params = list(params)
+        if len(params) == 0:
+            raise ValueError(
+                f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
+            )
+        self._init_setattr_fns()
+        self._skip_writeback_check = (
+            os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
+        )
+        self._use_full_prec_in_eval = (
+            os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
+        )
+        self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
+        self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
+        if self._skip_writeback_check:
+            _warn_skip_writeback_check(
+                log,
+                f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
+                "for parameter or gradient writeback. Changing parameter or "
+                "gradient storages may lead to silent correctness errors.",
+            )
+        if self._use_fake_all_gather:
+            _warn_use_fake_all_gather(
+                log,
+                f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
+                "all-gather ops. Your training will be incorrect, but "
+                "can reveal how much time spent on all-gather ops.",
+            )
+        if self._use_fake_reduce:
+            _warn_use_fake_reduce(
+                log,
+                f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
+                "reduce-scatter ops. Your training will be incorrect, but "
+                "can reveal how much time spent on reduce-scatter ops.",
+            )
+        # Only align addresses for `use_orig_params=True` (for now)
+        align_addresses = use_orig_params
+        self._init_get_unflat_views_fn(align_addresses)
+        self.device = device
+        self._device_handle = _FSDPDeviceHandle.from_device(self.device)
+        self.process_group = process_group
+        if self._use_fake_all_gather or self._use_fake_reduce:
+            self._fake_process_group = FakeProcessGroup(
+                rank=process_group.rank(), world_size=process_group.size()
+            )
+        self.rank = process_group.rank()
+        self.world_size = process_group.size()
+        self._sharding_strategy = sharding_strategy
+        self._offload_params = offload_params
+        self._use_orig_params = use_orig_params
+        self._keep_low_precision_grads = keep_low_precision_grads
+        self._training_state = HandleTrainingState.IDLE
+        self._debug_level = dist.get_debug_level()
+        self._fully_sharded_module = fully_sharded_module
+        # For strategies that do not free after forward, we skip using sharded
+        # views after forward since the unsharded data exists. We still switch
+        # `self.flat_param` to point to the sharded flat parameter since what
+        # it points to parameterizes behavior. We use the following attribute
+        # to track which tensor data the parameters are unsharded views into.
+        self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
+        # The index in the state's `all_handles`, which must be the
+        # same across ranks for the execution order validation to work
+        self._handle_index: Optional[int] = None
+        # Index in handles_to_pre_forward_order
+        self._pre_forward_order_index: Optional[int] = None
+        # Index in `handles_post_forward_order`
+        self._post_forward_index: Optional[int] = None
+        # Used for guarding against mistargeted forward prefetches
+        self._needs_pre_forward_unshard = False
+        # Used for guarding against mistargeted backward prefetches
+        self._needs_pre_backward_unshard = False
+        # Was the handle prefetched? Set on successful _prefetch_handle and unshard
+        self._prefetched = False
+        # Optimistically assume a valid input `params` and set dtype attributes
+        # before `_init_flat_param()`, which performs the actual validation
+        self._orig_param_dtype = params[0].dtype
+        self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
+        assert self._fwd_bwd_param_dtype is not None  # mypy
+        self._aligned_numel = (
+            _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
+            if align_addresses
+            else 0
+        )
+        self._fsdp_extension = fsdp_extension
+        self._init_flat_param_and_metadata(
+            params, fully_sharded_module, self._aligned_numel, use_orig_params  # type: ignore[arg-type]
+        )
+        self._use_unsharded_views(as_params=False)
+
+    def _init_setattr_fns(self):
+        use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
+        self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
+        self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
+        if use_unsafe_setattr:
+            self._setattr_tensor = _unsafe_setattr_tensor
+            self._setattr_param = _unsafe_setattr_param
+        else:
+            self._setattr_tensor = _safe_setattr_tensor_or_param
+            self._setattr_param = _safe_setattr_tensor_or_param
+
+    def _init_get_unflat_views_fn(self, align_addresses: bool):
+        self._get_unflat_views = (
+            self._get_unflat_views_aligned
+            if align_addresses
+            else self._get_unflat_views_unaligned
+        )
+
+    def _init_flat_param_and_metadata(
+        self,
+        params: List[Union[Tensor, nn.Parameter]],
+        module: nn.Module,
+        aligned_numel: int,
+        use_orig_params: bool,
+    ) -> None:
+        """
+        Initialize the ``FlatParameter`` and its metadata.
+
+        NOTE: This should only be called once at construction time, after which
+        the ``FlatParameter`` metadata is assumed to be static.
+
+        NOTE: The elements of ``params`` should only be ``Tensor`` s when
+        composing with ``DTensor`` -based tensor parallelism, in which case the
+        elements may be ``DTensor`` local shards.
+        """
+        if len(params) == 0:
+            raise ValueError("Expects non-empty `params`")
+        if aligned_numel < 0:
+            raise ValueError(
+                f"Expects non-negative `aligned_numel` but got {aligned_numel}"
+            )
+        (
+            dtype,
+            flat_param_requires_grad,
+            device,
+        ) = self._validate_tensors_to_flatten(params)
+        params_set = set(params)
+        # For alignment padding, only `numels` gets strictly non-`None`
+        # elements, and all other lists get `None` elements for padding.
+        param_infos: List[ParamInfo] = []
+        numels: List[int] = []
+        shapes: List[torch.Size] = []
+        fqns: List[str] = []
+        shared_param_infos: List[SharedParamInfo] = []
+        shared_param_memo: Dict[
+            Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
+        ] = {}
+        params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
+        shared_params: List[Union[Tensor, nn.Parameter]] = []
+        param_extensions: List[Any] = []
+        is_padding_mask: List[bool] = []
+        total_numel = total_numel_without_padding = 0
+        for submodule_name, submodule in module.named_modules(remove_duplicate=False):
+            for param_name, param in _named_parameters_with_duplicates(
+                submodule, recurse=False
+            ):
+                if param not in params_set:
+                    continue
+                if param in shared_param_memo:  # shared reference
+                    prim_module, prim_module_name, prim_param_name = shared_param_memo[
+                        param
+                    ]
+                    shared_params.append(param)
+                    shared_param_infos.append(
+                        SharedParamInfo(
+                            param_name,
+                            submodule,
+                            submodule_name,
+                            prim_param_name,
+                            prim_module,
+                            prim_module_name,
+                        )
+                    )
+                else:
+                    if aligned_numel > 0:
+                        numel_to_pad = aligned_numel - (total_numel % aligned_numel)
+                        if numel_to_pad > 0 and numel_to_pad < aligned_numel:
+                            padding_tensor = _construct_padding_tensor(
+                                numel_to_pad, dtype, False, device
+                            )
+                            params_to_flatten.append(padding_tensor)
+                            is_padding_mask.append(True)
+                            numels.append(numel_to_pad)
+                            total_numel += numel_to_pad
+                    transform_t, extension = _ext_pre_flatten_transform(
+                        param,
+                        self._fsdp_extension,
+                    )
+                    param = cast(nn.Parameter, transform_t)
+                    param_extensions.append(extension)
+                    shared_param_memo[param] = (submodule, submodule_name, param_name)
+                    params_to_flatten.append(param)
+                    is_padding_mask.append(False)
+                    param_infos.append(ParamInfo(param_name, submodule, submodule_name))
+                    numels.append(param.numel())
+                    shapes.append(param.shape)
+                    fqn = (
+                        submodule_name + "." + param_name
+                        if submodule_name
+                        else param_name
+                    )
+                    fqns.append(fqn)
+                    total_numel += param.numel()
+                    total_numel_without_padding += param.numel()
+        if len(params_to_flatten) == 0:
+            raise ValueError(
+                f"`params` were not found in `module`'s tree"
+                f"params: {params}\nmodule: {module}"
+            )
+        if (
+            self.rank == 0
+            and aligned_numel > 0
+            and total_numel != total_numel_without_padding
+        ):
+            log.info(
+                "FSDP FlatParameter address alignment created "
+                "%s numel of padding (%s vs. %s)",
+                total_numel - total_numel_without_padding,
+                total_numel,
+                total_numel_without_padding,
+            )
+        if aligned_numel > 0:
+            # Pad to be divisible by world size to avoid a copy for the
+            # post-backward reduce-scatter
+            numel_to_pad = self.world_size - (total_numel % self.world_size)
+            if numel_to_pad > 0 and numel_to_pad < self.world_size:
+                if self.rank == 0:
+                    log.info(
+                        "FSDP FlatParameter world size divisibility created "
+                        "%s numel of padding",
+                        numel_to_pad,
+                    )
+                padding_tensor = _construct_padding_tensor(
+                    numel_to_pad, dtype, False, device
+                )
+                params_to_flatten.append(padding_tensor)
+                is_padding_mask.append(True)
+                numels.append(numel_to_pad)
+                total_numel += numel_to_pad
+        # Pass `aligned_numel=0` since we already included padding tensors
+        self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
+            params_to_flatten,
+            aligned_numel=0,
+            requires_grad=flat_param_requires_grad,
+        )
+        FlatParameter._init_metadata(
+            self.flat_param,
+            param_infos,
+            numels,
+            shapes,
+            fqns,
+            shared_param_infos,
+            param_extensions,
+            _convert_to_params(params_to_flatten) if use_orig_params else None,
+            _convert_to_params(shared_params) if use_orig_params else None,
+            is_padding_mask,
+        )
+
+    def _validate_tensors_to_flatten(
+        self, tensors: List[Union[Tensor, nn.Parameter]]
+    ) -> Tuple:
+        """Validate the tensors to flatten and returns any necessary metadata."""
+        dtype: Optional[torch.dtype] = None
+        # Return as the logical OR over each tensor's value
+        flat_param_requires_grad: Optional[bool] = None
+        device: Optional[torch.device] = None
+        # For `use_orig_params=True`, permit non-uniform `requires_grad`
+        for tensor in tensors:
+            if isinstance(tensor, FlatParameter):
+                raise ValueError("Cannot flatten a `FlatParameter`")
+            if dtype is None and not tensor.is_floating_point():
+                raise ValueError("Cannot flatten integer dtype tensors")
+            if dtype is not None and tensor.dtype != dtype:
+                raise ValueError(
+                    f"Must flatten tensors with uniform dtype but got {dtype} "
+                    f"and {tensor.dtype}"
+                )
+            if (
+                not self._use_orig_params
+                and flat_param_requires_grad is not None
+                and tensor.requires_grad != flat_param_requires_grad
+            ):
+                raise ValueError(
+                    "Must flatten tensors with uniform `requires_grad` when "
+                    "`use_orig_params=False`"
+                )
+            if device is not None and tensor.device != device:
+                raise ValueError(
+                    "Must flatten tensors on the same device but got both "
+                    f"{device} and {tensor.device}"
+                )
+            dtype = tensor.dtype
+            flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
+            device = tensor.device
+        assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
+        return dtype, flat_param_requires_grad, device
+
+    def flatten_tensors(
+        self,
+        tensors: List[Tensor],
+        aligned_numel: int,
+    ) -> Tensor:
+        """
+        Flatten ``tensors`` into a single flat tensor.
+
+        The flattening optionally includes
+        padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
+        gives the numel required to have address alignment.
+
+        NOTE: The padding alignment algorithm must be kept in sync with
+        :meth:`_init_flat_param_metadata`. We separate the two methods because
+        the initialization happens once, whereas this method may be called
+        multiple times throughout training (e.g. for checkpointing).
+        """
+        if len(tensors) == 0:
+            raise ValueError("Expects non-empty `tensors`")
+        if aligned_numel < 0:
+            raise ValueError(
+                f"Expects non-negative `aligned_numel` but got {aligned_numel}"
+            )
+        dtype, _, device = self._validate_tensors_to_flatten(tensors)
+        flat_tensors: List[Tensor] = []
+        if aligned_numel > 0:
+            total_numel = 0
+            for tensor in tensors:
+                numel_to_pad = aligned_numel - (total_numel % aligned_numel)
+                if numel_to_pad > 0 and numel_to_pad < aligned_numel:
+                    padding_tensor = _construct_padding_tensor(
+                        numel_to_pad, dtype, False, device
+                    )
+                    flat_tensors.append(padding_tensor)
+                    total_numel += numel_to_pad
+                flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
+                total_numel += tensor.numel()
+            numel_to_pad = self.world_size - (total_numel % self.world_size)
+            if numel_to_pad > 0 and numel_to_pad < self.world_size:
+                padding_tensor = _construct_padding_tensor(
+                    numel_to_pad, dtype, False, device
+                )
+                flat_tensors.append(padding_tensor)
+                total_numel += numel_to_pad
+        else:
+            flat_tensors = [
+                torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
+            ]
+        return torch.cat(flat_tensors, dim=0)
+
+    def flatten_tensors_into_flat_param(
+        self,
+        tensors: List[Tensor],
+        aligned_numel: int,
+        requires_grad: bool,
+    ) -> FlatParameter:
+        flat_param_data = self.flatten_tensors(tensors, aligned_numel)
+        return FlatParameter(flat_param_data, requires_grad=requires_grad)
+
+    def _init_param_reduce_dtypes(
+        self,
+        mp_param_dtype: Optional[torch.dtype],
+        mp_reduce_dtype: Optional[torch.dtype],
+    ) -> None:
+        """
+        Initialize param and reduce dtypes.
+
+        Precondition: ``self.flat_param`` is set. This ensures that this
+        handle's parameters have a single dtype.
+
+        Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
+        ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
+        is ``None``, then we assume the original parameter dtype. One special
+        case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
+        is ``None``, in which case we assume the gradient reduction dtype
+        matches the forward/backward parameter dtype.
+        """
+        # Save whether these dtypes were specified so that we permit the
+        # parameter dtype to change up until the lazy initialization
+        self._low_prec_param_dtype_specified = mp_param_dtype is not None
+        self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
+        if (
+            self._low_prec_param_dtype_specified
+            and not self._low_prec_reduce_dtype_specified
+        ):
+            # Special case: infer gradient reduction mixed precision
+            self._fwd_bwd_param_dtype = mp_param_dtype
+            self._reduce_dtype = self._fwd_bwd_param_dtype
+        else:
+            self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
+            self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
+        assert self._fwd_bwd_param_dtype is not None
+        assert self._reduce_dtype is not None
+
+    ###################################
+    # SHARD INITIALIZATION & METADATA #
+    ###################################
+    @torch.no_grad()
+    def shard(self):
+        """
+        Shard the handle's ``FlatParameter``.
+
+        This allocates new memory for
+        the sharded flat parameter and frees the unsharded flat parameter's
+        storage.
+
+        Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
+        metadata attributes are set for all sharding strategies.
+        """
+        flat_param = self.flat_param
+        if not self.uses_sharded_strategy:
+            self._init_shard_metadata(0, 0, flat_param.numel() - 1)
+        else:
+            _p_assert(
+                flat_param.storage_offset() == 0,
+                "The `FlatParameter` is not the sole occupant of its storage",
+            )
+            sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
+                flat_param, self.rank, self.world_size
+            )
+            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+                allocated = flat_param._typed_storage()._size() > 0
+                if allocated:
+                    flat_param._typed_storage()._resize_(0)
+            flat_param.set_(sharded_flat_param)  # type: ignore[call-overload]
+            start_idx = sharded_flat_param.numel() * self.rank
+            end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1  # inclusive
+            self._init_shard_metadata(numel_padded, start_idx, end_idx)
+        if self._use_orig_params:
+            self._use_sharded_views()
+
+    def _init_shard_metadata(
+        self,
+        numel_padded: int,
+        unsharded_start_idx: int,
+        unsharded_end_idx: int,
+    ) -> None:
+        """
+        Initialize shard-related metadata for this rank's shard of the flat parameter.
+
+        This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
+
+        Args:
+            numel_padded (int): Numel padded for this rank's sharded flat
+                parameter.
+            unsharded_start_idx (int): Start index in the unsharded flat
+            parameter assigned to this rank.
+            unsharded_end_idx (int): End index (inclusive) in the unsharded
+                flat parameter assigned to this rank.
+
+        Precondition: ``self.flat_param`` 's data is the sharded flat
+        parameter.
+        """
+        flat_param = self.flat_param
+        flat_param._sharded_size = flat_param.size()  # type: ignore[attr-defined]
+        sharded_flat_param_numel = flat_param.numel()  # includes `numel_padded`
+        _p_assert(
+            unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
+            f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
+        )
+        _p_assert(
+            numel_padded <= sharded_flat_param_numel,
+            f"numel_padded: {numel_padded} "
+            f"sharded_flat_param_numel: {sharded_flat_param_numel}",
+        )
+        shard_param_infos = self._get_shard_metadata(
+            unsharded_start_idx, unsharded_end_idx
+        )
+        assert (
+            len(shard_param_infos) == flat_param._num_params
+        ), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
+        flat_param._shard_param_infos = shard_param_infos  # type: ignore[attr-defined]
+        flat_param._shard_numel_padded = numel_padded  # type: ignore[attr-defined]
+
+    def _get_shard_metadata(
+        self,
+        unsharded_start_idx: int,
+        unsharded_end_idx: int,
+    ) -> Tuple[_ShardParamInfo, ...]:
+        """
+        Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
+
+        ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
+        unsharded flat parameter specifying the shard.
+        """
+        flat_param_offsets = self._get_flat_param_offsets()
+        assert len(flat_param_offsets) == len(
+            self.flat_param._numels_with_padding
+        ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
+        shard_param_infos: List[_ShardParamInfo] = []
+        sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
+        # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
+        # into the unsharded flat parameter (inclusive) of the given parameter
+        for i, (
+            (unsharded_param_start_idx, unsharded_param_end_idx),
+            is_padding,
+        ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
+            if is_padding:
+                continue
+            in_sharded_flat_param = (
+                unsharded_start_idx <= unsharded_param_end_idx
+                and unsharded_end_idx >= unsharded_param_start_idx
+            )
+            if not in_sharded_flat_param:
+                shard_param_info = _ShardParamInfo(False, None, None, None, None)
+            else:
+                if unsharded_start_idx <= unsharded_param_start_idx:
+                    # This branch can only happen once since the rank's
+                    # unsharded start index can only intersect one parameter
+                    intra_param_start_idx = 0
+                    offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
+                else:
+                    intra_param_start_idx = (
+                        unsharded_start_idx - unsharded_param_start_idx
+                    )
+                    offset_in_shard = 0
+                assert (
+                    offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
+                ), (
+                    f"Invalid `offset_in_shard` of {offset_in_shard} for "
+                    f"sharded flat parameter with {sharded_flat_param_numel} numel"
+                )
+                intra_param_end_idx = (
+                    min(unsharded_param_end_idx, unsharded_end_idx)
+                    - unsharded_param_start_idx
+                )
+                numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
+                shard_param_info = _ShardParamInfo(
+                    True,
+                    offset_in_shard,
+                    numel_in_shard,
+                    intra_param_start_idx,
+                    intra_param_end_idx,
+                )
+            shard_param_infos.append(shard_param_info)
+        return tuple(shard_param_infos)
+
+    @staticmethod
+    def _get_unpadded_shard(
+        tensor: Tensor,
+        rank: int,
+        world_size: int,
+    ) -> Tuple[Tensor, int]:
+        """
+        Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
+
+        The returned value is a tuple of the shard of ``tensor`` without any
+        padding and the numel to pad for that shard.
+
+        If ``tensor`` is already flattened or may be viewed in the flattened
+        shape (which is true in the expected usage), then this method does not
+        allocate any new tensor memory.
+        """
+        chunks = torch.flatten(tensor).chunk(world_size)
+        if len(chunks) < (rank + 1):
+            # This rank gets an empty chunk fully padded with zeros since there
+            # are not enough chunks across ranks
+            chunk = chunks[0].new_empty(0)
+        else:
+            chunk = chunks[rank]
+        numel_to_pad = chunks[0].numel() - chunk.numel()
+        assert (
+            numel_to_pad >= 0
+        ), "Chunk's size should be at most the first chunk's size"
+        return chunk, numel_to_pad
+
+    @staticmethod
+    def _get_shard(
+        tensor: Tensor,
+        rank: int,
+        world_size: int,
+    ) -> Tuple[Tensor, int]:
+        """
+        Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
+
+        This method allocates new memory (via :meth:`clone`) since the
+        unsharded ``tensor`` may be deallocated after this method returns.
+        """
+        chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
+            tensor, rank, world_size
+        )
+        shard = chunk.clone()
+        if numel_to_pad > 0:
+            shard = F.pad(shard, [0, numel_to_pad])
+        return shard, numel_to_pad
+
+    @staticmethod
+    def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
+        """
+        Return the shape of ``tensor`` after sharding including padding.
+
+        This requires ``tensor`` to have 1D shape and ensures that the returned
+        shape is 1D.
+        """
+        assert len(tensor.shape) == 1, f"{tensor.shape}"
+        unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
+            tensor, rank, world_size
+        )
+        unpadded_sharded_size = unpadded_sharded_tensor.size()
+        assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
+        return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
+
+    def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
+        """
+        Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
+
+        NOTE: The returned list includes elements for alignment padding.
+        """
+        cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
+        starts = [0] + cumulative_sum[:-1]
+        ends = [end - 1 for end in cumulative_sum]  # inclusive
+        param_offsets = list(zip(starts, ends))
+        return param_offsets
+
+    @no_type_check
+    def shard_metadata(
+        self,
+    ) -> FlatParamShardMetadata:
+        """
+        Return the shard-related metadata specific to this rank's shard of the flat parameter.
+
+        NOTE: The returned tuple does not include elements for alignment
+        padding but does account for the padding.
+        """
+        fqns_list = []
+        shapes_list = []
+        numels_list = []
+        shard_param_offsets = []
+        for fqn, shape, numel, shard_param_info in zip(
+            self.flat_param._fqns,
+            self.flat_param._shapes,
+            self.flat_param._numels,
+            self.flat_param._shard_param_infos,
+        ):
+            if not shard_param_info.in_shard:
+                continue
+            fqns_list.append(fqn)
+            shapes_list.append(shape)
+            numels_list.append(numel)
+            shard_param_offsets.append(
+                (
+                    shard_param_info.intra_param_start_idx,
+                    shard_param_info.intra_param_end_idx,
+                )
+            )
+        return FlatParamShardMetadata(
+            tuple(fqns_list),
+            tuple(shapes_list),
+            tuple(numels_list),
+            shard_param_offsets,
+        )
+
+    @no_type_check
+    @torch.no_grad()
+    def init_flat_param_attributes(self) -> None:
+        """
+        This initializes some attributes on the handle's ``FlatParameter``.
+        This should be called during lazy initialization since it requires the
+        parameter to be on the compute device if not offloading to CPU and we
+        want to give users the chance to move the parameter appropriately after
+        the FSDP constructor.
+
+        For each tensor attribute on the ``FlatParameter``, see the unshard and
+        reshard methods in this class for the allocation and free pattern.
+        """
+        flat_param = self.flat_param
+        if flat_param.dtype != self._orig_param_dtype:
+            # Entering this branch means that the user changed the parameter
+            # dtype after FSDP initialization, in which case we may need to
+            # refresh some saved dtype attributes (dtypes specified as a part
+            # of mixed precision take precedence).
+            if not self._low_prec_param_dtype_specified:
+                self._fwd_bwd_param_dtype = flat_param.dtype
+            # For `reduce_dtype`, require `param_dtype` was not specified since
+            # then we infer the `reduce_dtype` from the specified `param_dtype`
+            if (
+                not self._low_prec_reduce_dtype_specified
+                and not self._low_prec_param_dtype_specified
+            ):
+                self._reduce_dtype = flat_param.dtype
+            self._orig_param_dtype = flat_param.dtype
+        cpu_device = torch.device("cpu")
+        if self._offload_params:
+            _p_assert(
+                flat_param.device == cpu_device,
+                f"Expects the `FlatParameter` to be on CPU when parameter CPU "
+                f"offloading is enabled, not {flat_param.device}",
+            )
+        else:
+            self._check_on_compute_device(self.flat_param)
+        flat_param._local_shard = flat_param.data
+        if self._offload_params:
+            # Pin the memory for faster H2D transfer
+            flat_param._local_shard = flat_param._local_shard.pin_memory()
+            # Pre-allocate the sharded gradient on CPU to enable non-blocking
+            # D2H transfer during the backward pass
+            flat_param._cpu_grad = torch.zeros_like(
+                flat_param._local_shard, device=cpu_device
+            ).pin_memory()
+        if self._uses_param_mixed_precision:
+            # For parameter mixed precision, we maintain a low precision
+            # sharded tensor on the compute device to be all-gathered (for
+            # sharded strategies) or directly used (for `NO_SHARD`) for
+            # computation.
+            flat_param._mp_shard = torch.empty_like(
+                flat_param._local_shard,
+                device=self.device,
+                dtype=self._fwd_bwd_param_dtype,
+            )
+            _free_storage(flat_param._mp_shard)
+        if self.uses_sharded_strategy:
+            # We maintain a padded unsharded tensor that serves as the
+            # all-gather destination and owns the original parameter storages.
+            unsharded_param_dtype = (
+                self._fwd_bwd_param_dtype
+                if self._uses_param_mixed_precision
+                else flat_param.dtype
+            )  # use low precision if parameter mixed precision is enabled
+            padded_unsharded_numel = flat_param.numel() * self.world_size
+            flat_param._full_param_padded = torch.empty(
+                padded_unsharded_numel,
+                device=self.device,
+                dtype=unsharded_param_dtype,
+            )
+            flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
+            _free_storage(flat_param._full_param_padded)
+
+            if self._uses_param_mixed_precision:
+                # For parameter mixed precision, we maintain a full precision
+                # padded unsharded tensor for when we force full precision.
+                flat_param._full_prec_full_param_padded = torch.empty(
+                    padded_unsharded_numel,
+                    device=self.device,
+                    dtype=flat_param.dtype,  # full precision
+                )
+                _free_storage(flat_param._full_prec_full_param_padded)
+
+    ###################
+    # UNSHARD/RESHARD #
+    ###################
+    def pre_unshard(self) -> bool:
+        """
+        Return ``False`` if this is a no-op and ``True`` otherwise.
+
+        Postcondition: ``self.flat_param`` 's data is on the device for
+        communication and is what should be all-gathered. This means that it
+        matches the dtype of the expected unsharded parameter.
+        """
+        if (
+            self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
+            and self._skipped_use_sharded_views
+        ):
+            # Since this path imposes special semantics for the unsharded flat
+            # parameter (e.g. forcing full precision), use sharded views to
+            # reuse the existing logic for that special handling
+            self._use_sharded_views()
+        ret = False
+        if self._use_orig_params and not self._skip_writeback_check:
+            ret = self._writeback_orig_params()
+        if (
+            self.uses_sharded_strategy
+            and not self._offload_params
+            and not self.needs_unshard()
+        ):
+            pass  # no-op
+        elif self._uses_param_mixed_precision and not self._force_full_precision:
+            self._use_low_precision_shard()
+            ret = True
+        elif self._offload_params and self.flat_param.device != self.device:
+            # NOTE: This creates a new tensor distinct from any attributes.
+            self.flat_param_to(self.device, non_blocking=True)
+            ret = True
+        self._check_on_compute_device(self.flat_param)
+        return ret
+
+    def _use_low_precision_shard(self):
+        """Allocate on the compute device and switch to using the low precision sharded flat parameter."""
+        self._check_low_precision_shard()
+        flat_param = self.flat_param
+        _alloc_storage(
+            flat_param._mp_shard, flat_param._local_shard.size()  # type: ignore[attr-defined]
+        )
+        # `copy_()` implicitly casts to the low precision
+        flat_param._mp_shard.copy_(  # type: ignore[attr-defined]
+            flat_param._local_shard.to(  # type: ignore[attr-defined]
+                self.device, non_blocking=True
+            )
+        )
+        # Invariant: `_mp_shard` is always on the compute device.
+        flat_param.data = flat_param._mp_shard  # type: ignore[attr-defined]
+
+    def unshard(self):
+        """
+        Run the unshard logic.
+
+        This includes all-gathering the flat parameter
+        and switching to using the unsharded flat parameter. If the handle does
+        not need unsharding, then this only switches to using the unsharded
+        flat parameter. For ``NO_SHARD``, this is a no-op.
+
+        If FSDP is in :meth:`summon_full_params` and the handle uses parameter
+        mixed precision, then the parameter is forced to full precision.
+        """
+        if not self.needs_unshard():
+            # Even when not needing an unshard, we should switch to using
+            # the unsharded flat parameter
+            unsharded_flat_param = (
+                self._get_padded_unsharded_flat_param()
+                if self.uses_sharded_strategy
+                else self.flat_param
+            )
+            self._use_unsharded_flat_param(unsharded_flat_param)
+            return
+        unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
+        padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
+        self._use_unsharded_flat_param(padded_unsharded_flat_param)
+
+    def needs_unshard(self) -> bool:
+        """Return if the handle's flat parameter needs to be unsharded."""
+        if not self.uses_sharded_strategy:
+            return False
+        unsharded_flat_param = self._get_padded_unsharded_flat_param()
+        already_unsharded = _same_storage_size(
+            unsharded_flat_param, unsharded_flat_param.numel()
+        )
+        return not already_unsharded
+
+    def _alloc_padded_unsharded_flat_param(self):
+        """
+        Allocate the *padded* unsharded flat parameter.
+
+        The unpadded unsharded
+        flat parameter is always a view into the padded one. This padded
+        parameter is saved to a different attribute on the ``FlatParameter``
+        depending on if we force full precision.
+        """
+        self._check_sharded_strategy()
+        flat_param = self.flat_param
+        unsharded_flat_param = self._get_padded_unsharded_flat_param()
+        self._check_storage_freed(unsharded_flat_param)
+        _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
+        return unsharded_flat_param
+
+    def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
+        """
+        Return a reference to the padded unsharded flat parameter depending on the calling context.
+
+        This should only be called if using a sharded strategy.
+        """
+        self._check_sharded_strategy()
+        flat_param = self.flat_param
+        if self._force_full_precision and self._uses_param_mixed_precision:
+            # When parameter mixed precision is enabled, we use a different
+            # tensor as the all-gather destination to preserve the invariant
+            # that  `_full_param_padded` is in the low precision
+            unsharded_flat_param = flat_param._full_prec_full_param_padded  # type: ignore[attr-defined]
+            _p_assert(
+                unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
+                f"Expects full precision but got {self._fwd_bwd_param_dtype}",
+            )
+            # For no-reshard-after-forward strategies, `_full_param_padded` may
+            # still be allocated from a previous forward. As we are forcing
+            # full precision here, the full-precision unsharded copy may be
+            # modified, invalidating the existing low-precision unsharded copy,
+            # so we should free it here to ensure a new all-gather for the next
+            # forward/backward computation to persist the modifications.
+            if flat_param._full_param_padded.untyped_storage().size() > 0:
+                _free_storage(flat_param._full_param_padded)
+        else:
+            unsharded_flat_param = flat_param._full_param_padded  # type: ignore[attr-defined]
+        return unsharded_flat_param
+
+    def _all_gather_flat_param(
+        self,
+        padded_unsharded_flat_param: Tensor,
+    ) -> Tensor:
+        """
+        All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
+
+        Then switch to use the all-gathered tensor.
+        """
+        _p_assert(
+            hasattr(self, "process_group") and hasattr(self, "world_size"),
+            "Expects a process group and world size to have been set via `shard()`",
+        )
+        sharded_flat_param = self.flat_param.data
+        expected_numel = sharded_flat_param.numel() * self.world_size
+        _p_assert(
+            padded_unsharded_flat_param.numel() == expected_numel,
+            f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
+        )
+
+        pg = (
+            self._fake_process_group
+            if self._use_fake_all_gather
+            else self.process_group
+        )
+
+        # HACK this should be handled by C10D
+        if sharded_flat_param.is_cpu:  # type: ignore[attr-defined]
+            tensor_list = list(
+                torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
+            )
+            work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
+        else:
+            dist.all_gather_into_tensor(
+                padded_unsharded_flat_param,
+                sharded_flat_param,
+                pg,
+            )
+
+        if self._offload_params:
+            # In case of offloading, `flat_param.data` (i.e. sharded param) is
+            # created on the pre-unshard stream. We need to hand it over to the
+            # unshard stream for all-gather
+            _no_dispatch_record_stream(
+                sharded_flat_param,
+                self._device_handle.current_stream(),  # unshard_stream
+            )
+        return padded_unsharded_flat_param
+
+    def _use_unsharded_flat_param(
+        self,
+        padded_unsharded_flat_param: torch.Tensor,
+    ) -> None:
+        """
+        Switch to use the *unpadded* unsharded flat parameter.
+
+        This is a view into the *padded* unsharded flat parameter.
+        """
+        unsharded_size = self.flat_param._unpadded_unsharded_size
+        flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
+        # slicing [:] is not visible to autograd because of .data
+        self.flat_param.data = flat_param_part
+        in_forward = self._training_state == HandleTrainingState.FORWARD
+        in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
+        if self._use_orig_params:
+            if self._skipped_use_sharded_views and in_pre_backward:
+                # This call corresponds to the complementary pre-backward
+                # `_use_unsharded_views()` to the skipped pre-forward
+                # `_use_sharded_views()`, so we should skip this one too.
+                return
+            # We use `Tensor` views in the forward so that they are tracked by
+            # autograd. We use them in the pre-backward as well to support
+            # reentrant activation checkpointing, which needs the views to be
+            # tracked by autograd in the backward pass's recomputed forward.
+            self._use_unsharded_views(
+                as_params=(not in_forward and not in_pre_backward)
+            )
+        elif in_forward:
+            self._use_unsharded_views(as_params=False)
+
+    def post_unshard(self):
+        """
+        Run the post-unshard logic.
+
+        This includes freeing the low precision shard if needed.
+        """
+        if self._uses_param_mixed_precision and self.uses_sharded_strategy:
+            self._free_low_precision_sharded_param()
+        self._check_on_compute_device(self.flat_param)
+
+    def _free_low_precision_sharded_param(self):
+        """Frees the low precision sharded flat parameter."""
+        self._check_low_precision_shard()
+        # `_mp_shard` is allocated in the pre-unshard stream, consumed in the
+        # unshard stream for sharded strategies, and consumed in both the
+        # unshard and default streams for `NO_SHARD`. For sharded strategies,
+        # the current stream here is the unshard stream, and for `NO_SHARD`,
+        # it is the default stream. For `NO_SHARD`, only recording for the
+        # default stream suffices since the default stream waits for the
+        # unshard stream.
+        _no_dispatch_record_stream(
+            self.flat_param._mp_shard, self._device_handle.current_stream()  # type: ignore[attr-defined]
+        )
+        _free_storage(self.flat_param._mp_shard)  # type: ignore[attr-defined]
+
+    @torch.no_grad()
+    def unshard_grad(self):
+        """
+        Unshard the handle's ``FlatParameter``'s gradient.
+
+        If all ranks have
+        ``None`` gradient, then all original parameters will as well. This
+        method performs an all-reduce and an all-gather. The additional
+        all-reduce is tolerable since this method is not meant to be used on
+        the computation critical path.
+
+        Postcondition: ``_saved_grad_shard`` is defined and contains the value
+        to set ``flat_param.grad`` after gradients are resharded.
+        """
+        if not self.uses_sharded_strategy:
+            self._use_unsharded_grad_views()
+            return
+        flat_param = self.flat_param
+        self._check_unsharded(flat_param)
+
+        # Check if all ranks have a `None` gradient
+        num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
+        num_grad_none[0] = flat_param.grad is None
+        dist.all_reduce(num_grad_none, group=self.process_group)
+        if num_grad_none[0] == self.world_size:
+            flat_param._saved_grad_shard = None  # type: ignore[assignment]
+            self._use_unsharded_grad_views()
+            return
+
+        if flat_param.grad is None:
+            # In the case that only some ranks have `None` gradient, we use
+            # zeros to approximate as a best effort attempt
+            if self._debug_level == dist.DebugLevel.INFO:
+                warnings.warn(
+                    f"[Rank {self.rank}] Only some but not all ranks have a "
+                    "`None` `FlatParameter` gradient, so FSDP is using zeros to "
+                    "approximate those ranks' sharded gradients being `None`"
+                )
+            flat_param._saved_grad_shard = None  # type: ignore[assignment]
+            sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device)  # type: ignore[attr-defined]
+        else:
+            self._check_sharded(flat_param.grad)
+            flat_param._saved_grad_shard = flat_param.grad  # type: ignore[attr-defined]
+            sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
+        padded_unsharded_grad = torch.empty(
+            flat_param._padded_unsharded_size,  # type: ignore[attr-defined]
+            device=self.device,
+            dtype=sharded_grad.dtype,
+        )
+        dist.all_gather_into_tensor(
+            padded_unsharded_grad, sharded_grad, self.process_group
+        )
+        unsharded_size = self.flat_param._unpadded_unsharded_size
+        flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
+            unsharded_size
+        )
+        self._use_unsharded_grad_views()
+
+    def reshard_grad(self):
+        if self._use_orig_params:
+            self._use_sharded_grad_views()
+        if not self.uses_sharded_strategy:
+            return
+        self.flat_param.grad = self.flat_param._saved_grad_shard  # type: ignore[attr-defined]
+        delattr(self.flat_param, "_saved_grad_shard")
+
+    def prepare_gradient_for_backward(self):
+        """
+        Prepare the gradient for the backward computation.
+
+        This is done by saving and clearing any existing sharded gradient
+        in ``.grad`` to enable computing a new unsharded gradient.
+        """
+        _p_assert(
+            self._training_state
+            in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
+            "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
+        )
+        flat_param = self.flat_param
+        if flat_param.grad is not None and (
+            flat_param.grad.size() != flat_param._unpadded_unsharded_size
+            or flat_param.grad.device != flat_param.device  # grad on CPU
+        ):
+            self._check_on_compute_device(self.flat_param)
+            grad_offloaded = flat_param.grad.device != self.device
+            _p_assert(
+                not grad_offloaded or self._offload_params,
+                f"Expects the sharded gradient to be on {self.device} "
+                f"but got {flat_param.grad.device}",
+            )
+            prev_iter_synced_gradients = (
+                flat_param.grad.size()
+                == flat_param._local_shard.size()  # type: ignore[attr-defined]
+            )
+            if prev_iter_synced_gradients:
+                # TODO (awgu): Gradient accumulation outside `no_sync()`
+                # does not work with CPU offloading. The issue should be
+                # that, in the post-backward hook, we cannot do an addition
+                # between a CPU tensor (the existing sharded gradient) and
+                # a GPU tensor (the new sharded gradient).
+                if not grad_offloaded:
+                    flat_param._saved_grad_shard = flat_param.grad.data  # type: ignore[attr-defined]
+                    sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
+                else:
+                    _p_assert(
+                        hasattr(flat_param, "_cpu_grad"),
+                        "`_cpu_grad` should be defined if the gradient is on CPU",
+                    )
+                    sharded_grad = flat_param._cpu_grad  # type: ignore[attr-defined]
+                # If user specified to keep the gradient in low precision, then
+                # the gradient may still be of the low precision dtype if the
+                # user did not set the gradient to `None` after the previous
+                # backward, in which case FSDP should cast back to the full
+                # precision dtype so that FSDP can accumulate in that dtype in
+                # the post-backward hook and assign to `.grad` in that dtype in
+                # the post-backward callback.
+                local_shard_dtype = flat_param._local_shard.dtype  # type: ignore[attr-defined]
+                if (
+                    self._keep_low_precision_grads
+                    and sharded_grad.dtype != local_shard_dtype
+                ):
+                    sharded_grad.data = sharded_grad.to(local_shard_dtype)
+            else:
+                padded_unsharded_size = flat_param._padded_unsharded_size  # type: ignore[attr-defined]
+                _p_assert(
+                    flat_param.grad.size() == padded_unsharded_size,
+                    "Expects `.grad` to be the unsharded gradient in "
+                    f"`no_sync()` with size {padded_unsharded_size} "
+                    f"but got size {flat_param.grad.size()}",
+                )
+            flat_param.grad = None
+
+    def prepare_gradient_for_optim(self):
+        """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
+
+        def cast_grad_to_param_dtype_if_needed(flat_param):
+            # TODO (rohan-varma): test for full precision with keep_low_precision_grads
+            if not self._force_full_precision and self._keep_low_precision_grads:
+                _p_assert(flat_param.grad is not None, "Unexpected None grad!")
+                if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
+                    flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
+                    if self._use_orig_params:
+                        self._use_sharded_grad_views()
+
+        flat_param = self.flat_param
+        # TODO (awgu): We should replace these conditional checks to encode
+        # the logical intention more directly.
+        if hasattr(flat_param, "_cpu_grad"):
+            # NOTE: This branch includes `NO_SHARD`.
+            self._check_sharded(flat_param)
+            self._check_on_cpu(flat_param)
+            flat_param.grad = flat_param._cpu_grad  # type: ignore[attr-defined]
+            cast_grad_to_param_dtype_if_needed(flat_param)
+        elif hasattr(flat_param, "_saved_grad_shard"):
+            self._check_sharded(flat_param)
+            self._check_on_compute_device(flat_param)
+            if flat_param._saved_grad_shard is not None:
+                self._check_on_compute_device(flat_param._saved_grad_shard)  # type: ignore[attr-defined]
+            # If no sharded gradient was computed this iteration, then there is
+            # no need to forward `_saved_grad_shard` to `grad`
+            if flat_param._post_backward_called:  # type: ignore[attr-defined]
+                flat_param.grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
+                if flat_param.grad is not None:
+                    cast_grad_to_param_dtype_if_needed(flat_param)
+        else:
+            _p_assert(
+                not self.uses_sharded_strategy
+                or not flat_param._post_backward_called,  # type: ignore[attr-defined]
+                "All sharded parameters that received a gradient in the "
+                "post-backward should use `_saved_grad_shard`",
+            )
+        # Delete `_saved_grad_shard` since its existence indicates a previous
+        # gradient to accumulate with in the post-backward hook
+        if hasattr(flat_param, "_saved_grad_shard"):
+            delattr(flat_param, "_saved_grad_shard")
+
+    @contextlib.contextmanager
+    def to_cpu(self):
+        """
+        Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
+
+        For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
+        since (1) there is no reason to include the padding in the copy and (2)
+        there is no use case for the sharded flat parameter.
+
+        Precondition: ``self.flat_param`` 's data is the unpadded unsharded
+        flat parameter on the compute device, and the handle uses a sharded
+        strategy.
+        Postcondition: Same as the precondition.
+        """
+        self._check_sharded_strategy()
+        _p_assert(
+            self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
+            f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
+        )
+        self._check_on_compute_device(self.flat_param)
+        # Check that the unpadded unsharded flat parameter is a view into the
+        # padded unsharded flat parameter as expected
+        # NOTE: This check is not strictly needed for correctness but is a
+        # useful sanity check since the tensor should only be used internally.
+        _p_assert(
+            _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
+            "Expects the unpadded parameter to be a view into the padded parameter",
+        )
+        self.flat_param_to(torch.device("cpu"))
+        self._free_unsharded_flat_param()
+        try:
+            yield
+        finally:
+            _p_assert(
+                self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
+                f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
+            )
+            padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
+            # Copy from CPU to the compute device
+            padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
+                self.flat_param
+            )
+            self._use_unsharded_flat_param(padded_unsharded_flat_param)
+
+    def reshard(self, free_unsharded_flat_param: bool):
+        """
+        Run the reshard logic.
+
+        This includes freeing the unsharded flat
+        parameter if ``free_unsharded_flat_param`` and switching to using the
+        sharded flat parameter. Note that this also implicitly offloads
+        the sharded flat parameter (if CPU offload is enabled) by pointing
+        it to the ``_local_shard`` attribute which resides on CPU.
+        """
+        # Switch to the sharded `FlatParameter` before freeing to prevent
+        # "use-after-free"-type bugs with external profiling tools, where for
+        # `use_orig_params=True`, the `param` does not point to valid memory
+        # when setting `param.data = ...` in `_use_sharded_views()`.
+        self._use_sharded_flat_param()
+        if free_unsharded_flat_param:
+            self._free_unsharded_flat_param()
+
+    def post_reshard(self):
+        """
+        Run the post-reshard logic.
+
+        This includes freeing any memory that
+        can now be freed given that the ``FlatParameter`` points to the full
+        precision sharded flat parameter.
+
+        Precondition: ``self.flat_param`` 's data points to the full precision
+        sharded flat parameter.
+        """
+        # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
+        # is also the low precision *unsharded* flat parameter. Hence, we delay
+        # the free until the reshard.
+        if (
+            self._uses_param_mixed_precision
+            and not self.uses_sharded_strategy
+            and not self._force_full_precision  # did not use the low precision shard
+        ):
+            self._free_low_precision_sharded_param()
+
+    def _free_unsharded_flat_param(self):
+        """
+        Free the padded unsharded flat parameter. We allow this
+        function to be called even when storage is not allocated
+
+        The tensor to free depends
+        on the calling context since the unshard may have forced full
+        precision, in which case a different tensor is used.
+        """
+        self._check_sharded_strategy()
+        unsharded_flat_param = self._get_padded_unsharded_flat_param()
+        self._check_on_compute_device(unsharded_flat_param)
+        # Do not free the memory until all ops in the current stream finish
+        _no_dispatch_record_stream(
+            unsharded_flat_param, self._device_handle.current_stream()
+        )
+        _free_storage(unsharded_flat_param)
+
+    def _use_sharded_flat_param(self) -> None:
+        """Switches to using the sharded flat parameter."""
+        flat_param = self.flat_param
+        if self._use_orig_params:
+            in_forward = self._training_state == HandleTrainingState.FORWARD
+            skip_use_sharded_views = (
+                torch.is_grad_enabled()
+                and in_forward
+                and self._sharding_strategy
+                in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
+            )
+            # Only incur the extra `.data` call if needed
+            if skip_use_sharded_views:
+                unsharded_flat_param = flat_param.data
+        if self._offload_params:
+            device = flat_param._local_shard.device  # type: ignore[attr-defined]
+            _p_assert(
+                device == torch.device("cpu"),
+                f"Expects the local shard to be on CPU but got {device}",
+            )
+        flat_param.data = flat_param._local_shard  # type: ignore[attr-defined]
+        if self._use_orig_params:
+            if skip_use_sharded_views:  # type: ignore[possibly-undefined]
+                self._unsharded_flat_param_for_skipped_views = unsharded_flat_param  # type: ignore[possibly-undefined]
+            else:
+                self._use_sharded_views()
+            # For the post-forward reshard, we may try to use sharded gradient
+            # views (or unsharded gradient views if a gradient was accumulated
+            # in `no_sync()`), but for the post-backward reshard, we delay the
+            # call to after the reduce-scatter.
+            if (
+                in_forward  # type: ignore[possibly-undefined]
+                # Skip using gradient views if skipped using sharded views
+                # since exposing unsharded parameters with sharded gradients
+                # may be confusing to the user
+                and not self._skipped_use_sharded_views
+            ):
+                # TODO: Change `_unpadded_unsharded_size` if we change the
+                # gradient to be computed directly with padding.
+                accumulated_grad_in_no_sync = (
+                    flat_param.grad is not None
+                    and self.uses_sharded_strategy
+                    and flat_param.grad.shape == flat_param._unpadded_unsharded_size
+                )
+                if accumulated_grad_in_no_sync:
+                    self._use_unsharded_grad_views()
+                else:
+                    self._use_sharded_grad_views()
+
+    #########
+    # VIEWS #
+    #########
+    @no_type_check
+    def _get_unflat_views_unaligned(
+        self,
+        tensor: Optional[torch.Tensor] = None,
+    ) -> Iterator[Tensor]:
+        """
+        Return unflattened ``Tensor`` views into ``tensor``.
+
+        If `tensor`` is ``None``,  ``flat_param`` is used. The unflattening is based
+        on ``flat_param`` 's metadata.
+
+        Examples for ``tensor`` include ``flat_param.grad`` or unsharded
+        tensor optimizer state.
+        """
+        flat_param = self.flat_param
+        if tensor is None:
+            tensor = flat_param
+        views = (
+            _ext_post_unflatten_transform(
+                subtensor.view(shape),
+                param_extension,
+                self._fsdp_extension,
+            )
+            for (subtensor, shape, param_extension) in zip(
+                torch.split(tensor, flat_param._numels, dim=0),
+                flat_param._shapes,
+                flat_param._param_extensions,
+            )
+        )
+        return views
+
+    @no_type_check
+    def _get_unflat_views_aligned(
+        self,
+        tensor: Optional[Tensor] = None,
+    ) -> List[Tensor]:
+        """
+        Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
+
+        This method has the same contract as :meth:`_get_unflat_views_unaligned`
+        except it checks for ``None`` placeholders representing padding for
+        alignment, which may incur slightly more CPU overhead.
+        """
+        flat_param = self.flat_param
+        if tensor is None:
+            tensor = flat_param
+        splits: List[Tensor] = torch.split(
+            tensor, flat_param._numels_with_padding, dim=0
+        )
+        idx = 0
+        views: List[Tensor] = []
+        for split, is_padding in zip(splits, flat_param._is_padding_mask):
+            if is_padding:
+                continue
+            views.append(
+                _ext_post_unflatten_transform(
+                    split.view(flat_param._shapes[idx]),
+                    flat_param._param_extensions[idx],
+                    self._fsdp_extension,
+                )
+            )
+            idx += 1
+        return views
+
+    @no_type_check
+    @torch.enable_grad()
+    def _use_unsharded_views(self, as_params: bool) -> None:
+        """
+        Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
+
+        Args:
+            as_params (bool): If ``True``, then registers the original
+                parameters as ``nn.Parameter`` s; if ``False``, then registers
+                the original parameters only as ``Tensor`` s. ``False`` should
+                be used during forward/backward computation and when hiding the
+                original parameters from :meth:`nn.Module.named_parameters`.
+
+        Note:
+            when prefetching for next forward, current forward may be
+            annotated with `@torch.no_grad()`
+            `@torch.enable_grad()` ensures non-empty `view.grad_fn`
+            otherwise `_post_backward_hook` will not get called
+        """
+        flat_param = self.flat_param
+        self._check_unsharded(flat_param)
+        views = self._get_unflat_views()
+        from torch.distributed._tensor import DTensor
+
+        for i, (view, (param_name, module, _)) in enumerate(
+            zip(views, flat_param._param_infos)
+        ):
+            if self._use_orig_params and as_params:
+                if type(view) is DTensor:
+                    # A `DTensor` `view` is not compatible with assigning
+                    # `param.data = view`, so we cannot preserve the parameter
+                    # variable.
+                    self._setattr_param(
+                        module,
+                        param_name,
+                        nn.Parameter(view, requires_grad=flat_param.requires_grad),
+                    )
+                    continue
+                param = self.flat_param._params[i]
+                self._setattr_param(module, param_name, param)
+                param.data = view
+            elif as_params:
+                self._setattr_param(
+                    module,
+                    param_name,
+                    nn.Parameter(view, requires_grad=flat_param.requires_grad),
+                )
+            else:  # `as_params=False`
+                param_var: Tensor = view
+                if self._use_orig_params:
+                    if self._training_state == HandleTrainingState.FORWARD:
+                        # Save the `Tensor` for the pre-backward
+                        self.flat_param._tensors[i] = view  # save for pre-backward
+                    elif self._training_state == HandleTrainingState.BACKWARD_PRE:
+                        # Use the saved `Tensor` variable from the forward to
+                        # preserve the autograd graph so that the post-backward
+                        # hook fires (e.g. for reentrant AC)
+                        tensor = self.flat_param._tensors[i]
+                        tensor.data = view
+                        param_var = tensor
+                self._setattr_tensor(module, param_name, param_var)
+                if (
+                    self._use_orig_params
+                    and self._training_state == HandleTrainingState.FORWARD
+                ):
+                    module._parameters[param_name] = param_var
+        for i, (
+            param_name,
+            module,
+            _,
+            prim_param_name,
+            prim_module,
+            _,
+        ) in enumerate(self.flat_param._shared_param_infos):
+            prim_param: Union[Tensor, nn.Parameter] = getattr(
+                prim_module, prim_param_name
+            )
+            _p_assert(
+                not as_params or isinstance(prim_param, nn.Parameter),
+                f"as_params={as_params} type(prim_param)={type(prim_param)}",
+            )
+            if self._use_orig_params and as_params:
+                shared_param = self.flat_param._shared_params[i]
+                self._setattr_param(module, param_name, shared_param)
+                shared_param.data = prim_param
+            elif as_params:
+                self._setattr_param(module, param_name, prim_param)
+            else:
+                self._setattr_tensor(module, param_name, prim_param)
+                if (
+                    self._use_orig_params
+                    and self._training_state == HandleTrainingState.FORWARD
+                ):
+                    module._parameters[param_name] = prim_param
+
+    @no_type_check
+    def _use_unsharded_grad_views(self) -> None:
+        """
+        Unflatten the unsharded flat parameter's gradient.
+
+        The original parameter variables' gradients are set to be views into
+        the unsharded flat parameter's gradient.
+        """
+        # Expects the gradient to be in `flat_param.grad`
+        if self.flat_param.grad is None:
+            for param in chain(self.flat_param._params, self.flat_param._shared_params):
+                param.grad = None
+            return
+        self._check_unsharded(self.flat_param.grad)
+        views = self._get_unflat_views(self.flat_param.grad)
+        for i, (view, (param_name, module, _)) in enumerate(
+            zip(views, self.flat_param._param_infos)
+        ):
+            _p_assert(
+                hasattr(module, param_name),
+                f"{self.flat_param._fqns[i]} is missing",
+            )
+            param = getattr(module, param_name)
+            if (
+                param.shape != view.shape
+                or param.dtype != view.dtype
+                or param.device != view.device
+            ):
+                # NOTE: This is a hack using `.data` to side step the check
+                # that parameter/gradient sizes/dtypes/devices match. From
+                # calling `reshard()`, `param` has the sharded size, has the
+                # full precision dtype, and if CPU offloading is enabled, is on
+                # CPU. Thus, one or more of the following cases can hold when
+                # in `no_sync()`, where `view` is the original parameter's
+                # gradient:
+                # 1. `view` can have the unsharded size.
+                # 2. `view` can have the parameter low precision dtype.
+                # 3. `view` can be on GPU.
+                if param.grad is None:
+                    param.grad = torch.empty_like(param)
+                param.grad.data = view
+            else:
+                param.grad = view
+        for i, (
+            param_name,
+            module,
+            module_name,
+            prim_param_name,
+            prim_module,
+            _,
+        ) in enumerate(self.flat_param._shared_param_infos):
+            _p_assert(
+                hasattr(module, param_name),
+                f"{module_name + '.' + param_name if module_name else param_name} is missing",
+            )  # did not save FQN info in `_shared_param_infos`
+            param = getattr(module, param_name)
+            prim_param = getattr(prim_module, prim_param_name)
+            if (
+                param.shape != prim_param.grad.shape
+                or param.dtype != prim_param.grad.dtype
+                or param.device != prim_param.grad.device
+            ):
+                # NOTE: This is the same hack to use `.data` to side step the
+                # size check.
+                if param.grad is None:
+                    param.grad = torch.empty_like(param)
+                param.grad.data = prim_param.grad
+            else:
+                param.grad = prim_param.grad
+
+    @contextlib.contextmanager
+    def unflatten_as_params(self) -> Generator:
+        """
+        Unflatten the original parameters.
+
+        The function assumes that the flat parameter is unsharded. When in the context,
+        unflattens the original parameters as ``nn.Parameter`` views into the
+        flat parameter, and after the context, restores the original parameters
+        as ``Tensor`` views into the flat parameter.
+        """
+        self._use_unsharded_views(as_params=True)
+        try:
+            yield
+        finally:
+            self._use_unsharded_views(as_params=False)
+
+    @no_type_check
+    @torch.no_grad()
+    def _use_sharded_views(self) -> None:
+        """
+        Set the original parameter variables' data to be flattened views into the sharded flat parameter.
+
+        The views are kept as flattened to simplify the case where a parameter
+        is sharded across ranks. Parameters whose data is not present in the
+        sharded flat parameter have their data set to a size-0 empty tensor. We
+        do not delete them to ensure to preserve expected behaviors like model
+        printability. Parameters whose data is present must preserve their
+        variables to be passable to an optimizer.
+        """
+        self._unsharded_flat_param_for_skipped_views = None
+        if not self.uses_sharded_strategy:
+            # For `NO_SHARD`, use the *unflattened* unsharded views since we
+            # have the unsharded parameter
+            self._use_unsharded_views(as_params=True)
+            return
+        flat_param = self.flat_param
+        self._check_sharded(flat_param)
+        # Construct once and reuse for all parameters not in the local shard
+        size_0_empty_tensor = torch.empty(
+            0,
+            dtype=self.flat_param.dtype,  # in case `flat_param` changed dtype
+            device=self.flat_param.device,
+            requires_grad=False,
+        )
+        for param, shard_param_info, (param_name, module, _) in zip(
+            flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
+        ):
+            self._setattr_param(module, param_name, param)
+            if not shard_param_info.in_shard:
+                # Allow the original data to be freed via garbage collection
+                param.data = size_0_empty_tensor
+            else:
+                offset = shard_param_info.offset_in_shard
+                numel_in_shard = shard_param_info.numel_in_shard
+                param.data = flat_param[offset : offset + numel_in_shard]
+        assert self.flat_param._shared_params is not None
+        for i, (
+            param,
+            (param_name, module, _, prim_param_name, prim_module, _),
+        ) in enumerate(
+            zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
+        ):
+            self._setattr_param(module, param_name, param)
+            prim_param = getattr(prim_module, prim_param_name)
+            param.data = prim_param  # could be both empty and non-empty
+        if self._training_state == HandleTrainingState.BACKWARD_POST:
+            # Clear the saved `Tensor`s since they are unneeded now
+            for i in range(len(self.flat_param._tensors)):
+                self.flat_param._tensors[i] = None
+
+    @no_type_check
+    @torch.no_grad()
+    def _use_sharded_grad_views(self) -> None:
+        """
+        Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
+
+        This is a no-op if there is no gradient.
+
+        Parameters whose data is not present in the sharded flat parameter and
+        parameters with ``requires_grad=False`` have their gradients set to
+        ``None``. Since the gradient variables do not need to be preserved,
+        this method does not manipulate existing ``Tensor`` data directly and
+        creates new ``Tensor`` variables instead.
+        """
+        flat_param = self.flat_param
+        self._check_sharded(flat_param)
+        grad = self.sharded_grad
+        if grad is None:
+            for param in chain(flat_param._params, flat_param._shared_params):
+                param.grad = None
+            return
+        self._check_sharded(grad)
+        for param, shard_param_info, is_grad_none in zip(
+            flat_param._params,
+            flat_param._shard_param_infos,
+            flat_param._is_grad_none_mask,
+        ):
+            if not shard_param_info.in_shard:
+                param.grad = None
+            else:
+                numel_in_shard = shard_param_info.numel_in_shard
+                if param.requires_grad and not is_grad_none:
+                    offset = shard_param_info.offset_in_shard
+                    if self._keep_low_precision_grads or param.dtype != grad.dtype:
+                        # NOTE: This is a hack using `.data` to side step the
+                        # check that parameter/gradient dtypes match. Here,
+                        # `param` has full precision; `grad` has low precision.
+                        if param.grad is None:
+                            # `.grad` must have the same shape as `param`
+                            param.grad = torch.empty_like(param)
+                        param.grad.data = grad[
+                            offset : offset + numel_in_shard
+                        ].reshape(param.shape)
+                    else:
+                        param.grad = grad[offset : offset + numel_in_shard].reshape(
+                            param.shape
+                        )
+                else:
+                    param.grad = None
+        assert flat_param._shared_params is not None
+        for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
+            zip(flat_param._shared_params, flat_param._shared_param_infos)
+        ):
+            in_sharded_flat_param = hasattr(prim_module, prim_param_name)
+            if in_sharded_flat_param and param.requires_grad:
+                prim_param = getattr(prim_module, prim_param_name)
+                param.grad = prim_param.grad  # share the same reference
+            else:
+                param.grad = None
+
+    @no_type_check
+    @torch.no_grad()
+    def _writeback_orig_params(self) -> bool:
+        """
+        Write back any parameters that changed storage to the handle's ``FlatParameter``.
+
+        Iterates over the original parameters and writes back any parameters
+        that changed storages (due to a non-inplace operator) to the handle's
+        ``FlatParameter``. This method preserves the ``FlatParameter` 's
+        device even if an original parameter's device changes.
+
+        Raises:
+            RuntimeError: If an original parameter or gradient changes storages
+            but no longer has the expected flattened shape.
+        Returns: ``True`` if some writeback happened, and ``False`` otherwise.
+        """
+        if (
+            self.uses_sharded_strategy
+            and not self.is_sharded(self.flat_param)
+            and not self._skipped_use_sharded_views
+        ):
+            # For `NO_SHARD`, we may still need to writeback
+            return False
+        flat_param = self.flat_param
+        wroteback = False
+        if self._skipped_use_sharded_views and self.uses_sharded_strategy:
+            # NOTE: We must use the unsharded flat parameter from which the
+            # unsharded views were computed, not the one from the current
+            # calling context (`_get_padded_unsharded_flat_param()`) since that
+            # may be different (e.g. the model changed from train to eval).
+            flat_param_tensor = self._unsharded_flat_param_for_skipped_views
+            _p_assert(
+                _data_ptr_allocated(flat_param_tensor),
+                "If skipped using sharded views, the unsharded flat parameter "
+                "should be allocated",
+            )
+        else:
+            flat_param_tensor = flat_param
+        # NOTE: Since this method is called in the pre-unshard, which is only
+        # called during computation in the pre-forward or pre-backward, the
+        # sharded gradient should be guaranteed to be in `.grad`, not in
+        # `._saved_grad_shard`.
+        flat_param_grad = (
+            flat_param.grad
+            if self.uses_sharded_strategy or not self._offload_params
+            else flat_param._cpu_grad
+        )
+        for i, (
+            param,
+            (in_shard, offset_in_shard, numel_in_shard, _, _),
+            (param_name, module, _),
+        ) in enumerate(
+            zip(
+                flat_param._params,
+                flat_param._shard_param_infos,
+                flat_param._param_infos,
+            )
+        ):
+            if not in_shard:
+                continue
+            if not hasattr(module, param_name):
+                # Do not writeback if original parameters are deregistered
+                # (e.g. during model checkpointing)
+                continue
+
+            # Check for parameter writeback
+            if self._skipped_use_sharded_views:
+                param = flat_param._tensors[i]
+                _p_assert(
+                    param is not None,
+                    f"Expects to have saved tensor for {flat_param._fqns[i]}",
+                )
+            param_changed = getattr(module, param_name) is not param
+            needs_param_writeback = (
+                param_changed  # changed parameter variable itself
+                or not _same_storage(param, flat_param_tensor)
+            )
+            if self._skipped_use_sharded_views and (
+                param_changed or needs_param_writeback
+            ):
+                raise AssertionError(
+                    "FSDP does not support changing the parameters between "
+                    f"forward and backward for {self._sharding_strategy}"
+                )
+            if param_changed:
+                # NOTE: The gradient is not preserved after a parameter change.
+                param = getattr(module, param_name)
+                flat_param._params[i] = param
+            if needs_param_writeback:
+                expected_shape = torch.Size([numel_in_shard])
+                self._writeback_tensor(
+                    param, flat_param, i, expected_shape, offset_in_shard, True
+                )
+                wroteback = True
+
+            # Check for gradient writeback
+            if self._skipped_use_sharded_views:
+                # Skip the writeback check because we do not expose gradients
+                # when we skipped using sharded views
+                continue
+            if param.grad is None and flat_param.grad is not None:
+                expected_shape = torch.Size([numel_in_shard])
+                self._writeback_tensor(
+                    None, flat_param.grad, i, expected_shape, offset_in_shard, False
+                )
+            elif param.grad is not None:
+                # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
+                # memory and owns the gradient storage, so it will never
+                # require gradient writeback.
+                if not self.uses_sharded_strategy and self._offload_params:
+                    # Explicitly continue to handle the case of `no_sync()`,
+                    # where `param.grad` is a view into the GPU gradient
+                    # referenced by `flat_param.grad`, while `flat_param_grad`
+                    # is `flat_param._cpu_grad`, which is on CPU
+                    continue
+
+                needs_grad_writeback = flat_param_grad is None or not _same_storage(
+                    param.grad, flat_param_grad
+                )
+                if needs_grad_writeback:
+                    if flat_param_grad is None:
+                        flat_param_grad = torch.zeros_like(flat_param)
+                    expected_shape = torch.Size([numel_in_shard])
+                    self._writeback_tensor(
+                        param.grad,
+                        flat_param_grad,
+                        i,
+                        expected_shape,
+                        offset_in_shard,
+                        False,
+                    )
+                    flat_param.grad = flat_param_grad
+                    flat_param_grad = flat_param.grad
+
+        # TODO: If we want to handle shared parameters, we need to re-generate
+        # the shared parameter data structures in case sharedness changed.
+        for i, (
+            param_name,
+            module,
+            _,
+            prim_param_name,
+            prim_module,
+            _,
+        ) in enumerate(flat_param._shared_param_infos):
+            if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
+                raise NotImplementedError(
+                    "Changing shared parameters is not supported yet"
+                )
+        return wroteback
+
+    def _writeback_tensor(
+        self,
+        src_tensor: Optional[Tensor],
+        dst_tensor: Tensor,
+        tensor_index: int,
+        expected_shape: torch.Size,
+        offset: int,
+        is_param: bool,  # else gradient
+    ) -> None:
+        """
+        Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
+
+        ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
+        ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
+        instead of copying. ``tensor_index`` gives the index of ``src_tensor``
+        in the metadata structures.
+
+        Raises:
+            RuntimeError: If the ``src_tensor`` does not have the expected
+            shape.
+        """
+        _p_assert(
+            len(expected_shape) == 1,
+            f"Expects a 1D expected shape but got {expected_shape}",
+        )
+        if self._debug_level == dist.DebugLevel.INFO:
+            rank = self.rank if hasattr(self, "rank") else dist.get_rank()
+            src_shape = src_tensor.shape if src_tensor is not None else None
+            src_device = src_tensor.device if src_tensor is not None else None
+            warnings.warn(
+                f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
+                f"writeback in {self._training_state}\n"
+                f"expected shape={expected_shape} shape={src_shape} "
+                f"expected device={dst_tensor.device} device={src_device}"
+            )
+        if src_tensor is not None and src_tensor.shape != expected_shape:
+            # NOTE: Gradient shape mismatch is not possible in practice since
+            # the gradient shape is enforced to match that of the parameter and
+            # we already check for parameter shape mismatch.
+            raise RuntimeError(
+                f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
+                f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
+            )
+        if src_tensor is not None:
+            dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
+        else:
+            dst_tensor[offset : offset + expected_shape.numel()].zero_()
+            assert self.flat_param._is_grad_none_mask is not None
+            self.flat_param._is_grad_none_mask[tensor_index] = True
+
+    def _reset_flat_param_grad_info_if_needed(self):
+        """
+        Reset ``flat_param.grad`` if needed.
+
+        When ``use_orig_params=True``:
+        (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
+        original parameters' ``.grad`` are ``None``, and
+        (2) sets ``flat_param.requires_grad=False`` if *none* of the original
+        parameters require gradient.
+        For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
+        which case we want to free the gradients as soon after the
+        ``zero_grad()`` call as possible.
+        """
+        if not self._use_orig_params:
+            return
+        flat_param = self.flat_param
+        assert flat_param._params is not None  # mypy
+        all_grad_none = True
+        requires_grad = False
+        for param in flat_param._params:
+            all_grad_none &= param.grad is None
+            requires_grad |= param.requires_grad
+        if all_grad_none:
+            flat_param.grad = None
+        # As long as one parameter requires gradient, then the flat parameter
+        # must require gradient
+        flat_param.requires_grad = requires_grad
+
+    def _deregister_orig_params(self):
+        for param_info in self.flat_param._param_infos:
+            param_name, module, _ = param_info
+            if hasattr(module, param_name):
+                delattr(module, param_name)
+        for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
+            if hasattr(module, param_name):
+                delattr(module, param_name)
+
+    ###########
+    # HELPERS #
+    ###########
+    def flat_param_to(self, *args, **kwargs):
+        """Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
+        self.flat_param.data = self.flat_param.to(*args, **kwargs)
+        if self._use_orig_params:
+            # Refresh the views because their storage may have changed
+            if self.is_sharded(self.flat_param):
+                self._use_sharded_views()
+            else:
+                self._use_unsharded_views(as_params=True)
+
+    def _get_modules(self) -> Set[nn.Module]:
+        """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
+        return {pi.module for pi in self.flat_param._param_infos}.union(
+            {spi.module for spi in self.flat_param._shared_param_infos}
+        )
+
+    def is_sharded(self, tensor: Tensor) -> bool:
+        """
+        Return whether ``tensor`` is *currently* sharded.
+
+        For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
+        """
+        if (
+            not hasattr(self.flat_param, "_sharded_size")
+            or not self.uses_sharded_strategy
+        ):
+            # `_sharded_size` is defined iff `handle.shard()` has been called
+            return False
+        sharded_size = self.flat_param._sharded_size  # type: ignore[attr-defined]
+        return tensor.size() == sharded_size
+
+    def param_module_names(self) -> Iterator[Tuple[str, str]]:
+        shared_param_infos = [
+            ParamInfo(param_name, module, module_name)
+            for (
+                param_name,
+                module,
+                module_name,
+                _,
+                _,
+                _,
+            ) in self.flat_param._shared_param_infos
+        ]
+        for param_info in chain(self.flat_param._param_infos, shared_param_infos):
+            param_name, _, module_name = param_info  # type: ignore[misc]
+            yield (param_name, module_name)
+
+    def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
+        for param_name, _, module_name in [
+            ParamInfo(param_name, module, module_name)
+            for (
+                param_name,
+                module,
+                module_name,
+                _,
+                _,
+                _,
+            ) in self.flat_param._shared_param_infos
+        ]:
+            yield (param_name, module_name)
+
+    @property
+    def _fqns_in_shard(self) -> List[str]:
+        """Return the FQNs of the parameters present in this rank's shard."""
+        fqns_in_shard: List[str] = []
+        for fqn, shard_param_info in zip(
+            self.flat_param._fqns, self.flat_param._shard_param_infos  # type: ignore[attr-defined]
+        ):
+            if shard_param_info.in_shard:
+                fqns_in_shard.append(fqn)
+        return fqns_in_shard
+
+    @property
+    def sharded_grad(self) -> Optional[Tensor]:
+        """Return the handle's sharded gradient."""
+        flat_param = self.flat_param
+        # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
+        # - CPU offloading: `_cpu_grad`
+        # - No CPU offloading + sharded strategies: `_saved_grad_shard`
+        # - No CPU offloading + `NO_SHARD`: `grad`
+        grad: Optional[Tensor]
+        if hasattr(flat_param, "_cpu_grad"):
+            grad = flat_param._cpu_grad  # type: ignore[attr-defined]
+        elif hasattr(flat_param, "_saved_grad_shard"):
+            # In the post-backward hook, the sharded gradient is still in
+            # `_saved_grad_shard`.
+            grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
+        else:
+            # If in IDLE or in FORWARD states, then there may be an
+            # (accumulated) gradient. If accessed in IDLE, then this should
+            # be due to re-registering the original parameters (e.g. in state
+            # dict load).
+            _p_assert(
+                flat_param.grad is None
+                or not self.uses_sharded_strategy
+                or self._training_state
+                in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
+                "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
+                "unless in IDLE or FORWARD",
+            )
+            grad = flat_param.grad
+        return grad
+
+    def _reset_is_grad_none(self) -> None:
+        """
+        Reset ``_is_grad_none_mask`` as needed.
+
+        This method should only be
+        called in the post-backward after gradient computation, in which case
+        if a parameter requires gradient, then it will surely receive a
+        gradient and we may reset its mask entry to ``False``.
+        """
+        if not self._use_orig_params:
+            return
+        _p_assert(
+            self._training_state == HandleTrainingState.BACKWARD_POST,
+            "Expects to only be called in the post-backward after gradient computation",
+        )
+        flat_param = self.flat_param
+        assert flat_param._params is not None  # mypy
+        for i, param in enumerate(flat_param._params):  # type: ignore[arg-type]
+            # As long as the parameter requires gradient, it should receive a
+            # meaningful gradient (even if the gradient happens to be zeros)
+            if param.requires_grad:
+                assert flat_param._is_grad_none_mask is not None  # mypy
+                flat_param._is_grad_none_mask[i] = False
+
+    #######################
+    # CHECKS & INVARIANTS #
+    #######################
+    def _check_sharded_strategy(self):
+        _p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
+
+    def _check_on_compute_device(self, tensor: Tensor):
+        _p_assert(
+            tensor.device == self.device,
+            f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
+        )
+
+    def _check_on_cpu(self, tensor: Tensor):
+        _p_assert(
+            tensor.device == torch.device("cpu"),
+            f"Expects tensor to be on CPU but got {tensor.device}",
+        )
+
+    @staticmethod
+    def _check_storage_freed(tensor: Tensor):
+        # Compile does not resize during trace
+        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+            _p_assert(
+                _same_storage_size(tensor, 0),
+                "Expects storage to be freed but got storage with size > 0",
+            )
+
+    @staticmethod
+    def _check_storage_allocated(tensor: Tensor):
+        _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
+
+    def _check_low_precision_shard(self):
+        _p_assert(
+            self._uses_param_mixed_precision,
+            "Not using low precision for parameters",
+        )
+        _p_assert(
+            getattr(self.flat_param, "_mp_shard", None) is not None,
+            "Expects `_mp_shard` to exist",
+        )
+        device = self.flat_param._mp_shard.device  # type: ignore[attr-defined]
+        _p_assert(
+            device == self.device,
+            f"Expects the low precision shard to be on {self.device} but got {device}",
+        )
+
+    def _check_unsharded(self, tensor: Tensor):
+        msg_prefix = "Expects tensor to be unsharded "
+        _p_assert(tensor is not None, msg_prefix + "but got `None`")
+        unsharded_size = self.flat_param._unpadded_unsharded_size
+        _p_assert(
+            tensor.size() == unsharded_size,
+            msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
+        )
+
+    def _check_sharded(self, tensor: Tensor):
+        msg_prefix = "Expects tensor to be sharded "
+        _p_assert(tensor is not None, msg_prefix + "but got `None`")
+        sharded_size = self.flat_param._sharded_size  # type: ignore[attr-defined]
+        _p_assert(
+            tensor.size() == sharded_size,
+            msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
+        )
+
+    ##############
+    # PROPERTIES #
+    ##############
+    @property
+    def uses_sharded_strategy(self) -> bool:
+        return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
+
+    @property
+    def _uses_param_mixed_precision(self) -> bool:
+        return self._fwd_bwd_param_dtype != self._orig_param_dtype
+
+    @property
+    def _uses_reduce_mixed_precision(self) -> bool:
+        return self._reduce_dtype != self._orig_param_dtype
+
+    @property
+    def _force_full_precision(self) -> bool:
+        return (
+            self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
+        ) and (
+            self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
+            or
+            # Also disable mixed precision in model eval mode, if configured
+            (not self._fully_sharded_module.training and self._use_full_prec_in_eval)
+        )
+
+    @property
+    def _skipped_use_sharded_views(self) -> bool:
+        """
+        This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
+
+        This returns if this handle is
+        currently in a state where it has skipped using sharded views, in which
+        case it can restore view invariants via ``_use_sharded_views()``.
+        """
+        return self._unsharded_flat_param_for_skipped_views is not None
+
+
+# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
+def _unsafe_setattr_param(
+    module: nn.Module, param_name: str, param: nn.Parameter
+) -> None:
+    module._parameters[param_name] = param
+    # This bypasses any overrides in case `module` is an instance of an
+    # `nn.Module` subclass
+    super(nn.Module, module).__setattr__(param_name, param)
+
+
+def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
+    module._parameters.pop(param_name, None)
+    # This bypasses any overrides in case `module` is an instance of an
+    # `nn.Module` subclass
+    super(nn.Module, module).__setattr__(param_name, tensor)
+
+
+def _safe_setattr_tensor_or_param(
+    module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
+):
+    # Call `delattr()` and `setattr()` to go through `nn.Module` checks
+    if hasattr(module, param_name):
+        delattr(module, param_name)
+    setattr(module, param_name, tensor_or_param)
+
+
+def _convert_to_params(
+    tensors: List[Union[torch.Tensor, nn.Parameter]]
+) -> List[nn.Parameter]:
+    return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
+
+
+def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
+    return (
+        param_or_tensor.detach()
+        if isinstance(param_or_tensor, nn.Parameter)
+        else param_or_tensor
+    )
+
+
+def _get_aligned_numel(unsharded_dtype: torch.dtype):
+    # NOTE: This alignment constraint comes from TorchInductor.
+    ALIGNMENT = 16  # bytes
+    unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
+    aligned_numel = ALIGNMENT // unsharded_dtype_size
+    return aligned_numel
+
+
+@functools.lru_cache(8)
+def _get_dtype_size(dtype):
+    return torch.empty((), dtype=dtype).element_size()
+
+
+def _construct_padding_tensor(
+    padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
+):
+    # NOTE: Set the padding value as a magic number for debuggability. The
+    # value itself should never be used in any user-facing computation.
+    return (
+        torch.ones(
+            (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
+        )
+        * _FLAT_PARAM_PADDING_VALUE
+    )
+
+
+# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
+# messasge is passed in)
+@functools.lru_cache(1)
+def _warn_skip_writeback_check(log: logging.Logger, warning: str):
+    log.warning(warning)
+
+
+# Use `lru_cache(1)` to only log the warning once
+@functools.lru_cache(1)
+def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
+    log.warning(warning)
+
+
+# Use `lru_cache(1)` to only log the warning once
+@functools.lru_cache(1)
+def _warn_use_fake_reduce(log: logging.Logger, warning: str):
+    log.warning(warning)
+
+
+def _same_storage(a, b):
+    # Params are DTensors in backward
+    # with SHARD_GRAD_OP + TP
+    from torch.distributed._tensor import DTensor
+
+    if isinstance(a, DTensor):
+        a = a._local_tensor
+    if isinstance(b, DTensor):
+        b = b._local_tensor
+    return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
+
+
+def _same_storage_size(a: torch.Tensor, b: int):
+    return a.untyped_storage().size() // a.element_size() == b
+
+
+def _storage_size_allocated(tensor: Tensor):
+    storage_size: int = tensor.untyped_storage().size()
+    return storage_size > 0
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_fsdp_extensions.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_fsdp_extensions.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed39cfb4ad0b7891066c8503d41bf3d52820a940
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_fsdp_extensions.py
@@ -0,0 +1,179 @@
+from abc import ABC, abstractmethod
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+from torch.distributed._shard.sharded_tensor.shard import Shard
+from torch.distributed._tensor import DeviceMesh, DTensor
+from torch.distributed.fsdp._shard_utils import (
+    _all_gather_dtensor,
+    _create_chunk_dtensor,
+    _create_chunk_sharded_tensor,
+)
+
+
+class FSDPExtensions(ABC):
+    """
+    This enables some customizable hooks to enable composability with tensor
+    parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
+    set a custom :class:`FSDPExtensions` that implements the hooks.
+    """
+
+    @abstractmethod
+    def pre_flatten_transform(
+        self,
+        tensor: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[Any]]:
+        """E.g. converting ``DistributedTensor`` to local tensor."""
+        ...
+
+    @abstractmethod
+    def post_unflatten_transform(
+        self,
+        tensor: torch.Tensor,
+        param_extension: Any,
+    ) -> torch.Tensor:
+        """E.g. converting local tensor to ``DistributedTensor``."""
+        ...
+
+    @abstractmethod
+    def chunk_tensor(
+        self,
+        tensor: torch.Tensor,
+        rank: int,
+        world_size: int,
+        num_devices_per_node: int,
+        pg: dist.ProcessGroup,
+        device: Optional[torch.device] = None,
+    ) -> torch.Tensor:
+        """Shards a tensor to chunks and returns the local chunk."""
+        ...
+
+    @abstractmethod
+    def chunk_dtensor(
+        self,
+        tensor: torch.Tensor,
+        rank: int,
+        device_mesh: DeviceMesh,
+    ) -> torch.Tensor:
+        """Shards a tensor/DTensor to DTensor and returns the local DTensor."""
+        ...
+
+    @abstractmethod
+    def pre_load_state_dict_transform(
+        self,
+        tensor: torch.Tensor,
+    ) -> Tuple[torch.Tensor, List[Shard]]:
+        """
+        This is to be called before loading a *sharded* model state dict and
+        should return the tensor and list of shards from which to load data.
+        """
+        ...
+
+    @abstractmethod
+    def all_gather_dtensor(
+        self,
+        tensor: DTensor,
+        parent_mesh: Optional[DeviceMesh],
+    ) -> torch.Tensor:
+        """
+        This is to be called before loading a *sharded* DTensor state dict.
+        This gathers tensor in FSDP dimension and returns local tensor of
+        TP DTensor.
+        """
+        ...
+
+
+_extensions: Optional[FSDPExtensions] = None
+
+
+def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
+    global _extensions
+    _extensions = flattener
+
+
+def _ext_pre_flatten_transform(
+    tensor: torch.Tensor,
+    fsdp_extension: Optional[FSDPExtensions] = None,
+) -> Tuple[torch.Tensor, Optional[Any]]:
+    if fsdp_extension is not None:
+        new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor)
+        if param_extension is not None:
+            return new_tensor, param_extension
+    return tensor, None
+
+
+def _ext_post_unflatten_transform(
+    tensor: torch.Tensor,
+    param_extension: Any,
+    fsdp_extension: Optional[FSDPExtensions] = None,
+) -> torch.Tensor:
+    if fsdp_extension is not None and param_extension is not None:
+        return fsdp_extension.post_unflatten_transform(tensor, param_extension)
+    return tensor
+
+
+def _ext_chunk_tensor(
+    tensor: torch.Tensor,
+    rank: int,
+    world_size: int,
+    num_devices_per_node: int,
+    pg: dist.ProcessGroup,
+    fsdp_extension: Optional[FSDPExtensions] = None,
+) -> torch.Tensor:
+    chunk_tensor_fn = (
+        fsdp_extension.chunk_tensor
+        if fsdp_extension is not None
+        else _create_chunk_sharded_tensor
+    )
+    return chunk_tensor_fn(
+        tensor,
+        rank,
+        world_size,
+        num_devices_per_node,
+        pg,
+    )
+
+
+def _ext_chunk_dtensor(
+    tensor: torch.Tensor,
+    rank: int,
+    device_mesh: DeviceMesh,
+    fsdp_extension: Optional[FSDPExtensions] = None,
+) -> torch.Tensor:
+    chunk_dtensor_fn = (
+        fsdp_extension.chunk_dtensor
+        if fsdp_extension is not None
+        else _create_chunk_dtensor
+    )
+    return chunk_dtensor_fn(
+        tensor,
+        rank,
+        device_mesh,
+    )
+
+
+def _ext_pre_load_state_dict_transform(
+    tensor: torch.Tensor,
+    fsdp_extension: Optional[FSDPExtensions] = None,
+) -> Tuple[torch.Tensor, List[Shard]]:
+    if fsdp_extension is not None:
+        return fsdp_extension.pre_load_state_dict_transform(tensor)
+
+    assert type(tensor) is ShardedTensor
+    shards = tensor.local_shards()
+    return (tensor, shards)
+
+
+def _ext_all_gather_dtensor(
+    tensor: DTensor,
+    parent_mesh: Optional[DeviceMesh],
+    fsdp_extension: Optional[FSDPExtensions] = None,
+) -> torch.Tensor:
+    all_gather_dtensor_fn = (
+        fsdp_extension.all_gather_dtensor
+        if fsdp_extension is not None
+        else _all_gather_dtensor
+    )
+    return all_gather_dtensor_fn(tensor, parent_mesh)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_init_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_init_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..179336fc13ae8801bbf326a10bfc6b9fe2fb1a00
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_init_utils.py
@@ -0,0 +1,1182 @@
+import collections
+import itertools
+import os
+import warnings
+from typing import (
+    Any,
+    Callable,
+    Deque,
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    List,
+    no_type_check,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._exec_order_utils as exec_order_utils
+import torch.distributed.fsdp._traversal_utils as traversal_utils
+import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
+import torch.nn as nn
+from torch.distributed.algorithms._comm_hooks import default_hooks
+from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.distributed.fsdp._common_utils import (
+    _FSDPDeviceHandle,
+    _FSDPState,
+    _get_module_fsdp_state,
+    _is_fsdp_flattened,
+    _named_parameters_with_duplicates,
+    clean_tensor_name,
+    TrainingState,
+)
+from torch.distributed.fsdp._flat_param import (
+    _FSDP_USE_FULL_PREC_IN_EVAL,
+    FlatParameter,
+    FlatParamHandle,
+    HandleShardingStrategy,
+)
+from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
+from torch.distributed.fsdp.api import (
+    BackwardPrefetch,
+    CPUOffload,
+    FullOptimStateDictConfig,
+    FullStateDictConfig,
+    MixedPrecision,
+    ShardingStrategy,
+    StateDictConfig,
+    StateDictType,
+)
+from torch.distributed.fsdp.wrap import _Policy
+from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
+from torch.distributed.utils import _sync_params_and_buffers
+
+from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+from torch.utils.hooks import RemovableHandle
+
+_TORCHDISTX_AVAIL = True
+try:
+    from torchdistx import deferred_init, fake  # type: ignore[import]
+except ImportError:
+    _TORCHDISTX_AVAIL = False
+
+PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
+FSDP_SYNCED = "_fsdp_synced"
+# Specification of process groups for hybrid sharding strategies.
+HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
+# Overall specification of process group.
+ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
+
+
+# TODO (awgu): Refactor this later
+SHARDING_STRATEGY_MAP = {
+    ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
+    ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
+    ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
+    ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
+    ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
+}
+HYBRID_SHARDING_STRATEGIES = [
+    ShardingStrategy.HYBRID_SHARD,
+    ShardingStrategy._HYBRID_SHARD_ZERO2,
+]
+NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
+    ShardingStrategy.SHARD_GRAD_OP,
+    ShardingStrategy._HYBRID_SHARD_ZERO2,
+)
+
+
+# NOTE: Since non-self attributes cannot be type annotated, several attributes
+# on `state` are defined first as local variables before being assigned.
+
+
+@no_type_check
+def _init_process_group_state(
+    state: _FSDPState,
+    process_group: ProcessGroupType,
+    sharding_strategy: ShardingStrategy,
+    policy: Optional[_Policy],
+    device_mesh: Optional[DeviceMesh] = None,
+) -> _FSDPState:
+    if process_group is not None and device_mesh is not None:
+        raise ValueError(
+            "Cannot pass both process_group and device_mesh at the "
+            "same time. Please just pass only one of them."
+        )
+    is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
+    if is_hybrid_strategy:
+        if process_group is None and policy is None and device_mesh is None:
+            # Raise an error here, since this is manual wrapping with no process group
+            # passed in, there is no way to ensure all wrapped FSDP instances use the same
+            # process groups.
+            raise ValueError(
+                f"Manual wrapping with {sharding_strategy}",
+                "requires explicit specification of process group or device_mesh.",
+            )
+        else:
+            state = _init_process_group_state_for_hybrid_shard(
+                state, process_group, device_mesh
+            )
+    else:
+        if device_mesh:
+            state._device_mesh = device_mesh
+            state.process_group = device_mesh.get_group(mesh_dim=0)
+        else:
+            state.process_group = (
+                process_group if process_group is not None else _get_default_group()
+            )
+
+    state.rank = state.process_group.rank()
+    state.world_size = state.process_group.size()
+    data_parallel_world_size = state.world_size
+    if is_hybrid_strategy:
+        data_parallel_world_size *= state._inter_node_pg.size()
+    state._gradient_predivide_factor = (
+        default_hooks.DefaultState._get_gradient_predivide_factor(
+            data_parallel_world_size
+        )
+    )
+    state._gradient_postdivide_factor = (
+        data_parallel_world_size / state._gradient_predivide_factor
+    )
+    return state
+
+
+@no_type_check
+def _init_process_group_state_for_hybrid_shard(
+    state: _FSDPState,
+    process_group: ProcessGroupType,
+    device_mesh: DeviceMesh,
+) -> _FSDPState:
+    if device_mesh:
+        if _is_valid_hybrid_shard_device_mesh(device_mesh):
+            state._device_mesh = device_mesh
+            # We currently only allow _inter_node_pg to be the outermost dimension, and the
+            # process_group(intra_node) to be the innermost dimension.
+            state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
+            state.process_group = device_mesh.get_group(mesh_dim=1)
+        else:
+            raise ValueError(
+                "Expected device_mesh to have ndim=2 "
+                f"but got {len(device_mesh.get_group())}"
+            )
+    elif process_group is None:
+        default_group = _get_default_group()
+        intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
+            default_group, state._device_handle.device_count()
+        )
+        # we shard across intra-node
+        state.process_group = intra_node_group
+        # save _inter_node_pg to allreduce across.
+        state._inter_node_pg = inter_node_group
+    else:
+        # Check type and assign state.process_group and state._inter_node_pg.
+        if _is_valid_hybrid_shard_pg_type(process_group):
+            # Assuming that user passed in as intra node group and inter node group
+            # as documented.
+            state.process_group, state._inter_node_pg = process_group
+        else:
+            raise ValueError(
+                "Expected process_group to be passed in as either None or "
+                f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
+            )
+    # Create state for allreduce
+    state._inter_node_state = _get_default_comm_hook_state(
+        process_group=state._inter_node_pg,
+    )
+    return state
+
+
+@no_type_check
+def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
+    return (
+        isinstance(process_group, tuple)
+        and len(process_group) == 2
+        and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
+    )
+
+
+@no_type_check
+def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
+    return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
+
+
+@no_type_check
+def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
+    """
+    Return a process group across the current node.
+
+    For example, given each row is a distinct node:
+    0 1 2 3 4 5 6 7 8
+    9 10 11 12 13 14 15
+    This API would return an intra-node subgroup across
+    [0, 7] or [8, 15] depending on the process's rank.
+    For example, rank 3 would get [0, 7].
+    """
+    intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
+    return intra_node_subgroup
+
+
+@no_type_check
+def _init_inter_node_process_group(
+    global_process_group: dist.ProcessGroup,
+    num_devices_per_node: int,
+) -> dist.ProcessGroup:
+    """
+    Return an inter-node process group where each contained rank has the same local rank.
+
+    For example, given each row is a distinct node:
+    0 1 2 3 4 5 6 7 8
+    9 10 11 12 13 14 15
+    This API would return inter-node process group {0, 8}, {1, 9}, {2, 10}, and so forth
+    depending on the process's rank. For example, rank 1 would get {1, 9}, rank 5
+    would get {5, 13}.
+    """
+    # the inter-node pg that is returned
+    inter_node_pg = None
+    sharding_backend = dist.get_backend(global_process_group)
+    world_size = dist.get_world_size(global_process_group)
+    # Assuming fully homogeneous setup
+    num_nodes = world_size // num_devices_per_node
+    my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
+    for local_rank in range(num_devices_per_node):
+        ranks_for_inter_group = [
+            local_rank + (i * num_devices_per_node) for i in range(num_nodes)
+        ]
+        # every rank always needs to call dist.new_group
+        grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
+        if local_rank == my_local_rank:
+            inter_node_pg = grp
+
+    assert (
+        inter_node_pg is not None
+    ), f"{my_local_rank} expected to assign inter-node pg, but did not"
+    return inter_node_pg
+
+
+def _init_intra_and_inter_node_groups(
+    global_process_group: dist.ProcessGroup,
+    num_devices_per_node: int,
+) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
+    """
+    Initialize intra and inter-node process groups and return the ones corresponding to this process's rank.
+
+    This function can be used to initialize process groups for ``HYBRID_SHARD`` or
+    ``_HYBRID_SHARD_ZERO2`` in FSDP.
+    This function assumes each node has an equal number of CUDA-enabled devices.
+    Returns:
+        Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
+    """
+    return (
+        _init_intra_node_process_group(num_devices_per_node),
+        _init_inter_node_process_group(global_process_group, num_devices_per_node),
+    )
+
+
+@no_type_check
+def _init_ignored_module_states(
+    state: _FSDPState,
+    module: nn.Module,
+    ignored_modules: Optional[Iterable[torch.nn.Module]],
+    ignored_states: Union[
+        Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
+    ] = None,
+) -> _FSDPState:
+    if ignored_modules is not None and ignored_states is not None:
+        raise ValueError(
+            "Cannot pass both ignored_modules and ignored_states at the "
+            "same time. Please just pass ignored_states."
+        )
+    ignored_parameters = None
+    passed_as_ignored_states = ignored_states is not None
+    if passed_as_ignored_states:
+        ignored_states_list = list(ignored_states)
+        _check_ignored_states(ignored_states_list, True)
+    else:
+        ignored_states_list = []
+        _check_ignored_states(
+            list(ignored_modules) if ignored_modules is not None else [], False
+        )
+    if len(ignored_states_list) > 0:
+        if isinstance(ignored_states_list[0], nn.Parameter):
+            ignored_parameters = ignored_states_list
+        else:
+            ignored_modules = ignored_states_list
+    state._ignored_modules = _get_ignored_modules(module, ignored_modules)
+    state._ignored_params = _get_ignored_params(
+        module,
+        state._ignored_modules,
+        ignored_parameters,
+    )
+    state._ignored_buffer_names = _get_ignored_buffer_names(
+        module,
+        state._ignored_modules,
+    )
+    # TODO: FSDP's contract for buffers is not well-defined. They are
+    # implicitly ignored for most functionality since they are not sharded;
+    # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
+    # precision). We should formalize this contract and decide if we need to
+    # compute and store `_ignored_buffers`.
+    return state
+
+
+def _check_ignored_states(
+    ignored_states: List[Any], passed_as_ignored_states: bool
+) -> None:
+    """
+    Check that the ignored states are uniformly parameters or uniformly modules.
+
+    We may remove this check in the future if we permit mixing.
+    """
+    if len(ignored_states) == 0:
+        return
+    if passed_as_ignored_states:
+        all_params = all(isinstance(state, nn.Parameter) for state in ignored_states)
+        all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
+        if not all_params and not all_modules:
+            # Sort for consistent ordering for unit test regex matching
+            sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
+            raise ValueError(
+                "ignored_states expects all nn.Parameter or all nn.Module list "
+                f"elements but got types {sorted_types}"
+            )
+    else:
+        if not all(isinstance(state, nn.Module) for state in ignored_states):
+            sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
+            raise ValueError(
+                "ignored_modules expects nn.Module list elements but got "
+                f"types {sorted_types}"
+            )
+
+
+@no_type_check
+def _init_device_handle(
+    state: _FSDPState,
+    module: nn.Module,
+    ignored_params: Set[nn.Parameter],
+    device_id: Optional[Union[int, torch.device]],
+) -> _FSDPState:
+    """
+    Determine device handle used for initializing FSDP.
+
+    If a device is specified by ``device_id``,
+    then returns device handle corresponds to that device type. Otherwise, If the
+    module is already on a non-CPU device, then the device type is that non-CPU device type.
+    If the module is on CPU or meta, then the device type is the current cuda device.
+
+    This method will be called once ignored paramters was determined, as the device handle maybe needed
+    for other initialization.
+    """
+    determined_device = None
+    if device_id is not None:
+        determined_device = (
+            device_id
+            if isinstance(device_id, torch.device)
+            else torch.device(device_id)
+        )
+    if determined_device is None:
+        for param in _get_orig_params(module, ignored_params):
+            if param.device.type in {"cpu", "meta"}:
+                continue
+            if determined_device is None:
+                determined_device = param.device
+            else:
+                if param.device.type != determined_device.type:
+                    raise RuntimeError(
+                        f"FSDP does not support modules with different device types "
+                        f"but got params on {determined_device.type} and {param.device.type}"
+                    )
+        determined_device = determined_device or torch.device(
+            "cuda", torch.cuda.current_device()
+        )
+
+    state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
+    return state
+
+
+@no_type_check
+def _init_buffer_state(
+    state: _FSDPState,
+    module: nn.Module,
+) -> _FSDPState:
+    state._buffer_names = _get_buffer_names(module)
+    # Save a mapping from clean fully-qualified buffer name (starting from
+    # `module`) to its original dtype for restoring that dtype during model
+    # checkpointing when buffer mixed precision is enabled. The names should
+    # be clean since the casting happens in a `summon_full_params()` context.
+    _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
+    for buffer_name, buffer in module.named_buffers():
+        buffer_name = clean_tensor_name(buffer_name)
+        _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
+    state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
+    return state
+
+
+@no_type_check
+def _init_core_state(
+    state: _FSDPState,
+    sharding_strategy: Optional[ShardingStrategy],
+    mixed_precision: Optional[MixedPrecision],
+    cpu_offload: Optional[CPUOffload],
+    limit_all_gathers: bool,
+    use_orig_params: bool,
+    backward_prefetch_limit: int,
+    forward_prefetch_limit: int,
+) -> _FSDPState:
+    # We clamp the strategy to `NO_SHARD` for world size of 1 since they are
+    # currently functionally equivalent. This may change if/when we integrate
+    # FSDP with MoE.
+    if state.world_size == 1:
+        if sharding_strategy != ShardingStrategy.NO_SHARD:
+            warnings.warn(
+                "FSDP is switching to use `NO_SHARD` instead of "
+                f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
+                "the world size is 1."
+            )
+        sharding_strategy = ShardingStrategy.NO_SHARD
+    elif sharding_strategy == ShardingStrategy.NO_SHARD:
+        warnings.warn(
+            "The `NO_SHARD` sharding strategy is deprecated. If having issues, "
+            "please use DistributedDataParallel instead.",
+            # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
+            # level 3 is from the true caller
+            stacklevel=3,
+        )
+    state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
+    state.mixed_precision = mixed_precision or MixedPrecision()
+    if mixed_precision is not None:
+        torch._C._log_api_usage_once(
+            f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
+        )
+    state._use_full_prec_in_eval = (
+        os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
+    )
+    state.cpu_offload = cpu_offload or CPUOffload()
+    state.limit_all_gathers = limit_all_gathers
+    state._use_orig_params = use_orig_params
+    state.training_state = TrainingState.IDLE
+    state._is_root = None
+    state._free_event_queue = _FreeEventQueue()
+    state._debug_level = dist.get_debug_level()
+    state._exec_order_data = exec_order_utils._ExecOrderData(
+        state._debug_level,
+        backward_prefetch_limit,
+        forward_prefetch_limit,
+    )
+    # Mapping from fully sharded module to the handles it is responsible to
+    # unshard and reshard (see [Note: Fully Sharded Module])
+    _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict()
+    state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
+    # Invariant: `state.params` contains exactly the `FlatParameter`s of the
+    # handles in `state._handle`
+    _handle: FlatParamHandle = None
+    state._handle = _handle
+    params: List[FlatParameter] = []
+    state.params = params
+    return state
+
+
+@no_type_check
+def _init_runtime_state(
+    state: _FSDPState,
+) -> _FSDPState:
+    _root_pre_forward_handles: List[RemovableHandle] = []
+    state._root_pre_forward_handles = _root_pre_forward_handles
+    _pre_forward_handles: List[RemovableHandle] = []
+    state._pre_forward_handles = _pre_forward_handles
+    _post_forward_handles: List[RemovableHandle] = []
+    state._post_forward_handles = _post_forward_handles
+    state._sync_gradients = True
+    state._comm_hook = None
+    state._comm_hook_state = None
+    # Used to prevent running the pre-backward hook multiple times
+    return state
+
+
+@no_type_check
+def _init_prefetching_state(
+    state: _FSDPState,
+    backward_prefetch: BackwardPrefetch,
+    forward_prefetch: bool,
+) -> _FSDPState:
+    state.backward_prefetch = backward_prefetch
+    state.forward_prefetch = forward_prefetch
+    # The data structures use tuples of handles to generalize over the case
+    # where a module's forward involves multiple handles.
+    return state
+
+
+@no_type_check
+def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
+    # TODO: we need to add additional check once we support FSDP + PiPPy.
+    # This check is currently sufficient, since we only support FSDP + TP.
+    if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None:
+        state._fsdp_extension = DTensorExtensions(state._device_handle)
+    else:
+        # We need to explicilty set _fsdp_extension to None.
+        # Otherwise, we will run into an infinite recursion when getting the attribute.
+        state._fsdp_extension = None
+    return state
+
+
+@no_type_check
+def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
+    state._state_dict_type = StateDictType.FULL_STATE_DICT
+    state_dict_config: StateDictConfig = FullStateDictConfig()
+    state._optim_state_dict_config = FullOptimStateDictConfig()
+    state._state_dict_config = state_dict_config
+    unshard_params_ctx: Dict[nn.Module, Generator] = {}
+    state._unshard_params_ctx = unshard_params_ctx
+
+    return state
+
+
+@no_type_check
+def _init_param_handle_from_module(
+    state: _FSDPState,
+    fully_sharded_module: nn.Module,
+    device_id: Optional[Union[int, torch.device]],
+    param_init_fn: Optional[Callable[[nn.Module], None]],
+    sync_module_states: bool,
+) -> _FSDPState:
+    """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
+    _check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
+    device_from_device_id = _get_device_from_device_id(device_id, state.rank)
+    is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
+        fully_sharded_module, state._ignored_params, state._ignored_modules
+    )
+    # Materialize the module if needed
+    if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
+        _materialize_with_param_init_fn(
+            fully_sharded_module, param_init_fn, state._ignored_modules
+        )
+    elif is_meta_module:
+        _materialize_meta_module(
+            fully_sharded_module, device_id, state._ignored_modules
+        )
+    elif is_torchdistX_deferred_init:
+        deferred_init.materialize_module(
+            fully_sharded_module,
+            check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
+            and submodule not in state._ignored_modules,
+        )
+
+    ignored_buffers = {
+        buffer
+        for ignored_module in state._ignored_modules
+        for buffer in ignored_module.buffers()
+    }
+
+    _move_module_to_device(
+        fully_sharded_module,
+        state._ignored_params,
+        ignored_buffers,
+        device_from_device_id,
+    )
+    state.compute_device = _get_compute_device(
+        fully_sharded_module,
+        state._ignored_params,
+        device_from_device_id,
+        state.rank,
+    )
+
+    managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
+    if sync_module_states:
+        _sync_module_params_and_buffers(
+            fully_sharded_module, managed_params, state.process_group
+        )
+        if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
+            _sync_module_params_and_buffers(
+                fully_sharded_module, managed_params, state._inter_node_pg
+            )
+    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
+    return state
+
+
+@no_type_check
+def _init_param_handle_from_params(
+    state: _FSDPState,
+    params: List[nn.Parameter],
+    fully_sharded_module: nn.Module,
+):
+    if len(params) == 0:
+        return
+    handle = FlatParamHandle(
+        params,
+        fully_sharded_module,
+        state.compute_device,
+        SHARDING_STRATEGY_MAP[state.sharding_strategy],
+        state.cpu_offload.offload_params,
+        state.mixed_precision.param_dtype,
+        state.mixed_precision.reduce_dtype,
+        state.mixed_precision.keep_low_precision_grads,
+        state.process_group,
+        state._use_orig_params,
+        fsdp_extension=state._fsdp_extension,
+    )
+    handle.shard()
+    assert not state._handle
+    state.params.append(handle.flat_param)
+    state._handle = handle
+    state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
+    cpu_device = torch.device("cpu")
+    if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
+        handle.flat_param_to(cpu_device)
+
+
+def _get_ignored_modules(
+    root_module: nn.Module,
+    _ignored_modules: Optional[Iterable[torch.nn.Module]],
+) -> Set[nn.Module]:
+    """
+    Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
+
+    Return the modules contained in their module
+    subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
+    already-computed ignored modules are included.
+
+    ``_ignored_modules`` represents the argument passed by the user to FSDP.
+    """
+    msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
+    try:
+        ignored_root_modules = (
+            set(_ignored_modules) if _ignored_modules is not None else set()
+        )
+    except TypeError as e:
+        raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
+    for module in ignored_root_modules:
+        if not isinstance(module, torch.nn.Module):
+            raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
+        if _get_module_fsdp_state(module):
+            # TODO: We may relax this by taking the FSDP instance's wrapped
+            # module to provide more flexibility to the user.
+            raise ValueError("`ignored_modules` should not include FSDP modules")
+    # Treat modules that cannot compose with `fully_shard` as ignored modules,
+    # meaning that their subtrees are ignored
+    for module in root_module.modules():
+        if not traversal_utils._composable(module):
+            ignored_root_modules.add(module)
+    # NOTE: Even if `ignored_root_modules` is empty, do not return early so
+    # that this FSDP instance can get any ignored modules from its children.
+
+    # Include child modules and exclude nested FSDP modules themselves
+    ignored_modules = {
+        child
+        for module in ignored_root_modules
+        for child in module.modules()
+        if not isinstance(child, fsdp_file.FullyShardedDataParallel)
+    }
+    if root_module in ignored_modules:
+        warnings.warn(
+            "Trying to ignore the top-level module passed into the FSDP "
+            "constructor itself will result in all parameters being "
+            f"ignored and is not well-supported: {module}"
+        )
+    # Include nested FSDP modules' ignored modules
+    for submodule in root_module.modules():
+        optional_fsdp_state = _get_module_fsdp_state(submodule)
+        if optional_fsdp_state is not None:
+            assert hasattr(optional_fsdp_state, "_ignored_modules")
+            ignored_modules.update(optional_fsdp_state._ignored_modules)
+    return ignored_modules
+
+
+def _get_ignored_params(
+    root_module: torch.nn.Module,
+    ignored_modules: Set[torch.nn.Module],
+    ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
+) -> Set[torch.nn.Parameter]:
+    """
+    Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
+
+    :class:`FlatParameter` s are excluded from the result.
+    """
+    all_ignored_params: Set[torch.nn.Parameter] = set()
+
+    params_in_ignored_modules = {
+        p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
+    }
+
+    all_ignored_params.update(params_in_ignored_modules)
+
+    if ignored_parameters is not None:
+        params_in_ignored_parameters = {
+            p for p in ignored_parameters if not _is_fsdp_flattened(p)
+        }
+        all_ignored_params.update(params_in_ignored_parameters)
+
+    # Always include nested FSDP modules' ignored parameters
+    for submodule in root_module.modules():
+        optional_fsdp_state = _get_module_fsdp_state(submodule)
+        if optional_fsdp_state is not None:
+            assert hasattr(optional_fsdp_state, "_ignored_params")
+            all_ignored_params.update(optional_fsdp_state._ignored_params)
+
+    return all_ignored_params
+
+
+def _get_ignored_buffer_names(
+    root_module: torch.nn.Module,
+    ignored_modules: Set[torch.nn.Module],
+) -> Set[str]:
+    """Return the cleaned buffer FQNs in ``ignored_modules``."""
+    all_ignored_buffer_names: Set[str] = set()
+
+    buffers_in_ignored_modules = {
+        buffer for m in ignored_modules for buffer in m.buffers()
+    }
+
+    all_ignored_buffer_names.update(
+        {
+            clean_tensor_name(buffer_name)
+            for buffer_name, buffer in root_module.named_buffers()
+            if buffer in buffers_in_ignored_modules
+        }
+    )
+
+    # Always include nested FSDP modules' ignored buffer names
+    for submodule in root_module.modules():
+        optional_fsdp_state = _get_module_fsdp_state(submodule)
+        if optional_fsdp_state is not None:
+            assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
+            all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
+
+    return all_ignored_buffer_names
+
+
+def _get_buffer_names(root_module: nn.Module) -> Set[str]:
+    """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
+    return {
+        clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
+    }
+
+
+def _check_single_device_module(
+    module: nn.Module,
+    ignored_params: Set[nn.Parameter],
+    device_id: Optional[Union[int, torch.device]],
+) -> None:
+    """
+    Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``.
+
+    Thus, after this method, the
+    module must be either fully on the CPU or fully on a non-CPU device.
+    """
+    devices = {param.device for param in _get_orig_params(module, ignored_params)}
+    # We allow module to be partially on CPU and partially on GPU if device_id is not
+    # None, since the device_id arg will result in the CPU portion being moved to
+    # GPU. This is useful in cases where part of the module may be parallelized
+    # by another algorithm and may already be on GPU. We'd like to enforce device_id
+    # to not be None, otherwise we'd flatten parameters in a mixed module which is
+    # not supported.
+    if len(devices) == 2 and torch.device("cpu") in devices:
+        if device_id is None:
+            raise RuntimeError(
+                "To support a module with both CPU and GPU params, "
+                "please pass in device_id argument."
+            )
+    elif len(devices) > 1:
+        raise RuntimeError(
+            f"FSDP only supports single device modules but got params on {devices}"
+        )
+
+
+def _get_device_from_device_id(
+    device_id: Optional[Union[int, torch.device]],
+    rank: int,
+) -> Optional[torch.device]:
+    """
+    Return a ``torch.device`` for the specified ``device_id``.
+
+    Processes ``device_id`` and returns either the corresponding device or
+    ``None`` if ``device_id`` is ``None``.
+    """
+    if device_id is None:
+        return None
+    device = (
+        device_id if isinstance(device_id, torch.device) else torch.device(device_id)
+    )
+    if device == torch.device("cuda"):
+        warnings.warn(
+            f"FSDP got the argument `device_id` {device_id} on rank "
+            f"{rank}, which does not have an explicit index. "
+            f"FSDP will use the current device {torch.cuda.current_device()}. "
+            "If this is incorrect, please explicitly call `torch.cuda.set_device()` "
+            "before FSDP initialization or pass in the explicit device "
+            "index as the `device_id` argument."
+        )
+        device = torch.device("cuda", torch.cuda.current_device())
+    return device
+
+
+def _need_to_materialize_module(
+    module: nn.Module,
+    ignored_params: Set[nn.Parameter],
+    ignored_modules: Set[nn.Module],
+) -> Tuple[bool, bool]:
+    """
+    Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
+
+    At most of the returned bools can
+    be ``True``. If either is ``True``, then ``module`` needs to be
+    materialized.
+    """
+    managed_params = list(_get_orig_params(module, ignored_params))
+    is_meta_module = any(param.is_meta for param in managed_params)
+    # TODO: We need to establish a contract for FSDP and buffers. For now, we
+    # skip checking for meta buffers from ignored modules. We should consider
+    # refactoring the initialization holistically to avoid so many traversals.
+    for submodule in module.modules():
+        if submodule in ignored_modules:
+            continue
+        for buf in submodule.buffers(recurse=False):
+            is_meta_module |= buf.is_meta
+    is_torchdistX_deferred_init = (
+        not is_meta_module
+        and _TORCHDISTX_AVAIL
+        and any(fake.is_fake(param) for param in managed_params)
+    )
+    return is_meta_module, is_torchdistX_deferred_init
+
+
+def _materialize_with_param_init_fn(
+    root_module: nn.Module,
+    param_init_fn: Callable[[nn.Module], None],
+    ignored_modules: Set[nn.Module],
+) -> None:
+    if not callable(param_init_fn):
+        raise ValueError(
+            f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
+        )
+    modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
+    for module in modules_to_materialize:
+        param_init_fn(module)
+
+
+def _materialize_meta_module(
+    root_module: nn.Module,
+    device_from_device_id: Optional[torch.device],
+    ignored_modules: Set[nn.Module],
+):
+    # Run default meta device initialization
+    materialization_device = device_from_device_id or torch.device(
+        torch.cuda.current_device()
+    )
+    modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
+    try:
+        # Assume that each module's `reset_parameters()` only initializes its
+        # own parameters and not those of its children
+        with torch.no_grad():
+            for module in modules_to_materialize:
+                # As a contract to the user, only call `reset_parameters()` if
+                # the module has directly managed parameters/buffers
+                module_state_iter = itertools.chain(
+                    module.parameters(recurse=False), module.buffers(recurse=False)
+                )
+                has_module_states = len(list(module_state_iter)) > 0
+                if has_module_states:
+                    module.to_empty(device=materialization_device, recurse=False)
+                    module.reset_parameters()  # type: ignore[operator]
+    except BaseException as e:
+        warnings.warn(
+            "Unable to call `reset_parameters()` for module on meta "
+            f"device with error {str(e)}. Please ensure that your module of"
+            f"type {type(module)} implements a `reset_parameters()` method."  # type: ignore[possibly-undefined]
+        )
+        raise e
+
+
+def _get_modules_to_materialize(
+    root_module: nn.Module, ignored_modules: Set[nn.Module]
+) -> List[nn.Module]:
+    # Run BFS to collect the modules to materialize via `reset_parameters()`,
+    # stopping at any module with FSDP already applied or at ignored modules.
+    modules_to_materialize: List[nn.Module] = []
+    queue = collections.deque([root_module])
+    visited_modules: Set[nn.Module] = {root_module}
+    while queue:
+        module = queue.popleft()
+        modules_to_materialize.append(module)
+        for child_module in module.children():
+            if (
+                child_module not in visited_modules
+                and _get_module_fsdp_state(child_module) is None
+                and child_module not in ignored_modules
+            ):
+                visited_modules.add(child_module)
+                queue.append(child_module)
+    return modules_to_materialize
+
+
+def _move_module_to_device(
+    module: nn.Module,
+    ignored_params: Set[nn.Parameter],
+    ignored_buffers: Set[torch.Tensor],
+    device_from_device_id: Optional[torch.device],
+) -> None:
+    """
+    Move ``module`` depending on ``device_from_device_id`` and its current device.
+
+    This includes moving ignored modules' parameters.
+
+    - If ``device_from_device_id`` is not ``None``, then this moves
+    ``module`` to the device.
+    - If ``device_from_device_id`` is ``None``, then this does not move
+    ``module`` but warns the user if it is on CPU.
+
+    Precondition: ``_check_single_device_module()``.
+    """
+    cpu_device = torch.device("cpu")
+    if device_from_device_id is not None:
+        # BFS from `module` without traversing any nested FSDP instances to
+        # collect the parameters/buffers that have not yet been managed
+        queue: Deque[nn.Module] = collections.deque()
+        queue.append(module)
+        params: List[nn.Parameter] = []
+        buffers: List[torch.Tensor] = []
+        while queue:
+            curr_module = queue.popleft()
+            # NOTE: We include a check to only move parameters/buffers that are
+            # on CPU device. If they are on a CUDA device different from the
+            # one specified by `device_id`, then this does NOT move them. This
+            # is so that we can raise an error in `_get_compute_device()`.
+            params.extend(
+                param
+                for param in curr_module.parameters(recurse=False)
+                if param.device == cpu_device
+            )
+            buffers.extend(
+                buffer
+                for buffer in curr_module.buffers(recurse=False)
+                if buffer.device == cpu_device
+            )
+            for submodule in curr_module.children():
+                if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
+                    queue.append(submodule)
+        params_to_move = [p for p in params if p not in ignored_params]
+        bufs_to_move = [p for p in buffers if p not in ignored_buffers]
+        _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
+        return
+    param = next(_get_orig_params(module, ignored_params), None)
+    if param is not None and param.device == cpu_device:
+        _warn_cpu_init()
+
+
+def _move_states_to_device(
+    params: List[nn.Parameter],
+    buffers: List[torch.Tensor],
+    device_from_device_id: Optional[torch.device],
+) -> None:
+    """
+    Move states to the specified device.
+
+    Precondition: ``_check_single_device_module()`` and module's parameters and
+    buffers have been materialized if needed.
+    """
+    if len(params) == 0 and len(buffers) == 0:
+        return
+    if len(params) > 0:
+        current_device = params[0].device
+    elif len(buffers) > 0:
+        current_device = buffers[0].device
+    cpu_device = torch.device("cpu")
+    if device_from_device_id is not None:
+        # Move the parameters and buffers like the `.data` code path in
+        # `nn.Module._apply()`, which underlies `nn.Module.to()`
+        for param in params:
+            with torch.no_grad():
+                param.data = param.to(device_from_device_id)
+                if param.grad is not None:
+                    param.grad.data = param.grad.to(device_from_device_id)
+        for buffer in buffers:
+            buffer.data = buffer.to(device_from_device_id)
+    elif current_device == cpu_device:  # type: ignore[possibly-undefined]
+        _warn_cpu_init()
+
+
+def _warn_cpu_init():
+    warnings.warn(
+        "The passed-in `module` is on CPU and will thus have FSDP's sharding "
+        "initialization run on CPU, which may be slower than on GPU. We "
+        "recommend passing in the `device_id` argument for FSDP to move "
+        "`module` to GPU for the sharding initialization. `module` must also "
+        "be on GPU device to work with the `sync_module_states=True` flag "
+        "since that requires GPU communication."
+    )
+
+
+def _get_compute_device(
+    module: nn.Module,
+    ignored_params: Set[nn.Parameter],
+    device_from_device_id: Optional[torch.device],
+    rank: int,
+) -> torch.device:
+    """
+    Determine and return this FSDP instance's compute device.
+
+    If a device is
+    specified by ``device_id``, then returns that device. Otherwise, If the
+    module is already on a non-CPU device, then the compute device is that non-CPU
+    device. If the module is on CPU, then the compute device is the current
+    device.
+
+    Since this method should be called after materializing the module, any
+    non-CPU device should not be meta device. For now, the compute device is
+    always a CUDA GPU device with its explicit index.
+
+    Precondition: ``_check_single_device_module()`` and
+    ``_move_module_to_device()``.
+    """
+    param = next(_get_orig_params(module, ignored_params), None)
+    if param is not None and param.device.type != "cpu":
+        compute_device = param.device  # Determined by model param placement
+    else:
+        if device_from_device_id is not None and device_from_device_id.type != "cuda":
+            compute_device = device_from_device_id  # Determined by custom backend
+        else:
+            compute_device = torch.device("cuda", torch.cuda.current_device())
+    if device_from_device_id is not None and compute_device != device_from_device_id:
+        raise ValueError(
+            f"Inconsistent compute device and `device_id` on rank {rank}: "
+            f"{compute_device} vs {device_from_device_id}"
+        )
+    return compute_device
+
+
+# TODO: See how to deprecate!
+def _sync_module_params_and_buffers(
+    module: nn.Module,
+    params: List[nn.Parameter],
+    process_group: dist.ProcessGroup,
+) -> None:
+    """
+    Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
+
+    Precondition: ``sync_module_states == True`` and ``self.process_group`` has
+    been set.
+    """
+    module_states: List[torch.Tensor] = []
+    for buffer in module.buffers():
+        # Avoid re-synchronizing buffers in case of nested wrapping
+        if not getattr(buffer, FSDP_SYNCED, False):
+            setattr(buffer, FSDP_SYNCED, True)
+            detached_buffer = buffer.detach()
+            if is_traceable_wrapper_subclass(detached_buffer):
+                # NOTE: Here we assume no nested subclasses, at most one level of subclass
+                # in both model's buffers and params
+                attrs, _ = detached_buffer.__tensor_flatten__()  # type: ignore[attr-defined]
+                inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
+                module_states.extend(inner_buffers)
+            else:
+                module_states.append(detached_buffer)
+
+    for param in params:
+        detached_param = param.detach()
+        if is_traceable_wrapper_subclass(detached_param):
+            attrs, _ = detached_param.__tensor_flatten__()  # type: ignore[attr-defined]
+            inner_params = [getattr(detached_param, attr) for attr in attrs]
+            module_states.extend(inner_params)
+        else:
+            module_states.append(detached_param)
+
+    _check_module_states_for_sync_module_states(module_states)
+    _sync_params_and_buffers(
+        process_group,
+        module_states,
+        PARAM_BROADCAST_BUCKET_SIZE,
+        src=0,
+    )
+
+
+def _sync_module_states(
+    params: List[nn.Parameter],
+    buffers: List[torch.Tensor],
+    process_group: dist.ProcessGroup,
+) -> None:
+    # Assumes that each call to this method passes in disjoint `params` and
+    # and `buffers` across calls, so there is no chance of re-synchronizing
+    params_and_buffers = [param.detach() for param in params] + [
+        buffer.detach() for buffer in buffers
+    ]
+    _check_module_states_for_sync_module_states(params_and_buffers)
+    _sync_params_and_buffers(
+        process_group,
+        params_and_buffers,
+        PARAM_BROADCAST_BUCKET_SIZE,
+        src=0,
+    )
+
+
+def _check_module_states_for_sync_module_states(
+    module_states: List[torch.Tensor],
+) -> None:
+    if module_states and any(
+        tensor.device == torch.device("cpu") for tensor in module_states
+    ):
+        raise ValueError(
+            "The module has CPU parameters or buffers when `sync_module_states=True`, "
+            "which requires them to be on GPU. Please specify the `device_id` argument "
+            "or move the module to GPU before passing it to FSDP."
+        )
+
+
+def _get_orig_params(
+    module: nn.Module,
+    ignored_params: Set[nn.Parameter],
+) -> Iterator[nn.Parameter]:
+    """
+    Return an iterator over the original parameters in ``module``.
+
+    The iterator does not return
+    the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
+    present due to nested FSDP wrapping), or any original parameters already
+    flattened (only relevant when ``use_orig_params=True``).
+    """
+    param_gen = module.parameters()
+    try:
+        while True:
+            param = next(param_gen)
+            if param not in ignored_params and not _is_fsdp_flattened(param):
+                yield param
+    except StopIteration:
+        pass
+
+
+def _check_orig_params_flattened(
+    fsdp_module,
+    ignored_params: Set[nn.Parameter],
+) -> None:
+    """
+    Check that original parameters in ``fsdp_module`` have been flattened.
+
+    The flattened parameters are made
+    invisible to ``named_parameters()`` for the module hierarchy rooted at
+    ``fsdp_module``. This should be called as a sanity check after flattening
+    the wrapped module's parameters.
+    """
+    for param_name, param in _named_parameters_with_duplicates(fsdp_module):
+        if param not in ignored_params and not _is_fsdp_flattened(param):
+            raise RuntimeError(
+                f"Found an unflattened parameter: {param_name}; "
+                f"{param.size()} {param.__class__}"
+            )
+
+
+def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
+    return (
+        default_hooks.allreduce_hook
+        if sharding_strategy == ShardingStrategy.NO_SHARD
+        else default_hooks.reduce_scatter_hook
+    )
+
+
+def _get_default_comm_hook_state(
+    process_group: dist.ProcessGroup,
+) -> default_hooks.DefaultState:
+    return default_hooks.DefaultState(process_group=process_group)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_limiter_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_limiter_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b37ad626fa1b2029179d5f07a43f602ecff953
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_limiter_utils.py
@@ -0,0 +1,33 @@
+import collections
+from typing import Deque, Optional
+
+import torch
+
+
+class _FreeEventQueue:
+    """
+    This tracks all pending frees corresponding to inflight all-gathers. The
+    queueing pattern is iterative enqueues with a single dequeue per iteration
+    once the limit ``_max_num_inflight_all_gathers`` is reached.
+    """
+
+    def __init__(self) -> None:
+        self._queue: Deque[torch.cuda.Event] = collections.deque()
+        self._max_num_inflight_all_gathers = 2  # empirically chosen
+
+    def enqueue(self, free_event: torch.cuda.Event) -> None:
+        """Enqueues a free event."""
+        self._queue.append(free_event)
+
+    def dequeue_if_needed(self) -> Optional[torch.cuda.Event]:
+        """Dequeues a single event if the limit is reached."""
+        if len(self._queue) >= self._max_num_inflight_all_gathers:
+            return self._dequeue()
+        return None
+
+    def _dequeue(self) -> Optional[torch.cuda.Event]:
+        """Dequeues a free event if possible."""
+        if self._queue:
+            event = self._queue.popleft()
+            return event
+        return None
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_optim_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_optim_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42ed3bfc3fd49945d3be9b48a78227b5edcabe0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_optim_utils.py
@@ -0,0 +1,2086 @@
+import copy
+import functools
+import logging
+import warnings
+from contextlib import ExitStack
+from dataclasses import dataclass, field
+from typing import (
+    Any,
+    cast,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    NamedTuple,
+    no_type_check,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._traversal_utils as traversal_utils
+import torch.nn as nn
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._state_dict_utils import _gather_state_dict
+from torch.distributed._tensor import DTensor, Replicate
+from torch.distributed.distributed_c10d import _get_pg_default_device
+from torch.distributed.fsdp._common_utils import (
+    _apply_to_modules,
+    _FSDPState,
+    _get_module_fsdp_state_if_fully_sharded_module,
+    _get_param_to_fqns,
+    _module_handle,
+    _named_parameters_with_duplicates,
+    clean_tensor_name,
+)
+from torch.distributed.fsdp._debug_utils import SimpleProfiler
+from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle
+from torch.distributed.fsdp._fsdp_extensions import (
+    _ext_chunk_dtensor,
+    _ext_chunk_tensor,
+)
+from torch.distributed.fsdp._runtime_utils import (
+    _lazy_init,
+    _reset_flat_param_grad_info_if_needed,
+)
+from torch.distributed.fsdp.api import (
+    ShardingStrategy,
+    StateDictSettings,
+    StateDictType,
+)
+from torch.utils._pytree import tree_map_only
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class FSDPParamInfo:
+    state: _FSDPState
+    handle: FlatParamHandle
+    param_indices: Dict[str, int]
+    param_requires_grad: List[bool]
+
+
+def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
+    keys = sorted(dictionary.keys())
+    for k in keys:
+        yield k, dictionary[k]
+
+
+@dataclass
+class _ConsolidatedOptimState:
+    """
+    This holds the consolidated optimizer state on the target rank. Positive-
+    dimension tensor state is communicated across ranks, while zero-dimension
+    tensor state and non-tensor state is taken directly from the target rank.
+
+    PyTorch version 1.12 moved to using zero-dimension tensors for scalar
+    values, but user implemented optimizers may still use float (i.e. a
+    non-tensor). Thus, we support both and handle them identically.
+
+    Attributes:
+        tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension
+            tensor state name to the unsharded flat tensor representing the
+            state.
+        zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero-
+            dimension tensor state name to its value.
+        non_tensor_state (Dict[str, Any]): Mapping from non-tensor state
+            name to its value.
+    """
+
+    tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
+    zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
+    non_tensor_state: Dict[str, Any] = field(default_factory=dict)
+
+
+class _PosDimTensorInfo(NamedTuple):
+    """
+    Meatadata for positive-dimension tensors used internally for
+    :meth:`scatter_full_optim_state_dict`.
+
+    Attributes:
+        shape (torch.Size): Sharded tensor shape (which is equal to the
+            unsharded tensor shape if the tensor is optimizer state for a
+            non-FSDP parameter and is hence not sharded).
+        dtype (torch.dtype): Data type of the tensor.
+    """
+
+    shape: torch.Size
+    dtype: torch.dtype
+
+
+class _OptimStateKey(NamedTuple):
+    """
+    This represents an optimizer state key that may be used commonly across
+    ranks. It is based on the unflattened parameter names rather than parameter
+    IDs to make it independent of each rank's own optimizer construction.
+    """
+
+    unflat_param_names: Tuple[str, ...]
+    is_fsdp_managed: bool
+
+
+def _unflatten_optim_state(
+    fsdp_param_info: FSDPParamInfo,
+    flat_param_state: Dict[str, Any],
+    to_save: bool,
+    shard_state: bool,
+    cpu_offload: bool,
+) -> List[Dict[str, Any]]:
+    """
+    Unflattens the optimizer state, consisting of the "state" part and the
+    "param_groups" part. Unflattening the "state" part involves consolidating
+    the state on the target rank and remapping from flattened to unflattened
+    parameter IDs, and the "param_groups" part only involves remapping from
+    flattened to unflattened parameter IDs.
+
+    Args:
+        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
+            mapping from FQN to original parameter index.
+        flat_param_state (Dict[str, Any]): Entry for the flat parameter in the
+            "state" part of the optimizer state dict.
+        to_save (bool): Whether to save the state on this rank.
+
+    Returns:
+        List[Dict[str, Any]]: A :class:`list` holding the entries in the
+        "state" part of the optimizer state dict corresponding to the
+        unflattened parameters comprising the flat parameter if on the target
+        rank or an empty :class:`list` otherwise. The final optimizer state
+        dict will need to map these entries using the proper unflattened
+        parameter IDs.
+    """
+    assert (
+        not shard_state or to_save
+    ), "If ``shard_state`` is True, ``to_save`` has to be True."
+    consolidated_state = _communicate_optim_state(
+        fsdp_param_info,
+        flat_param_state,
+    )
+    if to_save:
+        unflat_param_state = _unflatten_communicated_optim_state(
+            fsdp_param_info,
+            consolidated_state,
+            shard_state,
+        )
+        for optim_state in unflat_param_state:
+            # We can't use .items() below cuz we'd run into a concurrent modification error
+            if cpu_offload:
+                for key in list(optim_state.keys()):
+                    state = optim_state[key]
+                    if not isinstance(state, torch.Tensor):
+                        continue
+                    optim_state[key] = state.cpu()
+        return unflat_param_state
+    else:
+        return []
+
+
+def _is_zero_dim_tensor(x: Any) -> bool:
+    return torch.is_tensor(x) and x.dim() == 0
+
+
+def _communicate_optim_state(
+    fsdp_param_info: FSDPParamInfo,
+    flat_param_state: Dict[str, Any],
+) -> _ConsolidatedOptimState:
+    """
+    Communicates the optimizer state for a flat parameter across ranks. All
+    ranks will hold the entire non-sharded optimizer state on GPU.
+
+    If ``N`` is the number of tensor optimizer states in the optimizer state
+    dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1``
+    otherwise (where the plus 1 comes from all-gathering the padding per rank).
+
+    Args:
+        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
+            mapping from FQN to original parameter index.
+        flat_param_state (Dict[str, Any]): The entry in the "state" part of the
+            optimizer state dict corresponding to the flat parameter.
+
+    Returns:
+        ConsolidatedOptimState: Consolidated optimizer state for the target
+        flat parameter.
+    """
+    fsdp_state = fsdp_param_info.state
+    flat_param = fsdp_param_info.handle.flat_param
+    state = _ConsolidatedOptimState()
+    tensor_state, zero_dim_tensor_state, non_tensor_state = (
+        state.tensor_state,
+        state.zero_dim_tensor_state,
+        state.non_tensor_state,
+    )
+
+    for state_name, value in sorted_items(flat_param_state):
+        # Positive-dimension tensor state: communicate across ranks
+        if torch.is_tensor(value) and value.dim() > 0:
+            # If the parameter is not sharded, then neither is the
+            # positive-dimension tensor state, so no need to communicate it --
+            # we take the target rank's value
+            if (
+                fsdp_state.world_size == 1
+                or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
+            ):
+                tensor_state[state_name] = value
+                continue
+            assert (
+                fsdp_state.compute_device is not None
+            ), "compute_device has not been initialized"
+            if value.device.type != fsdp_state.compute_device.type:
+                value = value.to(fsdp_state.compute_device)
+            # Assume that positive-dimension tensor optimizer state
+            # has the same shape as the sharded flat parameter
+            buffer_size = flat_param._full_param_padded.size()  # type: ignore[attr-defined]
+            tensor_buffer = value.new_zeros(*buffer_size)
+            dist.all_gather_into_tensor(
+                tensor_buffer, value, group=fsdp_state.process_group
+            )
+            fsdp_state._device_handle.synchronize()
+            unpadded_numel = cast(
+                nn.Parameter, flat_param._unpadded_unsharded_size
+            ).numel()
+            tensor_state[state_name] = tensor_buffer[:unpadded_numel]
+        # Zero-dimension tensor state and non-tensor state: take this rank's
+        # value directly
+        else:
+            if _is_zero_dim_tensor(value):
+                zero_dim_tensor_state[state_name] = value.detach().clone()
+            else:
+                non_tensor_state[state_name] = value
+    return state
+
+
+def _unflatten_communicated_optim_state(
+    fsdp_param_info: FSDPParamInfo,
+    state: _ConsolidatedOptimState,
+    shard_state: bool,
+) -> List[Dict[str, Any]]:
+    """
+    Unflattens the communicated optimizer state (given by ``tensor_state``,
+    ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat
+    parameter. This should only be called on the target rank.
+
+    Args:
+        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
+            mapping from FQN to original parameter index.
+        state (_ConsolidatedOptimState): Consolidated optimizer state.
+
+    Returns:
+        List[Dict[str, Any]]: A :class:`list` holding the entries in the
+        "state" part of the optimizer state dict corresponding to the
+        unflattened parameters comprising the flat parameter. The final
+        optimizer state dict will need to map these entries using the proper
+        unflattened parameter IDs.
+    """
+    fsdp_state = fsdp_param_info.state
+    handle = fsdp_param_info.handle
+    flat_param = handle.flat_param
+    unflat_param_state: List[Dict[str, Any]] = []
+    flat_param_views: Dict[str, Iterator] = {}
+    num_unflat_params = flat_param._num_params
+    tensor_state, zero_dim_tensor_state, non_tensor_state = (
+        state.tensor_state,
+        state.zero_dim_tensor_state,
+        state.non_tensor_state,
+    )
+
+    for _ in range(num_unflat_params):
+        unflat_state_param = {}
+        # Add positive-dimension tensor state: unflatten with views
+        for state_name, flat_tensor in sorted_items(tensor_state):
+            views_generated = state_name in flat_param_views
+            if not views_generated:
+                views = handle._get_unflat_views(flat_tensor)
+                flat_param_views[state_name] = views
+            else:
+                views = flat_param_views[state_name]
+            optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views)
+            if shard_state:
+                osd_config = fsdp_state._optim_state_dict_config
+                if getattr(osd_config, "_use_dtensor", False):
+                    assert fsdp_state._device_mesh is not None
+                    optim_state = _ext_chunk_dtensor(
+                        optim_state,
+                        fsdp_state.rank,
+                        fsdp_state._device_mesh,
+                        fsdp_state._fsdp_extension,
+                    )
+                else:
+                    assert fsdp_state.process_group is not None
+                    optim_state = _ext_chunk_tensor(
+                        optim_state,
+                        fsdp_state.rank,
+                        fsdp_state.world_size,
+                        fsdp_state._device_handle.device_count(),
+                        fsdp_state.process_group,
+                        fsdp_state._fsdp_extension,
+                    )
+            unflat_state_param[state_name] = optim_state
+
+        # Add zero-dimension tensor state: take the target rank's value
+        for state_name, zero_dim_tensor in sorted_items(zero_dim_tensor_state):
+            unflat_state_param[state_name] = zero_dim_tensor
+        # Add non-tensor state: take the target rank's value
+        for state_name, non_tensor in sorted_items(non_tensor_state):
+            unflat_state_param[state_name] = non_tensor
+        unflat_param_state.append(unflat_state_param)
+    return unflat_param_state
+
+
+def _broadcast_processed_state(
+    fsdp_state: _FSDPState,
+    optim_state: Dict[str, Any],
+    group: Optional[dist.ProcessGroup],
+) -> Dict[str, Any]:
+    objects: List[Any] = [None]
+    if fsdp_state.rank == 0:
+        objects[0] = tree_map_only(
+            torch.Tensor,
+            lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype),  # type: ignore[union-attr]
+            optim_state,
+        )
+    dist.broadcast_object_list(objects, src=0, group=group)
+    if fsdp_state.rank == 0:
+        return optim_state
+    else:
+        return objects[0]
+
+
+def _broadcast_state(
+    fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]
+) -> Any:
+    if fsdp_state.rank == 0:
+        if not isinstance(state, torch.Tensor) or state.dim() == 0:
+            return state
+        tensor = state.to(fsdp_state.compute_device)
+    else:
+        if isinstance(state, torch.Tensor):
+            assert state.dim() == 0, (
+                "For non-zero ranks, a tensor state should have zero dimension, "
+                "but got the state with shape {state.shape()}."
+            )
+            return state
+        elif not isinstance(state, _PosDimTensorInfo):
+            return state
+        tensor = torch.zeros(
+            state.shape, dtype=state.dtype, device=fsdp_state.compute_device
+        )
+    dist.broadcast(tensor, src=0, group=group)
+    return tensor
+
+
+def _shard_orig_param_state(
+    fsdp_param_info: FSDPParamInfo,
+    fqn: str,
+    optim_state: Dict[str, Any],
+) -> Dict[str, Any]:
+    """
+    Shard the optimizer state for the original parameter with the name ``fqn``.
+    This API should only be used when ``use_orig_params`` is True.
+    """
+    if not optim_state:
+        return {}
+    fsdp_state = fsdp_param_info.state
+    flat_param = fsdp_param_info.handle.flat_param
+    param_idx = fsdp_param_info.param_indices[fqn]
+    shard_param_info = flat_param._shard_param_infos[param_idx]  # type: ignore[attr-defined]
+    optim_state = _gather_state_dict(
+        optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device
+    )
+    if not shard_param_info.in_shard:
+        return {}
+    # Flatten and shard the state.
+    new_optim_state: Dict[str, Any] = {}
+    intra_param_start_idx = shard_param_info.intra_param_start_idx
+    intra_param_end_idx = shard_param_info.intra_param_end_idx
+    for state_name, value in optim_state.items():
+        if (
+            torch.is_tensor(value)
+            and value.dim() > 0
+            and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
+        ):
+            value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone()  # type: ignore[operator]
+        new_optim_state[state_name] = value
+    return new_optim_state
+
+
+def _flatten_optim_state_dict(
+    optim_state_dict: Dict[str, Any],
+    model: nn.Module,
+    use_orig_params: bool = False,
+    optim: Optional[torch.optim.Optimizer] = None,
+    rank0_only: bool = False,
+    group: Optional[dist.ProcessGroup] = None,
+) -> Dict[str, Any]:
+    """
+    Flattens the full optimizer state dict, still keying by unflattened parameter
+    names.
+
+    If ``use_orig_params`` is True, each rank will have all FSDP-managed
+    parameters but some of these parameters may be empty due to the sharding.
+    For a regular optim.Optimizer, states for those empty parameters will
+    not be initialized. So, when aggregating the FQNs across ranks, no assert
+    will be raised on a rank even if it does not have all the states -- it is
+    valid and FSDP know how to aggregate them. However, FSDP has to ignore
+    handling those parameters that are not managed by FSDP and do not exist on
+    the local rank -- it is managed by other parallelism and FSDP does not
+    know ho to handle/aggregate them.
+
+    Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to
+    flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require
+    all the states even if the corresponding parameters are empty. To this end,
+    ``optim`` will be used to to get the initial state of the empty parameters.
+    ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or
+    NamedOptimizer.
+
+    Returns:
+        Dict[str, Any]: The flattened optimizer state dict.
+    """
+    SimpleProfiler.reset()
+
+    unflat_osd = optim_state_dict
+    if "state" not in unflat_osd and not rank0_only:
+        raise ValueError(
+            '`optim_state_dict` must have the keys "state"'
+            "to be a valid optimizer state dict"
+        )
+    param_to_fqns = _get_param_to_fqns(model)
+    fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
+    fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state
+
+    # Broadcast unflat_osd without non-scalar tensor if rank0_only is True.
+    if rank0_only:
+        unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group)
+
+    # Construct the "state" part
+    flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}
+    unflat_osd_state = unflat_osd["state"]
+    all_state_keys = set(unflat_osd_state.keys())
+
+    for param, fqns in param_to_fqns.items():
+        fqn = fqns[0]
+        if fqn not in unflat_osd_state:
+            continue
+        all_state_keys.difference_update(fqns)
+
+        if rank0_only:
+            for fqn in fqns:
+                if not unflat_osd_state[fqn]:
+                    continue
+                for state_name in unflat_osd_state[fqn].keys():
+                    unflat_osd_state[fqn][state_name] = _broadcast_state(
+                        fsdp_state, unflat_osd_state[fqn][state_name], group=group
+                    )
+            fqn = fqns[0]
+        if fqn in fqn_to_fsdp_param_info:
+            fsdp_param_info = fqn_to_fsdp_param_info[fqn]
+            if use_orig_params:
+                with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
+                    flat_state = _shard_orig_param_state(
+                        fsdp_param_info,
+                        fqn,
+                        unflat_osd_state[fqn],
+                    )
+            else:
+                flat_state = _flatten_optim_state(
+                    fsdp_param_info,
+                    unflat_osd_state,
+                    fqns,
+                )
+            key = _OptimStateKey(tuple(fqns), True)
+            # Only include non-empty states since as expected by
+            # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer
+            # or NamedOptimizer.
+            if flat_state:
+                flat_osd_state[key] = flat_state
+            elif use_orig_params:
+                assert (
+                    len(fqns) == 1
+                ), f"use_orig_params is True but there are multiple FQNs, {fqns}."
+                if optim is not None:  # NamedOptimizer or KeyedOptimizer case.
+                    state = optim.state.get(param, None)  # type: ignore[call-overload]
+                    if state is not None:
+                        flat_osd_state[key] = copy.deepcopy(state)
+                    else:
+                        warnings.warn(
+                            f"optim_state[{key}] is not on rank{fsdp_state.rank}."
+                        )
+
+            else:
+                raise RuntimeError(
+                    f"The state of {key} is empty. This should happen when "
+                    "use_orig_params=True."
+                )
+        else:  # do not flatten non-FSDP parameters' states
+            assert len(fqns) == 1
+            key = _OptimStateKey(tuple(fqns), False)
+            flat_osd_state[key] = copy.copy(unflat_osd_state[fqn])
+
+        if rank0_only:
+            for fqn in fqns:
+                if not unflat_osd_state[fqn]:
+                    continue
+                for state_name, param_state in list(unflat_osd_state[fqn].items()):
+                    if fsdp_state.rank > 0:
+                        # Deference the tensor so that PyTorch can collect the memory.
+                        del unflat_osd_state[fqn][state_name]
+                    else:
+                        # Move the tensor in the original osd back to CPU to make the
+                        # original osd unaffected.
+                        unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][
+                            state_name
+                        ].cpu()
+
+    # Handle user-defined state, states that are not associated with parameters.
+    for key in all_state_keys:
+        user_state = unflat_osd_state[key]
+        if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params:
+            user_state = _broadcast_state(fsdp_state, user_state, group=group)
+        flat_osd_state[key] = copy.copy(user_state)
+
+    SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ")
+    # Construct the "param_groups" part -- copy as is since it will be
+    # rekeyed later according to the target rank's optimizer
+    # Only copy param_groups if it exists in unflat_osd
+    if "param_groups" in unflat_osd:
+        flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"])
+        return {"state": flat_osd_state, "param_groups": flat_osd_param_groups}
+    else:
+        return {"state": flat_osd_state}
+
+
+def _flatten_optim_state(
+    fsdp_param_info: FSDPParamInfo,
+    unflat_osd_state: Dict[str, Dict[str, Any]],
+    unflat_param_names: List[str],
+) -> Dict[str, Any]:
+    """
+    Flattens the optimizer state in ``full_optim_state_dict`` for a single
+    flat parameter in ``fsdp_param_info`` corresponding to the unflattened
+    parameter names in ``unflat_param_names``.
+
+    Args:
+        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
+            mapping from FQN to original parameter index.
+        unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the
+            optimizer state dict corresponding to the unflattened parameters.
+        unflat_param_names (List[str]): A :class:`list` of unflattened
+            parameter names corresponding to the flat parameter ``flat_param``.
+
+    Returns:
+        Dict[str, Any]: A :class:`dict` mapping state names to their values for
+        a particular flat parameter. The sharded optimizer state dict's "state"
+        part will map a key to this returned value.
+    """
+    fsdp_state = fsdp_param_info.state
+    handle = fsdp_param_info.handle
+    flat_param = handle.flat_param
+    num_unflat_params = len(unflat_param_names)
+    assert num_unflat_params > 0, (
+        "Expects at least one unflattened parameter corresponding to the "
+        "flat parameter"
+    )
+    unflat_param_shapes = flat_param._shapes
+    num_unflat_param_shapes = len(unflat_param_shapes)
+    assert (
+        num_unflat_params == num_unflat_param_shapes
+    ), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
+
+    # Check if these unflattened parameters have any optimizer state
+    has_state = [
+        bool(unflat_param_name in unflat_osd_state)
+        for unflat_param_name in unflat_param_names
+    ]
+    # If none of the unflattened parameters comprising this flat parameter have
+    # any state, then we do not want an entry in the optimizer state dict
+    if not any(has_state):
+        return {}  # no need to flatten any state
+    # There may still be some unflattened parameters with state and some
+    # without
+    unflat_param_states = [
+        _gather_state_dict(
+            unflat_osd_state[unflat_param_name],
+            pg=fsdp_state.process_group,
+            device=fsdp_state.compute_device,
+        )
+        if unflat_param_name in unflat_osd_state
+        else None
+        for unflat_param_name in unflat_param_names
+    ]
+    # Check that the unflattened parameters have the same state names
+    state_names = None
+    for unflat_param_state in unflat_param_states:
+        if unflat_param_state is None:
+            continue
+        if state_names is None:
+            state_names = set(unflat_param_state.keys())
+        else:
+            if state_names != set(unflat_param_state.keys()):
+                raise ValueError(
+                    "Differing optimizer state names for the unflattened "
+                    f"parameters: {unflat_param_names}"
+                )
+    assert state_names is not None
+
+    # Flatten the state
+    flat_state: Dict[str, Any] = {}
+    for state_name in state_names:
+        state_values = [
+            unflat_param_state[state_name] if unflat_param_state is not None else None
+            for unflat_param_state in unflat_param_states
+        ]
+        non_none_state_values = [v for v in state_values if v is not None]
+        # If all ranks have None, this is a None value
+        if not non_none_state_values:
+            flat_state[state_name] = None
+            continue
+        are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True
+        for v in non_none_state_values:
+            are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0
+            are_zero_dim_tensors &= _is_zero_dim_tensor(v)
+            are_non_tensors &= not torch.is_tensor(v)
+        types = {type(v) for v in non_none_state_values}
+        if len(types) != 1 or not (
+            are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors
+        ):
+            raise ValueError(
+                f"Differing optimizer state types for state {state_name}, "
+                f"values {non_none_state_values}, and unflattened parameter "
+                f"names {unflat_param_names}"
+            )
+        if are_pos_dim_tensors:
+            flat_tensor = _flatten_tensor_optim_state(
+                state_name,
+                state_values,
+                unflat_param_names,
+                unflat_param_shapes,
+                handle,
+            )
+            # Shard the flattened tensor immediately to minimize max memory
+            # usage
+            if (
+                fsdp_state.world_size != 1
+                and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
+            ):
+                sharded_flat_tensor, _ = FlatParamHandle._get_shard(
+                    flat_tensor,
+                    fsdp_state.rank,
+                    fsdp_state.world_size,
+                )
+            else:
+                sharded_flat_tensor = flat_tensor
+            flat_state[state_name] = sharded_flat_tensor
+        elif are_zero_dim_tensors:
+            flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
+                state_name,
+                state_values,
+                unflat_param_names,
+            )
+        else:
+            assert are_non_tensors
+            flat_state[state_name] = _flatten_non_tensor_optim_state(
+                state_name,
+                state_values,
+                unflat_param_names,
+            )
+
+    return flat_state
+
+
+def _flatten_tensor_optim_state(
+    state_name: str,
+    pos_dim_tensors: List[torch.Tensor],
+    unflat_param_names: List[str],
+    unflat_param_shapes: Sequence[torch.Size],
+    handle: FlatParamHandle,
+) -> torch.Tensor:
+    """
+    Flattens the positive-dimension tensor optimizer state given by the values
+    ``tensors`` for the state ``state_name`` for a single flat parameter
+    from ``handle`` corresponding to the unflattened parameter names
+    ``unflat_param_names`` and unflatted parameter shapes
+    ``unflat_param_shapes``. This flattens each unflattened parameter's tensor
+    state into one tensor.
+
+    NOTE: We use zero tensors for any unflattened parameters without state
+    since some value is required to fill those entries. This assumes that the
+    zero tensor is mathematically equivalent to having no state, which is true
+    for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all
+    optimizers.
+
+    Args:
+        state_name (str): Optimizer state name.
+        pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor
+            optimizer state values for the unflattened parameters corresponding
+            to the single flat parameter.
+        unflat_param_names (List[str]): A :class:`list` of unflattened
+            parameter names corresponding to the single flat parameter.
+        unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes
+            corresponding to the single flat parameter.
+        handle (FlatParamHandle): The flat parameter's handle.
+
+    Returns:
+        torch.Tensor: A flat tensor containing the optimizer state
+        corresponding to ``state_name`` constructed by concatenating the
+        unflattened parameter tensor states in ``pos_dim_tensors`` (using zero
+        tensors for any unflattened parameters without the state).
+    """
+    flat_param = handle.flat_param
+    non_none_tensors = [t for t in pos_dim_tensors if t is not None]
+    # Check that all are tensors with the same dtype
+    dtypes = {t.dtype for t in non_none_tensors}
+    if len(dtypes) != 1:
+        raise ValueError(
+            "All unflattened parameters comprising a single flat "
+            "parameter must have positive-dimension tensor state with the "
+            f"same dtype but got dtypes {dtypes} for state {state_name} and "
+            f"unflattened parameter names {unflat_param_names}"
+        )
+    dtype = next(iter(dtypes))
+    # Check that each tensor state matches its parameter's shape
+    for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes):
+        if tensor is None and len(shape) == 0:
+            raise ValueError("Flattening a zero-dimension parameter is not supported")
+        elif tensor is not None and tensor.shape != shape:
+            raise ValueError(
+                "Tensor optimizer state does not have same shape as its "
+                f"parameter: {tensor.shape} {shape}"
+            )
+    # Flatten the tensor states: we do not need to add any right-hand-side
+    # padding since the flat optimizer state tensor is sharded via
+    # `_get_shard()`, which pads the shard as needed (just like for the flat
+    # parameter)
+    cpu_device = torch.device("cpu")
+    tensors_to_flatten = [
+        torch.flatten(state_value.to(cpu_device))
+        if state_value is not None
+        else torch.flatten(
+            torch.zeros(
+                size=shape,
+                dtype=dtype,
+                device=cpu_device,
+            )
+        )
+        for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes)
+    ]
+    flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
+    flat_param_shape = flat_param._unpadded_unsharded_size  # type: ignore[attr-defined]
+    assert flat_tensor.shape == flat_param_shape, (
+        f"tensor optim state: {flat_tensor.shape} "
+        f"flat parameter: {flat_param_shape}"
+    )
+    return flat_tensor
+
+
+def _flatten_zero_dim_tensor_optim_state(
+    state_name: str,
+    zero_dim_tensors: List[torch.Tensor],
+    unflat_param_names: List[str],
+) -> torch.Tensor:
+    """
+    Flattens the zero-dimension tensor optimizer state given by the values
+    ``zero_dim_tensors`` for the state ``state_name`` for a single flat
+    parameter corresponding to the unflattened parameter names
+    ``unflat_param_names`` by enforcing that all tensors are the same and using
+    that common value.
+
+    NOTE: The requirement that the tensors are the same across all unflattened
+    parameters comprising the flat parameter is needed to maintain the
+    invariant that FSDP performs the same computation as its non-sharded
+    equivalent. This means that none of the unflattened parameters can be
+    missing this state since imposing a value may differ from having no value.
+    For example, for Adam's "step", no value means maximum bias correction,
+    while having some positive value means less bias correction.
+
+    Args:
+        state_name (str): Optimizer state name.
+        zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state
+            for the unflattened parameters corresponding to the single
+            flat parameter.
+        unflat_param_names (List[str]): A :class:`list` of unflattened
+            parameter names corresponding to the single flat parameter.
+
+    Returns:
+        torch.Tensor: A zero-dimensional tensor giving the value of the state
+        ``state_name`` for all unflattened parameters corresponding to the
+        names ``unflat_param_names``.
+    """
+    non_none_tensors = [t for t in zero_dim_tensors if t is not None]
+    # Enforce that all have the same value and dtype
+    values_set = {t.item() if t is not None else None for t in zero_dim_tensors}
+    dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors}
+    if (
+        len(non_none_tensors) != len(zero_dim_tensors)
+        or len(values_set) != 1
+        or len(dtypes) != 1
+    ):
+        raise ValueError(
+            "All unflattened parameters comprising a single flat "
+            "parameter must have scalar state with the same value and dtype "
+            f"but got values {values_set} and dtypes {dtypes} for state "
+            f"{state_name} and unflattened parameter names "
+            f"{unflat_param_names}"
+        )
+    value = next(iter(values_set))
+    dtype = next(iter(dtypes))
+    return torch.tensor(value, dtype=dtype, device=torch.device("cpu"))
+
+
+def _flatten_non_tensor_optim_state(
+    state_name: str,
+    non_tensors: List[Any],
+    unflat_param_names: List[str],
+) -> Any:
+    """
+    Flattens the non-tensor optimizer state given by the values ``non_tensors``
+    for the state ``state_name`` for a single flat parameter corresponding
+    to the unflattened parameter names ``unflat_param_names`` by enforcing that
+    all values are the same and using that common value.
+
+    See the note in :func:`_flatten_zero_dim_tensor_optim_state`.
+
+    Args:
+        state_name (str): Optimizer state name.
+        non_tensors (List[Any]): Non-tensor optimizer state for the unflattened
+            parameters corresponding to the single flat parameter.
+        unflat_param_names (List[str]): A :class:`list` of unflattened
+            parameter names corresponding to the single flat parameter.
+
+    Returns:
+        Any: A non-tensor giving the value of the state ``state_name`` for all
+        unflattened parameters corresponding to the names
+        ``unflat_param_names``.
+    """
+    non_none_non_tensors = [nt for nt in non_tensors if nt is not None]
+    # Enforce that all have the same value (same type already checked)
+    non_tensor_set = set(non_tensors)
+    if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1:
+        raise ValueError(
+            "All unflattened parameters comprising a single flat "
+            "parameter must have scalar state with the same value and dtype "
+            f"but got values {non_tensor_set} for state {state_name} and  "
+            f"unflattened parameter names {unflat_param_names}"
+        )
+    non_tensor = next(iter(non_tensor_set))
+    return non_tensor
+
+
+def _rekey_sharded_optim_state_dict(
+    sharded_osd: Dict[str, Any],
+    model: nn.Module,
+    optim: torch.optim.Optimizer,
+    optim_input: Optional[
+        Union[
+            List[Dict[str, Any]],
+            Iterable[nn.Parameter],
+        ]
+    ],
+    using_optim_input: bool,
+    is_named_optimizer: bool = False,
+) -> Dict[str, Any]:
+    """
+    Rekeys the optimizer state dict from unflattened parameter names to flat
+    parameter IDs according to the calling rank's ``optim``, which may be
+    different across ranks. In particular, the unflattened parameter names are
+    represented as :class:`_OptimStateKey` s.
+    """
+    param_to_fqns = _get_param_to_fqns(model)
+    flat_param_to_fqn = _get_flat_param_to_fqn(model)
+    param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast(
+        Dict[nn.Parameter, Union[int, str]],
+        (
+            _get_param_to_param_id_from_optim_input(model, optim_input)
+            if using_optim_input
+            else _get_param_to_param_key(
+                optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
+            )
+        ),
+    )
+    # All parameter keys in `param_to_param_key` should be in
+    # `param_to_fqns` -- strict inequality follows when not all parameters are
+    # passed to the optimizer
+    assert len(param_to_param_key) <= len(param_to_fqns)
+
+    unflat_param_names_to_flat_param_key: Dict[
+        Tuple[str, ...], Union[int, str]
+    ] = {}  # for "state"
+    unflat_param_name_to_flat_param_key: Dict[
+        str, Union[int, str]
+    ] = {}  # for "param_groups"
+    for param, unflat_param_names in param_to_fqns.items():
+        if param not in param_to_param_key:
+            # This parameter was not passed to the optimizer
+            continue
+        flat_param_key = param_to_param_key[param]
+        unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key
+        for unflat_param_name in unflat_param_names:
+            unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key
+
+    sharded_osd_state = sharded_osd["state"]
+    rekeyed_osd_state: Dict[Union[str, int], Any] = {}
+    for key, param_state in sharded_osd_state.items():
+        if isinstance(key, str):
+            rekeyed_osd_state[key] = param_state
+            continue
+        flat_param_key = unflat_param_names_to_flat_param_key.get(
+            key.unflat_param_names, key.unflat_param_names
+        )
+        rekeyed_osd_state[flat_param_key] = param_state
+
+    # Only process param_groups if it exists in sharded_osd
+    if "param_groups" in sharded_osd:
+        rekeyed_osd_param_groups: List[Dict[str, Any]] = []
+        for unflat_param_group in sharded_osd["param_groups"]:
+            flat_param_group = copy.deepcopy(unflat_param_group)
+            flat_param_keys = sorted(
+                {
+                    unflat_param_name_to_flat_param_key[unflat_param_name]
+                    for unflat_param_name in unflat_param_group["params"]
+                }
+            )
+            flat_param_group["params"] = flat_param_keys
+            rekeyed_osd_param_groups.append(flat_param_group)
+        return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups}
+    else:
+        return {"state": rekeyed_osd_state}
+
+
+def _get_param_id_to_param_from_optim_input(
+    model: nn.Module,
+    optim_input: Optional[
+        Union[
+            List[Dict[str, Any]],
+            Iterable[nn.Parameter],
+        ]
+    ] = None,
+) -> Dict[int, nn.Parameter]:
+    """
+    Constructs a mapping from parameter IDs to parameters. This may be used
+    both for models with ``FlatParameter`` s and without.
+
+    NOTE: This method is only preserved for backward compatibility. The method
+    :meth:`_get_param_key_to_param` is the preferred code path that does not
+    rely on ``optim_input``.
+
+    NOTE: We critically assume that, whether the optimizer input is a list of
+    parameters or a list of parameter groups, :class:`torch.optim.Optimizer`
+    enumerates the parameter IDs in order. In other words, for a parameter list
+    input, the parameter IDs should be in that list order, and for a parameter
+    groups input, the parameter IDs should be in order within each parameter
+    group and in order across parameter groups.
+
+    Args:
+        model (nn.Module): Model whose parameters are passed into the
+            optimizer.
+        optim_input (Optional[Union[List[Dict[str, Any]],
+        Iterable[nn.Parameter]]]): Input passed into the optimizer
+            representing either a :class:`list` of parameter groups or an
+            iterable of parameters; if ``None``, then this method assumes the
+            input was ``model.parameters()``. (Default: ``None``)
+
+    Returns:
+        List[nn.Parameter]: Mapping from parameter IDs to parameters,
+        where the parameter ID is implicitly the index in the :class:`list`.
+    """
+    # Assume the standard case of passing `model.parameters()` to the optimizer
+    # if `optim_input` is not specified
+    if optim_input is None:
+        return dict(enumerate(model.parameters()))
+    try:
+        params = cast(List[nn.Parameter], list(optim_input))
+    except TypeError as e:
+        raise TypeError(
+            "Optimizer input should be an iterable of Tensors or dicts, "
+            f"but got {optim_input}"
+        ) from e
+    if len(params) == 0:
+        raise ValueError("Optimizer input should not be empty")
+
+    # Check if the optimizer input represents tensors or parameter groups
+    all_tensors = True
+    all_dicts = True
+    for param in params:
+        all_tensors &= isinstance(param, torch.Tensor)
+        all_dicts &= isinstance(param, dict)
+    if not all_tensors and not all_dicts:
+        raise TypeError("Optimizer input should be an iterable of Tensors or dicts")
+    if all_tensors:
+        return dict(enumerate(params))
+    assert all_dicts
+    param_id_to_param: List[nn.Parameter] = []
+    for param_group in params:
+        has_params_key = "params" in param_group  # type: ignore[operator]
+        assert has_params_key, (
+            'A parameter group should map "params" to a list of the '
+            "parameters in the group"
+        )
+        # Implicitly map `flat_param_id` (current length of the list) to
+        # `param`
+        param_id_to_param.extend(param_group["params"])  # type: ignore[index]
+    return dict(enumerate(param_id_to_param))
+
+
+def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
+    """
+    Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes
+    from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical"
+    because ``FlatParameter``  s do not come from the original module but are
+    registered only after FSDP has been applied. This function returns the FSDP-given
+    name for the ``FlatParameter`` (usually module._flat_param) as opposed to the
+    canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``).
+
+    Consequently, this function will only return a non-empty mapping if FSDP was
+    applied with ``use_orig_params=False`` as, otherwise, the original parameters
+    are used within the module and there would be no ``FlatParameter`` s in the module.
+
+    """
+
+    def module_fn(module, prefix, tree_level, flat_param_to_fqn):
+        for param_name, param in _named_parameters_with_duplicates(
+            module, recurse=False
+        ):
+            if not isinstance(param, FlatParameter):
+                continue
+            fqn = clean_tensor_name(prefix + param_name)
+            flat_param_to_fqn[param] = fqn
+
+    def return_fn(flat_param_to_fqn):
+        return flat_param_to_fqn
+
+    flat_param_to_fqn_ret: Dict[FlatParameter, str] = {}
+    return _apply_to_modules(
+        model,
+        module_fn,
+        return_fn,
+        [fqn for fqn, _ in _named_parameters_with_duplicates(model)],
+        flat_param_to_fqn_ret,
+    )
+
+
+def _get_param_key_to_param(
+    optim: torch.optim.Optimizer,
+    model: Optional[nn.Module] = None,
+    is_named_optimizer: bool = False,
+    param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
+    flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
+) -> Dict[Union[int, str], nn.Parameter]:
+    """
+    Constructs a mapping from parameter keys to parameters. For the regular
+    optimizers, the keys are parameter IDs. For NamedOptimizer, the keys
+    are FQNs. This API may be used both for models with ``FlatParameter`` s and
+    without.
+    """
+    clean_fqn_to_curr_fqn: Dict[str, str] = {}
+    if is_named_optimizer:
+        assert (
+            param_to_fqns is not None and flat_param_to_fqn is not None
+        ), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
+        assert model is not None
+        for key, _ in _named_parameters_with_duplicates(model):
+            clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
+
+    param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
+    pid = 0
+    for param_group in optim.param_groups:
+        if is_named_optimizer:
+            for param in param_group["params"]:
+                assert flat_param_to_fqn is not None
+                if param in flat_param_to_fqn:
+                    # FlatParameter case
+                    key = flat_param_to_fqn[param]
+                else:
+                    assert param_to_fqns is not None
+                    # use_orig_params case
+                    assert len(param_to_fqns[param]) == 1
+                    key = param_to_fqns[param][0]
+                try:
+                    key = clean_fqn_to_curr_fqn[key]
+                except KeyError as e:
+                    raise KeyError(
+                        f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}."
+                    ) from e
+                param_key_to_param[key] = param
+        else:
+            for param in param_group["params"]:
+                param_key_to_param[pid] = param
+                pid += 1
+
+    return param_key_to_param
+
+
+def _get_param_to_param_key(
+    optim: torch.optim.Optimizer,
+    model: Optional[nn.Module] = None,
+    is_named_optimizer: bool = False,
+    param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
+    flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
+) -> Dict[nn.Parameter, Union[int, str]]:
+    """
+    Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API
+    only supports the case where `optim` is a regular optimizer, not NamedOptimizer.
+    So the parameter keys will be parameter ids.
+    """
+    param_id_to_param = _get_param_key_to_param(
+        optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
+    )
+    return {param: param_id for param_id, param in param_id_to_param.items()}
+
+
+def _get_param_to_param_id_from_optim_input(
+    model: nn.Module,
+    optim_input: Optional[
+        Union[
+            List[Dict[str, Any]],
+            Iterable[nn.Parameter],
+        ]
+    ] = None,
+) -> Dict[nn.Parameter, int]:
+    """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`."""
+    param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input)
+    return {param: param_id for param_id, param in param_id_to_param.items()}
+
+
+def _check_missing_keys_on_rank(
+    r0_optim_state_keys: List[_OptimStateKey],
+    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]],
+    param_key_to_param: Dict[Union[str, int], nn.Parameter],
+    group: Optional[dist.ProcessGroup],
+) -> None:
+    # Ensure that all ranks have at least the optimizer states needed by
+    # rank 0's optimizer
+    missing_keys: List[_OptimStateKey] = []
+    for r0_optim_state_key in r0_optim_state_keys:
+        if r0_optim_state_key not in optim_state_key_to_param_key:
+            # A parameter from rank 0's optimizer does not exist for this
+            # rank's optimizer
+            missing_keys.append(r0_optim_state_key)
+            continue
+        param_key = optim_state_key_to_param_key[r0_optim_state_key]
+        if isinstance(param_key, int):
+            assert param_key >= 0 and param_key < len(
+                param_key_to_param
+            ), "Check the `param_key_to_param` construction"
+    # We cannot use FSDPState.compute_device as this API is a global view.
+    device = _get_pg_default_device(group)
+    num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
+    dist.all_reduce(num_missing, group=group)
+    if num_missing.item() > 0:
+        obj_list = [None for _ in range(dist.get_world_size(group))]
+        dist.all_gather_object(obj_list, missing_keys, group=group)
+        error_msg = (
+            "FSDP currently requires each rank to have at least the "
+            "optimizer states needed by rank 0's optimizer but some ranks "
+            "are missing some of those states"
+        )
+        for rank, keys in enumerate(obj_list):
+            keys = cast(List[_OptimStateKey], keys)
+            if len(keys) > 0:
+                error_msg += (
+                    f"\nRank {rank} is missing states for the parameters: "
+                    f"{[key.unflat_param_names for key in keys]}"
+                )
+        raise RuntimeError(error_msg)
+
+
+def _map_param_key_to_optim_keys(
+    optim_state_dict: Dict[str, Any],
+    group: Optional[dist.ProcessGroup],
+    param_key_to_param: Dict[Union[int, str], nn.Parameter],
+    param_to_fqns: Dict[nn.Parameter, List[str]],
+    fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
+    merge_keys: bool = False,
+) -> Tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]:
+    """
+    Construct the local mapping between the ``_OptimStateKey`` and parameter keys
+    and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0
+    must contain all the ``_OptimStateKey``, an exception will be raised otherwise.
+    Note that ``merge_keys`` should equal to ``use_orig_params``.
+    """
+    rank = dist.get_rank(group)
+    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {}  # local
+    all_optim_state_keys: List[_OptimStateKey] = []
+
+    for param_key, param in param_key_to_param.items():
+        # Do not include parameters without state to avoid empty mappings
+        # just like in normal `torch.optim.Optimizer.state_dict()`
+        if param_key not in optim_state_dict["state"]:
+            continue
+        fqns = param_to_fqns[param]
+        is_fsdp_managed = isinstance(param, FlatParameter)
+        if is_fsdp_managed:
+            assert fqns[0] in fqn_to_fsdp_param_info, (
+                fqns[0],
+                list(fqn_to_fsdp_param_info.keys()),
+            )
+        is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
+        optim_state_key = _OptimStateKey(
+            unflat_param_names=tuple(fqns),
+            is_fsdp_managed=is_fsdp_managed,
+        )
+        if rank == 0 or merge_keys:
+            all_optim_state_keys.append(optim_state_key)
+        optim_state_key_to_param_key[optim_state_key] = param_key
+
+    if merge_keys:
+        all_keys: List[List[_OptimStateKey]] = [
+            [] for _ in range(dist.get_world_size(group))
+        ]
+        dist.all_gather_object(all_keys, all_optim_state_keys, group=group)
+        merge_all_optim_state_keys = [
+            key for local_keys in all_keys for key in local_keys
+        ]
+        all_optim_state_keys = sorted(set(merge_all_optim_state_keys))
+    else:
+        key_obj_list: List[Optional[List[_OptimStateKey]]] = (
+            [all_optim_state_keys] if rank == 0 else [None]
+        )
+        dist.broadcast_object_list(key_obj_list, src=0, group=group)
+        assert key_obj_list[0] is not None
+        all_optim_state_keys = key_obj_list[0]
+        _check_missing_keys_on_rank(
+            all_optim_state_keys,
+            optim_state_key_to_param_key,
+            param_key_to_param,
+            group,
+        )
+
+    return all_optim_state_keys, optim_state_key_to_param_key
+
+
+def _unflatten_param_groups(
+    state_dict: Dict[str, Any],
+    param_key_to_param: Dict[Union[int, str], nn.Parameter],
+    param_to_fqns: Dict[nn.Parameter, List[str]],
+) -> List[Dict[str, Any]]:
+    param_groups: List[Dict[str, Any]] = []
+    for flat_param_group in state_dict["param_groups"]:
+        unflat_param_group = copy.deepcopy(flat_param_group)
+        param_group_params = [
+            param_key_to_param[flat_param_key]
+            for flat_param_key in flat_param_group["params"]
+        ]
+        nested_unflat_param_names = [
+            param_to_fqns[param] for param in param_group_params
+        ]
+        unflat_param_group["params"] = [
+            unflat_param_name
+            for unflat_param_names in nested_unflat_param_names
+            for unflat_param_name in unflat_param_names
+        ]  # flatten the list of lists
+        param_groups.append(unflat_param_group)
+    return param_groups
+
+
+def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
+    """
+    Returns whether the state_dict is from a NamedOptimizer.
+    This function checks that the keys in the state_dict['state'] are strings
+    (which usually are FQNs) versus integers (which usually refer to param_ids
+    from a vanilla torch.optim.Optimizer).
+    """
+    state = optim_state_dict.get("state", None)
+    if not state:
+        # If we cannot find a state, assume it is not NamedOptimizer as
+        # NamedOptimizer has eager initialization.
+        return False
+    try:
+        key = next(iter(state.keys()))
+    except Exception as e:
+        raise Exception(optim_state_dict) from e
+    return isinstance(key, str)
+
+
+@dataclass
+class StateInfo:
+    # The key of these dictionaries are the state name, e.g., `exp_avg`.
+    tensors: Dict[str, _PosDimTensorInfo]
+    scalar_tensors: Dict[str, torch.Tensor]
+    non_tensors: Dict[str, Any]
+
+
+def _allgather_state_info(
+    fsdp_state: _FSDPState,
+    input_states: Dict[str, Any],
+) -> List[Dict[str, StateInfo]]:
+    """
+    Given the ``input_states``, allgather StateInfo for each state. The function
+    uses all_gather_object to gather StateInfo so no GPU tensors are sent.
+    """
+
+    processed_state_dict: Dict[str, StateInfo] = {}
+    gathered_state_info: List[Dict[str, StateInfo]] = [
+        {} for _ in range(fsdp_state.world_size)
+    ]
+
+    for fqn, optim_state in input_states.items():
+        # Allgather the scalar tensor state, non-tensor states and tensors metadata.
+        processed_state = StateInfo({}, {}, {})
+        for state_name, value in sorted_items(optim_state):
+            if torch.is_tensor(value):
+                if value.dim() == 0:
+                    # Ensure that `step` is on CPU.
+                    processed_state.scalar_tensors[state_name] = value.cpu()
+                else:
+                    processed_state.tensors[state_name] = _PosDimTensorInfo(
+                        value.shape, value.dtype
+                    )
+            else:
+                processed_state.non_tensors[state_name] = value
+        processed_state_dict[fqn] = processed_state
+    dist.all_gather_object(
+        gathered_state_info,
+        processed_state_dict,
+        group=fsdp_state.process_group,
+    )
+    return gathered_state_info
+
+
+def _convert_all_state_info(
+    fsdp_param_info: FSDPParamInfo,
+    gathered_state_info: List[Dict[str, StateInfo]],
+    input_states: Dict[str, Any],
+    output_states: Dict[str, Dict[str, Any]],
+) -> Tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]:
+    """
+    Given the ``gathered_state_info`` and ``input_states``, the API converted
+    the StateInfo into the original state if the state is not a non-scalar
+    tensor. For a multi-dimensional tensor, the local state will be stored in
+    ``state_buffer`` in a correct order for later allgather purpose.
+    """
+
+    state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {}
+
+    for fqn, gathered_state in output_states.items():
+        state_info = [s[fqn] for s in gathered_state_info]
+        all_tensor_states = sorted(
+            {n for state in state_info for n in state.tensors.keys()}
+        )
+        empty_ranks: Set[int] = set()
+        dtype: Optional[torch.dtype] = None
+        # First check all the non-scalar states and get the information of
+        # states on each rank.
+        for state_name in all_tensor_states:
+            numels = []
+            _empty_ranks: Set[int] = set()
+            for rank, object_state in enumerate(state_info):
+                numels.append(0)
+                info = object_state.tensors.get(state_name, None)
+                if info is not None:
+                    numels[-1] = info.shape.numel()
+                    if not dtype:
+                        dtype = info.dtype
+                    else:
+                        assert dtype == info.dtype
+                if numels[-1] == 0:
+                    _empty_ranks.add(rank)
+
+            assert not empty_ranks or empty_ranks == _empty_ranks
+            empty_ranks = _empty_ranks
+            if state_name not in state_buffers:
+                state_buffers[state_name] = [
+                    None for _ in fsdp_param_info.param_indices
+                ]
+            local_state = input_states[fqn].get(state_name, None)
+            # N.B. We need to move the state to compute_device. The reason is
+            # not yet clear and we need to figure out why the state may be on a
+            # different device.
+            if local_state is not None:
+                local_state = local_state.to(fsdp_param_info.state.compute_device)
+            state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state
+
+        # Restoring the scalar and non-tensor states. If the corresponding
+        # non-scalar states do not exist on the rank, we also skip the scalar
+        # non-tensor states on that rank.
+        for rank, object_state in enumerate(state_info):
+            if rank in empty_ranks:
+                continue
+            for name, non_tensor_value in object_state.non_tensors.items():
+                curr_non_tensor_value = gathered_state.get(name, None)
+                assert (
+                    curr_non_tensor_value is None
+                    or curr_non_tensor_value == non_tensor_value
+                ), (
+                    f"Rank {rank} has different values for {name}: {non_tensor_value}."
+                    + f" Other ranks: {curr_non_tensor_value}"
+                )
+                gathered_state[name] = non_tensor_value
+
+            for name, scalar_tensor_value in object_state.scalar_tensors.items():
+                curr_scalar_tensor_value = gathered_state.get(name, None)
+                assert curr_scalar_tensor_value is None or torch.equal(
+                    scalar_tensor_value, curr_scalar_tensor_value
+                ), (
+                    f"Rank {rank} has different values for {name}: {scalar_tensor_value}."
+                    + f" Other ranks: {curr_scalar_tensor_value}"
+                )
+                gathered_state[name] = scalar_tensor_value
+
+    return dtype, state_buffers  # type: ignore[possibly-undefined]
+
+
+def _unflatten_orig_param_states(
+    fsdp_param_info: FSDPParamInfo,
+    output_states: Dict[str, Dict[str, Any]],
+    state_name: str,
+    shard_state: bool,
+    to_save: bool,
+    cpu_offload: bool,
+) -> None:
+    """
+    Given a output state dict, ``output_states``, which the keys are FQNs to the
+    original parameters (not FlatParameters nor parmeter ID), and the values
+    are gathered states, unflatten the states to the original dimensions.
+
+    This function performs the unflattening process in-place.
+    """
+    if not to_save:
+        return
+    flat_param = fsdp_param_info.handle.flat_param
+    fsdp_state = fsdp_param_info.state
+    for fqn, gathered_state in output_states.items():
+        value = gathered_state[state_name]
+        param_idx = fsdp_param_info.param_indices[fqn]
+
+        # TODO: This solution is not general and only apply to PTD TP solution.
+        if isinstance(value, DTensor):
+            placement = value.placements[0]
+            # If gathered state is a DTensor and its TP placement is not Replicate(), we need to
+            # gather the tensor on its TP dimension before chunking them into DTensor again.
+            if placement != Replicate():
+                placement_dim = placement.dim  # type: ignore[attr-defined]
+                value_local = value.redistribute(placements=(Replicate(),))
+                reshape_size = list(flat_param._shapes[param_idx])
+                reshape_size[placement_dim] *= value.device_mesh.size(0)
+                reshape_size = torch.Size(reshape_size)
+                value = value.reshape(reshape_size)
+            # If gathered state is a replicate DTensor, we directly reshape it.
+            else:
+                value = value.reshape(flat_param._shapes[param_idx])
+        else:
+            # If gathered state is a tensor, we directly reshape it into unflatten state.
+            value = value.reshape(flat_param._shapes[param_idx])
+
+        if shard_state:
+            osd_config = fsdp_state._optim_state_dict_config
+            if getattr(osd_config, "_use_dtensor", False):
+                assert fsdp_state._device_mesh is not None
+                value = _ext_chunk_dtensor(
+                    value,
+                    fsdp_state.rank,
+                    fsdp_state._device_mesh,
+                    fsdp_state._fsdp_extension,
+                )
+            else:
+                assert fsdp_state.process_group is not None
+                value = _ext_chunk_tensor(
+                    value,
+                    fsdp_state.rank,
+                    fsdp_state.world_size,
+                    fsdp_state._device_handle.device_count(),
+                    fsdp_state.process_group,
+                    fsdp_state._fsdp_extension,
+                )
+        elif not cpu_offload:
+            with SimpleProfiler.profile("clone"):
+                value = value.detach().clone()
+
+        if cpu_offload:
+            with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
+                value = value.cpu()
+        gathered_state[state_name] = value
+
+
+def _allgather_orig_param_states(
+    fsdp_param_info: FSDPParamInfo,
+    gathered_state_info: List[Dict[str, StateInfo]],
+    input_states: Dict[str, Any],
+    shard_state: bool,
+    to_save: bool,
+    cpu_offload: bool,
+) -> Dict[str, Dict[str, Any]]:
+    """
+    Given the ``gathered_state_info`` and ``input_states``, the API allgathers
+    all tensor states and restore non-tensor states from ``gathered_state_info``.
+    """
+    fsdp_state = fsdp_param_info.state
+    if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL:
+        logger.warning(
+            "CUDA Memory Summary before calling to _allgather_orig_param_states %s",
+            torch.cuda.memory_summary(),
+        )
+
+    output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
+
+    dtype, state_buffers = _convert_all_state_info(
+        fsdp_param_info, gathered_state_info, input_states, output_states
+    )
+
+    if len(state_buffers) == 0:
+        return output_states
+
+    has_state_params: List[bool] = [
+        True if fqn in output_states else False
+        for fqn, idx in fsdp_param_info.param_indices.items()
+    ]
+
+    # Loop through the ``state_buffers`` and construct the flattened, concatenated,
+    # sharded states. The size of the constructed state will be the same size as
+    # flat_param (also sharded).
+    # Then we perform an allgather_into_tensor to get the full flat_param state.
+    # The full flat_param state is the result of concatenation of multiple states
+    # the order of of flat_param._fqns.
+    # The final step is to split the flat_param state into original param states
+    # and return the result.
+    flat_param = fsdp_param_info.handle.flat_param
+    empty_func = functools.partial(
+        torch.empty, dtype=dtype, device=fsdp_state.compute_device
+    )
+    gathered_tensor = empty_func(flat_param._padded_unsharded_size)
+    # Synchronize can be slow but this will be easier for us to debug.
+    torch.cuda.synchronize()
+    for state_name, buffers in state_buffers.items():
+        local_buffers: List[torch.Tensor] = []
+        begin = fsdp_state.rank * flat_param._sharded_size.numel()
+        # End is inclusive.
+        end = begin + flat_param._sharded_size.numel() - 1
+        # param_idx corresponds to the parameter index in the FlatParameter.
+        mem_offset, param_idx = 0, 0
+        for numel, is_padding in zip(
+            flat_param._numels_with_padding, flat_param._is_padding_mask
+        ):
+            frozen_and_no_state = not is_padding and (
+                not fsdp_param_info.param_requires_grad[param_idx]
+                and not has_state_params[param_idx]
+            )
+
+            if is_padding or frozen_and_no_state:
+                # This memory range is a padding or the param is frozen and does
+                # not require gradient. For the later case, we treat it as a
+                # padding and add empty values to the local_buffers.
+
+                padding_begin, padding_end = mem_offset, mem_offset + numel - 1
+                if padding_begin <= begin <= padding_end:
+                    # The range is an align padding before the first parameter in
+                    # the shard. The shard includes parts of this align padding.
+                    padding_len = (
+                        padding_end - begin + 1
+                        if end >= padding_end
+                        else end - begin + 1
+                    )
+                elif padding_begin <= end <= padding_end:
+                    # The range is an align padding after the last parameter in
+                    # the shard. The shard includes parts of this align padding.
+                    padding_len = (
+                        end - padding_begin + 1
+                        if begin <= padding_begin
+                        else end - begin + 1
+                    )
+                elif begin < padding_begin <= padding_end < end:
+                    # The range is an align padding that is completely in the
+                    # shard.
+                    padding_len = numel
+                else:
+                    padding_len = 0
+                if padding_len:
+                    local_buffers.append(empty_func(padding_len))
+
+            if not is_padding:
+                # This memory range is a parameter in FlatParameter. So there
+                # should be an corresponding state in the optimizer unless the
+                # parameter is frozen, which we treat it as a padding above.
+
+                # We need to check if this rank owns the buffer. If this is None:
+                # 1.) the rank does not own any part of the original parameter.
+                #     As a result, there is no corresponding optimizer state on
+                #     the rank as well.
+                # 2.) the parameter is frozen AND no optimizer state for the
+                #     parameter. If a parameter is frozen, there can still be
+                #     optimizer state if the parameter is not frozen in the
+                #     previous steps.
+                if buffers[param_idx] is not None:
+                    local_buffers.append(cast(torch.Tensor, buffers[param_idx]))
+                param_idx += 1
+
+            mem_offset += numel
+
+        shard_numel_padded = flat_param._sharded_size.numel() - (
+            sum(t.numel() for t in local_buffers)
+        )
+
+        assert flat_param._shard_numel_padded == shard_numel_padded, (
+            "Manually calculated _sharded_numel_padded is incorrect. "
+            f"_shard_numel_padded={flat_param._shard_numel_padded}, "
+            f"shard_numel_padded={shard_numel_padded}, "
+            f"_sharded_size.numel={flat_param._sharded_size.numel()}, "
+            f"_numels_with_padding={flat_param._numels_with_padding}, "
+            f"begin={begin}, end={end},"
+        )
+        if shard_numel_padded > 0:
+            # Add right-handed padding.
+            local_buffers.append(empty_func(shard_numel_padded))
+        local_shard = torch.cat(local_buffers)
+        assert local_shard.numel() * fsdp_state.world_size == gathered_tensor.numel(), (
+            "The size of local shard times the world size should equal to the "
+            "gathered tensor size. The inconsistency may be from a bug of "
+            "FlatParameter's metadata or the reconstruction logic in optimizer "
+            "state dict."
+        )
+        torch.cuda.synchronize()
+        with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
+            dist.all_gather_into_tensor(
+                gathered_tensor, local_shard, group=fsdp_state.process_group
+            )
+            # Synchronize can be slow but this will be easier for us to debug.
+            torch.cuda.synchronize()
+
+        unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()]
+        flat_param_handle = fsdp_param_info.handle
+        orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor)
+        assert len(orig_states) == len(fsdp_param_info.param_indices), (
+            "The number of parameters from FlatParameter is not consistent to "
+            "the number of states used by optimizer state dict reconstruction "
+            "logic."
+        )
+        for fqn, idx in fsdp_param_info.param_indices.items():
+            if fsdp_param_info.param_requires_grad[idx] or fqn in output_states:
+                output_states[fqn][state_name] = orig_states[idx]
+
+        _unflatten_orig_param_states(
+            fsdp_param_info,
+            output_states,
+            state_name,
+            shard_state,
+            to_save,
+            cpu_offload,
+        )
+
+    del gathered_tensor
+    return output_states
+
+
+def _gather_all_orig_param_state(
+    fsdp_param_info: FSDPParamInfo,
+    input_states: Dict[str, Any],
+    shard_state: bool,
+    to_save: bool,
+    cpu_offload: bool,
+) -> Dict[str, Any]:
+    """
+    Given a optimizer state dict, ``input_states``, which the keys are FQNs to the
+    original parameters (not FlatParameters nor parmeter ID), gather all the
+    states and unflatten them to the original dimensions. Note that all the
+    params referred by the ``input_states`` must be managed by FSDP.
+    """
+    fsdp_state = fsdp_param_info.state
+    if (
+        fsdp_state.world_size == 1
+        or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
+    ):
+        return input_states if to_save else {}
+
+    with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
+        with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ):
+            gathered_state_info = _allgather_state_info(fsdp_state, input_states)
+        output_states = _allgather_orig_param_states(
+            fsdp_param_info,
+            gathered_state_info,
+            input_states,
+            shard_state,
+            to_save,
+            cpu_offload,
+        )
+    if to_save:
+        for key, idx in fsdp_param_info.param_indices.items():
+            if key in output_states:
+                continue
+            if not fsdp_param_info.param_requires_grad[idx]:
+                continue
+
+            raise RuntimeError(
+                f"{key} is not in the output state. "
+                "The FSDPParamInfo has the param keys "
+                f"{sorted(fsdp_param_info.param_indices.keys())} while "
+                "the output_states has the param keys "
+                f"{sorted(output_states.keys())}."
+            )
+        return output_states
+    else:
+        return {}
+
+
+def _convert_state_with_orig_params(
+    all_optim_state_keys: List[_OptimStateKey],
+    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
+    fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
+    optim_state_dict: Dict[Union[str, int], Any],
+    to_save: bool,
+    shard_state: bool,
+    cpu_offload: bool = True,
+) -> Dict[str, Any]:
+    fsdp_osd_state: Dict[str, Any] = {}
+    # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo
+    # usually corresponds to multiple parameters. We could not use FSDPParamInfo
+    # as the key because FSDPParamInfo is not hashable. As a result, we fall back
+    # to `id(FSDPParamInfo)`, which the type is an integer.
+    all_states: Dict[int, Dict[str, Any]] = {}
+    # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
+    # across ranks
+    for optim_state_key in all_optim_state_keys:
+        param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
+            optim_state_key, None
+        )
+
+        if param_key is None and not optim_state_key.is_fsdp_managed:
+            continue
+
+        if optim_state_key.is_fsdp_managed:
+            fqn = optim_state_key.unflat_param_names[0]
+            fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
+            if fsdp_param_info is None:
+                # This can happen if the not all FSDP instances have all the
+                # parameters. This can happen with FSDP + some MPMD style
+                # parallelism.
+
+                # TODO: it is unclear if we need to do the same check with
+                # non-FSDP managed keys.
+                continue
+            state = {} if param_key is None else optim_state_dict[param_key]
+            if id(fsdp_param_info) not in all_states:
+                all_states[id(fsdp_param_info)] = {}
+            all_states[id(fsdp_param_info)][fqn] = state
+
+        elif to_save:
+            assert len(optim_state_key.unflat_param_names) == 1
+            unflat_param_name = optim_state_key.unflat_param_names[0]
+            with SimpleProfiler.profile("none_fsdp_managed_copy"):
+                param_key = cast(Union[str, int], param_key)
+                fsdp_osd_state[unflat_param_name] = copy.copy(
+                    optim_state_dict[param_key]
+                )
+                if cpu_offload:
+                    for state_name, value in sorted_items(
+                        fsdp_osd_state[unflat_param_name]
+                    ):
+                        if not torch.is_tensor(value):
+                            continue
+                        fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
+
+    # Instead of gathering the state of each parameter individually, we perform
+    # the gathering  all at once to speed up the process.
+    for _all_states in all_states.values():
+        fqn = next(iter(_all_states.keys()))
+        fsdp_param_info = fqn_to_fsdp_param_info[fqn]
+        assert len(fsdp_param_info.param_requires_grad) > 0, (
+            "With use_orig_params, FSDPParamInfo should have requires_grad "
+            "information. However, the length is zero."
+        )
+        for key, idx in fsdp_param_info.param_indices.items():
+            if key in _all_states:
+                continue
+            if not fsdp_param_info.param_requires_grad[idx]:
+                continue
+            raise RuntimeError(
+                f"{key} is not in the optimizer state. "
+                "The FSDPParamInfo has the param keys "
+                f"{sorted(fsdp_param_info.param_indices.keys())} while "
+                "the optimizer has the param keys "
+                f"{sorted(_all_states.keys())}."
+            )
+        fsdp_osd_state.update(
+            _gather_all_orig_param_state(
+                fsdp_param_info,
+                _all_states,
+                shard_state,
+                to_save,
+                cpu_offload,
+            )
+        )
+
+    return fsdp_osd_state
+
+
+def _convert_state_with_flat_params(
+    all_optim_state_keys: List[_OptimStateKey],
+    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
+    fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
+    optim_state_dict: Dict[Union[str, int], Any],
+    to_save: bool,
+    shard_state: bool,
+    cpu_offload: bool = True,
+) -> Dict[str, Any]:
+    fsdp_osd_state: Dict[str, Any] = {}
+    # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
+    # across ranks
+    for optim_state_key in all_optim_state_keys:
+        param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
+            optim_state_key, None
+        )
+
+        assert param_key is not None, (
+            "If use_orig_params is False, we must be able to find the "
+            f"corresponding param id. {optim_state_key} {param_key}"
+        )
+
+        if optim_state_key.is_fsdp_managed:
+            # If there are multiple unflat_param_names (not use_orig_params),
+            # they share the same FSDPParamInfo. So the first unflat_param_name
+            # is sufficient to fetch the FSDPParamInfo.
+            fqn = optim_state_key.unflat_param_names[0]
+            fsdp_param_info = fqn_to_fsdp_param_info[fqn]
+            unflat_state = _unflatten_optim_state(
+                fsdp_param_info,
+                optim_state_dict[param_key],
+                to_save,
+                shard_state,
+                cpu_offload,
+            )
+            if to_save:
+                assert len(unflat_state) == len(optim_state_key.unflat_param_names)
+                for unflat_param_name, unflat_param_state in zip(
+                    optim_state_key.unflat_param_names,
+                    unflat_state,
+                ):
+                    fsdp_osd_state[unflat_param_name] = unflat_param_state
+        elif to_save:
+            assert len(optim_state_key.unflat_param_names) == 1
+            unflat_param_name = optim_state_key.unflat_param_names[0]
+            fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key])
+            if cpu_offload:
+                for state_name, value in sorted_items(
+                    fsdp_osd_state[unflat_param_name]
+                ):
+                    if not torch.is_tensor(value):
+                        continue
+                    fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
+
+    return fsdp_osd_state
+
+
+@torch.no_grad()
+def _optim_state_dict(
+    model: nn.Module,
+    optim: torch.optim.Optimizer,
+    optim_state_dict: Dict[str, Any],
+    optim_input: Optional[
+        Union[
+            List[Dict[str, Any]],
+            Iterable[nn.Parameter],
+        ]
+    ],
+    rank0_only: bool,
+    shard_state: bool,
+    group: Optional[dist.ProcessGroup],
+    using_optim_input: bool,
+    use_orig_params: bool = False,
+    cpu_offload: bool = True,
+) -> Dict[str, Any]:
+    """
+    Consolidates the optimizer state and returns it as a :class:`dict`
+    following the convention of :meth:`torch.optim.Optimizer.state_dict`,
+    i.e. with keys ``"state"`` and ``"param_groups"``.
+    The flat parameters in ``FSDP`` modules contained in ``model`` are mapped
+    back to their unflattened parameters.
+
+    Parameter keys are not well-defined. For a regular optimizer, the optimizer
+    state_dict contains a mapping from parameter IDs to parameter states.
+    Parameter IDs are the order of parameters in ``optim.param_groups()`` across
+    all the groups. This API also allows user to pass ``optim_input`` for the
+    mapping between parameters and parameter IDs. Using ``optim_input`` is being
+    deprecated.
+
+    If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not
+    contain parameter IDs mapping but a mapping from parameter FQNs to parameter
+    states. This API finds the mapping from FQNs to parameters if the optimizer
+    is a ``NamedOptimizer``.
+
+    If ``use_orig_params`` is True, each rank will have all FSDP-managed
+    parameters but some of these parameters may be empty due to the sharding.
+    For a regular optim.Optimizer, states for those empty parameters will
+    not be initialized. So, when aggregating the FQNs across ranks, no assert
+    will be raised on a rank even if it does not have all the states -- it is
+    valid and FSDP knows how to aggregate them. However, FSDP has to ignore
+    handling those parameters that are not managed by FSDP and do not exist on
+    the local rank -- those are managed by other parallelisms and FSDP does not
+    know how to handle/aggregate them.
+
+    Args:
+        model (nn.Module): Root module (which may or may not be a
+            :class:`FullyShardedDataParallel` instance) whose parameters
+            were passed into the optimizer ``optim``.
+        optim (torch.optim.Optimizer): Optimizer for ``model`` 's
+            parameters.
+        rank0_only (bool): If ``True``, saves the populated :class:`dict`
+            only on rank 0; if ``False``, saves it on all ranks. (Default:
+            ``True``)
+        shard_state (bool): If ``True``, shard and distribute all
+            non-zero-dimension states.
+
+    Returns:
+        Dict[str, Any]: A :class:`dict` containing the optimizer state for
+        ``model`` 's original unflattened parameters and including keys
+        "state" and "param_groups" following the convention of
+        :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``,
+        then nonzero ranks return an empty :class:`dict`.
+    """
+    SimpleProfiler.reset()
+    cm = ExitStack()
+    cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL))
+    _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model))
+    to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state
+
+    with SimpleProfiler.profile("preprocessing"):
+        param_to_fqns = _get_param_to_fqns(model)
+        flat_param_to_fqn = _get_flat_param_to_fqn(model)
+        is_named_optimizer = _is_named_optimizer(optim_state_dict)
+
+        param_key_to_param = cast(
+            Dict[Union[int, str], nn.Parameter],
+            (
+                _get_param_id_to_param_from_optim_input(model, optim_input)
+                if using_optim_input
+                else _get_param_key_to_param(
+                    optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
+                )
+            ),
+        )
+        fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
+
+    with SimpleProfiler.profile("preprocessing_with_comm"):
+        (
+            all_optim_state_keys,
+            optim_state_key_to_param_key,
+        ) = _map_param_key_to_optim_keys(
+            optim_state_dict,
+            group,
+            param_key_to_param,
+            param_to_fqns,
+            fqn_to_fsdp_param_info,
+            merge_keys=use_orig_params,
+        )
+
+    with SimpleProfiler.profile("state_converting"):
+        convert_fn = (
+            _convert_state_with_orig_params
+            if use_orig_params
+            else _convert_state_with_flat_params
+        )
+        fsdp_osd_state = convert_fn(
+            all_optim_state_keys,
+            optim_state_key_to_param_key,
+            fqn_to_fsdp_param_info,
+            optim_state_dict["state"],
+            to_save,
+            shard_state,
+            cpu_offload,
+        )
+
+    # At this point, communication is complete and ranks can return early if nothing
+    # will be saved on that rank.
+    if not to_save:
+        return {}
+
+    fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state}
+
+    flat_param_fqns = set(flat_param_to_fqn.values())
+    for key, value in optim_state_dict["state"].items():
+        if key in fsdp_osd_state:
+            continue
+        if key in flat_param_fqns:
+            continue
+        if key in param_key_to_param:
+            continue
+        # This key is not recognized by FSDP. It may be a user-defined state
+        # or some parameters state that FSDP is unable to map from
+        # ``optim.param_groups``.
+        warnings.warn(
+            f"Found a optim state, {key}, that FSDP cannot process. FSDP "
+            "will directly copy everything to the returned state_dict. In "
+            "most cases, this is a user-defined state that is not "
+            "associated with any particular parameter. Another possible "
+            "case is this state is managed by TorchRec. Otherwise, there may "
+            " be a mismatched assumption of optim_state_dict of this mode."
+        )
+        fsdp_osd_state[key] = value
+
+    if "param_groups" in optim_state_dict:
+        fsdp_osd["param_groups"] = _unflatten_param_groups(
+            optim_state_dict, param_key_to_param, param_to_fqns
+        )
+
+    cm.close()
+    SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ")
+
+    return fsdp_osd
+
+
+def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
+    """
+    Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
+    if the param is managed by FSDP. Shared parameters, or original parameters that
+    are shared across multiple nn.Modules, are required to belong to one and only
+    one FSDP instance and thus correspond to one ``FlatParameter``. Within the one
+    ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared
+    parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters.
+    """
+
+    def module_fn(module, prefix, tree_level, fqn_to_param_info):
+        fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
+        if fsdp_state is None:
+            return
+        _lazy_init(fsdp_state, module)
+        handle = _module_handle(fsdp_state, module)
+        if not handle:
+            return
+        flat_param = handle.flat_param
+        fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, [])
+        # NOTE: `idx` indexes into the data structures *without* padding
+        # elements
+        for idx, local_fqn in enumerate(flat_param._fqns):
+            fqn = clean_tensor_name(prefix + local_fqn)
+            if fqn in fqn_to_param_info:
+                assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn
+            fqn_to_param_info[fqn] = fsdp_param_info
+            fsdp_param_info.param_indices[fqn] = idx
+            if flat_param._params is not None:
+                fsdp_param_info.param_requires_grad.append(
+                    flat_param._params[idx].requires_grad
+                )
+
+    def return_fn(fqn_to_param_info):
+        return fqn_to_param_info
+
+    fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
+    # FlatParameter._fqns stores the local fqn, starting from the root of the
+    # FSDP. Using _apply_to_modules() with model (may not be the FSDP root
+    # module) allows us to construct the global fqn.
+    return _apply_to_modules(
+        model,
+        module_fn,
+        return_fn,
+        [fqn for fqn, _ in _named_parameters_with_duplicates(model)],
+        fqn_to_param_info,
+    )
+
+
+@no_type_check
+def _set_optim_use_dtensor(
+    fsdp_state: _FSDPState,
+    state_dict_settings: StateDictSettings,
+) -> None:
+    # If device_mesh is passed in when initalizing FSDP, we automatically turn the
+    # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type
+    # has to be set to SHARDED_STATE_DICT.
+    if getattr(fsdp_state, "_device_mesh", None):
+        state_dict_type = state_dict_settings.state_dict_type
+        if state_dict_type == StateDictType.LOCAL_STATE_DICT:
+            raise RuntimeError(
+                "Found state_dict_type LOCAL_STATE_DICT.",
+                "DeviceMesh is not compatible with LOCAL_STATE_DICT.",
+                "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
+            )
+        else:
+            state_dict_settings.optim_state_dict_config._use_dtensor = True
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_runtime_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_runtime_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a570627e460bdb9ae9d5dab2dddf60c63da9e4f8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_runtime_utils.py
@@ -0,0 +1,1630 @@
+import functools
+import logging
+from enum import auto, Enum
+from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._traversal_utils as traversal_utils
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from torch.autograd.graph import register_multi_grad_hook
+from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
+from torch.distributed.fsdp._common_utils import (
+    _assert_in_training_states,
+    _FSDPState,
+    _get_module_fsdp_state,
+    _is_composable,
+    _log_post_backward_hook,
+    _no_dispatch_record_stream,
+    clean_tensor_name,
+    TrainingState,
+)
+from torch.distributed.fsdp._flat_param import (
+    FlatParameter,
+    FlatParamHandle,
+    HandleShardingStrategy,
+    HandleTrainingState,
+    RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
+)
+from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
+from torch.distributed.fsdp.api import BackwardPrefetch
+from torch.distributed.utils import (
+    _apply_to_tensors,
+    _cast_forward_inputs,
+    _p_assert,
+    _to_kwargs,
+)
+from torch.utils import _pytree as pytree
+
+log = logging.getLogger(__name__)
+
+# Do not include "process_group" to enable hybrid shard and MoE cases
+HOMOGENEOUS_ATTR_NAMES = (
+    "_use_orig_params",
+    "limit_all_gathers",
+    "_use_full_prec_in_eval",
+)
+
+
+class _PrefetchMode(Enum):
+    BACKWARD = auto()
+    FORWARD = auto()
+
+
+def _get_fsdp_root_states_with_modules(
+    module: nn.Module,
+) -> Tuple[List[_FSDPState], List[nn.Module]]:
+    """
+    Returns a tuple containing:
+    1. A list of the root ``_FSDPState`` instances in the module tree rooted at
+    ``module`` without any duplicates and following the ``module.modules()``
+    traversal order (which is assumed to be depth-first).
+    2. A corresponding list of the root modules owning the states in the first
+    list.
+
+    This is similar to :func:`_get_fsdp_states_with_modules` except that we
+    must call :func:`_is_fsdp_root` to force a lazy initialization to determine
+    the FSDP root in case lazy initialization has not yet happened.
+    """
+    fsdp_root_states: List[_FSDPState] = []
+    fsdp_root_modules: List[nn.Module] = []
+    visited_fsdp_states: Set[_FSDPState] = set()
+    # NOTE: This function assumes that `module.modules()` proceeds top-down.
+    for submodule in module.modules():
+        optional_state = _get_module_fsdp_state(submodule)
+        if (
+            optional_state is not None
+            and optional_state not in visited_fsdp_states
+            and _is_fsdp_root(optional_state, submodule)
+        ):
+            visited_fsdp_states.add(optional_state)
+            fsdp_root_states.append(optional_state)
+            fsdp_root_modules.append(submodule)
+    return fsdp_root_states, fsdp_root_modules
+
+
+def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
+    """See :func:`_get_fsdp_root_states_with_modules`."""
+    fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
+    return fsdp_root_states
+
+
+def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
+    """
+    Returns if ``state`` corresponds to that of an FSDP root.
+
+    For the wrapper code path, ``state`` and ``module`` should be the same. For
+    the non-wrapper code path, ``state`` should be ``module`` 's state.
+    """
+    # Force a lazy initialization to determine the FSDP root
+    _lazy_init(state, module)
+    assert state._is_root is not None  # mypy
+    return state._is_root
+
+
+@no_type_check
+def _lazy_init(
+    state: _FSDPState,
+    root_module: nn.Module,
+) -> _FSDPState:
+    """
+    Performs initialization lazily, typically right before the first forward
+    pass. The laziness is needed to ensure that the parameter device/dtype and
+    the FSDP hierarchy have finalized. This method's actual logic only runs on
+    the root FSDP instance, which performs initialization for all non-root FSDP
+    instances to avoid partial initialization.
+
+    For the non-composable code path, ``state`` and ``root_module`` should be
+    the same, namely the FSDP instance itself.
+    """
+    if state._is_root is not None:
+        return  # no-op: already lazily initialized
+    if not state._device_handle.is_available():
+        # Allow the FSDP constructor to run even without CUDA but check this
+        # once we start real execution
+        raise RuntimeError("FSDP does not support CPU only execution")
+    # The following logic is only run on the root FSDP instance since it will
+    # set `_is_root=False` for the non-root instances
+    state._is_root = True
+    _assert_in_training_states(state, [TrainingState.IDLE])
+    _check_flat_params_on_expected_device(state, root_module)
+    state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)
+    _init_streams(state)
+    buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
+    _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
+    state._exec_order_data.init(state, root_module, state.process_group)
+    _share_state_and_init_handle_attrs(state, root_module)
+    return state
+
+
+def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
+    """
+    Checks that all ``FlatParameter``s in ``module`` 's tree managed by
+    ``state`` are on the expected device for *lazy initialization*.
+    """
+    cpu_device = torch.device("cpu")
+    for handle in traversal_utils._get_fsdp_handles(module):
+        if (
+            not handle._offload_params
+            and handle.flat_param.device != state.compute_device
+        ):
+            raise RuntimeError(
+                "An FSDP-managed module unexpectedly has parameters on "
+                f"{handle.flat_param.device}. Make sure to move the module to "
+                f"{state.compute_device} before training."
+            )
+        elif handle._offload_params and handle.flat_param.device != cpu_device:
+            raise RuntimeError(
+                "An FSDP-managed module with parameter CPU offloading enabled "
+                f"has parameters on {handle.flat_param.device}. Make sure to "
+                f"not move the module from CPU when offloading parameters."
+            )
+
+
+@no_type_check
+def _share_state_and_init_handle_attrs(
+    root_state: _FSDPState,
+    root_module: nn.Module,
+) -> None:
+    """
+    Shares data structure state from the ``root_state`` to all FSDP states in
+    ``root_module`` 's module tree, and initializes handle attributes. These
+    are done together to require a single loop over the states.
+    """
+    handle = root_state._handle
+    if handle:
+        handle.init_flat_param_attributes()
+    attr_name_to_values: Dict[str, Set[Any]] = {}
+    for attr_name in HOMOGENEOUS_ATTR_NAMES:
+        attr_name_to_values[attr_name] = set()
+    root_state._all_handles = root_state._exec_order_data.all_handles  # share reference
+    # Update _has_optim_in_backward for each handle.
+    for handle in root_state._all_handles:
+        flat_param = handle.flat_param
+        if hasattr(flat_param, "_in_backward_optimizers"):
+            raise RuntimeError(
+                "FSDP optimizer in backward only supported with use_orig_params=True!"
+            )
+        handle._has_optim_in_backward = flat_param._params is not None and any(
+            hasattr(param, "_in_backward_optimizers") for param in flat_param._params
+        )
+        if handle._has_optim_in_backward:
+            torch._C._log_api_usage_once("fsdp.optimizer_in_backward")
+    for fsdp_state in root_state._all_fsdp_states:
+        for attr_name in HOMOGENEOUS_ATTR_NAMES:
+            _p_assert(
+                hasattr(fsdp_state, attr_name),
+                f"FSDP state missing attribute {attr_name}",
+            )
+            attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
+        if fsdp_state is root_state:
+            continue
+        # Relax the assert for non-root FSDP instances in case the nested
+        # initialized module is wrapped again in FSDP later (e.g. after
+        # training to run inference)
+        _p_assert(
+            fsdp_state._is_root is None or not fsdp_state._is_root,
+            "Non-root FSDP instance's `_is_root` should not have been "
+            "set yet or should have been set to `False`",
+        )
+        fsdp_state._is_root = False
+        fsdp_state._unshard_stream = root_state._unshard_stream
+        fsdp_state._post_backward_stream = root_state._post_backward_stream
+        fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
+        fsdp_state._all_reduce_stream = root_state._all_reduce_stream
+        fsdp_state._default_stream = root_state._default_stream
+        fsdp_state._exec_order_data = root_state._exec_order_data
+        fsdp_state._free_event_queue = root_state._free_event_queue
+        if fsdp_state._fsdp_extension is not None:
+            fsdp_state._fsdp_extension.compute_stream = root_state._default_stream
+        handle = fsdp_state._handle
+        if handle:
+            handle.init_flat_param_attributes()
+    for attr_name, attr_values in attr_name_to_values.items():
+        if len(attr_values) != 1:
+            raise ValueError(
+                f"Expects one homogeneous value for {attr_name} but got {attr_values}"
+            )
+
+
+@no_type_check
+def _init_streams(
+    state: _FSDPState,
+) -> None:
+    """
+    Initializes CUDA streams for overlapping communication, computation, and
+    data transfers. The streams should be shared across FSDP instances.
+    """
+    assert state._is_root
+    assert state._device_handle.is_available()
+    uses_hybrid_sharding = any(
+        fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES
+        for fsdp_state in state._all_fsdp_states
+    )
+    # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and
+    # preserve the default priority of 0 otherwise
+    high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0
+    # Default stream for computation
+    state._default_stream = state._device_handle.current_stream()
+    if state._fsdp_extension is not None:
+        # set the compute stream to the FSDP extension
+        state._fsdp_extension.compute_stream = state._default_stream
+
+    # Stream for unshard logic, including allocating the all-gather destination
+    # tensors and the all-gathers themselves
+    state._unshard_stream = state._device_handle.Stream(priority=high_priority)
+    # Stream for overlapping gradient reduction with the backward pass gradient
+    # computation
+    state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
+    # Stream for pre-unshard logic, namely allocations and writes for CPU
+    # offloading (H2D copy) and mixed precision (low precision cast)
+    state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority)
+    # Stream to run HSDP's all-reduce as async (if using HSDP)
+    state._all_reduce_stream = (
+        state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream
+    )
+
+
+@no_type_check
+def _unshard(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    unshard_stream: torch.Stream,
+    pre_unshard_stream: torch.Stream,
+) -> None:
+    """
+    Unshards the handles in ``handles``. If the handles are in
+    :meth:`summon_full_params` and are using mixed precision, then they are
+    forced to full precision.
+
+    Postcondition: handle's ``FlatParameter`` 's data is the padded
+    unsharded flat parameter on the compute device.
+    """
+    if not handle:
+        return
+    with state._device_handle.stream(pre_unshard_stream):
+        ran_pre_unshard = handle.pre_unshard()
+    if ran_pre_unshard:
+        unshard_stream.wait_stream(pre_unshard_stream)
+    if state.limit_all_gathers:
+        event = state._free_event_queue.dequeue_if_needed()
+        if event:
+            with torch.profiler.record_function(
+                "FullyShardedDataParallel.rate_limiter"
+            ):
+                event.synchronize()
+    with state._device_handle.stream(unshard_stream):
+        handle.unshard()
+        handle.post_unshard()
+
+
+@no_type_check
+def _reshard(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    free_unsharded_flat_param: bool,
+):
+    """
+    Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
+    free the handle's padded unsharded flat parameter.
+    """
+    handle.reshard(free_unsharded_flat_param)
+    if state.limit_all_gathers and free_unsharded_flat_param:
+        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+            # We don't run a even queue for freeing under torch compile atm
+            # But maybe we need to? TODO(voz): Look into this
+            free_event = state._device_handle.Event()
+            free_event.record()
+            state._free_event_queue.enqueue(free_event)
+    handle.post_reshard()
+    # Flat parameter freed or not, we always have to "unshard" the parameter
+    # upon next access to get its shape correct.
+    handle._prefetched = False
+
+
+def _unshard_grads(
+    handle: Optional[FlatParamHandle],
+) -> None:
+    if handle:
+        handle.unshard_grad()
+
+
+def _reshard_grads(
+    handle: Optional[FlatParamHandle],
+) -> None:
+    if handle:
+        handle.reshard_grad()
+
+
+@no_type_check
+def _pre_forward(
+    state: _FSDPState,
+    handle: Optional[FlatParamHandle],
+    unshard_fn: Callable,
+    module: nn.Module,
+    args: Tuple[Any, ...],
+    kwargs: Dict[str, Any],
+) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+    """
+    Runs the pre-forward logic. This includes an opportunity to unshard
+    currently sharded parameters such as those for the current forward and
+    registering post-backward hooks for these current parameters. This function
+    also converts forward ``args`` and ``kwargs`` to the given precision.
+
+    Args:
+        handles (List[FlatParamHandle]): Handles giving the parameters used in
+            the current forward.
+        unshard_fn (Optional[Callable]): A callable to unshard any currently
+            sharded parameters or ``None`` to not do any unsharding.
+        module (nn.Module): Module whose forward this method runs right before;
+            expected by the hook signature.
+        args (Tuple[Any, ...]): Module forward ``args``.
+        kwargs (Dict[str, Any]): Module forward ``kwargs``.
+    """
+    with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
+        # For `fully_shard` + `checkpoint`, skip pre-forward logic in the
+        # recomputed forward
+        if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
+            # For both checkpoint implementations, we do not need to re-cast
+            # inputs here since they will be checkpointed in the low precision
+            # either by AC or normally by autograd as long as the AC region is
+            # nested within FSDP
+            return args, kwargs
+        state.training_state = TrainingState.FORWARD_BACKWARD
+        state._exec_order_data.record_pre_forward(handle, module.training)
+        if handle:
+            handle._training_state = HandleTrainingState.FORWARD
+        if unshard_fn is not None:
+            unshard_fn(state, handle)
+        # Register post-backward hooks to reshard the parameters and reduce-scatter
+        # their gradients. They must be re-registered every forward pass in case
+        # the `grad_fn` is mutated.
+        _register_post_backward_hook(state, handle)
+        # We have to reallocate the _cpu_grad if optimizer overlap
+        # set the grad to None in the backward pass.
+        if handle and handle._offload_params and handle.flat_param._cpu_grad is None:
+            handle.flat_param._cpu_grad = torch.zeros_like(
+                handle.flat_param._local_shard, device=torch.device("cpu")
+            ).pin_memory()
+
+        should_cast_forward_inputs = (
+            state._handle and not state._handle._force_full_precision
+        )
+
+        if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
+            # Recursively convert args and kwargs to specified precision.
+            input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
+            args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
+        _register_post_backward_reshard_only_hook(state, handle, args, kwargs)
+        return args, kwargs
+
+
+@no_type_check
+def _pre_forward_unshard(
+    state: _FSDPState,
+    handle: Optional[FlatParamHandle],
+) -> None:
+    """Unshards parameters in the pre-forward."""
+    if not handle:
+        return
+    # If the handles have been prefetched, then there is no need to call
+    # `_unshard()` again
+    if not handle._prefetched:
+        _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
+    handle._needs_pre_forward_unshard = False
+    # Don't wait during trace
+    if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        state._device_handle.current_stream().wait_stream(state._unshard_stream)
+    with torch.profiler.record_function(
+        "FullyShardedDataParallel._pre_forward_prefetch"
+    ):
+        _prefetch_handle(state, handle, _PrefetchMode.FORWARD)
+
+
+@no_type_check
+def _post_forward(
+    state: _FSDPState,
+    handle: Optional[FlatParamHandle],
+    reshard_fn: Callable,
+    module: nn.Module,
+    input: Any,
+    output: Any,
+) -> Any:
+    """
+    Runs the post-forward logic. This includes an opportunity to reshard
+    currently unsharded parameters such as those used in the current forward
+    and registering pre-backward hooks on the forward outputs.
+
+    Args:
+        handles (List[FlatParamHandle]): Handles giving the parameters used in
+            the current forward.
+        reshard_fn (Optional[Callable]): A callable to reshard any currently
+            unsharded parameters (e.g. from the current forward) or ``None`` to
+            not do any resharding.
+        module (nn.Module): Module whose forward just ran, which should be a
+            fully sharded module (see [Note: Fully Sharded Module]); expected
+            by the hook signature.
+        input (Any): Unused; expected by the hook signature.
+        output (Any): Forward pass output; pre-backward hooks are registered on
+            the tensors that require gradients in this output.
+
+    Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
+    parameter.
+    """
+    with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
+        # For `fully_shard` + `checkpoint`, skip post-forward logic in the
+        # recomputed forward
+        if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
+            return output
+
+        state._exec_order_data.record_post_forward(handle)
+        if reshard_fn is not None:
+            reshard_fn(state, handle)
+        # Register pre-backward hooks to unshard the flat parameters for the
+        # gradient computation (if needed)
+        output = _register_pre_backward_hooks(state, module, output, handle)
+        state.training_state = TrainingState.IDLE
+        if handle:
+            handle._training_state = HandleTrainingState.IDLE
+        return output
+
+
+@no_type_check
+def _post_forward_reshard(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+) -> None:
+    """Reshards parameters in the post-forward."""
+    if not handle:
+        return
+    # Do not free the root's parameters in the post-forward for `FULL_SHARD`
+    # with the intention that they are immediately used for backward
+    # computation (though this may not be true)
+    free_unsharded_flat_param = (
+        not state._is_root
+        and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
+    )
+    _reshard(state, handle, free_unsharded_flat_param)
+
+
+@no_type_check
+def _root_pre_forward(
+    state: _FSDPState,
+    module: nn.Module,
+    args,
+    kwargs,
+) -> None:
+    """
+    Runs pre-forward logic specific to the root FSDP instance, which should run
+    before any individual module's pre-forward. This starts with an attempt at
+    lazy initialization (which only runs non-vacuously once). Otherwise, if
+    this is called on a non-root FSDP instance, then it returns directly.
+
+    Args:
+        module (nn.Module): Module for which this logic tries to run. It may or
+            may not be the root. If not, then this method does not do anything.
+    """
+    with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):
+        _lazy_init(state, module)
+        _p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
+        if not state._is_root:
+            # Always cast forward inputs in the root of this local FSDP unit for mixed
+            # precision, as this is where mixed precision could be configed.
+            # This is more useful for auto wrapping that is recommended in composable path.
+            # For manual wrapping, cast forward inputs on each local FSDP unit root will
+            # increase some overhead, so not turned on for model wrapper path right now where
+            # manual wrapping is more broadly used.
+            if _is_composable(state):
+                return _root_cast_forward_input(state, module, args, kwargs)
+            return args, kwargs
+
+        # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
+        # are in full precision and if we should cast them back to lower precision, which happens when
+        # exiting eval() mode.
+        handle = state._handle
+        if handle:
+            should_cast_buffers_to_full_prec = handle._force_full_precision
+        else:
+            should_cast_buffers_to_full_prec = True
+
+        if should_cast_buffers_to_full_prec:
+            _cast_buffers_to_dtype_and_device(
+                buffers=dict(module.named_buffers()).values(),
+                buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
+                device=state.compute_device,
+            )
+            # This flag is only set when we cast buffers to full precision, to avoid the
+            # CPU overhead that can stem from retrieving all buffers and their types in the
+            # following else branch.
+            state._needs_buffer_dtype_restore_check = True
+        elif getattr(state, "_needs_buffer_dtype_restore_check", False):
+            # Check if buffers are in full precision and we need to cast them
+            # back down.
+            (
+                buffers,
+                buffer_dtypes_for_computation,
+            ) = _get_buffers_and_dtypes_for_computation(state, module)
+            if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
+                if any(
+                    buffer.dtype != buffer_dtype_for_computation
+                    for buffer, buffer_dtype_for_computation in zip(
+                        buffers, buffer_dtypes_for_computation
+                    )
+                ):
+                    # Assume we have to cast everything if there is one mismatch
+                    _cast_buffers_to_dtype_and_device(
+                        buffers, buffer_dtypes_for_computation, state.compute_device
+                    )
+            # We don't have to check this again until we cast buffers to full precision again.
+            state._needs_buffer_dtype_restore_check = False
+
+        if state.forward_prefetch:
+            handles = []
+            for fsdp_state in state._all_fsdp_states:
+                if fsdp_state._handle:
+                    handles.append(fsdp_state._handle)
+            for handle in handles:
+                handle._needs_pre_forward_unshard = True
+                handle._prefetched = False
+        _wait_for_computation_stream(
+            state._device_handle.current_stream(),
+            state._unshard_stream,
+            state._pre_unshard_stream,
+        )
+        _reset_flat_param_grad_info_if_needed(state._all_handles)
+
+        # Prepares the forward inputs by moving them to ``compute_device``
+        # TODO: Do not use the side stream for tensor copies for now; investigate
+        # the perf with/without it.
+        with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):
+            args_tuple, kwargs_tuple = _to_kwargs(
+                args, kwargs, state.compute_device, False
+            )
+        args = args_tuple[0]
+        kwargs = kwargs_tuple[0]
+
+        return _root_cast_forward_input(state, module, args, kwargs)
+
+
+@no_type_check
+def _root_cast_forward_input(
+    state: _FSDPState, module: torch.nn.Module, args, kwargs
+) -> Tuple[Any, Any]:
+    if state._handle:
+        force_full_precision = not state._handle._force_full_precision
+    else:
+        force_full_precision = True
+
+    should_cast_forward_inputs = (
+        (module.training or not state._use_full_prec_in_eval) and force_full_precision
+    ) and state.mixed_precision.cast_root_forward_inputs
+
+    if should_cast_forward_inputs:
+        input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
+        args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
+
+    return args, kwargs
+
+
+@no_type_check
+def _pre_backward_hook(
+    state: _FSDPState,
+    module: nn.Module,
+    handle: FlatParamHandle,
+    grad,
+    *unused: Any,
+) -> Any:
+    """
+    Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
+
+    Args:
+        module (nn.Module): Fully sharded module (see [Note: Fully Sharded
+            Module]).
+    """
+    # Only run the pre-backward hook once per group of handles involved in the
+    # same module forward computation
+    if (
+        handle
+        and hasattr(handle, "_ran_pre_backward_hook")
+        and handle._ran_pre_backward_hook
+    ):
+        log.debug("%s %s", id(state), "Not Running pre backward! Already Ran!")
+        return grad
+
+    with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
+        # Queue the post-backward callback once for the root FSDP instance to
+        # attach it to the outermost backward graph task so that it is called
+        # after all backward calls complete
+        if state._is_root and not state._post_backward_callback_queued:
+            _register_post_backward_final_callback(state, module)
+            _reset_flat_param_grad_info_if_needed(state._all_handles)
+        elif handle:
+            allowed_states = [TrainingState.IDLE]
+            if _is_composable(state):
+                allowed_states.append(TrainingState.FORWARD_BACKWARD)
+            _assert_in_training_states(state, allowed_states)
+        state.training_state = TrainingState.FORWARD_BACKWARD
+        # Queueing the post-backward callback is the only logic that is not
+        # per-handle in the pre-backward hook, so we can return early here if
+        # there are no handles.
+        if not handle:
+            return grad
+        handle._training_state = HandleTrainingState.BACKWARD_PRE
+
+        if handle._needs_pre_backward_unshard:
+            # If the handles have been prefetched, then there is no need to
+            # call `_unshard()` again
+            if not handle._prefetched:
+                _unshard(
+                    state,
+                    handle,
+                    state._unshard_stream,
+                    state._pre_unshard_stream,
+                )
+            # Don't wait during trace
+            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+                state._device_handle.current_stream().wait_stream(state._unshard_stream)
+
+        # Set this to `False` to ensure that a mistargeted prefetch does not
+        # actually unshard these handles
+        handle._needs_pre_backward_unshard = False
+        with torch.profiler.record_function(
+            "FullyShardedDataParallel._pre_backward_prefetch"
+        ):
+            _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
+        handle.prepare_gradient_for_backward()
+        handle._ran_pre_backward_hook = True
+        return grad
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_hook(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    flat_param,
+    *unused: Any,
+):
+    """
+    Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
+
+    Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
+    unsharded gradient for the local batch.
+
+    Postcondition:
+    - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
+    unsharded gradient.
+    - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
+    gradient (accumulating with any existing gradient).
+    """
+    _log_post_backward_hook(state, handle, log)
+    flat_param = handle.flat_param
+    flat_param._post_backward_called = True
+    with torch.autograd.profiler.record_function(
+        "FullyShardedDataParallel._post_backward_hook"
+    ):
+        _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+        # For multiple applications of reentrant AC across submodules sharing
+        # the same `FlatParameter`, the post-backward hook may run multiple
+        # times in one backward, in which case we permit the state to already
+        # be in `BACKWARD_POST`.
+        _p_assert(
+            handle._training_state
+            in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
+            f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
+        )
+        handle._training_state = HandleTrainingState.BACKWARD_POST
+
+        if flat_param.grad is None:
+            return
+        if flat_param.grad.requires_grad:
+            raise RuntimeError("FSDP does not support gradients of gradients")
+
+        _post_backward_reshard(state, handle)
+        if not state._sync_gradients:
+            if handle._use_orig_params:
+                handle._use_unsharded_grad_views()
+            return
+
+        # Wait for all ops in the current stream (e.g. gradient computation) to
+        # finish before reduce-scattering the gradient
+        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+            state._post_backward_stream.wait_stream(
+                state._device_handle.current_stream()
+            )
+
+        with state._device_handle.stream(state._post_backward_stream):
+            autograd_computed_grad = flat_param.grad.data
+            if (
+                not _low_precision_hook_enabled(state)
+                and flat_param.grad.dtype != handle._reduce_dtype
+                # If we are forcing full precision but communicating grads
+                # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.
+                and not handle._force_full_precision
+            ):
+                flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
+            if handle.uses_sharded_strategy:
+                _reduce_grad(state, handle)
+            else:
+                _reduce_grad_no_shard(state, handle)
+            # Since the unsharded gradient is produced in the computation
+            # stream and consumed in the post-backward stream, inform the
+            # caching allocator (before it goes out of scope)
+            _no_dispatch_record_stream(
+                autograd_computed_grad, state._post_backward_stream
+            )
+
+
+def _post_backward_reshard_only_hook(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    *unused: Any,
+) -> None:
+    with torch.profiler.record_function(
+        "FullyShardedDataParallel._post_backward_hook_reshard_only"
+    ):
+        # `_pre_backward_hook` may not get executed
+        # if forward output does not require grad
+        # overwrite IDLE state for post-backward prefetching
+        state.training_state = TrainingState.FORWARD_BACKWARD
+        handle._training_state = HandleTrainingState.BACKWARD_POST
+        _post_backward_reshard(state, handle)
+
+
+def _post_backward_reshard(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    *unused: Any,
+) -> None:
+    free_unsharded_flat_param = _should_free_in_backward(state, handle)
+    _reshard(state, handle, free_unsharded_flat_param)
+
+    # TODO: Post-backward prefetching does not support the multiple handles
+    # per module case since the post-backward hook runs per handle, not per
+    # group of handles.
+    with torch.profiler.record_function(
+        "FullyShardedDataParallel._post_backward_prefetch"
+    ):
+        _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
+
+
+@no_type_check
+def _should_free_in_backward(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+) -> bool:
+    """
+    Returns whether FSDP should free the unsharded flat parameter in the
+    post-backward or not.
+    """
+    if not handle.uses_sharded_strategy:
+        return False
+    # If not syncing gradients, then we do not free for strategies that do not
+    # reshard after forward as a *heuristic* to tradeoff higher memory for
+    # higher throughput.
+    return (
+        state._sync_gradients
+        or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
+    )
+
+
+@no_type_check
+def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
+    """
+    For sharded strategies, this runs gradient reduction, sharded gradient
+    accumulation if needed, and the post-reduction callback.
+    """
+    flat_param = handle.flat_param
+    uses_hybrid_sharded_strategy = handle._sharding_strategy in (
+        HandleShardingStrategy.HYBRID_SHARD,
+        HandleShardingStrategy._HYBRID_SHARD_ZERO2,
+    )
+    # We clear `.grad` to permit multiple backwards. This avoids a race where
+    # the second backward pass computation precedes ahead of the first backward
+    # pass reduction, which is possible since the reduction is issued in a
+    # separate stream and is async and would result in reducing the wrong
+    # gradient.
+    unsharded_grad = flat_param.grad.data
+    flat_param.grad = None
+    padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(
+        state, unsharded_grad
+    )
+    if state._comm_hook is None:  # default path
+        _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)
+        pg = (
+            handle._fake_process_group
+            if handle._use_fake_reduce
+            else state.process_group
+        )
+        dist.reduce_scatter_tensor(
+            new_sharded_grad,
+            padded_unsharded_grad,
+            group=pg,
+        )
+        if uses_hybrid_sharded_strategy:
+            # Don't wait during trace
+            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+                state._all_reduce_stream.wait_stream(state._post_backward_stream)
+            with state._device_handle.stream(state._all_reduce_stream):
+                # Since the new sharded gradient is produced in the post-
+                # backward stream and consumed in the all-reduce stream,
+                # inform the caching allocator
+                _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream)
+                dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)
+                _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
+                grad_to_offload = _accumulate_sharded_grad(
+                    state, handle, new_sharded_grad
+                )
+                _post_reduce_grad_callback(state, handle, grad_to_offload)
+                return
+        _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
+    else:
+        state._comm_hook(
+            state._comm_hook_state, padded_unsharded_grad, new_sharded_grad
+        )
+        # NOTE: HSDP variants do not support communication hook.
+    grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad)
+    _post_reduce_grad_callback(state, handle, grad_to_offload)
+
+
+@no_type_check
+def _get_reduce_scatter_tensors(
+    state: _FSDPState, unsharded_grad: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Returns the input and output tensors to reduce-scatter, respectively.
+    """
+    chunks = list(unsharded_grad.chunk(state.world_size))
+    numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()
+    padded_unsharded_grad = (
+        F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad
+    )
+    new_sharded_grad = torch.empty_like(chunks[0])  # padded
+    return padded_unsharded_grad, new_sharded_grad
+
+
+@no_type_check
+def _accumulate_sharded_grad(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    sharded_grad: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Accumulates the reduce-scattered sharded gradient with any existing sharded
+    gradient if needed, returning the gradient to offload (if CPU offloading is
+    enabled).
+    """
+    flat_param = handle.flat_param
+    _cast_grad_to_param_dtype(state, sharded_grad, flat_param)
+    # Save the sharded gradient in `_saved_grad_shard` to support gradient
+    # accumulation -- for multiple backwards, the gradient reductions may
+    # happen in arbitrary order
+    accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
+    if accumulate_grad:
+        _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard)
+        flat_param._saved_grad_shard += sharded_grad
+    else:
+        flat_param._saved_grad_shard = sharded_grad
+    grad_to_offload = flat_param._saved_grad_shard
+    return grad_to_offload
+
+
+@no_type_check
+def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None:
+    """
+    For no-shard, this runs gradient reduction (which directly covers any
+    gradient accumulation implicitly) and the post-reduction callback.
+    """
+    flat_param = handle.flat_param
+    if state._comm_hook is None:  # default path
+        _div_if_needed(flat_param.grad, state._gradient_predivide_factor)
+        dist.all_reduce(flat_param.grad, group=state.process_group)
+        _div_if_needed(flat_param.grad, state._gradient_postdivide_factor)
+    else:
+        state._comm_hook(state._comm_hook_state, flat_param.grad)
+    # For `NO_SHARD`, we can keep the low precision gradients by simply
+    # omitting the cast altogether
+    if not handle._keep_low_precision_grads:
+        _cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
+    grad_to_offload = flat_param.grad.data
+    _post_reduce_grad_callback(state, handle, grad_to_offload)
+
+
+@no_type_check
+def _post_reduce_grad_callback(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    # Additional arguments needed for the callback logic
+    grad_to_offload: torch.Tensor,
+):
+    """
+    This callback captures any logic to run after the gradient reduction
+    finishes. Currently, this offloads the gradient to CPU if CPU offloading is
+    enabled and uses sharded gradient views if ``use_orig_params=True``.
+    """
+    _offload_grad(state, handle, grad_to_offload)
+    _post_backward_use_sharded_grad_views(handle)
+
+
+@no_type_check
+def _offload_grad(
+    state: _FSDPState,
+    handle: FlatParamHandle,
+    grad_to_offload: torch.Tensor,
+):
+    if not handle._offload_params:
+        return
+    # Offload the gradient to CPU to ensure parameters and gradients are on the
+    # same device as required by the optimizer
+    # TODO: Investigate why `NO_SHARD` breaks correctness when using
+    # `non_blocking=True` here.
+    # TODO (rohan-varma): When CPU offload and optimizer overlap,
+    # non_blocking=True won't work since the copy may have not finished before
+    # the optimizer step executes on CPU. If we want to use non-blocking=True
+    # here, we'll have to synchronize before using result on CPU.
+    non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward
+    handle.flat_param._cpu_grad.copy_(
+        grad_to_offload.detach(), non_blocking=non_blocking
+    )  # synchronized in the post-backward callback
+    # Since the gradient being offloaded may have been produced in the
+    # computation stream and is being consumed here in the post-backward
+    # stream, inform the caching allocator
+    _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream)
+
+
+@no_type_check
+def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
+    if not handle._use_orig_params:
+        return
+    # Since the handle's `FlatParameter` completed its gradient computation, we
+    # should reset the gradient noneness mask
+    handle._reset_is_grad_none()
+    # Delay using sharded gradient views until after the reduce-scatter instead
+    # of immediately after resharding
+    handle._use_sharded_grad_views()
+    if handle._has_optim_in_backward:
+        handle.prepare_gradient_for_optim()
+        for orig_param in handle.flat_param._params:
+            # Check for `None` gradient to filter parameters not in the rank
+            if orig_param.grad is not None and hasattr(
+                orig_param, "_in_backward_optimizers"
+            ):
+                # TODO (rohan-varma): For CPU offload, this unfortunately
+                # operates on CPU because the parameters and gradients have
+                # already been offloaded. We should run this on GPU after
+                # refactoring.
+                for optim in orig_param._in_backward_optimizers:
+                    optim.step()
+
+                optim.zero_grad(set_to_none=True)
+        handle._reset_flat_param_grad_info_if_needed()
+        if handle._offload_params:
+            handle.flat_param._cpu_grad = None
+
+
+def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:
+    if div_factor > 1:
+        tensor.div_(div_factor)
+
+
+@no_type_check
+def _cast_grad_to_param_dtype(
+    state: _FSDPState,
+    sharded_grad: torch.Tensor,
+    param: FlatParameter,
+):
+    """
+    Casts ``sharded_grad`` back to the full parameter dtype so that the
+    optimizer step runs with that dtype. This performs an actual cast if
+    1. parameters were in reduced precision during the forward since then
+    gradients would be in that reduced precision, or
+    2. parameters were not in reduced precision but gradients were in
+    reduced precision for communication.
+    However, if a low precision communication hook is registered, then this
+    dtype cast happens in the hook instead.
+    """
+    _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+    if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:
+        low_prec_grad_data = sharded_grad.data
+        sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
+        # Since for `NO_SHARD`, the gradient is produced in the computation
+        # stream and consumed here in the post-backward stream, inform the
+        # caching allocator; for the sharded strategies, the gradient is
+        # produced in the post-backward stream, so this `record_stream()`
+        # should be a no-op
+        _no_dispatch_record_stream(
+            low_prec_grad_data, state._device_handle.current_stream()
+        )
+
+
+def _check_grad_to_accumulate(
+    new_sharded_grad: torch.Tensor,
+    accumulated_grad: torch.Tensor,
+) -> None:
+    _p_assert(
+        accumulated_grad.shape == new_sharded_grad.shape,
+        "Shape mismatch when accumulating gradients: "
+        f"existing gradient shape={accumulated_grad.shape} "
+        f"new gradient shape={new_sharded_grad.shape}",
+    )
+    _p_assert(
+        accumulated_grad.device == new_sharded_grad.device,
+        "Device mismatch when accumulating gradients: "
+        f"existing gradient device={accumulated_grad.device} "
+        f"new gradient device={new_sharded_grad.device}",
+    )
+
+
+@no_type_check
+def _low_precision_hook_enabled(state: _FSDPState) -> bool:
+    return state._comm_hook in LOW_PRECISION_HOOKS
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_final_callback(
+    state: _FSDPState,
+    module: nn.Module,
+):
+    """
+    This waits for the post-backward to finish and performs some final cleanup.
+    This runs at the end of the entire backward pass and should only be called
+    on the root FSDP instance.
+    """
+    _p_assert(
+        state._is_root,
+        "The post-backward callback should only be called on the root FSDP instance",
+    )
+    root_state = state
+
+    if root_state._sync_gradients:
+        current_stream = state._device_handle.current_stream()
+        # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish
+        # since it currently runs in the post-backward stream. That can be
+        # pushed to the next forward if run in a different stream
+        current_stream.wait_stream(root_state._post_backward_stream)
+        if root_state._all_reduce_stream is not current_stream:  # uses HSDP
+            current_stream.wait_stream(root_state._all_reduce_stream)
+        if root_state.cpu_offload.offload_params:
+            # Wait for non-blocking GPU -> CPU sharded gradient copies from the
+            # post-backward hooks to finish explicitly since CPU gradients do
+            # not automatically synchronize with the GPU
+            state._device_handle.current_stream().synchronize()
+    root_state._exec_order_data.next_iter()
+
+    for fsdp_state in state._all_fsdp_states:
+        _catch_all_reshard(fsdp_state)
+        _finalize_params(fsdp_state)
+        fsdp_state.training_state = TrainingState.IDLE
+        handle = fsdp_state._handle
+        if handle:
+            handle._ran_pre_backward_hook = False
+            handle._needs_pre_backward_unshard = False
+            handle._post_forward_index = None
+            handle._training_state = HandleTrainingState.IDLE
+            handle._prefetched = False
+    # Reset for cases like one forward and multiple backwards
+    root_state._post_backward_callback_queued = False
+
+
+@no_type_check
+def _catch_all_reshard(
+    state: _FSDPState,
+) -> None:
+    """
+    Reshards the parameters that may not have been resharded in the
+    post-backward hook. This can happen when a module's output is used in the
+    forward pass, meaning that its pre-backward hook runs (unsharding the
+    parameter), but the post-backward hook does not run because the output was
+    not jused in the loss computation corresponding to this backward pass.
+    """
+    # Wrap with a try-except to provide a more informative traceback if an
+    # error is raised
+    try:
+        if state._handle:
+            # TODO: This already-resharded check is brittle:
+            # https://github.com/pytorch/pytorch/issues/83956
+            already_resharded = (
+                state._handle.flat_param.data_ptr()
+                == state._handle.flat_param._local_shard.data_ptr()
+                # If FSDP skipped using sharded views, then the flat parameter
+                # still points to the sharded data, so we need to reshard to
+                # use sharded views
+                and not state._handle._skipped_use_sharded_views
+            )
+            if already_resharded:
+                return
+            free_unsharded_flat_param = _should_free_in_backward(state, state._handle)
+            _reshard(state, state._handle, free_unsharded_flat_param)
+    except Exception as e:
+        _p_assert(
+            False,
+            f"Got exception in the catch-all reshard for {state}: {str(e)}",
+            raise_assertion_error=False,
+        )
+        raise e
+
+
+@no_type_check
+def _finalize_params(
+    state: _FSDPState,
+) -> None:
+    """Finalizes the parameters before the next iteration."""
+    handle = state._handle
+    if not handle:
+        return
+    flat_param = handle.flat_param
+    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        if hasattr(flat_param, "_post_backward_hook_handle"):
+            pbhs_handle = flat_param._post_backward_hook_handle
+            pbhs_handle.remove()
+            del flat_param._post_backward_hook_handle
+    else:
+        if hasattr(flat_param, "_post_backward_hook_state"):
+            post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
+            expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
+            _p_assert(
+                post_backward_hook_state_len == expected_post_backward_hook_state_len,
+                f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
+            )
+            flat_param._post_backward_hook_state[-1].remove()
+            delattr(flat_param, "_post_backward_hook_state")
+    if flat_param.requires_grad:
+        if not state._sync_gradients:
+            # Preserve the gradient accumulation state if not synchronizing
+            # gradients: `.grad` remains the unsharded gradient  from prior
+            # `no_sync()` iterations, and `_saved_grad_shard` remains the
+            # sharded gradient from the last synchronized iteration
+            return
+        if not handle._has_optim_in_backward:
+            handle.prepare_gradient_for_optim()
+        _p_assert(
+            hasattr(flat_param, "_post_backward_called"),
+            "Expects `_post_backward_called` to be set on the `FlatParameter`",
+        )
+        flat_param._post_backward_called = False
+
+
+@no_type_check
+def _prefetch_handle(
+    state: _FSDPState,
+    current_handle: Optional[FlatParamHandle],
+    prefetch_mode: _PrefetchMode,
+) -> None:
+    """
+    Prefetches the next handles if needed (without synchronization). An empty
+    handles key cannot prefetch.
+    """
+    if not current_handle:
+        return
+    handle = _get_handle_to_prefetch(state, current_handle)
+    if not handle:
+        return
+    # Temporarily emulate the training state while calling `_unshard` to
+    # ensure the correct `as_params` for `_use_unsharded_views()`
+    prev_training_state = handle._training_state
+    if prefetch_mode == _PrefetchMode.BACKWARD:
+        handle._training_state = HandleTrainingState.BACKWARD_PRE
+    elif prefetch_mode == _PrefetchMode.FORWARD:
+        handle._training_state = HandleTrainingState.FORWARD
+    else:
+        raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}")
+    # Prefetch the next set of handles without synchronizing to allow
+    # the sync to happen as late as possible to maximize overlap
+    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
+    handle._training_state = prev_training_state
+    handle._prefetched = True
+
+
+@no_type_check
+def _get_handle_to_prefetch(
+    state: _FSDPState,
+    current_handle: FlatParamHandle,
+) -> FlatParamHandle:
+    """
+    Returns a :class:`list` of the handles keys to prefetch for the next
+    module(s), where ``current_handle`` represents the current module.
+
+    "Prefetching" refers to running the unshard logic early (without
+    synchronization), and the "next" modules depend on the recorded execution
+    order and the current training state.
+    """
+    training_state = _get_training_state(current_handle)
+    valid_training_states = (
+        HandleTrainingState.BACKWARD_PRE,
+        HandleTrainingState.BACKWARD_POST,
+        HandleTrainingState.FORWARD,
+    )
+    _p_assert(
+        training_state in valid_training_states,
+        f"Prefetching is only supported in {valid_training_states} but "
+        f"currently in {training_state}",
+    )
+    eod = state._exec_order_data
+    target_handle: Optional[FlatParamHandle] = None
+    if (
+        training_state == HandleTrainingState.BACKWARD_PRE
+        and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
+    ) or (
+        training_state == HandleTrainingState.BACKWARD_POST
+        and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
+    ):
+        target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle)
+        if (
+            target_handle_candidate
+            and target_handle_candidate._needs_pre_backward_unshard
+            and not target_handle_candidate._prefetched
+        ):
+            target_handle = target_handle_candidate
+        else:
+            target_handle = None
+    elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
+        target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle)
+        if (
+            target_handle_candidate
+            and target_handle_candidate._needs_pre_forward_unshard
+            and not target_handle_candidate._prefetched
+        ):
+            target_handle = target_handle_candidate
+        else:
+            target_handle = None
+
+    return target_handle
+
+
+def _get_training_state(
+    handle: FlatParamHandle,
+) -> HandleTrainingState:
+    """Returns the training state of the handles in ``handle``."""
+    _p_assert(handle, "Expects a non-empty handle")
+    return handle._training_state
+
+
+@no_type_check
+def _register_pre_forward_hook(
+    state: _FSDPState,
+    module: nn.Module,
+) -> None:
+    """
+    Registers a pre-forward hook on ``module``.
+    """
+    for forward_handle in state._pre_forward_handles:
+        forward_handle.remove()
+    state._pre_forward_handles.clear()
+    module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
+    hook = functools.partial(
+        _pre_forward, state, module_param_handle, _pre_forward_unshard
+    )
+    state._pre_forward_handles.append(
+        module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
+    )
+
+
+@no_type_check
+def _register_post_forward_hook(
+    state: _FSDPState,
+    module: nn.Module,
+) -> None:
+    """
+    Registers a post-forward hook on ``module``. Even if the module has no
+    handles, we should register the hook since it will register the module's
+    pre-backward hook.
+    """
+    for forward_handle in state._post_forward_handles:
+        forward_handle.remove()
+    state._post_forward_handles.clear()
+    module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
+    hook = functools.partial(
+        _post_forward,
+        state,
+        module_param_handle,
+        _post_forward_reshard,
+    )
+    state._post_forward_handles.append(module.register_forward_hook(hook))
+
+
+@no_type_check
+def _register_root_pre_forward_hook(
+    state: _FSDPState,
+    module: nn.Module,
+):
+    """
+    Registers root pre-forward hook on ``module``, which should be the local
+    FSDP root.
+
+    NOTE: For the current composable FSDP design, we have each application of
+    ``fully_shard()`` to a module to indicate that that module is the local
+    FSDP root. We may remove this assumption in the future, in which case we
+    will need to register this root pre-forward hook on any candidate module
+    that may be the local FSDP root.
+    """
+    for forward_handle in state._root_pre_forward_handles:
+        forward_handle.remove()
+    state._root_pre_forward_handles.clear()
+    hook = functools.partial(_root_pre_forward, state)
+    state._root_pre_forward_handles.append(
+        module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
+    )
+
+
+@no_type_check
+def _register_pre_backward_hooks(
+    state: _FSDPState,
+    module: nn.Module,
+    outputs: Any,
+    handle: FlatParamHandle,
+) -> None:
+    """
+    Registers pre-backward hooks on the tensors that require gradients in the
+    forward pass outputs ``outputs``, which were computed using the
+    ``FlatParameter`` s of ``handles``.
+
+    Args:
+        module (nn.Module): Fully sharded module (see [Note: Fully Sharded
+            Module]).
+
+    Returns:
+        Forward pass outputs with pre-backward hooks registered to tensors that
+        require gradients.
+    """
+    # If there is no gradient computation, then there is no need for
+    # pre-backward logic
+    if not torch.is_grad_enabled():
+        return outputs
+    if state._is_root:
+        state._post_backward_callback_queued = False  # only defined on the root
+
+    if handle:
+        handle._needs_pre_backward_unshard = False
+        # Since these handles' `FlatParameter`s participated in a forward, we
+        # conservatively assume that they will be used in the backward
+        handle._ran_pre_backward_hook = False
+
+    def _register_hook(t: torch.Tensor) -> torch.Tensor:
+        if t.requires_grad:
+            t.register_hook(
+                functools.partial(_pre_backward_hook, state, module, handle)
+            )
+            if handle:
+                handle._needs_pre_backward_unshard = True
+        return t
+
+    return _apply_to_tensors(_register_hook, outputs)
+
+
+def _register_post_backward_hook(
+    state: _FSDPState,
+    handle: Optional[FlatParamHandle],
+) -> None:
+    """
+    Registers post-backward hooks on the ``FlatParameter`` s'
+    ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
+
+    The ``AccumulateGrad`` object represents the last function that finalizes
+    the ``FlatParameter`` 's gradient, so it only runs after its entire
+    gradient computation has finished.
+
+    We register the post-backward hook only once in the *first* forward that a
+    ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
+    object being preserved through multiple forwards.
+
+    NOTE: We follow this heuristic to prefer the *first* forward to target the
+    parameter mixed precision case, where there are *separate*
+    ``AccumulateGrad`` objects across the different forwards. (Without
+    parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If
+    we instead prefer the *last* forward, then the hook runs early.
+    """
+    # If there is no gradient computation, then there is no need for
+    # post-backward logic
+    if not torch.is_grad_enabled():
+        return
+    if not handle:
+        return
+    flat_param = handle.flat_param
+
+    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        already_registered = hasattr(flat_param, "_post_backward_hook_handle")
+        if already_registered or not flat_param.requires_grad:
+            return
+        hook = functools.partial(_post_backward_hook, state, handle)
+        hook_handle = flat_param.register_post_accumulate_grad_hook(hook)
+        flat_param._post_backward_hook_handle = hook_handle  # type: ignore[attr-defined]
+    else:
+        already_registered = hasattr(flat_param, "_post_backward_hook_state")
+        if already_registered or not flat_param.requires_grad:
+            return
+        # Get the `AccumulateGrad` object
+        temp_flat_param = flat_param.expand_as(flat_param)
+        _p_assert(
+            temp_flat_param.grad_fn is not None,
+            "The `grad_fn` is needed to access the `AccumulateGrad` and "
+            "register the post-backward hook",
+        )
+        acc_grad = temp_flat_param.grad_fn.next_functions[0][0]  # type: ignore[union-attr]
+        assert acc_grad is not None
+        hook_handle = acc_grad.register_hook(
+            functools.partial(_post_backward_hook, state, handle)
+        )
+        flat_param._post_backward_hook_state = (acc_grad, hook_handle)  # type: ignore[attr-defined]
+
+
+def _register_post_backward_reshard_only_hook(
+    state: _FSDPState,
+    handle: Optional[FlatParamHandle],
+    args: Tuple[Any, ...],
+    kwargs: Dict[str, Any],
+) -> None:
+    """
+    Registers post-backward hooks to reshard flat parameters that do not
+    require gradient. We register these using multi-post-grad hooks on the
+    input activations to ensure that all gradients that may depend on the
+    parameters have been computed before resharding.
+    """
+    # If there is no gradient computation, then there is no need for
+    # post-backward logic
+    if not torch.is_grad_enabled():
+        return
+    # Construct `inp_tensors` lazily to avoid CPU overhead in typical case
+    # where each flat parameter requires gradient
+    inp_tensors: Optional[List[torch.Tensor]] = None
+    if not handle:
+        return
+    flat_param = handle.flat_param
+
+    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        already_registered = hasattr(flat_param, "_post_backward_hook_handle")
+    else:
+        already_registered = hasattr(flat_param, "_post_backward_hook_state")
+
+    if already_registered or flat_param.requires_grad:
+        return
+    if inp_tensors is None:
+        args_flat = pytree.arg_tree_leaves(*args, **kwargs)
+        inp_tensors = [
+            obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad
+        ]
+    assert inp_tensors is not None  # mypy
+    hook_handle = register_multi_grad_hook(
+        inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)
+    )
+    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        flat_param._post_backward_hook_handle = hook_handle  # type: ignore[attr-defined, assignment]
+    else:
+        flat_param._post_backward_hook_state = (hook_handle,)  # type: ignore[attr-defined, assignment]
+
+
+@no_type_check
+def _register_post_backward_final_callback(
+    state: _FSDPState, module: nn.Module
+) -> None:
+    """
+    Registers the post-backward final callback that runs at the end of the
+    backward pass. This should be called from the root FSDP instance at the
+    beginning of the pre-backward.
+    """
+    _p_assert(
+        state._is_root,
+        "Only the root FSDP instance should register the post-backward callback",
+    )
+    if state._post_backward_callback_queued:
+        return
+    _assert_in_training_states(state, [TrainingState.IDLE])
+    # Trace does not need this callback
+    if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        state._post_backward_callback_queued = True
+        Variable._execution_engine.queue_callback(
+            functools.partial(_post_backward_final_callback, state, module)
+        )
+
+
+def _wait_for_computation_stream(
+    computation_stream: torch.Stream,
+    unshard_stream: torch.Stream,
+    pre_unshard_stream: torch.Stream,
+):
+    """
+    Has the unshard and pre-unshard streams wait for the computation stream.
+    For example, this should be called in the FSDP root's pre-forward to
+    respect optimizer step computation.
+    """
+    # Tracing does not need to wait
+    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
+        return
+    unshard_stream.wait_stream(computation_stream)  # type: ignore[attr-defined]
+    # Having the pre-all-gather stream wait for the current stream even if we
+    # do not leverage the pre-all-gather stream is tolerable since this only
+    # runs once per iteration
+    pre_unshard_stream.wait_stream(computation_stream)  # type: ignore[attr-defined]
+
+
+def _reset_flat_param_grad_info_if_needed(
+    handles: List[FlatParamHandle],
+):
+    """
+    Clears the original parameters' gradients if needed. This method's CPU
+    overhead is minimal, so we may call it throughout FSDP methods, which serve
+    as callsites to free the gradient memory earlier.
+    """
+    if not isinstance(handles, list):
+        handles = [handles]
+    for handle in handles:
+        if handle._use_orig_params:
+            handle._reset_flat_param_grad_info_if_needed()
+
+
+@no_type_check
+def _get_buffers_and_dtypes_for_computation(
+    state: _FSDPState,
+    root_module: nn.Module,
+) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
+    """
+    Returns all buffers in the module tree rooted at ``root_module`` and a
+    corresponding list of the buffer dtypes for computation. Each buffer dtype
+    is either ``None`` if buffer mixed precision is not enabled or the buffer
+    low precision dtype otherwise.
+    """
+    _p_assert(state._is_root, "Expects the root to cast buffers")
+    buffers: List[torch.Tensor] = []
+    buffer_dtypes: List[Optional[torch.dtype]] = []
+    visited_buffers: Set[torch.Tensor] = set()
+    # Traverse the FSDP states bottom-up so that we prefer the owning FSDP
+    # instance's mixed precision setting for each buffer
+    fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
+        root_module
+    )
+    for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):
+        for buffer_name, buffer in fsdp_module.named_buffers():
+            if buffer in visited_buffers:
+                continue
+            visited_buffers.add(buffer)
+            if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names:
+                continue
+            buffers.append(buffer)
+            buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
+    assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
+    return buffers, buffer_dtypes
+
+
+@no_type_check
+def _get_orig_buffer_dtypes(
+    state: _FSDPState,
+    buffer_names: List[str],
+) -> List[torch.dtype]:
+    """
+    Returns the original buffer types of the given buffer names.
+    """
+    buffer_dtypes: List[torch.dtype] = []
+    for buffer_name in buffer_names:
+        _p_assert(
+            buffer_name in state._buffer_name_to_orig_dtype,
+            f"{buffer_name} is missing from pre-computed dict on rank "
+            f"{state.rank}, which only has keys "
+            f"{state._buffer_name_to_orig_dtype.keys()}",
+        )
+        buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])
+    return buffer_dtypes
+
+
+def _cast_buffers_to_dtype_and_device(
+    buffers: List[torch.Tensor],
+    buffer_dtypes: List[Optional[torch.dtype]],
+    device: torch.device,
+) -> None:
+    """
+    Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
+    to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
+    corresponding buffer is only moved to ``device``.
+    """
+    _p_assert(
+        buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
+        f"Expects `buffers` and `buffer_dtypes` to have the same length if "
+        f"`buffer_dtypes` is specified but got {len(buffers)} and "
+        f"{len(buffer_dtypes)}",
+    )
+    for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
+        if not torch.is_floating_point(buffer) or buffer_dtype is None:
+            buffer.data = buffer.to(device=device)
+        else:
+            buffer.data = buffer.to(device=device, dtype=buffer_dtype)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_shard_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_shard_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..756e65ba7ab3fc98f83299a91b3fb4ae7d596973
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_shard_utils.py
@@ -0,0 +1,127 @@
+import copy
+import itertools
+import math
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+from torch.distributed import distributed_c10d
+from torch.distributed._shard.sharded_tensor import (
+    Shard,
+    ShardedTensor,
+    ShardedTensorMetadata,
+    TensorProperties,
+)
+from torch.distributed._shard.sharding_spec import ShardMetadata
+from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
+
+
+def _get_remote_device_str(rank, device_type, num_devices_per_node):
+    if device_type.lower() == "cpu":
+        return f"rank:{rank}/{device_type}"
+    else:
+        return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
+
+
+def _create_chunk_sharded_tensor(
+    tensor: torch.Tensor,
+    rank: int,
+    world_size: int,
+    num_devices_per_node: int,
+    pg: dist.ProcessGroup,
+    device: Optional[torch.device] = None,
+) -> ShardedTensor:
+    """
+    Shard a tensor to chunks along the first dimension. The local rank will gets its
+    corresponding chunk as the local shard to create a ShardedTensor.
+    """
+    chunks = tensor.chunk(world_size, dim=0)
+    if len(chunks) > rank:
+        local_shard = chunks[rank].clone()
+        offsets = [0 for _ in tensor.size()]
+        offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
+        local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
+    else:
+        local_shards = []
+
+    # Create a ShardedTensor without invoking communication.
+    chunk_sizes = [list(chunk.size()) for chunk in chunks]
+    dim0_offsets = [0] + list(
+        itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
+    )[:-1]
+    offsets = [0] * (len(chunk_sizes[0]) - 1)
+    chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
+    device_type = (
+        distributed_c10d._get_pg_default_device(pg).type
+        if device is None
+        else device.type
+    )
+    placements = [
+        _get_remote_device_str(r, device_type, num_devices_per_node)
+        for r in range(len(chunk_sizes))
+    ]
+    assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
+    shard_metadata = [
+        ShardMetadata(offset, size, placement)
+        for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
+    ]
+    sharded_tensor_metadata = ShardedTensorMetadata(
+        shards_metadata=shard_metadata,
+        size=tensor.size(),
+        tensor_properties=TensorProperties(
+            dtype=tensor.dtype,
+            layout=tensor.layout,
+            requires_grad=False,
+            memory_format=torch.contiguous_format,
+            pin_memory=tensor.is_pinned(),
+        ),
+    )
+    return ShardedTensor._init_from_local_shards_and_global_metadata(
+        local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
+    )
+
+
+def _create_chunk_dtensor(
+    tensor: torch.Tensor,
+    rank: int,
+    device_mesh: DeviceMesh,
+) -> DTensor:
+    """
+    Shard a tensor to chunks along the first dimension. The local rank will gets its
+    corresponding chunk as the local tensor to create a DTensor.
+    """
+    # We need to explicitly call .detach() to return a new tensor detached from the current graph.
+    tensor = tensor.clone().detach()
+
+    # FSDP placements: [Shard(0)]
+    # HSDP placements: [Replicate(), Shard(0)]
+    replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
+    shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
+    shard_placements[-1] = DShard(0)  # type: ignore[call-overload]
+
+    return DTensor.from_local(
+        tensor, device_mesh, replicate_placements, run_check=False
+    ).redistribute(
+        placements=shard_placements,
+    )
+
+
+def _all_gather_dtensor(
+    tensor: DTensor,
+    parent_mesh: Optional[DeviceMesh],
+) -> torch.Tensor:
+    """
+    All gather a DTensor in its sharded dimension and return the local tensor.
+    """
+    assert parent_mesh is None
+
+    placements = list(copy.deepcopy(tensor.placements))
+    # FSDP placements: [Shard(0)] -> [Replicate()]
+    # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
+    placements[-1] = Replicate()
+    tensor = tensor.redistribute(
+        device_mesh=tensor.device_mesh,
+        placements=placements,
+    )
+
+    return tensor.to_local()
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_state_dict_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_state_dict_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..728f9c8a569a270260077391b3a6946a054c4d6a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_state_dict_utils.py
@@ -0,0 +1,928 @@
+import contextlib
+import logging
+import math
+import warnings
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    Iterator,
+    List,
+    no_type_check,
+    Tuple,
+)
+
+import torch
+import torch.distributed as dist
+
+import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
+
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed._shard.sharded_tensor import (
+    init_from_local_shards,
+    Shard,
+    ShardedTensor,
+)
+from torch.distributed._tensor import DTensor
+from torch.distributed.device_mesh import _mesh_resources
+
+from torch.distributed.fsdp._common_utils import (
+    _FSDPState,
+    _get_module_fsdp_state_if_fully_sharded_module,
+    _has_fsdp_params,
+    _is_composable,
+    _module_handle,
+    clean_tensor_name,
+    FSDP_PREFIX,
+    FSDP_WRAPPED_MODULE,
+)
+from torch.distributed.fsdp._debug_utils import SimpleProfiler
+from torch.distributed.fsdp._runtime_utils import (
+    _cast_buffers_to_dtype_and_device,
+    _get_orig_buffer_dtypes,
+    _lazy_init,
+    _reset_flat_param_grad_info_if_needed,
+)
+from torch.distributed.fsdp.api import (
+    FullStateDictConfig,
+    ShardingStrategy,
+    StateDictType,
+)
+from torch.distributed.utils import _replace_by_prefix
+
+from ._fsdp_extensions import (
+    _ext_all_gather_dtensor,
+    _ext_chunk_dtensor,
+    _ext_chunk_tensor,
+    _ext_post_unflatten_transform,
+    _ext_pre_load_state_dict_transform,
+)
+from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM
+
+
+logger = logging.getLogger(__name__)
+
+
+def _should_unshard_params(fsdp_state: _FSDPState) -> bool:
+    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD and (
+        _is_composable(fsdp_state) or fsdp_state._use_orig_params
+    ):
+        return False
+    else:
+        return True
+
+
+def _convert_to_wrapped_module_name(module_name: str) -> str:
+    module_name = module_name.replace(f"{FSDP_PREFIX}", "")
+    module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
+    if module_name:
+        module_name = f"{module_name}."
+    # `CheckpointWrapper` adds a prefix that has to be removed as well.
+    module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "")
+    return module_name
+
+
+def _param_name_infos(
+    module: nn.Module, fsdp_state: _FSDPState
+) -> Iterator[Tuple[str, str, str]]:
+    if not _has_fsdp_params(fsdp_state, module):
+        return
+    for param_name, module_name in _module_handle(
+        fsdp_state, module
+    ).param_module_names():
+        module_name = _convert_to_wrapped_module_name(module_name)
+        fqn = f"{module_name}{param_name}"
+        yield fqn, param_name, module_name
+
+
+def _shared_param_name_infos(
+    module: nn.Module, fsdp_state
+) -> Iterator[Tuple[str, str, str]]:
+    for param_name, module_name in _module_handle(
+        fsdp_state, module
+    ).shared_param_module_names():
+        module_name = _convert_to_wrapped_module_name(module_name)
+        fqn = f"{module_name}{param_name}"
+        yield fqn, param_name, module_name
+
+
+@no_type_check
+def _enter_unshard_params_ctx(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    writeback: bool = False,
+    rank0_only: bool = False,
+    offload_to_cpu: bool = False,
+    with_grads: bool = False,
+) -> None:
+    """
+    state_dict hooks cannot use the pure context call as the checkpoint flow
+    requires to enter the context in the pre-hook but leave the context in the
+    post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
+    """
+    assert module not in fsdp_state._unshard_params_ctx, (
+        "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
+        "is not None."
+    )
+    fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
+        module,
+        fsdp_state,
+        writeback=writeback,
+        rank0_only=rank0_only,
+        offload_to_cpu=offload_to_cpu,
+        with_grads=with_grads,
+    )
+    fsdp_state._unshard_params_ctx[module].__enter__()
+
+
+@no_type_check
+def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
+    """A helper function to exit ``_unshard_fsdp_state_params`` context."""
+    fsdp_state._unshard_params_ctx[module].__exit__(None, None, None)
+    fsdp_state._unshard_params_ctx.pop(module)
+
+
+def _common_pre_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+) -> None:
+    """Performs the pre-state_dict tasks shared by all state_dict types."""
+    if fsdp_state._device_handle.is_available():
+        fsdp_state._device_handle.synchronize()
+    # TODO: need to check if this is always correct for composable FSDP.
+    _lazy_init(fsdp_state, module)
+    if fsdp_state._is_root:
+        _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles)
+
+
+def _common_unshard_pre_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    offload_to_cpu: bool,
+    rank0_only: bool,
+) -> None:
+    """
+    Performs the pre-state_dict tasks shared by all state_dict types that require
+    ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
+    """
+    # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases.
+    if not _should_unshard_params(fsdp_state):
+        return
+    _enter_unshard_params_ctx(
+        module,
+        fsdp_state,
+        writeback=False,
+        offload_to_cpu=offload_to_cpu,
+        rank0_only=rank0_only,
+    )
+
+
+@no_type_check
+def _common_unshard_post_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+    param_hook: Callable,
+) -> Dict[str, Any]:
+    """
+    The post-state_dict flow that shared by all state_dict types that require
+    ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
+    hook.
+    """
+    _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
+    # Return early for trivial cases
+    if not state_dict or not _has_fsdp_params(fsdp_state, module):
+        if _should_unshard_params(fsdp_state):
+            _exit_unshard_params_ctx(module, fsdp_state)
+        return state_dict
+
+    # If a rank does not have unsharded parameters(when `rank0_only=True`
+    # and `rank != 0`), then the rank only needed to participate in the
+    # all-gather and does not need to save the # state dict. We simply check
+    # rank0_only to ensure this issue.
+    rank0_only = (
+        fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT
+        and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only
+    )
+    # no_fsdp_return means the state_dict returned by this rank should contain
+    # only non-FSDP controlled parameters and buffers.
+    no_fsdp_return = rank0_only and fsdp_state.rank != 0
+    if no_fsdp_return and not fsdp_state._use_orig_params:
+        for clean_key in fsdp_state._buffer_names:
+            # This is a hack to support activation checkpoint.
+            clean_key = clean_key.replace(
+                f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", ""
+            )
+            state_dict.pop(f"{prefix}{clean_key}", None)
+        # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is
+        # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param
+        # to appear in state_dict.
+        state_dict.pop(f"{prefix}{FLAT_PARAM}")
+        _exit_unshard_params_ctx(module, fsdp_state)
+        return state_dict
+
+    # Loop only the parameters saved in this instance's wrapped module to
+    # avoid processing buffers.
+    for fqn, param_name, module_name in _param_name_infos(module, fsdp_state):
+        fqn = f"{prefix}{fqn}"
+        if no_fsdp_return:
+            state_dict.pop(fqn)
+            continue
+        assert fqn in state_dict, (
+            f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
+            f"has {state_dict.keys()}. "
+            f"prefix={prefix}, module_name={module_name}, "
+            f"param_name={param_name} rank={fsdp_state.rank}."
+        )
+
+        param_hook(state_dict, prefix, fqn)
+
+    if _should_unshard_params(fsdp_state):
+        _exit_unshard_params_ctx(module, fsdp_state)
+
+    cpu_device = torch.device("cpu")
+    buffer_clean_fqns = []
+    buffers = []
+    for clean_key in fsdp_state._buffer_names:
+        # This is a hack to support activation checkpoint.
+        clean_key = clean_tensor_name(clean_key)
+        fqn = f"{prefix}{clean_key}"
+        if fqn not in state_dict:
+            # A buffer can be registered as non-persistent.
+            continue
+        if no_fsdp_return:
+            state_dict.pop(fqn)
+        else:
+            buffer = state_dict[fqn]
+            if (
+                fsdp_state._state_dict_config.offload_to_cpu
+                and buffer.device != cpu_device
+            ):
+                state_dict[fqn] = buffer.to(cpu_device)
+            # skip upcasting for ignored buffers
+            if clean_key not in fsdp_state._ignored_buffer_names:
+                buffer_clean_fqns.append(clean_key)
+                buffers.append(state_dict[fqn])
+
+    if buffers:
+        mixed_precision_enabled_for_buffers = (
+            fsdp_state._mixed_precision_enabled_for_buffers()
+            if not _is_composable(fsdp_state)
+            else (fsdp_state.mixed_precision.buffer_dtype is not None)
+        )
+        if mixed_precision_enabled_for_buffers:
+            buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns)
+            _cast_buffers_to_dtype_and_device(
+                buffers, buffer_dtypes, fsdp_state.compute_device
+            )
+            for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
+                fqn = f"{prefix}{clean_fqn}"
+                logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype)
+                state_dict[fqn] = buffer.clone()
+    return state_dict
+
+
+@no_type_check
+def _full_pre_state_dict_hook(
+    fsdp_state: _FSDPState,
+    module: nn.Module,
+    *args,
+    **kwargs,
+) -> None:
+    """
+    Hook that runs before model.state_dict() is called. pre-state_dict hook is
+    not actually supported by ``nn.Module``. As a result, this API is called
+    from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict
+    is supported in ``nn.Module``, this hook will be registered as a hook in
+    ``nn.Module``.
+    """
+    if getattr(fsdp_state, "_device_mesh", False):
+        parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
+
+    _common_pre_state_dict_hook(module, fsdp_state)
+    _common_unshard_pre_state_dict_hook(
+        module,
+        fsdp_state,
+        offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,
+        rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only,
+    )
+
+
+@no_type_check
+def _full_post_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+) -> Dict[str, Any]:
+    """
+    Hook that runs after model.state_dict() is called before returning result to
+    user. For FSDP, we may have to clone the tensors in state_dict as params go
+    back to sharded version after _unshard_fsdp_state_params ends, and also remove
+    the ``FSDP_WRAPPED_MODULE`` prefix.
+    """
+
+    def param_hook(
+        state_dict: Dict[str, Any],
+        prefix: str,
+        fqn: str,
+    ) -> None:
+        clean_key = fqn
+        clean_prefix = clean_tensor_name(prefix)
+        # Strip prefix out of key if needed as buffer names and param names
+        # do not have prefix considered as they are not computed in `state_dict`
+        # call.
+        if clean_key.startswith(clean_prefix):
+            clean_key = clean_key[len(clean_prefix) :]
+
+        # Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
+        if not getattr(state_dict[fqn], "_has_been_cloned", False):
+            try:
+                state_dict[fqn] = state_dict[fqn].clone().detach()
+                state_dict[fqn]._has_been_cloned = True  # type: ignore[attr-defined]
+            except BaseException as e:
+                warnings.warn(
+                    f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
+                    "This may mean that this state_dict entry could point to invalid "
+                    "memory regions after returning from state_dict() call if this "
+                    "parameter is managed by FSDP. Please check clone "
+                    f"implementation of {fqn}. Error: {str(e)}"
+                )
+
+    return _common_unshard_post_state_dict_hook(
+        module, fsdp_state, state_dict, prefix, param_hook
+    )
+
+
+def _full_pre_load_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+) -> None:
+    _lazy_init(fsdp_state, module)
+    if _should_unshard_params(fsdp_state):
+        with SimpleProfiler.profile("_enter_unshard_params_ctx"):
+            _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
+    # Add FSDP_PREFIX only for wrapper-based FSDP.
+    if not _is_composable(fsdp_state):
+        _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
+
+
+def _full_post_load_state_dict_hook(
+    module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
+) -> None:
+    if _should_unshard_params(fsdp_state):
+        with SimpleProfiler.profile("_exit_unshard_params_ctx"):
+            _exit_unshard_params_ctx(module, fsdp_state)
+
+
+def _local_pre_state_dict_hook(
+    fsdp_state: _FSDPState,
+    module: nn.Module,
+    *args,
+    **kwargs,
+) -> None:
+    """
+    Hook that runs before model.state_dict() is called. Right now, pre-state_dict
+    hook is not supported by the PyTorch core. So this API is called from
+    `_local_post_state_dict_hook()` to simulate the case.
+    """
+    if (
+        _has_fsdp_params(fsdp_state, module)
+        and not _module_handle(fsdp_state, module).uses_sharded_strategy
+    ):
+        raise RuntimeError(
+            "``local_state_dict`` can only be used when parameters are flatten "
+            "and sharded."
+        )
+    _common_pre_state_dict_hook(module, fsdp_state)
+
+
+@no_type_check
+def _local_post_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+) -> Dict[str, Any]:
+    """
+    This hook create a ShardedTensor from the local flat_param and replace
+    the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
+    will happen. The underlying storage is the same.
+    """
+
+    _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix)
+    if not _has_fsdp_params(fsdp_state, module):
+        return state_dict
+
+    # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
+    # value as the flat_param but it is a pure Tensor because
+    # nn.Module.state_dict() will detach the parameter. Therefore, we need
+    # to get flat_param to get the metadata.
+    assert _module_handle(fsdp_state, module), "Should have returned early"
+    flat_param = _module_handle(fsdp_state, module).flat_param
+    # Constructs a ShardedTensor from the flat_param "without" padding.
+    # Removing the padding allows users to change the number of ranks
+    # when loading the local_state_dict.
+    full_numel = flat_param._unpadded_unsharded_size.numel()  # type: ignore[attr-defined]
+    shard_offset = flat_param.numel() * fsdp_state.rank
+    valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
+    if valid_data_size > 0:
+        # If FlatParameter is returned, FlatParameter._local_shard cause a
+        # pickling issue (can be torch.save but not torch.load). Since there
+        # is no benefit for state_dict to return the actual FlatParameter class,
+        # a view (which is a tensor) of the FlatParameter will be returned.
+        flat_param = flat_param[:valid_data_size].view(valid_data_size)
+        local_shards = [
+            Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank)
+        ]
+    else:
+        local_shards = []
+    sharded_tensor = init_from_local_shards(
+        local_shards, full_numel, process_group=fsdp_state.process_group
+    )  # type: ignore[assignment]
+    # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT.
+    if fsdp_state._state_dict_config.offload_to_cpu:
+        sharded_tensor = sharded_tensor.cpu()
+    state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
+    return state_dict
+
+
+def _local_post_load_state_dict_hook(
+    module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
+) -> None:
+    pass
+
+
+def _local_pre_load_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+) -> None:
+    """
+    This hook finds the local flat_param for this FSDP module from the
+    state_dict. The flat_param should be a ShardedTensor. This hook converts
+    the ShardedTensor to a tensor. No copy happen unless padding is required.
+    """
+    _lazy_init(fsdp_state, module)
+    _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}")
+    fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"
+    if fqn not in state_dict:
+        assert not _has_fsdp_params(fsdp_state, module), (
+            "No `FlatParameter` in `state_dict` for this FSDP instance "
+            "but it has parameters"
+        )
+        return
+    load_tensor = state_dict[fqn]
+    assert isinstance(
+        load_tensor, ShardedTensor
+    ), "Tensors in local_state_dict should be ShardedTensor."
+
+    # Convert the ShardedTensor to a Tensor.
+    flat_param = _module_handle(fsdp_state, module).flat_param
+    assert flat_param is not None
+    valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
+    shards = load_tensor.local_shards()
+    if valid_data_size > 0:
+        assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
+        load_tensor = shards[0].tensor
+
+        # Get the metadata of the flat_param to decide whether to pad the loaded
+        # tensor.
+        if flat_param._shard_numel_padded > 0:
+            assert load_tensor.numel() < flat_param.numel(), (
+                f"Local shard size = {flat_param.numel()} and the tensor in "
+                f"the state_dict is {load_tensor.numel()}."
+            )
+            load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
+    else:
+        load_tensor = flat_param
+    # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT.
+    state_dict[fqn] = load_tensor
+
+
+def _sharded_pre_state_dict_hook(
+    fsdp_state: _FSDPState,
+    module: nn.Module,
+    *args,
+    **kwargs,
+) -> None:
+    """
+    Hook that runs before model.state_dict() is called. Check
+    ``_full_pre_load_state_dict_hook`` for the detail.
+    """
+    if (
+        _has_fsdp_params(fsdp_state, module)
+        and not _module_handle(fsdp_state, module).uses_sharded_strategy
+    ):
+        raise RuntimeError(
+            "``sharded_state_dict`` can only be used when parameters are flatten "
+            "and sharded."
+        )
+    _common_pre_state_dict_hook(module, fsdp_state)
+    # Setting offload_to_cpu here does not work even if offload_to_cpu is True.
+    # We have to create ShardedTensor first then move it to CPU.
+    _common_unshard_pre_state_dict_hook(
+        module,
+        fsdp_state,
+        offload_to_cpu=False,
+        rank0_only=False,
+    )
+
+
+@no_type_check
+def _sharded_post_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+) -> Dict[str, Any]:
+    """
+    The hook replaces the unflattened, unsharded parameter in the state_dict
+    with a unflattened, sharded parameter (a ShardedTensor).
+    """
+
+    def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
+        param = state_dict[fqn]
+        if not fsdp_state._state_dict_config._use_dtensor:
+            sharded_tensor = _ext_chunk_tensor(
+                tensor=param,
+                rank=fsdp_state.rank,
+                world_size=fsdp_state.world_size,
+                num_devices_per_node=fsdp_state._device_handle.device_count(),
+                pg=fsdp_state.process_group,
+                fsdp_extension=fsdp_state._fsdp_extension,
+            )
+        else:
+            sharded_tensor = _ext_chunk_dtensor(
+                tensor=param,
+                rank=fsdp_state.rank,
+                device_mesh=fsdp_state._device_mesh,
+                fsdp_extension=fsdp_state._fsdp_extension,
+            )
+        if fsdp_state._state_dict_config.offload_to_cpu:
+            sharded_tensor = sharded_tensor.cpu()
+        state_dict[fqn] = sharded_tensor
+
+    return _common_unshard_post_state_dict_hook(
+        module, fsdp_state, state_dict, prefix, param_hook
+    )
+
+
+@no_type_check
+def _sharded_post_load_state_dict_hook(
+    module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
+) -> None:
+    if _has_fsdp_params(fsdp_state, module):
+        with SimpleProfiler.profile("_exit_unshard_params_ctx"):
+            _exit_unshard_params_ctx(module, fsdp_state)
+
+
+@no_type_check
+def _sharded_pre_load_state_dict_hook(
+    module: nn.Module,
+    fsdp_state: _FSDPState,
+    state_dict: Dict[str, Any],
+    prefix: str,
+) -> None:
+    """
+    The hook combines the unflattened, sharded parameters (ShardedTensor) to
+    a new FlatParameter and shards the new FlatParameter to the local chunk.
+    """
+    _lazy_init(fsdp_state, module)
+    if not _is_composable(fsdp_state):
+        _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
+    if not _has_fsdp_params(fsdp_state, module):
+        return
+
+    handle = _module_handle(fsdp_state, module)
+    if not handle.uses_sharded_strategy:
+        raise RuntimeError(
+            "load_sharded_state_dict can only be called when parameters "
+            "are flattened and sharded."
+        )
+    fqn_to_param_ext = dict(
+        zip(handle.flat_param._fqns, handle.flat_param._param_extensions)
+    )
+
+    for fqn, _, _ in _param_name_infos(module, fsdp_state):
+        if not _is_composable(fsdp_state):
+            fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}"
+        else:
+            fqn_from_global_root = f"{prefix}{fqn}"
+        try:
+            param = state_dict.pop(fqn_from_global_root)
+        except KeyError:
+            logger.warning(
+                f"Did not find param with FQN {fqn_from_global_root}, skipping it. "  # noqa: G004
+                "The weight will not be filled if you expect it to be."
+            )
+            continue  # TODO: Improve unittesting for state_dict finetuning
+            # cases: https://github.com/pytorch/pytorch/issues/109134
+
+        if not fsdp_state._state_dict_config._use_dtensor:
+            # All-gather the param (ShardedTensor)
+            param, shards = _ext_pre_load_state_dict_transform(
+                param, fsdp_state._fsdp_extension
+            )
+
+            assert len(shards) < 2, (
+                "Expects 0 or 1 shard per rank "
+                f"but got {len(shards)} shards on rank {fsdp_state.rank}."
+            )
+            param_numel = param.size().numel()
+            dim_0_size = param.size()[0]
+            chunk_size = (
+                math.ceil(dim_0_size / fsdp_state.world_size)
+                * param_numel
+                // dim_0_size
+            )
+            if len(shards) == 1:
+                local_tensor = shards[0].tensor.flatten()
+                with SimpleProfiler.profile(SimpleProfiler.Type.H2D):
+                    local_tensor = local_tensor.to(fsdp_state.compute_device)
+                num_padding = chunk_size - local_tensor.numel()
+                if num_padding > 0:
+                    local_tensor = F.pad(local_tensor, [0, num_padding])
+            else:
+                local_tensor = torch.zeros(
+                    chunk_size, dtype=param.dtype, device=fsdp_state.compute_device
+                )
+            tensor = torch.empty(
+                chunk_size * fsdp_state.world_size,
+                dtype=local_tensor.dtype,
+                device=fsdp_state.compute_device,
+            )
+            with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
+                dist.all_gather_into_tensor(
+                    tensor, local_tensor, group=fsdp_state.process_group
+                )
+            tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
+            state_dict[fqn_from_global_root] = tensor
+        else:
+            if param.device != fsdp_state._device_mesh.device_type:
+                param = param.to(fsdp_state._device_mesh.device_type)
+
+            parent_mesh = _mesh_resources.get_parent_mesh(fsdp_state._device_mesh)
+            local_tensor = _ext_all_gather_dtensor(
+                param, parent_mesh, fsdp_state._fsdp_extension
+            )
+
+            if fqn_to_param_ext.get(fqn) is not None:
+                ext = fqn_to_param_ext[fqn]
+                local_tensor = _ext_post_unflatten_transform(
+                    local_tensor, ext, fsdp_state._fsdp_extension
+                )
+            state_dict[fqn_from_global_root] = local_tensor
+
+    with SimpleProfiler.profile("_enter_unshard_params_ctx"):
+        _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
+
+
+@contextlib.contextmanager
+def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
+    old_state_dict_config = fsdp_state._state_dict_config
+    old_state_dict_type = fsdp_state._state_dict_type
+    fsdp_state._state_dict_config = FullStateDictConfig()
+    fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+    yield
+    fsdp_state._state_dict_config = old_state_dict_config
+    fsdp_state._state_dict_type = old_state_dict_type
+
+
+@no_type_check
+@torch.no_grad()
+def _post_state_dict_hook(
+    module: nn.Module,
+    state_dict: Dict[str, Any],
+    prefix: str,
+    *args: Any,
+) -> Dict[str, Any]:
+    """
+    _post_state_dict_hook() is called after the state_dict() of this
+    FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
+    what postprocessing will be done.
+    """
+    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
+    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        context = contextlib.nullcontext()
+
+    with context:
+        _post_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
+        }
+        processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
+            module, fsdp_state, state_dict, prefix
+        )
+
+    if fsdp_state._is_root:
+        logger.info("FSDP finished processing state_dict(), prefix=%s", prefix)
+        for key, tensor in sorted(processed_state_dict.items()):
+            if key.startswith(prefix) and isinstance(tensor, torch.Tensor):
+                local_shape = tensor.shape
+                if isinstance(tensor, ShardedTensor):
+                    local_shape = None
+                    shards = tensor.local_shards()
+                    if shards:
+                        local_shape = shards[0].tensor.shape
+                elif isinstance(tensor, DTensor):
+                    local_shape = tensor.to_local().shape
+                logger.info(
+                    "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s",
+                    key,
+                    type(tensor),
+                    tensor.shape,
+                    local_shape,
+                    tensor.dtype,
+                    tensor.device,
+                )
+
+    return processed_state_dict
+
+
+@no_type_check
+@torch.no_grad()
+def _pre_state_dict_hook(
+    module: nn.Module,
+    *args,
+    **kwargs,
+) -> None:
+    """
+    This is called before the core state dict saving logic of ``module``.
+    ``fsdp_state._state_dict_type`` is used to decide what postprocessing will
+    be done.
+    """
+    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
+    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        _set_use_dtensor(fsdp_state)
+        context = contextlib.nullcontext()
+
+    with context:
+        _pre_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
+        }
+        _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
+            fsdp_state,
+            module,
+            *args,
+            **kwargs,
+        )
+
+
+@no_type_check
+def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
+    # If device_mesh is passed in when initalizing FSDP, we automatically turn the
+    # _use_dtensor flag to be true for ShardedStateDictConfig().
+    if getattr(fsdp_state, "_device_mesh", None):
+        state_dict_type = fsdp_state._state_dict_type
+        if state_dict_type == StateDictType.LOCAL_STATE_DICT:
+            raise RuntimeError(
+                "Found state_dict_type LOCAL_STATE_DICT",
+                "DeviceMesh is not compatible with LOCAL_STATE_DICT.",
+                "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
+            )
+        else:
+            fsdp_state._state_dict_config._use_dtensor = True
+
+
+@no_type_check
+@torch.no_grad()
+def _pre_load_state_dict_hook(
+    module: nn.Module,
+    state_dict: Dict[str, Any],
+    prefix: str,
+    *args: Any,
+) -> None:
+    """
+    This is called before ``module._load_from_state_dict()``.
+    ``fsdp_state._state_dict_type`` is used to decide what preprocessing will
+    be done.
+    """
+    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
+    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        _set_use_dtensor(fsdp_state)
+        context = contextlib.nullcontext()
+
+    _lazy_init(fsdp_state, module)
+    if fsdp_state._is_root:
+        SimpleProfiler.reset()
+
+    with context:
+        _pre_load_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
+        }
+        # Code that is common for all state_dict impls
+        if fsdp_state._device_handle.is_available():
+            fsdp_state._device_handle.synchronize()
+        # Dispatch into state_dict specific implementation of pre-hook.
+        _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
+            module, fsdp_state, state_dict, prefix
+        )
+
+
+@no_type_check
+@torch.no_grad()
+def _post_load_state_dict_hook(
+    module: nn.Module,
+    incompatible_keys: Tuple[List[str], List[str]],
+    *args: Any,
+) -> None:
+    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
+    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        context = contextlib.nullcontext()
+
+    with context:
+        _post_load_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
+        }
+        # Code that is common for all state_dict impls
+        # Dispatch into state_dict type specific implementation of post-hook for
+        # loading state_dict.
+        _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
+
+    # When reporting incompatible keys, trim FSDP prefixes.
+    missing_keys = incompatible_keys[0]
+    unexpected_keys = incompatible_keys[1]
+    for i in range(len(missing_keys)):
+        missing_keys[i] = clean_tensor_name(missing_keys[i])
+
+    for i in range(len(unexpected_keys)):
+        unexpected_keys[i] = clean_tensor_name(unexpected_keys[i])
+
+    if fsdp_state._is_root:
+        SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")
+
+
+def _register_all_state_dict_hooks(state: _FSDPState):
+    """
+    Registers pre-save, post-save, pre-load, and post-load state dict hooks.
+    """
+    for hook_registration_fn_str, hook, hook_registration_fn_kwargs in (
+        ("register_state_dict_pre_hook", _pre_state_dict_hook, {}),
+        ("_register_state_dict_hook", _post_state_dict_hook, {}),
+        (
+            "_register_load_state_dict_pre_hook",
+            _pre_load_state_dict_hook,
+            {"with_module": True},
+        ),
+        ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}),
+    ):
+        _register_state_dict_hooks_base(
+            state, hook_registration_fn_str, hook, hook_registration_fn_kwargs
+        )
+
+
+@no_type_check
+def _register_state_dict_hooks_base(
+    state: _FSDPState,
+    hook_registration_fn_name: str,
+    hook: Callable,
+    hook_registration_fn_kwargs: Dict[str, Any],
+) -> None:
+    """Registers ``hook`` using ``hook_registration_fn``."""
+    if not _is_composable(state):
+        getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs)
+    else:
+        handle = state._handle
+        if handle:
+            getattr(handle._fully_sharded_module, hook_registration_fn_name)(
+                hook, **hook_registration_fn_kwargs
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_trace_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_trace_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d1b1a4ee7b057fc0f00877448e65a47074fa1be
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_trace_utils.py
@@ -0,0 +1,237 @@
+import functools
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
+
+import torch
+import torch.nn as nn
+
+
+@dataclass
+class TracingConfig:
+    """
+    This represents a symbolic tracing configuration.
+
+    Args:
+        tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
+            use for symbolic tracing. The default value is the native
+            :class:`torch.fx.Tracer` constructed with default arguments.
+            However, the user may want to pass a different value such as the
+            ``HFTracer`` for models in the HuggingFace Transformers_ library.
+            .. _Transformers: https://huggingface.co/docs/transformers/index
+        concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
+            should not be treated as ``torch.fx.Proxy`` when tracing the
+            module ``forward()``. Passing ``concrete_args`` allows partially
+            specializing the forward, e.g. to remove control flow or data
+            structures. This ``concrete_args`` here is the same argument used
+            in :meth:`~torch.fx.Tracer.trace`.
+    """
+
+    tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer)
+    concrete_args: Optional[Dict[str, Any]] = None
+
+
+class _ParamUsageInfo(NamedTuple):
+    """
+    This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
+    execution information. The ``dict`` maps modules to a list of these
+    ``_ParamUsageInfo`` instances, where each instance represents a group of
+    parameters used together.
+
+    Specifically, for each module key in the ``dict``, each instance of this
+    class represents either:
+    (1) the module and some sublist of its ``named_parameters()`` used
+    together in execution (see ``_patched_create_proxy()``), or
+    (2) a submodule and all of ``submodule.named_parameters()`` (see
+    ``_patched_call_module()``).
+
+    Type (1) corresponds to directly using parameters in ops without calling
+    ``forward()``, and type (2) corresponds to calling ``forward()``. The
+    mapped-to lists in the ``dict`` follow the execution order.
+    """
+
+    module: nn.Module
+    named_params: List[Tuple[str, nn.Parameter]]
+
+
+class _ExecutionInfo:
+    """
+    This represents the execution order information from the forward pass.
+
+    Attributes:
+        curr_module (nn.Module): Current module being traced.
+        module_forward_order (List[nn.Module]): The modules in (pre-)forward
+            order, i.e. the order in which their ``forward()`` methods are
+            called. Each call to a module's ``forward()`` corresponds to one
+            element in the list.
+        module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
+            Maps a module to a list of module execution infos. See
+            :class:`_ParamUsageInfo` for details.
+        param_forward_order (List[nn.Parameter]): The parameters in forward
+            execution order, where only a parameter's first participation is
+            included.
+        visited_params (Set[nn.Parameter]): The parameters visited so far
+            during the trace. This is only used during tracing for fast
+            membership check. Invariant: The parameters in
+            ``param_forward_order`` are exactly those in ``visited_params``.
+    """
+
+    def __init__(self, root_module: nn.Module) -> None:
+        self.curr_module: nn.Module = root_module
+        self.module_forward_order: List[nn.Module] = [root_module]
+        self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
+            root_module: []
+        }
+        self.param_forward_order: List[nn.Parameter] = []
+        self.visited_params: Set[nn.Parameter] = set()
+
+
+class _ExecOrderTracer:
+    def __init__(self) -> None:
+        self.exec_info: Optional[_ExecutionInfo] = None
+
+    @contextmanager
+    def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
+        self.exec_info = _ExecutionInfo(root_module)
+        orig_call_module = tracer.call_module
+        orig_create_proxy = tracer.create_proxy
+        tracer.call_module = functools.partial(
+            self._patched_call_module, orig_call_module, self.exec_info
+        )
+        fqn_to_param = dict(root_module.named_parameters())
+        tracer.create_proxy = functools.partial(
+            self._patched_create_proxy,
+            orig_create_proxy,
+            self.exec_info,
+            fqn_to_param,
+        )
+        try:
+            yield
+        finally:
+            tracer.call_module = orig_call_module
+            tracer.create_proxy = orig_create_proxy
+
+    def _patched_call_module(
+        self,
+        call_module: Callable,
+        exec_info: _ExecutionInfo,
+        # Below are the expected arguments to `call_module()`
+        module: nn.Module,
+        forward: Callable,
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+    ) -> Any:
+        """
+        Overrides ``call_module`` to save execution information to
+        ``exec_info``. Note that ``call_module`` is called during symbolic
+        tracing for each non-root module.
+
+        Args:
+            call_module (Callable): Original ``call_module`` to override.
+            exec_info (_ExecutionInfo): Used to record execution information.
+            module (nn.Module): Module corresponding to this ``call_module``.
+            forward (Callable): ``forward()`` method of ``module`` to be called
+                for this ``call_module``.
+            args (Tuple[Any, ...]): Positional arguments for ``forward``.
+            kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
+
+        Returns:
+            Same return value as ``call_module``.
+        """
+        exec_info.module_forward_order.append(module)
+        named_params = list(module.named_parameters())
+        curr_module = exec_info.curr_module
+        if named_params:
+            assert (
+                curr_module in exec_info.module_to_param_usage_infos
+            ), "The current module should have already been processed by a patched `call_module`"
+            exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
+                _ParamUsageInfo(module, named_params)
+            )
+        prev_curr_module = curr_module
+        exec_info.curr_module = module
+        exec_info.module_to_param_usage_infos[module] = []
+        output = call_module(module, forward, args, kwargs)
+        exec_info.curr_module = prev_curr_module
+        return output
+
+    def _patched_create_proxy(
+        self,
+        create_proxy: Callable,
+        exec_info: _ExecutionInfo,
+        fqn_to_param: Dict[str, nn.Parameter],
+        # Below are the expected arguments to `create_proxy()`
+        kind: str,
+        target: torch.fx.node.Target,
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+        name: Optional[str] = None,
+        type_expr: Optional[Any] = None,
+        proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
+    ) -> torch.fx.Proxy:
+        """
+        Overrides ``create_proxy`` to save execution information to
+        ``exec_info``. Note that ``create_proxy`` is called during symbolic
+        tracing for each leaf function/method/module.
+
+        Args:
+            create_proxy (Callable): Original ``create_proxy`` to override.
+            exec_info (_ExecutionInfo): Used to record execution information.
+            fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
+                root module's ``named_parameters()`` with FQN as key and
+                parameter as value.
+            kind (str): Kind of the target method ('call_function',
+                'call_method', 'get_attr', 'call_module', 'placeholder', or
+                'output'). See :class:`torch.fx.Graph` for details. This is
+                passed to ``create_proxy``.
+            target (torch.fx.node.Target): Contains the string name of the
+                function/method/module. This is passed to ``create_proxy``.
+            args (Tuple[Any, ...]): Positional arguments for the function/
+                method/module. This is passed to ``create_proxy``.
+            kwargs (Dict[str, Any]): Keyword arguments for the function/method/
+                module. This is passed to ``create_proxy``
+            name (Optional[str]): An optional string name for the ``Node``
+                created in ``create_proxy``. This is passed to
+                ``create_proxy``.
+            type_expr (Optional[Any]): An optional type annotation representing
+                the Python type that the output of the node has. This is passed
+                to ``create_proxy``.
+            proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
+                An alternative proxy constructor used in ``create_proxy``. This
+                is passed to ``create_proxy``.
+
+        Returns:
+            torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
+        """
+        proxy = create_proxy(
+            kind, target, args, kwargs, name, type_expr, proxy_factory_fn
+        )
+        curr_module = exec_info.curr_module
+        if kind in ("call_function", "call_method"):
+            if args is not None:
+                named_params: List[Tuple[str, nn.Parameter]] = []
+                for arg in args:
+                    if (
+                        isinstance(arg, torch.fx.Proxy)
+                        and arg.node.target in fqn_to_param
+                    ):
+                        param = fqn_to_param[arg.node.target]
+                        named_params.append((arg.node.target, param))
+                        if param not in exec_info.visited_params:
+                            exec_info.visited_params.add(param)
+                            exec_info.param_forward_order.append(param)
+                if named_params:
+                    exec_info.module_to_param_usage_infos[curr_module].append(
+                        _ParamUsageInfo(curr_module, named_params)
+                    )
+        elif kind == "call_module":
+            named_params = list(curr_module.named_parameters())
+            if named_params:
+                exec_info.module_to_param_usage_infos[curr_module].append(
+                    _ParamUsageInfo(curr_module, named_params)
+                )
+            for _, param in named_params:
+                if param not in exec_info.visited_params:
+                    exec_info.visited_params.add(param)
+                    exec_info.param_forward_order.append(param)
+        return proxy
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_traversal_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_traversal_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b5db4fda13d369acf5522aef9a594abeb3db7ff
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_traversal_utils.py
@@ -0,0 +1,113 @@
+"""
+NOTE: This file must be imported like
+``import torch.distributed.fsdp._traversal_utils`` and not like
+``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular
+imports. For brevity, we may import the file as ``traversal_utils``.
+"""
+
+import collections
+from typing import Deque, List, Set, Tuple
+
+import torch.nn as nn
+from torch.distributed._composable.contract import _get_registry
+from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state
+
+
+"""
+[Note: FSDP State Traversal]
+For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel``
+module wrapping a fully sharded module, and for the non-wrapper code path,
+``_FSDPState`` is an object that gets embedded on a fully sharded module.
+See [Note: Fully Sharded Module] for the definition.
+
+There are three common traversal idioms: Given a root module,
+- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree.
+- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the
+tree (i.e. those with ``_is_root == True``).
+- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree.
+
+All of these methods must take in the root module (i.e. an ``nn.Module``) and
+not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph
+traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal.
+"""
+
+
+def _composable(module: nn.Module) -> bool:
+    """
+    Returns if ``module`` can compose with ``fully_shard``.
+    """
+    # TODO: Add any other composable APIs that are mutually exclusive.
+    registry = _get_registry(module)
+    if registry is None:
+        return True
+    return "replicate" not in registry
+
+
+# TODO (awgu): We may be able to remove this function if we retired the
+# `use_orig_params=False` code path since so far we only need the module for
+# `FlatParameter` registration, which is not needed for `use_orig_params=True`.
+def _get_fsdp_states_with_modules(
+    module: nn.Module,
+) -> Tuple[List[_FSDPState], List[nn.Module]]:
+    """
+    Returns a tuple containing:
+    1. A list of the ``_FSDPState`` instances in the module tree rooted at
+    ``module`` without any duplicates and following the ``module.modules()``
+    traversal order (which is assumed to be depth-first).
+    2. A corresponding list of the modules owning the states in the first list.
+
+    For the wrapper code path, both returned lists are the same, each
+    containing all ``FullyShardedDataParallel`` instances. For the composable
+    code path, this returns a list of all composable state instances and a list
+    of the corresponding fully sharded modules. See [Note: Fully Sharded
+    Module].
+
+    NOTE: The traversal does not proceed into any module annotated by an
+    incompatible API (e.g. ``replicate``).
+    """
+    fsdp_states: List[_FSDPState] = []
+    fsdp_modules: List[nn.Module] = []
+    # Track the visited FSDP states since multiple modules may share the same
+    # one and we want to return a de-duplicated list
+    visited_fsdp_states: Set[_FSDPState] = set()
+    # Track the visited modules in case of shared modules, which implies the
+    # module graph is no longer a tree
+    visited_modules: Set[nn.Module] = set()
+
+    # Perform depth-first search from `module` to ensure that we do not
+    # traverse into an incompatible API's subtree (use DFS instead of BFS to
+    # match `.modules()` order)
+    deque: Deque[nn.Module] = collections.deque([module])
+    while deque:
+        submodule = deque.popleft()
+        visited_modules.add(submodule)
+        if not _composable(submodule):
+            continue
+        for child_module in reversed(list(submodule.children())):
+            if child_module not in visited_modules:
+                deque.appendleft(child_module)
+        optional_state = _get_module_fsdp_state(submodule)
+        if optional_state is not None and optional_state not in visited_fsdp_states:
+            visited_fsdp_states.add(optional_state)
+            fsdp_states.append(optional_state)
+            fsdp_modules.append(submodule)
+    return fsdp_states, fsdp_modules
+
+
+def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]:
+    """See :func:`_get_fsdp_states_with_modules`."""
+    fsdp_states, _ = _get_fsdp_states_with_modules(module)
+    return fsdp_states
+
+
+def _get_fsdp_handles(module: nn.Module) -> List:
+    """
+    Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
+    following the rules in :func:`_get_fsdp_state`.
+    """
+    handles = [
+        fsdp_state._handle
+        for fsdp_state in _get_fsdp_states(module)
+        if fsdp_state._handle is not None
+    ]
+    return handles
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_unshard_param_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_unshard_param_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..053a73b3c254c33120d3b9dc8fc12dbaf272752d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_unshard_param_utils.py
@@ -0,0 +1,357 @@
+import contextlib
+import warnings
+from typing import cast, Generator
+
+import torch
+import torch.distributed.fsdp._traversal_utils as traversal_utils
+import torch.nn as nn
+from torch.distributed.fsdp._common_utils import (
+    _FSDPState,
+    _has_fsdp_params,
+    _module_handle,
+    HandleTrainingState,
+    TrainingState,
+)
+from torch.distributed.fsdp._runtime_utils import (
+    _get_fsdp_root_states_with_modules,
+    _lazy_init,
+    _reset_flat_param_grad_info_if_needed,
+    _reshard,
+    _reshard_grads,
+    _unshard,
+    _unshard_grads,
+)
+from torch.distributed.utils import _p_assert
+
+from ._flat_param import FlatParamHandle
+
+FLAT_PARAM = "_flat_param"
+
+
+@torch.no_grad()
+def _writeback_to_local_shard(
+    handle: FlatParamHandle,
+    writeback_grad: bool,
+):
+    """
+    For the handle, writes back the this rank's shard of the unsharded
+    flattened parameter to the sharded flattened parameter. If
+    ``writeback_grad=True``, then writes back to the sharded gradient as
+    well.
+
+    Precondition: The handle's ``FlatParameter`` 's data points to the
+    padded unsharded flattened parameter.
+    """
+
+    def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
+        if handle.uses_sharded_strategy:
+            # For sharded strategies, get the *unpadded* shard instead of
+            # the *padded* shard to persist user changes to the padding
+            # (though FSDP does not explicitly support this)
+            shard, _ = FlatParamHandle._get_unpadded_shard(
+                flat_param_or_grad,
+                handle.rank,
+                handle.world_size,
+            )
+            return shard
+        # For `NO_SHARD`, the `flat_param` or its gradient may be modified,
+        # so we write it back directly
+        return flat_param_or_grad
+
+    param_shard = _get_shard(handle.flat_param)
+    handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard)  # type: ignore[attr-defined]
+    if writeback_grad:
+        existing_grad = handle.sharded_grad
+        if existing_grad is not None:
+            assert handle.flat_param.grad is not None
+            grad_shard = _get_shard(handle.flat_param.grad)
+            existing_grad[: grad_shard.numel()].copy_(grad_shard)
+
+
+def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
+    """
+    De-registers the flattened parameter from the wrapped module, hiding it
+    from ``nn.Module`` methods.
+
+    We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
+    attribute but dynamically change whether it is visible to ``nn.Module``
+    methods.
+    """
+    if _has_fsdp_params(state, module):
+        # TODO: figure out the case for the composable APIs.
+        cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
+
+
+def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
+    """
+    Registers the flattened parameter to the wrapped module, making it
+    visible to ``nn.Module`` methods.
+
+    We do not use :meth:`nn.Module.register_parameter` because we want
+    ``FLAT_PARAM`` to always be an attribute but dynamically change whether
+    it is visible to ``nn.Module`` methods.
+    """
+    handle = _module_handle(state, module)
+    if _has_fsdp_params(state, module):
+        # TODO: figure out the case for the composable APIs.
+        cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param
+
+
+@contextlib.contextmanager
+def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
+    """
+    Assumes that the flattened parameter is unsharded. When in the context,
+    de-registers the flattened parameter and unflattens the original
+    parameters as ``nn.Parameter`` views into the flattened parameter.
+    After the context, re-registers the flattened parameter and restores
+    the original parameters as ``Tensor`` views into the flattened
+    parameter.
+    """
+    handle = _module_handle(state, module)
+    if not handle:
+        yield
+    else:
+        _deregister_flat_param(state, module)
+        try:
+            with handle.unflatten_as_params():
+                yield
+        finally:
+            if not handle._use_orig_params:
+                _register_flat_param(state, module)
+
+
+def _validate_unshard_params_args(
+    state: _FSDPState,
+    writeback: bool,
+    rank0_only: bool,
+    offload_to_cpu: bool,
+    with_grads: bool,
+) -> None:
+    if with_grads and (offload_to_cpu or not state._use_orig_params):
+        raise NotImplementedError(
+            f"with_grads={with_grads}, "
+            f"use_orig_params={state._use_orig_params}, "
+            f"offload_to_cpu={offload_to_cpu} "
+            f"is not supported yet"
+        )
+    if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy):
+        raise NotImplementedError(
+            "offload_to_cpu=True and NO_SHARD is not supported yet"
+        )
+    if writeback and rank0_only:
+        # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
+        # persist the changes.
+        raise NotImplementedError(
+            "writeback=True and rank0_only=True is not supported yet"
+        )
+    if offload_to_cpu and not rank0_only:
+        warnings.warn(
+            "offload_to_cpu=True and rank0_only=False may result in the"
+            "unsharded parameters being redundantly copied to CPU memory for "
+            "GPUs sharing the same CPU memory, which risks CPU OOM. We "
+            "recommend using offload_to_cpu=True with rank0_only=True."
+        )
+
+
+@contextlib.contextmanager
+def _unshard_fsdp_state_params(
+    module: nn.Module,
+    state: _FSDPState,
+    writeback: bool,
+    rank0_only: bool,
+    offload_to_cpu: bool,
+    with_grads: bool,
+):
+    """
+    This unshards the parameters for a single FSDP state ``state`` that
+    corresponds to ``module``.
+    """
+    _validate_unshard_params_args(
+        state, writeback, rank0_only, offload_to_cpu, with_grads
+    )
+    state._device_handle.synchronize()
+    # If handles are shared by other module(s), the handle may be already unsharded.
+    maybe_handle = _module_handle(state, module)
+    handle = None
+    if (
+        maybe_handle
+        and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
+    ):
+        handle = maybe_handle
+    if not handle:
+        yield
+        return
+
+    assert (
+        handle._training_state == HandleTrainingState.IDLE
+    ), f"Expects the handle training to be IDLE but got {handle._training_state}"
+
+    handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
+
+    _reset_flat_param_grad_info_if_needed(handle)
+    free_unsharded_flat_param = handle.needs_unshard()
+    # No need to call `wait_stream()` since we unshard in the computation
+    # stream directly
+    computation_stream = state._device_handle.current_stream()
+    _unshard(state, handle, computation_stream, computation_stream)
+    if with_grads:
+        _unshard_grads(handle)
+
+    if rank0_only and state.rank != 0:
+        # Free the unsharded flattened parameter early
+        _reshard(state, handle, free_unsharded_flat_param)
+        if with_grads:
+            _reshard_grads(handle)
+        try:
+            yield
+        finally:
+            handle._training_state = HandleTrainingState.IDLE
+    else:
+        # Unflatten the unsharded flattened parameters
+        with contextlib.ExitStack() as stack:
+            # Invariant: rank == 0 or !rank0_only
+            if offload_to_cpu and handle.uses_sharded_strategy:
+                stack.enter_context(handle.to_cpu())
+                # NOTE: Since PyTorch enforces that a parameter and its
+                # gradients need to match metadata (e.g. device), we must
+                # move gradients to CPU *after* we move parameters.
+            # NOTE: This assumes 1 `FlatParameter`
+            if not state._use_orig_params:
+                stack.enter_context(_unflatten_as_params(state, module))
+            try:
+                yield
+            finally:
+                stack.close()
+                if writeback:
+                    _writeback_to_local_shard(handle, with_grads)
+                _reshard(state, handle, free_unsharded_flat_param)
+                if with_grads:
+                    _reshard_grads(handle)
+                handle._training_state = HandleTrainingState.IDLE
+
+
+@contextlib.contextmanager
+def _unshard_params_recurse(
+    module: nn.Module,
+    state: _FSDPState,
+    recurse: bool,
+    writeback: bool,
+    rank0_only: bool,
+    offload_to_cpu: bool,
+    with_grads: bool,
+):
+    """
+    This is a helper for :func:`_unshard_params` that recursively calls
+    :func:`_unshard_fsdp_state_params` on FSDP states if ``recurse=True``.
+    NOTE: This runs lazy initialization.
+    """
+    _validate_unshard_params_args(
+        state, writeback, rank0_only, offload_to_cpu, with_grads
+    )
+    if recurse:
+        with contextlib.ExitStack() as stack:
+            # TODO (awgu): The traversal function does not traverse through
+            # incompatible composable APIs. Verify if this is the desired
+            # behavior for this function.
+            for state, fsdp_module in zip(
+                *traversal_utils._get_fsdp_states_with_modules(module)
+            ):
+                stack.enter_context(
+                    _unshard_params_recurse(
+                        module=fsdp_module,
+                        state=state,
+                        recurse=False,
+                        writeback=writeback,
+                        rank0_only=rank0_only,
+                        offload_to_cpu=offload_to_cpu,
+                        with_grads=with_grads,
+                    )
+                )
+            yield
+        return
+    _lazy_init(state, module)
+    if state.training_state == TrainingState.FORWARD_BACKWARD:
+        raise AssertionError(
+            "Cannot manually unshard parameters during forward/backward"
+        )
+    elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
+        raise AssertionError(
+            "Cannot manually unshard parameters when already unsharding parameters"
+        )
+    with _unshard_fsdp_state_params(
+        module=module,
+        state=state,
+        writeback=writeback,
+        rank0_only=rank0_only,
+        offload_to_cpu=offload_to_cpu,
+        with_grads=with_grads,
+    ):
+        try:
+            state.training_state = TrainingState.SUMMON_FULL_PARAMS
+            yield
+        finally:
+            state.training_state = TrainingState.IDLE
+
+
+@contextlib.contextmanager
+def _unshard_params(
+    module: nn.Module,
+    recurse: bool,
+    writeback: bool,
+    rank0_only: bool,
+    offload_to_cpu: bool,
+    with_grads: bool,
+):
+    """
+    This unshards FSDP-managed parameters for all modules with FSDP applied in
+    the module tree rooted at ``module``.
+    """
+    root_fsdp_states, root_fsdp_modules = _get_fsdp_root_states_with_modules(module)
+    with contextlib.ExitStack() as stack:
+        for root_fsdp_state, root_fsdp_module in zip(
+            root_fsdp_states, root_fsdp_modules
+        ):
+            stack.enter_context(
+                _unshard_params_recurse(
+                    module=root_fsdp_module,
+                    state=root_fsdp_state,
+                    recurse=recurse,
+                    writeback=writeback,
+                    rank0_only=rank0_only,
+                    offload_to_cpu=offload_to_cpu,
+                    with_grads=with_grads,
+                )
+            )
+        yield
+    return
+
+
+def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
+    """
+    Deregisters the original parameters; registers the ``FlatParameter``.
+    """
+    handle = _module_handle(state, module)
+    if not handle:
+        return
+    _p_assert(
+        handle._use_orig_params,
+        f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
+        f"handle: {handle._use_orig_params}",
+    )
+    handle._deregister_orig_params()
+    _register_flat_param(state, module)
+
+
+def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
+    """
+    Deregisters the ``FlatParameter``; registers the original parameters.
+    """
+    handle = _module_handle(state, module)
+    if not handle:
+        return
+    _deregister_flat_param(state, module)
+    if handle.is_sharded(handle.flat_param):
+        handle._use_sharded_views()
+        handle._use_sharded_grad_views()
+    else:
+        handle._use_unsharded_views(as_params=True)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/_wrap_utils.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/_wrap_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b5d3452edefabe1e5dd4ec1cd31f27a0f436ffa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/_wrap_utils.py
@@ -0,0 +1,262 @@
+import collections
+import functools
+import inspect
+import warnings
+from functools import partial
+from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
+
+import torch.nn as nn
+from torch.distributed.fsdp._common_utils import (
+    _get_module_fsdp_state,
+    _override_module_mixed_precision,
+)
+
+from torch.distributed.fsdp.wrap import (
+    _construct_wrap_fn,
+    _or_policy,
+    _Policy,
+    _post_order_apply,
+    _recursive_wrap,
+    _run_mixed_precision_override_policy,
+    _wrap_module_cls_individually,
+)
+
+
+def _auto_wrap(
+    root_module: nn.Module,
+    policy: Union[Callable, _Policy],
+    ignored_modules: Set[nn.Module],
+    ignored_params: Set[nn.Parameter],
+    root_kwargs: Dict[str, Any],
+    fsdp_fn: Callable,  # e.g. `FullyShardedDataParallel` or `fully_shard`
+):
+    """
+    Auto wraps modules in ``root_module`` 's tree according to ``policy``
+    following a post-order traversal.
+
+    Precondition: ``root_kwargs`` should contain all arguments except
+    ``module``. This function accepts the kwargs dict directly since it gets
+    forwarded into the post-order traversal function.
+    """
+    mixed_precision = root_kwargs["mixed_precision"]
+    is_wrapper = inspect.isclass(fsdp_fn)
+    # TODO: We may relax this no-nested-wrapping constraint to support manual
+    # wrapping followed by auto wrapping.
+    _check_nested_wrapping(root_module)
+
+    if isinstance(policy, _Policy):
+        root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
+        target_module_to_kwargs = policy._run_policy(
+            root_module, ignored_modules, root_kwargs
+        )
+        if mixed_precision is not None:
+            target_module_to_kwargs = _run_mixed_precision_override_policy(
+                root_module,
+                mixed_precision._module_classes_to_ignore,
+                ignored_modules,
+                root_kwargs,
+                target_module_to_kwargs,
+            )
+            overridden_module_classes = _override_module_mixed_precision(
+                root_module, mixed_precision._module_classes_to_ignore
+            )
+            _warn_on_overridden_mixed_precision(overridden_module_classes)
+        use_orig_params = root_kwargs.get("use_orig_params", False)
+        _validate_frozen_params(
+            root_module,
+            set(target_module_to_kwargs.keys()),
+            ignored_params,
+            use_orig_params,
+        )
+        wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
+        _post_order_apply(root_module, wrap_fn)
+        return
+
+    recursive_wrap_kwargs = {
+        "module": root_module,
+        "auto_wrap_policy": policy,
+        "wrapper_cls": fsdp_fn,
+        "ignored_modules": ignored_modules,
+        "ignored_params": ignored_params,
+        "only_wrap_children": True,
+    }
+    if mixed_precision is not None:
+        # Wrap modules of the ignored types separately and register forward
+        # hooks to cast to fp32 and back to the original dtype, respectively
+        overridden_module_classes = _override_module_mixed_precision(
+            root_module, mixed_precision._module_classes_to_ignore
+        )
+        policy = functools.partial(
+            _or_policy,
+            policies=[
+                policy,
+                partial(
+                    _wrap_module_cls_individually,
+                    module_classes=mixed_precision._module_classes_to_ignore,
+                ),
+            ],
+        )
+        recursive_wrap_kwargs["auto_wrap_policy"] = policy
+        _warn_on_overridden_mixed_precision(overridden_module_classes)
+    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
+
+
+def _check_nested_wrapping(root_module: nn.Module):
+    for module_name, module in root_module.named_modules():
+        if _get_module_fsdp_state(module) is not None:
+            raise ValueError(
+                "FSDP auto wrapping requires modules to not already have "
+                f"FSDP applied but found {module_name} in\n{root_module}"
+            )
+
+
+def _warn_on_overridden_mixed_precision(
+    overridden_module_classes: Set[Type[nn.Module]],
+):
+    if len(overridden_module_classes) == 0:
+        return
+    warnings.warn(
+        "Both mixed precision and an auto_wrap_policy were specified to FSDP, "
+        f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
+        "These modules will be wrapped as separate FSDP instacnes with mixed "
+        "precision disabled."
+    )
+
+
+def _validate_frozen_params(
+    root_module: nn.Module,
+    modules_to_wrap: Set[nn.Module],
+    ignored_params: Set[nn.Parameter],
+    use_orig_params: bool,
+):
+    """
+    This checks that, given ``modules_to_wrap``, each module would manage
+    parameters that are uniformly frozen or non-frozen. This uniformity
+    requirement is strict for ``use_orig_params=False`` (hard error) and highly
+    recommended for ``use_orig_params=True`` (user warning).
+    """
+    post_order_named_modules = _get_post_order_named_modules(root_module)
+    visited_modules: Set[nn.Module] = set()
+    for module_name, module in post_order_named_modules:
+        if module in modules_to_wrap:
+            param_to_fqn = _get_managed_param_to_fqn(
+                module, ignored_params, visited_modules, module_name
+            )
+            frozen_param_fqns: List[str] = []
+            frozen_param_numel = 0
+            nonfrozen_param_fqns: List[str] = []
+            nonfrozen_param_numel = 0
+            for param, fqn in param_to_fqn.items():
+                if param.requires_grad:
+                    nonfrozen_param_fqns.append(fqn)
+                    nonfrozen_param_numel += param.numel()
+                else:
+                    frozen_param_fqns.append(fqn)
+                    frozen_param_numel += param.numel()
+            if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
+                msg = f"{module_name} has both parameters with requires_grad=True and False."
+                if use_orig_params:
+                    total_param_numel = frozen_param_numel + nonfrozen_param_numel
+                    msg += (
+                        " We do not recommend wrapping such modules since "
+                        "the gradient memory usage will be higher than expected "
+                        f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
+                        "before sharding via reduce-scatter). "
+                    )
+                else:
+                    msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
+                msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
+                msg += (
+                    f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
+                    f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
+                )
+                if use_orig_params:
+                    warnings.warn(msg)
+                else:
+                    raise ValueError(msg)
+
+
+def _get_post_order_named_modules(
+    root_module: nn.Module,
+) -> List[Tuple[str, nn.Module]]:
+    """
+    This returns the named modules following a post-order traversal, which is a
+    valid reverse topological sort. We achieve this using the reverse of a
+    stack-based DFS order instead of reversing ``root_module.named_modules()``
+    since the former gives the modules in registration order at each level in
+    the module tree (as opposed to the reverse), which allows us to error/warn
+    on the first registered module that violates the condition.
+
+    For example, consider the following module structure:
+        M(
+          S1(),
+          S2(
+            SS1(),
+            SS2(),
+          ),
+          S3(),
+        )
+    The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
+    ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
+    """
+    visited_modules = {root_module}
+    stack = [("", root_module)]
+    # Append and reverse at the end for linear-time algorithm
+    reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
+    while stack:
+        module_name, module = stack.pop()
+        reverse_post_order_named_modules.append((module_name, module))
+        for child_module_name, child_module in module.named_children():
+            if child_module is None:  # only for overrides of `named_children()`
+                continue
+            if child_module not in visited_modules:
+                visited_modules.add(child_module)
+                if module_name != "":
+                    child_module_name = module_name + "." + child_module_name
+                stack.append((child_module_name, child_module))
+    post_order_named_modules = list(reversed(reverse_post_order_named_modules))
+    return post_order_named_modules
+
+
+def _get_managed_param_to_fqn(
+    module_to_wrap: nn.Module,
+    ignored_params: Set[nn.Parameter],
+    visited_modules: Set[nn.Module],
+    root_prefix: str,
+) -> Dict[nn.Parameter, str]:
+    """
+    This returns a dict that maps managed parameter to its FQN for the given
+    ``module_to_wrap``. The dict's keys are exactly the parameters that would
+    be managed by the module, where this is achieved by calling this function
+    on the modules to wrap in reverse topological order, destructively updating
+    ``visited_modules``, and not traversing into those modules. The FQNs are
+    prefixed from the root (via ``root_prefix``) to be more informative.
+
+    NOTE: This function is meant to be called pre-wrapping and iteratively in
+    reverse topological order to cover the full module tree. This differs from
+    the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
+    on the full module tree in one shot. Given those differences, we do not try
+    to unify the two.
+    """
+    param_to_fqn: Dict[nn.Parameter, str] = {}
+    # Run BFS (or any tree traversal works)
+    queue = collections.deque([(module_to_wrap, root_prefix)])
+    visited_modules.add(module_to_wrap)
+    while queue:
+        module, prefix = queue.popleft()
+        for param_name, param in module.named_parameters(recurse=False):
+            if param not in ignored_params:
+                fqn = param_name if prefix == "" else prefix + "." + param_name
+                param_to_fqn[param] = fqn
+        for child_module_name, child_module in module.named_children():
+            if child_module is None:  # only for overrides of `named_children()`
+                continue
+            if child_module not in visited_modules:
+                visited_modules.add(child_module)
+                child_prefix = (
+                    child_module_name
+                    if prefix == ""
+                    else prefix + "." + child_module_name
+                )
+                queue.append((child_module, child_prefix))
+    return param_to_fqn
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/api.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..4551ddf8e62694edc7c7b9d934c1e8beb2a58d63
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/api.py
@@ -0,0 +1,410 @@
+"""
+This file includes public APIs for FSDP such as the classes used for the
+constructor arguments.
+"""
+
+from dataclasses import dataclass
+from enum import auto, Enum
+
+from typing import Optional, Sequence, Type
+
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+
+__all__ = [
+    "ShardingStrategy",
+    "BackwardPrefetch",
+    "MixedPrecision",
+    "CPUOffload",
+    "StateDictType",
+    "StateDictConfig",
+    "FullStateDictConfig",
+    "LocalStateDictConfig",
+    "ShardedStateDictConfig",
+    "OptimStateDictConfig",
+    "FullOptimStateDictConfig",
+    "LocalOptimStateDictConfig",
+    "ShardedOptimStateDictConfig",
+    "StateDictSettings",
+]
+
+
+class ShardingStrategy(Enum):
+    """
+    This specifies the sharding strategy to be used for distributed training by
+    :class:`FullyShardedDataParallel`.
+
+    - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
+      For the parameters, this strategy unshards (via all-gather) before the
+      forward, reshards after the forward, unshards before the backward
+      computation, and reshards after the backward computation. For gradients,
+      it synchronizes and shards them (via reduce-scatter) after the backward
+      computation. The sharded optimizer states are updated locally per rank.
+    - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
+      computation, and additionally, parameters are sharded outside
+      computation. For the parameters, this strategy unshards before the
+      forward, does not reshard them after the forward, and only reshards them
+      after the backward computation. The sharded optimizer states are updated
+      locally per rank. Inside ``no_sync()``, the parameters are not resharded
+      after the backward computation.
+    - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
+      but instead replicated across ranks similar to PyTorch's
+      :class:`DistributedDataParallel` API. For gradients, this strategy
+      synchronizes them (via all-reduce) after the backward computation. The
+      unsharded optimizer states are updated locally per rank.
+    - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
+      nodes. This results in reduced communication volume as expensive all-gathers and
+      reduce-scatters are only done within a node, which can be more performant for medium
+      -sized models.
+    - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
+      nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
+      since the unsharded parameters are not freed after the forward pass, saving the
+      all-gathers in the pre-backward.
+    """
+
+    FULL_SHARD = auto()
+    SHARD_GRAD_OP = auto()
+    NO_SHARD = auto()
+    HYBRID_SHARD = auto()
+    _HYBRID_SHARD_ZERO2 = auto()
+
+
+class BackwardPrefetch(Enum):
+    """
+    This configures explicit backward prefetching, which improves throughput by
+    enabling communication and computation overlap in the backward pass at the
+    cost of slightly increased memory usage.
+
+    - ``BACKWARD_PRE``: This enables the most overlap but increases memory
+      usage the most. This prefetches the next set of parameters *before* the
+      current set of parameters' gradient computation. This overlaps the *next
+      all-gather* and the *current gradient computation*, and at the peak, it
+      holds the current set of parameters, next set of parameters, and current
+      set of gradients in memory.
+    - ``BACKWARD_POST``: This enables less overlap but requires less memory
+      usage. This prefetches the next set of parameters *after* the current
+      set of parameters' gradient computation. This overlaps the *current
+      reduce-scatter* and the *next gradient computation*, and it frees the
+      current set of parameters before allocating memory for the next set of
+      parameters, only holding the next set of parameters and current set of
+      gradients in memory at the peak.
+    - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
+      the backward prefetching altogether. This has no overlap and does not
+      increase memory usage. In general, we do not recommend this setting since
+      it may degrade throughput significantly.
+
+    For more technical context: For a single process group using NCCL backend,
+    any collectives, even if issued from different streams, contend for the
+    same per-device NCCL stream, which implies that the relative order in which
+    the collectives are issued matters for overlapping. The two backward
+    prefetching values correspond to different issue orders.
+    """
+
+    # NOTE: For both modes, the ordering that defines "current" and "next" is
+    # not always exact in the current implementation. A mistargeted prefetch
+    # simply means that the parameter memory is allocated earlier than needed,
+    # possibly increasing peak memory usage, but does not affect correctness.
+    BACKWARD_PRE = auto()
+    BACKWARD_POST = auto()
+
+
+@dataclass
+class MixedPrecision:
+    """
+    This configures FSDP-native mixed precision training.
+
+    Attributes:
+        param_dtype (Optional[torch.dtype]): This specifies the dtype for model
+            parameters during forward and backward and thus the dtype for
+            forward and backward computation. Outside forward and backward, the
+            *sharded* parameters are kept in full precision (e.g. for the
+            optimizer step), and for model checkpointing, the parameters are
+            always saved in full precision. (Default: ``None``)
+        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
+            gradient reduction (i.e. reduce-scatter or all-reduce). If this is
+            ``None`` but ``param_dtype`` is not ``None``, then this takes on
+            the ``param_dtype`` value, still running gradient reduction in low
+            precision. This is permitted to differ from ``param_dtype``, e.g.
+            to force gradient reduction to run in full precision. (Default:
+            ``None``)
+        buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
+            buffers. FSDP does not shard buffers. Rather, FSDP casts them to
+            ``buffer_dtype`` in the first forward pass and keeps them in that
+            dtype thereafter. For model checkpointing, the buffers are saved
+            in full precision except for ``LOCAL_STATE_DICT``. (Default:
+            ``None``)
+        keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
+            gradients to full precision after the backward pass in preparation
+            for the optimizer step. If ``True``, then FSDP keeps the gradients
+            in the dtype used for gradient reduction, which can save memory if
+            using a custom optimizer that supports running in low precision.
+            (Default: ``False``)
+        cast_forward_inputs (bool): If ``True``, then this FSDP module casts
+            its forward args and kwargs to ``param_dtype``. This is to ensure
+            that parameter and input dtypes match for forward computation, as
+            required by many ops. This may need to be set to ``True`` when only
+            applying mixed precision to some but not all FSDP modules, in which
+            case a mixed-precision FSDP submodule needs to recast its inputs.
+            (Default: ``False``)
+        cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
+            casts its forward args and kwargs to ``param_dtype``, overriding
+            the value of ``cast_forward_inputs``. For non-root FSDP modules,
+            this does not do anything. (Default: ``True``)
+        _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
+            module classes to ignore for mixed precision when using an
+            ``auto_wrap_policy``: Modules of these classes will have FSDP
+            applied to them separately with mixed precision disabled (meaning
+            that the final FSDP construction would deviate from the specified
+            policy). If ``auto_wrap_policy`` is not specified, then this does
+            not do anything. This API is experimental and subject to change.
+            (Default: ``(_BatchNorm,)``)
+
+    .. note:: This API is experimental and subject to change.
+
+    .. note:: Only floating point tensors are cast to their specified dtypes.
+
+    .. note:: In ``summon_full_params``, parameters are forced to full
+        precision, but buffers are not.
+
+    .. note:: Layer norm and batch norm accumulate in ``float32`` even when
+        their inputs are in a low precision like ``float16`` or ``bfloat16``.
+        Disabling FSDP's mixed precision for those norm modules only means that
+        the affine parameters are kept in ``float32``. However, this incurs
+        separate all-gathers and reduce-scatters for those norm modules, which
+        may be inefficient, so if the workload permits, the user should prefer
+        to still apply mixed precision to those modules.
+
+    .. note:: By default, if the user passes a model with any ``_BatchNorm``
+        modules and specifies an ``auto_wrap_policy``, then the batch norm
+        modules will have FSDP applied to them separately with mixed precision
+        disabled. See the ``_module_classes_to_ignore`` argument.
+
+    .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
+        ``cast_forward_inputs=False`` by default. For the root FSDP instance,
+        its ``cast_root_forward_inputs`` takes precedence over its
+        ``cast_forward_inputs``. For non-root FSDP instances, their
+        ``cast_root_forward_inputs`` values are ignored. The default setting is
+        sufficient for the typical case where each FSDP instance has the same
+        ``MixedPrecision`` configuration and only needs to cast inputs to the
+        ``param_dtype`` at the beginning of the model's forward pass.
+
+    .. note:: For nested FSDP instances with different ``MixedPrecision``
+        configurations, we recommend setting individual ``cast_forward_inputs``
+        values to configure casting inputs or not before each instance's
+        forward. In such a case, since the casts happen before each FSDP
+        instance's forward, a parent FSDP instance should have its non-FSDP
+        submodules run before its FSDP submodules to avoid the activation dtype
+        being changed due to a different ``MixedPrecision`` configuration.
+
+        Example::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
+            >>> model[1] = FSDP(
+            >>>     model[1],
+            >>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
+            >>> )
+            >>> model = FSDP(
+            >>>     model,
+            >>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
+            >>> )
+
+        The above shows a working example. On the other hand, if ``model[1]``
+        were replaced with ``model[0]``, meaning that the submodule using
+        different ``MixedPrecision`` ran its forward first, then ``model[1]``
+        would incorrectly see ``float16`` activations instead of ``bfloat16``
+        ones.
+
+    """
+
+    param_dtype: Optional[torch.dtype] = None
+    reduce_dtype: Optional[torch.dtype] = None
+    buffer_dtype: Optional[torch.dtype] = None
+    keep_low_precision_grads: bool = False
+    cast_forward_inputs: bool = False
+    cast_root_forward_inputs: bool = True
+    _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
+
+
+@dataclass
+class CPUOffload:
+    """
+    This configures CPU offloading.
+
+    Attributes:
+        offload_params (bool): This specifies whether to offload parameters to
+            CPU when not involved in computation. If ``True``, then this
+            offloads gradients to CPU as well, meaning that the optimizer step
+            runs on CPU.
+    """
+
+    offload_params: bool = False
+
+
+class StateDictType(Enum):
+    """
+    This enum indicates that which type of ``state_dict`` the FSDP module is
+    currently processing (returning or loading).
+    The default value is FULL_STATE_DICT to comply the PyTorch convention.
+    ..note::
+        FSDP currently supports three types of ``state_dict``:
+            1. ``state_dict/load_state_dict`: this pair of APIs return and load
+               the non-sharded, unflattened parameters. The semantics is the
+               same as using DDP.
+            2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
+               and load local sharded, flattened parameters. The values returned
+               by ``_local_state_dict`` can be directly used by FSDP and is only
+               meaningful to FSDP (because parameters are flattened). Note that
+               these APIs are meant for use via the :func:`state_dict_type`
+               context manager as follows:
+                   >>> # xdoctest: +SKIP("undefined variables")
+                   >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+                   ...     state = fsdp.state_dict()  # loads local state dict
+            3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
+               return and load sharded, unflattened parameters. The ``state_dict``
+               return by ``sharded_state_dict`` can be used by all other parallel
+               schemes (resharding may be required).
+    """
+
+    FULL_STATE_DICT = auto()
+    LOCAL_STATE_DICT = auto()
+    SHARDED_STATE_DICT = auto()
+
+
+@dataclass
+class StateDictConfig:
+    """
+    ``StateDictConfig`` is the base class for all ``state_dict`` configuration
+    classes. Users should instantiate a child class (e.g.
+    ``FullStateDictConfig``) in order to configure settings for the
+    corresponding ``state_dict`` type supported by FSDP.
+
+    Attributes:
+        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
+            values to CPU, and if ``False``, then FSDP keeps them on GPU.
+            (Default: ``False``)
+    """
+
+    offload_to_cpu: bool = False
+
+
+@dataclass
+class FullStateDictConfig(StateDictConfig):
+    """
+    ``FullStateDictConfig`` is a config class meant to be used with
+    ``StateDictType.FULL_STATE_DICT``. We recommend enabling both
+    ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
+    dicts to save GPU memory and CPU memory, respectively. This config class
+    is meant to be used via the :func:`state_dict_type` context manager as
+    follows:
+
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+        >>> fsdp = FSDP(model, auto_wrap_policy=...)
+        >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+        >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
+        >>>     state = fsdp.state_dict()
+        >>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
+        >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
+        >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
+        >>> if dist.get_rank() == 0:
+        >>>     # Load checkpoint only on rank 0 to avoid memory redundancy
+        >>>     state_dict = torch.load("my_checkpoint.pt")
+        >>>     model.load_state_dict(state_dict)
+        >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
+        >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
+        >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
+        >>> # After this point, all ranks have FSDP model with loaded checkpoint.
+
+    Attributes:
+        rank0_only (bool): If ``True``, then only rank 0 saves the full state
+            dict, and nonzero ranks save an empty dict. If ``False``, then all
+            ranks save the full state dict. (Default: ``False``)
+    """
+
+    rank0_only: bool = False
+
+
+@dataclass
+class LocalStateDictConfig(StateDictConfig):
+    pass
+
+
+@dataclass
+class ShardedStateDictConfig(StateDictConfig):
+    """
+    ``ShardedStateDictConfig`` is a config class meant to be used with
+    ``StateDictType.SHARDED_STATE_DICT``.
+
+    Attributes:
+        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
+            as ``DTensor``, and if ``False``, then FSDP saves them as
+            ``ShardedTensor``. (Default: ``False``)
+
+    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
+      and it is used by FSDP to determine the type of state dict values. Users should not
+      manually modify ``_use_dtensor``.
+    """
+
+    _use_dtensor: bool = False
+
+
+@dataclass
+class OptimStateDictConfig:
+    """
+    ``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
+    configuration classes.  Users should instantiate a child class (e.g.
+    ``FullOptimStateDictConfig``) in order to configure settings for the
+    corresponding ``optim_state_dict`` type supported by FSDP.
+
+    Attributes:
+        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
+            tensor values to CPU, and if ``False``, then FSDP keeps them on the
+            original device (which is GPU unless parameter CPU offloading is
+            enabled). (Default: ``True``)
+    """
+
+    offload_to_cpu: bool = True
+
+
+@dataclass
+class FullOptimStateDictConfig(OptimStateDictConfig):
+    """
+    Attributes:
+        rank0_only (bool): If ``True``, then only rank 0 saves the full state
+            dict, and nonzero ranks save an empty dict. If ``False``, then all
+            ranks save the full state dict. (Default: ``False``)
+    """
+
+    rank0_only: bool = False
+
+
+@dataclass
+class LocalOptimStateDictConfig(OptimStateDictConfig):
+    offload_to_cpu: bool = False
+
+
+@dataclass
+class ShardedOptimStateDictConfig(OptimStateDictConfig):
+    """
+    ``ShardedOptimStateDictConfig`` is a config class meant to be used with
+    ``StateDictType.SHARDED_STATE_DICT``.
+
+    Attributes:
+        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
+            as ``DTensor``, and if ``False``, then FSDP saves them as
+            ``ShardedTensor``. (Default: ``False``)
+
+    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
+      and it is used by FSDP to determine the type of state dict values. Users should not
+      manually modify ``_use_dtensor``.
+    """
+
+    _use_dtensor: bool = False
+
+
+@dataclass
+class StateDictSettings:
+    state_dict_type: StateDictType
+    state_dict_config: StateDictConfig
+    optim_state_dict_config: OptimStateDictConfig
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bf141619a9d02e41f1c2f115caa6fc37086260
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -0,0 +1,2075 @@
+# mypy: ignore-errors
+
+import contextlib
+import copy
+import functools
+import math
+import traceback
+import warnings
+from contextlib import contextmanager
+from enum import auto, Enum
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
+
+import torch
+import torch.distributed as dist
+import torch.distributed.fsdp._traversal_utils as traversal_utils
+import torch.nn as nn
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+    _CHECKPOINT_WRAPPED_MODULE,
+    ActivationWrapper,
+)
+from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
+from torch.distributed.fsdp._common_utils import (
+    _FSDPState,
+    _get_param_to_fqns,
+    FSDP_PREFIX,
+    FSDP_WRAPPED_MODULE,
+    TrainingState,
+)
+from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
+from torch.distributed.fsdp._init_utils import (
+    _check_orig_params_flattened,
+    _init_buffer_state,
+    _init_core_state,
+    _init_device_handle,
+    _init_extension,
+    _init_ignored_module_states,
+    _init_param_handle_from_module,
+    _init_prefetching_state,
+    _init_process_group_state,
+    _init_runtime_state,
+    _init_state_dict_state,
+    HYBRID_SHARDING_STRATEGIES,
+    ProcessGroupType,
+)
+from torch.distributed.fsdp._runtime_utils import (
+    _get_fsdp_root_states,
+    _is_fsdp_root,
+    _lazy_init,
+    _post_forward,
+    _post_forward_reshard,
+    _pre_forward,
+    _pre_forward_unshard,
+    _root_pre_forward,
+)
+from torch.distributed.fsdp._wrap_utils import _auto_wrap
+from torch.distributed.fsdp.api import (
+    BackwardPrefetch,
+    CPUOffload,
+    FullOptimStateDictConfig,
+    FullStateDictConfig,
+    LocalOptimStateDictConfig,
+    LocalStateDictConfig,
+    MixedPrecision,
+    OptimStateDictConfig,
+    ShardedOptimStateDictConfig,
+    ShardedStateDictConfig,
+    ShardingStrategy,
+    StateDictConfig,
+    StateDictSettings,
+    StateDictType,
+)
+from torch.distributed.utils import _p_assert
+from ._flat_param import FlatParameter
+
+from ._optim_utils import (
+    _flatten_optim_state_dict,
+    _get_param_id_to_param_from_optim_input,
+    _get_param_key_to_param,
+    _get_param_to_param_id_from_optim_input,
+    _get_param_to_param_key,
+    _optim_state_dict,
+    _rekey_sharded_optim_state_dict,
+    _set_optim_use_dtensor,
+)
+from ._state_dict_utils import _register_all_state_dict_hooks
+from ._unshard_param_utils import (
+    _deregister_orig_params,
+    _register_flat_param,
+    _register_orig_params,
+    _unshard_params,
+    _unshard_params_recurse,
+)
+from .wrap import CustomPolicy, ModuleWrapPolicy
+
+
+__all__ = [
+    "FullyShardedDataParallel",
+    "OptimStateKeyType",
+]
+
+
+FLAT_PARAM = "_flat_param"
+
+
+class OptimStateKeyType(Enum):
+    """Represents the type of key in an optimizer state-dict."""
+
+    PARAM_NAME = auto()
+    PARAM_ID = auto()
+
+
+class FullyShardedDataParallel(nn.Module, _FSDPState):
+    """A wrapper for sharding module parameters across data parallel workers.
+
+    This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
+    FullyShardedDataParallel is commonly shortened to FSDP.
+
+    .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
+    .. _DeepSpeed: https://www.deepspeed.ai/
+
+    For advanced notes please refer to :ref:`fsdp_notes`.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> import torch
+        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+        >>> torch.cuda.set_device(device_id)
+        >>> sharded_module = FSDP(my_module)
+        >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
+        >>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
+        >>> loss = x.sum()
+        >>> loss.backward()
+        >>> optim.step()
+
+    .. warning::
+        The optimizer must be initialized *after* the module has been wrapped
+        with FSDP since FSDP will shard and transform the module's parameters
+        in a way that may not preserve the original parameter variables. Thus,
+        the previously initialized optimizer may have stale references to the
+        parameters.
+
+    .. warning::
+        If the destination CUDA device has ID ``dev_id``, either (1)
+        ``module`` should already be placed on that device, (2) the device
+        should be set using ``torch.cuda.set_device(dev_id)``, or (3)
+        ``dev_id`` should be passed into the ``device_id`` constructor
+        argument. This FSDP instance's compute device will be that destination
+        device. For (1) and (3), the FSDP initialization always occurs on GPU.
+        For (2), the FSDP initialization happens on ``module`` 's current
+        device, which may be CPU.
+
+    .. warning::
+        FSDP currently does not support gradient accumulation outside
+        ``no_sync()`` when using CPU offloading. Trying to do so yields
+        incorrect results since FSDP will use the newly-reduced gradient
+        instead of accumulating with any existing gradient.
+
+    .. warning::
+        Changing the original parameter variable names after construction will
+        lead to undefined behavior.
+
+    .. warning::
+        Passing in the ``sync_module_states=True`` flag requires ``module`` to
+        be on GPU or to use the ``device_id`` argument to specify a CUDA device
+        that FSDP will move ``module`` to in the FSDP constructor. This is
+        because ``sync_module_states=True`` requires GPU communication.
+
+    .. warning::
+        As of PyTorch 1.12, FSDP only offers limited support for shared parameters
+        (for example, setting one ``Linear`` layer's weight to another's). In
+        particular, modules that share parameters must be wrapped as part of the
+        same FSDP unit. If enhanced shared parameter support is needed for your
+        use case, please ping https://github.com/pytorch/pytorch/issues/77724
+
+    .. warning::
+        FSDP has some constraints on freezing parameters (i.e. setting
+        ``param.requires_grad=False``). For ``use_orig_params=False``, each
+        FSDP instance must manage parameters that are all frozen or all
+        non-frozen. For ``use_orig_params=True``, FSDP supports mixing frozen
+        and non-frozen, but we recommend not doing so since then the gradient
+        memory usage will be higher than expected (namely, equivalent to not
+        freezing those parameters). This means that ideally, frozen parameters
+        should be isolated into their own ``nn.Module`` s and wrapped
+        separately with FSDP.
+
+    .. note::
+        Attempting to run the forward pass of a submodule that is contained in an
+        FSDP instance is not supported and will result in errors. This is because the
+        submodule's parameters will be sharded, but it itself is not an FSDP instance,
+        so its forward pass will not all-gather the full parameters appropriately.
+        This could potentially happen when attempting to run only the encoder of a
+        encoder-decoder model, and the encoder is not wrapped in its own FSDP instance. To
+        resolve this, please wrap the submodule in its own FSDP unit.
+
+    .. note::
+        FSDP moves input tensors to the ``forward`` method to the GPU compute
+        device, so the user does not need to manually move them from CPU.
+
+    .. warning::
+        The user should not modify the parameters between forward and backward
+        without using the :meth:`summon_full_params` context since the
+        modifications may not persist. Moreover, for ``use_orig_params=False``,
+        accessing the original parameters between forward and backward may
+        raise an illegal memory access.
+
+    .. warning::
+        For ``use_orig_params=True``, ``ShardingStrategy.SHARD_GRAD_OP``
+        exposes the unsharded parameters, not the sharded parameters, after
+        forward since it does not free the unsharded ones, unlike
+        ``ShardingStrategy.FULL_SHARD``. One caveat is that, since gradients
+        are always sharded or ``None``, ``ShardingStrategy.SHARD_GRAD_OP`` will
+        not expose the sharded gradients with the unsharded parameters after
+        forward. If you want to inspect the gradients, try
+        :meth:`summon_full_params` with ``with_grads=True``.
+
+    .. warning::
+        FSDP replaces managed modules' parameters with ``torch.Tensor`` views
+        during forward and backward computation for autograd-related reasons.
+        If your module's forward relies on saved references to the parameters
+        instead of reacquiring the references each iteration, then it will not
+        see FSDP's newly created views, and autograd will not work correctly.
+
+    .. note::
+        With ``limit_all_gathers=True``, you may see a gap in the FSDP
+        pre-forward where the CPU thread is not issuing any kernels. This is
+        intentional and shows the rate limiter in effect. Synchronizing the CPU
+        thread in that way prevents over-allocating memory for subsequent
+        all-gathers, and it should not actually delay GPU kernel execution.
+
+    .. note::
+        When using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD`` with the
+        sharding process group being intra-node and the replication process
+        group being inter-node, setting ``NCCL_CROSS_NIC=1`` can help improve
+        the all-reduce times over the replication process group for some
+        cluster setups.
+
+    .. warning::
+        FSDP does not work with double backwards due to how it registers
+        backward hooks.
+
+    Args:
+        module (nn.Module):
+            This is the module to be wrapped with FSDP.
+        process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]):
+            This is the process group over which the model is sharded and thus
+            the one used for FSDP's all-gather and reduce-scatter collective
+            communications. If ``None``, then FSDP uses the default process
+            group. For hybrid sharding strategies such as
+            ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of
+            process groups, representing the groups over which to shard and
+            replicate, respectively. If ``None``, then FSDP constructs process
+            groups for the user to shard intra-node and replicate inter-node.
+            (Default: ``None``)
+        sharding_strategy (Optional[ShardingStrategy]):
+            This configures the sharding strategy, which may trade off memory
+            saving and communication overhead. See :class:`ShardingStrategy`
+            for details. (Default: ``FULL_SHARD``)
+        cpu_offload (Optional[CPUOffload]):
+            This configures CPU offloading. If this is set to ``None``, then
+            no CPU offloading happens. See :class:`CPUOffload` for details.
+            (Default: ``None``)
+        auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]):
+            This specifies a policy to apply FSDP to submodules of ``module``,
+            which is needed for communication and computation overlap and thus
+            affects performance. If ``None``, then FSDP only applies to
+            ``module``, and users should manually apply FSDP to parent modules
+            themselves (proceeding bottom-up). For convenience, this accepts
+            ``ModuleWrapPolicy`` directly, which allows users to specify the
+            module classes to wrap (e.g. the transformer block). Otherwise,
+            this should be a callable that takes in three arguments
+            ``module: nn.Module``, ``recurse: bool``, and
+            ``nonwrapped_numel: int`` and should return a ``bool`` specifying
+            whether the passed-in ``module`` should have FSDP applied if
+            ``recurse=False`` or if the traversal should continue into the
+            module's subtree if ``recurse=True``. Users may add additional
+            arguments to the callable. The ``size_based_auto_wrap_policy`` in
+            ``torch.distributed.fsdp.wrap.py`` gives an example callable that
+            applies FSDP to a module if the parameters in its subtree exceed
+            100M numel. We recommend printing the model after applying FSDP
+            and adjusting as needed.
+
+            Example::
+
+                >>> def custom_auto_wrap_policy(
+                >>>     module: nn.Module,
+                >>>     recurse: bool,
+                >>>     nonwrapped_numel: int,
+                >>>     # Additional custom arguments
+                >>>     min_num_params: int = int(1e8),
+                >>> ) -> bool:
+                >>>     return nonwrapped_numel >= min_num_params
+                >>> # Configure a custom `min_num_params`
+                >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+
+        backward_prefetch (Optional[BackwardPrefetch]):
+            This configures explicit backward prefetching of all-gathers. If
+            ``None``, then FSDP does not backward prefetch, and there is no
+            communication and computation overlap in the backward pass. See
+            :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``)
+        mixed_precision (Optional[MixedPrecision]):
+            This configures native mixed precision for FSDP. If this is set to
+            ``None``, then no mixed precision is used. Otherwise, parameter,
+            buffer, and gradient reduction dtypes can be set. See
+            :class:`MixedPrecision` for details. (Default: ``None``)
+        ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
+            own parameters and child modules' parameters and buffers are
+            ignored by this instance. None of the modules directly in
+            ``ignored_modules`` should be :class:`FullyShardedDataParallel`
+            instances, and any child modules that are already-constructed
+            :class:`FullyShardedDataParallel` instances will not be ignored if
+            they are nested under this instance. This argument may be used to
+            avoid sharding specific parameters at module granularity when using an
+            ``auto_wrap_policy`` or if parameters' sharding is not managed by
+            FSDP. (Default: ``None``)
+        param_init_fn (Optional[Callable[[nn.Module], None]]):
+            A ``Callable[torch.nn.Module] -> None`` that
+            specifies how modules that are currently on the meta device should
+            be initialized onto an actual device. As of v1.12, FSDP detects
+            modules with parameters or buffers on meta device via ``is_meta``
+            and either applies ``param_init_fn`` if specified or calls
+            ``nn.Module.reset_parameters()`` otherwise. For both cases, the
+            implementation should *only* initialize the parameters/buffers of
+            the module, not those of its submodules. This is to avoid
+            re-initialization. In addition, FSDP also supports deferred
+            initialization via torchdistX's (https://github.com/pytorch/torchdistX)
+            ``deferred_init()`` API, where the deferred modules are initialized
+            by calling ``param_init_fn`` if specified or torchdistX's default
+            ``materialize_module()`` otherwise. If ``param_init_fn`` is
+            specified, then it is applied to all meta-device modules, meaning
+            that it should probably case on the module type. FSDP calls the
+            initialization function before parameter flattening and sharding.
+
+            Example::
+
+                >>> # xdoctest: +SKIP("undefined variables")
+                >>> module = MyModule(device="meta")
+                >>> def my_init_fn(module: nn.Module):
+                >>>     # E.g. initialize depending on the module type
+                >>>     ...
+                >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
+                >>> print(next(fsdp_model.parameters()).device) # current CUDA device
+                >>> # With torchdistX
+                >>> module = deferred_init.deferred_init(MyModule, device="cuda")
+                >>> # Will initialize via deferred_init.materialize_module().
+                >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
+
+        device_id (Optional[Union[int, torch.device]]): An ``int`` or
+            ``torch.device`` giving the CUDA device on which FSDP
+            initialization takes place, including the module initialization
+            if needed and the parameter sharding. This should be specified to
+            improve initialization speed if ``module`` is on CPU. If the
+            default CUDA device was set (e.g. via ``torch.cuda.set_device``),
+            then the user may pass ``torch.cuda.current_device`` to this.
+            (Default: ``None``)
+        sync_module_states (bool): If ``True``, then each FSDP module will
+            broadcast module parameters and buffers from rank 0 to ensure that
+            they are replicated across ranks (adding communication overhead to
+            this constructor). This can help load ``state_dict`` checkpoints
+            via ``load_state_dict`` in a memory efficient way. See
+            :class:`FullStateDictConfig` for an example of this. (Default:
+            ``False``)
+        forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches
+            the next forward-pass all-gather before the current forward
+            computation. This is only useful for CPU-bound workloads, in which
+            case issuing the next all-gather earlier may improve overlap. This
+            should only be used for static-graph models since the prefetching
+            follows the first iteration's execution order. (Default: ``False``)
+        limit_all_gathers (bool): If ``True``, then FSDP explicitly
+            synchronizes the CPU thread to ensure GPU memory usage from only
+            *two* consecutive FSDP instances (the current instance running
+            computation and the next instance whose all-gather is prefetched).
+            If ``False``, then FSDP allows the CPU thread to issue all-gathers
+            without any extra synchronization. (Default: ``True``) We often
+            refer to this feature as the "rate limiter". This flag should only
+            be set to ``False`` for specific CPU-bound workloads with low
+            memory pressure in which case the CPU thread can aggressively issue
+            all kernels without concern for the GPU memory usage.
+        use_orig_params (bool): Setting this to ``True`` has FSDP use
+            ``module`` 's original parameters. FSDP exposes those original
+            parameters to the user via :meth:`nn.Module.named_parameters`
+            instead of FSDP's internal :class:`FlatParameter` s. This means
+            that the optimizer step runs on the original parameters, enabling
+            per-original-parameter hyperparameters. FSDP preserves the original
+            parameter variables and manipulates their data between unsharded
+            and sharded forms, where they are always views into the underlying
+            unsharded or sharded :class:`FlatParameter`, respectively. With the
+            current algorithm, the sharded form is always 1D, losing the
+            original tensor structure. An original parameter may have all,
+            some, or none of its data present for a given rank. In the none
+            case, its data will be like a size-0 empty tensor. Users should not
+            author programs relying on what data is present for a given
+            original parameter in its sharded form. ``True`` is required to
+            use ``torch.compile()``. Setting this to ``False`` exposes FSDP's
+            internal :class:`FlatParameter` s to the user via
+            :meth:`nn.Module.named_parameters`. (Default: ``False``)
+        ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]):
+            Ignored parameters or modules that will not be managed by this FSDP
+            instance, meaning that the parameters are not sharded and their
+            gradients are not reduced across ranks. This argument unifies with
+            the existing ``ignored_modules`` argument, and we may deprecate
+            ``ignored_modules`` soon. For backward compatibility, we keep both
+            ``ignored_states`` and `ignored_modules``, but FSDP only allows one
+            of them to be specified as not ``None``.
+    """
+
+    def __init__(
+        self,
+        module: nn.Module,
+        process_group: ProcessGroupType = None,
+        sharding_strategy: Optional[ShardingStrategy] = None,
+        cpu_offload: Optional[CPUOffload] = None,
+        auto_wrap_policy: Optional[
+            Union[Callable, ModuleWrapPolicy, CustomPolicy]
+        ] = None,
+        backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
+        mixed_precision: Optional[MixedPrecision] = None,
+        ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
+        param_init_fn: Optional[Callable[[nn.Module], None]] = None,
+        device_id: Optional[Union[int, torch.device]] = None,
+        sync_module_states: bool = False,
+        forward_prefetch: bool = False,
+        limit_all_gathers: bool = True,
+        use_orig_params: bool = False,
+        ignored_states: Union[
+            Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
+        ] = None,
+        device_mesh: Optional[DeviceMesh] = None,
+    ):
+        torch._C._log_api_usage_once("torch.distributed.fsdp")
+        super().__init__()
+        _init_ignored_module_states(self, module, ignored_modules, ignored_states)
+        _init_device_handle(self, module, self._ignored_params, device_id)
+
+        # Add module annotations for Dynamo support (see function for details)
+        _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)
+
+        # Initializes self.process_group, along with rank and world size. This will
+        # also set another attribute, _inter_node_pg, to control the process group
+        # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}.
+        # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up
+        # the same process group state as the root FSDP module.
+        self._device_mesh = device_mesh
+        _init_process_group_state(
+            self,
+            process_group,
+            sharding_strategy,
+            auto_wrap_policy,
+            device_mesh,
+        )
+        if auto_wrap_policy is not None:
+            root_kwargs = {
+                "process_group": process_group,
+                "sharding_strategy": sharding_strategy,
+                "cpu_offload": cpu_offload,
+                "backward_prefetch": backward_prefetch,
+                "mixed_precision": mixed_precision,
+                "param_init_fn": param_init_fn,
+                "device_id": device_id,
+                "sync_module_states": sync_module_states,
+                "forward_prefetch": forward_prefetch,
+                "limit_all_gathers": limit_all_gathers,
+                "use_orig_params": use_orig_params,
+                "ignored_states": self._ignored_params,
+                "device_mesh": device_mesh,
+            }
+            if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None:
+                # Share root process groups with children to maintain
+                # the invariant that all FSDP modules will have the same
+                # process groups.
+                root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)
+
+            _auto_wrap(
+                module,
+                auto_wrap_policy,
+                self._ignored_modules,
+                self._ignored_params,
+                root_kwargs,
+                FullyShardedDataParallel,
+            )
+
+        backward_prefetch_limit = 1
+        forward_prefetch_limit = 1
+        _init_core_state(
+            self,
+            sharding_strategy,
+            mixed_precision,
+            cpu_offload,
+            limit_all_gathers,
+            use_orig_params,
+            backward_prefetch_limit,
+            forward_prefetch_limit,
+        )
+        _init_runtime_state(self)
+        _init_prefetching_state(self, backward_prefetch, forward_prefetch)
+        _init_buffer_state(self, module)
+        # extension needs to be set before `_init_param_handle_from_module()`
+        _init_extension(self, device_mesh)
+        _init_param_handle_from_module(
+            self,
+            module,
+            device_id,
+            param_init_fn,
+            sync_module_states,
+        )
+        self._fsdp_wrapped_module = module
+        if not use_orig_params:
+            _check_orig_params_flattened(self, self._ignored_params)
+            _register_flat_param(self, self)
+
+        # `_state_dict_type` controls the `state_dict()` behavior, which is
+        # implemented using post-save and pre-load hooks
+        _init_state_dict_state(self)
+        _register_all_state_dict_hooks(self)
+
+    @property
+    def module(self) -> nn.Module:
+        """Return the wrapped module."""
+        # FSDP's `.module` must refer to the innermost wrapped module when
+        # composing with other module wrappers in order for state dict to work
+        if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
+            return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE)
+        return self._fsdp_wrapped_module
+
+    @property
+    def _has_params(self) -> bool:
+        """Returns whether this FSDP instance manages any parameters."""
+        return hasattr(self, "_handle") and self._handle is not None
+
+    @property
+    def _flat_param(self) -> Optional[FlatParameter]:
+        return self._handle.flat_param if self._handle else None
+
+    def __getattr__(self, name: str) -> Any:
+        """Forward missing attributes to the wrapped module."""
+        try:
+            return super().__getattr__(name)  # defer to nn.Module's logic
+        except AttributeError:
+            return getattr(self._fsdp_wrapped_module, name)
+
+    def __getitem__(self, key: int) -> Any:
+        """Forward indexing calls in case the module is an ``nn.Sequential``."""
+        if hasattr(self, FSDP_WRAPPED_MODULE):
+            return self._fsdp_wrapped_module.__getitem__(key)  # type: ignore[operator]
+        return super().__getitem__(key)
+
+    def check_is_root(self) -> bool:
+        """Check if this instance is a root FSDP module."""
+        return _is_fsdp_root(self, self)
+
+    @staticmethod
+    def fsdp_modules(
+        module: nn.Module,
+        root_only: bool = False,
+    ) -> List["FullyShardedDataParallel"]:
+        """Return all nested FSDP instances.
+
+        This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
+
+        Args:
+            module (torch.nn.Module): Root module, which may or may not be an
+                ``FSDP`` module.
+            root_only (bool): Whether to return only FSDP root modules.
+                (Default: ``False``)
+
+        Returns:
+            List[FullyShardedDataParallel]: FSDP modules that are nested in
+            the input ``module``.
+        """
+        if root_only:
+            return _get_fsdp_root_states(module)
+        return traversal_utils._get_fsdp_states(module)
+
+    def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
+        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
+
+        Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).
+
+        Compared to ``torch.nn.Module.apply``, this version additionally gathers
+        the full parameters before applying ``fn``. It should not be called from
+        within another ``summon_full_params`` context.
+
+        Args:
+            fn (:class:`Module` -> None): function to be applied to each submodule
+
+        Returns:
+            Module: self
+        """
+        uninitialized = self._is_root is None
+        self._assert_state(TrainingState.IDLE)
+        # Use `_unshard_params_recurse()` with `recurse=False` instead of
+        # `_unshard_fsdp_state_params()` directly to perform lazy
+        # initialization, which is needed to initialize `FlatParameter`
+        # parameter attributes as required by the unshard logic
+        with _unshard_params_recurse(
+            self,
+            self,
+            recurse=False,
+            writeback=True,
+            rank0_only=False,
+            offload_to_cpu=False,
+            with_grads=False,
+        ):
+            ret = super().apply(fn)
+
+        # Reset lazy init called in `_unshard_params_recurse()` since `apply()`
+        # may have been called on FSDP instance that is not truly a root, in
+        # which case it will be incorrectly marked as one.
+        if uninitialized and self._is_root:
+            for module in traversal_utils._get_fsdp_states(self):
+                module._reset_lazy_init()
+
+        return ret
+
+    def _mixed_precision_enabled_for_buffers(self) -> bool:
+        """Return whether the user explicitly enabled buffer mixed precision.
+
+        NOTE: Unlike parameters and gradient reduction, buffer mixed precision
+        is applied at the FSDP instance level, not the ``FlatParameter`` level,
+        which may be different for the composable code path.
+        """
+        return self.mixed_precision.buffer_dtype is not None
+
+    def _low_precision_hook_enabled(self) -> bool:
+        """Whether a low precision hook is registered or not."""
+        return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
+
+    def _reset_lazy_init(self) -> None:
+        """Reset instance so :func:`_lazy_init` will run on the next forward."""
+        self._is_root: Optional[bool] = None
+
+    @staticmethod
+    def set_state_dict_type(
+        module: nn.Module,
+        state_dict_type: StateDictType,
+        state_dict_config: Optional[StateDictConfig] = None,
+        optim_state_dict_config: Optional[OptimStateDictConfig] = None,
+    ) -> StateDictSettings:
+        """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
+
+        Also takes (optional) configuration for the model's and optimizer's state dict.
+        The target module does not have to be a FSDP module. If the target
+        module is a FSDP module, its ``state_dict_type`` will also be changed.
+
+        .. note:: This API should be called for only the top-level (root)
+            module.
+
+        .. note:: This API enables users to transparently use the conventional
+            ``state_dict`` API to take model checkpoints in cases where the
+            root FSDP module is wrapped by another ``nn.Module``. For example,
+            the following will ensure ``state_dict`` is called on all non-FSDP
+            instances, while dispatching into `sharded_state_dict` implementation
+            for FSDP:
+
+        Example::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> model = DDP(FSDP(...))
+            >>> FSDP.set_state_dict_type(
+            >>>     model,
+            >>>     StateDictType.SHARDED_STATE_DICT,
+            >>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
+            >>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
+            >>> )
+            >>> param_state_dict = model.state_dict()
+            >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
+
+        Args:
+            module (torch.nn.Module): Root module.
+            state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
+            state_dict_config (Optional[StateDictConfig]): the configuration for the
+                target ``state_dict_type``.
+            optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration
+                for the optimizer state dict.
+
+        Returns:
+            A StateDictSettings that include the previous state_dict type and
+            configuration for the module.
+        """
+        _state_dict_type_to_config = {
+            StateDictType.FULL_STATE_DICT: FullStateDictConfig,
+            StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
+            StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
+        }
+        _optim_state_dict_type_to_config = {
+            StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig,
+            StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig,
+            StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig,
+        }
+
+        # Use the default config if a state_dict config is not set.
+        state_dict_config_type = _state_dict_type_to_config[state_dict_type]
+        optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type]
+        if state_dict_config is None:
+            state_dict_config = state_dict_config_type()
+        if optim_state_dict_config is None:
+            optim_state_dict_config = optim_state_dict_config_type()
+        if state_dict_config_type != type(state_dict_config):
+            raise RuntimeError(
+                f"Expected state_dict_config of type {state_dict_config_type} "
+                f"but got {type(state_dict_config)}"
+            )
+        if optim_state_dict_config_type != type(optim_state_dict_config):
+            raise RuntimeError(
+                f"Expected optim_state_dict_config of type {optim_state_dict_config_type} "
+                f"but got {type(optim_state_dict_config)}"
+            )
+
+        # Set the state_dict type and configurations.
+        prev_state_dict_type = None
+        prev_state_dict_config = None
+        prev_optim_state_dict_config = None
+        for submodule in traversal_utils._get_fsdp_states(module):
+            if prev_state_dict_type is None:
+                prev_state_dict_type = submodule._state_dict_type
+            else:
+                assert (
+                    prev_state_dict_type == submodule._state_dict_type
+                ), "All FSDP modules should have the same state_dict_type."
+            if prev_state_dict_config is None:
+                prev_state_dict_config = submodule._state_dict_config
+            else:
+                assert isinstance(
+                    submodule._state_dict_config, type(prev_state_dict_config)
+                ), "All FSDP modules must have the same type of state_dict_config."
+            if prev_optim_state_dict_config is None:
+                prev_optim_state_dict_config = submodule._optim_state_dict_config
+            else:
+                assert isinstance(
+                    submodule._optim_state_dict_config,
+                    type(prev_optim_state_dict_config),
+                ), "All FSDP modules must have the same type of optim_state_dict_config."
+
+            submodule._state_dict_type = state_dict_type
+            submodule._state_dict_config = state_dict_config
+            submodule._optim_state_dict_config = optim_state_dict_config
+
+        return StateDictSettings(
+            prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config
+        )
+
+    @staticmethod
+    def get_state_dict_type(module: nn.Module) -> StateDictSettings:
+        """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``.
+
+        The target module does not have to be an FSDP module.
+
+        Returns:
+            A ``StateDictSettings`` containing the state_dict_type and
+            state_dict / optim_state_dict configs that are currently set.
+
+        Raises:
+            ``AssertionError`` if the ``StateDictSettings`` for different
+            FSDP submodules differ.
+        """
+        state_dict_settings: Optional[StateDictSettings] = None
+        for submodule in FullyShardedDataParallel.fsdp_modules(module):
+            if state_dict_settings is None:
+                state_dict_settings = StateDictSettings(
+                    state_dict_type=submodule._state_dict_type,
+                    state_dict_config=submodule._state_dict_config,
+                    optim_state_dict_config=submodule._optim_state_dict_config,
+                )
+                _set_optim_use_dtensor(submodule, state_dict_settings)
+            else:
+                submodule_settings = StateDictSettings(
+                    submodule._state_dict_type,
+                    submodule._state_dict_config,
+                    submodule._optim_state_dict_config,
+                )
+                assert state_dict_settings == submodule_settings, (
+                    "All FSDP modules must have the same state dict settings."
+                    f"Got {submodule_settings} and {state_dict_settings}."
+                )
+                _set_optim_use_dtensor(submodule, submodule_settings)
+        return state_dict_settings
+
+    @staticmethod
+    @contextlib.contextmanager
+    def state_dict_type(
+        module: nn.Module,
+        state_dict_type: StateDictType,
+        state_dict_config: Optional[StateDictConfig] = None,
+        optim_state_dict_config: Optional[OptimStateDictConfig] = None,
+    ) -> Generator:
+        """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
+
+        This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
+        :meth:`set_state_dict_type` for the detail.
+
+        Example::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> model = DDP(FSDP(...))
+            >>> with FSDP.state_dict_type(
+            >>>     model,
+            >>>     StateDictType.SHARDED_STATE_DICT,
+            >>> ):
+            >>>     checkpoint = model.state_dict()
+
+        Args:
+            module (torch.nn.Module): Root module.
+            state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
+            state_dict_config (Optional[StateDictConfig]): the model ``state_dict``
+                configuration for the target ``state_dict_type``.
+            optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer
+               ``state_dict`` configuration for the target ``state_dict_type``.
+        """
+        prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
+            module,
+            state_dict_type,
+            state_dict_config,
+            optim_state_dict_config,
+        )
+        yield
+        FullyShardedDataParallel.set_state_dict_type(
+            module,
+            prev_state_dict_settings.state_dict_type,
+            prev_state_dict_settings.state_dict_config,
+            prev_state_dict_settings.optim_state_dict_config,
+        )
+
+    def forward(self, *args: Any, **kwargs: Any) -> Any:
+        """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
+        handle = self._handle
+        with torch.autograd.profiler.record_function(
+            "FullyShardedDataParallel.forward"
+        ):
+            args, kwargs = _root_pre_forward(self, self, args, kwargs)
+            unused = None
+            args, kwargs = _pre_forward(
+                self,
+                handle,
+                _pre_forward_unshard,
+                self._fsdp_wrapped_module,
+                args,
+                kwargs,
+            )
+            if handle:
+                _p_assert(
+                    handle.flat_param.device == self.compute_device,
+                    "Expected `FlatParameter` to be on the compute device "
+                    f"{self.compute_device} but got {handle.flat_param.device}",
+                )
+            output = self._fsdp_wrapped_module(*args, **kwargs)
+            return _post_forward(
+                self, handle, _post_forward_reshard, self, unused, output
+            )
+
+    @staticmethod
+    @contextlib.contextmanager
+    def summon_full_params(
+        module: nn.Module,
+        recurse: bool = True,
+        writeback: bool = True,
+        rank0_only: bool = False,
+        offload_to_cpu: bool = False,
+        with_grads: bool = False,
+    ) -> Generator:
+        r"""Expose full params for FSDP instances with this context manager.
+
+        Can be useful *after* forward/backward for a model to get
+        the params for additional processing or checking. It can take a non-FSDP
+        module and will summon full params for all contained FSDP modules as
+        well as their children, depending on the ``recurse`` argument.
+
+        .. note:: This can be used on inner FSDPs.
+        .. note:: This can *not* be used within a forward or backward pass. Nor
+            can forward and backward be started from within this context.
+        .. note:: Parameters will revert to their local shards after the context
+            manager exits, storage behavior is the same as forward.
+        .. note:: The full parameters can be modified, but only the portion
+            corresponding to the local param shard will persist after the
+            context manager exits (unless ``writeback=False``, in which case
+            changes will be discarded). In the case where FSDP does not shard
+            the parameters, currently only when ``world_size == 1``, or ``NO_SHARD``
+            config, the modification is persisted regardless of ``writeback``.
+        .. note:: This method works on modules which are not FSDP themselves but
+            may contain multiple independent FSDP units. In that case, the given
+            arguments will apply to all contained FSDP units.
+
+        .. warning:: Note that ``rank0_only=True`` in conjunction with
+            ``writeback=True`` is not currently supported and will raise an
+            error. This is because model parameter shapes would be different
+            across ranks within the context, and writing to them can lead to
+            inconsistency across ranks when the context is exited.
+
+        .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will
+            result in full parameters being redundantly copied to CPU memory for
+            GPUs that reside on the same machine, which may incur the risk of
+            CPU OOM. It is recommended to use ``offload_to_cpu`` with
+            ``rank0_only=True``.
+
+        Args:
+            recurse (bool, Optional): recursively summon all params for nested
+                FSDP instances (default: True).
+            writeback (bool, Optional): if ``False``, modifications to params are
+                discarded after the context manager exits;
+                disabling this can be slightly more efficient (default: True)
+            rank0_only (bool, Optional): if ``True``, full parameters are
+                materialized on only global rank 0. This means that within the
+                context, only rank 0 will have full parameters and the other
+                ranks will have sharded parameters. Note that setting
+                ``rank0_only=True`` with ``writeback=True`` is not supported,
+                as model parameter shapes will be different across ranks
+                within the context, and writing to them can lead to
+                inconsistency across ranks when the context is exited.
+            offload_to_cpu (bool, Optional): If ``True``, full parameters are
+                offloaded to CPU. Note that this offloading currently only
+                occurs if the parameter is sharded (which is only not the case
+                for world_size = 1 or ``NO_SHARD`` config). It is recommended
+                to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid
+                redundant copies of model parameters being offloaded to the same CPU memory.
+            with_grads (bool, Optional): If ``True``, gradients are also
+                unsharded with the parameters. Currently, this is only
+                supported when passing ``use_orig_params=True`` to the FSDP
+                constructor and ``offload_to_cpu=False`` to this method.
+                (Default: ``False``)
+        """
+        with _unshard_params(
+            module, recurse, writeback, rank0_only, offload_to_cpu, with_grads
+        ):
+            yield
+
+    @contextlib.contextmanager
+    def _deregister_orig_params_ctx(self):
+        """Deregister the original parameters and expose the :class:`FlatParameter`.
+
+        If a :class:`FlatParameter` is sharded, then
+        this refreshes the sharded views before exiting. This method should
+        only be called when using the original parameters.
+        """
+        _p_assert(
+            self._use_orig_params,
+            "`_deregister_orig_params_ctx()` should only be called when "
+            "`_use_orig_params=True`",
+        )
+        for fsdp_module in traversal_utils._get_fsdp_states(self):
+            _deregister_orig_params(fsdp_module, fsdp_module)
+        try:
+            yield
+        finally:
+            for fsdp_module in traversal_utils._get_fsdp_states(self):
+                _register_orig_params(fsdp_module, fsdp_module)
+
+    def _apply(self, *args, **kwargs):
+        """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``."""
+        # When using the original parameters: Since (1) the `FlatParameter`s
+        # own the storage and (2) `_apply()` is the subroutine underlying the
+        # most common storage-changing ops like `to()` and `cuda()`, we
+        # override `_apply()` to have the storage change directly performed on
+        # the `FlatParameter`s instead of applying to the original parameters
+        # and then writing back to the `FlatParameter`s.
+        context = (
+            self._deregister_orig_params_ctx()
+            if self._use_orig_params
+            else contextlib.nullcontext()
+        )
+        with context:
+            return super()._apply(*args, **kwargs)
+
+    def named_buffers(
+        self,
+        *args,
+        **kwargs,
+    ) -> Iterator[Tuple[str, torch.Tensor]]:
+        """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
+
+        Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix
+        when inside the :meth:`summon_full_params` context manager.
+        """
+        should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
+        for buffer_name, buffer in super().named_buffers(*args, **kwargs):
+            if should_clean_name:
+                # Remove any instances of the FSDP-specific prefix; there can
+                # be multiple in the case of nested FSDP modules
+                buffer_name = buffer_name.replace(FSDP_PREFIX, "")
+            yield (buffer_name, buffer)
+
+    def named_parameters(
+        self,
+        *args,
+        **kwargs,
+    ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
+        """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
+
+        Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix
+        when inside the :meth:`summon_full_params` context manager.
+        """
+        should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
+        for param_name, param in super().named_parameters(*args, **kwargs):
+            if should_clean_name:
+                # Remove any instances of the FSDP-specific prefix; there can
+                # be multiple in the case of nested FSDP modules
+                param_name = param_name.replace(FSDP_PREFIX, "")
+            yield (param_name, param)
+
+    def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
+        """Assert we are in the given state."""
+        # Since assert can be turned off and this error checking
+        # is really important, we use explicit error checking
+        # and raise a ValueError if needed.
+        if isinstance(state, TrainingState):
+            state = [state]
+        if self.training_state not in state:
+            msg = (
+                f"expected to be in states {state} but current state "
+                f"is {self.training_state}"
+            )
+            # In case we are failing in the context of autograd hook, asserting
+            # may not generate useful msg. So, let's print it to be sure.
+            if self.rank == 0:
+                print(f"Asserting FSDP instance is: {self}")
+                print(f"ERROR: {msg}")
+                traceback.print_stack()
+            raise ValueError(msg)
+
+    @contextmanager
+    def no_sync(self) -> Generator:
+        """Disable gradient synchronizations across FSDP instances.
+
+        Within this context, gradients will be accumulated in module
+        variables, which will later be synchronized in the first
+        forward-backward pass after exiting the context. This should only be
+        used on the root FSDP instance and will recursively apply to all
+        children FSDP instances.
+
+        .. note:: This likely results in higher memory usage because FSDP will
+            accumulate the full model gradients (instead of gradient shards)
+            until the eventual sync.
+
+        .. note:: When used with CPU offloading, the gradients will not be
+            offloaded to CPU when inside the context manager. Instead, they
+            will only be offloaded right after the eventual sync.
+        """
+        _lazy_init(self, self)
+        if not self._is_root:
+            raise RuntimeError(
+                "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module."
+            )
+        self._assert_state(TrainingState.IDLE)
+        old_flags = []
+        for m in self.modules():
+            if isinstance(m, FullyShardedDataParallel):
+                old_flags.append((m, m._sync_gradients))
+                m._sync_gradients = False
+        try:
+            yield
+        finally:
+            for m, old_flag in old_flags:
+                assert not m._sync_gradients, (
+                    "`_sync_gradients` was incorrectly set to "
+                    "`True` while in the `no_sync()` context manager"
+                )
+                m._sync_gradients = old_flag
+
+    @torch.no_grad()
+    def clip_grad_norm_(
+        self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
+    ) -> torch.Tensor:
+        """Clip the gradient norm of all parameters.
+
+        The norm is computed over all parameters' gradients as viewed as a single vector, and the
+        gradients are modified in-place.
+
+        Args:
+            max_norm (float or int): max norm of the gradients
+            norm_type (float or int): type of the used p-norm. Can be ``'inf'``
+                for infinity norm.
+
+        Returns:
+            Total norm of the parameters (viewed as a single vector).
+
+        .. note:: If every FSDP instance uses ``NO_SHARD``, meaning that no
+            gradients are sharded across ranks, then you may directly use
+            :func:`torch.nn.utils.clip_grad_norm_`.
+
+        .. note:: If at least some FSDP instance uses a sharded strategy (i.e.
+            one other than ``NO_SHARD``), then you should use this method
+            instead of :func:`torch.nn.utils.clip_grad_norm_` since this method
+            handles the fact that gradients are sharded across ranks.
+
+        .. note:: The total norm returned will have the "largest" dtype across
+            all parameters/gradients as defined by PyTorch's type promotion
+            semantics. For example, if *all* parameters/gradients use a low
+            precision dtype, then the returned norm's dtype will be that low
+            precision dtype, but if there exists at least one parameter/
+            gradient using FP32, then the returned norm's dtype will be FP32.
+
+        .. warning:: This needs to be called on all ranks since it uses
+            collective communications.
+        """
+        _lazy_init(self, self)
+        if not self._is_root:
+            raise RuntimeError(
+                "`clip_grad_norm_()` should only be called on the root FSDP instance"
+            )
+        self._assert_state(TrainingState.IDLE)
+        # If every FSDP instance uses `NO_SHARD`, then we can directly use
+        # the normal `nn.utils` one targeting local gradients
+        all_no_shard = all(
+            not handle.uses_sharded_strategy for handle in self._all_handles
+        )
+        if all_no_shard:
+            return torch.nn.utils.clip_grad_norm_(
+                self.parameters(), max_norm, norm_type
+            )
+        # Otherwise, there exists some FSDP instance using a sharded strategy,
+        # where sharded and non-sharded parameters must be handled separately
+        max_norm = float(max_norm)
+        norm_type = float(norm_type)
+        sharded_params = set()
+        nonsharded_params = set()  # `NO_SHARD` or not FSDP-managed
+        grads: List[torch.Tensor] = []
+        for handle in self._all_handles:
+            target_set = (
+                sharded_params if handle.uses_sharded_strategy else nonsharded_params
+            )
+            if handle._use_orig_params:
+                for param in handle.flat_param._params:
+                    target_set.add(param)
+                    if param.grad is not None:
+                        grads.append(param.grad)
+            else:
+                target_set.add(handle.flat_param)
+                if handle.flat_param.grad is not None:
+                    grads.append(handle.flat_param.grad)
+        for param in self.parameters():
+            not_fsdp_managed = (
+                param not in sharded_params and param not in nonsharded_params
+            )
+            if not_fsdp_managed:
+                nonsharded_params.add(param)
+                if param.grad is not None:
+                    grads.append(param.grad)
+        # Compute local norms (forced to be in FP32)
+        local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to(
+            self.compute_device
+        )
+        local_nonsharded_norm = _get_grad_norm(nonsharded_params, norm_type).to(
+            self.compute_device
+        )
+        # Reconstruct the total gradient norm depending on the norm type
+        if norm_type == math.inf:
+            total_norm = torch.maximum(local_sharded_norm, local_nonsharded_norm)
+            dist.all_reduce(
+                total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group
+            )
+        else:
+            total_norm = local_sharded_norm**norm_type
+            dist.all_reduce(total_norm, group=self.process_group)
+            # All-reducing the local non-sharded norm would count it an extra
+            # world-size-many times
+            total_norm += local_nonsharded_norm**norm_type
+            total_norm = total_norm ** (1.0 / norm_type)
+        if self.cpu_offload.offload_params:
+            total_norm = total_norm.cpu()
+
+        clip_coef = max_norm / (total_norm + 1e-6)
+        # Multiplying by the clamped coefficient is meaningless when it is
+        # equal to 1, but it avoids the host-device sync that would result from
+        # `if clip_coef < 1`
+        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+        for grad in grads:
+            grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype))
+        # Use the "largest" dtype by type promotion semantics to use the same
+        # dtype as if we did not force local norm computation to be in FP32
+        if len(grads) == 0:
+            # If this rank has no gradients, then we must default to FP32
+            # unless we use additional communication, which we prefer to avoid
+            # since `clip_grad_norm_()` is called in the training loop
+            warnings.warn(
+                f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no "
+                "gradients -- returning the total norm in the default dtype "
+                f"{total_norm.dtype}"
+            )  # warn since this is generally unexpected
+            return total_norm
+        total_norm_dtype = functools.reduce(
+            torch.promote_types,
+            [grad.dtype for grad in grads],
+        )
+        return total_norm.to(total_norm_dtype)
+
+    @staticmethod
+    def _warn_optim_input(optim_input):
+        if optim_input is not None:
+            warnings.warn(
+                "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. You may remove it "
+                "from your code without changing its functionality."
+            )
+
+    @staticmethod
+    def _is_using_optim_input(optim_input, optim) -> bool:
+        if optim_input is None and optim is None:
+            # Use the default behavior of `optim_input``
+            return True
+        if optim_input is not None:
+            # Use the `optim_input` code path
+            return True
+        # Use the `optim` code path
+        return False
+
+    @staticmethod
+    def _warn_legacy_optim_state_dict(curr: str, new: str):
+        warnings.warn(
+            f"``FullyShardedDataParallel.{curr}``is being deprecated and is "
+            f"replaced by ``FullyShardedDataParallel.{new}``. "
+            f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2."
+        )
+
+    @staticmethod
+    def _optim_state_dict_impl(
+        model: torch.nn.Module,
+        optim: torch.optim.Optimizer,
+        optim_state_dict: Dict[str, Any],
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
+        rank0_only: bool = True,
+        full_state_dict: bool = True,
+        group: Optional[dist.ProcessGroup] = None,
+        cpu_offload: bool = True,
+    ) -> Dict[str, Any]:
+        """Transform the state-dict of an optimizer corresponding to a sharded model.
+
+        This is the internal API that is used by all the optim_state_dict implementations.
+        Given model, optim, the original optim_state_dict, this API removes the
+        FSDP internal information and internal sharding from the optim_state_dict.
+        """
+        if full_state_dict:
+            FullyShardedDataParallel._warn_optim_input(optim_input)
+            using_optim_input = FullyShardedDataParallel._is_using_optim_input(
+                optim_input,
+                optim,
+            )
+        else:
+            using_optim_input = False
+            assert optim_input is None and not rank0_only
+
+        use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
+            0
+        ]._use_orig_params
+        assert all(
+            use_orig_params == m._use_orig_params
+            for m in FullyShardedDataParallel.fsdp_modules(model)
+        ), "Not all FSDP modules have the same _use_orig_params value"
+
+        return _optim_state_dict(
+            model=model,
+            optim=optim,
+            optim_state_dict=optim_state_dict,
+            optim_input=optim_input,
+            rank0_only=rank0_only,
+            shard_state=not full_state_dict,
+            group=group,
+            using_optim_input=using_optim_input,
+            use_orig_params=use_orig_params,
+            cpu_offload=cpu_offload,
+        )
+
+    @staticmethod
+    def _optim_state_dict_to_load_impl(
+        optim_state_dict: Dict[str, Any],
+        model: torch.nn.Module,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
+        optim: Optional[torch.optim.Optimizer] = None,
+        full_state_dict: bool = True,
+        rank0_only: bool = False,
+        is_named_optimizer: bool = False,
+        group: Optional[dist.ProcessGroup] = None,
+    ) -> Dict[str, Any]:
+        """
+        Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
+
+        This is the internal API that is used by all the load optim_state_dict implementations.
+        Given model, optim, and the saved optim_state_dict, this API adds the FSDP
+        internal information and internal sharding to the optim_state_dict.
+        """
+        if full_state_dict:
+            FullyShardedDataParallel._warn_optim_input(optim_input)
+            using_optim_input = FullyShardedDataParallel._is_using_optim_input(
+                optim_input,
+                optim,
+            )
+        else:
+            using_optim_input = False
+            assert optim_input is None and not rank0_only
+
+        use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
+            0
+        ]._use_orig_params
+        assert all(
+            use_orig_params == m._use_orig_params
+            for m in FullyShardedDataParallel.fsdp_modules(model)
+        ), "Not all FSDP modules have the same _use_orig_params value"
+
+        if rank0_only and dist.get_rank(group) > 0:
+            optim_state_dict = {}
+        sharded_osd = _flatten_optim_state_dict(
+            optim_state_dict,
+            model=model,
+            use_orig_params=use_orig_params,
+            optim=(optim if is_named_optimizer else None),
+            rank0_only=rank0_only,
+            group=group,
+        )
+        return _rekey_sharded_optim_state_dict(
+            sharded_osd,
+            model=model,
+            optim=optim,
+            optim_input=optim_input,
+            using_optim_input=using_optim_input,
+            is_named_optimizer=is_named_optimizer,
+        )
+
+    @staticmethod
+    def full_optim_state_dict(
+        model: torch.nn.Module,
+        optim: torch.optim.Optimizer,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
+        rank0_only: bool = True,
+        group: Optional[dist.ProcessGroup] = None,
+    ) -> Dict[str, Any]:
+        """Return the full optimizer state-dict.
+
+        Consolidates the full optimizer state on rank 0 and returns it
+        as a :class:`dict` following the convention of
+        :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
+        and ``"param_groups"``. The flattened parameters in ``FSDP`` modules
+        contained in ``model`` are mapped back to their unflattened parameters.
+
+        .. warning:: This needs to be called on all ranks since it uses
+            collective communications. However, if ``rank0_only=True``, then
+            the state dict is only populated on rank 0, and all other ranks
+            return an empty :class:`dict`.
+
+        .. warning:: Unlike ``torch.optim.Optimizer.state_dict()``, this method
+            uses full parameter names as keys instead of parameter IDs.
+
+        .. note:: Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors
+            contained in the optimizer state dict are not cloned, so there may
+            be aliasing surprises. For best practices, consider saving the
+            returned optimizer state dict immediately, e.g. using
+            ``torch.save()``.
+
+        Args:
+            model (torch.nn.Module): Root module (which may or may not be a
+                :class:`FullyShardedDataParallel` instance) whose parameters
+                were passed into the optimizer ``optim``.
+            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
+                parameters.
+            optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
+                Input passed into the optimizer ``optim`` representing either a
+                :class:`list` of parameter groups or an iterable of parameters;
+                if ``None``, then this method assumes the input was
+                ``model.parameters()``. This argument is deprecated, and there
+                is no need to pass it in anymore. (Default: ``None``)
+            rank0_only (bool): If ``True``, saves the populated :class:`dict`
+                only on rank 0; if ``False``, saves it on all ranks. (Default:
+                ``True``)
+            group (dist.ProcessGroup): Model's process group or ``None`` if using
+                the default process group. (Default: ``None``)
+
+        Returns:
+            Dict[str, Any]: A :class:`dict` containing the optimizer state for
+            ``model`` 's original unflattened parameters and including keys
+            "state" and "param_groups" following the convention of
+            :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``,
+            then nonzero ranks return an empty :class:`dict`.
+        """
+        FullyShardedDataParallel._warn_legacy_optim_state_dict(
+            "full_optim_state_dict", "optim_state_dict"
+        )
+        return FullyShardedDataParallel._optim_state_dict_impl(
+            model=model,
+            optim=optim,
+            optim_state_dict=optim.state_dict(),
+            optim_input=optim_input,
+            rank0_only=rank0_only,
+            group=group,
+            full_state_dict=True,
+        )
+
+    @staticmethod
+    def sharded_optim_state_dict(
+        model: torch.nn.Module,
+        optim: torch.optim.Optimizer,
+        group: Optional[dist.ProcessGroup] = None,
+    ) -> Dict[str, Any]:
+        """Return the optimizer state-dict in its sharded form.
+
+        The API is similar to :meth:`full_optim_state_dict` but this API chunks
+        all non-zero-dimension states to :class:`ShardedTensor` to save memory.
+        This API should only be used when the model ``state_dict`` is derived
+        with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``.
+
+        For the detailed usage, refer to :meth:`full_optim_state_dict`.
+
+        .. warning:: The returned state dict contains ``ShardedTensor`` and
+            cannot be directly used by the regular ``optim.load_state_dict``.
+        """
+        FullyShardedDataParallel._warn_legacy_optim_state_dict(
+            "sharded_optim_state_dict", "optim_state_dict"
+        )
+        return FullyShardedDataParallel._optim_state_dict_impl(
+            model=model,
+            optim=optim,
+            optim_state_dict=optim.state_dict(),
+            optim_input=None,
+            rank0_only=False,
+            full_state_dict=False,
+            group=group,
+        )
+
+    @staticmethod
+    def shard_full_optim_state_dict(
+        full_optim_state_dict: Dict[str, Any],
+        model: torch.nn.Module,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
+        optim: Optional[torch.optim.Optimizer] = None,
+    ) -> Dict[str, Any]:
+        """Shard a full optimizer state-dict.
+
+        Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
+        parameters and restricts to only this rank's part of the optimizer state.
+        The first argument should be the return value of :meth:`full_optim_state_dict`.
+
+        Example::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+            >>> model, optim = ...
+            >>> full_osd = FSDP.full_optim_state_dict(model, optim)
+            >>> torch.save(full_osd, PATH)
+            >>> # Define new model with possibly different world size
+            >>> new_model, new_optim = ...
+            >>> full_osd = torch.load(PATH)
+            >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
+            >>> new_optim.load_state_dict(sharded_osd)
+
+        .. note:: Both :meth:`shard_full_optim_state_dict` and
+            :meth:`scatter_full_optim_state_dict` may be used to get the
+            sharded optimizer state dict to load. Assuming that the full
+            optimizer state dict resides in CPU memory, the former requires
+            each rank to have the full dict in CPU memory, where each rank
+            individually shards the dict without any communication, while the
+            latter requires only rank 0 to have the full dict in CPU memory,
+            where rank 0 moves each shard to GPU memory (for NCCL) and
+            communicates it to ranks appropriately. Hence, the former has
+            higher aggregate CPU memory cost, while the latter has higher
+            communication cost.
+
+        Args:
+            full_optim_state_dict (Dict[str, Any]): Optimizer state dict
+                corresponding to the unflattened parameters and holding the
+                full non-sharded optimizer state.
+            model (torch.nn.Module): Root module (which may or may not be a
+                :class:`FullyShardedDataParallel` instance) whose parameters
+                correspond to the optimizer state in ``full_optim_state_dict``.
+            optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
+                Input passed into the optimizer representing either a
+                :class:`list` of parameter groups or an iterable of parameters;
+                if ``None``, then this method assumes the input was
+                ``model.parameters()``. This argument is deprecated, and there
+                is no need to pass it in anymore. (Default: ``None``)
+            optim (Optional[torch.optim.Optimizer]): Optimizer that will load
+                the state dict returned by this method. This is the preferred
+                argument to use over ``optim_input``. (Default: ``None``)
+
+        Returns:
+            Dict[str, Any]: The full optimizer state dict now remapped to
+            flattened parameters instead of unflattened parameters and
+            restricted to only include this rank's part of the optimizer state.
+        """
+        FullyShardedDataParallel._warn_legacy_optim_state_dict(
+            "shard_full_optim_state_dict", "optim_state_dict_to_load"
+        )
+        return FullyShardedDataParallel._optim_state_dict_to_load_impl(
+            optim_state_dict=full_optim_state_dict,
+            model=model,
+            optim_input=optim_input,
+            optim=optim,
+            full_state_dict=True,
+            is_named_optimizer=False,
+        )
+
+    @staticmethod
+    def flatten_sharded_optim_state_dict(
+        sharded_optim_state_dict: Dict[str, Any],
+        model: torch.nn.Module,
+        optim: torch.optim.Optimizer,
+    ) -> Dict[str, Any]:
+        """Flatten a sharded optimizer state-dict.
+
+        The API is similar to :meth:`shard_full_optim_state_dict`. The only
+        difference is that the input ``sharded_optim_state_dict`` should be
+        returned from :meth:`sharded_optim_state_dict`. Therefore, there will
+        be all-gather calls on each rank to gather ``ShardedTensor`` s.
+
+        Args:
+            sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict
+                corresponding to the unflattened parameters and holding the
+                sharded optimizer state.
+            model (torch.nn.Module):
+                Refer to :meth:`shard_full_optim_state_dict`.
+            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
+                parameters.
+
+        Returns:
+            Refer to :meth:`shard_full_optim_state_dict`.
+        """
+        FullyShardedDataParallel._warn_legacy_optim_state_dict(
+            "flatten_sharded_optim_state_dict", "optim_state_dict_to_load"
+        )
+        return FullyShardedDataParallel._optim_state_dict_to_load_impl(
+            optim_state_dict=sharded_optim_state_dict,
+            model=model,
+            optim_input=None,
+            optim=optim,
+            full_state_dict=False,
+            is_named_optimizer=False,
+        )
+
+    @staticmethod
+    def scatter_full_optim_state_dict(
+        full_optim_state_dict: Optional[Dict[str, Any]],
+        model: torch.nn.Module,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
+        optim: Optional[torch.optim.Optimizer] = None,
+        group: Optional[Any] = None,
+    ) -> Dict[str, Any]:
+        """Scatter the full optimizer state dict from rank 0 to all other ranks.
+
+        Returns the sharded optimizer state dict on each rank.
+        The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank
+        0, the first argument should be the return value of
+        :meth:`full_optim_state_dict`.
+
+        Example::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+            >>> model, optim = ...
+            >>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
+            >>> # Define new model with possibly different world size
+            >>> new_model, new_optim, new_group = ...
+            >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
+            >>> new_optim.load_state_dict(sharded_osd)
+
+        .. note:: Both :meth:`shard_full_optim_state_dict` and
+            :meth:`scatter_full_optim_state_dict` may be used to get the
+            sharded optimizer state dict to load. Assuming that the full
+            optimizer state dict resides in CPU memory, the former requires
+            each rank to have the full dict in CPU memory, where each rank
+            individually shards the dict without any communication, while the
+            latter requires only rank 0 to have the full dict in CPU memory,
+            where rank 0 moves each shard to GPU memory (for NCCL) and
+            communicates it to ranks appropriately. Hence, the former has
+            higher aggregate CPU memory cost, while the latter has higher
+            communication cost.
+
+        Args:
+            full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state
+                dict corresponding to the unflattened parameters and holding
+                the full non-sharded optimizer state if on rank 0; the argument
+                is ignored on nonzero ranks.
+            model (torch.nn.Module): Root module (which may or may not be a
+                :class:`FullyShardedDataParallel` instance) whose parameters
+                correspond to the optimizer state in ``full_optim_state_dict``.
+            optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
+                Input passed into the optimizer representing either a
+                :class:`list` of parameter groups or an iterable of parameters;
+                if ``None``, then this method assumes the input was
+                ``model.parameters()``. This argument is deprecated, and there
+                is no need to pass it in anymore. (Default: ``None``)
+            optim (Optional[torch.optim.Optimizer]): Optimizer that will load
+                the state dict returned by this method. This is the preferred
+                argument to use over ``optim_input``. (Default: ``None``)
+            group (dist.ProcessGroup): Model's process group or ``None`` if
+                using the default process group. (Default: ``None``)
+
+        Returns:
+            Dict[str, Any]: The full optimizer state dict now remapped to
+            flattened parameters instead of unflattened parameters and
+            restricted to only include this rank's part of the optimizer state.
+        """
+        FullyShardedDataParallel._warn_legacy_optim_state_dict(
+            "scatter_full_optim_state_dict", "optim_state_dict_to_load"
+        )
+        return FullyShardedDataParallel._optim_state_dict_to_load_impl(
+            optim_state_dict=full_optim_state_dict,
+            model=model,
+            optim_input=optim_input,
+            optim=optim,
+            full_state_dict=True,
+            rank0_only=True,
+            is_named_optimizer=False,
+            group=group,
+        )
+
+    @staticmethod
+    def rekey_optim_state_dict(
+        optim_state_dict: Dict[str, Any],
+        optim_state_key_type: OptimStateKeyType,
+        model: torch.nn.Module,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
+        optim: Optional[torch.optim.Optimizer] = None,
+    ) -> Dict[str, Any]:
+        """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
+
+        This can be used to achieve compatibility between optimizer state dicts from models with FSDP
+        instances and ones without.
+
+        To re-key an FSDP full optimizer state dict (i.e. from
+        :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to
+        a non-wrapped model::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> wrapped_model, wrapped_optim = ...
+            >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
+            >>> nonwrapped_model, nonwrapped_optim = ...
+            >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
+            >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
+
+        To re-key a normal optimizer state dict from a non-wrapped model to be
+        loadable to a wrapped model::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> nonwrapped_model, nonwrapped_optim = ...
+            >>> osd = nonwrapped_optim.state_dict()
+            >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
+            >>> wrapped_model, wrapped_optim = ...
+            >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
+            >>> wrapped_optim.load_state_dict(sharded_osd)
+
+        Returns:
+            Dict[str, Any]: The optimizer state dict re-keyed using the
+            parameter keys specified by ``optim_state_key_type``.
+        """
+        FullyShardedDataParallel._warn_optim_input(optim_input)
+        using_optim_input = FullyShardedDataParallel._is_using_optim_input(
+            optim_input,
+            optim,
+        )
+        assert optim_state_key_type in (
+            OptimStateKeyType.PARAM_NAME,
+            OptimStateKeyType.PARAM_ID,
+        )
+        osd = optim_state_dict  # alias
+        # Validate that the existing parameter keys are uniformly typed
+        uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]]
+        uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]]
+        if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or (
+            any(uses_param_id_mask) and not all(uses_param_id_mask)
+        ):
+            error_msg = f"Invalid parameter keys: {osd['state'].keys()}"
+            raise ValueError(error_msg)
+        # Return directly if the existing key type matches the target key type
+        if (
+            optim_state_key_type == OptimStateKeyType.PARAM_NAME
+            and all(uses_param_name_mask)
+        ) or (
+            optim_state_key_type == OptimStateKeyType.PARAM_ID
+            and all(uses_param_id_mask)
+        ):
+            return osd
+        # Otherwise, actually perform the re-keying
+        new_osd = {}
+        if optim_state_key_type == OptimStateKeyType.PARAM_NAME:  # ID -> name
+            param_id_to_param = (
+                _get_param_id_to_param_from_optim_input(model, optim_input)
+                if using_optim_input
+                else _get_param_key_to_param(optim)
+            )
+            param_to_param_name = _get_param_to_fqn(model)
+            param_id_to_param_name: List[str] = [
+                param_to_param_name[param] for param in param_id_to_param.values()
+            ]
+            new_osd["state"] = {
+                param_id_to_param_name[param_id]: param_state
+                for param_id, param_state in osd["state"].items()
+            }
+            new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
+            for param_group in new_osd["param_groups"]:
+                param_group["params"] = sorted(
+                    [
+                        param_id_to_param_name[param_id]
+                        for param_id in param_group["params"]
+                    ]
+                )
+            return new_osd
+        elif optim_state_key_type == OptimStateKeyType.PARAM_ID:  # name -> ID
+            param_name_to_param = _get_fqn_to_param(model)
+            param_to_param_id = (
+                _get_param_to_param_id_from_optim_input(model, optim_input)
+                if using_optim_input
+                else _get_param_to_param_key(optim)
+            )
+            # Because not all model parameters may be passed as the optimizer
+            # input, we may need to drop some parameters from this mapping
+            param_name_to_param_id = {
+                param_name: param_to_param_id[param]
+                for param_name, param in param_name_to_param.items()
+                if param in param_to_param_id
+            }
+            new_osd["state"] = {
+                param_name_to_param_id[param_name]: param_state
+                for param_name, param_state in osd["state"].items()
+            }
+            new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
+            for param_group in new_osd["param_groups"]:
+                param_group["params"] = sorted(
+                    [
+                        param_name_to_param_id[param_name]
+                        for param_name in param_group["params"]
+                    ]
+                )
+            return new_osd
+        return new_osd  # should never reach here
+
+    @staticmethod
+    def optim_state_dict(
+        model: torch.nn.Module,
+        optim: torch.optim.Optimizer,
+        optim_state_dict: Optional[Dict[str, Any]] = None,
+        group: Optional[dist.ProcessGroup] = None,
+    ) -> Dict[str, Any]:
+        """
+        Transform the state-dict of an optimizer corresponding to a sharded model.
+
+        The given state-dict can be transformed to one of three types:
+        1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
+
+        For full optimizer state_dict, all states are unflattened and not sharded.
+        Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
+        avoid OOM.
+
+        For sharded optimizer state_dict, all states are unflattened but sharded.
+        CPU only can be specified via :meth:`state_dict_type` to further save
+        memory.
+
+        For local state_dict, no transformation will be performed. But a state
+        will be converted from nn.Tensor to ShardedTensor to represent its sharding
+        nature (this is not supported yet).
+
+        Example::
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+            >>> from torch.distributed.fsdp import StateDictType
+            >>> from torch.distributed.fsdp import FullStateDictConfig
+            >>> from torch.distributed.fsdp import FullOptimStateDictConfig
+            >>> # Save a checkpoint
+            >>> model, optim = ...
+            >>> FSDP.set_state_dict_type(
+            >>>     model,
+            >>>     StateDictType.FULL_STATE_DICT,
+            >>>     FullStateDictConfig(rank0_only=False),
+            >>>     FullOptimStateDictConfig(rank0_only=False),
+            >>> )
+            >>> state_dict = model.state_dict()
+            >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
+            >>> save_a_checkpoint(state_dict, optim_state_dict)
+            >>> # Load a checkpoint
+            >>> model, optim = ...
+            >>> state_dict, optim_state_dict = load_a_checkpoint()
+            >>> FSDP.set_state_dict_type(
+            >>>     model,
+            >>>     StateDictType.FULL_STATE_DICT,
+            >>>     FullStateDictConfig(rank0_only=False),
+            >>>     FullOptimStateDictConfig(rank0_only=False),
+            >>> )
+            >>> model.load_state_dict(state_dict)
+            >>> optim_state_dict = FSDP.optim_state_dict_to_load(
+            >>>     model, optim, optim_state_dict
+            >>> )
+            >>> optim.load_state_dict(optim_state_dict)
+
+        Args:
+            model (torch.nn.Module): Root module (which may or may not be a
+                :class:`FullyShardedDataParallel` instance) whose parameters
+                were passed into the optimizer ``optim``.
+            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
+                parameters.
+            optim_state_dict (Dict[str, Any]): the target optimizer state_dict to
+                transform. If the value is None, optim.state_dict() will be used. (
+                Default: ``None``)
+            group (dist.ProcessGroup): Model's process group across which parameters
+                are sharded or ``None`` if using the default process group. (
+                Default: ``None``)
+
+        Returns:
+            Dict[str, Any]: A :class:`dict` containing the optimizer state for
+            ``model``. The sharding of the optimizer state is based on
+            ``state_dict_type``.
+        """
+        state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
+        if optim_state_dict is None:
+            optim_state_dict = optim.state_dict()
+        return FullyShardedDataParallel._optim_state_dict_impl(
+            model=model,
+            optim=optim,
+            optim_state_dict=optim_state_dict,
+            optim_input=None,
+            rank0_only=getattr(
+                state_dict_settings.optim_state_dict_config, "rank0_only", False
+            ),
+            full_state_dict=state_dict_settings.state_dict_type
+            == StateDictType.FULL_STATE_DICT,
+            group=group,
+            cpu_offload=getattr(
+                state_dict_settings.optim_state_dict_config, "offload_to_cpu", True
+            ),
+        )
+
+    @staticmethod
+    def optim_state_dict_to_load(
+        model: torch.nn.Module,
+        optim: torch.optim.Optimizer,
+        optim_state_dict: Dict[str, Any],
+        is_named_optimizer: bool = False,
+        load_directly: bool = False,
+        group: Optional[dist.ProcessGroup] = None,
+    ) -> Dict[str, Any]:
+        """
+        Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
+
+        Given a ``optim_state_dict`` that is transformed through
+        :meth:`optim_state_dict`, it gets converted to the flattened optimizer
+        state_dict that can be loaded to ``optim`` which is the optimizer for
+        ``model``. ``model`` must be sharded by FullyShardedDataParallel.
+
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+            >>> from torch.distributed.fsdp import StateDictType
+            >>> from torch.distributed.fsdp import FullStateDictConfig
+            >>> from torch.distributed.fsdp import FullOptimStateDictConfig
+            >>> # Save a checkpoint
+            >>> model, optim = ...
+            >>> FSDP.set_state_dict_type(
+            >>>     model,
+            >>>     StateDictType.FULL_STATE_DICT,
+            >>>     FullStateDictConfig(rank0_only=False),
+            >>>     FullOptimStateDictConfig(rank0_only=False),
+            >>> )
+            >>> state_dict = model.state_dict()
+            >>> original_osd = optim.state_dict()
+            >>> optim_state_dict = FSDP.optim_state_dict(
+            >>>     model,
+            >>>     optim,
+            >>>     optim_state_dict=original_osd
+            >>> )
+            >>> save_a_checkpoint(state_dict, optim_state_dict)
+            >>> # Load a checkpoint
+            >>> model, optim = ...
+            >>> state_dict, optim_state_dict = load_a_checkpoint()
+            >>> FSDP.set_state_dict_type(
+            >>>     model,
+            >>>     StateDictType.FULL_STATE_DICT,
+            >>>     FullStateDictConfig(rank0_only=False),
+            >>>     FullOptimStateDictConfig(rank0_only=False),
+            >>> )
+            >>> model.load_state_dict(state_dict)
+            >>> optim_state_dict = FSDP.optim_state_dict_to_load(
+            >>>     model, optim, optim_state_dict
+            >>> )
+            >>> optim.load_state_dict(optim_state_dict)
+
+        Args:
+            model (torch.nn.Module): Root module (which may or may not be a
+                :class:`FullyShardedDataParallel` instance) whose parameters
+                were passed into the optimizer ``optim``.
+            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
+                parameters.
+            optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
+            is_named_optimizer (bool): Is this optimizer a NamedOptimizer or
+                KeyedOptimizer. Only set to True if ``optim`` is TorchRec's
+                KeyedOptimizer or torch.distributed's NamedOptimizer.
+            load_directly (bool): If this is set to True, this API will also
+                call optim.load_state_dict(result) before returning the result.
+                Otherwise, users are responsible to call ``optim.load_state_dict()``
+                (Default: ``False``)
+            group (dist.ProcessGroup): Model's process group across which parameters
+                are sharded or ``None`` if using the default process group. (
+                Default: ``None``)
+        """
+        state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
+        result = FullyShardedDataParallel._optim_state_dict_to_load_impl(
+            optim_state_dict=optim_state_dict,
+            model=model,
+            optim_input=None,
+            optim=optim,
+            full_state_dict=(
+                state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT
+            ),
+            rank0_only=getattr(
+                state_dict_settings.optim_state_dict_config, "rank0_only", False
+            ),
+            is_named_optimizer=is_named_optimizer,
+            group=group,
+        )
+        if load_directly:
+            optim.load_state_dict(result)
+        return result
+
+    def register_comm_hook(self, state: object, hook: callable):
+        """Register a communication hook.
+
+        This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates
+        gradients across multiple workers.
+        This hook can be used to implement several algorithms like
+        `GossipGrad `_ and gradient compression
+        which involve different communication strategies for
+        parameter syncs while training with :class:`FullyShardedDataParallel`.
+
+        .. warning ::
+            FSDP communication hook should be registered before running an initial forward pass
+            and only once.
+
+        Args:
+            state (object): Passed to the hook to maintain any state information during the training process.
+                            Examples include error feedback in gradient compression,
+                            peers to communicate with next in `GossipGrad `_, etc.
+                            It is locally stored by each worker
+                            and shared by all the gradient tensors on the worker.
+            hook (Callable): Callable, which has one of the following signatures:
+                            1) ``hook: Callable[torch.Tensor] -> None``:
+                            This function takes in a Python tensor, which represents
+                            the full, flattened, unsharded gradient with respect to all variables
+                            corresponding to the model this FSDP unit is wrapping
+                            (that are not wrapped by other FSDP sub-units).
+                            It then performs all necessary processing and returns ``None``;
+                            2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
+                            This function takes in two Python tensors, the first one represents
+                            the full, flattened, unsharded gradient with respect to all variables
+                            corresponding to the model this FSDP unit is wrapping
+                            (that are not wrapped by other FSDP sub-units). The latter
+                            represents a pre-sized tensor to store a chunk of a sharded gradient after
+                            reduction.
+                            In both cases, callable performs all necessary processing and returns ``None``.
+                            Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
+                            Callables with signature 2 are expected to handle gradient communication for sharded cases.
+
+        """
+        if not self.check_is_root():
+            raise AssertionError(
+                "register_comm_hook can only be called on a root instance."
+            )
+        for fsdp_state in traversal_utils._get_fsdp_states(self):
+            if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
+                raise AssertionError(
+                    f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
+                )
+            if fsdp_state._comm_hook is not None:
+                raise AssertionError("A communication hook is already registered")
+            if not callable(hook):
+                raise ValueError(
+                    f"The communication hook must be callable but got {hook}"
+                )
+            fsdp_state._comm_hook = hook
+            fsdp_state._comm_hook_state = state
+
+
+def _get_grad_norm(
+    params: Iterable[nn.Parameter],
+    norm_type: float,
+) -> torch.Tensor:
+    """
+    Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
+
+    The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
+    use of this return value is a reduction across ranks.
+    """
+    params_with_grad = [param for param in params if param.grad is not None]
+    if len(params_with_grad) == 0:
+        return torch.tensor(0.0)
+    grads = [param.grad for param in params_with_grad]
+    grad_dtypes = {grad.dtype for grad in grads}
+    if len(grad_dtypes) != 1:
+        raise ValueError(
+            f"Requires uniform dtype across all gradients but got {grad_dtypes}"
+        )
+    # Compute the gradient norm in FP32, where we treat the gradients as a
+    # single vector
+    grad_norm = torch.linalg.vector_norm(
+        torch.stack(
+            [
+                torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32)
+                for grad in grads
+            ],
+        ),
+        norm_type,
+        dtype=torch.float32,
+    )
+    return grad_norm
+
+
+def _get_param_to_fqn(
+    model: torch.nn.Module,
+) -> Dict[torch.nn.Parameter, str]:
+    """
+    Construct a mapping from parameters to their parameter names.
+
+    The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which
+    means that none of the parameters should be ``FlatParameter`` s. As a
+    result, compared to :meth:`_get_param_to_fqns`, the mapped
+    values may be flattened from singleton :class:`list` s to the contained
+    names themselves.
+
+    Args:
+        model (torch.nn.Module): Root module, which should not contain any
+            :class:`FullyShardedDataParallel` instances.
+    """
+    param_to_param_names = _get_param_to_fqns(model)
+    for param_names in param_to_param_names.values():
+        assert (
+            len(param_names) > 0
+        ), "`_get_param_to_fqns()` should not construct empty lists"
+        if len(param_names) > 1:
+            raise RuntimeError(
+                "Each parameter should only map to one parameter name but got "
+                f"{len(param_names)}: {param_names}"
+            )
+    param_to_param_name = {
+        param: param_names[0] for param, param_names in param_to_param_names.items()
+    }
+    return param_to_param_name
+
+
+def _get_fqn_to_param(
+    model: torch.nn.Module,
+) -> Dict[str, torch.nn.Parameter]:
+    """Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
+    param_to_param_name = _get_param_to_fqn(model)
+    return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f144d57cd361fe650729c7bd022383bcc7dbb3a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py
@@ -0,0 +1,388 @@
+import logging
+from collections import abc, defaultdict
+from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
+
+import torch
+import torch.distributed as dist
+from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
+from torch.distributed.distributed_c10d import ProcessGroup
+
+log = logging.getLogger(__name__)
+
+
+def _refresh_per_optimizer_state() -> Dict[str, Any]:
+    return {"stage": OptState.READY, "found_inf_per_device": {}}
+
+
+def _is_supported_device(tensor: torch.Tensor) -> bool:
+    return tensor.is_cuda or tensor.device.type in ("xla", "cpu", "hpu")
+
+
+class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
+    """
+    Lazily serves tensor to request device. This class extends
+    _MultiDeviceReplicator to allow support for "cpu" as a device.
+    """
+
+    def __init__(self, master_tensor: torch.Tensor) -> None:
+        assert _is_supported_device(master_tensor)
+        self.master = master_tensor
+        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+
+class ShardedGradScaler(GradScaler):
+    """
+    ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
+    functionality from GradScaler:
+    * Supports Pytorch DDP and FSDP implementations
+    * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
+    * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
+    * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
+    nodes
+
+    Example::
+
+        # Creates a ShardedGradScaler once at the beginning of training.
+        scaler = ShardedGradScaler()
+
+        for epoch in epochs:
+            for input, target in data:
+                optimizer.zero_grad()
+                output = model(input)
+                loss = loss_fn(output, target)
+
+                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
+                scaler.scale(loss).backward()
+
+                # scaler.step() first unscales gradients of the optimizer's params.
+                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
+                # otherwise, optimizer.step() is skipped.
+                scaler.step(optimizer)
+
+                # Updates the scale for next iteration.
+                scaler.update()
+
+    See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
+
+    Args:
+        init_scale (float, optional, default=2.**16):  Initial scale factor.
+        growth_factor (float, optional, default=2.0):  Factor by which the scale is multiplied during
+            :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
+        backoff_factor (float, optional, default=0.5):  Factor by which the scale is multiplied during
+            :meth:`update` if inf/NaN gradients occur in an iteration.
+        growth_interval (int, optional, default=2000):  Number of consecutive iterations without inf/NaN gradients
+            that must occur for the scale to be multiplied by ``growth_factor``.
+        enabled (bool, optional):  If ``False``, disables gradient scaling. :meth:`step` simply
+            invokes the underlying ``optimizer.step()``, and other methods become no-ops.
+            Default: ``True``
+        process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
+            process group for sharding
+    """
+
+    def __init__(
+        self,
+        device: str = "cuda",
+        init_scale: float = 2.0**16,
+        backoff_factor: float = 0.5,
+        growth_factor: float = 2.0,
+        growth_interval: int = 2000,
+        enabled: bool = True,
+        process_group: Optional[ProcessGroup] = dist.group.WORLD,
+    ) -> None:
+        super().__init__(
+            device,
+            init_scale=init_scale,
+            backoff_factor=backoff_factor,
+            growth_factor=growth_factor,
+            growth_interval=growth_interval,
+            enabled=enabled,
+        )
+        if self._enabled:
+            self.process_group = process_group
+            self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
+
+    @overload
+    def scale(self, outputs: torch.Tensor) -> torch.Tensor:
+        ...
+
+    @overload
+    def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
+        ...
+
+    @overload
+    def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
+        ...
+
+    @overload
+    def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
+        ...
+
+    def scale(
+        self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
+    ) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
+        if not self._enabled:
+            return outputs
+
+        if isinstance(outputs, torch.Tensor):
+            assert _is_supported_device(outputs)
+            if self._scale is None:
+                self._lazy_init_scale_growth_tracker(outputs.device)
+            assert self._scale is not None
+            scaled_output = outputs * self._scale.to(
+                device=outputs.device, non_blocking=True
+            )
+            # Here we ensure the return dtype is the same as the outputs dtype.
+            # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
+            # format (fp16, bf16) and so the scaled loss should be of the same dtype.
+            return scaled_output.type(outputs.dtype)
+
+        stash: List[_GeneralMultiDeviceReplicator] = []
+
+        def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
+            if isinstance(val, torch.Tensor):
+                assert _is_supported_device(val)
+                if len(stash) == 0:
+                    if self._scale is None:
+                        self._lazy_init_scale_growth_tracker(val.device)
+                    assert self._scale is not None
+                    stash.append(_GeneralMultiDeviceReplicator(self._scale))
+                scaled_val = val * stash[0].get(val.device)
+                # Here we ensure the return dtype is the same as the outputs dtype.
+                # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
+                # format (fp16, bf16) and so the scaled loss should be of the same dtype.
+                return scaled_val.type(val.dtype)
+            if isinstance(val, abc.Iterable):
+                iterator = map(apply_scale, val)
+                if isinstance(val, (list, tuple)):
+                    return type(val)(iterator)
+                return iterator
+            raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+        return apply_scale(outputs)
+
+    def _foreach_non_finite_check_and_unscale_cpu_(
+        self,
+        grads: Sequence[torch.Tensor],
+        found_inf: torch.Tensor,
+        inv_scale: torch.Tensor,
+    ) -> None:
+        if len(grads) == 0:
+            return
+        assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
+        assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
+
+        for grad in grads:
+            if grad.device.type != "cpu":
+                log.error(
+                    "tensor device is %s but was expected to be ``cpu``",
+                    grad.device,
+                )
+                raise ValueError(
+                    "Gradients were found on a non-CPU device when"
+                    " expected to be on CPU."
+                )
+            if (
+                torch.isinf(grad).any().item() is True
+                or torch.isnan(grad).any().item() is True
+            ):
+                found_inf.data = torch.tensor([1.0])
+                break
+            else:
+                grad.data *= inv_scale.item()
+
+    def _unscale_grads_(
+        self,
+        optimizer: torch.optim.Optimizer,
+        inv_scale: torch.Tensor,
+        found_inf: torch.Tensor,
+        allow_fp16: bool = True,
+    ) -> Dict[torch.device, torch.Tensor]:
+        per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
+        per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
+
+        # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
+        # There could be thousands of grads, so we'd like to iterate through them just once.
+        # However, we don't know their devices or dtypes in advance.
+
+        # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
+        # Google says mypy struggles with defaultdicts type annotations.
+        per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
+        with torch.no_grad():
+            for group in optimizer.param_groups:
+                for param in group["params"]:
+                    if param.grad is None:
+                        continue
+                    if (not allow_fp16) and param.grad.dtype == torch.float16:
+                        raise ValueError("Attempting to unscale FP16 gradients.")
+                    if param.grad.is_sparse:
+                        # is_coalesced() == False means the sparse grad has values with duplicate indices.
+                        # coalesce() deduplicates indices and adds all values that have the same index.
+                        # For scaled fp16 values, there's a good chance coalescing will cause overflow,
+                        # so we should check the coalesced _values().
+                        if param.grad.dtype is torch.float16:
+                            # coalesce is not supported in torch.float16
+                            param_grad_fp32 = param.grad.type(torch.float32).coalesce()
+                            param.grad = param_grad_fp32.type(torch.float16)
+                        to_unscale = param.grad._values()
+                    else:
+                        to_unscale = param.grad
+
+                    per_device_and_dtype_grads[to_unscale.device][
+                        to_unscale.dtype
+                    ].append(to_unscale)
+
+            for device, per_dtype_grads in per_device_and_dtype_grads.items():
+                for grads in per_dtype_grads.values():
+                    if grads[0].device.type == "cpu":
+                        self._foreach_non_finite_check_and_unscale_cpu_(
+                            grads,
+                            per_device_found_inf.get(device),
+                            per_device_inv_scale.get(device),
+                        )
+                    else:
+                        torch._amp_foreach_non_finite_check_and_unscale_(
+                            grads,
+                            per_device_found_inf.get(device),
+                            per_device_inv_scale.get(device),
+                        )
+        # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some
+        # ranks may have no (non-zero sized) parameter shards, necessitating the
+        # initialization of `per_device_found_inf._per_device_tensors` here
+        if not per_device_found_inf._per_device_tensors:
+            assert self._scale is not None
+            per_device_found_inf.get(self._scale.device)
+        return per_device_found_inf._per_device_tensors
+
+    def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
+        if not self._enabled:
+            return
+
+        self._check_scale_growth_tracker("unscale_")
+
+        optimizer_state = self._per_optimizer_states[id(optimizer)]
+
+        if optimizer_state["stage"] is OptState.UNSCALED:
+            raise RuntimeError(
+                "unscale_() has already been called on this optimizer since the last update()."
+            )
+        elif optimizer_state["stage"] is OptState.STEPPED:
+            raise RuntimeError("unscale_() is being called after step().")
+
+        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
+        assert self._scale is not None
+        inv_scale = self._scale.double().reciprocal().float()
+        found_inf = torch.full(
+            (1,), 0.0, dtype=torch.float32, device=self._scale.device
+        )
+
+        optimizer_state["found_inf_per_device"] = self._unscale_grads_(
+            optimizer, inv_scale, found_inf, True
+        )
+        optimizer_state["stage"] = OptState.UNSCALED
+
+        # Synchronize the detected inf across the ranks
+        optimizer_state = self._per_optimizer_states[id(optimizer)]
+        works = []
+        found_inf_on_cpus = []
+        found_inf_on_cudas = []
+
+        for found_inf in optimizer_state["found_inf_per_device"].values():
+            if self._device == "cuda" and found_inf.device.type == "cpu":
+                found_inf_on_cpus.append(found_inf)
+                found_inf_on_cuda = found_inf.cuda()
+                found_inf_on_cudas.append(found_inf_on_cuda)
+                works.append(
+                    dist.all_reduce(
+                        found_inf_on_cuda, async_op=True, group=self.process_group
+                    )
+                )
+            else:
+                works.append(
+                    dist.all_reduce(found_inf, async_op=True, group=self.process_group)
+                )
+        for work in works:
+            work.wait()
+        if found_inf_on_cpus:
+            torch._foreach_copy_(found_inf_on_cpus, found_inf_on_cudas)
+
+    def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
+        """
+        If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
+        Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
+        """
+        assert self._scale is not None and self._growth_tracker is not None
+
+        if found_inf.item() >= 1.0:
+            self._scale *= self._backoff_factor
+            self._growth_tracker.fill_(0)
+        else:
+            successful = self._growth_tracker + 1
+            if successful == self._growth_interval:
+                self._scale *= self._growth_factor
+                self._growth_tracker.fill_(0)
+            else:
+                self._growth_tracker = successful
+
+    def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
+        """
+        Updates the scale factor.
+        If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
+        to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
+        the scale is multiplied by ``growth_factor`` to increase it.
+        Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
+        used directly, it's used to fill GradScaler's internal scale tensor. So if
+        ``new_scale`` was a tensor, later in-place changes to that tensor will not further
+        affect the scale GradScaler uses internally.)
+        Args:
+            new_scale (float or :class:`torch.Tensor`, optional, default=None):  New scale factor.
+        .. warning::
+            :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
+            been invoked for all optimizers used this iteration.
+        """
+
+        if not self._enabled:
+            return
+
+        _scale, _growth_tracker = self._check_scale_growth_tracker("update")  # type: ignore[var-annotated]
+
+        if new_scale is not None:
+            # Accept a new user-defined scale.
+            if isinstance(new_scale, float):
+                self._scale.fill_(new_scale)  # type: ignore[union-attr]
+            else:
+                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
+                    torch.FloatTensor with requires_grad=False."
+                assert new_scale.device.type == self._device, reason
+                assert new_scale.numel() == 1, reason
+                assert new_scale.requires_grad is False, reason
+                self._scale.copy_(new_scale)  # type: ignore[union-attr]
+        else:
+            # Consume shared inf/nan data collected from optimizers to update the scale.
+            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
+            found_infs = [
+                found_inf.to(device=_scale.device, non_blocking=True)
+                for state in self._per_optimizer_states.values()
+                for found_inf in state["found_inf_per_device"].values()
+            ]
+
+            assert len(found_infs) > 0, "No inf checks were recorded prior to update."
+
+            found_inf_combined = found_infs[0]
+            if len(found_infs) > 1:
+                for i in range(1, len(found_infs)):
+                    found_inf_combined += found_infs[i]
+
+            if _scale.device.type == "cpu":
+                self._amp_update_scale_cpu_(found_inf_combined)
+            else:
+                torch._amp_update_scale_(
+                    self._scale,  # type: ignore[arg-type]
+                    self._growth_tracker,  # type: ignore[arg-type]
+                    found_inf_combined,
+                    self._growth_factor,  # type: ignore[arg-type]
+                    self._backoff_factor,  # type: ignore[arg-type]
+                    self._growth_interval,  # type: ignore[arg-type]
+                )
+
+        # To prepare for next iteration, clear the data collected from optimizers this iteration.
+        self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
diff --git a/MLPY/Lib/site-packages/torch/distributed/fsdp/wrap.py b/MLPY/Lib/site-packages/torch/distributed/fsdp/wrap.py
new file mode 100644
index 0000000000000000000000000000000000000000..5122a6dc8431fa1280683e0f73157b39e6c5a123
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/fsdp/wrap.py
@@ -0,0 +1,606 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+import copy
+from abc import ABC, abstractmethod
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    Iterable,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    Union,
+)
+
+import torch.nn as nn
+
+__all__ = [
+    "always_wrap_policy",
+    "lambda_auto_wrap_policy",
+    "transformer_auto_wrap_policy",
+    "size_based_auto_wrap_policy",
+    "enable_wrap",
+    "wrap",
+    "CustomPolicy",
+    "ModuleWrapPolicy",
+]
+
+
+# NOTE: We intentionally keep this function simple and isolate the complexity
+# to `fn` to enable using this function generically. We may move this to a
+# non-FSDP-specific folder and/or make it public in the future.
+def _post_order_apply(
+    root_module: nn.Module,
+    fn: Callable[[nn.Module], Optional[nn.Module]],
+):
+    """
+    This applies ``fn`` to every module in the module tree of ``root_module``
+    following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
+    then this replaces the original module with the newly returned one in the
+    tree. Otherwise, ``fn`` should return ``None``, in which case the module is
+    not changed.
+    """
+    # Track visited modules to avoid visiting shared modules multiple times
+    visited_modules: Set[nn.Module] = {root_module}
+
+    def _post_order_apply_inner(
+        module: nn.Module,
+        module_name: str,
+        parent_module: Optional[nn.Module],
+    ):
+        for child_module_name, child_module in module.named_children():
+            if child_module not in visited_modules:
+                visited_modules.add(child_module)
+                _post_order_apply_inner(child_module, child_module_name, module)
+        optional_module = fn(module)
+        if optional_module is not None:
+            assert isinstance(parent_module, nn.Module), (
+                "Non-root modules should have their parent module set but got "
+                f"{parent_module} for {module}"
+            )
+            assert module_name, (
+                "Non-root modules should have their module name set but got "
+                f"an empty module name for {module}"
+            )
+            assert isinstance(
+                optional_module, nn.Module
+            ), f"fn should return None or an nn.Module but got {optional_module}"
+            setattr(parent_module, module_name, optional_module)
+
+    _post_order_apply_inner(root_module, "", None)
+
+
+def _construct_wrap_fn(
+    root_module: nn.Module,
+    target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
+    fsdp_fn: Callable,
+) -> Callable[[nn.Module], Optional[nn.Module]]:
+    """
+    This constructs the "wrap" function to pass to :func:`_post_order_apply`
+    based on ``target_module_to_kwargs``, which should be constructed from the
+    wrapping policy.
+    """
+
+    def fn(module: nn.Module) -> Optional[nn.Module]:
+        # Explicitly avoid wrapping the root module since for FSDP, it is
+        # handled by the caller
+        if module in target_module_to_kwargs and module is not root_module:
+            kwargs = target_module_to_kwargs[module]
+            return fsdp_fn(module, **kwargs)
+        return None
+
+    return fn
+
+
+def _run_mixed_precision_override_policy(
+    root_module: nn.Module,
+    module_classes: Iterable[Type[nn.Module]],
+    ignored_modules: Set[nn.Module],
+    root_kwargs: Dict[str, Any],
+    target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
+):
+    module_classes_tuple = tuple(set(module_classes))
+    for module in root_module.modules():
+        if module in ignored_modules:
+            continue
+        elif isinstance(module, module_classes_tuple):
+            # This policy overrides any existing policy
+            if module not in target_module_to_kwargs:
+                # Only inherit from the root kwargs if not already specified
+                target_module_to_kwargs[module] = root_kwargs
+            target_module_to_kwargs[module]["mixed_precision"] = None
+    return target_module_to_kwargs
+
+
+def always_wrap_policy(*args, **kwargs) -> bool:
+    """
+    A simple recursive wrap policy that always returns ``True``. This means
+    that every submodule is wrapped by the wrapper class in
+    :func:`_recursive_wrap`.
+    """
+    return True
+
+
+class _Policy(ABC):
+    """
+    This defines an abstract base class that represents a policy for applying
+    a module-level API.
+    """
+
+    @abstractmethod
+    def _run_policy(
+        self,
+        root_module: nn.Module,
+        ignored_modules: Set[nn.Module],
+        root_kwargs: Dict[str, Any],
+    ) -> Dict[nn.Module, Dict[str, Any]]:
+        """
+        This should return a dict ``target_module_to_kwargs`` that maps from
+        each target module to wrap to its kwargs.
+        """
+        ...
+
+
+def _module_wrap_policy(
+    module: nn.Module,
+    recurse: bool,
+    nonwrapped_numel: int,
+    module_classes: Set[Type[nn.Module]],
+) -> bool:
+    """
+    This auto wrap policy wraps every module that is an instance of any type in
+    ``module_classes`` as its own FSDP instance. The root module given by
+    ``module`` is always wrapped as an FSDP instance regardless. Since the
+    wrapping proceeds bottom up, each FSDP instance manages the parameters in
+    its subtree excluding any already managed by a child FSDP instance.
+
+    Args:
+        module (nn.Module): Current module being considered.
+        recurse (bool): If ``False``, then this function must decide whether
+            ``module`` should be wrapped as an FSDP instance or not. If
+            ``True``, then the function is still recursing down the module
+            tree as a part of the DFS.
+        nonwrapped_numel (int): Parameter numel not yet wrapped.
+        module_classes (Set[Type[nn.Module]]): Set of module classes that are
+            wrapped as FSDP instances.
+
+    Returns:
+        ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
+        if ``recurse=False``.
+    """
+    if recurse:
+        return True  # always recurse
+    return isinstance(module, tuple(module_classes))
+
+
+class ModuleWrapPolicy(_Policy):
+    """
+    This policy applies to every module of the specified module classes,
+    passing in the kwargs given to the root.
+    """
+
+    def __init__(self, module_classes: Iterable[Type[nn.Module]]):
+        module_classes_set = set(module_classes)
+        self._module_classes = module_classes_set
+        self._module_classes_str = str(module_classes_set)
+
+    def _run_policy(
+        self,
+        root_module: nn.Module,
+        ignored_modules: Set[nn.Module],
+        root_kwargs: Dict[str, Any],
+    ) -> Dict[nn.Module, Dict[str, Any]]:
+        module_classes = tuple(self._module_classes)
+        target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
+        for module in root_module.modules():
+            if module in ignored_modules:
+                continue
+            elif isinstance(module, module_classes):
+                # Shallow copy to avoid coupling changes across modules
+                target_module_to_kwargs[module] = copy.copy(root_kwargs)
+        return target_module_to_kwargs
+
+    def __call__(self, module, recurse, *args, **kwargs):
+        # nonwrapped_numel is not used.
+        return _module_wrap_policy(
+            module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
+        )
+
+    def __repr__(self) -> str:
+        return super().__repr__() + f"({self._module_classes_str})"
+
+
+class CustomPolicy(_Policy):
+    """
+    This policy takes in a lambda function that maps a given ``nn.Module`` to
+    either ``False``, ``True``, or a kwarg dictionary.
+    - If the function returns ``False`` or an empty dictionary, then the module
+      does not have the API applied.
+    - If the function returns ``True``, then the module has the API applied
+      with the root's kwargs.
+    - If the function returns a non-empty dictionary, then the module has the
+      API applied, and the dictionary overrides the root's kwargs.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> model = init_transformer_model(...)
+        >>> def lambda_fn(module: nn.Module):
+        >>>     if module is model.lm_head:
+        >>>         return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
+        >>>     elif isinstance(module, TransformerBlock):
+        >>>         return True
+        >>>     return False
+        >>> policy = CustomPolicy(lambda_fn)
+        >>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
+    """
+
+    def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
+        self._lambda_fn = lambda_fn
+
+    def _run_policy(
+        self,
+        root_module: nn.Module,
+        ignored_modules: Set[nn.Module],
+        root_kwargs: Dict[str, Any],
+    ) -> Dict[nn.Module, Dict[str, Any]]:
+        target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
+        for module in root_module.modules():
+            if module in ignored_modules:
+                continue
+            res = self._lambda_fn(module)
+            if not isinstance(res, (dict, bool)):
+                raise ValueError(
+                    "The lambda_fn passed to CustomPolicy should return "
+                    f"False/True or a kwarg dict, but it returned {res}"
+                )
+            if not res:
+                continue
+            kwargs = copy.copy(root_kwargs)
+            if isinstance(res, dict):
+                # Override the root kwargs with the ones specified by the
+                # lambda function
+                kwargs.update(res)
+            target_module_to_kwargs[module] = kwargs
+        return target_module_to_kwargs
+
+
+def lambda_auto_wrap_policy(
+    module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
+) -> bool:
+    """
+    A convenient auto wrap policy to wrap submodules based on an arbitrary user
+    function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
+    a `wrapper_cls` unit.
+
+    Return if a module should be wrapped during auto wrapping.
+
+    The first three parameters are required by :func:`_recursive_wrap`.
+
+    Args:
+        module (nn.Module): Current module being considered.
+        recurse (bool): If ``False``, then this function must decide whether
+            ``module`` should be wrapped as an FSDP instance or not. If
+            ``True``, then the function is still recursing down the module
+            tree as a part of the DFS.
+        nonwrapped_numel (int): Parameter numel not yet wrapped.
+
+        lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
+            this module will be wrapped.
+    """
+    if recurse:
+        return True  # always recurse
+    return lambda_fn(module)
+
+
+def transformer_auto_wrap_policy(
+    module: nn.Module,
+    recurse: bool,
+    nonwrapped_numel: int,
+    transformer_layer_cls: Set[Type[nn.Module]],
+) -> bool:
+    """
+    See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
+    same as ``module_classes``. Note that shared parameters must be wrapped in
+    the same FSDP instance, so this auto wrap policy can help wrap shared
+    embeddings into the same FSDP instance for transformer models.
+    """
+    return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
+
+
+def _wrap_module_cls_individually(
+    module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
+):
+    if recurse:
+        # always recurse
+        return True
+    else:
+        # if not recursing, decide whether we should wrap based on whether the type of module
+        # is in `module_classes`.
+        return isinstance(module, tuple(module_classes))
+
+
+def _or_policy(
+    module: nn.Module,
+    recurse: bool,
+    nonwrapped_numel: int,
+    policies,
+) -> bool:
+    """
+    A policy that wraps ``module`` if any policy in the passed in iterable of
+    ``policies`` returns ``True``.
+    """
+    return any(
+        policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
+        for policy in policies
+    )
+
+
+def size_based_auto_wrap_policy(
+    module: nn.Module,
+    recurse: bool,
+    nonwrapped_numel: int,
+    # Additional custom arguments
+    min_num_params: int = int(1e8),
+    force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
+    exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
+) -> bool:
+    """
+    A size-based auto wrap policy.
+
+    Args:
+        module (nn.Module): Current module being considered.
+        recurse (bool): If ``False``, then this function must decide whether
+            ``module`` should be wrapped as an FSDP instance or not. If
+            ``True``, then the function is still recursing down the module
+            tree as a part of the DFS.
+        nonwrapped_numel (int): Parameter numel not yet wrapped.
+
+        min_num_params (int): Customizable policy input that controls the size
+            threshold over which a module is ready to be wrapped. This is in
+            units of numel.
+        force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
+            as leaves, i.e. their children will never be wrapped.
+        exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
+            excluded in wrapping.
+
+    Returns:
+        Whether ``module`` should be wrapped.
+    """
+    force_leaf_modules = (
+        size_based_auto_wrap_policy.FORCE_LEAF_MODULES  # type: ignore[attr-defined]
+        if force_leaf_modules is None
+        else force_leaf_modules
+    )
+    exclude_wrap_modules = (
+        size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES  # type: ignore[attr-defined]
+        if exclude_wrap_modules is None
+        else exclude_wrap_modules
+    )
+
+    # Keep the argument `min_num_params` for BC for now, but it represents the
+    # minimum non-wrapped *numel* before triggering a wrapping
+    min_nonwrapped_numel = min_num_params
+    is_large = nonwrapped_numel >= min_nonwrapped_numel
+    if recurse:
+        # We should recurse if the module is big enough but not in force_leaf_modules list.
+        return is_large and not isinstance(module, tuple(force_leaf_modules))
+    else:
+        # If we are not recursing, determine if we should wrap.
+        return is_large and not isinstance(module, tuple(exclude_wrap_modules))
+
+
+# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
+size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}  # type: ignore[attr-defined]
+size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention}  # type: ignore[attr-defined]
+
+
+@contextlib.contextmanager
+def enable_wrap(
+    *, wrapper_cls: Any, **wrapper_kwargs: Any
+) -> Generator[None, None, None]:
+    """
+    Context manager to wrap modules using a wrapper.
+
+    Useful for when you'd like to apply the same configuration arguments to all
+    child modules that you wrap. A particularly important use case is wrapping
+    large layers so that they get sharded (in-place) during initialization, to
+    avoid running out of system memory. Large layers can indicate that they
+    should be sharded via the ``wrap`` annotation and this context manager can
+    provide the exact configuration for these nested instances.
+
+    Usage::
+
+        with enable_wrap(wrapper_cls, **params):
+            # Wraps layer in FSDP by default if within context
+            self.l1 = wrap(torch.nn.Linear(5, 5))
+
+    Args:
+        wrapper_cls:
+            Class that `wrap` annotation will `wrap` modules with, such as
+            `FullyShardedDataParallel`.
+        **wrapper_kwargs:
+            Configuration settings that will be passed to all ``wrap``
+            instances inside the context
+    """
+    kwargs = {
+        "wrapper_cls": wrapper_cls,
+        **wrapper_kwargs,
+    }
+    with _ConfigAutoWrap(**kwargs):
+        yield
+
+
+def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
+    """
+    Annotate that a module should be wrapped. Annotated modules will only be
+    wrapped if inside of an :func:`enable_wrap` context manager. This allows
+    a module to be initialized both with and without a wrapper without code
+    change.
+
+    The class that this function wraps the passed in ``nn.Module`` with is the
+    passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
+    ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
+    the ``wrapper_cls`` instance. In the case of duplicate kwargs in
+    ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
+    respected.
+
+    Usage::
+
+        with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
+            # Wraps layer in FSDP by default if within context
+            self.l1 = wrap(torch.nn.Linear(5, 5))
+
+    Args:
+        module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
+        **wrap_overrides: configuration overrides that will take priority over
+            the values provided by the :func:`enable_wrap` context
+    """
+    if _ConfigAutoWrap.in_autowrap_context:
+        assert _ConfigAutoWrap.wrapper_cls is not None
+
+        wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
+        return _wrap(
+            module,
+            _ConfigAutoWrap.wrapper_cls,
+            **wrap_overrides,
+        )
+    return module
+
+
+def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
+    assert wrapper_cls is not None
+    if hasattr(module, "_wrap_overrides"):
+        # If module has a _wrap_overrides attribute, we force overriding the
+        # FSDP config with these attributes for this module. Currently this
+        # is only used to disable mixed precision for BatchNorm when
+        # auto_wrapping.
+        overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]
+        return wrapper_cls(module, **overrides)
+
+    return wrapper_cls(module, **kwargs)
+
+
+def _recursive_wrap(
+    module: nn.Module,
+    auto_wrap_policy: Callable,
+    wrapper_cls: Callable,
+    ignored_modules: Set[nn.Module],
+    ignored_params: Set[nn.Parameter],
+    only_wrap_children: bool = False,
+    **kwargs: Any,
+) -> Tuple[nn.Module, int]:
+    """
+    Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
+    ``True`` with ``wrapper_cls``.
+
+    Args:
+        module (nn.Module): Module to recursively wrap.
+        auto_wrap_policy (Callable): A callable representing a policy that
+            determines which modules to recursively wrap with ``wrapper_cls``.
+        ignored_modules (Set[torch.nn.Module]): Modules to ignore when
+            wrapping.
+        ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
+            wrapping; these should be the parameters contained in the modules
+            in ``ignored_modules``.
+    Returns:
+        (nn.Module, int):
+            ``module`` after wrapping and the numel recursively wrapped.
+    """
+    assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
+    assert wrapper_cls is not None, "Must specify wrapper_cls"
+    # Make sure no child is already wrapped.
+    for _, child in module.named_modules():
+        if child in ignored_modules:
+            continue
+        try:
+            assert not isinstance(child, cast(type, wrapper_cls))
+        except TypeError:
+            # wrapper_cls is a function as opposed to a class type, just bypass above check.
+            pass
+
+    # We count all params, assuming none of them are already wrapped.
+    nonwrapped_numel = sum(
+        p.numel() for p in module.parameters() if p not in ignored_params
+    )
+
+    assert auto_wrap_policy is not None
+    if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
+        total_wrapped_numel = 0
+        # Iterate through the children, recursively wrap if necessary
+        for name, child in module.named_children():
+            if child in ignored_modules:
+                continue
+            wrapped_child, num_wrapped_params = _recursive_wrap(
+                module=child,
+                auto_wrap_policy=auto_wrap_policy,
+                wrapper_cls=wrapper_cls,
+                ignored_modules=ignored_modules,
+                ignored_params=ignored_params,
+                **kwargs,
+            )
+            setattr(module, name, wrapped_child)
+            # Keep track of how many parameters have been wrapped
+            total_wrapped_numel += num_wrapped_params
+        # decide if we need to wrap the current module,
+        # since the left over parameters exceed the number of params to wrap
+        remainder = nonwrapped_numel - total_wrapped_numel
+        if not only_wrap_children and auto_wrap_policy(
+            module=module, recurse=False, nonwrapped_numel=remainder
+        ):
+            # Leaf node or final wrapping of the remainder both happen here.
+            return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
+        else:
+            return module, total_wrapped_numel
+    return module, 0
+
+
+class _ConfigAutoWrap:
+    """
+    Helper class to wrap modules based on default config args via a context manager.
+    See :func:`enable_wrap` for more information.
+    """
+
+    in_autowrap_context: bool = False  # Context flag
+    wrapper_cls: Optional[Callable] = None  # The wrapper class
+    kwargs: Dict[str, Any] = {}  # Wrapper's args
+
+    def __init__(self, **kwargs: Dict[str, Any]):
+        self.kwargs = kwargs
+
+    @staticmethod
+    def enable_autowrap_context(kwargs: Any) -> None:
+        if _ConfigAutoWrap.in_autowrap_context:
+            raise NotImplementedError(
+                "You are already within an autowrap context and we currently do not supported nested autowrap."
+            )
+        _ConfigAutoWrap.in_autowrap_context = True
+        # Get and save the wrapper cls for the context.
+        assert (
+            "wrapper_cls" in kwargs.keys()
+        ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
+        _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
+        del kwargs["wrapper_cls"]
+        # Save the rest.
+        _ConfigAutoWrap.kwargs = kwargs
+
+    @staticmethod
+    def disable_autowrap_context() -> None:
+        _ConfigAutoWrap.in_autowrap_context = False
+        _ConfigAutoWrap.wrapper_cls = None
+        _ConfigAutoWrap.kwargs = {}
+
+    def __enter__(self) -> None:
+        self.enable_autowrap_context(self.kwargs)
+
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+        self.disable_autowrap_context()
diff --git a/MLPY/Lib/site-packages/torch/distributed/launch.py b/MLPY/Lib/site-packages/torch/distributed/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..babf258707b2cd719dae5159de1b5b691c69cac3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/launch.py
@@ -0,0 +1,198 @@
+r"""
+Module ``torch.distributed.launch``.
+
+``torch.distributed.launch`` is a module that spawns up multiple distributed
+training processes on each of the training nodes.
+
+.. warning::
+
+    This module is going to be deprecated in favor of :ref:`torchrun `.
+
+The utility can be used for single-node distributed training, in which one or
+more processes per node will be spawned. The utility can be used for either
+CPU training or GPU training. If the utility is used for GPU training,
+each distributed process will be operating on a single GPU. This can achieve
+well-improved single-node training performance. It can also be used in
+multi-node distributed training, by spawning up multiple processes on each node
+for well-improved multi-node distributed training performance as well.
+This will especially be beneficial for systems with multiple Infiniband
+interfaces that have direct-GPU support, since all of them can be utilized for
+aggregated communication bandwidth.
+
+In both cases of single-node distributed training or multi-node distributed
+training, this utility will launch the given number of processes per node
+(``--nproc-per-node``). If used for GPU training, this number needs to be less
+or equal to the number of GPUs on the current system (``nproc_per_node``),
+and each process will be operating on a single GPU from *GPU 0 to
+GPU (nproc_per_node - 1)*.
+
+**How to use this module:**
+
+1. Single-Node multi-process distributed training
+
+::
+
+    python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
+               YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
+               arguments of your training script)
+
+2. Multi-Node multi-process distributed training: (e.g. two nodes)
+
+
+Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
+
+::
+
+    python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
+               --nnodes=2 --node-rank=0 --master-addr="192.168.1.1"
+               --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
+               and all other arguments of your training script)
+
+Node 2:
+
+::
+
+    python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
+               --nnodes=2 --node-rank=1 --master-addr="192.168.1.1"
+               --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
+               and all other arguments of your training script)
+
+3. To look up what optional arguments this module offers:
+
+::
+
+    python -m torch.distributed.launch --help
+
+
+**Important Notices:**
+
+1. This utility and multi-process distributed (single-node or
+multi-node) GPU training currently only achieves the best performance using
+the NCCL distributed backend. Thus NCCL backend is the recommended backend to
+use for GPU training.
+
+2. In your training program, you must parse the command-line argument:
+``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
+If your training program uses GPUs, you should ensure that your code only
+runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
+
+Parsing the local_rank argument
+
+::
+
+    >>> # xdoctest: +SKIP
+    >>> import argparse
+    >>> parser = argparse.ArgumentParser()
+    >>> parser.add_argument("--local-rank", type=int)
+    >>> args = parser.parse_args()
+
+Set your device to local rank using either
+
+::
+
+    >>> torch.cuda.set_device(args.local_rank)  # before your code runs
+
+or
+
+::
+
+    >>> with torch.cuda.device(args.local_rank):
+    >>>    # your code to run
+    >>>    ...
+
+3. In your training program, you are supposed to call the following function
+at the beginning to start the distributed backend. It is strongly recommended
+that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work,
+but ``env://`` is the one that is officially supported by this module.
+
+::
+
+    >>> torch.distributed.init_process_group(backend='YOUR BACKEND',
+    >>>                                      init_method='env://')
+
+4. In your training program, you can either use regular distributed functions
+or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
+training program uses GPUs for training and you would like to use
+:func:`torch.nn.parallel.DistributedDataParallel` module,
+here is how to configure it.
+
+::
+
+    >>> model = torch.nn.parallel.DistributedDataParallel(model,
+    >>>                                                   device_ids=[args.local_rank],
+    >>>                                                   output_device=args.local_rank)
+
+Please ensure that ``device_ids`` argument is set to be the only GPU device id
+that your code will be operating on. This is generally the local rank of the
+process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
+and ``output_device`` needs to be ``args.local_rank`` in order to use this
+utility
+
+5. Another way to pass ``local_rank`` to the subprocesses via environment variable
+``LOCAL_RANK``. This behavior is enabled when you launch the script with
+``--use-env=True``. You must adjust the subprocess example above to replace
+``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
+will not pass ``--local-rank`` when you specify this flag.
+
+.. warning::
+
+    ``local_rank`` is NOT globally unique: it is only unique per process
+    on a machine.  Thus, don't use it to decide if you should, e.g.,
+    write to a networked filesystem.  See
+    https://github.com/pytorch/pytorch/issues/12042 for an example of
+    how things can go wrong if you don't do this correctly.
+
+
+
+"""
+
+import logging
+import warnings
+
+from torch.distributed.run import get_args_parser, run
+
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args(args):
+    parser = get_args_parser()
+    parser.add_argument(
+        "--use-env",
+        "--use_env",
+        default=False,
+        action="store_true",
+        help="Use environment variable to pass "
+        "'local rank'. For legacy reasons, the default value is False. "
+        "If set to True, the script will not pass "
+        "--local-rank as argument, and will instead set LOCAL_RANK.",
+    )
+    return parser.parse_args(args)
+
+
+def launch(args):
+    if args.no_python and not args.use_env:
+        raise ValueError(
+            "When using the '--no-python' flag,"
+            " you must also set the '--use-env' flag."
+        )
+    run(args)
+
+
+def main(args=None):
+    warnings.warn(
+        "The module torch.distributed.launch is deprecated\n"
+        "and will be removed in future. Use torchrun.\n"
+        "Note that --use-env is set by default in torchrun.\n"
+        "If your script expects `--local-rank` argument to be set, please\n"
+        "change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
+        "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
+        "further instructions\n",
+        FutureWarning,
+    )
+    args = parse_args(args)
+    launch(args)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/MLPY/Lib/site-packages/torch/distributed/launcher/__init__.py b/MLPY/Lib/site-packages/torch/distributed/launcher/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e1923710adbdaffe0930da1118e1ec6d60b331c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/launcher/__init__.py
@@ -0,0 +1,14 @@
+#!/usr/bin/env/python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from torch.distributed.launcher.api import (  # noqa: F401
+    LaunchConfig,
+    elastic_launch,
+    launch_agent,
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc3615084d9c0d822523bc81d2eaf0e0747036ad
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/launcher/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/launcher/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a65d9b855d9f8fbdd2c7bb29feb2c8237de2fd0c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/launcher/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/launcher/api.py b/MLPY/Lib/site-packages/torch/distributed/launcher/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b3f6cf03b7564140115fe3331b306050f4ebf2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/launcher/api.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import sys
+import uuid
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch.distributed.elastic.rendezvous.registry as rdzv_registry
+from torch.distributed.elastic import events, metrics
+from torch.distributed.elastic.agent.server.api import WorkerSpec
+from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
+from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException
+from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
+from torch.distributed.elastic.rendezvous import RendezvousParameters
+from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
+from torch.distributed.elastic.utils.logging import get_logger
+
+__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class LaunchConfig:
+    """
+    Creates a rendezvous config.
+
+    Args:
+        min_nodes: Minimum amount of nodes that the user function will
+                        be launched on. Elastic agent ensures that the user
+                        function start only when the min_nodes amount enters
+                        the rendezvous.
+        max_nodes: Maximum amount of nodes that the user function
+                        will be launched on.
+        nproc_per_node: On each node the elastic agent will launch
+                            this amount of workers that will execute user
+                            defined function.
+        rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
+        rdzv_endpoint: The endpoint of the rdzv sync. storage.
+        rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
+        rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
+            to be removed in future versions, see the note below. The default timeout is 900 seconds.
+        run_id: The unique run id of the job (if not passed a unique one will be
+                deduced from run environment - flow workflow id in flow - or auto generated).
+        role: User defined role of the worker (defaults to "trainer").
+        max_restarts: The maximum amount of restarts that elastic agent will conduct
+                    on workers before failure.
+        monitor_interval: The interval in seconds that is used by the elastic_agent
+                        as a period of monitoring workers.
+        start_method: The method is used by the elastic agent to start the
+                    workers (spawn, fork, forkserver).
+        metrics_cfg: configuration to initialize metrics.
+        local_addr: address of the local node if any. If not set, a lookup on the local
+                machine's FQDN will be performed.
+        local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
+    ..note:
+        `rdzv_timeout` is a legacy argument that will be removed in future.
+        Set the timeout via `rdzv_configs['timeout']`
+
+    """
+
+    min_nodes: int
+    max_nodes: int
+    nproc_per_node: int
+    logs_specs: Optional[LogsSpecs] = None
+    run_id: str = ""
+    role: str = "default_role"
+    rdzv_endpoint: str = ""
+    rdzv_backend: str = "etcd"
+    rdzv_configs: Dict[str, Any] = field(default_factory=dict)
+    rdzv_timeout: int = -1
+    max_restarts: int = 3
+    monitor_interval: float = 30
+    start_method: str = "spawn"
+    log_line_prefix_template: Optional[str] = None
+    metrics_cfg: Dict[str, str] = field(default_factory=dict)
+    local_addr: Optional[str] = None
+
+    def __post_init__(self):
+        default_timeout = 900
+        if self.rdzv_timeout != -1:
+            self.rdzv_configs["timeout"] = self.rdzv_timeout
+        elif "timeout" not in self.rdzv_configs:
+            self.rdzv_configs["timeout"] = default_timeout
+
+        # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage
+        if self.logs_specs is None:
+            self.logs_specs = DefaultLogsSpecs()
+
+
+class elastic_launch:
+    """
+    Launches an torchelastic agent on the container that invoked the entrypoint.
+
+        1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
+           ``entrypoint`` can be a function or a command.
+        2. The return value is a map of each worker's output mapped
+           by their respective global rank.
+
+    Usage
+
+    ::
+
+    def worker_fn(foo):
+        # ...
+
+    def main():
+        # entrypoint is a function.
+        outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
+        # return rank 0's output
+        return outputs[0]
+
+        # entrypoint is a command and ``script.py`` is the python module.
+        outputs = elastic_launch(LaunchConfig, "script.py")(args)
+        outputs = elastic_launch(LaunchConfig, "python")("script.py")
+    """
+
+    def __init__(
+        self,
+        config: LaunchConfig,
+        entrypoint: Union[Callable, str, None],
+    ):
+        self._config = config
+        self._entrypoint = entrypoint
+
+    def __call__(self, *args):
+        return launch_agent(self._config, self._entrypoint, list(args))
+
+
+def _get_entrypoint_name(
+    entrypoint: Union[Callable, str, None], args: List[Any]
+) -> str:
+    """Retrieve entrypoint name with the rule:
+    1. If entrypoint is a function, use ``entrypoint.__qualname__``.
+    2. If entrypoint is a string, check its value:
+        2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
+            which does not start with hifen letter (for example, "-u" will be skipped).
+        2.2 otherwise, use ``entrypoint`` value.
+    3. Otherwise, return empty string.
+    """
+    if isinstance(entrypoint, Callable):  # type: ignore[arg-type]
+        return entrypoint.__name__  # type: ignore[union-attr]
+    elif isinstance(entrypoint, str):
+        if entrypoint == sys.executable:
+            return next((arg for arg in args if arg[0] != "-"), "")
+        else:
+            return entrypoint
+    else:
+        return ""
+
+
+def _get_addr_and_port(
+    rdzv_parameters: RendezvousParameters,
+) -> Tuple[Optional[str], Optional[int]]:
+    if rdzv_parameters.backend != "static":
+        return (None, None)
+    endpoint = rdzv_parameters.endpoint
+    endpoint = endpoint.strip()
+    if not endpoint:
+        raise ValueError(
+            "Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
+        )
+    master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
+    if master_port == -1:
+        raise ValueError(
+            f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
+        )
+    return (master_addr, master_port)
+
+
+def launch_agent(
+    config: LaunchConfig,
+    entrypoint: Union[Callable, str, None],
+    args: List[Any],
+) -> Dict[int, Any]:
+    if not config.run_id:
+        run_id = str(uuid.uuid4().int)
+        logger.warning("config has no run_id, generated a random run_id: %s", run_id)
+        config.run_id = run_id
+
+    entrypoint_name = _get_entrypoint_name(entrypoint, args)
+
+    logger.info(
+        "Starting elastic_operator with launch configs:\n"
+        "  entrypoint       : %(entrypoint)s\n"
+        "  min_nodes        : %(min_nodes)s\n"
+        "  max_nodes        : %(max_nodes)s\n"
+        "  nproc_per_node   : %(nproc_per_node)s\n"
+        "  run_id           : %(run_id)s\n"
+        "  rdzv_backend     : %(rdzv_backend)s\n"
+        "  rdzv_endpoint    : %(rdzv_endpoint)s\n"
+        "  rdzv_configs     : %(rdzv_configs)s\n"
+        "  max_restarts     : %(max_restarts)s\n"
+        "  monitor_interval : %(monitor_interval)s\n"
+        "  log_dir          : %(log_dir)s\n"
+        "  metrics_cfg      : %(metrics_cfg)s\n",
+        {
+            "entrypoint": entrypoint_name,
+            "min_nodes": config.min_nodes,
+            "max_nodes": config.max_nodes,
+            "nproc_per_node": config.nproc_per_node,
+            "run_id": config.run_id,
+            "rdzv_backend": config.rdzv_backend,
+            "rdzv_endpoint": config.rdzv_endpoint,
+            "rdzv_configs": config.rdzv_configs,
+            "max_restarts": config.max_restarts,
+            "monitor_interval": config.monitor_interval,
+            "log_dir": config.logs_specs.root_log_dir,  # type: ignore[union-attr]
+            "metrics_cfg": config.metrics_cfg
+        }
+    )
+
+    rdzv_parameters = RendezvousParameters(
+        backend=config.rdzv_backend,
+        endpoint=config.rdzv_endpoint,
+        run_id=config.run_id,
+        min_nodes=config.min_nodes,
+        max_nodes=config.max_nodes,
+        local_addr=config.local_addr,
+        **config.rdzv_configs,
+    )
+
+    master_addr, master_port = _get_addr_and_port(rdzv_parameters)
+
+    spec = WorkerSpec(
+        role=config.role,
+        local_world_size=config.nproc_per_node,
+        entrypoint=entrypoint,
+        args=tuple(args),
+        rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
+        max_restarts=config.max_restarts,
+        monitor_interval=config.monitor_interval,
+        master_addr=master_addr,
+        master_port=master_port,
+        local_addr=config.local_addr,
+    )
+
+    agent = LocalElasticAgent(
+        spec=spec,
+        logs_specs=config.logs_specs,  # type: ignore[arg-type]
+        start_method=config.start_method,
+        log_line_prefix_template=config.log_line_prefix_template,
+    )
+
+    shutdown_rdzv = True
+    try:
+        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
+
+        result = agent.run()
+        # records that agent.run() has succeeded NOT that workers have succeeded
+        events.record(agent.get_event_succeeded())
+
+        if result.is_failed():
+            # ChildFailedError is treated specially by @record
+            # if the error files for the failed children exist
+            # @record will copy the first error (root cause)
+            # to the error file of the launcher process.
+            raise ChildFailedError(
+                name=entrypoint_name,
+                failures=result.failures,
+            )
+
+        return result.return_values
+    except ChildFailedError:
+        raise
+    except SignalException:
+        # when the agent dies with a signal do NOT shutdown the rdzv_handler
+        # since this closes the rendezvous on this rdzv_id permanently and
+        # prevents any additional scaling events
+        shutdown_rdzv = False
+        events.record(agent.get_event_failed())
+        raise
+    except Exception:
+        events.record(agent.get_event_failed())
+        raise
+    finally:
+        if shutdown_rdzv:
+            spec.rdzv_handler.shutdown()
diff --git a/MLPY/Lib/site-packages/torch/distributed/logging_handlers.py b/MLPY/Lib/site-packages/torch/distributed/logging_handlers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a775863e0b06b2f7597cd9cae85d19110271a1f6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/logging_handlers.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Dict, List
+
+__all__: List[str] = []
+
+_log_handlers: Dict[str, logging.Handler] = {
+    "default": logging.NullHandler(),
+}
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/__init__.py b/MLPY/Lib/site-packages/torch/distributed/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..625a0c95db7f0abf299697ab017038df21c3e0b2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/nn/__init__.py
@@ -0,0 +1,4 @@
+import torch
+if torch.distributed.rpc.is_available():
+    from .api.remote_module import RemoteModule
+from .functional import *  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bc78b0c91cb3905083c840e9a535d24bcec4650
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/__pycache__/functional.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/__pycache__/functional.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f26246000fe624edf5834bb64331beed216fe07e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/__pycache__/functional.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/api/__init__.py b/MLPY/Lib/site-packages/torch/distributed/nn/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c05488093f99cc8008a441c43c3c057ffb1f981f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b092a38973d1f3d2cea1d1ffd400b210be8319a0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/api/remote_module.py b/MLPY/Lib/site-packages/torch/distributed/nn/api/remote_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ad29ee898e26650bdfe0494d6559de05fa0a87
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/nn/api/remote_module.py
@@ -0,0 +1,760 @@
+#!/usr/bin/python3
+import collections
+import io
+import sys
+import types
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
+
+import torch
+import torch.distributed.rpc as rpc
+from torch import Tensor, device, dtype, nn
+from torch.distributed.nn.jit import instantiator
+from torch.distributed import _remote_device
+from torch.distributed.rpc.internal import _internal_rpc_pickler
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+from torch.utils.hooks import RemovableHandle
+
+__all__ = ["RemoteModule"]
+
+_grad_t = Union[Tuple[Tensor, ...], Tensor]
+# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
+# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
+# the type of the subclass, not the looser type of `Module`.
+T = TypeVar("T", bound="Module")
+
+_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
+    instantiator.instantiate_non_scriptable_remote_module_template()
+)
+
+_REMOTE_MODULE_PICKLED_ATTRIBUTES = (
+    "on",
+    "device",
+    "is_device_map_set",
+    "is_scriptable",
+    "generated_methods",
+    "module_rref",
+)
+
+_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES)  # type: ignore[misc]
+
+# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
+# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
+# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
+# Otherwise, it will not be pickled.
+_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = (
+    "training",
+    "_parameters",
+    "_buffers",
+    "_non_persistent_buffers_set",
+    "_backward_hooks",
+    "_backward_pre_hooks",
+    "_is_full_backward_hook",
+    "_forward_hooks",
+    "_forward_hooks_with_kwargs",
+    "_forward_hooks_always_called",
+    "_forward_pre_hooks",
+    "_forward_pre_hooks_with_kwargs",
+    "_state_dict_hooks",
+    "_state_dict_pre_hooks",
+    "_load_state_dict_pre_hooks",
+    "_load_state_dict_post_hooks",
+    "_state_dict_pre_hooks",
+    "_modules",
+    # The two attributes below are generated methods, not available at pickling time.
+    "forward_async",
+    "forward",
+)
+
+
+# RPC handler.
+def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda):
+    instantiator.instantiate_scriptable_remote_module_template(
+        module_interface_cls, enable_moving_cpu_tensors_to_cuda
+    )
+
+
+def _create_module(module_cls, args, kwargs, device):
+    module = module_cls(*args, **kwargs)
+    if not isinstance(module, nn.Module):
+        raise ValueError(
+            "Expect `module_cls(*args, **kwargs)` returns an instance of , "
+            f"but it returns an instance of {type(module)}."
+        )
+    module.to(device)
+    return module
+
+
+def _create_module_with_interface(
+    module_cls, args, kwargs, device, module_interface_cls
+):
+    module = _create_module(module_cls, args, kwargs, device)
+    if module_interface_cls is not None:
+        module = torch.jit.script(module)
+    return rpc.RRef(module, module_interface_cls)
+
+
+def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
+    ret: List[rpc.RRef[Parameter]] = []
+    for param in module_rref.local_value().parameters(recurse):
+        ret.append(rpc.RRef(param))
+    return ret
+
+
+def _raise_not_supported(name: str) -> None:
+    raise ValueError(f"Method ``{name}`` not supported for RemoteModule")
+
+
+class _RemoteModule(nn.Module):
+
+    def __new__(cls, *args, **kwargs):
+        # Use __new__ for logging purposes.
+        torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module")
+        return super().__new__(cls)
+
+    def __init__(
+        self,
+        remote_device: str,
+        module_cls: Type[nn.Module],
+        args: Optional[Tuple] = None,
+        kwargs: Optional[Dict[str, Any]] = None,
+        _module_interface_cls: Any = None,
+    ):
+        """
+        RemoteModule instance can only be created after RPC initialization.
+
+        It creates a user-specified module on a specified remote node.
+        It behaves like a regular ``nn.Module`` except that the ``forward`` method is
+        executed on the remote node.
+        It takes care of autograd recording to ensure the backward pass propagates
+        gradients back to the corresponding remote module.
+        It can be shared across processors using `RPC framework `__,
+        without incurring any overheads of copying the actual module,
+        which is equivalent to an :class:`~torch.distributed.rpc.RRef`
+        pointing to the remote module.
+
+        The arguments of ``forward_async`` and ``forward`` are the same as
+        the ``forward`` method of the module returned by the ``module_cls``.
+
+        Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now.
+
+        Particularly, to create a hybrid model, typically the local modules should be
+        created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
+        Hybrid Example:
+                >>> class HybridModel(nn.Module):
+                >>>     def __init__(self):
+                >>>         nn.Module.__init__(self)
+                >>>         self.remote_embedding = RemoteModule(...)
+                >>>         self.local_linear = nn.Linear(...)
+
+        For example, if ``module_cls`` returns an instance of ``nn.Linear``,
+        that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
+        the generated ``RemoteModule`` will have 2 methods in signature of
+        ``def forward(input: Tensor) -> Tensor:`` and
+        ``def forward_async(input: Tensor) -> Future[Tensor]:``.
+
+        .. note::
+            If the remote module is placed on a cuda device,
+            any input CPU tensors will be automatically moved to the same cuda device,
+            and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend.
+
+        Args:
+            remote_device (str): Device on the destination worker where we'd like to place this module.
+                The device can be a local device or a remote device specified by one of the following remote
+                formats:
+
+                    1. "rank:/" (ex: "rank:0/cuda:0").
+                    2. "/" (ex: "trainer0/cuda:0").
+
+                In addition, the device field can be optional and the default value is "cpu".
+            module_cls (nn.Module): For example,
+                >>> class MyModule(nn.Module):
+                >>>     def forward(input):
+                >>>         return input + 1
+                >>>
+                >>> module_cls = MyModule
+            args (Sequence, optional): args to be passed to ``module_cls``.
+            kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
+            _module_interface_cls (type, optional): The TorchScript interface type for the module
+                to be created. The type object should be decorated by @torch.jit.interface.
+                If not provided, the generated RemoteModule is not torchscript-able.
+                Warning, this is an experimental API and susceptible to frequent changes.
+
+        Returns:
+            A remote module instance which wraps the :class:`~nn.Module` created by the
+            user-provided ``module_cls``, it has a blocking ``forward`` method and an
+            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
+            on the user-provided module on the remote side.
+
+        Example::
+            Run the following code in two different processes:
+
+            >>> # xdoctest: +SKIP("distributed")
+            >>> # On worker 0:
+            >>> import torch
+            >>> import torch.distributed.rpc as rpc
+            >>> from torch import nn, Tensor
+            >>> from torch.distributed.nn.api.remote_module import RemoteModule
+            >>>
+            >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+            >>> remote_linear_module = RemoteModule(
+            >>>     "worker1/cpu", nn.Linear, args=(20, 30),
+            >>> )
+            >>> input = torch.randn(128, 20)
+            >>> ret_fut = remote_linear_module.forward_async(input)
+            >>> ret = ret_fut.wait()
+            >>> rpc.shutdown()
+
+            >>> # On worker 1:
+            >>> import torch
+            >>> import torch.distributed.rpc as rpc
+            >>>
+            >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+            >>> rpc.shutdown()
+        """
+        super().__init__()
+
+        enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device)
+
+        # Default arguments preparation.
+        args = args if args is not None else ()
+        kwargs = kwargs if kwargs is not None else {}
+
+        if _module_interface_cls is not None:
+            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
+            self.is_scriptable = True
+
+            # Instantiate template on remote side.
+            fut = rpc.rpc_async(
+                self.on,
+                _instantiate_template,
+                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
+            )
+
+            self._init_template(
+                _module_interface_cls, enable_moving_cpu_tensors_to_cuda
+            )
+
+            # Instantiate template on remote side.
+            fut = rpc.rpc_async(
+                self.on,
+                _instantiate_template,
+                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
+            )
+
+            # Create the module on the remote side.
+            fut.wait()  # Ensure remote_module_cls is available on remote side.
+
+            # TODO: We need to change this to rpc.remote, and make it async (see the else branch below).
+            # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote
+            # See https://github.com/pytorch/pytorch/issues/58098 for more context.
+            self.module_rref = rpc.rpc_sync(
+                self.on,
+                _create_module_with_interface,
+                (module_cls, args, kwargs, self.device, _module_interface_cls),
+            )
+        else:
+            self.is_scriptable = False
+            self.generated_methods = (
+                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
+            )
+            # Create the module on the remote side.
+            self.module_rref = rpc.remote(
+                self.on,
+                _create_module,
+                (module_cls, args, kwargs, self.device),
+            )
+
+        self._install_generated_methods()
+        self._check_attribute_picklability()
+
+    def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
+        """
+        Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters.
+
+        This can typically be used in conjunction
+        with :class:`~torch.distributed.optim.DistributedOptimizer`.
+
+        Args:
+            recurse (bool): if True, then returns parameters of the remote
+                module and all submodules of the remote module. Otherwise,
+                returns only parameters that are direct members of the
+                remote module.
+
+        Returns:
+            A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``)
+            to remote module's parameters.
+        """
+        return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))
+
+    def get_module_rref(self) -> rpc.RRef[nn.Module]:
+        """Return an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``) pointing to the remote module."""
+        return self.module_rref
+
+    @torch.jit.export
+    def __getstate__(self):
+        raise RuntimeError(
+            "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC"
+        )
+
+    @torch.jit.export
+    def __setstate__(self, state):
+        raise RuntimeError(
+            "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC"
+        )
+
+    def register_buffer(
+        self, name: str, tensor: Optional[Tensor], persistent: bool = True
+    ) -> None:
+        _raise_not_supported(self.register_buffer.__name__)
+
+    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
+        _raise_not_supported(self.register_parameter.__name__)
+
+    def add_module(self, name: str, module: Optional[Module]) -> None:
+        _raise_not_supported(self.add_module.__name__)
+
+    def apply(self: T, fn: Callable[[Module], None]) -> T:  # type: ignore[return]
+        _raise_not_supported(self.apply.__name__)
+
+    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:  # type: ignore[return]
+        _raise_not_supported(self.cuda.__name__)
+
+    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:  # type: ignore[return]
+        _raise_not_supported(self.ipu.__name__)
+
+    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:  # type: ignore[return]
+        _raise_not_supported(self.xpu.__name__)
+
+    def cpu(self: T) -> T:  # type: ignore[return]
+        _raise_not_supported(self.cpu.__name__)
+
+    def type(self: T, dst_type: Union[dtype, str]) -> T:  # type: ignore[return]
+        _raise_not_supported(self.type.__name__)
+
+    def float(self: T) -> T:  # type: ignore[return]
+        _raise_not_supported(self.float.__name__)
+
+    def double(self: T) -> T:  # type: ignore[return]
+        _raise_not_supported(self.double.__name__)
+
+    def half(self: T) -> T:  # type: ignore[return]
+        _raise_not_supported(self.half.__name__)
+
+    def bfloat16(self: T) -> T:  # type: ignore[return]
+        _raise_not_supported(self.bfloat16.__name__)
+
+    def to(self, *args, **kwargs) -> T:  # type: ignore[misc, return, type-var]
+        _raise_not_supported(self.to.__name__)
+
+    def register_backward_hook(  # type: ignore[return]
+        self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]]
+    ) -> RemovableHandle:
+        _raise_not_supported(self.register_backward_hook.__name__)
+
+    def register_forward_pre_hook(  # type: ignore[return]
+        self,
+        hook: Union[
+            Callable[[T, Tuple[Any, ...]], Optional[Any]],
+            Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],
+        ],
+        prepend: bool = False,
+        with_kwargs: bool = False,
+    ) -> RemovableHandle:
+        _raise_not_supported(self.register_forward_pre_hook.__name__)
+
+    def register_forward_hook(  # type: ignore[return, override]
+        self,
+        hook: Union[
+            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
+            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
+        ],
+        prepend: bool = False,
+        with_kwargs: bool = False,
+    ) -> RemovableHandle:
+        _raise_not_supported(self.register_forward_hook.__name__)
+
+    def state_dict(self, *args, **kwargs):
+        _raise_not_supported(self.state_dict.__name__)
+
+    def load_state_dict(
+        self,
+        state_dict: Mapping[str, Any],
+        strict: bool = True,
+        assign: bool = False,
+    ):
+        _raise_not_supported(self.load_state_dict.__name__)
+
+    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
+        raise ValueError(
+            "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
+        )
+
+    def named_parameters(  # type: ignore[return]
+        self,
+        prefix: str = "",
+        recurse: bool = True,
+        remove_duplicate: bool = True
+    ) -> Iterator[Tuple[str, Parameter]]:
+        _raise_not_supported(self.named_parameters.__name__)
+
+    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:  # type: ignore[return]
+        _raise_not_supported(self.buffers.__name__)
+
+    def named_buffers(  # type: ignore[return]
+        self,
+        prefix: str = "",
+        recurse: bool = True,
+        remove_duplicate: bool = True
+    ) -> Iterator[Tuple[str, Tensor]]:
+        _raise_not_supported(self.named_buffers.__name__)
+
+    def children(self) -> Iterator[Module]:  # type: ignore[return]
+        _raise_not_supported(self.children.__name__)
+
+    def named_children(self) -> Iterator[Tuple[str, Module]]:  # type: ignore[return]
+        _raise_not_supported(self.named_children.__name__)
+
+    def modules(self) -> Iterator[Module]:  # type: ignore[return]
+        _raise_not_supported(self.modules.__name__)
+
+    def named_modules(
+        self,
+        memo: Optional[Set[Module]] = None,
+        prefix: str = "",
+        remove_duplicate: bool = True,
+    ):
+        _raise_not_supported(self.named_modules.__name__)
+
+    def train(self: T, mode: bool = True) -> T:
+        return self.module_rref.rpc_sync().train()  # type: ignore[operator, union-attr]
+
+    def eval(self: T) -> T:
+        return self.module_rref.rpc_sync().eval()  # type: ignore[operator, union-attr]
+
+    def requires_grad_(self: T, requires_grad: bool = True) -> T:  # type: ignore[return]
+        _raise_not_supported(self.requires_grad_.__name__)
+
+    def zero_grad(self, set_to_none: bool = True) -> None:
+        _raise_not_supported(self.zero_grad.__name__)
+
+    def share_memory(self: T) -> T:  # type: ignore[return]
+        _raise_not_supported(self.share_memory.__name__)
+
+    def extra_repr(self) -> str:  # type: ignore[return]
+        _raise_not_supported(self.extra_repr.__name__)
+
+    def _prepare_init(self, remote_device_str: str) -> bool:
+        """Prepare the initialization and returns whether to enable automatically moving CPU tensors to CUDA devices."""
+        # Sanity check.
+        assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
+
+        remote_device = _remote_device(remote_device_str)
+        self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank()
+        self.device = str(remote_device.device())
+        agent = rpc._get_current_rpc_agent()
+        # If the device map of the remote worker is set,
+        # then enable moving any input CPU tensors to the same cuda device.
+        self.is_device_map_set = bool(
+            agent._get_device_map(agent.get_worker_info(self.on))  # type: ignore[arg-type]
+        )
+        # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``:
+        # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set,
+        # then any CPU tensors can still be moved to a cuda device to run forward,
+        # but the output must be moved back to CPU before being sent over the wire.
+        enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda"
+        return enable_moving_cpu_tensors_to_cuda
+
+    def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda):
+        """Instantiate template on local side."""
+        generated_module = instantiator.instantiate_scriptable_remote_module_template(
+            module_interface_cls, enable_moving_cpu_tensors_to_cuda
+        )
+        self.generated_methods = generated_module._generated_methods
+
+    def _check_attribute_picklability(self):
+        """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability)."""
+        for k in self.__dict__.keys():
+            if (
+                k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES
+                and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
+            ):
+                raise AttributeError(
+                    f"Attribute {k} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or "
+                    "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``."
+                )
+
+    def _install_generated_methods(self):
+        for method in self.generated_methods:
+            method_name = method.__name__
+            method = torch.jit.export(method)
+            setattr(self, method_name, types.MethodType(method, self))
+
+    @staticmethod
+    def init_from_module_rref(
+        remote_device: str,
+        module_rref: rpc.RRef[nn.Module],
+        _module_interface_cls: Any = None,
+    ):
+        """
+        Besides the constructor, a RemoteModule instance can also be initialized given a module RRef.
+
+        This alternate initialization method can be particularly useful if we want to create multiple
+        RemoteModule instances that share the same underlying module and reduce memory consumption.
+
+        Moreover, this also provides a workaround for passing script RemoteModule over RPC,
+        which is not supported. The recommended way is as follows:
+
+            1. the sender creates a RemoteModule;
+            2. the sender sends its ``module_rref`` over RPC;
+            3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``.
+
+        Example::
+            Run the following code in two different processes:
+
+            >>> # xdoctest: +SKIP("distributed")
+            >>> # On worker 0:
+            >>> import torch
+            >>> import torch.distributed.rpc as rpc
+            >>> from torch import nn, Tensor
+            >>> from torch.distributed.nn.api.remote_module import RemoteModule
+            >>>
+            >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+            >>> remote_module = RemoteModule(
+            >>>     "worker1/cpu", nn.Linear, args=(20, 30),
+            >>> )
+            >>>
+            >>> remote_module1 = rpc.rpc_sync(
+            >>>     "worker1/cpu",
+            >>>     RemoteModule.init_from_module_rref,
+            >>>     ("worker1/cpu", remote_module1.get_module_rref()),
+            >>> )
+            >>> rpc.shutdown()
+
+            >>> # On worker 1:
+            >>> import torch
+            >>> import torch.distributed.rpc as rpc
+            >>>
+            >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+            >>> rpc.shutdown()
+
+        Args:
+            remote_device (str): Device on the destination worker where we'd like to place this module.
+                The device can be a local device or a remote device specified by one of the following remote
+                formats:
+
+                    1. "rank:/" (ex: "rank:0/cuda:0").
+                    2. "/" (ex: "trainer0/cuda:0").
+
+                In addition, the device field can be optional and the default value is "cpu".
+            module_rref (RRef[nn.Module]): The module reference shared by both the caller and
+                the created remote module.
+            _module_interface_cls (type, optional): The TorchScript interface type for the module
+                to be created. The type object should be decorated by @torch.jit.interface.
+                If not provided, the generated RemoteModule is not torchscript-able.
+                Warning, this is an experimental API and susceptible to frequent changes.
+
+        Returns:
+            A remote module instance which wraps the :class:`~nn.Module` created by the
+            user-provided ``module_rref``, it has a blocking ``forward`` method and an
+            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
+            on the user-provided module on the remote side.
+        """
+        # NOTE: if a new attribute is added to this class, also need to add it
+        # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling.
+
+        remote_module = object.__new__(RemoteModule)
+
+        enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device)
+
+        if _module_interface_cls is not None:
+            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
+            remote_module.is_scriptable = True
+
+            remote_module._init_template(
+                _module_interface_cls, enable_moving_cpu_tensors_to_cuda
+            )
+        else:
+            remote_module.is_scriptable = False
+            remote_module.generated_methods = (
+                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
+            )
+        remote_module.module_rref = module_rref
+
+        remote_module._install_generated_methods()
+        remote_module._check_attribute_picklability()
+
+        return remote_module
+
+
+class RemoteModule(_RemoteModule):
+    """
+        A RemoteModule instance can only be created after RPC initialization.
+
+        It creates a user-specified module on a specified remote node.
+        It behaves like a regular ``nn.Module`` except that the ``forward`` method is
+        executed on the remote node.
+        It takes care of autograd recording to ensure the backward pass propagates
+        gradients back to the corresponding remote module.
+
+        It generates two methods ``forward_async`` and ``forward`` based on the
+        signature of the ``forward`` method of ``module_cls``. ``forward_async``
+        runs asynchronously and returns a Future. The arguments of ``forward_async``
+        and ``forward`` are the same as the ``forward`` method of the module
+        returned by the ``module_cls``.
+
+        For example, if ``module_cls`` returns an instance of ``nn.Linear``,
+        that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``,
+        the generated ``RemoteModule`` will have 2 methods with the signatures:
+
+        | ``def forward(input: Tensor) -> Tensor:``
+        | ``def forward_async(input: Tensor) -> Future[Tensor]:``
+
+    Args:
+        remote_device (str): Device on the destination worker where we'd like to place this module.
+            The format should be "/", where the device field can be parsed as torch.device type.
+            E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
+            In addition, the device field can be optional and the default value is "cpu".
+        module_cls (nn.Module): Class for the module to be created remotely. For example,
+
+            >>> class MyModule(nn.Module):
+            >>>     def forward(input):
+            >>>         return input + 1
+            >>>
+            >>> module_cls = MyModule
+
+        args (Sequence, optional): args to be passed to ``module_cls``.
+        kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
+
+    Returns:
+        A remote module instance which wraps the :class:`~nn.Module` created by the
+        user-provided ``module_cls``, it has a blocking ``forward`` method and an
+        asynchronous ``forward_async`` method that returns a future of the ``forward`` call
+        on the user-provided module on the remote side.
+
+    Example::
+        Run the following code in two different processes:
+
+        >>> # xdoctest: +SKIP("distributed")
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> from torch import nn, Tensor
+        >>> from torch.distributed.nn.api.remote_module import RemoteModule
+        >>>
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> remote_linear_module = RemoteModule(
+        >>>     "worker1/cpu", nn.Linear, args=(20, 30),
+        >>> )
+        >>> input = torch.randn(128, 20)
+        >>> ret_fut = remote_linear_module.forward_async(input)
+        >>> ret = ret_fut.wait()
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>>
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+
+        Furthermore, a more practical example that is combined with
+        `DistributedDataParallel `__ (DDP)
+        can be found in this `tutorial `__.
+    """
+
+    def __init__(
+        self,
+        remote_device: str,
+        module_cls: Type[nn.Module],
+        args: Optional[Tuple] = None,
+        kwargs: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(remote_device, module_cls, args, kwargs)
+
+
+def _remote_module_receiver(
+    *remote_module_pickled_attrs,
+):
+    """Deserializes a RemoteModule."""
+    serialized_remote_module = _SerializedRemoteModule._make(
+        remote_module_pickled_attrs
+    )
+    m = object.__new__(RemoteModule)
+    m.__dict__.update(serialized_remote_module._asdict())
+
+    # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
+    m.module_rref = rpc.PyRRef._deserialize(m.module_rref)
+
+    # Install generated methods when unpickled.
+    for method in m.generated_methods:
+        method_name = method.__name__
+        method = torch.jit.export(method)
+        setattr(m, method_name, types.MethodType(method, m))
+
+    return m
+
+
+def _remote_module_reducer(remote_module):
+    """Serialize a RemoteModule."""
+    pickled_attrs = {}
+    for k, v in remote_module.__dict__.items():
+        # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method.
+        if k == "module_rref":
+            pickled_attrs[k] = v._serialize()
+        elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES:
+            pickled_attrs[k] = v
+        # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
+        elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING:
+            print(
+                f"The new attribute ``{k}`` of RemoteModule is ignored during RPC pickling. "
+                "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "
+                "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.",
+                file=sys.stderr,
+            )
+
+    return (
+        _remote_module_receiver,
+        tuple(pickled_attrs.values()),
+    )
+
+
+def _recursive_script_module_receiver(
+    recursive_script_module_serialized,
+):
+    """Deserializes a RecursiveScriptModule that does not contain a script RemoteModule."""
+    f = io.BytesIO(recursive_script_module_serialized)
+    m = torch.jit.load(f)
+    return m
+
+
+def _recursive_script_module_reducer(recursive_script_module):
+    """Serialize a RecursiveScriptModule that does not contain a script RemoteModule, and raises an error otherwise."""
+    if hasattr(recursive_script_module._c, "module_rref"):
+        raise RuntimeError(
+            "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, "
+            "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`."
+        )
+
+    f = io.BytesIO()
+    torch.jit.save(recursive_script_module, f)
+    return (_recursive_script_module_receiver, (f.getvalue(),))
+
+
+_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer)
+_internal_rpc_pickler._register_reducer(
+    torch.jit.RecursiveScriptModule, _recursive_script_module_reducer
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/functional.py b/MLPY/Lib/site-packages/torch/distributed/nn/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..4477bc62583b9a473d9077203362fb6e8e957b77
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/nn/functional.py
@@ -0,0 +1,440 @@
+import torch
+import torch.distributed as dist
+from torch.autograd import Function
+# The two imports below are not always available depending on the
+# USE_DISTRIBUTED compile flag. Make sure they raise import error
+# if we're trying to use them.
+from torch.distributed import group, ReduceOp
+
+def broadcast(tensor, src, group=group.WORLD):
+    """
+    Broadcasts the tensor to the whole group.
+
+    ``tensor`` must have the same number of elements in all processes
+    participating in the collective.
+
+    Arguments:
+        tensor (Tensor): Data to be sent if ``src`` is the rank of current
+            process.
+        src (int): Source rank.
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        Tensor: Received tensor from the broadcast op.
+
+    """
+    return _Broadcast.apply(src, group, tensor)
+
+
+def gather(tensor, dst=0, group=group.WORLD):
+    """
+    Gathers a list of tensors in a single process.
+
+    Arguments:
+        tensor (Tensor): Input tensor.
+        dst (int, optional): Destination rank (default is 0).
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
+    """
+    return _Gather.apply(dst, group, tensor)
+
+
+def scatter(tensors, src=0, group=group.WORLD):
+    """
+    Scatters a list of tensors to all processes in a group.
+
+    Each process will receive exactly one tensor and store its data in the
+    ``tensor`` argument.
+
+    Arguments:
+        tensors (list[Tensor]): List of tensors to scatter on the source rank.
+            Receivers must pass ``None`.
+        src (int, optional): Source rank (default is 0).
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        Tensor: Output tensor from the scatter operation.
+
+    """
+    return _Scatter.apply(src, group, *tensors)
+
+
+def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD):
+    """
+    Reduces the tensor data across all machines.
+
+    Only the process with rank ``dst`` is going to receive the final result.
+
+    Arguments:
+        tensor (Tensor): Input of the collective.
+        dst (int): Destination rank.
+        op (optional): One of the values from
+            ``torch.distributed.ReduceOp``
+            enum.  Specifies an operation used for element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        Tensor: Output of the collective.
+
+    """
+    return _Reduce.apply(dst, op, group, tensor)
+
+
+def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD):
+    """
+    Reduces, then scatters a list of tensors to all processes in a group.
+
+    Arguments:
+        output (Tensor): Output tensor.
+        input_list (list[Tensor]): List of tensors to reduce and scatter.
+        op (optional): One of the values from
+            ``torch.distributed.ReduceOp``
+            enum.  Specifies an operation used for element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        Tensor: Output of the collective.
+
+    """
+    return _Reduce_Scatter.apply(op, group, output, *input_list)
+
+
+def all_gather(tensor, group=group.WORLD):
+    """
+    Gathers tensors from the whole group in a list.
+
+    Arguments:
+        tensor (Tensor): Tensor to be broadcast from current process.
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        tuple([Tensor]): Output of the collective.
+
+    """
+    return _AllGather.apply(group, tensor)
+
+def _all_gather_base(output_tensor, input_tensor, group=group.WORLD):
+    """
+    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
+
+    Args:
+        output_tensor (Tensor): Output tensor. It should contain
+            correctly-sized tensors to be used for output of the collective.
+        input_tensor (Tensor): Tensor to be broadcast from current process.
+        group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+
+    Examples:
+        >>> # All tensors below are of torch.int64 dtype.
+        >>> # We have 2 process groups, 2 ranks.
+        >>> # xdoctest: +SKIP("incorrect want text")
+        >>> output_tensor = torch.zeros(2, dtype=torch.int64)
+        >>> output_tensor
+        [tensor([0, 0])] # Rank 0 and 1
+        >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank
+        >>> tensor
+        tensor([1]) # Rank 0
+        tensor([2]) # Rank 1
+        >>> dist.all_gather_base(output_tensor, tensor)
+        >>> output_tensor
+        tensor([1,2]) # Rank 0
+        tensor([1,2]) # Rank 1
+
+    .. warning::
+        `_all_gather_base` is experimental and subject to change.
+        It is the caller's responsibility to ensure the output_tensor
+        is correctly sized.
+
+    """
+    return _AllGatherBase.apply(output_tensor, input_tensor, group)
+
+
+def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD):
+    """
+    Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
+
+    Arguments:
+        output_tensor_list (list[Tensor]): list of tensors to gather one per rank.
+        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        tuple([Tensor]): Output of the collective.
+
+    """
+    return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
+
+
+def all_to_all_single(
+    output,
+    input,
+    output_split_sizes=None,
+    input_split_sizes=None,
+    group=group.WORLD,
+):
+    """
+    Each process splits input tensor and then scatters the split list to all processes in a group.
+
+    Then concatenate the received tensors from all the processes in the group and return single output tensor.
+
+    Arguments:
+        output (Tensor): Gathered concatenated output tensor.
+        input (Tensor): Input tensor to scatter.
+        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
+            if specified None or empty, dim 0 of ``output`` tensor must divide
+            equally by ``world_size``.
+        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
+            if specified None or empty, dim 0 of ``input`` tensor must divide
+            equally by ``world_size``.
+
+    Returns:
+        Tensor: Output of the collective.
+
+    """
+    return _AlltoAllSingle.apply(
+        group, output, output_split_sizes, input_split_sizes, input
+    )
+
+
+def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
+    """
+    Reduces the tensor data across all machines in such a way that all get the final result.
+
+    After the call the returned tensor is going to be bitwise
+    identical in all processes.
+
+    Arguments:
+        tensor (Tensor): Input of the collective.
+        op (optional): One of the values from
+            ``torch.distributed.ReduceOp``
+            enum.  Specifies an operation used for element-wise reductions.
+        group (ProcessGroup, optional): The process group to work on.
+
+    Returns:
+        Tensor: Output of the collective
+
+    """
+    return _AllReduce.apply(op, group, tensor)
+
+
+class _Broadcast(Function):
+    @staticmethod
+    def forward(ctx, src, group, tensor):
+        ctx.src = src
+        ctx.group = group
+        ctx.rank = dist.get_rank(group=group)
+        # torch.distributed makes all the calls in place
+        # we allocate new tensors to avoid this
+        tensor = tensor.clone()
+        dist.broadcast(tensor, src, group=group)
+        return tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
+        if ctx.src != ctx.rank:
+            gx.zero_()
+        return (None, None, gx)
+
+
+class _Gather(Function):
+    @staticmethod
+    def forward(ctx, dst, group, tensor):
+        ctx.dst = dst
+        ctx.group = group
+        # Need to create a list of tensors here to do the
+        # aggregation, get it from the group size
+        # tensor should be correctly sized for the method
+        # gathering
+        tensor_list = [
+            torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group))
+        ]
+
+        tensor = tensor.contiguous()
+        if dist.get_rank(group=group) == dst:
+            dist.gather(tensor, tensor_list, dst, group=group)
+        else:
+            dist.gather(tensor, None, dst, group=group)
+        return tuple(tensor_list)
+
+    @staticmethod
+    def backward(ctx, *grad_outputs):
+        return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),)
+
+
+class _Scatter(Function):
+    @staticmethod
+    def forward(ctx, src, group, *tensors):
+        ctx.src = src
+        ctx.group = group
+        assert all(t.size() == tensors[0].size() for t in tensors)
+        output = torch.zeros_like(tensors[0])
+        if dist.get_rank(group=group) == src:
+            dist.scatter(output, list(tensors), src, group=group)
+        else:
+            dist.scatter(output, None, src, group=group)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
+
+
+class _Reduce(Function):
+    @staticmethod
+    def forward(ctx, src, op, group, tensor):
+        ctx.src = src
+        ctx.group = group
+        tensor = tensor.clone()
+        dist.reduce(tensor, src, op=op, group=group)
+        return tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
+
+
+class _Reduce_Scatter(Function):
+    @staticmethod
+    def forward(ctx, op, group, tensor, *input_tensor_list):
+        ctx.group = group
+        # Need contiguous tensors for collectives.
+        tensor = tensor.contiguous()
+        input_tensor_list = tuple(t.contiguous() for t in input_tensor_list)
+        dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
+        return tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
+
+
+class _AllGather(Function):
+    @staticmethod
+    def forward(ctx, group, tensor):
+        # Need contiguous tensors for collectives.
+        tensor = tensor.contiguous()
+
+        ctx.group = group
+        out_tensor_list = [
+            torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group))
+        ]
+
+        dist.all_gather(out_tensor_list, tensor, group=group)
+        return tuple(out_tensor_list)
+
+    @staticmethod
+    def backward(ctx, *grad_outputs):
+        if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
+            rank = dist.get_rank(group=ctx.group)
+            gx = torch.empty_like(grad_outputs[rank])
+            gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs)
+        else:
+            # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum()
+            # to emulate the ReduceScatter behavior
+            tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
+            gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
+            gx = torch.sum(torch.stack(gxs), dim=0)
+        return (None, gx)
+
+class _AllGatherBase(Function):
+    @staticmethod
+    def forward(ctx, output_tensor, input_tensor, group):
+        ctx.group = group
+        dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
+        return output_tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
+            world_size = dist.get_world_size(group=ctx.group)
+            out_size = list(grad_output.size())
+            if out_size[0] % world_size != 0:
+                raise RuntimeError(
+                    f'Tensor with dimensions: {out_size} does '
+                    f'not have first dimension divisible by world_size: {world_size}'
+                )
+            out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group)
+            gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype)
+            dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group)
+        else:
+            raise RuntimeError("Backend not supported!")
+        return (None, gx, None)
+
+class _AlltoAll(Function):
+    @staticmethod
+    def forward(ctx, group, out_tensor_list, *tensors):
+        ctx.group = group
+        ctx.input_tensor_size_list = [
+            tensors[i].size() for i in range(dist.get_world_size(group=group))
+        ]
+        my_rank = dist.get_rank(group=group)
+        tensors = tuple(t.contiguous() for t in tensors)
+        # Implement it on means of scatter/gather, send/recv async operations have issues
+        if dist.get_backend(group=group) is dist.Backend.GLOO:
+            for i in range(dist.get_world_size(group=group)):
+                to_send = None
+                if i == my_rank:
+                    to_send = list(tensors)
+                dist.scatter(out_tensor_list[i], to_send, i, group=group)
+        else:
+            dist.all_to_all(
+                out_tensor_list,
+                list(tensors),
+                group=group,
+            )
+        return tuple(out_tensor_list)
+
+    @staticmethod
+    def backward(ctx, *grad_outputs):
+        tensor_list = [
+            torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
+            for size in ctx.input_tensor_size_list
+        ]
+        return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
+
+
+class _AlltoAllSingle(Function):
+    @staticmethod
+    def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
+        ctx.group = group
+        ctx.input_size = input.size()
+        ctx.output_split_sizes = input_split_sizes
+        ctx.input_split_sizes = output_split_sizes
+        dist.all_to_all_single(
+            output,
+            input,
+            output_split_sizes=output_split_sizes,
+            input_split_sizes=input_split_sizes,
+            group=group,
+        )
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype)
+        return (None, None, None, None) + (
+            _AlltoAllSingle.apply(
+                ctx.group,
+                tensor,
+                ctx.output_split_sizes,
+                ctx.input_split_sizes,
+                grad_output.contiguous(),
+            ),
+        )
+
+
+class _AllReduce(Function):
+    @staticmethod
+    def forward(ctx, op, group, tensor):
+        ctx.group = group
+        ctx.op = op
+        tensor = tensor.clone()
+        dist.all_reduce(tensor, op=op, group=group)
+        return tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/__init__.py b/MLPY/Lib/site-packages/torch/distributed/nn/jit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c59e68425d9a7692d862a58dcc06b32c3c91e2b6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f251304788b835010037a9724d1702a808d2f921
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/instantiator.py b/MLPY/Lib/site-packages/torch/distributed/nn/jit/instantiator.py
new file mode 100644
index 0000000000000000000000000000000000000000..56121cc2cd57ec38521f31c00881cf826e4ea59d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/nn/jit/instantiator.py
@@ -0,0 +1,153 @@
+#!/usr/bin/python3
+import importlib
+import logging
+import os
+import sys
+import tempfile
+from typing import Optional
+
+import torch
+from torch.distributed.nn.jit.templates.remote_module_template import (
+    get_remote_module_template,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+_FILE_PREFIX = "_remote_module_"
+_TEMP_DIR = tempfile.TemporaryDirectory()
+INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name
+logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH)
+sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
+
+
+def get_arg_return_types_from_interface(module_interface):
+    assert getattr(
+        module_interface, "__torch_script_interface__", False
+    ), "Expect a TorchScript class interface decorated by @torch.jit.interface."
+    qualified_name = torch._jit_internal._qualified_name(module_interface)
+    cu = torch.jit._state._python_cu
+    module_interface_c = cu.get_interface(qualified_name)
+    assert (
+        "forward" in module_interface_c.getMethodNames()
+    ), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
+    method_schema = module_interface_c.getMethod("forward")
+
+    arg_str_list = []
+    arg_type_str_list = []
+    assert method_schema is not None
+    for argument in method_schema.arguments:
+        arg_str_list.append(argument.name)
+
+        if argument.has_default_value():
+            default_value_str = f" = {argument.default_value}"
+        else:
+            default_value_str = ""
+        arg_type_str = f"{argument.name}: {argument.type}{default_value_str}"
+        arg_type_str_list.append(arg_type_str)
+
+    arg_str_list = arg_str_list[1:]  # Remove "self".
+    args_str = ", ".join(arg_str_list)
+
+    arg_type_str_list = arg_type_str_list[1:]  # Remove "self".
+    arg_types_str = ", ".join(arg_type_str_list)
+
+    assert len(method_schema.returns) == 1
+    argument = method_schema.returns[0]
+    return_type_str = str(argument.type)
+
+    return args_str, arg_types_str, return_type_str
+
+
+def _write(out_path, text):
+    old_text: Optional[str]
+    try:
+        with open(out_path) as f:
+            old_text = f.read()
+    except OSError:
+        old_text = None
+    if old_text != text:
+        with open(out_path, "w") as f:
+            logger.info("Writing %s", out_path)
+            f.write(text)
+    else:
+        logger.info("Skipped writing %s", out_path)
+
+
+def _do_instantiate_remote_module_template(
+    generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
+):
+    generated_code_text = get_remote_module_template(
+        enable_moving_cpu_tensors_to_cuda
+    ).format(**str_dict)
+    out_path = os.path.join(
+        INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
+    )
+    _write(out_path, generated_code_text)
+
+    # From importlib doc,
+    # > If you are dynamically importing a module that was created since
+    # the interpreter began execution (e.g., created a Python source file),
+    # you may need to call invalidate_caches() in order for the new module
+    # to be noticed by the import system.
+    importlib.invalidate_caches()
+    generated_module = importlib.import_module(f"{generated_module_name}")
+    return generated_module
+
+
+def instantiate_scriptable_remote_module_template(
+    module_interface_cls, enable_moving_cpu_tensors_to_cuda=True
+):
+    if not getattr(module_interface_cls, "__torch_script_interface__", False):
+        raise ValueError(
+            f"module_interface_cls {module_interface_cls} must be a type object decorated by "
+            "@torch.jit.interface"
+        )
+
+    # Generate the template instance name.
+    module_interface_cls_name = torch._jit_internal._qualified_name(
+        module_interface_cls
+    ).replace(".", "_")
+    generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}"
+
+    # Generate type annotation strs.
+    assign_module_interface_cls_str = (
+        f"from {module_interface_cls.__module__} import "
+        f"{module_interface_cls.__name__} as module_interface_cls"
+    )
+    args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface(
+        module_interface_cls
+    )
+    kwargs_str = ""
+    arrow_and_return_type_str = f" -> {return_type_str}"
+    arrow_and_future_return_type_str = f" -> Future[{return_type_str}]"
+
+    str_dict = dict(
+        assign_module_interface_cls=assign_module_interface_cls_str,
+        arg_types=arg_types_str,
+        arrow_and_return_type=arrow_and_return_type_str,
+        arrow_and_future_return_type=arrow_and_future_return_type_str,
+        args=args_str,
+        kwargs=kwargs_str,
+        jit_script_decorator="@torch.jit.script",
+    )
+    return _do_instantiate_remote_module_template(
+        generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
+    )
+
+
+def instantiate_non_scriptable_remote_module_template():
+    generated_module_name = f"{_FILE_PREFIX}non_scriptable"
+    str_dict = dict(
+        assign_module_interface_cls="module_interface_cls = None",
+        args="*args",
+        kwargs="**kwargs",
+        arg_types="*args, **kwargs",
+        arrow_and_return_type="",
+        arrow_and_future_return_type="",
+        jit_script_decorator="",
+    )
+    # For a non-scriptable template, always enable moving CPU tensors to a cuda device,
+    # because there is no syntax limitation on the extra handling caused by the script.
+    return _do_instantiate_remote_module_template(generated_module_name, str_dict, True)
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__init__.py b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebabf3ef4943cb2bd926a83e8647d932f8876fa8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb3162a75a35f3a295f744d4f7bb1a1c8902f8ec
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..edc2a431dc3b1297590c47411035a058b5097b2a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py
@@ -0,0 +1,107 @@
+#!/usr/bin/python3
+
+
+def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool):
+    return _TEMPLATE_PREFIX + (
+        _REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA
+        if enable_moving_cpu_tensors_to_cuda
+        else _REMOTE_FORWARD_TEMPLATE
+    )
+
+
+_TEMPLATE_PREFIX = """from typing import *
+
+import torch
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch._jit_internal import Future
+from torch.distributed.rpc import RRef
+from typing import Tuple  # pyre-ignore: unused import
+
+
+{assign_module_interface_cls}
+
+
+def forward_async(self, {arg_types}){arrow_and_future_return_type}:
+    args = (self.module_rref, self.device, self.is_device_map_set, {args})
+    kwargs = {{{kwargs}}}
+    return rpc.rpc_async(
+        self.module_rref.owner(),
+        _remote_forward,
+        args,
+        kwargs,
+    )
+
+
+def forward(self, {arg_types}){arrow_and_return_type}:
+    args = (self.module_rref, self.device, self.is_device_map_set, {args})
+    kwargs = {{{kwargs}}}
+    ret_fut = rpc.rpc_async(
+        self.module_rref.owner(),
+        _remote_forward,
+        args,
+        kwargs,
+    )
+    return ret_fut.wait()
+
+
+_generated_methods = [
+    forward_async,
+    forward,
+]
+
+
+{jit_script_decorator}
+"""
+
+# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``)
+# even if the code is only used for instantiation but not execution.
+# Therefore, only include handling moving CPU tensors to a cuda device if necessary.
+# TODO: Merge these two templates together in the future once TorchScript syntax is improved.
+_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """
+def _remote_forward(
+    module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}:
+    module = module_rref.local_value()
+    device = torch.device(device)
+
+    if device.type != "cuda":
+        return module.forward({args}, {kwargs})
+
+    # If the module is on a cuda device,
+    # move any CPU tensor in args or kwargs to the same cuda device.
+    # Since torch script does not support generator expression,
+    # have to use concatenation instead of
+    # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
+    args = ({args},)
+    out_args: Tuple[()] = ()
+    for arg in args:
+        arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
+        out_args = out_args + arg
+
+    kwargs = {{{kwargs}}}
+    for k, v in kwargs.items():
+        if isinstance(v, Tensor):
+            kwargs[k] = kwargs[k].to(device)
+
+    if is_device_map_set:
+        return module.forward(*out_args, {kwargs})
+
+    # If the device map is empty, then only CPU tensors are allowed to send over wire,
+    # so have to move any GPU tensor to CPU in the output.
+    # Since torch script does not support generator expression,
+    # have to use concatenation instead of
+    # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``.
+    ret: Tuple[()] = ()
+    for i in module.forward(*out_args, {kwargs}):
+        i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
+        ret = ret + i
+    return ret
+"""
+
+_REMOTE_FORWARD_TEMPLATE = """
+def _remote_forward(
+    module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}:
+    module = module_rref.local_value()
+
+    return module.forward({args}, {kwargs})
+"""
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__init__.py b/MLPY/Lib/site-packages/torch/distributed/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c3b37ae54a5d1ebbada508988623a57cf07acd9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/__init__.py
@@ -0,0 +1,34 @@
+"""
+:mod:`torch.distributed.optim` exposes DistributedOptimizer, which takes a list
+of remote parameters (:class:`~torch.distributed.rpc.RRef`) and runs the
+optimizer locally on the workers where the parameters live.  The distributed
+optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
+apply the gradients on each worker.
+"""
+import torch
+from torch import optim
+
+from .apply_optimizer_in_backward import (
+    _apply_optimizer_in_backward,
+    _get_in_backward_optimizers,
+)
+from .functional_adadelta import _FunctionalAdadelta
+
+from .functional_adagrad import _FunctionalAdagrad
+from .functional_adam import _FunctionalAdam
+from .functional_adamax import _FunctionalAdamax
+from .functional_adamw import _FunctionalAdamW
+from .functional_rmsprop import _FunctionalRMSprop
+from .functional_rprop import _FunctionalRprop
+from .functional_sgd import _FunctionalSGD
+from .named_optimizer import _NamedOptimizer
+from .utils import as_functional_optim
+
+
+# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
+# based on RPC being available.
+if hasattr(torch._C, "_rpc_init"):
+    from .optimizer import DistributedOptimizer
+
+from .post_localSGD_optimizer import PostLocalSGDOptimizer
+from .zero_redundancy_optimizer import ZeroRedundancyOptimizer
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d91c1d8e65a98e40e17777a5da127cde23c35e21
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e56fdd2fd0eedaf44997b92d1a7f89b6acd8d07
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8166212d674990bdb830d5caa9030b3b2c3580d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1eb9417a655b0fa3737fe971e2b61463a9e474da
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df47c6d98856c39a52f5b2fcc181ac2427219219
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19d5c0ed68522a3a4cc34c56d443553c67836e78
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4e88e7248c84bd4f98d73f8e198982b79d6834e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dd0247f7038b7a0be849907fdb2cf790c049a14
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57a741cdcf26270c457faebc75df605bcd39279f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5974af622f2623e24929709b688fd16180bc61a2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b716e24744624fbf9d2f060edc324be3d816ffb
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e7746d1d773ed61939b94845f1605d86d3b95a4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a637b34cffbe7970b71611bc472c000befa7984a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3be951c7299123fc2dd955a1c83f02acb1e86a4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e705ed2527fcdf2099dd71491e8b4fea112ebf9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py b/MLPY/Lib/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py
new file mode 100644
index 0000000000000000000000000000000000000000..182cfc6ddb9ea9d299546e21b249ec571284038c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py
@@ -0,0 +1,118 @@
+from typing import Any, Dict, Iterable, List, no_type_check, Type
+
+import torch
+
+__all__: List[str] = []
+
+# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
+# without changing it's life-time.
+# NOTE: Alternative is to add the meta-data as an attribute to the tensor,
+#       but that will serialize the meta-data if Tensor is serialized.
+param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
+param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
+
+@no_type_check
+def _apply_optimizer_in_backward(
+    optimizer_class: Type[torch.optim.Optimizer],
+    params: Iterable[torch.nn.Parameter],
+    optimizer_kwargs: Dict[str, Any],
+    register_hook: bool = True,
+) -> None:
+    """
+    Upon ``backward()``, the optimizer specified for each parameter will fire after
+    the gradient has been accumulated into the parameter.
+
+    Note - gradients for these parameters will be set to None after ``backward()``.
+    This means that any other optimizer not specified via `_apply_optimizer_in_backward`
+    over this parameter will be a no-op.
+
+    Args:
+        optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
+        params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
+        optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
+        register_hook: (bool): whether to register a hook that runs the optimizer
+            after gradient for this parameter is accumulated. This is the default
+            way that optimizer in backward is implemented, but specific use cases
+            (such as DDP) may wish to override this to implement custom behavior.
+            (Default = True)
+
+    Example::
+        params_generator = model.parameters()
+        param_1 = next(params_generator)
+        remainder_params = list(params_generator)
+
+        apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
+        apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
+
+        model(...).sum().backward() # after backward, parameters will already
+        # have their registered optimizer(s) applied.
+
+    """
+    torch._C._log_api_usage_once(
+        "torch.distributed.optim.apply_optimizer_in_backward"
+    )
+
+    @no_type_check
+    def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
+        # view_as creates a node in autograd graph that allows us access to the
+        # parameter's AccumulateGrad autograd function object. We register a
+        # hook on this object to fire the optimizer when the gradient for
+        # this parameter is ready (has been accumulated into .grad field)
+
+        # Don't create a new acc_grad if we already have one
+        # i.e. for shared parameters or attaching multiple optimizers to a param.
+        if param not in param_to_acc_grad_map:
+            param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]
+
+        optimizer = optimizer_class([param], **optimizer_kwargs)
+
+        if not hasattr(param, "_in_backward_optimizers"):
+            param._in_backward_optimizers = []  # type: ignore[attr-defined]
+            # TODO: Remove these attributes once we have a better way of accessing
+            # optimizer classes and kwargs for a parameter.
+            param._optimizer_classes = []  # type: ignore[attr-defined]
+            param._optimizer_kwargs = []  # type: ignore[attr-defined]
+
+        param._in_backward_optimizers.append(optimizer)  # type: ignore[attr-defined]
+        param._optimizer_classes.append(optimizer_class)  # type: ignore[attr-defined]
+        param._optimizer_kwargs.append(optimizer_kwargs)  # type: ignore[attr-defined]
+
+        if not register_hook:
+            return
+
+        def optimizer_hook(*_unused) -> None:
+            for opt in param._in_backward_optimizers:  # type: ignore[attr-defined]
+                opt.step()
+
+            param.grad = None
+
+        handle = param_to_acc_grad_map[param].register_hook(optimizer_hook)  # type: ignore[attr-defined]
+        if param not in param_to_optim_hook_handle_map:
+            param_to_optim_hook_handle_map[param] = []
+        param_to_optim_hook_handle_map[param].append(handle)
+
+    for param in params:
+        _apply_optimizer_in_backward_to_param(param)
+
+
+def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]:
+    """
+    Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these
+    optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called
+    by the user and are intended to be used for things like checkpointing.
+
+    Args:
+        module: (torch.nn.Module): model to retrieve in-backward optimizers for
+
+    Returns:
+        List[torch.optim.Optimizer]: the in-backward optimizers.
+
+    Example::
+        _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
+        optims = _get_optimizers_in_backward(model)
+    """
+    optims: List[torch.optim.Optimizer] = []
+    for param in module.parameters():
+        optims.extend(getattr(param, "_in_backward_optimizers", []))
+
+    return optims
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_adadelta.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adadelta.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f4d83fb60a060452f5f7214c1685392bbff6e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adadelta.py
@@ -0,0 +1,102 @@
+from typing import Dict, List, Optional
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional Adadelta Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalAdadelta:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1.0,
+        rho: float = 0.9,
+        eps: float = 1e-6,
+        weight_decay: float = 0.0,
+        foreach: bool = False,
+        maximize: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        self.defaults = {
+            "lr": lr,
+            "rho": rho,
+            "eps": eps,
+            "weight_decay": weight_decay,
+        }
+        self.foreach = foreach
+        self.maximize = maximize
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        square_avgs = []
+        acc_deltas = []
+        lr = self.defaults["lr"]
+        rho = self.defaults["rho"]
+        eps = self.defaults["eps"]
+        weight_decay = self.defaults["weight_decay"]
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+        has_complex = False
+        for param, gradient in zip(params, gradients):
+            if gradient is not None:
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                # Lazy state initialization
+                if param not in self.state:
+                    self.state[param] = {}
+                    state = self.state[param]
+                    state["step"] = torch.tensor(0.0)
+                    state["square_avg"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    state["acc_delta"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+
+                state = self.state[param]
+                square_avgs.append(state["square_avg"])
+                acc_deltas.append(state["acc_delta"])
+
+        with torch.no_grad():
+            F.adadelta(
+                params_with_grad,
+                grads,
+                square_avgs,
+                acc_deltas,
+                lr=lr,
+                rho=rho,
+                eps=eps,
+                weight_decay=weight_decay,
+                foreach=self.foreach,
+                maximize=self.maximize,
+                has_complex=has_complex
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_adagrad.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adagrad.py
new file mode 100644
index 0000000000000000000000000000000000000000..280201ae4cf61be11c724d14a064964dc1797f84
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adagrad.py
@@ -0,0 +1,104 @@
+from typing import Dict, List, Optional
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional Adagrad Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly let the user pass gradients to the `step` function
+# this is so that we could separate the gradients and parameters
+# and allow multithreaded trainer to update the parameters
+# without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalAdagrad:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-2,
+        lr_decay: float = 0.0,
+        weight_decay: float = 0.0,
+        initial_accumulator_value: float = 0.0,
+        warmup_lr_multiplier: float = 1.0,
+        warmup_num_iters: float = 0.0,
+        eps: float = 1e-10,
+        coalesce_grad: bool = True,
+        foreach: bool = False,
+        maximize: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        self.defaults = {
+            "lr": lr,
+            "lr_decay": lr_decay,
+            "eps": eps,
+            "weight_decay": weight_decay,
+            "initial_accumulator_value": initial_accumulator_value,
+            "warmup_lr_multiplier": warmup_lr_multiplier,
+            "warmup_num_iters": warmup_num_iters,
+        }
+        self.coalesce_grad = coalesce_grad
+        self.foreach = foreach
+        self.maximize = maximize
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+        # TODO: no union or any types in TorchScript, make step a scalar tensor instead
+        # This is also needed by if we want to share_memory on the step across processes
+        for p in self.param_group["params"]:
+            self.state[p] = {
+                "sum": torch.full_like(p.data, initial_accumulator_value),
+                "step": torch.tensor(0.0),
+            }
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        state_sums = []
+        state_steps: List[Tensor] = []
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        has_sparse_grad, has_complex = False, False
+        for param, gradient in zip(self.param_group["params"], gradients):
+            if gradient is not None:
+                has_sparse_grad |= gradient.is_sparse
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                state = self.state[param]
+                state_sums.append(state["sum"])
+                state_steps.append(state["step"])
+
+        with torch.no_grad():
+            F.adagrad(
+                params,
+                grads,
+                state_sums,
+                state_steps,
+                lr=self.defaults["lr"],
+                weight_decay=self.defaults["weight_decay"],
+                lr_decay=self.defaults["lr_decay"],
+                eps=self.defaults["eps"],
+                has_sparse_grad=has_sparse_grad,
+                foreach=self.foreach,
+                maximize=self.maximize,
+                has_complex=has_complex,
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_adam.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adam.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9739bd1c6fb2793ca85c8dc9c9eba55c346ffc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adam.py
@@ -0,0 +1,196 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional Adam Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalAdam:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-3,
+        betas: Tuple[float, float] = (0.9, 0.999),
+        eps: float = 1e-8,
+        weight_decay: float = 0.0,
+        amsgrad: bool = False,
+        maximize: bool = False,
+        foreach: bool = False,
+        fused: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        if not 0.0 <= lr:
+            raise ValueError(f"Invalid learning rate: {lr}")
+        if not 0.0 <= eps:
+            raise ValueError(f"Invalid epsilon value: {eps}")
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+        if not 0.0 <= weight_decay:
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+        self.defaults = {
+            "lr": lr,
+            "eps": eps,
+            "beta1": betas[0],
+            "beta2": betas[1],
+            "weight_decay": weight_decay,
+        }
+        self.amsgrad = amsgrad
+        self.maximize = maximize
+        self.foreach = foreach
+        self.fused = fused
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+    def step_param(self, param: Tensor, grad: Optional[Tensor]):
+        """
+        Similar to step, but operates on a single parameter and optionally a
+        gradient tensor.
+        """
+        params_with_grad = []
+        grads = []
+        exp_avgs = []
+        exp_avg_sqs = []
+        max_exp_avg_sqs = []
+        state_steps: List[Tensor] = []
+        has_complex = torch.is_complex(param)
+        if grad is not None:
+            params_with_grad.append(param)
+            grads.append(grad)
+        if param not in self.state:
+            self.state[param] = {}
+            state = self.state[param]
+            state["step"] = torch.tensor(0.0)
+            state["exp_avg"] = torch.zeros_like(
+                param, memory_format=torch.preserve_format
+            )
+            state["exp_avg_sq"] = torch.zeros_like(
+                param, memory_format=torch.preserve_format
+            )
+            if self.amsgrad:
+                state["max_exp_avg_sq"] = torch.zeros_like(
+                    param, memory_format=torch.preserve_format
+                )
+
+        state = self.state[param]
+        exp_avgs.append(state["exp_avg"])
+        exp_avg_sqs.append(state["exp_avg_sq"])
+
+        if self.amsgrad:
+            max_exp_avg_sqs.append(state["max_exp_avg_sq"])
+
+        state_steps.append(state["step"])
+        with torch.no_grad():
+            F.adam(
+                params_with_grad,
+                grads,
+                exp_avgs,
+                exp_avg_sqs,
+                max_exp_avg_sqs,
+                state_steps,
+                amsgrad=self.amsgrad,
+                has_complex=has_complex,
+                maximize=self.maximize,
+                beta1=self.defaults["beta1"],
+                beta2=self.defaults["beta2"],
+                lr=self.defaults["lr"],
+                weight_decay=self.defaults["weight_decay"],
+                eps=self.defaults["eps"],
+                foreach=self.foreach,
+                fused=self.fused,
+                grad_scale=None,
+                found_inf=None,
+            )
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        exp_avgs = []
+        exp_avg_sqs = []
+        max_exp_avg_sqs = []
+        state_steps: List[Tensor] = []
+        has_complex = False
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        for param, gradient in zip(self.param_group["params"], gradients):
+            if gradient is not None:
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                # Lazy state initialization
+                if param not in self.state:
+                    self.state[param] = {}
+                    state = self.state[param]
+                    state["step"] = torch.tensor(0.0)
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    # Exponential moving average of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    if self.amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state["max_exp_avg_sq"] = torch.zeros_like(
+                            param, memory_format=torch.preserve_format
+                        )
+
+                state = self.state[param]
+
+                exp_avgs.append(state["exp_avg"])
+                exp_avg_sqs.append(state["exp_avg_sq"])
+
+                if self.amsgrad:
+                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
+
+                state_steps.append(state["step"])
+
+        with torch.no_grad():
+            F.adam(
+                params_with_grad,
+                grads,
+                exp_avgs,
+                exp_avg_sqs,
+                max_exp_avg_sqs,
+                state_steps,
+                amsgrad=self.amsgrad,
+                has_complex=has_complex,
+                maximize=self.maximize,
+                beta1=self.defaults["beta1"],
+                beta2=self.defaults["beta2"],
+                lr=self.defaults["lr"],
+                weight_decay=self.defaults["weight_decay"],
+                eps=self.defaults["eps"],
+                foreach=self.foreach,
+                fused=self.fused,
+                grad_scale=None,
+                found_inf=None,
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_adamax.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adamax.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9e02c4ae16edb177c9770aee42728c60980d0cd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adamax.py
@@ -0,0 +1,117 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional Adamax Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalAdamax:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-3,
+        betas: Tuple[float, float] = (0.9, 0.999),
+        eps: float = 1e-8,
+        weight_decay: float = 0.0,
+        foreach: bool = False,
+        maximize: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        if not 0.0 <= lr:
+            raise ValueError(f"Invalid learning rate: {lr}")
+        if not 0.0 <= eps:
+            raise ValueError(f"Invalid epsilon value: {eps}")
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+        if not 0.0 <= weight_decay:
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+        self.defaults = {
+            "lr": lr,
+            "eps": eps,
+            "beta1": betas[0],
+            "beta2": betas[1],
+            "weight_decay": weight_decay,
+        }
+        self.foreach = foreach
+        self.maximize = maximize
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        exp_avgs = []
+        exp_infs = []
+        state_steps: List[Tensor] = []
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        has_complex = False
+        for param, gradient in zip(self.param_group["params"], gradients):
+            if gradient is not None:
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                # Lazy state initialization
+                if param not in self.state:
+                    self.state[param] = {}
+                    state = self.state[param]
+                    state["step"] = torch.tensor(0.0)
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    # Exponential moving average of squared gradient values
+                    state["exp_inf"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+
+                state = self.state[param]
+
+                exp_avgs.append(state["exp_avg"])
+                exp_infs.append(state["exp_inf"])
+                state_steps.append(state["step"])
+
+        with torch.no_grad():
+            F.adamax(
+                params_with_grad,
+                grads,
+                exp_avgs,
+                exp_infs,
+                state_steps,
+                eps=self.defaults["eps"],
+                beta1=self.defaults["beta1"],
+                beta2=self.defaults["beta2"],
+                lr=self.defaults["lr"],
+                weight_decay=self.defaults["weight_decay"],
+                foreach=self.foreach,
+                maximize=self.maximize,
+                has_complex=has_complex,
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_adamw.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adamw.py
new file mode 100644
index 0000000000000000000000000000000000000000..58752d34615848bcecb354966f8a955e187ee407
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_adamw.py
@@ -0,0 +1,197 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional AdamW Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalAdamW:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-3,
+        betas: Tuple[float, float] = (0.9, 0.999),
+        eps: float = 1e-8,
+        weight_decay: float = 1e-2,
+        amsgrad: bool = False,
+        maximize: bool = False,
+        foreach: bool = False,
+        fused: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        if not 0.0 <= lr:
+            raise ValueError(f"Invalid learning rate: {lr}")
+        if not 0.0 <= eps:
+            raise ValueError(f"Invalid epsilon value: {eps}")
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+        if not 0.0 <= weight_decay:
+            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+        self.defaults = {
+            "lr": lr,
+            "eps": eps,
+            "beta1": betas[0],
+            "beta2": betas[1],
+            "weight_decay": weight_decay,
+        }
+        self.amsgrad = amsgrad
+        self.maximize = maximize
+        self.foreach = foreach
+        self.fused = fused
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+    def step_param(self, param: Tensor, grad: Optional[Tensor]):
+        params_with_grad = []
+        grads = []
+        exp_avgs = []
+        exp_avg_sqs = []
+        max_exp_avg_sqs = []
+        state_steps: List[Tensor] = []
+        has_complex = torch.is_complex(param)
+        if grad is not None:
+            params_with_grad.append(param)
+            grads.append(grad)
+        # Lazy state initialization
+        if param not in self.state:
+            self.state[param] = {}
+            state = self.state[param]
+            state["step"] = torch.tensor(0.0)
+            # Exponential moving average of gradient values
+            state["exp_avg"] = torch.zeros_like(
+                param, memory_format=torch.preserve_format
+            )
+            # Exponential moving average of squared gradient values
+            state["exp_avg_sq"] = torch.zeros_like(
+                param, memory_format=torch.preserve_format
+            )
+            if self.amsgrad:
+                # Maintains max of all exp. moving avg. of sq. grad. values
+                state["max_exp_avg_sq"] = torch.zeros_like(
+                    param, memory_format=torch.preserve_format
+                )
+
+        state = self.state[param]
+
+        exp_avgs.append(state["exp_avg"])
+        exp_avg_sqs.append(state["exp_avg_sq"])
+
+        if self.amsgrad:
+            max_exp_avg_sqs.append(state["max_exp_avg_sq"])
+
+        state_steps.append(state["step"])
+        with torch.no_grad():
+            F.adamw(
+                params_with_grad,
+                grads,
+                exp_avgs,
+                exp_avg_sqs,
+                max_exp_avg_sqs,
+                state_steps,
+                amsgrad=self.amsgrad,
+                maximize=self.maximize,
+                beta1=self.defaults["beta1"],
+                beta2=self.defaults["beta2"],
+                lr=self.defaults["lr"],
+                weight_decay=self.defaults["weight_decay"],
+                eps=self.defaults["eps"],
+                foreach=self.foreach,
+                fused=self.fused,
+                grad_scale=None,
+                found_inf=None,
+                has_complex=has_complex,
+            )
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        exp_avgs = []
+        exp_avg_sqs = []
+        max_exp_avg_sqs = []
+        state_steps: List[Tensor] = []
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        has_complex = False
+        for param, gradient in zip(self.param_group["params"], gradients):
+            if gradient is not None:
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                # Lazy state initialization
+                if param not in self.state:
+                    self.state[param] = {}
+                    state = self.state[param]
+                    state["step"] = torch.tensor(0.0)
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    # Exponential moving average of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    if self.amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state["max_exp_avg_sq"] = torch.zeros_like(
+                            param, memory_format=torch.preserve_format
+                        )
+
+                state = self.state[param]
+
+                exp_avgs.append(state["exp_avg"])
+                exp_avg_sqs.append(state["exp_avg_sq"])
+
+                if self.amsgrad:
+                    max_exp_avg_sqs.append(state["max_exp_avg_sq"])
+
+                state_steps.append(state["step"])
+
+        with torch.no_grad():
+            F.adamw(
+                params_with_grad,
+                grads,
+                exp_avgs,
+                exp_avg_sqs,
+                max_exp_avg_sqs,
+                state_steps,
+                amsgrad=self.amsgrad,
+                maximize=self.maximize,
+                beta1=self.defaults["beta1"],
+                beta2=self.defaults["beta2"],
+                lr=self.defaults["lr"],
+                weight_decay=self.defaults["weight_decay"],
+                eps=self.defaults["eps"],
+                foreach=self.foreach,
+                fused=self.fused,
+                grad_scale=None,
+                found_inf=None,
+                has_complex=has_complex,
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_rmsprop.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_rmsprop.py
new file mode 100644
index 0000000000000000000000000000000000000000..61c4b15fa79b94fe463c1a8ea507e38975b5bfaa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_rmsprop.py
@@ -0,0 +1,122 @@
+from typing import Dict, List, Optional
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional RMSprop Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalRMSprop:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-2,
+        alpha: float = 0.99,
+        eps: float = 1e-8,
+        weight_decay: float = 0.0,
+        momentum: float = 0.0,
+        centered: bool = False,
+        foreach: bool = False,
+        maximize: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        self.defaults = {
+            "lr": lr,
+            "alpha": alpha,
+            "eps": eps,
+            "weight_decay": weight_decay,
+            "momentum": momentum,
+        }
+        self.centered = centered
+        self.foreach = foreach
+        self.maximize = maximize
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        square_avgs = []
+        grad_avgs = []
+        momentum_buffer_list = []
+        lr = self.defaults["lr"]
+        alpha = self.defaults["alpha"]
+        eps = self.defaults["eps"]
+        momentum = self.defaults["momentum"]
+        weight_decay = self.defaults["weight_decay"]
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        has_complex = False
+        for param, gradient in zip(params, gradients):
+            if gradient is not None:
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                # Lazy state initialization
+                if param not in self.state:
+                    self.state[param] = {}
+                    state = self.state[param]
+                    state["step"] = torch.tensor(0.0)
+                    state["square_avg"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    if momentum > 0:
+                        state["momentum_buffer"] = torch.zeros_like(
+                            param, memory_format=torch.preserve_format
+                        )
+                    if self.centered:
+                        state["grad_avg"] = torch.zeros_like(
+                            param, memory_format=torch.preserve_format
+                        )
+
+                state = self.state[param]
+                square_avgs.append(state["square_avg"])
+                if momentum > 0:
+                    momentum_buffer_list.append(state["momentum_buffer"])
+                if self.centered:
+                    grad_avgs.append(state["grad_avg"])
+
+                state["step"] += 1
+
+        with torch.no_grad():
+            F.rmsprop(
+                params_with_grad,
+                grads,
+                square_avgs,
+                grad_avgs,
+                momentum_buffer_list,
+                lr=lr,
+                alpha=alpha,
+                eps=eps,
+                weight_decay=weight_decay,
+                momentum=momentum,
+                centered=self.centered,
+                foreach=self.foreach,
+                maximize=self.maximize,
+                has_complex=has_complex,
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_rprop.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_rprop.py
new file mode 100644
index 0000000000000000000000000000000000000000..90e4d5fd9b19f1e26f4b1ddf1f6348c9e602da4a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_rprop.py
@@ -0,0 +1,100 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional Rprop Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalRprop:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-2,
+        etas: Tuple[float, float] = (0.5, 1.2),
+        step_sizes: Tuple[float, float] = (1e-6, 50),
+        foreach: bool = False,
+        maximize: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        self.defaults = {
+            "lr": lr,
+        }
+        self.etas = etas
+        self.step_sizes = step_sizes
+        self.foreach = foreach
+        self.maximize = maximize
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        prevs = []
+        step_sizes = []
+        lr = self.defaults["lr"]
+        etaminus, etaplus = self.etas
+        step_size_min, step_size_max = self.step_sizes
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        has_complex = False
+        for param, gradient in zip(params, gradients):
+            if gradient is not None:
+                has_complex |= torch.is_complex(param)
+                params_with_grad.append(param)
+                grads.append(gradient)
+                # Lazy state initialization
+                if param not in self.state:
+                    self.state[param] = {}
+                    state = self.state[param]
+                    state["step"] = torch.tensor(0.0)
+                    state["prev"] = torch.zeros_like(
+                        param, memory_format=torch.preserve_format
+                    )
+                    state["step_size"] = torch.full_like(gradient, lr)
+
+                state = self.state[param]
+                prevs.append(state["prev"])
+                step_sizes.append(state["step_size"])
+
+                state["step"] += 1
+
+        with torch.no_grad():
+            F.rprop(
+                params_with_grad,
+                grads,
+                prevs,
+                step_sizes,
+                step_size_min=step_size_min,
+                step_size_max=step_size_max,
+                etaminus=etaminus,
+                etaplus=etaplus,
+                foreach=self.foreach,
+                maximize=self.maximize,
+                has_complex=has_complex,
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/functional_sgd.py b/MLPY/Lib/site-packages/torch/distributed/optim/functional_sgd.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cb305c2fb3c21e51d1de970b0e3c8107c3a401e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/functional_sgd.py
@@ -0,0 +1,160 @@
+from typing import Dict, List, Optional
+
+import torch
+import torch.optim._functional as F
+
+from torch import Tensor
+
+__all__: List[str] = []
+
+# Define a TorchScript compatible Functional SGD Optimizer
+# where we use these optimizer in a functional way.
+# Instead of using the `param.grad` when updating parameters,
+# we explicitly allow the distributed optimizer pass gradients to
+# the `step` function. In this way, we could separate the gradients
+# and parameters and allow multithreaded trainer to update the
+# parameters without data traces on accumulating to the same .grad.
+# NOTE: This should be only used by distributed optimizer internals
+# and not meant to expose to the user.
+@torch.jit.script
+class _FunctionalSGD:
+    def __init__(
+        self,
+        params: List[Tensor],
+        lr: float = 1e-2,
+        momentum: float = 0.0,
+        dampening: float = 0.0,
+        weight_decay: float = 0.0,
+        nesterov: bool = False,
+        maximize: bool = False,
+        foreach: bool = False,
+        fused: bool = False,
+        _allow_empty_param_list: bool = False,
+    ):
+        self.defaults = {
+            "lr": lr,
+            "momentum": momentum,
+            "dampening": dampening,
+            "weight_decay": weight_decay,
+        }
+        self.nesterov = nesterov
+        self.maximize = maximize
+        self.foreach = foreach
+        self.fused = fused
+        self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
+
+        if len(params) == 0 and not _allow_empty_param_list:
+            raise ValueError("optimizer got an empty parameter list")
+
+        # NOTE: we only have one param_group and don't allow user to add additional
+        # param group as it's not a common use case.
+        self.param_group = {"params": params}
+
+    def step_param(self, param: Tensor, grad: Optional[Tensor]):
+        """Similar to self.step, but operates on a single parameter and
+        its gradient.
+        """
+        # TODO: Once step_param interface is robust, refactor step to call
+        # step param on each param.
+        weight_decay = self.defaults["weight_decay"]
+        momentum = self.defaults["momentum"]
+        dampening = self.defaults["dampening"]
+        lr = self.defaults["lr"]
+        params = [param]
+        momentum_buffer_list: List[Optional[Tensor]] = []
+        grads = []
+
+        has_sparse_grad = False
+        if grad is not None:
+            grads.append(grad)
+            if grad.is_sparse:
+                has_sparse_grad = True
+            if param not in self.state:
+                self.state[param] = {}
+            state = self.state[param]
+            if "momentum_buffer" not in state:
+                momentum_buffer_list.append(None)
+            else:
+                momentum_buffer_list.append(state["momentum_buffer"])
+
+        with torch.no_grad():
+            F.sgd(
+                params,
+                grads,
+                momentum_buffer_list,
+                weight_decay=weight_decay,
+                momentum=momentum,
+                lr=lr,
+                dampening=dampening,
+                nesterov=self.nesterov,
+                maximize=self.maximize,
+                has_sparse_grad=has_sparse_grad,
+                foreach=self.foreach,
+                fused=self.fused,
+                grad_scale=None,
+                found_inf=None,
+            )
+        # update momentum_buffer in state
+        state = self.state[param]
+        momentum_buffer = momentum_buffer_list[0]
+        if momentum_buffer is not None:
+            state["momentum_buffer"] = momentum_buffer
+
+    def step(self, gradients: List[Optional[Tensor]]):
+        params = self.param_group["params"]
+        params_with_grad = []
+        grads = []
+        momentum_buffer_list: List[Optional[Tensor]] = []
+        lr = self.defaults["lr"]
+        weight_decay = self.defaults["weight_decay"]
+        momentum = self.defaults["momentum"]
+        dampening = self.defaults["dampening"]
+
+        if len(params) != len(gradients):
+            raise ValueError(
+                "the gradients passed in does not equal to the size of the parameters!"
+                + f"Params length: {len(params)}. "
+                + f"Gradients length: {len(gradients)}"
+            )
+
+        has_sparse_grad = False
+        for param, gradient in zip(params, gradients):
+            if gradient is not None:
+                params_with_grad.append(param)
+                grads.append(gradient)
+                if gradient.is_sparse:
+                    has_sparse_grad = True
+
+                if param not in self.state:
+                    self.state[param] = {}
+
+                state = self.state[param]
+                if "momentum_buffer" not in state:
+                    momentum_buffer_list.append(None)
+                else:
+                    momentum_buffer_list.append(state["momentum_buffer"])
+
+        with torch.no_grad():
+            F.sgd(
+                params_with_grad,
+                grads,
+                momentum_buffer_list,
+                weight_decay=weight_decay,
+                momentum=momentum,
+                lr=lr,
+                dampening=dampening,
+                nesterov=self.nesterov,
+                maximize=self.maximize,
+                has_sparse_grad=has_sparse_grad,
+                foreach=self.foreach,
+                fused=self.fused,
+                grad_scale=None,
+                found_inf=None,
+            )
+
+        # update momentum_buffers in state
+        for i, p in enumerate(params_with_grad):
+            state = self.state[p]
+            momentum_buffer = momentum_buffer_list[i]
+            if momentum_buffer is not None:
+                state["momentum_buffer"] = momentum_buffer
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/named_optimizer.py b/MLPY/Lib/site-packages/torch/distributed/optim/named_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..87cb734abccb132ee994555d7f5cf7bb9a823c90
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/named_optimizer.py
@@ -0,0 +1,331 @@
+import logging
+import warnings
+
+from copy import deepcopy
+from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload
+
+import torch
+import torch.nn as nn
+from torch import optim
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+
+__all__: List[str] = []
+
+logger = logging.getLogger(__name__)
+
+
+class _NamedOptimizer(optim.Optimizer):
+    """
+    ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key.
+
+    We replace the original key (number) in an optim to the
+    fully qualified name (FQN) string. User can initialize the optim as they
+    initialize a PyTorch optim, the only difference is that they also need to
+    pass in the FQN of each parameters.
+
+    Args:
+        named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
+            Mapping from FQN to parameter.
+        optimizer_class (optim.Optimizer):
+            The class of optimizer to instantiate.
+        param_groups (Collection[Mapping[str, Any]]):
+            `param_groups` to pass to optimizer if specified.
+            The key of the inner map needs to be FQNs.
+            Default: None
+        module (nn.Module): the module whose parameters to updated
+            by the optimizer.
+        args: arguments to pass to the optimizer constructor.
+        kwargs: arguments to pass to the optimizer constructor.
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> from torch import optim
+        >>> from torch.distributed.optim import _NamedOptimizer
+        >>>
+        >>> # Define the named optimizer.
+        >>> m = Model(...)
+        >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
+        >>> # Forward pass + backward pass.
+        >>> named_optim.step()
+        >>> ...
+        >>> # Call state_dict for the named optimizer returns a FQN state_dict.
+        >>> named_optim.state_dict()
+
+    Warning: This API is still in development and subject to change.
+
+    TODO: Add tutorial for _NamedOptimizer.
+    TODO: Add documentation in the docstring for the public attributes
+          like self.param_groups and self.named_parameters.
+    """
+
+    def __init__(
+        self,
+        named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
+        optimizer_class: optim.Optimizer,
+        param_groups: Optional[Collection[Mapping[str, Any]]] = None,
+        module: Optional[nn.Module] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
+        self.param_groups: Collection[Mapping[str, Any]] = param_groups  # type: ignore[assignment]
+        self._param_groups_check()
+        self.named_parameters = dict(named_parameters)
+        params_for_optimizer = (
+            self.named_parameters.values() if param_groups is None else param_groups
+        )
+        self._optimizer = optimizer_class(  # type: ignore[operator]
+            params_for_optimizer,
+            *args,
+            **kwargs,
+        )
+        self.module = module
+        if param_groups is None:
+            self.ordered_param_keys = list(self.named_parameters.keys())
+        else:
+            warnings.warn(
+                "Since we pass in param_groups, we will use param_groups to "
+                "initialize the optimizer, not all parameters of the module."
+            )
+            param_to_key = {param: key for key, param in self.named_parameters.items()}  # type: ignore[misc, has-type]
+            ordered_param_keys = []
+            for group in param_groups:
+                for param in group["params"]:
+                    if param not in param_to_key:
+                        raise ValueError(
+                            f"Expect param name {param} found in param group but is missing."
+                        )
+                    ordered_param_keys.append(param_to_key[param])
+            self.ordered_param_keys = ordered_param_keys
+        # Update param_groups from optimizer.
+        self.param_groups = self._optimizer.param_groups
+
+    def _param_groups_check(self):
+        if self.param_groups is not None:
+            for param_group in self.param_groups:
+                assert isinstance(param_group, dict), "param group must be a dict"
+                assert "params" in param_group, "param group must contain key params"
+                params = param_group["params"]
+                if isinstance(params, torch.Tensor):
+                    params = [params]
+                params = list(params)
+                for param in params:
+                    if not isinstance(param, torch.Tensor):
+                        raise TypeError(
+                            "optimizer can only optimize Tensors, "
+                            "but one of the params is " + torch.typename(param)
+                        )
+                param_group["params"] = params
+
+    def state_dict(self) -> Dict[str, Any]:
+        """
+        Return the ``state_dict`` of the optimizer.
+
+        Instead of using number to index
+        parameters, we will use module fully qualified name (FQN) as the key.
+        """
+        state_dict = self._optimizer.state_dict()
+        param_groups = state_dict["param_groups"]
+
+        ret_state = {
+            self.ordered_param_keys[st_key]: state_val
+            for st_key, state_val in state_dict["state"].items()
+        }
+
+        ret_groups = []
+        for group in param_groups:
+            param_keys = []
+            for param in group["params"]:
+                param_keys.append(self.ordered_param_keys[param])
+            ret_group = {"params": sorted(param_keys)}
+            for k, v in group.items():
+                if k != "params":
+                    ret_group[k] = deepcopy(v)
+            ret_groups.append(ret_group)
+
+        return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
+
+    @overload
+    def step(self, closure: None = ...) -> None:
+        ...
+
+    @overload
+    def step(self, closure: Callable[[], float]) -> float:
+        ...
+
+    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
+        """
+        Perform a single optimization step.
+
+        This will call :meth:`torch.optim.Optimizer.step` on the wrapped
+        optimizer.
+        """
+        return self._optimizer.step(closure=closure)
+
+    @property
+    def state(self) -> Mapping[torch.Tensor, Any]:  # type: ignore[override]
+        return self._optimizer.state
+
+    def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
+        """
+        Define the default behavior to load a state_dict for ``_NamedOptimizer``.
+
+        Sample Code
+        ```
+            my_model = MyModule()
+            optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
+            ...
+
+            optim_state_dict = optimizer.state_dict()
+            ...
+            ...
+
+            optimizer.load_state_dict(optim_state_dict)
+            ...
+        ```
+        Args:
+            state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer.
+                Note that this state dict update is performed in place.
+
+        .. note:: PyTorch is using lazy init to initialize the optim states.
+            So it is possible that there is no optim state when user call
+            ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
+            that users can only call ``load_state_dict`` after the state is initialized.
+            By doing this, we can validate the optim ``state_dict`` to be loaded.
+        """
+        new_state_dict = self._optimizer.state_dict()
+        state_dict = self._pre_load_state_dict(state_dict)
+        state = state_dict["state"]
+        new_state = new_state_dict["state"]
+        if len(new_state) == 0:
+            raise ValueError(
+                "Expects the optim to be initialized before load but found not initialized."
+            )
+
+        for idx, param_key in enumerate(self.ordered_param_keys):
+            # When the conditional training is performed, not all parameters are updated in the optim.
+            if param_key not in state.keys():
+                continue
+            if len(state[param_key]) != len(new_state[idx]):
+                raise ValueError(
+                    f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
+                )
+            # Iterate through all optimizer states.
+            for state_key, state_val in new_state[idx].items():
+                if state_key not in state[param_key]:
+                    raise ValueError(
+                        f"Expects state {state_key} for parameter {param_key} but not found."
+                    )
+
+                src_state_val = state[param_key][state_key]
+                if isinstance(state_val, ShardedTensor):
+                    assert isinstance(src_state_val, ShardedTensor)
+                    num_shards = len(state_val.local_shards())
+                    num_new_shards = len(src_state_val.local_shards())
+                    if num_shards != num_new_shards:
+                        raise ValueError(
+                            f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}"
+                        )
+                    for shard, src_shard in zip(
+                        state_val.local_shards(), src_state_val.local_shards()
+                    ):
+                        shard.tensor.detach().copy_(src_shard.tensor)
+                elif isinstance(state_val, torch.Tensor):
+                    assert isinstance(src_state_val, torch.Tensor)
+                    state_val.detach().copy_(src_state_val)
+                else:
+                    new_state[idx][state_key] = deepcopy(src_state_val)
+
+        # Load param_groups of state_dict
+        src_param_groups = state_dict["param_groups"]
+        new_param_groups = new_state_dict["param_groups"]
+
+        src_group_map = {}
+        for group in src_param_groups:
+            param_keys = list(group["params"])
+            src_group_map[_gen_param_group_key(param_keys)] = group
+        new_group_map = {}
+        for new_group in new_param_groups:
+            param_keys = []
+            for param_key in new_group["params"]:
+                param_keys.append(self.ordered_param_keys[param_key])  # type: ignore[call-overload]
+            new_group_map[_gen_param_group_key(param_keys)] = new_group
+        for group_key, new_group in new_group_map.items():
+            # When not all parameters are used in training or receive gradient, aka., not all parameters
+            # would be in the param_group. Thus we skip the group_key here.
+            if group_key not in src_group_map:
+                continue
+            src_group = src_group_map[group_key]
+            if len(src_group) != len(new_group):
+                raise ValueError(
+                    f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}."
+                )
+            for k in src_group:
+                if k not in new_group:
+                    raise ValueError(
+                        f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing."
+                    )
+                if k != "params":
+                    new_group[k] = deepcopy(src_group[k])
+
+        self._optimizer.load_state_dict(new_state_dict)
+
+    def add_param_group(self, param_group: Mapping[str, Any]) -> None:
+        """
+        Add a param group to the :class:`_NamedOptimizer` s `param_groups`.
+
+        Warning: This API is still in development and subject to change.
+        """
+        assert isinstance(param_group, dict), "param group must be a dict"
+
+        params = param_group["params"]
+        if isinstance(params, torch.Tensor):
+            param_group["params"] = [params]
+        else:
+            param_group["params"] = list(params)
+
+        param_to_key = {param: key for key, param in self.named_parameters.items()}  # type: ignore[misc, has-type]
+        for param in param_group["params"]:
+            if param not in param_to_key:
+                raise ValueError("some parameters are not in the module")
+            self.ordered_param_keys.append(param_to_key[param])
+
+        self._optimizer.add_param_group(param_group)
+        # Update param_groups from optimizer.
+        self.param_groups = self._optimizer.param_groups
+
+    def init_state(self) -> None:
+        """
+        Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers.
+
+        This allows doing in-place loading of optimizer state from a checkpoint.
+        """
+        for param in self.named_parameters.values():
+            if param.requires_grad:
+                t = torch.zeros_like(param)
+                param.grad = torch.autograd.Variable(t)
+        # Calling ``step`` will load the initial state for optimizer states.
+        self.step(closure=None)
+
+    def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
+        # TODO(chienchin): This API should be FSDP agnostic and should support
+        # general user hooks.
+        if isinstance(self.module, FSDP):
+            return FSDP.optim_state_dict_to_load(
+                self.module, self._optimizer, state_dict, is_named_optimizer=True
+            )
+        return state_dict
+
+    def _post_state_dict(self, state_dict) -> Dict[str, Any]:
+        # TODO(chienchin): This API should be FSDP agnostic and should support
+        # general user hooks.
+        if isinstance(self.module, FSDP):
+            FSDP.optim_state_dict(self.module, self._optimizer, state_dict)
+        return state_dict
+
+
+def _gen_param_group_key(param_keys: List[str]) -> str:
+    """Concatenate all param keys as a unique indentifier for one param group."""
+    return "/".join(sorted(param_keys))
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/optimizer.py b/MLPY/Lib/site-packages/torch/distributed/optim/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..de7dc1607fd8ed52078ba3407d468767464f8575
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/optimizer.py
@@ -0,0 +1,254 @@
+import logging
+
+from collections import defaultdict
+from threading import Lock
+from typing import List, Optional
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.jit as jit
+import torch.nn as nn
+from torch import Tensor
+from torch.distributed.rpc import RRef
+from .utils import functional_optim_map
+
+__all__ = ["DistributedOptimizer"]
+
+logger = logging.getLogger(__name__)
+
+
+# XXX: we define a _ScriptModuleOptimizer here to explicitly
+# compile the FunctionalOptimizer class into TorchScript
+# This is because ScriptClass instance still lives in
+# python unless you explicitly compile it as an attribute
+# in ScriptModule or pass it to a ScriptFunction
+# _ScriptLocalOptimizerInterface serves as a common
+# interface type for Optimizer ScriptModules.
+#
+# TODO (wanchaol): remove this once we added TorchScript
+# class reference semantics
+@jit.interface
+class _ScriptLocalOptimizerInterface:
+    def step(self, autograd_ctx_id: int) -> None:
+        pass
+
+
+class _ScriptLocalOptimizer(nn.Module):
+    # TorchScript does not support multithread concurrent compiling.
+    # request_callback might invoke concurrent compiling, so we
+    # serialize the compiling with a lock
+    compile_lock = Lock()
+
+    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
+        super().__init__()
+        self._local_params = [rref.local_value() for rref in local_params_rref]
+        self.optim = optim_cls(self._local_params, *args, **kwargs)
+
+    @jit.export
+    def step(self, autograd_ctx_id: int):
+        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
+        # apply functional optimizer step with a list of gradients
+        grads: List[Optional[Tensor]] = [
+            all_local_grads[p] if p in all_local_grads else None
+            for p in self._local_params
+        ]
+
+        self.optim.step(grads)
+
+
+# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once
+# we have converted all to functional optimizer in distributed.optim
+class _LocalOptimizer:
+    # Ideally we would only need to share a lock for instances of
+    # _LocalOptimizer that deal with the same parameters. We are
+    # making a simplifying assumption here that if there is more
+    # than one instance of _LocalOptimizer per worker, they will
+    # be optimizing the same parameters (e.g. each data parallel
+    # trainer will create its own instance of _LocalOptimizer but
+    # they will all optimize the same parameters on each worker)
+    global_lock = Lock()
+
+    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
+        self._local_params = [rref.local_value() for rref in local_params_rref]
+        self.optim = optim_cls(self._local_params, *args, **kwargs)
+
+    def step(self, autograd_ctx_id):
+        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
+
+        with _LocalOptimizer.global_lock:
+            for param, grad in all_local_grads.items():
+                param.grad = grad
+            self.optim.step()
+
+
+def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
+    return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
+
+
+def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
+    local_optim = local_optim_rref.local_value()
+    local_optim.step(autograd_ctx_id)
+
+
+# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
+def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
+    optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
+
+    with _ScriptLocalOptimizer.compile_lock:
+        script_optim = jit.script(optim)
+        return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface)
+
+
+@jit.script
+def _script_local_optimizer_step(
+    local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int
+) -> None:
+    local_optim = local_optim_rref.local_value()
+    local_optim.step(autograd_ctx_id)
+
+
+def _wait_for_all(rpc_futs):
+    # TODO: improve error propagation
+    exception = None
+    results = []
+    for fut in rpc_futs:
+        try:
+            results.append(fut.wait())
+        except Exception as e:
+            results.append(e)
+            exception = e
+    if exception is not None:
+        raise exception
+    return results
+
+
+class DistributedOptimizer:
+    """
+    DistributedOptimizer takes remote references to parameters scattered
+    across workers and applies the given optimizer locally for each parameter.
+
+    This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
+    to retrieve the gradients for specific parameters.
+
+    Concurrent calls to
+    :meth:`~torch.distributed.optim.DistributedOptimizer.step`,
+    either from the same or different clients, will
+    be serialized on each worker -- as each worker's optimizer can only work
+    on one set of gradients at a time. However, there is no guarantee that
+    the full forward-backward-optimizer sequence will execute for one client
+    at a time. This means that the gradients being applied may not correspond
+    to the latest forward pass executed on a given worker. Also, there is no
+    guaranteed ordering across workers.
+
+    `DistributedOptimizer` creates the local optimizer with TorchScript enabled
+    by default, so that optimizer updates are not blocked by the Python Global
+    Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
+    Model Parallel). This feature is currently enabled for most optimizers. You
+    can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
+    for your own custom optimizers.
+
+    Args:
+        optimizer_class (optim.Optimizer): the class of optimizer to
+            instantiate on each worker.
+        params_rref (list[RRef]): list of RRefs to local or remote parameters
+            to optimize.
+        args: arguments to pass to the optimizer constructor on each worker.
+        kwargs: arguments to pass to the optimizer constructor on each worker.
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> import torch.distributed.autograd as dist_autograd
+        >>> import torch.distributed.rpc as rpc
+        >>> from torch import optim
+        >>> from torch.distributed.optim import DistributedOptimizer
+        >>>
+        >>> with dist_autograd.context() as context_id:
+        >>>   # Forward pass.
+        >>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
+        >>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
+        >>>   loss = rref1.to_here() + rref2.to_here()
+        >>>
+        >>>   # Backward pass.
+        >>>   dist_autograd.backward(context_id, [loss.sum()])
+        >>>
+        >>>   # Optimizer.
+        >>>   dist_optim = DistributedOptimizer(
+        >>>      optim.SGD,
+        >>>      [rref1, rref2],
+        >>>      lr=0.05,
+        >>>   )
+        >>>   dist_optim.step(context_id)
+
+    __ https://github.com/pytorch/tutorials/pull/1465
+    """
+
+    def __init__(self, optimizer_class, params_rref, *args, **kwargs):
+        torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer")
+        per_worker_params_rref = defaultdict(list)
+        for param in params_rref:
+            per_worker_params_rref[param.owner()].append(param)
+
+        if optimizer_class in functional_optim_map and jit._state._enabled:
+            optim_ctor = functional_optim_map.get(optimizer_class)
+        else:
+            optim_ctor = optimizer_class
+        self.is_functional_optim = optim_ctor != optimizer_class
+
+        if self.is_functional_optim:
+            optimizer_new_func = _new_script_local_optimizer
+        else:
+            logger.warning(
+                "Creating the optimizer %s without TorchScript support, "
+                "this might result in slow computation time in multithreading environment"
+                "(i.e. Distributed Model Parallel training on CPU) due to the Python's "
+                "Global Interpreter Lock (GIL). Please file an issue if you need this "
+                "optimizer in TorchScript. ",
+                optimizer_class
+            )
+            optimizer_new_func = _new_local_optimizer
+
+        remote_optim_futs = []
+        for worker, param_rrefs in per_worker_params_rref.items():
+            remote_optim_rref_fut = rpc.rpc_async(
+                worker,
+                optimizer_new_func,
+                args=(optim_ctor, param_rrefs) + args,
+                kwargs=kwargs,
+            )
+            remote_optim_futs.append(remote_optim_rref_fut)
+
+        self.remote_optimizers = _wait_for_all(remote_optim_futs)
+
+    def step(self, context_id):
+        """
+        Performs a single optimization step.
+
+        This will call :meth:`torch.optim.Optimizer.step` on each worker
+        containing parameters to be optimized, and will block until all workers
+        return. The provided ``context_id`` will be used to retrieve the
+        corresponding :class:`~torch.distributed.autograd.context` that
+        contains the gradients that should be applied to the parameters.
+
+        Args:
+            context_id: the autograd context id for which we should run the
+                optimizer step.
+        """
+        dist_autograd._is_valid_context(context_id)
+
+        optimizer_step_func = (
+            _script_local_optimizer_step
+            if self.is_functional_optim
+            else _local_optimizer_step
+        )
+
+        rpc_futs = []
+        for optimizer in self.remote_optimizers:
+            rpc_futs.append(
+                rpc.rpc_async(
+                    optimizer.owner(),
+                    optimizer_step_func,
+                    args=(optimizer, context_id),
+                )
+            )
+        _wait_for_all(rpc_futs)
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/post_localSGD_optimizer.py b/MLPY/Lib/site-packages/torch/distributed/optim/post_localSGD_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..85514aade2d84347e84bebb5a44e1eed3db06218
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/post_localSGD_optimizer.py
@@ -0,0 +1,109 @@
+import warnings
+
+import torch
+import torch.distributed.algorithms.model_averaging.averagers as averagers
+
+
+class PostLocalSGDOptimizer(torch.optim.Optimizer):
+    r"""
+    Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD `_,
+    This optimizer runs local optimizer at every step.
+    After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
+
+    Args:
+        optim: The local optimizer.
+        averager: A model averager instance to run post-localSGD algorithm.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("undefined variables")
+        >>> import torch
+        >>> import torch.distributed as dist
+        >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
+        >>> import torch.nn as nn
+        >>> from torch.distributed.optim import PostLocalSGDOptimizer
+        >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
+        >>>   PostLocalSGDState,
+        >>>   post_localSGD_hook,
+        >>> )
+        >>>
+        >>> model = nn.parallel.DistributedDataParallel(
+        >>>    module, device_ids=[rank], output_device=rank
+        >>> )
+        >>>
+        >>> # Register a post-localSGD communication hook.
+        >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
+        >>> model.register_comm_hook(state, post_localSGD_hook)
+        >>>
+        >>> # Create a post-localSGD optimizer that wraps a local optimizer.
+        >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
+        >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+        >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
+        >>> opt = PostLocalSGDOptimizer(
+        >>>     optim=local_optim,
+        >>>     averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
+        >>> )
+        >>>
+        >>> # In the first 100 steps, DDP runs global gradient averaging at every step.
+        >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
+        >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
+        >>> for step in range(0, 200):
+        >>>    opt.zero_grad()
+        >>>    loss = loss_fn(output, labels)
+        >>>    loss.backward()
+        >>>    opt.step()
+    """
+
+    def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
+        self.optim = optim
+        self.param_groups = self.optim.param_groups
+        self.averager = averager
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __repr__(self):
+        return self.optim.__repr__()
+
+    def state_dict(self):
+        r"""
+        This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,
+        but adds an extra entry to record model averager's step to the checkpoint
+        to ensure reload does not cause unnecessary warm up again.
+        """
+        optim_state_dict = self.optim.state_dict()
+        optim_state_dict["step"] = self.averager.step
+        return optim_state_dict
+
+    def load_state_dict(self, state_dict):
+        r"""
+        This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,
+        but also restores model averager's step value to the one
+        saved in the provided ``state_dict``.
+
+        If there is no ``"step"`` entry in ``state_dict``,
+        it will raise a warning and initialize the model averager's step to 0.
+        """
+        self.optim.load_state_dict(state_dict)
+        if "step" in state_dict:
+            self.averager.step = state_dict["step"]
+        else:
+            warnings.warn(
+                "Loaded state dict does not contain a step counter for an averager. "
+                "Setting step counter to 0."
+            )
+            self.averager.step = 0
+
+    def step(self):
+        r"""
+        Performs a single optimization step (parameter update).
+        """
+        self.optim.step()
+        self.averager.average_parameters(params=self.param_groups)
+
+    def zero_grad(self, set_to_none: bool = True):  # type: ignore[override]
+        self.optim.zero_grad(set_to_none=set_to_none)
+
+    def add_param_group(self, param_group):
+        self.optim.add_param_group(param_group)
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/utils.py b/MLPY/Lib/site-packages/torch/distributed/optim/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9507f1a29d2aadc6cc8672048daae9256e73c8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/utils.py
@@ -0,0 +1,63 @@
+from typing import Type
+
+from torch import optim
+from .functional_adadelta import _FunctionalAdadelta
+from .functional_adagrad import _FunctionalAdagrad
+from .functional_adam import _FunctionalAdam
+from .functional_adamax import _FunctionalAdamax
+from .functional_adamw import _FunctionalAdamW
+from .functional_rmsprop import _FunctionalRMSprop
+from .functional_rprop import _FunctionalRprop
+from .functional_sgd import _FunctionalSGD
+
+# dict to map a user passed in optimizer_class to a functional
+# optimizer class if we have already defined inside the
+# distributed.optim package, this is so that we hide the
+# functional optimizer to user and still provide the same API.
+functional_optim_map = {
+    optim.Adagrad: _FunctionalAdagrad,
+    optim.Adam: _FunctionalAdam,
+    optim.AdamW: _FunctionalAdamW,
+    optim.SGD: _FunctionalSGD,
+    optim.Adadelta: _FunctionalAdadelta,
+    optim.RMSprop: _FunctionalRMSprop,
+    optim.Rprop: _FunctionalRprop,
+    optim.Adamax: _FunctionalAdamax,
+}
+
+
+def register_functional_optim(key, optim):
+    """
+    Interface to insert a new functional optimizer to functional_optim_map
+    ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
+    need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
+    Example::
+        >>> # import the new functional optimizer
+        >>> # xdoctest: +SKIP
+        >>> from xyz import fn_optimizer
+        >>> from torch.distributed.optim.utils import register_functional_optim
+        >>> fn_optim_key = "XYZ_optim"
+        >>> register_functional_optim(fn_optim_key, fn_optimizer)
+    """
+    if key not in functional_optim_map:
+        functional_optim_map[key] = optim
+
+
+def as_functional_optim(optim_cls: Type, *args, **kwargs):
+    try:
+        functional_cls = functional_optim_map[optim_cls]
+    except KeyError as e:
+        raise ValueError(
+            f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
+        ) from e
+
+    return _create_functional_optim(functional_cls, *args, **kwargs)
+
+
+def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
+    return functional_optim_cls(
+        [],
+        *args,
+        **kwargs,
+        _allow_empty_param_list=True,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py b/MLPY/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f3be9ce6a99bb56045c746b709b96fa38200e77
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py
@@ -0,0 +1,1651 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+r"""Zero Redundancy Optimizer."""
+import collections
+import copy
+import enum
+import inspect
+import io
+import logging
+from itertools import chain
+from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
+
+import torch
+import torch.distributed as dist
+from torch.distributed.algorithms.join import Join, Joinable, JoinHook
+from torch.distributed.optim.utils import functional_optim_map
+from torch.optim import Optimizer
+
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["ZeroRedundancyOptimizer"]
+
+
+# Credits:  classy_vision/generic/distributed_util.py
+def _recursive_copy_to_device(
+    value: Any,
+    non_blocking: bool,
+    device: torch.device,
+) -> Any:
+    r"""
+    Recursively searches lists, tuples, dicts and copies tensors to device if possible.
+
+    Non-tensor values are passed as-is in the result.
+
+    .. note:  These are all copies, so if there are two objects that reference
+    the same object, then after this call, there will be two different objects
+    referenced on the device.
+    """
+    if isinstance(value, torch.Tensor):
+        return value.to(device, non_blocking=non_blocking)
+
+    if isinstance(value, (list, tuple)):
+        values = [
+            _recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
+            for val in value
+        ]
+        return values if isinstance(value, list) else tuple(values)
+
+    if isinstance(value, collections.abc.Mapping):
+        return {
+            key: _recursive_copy_to_device(
+                val, non_blocking=non_blocking, device=device
+            )
+            for key, val in value.items()
+        }
+
+    return value
+
+
+def _is_trainable(param: torch.Tensor) -> bool:
+    r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient."""
+    return param.requires_grad
+
+
+def _broadcast_object(
+    obj: Any,
+    src_rank: int,
+    group: object = dist.group.WORLD,
+    device: torch.device = torch.device("cpu"),
+) -> Any:
+    r"""
+    Broadcasts an object to the given group.
+
+    It will be sending the object if called from the source rank and receiving
+    the object otherwise.
+
+    Arguments:
+        obj: object to broadcast; only used if called on the source rank.
+        src_rank (int): source rank.
+        group (``ProcessGroup``, optional): group used for the broadcast
+            (default: ``dist.group.WORLD``).
+        device (``torch.device``, optional): device to send from or receive
+            to (default: ``torch.device("cpu")``).
+
+    Returns:
+        The broadcasted object.
+    """
+    if dist.get_rank() == src_rank:
+        # Send the object
+        buffer = io.BytesIO()
+        torch.save(obj, buffer)
+        data = bytearray(buffer.getbuffer())
+        length_tensor = torch.LongTensor([len(data)]).to(device)
+        data_send_tensor = torch.ByteTensor(data).to(device)
+        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
+        dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
+    else:
+        # Receive the object
+        length_tensor = torch.LongTensor([0]).to(device)
+        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
+        data_recv_tensor = torch.empty(
+            [int(length_tensor.item())], dtype=torch.uint8, device=device
+        )
+        dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
+        buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
+        obj = torch.load(buffer, map_location=device)
+    return obj
+
+
+class _ZeROJoinHook(JoinHook):
+    def __init__(self, zero):
+        assert isinstance(zero, ZeroRedundancyOptimizer), (
+            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer "
+            "instance as the state"
+        )
+        self.zero = zero
+        super().__init__()
+
+    def main_hook(self):
+        """
+        Perform an optimizer step.
+
+        This step updates the joined process's shard of
+        the parameters and broadcasts those parameters.
+        """
+        self.zero.step()
+
+
+class _DDPBucketAssignment:
+    r"""
+    Represent a :class:`DistributedDataParallel` bucket assignment.
+
+    This means that a (possibly non-strict) subset of the parameters corresponding to
+    a DDP bucket assigned to a rank to update.
+
+    Attributes:
+        bucket_index (int): index of the bucket determined by the DDP gradient
+            bucket all-reduce order.
+        parameters (List[torch.Tensor]): model parameters in the bucket
+            assigned to this rank.
+        offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
+            giving the index of the first element in the passed-in
+            ``parameters``; this equivalently indexes into the
+            :class:`GradBucket` 's :meth:`gradients`.
+        device (torch.device): device on which the parameters are stored.
+        tensor (torch.Tensor): flattened tensor giving the data of the
+            parameter subset assigned to the rank.
+    """
+
+    def __init__(
+        self,
+        bucket_index: int,
+        parameters: List[torch.Tensor],
+        offset: int,
+    ):
+        self.bucket_index = bucket_index
+        self.parameters = parameters
+        self.offset = offset
+        if len(self.parameters) == 0:
+            raise ValueError("Empty bucket assignment")
+        # DDP guarantees all parameters in the bucket have the same device
+        self.device: torch.device = self.parameters[0].device
+        self.tensor: Optional[torch.Tensor] = None
+
+
+class _OverlapStatus(enum.IntEnum):
+    r"""
+    Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`.
+
+    Attributes:
+        ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and
+            is waiting for DDP to finalize its bucketing.
+        ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that
+            its bucketing is finalized. The ZeRO instance can now collect the
+            necessary information about the DDP bucketing.
+        ``INITIALIZED``: The ZeRO instance is fully initialized and can now
+            optimize parameters.
+    """
+
+    UNINITIALIZED = 0
+    DDP_HAS_REBUILT_BUCKETS = 1
+    INITIALIZED = 2
+
+
+class _OverlapInfo:
+    r"""
+    Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`.
+
+    Arguments:
+        world_size (int): world size of the process group being used.
+
+    Attributes:
+        shard_buckets (bool): if ``True``, then the assignment of each
+            :class:`DistributedDataParallel` bucket is partitioned across
+            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
+            across possibly multiple ranks) to approximate uniformity following
+            a threshold given by the total parameter size divided by the world
+            size; if ``False``, then each bucket is wholly assigned to a single
+            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
+            this should be set to the value passed into the hook constructor.
+        status (_OverlapStatus): current status; see :class:`_OverlapStatus`
+            for more information.
+        params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
+            gives the model parameters in the ``i``th bucket.
+        params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]``
+            gives the model parameters assigned to the ``i``th rank, where the
+            parameters are grouped by increasing bucket indices.
+        offsets (Dict[int, int]): maps from bucket index to the offset in
+            ``self.params_per_rank[rank]`` giving the index of the first
+            parameter in that bucket, where ``rank`` is this process's own
+            rank; the keys of this :class:`dict` are the bucket indices
+            assigned to this rank.
+        num_bucket_assignments (int): total number of bucket assignments across
+            all ranks; this is equal to the number of
+            :class:`DistributedDataParallel` gradient buckets if
+            ``shard_buckets=False`` and possibly greater otherwise.
+        total_size (int, optional): total size of all buckets (i.e. sum of
+            ``param.numel()`` for all ``param`` across all buckets) if
+            ``shard_buckets=True``; otherwise, ``None``.
+        broadcast_handles (List[Work]): :class:`list` of async work handles for
+            the parameter broadcasts.
+        bucket_index_to_future (Dict[int, torch.futures.Future]):
+            :class:`dict` mapping bucket index to the corresponding all-reduce
+            future.
+        bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict`
+            mapping bucket index to the corresponding bucket.
+        bucket_indices_seen (List[int]): :class:`list` of the bucket indices
+            seen on this iteration.
+    """
+
+    def __init__(self, world_size) -> None:
+        self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
+        self.shard_buckets: bool = False
+
+        # Modified per bucket reconstruction
+        self.params_per_bucket: List[List[torch.Tensor]] = []
+        self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
+        self.offsets: Dict[int, int] = {}
+        # Group Ranks
+        self.assigned_ranks_per_bucket: List[Set[int]] = []
+        self.num_bucket_assignments: int = 0
+        self.total_size: Optional[int] = None
+
+        # Modified per iteration
+        self.broadcast_handles: List[Any] = []
+        self.bucket_indices_seen: List[int] = []
+        # Used by `hook_with_zero_step()`
+        self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
+        self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
+
+    def wait_for_broadcasts(self) -> None:
+        r"""
+        Wait for all parameter broadcasts.
+
+        This function should be called once all broadcasts have been scheduled,
+        meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
+        in preparation for the next iteration.
+        """
+        assert (
+            len(self.broadcast_handles) == self.num_bucket_assignments
+        ), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
+        _ = [x.wait() for x in self.broadcast_handles]
+        self.broadcast_handles.clear()
+
+    def clear_per_iter_info(self) -> None:
+        r"""
+        Clear the data structures that are modified per-iteration.
+
+        This function should be called at the end of an iteration.
+        """
+        self.bucket_indices_seen.clear()
+        self.bucket_index_to_future.clear()
+        self.bucket_index_to_bucket.clear()
+
+
+class ZeroRedundancyOptimizer(Optimizer, Joinable):
+    r"""
+    Wrap an arbitrary :class:`optim.Optimizer ` and shards its states across ranks in the group.
+
+    The sharing is done as described by ZeRO_.
+
+    The local optimizer instance in each rank is only
+    responsible for updating approximately ``1 / world_size`` parameters and
+    hence only needs to keep ``1 / world_size`` optimizer states. After
+    parameters are updated locally, each rank will broadcast its parameters to
+    all other peers to keep all model replicas in the same state.
+    ``ZeroRedundancyOptimizer`` can be used in conjunction with
+    :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
+    memory consumption.
+
+    ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
+    of parameters at each rank. Each parameter belongs to a single rank and is
+    not divided among ranks. The partition is arbitrary and might not match the
+    the parameter registration or usage order.
+
+    Arguments:
+        params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
+            or :class:`dict` s giving all parameters, which will be sharded
+            across ranks.
+
+    Keyword Args:
+        optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
+            optimizer.
+        process_group (``ProcessGroup``, optional): ``torch.distributed``
+            ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by
+            :meth:`torch.distributed.init_process_group`).
+        parameters_as_bucket_view (bool, optional): if ``True``, parameters are
+            packed into buckets to speed up communication, and ``param.data``
+            fields point to bucket views at different offsets; if ``False``,
+            each individual parameter is communicated separately, and each
+            ``params.data`` stays intact (default: ``False``).
+        overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is
+            overlapped with :class:`DistributedDataParallel` 's gradient
+            synchronization; this requires (1) either a functional optimizer
+            for the ``optimizer_class`` argument or one with a functional
+            equivalent and (2) registering a DDP communication hook
+            constructed from one of the functions in ``ddp_zero_hook.py``;
+            parameters are packed into buckets matching those in
+            :class:`DistributedDataParallel`, meaning that the
+            ``parameters_as_bucket_view`` argument is ignored.
+            If ``False``, :meth:`step` runs disjointly after the backward pass
+            (per normal).
+            (default: ``False``)
+        **defaults: any trailing arguments, which are forwarded to the local
+            optimizer.
+
+    Example::
+
+        >>> # xdoctest: +SKIP
+        >>> import torch.nn as nn
+        >>> from torch.distributed.optim import ZeroRedundancyOptimizer
+        >>> from torch.nn.parallel import DistributedDataParallel as DDP
+        >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
+        >>> ddp = DDP(model, device_ids=[rank])
+        >>> opt = ZeroRedundancyOptimizer(
+        >>>     ddp.parameters(),
+        >>>     optimizer_class=torch.optim.Adam,
+        >>>     lr=0.01
+        >>> )
+        >>> ddp(inputs).sum().backward()
+        >>> opt.step()
+
+    .. warning::
+        Currently, ``ZeroRedundancyOptimizer`` requires that all of the
+        passed-in parameters are the same dense type.
+
+    .. warning::
+        If you pass ``overlap_with_ddp=True``, be wary of the following: Given
+        the way that overlapping :class:`DistributedDataParallel` with
+        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
+        two or three training iterations do not perform parameter updates in
+        the optimizer step, depending on if ``static_graph=False`` or
+        ``static_graph=True``, respectively. This is because it needs
+        information about the gradient bucketing strategy used by
+        :class:`DistributedDataParallel`, which is not finalized until the
+        second forward pass if ``static_graph=False`` or until the third
+        forward pass if ``static_graph=True``. To adjust for this, one option
+        is to prepend dummy inputs.
+
+    .. warning:: ZeroRedundancyOptimizer is experimental and subject to change.
+
+    .. _ZeRO: https://arxiv.org/abs/1910.02054
+
+    """
+
+    def __init__(
+        self,
+        params,
+        optimizer_class: Type[Optimizer],
+        process_group: Optional[Any] = None,
+        parameters_as_bucket_view: bool = False,
+        overlap_with_ddp: bool = False,
+        **defaults: Any,
+    ):
+        r"""Init."""
+        # Perform type and assumption checks on the input parameters
+        params = self._verify_and_init_params(params)
+        self._verify_same_dense_param_type()
+
+        # NOTE: The parent constructor uses `add_param_group()` which is
+        # partially overloaded in ZeroRedundancyOptimizer, so we use the
+        # `initialized` flag to dissociate the behaviour of `add_param_group()`
+        # between the parent and child.
+        self.initialized = False
+
+        Optimizer.__init__(self, params, defaults)
+        Joinable.__init__(self)
+        # Now, all parameters are held in both `self._all_params` and
+        # `self.param_groups`
+
+        # Internal data structures (`_cache` indicates lazily evaluated)
+        self._param_to_rank_cache: Dict[torch.Tensor, int] = {}
+        self._param_to_index_cache: Dict[torch.Tensor, int] = {}
+        self._partition_parameters_cache: List[List[Dict]] = []
+        self._index_to_param_cache: List[torch.Tensor] = []
+        self._device_to_params_per_rank_cache: Dict[
+            torch.device, List[List[torch.Tensor]]
+        ] = {}
+        self._bucket_assignments_per_rank_cache: List[
+            Dict[int, _DDPBucketAssignment]
+        ] = []
+        self._is_trainable_mask = self._get_is_trainable_mask()
+
+        # Default device for collective communication and buckets
+        self._default_device = self._all_params[0].device
+
+        self.process_group = (
+            process_group if process_group is not None else dist.group.WORLD
+        )
+        self.world_size: int = dist.get_world_size(self.process_group)
+        self.rank: int = dist.get_rank(self.process_group)
+        self.global_rank: int = dist.distributed_c10d.get_global_rank(
+            self.process_group, self.rank
+        )
+
+        self._overlap_with_ddp: bool = overlap_with_ddp
+        self._optim_defaults = defaults
+        self._optim_constructor = self._get_optimizer_constructor(optimizer_class)
+
+        # If `overlap_with_ddp=True`, local optimizer initialization is delayed
+        # to run time after the necessary information has been collected
+        if not overlap_with_ddp:
+            self._init_local_optimizer()
+        else:
+            self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
+            if parameters_as_bucket_view:
+                logger.warning(
+                    "`parameters_as_bucket_view=True` will be ignored since "
+                    "`overlap_with_ddp=True`; instead, a different bucketing "
+                    "strategy will be used"
+                )
+
+        # `self._buckets` is used if `parameters_as_bucket_view=True`, in
+        # which case parameter data is flattened into contiguous bucket tensors
+        self.parameters_as_bucket_view = parameters_as_bucket_view
+        self._buckets: List[List[torch.Tensor]] = []
+        self._build_param_buckets()
+
+        # Optional consolidated optimizer state, only populated if this rank
+        # is the target in `consolidate_state_dict()`
+        self._all_state_dicts: List[Dict[str, Any]] = []
+
+        self.initialized = True
+
+    def _clear_cache(self) -> None:
+        r"""Clear the cached data structures giving partition information."""
+        self._partition_parameters_cache.clear()
+        self._param_to_rank_cache.clear()
+        self._index_to_param_cache.clear()
+        self._param_to_index_cache.clear()
+        self._device_to_params_per_rank_cache.clear()
+        self._bucket_assignments_per_rank_cache.clear()
+
+    def add_param_group(self, param_group: Dict[str, Any]) -> None:
+        r"""
+        Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
+
+        This can be useful when fine tuning a pre-trained network, as frozen
+        layers can be made trainable and added to the :class:`Optimizer` as
+        training progresses.
+
+        Arguments:
+            param_group (dict): specifies the parameters to be optimized and
+                group-specific optimization options.
+
+        .. warning:: This method handles updating the shards on all partitions
+            but needs to be called on all ranks. Calling this on a subset of
+            the ranks will cause the training to hang because communication
+            primitives are called depending on the managed parameters and
+            expect all the ranks to participate on the same set of parameters.
+        """
+        if self.initialized and self._overlap_with_ddp:
+            raise RuntimeError(
+                "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only "
+                "supports a single parameter group"
+            )
+
+        super().add_param_group(param_group)
+        # NOTE: The rest of the method assumes that the call to the parent's
+        # `add_param_group()` appends the new parameter group and preserves
+        # the previous parameter-group ordering
+
+        if self.initialized:
+            # Force a re-partitioning of the parameters
+            self._clear_cache()
+            param_groups = self._partition_parameters()[self.rank]
+            # NOTE: All parameters in the old parameter groups should be
+            # assigned to the same ranks so that the local optimizers do not
+            # need to be reinitialized
+
+            # Add the parameters assigned to this rank from the new parameter
+            # group to the local optimizer, if any
+            if len(param_groups) == len(self.optim.param_groups) + 1:
+                self.optim.add_param_group(param_groups[-1])
+
+            # Update the bucketing strategy accordingly
+            if self.parameters_as_bucket_view:
+                self._build_param_buckets()
+
+    def consolidate_state_dict(self, to: int = 0) -> None:
+        r"""
+        Consolidate a list of ``state_dict`` s (one per rank) on the target rank.
+
+        Arguments:
+            to (int): the rank that receives the optimizer states (default: 0).
+
+        Raises:
+            RuntimeError: if ``overlap_with_ddp=True`` and this method is
+                called before this :class:`ZeroRedundancyOptimizer` instance
+                has been fully initialized, which happens once
+                :class:`DistributedDataParallel` gradient buckets have been
+                rebuilt.
+
+        .. warning:: This needs to be called on all ranks.
+        """
+        self._check_overlap_initialized()
+
+        # Sync the exposed `param_groups` attributes to the local optimizer in
+        # case they have been updated
+        self._sync_param_groups(self.param_groups, self.optim.param_groups)
+
+        # Pull the sharded state from all ranks and store them in rank order
+        empty_messenger = torch.tensor(
+            [0], dtype=torch.uint8, device=self._default_device
+        )
+
+        # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
+        # due to compatibility issues with NCCL backend; a possible follow-up
+        # is to move all sharded state management to RPC RRef
+        self._all_state_dicts = []
+        for rank in range(self.world_size):
+            global_rank = dist.distributed_c10d.get_global_rank(
+                self.process_group, rank
+            )
+            if self.rank == to:
+                # Consolidate all local `state_dict`s on this rank, storing on
+                # CPU to save GPU memory
+                if rank == self.rank:
+                    # Directly append own optimizer state
+                    self._all_state_dicts.append(
+                        _recursive_copy_to_device(
+                            self.optim.state_dict(),
+                            non_blocking=True,
+                            device=torch.device("cpu"),
+                        )
+                    )
+                else:
+                    # Receive the optimizer state from the source rank
+                    local_state_dict = _broadcast_object(
+                        empty_messenger,
+                        src_rank=global_rank,
+                        group=self.process_group,
+                        device=self._default_device,
+                    )
+                    self._all_state_dicts.append(
+                        _recursive_copy_to_device(
+                            local_state_dict,
+                            non_blocking=True,
+                            device=torch.device("cpu"),
+                        )
+                    )
+            else:
+                if rank == self.rank:
+                    # Send the optimizer state to the target rank
+                    _ = _broadcast_object(
+                        self.optim.state_dict(),
+                        src_rank=self.global_rank,
+                        group=self.process_group,
+                        device=self._default_device,
+                    )
+                elif rank != to:
+                    # Discard the received object; `broadcast()` is used for
+                    # compatibility reasons
+                    _ = _broadcast_object(
+                        empty_messenger,
+                        src_rank=global_rank,
+                        group=self.process_group,
+                        device=self._default_device,
+                    )
+
+    def _verify_params_per_rank(
+        self,
+        params_per_rank: List[List[torch.Tensor]],
+    ) -> None:
+        r"""
+        Verify ``params_per_rank`` for :meth:`_partition_parameters`.
+
+        The verification is done by checking that ``params_per_rank`` has length equal
+        to the world size and that it does not contain any parameters not passed into the
+        :class:`ZeroRedundancyOptimizer` constructor.
+
+        The parameters in ``params_per_rank`` being a strict subset of those
+        passed into the constructor is valid since some parameters may be
+        frozen.
+
+        Raises:
+            ValueError: if ``params_per_rank`` does not have length equal to
+                the world size or if it contains a parameter that was not
+                passed into the :class:`ZeroRedundancyOptimizer` constructor.
+        """
+        if len(params_per_rank) != self.world_size:
+            raise ValueError(
+                "`params_per_rank` must have length equal to the world size"
+            )
+        all_params_set = set(self._all_params)
+        for params in params_per_rank:
+            for param in params:
+                if param not in all_params_set:
+                    raise ValueError(
+                        "Passing a new parameter in `params_per_rank` that "
+                        "was not passed into the ZeroRedundancyOptimizer "
+                        "constructor"
+                    )
+
+    def _partition_param_group(
+        self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]]
+    ) -> None:
+        r"""
+        Partition the parameter group ``param_group`` according to ``params_per_rank``.
+
+        The partition will modify the ``self._partition_parameters_cache``. This method should
+        only be used as a subroutine for :meth:`_partition_parameters`.
+
+        Arguments:
+            param_group (dict[str, Any]): a parameter group as normally defined
+                in an optimizer state.
+            params_per_rank (list[list[torch.Tensor]]): a :class:`list` of
+                length world size containing :class:`list` s of parameters to
+                assign to each rank.
+        """
+        for rank, params in enumerate(params_per_rank):
+            rank_param_group = copy.copy(param_group)
+            rank_param_group["params"] = params
+            self._partition_parameters_cache[rank].append(rank_param_group)
+
+    def _partition_parameters(
+        self,
+        params_per_rank: Optional[List[List[torch.Tensor]]] = None,
+    ) -> List[List[Dict]]:
+        r"""
+        Partitions parameters across distributed data parallel ranks.
+
+        Arguments:
+            params_per_rank (list[list[torch.Tensor]], optional): a
+                :class:`list` of length world size containing :class:`list` s
+                of parameters to assign to each rank; this provides a way to
+                specify a partition manually.
+                If ``None``, the parameters are partitioned according to an
+                internal algorithm.
+                (default: ``None``)
+
+        Returns:
+            A :class:`list` where each element of the list contains the
+            ``param_groups`` for a rank (which itself is a :class:`list` of
+            :class:`dict`); element 0 corresponds to rank 0, etc.; each rank
+            stores the ``param_groups`` for all ranks for the collective
+            communication in :meth:`step`.
+
+        Raises:
+            ValueError: see :meth:`_validate_params_per_rank`.
+            RuntimeError: if ``params_per_rank`` is not ``None`` and this
+                :class:`ZeroRedundancyOptimizer` instance is using more than
+                one parameter group.
+        """
+        if params_per_rank is None:
+            # Partition the parameters optimizing for uniformity
+            if len(self._partition_parameters_cache) == 0:
+                self._partition_parameters_cache = [[] for _ in range(self.world_size)]
+                sizes = [0] * self.world_size
+                for param_group in self.param_groups:
+                    param_group_params_per_rank: List[List] = [
+                        [] for _ in range(self.world_size)
+                    ]
+                    # Sort the parameters by size (largest first)
+                    params_sorted = sorted(
+                        param_group["params"], key=lambda t: t.numel(), reverse=True
+                    )
+                    for param in params_sorted:
+                        # Greedily add the parameter to rank with smallest size so far
+                        rank = self._get_min_index(sizes)
+                        param_group_params_per_rank[rank].append(param)
+                        sizes[rank] += param.numel()
+                    # Apply the constructed partition of the parameter group
+                    self._partition_param_group(
+                        param_group, param_group_params_per_rank
+                    )
+
+            return self._partition_parameters_cache
+
+        # Partition the parameters according to `params_per_rank`
+        assert len(self._partition_parameters_cache) == 0, (
+            "Specifying `params_per_rank` should only be done when the "
+            "parameters have not been partitioned yet"
+        )
+        if len(self.param_groups) != 1:
+            raise RuntimeError(
+                "Specifying `params_per_rank` only supports a single parameter group"
+            )
+        self._verify_params_per_rank(params_per_rank)
+        self._partition_parameters_cache = [[] for _ in range(self.world_size)]
+
+        # Apply the passed-in partition of the parameter group
+        param_group = self.param_groups[0]
+        self._partition_param_group(param_group, params_per_rank)
+
+        return self._partition_parameters_cache
+
+    @property
+    def _param_to_rank(self) -> Dict[torch.Tensor, int]:
+        r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition."""
+        if len(self._param_to_rank_cache) == 0:
+            for rank, param_groups in enumerate(self._partition_parameters()):
+                for param_group in param_groups:
+                    for param in param_group["params"]:
+                        self._param_to_rank_cache[param] = rank
+        return self._param_to_rank_cache
+
+    @property
+    def _param_to_index(self) -> Dict[torch.Tensor, int]:
+        r"""
+        :class:`dict` mapping parameters to their indices in the global optimizer state.
+
+        NOTE: This assumes that the global optimizer state's indexing (in
+        ``state_dict``) follows a linear ordering over the parameter groups.
+        """
+        if len(self._param_to_index_cache) == 0:
+            self._param_to_index_cache = {
+                p: i
+                for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))
+            }
+        return self._param_to_index_cache
+
+    @property
+    def _index_to_param(self) -> List[torch.Tensor]:
+        r"""List mapping parameter indices in the global optimizer scheme to the actual params."""
+        if len(self._index_to_param_cache) == 0:
+            self._index_to_param_cache = list(
+                chain(*(g["params"] for g in self.param_groups))
+            )
+        return self._index_to_param_cache
+
+    def _broadcast_params_from_rank(self, rank: int):
+        r"""
+        Broadcast the shard of parameters from a given rank to all other ranks asynchronously.
+
+        Arguments:
+            rank (int): the source rank.
+
+        Returns:
+            A :class:`list` of async work handles for the ``broadcast()`` s
+            performed to synchronize the parameters.
+        """
+        assert not self._overlap_with_ddp, (
+            "`_broadcast_params_from_rank()` should not be used if "
+            "`overlap_with_ddp=True`; instead, the broadcasting should "
+            "happen in the DDP communication hook"
+        )
+        handles = []
+        if self.parameters_as_bucket_view:
+            for dev_i_buckets in self._buckets:
+                bucket = dev_i_buckets[rank]
+                global_rank = dist.distributed_c10d.get_global_rank(
+                    self.process_group, rank
+                )
+                handles.append(
+                    dist.broadcast(
+                        tensor=bucket,
+                        src=global_rank,
+                        group=self.process_group,
+                        async_op=True,
+                    )
+                )
+        else:
+            param_groups = self._partition_parameters()[rank]
+            global_rank = dist.distributed_c10d.get_global_rank(
+                self.process_group, rank
+            )
+            for param_group in param_groups:
+                for param in param_group["params"]:
+                    handles.append(
+                        dist.broadcast(
+                            tensor=param.data,
+                            src=global_rank,
+                            group=self.process_group,
+                            async_op=True,
+                        )
+                    )
+        return handles
+
+    def _sync_params(self):
+        r"""
+        Sync all parameter shards across the ranks.
+
+        This rank sends its shard of the parameters to all other ranks and
+        receives a shard from each other rank. This is done using
+        ``broadcast()``. Parameters are sent bucket-by-bucket if
+        ``parameters_as_bucket_view=True``and sent parameter-by-parameter
+        otherwise.
+        """
+        handles = []
+        for rank in range(self.world_size):
+            handles.extend(self._broadcast_params_from_rank(rank))
+        _ = [x.wait() for x in handles]
+
+    @property
+    def _device_to_params_per_rank(
+        self,
+    ) -> Dict[torch.device, List[List[torch.Tensor]]]:
+        r"""
+        Return device parameters assigned per rank.
+
+        :class:`dict` mapping each device to a :class:`list` of the per-rank parameter
+        lists filtered to only include the parameters stored on that device.
+        Each per-rank parameter list gives the parameters assigned to that rank
+        to update.
+
+        This is used for constructing the parameter buckets if
+        ``parameters_as_bucket_view=True``.
+
+        Let ``dev_i`` denote the ``i``th device for this rank. Then:
+        ``dev_0`` maps to a list containing:
+            rank 0's assigned parameters stored on ``dev_0``,
+            rank 1's assigned parameters stored on ``dev_0``,
+            ...
+        ``dev_1`` maps to a list containing:
+            rank 0's assigned parameters stored on ``dev_1``,
+            rank 1's assigned parameters stored on ``dev_1``,
+            ...
+        ...
+        """
+        assert self.parameters_as_bucket_view, (
+            "`_device_to_params_per_rank` should only be used if "
+            "`parameters_as_bucket_view=True`"
+        )
+        if len(self._device_to_params_per_rank_cache) == 0:
+            for rank, param_groups in enumerate(self._partition_parameters()):
+                for param_group in param_groups:
+                    for param in param_group["params"]:
+                        device = param.device
+                        if device not in self._device_to_params_per_rank_cache:
+                            self._device_to_params_per_rank_cache[device] = [
+                                [] for _ in range(self.world_size)
+                            ]
+                        self._device_to_params_per_rank_cache[device][rank].append(
+                            param
+                        )
+        return self._device_to_params_per_rank_cache
+
+    def _get_min_index(
+        self,
+        values: List[int],
+        disallowed_indices: Optional[Set[int]] = None,
+    ) -> int:
+        r"""
+        Return ``values.index(min(values))``, except only uses one pass.
+
+        It also excludes any indices in ``disallowed_indices`` if provided.
+
+        Arguments:
+            values: (List[int]): :class:`list` of values.
+            disallowed_indices (Optional[Set[int]]): indices that are
+                disallowed from being the returned min index.
+        """
+        min_index = -1
+        min_value = float("inf")
+        for i, value in enumerate(values):
+            if disallowed_indices and i in disallowed_indices:
+                continue
+            if value < min_value:
+                min_value = value
+                min_index = i
+        assert min_index >= 0, "All indices are disallowed"
+        return min_index
+
+    def _assign_bucket_subset_to_rank(
+        self,
+        bucket_index: int,
+        bucket_params: List[torch.Tensor],
+        bucket_offset: int,
+        assigned_rank: int,
+        assigned_ranks_per_bucket: List[Set[int]],
+    ) -> None:
+        r"""
+        Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information.
+
+        The model parameters given by ``bucket_params`` represents a (possibly non-strict)
+        subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket.
+
+        Arguments:
+            bucket_index (int): index of the :class:`DistributedDataParallel`
+                gradient bucket.
+            bucket_params (List[torch.Tensor]): subset of the parameters
+                corresponding to the bucket to assign.
+            bucket_offset (int): offset giving the index of the first element
+                in ``bucket_params`` in the bucket's full parameter list.
+            assigned_rank (int): group rank to assign to.
+            assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks
+                assigned to each bucket.
+        """
+        overlap_info = self._overlap_info
+        if len(bucket_params) == 0:
+            raise ValueError("Empty bucket assignment")
+        params_per_rank = overlap_info.params_per_rank
+        offsets = overlap_info.offsets
+
+        self._bucket_assignments_per_rank_cache[assigned_rank][
+            bucket_index
+        ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
+        if self.global_rank == assigned_rank:
+            offsets[bucket_index] = len(params_per_rank[assigned_rank])
+        params_per_rank[assigned_rank].extend(bucket_params)
+        assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
+        self._overlap_info.num_bucket_assignments += 1
+
+    @property
+    def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]:
+        r"""
+        Return DDP bucket parameters assigned per rank.
+
+        :class:`list` of length world size consisting of :class:`dict` s
+        mapping bucket indices to :class:`_DDPBucketAssignment` s for each
+        rank.
+        """
+        assert self._overlap_with_ddp, (
+            "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
+        )
+        if len(self._bucket_assignments_per_rank_cache) > 0:
+            return self._bucket_assignments_per_rank_cache
+
+        overlap_info = self._overlap_info
+        assert overlap_info.status == _OverlapStatus.INITIALIZED
+
+        self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
+        params_per_bucket = overlap_info.params_per_bucket
+
+        if overlap_info.shard_buckets:
+            # Define the assignment threshold to approximate uniformity
+            assert overlap_info.total_size is not None, "`total_size` was not computed"
+            threshold = overlap_info.total_size / self.world_size  # type: ignore[operator]
+            size_per_rank = [0 for _ in range(self.world_size)]
+
+        num_buckets = len(params_per_bucket)
+        overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
+        assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
+        if not overlap_info.shard_buckets:
+            # Assign each DDP bucket entirely to a single rank
+            for bucket_index, bucket_params in enumerate(params_per_bucket):
+                assert len(bucket_params) > 0, "Empty bucket"
+                assigned_rank = self._get_assigned_rank(bucket_index)
+                self._assign_bucket_subset_to_rank(
+                    bucket_index,
+                    bucket_params,
+                    0,
+                    assigned_rank,
+                    assigned_ranks_per_bucket,
+                )
+        else:
+            # Assign each DDP bucket to possibly multiple ranks
+            # Specifically, sort the DDP buckets by increasing size, and for
+            # each bucket, iteratively assign the maximal unassigned subset
+            # with size less than `threshold` to the rank with the least total
+            # size so far -- each such assignment is represented by a
+            # `_DDPBucketAssignment` instance and only contains parameters from
+            # a single DDP bucket
+            params_per_bucket_enum = sorted(
+                enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1])
+            )
+            for bucket_index, bucket_params in params_per_bucket_enum:
+                assert len(bucket_params) > 0, "Empty bucket"
+                bucket_offset = 0
+                assignment_size = 0
+                for param_index, param in enumerate(bucket_params):
+                    param_numel = param.numel()
+                    if (
+                        assignment_size + param_numel >= threshold
+                        and param_index > bucket_offset
+                    ):
+                        assigned_rank = self._get_min_index(
+                            size_per_rank, assigned_ranks_per_bucket[bucket_index]
+                        )
+                        # Include up to but not including the parameter that
+                        # exceeded the threshold
+                        self._assign_bucket_subset_to_rank(
+                            bucket_index,
+                            bucket_params[bucket_offset:param_index],
+                            bucket_offset,
+                            assigned_rank,
+                            assigned_ranks_per_bucket,
+                        )
+                        size_per_rank[assigned_rank] += assignment_size
+                        bucket_offset = param_index
+                        assignment_size = 0
+                    assignment_size += param_numel
+                # Assign the remainder of the bucket so that no assignment
+                # spans across two buckets
+                assigned_rank = self._get_min_index(
+                    size_per_rank, assigned_ranks_per_bucket[bucket_index]
+                )
+                self._assign_bucket_subset_to_rank(
+                    bucket_index,
+                    bucket_params[bucket_offset:],
+                    bucket_offset,
+                    assigned_rank,
+                    assigned_ranks_per_bucket,
+                )
+                size_per_rank[assigned_rank] += assignment_size
+
+        return self._bucket_assignments_per_rank_cache
+
+    def _local_step(
+        self,
+        gradients: Optional[List[Optional[torch.Tensor]]] = None,
+        closure: Optional[Callable[[], float]] = None,
+        **kwargs: Any,
+    ) -> Optional[float]:
+        r"""
+        Perform a single optimizer step without syncing parameters across ranks.
+
+        Arguments:
+            gradients (list[Optional[torch.Tensor]], optional): a :class:`list`
+                of length equal to the number of parameters assigned to this
+                rank containing gradient tensors or ``None`` as its elements;
+                a ``None`` in the :class:`list` indicates that the
+                corresponding parameter should not be updated.
+                If the argument itself is ``None``, then all parameters are
+                updated, and the gradients are assumed to be already populated.
+                (default: ``None``)
+            closure (Callable): a closure that re-evaluates the model and
+                returns the loss; optional for most optimizers and should be
+                ``None`` if ``gradients`` is not ``None``; (default: ``None``)
+        Returns:
+            Optional loss depending on the underlying local optimizer.
+
+        .. warning::
+            The argument ``gradients`` should only be specified (i.e. not
+            ``None``) if ``overlap_with_ddp=True``, in which case
+            :class:`ZeroRedundancyOptimizer` wraps a functional optimizer.
+        """
+        Join.notify_join_context(self)
+        # Check if the model trainability has changed
+        is_trainable_mask = self._get_is_trainable_mask()
+        if is_trainable_mask != self._is_trainable_mask:
+            if self._overlap_with_ddp:
+                raise RuntimeError(
+                    "ZeroRedundancyOptimizer with `overlap_with_ddp=True` "
+                    "does not support changing parameter trainability at run "
+                    "time"
+                )
+            logger.warning(
+                "ZeroRedundancyOptimizer detected that the trainable "
+                "parameters changed; rebuilding the parameter buckets if "
+                "enabled"
+            )
+            self._build_param_buckets()
+            self._is_trainable_mask = is_trainable_mask
+
+        # Sync the exposed `param_groups` attributes to the local optimizer in
+        # case they have been updated
+        self._sync_param_groups(self.param_groups, self.optim.param_groups)
+
+        # Run the optimizer step on this shard only
+        if gradients is None:
+            loss = (
+                self.optim.step(**kwargs)
+                if closure is None
+                else self.optim.step(closure=closure, **kwargs)
+            )
+        else:
+            assert self._overlap_with_ddp, (
+                "Specifying `gradients` should not "
+                "be used when `overlap_with_ddp=False`"
+            )
+            assert closure is None, (
+                "`closure` is not supported when using a local functional optimizer"
+            )
+            loss = self.optim.step(gradients=gradients)
+
+        # Sync any updated attributes in the local optimizer to the exposed
+        # `param_groups`
+        self._sync_param_groups(self.optim.param_groups, self.param_groups)
+
+        return loss
+
+    def step(
+        self,
+        closure: Optional[Callable[[], float]] = None,
+        **kwargs: Any,
+    ) -> Optional[float]:
+        r"""
+        Perform a single optimizer step and syncs parameters across all ranks.
+
+        Arguments:
+            closure (Callable): a closure that re-evaluates the model and
+                returns the loss; optional for most optimizers.
+        Returns:
+            Optional loss depending on the underlying local optimizer.
+
+        .. note: Any extra parameters are passed to the base optimizer as-is.
+        """
+        if self._overlap_with_ddp:
+            logger.warning(
+                "`step()` should not be included in the training loop when "
+                "`overlap_with_ddp=True`"
+            )
+            return None
+
+        # Perform the local optimizer step
+        loss = self._local_step(closure=closure, **kwargs)
+
+        # Sync all of the updated parameter shards across the ranks
+        self._sync_params()
+
+        return loss
+
+    def join_hook(self, **kwargs):
+        r"""
+        Return the ZeRO join hook.
+
+        It enables training on uneven inputs by
+        shadowing the collective communications in the optimizer step.
+
+        Gradients must be properly set before this hook is called.
+
+        Arguments:
+            kwargs (dict): a :class:`dict` containing any keyword arguments
+                to modify the behavior of the join hook at run time; all
+                :class:`Joinable` instances sharing the same join context
+                manager are forwarded the same value for ``kwargs``.
+
+        This hook does not support any keyword arguments; i.e. ``kwargs`` is
+        unused.
+        """
+        return _ZeROJoinHook(self)
+
+    @property
+    def join_device(self) -> torch.device:
+        r"""Return default device."""
+        return self._default_device
+
+    @property
+    def join_process_group(self) -> Any:
+        r"""Return process group."""
+        return self.process_group
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        r"""
+        Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed.
+
+        Arguments:
+            state_dict (dict): optimizer state; should be an object returned
+                from a call to :meth:`state_dict`.
+
+        Raises:
+            RuntimeError: if ``overlap_with_ddp=True`` and this method is
+                called before this :class:`ZeroRedundancyOptimizer` instance
+                has been fully initialized, which happens once
+                :class:`DistributedDataParallel` gradient buckets have been
+                rebuilt.
+        """
+        self._check_overlap_initialized()
+
+        for index, value in state_dict["state"].items():
+            param = self._index_to_param[index]
+            if self._param_to_rank[param] != self.rank:
+                # Clear any state irrelevant to this rank
+                state_dict["state"][index] = None
+            else:
+                # Load the parameter state to the local optimizer
+                self.optim.state[param] = _recursive_copy_to_device(
+                    value, non_blocking=True, device=param.device
+                )
+                # Force zero-dimensional tensors (like Adam "step") on CPU
+                for state_name, state_value in self.optim.state[param].items():
+                    if torch.is_tensor(state_value) and state_value.dim() == 0:
+                        self.optim.state[param][state_name] = state_value.cpu()
+
+        super().load_state_dict(state_dict)
+
+        # Sync the input state with the exposed and local optimizer states
+        self._sync_param_groups(state_dict["param_groups"], self.param_groups)
+        self._sync_param_groups(self.param_groups, self.optim.param_groups)
+
+    def state_dict(self) -> Dict[str, Any]:
+        r"""
+        Return the last global optimizer state known to this rank.
+
+        .. warning:
+            If the state has not been consolidated to this rank, this raises a
+            runtime error, and even if it has, the state may not be up-to-date,
+            depending on when :meth:`consolidate_state_dict` was last called.
+
+        Raises:
+            RuntimeError: if ``overlap_with_ddp=True`` and this method is
+                called before this :class:`ZeroRedundancyOptimizer` instance
+                has been fully initialized, which happens once
+                :class:`DistributedDataParallel` gradient buckets have been
+                rebuilt; or if this method is called without a preceding call
+                to :meth:`consolidate_state_dict`.
+        """
+        self._check_overlap_initialized()
+
+        if len(self._all_state_dicts) == 0:
+            raise RuntimeError(
+                "Optimizer state has not been consolidated on this rank. "
+                f"Please call `consolidate_state_dict(to={self.rank})` on "
+                "all ranks beforehand if you meant to save the global state."
+            )
+
+        # Get the possibly-stale global optimizer state that uses global
+        # parameter indexing
+        state_dict = super().state_dict()
+
+        # Update the global optimizer state with local state information,
+        # factoring in the translation from local to global indexing
+        for rank, local_state_dict in enumerate(self._all_state_dicts):
+            local_param_groups = local_state_dict["param_groups"]
+            global_param_groups = self._partition_parameters()[rank]
+            assert len(local_param_groups) == len(
+                global_param_groups
+            ), "Mismatch between number of local and global parameter groups"
+
+            for local_param_group, global_param_group in zip(
+                local_param_groups, global_param_groups
+            ):
+                # `local_param_group` stores local indices, while
+                # `global_param_group` stores the tensors directly
+                local_param_indices = local_param_group["params"]
+                global_params = global_param_group["params"]
+
+                assert len(local_param_indices) == len(
+                    global_params
+                ), "Mismatch between number of local and global parameters in parameter group"
+                for local_param_index, global_param in zip(
+                    local_param_indices, global_params
+                ):
+                    # Update the global parameter state, if any
+                    if local_param_index in local_state_dict["state"]:
+                        global_param_index = self._param_to_index[global_param]
+                        state_dict["state"][global_param_index] = local_state_dict[
+                            "state"
+                        ][local_param_index]
+
+        # Sort the parameters in the state
+        state_dict["state"] = dict(sorted(state_dict["state"].items()))
+        return state_dict
+
+    @staticmethod
+    def _sync_param_groups(
+        src_param_groups: List[Dict[Any, Any]],
+        dst_param_groups: List[Dict[Any, Any]],
+    ) -> None:
+        r"""
+        Sync the attributes from the source parameter groups to the destination parameter groups.
+
+        Example attributes include learning rate or scheduler attributes. The
+        two parameter groups should have the same length (i.e. same number of
+        parameter groups).
+
+        Arguments:
+            src_param_groups (list[dict]): parameter groups giving the
+                attribute settings to copy.
+            dst_param_groups (list[dict]): parameter groups giving the
+                attribute settings to set.
+        """
+        assert len(src_param_groups) == len(
+            dst_param_groups
+        ), "Mismatch between number of source and destination parameter groups"
+        for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
+            # Sync all attributes except the parameters
+            for attr in filter(lambda x: x != "params", src_param_group.keys()):
+                dst_param_group[attr] = src_param_group[attr]
+
+    def _build_param_buckets(self) -> None:
+        r"""
+        Build parameter buckets if ``parameters_as_bucket_view=True``.
+
+        For each device that stores this rank's parameters, there is a
+        bucket (represented as a tensor) containing all of the parameters on
+        that device that are assigned to a given rank in the parameter update
+        partition.
+
+        This method is called in the constructor and any time parameter
+        trainability is changed.
+
+        .. warning::
+            The current implementation assumes that all of the parameters in a
+            bucket are of the same dense type when allocating the bucket's
+            tensor.
+
+        .. warning::
+            If the model parameters are stored across more than one device,
+            then the storage partitioning must be the same across all
+            processes in order for parameter synchronization to work.
+        """
+        if not self.parameters_as_bucket_view or self._overlap_with_ddp:
+            return
+
+        # `self._buckets[i][j]` are the parameters stored on device i and
+        # assigned to rank j
+        num_devices = len(self._device_to_params_per_rank)
+        self._buckets = [[] for _ in range(num_devices)]  # type: ignore[assignment]
+
+        for dev_i, (device, params_per_rank) in enumerate(
+            self._device_to_params_per_rank.items()
+        ):
+            for params in params_per_rank:
+                bucket_size = 0
+                dtype = None
+                trainable_params = []
+                for param in params:
+                    if not _is_trainable(param):
+                        # Clone in case the parameter was previously part of
+                        # a bucket to avoid the data from being destroyed
+                        param.data = param.data.detach().clone()
+                    else:
+                        bucket_size += param.numel()
+                        trainable_params.append(param)
+                    dtype = param.dtype  # assumes all same dtype
+
+                if bucket_size == 0:
+                    # Create a dummy bucket if there are no parameters
+                    bucket = torch.zeros(1, device=device)
+                else:
+                    # Construct the bucket (assuming all dense and same dtype)
+                    bucket = torch.empty(bucket_size, dtype=dtype, device=device)
+                    offset = 0
+                    for param in trainable_params:
+                        offset_next = offset + param.numel()
+                        bucket[offset:offset_next].copy_(param.data.flatten())
+                        param.data = bucket[offset:offset_next].view_as(param.data)
+                        offset = offset_next
+                self._buckets[dev_i].append(bucket)  # type: ignore[arg-type]
+
+    def _build_ddp_param_buckets(self) -> None:
+        r"""
+        Build the DDP bucket with parameters assigned to this rank.
+
+        For each DDP bucket with parameters assigned to this rank, flattens the
+        data of those parameters into a single tensor and saves the tensor to
+        the ``tensor`` attribute in the corresponding
+        :class:`_DDPBucketAssignment` instance stored in
+        ``self._bucket_assignments_per_rank``.
+
+        :class:`DistributedDataParallel` guarantees that the parameters
+        corresponding to a gradient bucket have the same device and the same
+        dtype.
+        """
+        for bucket_assignments in self._bucket_assignments_per_rank:
+            for bucket_assignment in bucket_assignments.values():
+                params = bucket_assignment.parameters
+                bucket_size = 0
+                dtype = None
+                for param in params:
+                    assert _is_trainable(param), (
+                        "Model parameter "
+                        "corresponding to a gradient in a DDP bucket should "
+                        "require a gradient"
+                    )
+                    bucket_size += param.numel()
+                    dtype = param.dtype  # assumes all same dtype
+                assert bucket_size > 0, "Empty bucket"
+
+                # Construct the bucket tensor (assuming all dense and same dtype)
+                tensor = torch.empty(
+                    bucket_size, dtype=dtype, device=bucket_assignment.device
+                )
+                offset = 0
+                for param in params:
+                    offset_next = offset + param.numel()
+                    tensor[offset:offset_next].copy_(param.data.flatten())
+                    param.data = tensor[offset:offset_next].view_as(param.data)
+                    offset = offset_next
+                bucket_assignment.tensor = tensor
+
+    def _verify_and_init_params(
+        self,
+        params: Any,
+    ) -> Union[List[torch.Tensor], List[dict]]:
+        r"""
+        Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
+
+        The initializagtion will first make sure that provided ``params`` is valid.
+
+        Arguments:
+            params (Any): Candidate parameter list or parameter groups to verify.
+
+        Raises:
+            TypeError: ``params`` has an invalid type.
+            ValueError: ``params`` is empty.
+
+        Returns:
+            The persistent form of ``params`` to be passed into the parent
+            :class:`Optimizer` constructor -- i.e. returns ``params`` as a
+            :class:`list` to ensure that it can be iterated over again.
+        """
+        if isinstance(params, torch.Tensor):
+            raise TypeError(
+                "`params` argument should be an iterable of "
+                f"Tensors, but got {torch.typename(params)}"
+            )
+        try:
+            all_params = list(params)
+        except TypeError as e:
+            raise TypeError(
+                "`params` argument should be an iterable of Tensors"
+                f" or dicts, but got {torch.typename(params)}"
+            ) from e
+        if len(all_params) == 0:
+            raise ValueError("ZeroRedundancyOptimizer got an empty parameter list")
+        all_tensors = True
+        all_dicts = True
+        for param in all_params:
+            all_tensors &= isinstance(param, torch.Tensor)
+            all_dicts &= isinstance(param, dict)
+        if not all_tensors and not all_dicts:
+            raise TypeError(
+                "`params` argument should be an iterable of Tensors or dicts"
+            )
+        # Ensure that `self._all_params` contains a list of all parameters
+        if all_tensors:
+            self._all_params = all_params
+        elif all_dicts:
+            self._all_params = []
+            # `all_params` contains parameter groups (not parameters)
+            for param_group in all_params:
+                if "params" not in param_group:
+                    raise ValueError(
+                        "Each parameter group passed-in via `params` must "
+                        "have a 'params' key mapping to the parameters in "
+                        "the group"
+                    )
+                self._all_params.extend(param_group["params"])
+        return all_params
+
+    def _verify_same_dense_param_type(self) -> None:
+        r"""
+        Verify that all parameters are of the same dense type.
+
+        The method assumes that ``self._all_params`` has been initialized
+        and is non-empty.
+
+        Raises:
+            ValueError: ``params`` contains sparse parameters or parameters
+            of varying dense types.
+
+        NOTE: This method can be removed once support for sparse parameters
+        and varying parameter types is added.
+        """
+        typename = torch.typename(self._all_params[0])
+        if self._all_params[0].is_sparse:
+            raise ValueError(
+                "ZeroRedundancyOptimizer only supports using "
+                "the same dense type for all parameters but got "
+                f"{typename}"
+            )
+        for param in self._all_params[1:]:
+            other_typename = torch.typename(param)
+            if other_typename != typename:
+                raise ValueError(
+                    "ZeroRedundancyOptimizer only supports "
+                    "using the same dense type for all "
+                    f"parameters but got both {typename} and "
+                    f"{other_typename}"
+                )
+
+    def _get_is_trainable_mask(self) -> List[bool]:
+        r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not."""
+        return list(map(_is_trainable, self._all_params))
+
+    def _init_local_optimizer(self) -> None:
+        r"""
+        Initialize this rank's local optimizer, responsible for its subset of the parameters.
+
+        The local optimizer is saved in ``self.optim``.
+        """
+        assert (
+            self._optim_constructor is not None
+        ), "The local optimizer class has not been set"
+
+        param_groups = self._partition_parameters()[self.rank]
+        # `overlap_with_ddp=True` requires a local functional optimizer
+        if self._overlap_with_ddp:
+            # Functional optimizers only support a single parameter group and
+            # require passing in the parameters as a list
+            assert len(param_groups) == 1, (
+                "Initializing the local "
+                "functional optimizer with more than one parameter group"
+            )
+            params = param_groups[0]["params"]
+            # Try to pass `_allow_empty_param_list=True` to avoid erroring
+            if (
+                "_allow_empty_param_list"
+                in inspect.signature(self._optim_constructor).parameters
+            ):
+                self.optim: Any = self._optim_constructor(
+                    params, **self._optim_defaults, _allow_empty_param_list=True
+                )
+            else:
+                logger.warning(
+                    "%s does not support the argument "
+                    "`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
+                    "error due to an empty parameter list",
+                    self._optim_constructor
+                )
+                self.optim: Any = self._optim_constructor(params, **self._optim_defaults)  # type: ignore[no-redef]
+
+            # Log information about the DDP and ZeRO bucketing
+            if dist.get_debug_level() != dist.DebugLevel.OFF:
+                local_numel = sum(p.numel() for p in params)
+                num_assigned_buckets = len(
+                    self._bucket_assignments_per_rank[self.global_rank]
+                )
+                logger.info(
+                    "rank %s with %s parameters "
+                    "across %s buckets",
+                    self.global_rank, local_numel, num_assigned_buckets
+                )
+                if self.global_rank == 0:
+                    logger.info(
+                        "%s DDP "
+                        "buckets and "
+                        "%s bucket "
+                        "assignments",
+                        len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments
+                    )
+        else:
+            # NOTE: Passing `param_groups` into the local optimizer constructor
+            # bypasses the empty parameter list check
+            self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults)  # type: ignore[no-redef]
+
+        # TODO: Manually add `self.param_groups` if using a functional
+        # optimizer; remove this if/when the functional optimizers support
+        # multiple parameter groups
+        if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"):
+            assert hasattr(self.optim, "param_group"), (
+                "The functional optimizer should set at least one of the "
+                "attributes `param_group` or `param_groups`"
+            )
+            self.optim.param_groups = [self.optim.param_group]  # type: ignore[attr-defined]
+
+        self._sync_param_groups(self.optim.param_groups, self.param_groups)
+
+    def _init_zero_for_overlap(self) -> None:
+        r"""Perform a delayed initialization of the local optimizer and the supporting data structures."""
+        assert self._overlap_with_ddp, (
+            "`_init_zero_for_overlap()` should only be called when "
+            "`overlap_with_ddp=True`"
+        )
+        self._overlap_info.status = _OverlapStatus.INITIALIZED
+        self._clear_cache()
+        self._partition_parameters(self._overlap_info.params_per_rank)
+        self._build_ddp_param_buckets()
+        self._init_local_optimizer()
+
+    def _get_assigned_rank(self, bucket_index: int) -> int:
+        r"""
+        Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket.
+
+        Arguments:
+            bucket_index (int): index of the :class:`DistributedDataParallel`
+                bucket for which to get the assigned rank.
+        """
+        assert not self._overlap_info.shard_buckets, (
+            "The bucket assignment requires global bucket information and "
+            "will be computed later; there should be no need to use this "
+            "method"
+        )
+        return bucket_index % self.world_size
+
+    def _check_overlap_initialized(self):
+        r"""
+        Check the delayed initialization depending on the value of ``overlap_with_ddp``.
+
+        The delayed initialization has occurred (see
+        :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and
+        raises a ``RuntimeError`` if not. This should preface methods that
+        should not be run before that delayed initialization.
+
+        Raises:
+            RuntimeError: if ``overlap_with_ddp=True`` and
+                :meth:`_init_zero_for_overlap` has not been called.
+        """
+        if (
+            self._overlap_with_ddp
+            and self._overlap_info.status != _OverlapStatus.INITIALIZED
+        ):
+            raise RuntimeError(
+                "This method should not be called until this "
+                "ZeroRedundancyOptimizer instance has been fully "
+                "initialized"
+            )
+
+    def _get_optimizer_constructor(self, optimizer_class: Any) -> Any:
+        r"""
+        Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``.
+
+        Returns:
+            - ``optimizer_class`` if ``overlap_with_ddp=False`` and
+                ``optimizer_class`` is not a functional optimizer.
+            - ``optimizer_class`` if ``overlap_with_ddp=True`` and
+                ``optimizer_class`` is already a functional optimizer.
+            - The functional equivalent of ``optimizer_class`` if
+                ``overlap_with_ddp=True`` and ``optimizer_class`` is not
+                already a functional optimizer (assuming the equivalent
+                exists).
+
+        Raises:
+            ValueError:
+
+                - if ``overlap_with_ddp=True`` but ``optimizer_class`` is
+                    neither a functional optimizer nor translatable to a
+                    functional optimizer.
+                - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a
+                    functional optimizer.
+        """
+        functional_optims = functional_optim_map.values()
+        if not self._overlap_with_ddp:
+            if optimizer_class in functional_optims:
+                # Using a functional optimizer is only supported when
+                # `overlap_with_ddp=True`
+                raise ValueError(
+                    f"Passing in a functional optimizer {optimizer_class} "
+                    "when `overlap_with_ddp=False`"
+                )
+            else:
+                return optimizer_class
+        else:
+            if optimizer_class in functional_optims:
+                # Already a functional optimizer
+                return optimizer_class
+            elif optimizer_class in functional_optim_map:
+                # Translate the passed-in optimizer class to its functional
+                # equivalent if `overlap_with_ddp=True`
+                optim_constructor = functional_optim_map[optimizer_class]
+                logger.info(
+                    "Using the functional optimizer %s "
+                    "instead of %s since "
+                    "`overlap_with_ddp=True`",
+                    optim_constructor, optimizer_class
+                )
+                return optim_constructor
+            else:
+                raise ValueError(
+                    "Using `ddp_with_overlap=True` requires using a "
+                    "functional optimizer, but there is no supported functional "
+                    f"optimizer equivalent for {optimizer_class}"
+                )
diff --git a/MLPY/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi b/MLPY/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..22c434bb4fddb21d4780aa80c51aacff5bde08ee
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi
@@ -0,0 +1,83 @@
+import enum
+from typing import Any, Callable, Dict, List, Optional, overload, Set, Type
+
+import torch
+from torch.distributed.algorithms.join import Joinable, JoinHook
+from torch.optim import Optimizer
+
+class _ZeROJoinHook(JoinHook):
+    zero: Any = ...
+    def __init__(self, zero: Any) -> None: ...
+    def main_hook(self) -> None: ...
+
+class _DDPBucketAssignment:
+    bucket_index: int
+    parameters: List[torch.Tensor]
+    offset: int
+    device: torch.device
+    tensor: Optional[torch.Tensor]
+
+class _OverlapStatus(enum.IntEnum):
+    UNINITIALIZED: int = ...
+    DDP_HAS_REBUILT_BUCKETS: int = ...
+    INITIALIZED: int = ...
+
+class _OverlapInfo:
+    status: Any = ...
+    params_per_bucket: Any = ...
+    params_per_rank: Any = ...
+    offsets: Any = ...
+    broadcast_handles: Any = ...
+    bucket_index_to_future: Any = ...
+    bucket_index_to_bucket: Any = ...
+    bucket_indices_seen: Any = ...
+    assigned_ranks_per_bucket: List[Set[int]] = ...
+    total_size: int = ...
+    shard_buckets: bool = ...
+    def __init__(self) -> None: ...
+    def wait_for_broadcasts(self) -> None: ...
+    def clear_per_iter_info(self) -> None: ...
+
+class ZeroRedundancyOptimizer(Optimizer, Joinable):
+    functional_optim_map: Any = ...
+    initialized: bool = ...
+    process_group: Any = ...
+    world_size: int = ...
+    rank: int = ...
+    global_rank: int = ...
+    parameters_as_bucket_view: bool = ...
+    optim: Any = ...
+    _device_to_device_index: Dict[torch.device, int] = ...
+    _overlap_with_ddp: bool = ...
+    _overlap_info: _OverlapInfo = ...
+    _buckets: List[List[torch.Tensor]] = ...
+    _bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ...
+    def __init__(
+        self,
+        params: Any,
+        optimizer_class: Type[Optimizer],
+        process_group: Optional[Any] = ...,
+        parameters_as_bucket_view: bool = ...,
+        overlap_with_ddp: bool = ...,
+        **defaults: Any,
+    ) -> None: ...
+    def add_param_group(self, param_group: Dict[str, Any]) -> None: ...
+    def consolidate_state_dict(self, to: int = ...) -> None: ...
+    @overload
+    def step(self, closure: None = ..., **kwargs: Any) -> None: ...
+    @overload
+    def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
+    def state_dict(self) -> Dict[str, Any]: ...
+    def _local_step(
+        self,
+        gradients: Optional[List[Optional[torch.Tensor]]] = None,
+        closure: Optional[Callable[[], float]] = None,
+        **kwargs: Any,
+    ) -> Optional[float]: ...
+    def _get_assigned_rank(self, bucket_index: int) -> int: ...
+    def _init_zero_for_overlap(self) -> None: ...
+    def join_hook(self, **kwargs): ...
+    @property
+    def join_device(self) -> torch.device: ...
+    def join_process_group(self) -> Any: ...
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/__init__.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f62634d1e9258f7d0d2f9e27504356ccea03b4d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/__init__.py
@@ -0,0 +1,7 @@
+import warnings
+warnings.warn(
+    "torch.distributed.pipeline is deprecated. For up-to-date pipeline parallel "
+    "implementation, please refer to the PiPPy library under the PyTorch "
+    "organization (Pipeline Parallelism for PyTorch): "
+    "https://github.com/pytorch/PiPPy"
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d890ae88b2f29f8e4a1155fa2c1852efb41ed0bd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__init__.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e05daf288010e98488a99cccd375a5d5cea3784b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""A Pipe implementation in PyTorch."""
+from .checkpoint import is_checkpointing, is_recomputing
+from .pipe import Pipe, WithDevice
+from .microbatch import NoChunk
+
+__all__ = ["Pipe", "is_checkpointing", "is_recomputing"]
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d46aad5319d519cd5c12c285eb4608fb94c4ce4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/batchnorm.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/batchnorm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d961a8d85c8cfc57840501c7920daf11be3fd1a0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/batchnorm.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/checkpoint.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/checkpoint.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a008c682a0204e827e0ae2563da5c085055a0312
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/checkpoint.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46e7b4faf68f815f298fbec54fba9399f2afcaf6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/dependency.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/dependency.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc0a5b7e1335d541182e0ef3445c78818b761df3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/dependency.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/microbatch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/microbatch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3c0b2e0ec2b6733b885f716a9c44564b4a481ef
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/microbatch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/phony.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/phony.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb588cff3a0721eafabcb551517ed11f0efeb10b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/phony.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/pipe.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/pipe.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fb761651b69025870faa893ea0e6c7e1d069247
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/pipe.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/pipeline.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/pipeline.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e28957f87e633df05438159435226c7d64edc329
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/pipeline.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/stream.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/stream.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0164d127975a53cd89be3989d42bf33032bd46e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/stream.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64745d401c75b5a02fb8d9c40fe0bb924c9b0772
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28e57244590f30929dc1afd8bdb9e1a7726d8cad
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__init__.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3204f5e93573ba4cd23d24aa1a14dc9bca3cf1a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__init__.py
@@ -0,0 +1,164 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""A helper to roughly balance a sequential module.
+
+Usage::
+
+    import torch
+    from torch.distributed.pipeline.sync import Pipe
+    from torch.distributed.pipeline.sync.balance import balance_by_time
+
+    sample = torch.empty(128, 3, 224, 224)
+    balance = balance_by_time(torch.cuda.device_count(), model, sample)
+
+    pipe = Pipe(model, balance, chunks=8)
+
+"""
+from typing import Any, List, Union, Sequence
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+
+from . import blockpartition
+from .profile import profile_sizes, profile_times
+
+__all__ = ["balance_by_time", "balance_by_size"]
+
+
+Device = Union[torch.device, int, str]
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+
+
+def balance_cost(cost: List[int], partitions: int) -> List[int]:
+    partitioned = blockpartition.solve(cost, partitions)
+    return [len(p) for p in partitioned]
+
+
+def balance_by_time(
+    partitions: int,
+    module: nn.Sequential,
+    sample: Union[List[Any], Tensor],
+    *,
+    timeout: float = 1.0,
+    device: Device = torch.device("cuda"),
+) -> List[int]:
+    """Naive automatic balancing by elapsed time per layer.
+    ::
+
+        sample = torch.empty(128, 3, 224, 224)
+        balance = balance_by_time(torch.cuda.device_count(), model, sample)
+        pipe = Pipe(model, balance, chunks=8)
+
+    Args:
+        partitions (int):
+            intended number of partitions
+        module (torch.nn.Sequential):
+            sequential module to be partitioned
+        sample (torch.Tensor):
+            example input with arbitrary batch size
+
+    Keyword Args:
+        timeout (float):
+            profiling iterates again if the timeout (in second) is not exceeded
+            (default: ``1.0``)
+        device ('cpu' or 'cuda' device):
+            CPU or CUDA device where each layer is profiled (default: the
+            current CUDA device)
+
+    Returns:
+        A list of number of layers in each partition. Use it for the `balance`
+        parameter of :class:`~torchpipe.Pipe`.
+
+    .. note::
+        `module` and `sample` must be placed on the same device.
+
+    """
+    times = profile_times(module, sample, timeout, torch.device(device))
+    return balance_cost(times, partitions)
+
+
+def balance_by_size(
+    partitions: int,
+    module: nn.Sequential,
+    input: Union[List[Any], Tensor],
+    *,
+    chunks: int = 1,
+    param_scale: float = 2.0,
+    device: Device = torch.device("cuda"),
+) -> List[int]:
+    """Naive automatic balancing by CUDA memory usage per layer.
+
+    During training, required memory for parameters depends on which optimizer
+    is used. Optimizers may use buffers for each parameter to track
+    optimization statistics internally, such as momentum buffer in SGD.
+
+    To get more reliable size based balance, you should specify `param_scale`
+    with regard to your optimizer. The default `param_scale` is 2 instead of 1
+    due to gradient accumulation which is necessary for every optimizer.
+
+    Follow this guide to choose correct `param_scale` for typical optimizers:
+
+    =========  =============  =========================================
+    Optimizer  `param_scale`  Internal State
+    =========  =============  =========================================
+    SGD        2--3           (momentum_buffer)
+    Adam       4--5           exp_avg, exp_avg_sq, (max_exp_avg_sq)
+    Adadelta   4              square_avg, acc_delta
+    Adagrad    3              sum
+    RMSprop    3--5           square_avg, (momentum_buffer), (grad_avg)
+    =========  =============  =========================================
+
+    Here's a simple example with the Adam optimizer::
+
+        balance = balance_by_size(
+            torch.cuda.device_count(),
+            model,
+
+            # Same size with mini-batch to train
+            torch.empty(1024, 3, 224, 224),
+
+            # Number of micro-batches to train with Pipe
+            chunks=8,
+
+            # 4 for Adam
+            param_scale=4.0,
+        )
+
+        pipe = Pipe(model, balance, chunks=8)
+        adam = Adam(pipe.parameters())
+
+    Args:
+        partitions (int):
+            intended number of partitions
+        module (torch.nn.Sequential):
+            sequential module to be partitioned
+        input (torch.Tensor):
+            example mini-batch with the same size to train
+
+    Keyword Args:
+        chunks (int):
+            number of micro-batches will be used to train (default: ``1``)
+        param_scale (float):
+            how many copies of parameters would be allocated for training. It
+            depends on optimizer. See the above guide. (default: ``2.0``)
+        device ('cuda' device):
+            CUDA device where each layer is profiled (default: the current CUDA
+            device)
+
+    Returns:
+        A list of number of layers in each partition. Use it for the `balance`
+        parameter of :class:`~torchpipe.Pipe`.
+
+    .. note::
+        `module` and `input` must be placed on the same CUDA device.
+
+    """
+    sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device))
+    return balance_cost(sizes, partitions)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c07fe92e67695ff44d1ab05cc2a8ba4b0f1852a5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/blockpartition.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/blockpartition.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e958f676dddbe4ffb89272a9bd7ed62a593bf46
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/blockpartition.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/profile.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/profile.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77fc197e3dc20a0ad8b33e6e11227d082e116bde
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/__pycache__/profile.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/blockpartition.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/blockpartition.py
new file mode 100644
index 0000000000000000000000000000000000000000..c95d42771d5f6696433de9cf85db66738fe71d8a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/blockpartition.py
@@ -0,0 +1,95 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Implements "Block Partitions of Sequences" by Imre Bárány et al.
+
+Paper: https://arxiv.org/pdf/1308.2452.pdf
+
+"""
+from typing import Iterator, List, Tuple
+
+__all__ = ["solve"]
+
+
+def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]:
+    """Splits a sequence into several partitions to minimize variance for each
+    partition.
+
+    The result might not be optimal. However, it can be done only in O(kn³),
+    where k is the number of partitions and n is the length of the sequence.
+
+    """
+    if partitions < 1:
+        raise ValueError(f"partitions must be a positive integer ({partitions} < 1)")
+
+    n = len(sequence)
+    if n < partitions:
+        raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})")
+
+    # Normalize the sequence in [0, 1].
+    minimum = min(sequence)
+    maximum = max(sequence) - minimum
+
+    normal_sequence: List[float]
+    if maximum == 0:
+        normal_sequence = [0 for _ in sequence]
+    else:
+        normal_sequence = [(x - minimum) / maximum for x in sequence]
+
+    splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n]
+
+    def block_size(i: int) -> float:
+        start = splits[i - 1] if i > 0 else 0
+        stop = splits[i]
+        return sum(normal_sequence[start:stop])
+
+    def leaderboard() -> Iterator[Tuple[float, int]]:
+        return ((block_size(i), i) for i in range(partitions))
+
+    while True:
+        """
+        (1) Fix p ∈ [k] with M(P) = bp. So Bp is a maximal block of P.
+        """
+        # max_size: M(P)
+        max_size, p = max(leaderboard())
+
+        while True:
+            """
+            (2) If M(P) ≤ m(P) + 1, then stop.
+            """
+            # min_size: m(P)
+            min_size, q = min(leaderboard())
+
+            if max_size <= min_size + 1:
+                return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)]
+
+            """
+            (3) If M(P) > m(P) + 1, then let m(P) = bq for the q ∈ [k] which is
+            closest to p (ties broken arbitrarily). Thus Bq is a minimal block
+            of P. Let Bh be the block next to Bq between Bp and Bq. (Note that
+            Bh is a non-empty block: if it were, then m(P) = 0 and we should
+            have chosen Bh instead of Bq.)
+            """
+            if p < q:
+                """
+                So either p < q and then h = q−1 and we define P ∗ by moving
+                the last element from Bh = Bq−1 to Bq,
+                """
+                h = q - 1
+                splits[h] -= 1
+            else:
+                """
+                or q < p, and then h = q + 1 and P ∗ is obtained by moving the
+                first element of Bh = Bq+1 to Bq.
+                """
+                h = q + 1
+                splits[q] += 1
+
+            """
+            Set P = P ∗ . If p = h, then go to (1), else go to (2).
+            """
+            if p == h:
+                break
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/profile.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/profile.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e0a6b7d00b8cd168c657e0d69d202d023842f9a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/profile.py
@@ -0,0 +1,116 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Per-layer profilers."""
+import copy
+import time
+from typing import Any, Generator, List, Union, Sequence
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+
+from ..microbatch import Batch
+
+__all__: List[str] = []
+
+
+Device = Union[torch.device, int, str]
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+
+
+def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]:
+    """Copies layers for ease to profile. It doesn't modify the given
+    module.
+    """
+    for layer in module:
+        layer_copy = copy.deepcopy(layer)
+        layer_copy.to(device)
+        layer_copy.train()
+        yield layer_copy
+
+
+def detach(batch: Batch) -> None:
+    """Detaches from autograd graph."""
+    for i, x in enumerate(batch):
+        batch[i] = x.detach().requires_grad_(x.requires_grad)
+
+
+def profile_times(module: nn.Sequential, sample: Union[List[Any], Tensor], timeout: float, device: torch.device,) -> List[int]:
+    """Profiles elapsed times per layer."""
+    if any(p.grad is not None for p in module.parameters()):
+        raise ValueError("some parameter already has gradient")
+
+    _batch = Batch(sample)
+    for i, x in enumerate(_batch):
+        _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
+
+    time_bufs: List[List[float]] = [[] for _ in module]
+    begun_at = time.time()
+
+    while time.time() - begun_at < timeout:
+        batch = _batch
+
+        for i, layer in enumerate(layerwise_sandbox(module, device)):
+            detach(batch)
+
+            if device.type == "cuda":
+                torch.cuda.synchronize(device)
+            tick = time.time()
+
+            # Forward
+            batch = batch.call(layer)
+
+            # Backward
+            backward_tensors = tuple(y for y in batch if y.requires_grad)
+            if backward_tensors:
+                torch.autograd.backward(backward_tensors, backward_tensors)
+
+            if device.type == "cuda":
+                torch.cuda.synchronize(device)
+            tock = time.time()
+
+            time_bufs[i].append(tock - tick)
+
+    us = 1_000_000
+    return [sum(int(t * us) for t in buf) for buf in time_bufs]
+
+
+def profile_sizes(
+    module: nn.Sequential, input: Union[List[Any], Tensor], chunks: int, param_scale: float, device: torch.device,
+) -> List[int]:
+    """Profiles CUDA memory usage per layer."""
+    if device.type != "cuda":
+        raise ValueError("size profiler supports only CUDA device")
+
+    batch = Batch(input)
+    sizes: List[int] = []
+
+    latent_scale = batch[0].size(0) / chunks
+    for i, x in enumerate(batch):
+        batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad)
+
+    for layer in layerwise_sandbox(module, device):
+        detach(batch)
+
+        # Detect memory usage at forward.
+        torch._C._cuda_clearCublasWorkspaces()
+        memory_before = torch.cuda.memory_allocated(device)
+        batch = batch.call(layer)
+        torch._C._cuda_clearCublasWorkspaces()
+        memory_after = torch.cuda.memory_allocated(device)
+        latent_size = memory_after - memory_before
+
+        # Analyze size of parameters.
+        param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters())
+
+        # Combine size of parameters and activations with normalize scales.
+        size = latent_size * latent_scale + param_size * param_scale
+        sizes.append(int(size))
+
+    return sizes
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/py.typed b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..f4830a6416775aae091858a4ac5158ce69f7de29
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/_balance/py.typed
@@ -0,0 +1,6 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/batchnorm.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..882ebe8266feaec65765d82fcbc9b362da5be40f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/batchnorm.py
@@ -0,0 +1,159 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Tracks the running statistics per mini-batch instead of micro-batch."""
+from typing import TypeVar, cast
+
+import torch
+from torch import Tensor, nn
+from torch.nn.functional import batch_norm
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from .checkpoint import is_recomputing
+
+__all__ = ["DeferredBatchNorm"]
+
+
+TModule = TypeVar("TModule", bound=nn.Module)
+
+
+class DeferredBatchNorm(_BatchNorm):
+    """A BatchNorm layer tracks multiple micro-batches to update running statistics per mini-batch."""
+
+    sum: Tensor
+    sum_squares: Tensor
+    running_mean: Tensor
+    running_var: Tensor
+    num_batches_tracked: Tensor
+
+    def __init__(
+        self,
+        num_features: int,
+        eps: float = 1e-5,
+        momentum: float = 0.1,
+        affine: bool = True,
+        chunks: int = 1,
+    ) -> None:
+        super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
+
+        self.register_buffer("sum", torch.zeros_like(self.running_mean))
+        self.register_buffer("sum_squares", torch.zeros_like(self.running_var))
+
+        self.counter = 0
+        self.tracked = 0
+        self.chunks = chunks
+
+    def _check_input_dim(self, input: Tensor) -> None:
+        # It's the typical _check_input_dim() implementation in PyTorch.
+        if input.dim() <= 2:
+            raise ValueError("expected at least 3D input (got %dD input)" % input.dim())
+
+    def _track(self, input: Tensor) -> bool:
+        """Tracks statistics of a micro-batch."""
+        # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
+        dim = [0]
+        dim.extend(range(2, input.dim()))
+
+        with torch.no_grad():
+            self.sum += input.sum(dim)
+            self.sum_squares += (input ** 2).sum(dim)
+
+        size = input.size().numel() // input.size(1)
+        self.counter += size
+        self.tracked += 1
+
+        return self.tracked == self.chunks
+
+    def _commit(self) -> None:
+        """Update the running statistics of a mini-batch."""
+        exponential_average_factor = 0.0
+        self.num_batches_tracked += 1
+        if self.momentum is None:  # use cumulative moving average
+            exponential_average_factor = 1.0 / float(self.num_batches_tracked)
+        else:  # use exponential moving average
+            exponential_average_factor = self.momentum
+
+        mean = self.sum / self.counter
+        var = self.sum_squares / self.counter - mean ** 2
+
+        # Calculate the exponential moving average here.
+        m = exponential_average_factor
+
+        self.running_mean *= 1 - m
+        self.running_mean += mean * m
+
+        self.running_var *= 1 - m
+        self.running_var += var * m
+
+        self.sum.zero_()
+        self.sum_squares.zero_()
+        self.counter = 0
+        self.tracked = 0
+
+    def forward(self, input: Tensor) -> Tensor:
+        if not self.training:
+            # Don't train parameters on the evaluation mode.
+            return batch_norm(
+                input,
+                running_mean=self.running_mean,
+                running_var=self.running_var,
+                weight=self.weight,
+                bias=self.bias,
+                training=False,
+                momentum=0.0,
+                eps=self.eps,
+            )
+
+        if not is_recomputing():
+            # Track a micro-batch on the training mode
+            # but not under a recomputation.
+            tracked_enough = self._track(input)
+
+            # Update the running statistics for a mini-batch
+            # if it has tracked enough micro-batches.
+            if tracked_enough:
+                self._commit()
+
+        # Normalize a micro-batch and train the parameters.
+        return batch_norm(
+            input,
+            running_mean=None,
+            running_var=None,
+            weight=self.weight,
+            bias=self.bias,
+            training=True,
+            momentum=0.0,
+            eps=self.eps,
+        )
+
+    @classmethod
+    def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
+        """Converts a :class:`nn.BatchNorm` or underlying :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::
+
+            from torchvision.models.resnet import resnet101
+            from torchpipe.batchnorm import DeferredBatchNorm
+            model = resnet101()
+            model = DeferredBatchNorm.convert_deferred_batch_norm(model)
+
+        """
+        if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
+            return cast(TModule, module)
+
+        module_output: nn.Module = module
+
+        if isinstance(module, _BatchNorm) and module.track_running_stats:
+            module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
+            if module.affine:
+                module_output.register_parameter("weight", module.weight)
+                module_output.register_parameter("bias", module.bias)
+            module_output.register_buffer("running_mean", module.running_mean)
+            module_output.register_buffer("running_var", module.running_var)
+            module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
+
+        for name, child in module.named_children():
+            module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))
+
+        return cast(TModule, module_output)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/checkpoint.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7624392d36a2e5e83ea53b9ed87ad7bdb4380ae
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/checkpoint.py
@@ -0,0 +1,364 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Checkpointing with preceding recomputation.
+
+PyTorch already provides the official checkpointing utilities in
+:mod:`torch.utils.checkpoint`. The official checkpointing combines
+recomputation and recursive backpropagation into one autograd function named
+``CheckpointFunction``. Hence, the recomputation can be started only when the
+gradients arrive to the function. In Pipe, the recomputation needs to precede
+the gradient arrival to minimize the GPU idle time.
+
+We solve this problem by introducing separate autograd functions named
+:class:`Recompute` and :class:`Checkpoint`. Each function represents
+recomputation and recursive backpropagation, respectively. We can manipulate
+the control flow in aspect of both the autograd engine and CUDA with a pair of
+the functions.
+
+Specifically, we place CUDA stream synchronization between :class:`Recompute`
+and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
+copied entirely.
+
+"""
+from collections import deque
+from contextlib import contextmanager
+import threading
+from typing import (
+    Any,
+    Deque,
+    Generator,
+    List,
+    Optional,
+    Protocol,
+    Union,
+    Sequence,
+    Tuple
+)
+
+import torch
+from torch import Tensor
+import torch.autograd
+
+from .dependency import fork, join
+from .microbatch import Batch
+from .phony import get_phony
+
+__all__ = ["Function", "checkpoint", "Checkpointing", "ThreadLocal", "enable_checkpointing",
+           "enable_recomputing", "is_checkpointing", "is_recomputing", "Context", "save_rng_states",
+           "restore_rng_states", "Checkpoint", "Recompute"]
+
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+
+# Types for shared memory between Checkpoint and Recompute.
+Recomputed = Tuple[TensorOrTensors, Tensors]  # (output, input_leaf)
+RNGStates = Tuple[Tensor, Optional[Tensor]]  # (cpu_rng_state, gpu_rng_state)
+
+
+# Protocol with __call__ instead of Callable can be used as an attribute type.
+# See: https://github.com/python/mypy/issues/708#issuecomment-561735949
+class Function(Protocol):
+    def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
+        ...
+
+
+def checkpoint(function: Function, input):
+    """Make a checkpoint with a simple interface like
+    :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
+    :class:`Checkpoint` and :class:`Recompute` without boilerplate.
+    """
+    batch = Batch(input)
+
+    chk = Checkpointing(function, batch)
+    batch = chk.checkpoint()
+    chk.recompute(batch)
+
+    return batch.values
+
+
+class Checkpointing:
+    """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
+
+    def __init__(self, function: Function, batch: Batch) -> None:
+        self.function = function
+        self.batch = batch
+
+        # Shared memory between Checkpoint and Recompute. 1-length deque is
+        # used for mutability and length limitation.
+        self.recomputed: Deque[Recomputed] = deque(maxlen=1)
+        self.rng_states: Deque[RNGStates] = deque(maxlen=1)
+
+    def checkpoint(self) -> Batch:
+        """Return a batch applied by :class:`Checkpoint`."""
+        input_atomic = self.batch.atomic
+        inputs = tuple(self.batch)
+
+        # Use a phony which requires grad to ensure that Checkpoint can be
+        # tracked by the autograd engine even when none of the input tensors
+        # require grad.
+        phony = get_phony(self.batch.get_device(), requires_grad=True)
+
+        output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs)
+
+        # Gradients are only supported for float Tensors.
+        if isinstance(output, tuple):
+            output = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in output])
+
+        return Batch(output)
+
+    def recompute(self, batch: Batch) -> None:
+        """Apply :class:`Recompute` to the batch in place."""
+        input_atomic = self.batch.atomic
+        inputs = tuple(self.batch)
+
+        # Use a tensor in the batch to tie together fork-join
+        tensor_idx = batch.find_tensor_idx()
+        # batch[tensor_idx] is always requiring grad, because it has been passed
+        # checkpoint with a phony requiring grad.
+        batch[tensor_idx], phony = fork(batch[tensor_idx])
+        phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *inputs)
+        batch[tensor_idx] = join(batch[tensor_idx], phony)
+
+
+class ThreadLocal(threading.local):
+    def __init__(self) -> None:
+        self.is_checkpointing = False
+        self.is_recomputing = False
+
+
+thread_local = ThreadLocal()
+
+
+@contextmanager
+def enable_checkpointing() -> Generator[None, None, None]:
+    """Make :func:`is_checkpointing` return :data:`True` within a context."""
+    orig = thread_local.is_checkpointing
+    thread_local.is_checkpointing = True
+    try:
+        yield
+    finally:
+        thread_local.is_checkpointing = orig
+
+
+@contextmanager
+def enable_recomputing() -> Generator[None, None, None]:
+    """Makes :func:`is_recomputing` return :data:`True` within a context."""
+    orig = thread_local.is_recomputing
+    thread_local.is_recomputing = True
+    try:
+        yield
+    finally:
+        thread_local.is_recomputing = orig
+
+
+def is_checkpointing() -> bool:
+    """Whether the current forward propagation is under checkpointing.
+
+    Returns:
+        bool: :data:`True` if it's under checkpointing.
+
+    """
+    return thread_local.is_checkpointing
+
+
+def is_recomputing() -> bool:
+    """Whether the current forward propagation is under checkpoint recomputation.
+
+    Use this to prevent duplicated side-effects at forward
+    propagation::
+
+        class Counter(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.counter = 0
+
+            def forward(self, input):
+                if not is_recomputing():
+                    self.counter += 1
+                return input
+
+    Returns:
+        bool: :data:`True` if it's under checkpoint recomputation.
+
+    .. seealso:: :ref:`Detecting Recomputation`
+
+    """
+    return thread_local.is_recomputing
+
+
+class Context:
+    """The common interface between the :class:`Checkpoint` and :class:`Recompute` context."""
+
+    recomputed: Deque[Recomputed]
+    rng_states: Deque[RNGStates]
+    function: Function
+    input_atomic: bool
+    inputs: Sequence[Any]
+
+    saved_tensors: Tuple[Tensor, ...]
+
+    def save_for_backward(self, *tensors: Tensor) -> None:  # pragma: no cover
+        pass
+
+
+def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
+    """:
+    Capture the current random number generator states.
+
+    meth:`Checkpoint.forward` captures the current PyTorch's random number
+    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
+
+    .. seealso:: :ref:`Referential Transparency`
+
+    """
+    cpu_rng_state = torch.get_rng_state()
+
+    gpu_rng_state: Optional[Tensor]
+    if device.type == "cuda":
+        gpu_rng_state = torch.cuda.get_rng_state(device)
+    else:
+        gpu_rng_state = None
+
+    rng_states.append((cpu_rng_state, gpu_rng_state))
+
+
+@contextmanager
+def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
+    """:
+    Restore the random number generator state.
+
+    meth:`Recompute.backward` restores the random number generator states
+    captured by :func:`save_rng_states` within its context.
+
+    .. seealso:: :ref:`Referential Transparency`
+
+    """
+    cpu_rng_state, gpu_rng_state = rng_states.pop()
+
+    gpu_devices: List[torch.device] = []
+    if device.type == "cuda":
+        gpu_devices.append(device)
+
+    with torch.random.fork_rng(gpu_devices):
+        torch.set_rng_state(cpu_rng_state)
+        if gpu_rng_state is not None:
+            torch.cuda.set_rng_state(gpu_rng_state, device)
+        yield
+
+
+class Checkpoint(torch.autograd.Function):
+    @staticmethod
+    # type: ignore[override]
+    def forward(
+        ctx: Context,
+        phony: Tensor,
+        recomputed: Deque[Recomputed],
+        rng_states: Deque[RNGStates],
+        function: Function,
+        input_atomic: bool,
+        *inputs,
+    ):
+        ctx.recomputed = recomputed
+        ctx.rng_states = rng_states
+
+        save_rng_states(phony.device, ctx.rng_states)
+
+        ctx.function = function
+        ctx.input_atomic = input_atomic
+        if input_atomic:
+            tensors = [inputs[0]]
+        else:
+            tensors = []
+            for input in inputs:
+                if torch.is_tensor(input):
+                    tensors.append(input)
+
+        ctx.save_for_backward(*tensors)
+
+        with torch.no_grad(), enable_checkpointing():
+            if input_atomic:
+                assert len(inputs) == 1
+                output = function(inputs[0])
+            else:
+                output = function(*inputs)
+        return output
+
+    @staticmethod
+    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
+        output, input_leaf = ctx.recomputed.pop()
+
+        if isinstance(output, tuple):
+            outputs = output
+        else:
+            outputs = (output,)
+        if any(torch.is_tensor(y) and y.requires_grad for y in outputs):
+            tensors = tuple([x for x in outputs if torch.is_tensor(x) and x.requires_grad])
+            torch.autograd.backward(tensors, grad_output)
+
+        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
+        grad_input.extend(x.grad if torch.is_tensor(x) else None for x in input_leaf)
+        return tuple(grad_input)
+
+
+class Recompute(torch.autograd.Function):
+    @staticmethod
+    # type: ignore[override]
+    def forward(
+        ctx: Context,
+        phony: Tensor,
+        recomputed: Deque[Recomputed],
+        rng_states: Deque[RNGStates],
+        function: Function,
+        input_atomic: bool,
+        *inputs,
+    ) -> Tensor:
+        ctx.recomputed = recomputed
+        ctx.rng_states = rng_states
+
+        ctx.function = function
+        ctx.input_atomic = input_atomic
+        ctx.inputs = inputs
+        if input_atomic:
+            tensors = [inputs[0]]
+        else:
+            tensors = []
+            for input in inputs:
+                if torch.is_tensor(input):
+                    tensors.append(input)
+        ctx.save_for_backward(*tensors)
+
+        return phony
+
+    @staticmethod
+    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  # pragma: no cover
+        inputs = ctx.inputs
+        inputs_leaf = tuple(x.detach().requires_grad_(x.requires_grad) if torch.is_tensor(x) else x for x in inputs)
+
+        # Get the device for the inputs from a tensor
+        device = None
+        for input in inputs:
+            if torch.is_tensor(input):
+                device = input.device
+                break
+
+        if device is None:
+            raise RuntimeError(f'No tensors found in {inputs}')
+
+        with restore_rng_states(device, ctx.rng_states):
+            with torch.enable_grad(), enable_recomputing():
+                if ctx.input_atomic:
+                    assert len(inputs_leaf) == 1
+                    output = ctx.function(inputs_leaf[0])
+                else:
+                    output = ctx.function(*inputs_leaf)
+
+        ctx.recomputed.append((output, inputs_leaf))
+
+        grad_input: List[None] = [None, None, None, None, None]
+        grad_input.extend(None for _ in ctx.inputs)
+        return tuple(grad_input)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/copy.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/copy.py
new file mode 100644
index 0000000000000000000000000000000000000000..87e124cd42538905dbe056a7300854b776e2df88
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/copy.py
@@ -0,0 +1,108 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Autograd functions for stream-aware CUDA copy.
+
+It is used to overlap copy and computation on the same GPU.
+"""
+from collections import deque
+from typing import Deque, List, Optional, Tuple, Sequence
+
+import torch
+from torch import Tensor
+
+from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream
+
+__all__: List[str] = ["Context", "Copy", "Wait"]
+
+
+Tensors = Sequence[Tensor]
+
+
+# Common interface between :class:`Copy` and :class:`Wait`.
+class Context:
+    prev_stream: AbstractStream
+    next_stream: AbstractStream
+
+
+class Copy(torch.autograd.Function):
+    """Copies tensors on specific streams."""
+
+    @staticmethod
+    # type: ignore[override]
+    def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input,) -> Tensors:
+        ctx.prev_stream = prev_stream
+        ctx.next_stream = next_stream
+
+        output = []
+        output_stream = current_stream(get_device(next_stream))
+
+        with use_stream(prev_stream), use_stream(next_stream):
+            for x in input:
+                if torch.is_tensor(x):
+                    y = x.to(get_device(next_stream), non_blocking=True)
+                    output.append(y)
+
+                    # 'prev_stream' is not where 'x' has been allocated.
+                    record_stream(x, prev_stream)
+                    # 'y' has been allocated on 'next_stream'.
+                    # It might be used on the current stream captured as 'output_stream'.
+                    record_stream(y, output_stream)
+                else:
+                    output.append(x)
+
+        return tuple(output)
+
+    @staticmethod
+    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:
+        prev_stream = ctx.prev_stream
+        next_stream = ctx.next_stream
+
+        grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
+        input_stream = current_stream(get_device(prev_stream))
+
+        with use_stream(prev_stream), use_stream(next_stream):
+            for x in reversed(grad_output):
+                y = x.to(get_device(prev_stream), non_blocking=True)
+                grad_input.appendleft(y)
+
+                # 'next_stream' is not where 'x' has been allocated.
+                record_stream(x, next_stream)
+                # 'y' has been allocated on 'prev_stream'.
+                # It might be used on the current stream captured as 'input_stream'.
+                record_stream(y, input_stream)
+
+        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
+        return grad_streams + tuple(grad_input)
+
+
+class Wait(torch.autograd.Function):
+    """Synchronizes a stream to another stream.
+
+    Place it just before you want to start an operation on the next stream,
+    provided that all operations on the previous stream are done.
+
+    """
+
+    @staticmethod
+    # type: ignore[override]
+    def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input) -> Tensors:
+        ctx.prev_stream = prev_stream
+        ctx.next_stream = next_stream
+
+        wait_stream(next_stream, prev_stream)
+
+        return tuple(x.detach() if torch.is_tensor(x) else x for x in input)
+
+    @staticmethod
+    def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]:
+        prev_stream = ctx.prev_stream
+        next_stream = ctx.next_stream
+
+        wait_stream(prev_stream, next_stream)
+
+        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
+        return grad_streams + grad_input
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/dependency.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/dependency.py
new file mode 100644
index 0000000000000000000000000000000000000000..de3d57e5e16e69503806f5194bfcb981a133d4d1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/dependency.py
@@ -0,0 +1,54 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Arbitrary dependency between two autograd lanes."""
+from typing import List, Tuple
+
+import torch
+from torch import Tensor
+
+from .phony import get_phony
+
+__all__: List[str] = ["fork", "Fork", "join", "Join"]
+
+
+def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
+    """Branches out from an autograd lane of the given tensor."""
+    if torch.is_grad_enabled() and input.requires_grad:
+        input, phony = Fork.apply(input)
+    else:
+        phony = get_phony(input.device, requires_grad=False)
+
+    return input, phony
+
+
+class Fork(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore[override]
+        phony = get_phony(input.device, requires_grad=False)
+        return input.detach(), phony.detach()
+
+    @staticmethod
+    def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor:  # type: ignore[override]
+        return grad_input
+
+
+def join(input: Tensor, phony: Tensor) -> Tensor:
+    """Merge two autograd lanes."""
+    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
+        input = Join.apply(input, phony)
+
+    return input
+
+
+class Join(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor:  # type: ignore[override]
+        return input.detach()
+
+    @staticmethod
+    def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]:  # type: ignore[override]
+        return grad_input, None
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/microbatch.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/microbatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b5646b3b075952f54d6fc85aa3f11892900ec7d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/microbatch.py
@@ -0,0 +1,234 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Manipulation of micro-batches."""
+import typing
+from typing import Any, Callable, List, Union, cast, Sequence
+
+import torch
+from torch import Tensor
+import torch.cuda.comm
+
+__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]
+
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]]
+
+
+class NoChunk:
+    """
+    Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor
+    should not be chunked on the batch dimension and instead be replicated
+    as-is across all micro-batches. This is useful for tensors which might
+    not have any 'batch' semantics for the model.
+    """
+    def __init__(self, inp: Tensor):
+        if not torch.is_tensor(inp):
+            raise TypeError(f'NoChunk only supported for tensors, found: {inp}')
+        self._tensor = inp
+
+    @property
+    def tensor(self):
+        return self._tensor
+
+
+class Batch:
+    """
+    An abstraction representing a microbatch in the pipeline.
+    """
+
+    def __init__(self, values: Union[List[Any], Tensor]) -> None:
+        self._values = values
+        self.atomic = torch.is_tensor(values)
+
+        # Verify at least on tensor
+        if not self.atomic:
+            if not any(torch.is_tensor(value) for value in self._values):
+                raise TypeError(f'No tensors found in batch: {self._values}')
+
+    @property
+    def tensor(self) -> Tensor:
+        """Retrieves the underlying tensor."""
+        if not self.atomic:
+            raise AttributeError("not atomic batch")
+        return cast(Tensor, self._values)
+
+    @property
+    def values(self):
+        """Retrieves the underlying values for the batch"""
+        return self._values
+
+    def find_tensor_idx(self):
+        """
+        Retrieves the index of first tensor found.
+        """
+        if self.atomic:
+            return 0
+        for i, value in enumerate(self._values):
+            if torch.is_tensor(value):
+                return i
+
+        raise TypeError("No tensor found!")
+
+    def get_device(self):
+        """
+        Retrieves the device for this microbatch.
+        """
+        if self.atomic:
+            return self._values.device  # type: ignore[union-attr]
+
+        for value in self._values:
+            if torch.is_tensor(value):
+                return value.device
+
+    def call(self, function: Function) -> "Batch":
+        """Calls a function on the microbatch. It also wraps
+        the output with :class:`Batch`.
+        """
+        if self.atomic:
+            return Batch(function(self._values))
+        else:
+            return Batch(function(*self._values))
+
+    def __repr__(self) -> str:
+        return f"Batch[atomic={self.atomic!r}]({self._values!r})"
+
+    def __iter__(self):
+        if self.atomic:
+            yield self._values
+        else:
+            yield from self._values
+
+    def __len__(self) -> int:
+        return 1 if self.atomic else len(self._values)
+
+    def __getitem__(self, index: int):
+        if not self.atomic:
+            return self._values[index]
+
+        if index != 0:
+            raise IndexError("atomic batch allows index 0 only")
+
+        return self._values
+
+    # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
+    @typing.overload
+    def __setitem__(self, index: int, value: Tensor) -> None:
+        ...
+
+    @typing.overload
+    def __setitem__(self, index: slice, value: Tensors) -> None:
+        ...
+
+    def __setitem__(self, index: Union[int, slice], value) -> None:
+        if isinstance(index, int):
+            self._setitem_by_index(index, value)
+        else:
+            self._setitem_by_slice(index, value)
+
+    def _setitem_by_index(self, index: int, value) -> None:
+        if not self.atomic:
+            i = index
+            self._values = self._values[:i] + (value,) + self._values[i + 1 :]  # type: ignore[operator]
+            return
+
+        if index != 0:
+            raise IndexError("atomic batch allows index 0 only")
+
+        self._values = value
+
+    def _setitem_by_slice(self, index: slice, value) -> None:
+        if not (index.start is index.stop is index.step is None):  # noqa: E714
+            raise NotImplementedError("only slice [:] supported")
+
+        if not self.atomic:
+            self._values = value
+            return
+
+        if len(value) != 1:
+            raise IndexError("atomic batch cannot be replaced with multiple tensors")
+
+        self._values = value[0]
+
+
+def check(first_device, *inputs) -> None:
+    """
+    Checks whether the input contains at least one tensor and each tensor is
+    on the same device as the first partition.
+
+    Raises:
+        ValueError: input does not contain at least one tensor
+
+    """
+
+    if not any(torch.is_tensor(input) for input in inputs):
+        raise TypeError(f'inputs do not have any tensors: {inputs}')
+    if any(torch.is_tensor(input) and input.device != first_device for input in inputs):
+        raise ValueError('All inputs should be on the same device as the first partition')
+
+
+def scatter(*inputs, chunks: int) -> List[Batch]:
+    """Splits an input mini-batch into multiple micro-batches."""
+    if len(inputs) == 1 and isinstance(inputs[0], Tensor):
+        return [Batch(x) for x in inputs[0].chunk(chunks)]
+
+    batches: List[Any] = [[] for _ in range(chunks)]
+    # Actual number of chunks produced
+    num_chunks = -1
+    for input in inputs:
+        if torch.is_tensor(input):
+            # Chunk only tensors.
+            tensors = input.chunk(chunks)
+
+            # Validate number of chunks equal across all inputs.
+            if num_chunks != -1 and num_chunks != len(tensors):
+                raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}')
+            num_chunks = len(tensors)
+
+            for i, tensor in enumerate(tensors):
+                batches[i].append(tensor)
+        else:
+            # Replicate non-tensors or tensors wrapped with 'NoChunk'.
+            for i in range(chunks):
+                if isinstance(input, NoChunk):
+                    # Extract the tensor out.
+                    batches[i].append(input.tensor)
+                else:
+                    batches[i].append(input)
+
+    # Truncate to actual number of chunks
+    batches = batches[:num_chunks]
+
+    return [Batch(x) for x in batches]
+
+
+def gather(outputs: List[Batch]):
+    """Concatenates output micro-batches into a mini-batch."""
+    output: Any
+
+    if outputs[0].atomic:
+        tensors = tuple(b.tensor for b in outputs)
+        output = torch.cat(tensors)
+    else:
+        output_buf: List[Any] = []
+        for i in range(len(outputs[0])):
+            output_type = type(outputs[0][i])
+            current_outputs = []
+            for batch in outputs:
+                if output_type != type(batch[i]):
+                    raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}')
+                current_outputs.append(batch[i])
+
+            if torch.is_tensor(outputs[0][i]):
+                output_buf.append(torch.cat(current_outputs))
+            else:
+                output_buf.append(current_outputs)
+
+        output = tuple(output_buf)
+
+    return output
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/phony.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/phony.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee19ffb6cc82032b5d61db81820eb8034d13ea84
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/phony.py
@@ -0,0 +1,50 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Provides phony for arbitrary dependency in a autograd graph."""
+from typing import Dict, List, Tuple
+
+import torch
+from torch import Tensor
+
+from .stream import default_stream, use_stream
+
+__all__: List[str] = ["get_phony"]
+
+
+_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
+
+
+def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
+    """Get a phony. Phony is tensor without space.
+
+    It is useful to make arbitrary dependency in a autograd graph because it doesn't require any
+    gradient accumulation.
+
+    .. note::
+
+        Phonies for each device are cached. If an autograd function gets a phony
+        internally, the phony must be detached to be returned. Otherwise, the
+        autograd engine will mutate the cached phony in-place::
+
+            class Phonify(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    phony = get_phony(input.device, requires_grad=False)
+                    return phony.detach()  # detach() is necessary.
+
+    """
+    key = (device, requires_grad)
+
+    try:
+        phony = _phonies[key]
+    except KeyError:
+        with use_stream(default_stream(device)):
+            phony = torch.empty(0, device=device, requires_grad=requires_grad)
+
+        _phonies[key] = phony
+
+    return phony
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/pipe.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..139bc701926601b2f8232b2406f47b7c65afa3f3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/pipe.py
@@ -0,0 +1,490 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""The Pipe interface."""
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast
+
+import torch
+from torch import Tensor, nn
+from torch.distributed.rpc import RRef
+import torch.autograd
+import torch.cuda
+
+from . import microbatch
+from .batchnorm import DeferredBatchNorm
+from .pipeline import Pipeline
+from .skip.layout import inspect_skip_layout
+from .skip.skippable import verify_skippables
+from .stream import AbstractStream, new_stream
+
+__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]
+
+
+Device = Union[torch.device, int, str]
+Devices = Union[Iterable[Device], List[Device]]
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+
+if TYPE_CHECKING:
+    # Typechecking: nn.Module is not a Generic
+    Module = nn.Module[TensorOrTensors]  # type: ignore[type-arg]
+    NamedModules = OrderedDict[str, Module]
+else:
+    Module = nn.Module
+    NamedModules = OrderedDict
+
+
+def _recommend_auto_balance(message: str) -> str:
+    """Expands a message with recommendation to :mod:`torchpipe.balance`."""
+    return f"""{message}
+
+If your model is still under development, its optimal balance would change
+frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for
+naive automatic balancing:
+
+  from torch.distributed.pipeline.sync import Pipe
+  from torch.distributed.pipeline.sync.balance import balance_by_time
+
+  partitions = torch.cuda.device_count()
+  sample = torch.empty(...)
+  balance = balance_by_time(partitions, model, sample)
+
+  model = Pipe(model, balance, ...)
+"""
+
+
+def _verify_module(module: nn.Sequential) -> None:
+    if not isinstance(module, nn.Sequential):
+        raise TypeError("module must be nn.Sequential to be partitioned")
+
+    named_children = list(module.named_children())
+    if len(named_children) != len(module):
+        raise ValueError("module with duplicate children is not supported")
+
+
+def _verify_splitting(
+    module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device]
+) -> None:
+    num_parameters = len(list(module.parameters()))
+    num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
+    if num_parameters == num_child_parameters:
+        return
+
+    for i in range(len(partitions)):
+        for j in range(i + 1, len(partitions)):
+            parti = partitions[i]
+            partj = partitions[j]
+            if devices[i] == devices[j]:
+                continue
+            for p in parti.parameters():
+                for q in partj.parameters():
+                    if p is q:
+                        raise ValueError("module with duplicate parameters on distinct devices is not supported")
+
+
+class BalanceError(ValueError):
+    pass
+
+
+def _retrieve_device(module: nn.Module) -> torch.device:
+    """Validates all parameters in the Module have the same device and returns
+    the appropriate device.
+
+    Args:
+        An ``nn.Module`` to process.
+
+    Returns:
+        ``torch.Device`` for the entire module.
+
+    Raises:
+        ValueError:
+            If devices for ``nn.Module`` parameters are not all same.
+    """
+
+    device = None
+    for parameter in module.parameters():
+        if device is None:
+            device = parameter.device
+        elif device != parameter.device:
+            raise ValueError(
+                f'nn.Module: {module}, should have all parameters on a single device,'
+                ' please use .to() to place the module on a single device')
+
+    return device if device is not None else torch.device("cpu")
+
+
+class PipeSequential(nn.Sequential):
+    """
+    Pipe variant of ``nn.Sequential`` which supports multiple inputs.
+    """
+
+    def forward(self, *inputs):
+        for module in self:
+            if isinstance(inputs, Tuple):  # type: ignore[arg-type]
+                inputs = module(*inputs)
+            else:
+                # Don't expand single variables (ex: lists/Tensor)
+                inputs = module(inputs)
+        return inputs
+
+
+class WithDevice(nn.Module):
+    """
+    Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe`
+    that overrides the device for that module. In cases where :class:`Pipe`
+    can't implicitly determine the device for the module and places it on CPU,
+    this wrapper can be used to override the implicit behavior and explicitly
+    specify which device a module should run on.
+
+    The provided module is also moved to the given device via ``.to(device)``
+    by :class:`Pipe`
+
+    Args:
+        module(:class:`torch.nn.Module`): The module to be wrapped.
+        device(:class:`torch.device`): The device to run the module on.
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> fc1 = nn.Linear(16, 8).cuda(0)
+        >>> fc2 = nn.Linear(8, 4).cuda(1)
+        >>> dropout = nn.Dropout()
+        >>>
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
+        >>> # Dropout does not have any parameters/buffers, but we want to
+        >>> # run it on cuda:1 to avoid any GPU to CPU transfers.
+        >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
+        >>> # xdoctest: +SKIP("Needs RPC framework init")
+        >>> model = Pipe(model, chunks=8)
+    """
+    def __init__(self, module: nn.Module, device: torch.device):
+        super().__init__()
+        self._module = module
+        self._device = torch.device(device)
+
+    def forward(self, *args, **kwargs):
+        return self._module(*args, **kwargs)
+
+    @property
+    def module(self):
+        return self._module
+
+    @property
+    def device(self):
+        return self._device
+
+
+def _assemble_partition(modules: List[nn.Module]):
+    modules_list: List[nn.Module] = []
+    for module in modules:
+        if isinstance(module, nn.Sequential):
+            modules_list.extend(module.children())
+        else:
+            modules_list.append(module)
+    return PipeSequential(*modules_list)
+
+
+def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]:
+    partitions = []
+    devices = []
+
+    current_partition = []
+    current_device = None
+    for name, module in modules.named_children():
+        if isinstance(module, WithDevice):
+            # Process device override and move module to appropriate device.
+            device = module.device
+            module = module.module
+            module.to(device)
+        else:
+            device = _retrieve_device(module)
+        if current_device is not None and (current_device != device or device.type == 'cpu'):
+            partitions.append(_assemble_partition(current_partition))
+            devices.append(current_device)
+            current_partition = []
+        current_device = device
+        current_partition.append(module)
+
+    if current_device is not None:
+        partitions.append(_assemble_partition(current_partition))
+        devices.append(current_device)
+
+    partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
+
+    return partitions, devices
+
+
+MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
+
+
+class Pipe(Module):
+    """Wraps an arbitrary :class:`nn.Sequential ` module
+    to train on using synchronous pipeline parallelism. If the module requires
+    lots of memory and doesn't fit on a single GPU, pipeline parallelism is a
+    useful technique to employ for training.
+
+    The implementation is based on the torchgpipe_ paper.
+
+    .. _torchgpipe: https://arxiv.org/abs/2004.09910
+
+    Pipe combines pipeline parallelism with checkpointing to reduce peak
+    memory required to train while minimizing device under-utilization.
+
+    You should place all the modules on the appropriate devices and wrap them
+    into an :class:`nn.Sequential ` module defining the
+    desired order of execution. If a module does not contain any
+    parameters/buffers, it is assumed this module should be executed on CPU
+    and appropriate input tensors to the module are moved to CPU before
+    execution. This behavior can be overridden by the :class:`WithDevice`
+    wrapper which can be used to explicitly specify which device a module
+    should run on.
+
+    Args:
+        module (:class:`nn.Sequential `):
+            sequential module to be parallelized using pipelining. Each module
+            in the sequence has to have all of its parameters on a single
+            device. Each module in the sequence has to either be an nn.Module
+            or :class:`nn.Sequential ` (to combine multiple
+            sequential modules on a single device)
+        chunks (int):
+            number of micro-batches (default: ``1``)
+        checkpoint (str):
+            when to enable checkpointing, one of ``'always'``,
+            ``'except_last'``, or ``'never'`` (default: ``'except_last'``).
+            ``'never'`` disables checkpointing completely, ``'except_last'``
+            enables checkpointing for all micro-batches except the last one
+            and ``'always'`` enables checkpointing for all micro-batches.
+        deferred_batch_norm (bool):
+            whether to use deferred ``BatchNorm`` moving statistics (default:
+            :data:`False`). If set to :data:`True`, we track statistics across
+            multiple micro-batches to update the running statistics per
+            mini-batch.
+
+    Raises:
+        TypeError:
+            the module is not a :class:`nn.Sequential `.
+        ValueError:
+            invalid arguments
+
+    Example::
+        Pipeline of two FC layers across GPUs 0 and 1.
+
+        >>> # Need to initialize RPC framework first.
+        >>> # xdoctest: +SKIP
+        >>> os.environ['MASTER_ADDR'] = 'localhost'
+        >>> os.environ['MASTER_PORT'] = '29500'
+        >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)
+        >>>
+        >>> # Build pipe.
+        >>> fc1 = nn.Linear(16, 8).cuda(0)
+        >>> fc2 = nn.Linear(8, 4).cuda(1)
+        >>> model = nn.Sequential(fc1, fc2)
+        >>> model = Pipe(model, chunks=8)
+        >>> input = torch.rand(16, 16).cuda(0)
+        >>> output_rref = model(input)
+
+    .. note::
+        You can wrap a :class:`Pipe` model with
+        :class:`torch.nn.parallel.DistributedDataParallel` only when the
+        checkpoint parameter of :class:`Pipe` is ``'never'``.
+
+    .. note::
+        :class:`Pipe` only supports intra-node pipelining currently, but
+        will be expanded to support inter-node pipelining in the future.
+        The forward function returns an :class:`~torch.distributed.rpc.RRef`
+        to allow for inter-node pipelining in the future, where the output
+        might be on a remote host. For intra-node pipelining you can use
+        :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the
+        output locally.
+
+    .. warning::
+        :class:`Pipe` is experimental and subject to change.
+    """
+
+    def __init__(
+        self,
+        module: nn.Sequential,
+        chunks: int = 1,
+        checkpoint: str = "except_last",
+        deferred_batch_norm: bool = False,
+    ) -> None:
+        super().__init__()
+
+        # Check if RPC framework is initialized.
+        if not torch.distributed.rpc._is_current_rpc_agent_set():
+            raise RuntimeError(
+                'Please initialize RPC framework for Pipe using '
+                'torch.distributed.rpc.init_rpc')
+
+        chunks = int(chunks)
+        checkpoint = str(checkpoint)
+
+        if chunks <= 0:
+            raise ValueError("number of chunks must be positive integer")
+        if checkpoint not in ["always", "except_last", "never"]:
+            raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
+
+        _verify_module(module)
+
+        # Verify if the underlying skippable modules satisfy integrity. The
+        # integrity can be verified before forward() because it is static.
+        verify_skippables(module)
+
+        self.chunks = chunks
+        self.checkpoint = checkpoint
+
+        if deferred_batch_norm:
+            module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
+
+        self.partitions, self.devices = _split_module(module)
+        _verify_splitting(module, self.partitions, self.devices)
+
+        self._copy_streams: List[List[AbstractStream]] = []
+        self._skip_layout = inspect_skip_layout(self.partitions)
+
+        # Separate CUDA streams for copy.
+        copy_streams = self._ensure_copy_streams()
+
+        # The micro-batch index where the checkpointing stops.
+        checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
+
+        self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
+
+    def __len__(self) -> int:
+        """Counts the length of the underlying sequential module."""
+        return sum(len(p) for p in self.partitions)
+
+    def __getitem__(self, index: int) -> nn.Module:
+        """Gets a layer in the underlying sequential module."""
+        partitions = self.partitions
+        if index < 0:
+            partitions = partitions[::-1]
+
+        for partition in partitions:
+            try:
+                return partition[index]
+            except IndexError:
+                pass
+
+            shift = len(partition)
+
+            if index < 0:
+                index += shift
+            else:
+                index -= shift
+
+        raise IndexError
+
+    def __iter__(self) -> Iterator[nn.Module]:
+        """Iterates over children of the underlying sequential module."""
+        for partition in self.partitions:
+            yield from partition
+
+    # Pipe should manage the device of each partition.
+    # Deny cuda(), cpu(), and to() with device, by TypeError.
+    def cuda(self, device: Optional[Device] = None) -> "Pipe":
+        raise MOVING_DENIED
+
+    def cpu(self) -> "Pipe":
+        raise MOVING_DENIED
+
+    def to(self, *args: Any, **kwargs: Any) -> "Pipe":
+        # Deny these usages:
+        #
+        # - to(device[, dtype, non_blocking])
+        # - to(tensor[, non_blocking])
+        #
+        # But allow this:
+        #
+        # - to(dtype[, non_blocking])
+        #
+        if "device" in kwargs or "tensor" in kwargs:
+            raise MOVING_DENIED
+
+        if args:
+            if isinstance(args[0], (torch.device, int, str)):
+                raise MOVING_DENIED
+            if torch.is_tensor(args[0]):
+                raise MOVING_DENIED
+
+        return super().to(*args, **kwargs)
+
+    def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
+        """Ensures that :class:`Pipe` caches CUDA streams for copy.
+
+        It's worth to cache CUDA streams although PyTorch already manages a
+        pool of pre-allocated CUDA streams, because it may reduce GPU memory
+        fragmentation when the number of micro-batches is small.
+
+        """
+        if not self._copy_streams:
+            for device in self.devices:
+                self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
+
+        return self._copy_streams
+
+    def forward(self, *inputs) -> RRef:
+        """
+        Processes a single input mini-batch through the pipe and returns an
+        :class:`~torch.distributed.rpc.RRef` pointing to the output.
+        :class:`Pipe` is a fairly transparent module wrapper. It doesn't
+        modify the input and output signature of the underlying module. But
+        there's type restriction. Input and output have to contain at least one
+        tensor. This restriction is applied at partition boundaries too.
+
+        The sequence of inputs are fed into the first stage of the pipeline as
+        ``*inputs``. As a result the positional args for this function should
+        match the positional args for the first stage of the pipeline. The same
+        condition applies for output of one stage of the pipeline which is the
+        input for the next stage.
+
+        The input tensor is split into multiple micro-batches based on the
+        ``chunks`` parameter used to initialize :class:`Pipe`. The batch size
+        is assumed to be the first dimension of the tensor and if the batch
+        size is less than ``chunks``, the number of micro-batches is equal to
+        the batch size.
+
+        Only tensors are split into multiple micro-batches, non-Tensor inputs
+        are just replicated as-is in each micro-batch. For non-Tensor outputs
+        in the last stage of the pipeline, they are aggregated as a ``List``
+        and returned the user. For example, if you have 2 micro-batches
+        returning the integer 5, the user would receive the consolidated
+        output of `[5, 5]`
+
+        All the input tensors need to be on the same device as the first
+        partition of the pipeline.
+
+        If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor
+        is not split across micro-batches and is replicated as-is similar to
+        non-tensors.
+
+        Args:
+            inputs: input mini-batch
+
+        Returns:
+            :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch
+
+        Raises:
+            TypeError: input doesn't contain at least one tensor
+
+        """
+        first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu")
+        microbatch.check(first_partition_device, *inputs)
+
+        if not self.devices:
+            # Empty sequential module is not illegal.
+            return RRef(*inputs)
+
+        # Divide a mini-batch into micro-batches.
+        batches = microbatch.scatter(*inputs, chunks=self.chunks)
+
+        # Run pipeline parallelism.
+        self.pipeline.run(batches)
+
+        # Merge the micro-batches into one mini-batch.
+        output = microbatch.gather(batches)
+        return RRef(output)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/pipeline.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c9dc1c93dab43bd28bb76ca6bb6ec389a046d33
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/pipeline.py
@@ -0,0 +1,255 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""The pipeline parallelism of Pipe."""
+from queue import Queue
+from types import TracebackType
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence
+
+import torch
+from torch import Tensor, nn
+from torch.autograd.profiler import record_function
+
+from .checkpoint import Checkpointing
+from .copy import Copy, Wait
+from .dependency import fork, join
+from .microbatch import Batch
+from .skip.layout import SkipLayout
+from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
+from .stream import AbstractStream, current_stream, use_device
+from .worker import Task, create_workers
+
+__all__: List[str] = ["Pipeline"]
+
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+
+ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
+
+# Queue is generic only in stubs.
+# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
+if TYPE_CHECKING:
+    InQueue = Queue[Optional["Task"]]
+    OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
+else:
+    InQueue = Queue
+    OutQueue = Queue
+
+
+def _depend(fork_from: Batch, join_to: Batch) -> None:
+    fork_from_idx = fork_from.find_tensor_idx()
+    join_to_idx = join_to.find_tensor_idx()
+
+    fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx])
+    join_to[join_to_idx] = join(join_to[join_to_idx], phony)
+
+
+def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
+    batch[:] = Copy.apply(prev_stream, next_stream, *batch)
+    # Gradients are only supported for float Tensors.
+    batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
+
+
+def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
+    batch[:] = Wait.apply(prev_stream, next_stream, *batch)
+    # Gradients are only supported for float Tensors.
+    batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
+
+
+def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
+    """Generate schedules for each clock cycle."""
+    # m: number of micro-batches
+    # n: number of partitions
+    # i: index of micro-batch
+    # j: index of partition
+    # k: clock number
+    #
+    # k (i,j) (i,j) (i,j)
+    # - ----- ----- -----
+    # 0 (0,0)
+    # 1 (1,0) (0,1)
+    # 2 (2,0) (1,1) (0,2)
+    # 3       (2,1) (1,2)
+    # 4             (2,2)
+    for k in range(m + n - 1):
+        yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
+
+
+class Pipeline:
+    """The pipeline parallelism for Pipe."""
+
+    def __init__(
+        self,
+        partitions: List[nn.Sequential],
+        devices: List[torch.device],
+        copy_streams: List[List[AbstractStream]],
+        skip_layout: SkipLayout,
+        checkpoint_stop: int,
+    ) -> None:
+        self.partitions = partitions
+        self.devices = devices
+        self.copy_streams = copy_streams
+        self.skip_layout = skip_layout
+        self.checkpoint_stop = checkpoint_stop
+        (self.in_queues, self.out_queues) = create_workers(devices)
+
+    def run(self, batches: List[Batch]) -> None:
+        """Runs pipeline parallelism.
+
+        It modifies the given batches in place.
+
+        """
+        partitions = self.partitions
+        devices = self.devices
+        skip_layout = self.skip_layout
+
+        m = len(batches)
+        n = len(partitions)
+
+        skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
+
+        for schedule in _clock_cycles(m, n):
+            self.fence(batches, schedule, skip_trackers)
+            self.compute(batches, schedule, skip_trackers)
+
+    def fence(
+        self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
+    ) -> None:
+        """Copy micro-batches after computation for the previous micro-batches."""
+        copy_streams = self.copy_streams
+        skip_layout = self.skip_layout
+
+        for i, j in schedule:
+            # Ensure that batches[i-1] is executed after batches[i] in
+            # backpropagation by an explicit dependency.
+            if i != 0 and j != 0:
+                _depend(batches[i - 1], batches[i])
+
+            next_stream = copy_streams[j][i]
+
+            for prev_j, ns, name in skip_layout.copy_policy(j):
+                prev_stream = copy_streams[prev_j][i]
+                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
+
+            if j != 0:
+                prev_stream = copy_streams[j - 1][i]
+                _copy(batches[i], prev_stream, next_stream)
+
+    def compute(
+        self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
+    ) -> None:
+        """Run tasks with synchronization to copy streams."""
+        partitions = self.partitions
+        devices = self.devices
+        copy_streams = self.copy_streams
+        checkpoint_stop = self.checkpoint_stop
+
+        # Disable checkpointing if in eval mode.
+        if not self.partitions[0].training:
+            checkpoint_stop = 0
+
+        n = len(partitions)
+        streams = [current_stream(d) for d in devices]
+        exc_info: Optional[ExcInfo] = None
+
+        # With checkpointing, the autograd graph looks like this diagram:
+        # ┌─────┸──────┐
+        # │    Copy    │
+        # └─────┰──────┘   (fence)
+        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
+        #       ┃          (compute)
+        # ┌─────┸──────┐
+        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
+        # └─────┰──────┘
+        # ┌─────┸──────┐
+        # │ Checkpoint │ [2] Compute a partition within checkpointing.
+        # └─────┰──────┘
+        # ┌─────┸──────┐
+        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
+        # └─────┰──────┘
+        #       ┠ ─ ─ ─ ┐
+        #       ┃ ┌─────┴─────┐
+        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
+        #       ┃ └─────┬─────┘
+        #       ┠ ─ ─ ─ ┘
+        #       ┃
+        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
+        # ┌─────┸──────┐   (fence)
+        # │    Copy    │
+        # └─────┰──────┘
+        for i, j in schedule:
+            batch = batches[i]
+            partition = partitions[j]
+
+            # Synchronize with the copied input. ([1] in the diagram)
+            if j != 0:
+                _wait(batch, copy_streams[j][i], streams[j])
+
+            # Determine whether checkpointing or not.
+            checkpoint = i < checkpoint_stop
+            if checkpoint:
+
+                def function(
+                    *inputs,
+                    partition: nn.Module = partition,
+                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
+                    chunk_id: int = i,
+                    part_id: int = j,
+                ) -> TensorOrTensors:
+                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
+                        return partition(*inputs)
+
+                chk = Checkpointing(function, batch)  # type: ignore[arg-type]
+                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
+                del function, chk
+
+            else:
+
+                def compute(
+                    batch: Batch = batch,
+                    partition: nn.Module = partition,
+                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
+                    chunk_id: int = i,
+                    part_id: int = j,
+                ) -> Batch:
+                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
+                        return batch.call(partition)
+
+                task = Task(streams[j], compute=compute, finalize=None)
+                del compute
+
+            # Compute tasks in parallel. ([2] in the diagram)
+            self.in_queues[j].put(task)
+
+        for i, j in schedule:
+            ok, payload = self.out_queues[j].get()
+
+            # Hold the first exception.
+            if exc_info is not None:
+                continue
+            elif not ok:
+                exc_info = cast(ExcInfo, payload)
+                continue
+
+            task, batch = cast(Tuple[Task, Batch], payload)
+
+            # The copy stream synchronizes to copy the output. ([3] in the
+            # diagram)
+            if j != n - 1:
+                _wait(batch, streams[j], copy_streams[j][i])
+
+            # Finalize tasks. If checkpointing is enabled, here the
+            # recomputation is scheduled at backpropagation. ([4] in the
+            # diagram)
+            with use_device(devices[j]):
+                task.finalize(batch)
+
+            batches[i] = batch
+
+        # Fail at the first exception.
+        if exc_info is not None:
+            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/py.typed b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..f4830a6416775aae091858a4ac5158ce69f7de29
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/py.typed
@@ -0,0 +1,6 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__init__.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e79f0eaa9f6ecef7c31880f25348eb6f4704ec68
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__init__.py
@@ -0,0 +1,11 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Supports efficiency with skip connections."""
+from .namespace import Namespace
+from .skippable import pop, skippable, stash, verify_skippables
+
+__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"]
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f48cff79887d45cec515fbc87b00dc3f7a0cfc9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/layout.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/layout.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47bbd3d583cecfd48e0c002c49c4c30a3c763862
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/layout.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/namespace.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/namespace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9bd9d0a359df1d44821b77e6f6d0df8d56226b41
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/namespace.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/portal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/portal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..665915b1e8e8c081a767fea4b683998a8e8ca516
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/portal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/skippable.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/skippable.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fbb7c1c435edba8a8ffc66ea9f11ef1934357e5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/skippable.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/tracker.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/tracker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b2d9a3c42f4f4cfc15fccc08a3b280ba37277a6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/__pycache__/tracker.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/layout.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..332108af23a30b0d70c9a4b3cf45b32d14c42375
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/layout.py
@@ -0,0 +1,92 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Static skip connection layout of ``@skippable`` modules."""
+from typing import Dict, Iterable, List, Tuple
+
+from torch import nn
+
+from .namespace import Namespace
+
+__all__: List[str] = []
+
+
+class SkipLayout:
+    """Represents a skip connection layout across partitions."""
+
+    # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...}
+    by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]]
+
+    # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
+    by_partition: List[List[Tuple[int, Namespace, str]]]
+
+    def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
+        # The skip routes are already indexed by 'ns, name'.
+        self.by_ns_name = skip_routes
+
+        # Index skip routes by partition number 'j'.
+        self.by_partition = [[] for _ in range(num_partitions)]
+
+        for (ns, name), (prev_j, next_j) in skip_routes.items():
+            self.by_partition[next_j].append((prev_j, ns, name))
+
+        for p in self.by_partition:
+            p.sort()
+
+    def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
+        """Generates skip routes for the given destination partition number.
+        The skip routes are sorted by source partition number in ascending
+        order.
+
+        Yields:
+            Each tuple of (source partition number, namespace, name).
+
+        """
+        for prev_j, ns, name in self.by_partition[next_j]:
+            if prev_j == next_j:
+                # This skip tensor will be popped at the same partition where
+                # it is stashed. In this case, copy is not required.
+                continue
+
+            yield (prev_j, ns, name)
+
+    def requires_copy(self, ns: Namespace, name: str) -> bool:
+        """Whether the given namespace and name requires partition-to-partition
+        copy or not.
+        """
+        prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1))
+        return prev_j != next_j
+
+
+def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout:
+    """Inspects the skip connection layout in the given partitions."""
+    # NOTE(sublee): Hide circular import inside this subroutine. Circular
+    # import is not ideal but placing this logic near to SkipLayout may
+    # increase cohesion of code.
+    from .skippable import Skippable
+
+    skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {}
+    stashed_at: Dict[Tuple[Namespace, str], int] = {}
+
+    for j, partition in enumerate(partitions):
+        def inspect_layer(layer):
+            if not isinstance(layer, Skippable):
+                return
+
+            for ns, name in layer.stashable():
+                stashed_at[(ns, name)] = j
+
+            for ns, name in layer.poppable():
+                prev_j = stashed_at.pop((ns, name))
+                skip_routes[(ns, name)] = (prev_j, j)
+
+        if isinstance(partition, nn.Sequential):
+            for layer in partition:
+                inspect_layer(layer)
+        else:
+            inspect_layer(partition)
+
+    return SkipLayout(len(partitions), skip_routes)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/namespace.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/namespace.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fcb2687be37496c932e0c8e83ed811f82ec0ae2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/namespace.py
@@ -0,0 +1,50 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Provides isolated namespace of skip tensors."""
+import abc
+from functools import total_ordering
+from typing import Any
+import uuid
+
+__all__ = ["Namespace"]
+
+
+@total_ordering
+class Namespace(metaclass=abc.ABCMeta):
+    """Namespace for isolating skip tensors used by :meth:`isolate()
+    `.
+    """
+
+    __slots__ = ("id",)
+
+    def __init__(self) -> None:
+        self.id = uuid.uuid4()
+
+    def __repr__(self) -> str:
+        return f""
+
+    def __hash__(self) -> int:
+        return hash(self.id)
+
+    # Namespaces should support ordering, since SkipLayout will sort tuples
+    # including a namespace. But actual order between namespaces is not
+    # important. That's why they are ordered by version 4 UUID which generates
+    # random numbers.
+    def __lt__(self, other: Any) -> bool:
+        if isinstance(other, Namespace):
+            return self.id < other.id
+        return False
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, Namespace):
+            return self.id == other.id
+        return False
+
+
+# 'None' is the default namespace,
+# which means that 'isinstance(None, Namespace)' is 'True'.
+Namespace.register(type(None))
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/portal.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/portal.py
new file mode 100644
index 0000000000000000000000000000000000000000..97481245907908074364ac90b5ff7c918d1c423e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/portal.py
@@ -0,0 +1,231 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
+autograd engine. The shared context of three functions (:class:`PortalBlue`,
+:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
+one of the most important feature of :mod:`torchpipe.skip`.
+
+The metaphor is inspired by Portal™ from Valve.
+
+"""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import Tensor
+
+from ..copy import Context as CopyContext
+from ..copy import Copy
+from ..phony import get_phony
+from ..stream import AbstractStream, get_device
+
+__all__: List[str] = []
+
+
+class Portal:
+    """A portal for a tensor."""
+
+    def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
+        self.put_tensor(tensor, tensor_life)
+        self.grad: Optional[Tensor] = None
+
+    def blue(self) -> Tensor:
+        """Creates a :class:`PortalBlue` which hides the underlying tensor from
+        the autograd engine.
+
+        Join the returning phony to the main lane of the autograd graph to
+        assure the correct backpropagation::
+
+            PortalBlue --+
+                         |
+            ---------- Join --
+
+        """
+        tensor = self.use_tensor()
+
+        if tensor is None:
+            return get_phony(torch.device("cpu"), requires_grad=False)
+
+        return PortalBlue.apply(self, tensor)
+
+    def orange(self, phony: Tensor) -> Optional[Tensor]:
+        """Creates a :class:`PortalOrange` which retrieves the hidden tensor
+        without losing ability of backpropagation.
+
+        Give a phony forked from the main lane of an autograd graph::
+
+                +-- PortalOrange --+
+                |                  |
+            -- Fork --------- f(a, b) --
+
+        """
+        self.check_tensor_life()
+
+        if self.tensor is None:
+            return self.use_tensor()
+
+        return PortalOrange.apply(self, phony)
+
+    def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
+        """Copies the hidden tensor by a :class:`PortalCopy`.
+
+        Give a phony and use the returning phony to keep backpropagation::
+
+                +-- PortalCopy --+
+                |                |
+            -- Fork ---------- Join --
+
+        """
+        if self.tensor is None:
+            return get_phony(torch.device("cpu"), requires_grad=False)
+
+        return PortalCopy.apply(self, prev_stream, next_stream, phony)
+
+    def check_tensor_life(self) -> None:
+        if self.tensor_life <= 0:
+            raise RuntimeError("tensor in portal has been removed")
+
+    def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
+        """Stores a tensor into this portal."""
+        # [Life of Tensor through Portal]
+        #
+        # The tensor can be retrieved by use_tensor() up to 'tensor_life'
+        # times. When the life becomes 0, the tensor will be deleted for
+        # deallocation in CUDA memory.
+        #
+        # The below events participate in a tensor through a portal.
+        # Note that [x] denotes the events which call use_tensor():
+        #
+        #  1. [x] blue()
+        #  2. [ ]   PortalBlue.forward
+        #  3. [ ] copy()
+        #  4. [ ]   PortalCopy.forward
+        #  5. [ ] orange()
+        #  6. [x]   PortalOrange.forward
+        # - - - - - - - - - - - - - - - - - - - - - - - - - - -
+        #  7. [ ] orange() (recomputed)
+        #  8. [x]   PortalOrange.forward (recomputed)
+        #  9. [ ]   PortalOrange.backward
+        # 10. [ ] PortalCopy.backward
+        # 11. [x] blue() (recomputed)
+        # 12. [ ]   PortalBlue.forward (recomputed)
+        # 13. [ ]   PortalBlue.backward
+        #
+        self.tensor_life = tensor_life
+
+        if tensor_life > 0:
+            self.tensor = tensor
+        else:
+            self.tensor = None
+
+    def use_tensor(self) -> Optional[Tensor]:
+        """Retrieves the underlying tensor and decreases the tensor  life. When
+        the life becomes 0, it the tensor will be removed.
+        """
+        self.check_tensor_life()
+
+        tensor = self.tensor
+
+        self.tensor_life -= 1
+
+        if self.tensor_life <= 0:
+            self.tensor = None
+
+        return tensor
+
+    def put_grad(self, grad: Tensor) -> None:
+        """Stores a gradient into this portal."""
+        self.grad = grad
+
+    def use_grad(self) -> Tensor:
+        """Retrieves and removes the underlying gradient. The gradient is
+        always ephemeral.
+        """
+        if self.grad is None:
+            raise RuntimeError("grad in portal has been removed or never set")
+
+        grad = self.grad
+        self.grad = None
+        return grad
+
+
+# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
+# :class:`PortalCopy`.
+class Context(CopyContext):
+    portal: Portal
+
+
+class PortalBlue(torch.autograd.Function):
+    """Hides a tensor from the autograd engine by a :class:`Portal`."""
+
+    @staticmethod
+    # type: ignore[override]
+    def forward(
+        ctx: Context,
+        portal: Portal,
+        # This tensor must be retrieved by portal.use_tensor().
+        tensor: Tensor,
+    ) -> Tensor:
+        ctx.portal = portal
+
+        phony = get_phony(tensor.device, requires_grad=False)
+        return phony.detach()
+
+    @staticmethod
+    # type: ignore[override]
+    def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
+        # The paired PortalOrange should keep the gradient.
+        grad = ctx.portal.use_grad()
+        return None, grad
+
+
+class PortalOrange(torch.autograd.Function):
+    """Retrieves the hidden tensor from a :class:`Portal`."""
+
+    @staticmethod
+    # type: ignore[override]
+    def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
+        ctx.portal = portal
+
+        tensor = portal.use_tensor()
+        assert tensor is not None
+
+        return tensor.detach()
+
+    @staticmethod
+    def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]:  # type: ignore[override]
+        # The paired PortalBlue will use the gradient.
+        ctx.portal.put_grad(grad)
+        return None, None
+
+
+class PortalCopy(torch.autograd.Function):
+    """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
+    tensor with copied one.
+    """
+
+    @staticmethod
+    # type: ignore[override]
+    def forward(
+        ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
+    ) -> Tensor:
+        ctx.portal = portal
+
+        assert portal.tensor is not None
+        (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)
+
+        phony = get_phony(get_device(next_stream), requires_grad=False)
+        return phony.detach()
+
+    @staticmethod
+    # type: ignore[override]
+    def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
+        portal = ctx.portal
+
+        assert portal.grad is not None
+        _, _, portal.grad = Copy.backward(ctx, portal.grad)
+
+        return None, None, None, None
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/skippable.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/skippable.py
new file mode 100644
index 0000000000000000000000000000000000000000..8deaa5bb7b0ea3df2c36ecf38ed964eaac2130d8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/skippable.py
@@ -0,0 +1,431 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""The user interface to define skip connections."""
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    ClassVar,
+    Dict,
+    FrozenSet,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Sequence,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
+
+from torch import Tensor, nn
+
+from ..microbatch import Batch
+from .namespace import Namespace
+from .tracker import current_skip_tracker
+
+__all__ = ["skippable", "stash", "pop", "verify_skippables"]
+
+
+Tensors = Sequence[Tensor]
+TensorOrTensors = Union[Tensor, Tensors]
+
+StashPop = Union["stash", "pop"]
+StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
+if TYPE_CHECKING:
+    # Typechecking: nn.Module is not a Generic
+    SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]  # type: ignore[type-arg]
+else:
+    SkippableModule = nn.Module
+
+T = TypeVar("T", bound="Skippable")
+
+
+class Skippable(nn.Module):
+    """The base class for skippable modules.
+
+    Do not use this class directly. Define a subclass by :func:`skippable`
+    instead.
+
+    """
+
+    module_cls: ClassVar[Type[SkippableModule]]
+    stashable_names: ClassVar[FrozenSet[str]]
+    poppable_names: ClassVar[FrozenSet[str]]
+
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        super().__init__()
+        self.module = self.module_cls(*args, **kwargs)  # type: ignore[call-arg]
+        self.namespaces: Dict[str, Namespace] = {}
+
+    def __repr__(self) -> str:
+        return f"@skippable({self.module})"
+
+    def namespaced(self, name: str) -> Tuple[Namespace, str]:
+        """Prepend namespace for the given skip name."""
+        ns = self.namespaces.get(name)
+        ns = cast(Namespace, ns)
+        return (ns, name)
+
+    def stashable(self) -> Iterable[Tuple[Namespace, str]]:
+        """Iterate over namespaced skip names to be stashed."""
+        for name in self.stashable_names:
+            yield self.namespaced(name)
+
+    def poppable(self) -> Iterable[Tuple[Namespace, str]]:
+        """Iterate over namespaced skip names to be popped."""
+        for name in self.poppable_names:
+            yield self.namespaced(name)
+
+    def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
+        r"""Isolate a specified subset or the whole set of skip tensors.
+
+        In a single sequential module, skip tensors with the same
+        name are not allowed unless they are isolated by different namespaces.
+
+        Here's an example using the same name for skip tensors twice. Each pair
+        of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1``
+        and ``ns2``. There is no conflict anymore::
+
+            ns1 = Namespace()
+            ns2 = Namespace()
+
+            model = nn.Sequential(
+                Layer1().isolate(ns1),
+                Layer1().isolate(ns2),
+                Layer2(),
+                Layer3().isolate(ns2),
+                Layer3().isolate(ns1),
+            )
+
+        When `only` parameter is omitted, all skip tensors are isolated. You
+        can isolate a subset of skip tensors by passing `only` parameter::
+
+            ns_alice = Namespace()
+            ns_bob = Namespace()
+
+            model = nn.Sequential(
+                ...
+                StashStashPop().isolate(ns_alice, only=['alice']) \
+                               .isolate(ns_bob, only=['bob']),
+                ...
+            )
+
+        Args:
+            ns (Namespace):
+                namespace for isolation
+
+        Keyword Args:
+            only (iterable of strs):
+                names of specific skip tensors to be isolated (omit this option
+                to isolate all skip tensors declared in this module)
+
+        Returns:
+            this module itself
+
+        """
+        names: Iterable[str]
+
+        if only is None:
+            names = self.stashable_names | self.poppable_names
+        else:
+            names = set(only)
+
+        for name in names:
+            self.namespaces[name] = ns
+
+        return self
+
+    def dispatch(
+        self,
+        input,
+        handle_stash: Callable[[str, Optional[Tensor]], None],
+        handle_pop: Callable[[str], Optional[Tensor]],
+    ):
+        """Dispatch :class:`stash` or :class:`pop` commands.
+
+        The commands are generated by the module's ``forward()``.
+        """
+        generator = self.module(input)
+
+        if not isinstance(generator, Generator):
+            # The underlying module returned output without any yield.
+            output = generator
+            return output
+
+        try:
+            op = next(generator)
+
+            while True:
+                if isinstance(op, stash):
+                    handle_stash(op.name, op.tensor)
+                    op = next(generator)
+                    continue
+
+                if isinstance(op, pop):
+                    tensor = handle_pop(op.name)
+                    op = generator.send(tensor)
+                    continue
+
+                raise TypeError(f"{op!r} is not a command from @skippable")
+
+        except StopIteration as stop:
+            output = stop.args[0]
+            return output
+
+    def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors:
+        """Perform the forward propagation.
+
+        :class:`stash` or :class:`pop` commands will be handled by portals
+        silently. The portals won't be exposed to users.
+
+        Raises:
+            RuntimeError:
+                illegal 'stash' or 'pop' is found.
+
+        """
+        skip_tracker = current_skip_tracker()
+        stashed_tensors: Dict[str, Optional[Tensor]] = {}
+
+        # Load skip tensors that might be popped.
+        poppable_tensors = {}
+        batch = Batch(input)
+        for ns, name in self.poppable():
+            try:
+                poppable_tensors[name] = skip_tracker.load(batch, ns, name)
+            except KeyError as e:
+                raise RuntimeError(f"'{name}' has not been stashed") from e
+        input = batch.values
+
+        # Handle skip commands.
+        def handle_stash(name: str, tensor: Optional[Tensor]) -> None:
+            if name not in self.stashable_names:
+                raise RuntimeError(f"'{name}' has not been declared as stashable")
+            stashed_tensors[name] = tensor
+
+        def handle_pop(name: str) -> Optional[Tensor]:
+            if name not in self.poppable_names:
+                raise RuntimeError(f"'{name}' has not been declared as poppable")
+            return poppable_tensors.pop(name)
+
+        output = self.dispatch(input, handle_stash, handle_pop)
+
+        # All declared skips must be stashed or popped.
+        not_stashed = self.stashable_names - stashed_tensors.keys()
+        if not_stashed:
+            comma_names = ", ".join(f"'{n}'" for n in not_stashed)
+            raise RuntimeError(f"{comma_names} must be stashed but have not")
+
+        not_popped = poppable_tensors.keys()
+        if not_popped:
+            comma_names = ", ".join(f"'{n}'" for n in not_popped)
+            raise RuntimeError(f"{comma_names} must be popped but have not")
+
+        # Save stashed skip tensors.
+        batch = Batch(output)
+        for ns, name in self.stashable():
+            tensor = stashed_tensors[name]
+            skip_tracker.save(batch, ns, name, tensor)
+        output = batch.values
+
+        return output
+
+
+# TODO(sublee): Move to above of Skippable class for better read flow.
+def skippable(
+    stash: Iterable[str] = (), pop: Iterable[str] = (),
+) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
+    """Define a decorator to create :class:`nn.Module ` with skip connections.
+
+    These decorated modules are called "skippable". This functionality works perfectly
+    fine even when the module is not wrapped by :class:`~torch.distributed.pipeline.sync.Pipe`.
+
+    Each skip tensor is managed by its name. Before manipulating skip tensors,
+    a skippable module must statically declare the names for skip tensors by
+    `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be
+    stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield
+    pop(name)``.
+
+    Here is an example with three layers. A skip tensor named "1to3" is stashed
+    and popped at the first and last layer, respectively::
+
+        @skippable(stash=['1to3'])
+        class Layer1(nn.Module):
+            def forward(self, input):
+                yield stash('1to3', input)
+                return f1(input)
+
+        class Layer2(nn.Module):
+            def forward(self, input):
+                return f2(input)
+
+        @skippable(pop=['1to3'])
+        class Layer3(nn.Module):
+            def forward(self, input):
+                skip_1to3 = yield pop('1to3')
+                return f3(input) + skip_1to3
+
+        model = nn.Sequential(Layer1(), Layer2(), Layer3())
+
+    One skippable module can stash or pop multiple skip tensors::
+
+        @skippable(stash=['alice', 'bob'], pop=['carol'])
+        class StashStashPop(nn.Module):
+            def forward(self, input):
+                yield stash('alice', f_alice(input))
+                yield stash('bob', f_bob(input))
+                carol = yield pop('carol')
+                return input + carol
+
+    Every skip tensor must be associated with exactly one pair of `stash` and
+    `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this
+    restriction automatically when wrapping a module. You can also check the
+    restriction by :func:`verify_skippables`
+    without :class:`~torch.distributed.pipeline.sync.Pipe`.
+
+    """
+    stashable_names = frozenset(stash)
+    poppable_names = frozenset(pop)
+
+    def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]:
+        name = module_cls.__name__
+        bases = (Skippable,)
+        attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names}
+        return type(name, bases, attrs)
+
+    return extend_skippable
+
+
+class stash:
+    """The command to stash a skip tensor.
+
+    ::
+
+        def forward(self, input):
+            yield stash('name', input)
+            return f(input)
+
+    Args:
+        name (str): name of skip tensor
+        input (torch.Tensor or None): tensor to pass to the skip connection
+
+    """
+
+    __slots__ = ("name", "tensor")
+
+    def __init__(self, name: str, tensor: Optional[Tensor]) -> None:
+        self.name = name
+        self.tensor = tensor
+
+
+class pop:
+    """The command to pop a skip tensor.
+
+    ::
+
+        def forward(self, input):
+            skip = yield pop('name')
+            return f(input) + skip
+
+    Args:
+        name (str): name of skip tensor
+
+    Returns:
+        the skip tensor previously stashed by another layer under the same name
+
+    """
+
+    __slots__ = ("name",)
+
+    def __init__(self, name: str) -> None:
+        self.name = name
+
+
+def verify_skippables(module: nn.Sequential) -> None:
+    """Verify if the underlying skippable modules satisfy integrity.
+
+    Every skip tensor must have only one pair of `stash` and `pop`. If there
+    are one or more unmatched pairs, it will raise :exc:`TypeError` with the
+    detailed messages.
+
+    Here are a few failure cases. :func:`verify_skippables` will report failure
+    for these cases::
+
+        # Layer1 stashes "1to3".
+        # Layer3 pops "1to3".
+
+        nn.Sequential(Layer1(), Layer2())
+        #               └──── ?
+
+        nn.Sequential(Layer2(), Layer3())
+        #                   ? ────┘
+
+        nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3())
+        #               └───────────────────┘       ^^^^^^
+
+        nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3())
+        #             ^^^^^^      └───────────────────┘
+
+    To use the same name for multiple skip tensors, they must be isolated by
+    different namespaces. See :meth:`isolate()
+    `.
+
+    Raises:
+        TypeError:
+            one or more pairs of `stash` and `pop` are not matched.
+
+    """
+    stashed: Set[Tuple[Namespace, str]] = set()
+    popped: Set[Tuple[Namespace, str]] = set()
+    msgs: List[str] = []
+
+    for layer_name, layer in module.named_children():
+        if not isinstance(layer, Skippable):
+            continue
+
+        for name in layer.stashable_names & layer.poppable_names:
+            msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable"
+            msgs.append(msg)
+
+        for ns, name in layer.stashable():
+            if name in layer.poppable_names:
+                continue
+
+            if (ns, name) in stashed:
+                msg = f"'{layer_name}' redeclared '{name}' as stashable but not isolated by namespace"
+                msgs.append(msg)
+                continue
+
+            stashed.add((ns, name))
+
+        for ns, name in layer.poppable():
+            if name in layer.stashable_names:
+                continue
+
+            if (ns, name) in popped:
+                msg = f"'{layer_name}' redeclared '{name}' as poppable but not isolated by namespace"
+                msgs.append(msg)
+                continue
+
+            if (ns, name) not in stashed:
+                msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed"
+                msgs.append(msg)
+                continue
+
+            popped.add((ns, name))
+
+    for (_, name) in stashed - popped:
+        msg = f"no module declared '{name}' as poppable but stashed"
+        msgs.append(msg)
+
+    if msgs:
+        raise TypeError(
+            "one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs)
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/tracker.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..33cac8b1deaea59110941867fac6250a439e01b3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/skip/tracker.py
@@ -0,0 +1,180 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Tracks skip tensors on a thread."""
+from contextlib import contextmanager
+import threading
+from typing import Dict, Generator, List, Optional, Tuple
+
+from torch import Tensor
+
+from ..checkpoint import is_checkpointing
+from ..dependency import fork, join
+from ..microbatch import Batch
+from ..stream import AbstractStream
+from .layout import SkipLayout
+from .namespace import Namespace
+from .portal import Portal
+
+__all__: List[str] = []
+
+
+class SkipTracker:
+    """Tracks saved skip tensors.
+
+    It will update the given micro-batch in place. This is because when it
+    manipulates the underlying skip tensors, the current micro-batch also has
+    to be connected with the skip tensors.
+
+    One thread has one skip tracker. Call :func:`current_skip_tracker` to get
+    the skip tracker on the current thread.
+
+    """
+
+    def __init__(self) -> None:
+        self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {}
+
+    def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
+        self.tensors[(ns, name)] = tensor
+
+    def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
+        return self.tensors.pop((ns, name))
+
+    def copy(
+        self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
+    ) -> None:
+        raise TypeError("copy is not supported for non-portal skip tensors")
+
+
+class SkipTrackerThroughPotals(SkipTracker):
+    """Tracks saved skip tensors through portals. The skip tensors will be
+    hidden in portals so that the autograd engine does not need to track them.
+
+    This tracker is only used when the training or evaluating module is wrapped
+    with :class:`torchpipe.Pipe`.
+
+    """
+
+    def __init__(self, skip_layout: SkipLayout) -> None:
+        super().__init__()
+        self.skip_layout = skip_layout
+        self.portals: Dict[Tuple[Namespace, str], Portal] = {}
+
+    def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
+        """Saves the stashed skip tensor in a portal. The portal is then
+        connected to the given micro-batch with :class:`Join`.
+        """
+        if not self.skip_layout.requires_copy(ns, name):
+            super().save(batch, ns, name, tensor)
+            return
+
+        # See [Tensor Life of Portal] at Portal.put_tensor() to understand the
+        # below tensor_life values. Here are the selected events which retrieve
+        # the tensor in portal:
+        #
+        #  1. [x] blue()
+        #     ...
+        #  6. [x]   PortalOrange.forward
+        #     ...
+        #  8. [x]   PortalOrange.forward (recomputed)
+        #     ...
+        # 11. [x] blue() (recomputed)
+        #
+        if (ns, name) not in self.portals:
+            if is_checkpointing():
+                # Under checkpointing, the tensor used by the first
+                # PortalOrange should be alive in the portal. This tensor will
+                # be used again by the second PortalOrange during the
+                # recomputation.
+                tensor_life = 3  # Delete at [8. PortalOrange.forward (recomputed)]
+            else:
+                tensor_life = 2  # Delete at [6. PortalOrange.forward]
+
+            portal = Portal(tensor, tensor_life)
+            self.portals[(ns, name)] = portal
+
+        else:
+            # Under recomputation, the portal already exists.
+            portal = self.portals[(ns, name)]
+
+            # The existing tensor life already became 0. It should be reset as
+            # 1 to delete the tensor after the second PortalBlue immediately.
+            tensor_life = 1  # Delete at [11. blue() (recomputed)]
+
+            portal.put_tensor(tensor, tensor_life)
+
+        phony = portal.blue()
+        tensor_idx = batch.find_tensor_idx()
+        batch[tensor_idx] = join(batch[tensor_idx], phony)
+
+    def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
+        """Loads a skip tensor from the corresponding portal to pop. The given
+        micro-batch is connected to the portal with :class:`Fork`.
+        """
+        if not self.skip_layout.requires_copy(ns, name):
+            tensor = super().load(batch, ns, name)
+            return tensor
+
+        portal = self.portals[(ns, name)]
+        tensor_idx = batch.find_tensor_idx()
+        batch[tensor_idx], phony = fork(batch[tensor_idx])
+        tensor = portal.orange(phony)
+        return tensor
+
+    def copy(
+        self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
+    ) -> None:
+        """Copies the skip tensor in the corresponding portal. The given
+        micro-batch and the portal will be tied with :class:`Fork` and
+        :class:`Join`.
+        """
+        assert self.skip_layout.requires_copy(ns, name)
+
+        tensor_idx = batch.find_tensor_idx()
+        batch[tensor_idx], phony = fork(batch[tensor_idx])
+
+        portal = self.portals[(ns, name)]
+        phony = portal.copy(prev_stream, next_stream, phony)
+
+        batch[tensor_idx] = join(batch[tensor_idx], phony)
+
+
+class ThreadLocal(threading.local):
+    def __init__(self) -> None:
+        self.skip_tracker: Optional[SkipTracker] = None
+
+
+thread_local = ThreadLocal()
+
+
+@contextmanager
+def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]:
+    """Registers the given skip tracker on the current thread within a
+    context::
+
+        with use_skip_tracker(my_skip_tracker):
+            ...
+
+    """
+    orig = thread_local.skip_tracker
+
+    thread_local.skip_tracker = skip_tracker
+
+    try:
+        yield
+    finally:
+        thread_local.skip_tracker = orig
+
+
+def current_skip_tracker() -> SkipTracker:
+    """Gets the skip tracker on the current thread."""
+    skip_tracker = thread_local.skip_tracker
+
+    if skip_tracker is None:
+        skip_tracker = SkipTracker()
+        thread_local.skip_tracker = skip_tracker
+
+    return skip_tracker
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/stream.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..54b97c29211da62cd347cf188661ab1b41e42efd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/stream.py
@@ -0,0 +1,120 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Utilities for eliminating boilerplate code to handle abstract streams with
+CPU device.
+"""
+from contextlib import contextmanager
+from typing import Generator, List, Union, cast
+
+import torch
+
+__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream",
+                      "use_device", "use_stream", "get_device", "wait_stream", "record_stream",
+                      "is_cuda", "as_cuda"]
+
+
+class CPUStreamType:
+    pass
+
+
+# The placeholder on place of streams for the CPU device instead of CUDA.
+CPUStream = CPUStreamType()
+
+# It represents both CUDA streams and the CPU stream.
+AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
+
+
+def new_stream(device: torch.device) -> AbstractStream:
+    """Creates a new stream for either CPU or CUDA device."""
+    if device.type != "cuda":
+        return CPUStream
+    return torch.cuda.Stream(device)
+
+
+def current_stream(device: torch.device) -> AbstractStream:
+    """:func:`torch.cuda.current_stream` for either CPU or CUDA device."""
+    if device.type != "cuda":
+        return CPUStream
+    return torch.cuda.current_stream(device)
+
+
+def default_stream(device: torch.device) -> AbstractStream:
+    """:func:`torch.cuda.default_stream` for either CPU or CUDA device."""
+    if device.type != "cuda":
+        return CPUStream
+    return torch.cuda.default_stream(device)
+
+
+@contextmanager
+def use_device(device: torch.device) -> Generator[None, None, None]:
+    """:func:`torch.cuda.device` for either CPU or CUDA device."""
+    if device.type != "cuda":
+        yield
+        return
+
+    with torch.cuda.device(device):
+        yield
+
+
+@contextmanager
+def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
+    """:func:`torch.cuda.stream` for either CPU or CUDA stream."""
+    if not is_cuda(stream):
+        yield
+        return
+
+    with torch.cuda.stream(as_cuda(stream)):
+        yield
+
+
+def get_device(stream: AbstractStream) -> torch.device:
+    """Gets the device from CPU or CUDA stream."""
+    if is_cuda(stream):
+        return as_cuda(stream).device
+    return torch.device("cpu")
+
+
+def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
+    """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
+    makes the source stream wait until the target stream completes work queued.
+    """
+    if is_cuda(target):
+        if is_cuda(source):
+            # A CUDA stream waits another CUDA stream.
+            as_cuda(source).wait_stream(as_cuda(target))
+        else:
+            # CPU waits a CUDA stream.
+            as_cuda(target).synchronize()
+
+    # If the target is CPU, synchronization is not required.
+
+
+def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
+    """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
+    if is_cuda(stream):
+        # NOTE(sublee): record_stream() on a shifted view tensor throws
+        # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
+        # protect the tensor against unexpected reallocation, here we use a
+        # temporal tensor associated with the same storage without shifting as
+        # a workaround.
+        #
+        # Issue: https://github.com/pytorch/pytorch/issues/27366
+        #
+        tensor = tensor.new_empty([0]).set_(tensor._typed_storage())
+
+        # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
+        tensor.record_stream(as_cuda(stream))  # type: ignore[arg-type]
+
+
+def is_cuda(stream: AbstractStream) -> bool:
+    """Returns ``True`` if the given stream is a valid CUDA stream."""
+    return stream is not CPUStream
+
+
+def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
+    """Casts the given stream as :class:`torch.cuda.Stream`."""
+    return cast(torch.cuda.Stream, stream)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/utils.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32bad20160168f8994e357176f150dc5aa34012
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/utils.py
@@ -0,0 +1,38 @@
+from torch import nn
+from typing import List, Optional
+
+__all__ = ["partition_model"]
+
+def partition_model(
+        module: nn.Sequential,
+        balance: List[int],
+        devices: Optional[List[int]] = None):
+    """
+    Partions the model accross multiple GPU devices.
+
+    Given an :class:`nn.Sequential ` module, partitions
+    the model across multiple GPU devices according the provided ``balance``
+    and ``devices``.
+
+    Args:
+        module (:class:`nn.Sequential `):
+            Sequential model representing the pipe.
+        balance (List[int]):
+            List indicating the number of layers in each partition.
+        devices (List[int], optional):
+            List indicating the device to use for each partition. Defaults to
+            ``range(len(balance))``
+    """
+    device_idx = 0
+    pipe_idx = 0
+    balanced_pipe = []
+    for num_layers in balance:
+        layers = []
+        for i in range(num_layers):
+            layers.append(module[pipe_idx])
+            pipe_idx += 1
+        device = device_idx if devices is None else devices[device_idx]
+        balanced_pipe.append(nn.Sequential(*layers).to(device))
+        device_idx += 1
+
+    return nn.Sequential(*balanced_pipe)
diff --git a/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/worker.py b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e9ed8ded9a30a630f537406b9ff9aea50faa88c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/pipeline/sync/worker.py
@@ -0,0 +1,132 @@
+# Copyright 2019 Kakao Brain
+#
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""Multithreading in pipeline parallelism."""
+from contextlib import contextmanager
+from queue import Queue
+import sys
+from threading import Thread
+from types import TracebackType
+from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast
+
+import torch
+
+from .microbatch import Batch
+from .stream import AbstractStream, use_device, use_stream
+
+__all__: List[str] = ["Task", "worker", "create_workers", "spawn_workers"]
+
+
+ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
+
+# Queue is generic only in stubs.
+# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
+if TYPE_CHECKING:
+    InQueue = Queue[Optional["Task"]]
+    OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
+else:
+    InQueue = Queue
+    OutQueue = Queue
+
+
+class Task:
+    """A task represents how to compute a micro-batch on a partition.
+
+    It consists of two parts: :meth:`compute` and :meth:`finalize`.
+    :meth:`compute` should be executed in worker threads concurrently.
+    :meth:`finalize` should be executed after when worker threads complete to
+    execute :meth:`compute`.
+
+    :meth:`compute` might be boosted by worker threads. Because it produces
+    several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
+    are not serialized through GIL. So more than one CUDA API call can be
+    produced at the same time.
+
+    """
+
+    def __init__(
+        self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
+    ) -> None:
+        self.stream = stream
+        self._compute = compute
+        self._finalize = finalize
+        self._grad_enabled = torch.is_grad_enabled()
+
+    def compute(self) -> Batch:
+        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
+            return self._compute()
+
+    def finalize(self, batch: Batch) -> None:
+        if self._finalize is None:
+            return
+        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
+            self._finalize(batch)
+
+
+def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None:
+    """Main loop of a worker thread."""
+    with use_device(device):
+        while True:
+            task = in_queue.get()
+
+            if task is None:
+                break
+
+            try:
+                batch = task.compute()
+            except Exception:
+                exc_info = cast(ExcInfo, sys.exc_info())
+                out_queue.put((False, exc_info))
+                continue
+
+            out_queue.put((True, (task, batch)))
+
+    done = (False, None)
+    out_queue.put(done)
+
+
+def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]:
+    """Spawns worker threads. A worker thread is bound to a device."""
+    in_queues: List[InQueue] = []
+    out_queues: List[OutQueue] = []
+
+    # Spawn workers.
+    workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}
+
+    def normalize_device(device: torch.device) -> torch.device:
+        if device.type == "cuda" and device.index is None:
+            return torch.device("cuda", index=torch.cuda.current_device())
+
+        if device.type == "cpu" and device.index is not None:
+            return torch.device("cpu")
+
+        return device
+
+    for device in devices:
+        device = normalize_device(device)
+
+        try:
+            in_queue, out_queue = workers[device]
+        except KeyError:
+            in_queue = Queue()
+            out_queue = Queue()
+            workers[device] = (in_queue, out_queue)
+
+            t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,)
+            t.start()
+
+        in_queues.append(in_queue)
+        out_queues.append(out_queue)
+
+    return (in_queues, out_queues)
+
+@contextmanager
+def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
+    try:
+        (in_queues, out_queues) = create_workers(devices)
+        yield (in_queues, out_queues)
+    finally:
+        pass
diff --git a/MLPY/Lib/site-packages/torch/distributed/remote_device.py b/MLPY/Lib/site-packages/torch/distributed/remote_device.py
new file mode 100644
index 0000000000000000000000000000000000000000..45bde5aefb28f09db425233500179b97edf537cc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/remote_device.py
@@ -0,0 +1,128 @@
+from typing import Optional, Union
+
+import torch
+
+
+class _remote_device:
+    """
+    Represents a device on a remote worker.
+
+    Args:
+        remote_device (str or torch.device): Represents a device on a remote worker.
+            The string format should be one of the following:
+
+                1. "/", where the device field can be parsed as torch.device type.
+                   E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
+                   In addition, the device field can be optional and the default value is "cpu".
+                2. "rank:/", where  is the rank of the
+                   process and device can be parsed as torch.device type.
+                   E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
+                3.  and  are optional and formats like "cpu"
+                    and "cuda:1", just represent local devices.
+    """
+
+    def __init__(self, remote_device: Union[str, torch.device]):
+        PARSE_ERROR = (
+            f"Could not parse remote_device: {remote_device}. The valid format is "
+            "'/' or 'rank:/' or ''"
+        )
+        self._worker_name = None
+        self._rank = None
+        self._device: Optional[Union[str, int, torch.device]] = None
+
+        if isinstance(remote_device, torch.device):
+            self._device = remote_device
+        elif isinstance(remote_device, str):
+            fields = remote_device.split("/")
+            if len(fields) == 2:
+                self._worker_name, self._device = fields
+            elif len(fields) == 1:
+                # Check if this is a valid device.
+                if _remote_device._is_valid_local_device(fields[0]):
+                    self._device = fields[0]
+                else:
+                    self._worker_name = fields[0]
+                    self._device = "cpu"
+            else:
+                raise ValueError(PARSE_ERROR)
+        else:
+            raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
+
+        # Do some basic sanity check (no empty string)
+        if self._worker_name is not None and not self._worker_name:
+            raise ValueError(PARSE_ERROR)
+
+        # Validate the device.
+        self._device = torch.device(self._device)
+
+        # Check for rank based format.
+        if self._worker_name is not None:
+            fields = self._worker_name.split(":")
+            if len(fields) == 2:
+                # rank:/device format, extract rank
+                if fields[0] == "rank" and fields[1].isdigit():
+                    self._rank = int(fields[1])  # type: ignore[assignment]
+                    self._worker_name = None
+                else:
+                    raise ValueError(PARSE_ERROR)
+            elif len(fields) > 2:
+                raise ValueError(PARSE_ERROR)
+
+    @staticmethod
+    def _is_valid_local_device(device):
+        # Check for torch.device
+        try:
+            torch.device(device)
+            return True
+        except Exception:
+            return False
+
+    def worker_name(self) -> Optional[str]:
+        """Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
+        return self._worker_name
+
+    def rank(self) -> Optional[int]:
+        """
+        Returns the rank of remote worker representing the remote device.
+        Returns ``None`` if no rank is available.
+        """
+        return self._rank
+
+    def device(self) -> torch.device:
+        """Return the local device on the remote worker."""
+        return self._device  # type: ignore[return-value]
+
+    def __repr__(self):
+        if self._device is not None:
+            if self._worker_name is not None:
+                return f'{self._worker_name}/{self._device}'
+            elif self._rank is not None:
+                return f'rank:{self._rank}/{self._device}'
+            else:
+                return str(self._device)
+        else:
+            if self._worker_name is not None:
+                return f'{self._worker_name}'
+            elif self._rank is not None:
+                return f'{self._rank}'
+            else:
+                raise RuntimeError('Invalid state!')
+
+    def __eq__(self, other):
+        if not isinstance(other, _remote_device):
+            return False
+
+        if (
+            self._worker_name == other._worker_name
+            and self._device == other._device
+            and self._rank == other._rank
+        ):
+            return True
+
+        return False
+
+
+    def __hash__(self):
+        return hash(self._worker_name) ^ \
+            hash(self._device) ^ \
+            hash(self._rank)
diff --git a/MLPY/Lib/site-packages/torch/distributed/rendezvous.py b/MLPY/Lib/site-packages/torch/distributed/rendezvous.py
new file mode 100644
index 0000000000000000000000000000000000000000..20b86c0e1896f50b4ff478b2df22042f51f4a429
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rendezvous.py
@@ -0,0 +1,256 @@
+try:
+    from urllib.parse import urlparse, urlunparse
+except ImportError as e:
+    raise ImportError(
+        "urllib cannot be found, urlparse from python2 is no longer supported."
+    ) from e
+
+import numbers
+import os
+import sys
+from datetime import timedelta
+from typing import Dict, Optional, Callable, Iterator, Tuple
+
+from torch.distributed import FileStore, PrefixStore, Store, TCPStore
+
+from .constants import default_pg_timeout
+
+
+_rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]] = {}
+
+
+def register_rendezvous_handler(scheme, handler):
+    """
+    Register a new rendezvous handler.
+
+    Before we can run collective algorithms, participating processes
+    need to find each other and exchange information to be able to
+    communicate. We call this process rendezvous.
+
+    The outcome of the rendezvous process is a triplet containing a
+    shared key/value store, the rank of the process, and the total
+    number of participating processes.
+
+    If none of the bundled rendezvous methods apply to your execution
+    environment you can opt to register your own rendezvous handler.
+    Pick a unique name and use the URL scheme to identify it when
+    calling the `rendezvous()` function.
+
+    Args:
+        scheme (str): URL scheme to identify your rendezvous handler.
+        handler (function): Handler that is invoked when the
+            `rendezvous()` function is called with a URL that uses
+            the corresponding scheme. It must be a generator function
+            that yields the triplet.
+    """
+    global _rendezvous_handlers
+    if scheme in _rendezvous_handlers:
+        raise RuntimeError(
+            f"Rendezvous handler for {scheme}:// already registered"
+        )
+    _rendezvous_handlers[scheme] = handler
+
+
+# Query will have format "rank=0&world_size=1" and is
+# converted into {"rank": 0, "world_size": 1}
+def _query_to_dict(query: str) -> Dict[str, str]:
+    return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
+
+
+def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs):
+    result = urlparse(url)
+    if world_size_opt is None:
+        world_size = -1
+        if result.scheme == "env":
+            rank = int(os.environ.get("RANK", rank))
+            # If the world_size env variable is not present then it is a dynamic group
+            world_size = int(os.environ.get("WORLD_SIZE", world_size))
+    else:
+        world_size = world_size_opt
+    if rank != -1 or world_size != -1 or world_size_opt is None:
+        query_dict = _query_to_dict(result.query)
+        assert (
+            "rank" not in query_dict and "world_size" not in query_dict
+        ), f"The url: {url} has node-specific arguments(rank, world_size) already."
+        if rank != -1:
+            query_dict["rank"] = str(rank)
+        if world_size != -1 or world_size_opt is None:
+            query_dict["world_size"] = str(world_size)
+        result = result._replace(
+            query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}"
+        )
+        url = urlunparse(result)
+
+    if result.scheme not in _rendezvous_handlers:
+        raise RuntimeError(f"No rendezvous handler for {result.scheme}://")
+    return _rendezvous_handlers[result.scheme](url, **kwargs)
+
+
+def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
+    if not isinstance(url, (str, bytes)):
+        raise RuntimeError(f"`url` must be a string. {type(url)}: {url}")
+
+    if not isinstance(rank, numbers.Integral):
+        raise RuntimeError(f"`rank` must be an integer. {rank}")
+
+    if not isinstance(world_size, numbers.Integral):
+        raise RuntimeError(f"`world_size` must be an integer. {world_size}")
+
+    return _rendezvous_helper(url, rank, world_size, **kwargs)
+
+
+def _create_store_from_options(backend_options, rank):
+    store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None))
+    return store
+
+
+def _rendezvous_error(msg):
+    return ValueError("Error initializing torch.distributed using " + msg)
+
+
+def _file_rendezvous_handler(url: str, **kwargs):
+    def _error(msg):
+        return _rendezvous_error("file:// rendezvous: " + msg)
+
+    result = urlparse(url)
+    path = result.path
+    if sys.platform == "win32":
+        import urllib.request
+
+        full_path = result.netloc + result.path
+        path = urllib.request.url2pathname(full_path)
+        if path:
+            # Normalizing an empty string produces ".", which is not expected.
+            path = os.path.normpath(path)
+
+    if not path:
+        raise _error("path missing")
+    query_dict = _query_to_dict(result.query)
+    if "rank" not in query_dict:
+        raise _error("rank parameter missing")
+    if "world_size" not in query_dict:
+        raise _error("world size parameter missing")
+
+    rank = int(query_dict["rank"])
+    world_size = int(query_dict["world_size"])
+    store = FileStore(path, world_size)
+    yield (store, rank, world_size)
+
+    # If this configuration is invalidated, there is nothing we can do about it
+    raise RuntimeError("Unable to perform rerendezvous using file:// method")
+
+
+def _torchelastic_use_agent_store() -> bool:
+    return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
+
+
+def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=False) -> Store:
+    """
+    Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
+
+    The TCPStore server is assumed to be hosted
+    on ``hostname:port``.
+
+    If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that
+    the agent leader (node rank 0) hosts the TCPStore server (for which the
+    endpoint is specified by the given ``hostname:port``). Hence
+    ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``).
+
+    If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host
+    the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname
+    and port are correctly passed via ``hostname`` and ``port``. All
+    non-zero ranks will create and return a TCPStore client.
+    """
+    # check if port is uint16_t
+    if not 0 <= port < 2**16:
+        raise ValueError(f"port must have value from 0 to 65535 but was {port}.")
+
+    if _torchelastic_use_agent_store():
+        attempt = os.environ["TORCHELASTIC_RESTART_COUNT"]
+        tcp_store = TCPStore(hostname, port, world_size, False, timeout)
+        return PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
+    else:
+        start_daemon = rank == 0
+        return TCPStore(
+            hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv
+        )
+
+
+def _tcp_rendezvous_handler(
+    url: str, timeout: timedelta = default_pg_timeout, **kwargs
+):
+    def _error(msg):
+        return _rendezvous_error("tcp:// rendezvous: " + msg)
+
+    result = urlparse(url)
+    if not result.port:
+        raise _error("port number missing")
+    query_dict = _query_to_dict(result.query)
+    if "rank" not in query_dict:
+        raise _error("rank parameter missing")
+    if "world_size" not in query_dict:
+        raise _error("world size parameter missing")
+
+    rank = int(query_dict["rank"])
+    world_size = int(query_dict["world_size"])
+    use_libuv = query_dict.get("use_libuv", "0") == "1"
+    assert result.hostname is not None
+
+    store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv)
+
+    yield (store, rank, world_size)
+
+    # If this configuration is invalidated, there is nothing we can do about it
+    raise RuntimeError("Unable to perform re-rendezvous using tcp:// method")
+
+
+def _env_rendezvous_handler(
+    url: str, timeout: timedelta = default_pg_timeout, **kwargs
+):
+    def _error(msg):
+        return _rendezvous_error("env:// rendezvous: " + msg)
+
+    def _env_error(var):
+        return _error(f"environment variable {var} expected, but not set")
+
+    def _get_env_or_raise(env_var: str) -> str:
+        env_val = os.environ.get(env_var, None)
+        if not env_val:
+            raise _env_error(env_var)
+        else:
+            return env_val
+
+    result = urlparse(url)
+    query_dict = _query_to_dict(result.query)
+
+    rank: int
+    world_size: int
+    master_port: int
+    master_addr: str
+
+    if "rank" in query_dict:
+        rank = int(query_dict["rank"])
+    else:
+        rank = int(_get_env_or_raise("RANK"))
+
+    if "world_size" in query_dict:
+        world_size = int(query_dict["world_size"])
+    else:
+        world_size = int(_get_env_or_raise("WORLD_SIZE"))
+
+
+    master_addr = _get_env_or_raise("MASTER_ADDR")
+    master_port = int(_get_env_or_raise("MASTER_PORT"))
+    use_libuv = query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "0")) == "1"
+
+    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
+
+    yield (store, rank, world_size)
+
+    # If this configuration is invalidated, there is nothing we can do about it
+    raise RuntimeError("Unable to perform re-rendezvous using env:// method")
+
+
+register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
+register_rendezvous_handler("env", _env_rendezvous_handler)
+register_rendezvous_handler("file", _file_rendezvous_handler)
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__init__.py b/MLPY/Lib/site-packages/torch/distributed/rpc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7af4dfbcb348ce865fead98a154d543ad2c3e3e9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/__init__.py
@@ -0,0 +1,249 @@
+from datetime import timedelta
+import logging
+import os
+import threading
+import warnings
+from typing import Generator, Tuple
+from urllib.parse import urlparse
+
+import torch
+import torch.distributed as dist
+
+logger = logging.getLogger(__name__)
+
+
+_init_counter = 0
+_init_counter_lock = threading.Lock()
+
+__all__ = ["is_available"]
+
+def is_available() -> bool:
+    return hasattr(torch._C, "_rpc_init")
+
+
+if is_available() and not torch._C._rpc_init():
+    raise RuntimeError("Failed to initialize torch.distributed.rpc")
+
+
+if is_available():
+    from torch._C._distributed_c10d import Store
+    from torch._C._distributed_rpc import (
+        _disable_jit_rref_pickle,
+        _enable_jit_rref_pickle,
+        _disable_server_process_global_profiler,
+        _enable_server_process_global_profiler,
+        _set_and_start_rpc_agent,
+        _reset_current_rpc_agent,
+        _delete_all_user_and_unforked_owner_rrefs,
+        _destroy_rref_context,
+        _set_profiler_node_id,
+        _is_current_rpc_agent_set,
+        _rref_context_get_debug_info,
+        _cleanup_python_rpc_handler,
+        _invoke_rpc_builtin,
+        _invoke_rpc_python_udf,
+        _invoke_rpc_torchscript,
+        _invoke_remote_builtin,
+        _invoke_remote_python_udf,
+        _invoke_remote_torchscript,
+        _set_rpc_timeout,
+        _get_current_rpc_agent,
+        get_rpc_timeout,
+        enable_gil_profiling,
+        RpcBackendOptions,
+        _TensorPipeRpcBackendOptionsBase,
+        RpcAgent,
+        PyRRef,
+        TensorPipeAgent,
+        RemoteProfilerManager,
+        WorkerInfo,
+        _DEFAULT_INIT_METHOD,
+        _DEFAULT_NUM_WORKER_THREADS,
+        _UNSET_RPC_TIMEOUT,
+        _DEFAULT_RPC_TIMEOUT_SEC,
+    )  # noqa: F401
+
+    from . import api, backend_registry, functions
+    from .api import *  # noqa: F401,F403
+    import numbers
+
+    import torch.distributed.autograd as dist_autograd
+
+    from .backend_registry import BackendType
+    from .options import TensorPipeRpcBackendOptions  # noqa: F401
+    from .server_process_global_profiler import (
+        _server_process_global_profile,
+    )
+
+    rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
+
+    __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"]
+    __all__ = __all__ + api.__all__ + backend_registry.__all__  # noqa: PLE0605
+
+    def init_rpc(
+        name,
+        backend=None,
+        rank=-1,
+        world_size=None,
+        rpc_backend_options=None,
+    ):
+        r"""
+        Initializes RPC primitives such as the local RPC agent
+        and distributed autograd, which immediately makes the current
+        process ready to send and receive RPCs.
+
+        Args:
+            name (str): a globally unique name of this node. (e.g.,
+                ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
+                Name can only contain number, alphabet, underscore, colon,
+                and/or dash, and must be shorter than 128 characters.
+            backend (BackendType, optional): The type of RPC backend
+                implementation. Supported values is
+                ``BackendType.TENSORPIPE`` (the default).
+                See :ref:`rpc-backends` for more information.
+            rank (int): a globally unique id/rank of this node.
+            world_size (int): The number of workers in the group.
+            rpc_backend_options (RpcBackendOptions, optional): The options
+                passed to the RpcAgent constructor. It must be an agent-specific
+                subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
+                and contains agent-specific initialization configurations. By
+                default, for all agents, it sets the default timeout to 60
+                seconds and performs the rendezvous with an underlying process
+                group initialized using ``init_method = "env://"``,
+                meaning that environment variables ``MASTER_ADDR`` and
+                ``MASTER_PORT`` need to be set properly. See
+                :ref:`rpc-backends` for more information and find which options
+                are available.
+        """
+        torch._C._log_api_usage_once("torch.distributed.init_rpc")
+        if backend is not None and not isinstance(
+            backend, backend_registry.BackendType
+        ):
+            raise TypeError("Argument backend must be a member of BackendType")
+
+        if rpc_backend_options is not None and not isinstance(
+            rpc_backend_options, RpcBackendOptions
+        ):
+            raise TypeError(
+                "Argument rpc_backend_options must be an instance of RpcBackendOptions"
+            )
+
+        # Try to detect the backend from the options
+        if backend is None and rpc_backend_options is not None:
+            for candidate_backend in BackendType:
+                if isinstance(
+                    rpc_backend_options,
+                    type(
+                        backend_registry.construct_rpc_backend_options(
+                            candidate_backend
+                        )
+                    ),
+                ):
+                    backend = candidate_backend
+                    break
+            else:
+                raise TypeError(
+                    f"Could not infer backend for options {rpc_backend_options}"
+                )
+            # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
+            if backend != BackendType.TENSORPIPE:  # type: ignore[attr-defined]
+                logger.warning(
+                    "RPC was initialized with no explicit backend but with options "  # type: ignore[attr-defined]
+                    "corresponding to %(backend)s, hence that backend will be used "
+                    "instead of the default BackendType.TENSORPIPE. To silence this "
+                    "warning pass `backend=%(backend)s` explicitly.",
+                    {'backend': backend}
+                )
+
+        if backend is None:
+            backend = BackendType.TENSORPIPE  # type: ignore[attr-defined]
+
+        if rpc_backend_options is None:
+            # default construct a set of RPC backend options.
+            rpc_backend_options = backend_registry.construct_rpc_backend_options(
+                backend
+            )
+
+        # Create store, performs rendezvous for static RPC group.
+        if not world_size:
+            # If world_size is not set in construction and also not set in environment variables
+            # The store will be created for the dynamic group setting
+            store = dist._create_store_from_options(rpc_backend_options, rank)
+        else:
+            # This rendezvous state sometimes is destroyed before all processes
+            # finishing handshaking. To avoid that issue, we make it global to
+            # keep it alive.
+            global rendezvous_iterator
+            rendezvous_iterator = dist.rendezvous(
+                rpc_backend_options.init_method, rank=rank, world_size=world_size
+            )
+            store, _, _ = next(rendezvous_iterator)
+        # Use same timeout as RPC.
+        store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout))
+
+        # Use a PrefixStore to distinguish multiple invocations.
+        with _init_counter_lock:
+            global _init_counter
+            store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store)
+            _init_counter += 1
+
+        # Initialize autograd before RPC since _init_rpc_backend guarantees all
+        # processes sync via the store. If we initialize autograd after RPC,
+        # there could be a race where some nodes might have initialized autograd
+        # and others might not have. As a result, a node calling
+        # torch.distributed.autograd.backward() would run into errors since
+        # other nodes might not have been initialized.
+        dist_autograd._init(rank)
+
+        _set_profiler_node_id(rank)
+        # Initialize RPC.
+        _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
+
+    def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
+        type_mapping = {
+            backend: backend_registry.BackendType,
+            store: dist.Store,
+            name: str,
+            rank: numbers.Integral,
+            # world_size can be None for a dynamic group
+            world_size: (numbers.Integral, type(None)),
+            rpc_backend_options: RpcBackendOptions,
+        }
+        for arg, arg_type in type_mapping.items():
+            if not isinstance(arg, arg_type):  # type: ignore[arg-type]
+                raise RuntimeError(
+                    f"Argument {arg} must be of type {arg_type} but got type {type(arg)}"
+                )
+
+    def _init_rpc_backend(
+        backend=BackendType.TENSORPIPE,  # type: ignore[attr-defined]
+        store=None,
+        name=None,
+        rank=-1,
+        world_size=None,
+        rpc_backend_options=None,
+    ):
+
+        _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options)
+
+        if _is_current_rpc_agent_set():
+            raise RuntimeError("RPC is already initialized")
+
+        # Initialize RPC.
+        rpc_agent = backend_registry.init_backend(
+            backend,
+            store=store,
+            name=name,
+            rank=rank,
+            world_size=world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        api._init_rpc_states(rpc_agent)
+
+    @api._require_initialized
+    def _get_debug_info():
+        info = _rref_context_get_debug_info()
+        info.update(api._get_current_rpc_agent().get_debug_info())
+        info.update(dist_autograd._get_debug_info())
+        return info
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9b737902c3b858523709a5890947d89e1c9307a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3de56f5bd7849de63dea094b509526d0c58d786
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dabaaaa871a7ebb77b5856bba0a54353b03d63d2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..024bf88af18e2083d7956f9425ff1d36339b95fa
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e4749b6d3390c66736685ca17c71dd4303889cf
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ada019d95c2758157aae69994e443bdad993018f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31be07c9137041ea6d49004c7419839309962c1f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/options.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/options.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8fefff5c856f1b4cc82ca26a8dab977b20257558
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/options.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49b52870da33581e59fec3915be2c8c58fa9706d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80e125354f502fa35970dbb34b380548e2d6af9a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__init__.py b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..311c67fae2eba7a453a8713e2195cdcd9c2a83cf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__init__.py
@@ -0,0 +1,18 @@
+
+import torch
+
+
+def is_available():
+    return hasattr(torch._C, "_faulty_agent_init")
+
+
+if is_available() and not torch._C._faulty_agent_init():
+    raise RuntimeError("Failed to initialize torch.distributed.rpc._testing")
+
+if is_available():
+    # Registers FAULTY_TENSORPIPE RPC backend.
+    from . import faulty_agent_backend_registry
+    from torch._C._distributed_rpc_testing import (
+        FaultyTensorPipeRpcBackendOptions,
+        FaultyTensorPipeAgent,
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..575dcb52396833511947ccf7f3c7234889701c42
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db87aaf03444ad81290ce1c13892b22648aed4d7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad053948e0d55151e45cb1fbad687641d4930e28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
+
+def _faulty_tensorpipe_construct_rpc_backend_options_handler(
+    rpc_timeout,
+    init_method,
+    num_worker_threads,
+    messages_to_fail,
+    messages_to_delay,
+    num_fail_sends,
+    **kwargs
+):
+    from . import FaultyTensorPipeRpcBackendOptions
+
+    return FaultyTensorPipeRpcBackendOptions(
+        num_worker_threads=num_worker_threads,
+        rpc_timeout=rpc_timeout,
+        init_method=init_method,
+        messages_to_fail=messages_to_fail,
+        messages_to_delay=messages_to_delay,
+        num_fail_sends=num_fail_sends,
+    )
+
+
+def _faulty_tensorpipe_init_backend_handler(
+    store, name, rank, world_size, rpc_backend_options
+):
+    from . import FaultyTensorPipeAgent
+    from . import FaultyTensorPipeRpcBackendOptions
+    from torch.distributed.rpc import api
+
+    if not isinstance(store, dist.Store):
+        raise TypeError(f"`store` must be a c10d::Store. {store}")
+
+    if not isinstance(
+        rpc_backend_options, FaultyTensorPipeRpcBackendOptions
+    ):
+        raise TypeError(
+            f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}"
+        )
+
+    agent = FaultyTensorPipeAgent(
+        store,
+        name,
+        rank,
+        world_size,
+        rpc_backend_options,
+        {},  # reverse_device_map
+        [],  # devices
+    )
+    api._init_rpc_states(agent)
+
+    return agent
+
+
+rpc.backend_registry.register_backend(
+    "FAULTY_TENSORPIPE",
+    _faulty_tensorpipe_construct_rpc_backend_options_handler,
+    _faulty_tensorpipe_init_backend_handler,
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/_utils.py b/MLPY/Lib/site-packages/torch/distributed/rpc/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af5299e19240a15da940406d0a3918aaa9d59cce
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/_utils.py
@@ -0,0 +1,37 @@
+from contextlib import contextmanager
+from typing import cast
+import logging
+from . import api
+from . import TensorPipeAgent
+
+logger = logging.getLogger(__name__)
+
+@contextmanager
+def _group_membership_management(store, name, is_join):
+    token_key = "RpcGroupManagementToken"
+    join_or_leave = "join" if is_join else "leave"
+    my_token = f"Token_for_{name}_{join_or_leave}"
+    while True:
+        # Retrieve token from store to signal start of rank join/leave critical section
+        returned = store.compare_set(token_key, "", my_token).decode()
+        if returned == my_token:
+            # Yield to the function this context manager wraps
+            yield
+            # Finished, now exit and release token
+            # Update from store to signal end of rank join/leave critical section
+            store.set(token_key, "")
+            # Other will wait for this token to be set before they execute
+            store.set(my_token, "Done")
+            break
+        else:
+            # Store will wait for the token to be released
+            try:
+                store.wait([returned])
+            except RuntimeError:
+                logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned)
+                raise
+
+def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
+    agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
+    ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
+    return ret
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/api.py b/MLPY/Lib/site-packages/torch/distributed/rpc/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c2006a93bd7869f86bbc3cb4bb0169d9c9a1d7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/api.py
@@ -0,0 +1,947 @@
+__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
+           "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]
+
+import collections
+import contextlib
+import functools
+import inspect
+import logging
+import threading
+from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING
+
+import torch
+from torch.futures import Future
+
+from torch._C._distributed_rpc import (
+    PyRRef,
+    RemoteProfilerManager,
+    WorkerInfo,
+    TensorPipeAgent,
+    get_rpc_timeout,
+    _cleanup_python_rpc_handler,
+    _delete_all_user_and_unforked_owner_rrefs,
+    _destroy_rref_context,
+    _get_current_rpc_agent,
+    _invoke_remote_builtin,
+    _invoke_remote_python_udf,
+    _invoke_remote_torchscript,
+    _invoke_rpc_builtin,
+    _invoke_rpc_python_udf,
+    _invoke_rpc_torchscript,
+    _is_current_rpc_agent_set,
+    _reset_current_rpc_agent,
+    _set_and_start_rpc_agent,
+)
+
+from .internal import (
+    PythonUDF,
+    RPCExecMode,
+    _internal_rpc_pickler,
+    _build_rpc_profiling_key,
+)
+
+from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
+
+from ._utils import _group_membership_management, _update_group_membership
+
+logger = logging.getLogger(__name__)
+
+# NB: Ignoring RRef leaks during shutdown. Without this, applications have to
+# make sure there is no references to any RRef in the application code and
+# Python GC has done its job to delete those RRefs. This is could result in bad
+# debugging experiences especially when for large applications. Therefore, by
+# default, we are going to ignore RRef leaks during shutdown. This is usually
+# fine as shutdown means applications have done training and no longer care
+# about states.
+#
+# To enable RRef leak checking, set this _ignore_rref_leak to False
+_ignore_rref_leak = True
+_default_pickler = _internal_rpc_pickler
+
+@contextlib.contextmanager
+def _use_rpc_pickler(rpc_pickler):
+    r"""
+    rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
+    """
+    global _default_pickler
+    _default_pickler = rpc_pickler
+    try:
+        yield
+    finally:
+        _default_pickler = _internal_rpc_pickler
+
+
+def _require_initialized(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        if not _is_current_rpc_agent_set():
+            raise RuntimeError(
+                "RPC has not been initialized. Call "
+                "torch.distributed.rpc.init_rpc first."
+            )
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+class AllGatherStates:
+    def __init__(self):
+        # Each `gathered_objects` is an empty dict at beginning.
+        # The leader worker is elected as the first worker in a sorted worker
+        # name list. Whenever there is a worker entering `_all_gather()`, it
+        # runs `_gather_to_leader()` on the leader to add its own name and
+        # data obj to this dict. The leader also adds itself's name to the dict
+        # on calling `_all_gather()`.
+        # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
+        # will broadcast the gathered dict to all follower workers and set their
+        # `gathered_objects` field and the `proceed_signal` field.
+        self.gathered_objects = {}
+        # All workers wait on this signal until it receives all gathered
+        # objects.
+        self.proceed_signal = threading.Event()
+
+
+# States used by `def _all_gather()`.
+# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
+_ALL_WORKER_NAMES: Set[Any] = set()
+_all_gather_dict_lock = threading.RLock()
+_all_gather_sequence_id: Dict[str, int] = {}
+_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
+
+
+def _init_rpc_states(agent):
+    worker_infos = agent.get_worker_infos()
+    global _ALL_WORKER_NAMES
+    _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
+
+    # NB: backend implementation might have already set the rpc_agent.
+    if not _is_current_rpc_agent_set():
+        _set_and_start_rpc_agent(agent)
+
+
+def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
+    with _all_gather_dict_lock:
+        if not worker_names:
+            worker_names = _ALL_WORKER_NAMES
+            assert (
+                worker_name in worker_names
+            ), f"{worker_name} is not expected by leader."
+        states = _all_gather_sequence_id_to_states[sequence_id]
+        assert (
+            worker_name not in states.gathered_objects
+        ), f"{worker_name} reported intent sequence id {sequence_id} twice. "
+        states.gathered_objects[worker_name] = obj
+        if worker_names == set(states.gathered_objects.keys()):
+            states.proceed_signal.set()
+
+
+def _broadcast_to_followers(sequence_id, objects_map):
+    with _all_gather_dict_lock:
+        states = _all_gather_sequence_id_to_states[sequence_id]
+
+    assert (
+        not states.proceed_signal.is_set()
+    ), f"Termination signal sequence id {sequence_id} got set twice."
+    states.gathered_objects = objects_map
+    states.proceed_signal.set()
+
+_thread_local_var = threading.local()
+
+
+@contextlib.contextmanager
+def _wait_all():
+    r"""
+    A context manager that collects all futures returned by ``rpc_async`` and
+    waits them on the context manager's exit; relieving the user of needing
+    to explicitly call wait.
+
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> with rpc._wait_all():
+        >>>    fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
+        >>>    fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
+        >>> #fut_1 and fut_2 are waited on
+    """
+    _thread_local_var.future_list = []
+    try:
+        yield
+    finally:
+        try:
+            torch.futures.wait_all(_thread_local_var.future_list)
+        finally:
+            del _thread_local_var.future_list
+
+
+@_require_initialized
+def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
+    r"""
+    This is similar to torch.distributed.all_gather(), but is using RPC. It
+    picks the worker with the smallest name (alphabetic order) as the leader.
+    Then all followers send their data ``obj`` to the leader. After the leader
+    has received all, it will broadcast the results back to all followers. This
+    function blocks until all workers have received the gathered results.
+    """
+    if not worker_names:
+        assert (
+            _ALL_WORKER_NAMES is not None
+        ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
+        worker_names = _ALL_WORKER_NAMES
+    leader_name = min(worker_names)
+
+    self_name = _get_current_rpc_agent().get_worker_info().name
+
+    with _all_gather_dict_lock:
+        concat_names = "".join(sorted(worker_names))
+        sequence_num = _all_gather_sequence_id.get(concat_names, 0)
+        _all_gather_sequence_id[concat_names] = sequence_num + 1
+        sequence_id = concat_names + str(sequence_num)
+
+    is_leader = leader_name == self_name
+
+    if timeout == UNSET_RPC_TIMEOUT:
+        # Timeout is specified by agent for RPC calls
+        rpc_timeout = get_rpc_timeout()
+        # No timeout for signal
+        signal_timeout = None
+    elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
+        # No timeout for RPC
+        rpc_timeout = timeout
+        # No timeout for signal
+        signal_timeout = None
+    else:
+        # Signal and RPC timeout use the same timeout
+        signal_timeout = rpc_timeout = timeout
+
+    # Phase 1: Followers send it's object to the leader
+    if is_leader:
+        _gather_to_leader(sequence_id, self_name, obj, worker_names)
+    else:
+        rpc_sync(
+            leader_name,
+            _gather_to_leader,
+            args=(sequence_id, self_name, obj, worker_names),
+            timeout=rpc_timeout,
+        )
+
+    with _all_gather_dict_lock:
+        states = _all_gather_sequence_id_to_states[sequence_id]
+
+    # Timeout is either set by function parameter or None (which is indefinite)
+    states.proceed_signal.wait(timeout=signal_timeout)
+
+    # Phase 2: Leader broadcast gathered results to all followers
+    # Leader's signal is the first to be unblocked, after receiving all
+    # followers' data objects.
+    if is_leader:
+        worker_name_to_response_future_dict = {}
+        for follower_name in worker_names - {leader_name}:
+            fut = rpc_async(
+                follower_name,
+                _broadcast_to_followers,
+                args=(sequence_id, states.gathered_objects),
+                timeout=rpc_timeout
+            )
+            worker_name_to_response_future_dict[follower_name] = fut
+
+        errors = []
+        for follower_name, fut in worker_name_to_response_future_dict.items():
+            try:
+                fut.wait()
+            except RuntimeError as ex:
+                errors.append((follower_name, ex))
+
+        if errors:
+            raise RuntimeError(
+                f"Followers {[e[0] for e in errors]} timed out in _all_gather "
+                f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
+            )
+
+    # Clean up for the states using the sequence_id
+    with _all_gather_dict_lock:
+        states = _all_gather_sequence_id_to_states.pop(sequence_id)
+    return states.gathered_objects
+
+
+@_require_initialized
+def _barrier(worker_names):
+    r"""
+    Synchronizes local and remote RPC processes.
+
+    This will block until all local and remote RPC processes specified under worker_names
+    reach this method to wait for all outstanding work to complete.
+
+    Args:
+        worker_names (List[str]): The set of workers to synchronize.
+
+    """
+    try:
+        _all_gather(None, set(worker_names))
+    except RuntimeError as ex:
+        logger.error(
+            "Failed to complete barrier, got error %s", ex
+        )
+
+
+@_require_initialized
+def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
+    r"""
+    Block until all local and remote RPC processes reach this method and wait
+    for all outstanding work to complete. Every RPC process must call this
+    method before exit to perform a graceful shutdown. This should be used to
+    terminate the RPC framework, and there is no guarantee that the RPC
+    framework will work after this method returns.
+    """
+    try:
+        _all_gather(None, timeout=timeout)
+    except RuntimeError as ex:
+        logger.error(
+            "Failed to respond to 'Shutdown Proceed' in time, got error %s", ex
+        )
+        raise ex
+
+
+@_require_initialized
+def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
+    r"""
+    Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
+    stops the local agent from accepting outstanding requests, and shuts
+    down the RPC framework by terminating all RPC threads. If ``graceful=True``,
+    this will block until all local and remote RPC processes reach this method
+    and wait for all outstanding work to complete. Otherwise, if
+    ``graceful=False``, this is a local shutdown, and it does not wait for other
+    RPC processes to reach this method.
+
+    .. warning::
+        For :class:`~torch.futures.Future` objects returned by
+        :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
+        be called after ``shutdown()``.
+
+    Args:
+        graceful (bool): Whether to do a graceful shutdown or not. If True,
+                         this will 1) wait until there is no pending system
+                         messages for ``UserRRefs`` and delete them; 2) block
+                         until all local and remote RPC processes have reached
+                         this method and wait for all outstanding work to
+                         complete.
+
+    Example::
+        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
+        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
+        API for more details. For example,
+
+        export MASTER_ADDR=localhost
+        export MASTER_PORT=5678
+
+        Then run the following code in two different processes:
+
+        >>> # xdoctest: +SKIP
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> # do some work
+        >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
+        >>> # ready to shutdown
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> # wait for worker 0 to finish work, and then shutdown.
+        >>> rpc.shutdown()
+    """
+    if graceful:
+        try:
+            agent = _get_current_rpc_agent()
+            if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
+                _wait_all_workers(timeout)
+                _delete_all_user_and_unforked_owner_rrefs()
+                agent.join(shutdown=True, timeout=timeout)
+            else:
+                # This is a dynamic group so we need to grab the token for the operation
+                my_worker_info = agent.get_worker_info()
+                my_name = my_worker_info.name
+                with _group_membership_management(agent.store, my_name, False):
+                    all_worker_infos = agent.get_worker_infos()
+                    for worker in all_worker_infos:
+                        if worker.name != my_name:
+                            rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
+                    agent.join(shutdown=True, timeout=timeout)
+        finally:
+            # In case of errors, continue to complete the local shutdown.
+            _finalize_shutdown()
+    else:
+        _finalize_shutdown()
+
+
+def _finalize_shutdown():
+    try:
+        # This raises a `TORCH_CHECK()` exception on RRef leak detected.
+        _destroy_rref_context(_ignore_rref_leak)
+    finally:
+        _get_current_rpc_agent().shutdown()
+        # clean up python rpc handler in shutdown(), see comments in
+        # PythonRpcHandler::cleanup(), call it in python API because the
+        # cleanup() function has python dependency, it assumes python
+        # interpreter exists.
+        # No matter if RRef leak exception is raised, this clean-up code
+        # must run to avoid destruction segfault in Python 3.5.
+        #
+        # future.wait() should not be called after shutdown().
+        # pythonRpcHandler is cleaned up in shutdown(), after
+        # shutdown(), python objects returned from rpc python call can not be
+        # resolved.
+        _cleanup_python_rpc_handler()
+        _reset_current_rpc_agent()
+
+
+@_require_initialized
+def get_worker_info(worker_name=None):
+    r"""
+    Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
+    Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
+    expensive string on every invocation.
+
+    Args:
+        worker_name (str): the string name of a worker. If ``None``, return the
+                           the id of the current worker. (default ``None``)
+
+    Returns:
+        :class:`~torch.distributed.rpc.WorkerInfo` instance for the given
+        ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
+        current worker if ``worker_name`` is ``None``.
+    """
+    if worker_name is not None:
+        return _get_current_rpc_agent().get_worker_info(worker_name)
+    else:
+        return _get_current_rpc_agent().get_worker_info()
+
+
+def _to_worker_info(to):
+    if isinstance(to, WorkerInfo):
+        return to
+    elif isinstance(to, (str, int)):
+        return get_worker_info(to)
+    else:
+        raise ValueError(f"Cannot get WorkerInfo from name {to}")
+
+
+def _rref_typeof_on_owner(rref, blocking: bool = True):
+    rref_type = type(rref.local_value())
+    if blocking:
+        return rref_type
+    else:
+        # Wrap result into a completed Future. This is so that if blocking=`False`
+        # is specified, we return a future regardless of if this call is on user
+        # or owner.
+        future = Future[type]()
+        future.set_result(rref_type)
+        return future
+
+
+def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True):
+    fut = rpc_async(
+        rref.owner(),
+        _rref_typeof_on_owner,
+        args=(rref,),
+        timeout=timeout
+    )
+    if blocking:
+        return fut.wait()
+    else:
+        return fut
+
+
+T = TypeVar("T")
+GenericWithOneTypeVar = Generic[T]
+
+
+if TYPE_CHECKING:
+    class RRef(PyRRef[T], Generic[T]):
+        pass
+else:
+    try:
+        # Combine the implementation class and the type class.
+        class RRef(PyRRef, Generic[T]):
+            pass
+    except TypeError:
+        # TypeError: metaclass conflict: the metaclass of a derived class
+        # must be a (non-strict) subclass of the metaclasses of all its bases
+        # Mypy doesn't understand __class__ (mypy bug #4177)
+        class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__):  # type: ignore[name-defined, misc, valid-type]
+            pass
+
+        # Combine the implementation class and the type class.
+        # Types for classes expecting a certain generic parameter (mypy bug #7791)
+        class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta):  # type: ignore[misc, no-redef, valid-type]
+            pass
+
+
+# Install docstrings from `PyRRef` to `RRef`.
+#
+# This is for the fact that pybind11 generates the parameter
+# `self` as type `rpc.PyRRef`, so a `:inherited-members:`
+# under `.. autoclass:: RRef` does not work.
+# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`.
+#
+def method_factory(method_name, docstring):
+    def method(self, *args, **kwargs):
+        return getattr(super(RRef, self), method_name)(*args, **kwargs)
+
+    if method.__doc__:
+        method.__doc__ = docstring
+    return method
+
+
+for method_name, method in inspect.getmembers(PyRRef):
+    # Ignore magic methods, except "__str__".
+    if method_name.startswith("_") and method_name != "__str__":
+        continue
+
+    # Get pybind11 generated docstring.
+    # It's like,
+    """
+    to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
+
+        Blocking call that copies the value of the RRef from the owner
+        to the local node and returns it. If the current node is the
+        owner, returns a reference to the local value.
+    """
+    docstring = getattr(method, "__doc__", None)
+    assert docstring is not None, "RRef user-facing methods should all have docstrings."
+
+    # Do surgery on pybind11 generated docstrings.
+    docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")
+
+    # Attach user-facing RRef method with modified docstring.
+    new_method = method_factory(method_name, docstring)
+    setattr(RRef, method_name, new_method)
+
+
+@_require_initialized
+def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
+    r"""
+    Make a remote call to run ``func`` on worker ``to`` and return an
+    :class:`~torch.distributed.rpc.RRef` to the result value immediately.
+    Worker ``to`` will be the owner of the returned
+    :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
+    a user. The owner manages the global reference count of its
+    :class:`~torch.distributed.rpc.RRef`, and the owner
+    :class:`~torch.distributed.rpc.RRef` is only destructed when globally there
+    are no living references to it.
+
+    Args:
+        to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
+        func (Callable): a callable function, such as Python callables, builtin
+                         operators (e.g. :meth:`~torch.add`) and annotated
+                         TorchScript functions.
+        args (tuple): the argument tuple for the ``func`` invocation.
+        kwargs (dict): is a dictionary of keyword arguments for the ``func``
+                       invocation.
+
+        timeout (float, optional): timeout in seconds for this remote call. If the
+                                   creation of this
+                                   :class:`~torch.distributed.rpc.RRef` on worker
+                                   ``to`` is not successfully processed on this
+                                   worker within this timeout, then the next time
+                                   there is an attempt to use the RRef (such as
+                                   ``to_here()``), a timeout will be raised
+                                   indicating this failure. A value of 0 indicates
+                                   an infinite timeout, i.e. a timeout error will
+                                   never be raised. If not provided, the default
+                                   value set during initialization or with
+                                   ``_set_rpc_timeout`` is used.
+
+    Returns:
+        A user :class:`~torch.distributed.rpc.RRef` instance to the result
+        value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
+        to retrieve the result value locally.
+
+    .. warning ::
+        The ``remote`` API does not copy storages of argument tensors until
+        sending them over the wire, which could be done by a different thread
+        depending on the RPC backend type. The caller should make sure that the
+        contents of those tensors stay intact until the returned RRef is
+        confirmed by the owner, which can be checked using the
+        :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
+
+    .. warning ::
+        Errors such as timeouts for the ``remote`` API are handled on a
+        best-effort basis. This means that when remote calls initiated by
+        ``remote`` fail, such as with a timeout error, we take a best-effort
+        approach to error handling. This means that errors are handled and set
+        on the resulting RRef on an asynchronous basis. If the RRef has not been
+        used by the application before this handling (such as ``to_here`` or
+        fork call), then future uses of the ``RRef`` will appropriately raise
+        errors. However, it is possible that the user application will use the
+        ``RRef`` before the errors are handled. In this case, errors may not be
+        raised as they have not yet been handled.
+
+    Example::
+
+        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
+        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
+        API for more details. For example,
+
+        export MASTER_ADDR=localhost
+        export MASTER_PORT=5678
+
+        Then run the following code in two different processes:
+
+        >>> # xdoctest: +SKIP
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
+        >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
+        >>> x = rref1.to_here() + rref2.to_here()
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+
+        Below is an example of running a TorchScript function using RPC.
+
+        >>> # On both workers:
+        >>> @torch.jit.script
+        >>> def my_script_add(tensor: torch.Tensor, scalar: int):
+        >>>    return torch.add(tensor, scalar)
+
+        >>> # On worker 0:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
+        >>> rref.to_here()
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+    """
+    torch._C._log_api_usage_once("torch.distributed.rpc_remote")
+    qualified_name = torch.jit._builtins._find_builtin(func)
+    dst_worker_info = _to_worker_info(to)
+    should_profile = _get_should_profile()
+
+    ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
+
+    with ctx_manager as rf:
+        args = args if args else ()
+        kwargs = kwargs if kwargs else {}
+
+        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
+
+        if is_async_exec:
+            wrapped = func._wrapped_async_rpc_function
+            if isinstance(wrapped, torch.jit.ScriptFunction):
+                func = wrapped
+
+        if qualified_name is not None:
+            rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs)
+        elif isinstance(func, torch.jit.ScriptFunction):
+            rref = _invoke_remote_torchscript(
+                dst_worker_info.name,
+                torch._jit_internal._qualified_name(func),
+                timeout,
+                is_async_exec,
+                *args,
+                **kwargs,
+            )
+        else:
+            (pickled_python_udf, tensors) = _default_pickler.serialize(
+                PythonUDF(func, args, kwargs)
+            )
+            rref = _invoke_remote_python_udf(
+                dst_worker_info,
+                pickled_python_udf,
+                tensors,
+                timeout,
+                is_async_exec
+            )
+        # attach profiling information
+        if should_profile:
+            assert torch.autograd._profiler_enabled()
+            assert rf is not None
+            fut = rf._call_end_callbacks_on_future(rref._get_future())
+            rref._set_profiling_future(fut)
+
+    return rref
+
+
+def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT):
+    if not callable(func):
+        raise TypeError("function should be callable.")
+
+    qualified_name = torch.jit._builtins._find_builtin(func)
+    dst_worker_info = _to_worker_info(to)
+
+    should_profile = _get_should_profile()
+
+    ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
+
+    with ctx_manager as rf:
+        args = args if args else ()
+        kwargs = kwargs if kwargs else {}
+
+        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
+
+        if is_async_exec:
+            wrapped = func._wrapped_async_rpc_function
+            if isinstance(wrapped, torch.jit.ScriptFunction):
+                func = wrapped
+
+        if qualified_name is not None:
+            fut = _invoke_rpc_builtin(
+                dst_worker_info,
+                qualified_name,
+                rpc_timeout,
+                *args,
+                **kwargs
+            )
+        elif isinstance(func, torch.jit.ScriptFunction):
+            fut = _invoke_rpc_torchscript(
+                dst_worker_info.name,
+                torch._jit_internal._qualified_name(func),
+                args,
+                kwargs,
+                rpc_timeout,
+                is_async_exec
+            )
+        else:
+            (pickled_python_udf, tensors) = _default_pickler.serialize(
+                PythonUDF(func, args, kwargs)
+            )
+            fut = _invoke_rpc_python_udf(
+                dst_worker_info,
+                pickled_python_udf,
+                tensors,
+                rpc_timeout,
+                is_async_exec
+            )
+        if should_profile:
+            assert torch.autograd._profiler_enabled()
+            assert rf is not None
+            # Schedule profiling callbacks to run when the future completes.
+            # This returns a future that is completed when the original future
+            # completes and the profiling callbacks have been completed as well,
+            # to guarantee that fut.wait() completes the profiling. This new
+            # future will contain the same value as the original future.
+            fut = rf._call_end_callbacks_on_future(fut)
+    return fut
+
+
+@_require_initialized
+def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
+    r"""
+    Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
+    messages are sent and received in parallel to execution of Python code. This
+    method is thread-safe.
+
+    Args:
+        to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
+        func (Callable): a callable function, such as Python callables, builtin
+                         operators (e.g. :meth:`~torch.add`) and annotated
+                         TorchScript functions.
+        args (tuple): the argument tuple for the ``func`` invocation.
+        kwargs (dict): is a dictionary of keyword arguments for the ``func``
+                       invocation.
+        timeout (float, optional): timeout in seconds to use for this RPC. If
+                                   the RPC does not complete in this amount of
+                                   time, an exception indicating it has
+                                   timed out will be raised. A value of 0
+                                   indicates an infinite timeout, i.e. a timeout
+                                   error will never be raised. If not provided,
+                                   the default value set during initialization
+                                   or with ``_set_rpc_timeout`` is used.
+
+    Returns:
+        Returns the result of running ``func`` with ``args`` and ``kwargs``.
+
+    Example::
+        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
+        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
+        API for more details. For example,
+
+        export MASTER_ADDR=localhost
+        export MASTER_PORT=5678
+
+        Then run the following code in two different processes:
+
+        >>> # xdoctest: +SKIP
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+
+        Below is an example of running a TorchScript function using RPC.
+
+        >>> # On both workers:
+        >>> @torch.jit.script
+        >>> def my_script_add(tensor: torch.Tensor, scalar: int):
+        >>>    return torch.add(tensor, scalar)
+
+        >>> # On worker 0:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+
+    """
+    torch._C._log_api_usage_once("torch.distributed.rpc_sync")
+    fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
+    return fut.wait()
+
+
+@_require_initialized
+def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
+    r"""
+    Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
+    messages are sent and received in parallel to execution of Python code. This
+    method is thread-safe. This method will immediately return a
+    :class:`~torch.futures.Future` that can be awaited on.
+
+    Args:
+        to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
+        func (Callable): a callable function, such as Python callables, builtin
+                         operators (e.g. :meth:`~torch.add`) and annotated
+                         TorchScript functions.
+        args (tuple): the argument tuple for the ``func`` invocation.
+        kwargs (dict): is a dictionary of keyword arguments for the ``func``
+                       invocation.
+        timeout (float, optional): timeout in seconds to use for this RPC. If
+                                   the RPC does not complete in this amount of
+                                   time, an exception indicating it has
+                                   timed out will be raised. A value of 0
+                                   indicates an infinite timeout, i.e. a timeout
+                                   error will never be raised. If not provided,
+                                   the default value set during initialization
+                                   or with ``_set_rpc_timeout`` is used.
+
+
+    Returns:
+        Returns a :class:`~torch.futures.Future` object that can be waited
+        on. When completed, the return value of ``func`` on ``args`` and
+        ``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
+        object.
+
+    .. warning ::
+        Using GPU tensors as arguments or return values of ``func`` is not
+        supported since we don't support sending GPU tensors over the wire. You
+        need to explicitly copy GPU tensors to CPU before using them as
+        arguments or return values of ``func``.
+
+    .. warning ::
+        The ``rpc_async`` API does not copy storages of argument tensors until
+        sending them over the wire, which could be done by a different thread
+        depending on the RPC backend type. The caller should make sure that the
+        contents of those tensors stay intact until the returned
+        :class:`~torch.futures.Future` completes.
+
+    Example::
+        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
+        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
+        API for more details. For example,
+
+        export MASTER_ADDR=localhost
+        export MASTER_PORT=5678
+
+        Then run the following code in two different processes:
+
+        >>> # xdoctest: +SKIP
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
+        >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
+        >>> result = fut1.wait() + fut2.wait()
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+
+        Below is an example of running a TorchScript function using RPC.
+
+        >>> # On both workers:
+        >>> @torch.jit.script
+        >>> def my_script_add(tensor: torch.Tensor, scalar: int):
+        >>>    return torch.add(tensor, scalar)
+
+        >>> # On worker 0:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
+        >>> ret = fut.wait()
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> rpc.shutdown()
+    """
+    torch._C._log_api_usage_once("torch.distributed.rpc_async")
+    fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
+    if hasattr(_thread_local_var, "future_list"):
+        _thread_local_var.future_list.append(fut)
+    return fut
+
+
+def _get_should_profile():
+    # Legacy profiler should be enabled. RPC profiling is not supported with
+    # Kineto profiler.
+    ActiveProfilerType = torch._C._profiler.ActiveProfilerType
+    return (
+        torch.autograd._profiler_enabled() and
+        torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY  # type: ignore[attr-defined]
+    )
+
+
+def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
+    ctx_manager = contextlib.nullcontext()
+
+    if should_profile:
+        # Create appropriate string representation based on type of func
+        # (builtin, script, python)
+        if qualified_name is None:
+            func_name = (
+                torch._jit_internal._qualified_name(func)
+                if isinstance(func, torch.jit.ScriptFunction)
+                else func.__qualname__
+            )
+        else:
+            func_name = qualified_name
+        # Build RPC profiling key.
+        rpc_profiling_key = _build_rpc_profiling_key(
+            rpc_type,
+            func_name,
+            get_worker_info().name,
+            dst_worker_info.name,
+        )
+        RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
+        # Mypy doesn't support re-def of a variable not in the same block (#1174)
+        ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)  # type: ignore[assignment]
+
+    return ctx_manager
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/backend_registry.py b/MLPY/Lib/site-packages/torch/distributed/rpc/backend_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..e28651842199b3e4cc10c8c2665a277f9dbf9da3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/backend_registry.py
@@ -0,0 +1,395 @@
+__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"]
+
+import collections
+import enum
+from typing import cast, Dict, List, Set, Tuple
+
+import torch
+import torch.distributed as dist
+from ._utils import _group_membership_management, _update_group_membership
+
+from . import api
+from . import constants as rpc_constants
+
+__all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend",
+           "BackendValue", "BackendType"]
+
+BackendValue = collections.namedtuple(
+    "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
+)
+
+
+def _backend_type_repr(self):
+    return "BackendType." + self.name
+
+
+_backend_type_doc = """
+    An enum class of available backends.
+
+    PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
+    Additional ones can be registered using the
+    :func:`~torch.distributed.rpc.backend_registry.register_backend` function.
+"""
+
+# Create an enum type, `BackendType`, with empty members.
+# Can't handle Function Enum API (mypy bug #9079)
+BackendType = enum.Enum(value="BackendType", names=dict())  # type: ignore[misc]
+# Unable to assign a function a method (mypy bug #2427)
+BackendType.__repr__ = _backend_type_repr  # type: ignore[assignment]
+
+if BackendType.__doc__:
+    BackendType.__doc__ = _backend_type_doc
+
+def backend_registered(backend_name):
+    """
+    Checks if backend_name is registered as an RPC backend.
+
+    Args:
+        backend_name (str): string to identify the RPC backend.
+    Returns:
+        True if the backend has been registered with ``register_backend``, else
+        False.
+    """
+    return backend_name in BackendType.__members__.keys()
+
+
+def register_backend(
+    backend_name, construct_rpc_backend_options_handler, init_backend_handler
+):
+    """Registers a new RPC backend.
+
+    Args:
+        backend_name (str): backend string to identify the handler.
+        construct_rpc_backend_options_handler (function):
+            Handler that is invoked when
+            rpc_backend.construct_rpc_backend_options(**dict) is called.
+        init_backend_handler (function): Handler that is invoked when the
+            `_init_rpc_backend()` function is called with a backend.
+             This returns the agent.
+    """
+    global BackendType
+    if backend_registered(backend_name):
+        raise RuntimeError(f"RPC backend {backend_name}: already registered")
+    # Create a new enum type, `BackendType`, with extended members.
+    existing_enum_dict = {member.name: member.value for member in BackendType}
+    extended_enum_dict = dict(
+        {
+            backend_name: BackendValue(
+                construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
+                init_backend_handler=init_backend_handler,
+            )
+        },
+        **existing_enum_dict
+    )
+    # Can't handle Function Enum API (mypy bug #9079)
+    BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)  # type: ignore[misc]
+    # Unable to assign a function a method (mypy bug #2427)
+    BackendType.__repr__ = _backend_type_repr  # type: ignore[assignment]
+    if BackendType.__doc__:
+        BackendType.__doc__ = _backend_type_doc
+    return BackendType[backend_name]
+
+def construct_rpc_backend_options(
+    backend,
+    rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
+    init_method=rpc_constants.DEFAULT_INIT_METHOD,
+    **kwargs
+):
+
+    return backend.value.construct_rpc_backend_options_handler(
+        rpc_timeout, init_method, **kwargs
+    )
+
+def init_backend(backend, *args, **kwargs):
+    return backend.value.init_backend_handler(*args, **kwargs)
+
+def _init_process_group(store, rank, world_size):
+    # Initialize ProcessGroup.
+    process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
+
+    # We're using a bunch of private APIs here since `new_group` requires the
+    # default group to be initialized.
+    group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)
+
+    assert group is not None, "Failed to initialize default ProcessGroup."
+
+    if (rank != -1) and (rank != group.rank()):
+        raise RuntimeError(
+            f"rank argument {rank} doesn't match pg rank {group.rank()}"
+        )
+    if (world_size != -1) and (world_size != group.size()):
+        raise RuntimeError(
+            f"world_size argument {world_size} doesn't match pg size {group.size()}"
+        )
+    return group
+
+def _tensorpipe_construct_rpc_backend_options_handler(
+    rpc_timeout,
+    init_method,
+    num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,
+    _transports=None,
+    _channels=None,
+    **kwargs
+):
+    from . import TensorPipeRpcBackendOptions
+
+    return TensorPipeRpcBackendOptions(
+        rpc_timeout=rpc_timeout,
+        init_method=init_method,
+        num_worker_threads=num_worker_threads,
+        _transports=_transports,
+        _channels=_channels,
+    )
+
+
+def _tensorpipe_validate_devices(devices, device_count):
+    return all(
+        d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
+        for d in devices
+    )
+
+
+# detect if any worker has invalid device_map configurations, and return
+# reverse device maps
+def _tensorpipe_exchange_and_check_all_device_maps(
+    my_name, my_device_count, my_device_maps, my_devices, group
+):
+    gathered: List[Tuple[
+        str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]
+    ]] = [("", 0, {}, []) for _ in range(group.size())]
+    dist.all_gather_object(
+        gathered, (my_name, my_device_count, my_device_maps, my_devices), group
+    )
+    all_names = [name for name, _, _, _ in gathered]
+    all_device_counts = {name: count for name, count, _, _ in gathered}
+    all_device_maps = {name: map_ for name, _, map_, _ in gathered}
+    all_devices = {name: devices for name, _, _, devices in gathered}
+
+    _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
+
+    # passed all checked, construct reverse mapping and get list of devices handled by this agent
+    reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
+    my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
+    return reverse_device_maps, my_devices
+
+def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True):
+    for node in all_names:
+        devices = all_devices[node]
+        if len(set(devices)) != len(devices):
+            raise ValueError(
+                f"Node {node} has duplicated devices\n"
+                f"devices = {devices}"
+            )
+        if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
+            raise ValueError(
+                f"Node {node} has devices with invalid indices\n"
+                f"devices = {devices}\n"
+                f"device count = {all_device_counts[node]}"
+            )
+
+    for source_node in all_names:
+        # For dynamic group (non-static) do not check the target node name since it may not have joined yet
+        if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names):
+            raise ValueError(
+                f"Node {source_node} has invalid target node names in its device maps\n"
+                f"device maps = {all_device_maps[source_node].keys()}\n"
+                f"node names = {all_names}"
+            )
+        for target_node, map_ in all_device_maps[source_node].items():
+            if len(set(map_.values())) != len(map_):
+                raise ValueError(
+                    f"Node {source_node} has duplicated target devices "
+                    f"in its device map for {target_node}\n"
+                    f"device map = {map_}"
+                )
+            if all_devices[source_node]:
+                if not set(map_.keys()).issubset(all_devices[source_node]):
+                    raise ValueError(
+                        f"Node {source_node} has unexpected source devices "
+                        f"in its device map for {target_node}\n"
+                        f"device map = {map_}\n"
+                        f"devices = {all_devices[source_node]}"
+                    )
+            elif not _tensorpipe_validate_devices(
+                map_.keys(), all_device_counts[source_node]
+            ):
+                raise ValueError(
+                    f"Node {source_node} has source devices with invalid indices "
+                    f"in its device map for {target_node}\n"
+                    f"device map = {map_}\n"
+                    f"device count = {all_device_counts[source_node]}"
+                )
+            if all_devices.get(target_node, []):
+                if not set(map_.values()).issubset(all_devices[target_node]):
+                    raise ValueError(
+                        f"Node {source_node} has unexpected target devices "
+                        f"in its device map for {target_node}\n"
+                        f"device map = {map_}\n"
+                        f"devices = {all_devices[target_node]}"
+                    )
+            elif target_node in all_device_counts and not _tensorpipe_validate_devices(
+                map_.values(), all_device_counts[target_node]
+            ):
+                raise ValueError(
+                    f"Node {source_node} has target devices with invalid indices "
+                    f"in its device map for {target_node}\n"
+                    f"device map = {map_}\n"
+                    f"device count = {all_device_counts[target_node]}"
+                )
+
+def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
+    if not my_devices:
+        devices_set: Set[torch.device] = set()
+        for map_ in my_device_maps.values():
+            devices_set.update(map_.keys())
+        for map_ in reverse_device_maps.values():
+            devices_set.update(map_.keys())
+        devices_set.discard(torch.device("cpu"))
+        my_devices = list(devices_set)
+    my_devices = sorted(my_devices, key=lambda d: d.index)
+    return my_devices
+
+def _create_reverse_mapping(my_name, all_names, all_device_maps):
+    reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
+    for node in all_names:
+        if my_name in all_device_maps[node]:
+            reverse_device_maps[node] = {
+                v: k for k, v in all_device_maps[node][my_name].items()
+            }
+    return reverse_device_maps
+
+def _get_device_infos():
+    from . import TensorPipeAgent
+    agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
+    opts = agent._get_backend_options()
+    device_count = torch.cuda.device_count()
+    if torch.cuda.is_available() and opts.devices:
+        torch.cuda.init()
+    return device_count, opts.device_maps, opts.devices
+
+def _set_devices_and_reverse_device_map(agent):
+    from . import TensorPipeAgent
+    agent = cast(TensorPipeAgent, agent)
+    # Group state is retrieved from local agent
+    # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
+    my_worker_info = agent.get_worker_info()
+    my_name = my_worker_info.name
+    all_worker_infos = agent.get_worker_infos()
+    # One round to get device_maps of all workers and construct reverse device maps
+    all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
+    for worker_info in all_worker_infos:
+        worker_name = worker_info.name
+        if worker_name != my_name:
+            # TODO: make async?
+            device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos)
+        else:
+            opts = agent._get_backend_options()
+            device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices
+        all_device_counts[worker_name] = device_count
+        all_device_maps[worker_name] = device_map
+        all_devices[worker_name] = devices
+        all_names.append(worker_name)
+
+    _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False)
+    reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
+
+    # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
+    for worker_name in all_names:
+        # Set device list for each worker
+        all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps)
+        api.rpc_sync(worker_name, _update_group_membership,
+                     args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True))
+
+def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options):
+    from . import TensorPipeAgent
+    from . import TensorPipeRpcBackendOptions
+    if not isinstance(store, dist.Store):
+        raise TypeError(f"`store` must be a c10d::Store. {store}")
+
+    if not isinstance(
+        rpc_backend_options, TensorPipeRpcBackendOptions
+    ):
+        raise TypeError(
+            f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}"
+        )
+
+    device_count = torch.cuda.device_count()
+
+    is_static_group = True if world_size else False
+    # world_size is specified so this is a static group (ranks cannot join and leave)
+    if is_static_group:
+        # The agent's join method is required to behave like a barrier and perform
+        # collective operations, for which it relies on a process group, instead of
+        # re-implementing this on top of RPCs.
+        group = _init_process_group(store, rank, world_size)
+
+        reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
+            name,
+            device_count,
+            rpc_backend_options.device_maps,
+            rpc_backend_options.devices,
+            group,
+        )
+
+        if torch.cuda.is_available() and devices:
+            # It's necessary to initialize PyTorch CUDA states here (e.g.,
+            # CUDACachingAllocator). If this is missing, we could hit errors like
+            # "allocator not initialized", because other processes might send
+            # CUDA-related RPC request to this process before user code in this
+            # process initializes its PyTorch CUDA states.
+            torch.cuda.init()
+
+        # TODO: add try-except and destroy _agent in all processes if any fails.
+        agent = TensorPipeAgent(
+            store,
+            name,
+            rank,
+            world_size,
+            rpc_backend_options,
+            reverse_device_maps,
+            devices,
+        )
+
+        api._init_rpc_states(agent)
+
+        # Run one dummy round of RPC to initialize channels/transports. Without
+        # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
+        # on that process before rpc.shutdown(), as the agent initialization can
+        # take longer than 5s.
+        api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
+        # Need a barrier here to make sure no peers leave before the rank0 finishes
+        # _all_gather
+        group.barrier().wait()
+
+        return agent
+    # initialization for dynamic rpc (ranks can join and leave)
+    else:
+        with _group_membership_management(store, name, True):
+            # Construct TPAgent with empty reverse_device_map and devices
+            # these properties will be updated after initialization
+            agent = TensorPipeAgent(
+                store,
+                name,
+                rank,
+                world_size,
+                rpc_backend_options,
+                {},
+                [],
+            )
+            api._init_rpc_states(agent)
+
+            try:
+                # Notify all workers in group this rank has joined and set devices and reverse_device_map
+                # This is a synchronous operation that completes once all existing ranks are updated
+                _set_devices_and_reverse_device_map(agent)
+                pass
+            except Exception:
+                api.shutdown()
+                raise
+            return agent
+
+register_backend(
+    "TENSORPIPE",
+    _tensorpipe_construct_rpc_backend_options_handler,
+    _tensorpipe_init_backend_handler,
+)
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/constants.py b/MLPY/Lib/site-packages/torch/distributed/rpc/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dba916b05002459e67c4499e7fdb450ef3dbb38
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/constants.py
@@ -0,0 +1,24 @@
+from datetime import timedelta
+from typing import List
+from torch._C._distributed_rpc import (
+    _DEFAULT_INIT_METHOD,
+    _DEFAULT_NUM_WORKER_THREADS,
+    _DEFAULT_RPC_TIMEOUT_SEC,
+    _UNSET_RPC_TIMEOUT,
+)
+
+
+# For any RpcAgent.
+DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
+DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
+DEFAULT_SHUTDOWN_TIMEOUT: float = 0
+
+# For TensorPipeAgent.
+DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS
+# Ensure that we don't time out when there are long periods of time without
+# any operations against the underlying ProcessGroup.
+DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1)
+# Value indicating that timeout is not set for RPC call, and the default should be used.
+UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT
+
+__all__: List[str] = []
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/functions.py b/MLPY/Lib/site-packages/torch/distributed/rpc/functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d76b5f7e7d3ab72bdaee791c523ca068190431
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/functions.py
@@ -0,0 +1,166 @@
+import functools
+
+
+def async_execution(fn):
+    r"""
+    A decorator for a function indicating that the return value of the function
+    is guaranteed to be a :class:`~torch.futures.Future` object and this
+    function can run asynchronously on the RPC callee. More specifically, the
+    callee extracts the :class:`~torch.futures.Future` returned by the wrapped
+    function and installs subsequent processing steps as a callback to that
+    :class:`~torch.futures.Future`. The installed callback will read the value
+    from the :class:`~torch.futures.Future` when completed and send the
+    value back as the RPC response. That also means the returned
+    :class:`~torch.futures.Future` only exists on the callee side and is never
+    sent through RPC. This decorator is useful when the wrapped function's
+    (``fn``) execution needs to pause and resume due to, e.g., containing
+    :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
+
+    .. note:: To enable asynchronous execution, applications must pass the
+        function object returned by this decorator to RPC APIs. If RPC detected
+        attributes installed by this decorator, it knows that this function
+        returns a ``Future`` object and will handle that accordingly.
+        However, this does not mean this decorator has to be outmost one when
+        defining a function. For example, when combined with ``@staticmethod``
+        or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
+        inner decorator to allow the target function be recognized as a static
+        or class function. This target function can still execute asynchronously
+        because, when accessed, the static or class method preserves attributes
+        installed by ``@rpc.functions.async_execution``.
+
+
+    Example::
+        The returned :class:`~torch.futures.Future` object can come from
+        :meth:`~torch.distributed.rpc.rpc_async`,
+        :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
+        constructor. The example below shows directly using the
+        :class:`~torch.futures.Future` returned by
+        :meth:`~torch.futures.Future.then`.
+
+        >>> from torch.distributed import rpc
+        >>>
+        >>> # omitting setup and shutdown RPC
+        >>>
+        >>> # On all workers
+        >>> @rpc.functions.async_execution
+        >>> def async_add_chained(to, x, y, z):
+        >>>     # This function runs on "worker1" and returns immediately when
+        >>>     # the callback is installed through the `then(cb)` API. In the
+        >>>     # mean time, the `rpc_async` to "worker2" can run concurrently.
+        >>>     # When the return value of that `rpc_async` arrives at
+        >>>     # "worker1", "worker1" will run the lambda function accordingly
+        >>>     # and set the value for the previously returned `Future`, which
+        >>>     # will then trigger RPC to send the result back to "worker0".
+        >>>     return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        >>>         lambda fut: fut.wait() + z
+        >>>     )
+        >>>
+        >>> # On worker0
+        >>> # xdoctest: +SKIP
+        >>> ret = rpc.rpc_sync(
+        >>>     "worker1",
+        >>>     async_add_chained,
+        >>>     args=("worker2", torch.ones(2), 1, 1)
+        >>> )
+        >>> print(ret)  # prints tensor([3., 3.])
+
+        When combined with TorchScript decorators, this decorator must be the
+        outmost one.
+
+        >>> from torch import Tensor
+        >>> from torch.futures import Future
+        >>> from torch.distributed import rpc
+        >>>
+        >>> # omitting setup and shutdown RPC
+        >>>
+        >>> # On all workers
+        >>> @torch.jit.script
+        >>> def script_add(x: Tensor, y: Tensor) -> Tensor:
+        >>>     return x + y
+        >>>
+        >>> @rpc.functions.async_execution
+        >>> @torch.jit.script
+        >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
+        >>>     return rpc.rpc_async(to, script_add, (x, y))
+        >>>
+        >>> # On worker0
+        >>> ret = rpc.rpc_sync(
+        >>>     "worker1",
+        >>>     async_add,
+        >>>     args=("worker2", torch.ones(2), 1)
+        >>> )
+        >>> print(ret)  # prints tensor([2., 2.])
+
+        When combined with static or class method, this decorator must be the
+        inner one.
+
+        >>> from torch.distributed import rpc
+        >>>
+        >>> # omitting setup and shutdown RPC
+        >>>
+        >>> # On all workers
+        >>> class AsyncExecutionClass:
+        >>>
+        >>>     @staticmethod
+        >>>     @rpc.functions.async_execution
+        >>>     def static_async_add(to, x, y, z):
+        >>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        >>>             lambda fut: fut.wait() + z
+        >>>         )
+        >>>
+        >>>     @classmethod
+        >>>     @rpc.functions.async_execution
+        >>>     def class_async_add(cls, to, x, y, z):
+        >>>         ret_fut = torch.futures.Future()
+        >>>         rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        >>>             lambda fut: ret_fut.set_result(fut.wait() + z)
+        >>>         )
+        >>>         return ret_fut
+        >>>
+        >>>     @rpc.functions.async_execution
+        >>>     def bound_async_add(self, to, x, y, z):
+        >>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        >>>             lambda fut: fut.wait() + z
+        >>>         )
+        >>>
+        >>> # On worker0
+        >>> ret = rpc.rpc_sync(
+        >>>     "worker1",
+        >>>     AsyncExecutionClass.static_async_add,
+        >>>     args=("worker2", torch.ones(2), 1, 2)
+        >>> )
+        >>> print(ret)  # prints tensor([4., 4.])
+        >>>
+        >>> ret = rpc.rpc_sync(
+        >>>     "worker1",
+        >>>     AsyncExecutionClass.class_async_add,
+        >>>     args=("worker2", torch.ones(2), 1, 2)
+        >>> )
+        >>> print(ret)  # prints tensor([4., 4.])
+
+        This decorator also works with RRef helpers, i.e., .
+        :meth:`torch.distributed.rpc.RRef.rpc_sync`,
+        :meth:`torch.distributed.rpc.RRef.rpc_async`, and
+        :meth:`torch.distributed.rpc.RRef.remote`.
+
+        >>> from torch.distributed import rpc
+        >>>
+        >>> # reuse the AsyncExecutionClass class above
+        >>> rref = rpc.remote("worker1", AsyncExecutionClass)
+        >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
+        >>> print(ret)  # prints tensor([4., 4.])
+        >>>
+        >>> rref = rpc.remote("worker1", AsyncExecutionClass)
+        >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
+        >>> print(ret)  # prints tensor([4., 4.])
+        >>>
+        >>> rref = rpc.remote("worker1", AsyncExecutionClass)
+        >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
+        >>> print(ret)  # prints tensor([4., 4.])
+    """
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        return fn(*args, **kwargs)
+    # Can't declare and use attributes of function objects (mypy#2087)
+    wrapper._wrapped_async_rpc_function = fn  # type: ignore[attr-defined]
+    return wrapper
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/internal.py b/MLPY/Lib/site-packages/torch/distributed/rpc/internal.py
new file mode 100644
index 0000000000000000000000000000000000000000..9938352d0e6bf9c68f9fa2edb4fd66d662b105e8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/internal.py
@@ -0,0 +1,281 @@
+import collections
+import copyreg
+import io
+import pickle
+import sys
+import threading
+import traceback
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_rpc import _get_current_rpc_agent
+
+__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]
+
+# Thread local tensor tables to store tensors while pickling torch.Tensor
+# objects
+_thread_local_tensor_tables = threading.local()
+_pickler = pickle.Pickler
+_unpickler = pickle.Unpickler
+
+
+class RPCExecMode(Enum):
+    SYNC = "sync"
+    ASYNC = "async"
+    ASYNC_JIT = "async_jit"
+    REMOTE = "remote"
+
+
+class _InternalRPCPickler:
+    r"""
+    This class provides serialize() and deserialize() interfaces to serialize
+    data to be "binary string + tensor table" format
+    So for RPC python UDF function and args, non tensor data will be serialized
+    into regular binary string, tensor data will be put into thread local tensor
+    tables, this serialization format is consistent with builtin operator and args
+    using JIT pickler. This format will make tensor handling in C++ much easier,
+    e.g. attach tensor to distributed autograd graph in C++
+    """
+
+    def __init__(self):
+        # Ignore type error because dispatch_table is defined in third-party package
+        self._dispatch_table = copyreg.dispatch_table.copy()  # type: ignore[attr-defined]
+        self._dispatch_table[torch.Tensor] = self._tensor_reducer
+        # Used for registering customized picklers.
+        self._class_reducer_dict = {}
+
+    def _register_reducer(self, obj_class, reducer):
+        # For the same class, only register the reducer once.
+        if obj_class not in self._class_reducer_dict:
+            self._class_reducer_dict[obj_class] = reducer
+
+    @classmethod
+    def _tensor_receiver(cls, tensor_index):
+        global _thread_local_tensor_tables
+        return _thread_local_tensor_tables.recv_tables[tensor_index]
+
+    def _tensor_reducer(self, tensor):
+        global _thread_local_tensor_tables
+        _thread_local_tensor_tables.send_tables.append(tensor)
+        tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
+        return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
+
+    @classmethod
+    def _py_rref_receiver(cls, rref_fork_data):
+        return dist.rpc.PyRRef._deserialize(rref_fork_data)
+
+    def _py_rref_reducer(self, py_rref):
+        rref_fork_data = py_rref._serialize()
+        return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
+
+    def _rref_reducer(self, rref):
+        return self._py_rref_reducer(rref)
+
+    @classmethod
+    def _script_module_receiver(cls, script_module_serialized):
+        """
+        Given a serialized representation of a ScriptModule created with torch.jit.save,
+        loads and returns the ScriptModule.
+        """
+        f = io.BytesIO(script_module_serialized)
+        m = torch.jit.load(f)
+        return m
+
+    def _script_module_reducer(self, script_module):
+        """
+        Serializes a ScriptModule.
+        """
+        f = io.BytesIO()
+        torch.jit.save(script_module, f)
+        return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))
+
+    def serialize(self, obj):
+        r"""
+        Serialize non tensor data into binary string, tensor data into
+        tensor table
+        """
+        f = io.BytesIO()
+        p = _pickler(f)
+        p.dispatch_table = self._dispatch_table
+
+        # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
+        # user picklers could have different initialization function from _InternalRPCPickler,
+        # but all the user picklers should call serialize() and use _rref_reducer to pickle rref
+        # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
+        # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor,
+        # so putting rref's dispatch table here
+        #
+        # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
+        # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
+        # Ignore type error because dispatch_table is defined in third-party package
+        p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer  # type: ignore[index]
+        # An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
+        # Ignore type error because dispatch_table is defined in third-party package
+        p.dispatch_table[dist.rpc.RRef] = self._rref_reducer  # type: ignore[index]
+
+        # Add dispatch pickling for ScriptModule or its subclass.
+        if isinstance(obj, torch.jit.ScriptModule):
+            # Ignore type error because dispatch_table is defined in third-party package
+            p.dispatch_table[obj.__class__] = self._script_module_reducer  # type: ignore[index]
+
+        # Install customized picklers.
+        for class_name in self._class_reducer_dict.keys():
+            p.dispatch_table[class_name] = self._class_reducer_dict[class_name]  # type: ignore[index]
+
+        # save _thread_local_tensor_tables.send_tables if it is in nested call
+        global _thread_local_tensor_tables
+        if hasattr(_thread_local_tensor_tables, "send_tables"):
+            old_send_tables = _thread_local_tensor_tables.send_tables
+        else:
+            old_send_tables = None
+        _thread_local_tensor_tables.send_tables = []
+
+        p.dump(obj)
+
+        # restore _thread_local_tensor_tables.send_tables if return
+        # from nested call, otherwise clean up the table
+        tensors = _thread_local_tensor_tables.send_tables
+        if old_send_tables is not None:
+            _thread_local_tensor_tables.send_tables = old_send_tables
+        else:
+            del _thread_local_tensor_tables.send_tables
+
+        return (f.getvalue(), tensors)
+
+    def deserialize(self, binary_data, tensor_table):
+        r"""
+        Deserialize binary string + tensor table to original obj
+        """
+        # save _thread_local_tensor_tables.recv_tables if it is in nested call
+        global _thread_local_tensor_tables
+        if hasattr(_thread_local_tensor_tables, "recv_tables"):
+            old_recv_tables = _thread_local_tensor_tables.recv_tables
+        else:
+            old_recv_tables = None
+        _thread_local_tensor_tables.recv_tables = tensor_table
+
+        try:
+            unpickler = _unpickler(io.BytesIO(binary_data))
+            ret = unpickler.load()
+        except AttributeError as e:
+            # Occurs when function is not found on module/class during
+            # unpickling.
+            except_str = (
+                str(e)
+                + """ Default RPC pickler does not serialize
+            function code. Ensure that UDFs are defined on both caller and
+            callee modules."""
+            )
+            ret = AttributeError(except_str)
+            # Ensure the stack trace gets preserved
+            ret.__cause__ = e
+
+        # restore _thread_local_tensor_tables.recv_tables if return
+        # from nested call, otherwise clean up the table
+        if old_recv_tables is not None:
+            _thread_local_tensor_tables.recv_tables = old_recv_tables
+        else:
+            del _thread_local_tensor_tables.recv_tables
+
+        return ret
+
+
+# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
+_internal_rpc_pickler = _InternalRPCPickler()
+
+
+def serialize(obj):
+    return _internal_rpc_pickler.serialize(obj)
+
+
+def deserialize(binary_data, tensor_table):
+    return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
+
+
+def _run_function(python_udf):
+    r"""
+    This function is exclusively called from C++.
+    See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
+
+    Runs a Python UDF and returns its return value.
+    Wraps any exception in ``RemoteException`` if the function raises.
+    """
+    try:
+        if isinstance(python_udf, AttributeError):
+            raise python_udf
+        result = python_udf.func(*python_udf.args, **python_udf.kwargs)
+    except Exception as e:
+        # except str = exception info + traceback string
+        except_str = (
+            f"On {_get_current_rpc_agent().get_worker_info()}:\n"
+            f"{repr(e)}\n{traceback.format_exc()}"
+        )
+        print(except_str, file=sys.stderr)
+        result = RemoteException(except_str, type(e))
+    return result
+
+
+def _handle_exception(result):
+    if isinstance(result, RemoteException):
+        exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
+        # We wrap exception re-creation here in case some exception classes
+        # cannot be constructed directly from a string.
+        exc = None
+        try:
+            exc = result.exception_type(exception_msg)
+        except BaseException as e:
+            raise RuntimeError(  # noqa: B904
+                f"Failed to create original exception type. Error msg was {str(e)}"
+                f" Original exception on remote side was {exception_msg}"
+            ) from e
+
+        if exc is not None:
+            raise exc
+
+
+def _build_rpc_profiling_key(
+    exec_type, func_name, current_worker_name, dst_worker_name
+):
+    """
+    Builds the key that RPC calls are profiled with using the autograd profiler.
+    This will be the name of the corresponding Event recorded in the profiler.
+
+    Args:
+        exec_type (RPCExecMode): Type of RPC/RRef call
+        func_name (str): Name of function being profiled.
+        current_worker_name (str): Name of current worker.
+        dst_worker_name (str): Name of the destination worker.
+
+    Returns:
+        String representing profiling key
+    """
+    profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})"
+    return profile_key
+
+
+def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
+    """
+    This function should be called from RPC/RRef functions to create a
+    RecordFunction object for profiling. This function also runs the before
+    callbacks that start the profiling, though the user is responsible for
+    running the appropriate callbacks when the function to be profiled finishes.
+
+    Args:
+        exec_type (RPCExecMode): Type of RPC/RRef call
+        func_name (str): Name of function being profiled.
+        current_worker_name (str): Name of current worker.
+        dest_worker_name (str): Name of the destination worker.
+
+    Returns:
+        An instance of `torch.autograd._RecordFunction`.
+    """
+    assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
+    profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})"
+    rf = torch.autograd._RecordFunction()  # type: ignore[attr-defined]
+    torch.autograd._run_before_callbacks(rf, profile_key)  # type: ignore[attr-defined]
+    return rf
+
+
+PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
+RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/options.py b/MLPY/Lib/site-packages/torch/distributed/rpc/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..0791fef1bfa6d1757027486293b0b6e4a148f1f5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/options.py
@@ -0,0 +1,172 @@
+from typing import Dict, List, Optional, Union
+
+import torch
+from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
+from . import constants as rpc_contants
+
+
+DeviceType = Union[int, str, torch.device]
+
+__all__ = ["TensorPipeRpcBackendOptions"]
+
+def _to_device(device: DeviceType) -> torch.device:
+    device = torch.device(device)
+    if device.type != "cuda":
+        raise ValueError(
+            "`set_devices` expect a list of CUDA devices, but got "
+            f"device type {device.type}."
+        )
+    return device
+
+
+def _to_device_map(
+    device_map: Dict[DeviceType, DeviceType]
+) -> Dict[torch.device, torch.device]:
+    full_device_map: Dict[torch.device, torch.device] = {}
+    reverse_map: Dict[torch.device, torch.device] = {}
+    for k, v in device_map.items():
+        k, v = torch.device(k), torch.device(v)
+        if v in reverse_map:
+            raise ValueError(
+                "`device_map` only supports 1-to-1 mapping, "
+                f"trying to map {k} and {reverse_map[v]} to {v}"
+            )
+        full_device_map[k] = v
+        reverse_map[v] = k
+    return full_device_map
+
+
+def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
+    return list(map(_to_device, devices))
+
+
+class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
+    r"""
+    The backend options for
+    :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
+    :class:`~torch.distributed.rpc.RpcBackendOptions`.
+
+    Args:
+        num_worker_threads (int, optional): The number of threads in the
+            thread-pool used by
+            :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
+            requests (default: 16).
+        rpc_timeout (float, optional): The default timeout, in seconds,
+            for RPC requests (default: 60 seconds). If the RPC has not
+            completed in this timeframe, an exception indicating so will
+            be raised. Callers can override this timeout for individual
+            RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
+            :meth:`~torch.distributed.rpc.rpc_async` if necessary.
+        init_method (str, optional): The URL to initialize the distributed
+            store used for rendezvous. It takes any value accepted for the
+            same argument of :meth:`~torch.distributed.init_process_group`
+            (default: ``env://``).
+        device_maps (Dict[str, Dict], optional): Device placement mappings from
+            this worker to the callee. Key is the callee worker name and value
+            the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
+            that maps this worker's devices to the callee worker's devices.
+            (default: ``None``)
+        devices (List[int, str, or ``torch.device``], optional): all local
+            CUDA devices used by RPC agent. By Default, it will be initialized
+            to all local devices from its own ``device_maps`` and corresponding
+            devices from its peers' ``device_maps``. When processing CUDA RPC
+            requests, the agent will properly synchronize CUDA streams for
+            all devices in this ``List``.
+    """
+
+    def __init__(
+        self,
+        *,
+        num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
+        rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
+        init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
+        device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
+        devices: Optional[List[DeviceType]] = None,
+        _transports: Optional[List] = None,
+        _channels: Optional[List] = None,
+    ):
+        full_device_maps = (
+            {}
+            if device_maps is None
+            else {k: _to_device_map(v) for k, v in device_maps.items()}
+        )
+        full_device_list = [] if devices is None else _to_device_list(devices)
+        super().__init__(
+            num_worker_threads,
+            _transports,
+            _channels,
+            rpc_timeout,
+            init_method,
+            full_device_maps,
+            full_device_list,
+        )
+
+    def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
+        r"""
+        Set device mapping between each RPC caller and callee pair. This
+        function can be called multiple times to incrementally add
+        device placement configurations.
+
+        Args:
+            to (str): Callee name.
+            device_map (Dict of int, str, or torch.device): Device placement
+                mappings from this worker to the callee. This map must be
+                invertible.
+
+        Example:
+            >>> # xdoctest: +SKIP("distributed")
+            >>> # both workers
+            >>> def add(x, y):
+            >>>     print(x)  # tensor([1., 1.], device='cuda:1')
+            >>>     return x + y, (x + y).to(2)
+            >>>
+            >>> # on worker 0
+            >>> options = TensorPipeRpcBackendOptions(
+            >>>     num_worker_threads=8,
+            >>>     device_maps={"worker1": {0: 1}}
+            >>>     # maps worker0's cuda:0 to worker1's cuda:1
+            >>> )
+            >>> options.set_device_map("worker1", {1: 2})
+            >>> # maps worker0's cuda:1 to worker1's cuda:2
+            >>>
+            >>> rpc.init_rpc(
+            >>>     "worker0",
+            >>>     rank=0,
+            >>>     world_size=2,
+            >>>     backend=rpc.BackendType.TENSORPIPE,
+            >>>     rpc_backend_options=options
+            >>> )
+            >>>
+            >>> x = torch.ones(2)
+            >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
+            >>> # The first argument will be moved to cuda:1 on worker1. When
+            >>> # sending the return value back, it will follow the invert of
+            >>> # the device map, and hence will be moved back to cuda:0 and
+            >>> # cuda:1 on worker0
+            >>> print(rets[0])  # tensor([2., 2.], device='cuda:0')
+            >>> print(rets[1])  # tensor([2., 2.], device='cuda:1')
+        """
+        full_device_map = _to_device_map(device_map)
+        curr_device_maps = super().device_maps
+
+        if to in curr_device_maps:
+            for k, v in full_device_map.items():
+                if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
+                    raise ValueError(
+                        "`set_device_map` only supports 1-to-1 mapping, trying"
+                        f" to map {k} to {v} and {curr_device_maps[to][k]}"
+                    )
+
+        super()._set_device_map(to, full_device_map)
+
+    def set_devices(self, devices: List[DeviceType]):
+        r"""
+        Set local devices used by the TensorPipe RPC agent. When processing
+        CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
+        CUDA streams for all devices in this ``List``.
+
+        Args:
+            devices (List of int, str, or torch.device): local devices used by
+                the TensorPipe RPC agent.
+        """
+        self.devices = _to_device_list(devices)
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/rref_proxy.py b/MLPY/Lib/site-packages/torch/distributed/rpc/rref_proxy.py
new file mode 100644
index 0000000000000000000000000000000000000000..7219d74c9b9cdb01a1f480abc1927b66fb000a6f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/rref_proxy.py
@@ -0,0 +1,74 @@
+from functools import partial
+
+from . import functions
+from . import rpc_async
+
+import torch
+from .constants import UNSET_RPC_TIMEOUT
+from torch.futures import Future
+
+def _local_invoke(rref, func_name, args, kwargs):
+    return getattr(rref.local_value(), func_name)(*args, **kwargs)
+
+@functions.async_execution
+def _local_invoke_async_execution(rref, func_name, args, kwargs):
+    return getattr(rref.local_value(), func_name)(*args, **kwargs)
+
+def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
+    def _rref_type_cont(rref_fut):
+        rref_type = rref_fut.value()
+
+        _invoke_func = _local_invoke
+        # Bypass ScriptModules when checking for async function attribute.
+        bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
+            rref_type, torch._C.ScriptModule
+        )
+        if not bypass_type:
+            func = getattr(rref_type, func_name)
+            if hasattr(func, "_wrapped_async_rpc_function"):
+                _invoke_func = _local_invoke_async_execution
+
+        return rpc_api(
+            rref.owner(),
+            _invoke_func,
+            args=(rref, func_name, args, kwargs),
+            timeout=timeout
+        )
+
+    rref_fut = rref._get_type(timeout=timeout, blocking=False)
+
+    if rpc_api != rpc_async:
+        rref_fut.wait()
+        return _rref_type_cont(rref_fut)
+    else:
+        # A little explanation on this.
+        # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
+        # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
+        # To address that, we return a Future that is completed with the result of the async call.
+        result: Future = Future()
+
+        def _wrap_rref_type_cont(fut):
+            try:
+                _rref_type_cont(fut).then(_complete_op)
+            except BaseException as ex:
+                result.set_exception(ex)
+
+        def _complete_op(fut):
+            try:
+                result.set_result(fut.value())
+            except BaseException as ex:
+                result.set_exception(ex)
+
+        rref_fut.then(_wrap_rref_type_cont)
+        return result
+
+# This class manages proxied RPC API calls for RRefs. It is entirely used from
+# C++ (see python_rpc_handler.cpp).
+class RRefProxy:
+    def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
+        self.rref = rref
+        self.rpc_api = rpc_api
+        self.rpc_timeout = timeout
+
+    def __getattr__(self, func_name):
+        return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)
diff --git a/MLPY/Lib/site-packages/torch/distributed/rpc/server_process_global_profiler.py b/MLPY/Lib/site-packages/torch/distributed/rpc/server_process_global_profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f061904510c6cd11999372a0c82ff6c2157c0057
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/rpc/server_process_global_profiler.py
@@ -0,0 +1,177 @@
+#!/usr/bin/python3
+
+import itertools
+
+import torch
+from torch.autograd.profiler_legacy import profile
+from typing import List
+
+from . import (
+    _disable_server_process_global_profiler,
+    _enable_server_process_global_profiler,
+)
+
+__all__: List[str] = []
+
+class _server_process_global_profile(profile):
+    """
+    It has the same API as ``torch.autograd.profiler.profile`` class,
+    except that it enables profiling on all threads running RPC server request callbacks.
+
+    Context manager that manages autograd profiler state and holds a summary of results.
+    Under the hood it just records events of functions being executed in C++ and
+    exposes those events to Python. You can wrap any code into it and it will
+    only report runtime of PyTorch functions.
+    Note: profiler is thread local and is automatically propagated into the async tasks
+
+    Args:
+        enabled (bool, optional): Setting this to False makes this context manager a no-op.
+            Default: ``True``.
+
+        use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
+            Adds approximately 4us of overhead to each tensor operation.
+            Default: ``False``
+
+        record_shapes (bool, optional): If shapes recording is set, information
+            about input dimensions will be collected. This allows one to see which
+            dimensions have been used under the hood and further group by them
+            using prof.key_averages(group_by_input_shape=True). Please note that
+            shape recording might skew your profiling data. It is recommended to
+            use separate runs with and without shape recording to validate the timing.
+            Most likely the skew will be negligible for bottom most events (in a case
+            of nested function calls). But for higher level functions the total
+            self cpu time might be artificially increased because of the shape
+            collection.
+
+        profile_memory (bool, optional): Whether to report memory usage, default: ``False``
+
+    .. warning:
+        Enabling memory profiling incurs additional profiler overhead
+
+    .. warning:
+        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
+        one cannot use the profiler with ``use_cuda = True`` to benchmark
+        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
+        please use ``use_cuda = False`` or ``num_workers = 0``.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> # On worker 0:
+        >>> import torch
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
+        >>> x, y = torch.tensor(1), torch.tensor(2)
+        >>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
+        >>> outer_profile_rref.rpc_sync().__enter__()
+        >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
+        >>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile)
+        >>> inner_profile_rref.rpc_sync().__enter__()
+        >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
+        >>> inner_profile_rref.rpc_sync().__exit__(None, None, None)
+        >>> outer_profile_rref.rpc_sync().__exit__(None, None, None)
+        >>> print(inner_profile_rref.rpc_sync().key_averages())
+        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
+        Name       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
+        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
+        sub        85.06%           76.275us         100.00%          89.667us         89.667us         1
+        empty      14.94%           13.392us         14.94%           13.392us         13.392us         1
+        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
+        Self CPU time total: 89.667us
+        >>> print(outer_profile_rref.rpc_sync().key_averages())
+        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
+        Name       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
+        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
+        sub        35.65%           76.275us         41.91%           89.667us         89.667us         1
+        empty      12.67%           27.101us         12.67%           27.101us         13.551us         2
+        add        51.68%           110.550us        58.09%           124.259us        124.259us        1
+        ---------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
+        Self CPU time total: 213.926us
+        >>> rpc.shutdown()
+
+        >>> # On worker 1:
+        >>> import torch.distributed.rpc as rpc
+        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
+        >>> # wait for worker 0 to finish work, and then shutdown.
+        >>> rpc.shutdown()
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def __enter__(self):
+        """
+        Turn on server-side process-global profiling.
+        This enables thread-local profiler on all RPC threads running server-side request callbacks.
+        """
+        if not self.enabled:
+            return
+
+        if self.entered:  # type: ignore[has-type]
+            raise RuntimeError("autograd profiler traces are not reentrant")
+        self.entered = True
+
+        profiler_kind = (
+            torch.autograd.ProfilerState.CUDA
+            if self.use_cuda
+            else torch.autograd.ProfilerState.CPU
+        )
+        profiler_config = torch.autograd.ProfilerConfig(
+            profiler_kind,
+            self.record_shapes,
+            self.profile_memory,
+            False,
+            False,
+            False,
+            torch.profiler._ExperimentalConfig())
+        _enable_server_process_global_profiler(profiler_config)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        """
+        Turn off server-side process-global profiling.
+        Aggregate all profiling events recorded by RPC threads.
+
+        These attributes are assigned on exiting context.
+
+        Attributes:
+            function_events (torch.autograd.profiler.EventList).  It's a list that has helper
+            methods, like 1) show record items in a pretty-print table.
+            2) do averaging by grouping on keys. 3) and more.
+
+            process_global_function_events (List[torch.autograd.profiler.FunctionEvent]).
+            It's a list of ``FunctionEvent`` elements. Every element is a profiling result
+            of an RPC request handling within the profiling range.
+        """
+        if not self.enabled:
+            return
+
+        process_global_events = _disable_server_process_global_profiler()
+
+        # Every element in this list is a thread profiling result from an RPC request handling.
+        process_global_function_events = []
+        for thread_local_events in process_global_events:
+            # Parse from ``Event``s to ``FunctionEvent``s.
+            thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records(
+                thread_local_events
+            )
+            thread_local_function_events.sort(
+                key=lambda function_event: [
+                    function_event.time_range.start,
+                    -(function_event.time_range.end),
+                ]
+            )
+            process_global_function_events.append(thread_local_function_events)
+
+        flattened_function_events = list(
+            itertools.chain.from_iterable(process_global_function_events)
+        )
+        self.function_events = torch.autograd.profiler_util.EventList(
+            flattened_function_events,
+            use_cuda=self.use_cuda,
+            profile_memory=self.profile_memory,
+        )
+        self.function_events._build_tree()
+
+        self.process_global_function_events = process_global_function_events
+
+        return False
diff --git a/MLPY/Lib/site-packages/torch/distributed/run.py b/MLPY/Lib/site-packages/torch/distributed/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..120c9cec42a6117a1173c550bd0b3ee32edec5f5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/run.py
@@ -0,0 +1,883 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Superset of ``torch.distributed.launch``.
+
+``torchrun`` provides a superset of the functionality as ``torch.distributed.launch``
+with the following additional functionalities:
+
+1. Worker failures are handled gracefully by restarting all workers.
+
+2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically.
+
+3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
+
+.. note:: ``torchrun`` is a python
+          `console script `_
+          to the main module
+          `torch.distributed.run `_
+          declared in the ``entry_points`` configuration in
+          `setup.py `_.
+          It is equivalent to invoking ``python -m torch.distributed.run``.
+
+
+Transitioning from torch.distributed.launch to torchrun
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+
+``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except**
+for ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch``
+to ``torchrun`` follow these steps:
+
+1.  If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
+    Then you need simply omit the ``--use-env`` flag, e.g.:
+
+    +--------------------------------------------------------------------+--------------------------------------------+
+    |         ``torch.distributed.launch``                               |                ``torchrun``                |
+    +====================================================================+============================================+
+    |                                                                    |                                            |
+    | .. code-block:: shell-session                                      | .. code-block:: shell-session              |
+    |                                                                    |                                            |
+    |    $ python -m torch.distributed.launch --use-env train_script.py  |    $ torchrun train_script.py              |
+    |                                                                    |                                            |
+    +--------------------------------------------------------------------+--------------------------------------------+
+
+2.  If your training script reads local rank from a ``--local-rank`` cmd argument.
+    Change your training script to read from the ``LOCAL_RANK`` environment variable as
+    demonstrated by the following code snippet:
+
+    +-------------------------------------------------------+----------------------------------------------------+
+    |         ``torch.distributed.launch``                  |                    ``torchrun``                    |
+    +=======================================================+====================================================+
+    |                                                       |                                                    |
+    | .. code-block:: python                                | .. code-block:: python                             |
+    |                                                       |                                                    |
+    |                                                       |                                                    |
+    |    import argparse                                    |     import os                                      |
+    |    parser = argparse.ArgumentParser()                 |     local_rank = int(os.environ["LOCAL_RANK"])     |
+    |    parser.add_argument("--local-rank", type=int)      |                                                    |
+    |    args = parser.parse_args()                         |                                                    |
+    |                                                       |                                                    |
+    |    local_rank = args.local_rank                       |                                                    |
+    |                                                       |                                                    |
+    +-------------------------------------------------------+----------------------------------------------------+
+
+The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``.
+To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
+please refer to:
+
+* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
+* the rest of this page for more information on the features of ``torchrun``.
+
+
+Usage
+--------
+
+Single-node multi-worker
+++++++++++++++++++++++++++++++
+
+::
+
+    torchrun
+        --standalone
+        --nnodes=1
+        --nproc-per-node=$NUM_TRAINERS
+        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
+
+Stacked single-node multi-worker
++++++++++++++++++++++++++++++++++++
+
+To run multiple instances (separate jobs) of single-node, multi-worker on the
+same host, we need to make sure that each instance (job) is
+setup on different ports to avoid port conflicts (or worse, two jobs being merged
+as a single job). To do this you have to run with ``--rdzv-backend=c10d``
+and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
+For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
+port automatically instead of manually assigning different ports for each run.
+
+::
+
+    torchrun
+        --rdzv-backend=c10d
+        --rdzv-endpoint=localhost:0
+        --nnodes=1
+        --nproc-per-node=$NUM_TRAINERS
+        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
+
+
+Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+::
+
+    torchrun
+        --nnodes=$NUM_NODES
+        --nproc-per-node=$NUM_TRAINERS
+        --max-restarts=3
+        --rdzv-id=$JOB_ID
+        --rdzv-backend=c10d
+        --rdzv-endpoint=$HOST_NODE_ADDR
+        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
+
+``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and
+the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
+node in your training cluster, but ideally you should pick a node that has a high bandwidth.
+
+.. note::
+   If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
+
+Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+::
+
+    torchrun
+        --nnodes=1:4
+        --nproc-per-node=$NUM_TRAINERS
+        --max-restarts=3
+        --rdzv-id=$JOB_ID
+        --rdzv-backend=c10d
+        --rdzv-endpoint=$HOST_NODE_ADDR
+        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
+
+``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and
+the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
+node in your training cluster, but ideally you should pick a node that has a high bandwidth.
+
+.. note::
+   If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
+
+Note on rendezvous backend
+------------------------------
+
+For multi-node training you need to specify:
+
+1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
+2. ``--rdzv-backend``: An implementation of
+   :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler`
+3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
+   ``host:port``.
+
+Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy)  rendezvous backends are
+supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
+enabled (e.g. ``--enable-v2``).
+
+.. warning::
+   ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
+   server. Our tests use etcd v3.4.3.
+
+.. warning::
+   For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
+   equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
+   removed in a future version.
+
+Definitions
+--------------
+
+1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
+
+2. ``Worker`` - A worker in the context of distributed training.
+
+3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
+
+4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
+
+5. ``RANK`` - The rank of the worker within a worker group.
+
+6. ``WORLD_SIZE`` - The total number of workers in a worker group.
+
+7. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
+
+8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
+
+9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
+   used by each node to join as a member of a particular worker group.
+
+9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
+   consistent key-value store.
+
+10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``:``.
+
+A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
+all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
+
+Environment Variables
+----------------------
+
+The following environment variables are made available to you in your script:
+
+1. ``LOCAL_RANK`` -  The local rank.
+
+2. ``RANK`` -  The global rank.
+
+3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
+   running a single worker group per node, this is the rank of the node.
+
+4. ``ROLE_RANK`` -  The rank of the worker across all the workers that have the same role. The role
+   of the worker is specified in the ``WorkerSpec``.
+
+5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
+   ``--nproc-per-node`` specified on ``torchrun``.
+
+6. ``WORLD_SIZE`` - The world size (total number of workers in the job).
+
+7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
+   in ``WorkerSpec``.
+
+8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
+   the Torch Distributed backend.
+
+9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
+
+10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
+
+11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
+
+12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
+
+13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
+    use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
+
+Deployment
+------------
+
+1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
+   passed as ``--rdzv-endpoint`` to the launcher script)
+
+2. Single-node multi-worker: Start the launcher on the host to start the agent process which
+   creates and monitors a local worker group.
+
+3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes
+   participating in training.
+
+When using a job/cluster manager the entry point command to the multi-node job should be this
+launcher.
+
+Failure Modes
+---------------
+
+1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
+   are stopped and restarted up to ``max_restarts``.
+
+2. Agent failure: An agent failure results in a local worker group failure. It is up to the job
+   manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
+   are supported by the agent.
+
+3. Node failure: Same as agent failure.
+
+Membership Changes
+--------------------
+
+1. Node departure (scale-down): The agent is notified of the departure, all existing workers are
+   stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
+   ``WORLD_SIZE``.
+
+2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
+   a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
+   ``WORLD_SIZE``.
+
+Important Notices
+--------------------
+
+1. This utility and multi-process distributed (single-node or
+   multi-node) GPU training currently only achieves the best performance using
+   the NCCL distributed backend. Thus NCCL backend is the recommended backend to
+   use for GPU training.
+
+2. The environment variables necessary to initialize a Torch process group are provided to you by
+   this module, no need for you to pass ``RANK`` manually.  To initialize a process group in your
+   training script, simply run:
+
+::
+
+ >>> # xdoctest: +SKIP("stub")
+ >>> import torch.distributed as dist
+ >>> dist.init_process_group(backend="gloo|nccl")
+
+3. In your training program, you can either use regular distributed functions
+   or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
+   training program uses GPUs for training and you would like to use
+   :func:`torch.nn.parallel.DistributedDataParallel` module,
+   here is how to configure it.
+
+::
+
+    local_rank = int(os.environ["LOCAL_RANK"])
+    model = torch.nn.parallel.DistributedDataParallel(model,
+                                                      device_ids=[local_rank],
+                                                      output_device=local_rank)
+
+Please ensure that ``device_ids`` argument is set to be the only GPU device id
+that your code will be operating on. This is generally the local rank of the
+process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
+and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
+utility
+
+
+4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
+   checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
+   for lost work.
+
+5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
+   nodes run the same number of local workers (per role).
+
+6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a
+   different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
+   ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
+
+7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
+   ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
+
+8. It is recommended for your script to have the following structure:
+
+::
+
+  def main():
+    load_checkpoint(checkpoint_path)
+    initialize()
+    train()
+
+  def train():
+    for batch in iter(dataset):
+      train_step(batch)
+
+      if should_checkpoint:
+        save_checkpoint(checkpoint_path)
+
+9. (Recommended) On worker errors, this tool will summarize the details of the error
+   (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
+   is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
+   error summary print out, you must decorate your main entrypoint function in your
+   training script as shown in the example below. If not decorated, then the summary
+   will not include the traceback of the exception and will only contain the exitcode.
+   For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
+
+::
+
+  from torch.distributed.elastic.multiprocessing.errors import record
+
+  @record
+  def main():
+      # do train
+      pass
+
+  if __name__ == "__main__":
+      main()
+
+"""
+import logging
+import os
+import sys
+import uuid
+import importlib.metadata as metadata
+from argparse import REMAINDER, ArgumentParser
+from typing import Callable, List, Tuple, Type, Union, Optional, Set
+
+import torch
+from torch.distributed.argparse_util import check_env, env
+from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
+from torch.distributed.elastic.multiprocessing.errors import record
+from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
+from torch.distributed.elastic.utils import macros
+from torch.distributed.elastic.utils.logging import get_logger
+from torch.distributed.launcher.api import LaunchConfig, elastic_launch
+from torch.utils.backend_registration import _get_custom_mod_func
+
+log = get_logger(__name__)
+
+
+def get_args_parser() -> ArgumentParser:
+    """Parse the command line options."""
+    parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
+
+    #
+    # Worker/node size related arguments.
+    #
+
+    parser.add_argument(
+        "--nnodes",
+        action=env,
+        type=str,
+        default="1:1",
+        help="Number of nodes, or the range of nodes in form :.",
+    )
+    parser.add_argument(
+        "--nproc-per-node",
+        "--nproc_per_node",
+        action=env,
+        type=str,
+        default="1",
+        help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
+    )
+
+    #
+    # Rendezvous related arguments
+    #
+
+    parser.add_argument(
+        "--rdzv-backend",
+        "--rdzv_backend",
+        action=env,
+        type=str,
+        default="static",
+        help="Rendezvous backend.",
+    )
+    parser.add_argument(
+        "--rdzv-endpoint",
+        "--rdzv_endpoint",
+        action=env,
+        type=str,
+        default="",
+        help="Rendezvous backend endpoint; usually in form :.",
+    )
+    parser.add_argument(
+        "--rdzv-id",
+        "--rdzv_id",
+        action=env,
+        type=str,
+        default="none",
+        help="User-defined group id.",
+    )
+    parser.add_argument(
+        "--rdzv-conf",
+        "--rdzv_conf",
+        action=env,
+        type=str,
+        default="",
+        help="Additional rendezvous configuration (=,=,...).",
+    )
+    parser.add_argument(
+        "--standalone",
+        action=check_env,
+        help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
+        "on a free port. Useful when launching single-node, multi-worker job. If specified "
+        "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values "
+        "are ignored.",
+    )
+
+    #
+    # User-code launch related arguments.
+    #
+
+    parser.add_argument(
+        "--max-restarts",
+        "--max_restarts",
+        action=env,
+        type=int,
+        default=0,
+        help="Maximum number of worker group restarts before failing.",
+    )
+    parser.add_argument(
+        "--monitor-interval",
+        "--monitor_interval",
+        action=env,
+        type=float,
+        default=5,
+        help="Interval, in seconds, to monitor the state of workers.",
+    )
+    parser.add_argument(
+        "--start-method",
+        "--start_method",
+        action=env,
+        type=str,
+        default="spawn",
+        choices=["spawn", "fork", "forkserver"],
+        help="Multiprocessing start method to use when creating workers.",
+    )
+    parser.add_argument(
+        "--role",
+        action=env,
+        type=str,
+        default="default",
+        help="User-defined role for the workers.",
+    )
+    parser.add_argument(
+        "-m",
+        "--module",
+        action=check_env,
+        help="Change each process to interpret the launch script as a Python module, executing "
+        "with the same behavior as 'python -m'.",
+    )
+    parser.add_argument(
+        "--no-python",
+        "--no_python",
+        action=check_env,
+        help="Skip prepending the training script with 'python' - just execute it directly. Useful "
+        "when the script is not a Python script.",
+    )
+
+    parser.add_argument(
+        "--run-path",
+        "--run_path",
+        action=check_env,
+        help="Run the training script with runpy.run_path in the same interpreter."
+        " Script must be provided as an abs path (e.g. /abs/path/script.py)."
+        " Takes precedence over --no-python.",
+    )
+    parser.add_argument(
+        "--log-dir",
+        "--log_dir",
+        action=env,
+        type=str,
+        default=None,
+        help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
+        "directory is re-used for multiple runs (a unique job-level sub-directory is created with "
+        "rdzv_id as the prefix).",
+    )
+    parser.add_argument(
+        "-r",
+        "--redirects",
+        action=env,
+        type=str,
+        default="0",
+        help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
+        "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
+        "stderr for local rank 1).",
+    )
+    parser.add_argument(
+        "-t",
+        "--tee",
+        action=env,
+        type=str,
+        default="0",
+        help="Tee std streams into a log file and also to console (see --redirects for format).",
+    )
+
+    parser.add_argument(
+        "--local-ranks-filter",
+        "--local_ranks_filter",
+        action=env,
+        type=str,
+        default="",
+        help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will "
+        "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to"
+        "log files saved via --redirect or --tee",
+    )
+
+    #
+    # Backwards compatible parameters with caffe2.distributed.launch.
+    #
+
+    parser.add_argument(
+        "--node-rank",
+        "--node_rank",
+        type=int,
+        action=env,
+        default=0,
+        help="Rank of the node for multi-node distributed training.",
+    )
+    parser.add_argument(
+        "--master-addr",
+        "--master_addr",
+        default="127.0.0.1",
+        type=str,
+        action=env,
+        help="Address of the master node (rank 0) that only used for static rendezvous. It should "
+        "be either the IP address or the hostname of rank 0. For single node multi-proc training "
+        "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
+        "`[0:0:0:0:0:0:0:1]`.",
+    )
+    parser.add_argument(
+        "--master-port",
+        "--master_port",
+        default=29500,
+        type=int,
+        action=env,
+        help="Port on the master node (rank 0) to be used for communication during distributed "
+        "training. It is only used for static rendezvous.",
+    )
+    parser.add_argument(
+        "--local-addr",
+        "--local_addr",
+        default=None,
+        type=str,
+        action=env,
+        help="Address of the local node. If specified, will use the given address for connection. "
+        "Else, will look up the local node address instead. Else, it will be default to local "
+        "machine's FQDN.",
+    )
+
+    parser.add_argument(
+        "--logs-specs",
+        "--logs_specs",
+        default=None,
+        type=str,
+        help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
+        "Can be used to override custom logging behavior.",
+    )
+
+    #
+    # Positional arguments.
+    #
+
+    parser.add_argument(
+        "training_script",
+        type=str,
+        help="Full path to the (single GPU) training program/script to be launched in parallel, "
+        "followed by all the arguments for the training script.",
+    )
+
+    # Rest from the training program.
+    parser.add_argument("training_script_args", nargs=REMAINDER)
+
+    return parser
+
+
+def parse_args(args):
+    parser = get_args_parser()
+    return parser.parse_args(args)
+
+
+def parse_min_max_nnodes(nnodes: str):
+    arr = nnodes.split(":")
+
+    if len(arr) == 1:
+        min_nodes = max_nodes = int(arr[0])
+    elif len(arr) == 2:
+        min_nodes = int(arr[0])
+        max_nodes = int(arr[1])
+    else:
+        raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format')  # noqa: E231
+
+    return min_nodes, max_nodes
+
+
+def determine_local_world_size(nproc_per_node: str):
+    try:
+        logging.info("Using nproc_per_node=%s.", nproc_per_node)
+        return int(nproc_per_node)
+    except ValueError as e:
+        if nproc_per_node == "cpu":
+            num_proc = os.cpu_count()
+            device_type = "cpu"
+        elif nproc_per_node == "gpu":
+            if not torch.cuda.is_available():
+                raise ValueError("Cuda is not available.") from e
+            device_type = "gpu"
+            num_proc = torch.cuda.device_count()
+        elif nproc_per_node == torch._C._get_privateuse1_backend_name():
+            if not _get_custom_mod_func("is_available")():
+                raise ValueError(f"{nproc_per_node} is not available.") from e
+            device_type = nproc_per_node
+            num_proc = _get_custom_mod_func("device_count")()
+        elif nproc_per_node == "auto":
+            if torch.cuda.is_available():
+                num_proc = torch.cuda.device_count()
+                device_type = "gpu"
+            elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \
+                    _get_custom_mod_func("is_available")():
+                num_proc = _get_custom_mod_func("device_count")()
+                device_type = torch._C._get_privateuse1_backend_name()
+            else:
+                num_proc = os.cpu_count()
+                device_type = "cpu"
+        else:
+            raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
+
+        log.info(
+            "Using nproc_per_node=%s,"
+            " setting to %s since the instance "
+            "has %s %s",
+            nproc_per_node, num_proc, os.cpu_count(), device_type
+        )
+        return num_proc
+
+
+def get_rdzv_endpoint(args):
+    if args.rdzv_backend == "static" and not args.rdzv_endpoint:
+        return f"{args.master_addr}:{args.master_port}"  # noqa: E231
+    return args.rdzv_endpoint
+
+
+def get_use_env(args) -> bool:
+    """
+    Retrieve ``use_env`` from the args.
+
+    ``use_env`` is a legacy argument, if ``use_env`` is False, the
+    ``--node-rank`` argument will be transferred to all worker processes.
+    ``use_env`` is only used by the ``torch.distributed.launch`` and will
+    be deprecated in future releases.
+    """
+    if not hasattr(args, "use_env"):
+        return True
+    return args.use_env
+
+
+def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
+    """
+    Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
+    Provides plugin mechanism to provide custom implementation of LogsSpecs.
+
+    Returns `DefaultLogsSpecs` when logs_spec_name is None.
+    Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
+    """
+    logs_specs_cls = None
+    if logs_specs_name is not None:
+        eps = metadata.entry_points()
+        if hasattr(eps, "select"):  # >= 3.10
+            group = eps.select(group="torchrun.logs_specs")
+            if group.select(name=logs_specs_name):
+                logs_specs_cls = group[logs_specs_name].load()
+
+        elif specs := eps.get("torchrun.logs_specs"):  # < 3.10
+            if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
+                logs_specs_cls = entrypoint_list[0].load()
+
+        if logs_specs_cls is None:
+            raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
+
+        logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
+    else:
+        logs_specs_cls = DefaultLogsSpecs
+
+    return logs_specs_cls
+
+
+def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
+    # If ``args`` not passed, defaults to ``sys.argv[:1]``
+    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
+    assert 0 < min_nodes <= max_nodes
+    assert args.max_restarts >= 0
+
+    if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint:
+        log.warning(
+            "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
+            "is not specified."
+        )
+
+    nproc_per_node = determine_local_world_size(args.nproc_per_node)
+    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
+        omp_num_threads = 1
+        log.warning(
+            "\n*****************************************\n"
+            "Setting OMP_NUM_THREADS environment variable for each process to be "
+            "%s in default, to avoid your system being overloaded, "
+            "please further tune the variable for optimal performance in "
+            "your application as needed. \n"
+            "*****************************************",
+            omp_num_threads
+        )
+        # This env variable will be passed down to the subprocesses
+        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
+
+    log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE")
+
+    rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
+
+    if args.rdzv_backend == "static":
+        rdzv_configs["rank"] = args.node_rank
+
+    rdzv_endpoint = get_rdzv_endpoint(args)
+
+    ranks: Optional[Set[int]] = None
+    if args.local_ranks_filter:
+        try:
+            ranks = set(map(int, args.local_ranks_filter.split(",")))
+            assert ranks
+        except Exception as e:
+            raise Exception(
+                "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
+            ) from e
+
+    logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
+    logs_specs = logs_specs_cls(
+        log_dir=args.log_dir,
+        redirects=Std.from_str(args.redirects),
+        tee=Std.from_str(args.tee),
+        local_ranks_filter=ranks,
+    )
+
+    config = LaunchConfig(
+        min_nodes=min_nodes,
+        max_nodes=max_nodes,
+        nproc_per_node=nproc_per_node,
+        run_id=args.rdzv_id,
+        role=args.role,
+        rdzv_endpoint=rdzv_endpoint,
+        rdzv_backend=args.rdzv_backend,
+        rdzv_configs=rdzv_configs,
+        max_restarts=args.max_restarts,
+        monitor_interval=args.monitor_interval,
+        start_method=args.start_method,
+        log_line_prefix_template=log_line_prefix_template,
+        local_addr=args.local_addr,
+        logs_specs=logs_specs,
+    )
+
+    with_python = not args.no_python
+    cmd: Union[Callable, str]
+    cmd_args = []
+    use_env = get_use_env(args)
+    if args.run_path:
+        cmd = run_script_path
+        cmd_args.append(args.training_script)
+    else:
+        if with_python:
+            cmd = os.getenv("PYTHON_EXEC", sys.executable)
+            cmd_args.append("-u")
+            if args.module:
+                cmd_args.append("-m")
+            cmd_args.append(args.training_script)
+        else:
+            if args.module:
+                raise ValueError(
+                    "Don't use both the '--no-python' flag"
+                    " and the '--module' flag at the same time."
+                )
+            cmd = args.training_script
+    if not use_env:
+        cmd_args.append(f"--local-rank={macros.local_rank}")
+    cmd_args.extend(args.training_script_args)
+
+    return config, cmd, cmd_args
+
+
+def run_script_path(training_script: str, *training_script_args: str):
+    """
+    Run the provided `training_script` from within this interpreter.
+
+    Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
+    """
+    import runpy
+    import sys
+
+    sys.argv = [training_script] + [*training_script_args]
+    runpy.run_path(sys.argv[0], run_name="__main__")
+
+
+def run(args):
+    if args.standalone:
+        args.rdzv_backend = "c10d"
+        args.rdzv_endpoint = "localhost:0"
+        args.rdzv_id = str(uuid.uuid4())
+        log.info(
+            "\n**************************************\n"
+            "Rendezvous info:\n"
+            "--rdzv-backend=%s "
+            "--rdzv-endpoint=%s "
+            "--rdzv-id=%s\n"
+            "**************************************\n",
+            args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id
+        )
+
+    config, cmd, cmd_args = config_from_args(args)
+    elastic_launch(
+        config=config,
+        entrypoint=cmd,
+    )(*cmd_args)
+
+
+@record
+def main(args=None):
+    args = parse_args(args)
+    run(args)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/__init__.py b/MLPY/Lib/site-packages/torch/distributed/tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df247c72ce438d6c965b49ab9e173373ebbb74c2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__init__.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa7a940b2ffbfd3ba8ba81862705dd7a046c19f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from torch.distributed.tensor.parallel.api import parallelize_module
+
+from torch.distributed.tensor.parallel.loss import loss_parallel
+from torch.distributed.tensor.parallel.style import (
+    ColwiseParallel,
+    ParallelStyle,
+    PrepareModuleInput,
+    PrepareModuleOutput,
+    RowwiseParallel,
+    SequenceParallel,
+)
+
+__all__ = [
+    "ColwiseParallel",
+    "ParallelStyle",
+    "PrepareModuleInput",
+    "PrepareModuleOutput",
+    "RowwiseParallel",
+    "SequenceParallel",
+    "parallelize_module",
+    "loss_parallel"
+]
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecf7f940cefe3a648dd5fab924ac9205761cb1f7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ec8b7a19e2ed5c32fc076c4972adc439c115726
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db34013947773a89cda4c68b24dd4a307f93c954
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e22b75090c44c35cbe5056bf33d2d0a36656ec3d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5386f02fd77900a786e060c878320a241083a074
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6141aac0902acd4447093cd31bd36fd5a5fd44a4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5604920f9059e0c708ffa54f75cc71de2c2a2ea9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81b27f4864440fa1851aec646f7de7492c3a06d9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c697ad3051067a3628b20d975d20912e0151cb30
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7390282e5c4448c5fd93ad94b20d86875027d48b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py
@@ -0,0 +1,51 @@
+from functools import partial
+from typing import no_type_check, Optional, Tuple
+
+import torch
+from torch.distributed._functional_collectives import AsyncCollectiveTensor
+from torch.distributed._tensor import DTensor
+from torch.distributed._tensor.placement_types import DTensorSpec
+
+
+@no_type_check
+def sync_grad_hook(grad, *, device_handle=None, compute_stream=None):
+    if isinstance(grad, AsyncCollectiveTensor):
+        if compute_stream is not None:
+            with device_handle.stream(compute_stream):
+                grad = grad.wait()
+        else:
+            grad = grad.wait()
+
+    return grad
+
+
+def _flatten_tensor(
+    tensor: torch.Tensor,
+) -> Tuple[torch.Tensor, Optional[DTensorSpec]]:
+    if isinstance(tensor, DTensor):
+        tensor._local_tensor.requires_grad_()
+        return tensor._local_tensor, tensor._spec
+    return tensor, None
+
+
+@no_type_check
+def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None):
+    # unflatten would mainly be called everytime FSDP allgather parameters.
+    result = DTensor.from_local(
+        tensor,
+        spec.mesh,
+        spec.placements,
+        run_check=False,
+        shape=spec.shape,
+        stride=spec.stride,
+    )
+    if tensor.requires_grad:
+        # only register the hook if the tensor requires grad
+        tensor.register_hook(
+            partial(
+                sync_grad_hook,
+                device_handle=device_handle,
+                compute_stream=compute_stream,
+            )
+        )
+    return result
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/_utils.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..314b3c0e9768aaafc4f03cb4f08b85bebd998e61
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/_utils.py
@@ -0,0 +1,60 @@
+import warnings
+from typing import Tuple, Union
+
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed._tensor.placement_types import Placement
+from torch.distributed.device_mesh import _mesh_resources
+try:
+    from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
+except Exception:
+    def is_torchdynamo_compiling():  # type: ignore[misc]
+        return False
+
+LayoutsType = Union[Placement, Tuple[Placement, ...]]
+
+
+def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
+    """
+    Inject common validation logics for `_prepare_input` funcs via this decorator.
+
+    Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor`
+    and only 1D :class:`DeviceMesh` is passed in.
+    """
+    # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo.
+    if not is_torchdynamo_compiling():
+        warnings.warn(f"{func_name} is deprecated and will be removed soon. {extra_msg}")
+
+
+def _validate_tp_mesh_dim(
+    device_mesh: DeviceMesh,
+) -> None:
+    """
+    Check whether TP mesh dimension is valid or not.
+
+    Args:
+        device_mesh (:class:`DeviceMesh`):
+            The `device_mesh` where we perform
+            Tensor Parallelism on.
+
+    Return:
+        `True` if the mesh dimension
+        is valid, `False` otherwise.
+    """
+    if device_mesh.ndim > 1:
+        raise ValueError(f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!"
+                         "If you have a 2-D or N-D device_mesh, consider passing in device_mesh[\"tp\"]")
+
+    parent_mesh = _mesh_resources.get_parent_mesh(device_mesh)
+    if parent_mesh:
+        if parent_mesh.ndim != 2:
+            raise RuntimeError(
+                f"Found TP device_mesh has a parent mesh with dims {parent_mesh.ndim}",
+                "Currently we only support 2D TP composition with DP.",
+            )
+
+        tp_mesh_dim = _mesh_resources.get_parent_mesh_dim(device_mesh)
+        if tp_mesh_dim != 1:
+            raise RuntimeError(
+                f"Found TP device_mesh on the {tp_mesh_dim} dimension of its parent mesh.",
+                "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.",
+            )
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/api.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fa5497a8a7ca2dfec302316dfa33d66d0851afe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/api.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import Dict, Union
+
+import torch
+import torch.distributed._tensor.random as random
+import torch.nn as nn
+from torch.distributed._tensor import (
+    DeviceMesh,
+)
+from torch.distributed._tensor.random import (
+    is_rng_supported_mesh,
+    TensorParallelRNGTracker,
+)
+from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
+from torch.distributed.tensor.parallel.style import (
+    ParallelStyle,
+)
+
+
+__all__ = [
+    "parallelize_module",
+]
+
+
+def parallelize_module(  # type: ignore[return]
+    module: nn.Module,
+    device_mesh: DeviceMesh,
+    parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
+) -> nn.Module:
+    """
+    Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
+
+    We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
+    :class:`ParallelStyle`, which indicates how user wants the module or sub_module
+    to be parallelized.
+
+    User can also specify different parallel style per module fully qualified name (FQN).
+
+    Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
+    slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
+
+    Args:
+        module (:class:`nn.Module`):
+            Module to be parallelized.
+        device_mesh (:class:`DeviceMesh`):
+            Object which describes the mesh topology
+            of devices for the DTensor.
+        parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
+            The plan used to parallelize the module. It can be either a
+            :class:`ParallelStyle` object which contains how
+            we prepare input/output for Tensor Parallelism or it can be a
+            dict of module FQN and its corresponding :class:`ParallelStyle` object.
+    Return:
+        A :class:`nn.Module` object parallelized.
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>>
+        >>> # Define the module.
+        >>> m = Model(...)
+        >>> tp_mesh = init_device_mesh("cuda", (8,))
+        >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
+        >>>
+
+    .. note:: For complex module architecture like Attention, MLP layers, we recommend composing
+        different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
+        as a parallelize_plan, to achieves the desired sharding computation.
+    """
+    torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
+
+    _validate_tp_mesh_dim(device_mesh)
+
+    # instantiate a TP RNG state tracker if it's not there
+    if is_rng_supported_mesh(device_mesh) and not isinstance(
+        random._rng_tracker, TensorParallelRNGTracker
+    ):
+        random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
+        # TODO: we should allow user to pass in the default seed from a config
+        random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
+        # By default we execute random ops in non-tensor-parallel region. If users want
+        # to execute in tensor-parallel region, they can manually set this field to True
+        # after parallelizing the model.
+        random._rng_tracker.distribute_region_enabled = False
+
+    if isinstance(parallelize_plan, ParallelStyle):
+        return parallelize_plan._apply(module, device_mesh)
+    elif isinstance(parallelize_plan, dict):
+        for module_path, parallelize_style in parallelize_plan.items():
+            sub_module = module.get_submodule(module_path)
+            parent_module = module
+            if "." in module_path:
+                parent_module_path = ".".join(module_path.split(".")[:-1])
+                parent_module = module.get_submodule(parent_module_path)
+                module_path = module_path.split(".")[-1]
+            parent_module.register_module(  # type: ignore[call-arg] # pyre-ignore[20]
+                module_path,
+                parallelize_module(  # type: ignore[arg-type]
+                    sub_module, device_mesh, parallelize_style  # type: ignore[arg-type] # pyre-ignore[6]
+                ),
+            )
+        return module
+    else:
+        raise RuntimeError(  # pyre-ignore[7]
+            "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
+            f" parallelize_plan, {type(parallelize_plan)} found!"
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/ddp.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba58ac5096e2a88a95beae4c3e89105f0162eff1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/ddp.py
@@ -0,0 +1,96 @@
+from typing import Any, List, Tuple
+
+import torch.nn as nn
+from torch.distributed.tensor.parallel._data_parallel_utils import (
+    _flatten_tensor,
+    _unflatten_tensor,
+)
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _get_submodule_n_params(module: nn.Module, path: str):
+    """
+    Get submodule and the direct path of parameter from the module
+    """
+    if "." in path:
+        path_list = path.split(".")
+        parent_module_path = ".".join(path_list[:-1])
+        module = module.get_submodule(parent_module_path)
+        path = path_list[-1]
+    return module, path
+
+
+def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
+    """
+    Update parameters within the module
+    """
+    for item in param_list:
+        parent_module, module_path, t = item
+        assert hasattr(parent_module, module_path)
+        delattr(parent_module, module_path)
+        setattr(parent_module, module_path, t)
+
+
+def _reconstruct_dtensor(module: nn.Module, _input: Any):
+    """
+    Recontruct DTensor parameters from local tensors
+    """
+    param_list = []
+    # TODO: To add perf optimizations to this iterations
+    for name, t in module.named_parameters():
+        if hasattr(t, "_st_info"):
+            dtensor = _unflatten_tensor(t, t._st_info)
+            param_list.append((*_get_submodule_n_params(module, name), dtensor))
+    _update_module_param(param_list)  # type: ignore[arg-type]
+
+
+def _localize_dtensor(module: nn.Module, *_: Any):
+    """
+    Convert DTensor parameters to local tensors
+    """
+    param_list = []
+    for name, param in module.named_parameters():
+        t, sharding_info = _flatten_tensor(param)
+        if sharding_info is not None:
+            t = nn.Parameter(t)
+            t._st_info = sharding_info  # type: ignore[attr-defined]
+            param_list.append((*_get_submodule_n_params(module, name), t))
+    _update_module_param(param_list)  # type: ignore[arg-type]
+
+
+def _pre_dp_module_transform(module: nn.Module):
+    """
+    Enable the composability between Tensor Parallelism (TP) and Data
+    Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
+    are DTensors to local tensors before wrapping with data parallelism API.
+    We then register two hooks, one for converting local tensors back to DTensor
+    preforward and one to convert DTensors back to tensors after Forward. By
+    integrating this way, we avoid any special handling of DTensor parameters by DDP
+    and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
+
+    For now, this API only works with ``DistributedDataParallel``. It will later support
+    other DP methods such as FSDP.
+
+    Args:
+        module (:class:`nn.Module`):
+            Module which has been applied TP on.
+
+    Example::
+        >>> # xdoctest: +SKIP("distributed")
+        >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
+        >>> from torch.nn.parallel import DistributedDataParallel as DDP
+        >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
+        >>>
+        >>> # Define the module.
+        >>> m = module(...)
+        >>> parallelize_module(m, PairwiseParallel())
+        >>> m = pre_dp_module_transform(m)
+        >>> m = DDP(m)
+        >>>
+    """
+
+    _localize_dtensor(module, None, None)
+    # TODO: To add test cases and ensure that it works for nested modules
+    module.register_forward_pre_hook(_reconstruct_dtensor)
+    module.register_forward_hook(_localize_dtensor)
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/fsdp.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d67798f77c04096c6541d4343c5d8b0dfc97b00
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/fsdp.py
@@ -0,0 +1,391 @@
+import copy
+from typing import Any, cast, List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+
+import torch.distributed._shard.sharding_spec as shard_spec
+import torch.distributed.distributed_c10d as c10d
+from torch.distributed._shard.sharded_tensor import (
+    Shard,
+    ShardedTensor,
+    ShardedTensorMetadata,
+    TensorProperties,
+)
+
+from torch.distributed._shard.sharding_spec import ShardMetadata
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
+from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
+from torch.distributed.device_mesh import _mesh_resources
+
+from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
+from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
+from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
+from torch.distributed.remote_device import _remote_device
+from torch.distributed.tensor.parallel._data_parallel_utils import (
+    _flatten_tensor,
+    _unflatten_tensor,
+)
+
+__all__ = ["DTensorExtensions"]
+
+
+def _get_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
+    device_mesh = tensor.device_mesh
+    assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
+
+    placement = tensor.placements[0]
+    offsets = [0] * len(tensor.size())
+    num_chunks = device_mesh.size(mesh_dim=0)
+
+    if tensor.placements[0].is_shard():
+        shard_dim = cast(DShard, placement).dim
+        chunk_size = tensor.size(shard_dim) // num_chunks
+        offsets[shard_dim] = chunk_size
+
+    return (torch.Size(offsets), tensor._local_tensor.size())
+
+
+def _get_box_for(tensor: DTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
+    offsets, size = _get_box(tensor)
+    return (torch.Size([val * idx for val in offsets]), size)
+
+
+def _get_local_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
+    device_mesh = tensor.device_mesh
+    coord = device_mesh.get_coordinate()
+    assert coord is not None
+    return _get_box_for(tensor, coord[0])
+
+
+def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata:
+    mesh = dt.device_mesh
+    assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
+
+    offsets, sizes = _get_local_box(dt)
+    return ShardMetadata(
+        shard_offsets=list(offsets),
+        shard_sizes=list(sizes),
+        placement=f"rank:{current_rank}/{dt._local_tensor.device}",
+    )
+
+
+def _create_sharded_tensor_md_from_dt(
+    dt: DTensor, dt_pg: c10d.ProcessGroup
+) -> ShardedTensorMetadata:
+    # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
+    # and yet has only one valid shard for the current rank.
+
+    shards_md = []
+    my_rank = dist.get_rank(dt_pg)
+    scapegoat_rank = 0 if my_rank > 0 else 1
+
+    if dt.placements[0].is_shard():
+        shard_count = dt_pg.size()
+    else:
+        shard_count = 1
+
+    for i in range(shard_count):
+        offsets, sizes = _get_box_for(dt, i)
+        shards_md.append(
+            ShardMetadata(
+                shard_offsets=list(offsets),
+                shard_sizes=list(sizes),
+                placement=(
+                    f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
+                ),
+            )
+        )
+
+    return ShardedTensorMetadata(
+        shards_metadata=shards_md,
+        size=dt.size(),
+        tensor_properties=TensorProperties(
+            dtype=dt.dtype,
+            layout=dt.layout,
+            requires_grad=dt.requires_grad,
+            # ignore memory_format and pin_memory as those are not supported by DT
+        ),
+    )
+
+
+def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup:
+    mesh = dt.device_mesh
+    assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
+    dim_groups = mesh.get_group()
+    assert isinstance(dim_groups, list)
+    return dim_groups[0]
+
+
+def _rewrite_spec_if_needed(
+    spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
+) -> shard_spec.ShardingSpec:
+    """
+    Rewrite ``spec`` to match the device of ``tensor``.
+
+    FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
+    produces CUDA metadata, ST construction bombs.
+    """
+    if not isinstance(spec, ChunkShardingSpec):
+        return spec
+
+    # let's see if we need
+    rewrite = False
+    for p in spec.placements:
+        p = cast(_remote_device, p)
+        if p.rank() == rank and p.device() != tensor.device:
+            rewrite = True
+            break
+    if rewrite:
+        spec = copy.deepcopy(spec)
+        for i, placement in enumerate(spec.placements):
+            placement = cast(_remote_device, placement)
+            if placement.rank() == rank and placement.device() != tensor.device:
+                spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
+
+    return spec
+
+
+def _chunk_tensor(
+    tensor: torch.Tensor,
+    rank: int,
+    world_size: int,
+    num_devices_per_node: int,
+    pg: dist.ProcessGroup,
+) -> torch.Tensor:
+    if type(tensor) is ShardedTensor:
+        assert len(tensor.local_shards()) == 1
+
+        inner_param = tensor.local_tensor()
+        inner_st = _create_chunk_sharded_tensor(
+            inner_param,
+            rank,
+            world_size,
+            num_devices_per_node,
+            pg,
+        )
+
+        outer_local_shard = tensor.local_shards()[0]
+        shards: List[Shard] = [
+            Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
+        ]
+        st_meta = copy.deepcopy(tensor.metadata())
+        st_meta.tensor_properties.requires_grad = False
+
+        st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
+            shards,
+            sharded_tensor_metadata=st_meta,
+            process_group=tensor._process_group,
+            init_rrefs=False,
+        )
+        return st_outer
+    elif type(tensor) is DTensor:
+        device_mesh = tensor.device_mesh
+        assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
+
+        inner_param = tensor._local_tensor
+
+        inner_st = _create_chunk_sharded_tensor(
+            inner_param,
+            rank,
+            world_size,
+            torch.cuda.device_count(),
+            pg,
+        )
+
+        dt_pg = _get_dt_pg(tensor)
+        # We do this differently here, we create a ST with no local shards then patch it
+        shards = [
+            Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
+        ]
+
+        st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
+        st_meta.tensor_properties.requires_grad = False
+
+        st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
+            shards,
+            sharded_tensor_metadata=st_meta,
+            process_group=dt_pg,
+            init_rrefs=False,
+        )
+
+        return st_outer
+    else:
+        return _create_chunk_sharded_tensor(
+            tensor,
+            rank,
+            world_size,
+            num_devices_per_node,
+            pg,
+        )
+
+
+def _chunk_dtensor(
+    tensor: torch.Tensor,
+    rank: int,
+    device_mesh: DeviceMesh,
+) -> DTensor:
+    """
+    Shard a tensor to chunks along the first dimension.
+
+    The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
+    """
+    parent_mesh = _mesh_resources.get_parent_mesh(device_mesh)
+    if parent_mesh is None:
+        raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
+    if parent_mesh.ndim < 2:
+        raise RuntimeError(
+            f"Found parent device_mesh of ndim={parent_mesh.ndim},",
+            "but meshes must be at least 2D.",
+        )
+
+    # We need to explicitly call .detach() to return a new tensor detached from the current graph.
+    tensor = tensor.clone().detach()
+
+    # When a layer is not involved in TP, then the tensor will not be a DTensor.
+    # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer.
+    # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer.
+    if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):
+
+        # For tensors, it is replicated across tp dimension and sharded across FSDP dimension.
+        # TP is the inner dimension and FSDP is the outer dimension.
+        # Therefore, shard placements for tensor is (Shard(0), Replicate()).
+        replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)]
+        shard_placements = [Replicate() for _ in range(parent_mesh.ndim)]
+        shard_placements[0] = DShard(0)  # type: ignore[call-overload]
+
+        return DTensor.from_local(
+            tensor, parent_mesh, replicate_placements
+        ).redistribute(
+            device_mesh=parent_mesh,
+            placements=shard_placements,
+        )
+
+    else:
+        tp_placements = tensor.placements
+        tp_placement = tp_placements[0]
+
+        tensor = tensor.to_local()
+
+        # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension.
+        # TP is the inner dimension and FSDP is the outer dimension.
+        # Therefore, shard placements for tensor is (Shard(0), tp_placement).
+        # For higher dimensional meshes, it is replicated across other dimensions. For example, with
+        # HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement).
+        replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)]
+        replicate_placements[-1] = tp_placement  # type: ignore[call-overload]
+        shard_placements = [Replicate() for i in range(parent_mesh.ndim)]  # type: ignore[misc]
+        shard_placements[-2] = DShard(0)  # type: ignore[call-overload]
+        shard_placements[-1] = tp_placement  # type: ignore[call-overload]
+
+        return DTensor.from_local(
+            tensor, parent_mesh, replicate_placements
+        ).redistribute(
+            device_mesh=parent_mesh,
+            placements=shard_placements,
+        )
+
+
+def _pre_load_state_dict(
+    tensor: torch.Tensor,
+) -> Tuple[torch.Tensor, List[Shard]]:
+    shards = cast(ShardedTensor, tensor).local_shards()
+    if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
+        inner_tensor = shards[0].tensor
+        shards = inner_tensor.local_shards()  # pyre-ignore[16]
+        tensor = inner_tensor
+
+    return (tensor, shards if len(shards) > 0 else [])
+
+
+def _all_gather_dtensor(
+    tensor: DTensor,
+    parent_mesh: Optional[DeviceMesh],
+) -> torch.Tensor:
+    """All gather a DTensor in its FSDP dimension and return the local tensor."""
+    assert parent_mesh == tensor.device_mesh
+
+    placements = list(copy.deepcopy(tensor.placements))
+    # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement]
+    # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement]
+    for i in range(0, len(placements) - 1):
+        placements[i] = Replicate()
+    tensor = tensor.redistribute(
+        device_mesh=tensor.device_mesh,
+        placements=placements,
+    )
+
+    return tensor.to_local()
+
+
+class DTensorExtensions(FSDPExtensions):
+    """
+    DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP.
+
+    This is the implementation for FSDPExtensions defined in
+    https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py
+    """
+    def __init__(self, device_handle) -> None:
+        super().__init__()
+        self.compute_stream = None
+        self.device_handle = device_handle
+        # we have to use the dynamo disable this way to disable dynamo as the decorater way would
+        # trigger build failure with torch deploy...
+        self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform)  # type: ignore[method-assign]
+
+    def pre_flatten_transform(
+        self,
+        tensor: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[Any]]:
+        return _flatten_tensor(tensor)
+
+    def post_unflatten_transform(
+        self, tensor: torch.Tensor, param_extension: Any
+    ) -> torch.Tensor:
+        stream = self.compute_stream or self.device_handle.current_stream()
+        with self.device_handle.stream(stream):
+            # runtime we put the unflattened tensor call on the compute stream since
+            # the unflattened tensor might contain computations in fwd/bwd where we
+            # need to sync properly.
+            # TODO: this is a short term fix and we should make the get_unflat_views
+            # directly happen in the compute stream.
+            result = _unflatten_tensor(
+                tensor,
+                param_extension,
+                device_handle=self.device_handle,
+                compute_stream=self.compute_stream
+            )
+            _set_fsdp_flattened(result)
+            return result
+
+    def chunk_tensor(
+        self,
+        tensor: torch.Tensor,
+        rank: int,
+        world_size: int,
+        num_devices_per_node: int,
+        pg: dist.ProcessGroup,
+        device: Optional[torch.device] = None,
+    ) -> torch.Tensor:
+        return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
+
+    def chunk_dtensor(
+        self,
+        tensor: torch.Tensor,
+        rank: int,
+        device_mesh: DeviceMesh,
+    ) -> torch.Tensor:
+        return _chunk_dtensor(tensor, rank, device_mesh)
+
+    def pre_load_state_dict_transform(
+        self,
+        tensor: torch.Tensor,
+    ) -> Tuple[torch.Tensor, List[Shard]]:
+        return _pre_load_state_dict(tensor)
+
+    def all_gather_dtensor(
+        self,
+        tensor: DTensor,
+        parent_mesh: Optional[DeviceMesh],
+    ) -> torch.Tensor:
+        return _all_gather_dtensor(tensor, parent_mesh)
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/input_reshard.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/input_reshard.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba61692411591d940c5206f6c2873a7378866a13
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/input_reshard.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from functools import partial
+from typing import Any, Optional, Tuple
+
+import torch
+from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
+
+__all__ = [
+    "input_reshard",
+]
+
+
+def input_reshard(
+    module: torch.nn.Module,
+    tp_device_mesh: DeviceMesh,
+    input_reshard_dim: Optional[int] = None,
+) -> torch.nn.Module:
+    """
+    Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation.
+
+    Register hooks to an nn.Module with input resharding so that we can shard
+    per the given `tp_device_mesh` and `input_reshard_dim` and restore the
+    input back when recomputing the activations in the backward. The reason
+    why we can do this is that for Tensor Parallel(TP), the input are same
+    across all TP ranks.
+
+    Args:
+        module (:class:`nn.Module`):
+            Module to be registered with input resharding.
+        tp_device_mesh (:class:`DeviceMesh`):
+            Object which describes the mesh topology
+            of devices for Tensor Parallel.
+        input_reshard_dim (Optional[int]):
+            The dimension of where we perform the sharding
+            of input. If set None, there is no sharding of input.
+            Default: None
+
+    Return:
+        A :class:`nn.Module` object registered with TP input resharding.
+    """
+    cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
+
+    def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None:
+        saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
+            partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim),
+            partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim),
+        )
+        saved_tensor_hooks.__enter__()
+        nonlocal cx
+        cx = saved_tensor_hooks  # type: ignore[name-defined]
+
+    def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
+        nonlocal cx
+        cx.__exit__()  # type: ignore[name-defined, union-attr]
+
+    if input_reshard_dim is None:
+        return module
+    module.register_forward_pre_hook(input_reshard_forward_pre_hook)
+    module.register_forward_hook(input_reshard_backward_hook)
+    return module
+
+
+def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any:  # noqa: D401
+    """Hook function called after FWD to shard input."""
+    if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
+        return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
+    elif (
+        not isinstance(x, DTensor)
+        and isinstance(x, torch.Tensor)
+        and x.numel() >= mesh.size()
+    ):
+        return (
+            DTensor.from_local(x, device_mesh=mesh)
+            .redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
+            .to_local()
+        )
+    else:
+        return x
+
+
+def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor:  # noqa: D401
+    """Hook function called before activation recomputing in BWD to restore input."""
+    if (
+        isinstance(x, DTensor)
+        and len(x._spec.placements) == 1
+        and x._spec.placements[0].is_shard()
+    ):
+        return x.redistribute(device_mesh=mesh, placements=[Replicate()])
+    elif (
+        not isinstance(x, DTensor)
+        and isinstance(x, torch.Tensor)
+        and x.numel() >= mesh.size()
+    ):
+        return (
+            DTensor.from_local(
+                x, device_mesh=mesh, placements=[Shard(input_reshard_dim)]
+            )
+            .redistribute(device_mesh=mesh, placements=[Replicate()])
+            .to_local()
+        )
+    else:
+        return x
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/loss.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..44df123f23d48281e9f45cfa82d516ce5a61a80b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/loss.py
@@ -0,0 +1,484 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import contextlib
+from typing import cast, Dict, Optional, Tuple
+
+import torch
+import torch._prims_common as utils
+import torch.distributed._functional_collectives as funcol
+import torch.distributed.distributed_c10d as c10d
+from torch import Tensor
+from torch.distributed._tensor import DTensor, Replicate, Shard
+from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
+from torch.distributed._tensor.ops.math_ops import (
+    _skip_dim,
+    Reduction,
+    replicate_reduction_dims,
+)
+from torch.distributed._tensor.placement_types import Placement, TensorMeta
+from torch.distributed.device_mesh import DeviceMesh
+
+aten = torch.ops.aten
+
+
+__all__ = ["loss_parallel"]
+
+
+@contextlib.contextmanager
+def loss_parallel():
+    """
+    A context manager that enables loss parallelism, where efficient parallelized loss computation
+    can be performed when the input is sharded on the class dimension. Currently only the cross-entropy
+    loss is supported.
+
+    Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or
+    :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters.
+    The corresponding ``backward()`` call, if any, also needs to happen under this context manager.
+
+    Args:
+        input (:class:`DTensor`):
+            Input logits. Assumed to be sharded on the class dimension.
+        target (Union[:class:`torch.Tensor`, :class:`DTensor`]):
+            Must be ground truth class indices (class probabilities currently not supported).
+            Assumed to be replicated across the ``DeviceMesh``.
+        weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional):
+            If given, assumed to be replicated across the ``DeviceMesh``.
+        label_smoothing:
+            Currently not supported.
+
+    Returns:
+        A replicated :class:`DTensor`.
+
+    Example:
+        A sharded DTensor is manually created here to showcase the usage.
+        In practice, it is usually the output of a TP module.
+
+        >>> # xdoctest: +SKIP("distributed")
+        >>> from torch.distributed.tensor.parallel import loss_parallel
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>> ...
+        >>> device_mesh = init_device_mesh("cuda", (8,))
+        >>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
+        >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
+        >>> target = torch.randint(16, (4,), device="cuda")
+        >>> with loss_parallel():
+        >>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
+        >>>     loss.backward()
+        >>> ...
+    """
+    _enable_custom_loss_ops()
+
+    yield
+
+    _disable_custom_loss_ops()
+
+
+# Currently only needs to support one dimensional DeviceMesh; in general return
+# the mesh_dim with placements[mesh_dim].is_shard(dim)
+def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int:
+    if not len(placements) == 1:
+        raise ValueError(
+            "Currently loss_parallel() only supports input on one-dimensional DeviceMesh."
+        )
+    if not placements[0].is_shard(dim):
+        raise ValueError(
+            f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}."
+        )
+    return 0
+
+
+def _cast_to_dtensor(
+    tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh
+) -> DTensor:
+    if isinstance(tensor, DTensor):
+        if tensor.placements == placements:
+            return tensor
+        else:
+            raise RuntimeError(f"Expected {placements} but got {tensor.placements}.")
+    elif isinstance(tensor, torch.Tensor):
+        return DTensor.from_local(
+            tensor, device_mesh=mesh, placements=placements, run_check=False
+        )
+    else:
+        raise TypeError(f"Unsupported type {type(tensor)}")
+
+
+def _propagate_tensor_meta(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> TensorMeta:
+    op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
+    tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
+        op_info.schema
+    )
+    if isinstance(tensor_meta, TensorMeta):
+        return tensor_meta
+    elif isinstance(tensor_meta, tuple):
+        return tensor_meta[0]
+    else:
+        raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.")
+
+
+# NOTE: The implementation follows torch._decomp.decomposition._log_softmax,
+# with all_reduce manually inserted to perform distributed computation.
+def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
+    x = x.contiguous()
+    if half_to_float:
+        assert x.dtype == torch.half
+    computation_dtype, result_dtype = utils.elementwise_dtypes(
+        x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    )
+    x = x.to(computation_dtype)
+    if x.numel() == 0:
+        shifted = x
+    else:
+        x_max = torch.amax(x, dim, keepdim=True)
+        x_max = funcol.all_reduce(
+            x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim)
+        )
+        shifted = x - x_max
+    shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True)
+    shifted_sumexp = funcol.all_reduce(
+        shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim)
+    )
+    shifted_logsumexp = torch.log(shifted_sumexp)
+    result = shifted - shifted_logsumexp
+    if not half_to_float:
+        result = result.to(result_dtype)
+    return result
+
+
+def _log_softmax_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    x = cast(DTensor, args[0])
+    dim = cast(int, args[1])
+    half_to_float = cast(bool, args[2])
+
+    spec = x._spec
+    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim)
+
+    output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs)
+
+    res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim)
+
+    return DTensor(
+        res,
+        spec.mesh,
+        spec.placements,
+        shape=output_tensor_meta.shape,
+        dtype=output_tensor_meta.dtype,
+        requires_grad=res.requires_grad,
+        stride=output_tensor_meta.stride,
+    )
+
+
+# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the
+# _log_softmax_backward_handler does not actually do any computation.
+def _log_softmax_backward_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    grad_output = cast(DTensor, args[0])
+    input_dtype = cast(torch.dtype, args[3])
+    return grad_output.to(input_dtype)
+
+
+# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward,
+# with customized communication inserted to perform distributed computation.
+def _nll_loss_forward(
+    x: Tensor,
+    target: Tensor,
+    weight: Optional[Tensor],
+    local_weight: Optional[Tensor],
+    reduction: int,
+    ignore_index: int,
+    channel_dim_size: int,
+    mesh: DeviceMesh,
+    mesh_dim: int,
+) -> Tuple[Tensor, Tensor]:
+    n_dims = x.dim()
+    channel_dim = 1
+    if n_dims < 2:
+        channel_dim = 0
+
+    def _weight_view(weight: Tensor) -> Tensor:
+        if n_dims > 1:
+            shape = [
+                1,
+            ] * n_dims
+            shape[channel_dim] = weight.shape[0]
+            w = weight.view(shape)
+        else:
+            w = weight
+        return w
+
+    if weight is not None:
+        w = _weight_view(weight)
+        assert local_weight is not None
+        local_w = _weight_view(local_weight)
+        x = x * local_w
+    safe_target = torch.where(target != ignore_index, target, 0)
+    safe_target_ = safe_target.unsqueeze(channel_dim)
+
+    # The following code block is a distributed version of
+    # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
+    partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
+    safe_target_partial_ = partial_placement._partition_value(
+        safe_target_, mesh, mesh_dim
+    )
+    result_partial = torch.gather(x, channel_dim, safe_target_partial_)
+    # an all_reduce happens here
+    result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim)
+    result = -result_reduced.squeeze(channel_dim)
+
+    result = torch.where(target != ignore_index, result, 0)
+
+    if reduction == Reduction.NONE.value and n_dims > 1:
+        total_weight = x.new_full((), 0.0)
+        return result, total_weight
+
+    if weight is not None:
+        new_shape = list(x.shape)
+        new_shape[channel_dim] = -1
+        w = w.expand(new_shape)
+        wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
+        wsum = torch.where(target != ignore_index, wsum, 0)
+        total_weight = wsum.sum()
+    else:
+        total_weight = (target != ignore_index).sum().to(x)
+
+    # NOTE: this is correct only on 1D DeviceMesh; o/w additional
+    #       all-reduce on result and total_weight is needed
+    if reduction == Reduction.SUM.value:
+        result = result.sum()
+    elif reduction == Reduction.MEAN.value:
+        result = result.sum() / total_weight
+
+    return result, total_weight
+
+
+def _nll_loss_forward_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    x = cast(DTensor, args[0])
+    target = args[1]
+    weight = args[2]
+    reduction = cast(int, args[3])
+    ignore_index = cast(int, args[4])
+
+    channel_dim = 1 if x.dim() >= 2 else 0
+    channel_dim_size = x.shape[channel_dim]
+    spec = x._spec
+    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
+
+    # Check user input: if target and weight are not DTensors, convert them to DTensors;
+    # if they are DTensors, check that they have the desired placements.
+    target_placements = _skip_dim(
+        replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
+    )
+    all_replicate_placements = (Replicate(),) * spec.mesh.ndim
+    target = _cast_to_dtensor(target, target_placements, spec.mesh)
+    local_weight = None
+    if weight is not None:
+        weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
+        # For local computation, both (replicated) weight and (sharded) local_weight
+        # are needed in _nll_loss_forward(). local_weight is generated here using
+        # DTensor API, without incurring any communication.
+        sharded_placements = [
+            Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim)
+        ]
+        local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor
+        assert local_weight.shape[0] == x._local_tensor.shape[channel_dim]
+
+    if reduction == Reduction.NONE.value:
+        output_placements = target_placements
+    else:
+        output_placements = all_replicate_placements
+
+    # tensor inputs to _propagate_tensor_meta need to be DTensors
+    args = list(args)
+    args[1], args[2] = target, weight
+    output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
+
+    result, total_weight = _nll_loss_forward(
+        x._local_tensor,
+        target._local_tensor,
+        weight._local_tensor if weight is not None else None,
+        local_weight,
+        reduction,
+        ignore_index,
+        channel_dim_size,
+        spec.mesh,
+        mesh_dim,
+    )
+
+    return (
+        DTensor(
+            result,
+            spec.mesh,
+            output_placements,
+            shape=output_tensor_meta.shape,
+            dtype=output_tensor_meta.dtype,
+            requires_grad=result.requires_grad,
+            stride=output_tensor_meta.stride,
+        ),
+        total_weight,
+    )
+
+
+# NOTE: The backward computation of cross_entropy goes through two steps:
+# backward for nll_loss and then backward for log_softmax. In loss parallel,
+# the two steps are fused into the following function (called by _nll_loss_backward_handler)
+# to avoid communication when target contains class indices not class probabilities.
+# Also note that the _log_softmax_backward_handler does not perform computation.
+# The implementation resembles _nll_loss_backward and _log_softmax_backward_data
+# from torch._decomp.decomposition.
+def _nll_loss_and_log_softmax_backward(
+    grad_output: Tensor,
+    x: Tensor,
+    target: Tensor,
+    weight: Optional[Tensor],
+    reduction: int,
+    ignore_index: int,
+    total_weight: Tensor,
+    channel_dim_size: int,
+    mesh: DeviceMesh,
+    mesh_dim: int,
+) -> Tensor:
+    channel_dim = 0 if x.dim() < 2 else 1
+    if reduction == Reduction.MEAN.value:
+        grad_output = grad_output / total_weight
+
+    target = target.unsqueeze(channel_dim)
+    safe_target = torch.where(target != ignore_index, target, 0)
+    grad_input = torch.zeros_like(x)
+
+    # The following code block is a distributed version of
+    # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
+    partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
+    safe_target = safe_target.squeeze(channel_dim).flatten()
+    masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
+    # only update grad_input to -1 if not masked
+    assert partial_placement.mask_buffer.data is not None
+    grad_update = partial_placement.mask_buffer.data.float() - 1.0
+    arange_1d = torch.arange(
+        masked_safe_target.shape[0], device=masked_safe_target.device
+    )
+    # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default;
+    # the last case is for aten.nll_loss2d_backward.default.
+    if x.dim() == 1:
+        grad_input[masked_safe_target] = grad_update
+    elif x.dim() == 2:
+        grad_input[arange_1d, masked_safe_target] = grad_update
+    else:
+        grad_input_t = grad_input.transpose(channel_dim, -1)
+        intermidate_shape = grad_input_t.shape
+        grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim])
+        grad_input_2d[arange_1d, masked_safe_target] = grad_update
+        grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1)
+
+    if grad_input.dim() > grad_output.dim() > 0:
+        grad_output = grad_output.unsqueeze(channel_dim)
+
+    if weight is not None:
+        new_shape = [1 for _ in range(x.dim())]
+        new_shape[channel_dim] = weight.shape[0]
+        weight = weight.reshape(new_shape)
+        # In order for fused computation to work, the following line is rewritten.
+        # grad_output = grad_output * weight
+        new_shape = list(x.shape)
+        new_shape[channel_dim] = -1
+        w = weight.expand(new_shape)
+        w_target = torch.gather(w, channel_dim, target)
+        grad_output = grad_output * w_target
+
+    grad_output = torch.where(target != ignore_index, grad_output, 0)
+
+    # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax,
+    # here we perform backward computation for log_softmax altogether to avoid the
+    # otherwise extra all_gather communication.
+    # return grad_input * grad_output
+    return (grad_input + torch.exp(x)) * grad_output
+
+
+def _nll_loss_backward_handler(
+    op_call: torch._ops.OpOverload,
+    args: Tuple[object, ...],
+    kwargs: Dict[str, object],
+) -> object:
+    grad_output = cast(DTensor, args[0])
+    x = cast(DTensor, args[1])
+    target = args[2]
+    weight = args[3]
+    reduction = cast(int, args[4])
+    ignore_index = cast(int, args[5])
+    total_weight = cast(Tensor, args[6])
+
+    channel_dim = 1 if x.dim() >= 2 else 0
+    channel_dim_size = x.shape[channel_dim]
+    spec = x._spec
+    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
+
+    # if target and weight are not DTensors, convert them to DTensors
+    target_placements = _skip_dim(
+        replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
+    )
+    all_replicate_placements = (Replicate(),) * spec.mesh.ndim
+    target = _cast_to_dtensor(target, target_placements, spec.mesh)
+    if weight is not None:
+        weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
+
+    # tensor inputs to _propagate_tensor_meta need to be DTensors
+    args = list(args)
+    args[2], args[3] = target, weight
+    args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
+    output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
+
+    result = _nll_loss_and_log_softmax_backward(
+        grad_output._local_tensor,
+        x._local_tensor,
+        target._local_tensor,
+        weight._local_tensor if weight is not None else None,
+        reduction,
+        ignore_index,
+        total_weight,
+        channel_dim_size,
+        spec.mesh,
+        mesh_dim,
+    )
+
+    return DTensor(
+        result,
+        spec.mesh,
+        # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim
+        spec.placements,
+        shape=output_tensor_meta.shape,
+        dtype=output_tensor_meta.dtype,
+        requires_grad=result.requires_grad,
+        stride=output_tensor_meta.stride,
+    )
+
+
+customized_loss_ops = {
+    aten._log_softmax.default: _log_softmax_handler,
+    aten._log_softmax_backward_data.default: _log_softmax_backward_handler,
+    aten.nll_loss_forward.default: _nll_loss_forward_handler,
+    aten.nll_loss2d_forward.default: _nll_loss_forward_handler,
+    aten.nll_loss_backward.default: _nll_loss_backward_handler,
+    aten.nll_loss2d_backward.default: _nll_loss_backward_handler,
+}
+
+
+def _enable_custom_loss_ops():
+    DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops)
+
+
+def _disable_custom_loss_ops():
+    for custom_op in customized_loss_ops:
+        DTensor._op_dispatcher._custom_op_handlers.pop(custom_op)
diff --git a/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/style.py b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/style.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d892097b77bf25c7a478a7f9f0d22592fb933d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/tensor/parallel/style.py
@@ -0,0 +1,489 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from abc import ABC, abstractmethod
+from typing import Optional, Union, Tuple
+from functools import partial
+
+import torch
+import torch.nn as nn
+from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module
+
+
+__all__ = [
+    "ParallelStyle",
+    "RowwiseParallel",
+    "SequenceParallel",
+    "ColwiseParallel",
+    "PrepareModuleInput",
+    "PrepareModuleOutput",
+]
+
+
+class ParallelStyle(ABC):
+    """
+    The parallel style contract defines how the module or submodule should be parallelized.
+
+    It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
+    flexibility for different kind of style implementations.
+    """
+
+    @abstractmethod
+    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+        ...
+
+
+class ColwiseParallel(ParallelStyle):
+    """
+    Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.
+    Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.
+    (i.e. MLP, Attention)
+
+    Keyword Args:
+        input_layouts (Placement, optional):
+            The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
+            become a DTensor. If not specified, we assume the input tensor to be replicated.
+        output_layouts (Placement, optional):
+            The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
+            with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
+        use_local_output (bool, optional):
+            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
+    Returns:
+        A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
+
+    Example::
+        >>> # xdoctest: +SKIP(failing)
+        >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>> ...
+        >>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
+        >>> tp_mesh = init_device_mesh("cuda", (8,))
+        >>>
+        >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
+        >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
+        >>>
+        >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
+        >>> ...
+
+    .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
+        specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
+        keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
+    """
+
+    def __init__(
+        self,
+        *,
+        input_layouts: Optional[Placement] = None,
+        output_layouts: Optional[Placement] = None,
+        use_local_output: bool = True
+    ):
+        super().__init__()
+        self.input_layouts = (input_layouts or Replicate(), )
+        self.output_layouts = (output_layouts or Shard(-1), )
+        # colwise linear runtime sharding (desired sharding):
+        # 1. requires replicate input
+        # 2. shard output on last dim
+        self.desired_input_layouts = (Replicate(), )
+        self.use_local_output = use_local_output
+
+    @staticmethod
+    def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
+        # TODO: figure out dynamo support for instance method and switch this to instance method
+
+        # annotate module input placements/sharding with input_layouts
+        input_tensor = inputs[0]
+        if not isinstance(input_tensor, DTensor):
+            input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
+
+        # transform the input layouts to the desired layouts of ColwiseParallel
+        if input_layouts != desired_input_layouts:
+            input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
+        return input_tensor
+
+    def _partition_linear_fn(self, name, module, device_mesh):
+        # colwise shard weight/bias to Shard(0), weight be Shard(0)
+        # means Colwise as Linear is input * weight^T + bias, where
+        # weight would become Shard(1)
+        for name, param in module.named_parameters():
+            dist_param = nn.Parameter(
+                distribute_tensor(param, device_mesh, [Shard(0)])
+            )
+            module.register_parameter(name, dist_param)
+
+    def _partition_embedding_fn(self, name, module, device_mesh):
+        # colwise shard embedding.weight is straight forward as Shard(1)
+        for name, param in module.named_parameters():
+            dist_param = nn.Parameter(
+                distribute_tensor(param, device_mesh, [Shard(1)])
+            )
+            module.register_parameter(name, dist_param)
+
+    @staticmethod
+    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
+        # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
+        outputs = outputs.redistribute(placements=output_layouts, async_op=True)
+        # back to local tensor
+        return outputs.to_local() if use_local_output else outputs
+
+    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+        if isinstance(module, nn.Linear):
+            partition_fn = self._partition_linear_fn
+        elif isinstance(module, nn.Embedding):
+            partition_fn = self._partition_embedding_fn
+        else:
+            raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!")
+
+        return distribute_module(
+            module,
+            device_mesh,
+            partition_fn,
+            partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
+            partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
+        )
+
+
+class RowwiseParallel(ParallelStyle):
+    """
+    Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
+    Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
+    (i.e. MLP, Attention)
+
+    Keyword Args:
+        input_layouts (Placement, optional):
+            The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
+            become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
+        output_layouts (Placement, optional):
+            The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
+            with the user desired layout. If not specified, the output tensor is replicated.
+        use_local_output (bool, optional):
+            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
+    Returns:
+        A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
+
+    Example::
+        >>> # xdoctest: +SKIP(failing)
+        >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>> ...
+        >>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
+        >>> tp_mesh = init_device_mesh("cuda", (8,))
+        >>>
+        >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
+        >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
+        >>>
+        >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
+        >>> ...
+    """
+
+    def __init__(
+        self,
+        *,
+        input_layouts: Optional[Placement] = None,
+        output_layouts: Optional[Placement] = None,
+        use_local_output: bool = True
+    ):
+        super().__init__()
+        self.input_layouts = (input_layouts or Shard(-1), )
+        self.output_layouts = (output_layouts or Replicate(), )
+        self.use_local_output = use_local_output
+
+    @staticmethod
+    def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
+        input_tensor = inputs[0]
+        if not isinstance(input_tensor, DTensor):
+            input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
+
+        if input_layouts != desired_input_layouts:
+            input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
+        return input_tensor
+
+    def _partition_linear_fn(self, name, module, device_mesh):
+        # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
+        # means Rowwise as nn.Linear is input * weight^T + bias, where
+        # weight would become Shard(0)
+        module.register_parameter("weight", nn.Parameter(
+            distribute_tensor(module.weight, device_mesh, [Shard(1)])
+        ))
+        if module.bias is not None:
+            module.register_parameter("bias", nn.Parameter(
+                distribute_tensor(module.bias, device_mesh, [Replicate()])
+            ))
+
+    def _partition_embedding_fn(self, name, module, device_mesh):
+        # rowwise shard embedding.weight is Shard(0)
+        for name, param in module.named_parameters():
+            dist_param = nn.Parameter(
+                distribute_tensor(param, device_mesh, [Shard(0)])
+            )
+            module.register_parameter(name, dist_param)
+
+    @staticmethod
+    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
+        # Rowwise sharding produces partial output, depending on output layouts:
+        # 1. to replicate -> allreduce
+        # 2. to shard -> reduce_scatter
+        outputs = outputs.redistribute(placements=output_layouts, async_op=True)
+        # back to local tensor if use_local_output is True
+        return outputs.to_local() if use_local_output else outputs
+
+    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+        if isinstance(module, nn.Linear):
+            partition_fn = self._partition_linear_fn
+            # rowwise linear runtime sharding requires input tensor shard on last dim
+            self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), )
+        elif isinstance(module, nn.Embedding):
+            partition_fn = self._partition_embedding_fn
+            # rowwise embedding runtime sharding requires input tensor replicated
+            self.desired_input_layouts = (Replicate(), )
+        else:
+            raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
+
+        return distribute_module(
+            module,
+            device_mesh,
+            partition_fn,
+            partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
+            partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
+        )
+
+
+class SequenceParallel(ParallelStyle):
+    """
+    SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
+    input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
+    `RMSNorm python implementation `__
+
+    This style implements the operation that is described in the paper
+    `Reducing Activation Recomputation in Large Transformer Models `__
+
+    Both the input and output of the ``nn.Module`` will be sharded on the sequence dimension.
+
+    Keyword Args:
+        sequence_dim (int, optional):
+            The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
+            become a DTensor that is sharded on the sequence dimension, default: 1.
+        use_local_output (bool, optional):
+            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
+    Returns:
+        A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
+
+    Example::
+        >>> # xdoctest: +SKIP(failing)
+        >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>> ...
+        >>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
+        >>> tp_mesh = init_device_mesh("cuda", (8,))
+        >>>
+        >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
+        >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
+        >>>
+        >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
+        >>> ...
+
+    .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
+        ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
+        inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
+        to ensure that they are replicated.
+    """
+    def __init__(
+        self,
+        *,
+        sequence_dim: int = 1,
+        use_local_output: bool = False
+    ):
+        super().__init__()
+        self.sequence_dim = sequence_dim
+        self.use_local_output = use_local_output
+
+    def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
+        for p_name, param in module.named_parameters():
+            # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
+            # us to simply just use from_local
+            replicated_param = torch.nn.Parameter(
+                DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
+            )
+            module.register_parameter(p_name, replicated_param)
+
+    @staticmethod
+    def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh):
+        input_tensor = inputs[0]
+        if isinstance(input_tensor, DTensor):
+            return inputs
+        elif isinstance(input_tensor, torch.Tensor):
+            return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False)
+        else:
+            raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}")
+
+    @staticmethod
+    def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
+        return outputs.to_local() if use_local_output else outputs
+
+    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+        return distribute_module(
+            module,
+            device_mesh,
+            self._replicate_module_fn,
+            partial(self._prepare_input_fn, self.sequence_dim),
+            partial(self._prepare_output_fn, self.use_local_output),
+        )
+
+
+class PrepareModuleInput(ParallelStyle):
+    """
+    Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
+    ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
+
+    Keyword Args:
+        input_layouts (Union[Placement, Tuple[Placement]]):
+            The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
+            DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
+            as a placeholder.
+        desired_input_layouts (Union[Placement, Tuple[Placement]]):
+            The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
+            have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``.
+        use_local_output (bool, optional):
+            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
+    Returns:
+        A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
+
+    Example::
+        >>> # xdoctest: +SKIP(failing)
+        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>> ...
+        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
+        >>> tp_mesh = init_device_mesh("cuda", (8,))
+        >>>
+        >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
+        >>> # and then redistributed to Replicated DTensor.
+        >>> parallelize_module(
+        >>>     block, # this can be a submodule or module
+        >>>     tp_mesh,
+        >>>     parallelize_plan={
+        >>>         "attn": PrepareModuleInput(
+        >>>             input_layouts=(Shard(0), None, None, ...),
+        >>>             desired_input_layouts=(Replicate(), None, None, ...)
+        >>>         ),
+        >>>     }
+        >>> )
+    """
+
+    def __init__(
+        self,
+        *,
+        input_layouts: Union[Placement, Tuple[Placement]],
+        desired_input_layouts: Union[Placement, Tuple[Placement]],
+        use_local_output: bool = False
+    ):
+        self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
+        self.desired_input_layouts = \
+            (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts
+        self.use_local_output = use_local_output
+        assert len(self.input_layouts) == len(self.desired_input_layouts), \
+            "input_layouts and desired_input_layouts should have same length!"
+
+    def _prepare_input_fn(self, inputs, device_mesh):
+        prepared_inputs = []
+        if not isinstance(inputs, tuple):
+            inputs = (inputs,)
+        if len(inputs) != len(self.input_layouts):
+            raise ValueError("module inputs and input_layouts should have same length!")
+
+        for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
+            if input_layout is not None:
+                if isinstance(inp, DTensor):
+                    # TODO: re-enable the check once we fix the compile path
+                    # assert inp.placements[0] == input_layout
+                    dt_inp = inp
+                else:
+                    dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
+                if input_layout != desired_layout:
+                    dt_inp = dt_inp.redistribute(placements=(desired_layout,))
+                prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp)
+            else:
+                prepared_inputs.append(inp)
+        return tuple(prepared_inputs)
+
+    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+        module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh))  # type: ignore[misc, call-arg]
+        return module
+
+
+class PrepareModuleOutput(ParallelStyle):
+    """
+    Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
+    ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
+
+    Keyword Args:
+        output_layouts (Union[Placement, Tuple[Placement]]):
+            The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
+            DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
+            ``None`` need to be specified as a placeholder.
+        desired_output_layouts (Union[Placement, Tuple[Placement]]):
+            The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
+            have the desired DTensor layouts.
+        use_local_output (bool, optional):
+            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.
+    Returns:
+        A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
+
+    Example::
+        >>> # xdoctest: +SKIP(failing)
+        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
+        >>> from torch.distributed.device_mesh import init_device_mesh
+        >>> ...
+        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
+        >>> tp_mesh = init_device_mesh("cuda", (8,))
+        >>>
+        >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
+        >>> # and then redistributed to Sharded DTensor.
+        >>> parallelize_module(
+        >>>     block, # this can be a submodule or module
+        >>>     tp_mesh,
+        >>>     parallelize_plan = PrepareModuleOutput(
+        >>>         output_layouts=Replicate(),
+        >>>         desired_output_layouts=Shard(0)
+        >>>     )
+        >>> )
+    """
+    def __init__(
+        self,
+        *,
+        output_layouts: Union[Placement, Tuple[Placement]],
+        desired_output_layouts: Union[Placement, Tuple[Placement]],
+        use_local_output: bool = True
+    ):
+        self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts
+        self.desired_output_layouts = \
+            (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts
+        self.use_local_output = use_local_output
+        assert len(self.output_layouts) == len(self.desired_output_layouts), \
+            "output_layouts and desired_output_layouts should have same length!"
+
+    def _prepare_out_fn(self, outputs, device_mesh):
+        prepared_outputs = []
+        if not isinstance(outputs, tuple):
+            outputs = (outputs,)
+        if len(outputs) != len(self.output_layouts):
+            raise ValueError("module outputs and output_layouts should have same length!")
+        for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts):
+            if out_layout is not None:
+                if isinstance(out, DTensor):
+                    # TODO: re-enable the check once we fix the compile path
+                    # assert out.placements[0] == out_layout
+                    dt_out = out
+                else:
+                    dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False)
+
+                if out_layout != desired_out_layout:
+                    dt_out = dt_out.redistribute(placements=(desired_out_layout,))
+                prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out)
+            else:
+                prepared_outputs.append(out)
+        if len(prepared_outputs) == 1:
+            return prepared_outputs[0]
+        else:
+            return tuple(prepared_outputs)
+
+    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
+        module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh))  # type: ignore[misc, call-arg]
+        return module
diff --git a/MLPY/Lib/site-packages/torch/distributed/utils.py b/MLPY/Lib/site-packages/torch/distributed/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c67357e9a3a5f7fed4349552fdf4b794de76d1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributed/utils.py
@@ -0,0 +1,339 @@
+import dataclasses
+import traceback
+from typing import Any, Callable, Container, Dict, List, Optional, OrderedDict, Tuple, TypeVar, overload
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.nn.parallel._functions import _get_stream
+from torch.nn.parallel.scatter_gather import _is_namedtuple
+from torch.nn.utils.rnn import PackedSequence
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
+    """
+    Turn argument list into separate key list and value list (unpack_kwargs does the opposite).
+
+    Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70
+    Usage::
+
+        kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
+        assert kwarg_keys == ("a", "b")
+        assert flat_args == (1, 2, 3, 4)
+        args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
+        assert args == (1, 2)
+        assert kwargs == {"a": 3, "b": 4}
+    Returns:
+        Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives
+        gives both positional args and kwarg values, where the positional args
+        proceed kwarg values and kwarg values are ordered consistently with the
+        kwarg keys. The second tuple element gives the kwarg keys.
+        The second tuple element's length is at most the first tuple element's length.
+    """
+    kwarg_keys: List[str] = []
+    flat_args: List[Any] = list(args)
+    for k, v in kwargs.items():
+        kwarg_keys.append(k)
+        flat_args.append(v)
+
+    return tuple(flat_args), tuple(kwarg_keys)
+
+def _cast_forward_inputs(
+    dtype: Optional[torch.dtype],
+    *args: Any,
+    **kwargs: Any,
+) -> Tuple[Any, Any]:
+    """
+    Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
+
+    This respects the existing ``requires_grad`` on the tensors.
+    """
+    if dtype is None:
+        return args, kwargs
+
+    def cast_fn(x: torch.Tensor) -> torch.Tensor:
+        if not torch.is_floating_point(x) or x.dtype == dtype:
+            return x
+        return x.to(dtype)
+
+    return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
+
+def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
+    """See _pack_kwargs."""
+    assert len(kwarg_keys) <= len(
+        flat_args
+    ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
+    if len(kwarg_keys) == 0:
+        return flat_args, {}
+    args = flat_args[: -len(kwarg_keys)]
+    kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
+    return args, kwargs
+
+
+S = TypeVar("S", dict, list, tuple)
+T = TypeVar("T", torch.Tensor, PackedSequence)
+
+
+@overload
+def _recursive_to(inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> List[S]:
+    ...
+
+
+@overload
+def _recursive_to(inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> Tuple[T]:
+    ...
+
+
+def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
+    r"""Recursively moves input to the target_device."""
+
+    def to_map(obj):
+        if isinstance(obj, (torch.Tensor, PackedSequence)):
+            device = obj.data.device if isinstance(obj, PackedSequence) else obj.device
+            if device == target_device:
+                return (obj,)
+            if not use_side_stream_for_tensor_copies:
+                return (obj.to(target_device),)
+            else:
+                # If the custom module is not registered to torch, stream is not used for acceleration
+                device_mod = getattr(torch, device.type, None)
+                if device.type == "cpu" or device_mod is None:
+                    return (obj.to(target_device),)
+                # Perform CPU -> target_device copies in a background stream. This code is
+                # motivated from similar logic in torch/nn/parallel/_functions.py
+                stream = _get_stream(target_device)
+                with device_mod.stream(stream):
+                    output = obj.to(target_device)
+                # synchronize with the copy stream
+                with device_mod.device(target_device.index):
+                    current_stream = device_mod.current_stream()
+                    # Sync the current stream with the copy stream
+                    current_stream.wait_stream(stream)
+                    # Ensure tensor memory is not reused until work on
+                    # main stream is complete
+                    if isinstance(obj, PackedSequence):
+                        output.data.record_stream(current_stream)  # type: ignore[arg-type]
+                    else:
+                        assert isinstance(output, torch.Tensor)
+                        output.record_stream(current_stream)  # type: ignore[arg-type]
+                return (output,)
+        if _is_namedtuple(obj):
+            return [type(obj)(*args) for args in zip(*map(to_map, obj))]
+        if isinstance(obj, tuple) and len(obj) > 0:
+            return list(zip(*map(to_map, obj)))
+        if isinstance(obj, list) and len(obj) > 0:
+            return [list(i) for i in zip(*map(to_map, obj))]
+        if isinstance(obj, dict) and len(obj) > 0:
+            return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
+        return [obj]
+
+    # Avoid reference cycle
+    try:
+        res = to_map(inputs)
+    finally:
+        to_map = None  # type: ignore[assignment]
+    return res
+
+
+def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
+    """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed."""
+    if not cond:
+        print(s)
+        traceback.print_stack()
+        if raise_assertion_error:
+            raise AssertionError(s)
+
+
+def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
+    """
+    Allocate storage for ``tensor`` with the given size.
+
+    Returns:
+        bool: ``True`` if this method allocated storage and ``False`` if the
+        storage was already allocated.
+    """
+    with torch.no_grad():
+        if (
+            not torch.distributed._functional_collectives.is_torchdynamo_compiling()
+        ):
+            already_allocated = tensor._typed_storage()._size() == size.numel()
+            if not already_allocated:
+                tensor_storage_size = tensor._typed_storage()._size()
+                _p_assert(
+                    tensor_storage_size == 0,
+                    "Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
+                )
+                tensor._typed_storage()._resize_(size.numel())
+
+
+def _free_storage(tensor: torch.Tensor):
+    """
+    Frees the underlying storage of ``tensor``.
+
+    Returns:
+        bool: ``True`` if the method freed the storage and ``False`` if the
+        storage was already freed.
+    """
+    with torch.no_grad():
+        if (
+            not torch.distributed._functional_collectives.is_torchdynamo_compiling()
+        ):
+            already_freed = tensor._typed_storage()._size() == 0
+            if not already_freed:
+                _p_assert(
+                    tensor.storage_offset() == 0,
+                    "Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
+                    f"storage offset: {tensor.storage_offset()}\n"
+                    f"storage size: {tensor._typed_storage()._size()}\n"
+                    f"tensor shape: {tensor.shape}",
+                )
+                tensor._typed_storage()._resize_(0)
+
+
+
+Q = TypeVar("Q")
+R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any)
+
+
+@overload
+def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q:
+    ...
+
+
+@overload
+def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R:
+    ...
+
+
+def _apply_to_tensors(fn, container):
+    """Recursively apply to all tensor in different kinds of container types."""
+
+    def apply(x):
+        if isinstance(x, torch.Tensor):
+            return fn(x)
+        elif hasattr(x, "__dataclass_fields__"):
+            dc = dataclasses.replace(x)
+            for f in dataclasses.fields(dc):
+                name = f.name
+                setattr(dc, name, apply(getattr(dc, name)))
+            return dc
+        elif isinstance(x, OrderedDict):
+            od = x.__class__()
+            for key, value in x.items():
+                od[key] = apply(value)
+            return od
+        elif isinstance(x, PackedSequence):
+            apply(x.data)
+            return x
+        elif isinstance(x, dict):
+            return {key: apply(value) for key, value in x.items()}
+        elif _is_namedtuple(x):
+            res = (apply(el) for el in x)
+            return type(x)(*res)
+        elif isinstance(x, (list, tuple, set)):
+            return type(x)(apply(el) for el in x)
+        else:
+            return x
+
+    return apply(container)
+
+
+def _to_kwargs(
+    inputs: Tuple[Any, ...],
+    kwargs: Optional[Dict[str, Any]],
+    target_device: torch.device,
+    use_side_stream_for_tensor_copies: bool,
+) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
+    moved_inputs = (
+        _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
+        if inputs
+        else []
+    )
+    moved_kwargs = (
+        _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies)
+        if kwargs
+        else []
+    )
+    if len(moved_inputs) < len(moved_kwargs):
+        moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))])
+    elif len(moved_kwargs) < len(moved_inputs):
+        moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
+    return tuple(moved_inputs), tuple(moved_kwargs)
+
+
+def _verify_param_shape_across_processes(
+    process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None
+):
+    return dist._verify_params_across_processes(process_group, tensors, logger)
+
+
+def _sync_module_states(
+    module: nn.Module,
+    process_group: dist.ProcessGroup,
+    broadcast_bucket_size: int,
+    src: int,
+    params_and_buffers_to_ignore: Container[str],
+    broadcast_buffers: bool = True,
+) -> None:
+    """
+    Sync ``module``'s parameters and buffers state.
+
+    Syncs ``module``'s parameters and buffers state so that all ranks contain
+    the same module state across all ranks. Note that this API assumes that all
+    parameter shapes are consistent before running the synchronization. This can
+    be checked with ``_verify_param_shape_across_processes``.
+    """
+    module_states: List[torch.Tensor] = []
+    for name, param in module.named_parameters():
+        if name not in params_and_buffers_to_ignore:
+            module_states.append(param.detach())
+
+    if broadcast_buffers:
+        for name, buffer in module.named_buffers():
+            if name not in params_and_buffers_to_ignore:
+                module_states.append(buffer.detach())
+
+    _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
+
+
+def _sync_params_and_buffers(
+    process_group: dist.ProcessGroup,
+    module_states: List[torch.Tensor],
+    broadcast_bucket_size: int,
+    src: int,
+) -> None:
+    """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0."""
+    if len(module_states) > 0:
+        dist._broadcast_coalesced(
+            process_group, module_states, broadcast_bucket_size, src
+        )
+
+
+def _replace_by_prefix(
+    state_dict: Dict[str, Any],
+    old_prefix: str,
+    new_prefix: str,
+) -> None:
+    """
+    Replace all keys that match a given old_prefix with a new_prefix (in-place).
+
+    Usage::
+
+        state_dict = {"layer.xyz": torch.tensor(1)}
+        replace_by_prefix_(state_dict, "layer.", "module.layer.")
+        assert state_dict == {"module.layer.xyz": torch.tensor(1)}
+    """
+    if old_prefix == new_prefix:
+        raise ValueError("old_prefix and new_prefix must be distinct")
+    for key in list(state_dict.keys()):
+        if not key.startswith(old_prefix):
+            continue
+        new_key = new_prefix + key[len(old_prefix) :]
+        state_dict[new_key] = state_dict[key]
+        del state_dict[key]
+
+
+def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
+    return tensor.untyped_storage().data_ptr() > 0
diff --git a/MLPY/Lib/site-packages/torch/distributions/__init__.py b/MLPY/Lib/site-packages/torch/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66c2c9a8f5fd2e01a4c52fe439081cd43529c3c9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/__init__.py
@@ -0,0 +1,171 @@
+r"""
+The ``distributions`` package contains parameterizable probability distributions
+and sampling functions. This allows the construction of stochastic computation
+graphs and stochastic gradient estimators for optimization. This package
+generally follows the design of the `TensorFlow Distributions`_ package.
+
+.. _`TensorFlow Distributions`:
+    https://arxiv.org/abs/1711.10604
+
+It is not possible to directly backpropagate through random samples. However,
+there are two main methods for creating surrogate functions that can be
+backpropagated through. These are the score function estimator/likelihood ratio
+estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
+seen as the basis for policy gradient methods in reinforcement learning, and the
+pathwise derivative estimator is commonly seen in the reparameterization trick
+in variational autoencoders. Whilst the score function only requires the value
+of samples :math:`f(x)`, the pathwise derivative requires the derivative
+:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
+example. For more details see
+`Gradient Estimation Using Stochastic Computation Graphs`_ .
+
+.. _`Gradient Estimation Using Stochastic Computation Graphs`:
+     https://arxiv.org/abs/1506.05254
+
+Score function
+^^^^^^^^^^^^^^
+
+When the probability density function is differentiable with respect to its
+parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
+:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
+
+.. math::
+
+    \Delta\theta  = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
+
+where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
+:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
+taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
+
+In practice we would sample an action from the output of a network, apply this
+action in an environment, and then use ``log_prob`` to construct an equivalent
+loss function. Note that we use a negative because optimizers use gradient
+descent, whilst the rule above assumes gradient ascent. With a categorical
+policy, the code for implementing REINFORCE would be as follows::
+
+    probs = policy_network(state)
+    # Note that this is equivalent to what used to be called multinomial
+    m = Categorical(probs)
+    action = m.sample()
+    next_state, reward = env.step(action)
+    loss = -m.log_prob(action) * reward
+    loss.backward()
+
+Pathwise derivative
+^^^^^^^^^^^^^^^^^^^
+
+The other way to implement these stochastic/policy gradients would be to use the
+reparameterization trick from the
+:meth:`~torch.distributions.Distribution.rsample` method, where the
+parameterized random variable can be constructed via a parameterized
+deterministic function of a parameter-free random variable. The reparameterized
+sample therefore becomes differentiable. The code for implementing the pathwise
+derivative would be as follows::
+
+    params = policy_network(state)
+    m = Normal(*params)
+    # Any distribution with .has_rsample == True could work based on the application
+    action = m.rsample()
+    next_state, reward = env.step(action)  # Assuming that reward is differentiable
+    loss = -reward
+    loss.backward()
+"""
+
+from .bernoulli import Bernoulli
+from .beta import Beta
+from .binomial import Binomial
+from .categorical import Categorical
+from .cauchy import Cauchy
+from .chi2 import Chi2
+from .constraint_registry import biject_to, transform_to
+from .continuous_bernoulli import ContinuousBernoulli
+from .dirichlet import Dirichlet
+from .distribution import Distribution
+from .exp_family import ExponentialFamily
+from .exponential import Exponential
+from .fishersnedecor import FisherSnedecor
+from .gamma import Gamma
+from .geometric import Geometric
+from .gumbel import Gumbel
+from .half_cauchy import HalfCauchy
+from .half_normal import HalfNormal
+from .independent import Independent
+from .inverse_gamma import InverseGamma
+from .kl import _add_kl_info, kl_divergence, register_kl
+from .kumaraswamy import Kumaraswamy
+from .laplace import Laplace
+from .lkj_cholesky import LKJCholesky
+from .log_normal import LogNormal
+from .logistic_normal import LogisticNormal
+from .lowrank_multivariate_normal import LowRankMultivariateNormal
+from .mixture_same_family import MixtureSameFamily
+from .multinomial import Multinomial
+from .multivariate_normal import MultivariateNormal
+from .negative_binomial import NegativeBinomial
+from .normal import Normal
+from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
+from .pareto import Pareto
+from .poisson import Poisson
+from .relaxed_bernoulli import RelaxedBernoulli
+from .relaxed_categorical import RelaxedOneHotCategorical
+from .studentT import StudentT
+from .transformed_distribution import TransformedDistribution
+from .transforms import *  # noqa: F403
+from . import transforms
+from .uniform import Uniform
+from .von_mises import VonMises
+from .weibull import Weibull
+from .wishart import Wishart
+
+_add_kl_info()
+del _add_kl_info
+
+__all__ = [
+    "Bernoulli",
+    "Beta",
+    "Binomial",
+    "Categorical",
+    "Cauchy",
+    "Chi2",
+    "ContinuousBernoulli",
+    "Dirichlet",
+    "Distribution",
+    "Exponential",
+    "ExponentialFamily",
+    "FisherSnedecor",
+    "Gamma",
+    "Geometric",
+    "Gumbel",
+    "HalfCauchy",
+    "HalfNormal",
+    "Independent",
+    "InverseGamma",
+    "Kumaraswamy",
+    "LKJCholesky",
+    "Laplace",
+    "LogNormal",
+    "LogisticNormal",
+    "LowRankMultivariateNormal",
+    "MixtureSameFamily",
+    "Multinomial",
+    "MultivariateNormal",
+    "NegativeBinomial",
+    "Normal",
+    "OneHotCategorical",
+    "OneHotCategoricalStraightThrough",
+    "Pareto",
+    "RelaxedBernoulli",
+    "RelaxedOneHotCategorical",
+    "StudentT",
+    "Poisson",
+    "Uniform",
+    "VonMises",
+    "Weibull",
+    "Wishart",
+    "TransformedDistribution",
+    "biject_to",
+    "kl_divergence",
+    "register_kl",
+    "transform_to",
+]
+__all__.extend(transforms.__all__)
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6622ed0c298136dd98c8cb01f8d5379b8eb9daba
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/bernoulli.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/bernoulli.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b2d2a0b23659b4f19d8c490e230a9a1635863b4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/bernoulli.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/beta.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/beta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47b88a751c5f12357027de63c8091302fad51f3a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/beta.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/binomial.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/binomial.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f2f4fd99f62427d8111b172b2a99b913ba3f2a0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/binomial.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/categorical.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/categorical.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfbc2c4ba56dd7e210a145501b362439b4c39aa7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/categorical.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/cauchy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/cauchy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d14eb7f6a5ac6b396893fd4e0fb925da37642e2f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/cauchy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/chi2.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/chi2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1350f63bf19774ae7a91341c0a7b8079f1181fc9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/chi2.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4c80b0f77676e5ab94baf4461493fa8f679f9b5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/constraints.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/constraints.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d38704cd40b34eed5e51ea81fbac5c563d81aa10
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/constraints.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..517ab1cf0564f2c9d3b1a5a2a147f06c0ba3e497
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/dirichlet.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/dirichlet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..180eda18fdf137f3bb930c2a3a2a0202bf66a871
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/dirichlet.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/distribution.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/distribution.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15dcdfa00790b6be374d7239e3c67ee728fc791c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/distribution.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/exp_family.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/exp_family.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19ae9da6d5a89b02bf217e22115734a4f0d6517c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/exp_family.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/exponential.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/exponential.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8aa5a49e7f9baddfbd36662d106551f656257b9e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/exponential.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19e6073e24f8c8b358739261dd3f9da01e098e90
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/gamma.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/gamma.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2f8fc0f1d80ed8bda2720068c839a30c1053eb7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/gamma.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/geometric.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/geometric.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b6de6e6205596279bfc2de05973b329164493b9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/geometric.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/gumbel.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/gumbel.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48a0fa2bfa297a2764d06b0006efa99c82beefe9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/gumbel.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..109c7434b55504493e8e7a670e60a8ad3551a692
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/half_normal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/half_normal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..314ecfae63882443b70be2e8c64461570d2d772d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/half_normal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/independent.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/independent.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f22024469159ee2488092f95c63ea758625e6767
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/independent.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d95f72f5fb18cf7d28dc5d184a25fe449089e8ce
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/kl.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/kl.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07e82ecd2fb02ed3c094c5dad3d51003242e55ea
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/kl.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e82b8194bc64e22c5cdfc070e023af7a0f9f15e2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/laplace.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/laplace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..166f92486106b38859bf3bd5ce86fe114c382f1f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/laplace.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9dc2b48b97d11097ee73cd81e75d49821a4aa9a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/log_normal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/log_normal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27f86105c9399b1fa8fb6b8ac563f00565177711
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/log_normal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..787190086ef056318b3270edc5bae682d2d7744a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a8085a90479487e7234e9e10a11d35eaf1b1d4e
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65cafc1ec1e5364341cc547596543cdbdcc98927
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/multinomial.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/multinomial.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..777c98ea0e9aee8bea5e874af18545c9d55c1f21
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/multinomial.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5f9522c63aed983277409063bee0acda40bd9e5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72cf0c0b702a89cc43301a4ecfbad012618c89f0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/normal.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/normal.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14f7c5f6b76e72f943bc0699983921a4a3e603f0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/normal.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..310fc17d7a1982d4b0b0df43be2635a403a6657f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/pareto.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/pareto.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c466b9e939f3a9d722317babe78c83c20294dc5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/pareto.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/poisson.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/poisson.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f23195fbf4b0d9a806fe53ddeaaaea83bb2f0b7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/poisson.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5a7366fa7f4c7d8f8c8504f592911ffa3d33f45
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2dde6bb4b9357c38c495aff682a48e4fbc0739cd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/studentT.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/studentT.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15733f72e80c96a95cf1fafa029c318730694b85
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/studentT.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..095080ff2ca2ceee3b8f34f1b255c62948d0ac08
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/transforms.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/transforms.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f70b4839a42c78776b1293faa4825e25ea70073
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/transforms.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/uniform.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/uniform.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..355c38d3db4c41c4038f99ed66e253e185c475d7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/uniform.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..784d07c37df2892545e56cf93d5801a572353c14
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/von_mises.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/von_mises.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17284650552d73f9a6fefeb1f593f8d5e93fd43a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/von_mises.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/weibull.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/weibull.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ba30a5a45b2efce01083b728876b29fbb2042fd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/weibull.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/__pycache__/wishart.cpython-39.pyc b/MLPY/Lib/site-packages/torch/distributions/__pycache__/wishart.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da2fe2969d0b886defab55cf2ce6147da8115d1b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/distributions/__pycache__/wishart.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/distributions/bernoulli.py b/MLPY/Lib/site-packages/torch/distributions/bernoulli.py
new file mode 100644
index 0000000000000000000000000000000000000000..479c6b9a56bc72a4e0e213783f8ec0738606f276
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/bernoulli.py
@@ -0,0 +1,130 @@
+from numbers import Number
+
+import torch
+from torch import nan
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import (
+    broadcast_all,
+    lazy_property,
+    logits_to_probs,
+    probs_to_logits,
+)
+from torch.nn.functional import binary_cross_entropy_with_logits
+
+__all__ = ["Bernoulli"]
+
+
+class Bernoulli(ExponentialFamily):
+    r"""
+    Creates a Bernoulli distribution parameterized by :attr:`probs`
+    or :attr:`logits` (but not both).
+
+    Samples are binary (0 or 1). They take the value `1` with probability `p`
+    and `0` with probability `1 - p`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Bernoulli(torch.tensor([0.3]))
+        >>> m.sample()  # 30% chance 1; 70% chance 0
+        tensor([ 0.])
+
+    Args:
+        probs (Number, Tensor): the probability of sampling `1`
+        logits (Number, Tensor): the log-odds of sampling `1`
+    """
+    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
+    support = constraints.boolean
+    has_enumerate_support = True
+    _mean_carrier_measure = 0
+
+    def __init__(self, probs=None, logits=None, validate_args=None):
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            is_scalar = isinstance(probs, Number)
+            (self.probs,) = broadcast_all(probs)
+        else:
+            is_scalar = isinstance(logits, Number)
+            (self.logits,) = broadcast_all(logits)
+        self._param = self.probs if probs is not None else self.logits
+        if is_scalar:
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self._param.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Bernoulli, _instance)
+        batch_shape = torch.Size(batch_shape)
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(batch_shape)
+            new._param = new.probs
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(batch_shape)
+            new._param = new.logits
+        super(Bernoulli, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._param.new(*args, **kwargs)
+
+    @property
+    def mean(self):
+        return self.probs
+
+    @property
+    def mode(self):
+        mode = (self.probs >= 0.5).to(self.probs)
+        mode[self.probs == 0.5] = nan
+        return mode
+
+    @property
+    def variance(self):
+        return self.probs * (1 - self.probs)
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs, is_binary=True)
+
+    @lazy_property
+    def probs(self):
+        return logits_to_probs(self.logits, is_binary=True)
+
+    @property
+    def param_shape(self):
+        return self._param.size()
+
+    def sample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        with torch.no_grad():
+            return torch.bernoulli(self.probs.expand(shape))
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        logits, value = broadcast_all(self.logits, value)
+        return -binary_cross_entropy_with_logits(logits, value, reduction="none")
+
+    def entropy(self):
+        return binary_cross_entropy_with_logits(
+            self.logits, self.probs, reduction="none"
+        )
+
+    def enumerate_support(self, expand=True):
+        values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)
+        values = values.view((-1,) + (1,) * len(self._batch_shape))
+        if expand:
+            values = values.expand((-1,) + self._batch_shape)
+        return values
+
+    @property
+    def _natural_params(self):
+        return (torch.logit(self.probs),)
+
+    def _log_normalizer(self, x):
+        return torch.log1p(torch.exp(x))
diff --git a/MLPY/Lib/site-packages/torch/distributions/beta.py b/MLPY/Lib/site-packages/torch/distributions/beta.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d8cc383f0e1eba5a83ad803c4503e11cf04b85
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/beta.py
@@ -0,0 +1,107 @@
+from numbers import Number, Real
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.dirichlet import Dirichlet
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Beta"]
+
+
+class Beta(ExponentialFamily):
+    r"""
+    Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
+        >>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
+        tensor([ 0.1046])
+
+    Args:
+        concentration1 (float or Tensor): 1st concentration parameter of the distribution
+            (often referred to as alpha)
+        concentration0 (float or Tensor): 2nd concentration parameter of the distribution
+            (often referred to as beta)
+    """
+    arg_constraints = {
+        "concentration1": constraints.positive,
+        "concentration0": constraints.positive,
+    }
+    support = constraints.unit_interval
+    has_rsample = True
+
+    def __init__(self, concentration1, concentration0, validate_args=None):
+        if isinstance(concentration1, Real) and isinstance(concentration0, Real):
+            concentration1_concentration0 = torch.tensor(
+                [float(concentration1), float(concentration0)]
+            )
+        else:
+            concentration1, concentration0 = broadcast_all(
+                concentration1, concentration0
+            )
+            concentration1_concentration0 = torch.stack(
+                [concentration1, concentration0], -1
+            )
+        self._dirichlet = Dirichlet(
+            concentration1_concentration0, validate_args=validate_args
+        )
+        super().__init__(self._dirichlet._batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Beta, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new._dirichlet = self._dirichlet.expand(batch_shape)
+        super(Beta, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def mean(self):
+        return self.concentration1 / (self.concentration1 + self.concentration0)
+
+    @property
+    def mode(self):
+        return self._dirichlet.mode[..., 0]
+
+    @property
+    def variance(self):
+        total = self.concentration1 + self.concentration0
+        return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
+
+    def rsample(self, sample_shape=()):
+        return self._dirichlet.rsample(sample_shape).select(-1, 0)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        heads_tails = torch.stack([value, 1.0 - value], -1)
+        return self._dirichlet.log_prob(heads_tails)
+
+    def entropy(self):
+        return self._dirichlet.entropy()
+
+    @property
+    def concentration1(self):
+        result = self._dirichlet.concentration[..., 0]
+        if isinstance(result, Number):
+            return torch.tensor([result])
+        else:
+            return result
+
+    @property
+    def concentration0(self):
+        result = self._dirichlet.concentration[..., 1]
+        if isinstance(result, Number):
+            return torch.tensor([result])
+        else:
+            return result
+
+    @property
+    def _natural_params(self):
+        return (self.concentration1, self.concentration0)
+
+    def _log_normalizer(self, x, y):
+        return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
diff --git a/MLPY/Lib/site-packages/torch/distributions/binomial.py b/MLPY/Lib/site-packages/torch/distributions/binomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..847c6779770bd1481798707ac8dd8ed247bd5b02
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/binomial.py
@@ -0,0 +1,165 @@
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import (
+    broadcast_all,
+    lazy_property,
+    logits_to_probs,
+    probs_to_logits,
+)
+
+__all__ = ["Binomial"]
+
+
+def _clamp_by_zero(x):
+    # works like clamp(x, min=0) but has grad at 0 is 0.5
+    return (x.clamp(min=0) + x - x.clamp(max=0)) / 2
+
+
+class Binomial(Distribution):
+    r"""
+    Creates a Binomial distribution parameterized by :attr:`total_count` and
+    either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
+    broadcastable with :attr:`probs`/:attr:`logits`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
+        >>> x = m.sample()
+        tensor([   0.,   22.,   71.,  100.])
+
+        >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
+        >>> x = m.sample()
+        tensor([[ 4.,  5.],
+                [ 7.,  6.]])
+
+    Args:
+        total_count (int or Tensor): number of Bernoulli trials
+        probs (Tensor): Event probabilities
+        logits (Tensor): Event log-odds
+    """
+    arg_constraints = {
+        "total_count": constraints.nonnegative_integer,
+        "probs": constraints.unit_interval,
+        "logits": constraints.real,
+    }
+    has_enumerate_support = True
+
+    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            (
+                self.total_count,
+                self.probs,
+            ) = broadcast_all(total_count, probs)
+            self.total_count = self.total_count.type_as(self.probs)
+        else:
+            (
+                self.total_count,
+                self.logits,
+            ) = broadcast_all(total_count, logits)
+            self.total_count = self.total_count.type_as(self.logits)
+
+        self._param = self.probs if probs is not None else self.logits
+        batch_shape = self._param.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Binomial, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.total_count = self.total_count.expand(batch_shape)
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(batch_shape)
+            new._param = new.probs
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(batch_shape)
+            new._param = new.logits
+        super(Binomial, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._param.new(*args, **kwargs)
+
+    @constraints.dependent_property(is_discrete=True, event_dim=0)
+    def support(self):
+        return constraints.integer_interval(0, self.total_count)
+
+    @property
+    def mean(self):
+        return self.total_count * self.probs
+
+    @property
+    def mode(self):
+        return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)
+
+    @property
+    def variance(self):
+        return self.total_count * self.probs * (1 - self.probs)
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs, is_binary=True)
+
+    @lazy_property
+    def probs(self):
+        return logits_to_probs(self.logits, is_binary=True)
+
+    @property
+    def param_shape(self):
+        return self._param.size()
+
+    def sample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        with torch.no_grad():
+            return torch.binomial(
+                self.total_count.expand(shape), self.probs.expand(shape)
+            )
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        log_factorial_n = torch.lgamma(self.total_count + 1)
+        log_factorial_k = torch.lgamma(value + 1)
+        log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
+        # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
+        #     (case logit < 0)              = k * logit - n * log1p(e^logit)
+        #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
+        #                                   = k * logit - n * logit - n * log1p(e^-logit)
+        #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
+        normalize_term = (
+            self.total_count * _clamp_by_zero(self.logits)
+            + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
+            - log_factorial_n
+        )
+        return (
+            value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
+        )
+
+    def entropy(self):
+        total_count = int(self.total_count.max())
+        if not self.total_count.min() == total_count:
+            raise NotImplementedError(
+                "Inhomogeneous total count not supported by `entropy`."
+            )
+
+        log_prob = self.log_prob(self.enumerate_support(False))
+        return -(torch.exp(log_prob) * log_prob).sum(0)
+
+    def enumerate_support(self, expand=True):
+        total_count = int(self.total_count.max())
+        if not self.total_count.min() == total_count:
+            raise NotImplementedError(
+                "Inhomogeneous total count not supported by `enumerate_support`."
+            )
+        values = torch.arange(
+            1 + total_count, dtype=self._param.dtype, device=self._param.device
+        )
+        values = values.view((-1,) + (1,) * len(self._batch_shape))
+        if expand:
+            values = values.expand((-1,) + self._batch_shape)
+        return values
diff --git a/MLPY/Lib/site-packages/torch/distributions/categorical.py b/MLPY/Lib/site-packages/torch/distributions/categorical.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dbe7d6df01a5ce75ba6e5330c82bcaf5250d795
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/categorical.py
@@ -0,0 +1,155 @@
+import torch
+from torch import nan
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits
+
+__all__ = ["Categorical"]
+
+
+class Categorical(Distribution):
+    r"""
+    Creates a categorical distribution parameterized by either :attr:`probs` or
+    :attr:`logits` (but not both).
+
+    .. note::
+        It is equivalent to the distribution that :func:`torch.multinomial`
+        samples from.
+
+    Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
+
+    If `probs` is 1-dimensional with length-`K`, each element is the relative probability
+    of sampling the class at that index.
+
+    If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
+    relative probability vectors.
+
+    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
+              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
+              will return this normalized value.
+              The `logits` argument will be interpreted as unnormalized log probabilities
+              and can therefore be any real number. It will likewise be normalized so that
+              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
+              will return this normalized value.
+
+    See also: :func:`torch.multinomial`
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
+        >>> m.sample()  # equal probability of 0, 1, 2, 3
+        tensor(3)
+
+    Args:
+        probs (Tensor): event probabilities
+        logits (Tensor): event log probabilities (unnormalized)
+    """
+    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
+    has_enumerate_support = True
+
+    def __init__(self, probs=None, logits=None, validate_args=None):
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            if probs.dim() < 1:
+                raise ValueError("`probs` parameter must be at least one-dimensional.")
+            self.probs = probs / probs.sum(-1, keepdim=True)
+        else:
+            if logits.dim() < 1:
+                raise ValueError("`logits` parameter must be at least one-dimensional.")
+            # Normalize
+            self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
+        self._param = self.probs if probs is not None else self.logits
+        self._num_events = self._param.size()[-1]
+        batch_shape = (
+            self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
+        )
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Categorical, _instance)
+        batch_shape = torch.Size(batch_shape)
+        param_shape = batch_shape + torch.Size((self._num_events,))
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(param_shape)
+            new._param = new.probs
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(param_shape)
+            new._param = new.logits
+        new._num_events = self._num_events
+        super(Categorical, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._param.new(*args, **kwargs)
+
+    @constraints.dependent_property(is_discrete=True, event_dim=0)
+    def support(self):
+        return constraints.integer_interval(0, self._num_events - 1)
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs)
+
+    @lazy_property
+    def probs(self):
+        return logits_to_probs(self.logits)
+
+    @property
+    def param_shape(self):
+        return self._param.size()
+
+    @property
+    def mean(self):
+        return torch.full(
+            self._extended_shape(),
+            nan,
+            dtype=self.probs.dtype,
+            device=self.probs.device,
+        )
+
+    @property
+    def mode(self):
+        return self.probs.argmax(axis=-1)
+
+    @property
+    def variance(self):
+        return torch.full(
+            self._extended_shape(),
+            nan,
+            dtype=self.probs.dtype,
+            device=self.probs.device,
+        )
+
+    def sample(self, sample_shape=torch.Size()):
+        if not isinstance(sample_shape, torch.Size):
+            sample_shape = torch.Size(sample_shape)
+        probs_2d = self.probs.reshape(-1, self._num_events)
+        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
+        return samples_2d.reshape(self._extended_shape(sample_shape))
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        value = value.long().unsqueeze(-1)
+        value, log_pmf = torch.broadcast_tensors(value, self.logits)
+        value = value[..., :1]
+        return log_pmf.gather(-1, value).squeeze(-1)
+
+    def entropy(self):
+        min_real = torch.finfo(self.logits.dtype).min
+        logits = torch.clamp(self.logits, min=min_real)
+        p_log_p = logits * self.probs
+        return -p_log_p.sum(-1)
+
+    def enumerate_support(self, expand=True):
+        num_events = self._num_events
+        values = torch.arange(num_events, dtype=torch.long, device=self._param.device)
+        values = values.view((-1,) + (1,) * len(self._batch_shape))
+        if expand:
+            values = values.expand((-1,) + self._batch_shape)
+        return values
diff --git a/MLPY/Lib/site-packages/torch/distributions/cauchy.py b/MLPY/Lib/site-packages/torch/distributions/cauchy.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3ce3016066928a7a030face80733d0b6386b600
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/cauchy.py
@@ -0,0 +1,90 @@
+import math
+from numbers import Number
+
+import torch
+from torch import inf, nan
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Cauchy"]
+
+
+class Cauchy(Distribution):
+    r"""
+    Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
+    independent normally distributed random variables with means `0` follows a
+    Cauchy distribution.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
+        >>> m.sample()  # sample from a Cauchy distribution with loc=0 and scale=1
+        tensor([ 2.3214])
+
+    Args:
+        loc (float or Tensor): mode or median of the distribution.
+        scale (float or Tensor): half width at half maximum.
+    """
+    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+    support = constraints.real
+    has_rsample = True
+
+    def __init__(self, loc, scale, validate_args=None):
+        self.loc, self.scale = broadcast_all(loc, scale)
+        if isinstance(loc, Number) and isinstance(scale, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.loc.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Cauchy, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.loc = self.loc.expand(batch_shape)
+        new.scale = self.scale.expand(batch_shape)
+        super(Cauchy, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def mean(self):
+        return torch.full(
+            self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device
+        )
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @property
+    def variance(self):
+        return torch.full(
+            self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
+        )
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        eps = self.loc.new(shape).cauchy_()
+        return self.loc + eps * self.scale
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return (
+            -math.log(math.pi)
+            - self.scale.log()
+            - (((value - self.loc) / self.scale) ** 2).log1p()
+        )
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
+
+    def icdf(self, value):
+        return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
+
+    def entropy(self):
+        return math.log(4 * math.pi) + self.scale.log()
diff --git a/MLPY/Lib/site-packages/torch/distributions/chi2.py b/MLPY/Lib/site-packages/torch/distributions/chi2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8978cd242e9ad43782f1420c5f1b7bc04f6bf0e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/chi2.py
@@ -0,0 +1,33 @@
+from torch.distributions import constraints
+from torch.distributions.gamma import Gamma
+
+__all__ = ["Chi2"]
+
+
+class Chi2(Gamma):
+    r"""
+    Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`.
+    This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)``
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Chi2(torch.tensor([1.0]))
+        >>> m.sample()  # Chi2 distributed with shape df=1
+        tensor([ 0.1046])
+
+    Args:
+        df (float or Tensor): shape parameter of the distribution
+    """
+    arg_constraints = {"df": constraints.positive}
+
+    def __init__(self, df, validate_args=None):
+        super().__init__(0.5 * df, 0.5, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Chi2, _instance)
+        return super().expand(batch_shape, new)
+
+    @property
+    def df(self):
+        return self.concentration * 2
diff --git a/MLPY/Lib/site-packages/torch/distributions/constraint_registry.py b/MLPY/Lib/site-packages/torch/distributions/constraint_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86b6dead2666bb1426200d55900d6135bdde768
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/constraint_registry.py
@@ -0,0 +1,292 @@
+r"""
+PyTorch provides two global :class:`ConstraintRegistry` objects that link
+:class:`~torch.distributions.constraints.Constraint` objects to
+:class:`~torch.distributions.transforms.Transform` objects. These objects both
+input constraints and return transforms, but they have different guarantees on
+bijectivity.
+
+1. ``biject_to(constraint)`` looks up a bijective
+   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
+   to the given ``constraint``. The returned transform is guaranteed to have
+   ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
+2. ``transform_to(constraint)`` looks up a not-necessarily bijective
+   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
+   to the given ``constraint``. The returned transform is not guaranteed to
+   implement ``.log_abs_det_jacobian()``.
+
+The ``transform_to()`` registry is useful for performing unconstrained
+optimization on constrained parameters of probability distributions, which are
+indicated by each distribution's ``.arg_constraints`` dict. These transforms often
+overparameterize a space in order to avoid rotation; they are thus more
+suitable for coordinate-wise optimization algorithms like Adam::
+
+    loc = torch.zeros(100, requires_grad=True)
+    unconstrained = torch.zeros(100, requires_grad=True)
+    scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
+    loss = -Normal(loc, scale).log_prob(data).sum()
+
+The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
+samples from a probability distribution with constrained ``.support`` are
+propagated in an unconstrained space, and algorithms are typically rotation
+invariant.::
+
+    dist = Exponential(rate)
+    unconstrained = torch.zeros(100, requires_grad=True)
+    sample = biject_to(dist.support)(unconstrained)
+    potential_energy = -dist.log_prob(sample).sum()
+
+.. note::
+
+    An example where ``transform_to`` and ``biject_to`` differ is
+    ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
+    :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
+    exponentiates and normalizes its inputs; this is a cheap and mostly
+    coordinate-wise operation appropriate for algorithms like SVI. In
+    contrast, ``biject_to(constraints.simplex)`` returns a
+    :class:`~torch.distributions.transforms.StickBreakingTransform` that
+    bijects its input down to a one-fewer-dimensional space; this a more
+    expensive less numerically stable transform but is needed for algorithms
+    like HMC.
+
+The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
+constraints and transforms using their ``.register()`` method either as a
+function on singleton constraints::
+
+    transform_to.register(my_constraint, my_transform)
+
+or as a decorator on parameterized constraints::
+
+    @transform_to.register(MyConstraintClass)
+    def my_factory(constraint):
+        assert isinstance(constraint, MyConstraintClass)
+        return MyTransform(constraint.param1, constraint.param2)
+
+You can create your own registry by creating a new :class:`ConstraintRegistry`
+object.
+"""
+
+import numbers
+
+from torch.distributions import constraints, transforms
+
+__all__ = [
+    "ConstraintRegistry",
+    "biject_to",
+    "transform_to",
+]
+
+
+class ConstraintRegistry:
+    """
+    Registry to link constraints to transforms.
+    """
+
+    def __init__(self):
+        self._registry = {}
+        super().__init__()
+
+    def register(self, constraint, factory=None):
+        """
+        Registers a :class:`~torch.distributions.constraints.Constraint`
+        subclass in this registry. Usage::
+
+            @my_registry.register(MyConstraintClass)
+            def construct_transform(constraint):
+                assert isinstance(constraint, MyConstraint)
+                return MyTransform(constraint.arg_constraints)
+
+        Args:
+            constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
+                A subclass of :class:`~torch.distributions.constraints.Constraint`, or
+                a singleton object of the desired class.
+            factory (Callable): A callable that inputs a constraint object and returns
+                a  :class:`~torch.distributions.transforms.Transform` object.
+        """
+        # Support use as decorator.
+        if factory is None:
+            return lambda factory: self.register(constraint, factory)
+
+        # Support calling on singleton instances.
+        if isinstance(constraint, constraints.Constraint):
+            constraint = type(constraint)
+
+        if not isinstance(constraint, type) or not issubclass(
+            constraint, constraints.Constraint
+        ):
+            raise TypeError(
+                f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
+            )
+
+        self._registry[constraint] = factory
+        return factory
+
+    def __call__(self, constraint):
+        """
+        Looks up a transform to constrained space, given a constraint object.
+        Usage::
+
+            constraint = Normal.arg_constraints['scale']
+            scale = transform_to(constraint)(torch.zeros(1))  # constrained
+            u = transform_to(constraint).inv(scale)           # unconstrained
+
+        Args:
+            constraint (:class:`~torch.distributions.constraints.Constraint`):
+                A constraint object.
+
+        Returns:
+            A :class:`~torch.distributions.transforms.Transform` object.
+
+        Raises:
+            `NotImplementedError` if no transform has been registered.
+        """
+        # Look up by Constraint subclass.
+        try:
+            factory = self._registry[type(constraint)]
+        except KeyError:
+            raise NotImplementedError(
+                f"Cannot transform {type(constraint).__name__} constraints"
+            ) from None
+        return factory(constraint)
+
+
+biject_to = ConstraintRegistry()
+transform_to = ConstraintRegistry()
+
+
+################################################################################
+# Registration Table
+################################################################################
+
+
+@biject_to.register(constraints.real)
+@transform_to.register(constraints.real)
+def _transform_to_real(constraint):
+    return transforms.identity_transform
+
+
+@biject_to.register(constraints.independent)
+def _biject_to_independent(constraint):
+    base_transform = biject_to(constraint.base_constraint)
+    return transforms.IndependentTransform(
+        base_transform, constraint.reinterpreted_batch_ndims
+    )
+
+
+@transform_to.register(constraints.independent)
+def _transform_to_independent(constraint):
+    base_transform = transform_to(constraint.base_constraint)
+    return transforms.IndependentTransform(
+        base_transform, constraint.reinterpreted_batch_ndims
+    )
+
+
+@biject_to.register(constraints.positive)
+@biject_to.register(constraints.nonnegative)
+@transform_to.register(constraints.positive)
+@transform_to.register(constraints.nonnegative)
+def _transform_to_positive(constraint):
+    return transforms.ExpTransform()
+
+
+@biject_to.register(constraints.greater_than)
+@biject_to.register(constraints.greater_than_eq)
+@transform_to.register(constraints.greater_than)
+@transform_to.register(constraints.greater_than_eq)
+def _transform_to_greater_than(constraint):
+    return transforms.ComposeTransform(
+        [
+            transforms.ExpTransform(),
+            transforms.AffineTransform(constraint.lower_bound, 1),
+        ]
+    )
+
+
+@biject_to.register(constraints.less_than)
+@transform_to.register(constraints.less_than)
+def _transform_to_less_than(constraint):
+    return transforms.ComposeTransform(
+        [
+            transforms.ExpTransform(),
+            transforms.AffineTransform(constraint.upper_bound, -1),
+        ]
+    )
+
+
+@biject_to.register(constraints.interval)
+@biject_to.register(constraints.half_open_interval)
+@transform_to.register(constraints.interval)
+@transform_to.register(constraints.half_open_interval)
+def _transform_to_interval(constraint):
+    # Handle the special case of the unit interval.
+    lower_is_0 = (
+        isinstance(constraint.lower_bound, numbers.Number)
+        and constraint.lower_bound == 0
+    )
+    upper_is_1 = (
+        isinstance(constraint.upper_bound, numbers.Number)
+        and constraint.upper_bound == 1
+    )
+    if lower_is_0 and upper_is_1:
+        return transforms.SigmoidTransform()
+
+    loc = constraint.lower_bound
+    scale = constraint.upper_bound - constraint.lower_bound
+    return transforms.ComposeTransform(
+        [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
+    )
+
+
+@biject_to.register(constraints.simplex)
+def _biject_to_simplex(constraint):
+    return transforms.StickBreakingTransform()
+
+
+@transform_to.register(constraints.simplex)
+def _transform_to_simplex(constraint):
+    return transforms.SoftmaxTransform()
+
+
+# TODO define a bijection for LowerCholeskyTransform
+@transform_to.register(constraints.lower_cholesky)
+def _transform_to_lower_cholesky(constraint):
+    return transforms.LowerCholeskyTransform()
+
+
+@transform_to.register(constraints.positive_definite)
+@transform_to.register(constraints.positive_semidefinite)
+def _transform_to_positive_definite(constraint):
+    return transforms.PositiveDefiniteTransform()
+
+
+@biject_to.register(constraints.corr_cholesky)
+@transform_to.register(constraints.corr_cholesky)
+def _transform_to_corr_cholesky(constraint):
+    return transforms.CorrCholeskyTransform()
+
+
+@biject_to.register(constraints.cat)
+def _biject_to_cat(constraint):
+    return transforms.CatTransform(
+        [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
+    )
+
+
+@transform_to.register(constraints.cat)
+def _transform_to_cat(constraint):
+    return transforms.CatTransform(
+        [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
+    )
+
+
+@biject_to.register(constraints.stack)
+def _biject_to_stack(constraint):
+    return transforms.StackTransform(
+        [biject_to(c) for c in constraint.cseq], constraint.dim
+    )
+
+
+@transform_to.register(constraints.stack)
+def _transform_to_stack(constraint):
+    return transforms.StackTransform(
+        [transform_to(c) for c in constraint.cseq], constraint.dim
+    )
diff --git a/MLPY/Lib/site-packages/torch/distributions/constraints.py b/MLPY/Lib/site-packages/torch/distributions/constraints.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bece7c44c276fe5cf5d47a15cd6d3127a1d7c2a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/constraints.py
@@ -0,0 +1,657 @@
+r"""
+The following constraints are implemented:
+
+- ``constraints.boolean``
+- ``constraints.cat``
+- ``constraints.corr_cholesky``
+- ``constraints.dependent``
+- ``constraints.greater_than(lower_bound)``
+- ``constraints.greater_than_eq(lower_bound)``
+- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
+- ``constraints.integer_interval(lower_bound, upper_bound)``
+- ``constraints.interval(lower_bound, upper_bound)``
+- ``constraints.less_than(upper_bound)``
+- ``constraints.lower_cholesky``
+- ``constraints.lower_triangular``
+- ``constraints.multinomial``
+- ``constraints.nonnegative``
+- ``constraints.nonnegative_integer``
+- ``constraints.one_hot``
+- ``constraints.positive_integer``
+- ``constraints.positive``
+- ``constraints.positive_semidefinite``
+- ``constraints.positive_definite``
+- ``constraints.real_vector``
+- ``constraints.real``
+- ``constraints.simplex``
+- ``constraints.symmetric``
+- ``constraints.stack``
+- ``constraints.square``
+- ``constraints.symmetric``
+- ``constraints.unit_interval``
+"""
+
+import torch
+
+__all__ = [
+    "Constraint",
+    "boolean",
+    "cat",
+    "corr_cholesky",
+    "dependent",
+    "dependent_property",
+    "greater_than",
+    "greater_than_eq",
+    "independent",
+    "integer_interval",
+    "interval",
+    "half_open_interval",
+    "is_dependent",
+    "less_than",
+    "lower_cholesky",
+    "lower_triangular",
+    "multinomial",
+    "nonnegative",
+    "nonnegative_integer",
+    "one_hot",
+    "positive",
+    "positive_semidefinite",
+    "positive_definite",
+    "positive_integer",
+    "real",
+    "real_vector",
+    "simplex",
+    "square",
+    "stack",
+    "symmetric",
+    "unit_interval",
+]
+
+
+class Constraint:
+    """
+    Abstract base class for constraints.
+
+    A constraint object represents a region over which a variable is valid,
+    e.g. within which a variable can be optimized.
+
+    Attributes:
+        is_discrete (bool): Whether constrained space is discrete.
+            Defaults to False.
+        event_dim (int): Number of rightmost dimensions that together define
+            an event. The :meth:`check` method will remove this many dimensions
+            when computing validity.
+    """
+
+    is_discrete = False  # Default to continuous.
+    event_dim = 0  # Default to univariate.
+
+    def check(self, value):
+        """
+        Returns a byte tensor of ``sample_shape + batch_shape`` indicating
+        whether each event in value satisfies this constraint.
+        """
+        raise NotImplementedError
+
+    def __repr__(self):
+        return self.__class__.__name__[1:] + "()"
+
+
+class _Dependent(Constraint):
+    """
+    Placeholder for variables whose support depends on other variables.
+    These variables obey no simple coordinate-wise constraints.
+
+    Args:
+        is_discrete (bool): Optional value of ``.is_discrete`` in case this
+            can be computed statically. If not provided, access to the
+            ``.is_discrete`` attribute will raise a NotImplementedError.
+        event_dim (int): Optional value of ``.event_dim`` in case this
+            can be computed statically. If not provided, access to the
+            ``.event_dim`` attribute will raise a NotImplementedError.
+    """
+
+    def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
+        self._is_discrete = is_discrete
+        self._event_dim = event_dim
+        super().__init__()
+
+    @property
+    def is_discrete(self):
+        if self._is_discrete is NotImplemented:
+            raise NotImplementedError(".is_discrete cannot be determined statically")
+        return self._is_discrete
+
+    @property
+    def event_dim(self):
+        if self._event_dim is NotImplemented:
+            raise NotImplementedError(".event_dim cannot be determined statically")
+        return self._event_dim
+
+    def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
+        """
+        Support for syntax to customize static attributes::
+
+            constraints.dependent(is_discrete=True, event_dim=1)
+        """
+        if is_discrete is NotImplemented:
+            is_discrete = self._is_discrete
+        if event_dim is NotImplemented:
+            event_dim = self._event_dim
+        return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
+
+    def check(self, x):
+        raise ValueError("Cannot determine validity of dependent constraint")
+
+
+def is_dependent(constraint):
+    return isinstance(constraint, _Dependent)
+
+
+class _DependentProperty(property, _Dependent):
+    """
+    Decorator that extends @property to act like a `Dependent` constraint when
+    called on a class and act like a property when called on an object.
+
+    Example::
+
+        class Uniform(Distribution):
+            def __init__(self, low, high):
+                self.low = low
+                self.high = high
+            @constraints.dependent_property(is_discrete=False, event_dim=0)
+            def support(self):
+                return constraints.interval(self.low, self.high)
+
+    Args:
+        fn (Callable): The function to be decorated.
+        is_discrete (bool): Optional value of ``.is_discrete`` in case this
+            can be computed statically. If not provided, access to the
+            ``.is_discrete`` attribute will raise a NotImplementedError.
+        event_dim (int): Optional value of ``.event_dim`` in case this
+            can be computed statically. If not provided, access to the
+            ``.event_dim`` attribute will raise a NotImplementedError.
+    """
+
+    def __init__(
+        self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
+    ):
+        super().__init__(fn)
+        self._is_discrete = is_discrete
+        self._event_dim = event_dim
+
+    def __call__(self, fn):
+        """
+        Support for syntax to customize static attributes::
+
+            @constraints.dependent_property(is_discrete=True, event_dim=1)
+            def support(self):
+                ...
+        """
+        return _DependentProperty(
+            fn, is_discrete=self._is_discrete, event_dim=self._event_dim
+        )
+
+
+class _IndependentConstraint(Constraint):
+    """
+    Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
+    dims in :meth:`check`, so that an event is valid only if all its
+    independent entries are valid.
+    """
+
+    def __init__(self, base_constraint, reinterpreted_batch_ndims):
+        assert isinstance(base_constraint, Constraint)
+        assert isinstance(reinterpreted_batch_ndims, int)
+        assert reinterpreted_batch_ndims >= 0
+        self.base_constraint = base_constraint
+        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
+        super().__init__()
+
+    @property
+    def is_discrete(self):
+        return self.base_constraint.is_discrete
+
+    @property
+    def event_dim(self):
+        return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
+
+    def check(self, value):
+        result = self.base_constraint.check(value)
+        if result.dim() < self.reinterpreted_batch_ndims:
+            expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
+            raise ValueError(
+                f"Expected value.dim() >= {expected} but got {value.dim()}"
+            )
+        result = result.reshape(
+            result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
+        )
+        result = result.all(-1)
+        return result
+
+    def __repr__(self):
+        return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
+
+
+class _Boolean(Constraint):
+    """
+    Constrain to the two values `{0, 1}`.
+    """
+
+    is_discrete = True
+
+    def check(self, value):
+        return (value == 0) | (value == 1)
+
+
+class _OneHot(Constraint):
+    """
+    Constrain to one-hot vectors.
+    """
+
+    is_discrete = True
+    event_dim = 1
+
+    def check(self, value):
+        is_boolean = (value == 0) | (value == 1)
+        is_normalized = value.sum(-1).eq(1)
+        return is_boolean.all(-1) & is_normalized
+
+
+class _IntegerInterval(Constraint):
+    """
+    Constrain to an integer interval `[lower_bound, upper_bound]`.
+    """
+
+    is_discrete = True
+
+    def __init__(self, lower_bound, upper_bound):
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        super().__init__()
+
+    def check(self, value):
+        return (
+            (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
+        )
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += (
+            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
+        )
+        return fmt_string
+
+
+class _IntegerLessThan(Constraint):
+    """
+    Constrain to an integer interval `(-inf, upper_bound]`.
+    """
+
+    is_discrete = True
+
+    def __init__(self, upper_bound):
+        self.upper_bound = upper_bound
+        super().__init__()
+
+    def check(self, value):
+        return (value % 1 == 0) & (value <= self.upper_bound)
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += f"(upper_bound={self.upper_bound})"
+        return fmt_string
+
+
+class _IntegerGreaterThan(Constraint):
+    """
+    Constrain to an integer interval `[lower_bound, inf)`.
+    """
+
+    is_discrete = True
+
+    def __init__(self, lower_bound):
+        self.lower_bound = lower_bound
+        super().__init__()
+
+    def check(self, value):
+        return (value % 1 == 0) & (value >= self.lower_bound)
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += f"(lower_bound={self.lower_bound})"
+        return fmt_string
+
+
+class _Real(Constraint):
+    """
+    Trivially constrain to the extended real line `[-inf, inf]`.
+    """
+
+    def check(self, value):
+        return value == value  # False for NANs.
+
+
+class _GreaterThan(Constraint):
+    """
+    Constrain to a real half line `(lower_bound, inf]`.
+    """
+
+    def __init__(self, lower_bound):
+        self.lower_bound = lower_bound
+        super().__init__()
+
+    def check(self, value):
+        return self.lower_bound < value
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += f"(lower_bound={self.lower_bound})"
+        return fmt_string
+
+
+class _GreaterThanEq(Constraint):
+    """
+    Constrain to a real half line `[lower_bound, inf)`.
+    """
+
+    def __init__(self, lower_bound):
+        self.lower_bound = lower_bound
+        super().__init__()
+
+    def check(self, value):
+        return self.lower_bound <= value
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += f"(lower_bound={self.lower_bound})"
+        return fmt_string
+
+
+class _LessThan(Constraint):
+    """
+    Constrain to a real half line `[-inf, upper_bound)`.
+    """
+
+    def __init__(self, upper_bound):
+        self.upper_bound = upper_bound
+        super().__init__()
+
+    def check(self, value):
+        return value < self.upper_bound
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += f"(upper_bound={self.upper_bound})"
+        return fmt_string
+
+
+class _Interval(Constraint):
+    """
+    Constrain to a real interval `[lower_bound, upper_bound]`.
+    """
+
+    def __init__(self, lower_bound, upper_bound):
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        super().__init__()
+
+    def check(self, value):
+        return (self.lower_bound <= value) & (value <= self.upper_bound)
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += (
+            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
+        )
+        return fmt_string
+
+
+class _HalfOpenInterval(Constraint):
+    """
+    Constrain to a real interval `[lower_bound, upper_bound)`.
+    """
+
+    def __init__(self, lower_bound, upper_bound):
+        self.lower_bound = lower_bound
+        self.upper_bound = upper_bound
+        super().__init__()
+
+    def check(self, value):
+        return (self.lower_bound <= value) & (value < self.upper_bound)
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__[1:]
+        fmt_string += (
+            f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
+        )
+        return fmt_string
+
+
+class _Simplex(Constraint):
+    """
+    Constrain to the unit simplex in the innermost (rightmost) dimension.
+    Specifically: `x >= 0` and `x.sum(-1) == 1`.
+    """
+
+    event_dim = 1
+
+    def check(self, value):
+        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
+
+
+class _Multinomial(Constraint):
+    """
+    Constrain to nonnegative integer values summing to at most an upper bound.
+
+    Note due to limitations of the Multinomial distribution, this currently
+    checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
+    this may be strengthened to ``value.sum(-1) == upper_bound``.
+    """
+
+    is_discrete = True
+    event_dim = 1
+
+    def __init__(self, upper_bound):
+        self.upper_bound = upper_bound
+
+    def check(self, x):
+        return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
+
+
+class _LowerTriangular(Constraint):
+    """
+    Constrain to lower-triangular square matrices.
+    """
+
+    event_dim = 2
+
+    def check(self, value):
+        value_tril = value.tril()
+        return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
+
+
+class _LowerCholesky(Constraint):
+    """
+    Constrain to lower-triangular square matrices with positive diagonals.
+    """
+
+    event_dim = 2
+
+    def check(self, value):
+        value_tril = value.tril()
+        lower_triangular = (
+            (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
+        )
+
+        positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
+        return lower_triangular & positive_diagonal
+
+
+class _CorrCholesky(Constraint):
+    """
+    Constrain to lower-triangular square matrices with positive diagonals and each
+    row vector being of unit length.
+    """
+
+    event_dim = 2
+
+    def check(self, value):
+        tol = (
+            torch.finfo(value.dtype).eps * value.size(-1) * 10
+        )  # 10 is an adjustable fudge factor
+        row_norm = torch.linalg.norm(value.detach(), dim=-1)
+        unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1)
+        return _LowerCholesky().check(value) & unit_row_norm
+
+
+class _Square(Constraint):
+    """
+    Constrain to square matrices.
+    """
+
+    event_dim = 2
+
+    def check(self, value):
+        return torch.full(
+            size=value.shape[:-2],
+            fill_value=(value.shape[-2] == value.shape[-1]),
+            dtype=torch.bool,
+            device=value.device,
+        )
+
+
+class _Symmetric(_Square):
+    """
+    Constrain to Symmetric square matrices.
+    """
+
+    def check(self, value):
+        square_check = super().check(value)
+        if not square_check.all():
+            return square_check
+        return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
+
+
+class _PositiveSemidefinite(_Symmetric):
+    """
+    Constrain to positive-semidefinite matrices.
+    """
+
+    def check(self, value):
+        sym_check = super().check(value)
+        if not sym_check.all():
+            return sym_check
+        return torch.linalg.eigvalsh(value).ge(0).all(-1)
+
+
+class _PositiveDefinite(_Symmetric):
+    """
+    Constrain to positive-definite matrices.
+    """
+
+    def check(self, value):
+        sym_check = super().check(value)
+        if not sym_check.all():
+            return sym_check
+        return torch.linalg.cholesky_ex(value).info.eq(0)
+
+
+class _Cat(Constraint):
+    """
+    Constraint functor that applies a sequence of constraints
+    `cseq` at the submatrices at dimension `dim`,
+    each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
+    """
+
+    def __init__(self, cseq, dim=0, lengths=None):
+        assert all(isinstance(c, Constraint) for c in cseq)
+        self.cseq = list(cseq)
+        if lengths is None:
+            lengths = [1] * len(self.cseq)
+        self.lengths = list(lengths)
+        assert len(self.lengths) == len(self.cseq)
+        self.dim = dim
+        super().__init__()
+
+    @property
+    def is_discrete(self):
+        return any(c.is_discrete for c in self.cseq)
+
+    @property
+    def event_dim(self):
+        return max(c.event_dim for c in self.cseq)
+
+    def check(self, value):
+        assert -value.dim() <= self.dim < value.dim()
+        checks = []
+        start = 0
+        for constr, length in zip(self.cseq, self.lengths):
+            v = value.narrow(self.dim, start, length)
+            checks.append(constr.check(v))
+            start = start + length  # avoid += for jit compat
+        return torch.cat(checks, self.dim)
+
+
+class _Stack(Constraint):
+    """
+    Constraint functor that applies a sequence of constraints
+    `cseq` at the submatrices at dimension `dim`,
+    in a way compatible with :func:`torch.stack`.
+    """
+
+    def __init__(self, cseq, dim=0):
+        assert all(isinstance(c, Constraint) for c in cseq)
+        self.cseq = list(cseq)
+        self.dim = dim
+        super().__init__()
+
+    @property
+    def is_discrete(self):
+        return any(c.is_discrete for c in self.cseq)
+
+    @property
+    def event_dim(self):
+        dim = max(c.event_dim for c in self.cseq)
+        if self.dim + dim < 0:
+            dim += 1
+        return dim
+
+    def check(self, value):
+        assert -value.dim() <= self.dim < value.dim()
+        vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
+        return torch.stack(
+            [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim
+        )
+
+
+# Public interface.
+dependent = _Dependent()
+dependent_property = _DependentProperty
+independent = _IndependentConstraint
+boolean = _Boolean()
+one_hot = _OneHot()
+nonnegative_integer = _IntegerGreaterThan(0)
+positive_integer = _IntegerGreaterThan(1)
+integer_interval = _IntegerInterval
+real = _Real()
+real_vector = independent(real, 1)
+positive = _GreaterThan(0.0)
+nonnegative = _GreaterThanEq(0.0)
+greater_than = _GreaterThan
+greater_than_eq = _GreaterThanEq
+less_than = _LessThan
+multinomial = _Multinomial
+unit_interval = _Interval(0.0, 1.0)
+interval = _Interval
+half_open_interval = _HalfOpenInterval
+simplex = _Simplex()
+lower_triangular = _LowerTriangular()
+lower_cholesky = _LowerCholesky()
+corr_cholesky = _CorrCholesky()
+square = _Square()
+symmetric = _Symmetric()
+positive_semidefinite = _PositiveSemidefinite()
+positive_definite = _PositiveDefinite()
+cat = _Cat
+stack = _Stack
diff --git a/MLPY/Lib/site-packages/torch/distributions/continuous_bernoulli.py b/MLPY/Lib/site-packages/torch/distributions/continuous_bernoulli.py
new file mode 100644
index 0000000000000000000000000000000000000000..a867738dbe74085414bc3d3a591dc40531a57173
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/continuous_bernoulli.py
@@ -0,0 +1,235 @@
+import math
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import (
+    broadcast_all,
+    clamp_probs,
+    lazy_property,
+    logits_to_probs,
+    probs_to_logits,
+)
+from torch.nn.functional import binary_cross_entropy_with_logits
+
+__all__ = ["ContinuousBernoulli"]
+
+
+class ContinuousBernoulli(ExponentialFamily):
+    r"""
+    Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
+    or :attr:`logits` (but not both).
+
+    The distribution is supported in [0, 1] and parameterized by 'probs' (in
+    (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
+    does not correspond to a probability and 'logits' does not correspond to
+    log-odds, but the same names are used due to the similarity with the
+    Bernoulli. See [1] for more details.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = ContinuousBernoulli(torch.tensor([0.3]))
+        >>> m.sample()
+        tensor([ 0.2538])
+
+    Args:
+        probs (Number, Tensor): (0,1) valued parameters
+        logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
+
+    [1] The continuous Bernoulli: fixing a pervasive error in variational
+    autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
+    https://arxiv.org/abs/1907.06845
+    """
+    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
+    support = constraints.unit_interval
+    _mean_carrier_measure = 0
+    has_rsample = True
+
+    def __init__(
+        self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
+    ):
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            is_scalar = isinstance(probs, Number)
+            (self.probs,) = broadcast_all(probs)
+            # validate 'probs' here if necessary as it is later clamped for numerical stability
+            # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
+            if validate_args is not None:
+                if not self.arg_constraints["probs"].check(self.probs).all():
+                    raise ValueError("The parameter probs has invalid values")
+            self.probs = clamp_probs(self.probs)
+        else:
+            is_scalar = isinstance(logits, Number)
+            (self.logits,) = broadcast_all(logits)
+        self._param = self.probs if probs is not None else self.logits
+        if is_scalar:
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self._param.size()
+        self._lims = lims
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(ContinuousBernoulli, _instance)
+        new._lims = self._lims
+        batch_shape = torch.Size(batch_shape)
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(batch_shape)
+            new._param = new.probs
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(batch_shape)
+            new._param = new.logits
+        super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._param.new(*args, **kwargs)
+
+    def _outside_unstable_region(self):
+        return torch.max(
+            torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1])
+        )
+
+    def _cut_probs(self):
+        return torch.where(
+            self._outside_unstable_region(),
+            self.probs,
+            self._lims[0] * torch.ones_like(self.probs),
+        )
+
+    def _cont_bern_log_norm(self):
+        """computes the log normalizing constant as a function of the 'probs' parameter"""
+        cut_probs = self._cut_probs()
+        cut_probs_below_half = torch.where(
+            torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs)
+        )
+        cut_probs_above_half = torch.where(
+            torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs)
+        )
+        log_norm = torch.log(
+            torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))
+        ) - torch.where(
+            torch.le(cut_probs, 0.5),
+            torch.log1p(-2.0 * cut_probs_below_half),
+            torch.log(2.0 * cut_probs_above_half - 1.0),
+        )
+        x = torch.pow(self.probs - 0.5, 2)
+        taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
+        return torch.where(self._outside_unstable_region(), log_norm, taylor)
+
+    @property
+    def mean(self):
+        cut_probs = self._cut_probs()
+        mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
+            torch.log1p(-cut_probs) - torch.log(cut_probs)
+        )
+        x = self.probs - 0.5
+        taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
+        return torch.where(self._outside_unstable_region(), mus, taylor)
+
+    @property
+    def stddev(self):
+        return torch.sqrt(self.variance)
+
+    @property
+    def variance(self):
+        cut_probs = self._cut_probs()
+        vars = cut_probs * (cut_probs - 1.0) / torch.pow(
+            1.0 - 2.0 * cut_probs, 2
+        ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
+        x = torch.pow(self.probs - 0.5, 2)
+        taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x
+        return torch.where(self._outside_unstable_region(), vars, taylor)
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs, is_binary=True)
+
+    @lazy_property
+    def probs(self):
+        return clamp_probs(logits_to_probs(self.logits, is_binary=True))
+
+    @property
+    def param_shape(self):
+        return self._param.size()
+
+    def sample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
+        with torch.no_grad():
+            return self.icdf(u)
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
+        return self.icdf(u)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        logits, value = broadcast_all(self.logits, value)
+        return (
+            -binary_cross_entropy_with_logits(logits, value, reduction="none")
+            + self._cont_bern_log_norm()
+        )
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        cut_probs = self._cut_probs()
+        cdfs = (
+            torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
+            + cut_probs
+            - 1.0
+        ) / (2.0 * cut_probs - 1.0)
+        unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
+        return torch.where(
+            torch.le(value, 0.0),
+            torch.zeros_like(value),
+            torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs),
+        )
+
+    def icdf(self, value):
+        cut_probs = self._cut_probs()
+        return torch.where(
+            self._outside_unstable_region(),
+            (
+                torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
+                - torch.log1p(-cut_probs)
+            )
+            / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
+            value,
+        )
+
+    def entropy(self):
+        log_probs0 = torch.log1p(-self.probs)
+        log_probs1 = torch.log(self.probs)
+        return (
+            self.mean * (log_probs0 - log_probs1)
+            - self._cont_bern_log_norm()
+            - log_probs0
+        )
+
+    @property
+    def _natural_params(self):
+        return (self.logits,)
+
+    def _log_normalizer(self, x):
+        """computes the log normalizing constant as a function of the natural parameter"""
+        out_unst_reg = torch.max(
+            torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5)
+        )
+        cut_nat_params = torch.where(
+            out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
+        )
+        log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(
+            torch.abs(cut_nat_params)
+        )
+        taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
+        return torch.where(out_unst_reg, log_norm, taylor)
diff --git a/MLPY/Lib/site-packages/torch/distributions/dirichlet.py b/MLPY/Lib/site-packages/torch/distributions/dirichlet.py
new file mode 100644
index 0000000000000000000000000000000000000000..514433b3478bc08ae1168c4e56294fad04afa03f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/dirichlet.py
@@ -0,0 +1,123 @@
+import torch
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+
+__all__ = ["Dirichlet"]
+
+
+# This helper is exposed for testing.
+def _Dirichlet_backward(x, concentration, grad_output):
+    total = concentration.sum(-1, True).expand_as(concentration)
+    grad = torch._dirichlet_grad(x, concentration, total)
+    return grad * (grad_output - (x * grad_output).sum(-1, True))
+
+
+class _Dirichlet(Function):
+    @staticmethod
+    def forward(ctx, concentration):
+        x = torch._sample_dirichlet(concentration)
+        ctx.save_for_backward(x, concentration)
+        return x
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        x, concentration = ctx.saved_tensors
+        return _Dirichlet_backward(x, concentration, grad_output)
+
+
+class Dirichlet(ExponentialFamily):
+    r"""
+    Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Dirichlet(torch.tensor([0.5, 0.5]))
+        >>> m.sample()  # Dirichlet distributed with concentration [0.5, 0.5]
+        tensor([ 0.1046,  0.8954])
+
+    Args:
+        concentration (Tensor): concentration parameter of the distribution
+            (often referred to as alpha)
+    """
+    arg_constraints = {
+        "concentration": constraints.independent(constraints.positive, 1)
+    }
+    support = constraints.simplex
+    has_rsample = True
+
+    def __init__(self, concentration, validate_args=None):
+        if concentration.dim() < 1:
+            raise ValueError(
+                "`concentration` parameter must be at least one-dimensional."
+            )
+        self.concentration = concentration
+        batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Dirichlet, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.concentration = self.concentration.expand(batch_shape + self.event_shape)
+        super(Dirichlet, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    def rsample(self, sample_shape=()):
+        shape = self._extended_shape(sample_shape)
+        concentration = self.concentration.expand(shape)
+        return _Dirichlet.apply(concentration)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return (
+            torch.xlogy(self.concentration - 1.0, value).sum(-1)
+            + torch.lgamma(self.concentration.sum(-1))
+            - torch.lgamma(self.concentration).sum(-1)
+        )
+
+    @property
+    def mean(self):
+        return self.concentration / self.concentration.sum(-1, True)
+
+    @property
+    def mode(self):
+        concentrationm1 = (self.concentration - 1).clamp(min=0.0)
+        mode = concentrationm1 / concentrationm1.sum(-1, True)
+        mask = (self.concentration < 1).all(axis=-1)
+        mode[mask] = torch.nn.functional.one_hot(
+            mode[mask].argmax(axis=-1), concentrationm1.shape[-1]
+        ).to(mode)
+        return mode
+
+    @property
+    def variance(self):
+        con0 = self.concentration.sum(-1, True)
+        return (
+            self.concentration
+            * (con0 - self.concentration)
+            / (con0.pow(2) * (con0 + 1))
+        )
+
+    def entropy(self):
+        k = self.concentration.size(-1)
+        a0 = self.concentration.sum(-1)
+        return (
+            torch.lgamma(self.concentration).sum(-1)
+            - torch.lgamma(a0)
+            - (k - a0) * torch.digamma(a0)
+            - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
+        )
+
+    @property
+    def _natural_params(self):
+        return (self.concentration,)
+
+    def _log_normalizer(self, x):
+        return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
diff --git a/MLPY/Lib/site-packages/torch/distributions/distribution.py b/MLPY/Lib/site-packages/torch/distributions/distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee9424a7518f8593cda5957ef6acf51a266e3b5d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/distribution.py
@@ -0,0 +1,336 @@
+import warnings
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.utils import lazy_property
+from torch.types import _size
+
+__all__ = ["Distribution"]
+
+
+class Distribution:
+    r"""
+    Distribution is the abstract base class for probability distributions.
+    """
+
+    has_rsample = False
+    has_enumerate_support = False
+    _validate_args = __debug__
+
+    @staticmethod
+    def set_default_validate_args(value: bool) -> None:
+        """
+        Sets whether validation is enabled or disabled.
+
+        The default behavior mimics Python's ``assert`` statement: validation
+        is on by default, but is disabled if Python is run in optimized mode
+        (via ``python -O``). Validation may be expensive, so you may want to
+        disable it once a model is working.
+
+        Args:
+            value (bool): Whether to enable validation.
+        """
+        if value not in [True, False]:
+            raise ValueError
+        Distribution._validate_args = value
+
+    def __init__(
+        self,
+        batch_shape: torch.Size = torch.Size(),
+        event_shape: torch.Size = torch.Size(),
+        validate_args: Optional[bool] = None,
+    ):
+        self._batch_shape = batch_shape
+        self._event_shape = event_shape
+        if validate_args is not None:
+            self._validate_args = validate_args
+        if self._validate_args:
+            try:
+                arg_constraints = self.arg_constraints
+            except NotImplementedError:
+                arg_constraints = {}
+                warnings.warn(
+                    f"{self.__class__} does not define `arg_constraints`. "
+                    + "Please set `arg_constraints = {}` or initialize the distribution "
+                    + "with `validate_args=False` to turn off validation."
+                )
+            for param, constraint in arg_constraints.items():
+                if constraints.is_dependent(constraint):
+                    continue  # skip constraints that cannot be checked
+                if param not in self.__dict__ and isinstance(
+                    getattr(type(self), param), lazy_property
+                ):
+                    continue  # skip checking lazily-constructed args
+                value = getattr(self, param)
+                valid = constraint.check(value)
+                if not valid.all():
+                    raise ValueError(
+                        f"Expected parameter {param} "
+                        f"({type(value).__name__} of shape {tuple(value.shape)}) "
+                        f"of distribution {repr(self)} "
+                        f"to satisfy the constraint {repr(constraint)}, "
+                        f"but found invalid values:\n{value}"
+                    )
+        super().__init__()
+
+    def expand(self, batch_shape: torch.Size, _instance=None):
+        """
+        Returns a new distribution instance (or populates an existing instance
+        provided by a derived class) with batch dimensions expanded to
+        `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
+        the distribution's parameters. As such, this does not allocate new
+        memory for the expanded distribution instance. Additionally,
+        this does not repeat any args checking or parameter broadcasting in
+        `__init__.py`, when an instance is first created.
+
+        Args:
+            batch_shape (torch.Size): the desired expanded size.
+            _instance: new instance provided by subclasses that
+                need to override `.expand`.
+
+        Returns:
+            New distribution instance with batch dimensions expanded to
+            `batch_size`.
+        """
+        raise NotImplementedError
+
+    @property
+    def batch_shape(self) -> torch.Size:
+        """
+        Returns the shape over which parameters are batched.
+        """
+        return self._batch_shape
+
+    @property
+    def event_shape(self) -> torch.Size:
+        """
+        Returns the shape of a single sample (without batching).
+        """
+        return self._event_shape
+
+    @property
+    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
+        """
+        Returns a dictionary from argument names to
+        :class:`~torch.distributions.constraints.Constraint` objects that
+        should be satisfied by each argument of this distribution. Args that
+        are not tensors need not appear in this dict.
+        """
+        raise NotImplementedError
+
+    @property
+    def support(self) -> Optional[Any]:
+        """
+        Returns a :class:`~torch.distributions.constraints.Constraint` object
+        representing this distribution's support.
+        """
+        raise NotImplementedError
+
+    @property
+    def mean(self) -> torch.Tensor:
+        """
+        Returns the mean of the distribution.
+        """
+        raise NotImplementedError
+
+    @property
+    def mode(self) -> torch.Tensor:
+        """
+        Returns the mode of the distribution.
+        """
+        raise NotImplementedError(f"{self.__class__} does not implement mode")
+
+    @property
+    def variance(self) -> torch.Tensor:
+        """
+        Returns the variance of the distribution.
+        """
+        raise NotImplementedError
+
+    @property
+    def stddev(self) -> torch.Tensor:
+        """
+        Returns the standard deviation of the distribution.
+        """
+        return self.variance.sqrt()
+
+    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
+        """
+        Generates a sample_shape shaped sample or sample_shape shaped batch of
+        samples if the distribution parameters are batched.
+        """
+        with torch.no_grad():
+            return self.rsample(sample_shape)
+
+    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
+        """
+        Generates a sample_shape shaped reparameterized sample or sample_shape
+        shaped batch of reparameterized samples if the distribution parameters
+        are batched.
+        """
+        raise NotImplementedError
+
+    def sample_n(self, n: int) -> torch.Tensor:
+        """
+        Generates n samples or n batches of samples if the distribution
+        parameters are batched.
+        """
+        warnings.warn(
+            "sample_n will be deprecated. Use .sample((n,)) instead", UserWarning
+        )
+        return self.sample(torch.Size((n,)))
+
+    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+        """
+        Returns the log of the probability density/mass function evaluated at
+        `value`.
+
+        Args:
+            value (Tensor):
+        """
+        raise NotImplementedError
+
+    def cdf(self, value: torch.Tensor) -> torch.Tensor:
+        """
+        Returns the cumulative density/mass function evaluated at
+        `value`.
+
+        Args:
+            value (Tensor):
+        """
+        raise NotImplementedError
+
+    def icdf(self, value: torch.Tensor) -> torch.Tensor:
+        """
+        Returns the inverse cumulative density/mass function evaluated at
+        `value`.
+
+        Args:
+            value (Tensor):
+        """
+        raise NotImplementedError
+
+    def enumerate_support(self, expand: bool = True) -> torch.Tensor:
+        """
+        Returns tensor containing all values supported by a discrete
+        distribution. The result will enumerate over dimension 0, so the shape
+        of the result will be `(cardinality,) + batch_shape + event_shape`
+        (where `event_shape = ()` for univariate distributions).
+
+        Note that this enumerates over all batched tensors in lock-step
+        `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
+        along dim 0, but with the remaining batch dimensions being
+        singleton dimensions, `[[0], [1], ..`.
+
+        To iterate over the full Cartesian product use
+        `itertools.product(m.enumerate_support())`.
+
+        Args:
+            expand (bool): whether to expand the support over the
+                batch dims to match the distribution's `batch_shape`.
+
+        Returns:
+            Tensor iterating over dimension 0.
+        """
+        raise NotImplementedError
+
+    def entropy(self) -> torch.Tensor:
+        """
+        Returns entropy of distribution, batched over batch_shape.
+
+        Returns:
+            Tensor of shape batch_shape.
+        """
+        raise NotImplementedError
+
+    def perplexity(self) -> torch.Tensor:
+        """
+        Returns perplexity of distribution, batched over batch_shape.
+
+        Returns:
+            Tensor of shape batch_shape.
+        """
+        return torch.exp(self.entropy())
+
+    def _extended_shape(self, sample_shape: _size = torch.Size()) -> Tuple[int, ...]:
+        """
+        Returns the size of the sample returned by the distribution, given
+        a `sample_shape`. Note, that the batch and event shapes of a distribution
+        instance are fixed at the time of construction. If this is empty, the
+        returned shape is upcast to (1,).
+
+        Args:
+            sample_shape (torch.Size): the size of the sample to be drawn.
+        """
+        if not isinstance(sample_shape, torch.Size):
+            sample_shape = torch.Size(sample_shape)
+        return torch.Size(sample_shape + self._batch_shape + self._event_shape)
+
+    def _validate_sample(self, value: torch.Tensor) -> None:
+        """
+        Argument validation for distribution methods such as `log_prob`,
+        `cdf` and `icdf`. The rightmost dimensions of a value to be
+        scored via these methods must agree with the distribution's batch
+        and event shapes.
+
+        Args:
+            value (Tensor): the tensor whose log probability is to be
+                computed by the `log_prob` method.
+        Raises
+            ValueError: when the rightmost dimensions of `value` do not match the
+                distribution's batch and event shapes.
+        """
+        if not isinstance(value, torch.Tensor):
+            raise ValueError("The value argument to log_prob must be a Tensor")
+
+        event_dim_start = len(value.size()) - len(self._event_shape)
+        if value.size()[event_dim_start:] != self._event_shape:
+            raise ValueError(
+                f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}."
+            )
+
+        actual_shape = value.size()
+        expected_shape = self._batch_shape + self._event_shape
+        for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
+            if i != 1 and j != 1 and i != j:
+                raise ValueError(
+                    f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}."
+                )
+        try:
+            support = self.support
+        except NotImplementedError:
+            warnings.warn(
+                f"{self.__class__} does not define `support` to enable "
+                + "sample validation. Please initialize the distribution with "
+                + "`validate_args=False` to turn off validation."
+            )
+            return
+        assert support is not None
+        valid = support.check(value)
+        if not valid.all():
+            raise ValueError(
+                "Expected value argument "
+                f"({type(value).__name__} of shape {tuple(value.shape)}) "
+                f"to be within the support ({repr(support)}) "
+                f"of the distribution {repr(self)}, "
+                f"but found invalid values:\n{value}"
+            )
+
+    def _get_checked_instance(self, cls, _instance=None):
+        if _instance is None and type(self).__init__ != cls.__init__:
+            raise NotImplementedError(
+                f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method "
+                "must also define a custom .expand() method."
+            )
+        return self.__new__(type(self)) if _instance is None else _instance
+
+    def __repr__(self) -> str:
+        param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
+        args_string = ", ".join(
+            [
+                f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
+                for p in param_names
+            ]
+        )
+        return self.__class__.__name__ + "(" + args_string + ")"
diff --git a/MLPY/Lib/site-packages/torch/distributions/exp_family.py b/MLPY/Lib/site-packages/torch/distributions/exp_family.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06585810a02a42ea239cbd2fc110197e655335e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/exp_family.py
@@ -0,0 +1,62 @@
+import torch
+from torch.distributions.distribution import Distribution
+
+__all__ = ["ExponentialFamily"]
+
+
+class ExponentialFamily(Distribution):
+    r"""
+    ExponentialFamily is the abstract base class for probability distributions belonging to an
+    exponential family, whose probability mass/density function has the form is defined below
+
+    .. math::
+
+        p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
+
+    where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
+    :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
+    measure.
+
+    Note:
+        This class is an intermediary between the `Distribution` class and distributions which belong
+        to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
+        divergence methods. We use this class to compute the entropy and KL divergence using the AD
+        framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
+        Cross-entropies of Exponential Families).
+    """
+
+    @property
+    def _natural_params(self):
+        """
+        Abstract method for natural parameters. Returns a tuple of Tensors based
+        on the distribution
+        """
+        raise NotImplementedError
+
+    def _log_normalizer(self, *natural_params):
+        """
+        Abstract method for log normalizer function. Returns a log normalizer based on
+        the distribution and input
+        """
+        raise NotImplementedError
+
+    @property
+    def _mean_carrier_measure(self):
+        """
+        Abstract method for expected carrier measure, which is required for computing
+        entropy.
+        """
+        raise NotImplementedError
+
+    def entropy(self):
+        """
+        Method to compute the entropy using Bregman divergence of the log normalizer.
+        """
+        result = -self._mean_carrier_measure
+        nparams = [p.detach().requires_grad_() for p in self._natural_params]
+        lg_normal = self._log_normalizer(*nparams)
+        gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
+        result += lg_normal
+        for np, g in zip(nparams, gradients):
+            result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
+        return result
diff --git a/MLPY/Lib/site-packages/torch/distributions/exponential.py b/MLPY/Lib/site-packages/torch/distributions/exponential.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b624d3b88dc160aa8de4c6e60227c9cfbb33934
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/exponential.py
@@ -0,0 +1,84 @@
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Exponential"]
+
+
+class Exponential(ExponentialFamily):
+    r"""
+    Creates a Exponential distribution parameterized by :attr:`rate`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Exponential(torch.tensor([1.0]))
+        >>> m.sample()  # Exponential distributed with rate=1
+        tensor([ 0.1046])
+
+    Args:
+        rate (float or Tensor): rate = 1 / scale of the distribution
+    """
+    arg_constraints = {"rate": constraints.positive}
+    support = constraints.nonnegative
+    has_rsample = True
+    _mean_carrier_measure = 0
+
+    @property
+    def mean(self):
+        return self.rate.reciprocal()
+
+    @property
+    def mode(self):
+        return torch.zeros_like(self.rate)
+
+    @property
+    def stddev(self):
+        return self.rate.reciprocal()
+
+    @property
+    def variance(self):
+        return self.rate.pow(-2)
+
+    def __init__(self, rate, validate_args=None):
+        (self.rate,) = broadcast_all(rate)
+        batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Exponential, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.rate = self.rate.expand(batch_shape)
+        super(Exponential, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        return self.rate.new(shape).exponential_() / self.rate
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return self.rate.log() - self.rate * value
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return 1 - torch.exp(-self.rate * value)
+
+    def icdf(self, value):
+        return -torch.log1p(-value) / self.rate
+
+    def entropy(self):
+        return 1.0 - torch.log(self.rate)
+
+    @property
+    def _natural_params(self):
+        return (-self.rate,)
+
+    def _log_normalizer(self, x):
+        return -torch.log(-x)
diff --git a/MLPY/Lib/site-packages/torch/distributions/fishersnedecor.py b/MLPY/Lib/site-packages/torch/distributions/fishersnedecor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9d01bcd5105ad644c8fce4e202c2bc21b5dbcd0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/fishersnedecor.py
@@ -0,0 +1,98 @@
+from numbers import Number
+
+import torch
+from torch import nan
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.gamma import Gamma
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["FisherSnedecor"]
+
+
+class FisherSnedecor(Distribution):
+    r"""
+    Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
+        >>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
+        tensor([ 0.2453])
+
+    Args:
+        df1 (float or Tensor): degrees of freedom parameter 1
+        df2 (float or Tensor): degrees of freedom parameter 2
+    """
+    arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
+    support = constraints.positive
+    has_rsample = True
+
+    def __init__(self, df1, df2, validate_args=None):
+        self.df1, self.df2 = broadcast_all(df1, df2)
+        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
+        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
+
+        if isinstance(df1, Number) and isinstance(df2, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.df1.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(FisherSnedecor, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.df1 = self.df1.expand(batch_shape)
+        new.df2 = self.df2.expand(batch_shape)
+        new._gamma1 = self._gamma1.expand(batch_shape)
+        new._gamma2 = self._gamma2.expand(batch_shape)
+        super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def mean(self):
+        df2 = self.df2.clone(memory_format=torch.contiguous_format)
+        df2[df2 <= 2] = nan
+        return df2 / (df2 - 2)
+
+    @property
+    def mode(self):
+        mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2)
+        mode[self.df1 <= 2] = nan
+        return mode
+
+    @property
+    def variance(self):
+        df2 = self.df2.clone(memory_format=torch.contiguous_format)
+        df2[df2 <= 4] = nan
+        return (
+            2
+            * df2.pow(2)
+            * (self.df1 + df2 - 2)
+            / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
+        )
+
+    def rsample(self, sample_shape=torch.Size(())):
+        shape = self._extended_shape(sample_shape)
+        #   X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
+        #   Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
+        X1 = self._gamma1.rsample(sample_shape).view(shape)
+        X2 = self._gamma2.rsample(sample_shape).view(shape)
+        tiny = torch.finfo(X2.dtype).tiny
+        X2.clamp_(min=tiny)
+        Y = X1 / X2
+        Y.clamp_(min=tiny)
+        return Y
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        ct1 = self.df1 * 0.5
+        ct2 = self.df2 * 0.5
+        ct3 = self.df1 / self.df2
+        t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
+        t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
+        t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
+        return t1 + t2 - t3
diff --git a/MLPY/Lib/site-packages/torch/distributions/gamma.py b/MLPY/Lib/site-packages/torch/distributions/gamma.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b39cb12a407803debd0fec717180248ae090320
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/gamma.py
@@ -0,0 +1,108 @@
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Gamma"]
+
+
+def _standard_gamma(concentration):
+    return torch._standard_gamma(concentration)
+
+
+class Gamma(ExponentialFamily):
+    r"""
+    Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
+        >>> m.sample()  # Gamma distributed with concentration=1 and rate=1
+        tensor([ 0.1046])
+
+    Args:
+        concentration (float or Tensor): shape parameter of the distribution
+            (often referred to as alpha)
+        rate (float or Tensor): rate = 1 / scale of the distribution
+            (often referred to as beta)
+    """
+    arg_constraints = {
+        "concentration": constraints.positive,
+        "rate": constraints.positive,
+    }
+    support = constraints.nonnegative
+    has_rsample = True
+    _mean_carrier_measure = 0
+
+    @property
+    def mean(self):
+        return self.concentration / self.rate
+
+    @property
+    def mode(self):
+        return ((self.concentration - 1) / self.rate).clamp(min=0)
+
+    @property
+    def variance(self):
+        return self.concentration / self.rate.pow(2)
+
+    def __init__(self, concentration, rate, validate_args=None):
+        self.concentration, self.rate = broadcast_all(concentration, rate)
+        if isinstance(concentration, Number) and isinstance(rate, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.concentration.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Gamma, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.concentration = self.concentration.expand(batch_shape)
+        new.rate = self.rate.expand(batch_shape)
+        super(Gamma, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(
+            shape
+        )
+        value.detach().clamp_(
+            min=torch.finfo(value.dtype).tiny
+        )  # do not record in autograd graph
+        return value
+
+    def log_prob(self, value):
+        value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
+        if self._validate_args:
+            self._validate_sample(value)
+        return (
+            torch.xlogy(self.concentration, self.rate)
+            + torch.xlogy(self.concentration - 1, value)
+            - self.rate * value
+            - torch.lgamma(self.concentration)
+        )
+
+    def entropy(self):
+        return (
+            self.concentration
+            - torch.log(self.rate)
+            + torch.lgamma(self.concentration)
+            + (1.0 - self.concentration) * torch.digamma(self.concentration)
+        )
+
+    @property
+    def _natural_params(self):
+        return (self.concentration - 1, -self.rate)
+
+    def _log_normalizer(self, x, y):
+        return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return torch.special.gammainc(self.concentration, self.rate * value)
diff --git a/MLPY/Lib/site-packages/torch/distributions/geometric.py b/MLPY/Lib/site-packages/torch/distributions/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a1322fd72d6144df1e7c6d85728cff672a7ecac
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/geometric.py
@@ -0,0 +1,128 @@
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import (
+    broadcast_all,
+    lazy_property,
+    logits_to_probs,
+    probs_to_logits,
+)
+from torch.nn.functional import binary_cross_entropy_with_logits
+
+__all__ = ["Geometric"]
+
+
+class Geometric(Distribution):
+    r"""
+    Creates a Geometric distribution parameterized by :attr:`probs`,
+    where :attr:`probs` is the probability of success of Bernoulli trials.
+
+    .. math::
+
+        P(X=k) = (1-p)^{k} p, k = 0, 1, ...
+
+    .. note::
+        :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success
+        hence draws samples in :math:`\{0, 1, \ldots\}`, whereas
+        :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Geometric(torch.tensor([0.3]))
+        >>> m.sample()  # underlying Bernoulli has 30% chance 1; 70% chance 0
+        tensor([ 2.])
+
+    Args:
+        probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
+        logits (Number, Tensor): the log-odds of sampling `1`.
+    """
+    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
+    support = constraints.nonnegative_integer
+
+    def __init__(self, probs=None, logits=None, validate_args=None):
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            (self.probs,) = broadcast_all(probs)
+        else:
+            (self.logits,) = broadcast_all(logits)
+        probs_or_logits = probs if probs is not None else logits
+        if isinstance(probs_or_logits, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = probs_or_logits.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+        if self._validate_args and probs is not None:
+            # Add an extra check beyond unit_interval
+            value = self.probs
+            valid = value > 0
+            if not valid.all():
+                invalid_value = value.data[~valid]
+                raise ValueError(
+                    "Expected parameter probs "
+                    f"({type(value).__name__} of shape {tuple(value.shape)}) "
+                    f"of distribution {repr(self)} "
+                    f"to be positive but found invalid values:\n{invalid_value}"
+                )
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Geometric, _instance)
+        batch_shape = torch.Size(batch_shape)
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(batch_shape)
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(batch_shape)
+        super(Geometric, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def mean(self):
+        return 1.0 / self.probs - 1.0
+
+    @property
+    def mode(self):
+        return torch.zeros_like(self.probs)
+
+    @property
+    def variance(self):
+        return (1.0 / self.probs - 1.0) / self.probs
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs, is_binary=True)
+
+    @lazy_property
+    def probs(self):
+        return logits_to_probs(self.logits, is_binary=True)
+
+    def sample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        tiny = torch.finfo(self.probs.dtype).tiny
+        with torch.no_grad():
+            if torch._C._get_tracing_state():
+                # [JIT WORKAROUND] lack of support for .uniform_()
+                u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
+                u = u.clamp(min=tiny)
+            else:
+                u = self.probs.new(shape).uniform_(tiny, 1)
+            return (u.log() / (-self.probs).log1p()).floor()
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        value, probs = broadcast_all(value, self.probs)
+        probs = probs.clone(memory_format=torch.contiguous_format)
+        probs[(probs == 1) & (value == 0)] = 0
+        return value * (-probs).log1p() + self.probs.log()
+
+    def entropy(self):
+        return (
+            binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none")
+            / self.probs
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/gumbel.py b/MLPY/Lib/site-packages/torch/distributions/gumbel.py
new file mode 100644
index 0000000000000000000000000000000000000000..303e8d8e5d9980c5ccc6d3a436a6ff5a750cf39c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/gumbel.py
@@ -0,0 +1,81 @@
+import math
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, ExpTransform
+from torch.distributions.uniform import Uniform
+from torch.distributions.utils import broadcast_all, euler_constant
+
+__all__ = ["Gumbel"]
+
+
+class Gumbel(TransformedDistribution):
+    r"""
+    Samples from a Gumbel Distribution.
+
+    Examples::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
+        >>> m.sample()  # sample from Gumbel distribution with loc=1, scale=2
+        tensor([ 1.0124])
+
+    Args:
+        loc (float or Tensor): Location parameter of the distribution
+        scale (float or Tensor): Scale parameter of the distribution
+    """
+    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+    support = constraints.real
+
+    def __init__(self, loc, scale, validate_args=None):
+        self.loc, self.scale = broadcast_all(loc, scale)
+        finfo = torch.finfo(self.loc.dtype)
+        if isinstance(loc, Number) and isinstance(scale, Number):
+            base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
+        else:
+            base_dist = Uniform(
+                torch.full_like(self.loc, finfo.tiny),
+                torch.full_like(self.loc, 1 - finfo.eps),
+                validate_args=validate_args,
+            )
+        transforms = [
+            ExpTransform().inv,
+            AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
+            ExpTransform().inv,
+            AffineTransform(loc=loc, scale=-self.scale),
+        ]
+        super().__init__(base_dist, transforms, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Gumbel, _instance)
+        new.loc = self.loc.expand(batch_shape)
+        new.scale = self.scale.expand(batch_shape)
+        return super().expand(batch_shape, _instance=new)
+
+    # Explicitly defining the log probability function for Gumbel due to precision issues
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        y = (self.loc - value) / self.scale
+        return (y - y.exp()) - self.scale.log()
+
+    @property
+    def mean(self):
+        return self.loc + self.scale * euler_constant
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @property
+    def stddev(self):
+        return (math.pi / math.sqrt(6)) * self.scale
+
+    @property
+    def variance(self):
+        return self.stddev.pow(2)
+
+    def entropy(self):
+        return self.scale.log() + (1 + euler_constant)
diff --git a/MLPY/Lib/site-packages/torch/distributions/half_cauchy.py b/MLPY/Lib/site-packages/torch/distributions/half_cauchy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9880ff6e7518c40ecd74745ecedf4f0c5b1d3457
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/half_cauchy.py
@@ -0,0 +1,82 @@
+import math
+
+import torch
+from torch import inf
+from torch.distributions import constraints
+from torch.distributions.cauchy import Cauchy
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AbsTransform
+
+__all__ = ["HalfCauchy"]
+
+
+class HalfCauchy(TransformedDistribution):
+    r"""
+    Creates a half-Cauchy distribution parameterized by `scale` where::
+
+        X ~ Cauchy(0, scale)
+        Y = |X| ~ HalfCauchy(scale)
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = HalfCauchy(torch.tensor([1.0]))
+        >>> m.sample()  # half-cauchy distributed with scale=1
+        tensor([ 2.3214])
+
+    Args:
+        scale (float or Tensor): scale of the full Cauchy distribution
+    """
+    arg_constraints = {"scale": constraints.positive}
+    support = constraints.nonnegative
+    has_rsample = True
+
+    def __init__(self, scale, validate_args=None):
+        base_dist = Cauchy(0, scale, validate_args=False)
+        super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(HalfCauchy, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def scale(self):
+        return self.base_dist.scale
+
+    @property
+    def mean(self):
+        return torch.full(
+            self._extended_shape(),
+            math.inf,
+            dtype=self.scale.dtype,
+            device=self.scale.device,
+        )
+
+    @property
+    def mode(self):
+        return torch.zeros_like(self.scale)
+
+    @property
+    def variance(self):
+        return self.base_dist.variance
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        value = torch.as_tensor(
+            value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device
+        )
+        log_prob = self.base_dist.log_prob(value) + math.log(2)
+        log_prob = torch.where(value >= 0, log_prob, -inf)
+        return log_prob
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return 2 * self.base_dist.cdf(value) - 1
+
+    def icdf(self, prob):
+        return self.base_dist.icdf((prob + 1) / 2)
+
+    def entropy(self):
+        return self.base_dist.entropy() - math.log(2)
diff --git a/MLPY/Lib/site-packages/torch/distributions/half_normal.py b/MLPY/Lib/site-packages/torch/distributions/half_normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..60cc45c6633599006e4981437e0903a6ac4df913
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/half_normal.py
@@ -0,0 +1,74 @@
+import math
+
+import torch
+from torch import inf
+from torch.distributions import constraints
+from torch.distributions.normal import Normal
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AbsTransform
+
+__all__ = ["HalfNormal"]
+
+
+class HalfNormal(TransformedDistribution):
+    r"""
+    Creates a half-normal distribution parameterized by `scale` where::
+
+        X ~ Normal(0, scale)
+        Y = |X| ~ HalfNormal(scale)
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = HalfNormal(torch.tensor([1.0]))
+        >>> m.sample()  # half-normal distributed with scale=1
+        tensor([ 0.1046])
+
+    Args:
+        scale (float or Tensor): scale of the full Normal distribution
+    """
+    arg_constraints = {"scale": constraints.positive}
+    support = constraints.nonnegative
+    has_rsample = True
+
+    def __init__(self, scale, validate_args=None):
+        base_dist = Normal(0, scale, validate_args=False)
+        super().__init__(base_dist, AbsTransform(), validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(HalfNormal, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def scale(self):
+        return self.base_dist.scale
+
+    @property
+    def mean(self):
+        return self.scale * math.sqrt(2 / math.pi)
+
+    @property
+    def mode(self):
+        return torch.zeros_like(self.scale)
+
+    @property
+    def variance(self):
+        return self.scale.pow(2) * (1 - 2 / math.pi)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        log_prob = self.base_dist.log_prob(value) + math.log(2)
+        log_prob = torch.where(value >= 0, log_prob, -inf)
+        return log_prob
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return 2 * self.base_dist.cdf(value) - 1
+
+    def icdf(self, prob):
+        return self.base_dist.icdf((prob + 1) / 2)
+
+    def entropy(self):
+        return self.base_dist.entropy() - math.log(2)
diff --git a/MLPY/Lib/site-packages/torch/distributions/independent.py b/MLPY/Lib/site-packages/torch/distributions/independent.py
new file mode 100644
index 0000000000000000000000000000000000000000..2510c4724d57ae6c81be7f04019f40f15beef785
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/independent.py
@@ -0,0 +1,125 @@
+from typing import Dict
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import _sum_rightmost
+
+__all__ = ["Independent"]
+
+
+class Independent(Distribution):
+    r"""
+    Reinterprets some of the batch dims of a distribution as event dims.
+
+    This is mainly useful for changing the shape of the result of
+    :meth:`log_prob`. For example to create a diagonal Normal distribution with
+    the same shape as a Multivariate Normal distribution (so they are
+    interchangeable), you can::
+
+        >>> from torch.distributions.multivariate_normal import MultivariateNormal
+        >>> from torch.distributions.normal import Normal
+        >>> loc = torch.zeros(3)
+        >>> scale = torch.ones(3)
+        >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
+        >>> [mvn.batch_shape, mvn.event_shape]
+        [torch.Size([]), torch.Size([3])]
+        >>> normal = Normal(loc, scale)
+        >>> [normal.batch_shape, normal.event_shape]
+        [torch.Size([3]), torch.Size([])]
+        >>> diagn = Independent(normal, 1)
+        >>> [diagn.batch_shape, diagn.event_shape]
+        [torch.Size([]), torch.Size([3])]
+
+    Args:
+        base_distribution (torch.distributions.distribution.Distribution): a
+            base distribution
+        reinterpreted_batch_ndims (int): the number of batch dims to
+            reinterpret as event dims
+    """
+    arg_constraints: Dict[str, constraints.Constraint] = {}
+
+    def __init__(
+        self, base_distribution, reinterpreted_batch_ndims, validate_args=None
+    ):
+        if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
+            raise ValueError(
+                "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
+                f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
+            )
+        shape = base_distribution.batch_shape + base_distribution.event_shape
+        event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
+        batch_shape = shape[: len(shape) - event_dim]
+        event_shape = shape[len(shape) - event_dim :]
+        self.base_dist = base_distribution
+        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Independent, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.base_dist = self.base_dist.expand(
+            batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
+        )
+        new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
+        super(Independent, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def has_rsample(self):
+        return self.base_dist.has_rsample
+
+    @property
+    def has_enumerate_support(self):
+        if self.reinterpreted_batch_ndims > 0:
+            return False
+        return self.base_dist.has_enumerate_support
+
+    @constraints.dependent_property
+    def support(self):
+        result = self.base_dist.support
+        if self.reinterpreted_batch_ndims:
+            result = constraints.independent(result, self.reinterpreted_batch_ndims)
+        return result
+
+    @property
+    def mean(self):
+        return self.base_dist.mean
+
+    @property
+    def mode(self):
+        return self.base_dist.mode
+
+    @property
+    def variance(self):
+        return self.base_dist.variance
+
+    def sample(self, sample_shape=torch.Size()):
+        return self.base_dist.sample(sample_shape)
+
+    def rsample(self, sample_shape=torch.Size()):
+        return self.base_dist.rsample(sample_shape)
+
+    def log_prob(self, value):
+        log_prob = self.base_dist.log_prob(value)
+        return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
+
+    def entropy(self):
+        entropy = self.base_dist.entropy()
+        return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
+
+    def enumerate_support(self, expand=True):
+        if self.reinterpreted_batch_ndims > 0:
+            raise NotImplementedError(
+                "Enumeration over cartesian product is not implemented"
+            )
+        return self.base_dist.enumerate_support(expand=expand)
+
+    def __repr__(self):
+        return (
+            self.__class__.__name__
+            + f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/inverse_gamma.py b/MLPY/Lib/site-packages/torch/distributions/inverse_gamma.py
new file mode 100644
index 0000000000000000000000000000000000000000..418460df7100a34d4b0063b2aae92397a63fb673
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/inverse_gamma.py
@@ -0,0 +1,80 @@
+import torch
+from torch.distributions import constraints
+from torch.distributions.gamma import Gamma
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import PowerTransform
+
+
+__all__ = ["InverseGamma"]
+
+
+class InverseGamma(TransformedDistribution):
+    r"""
+    Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate`
+    where::
+
+        X ~ Gamma(concentration, rate)
+        Y = 1 / X ~ InverseGamma(concentration, rate)
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
+        >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))
+        >>> m.sample()
+        tensor([ 1.2953])
+
+    Args:
+        concentration (float or Tensor): shape parameter of the distribution
+            (often referred to as alpha)
+        rate (float or Tensor): rate = 1 / scale of the distribution
+            (often referred to as beta)
+    """
+    arg_constraints = {
+        "concentration": constraints.positive,
+        "rate": constraints.positive,
+    }
+    support = constraints.positive
+    has_rsample = True
+
+    def __init__(self, concentration, rate, validate_args=None):
+        base_dist = Gamma(concentration, rate, validate_args=validate_args)
+        neg_one = -base_dist.rate.new_ones(())
+        super().__init__(
+            base_dist, PowerTransform(neg_one), validate_args=validate_args
+        )
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(InverseGamma, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def concentration(self):
+        return self.base_dist.concentration
+
+    @property
+    def rate(self):
+        return self.base_dist.rate
+
+    @property
+    def mean(self):
+        result = self.rate / (self.concentration - 1)
+        return torch.where(self.concentration > 1, result, torch.inf)
+
+    @property
+    def mode(self):
+        return self.rate / (self.concentration + 1)
+
+    @property
+    def variance(self):
+        result = self.rate.square() / (
+            (self.concentration - 1).square() * (self.concentration - 2)
+        )
+        return torch.where(self.concentration > 2, result, torch.inf)
+
+    def entropy(self):
+        return (
+            self.concentration
+            + self.rate.log()
+            + self.concentration.lgamma()
+            - (1 + self.concentration) * self.concentration.digamma()
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/kl.py b/MLPY/Lib/site-packages/torch/distributions/kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fe67ea6cb56b6030a22369af71c2bc3e737620b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/kl.py
@@ -0,0 +1,971 @@
+import math
+import warnings
+from functools import total_ordering
+from typing import Callable, Dict, Tuple, Type
+
+import torch
+from torch import inf
+
+from .bernoulli import Bernoulli
+from .beta import Beta
+from .binomial import Binomial
+from .categorical import Categorical
+from .cauchy import Cauchy
+from .continuous_bernoulli import ContinuousBernoulli
+from .dirichlet import Dirichlet
+from .distribution import Distribution
+from .exp_family import ExponentialFamily
+from .exponential import Exponential
+from .gamma import Gamma
+from .geometric import Geometric
+from .gumbel import Gumbel
+from .half_normal import HalfNormal
+from .independent import Independent
+from .laplace import Laplace
+from .lowrank_multivariate_normal import (
+    _batch_lowrank_logdet,
+    _batch_lowrank_mahalanobis,
+    LowRankMultivariateNormal,
+)
+from .multivariate_normal import _batch_mahalanobis, MultivariateNormal
+from .normal import Normal
+from .one_hot_categorical import OneHotCategorical
+from .pareto import Pareto
+from .poisson import Poisson
+from .transformed_distribution import TransformedDistribution
+from .uniform import Uniform
+from .utils import _sum_rightmost, euler_constant as _euler_gamma
+
+_KL_REGISTRY: Dict[
+    Tuple[Type, Type], Callable
+] = {}  # Source of truth mapping a few general (type, type) pairs to functions.
+_KL_MEMOIZE: Dict[
+    Tuple[Type, Type], Callable
+] = {}  # Memoized version mapping many specific (type, type) pairs to functions.
+
+__all__ = ["register_kl", "kl_divergence"]
+
+
+def register_kl(type_p, type_q):
+    """
+    Decorator to register a pairwise function with :meth:`kl_divergence`.
+    Usage::
+
+        @register_kl(Normal, Normal)
+        def kl_normal_normal(p, q):
+            # insert implementation here
+
+    Lookup returns the most specific (type,type) match ordered by subclass. If
+    the match is ambiguous, a `RuntimeWarning` is raised. For example to
+    resolve the ambiguous situation::
+
+        @register_kl(BaseP, DerivedQ)
+        def kl_version1(p, q): ...
+        @register_kl(DerivedP, BaseQ)
+        def kl_version2(p, q): ...
+
+    you should register a third most-specific implementation, e.g.::
+
+        register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.
+
+    Args:
+        type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
+        type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
+    """
+    if not isinstance(type_p, type) and issubclass(type_p, Distribution):
+        raise TypeError(
+            f"Expected type_p to be a Distribution subclass but got {type_p}"
+        )
+    if not isinstance(type_q, type) and issubclass(type_q, Distribution):
+        raise TypeError(
+            f"Expected type_q to be a Distribution subclass but got {type_q}"
+        )
+
+    def decorator(fun):
+        _KL_REGISTRY[type_p, type_q] = fun
+        _KL_MEMOIZE.clear()  # reset since lookup order may have changed
+        return fun
+
+    return decorator
+
+
+@total_ordering
+class _Match:
+    __slots__ = ["types"]
+
+    def __init__(self, *types):
+        self.types = types
+
+    def __eq__(self, other):
+        return self.types == other.types
+
+    def __le__(self, other):
+        for x, y in zip(self.types, other.types):
+            if not issubclass(x, y):
+                return False
+            if x is not y:
+                break
+        return True
+
+
+def _dispatch_kl(type_p, type_q):
+    """
+    Find the most specific approximate match, assuming single inheritance.
+    """
+    matches = [
+        (super_p, super_q)
+        for super_p, super_q in _KL_REGISTRY
+        if issubclass(type_p, super_p) and issubclass(type_q, super_q)
+    ]
+    if not matches:
+        return NotImplemented
+    # Check that the left- and right- lexicographic orders agree.
+    # mypy isn't smart enough to know that _Match implements __lt__
+    # see: https://github.com/python/typing/issues/760#issuecomment-710670503
+    left_p, left_q = min(_Match(*m) for m in matches).types  # type: ignore[type-var]
+    right_q, right_p = min(_Match(*reversed(m)) for m in matches).types  # type: ignore[type-var]
+    left_fun = _KL_REGISTRY[left_p, left_q]
+    right_fun = _KL_REGISTRY[right_p, right_q]
+    if left_fun is not right_fun:
+        warnings.warn(
+            "Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(
+                type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__
+            ),
+            RuntimeWarning,
+        )
+    return left_fun
+
+
+def _infinite_like(tensor):
+    """
+    Helper function for obtaining infinite KL Divergence throughout
+    """
+    return torch.full_like(tensor, inf)
+
+
+def _x_log_x(tensor):
+    """
+    Utility function for calculating x log x
+    """
+    return tensor * tensor.log()
+
+
+def _batch_trace_XXT(bmat):
+    """
+    Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
+    """
+    n = bmat.size(-1)
+    m = bmat.size(-2)
+    flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
+    return flat_trace.reshape(bmat.shape[:-2])
+
+
+def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
+    r"""
+    Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
+
+    .. math::
+
+        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
+
+    Args:
+        p (Distribution): A :class:`~torch.distributions.Distribution` object.
+        q (Distribution): A :class:`~torch.distributions.Distribution` object.
+
+    Returns:
+        Tensor: A batch of KL divergences of shape `batch_shape`.
+
+    Raises:
+        NotImplementedError: If the distribution types have not been registered via
+            :meth:`register_kl`.
+    """
+    try:
+        fun = _KL_MEMOIZE[type(p), type(q)]
+    except KeyError:
+        fun = _dispatch_kl(type(p), type(q))
+        _KL_MEMOIZE[type(p), type(q)] = fun
+    if fun is NotImplemented:
+        raise NotImplementedError(
+            f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
+        )
+    return fun(p, q)
+
+
+################################################################################
+# KL Divergence Implementations
+################################################################################
+
+# Same distributions
+
+
+@register_kl(Bernoulli, Bernoulli)
+def _kl_bernoulli_bernoulli(p, q):
+    t1 = p.probs * (
+        torch.nn.functional.softplus(-q.logits)
+        - torch.nn.functional.softplus(-p.logits)
+    )
+    t1[q.probs == 0] = inf
+    t1[p.probs == 0] = 0
+    t2 = (1 - p.probs) * (
+        torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)
+    )
+    t2[q.probs == 1] = inf
+    t2[p.probs == 1] = 0
+    return t1 + t2
+
+
+@register_kl(Beta, Beta)
+def _kl_beta_beta(p, q):
+    sum_params_p = p.concentration1 + p.concentration0
+    sum_params_q = q.concentration1 + q.concentration0
+    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
+    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
+    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
+    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
+    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
+    return t1 - t2 + t3 + t4 + t5
+
+
+@register_kl(Binomial, Binomial)
+def _kl_binomial_binomial(p, q):
+    # from https://math.stackexchange.com/questions/2214993/
+    # kullback-leibler-divergence-for-binomial-distributions-p-and-q
+    if (p.total_count < q.total_count).any():
+        raise NotImplementedError(
+            "KL between Binomials where q.total_count > p.total_count is not implemented"
+        )
+    kl = p.total_count * (
+        p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()
+    )
+    inf_idxs = p.total_count > q.total_count
+    kl[inf_idxs] = _infinite_like(kl[inf_idxs])
+    return kl
+
+
+@register_kl(Categorical, Categorical)
+def _kl_categorical_categorical(p, q):
+    t = p.probs * (p.logits - q.logits)
+    t[(q.probs == 0).expand_as(t)] = inf
+    t[(p.probs == 0).expand_as(t)] = 0
+    return t.sum(-1)
+
+
+@register_kl(ContinuousBernoulli, ContinuousBernoulli)
+def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
+    t1 = p.mean * (p.logits - q.logits)
+    t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
+    t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs)
+    return t1 + t2 + t3
+
+
+@register_kl(Dirichlet, Dirichlet)
+def _kl_dirichlet_dirichlet(p, q):
+    # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
+    sum_p_concentration = p.concentration.sum(-1)
+    sum_q_concentration = q.concentration.sum(-1)
+    t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
+    t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
+    t3 = p.concentration - q.concentration
+    t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
+    return t1 - t2 + (t3 * t4).sum(-1)
+
+
+@register_kl(Exponential, Exponential)
+def _kl_exponential_exponential(p, q):
+    rate_ratio = q.rate / p.rate
+    t1 = -rate_ratio.log()
+    return t1 + rate_ratio - 1
+
+
+@register_kl(ExponentialFamily, ExponentialFamily)
+def _kl_expfamily_expfamily(p, q):
+    if not type(p) == type(q):
+        raise NotImplementedError(
+            "The cross KL-divergence between different exponential families cannot \
+                            be computed using Bregman divergences"
+        )
+    p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
+    q_nparams = q._natural_params
+    lg_normal = p._log_normalizer(*p_nparams)
+    gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
+    result = q._log_normalizer(*q_nparams) - lg_normal
+    for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
+        term = (qnp - pnp) * g
+        result -= _sum_rightmost(term, len(q.event_shape))
+    return result
+
+
+@register_kl(Gamma, Gamma)
+def _kl_gamma_gamma(p, q):
+    t1 = q.concentration * (p.rate / q.rate).log()
+    t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
+    t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
+    t4 = (q.rate - p.rate) * (p.concentration / p.rate)
+    return t1 + t2 + t3 + t4
+
+
+@register_kl(Gumbel, Gumbel)
+def _kl_gumbel_gumbel(p, q):
+    ct1 = p.scale / q.scale
+    ct2 = q.loc / q.scale
+    ct3 = p.loc / q.scale
+    t1 = -ct1.log() - ct2 + ct3
+    t2 = ct1 * _euler_gamma
+    t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
+    return t1 + t2 + t3 - (1 + _euler_gamma)
+
+
+@register_kl(Geometric, Geometric)
+def _kl_geometric_geometric(p, q):
+    return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
+
+
+@register_kl(HalfNormal, HalfNormal)
+def _kl_halfnormal_halfnormal(p, q):
+    return _kl_normal_normal(p.base_dist, q.base_dist)
+
+
+@register_kl(Laplace, Laplace)
+def _kl_laplace_laplace(p, q):
+    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
+    scale_ratio = p.scale / q.scale
+    loc_abs_diff = (p.loc - q.loc).abs()
+    t1 = -scale_ratio.log()
+    t2 = loc_abs_diff / q.scale
+    t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
+    return t1 + t2 + t3 - 1
+
+
+@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
+def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
+    if p.event_shape != q.event_shape:
+        raise ValueError(
+            "KL-divergence between two Low Rank Multivariate Normals with\
+                          different event shapes cannot be computed"
+        )
+
+    term1 = _batch_lowrank_logdet(
+        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
+    ) - _batch_lowrank_logdet(
+        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
+    )
+    term3 = _batch_lowrank_mahalanobis(
+        q._unbroadcasted_cov_factor,
+        q._unbroadcasted_cov_diag,
+        q.loc - p.loc,
+        q._capacitance_tril,
+    )
+    # Expands term2 according to
+    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
+    #                  = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
+    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
+    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
+    term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
+    term22 = _batch_trace_XXT(
+        p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
+    )
+    term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
+    term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
+    term2 = term21 + term22 - term23 - term24
+    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
+
+
+@register_kl(MultivariateNormal, LowRankMultivariateNormal)
+def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
+    if p.event_shape != q.event_shape:
+        raise ValueError(
+            "KL-divergence between two (Low Rank) Multivariate Normals with\
+                          different event shapes cannot be computed"
+        )
+
+    term1 = _batch_lowrank_logdet(
+        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
+    ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+    term3 = _batch_lowrank_mahalanobis(
+        q._unbroadcasted_cov_factor,
+        q._unbroadcasted_cov_diag,
+        q.loc - p.loc,
+        q._capacitance_tril,
+    )
+    # Expands term2 according to
+    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
+    #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
+    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
+    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
+    term21 = _batch_trace_XXT(
+        p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
+    )
+    term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
+    term2 = term21 - term22
+    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
+
+
+@register_kl(LowRankMultivariateNormal, MultivariateNormal)
+def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
+    if p.event_shape != q.event_shape:
+        raise ValueError(
+            "KL-divergence between two (Low Rank) Multivariate Normals with\
+                          different event shapes cannot be computed"
+        )
+
+    term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
+        -1
+    ) - _batch_lowrank_logdet(
+        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
+    )
+    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
+    # Expands term2 according to
+    # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
+    combined_batch_shape = torch._C._infer_size(
+        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2]
+    )
+    n = p.event_shape[0]
+    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
+    p_cov_factor = p._unbroadcasted_cov_factor.expand(
+        combined_batch_shape + (n, p.cov_factor.size(-1))
+    )
+    p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(
+        combined_batch_shape + (n, n)
+    )
+    term21 = _batch_trace_XXT(
+        torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)
+    )
+    term22 = _batch_trace_XXT(
+        torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)
+    )
+    term2 = term21 + term22
+    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
+
+
+@register_kl(MultivariateNormal, MultivariateNormal)
+def _kl_multivariatenormal_multivariatenormal(p, q):
+    # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
+    if p.event_shape != q.event_shape:
+        raise ValueError(
+            "KL-divergence between two Multivariate Normals with\
+                          different event shapes cannot be computed"
+        )
+
+    half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
+        -1
+    ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+    combined_batch_shape = torch._C._infer_size(
+        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2]
+    )
+    n = p.event_shape[0]
+    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
+    p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
+    term2 = _batch_trace_XXT(
+        torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)
+    )
+    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
+    return half_term1 + 0.5 * (term2 + term3 - n)
+
+
+@register_kl(Normal, Normal)
+def _kl_normal_normal(p, q):
+    var_ratio = (p.scale / q.scale).pow(2)
+    t1 = ((p.loc - q.loc) / q.scale).pow(2)
+    return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
+
+
+@register_kl(OneHotCategorical, OneHotCategorical)
+def _kl_onehotcategorical_onehotcategorical(p, q):
+    return _kl_categorical_categorical(p._categorical, q._categorical)
+
+
+@register_kl(Pareto, Pareto)
+def _kl_pareto_pareto(p, q):
+    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
+    scale_ratio = p.scale / q.scale
+    alpha_ratio = q.alpha / p.alpha
+    t1 = q.alpha * scale_ratio.log()
+    t2 = -alpha_ratio.log()
+    result = t1 + t2 + alpha_ratio - 1
+    result[p.support.lower_bound < q.support.lower_bound] = inf
+    return result
+
+
+@register_kl(Poisson, Poisson)
+def _kl_poisson_poisson(p, q):
+    return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
+
+
+@register_kl(TransformedDistribution, TransformedDistribution)
+def _kl_transformed_transformed(p, q):
+    if p.transforms != q.transforms:
+        raise NotImplementedError
+    if p.event_shape != q.event_shape:
+        raise NotImplementedError
+    return kl_divergence(p.base_dist, q.base_dist)
+
+
+@register_kl(Uniform, Uniform)
+def _kl_uniform_uniform(p, q):
+    result = ((q.high - q.low) / (p.high - p.low)).log()
+    result[(q.low > p.low) | (q.high < p.high)] = inf
+    return result
+
+
+# Different distributions
+@register_kl(Bernoulli, Poisson)
+def _kl_bernoulli_poisson(p, q):
+    return -p.entropy() - (p.probs * q.rate.log() - q.rate)
+
+
+@register_kl(Beta, ContinuousBernoulli)
+def _kl_beta_continuous_bernoulli(p, q):
+    return (
+        -p.entropy()
+        - p.mean * q.logits
+        - torch.log1p(-q.probs)
+        - q._cont_bern_log_norm()
+    )
+
+
+@register_kl(Beta, Pareto)
+def _kl_beta_infinity(p, q):
+    return _infinite_like(p.concentration1)
+
+
+@register_kl(Beta, Exponential)
+def _kl_beta_exponential(p, q):
+    return (
+        -p.entropy()
+        - q.rate.log()
+        + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
+    )
+
+
+@register_kl(Beta, Gamma)
+def _kl_beta_gamma(p, q):
+    t1 = -p.entropy()
+    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
+    t3 = (q.concentration - 1) * (
+        p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()
+    )
+    t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
+    return t1 + t2 - t3 + t4
+
+
+# TODO: Add Beta-Laplace KL Divergence
+
+
+@register_kl(Beta, Normal)
+def _kl_beta_normal(p, q):
+    E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
+    var_normal = q.scale.pow(2)
+    t1 = -p.entropy()
+    t2 = 0.5 * (var_normal * 2 * math.pi).log()
+    t3 = (
+        E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1)
+        + E_beta.pow(2)
+    ) * 0.5
+    t4 = q.loc * E_beta
+    t5 = q.loc.pow(2) * 0.5
+    return t1 + t2 + (t3 - t4 + t5) / var_normal
+
+
+@register_kl(Beta, Uniform)
+def _kl_beta_uniform(p, q):
+    result = -p.entropy() + (q.high - q.low).log()
+    result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
+    return result
+
+
+# Note that the KL between a ContinuousBernoulli and Beta has no closed form
+
+
+@register_kl(ContinuousBernoulli, Pareto)
+def _kl_continuous_bernoulli_infinity(p, q):
+    return _infinite_like(p.probs)
+
+
+@register_kl(ContinuousBernoulli, Exponential)
+def _kl_continuous_bernoulli_exponential(p, q):
+    return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
+
+
+# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
+# TODO: Add ContinuousBernoulli-Laplace KL Divergence
+
+
+@register_kl(ContinuousBernoulli, Normal)
+def _kl_continuous_bernoulli_normal(p, q):
+    t1 = -p.entropy()
+    t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log(
+        q.scale
+    )
+    t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / (
+        2.0 * torch.square(q.scale)
+    )
+    return t1 + t2 + t3
+
+
+@register_kl(ContinuousBernoulli, Uniform)
+def _kl_continuous_bernoulli_uniform(p, q):
+    result = -p.entropy() + (q.high - q.low).log()
+    return torch.where(
+        torch.max(
+            torch.ge(q.low, p.support.lower_bound),
+            torch.le(q.high, p.support.upper_bound),
+        ),
+        torch.ones_like(result) * inf,
+        result,
+    )
+
+
+@register_kl(Exponential, Beta)
+@register_kl(Exponential, ContinuousBernoulli)
+@register_kl(Exponential, Pareto)
+@register_kl(Exponential, Uniform)
+def _kl_exponential_infinity(p, q):
+    return _infinite_like(p.rate)
+
+
+@register_kl(Exponential, Gamma)
+def _kl_exponential_gamma(p, q):
+    ratio = q.rate / p.rate
+    t1 = -q.concentration * torch.log(ratio)
+    return (
+        t1
+        + ratio
+        + q.concentration.lgamma()
+        + q.concentration * _euler_gamma
+        - (1 + _euler_gamma)
+    )
+
+
+@register_kl(Exponential, Gumbel)
+def _kl_exponential_gumbel(p, q):
+    scale_rate_prod = p.rate * q.scale
+    loc_scale_ratio = q.loc / q.scale
+    t1 = scale_rate_prod.log() - 1
+    t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
+    t3 = scale_rate_prod.reciprocal()
+    return t1 - loc_scale_ratio + t2 + t3
+
+
+# TODO: Add Exponential-Laplace KL Divergence
+
+
+@register_kl(Exponential, Normal)
+def _kl_exponential_normal(p, q):
+    var_normal = q.scale.pow(2)
+    rate_sqr = p.rate.pow(2)
+    t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
+    t2 = rate_sqr.reciprocal()
+    t3 = q.loc / p.rate
+    t4 = q.loc.pow(2) * 0.5
+    return t1 - 1 + (t2 - t3 + t4) / var_normal
+
+
+@register_kl(Gamma, Beta)
+@register_kl(Gamma, ContinuousBernoulli)
+@register_kl(Gamma, Pareto)
+@register_kl(Gamma, Uniform)
+def _kl_gamma_infinity(p, q):
+    return _infinite_like(p.concentration)
+
+
+@register_kl(Gamma, Exponential)
+def _kl_gamma_exponential(p, q):
+    return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
+
+
+@register_kl(Gamma, Gumbel)
+def _kl_gamma_gumbel(p, q):
+    beta_scale_prod = p.rate * q.scale
+    loc_scale_ratio = q.loc / q.scale
+    t1 = (
+        (p.concentration - 1) * p.concentration.digamma()
+        - p.concentration.lgamma()
+        - p.concentration
+    )
+    t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
+    t3 = (
+        torch.exp(loc_scale_ratio)
+        * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration)
+        - loc_scale_ratio
+    )
+    return t1 + t2 + t3
+
+
+# TODO: Add Gamma-Laplace KL Divergence
+
+
+@register_kl(Gamma, Normal)
+def _kl_gamma_normal(p, q):
+    var_normal = q.scale.pow(2)
+    beta_sqr = p.rate.pow(2)
+    t1 = (
+        0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi)
+        - p.concentration
+        - p.concentration.lgamma()
+    )
+    t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
+    t3 = q.loc * p.concentration / p.rate
+    t4 = 0.5 * q.loc.pow(2)
+    return (
+        t1
+        + (p.concentration - 1) * p.concentration.digamma()
+        + (t2 - t3 + t4) / var_normal
+    )
+
+
+@register_kl(Gumbel, Beta)
+@register_kl(Gumbel, ContinuousBernoulli)
+@register_kl(Gumbel, Exponential)
+@register_kl(Gumbel, Gamma)
+@register_kl(Gumbel, Pareto)
+@register_kl(Gumbel, Uniform)
+def _kl_gumbel_infinity(p, q):
+    return _infinite_like(p.loc)
+
+
+# TODO: Add Gumbel-Laplace KL Divergence
+
+
+@register_kl(Gumbel, Normal)
+def _kl_gumbel_normal(p, q):
+    param_ratio = p.scale / q.scale
+    t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
+    t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
+    t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
+    return -t1 + t2 + t3 - (_euler_gamma + 1)
+
+
+@register_kl(Laplace, Beta)
+@register_kl(Laplace, ContinuousBernoulli)
+@register_kl(Laplace, Exponential)
+@register_kl(Laplace, Gamma)
+@register_kl(Laplace, Pareto)
+@register_kl(Laplace, Uniform)
+def _kl_laplace_infinity(p, q):
+    return _infinite_like(p.loc)
+
+
+@register_kl(Laplace, Normal)
+def _kl_laplace_normal(p, q):
+    var_normal = q.scale.pow(2)
+    scale_sqr_var_ratio = p.scale.pow(2) / var_normal
+    t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
+    t2 = 0.5 * p.loc.pow(2)
+    t3 = p.loc * q.loc
+    t4 = 0.5 * q.loc.pow(2)
+    return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
+
+
+@register_kl(Normal, Beta)
+@register_kl(Normal, ContinuousBernoulli)
+@register_kl(Normal, Exponential)
+@register_kl(Normal, Gamma)
+@register_kl(Normal, Pareto)
+@register_kl(Normal, Uniform)
+def _kl_normal_infinity(p, q):
+    return _infinite_like(p.loc)
+
+
+@register_kl(Normal, Gumbel)
+def _kl_normal_gumbel(p, q):
+    mean_scale_ratio = p.loc / q.scale
+    var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
+    loc_scale_ratio = q.loc / q.scale
+    t1 = var_scale_sqr_ratio.log() * 0.5
+    t2 = mean_scale_ratio - loc_scale_ratio
+    t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
+    return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
+
+
+@register_kl(Normal, Laplace)
+def _kl_normal_laplace(p, q):
+    loc_diff = p.loc - q.loc
+    scale_ratio = p.scale / q.scale
+    loc_diff_scale_ratio = loc_diff / p.scale
+    t1 = torch.log(scale_ratio)
+    t2 = (
+        math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2))
+    )
+    t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio)
+    return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi)))
+
+
+@register_kl(Pareto, Beta)
+@register_kl(Pareto, ContinuousBernoulli)
+@register_kl(Pareto, Uniform)
+def _kl_pareto_infinity(p, q):
+    return _infinite_like(p.scale)
+
+
+@register_kl(Pareto, Exponential)
+def _kl_pareto_exponential(p, q):
+    scale_rate_prod = p.scale * q.rate
+    t1 = (p.alpha / scale_rate_prod).log()
+    t2 = p.alpha.reciprocal()
+    t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
+    result = t1 - t2 + t3 - 1
+    result[p.alpha <= 1] = inf
+    return result
+
+
+@register_kl(Pareto, Gamma)
+def _kl_pareto_gamma(p, q):
+    common_term = p.scale.log() + p.alpha.reciprocal()
+    t1 = p.alpha.log() - common_term
+    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
+    t3 = (1 - q.concentration) * common_term
+    t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
+    result = t1 + t2 + t3 + t4 - 1
+    result[p.alpha <= 1] = inf
+    return result
+
+
+# TODO: Add Pareto-Laplace KL Divergence
+
+
+@register_kl(Pareto, Normal)
+def _kl_pareto_normal(p, q):
+    var_normal = 2 * q.scale.pow(2)
+    common_term = p.scale / (p.alpha - 1)
+    t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
+    t2 = p.alpha.reciprocal()
+    t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
+    t4 = (p.alpha * common_term - q.loc).pow(2)
+    result = t1 - t2 + (t3 + t4) / var_normal - 1
+    result[p.alpha <= 2] = inf
+    return result
+
+
+@register_kl(Poisson, Bernoulli)
+@register_kl(Poisson, Binomial)
+def _kl_poisson_infinity(p, q):
+    return _infinite_like(p.rate)
+
+
+@register_kl(Uniform, Beta)
+def _kl_uniform_beta(p, q):
+    common_term = p.high - p.low
+    t1 = torch.log(common_term)
+    t2 = (
+        (q.concentration1 - 1)
+        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
+        / common_term
+    )
+    t3 = (
+        (q.concentration0 - 1)
+        * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term)
+        / common_term
+    )
+    t4 = (
+        q.concentration1.lgamma()
+        + q.concentration0.lgamma()
+        - (q.concentration1 + q.concentration0).lgamma()
+    )
+    result = t3 + t4 - t1 - t2
+    result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
+    return result
+
+
+@register_kl(Uniform, ContinuousBernoulli)
+def _kl_uniform_continuous_bernoulli(p, q):
+    result = (
+        -p.entropy()
+        - p.mean * q.logits
+        - torch.log1p(-q.probs)
+        - q._cont_bern_log_norm()
+    )
+    return torch.where(
+        torch.max(
+            torch.ge(p.high, q.support.upper_bound),
+            torch.le(p.low, q.support.lower_bound),
+        ),
+        torch.ones_like(result) * inf,
+        result,
+    )
+
+
+@register_kl(Uniform, Exponential)
+def _kl_uniform_exponetial(p, q):
+    result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
+    result[p.low < q.support.lower_bound] = inf
+    return result
+
+
+@register_kl(Uniform, Gamma)
+def _kl_uniform_gamma(p, q):
+    common_term = p.high - p.low
+    t1 = common_term.log()
+    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
+    t3 = (
+        (1 - q.concentration)
+        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
+        / common_term
+    )
+    t4 = q.rate * (p.high + p.low) / 2
+    result = -t1 + t2 + t3 + t4
+    result[p.low < q.support.lower_bound] = inf
+    return result
+
+
+@register_kl(Uniform, Gumbel)
+def _kl_uniform_gumbel(p, q):
+    common_term = q.scale / (p.high - p.low)
+    high_loc_diff = (p.high - q.loc) / q.scale
+    low_loc_diff = (p.low - q.loc) / q.scale
+    t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
+    t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
+    return t1 - t2
+
+
+# TODO: Uniform-Laplace KL Divergence
+
+
+@register_kl(Uniform, Normal)
+def _kl_uniform_normal(p, q):
+    common_term = p.high - p.low
+    t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
+    t2 = (common_term).pow(2) / 12
+    t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
+    return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
+
+
+@register_kl(Uniform, Pareto)
+def _kl_uniform_pareto(p, q):
+    support_uniform = p.high - p.low
+    t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
+    t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
+    result = t2 * (q.alpha + 1) - t1
+    result[p.low < q.support.lower_bound] = inf
+    return result
+
+
+@register_kl(Independent, Independent)
+def _kl_independent_independent(p, q):
+    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
+        raise NotImplementedError
+    result = kl_divergence(p.base_dist, q.base_dist)
+    return _sum_rightmost(result, p.reinterpreted_batch_ndims)
+
+
+@register_kl(Cauchy, Cauchy)
+def _kl_cauchy_cauchy(p, q):
+    # From https://arxiv.org/abs/1905.10965
+    t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
+    t2 = (4 * p.scale * q.scale).log()
+    return t1 - t2
+
+
+def _add_kl_info():
+    """Appends a list of implemented KL functions to the doc for kl_divergence."""
+    rows = [
+        "KL divergence is currently implemented for the following distribution pairs:"
+    ]
+    for p, q in sorted(
+        _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)
+    ):
+        rows.append(
+            f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`"
+        )
+    kl_info = "\n\t".join(rows)
+    if kl_divergence.__doc__:
+        kl_divergence.__doc__ += kl_info  # type: ignore[operator]
diff --git a/MLPY/Lib/site-packages/torch/distributions/kumaraswamy.py b/MLPY/Lib/site-packages/torch/distributions/kumaraswamy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a07e72b67c9a442ee89a8007b7843802674038a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/kumaraswamy.py
@@ -0,0 +1,97 @@
+import torch
+from torch import nan
+from torch.distributions import constraints
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, PowerTransform
+from torch.distributions.uniform import Uniform
+from torch.distributions.utils import broadcast_all, euler_constant
+
+__all__ = ["Kumaraswamy"]
+
+
+def _moments(a, b, n):
+    """
+    Computes nth moment of Kumaraswamy using using torch.lgamma
+    """
+    arg1 = 1 + n / a
+    log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
+    return b * torch.exp(log_value)
+
+
+class Kumaraswamy(TransformedDistribution):
+    r"""
+    Samples from a Kumaraswamy distribution.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))
+        >>> m.sample()  # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
+        tensor([ 0.1729])
+
+    Args:
+        concentration1 (float or Tensor): 1st concentration parameter of the distribution
+            (often referred to as alpha)
+        concentration0 (float or Tensor): 2nd concentration parameter of the distribution
+            (often referred to as beta)
+    """
+    arg_constraints = {
+        "concentration1": constraints.positive,
+        "concentration0": constraints.positive,
+    }
+    support = constraints.unit_interval
+    has_rsample = True
+
+    def __init__(self, concentration1, concentration0, validate_args=None):
+        self.concentration1, self.concentration0 = broadcast_all(
+            concentration1, concentration0
+        )
+        finfo = torch.finfo(self.concentration0.dtype)
+        base_dist = Uniform(
+            torch.full_like(self.concentration0, 0),
+            torch.full_like(self.concentration0, 1),
+            validate_args=validate_args,
+        )
+        transforms = [
+            PowerTransform(exponent=self.concentration0.reciprocal()),
+            AffineTransform(loc=1.0, scale=-1.0),
+            PowerTransform(exponent=self.concentration1.reciprocal()),
+        ]
+        super().__init__(base_dist, transforms, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Kumaraswamy, _instance)
+        new.concentration1 = self.concentration1.expand(batch_shape)
+        new.concentration0 = self.concentration0.expand(batch_shape)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def mean(self):
+        return _moments(self.concentration1, self.concentration0, 1)
+
+    @property
+    def mode(self):
+        # Evaluate in log-space for numerical stability.
+        log_mode = (
+            self.concentration0.reciprocal() * (-self.concentration0).log1p()
+            - (-self.concentration0 * self.concentration1).log1p()
+        )
+        log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan
+        return log_mode.exp()
+
+    @property
+    def variance(self):
+        return _moments(self.concentration1, self.concentration0, 2) - torch.pow(
+            self.mean, 2
+        )
+
+    def entropy(self):
+        t1 = 1 - self.concentration1.reciprocal()
+        t0 = 1 - self.concentration0.reciprocal()
+        H0 = torch.digamma(self.concentration0 + 1) + euler_constant
+        return (
+            t0
+            + t1 * H0
+            - torch.log(self.concentration1)
+            - torch.log(self.concentration0)
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/laplace.py b/MLPY/Lib/site-packages/torch/distributions/laplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..64bf2e3937cd2ea705aee512e12f4ce70190a649
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/laplace.py
@@ -0,0 +1,94 @@
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Laplace"]
+
+
+class Laplace(Distribution):
+    r"""
+    Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
+        >>> m.sample()  # Laplace distributed with loc=0, scale=1
+        tensor([ 0.1046])
+
+    Args:
+        loc (float or Tensor): mean of the distribution
+        scale (float or Tensor): scale of the distribution
+    """
+    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+    support = constraints.real
+    has_rsample = True
+
+    @property
+    def mean(self):
+        return self.loc
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @property
+    def variance(self):
+        return 2 * self.scale.pow(2)
+
+    @property
+    def stddev(self):
+        return (2**0.5) * self.scale
+
+    def __init__(self, loc, scale, validate_args=None):
+        self.loc, self.scale = broadcast_all(loc, scale)
+        if isinstance(loc, Number) and isinstance(scale, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.loc.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Laplace, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.loc = self.loc.expand(batch_shape)
+        new.scale = self.scale.expand(batch_shape)
+        super(Laplace, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        finfo = torch.finfo(self.loc.dtype)
+        if torch._C._get_tracing_state():
+            # [JIT WORKAROUND] lack of support for .uniform_()
+            u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
+            return self.loc - self.scale * u.sign() * torch.log1p(
+                -u.abs().clamp(min=finfo.tiny)
+            )
+        u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
+        # TODO: If we ever implement tensor.nextafter, below is what we want ideally.
+        # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
+        return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(
+            -(value - self.loc).abs() / self.scale
+        )
+
+    def icdf(self, value):
+        term = value - 0.5
+        return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
+
+    def entropy(self):
+        return 1 + torch.log(2 * self.scale)
diff --git a/MLPY/Lib/site-packages/torch/distributions/lkj_cholesky.py b/MLPY/Lib/site-packages/torch/distributions/lkj_cholesky.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c35ec50cf8f50078decd87a9cb8879703ffce1a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/lkj_cholesky.py
@@ -0,0 +1,142 @@
+"""
+This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).
+
+Original copyright notice:
+
+# Copyright: Contributors to the Pyro project.
+# SPDX-License-Identifier: Apache-2.0
+"""
+
+import math
+
+import torch
+from torch.distributions import Beta, constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["LKJCholesky"]
+
+
+class LKJCholesky(Distribution):
+    r"""
+    LKJ distribution for lower Cholesky factor of correlation matrices.
+    The distribution is controlled by ``concentration`` parameter :math:`\eta`
+    to make the probability of the correlation matrix :math:`M` generated from
+    a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that,
+    when ``concentration == 1``, we have a uniform distribution over Cholesky
+    factors of correlation matrices::
+
+        L ~ LKJCholesky(dim, concentration)
+        X = L @ L' ~ LKJCorr(dim, concentration)
+
+    Note that this distribution samples the
+    Cholesky factor of correlation matrices and not the correlation matrices
+    themselves and thereby differs slightly from the derivations in [1] for
+    the `LKJCorr` distribution. For sampling, this uses the Onion method from
+    [1] Section 3.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> l = LKJCholesky(3, 0.5)
+        >>> l.sample()  # l @ l.T is a sample of a correlation 3x3 matrix
+        tensor([[ 1.0000,  0.0000,  0.0000],
+                [ 0.3516,  0.9361,  0.0000],
+                [-0.1899,  0.4748,  0.8593]])
+
+    Args:
+        dimension (dim): dimension of the matrices
+        concentration (float or Tensor): concentration/shape parameter of the
+            distribution (often referred to as eta)
+
+    **References**
+
+    [1] `Generating random correlation matrices based on vines and extended onion method` (2009),
+    Daniel Lewandowski, Dorota Kurowicka, Harry Joe.
+    Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
+    """
+    arg_constraints = {"concentration": constraints.positive}
+    support = constraints.corr_cholesky
+
+    def __init__(self, dim, concentration=1.0, validate_args=None):
+        if dim < 2:
+            raise ValueError(
+                f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}."
+            )
+        self.dim = dim
+        (self.concentration,) = broadcast_all(concentration)
+        batch_shape = self.concentration.size()
+        event_shape = torch.Size((dim, dim))
+        # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
+        marginal_conc = self.concentration + 0.5 * (self.dim - 2)
+        offset = torch.arange(
+            self.dim - 1,
+            dtype=self.concentration.dtype,
+            device=self.concentration.device,
+        )
+        offset = torch.cat([offset.new_zeros((1,)), offset])
+        beta_conc1 = offset + 0.5
+        beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
+        self._beta = Beta(beta_conc1, beta_conc0)
+        super().__init__(batch_shape, event_shape, validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(LKJCholesky, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.dim = self.dim
+        new.concentration = self.concentration.expand(batch_shape)
+        new._beta = self._beta.expand(batch_shape + (self.dim,))
+        super(LKJCholesky, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    def sample(self, sample_shape=torch.Size()):
+        # This uses the Onion method, but there are a few differences from [1] Sec. 3.2:
+        # - This vectorizes the for loop and also works for heterogeneous eta.
+        # - Same algorithm generalizes to n=1.
+        # - The procedure is simplified since we are sampling the cholesky factor of
+        #   the correlation matrix instead of the correlation matrix itself. As such,
+        #   we only need to generate `w`.
+        y = self._beta.sample(sample_shape).unsqueeze(-1)
+        u_normal = torch.randn(
+            self._extended_shape(sample_shape), dtype=y.dtype, device=y.device
+        ).tril(-1)
+        u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
+        # Replace NaNs in first row
+        u_hypersphere[..., 0, :].fill_(0.0)
+        w = torch.sqrt(y) * u_hypersphere
+        # Fill diagonal elements; clamp for numerical stability
+        eps = torch.finfo(w.dtype).tiny
+        diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
+        w += torch.diag_embed(diag_elems)
+        return w
+
+    def log_prob(self, value):
+        # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
+        # The probability of a correlation matrix is proportional to
+        #   determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
+        # Additionally, the Jacobian of the transformation from Cholesky factor to
+        # correlation matrix is:
+        #   prod(L_ii ^ (D - i))
+        # So the probability of a Cholesky factor is propotional to
+        #   prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
+        # with order_i = 2 * concentration - 2 + D - i
+        if self._validate_args:
+            self._validate_sample(value)
+        diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
+        order = torch.arange(2, self.dim + 1, device=self.concentration.device)
+        order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
+        unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
+        # Compute normalization constant (page 1999 of [1])
+        dm1 = self.dim - 1
+        alpha = self.concentration + 0.5 * dm1
+        denominator = torch.lgamma(alpha) * dm1
+        numerator = torch.mvlgamma(alpha - 0.5, dm1)
+        # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
+        # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
+        # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
+        pi_constant = 0.5 * dm1 * math.log(math.pi)
+        normalize_term = pi_constant + numerator - denominator
+        return unnormalized_log_pdf - normalize_term
diff --git a/MLPY/Lib/site-packages/torch/distributions/log_normal.py b/MLPY/Lib/site-packages/torch/distributions/log_normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..536c5c307fdcd75d58608b11203ff00a40e12923
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/log_normal.py
@@ -0,0 +1,62 @@
+from torch.distributions import constraints
+from torch.distributions.normal import Normal
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import ExpTransform
+
+__all__ = ["LogNormal"]
+
+
+class LogNormal(TransformedDistribution):
+    r"""
+    Creates a log-normal distribution parameterized by
+    :attr:`loc` and :attr:`scale` where::
+
+        X ~ Normal(loc, scale)
+        Y = exp(X) ~ LogNormal(loc, scale)
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
+        >>> m.sample()  # log-normal distributed with mean=0 and stddev=1
+        tensor([ 0.1046])
+
+    Args:
+        loc (float or Tensor): mean of log of distribution
+        scale (float or Tensor): standard deviation of log of the distribution
+    """
+    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+    support = constraints.positive
+    has_rsample = True
+
+    def __init__(self, loc, scale, validate_args=None):
+        base_dist = Normal(loc, scale, validate_args=validate_args)
+        super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(LogNormal, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def loc(self):
+        return self.base_dist.loc
+
+    @property
+    def scale(self):
+        return self.base_dist.scale
+
+    @property
+    def mean(self):
+        return (self.loc + self.scale.pow(2) / 2).exp()
+
+    @property
+    def mode(self):
+        return (self.loc - self.scale.square()).exp()
+
+    @property
+    def variance(self):
+        scale_sq = self.scale.pow(2)
+        return scale_sq.expm1() * (2 * self.loc + scale_sq).exp()
+
+    def entropy(self):
+        return self.base_dist.entropy() + self.loc
diff --git a/MLPY/Lib/site-packages/torch/distributions/logistic_normal.py b/MLPY/Lib/site-packages/torch/distributions/logistic_normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..d23d0bb7a6f0411f03de33b3f3b7f9c46abd4d79
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/logistic_normal.py
@@ -0,0 +1,54 @@
+from torch.distributions import constraints
+from torch.distributions.normal import Normal
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import StickBreakingTransform
+
+__all__ = ["LogisticNormal"]
+
+
+class LogisticNormal(TransformedDistribution):
+    r"""
+    Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
+    that define the base `Normal` distribution transformed with the
+    `StickBreakingTransform` such that::
+
+        X ~ LogisticNormal(loc, scale)
+        Y = log(X / (1 - X.cumsum(-1)))[..., :-1] ~ Normal(loc, scale)
+
+    Args:
+        loc (float or Tensor): mean of the base distribution
+        scale (float or Tensor): standard deviation of the base distribution
+
+    Example::
+
+        >>> # logistic-normal distributed with mean=(0, 0, 0) and stddev=(1, 1, 1)
+        >>> # of the base Normal distribution
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = LogisticNormal(torch.tensor([0.0] * 3), torch.tensor([1.0] * 3))
+        >>> m.sample()
+        tensor([ 0.7653,  0.0341,  0.0579,  0.1427])
+
+    """
+    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+    support = constraints.simplex
+    has_rsample = True
+
+    def __init__(self, loc, scale, validate_args=None):
+        base_dist = Normal(loc, scale, validate_args=validate_args)
+        if not base_dist.batch_shape:
+            base_dist = base_dist.expand([1])
+        super().__init__(
+            base_dist, StickBreakingTransform(), validate_args=validate_args
+        )
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(LogisticNormal, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def loc(self):
+        return self.base_dist.base_dist.loc
+
+    @property
+    def scale(self):
+        return self.base_dist.base_dist.scale
diff --git a/MLPY/Lib/site-packages/torch/distributions/lowrank_multivariate_normal.py b/MLPY/Lib/site-packages/torch/distributions/lowrank_multivariate_normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf3fda59c6144089026674d9f6de60be8be5421e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/lowrank_multivariate_normal.py
@@ -0,0 +1,237 @@
+import math
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
+from torch.distributions.utils import _standard_normal, lazy_property
+
+__all__ = ["LowRankMultivariateNormal"]
+
+
+def _batch_capacitance_tril(W, D):
+    r"""
+    Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
+    and a batch of vectors :math:`D`.
+    """
+    m = W.size(-1)
+    Wt_Dinv = W.mT / D.unsqueeze(-2)
+    K = torch.matmul(Wt_Dinv, W).contiguous()
+    K.view(-1, m * m)[:, :: m + 1] += 1  # add identity matrix to K
+    return torch.linalg.cholesky(K)
+
+
+def _batch_lowrank_logdet(W, D, capacitance_tril):
+    r"""
+    Uses "matrix determinant lemma"::
+        log|W @ W.T + D| = log|C| + log|D|,
+    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
+    the log determinant.
+    """
+    return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
+        -1
+    )
+
+
+def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
+    r"""
+    Uses "Woodbury matrix identity"::
+        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
+    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
+    Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
+    """
+    Wt_Dinv = W.mT / D.unsqueeze(-2)
+    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
+    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
+    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
+    return mahalanobis_term1 - mahalanobis_term2
+
+
+class LowRankMultivariateNormal(Distribution):
+    r"""
+    Creates a multivariate normal distribution with covariance matrix having a low-rank form
+    parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
+
+        covariance_matrix = cov_factor @ cov_factor.T + cov_diag
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
+        >>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
+        tensor([-0.2102, -0.5429])
+
+    Args:
+        loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
+        cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
+            `batch_shape + event_shape + (rank,)`
+        cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
+            `batch_shape + event_shape`
+
+    Note:
+        The computation for determinant and inverse of covariance matrix is avoided when
+        `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
+        `_ and
+        `matrix determinant lemma `_.
+        Thanks to these formulas, we just need to compute the determinant and inverse of
+        the small size "capacitance" matrix::
+
+            capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
+    """
+    arg_constraints = {
+        "loc": constraints.real_vector,
+        "cov_factor": constraints.independent(constraints.real, 2),
+        "cov_diag": constraints.independent(constraints.positive, 1),
+    }
+    support = constraints.real_vector
+    has_rsample = True
+
+    def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
+        if loc.dim() < 1:
+            raise ValueError("loc must be at least one-dimensional.")
+        event_shape = loc.shape[-1:]
+        if cov_factor.dim() < 2:
+            raise ValueError(
+                "cov_factor must be at least two-dimensional, "
+                "with optional leading batch dimensions"
+            )
+        if cov_factor.shape[-2:-1] != event_shape:
+            raise ValueError(
+                f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m"
+            )
+        if cov_diag.shape[-1:] != event_shape:
+            raise ValueError(
+                f"cov_diag must be a batch of vectors with shape {event_shape}"
+            )
+
+        loc_ = loc.unsqueeze(-1)
+        cov_diag_ = cov_diag.unsqueeze(-1)
+        try:
+            loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
+                loc_, cov_factor, cov_diag_
+            )
+        except RuntimeError as e:
+            raise ValueError(
+                f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}"
+            ) from e
+        self.loc = loc_[..., 0]
+        self.cov_diag = cov_diag_[..., 0]
+        batch_shape = self.loc.shape[:-1]
+
+        self._unbroadcasted_cov_factor = cov_factor
+        self._unbroadcasted_cov_diag = cov_diag
+        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
+        batch_shape = torch.Size(batch_shape)
+        loc_shape = batch_shape + self.event_shape
+        new.loc = self.loc.expand(loc_shape)
+        new.cov_diag = self.cov_diag.expand(loc_shape)
+        new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
+        new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
+        new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
+        new._capacitance_tril = self._capacitance_tril
+        super(LowRankMultivariateNormal, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def mean(self):
+        return self.loc
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @lazy_property
+    def variance(self):
+        return (
+            self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag
+        ).expand(self._batch_shape + self._event_shape)
+
+    @lazy_property
+    def scale_tril(self):
+        # The following identity is used to increase the numerically computation stability
+        # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
+        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
+        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
+        # hence it is well-conditioned and safe to take Cholesky decomposition.
+        n = self._event_shape[0]
+        cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
+        Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
+        K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
+        K.view(-1, n * n)[:, :: n + 1] += 1  # add identity matrix to K
+        scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
+        return scale_tril.expand(
+            self._batch_shape + self._event_shape + self._event_shape
+        )
+
+    @lazy_property
+    def covariance_matrix(self):
+        covariance_matrix = torch.matmul(
+            self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT
+        ) + torch.diag_embed(self._unbroadcasted_cov_diag)
+        return covariance_matrix.expand(
+            self._batch_shape + self._event_shape + self._event_shape
+        )
+
+    @lazy_property
+    def precision_matrix(self):
+        # We use "Woodbury matrix identity" to take advantage of low rank form::
+        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
+        # where :math:`C` is the capacitance matrix.
+        Wt_Dinv = (
+            self._unbroadcasted_cov_factor.mT
+            / self._unbroadcasted_cov_diag.unsqueeze(-2)
+        )
+        A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
+        precision_matrix = (
+            torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
+        )
+        return precision_matrix.expand(
+            self._batch_shape + self._event_shape + self._event_shape
+        )
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        W_shape = shape[:-1] + self.cov_factor.shape[-1:]
+        eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
+        eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+        return (
+            self.loc
+            + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+            + self._unbroadcasted_cov_diag.sqrt() * eps_D
+        )
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        diff = value - self.loc
+        M = _batch_lowrank_mahalanobis(
+            self._unbroadcasted_cov_factor,
+            self._unbroadcasted_cov_diag,
+            diff,
+            self._capacitance_tril,
+        )
+        log_det = _batch_lowrank_logdet(
+            self._unbroadcasted_cov_factor,
+            self._unbroadcasted_cov_diag,
+            self._capacitance_tril,
+        )
+        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
+
+    def entropy(self):
+        log_det = _batch_lowrank_logdet(
+            self._unbroadcasted_cov_factor,
+            self._unbroadcasted_cov_diag,
+            self._capacitance_tril,
+        )
+        H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
+        if len(self._batch_shape) == 0:
+            return H
+        else:
+            return H.expand(self._batch_shape)
diff --git a/MLPY/Lib/site-packages/torch/distributions/mixture_same_family.py b/MLPY/Lib/site-packages/torch/distributions/mixture_same_family.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ba3be63cce2446ed9c839294284485fb4713f3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/mixture_same_family.py
@@ -0,0 +1,214 @@
+from typing import Dict
+
+import torch
+from torch.distributions import Categorical, constraints
+from torch.distributions.distribution import Distribution
+
+__all__ = ["MixtureSameFamily"]
+
+
+class MixtureSameFamily(Distribution):
+    r"""
+    The `MixtureSameFamily` distribution implements a (batch of) mixture
+    distribution where all component are from different parameterizations of
+    the same distribution type. It is parameterized by a `Categorical`
+    "selecting distribution" (over `k` component) and a component
+    distribution, i.e., a `Distribution` with a rightmost batch shape
+    (equal to `[k]`) which indexes each (batch of) component.
+
+    Examples::
+
+        >>> # xdoctest: +SKIP("undefined vars")
+        >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
+        >>> # weighted normal distributions
+        >>> mix = D.Categorical(torch.ones(5,))
+        >>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
+        >>> gmm = MixtureSameFamily(mix, comp)
+
+        >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
+        >>> # weighted bivariate normal distributions
+        >>> mix = D.Categorical(torch.ones(5,))
+        >>> comp = D.Independent(D.Normal(
+        ...          torch.randn(5,2), torch.rand(5,2)), 1)
+        >>> gmm = MixtureSameFamily(mix, comp)
+
+        >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
+        >>> # consisting of 5 random weighted bivariate normal distributions
+        >>> mix = D.Categorical(torch.rand(3,5))
+        >>> comp = D.Independent(D.Normal(
+        ...         torch.randn(3,5,2), torch.rand(3,5,2)), 1)
+        >>> gmm = MixtureSameFamily(mix, comp)
+
+    Args:
+        mixture_distribution: `torch.distributions.Categorical`-like
+            instance. Manages the probability of selecting component.
+            The number of categories must match the rightmost batch
+            dimension of the `component_distribution`. Must have either
+            scalar `batch_shape` or `batch_shape` matching
+            `component_distribution.batch_shape[:-1]`
+        component_distribution: `torch.distributions.Distribution`-like
+            instance. Right-most batch dimension indexes component.
+    """
+    arg_constraints: Dict[str, constraints.Constraint] = {}
+    has_rsample = False
+
+    def __init__(
+        self, mixture_distribution, component_distribution, validate_args=None
+    ):
+        self._mixture_distribution = mixture_distribution
+        self._component_distribution = component_distribution
+
+        if not isinstance(self._mixture_distribution, Categorical):
+            raise ValueError(
+                " The Mixture distribution needs to be an "
+                " instance of torch.distributions.Categorical"
+            )
+
+        if not isinstance(self._component_distribution, Distribution):
+            raise ValueError(
+                "The Component distribution need to be an "
+                "instance of torch.distributions.Distribution"
+            )
+
+        # Check that batch size matches
+        mdbs = self._mixture_distribution.batch_shape
+        cdbs = self._component_distribution.batch_shape[:-1]
+        for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
+            if size1 != 1 and size2 != 1 and size1 != size2:
+                raise ValueError(
+                    f"`mixture_distribution.batch_shape` ({mdbs}) is not "
+                    "compatible with `component_distribution."
+                    f"batch_shape`({cdbs})"
+                )
+
+        # Check that the number of mixture component matches
+        km = self._mixture_distribution.logits.shape[-1]
+        kc = self._component_distribution.batch_shape[-1]
+        if km is not None and kc is not None and km != kc:
+            raise ValueError(
+                f"`mixture_distribution component` ({km}) does not"
+                " equal `component_distribution.batch_shape[-1]`"
+                f" ({kc})"
+            )
+        self._num_component = km
+
+        event_shape = self._component_distribution.event_shape
+        self._event_ndims = len(event_shape)
+        super().__init__(
+            batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
+        )
+
+    def expand(self, batch_shape, _instance=None):
+        batch_shape = torch.Size(batch_shape)
+        batch_shape_comp = batch_shape + (self._num_component,)
+        new = self._get_checked_instance(MixtureSameFamily, _instance)
+        new._component_distribution = self._component_distribution.expand(
+            batch_shape_comp
+        )
+        new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
+        new._num_component = self._num_component
+        new._event_ndims = self._event_ndims
+        event_shape = new._component_distribution.event_shape
+        super(MixtureSameFamily, new).__init__(
+            batch_shape=batch_shape, event_shape=event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    @constraints.dependent_property
+    def support(self):
+        # FIXME this may have the wrong shape when support contains batched
+        # parameters
+        return self._component_distribution.support
+
+    @property
+    def mixture_distribution(self):
+        return self._mixture_distribution
+
+    @property
+    def component_distribution(self):
+        return self._component_distribution
+
+    @property
+    def mean(self):
+        probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
+        return torch.sum(
+            probs * self.component_distribution.mean, dim=-1 - self._event_ndims
+        )  # [B, E]
+
+    @property
+    def variance(self):
+        # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
+        probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
+        mean_cond_var = torch.sum(
+            probs * self.component_distribution.variance, dim=-1 - self._event_ndims
+        )
+        var_cond_mean = torch.sum(
+            probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
+            dim=-1 - self._event_ndims,
+        )
+        return mean_cond_var + var_cond_mean
+
+    def cdf(self, x):
+        x = self._pad(x)
+        cdf_x = self.component_distribution.cdf(x)
+        mix_prob = self.mixture_distribution.probs
+
+        return torch.sum(cdf_x * mix_prob, dim=-1)
+
+    def log_prob(self, x):
+        if self._validate_args:
+            self._validate_sample(x)
+        x = self._pad(x)
+        log_prob_x = self.component_distribution.log_prob(x)  # [S, B, k]
+        log_mix_prob = torch.log_softmax(
+            self.mixture_distribution.logits, dim=-1
+        )  # [B, k]
+        return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1)  # [S, B]
+
+    def sample(self, sample_shape=torch.Size()):
+        with torch.no_grad():
+            sample_len = len(sample_shape)
+            batch_len = len(self.batch_shape)
+            gather_dim = sample_len + batch_len
+            es = self.event_shape
+
+            # mixture samples [n, B]
+            mix_sample = self.mixture_distribution.sample(sample_shape)
+            mix_shape = mix_sample.shape
+
+            # component samples [n, B, k, E]
+            comp_samples = self.component_distribution.sample(sample_shape)
+
+            # Gather along the k dimension
+            mix_sample_r = mix_sample.reshape(
+                mix_shape + torch.Size([1] * (len(es) + 1))
+            )
+            mix_sample_r = mix_sample_r.repeat(
+                torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
+            )
+
+            samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
+            return samples.squeeze(gather_dim)
+
+    def _pad(self, x):
+        return x.unsqueeze(-1 - self._event_ndims)
+
+    def _pad_mixture_dimensions(self, x):
+        dist_batch_ndims = len(self.batch_shape)
+        cat_batch_ndims = len(self.mixture_distribution.batch_shape)
+        pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
+        xs = x.shape
+        x = x.reshape(
+            xs[:-1]
+            + torch.Size(pad_ndims * [1])
+            + xs[-1:]
+            + torch.Size(self._event_ndims * [1])
+        )
+        return x
+
+    def __repr__(self):
+        args_string = (
+            f"\n  {self.mixture_distribution},\n  {self.component_distribution}"
+        )
+        return "MixtureSameFamily" + "(" + args_string + ")"
diff --git a/MLPY/Lib/site-packages/torch/distributions/multinomial.py b/MLPY/Lib/site-packages/torch/distributions/multinomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b76a0b3424f8b6f4fb1b952f07d18331fa4d0cb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/multinomial.py
@@ -0,0 +1,135 @@
+import torch
+from torch import inf
+from torch.distributions import Categorical, constraints
+from torch.distributions.binomial import Binomial
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Multinomial"]
+
+
+class Multinomial(Distribution):
+    r"""
+    Creates a Multinomial distribution parameterized by :attr:`total_count` and
+    either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
+    :attr:`probs` indexes over categories. All other dimensions index over batches.
+
+    Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
+    called (see example below)
+
+    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
+              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
+              will return this normalized value.
+              The `logits` argument will be interpreted as unnormalized log probabilities
+              and can therefore be any real number. It will likewise be normalized so that
+              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
+              will return this normalized value.
+
+    -   :meth:`sample` requires a single shared `total_count` for all
+        parameters and samples.
+    -   :meth:`log_prob` allows different `total_count` for each parameter and
+        sample.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("FIXME: found invalid values")
+        >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
+        >>> x = m.sample()  # equal probability of 0, 1, 2, 3
+        tensor([ 21.,  24.,  30.,  25.])
+
+        >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
+        tensor([-4.1338])
+
+    Args:
+        total_count (int): number of trials
+        probs (Tensor): event probabilities
+        logits (Tensor): event log probabilities (unnormalized)
+    """
+    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
+    total_count: int
+
+    @property
+    def mean(self):
+        return self.probs * self.total_count
+
+    @property
+    def variance(self):
+        return self.total_count * self.probs * (1 - self.probs)
+
+    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
+        if not isinstance(total_count, int):
+            raise NotImplementedError("inhomogeneous total_count is not supported")
+        self.total_count = total_count
+        self._categorical = Categorical(probs=probs, logits=logits)
+        self._binomial = Binomial(total_count=total_count, probs=self.probs)
+        batch_shape = self._categorical.batch_shape
+        event_shape = self._categorical.param_shape[-1:]
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Multinomial, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.total_count = self.total_count
+        new._categorical = self._categorical.expand(batch_shape)
+        super(Multinomial, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._categorical._new(*args, **kwargs)
+
+    @constraints.dependent_property(is_discrete=True, event_dim=1)
+    def support(self):
+        return constraints.multinomial(self.total_count)
+
+    @property
+    def logits(self):
+        return self._categorical.logits
+
+    @property
+    def probs(self):
+        return self._categorical.probs
+
+    @property
+    def param_shape(self):
+        return self._categorical.param_shape
+
+    def sample(self, sample_shape=torch.Size()):
+        sample_shape = torch.Size(sample_shape)
+        samples = self._categorical.sample(
+            torch.Size((self.total_count,)) + sample_shape
+        )
+        # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
+        # (sample_shape, batch_shape, total_count)
+        shifted_idx = list(range(samples.dim()))
+        shifted_idx.append(shifted_idx.pop(0))
+        samples = samples.permute(*shifted_idx)
+        counts = samples.new(self._extended_shape(sample_shape)).zero_()
+        counts.scatter_add_(-1, samples, torch.ones_like(samples))
+        return counts.type_as(self.probs)
+
+    def entropy(self):
+        n = torch.tensor(self.total_count)
+
+        cat_entropy = self._categorical.entropy()
+        term1 = n * cat_entropy - torch.lgamma(n + 1)
+
+        support = self._binomial.enumerate_support(expand=False)[1:]
+        binomial_probs = torch.exp(self._binomial.log_prob(support))
+        weights = torch.lgamma(support + 1)
+        term2 = (binomial_probs * weights).sum([0, -1])
+
+        return term1 + term2
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        logits, value = broadcast_all(self.logits, value)
+        logits = logits.clone(memory_format=torch.contiguous_format)
+        log_factorial_n = torch.lgamma(value.sum(-1) + 1)
+        log_factorial_xs = torch.lgamma(value + 1).sum(-1)
+        logits[(value == 0) & (logits == -inf)] = 0
+        log_powers = (logits * value).sum(-1)
+        return log_factorial_n - log_factorial_xs + log_powers
diff --git a/MLPY/Lib/site-packages/torch/distributions/multivariate_normal.py b/MLPY/Lib/site-packages/torch/distributions/multivariate_normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed25dbf3e0edffa6b65d421502210f87b1cdef4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/multivariate_normal.py
@@ -0,0 +1,262 @@
+import math
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import _standard_normal, lazy_property
+
+__all__ = ["MultivariateNormal"]
+
+
+def _batch_mv(bmat, bvec):
+    r"""
+    Performs a batched matrix-vector product, with compatible but different batch shapes.
+
+    This function takes as input `bmat`, containing :math:`n \times n` matrices, and
+    `bvec`, containing length :math:`n` vectors.
+
+    Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
+    to a batch shape. They are not necessarily assumed to have the same batch shape,
+    just ones which can be broadcasted.
+    """
+    return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
+
+
+def _batch_mahalanobis(bL, bx):
+    r"""
+    Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
+    for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
+
+    Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
+    shape, but `bL` one should be able to broadcasted to `bx` one.
+    """
+    n = bx.size(-1)
+    bx_batch_shape = bx.shape[:-1]
+
+    # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
+    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
+    bx_batch_dims = len(bx_batch_shape)
+    bL_batch_dims = bL.dim() - 2
+    outer_batch_dims = bx_batch_dims - bL_batch_dims
+    old_batch_dims = outer_batch_dims + bL_batch_dims
+    new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
+    # Reshape bx with the shape (..., 1, i, j, 1, n)
+    bx_new_shape = bx.shape[:outer_batch_dims]
+    for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
+        bx_new_shape += (sx // sL, sL)
+    bx_new_shape += (n,)
+    bx = bx.reshape(bx_new_shape)
+    # Permute bx to make it have shape (..., 1, j, i, 1, n)
+    permute_dims = (
+        list(range(outer_batch_dims))
+        + list(range(outer_batch_dims, new_batch_dims, 2))
+        + list(range(outer_batch_dims + 1, new_batch_dims, 2))
+        + [new_batch_dims]
+    )
+    bx = bx.permute(permute_dims)
+
+    flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
+    flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
+    flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
+    M_swap = (
+        torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
+    )  # shape = b x c
+    M = M_swap.t()  # shape = c x b
+
+    # Now we revert the above reshape and permute operators.
+    permuted_M = M.reshape(bx.shape[:-1])  # shape = (..., 1, j, i, 1)
+    permute_inv_dims = list(range(outer_batch_dims))
+    for i in range(bL_batch_dims):
+        permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
+    reshaped_M = permuted_M.permute(permute_inv_dims)  # shape = (..., 1, i, j, 1)
+    return reshaped_M.reshape(bx_batch_shape)
+
+
+def _precision_to_scale_tril(P):
+    # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
+    Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
+    L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
+    Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
+    L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
+    return L
+
+
+class MultivariateNormal(Distribution):
+    r"""
+    Creates a multivariate normal (also called Gaussian) distribution
+    parameterized by a mean vector and a covariance matrix.
+
+    The multivariate normal distribution can be parameterized either
+    in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
+    or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
+    or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
+    diagonal entries, such that
+    :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
+    can be obtained via e.g. Cholesky decomposition of the covariance.
+
+    Example:
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
+        >>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
+        tensor([-0.2102, -0.5429])
+
+    Args:
+        loc (Tensor): mean of the distribution
+        covariance_matrix (Tensor): positive-definite covariance matrix
+        precision_matrix (Tensor): positive-definite precision matrix
+        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
+
+    Note:
+        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
+        :attr:`scale_tril` can be specified.
+
+        Using :attr:`scale_tril` will be more efficient: all computations internally
+        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
+        :attr:`precision_matrix` is passed instead, it is only used to compute
+        the corresponding lower triangular matrices using a Cholesky decomposition.
+    """
+    arg_constraints = {
+        "loc": constraints.real_vector,
+        "covariance_matrix": constraints.positive_definite,
+        "precision_matrix": constraints.positive_definite,
+        "scale_tril": constraints.lower_cholesky,
+    }
+    support = constraints.real_vector
+    has_rsample = True
+
+    def __init__(
+        self,
+        loc,
+        covariance_matrix=None,
+        precision_matrix=None,
+        scale_tril=None,
+        validate_args=None,
+    ):
+        if loc.dim() < 1:
+            raise ValueError("loc must be at least one-dimensional.")
+        if (covariance_matrix is not None) + (scale_tril is not None) + (
+            precision_matrix is not None
+        ) != 1:
+            raise ValueError(
+                "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
+            )
+
+        if scale_tril is not None:
+            if scale_tril.dim() < 2:
+                raise ValueError(
+                    "scale_tril matrix must be at least two-dimensional, "
+                    "with optional leading batch dimensions"
+                )
+            batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
+            self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
+        elif covariance_matrix is not None:
+            if covariance_matrix.dim() < 2:
+                raise ValueError(
+                    "covariance_matrix must be at least two-dimensional, "
+                    "with optional leading batch dimensions"
+                )
+            batch_shape = torch.broadcast_shapes(
+                covariance_matrix.shape[:-2], loc.shape[:-1]
+            )
+            self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
+        else:
+            if precision_matrix.dim() < 2:
+                raise ValueError(
+                    "precision_matrix must be at least two-dimensional, "
+                    "with optional leading batch dimensions"
+                )
+            batch_shape = torch.broadcast_shapes(
+                precision_matrix.shape[:-2], loc.shape[:-1]
+            )
+            self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
+        self.loc = loc.expand(batch_shape + (-1,))
+
+        event_shape = self.loc.shape[-1:]
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+        if scale_tril is not None:
+            self._unbroadcasted_scale_tril = scale_tril
+        elif covariance_matrix is not None:
+            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
+        else:  # precision_matrix is not None
+            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(MultivariateNormal, _instance)
+        batch_shape = torch.Size(batch_shape)
+        loc_shape = batch_shape + self.event_shape
+        cov_shape = batch_shape + self.event_shape + self.event_shape
+        new.loc = self.loc.expand(loc_shape)
+        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
+        if "covariance_matrix" in self.__dict__:
+            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
+        if "scale_tril" in self.__dict__:
+            new.scale_tril = self.scale_tril.expand(cov_shape)
+        if "precision_matrix" in self.__dict__:
+            new.precision_matrix = self.precision_matrix.expand(cov_shape)
+        super(MultivariateNormal, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    @lazy_property
+    def scale_tril(self):
+        return self._unbroadcasted_scale_tril.expand(
+            self._batch_shape + self._event_shape + self._event_shape
+        )
+
+    @lazy_property
+    def covariance_matrix(self):
+        return torch.matmul(
+            self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
+        ).expand(self._batch_shape + self._event_shape + self._event_shape)
+
+    @lazy_property
+    def precision_matrix(self):
+        return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
+            self._batch_shape + self._event_shape + self._event_shape
+        )
+
+    @property
+    def mean(self):
+        return self.loc
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @property
+    def variance(self):
+        return (
+            self._unbroadcasted_scale_tril.pow(2)
+            .sum(-1)
+            .expand(self._batch_shape + self._event_shape)
+        )
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+        return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        diff = value - self.loc
+        M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
+        half_log_det = (
+            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+        )
+        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
+
+    def entropy(self):
+        half_log_det = (
+            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+        )
+        H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
+        if len(self._batch_shape) == 0:
+            return H
+        else:
+            return H.expand(self._batch_shape)
diff --git a/MLPY/Lib/site-packages/torch/distributions/negative_binomial.py b/MLPY/Lib/site-packages/torch/distributions/negative_binomial.py
new file mode 100644
index 0000000000000000000000000000000000000000..b55754a0a0fe17c1d50444a76f9314e8a052f0d7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/negative_binomial.py
@@ -0,0 +1,133 @@
+import torch
+import torch.nn.functional as F
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import (
+    broadcast_all,
+    lazy_property,
+    logits_to_probs,
+    probs_to_logits,
+)
+
+__all__ = ["NegativeBinomial"]
+
+
+class NegativeBinomial(Distribution):
+    r"""
+    Creates a Negative Binomial distribution, i.e. distribution
+    of the number of successful independent and identical Bernoulli trials
+    before :attr:`total_count` failures are achieved. The probability
+    of success of each Bernoulli trial is :attr:`probs`.
+
+    Args:
+        total_count (float or Tensor): non-negative number of negative Bernoulli
+            trials to stop, although the distribution is still valid for real
+            valued count
+        probs (Tensor): Event probabilities of success in the half open interval [0, 1)
+        logits (Tensor): Event log-odds for probabilities of success
+    """
+    arg_constraints = {
+        "total_count": constraints.greater_than_eq(0),
+        "probs": constraints.half_open_interval(0.0, 1.0),
+        "logits": constraints.real,
+    }
+    support = constraints.nonnegative_integer
+
+    def __init__(self, total_count, probs=None, logits=None, validate_args=None):
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            (
+                self.total_count,
+                self.probs,
+            ) = broadcast_all(total_count, probs)
+            self.total_count = self.total_count.type_as(self.probs)
+        else:
+            (
+                self.total_count,
+                self.logits,
+            ) = broadcast_all(total_count, logits)
+            self.total_count = self.total_count.type_as(self.logits)
+
+        self._param = self.probs if probs is not None else self.logits
+        batch_shape = self._param.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(NegativeBinomial, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.total_count = self.total_count.expand(batch_shape)
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(batch_shape)
+            new._param = new.probs
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(batch_shape)
+            new._param = new.logits
+        super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._param.new(*args, **kwargs)
+
+    @property
+    def mean(self):
+        return self.total_count * torch.exp(self.logits)
+
+    @property
+    def mode(self):
+        return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0)
+
+    @property
+    def variance(self):
+        return self.mean / torch.sigmoid(-self.logits)
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs, is_binary=True)
+
+    @lazy_property
+    def probs(self):
+        return logits_to_probs(self.logits, is_binary=True)
+
+    @property
+    def param_shape(self):
+        return self._param.size()
+
+    @lazy_property
+    def _gamma(self):
+        # Note we avoid validating because self.total_count can be zero.
+        return torch.distributions.Gamma(
+            concentration=self.total_count,
+            rate=torch.exp(-self.logits),
+            validate_args=False,
+        )
+
+    def sample(self, sample_shape=torch.Size()):
+        with torch.no_grad():
+            rate = self._gamma.sample(sample_shape=sample_shape)
+            return torch.poisson(rate)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+
+        log_unnormalized_prob = self.total_count * F.logsigmoid(
+            -self.logits
+        ) + value * F.logsigmoid(self.logits)
+
+        log_normalization = (
+            -torch.lgamma(self.total_count + value)
+            + torch.lgamma(1.0 + value)
+            + torch.lgamma(self.total_count)
+        )
+        # The case self.total_count == 0 and value == 0 has probability 1 but
+        # lgamma(0) is infinite. Handle this case separately using a function
+        # that does not modify tensors in place to allow Jit compilation.
+        log_normalization = log_normalization.masked_fill(
+            self.total_count + value == 0.0, 0.0
+        )
+
+        return log_unnormalized_prob - log_normalization
diff --git a/MLPY/Lib/site-packages/torch/distributions/normal.py b/MLPY/Lib/site-packages/torch/distributions/normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..11079ebb0618f783d5969d5121fbf494524ae9d1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/normal.py
@@ -0,0 +1,109 @@
+import math
+from numbers import Number, Real
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import _standard_normal, broadcast_all
+
+__all__ = ["Normal"]
+
+
+class Normal(ExponentialFamily):
+    r"""
+    Creates a normal (also called Gaussian) distribution parameterized by
+    :attr:`loc` and :attr:`scale`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
+        >>> m.sample()  # normally distributed with loc=0 and scale=1
+        tensor([ 0.1046])
+
+    Args:
+        loc (float or Tensor): mean of the distribution (often referred to as mu)
+        scale (float or Tensor): standard deviation of the distribution
+            (often referred to as sigma)
+    """
+    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+    support = constraints.real
+    has_rsample = True
+    _mean_carrier_measure = 0
+
+    @property
+    def mean(self):
+        return self.loc
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @property
+    def stddev(self):
+        return self.scale
+
+    @property
+    def variance(self):
+        return self.stddev.pow(2)
+
+    def __init__(self, loc, scale, validate_args=None):
+        self.loc, self.scale = broadcast_all(loc, scale)
+        if isinstance(loc, Number) and isinstance(scale, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.loc.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Normal, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.loc = self.loc.expand(batch_shape)
+        new.scale = self.scale.expand(batch_shape)
+        super(Normal, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def sample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        with torch.no_grad():
+            return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+        return self.loc + eps * self.scale
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        # compute the variance
+        var = self.scale**2
+        log_scale = (
+            math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
+        )
+        return (
+            -((value - self.loc) ** 2) / (2 * var)
+            - log_scale
+            - math.log(math.sqrt(2 * math.pi))
+        )
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        return 0.5 * (
+            1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))
+        )
+
+    def icdf(self, value):
+        return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
+
+    def entropy(self):
+        return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
+
+    @property
+    def _natural_params(self):
+        return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
+
+    def _log_normalizer(self, x, y):
+        return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
diff --git a/MLPY/Lib/site-packages/torch/distributions/one_hot_categorical.py b/MLPY/Lib/site-packages/torch/distributions/one_hot_categorical.py
new file mode 100644
index 0000000000000000000000000000000000000000..74f354f2ca1a4458d5a273d82cc0481669fc52ad
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/one_hot_categorical.py
@@ -0,0 +1,129 @@
+import torch
+from torch.distributions import constraints
+from torch.distributions.categorical import Categorical
+from torch.distributions.distribution import Distribution
+
+__all__ = ["OneHotCategorical", "OneHotCategoricalStraightThrough"]
+
+
+class OneHotCategorical(Distribution):
+    r"""
+    Creates a one-hot categorical distribution parameterized by :attr:`probs` or
+    :attr:`logits`.
+
+    Samples are one-hot coded vectors of size ``probs.size(-1)``.
+
+    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
+              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
+              will return this normalized value.
+              The `logits` argument will be interpreted as unnormalized log probabilities
+              and can therefore be any real number. It will likewise be normalized so that
+              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
+              will return this normalized value.
+
+    See also: :func:`torch.distributions.Categorical` for specifications of
+    :attr:`probs` and :attr:`logits`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
+        >>> m.sample()  # equal probability of 0, 1, 2, 3
+        tensor([ 0.,  0.,  0.,  1.])
+
+    Args:
+        probs (Tensor): event probabilities
+        logits (Tensor): event log probabilities (unnormalized)
+    """
+    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
+    support = constraints.one_hot
+    has_enumerate_support = True
+
+    def __init__(self, probs=None, logits=None, validate_args=None):
+        self._categorical = Categorical(probs, logits)
+        batch_shape = self._categorical.batch_shape
+        event_shape = self._categorical.param_shape[-1:]
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(OneHotCategorical, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new._categorical = self._categorical.expand(batch_shape)
+        super(OneHotCategorical, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._categorical._new(*args, **kwargs)
+
+    @property
+    def _param(self):
+        return self._categorical._param
+
+    @property
+    def probs(self):
+        return self._categorical.probs
+
+    @property
+    def logits(self):
+        return self._categorical.logits
+
+    @property
+    def mean(self):
+        return self._categorical.probs
+
+    @property
+    def mode(self):
+        probs = self._categorical.probs
+        mode = probs.argmax(axis=-1)
+        return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs)
+
+    @property
+    def variance(self):
+        return self._categorical.probs * (1 - self._categorical.probs)
+
+    @property
+    def param_shape(self):
+        return self._categorical.param_shape
+
+    def sample(self, sample_shape=torch.Size()):
+        sample_shape = torch.Size(sample_shape)
+        probs = self._categorical.probs
+        num_events = self._categorical._num_events
+        indices = self._categorical.sample(sample_shape)
+        return torch.nn.functional.one_hot(indices, num_events).to(probs)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        indices = value.max(-1)[1]
+        return self._categorical.log_prob(indices)
+
+    def entropy(self):
+        return self._categorical.entropy()
+
+    def enumerate_support(self, expand=True):
+        n = self.event_shape[0]
+        values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
+        values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
+        if expand:
+            values = values.expand((n,) + self.batch_shape + (n,))
+        return values
+
+
+class OneHotCategoricalStraightThrough(OneHotCategorical):
+    r"""
+    Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight-
+    through gradient estimator from [1].
+
+    [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
+    (Bengio et al, 2013)
+    """
+    has_rsample = True
+
+    def rsample(self, sample_shape=torch.Size()):
+        samples = self.sample(sample_shape)
+        probs = self._categorical.probs  # cached via @lazy_property
+        return samples + (probs - probs.detach())
diff --git a/MLPY/Lib/site-packages/torch/distributions/pareto.py b/MLPY/Lib/site-packages/torch/distributions/pareto.py
new file mode 100644
index 0000000000000000000000000000000000000000..3297b47488ecd2e5143b105d461fd695a83add7f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/pareto.py
@@ -0,0 +1,60 @@
+from torch.distributions import constraints
+from torch.distributions.exponential import Exponential
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, ExpTransform
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Pareto"]
+
+
+class Pareto(TransformedDistribution):
+    r"""
+    Samples from a Pareto Type 1 distribution.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))
+        >>> m.sample()  # sample from a Pareto distribution with scale=1 and alpha=1
+        tensor([ 1.5623])
+
+    Args:
+        scale (float or Tensor): Scale parameter of the distribution
+        alpha (float or Tensor): Shape parameter of the distribution
+    """
+    arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive}
+
+    def __init__(self, scale, alpha, validate_args=None):
+        self.scale, self.alpha = broadcast_all(scale, alpha)
+        base_dist = Exponential(self.alpha, validate_args=validate_args)
+        transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
+        super().__init__(base_dist, transforms, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Pareto, _instance)
+        new.scale = self.scale.expand(batch_shape)
+        new.alpha = self.alpha.expand(batch_shape)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def mean(self):
+        # mean is inf for alpha <= 1
+        a = self.alpha.clamp(min=1)
+        return a * self.scale / (a - 1)
+
+    @property
+    def mode(self):
+        return self.scale
+
+    @property
+    def variance(self):
+        # var is inf for alpha <= 2
+        a = self.alpha.clamp(min=2)
+        return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
+
+    @constraints.dependent_property(is_discrete=False, event_dim=0)
+    def support(self):
+        return constraints.greater_than_eq(self.scale)
+
+    def entropy(self):
+        return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())
diff --git a/MLPY/Lib/site-packages/torch/distributions/poisson.py b/MLPY/Lib/site-packages/torch/distributions/poisson.py
new file mode 100644
index 0000000000000000000000000000000000000000..fac9ee0aa691c16956063d4b3dfd7e6f1e670d8e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/poisson.py
@@ -0,0 +1,77 @@
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Poisson"]
+
+
+class Poisson(ExponentialFamily):
+    r"""
+    Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
+
+    Samples are nonnegative integers, with a pmf given by
+
+    .. math::
+      \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
+
+    Example::
+
+        >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
+        >>> m = Poisson(torch.tensor([4]))
+        >>> m.sample()
+        tensor([ 3.])
+
+    Args:
+        rate (Number, Tensor): the rate parameter
+    """
+    arg_constraints = {"rate": constraints.nonnegative}
+    support = constraints.nonnegative_integer
+
+    @property
+    def mean(self):
+        return self.rate
+
+    @property
+    def mode(self):
+        return self.rate.floor()
+
+    @property
+    def variance(self):
+        return self.rate
+
+    def __init__(self, rate, validate_args=None):
+        (self.rate,) = broadcast_all(rate)
+        if isinstance(rate, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.rate.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Poisson, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.rate = self.rate.expand(batch_shape)
+        super(Poisson, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def sample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        with torch.no_grad():
+            return torch.poisson(self.rate.expand(shape))
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        rate, value = broadcast_all(self.rate, value)
+        return value.xlogy(rate) - rate - (value + 1).lgamma()
+
+    @property
+    def _natural_params(self):
+        return (torch.log(self.rate),)
+
+    def _log_normalizer(self, x):
+        return torch.exp(x)
diff --git a/MLPY/Lib/site-packages/torch/distributions/relaxed_bernoulli.py b/MLPY/Lib/site-packages/torch/distributions/relaxed_bernoulli.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f56aad243aaf7f0c456e4e61679d1df3d53563e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/relaxed_bernoulli.py
@@ -0,0 +1,149 @@
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import SigmoidTransform
+from torch.distributions.utils import (
+    broadcast_all,
+    clamp_probs,
+    lazy_property,
+    logits_to_probs,
+    probs_to_logits,
+)
+
+__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"]
+
+
+class LogitRelaxedBernoulli(Distribution):
+    r"""
+    Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs`
+    or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli
+    distribution.
+
+    Samples are logits of values in (0, 1). See [1] for more details.
+
+    Args:
+        temperature (Tensor): relaxation temperature
+        probs (Number, Tensor): the probability of sampling `1`
+        logits (Number, Tensor): the log-odds of sampling `1`
+
+    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random
+    Variables (Maddison et al, 2017)
+
+    [2] Categorical Reparametrization with Gumbel-Softmax
+    (Jang et al, 2017)
+    """
+    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
+    support = constraints.real
+
+    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+        self.temperature = temperature
+        if (probs is None) == (logits is None):
+            raise ValueError(
+                "Either `probs` or `logits` must be specified, but not both."
+            )
+        if probs is not None:
+            is_scalar = isinstance(probs, Number)
+            (self.probs,) = broadcast_all(probs)
+        else:
+            is_scalar = isinstance(logits, Number)
+            (self.logits,) = broadcast_all(logits)
+        self._param = self.probs if probs is not None else self.logits
+        if is_scalar:
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self._param.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(LogitRelaxedBernoulli, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.temperature = self.temperature
+        if "probs" in self.__dict__:
+            new.probs = self.probs.expand(batch_shape)
+            new._param = new.probs
+        if "logits" in self.__dict__:
+            new.logits = self.logits.expand(batch_shape)
+            new._param = new.logits
+        super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._param.new(*args, **kwargs)
+
+    @lazy_property
+    def logits(self):
+        return probs_to_logits(self.probs, is_binary=True)
+
+    @lazy_property
+    def probs(self):
+        return logits_to_probs(self.logits, is_binary=True)
+
+    @property
+    def param_shape(self):
+        return self._param.size()
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        probs = clamp_probs(self.probs.expand(shape))
+        uniforms = clamp_probs(
+            torch.rand(shape, dtype=probs.dtype, device=probs.device)
+        )
+        return (
+            uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()
+        ) / self.temperature
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        logits, value = broadcast_all(self.logits, value)
+        diff = logits - value.mul(self.temperature)
+        return self.temperature.log() + diff - 2 * diff.exp().log1p()
+
+
+class RelaxedBernoulli(TransformedDistribution):
+    r"""
+    Creates a RelaxedBernoulli distribution, parametrized by
+    :attr:`temperature`, and either :attr:`probs` or :attr:`logits`
+    (but not both). This is a relaxed version of the `Bernoulli` distribution,
+    so the values are in (0, 1), and has reparametrizable samples.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = RelaxedBernoulli(torch.tensor([2.2]),
+        ...                      torch.tensor([0.1, 0.2, 0.3, 0.99]))
+        >>> m.sample()
+        tensor([ 0.2951,  0.3442,  0.8918,  0.9021])
+
+    Args:
+        temperature (Tensor): relaxation temperature
+        probs (Number, Tensor): the probability of sampling `1`
+        logits (Number, Tensor): the log-odds of sampling `1`
+    """
+    arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
+    support = constraints.unit_interval
+    has_rsample = True
+
+    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+        base_dist = LogitRelaxedBernoulli(temperature, probs, logits)
+        super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(RelaxedBernoulli, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def temperature(self):
+        return self.base_dist.temperature
+
+    @property
+    def logits(self):
+        return self.base_dist.logits
+
+    @property
+    def probs(self):
+        return self.base_dist.probs
diff --git a/MLPY/Lib/site-packages/torch/distributions/relaxed_categorical.py b/MLPY/Lib/site-packages/torch/distributions/relaxed_categorical.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cac7b9c285a51538d9c7219584285564dfc807e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/relaxed_categorical.py
@@ -0,0 +1,139 @@
+import torch
+from torch.distributions import constraints
+from torch.distributions.categorical import Categorical
+from torch.distributions.distribution import Distribution
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import ExpTransform
+from torch.distributions.utils import broadcast_all, clamp_probs
+
+__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"]
+
+
+class ExpRelaxedCategorical(Distribution):
+    r"""
+    Creates a ExpRelaxedCategorical parameterized by
+    :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
+    Returns the log of a point in the simplex. Based on the interface to
+    :class:`OneHotCategorical`.
+
+    Implementation based on [1].
+
+    See also: :func:`torch.distributions.OneHotCategorical`
+
+    Args:
+        temperature (Tensor): relaxation temperature
+        probs (Tensor): event probabilities
+        logits (Tensor): unnormalized log probability for each event
+
+    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
+    (Maddison et al, 2017)
+
+    [2] Categorical Reparametrization with Gumbel-Softmax
+    (Jang et al, 2017)
+    """
+    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
+    support = (
+        constraints.real_vector
+    )  # The true support is actually a submanifold of this.
+    has_rsample = True
+
+    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+        self._categorical = Categorical(probs, logits)
+        self.temperature = temperature
+        batch_shape = self._categorical.batch_shape
+        event_shape = self._categorical.param_shape[-1:]
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.temperature = self.temperature
+        new._categorical = self._categorical.expand(batch_shape)
+        super(ExpRelaxedCategorical, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    def _new(self, *args, **kwargs):
+        return self._categorical._new(*args, **kwargs)
+
+    @property
+    def param_shape(self):
+        return self._categorical.param_shape
+
+    @property
+    def logits(self):
+        return self._categorical.logits
+
+    @property
+    def probs(self):
+        return self._categorical.probs
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        uniforms = clamp_probs(
+            torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
+        )
+        gumbels = -((-(uniforms.log())).log())
+        scores = (self.logits + gumbels) / self.temperature
+        return scores - scores.logsumexp(dim=-1, keepdim=True)
+
+    def log_prob(self, value):
+        K = self._categorical._num_events
+        if self._validate_args:
+            self._validate_sample(value)
+        logits, value = broadcast_all(self.logits, value)
+        log_scale = torch.full_like(
+            self.temperature, float(K)
+        ).lgamma() - self.temperature.log().mul(-(K - 1))
+        score = logits - value.mul(self.temperature)
+        score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
+        return score + log_scale
+
+
+class RelaxedOneHotCategorical(TransformedDistribution):
+    r"""
+    Creates a RelaxedOneHotCategorical distribution parametrized by
+    :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
+    This is a relaxed version of the :class:`OneHotCategorical` distribution, so
+    its samples are on simplex, and are reparametrizable.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
+        ...                              torch.tensor([0.1, 0.2, 0.3, 0.4]))
+        >>> m.sample()
+        tensor([ 0.1294,  0.2324,  0.3859,  0.2523])
+
+    Args:
+        temperature (Tensor): relaxation temperature
+        probs (Tensor): event probabilities
+        logits (Tensor): unnormalized log probability for each event
+    """
+    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
+    support = constraints.simplex
+    has_rsample = True
+
+    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+        base_dist = ExpRelaxedCategorical(
+            temperature, probs, logits, validate_args=validate_args
+        )
+        super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
+        return super().expand(batch_shape, _instance=new)
+
+    @property
+    def temperature(self):
+        return self.base_dist.temperature
+
+    @property
+    def logits(self):
+        return self.base_dist.logits
+
+    @property
+    def probs(self):
+        return self.base_dist.probs
diff --git a/MLPY/Lib/site-packages/torch/distributions/studentT.py b/MLPY/Lib/site-packages/torch/distributions/studentT.py
new file mode 100644
index 0000000000000000000000000000000000000000..7881b5e50088d071a0a778c7229198b66f2c00b5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/studentT.py
@@ -0,0 +1,116 @@
+import math
+
+import torch
+from torch import inf, nan
+from torch.distributions import Chi2, constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import _standard_normal, broadcast_all
+
+__all__ = ["StudentT"]
+
+
+class StudentT(Distribution):
+    r"""
+    Creates a Student's t-distribution parameterized by degree of
+    freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
+
+    Example::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = StudentT(torch.tensor([2.0]))
+        >>> m.sample()  # Student's t-distributed with degrees of freedom=2
+        tensor([ 0.1046])
+
+    Args:
+        df (float or Tensor): degrees of freedom
+        loc (float or Tensor): mean of the distribution
+        scale (float or Tensor): scale of the distribution
+    """
+    arg_constraints = {
+        "df": constraints.positive,
+        "loc": constraints.real,
+        "scale": constraints.positive,
+    }
+    support = constraints.real
+    has_rsample = True
+
+    @property
+    def mean(self):
+        m = self.loc.clone(memory_format=torch.contiguous_format)
+        m[self.df <= 1] = nan
+        return m
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @property
+    def variance(self):
+        m = self.df.clone(memory_format=torch.contiguous_format)
+        m[self.df > 2] = (
+            self.scale[self.df > 2].pow(2)
+            * self.df[self.df > 2]
+            / (self.df[self.df > 2] - 2)
+        )
+        m[(self.df <= 2) & (self.df > 1)] = inf
+        m[self.df <= 1] = nan
+        return m
+
+    def __init__(self, df, loc=0.0, scale=1.0, validate_args=None):
+        self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
+        self._chi2 = Chi2(self.df)
+        batch_shape = self.df.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(StudentT, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.df = self.df.expand(batch_shape)
+        new.loc = self.loc.expand(batch_shape)
+        new.scale = self.scale.expand(batch_shape)
+        new._chi2 = self._chi2.expand(batch_shape)
+        super(StudentT, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    def rsample(self, sample_shape=torch.Size()):
+        # NOTE: This does not agree with scipy implementation as much as other distributions.
+        # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
+        # parameters seems to help.
+
+        #   X ~ Normal(0, 1)
+        #   Z ~ Chi2(df)
+        #   Y = X / sqrt(Z / df) ~ StudentT(df)
+        shape = self._extended_shape(sample_shape)
+        X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
+        Z = self._chi2.rsample(sample_shape)
+        Y = X * torch.rsqrt(Z / self.df)
+        return self.loc + self.scale * Y
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        y = (value - self.loc) / self.scale
+        Z = (
+            self.scale.log()
+            + 0.5 * self.df.log()
+            + 0.5 * math.log(math.pi)
+            + torch.lgamma(0.5 * self.df)
+            - torch.lgamma(0.5 * (self.df + 1.0))
+        )
+        return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
+
+    def entropy(self):
+        lbeta = (
+            torch.lgamma(0.5 * self.df)
+            + math.lgamma(0.5)
+            - torch.lgamma(0.5 * (self.df + 1))
+        )
+        return (
+            self.scale.log()
+            + 0.5
+            * (self.df + 1)
+            * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
+            + 0.5 * self.df.log()
+            + lbeta
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/transformed_distribution.py b/MLPY/Lib/site-packages/torch/distributions/transformed_distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..6165c8cf39087325ae84c42d612f836c7daba2a3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/transformed_distribution.py
@@ -0,0 +1,215 @@
+from typing import Dict
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.independent import Independent
+from torch.distributions.transforms import ComposeTransform, Transform
+from torch.distributions.utils import _sum_rightmost
+
+__all__ = ["TransformedDistribution"]
+
+
+class TransformedDistribution(Distribution):
+    r"""
+    Extension of the Distribution class, which applies a sequence of Transforms
+    to a base distribution.  Let f be the composition of transforms applied::
+
+        X ~ BaseDistribution
+        Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
+        log p(Y) = log p(X) + log |det (dX/dY)|
+
+    Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
+    maximum shape of its base distribution and its transforms, since transforms
+    can introduce correlations among events.
+
+    An example for the usage of :class:`TransformedDistribution` would be::
+
+        # Building a Logistic Distribution
+        # X ~ Uniform(0, 1)
+        # f = a + b * logit(X)
+        # Y ~ f(X) ~ Logistic(a, b)
+        base_distribution = Uniform(0, 1)
+        transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
+        logistic = TransformedDistribution(base_distribution, transforms)
+
+    For more examples, please look at the implementations of
+    :class:`~torch.distributions.gumbel.Gumbel`,
+    :class:`~torch.distributions.half_cauchy.HalfCauchy`,
+    :class:`~torch.distributions.half_normal.HalfNormal`,
+    :class:`~torch.distributions.log_normal.LogNormal`,
+    :class:`~torch.distributions.pareto.Pareto`,
+    :class:`~torch.distributions.weibull.Weibull`,
+    :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
+    :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
+    """
+    arg_constraints: Dict[str, constraints.Constraint] = {}
+
+    def __init__(self, base_distribution, transforms, validate_args=None):
+        if isinstance(transforms, Transform):
+            self.transforms = [
+                transforms,
+            ]
+        elif isinstance(transforms, list):
+            if not all(isinstance(t, Transform) for t in transforms):
+                raise ValueError(
+                    "transforms must be a Transform or a list of Transforms"
+                )
+            self.transforms = transforms
+        else:
+            raise ValueError(
+                f"transforms must be a Transform or list, but was {transforms}"
+            )
+
+        # Reshape base_distribution according to transforms.
+        base_shape = base_distribution.batch_shape + base_distribution.event_shape
+        base_event_dim = len(base_distribution.event_shape)
+        transform = ComposeTransform(self.transforms)
+        if len(base_shape) < transform.domain.event_dim:
+            raise ValueError(
+                "base_distribution needs to have shape with size at least {}, but got {}.".format(
+                    transform.domain.event_dim, base_shape
+                )
+            )
+        forward_shape = transform.forward_shape(base_shape)
+        expanded_base_shape = transform.inverse_shape(forward_shape)
+        if base_shape != expanded_base_shape:
+            base_batch_shape = expanded_base_shape[
+                : len(expanded_base_shape) - base_event_dim
+            ]
+            base_distribution = base_distribution.expand(base_batch_shape)
+        reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
+        if reinterpreted_batch_ndims > 0:
+            base_distribution = Independent(
+                base_distribution, reinterpreted_batch_ndims
+            )
+        self.base_dist = base_distribution
+
+        # Compute shapes.
+        transform_change_in_event_dim = (
+            transform.codomain.event_dim - transform.domain.event_dim
+        )
+        event_dim = max(
+            transform.codomain.event_dim,  # the transform is coupled
+            base_event_dim + transform_change_in_event_dim,  # the base dist is coupled
+        )
+        assert len(forward_shape) >= event_dim
+        cut = len(forward_shape) - event_dim
+        batch_shape = forward_shape[:cut]
+        event_shape = forward_shape[cut:]
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(TransformedDistribution, _instance)
+        batch_shape = torch.Size(batch_shape)
+        shape = batch_shape + self.event_shape
+        for t in reversed(self.transforms):
+            shape = t.inverse_shape(shape)
+        base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
+        new.base_dist = self.base_dist.expand(base_batch_shape)
+        new.transforms = self.transforms
+        super(TransformedDistribution, new).__init__(
+            batch_shape, self.event_shape, validate_args=False
+        )
+        new._validate_args = self._validate_args
+        return new
+
+    @constraints.dependent_property(is_discrete=False)
+    def support(self):
+        if not self.transforms:
+            return self.base_dist.support
+        support = self.transforms[-1].codomain
+        if len(self.event_shape) > support.event_dim:
+            support = constraints.independent(
+                support, len(self.event_shape) - support.event_dim
+            )
+        return support
+
+    @property
+    def has_rsample(self):
+        return self.base_dist.has_rsample
+
+    def sample(self, sample_shape=torch.Size()):
+        """
+        Generates a sample_shape shaped sample or sample_shape shaped batch of
+        samples if the distribution parameters are batched. Samples first from
+        base distribution and applies `transform()` for every transform in the
+        list.
+        """
+        with torch.no_grad():
+            x = self.base_dist.sample(sample_shape)
+            for transform in self.transforms:
+                x = transform(x)
+            return x
+
+    def rsample(self, sample_shape=torch.Size()):
+        """
+        Generates a sample_shape shaped reparameterized sample or sample_shape
+        shaped batch of reparameterized samples if the distribution parameters
+        are batched. Samples first from base distribution and applies
+        `transform()` for every transform in the list.
+        """
+        x = self.base_dist.rsample(sample_shape)
+        for transform in self.transforms:
+            x = transform(x)
+        return x
+
+    def log_prob(self, value):
+        """
+        Scores the sample by inverting the transform(s) and computing the score
+        using the score of the base distribution and the log abs det jacobian.
+        """
+        if self._validate_args:
+            self._validate_sample(value)
+        event_dim = len(self.event_shape)
+        log_prob = 0.0
+        y = value
+        for transform in reversed(self.transforms):
+            x = transform.inv(y)
+            event_dim += transform.domain.event_dim - transform.codomain.event_dim
+            log_prob = log_prob - _sum_rightmost(
+                transform.log_abs_det_jacobian(x, y),
+                event_dim - transform.domain.event_dim,
+            )
+            y = x
+
+        log_prob = log_prob + _sum_rightmost(
+            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
+        )
+        return log_prob
+
+    def _monotonize_cdf(self, value):
+        """
+        This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
+        monotone increasing.
+        """
+        sign = 1
+        for transform in self.transforms:
+            sign = sign * transform.sign
+        if isinstance(sign, int) and sign == 1:
+            return value
+        return sign * (value - 0.5) + 0.5
+
+    def cdf(self, value):
+        """
+        Computes the cumulative distribution function by inverting the
+        transform(s) and computing the score of the base distribution.
+        """
+        for transform in self.transforms[::-1]:
+            value = transform.inv(value)
+        if self._validate_args:
+            self.base_dist._validate_sample(value)
+        value = self.base_dist.cdf(value)
+        value = self._monotonize_cdf(value)
+        return value
+
+    def icdf(self, value):
+        """
+        Computes the inverse cumulative distribution function using
+        transform(s) and computing the score of the base distribution.
+        """
+        value = self._monotonize_cdf(value)
+        value = self.base_dist.icdf(value)
+        for transform in self.transforms:
+            value = transform(value)
+        return value
diff --git a/MLPY/Lib/site-packages/torch/distributions/transforms.py b/MLPY/Lib/site-packages/torch/distributions/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..14e44f88014fb9810701e9e31b0d63abf15a039f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/transforms.py
@@ -0,0 +1,1245 @@
+import functools
+import math
+import numbers
+import operator
+import weakref
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from torch.distributions import constraints
+from torch.distributions.utils import (
+    _sum_rightmost,
+    broadcast_all,
+    lazy_property,
+    tril_matrix_to_vec,
+    vec_to_tril_matrix,
+)
+from torch.nn.functional import pad, softplus
+
+__all__ = [
+    "AbsTransform",
+    "AffineTransform",
+    "CatTransform",
+    "ComposeTransform",
+    "CorrCholeskyTransform",
+    "CumulativeDistributionTransform",
+    "ExpTransform",
+    "IndependentTransform",
+    "LowerCholeskyTransform",
+    "PositiveDefiniteTransform",
+    "PowerTransform",
+    "ReshapeTransform",
+    "SigmoidTransform",
+    "SoftplusTransform",
+    "TanhTransform",
+    "SoftmaxTransform",
+    "StackTransform",
+    "StickBreakingTransform",
+    "Transform",
+    "identity_transform",
+]
+
+
+class Transform:
+    """
+    Abstract class for invertable transformations with computable log
+    det jacobians. They are primarily used in
+    :class:`torch.distributions.TransformedDistribution`.
+
+    Caching is useful for transforms whose inverses are either expensive or
+    numerically unstable. Note that care must be taken with memoized values
+    since the autograd graph may be reversed. For example while the following
+    works with or without caching::
+
+        y = t(x)
+        t.log_abs_det_jacobian(x, y).backward()  # x will receive gradients.
+
+    However the following will error when caching due to dependency reversal::
+
+        y = t(x)
+        z = t.inv(y)
+        grad(z.sum(), [y])  # error because z is x
+
+    Derived classes should implement one or both of :meth:`_call` or
+    :meth:`_inverse`. Derived classes that set `bijective=True` should also
+    implement :meth:`log_abs_det_jacobian`.
+
+    Args:
+        cache_size (int): Size of cache. If zero, no caching is done. If one,
+            the latest single value is cached. Only 0 and 1 are supported.
+
+    Attributes:
+        domain (:class:`~torch.distributions.constraints.Constraint`):
+            The constraint representing valid inputs to this transform.
+        codomain (:class:`~torch.distributions.constraints.Constraint`):
+            The constraint representing valid outputs to this transform
+            which are inputs to the inverse transform.
+        bijective (bool): Whether this transform is bijective. A transform
+            ``t`` is bijective iff ``t.inv(t(x)) == x`` and
+            ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
+            the codomain. Transforms that are not bijective should at least
+            maintain the weaker pseudoinverse properties
+            ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
+        sign (int or Tensor): For bijective univariate transforms, this
+            should be +1 or -1 depending on whether transform is monotone
+            increasing or decreasing.
+    """
+
+    bijective = False
+    domain: constraints.Constraint
+    codomain: constraints.Constraint
+
+    def __init__(self, cache_size=0):
+        self._cache_size = cache_size
+        self._inv = None
+        if cache_size == 0:
+            pass  # default behavior
+        elif cache_size == 1:
+            self._cached_x_y = None, None
+        else:
+            raise ValueError("cache_size must be 0 or 1")
+        super().__init__()
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["_inv"] = None
+        return state
+
+    @property
+    def event_dim(self):
+        if self.domain.event_dim == self.codomain.event_dim:
+            return self.domain.event_dim
+        raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
+
+    @property
+    def inv(self):
+        """
+        Returns the inverse :class:`Transform` of this transform.
+        This should satisfy ``t.inv.inv is t``.
+        """
+        inv = None
+        if self._inv is not None:
+            inv = self._inv()
+        if inv is None:
+            inv = _InverseTransform(self)
+            self._inv = weakref.ref(inv)
+        return inv
+
+    @property
+    def sign(self):
+        """
+        Returns the sign of the determinant of the Jacobian, if applicable.
+        In general this only makes sense for bijective transforms.
+        """
+        raise NotImplementedError
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        if type(self).__init__ is Transform.__init__:
+            return type(self)(cache_size=cache_size)
+        raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
+
+    def __eq__(self, other):
+        return self is other
+
+    def __ne__(self, other):
+        # Necessary for Python2
+        return not self.__eq__(other)
+
+    def __call__(self, x):
+        """
+        Computes the transform `x => y`.
+        """
+        if self._cache_size == 0:
+            return self._call(x)
+        x_old, y_old = self._cached_x_y
+        if x is x_old:
+            return y_old
+        y = self._call(x)
+        self._cached_x_y = x, y
+        return y
+
+    def _inv_call(self, y):
+        """
+        Inverts the transform `y => x`.
+        """
+        if self._cache_size == 0:
+            return self._inverse(y)
+        x_old, y_old = self._cached_x_y
+        if y is y_old:
+            return x_old
+        x = self._inverse(y)
+        self._cached_x_y = x, y
+        return x
+
+    def _call(self, x):
+        """
+        Abstract method to compute forward transformation.
+        """
+        raise NotImplementedError
+
+    def _inverse(self, y):
+        """
+        Abstract method to compute inverse transformation.
+        """
+        raise NotImplementedError
+
+    def log_abs_det_jacobian(self, x, y):
+        """
+        Computes the log det jacobian `log |dy/dx|` given input and output.
+        """
+        raise NotImplementedError
+
+    def __repr__(self):
+        return self.__class__.__name__ + "()"
+
+    def forward_shape(self, shape):
+        """
+        Infers the shape of the forward computation, given the input shape.
+        Defaults to preserving shape.
+        """
+        return shape
+
+    def inverse_shape(self, shape):
+        """
+        Infers the shapes of the inverse computation, given the output shape.
+        Defaults to preserving shape.
+        """
+        return shape
+
+
+class _InverseTransform(Transform):
+    """
+    Inverts a single :class:`Transform`.
+    This class is private; please instead use the ``Transform.inv`` property.
+    """
+
+    def __init__(self, transform: Transform):
+        super().__init__(cache_size=transform._cache_size)
+        self._inv: Transform = transform
+
+    @constraints.dependent_property(is_discrete=False)
+    def domain(self):
+        assert self._inv is not None
+        return self._inv.codomain
+
+    @constraints.dependent_property(is_discrete=False)
+    def codomain(self):
+        assert self._inv is not None
+        return self._inv.domain
+
+    @property
+    def bijective(self):
+        assert self._inv is not None
+        return self._inv.bijective
+
+    @property
+    def sign(self):
+        assert self._inv is not None
+        return self._inv.sign
+
+    @property
+    def inv(self):
+        return self._inv
+
+    def with_cache(self, cache_size=1):
+        assert self._inv is not None
+        return self.inv.with_cache(cache_size).inv
+
+    def __eq__(self, other):
+        if not isinstance(other, _InverseTransform):
+            return False
+        assert self._inv is not None
+        return self._inv == other._inv
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self._inv)})"
+
+    def __call__(self, x):
+        assert self._inv is not None
+        return self._inv._inv_call(x)
+
+    def log_abs_det_jacobian(self, x, y):
+        assert self._inv is not None
+        return -self._inv.log_abs_det_jacobian(y, x)
+
+    def forward_shape(self, shape):
+        return self._inv.inverse_shape(shape)
+
+    def inverse_shape(self, shape):
+        return self._inv.forward_shape(shape)
+
+
+class ComposeTransform(Transform):
+    """
+    Composes multiple transforms in a chain.
+    The transforms being composed are responsible for caching.
+
+    Args:
+        parts (list of :class:`Transform`): A list of transforms to compose.
+        cache_size (int): Size of cache. If zero, no caching is done. If one,
+            the latest single value is cached. Only 0 and 1 are supported.
+    """
+
+    def __init__(self, parts: List[Transform], cache_size=0):
+        if cache_size:
+            parts = [part.with_cache(cache_size) for part in parts]
+        super().__init__(cache_size=cache_size)
+        self.parts = parts
+
+    def __eq__(self, other):
+        if not isinstance(other, ComposeTransform):
+            return False
+        return self.parts == other.parts
+
+    @constraints.dependent_property(is_discrete=False)
+    def domain(self):
+        if not self.parts:
+            return constraints.real
+        domain = self.parts[0].domain
+        # Adjust event_dim to be maximum among all parts.
+        event_dim = self.parts[-1].codomain.event_dim
+        for part in reversed(self.parts):
+            event_dim += part.domain.event_dim - part.codomain.event_dim
+            event_dim = max(event_dim, part.domain.event_dim)
+        assert event_dim >= domain.event_dim
+        if event_dim > domain.event_dim:
+            domain = constraints.independent(domain, event_dim - domain.event_dim)
+        return domain
+
+    @constraints.dependent_property(is_discrete=False)
+    def codomain(self):
+        if not self.parts:
+            return constraints.real
+        codomain = self.parts[-1].codomain
+        # Adjust event_dim to be maximum among all parts.
+        event_dim = self.parts[0].domain.event_dim
+        for part in self.parts:
+            event_dim += part.codomain.event_dim - part.domain.event_dim
+            event_dim = max(event_dim, part.codomain.event_dim)
+        assert event_dim >= codomain.event_dim
+        if event_dim > codomain.event_dim:
+            codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
+        return codomain
+
+    @lazy_property
+    def bijective(self):
+        return all(p.bijective for p in self.parts)
+
+    @lazy_property
+    def sign(self):
+        sign = 1
+        for p in self.parts:
+            sign = sign * p.sign
+        return sign
+
+    @property
+    def inv(self):
+        inv = None
+        if self._inv is not None:
+            inv = self._inv()
+        if inv is None:
+            inv = ComposeTransform([p.inv for p in reversed(self.parts)])
+            self._inv = weakref.ref(inv)
+            inv._inv = weakref.ref(self)
+        return inv
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return ComposeTransform(self.parts, cache_size=cache_size)
+
+    def __call__(self, x):
+        for part in self.parts:
+            x = part(x)
+        return x
+
+    def log_abs_det_jacobian(self, x, y):
+        if not self.parts:
+            return torch.zeros_like(x)
+
+        # Compute intermediates. This will be free if parts[:-1] are all cached.
+        xs = [x]
+        for part in self.parts[:-1]:
+            xs.append(part(xs[-1]))
+        xs.append(y)
+
+        terms = []
+        event_dim = self.domain.event_dim
+        for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
+            terms.append(
+                _sum_rightmost(
+                    part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim
+                )
+            )
+            event_dim += part.codomain.event_dim - part.domain.event_dim
+        return functools.reduce(operator.add, terms)
+
+    def forward_shape(self, shape):
+        for part in self.parts:
+            shape = part.forward_shape(shape)
+        return shape
+
+    def inverse_shape(self, shape):
+        for part in reversed(self.parts):
+            shape = part.inverse_shape(shape)
+        return shape
+
+    def __repr__(self):
+        fmt_string = self.__class__.__name__ + "(\n    "
+        fmt_string += ",\n    ".join([p.__repr__() for p in self.parts])
+        fmt_string += "\n)"
+        return fmt_string
+
+
+identity_transform = ComposeTransform([])
+
+
+class IndependentTransform(Transform):
+    """
+    Wrapper around another transform to treat
+    ``reinterpreted_batch_ndims``-many extra of the right most dimensions as
+    dependent. This has no effect on the forward or backward transforms, but
+    does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
+    in :meth:`log_abs_det_jacobian`.
+
+    Args:
+        base_transform (:class:`Transform`): A base transform.
+        reinterpreted_batch_ndims (int): The number of extra rightmost
+            dimensions to treat as dependent.
+    """
+
+    def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
+        super().__init__(cache_size=cache_size)
+        self.base_transform = base_transform.with_cache(cache_size)
+        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return IndependentTransform(
+            self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size
+        )
+
+    @constraints.dependent_property(is_discrete=False)
+    def domain(self):
+        return constraints.independent(
+            self.base_transform.domain, self.reinterpreted_batch_ndims
+        )
+
+    @constraints.dependent_property(is_discrete=False)
+    def codomain(self):
+        return constraints.independent(
+            self.base_transform.codomain, self.reinterpreted_batch_ndims
+        )
+
+    @property
+    def bijective(self):
+        return self.base_transform.bijective
+
+    @property
+    def sign(self):
+        return self.base_transform.sign
+
+    def _call(self, x):
+        if x.dim() < self.domain.event_dim:
+            raise ValueError("Too few dimensions on input")
+        return self.base_transform(x)
+
+    def _inverse(self, y):
+        if y.dim() < self.codomain.event_dim:
+            raise ValueError("Too few dimensions on input")
+        return self.base_transform.inv(y)
+
+    def log_abs_det_jacobian(self, x, y):
+        result = self.base_transform.log_abs_det_jacobian(x, y)
+        result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
+        return result
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
+
+    def forward_shape(self, shape):
+        return self.base_transform.forward_shape(shape)
+
+    def inverse_shape(self, shape):
+        return self.base_transform.inverse_shape(shape)
+
+
+class ReshapeTransform(Transform):
+    """
+    Unit Jacobian transform to reshape the rightmost part of a tensor.
+
+    Note that ``in_shape`` and ``out_shape`` must have the same number of
+    elements, just as for :meth:`torch.Tensor.reshape`.
+
+    Arguments:
+        in_shape (torch.Size): The input event shape.
+        out_shape (torch.Size): The output event shape.
+    """
+
+    bijective = True
+
+    def __init__(self, in_shape, out_shape, cache_size=0):
+        self.in_shape = torch.Size(in_shape)
+        self.out_shape = torch.Size(out_shape)
+        if self.in_shape.numel() != self.out_shape.numel():
+            raise ValueError("in_shape, out_shape have different numbers of elements")
+        super().__init__(cache_size=cache_size)
+
+    @constraints.dependent_property
+    def domain(self):
+        return constraints.independent(constraints.real, len(self.in_shape))
+
+    @constraints.dependent_property
+    def codomain(self):
+        return constraints.independent(constraints.real, len(self.out_shape))
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
+
+    def _call(self, x):
+        batch_shape = x.shape[: x.dim() - len(self.in_shape)]
+        return x.reshape(batch_shape + self.out_shape)
+
+    def _inverse(self, y):
+        batch_shape = y.shape[: y.dim() - len(self.out_shape)]
+        return y.reshape(batch_shape + self.in_shape)
+
+    def log_abs_det_jacobian(self, x, y):
+        batch_shape = x.shape[: x.dim() - len(self.in_shape)]
+        return x.new_zeros(batch_shape)
+
+    def forward_shape(self, shape):
+        if len(shape) < len(self.in_shape):
+            raise ValueError("Too few dimensions on input")
+        cut = len(shape) - len(self.in_shape)
+        if shape[cut:] != self.in_shape:
+            raise ValueError(
+                f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}"
+            )
+        return shape[:cut] + self.out_shape
+
+    def inverse_shape(self, shape):
+        if len(shape) < len(self.out_shape):
+            raise ValueError("Too few dimensions on input")
+        cut = len(shape) - len(self.out_shape)
+        if shape[cut:] != self.out_shape:
+            raise ValueError(
+                f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}"
+            )
+        return shape[:cut] + self.in_shape
+
+
+class ExpTransform(Transform):
+    r"""
+    Transform via the mapping :math:`y = \exp(x)`.
+    """
+    domain = constraints.real
+    codomain = constraints.positive
+    bijective = True
+    sign = +1
+
+    def __eq__(self, other):
+        return isinstance(other, ExpTransform)
+
+    def _call(self, x):
+        return x.exp()
+
+    def _inverse(self, y):
+        return y.log()
+
+    def log_abs_det_jacobian(self, x, y):
+        return x
+
+
+class PowerTransform(Transform):
+    r"""
+    Transform via the mapping :math:`y = x^{\text{exponent}}`.
+    """
+    domain = constraints.positive
+    codomain = constraints.positive
+    bijective = True
+
+    def __init__(self, exponent, cache_size=0):
+        super().__init__(cache_size=cache_size)
+        (self.exponent,) = broadcast_all(exponent)
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return PowerTransform(self.exponent, cache_size=cache_size)
+
+    @lazy_property
+    def sign(self):
+        return self.exponent.sign()
+
+    def __eq__(self, other):
+        if not isinstance(other, PowerTransform):
+            return False
+        return self.exponent.eq(other.exponent).all().item()
+
+    def _call(self, x):
+        return x.pow(self.exponent)
+
+    def _inverse(self, y):
+        return y.pow(1 / self.exponent)
+
+    def log_abs_det_jacobian(self, x, y):
+        return (self.exponent * y / x).abs().log()
+
+    def forward_shape(self, shape):
+        return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
+
+    def inverse_shape(self, shape):
+        return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
+
+
+def _clipped_sigmoid(x):
+    finfo = torch.finfo(x.dtype)
+    return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps)
+
+
+class SigmoidTransform(Transform):
+    r"""
+    Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
+    """
+    domain = constraints.real
+    codomain = constraints.unit_interval
+    bijective = True
+    sign = +1
+
+    def __eq__(self, other):
+        return isinstance(other, SigmoidTransform)
+
+    def _call(self, x):
+        return _clipped_sigmoid(x)
+
+    def _inverse(self, y):
+        finfo = torch.finfo(y.dtype)
+        y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps)
+        return y.log() - (-y).log1p()
+
+    def log_abs_det_jacobian(self, x, y):
+        return -F.softplus(-x) - F.softplus(x)
+
+
+class SoftplusTransform(Transform):
+    r"""
+    Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
+    The implementation reverts to the linear function when :math:`x > 20`.
+    """
+    domain = constraints.real
+    codomain = constraints.positive
+    bijective = True
+    sign = +1
+
+    def __eq__(self, other):
+        return isinstance(other, SoftplusTransform)
+
+    def _call(self, x):
+        return softplus(x)
+
+    def _inverse(self, y):
+        return (-y).expm1().neg().log() + y
+
+    def log_abs_det_jacobian(self, x, y):
+        return -softplus(-x)
+
+
+class TanhTransform(Transform):
+    r"""
+    Transform via the mapping :math:`y = \tanh(x)`.
+
+    It is equivalent to
+    ```
+    ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
+    ```
+    However this might not be numerically stable, thus it is recommended to use `TanhTransform`
+    instead.
+
+    Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
+
+    """
+    domain = constraints.real
+    codomain = constraints.interval(-1.0, 1.0)
+    bijective = True
+    sign = +1
+
+    def __eq__(self, other):
+        return isinstance(other, TanhTransform)
+
+    def _call(self, x):
+        return x.tanh()
+
+    def _inverse(self, y):
+        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
+        # one should use `cache_size=1` instead
+        return torch.atanh(y)
+
+    def log_abs_det_jacobian(self, x, y):
+        # We use a formula that is more numerically stable, see details in the following link
+        # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
+        return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x))
+
+
+class AbsTransform(Transform):
+    r"""
+    Transform via the mapping :math:`y = |x|`.
+    """
+    domain = constraints.real
+    codomain = constraints.positive
+
+    def __eq__(self, other):
+        return isinstance(other, AbsTransform)
+
+    def _call(self, x):
+        return x.abs()
+
+    def _inverse(self, y):
+        return y
+
+
+class AffineTransform(Transform):
+    r"""
+    Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
+
+    Args:
+        loc (Tensor or float): Location parameter.
+        scale (Tensor or float): Scale parameter.
+        event_dim (int): Optional size of `event_shape`. This should be zero
+            for univariate random variables, 1 for distributions over vectors,
+            2 for distributions over matrices, etc.
+    """
+    bijective = True
+
+    def __init__(self, loc, scale, event_dim=0, cache_size=0):
+        super().__init__(cache_size=cache_size)
+        self.loc = loc
+        self.scale = scale
+        self._event_dim = event_dim
+
+    @property
+    def event_dim(self):
+        return self._event_dim
+
+    @constraints.dependent_property(is_discrete=False)
+    def domain(self):
+        if self.event_dim == 0:
+            return constraints.real
+        return constraints.independent(constraints.real, self.event_dim)
+
+    @constraints.dependent_property(is_discrete=False)
+    def codomain(self):
+        if self.event_dim == 0:
+            return constraints.real
+        return constraints.independent(constraints.real, self.event_dim)
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return AffineTransform(
+            self.loc, self.scale, self.event_dim, cache_size=cache_size
+        )
+
+    def __eq__(self, other):
+        if not isinstance(other, AffineTransform):
+            return False
+
+        if isinstance(self.loc, numbers.Number) and isinstance(
+            other.loc, numbers.Number
+        ):
+            if self.loc != other.loc:
+                return False
+        else:
+            if not (self.loc == other.loc).all().item():
+                return False
+
+        if isinstance(self.scale, numbers.Number) and isinstance(
+            other.scale, numbers.Number
+        ):
+            if self.scale != other.scale:
+                return False
+        else:
+            if not (self.scale == other.scale).all().item():
+                return False
+
+        return True
+
+    @property
+    def sign(self):
+        if isinstance(self.scale, numbers.Real):
+            return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
+        return self.scale.sign()
+
+    def _call(self, x):
+        return self.loc + self.scale * x
+
+    def _inverse(self, y):
+        return (y - self.loc) / self.scale
+
+    def log_abs_det_jacobian(self, x, y):
+        shape = x.shape
+        scale = self.scale
+        if isinstance(scale, numbers.Real):
+            result = torch.full_like(x, math.log(abs(scale)))
+        else:
+            result = torch.abs(scale).log()
+        if self.event_dim:
+            result_size = result.size()[: -self.event_dim] + (-1,)
+            result = result.view(result_size).sum(-1)
+            shape = shape[: -self.event_dim]
+        return result.expand(shape)
+
+    def forward_shape(self, shape):
+        return torch.broadcast_shapes(
+            shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
+        )
+
+    def inverse_shape(self, shape):
+        return torch.broadcast_shapes(
+            shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
+        )
+
+
+class CorrCholeskyTransform(Transform):
+    r"""
+    Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
+    Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
+    triangular matrix with positive diagonals and unit Euclidean norm for each row.
+    The transform is processed as follows:
+
+        1. First we convert x into a lower triangular matrix in row order.
+        2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
+           class :class:`StickBreakingTransform` to transform :math:`X_i` into a
+           unit Euclidean length vector using the following steps:
+           - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
+           - Transforms into an unsigned domain: :math:`z_i = r_i^2`.
+           - Applies :math:`s_i = StickBreakingTransform(z_i)`.
+           - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
+    """
+    domain = constraints.real_vector
+    codomain = constraints.corr_cholesky
+    bijective = True
+
+    def _call(self, x):
+        x = torch.tanh(x)
+        eps = torch.finfo(x.dtype).eps
+        x = x.clamp(min=-1 + eps, max=1 - eps)
+        r = vec_to_tril_matrix(x, diag=-1)
+        # apply stick-breaking on the squared values
+        # Note that y = sign(r) * sqrt(z * z1m_cumprod)
+        #             = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
+        z = r**2
+        z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
+        # Diagonal elements must be 1.
+        r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
+        y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
+        return y
+
+    def _inverse(self, y):
+        # inverse stick-breaking
+        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
+        y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
+        y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
+        y_vec = tril_matrix_to_vec(y, diag=-1)
+        y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
+        t = y_vec / (y_cumsum_vec).sqrt()
+        # inverse of tanh
+        x = (t.log1p() - t.neg().log1p()) / 2
+        return x
+
+    def log_abs_det_jacobian(self, x, y, intermediates=None):
+        # Because domain and codomain are two spaces with different dimensions, determinant of
+        # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
+        # flattened lower triangular part of `y`.
+
+        # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
+        y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
+        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
+        # also works for 2 x 2 matrix
+        y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
+        stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
+        tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1)
+        return stick_breaking_logdet + tanh_logdet
+
+    def forward_shape(self, shape):
+        # Reshape from (..., N) to (..., D, D).
+        if len(shape) < 1:
+            raise ValueError("Too few dimensions on input")
+        N = shape[-1]
+        D = round((0.25 + 2 * N) ** 0.5 + 0.5)
+        if D * (D - 1) // 2 != N:
+            raise ValueError("Input is not a flattend lower-diagonal number")
+        return shape[:-1] + (D, D)
+
+    def inverse_shape(self, shape):
+        # Reshape from (..., D, D) to (..., N).
+        if len(shape) < 2:
+            raise ValueError("Too few dimensions on input")
+        if shape[-2] != shape[-1]:
+            raise ValueError("Input is not square")
+        D = shape[-1]
+        N = D * (D - 1) // 2
+        return shape[:-2] + (N,)
+
+
+class SoftmaxTransform(Transform):
+    r"""
+    Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
+    normalizing.
+
+    This is not bijective and cannot be used for HMC. However this acts mostly
+    coordinate-wise (except for the final normalization), and thus is
+    appropriate for coordinate-wise optimization algorithms.
+    """
+    domain = constraints.real_vector
+    codomain = constraints.simplex
+
+    def __eq__(self, other):
+        return isinstance(other, SoftmaxTransform)
+
+    def _call(self, x):
+        logprobs = x
+        probs = (logprobs - logprobs.max(-1, True)[0]).exp()
+        return probs / probs.sum(-1, True)
+
+    def _inverse(self, y):
+        probs = y
+        return probs.log()
+
+    def forward_shape(self, shape):
+        if len(shape) < 1:
+            raise ValueError("Too few dimensions on input")
+        return shape
+
+    def inverse_shape(self, shape):
+        if len(shape) < 1:
+            raise ValueError("Too few dimensions on input")
+        return shape
+
+
+class StickBreakingTransform(Transform):
+    """
+    Transform from unconstrained space to the simplex of one additional
+    dimension via a stick-breaking process.
+
+    This transform arises as an iterated sigmoid transform in a stick-breaking
+    construction of the `Dirichlet` distribution: the first logit is
+    transformed via sigmoid to the first probability and the probability of
+    everything else, and then the process recurses.
+
+    This is bijective and appropriate for use in HMC; however it mixes
+    coordinates together and is less appropriate for optimization.
+    """
+
+    domain = constraints.real_vector
+    codomain = constraints.simplex
+    bijective = True
+
+    def __eq__(self, other):
+        return isinstance(other, StickBreakingTransform)
+
+    def _call(self, x):
+        offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
+        z = _clipped_sigmoid(x - offset.log())
+        z_cumprod = (1 - z).cumprod(-1)
+        y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
+        return y
+
+    def _inverse(self, y):
+        y_crop = y[..., :-1]
+        offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
+        sf = 1 - y_crop.cumsum(-1)
+        # we clamp to make sure that sf is positive which sometimes does not
+        # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
+        sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
+        x = y_crop.log() - sf.log() + offset.log()
+        return x
+
+    def log_abs_det_jacobian(self, x, y):
+        offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
+        x = x - offset.log()
+        # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
+        detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
+        return detJ
+
+    def forward_shape(self, shape):
+        if len(shape) < 1:
+            raise ValueError("Too few dimensions on input")
+        return shape[:-1] + (shape[-1] + 1,)
+
+    def inverse_shape(self, shape):
+        if len(shape) < 1:
+            raise ValueError("Too few dimensions on input")
+        return shape[:-1] + (shape[-1] - 1,)
+
+
+class LowerCholeskyTransform(Transform):
+    """
+    Transform from unconstrained matrices to lower-triangular matrices with
+    nonnegative diagonal entries.
+
+    This is useful for parameterizing positive definite matrices in terms of
+    their Cholesky factorization.
+    """
+
+    domain = constraints.independent(constraints.real, 2)
+    codomain = constraints.lower_cholesky
+
+    def __eq__(self, other):
+        return isinstance(other, LowerCholeskyTransform)
+
+    def _call(self, x):
+        return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
+
+    def _inverse(self, y):
+        return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
+
+
+class PositiveDefiniteTransform(Transform):
+    """
+    Transform from unconstrained matrices to positive-definite matrices.
+    """
+
+    domain = constraints.independent(constraints.real, 2)
+    codomain = constraints.positive_definite  # type: ignore[assignment]
+
+    def __eq__(self, other):
+        return isinstance(other, PositiveDefiniteTransform)
+
+    def _call(self, x):
+        x = LowerCholeskyTransform()(x)
+        return x @ x.mT
+
+    def _inverse(self, y):
+        y = torch.linalg.cholesky(y)
+        return LowerCholeskyTransform().inv(y)
+
+
+class CatTransform(Transform):
+    """
+    Transform functor that applies a sequence of transforms `tseq`
+    component-wise to each submatrix at `dim`, of length `lengths[dim]`,
+    in a way compatible with :func:`torch.cat`.
+
+    Example::
+
+       x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
+       x = torch.cat([x0, x0], dim=0)
+       t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
+       t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
+       y = t(x)
+    """
+
+    transforms: List[Transform]
+
+    def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
+        assert all(isinstance(t, Transform) for t in tseq)
+        if cache_size:
+            tseq = [t.with_cache(cache_size) for t in tseq]
+        super().__init__(cache_size=cache_size)
+        self.transforms = list(tseq)
+        if lengths is None:
+            lengths = [1] * len(self.transforms)
+        self.lengths = list(lengths)
+        assert len(self.lengths) == len(self.transforms)
+        self.dim = dim
+
+    @lazy_property
+    def event_dim(self):
+        return max(t.event_dim for t in self.transforms)
+
+    @lazy_property
+    def length(self):
+        return sum(self.lengths)
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
+
+    def _call(self, x):
+        assert -x.dim() <= self.dim < x.dim()
+        assert x.size(self.dim) == self.length
+        yslices = []
+        start = 0
+        for trans, length in zip(self.transforms, self.lengths):
+            xslice = x.narrow(self.dim, start, length)
+            yslices.append(trans(xslice))
+            start = start + length  # avoid += for jit compat
+        return torch.cat(yslices, dim=self.dim)
+
+    def _inverse(self, y):
+        assert -y.dim() <= self.dim < y.dim()
+        assert y.size(self.dim) == self.length
+        xslices = []
+        start = 0
+        for trans, length in zip(self.transforms, self.lengths):
+            yslice = y.narrow(self.dim, start, length)
+            xslices.append(trans.inv(yslice))
+            start = start + length  # avoid += for jit compat
+        return torch.cat(xslices, dim=self.dim)
+
+    def log_abs_det_jacobian(self, x, y):
+        assert -x.dim() <= self.dim < x.dim()
+        assert x.size(self.dim) == self.length
+        assert -y.dim() <= self.dim < y.dim()
+        assert y.size(self.dim) == self.length
+        logdetjacs = []
+        start = 0
+        for trans, length in zip(self.transforms, self.lengths):
+            xslice = x.narrow(self.dim, start, length)
+            yslice = y.narrow(self.dim, start, length)
+            logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
+            if trans.event_dim < self.event_dim:
+                logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
+            logdetjacs.append(logdetjac)
+            start = start + length  # avoid += for jit compat
+        # Decide whether to concatenate or sum.
+        dim = self.dim
+        if dim >= 0:
+            dim = dim - x.dim()
+        dim = dim + self.event_dim
+        if dim < 0:
+            return torch.cat(logdetjacs, dim=dim)
+        else:
+            return sum(logdetjacs)
+
+    @property
+    def bijective(self):
+        return all(t.bijective for t in self.transforms)
+
+    @constraints.dependent_property
+    def domain(self):
+        return constraints.cat(
+            [t.domain for t in self.transforms], self.dim, self.lengths
+        )
+
+    @constraints.dependent_property
+    def codomain(self):
+        return constraints.cat(
+            [t.codomain for t in self.transforms], self.dim, self.lengths
+        )
+
+
+class StackTransform(Transform):
+    """
+    Transform functor that applies a sequence of transforms `tseq`
+    component-wise to each submatrix at `dim`
+    in a way compatible with :func:`torch.stack`.
+
+    Example::
+
+       x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
+       t = StackTransform([ExpTransform(), identity_transform], dim=1)
+       y = t(x)
+    """
+
+    transforms: List[Transform]
+
+    def __init__(self, tseq, dim=0, cache_size=0):
+        assert all(isinstance(t, Transform) for t in tseq)
+        if cache_size:
+            tseq = [t.with_cache(cache_size) for t in tseq]
+        super().__init__(cache_size=cache_size)
+        self.transforms = list(tseq)
+        self.dim = dim
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return StackTransform(self.transforms, self.dim, cache_size)
+
+    def _slice(self, z):
+        return [z.select(self.dim, i) for i in range(z.size(self.dim))]
+
+    def _call(self, x):
+        assert -x.dim() <= self.dim < x.dim()
+        assert x.size(self.dim) == len(self.transforms)
+        yslices = []
+        for xslice, trans in zip(self._slice(x), self.transforms):
+            yslices.append(trans(xslice))
+        return torch.stack(yslices, dim=self.dim)
+
+    def _inverse(self, y):
+        assert -y.dim() <= self.dim < y.dim()
+        assert y.size(self.dim) == len(self.transforms)
+        xslices = []
+        for yslice, trans in zip(self._slice(y), self.transforms):
+            xslices.append(trans.inv(yslice))
+        return torch.stack(xslices, dim=self.dim)
+
+    def log_abs_det_jacobian(self, x, y):
+        assert -x.dim() <= self.dim < x.dim()
+        assert x.size(self.dim) == len(self.transforms)
+        assert -y.dim() <= self.dim < y.dim()
+        assert y.size(self.dim) == len(self.transforms)
+        logdetjacs = []
+        yslices = self._slice(y)
+        xslices = self._slice(x)
+        for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
+            logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
+        return torch.stack(logdetjacs, dim=self.dim)
+
+    @property
+    def bijective(self):
+        return all(t.bijective for t in self.transforms)
+
+    @constraints.dependent_property
+    def domain(self):
+        return constraints.stack([t.domain for t in self.transforms], self.dim)
+
+    @constraints.dependent_property
+    def codomain(self):
+        return constraints.stack([t.codomain for t in self.transforms], self.dim)
+
+
+class CumulativeDistributionTransform(Transform):
+    """
+    Transform via the cumulative distribution function of a probability distribution.
+
+    Args:
+        distribution (Distribution): Distribution whose cumulative distribution function to use for
+            the transformation.
+
+    Example::
+
+        # Construct a Gaussian copula from a multivariate normal.
+        base_dist = MultivariateNormal(
+            loc=torch.zeros(2),
+            scale_tril=LKJCholesky(2).sample(),
+        )
+        transform = CumulativeDistributionTransform(Normal(0, 1))
+        copula = TransformedDistribution(base_dist, [transform])
+    """
+
+    bijective = True
+    codomain = constraints.unit_interval
+    sign = +1
+
+    def __init__(self, distribution, cache_size=0):
+        super().__init__(cache_size=cache_size)
+        self.distribution = distribution
+
+    @property
+    def domain(self):
+        return self.distribution.support
+
+    def _call(self, x):
+        return self.distribution.cdf(x)
+
+    def _inverse(self, y):
+        return self.distribution.icdf(y)
+
+    def log_abs_det_jacobian(self, x, y):
+        return self.distribution.log_prob(x)
+
+    def with_cache(self, cache_size=1):
+        if self._cache_size == cache_size:
+            return self
+        return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)
diff --git a/MLPY/Lib/site-packages/torch/distributions/uniform.py b/MLPY/Lib/site-packages/torch/distributions/uniform.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e3e726db7e71f237f0e9533fef85a04255d697
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/uniform.py
@@ -0,0 +1,99 @@
+from numbers import Number
+
+import torch
+from torch import nan
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Uniform"]
+
+
+class Uniform(Distribution):
+    r"""
+    Generates uniformly distributed random samples from the half-open interval
+    ``[low, high)``.
+
+    Example::
+
+        >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
+        >>> m.sample()  # uniformly distributed in the range [0.0, 5.0)
+        >>> # xdoctest: +SKIP
+        tensor([ 2.3418])
+
+    Args:
+        low (float or Tensor): lower range (inclusive).
+        high (float or Tensor): upper range (exclusive).
+    """
+    # TODO allow (loc,scale) parameterization to allow independent constraints.
+    arg_constraints = {
+        "low": constraints.dependent(is_discrete=False, event_dim=0),
+        "high": constraints.dependent(is_discrete=False, event_dim=0),
+    }
+    has_rsample = True
+
+    @property
+    def mean(self):
+        return (self.high + self.low) / 2
+
+    @property
+    def mode(self):
+        return nan * self.high
+
+    @property
+    def stddev(self):
+        return (self.high - self.low) / 12**0.5
+
+    @property
+    def variance(self):
+        return (self.high - self.low).pow(2) / 12
+
+    def __init__(self, low, high, validate_args=None):
+        self.low, self.high = broadcast_all(low, high)
+
+        if isinstance(low, Number) and isinstance(high, Number):
+            batch_shape = torch.Size()
+        else:
+            batch_shape = self.low.size()
+        super().__init__(batch_shape, validate_args=validate_args)
+
+        if self._validate_args and not torch.lt(self.low, self.high).all():
+            raise ValueError("Uniform is not defined when low>= high")
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Uniform, _instance)
+        batch_shape = torch.Size(batch_shape)
+        new.low = self.low.expand(batch_shape)
+        new.high = self.high.expand(batch_shape)
+        super(Uniform, new).__init__(batch_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @constraints.dependent_property(is_discrete=False, event_dim=0)
+    def support(self):
+        return constraints.interval(self.low, self.high)
+
+    def rsample(self, sample_shape=torch.Size()):
+        shape = self._extended_shape(sample_shape)
+        rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
+        return self.low + rand * (self.high - self.low)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        lb = self.low.le(value).type_as(self.low)
+        ub = self.high.gt(value).type_as(self.low)
+        return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
+
+    def cdf(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        result = (value - self.low) / (self.high - self.low)
+        return result.clamp(min=0, max=1)
+
+    def icdf(self, value):
+        result = value * (self.high - self.low) + self.low
+        return result
+
+    def entropy(self):
+        return torch.log(self.high - self.low)
diff --git a/MLPY/Lib/site-packages/torch/distributions/utils.py b/MLPY/Lib/site-packages/torch/distributions/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..146ae2adc7e4f917ffb65b82549712f3e7745565
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/utils.py
@@ -0,0 +1,177 @@
+from functools import update_wrapper
+from numbers import Number
+from typing import Any, Dict
+
+import torch
+import torch.nn.functional as F
+from torch.overrides import is_tensor_like
+
+euler_constant = 0.57721566490153286060  # Euler Mascheroni Constant
+
+__all__ = [
+    "broadcast_all",
+    "logits_to_probs",
+    "clamp_probs",
+    "probs_to_logits",
+    "lazy_property",
+    "tril_matrix_to_vec",
+    "vec_to_tril_matrix",
+]
+
+
+def broadcast_all(*values):
+    r"""
+    Given a list of values (possibly containing numbers), returns a list where each
+    value is broadcasted based on the following rules:
+      - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
+      - numbers.Number instances (scalars) are upcast to tensors having
+        the same size and type as the first tensor passed to `values`.  If all the
+        values are scalars, then they are upcasted to scalar Tensors.
+
+    Args:
+        values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__)
+
+    Raises:
+        ValueError: if any of the values is not a `numbers.Number` instance,
+            a `torch.*Tensor` instance, or an instance implementing __torch_function__
+    """
+    if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
+        raise ValueError(
+            "Input arguments must all be instances of numbers.Number, "
+            "torch.Tensor or objects implementing __torch_function__."
+        )
+    if not all(is_tensor_like(v) for v in values):
+        options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
+        for value in values:
+            if isinstance(value, torch.Tensor):
+                options = dict(dtype=value.dtype, device=value.device)
+                break
+        new_values = [
+            v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
+        ]
+        return torch.broadcast_tensors(*new_values)
+    return torch.broadcast_tensors(*values)
+
+
+def _standard_normal(shape, dtype, device):
+    if torch._C._get_tracing_state():
+        # [JIT WORKAROUND] lack of support for .normal_()
+        return torch.normal(
+            torch.zeros(shape, dtype=dtype, device=device),
+            torch.ones(shape, dtype=dtype, device=device),
+        )
+    return torch.empty(shape, dtype=dtype, device=device).normal_()
+
+
+def _sum_rightmost(value, dim):
+    r"""
+    Sum out ``dim`` many rightmost dimensions of a given tensor.
+
+    Args:
+        value (Tensor): A tensor of ``.dim()`` at least ``dim``.
+        dim (int): The number of rightmost dims to sum out.
+    """
+    if dim == 0:
+        return value
+    required_shape = value.shape[:-dim] + (-1,)
+    return value.reshape(required_shape).sum(-1)
+
+
+def logits_to_probs(logits, is_binary=False):
+    r"""
+    Converts a tensor of logits into probabilities. Note that for the
+    binary case, each value denotes log odds, whereas for the
+    multi-dimensional case, the values along the last dimension denote
+    the log probabilities (possibly unnormalized) of the events.
+    """
+    if is_binary:
+        return torch.sigmoid(logits)
+    return F.softmax(logits, dim=-1)
+
+
+def clamp_probs(probs):
+    eps = torch.finfo(probs.dtype).eps
+    return probs.clamp(min=eps, max=1 - eps)
+
+
+def probs_to_logits(probs, is_binary=False):
+    r"""
+    Converts a tensor of probabilities into logits. For the binary case,
+    this denotes the probability of occurrence of the event indexed by `1`.
+    For the multi-dimensional case, the values along the last dimension
+    denote the probabilities of occurrence of each of the events.
+    """
+    ps_clamped = clamp_probs(probs)
+    if is_binary:
+        return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
+    return torch.log(ps_clamped)
+
+
+class lazy_property:
+    r"""
+    Used as a decorator for lazy loading of class attributes. This uses a
+    non-data descriptor that calls the wrapped method to compute the property on
+    first call; thereafter replacing the wrapped method into an instance
+    attribute.
+    """
+
+    def __init__(self, wrapped):
+        self.wrapped = wrapped
+        update_wrapper(self, wrapped)
+
+    def __get__(self, instance, obj_type=None):
+        if instance is None:
+            return _lazy_property_and_property(self.wrapped)
+        with torch.enable_grad():
+            value = self.wrapped(instance)
+        setattr(instance, self.wrapped.__name__, value)
+        return value
+
+
+class _lazy_property_and_property(lazy_property, property):
+    """We want lazy properties to look like multiple things.
+
+    * property when Sphinx autodoc looks
+    * lazy_property when Distribution validate_args looks
+    """
+
+    def __init__(self, wrapped):
+        property.__init__(self, wrapped)
+
+
+def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
+    r"""
+    Convert a `D x D` matrix or a batch of matrices into a (batched) vector
+    which comprises of lower triangular elements from the matrix in row order.
+    """
+    n = mat.shape[-1]
+    if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
+        raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].")
+    arange = torch.arange(n, device=mat.device)
+    tril_mask = arange < arange.view(-1, 1) + (diag + 1)
+    vec = mat[..., tril_mask]
+    return vec
+
+
+def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
+    r"""
+    Convert a vector or a batch of vectors into a batched `D x D`
+    lower triangular matrix containing elements from the vector in row order.
+    """
+    # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
+    n = (
+        -(1 + 2 * diag)
+        + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
+    ) / 2
+    eps = torch.finfo(vec.dtype).eps
+    if not torch._C._get_tracing_state() and (round(n) - n > eps):
+        raise ValueError(
+            f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
+            + "the lower triangular part of a square D x D matrix."
+        )
+    n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
+    mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
+    arange = torch.arange(n, device=vec.device)
+    tril_mask = arange < arange.view(-1, 1) + (diag + 1)
+    mat[..., tril_mask] = vec
+    return mat
diff --git a/MLPY/Lib/site-packages/torch/distributions/von_mises.py b/MLPY/Lib/site-packages/torch/distributions/von_mises.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6fdef81b67cbcb4edd60d353d99d57bb7a09e3a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/von_mises.py
@@ -0,0 +1,209 @@
+import math
+
+import torch
+import torch.jit
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all, lazy_property
+
+__all__ = ["VonMises"]
+
+
+def _eval_poly(y, coef):
+    coef = list(coef)
+    result = coef.pop()
+    while coef:
+        result = coef.pop() + y * result
+    return result
+
+
+_I0_COEF_SMALL = [
+    1.0,
+    3.5156229,
+    3.0899424,
+    1.2067492,
+    0.2659732,
+    0.360768e-1,
+    0.45813e-2,
+]
+_I0_COEF_LARGE = [
+    0.39894228,
+    0.1328592e-1,
+    0.225319e-2,
+    -0.157565e-2,
+    0.916281e-2,
+    -0.2057706e-1,
+    0.2635537e-1,
+    -0.1647633e-1,
+    0.392377e-2,
+]
+_I1_COEF_SMALL = [
+    0.5,
+    0.87890594,
+    0.51498869,
+    0.15084934,
+    0.2658733e-1,
+    0.301532e-2,
+    0.32411e-3,
+]
+_I1_COEF_LARGE = [
+    0.39894228,
+    -0.3988024e-1,
+    -0.362018e-2,
+    0.163801e-2,
+    -0.1031555e-1,
+    0.2282967e-1,
+    -0.2895312e-1,
+    0.1787654e-1,
+    -0.420059e-2,
+]
+
+_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
+_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
+
+
+def _log_modified_bessel_fn(x, order=0):
+    """
+    Returns ``log(I_order(x))`` for ``x > 0``,
+    where `order` is either 0 or 1.
+    """
+    assert order == 0 or order == 1
+
+    # compute small solution
+    y = x / 3.75
+    y = y * y
+    small = _eval_poly(y, _COEF_SMALL[order])
+    if order == 1:
+        small = x.abs() * small
+    small = small.log()
+
+    # compute large solution
+    y = 3.75 / x
+    large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
+
+    result = torch.where(x < 3.75, small, large)
+    return result
+
+
+@torch.jit.script_if_tracing
+def _rejection_sample(loc, concentration, proposal_r, x):
+    done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
+    while not done.all():
+        u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
+        u1, u2, u3 = u.unbind()
+        z = torch.cos(math.pi * u1)
+        f = (1 + proposal_r * z) / (proposal_r + z)
+        c = concentration * (proposal_r - f)
+        accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
+        if accept.any():
+            x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
+            done = done | accept
+    return (x + math.pi + loc) % (2 * math.pi) - math.pi
+
+
+class VonMises(Distribution):
+    """
+    A circular von Mises distribution.
+
+    This implementation uses polar coordinates. The ``loc`` and ``value`` args
+    can be any real number (to facilitate unconstrained optimization), but are
+    interpreted as angles modulo 2 pi.
+
+    Example::
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
+        >>> m.sample()  # von Mises distributed with loc=1 and concentration=1
+        tensor([1.9777])
+
+    :param torch.Tensor loc: an angle in radians.
+    :param torch.Tensor concentration: concentration parameter
+    """
+
+    arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
+    support = constraints.real
+    has_rsample = False
+
+    def __init__(self, loc, concentration, validate_args=None):
+        self.loc, self.concentration = broadcast_all(loc, concentration)
+        batch_shape = self.loc.shape
+        event_shape = torch.Size()
+        super().__init__(batch_shape, event_shape, validate_args)
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        log_prob = self.concentration * torch.cos(value - self.loc)
+        log_prob = (
+            log_prob
+            - math.log(2 * math.pi)
+            - _log_modified_bessel_fn(self.concentration, order=0)
+        )
+        return log_prob
+
+    @lazy_property
+    def _loc(self):
+        return self.loc.to(torch.double)
+
+    @lazy_property
+    def _concentration(self):
+        return self.concentration.to(torch.double)
+
+    @lazy_property
+    def _proposal_r(self):
+        kappa = self._concentration
+        tau = 1 + (1 + 4 * kappa**2).sqrt()
+        rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
+        _proposal_r = (1 + rho**2) / (2 * rho)
+        # second order Taylor expansion around 0 for small kappa
+        _proposal_r_taylor = 1 / kappa + kappa
+        return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
+
+    @torch.no_grad()
+    def sample(self, sample_shape=torch.Size()):
+        """
+        The sampling algorithm for the von Mises distribution is based on the
+        following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
+        von Mises distribution." Applied Statistics (1979): 152-157.
+
+        Sampling is always done in double precision internally to avoid a hang
+        in _rejection_sample() for small values of the concentration, which
+        starts to happen for single precision around 1e-4 (see issue #88443).
+        """
+        shape = self._extended_shape(sample_shape)
+        x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
+        return _rejection_sample(
+            self._loc, self._concentration, self._proposal_r, x
+        ).to(self.loc.dtype)
+
+    def expand(self, batch_shape):
+        try:
+            return super().expand(batch_shape)
+        except NotImplementedError:
+            validate_args = self.__dict__.get("_validate_args")
+            loc = self.loc.expand(batch_shape)
+            concentration = self.concentration.expand(batch_shape)
+            return type(self)(loc, concentration, validate_args=validate_args)
+
+    @property
+    def mean(self):
+        """
+        The provided mean is the circular one.
+        """
+        return self.loc
+
+    @property
+    def mode(self):
+        return self.loc
+
+    @lazy_property
+    def variance(self):
+        """
+        The provided variance is the circular one.
+        """
+        return (
+            1
+            - (
+                _log_modified_bessel_fn(self.concentration, order=1)
+                - _log_modified_bessel_fn(self.concentration, order=0)
+            ).exp()
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/weibull.py b/MLPY/Lib/site-packages/torch/distributions/weibull.py
new file mode 100644
index 0000000000000000000000000000000000000000..3277175e74c9723d256e4e3c9ee3141aa994cf81
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/weibull.py
@@ -0,0 +1,83 @@
+import torch
+from torch.distributions import constraints
+from torch.distributions.exponential import Exponential
+from torch.distributions.gumbel import euler_constant
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, PowerTransform
+from torch.distributions.utils import broadcast_all
+
+__all__ = ["Weibull"]
+
+
+class Weibull(TransformedDistribution):
+    r"""
+    Samples from a two-parameter Weibull distribution.
+
+    Example:
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
+        >>> m.sample()  # sample from a Weibull distribution with scale=1, concentration=1
+        tensor([ 0.4784])
+
+    Args:
+        scale (float or Tensor): Scale parameter of distribution (lambda).
+        concentration (float or Tensor): Concentration parameter of distribution (k/shape).
+    """
+    arg_constraints = {
+        "scale": constraints.positive,
+        "concentration": constraints.positive,
+    }
+    support = constraints.positive
+
+    def __init__(self, scale, concentration, validate_args=None):
+        self.scale, self.concentration = broadcast_all(scale, concentration)
+        self.concentration_reciprocal = self.concentration.reciprocal()
+        base_dist = Exponential(
+            torch.ones_like(self.scale), validate_args=validate_args
+        )
+        transforms = [
+            PowerTransform(exponent=self.concentration_reciprocal),
+            AffineTransform(loc=0, scale=self.scale),
+        ]
+        super().__init__(base_dist, transforms, validate_args=validate_args)
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Weibull, _instance)
+        new.scale = self.scale.expand(batch_shape)
+        new.concentration = self.concentration.expand(batch_shape)
+        new.concentration_reciprocal = new.concentration.reciprocal()
+        base_dist = self.base_dist.expand(batch_shape)
+        transforms = [
+            PowerTransform(exponent=new.concentration_reciprocal),
+            AffineTransform(loc=0, scale=new.scale),
+        ]
+        super(Weibull, new).__init__(base_dist, transforms, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @property
+    def mean(self):
+        return self.scale * torch.exp(torch.lgamma(1 + self.concentration_reciprocal))
+
+    @property
+    def mode(self):
+        return (
+            self.scale
+            * ((self.concentration - 1) / self.concentration)
+            ** self.concentration.reciprocal()
+        )
+
+    @property
+    def variance(self):
+        return self.scale.pow(2) * (
+            torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal))
+            - torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal))
+        )
+
+    def entropy(self):
+        return (
+            euler_constant * (1 - self.concentration_reciprocal)
+            + torch.log(self.scale * self.concentration_reciprocal)
+            + 1
+        )
diff --git a/MLPY/Lib/site-packages/torch/distributions/wishart.py b/MLPY/Lib/site-packages/torch/distributions/wishart.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec018b5caa33871c17479c726e10eff5244fc38
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/distributions/wishart.py
@@ -0,0 +1,335 @@
+import math
+import warnings
+from numbers import Number
+from typing import Optional, Union
+
+import torch
+from torch import nan
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.multivariate_normal import _precision_to_scale_tril
+from torch.distributions.utils import lazy_property
+
+
+__all__ = ["Wishart"]
+
+_log_2 = math.log(2)
+
+
+def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
+    assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
+    return torch.digamma(
+        x.unsqueeze(-1)
+        - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
+    ).sum(-1)
+
+
+def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
+    # We assume positive input for this function
+    return x.clamp(min=torch.finfo(x.dtype).eps)
+
+
+class Wishart(ExponentialFamily):
+    r"""
+    Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
+    or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
+
+    Example:
+        >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
+        >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
+        >>> m.sample()  # Wishart distributed with mean=`df * I` and
+        >>>             # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
+
+    Args:
+        df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
+        covariance_matrix (Tensor): positive-definite covariance matrix
+        precision_matrix (Tensor): positive-definite precision matrix
+        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
+    Note:
+        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
+        :attr:`scale_tril` can be specified.
+        Using :attr:`scale_tril` will be more efficient: all computations internally
+        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
+        :attr:`precision_matrix` is passed instead, it is only used to compute
+        the corresponding lower triangular matrices using a Cholesky decomposition.
+        'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
+
+    **References**
+
+    [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
+    [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
+    [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
+    [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
+    [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
+    """
+    arg_constraints = {
+        "covariance_matrix": constraints.positive_definite,
+        "precision_matrix": constraints.positive_definite,
+        "scale_tril": constraints.lower_cholesky,
+        "df": constraints.greater_than(0),
+    }
+    support = constraints.positive_definite
+    has_rsample = True
+    _mean_carrier_measure = 0
+
+    def __init__(
+        self,
+        df: Union[torch.Tensor, Number],
+        covariance_matrix: Optional[torch.Tensor] = None,
+        precision_matrix: Optional[torch.Tensor] = None,
+        scale_tril: Optional[torch.Tensor] = None,
+        validate_args=None,
+    ):
+        assert (covariance_matrix is not None) + (scale_tril is not None) + (
+            precision_matrix is not None
+        ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
+
+        param = next(
+            p
+            for p in (covariance_matrix, precision_matrix, scale_tril)
+            if p is not None
+        )
+
+        if param.dim() < 2:
+            raise ValueError(
+                "scale_tril must be at least two-dimensional, with optional leading batch dimensions"
+            )
+
+        if isinstance(df, Number):
+            batch_shape = torch.Size(param.shape[:-2])
+            self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
+        else:
+            batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
+            self.df = df.expand(batch_shape)
+        event_shape = param.shape[-2:]
+
+        if self.df.le(event_shape[-1] - 1).any():
+            raise ValueError(
+                f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."
+            )
+
+        if scale_tril is not None:
+            self.scale_tril = param.expand(batch_shape + (-1, -1))
+        elif covariance_matrix is not None:
+            self.covariance_matrix = param.expand(batch_shape + (-1, -1))
+        elif precision_matrix is not None:
+            self.precision_matrix = param.expand(batch_shape + (-1, -1))
+
+        self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)
+        if self.df.lt(event_shape[-1]).any():
+            warnings.warn(
+                "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
+            )
+
+        super().__init__(batch_shape, event_shape, validate_args=validate_args)
+        self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
+
+        if scale_tril is not None:
+            self._unbroadcasted_scale_tril = scale_tril
+        elif covariance_matrix is not None:
+            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
+        else:  # precision_matrix is not None
+            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
+
+        # Chi2 distribution is needed for Bartlett decomposition sampling
+        self._dist_chi2 = torch.distributions.chi2.Chi2(
+            df=(
+                self.df.unsqueeze(-1)
+                - torch.arange(
+                    self._event_shape[-1],
+                    dtype=self._unbroadcasted_scale_tril.dtype,
+                    device=self._unbroadcasted_scale_tril.device,
+                ).expand(batch_shape + (-1,))
+            )
+        )
+
+    def expand(self, batch_shape, _instance=None):
+        new = self._get_checked_instance(Wishart, _instance)
+        batch_shape = torch.Size(batch_shape)
+        cov_shape = batch_shape + self.event_shape
+        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
+        new.df = self.df.expand(batch_shape)
+
+        new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
+
+        if "covariance_matrix" in self.__dict__:
+            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
+        if "scale_tril" in self.__dict__:
+            new.scale_tril = self.scale_tril.expand(cov_shape)
+        if "precision_matrix" in self.__dict__:
+            new.precision_matrix = self.precision_matrix.expand(cov_shape)
+
+        # Chi2 distribution is needed for Bartlett decomposition sampling
+        new._dist_chi2 = torch.distributions.chi2.Chi2(
+            df=(
+                new.df.unsqueeze(-1)
+                - torch.arange(
+                    self.event_shape[-1],
+                    dtype=new._unbroadcasted_scale_tril.dtype,
+                    device=new._unbroadcasted_scale_tril.device,
+                ).expand(batch_shape + (-1,))
+            )
+        )
+
+        super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
+        new._validate_args = self._validate_args
+        return new
+
+    @lazy_property
+    def scale_tril(self):
+        return self._unbroadcasted_scale_tril.expand(
+            self._batch_shape + self._event_shape
+        )
+
+    @lazy_property
+    def covariance_matrix(self):
+        return (
+            self._unbroadcasted_scale_tril
+            @ self._unbroadcasted_scale_tril.transpose(-2, -1)
+        ).expand(self._batch_shape + self._event_shape)
+
+    @lazy_property
+    def precision_matrix(self):
+        identity = torch.eye(
+            self._event_shape[-1],
+            device=self._unbroadcasted_scale_tril.device,
+            dtype=self._unbroadcasted_scale_tril.dtype,
+        )
+        return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
+            self._batch_shape + self._event_shape
+        )
+
+    @property
+    def mean(self):
+        return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
+
+    @property
+    def mode(self):
+        factor = self.df - self.covariance_matrix.shape[-1] - 1
+        factor[factor <= 0] = nan
+        return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
+
+    @property
+    def variance(self):
+        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
+        diag_V = V.diagonal(dim1=-2, dim2=-1)
+        return self.df.view(self._batch_shape + (1, 1)) * (
+            V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
+        )
+
+    def _bartlett_sampling(self, sample_shape=torch.Size()):
+        p = self._event_shape[-1]  # has singleton shape
+
+        # Implemented Sampling using Bartlett decomposition
+        noise = _clamp_above_eps(
+            self._dist_chi2.rsample(sample_shape).sqrt()
+        ).diag_embed(dim1=-2, dim2=-1)
+
+        i, j = torch.tril_indices(p, p, offset=-1)
+        noise[..., i, j] = torch.randn(
+            torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
+            dtype=noise.dtype,
+            device=noise.device,
+        )
+        chol = self._unbroadcasted_scale_tril @ noise
+        return chol @ chol.transpose(-2, -1)
+
+    def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
+        r"""
+        .. warning::
+            In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
+            Several tries to correct singular samples are performed by default, but it may end up returning
+            singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
+            In those cases, the user should validate the samples and either fix the value of `df`
+            or adjust `max_try_correction` value for argument in `.rsample` accordingly.
+        """
+
+        if max_try_correction is None:
+            max_try_correction = 3 if torch._C._get_tracing_state() else 10
+
+        sample_shape = torch.Size(sample_shape)
+        sample = self._bartlett_sampling(sample_shape)
+
+        # Below part is to improve numerical stability temporally and should be removed in the future
+        is_singular = self.support.check(sample)
+        if self._batch_shape:
+            is_singular = is_singular.amax(self._batch_dims)
+
+        if torch._C._get_tracing_state():
+            # Less optimized version for JIT
+            for _ in range(max_try_correction):
+                sample_new = self._bartlett_sampling(sample_shape)
+                sample = torch.where(is_singular, sample_new, sample)
+
+                is_singular = ~self.support.check(sample)
+                if self._batch_shape:
+                    is_singular = is_singular.amax(self._batch_dims)
+
+        else:
+            # More optimized version with data-dependent control flow.
+            if is_singular.any():
+                warnings.warn("Singular sample detected.")
+
+                for _ in range(max_try_correction):
+                    sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
+                    sample[is_singular] = sample_new
+
+                    is_singular_new = ~self.support.check(sample_new)
+                    if self._batch_shape:
+                        is_singular_new = is_singular_new.amax(self._batch_dims)
+                    is_singular[is_singular.clone()] = is_singular_new
+
+                    if not is_singular.any():
+                        break
+
+        return sample
+
+    def log_prob(self, value):
+        if self._validate_args:
+            self._validate_sample(value)
+        nu = self.df  # has shape (batch_shape)
+        p = self._event_shape[-1]  # has singleton shape
+        return (
+            -nu
+            * (
+                p * _log_2 / 2
+                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
+                .log()
+                .sum(-1)
+            )
+            - torch.mvlgamma(nu / 2, p=p)
+            + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
+            - torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
+            .diagonal(dim1=-2, dim2=-1)
+            .sum(dim=-1)
+            / 2
+        )
+
+    def entropy(self):
+        nu = self.df  # has shape (batch_shape)
+        p = self._event_shape[-1]  # has singleton shape
+        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
+        return (
+            (p + 1)
+            * (
+                p * _log_2 / 2
+                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
+                .log()
+                .sum(-1)
+            )
+            + torch.mvlgamma(nu / 2, p=p)
+            - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
+            + nu * p / 2
+        )
+
+    @property
+    def _natural_params(self):
+        nu = self.df  # has shape (batch_shape)
+        p = self._event_shape[-1]  # has singleton shape
+        return -self.precision_matrix / 2, (nu - p - 1) / 2
+
+    def _log_normalizer(self, x, y):
+        p = self._event_shape[-1]
+        return (y + (p + 1) / 2) * (
+            -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
+        ) + torch.mvlgamma(y + (p + 1) / 2, p=p)
diff --git a/MLPY/Lib/site-packages/torch/export/__init__.py b/MLPY/Lib/site-packages/torch/export/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b0c27c4d227fd2f8eb1821ac2a595678dc54fd7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/__init__.py
@@ -0,0 +1,344 @@
+import builtins
+import copy
+import dataclasses
+import inspect
+import io
+import os
+import sys
+import typing
+import warnings
+from enum import auto, Enum
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TYPE_CHECKING,
+    Union,
+)
+
+import torch
+import torch.utils._pytree as pytree
+from torch.fx._compatibility import compatibility
+
+from torch.fx.passes.infra.pass_base import PassResult
+from torch.fx.passes.infra.pass_manager import PassManager
+
+from torch.utils._pytree import (
+    FlattenFunc,
+    FromDumpableContextFn,
+    ToDumpableContextFn,
+    UnflattenFunc,
+)
+
+if TYPE_CHECKING:
+    # Import the following modules during type checking to enable code intelligence features,
+    # Do not import unconditionally, as they import sympy and importing sympy is very slow
+    from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
+
+
+__all__ = [
+    "Constraint",
+    "Dim",
+    "ExportBackwardSignature",
+    "ExportGraphSignature",
+    "ExportedProgram",
+    "ModuleCallEntry",
+    "ModuleCallSignature",
+    "dims",
+    "dynamic_dim",
+    "export",
+    "load",
+    "register_dataclass",
+    "save",
+    "unflatten",
+    "FlatArgsAdapter",
+    "UnflattenedModule",
+]
+
+
+from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
+from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
+from .graph_signature import ExportBackwardSignature, ExportGraphSignature
+from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
+
+
+PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
+
+
+def export(
+    mod: torch.nn.Module,
+    args: Tuple[Any, ...],
+    kwargs: Optional[Dict[str, Any]] = None,
+    *,
+    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
+    strict: bool = True,
+    preserve_module_call_signature: Tuple[str, ...] = (),
+) -> ExportedProgram:
+    """
+    :func:`export` takes an arbitrary Python callable (an nn.Module, a function or
+    a method) along with example inputs, and produces a traced graph representing
+    only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
+    which can subsequently be executed with different inputs or serialized.  The
+    traced graph (1) produces normalized operators in the functional ATen operator set
+    (as well as any user-specified custom operators), (2) has eliminated all Python control
+    flow and data structures (with certain exceptions), and (3) records the set of
+    shape constraints needed to show that this normalization and control-flow elimination
+    is sound for future inputs.
+
+    **Soundness Guarantee**
+
+    While tracing, :func:`export()` takes note of shape-related assumptions
+    made by the user program and the underlying PyTorch operator kernels.
+    The output :class:`ExportedProgram` is considered valid only when these
+    assumptions hold true.
+
+    Tracing makes assumptions on the shapes (not values) of input tensors.
+    Such assumptions must be validated at graph capture time for :func:`export`
+    to succeed. Specifically:
+
+    - Assumptions on static shapes of input tensors are automatically validated without additional effort.
+    - Assumptions on dynamic shape of input tensors require explicit specification
+      by using the :func:`Dim` API to construct dynamic dimensions and by associating
+      them with example inputs through the ``dynamic_shapes`` argument.
+
+    If any assumption can not be validated, a fatal error will be raised. When that happens,
+    the error message will include suggested fixes to the specification that are needed
+    to validate the assumptions. For example :func:`export` might suggest the
+    following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the
+    shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``::
+
+        dim = Dim("dim0_x", max=5)
+
+    This example means the generated code requires dimension 0 of input ``x`` to be less
+    than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension
+    definitions and then copy them verbatim into your code without needing to change the
+    ``dynamic_shapes`` argument to your :func:`export` call.
+
+    Args:
+        mod: We will trace the forward method of this module.
+
+        args: Example positional inputs.
+
+        kwargs: Optional example keyword inputs.
+
+        dynamic_shapes:
+         An optional argument where the type should either be:
+         1) a dict from argument names of ``f`` to their dynamic shape specifications,
+         2) a tuple that specifies dynamic shape specifications for each input in original order.
+         If you are specifying dynamism on keyword args, you will need to pass them in the order that
+         is defined in the original function signature.
+
+         The dynamic shape of a tensor argument can be specified as either
+         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
+         not required to include static dimension indices in this dict, but when they are,
+         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
+         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
+         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
+         recursively specified by using mappings or sequences of contained specifications.
+
+        strict: When enabled (default), the export function will trace the program through
+         TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
+         exported program will not validate the implicit assumptions baked into the graph and
+         may cause behavior divergence between the original model and the exported one. This is
+         useful when users need to workaround bugs in the tracer, or simply want incrementally
+         enable safety in their models. Note that this does not affect the resulting IR spec
+         to be different and the model will be serialized in the same way regardless of what value
+         is passed here.
+         WARNING: This option is experimental and use this at your own risk.
+
+    Returns:
+        An :class:`ExportedProgram` containing the traced callable.
+
+    **Acceptable input/output types**
+
+    Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
+
+    - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
+    - Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
+    - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
+      ``OrderedDict`` containing all above types.
+
+    """
+    from ._trace import _export
+
+    if not isinstance(mod, torch.nn.Module):
+        raise ValueError(
+            f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
+        )
+
+    return _export(
+        mod,
+        args,
+        kwargs,
+        dynamic_shapes,
+        strict=strict,
+        preserve_module_call_signature=preserve_module_call_signature,
+    )
+
+
+def save(
+    ep: ExportedProgram,
+    f: Union[str, os.PathLike, io.BytesIO],
+    *,
+    extra_files: Optional[Dict[str, Any]] = None,
+    opset_version: Optional[Dict[str, int]] = None,
+) -> None:
+    """
+
+    .. warning::
+        Under active development, saved files may not be usable in newer versions
+        of PyTorch.
+
+    Saves an :class:`ExportedProgram` to a file-like object. It can then be
+    loaded using the Python API :func:`torch.export.load `.
+
+    Args:
+        ep (ExportedProgram): The exported program to save.
+
+        f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
+         implement write and flush) or a string containing a file name.
+
+        extra_files (Optional[Dict[str, Any]]): Map from filename to contents
+         which will be stored as part of f.
+
+        opset_version (Optional[Dict[str, int]]): A map of opset names
+         to the version of this opset
+
+
+    Example::
+
+        import torch
+        import io
+
+        class MyModule(torch.nn.Module):
+            def forward(self, x):
+                return x + 10
+
+        ep = torch.export.export(MyModule(), (torch.randn(5),))
+
+        # Save to file
+        torch.export.save(ep, 'exported_program.pt2')
+
+        # Save to io.BytesIO buffer
+        buffer = io.BytesIO()
+        torch.export.save(ep, buffer)
+
+        # Save with extra files
+        extra_files = {'foo.txt': b'bar'.decode('utf-8')}
+        torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
+
+    """
+    from torch._export import save
+
+    if not isinstance(ep, ExportedProgram):
+        raise TypeError(
+            f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead."
+        )
+
+    save(ep, f, extra_files=extra_files, opset_version=opset_version)
+
+
+def load(
+    f: Union[str, os.PathLike, io.BytesIO],
+    *,
+    extra_files: Optional[Dict[str, Any]] = None,
+    expected_opset_version: Optional[Dict[str, int]] = None,
+) -> ExportedProgram:
+    """
+
+    .. warning::
+        Under active development, saved files may not be usable in newer versions
+        of PyTorch.
+
+    Loads an :class:`ExportedProgram` previously saved with
+    :func:`torch.export.save `.
+
+    Args:
+        ep (ExportedProgram): The exported program to save.
+
+        f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
+         implement write and flush) or a string containing a file name.
+
+        extra_files (Optional[Dict[str, Any]]): The extra filenames given in
+         this map would be loaded and their content would be stored in the
+         provided map.
+
+        expected_opset_version (Optional[Dict[str, int]]): A map of opset names
+         to expected opset versions
+
+    Returns:
+        An :class:`ExportedProgram` object
+
+    Example::
+
+        import torch
+        import io
+
+        # Load ExportedProgram from file
+        ep = torch.export.load('exported_program.pt2')
+
+        # Load ExportedProgram from io.BytesIO object
+        with open('exported_program.pt2', 'rb') as f:
+            buffer = io.BytesIO(f.read())
+        buffer.seek(0)
+        ep = torch.export.load(buffer)
+
+        # Load with extra files.
+        extra_files = {'foo.txt': ''}  # values will be replaced with data
+        ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
+        print(extra_files['foo.txt'])
+        print(ep(torch.randn(5)))
+    """
+    from torch._export import load
+
+    return load(
+        f, extra_files=extra_files, expected_opset_version=expected_opset_version
+    )
+
+
+def register_dataclass(
+    cls: Type[Any],
+    *,
+    serialized_type_name: Optional[str] = None,
+) -> None:
+    """
+    Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
+
+    Args:
+        cls: the dataclass type to register
+        serialized_type_name: The serialized name for the dataclass. This is
+        required if you want to serialize the pytree TreeSpec containing this
+        dataclass.
+
+    Example::
+
+        @dataclass
+        class InputDataClass:
+            feature: torch.Tensor
+            bias: int
+
+        class OutputDataClass:
+            res: torch.Tensor
+
+        torch.export.register_dataclass(InputDataClass)
+        torch.export.register_dataclass(OutputDataClass)
+
+        def fn(o: InputDataClass) -> torch.Tensor:
+            res = res=o.feature + o.bias
+            return OutputDataClass(res=res)
+
+        ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
+        print(ep)
+
+    """
+
+    from torch._export.utils import register_dataclass_as_pytree_node
+
+    return register_dataclass_as_pytree_node(
+        cls, serialized_type_name=serialized_type_name
+    )
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86ec921f4a4e68c078e0682a1ee15f9b405d0a15
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e34f8969990041ae882e2c5e31b81267b3eca530
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1506966b7daf57bc81aac28628d404a04a6fb03f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/_safeguard.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/_safeguard.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfe58e97e4df8e2286f57d56842aafd8ed3280cd
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/_safeguard.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/_trace.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/_trace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06dbfa8b23d7976c56aea236a49e60cf4a577b15
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/_trace.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/_tree_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/_tree_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d920098d4fd007bfd375d58e2cf88e1b347820a3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/_tree_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/_unlift.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/_unlift.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f25a81e3c1b3da5ada919dd125c3de800acfd73f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/_unlift.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/custom_obj.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/custom_obj.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d146d777bf1ba4b929bedfad9aad28d2e818870
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/custom_obj.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ac1b09a0455f865c7f202d119b98b8039427708
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/exported_program.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/exported_program.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4848acf731f8e1c1af743dfafc6b67c755533881
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/exported_program.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/graph_signature.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/graph_signature.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85e8b5a2828803d24d1fc5baf1205e6ca6df8e8b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/graph_signature.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/__pycache__/unflatten.cpython-39.pyc b/MLPY/Lib/site-packages/torch/export/__pycache__/unflatten.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eace8613c9a58a3f7bf3674d119543c4c8203c27
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/export/__pycache__/unflatten.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/export/_remove_auto_functionalized_pass.py b/MLPY/Lib/site-packages/torch/export/_remove_auto_functionalized_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..e63ed74589c2d98e2befc003e74ff983c42663ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/_remove_auto_functionalized_pass.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import operator
+from typing import List
+
+import torch
+from torch._higher_order_ops.auto_functionalize import (
+    auto_functionalized,
+    get_mutable_arg_names,
+)
+from torch.export import ExportedProgram
+
+
+def unsafe_remove_auto_functionalized_pass(
+    ep: ExportedProgram,
+) -> ExportedProgram:
+    """
+    This pass removes an instances of the higher order op 'auto_functionalized',
+    and modifies the calling EP inplace to have the original mutator op.
+    This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
+    """
+    auto_functionalize_nodes: List[torch.fx.Node] = []
+    for module in ep.graph_module.modules():
+        if not isinstance(module, torch.fx.GraphModule):
+            continue
+        for node in ep.graph.nodes:
+            if node.op == "call_function" and node.target is auto_functionalized:
+                auto_functionalize_nodes.append(node)
+
+    # Update every use of the HOP
+    for node in reversed(auto_functionalize_nodes):
+        func = node.args[0]
+        original_kwargs = node.kwargs
+        assert isinstance(func, torch._ops.OpOverload)
+
+        with ep.graph.inserting_before(node):
+            # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
+            new_node = ep.graph.call_function(func, kwargs=node.kwargs)
+        for k, v in node.meta.items():
+            new_node.meta[k] = v
+
+        # Replace auto_functionalize(func, args) with just func(args)
+        node.replace_all_uses_with(new_node)
+
+        mutable_args_names = get_mutable_arg_names(new_node.target)
+        output_specs = ep.graph_signature.output_specs
+
+        # update the users of the auto_func node (the getitem nodes)
+        for user in list(new_node.users.keys()):
+            assert user.target == operator.getitem
+            # getitem corresponding to a mutated input, just replace all uses with the original input
+            if user.args[1] >= len(func._schema.returns):
+                assert user.args[1] <= len(func._schema.returns) + len(
+                    mutable_args_names
+                )
+
+                # If the result of getitem was used in an output node, update the output spec with the correct name
+                adusted_index = user.args[1] - len(func._schema.returns)
+                original_arg = original_kwargs[mutable_args_names[adusted_index]]
+                for spec in output_specs:
+                    if spec.arg.name == user.name:
+                        spec.arg.name = original_arg.name  # pyre-ignore
+                        break
+
+                # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
+                # of the getitem calls following the HOP.
+                user.replace_all_uses_with(
+                    original_kwargs[mutable_args_names[adusted_index]]
+                )
+
+        if len(func._schema.returns) == 1:
+            # If the function has 1 return then it will just directly return the
+            # result -- we don't need a getitem. So we can replace all the
+            # getitem(auto_functionalized, 0) with just the note itself.
+            for user in list(new_node.users.keys()):
+                if user.args[1] == 0:
+                    user.replace_all_uses_with(new_node)
+
+                    # Same case as above, update the output spec if getitem result used in an output node
+                    for spec in output_specs:
+                        if spec.arg.name == user.name:
+                            spec.arg.name = new_node.name
+                            break
+
+        new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
+        ep.graph.erase_node(node)
+
+    ep.graph.eliminate_dead_code()
+    return ep
diff --git a/MLPY/Lib/site-packages/torch/export/_remove_effect_tokens_pass.py b/MLPY/Lib/site-packages/torch/export/_remove_effect_tokens_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..911381067123c3d0e34343421e4c47f1c13ff412
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/_remove_effect_tokens_pass.py
@@ -0,0 +1,126 @@
+import operator
+from typing import List
+
+import torch
+from torch._higher_order_ops.effects import with_effects
+from .exported_program import ExportedProgram
+from .graph_signature import (
+    InputKind,
+    InputSpec,
+    OutputKind,
+    OutputSpec,
+    TensorArgument,
+)
+
+
+def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
+    """
+    Removes the existance of tokens from the exported program, including:
+    - Removes the input and output tokens
+    - Replaces with_effects(token, func, args) with just func(args)
+
+    This function does an inplace modification on the given ExportedProgram.
+    """
+    num_tokens: int = 0
+    input_token_names: List[str] = []
+    new_input_specs: List[InputSpec] = []
+    for inp in ep.graph_signature.input_specs:
+        if inp.kind == InputKind.TOKEN:
+            num_tokens += 1
+            assert isinstance(inp.arg, TensorArgument)
+            input_token_names.append(inp.arg.name)
+        else:
+            new_input_specs.append(inp)
+
+    num_out_tokens: int = 0
+    new_output_specs: List[str] = []
+    output_token_names: List[OutputSpec] = []
+    for out in ep.graph_signature.output_specs:
+        if out.kind == OutputKind.TOKEN:
+            num_out_tokens += 1
+            output_token_names.append(out.arg.name)
+        else:
+            new_output_specs.append(out)
+
+    assert num_tokens == num_out_tokens
+
+    output_node = None
+    with_effect_nodes: List[torch.fx.Node] = []
+    for node in ep.graph.nodes:
+        if node.op == "output":
+            output_node = node
+            break
+
+        if not (node.op == "call_function" and node.target is with_effects):
+            continue
+
+        with_effect_nodes.append(node)
+
+    # Remove tokens from outputs
+    assert output_node is not None
+    output_args = output_node.args[0]
+    assert len(output_args) >= num_tokens
+    out_token_nodes = output_args[:num_tokens]
+    output_node.args = (tuple(output_args[num_tokens:]),)
+    for out_token in out_token_nodes:
+        assert out_token.name in output_token_names
+        ep.graph.erase_node(out_token)
+
+    # Replace with_effects(token, func, args) with just func(args)
+    for node in reversed(with_effect_nodes):
+        func = node.args[1]
+        assert isinstance(func, torch._ops.OpOverload)
+
+        with ep.graph.inserting_before(node):
+            new_node = ep.graph.call_function(func, node.args[2:])
+        for k, v in node.meta.items():
+            new_node.meta[k] = v
+
+        node.replace_all_uses_with(new_node)
+
+        # Update user getitem nodes
+        for user in list(new_node.users.keys()):
+            assert user.target == operator.getitem
+            # getitem(with_effects, 0) == token
+            if user.args[1] == 0:
+                ep.graph.erase_node(user)
+
+        if len(func._schema.returns) == 1:
+            # If the function has 1 return then it will just directly return the
+            # result -- we don't need a getitem. So we can replace all the
+            # getitem(with_effects, 1) with just the note itself.
+            for user in list(new_node.users.keys()):
+                assert user.args[1] == 1
+                user.replace_all_uses_with(new_node)
+
+            new_node.meta["val"] = node.meta["val"][1]
+        elif len(func._schema.returns) > 1:
+            # If the function has more than 1 return then since we got rid of
+            # the 1st return value (the token), we need to bump all the other
+            # getitem calls by 1 down
+            for user in list(new_node.users.keys()):
+                assert user.args[1] >= 1
+                user.args = (user.args[0], user.args[1] - 1)
+
+            new_node.meta["val"] = node.meta["val"][1:]
+        else:
+            assert len(func._schema.returns) == 0
+            assert len(new_node.users) == 0
+            new_node.meta["val"] = None
+
+        ep.graph.erase_node(node)
+
+    # Remove tokens from inputs
+    placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
+    assert len(placeholders) >= num_tokens
+    inp_token_nodes = placeholders[:num_tokens]
+    for inp_token in inp_token_nodes:
+        assert inp_token.name in input_token_names
+        ep.graph.erase_node(inp_token)
+
+    # Update graph signature
+    ep.graph_signature.input_specs = new_input_specs
+    ep.graph_signature.output_specs = new_output_specs
+
+    ep.graph.eliminate_dead_code()
+    return ep
diff --git a/MLPY/Lib/site-packages/torch/export/_safeguard.py b/MLPY/Lib/site-packages/torch/export/_safeguard.py
new file mode 100644
index 0000000000000000000000000000000000000000..e96e595e3d04df98636dbdaf3e5849e1b54dd216
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/_safeguard.py
@@ -0,0 +1,42 @@
+import torch
+from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
+from torch.overrides import TorchFunctionMode
+
+
+class AutogradStateOpsFailSafeguard(TorchFunctionMode):
+    """
+    Detect grad state ops during exporting the graph and fail the process by
+    raising an error, to avoid unexpected behavior. Those grad mode ops could be:
+    `torch.no_grad`
+    `torch.enable_grad`
+    `torch.set_grad_enabled`
+
+    Export with predispatch mode is exempted.
+    """
+
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+        unsupported_grad_mode_ops = [
+            torch._C._set_grad_enabled,
+        ]
+        # It's only enabled while tracing, by confirming the torch dispatch mode is
+        # any active PROXY. This is to allow the autograd ops out of tracing.
+        current_state = torch._C.is_grad_enabled()
+        if func in unsupported_grad_mode_ops:
+            assert len(args) == 1
+            changed_state = args[0]
+            mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
+            # Intend to check if it's not the pre_dispatch mode. It's allowed to use
+            # autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
+            if (
+                mode
+                and isinstance(mode, ProxyTorchDispatchMode)
+                and not mode.pre_dispatch
+                and changed_state != current_state
+            ):
+                raise RuntimeError(
+                    f"Encountered autograd state manager op {func} trying to change global autograd state "
+                    "while exporting. This is unsafe because we don't capture this op in torch.export "
+                    "today, hence we can't reflect the user intention soundly."
+                )
+        return func(*args, **kwargs)
diff --git a/MLPY/Lib/site-packages/torch/export/_trace.py b/MLPY/Lib/site-packages/torch/export/_trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b7bc48868c5091a132367e44ce5187afc8713e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/_trace.py
@@ -0,0 +1,1060 @@
+import dataclasses
+import functools
+import inspect
+import logging
+import re
+import time
+import warnings
+from contextlib import contextmanager, nullcontext
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch._dynamo
+import torch.fx
+
+import torch.utils._pytree as pytree
+from torch._dynamo.exc import UserError, UserErrorType
+from torch._export.non_strict_utils import (
+    make_constraints,
+    make_fake_inputs,
+    make_fake_params_buffers,
+)
+from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
+    _AddRuntimeAssertionsForInlineConstraintsPass,
+)
+from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
+from torch._export.passes.lift_constants_pass import (
+    ConstantAttrMap,
+    lift_constants_pass,
+    rewrite_script_object_meta,
+)
+from torch._export.wrappers import _wrap_submodules
+from torch._functorch.aot_autograd import aot_export_module
+from torch._guards import detect_fake_mode
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch._utils_internal import log_export_usage
+from torch.export.exported_program import OutputKind
+from torch.fx.experimental.symbolic_shapes import (
+    ConstraintViolationError,
+    free_unbacked_symbols,
+    GuardOnDataDependentSymNode,
+    ShapeEnv,
+)
+from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
+from torch.utils._sympy.value_ranges import ValueRangeError
+
+from ._safeguard import AutogradStateOpsFailSafeguard
+
+from .dynamic_shapes import _process_constraints, Constraint
+from .exported_program import (
+    _disable_prexisiting_fake_mode,
+    ExportedProgram,
+    InputKind,
+    ModuleCallEntry,
+    ModuleCallSignature,
+)
+from .graph_signature import (
+    _sig_to_specs,
+    ArgumentSpec,
+    ConstantArgument,
+    CustomObjArgument,
+    ExportGraphSignature,
+    SymIntArgument,
+    TensorArgument,
+)
+
+
+log = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class ExportDynamoConfig:
+    """
+    Manage Export-specific configurations of Dynamo.
+    """
+
+    allow_rnn: bool = True
+    reorderable_logging_functions: Set[Callable] = dataclasses.field(
+        default_factory=set
+    )
+
+
+DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
+DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
+    logging.critical,
+    logging.debug,
+    logging.error,
+    logging.exception,
+    logging.info,
+    logging.log,
+    logging.warning,
+    print,
+    warnings.warn,
+}
+
+
+@contextmanager
+def _ignore_backend_decomps():
+    orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
+    orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
+    try:
+        yield
+    finally:
+        torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
+        torch.backends.nnpack.set_flags(*orig_nnpack_flag)
+
+
+def _convert_input_to_fake(gm, args, kwargs):
+    params_buffers = _get_params_buffers(gm)
+    fake_inps: List[torch.Tensor] = []
+    for node in gm.graph.nodes:
+        if node.op == "placeholder" and "val" in node.meta:
+            fake_val = node.meta["val"]
+            if fake_val is not None and isinstance(fake_val, torch.Tensor):
+                fake_inps.append(fake_val)
+
+    if detected_fake_mode := detect_fake_mode(fake_inps):
+        fake_mode = detected_fake_mode
+    else:
+        fake_mode = FakeTensorMode(shape_env=ShapeEnv())
+
+    if len(args) == 0 and len(kwargs) == 0:
+        return (), {}, params_buffers, fake_mode
+
+    count = 0
+
+    def convert_to_fake(x):
+        nonlocal count
+        val = fake_inps[count]
+        count += 1
+        return val
+
+    fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
+    # TODO properly use the cached fake tensor
+    fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
+    fake_params_buffers = pytree.tree_map_only(
+        torch.Tensor,
+        functools.partial(fake_mode.from_tensor, static_shapes=True),
+        params_buffers,
+    )
+    return fake_args, fake_kwargs, fake_params_buffers, fake_mode
+
+
+def _replace_param_buffer_names(param_buffer_table, sig):
+    for spec in sig.input_specs:
+        if spec.kind in (
+            InputKind.PARAMETER,
+            InputKind.BUFFER,
+        ):
+            spec.target = param_buffer_table[spec.target]
+    for spec in sig.output_specs:
+        if spec.kind in (
+            OutputKind.BUFFER_MUTATION,
+            OutputKind.GRADIENT_TO_PARAMETER,
+        ):
+            spec.target = param_buffer_table[spec.target]
+
+
+def _convert_to_positional_args(orig_arg_names, args, kwargs):
+    assert len(orig_arg_names) == len(args) + len(kwargs), (
+        f"Total number of arg names is expected to be {len(orig_arg_names)} "
+        f"but got {len(args)} positional args, {len(kwargs)} kwargs."
+    )
+    reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
+    return (
+        *args,
+        *reordered_kwargs,
+    )
+
+
+def _normalize_nn_module_stack(gm_torch_level, root_cls):
+    # Append a root module to every nn_module_stack.
+    root = "L['self']"
+    root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
+    for gm in gm_torch_level.modules():
+        if not isinstance(gm, torch.fx.GraphModule):
+            continue
+        for node in gm.graph.nodes:
+            if node.op in ["placeholder", "output"]:
+                continue
+            add_root = True
+            if nn_module_stack := node.meta.get("nn_module_stack", {}):
+                path, ty = next(iter(nn_module_stack.values()))
+                # After deserializing the class `ty` might not exist anymore so
+                # it could be a string
+                if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
+                    # TODO Figure out why sometimes we have root sometimes we don't.
+                    if path == root and ty is root_cls:
+                        add_root = False
+                else:
+                    assert isinstance(ty, str)
+            if add_root:
+
+                def normalize_path(path):
+                    try:
+                        parts = []
+
+                        class Path:
+                            def __getattr__(self, name):
+                                parts.append(name)
+                                return self
+
+                            def __getitem__(self, idx):
+                                parts.append(str(idx))
+                                return self
+
+                        eval(path, {"L": {"self": Path()}})
+                        return ".".join(parts)
+                    except Exception:  # TODO(zhxchen17) Remove this.
+                        return path
+
+                nn_module_stack = {root_key: (root, root_cls), **nn_module_stack}
+                node.meta["nn_module_stack"] = {
+                    key: (normalize_path(path), ty)
+                    for key, (path, ty) in nn_module_stack.items()
+                }
+
+
+def _get_param_buffer_mapping(
+    original_module: torch.nn.Module,
+    traced_module: torch.nn.Module,
+) -> Dict[str, str]:
+    """
+    Returns a mapping of parameter/buffer names from the new module to the
+    original model. This is to help with restoring the FQN for parameter/buffers
+    of a traced module to what the original module contains.
+    """
+
+    param_lookup: Dict[int, List[str]] = {}
+    buffer_lookup: Dict[int, List[str]] = {}
+    for name, param in original_module.named_parameters(remove_duplicate=False):
+        param_lookup.setdefault(id(param), []).append(name)
+    for name, buffer in original_module.named_buffers(remove_duplicate=False):
+        buffer_lookup.setdefault(id(buffer), []).append(name)
+
+    param_buffer_table: Dict[str, str] = {}
+    for dynamo_name, dynamo_param in traced_module.named_parameters(
+        remove_duplicate=False
+    ):
+        assert dynamo_name not in param_buffer_table
+        if id(dynamo_param) in param_lookup:
+            param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop()
+
+    for dynamo_name, dynamo_buffer in traced_module.named_buffers(
+        remove_duplicate=False
+    ):
+        assert dynamo_name not in param_buffer_table
+        if id(dynamo_buffer) in buffer_lookup:
+            param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop()
+
+    return param_buffer_table
+
+
+def _remap_constants(
+    orig_constant_attrs: ConstantAttrMap,
+    graph_signature: ExportGraphSignature,
+    constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
+) -> None:
+    """Rewrite the graph signature and constants table to use the FQN from the original module."""
+    remap_table: Dict[str, str] = {}
+    for name, value in constants.items():
+        if value in orig_constant_attrs:
+            remap_table[name] = orig_constant_attrs[value]
+
+    for spec in graph_signature.input_specs:
+        if spec.kind in (
+            InputKind.CONSTANT_TENSOR,
+            InputKind.CUSTOM_OBJ,
+        ):
+            orig_target = spec.target
+            assert orig_target is not None
+            spec.target = remap_table.get(orig_target, orig_target)
+
+            constant = constants[orig_target]
+            del constants[orig_target]
+            constants[spec.target] = constant
+
+
+def _restore_state_dict(
+    original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
+) -> None:
+    """
+    Restores the state dict of the traced module to that of the original module.
+    """
+    param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
+    # Since the graph module is flattened (no module heirarchy), we
+    # need to noramlize the module by replacing "." with "_". If we
+    # don't, it will try to save the weight to a submodule which no
+    # longer exists.
+    for name, fqn in param_buffer_table.items():
+        param_buffer_table[name] = fqn.replace(".", "_")
+
+    # Replace state dict attr names with the fqn
+    for name, fqn in param_buffer_table.items():
+        if not hasattr(traced_module, name):
+            continue
+
+        attr = getattr(traced_module, name)
+        if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter):
+            traced_module.register_buffer(fqn, attr)
+        else:
+            setattr(traced_module, fqn, attr)
+        delattr(traced_module, name)
+
+    # Replace graph getattr nodes with the correct name
+    for node in traced_module.graph.nodes:
+        if node.op == "get_attr":
+            attr_name = node.target
+            if attr_name in param_buffer_table:
+                node.target = param_buffer_table[attr_name]
+
+    traced_module.recompile()
+
+
+def _export_to_torch_ir(
+    f: Callable,
+    args: Tuple[Any, ...],
+    kwargs: Optional[Dict[str, Any]] = None,
+    constraints: Optional[List[Constraint]] = None,
+    *,
+    preserve_module_call_signature: Tuple[str, ...] = (),
+    disable_constraint_solver: bool = False,
+    restore_fqn: bool = True,
+    _log_export_usage: bool = True,
+) -> torch.fx.GraphModule:
+    """
+    Traces either an nn.Module's forward function or just a callable with PyTorch
+    operations inside and produce a torch.fx.GraphModule in torch IR.
+    """
+
+    if _log_export_usage:
+        log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
+
+    kwargs = kwargs or {}
+
+    if not isinstance(args, tuple):
+        raise UserError(
+            UserErrorType.INVALID_INPUT,
+            f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
+        )
+
+    with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
+        try:
+            module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
+            with _wrap_submodules(
+                f, preserve_module_call_signature, module_call_specs
+            ), _ignore_backend_decomps():
+                gm_torch_level, _ = torch._dynamo.export(
+                    f,
+                    constraints=constraints,  # type: ignore[arg-type]
+                    assume_static_by_default=True,
+                    tracing_mode="symbolic",
+                    disable_constraint_solver=disable_constraint_solver,
+                    _log_export_usage=_log_export_usage,
+                )(
+                    *args,
+                    **kwargs,
+                )
+        except (ConstraintViolationError, ValueRangeError) as e:
+            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: TRY200
+        except GuardOnDataDependentSymNode as e:
+            raise UserError(  # noqa: TRY200
+                UserErrorType.ANTI_PATTERN,
+                f"Consider annotating your code using torch._constrain_as_*(). {str(e)}",
+                case_name="constrain_as_size_example",
+            )
+
+    gm_torch_level.meta["module_call_specs"] = module_call_specs
+
+    if isinstance(f, torch.nn.Module) and restore_fqn:
+        _restore_state_dict(f, gm_torch_level)
+
+    return gm_torch_level
+
+
+def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
+    """Search the module hierarchy, gathering up all tensor and ScriptObject constants.
+
+    Returns a dictionary mapping hash(value) to the name of the constant. We
+    have to abuse `hash` here unfortunately, see: [ScriptObject hash].
+    """
+    constants = ConstantAttrMap()
+    buffers_parameters = set(m.buffers())
+    buffers_parameters.update(m.parameters())
+
+    def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
+        for k, v in m.__dict__.items():
+            if isinstance(v, (torch.Tensor, torch.ScriptObject)):
+                if v in buffers_parameters:
+                    # filter out buffers and parameters, leaving only constants
+                    continue
+
+                fqn = ".".join(prefix_atoms + [k])
+                if v in constants:
+                    raise ValueError(
+                        f"Duplicate reference to constant attribute found: '{constants[v]}' and '{fqn}'."
+                    )
+
+                constants[v] = fqn
+        for k, v in m.named_children():
+            inner(v, prefix_atoms + [k], constants)
+
+    inner(m, [], constants)
+    return constants
+
+
+def _export_non_strict(
+    mod: torch.nn.Module,
+    fake_args,
+    fake_kwargs,
+    fake_params_buffers,
+    constant_attrs: ConstantAttrMap,
+    *,
+    transform=lambda x: x,  # TODO(zhxchen17) Revisit if this is needed later.
+    pre_dispatch=False,
+):
+    # [NOTE] If the user is exporting under training mode, we want to detect if there is any
+    # state change in the autograd global state and error. If the user is exporting under inference
+    # mode, we don't care.
+    is_grad_enabled = torch._C.is_grad_enabled()
+    grad_safe_guard = (
+        AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext()
+    )
+
+    @contextmanager
+    def _compiling_state_context():
+        old_value = torch.compiler._is_compiling_flag
+        try:
+            torch.compiler._is_compiling_flag = True
+            yield
+        finally:
+            torch.compiler._is_compiling_flag = old_value
+
+    # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
+    # otherwise aot_export_module will error out because it sees a mix of fake_modes.
+    # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
+    with torch.nn.utils.stateless._reparametrize_module(
+        mod, fake_params_buffers
+    ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context():  # type: ignore[attr-defined]
+        gm, graph_signature = transform(aot_export_module)(
+            mod,
+            fake_args,
+            trace_joint=False,
+            pre_dispatch=pre_dispatch,
+            kwargs=fake_kwargs,
+        )
+    # TODO unfortunately preserving graph-level metadata is not
+    # working well with aot_export. So we manually copy it.
+    # (The node-level meta is addressed above.)
+    if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
+        gm.meta.update(mod.meta)
+
+    if pre_dispatch:
+        from torch._export.passes.replace_set_grad_with_hop_pass import (
+            replace_set_grad_with_hop_pass,
+        )
+
+        gm = replace_set_grad_with_hop_pass(gm)
+
+    # NOTE: aot_export adds symint metadata for placeholders with int values;
+    # since these become specialized, we replace such metadata with the original values
+    flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
+    index = 0
+    total_non_user_inputs = (
+        len(graph_signature.parameters)
+        + len(graph_signature.buffers)
+        + len(graph_signature.input_tokens)
+    )
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            if index >= total_non_user_inputs:
+                user_arg = flat_args[index - total_non_user_inputs]
+                if not isinstance(user_arg, torch.Tensor):
+                    node.meta["val"] = user_arg
+            index += 1
+
+    is_joint = graph_signature.backward_signature is not None
+
+    def make_argument_spec(node) -> ArgumentSpec:
+        if isinstance(node, (int, bool, float, type(None))):
+            # For const outputs we just directly return this
+            return ConstantArgument(value=node)
+
+        assert (
+            "val" in node.meta
+        ), f"{node} is not a constant or a node with a 'val' metadata field"
+        val = node.meta["val"]
+        if isinstance(val, FakeTensor):
+            return TensorArgument(name=node.name)
+        elif isinstance(val, torch.SymInt):
+            return SymIntArgument(name=node.name)
+        elif isinstance(val, torch.ScriptObject):
+            return CustomObjArgument(
+                name=node.name, class_fqn=val._type().qualified_name()  # type: ignore[attr-defined]
+            )
+        else:
+            # TODO: this branch is likely wrong, all permissible ConstantArgument type
+            # should have been handled already
+            return ConstantArgument(value=val)
+
+    input_specs, output_specs = _sig_to_specs(
+        user_inputs=set(graph_signature.user_inputs),
+        inputs_to_parameters=graph_signature.inputs_to_parameters,  # type: ignore[arg-type]
+        inputs_to_buffers=graph_signature.inputs_to_buffers,  # type: ignore[arg-type]
+        user_outputs=set(graph_signature.user_outputs),  # type: ignore[arg-type]
+        buffer_mutations=graph_signature.buffers_to_mutate,  # type: ignore[arg-type]
+        user_input_mutations=graph_signature.user_inputs_to_mutate,  # type: ignore[arg-type]
+        grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {},  # type: ignore[arg-type, union-attr]
+        grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {},  # type: ignore[arg-type, union-attr]
+        loss_output=graph_signature.backward_signature.loss_output if is_joint else None,  # type: ignore[arg-type, union-attr]
+        inputs=[
+            make_argument_spec(node)
+            for node in gm.graph.nodes
+            if node.op == "placeholder"
+        ],
+        outputs=[
+            make_argument_spec(node)
+            for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
+        ],
+        input_tokens=graph_signature.input_tokens,
+        output_tokens=graph_signature.output_tokens,
+    )
+    export_graph_signature = ExportGraphSignature(
+        input_specs=input_specs, output_specs=output_specs
+    )
+
+    constants = rewrite_script_object_meta(gm)
+    constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
+
+    @dataclasses.dataclass
+    class _ExportedProgramNonStrict:
+        gm: torch.fx.GraphModule
+        sig: ExportGraphSignature
+        constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]
+
+    return _ExportedProgramNonStrict(
+        gm,
+        export_graph_signature,
+        constants,
+    )
+
+
+def _get_params_buffers(mod: torch.nn.Module) -> Dict[str, torch.Tensor]:
+    params_buffers: Dict[str, torch.Tensor] = {}
+    for name, param in mod.named_parameters(remove_duplicate=False):
+        params_buffers[name] = param
+
+    for name, buffer in mod.named_buffers(remove_duplicate=False):
+        params_buffers[name] = buffer
+    return params_buffers
+
+
+def _rewrite_dynamo_tensor_constants(
+    orig_mod_buffers: Set[torch.Tensor],
+    traced_mod_buffers: Dict[str, torch.Tensor],
+    graph_signature: ExportGraphSignature,
+    constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
+):
+    """Dynamo erroneously marks tensor attributes on modules as a buffers.
+
+    Rewrite them to be tensor constants.
+    """
+    for spec in graph_signature.input_specs:
+        if spec.kind == InputKind.BUFFER:
+            assert spec.target is not None
+            value = traced_mod_buffers[spec.target]
+            if value not in orig_mod_buffers:
+                # This was a tensor constant erroneously marked as a buffer.
+                # Convert it int oa constant in the graph signature, and add its
+                # value to the constants table.
+                spec.kind = InputKind.CONSTANT_TENSOR
+                constants[spec.target] = value
+
+
+def _rewrite_non_persistent_buffers(
+    orig_mod: torch.nn.Module,
+    graph_signature: ExportGraphSignature,
+    constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
+):
+    """Dynamo erroneously drops the persistent flag on buffers.
+
+    Rewrite non-persistent buffers to reflect the original module.
+    """
+    state_dict = orig_mod.state_dict()
+    for spec in graph_signature.input_specs:
+        if spec.kind == InputKind.BUFFER:
+            assert spec.target is not None
+            if spec.target not in state_dict:
+                assert spec.target not in constants
+                spec.persistent = False
+                constants[spec.target] = orig_mod.get_buffer(spec.target)
+
+
+def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
+    op_count = 0
+    op_set = set()
+    for m in ep.graph_module.modules():
+        if not isinstance(m, torch.fx.GraphModule):
+            continue
+        for node in m.graph.nodes:
+            if node.op != "call_function":
+                continue
+            op_count += 1
+            assert hasattr(node.target, "__module__")
+            assert hasattr(node.target, "__name__")
+            op_set.add(f"{node.target.__module__}.{node.target.__name__}")
+    return {"op_count": op_count, "op_set": op_set}
+
+
+_EXPORT_FLAGS: Optional[Set[str]] = None
+
+
+def _log_export_wrapper(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        global _EXPORT_FLAGS
+        try:
+            start = time.time()
+            ep = fn(*args, **kwargs)
+            end = time.time()
+            log_export_usage(
+                event="export.time",
+                metrics=end - start,
+                flags=_EXPORT_FLAGS,
+                **get_ep_stats(ep),
+            )
+        except Exception as e:
+            t = type(e)
+            error_type = t.__module__ + "." + t.__qualname__
+            log_export_usage(
+                event="export.error",
+                type=error_type,
+                message=str(e),
+                flags=_EXPORT_FLAGS,
+            )
+            raise e
+        finally:
+            _EXPORT_FLAGS = None
+
+        return ep
+
+    return wrapper
+
+
+@_log_export_wrapper
+@_disable_prexisiting_fake_mode
+def _export(
+    mod: torch.nn.Module,
+    args: Tuple[Any, ...],
+    kwargs: Optional[Dict[str, Any]] = None,
+    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
+    *,
+    strict: bool = True,
+    preserve_module_call_signature: Tuple[str, ...] = (),
+    pre_dispatch: bool = False,
+) -> ExportedProgram:
+    """
+    Traces either an nn.Module's forward function or just a callable with PyTorch
+    operations inside and produce a ExportedProgram.
+
+    Args:
+        f: the `nn.Module` to trace.
+
+        args: example positional inputs.
+
+        kwargs: optional example keyword inputs.
+
+        dynamic_shapes:
+         An optional argument where the type should either be:
+         1) a dict from argument names of ``f`` to their dynamic shape specifications,
+         2) a tuple that specifies dynamic shape specifications for each input in original order.
+         If you are specifying dynamism on keyword args, you will need to pass them in the order that
+         is defined in the original function signature.
+
+         The dynamic shape of a tensor argument can be specified as either
+         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
+         not required to include static dimension indices in this dict, but when they are,
+         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
+         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
+         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
+         recursively specified by using mappings or sequences of contained specifications.
+
+        preserve_module_call_signature: A list of submodule paths for which the original
+            calling conventions are preserved as metadata.
+
+    Returns:
+        An ExportedProgram containing the traced method.
+    """
+    from .dynamic_shapes import _process_dynamic_shapes
+
+    global _EXPORT_FLAGS
+    flags = set()
+    flags.add("strict" if strict else "non_strict")
+    flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch")
+    log_export_usage(event="export.enter", flags=flags)
+    _EXPORT_FLAGS = flags
+
+    constraints = _process_dynamic_shapes(mod, args, kwargs, dynamic_shapes) or []
+
+    kwargs = kwargs or {}
+
+    constant_attrs = _gather_constant_attrs(mod)
+
+    flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs))
+
+    if not strict:
+        out_spec = None
+
+        module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
+
+        def strip_root(x):
+            if isinstance(x, str) and x.startswith("_export_root"):
+                stripped = x[len("_export_root") :]
+                return stripped[1:] if stripped.startswith(".") else stripped
+            return x
+
+        def fixup_key(x):
+            return "L__self__" + strip_root(x)
+
+        def _tuplify_outputs(aot_export):
+            def _aot_export_non_strict(mod, args, kwargs=None, **flags):
+                kwargs = kwargs or {}
+
+                class Wrapper(torch.nn.Module):
+                    def __init__(self, mod):
+                        super().__init__()
+                        self._export_root = mod
+
+                    def forward(self, *args, **kwargs):
+                        nonlocal out_spec
+                        if isinstance(self._export_root, torch.fx.GraphModule):
+                            with torch.fx.traceback.preserve_node_meta():
+                                tree_out = torch.fx.Interpreter(self._export_root).run(
+                                    *args, **kwargs
+                                )
+                        else:
+                            tree_out = self._export_root(*args, **kwargs)
+                        flat_outs, out_spec = pytree.tree_flatten(tree_out)
+                        return tuple(flat_outs)
+
+                wrapped_mod = Wrapper(mod)
+                # Patch export_root to the signatures so that wrapper module correctly populates the
+                # in/out spec
+                new_preserved_call_signatures = [
+                    "_export_root." + i for i in preserve_module_call_signature
+                ]
+                with _wrap_submodules(
+                    wrapped_mod, new_preserved_call_signatures, module_call_specs
+                ):
+                    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
+
+                sig.parameters = pytree.tree_map(strip_root, sig.parameters)
+                sig.buffers = pytree.tree_map(strip_root, sig.buffers)
+                sig.inputs_to_buffers = pytree.tree_map(
+                    strip_root, sig.inputs_to_buffers
+                )
+                sig.inputs_to_parameters = pytree.tree_map(
+                    strip_root, sig.inputs_to_parameters
+                )
+                sig.buffers_to_mutate = pytree.tree_map(
+                    strip_root, sig.buffers_to_mutate
+                )
+                for node in gm.graph.nodes:
+                    if "nn_module_stack" in node.meta:
+                        nn_module_stack = node.meta["nn_module_stack"]
+                        node.meta["nn_module_stack"] = {
+                            fixup_key(key): val
+                            for key, val in pytree.tree_map(
+                                strip_root, nn_module_stack
+                            ).items()
+                        }
+
+                return gm, sig
+
+            return _aot_export_non_strict
+
+        (
+            fake_mode,
+            fake_args,
+            fake_kwargs,
+            equalities_inputs,
+            original_signature,
+        ) = make_fake_inputs(mod, args, kwargs, constraints)
+
+        fake_params_buffers = make_fake_params_buffers(
+            fake_mode, _get_params_buffers(mod)
+        )
+        with fake_mode:
+            ep_non_strict = _export_non_strict(
+                mod,
+                fake_args,
+                fake_kwargs,
+                fake_params_buffers,
+                constant_attrs,
+                pre_dispatch=pre_dispatch,
+                transform=_tuplify_outputs,
+            )
+        try:
+            range_constraints = make_constraints(
+                fake_mode,
+                equalities_inputs,
+                original_signature,
+                ep_non_strict.gm,
+            )
+        except (ConstraintViolationError, ValueRangeError) as e:
+            raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: TRY200
+
+        assert out_spec is not None
+
+        gm = ep_non_strict.gm
+
+        module_call_signatures = {
+            strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs)
+            for fqn, specs in module_call_specs.items()
+        }
+
+        if len(preserve_module_call_signature) > 0:
+            for node in gm.graph.nodes:
+                if node.target == torch.ops.higher_order._export_tracepoint:
+                    if "path" in node.kwargs:
+                        path = strip_root(node.kwargs["path"])
+                        with gm.graph.inserting_before(node):
+                            new_node = gm.graph.create_node(
+                                "call_function",
+                                torch.ops.higher_order._export_tracepoint,
+                                args=node.args,
+                                kwargs={
+                                    "path": path,
+                                    "kind": node.kwargs["kind"],
+                                },
+                            )
+                            node.replace_all_uses_with(new_node)
+                            gm.graph.erase_node(node)
+
+            res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm)
+            assert res is not None
+            gm = res.graph_module
+
+        _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants)
+        return ExportedProgram(
+            root=gm,
+            graph=gm.graph,
+            graph_signature=ep_non_strict.sig,
+            state_dict=mod.state_dict(keep_vars=True),
+            range_constraints=range_constraints,
+            module_call_graph=[
+                ModuleCallEntry(
+                    "",
+                    ModuleCallSignature(
+                        inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec
+                    ),
+                )
+            ]
+            + [
+                ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()
+            ],
+            example_inputs=(args, kwargs),
+            constants=ep_non_strict.constants,
+        )
+
+    gm_torch_level = _export_to_torch_ir(
+        mod,
+        args,
+        kwargs,
+        constraints,
+        preserve_module_call_signature=preserve_module_call_signature,
+        restore_fqn=False,  # don't need to restore because we will do it later
+        _log_export_usage=False,
+    )
+
+    # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
+    (
+        fake_args,
+        fake_kwargs,
+        fake_params_buffers,
+        dynamo_fake_mode,
+    ) = _convert_input_to_fake(gm_torch_level, args, kwargs)
+
+    # First, we want to pass through the graph to try populating
+    # val field for getattr if there is anything missing.
+    # This can happen when quantization adds extra params and forgets
+    # to update "val"
+    for node in gm_torch_level.graph.nodes:
+        if node.op == "get_attr" and "val" not in node.meta:
+            attr = getattr(gm_torch_level, node.target)
+            # Checks if it is not a HigherOrderOp branch or a module
+            if not isinstance(attr, torch.nn.Module):
+                assert (
+                    dynamo_fake_mode is not None
+                ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
+                node.meta["val"] = dynamo_fake_mode.from_tensor(
+                    attr, static_shapes=True
+                )
+
+    # When aot_export lifts the params, we lose the nn_module_stack
+    # and source_fn from the param nodes as they are treated as fresh inputs
+    # Therefore, we manually extract them before calling into aot_export
+    params_buffers_to_node_meta = {}
+    for node in gm_torch_level.graph.nodes:
+        target = node.target
+        meta = node.meta
+        if node.op == "call_module":
+            submodule = getattr(gm_torch_level, target)
+            if isinstance(submodule, torch.nn.Module):
+                for name, _ in submodule.named_parameters(
+                    recurse=True, remove_duplicate=False
+                ):
+                    params_buffers_to_node_meta[target + "." + name] = meta
+
+                for name, _ in submodule.named_buffers(
+                    recurse=True, remove_duplicate=False
+                ):
+                    params_buffers_to_node_meta[target + "." + name] = meta
+
+        if node.op == "get_attr":
+            submodule = getattr(gm_torch_level, target)
+            if not isinstance(submodule, torch.fx.GraphModule):
+                params_buffers_to_node_meta[target] = meta
+
+        # If the call_function uses param as input, we also need to update params' meta
+        # with this call_function node's meta.
+        # This is basically the same flow as torch.fx.traceback.preserve_meta()
+        if node.op == "call_function" and not isinstance(
+            node.target, torch._ops.HigherOrderOperator
+        ):
+            for arg in node._input_nodes:
+                if arg.op == "get_attr":
+                    for entry in torch.fx.proxy._COPY_META_FIELDS:
+                        if entry in meta:
+                            params_buffers_to_node_meta[arg.target][entry] = meta[entry]
+
+    # Fix the graph output signature to be tuple if scalar
+    out_spec = orig_out_spec = gm_torch_level._out_spec
+    assert out_spec is not None
+    # aot_export expect the return type to always be a tuple.
+    if out_spec.type not in (list, tuple):
+        out_spec = pytree.TreeSpec(tuple, None, [out_spec])
+
+    orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args  # type: ignore[attr-defined]
+
+    gm_torch_level.graph._codegen = _PyTreeCodeGen(
+        _PyTreeInfo(
+            orig_arg_names,
+            gm_torch_level._in_spec,
+            out_spec,
+        )
+    )
+    gm_torch_level.recompile()
+
+    _normalize_nn_module_stack(gm_torch_level, type(mod))
+
+    # NOTE: graph module expects only positional args
+    ep_non_strict = _export_non_strict(
+        gm_torch_level,
+        _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs),
+        {},
+        fake_params_buffers,
+        constant_attrs,
+        pre_dispatch=pre_dispatch,
+    )
+
+    gm = ep_non_strict.gm
+    export_graph_signature = ep_non_strict.sig
+    constants = ep_non_strict.constants
+
+    # After aot_export, set the param/buffer metadata back into placeholders
+    # Technically, users can still construct this data from param names
+    # without relying on this metadata
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            if node.target in export_graph_signature.inputs_to_parameters:
+                param_name = export_graph_signature.inputs_to_parameters[node.target]
+                if param_name in params_buffers_to_node_meta:
+                    for k, v in params_buffers_to_node_meta[param_name].items():
+                        node.meta[k] = v
+            if node.target in export_graph_signature.inputs_to_buffers:
+                buffer_name = export_graph_signature.inputs_to_buffers[node.target]
+                if buffer_name in params_buffers_to_node_meta:
+                    for k, v in params_buffers_to_node_meta[buffer_name].items():
+                        node.meta[k] = v
+
+    # The unbacked symint symbols are updated in aot_export
+    # so we serialize them here instead of inside dynamo
+
+    gm.meta["inline_constraints"] = {
+        k: v
+        for k, v in dynamo_fake_mode.shape_env.var_to_range.items()
+        if free_unbacked_symbols(k)
+    }
+
+    num_lifted = next(
+        (
+            i
+            for i, s in enumerate(export_graph_signature.input_specs)
+            if s.kind == InputKind.USER_INPUT
+        ),
+        len(export_graph_signature.input_specs),
+    )
+    range_constraints = _process_constraints(
+        dynamo_fake_mode,
+        gm,
+        num_lifted,
+        flat_args,
+    )
+
+    # Do some cleanups on the graph module to restore the state dict to the
+    # expected form. Each of these steps should probably get fixed upstream.
+    # 1. Remove tensor constants that were added as buffers.
+    _rewrite_dynamo_tensor_constants(
+        orig_mod_buffers=set(mod.buffers()),
+        traced_mod_buffers=dict(gm_torch_level.named_buffers()),
+        graph_signature=ep_non_strict.sig,
+        constants=ep_non_strict.constants,
+    )
+    # 2. Restore FQN of param/buffers
+    param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
+    _replace_param_buffer_names(param_buffer_table, export_graph_signature)
+
+    # 3. Remove non-persistent buffers from the graph signature
+    _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants)
+
+    # 4. Rewrite constants to have the same FQN as the original module.
+    _remap_constants(constant_attrs, export_graph_signature, constants)
+
+    module_call_signatures = {
+        fqn: ModuleCallSignature(inputs=[], outputs=[], **specs)
+        for fqn, specs in gm_torch_level.meta["module_call_specs"].items()
+    }
+
+    if len(preserve_module_call_signature) > 0:
+        res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
+        assert res is not None
+        gm = res.graph_module
+
+    assert orig_out_spec is not None
+    exported_program = ExportedProgram(
+        root=gm,
+        graph=gm.graph,
+        graph_signature=export_graph_signature,
+        state_dict=mod.state_dict(keep_vars=True),
+        range_constraints=range_constraints,
+        module_call_graph=[
+            ModuleCallEntry(
+                "",
+                ModuleCallSignature(
+                    inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=orig_out_spec
+                ),
+            )
+        ]
+        + [ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()],
+        example_inputs=(args, kwargs),
+        constants=constants,
+    )
+    log.debug("Exported program from AOTAutograd:\n%s", exported_program)
+
+    if len(range_constraints) > 0:
+        exported_program = exported_program._transform_do_not_use(
+            _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints)
+        )
+
+    return exported_program
diff --git a/MLPY/Lib/site-packages/torch/export/_tree_utils.py b/MLPY/Lib/site-packages/torch/export/_tree_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..827793ae94febdfd10a869f48720b8f41eea2a40
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/_tree_utils.py
@@ -0,0 +1,64 @@
+from typing import Any, Callable, Dict, Optional
+
+from torch.utils._pytree import Context, TreeSpec
+
+
+def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
+    """Reorder user-provided kwargs to match the order in `spec`. `spec` is
+    expected to be the in_spec of an exported program, i.e. the spec that
+    results from flattening `(args, kwargs)`.
+
+    We need this to provide consistent input ordering, such so that users can
+    pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result.
+    """
+    # Make sure that the spec is actually shaped like (args, kwargs)
+    assert spec.type is tuple
+    assert spec.num_children == 2
+    kwargs_spec = spec.children_specs[1]
+    assert kwargs_spec.type is dict
+
+    if set(user_kwargs) != set(kwargs_spec.context):
+        raise ValueError(
+            f"kwarg key mismatch: "
+            f"Got {list(user_kwargs)} but expected {kwargs_spec.context}"
+        )
+
+    reordered_kwargs = {}
+    for kw in kwargs_spec.context:
+        reordered_kwargs[kw] = user_kwargs[kw]
+
+    return reordered_kwargs
+
+
+def is_equivalent(
+    spec1: TreeSpec,
+    spec2: TreeSpec,
+    equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool],
+) -> bool:
+    """Customizable equivalence check for two TreeSpecs.
+
+    Arguments:
+        spec1: The first TreeSpec to compare
+        spec2: The second TreeSpec to compare
+        equivalence_fn: A function to determine the equivalence of two
+            TreeSpecs by examining their types and contexts. It will be called like:
+
+                equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context)
+
+            This function will be applied recursively to all children.
+
+    Returns:
+        True if the two TreeSpecs are equivalent, False otherwise.
+    """
+    if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context):
+        return False
+
+    # Recurse on children
+    if len(spec1.children_specs) != len(spec2.children_specs):
+        return False
+
+    for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
+        if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
+            return False
+
+    return True
diff --git a/MLPY/Lib/site-packages/torch/export/_unlift.py b/MLPY/Lib/site-packages/torch/export/_unlift.py
new file mode 100644
index 0000000000000000000000000000000000000000..0171c94ddc7d0372732c98aace680e6f8d565946
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/_unlift.py
@@ -0,0 +1,314 @@
+import copy
+from itertools import chain
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.utils._pytree as pytree
+from torch._export.utils import _check_input_constraints_for_graph
+from torch.export.unflatten import _assign_attr, _AttrKind
+from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
+from ._remove_effect_tokens_pass import _remove_effect_tokens
+
+from .exported_program import (
+    ExportedProgram,
+    ExportGraphSignature,
+    InputKind,
+    OutputKind,
+)
+
+
+@torch._dynamo.disable
+def _check_input_constraints_pre_hook(self, *args, **kwargs):
+    flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args)
+
+    if received_spec != self._in_spec:
+        raise ValueError(  # noqa: TRY200
+            "Trying to flatten user inputs with exported input tree spec: \n"
+            f"{self._in_spec}\n"
+            "but actually got inputs with tree spec of: \n"
+            f"{received_spec}"
+        )
+
+    return _check_input_constraints_for_graph(
+        [node for node in self.graph.nodes if node.op == "placeholder"],
+        flat_args_with_path,
+        self.range_constraints,
+    )
+
+
+def _unlift_inputs_as_getattr(
+    gm: torch.fx.GraphModule,
+    lifted_inputs: List[Optional[str]],
+) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
+    """
+    Unlift inputs referring to params/buffers/constants as getattr nodes in the
+    graph
+    """
+    unlifted_name_to_node = {}
+    input_name_to_node = {}
+
+    placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
+    assert len(lifted_inputs) == len(placeholder_nodes)
+    for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
+        if lifted_node is None:
+            input_name_to_node[input_node.name] = input_node
+
+        else:
+            with gm.graph.inserting_after(input_node):
+                getattr_node = gm.graph.get_attr(lifted_node)
+                input_node.replace_all_uses_with(getattr_node)
+                metadata = input_node.meta
+                gm.graph.erase_node(input_node)
+                getattr_node.meta = metadata
+                unlifted_name_to_node[lifted_node] = getattr_node
+
+    return unlifted_name_to_node, input_name_to_node
+
+
+def _insert_copy_for_mutations(
+    gm: torch.fx.GraphModule,
+    mutated_outputs: List[Optional[str]],
+    unlifted_name_to_node: Dict[str, torch.fx.Node],
+    input_name_to_node: Dict[str, torch.fx.Node],
+) -> None:
+    """
+    Find the all the buffers and inputs that were mutated and insert copy_
+    operators to reflect mutations.
+    """
+    output_node = None
+    for node in gm.graph.nodes:
+        if node.op == "output":
+            output_node = node
+            break
+    assert output_node is not None
+    outputs = pytree.tree_flatten(output_node.args)[0]
+    assert len(outputs) == len(mutated_outputs)
+
+    user_output_nodes = []
+    for return_node, mutated_node_name in zip(outputs, mutated_outputs):
+        if mutated_node_name is None:
+            user_output_nodes.append(return_node)
+            continue
+
+        if mutated_node_name in unlifted_name_to_node:
+            mutated_node = unlifted_name_to_node[mutated_node_name]
+        elif mutated_node_name in input_name_to_node:
+            mutated_node = input_name_to_node[mutated_node_name]
+        else:
+            raise RuntimeError(
+                f"Could not find {mutated_node_name} in either buffer or input nodes"
+            )
+
+        with gm.graph.inserting_before(output_node):
+            _ = gm.graph.call_function(
+                torch.ops.aten.copy_.default, (mutated_node, return_node)
+            )
+
+    with gm.graph.inserting_before(output_node):
+        # Only return user outputs
+        new_output = gm.graph.output(tuple(user_output_nodes))
+        output_node.replace_all_uses_with(new_output)
+        gm.graph.erase_node(output_node)
+
+
+def _get_codegen(
+    in_spec: pytree.TreeSpec,
+    out_spec: Optional[pytree.TreeSpec],
+) -> _PyTreeCodeGen:
+    """
+    Create the codegen for the graph module based on the in/out specs
+    """
+    if (
+        in_spec.type == tuple
+        and in_spec.num_children == 2
+        and in_spec.children_specs[0].type == tuple
+        and in_spec.children_specs[1].type == dict
+    ):
+        # if in_spec contains the args (tuple) and kwargs (dict)
+        names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
+        # add kwarg names
+        names.extend(in_spec.children_specs[1].context)
+    else:
+        names = [f"arg_{i}" for i in range(in_spec.num_children)]
+
+    return _PyTreeCodeGen(
+        _PyTreeInfo(
+            names,
+            in_spec,
+            out_spec,
+        )
+    )
+
+
+def _unlift(
+    gm: torch.fx.GraphModule,
+    lifted_inputs: List[Optional[str]],
+    mutated_outputs: List[Optional[str]],
+    in_spec: pytree.TreeSpec,
+    out_spec: Optional[pytree.TreeSpec],
+    state_dict: Dict[str, Any],
+    constants: Dict[str, Any],
+):
+    """
+    Args:
+        lifted_inputs: A list matching the graph module's input nodes. For
+        an input node that is referring to a lifted parameter/buffer, this
+        list will contain the fqn the corresponding attribute. Otherwise, this
+        list will contain None. This is used to unlift the lifted parameters as
+        get_attr nodes.
+
+        mutated_outputs: A list matching the graph module's output nodes. For
+        an output node that is referring to a mutated buffer or user input, this
+        list will contain the name of the corresponding buffer or user input
+        that needs to be mutated. Otherwise, this list will contain None. This
+        is used to re-insert an inplace copy_ operator to copy the mutated
+        values back to the original node.
+    """
+    unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
+        gm, lifted_inputs
+    )
+    _insert_copy_for_mutations(
+        gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
+    )
+    gm.graph._codegen = _get_codegen(in_spec, out_spec)
+    gm.graph.lint()
+    gm.graph.eliminate_dead_code()
+    gm.recompile()
+    return gm
+
+
+def _register_attrs_to_new_gm(
+    new_gm: torch.fx.GraphModule,
+    graph_signature: ExportGraphSignature,
+    state_dict: Dict[str, Any],
+    constants: Dict[str, Any],
+) -> None:
+    non_persistent_buffers = set(graph_signature.non_persistent_buffers)
+    for name in graph_signature.buffers:
+        if name in non_persistent_buffers:
+            persistent = False
+            value = constants[name]
+        else:
+            persistent = True
+            value = state_dict[name]
+        _assign_attr(
+            value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
+        )
+    for name in graph_signature.parameters:
+        value = state_dict[name]
+        _assign_attr(
+            value,
+            new_gm,
+            name,
+            attr_kind=_AttrKind.PARAMETER,
+        )
+
+    for name in chain(
+        graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
+    ):
+        value = constants[name]
+        _assign_attr(
+            value,
+            new_gm,
+            name,
+            attr_kind=_AttrKind.CONSTANT,
+        )
+
+
+class _StatefulGraphModuleFactory(type):
+    """
+    Metaclass that ensures a private constructor for _StatefulGraphModule
+    """
+
+    def __call__(cls, *args, **kwargs):
+        raise TypeError(
+            f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
+        )
+
+    def _create(cls, root, graph, range_constraints=None):
+        return super().__call__(
+            root,
+            graph,
+            range_constraints=range_constraints,
+        )
+
+
+class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
+    def __init__(self, root, graph, range_constraints=None):
+        super().__init__(root, graph)
+        # Need to fix up non-persistent buffers.
+        self.range_constraints = range_constraints or []
+
+
+def _create_stateful_graph_module(
+    plain_graph_module: torch.fx.GraphModule,
+    range_constraints,
+    # TODO(suo) this should not be optional, but is since we still ahve
+    # capture_pre_autograd_graph grr
+    graph_signature: Optional[ExportGraphSignature] = None,
+):
+    stateful_gm = _StatefulGraphModule._create(
+        plain_graph_module,
+        plain_graph_module.graph,
+        range_constraints=range_constraints,
+    )
+    stateful_gm.register_forward_pre_hook(
+        _check_input_constraints_pre_hook, with_kwargs=True
+    )
+
+    if graph_signature is None:
+        return stateful_gm
+    # Fix up non-persistent buffers. torch.fx does not distinguish between
+    # persistent and non-persistent buffers, so we must restore that distinction
+    # here.
+    for buffer in graph_signature.non_persistent_buffers:
+        _assign_attr(
+            plain_graph_module.get_buffer(buffer),
+            stateful_gm,
+            buffer,
+            attr_kind=_AttrKind.BUFFER,
+            persistent=False,
+        )
+
+    return stateful_gm
+
+
+def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
+    ep = _remove_effect_tokens(ep)
+    new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
+    _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
+
+    lifted_inputs: List[Optional[str]] = [
+        in_spec.target
+        if in_spec.kind
+        in (
+            InputKind.BUFFER,
+            InputKind.CONSTANT_TENSOR,
+            InputKind.PARAMETER,
+            InputKind.CUSTOM_OBJ,
+        )
+        else None
+        for in_spec in ep.graph_signature.input_specs
+    ]
+
+    mutated_outputs: List[Optional[str]] = [
+        out_spec.target
+        if out_spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
+        else None
+        for out_spec in ep.graph_signature.output_specs
+    ]
+
+    new_gm = _unlift(
+        new_gm,
+        lifted_inputs,
+        mutated_outputs,
+        ep.call_spec.in_spec,
+        ep.call_spec.out_spec,
+        ep.state_dict,
+        ep.constants,
+    )
+    unlift_gm = _create_stateful_graph_module(
+        new_gm, ep.range_constraints, ep.graph_signature
+    )
+    unlift_gm.meta.update(ep.graph_module.meta)
+    return unlift_gm
diff --git a/MLPY/Lib/site-packages/torch/export/custom_obj.py b/MLPY/Lib/site-packages/torch/export/custom_obj.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b04215c31fb79af34511606600a856bc5ba6a8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/custom_obj.py
@@ -0,0 +1,16 @@
+from dataclasses import dataclass
+
+
+__all__ = ["ScriptObjectMeta"]
+
+
+@dataclass
+class ScriptObjectMeta:
+    """
+    Metadata which is stored on nodes representing ScriptObjects.
+    """
+
+    # Key into constants table to retrieve the real ScriptObject.
+    constant_name: str
+
+    class_fqn: str
diff --git a/MLPY/Lib/site-packages/torch/export/dynamic_shapes.py b/MLPY/Lib/site-packages/torch/export/dynamic_shapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34447908ca3a836598c3e237e24da789cc7900d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/dynamic_shapes.py
@@ -0,0 +1,876 @@
+import builtins
+import dataclasses
+import inspect
+import math
+import sys
+import weakref
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
+
+import torch
+from torch._subclasses.fake_tensor import FakeTensor
+from torch.utils._pytree import SUPPORTED_NODES
+
+from .exported_program import ExportedProgram
+
+if TYPE_CHECKING:
+    from sympy import Symbol
+
+    from torch._guards import Source
+
+    from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint
+
+__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"]
+
+
+class _Dim(type):
+    """
+    Metaclass for :func:`Dim` types.
+    """
+
+    @staticmethod
+    def readable(name, min_, max_):
+        if min_ == 2:
+            min_ = None
+        if max_ == sys.maxsize - 1:
+            max_ = None
+        if min_ is None and max_ is None:
+            return f"Dim('{name}')"
+        if min_ is None:
+            return f"Dim('{name}', max={max_})"
+        if max_ is None:
+            return f"Dim('{name}', min={min_})"
+        return f"Dim('{name}', min={min_}, max={max_})"
+
+    def __add__(cls, other):
+        # e.g., dim + 1
+        if type(other) is not int:
+            raise NotImplementedError(
+                f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
+                "(Only increasing linear operations with integer coefficients are supported.)"
+            )
+        return cls._derive(lambda x: x + other)
+
+    def __radd__(cls, other):
+        return cls + other
+
+    def __sub__(cls, other):
+        # e.g., dim - 1
+        if type(other) is not int:
+            raise NotImplementedError(
+                f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
+                "(Only increasing linear operations with integer coefficients are supported.)"
+            )
+        return cls._derive(lambda x: x - other)
+
+    def __rsub__(cls, other):
+        raise NotImplementedError(
+            f"Attempted to negate {cls.__name__}. "
+            "(Only increasing linear operations with integer coefficients are supported.)"
+        )
+
+    def __mul__(cls, other):
+        # e.g., dim * 2
+        if type(other) is not int or other <= 0:
+            raise NotImplementedError(
+                f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
+                "(Only increasing linear operations with integer coefficients are supported.)"
+            )
+        return cls._derive(lambda x: x * other)
+
+    def __rmul__(cls, other):
+        return cls * other
+
+    def _derived_name(cls, fn):
+        from sympy import sympify
+
+        return str(fn(sympify(cls.__name__)))
+
+    def _derive(cls, fn):
+        return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})
+
+
+class _DerivedDim(_Dim):
+    """
+    Metaclass for derived :func:`Dim` types.
+
+    Currently we only support increasing linear expressions with integer coefficients.
+    In other words, a derived Dim can always be written in the form Ax + B, where
+    x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive.
+    (In particular, the latter ensures that x < y => Ax + B < Ay + B.)
+    These restrictions on the form of derived Dims makes the metatheory simpler: e.g.,
+    it simplifies computing ranges for derived Dims, solving for underlying regular Dims,
+    deciding equalities between derived Dims, and so on.
+
+    The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`.
+    The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
+    """
+
+    @property
+    def min(self):
+        # assume that self.fn is an increasing function
+        # TODO(avik): use sympy value range analysis instead?
+        from sympy import Integer
+
+        _min_symint = self.fn(Integer(self.root.min))  # type: ignore[attr-defined]
+        assert _min_symint >= 2, (
+            f"Expected derived min value of {self.__name__} to be >= 2. "
+            f"Please specify an appropriate min value for {self.root.__name__} "  # type: ignore[attr-defined]
+            f"(currently {self.root.min})."  # type: ignore[attr-defined]
+        )
+        return int(_min_symint)
+
+    @property
+    def max(self):
+        # assume that self.fn is an increasing function
+        # TODO(avik): use sympy value range analysis instead?
+        from sympy import Integer
+
+        _max_symint = self.fn(Integer(self.root.max))  # type: ignore[attr-defined]
+        assert _max_symint <= sys.maxsize - 1, (
+            f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. "
+            f"Please specify an appropriate max value for {self.root.__name__} "  # type: ignore[attr-defined]
+            f"(currently {self.root.max})."  # type: ignore[attr-defined]
+        )
+        return int(_max_symint)
+
+    def _derive(self, fn):
+        # We support nesting, e.g., 2*dim + 1.
+        # This is implemented by composing operations on the same root.
+        # As a consequence, roots are always regular Dims (i.e., not derived Dims).
+        return _DerivedDim(
+            self._derived_name(fn),
+            (int,),
+            {"root": self.root, "fn": lambda x: fn(self.fn(x))},  # type: ignore[attr-defined]
+        )
+
+
+def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
+    """
+    :func:`Dim` constructs a type analogous to a named symbolic integer with a range.
+    It can be used to describe multiple possible values of a dynamic tensor dimension.
+    Note that different dynamic dimensions of the same tensor, or of different tensors,
+    can be described by the same type.
+
+    Args:
+        name (str): Human-readable name for debugging.
+        min (Optional[int]): Minimum possible value of given symbol (inclusive)
+        max (Optional[int]): Maximum possible value of given symbol (inclusive)
+
+    Returns:
+        A type that can be used in dynamic shape specifications for tensors.
+    """
+    _min = 2 if min is None else builtins.max(min, 2)
+    _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1)
+    assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
+    dim = _Dim(name, (int,), {"min": _min, "max": _max})
+    dim.__module__ = getattr(
+        inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
+    )
+    return dim
+
+
+def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None):
+    """
+    Util to create multiple :func:`Dim` types.
+    """
+    return tuple(Dim(name, min=min, max=max) for name in names)
+
+
+@dataclasses.dataclass
+class _ConstraintTarget:
+    """
+    This represents input tensor dimensions.  Don't create this
+    class directly; instead, use :func:`dynamic_dim`.
+    """
+
+    w_tensor: Any  # weakref to torch.Tensor
+    # TODO: We don't need t_id; we can get it off of w_tensor
+    t_id: int
+    dim: int
+
+
+class _ConstraintFactory(type):
+    """
+    Metaclass that ensures a private constructor for :class:`_Constraint`
+    """
+
+    def __call__(cls, *args, **kwargs):
+        raise TypeError(
+            f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
+            f"Please use torch.export.dynamic_dim() to create one"
+        )
+
+    def _create(
+        cls, w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None
+    ):
+        return super().__call__(
+            w_tensor, t_id, dim, constraint_range, shared, debug_name
+        )
+
+
+def _create_constraint(
+    w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None
+):
+    return _Constraint._create(
+        w_tensor, t_id, dim, constraint_range, shared, debug_name
+    )
+
+
+@dataclasses.dataclass
+class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
+    """
+
+    .. warning::
+        Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead.
+
+    This represents constraints on input tensor dimensions, e.g., requiring
+    them to be fully polymorphic or within some range.
+
+    """
+
+    # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, ]
+    constraint_range: "StrictMinMaxConstraint"
+    # Represent that `constraint_range` is shared with another _ConstraintTarget, which
+    # typically arises because of a specified equality with another dynamic dimension.
+    shared: Optional[_ConstraintTarget] = None
+    debug_name: Optional[str] = None
+
+    def _clone_with_range(self, lower=2, upper=math.inf):
+        # Import sympy locally
+        from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
+        from torch.utils._sympy.value_ranges import ValueRanges
+
+        constraint_range = StrictMinMaxConstraint(
+            vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
+            warn_only=False,
+        )
+        return _create_constraint(
+            self.w_tensor,
+            self.t_id,
+            self.dim,
+            constraint_range,
+            self.shared,
+            self.debug_name,
+        )
+
+    def __ge__(self, lower):
+        return self._clone_with_range(lower=lower)
+
+    def __gt__(self, lower):
+        return self._clone_with_range(lower=lower + 1)
+
+    def __le__(self, upper):
+        return self._clone_with_range(upper=upper)
+
+    def __lt__(self, upper):
+        return self._clone_with_range(upper=upper - 1)
+
+    def __bool__(self):
+        # NOTE(avik): We do not support compound expressions like a <= x <= b.
+        # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
+        # and moreover, enforces that any overload of __bool__ must return True or False.
+        # FWIW, sympy also raises TypeError in this case.
+        raise TypeError(
+            "Cannot determine truth value of _Constraint. "
+            "If you are trying to combine _Constraint's with logical connectives, "
+            "you can specify them separately instead."
+        )
+
+    @property
+    def serializable_spec(self):
+        # We need a serialization compatible format of the constraint so that it
+        # can be savedin the graph module w/o breaking the module serialization.
+        # The saved constraints will be used directly for the post-exporting pass
+        # that converts constraints to runtime assertion. The saved constraints
+        # will not be saved in the serialized module.
+        # TODO: A better way is needed. Currently we use 't_id' to map the constraint,
+        # which is not reliable
+        return {
+            "t_id": self.t_id,
+            "dim": self.dim,
+            "min": self.constraint_range.vr.lower,
+            "max": self.constraint_range.vr.upper,
+        }
+
+    def __eq__(self, other):
+        if not isinstance(other, _Constraint):
+            raise TypeError(
+                "A dynamic dim can be specified equal only to another dynamic dim. "
+                f"Equality with {type(other)} is not supported."
+            )
+
+        # import sympy locally
+        from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
+
+        constraint_range = StrictMinMaxConstraint(
+            vr=self.constraint_range.vr & other.constraint_range.vr,
+            warn_only=False,
+        )
+        if self.debug_name is None:
+            debug_name = other.debug_name
+        else:
+            assert other.debug_name is None or self.debug_name == other.debug_name
+            debug_name = self.debug_name
+        return _create_constraint(
+            self.w_tensor,
+            self.t_id,
+            self.dim,
+            constraint_range,
+            shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim),
+            debug_name=debug_name,
+        )
+
+
+@dataclasses.dataclass
+class _PhantomRoot:
+    """
+    This represents the root of a derived Dim where the root does not directly
+    specify the shape of any input dimension, but the derived Dim does.
+
+    e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim.
+
+    The fields `name`, `constraint_range`, and `val` carried by a phantom root
+    help create a symbol for it. Any derived dims with this phantom root are
+    backed by expressions over this symbol.
+    """
+
+    name: str
+    constraint_range: "StrictMinMaxConstraint"
+    val: int
+
+
+@dataclasses.dataclass
+class _DerivedConstraint(_ConstraintTarget):
+    """
+    This represents a derived Dim, whose root is either a regular constraint target
+    (which directly specifies the shape of some input dimension) or a phantom root
+    (which does so indirectly).
+    """
+
+    # NOTE: This is not currently a subclass of _Constraint because we do not support
+    # `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for
+    # legacy constraints based on `dynamic_dim`: equality can be expressed simply by
+    # reusing the same (derived or normal) `Dim`.
+    root: Union[_ConstraintTarget, _PhantomRoot]
+    fn: Callable
+    constraint_range: "StrictMinMaxConstraint"
+    debug_name: Optional[str] = None
+
+    @property
+    def shared(self):
+        # Some code paths expect a union of _Constraint and _DerivedConstraint.
+        # Thus we expose a `shared` field that is always None.
+        # TODO(avik): clean this up
+        return None
+
+    @property
+    def serializable_spec(self):
+        # same as _Constraint.serializable_spec
+        return {
+            "t_id": self.t_id,
+            "dim": self.dim,
+            "min": self.constraint_range.vr.lower,
+            "max": self.constraint_range.vr.upper,
+        }
+
+
+Constraint = Union[_Constraint, _DerivedConstraint]
+
+
+def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None):
+    """
+    .. warning::
+        (This feature is DEPRECATED. See :func:`Dim` instead.)
+
+    :func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of
+    a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to
+    ``constraints`` argument of :func:`export`.
+
+    Args:
+        t (torch.Tensor): Example input tensor that have dynamic dimension size(s)
+        index (int): Index of dynamic dimension
+
+    Returns:
+        A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so
+        that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic
+        as a symbolic size rather than specializing according to size of example tracing input.
+
+    Specifically :func:`dynamic_dim` can be used to express following types of dynamism.
+
+    - Size of a dimension is dynamic and unbounded::
+
+        t0 = torch.rand(2, 3)
+        t1 = torch.rand(3, 4)
+
+        # First dimension of t0 can be dynamic size rather than always being static size 2
+        constraints = [dynamic_dim(t0, 0)]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Size of a dimension is dynamic with a lower bound::
+
+        t0 = torch.rand(10, 3)
+        t1 = torch.rand(3, 4)
+
+        # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
+        # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
+        constraints = [
+            dynamic_dim(t0, 0) >= 5,
+            dynamic_dim(t1, 1) > 2,
+        ]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Size of a dimension is dynamic with an upper bound::
+
+        t0 = torch.rand(10, 3)
+        t1 = torch.rand(3, 4)
+
+        # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
+        # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
+        constraints = [
+            dynamic_dim(t0, 0) <= 16,
+            dynamic_dim(t1, 1) < 8,
+        ]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension::
+
+        t0 = torch.rand(10, 3)
+        t1 = torch.rand(3, 4)
+
+        # Sizes of second dimension of t0 and first dimension are always equal
+        constraints = [
+            dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
+        ]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Mix and match all types above as long as they do not express conflicting requirements
+
+    """
+    from torch._dynamo.exc import UserError, UserErrorType
+
+    if not isinstance(t, torch.Tensor):
+        raise UserError(
+            UserErrorType.DYNAMIC_DIM,
+            f"Expected tensor as input to dynamic_dim but got {type(t)}",
+        )
+
+    if t.dim() < 1:
+        raise UserError(
+            UserErrorType.DYNAMIC_DIM, "Cannot mark 0-dimension tensors to be dynamic"
+        )
+
+    if index >= t.dim():
+        raise UserError(
+            UserErrorType.DYNAMIC_DIM,
+            f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]"
+            f" but got {index}, which is out of bounds for the given tensor.",
+        )
+
+    # Import sympy locally
+    import sympy
+
+    from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
+    from torch.utils._sympy.value_ranges import ValueRanges
+
+    return _create_constraint(
+        weakref.ref(t),
+        id(t),
+        index,
+        StrictMinMaxConstraint(
+            vr=ValueRanges(lower=2, upper=sympy.oo), warn_only=False
+        ),
+        debug_name=debug_name,
+    )
+
+
+def _process_equalities(
+    constraint: Constraint,
+    get_sources: Callable[[int, int], List["Source"]],
+    shape_env: "ShapeEnv",
+    source_pairs: List[Tuple["Source", "Source"]],
+    derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]],
+    phantom_symbols: Dict[str, "Symbol"],
+):
+    """
+    Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become
+    fields of `EqualityConstraint`) based on a given input `constraint`.
+    """
+
+    source, *other_sources = get_sources(constraint.t_id, constraint.dim)
+    # When t.size()[dim] maps to src0, src1, ..., srcN, we add
+    # constraints that make src0 "equal" to src1, ..., srcN.
+    source_pairs.extend((source, other_source) for other_source in other_sources)
+    if not isinstance(constraint, _DerivedConstraint):
+        if constraint.shared is not None:
+            # Moreover, when t.size()[dim] is specified equal to t'.size()[dim']
+            # and t'.size()[dim'] maps to src1', ..., srcN', we add
+            # constraints that also make src0 "equal" to src1', ..., srcN'.
+            other_sources = get_sources(constraint.shared.t_id, constraint.shared.dim)
+            source_pairs.extend(
+                (source, other_source) for other_source in other_sources
+            )
+    else:
+        # branch based on the root of the _DerivedConstraint
+        if not isinstance(constraint.root, _PhantomRoot):
+            # either root points to an input source
+            root = get_sources(constraint.root.t_id, constraint.root.dim)[0]  # type: ignore[assignment]
+        else:
+            # or root points to a phantom symbol
+            if constraint.root.name in phantom_symbols:
+                root = phantom_symbols[constraint.root.name]  # type: ignore[assignment]
+            else:
+                # create a phantom symbol in the shape env based on the _PhantomRoot
+                root = shape_env.create_symbol(
+                    val=constraint.root.val,
+                    source=torch._dynamo.source.ConstantSource(constraint.root.name),
+                    dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC,
+                    constraint_dim=constraint.root.constraint_range,
+                )
+                phantom_symbols[constraint.root.name] = root  # type: ignore[assignment]
+
+        fn = constraint.fn
+        # A derived equality (source, root, fn) informally corresponds to source = fn(root).
+        # Here source describes an input and root might describe another input or a phantom symbol.
+        derived_equalities.append((source, root, fn))
+
+
+def _process_dynamic_shapes(
+    f: Callable,
+    args: Tuple[Any, ...],
+    kwargs: Optional[Dict[str, Any]] = None,
+    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
+) -> Optional[List[Constraint]]:
+    from collections import defaultdict
+    from collections.abc import Mapping, Sequence
+
+    from torch._dynamo.exc import UserError, UserErrorType
+
+    if dynamic_shapes is None or len(dynamic_shapes) == 0:
+        return None
+
+    kwargs = kwargs if kwargs is not None else {}
+
+    def tree_zip(combined_args, dynamic_shapes):
+        if isinstance(combined_args, (tuple, list)):
+            if not isinstance(dynamic_shapes, Sequence):
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, "
+                    f"got {dynamic_shapes} instead",
+                )
+            if len(combined_args) != len(dynamic_shapes):
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Expected {dynamic_shapes} to have {len(combined_args)} items",
+                )
+            for i, shape in enumerate(dynamic_shapes):
+                yield from tree_zip(combined_args[i], shape)
+        elif isinstance(combined_args, dict):
+            if not isinstance(dynamic_shapes, Mapping):
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, "
+                    f"got {dynamic_shapes} instead",
+                )
+            if len(combined_args) != len(dynamic_shapes):
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Expected {dynamic_shapes} to have {len(combined_args)} items",
+                )
+            for k, shape in dynamic_shapes.items():
+                yield from tree_zip(combined_args[k], shape)
+        elif type(combined_args) in SUPPORTED_NODES:
+            if not isinstance(dynamic_shapes, Sequence):
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Expected dynamic_shapes of a user-registered class (e.g., "
+                    f"{type(combined_args)}) to be a Sequence that matches the "
+                    f"flattened structure, but got {dynamic_shapes} instead",
+                )
+            yield from tree_zip(
+                SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0],
+                dynamic_shapes,
+            )
+        elif isinstance(combined_args, torch.Tensor):
+            yield (combined_args, dynamic_shapes)
+        else:
+            if dynamic_shapes is not None:
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Expected dynamic_shapes of a {type(combined_args)} to be None, "
+                    f"got {dynamic_shapes} instead",
+                )
+
+    # map of Dim names representing input shape dimensions to constraints on them
+    symbols: Dict[str, List[Constraint]] = defaultdict(list)
+    # track roots that do not directly represent input shape dimensions
+    phantom_roots: Dict[str, _PhantomRoot] = {}
+    derived_constraints_with_phantom_root: List[_DerivedConstraint] = []
+
+    def to_constraint(dim, tensor, i):
+        import sympy
+
+        from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
+        from torch.utils._sympy.solve import try_solve
+        from torch.utils._sympy.value_ranges import ValueRanges
+
+        def root_value():
+            # given tensor.shape[i] is the value of dim = fn(root),
+            # find the value of root
+            symbol = sympy.Symbol(dim.root.__name__, integer=True)
+            expr = dim.fn(symbol)
+            solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol)
+            if solution is not None:
+                return int(solution[1])  # type: ignore[call-overload]
+            else:
+                raise UserError(  # noqa: TRY200
+                    UserErrorType.CONSTRAINT_VIOLATION,
+                    f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be "
+                    f"of the form {expr}, where {symbol} is an integer",
+                )
+
+        if isinstance(dim, _DerivedDim):
+            # generate a _DerivedConstraint where the root is:
+            # - either a _ConstraintTarget (if dim.root directly describes an input shape)
+            # - or a _PhantomRoot (otherwise)
+            dim_root = dim.root  # type: ignore[attr-defined]
+            if dim_root.__name__ in symbols:
+                # root represents an input shape dimension
+                root_constraint = symbols[dim_root.__name__][0]
+                root = _ConstraintTarget(
+                    root_constraint.w_tensor,
+                    root_constraint.t_id,
+                    root_constraint.dim,
+                )
+            elif dim_root.__name__ not in phantom_roots:
+                # create a phantom root
+                root = _PhantomRoot(  # type: ignore[assignment]
+                    name=dim_root.__name__,
+                    constraint_range=StrictMinMaxConstraint(
+                        vr=ValueRanges(lower=dim_root.min, upper=dim_root.max),
+                        warn_only=False,
+                    ),
+                    val=root_value(),
+                )
+                phantom_roots[dim_root.__name__] = root  # type: ignore[assignment]
+            else:
+                root = phantom_roots[dim_root.__name__]  # type: ignore[assignment]
+            constraint = _DerivedConstraint(
+                weakref.ref(tensor),
+                id(tensor),
+                i,
+                root,
+                dim.fn,  # type: ignore[attr-defined]
+                StrictMinMaxConstraint(
+                    vr=ValueRanges(lower=dim.min, upper=dim.max),
+                    warn_only=False,
+                ),
+                debug_name=dim.__name__,
+            )
+            if isinstance(root, _PhantomRoot):
+                # NOTE(avik): since we have not processed all inputs yet, we may replace this
+                # with a root that does represent an input shape dimension later (see below)
+                derived_constraints_with_phantom_root.append(constraint)
+        else:
+            constraint = dynamic_dim(tensor, i, debug_name=dim.__name__)
+            if dim.min != 2:
+                constraint = constraint >= dim.min
+            if dim.max != sys.maxsize - 1:
+                constraint = constraint <= dim.max
+        return constraint
+
+    bounds: Dict[str, Tuple[int, int]] = {}
+
+    def check_same_bounds(dim):
+        if dim.__name__ in symbols:
+            min_, max_ = bounds[dim.__name__]
+            if dim.min != min_ or dim.max != max_:
+                this_ = _Dim.readable(dim.__name__, min_, max_)
+                that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Found different definitions {this_} and {that_} "
+                    f"for the same symbolic dimension {dim}!",
+                )
+
+        else:
+            bounds[dim.__name__] = (dim.min, dim.max)
+
+    def update_symbols(tensor, shape):
+        if isinstance(shape, dict):
+            for i, dim in shape.items():
+                if isinstance(dim, _Dim):
+                    check_same_bounds(dim)
+                    constraint = to_constraint(dim, tensor, i)
+                    symbols[dim.__name__].append(constraint)
+                else:
+                    if dim is not None:
+                        raise UserError(
+                            UserErrorType.INVALID_INPUT,
+                            f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, "
+                            "try None instead",
+                        )
+        elif isinstance(shape, (tuple, list)):
+            for i, dim in enumerate(shape):
+                if isinstance(dim, _Dim):
+                    check_same_bounds(dim)
+                    constraint = to_constraint(dim, tensor, i)
+                    symbols[dim.__name__].append(constraint)
+                else:
+                    if dim is not None:
+                        raise UserError(
+                            UserErrorType.INVALID_INPUT,
+                            f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, "
+                            "try None instead",
+                        )
+        else:
+            if shape is not None:
+                raise UserError(
+                    UserErrorType.INVALID_INPUT,
+                    f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead",
+                )
+
+    import inspect
+
+    if isinstance(f, ExportedProgram):
+        f = f.module()
+    signature = (
+        inspect.signature(f.forward)
+        if isinstance(f, torch.nn.Module)
+        else inspect.signature(f)
+    )
+    combined_args = signature.bind(*args, **kwargs).arguments
+
+    # This means user didn't specify dynamic shapes with argument names.
+    combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values())  # type: ignore[assignment]
+    for tensor, shape in tree_zip(combined_args, dynamic_shapes):
+        update_symbols(tensor, shape)
+
+    constraints = []
+    for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root:
+        phantom_root_name = derived_constraint_with_phantom_root.root.name  # type: ignore[union-attr]
+        if phantom_root_name in symbols:
+            # We found an input shape dimension corresponding to this name, so we
+            # do not need a phantom symbol for it after all.
+            # NOTE(avik): Overall we want to maintain the invariant that roots that
+            # are phantom symbols are really "phantom," i.e., they cannot be represented
+            # by any input source. This is important when we are deciding derived equalities,
+            # since we can focus our attention exclusively on input sources: deciding
+            # derived equalities involving phantom symbols are, in comparison, trivial.
+            derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0]
+
+    for dynamic_dims in symbols.values():
+        if all(
+            isinstance(dynamic_dim, _DerivedConstraint) for dynamic_dim in dynamic_dims
+        ):
+            constraints.extend(dynamic_dims)
+        else:
+            primary, *others = dynamic_dims
+            if others:
+                for other in others:
+                    constraints.append(primary == other)  # type: ignore[arg-type]
+            else:
+                constraints.append(primary)
+
+    return constraints  # type: ignore[return-value]
+
+
+def _process_constraints(
+    fake_mode,
+    graph_module: torch.fx.GraphModule,
+    num_lifted_params_buffers: int,
+    example_inputs: List[torch.Tensor],
+) -> Dict:
+    """
+    Process the constraints stored in the graph module to return something more readable.
+
+    Args:
+        graph_module (torch.fx.GraphModule): GraphModule returned from
+            dynamo.export, which contains the "input_shape_constraints" and
+            "inline_constraints" metadata
+
+        example_inputs: Flattened list of example inputs used to export the graph module
+
+    Returns:
+        range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of
+            symbols (from SymInts) appearing in the fake tensors in
+            node.meta["val"] to their range constraints, which are a tuple
+            containing (lower, upper) constraints.
+    """
+    from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
+        InputDim,
+    )
+
+    # Import sympy locally
+    from torch.fx.experimental.symbolic_shapes import SymInt
+    from torch.utils._sympy.value_ranges import ValueRanges
+
+    input_shape_constraints = graph_module.meta.get("input_shape_constraints", [])
+    inline_constraints = graph_module.meta.get("inline_constraints", [])
+
+    # Create dict mapping tensor_id to node names
+    tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list)
+    # Create dict mapping placeholder node names to their nodes
+    placeholder_nodes: Dict[str, torch.fx.Node] = {}
+    for i, node in enumerate(graph_module.graph.nodes):
+        if node.op != "placeholder":
+            # All placeholder nodes should be together in the beginning of the
+            # graph
+            break
+        if i >= num_lifted_params_buffers:
+            example_input = example_inputs[i - num_lifted_params_buffers]
+            tensor_id_to_nodes[id(example_input)].append(node.name)
+            placeholder_nodes[node.name] = node
+
+    # Create dict mapping (node name, dim) a list of range (lower, upper)
+    # constraints
+    multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list)
+    for constraint in input_shape_constraints:
+        for node in tensor_id_to_nodes[constraint["t_id"]]:
+            node_dim = InputDim(node, constraint["dim"])
+
+            # Accumulate range constraints
+            multi_range_constraints[node_dim].append(
+                ValueRanges(constraint["min"], constraint["max"])
+            )
+
+    # Create dict mapping symbol to a singular range (lower, upper)
+    range_constraints: Dict[Any, ValueRanges] = {}
+
+    # Add inline constraints to range_constraints
+    range_constraints = {
+        symbol: inline_constraints[symbol] for symbol in inline_constraints
+    }
+
+    free_symbols: Set["Symbol"] = set()
+    # Add input range constraints to range_constraints
+    for input_dim, multi_range_constraint in multi_range_constraints.items():  # type: ignore[assignment]
+        # Simplify the range constraints into a single range constraint
+        # Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10]
+        min_vals = [rc.lower for rc in multi_range_constraint]
+        max_vals = [rc.upper for rc in multi_range_constraint]
+        min_val = max(min_vals)  # type: ignore[type-var]
+        max_val = min(max_vals)  # type: ignore[type-var]
+        assert min_val <= max_val  # type: ignore[operator]
+
+        # Add input node range constraints
+        val = placeholder_nodes[input_dim.input_name].meta["val"]
+        assert isinstance(val, FakeTensor)
+        symint = val.shape[input_dim.dim]
+        assert isinstance(
+            symint, SymInt
+        ), f"Expected SymInt but got {symint}: {type(symint)}"
+        symbol = symint.node.expr
+        range_constraints[symbol] = ValueRanges(min_val, max_val)
+        free_symbols.update(symbol.free_symbols)
+
+    for symbol in free_symbols:
+        if symbol not in range_constraints:
+            # Placeholders can have symbolic shapes that are derived expressions.
+            # The above code will record direct range constraints for them
+            # so that we can do runtime assertions. In addition, for serde checks
+            # we want to record range constraints for their root symbols.
+            range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol]
+
+    return range_constraints
diff --git a/MLPY/Lib/site-packages/torch/export/exported_program.py b/MLPY/Lib/site-packages/torch/export/exported_program.py
new file mode 100644
index 0000000000000000000000000000000000000000..0093133ea91ca8129796b91bcb77d4987c3d1c6b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/exported_program.py
@@ -0,0 +1,745 @@
+import copy
+import dataclasses
+import functools
+import types
+import warnings
+from collections import namedtuple
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TYPE_CHECKING,
+    Union,
+)
+
+from torch.fx.immutable_collections import immutable_dict, immutable_list
+
+if TYPE_CHECKING:
+    # Import the following modules during type checking to enable code intelligence features,
+    # such as auto-completion in tools like pylance, even when these modules are not explicitly
+    # imported in user code.
+
+    import sympy
+
+    from torch.utils._sympy.value_ranges import ValueRanges
+
+import torch
+import torch.utils._pytree as pytree
+from torch.export._tree_utils import is_equivalent, reorder_kwargs
+from torch.fx._compatibility import compatibility
+from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
+
+from torch.fx.passes.infra.pass_base import PassResult
+from torch.fx.passes.infra.pass_manager import PassManager
+
+from .graph_signature import (  # noqa: F401
+    _sig_to_specs,
+    ArgumentSpec,
+    ConstantArgument,
+    CustomObjArgument,
+    ExportGraphSignature,
+    InputKind,
+    InputSpec,
+    OutputKind,
+    OutputSpec,
+    SymIntArgument,
+    TensorArgument,
+)
+
+
+__all__ = [
+    "ExportedProgram",
+    "ModuleCallEntry",
+    "ModuleCallSignature",
+]
+
+
+PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
+
+
+@dataclasses.dataclass
+class ModuleCallSignature:
+    inputs: List[ArgumentSpec]
+    outputs: List[ArgumentSpec]
+    in_spec: pytree.TreeSpec
+    out_spec: pytree.TreeSpec
+
+
+@dataclasses.dataclass
+class ModuleCallEntry:
+    fqn: str
+    signature: Optional[ModuleCallSignature] = None
+
+
+def _disable_prexisiting_fake_mode(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with maybe_disable_fake_tensor_mode():
+            return fn(*args, **kwargs)
+
+    return wrapper
+
+
+def _fx_collection_equivalence_fn(
+    spec1_type: Optional[type],
+    spec1_context: pytree.Context,
+    spec2_type: Optional[type],
+    spec2_context: pytree.Context,
+) -> bool:
+    """Treat containers and their immutable variants as the same type. Otherwise
+    compare as normal.
+    """
+    if spec1_type is None or spec2_type is None:
+        return spec1_type is spec2_type and spec1_context == spec2_context
+
+    if issubclass(spec1_type, (dict, immutable_dict)) and issubclass(
+        spec2_type, (dict, immutable_dict)
+    ):
+        return spec1_context == spec2_context
+
+    if issubclass(spec1_type, (list, immutable_list)) and issubclass(
+        spec2_type, (list, immutable_list)
+    ):
+        return spec1_context == spec2_context
+
+    return spec1_type is spec2_type and spec1_context == spec2_context
+
+
+class ExportedProgram:
+    """
+    Package of a program from :func:`export`. It contains
+    an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing
+    tensor values of all lifted parameters and buffers, and various metadata.
+
+    You can call an ExportedProgram like the original callable traced by
+    :func:`export` with the same calling convention.
+
+    To perform transformations on the graph, use ``.module`` property to access
+    an :class:`torch.fx.GraphModule`. You can then use
+    `FX transformation `_
+    to rewrite the graph. Afterwards, you can simply use :func:`export`
+    again to construct a correct ExportedProgram.
+    """
+
+    def __init__(
+        self,
+        root: Union[torch.nn.Module, Dict[str, Any]],
+        graph: torch.fx.Graph,
+        graph_signature: ExportGraphSignature,
+        state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
+        range_constraints: "Dict[sympy.Symbol, Any]",
+        module_call_graph: List[ModuleCallEntry],
+        example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
+        verifier: Optional[Type[Any]] = None,  # TODO Change typing hint to Verifier.
+        tensor_constants: Optional[
+            Dict[str, torch.Tensor]
+        ] = None,  # TODO: deprecate this
+        constants: Optional[
+            Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]
+        ] = None,
+    ):
+        # Remove codegen related things from the graph. It should just be a flat graph.
+        graph._codegen = torch.fx.graph.CodeGen()
+        self._graph_module = _create_graph_module_for_export(root, graph)
+        if isinstance(root, torch.fx.GraphModule):
+            self._graph_module.meta.update(root.meta)
+
+        self._graph_signature: ExportGraphSignature = graph_signature
+        self._state_dict: Dict[str, Any] = state_dict
+        self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints
+        assert module_call_graph is not None
+        self._module_call_graph: List[ModuleCallEntry] = module_call_graph
+        self._example_inputs = example_inputs
+
+        self._constants = tensor_constants or constants or {}
+        assert self._constants is not None
+
+        from torch._export.verifier import Verifier
+
+        if verifier is None:
+            verifier = Verifier
+        assert issubclass(verifier, Verifier)
+        self._verifier = verifier
+        # Validate should be always the last step of the constructor.
+        self.verifier().check(self)
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def graph_module(self):
+        return self._graph_module
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def graph(self):
+        return self.graph_module.graph
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def graph_signature(self):
+        return self._graph_signature
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def state_dict(self):
+        return self._state_dict
+
+    @compatibility(is_backward_compatible=False)
+    def parameters(self) -> Iterator[torch.nn.Parameter]:
+        """
+        Returns an iterator over original module's parameters.
+        """
+        for _, param in self.named_parameters():
+            yield param
+
+    @compatibility(is_backward_compatible=False)
+    def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
+        """
+        Returns an iterator over original module parameters, yielding
+        both the name of the parameter as well as the parameter itself.
+        """
+        for param_name in self.graph_signature.parameters:
+            yield param_name, self.state_dict[param_name]
+
+    @compatibility(is_backward_compatible=False)
+    def buffers(self) -> Iterator[torch.Tensor]:
+        """
+        Returns an iterator over original module buffers.
+        """
+        for _, buf in self.named_buffers():
+            yield buf
+
+    @compatibility(is_backward_compatible=False)
+    def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
+        """
+        Returns an iterator over original module buffers, yielding
+        both the name of the buffer as well as the buffer itself.
+        """
+        non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
+        for buffer_name in self.graph_signature.buffers:
+            if buffer_name in non_persistent_buffers:
+                yield buffer_name, self.constants[buffer_name]
+            else:
+                yield buffer_name, self.state_dict[buffer_name]
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def range_constraints(self):
+        return self._range_constraints
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def module_call_graph(self):
+        return self._module_call_graph
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def example_inputs(self):
+        return self._example_inputs
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def call_spec(self):
+        CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
+
+        if len(self.module_call_graph) == 0:
+            return CallSpec(in_spec=None, out_spec=None)
+        assert self.module_call_graph[0].fqn == ""
+        return CallSpec(
+            in_spec=self.module_call_graph[0].signature.in_spec,
+            out_spec=self.module_call_graph[0].signature.out_spec,
+        )
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def verifier(self) -> Any:
+        return self._verifier
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def dialect(self) -> str:
+        return self._verifier.dialect
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def tensor_constants(self):
+        return self._constants
+
+    @property
+    @compatibility(is_backward_compatible=False)
+    def constants(self):
+        return self._constants
+
+    def _get_flat_args_with_check(self, args, kwargs):
+        """Flatten args, kwargs using pytree, then, check specs.
+
+        Args:
+            args: List[Any] original args passed to __call__
+            kwargs: Dict[str, Any] original kwargs passed to __call
+
+        Returns:
+            A tuple of (flat_args, received_spec)
+            flat_args is flattend args / kwargs
+            received_spec is the pytree spec produced while flattening the
+            tuple (args, kwargs)
+        """
+        in_spec = self.call_spec.in_spec
+        if in_spec is not None:
+            kwargs = reorder_kwargs(kwargs, in_spec)
+        flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
+            (args, kwargs)
+        )  # type: ignore[possibly-undefined]
+        self._check_input_constraints(flat_args_with_path)
+        flat_args = tuple(x[1] for x in flat_args_with_path)
+        return flat_args, received_spec
+
+    def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any:
+        """Transform args, kwargs of __call__ to args for graph_module.
+
+        self.graph_module takes stuff from state dict as inputs.
+        The invariant is for ep: ExportedProgram is
+        ep(args, kwargs) ==
+          ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
+        """
+
+        in_spec = self.call_spec.in_spec
+        flat_args, received_spec = self._get_flat_args_with_check(args, kwargs)
+        if in_spec is not None and not is_equivalent(
+            received_spec, in_spec, _fx_collection_equivalence_fn
+        ):
+            raise ValueError(
+                "Trying to flatten user inputs with exported input tree spec: \n"
+                f"{in_spec}\n"
+                "but actually got inputs with tree spec of: \n"
+                f"{received_spec}"
+            )
+
+        additional_inputs = []
+        for input_ in self.graph_signature.input_specs:
+            if input_.kind == InputKind.USER_INPUT:
+                continue
+            elif input_.kind in (
+                InputKind.PARAMETER,
+                InputKind.BUFFER,
+            ):
+                if input_.persistent is False:
+                    # This is a non-persistent buffer, grab it from our
+                    # constants instead of the state dict.
+                    additional_inputs.append(self.constants[input_.target])
+                else:
+                    additional_inputs.append(self.state_dict[input_.target])
+            elif input_.kind in (
+                InputKind.CONSTANT_TENSOR,
+                InputKind.CUSTOM_OBJ,
+            ):
+                additional_inputs.append(self.constants[input_.target])
+        additional_inputs = tuple(additional_inputs)
+
+        # NOTE: calling convention is first params, then buffers, then args as user supplied them.
+        # See: torch/_functorch/aot_autograd.py#L1034
+        return additional_inputs + flat_args
+
+    def __call__(self, *args: Any, **kwargs: Any) -> Any:
+        raise RuntimeError(
+            "Unable to call ExportedProgram directly. "
+            "You should use `exported_program.module()` instead."
+        )
+
+    def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs):
+        """Process potential mutations to the input.
+
+        Because self.graph_module is functional, so mutations has to be written
+        back after execution of graph_module.
+        """
+        import torch._export.error as error
+
+        flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs)
+        if self.call_spec.out_spec is not None:
+            buffer_mutation = self.graph_signature.buffers_to_mutate
+            user_input_mutation = self.graph_signature.user_inputs_to_mutate
+            num_mutated = len(buffer_mutation) + len(user_input_mutation)
+            mutated_values = res[:num_mutated]
+
+            # Exclude dependency token from final result.
+            assertion_dep_token = self.graph_signature.assertion_dep_token
+            if assertion_dep_token is not None:
+                assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
+                res = res[:assertion_dep_token_index]
+
+            res = res[num_mutated:]
+            try:
+                res = pytree.tree_unflatten(res, self.call_spec.out_spec)
+            except Exception:
+                _, received_spec = pytree.tree_flatten(res)
+                raise error.InternalError(  # noqa: TRY200
+                    "Trying to flatten user outputs with exported output tree spec: \n"
+                    f"{self.call_spec.out_spec}\n"
+                    "but actually got outputs with tree spec of: \n"
+                    f"{received_spec}"
+                )
+            finally:
+                user_inputs = [
+                    spec
+                    for spec in self.graph_signature.input_specs
+                    if spec.kind == InputKind.USER_INPUT
+                ]
+                for i, value in enumerate(mutated_values):
+                    output_spec = self.graph_signature.output_specs[i]
+                    if output_spec.kind == OutputKind.BUFFER_MUTATION:
+                        assert output_spec.target is not None
+                        self.state_dict[output_spec.target] = value
+                    elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
+                        assert output_spec.target is not None
+                        index = next(
+                            i
+                            for i, spec in enumerate(user_inputs)
+                            if spec.arg.name == output_spec.target
+                        )
+                        flat_args[index].copy_(value)
+                    else:
+                        raise AssertionError(f"Unexpected kind: {output_spec.kind}")
+        return res
+
+    def __str__(self) -> str:
+        graph_module = self.graph_module.print_readable(print_output=False).replace(
+            "\n", "\n    "
+        )
+        string = (
+            "ExportedProgram:\n"
+            f"    {graph_module}\n"
+            f"Graph signature: {self.graph_signature}\n"
+            f"Range constraints: {self.range_constraints}\n"
+        )
+        return string
+
+    def module(self) -> torch.nn.Module:
+        """
+        Returns a self contained GraphModule with all the parameters/buffers inlined.
+        """
+        from ._unlift import _unlift_exported_program_lifted_states
+
+        module = _unlift_exported_program_lifted_states(self)
+
+        def _train(self, mode: bool = True):
+            raise NotImplementedError("Calling train() is not supported yet.")
+
+        def _eval(self, mode: bool = True):
+            raise NotImplementedError("Calling eval() is not supported yet.")
+
+        module.train = types.MethodType(_train, module)  # type: ignore[method-assign]
+        module.eval = types.MethodType(_eval, module)  # type: ignore[method-assign]
+        return module
+
+    @_disable_prexisiting_fake_mode
+    def run_decompositions(
+        self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
+    ) -> "ExportedProgram":
+        """
+        Run a set of decompositions on the exported program and returns a new
+        exported program. By default we will run the Core ATen decompositions to
+        get operators in the
+        `Core ATen Operator Set `_.
+
+        For now, we do not decompose joint graphs.
+        """
+        from torch._decomp import core_aten_decompositions
+        from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
+            _AddRuntimeAssertionsForInlineConstraintsPass,
+        )
+        from torch._export.passes.lift_constants_pass import (
+            ConstantAttrMap,
+            lift_constants_pass,
+        )
+        from torch._export.passes.replace_sym_size_ops_pass import (
+            _replace_sym_size_ops_pass,
+        )
+        from torch._functorch.aot_autograd import aot_export_module
+
+        def _get_placeholders(gm):
+            placeholders = []
+            for node in gm.graph.nodes:
+                if node.op != "placeholder":
+                    break
+                placeholders.append(node)
+            return placeholders
+
+        decomp_table = decomp_table or core_aten_decompositions()
+
+        old_placeholders = _get_placeholders(self.graph_module)
+        fake_args = [node.meta["val"] for node in old_placeholders]
+
+        buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]
+        for name in buffers_to_remove:
+            delattr(self.graph_module, name)
+        # TODO(zhxhchen17) Return the new graph_signature directly.
+        gm, graph_signature = aot_export_module(
+            self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False
+        )
+
+        # Update the signatures with the new placeholder names in case they
+        # changed when calling aot_export
+        def update_arg(old_arg, new_ph):
+            if isinstance(old_arg, ConstantArgument):
+                return old_arg
+            elif isinstance(old_arg, TensorArgument):
+                return TensorArgument(name=new_ph.name)
+            elif isinstance(old_arg, SymIntArgument):
+                return SymIntArgument(name=new_ph.name)
+            raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
+
+        new_placeholders = _get_placeholders(gm)
+        new_outputs = list(gm.graph.nodes)[-1].args[0]
+
+        # To match the output target with correct input for input mutations
+        # need to find the old to new placeholder map
+        old_new_placeholder_map = {
+            spec.arg.name: new_placeholders[i].name
+            for i, spec in enumerate(self.graph_signature.input_specs)
+            if not isinstance(spec.arg, ConstantArgument)
+        }
+
+        input_specs = [
+            InputSpec(
+                spec.kind,
+                update_arg(spec.arg, new_placeholders[i]),
+                spec.target,
+                spec.persistent,
+            )
+            for i, spec in enumerate(self.graph_signature.input_specs)
+        ]
+        output_specs = [
+            OutputSpec(
+                spec.kind,
+                update_arg(spec.arg, new_outputs[i]),
+                old_new_placeholder_map.get(spec.target, spec.target),
+            )
+            for i, spec in enumerate(self.graph_signature.output_specs)
+        ]
+
+        assert len(new_placeholders) == len(old_placeholders)
+
+        new_graph_signature = ExportGraphSignature(
+            input_specs=input_specs, output_specs=output_specs
+        )
+        # NOTE: aot_export adds symint metadata for placeholders with int
+        # values; since these become specialized, we replace such metadata with
+        # the original values.
+        # Also, set the param/buffer metadata back to the placeholders.
+        for old_node, new_node in zip(old_placeholders, new_placeholders):
+            if not isinstance(old_node.meta["val"], torch.Tensor):
+                new_node.meta["val"] = old_node.meta["val"]
+
+            if (
+                new_node.target in new_graph_signature.inputs_to_parameters
+                or new_node.target in new_graph_signature.inputs_to_buffers
+            ):
+                for k, v in old_node.meta.items():
+                    new_node.meta[k] = v
+
+        # TODO unfortunately preserving graph-level metadata is not
+        # working well with aot_export. So we manually copy it.
+        # (The node-level meta is addressed above.)
+        gm.meta.update(self.graph_module.meta)
+
+        new_range_constraints = _get_updated_range_constraints(gm)
+
+        constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap())
+        for k, v in constants.items():
+            assert k not in self.constants
+            self.constants[k] = v
+
+        _replace_sym_size_ops_pass(gm)
+        exported_program = ExportedProgram(
+            root=gm,
+            graph=gm.graph,
+            graph_signature=new_graph_signature,
+            state_dict=self.state_dict,
+            range_constraints=new_range_constraints,
+            module_call_graph=copy.deepcopy(self.module_call_graph),
+            example_inputs=self.example_inputs,
+            verifier=self.verifier,
+            constants=self.constants,
+        )
+
+        if len(new_range_constraints) > 0:
+            exported_program = exported_program._transform_do_not_use(
+                _AddRuntimeAssertionsForInlineConstraintsPass(new_range_constraints)
+            )
+
+        return exported_program
+
+    def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram":
+        pm = PassManager(list(passes))
+        res = pm(self.graph_module)
+        transformed_gm = res.graph_module if res is not None else self.graph_module
+        assert transformed_gm is not None
+
+        if transformed_gm is self.graph_module and not res.modified:
+            return self
+
+        # TODO(zhxchen17) Remove this.
+        def _get_updated_graph_signature(
+            old_signature: ExportGraphSignature,
+            new_gm: torch.fx.GraphModule,
+        ) -> ExportGraphSignature:
+            """
+            Update the graph signature's user_input/user_outputs.
+            """
+            new_input_specs = []
+            for i, node in enumerate(new_gm.graph.nodes):
+                if node.op != "placeholder":
+                    break
+
+                assert i < len(
+                    old_signature.input_specs
+                ), "Number of inputs changed after transformation"
+                old_input_spec = old_signature.input_specs[i]
+                arg = (
+                    old_input_spec.arg
+                    if isinstance(
+                        old_input_spec.arg, (ConstantArgument, CustomObjArgument)
+                    )
+                    else type(old_input_spec.arg)(node.name)
+                )
+                new_input_specs.append(
+                    InputSpec(
+                        old_input_spec.kind,
+                        arg,
+                        old_input_spec.target,
+                        old_input_spec.persistent,
+                    )
+                )
+
+            output_node = list(new_gm.graph.nodes)[-1]
+            assert output_node.op == "output"
+
+            new_output_specs = []
+            for i, node in enumerate(output_node.args[0]):
+                assert i < len(
+                    old_signature.output_specs
+                ), "Number of outputs changed after transformation"
+                old_output_spec = old_signature.output_specs[i]
+                arg = (
+                    old_output_spec.arg
+                    if isinstance(
+                        old_output_spec.arg, (ConstantArgument, CustomObjArgument)
+                    )
+                    else type(old_output_spec.arg)(node.name)
+                )
+                new_output_specs.append(
+                    OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
+                )
+
+            new_signature = ExportGraphSignature(
+                input_specs=new_input_specs, output_specs=new_output_specs
+            )
+            return new_signature
+
+        transformed_ep = ExportedProgram(
+            root=transformed_gm,
+            graph=transformed_gm.graph,
+            graph_signature=_get_updated_graph_signature(
+                self.graph_signature, transformed_gm
+            ),
+            state_dict=self.state_dict,
+            range_constraints=_get_updated_range_constraints(transformed_gm),
+            module_call_graph=copy.deepcopy(self._module_call_graph),
+            example_inputs=self.example_inputs,
+            verifier=self.verifier,
+            constants=self.constants,
+        )
+        transformed_ep.graph_module.meta.update(self.graph_module.meta)
+        transformed_ep.graph_module.meta.update(res.graph_module.meta)
+        return transformed_ep
+
+    def _check_input_constraints(self, flat_args_with_path):
+        from torch._export.utils import _check_input_constraints_for_graph
+
+        placeholders = [p for p in self.graph.nodes if p.op == "placeholder"]
+        input_placeholders = [
+            p
+            for p, s in zip(placeholders, self.graph_signature.input_specs)
+            if s.kind == InputKind.USER_INPUT
+        ]
+        _check_input_constraints_for_graph(
+            input_placeholders, flat_args_with_path, self.range_constraints
+        )
+
+    def _validate(self):
+        self.verifier().check(self)
+
+    # TODO(zhxchen17) Formalize this.
+    def _update(
+        self, graph_module, graph_signature, state_dict=None
+    ) -> "ExportedProgram":
+        return ExportedProgram(
+            root=graph_module,
+            graph=graph_module.graph,
+            graph_signature=graph_signature,
+            state_dict=state_dict or self.state_dict,
+            range_constraints=copy.deepcopy(self.range_constraints),
+            module_call_graph=copy.deepcopy(self._module_call_graph),
+            example_inputs=self.example_inputs,
+            verifier=self.verifier,
+            tensor_constants=self.tensor_constants,
+        )
+
+
+def _get_updated_range_constraints(
+    gm: torch.fx.GraphModule,
+) -> "Dict[sympy.Symbol, Any]":
+    def get_shape_env(gm):
+        vals = [
+            node.meta["val"]
+            for node in gm.graph.nodes
+            if node.meta.get("val", None) is not None
+        ]
+        from torch._guards import detect_fake_mode
+
+        fake_mode = detect_fake_mode(vals)
+        if fake_mode is not None:
+            return fake_mode.shape_env
+        for v in vals:
+            if isinstance(v, torch.SymInt):
+                return v.node.shape_env
+
+    shape_env = get_shape_env(gm)
+    if shape_env is None:
+        return {}
+    range_constraints = {
+        k: v
+        for k, v in shape_env.var_to_range.items()
+        if k not in shape_env.replacements
+    }
+    # Only when we have an unbacked symint, and it's used as constructor inputs,
+    # runtime_var_to_range will make a difference compated to var_to_range.
+    # e.g. [2, oo) -> [0, oo)
+    for k, v in shape_env.var_to_range.items():
+        if k not in shape_env.replacements:
+            range_constraints[k] = v
+    return range_constraints
+
+
+def _create_graph_module_for_export(root, graph):
+    try:
+        gm = torch.fx.GraphModule(root, graph)
+    except SyntaxError:
+        # If custom objects stored in memory are being used in the graph,
+        # the generated python code will result in a syntax error on the custom
+        # object, since it is unable to parse the in-memory object. However
+        # we can still run the graph eagerly through torch.fx.Interpreter,
+        # so we will bypass this error.
+        warnings.warn(
+            "Unable to execute the generated python source code from "
+            "the graph. The graph module will no longer be directly callable, "
+            "but you can still run the ExportedProgram, and if needed, you can "
+            "run the graph module eagerly using torch.fx.Interpreter."
+        )
+        gm = torch.fx.GraphModule(root, torch.fx.Graph())
+        gm._graph = graph
+
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/export/graph_signature.py b/MLPY/Lib/site-packages/torch/export/graph_signature.py
new file mode 100644
index 0000000000000000000000000000000000000000..57919aae7a1cca6166dedad764c7bbadbc66ca2c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/graph_signature.py
@@ -0,0 +1,504 @@
+import dataclasses
+from enum import auto, Enum
+from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union
+
+
+__all__ = [
+    "ConstantArgument",
+    "CustomObjArgument",
+    "ExportBackwardSignature",
+    "ExportGraphSignature",
+    "InputKind",
+    "InputSpec",
+    "OutputKind",
+    "OutputSpec",
+    "SymIntArgument",
+    "TensorArgument",
+]
+
+
+@dataclasses.dataclass
+class TensorArgument:
+    name: str
+
+
+@dataclasses.dataclass
+class SymIntArgument:
+    name: str
+
+
+@dataclasses.dataclass
+class CustomObjArgument:
+    name: str
+    class_fqn: str
+
+
+@dataclasses.dataclass
+class ConstantArgument:
+    value: Union[int, float, bool, None]
+
+
+ArgumentSpec = Union[
+    TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument
+]
+
+
+class InputKind(Enum):
+    USER_INPUT = auto()
+    PARAMETER = auto()
+    BUFFER = auto()
+    CONSTANT_TENSOR = auto()
+    CUSTOM_OBJ = auto()
+    TOKEN = auto()
+
+
+@dataclasses.dataclass
+class InputSpec:
+    kind: InputKind
+    arg: ArgumentSpec
+    target: Optional[str]
+    persistent: Optional[bool] = None
+
+    def __post_init__(self):
+        if self.kind == InputKind.BUFFER:
+            assert (
+                self.persistent is not None
+            ), "Failed to specify persistent flag on BUFFER."
+        assert isinstance(
+            self.arg,
+            (TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument),
+        ), f"got {type(self.arg)}"
+
+
+class OutputKind(Enum):
+    USER_OUTPUT = auto()
+    LOSS_OUTPUT = auto()
+    BUFFER_MUTATION = auto()
+    GRADIENT_TO_PARAMETER = auto()
+    GRADIENT_TO_USER_INPUT = auto()
+    USER_INPUT_MUTATION = auto()
+    TOKEN = auto()
+
+
+@dataclasses.dataclass
+class OutputSpec:
+    kind: OutputKind
+    arg: ArgumentSpec
+    target: Optional[str]
+
+    def __post_init__(self):
+        assert isinstance(self.arg, (TensorArgument, SymIntArgument, ConstantArgument))
+
+
+def _sig_to_specs(
+    *,
+    user_inputs: Set[str],
+    inputs_to_parameters: Mapping[str, str],
+    inputs_to_buffers: Mapping[str, str],
+    user_outputs: Set[str],
+    buffer_mutations: Mapping[str, str],
+    user_input_mutations: Mapping[str, str],
+    grad_params: Mapping[str, str],
+    grad_user_inputs: Mapping[str, str],
+    loss_output: Optional[str],
+    inputs: List[ArgumentSpec],
+    outputs: List[ArgumentSpec],
+    input_tokens: List[str],
+    output_tokens: List[str],
+) -> Tuple[List[InputSpec], List[OutputSpec]]:
+    def to_input_spec(inp: ArgumentSpec) -> InputSpec:
+        if not isinstance(inp, TensorArgument):
+            return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
+        name = inp.name
+        if name in user_inputs:
+            return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
+        elif name in inputs_to_parameters:
+            return InputSpec(
+                kind=InputKind.PARAMETER,
+                arg=inp,
+                target=inputs_to_parameters[name],
+            )
+        elif name in inputs_to_buffers:
+            return InputSpec(
+                kind=InputKind.BUFFER,
+                arg=inp,
+                target=inputs_to_buffers[name],
+                # Mark as True for now; we will fix this up to distinguish
+                # persistent from non-persistent later in tracing.
+                # See: rewrite_non_persistent_buffers()
+                # TODO(suo): this is horrible.
+                persistent=True,
+            )
+        elif name in input_tokens:
+            return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None)
+        else:
+            raise AssertionError(f"Unknown tensor input kind: {name}")
+
+    def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
+        if not isinstance(o, TensorArgument):
+            return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
+        name = o.name
+        if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
+            if name in buffer_mutations:
+                return OutputSpec(
+                    kind=OutputKind.BUFFER_MUTATION,
+                    arg=o,
+                    target=buffer_mutations[name],
+                )
+            elif name in user_input_mutations:
+                return OutputSpec(
+                    kind=OutputKind.USER_INPUT_MUTATION,
+                    arg=o,
+                    target=user_input_mutations[name],
+                )
+            elif name in output_tokens:
+                return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None)
+            else:
+                raise AssertionError(f"Unknown tensor mutation kind: {name}")
+        else:
+            if name in user_outputs:
+                return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
+
+            elif name in grad_params:
+                return OutputSpec(
+                    kind=OutputKind.GRADIENT_TO_PARAMETER,
+                    arg=o,
+                    target=grad_params[name],
+                )
+            elif name in grad_user_inputs:
+                return OutputSpec(
+                    kind=OutputKind.GRADIENT_TO_USER_INPUT,
+                    arg=o,
+                    target=grad_user_inputs[name],
+                )
+            elif name == loss_output:
+                return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
+
+            else:
+                raise AssertionError(f"Unknown tensor output kind: {name}")
+
+    input_specs = [to_input_spec(inp) for inp in inputs]
+    output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
+    return input_specs, output_specs
+
+
+@dataclasses.dataclass
+class ExportBackwardSignature:
+    gradients_to_parameters: Dict[str, str]
+    gradients_to_user_inputs: Dict[str, str]
+    loss_output: str
+
+
+@dataclasses.dataclass
+class ExportGraphSignature:
+    """
+    :class:`ExportGraphSignature` models the input/output signature of Export Graph,
+    which is a fx.Graph with stronger invariants gurantees.
+
+    Export Graph is functional and does not access "states" like parameters
+    or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
+    gurantees that parameters, buffers, and constant tensors are lifted out of
+    the graph as inputs.  Similarly, any mutations to buffers are not included
+    in the graph either, instead the updated values of mutated buffers are
+    modeled as additional outputs of Export Graph.
+
+    The ordering of all inputs and outputs are::
+
+        Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
+        Outputs = [*mutated_inputs, *flattened_user_outputs]
+
+    e.g. If following module is exported::
+
+        class CustomModule(nn.Module):
+            def __init__(self):
+                super(CustomModule, self).__init__()
+
+                # Define a parameter
+                self.my_parameter = nn.Parameter(torch.tensor(2.0))
+
+                # Define two buffers
+                self.register_buffer('my_buffer1', torch.tensor(3.0))
+                self.register_buffer('my_buffer2', torch.tensor(4.0))
+
+            def forward(self, x1, x2):
+                # Use the parameter, buffers, and both inputs in the forward method
+                output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
+
+                # Mutate one of the buffers (e.g., increment it by 1)
+                self.my_buffer2.add_(1.0) # In-place addition
+
+                return output
+
+    Resulting Graph would be::
+
+        graph():
+            %arg0_1 := placeholder[target=arg0_1]
+            %arg1_1 := placeholder[target=arg1_1]
+            %arg2_1 := placeholder[target=arg2_1]
+            %arg3_1 := placeholder[target=arg3_1]
+            %arg4_1 := placeholder[target=arg4_1]
+            %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
+            %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
+            %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
+            %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
+            %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
+            return (add_tensor_2, add_tensor_1)
+
+    Resulting ExportGraphSignature would be::
+
+        ExportGraphSignature(
+            input_specs=[
+                InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
+                InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
+                InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
+                InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None),
+                InputSpec(kind=, arg=TensorArgument(name='arg4_1'), target=None)
+            ],
+            output_specs=[
+                OutputSpec(kind=, arg=TensorArgument(name='add_2'), target='my_buffer2'),
+                OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None)
+            ]
+        )
+    """
+
+    input_specs: List[InputSpec]
+    output_specs: List[OutputSpec]
+
+    # A list of parameters uniquely identified by mangled fully qualified name
+    @property
+    def parameters(self) -> Collection[str]:
+        # TODO Make this tuple.
+        return [
+            s.target
+            for s in self.input_specs
+            if s.kind == InputKind.PARAMETER
+            if isinstance(s.target, str)
+        ]
+
+    # A list of buffers uniquely identified by mangled fully qualified name
+    @property
+    def buffers(self) -> Collection[str]:
+        # TODO Make this tuple.
+        return [
+            s.target
+            for s in self.input_specs
+            if s.kind == InputKind.BUFFER
+            if isinstance(s.target, str)
+        ]
+
+    @property
+    def non_persistent_buffers(self) -> Collection[str]:
+        return [
+            s.target
+            for s in self.input_specs
+            if s.kind == InputKind.BUFFER
+            if s.persistent is False
+            if isinstance(s.target, str)
+        ]
+
+    # A list of lifted constant tensors
+    @property
+    def lifted_tensor_constants(self) -> Collection[str]:
+        # TODO Make this tuple.
+        return [
+            s.target
+            for s in self.input_specs
+            if s.kind == InputKind.CONSTANT_TENSOR
+            if isinstance(s.target, str)
+        ]
+
+    @property
+    def lifted_custom_objs(self) -> Collection[str]:
+        # TODO Make this tuple.
+        return [
+            s.target
+            for s in self.input_specs
+            if s.kind == InputKind.CUSTOM_OBJ
+            if isinstance(s.target, str)
+        ]
+
+    # Graph node names of pytree-flattened inputs of original program
+    @property
+    def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
+        user_inputs: List[Union[int, float, bool, None, str]] = []
+        for s in self.input_specs:
+            if s.kind != InputKind.USER_INPUT:
+                continue
+
+            if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)):
+                user_inputs.append(s.arg.name)
+            elif isinstance(s.arg, ConstantArgument):
+                user_inputs.append(s.arg.value)
+            else:
+                raise RuntimeError(f"{s.arg} is not a valid user inputs")
+        return tuple(user_inputs)
+
+    # Graph node names of pytree-flattened outputs of original program
+    @property
+    def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
+        user_outputs: List[Union[int, float, bool, None, str]] = []
+        for s in self.output_specs:
+            if s.kind != OutputKind.USER_OUTPUT:
+                continue
+
+            if isinstance(s.arg, (TensorArgument, SymIntArgument)):
+                user_outputs.append(s.arg.name)
+            elif isinstance(s.arg, ConstantArgument):
+                user_outputs.append(s.arg.value)
+            else:
+                raise RuntimeError(f"{s.arg} is not a valid user output")
+        return tuple(user_outputs)
+
+    # A dictionary mapping graph input node names to parameters. If a graph input
+    # name is found in this dictionary, it is guranteed to be a lifted parameter.
+    @property
+    def inputs_to_parameters(self) -> Mapping[str, str]:
+        return {
+            s.arg.name: s.target
+            for s in self.input_specs
+            if s.kind == InputKind.PARAMETER
+            and isinstance(s.arg, TensorArgument)
+            and isinstance(s.target, str)
+        }
+
+    # A dictionary mapping graph input node names to buffers. If a graph input
+    # name is found in this dictionary, it is guranteed to be a lifted buffer.
+    @property
+    def inputs_to_buffers(self) -> Mapping[str, str]:
+        return {
+            s.arg.name: s.target  # type: ignore[union-attr, misc]
+            for s in self.input_specs
+            if s.kind == InputKind.BUFFER
+            and isinstance(s.arg, TensorArgument)
+            and isinstance(s.target, str)
+        }
+
+    # A dictionary mapping graph output node names to buffers that are mutated in the
+    # original program. Buffers that are not mutated will not be found in this dictionary.
+    @property
+    def buffers_to_mutate(self) -> Mapping[str, str]:
+        return {
+            s.arg.name: s.target
+            for s in self.output_specs
+            if s.kind == OutputKind.BUFFER_MUTATION
+            and isinstance(s.arg, TensorArgument)
+            and isinstance(s.target, str)
+        }
+
+    @property
+    def user_inputs_to_mutate(self) -> Mapping[str, str]:
+        return {
+            s.arg.name: s.target
+            for s in self.output_specs
+            if s.kind == OutputKind.USER_INPUT_MUTATION
+            and isinstance(s.arg, TensorArgument)
+            and isinstance(s.target, str)
+        }
+
+    # A dictionary mapping graph input node names to lifted tensor constants.
+    @property
+    def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
+        return {
+            s.arg.name: s.target
+            for s in self.input_specs
+            if s.kind == InputKind.CONSTANT_TENSOR
+            and isinstance(s.arg, TensorArgument)
+            and isinstance(s.target, str)
+        }
+
+    @property
+    def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]:
+        return {
+            s.arg.name: s.target
+            for s in self.input_specs
+            if s.kind == InputKind.CUSTOM_OBJ
+            and isinstance(s.arg, CustomObjArgument)
+            and isinstance(s.target, str)
+        }
+
+    @property
+    def backward_signature(self) -> Optional[ExportBackwardSignature]:
+        loss_output = None
+        gradients_to_parameters: Dict[str, str] = {}
+        gradients_to_user_inputs: Dict[str, str] = {}
+        for spec in self.output_specs:
+            if spec.kind == OutputKind.LOSS_OUTPUT:
+                assert loss_output is None
+                assert isinstance(spec.arg, TensorArgument)
+                loss_output = spec.arg.name
+            elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
+                assert isinstance(spec.target, str)
+                assert isinstance(spec.arg, TensorArgument)
+                gradients_to_parameters[spec.arg.name] = spec.target
+            elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT:
+                assert isinstance(spec.target, str)
+                assert isinstance(spec.arg, TensorArgument)
+                gradients_to_user_inputs[spec.arg.name] = spec.target
+
+        if loss_output is None:
+            return None
+
+        return ExportBackwardSignature(
+            loss_output=loss_output,
+            gradients_to_parameters=gradients_to_parameters,
+            gradients_to_user_inputs=gradients_to_user_inputs,
+        )
+
+    # Map from assertion dependency token index to assertion dep token output
+    # name in output. The shape of output after aot_autograd will be like:
+    # (updated_inputs, user_outputs, dep_token).
+    @property
+    def assertion_dep_token(self) -> Optional[Mapping[int, str]]:
+        return None
+
+    @property
+    def input_tokens(self) -> List[str]:
+        input_tokens = []
+        for s in self.input_specs:
+            if s.kind == InputKind.TOKEN:
+                assert isinstance(s.arg, TensorArgument)
+                input_tokens.append(s.arg.name)
+        return input_tokens
+
+    @property
+    def output_tokens(self) -> List[str]:
+        output_tokens = []
+        for s in self.output_specs:
+            if s.kind == OutputKind.TOKEN:
+                assert isinstance(s.arg, TensorArgument)
+                output_tokens.append(s.arg.name)
+        return output_tokens
+
+    def __post_init__(self) -> None:
+        assertion_dep_token = self.assertion_dep_token
+        if assertion_dep_token is None:
+            return
+        assert len(assertion_dep_token) == 1
+        assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
+        assert (
+            len(self.user_outputs) + len(self.buffers_to_mutate)
+            == assertion_dep_token_index
+        )
+
+    def replace_all_uses(self, old: str, new: str):
+        """
+        Replace all uses of the old name with new name in the signature.
+        """
+        assert isinstance(old, str)
+        assert isinstance(new, str)
+        arg_types = (TensorArgument, SymIntArgument, CustomObjArgument)
+        for o in self.output_specs:
+            if isinstance(o.arg, arg_types):
+                if o.arg.name == old:
+                    o.arg.name = new
+        for i in self.input_specs:
+            if isinstance(i.arg, arg_types):
+                if i.arg.name == old:
+                    i.arg.name = new
+
+    def get_replace_hook(self):
+        def _(old, new, user):
+            if user.op in ("output", "input"):
+                self.replace_all_uses(old.name, new)
+
+        return _
diff --git a/MLPY/Lib/site-packages/torch/export/unflatten.py b/MLPY/Lib/site-packages/torch/export/unflatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d46ca1792359ff6a5487f5eb8e91e4b4a6dbb5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/export/unflatten.py
@@ -0,0 +1,860 @@
+import abc
+import copy
+import operator
+from copy import deepcopy
+from enum import Enum
+from itertools import chain
+from typing import Any, cast, Dict, List, Optional, Union
+
+import torch
+import torch.fx._pytree as fx_pytree
+import torch.utils._pytree as pytree
+from torch.export._tree_utils import reorder_kwargs
+from torch.export.exported_program import (
+    ConstantArgument,
+    ExportedProgram,
+    ModuleCallSignature,
+    SymIntArgument,
+    TensorArgument,
+)
+from torch.fx._symbolic_trace import is_fx_tracing
+from torch.utils._pytree import GetAttrKey, SequenceKey
+
+__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"]
+
+
+class _AttrKind(Enum):
+    PARAMETER = "parameter"
+    BUFFER = "buffer"
+    CONSTANT = "constant"
+
+
+# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
+# This installs empty Modules where none exist yet if they are subpaths of target
+def _assign_attr(
+    from_obj: Union[torch.Tensor, torch.ScriptObject],
+    to_module: torch.nn.Module,
+    target: str,
+    attr_kind: _AttrKind,
+    persistent: bool = True,
+):
+    *prefix, field = target.split(".")
+    for item in prefix:
+        t = getattr(to_module, item, None)
+
+        if t is None:
+            t = torch.nn.Module()
+            setattr(to_module, item, t)
+        to_module = t
+
+    if attr_kind == _AttrKind.PARAMETER:
+        assert isinstance(from_obj, torch.nn.Parameter)
+        to_module.register_parameter(field, from_obj)
+    elif attr_kind == _AttrKind.BUFFER:
+        assert isinstance(from_obj, torch.Tensor)
+        to_module.register_buffer(field, from_obj, persistent=persistent)
+    elif attr_kind == _AttrKind.CONSTANT:
+        assert isinstance(from_obj, (torch.Tensor, torch.ScriptObject))
+        setattr(to_module, field, from_obj)
+
+
+class InterpreterModule(torch.nn.Module):
+    """A module that uses torch.fx.Interpreter to execute instead of the usual
+    codegen that GraphModule uses. This provides better stack trace information
+    and makes it easier to debug execution.
+    """
+
+    def __init__(
+        self,
+        graph: torch.fx.Graph,
+    ):
+        super().__init__()
+        self.graph = graph
+        self.graph.owning_module = self
+
+    def forward(self, *args, **kwargs):
+        assert self.graph_module is not None, "Didn't finalize this InterpreterModule"
+        if torch.compiler.is_dynamo_compiling():
+            # Dynamo cannot trace through torch.fx.Interpreter, so fall back to
+            # GraphModule codegen in this instance.
+            return self.graph_module(*args, **kwargs)
+        else:
+            if kwargs:
+                # Handle **kwargs. FX only natively supports positional
+                # arguments (through placeholders). So in order to pass in
+                # kwargs, we must correspond the names of the placeholders with
+                # the keys in the kwarg dict.
+                arg_list = list(args)
+                kwarg_names = self.arg_names[len(arg_list) :]
+                for kwarg_name in kwarg_names:
+                    if kwarg_name in kwargs:
+                        arg_list.append(kwargs[kwarg_name])
+
+                # Assert that the kwargs passed in exactly match the positional
+                # arguments specified by the GraphModule. This should be
+                # guaranteed by the unflattening process.
+                assert len(kwarg_names) == len(kwargs)
+                assert len(arg_list) == len(self.arg_names)
+                args = tuple(arg_list)
+
+            return torch.fx.Interpreter(self, graph=self.graph).run(
+                *args, enable_io_processing=False
+            )
+
+    def finalize(self):
+        # We need to "finalize" because GraphModule populates its own state_dict
+        # based on the get_attrs observed in the graph. So we need to fully
+        # construct the graph and call _sink_params before generating this
+        # GraphModule.
+
+        # need to set `graph_module` directly on the dict to avoid it getting
+        # registered as a submodule.
+        self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph)
+        self.graph.lint()
+
+        # Cache arg names for kwarg handling (see forward())
+        self.arg_names = []
+        for node in self.graph.nodes:
+            if node.op == "placeholder":
+                self.arg_names.append(node.target)
+
+
+class FlatArgsAdapter(abc.ABC):
+    """
+    Adapts input arguments with ``input_spec`` to align ``target_spec``.
+    """
+
+    @abc.abstractmethod
+    def adapt(
+        self,
+        target_spec: pytree.TreeSpec,
+        input_spec: pytree.TreeSpec,
+        input_args: List[Any],
+    ) -> List[Any]:
+        """NOTE: This adapter may mutate given ``input_args_with_path``."""
+        ...
+
+
+class UnflattenedModule(torch.nn.Module):
+    def __init__(
+        self,
+        export_module: ExportedProgram,
+        flat_args_adapter: Optional[FlatArgsAdapter] = None,
+    ):
+        super().__init__()
+        if export_module.graph_signature.backward_signature is not None:
+            raise ValueError("Unflattening on JointExportModule NYI")
+
+        export_graph = deepcopy(export_module.graph)
+        self.graph_signature = deepcopy(export_module.graph_signature)
+        self.graph = torch.fx.Graph()
+        self.module_call_graph = deepcopy(export_module.module_call_graph)
+        self.flat_args_adapter = flat_args_adapter
+        # Flag to indicate whether args have been adapted.
+        self.adapted = False
+
+        _inplace_buffer_mutations(export_graph, self.graph_signature)
+        _outline_submodules(export_graph, self)
+
+        self.range_constraints = export_module.range_constraints
+        self.equality_constraints: List = []
+
+        state_dict = export_module.state_dict
+        for name in self.graph_signature.parameters:
+            cloned = torch.nn.Parameter(state_dict[name].clone())
+            _assign_attr(
+                cloned,
+                self,
+                name,
+                attr_kind=_AttrKind.PARAMETER,
+            )
+
+        non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
+        for name in self.graph_signature.buffers:
+            if name in non_persistent_buffers:
+                persistent = False
+                cloned = export_module.constants[name].clone()
+            else:
+                persistent = True
+                cloned = state_dict[name].clone()
+
+            _assign_attr(
+                cloned,
+                self,
+                name,
+                attr_kind=_AttrKind.BUFFER,
+                persistent=persistent,
+            )
+
+        for fqn in chain(
+            self.graph_signature.lifted_tensor_constants,
+            self.graph_signature.lifted_custom_objs,
+        ):
+            constant = export_module.constants[fqn]
+            if isinstance(constant, torch.Tensor):
+                constant = constant.clone()
+            _assign_attr(
+                constant,
+                self,
+                fqn,
+                attr_kind=_AttrKind.CONSTANT,
+            )
+
+        inputs_to_state: Dict[str, str] = {
+            **self.graph_signature.inputs_to_parameters,
+            **self.graph_signature.inputs_to_buffers,
+            **self.graph_signature.inputs_to_lifted_tensor_constants,
+            **self.graph_signature.inputs_to_lifted_custom_objs,
+        }
+
+        _sink_params(self, inputs_to_state, [])
+        # Check all input nodes has been processed.
+        for module in self.modules():
+            if not isinstance(module, torch.fx.GraphModule):
+                continue
+            for node in module.graph.nodes:
+                if node.op != "placeholder":
+                    continue
+                assert node.name not in inputs_to_state
+
+        # Cache so we don't have to compute this every time.
+        # NOTE: this needs to be kept in sync with the placeholders in
+        # self.graph, but currently we have no way to guarantee that.
+        self.input_placeholders = [
+            node for node in self.graph.nodes if node.op == "placeholder"
+        ]
+        self.check_input_constraints = True
+        assert self.module_call_graph[0].fqn == ""
+
+    def forward(self, *args, **kwargs):
+        signature = self.module_call_graph[0].signature
+
+        reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec)
+
+        flat_args_with_path, in_spec = pytree.tree_flatten_with_path(
+            (args, reordered_kwargs)
+        )
+        flat_args = [x[1] for x in flat_args_with_path]
+        if is_fx_tracing():
+            return_val = torch.fx.Interpreter(self, graph=self.graph).run(
+                *flat_args, enable_io_processing=False
+            )
+            # For scalar return value, fx.Graph wraps in a tuple
+            if isinstance(return_val, tuple) and len(return_val) == 1:
+                return return_val[0]
+            return return_val
+
+        if in_spec != signature.in_spec:
+            if not self.adapted:
+                print(
+                    "Input treespec does not match with exported module's: \n"
+                    f"Input treespec: {in_spec}. ",
+                    f"Exported module treespec: {signature.in_spec}",
+                )
+            if self.flat_args_adapter is None:
+                raise TypeError(
+                    "There is no flat args adapter sepcified. "
+                    "Are you sure you are calling this with the right arguments? "
+                )
+            else:
+                if not self.adapted:
+                    print("Adapting flat arg to match exported module's treespec")
+                flat_args = self.flat_args_adapter.adapt(
+                    target_spec=signature.in_spec,
+                    input_spec=in_spec,
+                    input_args=flat_args,
+                )
+                self.adapted = True
+                if len(flat_args) != signature.in_spec.num_leaves:
+                    raise TypeError(
+                        f"Flat args adaption failed, number of args mismatch "
+                        f"Adatped: {len(flat_args)} \n"
+                        f"Exported module: {signature.in_spec.num_leaves}"
+                    )
+
+        if self.check_input_constraints:
+            # Import here to avoid an unfortunate circular dependency.
+            # TODO(suo): untangle this.
+            from torch._export.utils import _check_input_constraints_for_graph
+
+            if self.adapted is True:
+                # TODO(suo): The FlatArgsAdapter returns a list of flat args,
+                # which we don't have keypaths for. For now, just create a dummy
+                # keypath to associate with the arg.
+                new_flat_args_with_path = [  # type: ignore[var-annotated]
+                    ((SequenceKey(idx=0), GetAttrKey(name="")), arg)
+                    for arg in flat_args
+                ]
+            else:
+                new_flat_args_with_path = flat_args_with_path  # type: ignore[assignment]
+
+            _check_input_constraints_for_graph(
+                self.input_placeholders, new_flat_args_with_path, self.range_constraints
+            )
+        tree_out = torch.fx.Interpreter(self, graph=self.graph).run(
+            *flat_args, enable_io_processing=False
+        )
+        return pytree.tree_unflatten(tree_out, signature.out_spec)
+
+
+def unflatten(
+    module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None
+) -> UnflattenedModule:
+    """Unflatten an ExportedProgram, producing a module with the same module
+    hierarchy as the original eager module. This can be useful if you are trying
+    to use :mod:`torch.export` with another system that expects a module
+    hierachy instead of the flat graph that :mod:`torch.export` usually produces.
+
+    .. note:: The args/kwargs of unflattened modules will not necessarily match
+        the eager module, so doing a module swap (e.g. :code:`self.submod =
+        new_mod`) will not necessarily work. If you need to swap a module out, you
+        need to set the :code:`preserve_module_call_signature` parameter of
+        :func:`torch.export.export`.
+
+    Args:
+        module (ExportedProgram): The ExportedProgram to unflatten.
+        flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's.
+
+    Returns:
+        An instance of :class:`UnflattenedModule`, which has the same module
+        hierarchy as the original eager module pre-export.
+    """
+    return UnflattenedModule(module, flat_args_adapter)
+
+
+def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None:
+    """Transform buffer mutations from their functionalized form into a copy_
+    node in the graph.
+
+    Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code:
+        def forward(self, x):
+            self.buffer += x
+            return x * x
+
+    Will become a graph that looks like:
+        def forward(self, buffer, x):
+            mutated_buffer = aten.add(buffer, x)
+            mul = aten.mul(x, x)
+            return (mutated_buffer, mul)
+
+    We want to inplace this into something that looks like the original eager code:
+        def forward(self, buffer, x):
+            mutated_buffer = aten.add(buffer, x)
+            buffer.copy_(mutated_buffer)
+            mul = aten.mul(x, x)
+            return (mul,)
+    """
+    output_node = next(iter(reversed(graph.nodes)))
+    assert output_node.op == "output" and len(output_node.args) == 1
+    return_args = output_node.args[0]
+
+    mutation_node_to_buffer = graph_signature.buffers_to_mutate
+    mutations = return_args[: len(mutation_node_to_buffer)]
+    buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()}
+    input_name_to_node = {
+        node.name: node for node in graph.nodes if node.op == "placeholder"
+    }
+
+    for mutation in mutations:
+        buffer_name = mutation_node_to_buffer[mutation.name]
+        input_name = buffers_to_inputs[buffer_name]
+        input_node = input_name_to_node[input_name]
+
+        with graph.inserting_after(mutation):
+            new_node = graph.create_node(
+                "call_function", torch.ops.aten.copy_, (input_node, mutation)
+            )
+            for k, v in mutation.meta.items():
+                new_node.meta[k] = v
+        # Replace all uses of the previously functional mutation with our copy_ output.
+        mutation.replace_all_uses_with(new_node, lambda x: x is not new_node)
+
+    # Remove the mutated buffer from the graph outputs, since we don't need to
+    # thread it through anymore. We don't need to handle the inputs, which will
+    # be handled by _sink_params.
+    user_outputs = tuple(
+        return_args[len(mutation_node_to_buffer) :],
+    )
+    output_node.args = ((user_outputs),)
+
+
+def _is_prefix(candidate, target):
+    """Check whether `candidate` is a prefix of `target`."""
+    return len(candidate) < len(target) and target[: len(candidate)] == candidate
+
+
+def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
+    if parent_fqn == "":
+        # Handle the root module correctly.
+        return child_fqn
+
+    parent_split = parent_fqn.split(".")
+    child_split = child_fqn.split(".")
+
+    assert (
+        child_split[: len(parent_split)] == parent_split
+    ), f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'"
+    return ".".join(child_split[len(parent_split) :])
+
+
+def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
+    def graph_dump(graph: torch.fx.Graph) -> str:
+        ret = []
+        nodes_idx: Dict[int, int] = {}
+
+        def arg_dump(arg) -> str:
+            if isinstance(arg, torch.fx.Node):
+                return "%" + str(nodes_idx[id(arg)])
+            return str(arg)
+
+        for i, node in enumerate(graph.nodes):
+            args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)]
+            args_dump += [
+                f"{key}={value}"
+                for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
+            ]
+            target = node.target if node.op == "call_function" else ""
+            ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
+            nodes_idx[id(node)] = i
+        return "\n".join(ret)
+
+    assert graph_dump(x.graph) == graph_dump(y.graph)
+
+
+def _add_spec(gm: torch.nn.Module, spec) -> str:
+    i = 0
+    while hasattr(gm, f"_spec_{i}"):
+        i += 1
+    name = f"_spec_{i}"
+    setattr(gm, name, spec)
+    return name
+
+
+def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node:
+    name = _add_spec(gm, spec)
+    spec_node = gm.graph.get_attr(name)
+    return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node))
+
+
+def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node:
+    name = _add_spec(gm, spec)
+    spec_node = gm.graph.get_attr(name)
+    return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
+
+
+def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module):
+    *prefix, field = target.split(".")
+
+    for item in prefix:
+        submod = getattr(mod, item, None)
+
+        if submod is None:
+            submod = torch.nn.Module()
+            setattr(mod, item, submod)
+
+        if not isinstance(submod, torch.nn.Module):
+            return False
+
+        mod = submod
+
+    mod.add_module(field, module_to_add)
+
+
+class _ModuleFrame:
+    def __init__(
+        self,
+        flat_graph,
+        nodes,
+        seen_nodes,
+        seen_modules,
+        parent,
+        module_stack,
+        module_id,
+        module_call_graph: Dict[str, ModuleCallSignature],
+        module: Optional[torch.nn.Module] = None,
+    ):
+        self.flat_graph = flat_graph
+        self.nodes = nodes
+        self.seen_nodes = seen_nodes
+        self.seen_modules = seen_modules
+        self.parent = parent
+        self.module_stack = module_stack
+        self.module_id = module_id
+
+        self.module_call_graph = module_call_graph
+        self.verbose = False
+
+        self.fqn = self.module_stack[-1]
+        if module is not None:
+            self.module = module
+        else:
+            self.module = InterpreterModule(torch.fx.Graph())
+        if self.module_id in self.seen_modules:
+            self.cached_graph_module = self.seen_modules[self.module_id]
+        else:
+            self.cached_graph_module = None
+            self.seen_modules[self.module_id] = self.module
+
+        self.graph = self.module.graph
+
+        # Mapping of nodes in the flat graph to nodes in this graph.
+        self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {}
+        self.node_to_placeholder = {}
+
+        self.parent_call_module: Optional[torch.fx.Node] = None
+        if parent is not None:
+            accessor = _compute_accessor(parent.fqn, self.fqn)
+            _add_submodule(
+                parent.module,
+                accessor,
+                self.module
+                if self.cached_graph_module is None
+                else self.cached_graph_module,
+            )
+            self.parent_call_module = parent.graph.call_module(accessor)
+
+        signature = module_call_graph.get(self.fqn)
+        if signature is not None and self.parent is not None:
+            assert signature.in_spec.num_children == 2
+            args_spec = signature.in_spec.children_specs[0]
+            kwargs_spec = signature.in_spec.children_specs[1]
+            assert args_spec.context is None
+            assert kwargs_spec.context is not None
+
+            with self.graph.inserting_after(None):
+                arg_nodes = []
+                for idx in range(args_spec.num_children):
+                    arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}"))
+                kwarg_nodes = {}
+                for name in kwargs_spec.context:
+                    kwarg_nodes[name] = self.graph.placeholder(name)
+                flat_args = _generate_flatten(
+                    self.module,
+                    (tuple(arg_nodes), kwarg_nodes),
+                    signature.in_spec,
+                )
+                for idx, arg in enumerate(signature.inputs):
+                    flat_arg_node = self.graph.create_node(
+                        op="call_function",
+                        target=operator.getitem,
+                        args=(flat_args, idx),
+                        name=arg.name
+                        if not isinstance(arg, ConstantArgument)
+                        else f"_constant_{idx}",
+                    )
+                    if isinstance(arg, ConstantArgument):
+                        continue
+                    flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
+                    self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node
+
+            with self.parent.graph.inserting_before(self.parent_call_module):
+                input_nodes: List[Optional[torch.fx.Node]] = []
+                for input in signature.inputs:
+                    if isinstance(input, ConstantArgument) and input.value is None:
+                        input_nodes.append(None)
+                    else:
+                        assert isinstance(input, (TensorArgument, SymIntArgument))
+                        input_nodes.append(
+                            self.parent.remap_input(self.seen_nodes[input.name])
+                        )
+
+                inputs_node = _generate_unflatten(
+                    self.parent.module,
+                    input_nodes,
+                    signature.in_spec,
+                )
+
+                args_node = self.parent.graph.call_function(
+                    operator.getitem, (inputs_node, 0)
+                )
+                kwargs_node = self.parent.graph.call_function(
+                    operator.getitem, (inputs_node, 1)
+                )
+                arg_nodes = [
+                    self.parent.graph.call_function(operator.getitem, (args_node, i))
+                    for i in range(args_spec.num_children)
+                ]
+                kwarg_nodes = {
+                    k: self.parent.graph.call_function(
+                        operator.getitem, (kwargs_node, k)
+                    )
+                    for k in kwargs_spec.context
+                }
+            assert self.parent_call_module is not None
+            self.parent_call_module.args = tuple(arg_nodes)
+            self.parent_call_module.kwargs = kwarg_nodes
+
+    def add_placeholder(self, x):
+        assert x.graph is self.flat_graph
+        # x is not in subgraph, create a new placeholder for subgraph
+        with self.graph.inserting_before(None):
+            placeholder_node = self.graph.placeholder(x.name, type_expr=x.type)
+        # copy all meta fields, even if some fields might be irrelvant for
+        # the placeholder node
+        placeholder_node.meta = copy.copy(x.meta)
+        self.node_to_placeholder[x] = placeholder_node
+
+    def remap_input(self, x):
+        assert x.graph is self.flat_graph
+        if x in self.node_map:
+            return self.node_map[x]
+        if x not in self.node_to_placeholder:
+            self.add_placeholder(x)
+            if self.parent_call_module is not None:
+                # Important to *prepend* the output to match how we are
+                # inserting placeholder nodes.
+                self.parent_call_module.insert_arg(0, self.parent.remap_input(x))
+        return self.node_to_placeholder[x]
+
+    def finalize_outputs(self):
+        orig_outputs = []
+
+        signature = self.module_call_graph.get(self.fqn)
+        if signature is not None and self.parent is not None:
+            for output in signature.outputs:
+                if isinstance(output, (TensorArgument, SymIntArgument)):
+                    orig_outputs.append(self.seen_nodes[output.name])
+                else:
+                    raise RuntimeError(
+                        f"Unsupported data type for output node: {output}"
+                    )
+
+            tree_out_node = _generate_unflatten(
+                self.module,
+                tuple(
+                    self.node_map[self.seen_nodes[output.name]]
+                    for output in orig_outputs
+                ),
+                signature.out_spec,
+            )
+            parent_out: Optional[torch.fx.Node] = _generate_flatten(
+                self.parent.module, self.parent_call_module, signature.out_spec
+            )
+            graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node
+        else:
+            graph_outputs = []
+            # Iterate through nodes we have copied into self.graph.
+            for orig_node in self.node_map.keys():
+                for user_node in orig_node.users:
+                    if user_node.name not in self.seen_nodes:
+                        # external user node, need to expose as an output
+                        orig_outputs.append(orig_node)
+                        graph_outputs.append(self.node_map[orig_node])
+                        break
+
+            parent_out = self.parent_call_module
+            if len(graph_outputs) == 1:
+                graph_outputs = graph_outputs[0]
+
+        assert isinstance(graph_outputs, (list, torch.fx.Node))
+
+        self.graph.output(graph_outputs)
+
+        # Rewrite outputs in parent module
+        if parent_out is None:
+            return
+
+        parent_out.meta["val"] = (
+            graph_outputs.meta.get("val")
+            if isinstance(graph_outputs, torch.fx.Node)
+            else [o.meta.get("val") for o in graph_outputs]
+        )
+
+        if len(orig_outputs) == 1 and signature is None:
+            self.parent.node_map[orig_outputs[0]] = parent_out
+        else:
+            for i, orig_output in enumerate(orig_outputs):
+                # Use Proxy to record getitem access.
+                proxy_out = torch.fx.Proxy(parent_out)[i].node  # type: ignore[index]
+                proxy_out.meta["val"] = orig_output.meta.get("val")
+                self.parent.node_map[orig_output] = proxy_out
+
+        if self.cached_graph_module is not None:
+            _verify_graph_equivalence(self.cached_graph_module, self.module)
+
+    def copy_node(self, node):
+        self.print("copying", node.format_node())
+        self.node_map[node] = self.graph.node_copy(node, self.remap_input)
+        self.seen_nodes[node.name] = node
+
+    def run_outer(self):
+        i = 0
+        for node in self.flat_graph.nodes:
+            self.print(i, node.meta.get("nn_module_stack"), node.format_node())
+            i += 1
+
+        # Copy all graph inputs
+        node_idx: int = 0
+        node = self.nodes[node_idx]
+        while node.op == "placeholder":
+            self.copy_node(node)
+            node_idx += 1
+            node = self.nodes[node_idx]
+
+        self.run_from(node_idx)
+
+        # Copy graph outputs
+        for node in self.flat_graph.nodes:
+            if node.op == "output":
+                self.copy_node(node)
+
+    def print(self, *args, **kwargs):
+        if self.verbose:
+            print(*args, **kwargs)
+
+    def run_from(self, node_idx):
+        module_idx = 0
+        # Walk through the graph, building up a new graph with the right submodules
+        while node_idx < len(self.nodes):
+            node = self.nodes[node_idx]
+            assert node.op != "placeholder"
+
+            self.print()
+            self.print("STEP", node_idx, node.format_node())
+            self.print(self.module_stack)
+            if node.op == "output":
+                if len(self.module_stack) == 1:
+                    # We want the output node of the original graph to be handled
+                    # specially by the outermost stack frame (in run_outer). So
+                    # skip finalization here.
+                    return node_idx
+
+                # We've reached the end of the graph. Wrap up all the existing stack frames.
+                self.finalize_outputs()
+                return node_idx
+
+            node_module_stack = (
+                [path for path, ty in node.meta["nn_module_stack"].values()]
+                if "nn_module_stack" in node.meta
+                else self.module_stack
+            )
+            if node_module_stack[: len(self.module_stack)] != self.module_stack:
+                # This means that the current module is done executing and the
+                # current node is the beginning of a new module.
+                #
+                # In this case, we should finalize this module and return without
+                # incrementing the node counter.
+                self.finalize_outputs()
+                self.print("outlining", self.fqn)
+                self.print(self.graph)
+                return node_idx
+
+            assert node_module_stack is not None
+
+            if _is_prefix(self.module_stack, node_module_stack):
+                # This means that the current node represents the execution of a new
+                # module.
+                next_module = node_module_stack[len(self.module_stack)]
+                self.print("Creating new stack frame for", next_module)
+                # Run a nested version of module outliner from the current node
+                # counter. Once it is complete, continue from that point.
+                node_idx = _ModuleFrame(
+                    self.flat_graph,
+                    self.nodes,
+                    self.seen_nodes,
+                    self.seen_modules,
+                    self,
+                    self.module_stack + [next_module],
+                    list(node.meta["nn_module_stack"].keys())[len(self.module_stack)],
+                    self.module_call_graph,
+                ).run_from(node_idx)
+                module_idx += 1
+                continue
+
+            # The only remaining possibility is that we are in the right stack
+            # frame. Copy the node into this frame's graph and increment the node counter.
+            assert node_module_stack == self.module_stack
+            self.copy_node(node)
+            node_idx += 1
+
+
+def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
+    seen_nodes: Dict[str, torch.fx.Node] = {}
+    seen_modules: Dict[int, torch.nn.Module] = {}
+    _ModuleFrame(
+        orig_graph,
+        tuple(orig_graph.nodes),
+        seen_nodes,
+        seen_modules,
+        None,
+        [""],
+        "",
+        {
+            entry.fqn: entry.signature
+            for entry in root_module.module_call_graph
+            if entry.signature
+        },
+        module=root_module,
+    ).run_outer()
+
+
+def _sink_params(
+    module: torch.nn.Module,
+    inputs_to_state: Dict[str, str],
+    scope: List[str],
+):
+    """Sink params, buffers, and constants from graph inputs into get_attr nodes.
+
+    Exported modules are purely functional, so they pass their parameters and
+    buffers in as inputs to the graph.
+
+    To replicate eager's semantics, we need to get them from the module state
+    via get_attr instead.
+
+    module: GraphModule, potentially containining nested submodules.
+    inputs_to_state: mapping graph input names to the corresponding key in the state_dict.
+    scope: tracks where we are in the module hierarchy, so that we can emit the
+        right `getattr(self, "foo.bar")` calls, etc.
+    """
+    # We need to use _modules here instead of named_children(), because we
+    # explicitly want duplicate modules to show up in the traversal.
+    for name, submodule in module._modules.items():
+        _sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name])
+
+    if not hasattr(module, "graph"):
+        # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList)
+        return
+
+    graph = module.graph
+    inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes))
+    the_last_input = inputs[-1]
+
+    # Also remove from call_module nodes
+    call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes)
+    for node in call_module_nodes:
+        node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args))
+
+    for node in inputs:
+        if node.name not in inputs_to_state:
+            continue
+
+        if len(node.users) > 0:
+            state_name = inputs_to_state[node.name].split(".")
+            # If there's a mismatch beteewn scope name and state name, then there must be multuple scopes
+            # pointing to the same state name, meaning some modules are shared. In such case, we can simply
+            # skip updating the current node because another later iteration will take care of this input
+            # node when the unique match between scope and state name occurs.
+            # To make sure this always happen, we should enforce the invariant that no placeholder node
+            # in the unflattened graph appears in inputs_to_state dict, which means all the extra input
+            # nodes have been handled.
+            if state_name[: len(scope)] != scope:
+                continue
+            attr_path = state_name[len(scope) :]
+            state_attr = _recursive_getattr(module, attr_path)
+            assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject))
+
+            # Make sure the newly created get_attr node is placed after the last placeholder node
+            with graph.inserting_after(the_last_input):
+                new_node = graph.create_node("get_attr", ".".join(attr_path))
+
+            node.replace_all_uses_with(new_node, propagate_meta=True)
+        graph.erase_node(node)
+    if isinstance(module, InterpreterModule):
+        module.finalize()
+
+
+def _recursive_getattr(obj, attr_path):
+    for attr in attr_path:
+        obj = getattr(obj, attr)
+
+    return obj
diff --git a/MLPY/Lib/site-packages/torch/fft/__init__.py b/MLPY/Lib/site-packages/torch/fft/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd07fb38dfcd97bac41c7dce6286c06698ef981
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fft/__init__.py
@@ -0,0 +1,1360 @@
+import sys
+
+import torch
+from torch._C import _add_docstr, _fft  # type: ignore[attr-defined]
+from torch._torch_docs import factory_common_args, common_args
+
+__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
+           'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
+           'hfft', 'ihfft', 'fftfreq', 'rfftfreq', 'fftshift', 'ifftshift',
+           'Tensor']
+
+Tensor = torch.Tensor
+
+# Note: This not only adds the doc strings for the spectral ops, but
+# connects the torch.fft Python namespace to the torch._C._fft builtins.
+
+fft = _add_docstr(_fft.fft_fft, r"""
+fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
+
+Computes the one dimensional discrete Fourier transform of :attr:`input`.
+
+Note:
+    The Fourier domain representation of any real signal satisfies the
+    Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
+    the positive and negative frequency terms even though, for real inputs, the
+    negative frequencies are redundant. :func:`~torch.fft.rfft` returns the
+    more compact one-sided representation where only the positive frequencies
+    are returned.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimension.
+
+Args:
+    input (Tensor): the input tensor
+    n (int, optional): Signal length. If given, the input will either be zero-padded
+        or trimmed to this length before computing the FFT.
+    dim (int, optional): The dimension along which to take the one dimensional FFT.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.fft`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
+
+        Calling the backward transform (:func:`~torch.fft.ifft`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ifft`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.arange(4)
+    >>> t
+    tensor([0, 1, 2, 3])
+    >>> torch.fft.fft(t)
+    tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
+
+    >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
+    >>> torch.fft.fft(t)
+    tensor([12.+16.j, -8.+0.j, -4.-4.j,  0.-8.j])
+""".format(**common_args))
+
+ifft = _add_docstr(_fft.fft_ifft, r"""
+ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
+
+Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimension.
+
+Args:
+    input (Tensor): the input tensor
+    n (int, optional): Signal length. If given, the input will either be zero-padded
+        or trimmed to this length before computing the IFFT.
+    dim (int, optional): The dimension along which to take the one dimensional IFFT.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.ifft`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
+
+        Calling the forward transform (:func:`~torch.fft.fft`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ifft`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
+    >>> torch.fft.ifft(t)
+    tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j])
+""".format(**common_args))
+
+fft2 = _add_docstr(_fft.fft_fft2, r"""
+fft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
+
+Computes the 2 dimensional discrete Fourier transform of :attr:`input`.
+Equivalent to :func:`~torch.fft.fftn` but FFTs only the last two dimensions by default.
+
+Note:
+    The Fourier domain representation of any real signal satisfies the
+    Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This
+    function always returns all positive and negative frequency terms even
+    though, for real inputs, half of these values are redundant.
+    :func:`~torch.fft.rfft2` returns the more compact one-sided representation
+    where only the positive frequencies of the last dimension are returned.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: last two dimensions.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.fft2`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical FFT size.
+        Calling the backward transform (:func:`~torch.fft.ifft2`) with the same
+        normalization mode will apply an overall normalization of ``1/n``
+        between the two transforms. This is required to make
+        :func:`~torch.fft.ifft2` the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> x = torch.rand(10, 10, dtype=torch.complex64)
+    >>> fft2 = torch.fft.fft2(x)
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.fft2`
+    here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
+
+    >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
+    >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
+
+""".format(**common_args))
+
+ifft2 = _add_docstr(_fft.fft_ifft2, r"""
+ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
+
+Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
+Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the IFFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: last two dimensions.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.ifft2`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical IFFT size.
+        Calling the forward transform (:func:`~torch.fft.fft2`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ifft2`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> x = torch.rand(10, 10, dtype=torch.complex64)
+    >>> ifft2 = torch.fft.ifft2(x)
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.ifft2`
+    here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
+
+    >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
+    >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
+
+""".format(**common_args))
+
+fftn = _add_docstr(_fft.fft_fftn, r"""
+fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
+
+Computes the N dimensional discrete Fourier transform of :attr:`input`.
+
+Note:
+    The Fourier domain representation of any real signal satisfies the
+    Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
+    function always returns all positive and negative frequency terms even
+    though, for real inputs, half of these values are redundant.
+    :func:`~torch.fft.rfftn` returns the more compact one-sided representation
+    where only the positive frequencies of the last dimension are returned.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.fftn`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical FFT size.
+        Calling the backward transform (:func:`~torch.fft.ifftn`) with the same
+        normalization mode will apply an overall normalization of ``1/n``
+        between the two transforms. This is required to make
+        :func:`~torch.fft.ifftn` the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> x = torch.rand(10, 10, dtype=torch.complex64)
+    >>> fftn = torch.fft.fftn(x)
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.fftn`
+    here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
+
+    >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
+    >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
+
+""".format(**common_args))
+
+ifftn = _add_docstr(_fft.fft_ifftn, r"""
+ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
+
+Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the IFFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.ifftn`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical IFFT size.
+        Calling the forward transform (:func:`~torch.fft.fftn`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ifftn`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> x = torch.rand(10, 10, dtype=torch.complex64)
+    >>> ifftn = torch.fft.ifftn(x)
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.ifftn`
+    here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
+
+    >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
+    >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
+
+""".format(**common_args))
+
+rfft = _add_docstr(_fft.fft_rfft, r"""
+rfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
+
+Computes the one dimensional Fourier transform of real-valued :attr:`input`.
+
+The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so
+the output contains only the positive frequencies below the Nyquist frequency.
+To compute the full output, use :func:`~torch.fft.fft`
+
+Note:
+    Supports torch.half on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimension.
+
+Args:
+    input (Tensor): the real input tensor
+    n (int, optional): Signal length. If given, the input will either be zero-padded
+        or trimmed to this length before computing the real FFT.
+    dim (int, optional): The dimension along which to take the one dimensional real FFT.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.rfft`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
+
+        Calling the backward transform (:func:`~torch.fft.irfft`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.irfft`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.arange(4)
+    >>> t
+    tensor([0, 1, 2, 3])
+    >>> torch.fft.rfft(t)
+    tensor([ 6.+0.j, -2.+2.j, -2.+0.j])
+
+    Compare against the full output from :func:`~torch.fft.fft`:
+
+    >>> torch.fft.fft(t)
+    tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
+
+    Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted.
+    At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair,
+    and therefore must always be real-valued.
+""".format(**common_args))
+
+irfft = _add_docstr(_fft.fft_irfft, r"""
+irfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
+
+Computes the inverse of :func:`~torch.fft.rfft`.
+
+:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
+domain, as produced by :func:`~torch.fft.rfft`. By the Hermitian property, the
+output will be real-valued.
+
+Note:
+    Some input frequencies must be real-valued to satisfy the Hermitian
+    property. In these cases the imaginary component will be ignored.
+    For example, any imaginary component in the zero-frequency term cannot
+    be represented in a real output and so will always be ignored.
+
+Note:
+    The correct interpretation of the Hermitian input depends on the length of
+    the original data, as given by :attr:`n`. This is because each input shape
+    could correspond to either an odd or even length signal. By default, the
+    signal is assumed to be even length and odd signals will not round-trip
+    properly. So, it is recommended to always pass the signal length :attr:`n`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimension.
+    With default arguments, size of the transformed dimension should be (2^n + 1) as argument
+    `n` defaults to even output size = 2 * (transformed_dim_size - 1)
+
+Args:
+    input (Tensor): the input tensor representing a half-Hermitian signal
+    n (int, optional): Output signal length. This determines the length of the
+        output signal. If given, the input will either be zero-padded or trimmed to this
+        length before computing the real IFFT.
+        Defaults to even output: ``n=2*(input.size(dim) - 1)``.
+    dim (int, optional): The dimension along which to take the one dimensional real IFFT.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.irfft`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
+
+        Calling the forward transform (:func:`~torch.fft.rfft`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.irfft`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.linspace(0, 1, 5)
+    >>> t
+    tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
+    >>> T = torch.fft.rfft(t)
+    >>> T
+    tensor([ 2.5000+0.0000j, -0.6250+0.8602j, -0.6250+0.2031j])
+
+    Without specifying the output length to :func:`~torch.fft.irfft`, the output
+    will not round-trip properly because the input is odd-length:
+
+    >>> torch.fft.irfft(T)
+    tensor([0.1562, 0.3511, 0.7812, 1.2114])
+
+    So, it is recommended to always pass the signal length :attr:`n`:
+
+    >>> roundtrip = torch.fft.irfft(T, t.numel())
+    >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
+
+""".format(**common_args))
+
+rfft2 = _add_docstr(_fft.fft_rfft2, r"""
+rfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
+
+Computes the 2-dimensional discrete Fourier transform of real :attr:`input`.
+Equivalent to :func:`~torch.fft.rfftn` but FFTs only the last two dimensions by default.
+
+The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``,
+so the full :func:`~torch.fft.fft2` output contains redundant information.
+:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
+dimension.
+
+Note:
+    Supports torch.half on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the real FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: last two dimensions.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.rfft2`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical FFT size.
+        Calling the backward transform (:func:`~torch.fft.irfft2`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.irfft2`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.rand(10, 10)
+    >>> rfft2 = torch.fft.rfft2(t)
+    >>> rfft2.size()
+    torch.Size([10, 6])
+
+    Compared against the full output from :func:`~torch.fft.fft2`, we have all
+    elements up to the Nyquist frequency.
+
+    >>> fft2 = torch.fft.fft2(t)
+    >>> torch.testing.assert_close(fft2[..., :6], rfft2, check_stride=False)
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2`
+    here is equivalent to a combination of :func:`~torch.fft.fft` and
+    :func:`~torch.fft.rfft`:
+
+    >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
+    >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
+
+""".format(**common_args))
+
+irfft2 = _add_docstr(_fft.fft_irfft2, r"""
+irfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
+
+Computes the inverse of :func:`~torch.fft.rfft2`.
+Equivalent to :func:`~torch.fft.irfftn` but IFFTs only the last two dimensions by default.
+
+:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
+domain, as produced by :func:`~torch.fft.rfft2`. By the Hermitian property, the
+output will be real-valued.
+
+Note:
+    Some input frequencies must be real-valued to satisfy the Hermitian
+    property. In these cases the imaginary component will be ignored.
+    For example, any imaginary component in the zero-frequency term cannot
+    be represented in a real output and so will always be ignored.
+
+Note:
+    The correct interpretation of the Hermitian input depends on the length of
+    the original data, as given by :attr:`s`. This is because each input shape
+    could correspond to either an odd or even length signal. By default, the
+    signal is assumed to be even length and odd signals will not round-trip
+    properly. So, it is recommended to always pass the signal shape :attr:`s`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+    With default arguments, the size of last dimension should be (2^n + 1) as argument
+    `s` defaults to even output size = 2 * (last_dim_size - 1)
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the real FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Defaults to even output in the last dimension:
+        ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        The last dimension must be the half-Hermitian compressed dimension.
+        Default: last two dimensions.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.irfft2`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical IFFT size.
+        Calling the forward transform (:func:`~torch.fft.rfft2`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.irfft2`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.rand(10, 9)
+    >>> T = torch.fft.rfft2(t)
+
+    Without specifying the output length to :func:`~torch.fft.irfft2`, the output
+    will not round-trip properly because the input is odd-length in the last
+    dimension:
+
+    >>> torch.fft.irfft2(T).size()
+    torch.Size([10, 8])
+
+    So, it is recommended to always pass the signal shape :attr:`s`.
+
+    >>> roundtrip = torch.fft.irfft2(T, t.size())
+    >>> roundtrip.size()
+    torch.Size([10, 9])
+    >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
+
+""".format(**common_args))
+
+rfftn = _add_docstr(_fft.fft_rfftn, r"""
+rfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
+
+Computes the N-dimensional discrete Fourier transform of real :attr:`input`.
+
+The FFT of a real signal is Hermitian-symmetric,
+``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full
+:func:`~torch.fft.fftn` output contains redundant information.
+:func:`~torch.fft.rfftn` instead omits the negative frequencies in the
+last dimension.
+
+Note:
+    Supports torch.half on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the real FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.rfftn`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical FFT size.
+        Calling the backward transform (:func:`~torch.fft.irfftn`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.irfftn`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.rand(10, 10)
+    >>> rfftn = torch.fft.rfftn(t)
+    >>> rfftn.size()
+    torch.Size([10, 6])
+
+    Compared against the full output from :func:`~torch.fft.fftn`, we have all
+    elements up to the Nyquist frequency.
+
+    >>> fftn = torch.fft.fftn(t)
+    >>> torch.testing.assert_close(fftn[..., :6], rfftn, check_stride=False)
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn`
+    here is equivalent to a combination of :func:`~torch.fft.fft` and
+    :func:`~torch.fft.rfft`:
+
+    >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
+    >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
+
+""".format(**common_args))
+
+irfftn = _add_docstr(_fft.fft_irfftn, r"""
+irfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
+
+Computes the inverse of :func:`~torch.fft.rfftn`.
+
+:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
+domain, as produced by :func:`~torch.fft.rfftn`. By the Hermitian property, the
+output will be real-valued.
+
+Note:
+    Some input frequencies must be real-valued to satisfy the Hermitian
+    property. In these cases the imaginary component will be ignored.
+    For example, any imaginary component in the zero-frequency term cannot
+    be represented in a real output and so will always be ignored.
+
+Note:
+    The correct interpretation of the Hermitian input depends on the length of
+    the original data, as given by :attr:`s`. This is because each input shape
+    could correspond to either an odd or even length signal. By default, the
+    signal is assumed to be even length and odd signals will not round-trip
+    properly. So, it is recommended to always pass the signal shape :attr:`s`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+    With default arguments, the size of last dimension should be (2^n + 1) as argument
+    `s` defaults to even output size = 2 * (last_dim_size - 1)
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the real FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Defaults to even output in the last dimension:
+        ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        The last dimension must be the half-Hermitian compressed dimension.
+        Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.irfftn`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical IFFT size.
+        Calling the forward transform (:func:`~torch.fft.rfftn`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.irfftn`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.rand(10, 9)
+    >>> T = torch.fft.rfftn(t)
+
+    Without specifying the output length to :func:`~torch.fft.irfft`, the output
+    will not round-trip properly because the input is odd-length in the last
+    dimension:
+
+    >>> torch.fft.irfftn(T).size()
+    torch.Size([10, 8])
+
+    So, it is recommended to always pass the signal shape :attr:`s`.
+
+    >>> roundtrip = torch.fft.irfftn(T, t.size())
+    >>> roundtrip.size()
+    torch.Size([10, 9])
+    >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
+
+""".format(**common_args))
+
+hfft = _add_docstr(_fft.fft_hfft, r"""
+hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
+
+Computes the one dimensional discrete Fourier transform of a Hermitian
+symmetric :attr:`input` signal.
+
+Note:
+
+    :func:`~torch.fft.hfft`/:func:`~torch.fft.ihfft` are analogous to
+    :func:`~torch.fft.rfft`/:func:`~torch.fft.irfft`. The real FFT expects
+    a real signal in the time-domain and gives a Hermitian symmetry in the
+    frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in
+    the time-domain and real-valued in the frequency-domain. For this reason,
+    special care needs to be taken with the length argument :attr:`n`, in the
+    same way as with :func:`~torch.fft.irfft`.
+
+Note:
+    Because the signal is Hermitian in the time-domain, the result will be
+    real in the frequency domain. Note that some input frequencies must be
+    real-valued to satisfy the Hermitian property. In these cases the imaginary
+    component will be ignored. For example, any imaginary component in
+    ``input[0]`` would result in one or more complex frequency terms which
+    cannot be represented in a real output and so will always be ignored.
+
+Note:
+    The correct interpretation of the Hermitian input depends on the length of
+    the original data, as given by :attr:`n`. This is because each input shape
+    could correspond to either an odd or even length signal. By default, the
+    signal is assumed to be even length and odd signals will not round-trip
+    properly. So, it is recommended to always pass the signal length :attr:`n`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimension.
+    With default arguments, size of the transformed dimension should be (2^n + 1) as argument
+    `n` defaults to even output size = 2 * (transformed_dim_size - 1)
+
+Args:
+    input (Tensor): the input tensor representing a half-Hermitian signal
+    n (int, optional): Output signal length. This determines the length of the
+        real output. If given, the input will either be zero-padded or trimmed to this
+        length before computing the Hermitian FFT.
+        Defaults to even output: ``n=2*(input.size(dim) - 1)``.
+    dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.hfft`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)
+
+        Calling the backward transform (:func:`~torch.fft.ihfft`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ihfft`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    Taking a real-valued frequency signal and bringing it into the time domain
+    gives Hermitian symmetric output:
+
+    >>> t = torch.linspace(0, 1, 5)
+    >>> t
+    tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
+    >>> T = torch.fft.ifft(t)
+    >>> T
+    tensor([ 0.5000-0.0000j, -0.1250-0.1720j, -0.1250-0.0406j, -0.1250+0.0406j,
+            -0.1250+0.1720j])
+
+    Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is
+    redundant. We can thus compute the forward transform without considering
+    negative frequencies:
+
+    >>> torch.fft.hfft(T[:3], n=5)
+    tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
+
+    Like with :func:`~torch.fft.irfft`, the output length must be given in order
+    to recover an even length output:
+
+    >>> torch.fft.hfft(T[:3])
+    tensor([0.1250, 0.2809, 0.6250, 0.9691])
+""".format(**common_args))
+
+ihfft = _add_docstr(_fft.fft_ihfft, r"""
+ihfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
+
+Computes the inverse of :func:`~torch.fft.hfft`.
+
+:attr:`input` must be a real-valued signal, interpreted in the Fourier domain.
+The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.
+:func:`~torch.fft.ihfft` represents this in the one-sided form where only the
+positive frequencies below the Nyquist frequency are included. To compute the
+full output, use :func:`~torch.fft.ifft`.
+
+Note:
+    Supports torch.half on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimension.
+
+Args:
+    input (Tensor): the real input tensor
+    n (int, optional): Signal length. If given, the input will either be zero-padded
+        or trimmed to this length before computing the Hermitian IFFT.
+    dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.ihfft`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
+
+        Calling the forward transform (:func:`~torch.fft.hfft`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ihfft`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> t = torch.arange(5)
+    >>> t
+    tensor([0, 1, 2, 3, 4])
+    >>> torch.fft.ihfft(t)
+    tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j])
+
+    Compare against the full output from :func:`~torch.fft.ifft`:
+
+    >>> torch.fft.ifft(t)
+    tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
+            -0.5000+0.6882j])
+""".format(**common_args))
+
+hfft2 = _add_docstr(_fft.fft_hfft2, r"""
+hfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
+
+Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric
+:attr:`input` signal. Equivalent to :func:`~torch.fft.hfftn` but only
+transforms the last two dimensions by default.
+
+:attr:`input` is interpreted as a one-sided Hermitian signal in the time
+domain. By the Hermitian property, the Fourier transform will be real-valued.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+    With default arguments, the size of last dimension should be (2^n + 1) as argument
+    `s` defaults to even output size = 2 * (last_dim_size - 1)
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the Hermitian FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Defaults to even output in the last dimension:
+        ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        The last dimension must be the half-Hermitian compressed dimension.
+        Default: last two dimensions.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.hfft2`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical FFT size.
+        Calling the backward transform (:func:`~torch.fft.ihfft2`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ihfft2`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    Starting from a real frequency-space signal, we can generate a
+    Hermitian-symmetric time-domain signal:
+    >>> T = torch.rand(10, 9)
+    >>> t = torch.fft.ihfft2(T)
+
+    Without specifying the output length to :func:`~torch.fft.hfftn`, the
+    output will not round-trip properly because the input is odd-length in the
+    last dimension:
+
+    >>> torch.fft.hfft2(t).size()
+    torch.Size([10, 10])
+
+    So, it is recommended to always pass the signal shape :attr:`s`.
+
+    >>> roundtrip = torch.fft.hfft2(t, T.size())
+    >>> roundtrip.size()
+    torch.Size([10, 9])
+    >>> torch.allclose(roundtrip, T)
+    True
+
+""".format(**common_args))
+
+ihfft2 = _add_docstr(_fft.fft_ihfft2, r"""
+ihfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
+
+Computes the 2-dimensional inverse discrete Fourier transform of real
+:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
+two last dimensions by default.
+
+Note:
+    Supports torch.half on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the Hermitian IFFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: last two dimensions.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.ihfft2`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical IFFT size.
+        Calling the forward transform (:func:`~torch.fft.hfft2`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ihfft2`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> T = torch.rand(10, 10)
+    >>> t = torch.fft.ihfft2(t)
+    >>> t.size()
+    torch.Size([10, 6])
+
+    Compared against the full output from :func:`~torch.fft.ifft2`, the
+    Hermitian time-space signal takes up only half the space.
+
+    >>> fftn = torch.fft.ifft2(t)
+    >>> torch.allclose(fftn[..., :6], rfftn)
+    True
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.ihfft2`
+    here is equivalent to a combination of :func:`~torch.fft.ifft` and
+    :func:`~torch.fft.ihfft`:
+
+    >>> two_ffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0)
+    >>> torch.allclose(t, two_ffts)
+    True
+
+""".format(**common_args))
+
+hfftn = _add_docstr(_fft.fft_hfftn, r"""
+hfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
+
+Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric
+:attr:`input` signal.
+
+:attr:`input` is interpreted as a one-sided Hermitian signal in the time
+domain. By the Hermitian property, the Fourier transform will be real-valued.
+
+Note:
+    :func:`~torch.fft.hfftn`/:func:`~torch.fft.ihfftn` are analogous to
+    :func:`~torch.fft.rfftn`/:func:`~torch.fft.irfftn`. The real FFT expects
+    a real signal in the time-domain and gives Hermitian symmetry in the
+    frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in
+    the time-domain and real-valued in the frequency-domain. For this reason,
+    special care needs to be taken with the shape argument :attr:`s`, in the
+    same way as with :func:`~torch.fft.irfftn`.
+
+Note:
+    Some input frequencies must be real-valued to satisfy the Hermitian
+    property. In these cases the imaginary component will be ignored.
+    For example, any imaginary component in the zero-frequency term cannot
+    be represented in a real output and so will always be ignored.
+
+Note:
+    The correct interpretation of the Hermitian input depends on the length of
+    the original data, as given by :attr:`s`. This is because each input shape
+    could correspond to either an odd or even length signal. By default, the
+    signal is assumed to be even length and odd signals will not round-trip
+    properly. It is recommended to always pass the signal shape :attr:`s`.
+
+Note:
+    Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+    With default arguments, the size of last dimension should be (2^n + 1) as argument
+    `s` defaults to even output size = 2 * (last_dim_size - 1)
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the real FFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Defaults to even output in the last dimension:
+        ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        The last dimension must be the half-Hermitian compressed dimension.
+        Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
+    norm (str, optional): Normalization mode. For the forward transform
+        (:func:`~torch.fft.hfftn`), these correspond to:
+
+        * ``"forward"`` - normalize by ``1/n``
+        * ``"backward"`` - no normalization
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical FFT size.
+        Calling the backward transform (:func:`~torch.fft.ihfftn`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ihfftn`
+        the exact inverse.
+
+        Default is ``"backward"`` (no normalization).
+
+Keyword args:
+    {out}
+
+Example:
+
+    Starting from a real frequency-space signal, we can generate a
+    Hermitian-symmetric time-domain signal:
+    >>> T = torch.rand(10, 9)
+    >>> t = torch.fft.ihfftn(T)
+
+    Without specifying the output length to :func:`~torch.fft.hfftn`, the
+    output will not round-trip properly because the input is odd-length in the
+    last dimension:
+
+    >>> torch.fft.hfftn(t).size()
+    torch.Size([10, 10])
+
+    So, it is recommended to always pass the signal shape :attr:`s`.
+
+    >>> roundtrip = torch.fft.hfftn(t, T.size())
+    >>> roundtrip.size()
+    torch.Size([10, 9])
+    >>> torch.allclose(roundtrip, T)
+    True
+
+""".format(**common_args))
+
+ihfftn = _add_docstr(_fft.fft_ihfftn, r"""
+ihfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
+
+Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`.
+
+:attr:`input` must be a real-valued signal, interpreted in the Fourier domain.
+The n-dimensional IFFT of a real signal is Hermitian-symmetric,
+``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`~torch.fft.ihfftn` represents
+this in the one-sided form where only the positive frequencies below the
+Nyquist frequency are included in the last signal dimension. To compute the
+full output, use :func:`~torch.fft.ifftn`.
+
+Note:
+    Supports torch.half on CUDA with GPU Architecture SM53 or greater.
+    However it only supports powers of 2 signal length in every transformed dimensions.
+
+Args:
+    input (Tensor): the input tensor
+    s (Tuple[int], optional): Signal size in the transformed dimensions.
+        If given, each dimension ``dim[i]`` will either be zero-padded or
+        trimmed to the length ``s[i]`` before computing the Hermitian IFFT.
+        If a length ``-1`` is specified, no padding is done in that dimension.
+        Default: ``s = [input.size(d) for d in dim]``
+    dim (Tuple[int], optional): Dimensions to be transformed.
+        Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
+    norm (str, optional): Normalization mode. For the backward transform
+        (:func:`~torch.fft.ihfftn`), these correspond to:
+
+        * ``"forward"`` - no normalization
+        * ``"backward"`` - normalize by ``1/n``
+        * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal)
+
+        Where ``n = prod(s)`` is the logical IFFT size.
+        Calling the forward transform (:func:`~torch.fft.hfftn`) with the same
+        normalization mode will apply an overall normalization of ``1/n`` between
+        the two transforms. This is required to make :func:`~torch.fft.ihfftn`
+        the exact inverse.
+
+        Default is ``"backward"`` (normalize by ``1/n``).
+
+Keyword args:
+    {out}
+
+Example:
+
+    >>> T = torch.rand(10, 10)
+    >>> ihfftn = torch.fft.ihfftn(T)
+    >>> ihfftn.size()
+    torch.Size([10, 6])
+
+    Compared against the full output from :func:`~torch.fft.ifftn`, we have all
+    elements up to the Nyquist frequency.
+
+    >>> ifftn = torch.fft.ifftn(t)
+    >>> torch.allclose(ifftn[..., :6], ihfftn)
+    True
+
+    The discrete Fourier transform is separable, so :func:`~torch.fft.ihfftn`
+    here is equivalent to a combination of :func:`~torch.fft.ihfft` and
+    :func:`~torch.fft.ifft`:
+
+    >>> two_iffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0)
+    >>> torch.allclose(ihfftn, two_iffts)
+    True
+
+""".format(**common_args))
+
+fftfreq = _add_docstr(_fft.fft_fftfreq, r"""
+fftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
+
+Computes the discrete Fourier Transform sample frequencies for a signal of size :attr:`n`.
+
+Note:
+    By convention, :func:`~torch.fft.fft` returns positive frequency terms
+    first, followed by the negative frequencies in reverse order, so that
+    ``f[-i]`` for all :math:`0 < i \leq n/2`` in Python gives the negative
+    frequency terms. For an FFT of length :attr:`n` and with inputs spaced in
+    length unit :attr:`d`, the frequencies are::
+
+        f = [0, 1, ..., (n - 1) // 2, -(n // 2), ..., -1] / (d * n)
+
+Note:
+    For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as
+    either negative or positive. :func:`~torch.fft.fftfreq` follows NumPy's
+    convention of taking it to be negative.
+
+Args:
+    n (int): the FFT length
+    d (float, optional): The sampling length scale.
+        The spacing between individual samples of the FFT input.
+        The default assumes unit spacing, dividing that result by the actual
+        spacing gives the result in physical frequency units.
+
+Keyword Args:
+    {out}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Example:
+
+    >>> torch.fft.fftfreq(5)
+    tensor([ 0.0000,  0.2000,  0.4000, -0.4000, -0.2000])
+
+    For even input, we can see the Nyquist frequency at ``f[2]`` is given as
+    negative:
+
+    >>> torch.fft.fftfreq(4)
+    tensor([ 0.0000,  0.2500, -0.5000, -0.2500])
+
+""".format(**factory_common_args))
+
+rfftfreq = _add_docstr(_fft.fft_rfftfreq, r"""
+rfftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
+
+Computes the sample frequencies for :func:`~torch.fft.rfft` with a signal of size :attr:`n`.
+
+Note:
+    :func:`~torch.fft.rfft` returns Hermitian one-sided output, so only the
+    positive frequency terms are returned. For a real FFT of length :attr:`n`
+    and with inputs spaced in length unit :attr:`d`, the frequencies are::
+
+        f = torch.arange((n + 1) // 2) / (d * n)
+
+Note:
+    For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as
+    either negative or positive. Unlike :func:`~torch.fft.fftfreq`,
+    :func:`~torch.fft.rfftfreq` always returns it as positive.
+
+Args:
+    n (int): the real FFT length
+    d (float, optional): The sampling length scale.
+        The spacing between individual samples of the FFT input.
+        The default assumes unit spacing, dividing that result by the actual
+        spacing gives the result in physical frequency units.
+
+Keyword Args:
+    {out}
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+
+Example:
+
+    >>> torch.fft.rfftfreq(5)
+    tensor([0.0000, 0.2000, 0.4000])
+
+    >>> torch.fft.rfftfreq(4)
+    tensor([0.0000, 0.2500, 0.5000])
+
+    Compared to the output from :func:`~torch.fft.fftfreq`, we see that the
+    Nyquist frequency at ``f[2]`` has changed sign:
+    >>> torch.fft.fftfreq(4)
+    tensor([ 0.0000,  0.2500, -0.5000, -0.2500])
+
+""".format(**factory_common_args))
+
+fftshift = _add_docstr(_fft.fft_fftshift, r"""
+fftshift(input, dim=None) -> Tensor
+
+Reorders n-dimensional FFT data, as provided by :func:`~torch.fft.fftn`, to have
+negative frequency terms first.
+
+This performs a periodic shift of n-dimensional data such that the origin
+``(0, ..., 0)`` is moved to the center of the tensor. Specifically, to
+``input.shape[dim] // 2`` in each selected dimension.
+
+Note:
+    By convention, the FFT returns positive frequency terms first, followed by
+    the negative frequencies in reverse order, so that ``f[-i]`` for all
+    :math:`0 < i \leq n/2` in Python gives the negative frequency terms.
+    :func:`~torch.fft.fftshift` rearranges all frequencies into ascending order
+    from negative to positive with the zero-frequency term in the center.
+
+Note:
+    For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as
+    either negative or positive. :func:`~torch.fft.fftshift` always puts the
+    Nyquist term at the 0-index. This is the same convention used by
+    :func:`~torch.fft.fftfreq`.
+
+Args:
+    input (Tensor): the tensor in FFT order
+    dim (int, Tuple[int], optional): The dimensions to rearrange.
+        Only dimensions specified here will be rearranged, any other dimensions
+        will be left in their original order.
+        Default: All dimensions of :attr:`input`.
+
+Example:
+
+    >>> f = torch.fft.fftfreq(4)
+    >>> f
+    tensor([ 0.0000,  0.2500, -0.5000, -0.2500])
+
+    >>> torch.fft.fftshift(f)
+    tensor([-0.5000, -0.2500,  0.0000,  0.2500])
+
+    Also notice that the Nyquist frequency term at ``f[2]`` was moved to the
+    beginning of the tensor.
+
+    This also works for multi-dimensional transforms:
+
+    >>> x = torch.fft.fftfreq(5, d=1/5) + 0.1 * torch.fft.fftfreq(5, d=1/5).unsqueeze(1)
+    >>> x
+    tensor([[ 0.0000,  1.0000,  2.0000, -2.0000, -1.0000],
+            [ 0.1000,  1.1000,  2.1000, -1.9000, -0.9000],
+            [ 0.2000,  1.2000,  2.2000, -1.8000, -0.8000],
+            [-0.2000,  0.8000,  1.8000, -2.2000, -1.2000],
+            [-0.1000,  0.9000,  1.9000, -2.1000, -1.1000]])
+
+    >>> torch.fft.fftshift(x)
+    tensor([[-2.2000, -1.2000, -0.2000,  0.8000,  1.8000],
+            [-2.1000, -1.1000, -0.1000,  0.9000,  1.9000],
+            [-2.0000, -1.0000,  0.0000,  1.0000,  2.0000],
+            [-1.9000, -0.9000,  0.1000,  1.1000,  2.1000],
+            [-1.8000, -0.8000,  0.2000,  1.2000,  2.2000]])
+
+    :func:`~torch.fft.fftshift` can also be useful for spatial data. If our
+    data is defined on a centered grid (``[-(N//2), (N-1)//2]``) then we can
+    use the standard FFT defined on an uncentered grid (``[0, N)``) by first
+    applying an :func:`~torch.fft.ifftshift`.
+
+    >>> x_centered = torch.arange(-5, 5)
+    >>> x_uncentered = torch.fft.ifftshift(x_centered)
+    >>> fft_uncentered = torch.fft.fft(x_uncentered)
+
+    Similarly, we can convert the frequency domain components to centered
+    convention by applying :func:`~torch.fft.fftshift`.
+
+    >>> fft_centered = torch.fft.fftshift(fft_uncentered)
+
+    The inverse transform, from centered Fourier space back to centered spatial
+    data, can be performed by applying the inverse shifts in reverse order:
+
+    >>> x_centered_2 = torch.fft.fftshift(torch.fft.ifft(torch.fft.ifftshift(fft_centered)))
+    >>> torch.testing.assert_close(x_centered.to(torch.complex64), x_centered_2, check_stride=False)
+
+
+""")
+
+ifftshift = _add_docstr(_fft.fft_ifftshift, r"""
+ifftshift(input, dim=None) -> Tensor
+
+Inverse of :func:`~torch.fft.fftshift`.
+
+Args:
+    input (Tensor): the tensor in FFT order
+    dim (int, Tuple[int], optional): The dimensions to rearrange.
+        Only dimensions specified here will be rearranged, any other dimensions
+        will be left in their original order.
+        Default: All dimensions of :attr:`input`.
+
+Example:
+
+    >>> f = torch.fft.fftfreq(5)
+    >>> f
+    tensor([ 0.0000,  0.2000,  0.4000, -0.4000, -0.2000])
+
+    A round-trip through :func:`~torch.fft.fftshift` and
+    :func:`~torch.fft.ifftshift` gives the same result:
+
+    >>> shifted = torch.fft.fftshift(f)
+    >>> torch.fft.ifftshift(shifted)
+    tensor([ 0.0000,  0.2000,  0.4000, -0.4000, -0.2000])
+
+""")
diff --git a/MLPY/Lib/site-packages/torch/fft/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fft/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..366ebc07852ba0cdee9ffb10c48c89a9e1581991
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fft/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/func/__init__.py b/MLPY/Lib/site-packages/torch/func/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc09ec09c286cf71ccf5dab0df16f535d673910a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/func/__init__.py
@@ -0,0 +1,13 @@
+from torch._functorch.eager_transforms import (
+    vjp,
+    jvp,
+    jacrev,
+    jacfwd,
+    hessian,
+    functionalize,
+    linearize
+)
+from torch._functorch.apis import grad, grad_and_value
+from torch._functorch.functional_call import functional_call, stack_module_state
+from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
+from torch._functorch.apis import vmap
diff --git a/MLPY/Lib/site-packages/torch/func/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/func/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd25816d50d626323cf2b37b645e2731b29311ba
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/func/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/functional.py b/MLPY/Lib/site-packages/torch/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..90a08b93b8e8bf00dea9b2ef5248534073a2220c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/functional.py
@@ -0,0 +1,1983 @@
+from typing import (
+    List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
+)
+import operator
+import itertools
+
+import torch
+from torch._C import _add_docstr
+import torch.nn.functional as F
+from ._lowrank import svd_lowrank, pca_lowrank
+from .overrides import (
+    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
+    handle_torch_function)
+from ._jit_internal import boolean_dispatch
+from ._jit_internal import _overload as overload
+
+Tensor = torch.Tensor
+from torch import _VF
+
+__all__ = [
+    'atleast_1d',
+    'atleast_2d',
+    'atleast_3d',
+    'align_tensors',
+    'broadcast_shapes',
+    'broadcast_tensors',
+    'cartesian_prod',
+    'block_diag',
+    'cdist',
+    'chain_matmul',
+    'einsum',
+    'istft',
+    'lu',
+    'norm',
+    'meshgrid',
+    'pca_lowrank',
+    'split',
+    'stft',
+    'svd_lowrank',
+    'tensordot',
+    'unique',
+    'unique_consecutive',
+    'unravel_index',
+]
+
+
+def broadcast_tensors(*tensors):
+    r"""broadcast_tensors(*tensors) -> List of Tensors
+
+    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
+
+    Args:
+        *tensors: any number of tensors of the same type
+
+    .. warning::
+
+        More than one element of a broadcasted tensor may refer to a single
+        memory location. As a result, in-place operations (especially ones that
+        are vectorized) may result in incorrect behavior. If you need to write
+        to the tensors, please clone them first.
+
+    Example::
+
+        >>> x = torch.arange(3).view(1, 3)
+        >>> y = torch.arange(2).view(2, 1)
+        >>> a, b = torch.broadcast_tensors(x, y)
+        >>> a.size()
+        torch.Size([2, 3])
+        >>> a
+        tensor([[0, 1, 2],
+                [0, 1, 2]])
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(tensors):
+        return handle_torch_function(broadcast_tensors, tensors, *tensors)
+    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
+
+
+def broadcast_shapes(*shapes):
+    r"""broadcast_shapes(*shapes) -> Size
+
+    Similar to :func:`broadcast_tensors` but for shapes.
+
+    This is equivalent to
+    ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
+    but avoids the need create to intermediate tensors. This is useful for
+    broadcasting tensors of common batch shape but different rightmost shape,
+    e.g. to broadcast mean vectors with covariance matrices.
+
+    Example::
+
+        >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
+        torch.Size([1, 3, 2])
+
+    Args:
+        \*shapes (torch.Size): Shapes of tensors.
+
+    Returns:
+        shape (torch.Size): A shape compatible with all input shapes.
+
+    Raises:
+        RuntimeError: If shapes are incompatible.
+    """
+    # This wrapper exists to support variadic args.
+    # TODO Move this to C++ once the jit has better support for torch.Size.
+    if not torch.jit.is_tracing():
+        max_len = 0
+        for shape in shapes:
+            if isinstance(shape, (int, torch.SymInt)):
+                if max_len < 1:
+                    max_len = 1
+            elif isinstance(shape, (tuple, list)):
+                s = len(shape)
+                if max_len < s:
+                    max_len = s
+        result = [1] * max_len
+
+        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
+
+        for shape in shapes:
+            if isinstance(shape, (int, torch.SymInt)):
+                shape = (shape,)
+            if isinstance(shape, (tuple, list)):
+                for i in range(-1, -1 - len(shape), -1):
+                    if shape[i] < 0:
+                        raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
+                    # NB: result is initialized to 1 so this is effectively an
+                    # equals one test
+                    if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
+                        continue
+                    if result[i] != 1:
+                        raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
+                    result[i] = shape[i]
+            else:
+                raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
+        return torch.Size(result)
+    else:
+        # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
+        with torch.no_grad():
+            scalar = torch.zeros((), device="cpu")
+            tensors = [scalar.expand(shape) for shape in shapes]
+            tensors = broadcast_tensors(*tensors)
+            return tensors[0].shape
+
+
+def split(
+    tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
+) -> Tuple[Tensor, ...]:
+    r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
+
+    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
+    be split into equally sized chunks (if possible). Last chunk will be smaller if
+    the tensor size along the given dimension :attr:`dim` is not divisible by
+    :attr:`split_size`.
+
+    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
+    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
+    to :attr:`split_size_or_sections`.
+
+    Args:
+        tensor (Tensor): tensor to split.
+        split_size_or_sections (int) or (list(int)): size of a single chunk or
+            list of sizes for each chunk
+        dim (int): dimension along which to split the tensor.
+
+    Example::
+
+        >>> a = torch.arange(10).reshape(5, 2)
+        >>> a
+        tensor([[0, 1],
+                [2, 3],
+                [4, 5],
+                [6, 7],
+                [8, 9]])
+        >>> torch.split(a, 2)
+        (tensor([[0, 1],
+                 [2, 3]]),
+         tensor([[4, 5],
+                 [6, 7]]),
+         tensor([[8, 9]]))
+        >>> torch.split(a, [1, 4])
+        (tensor([[0, 1]]),
+         tensor([[2, 3],
+                 [4, 5],
+                 [6, 7],
+                 [8, 9]]))
+    """
+    if has_torch_function_unary(tensor):
+        return handle_torch_function(
+            split, (tensor,), tensor, split_size_or_sections, dim=dim)
+    # Overwriting reason:
+    # This dispatches to two ATen functions depending on the type of
+    # split_size_or_sections. The branching code is in _tensor.py, which we
+    # call here.
+    return tensor.split(split_size_or_sections, dim)
+
+
+def einsum(*args: Any) -> Tensor:
+    r"""einsum(equation, *operands) -> Tensor
+
+    Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
+    based on the Einstein summation convention.
+
+    Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
+    in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
+    this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
+    with some subscript and define which subscripts are part of the output. The output is then computed by summing
+    the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
+    output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
+    Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
+
+    Equation:
+
+        The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
+        the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
+        comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
+        must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
+        repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
+        must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
+        appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
+        The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
+        on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
+
+        Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
+        followed by the subscripts for the output. For instance, the following equation computes the transpose of a
+        matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
+        at most once for the output.
+
+        Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
+        Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
+        e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
+        dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
+        'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
+        explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
+        before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
+        batch matrix multiplication `'...ij,...jk'`.
+
+        A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
+        arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
+
+    .. note::
+
+        ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
+        covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
+
+    .. note::
+
+        This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to
+        consume less memory by optimizing contraction order. This optimization occurs when there are at least three
+        inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem,
+        thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available,
+        the default order is to contract from left to right.
+
+        To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path
+        calculation: `torch.backends.opt_einsum.enabled = False`
+
+        To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
+        `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and
+        'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
+        the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
+
+    .. note::
+
+        As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
+        subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
+        follow their operands, and an extra sublist can appear at the end of the input to specify the output's
+        subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
+        may be provided in a sublist to enable broadcasting as described in the Equation section above.
+
+    Args:
+        equation (str): The subscripts for the Einstein summation.
+        operands (List[Tensor]): The tensors to compute the Einstein summation of.
+
+    Examples::
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> # trace
+        >>> torch.einsum('ii', torch.randn(4, 4))
+        tensor(-1.2104)
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> # diagonal
+        >>> torch.einsum('ii->i', torch.randn(4, 4))
+        tensor([-0.1034,  0.7952, -0.2433,  0.4545])
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> # outer product
+        >>> x = torch.randn(5)
+        >>> y = torch.randn(4)
+        >>> torch.einsum('i,j->ij', x, y)
+        tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
+                [-0.3744,  0.9381,  1.2685, -1.6070],
+                [ 0.7208, -1.8058, -2.4419,  3.0936],
+                [ 0.1713, -0.4291, -0.5802,  0.7350],
+                [ 0.5704, -1.4290, -1.9323,  2.4480]])
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> # batch matrix multiplication
+        >>> As = torch.randn(3, 2, 5)
+        >>> Bs = torch.randn(3, 5, 4)
+        >>> torch.einsum('bij,bjk->bik', As, Bs)
+        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
+                [-1.6706, -0.8097, -0.8025, -2.1183]],
+
+                [[ 4.2239,  0.3107, -0.5756, -0.2354],
+                [-1.4558, -0.3460,  1.5087, -0.8530]],
+
+                [[ 2.8153,  1.8787, -4.3839, -1.2112],
+                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
+
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> # with sublist format and ellipsis
+        >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
+        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
+                [-1.6706, -0.8097, -0.8025, -2.1183]],
+
+                [[ 4.2239,  0.3107, -0.5756, -0.2354],
+                [-1.4558, -0.3460,  1.5087, -0.8530]],
+
+                [[ 2.8153,  1.8787, -4.3839, -1.2112],
+                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
+
+        >>> # batch permute
+        >>> A = torch.randn(2, 3, 4, 5)
+        >>> torch.einsum('...ij->...ji', A).shape
+        torch.Size([2, 3, 5, 4])
+
+        >>> # equivalent to torch.nn.functional.bilinear
+        >>> A = torch.randn(3, 5, 4)
+        >>> l = torch.randn(2, 5)
+        >>> r = torch.randn(2, 4)
+        >>> torch.einsum('bn,anm,bm->ba', l, A, r)
+        tensor([[-0.3430, -5.2405,  0.4494],
+                [ 0.3311,  5.5201, -3.0356]])
+    """
+    import torch.backends.opt_einsum as opt_einsum
+    # This wrapper exists to support variadic args.
+    if len(args) < 2:
+        raise ValueError('einsum(): must specify the equation string and at least one operand, '
+                         'or at least one operand and its subscripts list')
+
+    equation = None
+    operands = None
+
+    if isinstance(args[0], torch.Tensor):
+        # Convert the subscript list format which is an interleaving of operand and its subscripts
+        # list with an optional output subscripts list at the end (see documentation for more details on this)
+        # to the equation string format by creating the equation string from the subscripts list and grouping the
+        # input operands into a tensorlist (List[Tensor]).
+        def parse_subscript(n: int) -> str:
+            if n == Ellipsis:
+                return '...'
+            if n >= 0 and n < 26:
+                return chr(ord('A') + n)
+            if n >= 26 and n < 52:
+                return chr(ord('a') + n - 26)
+            raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
+
+        # Parse subscripts for input operands
+        equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
+
+        # Parse optional output subscripts (provided when the number of arguments is odd)
+        if len(args) % 2 == 1:
+            equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
+            operands = args[:-1:2]
+        else:
+            operands = args[::2]
+    else:
+        equation = args[0]
+        operands = args[1:]
+
+    if has_torch_function(operands):
+        return handle_torch_function(einsum, operands, equation, *operands)
+
+    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
+        # the old interface of passing the operands as one list argument
+        _operands = operands[0]
+        # recurse incase operands contains value that has torch function
+        # in the original implementation this line is omitted
+        return einsum(equation, *_operands)
+
+    if len(operands) <= 2 or not opt_einsum.enabled:
+        # the path for contracting 0 or 1 time(s) is already optimized
+        # or the user has disabled using opt_einsum
+        return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
+
+    path = None
+    if opt_einsum.is_available():
+        _opt_einsum = opt_einsum.get_opt_einsum()
+        tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
+        # flatten path for dispatching to C++
+        path = [item for pair in tupled_path for item in pair]
+    return _VF.einsum(equation, operands, path=path)  # type: ignore[attr-defined]
+
+
+# This wrapper exists to support variadic args.
+if TYPE_CHECKING:
+    # The JIT doesn't understand Union, so only add type annotation for mypy
+    def meshgrid(*tensors: Union[Tensor, List[Tensor]],
+                 indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
+        return _meshgrid(*tensors, indexing=indexing)
+else:
+    def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
+        r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
+
+        This is helpful when you want to visualize data over some
+        range of inputs. See below for a plotting example.
+
+        Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
+        inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
+        this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
+        G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
+        the output :math:`G_i` is constructed by expanding :math:`T_i`
+        to the result shape.
+
+        .. note::
+            0D inputs are treated equivalently to 1D inputs of a
+            single element.
+
+        .. warning::
+            `torch.meshgrid(*tensors)` currently has the same behavior
+            as calling `numpy.meshgrid(*arrays, indexing='ij')`.
+
+            In the future `torch.meshgrid` will transition to
+            `indexing='xy'` as the default.
+
+            https://github.com/pytorch/pytorch/issues/50276 tracks
+            this issue with the goal of migrating to NumPy's behavior.
+
+        .. seealso::
+
+            :func:`torch.cartesian_prod` has the same effect but it
+            collects the data in a tensor of vectors.
+
+        Args:
+            tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
+                treated as tensors of size :math:`(1,)` automatically
+
+            indexing: (str, optional): the indexing mode, either "xy"
+                or "ij", defaults to "ij". See warning for future changes.
+
+                If "xy" is selected, the first dimension corresponds
+                to the cardinality of the second input and the second
+                dimension corresponds to the cardinality of the first
+                input.
+
+                If "ij" is selected, the dimensions are in the same
+                order as the cardinality of the inputs.
+
+        Returns:
+            seq (sequence of Tensors): If the input has :math:`N`
+            tensors of size :math:`S_0 \ldots S_{N-1}``, then the
+            output will also have :math:`N` tensors, where each tensor
+            is of shape :math:`(S_0, ..., S_{N-1})`.
+
+        Example::
+
+            >>> x = torch.tensor([1, 2, 3])
+            >>> y = torch.tensor([4, 5, 6])
+
+            Observe the element-wise pairings across the grid, (1, 4),
+            (1, 5), ..., (3, 6). This is the same thing as the
+            cartesian product.
+            >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
+            >>> grid_x
+            tensor([[1, 1, 1],
+                    [2, 2, 2],
+                    [3, 3, 3]])
+            >>> grid_y
+            tensor([[4, 5, 6],
+                    [4, 5, 6],
+                    [4, 5, 6]])
+
+            This correspondence can be seen when these grids are
+            stacked properly.
+            >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
+            ...             torch.cartesian_prod(x, y))
+            True
+
+            `torch.meshgrid` is commonly used to produce a grid for
+            plotting.
+            >>> # xdoctest: +REQUIRES(module:matplotlib)
+            >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
+            >>> import matplotlib.pyplot as plt
+            >>> xs = torch.linspace(-5, 5, steps=100)
+            >>> ys = torch.linspace(-5, 5, steps=100)
+            >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
+            >>> z = torch.sin(torch.sqrt(x * x + y * y))
+            >>> ax = plt.axes(projection='3d')
+            >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
+            >>> plt.show()
+
+        .. image:: ../_static/img/meshgrid.png
+            :width: 512
+
+        """
+        return _meshgrid(*tensors, indexing=indexing)
+
+
+def _meshgrid(*tensors, indexing: Optional[str]):
+    if has_torch_function(tensors):
+        return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
+    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
+        # the old interface of passing the operands as one list argument
+        tensors = tensors[0]  # type: ignore[assignment]
+
+    # Continue allowing call of old method that takes no indexing
+    # kwarg for forward compatibility reasons.
+    #
+    # Remove this two weeks after landing.
+    kwargs = {} if indexing is None else {'indexing': indexing}
+    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
+
+
+def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
+         win_length: Optional[int] = None, window: Optional[Tensor] = None,
+         center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
+         onesided: Optional[bool] = None,
+         return_complex: Optional[bool] = None) -> Tensor:
+    r"""Short-time Fourier transform (STFT).
+
+    .. warning::
+        From version 1.8.0, :attr:`return_complex` must always be given
+        explicitly for real inputs and `return_complex=False` has been
+        deprecated. Strongly prefer `return_complex=True` as in a future
+        pytorch release, this function will only return complex tensors.
+
+        Note that :func:`torch.view_as_real` can be used to recover a real
+        tensor with an extra last dimension for real and imaginary components.
+
+    .. warning::
+        From version 2.1, a warning will be provided if a :attr:`window` is
+        not specified. In a future release, this attribute will be required.
+        Not providing a window currently defaults to using a rectangular window,
+        which may result in undesirable artifacts. Consider using tapered windows,
+        such as :func:`torch.hann_window`.
+
+    The STFT computes the Fourier transform of short overlapping windows of the
+    input. This giving frequency components of the signal as they change over
+    time. The interface of this function is modeled after (but *not* a drop-in
+    replacement for) librosa_ stft function.
+
+    .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
+
+    Ignoring the optional batch dimension, this method computes the following
+    expression:
+
+    .. math::
+        X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
+                            \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
+                            \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
+
+    where :math:`m` is the index of the sliding window, and :math:`\omega` is
+    the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
+    or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
+
+    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
+      sequences.
+
+    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
+      ``floor(n_fft / 4)``.
+
+    * If :attr:`win_length` is ``None`` (default), it is treated as equal to
+      :attr:`n_fft`.
+
+    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
+      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
+      treated as if having :math:`1` everywhere in the window. If
+      :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
+      both sides to length :attr:`n_fft` before being applied.
+
+    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
+      both sides so that the :math:`t`-th frame is centered at time
+      :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
+      begins at time  :math:`t \times \text{hop\_length}`.
+
+    * :attr:`pad_mode` determines the padding method used on :attr:`input` when
+      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
+      all available options. Default is ``"reflect"``.
+
+    * If :attr:`onesided` is ``True`` (default for real input), only values for
+      :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
+      \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
+      the real-to-complex Fourier transform satisfies the conjugate symmetry,
+      i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
+      Note if the input or window tensors are complex, then :attr:`onesided`
+      output is not possible.
+
+    * If :attr:`normalized` is ``True`` (default is ``False``), the function
+      returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
+
+    * If :attr:`return_complex` is ``True`` (default if input is complex), the
+      return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
+      the output is a ``input.dim() + 2`` dimensional real tensor where the last
+      dimension represents the real and imaginary components.
+
+    Returns either a complex tensor of size :math:`(* \times N \times T)` if
+    :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
+    \times T \times 2)`. Where :math:`*` is the optional batch size of
+    :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
+    and :math:`T` is the total number of frames used.
+
+    .. warning::
+      This function changed signature at version 0.4.1. Calling with the
+      previous signature may cause error or return incorrect result.
+
+    Args:
+        input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
+            batch dimension
+        n_fft (int): size of Fourier transform
+        hop_length (int, optional): the distance between neighboring sliding window
+            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
+        win_length (int, optional): the size of window frame and STFT filter.
+            Default: ``None``  (treated as equal to :attr:`n_fft`)
+        window (Tensor, optional): the optional window function.
+            Shape must be 1d and `<= n_fft`
+            Default: ``None`` (treated as window of all :math:`1` s)
+        center (bool, optional): whether to pad :attr:`input` on both sides so
+            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+            Default: ``True``
+        pad_mode (str, optional): controls the padding method used when
+            :attr:`center` is ``True``. Default: ``"reflect"``
+        normalized (bool, optional): controls whether to return the normalized STFT results
+             Default: ``False``
+        onesided (bool, optional): controls whether to return half of results to
+            avoid redundancy for real inputs.
+            Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
+        return_complex (bool, optional): whether to return a complex tensor, or
+            a real tensor with an extra last dimension for the real and
+            imaginary components.
+
+            .. versionchanged:: 2.0
+               ``return_complex`` is now a required argument for real inputs,
+               as the default is being transitioned to ``True``.
+
+            .. deprecated:: 2.0
+               ``return_complex=False`` is deprecated, instead use ``return_complex=True``
+               Note that calling :func:`torch.view_as_real` on the output will
+               recover the deprecated output format.
+
+    Returns:
+        Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
+           - `B?` is an optional batch dimension from the input.
+           - `N` is the number of frequency samples, `(n_fft // 2) + 1` for
+             `onesided=True`, or otherwise `n_fft`.
+           - `T` is the number of frames, `1 + L // hop_length`
+             for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
+           - `C?` is an optional length-2 dimension of real and imaginary
+             components, present when `return_complex=False`.
+
+    """
+    if has_torch_function_unary(input):
+        return handle_torch_function(
+            stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
+            window=window, center=center, pad_mode=pad_mode, normalized=normalized,
+            onesided=onesided, return_complex=return_complex)
+    # NOTE: Do not edit. This code will be removed once the forward-compatibility
+    #       period is over for PR #73432
+    if center:
+        signal_dim = input.dim()
+        extended_shape = [1] * (3 - signal_dim) + list(input.size())
+        pad = int(n_fft // 2)
+        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
+        input = input.view(input.shape[-signal_dim:])
+    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
+                    normalized, onesided, return_complex)
+
+
+istft = _add_docstr(
+    torch.istft,
+    "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
+    "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
+    r"""
+Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
+
+.. warning::
+    From version 2.1, a warning will be provided if a :attr:`window` is
+    not specified. In a future release, this attribute will be required.
+    Please provide the same window used in the stft call.
+
+It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
+least squares estimation of the original signal. The algorithm will check using the NOLA condition (
+nonzero overlap).
+
+Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
+created by the summation of all the windows is never zero at certain point in time. Specifically,
+:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
+
+Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
+``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
+since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
+``istft`` will pad zeros to the end of the returned signal.
+
+If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
+Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
+calculated without additional information.
+
+Example: Suppose the last window is:
+``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
+
+The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
+of right padding. These additional values could be zeros or a reflection of the signal so providing
+:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
+(some loss of signal).
+
+[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
+IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
+
+Args:
+    input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
+        output. That is a complex tensor of shape `(B?, N, T)` where
+
+        - `B?` is an optional batch dimension
+        - `N` is the number of frequency samples, `(n_fft // 2) + 1`
+          for onesided input, or otherwise `n_fft`.
+        - `T` is the number of frames, `1 + length // hop_length` for centered stft,
+          or `1 + (length - n_fft) // hop_length` otherwise.
+
+        .. versionchanged:: 2.0
+            Real datatype inputs are no longer supported. Input must now have a
+            complex datatype, as returned by ``stft(..., return_complex=True)``.
+    n_fft (int): Size of Fourier transform
+    hop_length (Optional[int]): The distance between neighboring sliding window frames.
+        (Default: ``n_fft // 4``)
+    win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
+    window (Optional[torch.Tensor]): The optional window function.
+        Shape must be 1d and `<= n_fft`
+        (Default: ``torch.ones(win_length)``)
+    center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
+        centered at time :math:`t \times \text{hop\_length}`.
+        (Default: ``True``)
+    normalized (bool): Whether the STFT was normalized. (Default: ``False``)
+    onesided (Optional[bool]): Whether the STFT was onesided.
+        (Default: ``True`` if `n_fft != fft_size` in the input size)
+    length (Optional[int]): The amount to trim the signal by (i.e. the
+        original signal length). Defaults to `(T - 1) * hop_length` for
+        centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
+        is the number of input frames.
+    return_complex (Optional[bool]):
+        Whether the output should be complex, or if the input should be
+        assumed to derive from a real signal and window.
+        Note that this is incompatible with ``onesided=True``.
+        (Default: ``False``)
+
+Returns:
+    Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
+        `B?` is an optional batch dimension from the input tensor.
+""")
+
+
+if TYPE_CHECKING:
+    # These _impl functions return a variable number of tensors as output with
+    # __torch_function__; tuple unpacking is done already rather than being
+    # done by the caller of the _impl function
+    _unique_impl_out = Any
+else:
+    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
+
+
+def _unique_impl(input: Tensor, sorted: bool = True,
+                 return_inverse: bool = False, return_counts: bool = False,
+                 dim: Optional[int] = None) -> _unique_impl_out:
+    r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
+
+    Returns the unique elements of the input tensor.
+
+    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
+        this function also eliminates non-consecutive duplicate values.
+
+    .. note:: Currently in the CUDA implementation and the CPU implementation,
+        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
+        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
+        :func:`torch.unique_consecutive` which avoids the sorting.
+
+    Args:
+        input (Tensor): the input tensor
+        sorted (bool): Whether to sort the unique elements in ascending order
+            before returning as output.
+        return_inverse (bool): Whether to also return the indices for where
+            elements in the original input ended up in the returned unique list.
+        return_counts (bool): Whether to also return the counts for each unique
+            element.
+        dim (int, optional): the dimension to operate upon. If ``None``, the
+            unique of the flattened input is returned. Otherwise, each of the
+            tensors indexed by the given dimension is treated as one of the
+            elements to apply the unique operation upon. See examples for more
+            details. Default: ``None``
+
+    Returns:
+        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
+
+            - **output** (*Tensor*): the output list of unique scalar elements.
+            - **inverse_indices** (*Tensor*): (optional) if
+              :attr:`return_inverse` is True, there will be an additional
+              returned tensor (same shape as input) representing the indices
+              for where elements in the original input map to in the output;
+              otherwise, this function will only return a single tensor.
+            - **counts** (*Tensor*): (optional) if
+              :attr:`return_counts` is True, there will be an additional
+              returned tensor (same shape as output or output.size(dim),
+              if dim was specified) representing the number of occurrences
+              for each unique value or tensor.
+
+    Example::
+
+        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
+        >>> output
+        tensor([1, 2, 3])
+
+        >>> output, inverse_indices = torch.unique(
+        ...     torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
+        >>> output
+        tensor([1, 2, 3])
+        >>> inverse_indices
+        tensor([0, 2, 1, 2])
+
+        >>> output, inverse_indices = torch.unique(
+        ...     torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
+        >>> output
+        tensor([1, 2, 3])
+        >>> inverse_indices
+        tensor([[0, 2],
+                [1, 2]])
+
+        >>> a = torch.tensor([
+        ...     [
+        ...         [1, 1, 0, 0],
+        ...         [1, 1, 0, 0],
+        ...         [0, 0, 1, 1],
+        ...     ],
+        ...     [
+        ...         [0, 0, 1, 1],
+        ...         [0, 0, 1, 1],
+        ...         [1, 1, 1, 1],
+        ...     ],
+        ...     [
+        ...         [1, 1, 0, 0],
+        ...         [1, 1, 0, 0],
+        ...         [0, 0, 1, 1],
+        ...     ],
+        ... ])
+
+        >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
+        >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
+        >>> # each other, so one of them will be removed.
+        >>> (a[0, :, :] == a[2, :, :]).all()
+        tensor(True)
+        >>> a_unique_dim0 = torch.unique(a, dim=0)
+        >>> a_unique_dim0
+        tensor([[[0, 0, 1, 1],
+                 [0, 0, 1, 1],
+                 [1, 1, 1, 1]],
+                [[1, 1, 0, 0],
+                 [1, 1, 0, 0],
+                 [0, 0, 1, 1]]])
+
+        >>> # Notice which sub-tensors from `a` match with the sub-tensors from
+        >>> # `a_unique_dim0`:
+        >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
+        tensor(True)
+        >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
+        tensor(True)
+
+        >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
+        >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
+        >>> # them will be removed.
+        >>> (a[:, 0, :] == a[:, 1, :]).all()
+        tensor(True)
+        >>> torch.unique(a, dim=1)
+        tensor([[[0, 0, 1, 1],
+                 [1, 1, 0, 0]],
+                [[1, 1, 1, 1],
+                 [0, 0, 1, 1]],
+                [[0, 0, 1, 1],
+                 [1, 1, 0, 0]]])
+
+        >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
+        >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
+        >>> # `a[:, :, 3]` match each other as well. So in this case, two of the
+        >>> # sub-tensors will be removed.
+        >>> (a[:, :, 0] == a[:, :, 1]).all()
+        tensor(True)
+        >>> (a[:, :, 2] == a[:, :, 3]).all()
+        tensor(True)
+        >>> torch.unique(a, dim=2)
+        tensor([[[0, 1],
+                 [0, 1],
+                 [1, 0]],
+                [[1, 0],
+                 [1, 0],
+                 [1, 1]],
+                [[0, 1],
+                 [0, 1],
+                 [1, 0]]])
+    """
+    if has_torch_function_unary(input):
+        return handle_torch_function(
+            unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
+            return_counts=return_counts, dim=dim)
+
+    if dim is not None:
+        output, inverse_indices, counts = _VF.unique_dim(
+            input,
+            dim,
+            sorted=sorted,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+        )
+    else:
+        output, inverse_indices, counts = torch._unique2(
+            input,
+            sorted=sorted,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+        )
+    return output, inverse_indices, counts
+
+
+def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
+                             return_counts: bool = False,
+                             dim: Optional[int] = None) -> _unique_impl_out:
+    r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+    .. note:: This function is different from :func:`torch.unique` in the sense that this function
+        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
+        in C++.
+
+    Args:
+        input (Tensor): the input tensor
+        return_inverse (bool): Whether to also return the indices for where
+            elements in the original input ended up in the returned unique list.
+        return_counts (bool): Whether to also return the counts for each unique
+            element.
+        dim (int): the dimension to apply unique. If ``None``, the unique of the
+            flattened input is returned. default: ``None``
+
+    Returns:
+        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
+
+            - **output** (*Tensor*): the output list of unique scalar elements.
+            - **inverse_indices** (*Tensor*): (optional) if
+              :attr:`return_inverse` is True, there will be an additional
+              returned tensor (same shape as input) representing the indices
+              for where elements in the original input map to in the output;
+              otherwise, this function will only return a single tensor.
+            - **counts** (*Tensor*): (optional) if
+              :attr:`return_counts` is True, there will be an additional
+              returned tensor (same shape as output or output.size(dim),
+              if dim was specified) representing the number of occurrences
+              for each unique value or tensor.
+
+    Example::
+
+        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
+        >>> output = torch.unique_consecutive(x)
+        >>> output
+        tensor([1, 2, 3, 1, 2])
+
+        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
+        >>> output
+        tensor([1, 2, 3, 1, 2])
+        >>> inverse_indices
+        tensor([0, 0, 1, 1, 2, 3, 3, 4])
+
+        >>> output, counts = torch.unique_consecutive(x, return_counts=True)
+        >>> output
+        tensor([1, 2, 3, 1, 2])
+        >>> counts
+        tensor([2, 2, 1, 2, 1])
+    """
+    if has_torch_function_unary(input):
+        return handle_torch_function(
+            unique_consecutive, (input,), input, return_inverse=return_inverse,
+            return_counts=return_counts, dim=dim)
+    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
+        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+    return output, inverse_indices, counts
+
+
+def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
+
+    if has_torch_function_unary(input):
+        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
+
+    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
+    return output, counts
+
+
+def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
+
+    if has_torch_function_unary(input):
+        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
+
+    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
+    return output
+
+
+def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
+
+    if has_torch_function_unary(input):
+        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
+
+    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
+    return output, inverse_indices
+
+
+_return_inverse_false = boolean_dispatch(
+    arg_name='return_counts',
+    arg_index=3,
+    default=False,
+    if_true=_return_counts,
+    if_false=_return_output,
+    module_name=__name__,
+    func_name='unique')
+
+_return_inverse_true = boolean_dispatch(
+    arg_name='return_counts',
+    arg_index=3,
+    default=False,
+    if_true=_unique_impl,
+    if_false=_return_inverse,
+    module_name=__name__,
+    func_name='unique')
+
+# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
+# resolve the output type in TorchScript we need to statically know the value of both parameters
+
+unique = boolean_dispatch(
+    arg_name='return_inverse',
+    arg_index=2,
+    default=False,
+    if_true=_return_inverse_true,
+    if_false=_return_inverse_false,
+    module_name=__name__,
+    func_name='unique')
+unique.__doc__ = _unique_impl.__doc__
+
+
+def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
+    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
+
+    if has_torch_function_unary(input):
+        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+
+    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+    return output, counts
+
+
+def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
+    # type: (Tensor, bool, bool, Optional[int]) -> Tensor
+
+    if has_torch_function_unary(input):
+        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+
+    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+    return output
+
+
+def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
+    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
+
+    if has_torch_function_unary(input):
+        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+
+    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+    return output, inverse_indices
+
+
+_consecutive_return_inverse_false = boolean_dispatch(
+    arg_name='return_counts',
+    arg_index=1,
+    default=False,
+    if_true=_consecutive_return_counts,
+    if_false=_consecutive_return_output,
+    module_name=__name__,
+    func_name='unique_consecutive')
+
+_consecutive_return_inverse_true = boolean_dispatch(
+    arg_name='return_counts',
+    arg_index=1,
+    default=False,
+    if_true=_unique_consecutive_impl,
+    if_false=_consecutive_return_inverse,
+    module_name=__name__,
+    func_name='unique_consecutive')
+
+# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
+# resolve the output type in TorchScript we need to statically know the value of both parameters
+
+unique_consecutive = boolean_dispatch(
+    arg_name='return_inverse',
+    arg_index=2,
+    default=False,
+    if_true=_consecutive_return_inverse_true,
+    if_false=_consecutive_return_inverse_false,
+    module_name=__name__,
+    func_name='unique_consecutive')
+unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
+
+if TYPE_CHECKING:
+    pass
+    # There's no good way to use this type annotation without breaking JIT
+    # overloads. So leave untyped for mypy for now.
+else:
+    @overload
+    def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
+        pass
+
+    @overload  # noqa: F811
+    def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None):  # noqa: F811
+        pass
+
+    @overload  # noqa: F811
+    def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None):  # noqa: F811
+        pass
+
+    @overload  # noqa: F811
+    def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None):  # noqa: F811
+        pass
+
+
+def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None):  # noqa: F811
+    r"""Returns a contraction of a and b over multiple dimensions.
+
+    :attr:`tensordot` implements a generalized matrix product.
+
+    Args:
+      a (Tensor): Left tensor to contract
+      b (Tensor): Right tensor to contract
+      dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
+         contract or explicit lists of dimensions for :attr:`a` and
+         :attr:`b` respectively
+
+    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
+    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
+    respectively, :func:`~torch.tensordot` computes
+
+    .. math::
+        r_{i_0,...,i_{m-d}, i_d,...,i_n}
+          = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
+
+    When called with :attr:`dims` of the list form, the given dimensions will be contracted
+    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
+    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
+    dimensions.
+
+    Examples::
+
+        >>> a = torch.arange(60.).reshape(3, 4, 5)
+        >>> b = torch.arange(24.).reshape(4, 3, 2)
+        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
+        tensor([[4400., 4730.],
+                [4532., 4874.],
+                [4664., 5018.],
+                [4796., 5162.],
+                [4928., 5306.]])
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
+        >>> a = torch.randn(3, 4, 5, device='cuda')
+        >>> b = torch.randn(4, 5, 6, device='cuda')
+        >>> c = torch.tensordot(a, b, dims=2).cpu()
+        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
+                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
+                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])
+
+        >>> a = torch.randn(3, 5, 4, 6)
+        >>> b = torch.randn(6, 4, 5, 3)
+        >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
+        tensor([[  7.7193,  -2.4867, -10.3204],
+                [  1.5513, -14.4737,  -6.5113],
+                [ -0.2850,   4.2573,  -3.5997]])
+    """
+    if has_torch_function_variadic(a, b):
+        return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
+
+    if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
+        raise RuntimeError("tensordot expects dims to be int or "
+                           + "Tuple[List[int], List[int]] or "
+                           + "List[List[int]] containing two lists, but got "
+                           + f"dims={dims}")
+
+    dims_a: List[int] = []
+    dims_b: List[int] = []
+
+    if isinstance(dims, (tuple, list)):
+        dims_a, dims_b = dims
+
+    if isinstance(dims, torch.Tensor):
+        num_elements = dims.numel()
+        if num_elements > 1:
+            assert dims.size()[0] == 2
+            dims_a = torch.jit.annotate(List[int], dims[0].tolist())
+            dims_b = torch.jit.annotate(List[int], dims[1].tolist())
+        else:
+            dims_val = int(dims.item())
+            if dims_val < 0:
+                raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
+            dims_a = list(range(-dims_val, 0))
+            dims_b = list(range(dims_val))
+
+    if isinstance(dims, (int, torch.SymInt)):
+        if dims < 0:
+            raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
+        if dims > min(a.dim(), b.dim()):
+            raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
+        dims_a = list(range(-dims, 0))
+        dims_b = list(range(dims))
+
+    if out is None:
+        return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore[attr-defined]
+    else:
+        return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]
+
+
+def cartesian_prod(*tensors: Tensor) -> Tensor:
+    """Do cartesian product of the given sequence of tensors. The behavior is similar to
+    python's `itertools.product`.
+
+    Args:
+        *tensors: any number of 1 dimensional tensors.
+
+    Returns:
+        Tensor: A tensor equivalent to converting all the input tensors into lists,
+        do `itertools.product` on these lists, and finally convert the resulting list
+        into tensor.
+
+    Example::
+
+        >>> import itertools
+        >>> a = [1, 2, 3]
+        >>> b = [4, 5]
+        >>> list(itertools.product(a, b))
+        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
+        >>> tensor_a = torch.tensor(a)
+        >>> tensor_b = torch.tensor(b)
+        >>> torch.cartesian_prod(tensor_a, tensor_b)
+        tensor([[1, 4],
+                [1, 5],
+                [2, 4],
+                [2, 5],
+                [3, 4],
+                [3, 5]])
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(tensors):
+        return handle_torch_function(cartesian_prod, tensors, *tensors)
+    return _VF.cartesian_prod(tensors)  # type: ignore[attr-defined]
+
+
+def block_diag(*tensors):
+    """Create a block diagonal matrix from provided tensors.
+
+    Args:
+        *tensors: One or more tensors with 0, 1, or 2 dimensions.
+
+    Returns:
+        Tensor: A 2 dimensional tensor with all the input tensors arranged in
+        order such that their upper left and lower right corners are
+        diagonally adjacent. All other elements are set to 0.
+
+    Example::
+
+        >>> import torch
+        >>> A = torch.tensor([[0, 1], [1, 0]])
+        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
+        >>> C = torch.tensor(7)
+        >>> D = torch.tensor([1, 2, 3])
+        >>> E = torch.tensor([[4], [5], [6]])
+        >>> torch.block_diag(A, B, C, D, E)
+        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
+                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
+                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
+                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
+                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
+                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
+                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(tensors):
+        return handle_torch_function(block_diag, tensors, *tensors)
+    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore[attr-defined]
+
+
+def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
+    # type: (Tensor, Tensor, float, str) -> (Tensor)
+    r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
+
+    Args:
+        x1 (Tensor): input tensor of shape :math:`B \times P \times M`.
+        x2 (Tensor): input tensor of shape :math:`B \times R \times M`.
+        p: p value for the p-norm distance to calculate between each vector pair
+            :math:`\in [0, \infty]`.
+        compute_mode:
+            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
+            euclidean distance (p = 2) if P > 25 or R > 25
+            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
+            euclidean distance (p = 2)
+            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
+            euclidean distance (p = 2)
+            Default: use_mm_for_euclid_dist_if_necessary.
+
+    If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
+    output will have shape :math:`B \times P \times R`.
+
+    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
+    if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
+    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
+    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
+
+    Example:
+
+        >>> a = torch.tensor([[0.9041,  0.0196], [-0.3108, -2.4423], [-0.4821,  1.059]])
+        >>> a
+        tensor([[ 0.9041,  0.0196],
+                [-0.3108, -2.4423],
+                [-0.4821,  1.0590]])
+        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986,  1.3702]])
+        >>> b
+        tensor([[-2.1763, -0.4713],
+                [-0.6986,  1.3702]])
+        >>> torch.cdist(a, b, p=2)
+        tensor([[3.1193, 2.0959],
+                [2.7138, 3.8322],
+                [2.2830, 0.3791]])
+    """
+    if has_torch_function_variadic(x1, x2):
+        return handle_torch_function(
+            cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
+    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
+        return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
+    elif compute_mode == 'use_mm_for_euclid_dist':
+        return _VF.cdist(x1, x2, p, 1)  # type: ignore[attr-defined]
+    elif compute_mode == 'donot_use_mm_for_euclid_dist':
+        return _VF.cdist(x1, x2, p, 2)  # type: ignore[attr-defined]
+    else:
+        raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
+
+
+def atleast_1d(*tensors):
+    r"""
+    Returns a 1-dimensional view of each input tensor with zero dimensions.
+    Input tensors with one or more dimensions are returned as-is.
+
+    Args:
+        input (Tensor or list of Tensors)
+
+    Returns:
+        output (Tensor or tuple of Tensors)
+
+    Example::
+
+        >>> x = torch.arange(2)
+        >>> x
+        tensor([0, 1])
+        >>> torch.atleast_1d(x)
+        tensor([0, 1])
+        >>> x = torch.tensor(1.)
+        >>> x
+        tensor(1.)
+        >>> torch.atleast_1d(x)
+        tensor([1.])
+        >>> x = torch.tensor(0.5)
+        >>> y = torch.tensor(1.)
+        >>> torch.atleast_1d((x, y))
+        (tensor([0.5000]), tensor([1.]))
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(tensors):
+        return handle_torch_function(atleast_1d, tensors, *tensors)
+    if len(tensors) == 1:
+        tensors = tensors[0]
+    return _VF.atleast_1d(tensors)  # type: ignore[attr-defined]
+
+
+def atleast_2d(*tensors):
+    r"""
+    Returns a 2-dimensional view of each input tensor with zero dimensions.
+    Input tensors with two or more dimensions are returned as-is.
+
+    Args:
+        input (Tensor or list of Tensors)
+
+    Returns:
+        output (Tensor or tuple of Tensors)
+
+    Example::
+
+        >>> x = torch.tensor(1.)
+        >>> x
+        tensor(1.)
+        >>> torch.atleast_2d(x)
+        tensor([[1.]])
+        >>> x = torch.arange(4).view(2, 2)
+        >>> x
+        tensor([[0, 1],
+                [2, 3]])
+        >>> torch.atleast_2d(x)
+        tensor([[0, 1],
+                [2, 3]])
+        >>> x = torch.tensor(0.5)
+        >>> y = torch.tensor(1.)
+        >>> torch.atleast_2d((x, y))
+        (tensor([[0.5000]]), tensor([[1.]]))
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(tensors):
+        return handle_torch_function(atleast_2d, tensors, *tensors)
+    if len(tensors) == 1:
+        tensors = tensors[0]
+    return _VF.atleast_2d(tensors)  # type: ignore[attr-defined]
+
+
+def atleast_3d(*tensors):
+    r"""
+    Returns a 3-dimensional view of each input tensor with zero dimensions.
+    Input tensors with three or more dimensions are returned as-is.
+
+    Args:
+        input (Tensor or list of Tensors)
+
+    Returns:
+        output (Tensor or tuple of Tensors)
+
+    Example:
+
+        >>> x = torch.tensor(0.5)
+        >>> x
+        tensor(0.5000)
+        >>> torch.atleast_3d(x)
+        tensor([[[0.5000]]])
+        >>> y = torch.arange(4).view(2, 2)
+        >>> y
+        tensor([[0, 1],
+                [2, 3]])
+        >>> torch.atleast_3d(y)
+        tensor([[[0],
+                 [1]],
+                
+                [[2],
+                 [3]]])
+        >>> x = torch.tensor(1).view(1, 1, 1)
+        >>> x
+        tensor([[[1]]])
+        >>> torch.atleast_3d(x)
+        tensor([[[1]]])
+        >>> x = torch.tensor(0.5)
+        >>> y = torch.tensor(1.)
+        >>> torch.atleast_3d((x, y))
+        (tensor([[[0.5000]]]), tensor([[[1.]]]))
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(tensors):
+        return handle_torch_function(atleast_3d, tensors, *tensors)
+    if len(tensors) == 1:
+        tensors = tensors[0]
+    return _VF.atleast_3d(tensors)  # type: ignore[attr-defined]
+
+
+if TYPE_CHECKING:
+    pass
+    # There's no good way to use this type annotation; cannot rename norm() to
+    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
+    # for mypy for now.
+    #    def norm(input: Tensor,
+    #             p: Optional[Union[str, Number]] = "fro",
+    #             dim: Optional[Union[int, List[int]]] = None,
+    #             keepdim: bool = False,
+    #             out: Optional[Tensor] = None,
+    #             dtype: _dtype = None) -> Tensor:
+    #        return _norm_impl(input, p, dim, keepdim, out, dtype)
+else:
+    # TODO: type dim as BroadcastingList when
+    # https://github.com/pytorch/pytorch/issues/33782 is fixed
+    @overload
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
+        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
+
+    @overload  # noqa: F811
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
+
+    @overload  # noqa: F811
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
+
+    @overload  # noqa: F811
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
+
+
+def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+    r"""Returns the matrix norm or vector norm of a given tensor.
+
+    .. warning::
+
+        torch.norm is deprecated and may be removed in a future PyTorch release.
+        Its documentation and behavior may be incorrect, and it is no longer
+        actively maintained.
+
+        Use :func:`torch.linalg.vector_norm` when computing vector norms and
+        :func:`torch.linalg.matrix_norm` when computing matrix norms.
+        For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
+        Note, however, the signature for these functions is slightly different than the
+        signature for ``torch.norm``.
+
+    Args:
+        input (Tensor): The input tensor. Its data type must be either a floating
+            point or complex type. For complex inputs, the norm is calculated using the
+            absolute value of each element. If the input is complex and neither
+            :attr:`dtype` nor :attr:`out` is specified, the result's data type will
+            be the corresponding floating point type (e.g. float if :attr:`input` is
+            complexfloat).
+
+        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
+            The following norms can be calculated:
+
+            ======  ==============  ==========================
+            ord     matrix norm     vector norm
+            ======  ==============  ==========================
+            'fro'   Frobenius norm  --
+            'nuc'   nuclear norm    --
+            Number  --              sum(abs(x)**ord)**(1./ord)
+            ======  ==============  ==========================
+
+            The vector norm can be calculated across any number of dimensions.
+            The corresponding dimensions of :attr:`input` are flattened into
+            one dimension, and the norm is calculated on the flattened
+            dimension.
+
+            Frobenius norm produces the same result as ``p=2`` in all cases
+            except when :attr:`dim` is a list of three or more dims, in which
+            case Frobenius norm throws an error.
+
+            Nuclear norm can only be calculated across exactly two dimensions.
+
+        dim (int, tuple of ints, list of ints, optional):
+            Specifies which dimension or dimensions of :attr:`input` to
+            calculate the norm across. If :attr:`dim` is ``None``, the norm will
+            be calculated across all dimensions of :attr:`input`. If the norm
+            type indicated by :attr:`p` does not support the specified number of
+            dimensions, an error will occur.
+        keepdim (bool, optional): whether the output tensors have :attr:`dim`
+            retained or not. Ignored if :attr:`dim` = ``None`` and
+            :attr:`out` = ``None``. Default: ``False``
+        out (Tensor, optional): the output tensor. Ignored if
+            :attr:`dim` = ``None`` and :attr:`out` = ``None``.
+        dtype (:class:`torch.dtype`, optional): the desired data type of
+            returned tensor. If specified, the input tensor is casted to
+            :attr:`dtype` while performing the operation. Default: None.
+
+    .. note::
+        Even though ``p='fro'`` supports any number of dimensions, the true
+        mathematical definition of Frobenius norm only applies to tensors with
+        exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
+        aligns with the mathematical definition, since it can only be applied across
+        exactly two dimensions.
+
+    Example::
+
+        >>> import torch
+        >>> a = torch.arange(9, dtype= torch.float) - 4
+        >>> b = a.reshape((3, 3))
+        >>> torch.norm(a)
+        tensor(7.7460)
+        >>> torch.norm(b)
+        tensor(7.7460)
+        >>> torch.norm(a, float('inf'))
+        tensor(4.)
+        >>> torch.norm(b, float('inf'))
+        tensor(4.)
+        >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
+        >>> torch.norm(c, dim=0)
+        tensor([1.4142, 2.2361, 5.0000])
+        >>> torch.norm(c, dim=1)
+        tensor([3.7417, 4.2426])
+        >>> torch.norm(c, p=1, dim=1)
+        tensor([6., 6.])
+        >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
+        >>> torch.norm(d, dim=(1, 2))
+        tensor([ 3.7417, 11.2250])
+        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
+        (tensor(3.7417), tensor(11.2250))
+    """
+
+    if has_torch_function_unary(input):
+        return handle_torch_function(
+            norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
+
+    # NB. All the repeated code and weird python is to please TorchScript.
+    #     For a more compact implementation see the relevant function in `_refs/__init__.py`
+
+    # We don't do this for MPS or sparse tensors
+    if input.layout == torch.strided and input.device.type in \
+            ("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
+        if dim is not None:
+            if isinstance(dim, (int, torch.SymInt)):
+                _dim = [dim]
+            else:
+                _dim = dim
+        else:
+            _dim = None  # type: ignore[assignment]
+
+        if isinstance(p, str):
+            if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
+                if out is None:
+                    return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
+                else:
+                    return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
+
+            # Here we either call the nuclear norm, or we call matrix_norm with some arguments
+            # that will throw an error
+            if _dim is None:
+                _dim = list(range(input.ndim))
+            if out is None:
+                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
+            else:
+                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
+        else:
+            # NB. p should be Union[str, number], not Optional!
+            _p = 2.0 if p is None else p
+            if out is None:
+                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
+            else:
+                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
+
+    ndim = input.dim()
+
+    # catch default case
+    if dim is None and out is None and dtype is None and p is not None:
+        if isinstance(p, str):
+            if p == "fro":
+                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
+        if not isinstance(p, str):
+            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))
+            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore[attr-defined]
+
+    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
+    # remove the overloads where dim is an int and replace with BraodcastingList1
+    # and remove next four lines, replace _dim with dim
+    if dim is not None:
+        if isinstance(dim, (int, torch.SymInt)):
+            _dim = [dim]
+        else:
+            _dim = dim
+    else:
+        _dim = None  # type: ignore[assignment]
+
+    if isinstance(p, str):
+        if p == "fro":
+            if dtype is not None:
+                raise ValueError("dtype argument is not supported in frobenius norm")
+
+            if _dim is None:
+                _dim = list(range(ndim))
+            if out is None:
+                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
+            else:
+                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
+        elif p == "nuc":
+            if dtype is not None:
+                raise ValueError("dtype argument is not supported in nuclear norm")
+            if _dim is None:
+                if out is None:
+                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore[arg-type]
+                else:
+                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore[arg-type]
+            else:
+                if out is None:
+                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
+                else:
+                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
+        raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
+    else:
+        if _dim is None:
+            _dim = list(range(ndim))
+
+        if out is None:
+            if dtype is None:
+                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore[attr-defined]
+            else:
+                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore[attr-defined]
+        else:
+            if dtype is None:
+                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore[attr-defined]
+            else:
+                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore[attr-defined]
+
+def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
+    r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
+    index into an arbitrary tensor of the specified shape.
+
+    Args:
+        indices (Tensor): An integer tensor containing indices into the
+            flattened version of an arbitrary tensor of shape :attr:`shape`.
+            All elements must be in the range ``[0, prod(shape) - 1]``.
+
+        shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
+            tensor. All elements must be non-negative.
+
+    Returns:
+        tuple of Tensors: Each ``i``-th tensor in the output corresponds with
+        dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
+        ``indices`` and contains one index into dimension ``i`` for each of the
+        flat indices given by ``indices``.
+
+    Example::
+
+        >>> import torch
+        >>> torch.unravel_index(torch.tensor(4), (3, 2))
+        (tensor(2),
+         tensor(0))
+
+        >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
+        (tensor([2, 0]),
+         tensor([0, 1]))
+
+        >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
+        (tensor([0, 0, 1, 1, 2, 2]),
+         tensor([0, 1, 0, 1, 0, 1]))
+
+        >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
+        (tensor([1, 5]),
+         tensor([2, 6]),
+         tensor([3, 7]),
+         tensor([4, 8]))
+
+        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
+        (tensor([[1], [5]]),
+         tensor([[2], [6]]),
+         tensor([[3], [7]]),
+         tensor([[4], [8]]))
+
+        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
+        (tensor([[12], [56]]),
+         tensor([[34], [78]]))
+    """
+    if has_torch_function_unary(indices):
+        return handle_torch_function(
+            unravel_index, (indices,), indices, shape=shape)
+    res_tensor = _unravel_index(indices, shape)
+    return res_tensor.unbind(-1)
+
+def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
+    torch._check_type(
+        not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
+        lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
+
+    torch._check_type(
+        isinstance(shape, (int, torch.SymInt, Sequence)),
+        lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
+
+    if isinstance(shape, (int, torch.SymInt)):
+        shape = torch.Size([shape])
+    else:
+        for dim in shape:
+            torch._check_type(
+                isinstance(dim, (int, torch.SymInt)),
+                lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
+        shape = torch.Size(shape)
+
+    torch._check_value(
+        all(dim >= 0 for dim in shape),
+        lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
+
+    coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
+    return indices.unsqueeze(-1).floor_divide(
+        torch.tensor(coefs, device=indices.device, dtype=torch.int64)
+    ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
+
+def chain_matmul(*matrices, out=None):
+    r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
+    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
+    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
+    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
+    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
+
+    .. warning::
+
+        :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
+        Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
+        rather than multiple arguments.
+
+    Args:
+        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
+        out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
+
+    Returns:
+        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
+        would be of dimensions :math:`p_{1} \times p_{N + 1}`.
+
+    Example::
+
+        >>> # xdoctest: +SKIP
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> a = torch.randn(3, 4)
+        >>> b = torch.randn(4, 5)
+        >>> c = torch.randn(5, 6)
+        >>> d = torch.randn(6, 7)
+        >>> # will raise a deprecation warning
+        >>> torch.chain_matmul(a, b, c, d)
+        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
+                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
+                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])
+
+    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
+    """
+    # This wrapper exists to support variadic args.
+    if has_torch_function(matrices):
+        return handle_torch_function(chain_matmul, matrices, *matrices)
+
+    if out is None:
+        return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]
+    else:
+        return _VF.chain_matmul(matrices, out=out)  # type: ignore[attr-defined]
+
+
+def _lu_impl(A, pivot=True, get_infos=False, out=None):
+    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
+    r"""Computes the LU factorization of a matrix or batches of matrices
+    :attr:`A`. Returns a tuple containing the LU factorization and
+    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to
+    ``True``.
+
+    .. warning::
+
+        :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
+        and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
+        future PyTorch release.
+        ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
+
+        .. code:: python
+
+            LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
+
+        ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
+
+        .. code:: python
+
+            LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
+
+    .. note::
+        * The returned permutation matrix for every matrix in the batch is
+          represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
+          ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
+          the ``i``-th row was permuted with the ``j-1``-th row.
+        * LU factorization with :attr:`pivot` = ``False`` is not available
+          for CPU, and attempting to do so will throw an error. However,
+          LU factorization with :attr:`pivot` = ``False`` is available for
+          CUDA.
+        * This function does not check if the factorization was successful
+          or not if :attr:`get_infos` is ``True`` since the status of the
+          factorization is present in the third element of the return tuple.
+        * In the case of batches of square matrices with size less or equal
+          to 32 on a CUDA device, the LU factorization is repeated for
+          singular matrices due to the bug in the MAGMA library
+          (see magma issue 13).
+        * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
+
+    .. warning::
+        The gradients of this function will only be finite when :attr:`A` is full rank.
+        This is because the LU decomposition is just differentiable at full rank matrices.
+        Furthermore, if :attr:`A` is close to not being full rank,
+        the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
+
+    Args:
+        A (Tensor): the tensor to factor of size :math:`(*, m, n)`
+        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
+        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
+                                    Default: ``False``
+        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
+                               then the elements in the tuple are Tensor, IntTensor,
+                               and IntTensor. If :attr:`get_infos` is ``False``, then the
+                               elements in the tuple are Tensor, IntTensor. Default: ``None``
+
+    Returns:
+        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
+
+            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
+
+            - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
+              ``pivots`` stores all the intermediate transpositions of rows.
+              The final permutation ``perm`` could be reconstructed by
+              applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
+              where ``perm`` is initially the identity permutation of :math:`m` elements
+              (essentially this is what :func:`torch.lu_unpack` is doing).
+
+            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
+              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
+              each minibatch has succeeded or failed
+
+    Example::
+
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
+        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
+        >>> A = torch.randn(2, 3, 3)
+        >>> A_LU, pivots = torch.lu(A)
+        >>> A_LU
+        tensor([[[ 1.3506,  2.5558, -0.0816],
+                 [ 0.1684,  1.1551,  0.1940],
+                 [ 0.1193,  0.6189, -0.5497]],
+
+                [[ 0.4526,  1.2526, -0.3285],
+                 [-0.7988,  0.7175, -0.9701],
+                 [ 0.2634, -0.9255, -0.3459]]])
+        >>> pivots
+        tensor([[ 3,  3,  3],
+                [ 3,  3,  3]], dtype=torch.int32)
+        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
+        >>> if info.nonzero().size(0) == 0:
+        ...     print('LU factorization succeeded for all samples!')
+        LU factorization succeeded for all samples!
+    """
+    # If get_infos is True, then we don't need to check for errors and vice versa
+    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
+
+if TYPE_CHECKING:
+    _ListOrSeq = Sequence[Tensor]
+else:
+    _ListOrSeq = List[Tensor]
+
+
+def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
+    get_infos_int = 1 if get_infos else 0
+    if out_len - get_infos_int != 2:
+        raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
+    if not isinstance(out, (tuple, list)):
+        raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
+
+
+def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
+    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
+    if has_torch_function_unary(A):
+        return handle_torch_function(
+            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
+    result = _lu_impl(A, pivot, get_infos, out)
+    if out is not None:
+        _check_list_size(len(out), get_infos, out)
+        for i in range(len(out)):
+            out[i].resize_as_(result[i]).copy_(result[i])
+        return out
+    else:
+        return result  # A_LU, pivots, infos
+
+
+def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
+    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
+    # need to check for torch_function here so that we exit if
+    if has_torch_function_unary(A):
+        return handle_torch_function(
+            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
+    result = _lu_impl(A, pivot, get_infos, out)
+    if out is not None:
+        _check_list_size(len(out), get_infos, out)
+        for i in range(len(out)):
+            out[i].resize_as_(result[i]).copy_(result[i])
+        return out
+    else:
+        return result[0], result[1]  # A_LU, pivots
+
+# The return type of lu depends on `get_infos`, so in order to resolve the output type
+# of lu in TorchScript we need to statically know the value of `get_infos`
+lu = boolean_dispatch(
+    arg_name='get_infos',
+    arg_index=2,
+    default=False,
+    if_true=_lu_with_infos,
+    if_false=_lu_no_infos,
+    module_name=__name__,
+    func_name='lu')
+lu.__doc__ = _lu_impl.__doc__
+
+
+def align_tensors(*tensors):
+    raise RuntimeError('`align_tensors` not yet implemented.')
diff --git a/MLPY/Lib/site-packages/torch/futures/__init__.py b/MLPY/Lib/site-packages/torch/futures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a98a32e3da8cbebfa6489c1da74b559b4daab126
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/futures/__init__.py
@@ -0,0 +1,318 @@
+from __future__ import annotations
+
+from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union
+
+import torch
+
+__all__ = ['Future', 'collect_all', 'wait_all']
+
+T = TypeVar("T")
+S = TypeVar("S")
+
+
+class _PyFutureMeta(type(torch._C.Future), type(Generic)):  # type: ignore[misc, no-redef]
+    pass
+
+
+class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
+    r"""
+    Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
+    execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
+    also exposes a set of APIs to add callback functions and set results.
+
+    .. warning:: GPU support is a beta feature, subject to changes.
+    """
+
+    def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None):
+        r"""
+        Create an empty unset ``Future``. If the future is intended to hold
+        values containing CUDA tensors, (a superset of) their CUDA devices must
+        be specified at construction. (This is only supported if
+        ``torch.cuda.is_available()`` returns ``True``). This is needed to
+        ensure proper CUDA stream synchronization. The child futures, returned
+        by the ``then`` method, will inherit these devices.
+
+        Args:
+            devices(``List[Union[int, str, torch.device]]``, optional): the set
+                of devices on which tensors contained in this future's value are
+                allowed to reside and on which callbacks are allowed to operate.
+        """
+        if devices is None:
+            devices = []
+        super().__init__([torch.device(d) for d in devices])
+
+    def done(self) -> bool:
+        r"""
+        Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
+        has a result or an exception.
+
+        If the value contains tensors that reside on GPUs, ``Future.done()``
+        will return ``True`` even if the asynchronous kernels that are
+        populating those tensors haven't yet completed running on the device,
+        because at such stage the result is already usable, provided one
+        performs the appropriate synchronizations (see :meth:`wait`).
+        """
+        return super().done()
+
+    def wait(self) -> T:
+        r"""
+        Block until the value of this ``Future`` is ready.
+
+        If the value contains tensors that reside on GPUs, then an additional
+        synchronization is performed with the kernels (executing on the device)
+        which may be asynchronously populating those tensors. Such sync is
+        non-blocking, which means that ``wait()`` will insert the necessary
+        instructions in the current streams to ensure that further operations
+        enqueued on those streams will be properly scheduled after the async
+        kernels but, once that is done, ``wait()`` will return, even if those
+        kernels are still running. No further synchronization is required when
+        accessing and using the values, as long as one doesn't change streams.
+
+        Returns:
+            The value held by this ``Future``. If the function (callback or RPC)
+            creating the value has thrown an error, this ``wait`` method will
+            also throw an error.
+        """
+        return super().wait()
+
+    def value(self) -> T:
+        r"""
+        Obtain the value of an already-completed future.
+
+        This method should only be called after a call to :meth:`wait` has
+        completed, or inside a callback function passed to :meth:`then`. In
+        other cases this ``Future`` may not yet hold a value and calling
+        ``value()`` could fail.
+
+        If the value contains tensors that reside on GPUs, then this method will
+        *not* perform any additional synchronization. This should be done
+        beforehand, separately, through a call to :meth:`wait` (except within
+        callbacks, for which it's already being taken care of by :meth:`then`).
+
+        Returns:
+            The value held by this ``Future``. If the function (callback or RPC)
+            creating the value has thrown an error, this ``value()`` method will
+            also throw an error.
+        """
+        return super().value()
+
+    def then(self, callback: Callable[[Future[T]], S]) -> Future[S]:
+        r"""
+        Append the given callback function to this ``Future``, which will be run
+        when the ``Future`` is completed.  Multiple callbacks can be added to
+        the same ``Future``, but the order in which they will be executed cannot
+        be guaranteed (to enforce a certain order consider chaining:
+        ``fut.then(cb1).then(cb2)``). The callback must take one argument, which
+        is the reference to this ``Future``. The callback function can use the
+        :meth:`value` method to get the value. Note that if this ``Future`` is
+        already completed, the given callback will be run immediately inline.
+
+        If the ``Future``'s value contains tensors that reside on GPUs, the
+        callback might be invoked while the async kernels that are populating
+        those tensors haven't yet finished executing on the device. However, the
+        callback will be invoked with some dedicated streams set as current
+        (fetched from a global pool) which will be synchronized with those
+        kernels. Hence any operation performed by the callback on these tensors
+        will be scheduled on the device after the kernels complete. In other
+        words, as long as the callback doesn't switch streams, it can safely
+        manipulate the result without any additional synchronization. This is
+        similar to the non-blocking behavior of :meth:`wait`.
+
+        Similarly, if the callback returns a value that contains tensors that
+        reside on a GPU, it can do so even if the kernels that are producing
+        these tensors are still running on the device, as long as the callback
+        didn't change streams during its execution. If one wants to change
+        streams, one must be careful to re-synchronize them with the original
+        streams, that is, those that were current when the callback was invoked.
+
+        Args:
+            callback(``Callable``): a ``Callable`` that takes this ``Future`` as
+                                    the only argument.
+
+        Returns:
+            A new ``Future`` object that holds the return value of the
+            ``callback`` and will be marked as completed when the given
+            ``callback`` finishes.
+
+        .. note:: Note that if the callback function throws, either
+            through the original future being completed with an exception and
+            calling ``fut.wait()``, or through other code in the callback, the
+            future returned by ``then`` will be marked appropriately with the
+            encountered error. However, if this callback later completes
+            additional futures, those futures are not marked as completed with
+            an error and the user is responsible for handling completion/waiting
+            on those futures independently.
+
+        Example::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
+            >>> def callback(fut):
+            ...     print(f"RPC return value is {fut.wait()}.")
+            >>> fut = torch.futures.Future()
+            >>> # The inserted callback will print the return value when
+            >>> # receiving the response from "worker1"
+            >>> cb_fut = fut.then(callback)
+            >>> chain_cb_fut = cb_fut.then(
+            ...     lambda x : print(f"Chained cb done. {x.wait()}")
+            ... )
+            >>> fut.set_result(5)
+            RPC return value is 5.
+            Chained cb done. None
+        """
+        return cast(Future[S], super().then(callback))
+
+    def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None:
+        r"""
+        Append the given callback function to this ``Future``, which will be run
+        when the ``Future`` is completed.  Multiple callbacks can be added to
+        the same ``Future``, but the order in which they will be executed cannot
+        be guaranteed. The callback must take one argument, which is the
+        reference to this ``Future``. The callback function can use the
+        :meth:`value` method to get the value. Note that if this ``Future`` is
+        already completed, the given callback will be run inline.
+
+        We recommend that you use the :meth:`then` method as it provides a way
+        to synchronize after your callback has completed. ``add_done_callback``
+        can be cheaper if your callback does not return anything. But both
+        :meth:`then` and ``add_done_callback`` use the same callback
+        registration API under the hood.
+
+        With respect to GPU tensors, this method behaves in the same way as
+        :meth:`then`.
+
+        Args:
+            callback(``Future``): a ``Callable`` that takes in one argument,
+                which is the reference to this ``Future``.
+
+        .. note:: Note that if the callback function throws, either
+            through the original future being completed with an exception and
+            calling ``fut.wait()``, or through other code in the callback,
+            error handling must be carefully taken care of. For example, if
+            this callback later completes additional futures, those futures are
+            not marked as completed with an error and the user is responsible
+            for handling completion/waiting on those futures independently.
+
+        Example::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
+            >>> def callback(fut):
+            ...     print("This will run after the future has finished.")
+            ...     print(fut.wait())
+            >>> fut = torch.futures.Future()
+            >>> fut.add_done_callback(callback)
+            >>> fut.set_result(5)
+            This will run after the future has finished.
+            5
+        """
+        super().add_done_callback(callback)
+
+    def set_result(self, result: T) -> None:
+        r"""
+        Set the result for this ``Future``, which will mark this ``Future`` as
+        completed and trigger all attached callbacks. Note that a ``Future``
+        cannot be marked completed twice.
+
+        If the result contains tensors that reside on GPUs, this method can be
+        called even if the asynchronous kernels that are populating those
+        tensors haven't yet completed running on the device, provided that the
+        streams on which those kernels were enqueued are set as the current ones
+        when this method is called. Put simply, it's safe to call this method
+        immediately after launching those kernels, without any additional
+        synchronization, as long as one doesn't change streams in between. This
+        method will record events on all the relevant current streams and will
+        use them to ensure proper scheduling for all the consumers of this
+        ``Future``.
+
+        Args:
+            result (object): the result object of this ``Future``.
+
+        Example::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
+            >>> import threading
+            >>> import time
+            >>> def slow_set_future(fut, value):
+            ...     time.sleep(0.5)
+            ...     fut.set_result(value)
+            >>> fut = torch.futures.Future()
+            >>> t = threading.Thread(
+            ...     target=slow_set_future,
+            ...     args=(fut, torch.ones(2) * 3)
+            ... )
+            >>> t.start()
+            >>> print(fut.wait())
+            tensor([3., 3.])
+            >>> t.join()
+        """
+        super().set_result(result)
+
+    def set_exception(self, result: T) -> None:
+        r"""
+        Set an exception for this ``Future``, which will mark this ``Future`` as
+        completed with an error and trigger all attached callbacks. Note that
+        when calling wait()/value() on this ``Future``, the exception set here
+        will be raised inline.
+
+        Args:
+            result (BaseException): the exception for this ``Future``.
+
+        Example::
+            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
+            >>> fut = torch.futures.Future()
+            >>> fut.set_exception(ValueError("foo"))
+            >>> fut.wait()
+            Traceback (most recent call last):
+            ...
+            ValueError: foo
+        """
+        assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
+
+        def raise_error(fut_result):
+            raise fut_result
+
+        super()._set_unwrap_func(raise_error)
+        self.set_result(result)  # type: ignore[arg-type]
+
+
+def collect_all(futures: List[Future]) -> Future[List[Future]]:
+    r"""
+    Collects the provided :class:`~torch.futures.Future` objects into a single
+    combined :class:`~torch.futures.Future` that is completed when all of the
+    sub-futures are completed.
+
+    Args:
+        futures (list): a list of :class:`~torch.futures.Future` objects.
+
+    Returns:
+        Returns a :class:`~torch.futures.Future` object to a list of the passed
+        in Futures.
+
+    Example::
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
+        >>> fut0 = torch.futures.Future()
+        >>> fut1 = torch.futures.Future()
+        >>> fut = torch.futures.collect_all([fut0, fut1])
+        >>> fut0.set_result(0)
+        >>> fut1.set_result(1)
+        >>> fut_list = fut.wait()
+        >>> print(f"fut0 result = {fut_list[0].wait()}")
+        fut0 result = 0
+        >>> print(f"fut1 result = {fut_list[1].wait()}")
+        fut1 result = 1
+    """
+    return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
+
+
+def wait_all(futures: List[Future]) -> List:
+    r"""
+    Waits for all provided futures to be complete, and returns
+    the list of completed values. If any of the futures encounters an error,
+    the method will exit early and report the error not waiting for other
+    futures to complete.
+
+    Args:
+        futures (list): a list of :class:`~torch.futures.Future` object.
+
+    Returns:
+        A list of the completed :class:`~torch.futures.Future` results. This
+        method will throw an error if ``wait`` on any
+        :class:`~torch.futures.Future` throws.
+    """
+    return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]
diff --git a/MLPY/Lib/site-packages/torch/futures/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/futures/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b62fa75e53d60a3e12f9f0d8f83d1b053ba2781a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/futures/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__init__.py b/MLPY/Lib/site-packages/torch/fx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9096227068557199f9c17ba0ee4c4b9bef985502
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/__init__.py
@@ -0,0 +1,89 @@
+r'''
+FX is a toolkit for developers to use to transform ``nn.Module``
+instances. FX consists of three main components: a **symbolic tracer,**
+an **intermediate representation**, and **Python code generation**. A
+demonstration of these components in action:
+
+::
+
+    import torch
+    # Simple module for demonstration
+    class MyModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.param = torch.nn.Parameter(torch.rand(3, 4))
+            self.linear = torch.nn.Linear(4, 5)
+
+        def forward(self, x):
+            return self.linear(x + self.param).clamp(min=0.0, max=1.0)
+
+    module = MyModule()
+
+    from torch.fx import symbolic_trace
+    # Symbolic tracing frontend - captures the semantics of the module
+    symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
+
+    # High-level intermediate representation (IR) - Graph representation
+    print(symbolic_traced.graph)
+    """
+    graph():
+        %x : [num_users=1] = placeholder[target=x]
+        %param : [num_users=1] = get_attr[target=param]
+        %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
+        %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
+        %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
+        return clamp
+    """
+
+    # Code generation - valid Python code
+    print(symbolic_traced.code)
+    """
+    def forward(self, x):
+        param = self.param
+        add = x + param;  x = param = None
+        linear = self.linear(add);  add = None
+        clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
+        return clamp
+    """
+
+The **symbolic tracer** performs "symbolic execution" of the Python
+code. It feeds fake values, called Proxies, through the code. Operations
+on theses Proxies are recorded. More information about symbolic tracing
+can be found in the :func:`symbolic_trace` and :class:`Tracer`
+documentation.
+
+The **intermediate representation** is the container for the operations
+that were recorded during symbolic tracing. It consists of a list of
+Nodes that represent function inputs, callsites (to functions, methods,
+or :class:`torch.nn.Module` instances), and return values. More information
+about the IR can be found in the documentation for :class:`Graph`. The
+IR is the format on which transformations are applied.
+
+**Python code generation** is what makes FX a Python-to-Python (or
+Module-to-Module) transformation toolkit. For each Graph IR, we can
+create valid Python code matching the Graph's semantics. This
+functionality is wrapped up in :class:`GraphModule`, which is a
+:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
+``forward`` method generated from the Graph.
+
+Taken together, this pipeline of components (symbolic tracing ->
+intermediate representation -> transforms -> Python code generation)
+constitutes the Python-to-Python transformation pipeline of FX. In
+addition, these components can be used separately. For example,
+symbolic tracing can be used in isolation to capture a form of
+the code for analysis (and not transformation) purposes. Code
+generation can be used for programmatically generating models, for
+example from a config file. There are many uses for FX!
+
+Several example transformations can be found at the
+`examples `__
+repository.
+'''
+
+from .graph_module import GraphModule
+from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
+from .graph import Graph, CodeGen
+from .node import Node, map_arg, has_side_effect
+from .proxy import Proxy
+from .interpreter import Interpreter as Interpreter, Transformer as Transformer
+from .subgraph_rewriter import replace_pattern
diff --git a/MLPY/Lib/site-packages/torch/fx/__init__.pyi b/MLPY/Lib/site-packages/torch/fx/__init__.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..cd49d6e4a7685a71bcec8e67060c30addcc92bc4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/__init__.pyi
@@ -0,0 +1,11 @@
+from ._symbolic_trace import (
+    symbolic_trace as symbolic_trace,
+    Tracer as Tracer,
+    wrap as wrap,
+)
+from .graph import Graph as Graph
+from .graph_module import GraphModule as GraphModule
+from .interpreter import Interpreter as Interpreter, Transformer as Transformer
+from .node import has_side_effect as has_side_effect, map_arg as map_arg, Node as Node
+from .proxy import Proxy as Proxy
+from .subgraph_rewriter import replace_pattern as replace_pattern
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49d940e2232f97efd65d55b6d59192d8e1023b80
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3b3cd3839ba7da50e9129a06d95f9d5c8193ac3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..114bb63f35686b613da45aa816110d0c8ebb5659
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad4c5c949940de3682468746270c66de1d003d75
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1142b16df1c5a1f26d898b3c333b33b35c7a2079
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/annotate.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/annotate.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3740f8f884cda76c3f47cf743cdb4a7448b406d5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/annotate.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8549dbd7a2ea60b9576ddbd0abf86dab0c04269
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/config.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/graph.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/graph.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0701a607a36730db64d8adf33a102249033a6cab
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/graph.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/graph_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/graph_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de89122380d284cd8e310c5b73e19ae90e22e0bc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/graph_module.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/immutable_collections.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/immutable_collections.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60b54a852dd35247f9c08775943ede3db6951bb3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/immutable_collections.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/interpreter.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/interpreter.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9892a39be0090ac82b7c7a06c5921baea4449a80
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/interpreter.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/node.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/node.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4302a97a559db61c67a0d5646c12b739def8fbf2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/node.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/operator_schemas.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/operator_schemas.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42ac3abe7902d5994c817e024fa11ffe64873ba7
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/operator_schemas.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/proxy.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/proxy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae595d6f20ff2a6d490a16c502d7280a771141fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/proxy.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86be2a47283d3a1fe2e361feeb6ad798e1755945
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/tensor_type.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/tensor_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..883fc307bb589978227137d488e4e5628622a727
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/tensor_type.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f79497b9c1c2ff2189bac4be20d158d6d30c554
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/_compatibility.py b/MLPY/Lib/site-packages/torch/fx/_compatibility.py
new file mode 100644
index 0000000000000000000000000000000000000000..24b3da3cbe981d01bb4cd9777320185a367e4c13
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/_compatibility.py
@@ -0,0 +1,34 @@
+from typing import Any, Dict
+import textwrap
+
+_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
+_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
+
+def compatibility(is_backward_compatible : bool):
+    if is_backward_compatible:
+
+        def mark_back_compat(fn):
+            docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
+            docstring += """
+.. note::
+    Backwards-compatibility for this API is guaranteed.
+"""
+            fn.__doc__ = docstring
+            _BACK_COMPAT_OBJECTS.setdefault(fn)
+            _MARKED_WITH_COMPATIBILITY.setdefault(fn)
+            return fn
+
+        return mark_back_compat
+    else:
+
+        def mark_not_back_compat(fn):
+            docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
+            docstring += """
+.. warning::
+    This API is experimental and is *NOT* backward-compatible.
+"""
+            fn.__doc__ = docstring
+            _MARKED_WITH_COMPATIBILITY.setdefault(fn)
+            return fn
+
+        return mark_not_back_compat
diff --git a/MLPY/Lib/site-packages/torch/fx/_lazy_graph_module.py b/MLPY/Lib/site-packages/torch/fx/_lazy_graph_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bcaf61263d40a4646a7c0b1a92dd6104610e9d5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/_lazy_graph_module.py
@@ -0,0 +1,182 @@
+from contextlib import contextmanager
+
+from torch.fx import GraphModule
+from torch.fx.graph_module import (
+    _format_import_block,
+    reduce_graph_module,
+    reduce_package_graph_module,
+)
+from torch.package import PackageExporter, sys_importer
+from ._compatibility import compatibility
+
+_use_lazy_graph_module_flag = False
+_force_skip_lazy_graph_module_flag = False
+
+
+@compatibility(is_backward_compatible=False)
+@contextmanager
+def _force_skip_lazy_graph_module():
+    """
+    Skip using lazy graph module disregarding the setting of _use_lazy_graph_module.
+    Use to skip _LazyGraphModule when testing inductor torchscript related backend.
+
+    torch.jit.script a _LazyGraphModule results in following error:
+        https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69
+    """
+    try:
+        global _force_skip_lazy_graph_module_flag
+        prior = _force_skip_lazy_graph_module_flag
+        _force_skip_lazy_graph_module_flag = True
+        yield
+    finally:
+        _force_skip_lazy_graph_module_flag = prior
+
+
+@compatibility(is_backward_compatible=False)
+@contextmanager
+def _use_lazy_graph_module(should_use: bool):
+    try:
+        global _use_lazy_graph_module_flag
+        prior = _use_lazy_graph_module_flag
+        _use_lazy_graph_module_flag = (
+            should_use and not _force_skip_lazy_graph_module_flag
+        )
+        yield
+    finally:
+        _use_lazy_graph_module_flag = prior
+
+
+@compatibility(is_backward_compatible=False)
+def _get_graph_module_cls():
+    return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule
+
+
+def _make_graph_module(*args, graph_module_cls=None, **kwargs):
+    if graph_module_cls is None:
+        graph_module_cls = _get_graph_module_cls()
+
+    return graph_module_cls(*args, **kwargs)
+
+
+@compatibility(is_backward_compatible=False)
+class _LazyGraphModule(GraphModule):
+    """
+    The main difference between _LazyGraphModule and GraphModule is how recompile happens.
+    GraphModule will do a 'recompile' call to generate python code and the forward method when it's
+    constructed. Later on if the graph get updated, recompile method can be called again to refresh
+    the saved python code and forward method.
+
+    However in some cases especially in inductor, the recompilation can be a waste since we never
+    check the python code for the graph module or call its forward method. A few more concreate
+    examples regarding pattern matching fx passes in inductor:
+    1. some passes will update the graph to be compiled and then call recompile on the GraphModule.
+    2. some passes will trace small pattern function to search it in the graph being compiled and
+       replace the match with the traced graph of a replacement function. The pattern graph and
+       replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile
+       for them in GraphModule.__init__ is also a waste of time.
+
+    However simply skip calling GraphModule.recompile in these scenarios is also dangeruous.
+    People may want to check the python code or call the GraphModule's forward method for debugging purposes.
+
+    The way _LazyGraphModule solves it is, we override the recompile method to just mark the
+    need for recompilation but does not do the actual recompilation. Later on if people really
+    access the compiled python code or call the GraphModule's forward method, we do the real
+    recompilation.
+    """
+
+    @classmethod
+    def from_graphmodule(cls, gm: GraphModule):
+        if isinstance(gm, _LazyGraphModule):
+            return gm
+        else:
+            return _LazyGraphModule(gm, gm.graph)
+
+    @staticmethod
+    def force_recompile(gm):
+        """
+        Sometimes we need force a recompile as a workaround
+        - we want to do the real recompilation before symbolic_trace to avoid error:
+            https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
+        """
+        if isinstance(gm, _LazyGraphModule):
+            gm.real_recompile()
+
+    def real_recompile(self):
+        if self._needs_recompile():
+            self._real_recompile()
+
+    @classmethod
+    def _needs_recompile(cls):
+        return cls.forward is cls._lazy_forward
+
+    def _lazy_forward(self, *args, **kwargs):
+        # Call self.real_recompile() rather than self._real_recompile() here.
+        # The _lazy_forward method may be saved and call repeatedly.
+        # Calling self.real_recompile can make sure we skip recompilation if
+        # we have already done so.
+        self.real_recompile()
+        assert not self._needs_recompile()
+
+        # call `__call__` rather than 'forward' since recompilation may
+        # install a wrapper for `__call__` to provide a customized error
+        # message.
+        return self(*args, **kwargs)
+
+    forward = _lazy_forward
+
+    # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
+    # or __reduce__ by calling _real_recompile. But I don't find a good way
+    # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
+    # will be used in torch::deploy. So it's skipped for now.
+
+    def __reduce_package__(self, exporter: PackageExporter):
+        """
+        Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
+        than 'self.recompile' since for a _LazyGraphModule, self.recompile just
+        mark the need of recompilation and does not return the PythonCode object.
+        """
+        python_code = self._real_recompile()
+        dict_without_graph = self.__dict__.copy()
+        dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
+        del dict_without_graph["_graph"]
+
+        generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
+        import_block = _format_import_block(python_code.globals, exporter.importer)
+        module_code = import_block + self.code
+        exporter.save_source_string(generated_module_name, module_code)
+        return (
+            reduce_package_graph_module,
+            (dict_without_graph, generated_module_name),
+        )
+
+    def __reduce__(self):
+        """
+        Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
+        than 'self.recompile' since for a _LazyGraphModule, self.recompile just
+        mark the need of recompilation and does not return the PythonCode object.
+        """
+        python_code = self._real_recompile()
+        dict_without_graph = self.__dict__.copy()
+        import_block = _format_import_block(python_code.globals, sys_importer)
+        del dict_without_graph["_graph"]
+        return (reduce_graph_module, (dict_without_graph, import_block))
+
+    def _real_recompile(self):
+        return super().recompile()
+
+    @classmethod
+    def recompile(cls):
+        cls.forward = cls._lazy_forward
+
+    @property
+    def code(self) -> str:
+        self.real_recompile()
+        return super().code
+
+    def __str__(self) -> str:
+        """
+        str(GraphModule) will access the _code attribute. Make sure recompile
+        happens so _code attribute is available.
+        """
+        self.real_recompile()
+        return super().__str__()
diff --git a/MLPY/Lib/site-packages/torch/fx/_pytree.py b/MLPY/Lib/site-packages/torch/fx/_pytree.py
new file mode 100644
index 0000000000000000000000000000000000000000..510be5f33516bf7a2bd90f1f92ef18bc540a519b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/_pytree.py
@@ -0,0 +1,102 @@
+from collections import namedtuple
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
+
+import torch.return_types
+
+from torch.utils._pytree import PyTree, TreeSpec
+
+FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
+FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
+
+SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
+SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
+
+
+def register_pytree_flatten_spec(
+    cls: Type[Any],
+    flatten_fn_spec: FlattenFuncSpec,
+    flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
+) -> None:
+    SUPPORTED_NODES[cls] = flatten_fn_spec
+    SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
+
+
+def tree_flatten_spec(
+    pytree: PyTree,
+    spec: TreeSpec,
+    exact_structural_match=False,
+) -> List[Any]:
+    if spec.is_leaf():
+        return [pytree]
+    if spec.type not in SUPPORTED_NODES:
+        raise RuntimeError(
+            f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with "
+            "torch.fx._pytree.register_pytree_flatten_spec.  If you have serialized your model, make "
+            "sure that any custom pytrees have been registered before loading it.",
+        )
+    flatten_fn_spec = SUPPORTED_NODES[spec.type]
+    child_pytrees = flatten_fn_spec(pytree, spec)
+    if exact_structural_match:
+        flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type]
+        if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec(
+            pytree,
+            spec,
+        ):
+            raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}")
+    result = []
+    for child, child_spec in zip(child_pytrees, spec.children_specs):
+        flat = tree_flatten_spec(child, child_spec, exact_structural_match)
+        result += flat
+    return result
+
+
+def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
+    return [d[k] for k in spec.context]
+
+
+def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
+    return [d[i] for i in range(spec.num_children)]
+
+
+def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
+    return [d[i] for i in range(spec.num_children)]
+
+
+def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
+    return [d[i] for i in range(spec.num_children)]
+
+
+def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
+    return len(d) == spec.num_children
+
+
+def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
+    return len(d) == spec.num_children
+
+
+def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
+    return len(d) == spec.num_children
+
+
+def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
+    return len(d) == spec.num_children
+
+
+register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
+register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
+register_pytree_flatten_spec(
+    tuple,
+    _tuple_flatten_spec,
+    _tuple_flatten_spec_exact_match,
+)
+for return_type in torch.return_types.all_return_types:
+    register_pytree_flatten_spec(
+        return_type,
+        _tuple_flatten_spec,
+        _tuple_flatten_spec_exact_match,
+    )
+register_pytree_flatten_spec(
+    namedtuple,  # type: ignore[arg-type]
+    _namedtuple_flatten_spec,
+    _namedtuple_flatten_spec_exact_match,
+)
diff --git a/MLPY/Lib/site-packages/torch/fx/_symbolic_trace.py b/MLPY/Lib/site-packages/torch/fx/_symbolic_trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08453e846a2355a9a6a1053aeec777829cd177d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/_symbolic_trace.py
@@ -0,0 +1,1202 @@
+import builtins
+import copy
+import functools
+import inspect
+import math
+import os
+import warnings
+import collections
+from itertools import chain
+from types import CodeType, FunctionType, ModuleType
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    NamedTuple,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    Union,
+)
+
+import torch
+import torch.utils._pytree as pytree
+from torch._C import ScriptObject  # type: ignore[attr-defined]
+
+from ._compatibility import compatibility
+from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
+from .graph_module import GraphModule
+from ._lazy_graph_module import _make_graph_module
+from .node import Argument, base_types, map_aggregate
+from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
+
+HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
+
+# These need to run in global scope to handle nested calls correctly
+_orig_module_call: Callable = torch.nn.Module.__call__
+_orig_module_getattr: Callable = torch.nn.Module.__getattr__
+
+_proxyable_classes: Dict[Type, None] = {}
+
+_is_fx_tracing_flag = False
+
+
+def is_fx_tracing():
+    return _is_fx_tracing_flag
+
+@compatibility(is_backward_compatible=True)
+class ProxyableClassMeta(type):
+    """
+    ProxyableClassMeta allows you to make construction of a given Python class
+    symbolically traceable. For example::
+
+        import torch
+        import torch.fx
+
+        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
+            def __init__(self, left, right):
+                self.left, self.right = left, right
+
+            def add(self, other):
+                l = self.left + other.left
+                r = self.right + other.right
+                return TensorPair(l, r)
+
+            def mul(self, other):
+                l = self.left * other.left
+                r = self.right * other.right
+                return TensorPair(l, r)
+
+        def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
+            s = x.add(TensorPair(y, y))
+            return s.mul(x)
+
+        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
+        y = torch.randn(5, 3)
+        ref_out = use_tensor_pair_ctor(x, y)
+
+        traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
+        print(traced.code)
+        '''
+        def forward(self, x : __main___TensorPair, y : torch.Tensor):
+            tensor_pair = __main___TensorPair(y, y);  y = None
+            add = x.add(tensor_pair);  tensor_pair = None
+            mul = add.mul(x);  add = x = None
+            return mul
+        '''
+
+    From this example, we can see that construction of a class (``TensorPair``)
+    defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
+    tracing.
+    """
+
+    def __init__(cls, name, bases, attrs):
+        _proxyable_classes.setdefault(cls)
+        super().__init__(name, bases, attrs)
+
+    def __call__(cls, *args, **kwargs):
+        instance = cls.__new__(cls)  # type: ignore[call-overload]
+
+        if not is_fx_tracing():
+            cls.__init__(instance, *args, **kwargs)  # type: ignore[misc]
+            return instance
+
+        found_proxies = []
+
+        def check_proxy(a):
+            if isinstance(a, Proxy):
+                found_proxies.append(a)
+
+        map_aggregate(args, check_proxy)
+        map_aggregate(kwargs, check_proxy)
+
+        if len(found_proxies) != 0:
+            tracer = found_proxies[0].tracer
+            return tracer.create_proxy("call_function", cls, args, kwargs)
+        else:
+            cls.__init__(instance, *args, **kwargs)  # type: ignore[misc]
+            return instance
+
+
+def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
+    co = fn.__code__
+    co_flags = co.co_flags & ~HAS_VARSTUFF
+    co_args: tuple
+    if hasattr(co, "co_qualname"):
+        # Python-3.11+ code signature
+        co_args = (
+            nargs,
+            0,
+            0,
+            co.co_nlocals,
+            co.co_stacksize,
+            co_flags,
+            co.co_code,
+            co.co_consts,
+            co.co_names,
+            co.co_varnames,
+            co.co_filename,
+            co.co_name,
+            co.co_qualname,  # type: ignore[attr-defined]
+            co.co_firstlineno,
+            co.co_lnotab,
+            co.co_exceptiontable,  # type: ignore[attr-defined]
+            co.co_freevars,
+            co.co_cellvars,
+        )
+    elif hasattr(co, "co_posonlyargcount"):
+        co_args = (
+            nargs,
+            0,
+            0,
+            co.co_nlocals,
+            co.co_stacksize,
+            co_flags,
+            co.co_code,
+            co.co_consts,
+            co.co_names,
+            co.co_varnames,
+            co.co_filename,
+            co.co_name,
+            co.co_firstlineno,
+            co.co_lnotab,
+            co.co_freevars,
+            co.co_cellvars,
+        )
+    else:
+        co_args = (
+            nargs,
+            0,
+            co.co_nlocals,
+            co.co_stacksize,
+            co_flags,
+            co.co_code,
+            co.co_consts,
+            co.co_names,
+            co.co_varnames,
+            co.co_filename,
+            co.co_name,
+            co.co_firstlineno,
+            co.co_lnotab,
+            co.co_freevars,
+            co.co_cellvars,
+        )
+    new_code = CodeType(*co_args)  # type: ignore[arg-type]
+    return FunctionType(
+        new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
+    )
+
+    # we need to insert placeholder nodes for *args and **kwargs
+    # we can't call this function normally, otherwise it would try to unpack them
+    # instead, let's make python think that args and kwargs are normal variables
+
+
+@compatibility(is_backward_compatible=False)
+class PHBase:
+    """
+    Object representing an input placeholder to `concrete_args`
+    """
+
+    def __repr__(self):
+        return "PH"
+
+
+PH = PHBase()
+
+
+@compatibility(is_backward_compatible=False)
+class PHWithMeta(PHBase):
+    """
+    Object representing an input placeholder to `concrete_args`
+    """
+    def __init__(self, ph_key: Optional[str] = None):
+        super().__init__()
+
+        # Provide a hey for user to identify placeholder node during analysis
+        self.ph_key = ph_key
+
+
+def _transfer_attrs(fr, to):
+    for attr_name in dir(fr):
+        attr_val = getattr(fr, attr_name)
+        if (
+            not callable(attr_val)
+            and not attr_name.startswith("__")
+            and not hasattr(to, attr_name)
+        ):
+            setattr(to, attr_name, attr_val)
+
+
+@compatibility(is_backward_compatible=True)
+class Tracer(TracerBase):
+    # Reference: https://github.com/pytorch/pytorch/issues/54354
+    # The first line of this docstring overrides the one Sphinx generates for the
+    # documentation. We need it so that Sphinx doesn't leak `math`s path from the
+    # build environment (e.g. ` None:
+        # This method's signature is overridden by the first line of this class'
+        # docstring. If this method's signature is modified, the signature that
+        # overrides it also should be modified accordingly.
+
+        """
+        Construct a Tracer object.
+
+        Args:
+
+            autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
+                Python modules whose functions should be wrapped automatically
+                without needing to use fx.wrap(). Backward-compatibility for
+                this parameter is guaranteed.
+
+            autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
+                Python functions that should be wrapped automatically without
+                needing to use fx.wrap(). Backward compatibility for this
+                parameter is guaranteed.
+
+            param_shapes_constant (bool): When this flag is set,  calls to shape,
+                size and a few other shape like attributes of a module's parameter
+                will be evaluated directly, rather than returning a new Proxy value
+                for an attribute access. Backward compatibility for this parameter
+                is guaranteed.
+        """
+
+        super().__init__()
+
+        # Functions we will eagerly wrap when we see them while tracing
+        # this captures both `math.sqrt()` and `from math import sqrt` automatically
+        self._autowrap_function_ids: Set[int] = {
+            id(value)
+            for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
+            if not name.startswith("_") and callable(value)
+        }
+        self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
+
+        # Python modules to apply autowrap to at the start, in addition to
+        # modules we see while tracing
+        self._autowrap_search: List[ModuleType] = list(autowrap_modules)
+        self.param_shapes_constant = param_shapes_constant
+
+        self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
+        self.root_module_name: str = ""
+        # Maps the containing module's name to the operator name
+        self.scope = Scope("", None)
+        # Records the module call stack
+        self.module_stack = collections.OrderedDict()
+        # Mapping of node name to module scope
+        self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
+
+    @compatibility(is_backward_compatible=True)
+    def create_arg(self, a: Any) -> "Argument":
+        """
+        A method to specify the behavior of tracing when preparing values to
+        be used as arguments to nodes in the ``Graph``.
+
+        By default, the behavior includes:
+
+        #. Iterate through collection types (e.g. tuple, list, dict) and recursively
+           call ``create_args`` on the elements.
+        #. Given a Proxy object, return a reference to the underlying IR ``Node``
+        #. Given a non-Proxy Tensor object, emit IR for various cases:
+
+            * For a Parameter, emit a ``get_attr`` node referring to that Parameter
+            * For a non-Parameter Tensor, store the Tensor away in a special
+              attribute referring to that attribute.
+
+        This method can be overridden to support more types.
+
+        Args:
+
+            a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
+
+
+        Returns:
+
+            The value ``a`` converted into the appropriate ``Argument``
+        """
+        # The base tracer is used to construct Graphs when there is no associated
+        # module hierarchy, so it can never create parameter references.
+        # The default tracer adds the ability to refer to parameters when
+        # tracing modules.
+        if isinstance(a, torch.nn.Parameter):
+            for n, p in self.root.named_parameters():
+                if a is p:
+                    return self.create_node("get_attr", n, (), {})
+            raise NameError("parameter is not a member of this module")
+        elif isinstance(a, torch.Tensor):
+            for n_, p_ in self.root.named_buffers():
+                if a is p_:
+                    return self.create_node("get_attr", n_, (), {})
+        elif isinstance(a, torch.nn.Module):
+            for n_, p_ in self.root.named_modules():
+                if a is p_:
+                    return self.create_node("get_attr", n_, (), {})
+        # For NamedTuple instances that appear literally as args, we emit
+        # a node to construct the NamedTuple and use that Node as the argument.
+        if isinstance(a, tuple) and hasattr(a, "_fields"):
+            args = tuple(self.create_arg(elem) for elem in a)
+            return self.create_node("call_function", a.__class__, args, {})
+
+        # Tensors do not have a reliable string repr() from which they can be
+        # constructed (and we probably don't want to rely on that, either), so
+        # for any constant Tensor values we encounter, first search for if they
+        # are an attribute of some module in the module hierarchy. If so, emit
+        # a get_attr to retrieve that tensor. Otherwise, we'll store away the
+        # tensor value into a special attribute on the Module s.t. we can
+        # retrieve it with a get_attr.
+        if isinstance(a, (torch.Tensor, ScriptObject)):
+            qualname: Optional[str] = self.tensor_attrs.get(a)
+
+            # Tensor was not found in the Module hierarchy, stow it away in a
+            # special attribute and set the qualname to refer to that
+            if not qualname:
+                i = 0
+                while True:
+                    qualname = f"_tensor_constant{i}"
+                    if not hasattr(self.root, qualname):
+                        break
+                    i += 1
+                self.tensor_attrs[a] = qualname
+                setattr(self.root, qualname, a)
+
+            return self.create_node("get_attr", qualname, (), {})
+
+        if type(a) in _proxyable_classes:
+            # This is an instance of a proxyable class for which we did not
+            # witness its construction. Intern this as a constant attribute
+
+            # TODO: binary search
+            i = 0
+            while True:
+                qualname = f"_{a.__class__.__name__}_constant_{i}"
+                if not hasattr(self.root, qualname):
+                    break
+                i += 1
+            setattr(self.root, qualname, a)
+
+            return self.create_node("get_attr", qualname, (), {})
+
+        return super().create_arg(a)
+
+    @compatibility(is_backward_compatible=True)
+    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
+        """
+        A method to specify whether a given ``nn.Module`` is a "leaf" module.
+
+        Leaf modules are the atomic units that appear in
+        the IR, referenced by ``call_module`` calls. By default,
+        Modules in the PyTorch standard library namespace (torch.nn)
+        are leaf modules. All other modules are traced through and
+        their constituent ops are recorded, unless specified otherwise
+        via this parameter.
+
+        Args:
+
+            m (Module): The module being queried about
+            module_qualified_name (str): The path to root of this module. For example,
+                if you have a module hierarchy where submodule ``foo`` contains
+                submodule ``bar``, which contains submodule ``baz``, that module will
+                appear with the qualified name ``foo.bar.baz`` here.
+        """
+        return (
+            (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
+            and not isinstance(m, torch.nn.Sequential)
+        )
+
+    @compatibility(is_backward_compatible=True)
+    def path_of_module(self, mod: torch.nn.Module) -> str:
+        """
+        Helper method to find the qualified name of ``mod`` in the Module hierarchy
+        of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
+        a submodule named ``bar``, passing ``bar`` into this function will return
+        the string "foo.bar".
+
+        Args:
+
+            mod (str): The ``Module`` to retrieve the qualified name for.
+        """
+        # Prefer the O(1) algorithm
+        if self.submodule_paths:
+            path = self.submodule_paths.get(mod)
+            if path is None:
+                raise NameError("module is not installed as a submodule")
+            assert isinstance(path, str)
+            return path
+        # O(N^2) fallback in the case that we didn't store the submodule
+        # paths.
+        else:
+            for n, p in self.root.named_modules():
+                if mod is p:
+                    return n
+            raise NameError("module is not installed as a submodule")
+
+    @compatibility(is_backward_compatible=True)
+    def call_module(
+        self,
+        m: torch.nn.Module,
+        forward: Callable[..., Any],
+        args: Tuple[Any, ...],
+        kwargs: Dict[str, Any],
+    ) -> Any:
+        """
+        Method that specifies the behavior of this ``Tracer`` when it encounters
+        a call to an ``nn.Module`` instance.
+
+        By default, the behavior is to check if the called module is a leaf module
+        via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
+        ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
+        the operations in its ``forward`` function.
+
+        This method can be overridden to--for example--create nested traced
+        GraphModules, or any other behavior you would want while tracing across
+        ``Module`` boundaries.
+
+        Args:
+
+            m (Module): The module for which a call is being emitted
+            forward (Callable): The forward() method of the ``Module`` to be invoked
+            args (Tuple): args of the module callsite
+            kwargs (Dict): kwargs of the module callsite
+
+        Return:
+
+            The return value from the Module call. In the case that a ``call_module``
+            node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
+            value was returned from the ``Module`` invocation.
+        """
+        module_qualified_name = self.path_of_module(m)
+        with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
+            # module_stack is an ordered dict so writing then deleting the
+            # entry is equivalent to push/pop on a list
+            self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
+            if not self.is_leaf_module(m, module_qualified_name):
+                ret_val = forward(*args, **kwargs)
+            else:
+                ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
+            key, _ = self.module_stack.popitem(last=True)
+            assert key == _scope.module_path, f" Unexpected key {key}"
+
+        return ret_val
+
+    @compatibility(is_backward_compatible=False)
+    def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
+        """
+        Method that specifies the behavior of this ``Tracer`` when we call getattr
+        on a call to an ``nn.Module`` instance.
+
+        By default, the behavior is to return a proxy value for the attribute. It
+        also stores the proxy value in the ``parameter_proxy_cache``, so that future
+        calls will reuse the proxy rather than creating a new one.
+
+        This method can be overridden to --for example-- not return proxies when
+        querying parameters.
+
+        Args:
+
+            attr (str): The name of the attribute being queried
+            attr_val (Any): The value of the attribute
+            parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
+
+        Return:
+
+            The return value from the getattr call.
+        """
+        def maybe_get_proxy_for_attr(
+            attr_val, collection_to_search, parameter_proxy_cache
+        ):
+            for n, p in collection_to_search:
+                if attr_val is p:
+                    if n not in parameter_proxy_cache:
+                        kwargs = {}
+                        if (
+                            "proxy_factory_fn"
+                            in inspect.signature(self.create_proxy).parameters
+                        ):
+                            kwargs["proxy_factory_fn"] = (
+                                None
+                                if not self.param_shapes_constant
+                                else lambda node: ParameterProxy(
+                                    self, node, n, attr_val
+                                )
+                            )
+                        val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)  # type: ignore[arg-type]
+                        parameter_proxy_cache[n] = val_proxy
+                    return parameter_proxy_cache[n]
+            return None
+
+        if isinstance(attr_val, torch.nn.Parameter):
+            maybe_parameter_proxy = maybe_get_proxy_for_attr(
+                attr_val, self.root.named_parameters(), parameter_proxy_cache
+            )
+            if maybe_parameter_proxy is not None:
+                return maybe_parameter_proxy
+
+        if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
+            maybe_buffer_proxy = maybe_get_proxy_for_attr(
+                attr_val, self.root.named_buffers(), parameter_proxy_cache
+            )
+            if maybe_buffer_proxy is not None:
+                return maybe_buffer_proxy
+
+        return attr_val
+
+    # This method will be refactored
+    @compatibility(is_backward_compatible=False)
+    def create_args_for_root(self, root_fn, is_module, concrete_args=None):
+        """
+        Create ``placeholder`` nodes corresponding to the signature of the ``root``
+        Module. This method introspects root's signature and emits those
+        nodes accordingly, also supporting ``*args`` and ``**kwargs``.
+        """
+        # In some cases, a function or method has been decorated with a wrapper
+        # defined via ``functools.wraps``. In this case, the outer code object
+        # will likely not contain the actual parameters we care about, so unwrap
+        # the function to get to the innermost callable.
+        fn_for_analysis = inspect.unwrap(root_fn)
+        co = fn_for_analysis.__code__
+        total_args = co.co_argcount + co.co_kwonlyargcount
+        orig_args = list(co.co_varnames)
+        names_iter = iter(co.co_varnames)
+        args: List[Any] = []
+        skip_arg_idx = 0
+        if is_module:
+            if total_args == 0:
+                raise RuntimeError(
+                    "``self`` argument cannot be part of *args expansion!"
+                )
+            skip_arg_idx = 1
+            next(names_iter)  # skip self
+            args.append(self.root)
+
+        sig = inspect.signature(fn_for_analysis)
+
+
+        # This covers the very specific case where we are passing in flat
+        # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
+        # In this case, just take the concrete_args and pass them through.
+        name_idx = 0
+        if isinstance(concrete_args, tuple) and \
+                len(concrete_args) > 0 and \
+                (co.co_flags & HAS_VARSTUFF) and \
+                total_args == 1:
+            for concrete_arg in concrete_args:
+                out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
+                if isinstance(concrete_arg, PHBase):
+                    if concrete_arg != PH:
+                        # Transfer attrs in the case where you're using a placeholder other
+                        # than the singleton PH (PH has no attributes to transfer).
+                        # Proxies were created out of the placeholders.
+                        # Transfer any metadata (put on the placeholders in the form of
+                        # attributes set by the user) from the placeholder to the
+                        # underlying nodes (the proxy is unwrapped by the user, but
+                        # the metadata should hold).
+                        _transfer_attrs(fr=concrete_arg, to=out.node)
+                args.append(out)
+                name_idx += 1
+            return root_fn, args
+
+        arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
+        if isinstance(concrete_args, tuple):
+            if len(arg_names) != len(concrete_args):
+                raise RuntimeError(
+                    f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
+                )
+            concrete_args = dict(zip(arg_names, concrete_args))
+
+        def proxy_placeholder(name):
+            return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
+
+        args.extend(proxy_placeholder(names) for names in arg_names)
+
+        if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
+            # TODO: type annotations for *args and **kwargs
+            if co.co_flags & inspect.CO_VARARGS:
+                args.append(proxy_placeholder("*" + next(names_iter)))
+            if co.co_flags & inspect.CO_VARKEYWORDS:
+                args.append(proxy_placeholder("**" + next(names_iter)))
+            root_fn = _patch_function(root_fn, len(args))
+
+        flat_args, in_spec = pytree.tree_flatten(tuple(args))
+        if not all(child.is_leaf() for child in in_spec.children_specs):
+            # In the case that we have pytree-flattened inputs in
+            # `concrete_args`, generate a flattening wrapper around the
+            # original root function and return that.
+            self.graph._codegen = _PyTreeCodeGen(
+                _PyTreeInfo(orig_args[:total_args], in_spec, None)
+            )
+
+            def flatten_fn(*args):
+                tree_args = pytree.tree_unflatten(list(args), in_spec)
+                tree_out = root_fn(*tree_args)
+                out_args, out_spec = pytree.tree_flatten(tree_out)
+                assert isinstance(self.graph._codegen, _PyTreeCodeGen)
+                self.graph._codegen.pytree_info = (
+                    self.graph._codegen.pytree_info._replace(out_spec=out_spec)
+                )
+                return out_args
+
+            return flatten_fn, flat_args
+        return root_fn, args
+
+    @compatibility(is_backward_compatible=True)
+    def trace(
+        self,
+        root: Union[torch.nn.Module, Callable[..., Any]],
+        concrete_args: Optional[Dict[str, Any]] = None,
+    ) -> Graph:
+        """
+        Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
+        can either be an ``nn.Module`` instance or a Python callable.
+
+        Note that after this call, ``self.root`` may be different from the ``root`` passed
+        in here. For example, when a free function is passed to ``trace()``, we will
+        create an ``nn.Module`` instance to use as the root and add embedded constants
+        to.
+
+
+        Args:
+
+            root (Union[Module, Callable]): Either a ``Module`` or a function to be
+                traced through. Backwards-compatibility for this parameter is
+                guaranteed.
+            concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
+                not be treated as Proxies. This parameter is experimental and
+                its backwards-compatibility is *NOT* guaranteed.
+
+        Returns:
+
+            A ``Graph`` representing the semantics of the passed-in ``root``.
+        """
+        global _is_fx_tracing_flag
+        old_is_fx_tracing_flag = _is_fx_tracing_flag
+        _is_fx_tracing_flag = True
+        try:
+            if isinstance(root, torch.nn.Module):
+
+                # do real recompilation for _LazyGraphModule before retracing since the trace
+                # method can not trace the _lazy_forward method. Got error:
+                #   https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
+                # without this.
+                from torch.fx._lazy_graph_module import _LazyGraphModule
+                _LazyGraphModule.force_recompile(root)
+
+                self.root = root
+
+                assert hasattr(
+                    type(root), self.traced_func_name
+                ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
+
+                fn = getattr(type(root), self.traced_func_name)
+                self.root_module_name = root._get_name()
+                self.submodule_paths = {mod: name for name, mod in root.named_modules()}
+            else:
+                self.root = torch.nn.Module()
+                fn = root
+
+            tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
+            self.graph = Graph(tracer_cls=tracer_cls)
+            if hasattr(fn, '__code__'):
+                code = fn.__code__
+                self.graph._co_fields = {
+                    'co_name': code.co_name,
+                    'co_filename': code.co_filename,
+                    'co_firstlineno': code.co_firstlineno,
+                }
+
+            # When we encounter a Tensor value that's not a parameter, we look if it
+            # is some other attribute on the model. Construct a dict mapping Tensor
+            # values to the qualified name here for efficiency. This is used downstream
+            # in create_arg
+            self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}
+
+            def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
+                for k, v in m.__dict__.items():
+                    if isinstance(v, (torch.Tensor, ScriptObject)):
+                        self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
+                for k, v in m.named_children():
+                    collect_tensor_attrs(v, prefix_atoms + [k])
+
+            collect_tensor_attrs(self.root, [])
+
+            assert isinstance(fn, FunctionType)
+
+            fn_globals = fn.__globals__  # run before it gets patched
+            fn, args = self.create_args_for_root(
+                fn, isinstance(root, torch.nn.Module), concrete_args
+            )
+
+            parameter_proxy_cache: Dict[
+                str, Proxy
+            ] = {}  # Reduce number of get_attr calls
+
+            # Method dispatch on parameters is not recorded unless it's directly used.
+            # Thus, we need to insert a proxy when __getattr__ requests a parameter.
+            @functools.wraps(_orig_module_getattr)
+            def module_getattr_wrapper(mod, attr):
+                attr_val = _orig_module_getattr(mod, attr)
+                return self.getattr(attr, attr_val, parameter_proxy_cache)
+
+            @functools.wraps(_orig_module_call)
+            def module_call_wrapper(mod, *args, **kwargs):
+                def forward(*args, **kwargs):
+                    return _orig_module_call(mod, *args, **kwargs)
+
+                _autowrap_check(
+                    patcher,
+                    getattr(getattr(mod, "forward", mod), "__globals__", {}),
+                    self._autowrap_function_ids,
+                )
+                return self.call_module(mod, forward, args, kwargs)
+
+            with _Patcher() as patcher:
+                # allow duplicate patches to support the case of nested calls
+                patcher.patch_method(
+                    torch.nn.Module,
+                    "__getattr__",
+                    module_getattr_wrapper,
+                    deduplicate=False,
+                )
+                patcher.patch_method(
+                    torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
+                )
+                _patch_wrapped_functions(patcher)
+                _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
+                for module in self._autowrap_search:
+                    _autowrap_check(
+                        patcher, module.__dict__, self._autowrap_function_ids
+                    )
+                self.create_node(
+                    "output",
+                    "output",
+                    (self.create_arg(fn(*args)),),
+                    {},
+                    type_expr=fn.__annotations__.get("return", None),
+                )
+
+            self.submodule_paths = None
+        finally:
+            _is_fx_tracing_flag = old_is_fx_tracing_flag
+        return self.graph
+
+    def __deepcopy__(self, memo):
+        # _autowrap_search contains modules, which cannot be deepcopied.
+        new_tracer = Tracer.__new__(Tracer)
+
+        for k, v in self.__dict__.items():
+            if k in {'_autowrap_search'}:
+                new_obj = copy.copy(v)
+            else:
+                new_obj = copy.deepcopy(v, memo)
+
+            new_tracer.__dict__[k] = new_obj
+
+        return new_tracer
+
+    def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
+        if concrete_args is not None and name in concrete_args:
+            cnt = 0
+
+            def replace_ph(x):
+                nonlocal cnt
+                cnt += 1
+                param = sig.parameters[name]
+                default = (
+                    ()
+                    if param.default is inspect.Parameter.empty
+                    else (param.default,)
+                )
+                out = self.create_proxy(
+                    "placeholder", f"{name}_{str(cnt)}", default, {}
+                )
+                if isinstance(x, PHBase):
+                    if x != PH:
+                        # Transfer attrs in the case where you're using a placeholder other
+                        # than the singleton PH (PH has no attributes to transfer).
+                        # Proxies were created out of the placeholders.
+                        # Transfer any metadata (put on the placeholders in the form of
+                        # attributes set by the user) from the placeholder to the
+                        # underlying nodes (the proxy is unwrapped by the user, but
+                        # the metadata should hold).
+                        _transfer_attrs(fr=x, to=out.node)
+
+                    return out
+                # Union[int, bool] == bool in Python <= 3.6
+                if (
+                    type(x) == bool
+                    or type(x) in base_types
+                    and type(x) != torch.Tensor
+                ):
+                    torch._assert(
+                        out == x,
+                        f"{name} has been specialized to have value {x} but got another value",
+                    )
+                elif x is None:
+                    args = (
+                        out,
+                        f"{name} has been specialized to have value None but got another value",
+                    )
+                    self.create_proxy("call_function", _assert_is_none, args, {})
+                else:
+                    warnings.warn(
+                        f"Was not able to add assertion to guarantee correct input {name} to "
+                        f"specialized function. It is up to the user to make sure that your inputs match the "
+                        f"inputs you specialized the function with."
+                    )
+
+                return x
+
+            return pytree.tree_map(replace_ph, concrete_args[name])
+        if name[0] == "*":
+            default = ()
+        else:
+            param = sig.parameters[name]
+            default = () if param.default is inspect.Parameter.empty else (param.default,)  # type: ignore[assignment]
+        return self.create_proxy(
+            "placeholder",
+            name,
+            default,
+            {},
+            type_expr=fn_for_analysis.__annotations__.get(name, None)
+        )
+
+
+# Dictionary of (id(globals dict), function name) => globals_dict to patch for
+# the purposes of the wrap() API.
+# We key by the globals dict id and function name to ensure we're wrapping a given
+# function only once.
+_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}
+
+# List of methods on classes to wrap (class type, function name)
+# this currently only works for Tensor.* methods that aren't traced properly
+_wrapped_methods_to_patch: List[Tuple[type, str]] = []
+
+if os.environ.get("FX_PATCH_GETITEM") == "1":
+    # This change is needed to trace models like PositionalEmbedding from BERT:
+    # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
+    # but causes issues in quantization documented here:
+    # https://github.com/pytorch/pytorch/issues/50710
+    # once that is fixed we can make this the default behavior.
+    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
+
+
+def _find_proxy(*objects_to_search):
+    """
+    Recursively search a data structure for a Proxy() and return it,
+    return None if not found.
+    """
+    proxy = None
+
+    def find_proxy(x):
+        nonlocal proxy
+        if isinstance(x, Proxy):
+            proxy = x
+
+    map_aggregate(objects_to_search, find_proxy)
+    return proxy
+
+
+def _create_wrapped_func(orig_fn):
+    @functools.wraps(orig_fn)
+    def wrapped(*args, **kwargs):
+        """
+        Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
+        a Proxy object. If there is one, emit a ``call_function`` node to preserve the
+        call to this leaf function directly. Otherwise, just return the results of
+        this function call, as this function is not being traced.
+        """
+        proxy = _find_proxy(args, kwargs)
+        if proxy is not None:
+            return_proxy = proxy.tracer.create_proxy(
+                "call_function", orig_fn, args, kwargs
+            )
+            return_proxy.node.meta["is_wrapped"] = True
+            return return_proxy
+        return orig_fn(*args, **kwargs)
+
+    return wrapped
+
+
+def _create_wrapped_method(cls, name):
+    orig_fn = getattr(cls, name)
+
+    @functools.wraps(orig_fn)
+    def wrapped(*args, **kwargs):
+        """
+        Search the args and kwargs for a Proxy object. If there is one,
+        emit a ``call_method`` node to preserve the call to this method
+        directly. Otherwise, just return the results of this function
+        call, as this function is not being traced.
+        """
+        proxy = _find_proxy(args, kwargs)
+        if proxy is not None:
+            return proxy.tracer.create_proxy("call_method", name, args, kwargs)
+        return orig_fn(*args, **kwargs)
+
+    return wrapped
+
+
+class _PatchedFn(NamedTuple):
+    frame_dict: Any
+    fn_name: str
+    orig_fn: Any
+
+    def revert(self):
+        raise NotImplementedError()
+
+
+class _PatchedFnSetItem(_PatchedFn):
+    def revert(self):
+        self.frame_dict[self.fn_name] = self.orig_fn
+
+
+class _PatchedFnDel(_PatchedFn):
+    def revert(self):
+        del self.frame_dict[self.fn_name]
+
+
+class _PatchedFnSetAttr(_PatchedFn):
+    def revert(self):
+        setattr(self.frame_dict, self.fn_name, self.orig_fn)
+
+
+class _Patcher:
+    def __init__(self):
+        super().__init__()
+        self.patches_made: List[_PatchedFn] = []
+        self.visited: Set[int] = set()
+
+    def patch(
+        self,
+        frame_dict: Dict[str, Any],
+        name: str,
+        new_fn: Callable,
+        deduplicate: bool = True,
+    ):
+        """
+        Replace frame_dict[name] with new_fn until we exit the context manager.
+        """
+        new_fn.__fx_already_patched = deduplicate  # type: ignore[attr-defined]
+        if name not in frame_dict and hasattr(builtins, name):
+            self.patches_made.append(_PatchedFnDel(frame_dict, name, None))
+        elif getattr(frame_dict[name], "__fx_already_patched", False):
+            return  # already patched, no need to do it again
+        else:
+            self.patches_made.append(
+                _PatchedFnSetItem(frame_dict, name, frame_dict[name])
+            )
+        frame_dict[name] = new_fn
+
+    def patch_method(
+        self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
+    ):
+        """
+        Replace object_or_dict.name with new_fn until we exit the context manager.
+        """
+        new_fn.__fx_already_patched = deduplicate  # type: ignore[attr-defined]
+        orig_fn = getattr(cls, name)
+        if getattr(orig_fn, "__fx_already_patched", False):
+            return  # already patched, no need to do it again
+        self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn))
+        setattr(cls, name, new_fn)
+
+    def visit_once(self, thing: Any):
+        """Return True on the first call to with thing, otherwise false"""
+        idx = id(thing)
+        if idx in self.visited:
+            return False
+        self.visited.add(idx)
+        return True
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        """
+        Undo all the changes made via self.patch() and self.patch_method()
+        """
+        while self.patches_made:
+            # unpatch in reverse order to handle duplicates correctly
+            self.patches_made.pop().revert()
+        self.visited.clear()
+
+
+def _patch_wrapped_functions(patcher: _Patcher):
+    """
+    Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
+    the listed global functions in the `_create_wrapped_func` wrapper.
+    """
+    for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
+        if name not in frame_dict and hasattr(builtins, name):
+            orig_fn = getattr(builtins, name)
+        else:
+            orig_fn = frame_dict[name]
+        patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
+
+    for cls, name in _wrapped_methods_to_patch:
+        patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
+
+
+def _autowrap_check(
+    patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
+):
+    """
+    Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
+    This method searches a scope for them and patches them if found.
+    """
+    if patcher.visit_once(frame_dict):
+        for name, value in frame_dict.items():
+            if (
+                not name.startswith("_")
+                and callable(value)
+                and id(value) in function_ids
+            ):
+                patcher.patch(frame_dict, name, _create_wrapped_func(value))
+
+
+@compatibility(is_backward_compatible=True)
+def wrap(fn_or_name: Union[str, Callable]):
+    """
+    This function can be called at module-level scope to register fn_or_name as a "leaf function".
+    A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
+    traced through::
+
+        # foo/bar/baz.py
+        def my_custom_function(x, y):
+            return x * x + y * y
+
+        torch.fx.wrap('my_custom_function')
+
+        def fn_to_be_traced(x, y):
+            # When symbolic tracing, the below call to my_custom_function will be inserted into
+            # the graph rather than tracing it.
+            return my_custom_function(x, y)
+
+    This function can also equivalently be used as a decorator::
+
+        # foo/bar/baz.py
+        @torch.fx.wrap
+        def my_custom_function(x, y):
+            return x * x + y * y
+
+    A wrapped function can be thought of a "leaf function", analogous to the concept of
+    "leaf modules", that is, they are functions that are left as calls in the FX trace
+    rather than traced through.
+
+    Args:
+
+        fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
+            graph when it's called
+    """
+    if not callable(fn_or_name) and not isinstance(fn_or_name, str):
+        raise RuntimeError(
+            "Unsupported type for global function! Must be either a callable or "
+            "string name"
+        )
+
+    if callable(fn_or_name):
+        assert not isinstance(fn_or_name, str)  # to make mypy happy
+        fn_name = fn_or_name.__name__
+    else:
+        assert isinstance(
+            fn_or_name, str
+        ), "fn_or_name must be a global function or string name"
+        fn_name = fn_or_name
+
+    currentframe = inspect.currentframe()
+    assert currentframe is not None
+    f = currentframe.f_back
+    assert f is not None
+    if f.f_code.co_name != "":
+        raise NotImplementedError("wrap must be called at the top level of a module")
+
+    # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
+    # semantics would be slightly different, but would add support `from x import wrapped_function`
+    _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
+    return fn_or_name
+
+
+@compatibility(is_backward_compatible=True)
+def symbolic_trace(
+    root: Union[torch.nn.Module, Callable[..., Any]],
+    concrete_args: Optional[Dict[str, Any]] = None,
+) -> GraphModule:
+    """
+    Symbolic tracing API
+
+    Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
+    constructed by recording operations seen while tracing through ``root``.
+
+    ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
+
+    For example::
+
+        def f(a, b):
+            if b == True:
+                return a
+            else:
+                return a*2
+
+    FX can typically not trace through this due to the presence of control
+    flow. However, we can use `concrete_args` to specialize on the value of
+    `b` to trace through this::
+
+        f = fx.symbolic_trace(f, concrete_args={'b': False})
+        assert f(3, False)  == 6
+
+    Note that although you can still pass in different values of `b`, they will be ignored.
+
+    We can also use `concrete_args` to eliminate data-structure handling from
+    our function. This will use pytrees to flatten your input. To avoid
+    overspecializing, pass in `fx.PH` for values that shouldn't be
+    specialized. For example::
+
+        def f(x):
+            out = 0
+            for v in x.values():
+                out += v
+            return out
+        f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
+        assert f({'a': 1, 'b': 2, 'c': 4}) == 7
+
+
+    Args:
+        root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
+            into a Graph representation.
+        concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
+
+    Returns:
+        GraphModule: a Module created from the recorded operations from ``root``.
+    """
+    tracer = Tracer()
+    graph = tracer.trace(root, concrete_args)
+    name = (
+        root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
+    )
+    return _make_graph_module(tracer.root, graph, name)
+
+
+@wrap
+def _assert_is_none(value, msg):
+    assert value is None, msg
diff --git a/MLPY/Lib/site-packages/torch/fx/annotate.py b/MLPY/Lib/site-packages/torch/fx/annotate.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a771009b04405fe381674138a7762efae0b6de2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/annotate.py
@@ -0,0 +1,21 @@
+from torch.fx.proxy import Proxy
+from ._compatibility import compatibility
+
+@compatibility(is_backward_compatible=False)
+def annotate(val, type):
+    # val could be either a regular value (not tracing)
+    # or fx.Proxy (tracing)
+    if isinstance(val, Proxy):
+        if val.node.type:
+            raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
+                               f" Existing type is {val.node.type} "
+                               f"and new type is {type}. "
+                               f"This could happen if you tried to annotate a function parameter "
+                               f"value (in which case you should use the type slot "
+                               f"on the function signature) or you called "
+                               f"annotate on the same value twice")
+        else:
+            val.node.type = type
+        return val
+    else:
+        return val
diff --git a/MLPY/Lib/site-packages/torch/fx/config.py b/MLPY/Lib/site-packages/torch/fx/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2539e748df4aa8016359bd1b068baa7653fcf686
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/config.py
@@ -0,0 +1,6 @@
+# Whether to disable showing progress on compilation passes
+# Need to add a new config otherwise wil get a circular import if dynamo config is imported here
+disable_progress = True
+
+# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
+verbose_progress = False
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__init__.py b/MLPY/Lib/site-packages/torch/fx/experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f79cd452ce43b99533bc6993c4f95bbee4a45fd4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c329fa6b150bddb2908217850f4a3e632704705a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_config.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e81713a0958c5f7958694dbf1e5aaca8371fa6b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_config.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_sym_dispatch_mode.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_sym_dispatch_mode.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02fac73ef922ed8b2bb9c8196435f7e5501d20c5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/_sym_dispatch_mode.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbc0f6880c6417efb1927642afc5b37e53ce8a1a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b79bddcfb7291c15d55a7fe324620ea78a6bd425
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/debug.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/debug.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50773f27abbd25a65f761da1d4ce5527d291d201
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/debug.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efd0b203678004b0756579c451c3428a1e951996
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c30e6c21f3f2cf89e139073c0dac7ea5ed38e11
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a2f4ddde21e82fe407adba5c04ceb70a6204ba2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d8eb1260f1009de0aff27edd3e0cac472c4a5ed
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f61bad3a01dcaf019277ac4d777ceb6d9caba446
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4df3e4f26c5de1d75431ebc5d66dc5d28daffa4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..421e5360d50ee065c1575f6663db5bb9aed3d8d3
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/recording.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/recording.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fafd7f16f9c25921802d1a4527f572a5e5d89b04
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/recording.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52f5b718bb8b812f28681548c0af6c31e1b75f1f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11911e3aa45de838b54db3002d64ad3025d3d414
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b42322c237613e5ed5b34bc9dfa75c5965d58c5
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e82518b7083ef477e269fc512cb2ff05edf75dc4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6640d3a5a322305ca55a81ca6167cb5127f1fc28
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/symbolic_shapes.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3789b225b37c556046ed5eb9a6c9116d0c0354de
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/validator.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/validator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a24a581a4404fe63a2b61a72e5f5e6a98783c20
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/__pycache__/validator.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/_backward_state.py b/MLPY/Lib/site-packages/torch/fx/experimental/_backward_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc9705413e9c29714d6e165c4b2ab3b34796124
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/_backward_state.py
@@ -0,0 +1,27 @@
+import torch.fx
+
+
+class BackwardState:
+    """
+    BackwardState is used to pass Python hooks from the forwards pass
+    into the backwards pass in Dynamo+Compiled Autograd.
+
+    It is created by TorchDynamo and has special handling there.
+    Dynamo will pass an empty BackwardState to the forwards, then populate
+    members on it (via setattr) only after the forwards graph is finished.
+    Later on, in CompileAutograd we will inline and add the needed guards
+    on the BackwardState.
+
+    BackwardState is identified and has special handling in AOTAutograd.
+    During AOTAutograd:
+        1) BackwardState is an input to the forwards graph
+        2) It must only be used in the backwards
+        3) It will be empty in the forwards
+        4) In the forwards we add a wrapper to save it
+        5) In the backwards it becomes an input
+        6) There can only be one per graph
+
+    BackwardState requires CompiledAutograd.
+    """
+
+    proxy: torch.fx.Proxy
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/_config.py b/MLPY/Lib/site-packages/torch/fx/experimental/_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f7e9c4cb5b41002f7e1d560ddc001032083a8c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/_config.py
@@ -0,0 +1,76 @@
+import os
+import sys
+
+from typing import Optional
+
+# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
+translation_validation = (
+    os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
+)
+# Timeout (in milliseconds) for z3 finding a solution.
+# [@compile_ignored: debug]
+translation_validation_timeout = int(
+    os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
+)
+# Disables bisection for translation validation.
+#
+# Translation validation bisection is enabled by default, if translation validation
+# is also enabled. This should help finding guard simplification issues. However,
+# since validation uses Z3 for bisecting, it might take a lot of time.
+#
+# Set this configuration option so as to avoid bisecting.
+# [@compile_ignored: debug]
+translation_validation_no_bisect = (
+    os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
+)
+# Checks whether replaying ShapeEnv events on a freshly constructed one yields
+# the a ShapeEnv with the same state. This should be used only in testing.
+check_shape_env_recorded_events = False
+
+# TODO: Perhaps consider allowing unions for the configs below (so you can hit
+# multiple reps at the same time)
+
+# Give extended debug information if the string representation of a guard
+# matches this.  For example, set this to "Ne(s0, 10)" and whenever we issue
+# this guard, we will generate full Python and C++ backtrace
+# [@compile_ignored: debug]
+extended_debug_guard_added = os.environ.get(
+    "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
+)
+
+# Give extended debug information when a particular symbol is allocated.  For
+# example, set this to "u2" and whenever we create this symbol, we will
+# generate full Python and C++ backtrace
+# [@compile_ignored: debug]
+extended_debug_create_symbol = os.environ.get(
+    "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
+)
+
+# Give extended debug information (C++ backtrace) for all extended debug
+# settings as well as errors.  The C++ backtrace is slow and very spammy so we
+# don't include it by default even when you're requesting extended debug.
+# [@compile_ignored: debug]
+extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
+
+# [@compile_ignored: debug] Show a warning for every specialization
+print_specializations = False
+
+# wraps (un)equalities with 'Not' class after recording the correct expression
+# in the FX graph. This should incorrectly construct the divisible and replacement
+# lists, and incorrectly issue guards.
+inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
+
+# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
+validate_shape_env_version_key = False
+
+# If we produce more than this many guards on a symbol, force the symbol to
+# get specialized and bail out if this many guards mention this particular
+# symbol.  This may be slightly more aggressive than the true number of guards
+# issued (as we test if we've hit the limit on-the-fly, whereas we may
+# do further simplifications at final guard issuance time that make guards
+# irrelevant.)
+symbol_guard_limit_before_specialize: Optional[int] = None
+
+from torch.utils._config_module import install_config_module
+
+install_config_module(sys.modules[__name__])
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/_sym_dispatch_mode.py b/MLPY/Lib/site-packages/torch/fx/experimental/_sym_dispatch_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..7922d0dbeb40584fd55ee7688a012ce31380a148
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/_sym_dispatch_mode.py
@@ -0,0 +1,58 @@
+from typing import List, Optional, Type
+
+__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
+
+SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
+
+
+# SymDispatchMode gets invoked whenever an operation is processed on
+# a PySymInt.  When this occurs, you get called at __sym_dispatch__
+# with the operation in question.  This is symmetric to TorchDispatchMode
+# but with some caveats:
+#
+#   - In TorchDispatchMode, you get the same arguments as what a user
+#     invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
+#     you get (a, b) as args to your call.  In SymDispatchMode, if
+#     you call a + b (where a and b are SymInts), you will get
+#     (a.node, b.node) as your args (these are PySymInts)
+#
+#   - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
+#     So you have to manually call Tracer/create_node to write into
+#     the graph.  See ProxySymDispatchMode for an example
+#
+class SymDispatchMode:
+    def __sym_dispatch__(self, func, types, args, kwargs):
+        raise NotImplementedError()
+
+    def __enter__(self):
+        global SYM_FUNCTION_MODE
+        old = SYM_FUNCTION_MODE
+        if hasattr(self, "inner"):
+            raise RuntimeError(
+                f"{self} has already been used as a mode. Please use a fresh version"
+            )
+        else:
+            self.inner = old
+        SYM_FUNCTION_MODE = self
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        global SYM_FUNCTION_MODE
+        SYM_FUNCTION_MODE = self.inner
+
+
+def handle_sym_dispatch(func, args, kwargs):
+    global SYM_FUNCTION_MODE
+    mode = sym_function_mode()
+    assert mode
+    SYM_FUNCTION_MODE = mode.inner
+    try:
+        # TODO: properly compute types
+        types: List[Type] = []
+        return mode.__sym_dispatch__(func, types, args, kwargs)
+    finally:
+        SYM_FUNCTION_MODE = mode
+
+
+def sym_function_mode():
+    return SYM_FUNCTION_MODE
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/accelerator_partitioner.py b/MLPY/Lib/site-packages/torch/fx/experimental/accelerator_partitioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d086c05db1e31b979e53024d51c3024327414a9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/accelerator_partitioner.py
@@ -0,0 +1,1078 @@
+import operator
+from collections import deque
+from typing import Dict, List, Set, NamedTuple, Tuple, Deque
+
+import torch
+from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
+from torch.fx.experimental.partitioner_utils import (
+    Partition,
+    Device,
+    PartitionerConfig,
+    get_partition_to_latency_mapping,
+    get_latency_of_partitioned_graph,
+    NodeLatency,
+    get_extra_size_of,
+    PartitionMode,
+)
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import Node, map_arg
+from torch.fx.passes.split_module import split_module
+
+
+class DAGNode:
+    """DAGNode class maintains useful information for a partition (submodule),
+    and its input submodules and output submodules.
+    """
+
+    def __init__(
+        self,
+        submodule_node: Node,
+        input_nodes: List[Node],
+        output_nodes: List[Node],
+        logical_device_ids: List[int],
+        size_bytes: int,
+    ) -> None:
+        self.submodule_node: Node = submodule_node
+        self.input_nodes: List[Node] = input_nodes
+        self.output_nodes: List[Node] = output_nodes
+        self.logical_device_ids: List[int] = logical_device_ids
+        self.size_bytes = size_bytes
+
+    def __str__(self) -> str:
+        return str(self.submodule_node)
+
+
+class DAG:
+    """DAG class contains all the DAG nodes"""
+
+    def __init__(self) -> None:
+        self.nodes: List[DAGNode] = []
+
+    def create_node(
+        self,
+        submodule_node: Node,
+        input_nodes: List[Node],
+        output_nodes: List[Node],
+        logical_devices: List[int],
+        size_bytes: int,
+    ) -> None:
+        node = DAGNode(
+            submodule_node, input_nodes, output_nodes, logical_devices, size_bytes
+        )
+        self.nodes.append(node)
+
+
+class PartitionResult(NamedTuple):
+    """NameTuple used for returning DAG and a new fx module"""
+
+    dag: DAG
+    module_with_submodules: GraphModule
+
+
+"""Followings are some helper functions for partition manipulation"""
+
+
+def reset_partition_device(partitions):
+    for partition in partitions:
+        partition.logical_device_ids = []
+
+
+def combine_two_partitions(
+    partition_0: Partition, partition_1: Partition, partitions: List[Partition]
+) -> None:
+    """Given a list of partitions and its two partitions,
+    combine these two partitions into a new one appending to the partitions
+    and remove the previous two partitions from the list of partitions
+    """
+    partition = Partition(len(partitions))
+    partition.nodes = partition_0.nodes.union(partition_1.nodes)
+    partition.recalculate_mem_size()
+    partitions.append(partition)
+    partitions.remove(partition_0)
+    partitions.remove(partition_1)
+    reorganize_partitions(partitions)
+    return
+
+
+def set_parents_and_children(partitions: List[Partition]) -> None:
+    """Given a list of partitions, mark parents and children for each partition"""
+    # Go through all nodes in a partition.
+    # If a node's user is in other partition,
+    # then the other partition is this partition's children.
+    # This partition is the other partition's parent
+    for partition in partitions:
+        partition.children = set()
+        partition.parents = set()
+    for partition in partitions:
+        for node in partition.nodes:
+            # For each node in the current partition, find its users
+            users = node.users
+            for n in users:
+                # Find which the partition the user node belongs to.
+                # Note that if the node itself is also belongs to that partition,
+                # that partition is not the child of the current partition
+                for p in partitions:
+                    if p != partition and n in p.nodes and node not in p.nodes:
+                        partition.children.add(p)
+                        p.parents.add(partition)
+    return
+
+
+def reorganize_partitions(partitions: List[Partition]) -> None:
+    """Given a list of partitions, reorganize partition id,
+    its parents and its children for each partition
+    """
+    # Rearrange partition ids
+    for i, partition in enumerate(partitions):
+        partition.partition_id = i
+    set_parents_and_children(partitions)
+    return
+
+
+def get_bfs_level_partition(partitions: List[Partition]) -> None:
+    """Given a list of partitions,
+    mark the bfs level for each partition
+    """
+    current_level: Set[Partition] = set()
+    visited: Set[Partition] = set()
+    for partition in partitions:
+        # If a partition has no parent, it should be in root level
+        if len(partition.parents) == 0:
+            current_level.add(partition)
+    next_level: Set[Partition] = set()
+    level = 0
+    # bfs
+    while current_level:
+        partition = current_level.pop()
+        partition.bfs_level = level
+        visited.add(partition)
+        children = partition.children
+        for child in children:
+            if child not in next_level:
+                next_level.add(child)
+        if not current_level:
+            current_level = next_level.copy()
+            next_level = set()
+            level += 1
+    return
+
+
+def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]:
+    """Given a list of partitions,return node to partition mapping"""
+    node_to_partition: Dict[Node, int] = {}
+    for partition in partitions:
+        for node in partition.nodes:
+            node_to_partition[node] = partition.partition_id
+    return node_to_partition
+
+
+def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]:
+    """Get a mapping from device logical ID to Device object."""
+    logical_id_to_device: Dict[int, Device] = {}
+    for d in devices:
+        logical_id_to_device[d.logical_id] = d
+    return logical_id_to_device
+
+
+def get_device_partition_stats(
+    partitions: List[Partition], devices: List[Device]
+) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]:
+    """Given a list of partitions and a list of devices, returns:
+    1. A mapping from device to partitions on it;
+    2. A mapping from device to its remaining memory size;
+    3. A list of partitions that do not have a device.
+    """
+    # logical id to device
+    logical_id_to_device = get_logical_id_to_device(devices)
+    # Track partitions on device
+    device_to_partitions: Dict[Device, List[Partition]] = {}
+    # Track device's left mem size
+    device_to_left_mem_bytes: Dict[Device, int] = {}
+    for d in devices:
+        device_to_partitions[d] = []
+        device_to_left_mem_bytes[d] = d.available_mem_bytes
+
+    # Deal with the partitions that already have a device
+    # and also collect all partitions without a device (no_device_partitions)
+    no_device_partitions = []
+    for partition in partitions:
+        if partition.logical_device_ids != []:
+            for logical_id in partition.logical_device_ids:
+                device = logical_id_to_device[logical_id]
+                device_to_partitions[device].append(partition)
+                device_to_left_mem_bytes[device] -= partition.used_mem_bytes
+        else:
+            no_device_partitions.append(partition)
+
+    return (
+        device_to_partitions,
+        device_to_left_mem_bytes,
+        no_device_partitions,
+    )
+
+
+def get_device_to_partitions_mapping(
+    partitions: List[Partition], devices: List[Device]
+):
+    """Given a list of partitions and a list of devices,
+    map each partition into a device.
+    """
+
+    def calculate_extra_mem_bytes_needed_for(
+        partition: Partition, partitions: List[Partition]
+    ):
+        all_nodes: Set[Node] = set()
+        for p in partitions:
+            all_nodes = all_nodes.union(p.nodes)
+        if len(all_nodes) == 0:
+            return partition.used_mem_bytes
+        all_nodes = all_nodes.union(partition.nodes)
+        extra_size_needed = 0
+        for node in partition.nodes:
+            extra_size_needed += get_extra_size_of(node, all_nodes)
+        return extra_size_needed
+
+    def find_device_for(partition: Partition):
+        """Given a partition, find a logical device for the partition
+        The algorithm is to put the partition on the device
+        that has just enough mem left for that partition.
+        device_to_left_mem_bytes is a dictionary between device and its left mem size
+        sorted by its left mem size
+        """
+        for d in device_to_left_mem_bytes:
+            extra_size_needed = calculate_extra_mem_bytes_needed_for(
+                partition, device_to_partitions[d]
+            )
+            if extra_size_needed < device_to_left_mem_bytes[d]:
+                device_to_partitions[d].append(partition)
+                partition.logical_device_ids.append(d.logical_id)
+                device_to_left_mem_bytes[d] -= extra_size_needed
+                return True
+        return False
+
+    (
+        device_to_partitions,
+        device_to_left_mem_bytes,
+        no_device_partitions,
+    ) = get_device_partition_stats(partitions, devices)
+
+    # Find devices for all the partitions without a device
+    found_device = True
+    for partition in no_device_partitions:
+        device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1]))
+        found_device = find_device_for(partition)
+        if not found_device:
+            break
+    return found_device
+
+
+def check_dependency(partition):
+    """Given a partition,check if there is a circular dependency on
+    this partition using bfs
+    """
+    visited: Set[Partition] = {partition}
+    queue: Deque[Partition] = deque([partition])
+    while queue:
+        p = queue.popleft()
+        for child in p.children:
+            if child == partition:
+                return True
+            else:
+                if child not in visited:
+                    visited.add(child)
+                    queue.append(child)
+    return False
+
+
+class Partitioner:
+    """A fx module may not fit into one device.
+    Partitioner class helps partition one fx module into submodules (partitions),
+    so that the submodules can be executed crossing different accelerators.
+    The main function of this class is self.partition_graph.
+    It partitions the fx module based on the scheme specified in partition_config
+    A DAG structure is returned
+    along with a new fx module with submodule nodes.
+    """
+
+    def __init__(self) -> None:
+        self.partitions: List[Partition] = []
+        self.node_to_partition: Dict[Node, int] = {}
+        self.devices: List[Device] = []
+
+    def partition_graph(
+        self,
+        fx_module: GraphModule,
+        torch_module: torch.nn.Module,
+        partitioner_config: PartitionerConfig,
+    ) -> PartitionResult:
+        """Given the fx module, torch module and partitioner_config,
+        find the partitions, do the partitions,
+        and then return a DAG and a new fx module with submodule nodes (partitions)
+        """
+        self.graph_module = fx_module
+        self.torch_module = torch_module
+        self.devices = partitioner_config.devices
+        if len(self.devices) == 0:
+            raise RuntimeError("No devices")
+        # Tag the size in bytes to all nodes in the graph_module.
+        get_size_of_all_nodes(self.graph_module)
+        # Check if there are op nodes in the fx module
+        nodes = self.graph_module.graph.nodes
+        if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes):
+            raise RuntimeError("No Partition since no operations in the module")
+        # Calculate total size of the fx module
+        total_size_of_graph = 0
+        for node in nodes:
+            if node.op == "output":
+                break
+            total_size_of_graph += node.size_bytes.total_size
+        # Find the device with the max mem size
+        device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes)
+        # AOT based partition
+        if partitioner_config.mode == PartitionMode.aot_based:
+            self.aot_based_partition(
+                partitioner_config.node_to_partition_mapping,
+                partitioner_config.partition_to_logical_device_mapping,
+            )
+        # Single partition if the whole module can be fit into one device
+        elif total_size_of_graph <= device_with_max_mem.available_mem_bytes:
+            self.find_single_partition(
+                total_size_of_graph, logical_device_id=device_with_max_mem.logical_id
+            )
+        elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]):
+            raise RuntimeError("Devices have no enough memory for the module")
+        else:
+            # Sparse nn based partition
+            if partitioner_config.mode == PartitionMode.sparse_nn:
+                available_mem_bytes = self.devices[0].available_mem_bytes
+                if not all(
+                    device.available_mem_bytes == available_mem_bytes
+                    for device in self.devices
+                ):
+                    raise RuntimeError("All devices must have same memory size!")
+                # sparse_nn_partition only support same memory size
+                # TODO: add different size support for sparse_nn_partition
+                self.sparse_nn_partition(available_mem_bytes)
+            # Cost aware partition
+            elif partitioner_config.mode == PartitionMode.cost_aware:
+                self.cost_aware_partition(
+                    partitioner_config.transfer_rate_bytes_per_sec,
+                    partitioner_config.node_to_latency_mapping,
+                )
+            # KL based partition
+            elif partitioner_config.mode == PartitionMode.kl_based:
+                self.kl_based_partition(
+                    partitioner_config.transfer_rate_bytes_per_sec,
+                    partitioner_config.node_to_latency_mapping,
+                )
+            else:
+                self.size_based_partition()
+
+        # Saturate host if possible.
+        if partitioner_config.saturate_host:
+            self.saturate_host()
+
+        # Partition the graph module based on the partition assignment.
+        module_with_submodules = self.do_partition()
+
+        # The DAG contains DAGNodes with info of each partition's input nodes, output nodes
+        # and how partitions are connected.
+        dag = self.dump_dag(module_with_submodules)
+        ret = PartitionResult(dag, module_with_submodules)
+        return ret
+
+    def find_single_partition(
+        self, total_size_of_graph, logical_device_id: int = 0
+    ) -> None:
+        """Fit the whole fx module into one device"""
+        partition_0 = self.create_partition()
+        for node in self.graph_module.graph.nodes:
+            if node.op == "output":
+                # Skip the output node, but there can
+                # be nodes after the output in certain cases.
+                continue
+            partition_0.nodes.add(node)
+        partition_0.used_mem_bytes = total_size_of_graph
+        partition_0.logical_device_ids = [logical_device_id]
+        # Get the node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        return
+
+    def size_based_partition(self) -> None:
+        """This method is to partition the fx module based on memory size.
+        It uses greedy approach. The result may not be the best.
+        The basic idea is:
+        Step 1:
+        Find a device which has enough memory to fit the current node, create a empty partition
+        with the size of that device.
+        Then keep adding the following nodes into the partition until the partition is full.
+        Step 2:
+        Repeat Step 1 until no device left
+        Step 3:
+        If some nodes are left, create a partition for each left node (single node partition).
+        and then try to map those partitions into logical devices with enough mem left.
+        """
+
+        def find_device_based_on_size(node) -> Device:
+            """Given a node, this function is to find a logical device
+            that could fit the node.
+            """
+            mem_size_needed = get_extra_size_of(node, set())
+            device = Device("", -1, -1)
+            for d in self.devices:
+                if (
+                    d not in occupied_devices
+                    and d.available_mem_bytes >= mem_size_needed
+                ):
+                    device = d
+                    break
+            if device.available_mem_bytes < 0:
+                raise RuntimeError(str(node) + "is too large to fit any device")
+            occupied_devices.append(device)
+            return device
+
+        # Track partition and its left mem size
+        partition_to_left_mem_bytes: Dict[Partition, int] = {}
+        # Track all the devices that have been used
+        occupied_devices: List[Device] = []
+        partition = self.create_partition()
+        for node in self.graph_module.graph.nodes:
+            if node.op in {"call_module", "call_method", "call_function"}:
+                # Check if there are devices left
+                if len(self.partitions) <= len(self.devices):
+                    total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
+                    # Check if the current partition is the very first partition
+                    if partition.used_mem_bytes == 0:
+                        # Find a device to fit the first node, return available mem size
+                        device = find_device_based_on_size(node)
+                        occupied_devices.append(device)
+                        # Update partition and its left mem size
+                        partition_to_left_mem_bytes[
+                            partition
+                        ] = device.available_mem_bytes
+                        # Update available mem for the current partition
+                        partition.logical_device_ids.append(device.logical_id)
+                    else:
+                        # The current partition is not the first partition
+                        # Check if the current node can fit into current partition
+                        if (
+                            partition_to_left_mem_bytes[partition]
+                            < total_size_of_input_nodes
+                        ):
+                            # Check if no device is left
+                            if len(self.partitions) == len(self.devices):
+                                # No device is left
+                                # Put the previous partitions into a list (non_single_node_partitions)
+                                non_single_node_partitions = self.partitions[:]
+                                # Create the first single node partition for the current node
+                                self.create_single_node_partition(node)
+                                continue
+                            # Some devices are still left
+                            # Create a new partition with a mem size that is enough for the current node
+                            device = find_device_based_on_size(node)
+                            partition = self.create_partition()
+                            total_size_of_input_nodes = get_extra_size_of(
+                                node, partition.nodes
+                            )
+                            partition_to_left_mem_bytes[
+                                partition
+                            ] = device.available_mem_bytes
+                            partition.logical_device_ids.append(device.logical_id)
+                    partition.add_node(node)
+                    partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
+                # Create single node partitions if no device is left
+                else:
+                    self.create_single_node_partition(node)
+        reorganize_partitions(self.partitions)
+        # Get the node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        # Mapping all partitions into device
+        found_partition_to_device_mapping = get_device_to_partitions_mapping(
+            self.partitions, self.devices
+        )
+        if not found_partition_to_device_mapping:
+            raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping")
+        return
+
+    def saturate_host(self) -> None:
+        """Saturate host by assigning replicates to unused devices with enough memory.
+        It uses a greedy approach to find a next available set of devices to place all split
+        partitions: For each used device, it searches for an idle device with minimal memory
+        size that can hold all the partition located on that device; If the search is successful
+        for all used devices, it then assigns the new devices' logical ID to the corresponding
+        partition.
+        """
+        (
+            device_to_partitions,
+            device_to_left_mem_bytes,
+            no_device_partitions,
+        ) = get_device_partition_stats(self.partitions, self.devices)
+
+        assert (
+            len(no_device_partitions) == 0
+        ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
+
+        # Devices that hold partitions
+        used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
+        # Track replicates of the assigned devices
+        replicated_device_to_used_device: Dict[Device, Device] = {}
+
+        while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
+            self.devices
+        ):
+            # Success flag for this round
+            success = True
+            # Devices that have not been assigned
+            idle_devices = [
+                d
+                for d in self.devices
+                if d not in used_devices and d not in replicated_device_to_used_device
+            ]
+            # Temporary mapping from replicated device to original device
+            temp_replicate_mapping = {}
+
+            # Find a new device to replicate all partitions on an used device
+            for used_device in used_devices:
+                # Idle devices that have enough memory
+                available_devices = [
+                    d
+                    for d in idle_devices
+                    if d.available_mem_bytes
+                    >= used_device.available_mem_bytes
+                    - device_to_left_mem_bytes[used_device]
+                ]
+                if len(available_devices) == 0:
+                    success = False
+                    break
+                new_device = min(available_devices, key=lambda d: d.available_mem_bytes)
+                idle_devices.remove(new_device)
+                temp_replicate_mapping[new_device] = used_device
+
+            if not success:
+                break
+            replicated_device_to_used_device.update(temp_replicate_mapping)
+
+        # Update logical device IDs assigned to the partitions
+        for (
+            replicate_device,
+            original_device,
+        ) in replicated_device_to_used_device.items():
+            logical_id = replicate_device.logical_id
+            for partition in device_to_partitions[original_device]:
+                partition.logical_device_ids.append(logical_id)
+        for p in self.partitions:
+            print(p.logical_device_ids)
+
+    def do_partition(self) -> GraphModule:
+        """Return a new fx module with submodule nodes (partitions)."""
+        module_with_submodules = split_module(
+            self.graph_module,
+            self.torch_module,
+            lambda node: self.node_to_partition[node],
+        )
+        return module_with_submodules
+
+    def dump_dag(self, module_with_submodules: GraphModule) -> DAG:
+        """Return the dag structure and the new fx module with submodules."""
+        dag = DAG()
+        for node in module_with_submodules.graph.nodes:
+            if node.op == "output":
+                break
+            if node.op in {"placeholder", "get_attr"}:
+                continue
+            if node.target == operator.__getitem__:
+                continue
+            input_nodes: Dict[Node, None] = {}
+            map_arg(node.args, input_nodes.setdefault)
+            map_arg(node.kwargs, input_nodes.setdefault)
+            # When a node has two or more output nodes,
+            # it outputs its result to 'getitem' nodes.
+            # Those 'getitem' nodes are the output node for this node.
+            # Otherwise, the output node is this node itself.
+            if len(node.users) > 1:
+                output_nodes = list(node.users)
+            else:
+                output_nodes = [node]
+            partition_id = int(node.name.rsplit("_", 1)[-1])
+            device_ids = self.partitions[partition_id].logical_device_ids
+            size_bytes = self.partitions[partition_id].used_mem_bytes
+            dag.create_node(
+                node, list(input_nodes), output_nodes, device_ids, size_bytes
+            )
+        return dag
+
+    def create_partition(self) -> Partition:
+        """Create a partition and append it to self.partitions."""
+        partition_id = len(self.partitions)
+        partition = Partition(partition_id)
+        self.partitions.append(partition)
+        return partition
+
+    def create_single_node_partition(self, node):
+        """Create a partition for a single node"""
+        partition = self.create_partition()
+        partition.add_node(node)
+        return
+
+    def sparse_nn_partition(self, available_mem_bytes: int) -> None:
+        """This method partition a sparse nn module.
+        It is size based partition but different from size_based_partition,
+        it only works when all the devices have same memory size (available_mem_bytes).
+        In the future, devices with different mem sizes will be supported like size_based_partition.
+        It first traverse all the nodes and do the partitions based on the same memory size.
+        If the current partition has no enough memory left for a new op node
+        (call_module, call_method, call_function), a new partition is created.
+        When crossing the boundary between non-embedding nodes and embedding nodes,
+        a new partition is created regardlessly.
+        For example, if the current node is a non-embedding node but the next node is an
+        embedding node, a new partition is created for the next node.
+        After the partition, the partitions are combined as much as possible.
+        The rule is that a non-embedding partition only
+        combines with another non-embedding one.
+        So as the embedding partitions.
+        """
+
+        def combine_partitions_based_on_size(
+            partitions: List[Partition], available_mem_bytes: int
+        ) -> None:
+            """Combining small partitions together to keep as less partitions as possible.
+            Here is an example of the algorithm to do this:
+            Assume some partitions, we first sort them based on partition used memory size.
+            [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)]
+            The available memory is 10.
+            step 1: self.find_partition_to_combine_based_on_size()
+            First, mark bfs level for each partition
+            Second, look the smallest partition, partition_4: 10 - 1 = 9
+            It means any partition has a used memory equal or less than 9 could combine this partition
+            We go from the largest and selection partition_0.
+            Check the bfs level for two partitions, if the level difference is less than 2,
+            it can be combined.
+            step 2: repeat step 1 until no partitions can be combined
+            """
+            find_combination = True
+            while find_combination:
+                # Sort partitions based on memory size
+                sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes)
+                # Mark bfs level
+                get_bfs_level_partition(self.partitions)
+                find_combination, partitions = find_partition_to_combine_based_on_size(
+                    sorted_partitions, available_mem_bytes, partitions
+                )
+            return
+
+        def calculate_mem_bytes_needed(p1, p2):
+            """Given two partitions, calculate how many mem bytes
+            are needed if two partitions are combined
+            """
+            nodes = p1.nodes.union(p2.nodes)
+            mem_bytes_needed = 0
+            for node in nodes:
+                mem_bytes_needed += get_extra_size_of(node, nodes)
+            return mem_bytes_needed
+
+        def find_partition_to_combine_based_on_size(
+            sorted_partitions: List[Partition],
+            available_mem_bytes: int,
+            partitions: List[Partition],
+        ) -> Tuple[bool, List[Partition]]:
+            """step 1 in combine_partition_based_on_size()"""
+            find_combination = False
+            smallest_partition = sorted_partitions.pop(0)
+            for p in sorted_partitions[::-1]:
+                if abs(smallest_partition.bfs_level - p.bfs_level) <= 1:
+                    # Calculate how many bytes needed if combined
+                    mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition)
+                    if mem_bytes_needed <= available_mem_bytes:
+                        combine_two_partitions(p, smallest_partition, self.partitions)
+                        partitions.remove(smallest_partition)
+                        partitions.remove(p)
+                        partitions.append(self.partitions[-1])
+                        find_combination = True
+                        break
+            return find_combination, partitions
+
+        def reset_partition_in_sparse_nn(partition, new_partition=True):
+            """If crossing the boundary between non-embedding nodes and
+            embedding nodes, create a new partition
+            """
+            if in_embedding_region:
+                embedding_partitions.append(partition)
+            else:
+                non_embedding_partitions.append(partition)
+            if new_partition:
+                partition = self.create_partition()
+                partition.left_mem_bytes = available_mem_bytes
+                return partition
+            return None
+
+        def is_embedding_node(node: Node) -> bool:
+            """Check if a node is an embedding node"""
+            if node.op == "call_module":
+                submodule = self.graph_module
+                for atom in str(node.target).split("."):
+                    if not hasattr(submodule, atom):
+                        raise RuntimeError(
+                            f"Module {submodule} has no attribute {atom}"
+                        )
+                    submodule = getattr(submodule, atom)
+                    if "Embedding" in str(submodule):
+                        return True
+            return False
+
+        # Track embedding partitions and non-embedding partitions separately
+        embedding_partitions: List[Partition] = []
+        non_embedding_partitions: List[Partition] = []
+        # A Flag to check the boundary
+        in_embedding_region: bool = False
+        partition = self.create_partition()
+        for node in self.graph_module.graph.nodes:
+            if node.op in {"call_module", "call_method", "call_function"}:
+                # Check if crossing the boundary between embedding nodes and non embedding nodes
+                if is_embedding_node(node) != in_embedding_region:
+                    # Crossing the boundary
+                    # Check if the current partition is an empty partition
+                    if partition.used_mem_bytes != 0:
+                        # The current partition isn't an empty partition. Create a new one.
+                        partition = reset_partition_in_sparse_nn(partition)
+                    in_embedding_region = not in_embedding_region
+                total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
+                if (
+                    total_size_of_input_nodes + partition.used_mem_bytes
+                    > available_mem_bytes
+                ):
+                    partition = reset_partition_in_sparse_nn(partition)
+                    total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
+                    if total_size_of_input_nodes > available_mem_bytes:
+                        raise RuntimeError(
+                            node.target + "is too large to fit into a device"
+                        )
+                partition.add_node(node)
+        reset_partition_in_sparse_nn(partition, new_partition=False)
+        # Set parents and children for partitions
+        set_parents_and_children(self.partitions)
+        # Combining non-embedding partitions
+        combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes)
+        # Combining embedding partitions
+        combine_partitions_based_on_size(embedding_partitions, available_mem_bytes)
+        total_size_of_non_embedding_partitions = 0
+        for partition in non_embedding_partitions:
+            total_size_of_non_embedding_partitions += partition.used_mem_bytes
+        # Check if devices are enough for all partitions
+        if len(embedding_partitions) > len(self.devices):
+            msg = (
+                "Need "
+                + str(len(embedding_partitions))
+                + " devices, but only "
+                + str(len(self.devices))
+                + " provided"
+            )
+            raise RuntimeError(msg)
+        occupied_devices = []
+        for i, partition in enumerate(embedding_partitions):
+            # Check if all non-embedding partitions can fit into embedding partition devices
+            if (
+                total_size_of_non_embedding_partitions + partition.used_mem_bytes
+                > available_mem_bytes
+            ):
+                raise RuntimeError(
+                    "partition_"
+                    + str(partition.partition_id)
+                    + "(embedding partition) and non embedding partitions can not fit into one device"
+                )
+            else:
+                # Add logical device to the partition
+                partition.logical_device_ids = [self.devices[i].logical_id]
+                occupied_devices.append(self.devices[i].logical_id)
+        # Add logical devices to the non_embedding_partitions
+        for partition in non_embedding_partitions:
+            partition.logical_device_ids = occupied_devices
+        # Get the node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        return
+
+    def cost_aware_partition(
+        self,
+        transfer_rate_bytes_per_sec: float,
+        node_to_latency_mapping: Dict[Node, NodeLatency],
+    ) -> None:
+        """This method is to partition the fx module based on the cost.
+        The cost is the total latency of running the whole fx module.
+        In partitioner_utils.py, the cost model is built.
+        The cost aware partition algorithm is:
+        #1. At every beginning, each node is a partition.
+            Then we map all the partitions to the devices
+            and calculate the cost
+        #2. Then try to pre-combine any two of the partitions if the two
+            partitions can be combined.
+            (the bfs level is less than 2 or two partitions are connected and
+            can find partition to device mapping)
+            See if any partition pair could reduce the current cost.
+            Choose the pair that shows the minimum cost and then combine them
+        #3. Repeat #2 until the cost cannot be reduced.
+        """
+
+        def try_combining_partitions(p0_index, p1_index, partitions) -> float:
+            """Given two partitions and a list of partitions, combine these two partitions
+            and see what is the cost of the modified partition list
+            """
+            p0 = partitions[p0_index]
+            p1 = partitions[p1_index]
+            """If two partitions' bfs level are less than 2 or two partitions are connected to each other,
+               then they can be combined
+            """
+            if (
+                (abs(p0.bfs_level - p1.bfs_level) <= 1)
+                or (p0 in p1.parents)
+                or p0 in (p1.children)
+            ):
+                combine_two_partitions(p0, p1, partitions)
+                # Check if a circular dependency exists after combining
+                if check_dependency(partitions[-1]):
+                    return float("inf")
+                # Check if the modified partition list can be mapped to devices after combination
+                reset_partition_device(partitions)
+                found_deivce = get_device_to_partitions_mapping(
+                    partitions, self.devices
+                )
+                if not found_deivce:
+                    return float("inf")
+                # Calculate the new cost
+                partition_to_latency_mapping = get_partition_to_latency_mapping(
+                    partitions, node_to_latency_mapping
+                )
+                cost = get_latency_of_partitioned_graph(
+                    partitions,
+                    partition_to_latency_mapping,
+                    transfer_rate_bytes_per_sec,
+                )
+                return cost
+            # If two partition can not be combined, the cost is inf
+            return float("inf")
+
+        def search_combination(
+            transfer_rate_bytes_per_sec, node_to_latency_mapping
+        ) -> bool:
+            """Given transfer rate between partitions and each node's latency,
+            find two partitions to combine so the cost of the partitions can
+            be reduced.
+            The algorithm is :
+            1. Go through all the partition pairs and see
+            if any pair of partitions can be combined.
+            2. Calculate the cost after the combination.
+            3. Select the minimum cost and combine its corresponding partition pair.
+            """
+            partition_to_latency_mapping = get_partition_to_latency_mapping(
+                self.partitions, node_to_latency_mapping
+            )
+            cost = get_latency_of_partitioned_graph(
+                self.partitions,
+                partition_to_latency_mapping,
+                transfer_rate_bytes_per_sec,
+            )
+            if len(self.partitions) == 1:
+                return False
+            partition_pair: List[int] = []
+            for i in range(len(self.partitions) - 1):
+                for j in range(i + 1, len(self.partitions)):
+                    # Try to combine the partition pair
+                    # and see the new cost after combination
+                    new_cost = try_combining_partitions(i, j, self.partitions[:])
+                    if new_cost <= cost:
+                        partition_pair = [i, j]
+                        cost = new_cost
+                    reorganize_partitions(self.partitions)
+            # If a partition pair is found, combine them
+            if len(partition_pair) != 0:
+                p0 = self.partitions[partition_pair[0]]
+                p1 = self.partitions[partition_pair[1]]
+                combine_two_partitions(p0, p1, self.partitions)
+            get_bfs_level_partition(self.partitions)
+            reset_partition_device(self.partitions)
+            get_device_to_partitions_mapping(self.partitions, self.devices)
+            return len(partition_pair) != 0
+
+        for node in self.graph_module.graph.nodes:
+            if node.op not in {"placeholder", "get_attr", "output"}:
+                self.create_single_node_partition(node)
+        # Set up parent partitions and children partitions for each partition
+        set_parents_and_children(self.partitions)
+        # Get bfs level for each partition
+        get_bfs_level_partition(self.partitions)
+        find_combination = True
+        while find_combination:
+            # Search for a pair partition to generate the minimum new cost,
+            # then combine them
+            find_combination = search_combination(
+                transfer_rate_bytes_per_sec, node_to_latency_mapping
+            )
+        # Make sure all partitions are set up correctly
+        reorganize_partitions(self.partitions)
+        # Set up node to partition mapping
+        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
+        return
+
+    def kl_based_partition(
+        self,
+        transfer_rate_bytes_per_sec: float,
+        node_to_latency_mapping: Dict[Node, NodeLatency],
+    ) -> None:
+        """This function is a cost aware partition based
+        on Kernighan-Lin algorithm.
+        First, the graph is partitioned using size_based_partition.
+        Then, each node is swapped with any other node in a different
+        partition, and at the same time, the cost is estimated after
+        the swapping.
+        For example, we have nodes n0, n1, n2, n3 and n4.
+        Using size_based_partition, n0 and n1 are in Partition p0.
+        n2, n3 and n4 in Partition p1. The current cost is estimated.
+        We first tried using n0 to swap with n2 from the other partition.
+        Then we see that swapping n0 and n2 shows a lower cost
+        than the current cost and it is the minimum among other pairs like
+        (n0, None)(This means moving n0 to Partition without swapping other nodes),
+        (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost
+        as the current cost.
+        Then We repeat this process for all the other nodes until all swapping pairs
+        are tried.
+        """
+
+        def swap_nodes(n0, n1, p0, p1):
+            # Either n0 or n1 could be None
+            # That means we simply move the node
+            # to another partition
+            if n0 is not None:
+                p0.remove_node(n0)
+                p1.add_node(n0)
+            if n1 is not None:
+                p0.add_node(n1)
+                p1.remove_node(n1)
+
+        def try_swap_nodes(
+            n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
+        ):
+            cost = float("inf")
+            swap_nodes(n0, n1, p0, p1)
+            # Reorganize partitions after swapping
+            reorganize_partitions(self.partitions)
+            # Check if there is a circular dependency after swapping
+            if (not check_dependency(p0)) and (not check_dependency(p1)):
+                reset_partition_device(self.partitions)
+                partition_to_latency_mapping = get_partition_to_latency_mapping(
+                    self.partitions, node_to_latency_mapping
+                )
+                # Check if all partitions can be mapped to logical devices after swapping
+                found_device = get_device_to_partitions_mapping(
+                    self.partitions, self.devices
+                )
+                if not found_device:
+                    cost = float("inf")
+                else:
+                    cost = get_latency_of_partitioned_graph(
+                        self.partitions,
+                        partition_to_latency_mapping,
+                        transfer_rate_bytes_per_sec,
+                    )
+            # Swap back and reset all partitions back to original
+            swap_nodes(n1, n0, p0, p1)
+            reorganize_partitions(self.partitions)
+            reset_partition_device(self.partitions)
+            get_device_to_partitions_mapping(self.partitions, self.devices)
+            return cost
+
+        def swap_node_to_partition(
+            node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
+        ):
+            """This function helps to swap one node from partition p0
+            with all the nodes in another partition p1
+            """
+            p1_nodes = list(p1.nodes) + [None]
+            min_cost = float("inf")
+            node_pair: List[Node] = []
+            for n1 in p1_nodes:
+                # Ignore the node if it is not a op node
+                if n1 is not None and n1.op in {"placeholder", "get_attr"}:
+                    continue
+                # Try swapping node in p0 with n1 in p1
+                cost = try_swap_nodes(
+                    node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
+                )
+                if cost < min_cost:
+                    node_pair = [node, n1]
+                    min_cost = cost
+            return cost, node_pair  # type: ignore[possibly-undefined]
+
+        # First use size_base_partition
+        self.size_based_partition()
+        partition_to_latency_mapping = get_partition_to_latency_mapping(
+            self.partitions, node_to_latency_mapping
+        )
+        # Calculate the cost of the partitions
+        cost = get_latency_of_partitioned_graph(
+            self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
+        )
+        # Keep tracking the node pair that shows the better cost
+        node_pair: List[Node] = []
+        # Keep tracking the partition pair of node pair
+        partition_pair: List[Partition] = []
+        # Collect all the op nodes from the graph
+        op_nodes = []
+        for n in self.graph_module.graph.nodes:
+            if n.op not in {"placeholder", "get_attr", "output"}:
+                op_nodes.append(n)
+        for node in op_nodes:
+            # Find which partition the current node belongs
+            p0_index = self.node_to_partition[node]
+            p0 = self.partitions[p0_index]
+            # Go through all the other partitions to swap
+            # with other nodes from those partitions
+            for p1_index, _ in enumerate(self.partitions):
+                if p0_index != p1_index:
+                    p1 = self.partitions[p1_index]
+                    new_cost, new_node_pair = swap_node_to_partition(
+                        node,
+                        p0,
+                        p1,
+                        node_to_latency_mapping,
+                        transfer_rate_bytes_per_sec,
+                    )
+                    # Update the cost
+                    # Track the swapped node pair and their partitions
+                    if new_cost < cost:
+                        cost = new_cost
+                        node_pair = new_node_pair
+                        partition_pair = [p0, p1]
+            # Do the swapping after trying all the nodes from a partition
+            if len(node_pair) != 0:
+                swap_nodes(
+                    node_pair[0], node_pair[1], partition_pair[0], partition_pair[1]
+                )
+                reorganize_partitions(self.partitions)
+                get_device_to_partitions_mapping(self.partitions, self.devices)
+        reorganize_partitions(self.partitions)
+        # Mapping the device to the partition
+        get_device_to_partitions_mapping(self.partitions, self.devices)
+        return
+
+    def aot_based_partition(
+        self, node_to_partition_mapping, partition_to_logical_device_mapping
+    ):
+        """This function helps to rebuild the partitions given the nodes and its
+        corresponding partition id
+        """
+        partition_id_to_partition_mapping: Dict[int, Partition] = {}
+        self.node_to_partition = node_to_partition_mapping
+        for node in self.node_to_partition:
+            partition_id = self.node_to_partition[node]
+            # If the requested partition has not been created, create the partition
+            if partition_id not in partition_id_to_partition_mapping:
+                partition = Partition(partition_id)
+                self.partitions.append(partition)
+                partition_id_to_partition_mapping[partition_id] = partition
+                partition.logical_device_ids = partition_to_logical_device_mapping[
+                    partition_id
+                ]
+            else:
+                partition = partition_id_to_partition_mapping[
+                    self.node_to_partition[node]
+                ]
+            # Add the current node into the partition
+            partition.add_node(node)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/const_fold.py b/MLPY/Lib/site-packages/torch/fx/experimental/const_fold.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1e4a002b6f831f579cc647142567815e7bcdfb1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/const_fold.py
@@ -0,0 +1,289 @@
+import re
+from typing import Callable, Dict, Optional, Set, Union
+
+import torch.fx
+from torch.fx.node import map_arg
+from torch.fx.passes.split_module import split_module
+
+
+__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
+
+class FoldedGraphModule(torch.fx.GraphModule):
+    """
+    FoldedGraphModule is a GraphModule which also contains another
+    `const_subgraph_module` representing a subgraph which has all const attr
+    inputs and which can be run once before running the main standard
+    `graph`. The `const_output_names` are the ordered list names of attrs which
+    represent what each respective output from the const_subgraph should be set
+    on which attrs.
+    """
+
+    def __init__(
+        self,
+        root: torch.nn.Module,
+        graph: torch.fx.Graph,
+        const_subgraph: Optional[torch.fx.Graph] = None,
+        fx_const_folded_attrs_name: Optional[str] = None,
+        device_for_folded_attrs: str = "cuda",
+    ):
+        super().__init__(root, graph)
+        self.const_subgraph_module = (
+            None
+            if const_subgraph is None
+            else torch.fx.GraphModule(root, const_subgraph)
+        )
+        self.has_folding_been_run = False
+        self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
+        self.device_for_folded_attrs = device_for_folded_attrs
+
+    def __call__(self, *args, **kwargs):
+        if not self.has_folding_been_run:
+            self.run_folding()
+        return super().__call__(*args)
+
+    def run_folding(self):
+        # If there's no const subgraph module or attr output names to use, return
+        # early as there is no const folding to perform.
+        if (
+            self.const_subgraph_module is None
+            or self.fx_const_folded_attrs_name is None
+        ):
+            return
+
+        assert not self.has_folding_been_run
+        self.has_folding_been_run = True
+
+        # Actually run const folding subgraph. Note that single attr const fold
+        # subgraphs output a single Tensor while multiple outputs are returned as
+        # Tuple[Tensor,].
+        folded_attrs = self.const_subgraph_module()
+
+        def _create_param(i):
+            return torch.nn.Parameter(
+                i
+                if not isinstance(i, int)
+                else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
+                requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
+            )
+
+        params = (
+            torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
+            if isinstance(folded_attrs, tuple)
+            else _create_param(folded_attrs)
+        )
+        setattr(self, self.fx_const_folded_attrs_name, params)
+
+
+def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
+    """
+    Given `gm` and some graph module which is called with target name `inline_mod_name`,
+    this helper will inline all of the nodes from that called graph module into `gm`.
+    """
+    # Fetch the inner graph module that we want to inline inside `gm`.
+    inline_mod = dict(gm.named_modules())[inline_mod_name]
+    assert isinstance(inline_mod, torch.fx.GraphModule)
+    call_mod_node_to_replace = None
+    for node in gm.graph.nodes:
+        if node.op == "call_module" and node.target == inline_mod_name:
+            call_mod_node_to_replace = node
+            break
+    assert call_mod_node_to_replace is not None
+
+    # Now actually do the swap. Note that we have to keep track of new nodes that are
+    # copied into `gm` -- we do this via replacement_mapping.
+    call_mod_args = call_mod_node_to_replace.args
+    replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
+    ph_count = 0
+
+    def replacement_fn(node):
+        new_node = replacement_mapping[node]
+        new_node.meta = node.meta.copy()
+        return new_node
+
+    for inline_node in inline_mod.graph.nodes:
+        if inline_node.op == "placeholder":
+            replacement_mapping[inline_node] = call_mod_args[ph_count]
+            ph_count += 1
+            continue
+
+        if inline_node.op == "output":
+            outputs = inline_node.args[0]
+            output_replacements = map_arg(outputs, replacement_fn)
+            call_mod_node_to_replace.replace_all_uses_with(output_replacements)
+            continue
+
+        with gm.graph.inserting_before(call_mod_node_to_replace):
+            new_node = gm.graph.node_copy(inline_node, replacement_fn)
+        replacement_mapping[inline_node] = new_node
+
+    gm.graph.eliminate_dead_code()
+
+
+def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
+    """
+    Make sure the name is unique (in a module) and can represents an attr.
+    """
+    # Delete all characters that are illegal in a Python identifier.
+    name = re.sub("[^0-9a-zA-Z_]+", "_", name)
+    if name[0].isdigit():
+        name = f"_{name}"
+    # Now make sure it is in fact unique to the module by incrementing suffix value.
+    while hasattr(mod_traced, name):
+        match = re.match(r"(.*)_(\d+)$", name)
+        if match is None:
+            name = name + "_1"
+        else:
+            base, num = match.group(1, 2)
+            name = f"{base}_{int(num) + 1}"
+
+    return name
+
+
+def split_const_subgraphs(
+    module: Union[torch.nn.Module, torch.fx.GraphModule],
+    skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
+    device_for_folded_attrs: str = "cpu",
+) -> FoldedGraphModule:
+    """
+    Looks through `module` for any nodes that have all constant attribute inputs
+    and separates them out into their own constant subgraph, and returns a
+    FoldedGraphModule which runs that constant subgraph on the first run to set
+    attributes on the module prior to running the non-constant portion of the
+    graph.
+    """
+    if not isinstance(module, torch.fx.GraphModule):
+        mod_traced = torch.fx.symbolic_trace(module)
+    else:
+        mod_traced = module
+
+    # Build up a list of const_nodes, defined as nodes that are themselves
+    # get_attrs, or have all get_attr or other constant node inputs.
+    const_nodes: Set[torch.fx.Node] = set()
+    found_const_folding = False
+    for node in mod_traced.graph.nodes:
+        # Skip over placeholders/outputs because they can't be const folded and
+        # we don't want to add tags to them.
+        if node.op in {"placeholder", "output"}:
+            continue
+
+        # If the node itself is constant, or all of its inputs are constant,
+        # then tag it as constant.
+        if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
+            const_nodes
+        ):
+            continue
+
+        # If provided skip folding function says to skip, then skip.
+        if skip_folding_node_fn and skip_folding_node_fn(node):
+            continue
+
+        # Skip folding side-effectful functions
+        if node.is_impure():
+            continue
+
+        # Must be a constant foldable node at this point.
+        const_nodes.add(node)
+        if node.op != "get_attr":
+            found_const_folding = True
+
+    # If we did not find any const folding then return early without a const fold subgraph.
+    if not found_const_folding:
+        return FoldedGraphModule(mod_traced, mod_traced.graph)
+
+    # Partition the module into two: submod_0 for constant folding subgraph, and
+    # submod_1 for the rest.
+    def mod_partition(node: torch.fx.Node):
+        return 0 if node in const_nodes else 1
+
+    split = split_module(mod_traced, module, mod_partition)
+
+    const_gm, non_const_gm = split.submod_0, split.submod_1
+    const_mod_name, non_const_mod_name = "submod_0", "submod_1"
+
+    # The module that a call_module node refers to gets copied to submodules during split.
+    # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
+    # attach inlined modules to `split` as it's the owning module now.
+    for node in non_const_gm.graph.nodes:
+        if node.op == "call_module":
+            setattr(split, node.target, getattr(non_const_gm, node.target))
+    for node in const_gm.graph.nodes:
+        if node.op == "call_module":
+            setattr(split, node.target, getattr(const_gm, node.target))
+
+    # split_module currently does not use get_attrs for attrs. Instead it passes
+    # them in as args from the parent module, which used get_attrs. Here we set
+    # them as get_attrs inside const_gm, allowing for running folding without
+    # somehow a priori knowing the attrs that should be passed as args. We can
+    # unconditionally do this for all placeholders because we know all
+    # placeholders to const_gm must be constants accessible via get_attr.
+    call_const_gm_args = None
+    for node in split.graph.nodes:
+        if node.op == "call_module":
+            if node.target == const_mod_name:
+                call_const_gm_args = node.args
+                break
+    assert call_const_gm_args is not None
+
+    # Here we do the actual replacement of placeholders to get_attrs. Note that here we
+    # set the const_gm.graph into a new root_const_gm with split as the root module,
+    # because we are fetching attributes directly from the root module, instead of
+    # fetching them from const_gm. Example: The const_gm must have some format like:
+    # graph():
+    #    %inp : [num_users=1] = placeholder[target=const_inp]
+    #    %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
+    #    return add
+    # We replace that with the following, which does not have any placeholders:
+    # graph():
+    #    %inp_1 : [num_users=1] = get_attr[target=const_inp]
+    #    %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
+    #    return add
+    root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
+    for node in root_const_gm.graph.nodes:
+        if node.op == "output":
+            multiple_outputs = isinstance(node.args[0], tuple)
+            continue
+        if node.op != "placeholder":
+            continue
+        in_node = next(n for n in call_const_gm_args if n.name == node.target)
+        assert in_node.op == "get_attr"
+        with root_const_gm.graph.inserting_before(node):
+            new_node = root_const_gm.graph.get_attr(in_node.target)
+        new_node.meta = node.meta.copy()
+        node.replace_all_uses_with(new_node)
+        root_const_gm.graph.erase_node(node)
+    assert "multiple_outputs" in locals()
+
+    # Now find the call to const_gm inside split, and replace it with a getattr to the
+    # folded tensor(s) that result from constant folding. Note that we don't need to
+    # worry about whether this is one or more tensors because the original graph
+    # correctly uses getitem to extract individual tensors if there are multiple folded.
+    fx_const_folded_attrs_name = get_unique_attr_name_in_module(
+        split, "_FX_CONST_FOLDED_ATTRS"
+    )
+    setattr(
+        split,
+        fx_const_folded_attrs_name,
+        torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),  # type: ignore[possibly-undefined]
+    )
+    for node in split.graph.nodes:
+        if node.op == "call_module" and node.target == const_mod_name:
+            with node.graph.inserting_before(node):
+                folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
+            folded_attrs.meta = node.meta.copy()
+            node.replace_all_uses_with(folded_attrs)
+            break
+
+    split.graph.eliminate_dead_code()
+
+    # Finally, inline the non-constant submod into the split submod. This is so that the
+    # original caller who may have passed in a graph module will get back out a graph
+    # module whose graph is traced to the same granularity.
+    _inline_module(split, non_const_mod_name)
+
+    return FoldedGraphModule(
+        split,
+        split.graph,
+        root_const_gm.graph,
+        fx_const_folded_attrs_name,
+        device_for_folded_attrs,
+    )
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/debug.py b/MLPY/Lib/site-packages/torch/fx/experimental/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c89590a704535a6d2dccd404b873a20dbccb169
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/debug.py
@@ -0,0 +1,31 @@
+import torch.fx as fx
+
+def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
+    """
+    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
+    `gm` gets run.
+
+    Args:
+        gm: graph module to insert breakpoint. It is then recompiled for it to
+            take effect.
+
+    Returns:
+        the `gm` with breakpoint inserted.
+    """
+    def insert_pdb(body):
+        return ["import pdb; pdb.set_trace()\n", *body]
+
+    with gm.graph.on_generate_code(
+        make_transformer=lambda cur_transform: (
+            # new code transformer to register
+            lambda body: (
+                insert_pdb(
+                    cur_transform(body) if cur_transform
+                    else body
+                )
+            )
+        )
+    ):
+        gm.recompile()
+
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/MLPY/Lib/site-packages/torch/fx/experimental/graph_gradual_typechecker.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c85b879a070fb4e37c5d6b55e12ab7bae2c5b9c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/graph_gradual_typechecker.py
@@ -0,0 +1,914 @@
+from functools import reduce
+import torch
+import operator
+from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
+from typing import Callable, Dict
+from torch.fx.node import Target, Node
+from torch.nn.modules.batchnorm import BatchNorm2d
+from torch.nn.modules.conv import Conv2d
+from torch.fx.experimental.refinement_types import Equality
+import itertools
+
+from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
+
+import sympy
+
+_INFERENCE_RULES: Dict[Target, Callable] = {}
+_REFINEMENT_RULES: Dict[Target, Callable] = {}
+_RULES: Dict[Target, Callable] = {}
+
+
+def expand_to_tensor_dim(t, n):
+    """
+    Expand a type to the desired tensor dimension if possible
+    Raise an error otherwise.
+    - t is the given type
+    - n is a number of dimensions to expand to
+    """
+    if t == Dyn:
+        dims = [Dyn] * n
+        return TensorType(tuple(dims))
+    elif isinstance(t, TensorType):
+        if len(t.__args__) != n:
+            raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
+        return t
+    else:
+        raise TypeError(f'Cannot match the type {t}')
+
+
+def broadcast_types(t1, t2):
+    """
+    Applies broadcasting to both given types such that they
+    become consistent with eachother and returns two new
+    resulting types
+    """
+
+    # if either type is Dyn, do nothing since the types are already consistent
+    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
+        return t1, t2
+
+    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
+        s1 = len(t1.__args__)
+        s2 = len(t2.__args__)
+
+        new_t1 = list(t1.__args__)
+        new_t2 = list(t2.__args__)
+
+        # We make the types the same length which is the first requirement
+        # for consistency
+        if s1 > s2:
+            for i in range(s1 - s2):
+                new_t2.insert(0, 1)
+
+        elif s2 > s1:
+            for i in range(s2 - s1):
+                new_t1.insert(0, 1)
+
+        # we replace occurrences of "1" with each tensor with
+        # the corresponding type from the other tensor
+        for i, (x, y) in enumerate(zip(new_t1, new_t2)):
+            if x == 1:
+                new_t1[i] = y
+            elif y == 1:
+                new_t2[i] = x
+
+        # at this point our tensors should be consistent
+        # and we can apply the element-wise operation and find the right dimension
+        # for the output of the operation
+        (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
+        return (t1, t2)
+    else:
+        raise TypeError(f'Cannot broadcast types {t1} and {t2}')
+
+def register_inference_rule(call_target):
+    def register(fn):
+        if call_target in _INFERENCE_RULES:
+            raise RuntimeError(f'Inference rule already registered for {call_target}!')
+        _INFERENCE_RULES[call_target] = fn
+        return fn
+    return register
+
+def register_refinement_rule(call_target):
+    def register(fn):
+        if call_target in _REFINEMENT_RULES:
+            raise RuntimeError(f'Refinement rule already registered for {call_target}!')
+        _REFINEMENT_RULES[call_target] = fn
+        return fn
+    return register
+
+def register_algebraic_expressions_inference_rule(call_target):
+    def register(fn):
+        if call_target in _RULES:
+            raise RuntimeError(f'Rule already registered for {call_target}!')
+        _RULES[call_target] = fn
+        return fn
+    return register
+
+@register_inference_rule(torch.add)
+@register_inference_rule(operator.add)
+def add_inference_rule(n: Node):
+    """
+    Apply the addition inference rule. This includes:
+    - scalar addition
+    - broadcasting semantics
+
+    Note that we always return the least precise type between
+    the operands (after applying broadcasting) to be the final type of the operation
+
+    Note that we do not modify the operand types themselves after applying broadcasting
+    to them. We only use them to calculate the final type
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+    t1 = n.args[0].type
+    t2 = n.args[1].type
+
+    # handle scalar addition
+    if t1 == int and isinstance(t2, TensorType):
+        n.type = t2
+        return n.type
+
+    # handle scalar addition
+    elif t2 == int and isinstance(t1, TensorType):
+        n.type = t1
+        return n.type
+
+    # we bring the new types to the point where
+    # we can check for consistency
+    # any inconsistency would not have been caused
+    # by broadcasting at this point
+    (new_t1, new_t2) = broadcast_types(t1, t2)
+
+    if new_t1 != t1 or new_t2 != t2:
+        n.meta['broadcast'] = True
+        n.meta[str(n.args[0])] = new_t1
+        n.meta[str(n.args[1])] = new_t2
+
+    else:
+        n.meta['broadcast'] = False
+
+    new_t1 = t1 if not n.meta['broadcast'] else new_t1
+    new_t2 = t2 if not n.meta['broadcast'] else new_t2
+
+    # we check for consistency between the new types
+    if is_consistent(new_t1, new_t2):
+        # we return the less precise type because
+        # broadcasting may have happened
+        # for operands with shape [1,2,Dyn] and [1,2,1]
+        # we have to assign the node [1,2,Dyn]
+        if is_more_precise(new_t1, new_t2):
+            n.type = new_t2
+        else:
+            n.type = new_t1
+        return n.type
+    else:
+        raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.'
+                        f' Types should match ')
+
+@register_inference_rule(getattr)
+def get_attr_inference_rule(n: Node, traced):
+    """
+    The current getattr rule only handles the shape attribute
+    Can be extended to other attributes
+    The most representitive type we have is "Dyn" but the system
+    can be extended with more types, such as a type to represent shapes
+    """
+    attr_node = n.args[0]
+    attr_name = n.args[1]
+
+    if attr_name == "shape":
+        n.type = Dyn
+    else:
+        raise TypeError("Not yet implemented")
+
+    # TODO. We leave it like this till we add a type to represent tensor sizes
+    return n.type
+
+@register_inference_rule(torch.transpose)
+def transpose_inference_rule(n: Node):
+    """
+    We check that dimensions for the transpose operations
+    are within range of the tensor type of the node
+    """
+    if n.target == torch.transpose:
+        assert isinstance(n.args[0], Node)
+        t = n.args[0].type
+
+        assert isinstance(n.args[1], int)
+        assert isinstance(n.args[2], int)
+        dim1, dim2 = n.args[1], n.args[2]
+
+        if t == Dyn:
+            n.type = Dyn
+            return n.type
+
+        elif isinstance(t, TensorType):
+            if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
+                new_type = list(t.__args__)
+                new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
+                final = TensorType(new_type)
+                n.type = get_greatest_upper_bound(n.type, final)
+                return n.type
+            else:
+                raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
+        else:
+            raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
+
+
+@register_inference_rule(torch.reshape)
+def reshape_inference_rule(n: Node):
+    """
+    Without dynamism, the rule checks that the
+    product of the elements of the argument tensor
+    type is equal to the product of the elements
+    of the required shape. We gradualize this rule
+    by adding a case to handle fully dynamic input
+    as well as input where some of the tensor dimensions
+    are unknown. In this case we check for divisibility
+    """
+    assert isinstance(n.args[0], Node)
+    t1 = n.args[0].type
+
+    assert isinstance(n.args[1], list)
+    t2 = n.args[1]
+    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])
+
+    # if we do not know the original tensor dimension,
+    # we return the required dimension
+    if t1 == Dyn:
+        n.type = t2_type
+        return t2_type
+
+    # if any of the dimensions are unknown,
+    # we check for divisibility
+    elif isinstance(t1, TensorType):
+        assert isinstance(t1, TensorType)
+        a = [e if e != Dyn else 1 for e in t1.__args__]
+        p1 = reduce(operator.mul, a)
+        p2 = reduce(operator.mul, t2)
+        if p1 % p2 == 0 or p2 % p1 == 0:
+            n.type = t2_type
+            return t2_type
+        else:
+            raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
+    else:
+        raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
+
+@register_inference_rule(BatchNorm2d)
+def bn2d_inference_rule(n: Node, module_instance):
+    """
+    Given a BatchNorm2D instance and a node check the following conditions:
+    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, x_3, x_4)
+    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
+    - t is consistent with t'
+    - x_2 is consistent with the module's num_features
+    - x_2' is consistent with the module's num_features
+    output type: the more precise type of t and t'
+    """
+    assert isinstance(n.args[0], Node)
+    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
+    arg_type = n.args[0].type
+    n.type = expand_to_tensor_dim(n.type, 4)
+
+    # we check the conditions on the incoming argument
+    # and any existing annotation
+    # we also check for consistency between both annotations
+    if is_consistent(arg_type.__args__[1], module_instance.num_features) and \
+            is_consistent(n.type.__args__[1], module_instance.num_features) and \
+            is_consistent(arg_type, n.type):
+
+        # we choose the more precise type
+        # to be the node type
+        # so if an incoming argument has more type information
+        # we set this node's type to be the argument type
+        n.type = get_greatest_upper_bound(arg_type, n.type)
+        return n.type
+    else:
+        raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
+
+
+def calculate_out_dimension(d_in, module_instance, index):
+    """
+    For calculating h_in and w_out according to the conv2D documentation
+    """
+    padding = (module_instance.padding, module_instance.padding) \
+        if isinstance(module_instance.padding, int) else module_instance.padding
+    kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \
+        if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size
+    stride = (module_instance.stride, module_instance.stride) \
+        if isinstance(module_instance.stride, int) else module_instance.stride
+    dilation = (module_instance.dilation, module_instance.dilation) \
+        if isinstance(module_instance.dilation, int) else module_instance.dilation
+
+    DIMENSION_TYPES = (int, sympy.Symbol)
+
+    if d_in == Dyn:
+        return Dyn
+
+    elif isinstance(d_in, DIMENSION_TYPES):
+        n = d_in + 2 * padding[index] - \
+            dilation[index] * \
+            (kernel_size[index] - 1) - 1
+
+        return (n // stride[0]) + 1
+
+    else:
+        raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
+
+
+def get_greatest_upper_bound(type1, type2):
+    """
+    Get the most precise type that's consistent with the given types
+    """
+    if type1 == Dyn:
+        return type2
+    elif type2 == Dyn:
+        return type1
+    elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
+        if not is_consistent(type1, type2):
+            raise TypeError(f'Inconsistent types {type1}, {type2}')
+        gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
+        return TensorType(tuple(gub))
+
+
+@register_inference_rule(Conv2d)
+def conv2d_inference_rule(n: Node, module_instance):
+    """
+    Given a Conv2D instance and a node check the following conditions:
+    - the input type can be expanded to a size 4 tensor: t =  (x_1, x_2, H, W)
+    - the current node type can be expanded to a size 4 tensor: t' =  (x_1', x_2', x_3', x_4')
+    - x_2 is consistent with the module's in_channels
+    - let o = (x_1, out_channels, H_out, W_out)
+    then the output is the greatest upper bound of o and the existing node type t'.
+    """
+    assert isinstance(n.args[0], Node)
+    n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
+    arg_type = n.args[0].type
+    curr_node_type = expand_to_tensor_dim(n.type, 4)
+
+    if is_consistent(arg_type.__args__[1], module_instance.in_channels):
+        w_in = arg_type.__args__[3]
+        h_in = arg_type.__args__[2]
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+        new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
+        gub = get_greatest_upper_bound(new_type, curr_node_type)
+        n.type = gub
+        return n.type
+    else:
+        raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
+
+
+@register_inference_rule(torch.nn.ReLU)
+def relu_inference_rule(n: Node, module_instance):
+    """
+    Input and output shapes should be equal.
+    """
+    assert isinstance(n.args[0], Node)
+
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+
+    if isinstance(n.args[0].type, TensorType):
+        n.type = get_greatest_upper_bound(n.args[0].type, n.type)
+    return n.type
+
+
+def maxpool2d_check(typ, module_instance):
+    """
+    Applies the maxpool2d shape information to the input
+    this affects the last two dimensions
+    """
+    new_type_list = list(typ.__args__)
+    if len(new_type_list) == 4 or len(new_type_list) == 3:
+        w_in = new_type_list[-1]
+        h_in = new_type_list[-2]
+
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+
+        new_type_list[-1] = w_out
+        new_type_list[-2] = h_out
+        return TensorType(tuple(new_type_list))
+
+    else:
+        raise TypeError(f'Wrong size {typ} for {module_instance}')
+
+
+@register_inference_rule(torch.nn.MaxPool2d)
+def maxpool2d_inference_rule(n: Node, module_instance):
+    """
+    Given a MaxPool2D instance and a node check the following conditions:
+    - Input size matches size 3 or 4
+    - Current node type is consistent with the output type we will calculate
+    - Input size matches output size and the last two dimensions of the output
+      are w_out and h_out. The remaining dimensions are the same as the input
+    - Our final result is the greatest upper bound of the output we calculate
+      and the current node type.
+    """
+    assert isinstance(n.args[0], Node)
+
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+    if isinstance(n.args[0].type, TensorType):
+        output = maxpool2d_check(n.args[0].type, module_instance)
+        n.type = get_greatest_upper_bound(output, n.type)
+    return n.type
+
+
+
+def linear_check(tensor_type, module_instance):
+    """
+    Checks that an input tensor type satisfies the conditions for linear operation
+    and returns the output type based on in and out features given by module_instance
+    """
+    if len(tensor_type.__args__) >= 2:
+        if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
+            new_type_args = list(tensor_type.__args__)
+            new_type_args[-1] = module_instance.out_features
+            return TensorType(tuple(new_type_args))
+        else:
+            raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}')
+    else:
+        raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
+
+
+@register_inference_rule(torch.nn.Linear)
+def linear_inference_rule(n: Node, module_instance):
+    """
+    Applies the shape information to the input then gets the greatest upper bound
+    of the resulting type and the existing type
+    """
+    assert isinstance(n.args[0], Node)
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+    if isinstance(n.args[0].type, TensorType):
+        output_type = linear_check(n.args[0].type, module_instance)
+        n.type = get_greatest_upper_bound(output_type, n.type)
+    return n.type
+
+
+def adaptiveavgpool2d_check(tensor_type, module_instance):
+    output_size = module_instance.output_size
+    if isinstance(output_size, int):
+        output_size = [output_size, output_size]
+    elif isinstance(output_size, tuple):
+        output_size = list(output_size)
+        if output_size[0] is None:
+            output_size[0] = output_size[1]
+        if output_size[1] is None:
+            output_size[1] = output_size[0]
+
+    new_type_list = list(tensor_type.__args__)
+
+    if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3:
+        new_type_list[-1] = output_size[1]
+        new_type_list[-2] = output_size[0]
+
+        return TensorType(tuple(new_type_list))
+
+    else:
+        raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}')
+
+@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
+def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
+    """
+    The input and output sizes should be the same except for the last
+    two dimensions taken from the input, which represent width and height
+    """
+    assert isinstance(n.args[0], Node)
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+    if isinstance(n.args[0].type, TensorType):
+        output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance)
+        n.type = get_greatest_upper_bound(n.type, output_type)
+    return n.type
+
+def flatten_check(tensor_type, start_dim, end_dim):
+    l = len(tensor_type.__args__)
+
+    start_dim = l if start_dim == -1 else abs(start_dim)
+    end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
+
+    if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
+        my_args = list(tensor_type.__args__)
+        lhs = my_args[0:start_dim]
+        rhs = my_args[end_dim:]
+        mid = my_args[start_dim:end_dim]
+        if Dyn in mid:
+            mid = [Dyn]
+        else:
+            mid = [reduce(operator.mul, my_args[start_dim:end_dim])]
+        new_type_list = lhs + mid + rhs
+        return TensorType(tuple(new_type_list))
+    else:
+        raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}')
+
+@register_inference_rule(torch.flatten)
+def flatten_inference_rule(n: Node):
+    """
+    Applies the flatten shape information to the input then gets the
+    greatest upper bound of the resulting type and the existing type
+    """
+    assert isinstance(n.args[0], Node)
+
+    # set the default start and end dims
+    start_dim = 1
+    end_dim = -1
+
+    if len(n.args) > 1:
+        assert isinstance(n.args[1], int)
+        start_dim = n.args[1]
+
+    if len(n.args) > 2:
+        assert isinstance(n.args[2], int)
+        end_dim = n.args[2]
+
+    if n.args[0].type == Dyn and isinstance(n.type, TensorType):
+        n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
+
+    if isinstance(n.args[0].type, TensorType):
+        output_type = flatten_check(n.args[0].type, start_dim, end_dim)
+        n.type = get_greatest_upper_bound(output_type , n.type)
+
+    return n.type
+
+class GraphTypeChecker:
+    def __init__(self, env, traced):
+        self.env = env
+        self.traced = traced
+
+    def type_check(self):
+        """
+        A gradual type checker for graphs
+        Effect: every node's field type will be
+        populated with a type after type-checking is done
+        """
+        graph = self.traced.graph
+
+        # type check every node with gradual type rules
+        # if any node does not type check return false
+        for n in graph.nodes:
+            self.type_check_node(n)
+        return True
+
+    def type_check_node(self, n: Node):
+        """
+        Type check a given fx node.
+        Current operations:
+        - Reshape
+        - Transpose
+        - Add
+        - Relu
+        - conv2d
+        - batchnorm2d
+        - flatten
+        - maxpool2d
+        - adaptiveavgpool2d
+        - linear
+        """
+        if n.type is None:
+            n.type = Dyn
+
+        if n.op == 'placeholder':
+            return n.type
+
+        elif n.op == 'get_attr':
+            t = get_parameter(self.traced, n.target)  # type: ignore[arg-type]
+            if isinstance(t.data, torch.Tensor):
+                n.type = TensorType(t.data.shape)
+            return n.type
+
+        elif n.op == 'call_function':
+            if n.target == getattr:
+                assert getattr in _INFERENCE_RULES
+                return _INFERENCE_RULES[n.target](n, self.traced)
+
+            elif n.target in _INFERENCE_RULES:
+                return _INFERENCE_RULES[n.target](n)
+            else:
+                raise RuntimeError(f'No inference rule registered for target {n.target}!')
+
+        elif n.op == 'call_module':
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _INFERENCE_RULES:
+                return _INFERENCE_RULES[type(module_instance)](n, module_instance)
+            else:
+                raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
+
+        elif n.op == 'output':
+            def get_node_type(a):
+                return a.type
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
+
+        else:
+            raise NotImplementedError(f"Method {n.op} not yet implemented")
+
+
+@register_refinement_rule(Conv2d)
+def conv_refinement_rule(n: Node):
+    """
+    The equality constraints are between the first dimension of
+    the input and output
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
+        return res
+
+
+@register_refinement_rule(torch.nn.Linear)
+def linear_refinement_rule(n: Node):
+    """
+    The equality constraints are between the first dimension of
+    the input and output
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
+    return res
+
+@register_refinement_rule(BatchNorm2d)
+@register_refinement_rule(torch.nn.ReLU)
+def all_eq(n: Node):
+    """
+    For operations where the input shape is equal to the output shape
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        args1 = arg_type.__args__
+        args2 = n.type.__args__
+        res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
+    return res
+
+
+@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
+@register_refinement_rule(torch.nn.MaxPool2d)
+def first_two_eq(n: Node):
+    """
+    For operations where the first two dimensions of the input and output shape
+    are equal
+    """
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        args1 = arg_type.__args__
+        args2 = n.type.__args__
+        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
+    return res
+
+
+@register_refinement_rule(torch.add)
+@register_refinement_rule(operator.add)
+def element_wise_eq(n: Node):
+    """
+    For element-wise operations and handles broadcasting.
+    Note that after applying broadcasting to the arguments
+    we are able to determine if certain dimensions have not been broadcast
+    if they are symbolicallu equal.
+
+    in this case, we can establish equality between those dimensions and the
+    corresponding output dimensions.
+
+    Note that it takes two iterations for this result. One iteration to establish
+    equality between certain dimensions of the operands (requiring the whole solver
+    including unification) and another iteration to establish equality between the operands
+    and the resulting type, requiring another round of constraint generation and unificaiton.
+    """
+    res = []
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        arg_type1 = n.args[0].type
+        arg_type2 = n.args[1].type
+        if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
+            args1, args2 = broadcast_types(arg_type1, arg_type2)
+            # by this point, we know that args1 and args2 are the same size.
+            a1 = args1.__args__
+            a2 = args2.__args__
+            a3 = n.type.__args__
+
+            # we would be here in the second iteration where we establish equality
+            # between operand type dimensions and the resulting type dimensions
+            r = []
+            for x, y, z in zip(a1, a2, a3):
+                if x == y:
+                    r.append(Equality(x, z))
+            res = r
+    return res
+
+
+@register_refinement_rule(torch.flatten)
+def flatten_refinement_rule(n: Node):
+    """
+    Generates equality constraints between the dimensions of the input and output
+    that will not be involved in the flatten operation
+    """
+    assert isinstance(n.args[0], Node)
+
+    eq_const = []
+
+    start_dim = 1
+    end_dim = -1
+
+    if len(n.args) > 1:
+        assert isinstance(n.args[1], int)
+        start_dim = n.args[1]
+
+    if len(n.args) > 2:
+        assert isinstance(n.args[2], int)
+        end_dim = n.args[2]
+
+    if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType):
+        l = len(n.type.__args__)
+        arg_type = n.args[0].type
+        start_dim = l if start_dim == -1 else start_dim
+        end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
+
+        for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]):
+            eq_const.append(Equality(t1, t2))
+
+        for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]):
+            eq_const.append(Equality(t1, t2))
+    return eq_const
+
+
+@register_algebraic_expressions_inference_rule(Conv2d)
+def conv_rule(n: Node, module_instance):
+    """
+    Represents the outout in terms of an algrbraic expression w.r.t
+    the input when possible
+    """
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        w_in = arg_type.__args__[3]
+        h_in = arg_type.__args__[2]
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+        new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out))
+        n.type = new_type
+        return new_type
+
+class Refine:
+    """
+    Symbolic shape inference.
+    Generates constraints over type variables.
+    Currently all constraints are equality constraints.
+    """
+    def __init__(self, traced):
+        self.constraints = []
+        self.traced = traced
+        self.symbol_iter = itertools.count(start=0, step=1)
+
+    def refine(self):
+        """
+        Generates constraints for
+        every node in the graph based on
+        the operation.
+        """
+        graph = self.traced.graph
+        for n in graph.nodes:
+            self.refine_node(n)
+        return True
+
+    def symbolic_relations(self):
+        """
+        Infers algebraic relations
+        """
+        graph = self.traced.graph
+        for n in graph.nodes:
+            self.infer_symbolic_relations(n)
+        return True
+
+    def replace_dyn_with_fresh_var(self, typ):
+        """
+        Replace all unknown types with fresh type variables.
+        """
+        if typ == Dyn:
+            new_symbol = Var(next(self.symbol_iter))
+            return new_symbol
+        elif isinstance(typ, TensorType):
+            new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
+            return TensorType(tuple(new_args))
+        elif isinstance(typ, list):
+            return [self.replace_dyn_with_fresh_var(t) for t in typ]
+        elif isinstance(typ, tuple):
+            return (self.replace_dyn_with_fresh_var(t) for t in typ)
+        else:
+            return typ
+
+
+    def convert_to_sympy_symbols(self, typ):
+        """
+        Replace all unknown types with fresh type variables.
+        """
+        if isinstance(typ, Var):
+            return sympy.symbols(str(typ))
+        elif isinstance(typ, TensorType):
+            new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
+            return TensorType(tuple(new_args))
+        elif isinstance(typ, list):
+            return [self.convert_to_sympy_symbols(t) for t in typ]
+        elif isinstance(typ, tuple):
+            return (self.convert_to_sympy_symbols(t) for t in typ)
+        else:
+            return typ
+
+    def refine_node(self, n: Node):
+        """
+        Returns a list of equality constraints for
+        call_module and call_function nodes.
+        Models the relation between input and output dimensions
+        using constraints in case they are both tensors.
+        All operations used in resnet50 are defined.
+        """
+        if n.type is None:
+            n.type = Dyn
+
+        n.type = self.replace_dyn_with_fresh_var(n.type)
+
+        if n.op == 'call_function':
+            if n.target in _REFINEMENT_RULES:
+                self.constraints += _REFINEMENT_RULES[n.target](n)
+            else:
+                pass
+
+        if n.op == 'call_module':
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _REFINEMENT_RULES:
+                self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
+            else:
+                pass
+
+        if n.op == 'output':
+            def get_node_type(a):
+                return a.type
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
+
+        else:
+            pass
+
+    def infer_symbolic_relations(self, n: Node):
+        n.type = self.convert_to_sympy_symbols(n.type)
+        if n.op == 'call_function':
+            if n.target in _RULES:
+                return _RULES[n.target](n)
+            else:
+                pass
+
+        if n.op == 'call_module':
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _RULES:
+                return _RULES[type(module_instance)](n, module_instance)
+            else:
+                pass
+
+        if n.op == 'output':
+            def get_node_type(a):
+                return a.type
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
+
+        else:
+            pass
+
+def get_parameter(traced, target: str):
+    """
+    Returns the parameter given by ``target`` if it exists,
+    otherwise throws an error.
+
+    See the docstring for ``get_submodule`` for a more detailed
+    explanation of this method's functionality as well as how to
+    correctly specify ``target``.
+
+    Args:
+        target: The fully-qualified string name of the Parameter
+            to look for. (See ``get_submodule`` for how to specify a
+            fully-qualified string.)
+
+    Returns:
+        torch.nn.Parameter: The Parameter referenced by ``target``
+
+    Raises:
+        AttributeError: If the target string references an invalid
+            path or resolves to something that is not an
+            ``nn.Parameter``
+    """
+    module_path, _, param_name = target.rpartition(".")
+
+    mod: torch.nn.Module = traced.get_submodule(module_path)
+
+    if not hasattr(mod, param_name):
+        raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
+
+    param: torch.nn.Parameter = getattr(mod, param_name)
+
+    return param
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/merge_matmul.py b/MLPY/Lib/site-packages/torch/fx/experimental/merge_matmul.py
new file mode 100644
index 0000000000000000000000000000000000000000..a14d1a9d8ab0fc86283b1fc1deacce6cd1ef2f17
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/merge_matmul.py
@@ -0,0 +1,171 @@
+import torch
+
+from torch.fx.node import Node
+from torch.fx._symbolic_trace import symbolic_trace
+from torch.fx.passes.tools_common import legalize_graph
+import itertools
+import operator
+
+from typing import Dict, List, Tuple
+
+
+def split_result_tensors(
+    result: torch.Tensor, inputs: List[torch.Tensor]
+) -> Tuple[torch.Tensor, ...]:
+    """
+    A free function for use in the merge_matmul graph transformation below that
+    splits the output from a merged matmul into the individual results for each
+    input tensor.
+
+    Arguments:
+        result: The merged matmul result tensor.
+        inputs: The list of inputs that were merged into one for the matmul.
+
+    Returns:
+        List of matmul results for each input tensor.
+    """
+    # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
+    # need an int even when tracing
+    if isinstance(result, torch.fx.Proxy):
+        splits = [0] * len(inputs)
+    else:
+        splits = [x.shape[0] for x in inputs]
+
+    return torch.split(result, splits)
+
+
+def may_depend_on(a: Node, b: Node, search_depth: int = 6):
+    """
+    Determine if one node depends on another in a torch.fx.Graph.
+
+    Arguments:
+        a: The node that may have a dependency on b.
+        b: The node that a may have a dependency on.
+        search_depth: In the case of an indirect dependency, this function
+                        searches upto this many nodes away in search of a
+                        data dependency. If none is found, the function
+                        makes the conservative assumption that there is a
+                        dependency.
+
+    Returns:
+        True if a may depend on b, False if it definitely does not.
+    """
+    # Equivalence is defined as dependence.
+    if a == b:
+        return True
+
+    # If a has no inputs, it cannot depend on b.
+    if len(a.all_input_nodes) == 0:
+        return False
+
+    # If the search depth has been exhausted and no conclusion has been
+    # reached, assume that there is a data dependency.
+    if search_depth == 0:
+        return True
+
+    # Recursively check all inputs of a.
+    for inp in a.all_input_nodes:
+        if may_depend_on(inp, b, search_depth - 1):
+            return True
+
+    return False
+
+
+def are_nodes_independent(nodes: List[Node]):
+    """
+    Check if all of the given nodes are pairwise-data independent.
+
+    Arguments:
+        nodes: The nodes to check for data dependencies.
+
+    Returns:
+        True if any pair in nodes has a data dependency.
+    """
+    # For each pair in nodes:
+    for i, j in itertools.combinations(nodes, 2):
+        if may_depend_on(i, j) or may_depend_on(j, i):
+            return False
+
+    return True
+
+
+def merge_matmul(in_mod: torch.nn.Module):
+    """
+    A graph transformation that merges matrix multiplication operations that share the same right-hand
+    side operand into one large matrix multiplication.
+               ____      _________        _________
+      ----    |    |    |         |     M|  A * C  |
+    M| A  |  T| B  | * K|    C    | =    |---------|
+      ---- ,  |    |    |         |     T|  B * C  |
+       K       ----      ---------        ---------
+                K            R                R
+    """
+    gm = symbolic_trace(in_mod)
+
+    rhs_users: Dict[Node, List[Node]] = {}
+    lhs_users: Dict[Node, List[Node]] = {}
+
+    # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
+    # the matmul of which they are the LHS/RHS.
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target is not torch.matmul:
+            continue
+
+        lhs, rhs = node.args
+
+        # TODO: Properly handle aliasing caused by get_attr. For now,
+        # use the attribute name as the operand if the node is a
+        # get_attr.
+        lhs = lhs.target if lhs.op == "get_attr" else lhs
+        rhs = rhs.target if rhs.op == "get_attr" else rhs
+
+        lhs_users.setdefault(lhs, []).append(node)
+        rhs_users.setdefault(rhs, []).append(node)
+
+    for rhs, mms in rhs_users.items():
+        # There must be at least matmuls for a merge to make sense.
+        if len(mms) < 2:
+            continue
+
+        # All matmuls must not depend on each other directly or indirectly
+        # in order for the merge to be possible.
+        if not are_nodes_independent(mms):
+            continue
+
+        lhs_vals = [mm.args[0] for mm in mms]
+
+        # Merge the matmul.
+        # Collect a list of LHS operands and the single RHS operand.
+        lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
+        rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
+
+        # Concatenate all the LHS operands.
+        merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
+
+        # Multiply the concatenated LHS operands with the one RHS. This will produce
+        # the same results as all the individual matmuls involving rhs in the original graph,
+        # but they will all be concatenated together.
+        merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
+
+        # Split the result of the merged matmul using the shapes of the LHS operands
+        # to ascertain how large each chunk should be.
+        merge_mm_split = gm.graph.call_function(
+            split_result_tensors, (merge_mm, lhs), {}
+        )
+        merge_mm_res = [
+            gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
+            for out in range(len(lhs))
+        ]
+
+        # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
+        for old, new in zip(mms, merge_mm_res):
+            old.replace_all_uses_with(new)
+            gm.graph.erase_node(old)
+
+        # All of the new nodes created above were inserted at the end, so we need to sort
+        # the nodes topologically to make sure all definitions precede uses.
+        legalize_graph(gm)
+
+    gm.recompile()
+    gm.graph.lint()
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/meta_tracer.py b/MLPY/Lib/site-packages/torch/fx/experimental/meta_tracer.py
new file mode 100644
index 0000000000000000000000000000000000000000..143c96be65de2d0931ef26b93c6d12b3f8efe6d8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/meta_tracer.py
@@ -0,0 +1,268 @@
+import torch
+import torch.fx
+import warnings
+import functools
+import builtins
+
+from typing import Any, Callable, Dict, Optional, Union
+
+def embedding_override(self, input):
+    return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
+
+
+def nn_layernorm_override(self, input):
+    return input
+
+
+def torch_relu_override(x):
+    return x
+
+
+def torch_nn_relu_override(self, x):
+    return x
+
+
+def functional_relu_override(x, inplace=False):
+    assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
+    return x
+
+
+def torch_where_override(condition, x, y):
+    # torch.where returns the broadcasted tensor of condition, x, and y,
+    # so hack it by using addition
+    return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
+
+
+def torch_abs_override(input, *, out=None):
+    assert out is None, 'Dont support in-place abs for MetaTensor analysis'
+    return input
+
+manual_meta_overrides : Dict[Callable, Callable] = {
+    torch.nn.Embedding: embedding_override,
+    torch.nn.LayerNorm: nn_layernorm_override,
+    torch.relu: torch_relu_override,
+    torch.nn.functional.relu: functional_relu_override,
+    torch.nn.ReLU: torch_nn_relu_override,
+    torch.where: torch_where_override,
+    torch.abs: torch_abs_override,
+}
+
+def gen_constructor_wrapper(target):
+    @functools.wraps(target)
+    def wrapper(*args, **kwargs):
+        proxy = None
+
+        def check_has_proxy(v):
+            if isinstance(v, torch.fx.Proxy):
+                nonlocal proxy
+                proxy = v
+        torch.fx.node.map_aggregate(args, check_has_proxy)
+        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
+
+        if proxy is not None:
+            return proxy.tracer.create_proxy('call_function', target, args, kwargs)
+        else:
+            return target(*args, **kwargs)
+    return wrapper, target
+
+class MetaProxy(torch.fx.Proxy):
+    def install_tensor_meta(self, tensor_meta):
+        self._tensor_meta = tensor_meta
+
+    def size(self, dim=None):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.size(*[dim] if dim else [])
+        return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
+
+    def dim(self):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.dim()
+        return self.tracer.create_proxy('call_method', 'dim', (self,), {})
+
+    @property
+    def shape(self):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.shape
+        return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
+
+    @property
+    def dtype(self):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.dtype
+        return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
+
+    @property
+    def device(self):
+        # Hack so we can track when devices are used. During meta-tensor propagation,
+        # replace these values with a constant 'meta'
+        return MetaDeviceAttribute(self, 'device')
+
+    def __getattr__(self, k):
+        if k == '_tensor_meta':
+            return self.__getattribute__(k)
+        # note: not added to the graph yet, if this is a method call
+        # we peephole optimize to the method invocation
+        return MetaAttribute(self, k)
+
+class MetaAttribute(MetaProxy):
+    def __init__(self, root, attr: str):
+
+        self.root = root
+        self.attr = attr
+        self.tracer = root.tracer
+        self._node = None
+
+    @property
+    def node(self):
+        # the node for attributes is added lazily, since most will just be method calls
+        # which do not rely on the getitem call
+        if self._node is None:
+            self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+        return self._node
+
+    def __call__(self, *args, **kwargs):
+        return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+
+class MetaDeviceAttribute(MetaAttribute):
+    pass
+
+def proxys_to_metas(v):
+    if isinstance(v, MetaDeviceAttribute):
+        return 'meta'
+    if isinstance(v, torch.fx.Proxy):
+        assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
+        assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
+        return v._tensor_meta
+    return v
+
+class MetaTracer(torch.fx.Tracer):
+    allow_insert_stateless_mods : bool = True
+
+    _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
+
+    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
+        rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
+
+        if kind == 'placeholder' and target in self.meta_args:
+            rv.install_tensor_meta(self.meta_args[target])
+            return rv
+
+        if target in self.orig_fns:
+            # NOTE: tensor constructors in PyTorch define the `device` argument as
+            # *kwargs-only*. That is why this works. If you add methods to
+            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
+            # this will break and you will likely see issues where we cannot infer
+            # the size of the output.
+            if 'device' in kwargs:
+                kwargs['device'] = 'meta'
+
+        try:
+            args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
+            kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
+
+            if kind == 'call_function':
+                meta_target = manual_meta_overrides.get(target, target)
+                meta_out = meta_target(*args_metas, **kwargs_metas)
+            elif kind == 'call_method':
+                meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
+            elif kind == 'call_module':
+                assert hasattr(self, 'orig_forward')
+                self._disable_module_getattr = True
+                try:
+                    mod = self.root.get_submodule(target)
+                    mod_type = type(mod)
+                    if mod_type in manual_meta_overrides:
+                        meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
+                    else:
+                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
+                finally:
+                    self._disable_module_getattr = False
+            elif kind == 'get_attr':
+                self._disable_module_getattr = True
+                try:
+                    attr_itr = self.root
+                    atoms = target.split('.')
+                    for atom in atoms:
+                        attr_itr = getattr(attr_itr, atom)
+                    assert isinstance(attr_itr, torch.Tensor)
+                    meta_out = attr_itr.to(device='meta')
+                finally:
+                    self._disable_module_getattr = False
+            else:
+                return rv
+
+            # TODO
+            assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
+            rv.install_tensor_meta(meta_out)
+        except Exception as e:
+            warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
+
+        return rv
+
+    def getattr(self, attr, attr_val, parameter_proxy_cache):
+        if getattr(self, '_disable_module_getattr', False):
+            return attr_val
+        else:
+            return super().getattr(attr, attr_val, parameter_proxy_cache)
+
+    def call_module(self, m, forward, args, kwargs):
+        self.orig_forward = forward
+        return super().call_module(m, forward, args, kwargs)
+
+    def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
+        """
+        Helper method which tries to insert a module that was not declared as submodule.
+        """
+        idx = 0
+        mod_name = mod.__class__.__name__.lower()
+        path = f"{mod_name}_{idx}"
+        while hasattr(self.root, path):
+            path = f"{mod_name}_{idx}"
+            idx += 1
+
+        self.root.add_module(path, mod)
+        return path
+
+    def path_of_module(self, mod: torch.nn.Module) -> str:
+        try:
+            return super().path_of_module(mod)
+        except NameError as e:
+            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
+                path = self._insert_module_as_submodule(mod)
+                self.prev_module = path
+                return path
+            raise
+
+    def proxy(self, node):
+        return MetaProxy(node, self)
+
+    def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
+        assert isinstance(meta_args, dict)
+        self.meta_args = meta_args
+
+        self.patched_torch_methods = {
+            target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
+        }
+        self.orig_fns = set()
+
+        for name, (wrapper, orig) in self.patched_torch_methods.items():
+            setattr(torch, name, wrapper)
+            self.orig_fns.add(orig)
+
+        try:
+            graph = super().trace(root, concrete_args)
+            graph._tracer_extras = {'meta_args': meta_args}
+            return graph
+        finally:
+            for name, (_, orig) in self.patched_torch_methods.items():
+                setattr(torch, name, orig)
+
+
+def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
+                   meta_args : Optional[Dict[str, torch.Tensor]] = None,
+                   concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
+    tracer = MetaTracer()
+    graph = tracer.trace(root, meta_args, concrete_args)
+    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
+    gm = torch.fx.GraphModule(tracer.root, graph, name)
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ecd695974a387abdb7fcb3618ddd14e6b82b0ef
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4fc77613d9b6045f052969042946aa8eb3f37fc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ef26ebb6e7c993de92b1ee7f91f47c12bc01021
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f35becf9565832439eca13c5188318830d7c1363
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13418d3a9018626d3394e019432391c09c10cfac
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efa521704e2d028ad3625871d8b5dce1d692b871
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ac5b598e7b80738631efed7315124e4267edce2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c82ea5aebef10983a91ff568ebc68e40ca02d280
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b187f07878befdc1ef027fa35bbf83f3d542040a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py
@@ -0,0 +1,557 @@
+from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
+    op_mod, op_gt, op_lt, op_neq, op_eq
+from torch.fx.tensor_type import TensorType, Dyn
+
+
+class Constraint:
+    pass
+
+
+class Conj(Constraint):
+    def __init__(self, conjuncts):
+        """
+        :param conjuncts: Conjunction of constraints
+        """
+        self.conjucts = conjuncts
+
+    def __eq__(self, other):
+        if isinstance(other, Conj):
+            return self.conjucts == other.conjucts and self.conjucts == other.conjucts
+        else:
+            return False
+
+    def __repr__(self):
+        return f'And({self.conjucts})'
+
+
+class Disj(Constraint):
+    def __init__(self, disjuncts):
+        """
+        :param disjuncts: Disjunction of constraints
+        """
+        self.disjuncts = disjuncts
+
+    def __eq__(self, other):
+        if isinstance(other, Disj):
+            return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
+        else:
+            return False
+
+    def __repr__(self):
+        return f'Or({self.disjuncts})'
+
+
+class Prod(Constraint):
+    def __init__(self, products):
+        """
+        :param products: lists of dimensions to multiply
+        """
+        self.products = products
+
+    def __eq__(self, other):
+        if isinstance(other, Prod):
+            return self.products == other.products and self.products == other.products
+        else:
+            return False
+
+    def __repr__(self):
+        return f'Product({self.products})'
+
+
+class T(Constraint):
+    """
+    True
+    """
+    def __init__(self):
+        pass
+
+    def __eq__(self, other):
+        return isinstance(other, T)
+
+    def __repr__(self):
+        return 'True'
+
+class F(Constraint):
+    """
+    False
+    """
+    def __init__(self):
+        pass
+
+    def __eq__(self, other):
+        return isinstance(other, F)
+
+    def __repr__(self):
+        return 'False'
+
+
+class BinaryConstraint(Constraint):
+    """
+    Represents all binary operations
+    """
+    def __init__(self, lhs, rhs, op):
+        """
+        :param lhs: lhs of the constraint
+        :param rhs: rhs of the constraint
+        :param op: string representing the operation
+        """
+        self.lhs = lhs
+        self.rhs = rhs
+        self.op = op
+
+    def __eq__(self, other):
+        if isinstance(other, BinaryConstraint):
+            return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
+        else:
+            return False
+
+    def __repr__(self):
+        return f'({self.lhs} {self.op} {self.rhs})'
+
+
+class BinConstraintT(BinaryConstraint):
+    """
+    Binary constraints about tensors
+    """
+    def __init__(self, lhs, rhs, op):
+        assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
+               (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
+        super().__init__(lhs, rhs, op)
+
+    def __eq__(self, other):
+        return super().__eq__(other)
+
+
+class BinConstraintD(BinaryConstraint):
+    """
+    Binary constraints about dimensions
+    """
+    def __init__(self, lhs, rhs, op):
+        assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
+        assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
+
+        super().__init__(lhs, rhs, op)
+
+    def __eq__(self, other):
+        return super().__eq__(other)
+
+
+
+class TGreatestUpperBound(Constraint):
+    """
+    Greatest Upper bound for tensors with dynamic type
+    """
+    def __init__(self, res, rhs1, rhs2):
+        """
+        :param res: tensor variable that stores the result of the outout
+        :param rhs1: tensor or tensor variable
+        :param rhs2: tensor or tensor variabke
+        """
+        self.res = res
+        self.rhs1 = rhs1
+        self.rhs2 = rhs2
+
+    def __repr__(self):
+        return f'{self.res} = {self.rhs1}⊔*{self.rhs2}'
+
+    def __eq__(self, other):
+        if isinstance(other, TGreatestUpperBound):
+            return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
+        else:
+            return False
+
+
+class DGreatestUpperBound(Constraint):
+    """
+    Greatest Upper bound for dimensions
+    """
+    def __init__(self, res, rhs1, rhs2):
+        """
+        :param res: Dimension variable to store the result
+        :param rhs1: dimension variable 1
+        :param rhs2: dimension variable 2
+        """
+        assert is_dim(res)
+        assert is_dim(rhs1)
+        assert is_dim(rhs2)
+
+        self.res = res
+        self.rhs1 = rhs1
+        self.rhs2 = rhs2
+
+    def __repr__(self):
+        return f'{self.res} = {self.rhs1}⊔{self.rhs2}'
+
+    def __eq__(self, other):
+        if isinstance(other, DGreatestUpperBound):
+            return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
+        else:
+            return False
+
+
+class CanReshape(Constraint):
+    """
+    can_reshape constraint
+    """
+    def __init__(self, src, target):
+        """
+        :param src: tensor variable
+        :param target: tensor
+        """
+        self.src = src
+        self.target = target
+
+    def __repr__(self):
+        return f'can-reshape({self.src}, {self.target})'
+
+    def __eq__(self, other):
+        if isinstance(other, CanReshape):
+            return self.src == other.src and self.target == other.target
+        else:
+            return False
+
+
+class IndexSelect(Constraint):
+
+    def __init__(self, tensor_size, input_var, dim_replace, index, output):
+        """
+        Args:
+            input_var: input to index_select
+            tensor_size: tensor size we are considering
+            dim_replace: the dimension of the output at "index"
+            index: location of the dimensions to replace in the input
+            output: variable to store the result
+        """
+        assert isinstance(input_var, TVar)
+        assert isinstance(output, TVar)
+        assert isinstance(dim_replace, DVar) or dim_replace == Dyn
+        assert isinstance(index, int)
+
+        self.input_var = input_var
+        self.tensor_size = tensor_size
+        self.dim_replace = dim_replace
+        self.index = index
+        self.output = output
+
+    def __repr__(self):
+
+        return f' {self.output} = ' \
+               f'IndexSelect({self.input_var}, ' \
+               f'tensor_size: {self.tensor_size}, ' \
+               f'{self.dim_replace}, ' \
+               f'{self.index})'
+
+    def __eq__(self, other):
+        if isinstance(other, IndexSelect):
+            return self.tensor_size == other.tensor_size and \
+                self.dim_replace == other.dim_replace and \
+                self.index == other.index and \
+                self.output == other.output and \
+                self.input_var == other.input_var
+        else:
+            return False
+
+
+class Transpose(Constraint):
+
+    def __init__(self, tensor_size, input_var, index1, index2, output):
+        """
+        Args:
+            tensor_size: current tensor size
+            input_var: variable to hold input
+            index1: dimension 1
+            index2: dimension 2
+            output: output that stores result
+        """
+        assert isinstance(input_var, TVar)
+        assert isinstance(output, TVar)
+        assert isinstance(index1, int)
+        assert isinstance(index2, int)
+
+        self.input_var = input_var
+        self.tensor_size = tensor_size
+        self.index1 = index1
+        self.index2 = index2
+        self.output = output
+
+    def __repr__(self):
+
+        return f' {self.output} = ' \
+               f'Transpose({self.input_var}, ' \
+               f'tensor_size: {self.tensor_size}, ' \
+               f'{self.index1}, ' \
+               f'{self.index2})'
+
+    def __eq__(self, other):
+        if isinstance(other, Transpose):
+            return self.tensor_size == other.tensor_size and \
+                self.index1 == other.index1 and \
+                self.index2 == other.index2 and \
+                self.output == other.output and \
+                self.input_var == other.input_var
+        else:
+            return False
+
+
+class GetItem(Constraint):
+
+    def __init__(self, tensor_size, index, res, input_var):
+        """
+        Constraint for getting item given a tensor size
+        :param tensor_size: actual number
+        :param index: actual number representing the index
+        :param res: dimension variable to carry the item we get
+        :param input_var: a tensor variable from which we will get item
+        """
+        assert isinstance(res, DVar)
+
+        self.res = res
+        self.tensor_size = tensor_size
+        self.index = index
+        self.input_var = input_var
+
+    def __repr__(self):
+        return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
+
+    def __eq__(self, other):
+        if isinstance(other, GetItem):
+            return self.res == other.res and \
+                self.tensor_size == other.tensor_size and \
+                self.index == other.index and \
+                self.input_var == other.input_var
+        else:
+            return False
+
+class GetItemTensor(Constraint):
+
+    def __init__(self, tensor_size, index_tuple, res, input_var):
+        """
+        Constraint for getting item given a tensor size
+        However, when the argument is a tuple, we will
+        expect a tensor
+        :param tensor_size: actual number representing the rank
+        :param index_tuple: tuple for indexing
+        :param res: tensor variable to carry the item we get
+        :param input_var: a tensor variable from which we will get item
+        """
+        assert isinstance(res, TVar)
+
+        self.res = res
+        self.tensor_size = tensor_size
+        self.index_tuple = index_tuple
+        self.input_var = input_var
+
+    def __repr__(self):
+        return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
+
+    def __eq__(self, other):
+        if isinstance(other, GetItemTensor):
+            return self.res == other.res and \
+                self.tensor_size == other.tensor_size and \
+                self.index_tuple == other.index_tuple and \
+                self.input_var == other.input_var
+        else:
+            return False
+
+class CalcConv(Constraint):
+
+    def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
+        """
+        :param conv_result: the convolution result
+        :param input_var: input to convolution
+        :param c_out: output chanel type
+        :param kernel: kernel tuple
+        """
+        self.conv_result = conv_result
+        self.input_var = input_var
+        self.c_out = c_out
+        self.kernel = kernel
+        self.padding = padding
+        self.stride = stride
+        self.dilation = dilation
+        self.matching_constraint = matching_constraint_vars
+
+    def __repr__(self):
+        return f'{self.conv_result} =' \
+               f' calc-conv({self.input_var},' \
+               f' {self.c_out}, {self.kernel}, ' \
+               f'{self.padding}, {self.stride},' \
+               f' {self.dilation})'
+
+    def __eq__(self, other):
+        if isinstance(other, CalcConv):
+            return self.conv_result == other.conv_result and self.input_var == other.input_var and \
+                self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
+                and self.stride == other.stride and self.dilation == other.dilation \
+                and self.matching_constraint == other.matching_constraint
+        else:
+            return False
+
+
+class CalcMaxPool(Constraint):
+
+    def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
+        """
+        :param maxpool_result: the result of maxpool
+        :param input_var: input to convolution
+        :param kernel: kernel tuple
+        """
+        self.maxpool_result = maxpool_result
+        self.input_var = input_var
+        self.kernel = kernel
+        self.padding = padding
+        self.stride = stride
+        self.dilation = dilation
+        self.matching_constraint = matching_constraint_vars
+
+    def __repr__(self):
+        return f'{self.maxpool_result} =' \
+               f' calc-maxpool({self.input_var},' \
+               f'  {self.kernel}, ' \
+               f'{self.padding}, {self.stride},' \
+               f' {self.dilation})'
+
+    def __eq__(self, other):
+        if isinstance(other, CalcMaxPool):
+            return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
+                and self.kernel == other.kernel and self.padding == other.padding \
+                and self.stride == other.stride and self.dilation == other.dilation \
+                and self.matching_constraint == other.matching_constraint
+        else:
+            return False
+
+
+class ApplyBroadcasting(Constraint):
+    def __init__(self, res1, res2, input1, input2):
+        """
+        :param res1: resulting tensor 1
+        :param res2: resulting tensor 2
+        :param input1: tensor variable 1
+        :param input2: tensor variable 2
+        """
+        self.res1 = res1
+        self.res2 = res2
+        self.input1 = input1
+        self.input2 = input2
+
+    def __eq__(self, other):
+        if isinstance(other, ApplyBroadcasting):
+            return self.res1 == other.res1 \
+                and self.res2 == other.res2 \
+                and self.input1 == other.input1 \
+                and self.input2 == other.input2
+        else:
+            return False
+
+    def __repr__(self):
+        return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
+
+
+class CalcProduct(Constraint):
+    """
+    Given correct dimensions, calculate the product for flatten accounting for Dyn
+    """
+    def __init__(self, start, end, flattened, dims_to_flatten):
+        """
+        :param start: start index
+        :param end: end index
+        :param flattened: variable to store the product
+        :param dims_to_flatten: the type which we will flatten
+        """
+        assert isinstance(dims_to_flatten, list)
+        assert isinstance(flattened, TVar)
+        assert isinstance(start, int)
+        assert isinstance(end, int)
+
+        self.start = start
+        self.end = end
+        self.dims_to_flatten = dims_to_flatten
+        self.flattened = flattened
+
+    def __eq__(self, other):
+        if isinstance(other, CalcProduct):
+            return self.start == other.start and self.end == other.end and \
+                self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
+
+        else:
+            return False
+
+    def __repr__(self):
+        return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
+
+
+class TVar:
+    """
+    Tensor variable with no tensor constructor
+    """
+    def __init__(self, tvar):
+        """
+        :param tvar: tensor variable
+        """
+        self.tvar = tvar
+
+    def __repr__(self):
+        return f'TV({self.tvar})'
+
+    def __eq__(self, other):
+        if isinstance(other, TVar):
+            return self.tvar == other.tvar
+        else:
+            return False
+
+
+class DVar:
+    """
+    Dimension variable
+    """
+    def __init__(self, c):
+        """
+        :param c: character or number
+        """
+        self.c = c
+
+    def __repr__(self):
+        return f'DV({self.c})'
+
+    def __eq__(self, other):
+        if isinstance(other, DVar):
+            return self.c == other.c
+        else:
+            return False
+
+
+class BVar:
+    """
+    Boolean variable
+    """
+    def __init__(self, c):
+        """
+        :param c: character or number
+        """
+        self.c = c
+
+    def __repr__(self):
+        return f'BV({self.c})'
+
+    def __eq__(self, other):
+        if isinstance(other, BVar):
+            return self.c == other.c
+        else:
+            return False
+
+
+def is_algebraic_expression(constraint):
+    if isinstance(constraint, BinConstraintD):
+        return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
+    else:
+        return isinstance(constraint, Prod)
+
+
+def is_bool_expr(constraint):
+    if isinstance(constraint, BinConstraintD):
+        return constraint.op in [op_gt, op_lt, op_neq, op_eq]
+    else:
+        return isinstance(constraint, (BVar, Conj, Disj))
+
+def is_dim(d):
+    return isinstance(d, (DVar, int)) or d == Dyn
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a25bf62c824b97eb23852526a57baff91e5e517
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -0,0 +1,1279 @@
+import torch
+import operator
+import warnings
+from typing import Callable, Dict, Iterable
+
+from torch.fx._symbolic_trace import _assert_is_none
+from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
+    Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
+    TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
+from torch.fx.experimental.migrate_gradual_types.operation import \
+    op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
+from torch.fx.node import Target, Node
+from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
+    gen_bvar
+
+from torch.fx.tensor_type import Dyn, TensorType
+from torch.nn.modules.conv import Conv2d
+from torch.nn.modules.batchnorm import BatchNorm2d
+
+_INFERENCE_RULES: Dict[Target, Callable] = {}
+
+MAX_TENSOR_RANK = 4
+
+def register_inference_rule(call_target):
+    def register(fn):
+        if call_target in _INFERENCE_RULES:
+            raise RuntimeError(f'Inference rule already registered for {call_target}!')
+        _INFERENCE_RULES[call_target] = fn
+        return fn
+    return register
+
+
+def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter):
+    d, counter = gen_tensor_dims(n, counter)
+    c1 = BinConstraintT(input, TensorType(d), op_eq)
+    start_dim = n if start_dim == -1 else abs(start_dim)
+    end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1
+    c2 = CalcProduct(start_dim, end_dim, flattened, d)
+    nat_constraints = gen_nat_constraints(d)
+    return Conj([c1, c2, *nat_constraints]), counter
+
+
+@register_inference_rule(getattr)
+def get_attr_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    If the attribute is "device" then the tensor shape is preserved
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], str)
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    input = symbols[n.args[0]]
+    attr = n.args[1]
+
+    if attr == 'device':
+        return [BinConstraintT(input, output, op_eq)], counter
+    else:
+        raise NotImplementedError('Not yet implemented')
+
+@register_inference_rule(torch.bmm)
+def bmm_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Constraints that match the input to a size 3 tensor
+    and switch the dimensions according to the rules
+    of batch multiplication
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    bmm_output, counter = gen_tvar(counter)
+    symbols[n] = bmm_output
+
+    bmm_input1 = symbols[n.args[0]]
+    bmm_input2 = symbols[n.args[1]]
+
+    dims_input1, counter = gen_tensor_dims(3, counter)
+    dims_input2, counter = gen_tensor_dims(3, counter)
+
+    inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
+                       BinConstraintT(bmm_input2, Dyn, op_eq),
+                       BinConstraintT(bmm_output, Dyn, op_eq)])
+
+    input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
+                       BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
+                       BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
+
+    input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
+                       BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
+                       BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
+
+    consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
+
+    batch_size, counter = gen_dvar(counter)
+
+    inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
+                               BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
+                               BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
+                               *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
+
+    return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
+
+
+@register_inference_rule("index_select")
+def index_select_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We constrain the second argument to a vector or Dyn.
+    The output replaces the input with the shape of the vector
+    at the position given by the index (first argument)
+    """
+    # print(n.args)
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], int)
+    assert isinstance(n.args[2], Node)
+
+
+
+    index_select, counter = gen_tvar(counter)
+    symbols[n] = index_select
+
+    dims, counter = gen_tensor_dims(1, counter)
+
+    # equality constraint
+    is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
+    is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
+
+    c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
+                                for i in range(MAX_TENSOR_RANK)])])
+    c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
+                             for i in range(MAX_TENSOR_RANK)])])
+
+    return [Disj([c2, c3])], counter
+
+
+@register_inference_rule("expand")
+def expand_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the exact constraints as we do for tensor additions but we constraint
+    the rank of this expression to be equal to len(n.args[1:]) so that only
+    those cases get considered for the output
+    """
+    assert isinstance(n.args[0], Node)
+
+    # define the output for expand
+    expand, counter = gen_tvar(counter)
+    symbols[n] = expand
+
+    # since we do not have two nodes here, we will construct an argument variable
+    e1 = symbols[n.args[0]]
+    e2, counter = gen_tvar(counter)
+
+    e2_nat_constraints = []
+    for arg in n.args[1:]:
+        assert isinstance(arg, (Node, int))
+        if isinstance(arg, Node):
+            assert isinstance(symbols[arg], DVar)
+            e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
+
+    e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
+
+    constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
+
+    # constraint the output size
+    dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
+    nat_constraints = gen_nat_constraints(dims)
+    c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
+    constraints += c
+
+    return constraints, counter
+
+
+@register_inference_rule(torch.nn.functional.gelu)
+@register_inference_rule(torch.nn.functional.dropout)
+@register_inference_rule(torch.nn.functional.softmax)
+@register_inference_rule("detach")
+@register_inference_rule("to")
+@register_inference_rule("int")
+@register_inference_rule("long")
+@register_inference_rule("contiguous")
+@register_inference_rule(torch.ones)
+@register_inference_rule(torch.zeros)
+def equality_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    if isinstance(n.args[0], Node):
+        input = symbols[n.args[0]]
+        if isinstance(input, TVar):
+            return [BinConstraintT(input, output, op_eq)], counter
+
+        # then we have dimension variables
+        else:
+            for arg in n.args:
+                assert isinstance(symbols[arg], DVar)
+        my_size = [symbols[arg] for arg in n.args]
+        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
+
+    elif isinstance(n.args[0], tuple):
+        # then the tuple is the size
+        assert len(n.args[0]) <= 4
+        my_size = [symbols[arg] for arg in n.args[0]]
+        return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
+    else:
+        raise NotImplementedError('Method not yet implemented')
+
+
+@register_inference_rule("transpose")
+def transpose_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Can be considered as a sequence of two index selects, so we generate constraints accordingly
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], int)
+    assert isinstance(n.args[2], int)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    from_arg = symbols[n.args[0]]
+    assert isinstance(from_arg, TVar)
+
+    # input and output are dyn
+    is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
+
+    # or input is a tensor and we actually do the replacement
+    c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
+
+    return [Disj([is_dyn, c3])], counter
+
+
+@register_inference_rule("type_as")
+def type_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+
+    from_arg = symbols[n.args[0]]
+    to_arg = symbols[n.args[1]]
+
+    assert isinstance(from_arg, TVar)
+    assert isinstance(to_arg, TVar)
+
+    return [BinConstraintT(from_arg, to_arg, op_consistency),
+            BinConstraintT(output, to_arg, op_eq)], counter
+
+@register_inference_rule("masked_fill_")
+def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Similar to addition. For now we implement the constraints when
+    the argument is a boolean tensor. There is also a case for when
+    it is a condition. We will leave this out for now.
+    """
+
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], Node)
+
+    # We will retrieve the type variables from the symbol table
+    # and confirm they are tensor variables
+
+    e1 = symbols[n.args[0]]
+    e2 = symbols[n.args[1]]
+
+    if isinstance(e1, TVar) and isinstance(e2, TVar):
+        masked_fill_tensor, counter = gen_tvar(counter)
+        symbols[n] = masked_fill_tensor
+        return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
+    else:
+        raise NotImplementedError('Not yet implemented')
+
+
+@register_inference_rule(torch.nn.functional.embedding)
+def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    embedding_dim_weights = symbols[n.args[1]]
+
+    # will treat this as a static shape. So we will not use matching.
+    weight_dims, counter = gen_tensor_dims(2, counter)
+    equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
+    embedding_dim = weight_dims[1]
+    constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
+    return [equality_constraint] + constraints, counter
+
+
+@register_inference_rule(torch.nn.modules.sparse.Embedding)
+def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    The output shape differs from the input shape in the last dimension
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)
+
+
+def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
+
+    embedding_output, counter = gen_tvar(counter)
+    symbols[n] = embedding_output
+    embedding_input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
+    output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)
+
+    c1 = Conj([input_dyn, output_dyn])
+    c2 = []
+
+    for i in range(1, MAX_TENSOR_RANK):
+        new_dims, counter = gen_tensor_dims(i, counter)
+        nat_constraints = gen_nat_constraints(new_dims)
+
+        # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
+        c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
+                           BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
+                          nat_constraints)
+        c2.append(c_tensor_i)
+
+    return [Disj([c1, Disj(c2)])], counter
+
+
+@register_inference_rule(torch.tensor)
+def tensor_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    If the tensor is a scalar, we will skip it since we
+    do not support scalars yet. We will add support in the future
+    if it's needed. For our examples so far, scalars are not needed.
+    """
+    return [], counter
+
+
+@register_inference_rule("reshape")
+@register_inference_rule("view")
+def view_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Similar to reshape but with an extra condition on the strides
+    """
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    my_view, counter = gen_tvar(counter)
+    symbols[n] = my_view
+
+
+    src_var = symbols[n.args[0]]
+    t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]]  # target shape
+    t2_type = []
+    num_constraints = []
+
+    for t in t2:
+        if t == -1:
+            var, counter = gen_dvar(counter)
+            t2_type.append(var)
+            num_constraints.append(BinConstraintD(var, Dyn, op_neq))
+
+        else:
+            num_constraints.append(BinConstraintD(t, Dyn, op_neq))
+            t2_type.append(t)
+
+    t2_type = TensorType(t2_type)  # type: ignore[assignment]
+
+    c1 = BinConstraintT(my_view, t2_type, op_eq)
+    c2 = CanReshape(src_var, t2_type)
+
+    # TODO: add the extra check mentioned here:
+    # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
+
+    return [c1, c2] + num_constraints, counter  # type: ignore[operator]
+
+
+@register_inference_rule("size")
+def size_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    The constraint is just lhs = rhs.
+    Ex: size = input_ids.size()
+    """
+
+
+    if len(n.args) == 1:
+        # generate the new variable
+        size, counter = gen_tvar(counter)
+        symbols[n] = size
+        input = symbols[n.args[0]]
+        c = BinConstraintT(input, size, op_eq)
+        return [c], counter
+
+    elif len(n.args) == 2:
+        # TODO: review this rule; should input = dyn; output = dyn be included here?
+        if isinstance(n.args[1], int):
+            # generate the new variable
+            size_index, counter = gen_dvar(counter)
+            symbols[n] = size_index
+            input = symbols[n.args[0]]
+            c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
+            c3 = BinConstraintD(0, size_index, op_leq)
+
+            input_dyn = BinConstraintT(input, Dyn, op_eq)
+            output_dyn = BinConstraintD(size_index, Dyn, op_eq)
+            c1 = Conj([input_dyn, output_dyn])
+
+            return [Disj([c1, Conj([Disj(c2), c3])])], counter
+
+        else:
+            raise NotImplementedError
+
+    else:
+        raise NotImplementedError
+
+
+def range_check(i, n):
+    """
+    Checks if an index i is within range of a size n list
+    Args:
+        i: index
+        n: list size
+
+    Returns: Boolean
+    """
+    if i >= 0:
+        return T() if i < n else F()
+    else:
+        return T() if i >= n else F()
+
+
+@register_inference_rule(torch.cumsum)
+def cumsum_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Input and output shapes should be equal
+    We should verify that the index is valid
+    """
+    assert isinstance(n.args[0], Node)
+    arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
+    assert isinstance(arg_1, int)
+
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+    input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(input, Dyn, op_eq)
+    output_dyn = BinConstraintT(output, Dyn, op_eq)
+    c1 = Conj([input_dyn, output_dyn])
+    c2 = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims)
+
+        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
+                           BinConstraintT(output, TensorType(new_dims), op_eq)] +
+                          [range_check(arg_1, i)] + nat_constraints)
+
+        c2.append(c_tensor_i)
+    dyn_or_tensor = Disj([c1, Disj(c2)])
+    return [dyn_or_tensor], counter
+
+
+@register_inference_rule(_assert_is_none)
+def assert_inference_rule(n: Node, symbols, constraints, counter):
+    assert len(n.users) == 0
+    return [], counter
+
+
+@register_inference_rule(operator.getitem)
+def getitem_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # dimension output case
+    if isinstance(n.args[1], int):
+        # create and store the new dimension variable
+        get_item_output, counter = gen_dvar(counter)
+        symbols[n] = get_item_output
+
+        # retrieve arg variables
+        get_item_arg = symbols[n.args[0]]
+        assert isinstance(get_item_arg, TVar)
+
+
+        # if the input is dynamic, we accept any index and return
+        # a dynamic dimension as output
+        input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
+        output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
+        c1 = Conj([input_dyn, output_dyn])
+
+        # if the input is a tensor,
+        # generate a getItem constraint which will be expanded based on the
+        # tensor dimension.
+
+        c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
+
+
+        # since the output is a dimension, we make sure it's a natural number
+        # added as a conjunction to the disjunction of c2
+        c3 = BinConstraintD(0, get_item_output, op_leq)
+        return [Disj([c1, Conj([Disj(c2), c3])])], counter
+
+    # tensor output case
+    elif isinstance(n.args[1], tuple):
+        # create and store the new tensor variable
+        get_item_output, counter = gen_tvar(counter)
+        symbols[n] = get_item_output
+
+        # retrieve arg variables
+        if n.args[0] in symbols:
+            get_item_arg = symbols[n.args[0]]
+            assert isinstance(get_item_arg, TVar)
+
+            input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
+            output_dyn = BinConstraintT(get_item_output, Dyn, op_eq)  # type: ignore[assignment]
+            c1 = Conj([input_dyn, output_dyn])
+
+            c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg)  # type: ignore[misc]
+                  for i in range(MAX_TENSOR_RANK)]
+        else:
+            # TODO: we should figure out why there is a key-error here.
+            return [], counter
+
+        return [Disj([c1, *c2])], counter
+
+    else:
+        raise RuntimeError('Method not yet implemented')
+
+
+@register_inference_rule(operator.gt)
+def gt_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], (Node, int))
+    assert isinstance(n.args[1], (Node, int))
+
+    # We make sure this node will not be used again. We do not
+    # generate a constraint about that node. Only about the operands.
+
+    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
+    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(e1, TVar) and isinstance(e2, TVar):
+            gt_tensor, counter = gen_tvar(counter)
+            symbols[n] = gt_tensor
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor)
+
+        elif isinstance(e1, DVar) and isinstance(e2, DVar):
+            # This is meant to be used for flow analysis only
+            gt_constraint = BinConstraintD(e1, e2, op_gt)
+
+            my_gt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise RuntimeError('Sort Mismatch')
+
+    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
+        if isinstance(e1, DVar):
+            # This is meant to be used for flow analysis only
+            gt_constraint = BinConstraintD(e1, e2, op_gt)
+
+            my_gt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        elif isinstance(e1, TVar) and isinstance(e2, int):
+            # then we made the wrong assumption about the argument being a tensor
+            # so we should fix the assumption
+            warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
+
+            new_e1, counter = gen_dvar(counter)
+            symbols[n.args[0]] = new_e1
+            symbols[n.args[0]]
+
+            gt_constraint = BinConstraintD(new_e1, e2, op_gt)
+
+            my_gt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise NotImplementedError('Method not yet implemented')
+
+    else:
+        raise NotImplementedError('Method not yet implemented')
+
+
+@register_inference_rule(operator.eq)
+def eq_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], (Node, int))
+    assert isinstance(n.args[1], (Node, int))
+
+    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
+    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(e1, TVar) and isinstance(e2, TVar):
+            eq_tensor, counter = gen_tvar(counter)
+            symbols[n] = eq_tensor
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor)
+
+        elif isinstance(e1, DVar) and isinstance(e2, DVar):
+            # This is meant to be used for flow analysis only
+            eq_constraint = BinConstraintD(e1, e2, op_eq)
+
+            my_eq, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise RuntimeError('Sort Mismatch')
+
+    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
+        if isinstance(e1, DVar):
+            # This is meant to be used for flow analysis only
+            eq_constraint = BinConstraintD(e1, e2, op_eq)
+
+            my_eq, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
+            return [equality_constraint], counter
+        else:
+            raise NotImplementedError('Method not yet implemented')
+    else:
+        raise NotImplementedError('Method not yet implemented')
+
+@register_inference_rule(operator.ne)
+def neq_inference_rule(n: Node, symbols, constraints, counter):
+    """
+    Translates to inconsistent in gradual types.
+    To prove inequality, we should prove that
+    tensors are either different sizes or
+    disagree on at least one dimension
+
+    This is a WIP (works when the condition
+    is false. We are working on making this operation work
+    when the condition is true as well)
+    """
+    assert isinstance(n.args[0], Node)
+    assert isinstance(n.args[1], tuple)
+
+    # implementing for size 3 and 4
+    if len(n.args[1]) == 3:
+
+        assert isinstance(n.args[1][0], (Node, int))
+        assert isinstance(n.args[1][1], (Node, int))
+        assert isinstance(n.args[1][2], (Node, int))
+
+        lhs = symbols[n.args[0]]
+
+        b, counter = gen_tensor_dims(4, counter)
+        input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)
+
+        d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
+        d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
+        d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
+
+        # dimensions not equal
+        my_ne, counter = gen_bvar(counter)
+        neq_1 = BinConstraintD(d1, b[0], op_neq)
+        neq_2 = BinConstraintD(d2, b[1], op_neq)
+        neq_3 = BinConstraintD(d3, b[2], op_neq)
+
+        # dimensions inconsistent
+        dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
+        dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
+        dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
+
+        dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
+
+        # we are covering size 3 and 4 only for now
+        ne_constraint = Conj([input_is_size3, dims_inconsistent])
+
+        my_ne, counter = gen_bvar(counter)
+        equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
+
+    elif len(n.args[1]) == 4:
+
+        assert isinstance(n.args[1][0], (Node, int))
+        assert isinstance(n.args[1][1], (Node, int))
+        assert isinstance(n.args[1][2], (Node, int))
+        assert isinstance(n.args[1][3], (Node, int))
+
+        lhs = symbols[n.args[0]]
+
+        b1, counter = gen_dvar(counter)
+        b2, counter = gen_dvar(counter)
+        b3, counter = gen_dvar(counter)
+        b4, counter = gen_dvar(counter)
+
+        input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)
+
+        d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
+        d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
+        d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
+        d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]
+
+        # dimensions not equal
+        my_ne, counter = gen_bvar(counter)
+        neq_1 = BinConstraintD(d1, b1, op_neq)
+        neq_2 = BinConstraintD(d2, b2, op_neq)
+        neq_3 = BinConstraintD(d3, b3, op_neq)
+        neq_4 = BinConstraintD(d4, b4, op_neq)
+
+        # dimensions to inconsistent
+        dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
+        dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
+        dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
+        dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
+
+        dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
+
+        ne_constraint = Conj([input_is_size4, dims_inconsistent])
+
+        my_ne, counter = gen_bvar(counter)
+
+        equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
+
+    else:
+        raise NotImplementedError('Method not yet implemented')
+
+    return [equality_constraint], counter
+
+
+@register_inference_rule(operator.lt)
+def lt_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], (Node, int))
+    assert isinstance(n.args[1], (Node, int))
+
+    # We make sure this node will not be used again. We do not
+    # generate a constraint about that node. Only about the operands.
+
+    e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
+    e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(e1, TVar) and isinstance(e2, TVar):
+            lt_tensor, counter = gen_tvar(counter)
+            symbols[n] = lt_tensor
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor)
+
+        elif isinstance(e1, DVar) and isinstance(e2, DVar):
+            # This is meant to be used for flow analysis only
+            lt_constraint = BinConstraintD(e1, e2, op_lt)
+
+            my_lt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
+            return [equality_constraint], counter
+
+        else:
+            raise RuntimeError('Sort Mismatch')
+
+    elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
+        if isinstance(e1, DVar):
+            # This is meant to be used for flow analysis only
+            lt_constraint = BinConstraintD(e1, e2, op_lt)
+
+            my_lt, counter = gen_bvar(counter)
+            equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
+            return [equality_constraint], counter
+        else:
+            raise NotImplementedError('Method not yet implemented')
+
+    else:
+        raise NotImplementedError('Method not yet implemented')
+
+
+@register_inference_rule(torch.full)
+def full_inference_rule(n: Node, symbols, constraints, counter):
+    full, counter = gen_tvar(counter)
+    symbols[n] = full
+    res = []
+
+    assert isinstance(n.args[0], Iterable)
+    for arg in n.args[0]:
+        dim = arg if isinstance(arg, int) else symbols[arg]
+        res.append(dim)
+    c = BinConstraintT(full, TensorType(list(res)), op_eq)  # type: ignore[arg-type]
+    return [c], counter
+
+
+# TODO normalize index
+@register_inference_rule(torch.arange)
+def arange_inference_rule(n: Node, symbols, constraints, counter):
+    start = 0
+    step = 1
+
+    if len(n.args) == 1:
+        end = symbols[n.args[0]]
+    else:
+        raise NotImplementedError('Not yet implemented')
+
+    # int((end - start) / step)
+    d1, counter = gen_dvar(counter)
+    size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
+    arange, counter = gen_tvar(counter)
+    symbols[n] = arange
+
+    # either the a parameter is a number or it is Dyn
+    c1 = Disj([BinConstraintD(end, Dyn, op_eq),
+               BinConstraintD(start, Dyn, op_eq),
+               BinConstraintD(step, Dyn, op_eq)])
+    c2 = BinConstraintD(d1, Dyn, op_eq)
+    both_dyn = Conj([c1, c2])
+
+    c11 = Conj([BinConstraintD(end, Dyn, op_neq),
+                BinConstraintD(start, Dyn, op_neq),
+                BinConstraintD(step, Dyn, op_neq)])
+    c22 = BinConstraintD(d1, Dyn, op_neq)
+    both_numbers = Conj([c11, c22, size_constraint])
+
+    return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
+
+def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
+    # additional vars that don't correspond to expressions
+    e11, counter = gen_tvar(counter)
+    e22, counter = gen_tvar(counter)
+
+    # generate constraints
+    c1 = TGreatestUpperBound(output_var, e11, e22)
+    c2 = ApplyBroadcasting(e11, e22, e1, e2)
+    c3 = BinConstraintT(e11, e22, op_consistency)
+    return [c1, c2, c3], counter
+
+
+@register_inference_rule(operator.mul)
+@register_inference_rule(torch.ne)
+@register_inference_rule("ne")
+@register_inference_rule(torch.add)
+@register_inference_rule(operator.add)
+def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
+
+    op_code = None
+    if n.target == operator.add or n.target == torch.add:
+        op_code = op_add
+    elif n.target == operator.mul:
+        op_code = op_mul
+
+    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
+        if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
+            my_output, counter = gen_tvar(counter)
+            symbols[n] = my_output
+            e1 = symbols[n.args[0]]
+            e2 = symbols[n.args[1]]
+
+            return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
+        else:
+            raise NotImplementedError('Method not yet implemented')
+
+    elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
+        if isinstance(symbols[n.args[0]], TVar):
+            my_output, counter = gen_tvar(counter)
+            symbols[n] = my_output
+            e1 = symbols[n.args[0]]
+            return [BinConstraintT(my_output, e1, op_eq)], counter
+        elif isinstance(symbols[n.args[0]], DVar):
+            my_output, counter = gen_dvar(counter)
+            symbols[n] = my_output
+            e1 = symbols[n.args[0]]
+
+            # we will propagate the runtime value here since this is regular addition
+            c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
+                      BinConstraintD(0, my_output, op_leq)])
+            return [c], counter
+
+    elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
+        if isinstance(symbols[n.args[1]], TVar):
+            my_output, counter = gen_tvar(counter)
+            symbols[n] = my_output
+            e2 = symbols[n.args[1]]
+            return [BinConstraintT(my_output, e2, op_eq)], counter
+        elif isinstance(symbols[n.args[1]], DVar):
+            my_output, counter = gen_dvar(counter)
+            symbols[n] = my_output
+            e2 = symbols[n.args[1]]
+
+            # we will propagate the runtime value here since this is regular addition
+            c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
+                      BinConstraintD(0, my_output, op_leq)])
+            return [c], counter
+
+        else:
+            raise NotImplementedError('Method not yet implemented')
+
+    else:
+        # TODO generate add constraints for scalar addition
+        raise NotImplementedError('Addition not yet implemented')
+
+
+@register_inference_rule(torch.flatten)
+def flatten_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    flattened, counter = gen_tvar(counter)
+    symbols[n] = flattened
+
+    input = symbols[n.args[0]]
+
+    # set the default start and end dims
+    start_dim = 1
+    end_dim = -1
+
+    if len(n.args) > 1:
+        assert isinstance(n.args[1], int)
+        start_dim = n.args[1]
+
+    if len(n.args) > 2:
+        assert isinstance(n.args[2], int)
+        end_dim = n.args[2]
+
+    c1 = BinConstraintT(input, Dyn, op_eq)
+    c2 = BinConstraintT(flattened, Dyn, op_eq)
+    both_dyn = Conj([c1, c2])
+
+    const = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
+        const.append(c)
+
+    return [Disj([both_dyn, *const])], counter
+
+
+@register_inference_rule(torch.nn.functional.layer_norm)
+def layer_norm_functional(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
+
+
+@register_inference_rule(torch.nn.LayerNorm)
+def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    Input and output shapes should be equal.
+    Input should be consistent with the normalized_shape
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
+
+
+def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+    input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(input, Dyn, op_eq)
+    output_dyn = BinConstraintT(output, Dyn, op_eq)
+
+    c1 = Conj([input_dyn, output_dyn])
+
+    c2 = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs, counter = gen_tensor_dims(i, counter)
+        nat_constraints = gen_nat_constraints(new_dims_rhs)
+
+        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
+                           BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
+                          add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
+                          nat_constraints)
+        c2.append(c_tensor_i)
+    return [Disj([c1, Disj(c2)])], counter
+
+@register_inference_rule(torch.nn.Dropout)
+@register_inference_rule(torch.nn.ReLU)
+def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    Input and output shapes should be equal.
+    """
+    assert isinstance(n.args[0], Node)
+    output, counter = gen_tvar(counter)
+    symbols[n] = output
+    input = symbols[n.args[0]]
+    assert isinstance(input, TVar)
+    return [BinConstraintT(input, output, op_eq)], counter
+
+
+@register_inference_rule(torch.nn.Linear)
+def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    """
+    Input and output sizes should be the same except for the last dimension
+    If the input is Dyn, then so should the output
+    """
+    assert isinstance(n.args[0], Node)
+    return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
+
+
+@register_inference_rule("dim")  # type: ignore[attr-defined]
+def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+    my_dim, counter = gen_dvar(counter)
+    symbols[n] = my_dim
+    input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(input, Dyn, op_eq)
+    output_dyn = BinConstraintD(my_dim, Dyn, op_eq)
+
+    c1 = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+
+        c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
+                           BinConstraintD(my_dim, i, op_eq)])
+        c1.append(c_tensor_i)
+
+    return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
+
+
+@register_inference_rule(torch._C._nn.linear)  # type: ignore[attr-defined]
+def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+    weight_dims, counter = gen_tensor_dims(2, counter)
+    equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
+    constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
+    return [equality_constraint] + constraints, counter
+
+
+def linear_constraints(n: Node, in_features, out_features, symbols, counter):
+    linear_output, counter = gen_tvar(counter)
+    symbols[n] = linear_output
+    linear_input = symbols[n.args[0]]
+
+    input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
+    output_dyn = BinConstraintT(linear_output, Dyn, op_eq)
+
+    c1 = Conj([input_dyn, output_dyn])
+
+    c2 = []
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
+
+        c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
+                           BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
+                          add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
+                          nat_constraints)
+        c2.append(c_tensor_i)
+    return [Disj([c1, Disj(c2)])], counter
+
+def add_layer_norm_constraints(input_dim, normalized_dim):
+    """
+    The constraints say that the type has te form: [*, 1024, 1024]
+     while the normalized_dim have the form [1024, 1024]
+    Args:
+        input_dim: Input shape of layer norm
+        normalized_dim: normalized_dim parameter of the module instance
+
+    """
+
+    # in this case we return false since there's a pattern mismatch
+    if len(normalized_dim) > len(input_dim):
+        return [F()]
+
+    else:
+        constraints = []
+        for i, n in zip(reversed(input_dim), reversed(normalized_dim)):
+            constraints.append(BinConstraintD(i, n, op_consistency))
+        return constraints
+
+
+def add_linear_constraints(dims1, dims2, in_features, out_features):
+    assert len(dims1) == len(dims2)
+    constraints = []
+    for i in range(len(dims1)):
+        if i == len(dims1) - 1:
+            constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
+            constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
+        else:
+            constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))
+
+    return constraints
+
+
+@register_inference_rule(torch.reshape)
+def reshape_inference_rule(n: Node, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    my_reshape, counter = gen_tvar(counter)
+    symbols[n] = my_reshape
+
+    src_var = symbols[n.args[0]]
+    t2 = n.args[1]
+    t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2])  # type: ignore[union-attr]
+    c1 = BinConstraintT(my_reshape, t2_type, op_eq)  # type: ignore[union-attr]
+    c2 = CanReshape(src_var, t2_type)
+
+    return [c1, c2], counter
+
+
+@register_inference_rule(BatchNorm2d)
+def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    # generate the new variable
+    batchnorm_output, counter = gen_tvar(counter)
+    symbols[n] = batchnorm_output
+    batchnorm_input = symbols[n.args[0]]
+
+    # dim vars
+    d1, counter = gen_dvar(counter)
+    d2, counter = gen_dvar(counter)
+    d3, counter = gen_dvar(counter)
+    d4, counter = gen_dvar(counter)
+
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+
+    c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
+    c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
+    return [c1, c2, *nat_constraints], counter
+
+
+@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
+def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    avg_pool, counter = gen_tvar(counter)
+
+    symbols[n] = avg_pool
+    input_var = symbols[n.args[0]]
+
+    # dim vars
+    d1, counter = gen_dvar(counter)
+    d2, counter = gen_dvar(counter)
+    d3, counter = gen_dvar(counter)
+    d4, counter = gen_dvar(counter)
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
+    c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
+
+    return [c1, c2, *nat_constraints], counter
+
+
+@register_inference_rule(Conv2d)
+def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+
+    my_conv, counter = gen_tvar(counter)
+    symbols[n] = my_conv
+    input_var = symbols[n.args[0]]
+
+    # dim vars
+    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
+
+    # c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
+    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
+
+    # c2 = DConsistency(module_instance.in_channels, d2)
+    c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
+
+    c3 = CalcConv(my_conv, input_var,
+                  module_instance.out_channels,
+                  module_instance.kernel_size,
+                  module_instance.padding,
+                  module_instance.stride,
+                  module_instance.dilation, [d1, d2, d3, d4])
+
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+
+    return [c1, c2, c3, *nat_constraints], counter
+
+
+@register_inference_rule(torch.nn.MaxPool2d)
+def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
+    assert isinstance(n.args[0], Node)
+    maxpool, counter = gen_tvar(counter)
+    symbols[n] = maxpool
+    input_var = symbols[n.args[0]]
+
+    # dim vars
+    [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
+
+    c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
+
+    c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
+                     module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
+
+    nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
+
+    return [c1, c2, *nat_constraints], counter
+
+
+class ConstraintGenerator:
+    def __init__(self, traced, graph=None):
+        self.traced = traced  # traced or tracer.root
+        self.traced_params = dict(self.traced.named_parameters())
+        self.constraints = []
+        self.symbol_dict = {}
+        self.graph = traced.graph if hasattr(traced, 'graph') else graph
+
+
+    def generate_constraints(self, counter=0):
+        """
+        Iterate through every node and generate constraints
+        Effect: self.constraints will be populated with the final constraints
+        """
+        graph = self.graph
+
+        all_constraints = []
+
+        for n in graph.nodes:
+            (constraints, counter) = self.generate_constraints_node(n, counter)
+            all_constraints += constraints
+
+        return Conj(all_constraints), counter
+
+    def generate_constraints_node(self, n: Node, counter):
+        """
+        Generate constraints the given node:
+        Currently supported operations:
+        - Reshape
+        - Add
+        - conv2d
+        """
+
+        if n.op == 'placeholder':
+            x, counter = gen_tvar(counter)
+            self.symbol_dict[n] = x
+
+            my_type = n.type
+
+            if n.type != Dyn and (not isinstance(n.type, TensorType)):
+                if n.type == torch.nn.parameter.Parameter:
+                    # since we have a parameter, the shape must be static
+                    assert 'example_value' in n.meta
+                    my_type = TensorType(n.meta['example_value'].size())
+                else:
+                    my_type = Dyn
+
+            c1 = BinConstraintT(my_type, x, op_precision)
+            c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
+            return [c1, c2], counter
+
+        elif n.op == 'call_function':
+            if n.target in _INFERENCE_RULES:
+                return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
+            else:
+                raise RuntimeError(f'No inference rule registered for target {n.target}!')
+
+        elif n.op == 'call_module':
+
+            module_instance = self.traced.get_submodule(n.target)
+            if type(module_instance) in _INFERENCE_RULES:
+                return _INFERENCE_RULES[type(module_instance)](n,
+                                                               module_instance,
+                                                               self.symbol_dict,
+                                                               self.constraints, counter)
+            else:
+                raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
+
+        elif n.op == 'call_method':
+            if n.target in _INFERENCE_RULES:
+                return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
+            else:
+                raise RuntimeError(f'No inference rule registered for target {n.target}!')
+
+        elif n.op == 'get_attr':
+            t = self.traced_params.get(n.target, None)
+
+            if isinstance(t, torch.Tensor):
+                if len(t.shape) > 0:
+                    res = list(t.shape)
+                    attr_type = TensorType(res)
+                    output, counter = gen_tvar(counter)
+                    self.symbol_dict[n] = output
+                    return [BinConstraintT(output, attr_type, op_eq)], counter
+                else:
+                    # scalar?
+                    return [], counter
+            else:
+                return [], counter
+
+        elif n.op == 'output':
+            return [], counter
+
+        else:
+            raise NotImplementedError(f"Method {n.op} not yet implemented")
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
new file mode 100644
index 0000000000000000000000000000000000000000..614b12426c599bfa950c97a0cc095fb3ddb81afe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
@@ -0,0 +1,1040 @@
+# mypy: ignore-errors
+import copy
+import itertools
+from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
+from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \
+    Transpose
+from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound
+from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
+from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
+from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
+from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
+from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
+from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
+from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
+from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar
+from torch.fx.tensor_type import TensorType, Dyn
+from typing import Callable, Dict, List
+
+_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
+
+
+def register_transformation_rule(call_target):
+    def register(fn):
+        if call_target in _TRANSFORMATION_RULES:
+            raise RuntimeError(f'Transformation rule already registered for {call_target}!')
+        _TRANSFORMATION_RULES[call_target] = fn
+        return fn
+    return register
+
+
+def valid_index(index, dims):
+    """
+    Given a list of dimensions, checks if an index is valid in the list
+    """
+    try:
+        dims[index]
+        return T()
+    except IndexError:
+        return F()
+
+
+@register_transformation_rule(Transpose)
+def transform_transpose(constraint, counter):
+    """
+    Similar to a sequence of two index-selects
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index1 = valid_index(constraint.index1, dims)
+    is_valid_index2 = valid_index(constraint.index2, dims)
+    new_dims = copy.deepcopy(dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    if is_valid_index1 == T() and is_valid_index2 == T():
+        new_dims[constraint.index1] = dims[constraint.index2]
+        new_dims[constraint.index2] = dims[constraint.index1]
+
+    transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+                                   *nat_constraints,
+                                   is_valid_index1, is_valid_index2,
+                                   BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
+    return transformed_constraint, counter
+
+
+@register_transformation_rule(IndexSelect)
+def transform_index_select(constraint, counter):
+    """
+    The constraints consider the given tensor size, checks if the index is valid
+    and if so, generates a constraint for replacing the input dimension
+    with the required dimension
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    is_valid_index = valid_index(constraint.index, dims)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # if the index is valid then replace the input dimension with the new dimension
+    # otherwise the dimension will not be replaced and the clause will contain False
+    if is_valid_index == T():
+        new_dims = copy.deepcopy(dims)
+        new_dims[constraint.index] = constraint.dim_replace
+
+    transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+                                   *nat_constraints,
+                                   is_valid_index,
+                                   BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
+
+    # print(constraints)
+    return transformed_constraint, counter
+
+
+@register_transformation_rule(GetItem)
+def transform_get_item(constraint, counter):
+    """
+    generate an equality of the form:
+    t = [a1, ..., an]
+    then generate constraints that check if the given index is valid
+    given this particular tensor size.
+    If the index is valid, generate a constraint to get the item
+    Note that we already handled the Dyn input case in the previous
+    step.
+    Args:
+        constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
+        counter: variable tracking
+    Returns: simplified constraints for GetItem
+
+    """
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    nat_constraints = gen_nat_constraints(dims)
+
+
+    is_valid_index = valid_index(constraint.index, dims)
+
+    all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+                       *nat_constraints,
+                       is_valid_index]
+
+    # if the index is valid, we generate a constraint for getting an item
+    # otherwise this clause will have been UNSAT due to the wrong index
+    if is_valid_index == T():
+        all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq))
+
+    return Conj(all_constraints), counter
+
+def valid_index_tensor(index, dims):
+    """
+    if the slice instances exceed the length of the dimensions
+    then this is a type error so we return False
+    """
+    slice_count = 0
+    for s in index:
+        if isinstance(s, slice):
+            slice_count += 1
+    if slice_count > len(dims):
+        return F()
+    else:
+        return T()
+
+@register_transformation_rule(GetItemTensor)
+def transform_get_item_tensor(constraint, counter):
+    """
+    When the index is a tuple, then the output will be a tensor
+    TODO: we have to check if this is the case for all HF models
+
+    The cases we are covering here are a tuple with one of:
+     - slice with default argument
+     - None
+
+     None appends 1 to the input tensor dimensions
+     so each occurrence of 'None' increases the rank by 1
+
+     slice with default arguments does not change the rank
+    """
+    assert isinstance(constraint.index_tuple, tuple)
+
+
+    # generate a result tensor of the expected size
+    dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
+    nat_constraints = gen_nat_constraints(dims)
+
+    # generate a place-holder list of the right rank
+    # where "slice" does not contribute to the rank and "None" does
+    none_c = constraint.index_tuple.count(None)
+    resulting_tensor_dims = (none_c + len(dims)) * [None]
+
+    dim_index = 0
+    for i in range(len(constraint.index_tuple)):
+
+        # append 1 to the right location of the resulting tensor
+        if constraint.index_tuple[i] is None:
+            resulting_tensor_dims[i] = 1
+
+        elif constraint.index_tuple[i] == slice(None, None, None):
+            pass
+
+        else:
+            raise NotImplementedError('Method not yet implemented')
+
+    # append the remaining dimensions to the right location
+    dim_index = 0
+    for i in range(len(resulting_tensor_dims)):
+        if resulting_tensor_dims[i] is None:
+            resulting_tensor_dims[i] = dims[dim_index]
+            dim_index += 1
+
+    # check if the index is valid
+    is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
+
+    # check if the resulting tensor is within bounds
+    if len(resulting_tensor_dims) > 4:
+        return F(), counter
+
+    else:
+        constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
+                       BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
+                       *nat_constraints,
+                       is_valid_index]
+        return Conj(constraints), counter
+
+
+@register_transformation_rule(BinConstraintT)
+def generate_binconstraint_t(constraint, counter):
+    """
+    Transform binary constraints for tensors
+    """
+
+    # precision constraints
+    if constraint.op == op_precision:
+        if constraint.lhs == Dyn:
+            return T(), counter
+        elif isinstance(constraint.lhs, TensorType):
+            is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
+            if is_fully_static:
+                return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
+            else:
+                new_dims = []
+
+                for _ in range(len(constraint.lhs.__args__)):
+                    dim, counter = gen_dvar(counter)
+                    new_dims.append(dim)
+
+                new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
+                                       new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
+                                      [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
+                                      [BinConstraintD(1, new_dim, op_leq) for
+                                       new_dim in new_dims]
+                return Conj(new_dim_constraints), counter
+
+    # matching
+    elif constraint.op == op_matching:
+        assert isinstance(constraint.rhs, TensorType)
+        d1 = constraint.rhs.__args__[0]
+        d2 = constraint.rhs.__args__[1]
+        d3 = constraint.rhs.__args__[2]
+        d4 = constraint.rhs.__args__[3]
+
+        conj = [BinConstraintT(constraint.lhs, Dyn, op_eq),
+                BinConstraintD(d1, Dyn, op_eq),
+                BinConstraintD(d2, Dyn, op_eq),
+                BinConstraintD(d3, Dyn, op_eq),
+                BinConstraintD(d4, Dyn, op_eq)]
+        return Disj([Conj(conj),
+                     BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter
+
+    elif constraint.op == op_consistency:
+        c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)])
+        [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter)
+
+        return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
+
+    elif constraint.op == op_leq:
+        assert isinstance(constraint.rhs, int)
+        disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
+        for i in range(1, constraint.rhs + 1):
+            dims = []
+            for j in range(1, i + 1):
+                dim_var, counter = gen_dvar(counter)
+                dims.append(dim_var)
+            disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
+        return Disj(disj), counter
+    else:
+        return constraint, counter
+
+
+@register_transformation_rule(BinConstraintD)
+def generate_binconstraint_d(constraint, counter):
+    """
+    Transform binary constraints for dimensions
+    """
+    if constraint.op == op_precision:
+        if isinstance(constraint.lhs, int):
+            return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
+        elif constraint.lhs == Dyn:
+            return T(), counter
+
+    elif constraint.op == op_consistency:
+        return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
+                     BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter
+
+    else:
+        return constraint, counter
+
+
+@register_transformation_rule(Conj)
+def generate_conj(constraint, counter):
+    """
+    Transform conjunctions
+    """
+    new = []
+    for c in constraint.conjucts:
+        new_c, counter = transform_constraint(c, counter)
+        new.append(new_c)
+    return Conj(new), counter
+
+
+@register_transformation_rule(Disj)
+def generate_disj(constraint, counter):
+    """
+    Transform disjunctions
+    """
+    new = []
+    for c in constraint.disjuncts:
+        new_c, counter = transform_constraint(c, counter)
+        new.append(new_c)
+    return Disj(new), counter
+
+
+@register_transformation_rule(TGreatestUpperBound)
+def generate_gub(constraint, counter):
+    """
+    Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
+    on dimensions
+    """
+    c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq),
+                     BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)])
+
+    [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
+
+    return Disj([c1, c2, c3, c4, c5]), counter
+
+
+@register_transformation_rule(DGreatestUpperBound)
+def generate_d_gub(constraint, counter):
+    """
+    Transform greatest upper bound for dimensions into equality constraints
+    """
+    c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)])
+    c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
+    c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
+    return Disj([c1, c2, c3]), counter
+
+
+@register_transformation_rule(CalcConv)
+def generate_calc_conv(constraint, counter):
+    d, counter = gen_tensor_dims(4, counter)
+    conv_result = TensorType([d[0], d[1], d[2], d[3]])
+
+    # the convolution result is a tensor of size 4
+    c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
+
+    # the second dimension of the output is equal to the output channels
+    c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)])
+
+    # the input corresponds to the output in the first dimension of the convolution
+    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
+
+    c4, c5 = calc_last_two_dims(constraint, d)
+
+    leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
+                            BinConstraintD(0, d[1], op_leq),
+                            BinConstraintD(0, d[2], op_leq),
+                            BinConstraintD(0, d[3], op_leq)])
+
+    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
+
+
+@register_transformation_rule(CalcMaxPool)
+def generate_calc_maxpool(constraint, counter):
+    """
+    Transform maxpool constraints
+    """
+    d, counter = gen_tensor_dims(4, counter)
+    maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
+
+    # the maxpool result is a tensor of size 4
+    c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
+
+    # the input corresponds to the output in the first and second dimension of maxpool
+    c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
+    c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
+    c4, c5 = calc_last_two_dims(constraint, d)
+
+    leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
+                            BinConstraintD(0, d[1], op_leq),
+                            BinConstraintD(0, d[2], op_leq),
+                            BinConstraintD(0, d[3], op_leq)])
+
+    return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
+
+
+@register_transformation_rule(CalcProduct)
+def generate_calc_product(constraint, counter):
+    """
+    Transform flatten constraints
+    """
+    start = constraint.start
+    end = constraint.end
+    dims = constraint.dims_to_flatten
+    flattened = constraint.flattened
+    n = len(constraint.dims_to_flatten)
+
+    # this will be evaluated right here
+    boundary_check = (0 <= start and start < end and end <= n)
+
+    c_boundary = T() if boundary_check else F()
+
+    lhs = dims[0:start]
+    rhs = dims[end:]
+    mid = dims[start:end]
+
+    all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
+
+    all_constraints = []
+
+    for p in all_possibilities:
+        p = list(p)
+        # this tells us there is a dynamic variable
+        contains_dyn = not all(constraint.op == op_neq for constraint in p)
+        if contains_dyn:
+            mid_var = [Dyn]
+            total_constraints = lhs + mid_var + rhs
+            if len(total_constraints) > 4:
+                all_constraints.append(F())
+            else:
+                all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p))
+        else:
+            new_var, counter = gen_dvar(counter)
+            mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)])
+            mid_var = [new_var]
+            total_constraints = lhs + mid_var + rhs
+            if len(total_constraints) > 4:
+                all_constraints.append(F())
+            else:
+                all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p))
+
+    return Conj([Disj(all_constraints), c_boundary]), counter
+
+
+@register_transformation_rule(CanReshape)
+def generate_reshape(constraint, counter):
+    """
+    Transform reshape constraints
+    """
+    d, counter = gen_tensor_dims(4, counter)
+
+    d1 = d[0]
+    d2 = d[1]
+    d3 = d[2]
+    d4 = d[3]
+
+    target = constraint.target.__args__
+
+    is_fully_static = all(d != Dyn for d in target)
+
+    # dynamic tensor
+    c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
+    c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
+    c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
+    c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
+    c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
+
+    d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
+    d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
+
+    d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
+    d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
+
+    d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
+    d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
+
+    d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
+    d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
+
+    nat_d1 = BinConstraintD(0, d1, op_leq)
+    nat_d2 = BinConstraintD(0, d2, op_leq)
+    nat_d3 = BinConstraintD(0, d3, op_leq)
+    nat_d4 = BinConstraintD(0, d4, op_leq)
+
+    if is_fully_static:
+        # size 1 tensor
+        c3_tensor1 = Disj([d1_eq_dyn,
+                           (Conj([d1_neq_dyn,
+                                  BinConstraintD(d1, Prod(target), op_eq)]))])
+        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
+
+        # size 2 tensor
+        all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)])
+
+        # size 3 tensor
+        all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)])
+
+        # size 4 tensor
+        all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)])
+
+        return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
+                     nat_d1, nat_d2, nat_d3, nat_d4]), counter
+
+    # then there must be exactly one occurrence of dyn
+    else:
+        new_target = []
+
+        for n in target:
+            if n != Dyn:
+                new_target.append(n)
+
+        # tensor 1
+        c3_tensor1 = Disj([d1_eq_dyn,
+                           (Conj([d1_neq_dyn,
+                                  is_dim_div_by_target(new_target, d1)]))])
+        all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
+
+        # tensor 2
+        c21 = Disj([d1_eq_dyn, d2_eq_dyn])
+        c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))])
+        all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
+
+        # tensor 3
+        c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
+        c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))])
+        all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
+
+        # tensor 4
+        c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
+        c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))])
+        all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
+
+        return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
+                     nat_d1, nat_d2, nat_d3, nat_d4]), counter
+
+
+@register_transformation_rule(ApplyBroadcasting)
+def generate_broadcasting(constraint, counter):
+    """
+    Transform broadcasting constraints
+    """
+    e11, e12 = constraint.res1, constraint.res2
+    e1, e2 = constraint.input1, constraint.input2
+
+    e1_dyn = BinConstraintT(e1, Dyn, op_eq)
+    e2_dyn = BinConstraintT(e2, Dyn, op_eq)
+
+    # Introduce dimensions
+    e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
+    e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
+
+    # dyn possibility
+    e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
+    e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
+
+    # tensor possibility
+    # generate dimensions to create tensors of size 1
+    final_tensor_1_constraint, _, _, nat_dims_1, counter = \
+        gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)
+
+    # generate dimensions to create tensors of size 2
+    final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
+        final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
+        gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
+
+    # generate dimensions to create tensors of size 3
+    final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
+        final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
+        gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
+
+    # generate dimensions to create tensors of size 4
+    final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
+        final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
+        gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
+
+    final_result = Disj([
+        e1_dyn_constraint,
+        e2_dyn_constraint,
+        final_tensor_1_constraint,
+        final_tensor_2_constraint_no_padding,
+        final_tensor_2_constraint_padding_arg1,
+        final_tensor_2_constraint_padding_arg2,
+        final_tensor_3_constraint_no_padding,
+        final_tensor_3_constraint_padding_arg1,
+        final_tensor_3_constraint_padding_arg2,
+        final_tensor_4_constraint_no_padding,
+        final_tensor_4_constraint_padding_arg1,
+        final_tensor_4_constraint_padding_arg2
+    ])
+
+    return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
+
+
+def transform_constraint(constraint: Constraint, counter: int):
+    """
+    Transforms a constraint into a simpler constraint.
+    Ex: precision and consistency are transformed to equality
+    Args:
+        constraint: constraint to be transformed
+        counter: for variable tracking
+
+    Returns: Constraint
+
+    """
+    if type(constraint) in _TRANSFORMATION_RULES:
+        return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
+
+    else:
+        return constraint, counter
+
+
+
+
+def calc_last_two_dims(constraint, d: List[DVar]):
+    """
+    Generates constraints for the last two dimensions of a convolution or a maxpool output
+    Args:
+        constraint: CalcConv or CalcMaxPool
+        d: The list of output dimensions
+
+    Returns: Constraints for calculating the last two dimensions of the output
+
+    """
+
+    assert isinstance(constraint, (CalcConv, CalcMaxPool))
+
+    b3 = constraint.matching_constraint[2]
+    b4 = constraint.matching_constraint[3]
+
+    b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
+    b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
+
+    d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)])
+    d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)])
+
+    # transform parameters into tuples incase they are not already
+    padding = (constraint.padding, constraint.padding) \
+        if isinstance(constraint.padding, int) else constraint.padding
+    kernel = (constraint.kernel, constraint.kernel) \
+        if isinstance(constraint.kernel, int) else constraint.kernel
+    stride = (constraint.stride, constraint.stride) \
+        if isinstance(constraint.stride, int) else constraint.stride
+    dilation = (constraint.dilation, constraint.dilation) \
+        if isinstance(constraint.dilation, int) else constraint.dilation
+
+    f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
+    f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
+    f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div)
+    f4 = BinConstraintD(f3, 1, op_add)
+
+    c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
+
+    f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
+    f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
+    f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div)
+    f44 = BinConstraintD(f33, 1, op_add)
+
+    c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
+
+    return c4, c5
+
+
+def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
+    """
+    Generate all possibilities of being equal or not equal to dyn for my_list
+    Args:
+        my_list: List of tensor dimensions
+
+    Returns: A list of a list of constraints. Each list of constraints corresponds to
+    one possibility about the values of the dimension variables
+    """
+    # generate all possibilities of being equal or not equal to dyn for my_list
+    eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))]
+    neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))]
+    d_possibilities = []
+
+    for i in zip(eq_possibilities, neq_possibilities):
+        d_possibilities.append(list(i))
+    all_possibilities = list(itertools.product(*d_possibilities))
+    return all_possibilities
+
+
+def is_target_div_by_dim(target: List[int], dim: List[DVar]):
+    """
+    Generate constraints to check if the target dimensions are divisible by the input dimensions
+    Args:
+        target: Target dimensions
+        dim: Input dimensions
+
+    Returns: Constraints to check divisibility
+
+    """
+    return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
+
+
+def is_dim_div_by_target(target: List[int], dim: List[DVar]):
+    """
+    Generate constraints to check if the input dimensions is divisible by the target dimensions
+    Args:
+        target: Target dimensions
+        dim:  Input dimensions
+
+    Returns: Constraints to check divisibility
+
+    """
+    return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
+
+
+def gen_all_reshape_possibilities(list_of_dims, target):
+    """
+    Consider all possibilities what the input dimensions could be (number or dynamic)
+    Then generate the appropriate constraints using multiplication or mod depending on the possibility
+    The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
+    for the input. Target is fixed because at most one dimension could be dyn.
+    We have different cases for this.
+
+    Args:
+        list_of_dims: The input list of dimensions
+        target: The tensor we want to reshape to
+
+    Returns: A disjunction of transformed reshape constraints
+
+    """
+    all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
+
+    all_constraints = []
+
+    for p in all_possibilities:
+        to_multiply = []
+
+        p = list(p)
+
+        for constraint in p:
+            assert isinstance(constraint, BinConstraintD)
+            if constraint.op == op_neq:
+                to_multiply.append(constraint.lhs)
+
+        if not to_multiply:
+            all_constraints.append(Conj(p))
+
+        elif len(to_multiply) < len(list_of_dims):
+            all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
+        else:
+            all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims),
+                                                            Prod(target), op_eq)]))
+
+    return Disj(all_constraints)
+
+
+def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
+    """
+    Apply broadcasting to the 'index' dimension of tensor_input1.
+    Args:
+        tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
+        tensor_input2: represents the second input
+        res1: broadcasted result 1
+        res2: broadcasted result 2
+        index: the index to broadcast
+        padding: If padding was used, then tensor_input1[index] does not exist
+
+    Returns:
+
+    """
+    if tensor_input1[index] is None:
+        assert padding
+
+
+    if not padding:
+        # then the inputs are the same length so they all have dimensions at "index"
+        return Conj([BinConstraintD(tensor_input1[index], 1, op_eq),
+                     BinConstraintD(res1[index], res2[index], op_eq),
+                     BinConstraintD(res2[index], tensor_input2[index], op_eq)])
+
+    else:
+        # we don't set the input dimension to 1, since it doesn't exist.
+        return Conj([BinConstraintD(res1[index], res2[index], op_eq),
+                     BinConstraintD(res2[index], tensor_input2[index], op_eq)])
+
+
+def apply_padding(e1_var: TVar,
+                  e11: BinConstraintT,
+                  e2: BinConstraintT,
+                  e12: BinConstraintT,
+                  d2: List[DVar],
+                  d11: List[DVar],
+                  d12: List[DVar],
+                  counter: int):
+    """
+    We are considering the possibility where one input has less dimensions than
+    another input, so we apply padding to the broadcasted results
+
+    Args:
+        e1_var: Variable representing the first input where padding will be
+        e11: constraint of the form e11 = Tensortype[d1, ..., dn]
+        e2:  constraint of the form e2 = Tensortype[d1, ..., dn]
+        e12: constraint of the form e11 = Tensortype[d1, ..., dn]
+        d2: Tensor variables for the second input
+        d11: Tensor variables for the broadcasted first input
+        d12: Tensor variables for the broadcasted second input
+        counter: variable tracking
+
+    Returns: A new constraint whose goal is to apply padding to the broadcasted result
+
+    """
+
+    res = []
+
+    # pad the shorter input with None so we can pass it to the broadcasting helper function
+    for i in range(1, len(d2)):
+
+        d1, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
+
+        e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
+
+        simulate_padding = [None] * (len(d2) - i)
+
+        assert len(simulate_padding + d1) == len(d2)
+
+        broadcast_padding = []
+
+        # for every padding size, we also consider broadcasting
+        for j in range(len(d2) - i):
+            broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True))
+
+        # we consider the possibilities for broadcasting for every dimension. Since we already
+        # padded d1, we do not consider it while broadcasting
+        all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1,
+                                                                                            d2[(len(d2) - i):],
+                                                                                            d11[(len(d2) - i):],
+                                                                                            d12[(len(d2) - i):])
+        # combine all constraints into a conjunction
+        c = Conj([e1, e11, e2, e12,
+                  *broadcast_padding,
+                  all_broadcasting_possibilities,
+                  *nat_constraints
+                  ])
+        res.append(c)
+
+    return Disj(res), counter
+
+
+def no_broadcast_dim_with_index(d1: List[DVar],
+                                d2: List[DVar],
+                                d3: List[DVar],
+                                d4: List[DVar],
+                                i: int):
+    """
+    Args:
+        d1: input 1
+        d2: input 2
+        d3: simulated broadcasting for input 1
+        d4: simulated broadcasting for input 2
+        i: the rank of the resulting tensor addition
+
+    Returns: Constraints for when no broadcasting occurs
+    """
+    return Conj([
+        Disj([
+            Conj([BinConstraintD(d1[i], 1, op_eq),
+                  BinConstraintD(d2[i], 1, op_eq)]),
+
+            Conj([BinConstraintD(d1[i], 1, op_neq),
+                  BinConstraintD(d2[i], 1, op_neq)])]),
+
+        BinConstraintD(d1[i], d3[i], op_eq),
+        BinConstraintD(d2[i], d4[i], op_eq)])
+
+
+
+def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
+    """
+    Generate lists of DVar to represent tensor dimensions
+    Args:
+        num_tensors: the required number of tensors
+        dim_size: the number of dimensions for each tensor
+        counter: variable tracking
+
+    Returns: A list of a list of tensor dimensions
+
+    """
+    res = []
+
+    for _ in range(num_tensors):
+        dims, counter = gen_tensor_dims(dim_size, counter)
+        res.append(dims)
+
+    return res, counter
+
+
+def create_equality_constraints_for_broadcasting(e1: TVar,
+                                                 e2: TVar,
+                                                 e11: TVar,
+                                                 e12: TVar,
+                                                 d1: List[DVar],
+                                                 d2: List[DVar],
+                                                 d11: List[DVar],
+                                                 d12: List[DVar]):
+    """
+    Create equality constraints for when no broadcasting occurs
+    Args:
+        e1: Input 1
+        e2: Input 2
+        e11: Broadcasted input 1
+        e12: Broadcasted input 2
+        d1: Variables that store dimensions for e1
+        d2: Variables that store dimensions for e2
+        d11: Variables that store dimensions for e11
+        d12: Variables that store dimensions for e22
+
+    Returns: Four equality constraints
+
+    """
+
+    e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
+    e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
+    e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
+    e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
+    return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
+
+
+def gen_consistency_constraints(constraint: Constraint, counter: int):
+    """
+    Args:
+        constraint: Consistency constraint on tensors
+        counter: for variable tracking
+
+    Returns: Equality and consistency constraints on dimensions
+
+    """
+
+    all_constraints = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
+        new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
+
+        nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
+
+        c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
+                           BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] +
+                          [BinConstraintD(d1, d2, op_consistency) for
+                           d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints)
+
+        all_constraints.append(c_tensor_i)
+
+    return all_constraints, counter
+
+
+def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
+    """
+    Args:
+        constraint: Greatest upper bound on tensors
+        counter: variable tracking
+
+    Returns: A set of equality constraints and DGreatestUpperBound constraints
+
+    """
+
+    all_constraints = []
+
+    for i in range(1, MAX_TENSOR_RANK + 1):
+        c = []
+        dims1, counter = gen_tensor_dims(i, counter)
+        c1tensor = TensorType(dims1)
+
+        dims2, counter = gen_tensor_dims(i, counter)
+        c2tensor = TensorType(dims2)
+
+        dims3, counter = gen_tensor_dims(i, counter)
+        c3tensor = TensorType(dims3)
+
+        c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq),
+              BinConstraintT(constraint.rhs2, c2tensor, op_eq),
+              BinConstraintT(constraint.res, c3tensor, op_eq)] + \
+            gen_nat_constraints(dims1 + dims2 + dims3)
+
+        assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
+        for i in range(len(c3tensor.__args__)):
+            c.append(DGreatestUpperBound(c3tensor.__args__[i],
+                                         c1tensor.__args__[i],
+                                         c2tensor.__args__[i]))
+
+        all_constraints.append(Conj(c))
+    return all_constraints, counter
+
+
+def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]):
+    """
+    Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
+    We look at all combinations for all dimensions in d1 and d2
+    Args:
+        d1: input1 dimensions
+        d2: input2 dimensions
+        d11: broadcasted input1 dimensions
+        d12: broadcasted input2 dimensions
+
+    Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
+
+    """
+
+    size = len(d1)
+
+    res2 = []
+
+    for i in range(size):
+        t1 = broadcast_dim(d1, d2, d11, d12, i)
+        t2 = broadcast_dim(d2, d1, d12, d11, i)
+        t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
+
+        res2.append(Disj([t1, t2, t3]))
+
+    return Conj(res2)
+
+
+def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int):
+    """
+    Simulates broadcasting on e1 and e2 and returns the results
+    respectively in e11 and e12. Because of gradual types,
+    e1 and e2 may not be equal. Similarly, e11 and e12 may not
+    be equal. e11 and e12 should be guaranteed to be consistent
+    as they represent the shapes of the tensors to be added after
+    broadcasting.
+    Args:
+        e1: TVar representing the type of input 1
+        e2: TVar representing the type of input 2
+        e11: TVar representing the representing broadcasted input 1
+        e12: TVar representing the representing broadcasted input 2
+        i: The rank of the resulting type of addition
+        counter: for variable tracking
+
+    Returns: Simplified broadcasting constraints
+
+    """
+    dims, counter = gen_lists_of_dims(4, i, counter)
+    [d1, d2, d3, d4] = dims
+    nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
+
+    initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
+                                                                                  d1, d2, d3, d4)
+
+    [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
+
+    # without padding, broadcast all possibilities for tensors of size i
+    final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints,
+                                               generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)])
+
+    # with padding, broadcast all possibilities for tensors of size i
+    final_tensor_constraint_padding_arg1, counter = \
+        apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter)
+
+    final_tensor_constraint_padding_arg2, counter = \
+        apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter)
+
+    return final_tensor_constraint_no_padding, \
+        final_tensor_constraint_padding_arg1, \
+        final_tensor_constraint_padding_arg2, nat_dims_i, counter
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b86d3ced1b0c0e056349c384a3c0bce10de823
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py
@@ -0,0 +1,14 @@
+op_add = '+'
+op_sub = '-'
+op_mul = '*'
+op_div = '/'
+op_eq = '='
+op_neq = '!='
+op_imp = '=>'
+op_matching = '⊳'
+op_consistency = '~'
+op_precision = '⊑'
+op_leq = '≤'
+op_lt = '<'
+op_gt = '>'
+op_mod = '%'
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9376ef7eb7ceb7d63bcb2b22af0857b24a6958a0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
@@ -0,0 +1,348 @@
+from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
+from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
+from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
+from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
+from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
+from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
+from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
+from torch.fx.tensor_type import TensorType, Dyn
+
+try:
+    import z3  # type: ignore[import]
+    from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
+    HAS_Z3 = True
+
+    def transform_to_z3(constraint, counter, dimension_dict):
+        if isinstance(constraint, Conj):
+            conjuncts = []
+            for c in constraint.conjucts:
+                new_c, counter = transform_to_z3(c, counter, dimension_dict)
+                conjuncts.append(new_c)
+            return z3.And(conjuncts), counter
+
+        elif isinstance(constraint, Disj):
+            disjuncts = []
+            for c in constraint.disjuncts:
+                new_c, counter = transform_to_z3(c, counter, dimension_dict)
+                disjuncts.append(new_c)
+            return z3.Or(disjuncts), counter
+
+        elif isinstance(constraint, T):
+            return True, counter
+
+        elif isinstance(constraint, F):
+            return False, counter
+
+        elif isinstance(constraint, BinConstraintT):
+            if constraint.op == op_eq:
+                lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
+                return (lhs == rhs), counter
+
+            else:
+                raise NotImplementedError('Method not yet implemented')
+
+        elif isinstance(constraint, BinConstraintD):
+            if constraint.op == op_eq:
+
+                if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
+                    transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
+                    transformed_lhs = z3.Bool(constraint.lhs.c)
+                    return transformed_lhs == transformed_rhs, counter
+
+                elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
+                    # with dimension transformations we consider the encoding
+                    lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
+                    rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
+                    return lhs == rhs, counter
+
+                else:
+                    # then we have an algebraic expression which means that we disregard the
+                    # first element of the encoding
+                    lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
+                    rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
+                    return lhs == rhs, counter
+
+            # The assumption here is that the LHS and RHS must be dimensions
+            elif constraint.op == op_neq:
+                assert is_dim(constraint.lhs)
+                assert is_dim(constraint.rhs)
+                lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
+                if constraint.rhs == Dyn or constraint.lhs == Dyn:
+                    if constraint.rhs == Dyn:
+                        return lhs.arg(0) == 1, counter
+                    elif constraint.lhs == Dyn:
+                        return rhs.arg(0) == 1, counter
+
+                # if one of the instances is a number
+                elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
+                    if isinstance(constraint.lhs, int):
+                        return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
+
+                    elif isinstance(constraint.rhs, int):
+                        return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
+
+                else:
+                    return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
+                                  z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
+                                  z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
+
+
+            elif constraint.op == op_leq:
+                # if the dimensions are not dyn, this will come into effect
+                # there would have been another constraint specifying if a given dimension
+                # is dyn or not
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
+                return lhs <= rhs, counter
+
+            elif constraint.op == op_gt:
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
+                return lhs > rhs, counter
+
+            elif constraint.op == op_lt:
+                assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
+                lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
+                rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
+                return lhs < rhs, counter
+
+            else:
+                raise NotImplementedError('operation not yet implemented')
+
+        else:
+            raise NotImplementedError('Operation not yet implemented')
+
+
+    def transform_var(tensor, counter, dimension_dict):
+        """
+        Transforms tensor variables to a format understood by z3
+        Args:
+            tensor: Tensor variable or a tensor type potentially with variable dimensions
+        Returns: Transformed variable to a z3 format
+
+        """
+        if isinstance(tensor, TensorType):
+            res = []
+            for t in tensor.__args__:
+                transformed, counter = transform_dimension(t, counter, dimension_dict)
+                res.append(transformed)
+
+            assert len(res) <= 4
+            if len(tensor.__args__) == 1:
+                return tensor_type.tensor1(res[0]), counter
+            elif len(tensor.__args__) == 2:
+                return tensor_type.tensor2(res[0], res[1]), counter
+            elif len(tensor.__args__) == 3:
+                return tensor_type.tensor3(res[0], res[1], res[2]), counter
+            elif len(tensor.__args__) == 4:
+                return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
+
+        elif tensor == Dyn:
+            return z3_dyn, counter
+
+        elif isinstance(tensor, TVar):
+            return z3.Const(tensor.tvar, tensor_type), counter
+
+    def transform_dimension(dimension, counter, dimension_dict):
+        """
+        Takes a dimension variable or a number and transforms it to a tuple
+        according to our scheme
+        Args:
+            dimension: The dimension to be transformed
+            counter: variable tracking
+
+        Returns:  tuple and the current counter
+
+        """
+        if dimension == Dyn:
+            counter += 1
+            return D(0, z3.Int(counter)), counter
+        elif isinstance(dimension, int):
+            return D(1, dimension), counter
+        elif isinstance(dimension, DVar):
+            if dimension.c in dimension_dict:
+                return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
+            else:
+                counter += 1
+                dimension_dict[dimension.c] = counter
+                return D(z3.Int(counter), z3.Int(dimension.c)), counter
+
+
+    def transform_algebraic_expression(expr, counter, dimension_dict):
+        """
+        Transforms an algebraic expression to z3 format
+        Args:
+            expr: An expression is either a dimension variable or an algebraic-expression
+
+
+        Returns: the transformed expression
+
+        """
+        assert is_algebraic_expression(expr) or is_dim(expr)
+
+        if is_dim(expr):
+            transformed, counter = transform_dimension(expr, counter, dimension_dict)
+            return transformed.arg(1), counter
+
+        elif isinstance(expr, Prod):
+
+            dims = []
+            for dim in expr.products:
+                assert is_dim(dim)
+                d, counter = transform_dimension(dim, counter, dimension_dict)
+                dims.append(d.arg(1))
+            return z3.Product(dims), counter
+
+        elif is_algebraic_expression(expr):
+
+            lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
+            rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
+
+            if expr.op == op_sub:
+                c = lhs - rhs
+
+            elif expr.op == op_add:
+                c = lhs + rhs
+
+            elif expr.op == op_div:
+                c = lhs / rhs
+
+            elif expr.op == op_mul:
+                c = lhs * rhs
+
+            elif expr.op == op_mod:
+                c = lhs % rhs
+
+            else:
+                raise NotImplementedError('operation not yet implemented')
+
+            return c, counter
+
+        else:
+            raise RuntimeError
+
+
+    def transform_all_constraints(traced, counter=0):
+        """
+        Given a trace, generates constraints and transforms them to z3 format
+
+        """
+        dimension_dict = {}  # type: ignore[var-annotated]
+
+        generator = ConstraintGenerator(traced)
+        new_constraints, counter = generator.generate_constraints(counter)
+
+        # print(new_constraints.conjucts[0])
+        # print(*new_constraints.conjucts, sep='\n')
+
+        # transform precision, matching, consistency till obtaining a fixed point
+        new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
+        # print(new_constraints)
+        # print(new_constraints.conjucts)
+        # new_constraints.conjucts = new_constraints.conjucts[:-1]
+        # print(*new_constraints.conjucts, sep='\n')
+
+        transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
+        # print(transformed)
+        return transformed
+
+    def iterate_till_fixed_point(constraints, counter):
+        """
+        Transform constraints till reaching a fixed point
+        """
+        old_c = None
+        while old_c != constraints:
+            old_c = constraints
+            constraints, counter = transform_constraint(constraints, counter)
+        return constraints, counter
+
+    def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
+        """
+        Takes a node and a graph and generates two sets of constraints.
+        One set constraints the node's constraints and another set
+        constraints the negation of the node's constraints
+        Args:
+            tracer_root: the root for getting the module instances
+            graph: the graph so far in the tracing process
+            node: node that represents a conditional
+            counter: variable tracking
+
+        Returns: Two sets of constraints. One with a conjunction with the
+        the conditional constraint and the other with a conjunction with
+        its negation.
+
+        """
+        dimension_dict = {}  # type: ignore[var-annotated]
+
+        generator = ConstraintGenerator(tracer_root, graph)
+        new_constraints, counter = generator.generate_constraints(counter)
+
+        condition_constraint = new_constraints.conjucts[-1]
+
+        # we know the constraint is a conjunction where the last constraint is about the conditional
+        # so remove the last constraint
+        new_constraints.conjucts = new_constraints.conjucts[:-1]
+
+        # transform precision, matching, consistency till obtaining a fixed point
+        new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
+
+
+        # since the function returns a list of one element, we get the first element
+        # we are only interested in the RHS in this case because the LHS just stores
+        # the result
+
+        # we make sure the constraint is of the form:
+        # c = b where b is a boolean expression
+        # and we consider b (constraint.rhs) for transformation
+        assert isinstance(condition_constraint.lhs, BVar)
+        assert is_bool_expr(condition_constraint.rhs)
+        condition_constraint_rhs = condition_constraint.rhs
+
+        # transform the condition constraint
+        condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
+
+        transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
+
+        transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
+
+        negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
+
+        return z3.And([transformed, transformed_condition_constraint]), \
+            z3.And([transformed, negation_transformed_condition_constraint])
+
+
+    def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
+        """
+        Given an IR and a node representing a conditional, evaluate the conditional
+        and its negation
+        Args:
+            tracer_root: Tracer root for module instances
+            node: The node to be evaluated
+
+        Returns: the results of evaluating the condition and the negation with
+        the rest of the constraints
+
+        """
+
+        transformed_positive, transformed_negative = \
+            transform_all_constraints_trace_time(tracer_root, graph, node, counter)
+
+        s = z3.Solver()
+        s.add(transformed_positive)
+        if user_constraints is not None:
+            s.add(user_constraints)
+        condition = s.check()
+
+        s = z3.Solver()
+        s.add(transformed_negative)
+        if user_constraints is not None:
+            s.add(user_constraints)
+        negation = s.check()
+        return condition, negation
+
+except ImportError:
+    HAS_Z3 = False
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..384d4a6c56bac39e7f341198bb881f5ecb52db62
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
@@ -0,0 +1,52 @@
+from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
+    BVar
+from torch.fx.experimental.migrate_gradual_types.operation import op_leq
+
+
+def gen_tvar(curr):
+    """
+    Generate a tensor variable
+    :param curr: The current counter
+    :return: a tensor variable and the updated counter
+    """
+    curr += 1
+    return TVar(curr), curr
+
+
+def gen_dvar(curr):
+    """
+    Generate a dimension variable
+    :param curr: the current counter
+    :return: a dimension variable and an updated counter
+    """
+    curr += 1
+    return DVar(curr), curr
+
+def gen_bvar(curr):
+    """
+    Generate a boolean variable
+    :param curr: the current counter
+    :return: a boolean variable and an updated counter
+    """
+    curr += 1
+    return BVar(curr), curr
+
+def gen_tensor_dims(n, curr):
+    """
+    Generate a list of tensor dimensions
+    :param n:  the number of dimensions
+    :param curr: the current counter
+    :return: a list of dimension variables and an updated counter
+    """
+    dims = []
+    for _ in range(n):
+        dvar, curr = gen_dvar(curr)
+        dims.append(dvar)
+    return dims, curr
+
+
+def gen_nat_constraints(list_of_dims):
+    """
+    Generate natural number constraints for dimensions
+    """
+    return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bc3f6798e8a0e0107c4dcf8ce0377567089bad4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py
@@ -0,0 +1,29 @@
+try:
+    import z3  # type: ignore[import]
+    HAS_Z3 = True
+    # dynamic type
+    dyn = z3.DeclareSort('Dyn')
+    dyn_type = z3.Const('dyn', dyn)
+
+    # dimension
+    dim = z3.Datatype('dim')
+    dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
+    dim = dim.create()
+
+    # tensors
+    tensor_type = z3.Datatype('TensorType')
+    tensor_type.declare('Dyn', ('dyn', dyn))
+    tensor_type.declare('tensor1', ('0', dim))
+    tensor_type.declare('tensor2', ('0', dim), ('1', dim))
+    tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
+    tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
+    tensor_type = tensor_type.create()
+
+    # create dimension
+    D = dim.dim
+
+    z3_dyn = tensor_type.Dyn(dyn_type)
+
+
+except ImportError:
+    HAS_Z3 = False
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/normalize.py b/MLPY/Lib/site-packages/torch/fx/experimental/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..4642f54c5f8a4e32053270a27b133091ca653be9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/normalize.py
@@ -0,0 +1,162 @@
+import operator
+from typing import Any, Callable, Dict, Tuple, Optional
+
+import torch
+import torch.fx
+import torch.fx as fx
+from torch.fx import Transformer, Proxy
+from torch.fx.node import Argument, Target, Node, map_aggregate
+from torch.fx.operator_schemas import (
+    normalize_module,
+    normalize_function,
+    create_type_hint,
+)
+
+from .schema_type_annotation import AnnotateTypesWithSchema
+
+
+class NormalizeArgs(Transformer):
+    """
+    Normalize arguments to Python targets. This means that
+    `args/kwargs` will be matched up to the module/functional's
+    signature and rewritten to exclusively kwargs in positional order
+    if `normalize_to_only_use_kwargs` is true. Also populates default
+    values. Does not support positional-only parameters or varargs
+    parameters (*args, **kwargs).
+
+    If the nodes have 'type' metadata, it will use it to disambiguate
+    overloads. Otherwise, it will throw an error.
+
+    Example usage:
+        m = torchvision.models.resnet18()
+        traced = torch.fx.symbolic_trace(m)
+        traced = NormalizeArgs(traced).transform()
+    """
+
+    def __init__(
+        self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
+    ):
+        super().__init__(module)
+        self.node_map: Dict[Proxy, Node] = {}
+        self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
+
+    def run_node(self, n: Node) -> Any:
+        args, kwargs = self.fetch_args_kwargs_from_env(n)
+
+        def get_type(arg):
+            if isinstance(arg, fx.Node):
+                return n.meta["type"] if "type" in n.meta else None
+            return type(arg)
+
+        arg_types = map_aggregate(n.args, get_type)
+        assert isinstance(arg_types, tuple)
+        arg_types = tuple([create_type_hint(i) for i in arg_types])
+        kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
+        if n.op == "call_function":
+            out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
+        else:
+            out = super().run_node(n)
+        if n.op != "output":
+            self.node_map[out] = n
+            out.node.meta = n.meta
+            out.node.type = n.type
+        return out
+
+    def call_function(
+        self,
+        target: Target,
+        args: Tuple[Argument, ...],
+        kwargs: Dict[str, Any],
+        arg_types: Optional[Tuple[Any, ...]] = None,
+        kwarg_types: Optional[Dict[str, Any]] = None,
+    ):
+        assert callable(target)
+        new_args_and_kwargs = normalize_function(
+            target,
+            args,  # type: ignore[arg-type]
+            kwargs,
+            arg_types,  # type: ignore[arg-type]
+            kwarg_types,
+            self.normalize_to_only_use_kwargs,
+        )
+        if new_args_and_kwargs:
+            new_args, new_kwargs = new_args_and_kwargs
+            return self.tracer.create_proxy(
+                "call_function", target, new_args, new_kwargs
+            )
+        else:
+            return super().call_function(target, args, kwargs)
+
+    def call_module(
+        self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
+    ):
+        assert isinstance(target, str)
+        new_args_and_kwargs = normalize_module(
+            self.module,
+            target,
+            args,  # type: ignore[arg-type]
+            kwargs,
+            self.normalize_to_only_use_kwargs,
+        )
+        if new_args_and_kwargs:
+            new_args, new_kwargs = new_args_and_kwargs
+            return super().call_module(target, new_args, new_kwargs)
+        else:
+            return super().call_module(target, args, kwargs)
+
+
+class NormalizeOperators(AnnotateTypesWithSchema):
+    """
+    Normalize callsites that are different ways of "spelling" the same
+    invocation into a single, canonical call. Currently supports:
+
+    1. Normalize operators (e.g. operator.add) to the `torch` ops they
+       ultimately invoke (e.g. torch.add) when it is possible to statically
+       reason that
+
+    Example usage:
+
+        m = torchvision.models.resnet18()
+
+        traced = torch.fx.symbolic_trace(m)
+
+        traced = NormalizeOperators(traced).transform()
+    """
+
+    binary_magic_method_remap: Dict[
+        Callable[[Any, Any], Any], Callable[[Any, Any], Any]
+    ] = {
+        torch.add: operator.add,
+        torch.mul: operator.mul,
+        torch.sub: operator.sub,
+        torch.div: operator.truediv,
+        torch.floor_divide: operator.floordiv,
+        torch.remainder: operator.mod,
+        torch.eq: operator.eq,
+        torch.ne: operator.ne,
+        torch.lt: operator.lt,
+        torch.le: operator.le,
+        torch.gt: operator.gt,
+        torch.ge: operator.ge,
+    }
+
+    def call_function(
+        self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
+    ):
+        # Normalize operators according to the magic methods implemented on tensors here:
+        # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
+
+        assert callable(target)
+
+        if target in self.binary_magic_method_remap:
+            if len(args) != 2:
+                return super().call_function(target, args, kwargs)
+            lhs, rhs = args
+
+            return super().call_function(
+                target=self.binary_magic_method_remap[target],
+                args=(lhs, rhs),
+                kwargs={},
+            )
+
+        return super().call_function(target, args, kwargs)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/optimization.py b/MLPY/Lib/site-packages/torch/fx/experimental/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3dc401cc39b0dc586701ff2741343d33f5df7ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/optimization.py
@@ -0,0 +1,408 @@
+import torch.fx as fx
+from torch.fx.node import Argument, Target
+from torch.nn.utils.fusion import fuse_conv_bn_eval
+from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.fx.passes.shape_prop import ShapeProp
+import copy
+from collections import defaultdict
+import torch.utils.mkldnn as th_mkldnn
+import operator
+import time
+import logging
+from enum import Enum
+
+def _parent_name(target : str) -> Tuple[str, str]:
+    """
+    Splits a qualname into parent path and last atom.
+    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
+    """
+    *parent, name = target.rsplit('.', 1)
+    return parent[0] if parent else '', name
+
+# Works for length 2 patterns with 2 modules
+def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
+    if len(node.args) == 0:
+        return False
+    nodes: Tuple[Any, fx.Node] = (node.args[0], node)
+    for expected_type, current_node in zip(pattern, nodes):
+        if not isinstance(current_node, fx.Node):
+            return False
+        if current_node.op != 'call_module':
+            return False
+        if not isinstance(current_node.target, str):
+            return False
+        if current_node.target not in modules:
+            return False
+        if type(modules[current_node.target]) is not expected_type:
+            return False
+    return True
+
+
+def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
+    assert isinstance(node.target, str)
+    parent_name, name = _parent_name(node.target)
+    modules[node.target] = new_module
+    setattr(modules[parent_name], name, new_module)
+
+def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
+    """
+    Fuses convolution/BN layers for inference purposes. Will deepcopy your
+    model by default, but can modify the model inplace as well.
+    """
+    patterns = [(nn.Conv1d, nn.BatchNorm1d),
+                (nn.Conv2d, nn.BatchNorm2d),
+                (nn.Conv3d, nn.BatchNorm3d)]
+    if not inplace:
+        model = copy.deepcopy(model)
+    if not no_trace or not isinstance(model, torch.fx.GraphModule):
+        fx_model = fx.symbolic_trace(model)
+    else:
+        fx_model = model
+    modules = dict(fx_model.named_modules())
+    new_graph = copy.deepcopy(fx_model.graph)
+
+    for pattern in patterns:
+        for node in new_graph.nodes:
+            if matches_module_pattern(pattern, node, modules):
+                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
+                    continue
+                conv = modules[node.args[0].target]
+                bn = modules[node.target]
+                if not bn.track_running_stats:
+                    continue
+                fused_conv = fuse_conv_bn_eval(conv, bn)
+                replace_node_module(node.args[0], modules, fused_conv)
+                node.replace_all_uses_with(node.args[0])
+                new_graph.erase_node(node)
+    return fx.GraphModule(fx_model, new_graph)
+
+def remove_dropout(model: nn.Module) -> nn.Module:
+    """
+    Removes all dropout layers from the module.
+    """
+    fx_model = fx.symbolic_trace(model)
+
+    class DropoutRemover(torch.fx.Transformer):
+        def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+            if isinstance(self.submodules[target], nn.Dropout):
+                assert len(args) == 1
+                return args[0]
+            else:
+                return super().call_module(target, args, kwargs)
+    return DropoutRemover(fx_model).transform()
+
+def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
+    """
+    Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
+    """
+    new_graph = fx.Graph()
+    env: Dict[fx.Node, fx.Node] = {}
+    for input in inputs:
+        new_node = new_graph.placeholder(input.name)
+        env[input] = new_node
+    for node in nodes:
+        new_node = new_graph.node_copy(node, lambda x: env[x])
+        env[node] = new_node
+    new_graph.output([env[output] for output in outputs])
+    new_graph.lint()
+    return fx.GraphModule(orig_module, new_graph)
+
+mkldnn_supported = [
+    nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
+    torch.relu, torch.transpose, torch.sigmoid,
+    F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
+]
+# These are operators that may not be convertible into MKLDNN ops (e.g. the
+# args are scalar values). Thus, we only include them in the subgraph if their
+# arguments are already in MKLDNN.
+# TODO: Determine whether this can be removed after type inference.
+mkldnn_supported_unknown = [operator.add, operator.mul]
+mkldnn_map = {
+    nn.Conv2d: th_mkldnn.MkldnnConv2d,
+    nn.Linear: th_mkldnn.MkldnnLinear,
+    nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
+}
+
+
+def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
+    """
+    For each node, if it's a module that can be preconverted into MKLDNN,
+    then we do so and create a mapping to allow us to convert from the MKLDNN
+    version of the module to the original.
+    """
+    old_modules: Dict[nn.Module, nn.Module] = {}
+    for node in nodes:
+        if node.op == 'call_module':
+            assert isinstance(node.target, str)
+            cur_module = modules[node.target]
+            if type(cur_module) in mkldnn_map:
+                new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
+                assert isinstance(new_module, nn.Module)
+                old_modules[new_module] = copy.deepcopy(cur_module)
+                replace_node_module(node, modules, new_module)
+    return old_modules
+
+def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
+    """
+    Maps each module that's been changed with `modules_to_mkldnn` back to its
+    original.
+    """
+    for node in nodes:
+        if node.op == 'call_module':
+            assert (isinstance(node.target, str))
+            cur_module = modules[node.target]
+            if cur_module in old_modules:
+                replace_node_module(node, modules, old_modules[cur_module])
+
+class MklSubgraph:
+    def __init__(self, fx_graph: fx.Graph):
+        self.fx_graph = fx_graph
+        self.nodes: List[fx.Node] = []
+        self.start_nodes: List[fx.Node] = []
+        self.end_nodes: List[fx.Node] = []
+
+def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
+    """
+    This generates a heuristic that can be passed into `optimize_for_inference` that
+    determines whether a subgraph should be run in MKL by running it with the example_inputs.
+
+    Example usage:
+        heuristic = gen_mkl_autotuner(example_inputs, iters=10)
+        fast_model = optimization.optimize_for_inference(model, heuristic)
+    """
+    fx_model = None
+    old_modules = None
+
+    def use_mkl_heuristic(graph: MklSubgraph) -> bool:
+        nonlocal fx_model, old_modules
+        input_nodes = graph.start_nodes
+        if fx_model is None:
+            fx_model = graph.fx_graph.owning_module
+            old_modules = graph.fx_graph.old_modules  # type: ignore[attr-defined]
+            ShapeProp(fx_model).propagate(example_inputs)
+        sample_inputs = [torch.randn(node.shape) for node in input_nodes]  # type: ignore[attr-defined]
+        output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
+        submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
+
+        def benchmark(f):
+            for _ in range(warmup):
+                f()
+            begin = time.time()
+            for _ in range(iters):
+                out = f()
+            return time.time() - begin
+
+        mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
+
+        reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
+        no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
+        return mkl_time < no_mkl_time
+    return use_mkl_heuristic
+
+def use_mkl_length(graph: MklSubgraph) -> bool:
+    """
+    This is a heuristic that can be passed into `optimize_for_inference` that
+    determines whether a subgraph should be run in MKL by checking if there
+    are more than 2 nodes in it
+    """
+    return len(graph.nodes) > 2
+
+class UnionFind:
+    def __init__(self, n):
+        self.parent: List[Optional[int]] = [None] * n
+        self.size: List[int] = [0] * n
+
+    def make_set(self, v: int):
+        self.parent[v] = v
+        self.size[v] = 1
+
+    def find(self, v: int) -> int:
+        par = self.parent[v]
+        if v == par:
+            return v
+        assert par is not None
+        self.parent[v] = self.find(par)
+        return cast(int, self.parent[v])
+
+    def join(self, a: int, b: int):
+        a, b = self.find(a), self.find(b)
+        if a == b:
+            return a
+        if self.size[a] < self.size[b]:
+            a, b = b, a
+        self.parent[b] = a
+        self.size[a] += self.size[b]
+
+def optimize_for_inference(
+    model: torch.nn.Module,
+    pass_config: Optional[Dict[str, Any]] = None,
+    tracer: Type[fx.Tracer] = fx.Tracer
+) -> torch.nn.Module:
+    """
+    Performs a set of optimization passes to optimize a model for the
+    purposes of inference. Specifically, the passes that are run are:
+    1. Conv/BN fusion
+    2. Dropout removal
+    3. MKL layout optimizations
+
+    The third optimization takes a function `use_mkl_heuristic` that's used
+    to determine whether a subgraph should be explicitly run in MKL layout.
+
+    Note: As FX does not currently handle aliasing, this pass currently
+    assumes nothing aliases. If that isn't true, use at your own risk.
+    """
+    default_pass_config = {
+        "conv_bn_fuse": True,
+        "remove_dropout": True,
+        "mkldnn_layout_optimize": {'heuristic': use_mkl_length},
+    }
+    if pass_config is None:
+        pass_config = {}
+    default_pass_config.update(pass_config)
+
+    if default_pass_config["conv_bn_fuse"]:
+        model = fuse(model)
+    if default_pass_config["remove_dropout"]:
+        model = remove_dropout(model)
+    if default_pass_config["mkldnn_layout_optimize"] is False:
+        return model
+    if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
+        raise RuntimeError("mkldnn_layout_optimize config is not a dict")
+    if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
+        raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
+    use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
+
+    cur_tracer = tracer()
+    fx_graph = cur_tracer.trace(copy.deepcopy(model))
+    fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
+    modules: Dict[str, nn.Module] = dict(model.named_modules())
+
+    class MklSupport(Enum):
+        NO = 1
+        YES = 2
+        UNKNOWN = 3
+
+    # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
+    # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
+    # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
+    # a MKLDNN node if its inputs are MKLDNN nodes.
+    for node in list(fx_graph.nodes):
+        supports_mkldnn = MklSupport.NO
+        if node.op == 'call_module':
+            cur_module = modules[node.target]
+            if type(cur_module) in mkldnn_supported:
+                supports_mkldnn = MklSupport.YES
+                sample_parameter = next(cur_module.parameters(), None)
+                if sample_parameter is not None:
+                    assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
+                    assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
+        elif node.op == 'call_function':
+            if node.target in mkldnn_supported:
+                supports_mkldnn = MklSupport.YES
+            elif node.target in mkldnn_supported_unknown:
+                supports_mkldnn = MklSupport.UNKNOWN
+
+        if supports_mkldnn != MklSupport.NO:
+            if supports_mkldnn == MklSupport.UNKNOWN:
+                if not any(arg.target == 'to_dense' for arg in node.args):
+                    continue
+            with fx_graph.inserting_before(node):
+                mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
+
+            node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
+
+            with fx_graph.inserting_after(node):
+                dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
+                node.replace_all_uses_with(dense_x)
+                dense_x.args = (node,)
+
+    # Does pre-conversion of all modules into MKLDNN (when possible)
+    old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
+    fx_graph.old_modules = old_modules  # type: ignore[attr-defined]
+
+    # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
+    for node in fx_graph.nodes:
+        if node.op == 'call_method' and node.target == 'to_dense':
+            prv_node = node.args[0]
+            users = list(node.users)
+            for user in users:
+                if user.op == 'call_method' and user.target == 'to_mkldnn':
+                    user.replace_all_uses_with(prv_node)
+                    fx_graph.erase_node(user)
+            if len(node.users) == 0:
+                fx_graph.erase_node(node)
+
+
+    num_nodes = len(fx_graph.nodes)
+    uf = UnionFind(num_nodes)
+
+    def get_color(n):
+        if hasattr(n, 'color'):  # Current node is part of a MKL subgraph
+            return uf.find(n.color)
+        if hasattr(n, 'start_color'):  # Current node is input to MKL subgraph
+            return uf.find(n.start_color)
+        return None
+
+
+    # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
+    # of input nodes (which are only `to_mkldnn` calls), output nodes
+    # (`to_dense` calls), and intermediate nodes, which are run entirely on
+    # MKLDNN layout tensors.
+    #
+    # Specifically, this code does a flood fill on a directed acyclic graph
+    # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
+    # If every node only had one input, this would be sufficient. However, in
+    # the case that a node has multiple inputs coming from different start
+    # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
+    # using a Disjoint Set Union.
+    for cur_idx, node in enumerate(fx_graph.nodes):
+        if node.op == 'call_method' and node.target == 'to_mkldnn':
+            node.start_color = cur_idx
+            uf.make_set(cur_idx)
+        elif node.op == 'call_method' and node.target == 'to_dense':
+            assert get_color(node.args[0]) is not None
+            node.end_color = get_color(node.args[0])
+        else:
+            cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
+
+            if len(cur_colors) == 0:
+                continue
+            assert not any(i is None for i in cur_colors)
+            cur_colors = sorted(cur_colors)
+            node.color = cur_colors[0]
+            for other_color in cur_colors[1:]:
+                uf.join(cur_colors[0], other_color)
+
+
+    mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
+    for node in fx_graph.nodes:
+        if hasattr(node, 'color'):
+            mkldnn_graphs[uf.find(node.color)].nodes.append(node)
+        if hasattr(node, 'start_color'):
+            mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
+        if hasattr(node, 'end_color'):
+            mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
+
+
+    # Now that we have all the subgraphs, we need to decide which MKLDNN
+    # subgraphs we actually want to keep in MKLDNN.
+    for graph in mkldnn_graphs.values():
+        if not use_mkl_heuristic(graph):
+            for node in graph.start_nodes + graph.end_nodes:
+                prv = node.args[0]
+                node.replace_all_uses_with(prv)
+                fx_graph.erase_node(node)
+            reset_modules(graph.nodes, modules, old_modules)
+
+    mkldnn_conversions = 0
+    for node in fx_graph.nodes:
+        if node.target == 'to_mkldnn' or node.target == 'to_dense':
+            mkldnn_conversions += 1
+
+    logging.getLogger(__name__).info(f"mkldnn conversions: {mkldnn_conversions}")
+    fx_graph.lint()
+    result = fx.GraphModule(model, fx_graph)
+    return result
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/partitioner_utils.py b/MLPY/Lib/site-packages/torch/fx/experimental/partitioner_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a520a0ff9845c0c0ab0c934ac9ec4487b9a19c9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/partitioner_utils.py
@@ -0,0 +1,317 @@
+from enum import Enum
+from typing import NamedTuple, Dict, List, Set
+
+from torch.fx.node import Node, map_arg
+
+
+class Partition:
+    """Partition class contains all the information about an individual partition.
+    It also provides necessary methods for manipulation the partition.
+    """
+
+    def __init__(self, partition_id: int) -> None:
+        self.nodes: Set[Node] = set()
+        self.partition_id = partition_id
+        self.parents: Set[Partition] = set()
+        self.children: Set[Partition] = set()
+        self.bfs_level: int = -1
+        self.used_mem_bytes: int = 0
+        self.logical_device_ids: List[int] = []
+
+    def __str__(self):
+        return str(self.partition_id)
+
+    def recalculate_mem_size(self):
+        self.used_mem_bytes = 0
+        for node in self.nodes:
+            self.used_mem_bytes += get_extra_size_of(node, self.nodes)
+
+    def add_node(self, node):
+        input_nodes: Dict[Node, None] = {}
+        map_arg(node.args, input_nodes.setdefault)
+        map_arg(node.kwargs, input_nodes.setdefault)
+        # Add current node's input nodes if they are placeholder or constants
+        for n in input_nodes:
+            if n.op in {"placeholder", "get_attr"}:
+                self.nodes.add(n)
+        self.nodes.add(node)
+        self.recalculate_mem_size()
+
+    def remove_node(self, node):
+        # Remove a node only if the node is in the partition
+        if node in self.nodes:
+            self.nodes.remove(node)
+            # Collect the node's input nodes
+            input_nodes: Dict[Node, None] = {}
+            map_arg(node.args, input_nodes.setdefault)
+            map_arg(node.kwargs, input_nodes.setdefault)
+            # Check if an input node is a placeholder or get_attr,
+            # and this input node is not used by some other nodes in this partition,
+            # the remove this input node
+            for input_node in input_nodes:
+                if all(
+                    n not in self.nodes for n in input_node.users
+                ) and input_node.op in {"placeholder", "get_attr"}:
+                    self.nodes.remove(input_node)
+            self.recalculate_mem_size()
+
+
+class Device(NamedTuple):
+    name: str
+    available_mem_bytes: int
+    logical_id: int
+
+
+class NodeLatency(NamedTuple):
+    # Latency due to the memory bandwidth
+    mem_latency_sec: float
+    # Latency due to the computation
+    computer_latency_sec: float
+
+
+class PartitionLatency(NamedTuple):
+    # Sum of all nodes' memory latency on the critical path
+    mem_latency_sec: float
+    # Sum of all nodes' compute latency on the critical path
+    computer_latency_sec: float
+    # Latency of the critical path
+    overall_latency_sec: float
+
+
+class PartitionMode(Enum):
+    size_based = 0
+    sparse_nn = 1
+    cost_aware = 2
+    kl_based = 3
+    aot_based = 4
+
+
+class PartitionerConfig(NamedTuple):
+    devices: List[Device]
+    mode: PartitionMode = PartitionMode.size_based
+    transfer_rate_bytes_per_sec: float = 0.0
+    node_to_latency_mapping: Dict[Node, NodeLatency] = {}
+    node_to_partition_mapping: Dict[Node, int] = {}
+    partition_to_logical_device_mapping: Dict[int, List[int]] = {}
+    # Saturate host by replicating partitions to the remaining idle devices.
+    saturate_host: bool = False
+
+
+def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
+    """Given a node and a set of nodes,
+    this function return the extra size that needed
+    if this node is included in this set.
+    """
+    # Find all its input nodes
+    input_nodes: Dict[Node, None] = {}
+    map_arg(node.args, input_nodes.setdefault)
+    map_arg(node.kwargs, input_nodes.setdefault)
+    # Calculate total size of related nodes
+    total_size_of_input_nodes = 0
+    for n in input_nodes:
+        # Make sure this node hasn't been in this set yet
+        if n not in nodes:
+            size_bytes = getattr(n, "size_bytes", None)
+            if size_bytes:
+                total_size_of_input_nodes += size_bytes.output_size
+            else:
+                raise RuntimeError("node has no size_bytes attr")
+    # Don't forget the op node itself
+    size_bytes = getattr(node, "size_bytes", None)
+    if size_bytes:
+        total_size_of_input_nodes += size_bytes.total_size
+    else:
+        raise RuntimeError("node has no size_bytes attr")
+    return total_size_of_input_nodes
+
+
+def get_latency_of_one_partition(
+    partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
+) -> PartitionLatency:
+    """Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
+
+    def get_top_nodes(partition: Partition) -> List[Node]:
+        """Given a partition, return a list of nodes on the top bfs level"""
+        top_nodes: List[Node] = []
+        for node in partition.nodes:
+            # Skip placeholder and get_attr nodes
+            if node.op in {"placeholder", "get_attr"}:
+                continue
+            input_nodes: Dict[Node, None] = {}
+            map_arg(node.args, input_nodes.setdefault)
+            map_arg(node.kwargs, input_nodes.setdefault)
+            # If a node has no input nodes in this partition,
+            # or its input nodes in this partition are placeholders and get_attrs
+            # this node is on the top bfs level in this partition
+            if not any(
+                n in partition.nodes and n.op not in {"placeholder", "get_attr"}
+                    for n in input_nodes
+            ):
+                top_nodes.append(node)
+        return top_nodes
+
+    def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
+        """Given a top node of a partition, this function returns
+        the latency of the critical path in the partition
+        """
+        node_latency = node_to_latency_mapping[node]
+        # Calculate the current overall latency of the partition
+        overall_latency_sec = partition_latency.overall_latency_sec + max(
+            node_latency.computer_latency_sec, node_latency.mem_latency_sec
+        )
+        # Update the mem latency of this path
+        mem_latency_sec = (
+            partition_latency.mem_latency_sec + node_latency.mem_latency_sec
+        )
+        # Update the compute latency of this path
+        computer_latency_sec = (
+            partition_latency.computer_latency_sec + node_latency.computer_latency_sec
+        )
+        # Get all users of this node that are in this partition
+        users = set(node.users).intersection(partition.nodes)
+        if users:
+            max_latency = PartitionLatency(
+                mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
+            )
+            for n in users:
+                # Get new partition latency recursively
+                new_partition_latency = dfs_helper(
+                    n,
+                    PartitionLatency(
+                        mem_latency_sec, computer_latency_sec, overall_latency_sec
+                    ),
+                )
+                if (
+                    new_partition_latency.overall_latency_sec
+                    > max_latency.overall_latency_sec
+                ):
+                    max_latency = new_partition_latency
+            return max_latency
+        # If there is no user, the node is at bottom of the partition
+        return PartitionLatency(
+            mem_latency_sec, computer_latency_sec, overall_latency_sec
+        )
+
+    # Main part starts
+    # Get all top level nodes of this partition
+    top_nodes = get_top_nodes(partition)
+    critical_path_latency = PartitionLatency(
+        mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
+    )
+    # Go through all top nodes and find the largest latency (critical pass latency)
+    for node in top_nodes:
+        partition_latency = dfs_helper(
+            node,
+            PartitionLatency(
+                mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
+            ),
+        )
+        if (
+            partition_latency.overall_latency_sec
+            > critical_path_latency.overall_latency_sec
+        ):
+            critical_path_latency = partition_latency
+    return critical_path_latency
+
+
+def get_partition_to_latency_mapping(
+    partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
+) -> Dict[Partition, PartitionLatency]:
+    """Given all the partitions and node_to_latency_mapping dictionary,
+    return a mapping dictionary of each partition to its overall latency
+    """
+    partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
+    # Go through each partition and get its latency
+    for partition in partitions:
+        partition_latency = get_latency_of_one_partition(
+            partition, node_to_latency_mapping
+        )
+        partition_to_latency_mapping[partition] = partition_latency
+    return partition_to_latency_mapping
+
+
+def get_comm_latency_between(
+    parent_partition: Partition,
+    child_partition: Partition,
+    transfer_rate_bytes_per_sec: float,
+):
+    """Given two partitions (parent and child),
+    calculate the communication latency between the two.
+    """
+    # If two partitions are on the same device, the comm latency is 0.
+    if (
+        parent_partition.logical_device_ids != []
+        and child_partition.logical_device_ids != []
+        and parent_partition.logical_device_ids == child_partition.logical_device_ids
+    ):
+        return 0.0
+    # Keep tracking the communication size between parent and child
+    comm_size = 0
+    # Keep tracking all the counted node
+    visited_nodes = set()
+    # Go through all nodes in the child partition
+    # If a node has input nodes from the parent partition,
+    # the output size of those input nodes will be counted
+    # and added to comm_size
+    for node in child_partition.nodes:
+        input_nodes: Dict[Node, None] = {}
+        map_arg(node.args, input_nodes.setdefault)
+        map_arg(node.kwargs, input_nodes.setdefault)
+        for n in input_nodes:
+            if n in parent_partition.nodes and n not in visited_nodes:
+                size_bytes = getattr(n, "size_bytes", None)
+                if size_bytes is not None:
+                    comm_size += size_bytes.output_size
+                visited_nodes.add(n)
+    return comm_size / transfer_rate_bytes_per_sec
+
+
+def get_latency_of_partitioned_graph(
+    partitions: List[Partition],
+    partition_to_latency_mapping: Dict[Partition, PartitionLatency],
+    transfer_rate_bytes_per_sec: float,
+):
+    """Given all partitions in a graph, find the critical path among all partitions
+    and return its latency as the latency of the whole graph
+    """
+
+    def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
+        """This function helps to recursively get the latency of a path of partitions"""
+        # Update latency by adding current partition's latency
+        latency_so_far_sec += partition_to_latency_mapping[
+            partition
+        ].overall_latency_sec
+        children = partition.children
+        if partition.children:
+            max_latency_sec = 0.0
+            for child in partition.children:
+                # Calculate latency between
+                comm_latency_sec = get_comm_latency_between(
+                    partition, child, transfer_rate_bytes_per_sec
+                )
+                new_latency_sec = dfs_helper(
+                    child, latency_so_far_sec + comm_latency_sec
+                )
+                if new_latency_sec > max_latency_sec:
+                    max_latency_sec = new_latency_sec
+            return max_latency_sec
+        return latency_so_far_sec
+
+    def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
+        """This function is to return all the partitions without parents
+        as the starting points of all the paths
+        """
+        top_partitions = []
+        for partition in partitions:
+            # If a partition has no parents, then it is a top partition
+            if len(partition.parents) == 0:
+                top_partitions.append(partition)
+        return top_partitions
+
+    top_partitions = get_top_partitions(partitions)
+    critical_path_latency_sec = 0.0
+    for partition in top_partitions:
+        latency_sec = dfs_helper(partition, 0.0)
+        if latency_sec > critical_path_latency_sec:
+            critical_path_latency_sec = latency_sec
+    return critical_path_latency_sec
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/proxy_tensor.py b/MLPY/Lib/site-packages/torch/fx/experimental/proxy_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b8bd1500088c0f2e19cfefa63fd40eab41b693e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/proxy_tensor.py
@@ -0,0 +1,1122 @@
+# mypy: ignore-errors
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import contextlib
+import functools
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import torch
+import torch.utils._pytree as pytree
+from torch.fx import Tracer, GraphModule
+from torch.fx.graph_module import _assign_attr
+from weakref import WeakKeyDictionary
+from collections import defaultdict
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
+from torch._dispatch.python import enable_python_dispatcher, enable_pre_dispatch
+import torch.fx as fx
+from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from contextlib import contextmanager, nullcontext
+import inspect
+from dataclasses import dataclass
+import weakref
+import operator
+from torch.utils._stats import count
+import logging
+
+from torch.overrides import TorchFunctionMode
+
+from torch.utils._python_dispatch import (
+    TorchDispatchMode,
+    _disable_infra_mode,
+    _push_mode,
+    _unset_infra_mode,
+)
+
+from ._backward_state import BackwardState
+from .sym_node import SymNode
+from ._sym_dispatch_mode import SymDispatchMode
+from torch.fx import Proxy
+import torch.fx.traceback as fx_traceback
+from torch import SymInt, SymFloat, SymBool
+from torch.utils.weak import WeakTensorKeyDictionary, WeakIdKeyDictionary, _WeakHashRef
+
+__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
+
+aten = torch.ops.aten
+prim = torch.ops.prim
+
+log = logging.getLogger(__name__)
+not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
+
+CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OperatorBase, Callable] = {}
+
+CONSTANT_NUMEL_LIMIT = 1
+
+# We currently convert all SymInt to proxies before we use them.
+# This could plausibly be handled at the Dynamo level.
+pytree.register_pytree_node(
+    torch.Size,
+    lambda xs: (list(xs), None),
+    lambda xs, _: tuple(xs),
+    flatten_with_keys_fn=lambda xs: (
+        [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
+        None,
+    ),
+)
+def fake_signature(fn, nargs):
+    """FX gets confused by varargs, de-confuse it"""
+    argnames = ",".join(f"arg{i}" for i in range(nargs))
+    return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
+
+@contextmanager
+def decompose(decomposition_table):
+    global CURRENT_DECOMPOSITION_TABLE
+    old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
+    CURRENT_DECOMPOSITION_TABLE = decomposition_table
+    try:
+        yield CURRENT_DECOMPOSITION_TABLE
+    finally:
+        CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
+
+# ensure we cannot collide with other properties
+proxy_slot = object()
+no_default = object()
+
+py_sym_types = (SymInt, SymFloat, SymBool)
+
+def is_sym_node(node):
+    assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
+    return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
+
+def set_proxy_slot(obj, tracer, proxy):
+    if isinstance(obj, torch.Tensor):
+        # We DO want to clobber proxies whenever we run an inplace operation
+        # on a tensor, and it affects the metadata on the proxy.
+        tracer.tensor_tracker[obj] = proxy
+    elif isinstance(obj, torch.ScriptObject):
+        # We DO want to clobber proxies, with a similar rationale as for tensors.
+        tracer.script_object_tracker[obj] = proxy
+    else:
+        # NB: Never clobber pre-existing proxy.  Although the proxies
+        # are in principle equivalent, when we do graph partitioning
+        # we need there not to be spurious dependencies on tangent inputs.
+        # This works because primals get their SymInts set first, and
+        # THEN later we allocate tangent inputs.  Make sure if a SymInt
+        # is derivable from a primal that we use that.
+        assert isinstance(obj, py_sym_types), type(obj)
+        if obj not in tracer.symnode_tracker:
+            tracer.symnode_tracker[obj] = proxy
+
+def has_proxy_slot(obj, tracer):
+    assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
+    return get_proxy_slot(obj, tracer, False, lambda _: True)
+
+# the default argument is what to return if the slot is not set.
+# the transform argument is handy if you need to extract a subfield from
+# the successfully looked up result (but NOT the default.)
+def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
+    if isinstance(obj, torch.Tensor):
+        tracker = tracer.tensor_tracker
+    elif isinstance(obj, torch.ScriptObject):
+        tracker = tracer.script_object_tracker
+    else:
+        assert isinstance(obj, py_sym_types), type(obj)
+        tracker = tracer.symnode_tracker
+
+    if obj not in tracker:
+        if default is no_default:
+            raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}")
+        return default
+    return transform(tracker[obj])
+
+def snapshot_fake(val):
+    return val.detach()
+
+def extract_val(val):
+    if is_fake(val):
+        return snapshot_fake(val)
+    elif isinstance(val, py_sym_types):
+        return val
+    elif isinstance(val, torch.ScriptObject):
+        return val
+    elif isinstance(val, BackwardState):
+        return val
+    elif isinstance(val, (list, tuple)):
+        return val.__class__([extract_val(x) for x in val])
+    elif isinstance(val, torch.Tensor):
+        if not val.is_sparse:
+            # NB: Kinda hacky, but we should try to get val as the metadata
+            # everywhere
+            # TODO: This doesn't properly track storages.  A more robust
+            # approach would be to maintain a per-trace FakeTensorMode and
+            # from_real_tensor to create fake values (don't forget to
+            # snapshot_fake)
+            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
+            with fake_tensor_mode:
+                return torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
+        else:
+            return None
+    elif isinstance(val, (int, float, bool)):
+        return val
+
+# What invariants do we have for the 'val' set on the FX node?  It has accurate
+# metadata... but only for metadata that exists "below" all other subsystems
+# (most notably autograd, but also vmap, functorch transforms, etc).  This means
+# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad,
+# grad_fn, _base (_base actually may be set due to recursive call to
+# ADInplaceOrView, but you shouldn't rely on it.)
+def set_meta(proxy, val):
+    proxy.node.meta['val'] = extract_val(val)
+    # Best effort tensor_meta setting; prefer using val!
+    if is_fake(val):
+        proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
+    elif isinstance(val, torch.Tensor) and not val.is_sparse:
+        proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
+    return proxy
+
+def thunkify(f, *args, **kwargs):
+    """
+    Delays computation of f until it's called again
+    Also caches the result
+    """
+    return functools.lru_cache(1)(functools.partial(f, *args, **kwargs))
+
+def track_tensor(tensor, proxy, *, constant, tracer):
+    def try_set_proxy_slot(outer_s, proxy_callable, *args):
+        assert callable(proxy_callable)
+        if isinstance(outer_s, SymInt):
+            set_proxy_slot(outer_s, tracer, thunkify(proxy_callable, outer_s, *args))
+    # The basic idea is that we need to associate each tensor/SymInt
+    # with a Proxy.  How do we setup this association?  We just store
+    # the proxy on the proxy slot of the object, keyed on the tracer
+    # (so that if we have multiple tracers at the same time, they
+    # don't clobber each other.)
+    for i, s in enumerate(tensor.shape):
+        try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size.int(proxy, i), x), i)
+
+    for i, s in enumerate(tensor.stride()):
+        try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride.int(proxy, i), x), i)
+
+    try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel.default(proxy), x))
+    try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset.default(proxy), x))
+    set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
+
+def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
+    def wrap_with_proxy(e, proxy, constant):
+        if isinstance(e, torch.Tensor):
+            track_tensor(e, proxy, tracer=tracer, constant=constant)
+            set_meta(proxy, e)
+        elif isinstance(e, py_sym_types):
+            # NB: eagerly set meta here, so that the numbering is in order
+            set_meta(proxy, e)
+            set_proxy_slot(e, tracer, lambda: proxy)
+        elif isinstance(e, torch.ScriptObject):
+            set_proxy_slot(e, tracer, proxy)
+            set_meta(proxy, e)
+        elif isinstance(e, (tuple, list)):
+            if isinstance(proxy, fx.Proxy):
+                set_meta(proxy, e)
+
+            # example use case: allreduce_ returns ([tensor], work)
+            for idx, ee in enumerate(e):
+                wrap_with_proxy(ee, proxy[idx], get_constant(idx))
+        elif isinstance(e, dict):
+            # In theory we could support const-prop when proxy-tensor-tracing
+            # operators that returns dicts of tensors, but we have no use case
+            # for it today (since the only op we currently trace that can
+            # return a dict is triton_kernel_wrapper_functional/mutation,
+            # which does not participate in const-prop)
+            assert constant is None
+
+            if isinstance(proxy, fx.Proxy):
+                set_meta(proxy, e)
+
+            # example use case: triton_kernel_wrapper takes arguments as kwargs
+            for key, val in e.items():
+                wrap_with_proxy(val, proxy[key], None)
+        elif isinstance(e, BackwardState):
+            set_meta(proxy, e)
+            e.proxy = proxy
+        else:
+            # intentionally pass on primitives
+            pass
+
+
+    def get_constant(idx):
+        if constant is None:
+            return None
+        else:
+            return constant[idx]
+
+    wrap_with_proxy(inner_res, proxy_res, constant)
+
+    return inner_res
+
+
+def maybe_disable_fake_tensor_mode():
+    # TODO: figure out if this API generally makes sense and bake it into the
+    # library
+    return unset_fake_temporarily()
+
+
+@dataclass
+class _ProxyTensor:
+    proxy: Proxy
+    constant: Optional[torch.Tensor]
+
+
+def fetch_sym_proxy(tracer):
+    def inner(e):
+        n = e.node
+        if n.constant is not None:
+            return n.constant
+        if e.node.expr.is_number:
+            if isinstance(e, SymBool):
+                return bool(e.node.expr)
+            elif isinstance(e, SymInt):
+                return int(e.node.expr)
+            return float(e.node.expr)
+        else:
+            # NB: we REQUIRE all symints to be tracked
+            return get_proxy_slot(e, tracer)()
+    return inner
+
+
+def fetch_object_proxy(tracer):
+    return lambda t: get_proxy_slot(t, tracer, t)
+
+HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, FakeTensor)
+
+def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
+    unrecognized_types = []
+
+    def can_handle_tensor(x):
+        r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
+        if proxy_mode._allow_fake_constant:
+            r = r or type(x) in (torch._subclasses.FakeTensor,)
+        if not r:
+            unrecognized_types.append(type(x))
+        return r
+
+    # If there are any tensor subclasses, we need to handle those tensor subclasses first
+    # TODO: we could use types to test this
+    if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)):
+        not_implemented_log.debug("ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", unrecognized_types)
+        return NotImplemented
+
+    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
+    if r is not NotImplemented:
+        return r
+
+    # For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
+    if not pre_dispatch and func not in [
+        torch.ops.aten.size.default, torch.ops.aten.stride.default, torch.ops.aten.storage_offset.default
+    ]:
+        with proxy_mode:
+            r = func.decompose(*args, **kwargs)
+            if r is not NotImplemented:
+                return r
+
+    tracer = proxy_mode.tracer
+    f_args, f_kwargs = pytree.tree_map_only((torch.Tensor, torch.ScriptObject), fetch_object_proxy(tracer), (args, kwargs))
+
+    # If there are SymInts, we also should not consider this constant.
+    # However, fake tensor handling of SymInts is sufficiently broken that
+    # I couldn't write a test for this case
+    all_constant = (
+        pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
+        # TODO: maybe constant SymInts should also be allowed?  Not sure if
+        # this can happen
+        and pytree.tree_all_only((SymInt, SymFloat, SymBool), lambda _: False, (args, kwargs))
+    )
+
+    if torch.Tag.data_dependent_output in func.tags:
+        # Check if all of the Tensor inputs are constants
+        if all_constant:
+            const_args, const_kwargs = pytree.tree_map_only(
+                _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
+            )
+            with maybe_disable_fake_tensor_mode():
+                return func(*const_args, **const_kwargs)
+        # If any of the Tensor inputs are "real" (not FakeTensor), we may
+        # incorrectly burn in constants by allowing this access.  Raise
+        # an error in this case
+        if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only(torch.Tensor, lambda t: not is_fake(t), (args, kwargs)):
+            raise RuntimeError(
+                f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
+                "It's likely that this is caused by data-dependent control flow or similar.  "
+                "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' "
+                "in your make_fx call."
+            )
+    proxy_args, proxy_kwargs = pytree.tree_map_only(
+        (SymInt, SymFloat, SymBool),
+        fetch_sym_proxy(proxy_mode.tracer),
+        pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
+    )
+
+    # When we trace through a torch.tensor invocation, you never actually
+    # see a torch.ops.aten.tensor call. Instead, the way this function is
+    # implemented internally is that we allocate a plain tensor (this is
+    # *guaranteed* to be a plain tensor, we disable all modes when doing
+    # so), and then call at::lift_fresh on it (to give modes a chance to do
+    # their stuff).  Furthermore, the tensor argument to lift_fresh is guaranteed
+    # to be freshly allocated, so we want lift_fresh to be a no-op (directly
+    # returning the input argument).
+    #
+    # Here is the basic problem: when we trace this sequence of executions
+    # into an FX graph, what happens to this call sequence?  Traditionally,
+    # tensor constants get interned as buffers on the FX GraphModule.  But
+    # this is dangerous.  Consider:
+    #
+    #       x = torch.tensor(1)
+    #       x.add_(2)
+    #
+    # Naively, this traces into:
+    #
+    #       t = self._tensor_constant0  # initialized to torch.tensor(1)
+    #       x = torch.ops.aten.lift_fresh(t)
+    #       x.add_(2)
+    #
+    # If lift_fresh returns t directly, the subsequent add_ call will
+    # modify the tensor constant. Really, the problem is we've violated
+    # the invariant the argument to lift is fresh.  So what we should
+    # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
+    #
+    #       t = self._tensor_constant0  # initialized to torch.tensor(1)
+    #       x = torch.ops.aten.lift_fresh_copy(t)
+    #       x.add_(2)
+    #
+    # This is what the overload modification does.
+    if func is torch.ops.aten.lift_fresh.default:
+        func = torch.ops.aten.lift_fresh_copy.default
+
+
+    proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
+                                               name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
+
+    # This makes DCE marginally less likely to DCE inplace operations.
+    # It is not strictly necessary
+    # Kind of a hacky way to test if an op is in-place or not
+    if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_":
+        if isinstance(args[0], List):
+            # e.g., c10d::allreduce_ returns a list of tensors as the first element
+            # in the output.
+            for i, a in enumerate(args[0]):
+                a.proxy = proxy_out[0][i]
+        else:
+            args[0].proxy = proxy_out
+
+    out = func(*args, **kwargs)
+
+    # In some circumstances, we will be tracing in a situation where a tensor
+    # is *statically* known to be a constant (currently, this only happens if
+    # you run torch.tensor; deterministic factory functions like torch.arange
+    # don't get this treatment).  When the tensor in question is small, it's
+    # helpful to due constant propagation in case we call item() (in which
+    # case we can return the constant value that is known, rather than give
+    # an error.)  The logic here tests if constant propagation is possible
+    # (because all of the inputs are constant).  If so, we disable fake tensor
+    # mode (if it is on) and do true compute on the constant.
+    #
+    # It's worth highlighting that we're making a policy decision here.
+    # There is a potential that the tensor is actually quite large, and we
+    # don't actually want to run the compute.  The tensor being quite large
+    # is one of the reasons why factory functions don't get this treatment
+    # (since they can be quite large; if a parameter is initialized to a
+    # constant value it will be!)  Similarly, there is also a potential
+    # to run an operator that blows up the size of a small tensor; we don't
+    # protect against this case, but we could force, e.g., only single
+    # element constant computation by testing the numel of the result before
+    # propagating const-ness.  Similarly, we don't require the constant to
+    # live on CPU, but we could.
+    any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
+
+    constant = None
+
+    # If this is a lift, the input tensor is guaranteed to be a
+    # constant, so we keep a copy of the original argument along so
+    # we can query it if we're asked to item() it at some later point
+    if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT:
+        with maybe_disable_fake_tensor_mode():
+            constant = args[0].clone()
+    elif (
+        torch.Tag.nondeterministic_seeded not in func.tags
+        and all_constant
+        and any_constant
+        and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out)
+    ):
+        # NB: do NOT include factories as constants
+        with maybe_disable_fake_tensor_mode():
+            const_args, const_kwargs = pytree.tree_map_only(
+                _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
+            )
+            constant = func(*const_args, **const_kwargs)
+    else:
+        constant = None
+
+    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
+    return out
+
+class _SymNodeDict:
+    """
+    Wrapper around a dictionary that will hash SymInts with their nodes
+    """
+    def __init__(self):
+        self.sym_node_dict = {}
+
+    def __setitem__(self, key: py_sym_types, value: Any):
+        self.sym_node_dict[key.node] = value
+
+    def __getitem__(self, key: py_sym_types):
+        return self.sym_node_dict[key.node]
+
+    def __contains__(self, key: py_sym_types):
+        return key.node in self.sym_node_dict
+
+    def get(self, key: py_sym_types, default: Any = None):
+        return self.sym_node_dict.get(key.node, default)
+
+class PythonKeyTracer(Tracer):
+    def __init__(self):
+        super().__init__(autowrap_modules=())
+        self.tensor_tracker = WeakTensorKeyDictionary()
+        self.symnode_tracker = _SymNodeDict()  # type: ignore[var-annotated]
+        self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef)
+
+    # In general, we don't want to make modules leaves. In principle, users of
+    # this tracer might want to override this in order to turn a couple specific
+    # modules into leaves in the traced graph.
+    def call_module(
+            self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
+    ) -> Any:
+        return forward(*args, **kwargs)
+
+    # We don't want to turn getattr calls into proxies. So we just return the actual value.
+    def getattr(self, attr, attr_val, parameter_proxy_cache):
+        return attr_val
+
+    def create_arg(self, a: Any):
+        if isinstance(a, torch.nn.Parameter):
+            for n, p in self.root.named_parameters():
+                if a is p:
+                    return self.create_node('get_attr', n, (), {})
+            qualname: Optional[str] = None
+
+            if not qualname:
+                i = 0
+                while True:
+                    qualname = f'_param_constant{i}'
+                    if not hasattr(self.root, qualname):
+                        break
+                    i += 1
+                setattr(self.root, qualname, a)
+
+            return self.create_node('get_attr', qualname, (), {})
+        elif isinstance(a, (SymInt, SymFloat, SymBool)):
+            assert a.node.constant is not None
+            return a.node.constant
+        return super().create_arg(a)
+
+    def unwrap_proxy(self, e):
+        if isinstance(e, torch.Tensor):
+            return get_proxy_slot(e, self, e, lambda e: e.proxy)
+        elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)):
+            return get_proxy_slot(e, self, e, lambda e: e())
+        elif isinstance(e, torch.ScriptObject):
+            return get_proxy_slot(e, self, e)
+        else:
+            return e
+
+
+@torch._disable_dynamo
+def dispatch_trace(
+        root: Union[torch.nn.Module, Callable],
+        tracer: Tracer,
+        concrete_args: Optional[Tuple[Any, ...]] = None,
+) -> GraphModule:
+    graph = tracer.trace(root, concrete_args)
+    from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
+    dedupe_symints(graph)
+    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
+    return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
+
+
+def wrap_key(f, tensors, tracer, pre_dispatch: bool):
+    flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
+
+    @functools.wraps(f)
+    def wrapped(*proxies):
+        flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
+        assert len(flat_proxies) == len(flat_tensors)
+        with disable_proxy_modes_tracing() as m:
+            assert isinstance(m, ProxyTorchDispatchMode)
+            track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
+
+        out = f(*tensors)
+        out = pytree.tree_map_only(
+            torch.Tensor,
+            lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy),
+            out
+        )
+        out = pytree.tree_map_only(
+            (SymInt, SymFloat, SymBool),
+            lambda t: get_proxy_slot(t, tracer)(),
+            out
+        )
+        return out
+
+    return wrapped
+
+ORIGINAL_ATEN = None
+@contextmanager
+def set_original_aten_op(func):
+    global ORIGINAL_ATEN
+    if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
+        ORIGINAL_ATEN = func
+        fx_traceback.current_meta['original_aten'] = func
+        try:
+            yield
+        finally:
+            ORIGINAL_ATEN = None
+            fx_traceback.current_meta['original_aten'] = None
+    else:
+        yield
+
+
+
+# This mode is **only** used for pre_dispatch tracing.
+# In particular, we need to make sure that autograd/autocast API's
+# that do not desugar into dispatcher operators stay in the graph.
+class PreDispatchTorchFunctionMode(TorchFunctionMode):
+
+    def __init__(self, tracer):
+        self.tracer = tracer
+
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+        if func in _side_effectful_need_to_be_preserved_pre_dispatch:
+            # It's for passing the export verifier which needs to verify the meta['val']
+            # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
+            # instead of hardcoding it here.
+            node = self.tracer.create_node("call_function", func, args, {})
+            if func is torch._C._set_grad_enabled:
+                node.meta['val'] = None
+            return node
+            # Don't actually run the function! We just want to trace the calls
+            # into a graph. We don't actualy want to change global autograd state.
+        return func(*args, **kwargs)
+
+
+class ProxyTorchDispatchMode(TorchDispatchMode):
+    def __init__(self, tracer, tracing_mode, pre_dispatch=False, _allow_fake_constant=False, _error_on_data_dependent_ops=True):
+        dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
+        super().__init__(dk)
+        self.tracer = tracer
+        self.tracing_mode = tracing_mode
+        self.enable_tracing = True
+        self.pre_dispatch = pre_dispatch
+        self._allow_fake_constant = _allow_fake_constant
+        self._error_on_data_dependent_ops = _error_on_data_dependent_ops
+        self.sym_mode = ProxySymDispatchMode(tracer)
+        self.trace_state = {}
+        self._managers = []
+        # Indicates to our torch_dispatch dispatching infra that
+        # this is an "infra" mode with lower dispatching precedence.
+        self._mode_key = torch._C._TorchDispatchModeKey.PROXY
+        # Every time we enter a mode, we maintain a stack telling us what the previous
+        # ProxyTorchDispatchMode state was (if there was any).
+        # This lets us properly reset the state on exit.
+        self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = []
+
+    @count
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        with self.sym_mode.enable(False), set_original_aten_op(func):
+            return self.inner_torch_dispatch(func, types, args, kwargs)
+
+    def __enter__(self):
+        # sym mode first, then us...
+        m = self.sym_mode.enable(True)
+        self._managers.append(m)
+        m.__enter__()
+        # Stash and store the previous proxy mode (there may or may not be one)
+        maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
+        self.enter_stack.append(maybe_prev_proxy_mode)
+        return super().__enter__()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        m = self._managers.pop()
+        # ...exit us first, then sym mode
+        b = super().__exit__(exc_type, exc_value, traceback)
+
+        # Re-enable the previous proxy mode, if there was one.
+        mb_previous_proxy_mode = self.enter_stack.pop()
+        if mb_previous_proxy_mode is not None:
+            _push_mode(mb_previous_proxy_mode)
+
+        if not b:
+            return m.__exit__(exc_type, exc_value, traceback)
+        else:
+            return m.__exit__(None, None, None)
+
+
+    def inner_torch_dispatch(self, func, types, args=(), kwargs=None):
+        if not self.enable_tracing:
+            return func(*args, **kwargs)
+
+        if func in [prim.device.default]:
+            return func(*args, **kwargs)
+
+        return proxy_call(self, func, self.pre_dispatch, args, kwargs)
+
+
+class ProxySymDispatchMode(SymDispatchMode):
+    def __init__(self, tracer):
+        super().__init__()
+        self.tracer = tracer
+        # When false, we don't trace operations.  If you do this, you MUST
+        # call track_tensor/track_tensor_tree on all results of the operation
+        # to ensure we can adequately track the results
+        self.enable_tracing = True
+
+    @contextmanager
+    def enable(self, b):
+        old = self.enable_tracing
+        self.enable_tracing = b
+        try:
+            yield
+        finally:
+            self.enable_tracing = old
+
+    def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat, SymBool]):
+        n_args = tuple(
+            get_proxy_slot(a, self.tracer)().node if isinstance(a, py_sym_types) else a
+            for a in args
+        )
+
+        # func doesn't have a __torch_function__ that Proxy can interpose, so
+        # we gotta do it manually
+        n_out = self.tracer.create_node("call_function", func, n_args, {})
+        p_out = fx.Proxy(n_out, self.tracer)
+        set_meta(p_out, out)
+        return p_out
+
+    def __sym_dispatch__(self, func, types, args, kwargs):
+        if not self.enable_tracing:
+            return func(*args, **kwargs)
+
+        # Peephole optimize multiply by one
+        # NB: be careful not to trigger guards here!
+        if func == operator.mul:
+            if isinstance(args[1], int) and args[1] == 1:
+                return args[0]
+            elif isinstance(args[0], int) and args[0] == 1:
+                return args[1]
+
+        # For speed, we assume there are no nested data structures
+        # (otherwise we could use tree_map)
+        # We also assume there are no keyword arguments.
+        assert not kwargs
+        out = func(*args, **kwargs)
+
+        # If func returned a constant, we don't need to trace; we have
+        # determined that the result is constant (no matter if the inputs
+        # were symbolic) and it is no longer necessary to trace the
+        # computation.  This could occur if func triggered some guards.
+        if isinstance(out, py_sym_types):
+            # Delays tracing out the proxies on this op until we actually need it
+            p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
+            set_proxy_slot(out, self.tracer, p_out_thunk)
+
+        return out
+
+
+# TODO: I'm not sure what the point of this class is; you can just
+# make_fx through a regular Interpreter
+class DecompositionInterpreter(torch.fx.Interpreter):
+    def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs):
+        super().__init__(module, **kwargs)
+        self.new_graph = new_graph
+        self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph)
+        # Blegh
+        self.tracer.tensor_tracker = WeakTensorKeyDictionary()  # type: ignore[attr-defined]
+        self.tracer.symnode_tracker = weakref.WeakKeyDictionary()  # type: ignore[attr-defined]
+        self.decomposition_table = decomposition_table
+        if self.decomposition_table is None:
+            self.decomposition_table = {}
+        self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
+
+    def placeholder(self, target, args, kwargs):
+        out = super().placeholder(target, args, kwargs)
+        proxy = torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer)
+        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
+        # TODO handle case where the first character of target is '*'
+        return out
+
+    def get_attr(self, target, args, kwargs):
+        out = super().get_attr(target, args, kwargs)
+        proxy = torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer)
+        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
+        return out
+
+    # call_function, call_method, call_module get traced automatically by the outer mode.
+
+    def output(self, target, args, kwargs):
+        out = super().output(target, args, kwargs)
+
+        def unwrap(e):
+            return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node)
+        self.new_graph.output(pytree.tree_map(unwrap, out))
+        return out
+
+    def run(self, *args, **kwargs):
+        # Should enter the mode at least once for being able to restore it later
+        # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
+        with decompose(self.decomposition_table), self.mode:
+            return super().run(*args, **kwargs)
+
+
+def wrapper_and_args_for_make_fx(func, args, kwargs):
+    # make_fx doesn't support kwargs, so we need to do this flattening
+    # and then unflatten the args before calling func
+    flat_args, spec = pytree.tree_flatten((args, kwargs))
+
+    def wrapped(flat_args):
+        fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
+        return func(*fn_args, **fn_kwargs)
+    return wrapped, flat_args
+
+@contextmanager
+def disable_autocast_cache():
+    old_value = torch.is_autocast_cache_enabled()
+    torch.set_autocast_cache_enabled(False)
+    try:
+        yield
+    finally:
+        torch.set_autocast_cache_enabled(old_value)
+
+
+class _ModuleStackTracer(PythonKeyTracer):
+    r"""Customized version of PythonKeyTracer that retains module stack
+    information in node.meta["nn_module_stack"].
+
+    FX symbolic trace actually does this already, but it relies on `self.root`
+    being the actual module being traced. Since make_fx traces a lambda of our
+    creation, things don't work properly.
+
+    So for this version we hold onto a reference to the original module
+    (scope_root) and use that to match the path. Also when we see,
+            A
+           / \
+          B   C
+           \ /
+            D
+    we want to record the path as A.B.D by recording only one path.
+    See Note [Preserving the nn module stack metadata during export non-strict mode]  # noqa: W605
+    """
+
+    def __init__(self, scope_root):
+        super().__init__()
+        self.scope_root = scope_root
+        self.proxy_paths = WeakKeyDictionary()
+        self.proxy_modules = WeakKeyDictionary()
+        self.counter = 0
+
+        self.module_id_cache = defaultdict(list)
+        for name, mod in self.scope_root.named_modules(remove_duplicate=False):
+            self.module_id_cache[id(mod)].append(name)
+
+        self_ = self
+
+        class AttrProxy:
+            def __init__(self, base, path):
+                self.__class__ = type(
+                    base.__class__.__name__,
+                    (self.__class__, base.__class__),
+                    {},
+                )
+                self.__dict__ = base.__dict__
+                self.__class__.__module__ = base.__class__.__module__
+                self.__class__.__qualname__ = base.__class__.__qualname__
+                self_.proxy_paths[self] = path
+                self_.proxy_modules[self] = base
+
+            def __getattr__(self, name):
+                assert isinstance(self, torch.nn.Module)
+                attr_val = super().__getattr__(name)
+                if isinstance(attr_val, AttrProxy):
+                    attr_val = self_.proxy_modules[attr_val]
+                elif not isinstance(attr_val, torch.nn.Module):
+                    return attr_val
+                return AttrProxy(attr_val, self_.proxy_paths[self] + "." + name)
+
+            @property
+            def _modules(self):
+                assert "_modules" in self.__dict__
+                submodules = self.__dict__["_modules"]
+                assert isinstance(submodules, dict)
+                return {
+                    key: AttrProxy(value, self_.proxy_paths[self] + "." + str(key))
+                    for key, value in submodules.items()
+                }
+
+        self.proxy_type = AttrProxy
+
+    def path_of_module(self, mod: torch.nn.Module) -> str:
+        """
+        Use tracked access path during tracing instead of the default BFS behavior.
+        Still use all the possible module paths to verify the result.
+        """
+        if mod is self.scope_root:
+            return ""
+
+        if isinstance(mod, self.proxy_type):
+            return self.proxy_paths[mod]
+
+        return Tracer.path_of_module(self, mod)
+
+    def getattr(self, attr, attr_val, parameter_proxy_cache):
+        if not isinstance(attr_val, torch.nn.Module) or isinstance(attr_val, torch.fx.GraphModule):
+            return super().getattr(attr, attr_val, parameter_proxy_cache)
+        if isinstance(attr_val, self.proxy_type):
+            return attr_val
+        return self.proxy_type(attr_val, attr)
+
+    def trace(self, root, concrete_args):
+        res = super().trace(root, concrete_args)
+        # Since we are making AttrProxy mimic the original
+        # submodule, when someone registers a module directly
+        # to the tracer while tracing, the proxy object gets registered
+        # first. So we need to replace the proxy modules with the real ones
+        # This can happen during HOO tracing
+        proxy_module_names_to_be_replaced = []
+        for name, module in self.root.named_modules():
+            if module in self.proxy_modules:
+                proxy_module_names_to_be_replaced.append((name, module))
+
+        def _delete_proxy_attr(obj, target):
+            # Copied from fx/graph_module.py
+            # Customized it for proxy type
+            atoms = target.split(".")
+            path, target_submod = atoms[:-1], atoms[-1]
+            assert isinstance(obj, torch.nn.Module)
+            mod = obj
+
+            # Get the parent module
+            for item in path:
+
+                if not hasattr(mod, item):
+                    return False
+
+                mod = getattr(mod, item)
+
+                if not isinstance(mod, (self.proxy_type, torch.nn.Module)):
+                    return False
+
+            if not hasattr(mod, target_submod):
+                return False
+
+            # At least the leaf module should be proxy type.
+            if not isinstance(getattr(mod, target_submod), self.proxy_type):
+                return False
+
+            delattr(mod, target_submod)
+            return True
+
+        for (proxy_module_name, proxy_module) in proxy_module_names_to_be_replaced:
+            _delete_proxy_attr(self.root, proxy_module_name)
+            actual_module = self.proxy_modules[proxy_module]
+            _assign_attr(actual_module, self.root, proxy_module_name)
+
+        return res
+
+
+    def call_module(self, m, forward, args, kwargs):
+        """PythonKeyTracer overrides call_module to avoid the scope handling,
+        but we actually want it.
+        """
+        from torch._dynamo import OptimizedModule
+        # FIXME (tmanlaibaatar)
+        # When we call torch.compile inside HOO, we will end up
+        # invoking a module that is not registered on the root. For
+        # now, we just inline them. But once we start supporting
+        # mark_strict in export, we do need to properly handle this.
+        # Right now, it doesn't matter because current non-strict
+        # use cases don't need to work with HOO.
+        if isinstance(m, (OptimizedModule, GraphModule)):
+            return forward(*args, **kwargs)
+        return Tracer.call_module(self, m, forward, args, kwargs)
+
+
+    def is_leaf_module(self, m, module_qualified_name):
+        return False
+
+
+def make_fx(f,
+            decomposition_table=None,
+            tracing_mode="real",
+            _allow_non_fake_inputs=False,
+            *,
+            pre_dispatch=False,
+            record_module_stack=False,
+            _allow_fake_constant=False,
+            _error_on_data_dependent_ops=True):
+    assert tracing_mode in ["real", "fake", "symbolic"]
+
+    if decomposition_table is None:
+        decomposition_table = {}
+
+    if torch.ops.aten.sym_numel.default not in decomposition_table:
+        decomposition_table = {
+            **decomposition_table,
+            torch.ops.aten.sym_numel.default: torch._decomp.decompositions.sym_numel
+        }
+
+    @functools.wraps(f)
+    def wrapped(*args):
+        # Avoid importing sympy at a module level
+        from .symbolic_shapes import ShapeEnv
+
+        phs = pytree.tree_map(lambda _: fx.PH, args)  # type: ignore[attr-defined]
+
+        if hasattr(f, "_orig_mod") and record_module_stack:
+            scope_root = f._orig_mod
+            fx_tracer = _ModuleStackTracer(scope_root)
+        else:
+            fx_tracer = PythonKeyTracer()
+        fake_tensor_mode: Any = nullcontext()
+        if tracing_mode == "real":
+            fake_tensor_mode = nullcontext()
+        elif tracing_mode == "fake":
+            import torch._dynamo
+            fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
+            if fake_tensor_mode is None:
+                fake_tensor_mode = FakeTensorMode(
+                    allow_fallback_kernels=True,
+                    allow_non_fake_inputs=_allow_non_fake_inputs,
+                    shape_env=ShapeEnv(),
+                    static_shapes=True,
+                )
+        elif tracing_mode == "symbolic":
+            import torch._dynamo
+            fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
+            if fake_tensor_mode is None:
+                shape_env = ShapeEnv()
+                fake_tensor_mode = FakeTensorMode(
+                    allow_fallback_kernels=False,
+                    allow_non_fake_inputs=_allow_non_fake_inputs,
+                    shape_env=shape_env)
+            else:
+                shape_env = fake_tensor_mode.shape_env
+                assert shape_env is not None, "shape_env should be set if tracing with 'symbolic'"
+
+        else:
+            raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
+
+        python_dispatcher_mode: Any = nullcontext()
+        pre_dispatch_mode: Any = nullcontext()
+        # pre-autograd tracing uses per-dispatch-key modes,
+        # which requires the python dispatcher
+        if tracing_mode == "symbolic" or pre_dispatch:
+            python_dispatcher_mode = enable_python_dispatcher()
+        if pre_dispatch:
+            pre_dispatch_mode = enable_pre_dispatch()
+
+        proxy_function_mode: Any = nullcontext()
+        if pre_dispatch:
+            proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer)
+
+        proxy_mode = ProxyTorchDispatchMode(fx_tracer,
+                                            tracing_mode,
+                                            pre_dispatch=pre_dispatch,
+                                            _allow_fake_constant=_allow_fake_constant,
+                                            _error_on_data_dependent_ops=_error_on_data_dependent_ops)
+
+        arg_count = 0
+
+        def wrap_fake(x):
+            nonlocal arg_count
+            # TODO: it would be nice to line these up with the names
+            # FX will choose for the placeholders, but we don't
+            # actually know what the names will be at this point yet
+            # NB: the Source here is actually meaningless
+            from torch._dynamo.source import ConstantSource
+            source = ConstantSource(f"input{arg_count}")
+            if isinstance(x, torch.Tensor):
+                arg_count += 1
+                return fake_tensor_mode.from_tensor(x, source=source)  # type: ignore[attr-defined]
+            # NB: don't match on bools
+            elif type(x) is int and tracing_mode == "symbolic":
+                return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source)
+
+            return x
+
+        sym_mode = proxy_mode.sym_mode
+
+        wrap_fn_map = {
+            "real": lambda x: x,
+            "fake": wrap_fake,
+            "symbolic": wrap_fake,
+        }
+        args = pytree.tree_map(wrap_fn_map[tracing_mode], args)
+
+        if not hasattr(inspect.unwrap(f), '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS:
+            # FX doesn't support varargs, so we gotta fake up a wrapper
+            # TODO: Would be nice to fix this at the source...
+            func = fake_signature(f, len(phs))
+        else:
+            func = f
+
+        # We disable the autocast cache as the autocast cache causes type conversions on parameters to
+        # check a cache, which introduces untracked tensors into the graph
+        #
+        # We also disable tracing by any other tensor proxy-based tracers except the current. The
+        # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
+        # thus irrelevant to any external functional trace.
+        with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, pre_dispatch_mode, proxy_function_mode, \
+             sym_mode, proxy_mode, disable_autocast_cache():
+            t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
+
+        # TODO: kind of a bad way to do it, should maybe figure out a better way
+        if tracing_mode == "symbolic":
+            t.shape_env = shape_env  # type: ignore[assignment]
+        return t
+
+    return wrapped
+
+
+def get_torch_dispatch_modes():
+    return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
+
+
+def get_innermost_proxy_mode():
+    return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
+
+
+@contextlib.contextmanager
+def disable_proxy_modes_tracing():
+    return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
+
+
+def maybe_handle_decomp(proxy_mode, op, args, kwargs):
+    if op in CURRENT_DECOMPOSITION_TABLE:
+        with proxy_mode:
+            return CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
+    return NotImplemented
+
+
+def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"):
+    """A helper function used to get the GraphModule for the given func.
+
+    It's expected to be used in the ProxyTensor tracing context.
+    It detaches the args and kwargs from the current tracer so that the trace of
+    the current graph module can be created without any side-effects.
+    """
+    wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
+
+    with disable_proxy_modes_tracing():
+        gm = make_fx(wrapped, tracing_mode=tracing_mode)(all_args)
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/recording.py b/MLPY/Lib/site-packages/torch/fx/experimental/recording.py
new file mode 100644
index 0000000000000000000000000000000000000000..755394fffbbb41d636265227c95161b94d823cc4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/recording.py
@@ -0,0 +1,458 @@
+import functools
+import itertools
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.utils._pytree as pytree
+
+
+__all__ = [
+    "ShapeEnvEvent",
+    "record_shapeenv_event",
+    "replay_shape_env_events",
+    "FakeTensorMeta",
+    "shape_env_check_state_equal",
+    "NotEqualError",
+]
+
+# [Note: Recording ShapeEnv Events]
+# =================================
+#
+# What is a ShapeEnv event?
+# -------------------------
+# We consider a ShapeEnv event every function call (ShapeEnv method or
+# independent function) that modifies the state of the ShapeEnv instance.
+# Such calls are recorded alongside their positional and keyword arguments,
+# so that it may be replayed over a different ShapeEnv instance.
+#
+# See [Note: ShapeEnv State Equality] for what is considered the state
+# of a ShapeEnv instance.
+#
+# What is it for?
+# ---------------
+# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
+# arbitrary state in time.
+#
+# Being able to arbitrarily replay events like so is useful, mainly for
+# translation validation bisection. i.e. if a ValidationException has been
+# raised, find the earliest point in time where the translation validation
+# fails.
+#
+# Besides that, it also allows us to inspect the given instance and,
+# for example, check the guards that would actually be issued at that point.
+#
+# What kind of arguments can be stored in an event?
+# -------------------------------------------------
+# There's no specific rule for what cannot be used as an argument.
+# That said, pay special attention to the following cases:
+#
+#   1. Tensor inputs: there are some tests that check whether the inputs
+#      were garbage collected after execution. These will fail if there's
+#      an event that is holding a reference to those inputs.
+#
+#   2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
+#      will be automatically replaced by the new given ShapeEnv instance.
+#
+#   3. SymTypes arguments: they also hold references to ShapeEnv. So,
+#      whenever we see them, we create a new instance, replacing the
+#      ShapeEnv reference.
+#
+#   4. FX nodes: specifically, FX nodes from the FX graph for symbolic
+#      shapes. That argument must be replaced when replaying the event at
+#      ShapeEnvEvent.run, since it has to reference a node from the given
+#      instance, and not from the recorded instance.
+
+
+# Event class for reconstructing ShapeEnv at arbitrary time.
+#
+# Represents a method call that mutates ShapeEnv in a way that affects the
+# issued guards, when ShapeEnv.produce_guards is called.
+@dataclass
+class ShapeEnvEvent:
+    # ShapeEnv method.
+    f: Callable
+
+    # Arguments and keyword arguments called with.
+    args: Optional[List[Any]] = None
+    kwargs: Optional[Dict[str, Any]] = None
+
+    # List of tracked_fakes at the time the method was called.
+    tracked_fakes: Optional[List[Any]] = None
+
+    # Name of the captured event.
+    # Used for special handling of particular methods.
+    name: Optional[str] = None
+
+    # Replay itself, but using shape_env as self.
+    def run(self, shape_env=None) -> Any:
+        from torch.fx.experimental.symbolic_shapes import (
+            is_symbolic,
+            ShapeEnv,
+            SymTypes,
+        )
+
+        # Special handling for the constructor event.
+        if self.f is ShapeEnv:
+            assert shape_env is None and self.args is None and self.kwargs is not None
+            return ShapeEnv(**self.kwargs)
+
+        assert shape_env is not None
+        args = list(self.args or list())
+        kwargs = dict(self.kwargs or dict())
+
+        # Replace any argument of type ShapeEnv by the given one.
+        args, kwargs = pytree.tree_map_only(
+            ShapeEnv, lambda _: shape_env, (args, kwargs)
+        )
+
+        # Replace any argument of type SymTypes by a new instance,
+        # replacing its ShapeEnv reference.
+        args, kwargs = pytree.tree_map_only(
+            lambda x: isinstance(x, SymTypes) and is_symbolic(x),
+            lambda a: type(a)(a.node.with_shape_env(shape_env)),
+            (args, kwargs),
+        )
+
+        # Converts FX nodes using the mapping argument.
+        def maybe_convert_node(x: Any) -> Any:
+            if not isinstance(x, torch.fx.Node):
+                # Don't do anything to x if it's not an FX node.
+                return x
+
+            # If, at some point, we created an FX node, it means that translation validation is on.
+            # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
+            # we are tracking node names at shape_env.name_to_node.
+            assert hasattr(shape_env, "name_to_node")
+            name_to_node = shape_env.name_to_node  # type: ignore[attr-defined]
+            assert x.name in name_to_node
+            return name_to_node[x.name]
+
+        # Replaces the value of an specific argument by the result of fn.
+        def replacearg(index: int, key: str, fn: Callable):
+            if index < len(args):
+                args[index] = fn(args[index])
+            if key in kwargs:
+                kwargs[key] = fn(kwargs[key])
+
+        if self.is_create_fx_call_function():
+            # ShapeEnv.create_fx_call_function:
+            # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
+            # They must be replaced, since a "call_function" FX node with this tuple as argument
+            # will be added to the FX graph of the new shape_env.
+            replacearg(
+                index=2,
+                key="args",
+                fn=lambda args: tuple(maybe_convert_node(a) for a in args),
+            )
+        if self.is_evaluate_expr() or self.is_defer_runtime_assert():
+            # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
+            # "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
+            # They must be replaced, since it will be part of a "call_function" FX node for
+            # torch._assert, which will be added to the FX graph of the new shape_env.
+            replacearg(index=3, key="fx_node", fn=maybe_convert_node)
+
+        # Actually call the method with the converted arguments.
+        return self.f(*args, **kwargs)
+
+    def __str__(self) -> str:
+        name = self.name if self.name is not None else self.f.__name__
+        return f"event: {name} ({self.args}, {self.kwargs})"
+
+    def is_create_fx_call_function(self) -> bool:
+        return self.name == "_create_fx_call_function"
+
+    def is_evaluate_expr(self) -> bool:
+        return self.name == "evaluate_expr"
+
+    def is_defer_runtime_assert(self) -> bool:
+        return self.name == "defer_runtime_assert"
+
+
+# Extracts a ShapeEnv instance inside args and kwargs.
+# Specifically, it looks for:
+#   1. ShapeEnv arguments
+#   2. SymInt, SymFloat, or SymBool arguments
+# If we find more than one object of any of the above types, we
+# also check that the ShapeEnv instance is the same for all of them.
+def _extract_shape_env_and_assert_equal(args, kwargs):
+    from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
+
+    def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
+        if old is not None:
+            assert old is new, "call with different ShapeEnv"
+        return new
+
+    shape_env = None
+    for val in itertools.chain(args, kwargs.values()):
+        if isinstance(val, ShapeEnv):
+            shape_env = assert_equal(shape_env, val)
+        if isinstance(val, SymTypes) and is_symbolic(val):
+            shape_env = assert_equal(shape_env, val.node.shape_env)
+
+    return shape_env
+
+
+# Decorator for recording the given function as a replayable event.
+#
+# This decorator should be used at every function that mutates the state of
+# ShapeEnv in some way that affects the resulting issued guards (i.e. when
+# ShapeEnv.produce_guards is called).
+#
+# save_tracked_fakes: saves a snapshot of the TrackedFake list.
+# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
+#
+# When to save the list of TrackedFake?
+# =====================================
+# We should save the list of TrackedFake whenever the translation validation
+# bisection may actually stop and call the produce_guards method at the moment
+# right after the recorded function was played. In other words, since the
+# bisection bisects through torch._assert calls, we should save in all methods
+# that adds a torch._assert call to the symbolic shapes FX graph.
+#
+# At the moment, there are 2 methods that save the list:
+#   - ShapeEnv.evaluate_expr
+#   - ShapeEnv.defer_runtime_assert
+def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
+    def decorator(fn: Callable) -> Callable:
+        assert callable(fn)
+        name = fn.__name__
+
+        @functools.wraps(fn)
+        def wrapper(*args, **kwargs):
+            from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+            if isinstance(args[0], ShapeEnv) and args[0].is_recording:  # type: ignore[has-type]
+                # If ShapeEnv is already recording an event, call the wrapped
+                # function directly.
+                #
+                # NB: here, we skip the check of whether all ShapeEnv instances
+                # are equal, in favor of a faster dispatch.
+                return fn(*args, **kwargs)
+
+            # Retrieve an instance of ShapeEnv.
+            # Assumption: the collection of args and kwargs may not reference
+            # different ShapeEnv instances.
+            self = _extract_shape_env_and_assert_equal(args, kwargs)
+
+            # If we are calling this function without any ShapeEnv instance
+            # alive in its arguments, we don't record and call the original.
+            if self is None:
+                return fn(*args, **kwargs)
+
+            # Otherwise, start recording and call the function.
+            with self._recording():
+                # Take a snapshot of the current tracked_fakes.
+                tracked_fakes = (
+                    self._snapshot_tracked_fakes() if save_tracked_fakes else None
+                )
+                # Record the event for 'fn'.
+                event = ShapeEnvEvent(
+                    fn, list(args), kwargs, tracked_fakes, name=fn.__name__
+                )
+                self.events.append(event)
+                # Play the event on this ShapeEnv.
+                return event.run(self)
+
+        return wrapper
+
+    return decorator
+
+
+# Replays the ShapeEnvEvents list.
+# It assumes the first event is the constructor call.
+#
+# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
+def replay_shape_env_events(events):
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+    constructor_event = events[0]
+    assert constructor_event.f == ShapeEnv
+
+    # Constructs the new ShapeEnv.
+    shape_env = constructor_event.run()
+
+    for event in events[1:]:
+        try:
+            # Actually replays each event.
+            # We need to call create_mapping_fn every time, since the node list might
+            # change after each event is replayed.
+            event.run(shape_env)
+        except Exception as e:
+            raise RuntimeError(f"failed when running event: {event}") from e
+
+    return shape_env
+
+
+# FakeTensor metadata.
+# This is to be used in place of FakeTensor placeholders when calling
+# ShapeEnv.produce_guards.
+@dataclass
+class FakeTensorMeta:
+    tensor_size: Tuple[Union[int, torch.SymInt], ...]
+    tensor_stride: Tuple[Union[int, torch.SymInt], ...]
+    tensor_storage_offset: Union[int, torch.SymInt]
+    is_nested: bool
+
+    def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
+        return self.tensor_size
+
+    def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
+        return self.tensor_stride
+
+    def storage_offset(self) -> Union[int, torch.SymInt]:
+        return self.tensor_storage_offset
+
+    def dim(self) -> int:
+        return len(self.tensor_size)
+
+    @staticmethod
+    def from_fake(fake) -> "FakeTensorMeta":
+        return FakeTensorMeta(
+            fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
+        )
+
+
+# [Note: ShapeEnv State Equality]
+# ===============================
+#
+# What is considered ShapeEnv state?
+# ----------------------------------
+# We consider to be the state of a ShapeEnv instance everything that
+# is not in the inline tuple inside remove_nonstate_variables function.
+# That is: the fields within ShapeEnv that modify the flow of execution
+# of the program.
+#
+# So, for example: the replacements field might influence on how an
+# expression is simplified. That, in turn, may result in a guard being
+# statically known (i.e. not added).
+#
+# On the other hand, var_to_stack serves only changes what is printed
+# in the screen, i.e. used only for debugging purposes. Therefore, we
+# should not consider it when comparing states.
+#
+# What to do on NotEqualError?
+# ----------------------------
+# Here are a few possible causes for getting a NotEqualError raised:
+#
+#   1. New field that does not belong in the ShapeEnv state.
+#      For example: log field of type ShapeEnvLoggerAdapter. Different
+#      ShapeEnv instances will always have different ShapeEnvLoggerAdapter
+#      instances, i.e. equality comparison would fail.
+#      Solution: add it to the inlined tuple inside remove_nonstate_variables
+#      function inside check_equal method.
+#
+#   2. New field that is not directly comparable across instances.
+#      For example: guards field of type List[ShapeGuard]. More specifically,
+#      the ShapeGuard type holds an expression and a stack information
+#      for debugging purposes. When replaying the even on a new ShapeEnv
+#      instance, the stack would be different, which would trigger this error.
+#      Solution: add a special case to the map_value function inside
+#      check_equal function.
+#
+#   3. Mutation of ShapeEnv on some not recorded function.
+#      If a mutation of the state of ShapeEnv happens inside a function
+#      that is not recorded (or that no caller in the stack is recorded),
+#      then, the replayed ShapeEnv won't catch that.
+#      Solution: decorate the function with record_shape_env_event.
+
+
+# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
+# returned by ShapeEnv.produce_guards.
+def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
+    # Collect and remove variables that don't necessarily represent the state
+    # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
+    # instance itself.
+    env1_vars = vars(env1).copy()
+    env2_vars = vars(env2).copy()
+
+    for v in non_state_variable_names:
+        if v in env1_vars:
+            env1_vars.pop(v)
+        if v in env2_vars:
+            env2_vars.pop(v)
+
+    # Function for transforming the mismatched values into string.
+    # Needed, since dict and set entries order might not be the same every time.
+    def value_to_str(value: Any) -> str:
+        if isinstance(value, dict):
+            return (
+                "{"
+                + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
+                + "}"
+            )
+        if isinstance(value, set):
+            return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
+        return str(value)
+
+    # Compares env1_vars with env2_vars.
+    # Here, we allow the value of each field to be mapped, so that we appropriately
+    # compare the two values.
+    def compare_vars(
+        map_value: Callable[[str, Any], Any]
+    ) -> List[Tuple[str, str, str]]:
+        env1_set, env2_set = set(env1_vars), set(env2_vars)
+
+        # First, compare the set of keys in each vars dictionary.
+        if env1_set != env2_set:
+            raise NotEqualError(
+                "field set mismatch:",
+                [
+                    (
+                        "found unique fields:",
+                        str(sorted(env1_set - env2_set)),
+                        str(sorted(env2_set - env1_set)),
+                    ),
+                ],
+            )
+
+        # Then, sort the keys, and compare the mapped values of each key.
+        sorted_keys = list(env1_set)
+        sorted_keys.sort()
+
+        mapped_dict = [
+            (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
+            for k in sorted_keys
+        ]
+
+        # Return a list of tuples representing the fields that did not match
+        # alongside their respective mapped values.
+        return [
+            (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
+            for k, val1, val2 in mapped_dict
+            if val1 != val2
+        ]
+
+    # Accumulate the mismatching fields.
+    errors = compare_vars(map_value)
+
+    if len(errors) > 0:
+        raise NotEqualError("field values don't match:", errors)
+
+
+class NotEqualError(Exception):
+    def __init__(
+        self,
+        msg: str,
+        mismatched: List[Tuple[str, str, str]],
+    ) -> None:
+        details = "\n".join(
+            [
+                "\n".join(
+                    [
+                        f"==> {inner_msg}",
+                        f"  >  Left: {str1}",
+                        f"  > Right: {str2}",
+                    ]
+                )
+                for inner_msg, str1, str2 in mismatched
+            ]
+        )
+
+        super().__init__(
+            f"""\
+ShapeEnv not equal: {msg}
+
+{details}
+"""
+        )
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/refinement_types.py b/MLPY/Lib/site-packages/torch/fx/experimental/refinement_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea1b75a7221a08b4abc8e0b2d421dc92e44867c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/refinement_types.py
@@ -0,0 +1,16 @@
+class Equality:
+    def __init__(self, lhs, rhs):
+        self.lhs = lhs
+        self.rhs = rhs
+
+    def __str__(self):
+        return f'{self.lhs} = {self.rhs}'
+
+    def __repr__(self):
+        return f'{self.lhs} = {self.rhs}'
+
+    def __eq__(self, other):
+        if isinstance(other, Equality):
+            return self.lhs == other.lhs and self.rhs == other.rhs
+        else:
+            return False
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/rewriter.py b/MLPY/Lib/site-packages/torch/fx/experimental/rewriter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7472c9cde89d2eb873bcc2bab4c18a9be0fee216
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/rewriter.py
@@ -0,0 +1,121 @@
+import ast
+import inspect
+import textwrap
+import copy
+import functools
+from types import FunctionType
+from typing import cast, Union, Callable, Dict, Optional, Any
+from torch.fx._symbolic_trace import Tracer
+from torch.fx.graph import Graph
+from torch._sources import normalize_source_lines
+import torch
+
+class AST_Rewriter(ast.NodeTransformer):
+    """
+    Take a FunctionType object representing a `forward` method, then
+    perform an AST rewrite to swap out nodes that are not symbolically
+    traceable with a callsite to the FX alternative.
+
+    To support swapping out an AST node, define a new `visit` method on
+    that node. For more details, see:
+    https://docs.python.org/3/library/ast.html#ast.NodeTransformer
+    """
+
+    def rewrite(self, fn: FunctionType):
+
+        # Normalize the source lines
+        sourcelines, _ = inspect.getsourcelines(fn)
+        sourcelines = normalize_source_lines(sourcelines)
+        source = ''.join(sourcelines)
+        normalized_str = textwrap.dedent(source)
+
+        # Rewrite the original AST
+        source_ast = ast.parse(normalized_str)
+        dest_ast = ast.fix_missing_locations(self.visit(source_ast))
+
+        # Pull out the compiled function from the newly-created Module
+        code = compile(dest_ast, "", "exec")
+        globals_dict = copy.copy(fn.__globals__)
+        keys_before = set(globals_dict.keys())
+        exec(code, globals_dict)
+        new_keys = list(set(globals_dict.keys()) - keys_before)
+        assert len(new_keys) == 1
+        fn_compiled = globals_dict[new_keys[0]]
+
+        # return the compiled function with the original globals
+        def change_func_globals(f, globals):
+            """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
+            # __globals__ is a private member of the function class
+            # so we have to copy the function, f, all of its member, except f.__globals__
+            g = FunctionType(
+                f.__code__,
+                globals,
+                name=f.__name__,
+                argdefs=f.__defaults__,
+                closure=f.__closure__,
+            )
+            g = functools.update_wrapper(g, f)
+            g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
+            return g
+        # Return the correct FunctionType object
+        return change_func_globals(fn_compiled, globals=fn.__globals__)
+
+    def visit_Assert(self, node):
+        """
+        Swap out the Assert node (Python's `assert`) with a callsite to the
+        symbolically-traceable torch._assert function
+        """
+        # Create the Call node
+        n = ast.parse('torch._assert()', mode='eval')
+        assert isinstance(n, ast.Expression)
+        call_node = n.body
+        assert isinstance(call_node, ast.Call)
+        msg = node.msg if node.msg else ast.Constant(value="", kind=None)
+        call_node.args = [node.test, msg]
+
+        # Ensure that the new node conforms to the Python AST grammar
+        expr_wrapper = ast.Expr(value=call_node)
+
+        # Return the new Call node to signify that we want to use it as
+        # a replacement for the original _assert node
+        return ast.copy_location(expr_wrapper, node)
+
+    def visit_AnnAssign(self, node):
+        """
+        Swap out Python's AnnAssign with an Assign node where the annotation function is called.
+        Example:
+             Original:
+             y: Tensor_Type(1,2,3, Dyn) = f2(x)
+            Output:
+             y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
+        """
+        return ast.Assign(targets=[node.target], value=ast.Call(
+            func=ast.Name(id='annotate', ctx=ast.Load()),
+            args=[node.value, node.annotation], keywords=[]))
+
+
+class RewritingTracer(Tracer):
+    def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
+        return super().trace(_rewrite(root), concrete_args)
+
+
+def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
+    if isinstance(fn, torch.nn.Module):
+        # Rewrite this module's `forward` as well as the `forward`s of
+        # all of this module's recursive descendents. Return the new,
+        # rewritten module hierarchy.
+        def rewrite_module(m : torch.nn.Module):
+            class RewrittenModule(torch.nn.Module):
+                def __init__(self, orig):
+                    super().__init__()
+                    for k, v in orig.__dict__.items():
+                        if isinstance(v, torch.nn.Module):
+                            self.__dict__[k] = copy.copy(rewrite_module(v))
+                        else:
+                            self.__dict__[k] = copy.copy(v)
+            RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
+            return RewrittenModule(m)
+        return rewrite_module(fn)
+    else:
+        # Rewrite this single free function
+        return AST_Rewriter().rewrite(cast(FunctionType, fn))
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/schema_type_annotation.py b/MLPY/Lib/site-packages/torch/fx/experimental/schema_type_annotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdb979715acb9de42b586c926080ddc812df403d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/schema_type_annotation.py
@@ -0,0 +1,111 @@
+import torch
+import torch.fx
+import inspect
+from typing import Any, Dict, Optional, Tuple
+from torch.fx.node import Argument, Target
+from torch._jit_internal import boolean_dispatched
+from torch.fx.operator_schemas import _torchscript_type_to_python_type
+
+from torch.fx import Transformer
+
+class AnnotateTypesWithSchema(Transformer):
+    """
+    Use Python function signatures to annotate types for `Nodes` within an FX graph.
+    This pulls out Python function signatures for:
+
+        1. Standard `torch.nn` Module calls
+        2. `torch.nn.functional` calls
+        3. Attribute fetches via `get_attr`
+
+    Example usage:
+
+        m = torchvision.models.resnet18()
+
+        traced = torch.fx.symbolic_trace(m)
+
+        traced = AnnotateTypesWithSchema(traced).transform()
+
+    """
+    def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
+                 annotate_modules : bool = True, annotate_get_attrs : bool = True):
+        super().__init__(module)
+        self.annotate_functionals = annotate_functionals
+        self.annotate_modules = annotate_modules
+        self.annotate_get_attrs = annotate_get_attrs
+
+    def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
+        python_ret_type = None
+        if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
+            target_for_analysis = target
+            if target in boolean_dispatched:
+                # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
+                # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
+                # branches of the dispatch have exactly the same signature. If they do, use the `true`
+                # branch signature for analysis. Otherwise, leave this un-normalized
+                assert not isinstance(target, str)
+                dispatched = boolean_dispatched[target]
+                if_true, if_false = dispatched['if_true'], dispatched['if_false']
+                # TODO: can we emit the union of these? What are the implications on TorchScript
+                # compilation?
+                if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
+                    return super().call_function(target, args, kwargs)
+                target_for_analysis = if_true
+
+            python_ret_type = self._extract_python_return_type(target_for_analysis)
+
+        return_proxy = super().call_function(target, args, kwargs)
+        return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
+        return return_proxy
+
+    def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
+        python_ret_type = None
+        assert isinstance(target, str)
+        submod = self.fetch_attr(target)
+        if self.annotate_modules and hasattr(submod.__class__, '__name__'):
+            classname = submod.__class__.__name__
+            if getattr(torch.nn, classname, None) == submod.__class__:
+                python_ret_type = self._extract_python_return_type(submod.forward)
+        return_proxy = super().call_module(target, args, kwargs)
+        return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
+        return return_proxy
+
+    def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
+        attr_proxy = super().get_attr(target, args, kwargs)
+
+        if self.annotate_get_attrs:
+            module_itr = self.module
+            assert isinstance(target, str)
+            atoms = target.split('.')
+            for i, atom in enumerate(atoms):
+                if not hasattr(module_itr, atom):
+                    raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
+                module_itr = getattr(module_itr, atom)
+
+            maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
+            if maybe_inferred_ts_type.success():
+                python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
+                attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
+
+        return attr_proxy
+
+    def _extract_python_return_type(self, target : Target) -> Optional[Any]:
+        """
+        Given a Python call target, try to extract the Python return annotation
+        if it is available, otherwise return None
+
+        Args:
+
+            target (Callable): Python callable to get return annotation for
+
+        Returns:
+
+            Optional[Any]: Return annotation from the `target`, or None if it was
+                not available.
+        """
+        assert callable(target)
+        try:
+            sig = inspect.signature(target)
+        except (ValueError, TypeError):
+            return None
+
+        return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/sym_node.py b/MLPY/Lib/site-packages/torch/fx/experimental/sym_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..e558dd22dce50d3bf11b60a589fe18dab3ca3bc6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/sym_node.py
@@ -0,0 +1,1330 @@
+"""
+This file does three things:
+- Contains the definition of SymNode
+- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
+- Does not depend on sympy at import time
+
+As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
+to avoid having to load SymPy at import time, as doing so is *very* slow.
+"""
+
+import builtins
+import itertools
+import logging
+import math
+import operator
+import sys
+from functools import lru_cache, update_wrapper
+from typing import Optional, Type, TYPE_CHECKING, Union
+
+import torch
+
+# NB: The sym_* functions are used via getattr() and must be imported here.
+from torch import (  # noqa: F401
+    sym_float,
+    sym_ite,
+    sym_max,
+    sym_min,
+    sym_not,
+    SymBool,
+    SymFloat,
+    SymInt,
+)
+
+from torch.fx.experimental._sym_dispatch_mode import (
+    handle_sym_dispatch,
+    sym_function_mode,
+)
+
+if TYPE_CHECKING:
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+log = logging.getLogger(__name__)
+sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
+
+
+__all__ = ["SymNode", "method_to_operator", "magic_methods"]
+
+
+SymTypes = (SymInt, SymFloat, SymBool)
+
+
+def _to_symtype(t):
+    if t is bool:
+        return SymBool
+    if t is int:
+        return SymInt
+    if t is float:
+        return SymFloat
+    return t
+
+
+# TODO: An incomplete list
+# 1. Set variables to be equal when we do equality
+# 2. Specialize on 0/1 when we do subtraction
+class SymNode:
+    """
+    This is a type erased SymInt/SymFloat which we use to do actual operations.
+    End users don't touch this.  Magic methods are NOT defined on this object.
+    """
+
+    def __init__(
+        self,
+        expr,
+        shape_env,
+        pytype,
+        hint: Optional[Union[int, float, bool]],
+        constant=None,
+        fx_node=None,
+    ):
+        self._expr = expr
+        self.shape_env = shape_env
+        self.pytype = pytype
+        # What's the difference between hint and constant?
+        #
+        # - A constant is known to be invariant across invocations of the model;
+        #   it will always be this value.  We only really know this when we
+        #   encounter an honest-to-goodness literal (when wrapping it into
+        #   a SymNode, we set constant.)  Most of the time, constant is None
+        #
+        # - A hint is a *particular* value from the particular run we are
+        #   tracing, but it may vary the next time around.  It's useful to
+        #   keep this around, as if we need a concrete value from a SymNode,
+        #   we will return the hint and guard on the expression that produced
+        #   it giving the same hint next time around.  The hint is not
+        #   guaranteed to be set either: if you have an unbacked SymNode,
+        #   there won't be any hint; it was the result of some tensor-dependent
+        #   computation, but we don't know what it actually is because we
+        #   haven't actually run the tensor computation.
+        #
+        # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
+        # in hopes that we've learned enough about the unbacked symints to
+        # discharge the hint; otherwise, you're likely to just error out.
+        #
+        # (A previous version of this system had some optimizations to only
+        # recompute when it was possible we had learned enough about the
+        # unbacked symint that a hint was now possible, but as we added more
+        # potential refinements to unbacked symints this got harder to keep
+        # in sync, so we've deleted it for now.)
+        if hint is not None:
+            assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
+                "Cannot create SymNode of type "
+                f"{pytype} with incompatible hint of type {type(hint)}"
+            )
+        self._hint = hint
+        self.constant: Optional[Union[int, float, bool]] = constant
+
+        # Record the FX node of the current node if we are doing translation
+        # validation. They will be used for building the input assertions for
+        # the translation validation problem.
+        self.fx_node = (
+            fx_node if self.shape_env._translation_validation_enabled else None
+        )
+
+    def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
+        return SymNode(
+            self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
+        )
+
+    @property
+    def expr(self):
+        return self.shape_env.replace(self._expr)
+
+    # Recompute the hint and see if we've got it now
+    # Precondition: self._hint is None
+    def _update_hint(self):
+        r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
+        if r is not None:
+            self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r
+
+    @property
+    def hint(self):
+        if self._hint is None:
+            self._update_hint()
+        return self._hint
+
+    def has_hint(self):
+        if self._hint is None:
+            self._update_hint()
+        return self._hint is not None
+
+    def require_hint(self, fallback=None):
+        if self._hint is None:
+            self._update_hint()
+        if self._hint is None:
+            if fallback is not None:
+                return fallback
+            # NB: we expect this to raise
+            return self.shape_env.size_hint(self.expr)
+        return self._hint
+
+    def maybe_as_int(self):
+        if self.expr.is_number:
+            return int(self.expr)
+        else:
+            return None
+
+    def is_int(self):
+        return self.pytype is int
+
+    def is_float(self):
+        return self.pytype is float
+
+    def is_bool(self):
+        return self.pytype is bool
+
+    def is_nested_int(self):
+        # Unbacked SymInts cannot be nested int today
+        return (
+            self._hint is not None
+            and isinstance(self._hint, SymInt)
+            and self._hint.node.is_nested_int()
+        )
+
+    def wrap_int(self, num):
+        assert type(num) is int
+        import sympy
+
+        return SymNode(
+            sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
+        )
+
+    def wrap_float(self, num):
+        assert type(num) is float
+        import sympy
+
+        return SymNode(
+            sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
+        )
+
+    def wrap_bool(self, num):
+        assert type(num) is bool
+        import sympy
+
+        return SymNode(
+            sympy.true if num else sympy.false,
+            self.shape_env,
+            bool,
+            num,
+            constant=num,
+            fx_node=num,
+        )
+
+    def clone(self):
+        return self
+
+    def str(self):
+        return f"{self.expr}"
+
+    def __str__(self):
+        return self.str()
+
+    def __repr__(self):
+        return self.str()
+
+    # These methods call the metaprogrammed methods, they're hand written
+    # here so we get good stack traces
+    def abs(self) -> "SymNode":
+        return self._abs()  # type: ignore[attr-defined]
+
+    def pos(self) -> "SymNode":
+        return self._pos()  # type: ignore[attr-defined]
+
+    def round(self, ndigits=None) -> "SymNode":
+        return self._round(ndigits)  # type: ignore[attr-defined]
+
+    def add(self, other) -> "SymNode":
+        return self._add(other)  # type: ignore[attr-defined]
+
+    def sub(self, other) -> "SymNode":
+        return self._sub(other)  # type: ignore[attr-defined]
+
+    def mul(self, other) -> "SymNode":
+        return self._mul(other)  # type: ignore[attr-defined]
+
+    def mod(self, other) -> "SymNode":
+        return self._mod(other)  # type: ignore[attr-defined]
+
+    def pow(self, other) -> "SymNode":
+        return self._pow(other)  # type: ignore[attr-defined]
+
+    def and_(self, other) -> "SymNode":
+        return self._and_(other)  # type: ignore[attr-defined]
+
+    def or_(self, other) -> "SymNode":
+        return self._or_(other)  # type: ignore[attr-defined]
+
+    def truediv(self, other) -> "SymNode":
+        return self._truediv(other)  # type: ignore[attr-defined]
+
+    def floordiv(self, other) -> "SymNode":
+        return self._floordiv(other)  # type: ignore[attr-defined]
+
+    def lshift(self, other) -> "SymNode":
+        return self._lshift(other)  # type: ignore[attr-defined]
+
+    def rshift(self, other) -> "SymNode":
+        return self._rshift(other)  # type: ignore[attr-defined]
+
+    def sym_not(self) -> "SymNode":  # noqa: F811
+        return self._sym_not()  # type: ignore[attr-defined]
+
+    def eq(self, other) -> "SymNode":
+        return self._eq(other)  # type: ignore[attr-defined]
+
+    def ne(self, other) -> "SymNode":
+        return self._ne(other)  # type: ignore[attr-defined]
+
+    def gt(self, other) -> "SymNode":
+        return self._gt(other)  # type: ignore[attr-defined]
+
+    def lt(self, other) -> "SymNode":
+        return self._lt(other)  # type: ignore[attr-defined]
+
+    def le(self, other) -> "SymNode":
+        return self._le(other)  # type: ignore[attr-defined]
+
+    def ge(self, other) -> "SymNode":
+        return self._ge(other)  # type: ignore[attr-defined]
+
+    def floor(self) -> "SymNode":
+        return self._floor()  # type: ignore[attr-defined]
+
+    def is_integer(self) -> "SymNode":
+        return self._is_integer()  # type: ignore[attr-defined]
+
+    def sym_float(self) -> "SymNode":  # noqa: F811
+        return self._sym_float()  # type: ignore[attr-defined]
+
+    def sym_int(self) -> "SymNode":
+        return self._sym_int()  # type: ignore[attr-defined]
+
+    def ceil(self) -> "SymNode":
+        return self._ceil()  # type: ignore[attr-defined]
+
+    def neg(self) -> "SymNode":
+        return self._neg()  # type: ignore[attr-defined]
+
+    def sym_min(self, other) -> "SymNode":  # noqa: F811
+        return self._sym_min(other)  # type: ignore[attr-defined]
+
+    def sym_max(self, other) -> "SymNode":  # noqa: F811
+        return self._sym_max(other)  # type: ignore[attr-defined]
+
+    def sym_ite(self, then_val, else_val) -> "SymNode":
+        return self._sym_ite(then_val, else_val)  # type: ignore[attr-defined]
+
+    def is_contiguous(self, sizes, strides) -> "SymNode":
+        return self._is_contiguous(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
+        return self._is_channels_last_contiguous_2d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
+        return self._is_channels_last_contiguous_3d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
+        return self._is_channels_last_strides_2d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
+        return self._is_channels_last_strides_3d(sizes, strides)  # type: ignore[attr-defined]
+
+    def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
+        return self._is_non_overlapping_and_dense_indicator(sizes, strides)  # type: ignore[attr-defined]
+
+    # Make C++ happy
+    def sym_or(self, other):
+        return self.or_(other)
+
+    def sym_and(self, other):
+        return self.and_(other)
+
+    def is_non_overlapping_and_dense(self, sizes, strides):
+        return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1))  # type: ignore[attr-defined]
+
+    def int_(self):
+        return self.guard_int("", 0)  # NB: uses Python backtrace
+
+    # You can manually trigger a guard with this function
+    def guard_int(self, file, line):
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
+        try:
+            return int(r)
+        except Exception:
+            log.warning("Failed to convert to int: %s", r)
+            raise
+
+    def guard_float(self, file, line):
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.shape_env.evaluate_expr(
+            self.expr, self.hint, fx_node=self.fx_node, expect_rational=False
+        )
+        try:
+            return float(r)
+        except Exception:
+            log.warning("Failed to convert to float: %s", r)
+            raise
+
+    def guard_bool(self, file, line):
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
+        try:
+            return bool(r)
+        except Exception:
+            log.warning("Failed to convert to bool: %s", r)
+            raise
+
+    def expect_true(self, file, line):
+        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+
+        if self.has_hint() and not free_unbacked_symbols(self.expr):
+            # OK to generate guards
+            return self.guard_bool(file, line)
+        # Generate a deferred runtime assert (this might actually end up doing
+        # a regular guard if we can!)
+        # TODO: file/line here is very important, because the assert has been
+        # deferred so you can't backtrace easily
+        return self.shape_env.defer_runtime_assert(
+            self.expr, f"{file}:{line}", fx_node=self.fx_node
+        )
+
+    def expect_size(self, file, line):
+        from torch.fx.experimental.symbolic_shapes import _advise_is_size
+
+        b = self.ge(self.wrap_int(0))
+        # Generate a deferred runtime assert
+        r = b.expect_true(file, line)
+        # Refine compile time range, but only if it's unbacked.
+        # If you refine range for hinted variables, you can end up making
+        # improper deductions since compile time reasoning may be
+        # incompatible with runtime reasoning.
+        if r and not self.has_hint():
+            _advise_is_size(SymInt(self))
+        return r
+
+    def guard_size_oblivious(self, file, line):
+        """
+        Like guard_bool, but if we encounter unbacked symbols, if those symbols
+        are size-like, we will treat them as >= 2 for the purposes of the analysis.
+
+        This CHANGES the runtime semantics, but all size-oblivious sites have been
+        audited to ensure that the runtime semantics don't change in a material way.
+        Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
+        an unbacked one size, or a tensor reporting as non-contiguous even if it's
+        contiguous if it would have been reported contiguous due to being empty.
+        """
+        # TODO: use the file/line for some useful diagnostic on why a
+        # guard occurred
+        r = self.shape_env.evaluate_expr(
+            self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
+        )
+        try:
+            return bool(r)
+        except Exception:
+            log.warning("Failed to convert to bool: %s", r)
+            raise
+
+    def bool_(self):
+        return self.guard_bool("", 0)
+
+    def is_symbolic(self):
+        return True
+
+    def nested_int(self):
+        return None
+
+    def is_constant(self):
+        return False
+
+
+# TODO: this probably needs the sizes-strides eval functions
+METHOD_TO_OPERATOR = {
+    "pos": operator.pos,
+    "abs": operator.abs,
+    "add": operator.add,
+    "and": operator.and_,
+    "ceil": math.ceil,
+    "eq": operator.eq,
+    "floor": math.floor,
+    "floordiv": operator.floordiv,
+    "ge": operator.ge,
+    "gt": operator.gt,
+    "is_integer": lambda x: x.is_integer(),
+    "le": operator.le,
+    "lshift": operator.lshift,
+    "lt": operator.lt,
+    "mod": operator.mod,
+    "mul": operator.mul,
+    "ne": operator.ne,
+    "neg": operator.neg,
+    "or": operator.or_,
+    "pow": operator.pow,
+    "round": builtins.round,
+    "rshift": operator.rshift,
+    "sub": operator.sub,
+    "sym_float": sym_float,
+    "sym_ite": sym_ite,
+    "sym_max": sym_max,
+    "sym_min": sym_min,
+    "sym_not": sym_not,
+    "truediv": operator.truediv,
+}
+
+unary_magic_methods = {
+    "abs",
+    "sym_float",
+    "ceil",
+    "floor",
+    "neg",
+    "sym_not",
+    "pos",
+}
+
+
+# Adding math ops: sqrt, cos, sin, ...
+def _get_sym_node_fn(name):
+    def fn(self):
+        return getattr(self, f"_sym_{name}")()
+
+    return fn
+
+
+math_op_names = (
+    "sqrt",
+    "cos",
+    "cosh",
+    "sin",
+    "sinh",
+    "tan",
+    "tanh",
+    "asin",
+    "acos",
+    "atan",
+)
+for name in math_op_names:
+    sym_name = f"sym_{name}"
+    priv_sym_name = f"_{sym_name}"
+    setattr(SymNode, sym_name, _get_sym_node_fn(name))
+    METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
+    unary_magic_methods.add(sym_name)
+    __all__.append(sym_name)
+
+
+# Unary methods that are not magic methods
+unary_nonmagic_methods = {
+    "is_integer",
+}
+
+unary_methods = unary_magic_methods | unary_nonmagic_methods
+
+# Most methods are only registered on SymInt and SymFloat
+# Some methods are only be registered on SymBool
+only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
+# Methods that implicitly convert SymBool into SymInt
+bool_becomes_int_magic_methods = {"add", "sub", "mul"}
+# Methods that are also on SymBool, in addition to on SymInt and SymFloat
+also_bool_magic_methods = {"eq"}
+bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
+
+# Methods that are only for float
+only_float_magic_methods = {"is_integer"}
+
+
+magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
+
+
+always_float_magic_methods = {"truediv", "sym_float", "pow"}
+
+for name in math_op_names:
+    sym_name = f"sym_{name}"
+    always_float_magic_methods.add(sym_name)
+
+
+always_int_magic_methods = {"ceil", "floor"}
+always_bool_magic_methods = {
+    "eq",
+    "ne",
+    "gt",
+    "lt",
+    "le",
+    "ge",
+    "and",
+    "or",
+    "sym_not",
+    "is_non_overlapping_and_dense",
+    "is_integer",
+}
+
+# Methods that have a `__foo__` as well as `__rfoo__`
+
+
+def _sympy_truediv(a, b):
+    from torch.utils._sympy.functions import TrueDiv
+
+    return TrueDiv(a, b)
+
+
+def _sympy_floordiv(a, b):
+    from torch.utils._sympy.functions import FloorDiv
+
+    return FloorDiv(a, b)
+
+
+def _sympy_mod(a, b):
+    from torch.utils._sympy.functions import Mod
+
+    return Mod(a, b)
+
+
+def _sympy_pow(a, b):
+    from torch.utils._sympy.functions import Pow
+
+    return Pow(a, b)
+
+
+def _sympy_and(a, b):
+    import sympy
+
+    return sympy.And(a, b)
+
+
+def _sympy_or(a, b):
+    import sympy
+
+    return sympy.Or(a, b)
+
+
+def _sympy_lshift(a, b):
+    from torch.utils._sympy.functions import LShift
+
+    return LShift(a, b)
+
+
+def _sympy_rshift(a, b):
+    from torch.utils._sympy.functions import RShift
+
+    return RShift(a, b)
+
+
+reflectable_magic_methods = {
+    "add": operator.add,
+    "sub": operator.sub,
+    "mul": operator.mul,
+    "mod": _sympy_mod,
+    "pow": _sympy_pow,
+    "and": _sympy_and,
+    "or": _sympy_or,
+    "truediv": _sympy_truediv,
+    "floordiv": _sympy_floordiv,
+    "lshift": _sympy_lshift,
+    "rshift": _sympy_rshift,
+}
+
+
+def _floor_ceil_helper(a, fn):
+    import sympy
+
+    if isinstance(a, sympy.Mul):
+        aa = a.args
+        if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
+            coef = sympy.Integer(aa[0])
+            if aa[0] == coef:  # structural equality test
+                return coef * aa[1]
+    if (
+        isinstance(a, sympy.Float)
+        and a == sympy.Integer(a)
+        or isinstance(a, sympy.Integer)
+    ):
+        return sympy.Integer(a)
+    return fn(a)
+
+
+def _sympy_floor(a):
+    import sympy
+
+    return _floor_ceil_helper(a, sympy.floor)
+
+
+def _sympy_ceil(a):
+    import sympy
+
+    return _floor_ceil_helper(a, sympy.ceiling)
+
+
+def _sympy_eq(a, b):
+    import sympy
+
+    return sympy.Eq(a, b)
+
+
+def _sympy_ne(a, b):
+    import sympy
+
+    return sympy.Ne(a, b)
+
+
+def _sympy_gt(a, b):
+    import sympy
+
+    return sympy.Gt(a, b)
+
+
+def _sympy_lt(a, b):
+    import sympy
+
+    return sympy.Lt(a, b)
+
+
+def _sympy_le(a, b):
+    import sympy
+
+    return sympy.Le(a, b)
+
+
+def _sympy_ge(a, b):
+    import sympy
+
+    return sympy.Ge(a, b)
+
+
+def _sympy_min(a, b):
+    import sympy
+
+    return sympy.Min(a, b)
+
+
+def _sympy_max(a, b):
+    import sympy
+
+    return sympy.Max(a, b)
+
+
+def _sympy_ite(a, t, f):
+    import sympy
+
+    return sympy.Piecewise((t, a), (f, True))
+
+
+current_module = sys.modules[__name__]
+
+
+def _get_sym_math_fn(name):
+    def fn(a):
+        import sympy
+
+        return getattr(sympy, name)(a)
+
+    return fn
+
+
+for name in math_op_names:
+    priv_sympy_name = f"_sympy_{name}"
+    fn = _get_sym_math_fn(name)
+    fn.__qualname__ = fn.__name__ = priv_sympy_name
+    setattr(current_module, priv_sympy_name, fn)
+
+del fn, name, priv_sympy_name  # type: ignore[possibly-undefined]
+
+
+def _sympy_abs(a):
+    import sympy
+
+    return sympy.Abs(a)
+
+
+def _sympy_round(number, ndigits=None):
+    from torch.utils._sympy.functions import Round, RoundDecimal
+
+    if ndigits is None:
+        return Round(number)
+    else:
+        return RoundDecimal(number, ndigits)
+
+
+def _sympy_sym_float(a):
+    # Cannot use sympy.Float(a) here, coz it expects python literals
+    # Multiply by 1.0 to cast to float. This is needed when the input
+    # is a SymInt which has the assumption that it is integer and
+    # SymPy will otherwise assume that return value cannot be a float.
+    return a * 1.0
+
+
+def _sympy_is_integer(a):
+    import sympy
+
+    return sympy.Eq(sympy.floor(a), a)
+
+
+magic_methods = {
+    **reflectable_magic_methods,
+    "sym_not": operator.invert,
+    "pos": operator.pos,
+    "eq": _sympy_eq,
+    "ne": _sympy_ne,
+    "gt": _sympy_gt,
+    "lt": _sympy_lt,
+    "le": _sympy_le,
+    "ge": _sympy_ge,
+    "floor": _sympy_floor,
+    "sym_float": _sympy_sym_float,
+    "ceil": _sympy_ceil,
+    "neg": operator.neg,
+    "sym_min": _sympy_min,
+    "sym_max": _sympy_max,
+    "sym_ite": _sympy_ite,
+    "abs": _sympy_abs,
+    "round": _sympy_round,
+    "is_integer": _sympy_is_integer,
+}
+
+
+for name in math_op_names:
+    sym_name = f"sym_{name}"
+    magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
+
+del name, sym_name, math_op_names, current_module  # type: ignore[possibly-undefined]
+
+
+def sympy_is_contiguous(sizes, strides):
+    dim = len(sizes)
+    return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
+
+
+def sympy_is_contiguous_generic(sizes, strides, dim_order):
+    import sympy
+
+    dim = len(sizes)
+
+    if len(dim_order) != dim:
+        return sympy.false
+
+    is_contiguous = sympy.true
+    z = sympy.Integer(1)
+    # Contiguous if the strides make sense (or the dim is size 1)
+    for d in dim_order:
+        is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z)
+        z *= sizes[d]
+    # OR if any size is zero
+    for d in range(dim):
+        is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0))
+    return is_contiguous
+
+
+# NB: There is a TODO in C++ to allow omitting the batch dim.  If that
+# happens you will need to refactor this
+
+
+def sympy_is_channels_last_contiguous_2d(sizes, strides):
+    return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
+
+
+def sympy_is_channels_last_contiguous_3d(sizes, strides):
+    return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
+
+
+def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
+    import sympy
+
+    dim = len(sizes)
+
+    if dim != len(dim_order):
+        return sympy.false
+
+    m = sympy.Integer(0)
+    r = sympy.true
+
+    # special case for trivial C dimension. default to NCHW
+    r &= sympy.Ne(strides[1], 0)
+
+    for d in dim_order:
+        r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
+        # Fallback to NCHW as default layout for ambiguous cases
+        # This is the flaw of implicit memory_format from strides.
+        # N111 tensor with identical strides for size 1 dimension;
+        # Two cases could lead us here:
+        # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
+        # b. N11W contiguous Tensor sliced on the W-dimension.
+        # ([N,1,1,1]@[W,W,W,W])
+        if d == 0:
+            r &= sympy.Ne(m, strides[1])
+        # This is necessary to:
+        # 1. distinguish the memory_format of N1H1;
+        #     [H, 1, 1, 1] channels_last stride
+        #     [H, H, 1, 1] contiguous stride
+        # 2. permutation of 1C1W:
+        #     [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
+        #     [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
+        #     channels_last
+        m = strides[d] * sympy.Max(sizes[d], 1)
+
+    return r
+
+
+def sympy_is_channels_last_strides_2d(sizes, strides):
+    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
+
+
+def sympy_is_channels_last_strides_3d(sizes, strides):
+    return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
+
+
+def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
+    from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
+
+    return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
+
+
+sizes_strides_methods = {
+    # TODO: These could also be done with indicators, maybe it is better
+    # for reasoning to do it that way
+    "is_contiguous": sympy_is_contiguous,
+    "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
+    "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
+    "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
+    "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
+    "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
+}
+
+alternate_impl_if_hinted_methods = {
+    "sym_min": builtins.min,
+    "sym_max": builtins.max,
+}
+
+
+def to_node(self, num):
+    if isinstance(num, SymTypes):
+        return num.node
+    elif type(num) is bool:
+        return self.wrap_bool(num)
+    elif type(num) is int:
+        return self.wrap_int(num)
+    elif type(num) is float:
+        return self.wrap_float(num)
+    else:
+        # NotImplemented is important so that Python tries the
+        # other magic method
+        return NotImplemented
+
+
+def wrap_node(x):
+    # TODO: let C++ also take advantage of this
+    if isinstance(x, SymNode) and x.constant is not None:
+        return x.constant
+    if x.is_int():
+        return SymInt(x)
+    elif x.is_float():
+        return SymFloat(x)
+    elif x.is_bool():
+        return SymBool(x)
+    else:
+        raise AssertionError(f"unrecognized return type {x}")
+
+
+def method_to_operator(method):
+    return METHOD_TO_OPERATOR[method]
+
+
+def _make_node_magic(method, func):
+    func = lru_cache(256)(func)
+
+    if method in magic_methods_on_operator_with_trailing_underscore:
+        method_attr = f"{method}_"
+    else:
+        method_attr = method
+
+    def binary_magic_impl(self, other):
+        from torch.fx.experimental.symbolic_shapes import safe_expand
+
+        op = method_to_operator(method)
+
+        out_hint = None
+        if self.hint is not None and other.hint is not None:
+            out_hint = op(self.hint, other.hint)
+
+        alternate_impl = alternate_impl_if_hinted_methods.get(method)
+        if alternate_impl and out_hint is not None:
+            return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
+
+        if sym_function_mode():
+            return to_node(
+                self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
+            )
+        assert isinstance(other, SymNode)
+        # TODO: consider constant prop here
+        try:
+            out = func(self.expr, other.expr)
+        except Exception:
+            log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
+            raise
+        out = safe_expand(out)
+        sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out)
+        pytype: Type
+        # This is not strictly correct. In Python, a**b may return complex when
+        # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
+        # returns a float while both arguments are ints: 2**(-1). Also, max and
+        # min do not type promote. To avoid having data-dependent control flow
+        # here, we just set the type to float if one of the args is a float. In
+        # case of a type mismatch, we assume that it will be detected during
+        # evaluation.
+        if method in always_float_magic_methods:
+            pytype = float
+        elif method in always_bool_magic_methods:
+            pytype = bool
+        elif self.pytype is float or other.pytype is float:
+            pytype = float
+        else:
+            pytype = self.pytype
+
+        if (
+            pytype is not None
+            and out_hint is not None
+            and not isinstance(out_hint, SymTypes)
+        ):
+            out_hint = pytype(out_hint)
+
+        # Create a FX node that corresponds to the operation being applied to
+        # this node.
+        fx_node, _ = self.shape_env._create_fx_call_function(
+            op, (self.fx_node, other.fx_node)
+        )
+        return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
+
+    def unary_magic_impl(self):
+        from torch.fx.experimental.symbolic_shapes import safe_expand
+
+        op = method_to_operator(method)
+        if sym_function_mode():
+            return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
+        # TODO: consider constant prop here
+        expr = self.expr
+        if method == "floor" or method == "ceiling":
+            expr = self.shape_env._simplify_floor_div(expr)
+
+        try:
+            out = func(expr)
+        except Exception:
+            log.warning("failed to eval %s(%s)", method, expr)
+            raise
+        sym_node_log.debug("%s %s -> %s", func, expr, out)
+        out_hint = None
+        if self.hint is not None:
+            out_hint = op(self.hint)
+        out = safe_expand(out)
+        pytype: Type
+        if method in always_int_magic_methods:
+            pytype = int
+        elif method in always_bool_magic_methods:
+            pytype = bool
+        elif method in always_float_magic_methods:
+            pytype = float
+        else:
+            pytype = self.pytype
+
+        fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
+        return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
+
+    if method in unary_methods:
+        setattr(SymNode, f"_{method_attr}", unary_magic_impl)
+    elif method == "sym_ite":
+
+        def sym_ite_impl(pred_node, then_node, else_node):
+            from torch.fx.experimental.symbolic_shapes import safe_expand
+
+            out_hint = then_node.hint if pred_node.hint else else_node.hint
+            if sym_function_mode():
+                return to_node(
+                    pred_node,
+                    handle_sym_dispatch(
+                        sym_ite,
+                        (
+                            wrap_node(pred_node),
+                            wrap_node(then_node),
+                            wrap_node(else_node),
+                        ),
+                        {},
+                    ),
+                )
+
+            try:
+                out = func(pred_node.expr, then_node.expr, else_node.expr)
+            except Exception:
+                log.warning(
+                    "failed to eval %s(%s, %s, %s)",
+                    method,
+                    pred_node.expr,
+                    then_node.expr,
+                    else_node.expr,
+                )
+                raise
+
+            out = safe_expand(out)
+            fx_node, _ = pred_node.shape_env._create_fx_call_function(
+                sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
+            )
+            return SymNode(
+                out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
+            )
+
+        setattr(SymNode, f"_{method_attr}", sym_ite_impl)
+    elif method == "round":
+
+        def round_impl(self, ndigits=None):
+            from torch.fx.experimental.symbolic_shapes import safe_expand
+
+            op = builtins.round
+            if sym_function_mode():
+                return to_node(
+                    self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
+                )
+
+            expr = self.expr
+            try:
+                out = func(expr, ndigits)
+            except Exception:
+                log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
+                raise
+            out = safe_expand(out)
+
+            pytype = int if ndigits is None else self.pytype
+
+            out_hint = None
+            if self.hint is not None:
+                out_hint = op(self.hint, ndigits)
+
+            # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
+            # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
+            # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
+            # hack down below works, because all round function down the line all take ndigits=None as default in their
+            # signature.
+            # TODO: Remove the args construction below if a different sentinel is used by FX.
+            args = [self.fx_node]
+            if ndigits is not None:
+                args.append(ndigits)
+            fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
+            return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
+
+        setattr(SymNode, f"_{method_attr}", round_impl)
+    else:
+        setattr(SymNode, f"_{method_attr}", binary_magic_impl)
+
+
+def _make_node_sizes_strides(method, func):
+    # NB: don't LRU cache, lots of arguments
+
+    def sizes_strides_impl(self, sizes, strides):
+        op = getattr(sys.modules[__name__], method)
+        if sym_function_mode():
+            return to_node(
+                self,
+                handle_sym_dispatch(
+                    op,
+                    ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
+                    {},
+                ),
+            )
+        size_exprs = [s.expr for s in sizes]
+        stride_exprs = [s.expr for s in strides]
+        try:
+            out = func(size_exprs, stride_exprs)
+        except Exception:
+            log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
+            raise
+        # bool is never expandable
+
+        size_hints = []
+        out_hint = None
+        for s in sizes:
+            if s.hint is None:
+                break
+            size_hints.append(s.hint)
+        else:
+            stride_hints = []
+            for s in strides:
+                if s.hint is None:
+                    break
+                stride_hints.append(s.hint)
+            else:
+                out_hint = op(size_hints, stride_hints)
+
+        # NB: This is the indicator function, not the actual bool!
+        pytype: Type
+        if method.endswith("_indicator"):
+            pytype = int
+        else:
+            pytype = bool
+        return SymNode(out, self.shape_env, pytype, out_hint)
+
+    setattr(SymNode, f"_{method}", sizes_strides_impl)
+
+    # TODO: This is technically hotpath, but in the ideal end state
+    # guards on this will resolve at a higher level so you never
+    # spend time in this code
+    def sizes_strides_user(sizes, strides):
+        import sympy
+
+        from torch.fx.experimental.symbolic_shapes import (
+            eval_is_non_overlapping_and_dense,
+        )
+
+        for a in itertools.chain(sizes, strides):
+            if isinstance(a, SymInt):
+                return wrap_node(
+                    getattr(a.node, method)(
+                        [to_node(a.node, b) for b in sizes],
+                        [to_node(a.node, b) for b in strides],
+                    )
+                )
+        if method == "is_non_overlapping_and_dense_indicator":
+            return eval_is_non_overlapping_and_dense(sizes, strides)
+        else:
+            # TODO: this is an awful implementation
+            return bool(
+                func(
+                    [sympy.sympify(a) for a in sizes],
+                    [sympy.sympify(a) for a in strides],
+                )
+            )
+
+    # Skip for is_non_overlapping_and_dense_indicator
+    if not hasattr(sys.modules[__name__], method):
+        setattr(sys.modules[__name__], method, sizes_strides_user)
+
+
+for method, func in magic_methods.items():
+    _make_node_magic(method, func)
+
+for method, func in sizes_strides_methods.items():
+    _make_node_sizes_strides(method, func)
+
+
+def _make_user_magic(method, user_type):
+    # User magic takes care of wrapping the other operand into a node,
+    # so that our internal logic can assume everything is nodes
+
+    if method in magic_methods_on_operator_with_trailing_underscore:
+        method_attr = f"sym_{method}"
+    else:
+        method_attr = method
+
+    def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
+        if isinstance(x, (int, float, bool)):
+            return x
+        if isinstance(x, SymBool):
+            return x.node.guard_bool("", 0)
+        raise AssertionError("expect to be called with constant SymBools")
+
+    def is_constant(x):
+        if isinstance(x, (int, float, bool)):
+            return True
+        if isinstance(x, (SymInt, SymFloat, SymBool)):
+            return x.node.is_constant()
+        return False
+
+    if method in bool_becomes_int_magic_methods:
+
+        def promote(x):
+            """Implements True+True=2, which works in python but not sympy"""
+            if isinstance(x, SymBool):
+                return SymInt(x.node.wrap_int(int(x)))
+            return x
+
+    else:
+
+        def promote(x):
+            return x
+
+    # Before and after performing the operation, check if any operands are constant.
+    # If so, extract out the constant values first. If `self` itself is a
+    # constant, then "redispatch" by calling back into the operator. Sometimes
+    # this means that operations involving SymBool return plain bools.
+    # Alternatively, we could also rewrap into constant Symbool (i.e. by
+    # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
+    # today for no particular reason.
+    def unary_magic_impl(self):
+        self = promote(self)
+        if is_constant(self):
+            return (method_to_operator(method))(get_constant(self))
+        return wrap_node(getattr(self.node, method_attr)())
+
+    def binary_magic_impl(self, other):
+        sym_node_log.debug("MAGIC %s %s %s", method, self, other)
+        self = promote(self)
+        other = promote(other)
+        if is_constant(self):
+            return (method_to_operator(method))(get_constant(self), other)
+        if is_constant(other):
+            other = get_constant(other)
+        other_node = to_node(self.node, other)
+        if other_node is NotImplemented:
+            return NotImplemented
+        ret = wrap_node(getattr(self.node, method_attr)(other_node))
+        return get_constant(ret) if is_constant(ret) else ret
+
+    def rbinary_magic_impl(self, other):
+        self = promote(self)
+        other = promote(other)
+        if is_constant(self):
+            return (method_to_operator(method))(get_constant(self), other)
+        if is_constant(other):
+            other = get_constant(other)
+        other_node = to_node(self.node, other)
+        if other_node is NotImplemented:
+            return NotImplemented
+        ret = wrap_node(getattr(other_node, method_attr)(self.node))
+        return get_constant(ret) if is_constant(ret) else ret
+
+    if method in unary_magic_methods:
+        setattr(user_type, f"__{method}__", unary_magic_impl)
+    elif method in unary_nonmagic_methods:
+        orig = getattr(user_type, method)
+        setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
+    elif method == "sym_ite":
+
+        def sym_ite_magic_impl(pred, then_val, else_val):
+            pred_node = pred.node
+            then_node = to_node(pred_node, then_val)
+            else_node = to_node(pred_node, else_val)
+            if then_node is NotImplemented or else_node is NotImplemented:
+                return NotImplemented
+            assert (
+                isinstance(then_node, SymNode)
+                and isinstance(else_node, SymNode)
+                and then_node.pytype == else_node.pytype
+            )
+            ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
+            return get_constant(ret) if ret.node.is_constant() else ret
+
+        setattr(user_type, f"__{method}__", sym_ite_magic_impl)
+    elif method == "round":
+
+        def round_magic_impl(self, ndigits=None):
+            if is_constant(self):
+                return builtins.round(get_constant(self), ndigits)
+
+            return wrap_node(getattr(self.node, method)(ndigits))
+
+        setattr(user_type, f"__{method}__", round_magic_impl)
+    else:
+        setattr(user_type, f"__{method}__", binary_magic_impl)
+        if method in reflectable_magic_methods:
+            setattr(user_type, f"__r{method}__", rbinary_magic_impl)
+
+
+for method, func in magic_methods.items():  # type: ignore[assignment]
+    if method in only_bool_magic_methods:
+        _make_user_magic(method, SymBool)
+        continue
+    if method in only_float_magic_methods:
+        _make_user_magic(method, SymFloat)
+        continue
+    if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
+        _make_user_magic(method, SymBool)
+    _make_user_magic(method, SymInt)
+    _make_user_magic(method, SymFloat)
+
+del method
+del func
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/symbolic_shapes.py b/MLPY/Lib/site-packages/torch/fx/experimental/symbolic_shapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..5317287e37b1c16952d5d62461f4f954a8d85bda
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/symbolic_shapes.py
@@ -0,0 +1,4362 @@
+# mypy: ignore-errors
+
+"""
+``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with
+our symbolic shapes reasoning system that is used heavily in torch.compile.  Although
+this is not generally considered public API, when writing framework code in PyTorch
+as well as extensions to PyTorch (e.g., in custom operator implementations), you may
+need to make use of these APIs to setup dynamic shapes support appropriately.
+"""
+
+import builtins
+import collections
+import functools
+import inspect
+import itertools
+import logging
+import math
+import operator
+import re
+import sys
+import threading
+import traceback
+from collections import defaultdict
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from enum import Enum
+from functools import lru_cache
+from typing import (
+    Any,
+    cast,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    Union,
+    TYPE_CHECKING
+)
+from typing_extensions import TypeAlias
+
+import torch
+import torch.fx
+import torch.fx.traceback as fx_traceback
+from torch.fx.experimental import _config as config
+
+from torch.fx.experimental.recording import (
+    FakeTensorMeta,
+    ShapeEnvEvent,
+    record_shapeenv_event,
+    replay_shape_env_events,
+    shape_env_check_state_equal
+)
+from torch.fx.experimental.sym_node import SymNode, SymTypes
+
+# NB: The sym_* functions are used via getattr() and must be imported here.
+from torch import SymBool, SymFloat, SymInt
+from torch._guards import ShapeGuard, Source, TracingContext
+from torch.utils._python_dispatch import is_traceable_wrapper_subclass
+from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator
+from torch.utils._sympy.solve import try_solve
+from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
+from torch.utils._sympy.singleton_int import SingletonInt
+from torch.utils._traceback import format_frame, CapturedTraceback
+from torch._utils_internal import signpost_event
+from torch._subclasses.meta_utils import is_sparse_any
+
+from torch._logging import LazyString
+
+if TYPE_CHECKING:
+    from torch._dynamo.source import TensorPropertySource
+
+InputList = List
+DimList = List
+
+log = logging.getLogger(__name__)
+
+class GuardOnDataDependentSymNode(RuntimeError):
+    pass
+
+import sympy
+from sympy.printing.str import StrPrinter
+from sympy.printing.precedence import precedence, PRECEDENCE
+
+aten = torch._ops.ops.aten  # type: ignore[has-type]
+
+__all__ = [
+    "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
+    "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
+    "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
+    "is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
+    "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
+    "StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
+    "guard_size_oblivious",
+]
+
+# FX node metadata keys for symbolic shape FX graph.
+SHAPEENV_EVENT_KEY = "shapeenv_event"
+CURRENT_NODE_KEY = "current_node"
+
+# These are modules that contain generic code for interacting with ShapeEnv
+# which are unlikely to identify a particular interesting guard statement
+@lru_cache(None)
+def uninteresting_files() -> Set[str]:
+    import torch._inductor.sizevars
+    import torch._library.abstract_impl
+    import torch._subclasses.meta_utils
+    import torch._subclasses.fake_tensor
+    mods = [
+        sys.modules[__name__],
+        torch.fx.experimental.recording,
+        torch.fx.experimental.sym_node,
+        torch.fx.interpreter,
+        torch,
+        torch._inductor.sizevars,
+        torch._library.abstract_impl,
+        torch._subclasses.meta_utils,
+        torch._subclasses.fake_tensor,
+    ]
+    return {inspect.getfile(m) for m in mods}
+
+# We don't bother with the metaclass as all of the dispatching logic happens
+# entirely from Python
+#
+# Didn't bother with ancestors for now, unlikely to have multiple modes for
+# symints right now
+
+class ConstraintViolationError(RuntimeError):
+    pass
+
+def has_symbolic_sizes_strides(elem) -> bool:
+    return elem._has_symbolic_sizes_strides
+
+Int = Union[torch.SymInt, int]
+
+def create_contiguous(shape: Sequence[Int]) -> List[Int]:
+    strides: List[Int] = [1]
+    for dim in reversed(shape[:-1]):
+        strides.append(dim * strides[-1])
+    return list(reversed(strides))
+
+def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
+    """
+    Retrieve the hint for an int (based on the underlying real values as observed
+    at runtime).  If no hint is available (e.g., because data dependent shapes),
+    if fallback is not None, use that instead (otherwise raise an error).
+    """
+    if isinstance(a, torch.SymInt):
+        return a.node.require_hint(fallback)
+    assert type(a) is int, a
+    return a
+
+Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
+
+def has_hint(a: Scalar) -> bool:
+    if isinstance(a, SymTypes):
+        return a.node.has_hint()
+    return True
+
+def is_concrete_int(a: Union[int, SymInt]) -> bool:
+    r""" Utility to check if underlying object
+    in SymInt is concrete value. Also returns
+    true if integer is passed in.
+
+    Args:
+        a (SymInt or int): Object to test if it int
+    """
+    assert isinstance(a, (SymInt, int))
+
+    if isinstance(a, int):
+        return True
+
+    if isinstance(a.node.expr, sympy.core.numbers.Integer):
+        return True
+
+    return False
+
+# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.
+# So make sure only type checker evaluates this alias.
+# Xref: https://www.internalfb.com/diff/D53324783
+SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean"
+
+def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
+    """
+    Perform a guard on a symbolic boolean expression in a size oblivious way.
+    This is typically used when a non-oblivious test would result in a guard
+    on a data dependent value of which we don't know the value of at compile time.
+    When a guard is tested this way, we may diverge in behavior from how regular
+    PyTorch semantics would treat it.  For more information, see
+    https://github.com/pytorch/pytorch/pull/118579
+    """
+    if isinstance(expr, torch.SymBool):
+        return expr.node.guard_size_oblivious("", 0)
+    else:
+        assert isinstance(expr, bool)
+        return expr
+
+def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
+    r""" Canonicalize a boolean expression by transforming it into a lt / le
+    inequality and moving all the non-constant terms to the rhs.
+    We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
+    recursively
+    nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924
+
+    Args:
+        expr (sympy.Expr): Expression to canonicalize
+    """
+    # Canonicalise an inequality by transforming it into a lt / le
+    # inequality and moving all the non-constant terms to the rhs
+    # We canonicalise And / Ors / Not via cnf
+    # nb. Relational.canonical in sympy is broken
+    # https://github.com/sympy/sympy/issues/25924
+
+    if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)):
+        return expr
+
+    if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)):
+        expr = sympy.logic.boolalg.to_cnf(expr)
+    return _canonicalize_bool_expr_impl(expr)
+
+def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
+    """
+    After canonicalization, we are guaranteed to have eliminated Ge/Gt relations
+    (rewriting them to Le/Lt, respectively).
+    """
+    if isinstance(expr, (sympy.And, sympy.Or)):
+        return type(expr)(*map(canonicalize_bool_expr, expr.args))
+
+    opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
+    if isinstance(expr, tuple(opposite.keys())):
+        lhs = expr.rhs - expr.lhs
+        t = opposite[type(expr)]
+    else:
+        assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne))
+        lhs = expr.lhs - expr.rhs
+        t = type(expr)
+    rhs = 0
+    if isinstance(lhs, sympy.Add):
+        cts = []
+        variables = []
+        for term in lhs.args:
+            if term.is_number:
+                cts.append(term)
+            else:
+                variables.append(term)
+        lhs = sympy.Add(*variables)
+        rhs = -sympy.Add(*cts)
+    return t(lhs, rhs)
+
+def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
+    r""" Utility to check if underlying object
+    in SymBool is concrete value. Also returns
+    true if integer is passed in.
+    Args:
+        a (SymBool or bool): Object to test if it bool
+    """
+    assert isinstance(a, (SymBool, bool))
+
+    if isinstance(a, bool):
+        return True
+
+    if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)):
+        return True
+
+    return False
+
+def is_nested_int(s):
+    return isinstance(s, torch.SymInt) and s.node.is_nested_int()
+
+def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
+    if isinstance(val, SymTypes):
+        # This allow applies to the jagged layout NestedTensor case as
+        # nested ints are not symbolic
+        if is_symbolic(val):
+            yield val.node.expr
+    elif isinstance(val, sympy.Basic):
+        yield val
+    elif isinstance(val, (int, float, bool)):
+        pass
+    elif is_sparse_any(val):
+        yield from _iterate_exprs(val.size())
+    elif isinstance(val, torch.Tensor):
+        yield from _iterate_exprs(val.size())
+        yield from _iterate_exprs(val.stride())
+        yield from _iterate_exprs(val.storage_offset())
+    elif isinstance(val, (tuple, list)):
+        for s in val:
+            yield from _iterate_exprs(s)
+    elif val is None:
+        pass
+    else:
+        raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
+
+def free_symbols(val: Union[SymInt, torch.Tensor]) -> Set[sympy.Symbol]:
+    if val is None:
+        return set()
+    itr = _iterate_exprs(val)
+    # we need at least 1 to call union, so we hand code the identity
+    try:
+        first_expr = next(itr)
+    except StopIteration:
+        return set()
+
+    return first_expr.free_symbols.union(*(e.free_symbols for e in itr))
+
+def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
+    """Faster version of bool(free_symbols(val))"""
+    return not all(e.is_number for e in _iterate_exprs(val))
+
+# Like free_symbols, but filtered to only report unbacked symbols
+def free_unbacked_symbols(x):
+    # NB: keep synced with is_unbacked_symint
+    return {s for s in free_symbols(x) if s.name.startswith(("u", "f"))}
+
+# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
+# setup!
+def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
+    if (
+        node.op == "placeholder" and
+        "val" in node.meta and
+        isinstance(node.meta["val"], torch.SymInt) and
+        isinstance(node.meta["val"].node.expr, sympy.Symbol)
+    ):
+        return node.meta["val"].node.expr
+    return None
+
+def find_symbol_binding_fx_nodes(graph):
+    return {
+        node.meta["val"].node.expr: node
+        for node in graph.nodes
+        if is_symbol_binding_fx_node(node)
+    }
+
+def definitely_true(a):
+    """
+    Returns True only if we can tell that a is True, possibly introducing
+    a guard in the process.  If a depends on some unbacked SymInt, we may
+    return False even though there may exist a possible value of the SymInt
+    that would cause the expression to return True.
+
+    When is it appropriate to use definitely_true?  First, if you can use
+    a higher level combinator like parallel_or/parallel_and, prefer using
+    those instead, they are definitely safe (modulo short-circuiting).
+    Second, it can be used if the program would behave equivalently if
+    definitely_true always returned False (parallel_or/parallel_and are
+    examples of this pattern, modulo short-circuiting).  Finally, it even
+    be OK if the program wouldn't behave equivalently, so long as the
+    change is semantics preserving.  It can be semantics preserving if
+    the program errors in more cases than it did previously (but otherwise
+    behaves identically), or if it changes some quantity in a way that
+    doesn't matter (e.g., strides often fall in this bucket.)
+    """
+    if isinstance(a, SymBool):
+        if a.node.has_hint():
+            return guard_bool(a)
+        else:
+            return False
+    return bool(a)
+
+def definitely_false(a):
+    """
+    Returns True only if we can tell that a is False, possibly introducing
+    a guard in the process.  If a depends on some unbacked SymInt, we may
+    return False even though there may exist a possible value of the SymInt
+    that would cause the expression a to be False.  See definitely_true
+    for more usage guidance.
+    """
+    if isinstance(a, SymBool):
+        if a.node.has_hint():
+            return not guard_bool(a)
+        else:
+            return False
+    return not bool(a)
+
+def statically_known_true(x: Union[bool, SymBool]) -> bool:
+    """Returns True if x can be simplified to a constant and is true.
+
+    .. note::
+        This function doesn't introduce new guards, so the expression may end
+        up evaluating to true at runtime even if this function returns False.
+
+    Args:
+        x (bool, SymBool): The expression to try statically evaluating
+
+    """
+    if isinstance(x, SymBool):
+        expr = x.node.expr
+        shape_env = x.node.shape_env
+        try:
+            simplified = shape_env._maybe_evaluate_static(expr)
+            if simplified is not None:
+                return bool(simplified)
+        except Exception:
+            log.debug("Could not simplify %s", expr)
+        return False
+    assert isinstance(x, bool)
+    return x
+
+
+def parallel_or(*args):
+    """
+    Evaluate the logical OR of several arguments, avoiding guarding on
+    unbacked SymInts if another argument is definitely True.
+    """
+    if any(statically_known_true(a) for a in args):
+        return True
+    if any(definitely_true(a) for a in args):
+        return True
+    return any(args)
+
+def parallel_and(*args):
+    """
+    Evaluate the logical FALSE of several arguments, avoiding guarding on
+    unbacked SymInts if another argument is definitely False.
+    """
+    if any(statically_known_true(torch.sym_not(a)) for a in args):
+        return False
+    if any(definitely_false(a) for a in args):
+        return False
+    return all(args)
+
+def sym_eq(x, y):
+    """
+    Like ==, but when run on list/tuple, it will recursively test equality
+    and use sym_and to join the results together, without guarding.
+    """
+    if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
+        if len(x) != len(y):
+            return False
+        return functools.reduce(operator.and_, map(sym_eq, x, y), True)
+    elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
+        return x == y
+    else:
+        raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
+
+def guard_scalar(a):
+    if isinstance(a, (SymBool, bool)):
+        return guard_bool(a)
+    elif isinstance(a, (SymInt, int)):
+        return guard_int(a)
+    elif isinstance(a, (SymFloat, float)):
+        return guard_float(a)
+    else:
+        raise AssertionError(f"unrecognized scalar {a}")
+
+
+@record_shapeenv_event()
+def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
+    upd_vr = ValueRanges(compiler_min, compiler_max)
+    old_vr = shape_env.var_to_range.get(s, ValueRanges.unknown())
+    new_vr = shape_env.var_to_range[s] = old_vr & upd_vr
+    if new_vr != old_vr:
+        log.info("_constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper)
+
+
+def _advise_is_size(a):
+    """
+    Don't use this directly; use torch._check_is_size instead.
+
+    This is a softer version of _constrain_range_for_size (with min=0,
+    max=Inf).  Instead of forcibly constraining a variable (and erroring if we
+    failed to constrain it), it will simply advise us that a size is
+    constrained in some way.  We will always defer a runtime assert for this
+    constraint if we cannot prove it at compile-time, but we we only
+    *sometimes* learn useful extra information at compile-time with this
+    information.  This is in contrast to constrain_range_for_size, where if
+    you don't call that on a fresh unbacked symint, chances are we will choke.
+
+    TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed
+    code.  Right now this is only really used in code with AOTAutograd trace
+    through, so it is not a big problem that this isn't supported, but in
+    principle all of this code should be Dynamo'able too.
+
+    TODO: I didn't support min/max because I didn't have a use case where this
+    actually helped.  In principle we can support it, it just makes the
+    implementation below more complicated.
+    """
+
+    # This must always succeed, because the sole allowed caller _check_is_size
+    # was responsible for expect_true'ing this
+    assert a >= 0
+
+    # NB: it's important not to constrain range for size for *hinted* SymInts,
+    # because it is not only unsound, it will immediately trip our asserts
+    # that hints have to be consistent with static analysis!  If you somehow
+    # have an unbounded SymInt that later constrains to 1, this will be
+    # inconsistent with the range
+    if (
+        isinstance(a, SymInt)
+        and isinstance(a.node, SymNode)
+        and not a.node.has_hint()
+        and isinstance(a.node.expr, sympy.Symbol)
+    ):
+        _constrain_range_for_size(a)
+
+@record_shapeenv_event()
+def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
+    """
+    This function is NOT INTENDED to be used by itself.
+    """
+
+    if isinstance(a, (SymFloat, SymBool)):
+        raise ValueError("Constraining SymFloat/SymBool is nyi")
+
+    assert isinstance(a, SymInt), "can only constrain range for SymInt"
+    assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+
+    if min is None:
+        min = 0
+    if max is None:
+        max = sympy.oo
+
+    if max < min:
+        raise ValueError(
+            "Maximum value to constrain_as_size can't be less than the specified min value, "
+            "received min={min} and max={max}"
+        )
+
+    _constrain_symbol_range(
+        a.node.shape_env,
+        a.node.expr,
+        compiler_min=min,
+        compiler_max=max,
+    )
+    a.node.shape_env.size_like.add(a.node.expr)
+
+
+# inclusive both ways
+@record_shapeenv_event()
+def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
+    """
+    Applies a constraint that the passed in SymInt must lie between min-max
+    inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning
+    that it can be used on unbacked SymInts).  If min/max are None, we assume
+    that the dimension is unbounded in that direction.  Repeated application
+    of constrain_range intersects the ranges.  This is a fairly low level API
+    that doesn't have a lot of safety guarantees (TODO: provide higher level
+    APIs).
+
+    Currently, we use this API in the following circumstance: when we allocate
+    an unbacked SymInt, denoting an integer quantity which is data dependent,
+    we ordinarily do not know anything about what values it may take.  This
+    means that any sort of guard on it will immediately fail.  However, in
+    many cases, we know something about the unbacked SymInt: for example, we
+    know that nonzero(x).size(0) must be >= 0.  We use constrain_range to
+    narrow the possible range, declaring that negative symbols are impossible.
+    This permits to definitely answer True to queries like 'nnz >= 0', even if
+    we don't know what the actual (hinted) value of 'nnz' is.  In fact, we
+    actually use constrain_range to unsoundly discharge common guards: for an
+    unbacked SymInt produced by nonzero, we will also assume that it is not
+    equal to 0/1 (even though these are perfectly possible values at runtime),
+    because we generally expect graphs that are valid for N=2 to also be valid
+    for N=1.
+    """
+    if min is None:
+        min = -sympy.oo
+    if max is None:
+        max = sympy.oo
+
+    if max < min:
+        raise ValueError(
+            "Maximum value to constrain_as_size can't be less than the specified min value, "
+            "received min={min} and max={max}"
+        )
+
+    if isinstance(a, int):
+        if not (min <= a <= max):
+            raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
+        return
+
+    if isinstance(a.node.expr, sympy.Integer):
+        if not (min <= int(a.node.expr) <= max):
+            raise ValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]")
+        return
+    assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+
+    # TODO: Shouldn't we install a guard if the symbol is backed?  Or is the
+    # semantics that this is an "unchecked" assert (but it this actually
+    # something useful?  Might be better to restrict only for unbacked
+    # SymInt).
+    _constrain_symbol_range(
+        a.node.shape_env,
+        a.node.expr,
+        compiler_min=min,
+        compiler_max=max,
+    )
+
+
+@record_shapeenv_event()
+def constrain_unify(a, b):
+    """
+    Given two SymInts, constrain them so that they must be equal.  NB:
+    this will not work with SymInts that represent nontrivial expressions
+    (yet!)
+    """
+    # TODO: this does not install a deferred runtime assert yet
+
+    # TODO: Maybe dedupe this with _maybe_guard_rel?
+    if not isinstance(a, SymInt):
+        if not isinstance(b, SymInt):
+            assert a == b
+        else:
+            assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+            shape_env = b.node.shape_env
+            shape_env.replacements[b.node.expr] = sympy.Integer(a)
+    else:
+        # TODO: Actually, we can support this as long as one of them is a symbol.
+        # NB: We can't actually do "unification" as our operators are not
+        # injective
+        assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+        shape_env = a.node.shape_env
+        if not isinstance(b, SymInt):
+            shape_env.replacements[a.node.expr] = sympy.Integer(b)
+        else:
+            assert a.node.shape_env is b.node.shape_env
+            assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+            new_var = shape_env._find(a.node.expr)
+            shape_env.replacements[b.node.expr] = new_var
+
+# Assume that a boolean is true for the purposes of subsequent symbolic
+# reasoning.  This will keep track of corresponding runtime checks to verify
+# that the result is upheld: either as a regular guard, or as a special set
+# of asserts which are triggered when an unbacked SymInt is allocated.
+#
+# DO NOT use this function for these cases:
+#
+#  - This is inappropriate for "branching" conditions (where both
+#    true and false result in valid programs).  We will always assume
+#    the condition evaluates true, and so it will never be possible
+#    to trace the false condition when you use it.  For true branching
+#    on unbacked SymInts, you must use torch.cond; if you incorrectly
+#    use expect_true in this case, you will make the false branch
+#    unreachable (as we will simply assume that only the true branch
+#    is ever exercised).
+#
+#  - This is inappropriate for situations where you know some other system
+#    invariant guarantees that this property holds, since you don't
+#    really need to insert a runtime check in that case.  Use something
+#    like constrain_range in that case.
+#
+# This API has a hitch.  To avoid having to reimplement error reporting
+# capabilities, this function CAN return False.  The invariant is that
+# the surrounding code must raise an error when this function returns
+# False.  This is quite low level, so we recommend using other functions
+# like check() which enforce this in a more intuitive way.
+#
+# By the way, this name is a nod to the __builtin_expect macro,
+# which is used similarly (but unlike __builtin_expect, you MUST fail
+# in the unlikely branch.)  (I think expect is a good name; in recent
+# versions of C++, this is replaced with [[likely]], which is weaker
+# and not accurate for this function!)
+def expect_true(a, skip: int = 0):
+    if isinstance(a, SymBool):
+        # TODO: check perf implications of this
+        frame = inspect.currentframe()
+        for _ in range(skip + 1):  # always run this loop at least once
+            frame = frame.f_back
+        return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno)
+    assert type(a) is bool, a
+    return a
+
+def guard_bool(a):
+    if isinstance(a, SymBool):
+        return a.node.guard_bool("", 0)  # NB: uses Python backtrace
+    assert type(a) is bool, a
+    return a
+
+def guard_int(a):
+    if isinstance(a, SymInt):
+        return a.node.guard_int("", 0)  # NB: uses Python backtrace
+    assert type(a) is int, a
+    return a
+
+def guard_float(a):
+    if isinstance(a, SymFloat):
+        return a.node.guard_float("", 0)  # NB: uses Python backtrace
+    assert isinstance(a, float), a
+    return a
+
+# Given a GraphModule, return all the FakeTensors for all the placeholders
+def fx_placeholder_vals(gm):
+    return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"]
+
+def fx_placeholder_targets(gm):
+    return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
+
+# Given a GraphModule and arguments to run it with, evaluate that the guards
+# for its associated ShapeEnv are satisfied by the passed arguments.  This
+# WILL check for duck sizing.
+def eval_guards(gm, *args, ignore_static=True):
+    return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
+
+def bind_symbols(gm, *args):
+    return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
+
+def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges):
+    """
+    We assert that the bounds are either Boolean, or not finite, or can be computed
+    in exact prevision via rational arithmetic.
+    The only exception to this is the rare case when the user calls `sqrt(s0)`
+    sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still)
+    """
+    assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr)
+    assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr)
+
+class DimDynamic(Enum):
+    """
+    Controls how to perform symbol allocation for a dimension.  It is always
+    sound to default this to DYNAMIC, but the policies DUCK and STATIC can
+    result in better trace-time and compile-time performance, as they reduce
+    the number of allocated symbols and generally make your graph more static.
+
+    NB: If we notice you've applied a constraint to the dimension, we will
+    force it to DYNAMIC for simplicity.
+
+    DimDynamic is controlled by a variety of higher level UX features.
+    Currently:
+
+    - In eager mode, the default policy is DUCK.
+        - The default is changed to STATIC with assume_static_by_default.
+        - An individual dim is marked DYNAMIC if you mark_dynamic_dim.
+    - In export mode, the default policy is STATIC.
+        - An individual dim is marked DYNAMIC if you mention it as dynamic_dim
+          in the constraints kwarg.
+    """
+    # Treat the dimension symbolically
+    DYNAMIC = 0
+    # Treat the dimension symbolically, but if its hint matches another
+    # dynamic dimension, unify the two symbols ("duck sizing")
+    DUCK = 1
+    # Treat the dimension statically based on its hint
+    STATIC = 2
+
+
+# NB: These constraints affect both clients and backends: given some
+# constraint C, the client must pass inputs that satisfy the constraint,
+# while a backend must not introduce guards BEYOND this constraint.
+# For clarity, we document the implications on both sides for both the client
+# and the backend.
+#
+# NB: These constraints are on a *single* dimension.  In principle, we could
+# also have multi-dimension constraints, but our guess is that this is not
+# actually useful and so we are not supporting it right now.
+#
+# NB: Strict constraints are typically only suitable for export, as in eager
+# a backend like inductor may validly introduce extra, discretionary guards
+# to improve performance of code.  A StrictMinMaxConstraint would be brittle
+# under future optimizations performed by inductor; we don't guarantee
+# eager code with StrictMinMaxConstraint will keep working in the future!
+
+@dataclass(frozen=True)
+class Constraint:
+    warn_only: bool
+
+@dataclass(frozen=True)
+class StrictMinMaxConstraint(Constraint):
+    """
+    For clients: the size at this dimension must be within 'vr' (which
+    specifies a lower and upper bound, inclusive-inclusive) AND it
+    must be non-negative and should not be 0 or 1 (but see NB below).
+
+    For backends: there must not be any guards on this dimension which
+    are not implied by the given lower and upper bound.  Regardless of
+    the lower bound, the backend can assume the size is non-negative
+    and that it is not 0 or 1.
+
+    An unbounded StrictMinMaxConstraint can be thought of as a strict version
+    of "RelaxedUnspecConstraint".
+
+    NB: Export will often unsoundly assume that a graph works for 0/1, even
+    though at trace time we assumed size is not 0 or 1.  The idea is that
+    if we produce a graph that works for a range of values, it will be OK
+    for N=0/1 too.
+    """
+    vr: ValueRanges
+
+    def render(self, source: Source):
+        """Format the constrain equation"""
+        # TODO: better printing for -oo and oo
+        return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
+
+@dataclass(frozen=True)
+class RelaxedUnspecConstraint(Constraint):
+    """
+    For clients: no explicit constraint; constraint is whatever is implicitly
+    inferred by guards from tracing.
+
+    For backends: there must exist at least TWO possible values for the
+    size at this dimension which satisfy the guards for this dimension.
+
+    In other words, this constraint helps us distinguish between "we don't
+    care if this dimension specializes or not" versus "this dimension must be
+    unspecialized."  However, this constraint doesn't say very much about what
+    specialization is permitted; for example, if we guard on a size being
+    even, this would still be acceptable under an unspec constraint.  This
+    makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler
+    may add constraints to otherwise dynamic dimensions; we can't assert that
+    there are NO guards as this is brittle because compilers should be able to
+    add extra constraints.  If you want to assert that there are no guards,
+    use StrictMinMaxConstraint with an unbounded ValueRanges.
+    """
+    def render(self, source: Source):
+        return f"RelaxedUnspecConstraint({source.name()})"
+
+# NB: None here indicates the client constraint is whatever is implicitly
+# inferred by guards from tracing, and that a backend can add whatever guards
+# it wants (including fully specializing the value).
+DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None]
+
+@dataclass(frozen=True)
+class EqualityConstraint(Constraint):
+    """
+    Represent and decide various kinds of equality constraints between input sources.
+
+    A "source pair" is a pair of input sources for dynamic dimensions that
+    are specified equal. We represent `source_pairs` in a union-find forest
+    so that we can efficiently check whether two such sources are transitively equal.
+
+    A "derived equality" relates an input source to an expression over a root.
+    The root can be another input source, corresponding to some dynamic dimension,
+    or a phantom symbol that does not directly represent any dynamic dimension. We
+    represent `derived_equalities` involving input sources in a transitively-closed map
+    so that we can efficiently check whether an input source is transitively equal to
+    a given expression over another input source.
+    (NOTE: In contrast, it is easy to decide whether an input source is transitively equal
+    to a given expression over a phantom symbol; such expressions are already in canonical
+    form and so the problem reduces to symbolic expression equality.)
+    """
+    source_pairs: List[Tuple[Source, Source]]
+    derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]]
+    phantom_symbols: List[sympy.Symbol]
+
+    def __post_init__(self):
+        """Pre-processing to answer queries `is_equal` and `is_derived` below.
+
+        Example: Suppose we are given:
+          source_pairs [a = b, b = c]
+          derived_equalities [d = c + 1, e = d - 1]
+        We first construct a union find with source_pairs:
+          _parents = {a: a, b: a, c: a}
+        Then we compute canonical symbolic expressions, recursively applying derived_equalities
+        until we bottom out:
+          _defs = {d: c + 1, e: (c + 1) - 1 aka c}
+        """
+
+        # self._parents is a map from input sources to input sources where, conceptually,
+        # these are directed edges in a union-find forest
+        _parents: Dict[Source, Source] = {}
+        object.__setattr__(self, "_parents", _parents)
+        # self._defs is a map from input sources to "canonical" symbolic expressions,
+        # i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
+        # not derived Dims)
+        _defs: Dict[Source, sympy.Expr] = {}
+        object.__setattr__(self, "_defs", _defs)
+
+        for source1, source2 in self.source_pairs:
+            # preprocess into a union-find forest
+            self._union(self._find(source1), self._find(source2))
+        for source, root, fn in self.derived_equalities:
+            # preprocess into a transitively-closed map
+            # NOTE(avik): we reuse the union-find forest for canonicalizing input sources
+            if isinstance(root, sympy.Symbol):
+                self._defs[self._find(source)] = fn(root)
+            else:
+                self._defs[self._find(source)] = fn(self._rewrite(root))
+
+    def _find(self, source):
+        # chase edges to find the root of this equivalence class
+        if source in self._parents:
+            return self._find(self._parents[source])
+        else:
+            return source
+
+    def _union(self, root1, root2):
+        # merge two equivalence classes by adding an edge from one root to the other
+        if root1 != root2:
+            self._parents[root1] = root2
+
+    def _rewrite(self, src):
+        # always represent the given source by the root of its equivalence class
+        src = self._find(src)
+        if src in self._defs:
+            # simply look up the definition if it exists
+            # NOTE(avik): This works because definitions are always transitively-closed;
+            # otherwise we would have to do recursive rewriting.
+            return self._defs[src]
+        else:
+            # otherwise, create a symbol representing the source
+            return sympy.Symbol(src.name())
+
+    def is_equal(self, source1, source2):
+        return (
+            # check whether source1 and source2 have the same root
+            self._find(source1) == self._find(source2) or
+            # check whether source1 is derived equal to source2
+            self.is_derived(source1, source2, lambda x: x)
+        )
+
+    def is_derived(self, src, symbol_src, fn):
+        # check whether both src and symbol_src have the same definition
+        return self._rewrite(src) == fn(self._rewrite(symbol_src))
+
+
+def _assert_symbol_context(symbolic_context):
+    assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object"
+    assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC"
+
+
+@dataclass(frozen=True)
+class SymbolicContext:
+    """
+    Data structure specifying how we should create symbols in
+    ``create_symbolic_sizes_strides_storage_offset``; e.g., should
+    they be static or dynamic.
+
+    This is an abstract base class because we are probably going to add
+    another version of this that says "use exactly these SymInts, don't
+    allocate fresh symbols."
+    """
+    pass
+
+
+@dataclass(frozen=True)
+class StatelessSymbolicContext(SymbolicContext):
+    """
+    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
+    a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``.
+    This will cause fresh symbols to be allocated
+    """
+    dynamic_sizes: DimList[DimDynamic]
+    constraint_sizes: DimList[DimConstraint] = None
+    # If the tensor is a view, this should be populated for the base. It contains
+    # information on how to allocate symbols when recursively fakeifying the base
+    # during view fake-ification.
+    view_base_context: Optional[SymbolicContext] = None
+    # TODO: add storage offset and stride symbolic_context
+
+    def __post_init__(self):
+        if self.constraint_sizes is None:
+            object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes))
+
+
+# note [Tensor Fakification and Symbol Caching]
+#
+# As of the time of this note, dynamo creates a fresh fake tensor mode for backends.
+# The reason we do this is because there are certain classes of operations, namely,
+# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor
+# state at the end of a dynamo trace is different than the fake tensor state at the beginning
+# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,
+# view relationships, etc.
+#
+# As we create a new fake mode, we also lose the memoization that comes with it. Rather than
+# transfer the memoization cache, we instead transfer the shape env. However, with this
+# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in
+# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across
+# recompilations.
+#
+# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass
+# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.
+# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is
+# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors
+# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env
+# is used.
+# TODO(voz): Shape env validation
+@dataclass(frozen=True)
+class StatefulSymbolicContext(StatelessSymbolicContext):
+    """
+    Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
+    a symbolic_context determination as given by a cache of Source:Symbol. A cache hit
+    will reuse a stored symbol, and a cache miss will write to this cache.
+
+    This behaves like StatelessSymbolicContext, except the cache supersedes the
+    other values - dynamic_sizes and constraint_sizes will not be read if we cache
+    hit.
+
+    It is the cache owners responsibility to maintain the lifecycle of the cache
+    w/r/t different shape_envs, clearing, etc.
+    """
+    tensor_source: Source = None
+    # Why is this keyd on int first?
+    # That integer is actually the id of the shape_env. This cache short-circuits symbol
+    # creation, and we must store it per shape env. Now, while tracing invariants are a single
+    # shape env per tracing context, and every new frame gets a new shape_env. So where would we have
+    # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events
+    # is invoked, and creates a new shape_env. Replaying events against this new shape_env will
+    # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
+    # get recorded in var_to_val, etc.
+    # TODO(voz): consider a weakref to the shape_env here
+    shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None
+
+    def __post_init__(self):
+        # The None default is annoying, but required because of dataclass limitations
+        assert self.tensor_source is not None
+        if not self.shape_env_to_source_to_symbol_cache:
+            object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {})
+
+
+@dataclass(frozen=True)
+class SubclassSymbolicContext(StatefulSymbolicContext):
+    """
+    The correct symbolic context for a given inner tensor of a traceable tensor subclass
+    may differ from that of the outer symbolic context. This structure allows for this
+    flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
+    """
+    inner_contexts: Dict[str, SymbolicContext] = None
+
+    def __post_init__(self):
+        super().__post_init__()
+        if self.inner_contexts is None:
+            self.inner_contexts = {}
+
+
+def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
+    if isinstance(val, (int, float, bool)):
+        return False
+    return val.node.is_symbolic()
+
+IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
+
+@lru_cache(256)
+def safe_expand(r):
+    if hasattr(r, 'expand'):
+        try:
+            return sympy.expand(r)
+        except RecursionError:
+            log.warning("RecursionError in sympy.expand(%s)", r)
+            return r
+    else:
+        return r
+
+def error():
+    raise AssertionError("shouldn't be hit")
+
+
+# TODO: Deduplicate this with torch/_prims_common/__init__.py
+def eval_is_non_overlapping_and_dense(sizes, strides):
+    return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))
+
+def _eval_is_non_overlapping_and_dense(sizes, strides):
+    dim = len(sizes)
+
+    # Short-circuits for tensors of rank one, which are
+    # non-overlapping and "dense" if their stride is one
+    # or it is a 0/1 element tensor
+    if dim == 1:
+        return strides[0] == 1 or sizes[0] < 2
+
+    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
+    # Sorts (length, stride) pairs by stride
+    lengths_and_strides = sorted(
+        zip(sizes, strides), key=operator.itemgetter(1)
+    )
+
+    # Unlike the C++ code, we don't move the 0/1 size dimensions to the
+    # end.  So we have to keep going for this code.
+    expected_stride = 1
+    for length, stride in lengths_and_strides:
+
+        if length == 1:
+            continue
+
+        if stride != expected_stride:
+            return False
+
+        expected_stride *= length
+
+    return True
+
+
+def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
+    int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True))
+    return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()))
+
+SYMPY_INTERP = {
+    'Abs': operator.abs,
+    'Eq': operator.eq,
+    'Ne': operator.ne,
+    'Gt': operator.gt,
+    'Lt': operator.lt,
+    'Le': operator.le,
+    'Ge': operator.ge,
+    'Min': min,
+    'Max': max,
+    'Mod': operator.mod,
+    'FloorDiv': operator.floordiv,
+    'TrueDiv': operator.truediv,
+    'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
+    'floor': math.floor,
+    'ceiling': math.ceil,
+    'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
+    'Round': builtins.round,
+    'RoundDecimal': builtins.round,
+}
+
+
+def _lru_cache(fn, maxsize=None):
+    """
+    Wrapper around lru_cache that clears when new info about shapes has been
+    updated.
+
+    Use lru_cache if the output is always the same, regardless of the
+    constraints we know now (i.e. evaluate_expr)
+
+    Use _lru_cache otherwise.
+
+    Also note that this depends on _update_version_counter being called on the
+    shape environment whenever the constraints are updated, otherwise the cache
+    will not be cleared.
+    """
+    fn_cache = lru_cache(maxsize)(fn)
+    prior_version = 0
+
+    if config.validate_shape_env_version_key:
+        prior_key = None
+
+        @functools.wraps(fn)
+        def wrapper(self, *args, **kwargs):
+            nonlocal prior_version, prior_key
+            if prior_key is None:
+                prior_key = self._get_key()
+
+            if prior_version != self._version_counter:
+                fn_cache.cache_clear()
+                prior_version = self._version_counter
+                prior_key = self._get_key()
+            else:
+                assert prior_key == self._get_key(), \
+                    "ShapeEnv cache key changed without version being updated!"
+
+            return fn_cache(self, *args, **kwargs)
+
+    else:
+
+        @functools.wraps(fn)
+        def wrapper(self, *args, **kwargs):
+            nonlocal prior_version
+            if prior_version != self._version_counter:
+                fn_cache.cache_clear()
+                prior_version = self._version_counter
+
+            return fn_cache(self, *args, **kwargs)
+
+    wrapper.cache_clear = fn_cache.cache_clear
+    wrapper.cache_info = fn_cache.cache_info  # type: ignore[attr-defined]
+    return wrapper
+
+
+# This is pretty similar to ShapeGuard but it also comes with a message,
+# and is exclusively used for things that MUST be true (unlike guards,
+# which can evaluate False, in which case you just choose not to use
+# a particular specialization)
+@dataclass(frozen=True)
+class RuntimeAssert:
+    expr: sympy.Expr
+    msg: str = field(repr=False)
+    stack: str = field(repr=False)
+
+
+class ShapeGuardPrinter(StrPrinter):
+    def __init__(
+        self,
+        symbol_to_source,
+        source_ref,
+        var_to_sources,
+    ):
+        super().__init__()
+        self.symbol_to_source = symbol_to_source
+        self.source_ref = source_ref
+        self.var_to_sources = var_to_sources
+
+    def _print_Not(self, expr):
+        return 'not %s' % (self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
+
+    def _print_And(self, expr):
+        return self.stringify(expr.args, " and ", PRECEDENCE["And"])
+
+    def _print_Or(self, expr):
+        return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
+
+    def _print_Symbol(self, expr) -> str:
+        assert isinstance(expr, sympy.Symbol), str(type(expr))
+
+        def repr_symbol_to_source():
+            return repr({
+                symbol: [s.name() for s in sources]
+                for symbol, sources in self.symbol_to_source.items()
+            })
+
+        assert self.symbol_to_source.get(expr), (
+            f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
+            f"not in {repr_symbol_to_source()}.  If this assert is failing, it could be "
+            "due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
+        )
+        return self.source_ref(self.symbol_to_source[expr][0])
+
+
+class LoggingShapeGuardPrinter(ShapeGuardPrinter):
+    def __init__(self, var_to_sources):
+        super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
+
+
+class DynamicDimConstraintPrinter(StrPrinter):
+    """
+    Printer for dynamic dim constraints.
+    - Instead of t.size()[d] it prints dynamic_dim(t, d)
+    - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc.
+
+    We use this to suggest code for specifying dynamic dim constraints.
+    """
+    def __init__(self, symbol_to_source, source_name_to_debug_name):
+        super().__init__()
+        self.symbol_to_source = symbol_to_source
+        self.source_name_to_debug_name = source_name_to_debug_name
+
+    def print_source(self, source) -> str:
+        if self.source_name_to_debug_name:
+            return source.name()
+        return f"dynamic_dim({source.base.name()}, {source.idx})"
+
+    def _print_Symbol(self, expr) -> str:
+        assert isinstance(expr, sympy.Symbol), str(type(expr))
+        assert self.symbol_to_source.get(expr), (
+            f"Unknown symbol {expr} created by constraints solver"
+        )
+        return self.print_source(self.symbol_to_source[expr][0])
+
+    def _print_Relational(self, expr):
+        return '{} {} {}'.format(
+            self.parenthesize(expr.lhs, precedence(expr)),
+            expr.rel_op,
+            self.parenthesize(expr.rhs, precedence(expr))
+        )
+
+
+class DimConstraints:
+    """
+    Custom solver for a system of constraints on symbolic dimensions.
+    Solutions are "static" values or simplified "dynamic" constraints.
+    """
+
+    def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_debug_name):
+        # We try to solve systems of inequalities with 1 free variable.
+        self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
+        # Among them, we prioritize solving for a free variable that has equalities.
+        # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
+        # and removing a symbol from the former => removing it from the latter.
+        self._symbols_with_equalities: Set[sympy.Symbol] = set()
+        # A solution of a free variable with equalities becomes a substitution.
+        # We use these substitutions to simplify other constraints.
+        # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
+        self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {}
+
+        # In general, constraints may have // and % operations.
+        # Of course, // can be expressed in terms of / and %.
+        # Our inequality solver can handle / but not %. So we need to transform them away.
+        # We do so by using the values of variables as hints to evaluate %.
+        # For soundness we record additional congruence guards and solve them separately.
+        self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val
+        self._congruences: Set[sympy.Expr] = defaultdict(set)
+
+        # We do not try to (directly) solve inequalities with > 1 free variables.
+        # NOTE: free variables in these inequalities cannot also be in _substitutions.
+        self._multivariate_inequalities: Set[sympy.Expr] = set()
+
+        # We park external equalities between free variables here.
+        self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = []
+
+        # Solutions come in two forms:
+        # - (static) specializations
+        # - (dynamic) inequalities / congruences
+        self._static_results: Set[str] = set()
+        self._dynamic_results: Set[str] = set()
+
+        # printer for solutions
+        self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name)
+
+        # inconsistencies found on substituting with concrete values / static solutions
+        self._inconsistencies: List[str] = []
+
+        # symbols that are marked dynamic
+        self._marked_dynamic = marked_dynamic
+
+    def rewrite_with_congruences(self, s, expr):
+        """
+        Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
+        This leaves rational operators (in particular of the form b / d) that our inequality solver can handle.
+        We solve the added congruences separately (using our congruence solver, see below).
+        """
+        def mod_handler(*args):
+            # Suppose that we have an expression of the form b % d with free variable s.
+            # Using the value of s as a "hint," we can evaluate b % d to a value k.
+            # Then we can rewrite b % d to k while adding the guard b % d == k.
+
+            # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF
+            # the original expression always evaluates to a constant value (i.e., it does not vary with s).
+            # In other words,
+            # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with
+            #   the original expression;
+            # - while it may be possible to find solutions of s with the original expression that are not
+            #   solutions with the rewritten expression, in that case the original expression cannot evaluate
+            #   to the same value for all solutions of s.
+            #
+            # Should we be worried about this incompleteness? No, because of the following reasons:
+            # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech
+            #    (i.e., "don't let perfect be the enemy of the good").
+            # 2. We already have a tradition of using hints to add guards in the compiler for making progress.
+            # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards
+            #    we generate (or simplify to) seem to be of the form b % d == k where k is a constant.
+            #
+            # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2.
+            # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
+            # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
+            base, divisor = args
+            base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
+            mod_reduced = base.subs(self._var_to_val) % divisor.subs(self._var_to_val)
+            congruence = (base - mod_reduced) % divisor
+            if congruence != 0:
+                self._congruences[s].add(congruence)
+            return mod_reduced
+
+        def floor_div_handler(*args):
+            # Suppose that we have an expression of the form b // d with free variable s.
+            # Using the value of s, we can evaluate b % d to a value k.
+            # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k.
+
+            # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
+            # and eliminating b % d as above.
+            base, divisor = args
+            base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
+            mod_reduced = base.subs(self._var_to_val) % divisor.subs(self._var_to_val)
+            congruence = (base - mod_reduced) % divisor
+            if congruence != 0:
+                self._congruences[s].add(congruence)
+            return (base - mod_reduced) / divisor
+
+        if expr.has(Mod):
+            expr = expr.replace(Mod, mod_handler)
+        if expr.has(FloorDiv):
+            expr = expr.replace(FloorDiv, floor_div_handler)
+        return expr
+
+    def add(self, expr) -> bool:
+        """Add an expression to the set of constraints.
+
+        Return whether the expression is a trivial constraint (i.e., an obvious tautology).
+        """
+        if expr == sympy.true:
+            return True
+        orig_expr = expr
+        orig_reduced = orig_expr.subs(self._var_to_val)
+        # TODO(avik): https://github.com/pytorch/pytorch/issues/101093
+        # It is possible that `expr` will fail the consistency check because of
+        # precision errors. Specifically, on substituting its free symbols with
+        # their concrete values, we might end up comparing floats. Until we have
+        # a fix for this issue, we delay raising such failures. See solve().
+        if orig_reduced == sympy.false:
+            self._inconsistencies.append(f"{orig_expr} is inconsistent!")
+        if isinstance(expr, sympy.Ne):
+            # we're not going to do anything useful with these, so drop them
+            return False
+        free_symbols = expr.free_symbols
+        assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
+        if len(free_symbols) > 1:
+            # multivariate: record and move on
+            self._multivariate_inequalities.add(expr)
+        else:
+            # univariate: can solve these immediately
+            s = next(iter(free_symbols))
+            # eliminate // and % (see documentation of `rewrite_with_congruences` above)
+            old_n_congruences = len(self._congruences[s])
+            expr = self.rewrite_with_congruences(s, expr)
+            new_n_congruences = len(self._congruences[s])
+            if expr == sympy.true:
+                return old_n_congruences == new_n_congruences
+            reduced = expr.subs(self._var_to_val)
+            if reduced == sympy.false:
+                self._inconsistencies.append(
+                    f"{expr}, obtained by rewriting {orig_expr} with congruences, "
+                    "is inconsistent!"
+                )
+            if isinstance(expr, sympy.Eq):
+                # special status for symbols that have equalities (see `solve` below)
+                self._symbols_with_equalities.add(s)
+            self._univariate_inequalities[s].add(expr)
+        return False
+
+    def add_equality(self, source, expr):
+        """Add an equality constraint"""
+        if expr.is_number:
+            # specialization, right here
+            self._static_results.add(f"{source.name()} == {expr}")
+        else:
+            # these will resolve to either specializations or dynamic equality constraints
+            self._symbolic_equivalences.append((source, expr))
+
+    def _reduce_congruences(self):
+        reduced_congruences = {}
+        for s, congruences in self._congruences.items():
+            remainder_modulus_pairs = []
+            congruences_to_check = set()
+            for congruence in congruences:
+                base, divisor = congruence.args
+                # We are given a congruence of the form base % divisor == 0 with a free variable s. So:
+                # - we transform this into an equation of the form base = divisor * tmp;
+                # - we solve this equation for s to get a linear solution with free variable tmp.
+                tmp = sympy.Symbol("tmp", integer=True)
+                symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
+                # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
+                # for how to interpret the results.
+                if s == symbol:
+                    # This means the solution is of the form s = modulus*tmp + remainder.
+                    modulus, remainder = sympy.polys.polytools.div(solution, tmp)
+                    if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer):
+                        # Make sure 0 <= remainder <= modulus.
+                        remainder = remainder % modulus
+                        remainder_modulus_pairs.append((remainder, modulus))
+                        continue
+                # This means that we did not get a unique solution to the equation.
+                # No problem, we will check it.
+                congruences_to_check.add(congruence)
+            # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i).
+            # The solution will be a congruence of the form s = r mod m.
+            # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT.
+            if remainder_modulus_pairs:
+                remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs)
+                reduced_congruences[s] = {(s - remainder) % modulus}
+                substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder}
+                reduced_congruences[s].update(
+                    congruence for congruence in congruences_to_check
+                    if not sympy.checksol(congruence, substitution)
+                )
+            else:
+                reduced_congruences[s] = congruences_to_check
+
+        return reduced_congruences
+
+    def _raise_inconsistencies(self):
+        if self._inconsistencies:
+            msg = "\n".join(self._inconsistencies)
+            self._inconsistencies.clear()
+            raise ValueError(f"The following inconsistencies were found:\n{msg}")
+
+    def _force_specialization(self, s):
+        val = self._var_to_val[s]
+        self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
+        self._substitutions[s] = val
+
+    def _specialize_divisor_symbols(self):
+        for expr in self._multivariate_inequalities:
+            for atom in expr.atoms(FloorDiv, Mod):
+                _, divisor = atom.args
+                for s in divisor.free_symbols:
+                    self._force_specialization(s)
+
+        multivariate_inequalities = self._multivariate_inequalities
+        self._multivariate_inequalities = set()
+        for expr in multivariate_inequalities:
+            self.add(expr.subs(self._substitutions))
+        self._raise_inconsistencies()
+        self._univariate_inequalities = {
+            s: exprs
+            for s, exprs in self._univariate_inequalities.items()
+            if s not in self._substitutions
+        }
+        self._congruences = {
+            s: congruences
+            for s, congruences in self._congruences.items()
+            if s not in self._substitutions
+        }
+
+    def solve(self, disable_congruences=True, disable_equivalences=True):
+        """Solve the system of constraint equations to find simplified constraints
+        """
+        self._raise_inconsistencies()
+        # as long as there are symbols with equalities, solve for them
+        # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
+        while self._symbols_with_equalities:
+            s = self._symbols_with_equalities.pop()
+            exprs = self._univariate_inequalities.pop(s)
+            solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
+            if isinstance(solution, sympy.And):
+                solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution)
+            assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}"
+            symbol, val = solution.args
+            assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
+            # because this is univariate, the solution is a specialization
+            self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
+            # add this as a substitution to simplify other constraints
+            self._substitutions[s] = val
+
+            # simplify multivariate inequalities: some of them will now become univariate!
+            multivariate_inequalities = self._multivariate_inequalities
+            self._multivariate_inequalities = set()
+            for expr in multivariate_inequalities:
+                self.add(expr.subs(s, self._substitutions[s]))
+            self._raise_inconsistencies()
+
+        self._specialize_divisor_symbols()
+
+        # solve linear congruences
+        # NOTE(avik): We do not need to solve them for symbols that have already been specialized.
+        reduced_congruences = self._reduce_congruences()
+        for s, congruences in reduced_congruences.items():
+            for congruence in congruences:
+                # any congruence that cannot be checked becomes a dynamic constraint as well
+                if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
+                    if self._is_supported_congruence(congruence):
+                        base, divisor = congruence.args
+                        tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
+                        tmp = sympy.Symbol(tmp_name, integer=True)
+                        from torch._dynamo.source import ConstantSource
+                        self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)]
+                        r = try_solve(sympy.Eq(base, divisor * tmp), s)
+                        self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1])))
+                    elif disable_congruences:
+                        self._force_specialization(s)
+                        self._univariate_inequalities.pop(s, None)
+
+        # remaining symbols have only pure inequalities (no equalities)
+        for s, exprs in self._univariate_inequalities.items():
+            try:
+                solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
+                # because this is univariate, the solution is a dynamic (range) constraint
+                if isinstance(solution, sympy.Or):
+                    solution = next(iter(arg for arg in solution.args if arg.subs(self._var_to_val)))
+                if isinstance(solution, sympy.And):
+                    for arg in solution.args:
+                        self._dynamic_results.add(self._dcp.doprint(arg))
+                else:
+                    self._dynamic_results.add(self._dcp.doprint(solution))
+            except (NotImplementedError, AssertionError) as e:
+                log.warning("Failed to reduce inequalities: %s", e)
+                for expr in exprs:
+                    self._dynamic_results.add(self._dcp.doprint(expr))
+
+        # simplify symbolic equivalences: some of them will now become specializations!
+        symbolic_equivalences = self._symbolic_equivalences
+        self._symbolic_equivalences = []
+        for source, expr in symbolic_equivalences:
+            if disable_equivalences and not self._is_supported_equivalence(expr):
+                for s in expr.free_symbols:
+                    self._force_specialization(s)
+                    sexpr = self._dcp._print_Symbol(s)
+                    self._dynamic_results = {r for r in self._dynamic_results if sexpr not in r}
+            self.add_equality(source, expr.subs(self._substitutions))
+
+        # remaining symbolic equivalences become dynamic equality constraints
+        for source, expr in self._symbolic_equivalences:
+            self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")
+
+    @classmethod
+    def _is_supported_equivalence(cls, expr):
+        # Currently supported Dim ops are linear expressions with integer coefficients.
+        # So check that expr only contains +, *, ints, and a single occurrence of a symbol.
+        # (See also documentation of dynamic_shapes._DerivedDim.)
+        if isinstance(expr, (sympy.Add, sympy.Mul)):
+            lhs, rhs = expr.args
+            return (
+                (cls._is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or
+                (isinstance(lhs, sympy.Integer) and cls._is_supported_equivalence(rhs))
+            )
+        return isinstance(expr, sympy.Symbol)
+
+    @classmethod
+    def _is_supported_congruence(cls, congruence):
+        base, divisor = congruence.args
+        # Congruences that can be currently expressed with supported Dim ops are
+        # of the form (x + a) % b == 0, where x is a Dim and a and b are constants.
+        # This allows us to derive x as b*y - a for some Dim y.
+        # (See also documentation of dynamic_shapes._DerivedDim.)
+        if isinstance(base, sympy.Add):
+            lhs, rhs = base.args
+            cond = (
+                (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or
+                (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol))
+            )
+        else:
+            cond = isinstance(base, sympy.Symbol)
+        cond = cond and isinstance(divisor, sympy.Integer)
+        return cond
+
+    def forced_specializations(self):
+        """Returns a dictionary of the names of symbols to their specialized value
+        """
+        def debug_name(src):
+            name = src.name()
+            if self._dcp.source_name_to_debug_name:
+                return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
+            else:
+                return name
+
+        return {
+            debug_name(self._dcp.symbol_to_source[s][0]): val
+            for s, val in self._substitutions.items()
+            if s in self._marked_dynamic
+        }
+
+    def remove_redundant_dynamic_results(self):
+        """Remove constraints of the form 2 <= dynamic_dim(...) as 2 is the default
+        lower bound.
+        """
+        candidates_for_removal = []
+        dynamic_results = set()
+        for dc in self._dynamic_results:
+            # Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
+            # There is no change in behavior since 2 is the default lower bound.
+            dc_ = re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
+            if dc != dc_:
+                candidates_for_removal.append(dc_)
+            else:
+                dynamic_results.add(dc_)
+        for dc in candidates_for_removal:
+            # remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also
+            # appears as part of another constraint
+            found = False
+            for other_dc in dynamic_results:
+                if dc in other_dc:
+                    found = True
+            if not found:
+                dynamic_results.add(dc)
+        self._dynamic_results = dynamic_results
+
+    def prettify_results(
+        self,
+        original_signature: inspect.Signature,
+        constraint_violation_error=None,
+        forced_specializations=None,
+    ):
+        """Format a message for constraint violation erros"""
+        if self._dcp.source_name_to_debug_name:
+            def transform(s):
+                for k, v in self._dcp.source_name_to_debug_name.items():
+                    s = s.replace(k, v)
+                return s
+
+            results = defaultdict(dict)
+
+            def flip(op):
+                if op == "<=":
+                    return ">="
+                if op == ">=":
+                    return "<="
+                if op == "<":
+                    return ">"
+                if op == ">":
+                    return "<"
+                assert op == "=="
+                return op
+
+            def relation_with_digit(expr, op, digit):
+                if op == "<=":
+                    results[expr]["max"] = digit
+                elif op == "<":
+                    results[expr]["max"] = digit - 1
+                elif op == ">=":
+                    results[expr]["min"] = digit
+                elif op == ">":
+                    results[expr]["min"] = digit + 1
+                else:
+                    assert op == "=="
+                    results[expr]["eq"] = digit
+
+            for s in self._static_results.union(self._dynamic_results):
+                t = transform(s)
+                if t == s:
+                    continue
+                left, op, right = re.split(r"( == | <= | >= | < | > )", t)
+                op = op.strip()
+                if op == "==" and left == right:
+                    continue
+                if right.isdigit():
+                    relation_with_digit(left, op, int(right))
+                elif left.isdigit():
+                    relation_with_digit(right, flip(op), int(left))
+                else:
+                    assert op == "=="
+                    results[left]["eq"] = sympy.sympify(right)
+
+            buf = ""
+            debug_names = set()
+            if forced_specializations:
+                debug_names.update(k.split(" = ")[0] for k in forced_specializations.keys())
+                buf += (
+                    f"Specializations unexpectedly required ({', '.join(debug_names)})! "
+                    "For more information, run with TORCH_LOGS=\"+dynamic\".\n"
+                )
+                for s, val in forced_specializations.items():
+                    buf += f"  - {s} must be specialized to {val} because the guards generated for it are too complex.\n"
+
+            dims = []
+            others = []
+            match = None
+            if constraint_violation_error:
+                match = re.search(r"Constraints violated \((.*)\)", constraint_violation_error.args[0])
+            if match is not None:
+                debug_names.update(match.expand(r'\1').split(', '))
+
+            for k, c in sorted(results.items()):
+                # if k not in debug_names:
+                #     continue
+                if "eq" in c:
+                    other = c["eq"]
+                    if isinstance(other, int):
+                        others.append(f"{k} = None  # {other}")
+                    elif self._is_supported_equivalence(other):
+                        s = next(iter(other.free_symbols))
+                        if s not in results:
+                            modulus, remainder = sympy.polys.polytools.div(other, s)
+                            c_min = c.get("min", 2)
+                            min_ = math.ceil((c_min - remainder) / modulus)
+                            c_max = c.get("max", sys.maxsize - 1)
+                            max_ = math.floor((c_max - remainder) / modulus)
+                            dims.append(f"{s} = Dim('{s}', min={min_}, max={max_})  # {c_min} <= {other} <= {c_max}")
+                        others.append(f"{k} = {other}")
+                else:
+                    min_ = c.get("min", None)
+                    if min_ == 2:
+                        min_ = None
+                    max_ = c.get("max", None)
+                    if min_ is not None and max_ is not None:
+                        dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
+                    elif min_ is not None:
+                        dims.append(f"{k} = Dim('{k}', min={min_})")
+                    elif max_ is not None:
+                        dims.append(f"{k} = Dim('{k}', max={max_})")
+                    else:
+                        dims.append(f"{k} = Dim('{k}')")
+
+            buf += "\nSuggested fixes:\n  "
+            buf += "\n  ".join(dims + others)
+
+            return buf
+
+        # Note: Model inputs are wrapped as LocalSource in dynamo.
+        # LocalSource.name() wraps the name with L[""]. We use regular
+        # expression to do the replacement to avoid traversing up
+        # the source hierarchy manually.
+        def extract_and_rewrite_local(dc):
+            match = re.search(r"L\['(.+?)'\]", dc)
+            if match is None:
+                return
+            arg = match.expand(r'\1')
+            dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
+            return arg, dc
+
+        def group(results, args_index):
+            groups = defaultdict(list)
+            for dc in results:
+                local = extract_and_rewrite_local(dc)
+                if local is None:
+                    # This can happen, e.g., with `assume_constant_result`.
+                    # In that case, we drop the constraint.
+                    # TODO(avik) Maybe we should generate an assertion here?
+                    continue
+                arg, dc = local
+                if arg in args_index:
+                    groups[args_index[arg]].append(dc)
+                else:
+                    # This can happen, e.g., with decorators that change the signature.
+                    # In that case, we drop the constraint. Seems hard to do better. :/
+                    # TODO(avik) Maybe warn that `arg` in not in `signature`?
+                    continue
+            sorted_groups = []
+            for idx, dcs in sorted(groups.items()):
+                _, arg = idx
+                sorted_groups.append((arg, sorted(dcs)))
+            return sorted_groups
+
+        signature = original_signature.replace(return_annotation=inspect.Signature.empty)
+        args_index = {}
+        for i, arg in enumerate(signature.parameters.keys()):
+            args_index[arg] = (i, arg)
+
+        def print_results(grouped, indent, result_fn):
+            nonlocal buf
+
+            space = False
+            for arg, results in grouped:
+                if space:
+                    buf += "\n"
+                else:
+                    space = True
+                buf += f"\n{indent}# {arg}:"
+                for result in results:
+                    buf += f"\n{indent}{result_fn(result)}"
+
+        buf = ""
+        if forced_specializations:
+            buf += (
+                "Some dynamic dimensions need to be specialized because "
+                "the constraints inferred for them are too complex to specify.\n"
+            )
+            for s, val in forced_specializations.items():
+                buf += f"  - {s}, which was marked dynamic, must be specialized to {val}.\n"
+        indent = 4 * " "
+        if self._static_results:
+            grouped_static_results = group(self._static_results, args_index)
+            buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
+            buf += f"\n```\ndef specializations{str(signature)}:"
+            print_results(
+                grouped_static_results,
+                indent,
+                lambda result: f"assert {result}",
+            )
+            buf += "\n```\n"
+        if self._dynamic_results:
+            grouped_dynamic_results = group(self._dynamic_results, args_index)
+            buf += "\nThe following dimensions CAN be dynamic."
+            buf += "\nPlease use the following code to specify the constraints they must satisfy:"
+            buf += f"\n```\ndef specify_constraints{str(signature)}:"
+            buf += f"\n{indent}return ["
+            print_results(
+                grouped_dynamic_results,
+                indent * 2,
+                lambda result: f"{result},",
+            )
+            buf += f"\n{indent}]\n```\n"
+        return buf
+
+
+TLS = threading.local()
+
+
+class ShapeEnv:
+    # This is a wrapper over the actual __init__ function.
+    #
+    # Where to add a new constructor parameter to ShapeEnv?
+    # =====================================================
+    # This __init__ function should be used only for parameters related to event recording.
+    # These are parameters that we don't wish to pass down the road to new ShapeEnv instances
+    # created from replaying events.
+    #
+    # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event
+    # recording, do so in the _init function.
+    def __init__(
+        self, *,
+        should_record_events: Optional[bool] = None,
+        tracked_fakes: Optional[List[Any]] = None,
+        **kwargs
+    ) -> None:
+        self._init(**kwargs)
+
+        # Disable event recording when replaying.
+        kwargs["should_record_events"] = False
+
+        from torch.fx.experimental.validator import translation_validation_enabled
+        self._translation_validation_enabled = translation_validation_enabled()
+
+        # If not specified, enable event recording if both:
+        #   - Translation validation is on
+        #   - Translation validation bisection is not disabled
+        self.should_record_events = (
+            should_record_events
+            if should_record_events is not None
+            else (
+                self._translation_validation_enabled
+                and not config.translation_validation_no_bisect
+            )
+        )
+
+        # Enable event recording check if both:
+        #   - It should record events
+        #   - The recording check is enabled
+        self.check_recorded_events = (
+            self.should_record_events and config.check_shape_env_recorded_events
+        )
+
+        # This will make sure we only record the top-level function call.
+        self.is_recording = not self.should_record_events
+        # Keep track of the list of tracked fakes.
+        self.tracked_fakes = tracked_fakes
+        # List of events for reconstructing ShapeEnv at arbitrary points in time.
+        self.events: List[ShapeEnvEvent] = (
+            [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
+        )
+
+    # Pro-tip: if you add new field to ShapeEnv, this affects some accept
+    # tests.  Accept their output with:
+    #
+    #   EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal
+    #
+    def _init(
+        self, *,
+        allow_scalar_outputs=True,
+        allow_dynamic_output_shape_ops=True,
+        # NB: These are legacy configuration that help us make good choices
+        # when the constraint/dynamic dims are not explicitly passed to us.
+        # Ideally we will fix all call sites to be explicit and not have
+        # implicit choices, but this apparently was pretty involved.
+        assume_static_by_default=False,
+        # Note - On 0/1 specialization
+        #
+        # The following options affect decisions we make about eager
+        # specialization.  Disabling them will increase trace time (as we do
+        # more symbolic reasoning) and can also harm the quality of generated
+        # code (because inductor may not be able to specialize for bounds
+        # being equal--although if we later respecialize because of a guard,
+        # your code may be just as good as it was before.)
+        #
+        # When True, eagerly specialize input sizes which have 0/1.
+        specialize_zero_one=True,
+        # When True, assume input sizes which have the same size are
+        # symbolically equal.
+        duck_shape=True,
+        # For debugging
+        co_fields=None,
+        # XXX Add any new settings that could affect FakeTensor evaluation
+        # to: torch._subclasses.fake_tensor._ShapeEnvSettings
+    ):
+        # Not directly used by ShapeEnv; indirectly used by FakeTensor
+        self.allow_scalar_outputs = allow_scalar_outputs
+        self.allow_dynamic_output_shape_ops = allow_dynamic_output_shape_ops
+        self.guards: List[ShapeGuard] = []
+        # Maps symbolic ints to their original concrete values
+        # Currently populated from tensors
+        self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
+        # Maps symbolic ints to their min/max range.  These ranges
+        # are conservative: the int MUST fall in the range, but the
+        # range may contain ints which may not actually appear in
+        # practice
+        self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
+        self.source_name_to_debug_name: Dict[str, str] = {}
+        self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
+        self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
+        # Maps from sympy ints to expressions representing them
+        # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
+        self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
+        # Set holds a % b expressions that evaluate to 0.
+        self.divisible: Set[sympy.Expr] = set()
+        # Set that holds "size-like" symbols.  When we perform
+        # "size-oblivious" tests, these can be assumed to be >= 2.
+        self.size_like: Set[sympy.Symbol] = set()
+        # Duck-shaping says that if two input tensors have the same size,
+        # they get assigned the same symbolic variable
+        self.val_to_var: Dict[int, sympy.Expr] = {}
+        if specialize_zero_one:
+            self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)}
+        self.unbacked_symfloat_counter = itertools.count()
+        self.unbacked_symint_counter = itertools.count()
+        # Similar to guards, but these MUST evaluate to true and can
+        # only be evaluated at runtime midway through (i.e., they always
+        # involve unbacked symints)
+        #
+        # For efficiency reasons, we index in the following way.  Suppose you have
+        # a runtime assert i0 + i1 <= s1.  We pick the most recently allocated
+        # symbol in the source expression and add the assert to the list for
+        # that symbol e.g., {i1: [i0 + i1 <= s1]}.
+        #
+        # We access the runtime asserts in two situations:
+        #
+        #   - When we are guarding on an expression, we will attempt to
+        #     statically evaluate it, in case the unbacked SymInts can
+        #     simplify away.  If we have a runtime assert, we may be able
+        #     to discharge the guard entirely.  We only need to attempt
+        #     runtime asserts that mention freevars of the expression in
+        #     question.
+        #
+        #   - When we are performing codegen (in Inductor for eager, or
+        #     when finalizing the export FX graph), we need to know what
+        #     extra runtime asserts to insert.  Whenever an unbacked
+        #     SymInt comes into scope, all runtime asserts involving it
+        #     become eligible for insertion (so long as all of their other
+        #     free unbacked symbols are also in scope).  We technically
+        #     can handle any choice of key by kicking inexpressible asserts
+        #     to the next unbacked symbol to wait on, but if we choose the
+        #     latest key, an assert will only show up at the moment when
+        #     we can actually codegen it.
+        self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {}
+        # This exists so we can efficiently invalidate the cache (it's used as
+        # part of the cache key); otherwise we'd have to iterate through
+        # deferred_runtime_asserts to compute its length
+        self.num_deferred_runtime_asserts = 0
+        self.assume_static_by_default = assume_static_by_default
+        self.specialize_zero_one = specialize_zero_one
+        self.duck_shape = duck_shape
+        self.log = log
+        self.log.debug("create_env")
+        self.frozen = False
+        self.dim_constraints: Optional[DimConstraints] = None
+        self.counter = collections.Counter()
+        # Mapping from sympy.Symbol to the number of guards which mention this
+        # symbol
+        self.symbol_guard_counter = collections.Counter()
+        # A selection of important fields on co_field; solely used for
+        # signpost_event
+        self.co_fields = co_fields if co_fields else {}
+
+        # Version counter used to invalidate cached values
+        self._prev_cache_key = self._get_key()
+        self._version_counter = 0
+
+        # Cache for FX nodes.
+        # Maps an already built node a tuple of:
+        #   1. node's target
+        #   2. list of arguments
+        # This drastically reduces the size of the FX graph, avoiding
+        # duplicated nodes.
+        self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
+        self.source_to_symbol: Dict[str, sympy.Symbol] = {}
+
+        from torch.fx.experimental.validator import translation_validation_enabled
+        self._translation_validation_enabled = translation_validation_enabled()
+
+        if self._translation_validation_enabled:
+            from torch.fx.experimental.validator import TranslationValidator
+
+            self.validator = TranslationValidator()
+            self.graph = torch.fx.Graph()
+            # Create an output graph and start inserting before that.
+            # This is needed when 'deepcopy'-ing this object.
+            self.graph.inserting_before(self.graph.output(None))
+
+            # Mapping of each node name to the node itself.
+            #
+            # This is useful for matching an FX node from a recorded ShapeEnv.graph
+            # to the FX node of the ShapeEnv we are running the event on.
+            #
+            # Whenever you add a node to self.graph, you must add a mapping to this
+            # variable. Otherwise, the built FX graph on the replayed ShapeEnv will
+            # not be valid.
+            self.name_to_node: Dict[str, torch.fx.Node] = {}
+
+    def check_equal(self, other: "ShapeEnv") -> None:
+        """Compare another ShapeEnv for equivalence
+        """
+        # ShapeEnv fields that are not relevant for the outcome of
+        # ShapeEnv.produce_guards call:
+        #   - Debugging variables
+        #   - Translation validation related variables
+        #   - Events recording related variables
+        non_state_variable_names = (
+            "counter",
+            "log",
+            "var_to_stack",
+            "fx_node_cache",
+            "graph",
+            "validator",
+            "check_recorded_events",
+            "should_record_events",
+            "is_recording",
+            "tracked_fakes",
+            "events",
+            "source_name_to_debug_name",
+            "_prev_cache_key",
+            "_version_counter",
+        )
+
+        # Mapping of the value of each to-be-compared field into the values that
+        # should actually be compared.
+        #
+        # You should modify this if, for example, the field that holds state and
+        # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)
+        # and the stack when it was added to the set of guards. In order to compare
+        # it, we throw away the stack information.
+        def map_value(key: str, value: Any) -> Any:
+            if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"):
+                from copy import copy
+
+                # For itertools.count(), we compare the next integer returned
+                # by the count iterators. Not that we need to copy the iterator
+                # first. Otherwise we are mutating the object.
+                return next(copy(value))
+            elif key == "guards":
+                # Transform the list of ShapeGuard into a list of expressions.
+                return [g.expr for g in value]
+            elif key == "deferred_runtime_asserts":
+                # Transform the list of RuntimeAsserts into a list of expressions.
+                return {s: [ra.expr for ra in ras] for s, ras in value.items()}
+            elif key == "name_to_node":
+                # Compare just the set of keys is the same.
+                return set(value.keys())
+            elif key == "symbol_guard_counter":
+                # Skip this for comparisons
+                return None
+            return value
+
+        shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
+
+    def _snapshot_tracked_fakes(self) -> Optional[List[Any]]:
+        if self.tracked_fakes is None:
+            return None
+
+        from torch._dynamo.variables.builder import TrackedFake
+
+        def maybe_transform_fake(fake: TrackedFake):
+            inner_fake = fake.fake \
+                if isinstance(fake.fake, torch.SymInt) \
+                else FakeTensorMeta.from_fake(fake.fake)
+            # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a
+            # FakeTensorMeta for two reasons:
+            #   1. this is all the information we need when recording ShapeEnvEvents.
+            #   2. it works even if each TrackedFake changes its metadata.
+            return TrackedFake(inner_fake, fake.source, fake.symbolic_context)  # type: ignore[arg-type]
+
+        return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
+
+    def _last_event_index(self) -> int:
+        return len(self.events) - 1
+
+    @contextmanager
+    def _recording(self):
+        self.is_recording = True
+        try:
+            yield
+        finally:
+            self.is_recording = False
+
+    @record_shapeenv_event()
+    def freeze(self):
+        """Freeze this ShapeEnv to stop accumulating guards
+
+        A frozen ShapeEnv will ignore any further guards generated on it and
+        only emit a warning which may lead to accuracy problems.
+        """
+        self.frozen = True
+
+    def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
+        if not self._translation_validation_enabled:
+            return None
+        srcname = source.name()
+        if source not in self.source_to_symbol:
+            self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
+        return self.source_to_symbol[srcname]
+
+    def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None:
+        if self._translation_validation_enabled:
+            self.validator.add_var(symbol, type)
+
+    def _add_target_expr(self, expr) -> None:
+        if self._translation_validation_enabled:
+            self.validator.add_target_expr(expr)
+
+    def _add_assertion(self, expr) -> None:
+        if self._translation_validation_enabled:
+            self.validator.add_assertion(expr)
+
+    def _check_translation_validate(self) -> None:
+        if self._translation_validation_enabled:
+            self.validator.validate()
+
+    @record_shapeenv_event()
+    def _create_fx_call_function(
+            self,
+            op: Callable,
+            args: Tuple,
+    ) -> Tuple[Optional[torch.fx.Node], bool]:
+        # Cache this tuple in order to avoid duplicated nodes.
+        node_key = (op, args)
+        # Flags whether the returned node was cached or not.
+        fresh = False
+
+        if self._translation_validation_enabled and node_key not in self.fx_node_cache:
+            from torch.fx.experimental.validator import z3op
+
+            # Presence of None in the arguments implies that we should ignore this operation.
+            if any(a is None for a in args):
+                # We check if we are not mixing SymNode that should not be ignored
+                # (fx_node is not None) with those that should (fx_node is None).
+                assert all(not isinstance(a, torch.fx.Node) for a in args)
+                return None, fresh
+
+            fresh = True
+            lifted_op = z3op(op, self.validator)
+
+            # If translation validation is enabled, all arguments must have its
+            # own FX node.
+            assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}"
+            node = self.fx_node_cache[node_key] = self.graph.call_function(lifted_op, args)
+            self.name_to_node[node.name] = node
+
+        return self.fx_node_cache.get(node_key, None), fresh
+
+    def _create_fx_placeholder_and_z3var(
+            self,
+            symbol: sympy.Symbol,
+            type: Type,
+    ) -> Optional[torch.fx.Node]:
+        if not self._translation_validation_enabled:
+            return None
+
+        node_key = (self.graph.placeholder, (symbol,))
+
+        # Check if we haven't added this symbol already.
+        # If so, skip the placeholder creation, as it
+        # generates invalid Python code.
+        if node_key not in self.fx_node_cache:
+            # Add a Z3 variable according to 'type'.
+            self._add_z3var(symbol, type)
+            # Create the FX placeholder out of a mangled name.
+            mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name))
+            node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name)
+            self.name_to_node[node.name] = node
+            # Attach the 'symbol' to the placeholder so that we can retrieve
+            # the Z3 variable later.
+            node.meta["symbol"] = symbol
+
+        return self.fx_node_cache[node_key]
+
+    def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
+        if self._translation_validation_enabled and node is not None:
+            self.name_to_node.pop(node.name)
+            self.graph.erase_node(node)
+
+    def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
+        from torch._dynamo.utils import get_current_node
+
+        if self.should_record_events:
+            node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
+            node.meta[CURRENT_NODE_KEY] = get_current_node()
+
+    def _suppress_guards_tls(self):
+        return getattr(TLS, "suppress_guards", False)
+
+    @record_shapeenv_event()
+    def _suppress_guards_enter(self):
+        TLS.suppress_guards = True
+
+    @record_shapeenv_event()
+    def _suppress_guards_exit(self):
+        TLS.suppress_guards = False
+
+    @contextmanager
+    def suppress_guards(self):
+        """Context manager to ignore all guards generated inside"""
+        self._suppress_guards_enter()
+        try:
+            yield
+        finally:
+            self._suppress_guards_exit()
+
+    def _get_key(self):
+        """
+        Defines the current "state" of the guards we've accumulated in this ShapeEnv.
+        Determines when we need to invalidate our cache
+        """
+        return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts)
+
+    def _update_version_counter(self):
+        # The shape environment is queried orders of magnitude more often than
+        # it is changed, so we summarise the cache key into a linearly
+        # increasing version counter which is cheaper to check in _lru_cache
+
+        # Only update version counter if the state actually changed
+        cur_key = self._get_key()
+        if self._prev_cache_key != cur_key:
+            self._prev_cache_key = cur_key
+            self._version_counter += 1
+
+    def _produce_dyn_sizes(self,
+                           ex_size: Sequence[int],
+                           source: Source,
+                           symbolic_context: SymbolicContext
+                           ) -> List[sympy.Expr]:
+        return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context)
+
+    def _produce_dyn_sizes_from_int_tuple(self,
+                                          tensor_size: Tuple[int],
+                                          source: Source,
+                                          symbolic_context: SymbolicContext,
+                                          ) -> List[sympy.Expr]:
+        assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}"
+        from torch._dynamo.source import TensorPropertySource, TensorProperty
+        _assert_symbol_context(symbolic_context)
+        dynamic_dims = symbolic_context.dynamic_sizes
+        constraint_dims = symbolic_context.constraint_sizes
+        size = []
+        for i, val in enumerate(tensor_size):
+            size.append(self.create_symbol(
+                val,
+                TensorPropertySource(source, TensorProperty.SIZE, i),
+                dynamic_dims[i],
+                constraint_dims[i],
+                symbolic_context=symbolic_context
+            ))
+        return size
+
+    def create_symbolic_sizes_strides_storage_offset(
+        self,
+        ex: torch.Tensor,
+        source: Source,
+        *,
+        symbolic_context: Optional[SymbolicContext] = None,
+    ):
+        """
+        Returns a list of symbolic sizes and strides for the given tensor.
+        We try our best to express stride in terms of the sizes, so as to not
+        introduce new symbolic variables.
+        """
+
+        # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
+        # We create symbols in shape_env using the backed hints behind SymInt.
+
+        # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
+        # produce_guards will trigger specializations on the outer stuff
+
+        # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
+        #
+        # It's probably good for now but it's important to note that this approach has implications for
+        # the original shape_env when checking guards in different order.
+
+        # Example:
+        # ---------
+        # Consider a function "opt_f" as shown below:
+
+        # @torch.compile()
+        # def opt_f(x: bool, y: Tensor):
+        #   if x == True:
+        #     return y + torch.randn([4])
+        #   else:
+        #     return y
+        # Depending on the sequence of calls, we might install two different sets of guards:
+
+        # 1. opt_f(False, y):
+        #    - "x == False" (always works for any size y)
+
+        # 2. opt_f(True, y):
+        #    - Triggers recompilation and results in guards like:
+        #      - "x == True and y.size(0) == 4"
+        #      - (or "y.size(0) == 4 and x == True")
+
+        # The order of checking the guards matters. In this specific example:
+        # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
+        # we may have an unnessary shape speciliazation for y.
+        def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
+            assert isinstance(maybe_sym, (int, torch.SymInt))
+            if is_symbolic(maybe_sym):
+                assert maybe_sym.node.shape_env is not self, \
+                    "expect the symbol is created from an shape env other than current one."
+                return maybe_sym.node.require_hint()
+            return maybe_sym
+
+        ex_size = tuple(maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
+        ex_stride = tuple(maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
+        ex_storage_offset = maybe_specialize_sym_int_with_hint(ex.storage_offset())
+
+        return self._create_symbolic_sizes_strides_storage_offset(
+            ex_size,
+            ex_stride,
+            ex_storage_offset,
+            [_is_dim_dynamic(ex, i) for i in range(ex.dim())],
+            source,
+            symbolic_context=symbolic_context,
+        )
+
+    @record_shapeenv_event()
+    def _create_symbolic_sizes_strides_storage_offset(
+        self,
+        ex_size: Sequence[int],
+        ex_stride: Sequence[int],
+        ex_storage_offset: int,
+        is_dim_dynamic: Sequence[bool],
+        source: Source,
+        *,
+        symbolic_context: Optional[SymbolicContext] = None,
+    ):
+        dim = len(ex_size)
+
+        # Reimplement the legacy behavior
+        if symbolic_context is None:
+            constraint_dims = [None] * dim
+            dynamic_dims = []
+            for i in range(dim):
+                # NB: This is encapsulation breaking!  Legacy behavior was
+                # bad.
+                if is_dim_dynamic[i]:
+                    r = DimDynamic.DYNAMIC
+                elif self.assume_static_by_default:
+                    r = DimDynamic.STATIC
+                else:
+                    r = DimDynamic.DUCK
+                dynamic_dims.append(r)
+            dynamic_dims = [DimDynamic.DUCK] * dim
+            # symbolic_context is None - set one
+            symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims)
+        # We got a StatelessSymbolicContext
+        _assert_symbol_context(symbolic_context)
+        constraint_dims = symbolic_context.constraint_sizes
+        dynamic_dims = symbolic_context.dynamic_sizes
+
+        # TODO: make this configurable from outside symbolic_context; we made a symbolic_context
+        # decision here where if all sizes are static, we are going to
+        # specialize all of the inner strides/offset too. We don't have to
+        # do this, and arguably we should ALWAYS allow for dynamic offset,
+        # this is cheap.
+        # TODO: This should be DYNAMIC, using DUCK for BC
+        dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK
+
+        assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}"
+        assert len(constraint_dims) == dim
+
+        from torch._dynamo.source import TensorPropertySource, TensorProperty
+        size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context)
+        stride: List[Optional[sympy.Expr]] = [None] * len(size)
+        for i, val in enumerate(ex_stride):
+            if val in (0, 1):
+                stride[i] = sympy.Integer(val)
+        while any(x is None for x in stride):
+            candidates = {
+                ex_size[i] * ex_stride[i]: size[i] * stride[i]
+                for i in range(len(size))
+                if stride[i] is not None and ex_stride[i] >= 0
+            }
+
+            # iterate over unbound strides in sorted order
+            def _nested_int_aware_sort(tup):
+                return (
+                    # Order nested ints by their coefficients.
+                    # 1 here to order nested ints after non-nested-ints.
+                    (1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0])
+                    else (0, *tup)
+                )
+            val_list = sorted(
+                [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
+                key=_nested_int_aware_sort,
+            )
+            for _, i in val_list:
+                if stride[i] is None and ex_stride[i] in candidates:
+                    stride[i] = candidates[ex_stride[i]]
+                    candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i]
+
+            if any(x is None for x in stride):
+                # bind the smallest unbound stride to a new variable
+                val, i = min(
+                    [
+                        (ex_stride[i], i)
+                        for i in range(len(stride))
+                        if stride[i] is None
+                    ], key=_nested_int_aware_sort
+                )
+                stride[i] = self.create_symbol(
+                    val,
+                    TensorPropertySource(source, TensorProperty.STRIDE, i),
+                    dynamic_dim=dynamic_strides_offset,
+                    constraint_dim=None,
+                    symbolic_context=symbolic_context,
+                )
+        assert all(x is not None for x in stride)
+
+        sym_sizes = [
+            self.create_symintnode(
+                sym,
+                hint=hint,
+                source=TensorPropertySource(source, TensorProperty.SIZE, i),
+            )
+            for i, (sym, hint) in enumerate(zip(size, ex_size))
+        ]
+        sym_stride = []
+        for i, stride_expr in enumerate(stride):
+            # NB: Don't duck size the stride; instead use the expression
+            # we computed
+            assert stride_expr is not None
+            sym_stride.append(self.create_symintnode(
+                stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i)))
+        sym_storage_offset = self.create_symintnode(
+            self.create_symbol(
+                ex_storage_offset,
+                TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
+                dynamic_dim=dynamic_strides_offset,
+                constraint_dim=None,
+                symbolic_context=symbolic_context
+            ),
+            hint=ex_storage_offset,
+            source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET))
+        return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset
+
+    @record_shapeenv_event()
+    def create_symintnode(
+            self,
+            sym: "sympy.Expr",
+            *,
+            hint: Optional[int],
+            source: Optional[Source] = None,
+    ):
+        """Create a SymInt value from a symbolic expression
+
+        If you know what the current hint value of the SymInt to be created
+        is, pass it into hint.  Otherwise, pass None and we will make our best
+        guess
+
+        """
+        source_name = source.name() if source else None
+
+        if self._translation_validation_enabled and source is not None:
+            # Create a new symbol for this source.
+            symbol = self._create_symbol_for_source(source)
+            assert symbol is not None
+
+            # Create a new FX placeholder and Z3 variable for 'symbol'.
+            fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
+
+            # Add an equality assertion for the newly created symbol and 'sym'.
+            self._add_assertion(sympy.Eq(symbol, sym))
+        else:
+            fx_node = None
+
+        if isinstance(sym, sympy.Integer):
+            if hint is not None:
+                assert int(sym) == hint
+            out = int(sym)
+        else:
+            out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
+        return out
+
+    @record_shapeenv_event()
+    def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
+        """Create a SymInt wrapping a new unspecified symbol"""
+        return self.create_symintnode(
+            self.create_unspecified_symbol(
+                value,
+                source=source,
+                dynamic_dim=dynamic_dim,
+            ),
+            hint=value,
+            source=source,
+        )
+
+    def create_symboolnode(self, sym: "sympy.Expr"):
+        """Create a SymBool object from a sympy boolean expression"""
+        # This function is only being used in serialization, so we do not track it
+        # for validation.
+        return SymBool(SymNode(sym, self, bool, None))
+
+    def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges):
+        is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',')
+        fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
+        log.info(
+            "%s %s [%s, %s]%s (%s)%s",
+            prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug
+        )
+
+    @record_shapeenv_event()
+    def create_unbacked_symfloat(self):
+        """Create a symbolic float without a hint value
+        """
+        symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
+        self.counter["create_unbacked_symbol"] += 1
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
+        vr = self.var_to_range[symbol] = ValueRanges.unknown()
+
+        # Create a new FX placeholder and Z3 variable for 'symbol'.
+        fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
+
+        self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr)
+
+        return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node))
+
+    @record_shapeenv_event()
+    def create_unbacked_symint(self):
+        """Create a symbolic integer without a hint value
+        """
+        symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
+        self.counter["create_unbacked_symbol"] += 1
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
+        vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
+
+        # Create a new FX placeholder and Z3 variable for 'symbol'.
+        fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
+
+        self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr)
+
+        return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))
+
+    def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
+        """Check if a sympy symbol matches the naming convention for unbacked symbols
+        """
+        # NB: keep synced with free_unbacked_symbols
+        return str(symbol).startswith("u")
+
+    @record_shapeenv_event()
+    def create_unbacked_symbool(self):
+        """Create a symbolic boolean without a hint value
+        """
+        symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
+        self.counter["create_unbacked_symbol"] += 1
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
+        vr = self.var_to_range[symbol] = ValueRanges(0, 1)
+
+        # Create a new FX placeholder and Z3 variable for 'symbol'.
+        fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
+
+        self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr)
+
+        return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node))
+
+    @record_shapeenv_event()
+    def create_unspecified_symbol(
+        self,
+        val: Union[int, SymInt],
+        source: Source,
+        dynamic_dim: DimDynamic = DimDynamic.DUCK,
+        constraint_dim: DimConstraint = None,  # NB: includes None
+    ) -> "sympy.Expr":
+        """Create a symbol with an unspecified value
+
+        Compared to standard symbols we do not assume the value is positive,
+        nor do we specialze on zero or one values.
+        """
+        # 'positive' is None for unspecified symbols, since we can't
+        # assume that it will be neither positive nor negative.
+
+        # We don't want to specialize zero one val for unspecified symbol
+        # so that we can always get a new symbol despite val.
+        return self.create_symbol(
+            val,
+            source,
+            dynamic_dim,
+            constraint_dim,
+            positive=None,
+            do_not_specialize_zero_one=True,
+            symbolic_context=None)
+
+    @record_shapeenv_event()
+    def create_symbol(
+        self,
+        val: int,
+        source: Source,
+        dynamic_dim: DimDynamic = DimDynamic.DUCK,
+        constraint_dim: DimConstraint = None,  # NB: includes None
+        positive: Optional[bool] = True,
+        do_not_specialize_zero_one: bool = False,
+        symbolic_context=None,
+    ) -> "sympy.Expr":
+        """Create a new symbol which is tracked by this ShapeEnv
+        """
+        # see note [Tensor Fakification and Symbol Caching]
+        source_name = source.name()
+        if (isinstance(symbolic_context, StatefulSymbolicContext)
+                and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache):
+            symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {}
+
+        if (isinstance(symbolic_context, StatefulSymbolicContext)
+                and source_name
+                and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])):
+            return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]
+
+        if do_not_specialize_zero_one:
+            specialize_zero_one = False
+        else:
+            specialize_zero_one = self.specialize_zero_one
+
+        assert isinstance(source, Source), f"{type(source)} {source}"
+        assert not (positive and val < 0), f"positive set for negative value: {val}"
+        # It's always sound to allocate a symbol as DYNAMIC.  If the user
+        # constrained the symbol, force the symbolic_context to DYNAMIC, because our
+        # constraint code will do weird stuff if, e.g., it's duck shaped
+        if constraint_dim is not None:
+            dynamic_dim = DimDynamic.DYNAMIC
+
+        if dynamic_dim is DimDynamic.STATIC:
+            out = sympy.Integer(val)
+            if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
+                symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out
+            return out
+
+        elif dynamic_dim is DimDynamic.DUCK:
+            # duck_shape can be used to globally turn off duck shaping, even
+            # if it was requested
+            duck = self.duck_shape
+        elif dynamic_dim is DimDynamic.DYNAMIC:
+            duck = False
+        else:
+            raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
+
+        if val in (0, 1) and specialize_zero_one:
+            r = self.val_to_var[val]
+        elif not duck or val not in self.val_to_var:
+            # If we're not duck shaping, we always create a new symbol
+            # Even if we're duck shaping, if we haven't seen this particular
+            # value before, we also create a new symbol
+            sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=positive, integer=True)
+            # We always associate vars to vals
+            if isinstance(val, int):
+                self.var_to_val[sympy_expr] = sympy.Integer(val)
+            else:
+                # Only used for jagged layout nested tensors
+                self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff())
+
+            # Do the appending later, because we always want to populate this
+            self.var_to_sources[sympy_expr] = []
+            # Create a Z3 variable for the new symbol.
+            self._add_z3var(sympy_expr, int)
+
+            if duck:
+                # Make sure to reuse this symbol for subsequent duck shaping
+                self.val_to_var[val] = sympy_expr
+
+            if isinstance(val, int):
+                if positive:
+                    # Add assertions for the newly created symbols
+                    self._add_assertion(sympy_expr > 1)
+
+                    # Apply default range, which assumes not zero-one
+                    self.var_to_range[sympy_expr] = self._default_value_range()
+                else:
+                    self.var_to_range[sympy_expr] = self._default_unspecified_value_range()
+
+                # Small performance optimization: if we have a min-max constraint,
+                # we can proactively narrow to that range
+                if isinstance(constraint_dim, StrictMinMaxConstraint):
+                    assert not duck
+                    self.var_to_range[sympy_expr] &= constraint_dim.vr
+
+                vr = self.var_to_range[sympy_expr]
+
+                if val not in vr:
+                    raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
+
+                range_str = f"[{vr.lower}, {vr.upper}]"
+            else:
+                # Skip var_range logic for SingletonInt
+                # Only used for jagged layout nested tensors
+                range_str = ""
+
+            r = sympy_expr
+
+            is_debug = (
+                config.extended_debug_create_symbol is not None and
+                str(sympy_expr) in config.extended_debug_create_symbol.split(',')
+            )
+            fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
+            self.log.info(
+                "create_symbol %s = %s for %s %s%s (%s)%s",
+                sympy_expr, val, source.name(), range_str,
+                maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug
+            )
+
+            self.counter["create_symbol"] += 1
+        else:
+            # This implements duck-shaping: input sizes that match are assigned
+            # the same symint
+            r = self.val_to_var[val]
+            self.log.debug("create_symbol %s duck sized %s", r, source.name())
+
+        if isinstance(r, sympy.Symbol):
+            r_sources = self.var_to_sources[r]
+            r_sources.append(source)
+            if not source.is_ephemeral() and r_sources[0].is_ephemeral():
+                # prefer non-ephemeral source first since it may be guarded on later
+                r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
+
+            # This ensures we get zeros in symbol_guard_counts, which makes
+            # some queries simpler (since we will accumulate mass on 0 this
+            # way)
+            self.symbol_guard_counter[r] = 0
+
+        if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
+            symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r
+        return r
+
+    def _debug_name(self, source):
+        src_name = source.name()
+        return self.source_name_to_debug_name.get(src_name, src_name)
+
+    def _render_range_for_constraint_violation(self, source, c):
+        if isinstance(c, StrictMinMaxConstraint):
+            lower, upper = c.vr.lower, c.vr.upper
+            default = self._default_value_range()
+            if lower <= default.lower:
+                lower = None
+            if upper >= default.upper:
+                upper = None
+            c_render = f"{self._debug_name(source)} = {source.name()} in the specified range"
+            if lower is not None and upper is not None:
+                c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
+            elif lower is None and upper is not None:
+                c_render += f" {self._debug_name(source)} <= {upper}"
+            elif lower is not None and upper is None:
+                c_render += f" {lower} <= {self._debug_name(source)}"
+            return c_render
+        return c.render(source)
+
+    def produce_guards(
+        self,
+        placeholders,
+        sources,
+        source_ref=lambda n: n.name(),
+        *,
+        input_contexts: Optional[DimList[SymbolicContext]] = None,
+        # Encodes user-specified input shape equations of the form s = s' and s = fn(s').
+        # (See docs on EqualityConstraint for details of the encoding.)
+        equalities_inputs: Optional[EqualityConstraint] = None,
+        _simplified=False,
+        # Indicates if we should produce guards for known static values.
+        ignore_static=True,
+    ) -> List[str]:
+        """
+        Generates a list of guards strings which, when evaluated in a context that
+        defines tensors for all the sources, returns True or False depending
+        on if the guards in the list evaluated to True or not.  Primarily used by Dynamo,
+        but this is also helpful for manual testing of guards (see
+        evaluate_guards_for_args)
+
+        For convenience in testing, a source is allowed to be a str,
+        in which case we will assume it is a LocalSource
+
+        simplified lets you omit duck sizing, equality and 0/1 guards.
+        This is useful for testing when you don't care about the boilerplate
+        guards, and it may be helpful for user output too (be careful though;
+        some equality guards are nontrivial!  It would be nice to get simplified
+        output to print them too).  It's private because it's not
+        intended for normal use
+        """
+        self.log.info("produce_guards")
+
+        # Check if we get to the same ShapeEnv state by replaying the recorded events.
+        # This will create a new ShapeEnv instance, and call all recorded function
+        # calls on this new instance. Finally, it will check whether this new instance
+        # has equal state.
+        #
+        # It's important that we do it in the begining of this function, since it modifies
+        # self.dim_constraints through its execution. Changes that happen in this method
+        # aren't interesting, since this is the function call we wish to reproduce at the
+        # end. If we wish to simply reproduce ShapeEnv instances even after this call,
+        # this method should also be recorded.
+        if self.check_recorded_events:
+            shape_env = replay_shape_env_events(self.events)
+            self.check_equal(shape_env)
+
+        assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})"
+        Tensorlike = (torch.Tensor, FakeTensorMeta)
+
+        def _create_no_constraints_context(t):
+            return StatelessSymbolicContext(
+                # Ignored; only the constraints part is relevant below.
+                dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
+                constraint_sizes=[None] * t.dim()
+            )
+
+        # Expand optional inputs, or verify invariants are upheld
+        if input_contexts is None:
+            input_contexts = [
+                _create_no_constraints_context(t) if isinstance(t, Tensorlike)
+                else None for t in placeholders
+            ]
+        else:
+            assert len(input_contexts) == len(placeholders)
+            for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
+                if isinstance(t, Tensorlike):
+                    if context is None:
+                        input_contexts[i] = _create_no_constraints_context(t)
+                else:
+                    assert isinstance(t, (SymInt, int))
+                    assert not isinstance(context, list)
+
+        # It took a lot of sweat to figure out the algorithm here.  Let's
+        # explain how it works.
+        #
+        # The ShapeEnv lifecycle looks something like this:
+        #
+        # - For each input, you either generate a fresh Sympy symbol (s0) to
+        #   represent its value (a binding site), or you reuse some
+        #   preexisting symbol or expression, skipping the symbol allocation
+        #   (e.g., duck sizing to a preexisting symbol, or expressing a
+        #   stride as a multiplication of a separate stride and size.)
+        #   Naively, you might expect to bind a fresh Sympy symbol for
+        #   every input, but this is fairly wasteful as most of these
+        #   symbols immediately simplify away, and if you don't eagerly
+        #   specialize, e.g., 0/1 symbols, you end up with very complicated
+        #   expressions that are not optimizable in practice.
+        #
+        # - You perform some compute on these symbols, occasionally
+        #   introducing guards on boolean expressions on these symbols.
+        #   In particular, whenever we guard on equality (_maybe_guard_rel),
+        #   we can simplify shapes; e.g., when s0 == s1 * 2, we can now
+        #   replace all occurrences of s0 with s1 * 2.  Sometimes, a
+        #   boolean expression evaluation doesn't introduce a guard, as
+        #   the guard is already entailed by the simplifications we have
+        #   applied.
+        #
+        # - In the end, you have a bunch of replacements (saying how to
+        #   simplify shapes) and a bunch of guards (all the equality guards
+        #   are trivial, because they're covered by the replacements).
+        #
+        # From the ShapeEnv, we must generate a Python expression that, when
+        # evaluated on a set of inputs, tells us whether or not these boolean
+        # expressions would have evaluated in the same way.  However,
+        # we cannot easily compute this, as we elide recording boolean
+        # expressions when we think they are vacuously true.  Thus, we seek
+        # an approximation: we must generate an expression, if true, would have
+        # produced an "equivalent" ShapeEnv, which would answer guard
+        # expressions in the same way.
+        #
+        # Our notion of equivalence is a bit subtle.  For example, consider
+        # the ShapeEnv created from an input of size (5, 4) versus (4, 4)
+        # (no other guards.)  Duck sizing would generate (s0, s1) in the first
+        # case but (s0, s0) in the second.  We do NOT assume that size
+        # variables are disjoint; so in fact a graph that assumes the input
+        # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not
+        # vice versa.  However, consider an analogous case (1,) versus (2,).
+        # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT
+        # subsume the (1,) graph because we assume that any size variables
+        # is NOT 0/1 (and make simplifications according to this; e.g., if
+        # we queried s0 == 0, we would immediately return False without
+        # returning a guard.)
+        #
+        # So, it is perhaps easier to flip things on their head: the guard
+        # expressions we generate here say what simplifications are valid,
+        # and what are not.  Below, we explain each of the guard expressions
+        # we generate
+
+        # TODO: Make this more efficient by binding all the size/stride/offsets
+        # to locals before performing tests on them.
+
+        from torch._dynamo.source import TensorPropertySource, TensorProperty, NegateSource
+
+        # Actual codegen must be delayed as we don't necessarily know what
+        # the symbol mapping is
+        input_guards = []
+
+        symbol_to_source = collections.defaultdict(list)
+        symbol_to_constraints = collections.defaultdict(set)
+        constraint_violations : List[Tuple[bool, Callable[[], str]]] = []
+
+        def record_constraint_violation(warn_only, debug_name, msg, hint=None):
+            constraint_violations.append(
+                (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg)
+            )
+
+        def is_dim(src):
+            return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE
+
+        if equalities_inputs:
+            source_index = {}
+            for i, src in enumerate(sources):
+                source_index[src.name()] = i
+
+            def get_expression(tensor_dim_src):
+                fake = placeholders[source_index[tensor_dim_src.base.name()]]
+                symint = fake.shape[tensor_dim_src.idx]
+                if isinstance(symint, torch.SymInt):
+                    return symint.node.expr
+                else:
+                    assert type(symint) is int, f"Expected int, got {type(symint)}"
+                    return symint
+
+            for src1, src2 in equalities_inputs.source_pairs:
+                expr1, expr2 = get_expression(src1), get_expression(src2)
+                # Check whether given input shape values satisfy a specified equation s = s'.
+                # - Raise when the equation was violated by the given input shape values.
+                # - Otherwise issue a guard to constrain them.
+                concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
+                if not concrete_val:
+                    raise ConstraintViolationError(
+                        f"{src1.name()} = {expr1.subs(self.var_to_val)}"
+                        " is not equal to "
+                        f"{src2.name()} = {expr2.subs(self.var_to_val)}"
+                    )
+
+            for src, root, fn in equalities_inputs.derived_equalities:
+                expr1 = get_expression(src)
+                # recall that root is either a phantom symbol or an input source
+                expr2, debug_name = (
+                    (root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol)
+                    else (get_expression(root), self._debug_name(root))
+                )
+                expr2_ = fn(expr2)
+                # Check whether given input shape values satisfy a specified equation s = fn(s').
+                # - Raise when the equation was violated by the given input shape values.
+                # - Otherwise issue a guard to constrain them.
+                concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
+                if not concrete_val:
+                    raise ConstraintViolationError(
+                        f"Expected input {src.name()} to be equal to "
+                        f"{fn(sympy.Symbol(debug_name))}, "
+                        f"where {debug_name} = {expr2.subs(self.var_to_val)}, "
+                        f"but got {expr1.subs(self.var_to_val)}"
+                    )
+
+            for phantom_symbol in equalities_inputs.phantom_symbols:
+                # we created additional phantom symbols that are not input shape dimensions
+                symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol])
+
+        # How do we know what the value of s0 is?  Fresh variables can only be
+        # bound by inputs, so there MUST be some other input which binds the
+        # variable.  If there is no such input, this is an error in our
+        # system.  We record where all symbols come from, to help you diagnose
+        # why those symbols didn't occur.
+        #
+        # In fact, generally speaking it is only possible for the "outermost"
+        # user of a ShapeEnv to evaluate the guards, because some inputs may
+        # not be available to inner levels.  For example, Dynamo can guard on
+        # tensors that never actually become graph arguments (they are
+        # pruned).  In this case, only Dynamo knows about these arguments.
+        def track_symint(source, val, constraint=None):
+            log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint)
+            assert not isinstance(val, SymInt) or is_symbolic(val)
+
+            if isinstance(val, SymInt) and val.node.maybe_as_int() is not None:
+                val = val.node.maybe_as_int()
+
+            if isinstance(val, SymInt):
+                s = val.node.expr
+                if isinstance(s, sympy.Symbol):
+                    symbol_to_source[s].append(source)
+                    if constraint is not None:
+                        symbol_to_constraints[s].add(constraint)
+                elif isinstance(-s, sympy.Symbol):
+                    symbol_to_source[-s].append(NegateSource(source))
+                else:
+                    constraint_violated = False
+                    if isinstance(constraint, StrictMinMaxConstraint):
+                        # try inferring the ranges of the expr s
+                        sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols}
+                        if all(vr is not None for vr in sym_vrs.values()):
+                            expr_vr = bound_sympy(s, sym_vrs)
+                            if expr_vr != constraint.vr:
+                                # the expr and constrain ranges don't match
+                                constraint_violated = True
+                        else:
+                            # some of the free symbols in s don't have ranges
+                            constraint_violated = True
+                    elif isinstance(constraint, RelaxedUnspecConstraint):
+                        if s.is_number:
+                            i = int(s)
+                            # Don't complain about 0/1 specialization, we
+                            # expect to have to compile in this case anyway
+                            if i not in (0, 1):
+                                constraint_violated = True
+                    if constraint_violated:
+                        def hint(s):
+                            sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
+                            return f"{sexpr}."
+
+                        var_with_range = self._render_range_for_constraint_violation(source, constraint)
+                        msg = (
+                            f"Not all values of {var_with_range} are valid because "
+                            f"{self._debug_name(source)} was inferred to be equal to "
+                        )
+                        record_constraint_violation(
+                            constraint.warn_only,
+                            self._debug_name(source),
+                            msg,
+                            hint=functools.partial(hint, s),
+                        )
+
+                input_guards.append((source, s))
+            else:
+                s = sympy.Integer(val)
+                input_guards.append((source, s))
+                constraint_violated = False
+                if isinstance(constraint, StrictMinMaxConstraint):
+                    constraint_violated = True
+                elif isinstance(constraint, RelaxedUnspecConstraint):
+                    # Don't complain about 0/1 specialization, we
+                    # expect to have to compile in this case anyway
+                    if val not in (0, 1):
+                        constraint_violated = True
+                if constraint_violated:
+                    var_with_range = self._render_range_for_constraint_violation(source, constraint)
+                    msg = (
+                        f"Not all values of {var_with_range} are valid because "
+                        f"{self._debug_name(source)} was inferred to be a constant ({val})."
+                    )
+                    record_constraint_violation(constraint.warn_only, self._debug_name(source), msg)
+
+        for t, source, context in zip(placeholders, sources, input_contexts):
+            if isinstance(source, str):
+                from torch._dynamo.source import LocalSource
+                source = LocalSource(source)
+            assert isinstance(source, Source)
+            if t is None:
+                continue
+            if isinstance(t, (SymInt, int)):
+                track_symint(source, t)
+                continue
+            assert isinstance(t, Tensorlike)
+            if is_traceable_wrapper_subclass(t):
+                from torch._dynamo.source import AttrSource
+
+                assert isinstance(context, SubclassSymbolicContext)
+
+                # For subclasses, we need to track symints on BOTH the outer
+                # and inner tensors.
+                sources_tensors_constraints = [
+                    (source, t, context.constraint_sizes)
+                ]
+                attrs, _ = t.__tensor_flatten__()
+                for attr in attrs:
+                    inner_t = getattr(t, attr)
+                    inner_context = context.inner_contexts[attr]
+                    sources_tensors_constraints.append((
+                        AttrSource(source, attr),
+                        inner_t,
+                        inner_context.constraint_sizes
+                    ))
+            else:
+                sources_tensors_constraints = [(source, t, context.constraint_sizes)]
+
+            for src, curr_t, constraint in sources_tensors_constraints:
+                if is_sparse_any(curr_t):
+                    for i, ss in enumerate(curr_t.size()):
+                        property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
+                        track_symint(property_source, ss, constraint[i])
+                else:
+                    for i, ss in enumerate(curr_t.size()):
+                        property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
+                        track_symint(property_source, ss, constraint[i])
+                    for i, ss in enumerate(curr_t.stride()):
+                        track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss)
+                    track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
+
+        # 1. Every input must equal the final simplified symbolic expression
+        #    stored on the placeholder.  Given a placeholder (s0*2, s1),
+        #    if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
+        #    This does a lot of work: it covers duck sizing and equality guards.
+        exprs = []
+        self.dim_constraints = DimConstraints(
+            symbol_to_source,
+            self.var_to_val,
+            set(symbol_to_constraints.keys()),
+            self.source_name_to_debug_name,
+        )
+
+        if not _simplified:
+            for source, expr in input_guards:
+                if self._translation_validation_enabled:
+                    # Ignore sources that were not turned into SymInts.
+                    srcname = source.name()
+                    if srcname in self.source_to_symbol:
+                        self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr))
+
+                # Small optimization
+                if (
+                    isinstance(expr, sympy.Symbol) and
+                    symbol_to_source.get(expr) and
+                    source == symbol_to_source[expr][0]
+                ):
+                    continue
+
+                # This logic excludes static values found on tensors from guarding, because
+                # dynamo's check_tensor_fn does that (see guards.cpp).
+                # However, for non tensor sources, we still need to guard here.
+                if ignore_static and isinstance(source, TensorPropertySource):
+                    if expr.is_number:
+                        self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}")
+                        continue
+
+                if is_dim(source):
+                    self.dim_constraints.add_equality(source, expr)
+
+                sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
+                exprs.append(f"{source_ref(source)} == {sexpr}")
+                if (
+                    isinstance(source, TensorPropertySource)
+                    and source.prop is TensorProperty.SIZE
+                    and equalities_inputs
+                    and len(expr.free_symbols) == 1
+                ):
+                    symbol = next(iter(expr.free_symbols))
+                    if (
+                        isinstance(expr, sympy.Symbol) and
+                        expr in symbol_to_constraints and
+                        not equalities_inputs.is_equal(source, symbol_to_source[expr][0])
+                    ):
+                        msg = (
+                            f"The values of {self._debug_name(source)} = {source.name()} and "
+                            f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} "
+                            "must always be equal."
+                        )
+                        record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
+
+                    if (
+                        not isinstance(expr, sympy.Symbol) and
+                        symbol in symbol_to_constraints and
+                        not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.subs(symbol, x))
+                    ):
+                        src = symbol_to_source[symbol][0]
+                        msg = (
+                            f"The values of {self._debug_name(source)} = {source.name()} must always be related to "
+                            f"the values of {self._debug_name(src)} = {src.name()} by "
+                            f"{self._debug_name(source)} = {expr.subs(symbol, sympy.sympify(self._debug_name(src)))}."
+                        )
+                        record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
+
+                # NB: Not necessary to report constraint violations here:
+                # constraints are guaranteed to be on symbols (we've already
+                # caught constants and non-atomic expressions), so we only
+                # have relational constraints, but we don't support those
+                # at the moment
+
+        # 2. Every guard must evaluate to True (but remember many guards
+        #    like s0 == s1*2 because trivial due to simplification)
+        issued = set()
+
+        def issue_guard(guard: ShapeGuard) -> None:
+            expr = self.simplify(guard.expr)
+
+            # Avoid re-issueing the same guard.
+            if expr in issued:
+                return
+
+            issued.add(expr)
+
+            try:
+                is_trivial = False
+                if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]):
+                    is_trivial = self.dim_constraints.add(expr)
+                guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
+                exprs.append(guard_expr)
+                self._add_target_expr(expr)
+                # A non-relational constraint on a single sizevar can violate
+                # a constraint
+                if not is_trivial and len(expr.free_symbols) == 1:
+                    symbol = next(iter(expr.free_symbols))
+                    source = symbol_to_source[symbol][0]
+                    constraints = symbol_to_constraints[symbol]
+                    for c in constraints:
+                        if isinstance(c, StrictMinMaxConstraint):
+                            var_with_range = self._render_range_for_constraint_violation(source, c)
+                            msg = (
+                                f"Not all values of {var_with_range} "
+                                f"satisfy the generated guard {guard_expr}."
+                            )
+                            record_constraint_violation(c.warn_only, self._debug_name(source), msg)
+                        elif isinstance(c, RelaxedUnspecConstraint):
+                            # This is fine, we allow guards here as long as it
+                            # didn't constrain it to one value  (we don't
+                            # actually know this; this depends on our
+                            # ValueRanges reasoning capability)
+                            pass
+                        else:
+                            raise AssertionError(f"unrecognized constraint {c}")
+            except Exception:
+                self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
+                raise
+
+        # First, issue all the non-trivial guards.
+        for guard in self.guards:
+            if self._maybe_evaluate_static(guard.expr) is not None:
+                continue
+            issue_guard(guard)
+
+        # 3. Every symbol must be within its value range (this handles 0/1
+        # specialization too).
+        for symbol, sources in symbol_to_source.items():
+            r = self.var_to_range.get(symbol)
+            if r is None:
+                if symbol not in self.var_to_range:
+                    continue
+                r = self.var_to_range[symbol]
+
+            assert sources
+            assert symbol.is_integer
+            bounds = []
+            if r.lower != -sympy.oo:
+                if any(is_dim(source) for source in sources):
+                    self.dim_constraints.add(sympy.Ge(symbol, r.lower))
+                # Only print lower bound in simplified mode if it is not the
+                # default
+                if not _simplified or r.lower != self._default_value_range().lower:
+                    bounds.append(str(r.lower))
+            bounds.append(source_ref(sources[0]))
+            # NB: This looks like an off-by-one error but it's not: the
+            # upper bound may be sys.maxsize - 1 because we intentionally
+            # exclude sys.maxsize from our bounds to deal with direct
+            # == INT_MAX guards, but it's still dumb to actually test it.
+            # Note that you can be off by a pretty large constant and it
+            # won't matter because sizes in practice will be no where near
+            # the 64-bit limit.
+            if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
+                if any(is_dim(source) for source in sources):
+                    self.dim_constraints.add(sympy.Le(symbol, r.upper))
+                # nontrivial upper bound is always interesting
+                bounds.append(str(r.upper))
+            if len(bounds) > 1:
+                exprs.append(" <= ".join(bounds))
+
+                # Check constraints
+                constraints = symbol_to_constraints[symbol]
+                for c in constraints:
+                    if isinstance(c, StrictMinMaxConstraint):
+                        # NB: By default, we have a restrictive range
+                        # 2 <= s0 <= sys.maxsize - 1.  But export users generally
+                        # expect to be able to specify nice ranges like [0, oo]
+                        if not (c.vr & self._default_value_range()).issubset(r):
+                            source = sources[0]
+
+                            expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper))
+                            guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
+                            var_with_range = self._render_range_for_constraint_violation(source, c)
+                            msg = (
+                                f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}"
+                            )
+                            record_constraint_violation(
+                                c.warn_only,
+                                self._debug_name(source),
+                                msg,
+                            )
+
+        if constraint_violations:
+            warn_msgs = []
+            error_msgs = []
+            debug_names = set()
+            for warn_only, debug_name, msg in constraint_violations:
+                if warn_only:
+                    msg = f"  {len(warn_msgs) + 1}. {msg()}"
+                    warn_msgs.append(msg)
+                else:
+                    msg = f"  - {msg()}"
+                    error_msgs.append(msg)
+                    debug_names.add(debug_name)
+            if len(error_msgs) > 0:
+                debug_names = ', '.join(debug_names)
+                err = '\n'.join(error_msgs)
+                raise ConstraintViolationError(
+                    f"Constraints violated ({debug_names})! "
+                    "For more information, run with TORCH_LOGS=\"+dynamic\".\n"
+                    f"{err}"
+                )
+            elif len(warn_msgs) > 0:
+                log.debug("%s Warning only constraints violated", len(warn_msgs))
+
+        signpost_event(
+            "dynamic",
+            "produce_guards",
+            {
+                **self.co_fields,
+                **self.counter,
+                "num_guards": len(exprs),
+                "free_symbols": sum(1 for v in symbol_to_source.values() if v),
+                # The keys are meaningless from an aggregate perspective, so
+                # don't include them.  Biggest first.
+                "symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True),
+            },
+        )
+
+        if self._translation_validation_enabled:
+            from torch.fx.experimental.validator import PopulateValidator
+
+            # Add all deferred runtime assertions; these are not technically
+            # handled by produce_guards but we need to put them in the target
+            # set
+            for ras in self.deferred_runtime_asserts.values():
+                for ra in ras:
+                    self._add_target_expr(ra.expr)
+
+            # Add value range bound guards for all symbols with no trivial bounds.
+            # Reason: '_maybe_evaluate_static' may eliminate guards based on the
+            # refined value ranges.
+            for sym, vr in self.var_to_range.items():
+                if vr.lower != -sympy.oo:
+                    self._add_target_expr(sympy.Le(vr.lower, sym))
+                if vr.upper != sympy.oo:
+                    self._add_target_expr(sympy.Le(sym, vr.upper))
+
+            # Before validating, populate the input of the validator with the
+            # built FX graph.
+            with fx_traceback.preserve_node_meta():
+                PopulateValidator(self.graph, self.validator).run()
+
+        self._check_translation_validate()
+        return exprs
+
+    def produce_guards_expression(self, placeholders, ignore_static=True):
+        """
+        Expected to be used with evaluate_guards_expression(). Produces the guards
+        for the given placeholders and returns a string expression to be evaluated
+        by evaluate_guards_expression given concrete values for the placeholders.
+        """
+        from torch._dynamo.source import LocalSource
+        arg_names = [f"t{i}" for i in range(len(placeholders))]
+        guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
+        if guards:
+            return " and ".join(guards)
+        return None
+
+    def evaluate_guards_expression(self, code, args):
+        """
+        Expected to be used with produce_guards_expression(). Evaluates an expression
+        generated by produce_guards_expression for the given concrete args.
+        """
+        arg_names = [f"t{i}" for i in range(len(args))]
+        return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
+
+    def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
+        """Generate guards for a graph's placeholder values and evaluate the guards with args
+        """
+        code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
+        if code:
+            return self.evaluate_guards_expression(code, args)
+        return True
+
+    def bind_symbols(self, placeholders, args):
+        """
+        Given a paired list of placeholders (fake tensors with
+        symbolic sizes) and concrete arguments (regular tensors
+        with real sizes), returns a dictionary mapping each
+        symbol to its real value.  So for example, if you
+        have a placeholder with size (s0, s1), binding
+        (2, 4) to it will give you {s0: 2, s1: 4}.  This is
+        not guaranteed to bind ALL symbols in the ShapeEnv;
+        we can't bind a symbol if it doesn't occur in any placeholder,
+        and symbols that already have replacements won't get bindings.
+
+        This is a little duplicative with evaluate_guards but
+        it's different enough that it seemed cleanest to make
+        another copy.  This assumes the guards are already checked,
+        though if it's cheap we'll check for shenanigans
+        """
+        bindings: Dict[sympy.Symbol, int] = {}
+
+        def bind_symint(arg, val):
+            if isinstance(val, SymInt):
+                s = val.node.expr
+
+                if isinstance(s, sympy.Symbol):
+                    if s in bindings:
+                        assert bindings[s] == arg, f"{bindings[s]} != {arg}"
+                    else:
+                        bindings[s] = arg
+                elif isinstance(-s, sympy.Symbol):
+                    if -s in bindings:
+                        assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}"
+                    else:
+                        bindings[-s] = -arg
+
+        for t, arg in zip(placeholders, args):
+            if t is None:
+                continue
+            if isinstance(t, SymInt):
+                bind_symint(arg, t)
+                continue
+            assert isinstance(t, torch.Tensor)
+            for i, s in enumerate(t.size()):
+                bind_symint(arg.size(i), s)
+            for i, s in enumerate(t.stride()):
+                bind_symint(arg.stride(i), s)
+            bind_symint(arg.storage_offset(), t.storage_offset())
+
+        return bindings
+
+    def get_nontrivial_guards(self):
+        """Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
+        return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
+
+    def format_guards(self, verbose=False):
+        """Format this shape env's guard expressions with optional traceback info if verbose"""
+        def format_tb(tb):
+            if not verbose:
+                return ""
+            return f"\n   Guarded at:\n{''.join('   ' + l for l in tb.format())}"
+
+        return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
+
+    def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges:
+        """Given a sympy expression, computes a ValueRanges bound for what values it can be"""
+        var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
+        if size_oblivious:
+            # Clamp values of size-like variables
+            for x in self.size_like & var_to_range.keys():
+                if var_to_range[x] is not None:
+                    var_to_range[x] &= ValueRanges(2, sympy.oo)
+        return bound_sympy(expr, var_to_range)
+
+    @_lru_cache
+    def _maybe_evaluate_static(
+        self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False,
+        expect_rational=True, size_oblivious: bool = False
+    ) -> "Optional[sympy.Expr]":
+        """
+        Tries to evaluate expr without introducing guards
+
+        If unbacked_only == True, then we only do substitutions on
+        unbacked SymInts (leaving regular hinted integers alone).  This could
+        result in an expression that still contains backed SymInts, which you
+        could then potentially guard on.
+
+        Use compute_hint == True if you are trying to compute a non-binding
+        hint for the particular hint values of backed SymInts, e.g., if
+        s0 happens to be 3 this run, compute_hint will subsitute s0 with 3.
+        """
+        expr = self.simplify(expr)
+
+        if compute_hint:
+            expr = expr.xreplace(self.var_to_val)
+
+        expr = canonicalize_bool_expr(expr)
+
+        symbols = list(expr.free_symbols)
+
+        # Apply known runtime asserts
+        for s in symbols:
+            # Unbacked symints only
+            if s in self.var_to_val:
+                continue
+
+            subst = {}
+
+            def add_expr(expr):
+                # Expr and negation
+                subst[canonicalize_bool_expr(expr)] = sympy.true
+                subst[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false
+                if isinstance(expr, sympy.Rel):
+                    # multiplying by -1 changes the direction of the inequality
+                    dual = type(expr)(-expr.rhs, -expr.lhs)
+                    subst[canonicalize_bool_expr(dual)] = sympy.true
+                    subst[canonicalize_bool_expr(sympy.Not(dual))] = sympy.false
+
+            for e in itertools.chain(self.guards, self.deferred_runtime_asserts.get(s, ())):
+                e = e.expr
+                if compute_hint:
+                    e = canonicalize_bool_expr(e.xreplace(self.var_to_val))
+                add_expr(e)
+                # Other relational expressions this expression implies
+                if isinstance(e, sympy.Eq):
+                    add_expr(sympy.Le(e.lhs, e.rhs))
+                    add_expr(sympy.Ge(e.lhs, e.rhs))
+                elif isinstance(e, sympy.Lt):
+                    add_expr(sympy.Le(e.lhs, e.rhs))
+                    add_expr(sympy.Ne(e.lhs, e.rhs))
+
+            # NB: this helps us deal with And/Or connectives
+            expr = expr.subs(subst)
+
+        # Simplify making use of value range lower bound
+        new_shape_env = {}
+        new_range_env = {}
+        for idx, k in enumerate(symbols):
+            if isinstance(self.var_to_val.get(k, None), SingletonInt):
+                # Skip var_to_range logic for SingletonInt which is only used
+                # for jagged layout NestedTensors today
+                continue
+            vr = self.var_to_range[k]
+            if size_oblivious and k in self.size_like:
+                lower = max(2, vr.lower)
+            else:
+                lower = vr.lower
+            # Don't do anything if we don't have a nontrivial lower bound
+            # Also don't do anything if we asked only to simplify unbacked
+            # SymInt
+            if (
+                lower < (-sys.maxsize - 1) // 2 or
+                (unbacked_only and k in self.var_to_val)
+            ):
+                new_range_env[k] = vr
+                continue
+            # Positive means >= 1
+            # Positive - 1 means >= 0
+            # Positive + lower - 1 means >= lower
+            # The new symbol 's' is "too low", so when we substitute it in
+            # we have to increase it by offset (and conversely, the new
+            # variables have to have their value range bounds adjusted as
+            # well)
+            s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
+            offset = lower - 1
+            new_shape_env[k] = s + offset
+            new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
+
+        def replace(expr, repl):
+            return expr.xreplace(repl)
+
+        try:
+            new_expr = replace(expr, new_shape_env)
+        except RecursionError:
+            log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
+            self.counter["sympy_recursion_error"] += 1
+            return None
+
+        floor_div_replace = {}
+        for atom in new_expr.atoms(FloorDiv):
+            floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
+        new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
+        # TODO: when unbacked_only, can sometimes early return even when there
+        # are still free symbols
+        if new_expr.is_number:
+            return new_expr
+
+        # Check if the range can solve it statically
+        out = bound_sympy(new_expr, new_range_env)
+        if expect_rational:
+            _assert_bound_is_rational(new_expr, out)
+            if out.is_singleton():
+                return out.lower
+
+        return new_expr if unbacked_only else None
+
+    @_lru_cache
+    def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
+        """Apply symbol replacements to any symbols in the given expression
+        """
+        replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
+        return safe_expand(expr.xreplace(replacements))
+
+    @_lru_cache
+    def _update_divisible(self):
+        new_divisible = set()
+        for k in self.divisible:
+            res = self.replace(k)
+            if not res.is_number:
+                new_divisible.add(k)
+
+        self.divisible = new_divisible
+        self._update_version_counter()
+
+    @_lru_cache
+    def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
+        """Use known constraints and replacements to simplify the given expr
+        """
+        expr = self.replace(expr)
+        # TODO it would seem that this pass is not necessary given the
+        # below replacement of // with /, but for nested FloorDivs
+        # the non-recursive replacement doesn't work, and
+        # recursive makes it hard to look up divisibility,
+        # because existing divisibility info has FloorDiv in it, not /
+        # for now just do a separate pass to catch common nested case
+        if expr.has(FloorDiv):
+            self._update_divisible()
+            div_replacements = {}
+            for atom in expr.atoms(FloorDiv):
+                base, divisor = atom.args
+                if isinstance(divisor, FloorDiv):
+                    base1, divisor1 = divisor.args
+                    if self.replace(Mod(base, divisor)) in self.divisible and \
+                            base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
+                        div_replacements[atom] = divisor1
+            expr = expr.xreplace(div_replacements)
+            expr = safe_expand(expr)
+        if expr.has(FloorDiv):
+            div_replacements = {}
+            pows = expr.atoms(sympy.Pow)
+            rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
+            for fd in expr.atoms(FloorDiv):
+                base, divisor = fd.args
+                if self.replace(Mod(base, divisor)) in self.divisible:
+                    div_replacements[fd] = base / divisor
+            new_expr = expr.xreplace(div_replacements)
+            new_expr = safe_expand(new_expr)
+            new_pows = new_expr.atoms(sympy.Pow)
+            new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))
+            # divisions simplified away
+            if new_pows.issubset(pows) and new_rationals.issubset(rationals):
+                expr = new_expr
+        return expr
+
+    @lru_cache(256)
+    def size_hint(self, expr: "sympy.Expr", *, allow_none=False):
+        """
+        Gets a size hint for a given expression from the underlying shapes we had.
+        Does not introduce a guard, so only use this when you can guarantee that
+        your code is still valid for arbitrary shapes (such as optimization decisions)
+        """
+        result_expr = safe_expand(expr).xreplace(self.var_to_val)
+        if not result_expr.is_number:
+
+            from torch.utils._sympy.singleton_int import SingletonInt
+
+            if isinstance(result_expr, SingletonInt):
+                return None
+            r = self._maybe_evaluate_static(result_expr, compute_hint=True)
+            if r is not None:
+                return r
+            if allow_none:
+                return None
+            raise self._make_data_dependent_error(result_expr, expr)
+        return result_expr
+
+    # NB: keep in sync with size_hint
+    @lru_cache(256)
+    def has_hint(self, expr: "sympy.Expr"):
+        result_expr = safe_expand(expr).xreplace(self.var_to_val)
+        return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None
+
+    def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None):
+        # TODO: in a Dynamo context, having user code, and having the
+        # name of the local, will be much better
+        size_like_symbols = []
+        for s in expr.free_symbols:
+            stacktrace = ''.join(self.var_to_stack[s].format())
+            self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
+            if s in self.size_like:
+                size_like_symbols.append(s)
+        size_oblivious_result_msg = ""
+        if size_oblivious_result is not None:
+            size_oblivious_result_msg = (
+                f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n"
+                "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n"
+            )
+        fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True)
+        return GuardOnDataDependentSymNode(
+            f"Could not guard on data-dependent expression {expr} (unhinted: {unhinted_expr}).  "
+            f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
+            f"{size_oblivious_result_msg}"
+            "Potential framework code culprit (scroll up for full backtrace):\n"
+            f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n"
+            "For more information, run with TORCH_LOGS=\"dynamic\"\n"
+            "For extended logs when we create symbols, also add "
+            f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n"
+            "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
+            "For more debugging help, see "
+            "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" +
+            maybe_extra_debug
+            # TODO: Help text about how to use our runtime tests to fix this
+            # problem
+        )
+
+    def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None:
+        """
+        Adds or updates a replacement for a symbol.
+        Use this instead of `self.replacements[a] = tgt`.
+        """
+
+        # Precondition: a == tgt
+        assert isinstance(a, sympy.Symbol)
+
+        # Handles nested tensor symbolic variables which don't have
+        # var_to_range bounds
+        tgt_bound = None
+        if a in self.var_to_range:
+            src_bound = self.var_to_range[a]
+
+            # If you have x in [2, maxint], then 2*x in [4, 2*maxint].
+            # But we don't really care that the max bound says we can
+            # go beyond the maximum integer size, because we aren't
+            # using bigints anyway.  Arguably, ValueRanges should know
+            # to do this truncation automaticaly (to avoid doing
+            # bigint compute in range analysis), but right now it doesn't
+            # so we need to get rid of some unnecessary precision.
+            int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
+
+            def issubset(x, y):
+                return (x & int_range).issubset(y & int_range)
+
+            # First, refine the value range of a based on the computed value range
+            # of tgt.  This is always OK to do, even if we decide not to do the
+            # substitution in the end.  This might be a no-op, if a already has
+            # a tighter bound
+            tgt_bound = self.bound_sympy(tgt)
+            self.var_to_range[a] = src_bound & tgt_bound
+
+            # Next, check if we can update the range of free symbols in tgt
+            # based on the range in a. But only do it if:
+            #  - the source bound non-trivially improves over what we get out of
+            #    the existing bounds.
+            #  - the replacement is univariate and we can invert the tgt expression
+            if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1:
+                b = next(iter(tgt.free_symbols))
+                # Try to invert the equality
+                r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
+                if r is not None:
+                    b_bound = self.bound_sympy(r[1])
+                    self.var_to_range[b] = b_bound & self.var_to_range[b]
+                    tgt_bound = self.bound_sympy(tgt)
+                    assert issubset(tgt_bound, src_bound)
+
+            # TODO: Should we propagate size-like-ness?
+            #
+            # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
+            # to become size-like.
+            #
+            # Cons: if u0 is size-like, what about u0 - 1 == u1?  You CAN'T
+            # propagate in this case, because what if u0 == 0, then u1 is negative
+            # and clearly isn't a size.  So, at minimum, any f(x) whose value
+            # range isn't [0, inf] given x in [0, inf] cannot propagate
+            # size-like-ness.  But there are many situations where you could
+            # imagine u1 is going to be size-like and actually you just didn't
+            # have a refined enough value range on u0.  Since even innocuous
+            # looking arithmetic operations can destroy size-like-ness, it's
+            # best to not propagate it at all and force the user to annotate it
+            # as necessary.
+            #
+            # Compromise: we preserve size-like-ness only for exact equality
+            # and nothing else.
+            if a in self.size_like and isinstance(tgt, sympy.Symbol):
+                self.size_like.add(tgt)
+            elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
+                self.size_like.add(a)
+
+            # Now, decide if we will do the substitution.
+            #
+            #  - If the source has a non-trivial range, only substitute if
+            #    we preserve this range.  Note that we may have propagated
+            #    the src_range to free variables in tgt when tgt is univariate
+            #    and we could find an inverse, which helps us achieve this.
+            #    This ensures we never "forget" about user defined ranges,
+            #    even if they end up being defined on composite formulas
+            #    like s0 + s1.
+            #
+            #  - If the variable is unbacked, only substitute if the substitution
+            #    would preserve the bounds also under size-like-ness conditions.
+
+            if not issubset(tgt_bound, src_bound):
+                self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
+                return
+            elif a in self.size_like:
+                tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
+                # This is morally equivalent to self.bound_sympy(a, size_oblivious=True)
+                # but handles substitutions like u0 == 0
+                src_bound_so = self.var_to_range[a]
+                if src_bound_so.upper >= 2:
+                    src_bound_so &= ValueRanges(2, sympy.oo)
+                if not issubset(tgt_bound_so, src_bound_so):
+                    self.log.debug("skipped set_replacement %s = %s (%s) "
+                                   "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
+                    return
+
+        if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)):
+            # specializing to a constant, which is likely unexpected
+
+            # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g.,
+            # when adding a to self.replacements, and again when simplifying an expression containing a.
+            # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is,
+            # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant.
+            if a not in self.replacements or tgt != self.replacements[a]:
+                self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt)
+                self.log.debug("SPECIALIZATION", stack_info=True)
+        log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
+        self.replacements[a] = tgt
+        self._update_version_counter()
+
+        # When specializing 'a == tgt', the equality should be also conveyed to
+        # Z3, in case an expression uses 'a'.
+        self._add_target_expr(sympy.Eq(a, tgt))
+
+    def _add_divisible(self, expr: "sympy.Expr"):
+        self.divisible.add(expr)
+        self._update_version_counter()
+
+    @_lru_cache
+    @record_shapeenv_event()
+    def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
+        """
+        Implements a DSU-like algorithm to find the variable that represents a
+        Also handles transitive non-identity replacements.
+
+        a: b + c
+        c: d
+        """
+        if a not in self.replacements:
+            return a
+        res = self.replacements[a]
+        cur_replace = {s: self._find(s) for s in res.free_symbols}
+        self._set_replacement(a, self.replacements[a].xreplace(cur_replace), "find")
+        return self.replacements[a]
+
+    @lru_cache(256)
+    def _maybe_guard_rel(self, expr: "sympy.Rel") -> None:
+        """
+        The relational guard is guarded to be true.  Use this information to
+        simplify shapes (i.e. a == b or a % 5 == 0)
+        """
+        assert isinstance(expr, sympy.Rel)
+
+        # A good example of what goes wrong if you don't do this is
+        # python test/functorch/test_aotdispatch.py -k
+        # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32
+        if isinstance(expr, sympy.Ne):
+            return
+
+        free = list(expr.free_symbols)
+
+        assert len(free) > 0, f"The expression should not be static by this point: {expr}"
+        # In case of really gnarly expression, we don't blow up
+        if len(free) > 5:
+            return
+
+        # Prioritize unbacked symints for solving by ordering them last.
+        # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
+        #   (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
+        # Prefer to simplify out symbols with ephemeral sources.
+        def _smart_symbol_sort(x):
+            has_only_ephemeral_sources = (
+                x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
+            )
+            size = self.size_hint(x, allow_none=True) or sys.maxsize
+            name = x.name
+            # 1 puts ephemeral sourced symbols first when sorting in reverse
+            return (1 if has_only_ephemeral_sources else 0, size, name)
+
+        free = sorted(free, key=_smart_symbol_sort, reverse=True)  # type: ignore[attr-defined]
+        lhs = expr.lhs
+        rhs = expr.rhs
+
+        self._refine_ranges(expr)
+
+        # The rest of this stuff is for equality only
+        if not isinstance(expr, sympy.Eq):
+            return
+
+        if not expr.has(Mod):
+            try:
+                floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))
+                if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms):
+                    raise NotImplementedError
+                # short-circuit when no solving is needed
+
+                if isinstance(lhs, sympy.Symbol) and free_unbacked_symbols(lhs):
+                    self._set_replacement(lhs, self._find(rhs), "trivial_lhs")
+                elif isinstance(rhs, sympy.Symbol) and free_unbacked_symbols(rhs):
+                    self._set_replacement(rhs, self._find(lhs), "trivial_rhs")
+                else:
+                    r = try_solve(expr, free[0], floordiv_inequality=False)
+                    if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])):
+                        new_var = self._find(r[1])
+                        ok = False
+                        if self.is_unbacked_symint(free[0]):
+                            # If you have i0 + i1 + i2 = s0, don't substitute i2 =
+                            # s0 - i0 - i1.  Arguably this should be OK but the
+                            # runtime assert machinery is very delicate right now
+                            # so this causes things to fail e.g.,
+                            # test_split_unbacked_sizes
+                            ok = len(free_unbacked_symbols(new_var)) <= 1
+                            msg = "solve_unbacked"
+                        else:
+                            # Never substitute backed with unbacked
+                            ok = len(free_unbacked_symbols(new_var)) == 0
+                            msg = "solve_backed"
+                        if ok:
+                            self._set_replacement(cast(sympy.Symbol, free[0]), new_var, msg)
+            except NotImplementedError:
+                pass
+        if expr.has(Mod):
+            mod_expr = next(iter(expr.atoms(Mod)))
+            try:
+                r = try_solve(expr, mod_expr, floordiv_inequality=False)
+                if r is not None and r[1] == 0:
+                    self._add_divisible(mod_expr)
+                    # This is a little bit of extra logic to make things like
+                    # torch.empty(i0, q).view(c, -1, q) work out
+                    p, q = mod_expr.args
+                    if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2:
+                        c, i0 = p.args
+                        # Given Mod(c * i0, q) == 0
+                        if (
+                            isinstance(c, sympy.Number) and
+                            isinstance(i0, sympy.Symbol) and
+                            self.is_unbacked_symint(i0)
+                        ):
+                            # We have Mod(i0, q / c) == 0, which means we can
+                            # rewrite i0 as (q / gcd(q, c)) * i1
+                            d = q / sympy.gcd(q, c)
+                            i1 = self.create_unbacked_symint().node.expr
+                            # Propagate the value ranges.  It doesn't really
+                            # matter if we use truediv or floordiv, because we
+                            # have established divisibility.
+                            self.var_to_range[i1] = SymPyValueRangeAnalysis.truediv(
+                                self.var_to_range[i0], ValueRanges.wrap(d)
+                            )
+                            # Propagate size-like-ness
+                            if i0 in self.size_like:
+                                self.size_like.add(i1)
+                            self._set_replacement(i0, d * i1, "divisibility")
+
+            except NotImplementedError:
+                pass
+        return
+
+    # See: Note - On 0/1 specialization
+    # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
+    # as a sentinel sometimes.  Your sizevar isn't going to be
+    # anywhere near the max 64-bit integer anyway.
+    def _default_value_range(self) -> ValueRanges:
+        lower = 2 if self.specialize_zero_one else 0
+        return ValueRanges(lower, sys.maxsize - 1)
+
+    def _default_unspecified_value_range(self) -> ValueRanges:
+        return ValueRanges(-sys.maxsize - 1, sys.maxsize)
+
+    @_lru_cache
+    def _simplify_floor_div(self, expr):
+        floor_divs = tuple(expr.atoms(FloorDiv))
+        # we expect floor_divs to be exact,
+        # and thus add the guards for the exact floordivs,
+        # even if tracing doesn't require them otherwise
+        for fd in reversed(floor_divs):
+            base, divisor = fd.args
+            mod_expr = Mod(base, divisor)
+            eq_expr = sympy.Eq(mod_expr, 0)
+            # add necessary mod guards
+            self.evaluate_expr(eq_expr)
+        return self.simplify(expr)
+
+    # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen
+    # and if so issue a warning
+    def _check_frozen(self, expr, concrete_val):
+        if self.frozen:
+            self.counter["ignored_backward_guard"] += 1
+            signpost_event(
+                "dynamic",
+                "evaluate_expr_frozen",
+                {
+                    **self.co_fields,
+                    "ignored_guard": f"{expr} == {concrete_val}",
+                    # no version = original state (this signpost is expected)
+                    # version 2 = dynamic backwards is eagerly compiled
+                    "version": 2,
+                },
+            )
+            log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
+
+
+    def _get_stack_summary(self, is_debug: bool = False):
+        fsummary = None
+        frame = inspect.currentframe()
+        try:
+            while frame is not None:
+                if frame.f_code.co_filename not in uninteresting_files():
+                    fsummary = traceback.FrameSummary(
+                        frame.f_code.co_filename,
+                        frame.f_lineno,
+                        frame.f_code.co_name,
+                    )
+                    break
+                frame = frame.f_back
+        finally:
+            del frame
+
+        # NB: this stack is truncated, but it's fine because the main
+        # stack_info will give you the rest of the info you need
+        maybe_user_loc = ""
+        user_tb = TracingContext.extract_stack()
+        if user_tb:
+            maybe_user_loc = " at " + format_frame(user_tb[-1])
+
+        maybe_extra_debug = ""
+        if is_debug and user_tb:
+            maybe_extra_debug = (
+                '\nUser Stack (most recent call last):\n' +
+                '  (snipped, see stack below for prefix)\n' +
+                ''.join(traceback.format_list(user_tb))
+            )
+        if is_debug and config.extended_debug_cpp:
+            cpp_stack = CapturedTraceback.extract(cpp=True)
+            maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format())
+
+        return fsummary, maybe_user_loc, maybe_extra_debug
+
+    def _log_guard(self, prefix: str, g, forcing_spec: bool):
+        if self.log.isEnabledFor(logging.INFO):
+            str_g = str(g)
+            is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added
+            fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
+            self.log.info(
+                "%s %s [guard added]%s (%s)%s",
+                prefix if not forcing_spec else f"{prefix} (forcing_spec)",
+                str_g,
+                maybe_user_loc,
+                format_frame(fsummary),
+                maybe_extra_debug,
+                stack_info=is_debug,
+            )
+
+    @lru_cache(256)
+    @record_shapeenv_event(save_tracked_fakes=True)
+    def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None,
+                      expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False):
+        """
+        Given an expression, evaluates it, adding guards if necessary
+        """
+
+        # TODO: split conjunctions and evaluate them separately
+
+        @lru_cache(None)
+        def compute_concrete_val():
+            if hint is None:
+                return self.size_hint(orig_expr)
+            else:
+                return sympy.sympify(hint)
+
+        # Check if:
+        #   1. 'translation_validation' is set
+        #   2. the corresponding 'fx_node' is not 'None'
+        #   3. the guard should not be suppressed
+        #
+        # If all of the above check, we create an FX node representing the
+        # actual expression to be guarded.
+        node = None
+        fresh = False
+        if (
+                self._translation_validation_enabled
+                and fx_node is not None
+                and not self._suppress_guards_tls()
+                and not size_oblivious
+        ):
+            concrete_val = compute_concrete_val()
+            if concrete_val is sympy.true:
+                node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
+            elif concrete_val is sympy.false:
+                neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
+                node, fresh = self._create_fx_call_function(torch._assert, (neg,))
+            else:
+                eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
+                node, fresh = self._create_fx_call_function(torch._assert, (eql,))
+
+            assert node is not None
+            # If this is a fresh node, we have to remember the event index that
+            # corresponds to this assertion node.
+            # Reason: so that, given an assertion node, we can replay the ShapeEnv
+            # events until the point where this assertion node was freshly created.
+            if fresh:
+                self._add_fx_node_metadata(node)
+
+        # After creating the FX node corresponding to orig_expr, we must make sure that
+        # no error will be raised until the end of this function.
+        #
+        # Reason: the translation validation may become invalid otherwise.
+        #
+        # If an error is raised before the end of this function, we remove the FX node
+        # inserted, and re-raise the error.
+        guard = None
+        tb = None
+
+        try:
+            if orig_expr.is_number:
+                self.log.debug("eval %s [trivial]", orig_expr)
+                # NB: don't test float as there may be precision issues
+                if isinstance(hint, (int, bool)):
+                    assert orig_expr == hint, f"{orig_expr} != {hint}"
+                return orig_expr
+
+            expr = orig_expr
+
+            static_expr = self._maybe_evaluate_static(expr,
+                                                      expect_rational=expect_rational,
+                                                      size_oblivious=size_oblivious)
+            if static_expr is not None:
+                self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
+                # NB: don't test float as there may be precision issues
+                if isinstance(hint, (int, bool)):
+                    assert static_expr == hint, f"{static_expr} != {hint}"
+                return static_expr
+
+            if not (expr.free_symbols <= self.var_to_val.keys()):
+                # TODO: dedupe this with _maybe_evaluate_static
+                # Attempt to eliminate the unbacked SymInt
+                new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
+                if not (new_expr.free_symbols <= self.var_to_val.keys()):
+                    size_oblivious_result = None
+                    if not size_oblivious:
+                        size_oblivious_result = self._maybe_evaluate_static(
+                            expr,
+                            expect_rational=expect_rational,
+                            size_oblivious=True
+                        )
+
+                    raise self._make_data_dependent_error(
+                        expr.xreplace(self.var_to_val),
+                        expr,
+                        size_oblivious_result=size_oblivious_result
+                    )
+                expr = new_expr
+
+            concrete_val = compute_concrete_val()
+            self._check_frozen(expr, concrete_val)
+
+            if (
+                    config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
+                    and isinstance(hint, bool)
+                    and isinstance(expr, (sympy.Eq, sympy.Ne))
+            ):
+                expr = sympy.Not(expr)
+
+            # Turn this into a boolean expression, no longer need to consult
+            # concrete_val
+            suppress_maybe_guard_rel = False
+            if concrete_val is sympy.true:
+                g = expr
+            elif concrete_val is sympy.false:
+                g = sympy.Not(expr)
+            else:
+                # WARNING: we cannot actually do simplifications on guards
+                # on floating point values, because Sympy generally does not
+                # think expressions on integers can ever be equal to floating
+                # point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False).  Without
+                # very clear algebraic laws that hold for floating point, such
+                # simplifications are error prone anyway, so be sure not to
+                # maybe_guard_rel in those cases.
+                if not isinstance(concrete_val, sympy.Integer):
+                    suppress_maybe_guard_rel = True
+                g = sympy.Eq(expr, concrete_val)  # type: ignore[arg-type]
+
+            if isinstance(g, sympy.Rel):
+                # TODO: If we successfully eliminate a symbol via equality, it
+                # is not actually necessary to save a guard for the equality,
+                # as we will implicitly generate a guard when we match that
+                # input against the symbol.  Probably the easiest way to
+                # implement this is to have maybe_guard_rel return a bool
+                # saying if it "subsumed" the guard (and therefore the guard
+                # is no longer necessary)
+                self._maybe_guard_rel(g)
+
+            if not self._suppress_guards_tls():
+                stack = CapturedTraceback.extract(skip=1)
+                guard = ShapeGuard(g, stack)
+                # TODO: deal with duplicate guards somehow
+                self.guards.append(guard)
+        except Exception:
+            if fresh:
+                self._remove_fx_node(node)
+            raise
+        else:
+            if not self._suppress_guards_tls():
+                assert guard is not None
+
+                self._log_guard("eval", g, forcing_spec=forcing_spec)
+
+                for s in g.free_symbols:
+                    self.symbol_guard_counter[s] += 1
+                    # Forcing_spec to avoid infinite recursion
+                    if (
+                        not forcing_spec and
+                        config.symbol_guard_limit_before_specialize is not None and
+                        self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize
+                    ):
+                        # Force specialization
+                        self.log.info(
+                            "symbol_guard_limit_before_specialize=%s exceeded on %s",
+                            config.symbol_guard_limit_before_specialize,
+                            s
+                        )
+                        self.evaluate_expr(s, forcing_spec=True)
+            else:
+                self.log.debug("eval %s [guard suppressed]", g)
+
+        return concrete_val
+
+    def cleanup(self):
+        """
+        Break reference cycles.
+
+        This destroys the stacks. If you really want to keep them, we
+        just need some way to break references on code objects.
+        """
+        for g in self.guards:
+            g.stack.cleanup()
+        for s in self.var_to_stack.values():
+            s.cleanup()
+        for ras in self.deferred_runtime_asserts.values():
+            for ra in ras:
+                ra.stack.cleanup()
+
+    @record_shapeenv_event(save_tracked_fakes=True)
+    def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
+        """Create an assert that is checked at runtime
+
+        Args:
+            orig_expr (sympy.Expr): Boolean expression to assert is true
+            msg (str): Message to display on assertion failure
+            fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding
+                to the expression, if applicable
+
+        """
+        expr = orig_expr
+
+        # TODO: split conjunctions and evaluate them separately
+
+        static_expr = self._maybe_evaluate_static(expr)
+        if static_expr is not None:
+            self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr)
+            return static_expr
+
+        # Attempt to eliminate the unbacked SymInt
+        new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
+        if new_expr.free_symbols <= self.var_to_val.keys():
+            # Do a normal guard
+            return self.evaluate_expr(new_expr, fx_node=fx_node)
+        # NB: Don't use new_expr as expr; it could contain gunk like shape0
+        # which we don't want to guard on
+
+        # OK, we're definitely doing a runtime assert now
+        if (
+            self._translation_validation_enabled
+            and fx_node is not None
+            and not self._suppress_guards_tls()
+        ):
+            node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
+            assert node is not None
+            if fresh:
+                self._add_fx_node_metadata(node)
+
+        self._check_frozen(expr, sympy.true)
+
+        # eliminate symbols on equality tests / refine ranges
+        if isinstance(expr, sympy.Rel):
+            self._maybe_guard_rel(expr)
+
+        if not self._suppress_guards_tls():
+            # canonicalise to remove equations that are trivially equal
+            orig_expr = expr
+            expr = canonicalize_bool_expr(expr)
+            stack = CapturedTraceback.extract(skip=1)
+            ra = RuntimeAssert(expr, msg, stack)
+            # TODO: Do this in a way that is less janky than int(s.name[1:])
+            cands = sorted([s for s in expr.free_symbols if s.name.startswith("u")], key=lambda s: int(s.name[1:]))
+            self.deferred_runtime_asserts.setdefault(cands[-1], []).append(ra)
+            self.num_deferred_runtime_asserts += 1
+            self._update_version_counter()
+            self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
+        else:
+            self.log.debug("runtime_assert %s [guard suppressed]", expr)
+
+        return True
+
+    # Refines the ranges of the variables present in 'guard'.
+    #
+    # This function tries to refine the range of the variables inside
+    # 'guard' by reasoning about it. Specifically, when 'guard' is a
+    # 'sympy.Relational' operation.
+    #
+    # It does mainly 3 things:
+    #   1. Tries to isolate a variable in the left-hand side
+    #   2. Compute the value range of the right-hand side
+    #   3. Update the value range of the variable, if better
+    def _refine_ranges(self, expr: sympy.Expr) -> None:
+        expr = self.simplify(expr)
+
+        for symbol in expr.free_symbols:
+            assert isinstance(symbol, sympy.Symbol)
+
+            if isinstance(self.var_to_val.get(symbol, None), SingletonInt):
+                # Skip var_to_range logic for SingletonInt which is only used
+                # for jagged layout NestedTensors today
+                continue
+
+            r = try_solve(expr, symbol)
+
+            if r is None or not (symbol.is_integer and r[1].is_integer):
+                # Range refinement only supports integer symbols for now.
+                # There are lots of SymPy bugs when it comes to comparing
+                # reals and integers, so we skip that for now.
+                continue
+
+            r_expr, rhs = r
+            vr = self.var_to_range[symbol]
+            lower, upper = vr.lower, vr.upper
+
+            rhs_vr = bound_sympy(rhs, self.var_to_range)
+            _assert_bound_is_rational(rhs, rhs_vr)
+
+            # Let's suppose that we have a preexisting range for x [0, 100].
+            # Now, we issue a guard x > y, where the range for y is [50, 150].
+            # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,
+            # refining x to [51, 100], since x must be greater than y, but the lowest
+            # y could be is 50.
+            #
+            # sympy.Eq may update both lower and upper bounds.
+            # sympy.G{t,e} may update the lower bound, only.
+            # sympy.L{t,e} may update the upper bound, only.
+            if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)):
+                # Strictly greater relations allow us to refine a bit more, since
+                # x < y implies that the lower bound for x is: y + 1.
+                lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
+            if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)):
+                upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))
+
+            # Do nothing if the new value range is no better than what we already have.
+            if vr == ValueRanges(lower, upper):
+                continue
+
+            # Updates the range and the guards corresponding to each bound of the symbol.
+            self.var_to_range[symbol] = ValueRanges(lower, upper)
+            # Clears the cache, since this update can change the result.
+            self._maybe_evaluate_static.cache_clear()
+
+def _is_int(expr):
+    return isinstance(expr, SymInt) and expr.node.expr.is_number
+
+# WARNING: This is legacy, DO NOT USE
+def _is_dim_dynamic(t, d):
+    return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__init__.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..78e36933d6bf060e76a60df991002ac37ef52440
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__init__.py
@@ -0,0 +1,4 @@
+# mypy: disable-error-code=attr-defined
+from .core import unify, reify  # noqa: F403
+from .more import unifiable  # noqa: F403
+from .variable import var, isvar, vars, variables, Var  # noqa: F403
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a7e074ec4d8ee3801985889f4a60e608f4d42e1
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65d0ef03b065aafedcc7c76c61b1355a015bb9d6
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a2beff5703af8b4f703958fba5db2c662f4380b4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e892fbbd0e2d94a4faab12fc195ccb28fb939a53
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43f41e3b7a86313f551ee92c6c5b958833bd4390
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1da9101f537cb0f6cf0fdcb428cfe3ba99ae9e10
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79887ebff41997b16520f5ac7ea617dd08d21e40
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac168095063f023aa84379ff3c3f361402694acc
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/core.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..5594c534bb8014f47564ec2eac4488f0887d849a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/core.py
@@ -0,0 +1,118 @@
+from collections.abc import Iterator  # type: ignore[import]
+from functools import partial
+
+from .unification_tools import assoc  # type: ignore[import]
+from .utils import transitive_get as walk
+from .variable import isvar
+from .dispatch import dispatch
+
+__all__ = ["reify", "unify"]
+
+###############
+# Reification #
+###############
+
+@dispatch(Iterator, dict)
+def _reify(t, s):
+    return map(partial(reify, s=s), t)
+    # return (reify(arg, s) for arg in t)
+_reify
+
+@dispatch(tuple, dict)  # type: ignore[no-redef]
+def _reify(t, s):
+    return tuple(reify(iter(t), s))
+_reify
+
+@dispatch(list, dict)  # type: ignore[no-redef]
+def _reify(t, s):
+    return list(reify(iter(t), s))
+_reify
+
+@dispatch(dict, dict)  # type: ignore[no-redef]
+def _reify(d, s):
+    return {k: reify(v, s) for k, v in d.items()}
+_reify
+
+@dispatch(object, dict)  # type: ignore[no-redef]
+def _reify(o, s):
+    return o  # catch all, just return the object
+
+def reify(e, s):
+    """ Replace variables of expression with substitution
+    >>> # xdoctest: +SKIP
+    >>> x, y = var(), var()
+    >>> e = (1, x, (3, y))
+    >>> s = {x: 2, y: 4}
+    >>> reify(e, s)
+    (1, 2, (3, 4))
+    >>> e = {1: x, 3: (y, 5)}
+    >>> reify(e, s)
+    {1: 2, 3: (4, 5)}
+    """
+    if isvar(e):
+        return reify(s[e], s) if e in s else e
+    return _reify(e, s)
+
+###############
+# Unification #
+###############
+
+seq = tuple, list, Iterator
+
+@dispatch(seq, seq, dict)
+def _unify(u, v, s):
+    if len(u) != len(v):
+        return False
+    for uu, vv in zip(u, v):  # avoiding recursion
+        s = unify(uu, vv, s)
+        if s is False:
+            return False
+    return s
+#
+# @dispatch((set, frozenset), (set, frozenset), dict)
+# def _unify(u, v, s):
+#     i = u & v
+#     u = u - i
+#     v = v - i
+#     return _unify(sorted(u), sorted(v), s)
+#
+#
+# @dispatch(dict, dict, dict)
+# def _unify(u, v, s):
+#     if len(u) != len(v):
+#         return False
+#     for key, uval in iteritems(u):
+#         if key not in v:
+#             return False
+#         s = unify(uval, v[key], s)
+#         if s is False:
+#             return False
+#     return s
+#
+#
+# @dispatch(object, object, dict)
+# def _unify(u, v, s):
+#     return False  # catch all
+
+
+@dispatch(object, object, dict)
+def unify(u, v, s):  # no check at the moment
+    """ Find substitution so that u == v while satisfying s
+    >>> x = var('x')
+    >>> unify((1, x), (1, 2), {})
+    {~x: 2}
+    """
+    u = walk(u, s)
+    v = walk(v, s)
+    if u == v:
+        return s
+    if isvar(u):
+        return assoc(s, u, v)
+    if isvar(v):
+        return assoc(s, v, u)
+    return _unify(u, v, s)
+unify
+
+@dispatch(object, object)  # type: ignore[no-redef]
+def unify(u, v):
+    return unify(u, v, {})
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/dispatch.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/dispatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9d6d2d7efde128dd7fa9f78f414df757b3d87a5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/dispatch.py
@@ -0,0 +1,6 @@
+from functools import partial
+from .multipledispatch import dispatch  # type: ignore[import]
+
+namespace = {}  # type: ignore[var-annotated]
+
+dispatch = partial(dispatch, namespace=namespace)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/match.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/match.py
new file mode 100644
index 0000000000000000000000000000000000000000..56c04e0134e231396daa6d5f36563365a4409752
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/match.py
@@ -0,0 +1,121 @@
+from .core import unify, reify  # type: ignore[attr-defined]
+from .variable import isvar
+from .utils import _toposort, freeze
+from .unification_tools import groupby, first  # type: ignore[import]
+
+
+class Dispatcher:
+    def __init__(self, name):
+        self.name = name
+        self.funcs = {}
+        self.ordering = []
+
+    def add(self, signature, func):
+        self.funcs[freeze(signature)] = func
+        self.ordering = ordering(self.funcs)
+
+    def __call__(self, *args, **kwargs):
+        func, s = self.resolve(args)
+        return func(*args, **kwargs)
+
+    def resolve(self, args):
+        n = len(args)
+        for signature in self.ordering:
+            if len(signature) != n:
+                continue
+            s = unify(freeze(args), signature)
+            if s is not False:
+                result = self.funcs[signature]
+                return result, s
+        raise NotImplementedError("No match found. \nKnown matches: "
+                                  + str(self.ordering) + "\nInput: " + str(args))
+
+    def register(self, *signature):
+        def _(func):
+            self.add(signature, func)
+            return self
+        return _
+
+
+class VarDispatcher(Dispatcher):
+    """ A dispatcher that calls functions with variable names
+    >>> # xdoctest: +SKIP
+    >>> d = VarDispatcher('d')
+    >>> x = var('x')
+    >>> @d.register('inc', x)
+    ... def f(x):
+    ...     return x + 1
+    >>> @d.register('double', x)
+    ... def f(x):
+    ...     return x * 2
+    >>> d('inc', 10)
+    11
+    >>> d('double', 10)
+    20
+    """
+    def __call__(self, *args, **kwargs):
+        func, s = self.resolve(args)
+        d = {k.token: v for k, v in s.items()}
+        return func(**d)
+
+
+global_namespace = {}  # type: ignore[var-annotated]
+
+
+def match(*signature, **kwargs):
+    namespace = kwargs.get('namespace', global_namespace)
+    dispatcher = kwargs.get('Dispatcher', Dispatcher)
+
+    def _(func):
+        name = func.__name__
+
+        if name not in namespace:
+            namespace[name] = dispatcher(name)
+        d = namespace[name]
+
+        d.add(signature, func)
+
+        return d
+    return _
+
+
+def supercedes(a, b):
+    """ ``a`` is a more specific match than ``b`` """
+    if isvar(b) and not isvar(a):
+        return True
+    s = unify(a, b)
+    if s is False:
+        return False
+    s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
+    if reify(a, s) == a:
+        return True
+    if reify(b, s) == b:
+        return False
+
+
+# Taken from multipledispatch
+def edge(a, b, tie_breaker=hash):
+    """ A should be checked before B
+    Tie broken by tie_breaker, defaults to ``hash``
+    """
+    if supercedes(a, b):
+        if supercedes(b, a):
+            return tie_breaker(a) > tie_breaker(b)
+        else:
+            return True
+    return False
+
+
+# Taken from multipledispatch
+def ordering(signatures):
+    """ A sane ordering of signatures to check, first to last
+    Topological sort of edges as given by ``edge`` and ``supercedes``
+    """
+    signatures = list(map(tuple, signatures))
+    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
+    edges = groupby(first, edges)
+    for s in signatures:
+        if s not in edges:
+            edges[s] = []
+    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[attr-defined, assignment]
+    return _toposort(edges)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/more.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/more.py
new file mode 100644
index 0000000000000000000000000000000000000000..86d13b155b6de946802bd1459ecf7a2f2783c909
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/more.py
@@ -0,0 +1,117 @@
+from .core import unify, reify  # type: ignore[attr-defined]
+from .dispatch import dispatch
+
+
+def unifiable(cls):
+    """ Register standard unify and reify operations on class
+    This uses the type and __dict__ or __slots__ attributes to define the
+    nature of the term
+    See Also:
+    >>> # xdoctest: +SKIP
+    >>> class A(object):
+    ...     def __init__(self, a, b):
+    ...         self.a = a
+    ...         self.b = b
+    >>> unifiable(A)
+    
+    >>> x = var('x')
+    >>> a = A(1, 2)
+    >>> b = A(1, x)
+    >>> unify(a, b, {})
+    {~x: 2}
+    """
+    _unify.add((cls, cls, dict), unify_object)
+    _reify.add((cls, dict), reify_object)
+
+    return cls
+
+
+#########
+# Reify #
+#########
+
+
+def reify_object(o, s):
+    """ Reify a Python object with a substitution
+    >>> # xdoctest: +SKIP
+    >>> class Foo(object):
+    ...     def __init__(self, a, b):
+    ...         self.a = a
+    ...         self.b = b
+    ...     def __str__(self):
+    ...         return "Foo(%s, %s)"%(str(self.a), str(self.b))
+    >>> x = var('x')
+    >>> f = Foo(1, x)
+    >>> print(f)
+    Foo(1, ~x)
+    >>> print(reify_object(f, {x: 2}))
+    Foo(1, 2)
+    """
+    if hasattr(o, '__slots__'):
+        return _reify_object_slots(o, s)
+    else:
+        return _reify_object_dict(o, s)
+
+
+def _reify_object_dict(o, s):
+    obj = object.__new__(type(o))
+    d = reify(o.__dict__, s)
+    if d == o.__dict__:
+        return o
+    obj.__dict__.update(d)
+    return obj
+
+
+def _reify_object_slots(o, s):
+    attrs = [getattr(o, attr) for attr in o.__slots__]
+    new_attrs = reify(attrs, s)
+    if attrs == new_attrs:
+        return o
+    else:
+        newobj = object.__new__(type(o))
+        for slot, attr in zip(o.__slots__, new_attrs):
+            setattr(newobj, slot, attr)
+        return newobj
+
+
+@dispatch(slice, dict)
+def _reify(o, s):
+    """ Reify a Python ``slice`` object """
+    return slice(*reify((o.start, o.stop, o.step), s))
+
+
+#########
+# Unify #
+#########
+
+
+def unify_object(u, v, s):
+    """ Unify two Python objects
+    Unifies their type and ``__dict__`` attributes
+    >>> # xdoctest: +SKIP
+    >>> class Foo(object):
+    ...     def __init__(self, a, b):
+    ...         self.a = a
+    ...         self.b = b
+    ...     def __str__(self):
+    ...         return "Foo(%s, %s)"%(str(self.a), str(self.b))
+    >>> x = var('x')
+    >>> f = Foo(1, x)
+    >>> g = Foo(1, 2)
+    >>> unify_object(f, g, {})
+    {~x: 2}
+    """
+    if type(u) != type(v):
+        return False
+    if hasattr(u, '__slots__'):
+        return unify([getattr(u, slot) for slot in u.__slots__],
+                     [getattr(v, slot) for slot in v.__slots__],
+                     s)
+    else:
+        return unify(u.__dict__, v.__dict__, s)
+
+
+@dispatch(slice, slice, dict)
+def _unify(u, v, s):
+    """ Unify a Python ``slice`` object """
+    return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..527b66d546e372ed1bfd1f1f5ec9ab3772436979
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
@@ -0,0 +1,3 @@
+from .core import dispatch
+from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
+                         MDNotImplementedError)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1199d4f22232280ad60ca9c6881242fd2f8acb88
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c587ef5c96b54a49ea509eb0f1607bba6dc75b0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd351e35bd62ad0cf613ecfd3dc7d0062ae62575
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a6561dea6279d24b774c5d239c90e73362bdc12
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3037e8169aa2b76041bd502c1c939168418bae29
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de6dbb990a26986b5e25e66c014f7846315f7b86
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
new file mode 100644
index 0000000000000000000000000000000000000000..021f4fdf5c9b5fa715d3933e2b00aaddcfd73d66
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
@@ -0,0 +1,119 @@
+from .utils import _toposort, groupby
+from .variadic import isvariadic
+
+__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
+           "edge", "ordering"]
+
+class AmbiguityWarning(Warning):
+    pass
+
+
+def supercedes(a, b):
+    """ A is consistent and strictly more specific than B """
+    if len(a) < len(b):
+        # only case is if a is empty and b is variadic
+        return not a and len(b) == 1 and isvariadic(b[-1])
+    elif len(a) == len(b):
+        return all(map(issubclass, a, b))
+    else:
+        # len(a) > len(b)
+        p1 = 0
+        p2 = 0
+        while p1 < len(a) and p2 < len(b):
+            cur_a = a[p1]
+            cur_b = b[p2]
+            if not (isvariadic(cur_a) or isvariadic(cur_b)):
+                if not issubclass(cur_a, cur_b):
+                    return False
+                p1 += 1
+                p2 += 1
+            elif isvariadic(cur_a):
+                assert p1 == len(a) - 1
+                return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
+            elif isvariadic(cur_b):
+                assert p2 == len(b) - 1
+                if not issubclass(cur_a, cur_b):
+                    return False
+                p1 += 1
+        return p2 == len(b) - 1 and p1 == len(a)
+
+
+def consistent(a, b):
+    """ It is possible for an argument list to satisfy both A and B """
+
+    # Need to check for empty args
+    if not a:
+        return not b or isvariadic(b[0])
+    if not b:
+        return not a or isvariadic(a[0])
+
+    # Non-empty args check for mutual subclasses
+    if len(a) == len(b):
+        return all(issubclass(aa, bb) or issubclass(bb, aa)
+                   for aa, bb in zip(a, b))
+    else:
+        p1 = 0
+        p2 = 0
+        while p1 < len(a) and p2 < len(b):
+            cur_a = a[p1]
+            cur_b = b[p2]
+            if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
+                return False
+            if not (isvariadic(cur_a) or isvariadic(cur_b)):
+                p1 += 1
+                p2 += 1
+            elif isvariadic(cur_a):
+                p2 += 1
+            elif isvariadic(cur_b):
+                p1 += 1
+        # We only need to check for variadic ends
+        # Variadic types are guaranteed to be the last element
+        return (isvariadic(cur_a) and p2 == len(b) or  # type: ignore[possibly-undefined]
+                isvariadic(cur_b) and p1 == len(a))  # type: ignore[possibly-undefined]
+
+
+def ambiguous(a, b):
+    """ A is consistent with B but neither is strictly more specific """
+    return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
+
+
+def ambiguities(signatures):
+    """ All signature pairs such that A is ambiguous with B """
+    signatures = list(map(tuple, signatures))
+    return {(a, b) for a in signatures for b in signatures
+            if hash(a) < hash(b)
+            and ambiguous(a, b)
+            and not any(supercedes(c, a) and supercedes(c, b)
+            for c in signatures)}
+
+
+def super_signature(signatures):
+    """ A signature that would break ambiguities """
+    n = len(signatures[0])
+    assert all(len(s) == n for s in signatures)
+
+    return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
+            for i in range(n)]
+
+
+def edge(a, b, tie_breaker=hash):
+    """ A should be checked before B
+    Tie broken by tie_breaker, defaults to ``hash``
+    """
+    # A either supercedes B and B does not supercede A or if B does then call
+    # tie_breaker
+    return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
+
+
+def ordering(signatures):
+    """ A sane ordering of signatures to check, first to last
+    Topological sort of edges as given by ``edge`` and ``supercedes``
+    """
+    signatures = list(map(tuple, signatures))
+    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
+    edges = groupby(lambda x: x[0], edges)
+    for s in signatures:
+        if s not in edges:
+            edges[s] = []
+    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[assignment, attr-defined]
+    return _toposort(edges)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..e767aa4beab222d23e0ded688f620f2483e8c555
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
@@ -0,0 +1,83 @@
+import inspect
+import sys
+
+from .dispatcher import Dispatcher, MethodDispatcher
+
+global_namespace = {}  # type: ignore[var-annotated]
+
+__all__ = ["dispatch", "ismethod"]
+
+def dispatch(*types, **kwargs):
+    """ Dispatch function on the types of the inputs
+    Supports dispatch on all non-keyword arguments.
+    Collects implementations based on the function name.  Ignores namespaces.
+    If ambiguous type signatures occur a warning is raised when the function is
+    defined suggesting the additional method to break the ambiguity.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> @dispatch(int)
+        ... def f(x):
+        ...     return x + 1
+        >>> @dispatch(float)
+        ... def f(x):
+        ...     return x - 1
+        >>> # xdoctest: +SKIP
+        >>> f(3)
+        4
+        >>> f(3.0)
+        2.0
+        >>> # Specify an isolated namespace with the namespace keyword argument
+        >>> my_namespace = {}
+        >>> @dispatch(int, namespace=my_namespace)
+        ... def foo(x):
+        ...     return x + 1
+        >>> # Dispatch on instance methods within classes
+        >>> class MyClass(object):
+        ...     @dispatch(list)
+        ...     def __init__(self, data):
+        ...         self.data = data
+        ...     @dispatch(int)
+        ...     def __init__(self, datum):
+        ...         self.data = [datum]
+        >>> MyClass([1, 2, 3]).data
+        [1, 2, 3]
+        >>> MyClass(3).data
+        [3]
+    """
+    namespace = kwargs.get('namespace', global_namespace)
+
+    types = tuple(types)
+
+    def _df(func):
+        name = func.__name__
+
+        if ismethod(func):
+            dispatcher = inspect.currentframe().f_back.f_locals.get(  # type: ignore[union-attr]
+                name,  # type: ignore[union-attr]
+                MethodDispatcher(name),
+            )
+        else:
+            if name not in namespace:
+                namespace[name] = Dispatcher(name)
+            dispatcher = namespace[name]
+
+        dispatcher.add(types, func)
+        return dispatcher
+    return _df
+
+
+def ismethod(func):
+    """ Is func a method?
+    Note that this has to work as the method is defined but before the class is
+    defined.  At this stage methods look like functions.
+    """
+    if hasattr(inspect, "signature"):
+        signature = inspect.signature(func)
+        return signature.parameters.get('self', None) is not None
+    else:
+        if sys.version_info.major < 3:
+            spec = inspect.getargspec(func)  # type: ignore[attr-defined]
+        else:
+            spec = inspect.getfullargspec(func)  # type: ignore[union-attr, assignment]
+        return spec and spec.args and spec.args[0] == 'self'
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb4a1d5ce86a9c4d59aade063b906a1b237318d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
@@ -0,0 +1,430 @@
+from warnings import warn
+import inspect
+from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
+from .utils import expand_tuples
+from .variadic import Variadic, isvariadic
+import itertools as itl
+
+__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
+           "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
+
+class MDNotImplementedError(NotImplementedError):
+    """ A NotImplementedError for multiple dispatch """
+
+
+def ambiguity_warn(dispatcher, ambiguities):
+    """ Raise warning when ambiguity is detected
+    Parameters
+    ----------
+    dispatcher : Dispatcher
+        The dispatcher on which the ambiguity was detected
+    ambiguities : set
+        Set of type signature pairs that are ambiguous within this dispatcher
+    See Also:
+        Dispatcher.add
+        warning_text
+    """
+    warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
+
+
+def halt_ordering():
+    """Deprecated interface to temporarily disable ordering.
+    """
+    warn(
+        'halt_ordering is deprecated, you can safely remove this call.',
+        DeprecationWarning,
+    )
+
+
+def restart_ordering(on_ambiguity=ambiguity_warn):
+    """Deprecated interface to temporarily resume ordering.
+    """
+    warn(
+        'restart_ordering is deprecated, if you would like to eagerly order'
+        'the dispatchers, you should call the ``reorder()`` method on each'
+        ' dispatcher.',
+        DeprecationWarning,
+    )
+
+
+def variadic_signature_matches_iter(types, full_signature):
+    """Check if a set of input types matches a variadic signature.
+    Notes
+    -----
+    The algorithm is as follows:
+    Initialize the current signature to the first in the sequence
+    For each type in `types`:
+        If the current signature is variadic
+            If the type matches the signature
+                yield True
+            Else
+                Try to get the next signature
+                If no signatures are left we can't possibly have a match
+                    so yield False
+        Else
+            yield True if the type matches the current signature
+            Get the next signature
+    """
+    sigiter = iter(full_signature)
+    sig = next(sigiter)
+    for typ in types:
+        matches = issubclass(typ, sig)
+        yield matches
+        if not isvariadic(sig):
+            # we're not matching a variadic argument, so move to the next
+            # element in the signature
+            sig = next(sigiter)
+    else:
+        try:
+            sig = next(sigiter)
+        except StopIteration:
+            assert isvariadic(sig)
+            yield True
+        else:
+            # We have signature items left over, so all of our arguments
+            # haven't matched
+            yield False
+
+
+def variadic_signature_matches(types, full_signature):
+    # No arguments always matches a variadic signature
+    assert full_signature
+    return all(variadic_signature_matches_iter(types, full_signature))
+
+
+class Dispatcher:
+    """ Dispatch methods based on type signature
+    Use ``dispatch`` to add implementations
+    Examples
+    --------
+    >>> # xdoctest: +SKIP("bad import name")
+    >>> from multipledispatch import dispatch
+    >>> @dispatch(int)
+    ... def f(x):
+    ...     return x + 1
+    >>> @dispatch(float)
+    ... def f(x):
+    ...     return x - 1
+    >>> f(3)
+    4
+    >>> f(3.0)
+    2.0
+    """
+    __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
+
+    def __init__(self, name, doc=None):
+        self.name = self.__name__ = name
+        self.funcs = {}
+        self.doc = doc
+
+        self._cache = {}
+
+    def register(self, *types, **kwargs):
+        """ register dispatcher with new implementation
+        >>> # xdoctest: +SKIP
+        >>> f = Dispatcher('f')
+        >>> @f.register(int)
+        ... def inc(x):
+        ...     return x + 1
+        >>> @f.register(float)
+        ... def dec(x):
+        ...     return x - 1
+        >>> @f.register(list)
+        ... @f.register(tuple)
+        ... def reverse(x):
+        ...     return x[::-1]
+        >>> f(1)
+        2
+        >>> f(1.0)
+        0.0
+        >>> f([1, 2, 3])
+        [3, 2, 1]
+        """
+        def _df(func):
+            self.add(types, func, **kwargs)   # type: ignore[call-arg]
+            return func
+        return _df
+
+    @classmethod
+    def get_func_params(cls, func):
+        if hasattr(inspect, "signature"):
+            sig = inspect.signature(func)
+            return sig.parameters.values()
+
+    @classmethod
+    def get_func_annotations(cls, func):
+        """ get annotations of function positional parameters
+        """
+        params = cls.get_func_params(func)
+        if params:
+            Parameter = inspect.Parameter
+
+            params = (param for param in params
+                      if param.kind in
+                      (Parameter.POSITIONAL_ONLY,
+                       Parameter.POSITIONAL_OR_KEYWORD))
+
+            annotations = tuple(
+                param.annotation
+                for param in params)
+
+            if all(ann is not Parameter.empty for ann in annotations):
+                return annotations
+
+    def add(self, signature, func):
+        """ Add new types/method pair to dispatcher
+        >>> # xdoctest: +SKIP
+        >>> D = Dispatcher('add')
+        >>> D.add((int, int), lambda x, y: x + y)
+        >>> D.add((float, float), lambda x, y: x + y)
+        >>> D(1, 2)
+        3
+        >>> D(1, 2.0)
+        Traceback (most recent call last):
+        ...
+        NotImplementedError: Could not find signature for add: 
+        >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
+        >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
+        >>> # as inputs.  See ``ambiguity_warn`` for an example.
+        """
+        # Handle annotations
+        if not signature:
+            annotations = self.get_func_annotations(func)
+            if annotations:
+                signature = annotations
+
+        # Handle union types
+        if any(isinstance(typ, tuple) for typ in signature):
+            for typs in expand_tuples(signature):
+                self.add(typs, func)
+            return
+
+        new_signature = []
+
+        for index, typ in enumerate(signature, start=1):
+            if not isinstance(typ, (type, list)):
+                str_sig = ', '.join(c.__name__ if isinstance(c, type)
+                                    else str(c) for c in signature)
+                raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
+                                f"In signature: <{str_sig}>\n"
+                                f"In function: {self.name}")
+
+            # handle variadic signatures
+            if isinstance(typ, list):
+                if index != len(signature):
+                    raise TypeError(
+                        'Variadic signature must be the last element'
+                    )
+
+                if len(typ) != 1:
+                    raise TypeError(
+                        'Variadic signature must contain exactly one element. '
+                        'To use a variadic union type place the desired types '
+                        'inside of a tuple, e.g., [(int, str)]'
+                    )
+                new_signature.append(Variadic[typ[0]])
+            else:
+                new_signature.append(typ)
+
+        self.funcs[tuple(new_signature)] = func
+        self._cache.clear()
+
+        try:
+            del self._ordering
+        except AttributeError:
+            pass
+
+    @property
+    def ordering(self):
+        try:
+            return self._ordering
+        except AttributeError:
+            return self.reorder()
+
+    def reorder(self, on_ambiguity=ambiguity_warn):
+        self._ordering = od = ordering(self.funcs)
+        amb = ambiguities(self.funcs)
+        if amb:
+            on_ambiguity(self, amb)
+        return od
+
+    def __call__(self, *args, **kwargs):
+        types = tuple([type(arg) for arg in args])
+        try:
+            func = self._cache[types]
+        except KeyError as e:
+            func = self.dispatch(*types)
+            if not func:
+                raise NotImplementedError(
+                    f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
+            self._cache[types] = func
+        try:
+            return func(*args, **kwargs)
+
+        except MDNotImplementedError as e:
+            funcs = self.dispatch_iter(*types)
+            next(funcs)  # burn first
+            for func in funcs:
+                try:
+                    return func(*args, **kwargs)
+                except MDNotImplementedError:
+                    pass
+
+            raise NotImplementedError(
+                "Matching functions for "
+                f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
+
+    def __str__(self):
+        return f""
+    __repr__ = __str__
+
+    def dispatch(self, *types):
+        """Determine appropriate implementation for this type signature
+        This method is internal.  Users should call this object as a function.
+        Implementation resolution occurs within the ``__call__`` method.
+        >>> # xdoctest: +SKIP
+        >>> from multipledispatch import dispatch
+        >>> @dispatch(int)
+        ... def inc(x):
+        ...     return x + 1
+        >>> implementation = inc.dispatch(int)
+        >>> implementation(3)
+        4
+        >>> print(inc.dispatch(float))
+        None
+        See Also:
+          ``multipledispatch.conflict`` - module to determine resolution order
+        """
+
+        if types in self.funcs:
+            return self.funcs[types]
+
+        try:
+            return next(self.dispatch_iter(*types))
+        except StopIteration:
+            return None
+
+    def dispatch_iter(self, *types):
+
+        n = len(types)
+        for signature in self.ordering:
+            if len(signature) == n and all(map(issubclass, types, signature)):
+                result = self.funcs[signature]
+                yield result
+            elif len(signature) and isvariadic(signature[-1]):
+                if variadic_signature_matches(types, signature):
+                    result = self.funcs[signature]
+                    yield result
+
+    def resolve(self, types):
+        """ Determine appropriate implementation for this type signature
+        .. deprecated:: 0.4.4
+            Use ``dispatch(*types)`` instead
+        """
+        warn("resolve() is deprecated, use dispatch(*types)",
+             DeprecationWarning)
+
+        return self.dispatch(*types)
+
+    def __getstate__(self):
+        return {'name': self.name,
+                'funcs': self.funcs}
+
+    def __setstate__(self, d):
+        self.name = d['name']
+        self.funcs = d['funcs']
+        self._ordering = ordering(self.funcs)
+        self._cache = {}
+
+    @property
+    def __doc__(self):
+        docs = [f"Multiply dispatched method: {self.name}"]
+
+        if self.doc:
+            docs.append(self.doc)
+
+        other = []
+        for sig in self.ordering[::-1]:
+            func = self.funcs[sig]
+            if func.__doc__:
+                s = f'Inputs: <{str_signature(sig)}>\n'
+                s += '-' * len(s) + '\n'
+                s += func.__doc__.strip()
+                docs.append(s)
+            else:
+                other.append(str_signature(sig))
+
+        if other:
+            docs.append('Other signatures:\n    ' + '\n    '.join(other))
+
+        return '\n\n'.join(docs)
+
+    def _help(self, *args):
+        return self.dispatch(*map(type, args)).__doc__
+
+    def help(self, *args, **kwargs):
+        """ Print docstring for the function corresponding to inputs """
+        print(self._help(*args))
+
+    def _source(self, *args):
+        func = self.dispatch(*map(type, args))
+        if not func:
+            raise TypeError("No function found")
+        return source(func)
+
+    def source(self, *args, **kwargs):
+        """ Print source code for the function corresponding to inputs """
+        print(self._source(*args))
+
+
+def source(func):
+    s = f'File: {inspect.getsourcefile(func)}\n\n'
+    s = s + inspect.getsource(func)
+    return s
+
+
+class MethodDispatcher(Dispatcher):
+    """ Dispatch methods based on type signature
+    See Also:
+        Dispatcher
+    """
+    __slots__ = ('obj', 'cls')
+
+    @classmethod
+    def get_func_params(cls, func):
+        if hasattr(inspect, "signature"):
+            sig = inspect.signature(func)
+            return itl.islice(sig.parameters.values(), 1, None)
+
+    def __get__(self, instance, owner):
+        self.obj = instance
+        self.cls = owner
+        return self
+
+    def __call__(self, *args, **kwargs):
+        types = tuple([type(arg) for arg in args])
+        func = self.dispatch(*types)
+        if not func:
+            raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
+        return func(self.obj, *args, **kwargs)
+
+
+def str_signature(sig):
+    """ String representation of type signature
+    >>> str_signature((int, float))
+    'int, float'
+    """
+    return ', '.join(cls.__name__ for cls in sig)
+
+
+def warning_text(name, amb):
+    """ The text for ambiguity warnings """
+    text = f"\nAmbiguities exist in dispatched function {name}\n\n"
+    text += "The following signatures may result in ambiguous behavior:\n"
+    for pair in amb:
+        text += "\t" + \
+                ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
+    text += "\n\nConsider making the following additions:\n\n"
+    text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+                         + f')\ndef {name}(...)' for s in amb])
+    return text
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd060d95eeefe3648e0de143271bbac0629b71e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
@@ -0,0 +1,125 @@
+from collections import OrderedDict
+
+__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
+
+def raises(err, lamda):
+    try:
+        lamda()
+        return False
+    except err:
+        return True
+
+
+def expand_tuples(L):
+    """
+    >>> expand_tuples([1, (2, 3)])
+    [(1, 2), (1, 3)]
+    >>> expand_tuples([1, 2])
+    [(1, 2)]
+    """
+    if not L:
+        return [()]
+    elif not isinstance(L[0], tuple):
+        rest = expand_tuples(L[1:])
+        return [(L[0],) + t for t in rest]
+    else:
+        rest = expand_tuples(L[1:])
+        return [(item,) + t for t in rest for item in L[0]]
+
+
+# Taken from theano/theano/gof/sched.py
+# Avoids licensing issues because this was written by Matthew Rocklin
+def _toposort(edges):
+    """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
+    inputs:
+        edges - a dict of the form {a: {b, c}} where b and c depend on a
+    outputs:
+        L - an ordered list of nodes that satisfy the dependencies of edges
+    >>> _toposort({1: (2, 3), 2: (3, )})
+    [1, 2, 3]
+    >>> # Closely follows the wikipedia page [2]
+    >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
+    >>> # Communications of the ACM
+    >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
+    """
+    incoming_edges = reverse_dict(edges)
+    incoming_edges = OrderedDict((k, set(val))
+                                 for k, val in incoming_edges.items())
+    S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
+    L = []
+
+    while S:
+        n, _ = S.popitem()
+        L.append(n)
+        for m in edges.get(n, ()):
+            assert n in incoming_edges[m]
+            incoming_edges[m].remove(n)
+            if not incoming_edges[m]:
+                S[m] = None
+    if any(incoming_edges.get(v, None) for v in edges):
+        raise ValueError("Input has cycles")
+    return L
+
+
+def reverse_dict(d):
+    """Reverses direction of dependence dict
+    >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
+    >>> reverse_dict(d)  # doctest: +SKIP
+    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
+    :note: dict order are not deterministic. As we iterate on the
+        input dict, it make the output of this function depend on the
+        dict order. So this function output order should be considered
+        as undeterministic.
+    """
+    result = OrderedDict()  # type: ignore[var-annotated]
+    for key in d:
+        for val in d[key]:
+            result[val] = result.get(val, tuple()) + (key, )
+    return result
+
+
+# Taken from toolz
+# Avoids licensing issues because this version was authored by Matthew Rocklin
+def groupby(func, seq):
+    """ Group a collection by a key function
+    >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
+    >>> groupby(len, names)  # doctest: +SKIP
+    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
+    >>> iseven = lambda x: x % 2 == 0
+    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
+    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
+    See Also:
+        ``countby``
+    """
+
+    d = OrderedDict()  # type: ignore[var-annotated]
+    for item in seq:
+        key = func(item)
+        if key not in d:
+            d[key] = list()
+        d[key].append(item)
+    return d
+
+
+def typename(type):
+    """Get the name of `type`.
+    Parameters
+    ----------
+    type : Union[Type, Tuple[Type]]
+    Returns
+    -------
+    str
+        The name of `type` or a tuple of the names of the types in `type`.
+    Examples
+    --------
+    >>> typename(int)
+    'int'
+    >>> typename((int, float))
+    '(int, float)'
+    """
+    try:
+        return type.__name__
+    except AttributeError:
+        if len(type) == 1:
+            return typename(*type)
+        return f"({', '.join(map(typename, type))})"
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab008d6fcad276ef66076eb4a426bf3e390688cf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
@@ -0,0 +1,91 @@
+from .utils import typename
+
+__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
+
+class VariadicSignatureType(type):
+    # checking if subclass is a subclass of self
+    def __subclasscheck__(cls, subclass):
+        other_type = (subclass.variadic_type if isvariadic(subclass)
+                      else (subclass,))
+        return subclass is cls or all(
+            issubclass(other, cls.variadic_type) for other in other_type  # type: ignore[attr-defined]
+        )
+
+    def __eq__(cls, other):
+        """
+        Return True if other has the same variadic type
+        Parameters
+        ----------
+        other : object (type)
+            The object (type) to check
+        Returns
+        -------
+        bool
+            Whether or not `other` is equal to `self`
+        """
+        return (isvariadic(other) and
+                set(cls.variadic_type) == set(other.variadic_type))  # type: ignore[attr-defined]
+
+    def __hash__(cls):
+        return hash((type(cls), frozenset(cls.variadic_type)))  # type: ignore[attr-defined]
+
+
+def isvariadic(obj):
+    """Check whether the type `obj` is variadic.
+    Parameters
+    ----------
+    obj : type
+        The type to check
+    Returns
+    -------
+    bool
+        Whether or not `obj` is variadic
+    Examples
+    --------
+    >>> # xdoctest: +SKIP
+    >>> isvariadic(int)
+    False
+    >>> isvariadic(Variadic[int])
+    True
+    """
+    return isinstance(obj, VariadicSignatureType)
+
+
+class VariadicSignatureMeta(type):
+    """A metaclass that overrides ``__getitem__`` on the class. This is used to
+    generate a new type for Variadic signatures. See the Variadic class for
+    examples of how this behaves.
+    """
+    def __getitem__(cls, variadic_type):
+        if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
+            raise ValueError("Variadic types must be type or tuple of types"
+                             " (Variadic[int] or Variadic[(int, float)]")
+
+        if not isinstance(variadic_type, tuple):
+            variadic_type = variadic_type,
+        return VariadicSignatureType(
+            f'Variadic[{typename(variadic_type)}]',
+            (),
+            dict(variadic_type=variadic_type, __slots__=())
+        )
+
+
+class Variadic(metaclass=VariadicSignatureMeta):
+    """A class whose getitem method can be used to generate a new type
+    representing a specific variadic signature.
+    Examples
+    --------
+    >>> # xdoctest: +SKIP
+    >>> Variadic[int]  # any number of int arguments
+    
+    >>> Variadic[(int, str)]  # any number of one of int or str arguments
+    
+    >>> issubclass(int, Variadic[int])
+    True
+    >>> issubclass(int, Variadic[(int, str)])
+    True
+    >>> issubclass(str, Variadic[(int, str)])
+    True
+    >>> issubclass(float, Variadic[(int, str)])
+    False
+    """
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/unification_tools.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/unification_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0c3fafa32ed4f3f49dec01e0b25d02c009291a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/unification_tools.py
@@ -0,0 +1,395 @@
+import collections
+import operator
+from functools import reduce
+from collections.abc import Mapping
+
+__all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
+           'valfilter', 'keyfilter', 'itemfilter',
+           'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in')
+
+
+def _get_factory(f, kwargs):
+    factory = kwargs.pop('factory', dict)
+    if kwargs:
+        raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
+    return factory
+
+
+def merge(*dicts, **kwargs):
+    """ Merge a collection of dictionaries
+
+    >>> merge({1: 'one'}, {2: 'two'})
+    {1: 'one', 2: 'two'}
+
+    Later dictionaries have precedence
+
+    >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
+    {1: 2, 3: 3, 4: 4}
+
+    See Also:
+        merge_with
+    """
+    if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
+        dicts = dicts[0]
+    factory = _get_factory(merge, kwargs)
+
+    rv = factory()
+    for d in dicts:
+        rv.update(d)
+    return rv
+
+
+def merge_with(func, *dicts, **kwargs):
+    """ Merge dictionaries and apply function to combined values
+
+    A key may occur in more than one dict, and all values mapped from the key
+    will be passed to the function as a list, such as func([val1, val2, ...]).
+
+    >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
+    {1: 11, 2: 22}
+
+    >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30})  # doctest: +SKIP
+    {1: 1, 2: 2, 3: 30}
+
+    See Also:
+        merge
+    """
+    if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
+        dicts = dicts[0]
+    factory = _get_factory(merge_with, kwargs)
+
+    result = factory()
+    for d in dicts:
+        for k, v in d.items():
+            if k not in result:
+                result[k] = [v]
+            else:
+                result[k].append(v)
+    return valmap(func, result, factory)
+
+
+def valmap(func, d, factory=dict):
+    """ Apply function to values of dictionary
+
+    >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
+    >>> valmap(sum, bills)  # doctest: +SKIP
+    {'Alice': 65, 'Bob': 45}
+
+    See Also:
+        keymap
+        itemmap
+    """
+    rv = factory()
+    rv.update(zip(d.keys(), map(func, d.values())))
+    return rv
+
+
+def keymap(func, d, factory=dict):
+    """ Apply function to keys of dictionary
+
+    >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
+    >>> keymap(str.lower, bills)  # doctest: +SKIP
+    {'alice': [20, 15, 30], 'bob': [10, 35]}
+
+    See Also:
+        valmap
+        itemmap
+    """
+    rv = factory()
+    rv.update(zip(map(func, d.keys()), d.values()))
+    return rv
+
+
+def itemmap(func, d, factory=dict):
+    """ Apply function to items of dictionary
+
+    >>> accountids = {"Alice": 10, "Bob": 20}
+    >>> itemmap(reversed, accountids)  # doctest: +SKIP
+    {10: "Alice", 20: "Bob"}
+
+    See Also:
+        keymap
+        valmap
+    """
+    rv = factory()
+    rv.update(map(func, d.items()))
+    return rv
+
+
+def valfilter(predicate, d, factory=dict):
+    """ Filter items in dictionary by value
+
+    >>> iseven = lambda x: x % 2 == 0
+    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
+    >>> valfilter(iseven, d)
+    {1: 2, 3: 4}
+
+    See Also:
+        keyfilter
+        itemfilter
+        valmap
+    """
+    rv = factory()
+    for k, v in d.items():
+        if predicate(v):
+            rv[k] = v
+    return rv
+
+
+def keyfilter(predicate, d, factory=dict):
+    """ Filter items in dictionary by key
+
+    >>> iseven = lambda x: x % 2 == 0
+    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
+    >>> keyfilter(iseven, d)
+    {2: 3, 4: 5}
+
+    See Also:
+        valfilter
+        itemfilter
+        keymap
+    """
+    rv = factory()
+    for k, v in d.items():
+        if predicate(k):
+            rv[k] = v
+    return rv
+
+
+def itemfilter(predicate, d, factory=dict):
+    """ Filter items in dictionary by item
+
+    >>> def isvalid(item):
+    ...     k, v = item
+    ...     return k % 2 == 0 and v < 4
+
+    >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
+    >>> itemfilter(isvalid, d)
+    {2: 3}
+
+    See Also:
+        keyfilter
+        valfilter
+        itemmap
+    """
+    rv = factory()
+    for item in d.items():
+        if predicate(item):
+            k, v = item
+            rv[k] = v
+    return rv
+
+
+def assoc(d, key, value, factory=dict):
+    """ Return a new dict with new key value pair
+
+    New dict has d[key] set to value. Does not modify the initial dictionary.
+
+    >>> assoc({'x': 1}, 'x', 2)
+    {'x': 2}
+    >>> assoc({'x': 1}, 'y', 3)   # doctest: +SKIP
+    {'x': 1, 'y': 3}
+    """
+    d2 = factory()
+    d2.update(d)
+    d2[key] = value
+    return d2
+
+
+def dissoc(d, *keys, **kwargs):
+    """ Return a new dict with the given key(s) removed.
+
+    New dict has d[key] deleted for each supplied key.
+    Does not modify the initial dictionary.
+
+    >>> dissoc({'x': 1, 'y': 2}, 'y')
+    {'x': 1}
+    >>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
+    {}
+    >>> dissoc({'x': 1}, 'y') # Ignores missing keys
+    {'x': 1}
+    """
+    factory = _get_factory(dissoc, kwargs)
+    d2 = factory()
+
+    if len(keys) < len(d) * .6:
+        d2.update(d)
+        for key in keys:
+            if key in d2:
+                del d2[key]
+    else:
+        remaining = set(d)
+        remaining.difference_update(keys)
+        for k in remaining:
+            d2[k] = d[k]
+    return d2
+
+
+def assoc_in(d, keys, value, factory=dict):
+    """ Return a new dict with new, potentially nested, key value pair
+
+    >>> purchase = {'name': 'Alice',
+    ...             'order': {'items': ['Apple', 'Orange'],
+    ...                       'costs': [0.50, 1.25]},
+    ...             'credit card': '5555-1234-1234-1234'}
+    >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
+    {'credit card': '5555-1234-1234-1234',
+     'name': 'Alice',
+     'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
+    """
+    return update_in(d, keys, lambda x: value, value, factory)
+
+
+def update_in(d, keys, func, default=None, factory=dict):
+    """ Update value in a (potentially) nested dictionary
+
+    inputs:
+    d - dictionary on which to operate
+    keys - list or tuple giving the location of the value to be changed in d
+    func - function to operate on that value
+
+    If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
+    original dictionary with v replaced by func(v), but does not mutate the
+    original dictionary.
+
+    If k0 is not a key in d, update_in creates nested dictionaries to the depth
+    specified by the keys, with the innermost value set to func(default).
+
+    >>> inc = lambda x: x + 1
+    >>> update_in({'a': 0}, ['a'], inc)
+    {'a': 1}
+
+    >>> transaction = {'name': 'Alice',
+    ...                'purchase': {'items': ['Apple', 'Orange'],
+    ...                             'costs': [0.50, 1.25]},
+    ...                'credit card': '5555-1234-1234-1234'}
+    >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
+    {'credit card': '5555-1234-1234-1234',
+     'name': 'Alice',
+     'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
+
+    >>> # updating a value when k0 is not in d
+    >>> update_in({}, [1, 2, 3], str, default="bar")
+    {1: {2: {3: 'bar'}}}
+    >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
+    {1: 'foo', 2: {3: {4: 1}}}
+    """
+    ks = iter(keys)
+    k = next(ks)
+
+    rv = inner = factory()
+    rv.update(d)
+
+    for key in ks:
+        if k in d:
+            d = d[k]
+            dtemp = factory()
+            dtemp.update(d)
+        else:
+            d = dtemp = factory()
+
+        inner[k] = inner = dtemp
+        k = key
+
+    if k in d:
+        inner[k] = func(d[k])
+    else:
+        inner[k] = func(default)
+    return rv
+
+
+def get_in(keys, coll, default=None, no_default=False):
+    """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
+
+    If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
+    ``no_default`` is specified, then it raises KeyError or IndexError.
+
+    ``get_in`` is a generalization of ``operator.getitem`` for nested data
+    structures such as dictionaries and lists.
+
+    >>> transaction = {'name': 'Alice',
+    ...                'purchase': {'items': ['Apple', 'Orange'],
+    ...                             'costs': [0.50, 1.25]},
+    ...                'credit card': '5555-1234-1234-1234'}
+    >>> get_in(['purchase', 'items', 0], transaction)
+    'Apple'
+    >>> get_in(['name'], transaction)
+    'Alice'
+    >>> get_in(['purchase', 'total'], transaction)
+    >>> get_in(['purchase', 'items', 'apple'], transaction)
+    >>> get_in(['purchase', 'items', 10], transaction)
+    >>> get_in(['purchase', 'total'], transaction, 0)
+    0
+    >>> get_in(['y'], {}, no_default=True)
+    Traceback (most recent call last):
+        ...
+    KeyError: 'y'
+
+    See Also:
+        itertoolz.get
+        operator.getitem
+    """
+    try:
+        return reduce(operator.getitem, keys, coll)
+    except (KeyError, IndexError, TypeError):
+        if no_default:
+            raise
+        return default
+
+
+def getter(index):
+    if isinstance(index, list):
+        if len(index) == 1:
+            index = index[0]
+            return lambda x: (x[index],)
+        elif index:
+            return operator.itemgetter(*index)
+        else:
+            return lambda x: ()
+    else:
+        return operator.itemgetter(index)
+
+
+def groupby(key, seq):
+    """ Group a collection by a key function
+
+    >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
+    >>> groupby(len, names)  # doctest: +SKIP
+    {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
+
+    >>> iseven = lambda x: x % 2 == 0
+    >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8])  # doctest: +SKIP
+    {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
+
+    Non-callable keys imply grouping on a member.
+
+    >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
+    ...                    {'name': 'Bob', 'gender': 'M'},
+    ...                    {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
+    {'F': [{'gender': 'F', 'name': 'Alice'}],
+     'M': [{'gender': 'M', 'name': 'Bob'},
+           {'gender': 'M', 'name': 'Charlie'}]}
+
+    Not to be confused with ``itertools.groupby``
+
+    See Also:
+        countby
+    """
+    if not callable(key):
+        key = getter(key)
+    d = collections.defaultdict(lambda: [].append)  # type: ignore[var-annotated]
+    for item in seq:
+        d[key(item)](item)
+    rv = {}
+    for k, v in d.items():
+        rv[k] = v.__self__  # type: ignore[var-annotated, attr-defined]
+    return rv
+
+
+def first(seq):
+    """ The first element in a sequence
+
+    >>> first('ABC')
+    'A'
+    """
+    return next(iter(seq))
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/utils.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4887c8f715489c8ce3ecb0616c24b1975f792048
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/utils.py
@@ -0,0 +1,105 @@
+__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
+def hashable(x):
+    try:
+        hash(x)
+        return True
+    except TypeError:
+        return False
+
+
+def transitive_get(key, d):
+    """ Transitive dict.get
+    >>> d = {1: 2, 2: 3, 3: 4}
+    >>> d.get(1)
+    2
+    >>> transitive_get(1, d)
+    4
+    """
+    while hashable(key) and key in d:
+        key = d[key]
+    return key
+
+
+def raises(err, lamda):
+    try:
+        lamda()
+        return False
+    except err:
+        return True
+
+
+# Taken from theano/theano/gof/sched.py
+# Avoids licensing issues because this was written by Matthew Rocklin
+def _toposort(edges):
+    """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
+    inputs:
+        edges - a dict of the form {a: {b, c}} where b and c depend on a
+    outputs:
+        L - an ordered list of nodes that satisfy the dependencies of edges
+    >>> # xdoctest: +SKIP
+    >>> _toposort({1: (2, 3), 2: (3, )})
+    [1, 2, 3]
+    Closely follows the wikipedia page [2]
+    [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
+    Communications of the ACM
+    [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
+    """
+    incoming_edges = reverse_dict(edges)
+    incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
+    S = ({v for v in edges if v not in incoming_edges})
+    L = []
+
+    while S:
+        n = S.pop()
+        L.append(n)
+        for m in edges.get(n, ()):
+            assert n in incoming_edges[m]
+            incoming_edges[m].remove(n)
+            if not incoming_edges[m]:
+                S.add(m)
+    if any(incoming_edges.get(v, None) for v in edges):
+        raise ValueError("Input has cycles")
+    return L
+
+
+def reverse_dict(d):
+    """Reverses direction of dependence dict
+    >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
+    >>> reverse_dict(d)  # doctest: +SKIP
+    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
+    :note: dict order are not deterministic. As we iterate on the
+        input dict, it make the output of this function depend on the
+        dict order. So this function output order should be considered
+        as undeterministic.
+    """
+    result = {}  # type: ignore[var-annotated]
+    for key in d:
+        for val in d[key]:
+            result[val] = result.get(val, tuple()) + (key, )
+    return result
+
+
+def xfail(func):
+    try:
+        func()
+        raise Exception("XFailed test passed")  # pragma:nocover
+    except Exception:
+        pass
+
+
+def freeze(d):
+    """ Freeze container to hashable form
+    >>> freeze(1)
+    1
+    >>> freeze([1, 2])
+    (1, 2)
+    >>> freeze({1: 2}) # doctest: +SKIP
+    frozenset([(1, 2)])
+    """
+    if isinstance(d, dict):
+        return frozenset(map(freeze, d.items()))
+    if isinstance(d, set):
+        return frozenset(map(freeze, d))
+    if isinstance(d, (tuple, list)):
+        return tuple(map(freeze, d))
+    return d
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unification/variable.py b/MLPY/Lib/site-packages/torch/fx/experimental/unification/variable.py
new file mode 100644
index 0000000000000000000000000000000000000000..778d2e1cbbdbbd6d9e6c96127f15a349243bd915
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unification/variable.py
@@ -0,0 +1,85 @@
+from contextlib import contextmanager
+from .utils import hashable
+from .dispatch import dispatch
+
+_global_logic_variables = set()  # type: ignore[var-annotated]
+_glv = _global_logic_variables
+
+
+class Var:
+    """ Logic Variable """
+
+    _id = 1
+
+    def __new__(cls, *token):
+        if len(token) == 0:
+            token = f"_{Var._id}"  # type: ignore[assignment]
+            Var._id += 1
+        elif len(token) == 1:
+            token = token[0]
+
+        obj = object.__new__(cls)
+        obj.token = token  # type: ignore[attr-defined]
+        return obj
+
+    def __str__(self):
+        return "~" + str(self.token)  # type: ignore[attr-defined]
+    __repr__ = __str__
+
+    def __eq__(self, other):
+        return type(self) == type(other) and self.token == other.token  # type: ignore[attr-defined]
+
+    def __hash__(self):
+        return hash((type(self), self.token))  # type: ignore[attr-defined]
+
+
+def var():
+    return lambda *args: Var(*args)
+
+
+def vars():
+    return lambda n: [var() for i in range(n)]
+
+
+@dispatch(Var)
+def isvar(v):
+    return True
+
+isvar
+
+
+@dispatch(object)  # type: ignore[no-redef]
+def isvar(o):
+    return not not _glv and hashable(o) and o in _glv
+
+
+@contextmanager
+def variables(*variables):
+    """
+    Context manager for logic variables
+
+    Example:
+        >>> # xdoctest: +SKIP("undefined vars")
+        >>> from __future__ import with_statement
+        >>> with variables(1):
+        ...     print(isvar(1))
+        True
+        >>> print(isvar(1))
+        False
+        >>> # Normal approach
+        >>> from unification import unify
+        >>> x = var('x')
+        >>> unify(x, 1)
+        {~x: 1}
+        >>> # Context Manager approach
+        >>> with variables('x'):
+        ...     print(unify('x', 1))
+        {'x': 1}
+    """
+    old_global_logic_variables = _global_logic_variables.copy()
+    _global_logic_variables.update(set(variables))
+    try:
+        yield
+    finally:
+        _global_logic_variables.clear()
+        _global_logic_variables.update(old_global_logic_variables)
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/unify_refinements.py b/MLPY/Lib/site-packages/torch/fx/experimental/unify_refinements.py
new file mode 100644
index 0000000000000000000000000000000000000000..b30cadfd04c3cdd40023a360419af0587d09fdd8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/unify_refinements.py
@@ -0,0 +1,120 @@
+from torch.fx.experimental.graph_gradual_typechecker import Refine
+from torch.fx.tensor_type import TensorType
+from torch.fx.experimental.unification import Var, unify  # type: ignore[attr-defined]
+
+
+def infer_symbolic_types_single_pass(traced):
+    """
+    Calls our symbolic inferencer once.
+    """
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+def infer_symbolic_types(traced):
+    """
+    Calls our symbolic inferencer twice.
+    This is useful when one pass is not enough
+    to infer all the information such as the case
+    for braodcasting.
+    """
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+    r.symbolic_relations()
+
+def convert_eq(list_of_eq):
+    """
+    Convert equality constraints in the right format
+    to be used by unification library.
+    """
+    lhs = []
+    rhs = []
+    for eq in list_of_eq:
+        lhs.append(eq.lhs)
+        rhs.append(eq.rhs)
+    return tuple(lhs), tuple(rhs)
+
+
+def unify_eq(list_of_eq):
+    """
+    Apply unification to a set of
+    equality constraints
+    """
+    lhs, rhs = convert_eq(list_of_eq)
+    return unify(lhs, rhs)
+
+
+def substitute_solution_one_type(mapping, t):
+    """
+    Apply the most general unifier to a type
+    """
+    if isinstance(t, Var):
+        if t in mapping.keys():
+            return mapping[t]
+        else:
+            return t
+
+    elif isinstance(t, TensorType):
+        new_type = []
+        for typ in t.__args__:
+            if typ in mapping.keys():
+                new_type.append(mapping[typ])
+            else:
+                new_type.append(typ)
+        return TensorType(tuple(new_type))
+
+    elif isinstance(t, list):
+        new_type = []
+        for typ in t:
+            new_type.append(substitute_solution_one_type(mapping, typ))
+        return new_type
+
+    elif isinstance(t, tuple):
+        new_type = []
+        for typ in t:
+            new_type.append(substitute_solution_one_type(mapping, typ))
+        return tuple(new_type)
+
+    else:
+        return t
+
+
+def substitute_all_types(graph, mapping):
+    """
+    Apply the most general unifier to all types in a graph
+    till reaching a fixed point. If the input and output graph
+    are the same, we converge.
+    """
+    flag = True
+    while flag:
+        flag = False
+        for k in mapping:
+            old_mapping_val = mapping[k]
+            if mapping[k] in mapping.keys():
+                new_key = mapping[k]
+                mapping[k] = mapping[new_key]
+            if old_mapping_val != mapping[k]:
+                flag = True
+
+    for n in graph.nodes:
+        n.type = substitute_solution_one_type(mapping, n.type)
+
+def check_for_type_equality(g1, g2):
+    """
+    A check equality to be used in fixed points.
+    We do not use graph equality but instead type
+    equality.
+    """
+    for n, m in zip(g1.nodes, g2.nodes):
+        if n.type != m.type:
+            return False
+    return True
diff --git a/MLPY/Lib/site-packages/torch/fx/experimental/validator.py b/MLPY/Lib/site-packages/torch/fx/experimental/validator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a732c06f202bd3664305793d76650978e0cea4e1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/experimental/validator.py
@@ -0,0 +1,766 @@
+import functools
+import logging
+import math
+import operator
+import sympy
+import builtins
+
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
+
+import torch
+import torch.fx
+import torch.fx.traceback as fx_traceback
+
+from torch._dynamo.exc import TorchDynamoException
+from torch.fx.node import Argument, Target
+from torch.utils._sympy.interp import sympy_interp
+
+log = logging.getLogger(__name__)
+
+try:
+    import z3  # type: ignore[import]
+
+    # Translation Validation for Dynamo guards
+    # ========================================
+    #
+    # Checks whether optimizations applied to the collected guards are
+    # valid. In other words, whether the guard function we actually run
+    # does not have false positives (unsound).
+    #
+    # In order to do so, we build the guards using 2 different information
+    # attached to each 'SymNode':
+    #   1. SymPy expressions
+    #   2. FX nodes
+    #
+    # SymPy expressions have implicit optimizations baked within itself,
+    # which may have a few bugs. On the other hand, we build the FX graph
+    # manually, with no optimizations enabled. This gives us access to
+    # the "ground truth".
+    #
+    # We then convert into Z3 expressions both the SymPy expressions
+    # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function
+    # and the FX nodes (see [Note: PopulateValidator]) that go through
+    # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
+    # (see [Note: TranslationValidator])
+
+    # Better Z3 to string implementation (for a small fraction of Z3).
+    #
+    # Here are the things we clean before showing the Z3 expression:
+    #   - Rename a few ops (e.g. "Distinct" ==> "!=")
+    #
+    #   - Ignore ToInt and ToReal operations:
+    #     usually they don't really matter
+    #
+    #   - Transform (ToInt (/ ...)) into (idiv ...):
+    #     this is the pattern for floor division
+    #
+    #   - Collect a chain of the same operations into one
+    def z3str(e: z3.ExprRef) -> str:
+        assert z3.is_expr(e), f"unsupported expression type: {e}"
+
+        def get_args_str(e: z3.ExprRef) -> List[str]:
+            return [z3str(e.arg(i)) for i in range(e.num_args())]
+
+        # First, we simplify the given expression.
+        # This is done using rewriting rules, so shouldn't take long.
+        e = z3.simplify(e)
+
+
+        # Only support function applications.
+        # Even Z3 "variables" are, in fact, function applications.
+        if not z3.is_app(e):
+            raise ValueError(f"can't print Z3 expression: {e}")
+
+        if z3.is_int_value(e) or z3.is_rational_value(e):
+            return e.as_string()  # type: ignore[attr-defined]
+
+        decl = e.decl()
+        kind = decl.kind()
+        op = str(decl)
+        args = get_args_str(e)
+
+        if kind == z3.Z3_OP_POWER:
+            op = "pow"
+
+        elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
+            # Collect the arguments of chains of ADD and MUL.
+            # This is safe, since they are associative.
+
+            def collect_str_args(e):
+                if not (z3.is_app(e) and e.decl().kind() == kind):
+                    return [z3str(e)]
+                else:
+                    return [
+                        x
+                        for i in range(e.num_args())
+                        for x in collect_str_args(e.arg(i))
+                    ]
+
+            args = collect_str_args(e)
+
+        elif kind == z3.Z3_OP_NOT:
+            # Revert some conversions that z3.simplify applies:
+            #   - a != b ==> (Not (== a b)) ==> (!= a b)
+            #   - a < b ==> (Not (<= b a)) ==> (> b a)
+            #   - a > b ==> (Not (<= a b)) ==> (> a b)
+
+            assert e.num_args() == 1
+            arg = e.arg(0)
+
+            assert z3.is_app(arg)
+            argkind = arg.decl().kind()
+
+            logic_inverse = {
+                z3.Z3_OP_EQ: "!=",
+                z3.Z3_OP_LE: ">",
+                z3.Z3_OP_GE: "<",
+            }
+
+            if argkind in logic_inverse:
+                op = logic_inverse[argkind]
+                args = get_args_str(arg)
+
+        elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
+            assert e.num_args() == 1
+            argstr = z3str(e.arg(0))
+
+            # Check if it's the floor division pattern.
+            if argstr.startswith("(/"):
+                return "(idiv" + argstr[2:]
+
+            # Otherwise, just ignore it.
+            return argstr
+
+        elif kind == z3.Z3_OP_UNINTERPRETED:
+            assert e.num_args() == 0
+            return str(decl)
+
+        string = op + " " + " ".join(args)
+        return f"({string.rstrip()})"
+
+    # Implementation of Python semantics as Z3 expressions.
+    #
+    # Z3 Real-Int theory has operators with semantics that differ that of
+    # Python. Therefore, in order to get it right, we need to implement
+    # the (Python) semantics we are relying on in Z3.
+    @dataclass
+    class _Z3Ops:
+        # Validator used for adding assertions as needed.
+        # e.g. div(a, b) requires b != 0.
+        validator: "TranslationValidator"
+
+        # The 2 functions below are used for conditionally casting between
+        # integer and reals.
+        #
+        # Returns a real expression from 'x'.
+        @staticmethod
+        def to_real(x: z3.ArithRef) -> z3.ArithRef:
+            return x if x.is_real() else z3.ToReal(x)
+
+        # Returns an integer expression from 'x'.
+        @staticmethod
+        def to_int(x: z3.ArithRef) -> z3.ArithRef:
+            return x if x.is_int() else z3.ToInt(x)
+
+        # Implements Python division semantics.
+        def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            self.validator.add_assertion(denominator != 0)  # type: ignore[arg-type]
+            return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
+
+        def floor(self, number: z3.ArithRef) -> z3.ArithRef:
+            # Z3 ToInt function rounds a real number towards negative infinity.
+            return _Z3Ops.to_int(number)
+
+        # Python semantics for 'FloorDiv' states that before applying the floor
+        # function, the operands are converted to their common type.
+        def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            cast_result_to_real = numerator.is_real() or denominator.is_real()
+            result = _Z3Ops.to_int(self.div(numerator, denominator))
+            # Since the 'result' is already an integer, we just have to check
+            # whether we should cast it to real.
+            return _Z3Ops.to_real(result) if cast_result_to_real else result
+
+        def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(
+                self.floor(number) < number,
+                self.floor(number + 1),
+                number
+            )  # type: ignore[return-value]
+
+        def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(a > b, a, b)  # type: ignore[return-value]
+
+        def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
+            return z3.If(a < b, a, b)  # type: ignore[return-value]
+
+        # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
+        # It should work with both integer and reals.
+        def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
+            return p - self.floordiv(p, q) * q
+
+        def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
+            # Z3 can't handle complex numbers very well.
+            self.validator.add_assertion(z3.Or(base != 0, exp > 0))  # type: ignore[arg-type]
+            return base ** exp
+
+        def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
+            # Square-root:
+            # 1. Only work with reals
+            number = _Z3Ops.to_real(number)
+            # 2. The number should be positive or zero.
+            #    Otherwise, Z3 returns 'unknown'.
+            self.validator.add_assertion(number >= 0)
+            return number ** 0.5
+
+        def abs(self, number: z3.ArithRef) -> z3.ArithRef:
+            return z3.Abs(number)
+
+        def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef:
+            if ndigits is not None:
+                raise ValueError("round(..., ndigits=) is currently not supported by shape validations.")
+
+            # Pythons builtin 'round' implements the 'round half to even' strategy
+            # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
+            # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
+            # floating point numbers, which is different from real numbers that we are dealing with here.
+            # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and
+            # 'round half down' (ceil(x - 0.5)).
+            # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ...
+            # to round down, i.e. use the 'round half down' strategy
+            return z3.If(
+                self.mod(number, z3.IntVal(2)) == 0.5,
+                self.ceil(number - 0.5),
+                self.floor(number + 0.5),
+            )
+
+    # Lifts a callable to be used in Z3.
+    #
+    # This function replaces the given 'op' by a function that:
+    #
+    #   1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3)
+    #
+    #   2. Calls an operation that corresponds to 'op', but works with Z3
+    #      inhabitants (left as is if it works as is)
+    def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
+        # Operations that have booleans as their argument.
+        # This is needed because the argument of some FX nodes were
+        # literal integers, instead of booleans. So, whenever this flag
+        # is set, we also convert ints to booleans.
+        boolean_ops = {operator.not_, operator.and_, operator.or_}
+        as_bool = op in boolean_ops
+
+        # Lifts the function into 'z3.ExprRef' domain.
+        def lift(func):
+            def wrap(a) -> z3.ExprRef:
+                if isinstance(a, (z3.ArithRef, z3.BoolRef)):
+                    return a
+                # Convert it into a Z3 value, if it is some of the supported
+                # types below.
+                if isinstance(a, bool) or (as_bool and isinstance(a, int)):
+                    return z3.BoolVal(bool(a))
+                if isinstance(a, (int, sympy.Integer)):
+                    return z3.IntVal(int(a))
+                if isinstance(a, (float, sympy.Float)):
+                    return z3.RealVal(float(a))
+                raise ValueError(f"can't lift type: {type(a)}")
+
+            @functools.wraps(func)
+            def wrapper(*args):
+                # Lifts the arguments into a list of Z3 inhabitants.
+                wrapped_args = (wrap(a) for a in args)
+                # Run the function on the Z3 expressions.
+                return func(*wrapped_args)
+
+            return wrapper
+
+        ops = _Z3Ops(validator)
+        replacement_map = {
+            # Operator module.
+            operator.not_: lift(z3.Not),
+            operator.and_: lift(z3.And),
+            operator.or_: lift(z3.Or),
+            operator.floordiv: lift(ops.floordiv),
+            operator.truediv: lift(ops.div),
+            operator.mod: lift(ops.mod),
+            operator.abs: lift(ops.abs),
+            builtins.round: lift(ops.round),
+
+            # Math module.
+            math.ceil: lift(ops.ceil),
+            math.floor: lift(ops.floor),
+
+            # Torch module.
+            torch.sym_float: lift(ops.to_real),
+            torch.sym_max: lift(ops.max),
+            torch.sym_min: lift(ops.min),
+            torch.sym_ite: lift(lambda b, t, f: t if b else f),
+            torch._sym_sqrt: lift(ops.sqrt),  # type: ignore[attr-defined]
+            # Not lifted because we only use this function as a
+            # marker for adding the expression as validator input.
+            torch._assert: torch._assert,
+        }
+        return replacement_map[op] if op in replacement_map else lift(op)
+
+    # Processes an FX graph, populating the given validator.
+    #
+    # [Note: PopulateValidator]
+    # This class walks through each node in the FX graph, translating
+    # them into the Z3 world.
+    #
+    # Then, whenever it finds an 'torch._assert' call_function operation,
+    # it adds the Z3 expression corresponding to the argument as validator
+    # input.
+    class PopulateValidator(torch.fx.Interpreter):
+        def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
+            # Reference to the translation validator.
+            self.validator = validator
+
+            # Build the graph module and call `Interpreter` constructor.
+            module = torch.fx.GraphModule(root={}, graph=graph)
+            super().__init__(module, garbage_collect_values=True)
+
+        def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+            symbol = fx_traceback.get_current_meta()["symbol"]
+            return self.validator.z3var(symbol)
+
+        def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+            if target != torch._assert:
+                # Actually runs the node target function (which is already
+                # lifted) with its arguments.
+                return super().call_function(target, args, kwargs)
+            # Adds the Z3 expression corresponding to the first argument
+            # as a validator input.
+            assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} "
+            self.validator.add_source_expr(args[0])  # type: ignore[arg-type]
+
+    # Translates SymPy expressions into Z3 expressions.
+    #
+    # [Note: SympyToZ3]
+    # At the time of the translation, all free variables present in the
+    # SymPy expression being translated must be already mapped to a Z3
+    # integer variable.
+    class SympyToZ3:
+        OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
+
+        def __init__(
+                self,
+                validator: "TranslationValidator",
+        ) -> None:
+            self._validator = validator
+            self._ops = _Z3Ops(self._validator)
+
+        def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
+            if dtype is torch.int64:
+                return z3.IntVal(int(value))
+            if dtype is torch.double:
+                return z3.RealVal(float(value))
+            if dtype is torch.bool:
+                return z3.BoolVal(bool(value))
+            raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
+
+        def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.div(numerator, denominator)
+
+        def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.floordiv(numerator, denominator)
+
+        def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.floordiv(numerator, denominator)
+
+        def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.pow(base, exp)
+
+        def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
+            return self._ops.mod(p, q)
+
+        def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef:
+            return self._ops.round(number, ndigits)
+
+        def __getattr__(self, name: str) -> Any:
+            REPLACEMENT = {
+                "and_": z3.And,
+                "or_": z3.Or,
+                "not_": z3.Not,
+                "floor": self._ops.floor,
+                "ceil": self._ops.ceil,
+                "minimum": self._ops.min,
+                "maximum": self._ops.max,
+            }
+
+            if name in REPLACEMENT:
+                return REPLACEMENT[name]
+            if name in self.OPERATOR_HANDLES:
+                return getattr(operator, name)
+            raise AttributeError(f"unhandled operator: {name}")
+
+        def run(self, expr: sympy.Basic) -> z3.ExprRef:
+            return sympy_interp(self, self._validator.symbols, expr)  # type: ignore[arg-type]
+
+    # Dynamo guards translation validator.
+    #
+    # [Note: TranslationValidator]
+    # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound.
+    # That is: whether those (target) guards only yield TRUE whenever the original,
+    # unoptimized, (source) guards yield TRUE.
+    #
+    # More concretely, given 'source' and 'target' guard expressions, we wish to
+    # check whether the following expression holds:
+    #
+    # Not(And(source)) AND And(target)
+    #
+    # i.e. whether there is an assignment of the free variables where the opposite
+    # happens: target is TRUE, but source is FALSE.
+    class TranslationValidator:
+        def __init__(self) -> None:
+            log.debug("new instance")
+
+            # Mapping of SymPy symbols to Z3 variables.
+            self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
+
+            # Set of source Z3 expressions.
+            # They represent the generated guards without any kind of
+            # simplification or transformation.
+            self._source_exprs: Set[z3.BoolRef] = set()
+
+            # Set of target Z3 expressions.
+            # They represent the actual checked guards at runtime. They might
+            # be simplified or transformed versions of the source guards.
+            self._target_exprs: Set[z3.BoolRef] = set()
+
+            # Set of Z3 expressions representing assertions over both the
+            # source and target expressions.
+            self._assertions: Set[z3.BoolRef] = set()
+
+        # Retrieves the corresponding Z3 variable.
+        def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
+            assert symbol in self.symbols, f"Z3 variable not found for: {symbol}"
+            return self.symbols[symbol]
+
+        # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
+        def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef:
+            if symbol in self.symbols:
+                return self.symbols[symbol]
+
+            log.debug("new variable: %s (%s)", symbol.name, type.__name__)
+
+            if type is int:
+                var = z3.Int(symbol.name)
+
+                # If 'symbol' is positive (SymPy assumption), we have to
+                # convey it to Z3 as well.
+                if symbol.is_positive:  # type: ignore[attr-defined]
+                    self._target_exprs.add(var > 0)
+            elif type is float:
+                var = z3.Real(symbol.name)
+            elif type is bool:
+                var = z3.Bool(symbol.name)
+            else:
+                raise RuntimeError(f"unsupported type for Z3 variable: {type}")
+
+            self.symbols[symbol] = var
+            return var
+
+        # Checks whether all symbols were already added.
+        def _check_freesymbols(self, e: sympy.Basic) -> None:
+            for s in e.free_symbols:
+                assert isinstance(s, sympy.Symbol)
+                # Call 'z3var' just to check whether there's already a
+                # Z3 variable corresponding to 's'.
+                self.z3var(s)
+
+
+        def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
+            z3expr = SympyToZ3(self).run(e)
+            assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}"
+            return z3expr
+
+        def add_source_expr(self, e: z3.BoolRef) -> None:
+            if e not in self._source_exprs:
+                log.debug("add source guard: %s", z3str(e))
+            self._source_exprs.add(e)
+
+        def add_target_expr(self, e: sympy.Expr) -> None:
+            self._check_freesymbols(e)
+            z3expr = self.to_z3_boolean_expr(e)
+            if e not in self._target_exprs:
+                log.debug("add target guard: %s", z3str(z3expr))
+            self._target_exprs.add(z3expr)
+
+        def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
+            if isinstance(e, sympy.Basic):
+                self._check_freesymbols(e)
+                ref = self.to_z3_boolean_expr(e)
+            else:
+                ref = e
+            assert isinstance(ref, z3.BoolRef)
+            if ref not in self._assertions:
+                log.debug("add assertion: %s", z3str(ref))
+            self._assertions.add(ref)
+
+        def validate(self) -> None:
+            from torch._dynamo.utils import dynamo_timed
+
+            if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
+                # If there are no source/target expressions, there's nothing we really
+                # wish to prove. So, we just return.
+                return None
+
+            # Here, we use "QF_NRA" logic for the solver:
+            #   "Quantifier-free Non-linear Real Arithmetic".
+            #
+            # Most of the guards expressions have:
+            #   1. arithmetic between integer and reals
+            #   2. no quantifiers
+            #   3. potentially non-linear.
+            #
+            # Although there's also "QF_NIRA" (mixed integer-real arithmetic),
+            # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'.
+            solver = z3.SolverFor("QF_NRA")
+            # Set a timeout for finding a solution.
+            solver.set(timeout=translation_validation_timeout())
+
+            # Add all the assertions to the solver.
+            for assertion in self._assertions:
+                solver.add(assertion)
+
+            # "Is there any case where it's TRUE for the target expressions,
+            #  but FALSE for the source expressions?"
+            solver.add(z3.Not(z3.And(*self._source_exprs)))
+            solver.add(*self._target_exprs)
+
+            log.debug("translation validation: start")
+            r = dynamo_timed()(solver.check)()
+            if r == z3.sat:
+                # Target expressions are unsound.
+                # Log the found model and the source expressions that failed.
+                model = solver.model()
+                raise ValidationException(
+                    model, self._assertions, self._target_exprs,
+                    failed_source_exprs=[
+                        inp for inp in self._source_exprs if not model.evaluate(inp)
+                    ]
+                )
+            else:
+                if r == z3.unknown:
+                    # Could not find a solution. It didn't fail, but it also
+                    # didn't succeed. Canceling the validation execution (keyboard
+                    # interrupt) also gets to this branch.
+                    log.warning("translation validation: could not validate: got z3.unknown")
+                else:
+                    # Target expressions are sound.
+                    assert r == z3.unsat
+                    log.debug("translation validation: success")
+
+except ImportError:
+    _HAS_Z3 = False
+
+    __all__ = [
+        "translation_validation_enabled", "translation_validation_timeout",
+        "ValidationException", "BisectValidationException",
+    ]
+
+else:
+    _HAS_Z3 = True
+
+    __all__ = [
+        "z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator",
+        "translation_validation_enabled", "translation_validation_timeout",
+        "ValidationException", "BisectValidationException",
+    ]
+
+from torch.fx.experimental import _config as config
+
+def translation_validation_enabled() -> bool:
+    # Checks everytime this function is called, in case the Dynamo
+    # option is set, but Z3 is not installed.
+    _assert_z3_installed_if_tv_set()
+    return _HAS_Z3 and config.translation_validation
+
+
+def translation_validation_timeout() -> int:
+    return config.translation_validation_timeout
+
+
+def _assert_z3_installed_if_tv_set():
+    assert _HAS_Z3 or not config.translation_validation, (
+        "translation validation requires Z3 package. Please, either install "
+        "z3-solver or disable translation validation."
+    )
+
+
+class ValidationException(TorchDynamoException):
+    def __init__(self, model, assertions, target_exprs, failed_source_exprs):
+        assert _HAS_Z3
+
+        def symbolstr(sym) -> str:
+            return f"{sym}: {model[sym]}"
+
+        def joinlines(xs) -> str:
+            return "\n".join(f"  ==> {x}" for x in xs)
+
+        model_str = joinlines(sorted(map(symbolstr, model)))
+        assertions_str = joinlines(sorted(map(z3str, assertions)))
+        target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
+        failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
+
+        self.msg = "translation validation failed."
+        self.details = f"""\
+Model:
+{model_str}
+
+Assertions:
+{assertions_str}
+
+Target Expressions:
+{target_exprs_str}
+
+Failed Source Expressions:
+{failed_source_exprs_str}"""
+
+    def __str__(self):
+        return f"{self.msg}\n\n{self.details}"
+
+
+class BisectValidationException(TorchDynamoException):
+    def __init__(self, validation_exc, expr, failed_action, traced_node):
+        self.msg = f"translation validation failed when {failed_action}: {expr}"
+        self.details = f"""\
+Failure occurred while running node:
+    {traced_node.format_node()}
+
+{validation_exc.details}"""
+
+    def __str__(self):
+        return f"{self.msg}\n\n{self.details}"
+
+# Checks when this module is loaded.
+_assert_z3_installed_if_tv_set()
+
+# Translation validation bisection.
+#
+# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
+# the earliest ValidationException.
+#
+# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors
+# might be silently happening. This function tries to nail down exactly at which
+# point things went wrong from a validation perspective.
+def bisect(shape_env):
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY
+    from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events
+
+    events = shape_env.events
+
+    # Retrieves the ShapeEnvEvent associated with node.
+    def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
+        assert SHAPEENV_EVENT_KEY in node.meta
+        return events[node.meta[SHAPEENV_EVENT_KEY]]
+
+    # Creates a new instance of fake, but updating every symbolic value's ShapeEnv
+    # reference to the one given as argument.
+    #
+    # This is needed so as not to simplify a symbolic expression using a ShapeEnv
+    # "from the future", where it may have a different set of replacements.
+    def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
+        if isinstance(fake, int):
+            return fake
+        if isinstance(fake, torch.SymInt):
+            return torch.SymInt(fake.node.with_shape_env(shape_env))
+        assert isinstance(fake, FakeTensorMeta)
+        return FakeTensorMeta(
+            tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
+            tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
+            new_with_shape_env(shape_env, fake.storage_offset()),
+            fake.is_nested,
+        )
+
+    # Checks whether the given shape_env fails when produce_guards is called.
+    def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]:
+        assert tracked_fakes is not None
+        try:
+            # This produce_guards call is a best-effort replication, since we
+            # don't populate EqualityConstraint list. Reason: we would also have
+            # to save OutputGraph.tracked_fakes_id_to_source.
+            shape_env.produce_guards(
+                [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
+                [a.source for a in tracked_fakes],
+                input_contexts=[a.symbolic_context for a in tracked_fakes],
+            )
+            return None
+        except ValidationException as e:
+            return e
+
+    # Checks whether the ShapeEnv reconstructed by replaying the events until
+    # node is created fails when produce_guards is called.
+    def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
+        number = node.meta[SHAPEENV_EVENT_KEY]
+        # Reconstruct shape_env until the event at event_number.
+        shape_env = replay_shape_env_events(events[:number + 1])
+        shape_env.graph.lint()
+        return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
+
+    last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes())
+
+    if not last_exception:
+        # We don't actually fail due to a produce_guards call.
+        # Stop and don't bisect.
+        log.info("translation validation succeeded: no errors found.")
+        return
+
+    if not shape_env.should_record_events or config.translation_validation_no_bisect:
+        # Bisection is off.
+        # Return the last ValidationException we got.
+        raise last_exception
+
+    # Cache the raised exception (if any) at each bisection point.
+    exception = {}
+
+    # Bisection happens on the assertion nodes of the recorded FX graph for
+    # dynamic shapes.
+    assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert]
+
+    # Preparing the indices for binary search.
+    left, mid, right = 0, 0, len(assert_nodes) - 1
+
+    while left < right:
+        mid = (left + right) // 2
+
+        node = assert_nodes[mid]
+        log.debug("bisecting at %s: %s", mid, get_node_event(node))
+
+        # Check whether the new shape_env raises a ValidationException or not.
+        exception[mid] = check_node_fails(node)
+
+        if exception[mid]:
+            right = mid
+        else:
+            left = mid + 1
+
+    assert left in exception and isinstance(exception[left], ValidationException)
+
+    node = assert_nodes[left]
+    event = get_node_event(node)
+
+    if event.is_evaluate_expr():
+        failed_action = "evaluating"
+    else:
+        assert event.is_defer_runtime_assert(), f"unexpected event type: {event}"
+        failed_action = "adding runtime assert"
+
+    args = event.args
+    assert args is not None
+    assert len(args) >= 2, (
+        f"bisecting expects {event.name} to have at least 2 positional arguments. "
+        f"Got: {len(args)}"
+    )
+    assert isinstance(args[1], sympy.Basic), (
+        f"bisecting expects {event.name} to have a SymPy expression as its second argument. "
+        f"Got: {type(args[1])}"
+    )
+
+    raise BisectValidationException(
+        exception[left],
+        expr=args[1],
+        failed_action=failed_action,
+        traced_node=node.meta[CURRENT_NODE_KEY],
+    )
diff --git a/MLPY/Lib/site-packages/torch/fx/graph.py b/MLPY/Lib/site-packages/torch/fx/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab867dc732d02c5bc4b6e85d88b6866d611e31f1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/graph.py
@@ -0,0 +1,1653 @@
+from collections import defaultdict
+from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
+import torch.utils._pytree as pytree
+from . import _pytree as fx_pytree
+from ._compatibility import compatibility
+
+import contextlib
+from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
+from dataclasses import dataclass
+from contextlib import contextmanager
+import copy
+import enum
+import torch
+import keyword
+import re
+import builtins
+import math
+import warnings
+import inspect
+
+__all__ = ["PythonCode", "CodeGen", "Graph"]
+
+if TYPE_CHECKING:
+    from .graph_module import GraphModule  # noqa: F401
+    from ._symbolic_trace import Tracer   # noqa: F401
+
+
+# Mapping of builtins to their `typing` equivalent.
+_origin_type_map = {
+    list: List,
+    dict: Dict,
+    set: Set,
+    frozenset: FrozenSet,
+    tuple: Tuple,
+}
+
+
+# Signature for functions thattransforms the body (`list[str]`) of the
+# generated code
+TransformCodeFunc = Callable[[List[str]], List[str]]
+
+
+class _CustomBuiltin(NamedTuple):
+    """Additional objs that we add to every graph's globals.
+
+    The repr() for some standard library objects is not valid Python code without
+    an import. For common objects of this sort, we bundle them in the globals of
+    every FX graph.
+    """
+    # How to import this object from the standard library.
+    import_str: str
+    # The actual object, produced from that import string.
+    obj: Any
+
+_custom_builtins: Dict[str, _CustomBuiltin] = {}
+
+
+def _register_custom_builtin(name: str, import_str: str, obj: Any):
+    _custom_builtins[name] = _CustomBuiltin(import_str, obj)
+
+
+_register_custom_builtin('inf', 'from math import inf', math.inf)
+_register_custom_builtin('nan', 'from math import nan', math.nan)
+_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
+_register_custom_builtin('torch', 'import torch', torch)
+_register_custom_builtin('device', 'from torch import device', torch.device)
+_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
+_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)
+
+
+def _is_magic(x: str) -> bool:
+    return x.startswith('__') and x.endswith('__')
+
+
+def _snake_case(s: str) -> str:
+    """
+    Transforms the given string ``s`` to a Python-style variable name
+
+    Examples:
+        ``mod.snake_case`` -> ``mod.snake_case``
+        ``mod.pascalCase``-> ``mod.pascal_case``
+        ``mod.ALL_CAPS`` -> ``mod.all_caps``
+    """
+    chars = []
+    prev_lower = False
+    for c in s:
+        if prev_lower and c.isupper():
+            chars.append('_')
+        chars.append(c.lower())
+        prev_lower = c.islower()
+    return ''.join(chars)
+
+
+def _is_from_torch(obj: Any) -> bool:
+    module_name = getattr(obj, '__module__', None)
+    if module_name is not None:
+        base_module = module_name.partition('.')[0]
+        return (
+            base_module == 'torch' and
+            not module_name.startswith("torch._dynamo.") and
+            not module_name.startswith("torch._inductor.")
+        )
+
+    name = getattr(obj, '__name__', None)
+    # exclude torch because torch.torch.torch.torch works. idk mang
+    if name is not None and name != 'torch':
+        for guess in [torch, torch.nn.functional]:
+            if getattr(guess, name, None) is obj:
+                return True
+
+    return False
+
+
+class _Namespace:
+    """A context for associating names uniquely with objects.
+
+    The following invariants are enforced:
+    - Each object gets a single name.
+    - Each name is unique within a given namespace.
+    - Names generated do not shadow builtins, unless the object is indeed that builtin.
+    """
+    def __init__(self):
+        self._obj_to_name: Dict[Any, str] = {}
+        self._unassociated_names = set()
+        self._used_names: Set[str] = set()
+        self._base_count: Dict[str, int] = defaultdict(int)
+
+        self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
+        self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
+
+    def create_name(self, candidate: str, obj: Optional[Any]) -> str:
+        """Create a unique name.
+
+        Arguments:
+            candidate: used as the basis for the unique name, relevant to the user.
+            obj: If not None, an object that will be associated with the unique name.
+        """
+        if obj is not None and obj in self._obj_to_name:
+            return self._obj_to_name[obj]
+
+        # delete all characters that are illegal in a Python identifier
+        candidate = self._illegal_char_regex.sub('_', candidate)
+
+        if not candidate:
+            candidate = '_unnamed'
+
+        if candidate[0].isdigit():
+            candidate = f'_{candidate}'
+
+        match = self._name_suffix_regex.match(candidate)
+        if match is None:
+            base = candidate
+            num = None
+        else:
+            base, num_str = match.group(1, 2)
+            num = int(num_str)
+
+        candidate = base if num is None else f'{base}_{num}'
+        if not num:
+            num = self._base_count[base]
+
+        while candidate in self._used_names or self._is_illegal_name(candidate, obj):
+            num += 1
+            candidate = f'{base}_{num}'
+
+        self._used_names.add(candidate)
+        self._base_count[base] = num
+        if obj is None:
+            self._unassociated_names.add(candidate)
+        else:
+            self._obj_to_name[obj] = candidate
+        return candidate
+
+    def associate_name_with_obj(self, name: str, obj: Any):
+        """Associate a unique name with an object.
+
+        Neither `name` nor `obj` should be associated already.
+        """
+        assert obj not in self._obj_to_name
+        assert name in self._unassociated_names
+        self._obj_to_name[obj] = name
+        self._unassociated_names.remove(name)
+
+    def _is_illegal_name(self, name: str, obj: Any) -> bool:
+        # 1. keywords are never allowed as names.
+        if name in keyword.kwlist:
+            return True
+
+        # 2. Can't shadow a builtin name, unless you *are* that builtin.
+        if name in builtins.__dict__:
+            return obj is not builtins.__dict__[name]
+
+        # 3. Can't shadow our custom builtins either
+        if name in _custom_builtins:
+            return obj is not _custom_builtins[name].obj
+
+        return False
+
+    def _rename_object(self, obj: Any, name: str):
+        assert obj in self._obj_to_name
+        self._obj_to_name[obj] = name
+        self._used_names.add(name)
+
+dtype_abbrs = {
+    torch.bfloat16: 'bf16',
+    torch.float64: 'f64',
+    torch.float32: 'f32',
+    torch.float16: 'f16',
+    torch.float8_e4m3fn: 'f8e4m3fn',
+    torch.float8_e5m2: 'f8e5m2',
+    torch.float8_e4m3fnuz: 'f8e4m3fnuz',
+    torch.float8_e5m2fnuz: 'f8e5m2fnuz',
+    torch.complex32: 'c32',
+    torch.complex64: 'c64',
+    torch.complex128: 'c128',
+    torch.int8: 'i8',
+    torch.int16: 'i16',
+    torch.int32: 'i32',
+    torch.int64: 'i64',
+    torch.bool: 'b8',
+    torch.uint8: 'u8',
+    torch.uint32: 'u32',
+    torch.uint64: 'u64',
+}
+
+@compatibility(is_backward_compatible=True)
+@dataclass
+class PythonCode:
+    """
+    Represents all the information necessary to exec or save a graph as Python code.
+    """
+    # Python source code for the forward function definition.
+    src: str
+    # Values in global scope during execution of `src_def`.
+    globals: Dict[str, Any]
+    # Optional mapping from the forward function's line number to
+    # node index.
+    _lineno_map: Optional[Dict[int, Optional[int]]]
+
+
+def _format_target(base: str, target: str) -> str:
+    elems = target.split('.')
+    r = base
+    for e in elems:
+        if not e.isidentifier():
+            r = f'getattr({r}, "{e}")'
+        else:
+            r = f'{r}.{e}'
+    return r
+
+class _InsertPoint:
+    def __init__(self, graph, new_insert):
+        self.graph = graph
+        self.orig_insert, graph._insert = graph._insert, new_insert
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, type, value, tb):
+        self.graph._insert = self.orig_insert
+
+class _node_list:
+    def __init__(self, graph: 'Graph', direction: str = '_next'):
+        assert direction in ['_next', '_prev']
+        self.graph = graph
+        self.direction = direction
+
+    def __len__(self):
+        return self.graph._len
+
+    def __iter__(self):
+        root = self.graph._root
+        if self.direction == "_next":
+            cur = root._next
+            while cur is not root:
+                if not cur._erased:
+                    yield cur
+                cur = cur._next
+        else:
+            assert self.direction == "_prev"
+            cur = root._prev
+            while cur is not root:
+                if not cur._erased:
+                    yield cur
+                cur = cur._prev
+
+    def __reversed__(self):
+        return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
+
+class _PyTreeInfo(NamedTuple):
+    """
+    Contains extra info stored when we're using Pytrees
+    """
+    orig_args: List[str]
+    in_spec: pytree.TreeSpec
+    out_spec: Optional[pytree.TreeSpec]
+
+@dataclass(frozen=True)
+class _ParsedStackTrace:
+    """
+    Represents the top-most frame of a parsed stack trace
+    """
+    file: str
+    lineno: str
+    name: str
+    code: str
+
+# get File:lineno code from stack_trace
+def _parse_stack_trace(stack_trace: str):
+    if stack_trace is None:
+        return None
+    pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
+    lines = stack_trace.strip().split('\n')
+    # stacktrace should have innermost frame last, so we
+    # iterate backwards to find the first line that starts
+    # with 'File '
+    summary_str = ""
+    for idx in range(len(lines) - 2, -1, -1):
+        line = lines[idx].strip()
+        matches = pattern.match(line)
+        if matches:
+            file = matches.group(1)
+            lineno = matches.group(2)
+            name = matches.group(3)
+            # next line should be the code
+            code = lines[idx + 1].strip()
+            return _ParsedStackTrace(file, lineno, name, code)
+    return None
+
+@compatibility(is_backward_compatible=False)
+class CodeGen:
+    def __init__(self):
+        self._body_transformer: Optional[TransformCodeFunc] = None
+        self._func_name: str = "forward"
+
+    def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
+        """
+        Given the free variables and a return annotation, generates the beginning of the FX function.
+        By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
+        """
+        # If the original function didn't have self as its first argument, we
+        # would have added it.
+        if len(free_vars) == 0 or free_vars[0] != 'self':
+            free_vars.insert(0, 'self')
+        return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
+
+    def generate_output(self, output_args: Argument) -> str:
+        """
+        Given the output arguments, generates the return statement of the FX function.
+        Note: The returned statement should not be indented.
+        """
+        return f'return {repr(output_args)}'
+
+    def process_inputs(self, *args: Any) -> Any:
+        """
+        Transforms the inputs so that the graph can take them as arguments, as
+        non-default codegen may result in the inputs to the function being
+        different from the inputs to the graph.
+
+        If the graph was directly runnable, this invariant should hold true
+        `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
+        """
+        return args
+
+    def process_outputs(self, outputs: Any) -> Any:
+        """
+        Transforms the outputs of the graph to be identical to the codegen.
+
+        See ``process_inputs`` for more details.
+        """
+        return outputs
+
+    def additional_globals(self) -> List[Tuple[str, Any]]:
+        """
+        If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
+        For example, return ['List', typing.List] if you need ``List`` in the global context.
+        """
+        return []
+
+    def _gen_python_code(
+        self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,
+    ) -> PythonCode:
+        free_vars: List[str] = []
+        body: List[str] = []
+        globals_: Dict[str, Any] = {}
+        wrapped_fns: Dict[str, None] = {}
+
+        # Wrap string in list to pass by reference
+        maybe_return_annotation : List[str] = ['']
+
+        def add_global(name_hint: str, obj: Any):
+            """Add an obj to be tracked as a global.
+
+            We call this for names that reference objects external to the
+            Graph, like functions or types.
+
+            Returns: the global name that should be used to reference 'obj' in generated source.
+            """
+            if _is_from_torch(obj) and obj != torch.device:  # to support registering torch.device
+                # HACK: workaround for how torch custom ops are registered. We
+                # can't import them like normal modules so they must retain their
+                # fully qualified name.
+                return _get_qualified_name(obj)
+
+            # normalize the name hint to get a proper identifier
+            global_name = namespace.create_name(name_hint, obj)
+
+            if global_name in globals_:
+                assert globals_[global_name] is obj
+                return global_name
+            globals_[global_name] = obj
+            return global_name
+
+        # Pre-fill the globals table with registered builtins.
+        for name, (_, obj) in _custom_builtins.items():
+            add_global(name, obj)
+
+        def type_repr(o : Any):
+            if o == ():
+                # Empty tuple is used for empty tuple type annotation Tuple[()]
+                return '()'
+
+            typename = _type_repr(o)
+
+            if hasattr(o, '__origin__'):
+                # This is a generic type, e.g. typing.List[torch.Tensor]
+                origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
+                origin_typename = add_global(_type_repr(origin_type), origin_type)
+
+                if hasattr(o, '__args__'):
+                    # Assign global names for each of the inner type variables.
+                    args = [type_repr(arg) for arg in o.__args__]
+
+                    if len(args) == 0:
+                        # Bare type, such as `typing.Tuple` with no subscript
+                        # This code-path used in Python < 3.9
+                        return origin_typename
+
+                    return f'{origin_typename}[{",".join(args)}]'
+                else:
+                    # Bare type, such as `typing.Tuple` with no subscript
+                    # This code-path used in Python 3.9+
+                    return origin_typename
+
+            # Common case: this is a regular module name like 'foo.bar.baz'
+            return add_global(typename, o)
+
+        def _get_repr(arg: Any) -> str:
+            # Handle NamedTuples (if it has `_fields`) via add_global.
+            if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+                qualified_name = _get_qualified_name(type(arg))
+                global_name = add_global(qualified_name, type(arg))
+                return f"{global_name}{repr(tuple(arg))}"
+            elif isinstance(arg, torch._ops.OpOverload):
+                qualified_name = _get_qualified_name(arg)
+                global_name = add_global(qualified_name, arg)
+                return f"{global_name}"
+            elif isinstance(arg, enum.Enum):
+                cls = arg.__class__
+                clsname = add_global(cls.__name__, cls)
+                return f"{clsname}.{arg.name}"
+            return repr(arg)
+
+        def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
+            args_s = ', '.join(_get_repr(a) for a in args)
+            kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+            if args_s and kwargs_s:
+                return f'{args_s}, {kwargs_s}'
+            return args_s or kwargs_s
+
+        # Run through reverse nodes and record the first instance of a use
+        # of a given node. This represents the *last* use of the node in the
+        # execution order of the program, which we will use to free unused
+        # values
+        node_to_last_use : Dict[Node, Node] = {}
+        user_to_last_uses : Dict[Node, List[Node]] = {}
+
+        def register_last_uses(n : Node, user : Node):
+            if n not in node_to_last_use:
+                node_to_last_use[n] = user
+                user_to_last_uses.setdefault(user, []).append(n)
+
+        for node in reversed(nodes):
+            map_arg(node.args, lambda n: register_last_uses(n, node))
+            map_arg(node.kwargs, lambda n: register_last_uses(n, node))
+
+        def delete_unused_values(user : Node):
+            """
+            Delete values after their last use. This ensures that values that are
+            not used in the remainder of the code are freed and the memory usage
+            of the code is optimal.
+            """
+            if user.op == 'placeholder':
+                return
+            if user.op == 'output':
+                body.append('\n')
+                return
+            nodes_to_delete = user_to_last_uses.get(user, [])
+            if len(nodes_to_delete):
+                to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
+                body.append(f';  {to_delete_str}\n')
+            else:
+                body.append('\n')
+
+        prev_stacktrace = None
+
+        def append_stacktrace_summary(node : Node):
+            """
+            Append a summary of the stacktrace to the generated code. This is
+            useful for debugging.
+            """
+            nonlocal prev_stacktrace
+
+            if node.op not in {'placeholder', 'output'}:
+                if node.stack_trace:
+                    if node.stack_trace != prev_stacktrace:
+                        prev_stacktrace = node.stack_trace
+                        summary_str = ""
+
+                        parsed_stack_trace = _parse_stack_trace(node.stack_trace)
+
+                        if parsed_stack_trace is not None:
+                            lineno = parsed_stack_trace.lineno
+                            code = parsed_stack_trace.code
+                            name = parsed_stack_trace.name
+                            summary_str = f'File: {parsed_stack_trace.file}:{lineno} in {name}, code: {code}'
+
+                        body.append(f'\n# {summary_str}\n')
+                elif prev_stacktrace != "":
+                    prev_stacktrace = ""
+                    body.append('\n# No stacktrace found for following nodes\n')
+
+        def stringify_shape(shape : torch.Size) -> str:
+            return f"[{', '.join(str(x) for x in shape)}]"
+
+        def emit_node(node : Node):
+            maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
+
+            if verbose:
+                # override annotation with more detailed information
+                from torch._subclasses.fake_tensor import FakeTensor
+                from torch.fx.experimental.proxy_tensor import py_sym_types
+                from torch.fx.passes.shape_prop import TensorMetadata
+
+                meta_val = node.meta.get('val', node.meta.get('tensor_meta', None))
+
+                # use string as annotation, to make it valid python code
+                if isinstance(meta_val, FakeTensor):
+                    maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
+                elif isinstance(meta_val, py_sym_types):
+                    maybe_type_annotation = f': "Sym({meta_val})"'
+                elif isinstance(meta_val, TensorMetadata):
+                    maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
+
+            if node.op == 'placeholder':
+                assert isinstance(node.target, str)
+                maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}'
+                free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
+                raw_name = node.target.replace('*', '')
+                if raw_name != repr(node):
+                    body.append(f'{repr(node)} = {raw_name}\n')
+                return
+            elif node.op == 'call_method':
+                assert isinstance(node.target, str)
+                body.append(
+                    f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}'
+                    f'({_format_args(node.args[1:], node.kwargs)})')
+                return
+            elif node.op == 'call_function':
+                assert callable(node.target)
+                # pretty print operators
+                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
+                    assert isinstance(node.args, tuple)
+                    body.append(f'{repr(node)}{maybe_type_annotation} = '
+                                f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}')
+                    return
+
+                # pretty print inplace operators; required for jit.script to work properly
+                # not currently supported in normal FX graphs, but generated by torchdynamo
+                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
+                    body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))};  '
+                                f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}')
+                    return
+
+                qualified_name = _get_qualified_name(node.target)
+                global_name = add_global(qualified_name, node.target)
+                # special case for getattr: node.args could be 2-argument or 3-argument
+                # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
+                if global_name == 'getattr' and \
+                   isinstance(node.args, tuple) and \
+                   isinstance(node.args[1], str) and \
+                   node.args[1].isidentifier() and \
+                   len(node.args) == 2:
+                    body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}')
+                    return
+                body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
+                if node.meta.get('is_wrapped', False):
+                    wrapped_fns.setdefault(global_name)
+                return
+            elif node.op == 'call_module':
+                assert isinstance(node.target, str)
+                body.append(f'{repr(node)}{maybe_type_annotation} = '
+                            f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+                return
+            elif node.op == 'get_attr':
+                assert isinstance(node.target, str)
+                body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+                return
+            elif node.op == 'output':
+                if node.type is not None:
+                    maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
+                body.append(self.generate_output(node.args[0]))
+                return
+            raise NotImplementedError(f'node: {node.op} {node.target}')
+
+        for i, node in enumerate(nodes):
+            # NOTE: emit_node does not emit a string with newline. It depends
+            # on delete_unused_values to append one
+            if verbose:
+                append_stacktrace_summary(node)
+            # emit a counter comment to keep track of
+            # node index, which will be deleted later
+            # after going through _body_transformer
+            body.append(f"# COUNTER: {i}\n")
+            emit_node(node)
+            delete_unused_values(node)
+
+        if len(body) == 0:
+            # If the Graph has no non-placeholder nodes, no lines for the body
+            # have been emitted. To continue to have valid Python code, emit a
+            # single pass statement
+            body.append('pass\n')
+
+
+
+        if len(wrapped_fns) > 0:
+            wrap_name = add_global('wrap', torch.fx.wrap)
+            wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+        else:
+            wrap_stmts = ''
+
+        if self._body_transformer:
+            body = self._body_transformer(body)
+
+        for name, value in self.additional_globals():
+            add_global(name, value)
+
+        prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
+
+        # remove counter and generate lineno to node index mapping
+        lineno_map: Dict[int, Optional[int]] = {}
+        prologue_len = prologue.count('\n') + 1
+        new_lines: List[str] = []
+        cur_idx = None
+        for line in ''.join(body).split('\n'):
+            counter = re.search(r"# COUNTER: (\d+)", line)
+            if counter and counter.group(1) is not None:
+                cur_idx = int(counter.group(1))
+            else:
+                lineno_map[len(new_lines) + prologue_len] = cur_idx
+                new_lines.append(line)
+
+        code = "\n".join(new_lines).lstrip('\n')
+        code = '\n'.join('    ' + line for line in code.split('\n'))
+
+        fn_code = f"""
+{wrap_stmts}
+
+{prologue}
+{code}"""
+        return PythonCode(fn_code, globals_, _lineno_map=lineno_map)
+
+
+# Ideally, we'd like to refactor all of the pytree logic into this codegen
+# class. Unfortunately, there are 3 areas we currently need extra logic in FX.
+# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
+# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
+#    Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
+# 3. We currently can't register the pytree imports with `add_global` - not sure why.
+class _PyTreeCodeGen(CodeGen):
+    def __init__(self, pytree_info: _PyTreeInfo):
+        super().__init__()
+        self.pytree_info: _PyTreeInfo = pytree_info
+
+    def process_inputs(self, *inputs: Any) -> Any:
+        flat_args = pytree.arg_tree_leaves(*inputs)
+        return flat_args
+
+    def process_outputs(self, out: Any) -> Any:
+        if self.pytree_info is None or self.pytree_info.out_spec is None:
+            return out
+        if not isinstance(out, (list, tuple)):
+            out = [out]
+        assert self.pytree_info.out_spec is not None
+        return pytree.tree_unflatten(out, self.pytree_info.out_spec)
+
+    def gen_fn_def(self, free_vars, maybe_return_annotation):
+        # Given a user function/model:
+        #   myargs = [myargs0, myargs1]
+        #   mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
+        #   def forward(self, mypos, *myargs, mykey=None, **mykwargs):
+        #
+        # The generated code flattens all keywords into positional arguments for `forward()`
+        #   e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
+        #
+        # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
+        #   e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
+        #                           {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
+        #                          self._in_spec)
+        #
+        # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
+        #   e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
+        if self.pytree_info is None:
+            return super().gen_fn_def(free_vars, maybe_return_annotation)
+
+        fn_args = self.pytree_info.orig_args
+        has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
+        if has_orig_self:
+            free_vars.insert(0, 'self')
+        fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
+
+        if len(free_vars) > 0:  # pytree has placeholders in it
+            # when kwargs is present, in_spec is tuple(args, kwargs)
+            has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
+                self.pytree_info.in_spec.num_children == 2 and \
+                self.pytree_info.in_spec.children_specs[0].type == tuple and \
+                self.pytree_info.in_spec.children_specs[1].type == dict
+            fn_kwargs = '{}'
+            fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
+            if has_args_kwargs_tuple:
+                count_args = self.pytree_info.in_spec.children_specs[0].num_children
+                fn_args = self.pytree_info.orig_args[:count_args]
+                fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
+                                  self.pytree_info.in_spec.children_specs[1].context,
+                                  self.pytree_info.orig_args[count_args:])) + '}'
+                fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
+
+            # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
+            # we need to split it to two lines:
+            # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
+            # one for code: `var1, var2, = function_call()`
+            without_annotation = [x.split(":")[0] for x in free_vars]
+            has_annotation = [x + "; " for x in free_vars if ":" in x]
+            if len(has_annotation) > 0:
+                fn_definition += "\n    " + "".join(has_annotation) + "\n"
+            fn_definition += f"""
+    {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
+        return fn_definition
+
+    def generate_output(self, output_args):
+        if self.pytree_info and self.pytree_info.out_spec:
+            return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
+        else:
+            return super().generate_output(output_args)
+
+@compatibility(is_backward_compatible=True)
+class Graph:
+    """
+    ``Graph`` is the main data structure used in the FX Intermediate Representation.
+    It consists of a series of ``Node`` s, each representing callsites (or other
+    syntactic constructs). The list of ``Node`` s, taken together, constitute a
+    valid Python function.
+
+    For example, the following code
+
+    .. code-block:: python
+
+        import torch
+        import torch.fx
+
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.param = torch.nn.Parameter(torch.rand(3, 4))
+                self.linear = torch.nn.Linear(4, 5)
+
+            def forward(self, x):
+                return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
+
+        m = MyModule()
+        gm = torch.fx.symbolic_trace(m)
+
+    Will produce the following Graph::
+
+        print(gm.graph)
+
+    .. code-block:: text
+
+        graph(x):
+            %linear_weight : [num_users=1] = self.linear.weight
+            %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
+            %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
+            %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
+            %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
+            %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
+            return topk_1
+
+    For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
+    """
+
+    @compatibility(is_backward_compatible=True)
+    def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
+                 tracer_extras: Optional[Dict[str, Any]] = None):
+        """
+        Construct an empty Graph.
+        """
+        self._root : Node = Node(self, '', 'root', '', (), {})
+        self._used_names : Dict[str, int] = {}  # base name -> number
+        self._insert = self._root.prepend
+        self._len = 0
+        self._graph_namespace = _Namespace()
+        self._owning_module = owning_module
+        self._tracer_cls = tracer_cls
+        self._tracer_extras = tracer_extras
+        self._codegen = CodeGen()
+        self._co_fields : Dict[str, Any] = {}
+
+    @property
+    def owning_module(self):
+        return self._owning_module
+
+    @owning_module.setter
+    def owning_module(self, mod: Optional["GraphModule"]):
+        self._owning_module = mod
+
+    @property
+    def nodes(self) -> _node_list:
+        """
+        Get the list of Nodes that constitute this Graph.
+
+        Note that this ``Node`` list representation is a doubly-linked list. Mutations
+        during iteration (e.g. delete a Node, add a Node) are safe.
+
+        Returns:
+
+            A doubly-linked list of Nodes. Note that ``reversed`` can be called on
+            this list to switch iteration order.
+        """
+        return _node_list(self)
+
+    @compatibility(is_backward_compatible=True)
+    def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
+        """
+        Copy all nodes from a given graph into ``self``.
+
+        Args:
+
+            g (Graph): The source graph from which to copy Nodes.
+
+            val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping
+                from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed
+                in with values in it already to override copying of certain values.
+
+        Returns:
+
+            The value in ``self`` that is now equivalent to the output value in ``g``,
+            if ``g`` had an ``output`` node. ``None`` otherwise.
+        """
+        for node in g.nodes:
+            if node in val_map:
+                continue
+            if node.op == 'output':
+                rv = map_arg(node.args[0], lambda n: val_map[n])
+                return rv if not return_output_node else (rv, node)
+            val_map[node] = self.node_copy(node, lambda n : val_map[n])
+        return None
+
+    def __deepcopy__(self, memo=None) -> 'Graph':
+        """
+        Explicitly implement __deepcopy__ to prevent excessive recursion depth
+        from the default implementation. This uses graph_copy to copy the nodes
+        in an iterative way, rather than recursive. It also populates the
+        memoization table to prevent unnecessary copies (e.g. references to
+        nodes or other parts of the Graph from a custom GraphModule implementation.
+        """
+        memo = memo if memo else {}
+        g = Graph(tracer_cls=self._tracer_cls)
+        output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
+        g._codegen = copy.deepcopy(self._codegen)
+        assert isinstance(output_vals, tuple)
+        output_val, old_output_node = output_vals
+        new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None))
+        new_output_node.meta = copy.copy(old_output_node.meta)
+        return g
+
+    @compatibility(is_backward_compatible=True)
+    def create_node(self, op: str, target: 'Target',
+                    args: Optional[Tuple['Argument', ...]] = None,
+                    kwargs: Optional[Dict[str, 'Argument']] = None,
+                    name: Optional[str] = None,
+                    type_expr: Optional[Any] = None) -> Node:
+        """
+        Create a ``Node`` and add it to the ``Graph`` at the current insert-point.
+        Note that the current insert-point can be set via :meth:`Graph.inserting_before`
+        and :meth:`Graph.inserting_after`.
+
+        Args:
+            op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
+                'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
+                described in the ``Graph`` docstring.
+
+            args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.
+
+            kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node
+
+            name (Optional[str]): an optional string name for the ``Node``.
+                This will influence the name of the value assigned to in the
+                Python generated code.
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have.
+
+        Returns:
+
+            The newly-created and inserted node.
+        """
+        assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
+        args = () if args is None else args
+        kwargs = {} if kwargs is None else kwargs
+        assert isinstance(args, tuple), "args must be a tuple"
+        assert isinstance(kwargs, dict), "kwargs must be a dict"
+
+        candidate = name if name is not None else self._target_to_str(target)
+        name = self._graph_namespace.create_name(candidate, None)
+        n = Node(self, name, op, target, args, kwargs, type_expr)
+
+        self._graph_namespace.associate_name_with_obj(name, n)
+
+        self._insert(n)
+        self._len += 1
+        return n
+
+    @compatibility(is_backward_compatible=False)
+    def process_inputs(self, *args):
+        """
+        Processes args so that they can be passed to the FX graph.
+        """
+        return self._codegen.process_inputs(*args)
+
+    @compatibility(is_backward_compatible=False)
+    def process_outputs(self, out):
+        return self._codegen.process_outputs(out)
+
+
+    @compatibility(is_backward_compatible=True)
+    def erase_node(self, to_erase : Node) -> None:
+        """
+        Erases a ``Node`` from the ``Graph``. Throws an exception if
+        there are still users of that node in the ``Graph``.
+
+        Args:
+
+            to_erase (Node): The ``Node`` to erase from the ``Graph``.
+        """
+        if len(to_erase.users) > 0:
+            raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
+                               f'users in the graph: {to_erase.users}!')
+        if to_erase.graph != self:
+            raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!")
+        if to_erase._erased:
+            warnings.warn(f"erase_node({to_erase}) on an already erased node")
+            return
+
+        to_erase._remove_from_list()
+        to_erase._erased = True  # iterators may retain handles to erased nodes
+        self._len -= 1
+
+        # Null out this Node's argument nodes so that the Nodes referred to
+        # can update their ``users`` accordingly
+        new_args = map_arg(to_erase.args, lambda n: None)
+        assert isinstance(new_args, tuple)
+        to_erase.args = new_args
+        new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
+        assert isinstance(new_kwargs, dict)
+        to_erase.kwargs = new_kwargs
+
+    @compatibility(is_backward_compatible=True)
+    def inserting_before(self, n: Optional[Node] = None):
+        """Set the point at which create_node and companion methods will insert into the graph.
+        When used within a 'with' statement, this will temporary set the insert point and
+        then restore it when the with statement exits::
+
+            with g.inserting_before(n):
+                ... # inserting before node n
+            ... # insert point restored to what it was previously
+            g.inserting_before(n) #  set the insert point permanently
+
+        Args:
+
+            n (Optional[Node]): The node before which to insert. If None this will insert before
+                the beginning of the entire graph.
+
+        Returns:
+            A resource manager that will restore the insert point on ``__exit__``.
+        """
+        if n is None:
+            return self.inserting_after(self._root)
+        assert n.graph == self, "Node to insert before is not in graph."
+        return _InsertPoint(self, n.prepend)
+
+    @compatibility(is_backward_compatible=True)
+    def inserting_after(self, n: Optional[Node] = None):
+        """Set the point at which create_node and companion methods will insert into the graph.
+        When used within a 'with' statement, this will temporary set the insert point and
+        then restore it when the with statement exits::
+
+            with g.inserting_after(n):
+                ... # inserting after node n
+            ... # insert point restored to what it was previously
+            g.inserting_after(n) #  set the insert point permanently
+
+        Args:
+
+            n (Optional[Node]): The node before which to insert. If None this will insert after
+                the beginning of the entire graph.
+
+        Returns:
+            A resource manager that will restore the insert point on ``__exit__``.
+        """
+        if n is None:
+            return self.inserting_before(self._root)
+        assert n.graph == self, "Node to insert after is not in graph."
+        return _InsertPoint(self, n.append)
+
+    @compatibility(is_backward_compatible=True)
+    def placeholder(self, name: str, type_expr: Optional[Any] = None,
+                    default_value : Any = inspect.Signature.empty) -> Node:
+        """
+        Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
+        a function input.
+
+        Args:
+
+            name (str): A name for the input value. This corresponds to the name
+                of the positional argument to the function this ``Graph`` represents.
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have. This is needed in some
+                cases for proper code generation (e.g. when the function is used
+                subsequently in TorchScript compilation).
+
+            default_value (Any): The default value this function argument should take
+                on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
+                should be passed as this argument to specify that the parameter does _not_
+                have a default value.
+
+        .. note::
+            The same insertion point and type expression rules apply for this method
+            as ``Graph.create_node``.
+        """
+        args = () if default_value is inspect.Signature.empty else (default_value,)
+        return self.create_node('placeholder', name, args=args, type_expr=type_expr)
+
+    @compatibility(is_backward_compatible=True)
+    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
+        fetch of an attribute from the ``Module`` hierarchy.
+
+        Args:
+
+            qualified_name (str): the fully-qualified name of the attribute to be retrieved.
+                For example, if the traced Module has a submodule named ``foo``, which has a
+                submodule named ``bar``, which has an attribute named ``baz``, the qualified
+                name ``foo.bar.baz`` should be passed as ``qualified_name``.
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have.
+
+
+        Returns:
+
+            The newly-created and inserted ``get_attr`` node.
+
+        .. note::
+            The same insertion point and type expression rules apply for this method
+            as ``Graph.create_node``.
+        """
+        def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool:
+            module_path, _, name = qualified_name.rpartition(".")
+
+            try:
+                submod: torch.nn.Module = mod.get_submodule(module_path)
+            except AttributeError:
+                warnings.warn(f"Failed to fetch module {module_path}!")
+                return False
+
+            if not hasattr(submod, name):
+                return False
+
+            res = getattr(submod, name)
+
+            if (not isinstance(res, torch.nn.Module)
+                    and not isinstance(res, torch.nn.Parameter)
+                    and name not in submod._buffers):
+                return False
+
+            return True
+
+        if (self.owning_module and
+                not _get_attr_reference_exists(self.owning_module, qualified_name)):
+            warnings.warn("Attempted to insert a get_attr Node with no "
+                          "underlying reference in the owning "
+                          "GraphModule! Call "
+                          "GraphModule.add_submodule to add the "
+                          "necessary submodule, "
+                          "GraphModule.add_parameter to add the "
+                          "necessary Parameter, or "
+                          "nn.Module.register_buffer to add the "
+                          "necessary buffer", stacklevel=2)
+        return self.create_node('get_attr', qualified_name, type_expr=type_expr)
+
+    @compatibility(is_backward_compatible=True)
+    def call_module(self,
+                    module_name: str,
+                    args: Optional[Tuple['Argument', ...]] = None,
+                    kwargs: Optional[Dict[str, 'Argument']] = None,
+                    type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node
+        represents a call to the forward() function of a ``Module`` in the ``Module``
+        hierarchy.
+
+        Args:
+
+            module_name (str): The qualified name of the ``Module`` in the ``Module``
+                hierarchy to be called. For example, if the traced ``Module`` has a
+                submodule named ``foo``, which has a submodule named ``bar``, the
+                qualified name ``foo.bar`` should be passed as ``module_name`` to
+                call that module.
+
+            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
+                to the called method. Note that this should *not* include a ``self`` argument.
+
+            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
+                to the called method
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have.
+
+        Returns:
+
+            The newly-created and inserted ``call_module`` node.
+
+        .. note::
+            The same insertion point and type expression rules apply for this method
+            as :meth:`Graph.create_node`.
+        """
+        if (self.owning_module and
+                self.owning_module.get_submodule(module_name) is None):
+            warnings.warn("Attempted to insert a call_module Node with "
+                          "no underlying reference in the owning "
+                          "GraphModule! Call "
+                          "GraphModule.add_submodule to add the "
+                          "necessary submodule")
+        return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
+
+    @compatibility(is_backward_compatible=True)
+    def call_method(self,
+                    method_name: str,
+                    args: Optional[Tuple['Argument', ...]] = None,
+                    kwargs: Optional[Dict[str, 'Argument']] = None,
+                    type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
+        represents a call to a given method on the 0th element of ``args``.
+
+        Args:
+
+            method_name (str): The name of the method to apply to the self argument.
+                For example, if args[0] is a ``Node`` representing a ``Tensor``,
+                then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.
+
+            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
+                to the called method. Note that this *should* include a ``self`` argument.
+
+            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
+                to the called method
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have.
+
+        Returns:
+
+            The newly created and inserted ``call_method`` node.
+
+        .. note::
+            The same insertion point and type expression rules apply for this method
+            as :meth:`Graph.create_node`.
+        """
+        return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
+
+    @compatibility(is_backward_compatible=True)
+    def call_function(self,
+                      the_function: Callable[..., Any],
+                      args: Optional[Tuple['Argument', ...]] = None,
+                      kwargs: Optional[Dict[str, 'Argument']] = None,
+                      type_expr: Optional[Any] = None) -> Node:
+        """
+        Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
+        represents a call to a Python callable, specified by ``the_function``.
+
+        Args:
+
+            the_function (Callable[..., Any]): The function to be called. Can be any PyTorch
+                operator, Python function, or member of the ``builtins`` or ``operator``
+                namespaces.
+
+            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
+                to the called function.
+
+            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
+                to the called function
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have.
+
+        Returns:
+
+            The newly created and inserted ``call_function`` node.
+
+        .. note::
+            The same insertion point and type expression rules apply for this method
+            as :meth:`Graph.create_node`.
+        """
+        return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
+
+    @compatibility(is_backward_compatible=True)
+    def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
+        """
+        Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
+        the graph of node to the graph of self. Example::
+
+            # Copying all the nodes in `g` into `new_graph`
+            g : torch.fx.Graph = ...
+            new_graph = torch.fx.graph()
+            value_remap = {}
+            for node in g.nodes:
+                value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
+
+        Args:
+
+            node (Node): The node to copy into ``self``.
+
+            arg_transform (Callable[[Node], Argument]): A function that transforms
+                ``Node`` arguments in node's ``args`` and ``kwargs`` into the
+                equivalent argument in ``self``. In the simplest case, this should
+                retrieve a value out of a table mapping Nodes in the original
+                graph to ``self``.
+        """
+        args = map_arg(node.args, arg_transform)
+        kwargs = map_arg(node.kwargs, arg_transform)
+        assert isinstance(args, tuple)
+        assert isinstance(kwargs, dict)
+        result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type)
+        result_node.meta = copy.copy(node.meta)
+        return result_node
+
+    @compatibility(is_backward_compatible=True)
+    def output(self, result: 'Argument', type_expr: Optional[Any] = None):
+        """
+        Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
+        a ``return`` statement in Python code. ``result`` is the value that should
+        be returned.
+
+        Args:
+
+            result (Argument): The value to be returned.
+
+            type_expr (Optional[Any]): an optional type annotation representing the
+                Python type the output of this node will have.
+
+        .. note::
+
+            The same insertion point and type expression rules apply for this method
+            as ``Graph.create_node``.
+        """
+        return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
+
+    def _target_to_str(self, target : Target) -> str:
+        if callable(target):
+            op = target.__name__
+        else:
+            assert isinstance(target, str)
+            op = target
+            if _is_magic(op):
+                op = op[2:-2]
+        op = _snake_case(op)
+        return op
+
+    @compatibility(is_backward_compatible=True)
+    def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode:
+        """
+        Turn this ``Graph`` into valid Python code.
+
+        Args:
+
+            root_module (str): The name of the root module on which to look-up
+                qualified name targets. This is usually 'self'.
+
+        Returns:
+
+            A PythonCode object, consisting of two fields:
+                src: the Python source code representing the object
+                globals: a dictionary of global names in `src` -> the objects that they reference.
+        """
+        # NOTE: [Graph Namespaces]
+        #
+        # There are two types of symbols in generated Python source code:
+        # locals and globals.
+        #   Locals are locally defined by the output of a node in the Graph.
+        #   Globals are references to external objects, like functions or types.
+        #
+        # When generating Python code, we need to make sure to name things
+        # appropriately. In particular:
+        # - All names should be unique, to avoid weird shadowing bugs.
+        # - These names need to be consistent, e.g. a object should always be
+        #   referenced by the same name.
+        #
+        # To do this, we create a new namespace just for this source. All names
+        # that get printed must come from this namespace.
+        #
+        # Why can't we re-use node.name? Because it was generated within the
+        # namespace `self._graph_namespace`. In order to provide uniqueness
+        # over both locals (node.name) *and* globals, we create a completely
+        # new namespace to put all identifiers in.
+        namespace = _Namespace()
+
+        # Override Node's repr to generate a valid name within our namespace.
+        # Since repr() is designed to produce a valid Python expression, it
+        # makes sense to re-use it. This way, it's easy to print something like
+        # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
+        # implemented cooperatively to allow this.
+        def node_repr(n: Node):
+            return namespace.create_name(n.name, n)
+
+        @contextmanager
+        def override_node_repr(graph: Graph):
+            orig_repr_fns = {}
+            for node in graph.nodes:
+                orig_repr_fns[node] = node._repr_fn
+                node._repr_fn = node_repr
+            try:
+                yield None
+            finally:
+                # restore the original repr functions
+                for node in graph.nodes:
+                    node._repr_fn = orig_repr_fns[node]
+
+        with override_node_repr(self):
+            return self._python_code(root_module, namespace, verbose=verbose)
+
+    def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
+        return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose)
+
+
+    def __str__(self) -> str:
+        """
+        Return a human-readable (not machine-readable) string representation
+        of this Graph
+        """
+        placeholder_names : List[str] = []
+        # This is a one-element array just so ``format_node`` can modify the closed
+        # over value
+        maybe_return_typename : List[str] = ['']
+
+        node_strs = [node.format_node(placeholder_names) for node in self.nodes]
+        param_str = ', '.join(placeholder_names)
+        s = f'graph({param_str}){maybe_return_typename[0]}:'
+        for node_str in node_strs:
+            if node_str:
+                s += '\n    ' + node_str
+        return s
+
+    @compatibility(is_backward_compatible=True)
+    def print_tabular(self):
+        """
+        Prints the intermediate representation of the graph in tabular
+        format. Note that this API requires the ``tabulate`` module to be
+        installed.
+        """
+        try:
+            from tabulate import tabulate
+        except ImportError:
+            print("`print_tabular` relies on the library `tabulate`, "
+                  "which could not be found on this machine. Run `pip "
+                  "install tabulate` to install the library.")
+            raise
+
+        node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
+                      for n in self.nodes]
+        print(tabulate(node_specs,
+              headers=['opcode', 'name', 'target', 'args', 'kwargs']))
+
+    @compatibility(is_backward_compatible=True)
+    def lint(self):
+        """
+        Runs various checks on this Graph to make sure it is well-formed. In
+        particular:
+        - Checks Nodes have correct ownership (owned by this graph)
+        - Checks Nodes appear in topological order
+        - If this Graph has an owning GraphModule, checks that targets
+        exist in that GraphModule
+        """
+
+        # Check topo order
+        def check_arg(arg : Node, n : Optional[Node] = None) -> None:
+            context_str = f' of Node \'{n}\' ' if n else ' '
+            if arg.graph is not self:
+                raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
+                                   f'but was used as an argument! If you are copying nodes from another graph, make '
+                                   f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}')
+            if arg not in seen_values:
+                raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
+                                   f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
+
+        seen_names : Set[str] = set()
+        seen_values : Set[Node] = set()
+        for node in self.nodes:
+            if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
+                raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
+            if node.graph is not self:
+                raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
+            map_arg(node.args, lambda arg: check_arg(arg, node))
+            map_arg(node.kwargs, lambda arg: check_arg(arg, node))
+            seen_values.add(node)
+
+            if node.name in seen_names:
+                raise RuntimeError(f'Node redefined name {node.name}!')
+            seen_names.add(node.name)
+
+        # Check targets are legit
+        if self.owning_module:
+            for node in self.nodes:
+                if node.op == 'call_function':
+                    if not callable(node.target):
+                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
+                                         'a Callable is expected')
+                else:
+                    if not isinstance(node.target, str):
+                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
+                                         'a str is expected')
+                if node.op in ['get_attr', 'call_module']:
+                    target_atoms = node.target.split('.')
+                    m_itr = self.owning_module
+                    for i, atom in enumerate(target_atoms):
+                        new_m_itr = getattr(m_itr, atom, None)
+                        seen_qualname = '.'.join(target_atoms[:i])
+                        if new_m_itr is None:
+                            raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
+                                               f'{atom} of {seen_qualname}')
+                        if (node.op == "call_module"
+                                and not isinstance(new_m_itr, torch.nn.Module)):
+                            raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
+                                               'not reference an nn.Module')
+                        elif (node.op == "get_attr"
+                              and not isinstance(new_m_itr, torch.nn.Module)
+                              and not isinstance(new_m_itr, torch.nn.Parameter)
+                              and atom not in m_itr._buffers):
+                            warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
+                                          'not reference an nn.Module, nn.Parameter, or buffer, which is '
+                                          'what \'get_attr\' Nodes typically target')
+                        else:
+                            m_itr = new_m_itr
+
+    @compatibility(is_backward_compatible=True)
+    def eliminate_dead_code(self):
+        """
+        Remove all dead code from the graph, based on each node's number of
+        users, and whether the nodes have any side effects. The graph must be
+        topologically sorted before calling.
+
+        Returns:
+          bool: Whether the graph was changed as a result of the pass.
+
+        Example:
+
+        Before dead code is eliminated, `a` from `a = x + 1` below has no users
+        and thus can be eliminated from the graph without having an effect.
+
+        .. code-block:: python
+
+            def forward(self, x):
+                a = x + 1
+                return x + self.attr_1
+
+        After dead code is eliminated, `a = x + 1` has been removed, and the rest
+        of `forward` remains.
+
+        .. code-block:: python
+
+            def forward(self, x):
+                return x + self.attr_1
+
+        .. warning::
+
+            Dead code elimination has some heuristics to avoid removing
+            side-effectful nodes (see Node.is_impure) but in general coverage
+            is very bad, so you should assume that this method is not sound
+            to call unless you know that your FX graph consists entirely
+            of functional operations.
+        """
+        # Lint the graph first to make sure its topologically sorted, otherwise
+        # DCE below will not behave as expected.
+        self.lint()
+
+        # Reverse iterate so that when we remove a node, any nodes used as an
+        # input to that node have an updated user count that no longer reflects
+        # the removed node.
+        changed = False
+        for node in reversed(self.nodes):
+            if not node.is_impure() and len(node.users) == 0:
+                self.erase_node(node)
+                changed = True
+
+        return changed
+
+    @compatibility(is_backward_compatible=False)
+    def set_codegen(self, codegen: CodeGen):
+        self._codegen = codegen
+
+    @compatibility(is_backward_compatible=False)
+    def on_generate_code(
+        self,
+        make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]
+    ):
+        """Register a transformer function when python code is generated
+
+        Args:
+            make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
+                a function that returns a code transformer to be registered.
+                This function is called by `on_generate_code` to obtain the
+                code transformer.
+
+                This function is also given as its input the currently
+                registered code transformer (or None if nothing is registered),
+                in case it is not desirable to overwrite it. This is useful to
+                chain code transformers together.
+
+        Returns:
+            a context manager that when used in a `with` statement, to automatically
+            restore the previously registered code transformer.
+
+        Example:
+
+        .. code-block:: python
+
+
+            gm: fx.GraphModule = ...
+
+            # This is a code transformer we want to register. This code
+            # transformer prepends a pdb import and trace statement at the very
+            # beginning of the generated torch.fx code to allow for manual
+            # debugging with the PDB library.
+            def insert_pdb(body):
+                return ["import pdb; pdb.set_trace()\\n", *body]
+
+            # Registers `insert_pdb`, and overwrites the current registered
+            # code transformer (given by `_` to the lambda):
+            gm.graph.on_generate_code(
+                lambda _: insert_pdb
+            )
+
+            # Or alternatively, registers a code transformer which first
+            # runs `body` through existing registered transformer, then
+            # through `insert_pdb`:
+            gm.graph.on_generate_code(
+                lambda current_trans: (
+                    lambda body: insert_pdb(
+                        current_trans(body) if current_trans
+                        else body
+                    )
+                )
+            )
+
+            gm.recompile()
+            gm(*inputs)  # drops into pdb
+
+
+        This function can also be used as a context manager, with the benefit to
+        automatically restores the previously registered code transformer:
+
+        .. code-block:: python
+
+            # ... continue from previous example
+
+            with gm.graph.on_generate_code(lambda _: insert_pdb):
+                # do more stuff with `gm`...
+                gm.recompile()
+                gm(*inputs)  # drops into pdb
+
+            # now previous code transformer is restored (but `gm`'s code with pdb
+            # remains - that means you can run `gm` with pdb here too, until you
+            # run next `recompile()`).
+        """
+        on_gen_code_old = self._codegen._body_transformer
+        self._codegen._body_transformer = make_transformer(on_gen_code_old)
+
+        @contextlib.contextmanager
+        def on_generate_code_context_manager():
+            try:
+                yield
+            finally:
+                self._codegen._body_transformer = on_gen_code_old
+
+        return on_generate_code_context_manager()
+
+
+reflectable_magic_methods = {
+    'add': '{} + {}',
+    'sub': '{} - {}',
+    'mul': '{} * {}',
+    'floordiv': '{} // {}',
+    'truediv': '{} / {}',
+    'div': '{} / {}',
+    'mod': '{} % {}',
+    'pow': '{} ** {}',
+    'lshift': '{} << {}',
+    'rshift': '{} >> {}',
+    'and_': '{} & {}',
+    'or_': '{} | {}',
+    'xor': '{} ^ {}',
+    'getitem': '{}[{}]',
+    'matmul': '{} @ {}',
+}
+
+magic_methods = dict({
+    'eq': '{} == {}',
+    'ne': '{} != {}',
+    'lt': '{} < {}',
+    'gt': '{} > {}',
+    'le': '{} <= {}',
+    'ge': '{} >= {}',
+    'pos': '+{}',
+    'neg': '-{}',
+    'invert': '~{}'}, **reflectable_magic_methods)
+
+inplace_methods = {
+    'iadd': '{} += {}',
+    'iand': '{} &= {}',
+    'ifloordiv': '{} //= {}',
+    'ilshift': '{} <<= {}',
+    'imod': '{} %= {}',
+    'imul': '{} *= {}',
+    'imatmul': '{} @= {}',
+    'ior': '{} |= {}',
+    'ipow': '{} **= {}',
+    'irshift': '{} >>= {}',
+    'isub': '{} -= {}',
+    'itruediv': '{} /= {}',
+    'ixor': '{} ^= {}',
+    'setitem': '{}[{}] = {}',
+}
diff --git a/MLPY/Lib/site-packages/torch/fx/graph_module.py b/MLPY/Lib/site-packages/torch/fx/graph_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c1d8357ad0212bc5261a540cf3beeb61f545e9f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/graph_module.py
@@ -0,0 +1,884 @@
+import contextlib
+import copy
+import itertools
+import linecache
+import os
+import sys
+import traceback
+import warnings
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
+
+import torch
+import torch.nn as nn
+import torch.overrides
+from torch.nn.modules.module import _addindent
+from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
+
+from ._compatibility import compatibility
+from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
+
+__all__ = [
+    "reduce_graph_module",
+    "reduce_package_graph_module",
+    "reduce_deploy_graph_module",
+    "GraphModule",
+]
+
+_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
+
+# Normal exec loses the source code, however we can work with
+# the linecache module to recover it.
+# Using _exec_with_source will add it to our local cache
+# and then tools like TorchScript will be able to get source info.
+class _EvalCacheLoader:
+    def __init__(self):
+        self.eval_cache = {}
+        self.next_id = 0
+
+    def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
+        """Store the source in a private cache, and add a lazy entry in linecache
+        that allows the source to be retrieved by 'filename'.
+
+        Args:
+            src (str): The module source to cache
+            globals (dict): The module globals
+
+        Returns:
+            str: The cache key (and dummy filename) generated for src.
+        """
+
+        key = self._get_key()
+        if co_fields:
+            key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}"
+        self.eval_cache[key] = src
+
+        # Don't mutate globals so that this loader is only used
+        # to populate linecache, and doesn't interact with other modules
+        # that might check `__loader__`
+        globals_copy = globals.copy()
+        globals_copy["__file__"] = key
+        globals_copy["__name__"] = key
+        globals_copy["__loader__"] = self
+        linecache.lazycache(key, globals_copy)
+
+        return key
+
+    # Part of the loader protocol (PEP 302)
+    # linecache will use this method when trying to find source code
+    def get_source(self, module_name) -> Optional[str]:
+        if module_name in self.eval_cache:
+            return self.eval_cache[module_name]
+        return None
+
+    def _get_key(self):
+        key = f".{self.next_id}"
+        self.next_id += 1
+        return key
+
+
+_loader = _EvalCacheLoader()
+
+
+def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
+    key = _loader.cache(src, globals, co_fields)
+    exec(compile(src, key, "exec"), globals)
+
+
+def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
+    return _method_from_src(
+        method_name="forward", src=src, globals=globals, co_fields=co_fields
+    )
+
+
+def _method_from_src(
+    method_name: str, src: str, globals: Dict[str, Any], co_fields=None
+) -> Callable:
+    # avoid mutating the passed in dict
+    globals_copy = globals.copy()
+    _exec_with_source(src, globals_copy, co_fields)
+    fn = globals_copy[method_name]
+    del globals_copy[method_name]
+    return fn
+
+
+def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
+    if name in _custom_builtins:
+        return _custom_builtins[name].import_str
+    if _is_from_torch(name):
+        return "import torch"
+    module_name, attr_name = importer.get_name(obj)
+    return f"from {module_name} import {attr_name} as {name}"
+
+
+def _format_import_block(globals: Dict[str, Any], importer: Importer):
+    import_strs: Set[str] = set()
+    for name, obj in globals.items():
+        import_strs.add(_format_import_statement(name, obj, importer))
+    # Sort the imports so we have a stable import block that allows us to
+    # hash the graph module and get a consistent key for use in a cache.
+    return "\n".join(sorted(import_strs))
+
+
+@compatibility(is_backward_compatible=True)
+def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
+    # BC: attribute name was changed from `code` to `_code` to facilitate
+    # making `code` into a property and adding a docstring to it
+    fn_src = body.get("_code") or body["code"]
+    forward = _forward_from_src(import_block + fn_src, {})
+    return _deserialize_graph_module(forward, body)
+
+
+@compatibility(is_backward_compatible=True)
+def reduce_package_graph_module(
+    importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
+) -> torch.nn.Module:
+    forward = importer.import_module(generated_module_name).forward
+    return _deserialize_graph_module(forward, body)
+
+
+@compatibility(is_backward_compatible=True)
+def reduce_deploy_graph_module(
+    importer: PackageImporter, body: Dict[Any, Any], import_block: str
+) -> torch.nn.Module:
+    ns = {}
+    ns["__builtins__"] = importer.patched_builtins
+    fn_src = body.get("_code")
+    assert fn_src is not None
+    forward = _forward_from_src(import_block + fn_src, ns)
+    return _deserialize_graph_module(forward, body)
+
+
+# We create a dummy class here because symbolic_trace pulls the forward()
+# function off of the class, rather than the instance. This class is used
+# in _deserialize_graph_module() below.
+class _CodeOnlyModule(torch.nn.Module):
+    def __init__(self, body):
+        super().__init__()
+        self.__dict__ = body
+
+
+def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module:
+    """
+    Deserialize a GraphModule given the dictionary of the original module,
+    using the code to reconstruct the graph. We delete the actual graph before
+    saving the dictionary so that changes to the in-memory graph format do not
+    get serialized.
+    """
+
+    # Try to retrieve the forward source in a backward-compatible way
+    _CodeOnlyModule.forward = forward
+
+    tracer_cls = body.get("_tracer_cls")
+    if tracer_cls is None:
+        from ._symbolic_trace import Tracer
+
+        tracer_cls = Tracer
+
+    graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule")
+
+    # This is a workaround for a mypy linter issue related to
+    # passing base class as an argument - https://github.com/python/mypy/issues/5865.
+    cls_tracer: Any = tracer_cls
+
+    class KeepModules(cls_tracer):
+        # we shouldn't trace into any of the submodules,
+        # because they were not traced in the original GraphModule
+        def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
+            return True
+
+    com = _CodeOnlyModule(body)
+
+    tracer_extras = body.get("_tracer_extras", {})
+    graph = KeepModules().trace(com, **tracer_extras)
+
+    # Manually set Tracer class on the reconstructed Graph, to avoid
+    # referencing the private local subclass KeepModules.
+    graph._tracer_cls = tracer_cls
+    from ._lazy_graph_module import _make_graph_module
+    gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls)
+
+    # The GraphModule constructor only retains attributes referenced by the graph.
+    # In this case, our goal is return a GraphModule as close to identical as the one
+    # put into the package. If any additional attributes were present in body,
+    # we should keep them.
+    for k, v in body.items():
+        if not hasattr(gm, k):
+            setattr(gm, k, v)
+    return gm
+
+
+# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
+# This installs empty Modules where none exist yet if they are subpaths of target
+def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
+    *prefix, field = target.split(".")
+    for item in prefix:
+        f = getattr(from_module, item)
+        t = getattr(to_module, item, None)
+        if f is t:
+            # we have already installed one of its parents
+            # (e.g. target = root.linear.weight, but we have already installed root.linear)
+            # once we install a parent, we no longer need to copy the children
+            # since all the needed properties will already be present
+            return
+
+        if t is None:
+            t = torch.nn.Module()
+            setattr(to_module, item, t)
+        from_module, to_module = f, t
+
+    orig = getattr(from_module, field)
+    # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
+    # So, we register it as a named buffer in the target module.
+    if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
+        to_module.register_buffer(field, orig)
+    else:
+        setattr(to_module, field, orig)
+
+
+# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
+# This installs empty Modules where none exist yet if they are subpaths of target
+def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
+    *prefix, field = target.split(".")
+    for item in prefix:
+        t = getattr(to_module, item, None)
+
+        if t is None:
+            t = torch.nn.Module()
+            setattr(to_module, item, t)
+        to_module = t
+
+    # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
+    # So, we register it as a named buffer in the target module.
+    if isinstance(from_obj, torch.Tensor) and not isinstance(
+        from_obj, torch.nn.Parameter
+    ):
+        to_module.register_buffer(field, from_obj)
+    else:
+        setattr(to_module, field, from_obj)
+
+
+class _WrappedCall:
+    def __init__(self, cls, cls_call):
+        self.cls = cls
+        self.cls_call = cls_call
+
+    # Previously, if an error occurred when valid
+    # symbolically-traced code was run with an invalid input, the
+    # user would see the source of the error as coming from
+    # `File "`, where N is some number. We use
+    # this function to generate a more informative error message. We
+    # return the traceback itself, a message explaining that the
+    # error occurred in a traced Module's generated forward
+    # function, and five lines of context surrounding the faulty
+    # line
+    @staticmethod
+    def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
+        # auxiliary variables (for readability)
+        err_lineno = frame_summary.lineno
+        assert err_lineno is not None
+        line = frame_summary.line
+        assert line is not None
+        err_line_len = len(line)
+        all_src_lines = linecache.getlines(frame_summary.filename)
+
+        # constituent substrings of the error message
+        tb_repr = traceback.format_exc()
+        custom_msg = (
+            "Call using an FX-traced Module, "
+            f"line {err_lineno} of the traced Module's "
+            "generated forward function:"
+        )
+        before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
+        marker = "~" * err_line_len + "~~~ <--- HERE"
+        err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
+
+        # joined message
+        return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
+
+    def __call__(self, obj, *args, **kwargs):
+        try:
+            if self.cls_call is not None:
+                return self.cls_call(obj, *args, **kwargs)
+            else:
+                return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
+        except Exception as e:
+            assert e.__traceback__
+            topmost_framesummary: traceback.FrameSummary = (
+                traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
+            )  # type: ignore[arg-type]
+            if "eval_with_key" in topmost_framesummary.filename:
+                print(
+                    _WrappedCall._generate_error_message(topmost_framesummary),
+                    file=sys.stderr,
+                )
+                raise e.with_traceback(None)  # noqa: TRY200
+            else:
+                raise e
+
+@compatibility(is_backward_compatible=True)
+class GraphModule(torch.nn.Module):
+    """
+    GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
+    ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
+    from that ``graph``.
+
+    .. warning::
+
+        When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
+        regenerated. However, if you edit the contents of the ``graph`` without reassigning
+        the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
+        code.
+    """
+
+    def __new__(cls: "Type[GraphModule]", *args, **kwargs):
+        # each instance of a graph module needs its own forward method
+        # so create a new singleton class for each instance.
+        # it is a subclass of the user-defined class, the only difference
+        # is an extra layer to install the forward method
+
+        # address issue described at https://github.com/pytorch/pytorch/issues/63883
+        # in other words, traverse class hierarchy to fix the redundant class definition problem
+        for t in cls.__mro__:
+            c = t.__qualname__.split(".")[-1]
+            if c != "GraphModuleImpl":
+                cls = t
+                break
+
+        class GraphModuleImpl(cls):  # type: ignore[misc, valid-type]
+            pass
+
+        return super().__new__(GraphModuleImpl)
+
+    @compatibility(is_backward_compatible=True)
+    def __init__(
+        self,
+        root: Union[torch.nn.Module, Dict[str, Any]],
+        graph: Graph,
+        class_name: str = "GraphModule",
+    ):
+        """
+        Construct a GraphModule.
+
+        Args:
+
+            root (Union[torch.nn.Module, Dict[str, Any]):
+                ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
+                In the case that ``root`` is a Module, any references to Module-based objects (via qualified
+                name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
+                within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
+                In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
+                looked up directly in the dict's keys. The object mapped to by the Dict will be copied
+                over into the appropriate place within the GraphModule's module hierarchy.
+
+            graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
+
+            class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
+                error messages will report as originating from ``GraphModule``. It may be helpful to set this
+                to ``root``'s original name or a name that makes sense within the context of your transform.
+        """
+        super().__init__()
+        self.__class__.__name__ = class_name
+        if isinstance(root, torch.nn.Module):
+            if hasattr(root, "training"):
+                self.training = root.training
+
+            # When we pickle/unpickle graph module, we don't want to drop any module or attributes.
+            if isinstance(root, _CodeOnlyModule):
+                for k, _ in root.named_children():
+                    _copy_attr(root, self, k)
+
+                for k, _ in root.named_buffers():
+                    _copy_attr(root, self, k)
+
+                for k, _ in root.named_parameters():
+                    _copy_attr(root, self, k)
+
+            for node in graph.nodes:
+                if node.op in ["get_attr", "call_module"]:
+                    assert isinstance(node.target, str)
+                    _copy_attr(root, self, node.target)
+        elif isinstance(root, dict):
+            targets_to_copy = []
+            for node in graph.nodes:
+                if node.op in ["get_attr", "call_module"]:
+                    assert isinstance(node.target, str)
+                    if node.target not in root:
+                        raise RuntimeError(
+                            "Node "
+                            + str(node)
+                            + " referenced target "
+                            + node.target
+                            + " but that target was not provided in ``root``!"
+                        )
+                    targets_to_copy.append(node.target)
+            # Sort targets in ascending order of the # of atoms.
+            # This will ensure that less deeply nested attributes are assigned
+            # before more deeply nested attributes. For example, foo.bar
+            # will be assigned before foo.bar.baz. Otherwise, we might assign
+            # the user-provided ``foo.bar`` and wipe out the previously-assigned
+            # ``foo.bar.baz``
+            targets_to_copy.sort(key=lambda t: t.count("."))
+            for target_to_copy in targets_to_copy:
+                _assign_attr(root[target_to_copy], self, target_to_copy)
+        else:
+            raise RuntimeError("Unsupported type " + str(root) + " passed for root!")
+
+        self.graph = graph
+
+        # Store the Tracer class responsible for creating a Graph separately as part of the
+        # GraphModule state, except when the Tracer is defined in a local namespace.
+        # Locally defined Tracers are not pickleable. This is needed because torch.package will
+        # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
+        # to re-create the Graph during deserialization.
+        self._tracer_cls = None
+        if (
+            self.graph._tracer_cls
+            and "" not in self.graph._tracer_cls.__qualname__
+        ):
+            self._tracer_cls = self.graph._tracer_cls
+
+        self._tracer_extras = {}
+        if self.graph._tracer_extras:
+            self._tracer_extras = self.graph._tracer_extras
+
+        # Dictionary to store metadata
+        self.meta: Dict[str, Any] = {}
+        self._replace_hook = None
+
+    # TorchScript breaks trying to compile the graph setter because of the
+    # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
+    #
+    # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
+    __jit_unused_properties__ = ["graph"]
+
+    @property
+    def graph(self) -> Graph:
+        """
+        Return the ``Graph`` underlying this ``GraphModule``
+        """
+        return self._graph
+
+    @graph.setter
+    def graph(self, g: Graph) -> None:
+        """
+        Set the underlying ``Graph`` for this ``GraphModule``. This will internally
+        recompile the ``GraphModule`` so that the generated ``forward()`` function
+        corresponds to ``g``
+        """
+        assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}"
+        self._graph = g
+        g.owning_module = self
+        self.recompile()
+
+    @compatibility(is_backward_compatible=False)
+    def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
+        """Dumps out module to ``folder`` with ``module_name`` so that it can be
+        imported with ``from  import ``
+
+        Args:
+
+            folder (Union[str, os.PathLike]): The folder to write the code out to
+
+            module_name (str): Top-level name to use for the ``Module`` while
+                writing out the code
+        """
+        folder = Path(folder)
+        Path(folder).mkdir(exist_ok=True)
+        torch.save(self.state_dict(), folder / "state_dict.pt")
+        tab = " " * 4
+        custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()])
+        model_str = f"""
+import torch
+{custom_builtins}
+
+from torch.nn import *
+class {module_name}(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+"""
+
+        def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
+            safe_reprs = [
+                nn.Linear,
+                nn.Conv1d,
+                nn.Conv2d,
+                nn.Conv3d,
+                nn.BatchNorm1d,
+                nn.BatchNorm2d,
+                nn.BatchNorm3d,
+            ]
+            if type(module) in safe_reprs:
+                return f"{module.__repr__()}"
+            else:
+                return None
+
+        blobified_modules = []
+        for module_name, module in self.named_children():
+            module_str = _gen_model_repr(module_name, module)
+            if module_str is None:
+                module_file = folder / f"{module_name}.pt"
+                torch.save(module, module_file)
+                blobified_modules.append(module_name)
+                module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
+                module_str = f"torch.load(r'{module_file}') # {module_repr}"
+            model_str += f"{tab*2}self.{module_name} = {module_str}\n"
+
+        for buffer_name, buffer in self._buffers.items():
+            if buffer is None:
+                continue
+            model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
+
+        for param_name, param in self._parameters.items():
+            if param is None:
+                continue
+            model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
+
+        model_str += (
+            f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
+        )
+        model_str += f"{_addindent(self.code, 4)}\n"
+
+        module_file = folder / "module.py"
+        module_file.write_text(model_str)
+
+        init_file = folder / "__init__.py"
+        init_file.write_text("from .module import *")
+
+        if len(blobified_modules) > 0:
+            warnings.warn(
+                "Was not able to save the following children modules as reprs -"
+                f"saved as pickled files instead: {blobified_modules}"
+            )
+
+    @compatibility(is_backward_compatible=True)
+    def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
+        """
+        Adds the given submodule to ``self``.
+
+        This installs empty Modules where none exist yet if they are
+        subpaths of ``target``.
+
+        Args:
+            target: The fully-qualified string name of the new submodule
+                (See example in ``nn.Module.get_submodule`` for how to
+                specify a fully-qualified string.)
+            m: The submodule itself; the actual object we want to
+                install in the current Module
+
+        Return:
+            bool: Whether or not the submodule could be inserted. For
+                this method to return True, each object in the chain
+                denoted by ``target`` must either a) not exist yet,
+                or b) reference an ``nn.Module`` (not a parameter or
+                other attribute)
+        """
+        *prefix, field = target.split(".")
+        mod: torch.nn.Module = self
+
+        for item in prefix:
+
+            submod = getattr(mod, item, None)
+
+            if submod is None:
+                submod = torch.nn.Module()
+                setattr(mod, item, submod)
+
+            if not isinstance(submod, torch.nn.Module):
+                return False
+
+            mod = submod
+
+        mod.add_module(field, m)
+        return True
+
+    @compatibility(is_backward_compatible=True)
+    def delete_submodule(self, target: str) -> bool:
+        """
+        Deletes the given submodule from ``self``.
+
+        The module will not be deleted if ``target`` is not a valid
+        target.
+
+        Args:
+            target: The fully-qualified string name of the new submodule
+                (See example in ``nn.Module.get_submodule`` for how to
+                specify a fully-qualified string.)
+
+        Returns:
+            bool: Whether or not the target string referenced a
+                submodule we want to delete. A return value of ``False``
+                means that the ``target`` was not a valid reference to
+                a submodule.
+        """
+        atoms = target.split(".")
+        path, target_submod = atoms[:-1], atoms[-1]
+        mod: torch.nn.Module = self
+
+        # Get the parent module
+        for item in path:
+
+            if not hasattr(mod, item):
+                return False
+
+            mod = getattr(mod, item)
+
+            if not isinstance(mod, torch.nn.Module):
+                return False
+
+        if not hasattr(mod, target_submod):
+            return False
+
+        if not isinstance(getattr(mod, target_submod), torch.nn.Module):
+            return False
+
+        delattr(mod, target_submod)
+        return True
+
+    @compatibility(is_backward_compatible=True)
+    def delete_all_unused_submodules(self) -> None:
+        """
+        Deletes all unused submodules from ``self``.
+
+        A Module is considered "used" if any one of the following is
+        true:
+        1. It has children that are used
+        2. Its forward is called directly via a ``call_module`` node
+        3. It has a non-Module attribute that is used from a
+        ``get_attr`` node
+
+        This method can be called to clean up an ``nn.Module`` without
+        manually calling ``delete_submodule`` on each unused submodule.
+        """
+        used: List[str] = []
+
+        for node in self.graph.nodes:
+
+            if node.op == "call_module" or node.op == "get_attr":
+
+                # A list of strings representing the different parts
+                # of the path. For example, `foo.bar.baz` gives us
+                # ["foo", "bar", "baz"]
+                fullpath = node.target.split(".")
+
+                # If we're looking at multiple parts of a path, join
+                # join them with a dot. Otherwise, return that single
+                # element without doing anything to it.
+                def join_fn(x: str, y: str) -> str:
+                    return ".".join([x, y] if y else [x])
+
+                # Progressively collect all the names of intermediate
+                # modules. For example, if we have the target
+                # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
+                # `foo.bar.baz` to the list.
+                used.extend(itertools.accumulate(fullpath, join_fn))
+
+                # For a `call_module` node, also register all recursive submodules
+                # as used
+                if node.op == "call_module":
+                    try:
+                        submod = self.get_submodule(node.target)
+
+                        for submod_name, _ in submod.named_modules():
+                            if submod_name != "":
+                                used.append(".".join([node.target, submod_name]))
+                    except AttributeError:
+                        # Node referenced nonexistent submodule, don't need to
+                        # worry about GCing anything
+                        pass
+
+        to_delete = [name for name, _ in self.named_modules() if name not in used]
+
+        for name in to_delete:
+            self.delete_submodule(name)
+
+    @property
+    def code(self) -> str:
+        """
+        Return the Python code generated from the ``Graph`` underlying this
+        ``GraphModule``.
+        """
+        if not hasattr(self, "_code"):
+            raise RuntimeError(
+                "Code has not been generated! Please report a bug to PyTorch"
+            )
+        return self._code
+
+    @compatibility(is_backward_compatible=True)
+    def recompile(self) -> PythonCode:
+        """
+        Recompile this GraphModule from its ``graph`` attribute. This should be
+        called after editing the contained ``graph``, otherwise the generated
+        code of this ``GraphModule`` will be out of date.
+        """
+        if isinstance(self._graph._codegen, _PyTreeCodeGen):
+            self._in_spec = self._graph._codegen.pytree_info.in_spec
+            self._out_spec = self._graph._codegen.pytree_info.out_spec
+        python_code = self._graph.python_code(root_module="self")
+        self._code = python_code.src
+        self._lineno_map = python_code._lineno_map
+
+        cls = type(self)
+        co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
+        cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
+
+        # Determine whether this class explicitly defines a __call__ implementation
+        # to wrap. If it does, save it in order to have wrapped_call invoke it.
+        # If it does not, wrapped_call can use a dynamic call to super() instead.
+        # In most cases, super().__call__ should be torch.nn.Module.__call__.
+        # We do not want to hold a reference to Module.__call__ here; doing so will
+        # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
+        cls_call = cls.__call__ if "__call__" in vars(cls) else None
+
+        if "_wrapped_call" not in vars(cls):
+            cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined]
+
+        def call_wrapped(self, *args, **kwargs):
+            return self._wrapped_call(self, *args, **kwargs)
+
+        cls.__call__ = call_wrapped  # type: ignore[method-assign]
+
+        return python_code
+
+    # Passing Tracer as argument allows subclasses extending fx.GraphModule
+    # define their own Tracer (extending fx.Tracer).
+    def __reduce_deploy__(self, importer: Importer):
+        dict_without_graph = self.__dict__.copy()
+        dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
+        del dict_without_graph["_graph"]
+
+        python_code = self.recompile()
+        import_block = _format_import_block(python_code.globals, importer)
+        return (reduce_deploy_graph_module, (dict_without_graph, import_block))
+
+    def __reduce_package__(self, exporter: PackageExporter):
+        dict_without_graph = self.__dict__.copy()
+        dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
+        del dict_without_graph["_graph"]
+
+        generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
+        python_code = self.recompile()
+        import_block = _format_import_block(python_code.globals, exporter.importer)
+        module_code = import_block + self.code
+        exporter.save_source_string(generated_module_name, module_code)
+        return (
+            reduce_package_graph_module,
+            (dict_without_graph, generated_module_name),
+        )
+
+    def __reduce__(self):
+        """
+        Serialization of GraphModule. We serialize only the generated code, not
+        the underlying ``Graph``. This is because ``Graph`` does not have on-disk
+        backward-compatibility guarantees, whereas Python source code does.
+        On the deserialization side, we symbolically trace through the generated
+        code to regenerate the underlying ``Graph``
+        """
+        dict_without_graph = self.__dict__.copy()
+
+        python_code = self.recompile()
+        import_block = _format_import_block(python_code.globals, sys_importer)
+        del dict_without_graph["_graph"]
+        return (reduce_graph_module, (dict_without_graph, import_block))
+
+    def _deepcopy_init(self):
+        return GraphModule.__init__
+
+    # because __reduce__ is defined for serialization,
+    # we need to define deepcopy otherwise it will call __reduce__
+    # and cause symbolic tracing to occur every time we try to copy the object
+    def __deepcopy__(self, memo):
+        res = type(self).__new__(type(self))
+        memo[id(self)] = res
+        fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
+        self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"])
+        # hooks are lost during `GraphModule.__init__`, so we need to copy over
+        # them explicitly, note right now we are only copying state_dict related
+        # hooks, to reduce bc-related issues, we can copy forward/backward related
+        # hooks in the future as well if needed
+        extra_preserved_attrs = [
+            "_state_dict_hooks",
+            "_load_state_dict_pre_hooks",
+            "_load_state_dict_post_hooks",
+            "_replace_hook",
+        ]
+        for attr in extra_preserved_attrs:
+            if attr in self.__dict__:
+                setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
+        res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
+        if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
+            for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
+                setattr(res, attr_name, attr)
+        return res
+
+    def __copy__(self):
+        from ._lazy_graph_module import _make_graph_module
+        res = _make_graph_module(self, self.graph)
+        res.meta = getattr(self, "meta", {})
+        return res
+
+    @compatibility(is_backward_compatible=False)
+    def print_readable(self, print_output=True):
+        """
+        Return the Python code generated for current GraphModule and its children GraphModules
+        """
+        verbose_python_code = self._graph.python_code(root_module="self", verbose=True)
+        module_code = verbose_python_code.src
+        module_code = module_code.lstrip("\n")
+        module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
+        module_code = _addindent(module_code, 4)
+
+        submodule_code_list = [""]
+        for submodule in self.children():
+            if isinstance(submodule, GraphModule):
+                submodule_code_list.append(submodule.print_readable(print_output=False))
+        submodule_code = "\n".join(submodule_code_list)
+        submodule_code = _addindent(submodule_code, 4)
+
+        output = module_code + submodule_code
+        if print_output:
+            print(module_code + submodule_code)
+        return output
+
+    def __str__(self) -> str:
+        orig_str = super().__str__()
+        print_readable_reminder = (
+            "# To see more debug info, please use `graph_module.print_readable()`"
+        )
+        return "\n".join([orig_str, self._code, print_readable_reminder])
+
+    def _replicate_for_data_parallel(self):
+        new_gm = self.__copy__()
+        new_gm._is_replica = True
+        return new_gm
+
+    @contextlib.contextmanager
+    def _set_replace_hook(self, f):
+        """
+        Takes a callable which will be called everytime when we replace a node
+        to a new node, or change the node's name. Callable takes three arguments:
+        the old node we're changing, and NAME of the new node, followed by the
+        user node which consumes the old node to be replaced.
+        """
+        assert callable(f), "Replace hook must be a callable."
+        prev, self._replace_hook = self._replace_hook, f
+        try:
+            yield
+        finally:
+            self._replace_hook = prev
+
+
+# workarounds for issues in __torch_function__
+
+# WAR for __torch_function__ not handling tensor lists,
+# fix is in https://github.com/pytorch/pytorch/pull/34725
+# orig_cat = torch.cat
+# def patched_cat(*args, **kwargs):
+#     tensors = args[0]
+#     for t in tensors:
+#         if isinstance(t, Proxy):
+#             return t.__torch_function__(patched_cat, (), args, kwargs)
+#     return orig_cat(*args, **kwargs)
+# patched_cat.__module__ = 'torch'
+# patched_cat.__name__ = 'cat'
+# torch.cat = patched_cat
diff --git a/MLPY/Lib/site-packages/torch/fx/immutable_collections.py b/MLPY/Lib/site-packages/torch/fx/immutable_collections.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fb87bf1bc80901568e75efe13b7d119709b49f8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/immutable_collections.py
@@ -0,0 +1,112 @@
+from typing import Any, Dict, Iterable, List, Tuple
+
+from torch.utils._pytree import (
+    _dict_flatten,
+    _dict_flatten_with_keys,
+    _dict_unflatten,
+    _list_flatten,
+    _list_flatten_with_keys,
+    _list_unflatten,
+    Context,
+    register_pytree_node,
+)
+
+from ._compatibility import compatibility
+
+
+__all__ = ["immutable_list", "immutable_dict"]
+
+_help_mutation = """\
+If you are attempting to modify the kwargs or args of a torch.fx.Node object,
+instead create a new copy of it and assign the copy to the node:
+    new_args = ... # copy and mutate args
+    node.args = new_args
+"""
+
+
+def _no_mutation(self, *args, **kwargs):
+    raise NotImplementedError(
+        f"'{type(self).__name__}' object does not support mutation. {_help_mutation}",
+    )
+
+
+def _create_immutable_container(base, mutable_functions):
+    container = type("immutable_" + base.__name__, (base,), {})
+    for attr in mutable_functions:
+        setattr(container, attr, _no_mutation)
+    return container
+
+
+immutable_list = _create_immutable_container(
+    list,
+    [
+        "__delitem__",
+        "__iadd__",
+        "__imul__",
+        "__setitem__",
+        "append",
+        "clear",
+        "extend",
+        "insert",
+        "pop",
+        "remove",
+    ],
+)
+immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),))
+immutable_list.__hash__ = lambda self: hash(tuple(self))
+
+compatibility(is_backward_compatible=True)(immutable_list)
+
+immutable_dict = _create_immutable_container(
+    dict,
+    [
+        "__delitem__",
+        "__setitem__",
+        "clear",
+        "pop",
+        "popitem",
+        "update",
+    ],
+)
+immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),))
+immutable_dict.__hash__ = lambda self: hash(tuple(self.items()))
+compatibility(is_backward_compatible=True)(immutable_dict)
+
+
+# Register immutable collections for PyTree operations
+def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
+    return _dict_flatten(d)
+
+
+def _immutable_dict_unflatten(
+    values: Iterable[Any],
+    context: Context,
+) -> Dict[Any, Any]:
+    return immutable_dict(_dict_unflatten(values, context))
+
+
+def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
+    return _list_flatten(d)
+
+
+def _immutable_list_unflatten(
+    values: Iterable[Any],
+    context: Context,
+) -> List[Any]:
+    return immutable_list(_list_unflatten(values, context))
+
+
+register_pytree_node(
+    immutable_dict,
+    _immutable_dict_flatten,
+    _immutable_dict_unflatten,
+    serialized_type_name="torch.fx.immutable_collections.immutable_dict",
+    flatten_with_keys_fn=_dict_flatten_with_keys,
+)
+register_pytree_node(
+    immutable_list,
+    _immutable_list_flatten,
+    _immutable_list_unflatten,
+    serialized_type_name="torch.fx.immutable_collections.immutable_list",
+    flatten_with_keys_fn=_list_flatten_with_keys,
+)
diff --git a/MLPY/Lib/site-packages/torch/fx/interpreter.py b/MLPY/Lib/site-packages/torch/fx/interpreter.py
new file mode 100644
index 0000000000000000000000000000000000000000..267c394acf406c65699a317e15d1c9914c77bdfd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/interpreter.py
@@ -0,0 +1,512 @@
+from .graph_module import GraphModule
+from ._lazy_graph_module import _make_graph_module
+from .graph import Graph
+from .node import Argument, Node, Target, map_arg, map_aggregate
+from .proxy import Proxy
+from ._symbolic_trace import Tracer
+from ._compatibility import compatibility
+from . import config
+import torch.fx.traceback as fx_traceback
+import torch
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+import inspect
+from contextlib import contextmanager
+from torch.hub import tqdm
+
+__all__ = ['Interpreter', 'Transformer']
+
+@compatibility(is_backward_compatible=True)
+class Interpreter:
+    """
+    An Interpreter executes an FX graph Node-by-Node. This pattern
+    can be useful for many things, including writing code
+    transformations as well as analysis passes.
+
+    Methods in the Interpreter class can be overridden to customize
+    the behavior of execution. The map of overrideable methods
+    in terms of call hierarchy::
+
+        run()
+            +-- run_node
+                +-- placeholder()
+                +-- get_attr()
+                +-- call_function()
+                +-- call_method()
+                +-- call_module()
+                +-- output()
+
+    Example:
+
+        Suppose we want to swap all instances of ``torch.neg`` with
+        ``torch.sigmoid`` and vice versa (including their ``Tensor``
+        method equivalents). We could subclass Interpreter like so::
+
+            class NegSigmSwapInterpreter(Interpreter):
+                def call_function(self, target : Target,
+                                  args : Tuple, kwargs : Dict) -> Any:
+                    if target == torch.sigmoid:
+                        return torch.neg(*args, **kwargs)
+                    return super().call_function(n)
+
+                def call_method(self, target : Target,
+                                args : Tuple, kwargs : Dict) -> Any:
+                    if target == 'neg':
+                        call_self, *args_tail = args
+                        return call_self.sigmoid(*args_tail, **kwargs)
+                    return super().call_method(n)
+
+            def fn(x):
+                return torch.sigmoid(x).neg()
+
+            gm = torch.fx.symbolic_trace(fn)
+            input = torch.randn(3, 4)
+            result = NegSigmSwapInterpreter(gm).run(input)
+            torch.testing.assert_close(result, torch.neg(input).sigmoid())
+
+    Args:
+        module (torch.nn.Module): The module to be executed
+        garbage_collect_values (bool): Whether to delete values after their last
+            use within the Module's execution. This ensures optimal memory usage during
+            execution. This can be disabled to, for example, examine all of the intermediate
+            values in the execution by looking at the ``Interpreter.env`` attribute.
+        graph (Optional[Graph]): If passed, the interpreter will execute this
+            graph instead of `module.graph`, using the provided `module`
+            argument to satisfy any requests for state.
+    """
+    @compatibility(is_backward_compatible=True)
+    def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
+        self.module = module
+        self.submodules = dict(self.module.named_modules())
+        if graph is not None:
+            self.graph = graph
+        else:
+            self.graph = self.module.graph
+        self.env : Dict[Node, Any] = {}
+        self.name = "Interpreter"
+        self.garbage_collect_values = garbage_collect_values
+        self.extra_traceback = True
+
+        if self.garbage_collect_values:
+            # Run through reverse nodes and record the first instance of a use
+            # of a given node. This represents the *last* use of the node in the
+            # execution order of the program, which we will use to free unused
+            # values
+            node_to_last_use : Dict[Node, Node] = {}
+            self.user_to_last_uses : Dict[Node, List[Node]] = {}
+
+            def register_last_uses(n : Node, user : Node):
+                if n not in node_to_last_use:
+                    node_to_last_use[n] = user
+                    self.user_to_last_uses.setdefault(user, []).append(n)
+
+            for node in reversed(self.graph.nodes):
+                map_arg(node.args, lambda n: register_last_uses(n, node))
+                map_arg(node.kwargs, lambda n: register_last_uses(n, node))
+
+    @compatibility(is_backward_compatible=True)
+    def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
+        """
+        Run `module` via interpretation and return the result.
+
+        Args:
+            *args: The arguments to the Module to run, in positional order
+            initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
+                This is a dict mapping `Node` to any value. This can be used, for example, to
+                pre-populate results for certain `Nodes` so as to do only partial evaluation within
+                the interpreter.
+            enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
+                process_outputs function first before using them.
+
+        Returns:
+            Any: The value returned from executing the Module
+        """
+        self.env = initial_env if initial_env is not None else {}
+
+        # Positional function args are consumed left-to-right by
+        # `placeholder` nodes. Use an iterator to keep track of
+        # position and extract those values.
+        if enable_io_processing:
+            args = self.graph.process_inputs(*args)
+        self.args_iter : Iterator[Any] = iter(args)
+        pbar = tqdm(total=len(self.graph.nodes),
+                    desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
+                    initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
+
+        for node in self.graph.nodes:
+            pbar.update(1)
+            if node in self.env:
+                # Short circuit if we have this value. This could
+                # be used, for example, for partial evaluation
+                # where the caller has pre-populated `env` with
+                # values for a subset of the program.
+                continue
+
+            try:
+                self.env[node] = self.run_node(node)
+            except Exception as e:
+                if self.extra_traceback:
+                    msg = f"While executing {node.format_node()}"
+                    msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
+                    msg += f"\nOriginal traceback:\n{node.stack_trace}"
+                    e.args = (msg,) + e.args[1:]
+                    if isinstance(e, KeyError):
+                        raise RuntimeError(*e.args) from e
+                raise
+
+            if self.garbage_collect_values:
+                for to_delete in self.user_to_last_uses.get(node, []):
+                    del self.env[to_delete]
+
+            if node.op == 'output':
+                output_val = self.env[node]
+                return self.graph.process_outputs(output_val) if enable_io_processing else output_val
+
+    @compatibility(is_backward_compatible=True)
+    def boxed_run(self, args_list):
+        """
+        Run `module` via interpretation and return the result.  This uses the "boxed"
+        calling convention, where you pass a list of arguments, which will be cleared
+        by the interpreter.  This ensures that input tensors are promptly deallocated.
+        """
+        args_iter = iter(args_list)
+        env = {}
+        for n in self.graph.nodes:
+            if n.op == "placeholder":
+                env[n] = next(args_iter)
+        args_list.clear()
+        return self.run(initial_env=env)
+
+    @contextmanager
+    def _set_current_node(self, node):
+        with fx_traceback.set_current_meta(node):
+            yield
+
+    @compatibility(is_backward_compatible=True)
+    def run_node(self, n : Node) -> Any:
+        """
+        Run a specific node ``n`` and return the result.
+        Calls into placeholder, get_attr, call_function,
+        call_method, call_module, or output depending
+        on ``node.op``
+
+        Args:
+            n (Node): The Node to execute
+
+        Returns:
+            Any: The result of executing ``n``
+        """
+        with self._set_current_node(n):
+            args, kwargs = self.fetch_args_kwargs_from_env(n)
+            assert isinstance(args, tuple)
+            assert isinstance(kwargs, dict)
+            return getattr(self, n.op)(n.target, args, kwargs)
+
+    # Main Node running APIs
+    @compatibility(is_backward_compatible=True)
+    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        """
+        Execute a ``placeholder`` node. Note that this is stateful:
+        ``Interpreter`` maintains an internal iterator over
+        arguments passed to ``run`` and this method returns
+        next() on that iterator.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+
+        Returns:
+            Any: The argument value that was retrieved.
+        """
+        assert isinstance(target, str)
+        if target.startswith('*'):
+            # For a starred parameter e.g. `*args`, retrieve all
+            # remaining values from the args list.
+            return list(self.args_iter)
+        else:
+            try:
+                return next(self.args_iter)
+            except StopIteration as si:
+                if len(args) > 0:
+                    return args[0]
+                else:
+                    raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
+
+    @compatibility(is_backward_compatible=True)
+    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        """
+        Execute a ``get_attr`` node. Will retrieve an attribute
+        value from the ``Module`` hierarchy of ``self.module``.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+
+        Return:
+            Any: The value of the attribute that was retrieved
+        """
+        assert isinstance(target, str)
+        return self.fetch_attr(target)
+
+    @compatibility(is_backward_compatible=True)
+    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        """
+        Execute a ``call_function`` node and return the result.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+
+        Return
+            Any: The value returned by the function invocation
+        """
+        assert not isinstance(target, str)
+
+        # Execute the function and return the result
+        return target(*args, **kwargs)
+
+    @compatibility(is_backward_compatible=True)
+    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        """
+        Execute a ``call_method`` node and return the result.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+
+        Return
+            Any: The value returned by the method invocation
+        """
+        # args[0] is the `self` object for this method call
+        self_obj, *args_tail = args
+
+        # Execute the method and return the result
+        assert isinstance(target, str)
+        return getattr(self_obj, target)(*args_tail, **kwargs)
+
+    @compatibility(is_backward_compatible=True)
+    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        """
+        Execute a ``call_module`` node and return the result.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+
+        Return
+            Any: The value returned by the module invocation
+        """
+        # Retrieve executed args and kwargs values from the environment
+
+        # Execute the method and return the result
+        assert isinstance(target, str)
+        submod = self.fetch_attr(target)
+
+        return submod(*args, **kwargs)
+
+    @compatibility(is_backward_compatible=True)
+    def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        """
+        Execute an ``output`` node. This really just retrieves
+        the value referenced by the ``output`` node and returns it.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+
+        Return:
+            Any: The return value referenced by the output node
+        """
+        return args[0]
+
+    # Helper methods
+    @compatibility(is_backward_compatible=True)
+    def fetch_attr(self, target : str):
+        """
+        Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
+
+        Args:
+            target (str): The fully-qualified name of the attribute to fetch
+
+        Return:
+            Any: The value of the attribute.
+        """
+        target_atoms = target.split('.')
+        attr_itr = self.module
+        for i, atom in enumerate(target_atoms):
+            if not hasattr(attr_itr, atom):
+                raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
+            attr_itr = getattr(attr_itr, atom)
+        return attr_itr
+
+    @compatibility(is_backward_compatible=True)
+    def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
+        """
+        Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
+        from the current execution environment.
+
+        Args:
+            n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
+
+        Return:
+            Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
+        """
+        args = self.map_nodes_to_values(n.args, n)
+        assert isinstance(args, tuple)
+        kwargs = self.map_nodes_to_values(n.kwargs, n)
+        assert isinstance(kwargs, dict)
+        return args, kwargs
+
+    @compatibility(is_backward_compatible=True)
+    def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
+        """
+        Recursively descend through ``args`` and look up the concrete value
+        for each ``Node`` in the current execution environment.
+
+        Args:
+            args (Argument): Data structure within which to look up concrete values
+
+            n (Node): Node to which ``args`` belongs. This is only used for error reporting.
+        """
+        def load_arg(n_arg : Node) -> Any:
+            if n_arg not in self.env:
+                raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
+                                   f'to diagnose such issues')
+            return self.env[n_arg]
+        return map_arg(args, load_arg)
+
+@compatibility(is_backward_compatible=True)
+class Transformer(Interpreter):
+    """
+    ``Transformer`` is a special type of interpreter that produces a
+    new ``Module``. It exposes a ``transform()`` method that returns
+    the transformed ``Module``. ``Transformer`` does not require
+    arguments to run, as ``Interpreter`` does. ``Transformer`` works
+    entirely symbolically.
+
+    Example:
+
+        Suppose we want to swap all instances of ``torch.neg`` with
+        ``torch.sigmoid`` and vice versa (including their ``Tensor``
+        method equivalents). We could subclass ``Transformer`` like so::
+
+            class NegSigmSwapXformer(Transformer):
+                def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+                    if target == torch.sigmoid:
+                        return torch.neg(*args, **kwargs)
+                    return super().call_function(n)
+
+                def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+                    if target == 'neg':
+                        call_self, *args_tail = args
+                        return call_self.sigmoid(*args_tail, **kwargs)
+                    return super().call_method(n)
+
+            def fn(x):
+                return torch.sigmoid(x).neg()
+
+            gm = torch.fx.symbolic_trace(fn)
+
+            transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
+            input = torch.randn(3, 4)
+            torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
+
+    Args:
+        module (GraphModule): The ``Module`` to be transformed.
+    """
+
+    @compatibility(is_backward_compatible=True)
+    def __init__(self, module):
+        super().__init__(module)
+        self.new_graph = Graph()
+        self.new_graph.set_codegen(module.graph._codegen)
+
+        class TransformerTracer(Tracer):
+            def __init__(self, graph: Graph):
+                super().__init__()
+                self.graph = graph
+                self.tensor_attrs: Dict[torch.Tensor, str] = {}  # type: ignore[assignment]
+
+            def is_leaf_module(self, _, __) -> bool:
+                return True
+
+        self.tracer = TransformerTracer(self.new_graph)
+        self.tracer.root = module
+
+    @compatibility(is_backward_compatible=True)
+    def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
+        """
+        Execute a ``placeholder`` node. In ``Transformer``, this is
+        overridden to insert a new ``placeholder`` into the output
+        graph.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+        """
+        assert isinstance(target, str)
+        default_value = next(iter(args)) if args else inspect.Signature.empty
+        return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
+
+    @compatibility(is_backward_compatible=True)
+    def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
+        """
+        Execute a ``get_attr`` node. In ``Transformer``, this is
+        overridden to insert a new ``get_attr`` node into the output
+        graph.
+
+        Args:
+            target (Target): The call target for this node. See
+                `Node `__ for
+                details on semantics
+            args (Tuple): Tuple of positional args for this invocation
+            kwargs (Dict): Dict of keyword arguments for this invocation
+        """
+        assert isinstance(target, str)
+        return self.tracer.create_proxy("get_attr", target, args, kwargs)
+
+    @compatibility(is_backward_compatible=True)
+    def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        # Override so that the leaf module policy from `self.tracer` is respected.
+        assert isinstance(target, str)
+        submod = self.fetch_attr(target)
+        return self.tracer.call_module(submod, submod.forward, args, kwargs)
+
+    @compatibility(is_backward_compatible=True)
+    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
+        # Override so that functions that were wrapped are still wrapped.
+        return self.tracer.create_proxy('call_function', target, args, kwargs)
+
+    @compatibility(is_backward_compatible=True)
+    def transform(self) -> GraphModule:
+        """
+        Transform ``self.module`` and return the transformed
+        ``GraphModule``.
+        """
+        with fx_traceback.preserve_node_meta():
+            result = super().run(enable_io_processing=False)
+        if result is not None:
+            def strip_proxy(a : Union[Argument, Proxy]) -> Any:
+                return a.node if isinstance(a, Proxy) else a
+            self.new_graph.output(map_aggregate(result, strip_proxy))
+        return _make_graph_module(self.module, self.new_graph)
diff --git a/MLPY/Lib/site-packages/torch/fx/node.py b/MLPY/Lib/site-packages/torch/fx/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7130b87fd03f051d531f57f0d2b276af61bd234
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/node.py
@@ -0,0 +1,726 @@
+# mypy: ignore-errors
+
+# Nodes represent a definition of a value in our graph of operators.
+from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
+from ._compatibility import compatibility
+from .immutable_collections import immutable_dict, immutable_list
+import torch
+import builtins
+import types
+import inspect
+import warnings
+from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
+from .._ops import ops as _ops
+
+if TYPE_CHECKING:
+    from .graph import Graph
+
+__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
+
+BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
+                          torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload]
+base_types = BaseArgumentTypes.__args__  # type: ignore[attr-defined]
+
+Target = Union[Callable[..., Any], str]
+
+Argument = Optional[Union[
+    Tuple[Any, ...],  # actually Argument, but mypy can't represent recursive types
+    List[Any],  # actually Argument
+    Dict[str, Any],  # actually Argument
+    slice,  # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
+    range,
+    'Node',
+    BaseArgumentTypes
+]]
+
+_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
+    torch._C._set_grad_enabled,
+    torch.amp._enter_autocast,
+    torch.amp._exit_autocast,
+}
+
+# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
+# or add logic to correctly mark all inplace ops as side effectful.
+_side_effectful_functions: Set[Callable] = {
+    torch._assert,
+    torch._assert_async,
+    _ops.aten._assert_async.msg,
+    _ops.aten._assert_scalar.default,
+    _ops.aten.copy_.default,
+    _ops.aten.index_put_.default,
+    _ops.aten.sym_constrain_range.default,
+    _ops.aten.sym_constrain_range_for_size.default,
+    _ops.profiler._record_function_enter,
+    _ops.profiler._record_function_enter_new,
+    _ops.profiler._record_function_exit,
+    _ops.inductor.accumulate_grad_.default,
+    _ops.inductor.resize_storage_bytes_.default,
+} | _side_effectful_need_to_be_preserved_pre_dispatch
+
+
+@compatibility(is_backward_compatible=False)
+def has_side_effect(fn: Callable) -> None:
+    _side_effectful_functions.add(fn)
+    return fn
+
+
+# this is fixed on master, WAR for 1.5
+def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
+    name = orig_method.__name__
+    module = orig_method.__module__
+    if module is not None:
+        return module
+    for guess in [torch, torch.nn.functional]:
+        if getattr(guess, name, None) is orig_method:
+            return guess.__name__
+    raise RuntimeError(f'cannot find module for {orig_method}')
+
+# Borrowed from CPython typing module
+# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
+def _type_repr(obj):
+    """Return the repr() of an object, special-casing types (internal helper).
+    If obj is a type, we return a shorter version than the default
+    type.__repr__, based on the module and qualified name, which is
+    typically enough to uniquely identify a type.  For everything
+    else, we fall back on repr(obj).
+    """
+    if isinstance(obj, type):
+        if obj.__module__ == 'builtins':
+            return obj.__qualname__
+        return f'{obj.__module__}.{obj.__qualname__}'
+    if obj is ...:
+        return '...'
+    if isinstance(obj, types.FunctionType):
+        return obj.__name__
+    return repr(obj)
+
+def _get_qualified_name(func: Callable[..., Any]) -> str:
+    # things like getattr just appear in builtins
+    if getattr(builtins, func.__name__, None) is func:
+        return func.__name__
+    # torch.Tensor.{fn}
+    if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType))
+       and func is getattr(torch.Tensor, func.__name__, None)):
+        return f"torch.Tensor.{func.__name__}"
+    name = func.__name__
+    if name == "":
+        # For lambdas, try to get their defining name in the module
+        try:
+            name = inspect.getsource(func).split("=")[0].strip()
+        except Exception as e:
+            raise RuntimeError("Unable to represent lambda") from e
+    module = _find_module_of_method(func)
+    module = module.replace('torch._ops', 'torch.ops')  # WAR for bug in how torch.ops assigns module
+    # Fixup segment_reduce mismatch
+    if module == "torch" and name == "segment_reduce":
+        name = "_" + name
+    return f'{module}.{name}'
+
+def _format_arg(arg, max_list_len=float('inf')) -> str:
+    if hasattr(arg, '_custom_fx_repr_fn'):
+        return arg._custom_fx_repr_fn()
+    elif isinstance(arg, list):
+        items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
+        maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
+        return f'[{items}{maybe_len}]'
+    elif isinstance(arg, tuple):
+        items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
+        maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
+        maybe_comma = ',' if len(arg) == 1 else ''
+        return f'({items}{maybe_comma}{maybe_len})'
+    elif isinstance(arg, dict):
+        items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items())
+        return f'{{{items_str}}}'
+
+    if isinstance(arg, Node):
+        return '%' + str(arg)
+    else:
+        return str(arg)
+
+@compatibility(is_backward_compatible=True)
+class Node:
+    """
+    ``Node`` is the data structure that represents individual operations within
+    a ``Graph``. For the most part, Nodes represent callsites to various entities,
+    such as operators, methods, and Modules (some exceptions include nodes that
+    specify function inputs and outputs). Each ``Node`` has a function specified
+    by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows:
+
+    - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
+      ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
+      denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
+      the function parameters (e.g. ``x``) in the graph printout.
+    - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
+      fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
+      ``args`` and ``kwargs`` are don't-care
+    - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
+      to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
+      following the Python calling convention
+    - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
+      as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
+      ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*.
+    - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
+      to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
+      *including the self argument*
+    - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
+      in the Graph printout.
+    """
+
+    @compatibility(is_backward_compatible=True)
+    def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
+                 args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
+                 return_type : Optional[Any] = None) -> None:
+        """
+        Instantiate an instance of ``Node``. Note: most often, you want to use the
+        Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
+        than instantiating a ``Node`` directly.
+
+        Args:
+            graph (Graph): The ``Graph`` to which this ``Node`` should belong.
+
+            name (str): The name to which the output of this ``Node`` should be assigned
+
+            op (str): The opcode for this ``Node``. Can be one of 'placeholder',
+                'call_method', 'call_module', 'call_function', 'get_attr',
+                'output'
+
+            target ('Target'): The target this op should call. See the broader
+                ``Node`` docstring for more details.
+
+            args (Tuple['Argument']): The args to be passed to ``target``
+
+            kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target``
+
+            return_type (Optional[Any]): The python type expression representing the
+                type of the output of this node. This field can be used for
+                annotation of values in the generated code or for other types
+                of analyses.
+        """
+        self.graph = graph
+        self.name = name  # unique name of value being created
+        assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
+        self.op = op  # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
+        if op == 'call_function':
+            if not callable(target):
+                raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
+                                 'but a Callable is expected')
+        else:
+            if not isinstance(target, str):
+                raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
+                                 'but a str is expected')
+        self.target = target  # for method/module/function, the name of the method/module/function/attr
+        # being invoked, e.g add, layer1, or torch.add
+
+        # All `Node`-valued inputs. Key is the Node, value is don't-care.
+        # The public API for this is `all_input_nodes`, this private attribute
+        # should not be accessed directly.
+        self._input_nodes : Dict[Node, None] = {}
+        self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x))  # type: ignore[arg-type]
+
+        # All of the nodes that use the value produced by this Node
+        # Note one user may correspond to several uses, e.g. the node fo ``x + x``
+        # would appear once here, but represents two uses.
+        #
+        # Is a dict to act as an "ordered set". Keys are significant, value dont-care
+        self.users : Dict[Node, None] = {}
+        # Type expression representing the output value of this node.
+        # This should contain the same class of Type objects that would appear
+        # as type annotations for function inputs/outputs.
+        #
+        # For placeholder nodes, this value will be used to type-annotate the
+        # generated function parameters.
+        # For the return node, this value will be used to type-annotate the
+        # generated function return type. (Note this is a special case. ``return``
+        # does not produce a value, it's more of a notation. Thus, this value
+        # describes the type of args[0] in the ``return`` node.
+        self.type : Optional[Any] = return_type
+        self._prev = self
+        self._next = self
+        self._erased = False
+
+        # If set, use this fn to print this node
+        self._repr_fn : Optional[Callable[[Node], str]] = None
+
+        # Dictionary to store metadata passes need to do their
+        # transformations. This metadata is preserved across node copies
+        self.meta : Dict[str, Any] = {}
+
+    @property
+    def next(self) -> 'Node':
+        """
+        Returns the next ``Node`` in the linked list of Nodes.
+
+        Returns:
+
+            The next ``Node`` in the linked list of Nodes.
+        """
+        return self._next
+
+    @property
+    def prev(self) -> 'Node':
+        """
+        Returns the previous ``Node`` in the linked list of Nodes.
+
+        Returns:
+
+            The previous ``Node`` in the linked list of Nodes.
+        """
+        return self._prev
+
+    @compatibility(is_backward_compatible=True)
+    def prepend(self, x: 'Node') -> None:
+        """
+        Insert x before this node in the list of nodes in the graph. Example::
+
+            Before: p -> self
+                    bx -> x -> ax
+            After:  p -> x -> self
+                    bx -> ax
+
+        Args:
+            x (Node): The node to put before this node. Must be a member of the same graph.
+        """
+        assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
+        if self == x:
+            warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.")
+            return
+        x._remove_from_list()
+        p = self._prev
+        p._next, x._prev = x, p
+        x._next, self._prev = self, x
+
+    @compatibility(is_backward_compatible=True)
+    def append(self, x: 'Node') -> None:
+        """
+        Insert ``x`` after this node in the list of nodes in the graph.
+        Equivalent to ``self.next.prepend(x)``
+
+        Args:
+            x (Node): The node to put after this node. Must be a member of the same graph.
+        """
+        self._next.prepend(x)
+
+    def _remove_from_list(self):
+        p, n = self._prev, self._next
+        p._next, n._prev = n, p
+
+    @property
+    def args(self) -> Tuple[Argument, ...]:
+        """
+        The tuple of arguments to this ``Node``. The interpretation of arguments
+        depends on the node's opcode. See the :class:`Node` docstring for more
+        information.
+
+        Assignment to this property is allowed. All accounting of uses and users
+        is updated automatically on assignment.
+        """
+        return self._args
+
+    @args.setter
+    def args(self, a : Tuple[Argument, ...]):
+        """
+        Set the tuple of arguments to this Node. The interpretation of arguments
+        depends on the node's opcode. See the ``fx.Graph`` docstring for more
+        information.
+        """
+        # DO NOT CALL `__update_args_kwargs` directly. The correct way to
+        # set `args` is via direct assignment, i.e. `node.args = new_args`
+        self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs)  # type: ignore[arg-type]
+
+    @property
+    def kwargs(self) -> Dict[str, Argument]:
+        """
+        The dict of keyword arguments to this ``Node``. The interpretation of arguments
+        depends on the node's opcode. See the :class:`Node` docstring for more
+        information.
+
+        Assignment to this property is allowed. All accounting of uses and users
+        is updated automatically on assignment.
+        """
+        return self._kwargs
+
+    @kwargs.setter
+    def kwargs(self, k : Dict[str, Argument]):
+        """
+        Set the dict of kwargs to this Node. The interpretation of arguments
+        depends on the node's opcode. See the ``fx.Graph`` docstring for more
+        information.
+        """
+        # DO NOT CALL `__update_args_kwargs` directly. The correct way to
+        # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
+        self.__update_args_kwargs(self._args, map_arg(k, lambda x: x))  # type: ignore[arg-type]
+
+    @property
+    def all_input_nodes(self) -> List['Node']:
+        """
+        Return all Nodes that are inputs to this Node. This is equivalent to
+        iterating over ``args`` and ``kwargs`` and only collecting the values that
+        are Nodes.
+
+        Returns:
+
+            List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
+            ``Node``, in that order.
+        """
+        return list(self._input_nodes.keys())
+
+    @compatibility(is_backward_compatible=True)
+    def update_arg(self, idx : int, arg : Argument) -> None:
+        """
+        Update an existing positional argument to contain the new value
+        ``arg``. After calling, ``self.args[idx] == arg``.
+
+        Args:
+
+            idx (int): The index into ``self.args`` of the element to update
+            arg (Argument): The new argument value to write into ``args``
+        """
+        args = list(self.args)
+        args[idx] = arg
+        self.args = tuple(args)
+
+    @compatibility(is_backward_compatible=True)
+    def insert_arg(self, idx : int, arg : Argument) -> None:
+        """
+        Insert an positional argument to the argument list with given index.
+
+        Args:
+
+            idx (int): The index of the element in ``self.args`` to be inserted before.
+            arg (Argument): The new argument value to insert into ``args``
+        """
+        assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)"
+        args_left = self.args[:idx]
+        args_right = self.args[idx:]
+
+        self._args = args_left + (arg,) + args_right
+
+        _new_input_nodes = {}
+        map_arg(arg, _new_input_nodes.setdefault)
+
+        for new_use in _new_input_nodes.keys():
+            if new_use not in self._input_nodes:
+                self._input_nodes.setdefault(new_use)
+                new_use.users.setdefault(self)
+
+    @compatibility(is_backward_compatible=True)
+    def update_kwarg(self, key : str, arg : Argument) -> None:
+        """
+        Update an existing keyword argument to contain the new value
+        ``arg``. After calling, ``self.kwargs[key] == arg``.
+
+        Args:
+
+            key (str): The key in ``self.kwargs`` of the element to update
+            arg (Argument): The new argument value to write into ``kwargs``
+        """
+        kwargs = dict(self.kwargs)
+        kwargs[key] = arg
+        self.kwargs = kwargs
+
+    @property
+    def stack_trace(self) -> Optional[str]:
+        """
+        Return the Python stack trace that was recorded during tracing, if any.
+        When traced with fx.Tracer, this property is usually populated by
+        `Tracer.create_proxy`. To record stack traces during tracing for debug purposes,
+        set `record_stack_traces = True` on the `Tracer` instance.
+        When traced with dynamo, this property will be populated by default by
+        `OutputGraph.create_proxy`.
+
+        stack_trace would have the innermost frame at the end of the string.
+        """
+        return self.meta.get("stack_trace", None)
+
+    @stack_trace.setter
+    def stack_trace(self, trace : Optional[str]):
+        self.meta["stack_trace"] = trace
+
+    def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']):
+        """
+        This API is internal. Do *not* call it directly.
+        """
+        self._args = new_args
+        self._kwargs = new_kwargs
+
+        for old_use in self._input_nodes.keys():
+            old_use.users.pop(self)
+
+        self._input_nodes = {}
+        map_arg(self._args, self._input_nodes.setdefault)
+        map_arg(self._kwargs, self._input_nodes.setdefault)
+
+        for new_use in self._input_nodes.keys():
+            new_use.users.setdefault(self)
+
+    def __repr__(self) -> str:
+        if self._repr_fn:
+            return self._repr_fn(self)
+        return self.name
+
+    def _pretty_print_target(self, target):
+        """
+        Make target printouts more user-friendly.
+        1) builtins will be printed as `builtins.xyz`
+        2) operators will be printed as `operator.xyz`
+        3) other callables will be printed with qualified name, e.g. torch.add
+        """
+        if isinstance(target, str):
+            return target
+        if hasattr(target, '__module__'):
+            if not hasattr(target, '__name__'):
+                # Just to be defensive, if we don't have `__name__`, get the
+                # qualname. Not sure if this happens for any members of `operator`
+                # or `builtins`. This fallback path is not as good, since e.g.
+                # things in `operator` have `_operator` as their __module__.
+                return _get_qualified_name(target)
+            if target.__module__ == 'builtins':
+                return f'builtins.{target.__name__}'
+            elif target.__module__ == '_operator':
+                return f'operator.{target.__name__}'
+        return _get_qualified_name(target)
+
+    @compatibility(is_backward_compatible=True)
+    def format_node(self,
+                    placeholder_names: Optional[List[str]] = None,
+                    maybe_return_typename: Optional[List[str]] = None) -> Optional[str]:
+        """
+        Return a descriptive string representation of ``self``.
+
+        This method can be used with no arguments as a debugging
+        utility.
+
+        This function is also used internally in the ``__str__`` method
+        of ``Graph``. Together, the strings in ``placeholder_names``
+        and ``maybe_return_typename`` make up the signature of the
+        autogenerated ``forward`` function in this Graph's surrounding
+        GraphModule. ``placeholder_names`` and ``maybe_return_typename``
+        should not be used otherwise.
+
+        Args:
+            placeholder_names: A list that will store formatted strings
+                representing the placeholders in the generated
+                ``forward`` function. Internal use only.
+            maybe_return_typename: A single-element list that will store
+                a formatted string representing the output of the
+                generated ``forward`` function. Internal use only.
+
+        Returns:
+            str: If 1) we're using ``format_node`` as an internal helper
+                in the ``__str__`` method of ``Graph``, and 2) ``self``
+                is a placeholder Node, return ``None``. Otherwise,
+                return a  descriptive string representation of the
+                current Node.
+        """
+        if self.op == 'placeholder':
+            assert isinstance(self.target, str)
+            arg_str = self.target
+            arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else ''
+            if placeholder_names:
+                placeholder_names.append(arg_str)
+                return None
+            maybe_typename = f'{_type_repr(self.type)} ' if self.type else ''
+            default_val = '(default=' + str(self.args[0]) + ')' if self.args else ''
+            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
+        elif self.op == 'get_attr':
+            maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
+            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
+                   f'{self.op}[target={self._pretty_print_target(self.target)}]'
+        elif self.op == 'output':
+            if self.type and maybe_return_typename:
+                maybe_return_typename[0] = f' -> {_type_repr(self.type)}'
+            return f'return {self.args[0]}'
+        else:
+            maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
+            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
+                   f'{self.op}[target={self._pretty_print_target(self.target)}](' \
+                   f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
+
+    @compatibility(is_backward_compatible=True)
+    def replace_all_uses_with(self,
+                              replace_with : 'Node',
+                              delete_user_cb: Callable[['Node'], bool] = lambda user: True,
+                              *,
+                              propagate_meta=False
+                              ) -> List['Node']:
+        """
+        Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
+
+        Args:
+
+            replace_with (Node): The node to replace all uses of ``self`` with.
+            delete_user_cb (Callable): Callback that is called to determine
+              whether a given user of the self node should be removed.
+            propagate_meta (bool): Whether or not to copy all properties
+              on the .meta field of the original node onto the replacement node.
+              For safety, this is only valid to do if the replacement node
+              doesn't already have an existing .meta field.
+
+        Returns:
+
+            The list of Nodes on which this change was made.
+        """
+        if propagate_meta:
+            assert len(replace_with.meta) == 0, \
+                'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \
+                'but replace_with already has .meta keys'
+            for k, v in self.meta.items():
+                replace_with.meta[k] = v
+        to_process = list(self.users)
+        skipped = []
+        m = self.graph.owning_module
+        for use_node in to_process:
+            if not delete_user_cb(use_node):
+                skipped.append(use_node)
+                continue
+
+            def maybe_replace_node(n : Node) -> Node:
+                if n == self:
+                    return replace_with
+                else:
+                    return n
+
+            if getattr(m, "_replace_hook", None):
+                m._replace_hook(old=self, new=replace_with.name, user=use_node)
+
+            new_args = map_arg(use_node.args, maybe_replace_node)
+            new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
+            assert isinstance(new_args, tuple)
+            assert isinstance(new_kwargs, dict)
+            use_node.__update_args_kwargs(new_args, new_kwargs)
+
+        assert len(self.users) - len(skipped) == 0
+        return [n for n in to_process if n not in skipped]
+
+    @compatibility(is_backward_compatible=False)
+    def is_impure(self):
+        """
+        Returns whether this op is impure, i.e. if its op is a placeholder or
+        output, or if a call_function or call_module which is impure.
+
+        Returns:
+
+            bool: If the op is impure or not.
+        """
+        if self.op in {"placeholder", "output"}:
+            return True
+
+        # Check if an impure function.
+        if self.op == "call_function":
+            return self.target in _side_effectful_functions
+
+        # Check if an impure module.
+        if self.op == "call_module":
+            assert (
+                self.graph.owning_module is not None
+            ), "self.graph.owning_module not set for purity check"
+            target_mod = self.graph.owning_module.get_submodule(self.target)
+            assert (
+                target_mod is not None
+            ), f"Did not find expected submodule target {self.target}"
+            return getattr(target_mod, "_is_impure", False)
+
+        return False
+
+    @compatibility(is_backward_compatible=False)
+    def normalized_arguments(
+            self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
+            kwarg_types : Optional[Dict[str, Any]] = None,
+            normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
+        """
+        Returns normalized arguments to Python targets. This means that
+        `args/kwargs` will be matched up to the module/functional's
+        signature and return exclusively kwargs in positional order
+        if `normalize_to_only_use_kwargs` is true.
+        Also populates default values. Does not support positional-only
+        parameters or varargs parameters.
+
+        Supports module calls.
+
+        May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
+
+        Args:
+            root (torch.nn.Module): Module upon which to resolve module targets.
+            arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
+            kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
+            normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
+
+        Returns:
+
+            Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
+        """
+        if self.op == 'call_function':
+            assert callable(self.target)
+            return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types)  # type: ignore[arg-type]
+        elif self.op == 'call_module':
+            assert isinstance(self.target, str)
+            return normalize_module(root, self.target, self.args, self.kwargs)  # type: ignore[arg-type]
+
+        return None
+
+    @compatibility(is_backward_compatible=True)
+    def replace_input_with(self, old_input: 'Node', new_input: 'Node'):
+        """
+        Loop through input nodes of ``self``, and replace all instances of
+        ``old_input`` with ``new_input``.
+
+        Args:
+
+            old_input (Node): The old input node to be replaced.
+            new_input (Node): The new input node to replace ``old_input``.
+        """
+        def maybe_replace_node(n : Node) -> Node:
+            return new_input if n == old_input else n
+
+        m = self.graph.owning_module
+        if getattr(m, "_replace_hook", None):
+            m._replace_hook(old=old_input, new=new_input.name, user=self)
+
+        new_args = map_arg(self.args, maybe_replace_node)
+        new_kwargs = map_arg(self.kwargs, maybe_replace_node)
+        assert isinstance(new_args, tuple)
+        assert isinstance(new_kwargs, dict)
+        self.__update_args_kwargs(new_args, new_kwargs)
+
+    def _rename(self, candidate: str):
+        if candidate == self.name:
+            return
+        name = self.graph._graph_namespace.create_name(candidate, None)
+        self.name = name
+        self.graph._graph_namespace._rename_object(self, name)
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        if name == 'name' and hasattr(self, "name"):
+            m = self.graph.owning_module
+            if getattr(m, "_replace_hook", None):
+                assert isinstance(value, str)
+                for user in self.users:
+                    m._replace_hook(old=self, new=value, user=user)
+        object.__setattr__(self, name, value)
+
+
+@compatibility(is_backward_compatible=True)
+def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
+    """
+    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
+    """
+    assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
+    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
+
+@compatibility(is_backward_compatible=True)
+def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
+    """
+    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
+    """
+    if isinstance(a, tuple):
+        t = tuple(map_aggregate(elem, fn) for elem in a)
+        # Support NamedTuple (if it has `_fields`) by repacking into original type.
+        return t if not hasattr(a, '_fields') else type(a)(*t)
+    elif isinstance(a, list):
+        return immutable_list(map_aggregate(elem, fn) for elem in a)
+    elif isinstance(a, dict):
+        return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
+    elif isinstance(a, slice):
+        return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
+    else:
+        return fn(a)
diff --git a/MLPY/Lib/site-packages/torch/fx/operator_schemas.py b/MLPY/Lib/site-packages/torch/fx/operator_schemas.py
new file mode 100644
index 0000000000000000000000000000000000000000..c781fc533e67445a2ecd3ebd5627ed28eb45d2c4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/operator_schemas.py
@@ -0,0 +1,441 @@
+import torch
+import inspect
+import numbers
+import types
+import typing
+import enum
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
+from torch._jit_internal import boolean_dispatched
+from ._compatibility import compatibility
+from torch._ops import OpOverloadPacket, OpOverload
+
+if TYPE_CHECKING:
+    from .node import Argument
+
+__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
+           "type_matches", "normalize_function", "normalize_module"]
+
+@compatibility(is_backward_compatible=False)
+class ArgsKwargsPair(NamedTuple):
+    """
+    Simple named tuple for wrapping args/kwargs pairs.
+    """
+    args: Tuple[Any, ...]
+    kwargs: Dict[str, Any]
+
+_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
+
+def _nonzero_schemas():
+    signatures = []
+
+    def nonzero(self):
+        pass
+    signatures.append(inspect.signature(nonzero))
+
+    def nonzero(self, *, as_tuple : bool):  # type: ignore[no-redef]
+        pass
+    signatures.append(inspect.signature(nonzero))
+
+    return signatures
+
+_manual_overrides[torch.nonzero] = _nonzero_schemas()
+
+class _FakeGlobalNamespace:
+    def __getattr__(self, name):
+        if name == 'torch':
+            return torch
+        raise RuntimeError('Expected a torch namespace lookup')
+
+_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
+                      'number' : numbers.Number, 'Future' : torch.jit.Future,
+                      'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
+                      '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
+                      'Storage': torch.UntypedStorage,
+                      't': typing.TypeVar('t')}
+for k in dir(typing):
+    _type_eval_globals[k] = getattr(typing, k)
+
+def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
+    """
+    Convert a TorchScript type to a Python type (including subtypes) via
+    eval'ing the annotation_str. _type_eval_globals sets up expressions
+    like "List" and "Future" to map to actual types (typing.List and jit.Future)
+    """
+    return eval(ts_type.annotation_str, _type_eval_globals)
+
+def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
+    from inspect import Parameter
+    parameters : List[Parameter] = []
+    for arg in ts_schema.arguments:
+        arg_type = _torchscript_type_to_python_type(arg.type)
+        default = arg.default_value if arg.has_default_value() else Parameter.empty
+        # TODO: Figure out if this is safe. It seems like when generating the type signatures for
+        # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
+        # argument name. Downstream, if someone converts that positional argument to a keyword
+        # argument, the name mismatch will break things, so here we're going to normalize the
+        # name to "input"
+        name = arg.name if arg.name != 'self' else 'input'
+        kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
+        # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
+        if name == "from":
+            assert kind == Parameter.POSITIONAL_OR_KEYWORD
+            # ParameterKind type is internal implementation detail to inspec package
+            # which makes it hard to do type annotation
+            kind = Parameter.POSITIONAL_ONLY  # type: ignore[assignment]
+            # This renders all previous arguments to positional only
+            for idx, p in enumerate(parameters):
+                assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
+                parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
+        parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
+    return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
+    if len(return_types) == 0:
+        return_type = None
+    elif len(return_types) == 1:
+        return_type = return_types[0]
+    else:
+        return_type = tuple(return_types)
+
+    return inspect.Signature(parameters, return_annotation=return_type)
+
+_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {}
+
+def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
+    # Cached as it's called in the hot path of FakeTensor dispatch
+    cache_key = ts_schema.name, ts_schema.overload_name
+    cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
+    if cache_val is not None:
+        return cache_val
+
+    res = _torchscript_schema_to_signature_impl(ts_schema)
+    _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
+    return res
+
+@compatibility(is_backward_compatible=False)
+def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
+    signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
+
+    if signatures and schemas:
+        matched_schemas = []
+
+        # Iterate through all of the schema until we find one that matches
+        # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
+        # values. If none matches, `new_args_and_kwargs` will be None
+        for candidate_signature, schema in zip(signatures, schemas):
+            try:
+                candidate_signature.bind(*args, **kwargs)
+                matched_schemas.append((candidate_signature, schema))
+            except TypeError as e:
+                continue
+
+        def throw_if_mutable(schema):
+            if schema.is_mutable:
+                raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
+                                   f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
+                                   f'are not supported')
+
+        if len(matched_schemas) == 0:
+            # Did not match any schema. Cannot check for mutation
+            pass
+        elif len(matched_schemas) == 1:
+            # Matched exactly one schema, unambiguous
+            _, schema_to_check = matched_schemas[0]
+            throw_if_mutable(schema_to_check)
+            pass
+        else:
+            # Ambiguous schema match. Since mutability checking is best effort,
+            # do nothing.
+            pass
+
+@compatibility(is_backward_compatible=False)
+def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
+    """
+    Given an operator on the `torch` namespace, return a list of `inspect.Signature`
+    objects corresponding to the overloads of that op.. May return `None` if a signature
+    could not be retrieved.
+
+    Args:
+        op (Callable): An operator on the `torch` namespace to look up a signature for
+
+    Returns:
+        Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
+            operator, or None if the operator signatures could not be retrieved. If
+            return_schemas=True, returns a tuple containing the optional Python signatures
+            and the optional TorchScript Function signature
+    """
+    if isinstance(op, OpOverload):
+        schemas = [op._schema]
+    elif isinstance(op, OpOverloadPacket):
+        schemas = [getattr(op, overload)._schema for overload in op.overloads()]
+    else:
+        override = _manual_overrides.get(op)
+        if override:
+            return (override, None) if return_schemas else None
+
+        aten_fn = torch.jit._builtins._find_builtin(op)
+
+        if aten_fn is None:
+            return (None, None) if return_schemas else None
+        schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
+
+    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
+    return (signatures, schemas) if return_schemas else signatures
+
+@compatibility(is_backward_compatible=False)
+def create_type_hint(x):
+    try:
+        if isinstance(x, (list, tuple)):
+            # todo(chilli): Figure out the right way for mypy to handle this
+            if isinstance(x, list):
+                def ret_type(x):
+                    return List[x]  # type: ignore[valid-type]
+            else:
+                def ret_type(x):
+                    return Tuple[x, ...]
+            if len(x) == 0:
+                return ret_type(Any)
+            base_type = x[0]
+            for t in x:
+                if issubclass(t, base_type):
+                    continue
+                elif issubclass(base_type, t):
+                    base_type = t
+                else:
+                    return ret_type(Any)
+            return ret_type(base_type)
+    except Exception as e:
+        # We tried to create a type hint for list but failed.
+        warnings.warn(f"We were not able to successfully create type hint from the type {x}")
+        pass
+    return x
+
+@compatibility(is_backward_compatible=False)
+def type_matches(signature_type : Any, argument_type : Any):
+    sig_origin_type = getattr(signature_type, '__origin__', signature_type)
+
+    if signature_type is argument_type:
+        return True
+
+    # Union types in signature. Given type needs to match one of the
+    # contained types in the Union
+    if sig_origin_type is typing.Union and signature_type != argument_type:
+        sig_contained = signature_type.__args__
+        return any(type_matches(c, argument_type) for c in sig_contained)
+
+    if signature_type is List[int] and argument_type is int:
+        # int can be promoted to List[int]
+        return True
+
+    if getattr(signature_type, '__origin__', None) in {list, List}:
+        sig_el_type = signature_type.__args__[0]
+        if not inspect.isclass(sig_el_type):
+            warnings.warn(
+                f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
+            return False
+        if getattr(argument_type, '__origin__', None) in {list, List}:
+            return issubclass(argument_type.__args__[0], sig_el_type)
+
+        def is_homogeneous_tuple(t):
+            if getattr(t, "__origin__", None) not in {tuple, Tuple}:
+                return False
+            contained = t.__args__
+            if t.__args__ == ((),):  # Tuple[()].__args__ == ((),) for some reason
+                return True
+            return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
+
+        # Tuple[T] is accepted for List[T] parameters
+        return is_homogeneous_tuple(argument_type)
+
+    # Dtype is an int in schemas
+    if signature_type is int and argument_type is torch.dtype:
+        return True
+
+    if signature_type is numbers.Number and argument_type in {int, float}:
+        return True
+    if inspect.isclass(argument_type) and inspect.isclass(signature_type):
+        return issubclass(argument_type, signature_type)
+
+    return False
+
+@compatibility(is_backward_compatible=False)
+def normalize_function(
+        target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
+        kwarg_types : Optional[Dict[str, Any]] = None,
+        normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
+    """
+    Returns normalized arguments to PyTorch functions. This means that
+    `args/kwargs` will be matched up to the functional's
+    signature and return exclusively kwargs in positional order if
+    `normalize_to_only_use_kwargs` is True.
+    Also populates default values. Does not support positional-only
+    parameters or varargs parameters (*args, **kwargs). Does not support modules.
+
+    May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
+
+    Args:
+        target (Callable): Function that we are normalizing
+        args (Tuple[Any]): Tuple of args to the function
+        kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
+        arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
+        kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
+        normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
+
+    Returns:
+
+        Returns normalized_args_and_kwargs, or `None` if not successful.
+    """
+    if kwargs is None:
+        kwargs = {}
+    new_args_and_kwargs = None
+    if not isinstance(target, types.BuiltinFunctionType) and not (
+        isinstance(target, (OpOverloadPacket, OpOverload))
+    ):
+        target_for_analysis = target
+        if target in boolean_dispatched:
+            # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
+            # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
+            # branches of the dispatch have exactly the same signature. If they do, use the `true`
+            # branch signature for analysis. Otherwise, leave this un-normalized
+            assert not isinstance(target, str)
+            dispatched = boolean_dispatched[target]
+            if_true, if_false = dispatched['if_true'], dispatched['if_false']
+            if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
+                return None
+            target_for_analysis = if_true
+
+        assert callable(target_for_analysis)
+        sig = inspect.signature(inspect.unwrap(target_for_analysis))
+        new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
+    else:
+        assert callable(target)
+        torch_op_schemas = get_signature_for_torch_op(target)
+        matched_schemas = []
+        if torch_op_schemas:
+            # Iterate through all of the schema until we find one that matches
+            # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
+            # values. If none matches, `new_args_and_kwargs` will be None
+            for candidate_signature in torch_op_schemas:
+                try:
+                    candidate_signature.bind(*args, **kwargs)
+                    matched_schemas.append(candidate_signature)
+                except TypeError as e:
+                    continue
+
+            if len(matched_schemas) == 0:
+                # Did not match any schema. Cannot normalize
+                pass
+            elif len(matched_schemas) == 1:
+                # Matched exactly one schema, unambiguous
+                new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
+                                                                             normalize_to_only_use_kwargs)
+            else:
+                if arg_types is not None or kwarg_types is not None:
+                    arg_types = arg_types if arg_types else cast(Tuple[Any], ())
+                    kwarg_types = kwarg_types if kwarg_types else {}
+                    for candidate_signature in torch_op_schemas:
+                        sig_matches = True
+                        try:
+                            bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
+                            for arg_name, arg_type in bound_types.arguments.items():
+                                param = candidate_signature.parameters[arg_name]
+                                sig_matches = sig_matches and type_matches(param.annotation, arg_type)
+                        except TypeError as e:
+                            sig_matches = False
+                        if sig_matches:
+                            new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
+                                                                                         normalize_to_only_use_kwargs)
+                            break
+                else:
+                    # Matched more than one schema. In this situation, the caller must provide the types of
+                    # the arguments of the overload they expect.
+                    schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
+                    raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
+                                       f'the schema match was ambiguous! Please provide argument types to '
+                                       f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
+
+    return new_args_and_kwargs
+
+@compatibility(is_backward_compatible=False)
+def normalize_module(
+        root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
+        normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
+    """
+    Returns normalized arguments to PyTorch modules. This means that
+    `args/kwargs` will be matched up to the functional's
+    signature and return exclusively kwargs in positional order if
+    `normalize_to_only_use_kwargs` is True.
+    Also populates default values. Does not support positional-only
+    parameters or varargs parameters (*args, **kwargs).
+
+    Args:
+        root (nn.Module): root module upon which we query modules
+        target (Callable): Function that we are normalizing
+        args (Tuple[Any]): Tuple of args to the function
+        kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
+        normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
+
+    Returns:
+
+        Returns normalized_args_and_kwargs, or `None` if not successful.
+    """
+    try:
+        submod = root.get_submodule(target)
+    except AttributeError as e:
+        raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
+                           f"have that target!") from e
+    if hasattr(submod.__class__, '__name__'):
+        classname = submod.__class__.__name__
+        if getattr(torch.nn, classname, None) == submod.__class__:
+            sig = inspect.signature(inspect.unwrap(submod.forward))
+            if kwargs is None:
+                kwargs = {}
+            new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
+                                                                         normalize_to_only_use_kwargs)
+            return new_args_and_kwargs
+    return None
+
+def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
+                                           kwargs : Dict[str, Any],
+                                           normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
+    """
+    Given a call target, args, and kwargs, return the arguments normalized into
+    an ArgsKwargsPair, or None if the type signature is not supported by
+    this normalization.
+
+    Args:
+
+        sig (inspect.Signature): Signature object for the target
+        args (Tuple): Arguments that appear at the callsite for `target`
+        kwargs (Dict): Keyword arguments that appear at the callsite for `target`
+        normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
+
+    Returns:
+
+        Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
+            this target is not supported.
+    """
+
+    # Don't currently support positional-only
+    # or varargs (*args, **kwargs) signatures
+    supported_parameter_types = {
+        inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
+    if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
+        # Add an exception for one signature, which is common for random/uniform, i.e.:
+        # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
+        # `from` is Python keyword and as such functions with that signature should have
+        # positional-only args, but at the same time they could be dispatched as kwargs
+        if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
+            return None
+
+    bound_args = sig.bind(*args, **kwargs)
+    bound_args.apply_defaults()
+
+    new_kwargs : Dict[str, Any] = {}
+    new_args : List[Any] = []
+    for i, param in enumerate(sig.parameters):
+        if not normalize_to_only_use_kwargs and i < len(args):
+            new_args.append(bound_args.arguments[param])
+        else:
+            new_kwargs[param] = bound_args.arguments[param]
+
+    return ArgsKwargsPair(tuple(new_args), new_kwargs)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36ba1c8585df0ef2821ab0c8b31d170ce2bb4d59
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/__init__.py
@@ -0,0 +1,11 @@
+from . import graph_drawer
+from . import graph_manipulation
+from . import net_min_base
+from . import operator_support
+from . import param_fetch
+from . import reinplace
+from . import shape_prop
+from . import split_module
+from . import split_utils
+from . import splitter_base
+from . import tools_common
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a197d5f9245b23dce87e52f516888bb273532aa0
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0033dce8bbe76577842405651580ba0c16979340
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..598dfb125f4000d62ca3f52c404c65ba35262d2c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9748353172c084bbc68d03b863a0a9c258c52f79
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02e9c23a8d246c551fcca65666cc6f211ee9d37d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50d150eb796e2779d91a78ad8bda83bed5c269ef
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b8ef7232f6b6ce8e635143f5e615bfd909e3e86
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e2a985dbac76e73c69fdd6d089385a16d19f154
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26de3f26cc6ad0626e466a091aacd1b5199bb353
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cdfa4d687e3aead08cca0b271136cc715dd294f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df9d7595ffc76c60bc3f839abf053993f353128c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/split_module.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/split_module.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3b72cfc2d06d6dbd5edbbc4918bed3a5ad12c64
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/split_module.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e063011b0c824b0702d865e71d4b3bde665c8ec8
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..767688822ff587bdae95b90e5ffac479b6debf6d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3db16cfef9ae20a5633983b629b13d7e8f6e7db9
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/MLPY/Lib/site-packages/torch/fx/passes/annotate_getitem_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0fc6476c4a645d9211e5f27267858265b067de
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/annotate_getitem_nodes.py
@@ -0,0 +1,44 @@
+import operator
+
+import torch
+
+
+def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
+    """
+    Annotate the type of getitem nodes, inferred from the type of sequence node.
+    If sequence node is not annotated with a type, do nothing.
+    Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
+
+    This is helpful since annotations on local names within function are lost during FX transforms.
+    Adding back known type annotation for getitem nodes to improve jit scriptability.
+
+    Args:
+        graph (Graph): The graph to be annotated
+    """
+    for node in graph.nodes:
+        if node.target == operator.getitem:
+            sequence_node, index_node = node.args
+            if not sequence_node.type:
+                continue
+            # container types
+            if hasattr(sequence_node.type, "_name"):
+                parameterized_types = sequence_node.type.__args__
+                if sequence_node.type._name == "Tuple":
+                    if len(parameterized_types) == 2 and isinstance(
+                        parameterized_types[1], type(...)
+                    ):
+                        node.type = parameterized_types[0]
+                    else:
+                        assert len(parameterized_types) > index_node
+                        node_type = parameterized_types[index_node]
+                        node.type = node_type
+                elif sequence_node.type._name == "List":
+                    assert len(parameterized_types) == 1
+                    node.type = parameterized_types[0]
+            # NamedTuple type
+            elif hasattr(sequence_node.type, "__annotations__"):
+                if sequence_node.type == torch.Tensor:
+                    continue
+                sequence_node_field_types = sequence_node.type.__annotations__
+                field_name = sequence_node.type._fields[index_node]
+                node.type = sequence_node_field_types[field_name]
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/backends/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/backends/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc670b27419e4a024a6558bb88b3d4819467244c
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42dd1cb9e3edbd0f3470c4e2942a47affaa35539
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/backends/cudagraphs.py b/MLPY/Lib/site-packages/torch/fx/passes/backends/cudagraphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..40f261e41fc72cc6a1e04e7ea76e23c9da278e6b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/backends/cudagraphs.py
@@ -0,0 +1,56 @@
+import torch
+from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
+from torch.fx.passes.operator_support import OperatorSupport
+from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
+from torch.fx.passes.fake_tensor_prop import FakeTensorProp
+from torch.utils import _pytree as pytree
+
+import operator
+
+class CudaGraphsSupport(OperatorSupport):
+    # TODO: why is submodules passed here
+    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
+        if node.op not in CALLABLE_NODE_OPS:
+            return False
+
+        if node.target in [torch.ops.aten.embedding_dense_backward.default]:
+            return False
+
+        if node.target in [operator.getitem]:
+            return True
+
+        found_not_cuda = False
+
+        def meta_fk(meta):
+            return meta["val"] if "val" in meta else meta["fake_result"]
+
+        def find_not_cuda(t):
+            nonlocal found_not_cuda
+            if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
+                found_not_cuda = True
+
+        for n in node.all_input_nodes:
+            pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
+
+        pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
+
+        # NB: factory function is accounted for because the result would be
+        # cpu or cuda
+
+        return not found_not_cuda
+
+def partition_cudagraphs(gm, inputs):
+    """
+    Partition an FX graph into sub-GraphModules that can be validly run under
+    CUDA graphs.  For a subgraph to be runnable under CUDA, all of the operations
+    must involve CUDA tensors only/
+    """
+
+    FakeTensorProp(gm).propagate(*inputs)
+    supported_ops = CudaGraphsSupport()
+    # TODO: single node partition may be wrong due to the pessimization
+    # from copying in and out the data.  Check in benchmarks, perhaps
+    partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
+    partitions = partitioner.propose_partitions()
+    fused_graph = partitioner.fuse_partitions(partitions)
+    return fused_graph
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/dialect/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/dialect/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8a3db5fa0f83c4785bdc671d8e550e61742978a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d5a0c4074f0162ac79c57ed810104f4a0892566
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52449e2652d43ccf801bf424fcc701b7d25e8221
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/cse_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cbfb54f2b116d6c4b63a33f80f5dfadf3af20a5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/dialect/common/cse_pass.py
@@ -0,0 +1,112 @@
+from typing import Dict, Tuple, Any
+
+import torch
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+from torch.utils._pytree import tree_flatten
+
+from torch.fx import GraphModule, Graph
+from torch.fx import Node
+
+aten = torch.ops.aten
+
+
+# stateful ops are banned from CSE
+rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm}  # noqa: E501,B950
+
+inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_}  # noqa: E501
+
+
+@torch.fx._compatibility.compatibility(is_backward_compatible=False)
+def get_CSE_banned_ops():
+    return rand_ops.union(inplace_ops)
+
+
+@torch.fx._compatibility.compatibility(is_backward_compatible=False)
+class CSEPass(PassBase):
+
+    def __init__(self, banned_ops=None):
+        """
+        This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
+
+        For functional dialects, user would only need to specify the random ops in ban list.
+
+        Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
+        If your dialect contains stateful operators, please customized the banned_ops.
+
+        """
+        if banned_ops is None:
+            banned_ops = set()
+        self.banned_ops = banned_ops
+        super().__init__()
+
+    def call(self, graph_module: GraphModule) -> PassResult:
+        """
+        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
+
+        Example usage:
+
+        from torch.fx.experimental.proxy_tensor import make_fx
+        def f(a):
+            b = a * a
+            c = a * a
+            return b+c
+
+        p = CSEPass()
+        traced_graph = make_fx(f)(torch.tensor(1))
+        print(traced_graph)
+        result = p(traced_graph)
+        print(result.graph_module)
+        """
+        def get_aten_target(node):
+            if hasattr(node.target, 'overloadpacket'):
+                return node.target.overloadpacket
+            return node.target
+
+        modified = False
+        new_graph = Graph()
+        env: Dict[Node, Node] = {}  # map from node in the old graph to node in the new graph
+        hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {}  # map from hash to a node in the new graph
+        token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {}  # map from hash to token
+        for n in graph_module.graph.nodes:
+            # The placeholder, output, and get_attr nodes are copied to the new graph without change
+            # do not CSE away random operations
+            if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
+                new_node = new_graph.node_copy(n, lambda x: env[x])
+                env[n] = new_node
+            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
+                # substitute args and kwargs members to their mapping in env if exists
+                # specs can be used to reconstruct nested list/dictionaries
+                def substitute(arg_list):
+                    arg_list, spec = tree_flatten(arg_list)
+                    for i in range(len(arg_list)):
+                        v = arg_list[i]
+                        if isinstance(v, Node) and v in env:
+                            arg_list[i] = env[v]
+                    return tuple(arg_list), spec
+                args, args_spec = substitute(n.args)
+                kwargs, kwargs_spec = substitute(n.kwargs)
+
+                # each token corresponds to a unique node
+                # nodes with the same token can be substituted
+                token = {"target": n.target, "args": args, "args_spec": args_spec,
+                         "kwargs": kwargs, "kwargs_spec": kwargs_spec}
+
+                # hash substituted args to a number, do not hash specs because specs are not hashable
+                hash_arg = hash((args, kwargs))
+                hash_val = (n.target, hash_arg)
+
+                # check if a node has a substitute and can be eliminated
+                hash_val_in_hash_env = hash_val in hash_env
+                if hash_val_in_hash_env and token_map[hash_val] == token:
+                    modified = True  # substitution happens and the graph is modified
+                    env[n] = hash_env[hash_val]
+                    continue
+
+                new_node = new_graph.node_copy(n, lambda x: env[x])
+                env[n] = new_node
+                if not hash_val_in_hash_env:
+                    hash_env[hash_val] = new_node
+                    token_map[hash_val] = token
+
+        csed_gm = GraphModule(graph_module, new_graph)
+        return PassResult(csed_gm, modified)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/fake_tensor_prop.py b/MLPY/Lib/site-packages/torch/fx/passes/fake_tensor_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a12caf9e9b6c45f0c8255de8d971c517203cfab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/fake_tensor_prop.py
@@ -0,0 +1,73 @@
+from typing import Optional
+
+import torch.fx
+from torch.fx import Node
+from torch.fx._compatibility import compatibility
+from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
+from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
+from torch.fx.node import map_aggregate
+
+__all__ = ['FakeTensorProp']
+
+@compatibility(is_backward_compatible=False)
+class FakeTensorProp(torch.fx.Interpreter):
+    """
+    Execute an FX graph Node-by-Node and record a fake tensor representing
+    the metadata for the node.  Unlike ShapeProp, (1) this propagation
+    is cheap--it does the propagation with meta tensors which do not actually
+    store data, and (2) the fake tensors have much more fine grained information,
+    e.g., they have accurate alias information that can be consulted by looking
+    at the storages.
+
+    Args:
+         module (GraphModule): The module to be executed
+         mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
+    """
+    def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
+        super().__init__(module)
+        if mode is None:
+            mode = FakeTensorMode()
+        self._mode = mode
+
+    def run_node(self, n: Node):
+        import sympy
+        from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+
+        result = super().run_node(n)
+        sym = None
+        if (
+            'val' in n.meta and
+            isinstance(v := n.meta['val'], torch.SymInt) and
+            isinstance(v.node.expr, sympy.Symbol) and free_unbacked_symbols(v)
+        ):
+            sym = v
+
+        def extract_val(obj):
+            if isinstance(obj, FakeTensor):
+                return snapshot_fake(obj)
+            elif isinstance(obj, torch.Tensor):
+                # TODO: How is it possible that we get a non fake tensor?  We
+                # should be running under the mode...
+                return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
+            elif isinstance(obj, py_sym_types):
+                return obj
+            else:
+                return None
+
+        meta = map_aggregate(result, extract_val)
+        if meta is not None:
+            n.meta['val'] = meta
+            if sym is not None:
+                torch._check(meta == v)
+        return result
+
+    def propagate(self, *args):
+        fake_args = [
+            self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
+            for a in args
+        ]
+        return self.propagate_dont_convert_inputs(*fake_args)
+
+    def propagate_dont_convert_inputs(self, *args):
+        with self._mode:
+            return super().run(*args)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/graph_drawer.py b/MLPY/Lib/site-packages/torch/fx/passes/graph_drawer.py
new file mode 100644
index 0000000000000000000000000000000000000000..afe821f8dd7189c22b0f673f6dee79b06b85185f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/graph_drawer.py
@@ -0,0 +1,421 @@
+
+import hashlib
+import torch
+import torch.fx
+from typing import Any, Dict, Optional, TYPE_CHECKING
+from torch.fx.node import _get_qualified_name, _format_arg
+from torch.fx.graph import _parse_stack_trace
+from torch.fx.passes.shape_prop import TensorMetadata
+from torch.fx._compatibility import compatibility
+from itertools import chain
+
+__all__ = ['FxGraphDrawer']
+try:
+    import pydot
+    HAS_PYDOT = True
+except ImportError:
+    HAS_PYDOT = False
+
+_COLOR_MAP = {
+    "placeholder": '"AliceBlue"',
+    "call_module": "LemonChiffon1",
+    "get_param": "Yellow2",
+    "get_attr": "LightGrey",
+    "output": "PowderBlue",
+}
+
+_HASH_COLOR_MAP = [
+    "CadetBlue1",
+    "Coral",
+    "DarkOliveGreen1",
+    "DarkSeaGreen1",
+    "GhostWhite",
+    "Khaki1",
+    "LavenderBlush1",
+    "LightSkyBlue",
+    "MistyRose1",
+    "MistyRose2",
+    "PaleTurquoise2",
+    "PeachPuff1",
+    "Salmon",
+    "Thistle1",
+    "Thistle3",
+    "Wheat1",
+]
+
+_WEIGHT_TEMPLATE = {
+    "fillcolor": "Salmon",
+    "style": '"filled,rounded"',
+    "fontcolor": "#000000",
+}
+
+if HAS_PYDOT:
+    @compatibility(is_backward_compatible=False)
+    class FxGraphDrawer:
+        """
+        Visualize a torch.fx.Graph with graphviz
+        Basic usage:
+            g = FxGraphDrawer(symbolic_traced, "resnet18")
+            g.get_dot_graph().write_svg("a.svg")
+        """
+
+        def __init__(
+            self,
+            graph_module: torch.fx.GraphModule,
+            name: str,
+            ignore_getattr: bool = False,
+            ignore_parameters_and_buffers: bool = False,
+            skip_node_names_in_args: bool = True,
+            parse_stack_trace: bool = False,
+            dot_graph_shape: Optional[str] = None,
+        ):
+            self._name = name
+            self.dot_graph_shape = (
+                dot_graph_shape if dot_graph_shape is not None else "record"
+            )
+            _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
+
+            self._dot_graphs = {
+                name: self._to_dot(
+                    graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
+                )
+            }
+
+            for node in graph_module.graph.nodes:
+                if node.op != "call_module":
+                    continue
+
+                leaf_node = self._get_leaf_node(graph_module, node)
+
+                if not isinstance(leaf_node, torch.fx.GraphModule):
+                    continue
+
+
+                self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
+                    leaf_node,
+                    f"{name}_{node.target}",
+                    ignore_getattr,
+                    ignore_parameters_and_buffers,
+                    skip_node_names_in_args,
+                    parse_stack_trace,
+                )
+
+        def get_dot_graph(self, submod_name=None) -> pydot.Dot:
+            """
+            Visualize a torch.fx.Graph with graphviz
+            Example:
+                >>> # xdoctest: +REQUIRES(module:pydot)
+                >>> # define module
+                >>> class MyModule(torch.nn.Module):
+                >>>     def __init__(self):
+                >>>         super().__init__()
+                >>>         self.linear = torch.nn.Linear(4, 5)
+                >>>     def forward(self, x):
+                >>>         return self.linear(x).clamp(min=0.0, max=1.0)
+                >>> module = MyModule()
+                >>> # trace the module
+                >>> symbolic_traced = torch.fx.symbolic_trace(module)
+                >>> # setup output file
+                >>> import ubelt as ub
+                >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
+                >>> fpath = dpath / 'linear.svg'
+                >>> # draw the graph
+                >>> g = FxGraphDrawer(symbolic_traced, "linear")
+                >>> g.get_dot_graph().write_svg(fpath)
+            """
+            if submod_name is None:
+                return self.get_main_dot_graph()
+            else:
+                return self.get_submod_dot_graph(submod_name)
+
+        def get_main_dot_graph(self) -> pydot.Dot:
+            return self._dot_graphs[self._name]
+
+        def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
+            return self._dot_graphs[f"{self._name}_{submod_name}"]
+
+        def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
+            return self._dot_graphs
+
+        def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
+
+            template = {
+                "shape": self.dot_graph_shape,
+                "fillcolor": "#CAFFE3",
+                "style": '"filled,rounded"',
+                "fontcolor": "#000000",
+            }
+            if node.op in _COLOR_MAP:
+                template["fillcolor"] = _COLOR_MAP[node.op]
+            else:
+                # Use a random color for each node; based on its name so it's stable.
+                target_name = node._pretty_print_target(node.target)
+                target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
+                template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
+            return template
+
+        def _get_leaf_node(
+            self, module: torch.nn.Module, node: torch.fx.Node
+        ) -> torch.nn.Module:
+            py_obj = module
+            assert isinstance(node.target, str)
+            atoms = node.target.split(".")
+            for atom in atoms:
+                if not hasattr(py_obj, atom):
+                    raise RuntimeError(
+                        str(py_obj) + " does not have attribute " + atom + "!"
+                    )
+                py_obj = getattr(py_obj, atom)
+            return py_obj
+
+        def _typename(self, target: Any) -> str:
+            if isinstance(target, torch.nn.Module):
+                ret = torch.typename(target)
+            elif isinstance(target, str):
+                ret = target
+            else:
+                ret = _get_qualified_name(target)
+
+            # Escape "{" and "}" to prevent dot files like:
+            # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
+            # which triggers `Error: bad label format (...)` from dot
+            return ret.replace("{", r"\{").replace("}", r"\}")
+
+        # shorten path to avoid drawing long boxes
+        # for full path = '/home/weif/pytorch/test.py'
+        # return short path = 'pytorch/test.py'
+        def _shorten_file_name(
+            self,
+            full_file_name: str,
+            truncate_to_last_n: int = 2,
+        ):
+            splits = full_file_name.split('/')
+            if len(splits) >= truncate_to_last_n:
+                return '/'.join(splits[-truncate_to_last_n:])
+            return full_file_name
+
+
+        def _get_node_label(
+            self,
+            module: torch.fx.GraphModule,
+            node: torch.fx.Node,
+            skip_node_names_in_args: bool,
+            parse_stack_trace: bool,
+        ) -> str:
+            def _get_str_for_args_kwargs(arg):
+                if isinstance(arg, tuple):
+                    prefix, suffix = r"|args=(\l", r",\n)\l"
+                    arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
+                elif isinstance(arg, dict):
+                    prefix, suffix = r"|kwargs={\l", r",\n}\l"
+                    arg_strs_list = [
+                        f"{k}: {_format_arg(v, max_list_len=8)}"
+                        for k, v in arg.items()
+                    ]
+                else:  # Fall back to nothing in unexpected case.
+                    return ""
+
+                # Strip out node names if requested.
+                if skip_node_names_in_args:
+                    arg_strs_list = [a for a in arg_strs_list if "%" not in a]
+                if len(arg_strs_list) == 0:
+                    return ""
+                arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
+                if len(arg_strs_list) == 1:
+                    arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
+                return arg_strs.replace("{", r"\{").replace("}", r"\}")
+
+
+            label = "{" + f"name=%{node.name}|op_code={node.op}\n"
+
+            if node.op == "call_module":
+                leaf_module = self._get_leaf_node(module, node)
+                label += r"\n" + self._typename(leaf_module) + r"\n|"
+                extra = ""
+                if hasattr(leaf_module, "__constants__"):
+                    extra = r"\n".join(
+                        [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__]  # type: ignore[union-attr]
+                    )
+                label += extra + r"\n"
+            else:
+                label += f"|target={self._typename(node.target)}" + r"\n"
+                if len(node.args) > 0:
+                    label += _get_str_for_args_kwargs(node.args)
+                if len(node.kwargs) > 0:
+                    label += _get_str_for_args_kwargs(node.kwargs)
+                label += f"|num_users={len(node.users)}" + r"\n"
+
+            tensor_meta = node.meta.get('tensor_meta')
+            label += self._tensor_meta_to_label(tensor_meta)
+
+            # for original fx graph
+            # print buf=buf0, n_origin=6
+            buf_meta = node.meta.get('buf_meta', None)
+            if buf_meta is not None:
+                label += f"|buf={buf_meta.name}" + r"\n"
+                label += f"|n_origin={buf_meta.n_origin}" + r"\n"
+
+            # for original fx graph
+            # print file:lineno code
+            if parse_stack_trace and node.stack_trace is not None:
+                parsed_stack_trace = _parse_stack_trace(node.stack_trace)
+                fname = self._shorten_file_name(parsed_stack_trace.file)
+                label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
+
+
+            return label + "}"
+
+        def _tensor_meta_to_label(self, tm) -> str:
+            if tm is None:
+                return ""
+            elif isinstance(tm, TensorMetadata):
+                return self._stringify_tensor_meta(tm)
+            elif isinstance(tm, list):
+                result = ""
+                for item in tm:
+                    result += self._tensor_meta_to_label(item)
+                return result
+            elif isinstance(tm, dict):
+                result = ""
+                for v in tm.values():
+                    result += self._tensor_meta_to_label(v)
+                return result
+            elif isinstance(tm, tuple):
+                result = ""
+                for item in tm:
+                    result += self._tensor_meta_to_label(item)
+                return result
+            else:
+                raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
+
+        def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
+            result = ""
+            if not hasattr(tm, "dtype"):
+                print("tm", tm)
+            result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
+            result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
+            result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
+            result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
+            if tm.is_quantized:
+                assert tm.qparams is not None
+                assert "qscheme" in tm.qparams
+                qscheme = tm.qparams["qscheme"]
+                if qscheme in {
+                        torch.per_tensor_affine,
+                        torch.per_tensor_symmetric,
+                }:
+                    result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
+                    result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
+                elif qscheme in {
+                        torch.per_channel_affine,
+                        torch.per_channel_symmetric,
+                        torch.per_channel_affine_float_qparams,
+                }:
+                    result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
+                    result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
+                    result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
+                else:
+                    raise RuntimeError(f"Unsupported qscheme: {qscheme}")
+                result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
+            return result
+
+        def _get_tensor_label(self, t: torch.Tensor) -> str:
+            return str(t.dtype) + str(list(t.shape)) + r"\n"
+
+        # when parse_stack_trace=True
+        # print file:lineno code
+        def _to_dot(
+            self,
+            graph_module: torch.fx.GraphModule,
+            name: str,
+            ignore_getattr: bool,
+            ignore_parameters_and_buffers: bool,
+            skip_node_names_in_args: bool,
+            parse_stack_trace: bool,
+        ) -> pydot.Dot:
+            """
+            Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
+            If ignore_parameters_and_buffers is True, the parameters and buffers
+            created with the module will not be added as nodes and edges.
+            """
+
+            # "TB" means top-to-bottom rank direction in layout
+            dot_graph = pydot.Dot(name, rankdir="TB")
+
+
+            buf_name_to_subgraph = {}
+
+            for node in graph_module.graph.nodes:
+                if ignore_getattr and node.op == "get_attr":
+                    continue
+
+                style = self._get_node_style(node)
+                dot_node = pydot.Node(
+                    node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
+                )
+
+                current_graph = dot_graph
+
+                buf_meta = node.meta.get('buf_meta', None)
+                if buf_meta is not None and buf_meta.n_origin > 1:
+                    buf_name = buf_meta.name
+                    if buf_name not in buf_name_to_subgraph:
+                        buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
+                    current_graph = buf_name_to_subgraph.get(buf_name)
+
+                current_graph.add_node(dot_node)
+
+                def get_module_params_or_buffers():
+                    for pname, ptensor in chain(
+                        leaf_module.named_parameters(), leaf_module.named_buffers()
+                    ):
+                        pname1 = node.name + "." + pname
+                        label1 = (
+                            pname1 + "|op_code=get_" + "parameter"
+                            if isinstance(ptensor, torch.nn.Parameter)
+                            else "buffer" + r"\l"
+                        )
+                        dot_w_node = pydot.Node(
+                            pname1,
+                            label="{" + label1 + self._get_tensor_label(ptensor) + "}",
+                            **_WEIGHT_TEMPLATE,
+                        )
+                        dot_graph.add_node(dot_w_node)
+                        dot_graph.add_edge(pydot.Edge(pname1, node.name))
+
+                if node.op == "call_module":
+                    leaf_module = self._get_leaf_node(graph_module, node)
+
+                    if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
+                        get_module_params_or_buffers()
+
+            for subgraph in buf_name_to_subgraph.values():
+                subgraph.set('color', 'royalblue')
+                subgraph.set('penwidth', '2')
+                dot_graph.add_subgraph(subgraph)
+
+            for node in graph_module.graph.nodes:
+                if ignore_getattr and node.op == "get_attr":
+                    continue
+
+                for user in node.users:
+                    dot_graph.add_edge(pydot.Edge(node.name, user.name))
+
+            return dot_graph
+
+else:
+    if not TYPE_CHECKING:
+        @compatibility(is_backward_compatible=False)
+        class FxGraphDrawer:
+            def __init__(
+                self,
+                graph_module: torch.fx.GraphModule,
+                name: str,
+                ignore_getattr: bool = False,
+                ignore_parameters_and_buffers: bool = False,
+                skip_node_names_in_args: bool = True,
+                parse_stack_trace: bool = False,
+                dot_graph_shape: Optional[str] = None,
+            ):
+                raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
+                                   'pydot through your favorite Python package manager.')
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/graph_manipulation.py b/MLPY/Lib/site-packages/torch/fx/passes/graph_manipulation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bf16efebf644937fdef7ce7dd8f504a210da134
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/graph_manipulation.py
@@ -0,0 +1,110 @@
+from typing import Any, Dict, List, NamedTuple, Optional
+
+import torch
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import Graph
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import (
+    map_arg,
+    Node,
+    Target,
+)
+from torch.fx.passes.shape_prop import ShapeProp
+
+__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
+           'get_size_of_node']
+
+@compatibility(is_backward_compatible=False)
+def replace_target_nodes_with(
+    fx_module: GraphModule,
+    old_op: str,
+    old_target: Target,
+    new_op: str,
+    new_target: Target,
+):
+    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
+    and updates them to match the new op code and target"""
+    new_graph = Graph()
+    val_map: Dict[Node, Node] = {}
+    for node in fx_module.graph.nodes:
+        if node.op == old_op and node.target == old_target:
+            args = map_arg(node.args, lambda n: val_map[n])
+            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
+            assert isinstance(args, tuple)
+            assert isinstance(kwargs, dict)
+            val_map[node] = new_graph.create_node(
+                new_op, new_target, args, kwargs, node.name
+            )
+        else:
+            val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
+    fx_module.graph = new_graph
+
+
+@compatibility(is_backward_compatible=False)
+class size_bytes(NamedTuple):
+    output_size: int
+    total_size: int
+
+
+@compatibility(is_backward_compatible=False)
+def get_size_of_all_nodes(
+    fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
+) -> None:
+    """Given a fx graph module, update each node with its total size (weights + bias + output)
+    and its output_size(output). For a non-module node, the total size is the output size.
+    return total size"""
+    if args is not None:
+        # Mark shape and dtype for each node (node.shape and node.dtype)
+        ShapeProp(fx_module).propagate(*args)
+    # Calculate the total size of the whole fx graph
+    total_size_of_graph = 0.0
+    for node in fx_module.graph.nodes:
+        if node.op == "output":
+            break
+        node.size_bytes = get_size_of_node(fx_module, node)
+    return
+
+
+@compatibility(is_backward_compatible=False)
+def get_tensor_meta(node: Node) -> Any:
+    tensor_meta = node.meta.get("tensor_meta")
+
+    if not tensor_meta:
+        raise RuntimeError(
+            f"Node {node} has no tensor metadata associated with it! "
+            f"Check that shape propagation has run."
+        )
+
+    return tensor_meta
+
+
+@compatibility(is_backward_compatible=False)
+def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
+    """Given a node with node.dtype and node.shape, return its total size and its output size.
+    total_size = weights + bias + output_size
+    """
+    # Total num of elements
+    total_num_of_elems = 0
+    # For a module, conside all parameters
+    if node.op == "call_module":
+        submodule_dict = dict(fx_module.named_modules())
+        submodule = submodule_dict[node.target]
+        parameters = submodule.named_parameters()
+        # Parameters are named tuples
+        for name, p in parameters:
+            total_num_of_elems += p.numel()
+    # Don't forget the output size
+    # node.shape is the shape of this node's output
+    tensor_meta = get_tensor_meta(node)
+    output_elem = tensor_meta.shape.numel()
+    total_num_of_elems += output_elem
+    # Assume for now if it's quantized then it's qint8 or quint8
+    if tensor_meta.is_quantized:
+        size_per_elem_bytes = torch._empty_affine_quantized(
+            [], dtype=tensor_meta.dtype
+        ).element_size()
+    else:
+        size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
+    total_size = size_per_elem_bytes * total_num_of_elems
+    output_size = size_per_elem_bytes * output_elem
+    return size_bytes(output_size, total_size)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/infra/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6398778292d887c23e3c69c4eb1f75fd9c516d2e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/infra/__init__.py
@@ -0,0 +1,2 @@
+
+from . import pass_manager
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f51957750e851408b2f6bd0ccdcff2da6c430f4b
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7aa121543c091945f67c8a2c1fb67440147c966
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3012f7811617112f42077f55b5b0d4c0a4932616
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37bb9842130a3f019142b1ce74fa8913aa84398f
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/partitioner.py b/MLPY/Lib/site-packages/torch/fx/passes/infra/partitioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc0093618f5833874b529bbea16c45138c400491
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/infra/partitioner.py
@@ -0,0 +1,329 @@
+from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
+import collections
+import itertools
+import logging
+
+from copy import copy
+from typing import Dict, Iterable, List, Optional, Sequence, Set
+
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import Node, _get_qualified_name
+from torch.fx.passes.operator_support import OperatorSupportBase
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+class Partition:
+    def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
+        self.id = id
+        self.nodes: Set[Node] = set(nodes) if nodes is not None else set()
+
+    def __repr__(self) -> str:
+        return str(self.nodes)
+
+    def add_node(self, node: Node):
+        self.nodes.add(node)
+
+    def remove_node(self, node: Node):
+        self.nodes.remove(node)
+
+    def size(self):
+        return len(self.nodes)
+
+class _DependencyViewer:
+    def __init__(self, graph_module: GraphModule):
+        self.upstreams = collections.defaultdict(set)
+        self.downstreams = collections.defaultdict(set)
+
+        for node in graph_module.graph.nodes:
+            for input_node in node.all_input_nodes:
+                # add input_node and input_node's upstream dependency
+                self.upstreams[node].add(input_node)
+                self.upstreams[node].update(self.upstreams[input_node])
+
+        for node in reversed(graph_module.graph.nodes):
+            for output_node in node.users:
+                # add output_node and output_node's downstream dependency
+                self.downstreams[node].add(output_node)
+                self.downstreams[node].update(self.downstreams[output_node])
+
+    def downstreams_of(self, node: Node) -> Set[Node]:
+        return self.downstreams[node]
+
+    def upstreams_of(self, node: Node) -> Set[Node]:
+        return self.upstreams[node]
+
+class CapabilityBasedPartitioner:
+
+    def __init__(self,
+                 graph_module: GraphModule,
+                 operator_support: OperatorSupportBase,
+                 allows_single_node_partition: bool = False,
+                 non_compute_ops: Optional[Sequence[str]] = None,
+                 allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
+                 ) -> None:
+        self.graph_module = graph_module
+        self.operator_support = operator_support
+        self.allows_single_node_partition = allows_single_node_partition
+        self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
+        self.allowed_single_node_partition_ops = (
+            allowed_single_node_partition_ops
+            if allowed_single_node_partition_ops is not None
+            else []
+        )
+        self.dependency_viewer = _DependencyViewer(graph_module)
+
+    def __is_node_supported(self, node: Node) -> bool:
+        return (
+            self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
+        )
+
+    def propose_partitions(self) -> List[Partition]:
+        # partition_map is a mapping from partition id to a set of partition id's.
+        # The value set contains all the partition ids that can be reached by doing a
+        # DFS starting from the partition id in the key.
+        partition_map : Dict[int, Set] = collections.defaultdict(set)
+
+        # assumptions: nodes in candidate list is sorted in topological order
+        assignment: Dict[Node, int] = {}   # mapping from node to partition_id
+        partitions_by_id: Dict[int, Partition] = {}  # mapping from partition_id to partition
+        new_partition_id = itertools.count()
+
+        # try to merge partition other_id into partition self_id
+        # merge only happens if the end graph doesn't contain cyclic dependency
+        # returns `True` when merge happens, `False` otherwise.
+        def maybe_merge_partition(self_id: int, other_id: int):
+            # merged_nodes is the union of nodes in two partition to-be-merged
+            merged_nodes = copy(partitions_by_id[self_id].nodes)
+            merged_nodes.update(partitions_by_id[other_id].nodes)
+
+            def dfs_iter_find_cycle(all_user_nodes: List[Node]):
+                for user_node in all_user_nodes:
+                    visited_partition_ids = set()
+
+                    for path_node in self.dependency_viewer.downstreams_of(user_node):
+                        # If any of the nodes in the dfs path of this node are in the merged_nodes
+                        # list then there is a cycle in the graph.
+                        if path_node in merged_nodes:
+                            return True
+
+                        # If any of the nodes in the dfs path of this node are in the assignment
+                        # map then we have to make sure that the partitions that these nodes belong
+                        # to do not form a cycle with the current partitions being merged. This means
+                        # iterating through all the nodes in all the parititons that are traversed in
+                        # the dfs path and checking if they are in the merged_nodes list.
+                        if path_node in assignment:
+                            partition_id = assignment[path_node]
+                            # If the partition id has already been visited then we know that it doesn't
+                            # form a cycle with the current partitions being merged.
+                            if partition_id in visited_partition_ids:
+                                continue
+                            p_map = partition_map[partition_id]
+                            if self_id in p_map or other_id in p_map:
+                                return True
+
+                            visited_partition_ids.add(partition_id)
+
+                return False
+
+            # check if merge would create cyclic dependency.
+            all_user_nodes = []
+            for node in merged_nodes:
+                for user_node in node.users:
+                    if user_node not in merged_nodes:
+                        all_user_nodes.append(user_node)
+
+            if dfs_iter_find_cycle(all_user_nodes):
+                # return false indicating cyclic dependency found and
+                # merge is aborted
+                return False
+
+            # no cyclic dependency found, move forward with the merge
+            # updating partition nodes
+            partitions_by_id[self_id].nodes = merged_nodes
+            # updating assignment map
+            for node in partitions_by_id[other_id].nodes:
+                assignment[node] = self_id
+            # delete other partition
+            del partitions_by_id[other_id]
+
+            partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
+            del partition_map[other_id]
+
+            return True
+
+        def merge_single_node(node: Node, id: Optional[int]):
+            def _update_partition_map(node: Node, id: int):
+                # Iterate through all the downstream nodes of this node and update the partition map
+                # to indicate that there is a path from the partition id of this node to the target
+                # partition id.
+                downstream_nodes = self.dependency_viewer.downstreams_of(node)
+                for curr_node in downstream_nodes:
+                    target_id = assignment.get(curr_node, None)
+                    if target_id is not None:
+                        partition_map[id].add(target_id)
+
+                # Iterate through all the upstream nodes of this node and update the partition map
+                # to indicate that there is a path from the partition id of the upstream node to the
+                # current node's partition id.
+                upstream_nodes = self.dependency_viewer.upstreams_of(node)
+                for curr_node in upstream_nodes:
+                    source_id = assignment.get(curr_node, None)
+                    if source_id is not None:
+                        partition_map[source_id].add(id)
+
+            if node in assignment:
+                partitions_by_id[assignment[node]].remove_node(node)
+
+            if id is None:
+                assignment.pop(node)
+            elif id not in partitions_by_id:
+                assignment[node] = id
+                partitions_by_id[id] = Partition(id=id, nodes=[node])
+                _update_partition_map(node, id)
+            else:
+                assignment[node] = id
+                partitions_by_id[id].add_node(node)
+                _update_partition_map(node, id)
+
+        logger.debug("Proposing partitions...")
+
+        for node in reversed(self.graph_module.graph.nodes):
+            # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
+            merge_candidates: Dict[int, None] = {}
+
+            # Note a limited horizontal fusion is enabled:
+            #   when `node` is not supported, the code below attempts to fuse consumer of `node`.
+            #
+            # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
+            # the fusion by adding an `else` block here to skip horizontal fusion.
+            if self.__is_node_supported(node) and node not in assignment:
+                partition_id = next(new_partition_id)
+                merge_single_node(node, partition_id)
+                merge_candidates[partition_id] = None
+
+            # merge all possible partitions
+            for node in assignment:
+                merge_candidates[assignment[node]] = None
+
+            merge_candidates_list = list(merge_candidates.keys())
+            if len(merge_candidates_list) > 1:
+                self_id = merge_candidates_list[0]
+                for other_id in merge_candidates_list[1:]:
+                    # note: merge partition `other_id` into partition `self_id` if
+                    # it doesn't create cyclic dependency in the graph, otherwise,
+                    # this is a no-op
+                    maybe_merge_partition(self_id, other_id)
+
+        # post processing to re-assign "getitem" nodes into upstream partition
+        logger.debug("Reassigning getitem nodes to its producer node's partition...")
+        nodes_reassignment: Dict[Node, int] = {}
+        for node in self.graph_module.graph.nodes:
+            is_tuple_output = True
+            for user in node.users:
+                if user.op != "call_function" or \
+                   _get_qualified_name(user.target) != "_operator.getitem":     # type: ignore[arg-type]
+                    is_tuple_output = False
+                    break
+
+            # node has tuple outputs, re-assign all following getitem node into node's partition
+            if is_tuple_output:
+                id = assignment.get(node, None)     # type: ignore[arg-type]
+                for user in node.users:
+                    if assignment.get(user, None) != id:    # type: ignore[arg-type]
+                        nodes_reassignment[user] = id  # type: ignore[assignment]
+        for node, id in nodes_reassignment.items():
+            merge_single_node(node, id)
+
+        # filter out single node partitions
+        if not self.allows_single_node_partition:
+            logger.debug("Filtering out single node partitions...")
+            default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
+            non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
+            partitions_to_remove: List[int] = []
+            for id, partition in partitions_by_id.items():
+                compute_node_count = 0
+                for node in partition.nodes:
+                    if node.op == "call_function":
+                        assert callable(node.target)
+                        if _get_qualified_name(node.target) not in non_compute_ops:
+                            compute_node_count += 1
+                        if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
+                            compute_node_count += 1
+                if compute_node_count <= 1:
+                    partitions_to_remove.append(id)
+            for id in partitions_to_remove:
+                del partitions_by_id[id]
+
+        logger.debug("Partitions proposed:")
+        for id, partition in partitions_by_id.items():
+            logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
+
+        return list(partitions_by_id.values())
+
+    def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
+        logger.debug("Fusing partitions...")
+        # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
+        return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])
+
+    # remove non-compute-ops that sits at the boundary of a partition.
+    def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
+        non_compute_ops = set(self.non_compute_ops)
+
+        def is_non_compute_node(node: Node):
+            return node.op == "call_function" and \
+                _get_qualified_name(node.target) in non_compute_ops  # type: ignore[arg-type]
+
+        # cache transparent nodes
+        transparent_input_nodes: Dict[Node, bool] = {}
+        transparent_output_nodes: Dict[Node, bool] = {}
+
+        def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
+            if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
+                return True
+            if node in transparent_input_nodes:
+                return transparent_input_nodes[node]
+            if is_non_compute_node(node):
+                for input_n in node.all_input_nodes:
+                    if not is_transparent_input_node(input_n, partition, removed_nodes):
+                        transparent_input_nodes[node] = False
+                        return False
+                transparent_input_nodes[node] = True
+                return True
+            transparent_input_nodes[node] = False
+            return False
+
+        def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
+            if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
+                return True
+            if node in transparent_output_nodes:
+                return transparent_output_nodes[node]
+            if is_non_compute_node(node):
+                for output_n in node.users:
+                    if not is_transparent_output_node(output_n, partition, removed_nodes):
+                        transparent_output_nodes[node] = False
+                        return False
+                transparent_output_nodes[node] = True
+                return True
+            transparent_output_nodes[node] = False
+            return False
+
+        for partition in partitions:
+            # Note it's ok to use `set` here, since we are only query if a node
+            # has been removed. We are NEVER going to iterate on nodes inside
+            # the set.
+            remove_node: Set[Node] = set()
+            for node in partition.nodes:
+                if is_non_compute_node(node) and \
+                    (is_transparent_input_node(node, partition.nodes, remove_node) or
+                     is_transparent_output_node(node, partition.nodes, remove_node)):
+                    remove_node.add(node)
+
+            if len(remove_node) != 0:
+                partition.nodes = partition.nodes - remove_node
+
+    def partition_and_fuse(self) -> GraphModule:
+        partitions = self.propose_partitions()
+        fused_gm = self.fuse_partitions(partitions)
+        return fused_gm
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/pass_base.py b/MLPY/Lib/site-packages/torch/fx/passes/infra/pass_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb027a90e0a4541006a255d245c343daac645abc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/infra/pass_base.py
@@ -0,0 +1,75 @@
+import abc
+from collections import namedtuple
+from typing import Optional
+
+from torch.fx.graph_module import GraphModule
+from torch.fx._compatibility import compatibility
+
+
+__all__ = ['PassResult', 'PassBase']
+
+@compatibility(is_backward_compatible=False)
+class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
+    """
+    Result of a pass:
+        graph_module: The modified graph module
+        modified: A flag for if the pass has modified the graph module
+    """
+    def __new__(cls, graph_module, modified):
+        return super().__new__(cls, graph_module, modified)
+
+@compatibility(is_backward_compatible=False)
+class PassBase(abc.ABC):
+    """
+    Base interface for implementing passes.
+
+    It is required to implement the `call` function so that we can directly
+    pass instances of the Pass directly to the PassManager and call them as a
+    function.
+
+    We can directly pass an instance of a class implementing this interface into
+    the PassManager's `passes` attribute.
+    """
+
+    def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
+        """
+        Runs the precondition check, the pass itself, and the postcondition check.
+        """
+
+        self.requires(graph_module)
+        res = self.call(graph_module)
+        self.ensures(graph_module)
+        return res
+
+    @abc.abstractmethod
+    def call(self, graph_module: GraphModule) -> Optional[PassResult]:
+        """
+        The pass that is run through the given graph module. To implement a
+        pass, it is required to implement this function.
+
+        Args:
+            graph_module: The graph module we will run a pass on
+        """
+        pass
+
+    def requires(self, graph_module: GraphModule) -> None:  # noqa: B027
+        """
+        This function will be called before the pass is run and will check that
+        the given graph module contains the preconditions needed to run the
+        pass. It is not required to implement this function.
+
+        Args:
+            graph_module: The graph module we will run checks on
+        """
+        pass
+
+    def ensures(self, graph_module: GraphModule) -> None:  # noqa: B027
+        """
+        This function will be called after the pass is run and will check that
+        the given graph module contains the postconditions needed to run the
+        pass. It is not required to implement this function.
+
+        Args:
+            graph_module: The graph module we will run checks on
+        """
+        pass
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/infra/pass_manager.py b/MLPY/Lib/site-packages/torch/fx/passes/infra/pass_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d4580af2f4ba06b44b6c9d6297eb130fab4147
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/infra/pass_manager.py
@@ -0,0 +1,303 @@
+import inspect
+import logging
+from queue import Queue
+from functools import wraps
+from typing import Callable, Dict, List
+
+import torch.nn as nn
+from torch.fx.graph_module import GraphModule
+from torch.fx._compatibility import compatibility
+from torch.fx.passes.infra.pass_base import PassResult
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
+
+@compatibility(is_backward_compatible=False)
+def pass_result_wrapper(fn: Callable) -> Callable:
+    """
+    Wrapper for passes which currently do not return a PassResult.
+    This wrapper makes them return a PassResult containing the modified object
+    and True for the "modified" flag.
+
+    Args:
+        fn (Callable[Module, Any])
+
+    Returns:
+        wrapped_fn (Callable[Module, PassResult])
+    """
+    if fn is None:
+        return None
+
+    @wraps(fn)
+    def wrapped_fn(gm):
+        res = fn(gm)
+        if res is None:
+            return PassResult(gm, True)
+        if isinstance(res, PassResult):
+            return res
+        elif isinstance(res, nn.Module):
+            return PassResult(res, True)
+
+    if not inspect.isfunction(fn):
+        wrapped_fn.__name__ = type(fn).__name__
+
+    return wrapped_fn
+
+def _validate_pass_schedule_constraint(
+    constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
+) -> None:
+    for i, a in enumerate(passes):
+        for j, b in enumerate(passes[i + 1 :]):
+            if constraint(a, b):
+                continue
+            raise RuntimeError(
+                f"pass schedule constraint violated. Expected {a} before {b}"
+                f" but found {a} at index {i} and {b} at index{j} in pass"
+                f" list."
+            )
+
+def _topological_sort_passes(
+    passes: List[Callable], constraints: List[Callable]
+) -> List[Callable]:
+    """
+    Args
+        passes: Passes that we are ordering
+        constraints: Constraints applied on these passes
+
+    Returns
+        A sorted list of callables and a boolean of if a circular dependency
+        existed
+    """
+    if len(constraints) == 0:
+        return passes
+
+    # Contruct a graph mapping nodes to a list of their users
+    graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
+    indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
+    candidates: Queue = Queue()
+    for a in passes:
+        for b in passes:
+            if a == b:
+                continue
+
+            for constraint in constraints:
+                if not constraint(a, b):
+                    graph[b].append(a)
+                    indegree_map[a] += 1
+
+        if indegree_map[a] == 0:
+            candidates.put(a)
+
+    visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
+    sorted_passes: List[Callable] = []
+
+    while not candidates.empty():
+        p = candidates.get()
+        sorted_passes.append(p)
+        visited[p] = True
+
+        for n in graph[p]:
+            if not visited[n]:
+                indegree_map[n] -= 1
+                if indegree_map[n] == 0:
+                    candidates.put(n)
+
+    # Check if there are unvisited nodes (aka cycles in the graph)
+    cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
+    if len(cycle_passes) != 0:
+        error = f"Circular dependency detected within the following passes: {cycle_passes}"
+        raise RuntimeError(error)
+
+    return sorted_passes
+
+@compatibility(is_backward_compatible=False)
+def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
+    """
+    Defines a partial order ('depends on' function) where `this` must occur
+    before `that`.
+
+    For example, the following pass list and constraint list would be invalid.
+    ```
+    passes = [pass_b, pass_a]
+
+    constraints = [
+        this_before_that_pass_constraint(pass_a, pass_b)
+    ]
+    ```
+
+    Args:
+        this (Callable): pass which should occur first
+        that (Callable): pass which should occur later
+
+    Returns:
+        depends_on (Callable[[Object, Object], bool]
+    """
+
+    def depends_on(a: Callable, b: Callable):
+        if a == that and b == this:
+            return False
+        return True
+
+    return depends_on
+
+
+@compatibility(is_backward_compatible=False)
+class PassManager:
+    """
+    Construct a PassManager.
+
+    Collects passes and constraints. This defines the pass schedule, manages
+    pass constraints and pass execution.
+
+    Args:
+        passes (Optional[List[Callable]]): List of passes. A pass is a
+            callable which modifies an object and returns a PassResult
+        constraint (Optional[List[Callable]]): List of constraints. A
+            constraint is a callable which takes two passes (A, B) and returns
+            True if A depends on B and False otherwise. See implementation of
+            `this_before_that_pass_constraint` for example.
+        steps (int): Max number of times we run the passes (default = 1).
+        run_checks_after_each_pass (bool): Whether to run checks and linting
+            after each pass
+        suppress_check_failures (bool): Whether to raise errors when running
+            checks
+    """
+
+    passes: List[Callable[[nn.Module], PassResult]]
+    constraints: List[Callable[[Callable, Callable], bool]]
+    _validated: bool = False
+    steps: int = 1
+
+    def __init__(
+        self,
+        passes=None,
+        constraints=None,
+        steps=None,
+        run_checks_after_each_pass: bool = False,
+        suppress_check_failures: bool = False,
+    ):
+        self.passes = passes or []
+        self.constraints = constraints or []
+        if steps:
+            self.steps = steps
+
+        self.run_checks_after_each_pass = run_checks_after_each_pass
+        self.suppress_check_failures = suppress_check_failures
+
+    def add_pass(self, _pass: Callable):
+        """
+        Adds a pass into the current list of passes.
+        """
+        self.passes.append(_pass)
+        self._validated = False
+
+    def add_constraint(self, constraint: Callable):
+        """
+        Adds a constraint into the current list of constraints.
+        """
+        self.constraints.append(constraint)
+        self._validated = False
+
+    def validate_constraints(self):
+        """
+        Validates that current pass schedule defined by `self.passes` is valid
+        according to all constraints in `self.constraints`
+        """
+        if self._validated:
+            return
+        for constraint in self.constraints:
+            _validate_pass_schedule_constraint(constraint, self.passes)
+        self._validated = True
+
+    def solve_constraints(self):
+        """
+        Finds a valid traversal order based on the given constraints and orders
+        the passes based on this order.
+
+        If a circular dependency exists between the constraints and steps = 1,
+        then we will raise an error because if steps != 1 this means that we
+        will re-run the passes, allowing for circular dependencies.
+        """
+        self.passes = _topological_sort_passes(self.passes, self.constraints)
+        self._validated = True
+
+    def add_checks(self, check: Callable) -> None:
+        """
+        Adds a function which takes runs various checks on a given graph module.
+        This function is run before and after each pass if the
+        `run_checks_after_each_pass` flag is enabled.
+        """
+        sig = inspect.signature(check)
+
+        if len(list(sig.parameters.values())) != 1:
+            raise TypeError("PassManager check function should only take in one variable, a module")
+
+        setattr(self, "check", check)  # noqa: B010
+
+    def check(self, module: nn.Module) -> None:
+        pass
+
+    def __call__(self, module: nn.Module) -> PassResult:
+        """
+        Runs a list of passes in the order based on `self.passes` on the given
+        graph module. Each time a pass is run, checks and linting will be run on
+        the graph module if `run_checks_after_each_pass` is set.
+
+        If the module is a graph module, we will run the list of passes until
+        the graph stops changing, or until `steps` number of times.
+        """
+        # Order the passes based on the constraints
+        if not self._validated:
+            self.solve_constraints()
+
+        # Check graph invariants
+        self.check(module)
+
+        # Run the set of passes `steps` number of times or until the graph stops
+        # changing
+        overall_modified = False
+        for _ in range(self.steps):
+            modified = False
+
+            # Run the set of passes on the graph module
+            for i, fn in enumerate(self.passes):
+                fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
+                logger.debug("Running pass '%s'", fn_name)
+
+                try:
+                    res = fn(module)
+
+                    if not isinstance(res, PassResult) and not hasattr(
+                        res, "graph_module"
+                    ):
+                        raise TypeError(
+                            f"The result of the pass {fn_name} should be type PassResult."
+                            + "Please wrap it with pass_result_wrapper()"
+                        )
+                    module = res.graph_module
+                    modified = modified or res.modified
+
+                    if isinstance(module, GraphModule):
+                        logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
+                        module.recompile()
+
+                    # Check graph invariants
+                    if self.run_checks_after_each_pass:
+                        self.check(module)
+
+                except Exception as e:
+                    prev_pass_names = [
+                        p.__name__ if inspect.isfunction(p) else type(p).__name__
+                        for p in self.passes[:i]
+                    ]
+                    msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
+                    raise Exception(msg) from e
+
+            # If the graph no longer changes, then we can stop running these passes
+            overall_modified = overall_modified or modified
+            if not modified:
+                break
+
+        return PassResult(module, overall_modified)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/net_min_base.py b/MLPY/Lib/site-packages/torch/fx/passes/net_min_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2d6c1fb0d88a072480a3b2f74032d3bd4869999
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/net_min_base.py
@@ -0,0 +1,731 @@
+import logging
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+import torch.fx
+
+from torch.fx._compatibility import compatibility
+from torch.fx.node import map_arg
+
+from .shape_prop import ShapeProp
+from .split_utils import split_by_tags
+from .tools_common import (
+    CALLABLE_NODE_OPS,
+    FxNetAccFusionsFinder,
+    Names,
+    NodeList,
+    NodeSet,
+    TensorOrTensors,
+    Tensors,
+)
+
+__all__ = [
+    "FxNetMinimizerBadModuleError",
+    "FxNetMinimizerRunFuncError",
+    "FxNetMinimizerResultMismatchError",
+]
+
+_LOGGER = logging.getLogger(__name__)
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetMinimizerBadModuleError(Exception):
+    """
+    Raised if failed to split out a minimize module
+    """
+
+    pass
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetMinimizerRunFuncError(Exception):
+    """
+    Raised if error occurs during run_a or run_b functions
+    """
+
+    pass
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetMinimizerResultMismatchError(Exception):
+    """
+    Raised if comparing function thinks the results are mismatching.
+    """
+
+    pass
+
+
+@dataclass
+class _MinimizerSettingBase:
+    """
+    Args:
+    `accumulate_error`: Instead of using a's input for both converted module to verify
+    , use the previous outputs of each converted module as input to accumulate the
+    errors.
+
+    `traverse_method`: "sequential" or "binary" or "accumulate"
+    Determine the way of traverse the nodes in FX module.
+
+    `find_all`: Minimizer will go through the entire model and return all problematic nodes.
+
+    `return_intermediate`: If true, when using `run_nodes()` function to run the
+    model, intermediate results of all the ops will be returned as output.
+    """
+
+    accumulate_error: bool = False
+    traverse_method: str = "sequential"
+    find_all: bool = False
+    return_intermediate: bool = False
+
+    def __str__(self):
+        settings_str = "FX Minimizer Settings:\n"
+
+        for k, v in vars(self).items():
+            settings_str += f"\t{k}: {v}\n"
+
+        return settings_str
+
+
+class _MinimizerBase:
+    """
+    This class is used to automatically find problematic nodes in a model. It takes a FX
+    graphmodule and generate some submodules while traverse the graph. Then two functions
+    `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
+    will be used to compare the results.
+
+    Currently we provides two ways to traverse the graph and generate submodules.
+        1. Sequential traversal: this will traverse the graph node by node and generate
+           one submodule with one sigle node.
+        2. Binary searching: this will do a binary search style traversal on the graph.
+
+    For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
+    """
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Tensors,
+        compare_fn: Callable[
+            [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
+        ],
+        settings: _MinimizerSettingBase,
+        module_exporter: Optional[
+            Callable[
+                [List[torch.Tensor], torch.fx.GraphModule, str],
+                None
+            ]
+        ] = None,
+    ):
+        assert isinstance(module, torch.fx.GraphModule)
+
+        self.module = module
+        self.sample_input = sample_input
+        self.compare_fn = compare_fn
+        self.module_exporter = module_exporter
+        self.settings = settings
+
+        # Stores outputs of run_a function
+        self.a_outputs: Dict[str, Any] = {}
+
+        # Stores outputs of run_b function
+        self.b_outputs: Dict[str, Any] = {}
+
+        # Stores the results of compare_fn
+        self.results: Dict[Any, Any] = {}
+
+        # Stores the report for the runs
+        self.reports: List[List[str]] = []
+
+        # Current iteration
+        self.iteration: int = 0
+
+        callable_nodes = {
+            node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
+        }
+        ShapeProp(self.module).propagate(*self.sample_input)
+        self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
+
+        # Check if number of input in sample_input matches the number of placeholders
+        placeholders = [
+            node.name for node in self.module.graph.nodes if node.op == "placeholder"
+        ]
+        assert len(placeholders) == len(self.sample_input)
+
+        # Store sample_input
+        for i, name in enumerate(placeholders):
+            self.a_outputs[name] = sample_input[i]
+            self.b_outputs[name] = sample_input[i]
+
+    def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
+        """
+        Run `mod` with `inputs` and generate output. The output will be compared with
+        output of run_b().
+        """
+        raise RuntimeError("run_a() is not implemented.")
+
+    def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
+        """
+        Run `mod` with `inputs` and generate output. The output will be compared with
+        output of run_a().
+        """
+        raise RuntimeError("run_b() is not implemented.")
+
+    def _store_outputs(
+        self,
+        a_result: TensorOrTensors,
+        b_result: TensorOrTensors,
+        submodule: torch.fx.GraphModule,
+    ):
+        """
+        Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
+        self.b_outputs, so that we can use them when execute preceding nodes that
+        use those outputs as inputs.
+
+        Args:
+            a_result: Output of self.run_a(). Could be a tensor or tensors.
+            b_result: Output of self.run_b(). Could be a tensor or tensors.
+            submodule: The module that generates a_result and b_result.
+        """
+        output_node = next(
+            node for node in submodule.graph.nodes if node.op == "output"
+        )
+
+        # Only one output
+        if isinstance(output_node.args[0], torch.fx.Node):
+            self.a_outputs[output_node.args[0].name] = a_result
+            self.b_outputs[output_node.args[0].name] = b_result
+        # Multiple outputs
+        else:
+            for i, arg in enumerate(output_node.args[0]):
+                self.a_outputs[arg.name] = a_result[i]
+                self.b_outputs[arg.name] = b_result[i]
+
+    def _get_submod_inputs(
+        self, main_module: torch.fx.GraphModule, submod_path: str
+    ) -> Tuple[Tensors, Tensors]:
+        """
+        Try get submodule inputs from stored outputs. If not found then use
+        torch_glow.get_submod_inputs to get the inputs.
+
+        If accumulate_error is False, use a_input for run_a() and run_b()
+        otherwise use a_input for run_a and b_input for run_b.
+
+        Args:
+            main_module: Top-levlel fx module.
+            submod_path: Path to the submodule we want to run and compare results.
+
+        Returns:
+            a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
+            b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
+        """
+        a_input = []
+        b_input = []
+        submodule = getattr(main_module, submod_path)
+        placeholders = [
+            node.name for node in submodule.graph.nodes if node.op == "placeholder"
+        ]
+
+        # If all placeholder can be found in stored outputs, use stored
+        # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
+        # to get the inputs.
+        if set(placeholders) <= self.a_outputs.keys():
+            for name in placeholders:
+                a_input.append(self.a_outputs[name])
+                b_input.append(self.b_outputs[name])
+        else:
+            if self.settings.accumulate_error:
+                print(f"Can't find previous stored outputs named {placeholders}!")
+
+            def get_inputs(self: torch.nn.Module, inputs: Any):
+                nonlocal a_input
+                a_input = inputs
+
+            # Use forward hook to get the inputs to the submodule
+            handle = submodule.register_forward_pre_hook(get_inputs)
+            main_module(*self.sample_input)
+            handle.remove()
+
+            b_input = a_input
+
+        if not self.settings.accumulate_error:
+            return a_input, a_input
+
+        return a_input, b_input
+
+    def _tag_nodes(self, selected_nodes: NodeSet):
+        """
+        Tag selected nodes with tag "minimize". Nodes with the same tags will
+        be split to the same submodule afterwards.
+
+        Args:
+            selected_nodes: Nodes that we want to minimize. We will tag those nodes
+                with "minimize", all preceding nodes with "main_0" and all following
+                nodes with "main_1".
+        """
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            if node in selected_nodes:
+                node.tag = "minimize"
+            elif any(
+                n.tag in {"minimize", "main_1"}
+                for n in node.all_input_nodes
+                if n.op in CALLABLE_NODE_OPS
+            ):
+                node.tag = "main_1"
+            else:
+                node.tag = "main_0"
+
+    def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
+        """
+        Split self.module so that one submodule consists of `nodes` and only `nodes`.
+
+        Args:
+            nodes: Nodes that we want to include in the minimize submodule.
+
+        Returns:
+            split_module (torch.fx.GraphModule): the module after split.
+            submodule_name (str): the name of the submodule that consists of `nodes`.
+        """
+        # Color provided nodes
+        self._tag_nodes(nodes)
+
+        # Split module based on coloring
+        split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
+
+        # Find submodule containing colored nodes
+        submodule_name: str = ""
+        for child_name, _ in split_module.named_children():
+            # Skip submodules we're not interested in at the moment
+            if "minimize" not in child_name:
+                continue
+
+            if submodule_name == "":
+                submodule_name = child_name
+            else:
+                raise FxNetMinimizerBadModuleError(
+                    f"Expected only one minimize submodule with nodes {nodes}"
+                )
+
+        if submodule_name == "":
+            raise FxNetMinimizerBadModuleError(
+                f"Minimize submodule was not found with nodes {nodes}"
+            )
+
+        return split_module, submodule_name
+
+    def _run_and_compare(
+        self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
+    ):
+        """
+        Run the submodule in `split_module` that has name `submod_name`
+        using `self.run_a` and `self.run_b` and compare their results.
+
+        Args:
+            split_module: Main module that contains the minimize submodule.
+            submod_name: Name of the minimize submodule.
+            output_names: Names of the node we want to output. If None, we
+                will use the original output.
+        """
+        submodule = getattr(split_module, submod_name)
+        a_input, b_input = self._get_submod_inputs(split_module, submod_name)
+
+        if len(self.reports) == 0:
+            self.reports.append([])
+            self.iteration = 1
+
+        report = self.reports[self.iteration - 1]
+        report.append("Run and compare ...")
+
+        if output_names:
+            output_nodes: NodeList = []
+            for node in submodule.graph.nodes:
+                if node.op == "output":
+                    submodule.graph.erase_node(node)
+
+                if node.name in output_names:
+                    output_nodes.append(node)
+
+            submodule.graph.output(
+                output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
+            )
+            submodule.graph.lint()
+            submodule.recompile()
+
+        # Use name of args in output node as key to store comparison result
+        for node in submodule.graph.nodes:
+            if node.op == "output":
+                result_key = map_arg(node.args, lambda x: x.name)
+
+        try:
+            a_result = self.run_a(submodule, a_input)
+            b_result = self.run_b(submodule, b_input)
+            self._store_outputs(a_result, b_result, submodule)
+        except Exception as e:
+            report.append(f"Exception raised when running {submod_name}: {e}")
+            raise FxNetMinimizerRunFuncError(  # noqa: TRY200
+                f"Exception raised when running {submod_name}: {e}"
+            )
+
+        # Compare results
+        names: Names = output_names
+        if output_names is None:
+            names = [str(v) for v in result_key]  # type: ignore[possibly-undefined]
+
+        numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
+
+        self.results[result_key] = numeric_result  # type: ignore[possibly-undefined]
+        report.append(f"Numerical accuracy = {numeric_result}")
+        if not bool_result:
+            report.append(f"Result mismatch for {result_key}")
+            if self.module_exporter:
+                self.module_exporter(
+                    List[torch.Tensor](a_input), submodule, str(result_key[0]) + "_cpu",
+                )
+                self.module_exporter(
+                    List[torch.Tensor](b_input), submodule, str(result_key[0]) + "_acc",
+                )
+            raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
+
+    def _binary_search_impl(
+        self, all_nodes: NodeList, start_idx: int, end_idx: int
+    ) -> NodeSet:
+        """
+        Recursive binary search implementation.
+        """
+        nodes: NodeList = all_nodes[start_idx:end_idx]
+
+        report: List[str] = []
+        self.reports.append(report)
+        self.iteration += 1
+        report.append(f"Binary search iteration {self.iteration}.")
+        report.append(
+            f"From node index {start_idx} to {end_idx-1}. "
+            f"Size of the interested node list is {len(nodes)}"
+        )
+
+        cur_nodes: NodeSet = set(nodes)
+
+        for node in nodes:
+            if node in self.fusions:
+                cur_nodes.update(self.fusions[node])
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, [])
+        except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
+
+            if len(nodes) == 1:
+                report.append(
+                    f"This is the last node in the sub-module. "
+                    f"Search in the current branch is successful with culprit = {cur_nodes}."
+                )
+                self.print_report(report)
+                return cur_nodes
+
+            report.append(
+                "Proceed to split and lower the halves of the current "
+                "sub-module individually."
+            )
+            self.print_report(report)
+
+            mid = len(nodes) // 2
+            culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
+
+            if len(culprits) != 0 and not self.settings.find_all:
+                return culprits
+
+            culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
+
+            if len(culprits) == 0:
+                report.append(
+                    f"Further split and lowering found no errors. "
+                    f"Unable to minimize the submodule with list of nodes: {nodes}"
+                )
+                self.print_report(report)
+
+            return culprits
+        else:
+            report.append("No discrepancy found.")
+            self.print_report(report)
+            return set()
+
+    def _binary_traverse(self, nodes: NodeList) -> NodeSet:
+        """
+        Binary search on `nodes` for culprit.
+        """
+        return self._binary_search_impl(nodes, 0, len(nodes))
+
+    def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
+        """
+        Traverse `nodes` one by one and determine if any of them is a culprit.
+        """
+        culprits: NodeSet = set()
+
+        for node in nodes:
+            report: List[str] = []
+            self.reports.append(report)
+            self.iteration += 1
+            report.append(f"Sequential traverse iteration {self.iteration}.")
+            report.append(f"Visit node: {node.name}")
+
+            _LOGGER.info("Visit node: %s", node.name)
+            cur_nodes: NodeSet = {node}
+
+            if node in self.fusions:
+                cur_nodes = self.fusions[node]
+
+            try:
+                split_module, submod_name = self._build_submodule(cur_nodes)
+                self._run_and_compare(split_module, submod_name, [node.name])
+                self.print_report(report)
+            except (FxNetMinimizerResultMismatchError):
+                culprits.add(node)
+                report.append(f"Found culprit from numeric error: {node}")
+                self.print_report(report)
+                if not self.settings.find_all:
+                    return culprits
+            except (FxNetMinimizerRunFuncError):
+                culprits.update(cur_nodes)
+                report.append(f"Found culprit from run error: {node}")
+                self.print_report(report)
+                if not self.settings.find_all:
+                    return culprits
+
+        return culprits
+
+    def _defined_traverse(self, nodes: NodeList) -> NodeSet:
+        """
+        run user defined `nodes` and determine if it is a culprit.
+        """
+        culprits: NodeSet = set()
+
+        first_node_name = nodes[0].name
+        output_node_name = nodes[-1].name
+        report = [f"Defined graph from {first_node_name} to {output_node_name}"]
+        cur_nodes: NodeSet = set(nodes)
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, [output_node_name])
+            self.print_report(report)
+        except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
+            report.append(f"Found culprit {cur_nodes}")
+            self.print_report(report)
+            return culprits
+
+        return culprits
+
+    def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
+        culprits: NodeSet = set()
+        nodes_to_run: NodeSet = set()
+
+        # find_all is not supported for accumulate traversal because all the
+        # ops run on NNPI. So we return after the first op that raises error.
+        if self.settings.find_all:
+            print("'Find All' mode is not supported in accumulate traversal.")
+            return culprits
+
+        for node in nodes:
+            report: List[str] = []
+            self.reports.append(report)
+            self.iteration += 1
+            report.append(f"Accumulate traverse iteration {self.iteration}.")
+
+            nodes_to_run.add(node)
+
+            node_name = node.name
+            if node_name is not None and isinstance(node_name, tuple):
+                node_name = node_name[0]
+            assert node_name is not None and isinstance(
+                node_name, str
+            ), f"minimize: node_name: {node_name}"
+
+            report.append(f"Add node: {node_name}")
+
+            try:
+                split_module, submod_name = self._build_submodule(nodes_to_run)
+                self._run_and_compare(split_module, submod_name, [node_name])
+                self.print_report(report)
+            except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
+                culprits.add(node)
+                report.append(f"Found culprit {node}")
+                self.print_report(report)
+                return culprits
+
+        return culprits
+
+    def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
+        """
+        Skip certain nodes in graph based on settings
+        """
+        culprits: NodeSet = set()
+        nodes: NodeList = all_nodes[start_idx:end_idx]
+
+        report: List[str] = []
+        self.reports.append(report)
+        self.iteration += 1
+        report.append(f" Nodes block {self.iteration}.")
+        report.append(
+            f"From node index {start_idx} to {end_idx-1}. "
+            f"Size of the interested node list is {len(nodes)}"
+        )
+
+        cur_nodes: NodeSet = set(nodes)
+
+        for node in nodes:
+            if node in self.fusions:
+                cur_nodes.update(self.fusions[node])
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, [])
+        except (FxNetMinimizerResultMismatchError):
+            culprits.update(cur_nodes)
+            report.append(f"Found culprit from numeric error: {cur_nodes}")
+            self.print_report(report)
+            return culprits
+        except (FxNetMinimizerRunFuncError):
+            culprits.update(cur_nodes)
+            report.append(f"Found culprit from run error: {node}")
+            self.print_report(report)
+            return culprits
+        else:
+            report.append("No discrepancy found.")
+            self.print_report(report)
+            return set()
+
+
+    def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
+        """
+        Skip certain nodes in graph based on settings
+        """
+        start_idx = 0
+        num_nodes = len(all_nodes)
+        idx = 0
+        culprits = set()
+        while idx < num_nodes:
+            node = all_nodes[idx]
+            if (node.name in skip_nodes):  # skip the node
+                if idx > start_idx:
+                    culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
+                start_idx = idx + 1
+            elif idx == num_nodes - 1 and start_idx <= idx:  # last node
+                culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
+            idx += 1
+
+        return culprits
+
+
+
+    def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
+        """
+        Collect nodes in the model that between nodes with name of `start` and `end`.
+        These two nodes are also included.
+        """
+        nodes: NodeList = []
+        add_node = start is None
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            if node.name == start:
+                add_node = True
+
+            if add_node:
+                nodes.append(node)
+
+            if node.name == end:
+                break
+
+        return nodes
+
+    def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
+        """
+        Run part of the model from `start` node to `end` node. If `start` is None
+        then we start from the beginning of the model. If `end` is None then we
+        stop at the end of the model.
+
+        Args:
+            start: The name of the node which is the first node of the submodule
+                we want to run. If set to None, then we'll start with the first
+                node of the model.
+            end: The name of the node which is the last node of the submodule we
+                want to run. If set to None, we'll end with the last node of the
+                model.
+        """
+        nodes = self._collect_nodes(start, end)
+        cur_nodes = set(nodes)
+
+        for node in nodes:
+            if node in self.fusions:
+                cur_nodes.update(self.fusions[node])
+
+        output_names = []
+        if self.settings.return_intermediate:
+            output_names = [node.name for node in nodes]
+
+        try:
+            split_module, submod_name = self._build_submodule(cur_nodes)
+            self._run_and_compare(split_module, submod_name, output_names)
+        except (
+            FxNetMinimizerRunFuncError,
+            FxNetMinimizerResultMismatchError,
+        ) as e:
+            print(e)
+
+    def print_report(self, report: List[str]):
+        for i in range(len(report)):
+            if i > 0:
+                print(" . " + report[i])
+            else:
+                print(report[i])
+
+    def print_reports(self):
+        for report in self.reports:
+            self.print_report(report)
+
+    def minimize(
+        self, start: Optional[str] = None, end: Optional[str] = None, skip_nodes: Optional[List] = None,
+    ) -> NodeSet:
+        """
+        Minimizing the model from node with name `start` to node with name `end` base
+        on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
+        FxNetMinimizerResultMismatchError errors.
+
+        Args:
+            start: The name of the node where we want to start minimizing. If set
+                to None, then we'll start with the first node of the model.
+            end: The name of the node where we want to terminate minimizing. If
+                set to None, we'll end with the last node of the model.
+
+        Returns:
+            nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
+                FxNetMinimizerResultMismatchError errors during minimizing.
+        """
+
+        print(self.settings)
+        print(self.module.graph)
+
+        nodes = self._collect_nodes(start, end)
+
+        if self.settings.traverse_method == "sequential":
+            return self._sequential_traverse(nodes)
+
+        if self.settings.traverse_method == "binary":
+            return self._binary_traverse(nodes)
+
+        if self.settings.traverse_method == "accumulate":
+            return self._accumulate_traverse(nodes)
+
+        if self.settings.traverse_method == "skip":
+            if (skip_nodes is None):
+                raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
+            return self._skip_traverse(nodes, skip_nodes)
+
+        if self.settings.traverse_method == "defined":
+            return self._defined_traverse(nodes)
+
+        raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/operator_support.py b/MLPY/Lib/site-packages/torch/fx/passes/operator_support.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac5df5ec226c7c3fead84ebb7e8f22886ef027bc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/operator_support.py
@@ -0,0 +1,217 @@
+import abc
+import typing as t
+
+import torch
+import torch.fx
+from torch.fx._compatibility import compatibility
+from .shape_prop import TensorMetadata
+from .tools_common import get_node_target, CALLABLE_NODE_OPS
+
+
+__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
+
+# fx.Node.target typename, as returned by `get_node_target()`
+TargetTypeName = str
+
+# Arguments' dtypes for a given node, see `OperatorSupport`
+SupportedArgumentDTypes = t.Optional[
+    t.Tuple[
+        t.Sequence[t.Sequence[torch.dtype]],
+        t.Dict[str, t.Sequence[torch.dtype]],
+    ]
+]
+
+SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
+
+
+@compatibility(is_backward_compatible=False)
+class OperatorSupportBase(abc.ABC):
+    """Interface for determining if a fx.Node is supported by a backend"""
+    @abc.abstractmethod
+    def is_node_supported(
+        self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
+    ) -> bool:
+        raise NotImplementedError()
+
+
+@compatibility(is_backward_compatible=False)
+class OperatorSupport(OperatorSupportBase):
+    """
+    `_support_dict` maps node.target typename to supported inputs dtypes.
+
+    node.target typename is retrieved using helper function `get_node_target()`
+
+    If supported inputs dtypes is None, it means any dtype is supported, else
+    we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
+
+    The first tuple ([dtypes], ...) indicates what dtypes are supported for
+    inputs in node.args and the second dict {"name": [dtypes], ...} indicates
+    what dtypes are supported for inputs in node.kwargs.
+
+    For inputs in args, if we don't want to check it, we can put None there,
+    e.g. (None, [torch.float]) indicates that we don't care about the type of
+    the first input in args. And for inputs in kwargs, if not listed, will not
+    be checked.
+    """
+
+    _support_dict: SupportDict
+
+    def __init__(
+        self,
+        support_dict: t.Optional[SupportDict] = None
+    ):
+        self._support_dict = support_dict or {}
+
+    def is_node_supported(
+        self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
+    ) -> bool:
+        """
+        Args:
+            `submodules`: mapping from module name to the module. This can be
+                          retrieved by calling model.named_modules().
+
+            `node`: a Fx node that we want to determine whether it's supported.
+
+        Returns:
+            `is_supported`: whether the arg `node` is supported.
+        """
+        if node.op not in CALLABLE_NODE_OPS:
+            return True
+
+        target = get_node_target(submodules, node)
+
+        # Target not found in _support_dict meaning that we don't support this op at all
+        if target not in self._support_dict:
+            return False
+
+        # The rule for target is None meaning that we accept any dtype
+        if self._support_dict[target] is None:
+            return True
+
+        args_dtypes, kwargs_dtypes = self._support_dict[target]  # type: ignore[misc]
+
+        # Check args dtypes
+        for i, dtypes in enumerate(args_dtypes):
+            if len(node.args) <= i:
+                break
+
+            # None indicates we don't care about the dtype of args[i]
+            if dtypes is None:
+                continue
+
+            # If arg is not a node then we don't check it
+            if not isinstance(node.args[i], torch.fx.Node):
+                continue
+
+            arg_dtype = _get_arg_dtype(node.args[i])  # type: ignore[arg-type]
+            if arg_dtype not in dtypes:
+                return False
+
+        # Check kwargs dtypes
+        for k, dtypes in kwargs_dtypes.items():
+            if k not in node.kwargs:
+                continue
+
+            # If arg is not a node then we don't check it
+            if not isinstance(node.kwargs[k], torch.fx.Node):
+                continue
+
+            kwarg_dtype = _get_arg_dtype(node.kwargs[k])  # type: ignore[arg-type]
+            if kwarg_dtype not in dtypes:
+                return False
+
+        return True
+
+
+# ======================================================================
+# Functional interfaces and utils for defining basic operator support logic
+# and composing them into more complex ones
+# ======================================================================
+
+IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
+
+
+@compatibility(is_backward_compatible=False)
+def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
+    """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
+
+    `IsNodeSupported` has the same call signature as
+    `OperatorSupportBase.is_node_supported`
+    """
+    class FunctionalOperatorSupport(OperatorSupportBase):
+        def is_node_supported(
+                self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
+        ) -> bool:
+            return is_node_supported(submodules, node)
+    return FunctionalOperatorSupport()
+
+
+@compatibility(is_backward_compatible=False)
+def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
+    """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
+    instance by evaluating each input `OperatorSupportBase` instance, and returns False if
+    any of it reports False.
+    """
+    def _chain(submods, node) -> bool:
+        return all(
+            x.is_node_supported(submods, node)
+            for x in op_support
+        )
+    return create_op_support(_chain)
+
+
+@compatibility(is_backward_compatible=False)
+def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
+    """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
+    instance by evaluating each input `OperatorSupportBase` instance, and returns True if
+    any of it reports True.
+    """
+    def _any_chain(submods, node) -> bool:
+        return any(
+            x.is_node_supported(submods, node)
+            for x in op_support
+        )
+    return create_op_support(_any_chain)
+
+
+@compatibility(is_backward_compatible=False)
+class OpSupports:
+    """A set of atomic `OperatorSupportBase` instances that can be combined together
+    to form more complex operator support logic.
+    """
+    @classmethod
+    def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
+        """Report a node as non-supported, if any of its arguments is of dtype"""
+
+        def _decline_if_input_dtype(
+            submodules: t.Mapping[str, torch.nn.Module],
+            node: torch.fx.Node,
+        ) -> bool:
+            for arg in node.all_input_nodes:
+                arg_dtype = _get_arg_dtype(arg)
+                if arg_dtype == dtype:
+                    return False
+            return True
+        return create_op_support(_decline_if_input_dtype)
+
+    @classmethod
+    def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
+        """
+        If a node has a name that is in the disallow set, reported it as non-supported.
+        """
+        def _decline_if_node_in_names(
+            submodules: t.Mapping[str, torch.nn.Module],
+            node: torch.fx.Node,
+        ) -> bool:
+            if node.name in disallow_set:
+                return False
+            else:
+                return True
+        return create_op_support(_decline_if_node_in_names)
+
+
+def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
+    assert isinstance(arg, torch.fx.Node)
+    tensor_meta = arg.meta.get("tensor_meta")  # type: ignore[union-attr]
+    dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
+    return dtype
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/param_fetch.py b/MLPY/Lib/site-packages/torch/fx/passes/param_fetch.py
new file mode 100644
index 0000000000000000000000000000000000000000..22ec7305d8191862d31257c6d29ecf3863873e19
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/param_fetch.py
@@ -0,0 +1,66 @@
+from torch.fx.graph_module import GraphModule
+from typing import Any, Callable, Dict, List, Tuple, Type
+import torch
+import torch.nn as nn
+
+from torch.fx._compatibility import compatibility
+
+__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
+
+# Matching method matches the attribute name of current version to the attribute name of `target_version`
+@compatibility(is_backward_compatible=False)
+def default_matching(name: str, target_version: int) -> str:
+    """Default matching method
+    """
+    return name
+
+# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
+# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
+# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
+module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
+    torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
+    torch.nn.modules.conv.Conv2d: (
+        1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
+    ),
+    torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
+    torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
+    torch.nn.modules.pooling.MaxPool2d: (
+        1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
+    ),
+    torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
+}
+
+@compatibility(is_backward_compatible=False)
+def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
+    """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
+    after checking module's version is compatible with the `module_fetch_book`.
+    """
+    attrs_for_lowering: Dict[str, Any] = {}
+    attrs_for_lowering["name"] = torch.typename(mod)
+
+    if type(mod) in module_fetch_book:
+        version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
+        if version < mod._version:
+            raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
+                               "please upgrade the module_fetch_book, open an issue and @842974287 "
+                               "or report a bug to AIACC team directly.")
+        for attr in param_to_fetch:
+            attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
+    else:
+        raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
+                           "please add it to the module_fetch_book, open an issue and @842974287 "
+                           "or report a bug to AIACC team directly.")
+    return attrs_for_lowering
+
+@compatibility(is_backward_compatible=False)
+def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
+    """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
+    """
+    submodules = dict(fx_module.named_modules())
+
+    for node in fx_module.graph.nodes:
+        if node.op == "call_module":
+            if isinstance(submodules[node.target], GraphModule):
+                lift_lowering_attrs_to_nodes(submodules[node.target])
+            else:
+                node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/pass_manager.py b/MLPY/Lib/site-packages/torch/fx/passes/pass_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf93182d37fbef674f79580c4702fecdf640a26
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/pass_manager.py
@@ -0,0 +1,257 @@
+from functools import wraps
+from inspect import unwrap
+from typing import Callable, List, Optional
+import logging
+
+logger = logging.getLogger(__name__)
+
+__all__ = [
+    "PassManager",
+    "inplace_wrapper",
+    "log_hook",
+    "loop_pass",
+    "this_before_that_pass_constraint",
+    "these_before_those_pass_constraint",
+]
+
+# for callables which modify object inplace and return something other than
+# the object on which they act
+def inplace_wrapper(fn: Callable) -> Callable:
+    """
+    Convenience wrapper for passes which modify an object inplace. This
+    wrapper makes them return the modified object instead.
+
+    Args:
+        fn (Callable[Object, Any])
+
+    Returns:
+        wrapped_fn (Callable[Object, Object])
+    """
+
+    @wraps(fn)
+    def wrapped_fn(gm):
+        val = fn(gm)
+        return gm
+
+    return wrapped_fn
+
+def log_hook(fn: Callable, level=logging.INFO) -> Callable:
+    """
+    Logs callable output.
+
+    This is useful for logging output of passes. Note inplace_wrapper replaces
+    the pass output with the modified object. If we want to log the original
+    output, apply this wrapper before inplace_wrapper.
+
+
+    ```
+    def my_pass(d: Dict) -> bool:
+        changed = False
+        if 'foo' in d:
+            d['foo'] = 'bar'
+            changed = True
+        return changed
+
+    pm = PassManager(
+        passes=[
+            inplace_wrapper(log_hook(my_pass))
+        ]
+    )
+    ```
+
+    Args:
+        fn (Callable[Type1, Type2])
+        level: logging level (e.g. logging.INFO)
+
+    Returns:
+        wrapped_fn (Callable[Type1, Type2])
+    """
+    @wraps(fn)
+    def wrapped_fn(gm):
+        val = fn(gm)
+        logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
+        return val
+
+    return wrapped_fn
+
+
+
+def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
+    """
+    Convenience wrapper for passes which need to be applied multiple times.
+
+    Exactly one of `n_iter`or `predicate` must be specified.
+
+    Args:
+        base_pass (Callable[Object, Object]): pass to be applied in loop
+        n_iter (int, optional): number of times to loop pass
+        predicate (Callable[Object, bool], optional):
+
+    """
+    assert (n_iter is not None) ^ (
+        predicate is not None
+    ), "Exactly one of `n_iter`or `predicate` must be specified."
+
+    @wraps(base_pass)
+    def new_pass(source):
+        output = source
+        if n_iter is not None and n_iter > 0:
+            for _ in range(n_iter):
+                output = base_pass(output)
+        elif predicate is not None:
+            while predicate(output):
+                output = base_pass(output)
+        else:
+            raise RuntimeError(
+                f"loop_pass must be given positive int n_iter (given "
+                f"{n_iter}) xor predicate (given {predicate})"
+            )
+        return output
+
+    return new_pass
+
+
+# Pass Schedule Constraints:
+#
+# Implemented as 'depends on' operators. A constraint is satisfied iff a list
+# has a valid partial ordering according to this comparison operator.
+def _validate_pass_schedule_constraint(
+    constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
+):
+    for i, a in enumerate(passes):
+        for j, b in enumerate(passes[i + 1 :]):
+            if constraint(a, b):
+                continue
+            raise RuntimeError(
+                f"pass schedule constraint violated. Expected {a} before {b}"
+                f" but found {a} at index {i} and {b} at index{j} in pass"
+                f" list."
+            )
+
+
+def this_before_that_pass_constraint(this: Callable, that: Callable):
+    """
+    Defines a partial order ('depends on' function) where `this` must occur
+    before `that`.
+    """
+
+    def depends_on(a: Callable, b: Callable):
+        if a == that and b == this:
+            return False
+        return True
+
+    return depends_on
+
+
+def these_before_those_pass_constraint(these: Callable, those: Callable):
+    """
+    Defines a partial order ('depends on' function) where `these` must occur
+    before `those`. Where the inputs are 'unwrapped' before comparison.
+
+    For example, the following pass list and constraint list would be invalid.
+    ```
+    passes = [
+        loop_pass(pass_b, 3),
+        loop_pass(pass_a, 5),
+    ]
+
+    constraints = [
+        these_before_those_pass_constraint(pass_a, pass_b)
+    ]
+    ```
+
+    Args:
+        these (Callable): pass which should occur first
+        those (Callable): pass which should occur later
+
+    Returns:
+        depends_on (Callable[[Object, Object], bool]
+    """
+
+    def depends_on(a: Callable, b: Callable):
+        if unwrap(a) == those and unwrap(b) == these:
+            return False
+        return True
+
+    return depends_on
+
+
+class PassManager:
+    """
+    Construct a PassManager.
+
+    Collects passes and constraints. This defines the pass schedule, manages
+    pass constraints and pass execution.
+
+    Args:
+        passes (Optional[List[Callable]]): list of passes. A pass is a
+            callable which modifies an object and returns modified object
+        constraint (Optional[List[Callable]]): list of constraints. A
+            constraint is a callable which takes two passes (A, B) and returns
+            True if A depends on B and False otherwise. See implementation of
+            `this_before_that_pass_constraint` for example.
+    """
+
+    passes: List[Callable]
+    constraints: List[Callable]
+    _validated: bool = False
+
+    def __init__(
+        self,
+        passes=None,
+        constraints=None,
+    ):
+        self.passes = passes or []
+        self.constraints = constraints or []
+
+    @classmethod
+    def build_from_passlist(cls, passes):
+        pm = PassManager(passes)
+        # TODO(alexbeloi): add constraint management/validation
+        return pm
+
+    def add_pass(self, _pass: Callable):
+        self.passes.append(_pass)
+        self._validated = False
+
+    def add_constraint(self, constraint):
+        self.constraints.append(constraint)
+        self._validated = False
+
+    def remove_pass(self, _passes: List[str]):
+        if _passes is None:
+            return
+        passes_left = []
+        for ps in self.passes:
+            if ps.__name__ not in _passes:
+                passes_left.append(ps)
+        self.passes = passes_left
+        self._validated = False
+
+    def replace_pass(self, _target, _replacement):
+        passes_left = []
+        for ps in self.passes:
+            if ps.__name__ == _target.__name__:
+                passes_left.append(_replacement)
+            else:
+                passes_left.append(ps)
+        self.passes = passes_left
+        self._validated = False
+
+    def validate(self):
+        """
+        Validates that current pass schedule defined by `self.passes` is valid
+        according to all constraints in `self.constraints`
+        """
+        if self._validated:
+            return
+        for constraint in self.constraints:
+            _validate_pass_schedule_constraint(constraint, self.passes)
+        self._validated = True
+
+    def __call__(self, source):
+        self.validate()
+        out = source
+        for _pass in self.passes:
+            out = _pass(out)
+        return out
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/reinplace.py b/MLPY/Lib/site-packages/torch/fx/passes/reinplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f8bfe4eb2f77023755781a9e67cf99e25fe8117
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/reinplace.py
@@ -0,0 +1,675 @@
+import torch
+from torch.fx import Node
+from torch.fx._compatibility import compatibility
+from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
+from torch.utils._pytree import tree_map_only
+from torch.utils import _pytree as pytree
+from torch.multiprocessing.reductions import StorageWeakRef
+
+import _operator
+from enum import Enum
+import itertools
+from typing import Set, Dict
+from collections import defaultdict
+
+__all__ = ['reinplace']
+
+class _ViewType(Enum):
+    NonView = 0
+    SingleOutputView = 1
+    MultiOutputView = 2
+
+def _is_view_op(tgt):
+    if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
+        schema = tgt._schema
+        if len(schema.arguments) > 0:
+            first_arg = schema.arguments[0]
+            # check if op is a view
+            return first_arg.alias_info is not None and not first_arg.alias_info.is_write
+
+def _get_view_type(tgt) -> _ViewType:
+    if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
+        schema = tgt._schema
+        if len(schema.arguments) > 0:
+            first_arg = schema.arguments[0]
+            # check if op is a view
+            if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
+                # check if op is a multi-output view
+                if '*' in first_arg.alias_info.after_set:
+                    return _ViewType.MultiOutputView
+                else:
+                    return _ViewType.SingleOutputView
+    return _ViewType.NonView
+
+
+# Stores a bunch of metadata related to functionalization each node.
+# Relevant metadata:
+# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
+#   The fake tensor output from running the current node
+# n.meta['view_of']: Node
+#   If the current node n is a view of some base tensor, the 'view_of' field tells us which
+#   view node was used to generate the current node (a view tensor).
+#   This information actually makes `fake_result` redundant, but we can use `fake_result`
+#   to sanity check that our aliasing information is correct.
+@compatibility(is_backward_compatible=False)
+class _FunctionalizationMetadataProp(torch.fx.Interpreter):
+
+    def run_node(self, node: Node):
+        self.node_counter += 1
+        result = super().run_node(node)
+        node.meta['fake_result'] = result
+        node.meta['node_idx'] = self.node_counter
+
+        # (1) Update metadata with the list of nodes that are used by this node
+        # copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
+        # We don't want to treat it as "being used as an input".
+        node_args = node.args
+        if node.target is torch.ops.aten.copy_.default:
+            node_args = node_args[1:]
+
+        # (2) Update metadata to track aliasing information about view tensor nodes.
+        if node.op == 'call_function':
+            view_type = _get_view_type(node.target)
+            if view_type == _ViewType.SingleOutputView:
+                assert isinstance(node.args[0], Node)
+                node.meta['view_of'] = node.args[0]
+            elif view_type == _ViewType.MultiOutputView:
+                self.multi_output_view_nodes[node] = node.args[0]
+
+            # Check if we returned a multi-output view,
+            # and we're now grabbing the individual views from the output.
+            #
+            # For multi-output views, we want to map each output view to the base,
+            # but this mapping involves two separate nodes in FX IR.
+            # e.g. "a, b = x_1.split(...)" becomes:
+            #    %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
+            #    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
+            #    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
+            # And we'd like to set:
+            #    getitem1.meta['view_of'] = x_1
+            elif node.target is _operator.getitem:
+                list_arg = node.args[0]
+                maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
+                if maybe_base_of_view is not None:
+                    # Note: we could also track indexing info here for multi-output views.
+                    # I don't think this metadata is strictly needed for de-functionalization.
+                    assert isinstance(maybe_base_of_view, Node)
+                    node.meta['view_of'] = maybe_base_of_view
+
+        if 'view_of' in node.meta:
+            # We're linking the current node with its first argument as views.
+            # Assert here that this is actually the case, and their storages are the same.
+            assert isinstance(node.meta['fake_result'], FakeTensor)
+            assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
+            view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
+            base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
+            assert view_storage == base_storage
+        return result
+
+
+
+    def propagate(self, *args):
+        self.multi_output_view_nodes = {}
+        self.node_counter = -1
+
+        with FakeTensorMode() as mode:
+            fake_args = [mode.from_tensor(a) for a in args]
+            return super().run(*fake_args)
+
+def _schemas_match(functional_schema, inplace_schema):
+    names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
+    arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
+        a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
+    # for the inplace op, its first argument should be mutable
+    assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
+    # and its remaining arguments shouldn't be.
+    assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
+    return names_match and arg_types_match
+
+# TODO: this should be beefed up to be able to properly re-inplace with:
+# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
+# - out= ops (e.g. angle -> angle.out)
+# TODO: we should also figure this info out using torchgen.
+def _maybe_get_inplace_op(op):
+    # __module__ seems broken; it returns torch._ops.aten which doesn't exist
+    if not isinstance(op, torch._ops.OpOverload):
+        return None
+    # Some view ops have inplace variants (as_strided_, etc),
+    # but we do NOT want the reinplacing pass to directly add these into the program.
+    # (they'll require extra special handling, aren't aren't really useful for perf anyway)
+    if _is_view_op(op):
+        return None
+    op_namespace = op.__module__.split(".")[-1]
+    op_base_name = op.overloadpacket.__name__
+    maybe_namespace_module = getattr(torch.ops, op_namespace)
+    maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
+    if maybe_inplace_op is None:
+        return None
+
+    inplace_overloads = [
+        getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
+    ]
+    inplace_overloads_with_matching_schemas = [
+        f
+        for f in inplace_overloads
+        if _schemas_match(op._schema, f._schema)
+    ]
+    # Just because foo() and foo_() are both existing operators,
+    # They aren't guaranteed to have compatible schemas.
+    # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
+    # Even though several overloads of pow_ exist.
+    if len(inplace_overloads_with_matching_schemas) == 0:
+        return None
+    assert len(inplace_overloads_with_matching_schemas) == 1
+    inplace_op = inplace_overloads_with_matching_schemas[0]
+    return inplace_op
+
+_VIEW_INVERSE_MAP = {
+    torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
+    torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
+    torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
+    torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
+}
+
+# This function, given a set of set of (aliased) tensor nodes,
+# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
+# in the node ordering.
+def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
+    def _add_if_tensor(x, set_):
+        if isinstance(x, FakeTensor):
+            set_.add(StorageWeakRef(x._typed_storage()))
+
+    nodes_used_after = set()
+    for t in tensor_aliases:
+        # get all nodes that use the current alias
+        usage_nodes = t.users
+        for n in usage_nodes:
+            # We only care about usages after the current node
+            if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
+                continue
+            # We also don't care about intermediate view ops.
+            # They only matter if their output is then used elsewhere
+            # (either in an out-of-place op, or as an output to the function).
+            if n in tensor_aliases:
+                if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
+                    continue
+            nodes_used_after.add(n)
+    return nodes_used_after
+
+# Given an op that we're trying to re-inplace, "b = foo(a)",
+# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
+# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
+# If there are any aliases in the alias_set(a) that satisfy:
+# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
+# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
+#     as "alias"
+def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
+    def matching_view_metadata(a, b):
+        return a.size() == b.size() and \
+            a.stride() == b.stride() and \
+            a.storage_offset() == b.storage_offset()
+
+    view_inverse_nodes = set()
+    # Go through them in node order, so we can see chains of view_scatter ops.
+    for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
+        if n.target not in _VIEW_INVERSE_MAP:
+            continue
+        base = n.args[0]
+        mutated_view = n.args[1]
+        assert isinstance(base, Node)
+        assert isinstance(base.meta['fake_result'], FakeTensor)
+        assert isinstance(mutated_view, Node)
+        assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
+        # Check that this view_inverse op actually corresponds to taking doing the inverse
+        # of one of our existing self_alias nodes.
+        original_view = _VIEW_INVERSE_MAP[n.target]
+        for self_alias in self_aliases:
+            # We're looking for some alias of the self arg, "alias",
+            # that was created from some op `alias = foo(base, args...)`
+            # such that the current _scatter op "inverts" that foo call.
+            # We can check that by running the original op again, and checking that the strides match.
+            if 'view_of' not in self_alias.meta:
+                continue
+            self_alias_base = self_alias.meta['view_of']
+            try:
+                # The we're trying to re-use the args from the view_scatter call inside of the corresponding
+                # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
+                # of the current alias we're looking at.
+                view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
+                expected_metadata = self_alias.meta['fake_result']
+                # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
+                if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
+                        matching_view_metadata(view_replay_metadata, expected_metadata):
+                    view_inverse_nodes.add(n)
+            except Exception:
+                continue
+
+    return view_inverse_nodes
+
+
+@compatibility(is_backward_compatible=True)
+def reinplace(gm, *sample_args):
+    """
+    Given an fx.GraphModule, modifies it to perform "reinplacing",
+    mutating the nodes of the graph.
+    We look for out-of-place op call sites like `b = a.add(...)`,
+    and convert them to be inplace (`b = a.add_(...)`),
+    as long as the input to the current operator ("a") isn't re-used
+    anywhere later in the graph.
+
+    This pass currently expects to operate on a **functional, ATen** graph.
+    This can be obtained by running `make_fx(functionalize(f))`.
+
+    Sample inputs are needed to determine aliasing relationships of the inputs.
+    In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
+    inputs to the program.
+
+    Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
+
+    (1) Perform some initial checks on the metadata of "a" and "args..."
+        that can disqualify them from being reinplaced.
+
+      (1a) Check that the self argument we're attempting to reinplace
+           has acceptable dtype/size metadata to reinplace with.
+
+           For example, if we have:
+             a = torch.ones(1)
+             b = torch.ones(10)
+             out = torch.add(a, b)
+           We can't turn that into
+             a.add_(b)
+           Because that would require resizing "a".
+
+           Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
+           because that would require changing a's dtype (from e.g. float32 to bool).
+           Note that in this specific example, we could technically do better..
+
+           If we see the pattern:
+             a_1 = a.ge(b)
+             a_2 = aten._to_copy(a_1, a.dtype)
+           Then we this should be valid to completely re-inplace
+           (this is exactly what functionalization will emit when it sees a.ge_(b)).
+
+           This optimization is only really important for user programs
+           that directly use inplace comparison ops though.
+
+           We also cannot re-inplace on tensors that have overlapping memory,
+           e.g. torch.ones(1).expand(4, 4).add_(1)
+
+      (1b) Check if "a" is an alias of any of the program inputs.
+
+          If it is, skip and move to the next node.
+          Inplace'ing an op that would cause it to mutate a program is not sound,
+          because that would be a side effect visible to the user.
+
+          NOTE: there's a future optimization that we should make:
+          if "a" is a (alias of a)  program input, but later in the program
+          there is a node that looks like "a.copy_(...)",
+          Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
+          which will later be overwritten by the copy_() call.
+
+          This will be an important optimization to have for programs that mutate
+          their inputs. It currently isn't implemented though.
+
+      (1c) Check if "a" and "args..." alias
+
+          For example, re-inplacing to create code like the below
+          isn't guaranteed to be sound:
+
+            aten.mul_(a, a)
+
+    (2) Check that "a" and all of its outstanding aliases are not used anywhere
+        later in the graph. If this is the case, then it's safe to re-inplace
+        to "b = foo_(a)".
+
+        There are a few caveats to this, explained in more detail below:
+        (a) If "a" is used later as an argument to a view op, that is okay.
+            It's only a problem if "a" (or that view) is later passed
+            into a normal operator, or if it is returned as the program output.
+        (b) If "a" is a repeat argument in `foo()`, then don't reinplace.
+            Most ATen kernels don't make any guarantees that this is sound,
+            e.g. if you do aten.mul_(a, a).
+            So we'll just ban re-inplacing in this case.
+            It's only a problem if "a" (or that view) is later passed
+        (c) If "a" is used as an input into a view "inverse" / "scatter"
+            operator, it is potentially fine to re-inplace
+            (and remove that scatter operator from the graph).
+            See below for a more detailed example.
+
+        NOTE: there is an optimization in this step that is crucial
+        to fully recovering performance from functionalization.
+
+        Given this program:
+        def f(x):
+            a = torch.ops.aten.add(x, x)
+            b = torch.ops.aten.diagonal(a)
+            torch.ops.aten.fill_(b, 0)
+            return d
+
+        Functionalization will emit the following:
+        def f(x):
+            a = torch.ops.aten.add(x, x)
+            b = torch.ops.aten.diagonal(a, 0, 1)
+            b_updated = torch.ops.aten.fill(b, 0)
+            a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
+            return a_updated
+
+        Ordinarily, we would not be able to reinplace the fill,
+        because "b" aliases with "a" which is used by the diagonal_scatter call.
+
+        "re-inplacing" is on the hook for figuring out that it is ok to
+        completely, the expensive diagonal_scatter call, if we re-inplace the add().
+
+        So, for every `alias in alias_set(a)`, instead of checking
+        that "alias" is not used anywhere later in the graph,
+        we check that
+            EITHER:
+          (a) alias is not used anywhere later in the graph
+            OR:
+          (b) alias is used exactly once later on in the graph,
+              in the following op:
+
+                out = foo_scatter(alias, x, args...)
+
+              where the following must hold:
+                (i) "foo_scatter" is the "inverse" operator for foo.
+                    This only applies to "foo" ops that are view operators,
+                    which view into a subset of the original tensor's memory.
+                    In practice, there are ~4 operators where this applies:
+                      diagonal -> diagonal_scatter
+                      slice -> slice_scatter
+                      select -> select_scatter
+                      as_strided -> as_strided_scatter
+                (ii) "args..." are the same between the foo() and foo_scatter() calls.
+
+    (3) Perform the actual re-inplacing on foo!
+
+      (3b) is the common case, but special care is needed for {view}_scatter (3a)
+
+      (3a) {view}_scatter ops.
+
+        Consider this program:
+          a = torch.zeros(2, 2)
+          b = torch.ones(2)
+          a[0] = b
+
+        Post functionalization, that will look like:
+          a = torch.zeros(2)
+          b = torch.ones(1)
+          a_updated = torch.select_scatter(a, b, 0, 0)
+
+        In this case though, there is no "functional" op to re-inplace!
+        Instead, we'd like to directly remove toe select_scatter call.
+        We already know from (3) that this is valid,
+        because "a" has no later usages in the graph.
+
+        We perform the re-inplacing on the {view}_scatter op like so
+        Before:
+          a_updated = torch.select_scatter(a, b, args...)
+        After:
+          a_slice = a.select(a, args...)
+          a_slice.copy_(b)
+
+      (3b) Otherwise, replace the functional op with its inplace variant.
+        Before:
+          b = foo(a, args...)
+        After:
+          a.foo_(args...)
+
+    (4) Finally, after converting either:
+          Before:
+            b = foo(a)
+          After:
+            foo_(a)
+        or
+          Before:
+            b = {slice}_scatter(a, mutated_slice, args...)
+          After:
+            slice = {slice}(a, args...)
+            slice.copy_(mutated_slice)
+
+        We now need to find all later nodes that use "b" as an argument
+        and update them to take in "a" instead.
+
+        Note that for the majority of inplace ops, this isn't actually necessary
+        (because most inplace ops return "self" as their output).
+        This isn't generally true for all mutable ops though, which is why
+        we need to actually replace all of the arguments.
+
+        We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
+        That maps a given tensor storage to the set of all nodes that take in that storage
+        as an input.
+        Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
+        together.
+
+    (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
+        during step (3) get manually deleted from the graph.
+        Their outputs are no longer used, so technically standard DCE would be able
+        to do this, but we can no longer run FX's DCE pass now that we have mutable
+        ops in the graph.
+    """
+    _FunctionalizationMetadataProp(gm).propagate(*sample_args)
+
+    # Useful debug printing
+    # def _print(x):
+    # if isinstance(x, FakeTensor):
+    # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
+
+    # for n in gm.graph.nodes:
+    # print(n.format_node())
+    # if hasattr(n, 'meta'):
+    # print(f'node_idx: {n.meta["node_idx"]}')
+    # if 'fake_result' in n.meta:
+    # tree_map(_print, n.meta['fake_result'])
+    # if 'view_of' in n.meta:
+    # print(f'view_of: {str(n.meta["view_of"])}')
+    # print()
+
+    # We need to know which nodes correspond to inputs (or their aliases)
+    # so we know not to re-inplace them.
+    # NOTE: later, we'll need to add an optimization for fully recovering performance
+    # on programs that mutate inputs.
+    input_storages = {
+        StorageWeakRef(
+            node.meta['fake_result']._typed_storage()
+        ) for node in gm.graph.nodes if node.op == 'placeholder'}
+
+
+    # We also need to know for a given node, what are all of its aliasing nodes.
+    storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
+    for n in gm.graph.nodes:
+        if 'fake_result' in n.meta:
+            # Tree-mapping because some ops can return lists of tensors.
+            def _add_to_map(x):
+                if isinstance(x, FakeTensor):
+                    storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
+            pytree.tree_map_(_add_to_map, n.meta['fake_result'])
+
+    # inplace-ify functional ops, subject to the constraints written below.
+    all_later_view_inverse_nodes_to_delete = set()
+    for idx, node in enumerate(gm.graph.nodes):
+        if node.op == 'call_function':
+
+            # Today, the re-inplace pass on directly acts on:
+            # - functional ops with an inplace variant
+            # - {view}_scatter ops that can be potentially removed from the graph.
+            # Both of these ops take in tensor first args, so filtering on this condition
+            # makes the later code simpler.
+            # We should revisit this at some point though, particularly when we also want
+            # the reinplacer to be able to handle out= and mutable operators
+            # and tensorlist first args (like `_foreach_` ops).
+            if not isinstance(node.target, torch._ops.OpOverload):
+                continue
+            if len(node.target._schema.arguments) < 1:
+                continue
+            if type(node.target._schema.arguments[0].type) != torch.TensorType:
+                continue
+
+            # Step 1a: Check that the self argument we're attempting to reinplace
+            # has the same size/stride as the output.
+            # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
+            # As it would require resizing scalar_tensor.
+            # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
+            # this is probably an optimization to revisit later).
+            self_arg = node.args[0]
+            self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
+            node_flattened = pytree.tree_leaves(node.meta['fake_result'])
+            self_has_wrong_metadata = False
+            if len(self_flattened) == len(node_flattened):
+                for self_meta, node_meta in zip(self_flattened, node_flattened):
+                    if self_meta.numel() != node_meta.numel():
+                        self_has_wrong_metadata = True
+                    if self_meta.dtype != node_meta.dtype:
+                        self_has_wrong_metadata = True
+                    # We also cannot re-inplace on tensors that have internal memory overlap.
+                    # e.g. torch.ones(1).expand(4, 4).add_(1)
+                    if torch._debug_has_internal_overlap(self_meta) == 1:
+                        self_has_wrong_metadata = True
+            # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
+            # Since users should never really be calling the functional "torch.ops.aten.resize"
+            # op directly in their programs.
+            if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
+                continue
+
+            # Step 1b: ensure that the op we're trying to re-inplace isn't a program input
+            self_arg_name = self_arg.name
+            self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
+            if self_arg_storage in input_storages:
+                # TODO: later, add the optimization for handling `copy_()` calls in the graph.
+                continue
+            if len([x for x in node.args if x is self_arg]) > 1:
+                # Step 1c:
+                # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
+                # so we prevent re-inplacing in this case.
+                continue
+
+            self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
+            self_aliases = storage_to_nodes[self_arg_storage]
+
+            # First, we find all later usages of any of the aliases of self_arg.
+            later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
+            # Then, we check if any of those later usages are actually view_scatter ops
+            # that are safe to fully remove.
+            later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
+
+            # Step 2: Check to see if the input to the op is re-used later in the graph.
+            # If not (same goes for its aliases), then this op is safe to re-in place.
+            # This is a slightly roundabout way to check that there are no later usages of the current self argument.
+            # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
+            can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
+            if not can_reinplace:
+                continue
+
+            # Step 3a: Special handling for when we see *_scatter operators.
+            # When we see an operator like `b = torch.slice_scatter(a, ...)`,
+            # instead of trying to "inplace" it into a.slice_scatter_(..._),
+            # we would prefer to remove it from the graph entirely,
+            # and instead copy_() the slice directly into the larger tensor.
+            # See the description of the algorithm for a full example.
+            if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
+                view_op = _VIEW_INVERSE_MAP[node.target]
+                # Before:
+                #   base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
+                # After:
+                #   slice = torch.ops.aten.slice.default(base, args...)
+                #   slice.copy_(mutated_slice)
+                with gm.graph.inserting_before(node):
+                    mutated_slice_node = node.args[1]
+                    remaining_slice_args = node.args[2:]
+                    slice_node = gm.graph.create_node(
+                        'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
+                    copy_node = gm.graph.create_node(
+                        'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
+                # Add the slice_scatter node to our "nodes to delete" list.
+                all_later_view_inverse_nodes_to_delete.add(node)
+
+
+            else:
+                # Step 3b: Check to see if this operator has an inplace variant.
+                maybe_inplace_op = _maybe_get_inplace_op(node.target)
+                if maybe_inplace_op is None:
+                    continue
+                # And if so, replace it with its inplace variant.
+                node.target = maybe_inplace_op
+
+            # At this point, 'storage_to_nodes' will be stale.
+            # Now that we're inplacing `b = foo(a)`, we need to effectively
+            # union together the dict values for b and a's storage.
+            # Hmm... morally I think we also want to keep the `fake_result` metadata
+            # up to date here, but I'm not sure how easy it is to do.
+            # Maybe it's fine to wait until the end of the pass to update it.
+            curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
+            storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
+            storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
+
+            # Need to remember the view_scatter view nodes we found so we can remove them alter.
+            all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
+
+            # Step 4:
+            # Now that we've replaced b = a.foo() with a.foo_(),
+            # We need to replace any later usages of "b" with "a"
+            for old in itertools.chain([node], later_view_inverse_node_usages):
+                new = old.args[0]
+                nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
+                for node_to_update in nodes_to_update:
+                    new_args = []
+                    args = node_to_update.args
+
+                    def replace_arg(a):
+                        if a == old:
+                            return new
+                        return a
+
+                    # First, replace usages of "b" with "a"
+                    node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
+                    node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
+
+                    # Second, update our storage_to_nodes data structure.
+                    old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
+                    node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
+
+                    old_res_storage = {
+                        StorageWeakRef(
+                            x._typed_storage()
+                        ) for x in old_flattened_res if isinstance(x, FakeTensor)}
+                    node_res_storage = {
+                        StorageWeakRef(
+                            x._typed_storage()
+                        ) for x in node_flattened_res if isinstance(x, FakeTensor)}
+
+                    # This will happen if we're updating a view op, e.g.
+                    # e.g. replacing
+                    #     x = view(old)
+                    #     x = view(new)
+                    # When that happens, we need to make sure to keep our
+                    # storage mapping up to date.
+                    #
+                    # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
+                    # or multiple tensors that all share the same storage.
+                    # We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
+                    if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
+                        new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
+                        new_res_storage = {
+                            StorageWeakRef(
+                                x._typed_storage()
+                            ) for x in new_flattened_res if isinstance(x, FakeTensor)}
+                        assert len(new_res_storage) == 1
+                        (old_ref,) = old_res_storage
+                        (new_ref,) = new_res_storage
+                        (node_ref,) = node_res_storage
+                        # Technically, "old_ref" and all its aliases will remain
+                        # in our mapping.
+                        # That should be fine though, since we deleted "old"
+                        # from the graph at this point.
+                        storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
+                        storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
+
+    # Step 4: delete any _scatter nodes that we de-functionalized
+    # Need to take care not to delete any of these nodes until after *all* modifications
+    # to the graph are finished.
+    for to_delete in all_later_view_inverse_nodes_to_delete:
+        gm.graph.erase_node(to_delete)
+
+
+    gm.recompile()
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/shape_prop.py b/MLPY/Lib/site-packages/torch/fx/passes/shape_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..63a665340c7b159cf6d50a4d3be888c6df40e1f8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/shape_prop.py
@@ -0,0 +1,195 @@
+# mypy: ignore-errors
+
+import torch
+import torch.fx
+import traceback
+
+from torch._dispatch.python import enable_python_dispatcher
+from torch.fx.node import Node, map_aggregate
+from typing import Any, Tuple, NamedTuple, Optional, Dict
+from torch.fx._compatibility import compatibility
+from torch._guards import detect_fake_mode
+
+__all__ = ['TensorMetadata', 'ShapeProp']
+
+@compatibility(is_backward_compatible=True)
+class TensorMetadata(NamedTuple):
+    # TensorMetadata is a structure containing pertinent information
+    # about a tensor within a PyTorch program.
+
+    # General Tensor metadata
+    shape : torch.Size
+    dtype : torch.dtype
+    requires_grad : bool
+    stride : Tuple[int, ...]
+    memory_format : Optional[torch.memory_format]
+
+    # Quantization metadata
+    is_quantized : bool
+    qparams: Dict[str, Any]
+
+def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
+    """
+    Extract a TensorMetadata NamedTuple describing `result`.
+    """
+    shape = result.shape
+    dtype = result.dtype
+    requires_grad = result.requires_grad
+    stride = result.stride()
+
+    memory_format = None
+
+    if include_contiguity:
+        memory_formats = {
+            torch.contiguous_format,
+            torch.channels_last,
+            torch.channels_last_3d,
+        }
+        for query_format in memory_formats:
+            if result.is_contiguous(memory_format=query_format):
+                memory_format = query_format
+                break
+
+    is_quantized = result.is_quantized
+    qparams: Dict[str, Any] = {}
+    if is_quantized:
+        qscheme = result.qscheme()
+        qparams["qscheme"] = qscheme
+        if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
+            qparams["scale"] = result.q_scale()  # type: ignore[assignment]
+            qparams["zero_point"] = result.q_zero_point()  # type: ignore[assignment]
+        elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
+            # In this branch, scale and zero_point are expected to be tensors,
+            # we store the values as immutable_list in TensorMetadata for
+            # easier serialization downstream
+            qparams["scale"] = result.q_per_channel_scales().tolist()  # type: ignore[assignment]
+            qparams["zero_point"] = result.q_per_channel_zero_points().tolist()  # type: ignore[assignment]
+            qparams["axis"] = result.q_per_channel_axis()  # type: ignore[assignment]
+
+    return TensorMetadata(
+        shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
+
+@compatibility(is_backward_compatible=True)
+class ShapeProp(torch.fx.Interpreter):
+    """
+    Execute an FX graph Node-by-Node and
+    record the shape and type of the result
+    into the corresponding node.
+
+    Example:
+         In this example, we record the shape
+         and data type of a module given
+         an example input ``torch.randn(50, D_in)``.
+         We print the name, shape and dtype of each node.
+
+        class TwoLayerNet(torch.nn.Module):
+            def __init__(self, D_in, H, D_out):
+                super().__init__()
+                self.linear1 = torch.nn.Linear(D_in, H)
+                self.linear2 = torch.nn.Linear(H, D_out)
+            def forward(self, x):
+                h_relu = self.linear1(x).clamp(min=0)
+                y_pred = self.linear2(h_relu)
+                return y_pred
+        N, D_in, H, D_out = 64, 1000, 100, 10
+        x = torch.randn(N, D_in)
+        y = torch.randn(N, D_out)
+        model = TwoLayerNet(D_in, H, D_out)
+        gm = torch.fx.symbolic_trace(model)
+        sample_input = torch.randn(50, D_in)
+        ShapeProp(gm).propagate(sample_input)
+
+        for node in gm.graph.nodes:
+            print(node.name, node.meta['tensor_meta'].dtype,
+                node.meta['tensor_meta'].shape)
+
+        The output of this code is:
+
+        x torch.float32 torch.Size([50, 1000])
+        linear1 torch.float32 torch.Size([50, 100])
+        clamp_1 torch.float32 torch.Size([50, 100])
+        linear2 torch.float32 torch.Size([50, 10])
+        output torch.float32 torch.Size([50, 10])
+
+    Args:
+         module (GraphModule): The module to be executed
+         fake_mode (FakeTensorMode): A fake mode for copying the gm
+
+    """
+    def __init__(self, gm, fake_mode=None):
+        super().__init__(gm)
+        if fake_mode is None:
+            fake_mode = detect_fake_mode()
+        if fake_mode is not None:
+            from torch._dynamo.utils import deepcopy_to_fake_tensor
+            # Note:
+            # We need fake execution cause the inputs are fake, however, we cannot fakify the module
+            # - because we need to write to the tensor_meta of the real module. So we fakify to
+            # produce a result (L131 below), to extract tensor meta, and then keep going.
+            #
+            # If we were to fakify, we would write to the wrong node, and then downstream fusion
+            # would be missing the tensor_meta.
+            #
+            # See torch/_inductor/overrides.py for where this is called upstream of fusion.
+            self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
+            self.fake_mode = fake_mode
+        else:
+            self.fake_module = None
+            self.fake_mode = None
+
+        self.real_module = self.module
+
+    def run_node(self, n : Node) -> Any:
+        try:
+            if self.fake_module is not None:
+                # Hacky swap. Alternatively, we could do this with overriding
+                # call_module and get_attr.
+                self.module = self.fake_module
+            try:
+                if self.fake_mode is not None:
+                    with self.fake_mode, enable_python_dispatcher():
+                        result = super().run_node(n)
+                else:
+                    result = super().run_node(n)
+            finally:
+                self.module = self.real_module
+        except Exception as e:
+            traceback.print_exc()
+            raise RuntimeError(
+                f"ShapeProp error for: node={n.format_node()} with "
+                f"meta={n.meta}"
+            ) from e
+
+        found_tensor = False
+
+        def extract_tensor_meta(obj):
+            if isinstance(obj, torch.Tensor):
+                nonlocal found_tensor
+                found_tensor = True
+                return _extract_tensor_metadata(obj)
+            else:
+                return obj
+
+        meta = map_aggregate(result, extract_tensor_meta)
+        if found_tensor:
+            n.meta['tensor_meta'] = meta
+
+        n.meta['type'] = type(result)
+        return result
+
+    def propagate(self, *args):
+        """
+        Run `module` via interpretation and return the result and
+        record the shape and type of each node.
+
+        Args:
+            *args (Tensor): the sample input.
+
+        Returns:
+            Any: The value returned from executing the Module
+        """
+        if self.fake_mode is not None:
+            fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
+        else:
+            fake_args = args
+        return super().run(*fake_args)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/split_module.py b/MLPY/Lib/site-packages/torch/fx/passes/split_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..b36022dabfdfc7fe34f07f5e98cabb8c84d0bb04
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/split_module.py
@@ -0,0 +1,514 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
+from collections import OrderedDict
+import logging
+
+import torch
+from torch.fx._compatibility import compatibility
+from torch.fx.graph_module import GraphModule
+from torch.fx.node import Node
+
+if TYPE_CHECKING:
+    import sympy  # noqa: F401
+
+__all__ = ["Partition", "split_module"]
+_LOGGER = logging.getLogger(__name__)
+
+@compatibility(is_backward_compatible=True)
+class Partition:
+    def __init__(self, name: str):
+        self.name: str = name
+        self.submod_name = f"submod_{name}"
+        self.node_names: List[str] = []
+        self.inputs: Dict[str, None] = {}
+        self.outputs: Dict[str, None] = {}
+        self.dependencies: Dict[str, None] = {}
+        self.dependents: Dict[str, None] = {}
+        self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
+        self.environment: Dict[Node, Node] = {}
+        self.targets: Dict[str, Any] = {}
+
+    def __repr__(self) -> str:
+        return (
+            f"name: {self.name},\n"
+            f" nodes: {self.node_names},\n"
+            f" inputs: {self.inputs},\n"
+            f" outputs: {self.outputs},\n"
+            f" partitions depended on: {self.dependencies},\n"
+            f" partition dependents: {self.dependents}"
+        )
+
+
+# Creates subgraphs out of main graph
+@compatibility(is_backward_compatible=True)
+def split_module(
+    m: GraphModule,
+    root_m: torch.nn.Module,
+    split_callback: Callable[[Node], int],
+    qualname_map: Optional[Dict[str, str]] = None,
+    keep_original_order: Optional[bool] = False,
+    keep_original_node_name: Optional[bool] = False,
+):
+    """
+    Creates subgraphs out of main graph
+
+    Args:
+        m (GraphModule): Graph module to split
+        root_m (torch.nn.Module): root nn module. Not currently used. Included
+            because the root nn module is usually transformed via
+            torch.fx._symbolic_trace.symbolic_trace (see example below)
+        split_callback (Callable[[Node], int]): Callable function
+            that maps a given Node instance to a numeric partition identifier.
+            split_module will use this function as the policy for which operations
+            appear in which partitions in the output Module.
+        qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
+            mapping from new target names in the module after split to old target
+            names in the original module.
+        keep_original_order: Optional[bool]: keep the original order of the GraphModule
+            or use the Topological order of the new constructed GraphModule
+
+
+    Returns:
+        GraphModule: the module after split.
+
+    Example:
+
+        This is a sample setup:
+
+            import torch
+            from torch.fx.symbolic_trace import symbolic_trace
+            from torch.fx.graph_module import GraphModule
+            from torch.fx.node import Node
+            from torch.fx.passes.split_module import split_module
+
+            class MyModule(torch.nn.Module):
+                def __init__(self):
+                    super().__init__()
+                    self.param = torch.nn.Parameter(torch.rand(3, 4))
+                    self.linear = torch.nn.Linear(4, 5)
+
+                def forward(self, x, y):
+                    z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
+                    w = self.linear(y).clamp(min=0.0, max=1.0)
+                    return z + w
+
+            # symbolically trace model
+            my_module = MyModule()
+            my_module_traced = symbolic_trace(my_module)
+
+            # random mod partitioning
+            partition_counter = 0
+            NPARTITIONS = 3
+
+            def mod_partition(node: Node):
+                global partition_counter
+                partition = partition_counter % NPARTITIONS
+                partition_counter = (partition_counter + 1) % NPARTITIONS
+                return partition
+
+            # split module in module with submodules
+            module_with_submodules = split_module(
+                my_module_traced, my_module, mod_partition
+            )
+
+        Output looks like this. Original graph is broken into partitions
+
+            > print(module_with_submodules)
+            GraphModule(
+                (submod_0): GraphModule(
+                    (linear): Linear(in_features=4, out_features=5, bias=True)
+                )
+                (submod_1): GraphModule(
+                    (linear): Linear(in_features=4, out_features=5, bias=True)
+                )
+                (submod_2): GraphModule()
+            )
+
+            def forward(self, x, y):
+                param = self.param
+                submod_0 = self.submod_0(x, param, y);  x = param = y = None
+                getitem = submod_0[0]
+                getitem_1 = submod_0[1];  submod_0 = None
+                submod_1 = self.submod_1(getitem, getitem_1);  getitem = getitem_1 = None
+                getitem_2 = submod_1[0]
+                getitem_3 = submod_1[1];  submod_1 = None
+                submod_2 = self.submod_2(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
+                return submod_2
+
+        Output of split module is the same as output of input traced module.
+        This is an example within a test setting:
+
+            > orig_out = my_module_traced(x, y)
+            > submodules_out = module_with_submodules(x, y)
+            > self.assertEqual(orig_out, submodules_out)
+            True
+    """
+
+    def construct_graph(
+        node: Node,
+        base_mod_env: Dict[str, Node],
+        base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
+    ):
+        if node.op == "placeholder":
+            default_value = (
+                node.args[0] if len(node.args) > 0 else inspect.Signature.empty
+            )
+            if keep_original_node_name:
+                args = () if default_value is inspect.Signature.empty else (default_value,)
+                base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
+            else:
+                base_mod_env[node.name] = base_mod_graph.placeholder(
+                    node.target, type_expr=node.type, default_value=default_value
+                )
+            base_mod_env[node.name].meta = node.meta.copy()
+        elif node.op == "get_attr":
+            base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
+            base_mod_env[node.name].meta = node.meta.copy()
+            attr_val = m
+            for atom in node.target.split("."):  # type: ignore[union-attr]
+                if not hasattr(attr_val, atom):
+                    raise AttributeError(f"Node target {node.target} not found!")
+                attr_val = getattr(attr_val, atom)
+            base_mod_attrs[node.target] = attr_val  # type: ignore[index]
+        return base_mod_env, base_mod_attrs
+
+    partitions: Dict[str, Partition] = {}
+    orig_nodes: Dict[str, Node] = {}
+    symbol_to_node: Dict["sympy.Symbol", Node] = {}
+
+    def record_cross_partition_use(
+        def_node: Node, use_node: Optional[Node]
+    ):  # noqa: B950
+        from torch.fx.experimental.symbolic_shapes import free_symbols
+
+        defined = getattr(def_node, "_fx_partition", None)
+        used = getattr(use_node, "_fx_partition", None)
+        if defined != used:
+            if defined is not None:
+                def_partition = partitions[defined]
+                def_partition.outputs.setdefault(def_node.name)
+                if used is not None:
+                    def_partition.dependents.setdefault(used)
+
+            if used is not None:
+                use_partition = partitions[used]
+                use_partition.inputs.setdefault(def_node.name)
+                if (def_val := def_node.meta.get("example_value")) is not None:
+                    for s in sorted(free_symbols(def_val), key=str):
+                        use_partition.inputs.setdefault(symbol_to_node[s].name)
+                if defined is not None:
+                    use_partition.dependencies.setdefault(defined)
+
+    def instantiate_node_partition_mapping(node):
+        partition_name = str(split_callback(node))
+
+        # add node to partitions
+        partition = partitions.get(partition_name)
+        if partition is None:
+            partitions[partition_name] = partition = Partition(partition_name)
+
+        partition.node_names.append(node.name)
+        node._fx_partition = partition_name
+
+    # Global State Nodes are nodes which by their global state effects,
+    # "taint" all downstream nodes while they are active.
+    GLOBAL_STATE_NODES = [
+        torch.amp._enter_autocast,
+        torch.amp._exit_autocast,
+        torch._C._set_grad_enabled
+    ]
+
+    # For grad regions:
+    # ------------------------
+    # 1. first region: we do nothing
+    # 2. subsequent regions: we insert the set_grad at the beginning
+    grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
+
+    # For autocast regions:
+    # ------------------------
+    # 1. first region: we will only insert the _exit at the end
+    # 2. intermediate regions: we will insert both the
+    #    _enter at the beginning and _exit at the end
+    # 3. last region: we will only insert _enter at the beginning
+    # We will do so in the order in which the autocasts were instantiated.
+    autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
+    autocast_exits: Dict[Node, Optional[Node]] = {}
+
+    active_grad = None
+    active_autocasts = set()
+
+    import sympy  # noqa: F811
+
+    for node in m.graph.nodes:
+        if node.op in ["placeholder", "get_attr", "output"]:
+            if (
+                node.op == "placeholder" and
+                (val := node.meta.get("example_value")) is not None and
+                isinstance(val, torch.SymInt) and
+                isinstance(val.node.expr, sympy.Symbol)
+            ):
+                symbol_to_node[val.node.expr] = node
+            continue
+
+        instantiate_node_partition_mapping(node)
+
+        if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
+            if node.target == torch._C._set_grad_enabled:
+                assert len(node.args) == 1
+                assert isinstance(node.args[0], bool)
+                active_grad = node
+                grad_regions[active_grad] = set({split_callback(node)})
+            elif node.target == torch.amp._enter_autocast:
+                # Should all be python constants
+                assert all(not isinstance(arg, Node) for arg in node.args)
+                active_autocasts.add(node)
+                autocast_regions[node] = set({split_callback(node)})
+                autocast_exits[node] = None
+            elif node.target == torch.amp._exit_autocast:
+                assert len(node.args) == 1
+                autocast_regions[node.args[0]].add(split_callback(node))
+                active_autocasts.remove(node.args[0])
+                autocast_exits[node.args[0]] = node
+
+        if active_grad is not None:
+            grad_regions[active_grad].add(split_callback(node))
+
+        for a in active_autocasts:
+            autocast_regions[a].add(split_callback(node))
+
+    assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
+
+    autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
+    grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
+
+    if _LOGGER.isEnabledFor(logging.DEBUG):
+        _LOGGER.debug("autocast_regions: %s", autocast_regions)
+        _LOGGER.debug("grad_regions: %s", grad_regions)
+
+    assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
+
+    # split nodes into partitions
+    highest_partition = -1
+    for node in m.graph.nodes:
+        orig_nodes[node.name] = node
+
+        # TODO currently placeholders/parameters aren't put into random partitions,
+        # rather they're added to the graphs where they are used down below
+        if node.op in ["placeholder", "get_attr"]:
+            continue
+        if node.op == "output":
+            torch.fx.graph.map_arg(
+                node.args[0], lambda n: record_cross_partition_use(n, None)
+            )
+            continue
+
+        if assert_monotonically_increasing:
+            pid = split_callback(node)
+            assert highest_partition <= pid, \
+                ("autocast or set_grad_enabled require monotonically increasing partitions:"
+                 f"highest: {highest_partition}, this node's: {pid}")
+            highest_partition = pid
+
+        # do not capture cross-partition dependencies for global state nodes as they will be
+        # self-contained - their setup and unwind will be isolated to each partition submodule.
+        if node.target not in GLOBAL_STATE_NODES:
+            torch.fx.graph.map_arg(
+                node.args, lambda def_node: record_cross_partition_use(def_node, node)
+            )
+            torch.fx.graph.map_arg(
+                node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
+            )  # noqa: B950
+
+    original_partition_order = list(partitions.keys())
+    # find partitions with no dependencies
+    root_partitions: List[str] = []
+    for partition_name, partition in partitions.items():
+        if not len(partition.dependencies):
+            root_partitions.append(partition_name)
+
+    # check partitions for circular dependencies and create topological partition ordering
+    sorted_partitions: List[str] = []
+    while root_partitions:
+        root_partition = root_partitions.pop()
+        sorted_partitions.append(root_partition)
+        for dependent in partitions[root_partition].dependents:
+            partitions[dependent].dependencies.pop(root_partition)
+            if not partitions[dependent].dependencies:
+                root_partitions.append(dependent)
+    if len(sorted_partitions) != len(partitions):
+        raise RuntimeError("cycle exists between partitions!")
+
+    # Enter prelude
+    for regions_mapping in [autocast_regions, grad_regions]:
+        for node, regions in regions_mapping.items():
+            assert len(regions) > 0
+            partitions[str(regions[0])].environment[node] = node
+            for r in regions[1:]:
+                partition = partitions[str(r)]
+                new_node = partition.graph.create_node(
+                    op=node.op,
+                    target=node.target,
+                    args=tuple(arg for arg in node.args),
+                    kwargs={},
+                    type_expr=node.type,
+                )
+                new_node.meta = node.meta.copy()  # is it really a good idea to copy this?
+                partition.environment[node] = new_node
+
+    # add placeholders to partition inputs
+    for partition_name in sorted_partitions:
+        partition = partitions[partition_name]
+        for inp in partition.inputs:
+            placeholder = partition.graph.placeholder(
+                inp,
+                type_expr=orig_nodes[inp].type,
+            )
+            placeholder.meta = orig_nodes[inp].meta.copy()
+            partition.environment[orig_nodes[inp]] = placeholder
+
+    # Transform nodes and collect targets for partition's submodule
+    for node in m.graph.nodes:
+        if hasattr(node, "_fx_partition"):
+            partition = partitions[node._fx_partition]
+
+            # swap out old graph nodes in kw/args with references to new nodes in this submodule
+            environment = partition.environment
+            gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
+            gathered_kwargs = torch.fx.graph.map_arg(
+                node.kwargs, lambda n: environment[n]
+            )
+
+            if node.op not in ["call_module", "get_attr"]:
+                target = node.target
+            else:
+                target_atoms = node.target.split(".")
+                target_attr = m
+                for atom in target_atoms:
+                    if not hasattr(target_attr, atom):
+                        raise AttributeError(f"Operator target {node.target} not found!")
+                    target_attr = getattr(target_attr, atom)
+                # target = target_atoms[-1]
+                target = "_".join(target_atoms)
+                partition.targets[target] = target_attr
+                # Fill in the passed-in mapping from new qualname to old qualname
+                if qualname_map is not None:
+                    # When creating the split module later, the submodules will have
+                    # path prefix matching the corresponding partition's submod_name
+                    qualname = f"{partition.submod_name}.{target}"
+                    qualname_map[qualname] = node.target
+
+            assert isinstance(gathered_args, tuple)
+            assert isinstance(gathered_kwargs, dict)
+            name = node.name if keep_original_node_name else None
+            new_node = partition.graph.create_node(
+                op=node.op,
+                target=target,
+                args=gathered_args,
+                kwargs=gathered_kwargs,
+                type_expr=node.type,
+                name=name,
+            )
+            new_node.meta = node.meta.copy()
+            partition.environment[node] = new_node
+
+    # Exit epilogue
+    for regions_mapping in [autocast_regions]:
+        for node in reversed(regions_mapping):
+            regions = regions_mapping[node]
+            assert len(regions) > 0
+            for r in regions[:-1]:
+                partition = partitions[str(r)]
+                exit_node = autocast_exits[node]
+                assert exit_node is not None, "Missing exit node"
+                new_node = partition.graph.create_node(
+                    op=exit_node.op,
+                    target=exit_node.target,
+                    args=(partition.environment[node],),
+                    kwargs={},
+                    type_expr=exit_node.type,
+                )
+                new_node.meta = exit_node.meta.copy()  # is it really a good idea to copy this?
+
+    # original module environment dict mapping node names to nodes
+    orig_mod_env: Dict[str, Node] = {}
+    # Set up values to construct base module
+    base_mod_env: Dict[str, Node] = {}
+    base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
+    base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
+    if not keep_original_order:
+        for node in m.graph.nodes:
+            base_mod_env, base_mod_attrs = construct_graph(
+                node, base_mod_env, base_mod_attrs
+            )
+
+    else:
+        # Go through the graph to construct the mapping dict
+        for node in m.graph.nodes:
+            orig_mod_env[node.name] = node
+
+    # Do some things iterating over the partitions in topological order again:
+    # 1) Finish off submodule Graphs by setting corresponding outputs
+    # 2) Construct GraphModules for each submodule
+    # 3) Construct the base graph by emitting calls to those submodules in
+    #    topological order or original order specified by keep_original_order
+
+    construct_order_partitions = (
+        sorted_partitions if not keep_original_order else original_partition_order
+    )
+
+    already_constructed_attr_nodes = set()
+    for partition_name in construct_order_partitions:
+        partition = partitions[partition_name]
+
+        # Set correct output values
+        output_vals = tuple(
+            partition.environment[orig_nodes[name]] for name in partition.outputs
+        )
+
+        # skip output node generation if there are no output values
+        num_output_vals = len(output_vals)
+        if num_output_vals == 1:
+            partition.graph.output(output_vals[0])
+        elif num_output_vals > 1:
+            partition.graph.output(output_vals)
+
+        if keep_original_order:
+            # first get the attr nodes required by this partition
+            orig_mod_attr_nodes: List[Node] = [
+                orig_mod_env[key] for key in partition.inputs
+            ]
+            # Construct GraphModule for this partition
+            for node in orig_mod_attr_nodes:  # type: ignore[attr-defined]
+                if node in already_constructed_attr_nodes:
+                    continue
+                base_mod_env, base_mod_attrs = construct_graph(
+                    node, base_mod_env, base_mod_attrs
+                )
+                already_constructed_attr_nodes.add(node)
+
+        base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
+            partition.targets, partition.graph
+        )  # noqa: B950
+
+        # Emit call in base graph to this submodule
+        output_val = base_mod_graph.call_module(
+            partition.submod_name,
+            tuple(base_mod_env[name] for name in partition.inputs),
+        )
+
+        num_outputs = len(partition.outputs)
+        if num_outputs > 1:
+            # Unpack multiple return values from submodule
+            output_val_proxy = torch.fx.proxy.Proxy(output_val)
+            for i, output_name in enumerate(partition.outputs):
+                base_mod_env[output_name] = output_val_proxy[i].node  # type: ignore[index]
+        elif num_outputs == 1:
+            base_mod_env[next(iter(partition.outputs))] = output_val
+
+    for node in m.graph.nodes:
+        if node.op == "output":
+            base_mod_graph.output(
+                torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
+            )  # noqa: B950
+
+    return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/split_utils.py b/MLPY/Lib/site-packages/torch/fx/passes/split_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f682a42dbb91b6b8b88f7b3f9e854724b718e30a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/split_utils.py
@@ -0,0 +1,302 @@
+import copy
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple, Type, Union
+
+import torch.fx
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import map_arg
+from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
+
+from .tools_common import NodeList
+
+__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
+
+
+@compatibility(is_backward_compatible=False)
+def getattr_recursive(obj, name):
+    for layer in name.split("."):
+        if hasattr(obj, layer):
+            obj = getattr(obj, layer)
+        else:
+            return None
+    return obj
+
+
+@compatibility(is_backward_compatible=False)
+def setattr_recursive(obj, attr, value):
+    if "." not in attr:
+        setattr(obj, attr, value)
+    else:
+        layer = attr.split(".")
+        setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class Component:
+    """
+    A component serves as a container for a subgraph we want to create afterwards.
+    """
+
+    graph: torch.fx.Graph
+    order: int
+    name: str
+
+    # Stores the placeholder nodes in `graph`.
+    input_placeholders: List = field(default_factory=list)
+
+    # Store the nodes in original graph that are placeholder in `graph`.
+    orig_inputs: List = field(default_factory=list)
+
+    # Store the nodes in original graph that are outputs in `graph`.
+    orig_outputs: List = field(default_factory=list)
+
+    # Mapping from get_attr node in original graph to get_attr node in `graph`.
+    getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
+    constructor_args: List[str] = field(default_factory=list)
+    gm: Optional[torch.fx.GraphModule] = None
+
+
+@compatibility(is_backward_compatible=False)
+def split_by_tags(
+    gm: torch.fx.GraphModule,
+    tags: List[str],
+    return_fqn_mapping: bool = False,
+    return_tuple: bool = False,
+    GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
+) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
+    """
+    Splits a GraphModule using tags on its graph nodes. We honor the order of
+    tags. For example, we have tags = ["a", "b", "c"], the function will create
+    the initial submodules in the order of "a", "b", "c".
+
+    To set a tag:
+    gm.graph.nodes[idx].tag = "mytag"
+
+    This will result in all nodes with the same tag being extracted and placed in their
+    own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
+    and output nodes are created when needed while get_attr nodes get copied to submodules
+    where they are used.
+
+    Given the following module def:
+
+    class SimpleModule(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.linear1 = torch.nn.Linear(...)
+            self.linear2 = torch.nn.Linear(...)
+            self.linear3 = torch.nn.Linear(...)
+
+        def forward(self, in1, in2):
+            r1 = self.linear1(in1)
+            r2 = self.linear2(in2)
+            r3 = torch.cat([r1, r2])
+            return self.linear3(r3)
+
+    Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
+
+    ro:
+    def forward(self, in1):
+        self = self.root
+        linear1 = self.linear1(in1)
+        return linear1
+
+    main:
+    def forward(self, in2, linear1):
+        self = self.root
+        linear2 = self.linear2(in2)
+        cat_1 = torch.cat([linear1, linear2])
+        linear3 = self.linear3(cat_1)
+        return linear3
+
+    main:
+    def forward(self, in1, in2):
+        self = self.root
+        ro_0 = self.ro_0(in1)
+        main_1 = self.main_1(in2, ro_0)
+        return main_1
+
+    Returns:
+        split_gm: torch fx graph after split
+        orig_to_split_fqn_mapping: a map between the original fqn and the fqn
+            after split for call_module and get_attr.
+    """
+
+    def flatten(x: torch.fx.node.Argument) -> NodeList:
+        """
+        Stores nodes in x to a list and returns the list.
+        """
+        r: NodeList = []
+        map_arg(x, r.append)
+        return r
+
+    # Mapping from node in original module to node in created submodule.
+    node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Mapping from node in original module or created submodules to
+    # corresponding component.
+    node_to_component: Dict[torch.fx.Node, Component] = {}
+
+    # Mapping from tag to the corresponding component.
+    tag_to_component: Dict[str, Component] = {}
+
+    # Stores all components.
+    all_components: List[Component] = []
+
+    # Stores nodes that will be used in main graph.
+    used_in_main: Dict[torch.fx.Node, None] = {}
+
+    # Main graph after split.
+    main_g = torch.fx.Graph()
+
+    # Mapping from node in original module to node in main graph after split.
+    main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
+
+    # Output node of original module.
+    output_node: Optional[torch.fx.Node] = None
+
+    # Create a component for each tag, we don't expect to create other components afterwards.
+    for tag in tags:
+        comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
+        all_components.append(comp)
+        tag_to_component[tag] = comp
+
+    # Traverse the nodes in original graph and take care of them.
+    for node in gm.graph.nodes:
+        if node.op == "output":
+            if output_node is not None:
+                raise RuntimeError("Multiple output nodes in graph!")
+            output_node = node
+            continue
+
+        # Placeholders in the original graph get copied to main graph.
+        if node.op == "placeholder":
+            main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
+            main_remapping[node].meta = copy.copy(node.meta)
+            continue
+
+        # Get_attr nodes are ignored because we are not tagging them.
+        # Instead, we copy them directly to the submodules use them afterwards.
+        if node.op == "get_attr":
+            continue
+
+        # Now we process callable nodes which are nodes with op of call_module,
+        # call_function or call_method. Every callable nodes should be tagged.
+        assert hasattr(node, "tag")
+
+        upstream_components = [
+            node_to_component[x]
+            for x in flatten(node.args) + flatten(node.kwargs)
+            if x.op not in {"placeholder", "get_attr"}
+        ]
+
+        comp = tag_to_component[node.tag]
+        node_to_component[node] = comp
+
+        # Max order of upperstream components.
+        mx = max((c.order for c in upstream_components), default=0)
+
+        # Expect the component for `node` has higher order then its upstream components.
+        assert comp.order >= mx
+
+        # Map a input of `node` to nodes in the component's graph.
+        def remap_func(x):
+            # If input is a get_attr node, copy it to current component's graph.
+            # Returns the get_attr node in current component's graph.
+            if x.op == "get_attr":
+                if x not in comp.getattr_maps:
+                    comp.getattr_maps[x] = comp.graph.get_attr(
+                        x.target, type_expr=x.type
+                    )
+                return comp.getattr_maps[x]
+
+            # If input is not a placeholder, it should have been put into a component
+            # already. If it's the current component then we return the corresponding
+            # node in the component.
+            if x.op != "placeholder" and node_to_component[x] == comp:
+                return node_remapping[x]
+
+            # If input is a placeholder or it's in other components, we want to make it
+            # as a placeholder in current component's graph.
+            if x not in comp.orig_inputs:
+                comp.orig_inputs.append(x)
+                placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
+                placeholder.meta = copy.copy(x.meta)
+                comp.input_placeholders.append(placeholder)
+                used_in_main[x] = None
+
+            return comp.input_placeholders[comp.orig_inputs.index(x)]
+
+        n = comp.graph.node_copy(node, remap_func)
+        n.tag = node.tag  # type: ignore[attr-defined]
+        node_remapping[node] = n
+        node_to_component[n] = comp
+
+    if output_node is None:
+        raise RuntimeError("Graph had no output node!")
+
+    for x in flatten(output_node.args[0]):
+        if x.op == "get_attr":
+            # We don't need components mapping for nodes of type "get_attr"
+            # that are consumed by the output. Only need to make sure we create
+            # corresponding counterparts in the resulting graph.
+            main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
+        else:
+            # All component results consumed by the output node should be
+            # marked as "used in main".
+            used_in_main[x] = None
+
+    # If a node is used in main graph then we mark it as an output in the component
+    # it belongs to.
+    for n in used_in_main:
+        if n.op != "placeholder":
+            node_to_component[n].orig_outputs.append(n)
+
+    # Now we create a graphmodule for each component.
+    orig_to_split_fqn_mapping: Dict[str, str] = {}
+    for comp in all_components:
+        outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
+
+        if return_tuple:
+            comp.graph.output(outs)
+        else:
+            # Take care of the args of FX output node. If there's a single
+            # output then the output node args is like (output_single), else
+            # if there're multiple outputs then the output node args is like
+            # ((output_0, output_1, ...)).
+            comp.graph.output(outs[0] if len(outs) == 1 else outs)
+
+        comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
+            gm, subgraph=comp.graph, comp_name=comp.name
+        )
+        orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
+
+        # Create a call_module node in main graph.
+        main_node = main_g.call_module(
+            comp.name,
+            args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
+            kwargs=None,
+        )
+
+        if len(outs) == 1 and not return_tuple:
+            main_remapping[comp.orig_outputs[0]] = main_node
+        else:
+            for i, o in enumerate(comp.orig_outputs):
+                # Use Proxy to record getitem access.
+                main_remapping[o] = torch.fx.Proxy(main_node)[i].node  # type: ignore[index]
+
+    main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
+    main_root = HolderModule({comp.name: comp.gm for comp in all_components})
+    main_g._codegen = gm.graph._codegen
+
+    # If the output nodes consumes get_attr directly in the original graph,
+    # then we need to make sure get_attr is copied to the new graph.
+    for x in flatten(output_node.args[0]):
+        if x.op == "get_attr":
+            setattr(main_root, x.name, getattr_recursive(gm, x.target))  # type: ignore[arg-type]
+
+    result_gm = GraphModuleCls(main_root, main_g)
+    if return_fqn_mapping:
+        return result_gm, orig_to_split_fqn_mapping
+
+    return result_gm
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/splitter_base.py b/MLPY/Lib/site-packages/torch/fx/passes/splitter_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2e4d93f99b56e1ba161a3b5a17d395f6a6c8e18
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/splitter_base.py
@@ -0,0 +1,871 @@
+import argparse
+import copy
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
+import logging
+
+import torch
+from torch.fx.passes.graph_manipulation import get_size_of_node
+from torch.fx.node import map_arg
+from torch.fx._compatibility import compatibility
+
+from .operator_support import (
+    get_node_target,
+    OperatorSupportBase,
+)
+from .graph_drawer import FxGraphDrawer
+from .shape_prop import ShapeProp
+from .split_utils import split_by_tags
+from .tools_common import (
+    FxNetAccFusionsFinder,
+    CALLABLE_NODE_OPS,
+    Tensors,
+    NodeList,
+    NodeSet,
+    is_node_output_tensor,
+)
+
+
+__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
+_LOGGER = logging.getLogger(__name__)
+
+DEFAULT_MIN_ACC_MODULE_SIZE = 1
+DEFAULT_SKIP_FUSION = False
+DEFAULT_ALLOW_NON_TENSOR = False
+
+class _SplitterSettingBase:
+    def __init__(
+        self,
+        min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
+        skip_fusion=DEFAULT_SKIP_FUSION,
+        allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
+    ):
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "--min-acc-module-size",
+            "--min_acc_module_size",
+            required=False,
+            type=int,
+            help="Minimum size limit of an accelerator subgraph.",
+        )
+        parser.add_argument(
+            "--skip-fusion",
+            "--skip_fusion",
+            default=False,
+            action="store_true",
+            help="If true then no fusion groups. Fusion group is used to "
+            "enforce no non-tensor data flow between submodules. If we don't "
+            "have this constrain, setting this to false is recommended as it "
+            "can reduce overhead.",
+        )
+        parser.add_argument(
+            "--allow-non-tensor",
+            "--allow_non_tensor",
+            default=False,
+            action="store_true",
+            help="For some backends non-tensor data flow between cpu and them "
+            "are not allowed. Therefore, if a node supported by accelerator but "
+            "it has non-tensor inputs or outputs to a cpu node we would want to "
+            "consider it as a cpu node during splitting. However, for some backends "
+            "we might not care about non-tensor data flow and we can set this option "
+            "to true to disable the functionality that prevent non-tensor data flow.",
+        )
+        args, unknown = parser.parse_known_args()
+
+        self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
+        self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
+        self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
+
+
+@compatibility(is_backward_compatible=False)
+class FxNetAccNodesFinder:
+    """
+    Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
+    input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
+
+    I.e. if we have a chain:
+
+    ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
+
+    where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
+
+    This behavior can be turned off by passing allow_non_tensor=True.
+    """
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        operator_support: OperatorSupportBase,
+        allow_non_tensor: bool,
+    ):
+        self.module = module
+        self.operator_support = operator_support
+        self.allow_non_tensor = allow_non_tensor
+
+    def reduce_acc_nodes_non_tensor_input_helper(
+        self, cpu_worklist: NodeList
+    ):
+        """
+        Transitively excludes nodes from ACC supported set.
+        For every node in the worklist:
+        - removes its downstream ACC nodes from ACC supported set,
+        - if any downstream ACC node produces non-tensor output,
+          then it gets added into the worklist.
+        """
+        while cpu_worklist:
+            node = cpu_worklist.pop(0)
+
+            for user in node.users:
+                if user in self.acc_nodes:
+                    self.acc_nodes.remove(user)
+                    if not is_node_output_tensor(user):
+                        cpu_worklist.append(user)
+
+    def reduce_acc_nodes_non_tensor_input(self):
+        """
+        Excludes nodes from ACC supported set that have direct
+        upstream CPU nodes that produce non-tensor outputs.
+        """
+        non_tensor_cpu_nodes: NodeList = []
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+            if node in self.acc_nodes:
+                continue
+            if is_node_output_tensor(node):
+                continue
+            non_tensor_cpu_nodes.append(node)
+
+        self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
+
+    def reduce_acc_nodes_non_tensor_output(self):
+        """
+        Excludes nodes from ACC supported set that produce non-tensor
+        outputs and have downstream CPU nodes.
+        """
+        while True:
+            new_cpu_nodes: NodeList = []
+
+            for acc_node in self.acc_nodes:
+                if is_node_output_tensor(acc_node):
+                    continue
+                for user in acc_node.users:
+                    if user not in self.acc_nodes:
+                        new_cpu_nodes.append(acc_node)
+                        break
+
+            if not new_cpu_nodes:
+                break
+
+            for new_cpu_node in new_cpu_nodes:
+                self.acc_nodes.remove(new_cpu_node)
+
+            self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
+
+    def __call__(self) -> NodeSet:
+        submodules = dict(self.module.named_modules())
+        self.acc_nodes = {
+            n
+            for n in self.module.graph.nodes
+            if n.op in CALLABLE_NODE_OPS
+            and self.operator_support.is_node_supported(submodules, n)
+        }
+
+        if not self.allow_non_tensor:
+            self.reduce_acc_nodes_non_tensor_input()
+            self.reduce_acc_nodes_non_tensor_output()
+
+        return self.acc_nodes
+
+@compatibility(is_backward_compatible=False)
+class FxNetSplitterInternalError(Exception):
+    pass
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class Subgraph:
+    is_acc: bool
+    nodes: NodeList
+
+
+@compatibility(is_backward_compatible=False)
+class SplitResult(NamedTuple):
+    """
+    Stores the results of the splitter.
+
+    Attributes:
+        split_module: root module after splitting.
+        submodule_inputs: a dict that maps submodule name to its inputs.
+        non_acc_submodule_prefix: the prefix for non acc submodules. For
+            acc submodule the prefix is alwasy "_run_on_acc_".
+    """
+
+    split_module: torch.fx.GraphModule
+    submodule_inputs: Dict[str, Any]
+    non_acc_submodule_prefix: str
+
+
+@compatibility(is_backward_compatible=False)
+def generate_inputs_for_submodules(
+    model: torch.nn.Module,
+    inputs: Sequence[Any],
+    target_submodules: Iterable[str],
+    deepcopy: bool = False,
+) -> Dict[str, Any]:
+    """
+    Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
+    function doesn't work.
+
+    Args:
+        model: root model.
+        inputs: inputs to the root model.
+        target_submodules: submodules that we want to generate inputs for.
+
+    Returns:
+        A dict that maps from submodule name to its inputs.
+    """
+
+    handles = []
+    results = {}
+    submodule_to_names = {mod: name for name, mod in model.named_modules()}
+
+    def pre_forward(module, module_inputs):
+        results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
+
+    for name, mod in model.named_modules():
+        if name in target_submodules:
+            handles.append(mod.register_forward_pre_hook(pre_forward))
+
+    def clean_up_handles():
+        for h in handles:
+            h.remove()
+
+    try:
+        with torch.no_grad():
+            model(*inputs)
+    except Exception as e:
+        clean_up_handles()
+        raise e
+
+    clean_up_handles()
+    return results
+
+
+class _SplitterBase:
+    """
+    Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
+    Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
+    Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
+
+    Given the following graph:
+          ==> b ==>
+        //         \\
+       a             d
+        \\         //
+          ==> c ==>
+
+    class SimpleModule(torch.nn.Module):
+        def forward(self, a):
+            b = torch.sin(a)
+            c = torch.cos(a)
+            d = b + c
+            return d
+
+    and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
+    we will get the following split result:
+
+    main:
+    def forward(self, a):
+        run_on_acc_0_0 = self._run_on_acc_0_0(a)
+        getitem = run_on_acc_0_0[0]
+        getitem_1 = run_on_acc_0_0[1]
+        run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
+        return run_on_cpu_1_1
+
+    _run_on_acc_0_0:
+    def forward(self, a):
+        sin_1 = torch.sin(a)
+        cos_1 = torch.cos(a)
+        return (sin_1, cos_1)
+
+    _run_on_cpu_1_1:
+    def forward(self, sin_1, cos_1):
+        add_1 = sin_1 + cos_1
+        return add_1
+    """
+
+    # PCIe bandwidth for the backend, default to 100 GB/s
+    PCIe_BW = 100 * 2 ** 30
+
+    def __init__(
+        self,
+        module: torch.fx.GraphModule,
+        sample_input: Sequence[Any],
+        operator_support: OperatorSupportBase,
+        settings: _SplitterSettingBase,
+        non_acc_submodule_name: str = "_run_on_cpu_",
+    ):
+        """
+        Preprocesses graph before splitting:
+        - finds nodes supported by ACC,
+        - finds fusion groups for ACC nodes having non-tensor IO,
+        - builds a graph of direct dependencies,
+        - builds a map of fused nodes to their fusions.
+        As a result we get self.acc_nodes, self.deps and self.fusions.
+        """
+        assert isinstance(module, torch.fx.GraphModule)
+
+        self.module = module
+        ShapeProp(self.module).propagate(*sample_input)
+
+        self.settings = settings
+        self.operator_support = operator_support
+        self.sample_input = sample_input
+        self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
+
+        if self.settings.skip_fusion:
+            self.fusions = {}
+        else:
+            self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
+
+        # Modify deps to add more deps for fused nodes
+        self.deps = self.find_deps()
+        self.update_deps_for_fusions()
+
+        self.non_acc_submodule_name = non_acc_submodule_name
+        self._node_submodule_map: Dict[str, str] = {}
+
+    # ===============================================================
+    # Helpers for ctor and initial state
+    # ===============================================================
+
+    def get_node_submodule_map(self) -> Dict[str, str]:
+        """ Returns a map from node name to submodule name, e.g.
+            node: main_module_impl_impl_over_arch_unary_multiple_embedding
+              _pooling_embedding_pooling_sparse_entity_equivalence_key
+              _proxy_embedding_bag
+            maps to submodule name of: _run_on_acc_1
+        """
+        return self._node_submodule_map
+
+    def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
+        """
+        Builds a graph of node dependencies. Leaf nodes don't have any
+        dependencies and the "output" node doesn't have nodes depending on it.
+
+        Resulting graph has only direct dependencies, i.e. there are no
+        transitive dependencies.
+        """
+        deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            for user in node.users:
+                if user.op != "output":
+                    deps[user].add(node)
+        return deps
+
+    def update_deps_for_fusions(self):
+        """
+        Updates graph of dependencies so that:
+        - nodes from the same fusion depend on the same set of outer nodes,
+        - outer nodes depending on a fusion depend on all nodes in that fusion.
+        """
+        for node in self.fusions:
+            fusion = self.fusions[node]
+            for fused_neighbor in fusion:
+                self.deps[node].update(self.deps[fused_neighbor] - fusion)
+
+                for user in fused_neighbor.users:
+                    if user not in fusion:
+                        self.deps[user].add(node)
+
+    # ===============================================================
+    # Helpers for preview
+    # ===============================================================
+
+    def _lower_model_to_backend(
+        self, mod: torch.fx.GraphModule, inputs: Tensors
+    ) -> torch.nn.Module:
+        """
+        Lower the model to a backend.
+        """
+
+        return mod
+
+    def _find_culprit(
+        self, mod: torch.fx.GraphModule, inputs: Tensors
+    ) -> str:
+        """
+        When an error occurs during lowering or running the lowered mod, we use this
+        function to find culprits in the `mod` that causes the error.
+        """
+
+        return "Unable to find a culprit because _find_culprit() function is not implemented."
+
+    def _draw_graph_based_on_node_support(
+        self, mod: torch.fx.GraphModule, supported_nodes: NodeList
+    ):
+        color_map = {
+            "default": "AliceBlue",
+            "supported": "chartreuse1",
+            "unsupported": "crimson",
+        }
+
+        class CustomDrawer(FxGraphDrawer):
+            def _get_node_style(self, node):
+                template = super()._get_node_style(node)
+                if node in supported_nodes:
+                    template["fillcolor"] = color_map["supported"]
+                elif node.op in CALLABLE_NODE_OPS:
+                    template["fillcolor"] = color_map["unsupported"]
+                else:
+                    template["fillcolor"] = color_map["default"]
+
+                return template
+
+        drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
+        dot_graph = drawer.get_main_dot_graph()
+        dot_graph.write_raw("node_support.dot")
+
+    def node_support_preview(self, dump_graph: bool = False):
+        submodules = dict(self.module.named_modules())
+
+        supported_nodes: NodeList = []
+        supported_node_types = defaultdict(set)
+        unsupported_node_types = defaultdict(set)
+
+        def get_dtype(arg):
+            tensor_meta = arg.meta.get("tensor_meta")
+            return getattr(tensor_meta, "dtype", None)
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            target = get_node_target(submodules, node)
+
+            # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
+            arg_dtypes = [
+                get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
+                for arg in node.args
+            ]
+
+            # Find last non-None element. If all elements are None, return max_len.
+            last_index = len(arg_dtypes) - next(
+                (
+                    i
+                    for i, dtype in enumerate(reversed(arg_dtypes))
+                    if dtype is not None
+                ),
+                len(arg_dtypes),
+            )
+
+            # Strip None elements at the end.
+            arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
+            kwarg_dtypes_tuple = tuple(
+                (k, get_dtype(arg))
+                for k, arg in node.kwargs.items()
+                if isinstance(arg, torch.fx.Node)
+            )
+
+            if self.operator_support.is_node_supported(submodules, node):
+                supported_nodes.append(node)
+                supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
+            else:
+                unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
+
+        if dump_graph:
+            self._draw_graph_based_on_node_support(self.module, supported_nodes)
+
+        reports = "\nSupported node types in the model:\n"
+        for t, dtypes in supported_node_types.items():
+            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
+                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
+
+        reports += "\nUnsupported node types in the model:\n"
+        for t, dtypes in unsupported_node_types.items():
+            for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
+                reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
+
+        print(reports)
+
+        # Return reports for testing purpose
+        return reports
+
+    def split_preview(self, dump_graph: bool = False):
+        reports = ""
+        subgraphs = self.put_nodes_into_subgraphs()
+        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
+        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
+        reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
+        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
+
+        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
+        acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
+        cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
+        reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
+        reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
+
+        for i, subgraph in enumerate(subgraphs):
+            reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
+            reports += f"{len(subgraph.nodes)} node(s)\n"
+
+        self.tag(subgraphs)
+        split_mod = self.split(remove_tag=True)
+        split_mod.eval()
+
+        if dump_graph:
+            drawer = FxGraphDrawer(
+                split_mod, "preview", ignore_getattr=True
+            )
+            dot_graphs = drawer.get_all_dot_graphs()
+            for name, dot_graph in dot_graphs.items():
+                dot_graph.write_raw(f"{name}.dot")
+
+        max_qps: float = self.PCIe_BW
+        bottleneck_module = ""
+
+        for node in split_mod.graph.nodes:
+            if node.op == "call_module" and "acc" in node.target:
+                reports += f"\nProcessing acc submodule {node.target}\n"
+
+                submod = getattr(split_mod, node.target)
+
+                def get_submod_inputs(main_mod, submod, example_inputs):
+                    sub_inputs = None
+
+                    def get_inputs(self, inputs):
+                        nonlocal sub_inputs
+                        sub_inputs = inputs
+
+                    handle = submod.register_forward_pre_hook(get_inputs)
+                    main_mod(*example_inputs)
+                    handle.remove()
+                    return sub_inputs
+
+                submod_inputs = get_submod_inputs(
+                    split_mod, submod, self.sample_input
+                )
+                ShapeProp(submod).propagate(*submod_inputs)
+
+                total_input_bytes = 0
+                total_output_bytes = 0
+
+                reports += "Checking inputs...\n"
+                for n in submod.graph.nodes:
+                    if n.op == "placeholder":
+                        if not is_node_output_tensor(n):
+                            reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
+                        else:
+                            total_input_bytes += get_size_of_node(submod, n)[0]
+                    if n.op == "output":
+                        output_node = n
+
+                reports += "Checking outputs...\n"
+
+                def get_bytes(node: torch.fx.Node):
+                    nonlocal total_output_bytes
+                    nonlocal reports
+                    if not is_node_output_tensor(node):
+                        reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
+                    else:
+                        total_output_bytes += get_size_of_node(submod, node)[0]
+
+                map_arg(output_node.args, get_bytes)  # type: ignore[possibly-undefined]
+                qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
+                reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
+                reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
+
+                if qps < max_qps:
+                    max_qps = qps
+                    bottleneck_module = node.target
+
+                try:
+                    lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
+                except RuntimeError:
+                    reports += "Run into an error during lowering!\n"
+                    reports += self._find_culprit(submod, submod_inputs)
+                    continue
+
+                try:
+                    lowered_submod(*submod_inputs)
+                except RuntimeError:
+                    reports += "Run into an error during inference!\n"
+                    reports += self._find_culprit(submod, submod_inputs)
+                else:
+                    reports += "Lowering and running succeed!\n"
+
+        reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
+        reports += f" bottleneck is submodule {bottleneck_module}."
+        print(reports)
+
+        # return the reports for testing purposes
+        return reports
+
+    # ===============================================================
+    # Helpers for extend_acc_subgraph() method
+    # ===============================================================
+
+    def find_reverse_deps(
+        self, tag_id: Optional[int] = None
+    ) -> Dict[torch.fx.Node, NodeSet]:
+        """
+        Builds reversed topological node dependencies, if tag_id is specified,
+        we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
+        """
+        result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
+
+        for node in self.module.graph.nodes:
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+
+            for user in node.users:
+                if user.op not in CALLABLE_NODE_OPS:
+                    continue
+
+                if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
+                    result[node].add(user)
+
+        return result
+
+    def update_reverse_deps_for_fusions(
+        self, deps: Dict[torch.fx.Node, NodeSet]
+    ):
+        processed_node = set()
+
+        for node, fusion in self.fusions.items():
+            if node in processed_node:
+                continue
+
+            new_dep = set()
+
+            # Create a new dependency set which include all the
+            # dependencies of the nodes in the fusion group
+            for n in fusion:
+                new_dep.update(deps[n])
+
+            # Exclude nodes in the fusion
+            new_dep.difference_update(fusion)
+
+            # Update dependency
+            for n in fusion:
+                deps[n] = new_dep
+
+                for arg in n.all_input_nodes:
+                    if arg not in fusion:
+                        deps[arg].update(fusion)
+
+                processed_node.add(n)
+
+    def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
+        """
+        Finds parent nodes of the `tag` subgraph.
+
+        Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
+        and is not a placeholder, we consider it as the parent node of the subgraph.
+        """
+        parent_nodes = set()
+
+        for node in self.module.graph.nodes:
+            if node.op in CALLABLE_NODE_OPS and node.tag == tag:
+                for arg in node.all_input_nodes:
+                    if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
+                        parent_nodes.add(arg)
+
+        return parent_nodes
+
+    def extend_acc_subgraph(self, tag: str):
+        """
+        Extend the acc subgraph with `tag` going the reversed topological direction.
+        """
+        # Dict that maps node to its users and ignore users that
+        # are in the subgraph that has greater tag
+        deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
+        self.update_reverse_deps_for_fusions(deps)
+
+        # Parent nodes of the subgraph
+        parent_nodes = self.find_parent_nodes_of_subgraph(tag)
+
+        visited_nodes: NodeSet = set()
+
+        while parent_nodes:
+            node = None
+
+            # Find a acc node that depends on visited nodes only
+            for n in parent_nodes:
+                if deps[n] <= visited_nodes and n in self.acc_nodes:
+                    node = n
+                    break
+
+            if node is None:
+                break
+
+            # Put the node into `tag` subgraph
+            node.tag = tag  # type: ignore[attr-defined]
+            parent_nodes.remove(node)
+            visited_nodes.add(node)
+
+            # If node is in a fusion group, add all fusion buddies to parent nodes
+            if node in self.fusions:
+                for fusion_node in self.fusions[node]:
+                    if fusion_node not in visited_nodes:
+                        parent_nodes.add(fusion_node)
+
+            # Add inputs of the node to parent nodes
+            for arg in node.all_input_nodes:
+                if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
+                    parent_nodes.add(arg)
+
+    # ===============================================================
+    # Helpers for split() method
+    # ===============================================================
+
+    def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
+        """
+        Finds nodes that consume module inputs or get_attr nodes.
+        """
+        starter_cpu_nodes: NodeSet = set()
+        starter_acc_nodes: NodeSet = set()
+        for node in self.module.graph.nodes:
+            if node.op not in {"placeholder", "get_attr"}:
+                continue
+            for user in node.users:
+                if user in self.acc_nodes:
+                    starter_acc_nodes.add(user)
+                else:
+                    starter_cpu_nodes.add(user)
+        return starter_cpu_nodes, starter_acc_nodes
+
+    def put_nodes_into_subgraphs(self) -> List[Subgraph]:
+        # We start graph traversal from leaf nodes
+        current_cpu_nodes, current_acc_nodes = self.starter_nodes()
+        visited_nodes: NodeSet = set()
+
+        # Determine which subgraph to start from based on which subgraph has
+        # 0-dep node
+        acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
+
+        current_subgraph_nodes: NodeList = []
+
+        # Result accumulator
+        subgraphs: List[Subgraph] = []
+        while current_cpu_nodes or current_acc_nodes:
+            # Find the first node that should belong to the current subgraph and has all dependencies resolved
+            current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
+            node = next(
+                (n for n in current_nodes if self.deps[n] <= visited_nodes),
+                None,
+            )
+
+            # If nothing was found, then it's time to flip the mode and start a new subgraph
+            if node is None:
+                if not current_subgraph_nodes:
+                    raise FxNetSplitterInternalError("Subgraph can't be empty")
+
+                subgraphs.append(
+                    Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
+                )
+                acc_subgraph = not acc_subgraph
+                current_subgraph_nodes = []
+                continue
+
+            current_nodes.remove(node)
+            visited_nodes.add(node)
+            current_subgraph_nodes.append(node)
+
+            # Add fusion buddies
+            if node in self.fusions:
+                if node in self.acc_nodes:
+                    current_acc_nodes.update(self.fusions[node] - visited_nodes)
+                else:
+                    current_cpu_nodes.update(self.fusions[node] - visited_nodes)
+
+            # Put depending nodes into the queue
+            for user in node.users:
+                if user.op not in CALLABLE_NODE_OPS:
+                    continue
+
+                # Add downstream nodes
+                if user in self.acc_nodes:
+                    current_acc_nodes.add(user)
+                else:
+                    current_cpu_nodes.add(user)
+
+        # Check if the last subgraph was not created
+        if current_subgraph_nodes:
+            subgraphs.append(
+                Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
+            )
+
+        if not subgraphs:
+            raise FxNetSplitterInternalError("Couldn't create subgraphs")
+
+        return subgraphs
+
+    def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
+        """
+        This pass finds ACC submodules with less than specified size and merges
+        them with adjacent CPU submodules.
+        """
+        result: List[Subgraph] = []
+        for subgraph in subgraphs:
+            if subgraph.is_acc:
+                if len(subgraph.nodes) >= self.settings.min_acc_module_size:
+                    result.append(subgraph)
+                else:
+                    print(
+                        "Eliminating acc subgraph because it's smaller than the threshold: "
+                        f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
+                    )
+                    if result:
+                        result[-1].nodes.extend(subgraph.nodes)
+                    else:
+                        subgraph.is_acc = False
+                        result.append(subgraph)
+            else:
+                if result and not result[-1].is_acc:
+                    result[-1].nodes.extend(subgraph.nodes)
+                else:
+                    result.append(subgraph)
+        return result
+
+    def tag(self, subgraphs: List[Subgraph]):
+        self.tags: List[str] = []
+        for subgraph in subgraphs:
+            tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
+            self.tags.append(tag)
+            for node in subgraph.nodes:
+                if hasattr(node, "tag"):
+                    raise FxNetSplitterInternalError(f"Node {node} was already tagged")
+
+                node.tag = tag  # type: ignore[attr-defined]
+                self._node_submodule_map[node.name] = tag
+
+    def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
+        split_module = split_by_tags(self.module, self.tags)
+        if remove_tag:
+            for node in self.module.graph.nodes:
+                if hasattr(node, "tag"):
+                    del node.tag
+        return split_module
+
+    def __call__(self) -> torch.fx.GraphModule:
+        subgraphs = self.put_nodes_into_subgraphs()
+        subgraphs = self.remove_small_acc_subgraphs(subgraphs)
+        acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
+        non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
+        print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
+        self.tag(subgraphs)
+        return self.split()
+
+    def generate_split_results(self) -> SplitResult:
+        split_module = self()
+        submodule_names = []
+        for name, mod in split_module.named_children():
+            submodule_names.append(name)
+        submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
+        return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/tests/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..061466ab8991a904303af8b28067103c4ba47bc4
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c537637b705f543b820a842e795913231650492
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/tests/test_pass_manager.py b/MLPY/Lib/site-packages/torch/fx/passes/tests/test_pass_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bcbbcbdd76710e0015e85361635296ffb0bd876
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/tests/test_pass_manager.py
@@ -0,0 +1,58 @@
+import unittest
+
+from ..pass_manager import (
+    inplace_wrapper,
+    PassManager,
+    these_before_those_pass_constraint,
+    this_before_that_pass_constraint,
+)
+
+
+class TestPassManager(unittest.TestCase):
+    def test_pass_manager_builder(self) -> None:
+        passes = [lambda x: 2 * x for _ in range(10)]
+        pm = PassManager(passes)
+        pm.validate()
+
+    def test_this_before_that_pass_constraint(self) -> None:
+        passes = [lambda x: 2 * x for _ in range(10)]
+        pm = PassManager(passes)
+
+        # add unfulfillable constraint
+        pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
+
+        self.assertRaises(RuntimeError, pm.validate)
+
+    def test_these_before_those_pass_constraint(self) -> None:
+        passes = [lambda x: 2 * x for _ in range(10)]
+        constraint = these_before_those_pass_constraint(passes[-1], passes[0])
+        pm = PassManager(
+            [inplace_wrapper(p) for p in passes]
+        )
+
+        # add unfulfillable constraint
+        pm.add_constraint(constraint)
+
+        self.assertRaises(RuntimeError, pm.validate)
+
+    def test_two_pass_managers(self) -> None:
+        """Make sure we can construct the PassManager twice and not share any
+        state between them"""
+
+        passes = [lambda x: 2 * x for _ in range(3)]
+        constraint = these_before_those_pass_constraint(passes[0], passes[1])
+        pm1 = PassManager()
+        for p in passes:
+            pm1.add_pass(p)
+        pm1.add_constraint(constraint)
+        output1 = pm1(1)
+        self.assertEqual(output1, 2 ** 3)
+
+        passes = [lambda x: 3 * x for _ in range(3)]
+        constraint = these_before_those_pass_constraint(passes[0], passes[1])
+        pm2 = PassManager()
+        for p in passes:
+            pm2.add_pass(p)
+        pm2.add_constraint(constraint)
+        output2 = pm2(1)
+        self.assertEqual(output2, 3 ** 3)
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/tools_common.py b/MLPY/Lib/site-packages/torch/fx/passes/tools_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ecc8daf3db75132c8d101ca34c27207737094eb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/tools_common.py
@@ -0,0 +1,273 @@
+from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
+import collections
+from dataclasses import dataclass
+
+import torch
+import torch.fx
+from torch.fx.node import _get_qualified_name
+from torch.fx._compatibility import compatibility
+
+__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
+
+Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
+TensorOrTensors = Union[torch.Tensor, Tensors]
+NodeList = List[torch.fx.Node]
+NodeSet = Set[torch.fx.Node]
+Names = List[str]
+CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
+
+
+@compatibility(is_backward_compatible=False)
+def get_acc_ops_name(k):
+    if isinstance(k, str):
+        return k
+    elif k.__module__ and "acc_ops" in k.__module__:
+        return f"acc_ops.{k.__name__}"
+    else:
+        module = k.__module__.replace('torch._ops', 'torch.ops')  # WAR for bug in how torch.ops assigns module
+        return f"{module if module else ''}.{k.__name__}"
+
+
+@compatibility(is_backward_compatible=False)
+def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
+    """
+    Given a `node` returns its target typename.
+
+    For "call_method" node, return node.target which is the name of that method being called.
+    This could potential lead to conflict but should be okay because normally it's on a tensor.
+
+    For "call_function" node, return typename of node.target.
+
+    For "call_module" node, return typename of the module that node.target point to.
+
+    If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
+    "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
+    """
+
+    assert node.op in CALLABLE_NODE_OPS, (
+        "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
+    )
+
+    if node.op == "call_module":
+        assert isinstance(node.target, str)
+        submod = submodules[node.target]
+        submod_type = getattr(submod, "_base_class_origin", type(submod))
+        return get_acc_ops_name(submod_type)
+    elif node.op == "call_function":
+        target: Any = node.target
+        return (
+            f"acc_ops.{target.__name__}"
+            if target.__module__ is not None and "acc_ops" in target.__module__
+            else _get_qualified_name(target)
+        )
+    else:
+        assert isinstance(node.target, str)
+        return node.target
+
+@compatibility(is_backward_compatible=False)
+def is_node_output_tensor(node: torch.fx.Node) -> bool:
+    """Checks if the node output produces a Tensor or not.
+
+    NOTE: This requires to run `ShapeProp` on the containing fx graph before
+    calling this function. This is because it works by checking the `type`
+    metadata on the node. This metadata is produced by the `ShapeProp`.
+    """
+    type_ = node.meta.get("type", None)
+    return type_ is not None and issubclass(type_, torch.Tensor)
+
+@compatibility(is_backward_compatible=False)
+class FxNetAccFusionsFinder:
+    """
+    Finds groups of connected ACC nodes that pass non-tensor data between each other.
+    Such groups are called fusion groups.
+    """
+
+    def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
+        self.module = module
+        self.nodes = list(module.graph.nodes)
+        self.acc_nodes = acc_nodes
+
+    @dataclass
+    class FusionGroup:
+        # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
+        top_node_idx: int
+
+        # Nodes in this fusion group.
+        nodes: NodeSet
+
+        # Inputs to this fusion group.
+        inputs: NodeSet
+
+        # Nodes that in the fusion group that haven't been processed yet.
+        nodes_need_process: NodeSet
+
+        def add_node(self, node):
+            """
+            Add a node to fusion group.
+            """
+            if node in self.nodes:
+                return
+
+            self.nodes_need_process.add(node)
+            self.nodes.add(node)
+            self.inputs.discard(node)
+            self.inputs.update(
+                {
+                    n
+                    for n in node.all_input_nodes
+                    if n.op in CALLABLE_NODE_OPS and n not in self.nodes
+                }
+            )
+
+    def recursive_add_node(
+        self,
+        fusion_group: "FxNetAccFusionsFinder.FusionGroup",
+        inputs: Union[NodeSet, NodeList],
+        visited: Optional[NodeSet] = None,
+    ):
+        """
+        Start from inputs and going reverse topological order. If any upstream node
+        is in the fusion group, add all the nodes in this path to fusion group.
+        """
+        for arg in inputs:
+            # skip the node if already seen
+            if visited is not None:
+                if arg in visited:
+                    continue
+                visited.add(arg)
+
+            # Skip placeholder and get_attr because they won't be in the fusion group.
+            if arg.op not in CALLABLE_NODE_OPS:
+                continue
+
+            # If the node has smaller idx, it's already an upstream node of the fusion
+            # group. We don't need to check it anymore.
+            if self.nodes.index(arg) < fusion_group.top_node_idx:
+                continue
+
+            # If the node is in the fusion group, return True.
+            if arg in fusion_group.nodes:
+                return True
+
+            # Check the upstream nodes of the node, if any of them is in the fusion group
+            # we'll add this node to fusion group and return True.
+            if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
+                fusion_group.add_node(arg)
+                return True
+
+        return False
+
+    def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
+        result: Dict[torch.fx.Node, NodeSet] = {}
+        acc_nodes = list(self.acc_nodes)
+
+        for node in acc_nodes:
+            if node in result:
+                continue
+            if node.op not in CALLABLE_NODE_OPS:
+                continue
+            if "tensor_meta" in node.meta:
+                continue
+            if node not in self.acc_nodes:
+                continue
+
+            fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
+                top_node_idx=self.nodes.index(node),
+                nodes={node},
+                inputs=set(node.all_input_nodes),
+                nodes_need_process={node},
+            )
+            while fusion_group.nodes_need_process:
+                node = fusion_group.nodes_need_process.pop()
+                self.recursive_add_node(
+                    fusion_group,
+                    fusion_group.inputs,
+                    visited=set(),
+                )
+
+                # Optionally add downstream nodes
+                if "tensor_meta" not in node.meta:
+                    for user in node.users:
+                        if user.op not in CALLABLE_NODE_OPS:
+                            continue
+                        if user in fusion_group.nodes:
+                            continue
+
+                        fusion_group.add_node(user)
+                        self.recursive_add_node(
+                            fusion_group,
+                            fusion_group.inputs,
+                            visited=set(),
+                        )
+
+                # Add some upstream nodes
+                for arg in node.all_input_nodes:
+                    if arg.op not in CALLABLE_NODE_OPS:
+                        continue
+                    if "tensor_meta" in arg.meta:
+                        continue
+                    if arg in fusion_group.nodes:
+                        continue
+
+                    fusion_group.add_node(arg)
+                    fusion_group.top_node_idx = min(
+                        fusion_group.top_node_idx, self.nodes.index(arg)
+                    )
+                    self.recursive_add_node(
+                        fusion_group,
+                        fusion_group.inputs,
+                        visited=set(),
+                    )
+
+            if not (set(fusion_group.nodes) <= self.acc_nodes):
+                self.acc_nodes -= fusion_group.nodes
+            else:
+                for n in fusion_group.nodes:
+                    result[n] = fusion_group.nodes
+
+        return result
+
+
+@compatibility(is_backward_compatible=False)
+def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    """
+    Replace the graph of the given GraphModule with one that contains the same nodes as the
+    original, but in topologically sorted order.
+
+    This is used by the merge_matmul transformation below, which disturbs the topologically sorted
+    order of its input GraphModule, so that this order is restored before further transformation.
+
+    Arguments:
+        gm: The graph module to topologically sort. It is modified in-place.
+
+    Returns:
+        The graph module in-place sorted
+    """
+    indeg = dict.fromkeys(gm.graph.nodes, 0)
+    new_graph = torch.fx.Graph()
+    # Track how many unfulfilled dependencies each node has
+    for node in gm.graph.nodes:
+        for user in node.users:
+            indeg[user] += 1
+    queue: collections.deque = collections.deque()
+    # Add all nodes with no dependencies to the queue
+    for node in gm.graph.nodes:
+        if indeg[node] == 0:
+            queue.append(node)
+    env: Dict[torch.fx.Node, torch.fx.Node] = {}
+    # Pop nodes from the queue, and add nodes that have had all their
+    # dependencies fulfilled
+    while len(queue) > 0:
+        cur = queue.popleft()
+        env[cur] = new_graph.node_copy(cur, lambda x: env[x])
+        for user in cur.users:
+            indeg[user] -= 1
+            if indeg[user] == 0:
+                queue.append(user)
+    # If the new graph's size is not as large as the old one, then there must be
+    # a cycle (i.e. some node's dependencies were not satisfied.)
+    if len(new_graph.nodes) < len(gm.graph.nodes):
+        raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
+    new_graph._codegen = gm.graph._codegen
+    gm.graph = new_graph
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__init__.py b/MLPY/Lib/site-packages/torch/fx/passes/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bbbabf92272a3979ec9197c5ed3cd38352d2472
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/utils/__init__.py
@@ -0,0 +1 @@
+from .common import lift_subgraph_as_module, HolderModule, compare_graphs
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71804d0bf6960ed9a8f85c819fe679be9719dd3a
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e47d0d25b587619c669383b34021b5223aa2683
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f869bffbc69c0930d72cf8d69ecba8127de88c62
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb62a5d57b31fd079d6e8f68c77a7c6c818ae966
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..749025d795688e58a9b58c6d882d9d48e560cec2
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22e6407d92eb285b07967d18393e0431e626679d
Binary files /dev/null and b/MLPY/Lib/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc differ
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/common.py b/MLPY/Lib/site-packages/torch/fx/passes/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4a04920d5b17d997a93588154103f11247a562
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/utils/common.py
@@ -0,0 +1,95 @@
+from typing import Dict, Tuple
+
+from torch.fx._compatibility import compatibility
+from torch.fx.graph import Graph
+
+from torch.fx.graph_module import GraphModule
+from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
+from torch.nn import Module
+
+
+__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
+
+
+@compatibility(is_backward_compatible=False)
+class HolderModule(Module):
+    """
+    HolderModule is used to copy all the attributes from original module to submodules
+    that uses the attributes
+    """
+
+    def __init__(self, d):
+        super().__init__()
+        for k, v in d.items():
+            self.add_module(k, v)
+
+
+@compatibility(is_backward_compatible=False)
+def lift_subgraph_as_module(
+    gm: GraphModule,
+    subgraph: Graph,
+    comp_name: str = "",
+    class_name: str = "GraphModule",
+) -> Tuple[GraphModule, Dict[str, str]]:
+    """
+    Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
+
+    Args:
+        gm (GraphModule): parent graph module
+
+        subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
+
+        comp_name (str): name for the new component
+
+        class_name (str): name for the submodule
+
+    """
+
+    # Loop through all module calls (call_module) and param fetches (get_attr)
+    # in this component, creating HolderModules as necessary to match the path.
+    # e.g. if in the original module there's a get_attr node fetches "conv.weight".
+    # We create a HolderModule as root -> add a HolderModule named "conv" ->
+    # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
+    # the original module.
+    submodule = HolderModule({})
+    orig_to_split_fqn_mapping: Dict[str, str] = {}
+    for n in subgraph.nodes:
+        if n.op not in ("call_module", "get_attr"):
+            continue
+
+        target = n.target
+        assert isinstance(target, str)
+        target_name_parts = target.split(".")
+        curr = submodule
+        orig_gm = gm
+
+        for name in target_name_parts[:-1]:
+            if not hasattr(curr, name):
+                curr.add_module(name, HolderModule({}))
+
+            curr = getattr(curr, name)
+            orig_gm = getattr(orig_gm, name)
+
+        leaf_node_name = target_name_parts[-1]
+        leaf_node = getattr(orig_gm, leaf_node_name)
+
+        orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
+        # Relies on custom __setattr__ magic.
+        setattr(curr, leaf_node_name, leaf_node)
+
+    return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
+
+
+@compatibility(is_backward_compatible=False)
+def compare_graphs(left: Graph, right: Graph) -> bool:
+    """
+    Return True if two graphs are identical, i.e they
+        - have the same number of outputs in the same order
+        - have the same number of inputs in the same order
+        - have the same set of nodes, and identical connectivity
+    """
+
+    matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
+    matches = matcher.match(right)
+
+    return len(matches) > 0
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/fuser_utils.py b/MLPY/Lib/site-packages/torch/fx/passes/utils/fuser_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..54bf8364c4d4633d5a054870337bb0bba05d04db
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/utils/fuser_utils.py
@@ -0,0 +1,233 @@
+import copy
+from queue import SimpleQueue
+from typing import List, Dict, Tuple
+
+import torch.fx
+from torch.fx.graph_module import GraphModule
+from torch.fx.graph import Graph
+from torch.fx.node import Node
+from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
+from torch.fx.passes.utils import lift_subgraph_as_module
+from torch.fx._compatibility import compatibility
+
+@compatibility(is_backward_compatible=False)
+def topo_sort(nodes: NodeList) -> NodeList:
+    # sort nodes according to the topological order
+    indegree_map = dict.fromkeys(nodes, 0)
+    candidates: SimpleQueue = SimpleQueue()
+
+    for node in nodes:
+        for n in node.all_input_nodes:
+            if n in indegree_map:
+                indegree_map[node] += 1
+        if indegree_map[node] == 0:
+            candidates.put(node)
+
+    sorted_nodes: NodeList = list()
+    while not candidates.empty():
+        node = candidates.get()
+        sorted_nodes.append(node)
+
+        for n in node.users:
+            if n in indegree_map:
+                indegree_map[n] -= 1
+                if indegree_map[n] == 0:
+                    candidates.put(n)
+
+    assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
+
+    return sorted_nodes
+
+
+@compatibility(is_backward_compatible=False)
+def validate_partition(partition: NodeList) -> bool:
+    # verify the partition does't form a dependency cycle in the original graph
+    # returns True for valid partition, False for invalid
+
+    partition_set = set(partition)
+
+    outputs: NodeList = list()
+    for node in partition_set:
+        for user_node in node.users:
+            if user_node not in partition_set:
+                # external user node, need to expose as an output
+                outputs.append(user_node)
+
+    # Perform BFS on the partition outputs.
+    # If it reaches a node within the partition, then it found a cycle.
+    # This function takes the ownership of `root_nodes` and may modify it.
+    def bfs_find_cycle(root_nodes: NodeList) -> bool:
+        # Set used to exclude nodes that have already been visited.
+        # If a node has been visited, that node and all its children have
+        # been checked for cycles.
+        visited: NodeSet = set()
+
+        # Start with `root_nodes` and traverse through (toward child nodes)
+        # their connected sub-graph. Nodes in `visited` won't be added
+        # to `queue` again.
+        queue: NodeList = root_nodes
+        while queue:
+            current = queue.pop()
+            visited.add(current)
+            if current in partition_set:
+                # Started from partition's `output` nodes, and reached
+                # another node in partition. Cycle!
+                return True
+            for user_node in current.users:
+                if user_node in visited:
+                    continue
+                queue.append(user_node)
+        # `root_nodes` don't cause cycle.
+        return False
+
+    # Use all output nodes as roots to traverse
+    # the graph to check cycles.
+    if bfs_find_cycle(outputs):
+        return False
+
+    return True
+
+
+@compatibility(is_backward_compatible=False)
+def fuse_as_graphmodule(gm: GraphModule,
+                        nodes: NodeList,
+                        module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
+
+    """
+    Fuse nodes in graph_module into a GraphModule.
+
+    Args:
+        gm (GraphModule): target graph_module
+
+        nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
+
+        module_name: class name for the fused GraphModule
+
+    Returns:
+        fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
+
+        original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
+
+        original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
+
+    """
+
+    # assumption: nodes are already sorted in topo order
+
+    for node in nodes:
+        assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
+        assert not node._erased, f"{node} has been removed from owning graph"
+        assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}"
+
+    # validates partition doesn't introduce dependency circles in the graph
+    assert validate_partition(nodes), "Invalid partition, found dependency cycles"
+
+    subgraph = Graph()
+
+    node_to_placeholder: Dict[Node, Node] = {}  # mapping of nodes from old graph to placeholder in new graph
+    node_map: Dict[Node, Node] = {}       # mapping of nodes from old graph to new graph
+
+    # handles inputs through graph.node_copy's arg_transform functions
+    def remap_inputs(x):
+        if x.op == "get_attr":
+            # TODO: do we really need copy the get_attr node into the graph?
+            # do something here
+            pass
+
+        if x in nodes:
+            # x is inside subgraph, return the copied node
+            # the node should have been copied aleady, as we are copying graph in the topological order
+            return node_map[x]
+
+        if x not in node_to_placeholder:
+            # x is not in subgraph, create a new placeholder for subgraph
+            placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
+            # copy all meta fields, even if some fields might be irrelvant for the placeholder node
+            placeholder_node.meta = copy.copy(x.meta)
+            node_to_placeholder[x] = placeholder_node
+
+        return node_to_placeholder[x]
+
+    # copy nodes in topological order
+    for node in nodes:
+        new_node = subgraph.node_copy(node, remap_inputs)
+        node_map[node] = new_node
+
+    # handles outputs
+    output_mapping: Dict[Node, Node] = {}  # mapping from old output to new outputs
+
+    for node in nodes:
+        for user_node in node.users:
+            if user_node not in nodes:
+                # external user node, need to expose as an output
+                output_mapping[node] = node_map[node]
+
+    # outs contain nodes in the new subgraph
+    outs = tuple(output_mapping.values())
+
+    # Take care of the args of FX output node. If there's a single
+    # output then the output node args is like (output_single), else
+    # if there're multiple outputs then the output node args is like
+    # ((output_0, output_1, ...)).
+    subgraph.output(outs[0] if len(outs) == 1 else outs)
+
+    # lint to ensure correctness
+    subgraph.lint()
+    fused_gm: GraphModule
+    fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name)
+
+    # sub_gm's input nodes in the original module
+    original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
+
+    # sub_gm's outputs node in the original module
+    original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
+
+    return fused_gm, original_inputs, original_outputs
+
+
+@compatibility(is_backward_compatible=False)
+def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
+    # add sub_gm into gm
+    submodule_name = sub_gm.__class__.__name__
+    gm.add_submodule(submodule_name, sub_gm)
+
+    # Create a call_module node in main graph.
+    module_node = gm.graph.call_module(
+        submodule_name,
+        args=orig_inputs,
+        kwargs=None)
+
+    if len(orig_outputs) == 1:
+        # main_remapping[comp.orig_outputs[0]] = module_node
+        orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
+    else:
+        for i, orig_output in enumerate(orig_outputs):
+            # Use Proxy to record getitem access.
+            proxy_out = torch.fx.Proxy(module_node)[i].node  # type: ignore[index]
+            orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
+    return gm
+
+@compatibility(is_backward_compatible=False)
+def erase_nodes(gm: GraphModule, nodes: NodeList):
+
+    # erase original nodes in inversed topological order
+    for node in reversed(nodes):
+        gm.graph.erase_node(node)
+
+
+@compatibility(is_backward_compatible=False)
+def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
+    for partition_id, nodes in enumerate(partitions):
+        sorted_nodes = topo_sort(nodes)
+
+        submodule_name = "fused_" + str(partition_id)
+        sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
+
+        insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
+
+        erase_nodes(gm, sorted_nodes)
+
+    # topological sort original gm with newly created sub_gm
+    legalize_graph(gm)
+
+    return gm
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/matcher_utils.py b/MLPY/Lib/site-packages/torch/fx/passes/utils/matcher_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1f9e84998ff0bf01517ff09bc7bccccdbfd634
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/utils/matcher_utils.py
@@ -0,0 +1,400 @@
+from dataclasses import dataclass, field
+from collections import defaultdict
+import copy
+import torch
+from torch.fx import (
+    Node,
+    Graph,
+)
+from torch.fx._compatibility import compatibility
+from typing import Dict, List, Set, Any, Union, Tuple
+import logging
+import os
+
+__all__ = ['SubgraphMatcher', 'InternalMatch']
+
+# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
+def _init_logger():
+    logger = logging.getLogger(__name__)
+
+    level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter("%(filename)s > %(message)s")
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+logger = _init_logger()
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class InternalMatch:
+    # Nodes from which the match was found
+    anchors: List[Node]
+    # Maps nodes in the pattern subgraph to nodes in the larger graph
+    nodes_map: Dict[Node, Node] = field(default_factory=dict)
+
+    # nodes in target graph that are matched placeholder in pattern
+    placeholder_nodes: List[Node] = field(default_factory=list)
+
+    # nodes in matched subgraph returned by output
+    returning_nodes: List[Node] = field(default_factory=list)
+
+    # map from a string name to a node in the target graph
+    # only available if the matcher is `SubgraphMatcherWithNameNodesMap`
+    name_node_map: Dict[str, Node] = field(default_factory=dict)
+
+    def __copy__(self):
+        return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(),
+                             placeholder_nodes=self.placeholder_nodes.copy(),
+                             returning_nodes=self.returning_nodes.copy())
+
+@compatibility(is_backward_compatible=False)
+class SubgraphMatcher:
+    def __init__(self, pattern: Graph,
+                 match_output: bool = False,
+                 match_placeholder: bool = False,
+                 remove_overlapping_matches: bool = True,
+                 ignore_literals: bool = False) -> None:
+        """
+        Args:
+            pattern: the targeted matching pattern, represented in fx.Graph.
+            match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
+                If False, output node is ignored during match.
+            match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
+                the targeted pattern. If False, placeholder nodes will be used a wildcard.
+            remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
+                will be returned.
+            ignore_literals: If True, will not check if literals are equal and
+                will instead treat them as wildcards.
+        """
+
+        self.pattern = pattern
+        self.match_output = match_output
+        self.match_placeholder = match_placeholder
+        self.remove_overlapping_matches = remove_overlapping_matches
+        self.ignore_literals = ignore_literals
+
+        if len(pattern.nodes) == 0:
+            raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern")
+
+        for node in pattern.nodes:
+            if node.op != "output":
+                assert len(node.users) > 0, \
+                       "SubgraphMatcher cannot be initialized with an pattern with dead code"
+
+        # TODO: assert pattern is a connected graph
+
+        self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"]
+        output_node = next(iter(reversed(pattern.nodes)))
+        # nodes returned by outputs
+        self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
+
+        self.pattern_anchors: List[Node] = []
+        if match_output:
+            self.pattern_anchors = [output_node]
+        else:
+            # If a node has output_node as the ONLY user, then this node is a graph sink,
+            # and should be matched against as an anchor
+            self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1]
+
+    def _match_attributes(self, pn: Node, gn: Node) -> bool:
+        # Attributes matching is complicated. Right now we only support matching constant tensor
+        assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
+        assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
+
+        # TODO(tmanlaibaatar) should probably make this actual API
+        def _getattr(model: torch.fx.GraphModule, attr_name: str):
+            *prefix, field = attr_name.split(".")
+            t = model
+            for item in prefix:
+                t = getattr(t, item, None)  # type: ignore[assignment]
+                assert t is not None
+
+            return getattr(t, field)
+
+        pn_value = _getattr(pn.graph.owning_module, pn.target)
+        gn_value = _getattr(gn.graph.owning_module, gn.target)
+
+        if type(pn_value) != type(gn_value):
+            return False
+
+        # Don't require exact match on tensor values.
+        if isinstance(pn_value, torch.Tensor):
+            return isinstance(gn_value, torch.Tensor)
+        else:
+            raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
+        return False
+
+    def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
+        # if exact match for placeholder is not required, then use placeholder as a wildcard
+        if not self.match_placeholder and pn.op == "placeholder":
+            return True
+
+        if pn.op == gn.op:
+            if pn.op == "placeholder" or pn.op == "output":
+                return True
+            elif pn.op == "get_attr":
+                return self._match_attributes(pn, gn)
+            return pn.target == gn.target
+        return False
+
+    def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
+        # `lookup` represents all the nodes in `original_graph`
+        # that are part of `pattern`
+
+        # Placeholders can be used by other nodes in the graphs
+        lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
+
+        for gn, pn in lookup.items():
+            # nodes returned by output are allowed to be used in other areas of the graph
+            if pn in self.pattern_returning_nodes:
+                continue
+
+            for user in gn.users:
+                # If this node has users that were not in `lookup`, then it must leak out of the
+                # pattern subgraph
+                if user not in lookup:
+                    return False
+        return True
+
+    def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]:
+        non_overlapping_matches: List[InternalMatch] = list()
+        nodes_matched: Set[Node] = set()
+
+        for match in matches:
+            found_overlap = False
+            for pn, gn in match.nodes_map.items():
+                if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
+                    found_overlap = True
+                    break
+
+            if not found_overlap:
+                non_overlapping_matches.append(match)
+                for pn, gn in match.nodes_map.items():
+                    if pn.op not in {"placeholder", "output"}:
+                        nodes_matched.add(gn)
+        return non_overlapping_matches
+
+    def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
+        assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
+
+        if isinstance(pn, Node) and not isinstance(gn, Node):
+            if pn.op == "placeholder":
+                # Check if we've already matched these nodes in the current
+                # traversal
+                if pn in match.nodes_map:
+                    return match.nodes_map[pn] == gn
+
+                match.nodes_map[pn] = gn
+                return True
+            else:
+                return False
+        elif not isinstance(pn, Node) and isinstance(gn, Node):
+            return False
+        else:
+            return type(gn) == type(pn) and gn == pn
+
+    def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
+        logger.info("  matching %s to %s", pn, gn)
+
+        assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}")
+
+        # Check if we've already matched these nodes in the current
+        # traversal
+        if pn in match.nodes_map:
+            return match.nodes_map[pn] == gn
+
+        # TODO: use a more efficient way to check if gn is matched before: two-way dict
+        if gn in match.nodes_map.values():
+            return False
+
+        if not self._nodes_are_equal(pn, gn):
+            return False
+
+        # Optimistically mark `pn` as a match for `gn`, and save a local copy of match
+        saved_match = copy.copy(match)
+        match.nodes_map[pn] = gn
+
+        # Placeholder is a wildcard and can be matched with any python object
+        # (including list/tuple)
+        if pn.op == "placeholder":
+            return True
+
+        # Recursively traverse upwards to check if `pn` is a true
+        # match for `gn`
+        match_found = True
+
+        def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
+            if len(args1) != len(args2):
+                return False
+
+            for a1, a2 in zip(args1, args2):
+                if isinstance(a1, Node) and isinstance(a2, Node):
+                    matched = self._match_nodes(a1, a2, match)
+                elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
+                    matched = _match_args(a1, a2)
+                else:
+                    matched = self._match_literals(a1, a2, match) or self.ignore_literals
+
+                if not matched:
+                    return False
+
+            return True
+
+        # Flatten all args/kwargs into 1 list of args
+        pn_args, gn_args = None, None
+        if (
+            (len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and
+            pn.op == "call_function" and
+            isinstance(pn.target, torch._ops.OpOverload)
+        ):
+            args_schema = pn.target._schema.arguments
+
+            def get_all_arguments(orig_args, orig_kwargs):
+                all_args = []
+                for i, schema in enumerate(args_schema):
+                    if schema.name in orig_kwargs:
+                        all_args.append(orig_kwargs[schema.name])
+                    elif not schema.kwarg_only and i < len(orig_args):
+                        all_args.append(orig_args[i])
+                    else:
+                        all_args.append(schema.default_value)
+                return all_args
+
+            pn_args = get_all_arguments(pn.args, pn.kwargs)
+            gn_args = get_all_arguments(gn.args, gn.kwargs)
+
+        elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()):
+            pn_args = list(pn.args)
+            gn_args = list(gn.args)
+            pn_args.extend(list(pn.kwargs.values()))
+            gn_args.extend(list(gn.kwargs.values()))
+        else:
+            match_found = False
+
+        match_found = (
+            match_found and
+            pn_args is not None and
+            gn_args is not None and
+            _match_args(pn_args, gn_args)
+        )
+
+        if not match_found:
+            # revert to saved_match before matching with current node
+            match = copy.copy(saved_match)
+            return False
+
+        return True
+
+    def match(self, graph: Graph) -> List[InternalMatch]:
+        """
+        Returns:
+            The matched subgraphs.
+            Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
+            and nodes returned by output) can only be consumed by nodes within the matched subgraph.
+
+        Subgraph pattern matcher is implemented with the backtracking style in the following steps:
+
+        1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
+        are the "sinks" (nodes with no user other than the output node) of the pattern graph.
+        One pattern graph could have multiple anchors if it has multiple return values.
+
+        2. In the target graph, we identify the potential candidate nodes that can be matched
+        with each anchor. These anchor-candidate pairs are the starting points for
+        pairwise per-node matching.
+
+        3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
+        pattern and target graphs. For every pattern nodes along traversal path, we compare it
+        against the target nodes. In case any comparison failed, the match for this anchor-candidate
+        pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
+        for more details.
+
+        4. In the case of multiple anchors, every anchor will need to find a match using step 3.
+        In addition, the matches found between anchors need to have a common intersection node
+        in order for the match to be valid. This is implemented with backtracking. See `backtracking`
+        for more details.
+
+        Notice: graph traversal must be done in the reverser order because a tensor can have multiple
+        consumers, but can only have a single producer. Only with reverser order, we can we jointly
+        traverse the pattern and target graph in a deterministic path.
+
+        Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
+        in practice, it's unlikely to blow up.
+
+        """
+        from torch.fx.passes.utils.fuser_utils import validate_partition
+
+        # find candidate nodes to match with pattern anchors
+        match_candidates: Dict[Node, List[Node]] = defaultdict(list)
+        for pattern_anchor in self.pattern_anchors:
+            for node in graph.nodes:
+                if self._nodes_are_equal(pattern_anchor, node):
+                    match_candidates[pattern_anchor].append(node)
+        match_candidates_list = list(match_candidates.items())
+
+        logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
+
+        matches: List[InternalMatch] = []
+
+        def backtracking(anchor_index, match):
+            if anchor_index == len(match_candidates_list):
+                match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes]
+                match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes]
+                matches.append(match)
+
+                logger.info("Found a match: %s\n", match)
+                return
+
+            pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
+            saved_match = copy.copy(match)
+
+            for node in candidate_nodes:
+                logger.info("Trying to match anchor %s to %s", pattern_anchor, node)
+
+                match_found = self._match_nodes(pattern_anchor, node, match)
+                if match_found:
+                    # match next anchor
+                    backtracking(anchor_index + 1, match)
+                else:
+                    logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node)
+
+                # revert to saved_match before matching with current anchor
+                match = copy.copy(saved_match)
+
+        match = InternalMatch(anchors=self.pattern_anchors)
+        if match_candidates_list:
+            backtracking(0, match)
+
+        # filter out the matches where the subgraph is not fully_contained
+        before = len(matches)
+        matches = [match for match in matches if self._is_contained(match.nodes_map)]
+        after = len(matches)
+        if before != after:
+            logger.info("Filtered out %s matches because they are not fully contained", before - after)
+
+        # filter out the matches that form a cycle if the subgraph is fused
+        valid_matches = []
+        for match in matches:
+            matched_compute_nodes = \
+                [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}]
+            if validate_partition(matched_compute_nodes):
+                valid_matches.append(match)
+        if len(valid_matches) != len(matches):
+            logger.info("Filtered out %s matches because \
+                          matched subgraph would form a cycle if fused", len(matches) - len(valid_matches))
+
+        if self.remove_overlapping_matches:
+            before = len(valid_matches)
+            matches = self._remove_overlapping_matches(valid_matches)
+            after = len(matches)
+            if before != after:
+                logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after)
+
+        logger.info("Matches returned: %s", matches)
+
+        return matches
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/MLPY/Lib/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e3b59aed14e1af47f7116f44ace6b8c09668d33
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py
@@ -0,0 +1,113 @@
+from typing import Dict, List, Tuple
+
+from torch.fx import Graph, GraphModule, Node
+
+from torch.fx._compatibility import compatibility
+from .matcher_utils import InternalMatch, SubgraphMatcher
+
+__all__ = ["SubgraphMatcherWithNameNodeMap"]
+
+
+def _split_to_graph_and_name_node_map(
+    gm: GraphModule,
+) -> Tuple[GraphModule, Dict[str, Node]]:
+    from torch.fx.graph import _PyTreeInfo
+    from torch.utils._pytree import tree_flatten, tree_unflatten
+
+    name_node_map = {}
+    for n in gm.graph.nodes:
+        if n.op == "output":
+            assert gm._out_spec is not None
+            output = tree_unflatten(n.args[0], gm._out_spec)
+            assert isinstance(
+                output, tuple
+            ), "Expecting the pattern graph to return a tuple"
+            assert (
+                len(output) >= 2
+            ), "Expecting the pattern graph to have at least two outputs"
+            *out, name_node_map = output
+            flattened, out_spec = tree_flatten(out)
+            assert isinstance(
+                name_node_map, Dict
+            ), "Expecting the input graph to have a dict output as the last element"
+            n.args = (flattened,)
+            orig_pytree_info = gm._graph._codegen.pytree_info
+            gm._graph._codegen.pytree_info = _PyTreeInfo(
+                orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
+            )
+    gm.recompile()
+    return gm, name_node_map
+
+
+@compatibility(is_backward_compatible=False)
+class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
+    """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name,
+    this requires pattern to have specific format (returning and additional dictionary at the output,
+    that has node name as key, and the node in the pattern graph as value, see Example for more details)
+
+    Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during
+    initialization since we need to modify the graph (which requires `recompile` the GraphModule)
+
+    Example::
+        def pattern(x, weight):
+            conv = F.conv2d(x, weight)
+            relu = F.relu(conv)
+            return relu, {"conv": conv, "relu": relu}
+
+        def target_graph(x, weight):
+            conv = F.conv2d(x, weight)
+            relu = F.relu(conv)
+            relu *= 2
+            return relu
+
+        pattern_gm = capture_pre_autograd_graph(pattern, example_inputs)
+        target_gm = capture_pre_autograd_graph(target_graph, example_inputs)
+        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
+        matches = matcher.match(target_gm)
+        for match in matches:
+            match.name_node_map["conv"].meta["annotation"] = ...
+
+    """
+
+    def __init__(
+        self,
+        pattern_gm: GraphModule,
+        match_output: bool = False,
+        match_placeholder: bool = False,
+        remove_overlapping_matches: bool = True,
+        ignore_literals: bool = False,
+    ) -> None:
+        pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
+        self.name_node_map = name_node_map
+        super().__init__(
+            pattern_gm.graph,
+            match_output,
+            match_placeholder,
+            remove_overlapping_matches,
+            ignore_literals,
+        )
+
+    def match(self, graph: Graph) -> List[InternalMatch]:
+        """The returned InternalMatch will have name_node_map populated with a map
+        from node name (str) to the target node, e.g.
+        {"conv": target_conv_ndoe, "relu": target_relu_node}
+
+        this requires the pattern graph returns an additional
+        output of node name to node, e.g. instead of:
+        ```
+        def pattern(...):
+            ...
+            return relu
+        ```
+        we should do:
+        ```
+        def pattern(...):
+            ...
+            return relu, {"conv": conv, "relu": relu}
+        ``` instead
+        """
+        internal_matches = super().match(graph)
+        for internal_match in internal_matches:
+            for k, n in self.name_node_map.items():
+                internal_match.name_node_map[k] = internal_match.nodes_map[n]
+        return internal_matches
diff --git a/MLPY/Lib/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/MLPY/Lib/site-packages/torch/fx/passes/utils/source_matcher_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aea6cb4191cdc4d04b1e6dc941ac35d2b8490dc2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/passes/utils/source_matcher_utils.py
@@ -0,0 +1,144 @@
+from dataclasses import dataclass, field
+from torch.fx.graph import Graph
+from torch.fx.node import Node
+from torch.fx._compatibility import compatibility
+from typing import Dict, List, Any, Type, Optional, Callable
+import logging
+import os
+
+
+__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
+
+# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
+def _init_logger():
+    logger = logging.getLogger(__name__)
+
+    level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter("%(filename)s > %(message)s")
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+logger = _init_logger()
+
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class SourcePartition:
+    # Nodes in a particular partition
+    nodes: List[Node]
+
+    # The source these nodes decomposed from
+    source: Any
+
+    # Nodes in the graph that are needed as inputs to the partition
+    input_nodes: List[Node] = field(default_factory=list)
+
+    # Nodes in the partition that are being used by nodes outside of the
+    # partition
+    output_nodes: List[Node] = field(default_factory=list)
+
+    # Parameters that are being used
+    params: List[Node] = field(default_factory=list)
+
+
+@compatibility(is_backward_compatible=False)
+def get_source_partitions(
+    graph: Graph,
+    wanted_sources: List[Any],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Dict[Any, List[SourcePartition]]:
+    """
+    Args:
+        graph: The graph we want to partition
+        wanted_sources: List of sources of nodes that were decomposed from this
+            source. This can be a function (ex. torch.nn.functional.linear) or a
+            leaf module type (ex. torch.nn.Linear).
+
+    Returns:
+        Dictionary mapping sources that were given to a list of SourcePartitions
+        that correspond to the list of nodes that were decomposed from the given
+        source.
+    """
+    modules: Dict[Type, Dict[str, List[Node]]] = {}
+
+    for node in graph.nodes:
+        # The metadata source_fn should contain a tuple of a unique name for the
+        # source, and the source function if the node is decomposed from a
+        # function, or the type of module if the node is decomposed from a leaf
+        # module
+
+        if (source_fn_st := node.meta.get("source_fn_stack", None)) is None:
+            continue
+
+        source_fn = source_fn_st[-1]
+        if source_fn[1] not in wanted_sources:
+            continue
+
+        diff_modules = modules.setdefault(source_fn[1], {})
+        partition = diff_modules.setdefault(source_fn[0], [])
+        partition.append(node)
+
+    def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
+        input_nodes = set()
+        output_nodes = set()
+        params = set()
+        for node in nodes:
+            for arg in node.args:
+                if isinstance(arg, Node) and arg not in nodes:
+                    input_nodes.add(arg)
+
+            if node.op == "get_attr":
+                params.add(node)
+
+            for user in node.users.keys():
+                if user not in nodes:
+                    output_nodes.add(node)
+
+        return SourcePartition(
+            nodes,
+            module_type,
+            list(input_nodes),
+            list(output_nodes),
+            list(params),  # type: ignore[arg-type]
+        )
+
+    ret: Dict[Type[Any], List[SourcePartition]] = {}
+
+    if filter_fn:
+        # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
+        # filter condition
+        filtered_modules = {}
+        for tp, name_to_partition in modules.items():
+            filtered_name_to_partition = {
+                name: partition
+                for name, partition in name_to_partition.items()
+                if all(map(filter_fn, partition))
+            }
+            filtered_modules[tp] = filtered_name_to_partition
+        modules = filtered_modules
+
+    for k, v in modules.items():
+        ret[k] = [make_partition(partition, k) for partition in v.values()]
+
+    return ret
+
+
+@compatibility(is_backward_compatible=False)
+def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
+    """
+    Given two subgraphs A and B (in the form of a list of nodes), checks if
+    A has nodes connecting to at least one node in B -- aka there exists a node
+    in B that uses a node in A (not the other way around).
+    """
+
+    for node in reversed(subgraph1.nodes):
+        for user in node.users.keys():
+            if user in subgraph2.nodes:
+                return True
+    return False
diff --git a/MLPY/Lib/site-packages/torch/fx/proxy.py b/MLPY/Lib/site-packages/torch/fx/proxy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec706d1318f429181547894f24cf60c033545b6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/proxy.py
@@ -0,0 +1,565 @@
+# mypy: ignore-errors
+
+import enum
+import dis
+import copy
+import sys
+import torch
+import inspect
+import operator
+import traceback
+import collections
+
+from dataclasses import is_dataclass, fields
+
+
+from .graph import magic_methods, reflectable_magic_methods, Graph
+from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
+from .node import Target, Node, Argument, base_types, map_aggregate
+from ._compatibility import compatibility
+from .operator_schemas import check_for_mutable_operation
+import torch.fx.traceback as fx_traceback
+
+__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
+           'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
+           'ScopeContextManager']
+
+
+@compatibility(is_backward_compatible=False)
+class Scope:
+    """ Scope object that records the module path and the module type
+    of a module. Scope is used to track the information of the module
+    that contains a Node in a Graph of GraphModule. For example::
+
+        class Sub(torch.nn.Module):
+            def forward(self, x):
+                # This will be a call_method Node in GraphModule,
+                # scope for this would be (module_path="sub", module_type=Sub)
+                return x.transpose(1, 2)
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                self.sub = Sub()
+
+            def forward(self, x):
+                # This will be a call_method Node as well,
+                # scope for this would be (module_path="", None)
+                x = x.transpose(1, 2)
+                x = self.sub(x)
+                return x
+
+    """
+
+    def __init__(self, module_path: str, module_type: Any):
+        super().__init__()
+        self.module_path = module_path
+        self.module_type = module_type
+
+
+@compatibility(is_backward_compatible=False)
+class ScopeContextManager:
+    """ A context manager to track the Scope of Node during symbolic tracing.
+    When entering a forward function of a Module, we'll update the scope information of
+    the current module, and when we exit, we'll restore the previous scope information.
+    """
+
+    def __init__(
+        self,
+        scope: Scope,
+        current_scope: Scope,
+    ):
+        super().__init__()
+        # Keep a copy of prev scope to restore on exit
+        self._prev_scope = copy.copy(scope)
+        # Update scope to current scope
+        scope.module_path = current_scope.module_path
+        scope.module_type = current_scope.module_type
+        # Save a reference so we can restore it
+        self._scope = scope
+
+    def __enter__(self):
+        return self._scope
+
+    def __exit__(self, *args):
+        self._scope.module_path = self._prev_scope.module_path
+        self._scope.module_type = self._prev_scope.module_type
+        return
+
+
+_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node", "quantization_tag"]
+
+
+@compatibility(is_backward_compatible=True)
+class TracerBase:
+    graph: Graph
+    record_stack_traces : bool = False
+    # Feature flag for mutable schema checking
+    # Enableby default in 1.12
+    check_mutable_operations : bool = False
+    # Feature flag for assert tracing
+    trace_asserts : bool = False
+    # Feature flag for proxying accesses to buffer values
+    proxy_buffer_attributes : bool = False
+
+    # Name of the function to be traced. It will only be used when
+    # ``root`` is an instance of ``nn.Module``
+    traced_func_name: str = "forward"
+
+    # Maps the containing module's name to the operator name
+    scope : Scope
+
+    # Records the module call stack
+    module_stack: OrderedDict[str, Tuple[str, Any]]
+
+    # Mapping of node name to module scope
+    node_name_to_scope: Dict[str, Tuple[str, type]]
+
+    @compatibility(is_backward_compatible=True)
+    def create_node(self, kind : str, target : Target,
+                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
+                    type_expr : Optional[Any] = None) -> Node:
+        """
+        Inserts a graph node given target, args, kwargs, and name.
+
+        This method can be overridden to do extra checking, validation, or
+        modification of values used in node creation. For example, one might
+        want to disallow in-place operations from being recorded.
+        """
+        if kind == 'call_function' and self.check_mutable_operations:
+            check_for_mutable_operation(target, args, kwargs)
+
+        node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
+        # TODO node_name_to_scope will be depreciated in favor of
+        # node.meta['nn_module_stack']
+        self.node_name_to_scope[node.name] = (
+            self.scope.module_path,
+            self.scope.module_type,
+        )
+        # Optionally set stack trace on the created Node for debugging purposes
+        if fx_traceback.has_preserved_node_meta():
+            current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
+
+            stack_trace = current_meta.get("stack_trace")
+            if stack_trace:
+                node.stack_trace = stack_trace
+            # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
+            # If other meta fields are needed, they can be added here
+            for field in _COPY_META_FIELDS:
+                if field in current_meta:
+                    node.meta[field] = copy.copy(current_meta[field])
+
+            # Here we decrement to account for the sequence_nr having
+            # just been incremented while tracing this lowered aten op.
+            new_seq_nr = torch.autograd._get_sequence_nr() - 1
+            # The sequence_nr increments every time a new autograd Node
+            # is created. During the FWD pass we store the sequence_nr
+            # corresponding to the last autograd Node created on this fx
+            # node's meta.  A single aten op can create multiple autograd
+            # nodes as is the case with in-place foreach ops. During the
+            # BWD pass we retrieve the sequence_nr stored on the current
+            # executing autograd Node. See NOTE [ Sequence Number ].
+            if current_meta.get("in_grad_fn", 0) > 0:
+                new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
+            node.meta["seq_nr"] = new_seq_nr
+
+        elif self.module_stack:
+            node.meta['nn_module_stack'] = copy.copy(self.module_stack)
+        return node
+
+    @compatibility(is_backward_compatible=True)
+    def proxy(self, node: Node) -> 'Proxy':
+        return Proxy(node, self)
+
+    @compatibility(is_backward_compatible=True)
+    def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
+                     name: Optional[str] = None, type_expr : Optional[Any] = None,
+                     proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
+        '''
+        Create a Node from the given arguments, then return the Node
+        wrapped in a Proxy object.
+
+        If kind = 'placeholder', then we're creating a Node that
+        represents the parameter of a function. If we need to encode
+        a default parameter, we use the ``args`` tuple. ``args`` is
+        otherwise empty for ``placeholder`` Nodes.
+        '''
+
+        args_ = self.create_arg(args)
+        kwargs_ = self.create_arg(kwargs)
+        assert isinstance(args_, tuple)
+        assert isinstance(kwargs_, dict)
+
+        node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
+
+        if not proxy_factory_fn:
+            proxy = self.proxy(node)
+        else:
+            proxy = proxy_factory_fn(node)
+
+        if self.record_stack_traces and not proxy.node.stack_trace:
+            user_frame = self._find_user_frame()
+            if user_frame:
+                summary = traceback.extract_stack(user_frame)
+                tb_lines = summary.format()
+                # stack_trace would have innermost frame at the bottom
+                proxy.node.stack_trace = ''.join(tb_lines)
+
+        return proxy
+
+    def _find_user_frame(self):
+        """
+        Find the Python stack frame executing the user code during
+        symbolic tracing.
+        """
+        # We have to do a little dance here. Basically, walk up the callstack and
+        # record the first frame not in the pytorch source. This is the frame executing
+        # the user code during tracing.
+        frame = inspect.currentframe()
+
+        pt_files = ['torch/fx/proxy.py',
+                    'torch/fx/_symbolic_trace.py',
+                    'torch/fx/experimental/proxy_tensor.py',
+                    'torch/_ops.py',
+                    'torch/_tensor.py',
+                    'torch/utils/_python_dispatch.py',
+                    'torch/_prims_common/wrappers.py',
+                    'torch/_refs/__init__.py',
+                    'torch/_refs/nn/functional/__init__.py',
+                    'torch/utils/_stats.py',
+                    ]
+        while frame:
+            frame = frame.f_back
+            if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
+                break
+
+        if not frame:
+            return None
+
+        return frame
+
+    @compatibility(is_backward_compatible=True)
+    def create_arg(self, a: Any) -> Argument:
+        """
+        A method that lowers the objects seen as arguments during symbolic evaluation
+        into Argument types that can be stored in IR.
+
+        Can be override to support more trace-specific types.
+        """
+        if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
+            return a.__fx_create_arg__(self)
+        # aggregates
+        elif isinstance(a, tuple) and hasattr(a, '_fields'):
+            # NamedTuple constructors don't seem to like getting a generator
+            # expression as an argument to their constructor, so build this
+            # intermediate tuple and unpack it into the NamedTuple constructor
+            args = tuple(self.create_arg(elem) for elem in a)
+            return type(a)(*args)  # type: ignore[arg-type]
+        elif isinstance(a, (tuple, list)):
+            return type(a)(self.create_arg(elem) for elem in a)
+        elif isinstance(a, dict):
+            r = {}
+            for k, v in a.items():
+                # Check for invalid dict keys. We do not want a Proxy to appear
+                # anywhere within the key. Since keys can be collection types,
+                # we iterate through the key with map_aggregate
+                k = self.create_arg(k)
+
+                def no_node(arg):
+                    if isinstance(arg, Node):
+                        raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
+                                           f"Node. Got key: {k}")
+                map_aggregate(k, no_node)
+
+                r[k] = self.create_arg(v)
+            return r
+        elif isinstance(a, slice):
+            return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
+
+        elif isinstance(a, range):
+            return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
+
+        elif isinstance(a, torch._ops.OpOverload):
+            return a
+
+        if isinstance(a, Proxy):
+            # base case: we unwrap the Proxy object
+            return a.node
+
+        if is_dataclass(a):
+            kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
+            return self.create_node("call_function", a.__class__, (), kwargs)
+
+        elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
+            return a
+        raise NotImplementedError(f"argument of type: {type(a)}")
+
+    @compatibility(is_backward_compatible=True)
+    def to_bool(self, obj: 'Proxy') -> bool:
+        """Called when a proxy object is being converted to a boolean, such as
+        when used in control flow.  Normally we don't know what to do because
+        we don't know the value of the proxy, but a custom tracer can attach more
+        information to the graph node using create_node and can choose to return a value.
+        """
+        raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
+
+    @compatibility(is_backward_compatible=True)
+    def iter(self, obj: 'Proxy') -> Iterator:
+        """Called when a proxy object is being iterated over, such as
+        when used in control flow.  Normally we don't know what to do because
+        we don't know the value of the proxy, but a custom tracer can attach more
+        information to the graph node using create_node and can choose to return an iterator.
+        """
+        raise TraceError('Proxy object cannot be iterated. This can be '
+                         'attempted when the Proxy is used in a loop or'
+                         ' as a *args or **kwargs function argument. '
+                         'See the torch.fx docs on pytorch.org for a '
+                         'more detailed explanation of what types of '
+                         'control flow can be traced, and check out the'
+                         ' Proxy docstring for help troubleshooting '
+                         'Proxy iteration errors')
+
+    @compatibility(is_backward_compatible=True)
+    def keys(self, obj: 'Proxy') -> Any:
+        """Called when a proxy object is has the keys() method called.
+        This is what happens when ** is called on a proxy. This should return an
+        iterator it ** is suppose to work in your custom tracer.
+        """
+        return Attribute(obj, 'keys')()
+
+
+# used in Proxy object when just appending to the graph while not tracing.
+@compatibility(is_backward_compatible=True)
+class GraphAppendingTracer(TracerBase):
+    def __init__(self, graph: Graph):
+        super().__init__()
+        self.graph = graph
+        self.scope = Scope("", None)
+        self.module_stack = collections.OrderedDict()
+        self.node_name_to_scope = {}
+
+@compatibility(is_backward_compatible=False)
+def assert_fn(x):
+    assert x
+
+@compatibility(is_backward_compatible=True)
+class TraceError(ValueError):
+    pass
+
+@compatibility(is_backward_compatible=True)
+class Proxy:
+    """
+    ``Proxy`` objects are ``Node`` wrappers that flow through the
+    program during symbolic tracing and record all the operations
+    (``torch`` function calls, method calls, operators) that they touch
+    into the growing FX Graph.
+
+    If you're doing graph transforms, you can wrap your own ``Proxy``
+    method around a raw ``Node`` so that you can use the overloaded
+    operators to add additional things to a ``Graph``.
+
+    ``Proxy`` objects cannot be iterated. In other words, the symbolic
+    tracer will throw an error if a ``Proxy`` is used in a loop or as
+    an ``*args``/``**kwargs`` function argument.
+
+    There are two main ways around this:
+    1. Factor out the untraceable logic into a top-level function and
+    use ``fx.wrap`` on it.
+    2. If the control flow is static (i.e. the loop trip count is
+    based on some hyperparameter), the code can be kept in its original
+    position and refactored into something like::
+
+        for i in range(self.some_hyperparameter):
+            indexed_item = proxied_value[i]
+
+    For a more detailed description into the Proxy internals, check out
+    the "Proxy" section in `torch/fx/OVERVIEW.md`
+    """
+
+    @compatibility(is_backward_compatible=True)
+    def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
+        if tracer is None:
+            # This allows you to create a Proxy object around a raw Node
+            tracer = GraphAppendingTracer(node.graph)
+        self.tracer = tracer
+        self.node = node
+
+    def __repr__(self) -> str:
+        return f'Proxy({self.node.name})'
+
+    def __getattr__(self, k) -> 'Attribute':
+        # note: not added to the graph yet, if this is a method call
+        # we peephole optimize to the method invocation
+        return Attribute(self, k)
+
+    def __call__(self, *args, **kwargs) -> 'Proxy':
+        return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
+
+    def __iter__(self) -> Iterator['Proxy']:
+        frame = inspect.currentframe()
+        assert frame is not None
+        calling_frame = frame.f_back
+        assert calling_frame is not None
+        inst_list = list(dis.get_instructions(calling_frame.f_code))
+        if sys.version_info >= (3, 11):
+            from bisect import bisect_left
+            inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
+        else:
+            inst_idx = calling_frame.f_lasti // 2
+        inst = inst_list[inst_idx]
+        if inst.opname == 'UNPACK_SEQUENCE':
+            return (self[i] for i in range(inst.argval))  # type: ignore[index]
+
+        return self.tracer.iter(self)
+
+    def __abs__(self):
+        return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
+
+    def __bool__(self) -> bool:
+        if self.tracer.trace_asserts:
+            # check if this boolean is used in an assertion, bytecode pattern for assertions
+            # is pretty stable for Python 3.7--3.9
+            frame = inspect.currentframe()
+            assert frame is not None
+            calling_frame = frame.f_back
+            assert calling_frame is not None
+            insts = list(dis.get_instructions(calling_frame.f_code))
+            if sys.version_info >= (3, 11):
+                from bisect import bisect_left
+                cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
+            else:
+                cur = calling_frame.f_lasti // 2
+            inst = insts[cur]
+
+            if inst.opname == 'POP_JUMP_IF_TRUE':
+                first = insts[cur + 1]
+                assert inst.arg is not None
+                last = insts[inst.arg // 2 - 1]
+                starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
+                                      or first.opname == 'LOAD_ASSERTION_ERROR')
+                if starts_with_assert and last.opname == 'RAISE_VARARGS':
+                    self.tracer.create_proxy('call_function', assert_fn, (self,), {})
+                    return True
+
+        return self.tracer.to_bool(self)
+
+    @compatibility(is_backward_compatible=True)
+    def keys(self):
+        return self.tracer.keys(self)
+
+    def __len__(self):
+        raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
+                           "this call to be recorded, please call torch.fx.wrap('len') at "
+                           "module scope")
+
+    @classmethod
+    def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
+        args = args if args else ()
+        kwargs = kwargs if kwargs else {}
+
+        tracers : Dict[Any, None] = {}
+
+        def find_tracer(a):
+            if isinstance(a, cls):
+                tracers[a.tracer] = None
+        torch.fx.node.map_aggregate(args, find_tracer)
+        torch.fx.node.map_aggregate(kwargs, find_tracer)
+
+        if len(tracers) > 1:
+            raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
+                               f'trying to trace operations {orig_method}')
+        tracer = next(iter(tracers.keys()))
+
+        if isinstance(orig_method, torch._C.ScriptMethod):
+            args = (orig_method.owner,) + args
+            return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
+        if torch.overrides.is_tensor_method_or_property(orig_method):
+            return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
+        else:
+            if isinstance(orig_method, torch._ops.HigherOrderOperator):
+                # TODO: Define how to symbolically trace HigherOrderOperators
+                raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
+            return tracer.create_proxy('call_function', orig_method, args, kwargs,
+                                       name=tracer.graph._target_to_str(orig_method.__name__))
+
+
+@compatibility(is_backward_compatible=True)
+class Attribute(Proxy):
+    @compatibility(is_backward_compatible=True)
+    def __init__(self, root: Proxy, attr: str):
+        self.root = root
+        self.attr = attr
+        self.tracer = root.tracer
+        self._node: Optional[Node] = None
+
+    @property
+    def node(self):
+        # the node for attributes is added lazily, since most will just be method calls
+        # which do not rely on the getitem call
+        if self._node is None:
+            self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+        return self._node
+
+    def __call__(self, *args, **kwargs):
+        return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+
+
+@compatibility(is_backward_compatible=False)
+class ParameterProxy(Proxy):
+    """
+    A special proxy which lets "shape", "size", "dim", and a few other
+    attribute accesses pass through to the underlying  module parameter object,
+    so that conditional tests on these attributes will not throw exception during tracing
+    """
+    def __init__(self, tracer: TracerBase, node: Node, name, param):
+        super().__init__(node, tracer)
+        assert isinstance(param, torch.nn.Parameter)
+        self.param = param
+        self.name = name
+
+    def __repr__(self) -> str:
+        return f'ParameterProxy({self.name})'
+
+    @property
+    def shape(self):
+        return self.param.shape
+
+    def size(self):
+        return self.param.size()
+
+    def dim(self):
+        return self.param.dim()
+
+    @property
+    def ndim(self):
+        return self.param.ndim
+
+    def numel(self):
+        return self.param.numel()
+
+    def nelement(self):
+        return self.param.nelement()
+
+
+for method in magic_methods:
+    def _scope(method):
+        def impl(*args, **kwargs):
+            tracer = args[0].tracer
+            target = getattr(operator, method)
+            return tracer.create_proxy('call_function', target, args, kwargs)
+        impl.__name__ = method
+        as_magic = f'__{method.strip("_")}__'
+        setattr(Proxy, as_magic, impl)
+    _scope(method)
+
+def _define_reflectable(orig_method_name):
+    method_name = f'__r{orig_method_name.strip("_")}__'
+
+    def impl(self, rhs):
+        target = getattr(operator, orig_method_name)
+        return self.tracer.create_proxy('call_function', target, (rhs, self), {})
+    impl.__name__ = method_name
+    impl.__qualname__ = method_name
+    setattr(Proxy, method_name, impl)
+
+for orig_method_name in reflectable_magic_methods:
+    _define_reflectable(orig_method_name)
diff --git a/MLPY/Lib/site-packages/torch/fx/subgraph_rewriter.py b/MLPY/Lib/site-packages/torch/fx/subgraph_rewriter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b7364794cf863acdc4935f519d869250ef63d3a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/subgraph_rewriter.py
@@ -0,0 +1,349 @@
+from .graph_module import GraphModule
+from .graph import Graph
+from .node import Node
+from ._symbolic_trace import symbolic_trace
+from ._compatibility import compatibility
+
+import copy
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
+import torch
+
+if TYPE_CHECKING:
+    from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
+
+__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
+
+@compatibility(is_backward_compatible=True)
+class Match(NamedTuple):
+    # Node from which the match was found
+    anchor: Node
+    # Maps nodes in the pattern subgraph to nodes in the larger graph
+    nodes_map: Dict[Node, Node]
+
+@compatibility(is_backward_compatible=False)
+@dataclass
+class ReplacedPatterns:
+    # Node from which the match was found
+    anchor: Node
+    # Maps nodes in the pattern subgraph to nodes in the larger graph
+    nodes_map: Dict[Node, Node]
+    # List of nodes that were added into the graph
+    replacements: List[Node]
+
+def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
+    gm.delete_all_unused_submodules()
+
+    if isinstance(replacement, GraphModule):
+        replacement.graph.lint()
+
+    def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
+        module_path, _, attr_name = target.rpartition(".")
+        try:
+            mod: torch.nn.Module = gm.get_submodule(module_path)
+        except AttributeError:
+            return None
+        attr = getattr(mod, attr_name, None)
+        return attr
+
+    for node in gm.graph.nodes:
+        if node.op == "call_module" or node.op == "get_attr":
+
+            gm_attr = try_get_attr(gm, node.target)
+            replacement_attr = try_get_attr(replacement, node.target)
+
+            # CASE 1: This target already exists as an attribute in our
+            # result GraphModule. Whether or not it exists in
+            # `replacement`, the existing submodule takes precedence.
+            if gm_attr is not None:
+                continue
+
+            # CASE 2: The target exists as an attribute in `replacement`
+            # only, so we need to copy it over.
+            elif replacement_attr is not None:
+                new_attr = copy.deepcopy(replacement_attr)
+                if isinstance(replacement_attr, torch.nn.Module):
+                    gm.add_submodule(node.target, new_attr)
+                else:
+                    setattr(gm, node.target, new_attr)
+
+            # CASE 3: The target doesn't exist as an attribute in `gm`
+            # or `replacement`
+            else:
+                raise RuntimeError("Attempted to create a \"", node.op,
+                                   "\" node during subgraph rewriting "
+                                   f"with target {node.target}, but "
+                                   "the referenced attribute does not "
+                                   "exist in the replacement GraphModule")
+
+    gm.graph.lint()
+
+
+@compatibility(is_backward_compatible=True)
+def replace_pattern(
+    gm: GraphModule,
+    pattern: Union[Callable, GraphModule],
+    replacement: Union[Callable, GraphModule]
+) -> List[Match]:
+    """
+    Matches all possible non-overlapping sets of operators and their
+    data dependencies (``pattern``) in the Graph of a GraphModule
+    (``gm``), then replaces each of these matched subgraphs with another
+    subgraph (``replacement``).
+
+    Args:
+        ``gm``: The GraphModule that wraps the Graph to operate on
+        ``pattern``: The subgraph to match in ``gm`` for replacement
+        ``replacement``: The subgraph to replace ``pattern`` with
+
+    Returns:
+        List[Match]: A list of ``Match`` objects representing the places
+        in the original graph that ``pattern`` was matched to. The list
+        is empty if there are no matches. ``Match`` is defined as:
+
+        .. code-block:: python
+
+            class Match(NamedTuple):
+                # Node from which the match was found
+                anchor: Node
+                # Maps nodes in the pattern subgraph to nodes in the larger graph
+                nodes_map: Dict[Node, Node]
+
+    Examples:
+
+    .. code-block:: python
+
+        import torch
+        from torch.fx import symbolic_trace, subgraph_rewriter
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x, w1, w2):
+                m1 = torch.cat([w1, w2]).sum()
+                m2 = torch.cat([w1, w2]).sum()
+                return x + torch.max(m1) + torch.max(m2)
+
+        def pattern(w1, w2):
+            return torch.cat([w1, w2]).sum()
+
+        def replacement(w1, w2):
+            return torch.stack([w1, w2])
+
+        traced_module = symbolic_trace(M())
+
+        subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
+
+    The above code will first match ``pattern`` in the ``forward``
+    method of ``traced_module``. Pattern-matching is done based on
+    use-def relationships, not node names. For example, if you had
+    ``p = torch.cat([a, b])`` in ``pattern``, you could match
+    ``m = torch.cat([a, b])`` in the original ``forward`` function,
+    despite the variable names being different (``p`` vs ``m``).
+
+    The ``return`` statement in ``pattern`` is matched based on its
+    value only; it may or may not match to the ``return`` statement in
+    the larger graph. In other words, the pattern doesn't have to extend
+    to the end of the larger graph.
+
+    When the pattern is matched, it will be removed from the larger
+    function and replaced by ``replacement``. If there are multiple
+    matches for ``pattern`` in the larger function, each non-overlapping
+    match will be replaced. In the case of a match overlap, the first
+    found match in the set of overlapping matches will be replaced.
+    ("First" here being defined as the first in a topological ordering
+    of the Nodes' use-def relationships. In most cases, the first Node
+    is the parameter that appears directly after ``self``, while the
+    last Node is whatever the function returns.)
+
+    One important thing to note is that the parameters of the
+    ``pattern`` Callable must be used in the Callable itself,
+    and the parameters of the ``replacement`` Callable must match
+    the pattern. The first rule is why, in the above code block, the
+    ``forward`` function has parameters ``x, w1, w2``, but the
+    ``pattern`` function only has parameters ``w1, w2``. ``pattern``
+    doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
+    As an example of the second rule, consider replacing
+
+    .. code-block:: python
+
+        def pattern(x, y):
+            return torch.neg(x) + torch.relu(y)
+
+    with
+
+    .. code-block:: python
+
+        def replacement(x, y):
+            return torch.relu(x)
+
+    In this case, ``replacement`` needs the same number of parameters
+    as ``pattern`` (both ``x`` and ``y``), even though the parameter
+    ``y`` isn't used in ``replacement``.
+
+    After calling ``subgraph_rewriter.replace_pattern``, the generated
+    Python code looks like this:
+
+    .. code-block:: python
+
+        def forward(self, x, w1, w2):
+            stack_1 = torch.stack([w1, w2])
+            sum_1 = stack_1.sum()
+            stack_2 = torch.stack([w1, w2])
+            sum_2 = stack_2.sum()
+            max_1 = torch.max(sum_1)
+            add_1 = x + max_1
+            max_2 = torch.max(sum_2)
+            add_2 = add_1 + max_2
+            return add_2
+    """
+    match_and_replacements = _replace_pattern(gm, pattern, replacement)
+    return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
+
+
+# Experimental API, not backward compatible
+@compatibility(is_backward_compatible=False)
+def replace_pattern_with_filters(
+    gm: GraphModule,
+    pattern: Union[Callable, Graph, GraphModule],
+    replacement: Union[Callable, Graph, GraphModule],
+    match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
+    ignore_literals: bool = False,
+) -> List[ReplacedPatterns]:
+    """
+    See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
+
+    Args:
+        ``match_filters``: A list of functions that take in
+            (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
+            whether the match satisfies the condition.
+            See matcher_utils.py for definition of InternalMatch.
+    """
+
+    return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals)
+
+
+def _replace_pattern(
+    gm: GraphModule,
+    pattern: Union[Callable, Graph, GraphModule],
+    replacement: Union[Callable, Graph, GraphModule],
+    match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
+    ignore_literals: bool = False,
+) -> List[ReplacedPatterns]:
+
+    from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
+
+    if match_filters is None:
+        match_filters = []
+
+    # Get the graphs for `gm`, `pattern`, `replacement`
+    original_graph: Graph = gm.graph
+
+    if isinstance(pattern, GraphModule):
+        pattern_graph = pattern.graph
+    elif isinstance(pattern, Graph):
+        pattern_graph = pattern
+    else:
+        pattern_graph = symbolic_trace(pattern).graph
+
+    if isinstance(replacement, GraphModule):
+        replacement_graph = replacement.graph
+    elif isinstance(replacement, Graph):
+        replacement_graph = replacement
+    else:
+        replacement_graph = symbolic_trace(replacement).graph
+
+    matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
+                              remove_overlapping_matches=True, ignore_literals=ignore_literals)
+    _matches: List[InternalMatch] = matcher.match(original_graph)
+
+    # Filter out matches that don't match the filter
+    _matches = [
+        m for m in _matches
+        if all(match_filter(m, original_graph, pattern_graph)
+               for match_filter in match_filters)
+    ]
+
+    replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
+
+    # As we progressively replace nodes, we'll need to keep track of how the match results should change
+    match_changed_node: Dict[Node, Node] = {}
+
+    match_and_replacements = []
+    for match in _matches:
+
+        # Build connecting between replacement graph's input and original graph input producer node
+
+        # Initialize `val_map` with mappings from placeholder nodes in
+        # `replacement` to their corresponding node in `original_graph`
+        assert len(match.placeholder_nodes) == len(replacement_placeholders)
+        val_map: Dict[Node, Node] = {}
+        for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
+            if isinstance(gn, Node):
+                val_map[rn] = match_changed_node.get(gn, gn)
+                if gn != val_map[rn]:
+                    # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
+                    gn_ind = match.placeholder_nodes.index(gn)
+                    match.placeholder_nodes[gn_ind] = match_changed_node[gn]
+                    map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)]
+                    match.nodes_map[map_key] = match_changed_node[gn]
+            else:
+                val_map[rn] = gn
+
+        # Copy the replacement graph over
+        user_nodes: Set[Node] = set()
+        for n in match.returning_nodes:
+            for user in n.users:
+                user_nodes.add(user)
+        assert user_nodes, "The returning_nodes should have at least one user node"
+
+        if len(user_nodes) == 1:
+            first_user_node = next(iter(user_nodes))
+        else:
+            # If there are multiple user nodes, we need to find the first user node
+            # in the current execution order of the `original_graph`
+            for n in original_graph.nodes:
+                if n in user_nodes:
+                    first_user_node = n
+                    break
+
+        with original_graph.inserting_before(first_user_node):  # type: ignore[possibly-undefined]
+            copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
+
+        if isinstance(copied_returning_nodes, Node):
+            copied_returning_nodes = (copied_returning_nodes, )
+
+        # Get a list of nodes that have been replaced into the graph
+        replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes]
+
+        # Hook the output Node of the replacement subgraph in to the
+        # original Graph at the correct location
+        assert len(match.returning_nodes) == len(copied_returning_nodes)
+        for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
+            gn.replace_all_uses_with(copied_node)
+            match_changed_node[gn] = copied_node
+        # Remove the original nodes
+        for node in reversed(pattern_graph.nodes):
+            if node.op != "placeholder" and node.op != "output":
+                gn = match.nodes_map[node]
+                gm.graph.erase_node(gn)
+
+        match_and_replacements.append(
+            ReplacedPatterns(
+                anchor=match.anchors[0],
+                nodes_map=match.nodes_map,
+                replacements=replacement_nodes
+            )
+        )
+
+    # Update the passed-in GraphModule to reflect the new state of
+    # `original_graph`
+    gm.recompile()
+
+    # If `replacement` was an nn.Module, we'll need to make sure that
+    # all the submodules have been copied over correctly
+    if isinstance(replacement, torch.nn.Module):
+        _replace_attributes(gm, replacement)
+
+    return match_and_replacements
diff --git a/MLPY/Lib/site-packages/torch/fx/tensor_type.py b/MLPY/Lib/site-packages/torch/fx/tensor_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec0105c846d2a9249c9a930738db2ec8b8b2aab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/tensor_type.py
@@ -0,0 +1,104 @@
+from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
+
+from ._compatibility import compatibility
+
+
+@compatibility(is_backward_compatible=False)
+class TensorType:
+    """
+    TensorType defines a type for tensors, which consists of a list of dimensions.
+    Example:
+        class M(torch.nn.Module):
+            def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))):
+                return torch.add(x, y)
+    """
+
+    def __init__(self, dim):
+        self.__origin__ = TensorType
+        self.__args__ = dim
+
+    def __repr__(self):
+        return f'TensorType[{self.__args__}]'
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return list(self.__args__) == list(other.__args__)
+        else:
+            return False
+
+    @staticmethod
+    def __class_getitem__(*args):
+        if len(args) == 1 and isinstance(args[0], tuple):
+            args = args[0]
+        return TensorType(tuple(args))
+
+
+class _DynType:
+    """
+    _DynType defines a type which stands for the absence of type information.
+    """
+    def __init__(self):
+        self.__name__ = '_DynType'
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__)
+
+    def __str__(self):
+        return "Dyn"
+
+    def __repr__(self):
+        return "Dyn"
+
+
+Dyn = _DynType()
+
+@compatibility(is_backward_compatible=False)
+def is_consistent(t1, t2):
+    """
+    A binary relation denoted by ~ that determines if t1 is consistent with t2.
+    The relation is reflexive, symmetric but not transitive.
+    returns True if t1 and t2 are consistent and False otherwise.
+    Example:
+        Dyn ~ TensorType((1,2,3))
+        int ~ Dyn
+        int ~ int
+        TensorType((1,Dyn,3)) ~ TensorType((1,2,3))
+    """
+
+    if t1 == t2:
+        return True
+
+    if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
+        return True
+
+    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
+        return len(t1.__args__) == len(t2.__args__) and \
+            all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
+    else:
+        return False
+
+
+@compatibility(is_backward_compatible=False)
+def is_more_precise(t1, t2):
+    """
+    A binary relation denoted by <= that determines if t1 is more precise than t2.
+    The relation is reflexive and transitive.
+    returns True if t1 is more precise than t2 and False otherwise.
+    Example:
+        Dyn >= TensorType((1,2,3))
+        int >= Dyn
+        int >= int
+        TensorType((1,Dyn,3)) <= TensorType((1,2,3))
+    """
+    if t1 == t2:
+        return True
+
+    if isinstance(t2, _DynType):
+        return True
+
+    if isinstance(t1, TensorType) and isinstance(t2, TensorType):
+        return len(t1.__args__) == len(t2.__args__) and \
+            all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
+
+    else:
+        return False
diff --git a/MLPY/Lib/site-packages/torch/fx/traceback.py b/MLPY/Lib/site-packages/torch/fx/traceback.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ea48ca1107523575a36ff822b10b81373e2046
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/fx/traceback.py
@@ -0,0 +1,99 @@
+import traceback
+from contextlib import contextmanager
+from typing import List, Any, Dict
+from ._compatibility import compatibility
+
+__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
+           'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
+           'format_stack', 'set_current_meta', 'get_current_meta']
+
+current_meta: Dict[str, Any] = {}
+should_preserve_node_meta = False
+
+
+@compatibility(is_backward_compatible=False)
+@contextmanager
+def preserve_node_meta():
+    global should_preserve_node_meta
+
+    saved_should_preserve_node_meta = should_preserve_node_meta
+    try:
+        should_preserve_node_meta = True
+        yield
+    finally:
+        should_preserve_node_meta = saved_should_preserve_node_meta
+
+
+@compatibility(is_backward_compatible=False)
+def set_stack_trace(stack : List[str]):
+    global current_meta
+
+    if should_preserve_node_meta and stack:
+        current_meta["stack_trace"] = "".join(stack)
+
+
+@compatibility(is_backward_compatible=False)
+def set_grad_fn_seq_nr(seq_nr):
+    global current_meta
+
+    if should_preserve_node_meta:
+        # The seq_nr is captured by eager mode in the grad_fn during forward
+        current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
+        current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
+
+
+@compatibility(is_backward_compatible=False)
+def reset_grad_fn_seq_nr():
+    # NB: reset state properly, this would be helpful towards supporting
+    #     reentrant autograd if we actually wanted to do that.
+    global current_meta
+    if should_preserve_node_meta:
+        current_level = current_meta.get("in_grad_fn", 0)
+        assert current_level > 0
+        if current_level == 1:
+            del current_meta["in_grad_fn"]
+            del current_meta["grad_fn_seq_nr"]
+        else:
+            current_meta["in_grad_fn"] = current_level - 1
+            current_meta["grad_fn_seq_nr"].pop()
+
+
+@compatibility(is_backward_compatible=False)
+def format_stack() -> List[str]:
+    if should_preserve_node_meta:
+        return [current_meta.get("stack_trace", "")]
+    else:
+        # fallback to traceback.format_stack()
+        return traceback.format_list(traceback.extract_stack()[:-1])
+
+
+@compatibility(is_backward_compatible=False)
+def has_preserved_node_meta() -> bool:
+    return should_preserve_node_meta
+
+
+@compatibility(is_backward_compatible=False)
+@contextmanager
+def set_current_meta(node):
+    global current_meta
+    if should_preserve_node_meta and node.meta:
+        saved_meta = current_meta
+        try:
+            current_meta = node.meta.copy()
+
+            # Append (node.name, node.target) onto "from_node" for provenance tracking
+            if "from_node" not in current_meta:
+                current_meta["from_node"] = [(node.name, node.target)]
+            elif current_meta["from_node"][-1][0] != node.name:
+                current_meta["from_node"].append((node.name, node.target))
+
+            yield
+        finally:
+            current_meta = saved_meta
+    else:
+        yield
+
+
+@compatibility(is_backward_compatible=False)
+def get_current_meta() -> Dict[str, Any]:
+    return current_meta
diff --git a/MLPY/Lib/site-packages/torch/hub.py b/MLPY/Lib/site-packages/torch/hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e3a9e55397b59799b4eeb3846c1ceb480eab4c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/hub.py
@@ -0,0 +1,764 @@
+import contextlib
+import errno
+import hashlib
+import json
+import os
+import re
+import shutil
+import sys
+import tempfile
+import torch
+import uuid
+import warnings
+import zipfile
+from pathlib import Path
+from typing import Dict, Optional, Any
+from urllib.error import HTTPError, URLError
+from urllib.request import urlopen, Request
+from urllib.parse import urlparse  # noqa: F401
+from torch.serialization import MAP_LOCATION
+
+class _Faketqdm:  # type: ignore[no-redef]
+
+    def __init__(self, total=None, disable=False,
+                 unit=None, *args, **kwargs):
+        self.total = total
+        self.disable = disable
+        self.n = 0
+        # Ignore all extra *args and **kwargs lest you want to reinvent tqdm
+
+    def update(self, n):
+        if self.disable:
+            return
+
+        self.n += n
+        if self.total is None:
+            sys.stderr.write(f"\r{self.n:.1f} bytes")
+        else:
+            sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
+        sys.stderr.flush()
+
+    # Don't bother implementing; use real tqdm if you want
+    def set_description(self, *args, **kwargs):
+        pass
+
+    def write(self, s):
+        sys.stderr.write(f"{s}\n")
+
+    def close(self):
+        self.disable = True
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.disable:
+            return
+
+        sys.stderr.write('\n')
+
+try:
+    from tqdm import tqdm  # If tqdm is installed use it, otherwise use the fake wrapper
+except ImportError:
+    tqdm = _Faketqdm
+
+__all__ = [
+    'download_url_to_file',
+    'get_dir',
+    'help',
+    'list',
+    'load',
+    'load_state_dict_from_url',
+    'set_dir',
+]
+
+# matches bfd8deac from resnet18-bfd8deac.pth
+HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
+
+_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
+ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
+ENV_TORCH_HOME = 'TORCH_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+VAR_DEPENDENCY = 'dependencies'
+MODULE_HUBCONF = 'hubconf.py'
+READ_DATA_CHUNK = 128 * 1024
+_hub_dir: Optional[str] = None
+
+
+@contextlib.contextmanager
+def _add_to_sys_path(path):
+    sys.path.insert(0, path)
+    try:
+        yield
+    finally:
+        sys.path.remove(path)
+
+
+# Copied from tools/shared/module_loader to be included in torch package
+def _import_module(name, path):
+    import importlib.util
+    from importlib.abc import Loader
+    spec = importlib.util.spec_from_file_location(name, path)
+    assert spec is not None
+    module = importlib.util.module_from_spec(spec)
+    assert isinstance(spec.loader, Loader)
+    spec.loader.exec_module(module)
+    return module
+
+
+def _remove_if_exists(path):
+    if os.path.exists(path):
+        if os.path.isfile(path):
+            os.remove(path)
+        else:
+            shutil.rmtree(path)
+
+
+def _git_archive_link(repo_owner, repo_name, ref):
+    # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip
+    return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}"
+
+
+def _load_attr_from_module(module, func_name):
+    # Check if callable is defined in the module
+    if func_name not in dir(module):
+        return None
+    return getattr(module, func_name)
+
+
+def _get_torch_home():
+    torch_home = os.path.expanduser(
+        os.getenv(ENV_TORCH_HOME,
+                  os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
+                                         DEFAULT_CACHE_DIR), 'torch')))
+    return torch_home
+
+
+def _parse_repo_info(github):
+    if ':' in github:
+        repo_info, ref = github.split(':')
+    else:
+        repo_info, ref = github, None
+    repo_owner, repo_name = repo_info.split('/')
+
+    if ref is None:
+        # The ref wasn't specified by the user, so we need to figure out the
+        # default branch: main or master. Our assumption is that if main exists
+        # then it's the default branch, otherwise it's master.
+        try:
+            with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
+                ref = 'main'
+        except HTTPError as e:
+            if e.code == 404:
+                ref = 'master'
+            else:
+                raise
+        except URLError as e:
+            # No internet connection, need to check for cache as last resort
+            for possible_ref in ("main", "master"):
+                if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
+                    ref = possible_ref
+                    break
+            if ref is None:
+                raise RuntimeError(
+                    "It looks like there is no internet connection and the "
+                    f"repo could not be found in the cache ({get_dir()})"
+                ) from e
+    return repo_owner, repo_name, ref
+
+
+def _read_url(url):
+    with urlopen(url) as r:
+        return r.read().decode(r.headers.get_content_charset('utf-8'))
+
+
+def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
+    # Use urlopen to avoid depending on local git.
+    headers = {'Accept': 'application/vnd.github.v3+json'}
+    token = os.environ.get(ENV_GITHUB_TOKEN)
+    if token is not None:
+        headers['Authorization'] = f'token {token}'
+    for url_prefix in (
+            f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
+            f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
+        page = 0
+        while True:
+            page += 1
+            url = f'{url_prefix}?per_page=100&page={page}'
+            response = json.loads(_read_url(Request(url, headers=headers)))
+            # Empty response means no more data to process
+            if not response:
+                break
+            for br in response:
+                if br['name'] == ref or br['commit']['sha'].startswith(ref):
+                    return
+
+    raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
+                     'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
+
+
+def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
+    # Setup hub_dir to save downloaded files
+    hub_dir = get_dir()
+    os.makedirs(hub_dir, exist_ok=True)
+    # Parse github repo information
+    repo_owner, repo_name, ref = _parse_repo_info(github)
+    # Github allows branch name with slash '/',
+    # this causes confusion with path on both Linux and Windows.
+    # Backslash is not allowed in Github branch name so no need to
+    # to worry about it.
+    normalized_br = ref.replace('/', '_')
+    # Github renames folder repo-v1.x.x to repo-1.x.x
+    # We don't know the repo name before downloading the zip file
+    # and inspect name from it.
+    # To check if cached repo exists, we need to normalize folder names.
+    owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
+    repo_dir = os.path.join(hub_dir, owner_name_branch)
+    # Check that the repo is in the trusted list
+    _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
+
+    use_cache = (not force_reload) and os.path.exists(repo_dir)
+
+    if use_cache:
+        if verbose:
+            sys.stderr.write(f'Using cache found in {repo_dir}\n')
+    else:
+        # Validate the tag/branch is from the original repo instead of a forked repo
+        if not skip_validation:
+            _validate_not_a_forked_repo(repo_owner, repo_name, ref)
+
+        cached_file = os.path.join(hub_dir, normalized_br + '.zip')
+        _remove_if_exists(cached_file)
+
+        try:
+            url = _git_archive_link(repo_owner, repo_name, ref)
+            sys.stderr.write(f'Downloading: \"{url}\" to {cached_file}\n')
+            download_url_to_file(url, cached_file, progress=False)
+        except HTTPError as err:
+            if err.code == 300:
+                # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch
+                # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags
+                # See https://git-scm.com/book/en/v2/Git-Internals-Git-References
+                # Here, we do the same as git: we throw a warning, and assume the user wanted the branch
+                warnings.warn(
+                    f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? "
+                    "Torchhub will now assume that it's a branch. "
+                    "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or "
+                    "refs/tags/tag_name as the ref. That might require using skip_validation=True."
+                )
+                disambiguated_branch_ref = f"refs/heads/{ref}"
+                url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
+                download_url_to_file(url, cached_file, progress=False)
+            else:
+                raise
+
+        with zipfile.ZipFile(cached_file) as cached_zipfile:
+            extraced_repo_name = cached_zipfile.infolist()[0].filename
+            extracted_repo = os.path.join(hub_dir, extraced_repo_name)
+            _remove_if_exists(extracted_repo)
+            # Unzip the code and rename the base folder
+            cached_zipfile.extractall(hub_dir)
+
+        _remove_if_exists(cached_file)
+        _remove_if_exists(repo_dir)
+        shutil.move(extracted_repo, repo_dir)  # rename the repo
+
+    return repo_dir
+
+
+def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
+    hub_dir = get_dir()
+    filepath = os.path.join(hub_dir, "trusted_list")
+
+    if not os.path.exists(filepath):
+        Path(filepath).touch()
+    with open(filepath) as file:
+        trusted_repos = tuple(line.strip() for line in file)
+
+    # To minimize friction of introducing the new trust_repo mechanism, we consider that
+    # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
+    trusted_repos_legacy = next(os.walk(hub_dir))[1]
+
+    owner_name = '_'.join([repo_owner, repo_name])
+    is_trusted = (
+        owner_name in trusted_repos
+        or owner_name_branch in trusted_repos_legacy
+        or repo_owner in _TRUSTED_REPO_OWNERS
+    )
+
+    # TODO: Remove `None` option in 2.0 and change the default to "check"
+    if trust_repo is None:
+        if not is_trusted:
+            warnings.warn(
+                "You are about to download and run code from an untrusted repository. In a future release, this won't "
+                "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
+                "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
+                f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
+                f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
+                f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
+        return
+
+    if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
+        response = input(
+            f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
+            "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
+        if response.lower() in ("y", "yes"):
+            if is_trusted:
+                print("The repository is already trusted.")
+        elif response.lower() in ("n", "no", ""):
+            raise Exception("Untrusted repository.")
+        else:
+            raise ValueError(f"Unrecognized response {response}.")
+
+    # At this point we're sure that the user trusts the repo (or wants to trust it)
+    if not is_trusted:
+        with open(filepath, "a") as file:
+            file.write(owner_name + "\n")
+
+
+def _check_module_exists(name):
+    import importlib.util
+    return importlib.util.find_spec(name) is not None
+
+
+def _check_dependencies(m):
+    dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
+
+    if dependencies is not None:
+        missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
+        if len(missing_deps):
+            raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}")
+
+
+def _load_entry_from_hubconf(m, model):
+    if not isinstance(model, str):
+        raise ValueError('Invalid input: model should be a string of function name')
+
+    # Note that if a missing dependency is imported at top level of hubconf, it will
+    # throw before this function. It's a chicken and egg situation where we have to
+    # load hubconf to know what're the dependencies, but to import hubconf it requires
+    # a missing package. This is fine, Python will throw proper error message for users.
+    _check_dependencies(m)
+
+    func = _load_attr_from_module(m, model)
+
+    if func is None or not callable(func):
+        raise RuntimeError(f'Cannot find callable {model} in hubconf')
+
+    return func
+
+
+def get_dir():
+    r"""
+    Get the Torch Hub cache directory used for storing downloaded models & weights.
+
+    If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
+    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
+    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
+    filesystem layout, with a default value ``~/.cache`` if the environment
+    variable is not set.
+    """
+    # Issue warning to move data if old env is set
+    if os.getenv('TORCH_HUB'):
+        warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
+
+    if _hub_dir is not None:
+        return _hub_dir
+    return os.path.join(_get_torch_home(), 'hub')
+
+
+def set_dir(d):
+    r"""
+    Optionally set the Torch Hub directory used to save downloaded models & weights.
+
+    Args:
+        d (str): path to a local folder to save downloaded models & weights.
+    """
+    global _hub_dir
+    _hub_dir = os.path.expanduser(d)
+
+
+def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
+    r"""
+    List all callable entrypoints available in the repo specified by ``github``.
+
+    Args:
+        github (str): a string with format "repo_owner/repo_name[:ref]" with an optional
+            ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if
+            it exists, and otherwise ``master``.
+            Example: 'pytorch/vision:0.10'
+        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
+            Default is ``False``.
+        skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
+            specified by the ``github`` argument properly belongs to the repo owner. This will make
+            requests to the GitHub API; you can specify a non-default GitHub token by setting the
+            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
+        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
+            This parameter was introduced in v1.12 and helps ensuring that users
+            only run code from repos that they trust.
+
+            - If ``False``, a prompt will ask the user whether the repo should
+              be trusted.
+            - If ``True``, the repo will be added to the trusted list and loaded
+              without requiring explicit confirmation.
+            - If ``"check"``, the repo will be checked against the list of
+              trusted repos in the cache. If it is not present in that list, the
+              behaviour will fall back onto the ``trust_repo=False`` option.
+            - If ``None``: this will raise a warning, inviting the user to set
+              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
+              is only present for backward compatibility and will be removed in
+              v2.0.
+
+            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
+        verbose (bool, optional): If ``False``, mute messages about hitting
+            local caches. Note that the message about first download cannot be
+            muted. Default is ``True``.
+
+    Returns:
+        list: The available callables entrypoint
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
+        >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
+    """
+    repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
+                                    skip_validation=skip_validation)
+
+    with _add_to_sys_path(repo_dir):
+        hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
+        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
+
+    # We take functions starts with '_' as internal helper functions
+    entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
+
+    return entrypoints
+
+
+def help(github, model, force_reload=False, skip_validation=False, trust_repo=None):
+    r"""
+    Show the docstring of entrypoint ``model``.
+
+    Args:
+        github (str): a string with format  with an optional
+            ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed
+            to be ``main`` if it exists, and otherwise ``master``.
+            Example: 'pytorch/vision:0.10'
+        model (str): a string of entrypoint name defined in repo's ``hubconf.py``
+        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
+            Default is ``False``.
+        skip_validation (bool, optional): if ``False``, torchhub will check that the ref
+            specified by the ``github`` argument properly belongs to the repo owner. This will make
+            requests to the GitHub API; you can specify a non-default GitHub token by setting the
+            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
+        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
+            This parameter was introduced in v1.12 and helps ensuring that users
+            only run code from repos that they trust.
+
+            - If ``False``, a prompt will ask the user whether the repo should
+              be trusted.
+            - If ``True``, the repo will be added to the trusted list and loaded
+              without requiring explicit confirmation.
+            - If ``"check"``, the repo will be checked against the list of
+              trusted repos in the cache. If it is not present in that list, the
+              behaviour will fall back onto the ``trust_repo=False`` option.
+            - If ``None``: this will raise a warning, inviting the user to set
+              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
+              is only present for backward compatibility and will be removed in
+              v2.0.
+
+            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
+        >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
+    """
+    repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
+                                    skip_validation=skip_validation)
+
+    with _add_to_sys_path(repo_dir):
+        hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
+        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
+
+    entry = _load_entry_from_hubconf(hub_module, model)
+
+    return entry.__doc__
+
+
+def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
+         skip_validation=False,
+         **kwargs):
+    r"""
+    Load a model from a github repo or a local directory.
+
+    Note: Loading a model is the typical use case, but this can also be used to
+    for loading other objects such as tokenizers, loss functions, etc.
+
+    If ``source`` is 'github', ``repo_or_dir`` is expected to be
+    of the form ``repo_owner/repo_name[:ref]`` with an optional
+    ref (a tag or a branch).
+
+    If ``source`` is 'local', ``repo_or_dir`` is expected to be a
+    path to a local directory.
+
+    Args:
+        repo_or_dir (str): If ``source`` is 'github',
+            this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with
+            an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified,
+            the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.
+            If ``source`` is 'local'  then it should be a path to a local directory.
+        model (str): the name of a callable (entrypoint) defined in the
+            repo/dir's ``hubconf.py``.
+        *args (optional): the corresponding args for callable ``model``.
+        source (str, optional): 'github' or 'local'. Specifies how
+            ``repo_or_dir`` is to be interpreted. Default is 'github'.
+        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
+            This parameter was introduced in v1.12 and helps ensuring that users
+            only run code from repos that they trust.
+
+            - If ``False``, a prompt will ask the user whether the repo should
+              be trusted.
+            - If ``True``, the repo will be added to the trusted list and loaded
+              without requiring explicit confirmation.
+            - If ``"check"``, the repo will be checked against the list of
+              trusted repos in the cache. If it is not present in that list, the
+              behaviour will fall back onto the ``trust_repo=False`` option.
+            - If ``None``: this will raise a warning, inviting the user to set
+              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
+              is only present for backward compatibility and will be removed in
+              v2.0.
+
+            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
+        force_reload (bool, optional): whether to force a fresh download of
+            the github repo unconditionally. Does not have any effect if
+            ``source = 'local'``. Default is ``False``.
+        verbose (bool, optional): If ``False``, mute messages about hitting
+            local caches. Note that the message about first download cannot be
+            muted. Does not have any effect if ``source = 'local'``.
+            Default is ``True``.
+        skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
+            specified by the ``github`` argument properly belongs to the repo owner. This will make
+            requests to the GitHub API; you can specify a non-default GitHub token by setting the
+            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
+        **kwargs (optional): the corresponding kwargs for callable ``model``.
+
+    Returns:
+        The output of the ``model`` callable when called with the given
+        ``*args`` and ``**kwargs``.
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
+        >>> # from a github repo
+        >>> repo = 'pytorch/vision'
+        >>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
+        >>> # from a local directory
+        >>> path = '/some/local/path/pytorch/vision'
+        >>> # xdoctest: +SKIP
+        >>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT')
+    """
+    source = source.lower()
+
+    if source not in ('github', 'local'):
+        raise ValueError(
+            f'Unknown source: "{source}". Allowed values: "github" | "local".')
+
+    if source == 'github':
+        repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
+                                           verbose=verbose, skip_validation=skip_validation)
+
+    model = _load_local(repo_or_dir, model, *args, **kwargs)
+    return model
+
+
+def _load_local(hubconf_dir, model, *args, **kwargs):
+    r"""
+    Load a model from a local directory with a ``hubconf.py``.
+
+    Args:
+        hubconf_dir (str): path to a local directory that contains a
+            ``hubconf.py``.
+        model (str): name of an entrypoint defined in the directory's
+            ``hubconf.py``.
+        *args (optional): the corresponding args for callable ``model``.
+        **kwargs (optional): the corresponding kwargs for callable ``model``.
+
+    Returns:
+        a single model with corresponding pretrained weights.
+
+    Example:
+        >>> # xdoctest: +SKIP("stub local path")
+        >>> path = '/some/local/path/pytorch/vision'
+        >>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
+    """
+    with _add_to_sys_path(hubconf_dir):
+        hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
+        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
+
+        entry = _load_entry_from_hubconf(hub_module, model)
+        model = entry(*args, **kwargs)
+
+    return model
+
+
+def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
+                         progress: bool = True) -> None:
+    r"""Download object at the given URL to a local path.
+
+    Args:
+        url (str): URL of the object to download
+        dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``
+        hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
+            Default: None
+        progress (bool, optional): whether or not to display a progress bar to stderr
+            Default: True
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
+        >>> # xdoctest: +REQUIRES(POSIX)
+        >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
+
+    """
+    file_size = None
+    req = Request(url, headers={"User-Agent": "torch.hub"})
+    u = urlopen(req)
+    meta = u.info()
+    if hasattr(meta, 'getheaders'):
+        content_length = meta.getheaders("Content-Length")
+    else:
+        content_length = meta.get_all("Content-Length")
+    if content_length is not None and len(content_length) > 0:
+        file_size = int(content_length[0])
+
+    # We deliberately save it in a temp file and move it after
+    # download is complete. This prevents a local working checkpoint
+    # being overridden by a broken download.
+    # We deliberately do not use NamedTemporaryFile to avoid restrictive
+    # file permissions being applied to the downloaded file.
+    dst = os.path.expanduser(dst)
+    for seq in range(tempfile.TMP_MAX):
+        tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
+        try:
+            f = open(tmp_dst, 'w+b')
+        except FileExistsError:
+            continue
+        break
+    else:
+        raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
+
+    try:
+        if hash_prefix is not None:
+            sha256 = hashlib.sha256()
+        with tqdm(total=file_size, disable=not progress,
+                  unit='B', unit_scale=True, unit_divisor=1024) as pbar:
+            while True:
+                buffer = u.read(READ_DATA_CHUNK)
+                if len(buffer) == 0:
+                    break
+                f.write(buffer)  # type: ignore[possibly-undefined]
+                if hash_prefix is not None:
+                    sha256.update(buffer)  # type: ignore[possibly-undefined]
+                pbar.update(len(buffer))
+
+        f.close()
+        if hash_prefix is not None:
+            digest = sha256.hexdigest()  # type: ignore[possibly-undefined]
+            if digest[:len(hash_prefix)] != hash_prefix:
+                raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
+        shutil.move(f.name, dst)
+    finally:
+        f.close()
+        if os.path.exists(f.name):
+            os.remove(f.name)
+
+
+# Hub used to support automatically extracts from zipfile manually compressed by users.
+# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
+# We should remove this support since zipfile is now default zipfile format for torch.save().
+def _is_legacy_zip_format(filename: str) -> bool:
+    if zipfile.is_zipfile(filename):
+        infolist = zipfile.ZipFile(filename).infolist()
+        return len(infolist) == 1 and not infolist[0].is_dir()
+    return False
+
+
+def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
+    warnings.warn('Falling back to the old format < 1.6. This support will be '
+                  'deprecated in favor of default zipfile format introduced in 1.6. '
+                  'Please redo torch.save() to save it in the new zipfile format.')
+    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
+    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
+    #       E.g. resnet18-5c106cde.pth which is widely used.
+    with zipfile.ZipFile(filename) as f:
+        members = f.infolist()
+        if len(members) != 1:
+            raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
+        f.extractall(model_dir)
+        extraced_name = members[0].filename
+        extracted_file = os.path.join(model_dir, extraced_name)
+    return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
+
+
+def load_state_dict_from_url(
+    url: str,
+    model_dir: Optional[str] = None,
+    map_location: MAP_LOCATION = None,
+    progress: bool = True,
+    check_hash: bool = False,
+    file_name: Optional[str] = None,
+    weights_only: bool = False,
+) -> Dict[str, Any]:
+    r"""Loads the Torch serialized object at the given URL.
+
+    If downloaded file is a zip file, it will be automatically
+    decompressed.
+
+    If the object is already present in `model_dir`, it's deserialized and
+    returned.
+    The default value of ``model_dir`` is ``/checkpoints`` where
+    ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
+
+    Args:
+        url (str): URL of the object to download
+        model_dir (str, optional): directory in which to save the object
+        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
+        progress (bool, optional): whether or not to display a progress bar to stderr.
+            Default: True
+        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
+            ``filename-.ext`` where ```` is the first eight or more
+            digits of the SHA256 hash of the contents of the file. The hash is used to
+            ensure unique names and to verify the contents of the file.
+            Default: False
+        file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set.
+        weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects.
+            Recommended for untrusted sources. See :func:`~torch.load` for more details.
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
+        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
+
+    """
+    # Issue warning to move data if old env is set
+    if os.getenv('TORCH_MODEL_ZOO'):
+        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
+
+    if model_dir is None:
+        hub_dir = get_dir()
+        model_dir = os.path.join(hub_dir, 'checkpoints')
+
+    os.makedirs(model_dir, exist_ok=True)
+
+    parts = urlparse(url)
+    filename = os.path.basename(parts.path)
+    if file_name is not None:
+        filename = file_name
+    cached_file = os.path.join(model_dir, filename)
+    if not os.path.exists(cached_file):
+        sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
+        hash_prefix = None
+        if check_hash:
+            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
+            hash_prefix = r.group(1) if r else None
+        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+
+    if _is_legacy_zip_format(cached_file):
+        return _legacy_zip_load(cached_file, model_dir, map_location, weights_only)
+    return torch.load(cached_file, map_location=map_location, weights_only=weights_only)
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ATen.h b/MLPY/Lib/site-packages/torch/include/ATen/ATen.h
new file mode 100644
index 0000000000000000000000000000000000000000..60a33d74a04a0a1ae07d1e000b8c73bc3ca3cda4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ATen.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#if !defined(_MSC_VER) && __cplusplus < 201703L
+#error C++17 or later compatible compiler is required to use ATen.
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// TODO: try to remove this
+// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/AccumulateType.h b/MLPY/Lib/site-packages/torch/include/ATen/AccumulateType.h
new file mode 100644
index 0000000000000000000000000000000000000000..d7a26b07c647477e71f807d7c00f16ae8a4ab3b0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/AccumulateType.h
@@ -0,0 +1,153 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// Defines the accumulation type for a scalar type.
+// Example:
+//   using accscalar_t = acc_type;
+//
+// Accumulation types are an important concept in numeric computing
+// because you frequently want to perform intermediate computations
+// at a higher precision than the input and output precision, to avoid
+// compounding internal rounding errors.  Accumulation is the most
+// well-known intermediate computation (it is of great importance for
+// sum reduction and matrix multiply, for example), but in PyTorch
+// acc_type ends up getting used for all sorts of other intermediate
+// computations, so it perhaps would be more accurately (ahem) called an
+// "accurate" type.  acc_type is especially important for reduced
+// precision operations like float16 and bfloat16, where relatively
+// benign looking inputs can easily end up overflowing/underflowing.
+//
+// acc_type is parametrized by whether or not you are running on CUDA
+// or not, because on CUDA double precision operations are expensive
+// and so by default, we don't actually want to use double as an
+// acc_type on CUDA.  A lot of things are typed out below, but
+// basically, the table is generated by a few rules:
+//
+//  If bool:
+//      Use 'bool' as acc_type.
+//  If floating point:
+//      If CUDA, use 'float' as acc_type (unless scalar_t is double),
+//      otherwise (CPU) use 'double'
+//  If integral:
+//      Use 'int64_t' as acc_type
+//
+// You're not forced to use this template; if you happen to know
+// something specific about your use case, you can specify your own
+// desired behavior.  This template, however, will give you a reasonable
+// default that will work for all dtypes supported in PyTorch.
+
+#if defined(__CUDACC__)
+#include 
+#include 
+#elif defined(__HIPCC__)
+#include 
+#include 
+#endif
+
+namespace at {
+
+template 
+struct AccumulateTypeDevice {};
+
+template 
+struct AccumulateType {};
+
+template 
+struct AccumulateType {
+  using type = typename AccumulateTypeDevice::type;
+};
+
+template 
+struct AccumulateType {
+  using type = typename AccumulateTypeDevice::type;
+};
+
+template 
+using acc_type_device = typename AccumulateTypeDevice::type;
+
+template 
+using acc_type = typename AccumulateType::type;
+
+#define ACC_TYPE(t, acc_t, device_type)         \
+  template <>                                   \
+  struct AccumulateTypeDevice { \
+    using type = acc_t;                         \
+  };
+#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
+#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
+#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
+
+MPS_ACC_TYPE(BFloat16, float);
+MPS_ACC_TYPE(Half, float);
+MPS_ACC_TYPE(Float8_e5m2, float);
+MPS_ACC_TYPE(Float8_e4m3fn, float);
+MPS_ACC_TYPE(Float8_e5m2fnuz, float);
+MPS_ACC_TYPE(Float8_e4m3fnuz, float);
+MPS_ACC_TYPE(float, float);
+MPS_ACC_TYPE(double, float);
+MPS_ACC_TYPE(int8_t, int64_t);
+MPS_ACC_TYPE(uint8_t, int64_t);
+MPS_ACC_TYPE(char, int64_t);
+MPS_ACC_TYPE(int16_t, int64_t);
+MPS_ACC_TYPE(int32_t, int64_t);
+MPS_ACC_TYPE(int64_t, int64_t);
+MPS_ACC_TYPE(bool, bool);
+MPS_ACC_TYPE(c10::complex, c10::complex);
+MPS_ACC_TYPE(c10::complex, c10::complex);
+MPS_ACC_TYPE(c10::complex, c10::complex);
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+CUDA_ACC_TYPE(half, float);
+#endif
+CUDA_ACC_TYPE(BFloat16, float);
+CUDA_ACC_TYPE(Half, float);
+CUDA_ACC_TYPE(Float8_e5m2, float);
+CUDA_ACC_TYPE(Float8_e4m3fn, float);
+CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
+CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
+CUDA_ACC_TYPE(float, float);
+CUDA_ACC_TYPE(double, double);
+CUDA_ACC_TYPE(int8_t, int64_t);
+CUDA_ACC_TYPE(uint8_t, int64_t);
+CUDA_ACC_TYPE(char, int64_t);
+CUDA_ACC_TYPE(int16_t, int64_t);
+CUDA_ACC_TYPE(int32_t, int64_t);
+CUDA_ACC_TYPE(int64_t, int64_t);
+CUDA_ACC_TYPE(bool, bool);
+CUDA_ACC_TYPE(c10::complex, c10::complex);
+CUDA_ACC_TYPE(c10::complex, c10::complex);
+CUDA_ACC_TYPE(c10::complex, c10::complex);
+
+CPU_ACC_TYPE(BFloat16, float);
+CPU_ACC_TYPE(Half, float);
+CPU_ACC_TYPE(Float8_e5m2, float);
+CPU_ACC_TYPE(Float8_e4m3fn, float);
+CPU_ACC_TYPE(Float8_e5m2fnuz, float);
+CPU_ACC_TYPE(Float8_e4m3fnuz, float);
+CPU_ACC_TYPE(float, double);
+CPU_ACC_TYPE(double, double);
+CPU_ACC_TYPE(int8_t, int64_t);
+CPU_ACC_TYPE(uint8_t, int64_t);
+CPU_ACC_TYPE(char, int64_t);
+CPU_ACC_TYPE(int16_t, int64_t);
+CPU_ACC_TYPE(int32_t, int64_t);
+CPU_ACC_TYPE(int64_t, int64_t);
+CPU_ACC_TYPE(bool, bool);
+CPU_ACC_TYPE(c10::complex, c10::complex);
+CPU_ACC_TYPE(c10::complex, c10::complex);
+CPU_ACC_TYPE(c10::complex, c10::complex);
+
+TORCH_API c10::ScalarType toAccumulateType(
+    c10::ScalarType type,
+    c10::DeviceType device);
+TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ArrayRef.h b/MLPY/Lib/site-packages/torch/include/ATen/ArrayRef.h
new file mode 100644
index 0000000000000000000000000000000000000000..8c1febe4654361afa6b90cd38898b90cf8a8d17f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ArrayRef.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Backend.h b/MLPY/Lib/site-packages/torch/include/ATen/Backend.h
new file mode 100644
index 0000000000000000000000000000000000000000..34b3b191549d2be6218da30bc2acab3baa215888
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Backend.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Backtrace.h b/MLPY/Lib/site-packages/torch/include/ATen/Backtrace.h
new file mode 100644
index 0000000000000000000000000000000000000000..2d6eba46720207605fd2b6640ce48c9ae0bffd20
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Backtrace.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..3135125e8d348b2b363617be3cc4a703fe814443
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h
@@ -0,0 +1,343 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+/*
+ * The basic strategy for apply is as follows:
+ *
+ * 1. Starting with the outermost index, loop until we reach a dimension where
+ * the data is no longer contiguous, i.e. the stride at that dimension is not
+ * equal to the size of the tensor defined by the outer dimensions. Let's call
+ * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
+ * A is equal to the entire Tensor. Let's call the inner tensor B.
+ *
+ * 2. We loop through the indices in B, starting at its outermost dimension. For
+ * example, if B is a 2x2 matrix, then we do:
+ *
+ * B[0][0]
+ * B[0][1]
+ * B[1][0]
+ * B[1][1]
+ *
+ * We set the offset into the underlying storage as (storageOffset + stride_B *
+ * index_B), i.e. basically we compute the offset into the storage as we would
+ * normally for a Tensor. But because we are guaranteed the subsequent data is
+ * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
+ * the operation, without having to follow the order described by the strides of
+ * A.
+ *
+ * 3. As an optimization, we merge dimensions of A that are contiguous in
+ * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
+ * then the first two dimensions can be merged for the purposes of APPLY,
+ * reducing the number of nested loops.
+ */
+
+inline Tensor sort_strides(Tensor& tensor_) {
+  IntArrayRef strides = tensor_.strides();
+  std::vector indices;
+  indices.reserve(tensor_.ndimension());
+  for (const auto i : c10::irange(tensor_.ndimension())) {
+    indices.push_back(i);
+  }
+  std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
+    return strides[i1] > strides[i2];
+  });
+  Tensor tensor = tensor_.permute(indices);
+  return tensor;
+}
+
+template 
+struct strided_tensor_iter_fixed {
+ public:
+  T* data_ = NULL;
+  int64_t dim_ = 0;
+
+  int64_t counter_[N] = {0};
+  int64_t sizes_[N] = {0};
+  int64_t strides_[N] = {0};
+
+  strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
+  void operator=(strided_tensor_iter_fixed const& x) = delete;
+  strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
+  strided_tensor_iter_fixed(
+      Tensor& tensor,
+      C10_UNUSED bool sort_strides = false)
+      : data_(tensor.data_ptr()) {
+    std::memset(counter_, 0, sizeof(int64_t) * N);
+    if (tensor.dim() > 0) {
+      std::memcpy(
+          sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
+      std::memcpy(
+          strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
+    }
+    dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
+  }
+};
+
+template 
+struct strided_tensor_iter {
+ private:
+ public:
+  T* data_ = NULL;
+  int64_t dim_;
+
+  std::vector counter_;
+  std::vector sizes_;
+  std::vector strides_;
+
+  strided_tensor_iter(strided_tensor_iter const&) = delete;
+  void operator=(strided_tensor_iter const& x) = delete;
+  strided_tensor_iter(strided_tensor_iter&&) = default;
+  strided_tensor_iter(Tensor& tensor)
+      : data_(tensor.data_ptr()),
+        dim_(tensor.ndimension()),
+        counter_(dim_, 0),
+        sizes_(tensor.sizes().vec()),
+        strides_(tensor.strides().vec()) {
+    dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
+  }
+};
+
+inline bool _all_equal_numel(at::ArrayRef tensors) {
+  if (tensors.empty())
+    return true;
+  int64_t all_numel = tensors[0].numel();
+  for (const auto i : c10::irange(1, tensors.size())) {
+    if (tensors[i].numel() != all_numel)
+      return false;
+  }
+  return true;
+}
+
+inline std::string _all_equal_numel_error(at::ArrayRef tensors) {
+  std::ostringstream oss;
+  oss << "inconsistent tensor size, expected ";
+  for (size_t i = 0; i < tensors.size() - 1; i++) {
+    oss << tensors[i].sizes() << ", ";
+  }
+  oss << "and " << tensors[tensors.size() - 1].sizes()
+      << " to have the same number of elements, but got ";
+  for (size_t i = 0; i < tensors.size() - 1; i++) {
+    oss << tensors[i].numel() << ", ";
+  }
+  oss << "and " << tensors[tensors.size() - 1].numel()
+      << " elements respectively";
+  return oss.str();
+}
+
+inline bool _apply_preamble(ArrayRef tensors) {
+  checkDeviceType("CPU_tensor_apply", tensors, kCPU);
+  checkLayout("CPU_tensor_apply", tensors, kStrided);
+  if (!_all_equal_numel(tensors))
+    AT_ERROR(_all_equal_numel_error(tensors));
+  // An empty tensor has no elements
+  for (auto& t : tensors)
+    if (t.numel() == 0)
+      return false;
+  return true;
+}
+
+inline int64_t _max_dim_tensors(ArrayRef tensors) {
+  int64_t dim = 0;
+  for (auto& t : tensors)
+    dim = std::max(dim, t.ndimension());
+  return dim;
+}
+
+inline void iterate(int64_t /*size*/){};
+
+template 
+inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
+  iter.counter_[iter.dim_ - 1] += size;
+  iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
+  iterate(size, iter_tail...);
+}
+
+inline bool iterate_continue() {
+  return true;
+};
+
+template 
+inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
+  return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
+      iterate_continue(iter_tail...);
+}
+
+inline int64_t max_iterate_size() {
+  return std::numeric_limits::max();
+};
+
+template 
+inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
+  return std::min(
+      (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
+      max_iterate_size(iter_tail...));
+}
+
+inline void iterate_overflow(){};
+
+template 
+inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
+  if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
+    for (int64_t i = iter.dim_ - 1; i > 0; i--) {
+      if (iter.counter_[i] == iter.sizes_[i]) {
+        iter.counter_[i] = 0;
+        iter.counter_[i - 1]++;
+        iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
+            iter.strides_[i - 1];
+      }
+    }
+  }
+  iterate_overflow(iter_tail...);
+}
+
+inline void forward(int64_t /*offset*/){};
+
+template 
+inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
+  int64_t multi = offset;
+  for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
+    int64_t inc = multi % iter.sizes_[i];
+    multi = multi / iter.sizes_[i];
+    iter.data_ = iter.data_ + inc * iter.strides_[i];
+    iter.counter_[i] += inc;
+  }
+  forward(offset, iter_tail...);
+}
+
+inline int64_t max_dim() {
+  return 0;
+}
+
+template 
+inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
+  return std::max(iter.dim_, max_dim(iter_tail...));
+}
+
+inline void apply_op(){};
+
+template 
+inline void apply_op(
+    int64_t numel,
+    int64_t offset,
+    const Op& op,
+    Args... iters) {
+  // For 0-dim tensors
+  if (numel == 1 && max_dim(iters...) == 0) {
+    op(*iters.data_...);
+    return;
+  }
+  if (offset > 0)
+    forward(offset, iters...);
+  // Splitting this into chunks helps the compiler create faster assembly
+  for (int64_t i = 0; i < numel;) {
+    for (; iterate_continue(iters...) && i < numel;) {
+      op(*iters.data_...);
+      iterate(1, iters...);
+      i++;
+    }
+    iterate_overflow(iters...);
+  }
+}
+
+/*
+  Apply a pointwise operator to sequence of tensors
+
+  The calling convention for op is a function/functor that takes the same
+  number of pointers of type scalar as the number of given tensors. For example,
+  to compute a = b * c, op would be of the form:
+  [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
+  b_val[0] * c_val[0]; };
+*/
+
+template 
+inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
+  if (!_apply_preamble({tensor1, tensor2}))
+    return;
+  if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
+    apply_op(
+        tensor1.numel(),
+        0,
+        op,
+        strided_tensor_iter_fixed(tensor1),
+        strided_tensor_iter_fixed(tensor2));
+  } else {
+    apply_op(
+        tensor1.numel(),
+        0,
+        op,
+        strided_tensor_iter(tensor1),
+        strided_tensor_iter(tensor2));
+  }
+}
+
+template 
+inline void CPU_tensor_apply3(
+    Tensor tensor1,
+    Tensor tensor2,
+    Tensor tensor3,
+    const Op op) {
+  if (!_apply_preamble({tensor1, tensor2, tensor3}))
+    return;
+  if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
+    apply_op(
+        tensor1.numel(),
+        0,
+        op,
+        strided_tensor_iter_fixed(tensor1),
+        strided_tensor_iter_fixed(tensor2),
+        strided_tensor_iter_fixed(tensor3));
+  } else {
+    apply_op(
+        tensor1.numel(),
+        0,
+        op,
+        strided_tensor_iter(tensor1),
+        strided_tensor_iter(tensor2),
+        strided_tensor_iter(tensor3));
+  }
+}
+
+template <
+    typename scalar1,
+    typename scalar2,
+    typename scalar3,
+    typename scalar4,
+    typename Op>
+inline void CPU_tensor_apply4(
+    Tensor tensor1,
+    Tensor tensor2,
+    Tensor tensor3,
+    Tensor tensor4,
+    const Op op) {
+  if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
+    return;
+  if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
+    apply_op(
+        tensor1.numel(),
+        0,
+        op,
+        strided_tensor_iter_fixed(tensor1),
+        strided_tensor_iter_fixed(tensor2),
+        strided_tensor_iter_fixed(tensor3),
+        strided_tensor_iter_fixed(tensor4));
+  } else {
+    apply_op(
+        tensor1.numel(),
+        0,
+        op,
+        strided_tensor_iter(tensor1),
+        strided_tensor_iter(tensor2),
+        strided_tensor_iter(tensor3),
+        strided_tensor_iter(tensor4));
+  }
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h b/MLPY/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..a4d75a9245fc877e844dc1db699d68c286fe6a5c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include 
+#include 
+
+// This file creates a fake allocator that just throws exceptions if
+// it is actually used.
+
+// state passed to the allocator is the std::function called
+// when the blob is release by ATen
+
+namespace at {
+
+static cpu_fixed_malloc(void*, ptrdiff_t) {
+  AT_ERROR("attempting to resize a tensor view of an external blob");
+}
+
+static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
+  AT_ERROR("attempting to resize a tensor view of an external blob");
+}
+
+static cpu_fixed_free(void* state, void* allocation) {
+  auto on_release = static_cast*>(state);
+  (*on_release)(allocation);
+  delete on_release;
+}
+
+static Allocator CPU_fixed_allocator = {
+    cpu_fixed_malloc,
+    cpu_fixed_realloc,
+    cpu_fixed_free};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CPUFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/CPUFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..fb55baadc951e7a56b5e6c3b832e1868cb64684d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CPUFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..78548339cb9c38e5891d1a1606a8eaa48bc0e5df
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h
@@ -0,0 +1,576 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..34dc2c57b29e05e6efa46a6951001f626d7d9046
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
+  // Constructors
+  CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
+  ~CPUGeneratorImpl() override = default;
+
+  // CPUGeneratorImpl methods
+  std::shared_ptr clone() const;
+  void set_current_seed(uint64_t seed) override;
+  void set_offset(uint64_t offset) override;
+  uint64_t get_offset() const override;
+  uint64_t current_seed() const override;
+  uint64_t seed() override;
+  void set_state(const c10::TensorImpl& new_state) override;
+  c10::intrusive_ptr get_state() const override;
+  static c10::DeviceType device_type();
+  uint32_t random();
+  uint64_t random64();
+  c10::optional next_float_normal_sample();
+  c10::optional next_double_normal_sample();
+  void set_next_float_normal_sample(c10::optional randn);
+  void set_next_double_normal_sample(c10::optional randn);
+  at::mt19937 engine();
+  void set_engine(at::mt19937 engine);
+
+ private:
+  CPUGeneratorImpl* clone_impl() const override;
+  at::mt19937 engine_;
+  c10::optional next_float_normal_sample_;
+  c10::optional next_double_normal_sample_;
+};
+
+namespace detail {
+
+TORCH_API const Generator& getDefaultCPUGenerator();
+TORCH_API Generator
+createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
+
+} // namespace detail
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CUDAFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/CUDAFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..2fae8914e35a412df8a615d3f403c4aee5ba758e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CUDAFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..a805f1717f26c93472b0c6603d638d6176652dc8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h
@@ -0,0 +1,614 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CachedTensorUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/CachedTensorUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..573ac8e18c2548bde3f97e1489be102d9bd89f3d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CachedTensorUtils.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include 
+
+namespace at::caching {
+
+// Some systems (just cudagraphs currently) will persist a static tensor output
+// whose TensorImpl does not change across iterations. For these tensors caching
+// dtype conversions is invalid. Additionally, there will be an extra reference
+// count to these cached tensors that would prevent buffer inplacing and other
+// checks on tensor uniqueness. If we are not using these systems the enabled
+// flag will be false and we will avoid the hash lookup.
+
+TORCH_API bool is_cached_tensor(const at::Tensor& t);
+TORCH_API void add_cached_tensor(const at::Tensor& t);
+TORCH_API void remove_cached_tensor(const at::Tensor& t);
+TORCH_API void set_cached_tensors_enabled(bool enable);
+
+// For gradient buffer stealing we will adjust the use count of tensors
+// which are persisted by cudagraphs, just as we need to adjust reference
+// count of tensors with hooks.
+TORCH_API size_t adjusted_use_count(const at::Tensor& t);
+
+} // namespace at::caching
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CollapseDims.h b/MLPY/Lib/site-packages/torch/include/ATen/CollapseDims.h
new file mode 100644
index 0000000000000000000000000000000000000000..b7ca0d9db788470049ff8ce48a433217ffeb5cc3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CollapseDims.h
@@ -0,0 +1,94 @@
+#include 
+#include 
+
+namespace at {
+
+/*
+[collapse dims] Updates sizes, and strides to reflect a "collapse" of
+the info, possibly excluding the optional excludeDim. A "collapsed" version
+of the info is the fewest dims that order the tensor's elements in the same
+way as the original info. If excludeDim is specified, the collapse is the
+fewest dims that order the tensor's elements as the original and preserve the
+excluded dimension, unless the tensor collapses to a point.
+
+This function returns a pair of values.
+
+1) The (new) index of the preserved dimension if excludeDim is
+specified. 0 if the tensor is collapsed to a point. -1
+otherwise.
+
+2) The new number of dimensions.
+*/
+template 
+inline std::pair collapse_dims(
+    T* sizes,
+    T* strides,
+    int64_t dims,
+    const int excludeDim = -1) {
+  TORCH_CHECK(
+      excludeDim >= -1 && excludeDim < dims,
+      "expected excluded dim between -1 and dims - 1");
+
+  int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
+  int64_t newIndex = -1;
+  int64_t oldIndex = 0;
+  int64_t remappedExcludedDim = -1;
+
+  while (oldIndex < dims) {
+    // Finds a dimension to collapse into
+    for (; oldIndex < stopDim; ++oldIndex) {
+      if (sizes[oldIndex] == 1) {
+        continue;
+      }
+
+      ++newIndex;
+      sizes[newIndex] = sizes[oldIndex];
+      strides[newIndex] = strides[oldIndex];
+      ++oldIndex;
+      break;
+    }
+
+    // Collapses dims
+    for (; oldIndex < stopDim; ++oldIndex) {
+      if (sizes[oldIndex] == 1) {
+        continue;
+      }
+
+      if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
+        sizes[newIndex] *= sizes[oldIndex];
+        strides[newIndex] = strides[oldIndex];
+      } else {
+        ++newIndex;
+        sizes[newIndex] = sizes[oldIndex];
+        strides[newIndex] = strides[oldIndex];
+      }
+    }
+
+    // Handles excludeDim being set (oldIndex == excludeDim)
+    if (oldIndex != dims) {
+      // Preserves excluded dimension
+      ++newIndex;
+      sizes[newIndex] = sizes[oldIndex];
+      strides[newIndex] = strides[oldIndex];
+      remappedExcludedDim = newIndex;
+
+      // Restarts iteration after excludeDim
+      ++oldIndex;
+      stopDim = dims;
+    }
+  }
+
+  // Handles special case of all dims size 1
+  if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
+    dims = 1;
+    sizes[0] = 1;
+    strides[0] = 1;
+
+    return std::pair(0, 1);
+  }
+
+  dims = newIndex + 1;
+  return std::pair(remappedExcludedDim, dims);
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..c4564ba4f32f3a84b638058a84edda3bb49230b4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..a1f2556cedee558c7ada96911a13e77ce1d6107d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h
@@ -0,0 +1,542 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..6f96cd9c9d7ba2a487c5b4943b2e50fa2e2d2b99
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..22915229c825ab0cf1aa8488a3bdb67931b96601
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h
@@ -0,0 +1,323 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..91de7d33c69904c252bc999926c1073464bb3e8f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..7fed37f3e886e0692495f4b3d1a8dacabc0380e4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h
@@ -0,0 +1,500 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..8b065d6c742bf58aa05ef33b03c16f7f61bddbe4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..d5b7c77fbca654b74b428193f1da16b348f6d325
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h
@@ -0,0 +1,25 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Config.h b/MLPY/Lib/site-packages/torch/include/ATen/Config.h
new file mode 100644
index 0000000000000000000000000000000000000000..7c22566d9c7b87ac9d41156b1d7ba71f7f6fefda
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Config.h
@@ -0,0 +1,22 @@
+#pragma once
+
+// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
+// obvious if you forgot to include Config.h
+//    c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
+//
+// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
+
+#define AT_MKLDNN_ENABLED() 1
+#define AT_MKLDNN_ACL_ENABLED() 0
+#define AT_MKL_ENABLED() 1
+#define AT_MKL_SEQUENTIAL() 0
+#define AT_POCKETFFT_ENABLED() 0
+#define AT_NNPACK_ENABLED() 0
+#define CAFFE2_STATIC_LINK_CUDA() 0
+#define AT_BUILD_WITH_BLAS() 1
+#define AT_BUILD_WITH_LAPACK() 1
+#define AT_PARALLEL_OPENMP 1
+#define AT_PARALLEL_NATIVE 0
+#define AT_PARALLEL_NATIVE_TBB 0
+#define AT_BLAS_F2C() 0
+#define AT_BLAS_USE_CBLAS_DOT() 0
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Context.h b/MLPY/Lib/site-packages/torch/include/ATen/Context.h
new file mode 100644
index 0000000000000000000000000000000000000000..b8e2d98216334f124c52ba2765e69be140f00196
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Context.h
@@ -0,0 +1,560 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at {
+
+class Tensor;
+
+enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
+
+class TORCH_API Context {
+ public:
+  Context();
+
+  const Generator& defaultGenerator(Device device) {
+    c10::DeviceType device_type = device.type();
+    initCUDAIfNeeded(device_type);
+    initHIPIfNeeded(device_type);
+    if (device_type == at::kCPU) {
+      return at::detail::getDefaultCPUGenerator();
+    } else if (device_type == at::kCUDA) {
+      return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
+    } else if (device_type == at::kMPS) {
+      return at::detail::getMPSHooks().getDefaultMPSGenerator();
+    } else if (device_type == at::kXPU) {
+      return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
+    } else if (device_type == at::kIPU) {
+      return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
+    } else if (device_type == at::kPrivateUse1) {
+      return at::GetPrivateUse1HooksInterface()->getDefaultGenerator(
+          device.index());
+    } else {
+      AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
+    }
+  }
+  const AcceleratorHooksInterface& getAcceleratorHooksInterface(
+      c10::optional opt_device_type = c10::nullopt) {
+    c10::DeviceType device_type = opt_device_type.has_value()
+        ? opt_device_type.value()
+        : at::getAccelerator(true).value();
+    if (device_type == at::kCUDA) {
+      return at::detail::getCUDAHooks();
+    } else if (device_type == at::kMPS) {
+      return at::detail::getMPSHooks();
+    } else if (device_type == at::kPrivateUse1) {
+      return at::detail::getPrivateUse1Hooks();
+    } else {
+      AT_ERROR(
+          c10::DeviceTypeName(device_type), " device type not an accelerator.");
+    }
+  }
+  Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
+    initCUDAIfNeeded(device_type);
+    initHIPIfNeeded(device_type);
+    initXPUIfNeeded(device_type);
+    if (device_type == at::kCPU) {
+      return c10::DeviceType::CPU;
+    } else if (device_type == at::kCUDA) {
+      return at::detail::getCUDAHooks().getDeviceFromPtr(data);
+    } else if (device_type == at::kXPU) {
+      return at::detail::getXPUHooks().getDeviceFromPtr(data);
+    } else if (device_type == at::kPrivateUse1) {
+      return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data);
+    } else {
+      AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
+    }
+  }
+  static bool isPinnedPtr(const void* data) {
+    return detail::getCUDAHooks().isPinnedPtr(data);
+  }
+  static bool hasOpenMP();
+  static bool hasMKL();
+  static bool hasLAPACK();
+  static bool hasMKLDNN();
+  static bool hasMAGMA() {
+    return detail::getCUDAHooks().hasMAGMA();
+  }
+  static bool hasCUDA() {
+    return detail::getCUDAHooks().hasCUDA();
+  }
+  static bool hasMTIA() {
+    return detail::getMTIAHooks().hasMTIA();
+  }
+  static bool hasCUDART() {
+    return detail::getCUDAHooks().hasCUDART();
+  }
+  static long versionCUDART() {
+    return detail::getCUDAHooks().versionCUDART();
+  }
+  static bool hasCuDNN() {
+    return detail::getCUDAHooks().hasCuDNN();
+  }
+  static long versionCuDNN() {
+    return detail::getCUDAHooks().versionCuDNN();
+  }
+  static bool hasCuSOLVER() {
+    return detail::getCUDAHooks().hasCuSOLVER();
+  }
+  static bool hasHIP() {
+    return detail::getHIPHooks().hasHIP();
+  }
+  static bool hasMPS() {
+    return detail::getMPSHooks().hasMPS();
+  }
+  static bool hasIPU() {
+    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
+  }
+  static bool hasXLA() {
+    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
+  }
+  static bool hasXPU() {
+    return detail::getXPUHooks().hasXPU();
+  }
+  static bool hasLazy() {
+    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
+  }
+  static bool hasORT() {
+    return c10::impl::hasDeviceGuardImpl(c10::DeviceType::ORT);
+  }
+  // defined in header so that getNonVariableType has ability to inline
+  // call_once check. getNonVariableType is called fairly frequently
+  void lazyInitCUDA() {
+    c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
+  }
+  void lazyInitHIP() {
+    c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
+  }
+  void lazyInitXPU() {
+    c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
+  }
+  void lazyInitPrivateUse1() {
+    c10::call_once(thp_init, [&] {
+      if (isPrivateUse1HooksRegistered()) {
+        at::GetPrivateUse1HooksInterface()->initPrivateUse1();
+      }
+    });
+  }
+  static const at::cuda::NVRTC& getNVRTC() {
+    return detail::getCUDAHooks().nvrtc();
+  }
+
+  static bool setFlushDenormal(bool on);
+
+  // NB: This method is *purely* whether or not a user requested
+  // that CuDNN was enabled, it doesn't actually say anything about
+  // whether or not CuDNN is actually usable.  Use cudnn_is_acceptable
+  // to test this instead
+  bool userEnabledCuDNN() const;
+  void setUserEnabledCuDNN(bool e);
+  bool userEnabledMkldnn() const;
+  void setUserEnabledMkldnn(bool e);
+  bool benchmarkCuDNN() const;
+  void setBenchmarkCuDNN(bool);
+  int benchmarkLimitCuDNN() const;
+  void setBenchmarkLimitCuDNN(int);
+  bool deterministicCuDNN() const;
+  void setDeterministicCuDNN(bool);
+  bool userEnabledNNPACK() const;
+  void setUserEnabledNNPACK(bool e);
+
+  // Note [Disabling Fused SDP Kernels]
+  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+  // Flash and Memory Efficient SDP kernels are enabled by default.
+  // However, they can be disabled by setting
+  // at::globalContext().setUserEnabledFlashSDP(false) flag.
+  // This is useful for debugging purposes. For example, if you want to
+  // compare the performance of the flash SDP kernels with the unfused
+  // kernel, you can disable the flash SDP kernels. By disabling
+  // the math SDP kernel, you can force your code to use flash kernels.
+  // The math SDP kernel can be disabled by setting
+  // at::globalContext().setUserEnabledMathSDP(false) flag.
+  void setSDPUseFlash(bool);
+  bool userEnabledFlashSDP() const;
+
+  void setSDPUseMemEfficient(bool);
+  bool userEnabledMemEfficientSDP() const;
+
+  void setSDPUseMath(bool);
+  bool userEnabledMathSDP() const;
+
+  void setSDPUseCuDNN(bool);
+  bool userEnabledCuDNNSDP() const;
+
+  at::LinalgBackend linalgPreferredBackend() const;
+  void setLinalgPreferredBackend(at::LinalgBackend);
+
+  // Note [Enabling Deterministic Operations]
+  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+  // Operations in PyTorch that normally act nondeterministically, but have an
+  // alternate deterministic implementation, should satisfy the following
+  // requirements:
+  //
+  // * Include this comment: "See Note [Enabling Deterministic Operations]"
+  //
+  // * Check the value of `at::globalContext().deterministicAlgorithms()` to
+  // toggle
+  //   between nondeterministic and deterministic implementations.
+  //
+  // * Have an entry in the list of PyTorch operations that toggle between
+  // nondeterministic
+  //   and deterministic implementations, in the docstring of
+  //   `use_deterministic_algorithms()` in torch/__init__.py
+  //
+  // `example_func()` below shows an example of toggling between
+  // nondeterministic and deterministic implementations:
+  //
+  //    void example_func() {
+  //      // See Note [Enabling Deterministic Operations]
+  //      if (at::globalContext().deterministicAlgorithms()) {
+  //        example_func_deterministic();
+  //      } else {
+  //        example_func_nondeterministic();
+  //      }
+  //    }
+
+  bool deterministicAlgorithms() const;
+  bool deterministicAlgorithmsWarnOnly() const;
+  void setDeterministicAlgorithms(bool, bool);
+  bool deterministicFillUninitializedMemory() const;
+  void setDeterministicFillUninitializedMemory(bool);
+
+  // Note [Writing Nondeterministic Operations]
+  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+  // Operations in PyTorch that act nondeterministically and do not have an
+  // alternate deterministic implementation should satisfy the following
+  // requirements:
+  //
+  // * Include this comment: "See Note [Writing Nondeterministic Operations]"
+  //
+  // * Include a comment explaining why the operation is nondeterministic.
+  //
+  // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
+  //   of the time, this should be accomplished by calling
+  //   `at::globalContext().alertNotDeterminstic()`.  However, if the
+  //   nondeterministic behavior is caused by the CuBLAS workspace
+  //   configuration in CUDA >= 10.2,
+  //   `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
+  //   called instead (in this case, a comment explaining why the operation is
+  //   nondeterministic is not necessary). See below for details on these
+  //   methods.
+  //
+  // * Have an entry in the list of nondeterministic PyTorch operations in the
+  //   docstring of `use_deterministic_algorithms()` in torch/__init__.py
+  //
+  // * Have a test function in `test/test_torch.py` whose name begins with
+  //   `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
+  //   configuration is the reason for nondeterminism, the operation should be
+  //   included in the `test_cublas_config_nondeterministic_alert` test. Any new
+  //   tests should ideally follow a pattern similar to the existing ones.
+  //
+  // `example_func()` below shows an example of the comments and error-throwing
+  // code for a nondeterministic operation:
+  //
+  //    void example_func() {
+  //      // See Note [Writing Nondeterministic Operations]
+  //      // Nondeterministic because 
+  //      at::globalContext().alertNondeterministic("example_func");
+  //      ...
+  //    }
+
+  // Throws an error if `Context::deterministicAlgorithms()` is true
+  static void alertNotDeterministic(c10::string_view const& caller);
+
+  // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
+  // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
+  // ":4096:8". For more details:
+  // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
+  void alertCuBLASConfigNotDeterministic() const;
+
+  void setFloat32MatmulPrecision(const std::string& s);
+  bool allowTF32CuDNN() const;
+  void setAllowTF32CuDNN(bool);
+  bool allowTF32CuBLAS() const;
+  void setAllowTF32CuBLAS(bool);
+  Float32MatmulPrecision float32MatmulPrecision() const;
+  void setFloat32MatmulPrecision(Float32MatmulPrecision p);
+  bool allowFP16ReductionCuBLAS() const;
+  void setAllowFP16ReductionCuBLAS(bool);
+  bool allowBF16ReductionCuBLAS() const;
+  void setAllowBF16ReductionCuBLAS(bool);
+  at::QEngine qEngine() const;
+  void setQEngine(at::QEngine e);
+  static const std::vector& supportedQEngines();
+  static bool isXNNPACKAvailable();
+  void setCheckSparseTensorInvariants(bool e);
+  bool checkSparseTensorInvariants() const;
+  // This method is used to release the original weight after pre-packing.
+  // It should be called once before loading/running the model.
+  // NB: By default it is set to true for mobile builds.
+  void setReleaseWeightsWhenPrepacking(bool e);
+  bool releaseWeightsWhenPrepacking() const;
+
+  void setDisplayVmapFallbackWarnings(bool enabled);
+  bool areVmapFallbackWarningsEnabled() const;
+
+  void setDefaultMobileCPUAllocator();
+  void unsetDefaultMobileCPUAllocator();
+  bool allowFP16ReductionCPU() const;
+  void setAllowFP16ReductionCPU(bool);
+
+ private:
+  void initCUDAIfNeeded(c10::DeviceType p) {
+    if (p == c10::DeviceType::CUDA) {
+      lazyInitCUDA();
+    }
+  }
+  void initHIPIfNeeded(c10::DeviceType p) {
+    if (p == c10::DeviceType::HIP) {
+      lazyInitHIP();
+    }
+  }
+  void initXPUIfNeeded(c10::DeviceType p) {
+    if (p == c10::DeviceType::XPU) {
+      lazyInitXPU();
+    }
+  }
+  static bool checkCuBLASConfigDeterministic();
+  c10::once_flag thc_init;
+  c10::once_flag thh_init;
+  c10::once_flag thx_init;
+  c10::once_flag thp_init;
+  bool enabled_cudnn = true;
+  bool deterministic_cudnn = false;
+  bool _deterministic_algorithms = false;
+  bool _deterministic_algorithms_warn_only = false;
+  bool _deterministic_fill_uninitialized_memory = true;
+  bool enabled_flashSDP = true;
+  bool enabled_mem_efficientSDP = true;
+  bool enabled_mathSDP = true;
+  bool enabled_cudnnSDP = false;
+#ifdef USE_ROCM
+  bool benchmark_cudnn = true;
+#else
+  bool benchmark_cudnn = false;
+#endif
+  Float32MatmulPrecision float32_matmul_precision =
+      c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
+      ? at::Float32MatmulPrecision::HIGH
+      : at::Float32MatmulPrecision::HIGHEST;
+  int benchmark_limit_cudnn = 10;
+  bool allow_tf32_cudnn = true;
+  bool allow_fp16_reduction_cublas = true;
+  bool allow_bf16_reduction_cublas = true;
+  bool enabled_mkldnn = true;
+  bool enabled_nnpack = true;
+  at::LinalgBackend linalg_preferred_backend =
+      c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
+      ? at::LinalgBackend::Cusolver
+      : at::LinalgBackend::Default;
+#ifdef C10_MOBILE
+  bool release_original_weights = true;
+#else
+  bool release_original_weights = false;
+#endif
+  bool display_vmap_fallback_warnings_ = false;
+  c10::optional quantized_engine = c10::nullopt;
+  bool enable_sparse_tensor_invariant_checks = false;
+  bool allow_fp16_reduction_cpu = false;
+
+  Allocator* prev_allocator_ptr_{nullptr};
+};
+
+TORCH_API Context& globalContext();
+
+static inline void init() {
+  globalContext();
+}
+
+TORCH_API Allocator* getCPUAllocator();
+
+static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
+    Backend p,
+    ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      p, s);
+}
+
+static inline DeprecatedTypeProperties& CPU(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::CPU, s);
+}
+
+static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::CUDA, s);
+}
+
+static inline DeprecatedTypeProperties& HIP(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::HIP, s);
+}
+
+static inline DeprecatedTypeProperties& MPS(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::MPS, s);
+}
+
+static inline bool hasCUDA() {
+  return globalContext().hasCUDA();
+}
+
+static inline bool hasMTIA() {
+  return globalContext().hasMTIA();
+}
+
+static inline bool hasHIP() {
+  return globalContext().hasHIP();
+}
+
+static inline bool hasIPU() {
+  return globalContext().hasIPU();
+}
+
+static inline bool hasXLA() {
+  return globalContext().hasXLA();
+}
+
+static inline bool hasMPS() {
+  return globalContext().hasMPS();
+}
+
+static inline bool hasORT() {
+  return globalContext().hasORT();
+}
+
+static inline bool hasXPU() {
+  return globalContext().hasXPU();
+}
+
+// Despite its name, this function returns the number of *CUDA* GPUs.
+static inline size_t getNumGPUs() {
+  // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
+  // FUNCTION.  If you are interested in interrogating the number of
+  // devices for a specific device type, add that function to the
+  // relevant library (e.g., similar to at::cuda::device_count())
+  if (hasCUDA() && hasHIP()) {
+    throw std::runtime_error(
+        "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
+        "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
+        "means HIP.  Rebuild PyTorch with one or the other disabled.");
+  } else if (hasCUDA()) {
+    return detail::getCUDAHooks().getNumGPUs();
+  } else if (hasHIP()) {
+    return detail::getHIPHooks().getNumGPUs();
+  } else {
+    return 0;
+  }
+}
+
+static inline bool hasOpenMP() {
+  return globalContext().hasOpenMP();
+}
+
+static inline bool hasMKL() {
+  return globalContext().hasMKL();
+}
+
+static inline bool hasLAPACK() {
+  return globalContext().hasLAPACK();
+}
+
+static inline bool hasMAGMA() {
+  return globalContext().hasMAGMA();
+}
+
+static inline bool hasMKLDNN() {
+  return globalContext().hasMKLDNN();
+}
+
+static inline void manual_seed(uint64_t seed) {
+  auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
+  {
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(gen.mutex());
+    gen.set_current_seed(seed);
+  }
+  // NB: Sometimes we build with CUDA, but we don't have any GPUs
+  // available. In that case, we must not seed CUDA; it will fail!
+  const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
+  if (hasCUDA() && cuda_num_gpus > 0) {
+    for (const auto i : c10::irange(cuda_num_gpus)) {
+      auto cuda_gen = globalContext().defaultGenerator(
+          Device(at::kCUDA, static_cast(i)));
+      {
+        // See Note [Acquire lock when using random generators]
+        std::lock_guard lock(cuda_gen.mutex());
+        cuda_gen.set_current_seed(seed);
+      }
+    }
+  }
+
+  const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
+  if (hasXPU() && xpu_num_gpus) {
+    for (const auto i : c10::irange(xpu_num_gpus)) {
+      auto xpu_gen = globalContext().defaultGenerator(
+          Device(at::kXPU, static_cast(i)));
+      {
+        // See Note [Acquire lock when using random generators]
+        std::lock_guard lock(xpu_gen.mutex());
+        xpu_gen.set_current_seed(seed);
+      }
+    }
+  }
+
+  if (hasMPS()) {
+    auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(mps_gen.mutex());
+    mps_gen.set_current_seed(seed);
+  }
+}
+
+// When the global flag `allow_tf32` is set to true, cuBLAS handles are
+// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
+// For some operators, such as addmv, TF32 offers no performance improvement
+// but causes precision loss. To help this case, this class implements
+// a RAII guard that can be used to quickly disable TF32 within its scope.
+//
+// Usage:
+//     NoTF32Guard disable_tf32;
+struct TORCH_API NoTF32Guard {
+  NoTF32Guard();
+  ~NoTF32Guard();
+  static bool should_disable_tf32();
+
+ private:
+  bool changed = false;
+};
+
+struct TORCH_API ROCmBackwardPassGuard {
+  ROCmBackwardPassGuard();
+  ~ROCmBackwardPassGuard();
+  static bool is_backward_pass();
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/DLConvertor.h b/MLPY/Lib/site-packages/torch/include/ATen/DLConvertor.h
new file mode 100644
index 0000000000000000000000000000000000000000..70254bc97a4f61b531bc8491d0ba7b93d253fc63
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/DLConvertor.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+// this convertor will:
+// 1) take a Tensor object and wrap it in the DLPack tensor
+// 2) take a dlpack tensor and convert it to the ATen Tensor
+
+namespace at {
+
+TORCH_API ScalarType toScalarType(const DLDataType& dtype);
+TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
+TORCH_API Tensor fromDLPack(DLManagedTensor* src);
+C10_DEPRECATED_MESSAGE("Please migrate to a non-const variant")
+inline Tensor fromDLPack(const DLManagedTensor* src) {
+  return fromDLPack(const_cast(src));
+}
+TORCH_API Tensor
+fromDLPack(DLManagedTensor* src, std::function deleter);
+TORCH_API DLDataType getDLDataType(const Tensor& t);
+TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Device.h b/MLPY/Lib/site-packages/torch/include/ATen/Device.h
new file mode 100644
index 0000000000000000000000000000000000000000..77626cce2465850485e137b148845ee38b9ebb4d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Device.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h b/MLPY/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h
new file mode 100644
index 0000000000000000000000000000000000000000..ea564ed66b2e755e441c0338cf30290c90af96fa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+
+// This file defines the top level Accelerator concept for PyTorch.
+// A device is an accelerator per the definition here if:
+// - It is mutually exclusive with all other accelerators
+// - It performs asynchronous compute via a Stream/Event system
+// - It provides a set of common APIs as defined by AcceleratorHooksInterface
+//
+// As of today, accelerator devices are (in no particular order):
+// CUDA, MTIA, PrivateUse1
+// We want to add once all the proper APIs are supported and tested:
+// HIP, MPS, XPU
+
+namespace at {
+
+// Ensures that only one accelerator is available (at
+// compile time if possible) and return it.
+// When checked is true, the returned optional always has a value.
+TORCH_API std::optional getAccelerator(bool checked = false);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/DeviceGuard.h b/MLPY/Lib/site-packages/torch/include/ATen/DeviceGuard.h
new file mode 100644
index 0000000000000000000000000000000000000000..0cd52c27cd0b984f96097b8efed095a0e8a84016
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/DeviceGuard.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include  // TensorList whyyyyy
+
+namespace at {
+
+// Are you here because you're wondering why DeviceGuard(tensor) no
+// longer works?  For code organization reasons, we have temporarily(?)
+// removed this constructor from DeviceGuard.  The new way to
+// spell it is:
+//
+//    OptionalDeviceGuard guard(device_of(tensor));
+
+/// Return the Device of a Tensor, if the Tensor is defined.
+inline c10::optional device_of(const Tensor& t) {
+  if (t.defined()) {
+    return c10::make_optional(t.device());
+  } else {
+    return c10::nullopt;
+  }
+}
+
+inline c10::optional device_of(const c10::optional& t) {
+  return t.has_value() ? device_of(t.value()) : c10::nullopt;
+}
+
+/// Return the Device of a TensorList, if the list is non-empty and
+/// the first Tensor is defined.  (This function implicitly assumes
+/// that all tensors in the list have the same device.)
+inline c10::optional device_of(ITensorListRef t) {
+  if (!t.empty()) {
+    return device_of(t.front());
+  } else {
+    return c10::nullopt;
+  }
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/DimVector.h b/MLPY/Lib/site-packages/torch/include/ATen/DimVector.h
new file mode 100644
index 0000000000000000000000000000000000000000..0a854a378782824f756ff054d39965d259054351
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/DimVector.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Dimname.h b/MLPY/Lib/site-packages/torch/include/ATen/Dimname.h
new file mode 100644
index 0000000000000000000000000000000000000000..9a93a8e38f8f25d42131a320ecf54a55c59bb481
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Dimname.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/Dispatch.h
new file mode 100644
index 0000000000000000000000000000000000000000..d08a04b45e4244b15e845eecac8dc365112ee3f0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Dispatch.h
@@ -0,0 +1,808 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __CUDACC__
+#include  // For CUDA_VERSION
+#endif
+
+#ifdef TEMPLATE_SELECTIVE_BUILD
+#include 
+#else
+namespace at {
+/**
+ * The method should_include_kernel_dtype() returns true/false
+ * based on whether the switching code for a specific dtype should be
+ * included based on build time constants generated from tracing model
+ * execution. This method will be implmeneted via code-generation and
+ * included in this file when code-gen is ready.
+ */
+inline constexpr bool should_include_kernel_dtype(
+    const char* /*kernel_tag_str*/,
+    at::ScalarType /*scalar_type*/
+) {
+  return true;
+}
+} // namespace at
+#endif
+
+/**
+ * In the Facebook internal build (using BUCK), this macro is enabled by
+ * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
+ * binary.
+ */
+#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
+namespace at {
+namespace detail {
+TORCH_API void record_kernel_function_dtype(std::string name);
+}
+} // namespace at
+
+#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
+  at::detail::record_kernel_function_dtype(           \
+      std::string(NAME) + "$" + toString(enum_type));
+#else
+#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
+#endif
+
+#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type)   \
+  do {                                                \
+    if constexpr (!at::should_include_kernel_dtype(   \
+                      at_dispatch_name, enum_type)) { \
+      AT_ERROR(                                       \
+          "dtype '",                                  \
+          toString(enum_type),                        \
+          "' not selected for kernel tag ",           \
+          at_dispatch_name);                          \
+    }                                                 \
+  } while (0)
+
+#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...)           \
+  case enum_type: {                                                     \
+    AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type);                        \
+    using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT; \
+    return __VA_ARGS__();                                               \
+  }
+
+#define AT_DISPATCH_CASE(enum_type, ...) \
+  AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
+
+#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...)            \
+  case enum_type: {                                                   \
+    AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type);                      \
+    using scalar_t = scalar_type;                                     \
+    using underlying_t C10_UNUSED = typename scalar_t::underlying;    \
+    const auto& SCALAR_TYPE C10_UNUSED = enum_type;                   \
+    const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
+    return __VA_ARGS__();                                             \
+  }
+
+#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                           \
+    enum_type, scalar_type, bitwidth, qmin, qmax, ...)                \
+  case enum_type: {                                                   \
+    AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type);                      \
+    using scalar_t = scalar_type;                                     \
+    using underlying_t C10_UNUSED = typename scalar_t::underlying;    \
+    const auto& SCALAR_TYPE C10_UNUSED = enum_type;                   \
+    const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
+    C10_UNUSED int bit_width = bitwidth;                              \
+    C10_UNUSED int64_t quant_min = qmin;                              \
+    C10_UNUSED int64_t quant_max = qmax;                              \
+    return __VA_ARGS__();                                             \
+  }
+
+namespace detail {
+
+inline at::ScalarType scalar_type(at::ScalarType s) {
+  return s;
+}
+
+C10_DEPRECATED_MESSAGE(
+    "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
+    "pass an at::ScalarType instead")
+inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
+  return t.scalarType();
+}
+
+C10_DEPRECATED_MESSAGE(
+    "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
+    "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
+inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
+
+C10_DEPRECATED_MESSAGE(
+    "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
+    "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
+    "instead")
+inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
+
+} // namespace detail
+
+// The AT_DISPATCH_* family of macros provides the ability to
+// conveniently generate specializations of a kernel over all of the
+// dtypes we care about in PyTorch.  We call it "dispatch" because
+// we are "dispatching" to the correct, dtype-specific kernel.
+//
+// A standard usage looks like:
+//
+//      AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
+//          // Your code here, with 'scalar_t' now defined to
+//          // be the dtype in question
+//      });
+//
+// There are many variations of this macro, so it's important to
+// understand exactly /which/ dtypes you want to get instantiated, as
+// well as what the "default" set is.
+//
+// The default set of dtypes that are instantiated (e.g., by
+// AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
+// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
+// but NOT booleans (bool), half-precision floats (Half) or
+// complex number (c10::complex, c10::complex).
+// This "cut" is somewhat historical (the default types are the
+// ones that TH historically supported), but it also reflects the
+// fact that the non-default types are "poorly" behaved (booleans
+// are NOT integers mod 2, half precision operations ~essentially
+// don't exist on CPU, complex numbers are an experimental application).
+//
+// Here are the questions you should generally ask to decide which
+// dispatch you want:
+//
+// 1. Is this an integral or floating point specific operation?
+//    (If so, you'll want one of the FLOATING or INTEGRAL macros.)
+//
+// 2. Should half be supported?  (If you're on CPU, the answer is almost
+//    definitely no.  If you do want support, use one of the AND_HALF
+//    macros)
+//
+// Much rarer situations:
+//
+// 3. Should bool be supported?  (You often have to write your kernel
+//    differently if arithmetic operations are involved.)  If so,
+//    Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
+//
+// 4. Should complex be supported?  The answer is almost always no,
+//    unless you are working on "generic" code that should work on
+//    all dtypes.
+//
+// Parameters:
+// -----------
+//
+// 1. The NAME argument is a "tag" that is used to trace and then
+//    conditionally compile fragments of the case statements such
+//    that the kernel functions are specialized only for the dtypes
+//    that are needed. The NAME parameter *must* be a build time
+//    const char* (can't be std::string, etc...)
+//
+// Please ensure that the NAME is unique for every implementation
+// or you run the risk of over-including code for the kernel
+// functions. There is no risk of missing out on any code, so
+// it's mostly a risk of a Type-2 error, and not a Type-1 error.
+//
+// Switch-like syntax:
+// -------------------
+// There is also a switch-case like syntax which is useful if a kernel
+// needs to be specialized for particular scalar types
+//
+//      AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
+//          AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
+//            op_integral(iter);
+//          })
+//          AT_DISPATCH_CASE_FLOATING_TYPES([&] {
+//            op_floating(iter);
+//          })
+//          AT_DISPATCH_CASE(kBool, [&] {
+//            op_bool(iter);
+//          })
+//      );
+//
+// For each AT_DISPATCH_FOO macro, there is a corresponding
+// AT_DISPATCH_CASE_FOO macro which can be used inside of an
+// AT_DISPATCH_SWITCH block.
+
+// NB: the the_type variable is not used, but we have kept it for
+// backwards compatibility.  It's probably not used by anyone though;
+// but we're just being safe (and it doesn't hurt.)  Note we must
+// use it to shut up warnings about unused store.
+
+#define AT_DISPATCH_SWITCH(TYPE, NAME, ...)                                 \
+  [&] {                                                                     \
+    const auto& the_type = TYPE;                                            \
+    constexpr const char* at_dispatch_name = NAME;                          \
+    /* don't use TYPE again in case it is an expensive or side-effect op */ \
+    at::ScalarType _st = ::detail::scalar_type(the_type);                   \
+    RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st);                    \
+    switch (_st) {                                                          \
+      __VA_ARGS__                                                           \
+      default:                                                              \
+        AT_ERROR(                                                           \
+            '"',                                                            \
+            at_dispatch_name,                                               \
+            "\" not implemented for '",                                     \
+            toString(_st),                                                  \
+            "'");                                                           \
+    }                                                                       \
+  }()
+
+#define AT_DISPATCH_CASE_FLOATING_TYPES(...)            \
+  AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...)   \
+  AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)  \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                        \
+      TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...)  \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+
+#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                       \
+      TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                               \
+      TYPE,                                                         \
+      NAME,                                                         \
+      AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                                \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES_AND2(       \
+    SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                          \
+      TYPE,                                    \
+      NAME,                                    \
+      AT_DISPATCH_CASE_FLOATING_TYPES_AND2(    \
+          SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3(   \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)  \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)    \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)    \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES_AND3(                    \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                       \
+      TYPE,                                                 \
+      NAME,                                                 \
+      AT_DISPATCH_CASE_FLOATING_TYPES_AND3(                 \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4(                \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES_AND4(                                 \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                                    \
+      TYPE,                                                              \
+      NAME,                                                              \
+      AT_DISPATCH_CASE_FLOATING_TYPES_AND4(                              \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_COMPLEX_TYPES(...)                    \
+  AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
+
+#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                              \
+      TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)           \
+  AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                           \
+      TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)                \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(    \
+    SCALARTYPE, TYPE, NAME, ...)                        \
+  AT_DISPATCH_SWITCH(                                   \
+      TYPE,                                             \
+      NAME,                                             \
+      AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
+          SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2(  \
+    SCALARTYPE1, SCALARTYPE2, ...)                         \
+  AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(    \
+    SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)          \
+  AT_DISPATCH_SWITCH(                                   \
+      TYPE,                                             \
+      NAME,                                             \
+      AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
+          SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3(  \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...)            \
+  AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(        \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                       \
+      TYPE,                                                 \
+      NAME,                                                 \
+      AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3(     \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4(    \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
+  AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)   \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(                     \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                                    \
+      TYPE,                                                              \
+      NAME,                                                              \
+      AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4(                  \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5(                 \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
+  AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)                \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5(    \
+    SCALARTYPE1,                                        \
+    SCALARTYPE2,                                        \
+    SCALARTYPE3,                                        \
+    SCALARTYPE4,                                        \
+    SCALARTYPE5,                                        \
+    TYPE,                                               \
+    NAME,                                               \
+    ...)                                                \
+  AT_DISPATCH_SWITCH(                                   \
+      TYPE,                                             \
+      NAME,                                             \
+      AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
+          SCALARTYPE1,                                  \
+          SCALARTYPE2,                                  \
+          SCALARTYPE3,                                  \
+          SCALARTYPE4,                                  \
+          SCALARTYPE5,                                  \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6(  \
+    SCALARTYPE1,                                           \
+    SCALARTYPE2,                                           \
+    SCALARTYPE3,                                           \
+    SCALARTYPE4,                                           \
+    SCALARTYPE5,                                           \
+    SCALARTYPE6,                                           \
+    ...)                                                   \
+  AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6(    \
+    SCALARTYPE1,                                        \
+    SCALARTYPE2,                                        \
+    SCALARTYPE3,                                        \
+    SCALARTYPE4,                                        \
+    SCALARTYPE5,                                        \
+    SCALARTYPE6,                                        \
+    TYPE,                                               \
+    NAME,                                               \
+    ...)                                                \
+  AT_DISPATCH_SWITCH(                                   \
+      TYPE,                                             \
+      NAME,                                             \
+      AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
+          SCALARTYPE1,                                  \
+          SCALARTYPE2,                                  \
+          SCALARTYPE3,                                  \
+          SCALARTYPE4,                                  \
+          SCALARTYPE5,                                  \
+          SCALARTYPE6,                                  \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...)          \
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)  \
+  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
+
+#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                               \
+      TYPE,                                                         \
+      NAME,                                                         \
+      AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES(...)        \
+  AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_QINT_TYPES(...)                      \
+  AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__)   \
+  AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
+  AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
+
+#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                           \
+      TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...)               \
+  AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
+  AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
+
+#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...)                     \
+  AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
+      at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
+  AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
+      at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__)       \
+  AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
+      at::kQInt32,                                                        \
+      at::qint32,                                                         \
+      CHAR_BIT * sizeof(int),                                             \
+      INT_MIN,                                                            \
+      INT_MAX,                                                            \
+      __VA_ARGS__)                                                        \
+  AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
+      at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__)                 \
+  AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE(                                     \
+      at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
+
+#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                        \
+      TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
+  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)           \
+  AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                      \
+      TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                          \
+      TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                                      \
+      TYPE,                                                                \
+      NAME,                                                                \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                           \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                                         \
+      TYPE,                                                                   \
+      NAME,                                                                   \
+      AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2(  \
+    SCALARTYPE1, SCALARTYPE2, ...)                    \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(    \
+    SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)     \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
+          SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND3(        \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)       \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)    \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)    \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND3(                         \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                       \
+      TYPE,                                                 \
+      NAME,                                                 \
+      AT_DISPATCH_CASE_ALL_TYPES_AND3(                      \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3(  \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...)       \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(             \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                       \
+      TYPE,                                                 \
+      NAME,                                                 \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3(          \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(         \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)        \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                          \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                                    \
+      TYPE,                                                              \
+      NAME,                                                              \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(                       \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5(                      \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)                     \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(    \
+    SCALARTYPE1,                                   \
+    SCALARTYPE2,                                   \
+    SCALARTYPE3,                                   \
+    SCALARTYPE4,                                   \
+    SCALARTYPE5,                                   \
+    TYPE,                                          \
+    NAME,                                          \
+    ...)                                           \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
+          SCALARTYPE1,                             \
+          SCALARTYPE2,                             \
+          SCALARTYPE3,                             \
+          SCALARTYPE4,                             \
+          SCALARTYPE5,                             \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6(  \
+    SCALARTYPE1,                                      \
+    SCALARTYPE2,                                      \
+    SCALARTYPE3,                                      \
+    SCALARTYPE4,                                      \
+    SCALARTYPE5,                                      \
+    SCALARTYPE6,                                      \
+    ...)                                              \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(    \
+    SCALARTYPE1,                                   \
+    SCALARTYPE2,                                   \
+    SCALARTYPE3,                                   \
+    SCALARTYPE4,                                   \
+    SCALARTYPE5,                                   \
+    SCALARTYPE6,                                   \
+    TYPE,                                          \
+    NAME,                                          \
+    ...)                                           \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
+          SCALARTYPE1,                             \
+          SCALARTYPE2,                             \
+          SCALARTYPE3,                             \
+          SCALARTYPE4,                             \
+          SCALARTYPE5,                             \
+          SCALARTYPE6,                             \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7(  \
+    SCALARTYPE1,                                      \
+    SCALARTYPE2,                                      \
+    SCALARTYPE3,                                      \
+    SCALARTYPE4,                                      \
+    SCALARTYPE5,                                      \
+    SCALARTYPE6,                                      \
+    SCALARTYPE7,                                      \
+    ...)                                              \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7(    \
+    SCALARTYPE1,                                   \
+    SCALARTYPE2,                                   \
+    SCALARTYPE3,                                   \
+    SCALARTYPE4,                                   \
+    SCALARTYPE5,                                   \
+    SCALARTYPE6,                                   \
+    SCALARTYPE7,                                   \
+    TYPE,                                          \
+    NAME,                                          \
+    ...)                                           \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
+          SCALARTYPE1,                             \
+          SCALARTYPE2,                             \
+          SCALARTYPE3,                             \
+          SCALARTYPE4,                             \
+          SCALARTYPE5,                             \
+          SCALARTYPE6,                             \
+          SCALARTYPE7,                             \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8(  \
+    SCALARTYPE1,                                      \
+    SCALARTYPE2,                                      \
+    SCALARTYPE3,                                      \
+    SCALARTYPE4,                                      \
+    SCALARTYPE5,                                      \
+    SCALARTYPE6,                                      \
+    SCALARTYPE7,                                      \
+    SCALARTYPE8,                                      \
+    ...)                                              \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8(    \
+    SCALARTYPE1,                                   \
+    SCALARTYPE2,                                   \
+    SCALARTYPE3,                                   \
+    SCALARTYPE4,                                   \
+    SCALARTYPE5,                                   \
+    SCALARTYPE6,                                   \
+    SCALARTYPE7,                                   \
+    SCALARTYPE8,                                   \
+    TYPE,                                          \
+    NAME,                                          \
+    ...)                                           \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
+          SCALARTYPE1,                             \
+          SCALARTYPE2,                             \
+          SCALARTYPE3,                             \
+          SCALARTYPE4,                             \
+          SCALARTYPE5,                             \
+          SCALARTYPE6,                             \
+          SCALARTYPE7,                             \
+          SCALARTYPE8,                             \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_BIT_TYPES(...)                  \
+  AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
+  AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__)   \
+  AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
+
+#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
+
+#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...)     \
+  AT_DISPATCH_SWITCH(                                \
+      TYPE,                                          \
+      NAME,                                          \
+      AT_PRIVATE_CASE_TYPE_USING_HINT(               \
+          at::ScalarType::Int, index_t, __VA_ARGS__) \
+          AT_PRIVATE_CASE_TYPE_USING_HINT(           \
+              at::ScalarType::Long, index_t, __VA_ARGS__))
+
+// ----------------------------------------------------------------------------
+// DEPRECATED MACROS, DON'T USE THESE
+// ----------------------------------------------------------------------------
+
+#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
+  detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF();  \
+  AT_DISPATCH_SWITCH(                                   \
+      TYPE,                                             \
+      NAME,                                             \
+      AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Dispatch_v2.h b/MLPY/Lib/site-packages/torch/include/ATen/Dispatch_v2.h
new file mode 100644
index 0000000000000000000000000000000000000000..8960e2f4d96eb68632ecae60fdc94515963e984a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Dispatch_v2.h
@@ -0,0 +1,186 @@
+#include 
+
+// This is a new implementation of the AT_DISPATCH macro family from
+// ATen/Dispatch.h
+//
+// The intended usage is:
+//
+//  ScalarType scalar_type;
+//
+//  AT_DISPATCH_V2(
+//    scalar_type,
+//    "debug string",
+//    AT_WRAP([&] {
+//      ... code to specialize with scalar_t ...
+//    }),
+//    kHalf,
+//    AT_EXPAND(AT_ALL_TYPES),
+//    ... as many types arguments as needed ...
+//  )
+//
+// For example, given an old style:
+//
+//  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
+//    kComplexHalf,
+//    kHalf,
+//    self.scalar_type(),
+//    "_local_scalar_dense_cpu",
+//    [&] {
+//      scalar_t value = *self.data_ptr();
+//      r = Scalar(value);
+//    }
+//  )
+//
+// You now write:
+//
+//  AT_DISPATCH_V2(
+//    self.scalar_type(),
+//    "_local_scalar_dense_cpu",
+//    AT_WRAP([&] {
+//      scalar_t value = *self.data_ptr();
+//      r = Scalar(value);
+//    }),
+//    AT_EXPAND(AT_ALL_TYPES),
+//    AT_EXPAND(AT_COMPLEX_TYPES),
+//    kComplexHalf,
+//    kHalf,
+//  )
+//
+// Notably, it sports the following improvements:
+//
+//  - It is not necessary to specify the arity (e.g.,
+//    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
+//    when using the macro
+//
+//  - It is not necessary to specify each dtype individually; if
+//    there is a set of related dtypes and you want to dispatch
+//    over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
+//    in your argument list.
+//
+// However, you must remember to wrap the payload body in AT_WRAP, or commas
+// inside your lambda will be improperly handled.  Furthermore, if you more
+// entries to ScalarType than can be supported by this macro, it will fail
+// with an obscure error (due to attempting to concatenate AT_AP with
+// something that is not a number).
+//
+// The implementation strategy is to use the count arguments trick
+// (e.g., as described in https://stackoverflow.com/a/2124385/23845)
+// to discover how many dtypes have been passed, and then dispatch to a
+// hand-written macro for each arity that applies as many DISPATCH_CASE as
+// necessary.  The hand-written macros can be regenerated for other arities
+// with the script below.
+//
+// There is some delicacy in the implementation in controlling when
+// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD.  I mostly
+// relied on GPT4 to help me get it right.
+
+// Public API macros
+
+// See documentation above
+#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
+  AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
+
+// This macro lets you pass an arbitrary expression that may contain internal
+// commas to another macro without having the commas causing the expression
+// to be interpreted as being multiple arguments
+#define AT_WRAP(...) __VA_ARGS__
+
+#define AT_FLOAT8_TYPES                                          \
+  c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
+      c10::kFloat8_e4m3fnuz
+
+#define AT_INTEGRAL_TYPES \
+  c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
+#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
+#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
+#define AT_INTEGRAL_TYPES_V2 \
+  AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
+#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
+#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
+// NB: not *actually* all types
+#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
+#define AT_ALL_TYPES_AND_COMPLEX \
+  AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
+
+// Helper macros
+
+#define AT_AP_VAR(N, T, ...) \
+  AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
+#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
+#define AT_CONCAT_AUX(a, b) a##b
+#define AT_EXPAND(X) X
+
+// Ensure we never have too many scalar types for the expansion here to
+// support.  To bump this, you must regenerate the macros below.
+static_assert(static_cast(c10::ScalarType::NumOptions) < 45);
+
+// Python code to regenerate generate code below:
+#if 0
+
+num_args = 45
+
+nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
+args = ', '.join(f'_{i}' for i in range(1, num_args+1))
+
+print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
+print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
+
+for i in range(1, num_args+1):
+    args = ', '.join(f'_{i}' for i in range(1, i+1))
+    cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
+    print(f'#define AT_AP{i}(N, {args}) {cases}')
+
+#endif
+
+// Begin generated code
+// clang-format off
+
+#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
+#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, N, ...) N
+#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
+#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
+#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
+#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
+#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
+#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
+#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
+#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
+#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
+#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
+#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
+#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
+#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
+#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
+#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
+#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
+#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
+#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
+#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
+#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
+#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
+#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
+#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
+#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
+#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
+#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
+#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
+#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
+#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
+#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
+#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
+#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
+#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
+#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
+#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
+#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
+#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
+#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
+#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
+#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
+#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
+#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
+#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
+#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
+#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
+// End generated code
+// clang-format on
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/DynamicLibrary.h b/MLPY/Lib/site-packages/torch/include/ATen/DynamicLibrary.h
new file mode 100644
index 0000000000000000000000000000000000000000..5e8a2b6d4c10efe0804210e7ae7fc45549a9d166
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/DynamicLibrary.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+class DynamicLibraryError : public Error {
+  using Error::Error;
+};
+
+} // namespace c10
+
+namespace at {
+
+struct DynamicLibrary {
+  AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
+
+  TORCH_API DynamicLibrary(
+      const char* name,
+      const char* alt_name = nullptr,
+      bool leak_handle = false);
+
+  TORCH_API void* sym(const char* name);
+
+  TORCH_API ~DynamicLibrary();
+
+ private:
+  bool leak_handle;
+  void* handle = nullptr;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/EmptyTensor.h b/MLPY/Lib/site-packages/torch/include/ATen/EmptyTensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..2a380a34b965e347a15c6cfc4a6d25aa9a62e773
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/EmptyTensor.h
@@ -0,0 +1,160 @@
+#pragma once
+#include 
+
+namespace at::detail {
+
+inline void check_size_nonnegative(ArrayRef size) {
+  for (const auto& x : size) {
+    TORCH_CHECK(
+        x >= 0,
+        "Trying to create tensor with negative dimension ",
+        x,
+        ": ",
+        size);
+  }
+}
+
+inline void check_size_nonnegative(ArrayRef size) {
+  for (const auto& x : size) {
+    TORCH_CHECK(
+        x.expect_size(__FILE__, __LINE__),
+        "Trying to create tensor with negative dimension ",
+        x,
+        ": ",
+        size);
+  }
+}
+
+TORCH_API size_t computeStorageNbytesContiguous(
+    IntArrayRef sizes,
+    size_t itemsize,
+    size_t storage_offset = 0);
+TORCH_API SymInt computeStorageNbytesContiguous(
+    SymIntArrayRef sizes,
+    const SymInt& itemsize,
+    const SymInt& storage_offset = 0);
+TORCH_API size_t computeStorageNbytes(
+    IntArrayRef sizes,
+    IntArrayRef strides,
+    size_t itemsize,
+    size_t storage_offset = 0);
+TORCH_API SymInt computeStorageNbytes(
+    SymIntArrayRef sizes,
+    SymIntArrayRef strides,
+    const SymInt& itemsize,
+    const SymInt& storage_offset = 0);
+
+TORCH_API TensorBase empty_generic(
+    IntArrayRef size,
+    c10::Allocator* allocator,
+    c10::DispatchKeySet ks,
+    ScalarType scalar_type,
+    c10::optional memory_format_opt);
+
+TORCH_API TensorBase empty_strided_generic(
+    IntArrayRef size,
+    IntArrayRef stride,
+    c10::Allocator* allocator,
+    c10::DispatchKeySet ks,
+    ScalarType scalar_type);
+
+TORCH_API TensorBase empty_strided_symint_generic(
+    SymIntArrayRef size,
+    SymIntArrayRef stride,
+    c10::Allocator* allocator,
+    c10::DispatchKeySet ks,
+    ScalarType scalar_type);
+
+TORCH_API TensorBase empty_cpu(
+    IntArrayRef size,
+    ScalarType dtype,
+    bool pin_memory = false,
+    c10::optional memory_format_opt = c10::nullopt);
+
+TORCH_API TensorBase empty_cpu(
+    IntArrayRef size,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt,
+    c10::optional memory_format_opt);
+
+TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
+
+TORCH_API TensorBase empty_strided_cpu(
+    IntArrayRef size,
+    IntArrayRef stride,
+    ScalarType dtype,
+    bool pin_memory = false);
+
+TORCH_API TensorBase empty_strided_cpu(
+    IntArrayRef size,
+    IntArrayRef stride,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt);
+
+TORCH_API TensorBase empty_strided_cpu(
+    IntArrayRef size,
+    IntArrayRef stride,
+    const TensorOptions& options);
+
+TORCH_API TensorBase empty_meta(
+    IntArrayRef size,
+    ScalarType dtype,
+    c10::optional memory_format_opt = c10::nullopt);
+
+TORCH_API TensorBase empty_meta(
+    IntArrayRef size,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt,
+    c10::optional memory_format_opt);
+
+TORCH_API TensorBase empty_symint_meta(
+    SymIntArrayRef size,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt,
+    c10::optional memory_format_opt);
+
+TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
+
+TORCH_API TensorBase
+empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
+
+TORCH_API TensorBase empty_strided_meta(
+    IntArrayRef size,
+    IntArrayRef stride,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt);
+
+TORCH_API TensorBase empty_strided_meta(
+    IntArrayRef size,
+    IntArrayRef stride,
+    const TensorOptions& options);
+
+TORCH_API TensorBase empty_strided_symint_meta(
+    SymIntArrayRef size,
+    SymIntArrayRef stride,
+    ScalarType dtype);
+
+TORCH_API TensorBase empty_strided_symint_meta(
+    SymIntArrayRef size,
+    SymIntArrayRef stride,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt);
+
+TORCH_API TensorBase empty_strided_symint_meta(
+    SymIntArrayRef size,
+    SymIntArrayRef stride,
+    const TensorOptions& options);
+
+} // namespace at::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ExpandBase.h b/MLPY/Lib/site-packages/torch/include/ATen/ExpandBase.h
new file mode 100644
index 0000000000000000000000000000000000000000..d59a2714455873cf776242bd04157130911c8b28
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ExpandBase.h
@@ -0,0 +1,30 @@
+#include 
+
+// Broadcasting utilities for working with TensorBase
+namespace at {
+namespace internal {
+TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
+} // namespace internal
+
+inline c10::MaybeOwned expand_size(
+    const TensorBase& self,
+    IntArrayRef size) {
+  if (size.equals(self.sizes())) {
+    return c10::MaybeOwned::borrowed(self);
+  }
+  return c10::MaybeOwned::owned(
+      at::internal::expand_slow_path(self, size));
+}
+c10::MaybeOwned expand_size(TensorBase&& self, IntArrayRef size) =
+    delete;
+
+inline c10::MaybeOwned expand_inplace(
+    const TensorBase& tensor,
+    const TensorBase& to_expand) {
+  return expand_size(to_expand, tensor.sizes());
+}
+c10::MaybeOwned expand_inplace(
+    const TensorBase& tensor,
+    TensorBase&& to_expand) = delete;
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ExpandUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/ExpandUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..04710d2796a14d16ead11b0b4d3fae7925f9ab2d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ExpandUtils.h
@@ -0,0 +1,527 @@
+#pragma once
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b);
+TORCH_API std::vector infer_size_symint(
+    SymIntArrayRef a,
+    SymIntArrayRef b);
+TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
+TORCH_API SymDimVector
+infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
+
+// Named type instead of a pair/tuple so that we can be sure to
+// construct the vectors in place and get NRVO.
+template 
+struct InferExpandGeometryResult {
+  Container sizes;
+  Container strides;
+  explicit InferExpandGeometryResult(size_t ndim)
+      : sizes(ndim), strides(ndim) {}
+  explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
+      : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
+};
+
+TORCH_API std::tuple, std::vector>
+inferExpandGeometry(
+    IntArrayRef tensor_sizes,
+    IntArrayRef tensor_strides,
+    IntArrayRef sizes);
+
+TORCH_API InferExpandGeometryResult inferExpandGeometry_dimvector(
+    IntArrayRef tensor_sizes,
+    IntArrayRef tensor_strides,
+    IntArrayRef sizes);
+
+TORCH_API std::vector infer_dense_strides(
+    IntArrayRef tensor_sizes,
+    IntArrayRef tensor_strides);
+
+// True if input shapes are expandable
+// NOTE: infer_size did a similar check, please keep them sync if change is
+// needed
+inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
+  size_t ndim1 = shape1.size();
+  size_t ndim2 = shape2.size();
+  size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
+
+  for (int64_t i = static_cast(ndim) - 1; i >= 0; --i) {
+    if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
+        shape2[ndim2] == 1) {
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+// avoid copy-construction of Tensor by using a reference_wrapper.
+inline void check_defined(
+    std::initializer_list> tensors,
+    const char* api_name) {
+  for (auto& t : tensors) {
+    if (!t.get().defined()) {
+      AT_ERROR(api_name, "(...) called with an undefined Tensor");
+    }
+  }
+}
+
+// NOTE [ ExpandUtils Borrowing ]
+//
+// Functions in ExpandUtils return `c10::MaybeOwned` because
+// expansion may not actually be needed, in which case we can improve
+// efficiency by returning
+// `c10::MaybeOwned::borrowed(to_expand)`. However, this means
+// that you need to be careful: the returned `c10::MaybeOwned`
+// must not outlive the original `Tensor` object that `to_expand`
+// referred to! The deleted rvalue reference overloads of these
+// functions help with this by preventing trivial use of a temporary
+// resulting from a function call, but it is still possible to make a
+// mistake.
+
+inline c10::MaybeOwned expand_inplace(
+    const Tensor& tensor,
+    const Tensor& to_expand) {
+  if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
+    return c10::MaybeOwned::borrowed(to_expand);
+  }
+  return c10::MaybeOwned::owned(
+      to_expand.expand_symint(tensor.sym_sizes()));
+}
+
+inline c10::MaybeOwned expand_inplace(
+    const Tensor& tensor,
+    Tensor&& to_expand) = delete;
+
+inline c10::MaybeOwned expand_inplace(
+    const Tensor& tensor,
+    const Tensor& to_expand,
+    const char* api_name) {
+  check_defined({tensor, to_expand}, api_name);
+  return expand_inplace(tensor, to_expand);
+}
+
+inline c10::MaybeOwned expand_inplace(
+    const Tensor& tensor,
+    Tensor&& to_expand,
+    const char* api_name) = delete;
+
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    const Tensor& to_expand1,
+    const Tensor& to_expand2) {
+  if (tensor.sizes().equals(to_expand1.sizes()) &&
+      tensor.sizes().equals((to_expand2.sizes()))) {
+    return std::make_tuple(
+        c10::MaybeOwned::borrowed(to_expand1),
+        c10::MaybeOwned::borrowed(to_expand2));
+  }
+
+  return std::make_tuple(
+      c10::MaybeOwned::owned(to_expand1.expand(tensor.sizes())),
+      c10::MaybeOwned::owned(to_expand2.expand(tensor.sizes())));
+}
+
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    Tensor&& to_expand1,
+    const Tensor& to_expand2) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    const Tensor& to_expand1,
+    Tensor&& to_expand2) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
+    delete;
+
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    const Tensor& to_expand1,
+    const Tensor& to_expand2,
+    const char* api_name) {
+  check_defined({tensor, to_expand1, to_expand2}, api_name);
+  return expand_inplace(tensor, to_expand1, to_expand2);
+}
+
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    Tensor&& to_expand1,
+    const Tensor& to_expand2,
+    const char* api_name) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    const Tensor& to_expand1,
+    Tensor&& to_expand2,
+    const char* api_name) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_inplace(
+    const Tensor& tensor,
+    Tensor&& to_expand1,
+    Tensor&& to_expand2,
+    const char* api_name) = delete;
+
+// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
+  auto s1 = to_expand1.sym_sizes();
+  auto s2 = to_expand2.sym_sizes();
+  if (s1.equals(s2)) {
+    return std::make_tuple(
+        c10::MaybeOwned::borrowed(to_expand1),
+        c10::MaybeOwned::borrowed(to_expand2));
+  }
+
+  auto expanded_size = infer_size_symdimvector(s1, s2);
+  return std::make_tuple(
+      c10::MaybeOwned::owned(to_expand1.expand_symint(expanded_size)),
+      c10::MaybeOwned::owned(to_expand2.expand_symint(expanded_size)));
+}
+
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
+
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    const Tensor& to_expand2,
+    const char* api_name) {
+  check_defined({to_expand1, to_expand2}, api_name);
+  return expand_outplace(to_expand1, to_expand2);
+}
+
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    const Tensor& to_expand2,
+    const char* api_name) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    Tensor&& to_expand2,
+    const char* api_name) = delete;
+inline std::tuple, c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    Tensor&& to_expand2,
+    const char* api_name) = delete;
+
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    const Tensor& to_expand2,
+    const Tensor& to_expand3) {
+  if (to_expand1.sizes().equals(to_expand2.sizes()) &&
+      to_expand1.sizes().equals(to_expand3.sizes())) {
+    return std::make_tuple(
+        c10::MaybeOwned::borrowed(to_expand1),
+        c10::MaybeOwned::borrowed(to_expand2),
+        c10::MaybeOwned::borrowed(to_expand3));
+  }
+
+  auto expanded_size12 =
+      infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
+  auto expanded_size =
+      infer_size_dimvector(expanded_size12, to_expand3.sizes());
+  return std::make_tuple(
+      c10::MaybeOwned::owned(to_expand1.expand(expanded_size)),
+      c10::MaybeOwned::owned(to_expand2.expand(expanded_size)),
+      c10::MaybeOwned::owned(to_expand3.expand(expanded_size)));
+}
+
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    const Tensor& to_expand2,
+    const Tensor& to_expand3) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    Tensor&& to_expand2,
+    const Tensor& to_expand3) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    Tensor&& to_expand2,
+    const Tensor& to_expand3) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    const Tensor& to_expand2,
+    Tensor&& to_expand3) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    const Tensor& to_expand2,
+    Tensor&& to_expand3) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    Tensor&& to_expand2,
+    Tensor&& to_expand3) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
+    delete;
+
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    const Tensor& to_expand2,
+    const Tensor& to_expand3,
+    const char* api_name) {
+  check_defined({to_expand1, to_expand2, to_expand3}, api_name);
+  return expand_outplace(to_expand1, to_expand2, to_expand3);
+}
+
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    const Tensor& to_expand2,
+    const Tensor& to_expand3,
+    const char* api_name) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    Tensor&& to_expand2,
+    const Tensor& to_expand3,
+    const char* api_name) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    Tensor&& to_expand2,
+    const Tensor& to_expand3,
+    const char* api_name) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    const Tensor& to_expand2,
+    Tensor&& to_expand3,
+    const char* api_name) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    const Tensor& to_expand2,
+    Tensor&& to_expand3,
+    const char* api_name) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    const Tensor& to_expand1,
+    Tensor&& to_expand2,
+    Tensor&& to_expand3,
+    const char* api_name) = delete;
+inline std::tuple<
+    c10::MaybeOwned,
+    c10::MaybeOwned,
+    c10::MaybeOwned>
+expand_outplace(
+    Tensor&& to_expand1,
+    Tensor&& to_expand2,
+    Tensor&& to_expand3,
+    const char* api_name) = delete;
+
+inline c10::MaybeOwned expand_size(
+    const Tensor& to_expand,
+    IntArrayRef sizes) {
+  if (to_expand.sizes().equals(sizes)) {
+    return c10::MaybeOwned::borrowed(to_expand);
+  }
+
+  return c10::MaybeOwned::owned(to_expand.expand(sizes));
+}
+
+inline c10::MaybeOwned expand_size(
+    Tensor&& to_expand,
+    IntArrayRef sizes) = delete;
+
+inline c10::MaybeOwned expand_size(
+    const Tensor& to_expand,
+    IntArrayRef sizes,
+    const char* api_name) {
+  check_defined({to_expand}, api_name);
+  return expand_size(to_expand, sizes);
+}
+
+inline c10::MaybeOwned expand_size(
+    Tensor&& to_expand,
+    IntArrayRef sizes,
+    const char* api_name) = delete;
+
+inline std::vector expand_outplace(TensorList to_expand) {
+  // expands a list of Tensors; ignores undefined (null) tensors
+  bool first = true;
+  DimVector sizes;
+  for (const auto i : c10::irange(to_expand.size())) {
+    if (!to_expand[i].defined()) {
+      continue;
+    } else if (first) {
+      sizes = to_expand[i].sizes();
+      first = false;
+    } else {
+      sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
+    }
+  }
+
+  std::vector result(to_expand.size());
+  for (const auto i : c10::irange(to_expand.size())) {
+    if (!to_expand[i].defined()) {
+      continue;
+    } else if (to_expand[i].sizes().equals(sizes)) {
+      result[i] = to_expand[i];
+    } else {
+      result[i] = to_expand[i].expand(sizes);
+    }
+  }
+  return result;
+}
+
+template 
+inline Tensor _sum_to(
+    Tensor tensor,
+    const c10::ArrayRef shape,
+    bool always_return_non_view = false) {
+  if (shape.size() == 0) {
+    return tensor.sum();
+  }
+
+  auto sizes = at::symint::sizes(tensor);
+  c10::SmallVector reduce_dims;
+  const int64_t leading_dims = sizes.size() - shape.size();
+  for (const auto i : c10::irange(leading_dims)) {
+    reduce_dims.push_back(i);
+  }
+  for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) {
+    if (shape[i - leading_dims] == 1 && sizes[i] != 1) {
+      reduce_dims.push_back(i);
+    }
+  }
+
+  if (!reduce_dims.empty()) {
+    tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
+  }
+
+  if (always_return_non_view) {
+    // This is only actually used by the functionalization pass.
+    // We want to be able to guarantee that this function doesn't return a view
+    // of the input.
+    return leading_dims > 0 ? at::symint::view_copy(tensor, shape)
+                            : tensor.clone();
+  } else {
+    return leading_dims > 0 ? at::symint::view(tensor, shape) : tensor;
+  }
+}
+
+inline Tensor sum_to(
+    Tensor tensor,
+    const c10::SymIntArrayRef shape,
+    bool always_return_non_view = false) {
+  return _sum_to(std::move(tensor), shape, always_return_non_view);
+}
+
+// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
+// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
+inline Tensor sum_to(
+    Tensor tensor,
+    const IntArrayRef shape,
+    bool always_return_non_view = false) {
+  return _sum_to(std::move(tensor), shape, always_return_non_view);
+}
+
+static inline bool is_expandable_to(
+    SymIntArrayRef shape,
+    c10::SymIntArrayRef desired) {
+  size_t ndim = shape.size();
+  size_t target_dim = desired.size();
+  if (ndim > target_dim) {
+    return false;
+  }
+  for (const auto i : c10::irange(ndim)) {
+    const auto& size = shape[ndim - i - 1];
+    const auto& target = desired[target_dim - i - 1];
+    if (size != target && size != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
+  auto sym_shape = c10::SymIntArrayRef(
+      reinterpret_cast(shape.data()), shape.size());
+  auto sym_desired = c10::SymIntArrayRef(
+      reinterpret_cast(desired.data()), desired.size());
+  return is_expandable_to(sym_shape, sym_desired);
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Formatting.h b/MLPY/Lib/site-packages/torch/include/ATen/Formatting.h
new file mode 100644
index 0000000000000000000000000000000000000000..e23b27ffd373180a1857a5491694eff11705f9a1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Formatting.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h b/MLPY/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h
new file mode 100644
index 0000000000000000000000000000000000000000..6430caadfa947f57f76f1d7e218b4b4d60140f8b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::functorch {
+
+// NOTE [functorch TLS in pytorch/pytorch]
+//
+// functorch lives out-of-tree. However, it has some TLS that needs to be
+// propagated. The solution for that is we store a pointer to the TLS
+// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
+// include whatever functorch needs.
+//
+// We need to store a pointer due to the indirection:
+// inside functorch, we will create a subclass of FunctorchTLSBase called
+// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
+// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
+// yet.
+//
+// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
+// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
+// We can't directly pass around FunctorchTLSBase (without a pointer) because
+// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
+// more elements.
+struct TORCH_API FuncTorchTLSBase {
+  virtual ~FuncTorchTLSBase() = default;
+  virtual std::unique_ptr deepcopy() const = 0;
+
+  virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
+  virtual void checkSupportsCppAutogradFunction() const = 0;
+  virtual void checkSupportsInplaceRequiresGrad() const = 0;
+  virtual void checkSupportsRetainGrad() const = 0;
+};
+
+// returns deepcopy of the functorch tls
+TORCH_API std::unique_ptr getCopyOfFuncTorchTLS();
+
+// sets the functorch tls. always does a deep copy.
+TORCH_API void setFuncTorchTLS(
+    const std::shared_ptr& state);
+
+// get a mutable reference to the functorch tls
+TORCH_API std::unique_ptr& functorchTLSAccessor();
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/FunctionalStorageImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/FunctionalStorageImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..2753121e1da2ae541b5afe748a3c05690f01676e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/FunctionalStorageImpl.h
@@ -0,0 +1,126 @@
+#pragma once
+
+#include 
+
+namespace at::functionalization {
+
+// See Note [Functionalization Pass In Core]
+
+// ViewMeta is a class used by the functionalization pass to navigate between
+// a base tensor and a view tensor.
+// For example, if I call `b = a.view1(...)`
+// the functionalization pass will generate and store a ViewMeta on b that looks
+// like:
+//
+// ViewMeta(
+//   [](const Tensor& base, int64_t mutated_view_idx) {
+//     return base.view1(...);
+//   },
+//   [](const at::Tensor& base, const at::Tensor& mutated_view,
+//   int64_t mutated_view_idx) -> at::Tensor {
+//     return at::functionalization::impl::view1_inverse(base, mutated_view,
+//     ...);
+//   }
+//
+// The forward_fn lambda describes how to replay view1 on a tensor.
+//
+// The reverse_fn lambda describes how, given a tensor that is already a view,
+// how to get the corresponding base tensor. See Note [Functionalization Pass:
+// View Inverses] for details.
+struct ViewMeta {
+  ViewMeta(
+      std::function forward,
+      std::function reverse,
+      bool is_multi_output = false,
+      int64_t out_idx = 0)
+      : forward_fn(std::move(forward)),
+        reverse_fn(std::move(reverse)),
+        out_index(out_idx),
+        is_multi_output(is_multi_output) {}
+
+  std::function forward_fn;
+  std::function reverse_fn;
+  // See Note [out_idx in ViewMeta]
+  int64_t out_index;
+
+  // Tells us if this is a multi-output view
+  bool is_multi_output;
+
+  // Returns a copy of the current ViewMeta, if out_idx matches the current
+  // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
+  // functions, but a new out index.
+  ViewMeta to_out_idx(int64_t out_idx);
+};
+
+// FunctionalStorageImpl is a subclass of StorageImpl used by the
+// functionalization pass. It has no underlying data (similar to meta storage).
+// It also knows how to reflect mutations to tensors in the absence of a valid
+// data pointer.
+//
+// A storage represents the state shared by (potentially multiple) views of the
+// same tensor. For example, in the following code:
+//
+// b = a.view1(...)
+// c = b.view2(...)
+// b.add_(1)
+// --> storage.add_update(b, {view1_meta})
+//
+// The call to add_(1) will result in a call to alias.add_update(b,
+// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose
+// c is used in an expression (e.g. you try to print c, or pass it to an
+// operator). Doing so will involve "syncing" c. First we apply any pending
+// updates to the alias, and then we regenerate c by replaying its views off of
+// the updated alias. E.g:
+//
+// print(str(c))
+// --> c.sync_()
+//     --> alias.apply_updates() // after this, the alias will be updated to
+//     reflect the mutation to b
+struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
+ public:
+  struct Update {
+    // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+    const at::Tensor new_val;
+    // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+    const std::vector view_metas;
+  };
+
+  explicit FunctionalStorageImpl(const Tensor& value);
+
+  void add_update(
+      const Tensor& updated_val,
+      const std::vector& view_metas);
+  bool apply_updates();
+  const Tensor& base() {
+    return base_;
+  }
+  size_t generation() const {
+    return generation_;
+  }
+  void freeze() {
+    frozen_ = true;
+  }
+
+  ~FunctionalStorageImpl() override = default;
+
+ private:
+  // NB: base_ should always point to a tensor BELOW the current
+  // functionalization layer. This is mainly to avoid reference cycles. e.g.
+  // given `b = a.view(...)` Both a.storage_ and b.storage_ are a
+  // FunctionStorageImpl containing an Walualias, with contains a Tensor
+  // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_
+  // should point not to a, but to a's unwrapped value, a.value_` See Note
+  // [Functionalization: Walualias Removal] for a diagram that shows this
+  // visually.
+  at::Tensor base_;
+  std::vector updates_;
+  // generation_ gets incremented every time a mutation is queued onto the
+  // alias. It is used to determine if a given tensor is "up to date", or if it
+  // needs to be regenerated from the alias.
+  size_t generation_ = 0;
+  // If frozen, no more mutations are allowed on this storage.  Once frozen, a
+  // storage cannot be unfrozen.
+  bool frozen_ = false;
+};
+
+} // namespace at::functionalization
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h b/MLPY/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h
new file mode 100644
index 0000000000000000000000000000000000000000..ef40ae5e931bc8eda6f00c16cac9b67eed7276ba
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h
@@ -0,0 +1,408 @@
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+
+// Note [Functionalization Pass In Core]
+// The Functionalization pass is used to remove aliasing from a pytorch program.
+//
+// This is useful for backends that don't support aliasing, like XLA and Vulkan.
+// It's also necessary in order to remove mutation from a program, which is
+// needed in Functorch.
+//
+// Consider this program:
+// a = torch.ones(...)
+// b = a.view(...)
+// b.add_(1)
+//
+// In this program, b is meant to alias with a due to the use of view(). At the
+// end of the program, both a and b are full of 2's. However, backends that
+// don't support aliasing aren't able to correctly implement the view()
+// operator. Instead, they can opt into the Functionalization pass, which will
+// sit between the user and the backend, and provide the necessary aliasing
+// logic.
+//
+// The functionalization pass will turn the above program into a slightly
+// different program that has the same semantics, transparently to the user,
+// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
+// a.view_copy(...)  # view() replaced with view_copy(). Backends like
+// XLA/Vulkan can implement this! b.add_(1) a.add_(1)  # Our functionalization
+// pass machinery knows that a and b are aliased - it applies b's mutation to a
+// too.
+//
+// So, how does the functionalization pass keep track of which tensors are
+// aliased? The pass works by wrapping EVERY tensor in the program inside of a
+// FunctionalTensorWrapper, which knows about its alias'd tensors.
+//
+// See Note [Functionalization: Alias Removal] for details on the aliasing
+// machinery. See Note [Functionalization: Mutation Removal] for details on
+// mutation removal.
+struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
+  explicit FunctionalTensorWrapper(const Tensor& value);
+  // Additional constructor to create a FunctionalTensorWrapper directly from an
+  // underlying tensor that was created from a view. For example, the code b =
+  // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
+  // view1_meta)
+  explicit FunctionalTensorWrapper(
+      const Tensor& view_value,
+      const FunctionalTensorWrapper* base,
+      const functionalization::ViewMeta& meta);
+
+  // Get the underlying, actual tensor, that doesn't know anything about
+  // functionalization.
+  const Tensor& value() const {
+    return value_;
+  };
+  // The concept of "level" is only ever important to functorch; it's exposed
+  // here as more of a hook for functorch to use.
+  int64_t level() const {
+    return level_;
+  };
+  void set_level(int64_t level) {
+    level_ = level;
+  }
+  bool has_metadata_mutation() const {
+    return has_metadata_mutation_;
+  };
+
+  // Denotes a mutation that's hidden from autograd,
+  // e.g. for the purposes of passing a tensor to a triton kernel
+  void mark_mutation_hidden_from_autograd() {
+    mutation_hidden_from_autograd_counter_++;
+  }
+  void mark_mutation_during_no_grad_or_inference_mode() {
+    mutation_during_no_grad_or_inference_mode_++;
+  }
+  // Are all the mutations happening to the tensor hidden from autograd
+  bool are_all_mutations_hidden_from_autograd() const {
+    return mutation_hidden_from_autograd_counter_ == mutation_counter_;
+  }
+  // Did all mutations happen under no_grad or inference_mode
+  // (We also need to ignore mutations fully hidden from autograd here)
+  bool are_all_mutations_under_no_grad_or_inference_mode() const {
+    return mutation_hidden_from_autograd_counter_ +
+        mutation_during_no_grad_or_inference_mode_ ==
+        mutation_counter_;
+  }
+
+  // Sync's the underlying tensor with its alias, if it's out of date. This
+  // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
+  // Replay the views (if any) to regenerate the current tensor off of the
+  // updated alias.
+  void sync_();
+  // Performs step (1) of the sync. This is its own public API because it's
+  // needed by view_inplace ops like transpose_. See Note [Functionalization
+  // Pass - Inplace View Ops]
+  void regenerate_from_base();
+  // Performs step (2) of the sync. This is its own public API because it's
+  // needed by functorch. functorch wants to make sure that all input tensors to
+  // a functionalized program have been properly synced so it can properly
+  // propagate mutations to inputs. It can't just call sync_(), because the
+  // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
+  // a noop. We use the reference count on storage_ to determine if the wrapper
+  // is aliased, and by the time functorch is ready to propagate updates to
+  // inputs, any intermediate views of the input created by the program will
+  // have been deallocated. This function also returns whether or not the base
+  // actually had any updates to apply.
+  bool apply_updates();
+  // Takes the current state of value_ and snapshots it, sending it as a pending
+  // update to the alias.
+  void commit_update();
+  // When any tensor is mutated, the tensor increments its alias's "generation".
+  // Separately, each tensor maintains its own "generation" counter, which is
+  // used to determine if it's up-to-date with its alias. The act of syncing a
+  // tensor will set a tensor's generation equal to its alias's generation.
+  bool is_up_to_date() const;
+  // Freezes the storage of this tensor, preventing subsequent mutations
+  void freeze_storage() const;
+  // Every FunctionalTensorWrapper contains a vector objects
+  // describing the series of view ops that ran to generate the current tensor
+  // from the base tensor. This method is used by inplace-view ops like
+  // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
+  // tensor by replaying the views off of the alias.
+  void mutate_view_meta(const at::functionalization::ViewMeta& meta);
+
+  // Custom implementation of self.set_(src)
+  void set__impl(const FunctionalTensorWrapper* other);
+
+  // Returns whether the current tensor's data was ever mutated
+  bool has_data_mutation();
+  //
+  // Returns whether the current FunctionalTensorWrapper
+  // experienced a set_() call.
+  bool was_storage_changed() {
+    return was_storage_changed_;
+  }
+
+  // The functionalization pass can be used to remove mutations.
+  // It does so by replacing any mutation op with it's corresponding
+  // out-of-place op, followed by a call to replace_(). e.g:
+  //
+  // a.add_(1)
+  //
+  // will turn into:
+  //
+  // tmp = a.add(1)
+  // a.replace_(tmp)
+  //
+  // replace_() swaps out the wrapped tensor, value_, with tmp.
+  void replace_(const Tensor& other);
+
+  bool is_multi_output_view() {
+    return is_multi_output_view_;
+  }
+
+  // See Note[resize_() in functionalization pass]
+  void maybe_replace_storage(const Tensor& other);
+
+  // Replaces the storage with a new functional storage,
+  // and clears the view_metas_ stack.
+  // WARNING: Calling this function will sever the aliasing relationship between
+  // the current FunctionalTensorWrapper and any of its outstanding aliases.
+  // Please only call if you know what you're doing.
+  void _unsafe_reset_storage();
+
+  c10::intrusive_ptr shallow_copy_and_detach(
+      const c10::VariableVersion& version_counter,
+      bool allow_tensor_metadata_change) const override;
+
+  c10::intrusive_ptr shallow_copy_and_detach(
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const override;
+
+  ~FunctionalTensorWrapper() override = default;
+
+  // FunctionalTensorWrapper overrides all custom size/stride function,
+  // so that if the inner tensor has a custom implementation
+  // we make sure to call that implementation.
+  at::IntArrayRef sizes_custom() const override;
+  at::IntArrayRef strides_custom() const override;
+  int64_t dim_custom() const override;
+  int64_t numel_custom() const override;
+  bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
+  c10::SymIntArrayRef sym_sizes_custom() const override;
+  c10::SymInt sym_size_custom(int64_t d) const override;
+  c10::SymIntArrayRef sym_strides_custom() const override;
+  c10::SymInt sym_storage_offset_custom() const override;
+  c10::Device device_custom() const override;
+
+ private:
+  const char* tensorimpl_type_name() const override;
+  void set_constructor_metadata();
+  functionalization::FunctionalStorageImpl* functional_storage_impl() const;
+
+  // This is used to re-implement shallow_copy_and_detach for
+  // FunctionalTensorWrapper. The implementation is identical, but we just need
+  // to return a subclass instead of a plain TensorImpl.
+  // TODO: maybe it's possible to arrange for that to happen automatically
+  // without an override here?
+  template 
+  c10::intrusive_ptr shallow_copy_and_detach_core(
+      VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const;
+
+  void shallow_copy_from(const c10::intrusive_ptr& impl) override;
+  void copy_tensor_metadata_and_refresh(
+      const FunctionalTensorWrapper* src_impl,
+      FunctionalTensorWrapper* dest_impl,
+      const c10::VariableVersion& version_counter,
+      bool allow_tensor_metadata_change) const;
+
+  // Note that value is not taken by reference: internally, the wrapper will
+  // change the value tensor that it points to over time.
+  Tensor value_;
+  int64_t level_{};
+  // These two counters are used for identifying
+  // whether all the mutations on a given tensor are hidden from autograd or
+  // not. If we have an input mutation that is hidden from autograd, then once
+  // we convert the input mutation to a copy_() we know it will be safe to hide
+  // the copy_() from autograd as well.
+  uint64_t mutation_counter_ = 0;
+  uint64_t mutation_hidden_from_autograd_counter_ = 0;
+  uint64_t mutation_during_no_grad_or_inference_mode_ = 0;
+  bool has_metadata_mutation_ = false;
+  bool is_multi_output_view_ = false;
+  // Did the tensor experience a set_() call.
+  bool was_storage_changed_ = false;
+
+  size_t generation_ = 0;
+  std::vector view_metas_;
+
+ protected:
+  static void copy_tensor_metadata(
+      const FunctionalTensorWrapper* src_impl,
+      FunctionalTensorWrapper* dest_impl,
+      const c10::VariableVersion& version_counter,
+      bool allow_tensor_metadata_change);
+};
+
+// Utility functions for the functionalization pass.
+
+namespace functionalization {
+namespace impl {
+
+TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
+    const Tensor& tensor) {
+  auto functional_impl =
+      static_cast(tensor.unsafeGetTensorImpl());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
+  return functional_impl;
+}
+
+TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
+TORCH_API bool isFunctionalTensor(const c10::optional& t);
+TORCH_API bool isFunctionalTensor(
+    const c10::List>& t_list);
+TORCH_API bool isFunctionalTensor(ITensorListRef list);
+
+TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
+TORCH_API c10::optional to_functional_tensor(
+    const c10::optional& tensor);
+TORCH_API c10::List> to_functional_tensor(
+    const c10::List>& t_list);
+TORCH_API std::vector to_functional_tensor(ITensorListRef t_list);
+
+TORCH_API void freeze_functional_tensor(const Tensor& tensor);
+
+TORCH_API Tensor
+from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
+TORCH_API c10::optional from_functional_tensor(
+    const c10::optional& t,
+    bool assert_functional = true);
+TORCH_API c10::List> from_functional_tensor(
+    const c10::List>& t_list);
+TORCH_API std::vector from_functional_tensor(ITensorListRef t_list);
+
+TORCH_API void sync(const at::Tensor& t);
+TORCH_API void sync(const c10::optional& t);
+TORCH_API void sync(const c10::List>& t_list);
+TORCH_API void sync(ITensorListRef t_list);
+
+TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
+TORCH_API void replace_(
+    const ITensorListRef functional_tensor,
+    ITensorListRef other);
+
+TORCH_API void commit_update(const Tensor& functional_tensor);
+TORCH_API void commit_update(ITensorListRef functional_tensor);
+
+TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
+
+TORCH_API void mark_mutation_hidden_from_autograd(
+    const Tensor& functional_tensor);
+
+TORCH_API bool are_all_mutations_hidden_from_autograd(
+    const Tensor& functional_tensor);
+
+TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
+    const Tensor& functional_tensor);
+
+// These two methods are XLA-specific logic and are no-ops
+// for the normal functionalization flow.
+TORCH_API void propagate_xla_data(
+    const Tensor& functional_tensor,
+    const Tensor& other);
+TORCH_API void propagate_xla_data(
+    const ITensorListRef functional_tensor,
+    ITensorListRef other);
+
+Tensor create_functional_tensor_with_view_meta(
+    const Tensor& view_to_wrap,
+    const Tensor& base,
+    functionalization::ViewMeta meta,
+    int64_t out_idx = 0);
+std::vector create_functional_tensor_with_view_meta(
+    ITensorListRef view_to_wrap,
+    const Tensor& base,
+    const functionalization::ViewMeta& meta);
+
+void mutate_view_meta(
+    const Tensor& self,
+    const functionalization::ViewMeta& meta);
+
+void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
+void set_sizes_strides_offset(
+    const std::vector& outs,
+    const std::vector& meta_outs);
+
+//  ~~~~~ TLS used in functionalization ~~~~~
+
+TORCH_API bool getFunctionalizationReapplyViewsTLS();
+TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
+
+class TORCH_API FunctionalizationReapplyViewsGuard {
+ public:
+  FunctionalizationReapplyViewsGuard(bool reapply_views)
+      : prev_(getFunctionalizationReapplyViewsTLS()) {
+    setFunctionalizationReapplyViewsTLS(reapply_views);
+  }
+
+  ~FunctionalizationReapplyViewsGuard() {
+    setFunctionalizationReapplyViewsTLS(prev_);
+  }
+
+  FunctionalizationReapplyViewsGuard(
+      const FunctionalizationReapplyViewsGuard&) = delete;
+  FunctionalizationReapplyViewsGuard operator=(
+      const FunctionalizationReapplyViewsGuard&) = delete;
+  FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
+      delete;
+  FunctionalizationReapplyViewsGuard operator=(
+      FunctionalizationReapplyViewsGuard&&) = delete;
+
+ private:
+  bool prev_;
+};
+
+} // namespace impl
+
+// Helper function to call an out-of-place composite aten kernel that may use
+// mutations / views internally, and functionalize them.
+TORCH_API void functionalize_op_helper(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack);
+
+template 
+struct _functionalize_aten_op final {};
+
+template 
+struct _functionalize_aten_op final {
+  static ReturnType call(
+      typename c10::maybe_keep_symint::type... args) {
+    using FuncType = ReturnType(
+        typename c10::maybe_keep_symint::type...);
+    auto op = c10::Dispatcher::singleton()
+                  .findSchemaOrThrow(
+                      (const char*)Op::name, (const char*)Op::overload_name)
+                  .typed();
+
+    return c10::impl::BoxedKernelWrapper::call(
+        c10::BoxedKernel::makeFromFunction(),
+        op,
+        // BoxedKernelWrapper knows to ignore this keyset argument,
+        // because functionalize_op_helper doesn't take in a DispatchKeySet
+        c10::DispatchKeySet(),
+        args...);
+  }
+};
+
+template 
+using functionalize_aten_op =
+    _functionalize_aten_op;
+
+template 
+using functionalize_aten_op_symint =
+    _functionalize_aten_op;
+
+} // namespace functionalization
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Functions.h b/MLPY/Lib/site-packages/torch/include/ATen/Functions.h
new file mode 100644
index 0000000000000000000000000000000000000000..a2ca4df21178cb487fcaca8b2908a150824a28d3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Functions.h
@@ -0,0 +1,1427 @@
+#pragma once
+
+// @generated by torchgen/gen.py from Functions.h
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,            \
+  meaning the file will need to be re-compiled every time an operator     \
+  is changed or added. Consider if your change would be better placed in  \
+  another file, or if a more specific header might achieve the same goal. \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from  and   \
+  see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
+//
+// In ATen, certain generated headers files include the definitions of
+// every single operator in PyTorch. Unfortunately this means every
+// time an operator signature is updated or changed in
+// native_functions.yaml, you (and every other PyTorch developer) need
+// to recompile every source file that includes any of these headers.
+//
+// To break up these header dependencies, and improve incremental
+// build times for all PyTorch developers. These headers are split
+// into per-operator headers in the `ATen/ops` folder. This limits
+// incremental builds to only changes to methods of `Tensor`, or files
+// that use the specific operator being changed. With `at::sum` as an
+// example, you should include
+//
+//                  // instead of ATen/Functions.h
+//           // instead of ATen/NativeFunctions.h
+//              // instead of ATen/Operators.h
+//     // instead of ATen/CPUFunctions.h
+//
+// However, even if you're careful to use this in your own code.
+// `Functions.h` might be included indirectly through another header
+// without you realising. To avoid this, you can add
+//
+//   #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+//
+// to the top of your source file. This way any time the non-specific
+// headers are included, the compiler will error out.
+//
+// Also, be aware that `ops` are not available in all build
+// configurations (namely fb-internal) so you must guard these
+// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
+//
+//   #ifndef AT_PER_OPERATOR_HEADERS
+//   #include 
+//   #else
+//   #include 
+//   #endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+
+
+// Special C++ only overloads for std()-like functions (See gh-40287)
+// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
+// So, for example std(0) would select the std(unbiased=False) overload
+TORCH_API inline Tensor var(const Tensor& self, int dim) {
+  return at::var(self, IntArrayRef{dim});
+}
+TORCH_API inline std::tuple var_mean(const Tensor& self, int dim) {
+  return at::var_mean(self, IntArrayRef{dim});
+}
+TORCH_API inline Tensor std(const Tensor& self, int dim) {
+  return at::std(self, IntArrayRef{dim});
+}
+TORCH_API inline std::tuple std_mean(const Tensor& self, int dim) {
+  return at::std_mean(self, IntArrayRef{dim});
+}
+
+inline int64_t numel(const Tensor& tensor) {
+  return tensor.numel();
+}
+
+inline int64_t size(const Tensor& tensor, int64_t dim) {
+  return tensor.size(dim);
+}
+
+inline int64_t stride(const Tensor& tensor, int64_t dim) {
+  return tensor.stride(dim);
+}
+
+inline bool is_complex(const Tensor& tensor) {
+  return tensor.is_complex();
+}
+
+inline bool is_floating_point(const Tensor& tensor) {
+  return tensor.is_floating_point();
+}
+
+inline bool is_signed(const Tensor& tensor) {
+  return tensor.is_signed();
+}
+
+inline bool is_inference(const Tensor& tensor) {
+  return tensor.is_inference();
+}
+
+inline bool _is_zerotensor(const Tensor& tensor) {
+  return tensor._is_zerotensor();
+}
+
+inline bool is_conj(const Tensor& tensor) {
+  return tensor.is_conj();
+}
+
+inline Tensor conj(const Tensor& tensor) {
+  return tensor.conj();
+}
+
+inline bool is_neg(const Tensor& tensor) {
+  return tensor.is_neg();
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Generator.h b/MLPY/Lib/site-packages/torch/include/ATen/Generator.h
new file mode 100644
index 0000000000000000000000000000000000000000..741e39f29dae4cca6cb39f8b1d385bb14ed1b6c5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Generator.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/InferSize.h b/MLPY/Lib/site-packages/torch/include/ATen/InferSize.h
new file mode 100644
index 0000000000000000000000000000000000000000..853425357e3d7d4697adfbaf1e67f5c295d2760f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/InferSize.h
@@ -0,0 +1,87 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+// Infers the size of a dim with size -1, if it exists. Also checks that new
+// shape is compatible with the number of elements.
+//
+// templated to handle std::vector and DimVector use cases, see
+// below
+//
+template 
+inline void infer_size_impl(
+    InputArrayRef shape,
+    NumelType numel,
+    ResultVec& res) {
+  NumelType newsize = 1;
+  // N.B. this is an index, not a sym dim!
+  auto infer_dim = c10::optional();
+  for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
+    if (shape[dim] == -1) {
+      if (infer_dim) {
+        throw std::runtime_error("only one dimension can be inferred");
+      }
+      infer_dim = dim;
+    } else if (shape[dim] >= 0) {
+      newsize *= shape[dim];
+    } else {
+      AT_ERROR("invalid shape dimension ", shape[dim]);
+    }
+  }
+
+  if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
+    if (infer_dim) {
+      // We have a degree of freedom here to select the dimension size; follow
+      // NumPy semantics and just bail.  However, a nice error message is needed
+      // because users often use `view` as a way to flatten & unflatten
+      // dimensions and will otherwise be confused why
+      //   empty_tensor.view( 0, 0)
+      // works yet
+      //   empty_tensor.view(-1, 0)
+      // doesn't.
+      TORCH_CHECK(
+          newsize != 0,
+          "cannot reshape tensor of 0 elements into shape ",
+          shape,
+          " because the unspecified dimension size -1 can be any "
+          "value and is ambiguous");
+      res[*infer_dim] = numel / newsize;
+    }
+    return;
+  }
+
+  std::ostringstream ss;
+  ss << "shape '" << shape << "' is invalid for input of size " << numel;
+  throw std::runtime_error(ss.str());
+}
+
+inline std::vector infer_size(IntArrayRef shape, int64_t numel) {
+  auto res = shape.vec();
+  infer_size_impl(shape, numel, res);
+  return res;
+}
+
+inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
+  auto res = at::DimVector(shape);
+  infer_size_impl(shape, numel, res);
+  return res;
+}
+
+inline at::SymDimVector infer_size_dv(
+    c10::SymIntArrayRef shape,
+    c10::SymInt numel) {
+  auto res = at::SymDimVector(shape);
+  infer_size_impl(
+      shape, std::move(numel), res);
+  return res;
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h b/MLPY/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h
new file mode 100644
index 0000000000000000000000000000000000000000..58289fb41c6f66b85ca17297864e1639f0a78441
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include 
+
+namespace at {
+
+// Represents the initial TensorOptions, before the "defaults" are ever changed.
+// This is designed to be used in library code, where the explicit devices,
+// dtypes, etc. are known. NOTE: this is not a stable API.
+inline TensorOptions initialTensorOptions() {
+  return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
+      false);
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Layout.h b/MLPY/Lib/site-packages/torch/include/ATen/Layout.h
new file mode 100644
index 0000000000000000000000000000000000000000..11bda768d2fc435e5aa32c764097ef158fe4a315
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Layout.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h b/MLPY/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h
new file mode 100644
index 0000000000000000000000000000000000000000..7a4a1961a5f57d0aed6a4bd9b07ae2ff7e094d8a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h
@@ -0,0 +1,25 @@
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at {
+
+// If an operator doesn't have a batching rule implemented then we fallback
+// to this implementation. The fallback only works on out-of-place operators
+// that return only tensors with new memory. (e.g., no in-place operators, no
+// view operations).
+//
+// The fallback effectively takes all of the BatchedTensors in `stack`, slices
+// them, and runs `op` on all of the corresponding slices to produce slices
+// of the outputs. The output slices then get `torch.stack`ed to create the
+// final returns.
+//
+// The performance of the fallback is not very good because it introduces an
+// extra copy from stacking the sliced outputs. Because of this, we prefer to
+// write batching rules for operators whenever possible.
+void batchedTensorForLoopFallback(
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..c606c6c1d423364cfb25d54cab682835d7b3074e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h
@@ -0,0 +1,160 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+// We assume this in a few other places in the codebase,
+// but there isn't a centralized definition.
+constexpr int64_t kVmapMaxTensorDims = 64;
+
+// The valid vmap levels range from [0, 64). This effectively means that we
+// support a maximum of 64 nested vmaps.
+constexpr int64_t kVmapNumLevels = 64;
+
+// Store this number of elements of BatchDims on the stack. Most people will
+// probably use <= 5 nested vmaps, but adjust this number as necessary.
+constexpr int64_t kBatchDimsStackSize = 5;
+
+// a BatchDim represents a "private" dimension on a Tensor created inside of
+// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
+// is being vmap'ed over and the `level` being an identifier for which vmap
+// said dimension was created inside. The `dim` corresponds to a "physical
+// dim" - it is a dimension index on the underlying physical tensor that is
+// being vmapped over.
+struct BatchDim {
+  BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
+  int64_t dim() const {
+    return dim_;
+  }
+  int64_t level() const {
+    return level_;
+  }
+
+ private:
+  int64_t dim_;
+  int64_t level_;
+};
+
+using BatchDims = SmallVector;
+using BatchDimsRef = ArrayRef;
+
+// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
+// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
+// BatchedTensorImpl.
+//
+// The batch dimensions are treated as being "private"; they are not
+// user-visible. For example, in the following Tensor,
+//    bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
+// dimensions 0 and 1 are batch dimensions.
+//
+// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
+// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
+// tensor.
+struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
+  explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
+
+  // Returns a reference to BatchDims that represent which dimensions of this
+  // tensor are private.
+  BatchDimsRef bdims() const {
+    return bdims_;
+  }
+
+  // BatchedTensorImpl wraps a Tensor
+  const Tensor& value() const {
+    return value_;
+  };
+
+  // Given a public dimension index, return the dimension index in the
+  // underlying value() tensor. For example, if we have
+  //    bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
+  //    dim=2)])
+  // bt.actualDim(0) -> 1
+  // bt.actualDim(1) -> 3
+  // bt.actualDim(2) -> Error
+  int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
+
+  // We have to override this because we opted into CustomStrides
+  IntArrayRef strides_custom() const override;
+  // Override a bunch of methods inherited from TensorImpl to return error
+  // messages.
+  bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
+  void set_size(int64_t dim, int64_t new_size) override;
+  void set_stride(int64_t dim, int64_t new_stride) override;
+  void set_storage_offset(int64_t storage_offset) override;
+#ifdef DEBUG
+  bool has_storage() const override;
+#endif
+
+ private:
+  // see NOTE: [BatchedTensorImpl levels invariant]
+  void checkInvariants() const;
+  const char* tensorimpl_type_name() const override;
+
+  Tensor value_;
+
+  // Note: [BatchedTensorImpl levels invariant]
+  // There is an invariant that the BatchDims must be stored in increasing
+  // `level` order. That is, for i < j, bdims_[i].level must be less than
+  // bdims_[j].level.
+  BatchDims bdims_;
+};
+
+// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
+// BatchedTensorImpl.
+inline bool isBatchedTensor(const Tensor& tensor) {
+  return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
+}
+
+// It is unsafe to call this on a Tensor that is not backed by a
+// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
+inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
+  return static_cast(tensor.unsafeGetTensorImpl());
+}
+
+inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
+  if (!isBatchedTensor(tensor)) {
+    return nullptr;
+  }
+  return unsafeGetBatchedImpl(tensor);
+}
+
+// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
+inline std::bitset createBatchDimBitset(
+    BatchDimsRef bdims) {
+  std::bitset is_bdim;
+  for (const auto& bdim : bdims) {
+    is_bdim.set(bdim.dim());
+  }
+  return is_bdim;
+}
+
+// Creates a bitset for all of the levels present in `bdims`
+inline std::bitset createVmapLevelsBitset(BatchDimsRef bdims) {
+  std::bitset result;
+  for (const auto& bdim : bdims) {
+    result.set(bdim.level());
+  }
+  return result;
+}
+
+inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
+  out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
+  return out;
+}
+
+// Use this to construct a BatchedTensor from a regular Tensor
+TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
+
+// Adds a batch dim to `tensor`, returning a BatchedTensor
+TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
+
+// Checks if an inplace operation on self and other is "vmap compatible".
+// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
+TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h b/MLPY/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h
new file mode 100644
index 0000000000000000000000000000000000000000..dfb093566ccbe05a23e1d474cad84166496eb402
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include 
+
+namespace at::impl {
+
+// VmapMode contains a thread local count of how many nested vmaps
+// we are currently inside. That number is known as the `vmap level`.
+// VmapMode is used in the implementation of the Python `torch.vmap` API.
+//
+// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
+
+struct TORCH_API VmapMode {
+  // Returns the vmap level, aka the count of how many nested vmaps we're in.
+  static int64_t current_vmap_level();
+
+  // Increment the count of nested vmaps. If this causes the vmap level to be
+  // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
+  static int64_t increment_nesting();
+
+  // Decrements the count of nested vmaps. If this causes the vmap level to be
+  // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
+  static int64_t decrement_nesting();
+};
+
+} // namespace at::impl
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h b/MLPY/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h
new file mode 100644
index 0000000000000000000000000000000000000000..13af3ad08ad24f59d81bf6d4ade0cb925d3a5b95
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h
@@ -0,0 +1,183 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+
+// This file contains abstractions used for transforming *logical* vmap
+// arguments into *physical* arguments. (Keep reading for definitions of these
+// terms).
+
+// NOTE: [Logical vs physical args]
+// Consider the following vmap.
+//   vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
+// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
+// with batch dims 0 and 2:
+//   BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
+//
+// We say the *logical* view of the tensor has size [3] -- tensors inside
+// `func` appear to have size [3].
+// However, the *physical* underlying tensor (the one passed to vmap) has size
+// [2, 3, 4].
+//
+// This notion of logical vs physical also extends to non-tensor arguments.
+// Consider the previous tensor; let's assume the user called
+// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
+// dimension they are reducing over is dim 0 but the physical dim is dim 1
+// (the first non-batch dimension)
+
+// Forward declared; see NOTE: [What is a VmapPhysicalView?]
+struct VmapPhysicalView;
+
+// Most PyTorch operators take 4 or fewer inputs.
+constexpr int64_t kVmapTransformStaticInputSize = 4;
+using VmapPhysicalViewVec =
+    SmallVector;
+
+// Pytorch generally advertises good performance for <= 5 dims.
+// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
+// dimensions to get 8. Adjust this number as necessary
+constexpr int64_t kVmapStaticDimVecSize = 8;
+using VmapDimVector = SmallVector;
+using VmapSymDimVector = SmallVector;
+
+// NOTE: [What is an VmapTransform?]
+// An *VmapTransform* converts logical views of tensors to physical views.
+//
+// Batching rules use VmapTransforms to convert logical arguments to
+// physical arguments, then call one or more at:: operator that handles the
+// physical arguments, and then converts the physical result back to a logical
+// argument.
+
+// VmapTransform for operators that take tensors with multiple batch dims.
+// Given one or more logical views on Tensors, `logicalToPhysical`
+// permutes all of the batch dims to the front of the tensor, aligns
+// and expands the batch dims to match each other (according to their `level`),
+// and returns a VmapPhysicalView on the tensor(s).
+struct TORCH_API MultiBatchVmapTransform {
+  static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
+  static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
+};
+
+// VmapTransform for operators that broadcast all inputs.
+// Given some logical views on Tensors, `logicalToPhysical`:
+// - permutes all of the batch dims to the front of the tensors
+// - aligns all the batch dims to the collective levels of all of the tensors.
+//   If a tensor does not have a batch dim for a vmap level, then it receives
+//   a size-one dimension for said level.
+// - aligns the non-batch dims to have the same dimensionality, adding extra
+//   size-1 dimensions in between the batch dimensions and the non-batch
+//   dimensions so that the batch dimensions are lined up from the right.
+//
+// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
+// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
+// tensors of size (B, 1, 2) and (B, 3, 2).
+//
+// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
+// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
+// actually *need* to return a tensor of size (1, 2) for the second tensor
+// because the broadcasting operation takes care of that for us, but we do
+// it anyways to keep things simple.
+struct TORCH_API BroadcastingVmapTransform {
+  static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
+};
+
+// Forward declared, if you're reading this file head to toe, don't worry about
+// it yet.
+struct VmapPhysicalToLogicalMap;
+
+// NOTE: [What is a VmapPhysicalView?]
+// VmapPhysicalView represents a physical view on a Tensor.
+//
+// One can use it to further convert logical dimension indices, logical shapes,
+// and more to their physical variants, or convert a new (physical) tensor into
+// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
+//
+// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
+// the front and some levels that correspond to said batch dimensions.
+//
+// The levels bitset specifies which vmap levels correspond to the batch
+// dimensions at the front of the tensor. In particular, the number of set bits
+// corresponds to the number of batch dimensions on `tensor` and the rightmost
+// bit of `levels` specifies the maximum number of nested vmaps we are in at
+// this point in time.
+// For example, given:
+//   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
+//
+// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
+// than or equal to 3.
+//   bitset: 010100
+//              ^
+//              |
+//   levels: 012345
+struct TORCH_API VmapPhysicalView {
+  VmapPhysicalView(Tensor&& tensor, std::bitset levels)
+      : levels_(levels), tensor_(std::move(tensor)) {
+    TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
+  }
+
+  Tensor& tensor() {
+    return tensor_;
+  }
+  const Tensor& tensor() const {
+    return tensor_;
+  }
+
+  // Maps logical dim indices to physical dim indices. Also does dim wrapping.
+  //
+  // For example, given:
+  //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
+  //
+  // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
+  // This is because the size of levels tell us that the first two dimensions
+  // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
+  // a physical dim of `n + 2`.
+  VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
+  int64_t getPhysicalDim(int64_t logical_dim) const;
+
+  // Returns a VmapPhysicalToLogicalMap object. This can be used for
+  // mapping a physical tensor to a new logical tensor (BatchedTensor)
+  VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
+
+  // Maps a logical shape to a physical shape by pre-pending the batch
+  // sizes to the logical shape.
+  VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
+
+  int64_t numBatchDims() const;
+
+ private:
+  int64_t numLogicalDims() const;
+
+  std::bitset levels_;
+  Tensor tensor_;
+};
+
+// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
+// to a logical one (BatchedTensor). It holds some levels that are used to do
+// the mapping and assumes that the batch dimensions in the physical tensor all
+// occur at the front of the tensor.
+struct TORCH_API VmapPhysicalToLogicalMap {
+  VmapPhysicalToLogicalMap(std::bitset levels)
+      : levels_(levels) {}
+
+  // Maps a physical tensor to a new logical tensor (BatchedTensor).
+  // Assumes that all of the "batch dimensions" are at the front
+  // of the physical tensor. For example, given:
+  // - x = rank-4 Tensor with size 2, 3, 5, 7
+  // - levels = (2, 4)
+  // Returns:
+  // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
+  Tensor apply(const Tensor& physical_tensor) const;
+
+  // Given a vector of physical tensors,
+  // 1. maps each tensor to a new logical tensor. Assumes that all of the
+  //    "batch dimensions" are at the front of the physical tensors.
+  // 2. stores the new logical tensors back into the passed-in vector. This is
+  //    to avoid additional dynamic allocations.
+  void applyInplace(std::vector& physical_tensors) const;
+
+  std::bitset levels_;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/LinalgBackend.h b/MLPY/Lib/site-packages/torch/include/ATen/LinalgBackend.h
new file mode 100644
index 0000000000000000000000000000000000000000..3b084d189d7fb61cc0f67ccc0be15614be7e490c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/LinalgBackend.h
@@ -0,0 +1,31 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+namespace at {
+
+enum class LinalgBackend : int8_t { Default, Cusolver, Magma };
+
+inline std::string LinalgBackendToString(at::LinalgBackend backend) {
+  switch (backend) {
+    case LinalgBackend::Default:
+      return "at::LinalgBackend::Default";
+    case LinalgBackend::Cusolver:
+      return "at::LinalgBackend::Cusolver";
+    case LinalgBackend::Magma:
+      return "at::LinalgBackend::Magma";
+    default:
+      TORCH_CHECK(false, "Unknown linalg backend");
+  }
+}
+
+inline std::ostream& operator<<(
+    std::ostream& stream,
+    at::LinalgBackend backend) {
+  return stream << LinalgBackendToString(backend);
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/MapAllocator.h b/MLPY/Lib/site-packages/torch/include/ATen/MapAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..17af2f8947abb4501412cc50057d4889d2a0a237
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/MapAllocator.h
@@ -0,0 +1,139 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+
+enum MappedAllocatorModes {
+  ALLOCATOR_MAPPED_SHARED = 1,
+  ALLOCATOR_MAPPED_SHAREDMEM = 2,
+  ALLOCATOR_MAPPED_EXCLUSIVE = 4,
+  ALLOCATOR_MAPPED_NOCREATE = 8,
+  ALLOCATOR_MAPPED_KEEPFD = 16,
+  ALLOCATOR_MAPPED_FROMFD = 32,
+  ALLOCATOR_MAPPED_UNLINK = 64
+};
+
+// Sentinel value/type to help distinguish the file descriptor constructor from
+// the non-file descriptor constructor
+enum WithFd { WITH_FD };
+
+TORCH_API std::string NewProcessWideShmHandle();
+
+class TORCH_API MapAllocator {
+ public:
+  MapAllocator(c10::string_view filename, int flags, size_t size);
+  MapAllocator(
+      WithFd,
+      c10::string_view filename,
+      int fd,
+      int flags,
+      size_t size);
+  MapAllocator(const MapAllocator&) = delete;
+  MapAllocator& operator=(const MapAllocator&) = delete;
+  MapAllocator(MapAllocator&&) = delete;
+  MapAllocator& operator=(MapAllocator&&) = delete;
+
+  const char* filename() const {
+    return filename_.c_str();
+  }
+  int fd() const {
+#ifdef _WIN32
+    TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
+#else
+    return fd_;
+#endif
+  }
+  ptrdiff_t size() const {
+    return size_;
+  }
+  // Return a pointer to the actual data for this allocator
+  // (in the case of the refcounted allocator, this is offset
+  // from the base pointer.)
+  virtual void* data() const {
+    return base_ptr_;
+  }
+
+  static MapAllocator* fromDataPtr(const at::DataPtr&);
+  static at::DataPtr makeDataPtr(
+      c10::string_view filename,
+      int flags,
+      size_t size,
+      size_t* actual_size_out);
+  static at::DataPtr makeDataPtr(
+      WithFd,
+      const char* filename,
+      int fd,
+      int flags,
+      size_t size,
+      size_t* actual_size_out);
+
+  // Closes the data.  Helps us avoid destructor shenanigans
+  virtual void close();
+
+  // This is very dangerous.  You have to redefine this destructor for each
+  // subclass
+  virtual ~MapAllocator();
+
+ protected:
+  bool closed_ = false;
+  std::string filename_;
+  int flags_ = 0;
+  ptrdiff_t size_; /* mapped size */
+#ifdef _WIN32
+  void* handle_;
+  void* event_;
+  std::string eventname_;
+#else
+  int fd_ = -1;
+#endif
+  void* base_ptr_ = nullptr;
+};
+
+// Base-from-member idiom
+struct TORCH_API RefcountedMapAllocatorArgCheck {
+  RefcountedMapAllocatorArgCheck(int flags);
+};
+
+class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
+                                         public MapAllocator {
+ public:
+  RefcountedMapAllocator(const char* filename, int flags, size_t size);
+  RefcountedMapAllocator(
+      WithFd,
+      const char* filename,
+      int fd,
+      int flags,
+      size_t size);
+
+  static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
+  static at::DataPtr makeDataPtr(
+      const char* filename,
+      int flags,
+      size_t size,
+      size_t* actual_size_out);
+  static at::DataPtr makeDataPtr(
+      WithFd,
+      const char* filename,
+      int fd,
+      int flags,
+      size_t size,
+      size_t* actual_size_out);
+
+  void* data() const override;
+
+  void incref();
+  int decref();
+  void close() override;
+
+  ~RefcountedMapAllocator() override {
+    RefcountedMapAllocator::close();
+  }
+
+ protected:
+  void checkFlags();
+  void initializeAlloc();
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/MatrixRef.h b/MLPY/Lib/site-packages/torch/include/ATen/MatrixRef.h
new file mode 100644
index 0000000000000000000000000000000000000000..8e803b09f9dc5f592a301d94aa858021371ddc0d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/MatrixRef.h
@@ -0,0 +1,109 @@
+#pragma once
+#include 
+#include 
+
+#include 
+
+namespace at {
+/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
+/// we can easily view it as a multidimensional array.
+///
+/// Like ArrayRef, this class does not own the underlying data, it is expected
+/// to be used in situations where the data resides in some other buffer.
+///
+/// This is intended to be trivially copyable, so it should be passed by
+/// value.
+///
+/// For now, 2D only (so the copies are actually cheap, without having
+/// to write a SmallVector class) and contiguous only (so we can
+/// return non-strided ArrayRef on index).
+///
+/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
+template 
+class MatrixRef {
+ public:
+  typedef size_t size_type;
+
+ private:
+  /// Underlying ArrayRef
+  ArrayRef arr;
+
+  /// Stride of dim 0 (outer dimension)
+  size_type stride0;
+
+  // Stride of dim 1 is assumed to be 1
+
+ public:
+  /// Construct an empty Matrixref.
+  /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
+
+  /// Construct an MatrixRef from an ArrayRef and outer stride.
+  /*implicit*/ MatrixRef(ArrayRef arr, size_type stride0)
+      : arr(arr), stride0(stride0) {
+    TORCH_CHECK(
+        arr.size() % stride0 == 0,
+        "MatrixRef: ArrayRef size ",
+        arr.size(),
+        " not divisible by stride ",
+        stride0)
+  }
+
+  /// @}
+  /// @name Simple Operations
+  /// @{
+
+  /// empty - Check if the matrix is empty.
+  bool empty() const {
+    return arr.empty();
+  }
+
+  const T* data() const {
+    return arr.data();
+  }
+
+  /// size - Get size a dimension
+  size_t size(size_t dim) const {
+    if (dim == 0) {
+      return arr.size() / stride0;
+    } else if (dim == 1) {
+      return stride0;
+    } else {
+      TORCH_CHECK(
+          0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
+    }
+  }
+
+  size_t numel() const {
+    return arr.size();
+  }
+
+  /// equals - Check for element-wise equality.
+  bool equals(MatrixRef RHS) const {
+    return stride0 == RHS.stride0 && arr.equals(RHS.arr);
+  }
+
+  /// @}
+  /// @name Operator Overloads
+  /// @{
+  ArrayRef operator[](size_t Index) const {
+    return arr.slice(Index * stride0, stride0);
+  }
+
+  /// Disallow accidental assignment from a temporary.
+  ///
+  /// The declaration here is extra complicated so that "arrayRef = {}"
+  /// continues to select the move assignment operator.
+  template 
+  std::enable_if_t, MatrixRef>& operator=(
+      U&& Temporary) = delete;
+
+  /// Disallow accidental assignment from a temporary.
+  ///
+  /// The declaration here is extra complicated so that "arrayRef = {}"
+  /// continues to select the move assignment operator.
+  template 
+  std::enable_if_t, MatrixRef>& operator=(
+      std::initializer_list) = delete;
+};
+
+} // end namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/MemoryOverlap.h b/MLPY/Lib/site-packages/torch/include/ATen/MemoryOverlap.h
new file mode 100644
index 0000000000000000000000000000000000000000..f8427eef13cdd1741262f4dcdb84900389157e22
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/MemoryOverlap.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+struct TensorImpl;
+}
+
+namespace at {
+class TensorBase;
+
+// MemOverlap: Whether or not there is memory overlap
+//
+// No: Absolutely no memory overlap
+// Yes: Absolutely yes memory overlap
+// TooHard: There might be memory overlap, but it was too expensive to compute.
+//
+// NB: Please update the python test for these if you renumber them.
+enum class MemOverlap { No, Yes, TooHard };
+
+enum class MemOverlapStatus { Full, Partial, No, TooHard };
+
+TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
+TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
+
+TORCH_API void assert_no_internal_overlap(const TensorBase& t);
+TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
+
+TORCH_API MemOverlapStatus
+get_overlap_status(const TensorBase& a, const TensorBase& b);
+TORCH_API MemOverlapStatus
+get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
+
+TORCH_API void assert_no_partial_overlap(
+    const TensorBase& a,
+    const TensorBase& b);
+void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
+
+TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
+TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/MetaFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/MetaFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..2bd95d4fcf9785268519848d69f378f5e4bbdacb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/MetaFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..614fdf1e725bfe64454b8210e88a28c1e96af529
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h
@@ -0,0 +1,324 @@
+#pragma once
+// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/MethodOperators.h b/MLPY/Lib/site-packages/torch/include/ATen/MethodOperators.h
new file mode 100644
index 0000000000000000000000000000000000000000..c9848f67d4b24fcb5d69f0396de5930271b4ac64
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/MethodOperators.h
@@ -0,0 +1,443 @@
+#pragma once
+
+// @generated by torchgen/gen.py from MethodOperators.h
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,             \
+  meaning the file will need to be re-compiled every time an operator      \
+  is changed or added. Consider if your change would be better placed in   \
+  another file, or if a more specific header might achieve the same goal.  \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+// Forward declarations of any types needed in the operator signatures.
+// We can't directly include these classes because it will cause circular include dependencies.
+// This file is included by TensorBody.h, which defines the Tensor class.
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace _ops {
+
+} // namespace _ops
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/NamedTensor.h b/MLPY/Lib/site-packages/torch/include/ATen/NamedTensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..b18f8d95b195a19fc5c78cc941b7ce6de28f4534
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/NamedTensor.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/NamedTensorUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/NamedTensorUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..930957b44f168dfaa5a20311756b563b88cb2870
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/NamedTensorUtils.h
@@ -0,0 +1,215 @@
+#pragma once
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+using NameVector = SmallVector;
+
+inline bool has_names(const ITensorListRef& tensors) {
+  return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
+    return t.has_names();
+  });
+}
+
+// Converts dim to an positional index. Errors if `dim` cannot be used to
+// refer to any dimension of tensor.
+TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
+TORCH_API std::vector dimnames_to_positions(
+    const Tensor& tensor,
+    DimnameList dims);
+
+// Unifies two DimnameList to produce a third. This is useful for implementing
+// the named inference rule for binary broadcasting operations like add.
+//
+// There are three main constraints:
+// 1) Check matching: Names must match positionally from the right.
+// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
+//    the same index from the right in other.
+// 3) The output names are obtained by unifying the names individually from the
+// right.
+TORCH_API std::vector unify_from_right(
+    DimnameList names,
+    DimnameList other,
+    const char* action = "broadcast");
+
+[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
+  TORCH_CHECK(
+      false,
+      op_name,
+      ": You passed a dimname (string) to this op in place of a dimension "
+      "index but it does not yet support this behavior. Please pass a dimension "
+      "index to work around this.");
+}
+
+// [NOTE] Writing name inference rules
+//
+// Operators that support named tensors are either composed of operations that
+// support named tensors or implement some name inference rule. An op that
+// implements its own name inference rule generally looks like the following:
+//
+// Tensor op(...) {
+//   perform_shape_checks(...);
+//   # (1)
+//   auto maybe_outnames = compute_outnames(...);
+//   auto result = [&]() {
+//     NoNamesGuard guard;
+//     return op_impl(...);
+//   }();
+//   # (2)
+//   propagate_names_if_nonempty(result, maybe_outnames);
+//
+// Each op has (1) a compute outnames step and (2) a propagate names step.
+//
+// compute_outnames is responsible for checking that input names match and
+// determining what the output names should be. It returns either:
+// - {} (if the inputs tensors are all unnamed)
+// - non-empty outnames.
+//
+// propagate_names_if_nonempty propagates the outnames if they exist to the
+// result tensors.
+//
+// The {} case is an optimization; if the user does not use named tensors they
+// pay no perf cost for it.
+
+namespace namedinference {
+
+const Tensor& propagate_names_if_present_and_nonempty(
+    const Tensor& result,
+    c10::optional maybe_names,
+    bool validate_names = false);
+// Propagates `names` to `result` if `names` is not empty.
+// `names` can be empty; see [NOTE] Writing name inference rules
+// If `names` is not empty, `names.size()` should equal `result.dim()`.
+// When in doubt, use this overload instead of the others.
+TORCH_API const Tensor& propagate_names_if_nonempty(
+    const Tensor& result,
+    DimnameList maybe_names,
+    bool validate_names = false);
+
+// Propagates `names` to `result`. Only use this if we are certain that there
+// are names to propagate (that names is not empty).
+TORCH_API const Tensor& propagate_names(
+    const Tensor& result,
+    DimnameList names,
+    bool validate_names = false);
+
+// Propagates all names from src to result.
+TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
+
+// Propagates all names except for those at the excluded_idxs.
+TORCH_API void propagate_names_except(
+    const Tensor& result,
+    const Tensor& src,
+    IntArrayRef excluded_idxs);
+
+// Used for reduction ops that have a `keepdim` arg.
+TORCH_API void propagate_names_for_reduction(
+    const Tensor& result,
+    const Tensor& src,
+    IntArrayRef excluded_idxs,
+    bool keepdim);
+
+TORCH_API void propagate_names_for_expand(
+    const Tensor& result,
+    const Tensor& self);
+
+TORCH_API std::vector compute_cat_outnames(
+    const MaterializedITensorListRef& tensors);
+
+TORCH_API std::vector compute_broadcast_outnames(
+    const Tensor& self,
+    const Tensor& other);
+
+TORCH_API std::vector broadcast_to_outnames(
+    const Tensor& tensor,
+    const Tensor& reference_tensor,
+    const char* op_name);
+
+TORCH_API std::vector compute_matmul_outnames(
+    const Tensor& self,
+    const Tensor& other);
+
+TORCH_API std::vector compute_cdist_outnames(
+    const Tensor& self,
+    const Tensor& other);
+
+TORCH_API std::vector compute_bmm_outnames(
+    const Tensor& result,
+    const Tensor& self,
+    const Tensor& other);
+
+TORCH_API std::vector compute_squeeze_outnames(const Tensor& tensor);
+TORCH_API std::vector compute_squeeze_outnames(
+    const Tensor& tensor,
+    std::bitset dims);
+
+std::vector compute_diagonal_outnames(
+    const Tensor& tensor,
+    int64_t dim1,
+    int64_t dim2);
+
+// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
+
+TORCH_API TensorImpl* propagate_names_if_nonempty(
+    TensorImpl* result,
+    DimnameList maybe_names,
+    bool validate_names = false);
+
+TORCH_API TensorImpl* propagate_names(
+    TensorImpl* result,
+    DimnameList names,
+    bool validate_names = false);
+
+TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
+
+TORCH_API inline void propagate_names(
+    const TensorBase& result,
+    DimnameList names,
+    bool validate_names = false) {
+  propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
+}
+
+TORCH_API inline void propagate_names_if_nonempty(
+    const TensorBase& result,
+    DimnameList names,
+    bool validate_names = false) {
+  propagate_names_if_nonempty(
+      result.unsafeGetTensorImpl(), names, validate_names);
+}
+
+TORCH_API inline void propagate_names(
+    const TensorBase& result,
+    const TensorBase& src) {
+  propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
+}
+
+// result = m1 @ m2 + bias
+TORCH_API std::vector propagate_names_for_addmm(
+    const Tensor& m1,
+    const Tensor& m2,
+    const Tensor& bias);
+
+TORCH_API std::vector propagate_names_for_addmv(
+    const Tensor& mat,
+    const Tensor& vec,
+    const Tensor& bias);
+
+TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
+
+TORCH_API std::vector compute_baddbmm_outnames(
+    const Tensor& result,
+    const Tensor& self,
+    const Tensor& other,
+    const Tensor& bias);
+
+TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
+
+} // namespace namedinference
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/NativeFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/NativeFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..ee53762fbe68c3754aa9fe321446d24cc38aad84
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/NativeFunctions.h
@@ -0,0 +1,1317 @@
+#pragma once
+
+// @generated by torchgen/gen.py from NativeFunctions.h
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,            \
+  meaning the file will need to be re-compiled every time an operator     \
+  is changed or added. Consider if your change would be better placed in  \
+  another file, or if a more specific header might achieve the same goal. \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the      \
+  file will need to be re-compiled every time an operator is changed or added.  \
+  Consider including a specific operator from  \
+  and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/NativeMetaFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/NativeMetaFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..58fb5c2eee20f90c01cc6ea4afedf6b14b686dc9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/NativeMetaFunctions.h
@@ -0,0 +1,1303 @@
+#pragma once
+
+// @generated by torchgen/gen.py from NativeMetaFunctions.h
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+namespace meta {
+
+
+
+} // namespace meta
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/NestedTensorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/NestedTensorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..af9c50182715f62e4e3991c403a62855e07fab5a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/NestedTensorImpl.h
@@ -0,0 +1,283 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+struct NestedTensorImpl;
+inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
+int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
+
+struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
+  explicit NestedTensorImpl(
+      Storage storage,
+      c10::DispatchKeySet key_set,
+      const caffe2::TypeMeta data_type,
+      at::Tensor nested_sizes,
+      at::Tensor nested_strides,
+      at::Tensor storage_offsets);
+
+  explicit NestedTensorImpl(
+      const at::Tensor& buffer,
+      at::Tensor nested_sizes,
+      at::Tensor nested_strides,
+      at::Tensor storage_offsets);
+  // assume contiguous, `nested_strides` and `offsets`
+  // can be infered from `nested_sizes`
+  explicit NestedTensorImpl(
+      const at::Tensor& buffer,
+      const at::Tensor& nested_sizes);
+
+  // This constructor is used creating view tensors from nested tensors
+  explicit NestedTensorImpl(
+      c10::TensorImpl::ImplType impl_type,
+      const at::Tensor& base_tensor,
+      at::Tensor nested_sizes,
+      at::Tensor nested_strides,
+      at::Tensor storage_offsets);
+
+  // TODO: don't expose private implementation details like this; in
+  // particular, resizing this tensor will mess up our dim() and
+  // callers cannot fix it.
+  const Tensor& get_nested_sizes() const {
+    return nested_sizes_;
+  }
+  // TODO: don't expose private implementation details like this
+  const Tensor& get_nested_strides() const {
+    return nested_strides_;
+  }
+  const Tensor& get_storage_offsets() const {
+    return storage_offsets_;
+  }
+  // Returns nullopt if the ith dimension is irregular. The ith dimension
+  // of a NestedTensor is regular if the unbound tensors match in
+  // size at the (i-1)th dimension.
+  c10::optional opt_size(int64_t d) const;
+
+  int64_t size(int64_t d) const {
+    c10::optional optional_size = this->opt_size(d);
+    TORCH_CHECK(
+        optional_size.has_value(),
+        "Given dimension ",
+        d,
+        " is irregular and does not have a size.");
+    return *optional_size;
+  }
+  /**
+   * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
+   *
+   * The buffer tensor created by this function shares the same storage_impl as
+   * the original nested tensor, and therefore can be seen as a view.
+   *
+   * @return A newly constructed view tensor
+   */
+  at::Tensor get_buffer() const {
+    TORCH_CHECK(
+        nested_tensor_impl_is_contiguous(this),
+        "NestedTensor must be contiguous to get buffer.");
+    return get_unsafe_storage_as_tensor();
+  }
+  /**
+   * If possible use get_buffer() instead. This function returns the storage
+   * as a tensor directly, which is not safe to use in general. If using this
+   * function, The caller must ensure to account for nested_sizes,
+   * nested_strides and storage_offsets.
+   *
+   * @return A newly constructed view tensor
+   */
+  at::Tensor get_unsafe_storage_as_tensor() const {
+    auto buffer_key_set_ = generate_buffer_key_set();
+    const auto buffer_size = get_buffer_size();
+    auto buffer_tensor_impl = c10::make_intrusive(
+        c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
+    buffer_tensor_impl->set_sizes_contiguous(
+        c10::makeArrayRef(static_cast(buffer_size)));
+    return Tensor(buffer_tensor_impl);
+  }
+
+  size_t get_buffer_size() const {
+    return storage_.nbytes() / data_type_.itemsize();
+  }
+
+ protected:
+  const char* tensorimpl_type_name() const override;
+
+  // TODO: numel_custom and is_contiguous_custom can be profitably overridden
+  // with real implementations
+  int64_t numel_custom() const override;
+  c10::SymInt sym_numel_custom() const override;
+  bool is_contiguous_custom(MemoryFormat) const override;
+  int64_t size_custom(int64_t d) const override {
+    return this->size(d);
+  }
+  c10::SymInt sym_size_custom(int64_t d) const override {
+    return c10::SymInt{this->size(d)};
+  }
+  IntArrayRef sizes_custom() const override;
+  c10::SymIntArrayRef sym_sizes_custom() const override;
+  IntArrayRef strides_custom() const override;
+  c10::SymIntArrayRef sym_strides_custom() const override;
+
+  // this one is real
+  int64_t dim_custom() const override;
+
+  c10::intrusive_ptr shallow_copy_and_detach(
+      const c10::VariableVersion& version_counter,
+      bool allow_tensor_metadata_change) const override;
+
+  c10::intrusive_ptr shallow_copy_and_detach(
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const override;
+
+  void shallow_copy_from(const c10::intrusive_ptr& impl) override {
+    copy_tensor_metadata(
+        /*src_impl=*/impl.get(),
+        /*dest_impl=*/this,
+        /*version_counter=*/version_counter(),
+        /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
+  }
+
+ private:
+  // Must be called after any changes to our dim() to sync the state
+  // to TensorImpl.
+  void refresh_dim();
+
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+  const at::Tensor nested_sizes_, nested_strides_;
+  // The starting positions of the underlying tensors in contiguous buffer
+  // i.e. the buffer memory offsets to get the underlying tensors
+  // The reason to keep this metadata is that, without strong enough constraint
+  // it cannot be derived from `nested_sizes_`
+  // and `nested_strides_`:
+  // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
+  //    this can happen e.g. after slicing a nested tensor
+  // 2. when multiple tensors share a same memory
+  // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
+  // Some strong enough constraints are:
+  // 1. every underlying tensor is contiguous in memory
+  //    && nesting in ascending order
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+  const at::Tensor storage_offsets_;
+  // NOTE: -1 here means the size is missing
+  // Optional to allow it to be computed lazily from nested.
+  // TODO: maybe we can remove this metadata since
+  //       we can compute it from `nested_sizes_`
+  mutable c10::optional> opt_sizes_;
+
+  template 
+  c10::intrusive_ptr shallow_copy_and_detach_core(
+      VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const;
+
+  /**
+   * Generates a non-nested key_set from a nested tensor.
+   *
+   * For many nested tensor kernel implementations a buffer tensor
+   * is generated and redispatched to a non-nested kernel this function
+   * generates the key set used by that buffer tensor
+   *
+   * @return Appropriate key set for non-nested tensor
+   */
+  inline c10::DispatchKeySet generate_buffer_key_set() const {
+    auto buffer_key_set = this->key_set();
+    const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
+    // Remove nested tensor specific keys
+    buffer_key_set = buffer_key_set -
+        c10::DispatchKeySet{
+            c10::DispatchKey::NestedTensor,
+            c10::DispatchKey::AutogradNestedTensor};
+
+    // Add dense tensor specific keys
+    buffer_key_set =
+        buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
+    buffer_key_set = Autograd
+        ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
+        : buffer_key_set;
+
+    return buffer_key_set;
+  }
+};
+
+inline NestedTensorImpl* get_nested_tensor_impl_or_null(
+    const at::Tensor& tensor) {
+  if (tensor.is_nested()) {
+    return static_cast(tensor.unsafeGetTensorImpl());
+  }
+  return nullptr;
+}
+
+inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
+  TORCH_CHECK(
+      tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
+  return static_cast(tensor.unsafeGetTensorImpl());
+}
+
+inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
+  int64_t ntensors = nt->size(0);
+  if (ntensors == 0) {
+    return true;
+  }
+  const Tensor &sizemat = nt->get_nested_sizes(),
+               &stridemat = nt->get_nested_strides();
+  int64_t* offsets_ptr = nt->get_storage_offsets().data_ptr();
+  int64_t orig_dim = sizemat.size(1);
+  // nesting scalars
+  if (orig_dim == 0) {
+    // each scalar must be contiguous
+    // if there is blank memory between underlying scalars
+    for (int64_t i = 0; i < ntensors; i++) {
+      if (offsets_ptr[i] != i) {
+        return false;
+      }
+    }
+  }
+  // nesting tensors
+  else {
+    // if any underlying tensor is non-contiguous
+    const int64_t *sizemat_ptr = sizemat.data_ptr(),
+                  *stridemat_ptr = stridemat.data_ptr();
+    for (int64_t i = 0; i < ntensors; i++) {
+      if (stridemat_ptr[orig_dim - 1] != 1) {
+        return false;
+      }
+      int64_t product = sizemat_ptr[orig_dim - 1];
+      for (int64_t j = orig_dim - 2; j >= 0; j--) {
+        if (stridemat_ptr[j] != product) {
+          return false;
+        }
+        product *= sizemat_ptr[j];
+      }
+      sizemat_ptr += orig_dim;
+      stridemat_ptr += orig_dim;
+    }
+    // if there is blank memory between underlying tensors
+    if (offsets_ptr[0] != 0) {
+      return false;
+    }
+    sizemat_ptr = sizemat.data_ptr();
+    stridemat_ptr = stridemat.data_ptr();
+    for (int64_t i = 1; i < ntensors; i++) {
+      if (offsets_ptr[i] !=
+          offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
+        return false;
+      }
+      sizemat_ptr += orig_dim;
+      stridemat_ptr += orig_dim;
+    }
+  }
+  // everything is fine
+  return true;
+}
+
+inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
+  return get_nested_tensor_impl(tensor)->get_nested_sizes();
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/NumericUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/NumericUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..ccef4482d530839205e4ceec8b0d69c9e1565a15
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/NumericUtils.h
@@ -0,0 +1,203 @@
+#pragma once
+
+#ifdef __HIPCC__
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at {
+
+// std::isnan isn't performant to use on integral types; it will
+// (uselessly) convert to floating point and then do the test.
+// This function is.
+
+template , int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
+  return false;
+}
+
+template , int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  return ::isnan(val);
+#else
+  return std::isnan(val);
+#endif
+}
+
+template ::value, int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return std::isnan(val.real()) || std::isnan(val.imag());
+}
+
+template , int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return at::_isnan(static_cast(val));
+}
+
+template <
+    typename T,
+    std::enable_if_t, int> = 0>
+inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
+  return at::_isnan(static_cast(val));
+}
+
+inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
+  return at::_isnan(static_cast(val));
+}
+
+template <
+    typename T,
+    std::enable_if_t, int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return val.isnan();
+}
+
+template <
+    typename T,
+    std::enable_if_t, int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return val.isnan();
+}
+
+template <
+    typename T,
+    std::enable_if_t, int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return val.isnan();
+}
+
+template <
+    typename T,
+    std::enable_if_t, int> = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return val.isnan();
+}
+
+// std::isinf isn't performant to use on integral types; it will
+// (uselessly) convert to floating point and then do the test.
+// This function is.
+
+template , int> = 0>
+inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
+  return false;
+}
+
+template , int> = 0>
+inline C10_HOST_DEVICE bool _isinf(T val) {
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  return ::isinf(val);
+#else
+  return std::isinf(val);
+#endif
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::Half val) {
+  return at::_isinf(static_cast(val));
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
+  return at::_isinf(static_cast(val));
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
+  return val.isinf();
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
+  return false;
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) {
+  return false;
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) {
+  return false;
+}
+
+template 
+C10_HOST_DEVICE inline T exp(T x) {
+  static_assert(
+      !std::is_same_v,
+      "this template must be used with float or less precise type");
+#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
+  // use __expf fast approximation for peak bandwidth
+  return __expf(x);
+#else
+  return ::exp(x);
+#endif
+}
+
+template <>
+C10_HOST_DEVICE inline double exp(double x) {
+  return ::exp(x);
+}
+
+template 
+C10_HOST_DEVICE inline T log(T x) {
+  static_assert(
+      !std::is_same_v,
+      "this template must be used with float or less precise type");
+#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
+  // use __logf fast approximation for peak bandwidth
+  return __logf(x);
+#else
+  return ::log(x);
+#endif
+}
+
+template <>
+C10_HOST_DEVICE inline double log(double x) {
+  return ::log(x);
+}
+
+template 
+C10_HOST_DEVICE inline T log1p(T x) {
+  static_assert(
+      !std::is_same_v,
+      "this template must be used with float or less precise type");
+#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
+  // use __logf fast approximation for peak bandwidth
+  // NOTE: There is no __log1pf so unfortunately we lose precision.
+  return __logf(1.0f + x);
+#else
+  return ::log1p(x);
+#endif
+}
+
+template <>
+C10_HOST_DEVICE inline double log1p(double x) {
+  return ::log1p(x);
+}
+
+template 
+C10_HOST_DEVICE inline T tan(T x) {
+  static_assert(
+      !std::is_same_v,
+      "this template must be used with float or less precise type");
+#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
+  // use __tanf fast approximation for peak bandwidth
+  return __tanf(x);
+#else
+  return ::tan(x);
+#endif
+}
+
+template <>
+C10_HOST_DEVICE inline double tan(double x) {
+  return ::tan(x);
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/OpMathType.h b/MLPY/Lib/site-packages/torch/include/ATen/OpMathType.h
new file mode 100644
index 0000000000000000000000000000000000000000..64b1364a8bb72db5916c331835aa76bfd96e7995
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/OpMathType.h
@@ -0,0 +1,69 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+// For FP16 or BFloat16 inputs, ops should perform internal math in FP32.
+template 
+struct OpMathType {
+  using type = scalar_t;
+};
+template <>
+struct OpMathType {
+  using type = float;
+};
+template <>
+struct OpMathType {
+  using type = float;
+};
+template <>
+struct OpMathType {
+  using type = float;
+};
+template <>
+struct OpMathType {
+  using type = float;
+};
+template <>
+struct OpMathType {
+  using type = float;
+};
+template <>
+struct OpMathType {
+  using type = float;
+};
+template <>
+struct OpMathType> {
+  using type = c10::complex;
+};
+
+template 
+using opmath_type = typename OpMathType::type;
+
+namespace {
+
+inline c10::ScalarType toOpMathType(const c10::ScalarType type) {
+  switch (type) {
+#define DEFINE_CASE(scalar_t, TypeNum) \
+  case ScalarType::TypeNum:            \
+    return CppTypeToScalarType>::value;
+
+    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
+#undef DEFINE_CASE
+
+    default:
+      TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
+  }
+}
+
+} // namespace
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/OpaqueTensorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/OpaqueTensorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..23805376faad47f149d4bf57823ad1a473ed16ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/OpaqueTensorImpl.h
@@ -0,0 +1,187 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+// An "Opaque" TensorImpl -- there are no strides and (for now)
+// even data() is not supported (thus no pointer arithmetic).
+
+// NOTE: We could allow data() in the future, but would have to ensure pointer
+// arithmetic code is properly guarded.
+//
+// NOTE: This does not support resize_ (and other metadata-changing ops) because
+// of `shallow_copy_and_detach`. We would need to define an interface to
+// "shallow copy" in order to add support.
+
+template 
+struct TORCH_API OpaqueTensorImpl : public TensorImpl {
+  // public constructor for now...
+  OpaqueTensorImpl(
+      at::DispatchKeySet key_set,
+      const caffe2::TypeMeta data_type,
+      c10::Device device,
+      OpaqueHandle opaque_handle,
+      c10::IntArrayRef sizes,
+      bool is_non_overlapping_and_dense = true)
+      : TensorImpl(key_set, data_type, device),
+        opaque_handle_(std::move(opaque_handle)) {
+    set_storage_access_should_throw();
+    set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
+    sizes_and_strides_.set_sizes(sizes);
+    refresh_numel();
+    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
+    is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
+  }
+
+  // Destructor doesn't call release_resources because it's
+  // unnecessary; don't forget to change that if needed!
+  void release_resources() override {
+    TensorImpl::release_resources();
+    opaque_handle_ = {};
+  }
+
+  void set_size(int64_t dim, int64_t new_size) override {
+    AT_ERROR("opaque tensors do not have set_size");
+  }
+
+  void set_stride(int64_t dim, int64_t new_stride) override {
+    AT_ERROR("opaque tensors do not have set_stride");
+  }
+
+  void set_storage_offset(int64_t storage_offset) override {
+    AT_ERROR("opaque tensors do not have set_storage_offset");
+  }
+
+#ifdef DEBUG
+  bool has_storage() const override {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+        !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
+    return false;
+  }
+#endif
+
+  /**
+   * Return a TensorImpl that is a shallow-copy of this TensorImpl.
+   *
+   * For usage of `version_counter` and `allow_tensor_metadata_change`,
+   * see NOTE [ TensorImpl Shallow-Copying ].
+   */
+  c10::intrusive_ptr shallow_copy_and_detach(
+      const c10::VariableVersion& version_counter,
+      bool allow_tensor_metadata_change) const override {
+    auto impl = c10::make_intrusive>(
+        key_set(),
+        dtype(),
+        device(),
+        opaque_handle_,
+        sizes_and_strides_.sizes_arrayref());
+    copy_tensor_metadata(
+        /*src_opaque_impl=*/this,
+        /*dest_opaque_impl=*/impl.get(),
+        /*version_counter=*/version_counter,
+        /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
+    impl->refresh_numel();
+    return impl;
+  }
+
+  /**
+   * Return a TensorImpl that is a shallow-copy of this TensorImpl.
+   *
+   * For usage of `version_counter` and `allow_tensor_metadata_change`,
+   * see NOTE [ TensorImpl Shallow-Copying ].
+   */
+  c10::intrusive_ptr shallow_copy_and_detach(
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const override {
+    auto impl = c10::make_intrusive>(
+        key_set(),
+        dtype(),
+        device(),
+        opaque_handle_,
+        sizes_and_strides_.sizes_arrayref());
+    copy_tensor_metadata(
+        /*src_opaque_impl=*/this,
+        /*dest_opaque_impl=*/impl.get(),
+        /*version_counter=*/std::move(version_counter),
+        /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
+    impl->refresh_numel();
+    return impl;
+  }
+
+  /**
+   * Shallow-copies data from another TensorImpl into this TensorImpl.
+   *
+   * For why this function doesn't check this TensorImpl's
+   * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
+   */
+  void shallow_copy_from(const c10::intrusive_ptr& impl) override {
+    AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
+    auto opaque_impl =
+        static_cast*>(impl.get());
+    copy_tensor_metadata(
+        /*src_impl=*/opaque_impl,
+        /*dest_impl=*/this,
+        /*version_counter=*/version_counter(),
+        /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
+    refresh_numel();
+  }
+
+  const OpaqueHandle& opaque_handle() const {
+    return opaque_handle_;
+  }
+
+  OpaqueHandle& unsafe_opaque_handle() {
+    return opaque_handle_;
+  }
+
+ protected:
+  /**
+   * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
+   * storage_offset) from one TensorImpl to another TensorImpl.
+   *
+   * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
+   * [ TensorImpl Shallow-Copying ].
+   */
+  static void copy_tensor_metadata(
+      const OpaqueTensorImpl* src_opaque_impl,
+      OpaqueTensorImpl* dest_opaque_impl,
+      const c10::VariableVersion& version_counter,
+      bool allow_tensor_metadata_change) {
+    TensorImpl::copy_tensor_metadata(
+        src_opaque_impl,
+        dest_opaque_impl,
+        version_counter,
+        allow_tensor_metadata_change);
+
+    // OpaqueTensorImpl-specific fields.
+    dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
+  }
+
+  static void copy_tensor_metadata(
+      const OpaqueTensorImpl* src_opaque_impl,
+      OpaqueTensorImpl* dest_opaque_impl,
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) {
+    TensorImpl::copy_tensor_metadata(
+        src_opaque_impl,
+        dest_opaque_impl,
+        std::move(version_counter),
+        allow_tensor_metadata_change);
+
+    // OpaqueTensorImpl-specific fields.
+    dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
+  }
+
+ private:
+  const char* tensorimpl_type_name() const override {
+    return "OpaqueTensorImpl";
+  }
+
+  OpaqueHandle opaque_handle_;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Operators.h b/MLPY/Lib/site-packages/torch/include/ATen/Operators.h
new file mode 100644
index 0000000000000000000000000000000000000000..5e3118f98030fb8d660e55931b771b794e82244d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Operators.h
@@ -0,0 +1,1358 @@
+#pragma once
+
+// @generated by torchgen/gen.py from Operators.h
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,             \
+  meaning the file will need to be re-compiled every time an operator      \
+  is changed or added. Consider if your change would be better placed in   \
+  another file, or if a more specific header might achieve the same goal.  \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from    \
+  and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// Extension writers: do you write wrapper functions? Are you frustrated with
+// resolving overloads of operators? Are you frustrated with dealing with
+// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no
+// further, this is the utility for you.
+//
+// Given an operator schema: aten::op.overload(...
+//
+// Use ATEN_FN2(op, overload) to get a *function* version of the operator
+// that is guaranteed to not be overloaded. This means that you can safely
+// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args.
+//
+// Given an operator schema without an overload name: aten::op(...
+//
+// Use ATEN_FN(op) to get an unambiguous *function* version of the operator.
+//
+// There is some interesting behavior for out= operations.
+// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema;
+// that is, the order of arguments is exactly what it looks like in the schema.
+
+#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call
+#define ATEN_FN(op_name) at::_ops::op_name::call
+
+// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time
+// metadata about a given aten operator.
+// Notable data on the class includes:
+// - ATEN_OP2(add, Tensor)::name // returns the string name: "add"
+// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor"
+// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &)
+// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
+
+#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload
+#define ATEN_OP(op_name) at::_ops::op_name
+
+// WARNING: Please do not call any of the ops in the _ops namespace directly.
+// Use the ATEN_FN macros. We do not guarantee stability of the naming
+// scheme for the functions in at::_ops
+
+// See Note [The ATen Operators API] for details of the at::_ops namespace
+
+namespace at {
+namespace _ops {
+
+} // namespace _ops
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/PTThreadPool.h b/MLPY/Lib/site-packages/torch/include/ATen/PTThreadPool.h
new file mode 100644
index 0000000000000000000000000000000000000000..d18d80161296db96fc6cc0c89ba4546490b6e5a4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/PTThreadPool.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+
+class TORCH_API PTThreadPool : public c10::ThreadPool {
+ public:
+  explicit PTThreadPool(int pool_size, int numa_node_id = -1)
+      : c10::ThreadPool(pool_size, numa_node_id, []() {
+          c10::setThreadName("PTThreadPool");
+          at::init_num_threads();
+        }) {}
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/PadNd.h b/MLPY/Lib/site-packages/torch/include/ATen/PadNd.h
new file mode 100644
index 0000000000000000000000000000000000000000..612631ec6bc042ff7b02955620981e107a2fa8fe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/PadNd.h
@@ -0,0 +1,28 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+
+enum class padding_mode {
+  reflect,
+  replicate,
+  circular,
+  constant,
+};
+
+static inline c10::string_view padding_mode_string(padding_mode m) {
+  switch (m) {
+    case padding_mode::reflect:
+      return "reflect";
+    case padding_mode::replicate:
+      return "replicate";
+    case padding_mode::circular:
+      return "circular";
+    case padding_mode::constant:
+      return "constant";
+  }
+  TORCH_CHECK(false, "Invalid padding mode (", static_cast(m), ")");
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Parallel-inl.h b/MLPY/Lib/site-packages/torch/include/ATen/Parallel-inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..966aa4b6371df7442cade150cae890bd772e4491
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Parallel-inl.h
@@ -0,0 +1,93 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+template 
+inline void parallel_for(
+    const int64_t begin,
+    const int64_t end,
+    const int64_t grain_size,
+    const F& f) {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
+  if (begin >= end) {
+    return;
+  }
+
+#ifdef INTRA_OP_PARALLEL
+  at::internal::lazy_init_num_threads();
+  const auto numiter = end - begin;
+  const bool use_parallel =
+      (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
+       at::get_num_threads() > 1);
+  if (!use_parallel) {
+    internal::ThreadIdGuard tid_guard(0);
+    c10::ParallelGuard guard(true);
+    f(begin, end);
+    return;
+  }
+
+  internal::invoke_parallel(
+      begin, end, grain_size, [&](int64_t begin, int64_t end) {
+        c10::ParallelGuard guard(true);
+        f(begin, end);
+      });
+#else
+  internal::ThreadIdGuard tid_guard(0);
+  c10::ParallelGuard guard(true);
+  f(begin, end);
+#endif
+}
+
+template 
+inline scalar_t parallel_reduce(
+    const int64_t begin,
+    const int64_t end,
+    const int64_t grain_size,
+    const scalar_t ident,
+    const F& f,
+    const SF& sf) {
+  TORCH_CHECK(grain_size >= 0);
+  if (begin >= end) {
+    return ident;
+  }
+
+#ifdef INTRA_OP_PARALLEL
+  at::internal::lazy_init_num_threads();
+  const auto max_threads = at::get_num_threads();
+  const bool use_parallel =
+      ((end - begin) > grain_size && !at::in_parallel_region() &&
+       max_threads > 1);
+  if (!use_parallel) {
+    internal::ThreadIdGuard tid_guard(0);
+    c10::ParallelGuard guard(true);
+    return f(begin, end, ident);
+  }
+
+  c10::SmallVector results(max_threads, ident);
+  internal::invoke_parallel(
+      begin,
+      end,
+      grain_size,
+      [&](const int64_t my_begin, const int64_t my_end) {
+        const auto tid = at::get_thread_num();
+        c10::ParallelGuard guard(true);
+        results[tid] = f(my_begin, my_end, ident);
+      });
+
+  scalar_t result = ident;
+  for (auto partial_result : results) {
+    result = sf(result, partial_result);
+  }
+  return result;
+#else
+  internal::ThreadIdGuard tid_guard(0);
+  c10::ParallelGuard guard(true);
+  return f(begin, end, ident);
+#endif
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Parallel.h b/MLPY/Lib/site-packages/torch/include/ATen/Parallel.h
new file mode 100644
index 0000000000000000000000000000000000000000..7261fed38968b84d9dfca9f63c30af742cdec4e6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Parallel.h
@@ -0,0 +1,160 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+inline int64_t divup(int64_t x, int64_t y) {
+  return (x + y - 1) / y;
+}
+
+// Called during new thread initialization
+TORCH_API void init_num_threads();
+
+// Sets the number of threads to be used in parallel region
+TORCH_API void set_num_threads(int);
+
+// Returns the maximum number of threads that may be used in a parallel region
+TORCH_API int get_num_threads();
+
+// Returns the current thread number (starting from 0)
+// in the current parallel region, or 0 in the sequential region
+TORCH_API int get_thread_num();
+
+// Checks whether the code runs in parallel region
+TORCH_API bool in_parallel_region();
+
+namespace internal {
+
+// Initialise num_threads lazily at first parallel call
+inline void lazy_init_num_threads() {
+  thread_local bool init = false;
+  if (C10_UNLIKELY(!init)) {
+    at::init_num_threads();
+    init = true;
+  }
+}
+
+TORCH_API void set_thread_num(int);
+
+class TORCH_API ThreadIdGuard {
+ public:
+  ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
+    set_thread_num(new_id);
+  }
+
+  ~ThreadIdGuard() {
+    set_thread_num(old_id_);
+  }
+
+ private:
+  int old_id_;
+};
+
+} // namespace internal
+
+/*
+parallel_for
+
+begin: index at which to start applying user function
+
+end: index at which to stop applying user function
+
+grain_size: number of elements per chunk. impacts the degree of parallelization
+
+f: user function applied in parallel to the chunks, signature:
+  void f(int64_t begin, int64_t end)
+
+Warning: parallel_for does NOT copy thread local
+states from the current thread to the worker threads.
+This means for example that Tensor operations CANNOT be used in the
+body of your function, only data pointers.
+*/
+template 
+inline void parallel_for(
+    const int64_t begin,
+    const int64_t end,
+    const int64_t grain_size,
+    const F& f);
+
+/*
+parallel_reduce
+
+begin: index at which to start applying reduction
+
+end: index at which to stop applying reduction
+
+grain_size: number of elements per chunk. impacts number of elements in
+intermediate results tensor and degree of parallelization.
+
+ident: identity for binary combination function sf. sf(ident, x) needs to return
+x.
+
+f: function for reduction over a chunk. f needs to be of signature scalar_t
+f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
+
+sf: function to combine two partial results. sf needs to be of signature
+scalar_t sf(scalar_t x, scalar_t y)
+
+For example, you might have a tensor of 10000 entires and want to sum together
+all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
+an intermediate result tensor with 4 elements. Then it will execute the function
+"f" you provide and pass the beginning and end index of these chunks, so
+0-2499, 2500-4999, etc. and the combination identity. It will then write out
+the result from each of these chunks into the intermediate result tensor. After
+that it'll reduce the partial results from each chunk into a single number using
+the combination function sf and the identity ident. For a total summation this
+would be "+" and 0 respectively. This is similar to tbb's approach [1], where
+you need to provide a function to accumulate a subrange, a function to combine
+two partial results and an identity.
+
+Warning: parallel_reduce does NOT copy thread local
+states from the current thread to the worker threads.
+This means for example that Tensor operations CANNOT be used in the
+body of your function, only data pointers.
+
+[1] https://software.intel.com/en-us/node/506154
+*/
+template 
+inline scalar_t parallel_reduce(
+    const int64_t begin,
+    const int64_t end,
+    const int64_t grain_size,
+    const scalar_t ident,
+    const F& f,
+    const SF& sf);
+
+// Returns a detailed string describing parallelization settings
+TORCH_API std::string get_parallel_info();
+
+// Sets number of threads used for inter-op parallelism
+TORCH_API void set_num_interop_threads(int);
+
+// Returns the number of threads used for inter-op parallelism
+TORCH_API int get_num_interop_threads();
+
+// Launches inter-op parallel task
+TORCH_API void launch(std::function func);
+namespace internal {
+void launch_no_thread_state(std::function fn);
+} // namespace internal
+
+// Launches intra-op parallel task
+TORCH_API void intraop_launch(std::function func);
+
+// Returns number of intra-op threads used by default
+TORCH_API int intraop_default_num_threads();
+
+} // namespace at
+
+#if AT_PARALLEL_OPENMP
+#include  // IWYU pragma: keep
+#elif AT_PARALLEL_NATIVE
+#include  // IWYU pragma: keep
+#elif AT_PARALLEL_NATIVE_TBB
+#include  // IWYU pragma: keep
+#endif
+
+#include  // IWYU pragma: keep
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ParallelFuture.h b/MLPY/Lib/site-packages/torch/include/ATen/ParallelFuture.h
new file mode 100644
index 0000000000000000000000000000000000000000..f05e79b333c8dd3a7dfb7874c102f534ba354a2c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ParallelFuture.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+// Launches intra-op parallel task, returns a future
+TORCH_API c10::intrusive_ptr intraop_launch_future(
+    std::function func);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ParallelNative.h b/MLPY/Lib/site-packages/torch/include/ATen/ParallelNative.h
new file mode 100644
index 0000000000000000000000000000000000000000..dd572a697eb4d7fb0f7bce45b4b887713c6e5534
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ParallelNative.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+#define INTRA_OP_PARALLEL
+
+namespace at::internal {
+
+TORCH_API void invoke_parallel(
+    const int64_t begin,
+    const int64_t end,
+    const int64_t grain_size,
+    const std::function& f);
+
+} // namespace at::internal
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ParallelNativeTBB.h b/MLPY/Lib/site-packages/torch/include/ATen/ParallelNativeTBB.h
new file mode 100644
index 0000000000000000000000000000000000000000..0378a733a6a2762838c76cc191e65996beb20747
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ParallelNativeTBB.h
@@ -0,0 +1,52 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+#ifdef _WIN32
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif
+#endif
+#include 
+
+#define INTRA_OP_PARALLEL
+
+namespace at::internal {
+
+template 
+inline void invoke_parallel(
+    const int64_t begin,
+    const int64_t end,
+    const int64_t grain_size,
+    const F& f) {
+  // Choose number of tasks based on grain size and number of threads.
+  int64_t chunk_size = divup((end - begin), get_num_threads());
+  // Make sure each task is at least grain_size size.
+  chunk_size = std::max(grain_size, chunk_size);
+
+  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
+  std::exception_ptr eptr;
+  tbb::parallel_for(
+      tbb::blocked_range(begin, end, chunk_size),
+      [&eptr, &err_flag, f](const tbb::blocked_range& r) {
+        try {
+          internal::ThreadIdGuard tid_guard(
+              tbb::this_task_arena::current_thread_index());
+          f(r.begin(), r.end());
+        } catch (...) {
+          if (!err_flag.test_and_set()) {
+            eptr = std::current_exception();
+          }
+        }
+      },
+      tbb::static_partitioner{});
+  if (eptr) {
+    std::rethrow_exception(eptr);
+  }
+}
+
+} // namespace at::internal
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ParallelOpenMP.h b/MLPY/Lib/site-packages/torch/include/ATen/ParallelOpenMP.h
new file mode 100644
index 0000000000000000000000000000000000000000..40a8830c764543d90f8b0180fa3a91039a537d38
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ParallelOpenMP.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#ifdef _OPENMP
+#define INTRA_OP_PARALLEL
+
+#include 
+#endif
+
+#ifdef _OPENMP
+namespace at::internal {
+template 
+inline void invoke_parallel(
+    int64_t begin,
+    int64_t end,
+    int64_t grain_size,
+    const F& f) {
+  std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
+  std::exception_ptr eptr;
+
+#pragma omp parallel
+  {
+    // choose number of tasks based on grain size and number of threads
+    // can't use num_threads clause due to bugs in GOMP's thread pool (See
+    // #32008)
+    int64_t num_threads = omp_get_num_threads();
+    if (grain_size > 0) {
+      num_threads = std::min(num_threads, divup((end - begin), grain_size));
+    }
+
+    int64_t tid = omp_get_thread_num();
+    int64_t chunk_size = divup((end - begin), num_threads);
+    int64_t begin_tid = begin + tid * chunk_size;
+    if (begin_tid < end) {
+      try {
+        internal::ThreadIdGuard tid_guard(tid);
+        f(begin_tid, std::min(end, chunk_size + begin_tid));
+      } catch (...) {
+        if (!err_flag.test_and_set()) {
+          eptr = std::current_exception();
+        }
+      }
+    }
+  }
+  if (eptr) {
+    std::rethrow_exception(eptr);
+  }
+}
+} // namespace at::internal
+#endif // _OPENMP
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h b/MLPY/Lib/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h
new file mode 100644
index 0000000000000000000000000000000000000000..d7ca10fd8895bf40842b6d3d6c7adff367d49fc7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::impl {
+
+enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
+
+struct TORCH_API PythonTorchFunctionTLS {
+  static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
+  static TorchFunctionDisabledState get_disabled_state();
+
+  static void push_onto_stack(std::shared_ptr mode);
+  static const std::shared_ptr pop_stack();
+  static const std::shared_ptr& get_stack_at(int64_t idx);
+  static int64_t stack_len();
+
+  static const PythonTorchFunctionTLS& get_state();
+  static void set_state(const PythonTorchFunctionTLS& state);
+
+ private:
+  // The mode TLS is split into
+  //   - disabled_state, which says which part of torch function are disabled
+  //   - stack_, which is a vector of modes representing the stack of user
+  //   defined modes
+  TorchFunctionDisabledState disabled_state_ =
+      TorchFunctionDisabledState::ENABLED;
+  std::vector> stack_;
+};
+
+TORCH_API bool torch_function_mode_enabled();
+
+} // namespace at::impl
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/RedispatchFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/RedispatchFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..d9d205bb300f8fb8c6fe5a11b29549a2912c46d5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/RedispatchFunctions.h
@@ -0,0 +1,24791 @@
+#pragma once
+
+// @generated by torchgen/gen.py from RedispatchFunctions.h
+
+#ifdef TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider using the at::_ops::{name}::redispatch() interface by including     \
+  the specific operator from 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+namespace redispatch {
+    
+    // aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Byte(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Byte::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Char(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Char(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Char::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Double(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Double(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Double::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Float(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Float(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Float::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Int(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Int(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Int::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Long(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Long(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Long::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Short(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Short(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Short::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_cast_Half(Tensor self, bool non_blocking=False) -> Tensor
+    inline at::Tensor _cast_Half(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) {
+        return at::_ops::_cast_Half::redispatch(dispatchKeySet, self, non_blocking);
+    }
+    
+    // aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
+    inline void __dispatch__backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList inputs, const c10::optional & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false) {
+        return at::_ops::_backward::redispatch(dispatchKeySet, self, inputs, gradient, retain_graph, create_graph);
+    }
+    
+    // aten::set_data(Tensor(a!) self, Tensor new_data) -> ()
+    inline void __dispatch_set_data(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & new_data) {
+        return at::_ops::set_data::redispatch(dispatchKeySet, self, new_data);
+    }
+    
+    // aten::data(Tensor self) -> Tensor
+    inline at::Tensor __dispatch_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::data::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_leaf(Tensor self) -> bool
+    inline bool __dispatch_is_leaf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_leaf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::output_nr(Tensor self) -> int
+    inline int64_t __dispatch_output_nr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::output_nr::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_version(Tensor self) -> int
+    inline int64_t __dispatch__version(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_version::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)
+    inline at::Tensor & __dispatch_requires_grad_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, bool requires_grad=true) {
+        return at::_ops::requires_grad_::redispatch(dispatchKeySet, self, requires_grad);
+    }
+    
+    // aten::retain_grad(Tensor(a!) self) -> ()
+    inline void __dispatch_retain_grad(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::retain_grad::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::retains_grad(Tensor self) -> bool
+    inline bool __dispatch_retains_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::retains_grad::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a)
+    inline at::Tensor _fw_primal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level) {
+        return at::_ops::_fw_primal::redispatch(dispatchKeySet, self, level);
+    }
+    
+    // aten::_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)
+    inline at::Tensor _make_dual(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) {
+        return at::_ops::_make_dual::redispatch(dispatchKeySet, primal, tangent, level);
+    }
+    
+    // aten::_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)
+    inline ::std::tuple _unpack_dual(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dual, int64_t level) {
+        return at::_ops::_unpack_dual::redispatch(dispatchKeySet, dual, level);
+    }
+    
+    // aten::_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor
+    inline at::Tensor _new_zeros_with_same_feature_meta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims=0) {
+        return at::_ops::_new_zeros_with_same_feature_meta::redispatch(dispatchKeySet, self, other, self_num_batch_dims);
+    }
+    
+    // aten::_has_same_storage_numel(Tensor self, Tensor other) -> bool
+    inline bool _has_same_storage_numel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::_has_same_storage_numel::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
+    inline at::Tensor & rename_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::optional names) {
+        return at::_ops::rename_::redispatch(dispatchKeySet, self, names);
+    }
+    
+    // aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)
+    inline at::Tensor rename(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional names) {
+        return at::_ops::rename::redispatch(dispatchKeySet, self, names);
+    }
+    
+    // aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)
+    inline at::Tensor align_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names) {
+        return at::_ops::align_to::redispatch(dispatchKeySet, self, names);
+    }
+    
+    // aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)
+    inline at::Tensor align_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx) {
+        return at::_ops::align_to_ellipsis_idx::redispatch(dispatchKeySet, self, order, ellipsis_idx);
+    }
+    
+    // aten::align_as(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor align_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::align_as::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::align_tensors(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector align_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::align_tensors::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::_assert_async(Tensor self) -> ()
+    inline void _assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_assert_async::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_assert_async.msg(Tensor self, str assert_msg) -> ()
+    inline void _assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view assert_msg) {
+        return at::_ops::_assert_async_msg::redispatch(dispatchKeySet, self, assert_msg);
+    }
+    
+    // aten::_assert_scalar(Scalar self, str assert_msg) -> ()
+    inline void _assert_scalar(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, c10::string_view assert_msg) {
+        return at::_ops::_assert_scalar::redispatch(dispatchKeySet, self, assert_msg);
+    }
+    
+    // aten::_functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor
+    inline at::Tensor _functional_assert_scalar(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token) {
+        return at::_ops::_functional_assert_scalar::redispatch(dispatchKeySet, self, assert_msg, dep_token);
+    }
+    
+    // aten::_functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor
+    inline at::Tensor _functional_assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token) {
+        return at::_ops::_functional_assert_async_msg::redispatch(dispatchKeySet, self, assert_msg, dep_token);
+    }
+    
+    // aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
+    inline void _assert_tensor_metadata(c10::DispatchKeySet dispatchKeySet, const at::Tensor & a, at::OptionalIntArrayRef size=c10::nullopt, at::OptionalIntArrayRef stride=c10::nullopt, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_assert_tensor_metadata::redispatch(dispatchKeySet, a, size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*size)) : c10::nullopt, stride.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*stride)) : c10::nullopt, dtype);
+    }
+    
+    // aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
+    inline void _assert_tensor_metadata_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & a, at::OptionalSymIntArrayRef size=c10::nullopt, at::OptionalSymIntArrayRef stride=c10::nullopt, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_assert_tensor_metadata::redispatch(dispatchKeySet, a, size, stride, dtype);
+    }
+    
+    // aten::_print(str s) -> ()
+    inline void _print(c10::DispatchKeySet dispatchKeySet, c10::string_view s) {
+        return at::_ops::_print::redispatch(dispatchKeySet, s);
+    }
+    
+    // aten::sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()
+    inline void sym_constrain_range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, c10::optional min=c10::nullopt, c10::optional max=c10::nullopt) {
+        return at::_ops::sym_constrain_range::redispatch(dispatchKeySet, size, min, max);
+    }
+    
+    // aten::sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> ()
+    inline void sym_constrain_range_for_size(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, c10::optional min=c10::nullopt, c10::optional max=c10::nullopt) {
+        return at::_ops::sym_constrain_range_for_size::redispatch(dispatchKeySet, size, min, max);
+    }
+    
+    // aten::_functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
+    inline at::Tensor _functional_sym_constrain_range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, c10::optional min, c10::optional max, const at::Tensor & dep_token) {
+        return at::_ops::_functional_sym_constrain_range::redispatch(dispatchKeySet, size, min, max, dep_token);
+    }
+    
+    // aten::_functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
+    inline at::Tensor _functional_sym_constrain_range_for_size(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, c10::optional min, c10::optional max, const at::Tensor & dep_token) {
+        return at::_ops::_functional_sym_constrain_range_for_size::redispatch(dispatchKeySet, size, min, max, dep_token);
+    }
+    
+    // aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor _make_dep_token(c10::DispatchKeySet dispatchKeySet, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::_make_dep_token::redispatch(dispatchKeySet, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor _make_dep_token(c10::DispatchKeySet dispatchKeySet, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::_make_dep_token::redispatch(dispatchKeySet, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
+    inline at::Tensor refine_names(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names) {
+        return at::_ops::refine_names::redispatch(dispatchKeySet, self, names);
+    }
+    
+    // aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool
+    inline bool _use_cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank) {
+        return at::_ops::_use_cudnn_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank);
+    }
+    
+    // aten::_use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool
+    inline bool _use_cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank) {
+        return at::_ops::_use_cudnn_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank);
+    }
+    
+    // aten::_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
+    inline ::std::tuple _cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) {
+        return at::_ops::_cudnn_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity);
+    }
+    
+    // aten::_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
+    inline ::std::tuple _cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity) {
+        return at::_ops::_cudnn_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity);
+    }
+    
+    // aten::_use_cudnn_rnn_flatten_weight() -> bool
+    inline bool _use_cudnn_rnn_flatten_weight(c10::DispatchKeySet dispatchKeySet) {
+        return at::_ops::_use_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet);
+    }
+    
+    // aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor
+    inline at::Tensor _cudnn_rnn_flatten_weight(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional) {
+        return at::_ops::_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional);
+    }
+    
+    // aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor
+    inline at::Tensor _cudnn_rnn_flatten_weight_symint(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) {
+        return at::_ops::_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional);
+    }
+    
+    // aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _cudnn_rnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const at::Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state) {
+        return at::_ops::_cudnn_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state);
+    }
+    
+    // aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _cudnn_rnn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const at::Tensor & hx, const c10::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state) {
+        return at::_ops::_cudnn_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state);
+    }
+    
+    // aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+    inline ::std::tuple> _cudnn_rnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) {
+        return at::_ops::_cudnn_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask);
+    }
+    
+    // aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+    inline ::std::tuple> _cudnn_rnn_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) {
+        return at::_ops::_cudnn_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask);
+    }
+    
+    // aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor _cudnn_init_dropout_state(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, at::TensorOptions options) {
+        return at::_ops::_cudnn_init_dropout_state::redispatch(dispatchKeySet, dropout, train, dropout_seed, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor _cudnn_init_dropout_state(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_cudnn_init_dropout_state::redispatch(dispatchKeySet, dropout, train, dropout_seed, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_debug_has_internal_overlap(Tensor self) -> int
+    inline int64_t _debug_has_internal_overlap(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_debug_has_internal_overlap::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
+    inline ::std::tuple _fused_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, c10::optional generator=c10::nullopt) {
+        return at::_ops::_fused_dropout::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::_masked_scale(Tensor self, Tensor mask, float scale) -> Tensor
+    inline at::Tensor _masked_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, double scale) {
+        return at::_ops::_masked_scale::redispatch(dispatchKeySet, self, mask, scale);
+    }
+    
+    // aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
+    inline ::std::tuple native_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, c10::optional train) {
+        return at::_ops::native_dropout::redispatch(dispatchKeySet, input, p, train);
+    }
+    
+    // aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor
+    inline at::Tensor native_dropout_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, double scale) {
+        return at::_ops::native_dropout_backward::redispatch(dispatchKeySet, grad_output, mask, scale);
+    }
+    
+    // aten::_sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor)
+    inline ::std::tuple _sobol_engine_draw(c10::DispatchKeySet dispatchKeySet, const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, c10::optional dtype) {
+        return at::_ops::_sobol_engine_draw::redispatch(dispatchKeySet, quasi, n, sobolstate, dimension, num_generated, dtype);
+    }
+    
+    // aten::_sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!)
+    inline at::Tensor & _sobol_engine_ff_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated) {
+        return at::_ops::_sobol_engine_ff_::redispatch(dispatchKeySet, self, n, sobolstate, dimension, num_generated);
+    }
+    
+    // aten::_sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!)
+    inline at::Tensor & _sobol_engine_scramble_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & ltm, int64_t dimension) {
+        return at::_ops::_sobol_engine_scramble_::redispatch(dispatchKeySet, self, ltm, dimension);
+    }
+    
+    // aten::_sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!)
+    inline at::Tensor & _sobol_engine_initialize_state_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dimension) {
+        return at::_ops::_sobol_engine_initialize_state_::redispatch(dispatchKeySet, self, dimension);
+    }
+    
+    // aten::_reshape_from_tensor(Tensor self, Tensor shape) -> Tensor
+    inline at::Tensor _reshape_from_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & shape) {
+        return at::_ops::_reshape_from_tensor::redispatch(dispatchKeySet, self, shape);
+    }
+    
+    // aten::_shape_as_tensor(Tensor self) -> Tensor
+    inline at::Tensor _shape_as_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_shape_as_tensor::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::dropout(Tensor input, float p, bool train) -> Tensor
+    inline at::Tensor dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) {
+        return at::_ops::dropout::redispatch(dispatchKeySet, input, p, train);
+    }
+    
+    // aten::dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+    inline at::Tensor & dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) {
+        return at::_ops::dropout_::redispatch(dispatchKeySet, self, p, train);
+    }
+    
+    // aten::feature_dropout(Tensor input, float p, bool train) -> Tensor
+    inline at::Tensor feature_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) {
+        return at::_ops::feature_dropout::redispatch(dispatchKeySet, input, p, train);
+    }
+    
+    // aten::feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+    inline at::Tensor & feature_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) {
+        return at::_ops::feature_dropout_::redispatch(dispatchKeySet, self, p, train);
+    }
+    
+    // aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor
+    inline at::Tensor alpha_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) {
+        return at::_ops::alpha_dropout::redispatch(dispatchKeySet, input, p, train);
+    }
+    
+    // aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+    inline at::Tensor & alpha_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) {
+        return at::_ops::alpha_dropout_::redispatch(dispatchKeySet, self, p, train);
+    }
+    
+    // aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor
+    inline at::Tensor feature_alpha_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) {
+        return at::_ops::feature_alpha_dropout::redispatch(dispatchKeySet, input, p, train);
+    }
+    
+    // aten::feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+    inline at::Tensor & feature_alpha_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) {
+        return at::_ops::feature_alpha_dropout_::redispatch(dispatchKeySet, self, p, train);
+    }
+    
+    // aten::abs(Tensor self) -> Tensor
+    inline at::Tensor abs(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::abs::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::abs_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & abs_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::abs_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & abs_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::abs_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & abs_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::abs_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::absolute(Tensor self) -> Tensor
+    inline at::Tensor absolute(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::absolute::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::absolute_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & absolute_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::absolute_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & absolute_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::absolute_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & absolute_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::absolute_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::angle(Tensor self) -> Tensor
+    inline at::Tensor angle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::angle::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & angle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::angle_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & angle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::angle_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::view_as_real(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor view_as_real(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::view_as_real::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::view_as_complex(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor view_as_complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::view_as_complex::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sgn(Tensor self) -> Tensor
+    inline at::Tensor sgn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sgn::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sgn_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sgn_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sgn_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sgn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sgn_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sgn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sgn_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor chalf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::chalf::redispatch(dispatchKeySet, self, memory_format);
+    }
+    
+    // aten::real(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor real(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::real::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::imag(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor imag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::imag::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_conj(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor _conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_conj::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::conj(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor __dispatch_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::conj::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_conj_physical(Tensor self) -> Tensor
+    inline at::Tensor _conj_physical(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_conj_physical::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::conj_physical(Tensor self) -> Tensor
+    inline at::Tensor conj_physical(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::conj_physical::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conj_physical_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::conj_physical_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conj_physical_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::conj_physical_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::conj_physical_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & conj_physical_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::conj_physical_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::resolve_conj(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor resolve_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::resolve_conj::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::resolve_neg(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor resolve_neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::resolve_neg::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_neg_view(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor _neg_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_neg_view::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::acos(Tensor self) -> Tensor
+    inline at::Tensor acos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::acos::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::acos_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & acos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::acos_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & acos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::acos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & acos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::acos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arccos(Tensor self) -> Tensor
+    inline at::Tensor arccos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::arccos::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arccos_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & arccos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::arccos_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arccos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::arccos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arccos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::arccos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor
+    inline at::Tensor avg_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) {
+        return at::_ops::avg_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad);
+    }
+    
+    // aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
+    inline at::Tensor adaptive_avg_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool1d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
+    inline ::std::tuple adaptive_max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_max_pool1d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::add_Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::add__Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::add_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::add_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor _add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_add_relu_Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & _add_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_add_relu__Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_add_relu_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::_add_relu_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+    inline at::Tensor _add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::_add_relu_Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & _add_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::_add_relu__Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+    inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::add_Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::add__Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor addmv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addmv::redispatch(dispatchKeySet, self, mat, vec, beta, alpha);
+    }
+    
+    // aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & addmv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addmv_::redispatch(dispatchKeySet, self, mat, vec, beta, alpha);
+    }
+    
+    // aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addmv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addmv_out::redispatch(dispatchKeySet, self, mat, vec, beta, alpha, out);
+    }
+    
+    // aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addmv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::addmv_out::redispatch(dispatchKeySet, self, mat, vec, beta, alpha, out);
+    }
+    
+    // aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor addr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addr::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha);
+    }
+    
+    // aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & addr_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addr_::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha);
+    }
+    
+    // aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addr_out::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha, out);
+    }
+    
+    // aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::addr_out::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha, out);
+    }
+    
+    // aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor
+    inline at::Tensor affine_grid_generator(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) {
+        return at::_ops::affine_grid_generator::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners);
+    }
+    
+    // aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor
+    inline at::Tensor affine_grid_generator_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) {
+        return at::_ops::affine_grid_generator::redispatch(dispatchKeySet, theta, size, align_corners);
+    }
+    
+    // aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor
+    inline at::Tensor affine_grid_generator_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef size, bool align_corners) {
+        return at::_ops::affine_grid_generator_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(size), align_corners);
+    }
+    
+    // aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor
+    inline at::Tensor affine_grid_generator_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) {
+        return at::_ops::affine_grid_generator_backward::redispatch(dispatchKeySet, grad, size, align_corners);
+    }
+    
+    // aten::_is_all_true(Tensor self) -> Tensor
+    inline at::Tensor _is_all_true(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_is_all_true::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_is_any_true(Tensor self) -> Tensor
+    inline at::Tensor _is_any_true(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_is_any_true::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_test_check_tensor(Tensor self) -> Tensor
+    inline at::Tensor _test_check_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_test_check_tensor::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor _test_functorch_fallback(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::_test_functorch_fallback::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+    inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::all_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+    inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) {
+        return at::_ops::all_dims::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::all_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::all_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) {
+        return at::_ops::all_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::all_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
+    inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::all_dimname::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::all_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::all_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
+    inline bool allclose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) {
+        return at::_ops::allclose::redispatch(dispatchKeySet, self, other, rtol, atol, equal_nan);
+    }
+    
+    // aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+    inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::any_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+    inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) {
+        return at::_ops::any_dims::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::any_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::any_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) {
+        return at::_ops::any_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::any_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
+    inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::any_dimname::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::any_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::any_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::TensorOptions options={}) {
+        return at::_ops::arange::redispatch(dispatchKeySet, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::arange::redispatch(dispatchKeySet, end, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) {
+        return at::_ops::arange_start::redispatch(dispatchKeySet, start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::arange_start::redispatch(dispatchKeySet, start, end, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::TensorOptions options={}) {
+        return at::_ops::arange_start_step::redispatch(dispatchKeySet, start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::arange_start_step::redispatch(dispatchKeySet, start, end, step, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arange_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & end) {
+        return at::_ops::arange_out::redispatch(dispatchKeySet, end, out);
+    }
+    
+    // aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arange_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::Tensor & out) {
+        return at::_ops::arange_out::redispatch(dispatchKeySet, end, out);
+    }
+    
+    // aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arange_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) {
+        return at::_ops::arange_start_out::redispatch(dispatchKeySet, start, end, step, out);
+    }
+    
+    // aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arange_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
+        return at::_ops::arange_start_out::redispatch(dispatchKeySet, start, end, step, out);
+    }
+    
+    // aten::_dim_arange(Tensor like, int dim) -> Tensor
+    inline at::Tensor _dim_arange(c10::DispatchKeySet dispatchKeySet, const at::Tensor & like, int64_t dim) {
+        return at::_ops::_dim_arange::redispatch(dispatchKeySet, like, dim);
+    }
+    
+    // aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
+    inline at::Tensor argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim=c10::nullopt, bool keepdim=false) {
+        return at::_ops::argmax::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & argmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dim=c10::nullopt, bool keepdim=false) {
+        return at::_ops::argmax_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & argmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::argmax_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
+    inline at::Tensor argmin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim=c10::nullopt, bool keepdim=false) {
+        return at::_ops::argmin::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & argmin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dim=c10::nullopt, bool keepdim=false) {
+        return at::_ops::argmin_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & argmin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::argmin_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::acosh(Tensor self) -> Tensor
+    inline at::Tensor acosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::acosh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::acosh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & acosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::acosh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & acosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::acosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & acosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::acosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arccosh(Tensor self) -> Tensor
+    inline at::Tensor arccosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::arccosh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arccosh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & arccosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::arccosh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arccosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::arccosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arccosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::arccosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::asinh(Tensor self) -> Tensor
+    inline at::Tensor asinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::asinh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::asinh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & asinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::asinh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & asinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::asinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & asinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::asinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arcsinh(Tensor self) -> Tensor
+    inline at::Tensor arcsinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::arcsinh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arcsinh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & arcsinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::arcsinh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arcsinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::arcsinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arcsinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::arcsinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::atanh(Tensor self) -> Tensor
+    inline at::Tensor atanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::atanh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atanh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & atanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::atanh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & atanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::atanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & atanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::atanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arctanh(Tensor self) -> Tensor
+    inline at::Tensor arctanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::arctanh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arctanh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & arctanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::arctanh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arctanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::arctanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arctanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::arctanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
+    inline at::Tensor as_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+    }
+    
+    // aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
+    inline at::Tensor as_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided::redispatch(dispatchKeySet, self, size, stride, storage_offset);
+    }
+    
+    // aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
+    inline const at::Tensor & as_strided_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+    }
+    
+    // aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
+    inline const at::Tensor & as_strided__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_::redispatch(dispatchKeySet, self, size, stride, storage_offset);
+    }
+    
+    // aten::asin(Tensor self) -> Tensor
+    inline at::Tensor asin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::asin::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::asin_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & asin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::asin_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & asin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::asin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & asin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::asin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arcsin(Tensor self) -> Tensor
+    inline at::Tensor arcsin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::arcsin::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arcsin_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & arcsin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::arcsin_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arcsin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::arcsin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arcsin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::arcsin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::atan(Tensor self) -> Tensor
+    inline at::Tensor atan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::atan::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atan_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & atan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::atan_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & atan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::atan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & atan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::atan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arctan(Tensor self) -> Tensor
+    inline at::Tensor arctan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::arctan::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arctan_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & arctan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::arctan_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arctan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::arctan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arctan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::arctan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::atleast_1d(Tensor self) -> Tensor
+    inline at::Tensor atleast_1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::atleast_1d::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector atleast_1d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::atleast_1d_Sequence::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::atleast_2d(Tensor self) -> Tensor
+    inline at::Tensor atleast_2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::atleast_2d::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector atleast_2d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::atleast_2d_Sequence::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::atleast_3d(Tensor self) -> Tensor
+    inline at::Tensor atleast_3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::atleast_3d::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector atleast_3d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::atleast_3d_Sequence::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor baddbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::baddbmm::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha);
+    }
+    
+    // aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & baddbmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::baddbmm_::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha);
+    }
+    
+    // aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & baddbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::baddbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out);
+    }
+    
+    // aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & baddbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::baddbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out);
+    }
+    
+    // aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) {
+        return at::_ops::bartlett_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::bartlett_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) {
+        return at::_ops::bartlett_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::bartlett_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
+    inline at::Tensor batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) {
+        return at::_ops::batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled);
+    }
+    
+    // aten::quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
+    inline at::Tensor quantized_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) {
+        return at::_ops::quantized_batch_norm::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point);
+    }
+    
+    // aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)
+    inline ::std::tuple _batch_norm_impl_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) {
+        return at::_ops::_batch_norm_impl_index::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled);
+    }
+    
+    // aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _batch_norm_impl_index_backward(c10::DispatchKeySet dispatchKeySet, int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace) {
+        return at::_ops::_batch_norm_impl_index_backward::redispatch(dispatchKeySet, impl_index, input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, train, eps, output_mask, reservedSpace);
+    }
+    
+    // aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor
+    inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli::redispatch(dispatchKeySet, self, generator);
+    }
+    
+    // aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator, at::Tensor & out) {
+        return at::_ops::bernoulli_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & bernoulli_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & p, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli__Tensor::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & bernoulli_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p=0.5, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli__float::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor
+    inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli_p::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor
+    inline at::Tensor bilinear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const c10::optional & bias={}) {
+        return at::_ops::bilinear::redispatch(dispatchKeySet, input1, input2, weight, bias);
+    }
+    
+    // aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
+    inline at::Tensor binary_cross_entropy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::binary_cross_entropy::redispatch(dispatchKeySet, self, target, weight, reduction);
+    }
+    
+    // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & binary_cross_entropy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::binary_cross_entropy_out::redispatch(dispatchKeySet, self, target, weight, reduction, out);
+    }
+    
+    // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & binary_cross_entropy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, at::Tensor & out) {
+        return at::_ops::binary_cross_entropy_out::redispatch(dispatchKeySet, self, target, weight, reduction, out);
+    }
+    
+    // aten::binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
+    inline at::Tensor binary_cross_entropy_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::binary_cross_entropy_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction);
+    }
+    
+    // aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & binary_cross_entropy_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::binary_cross_entropy_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, grad_input);
+    }
+    
+    // aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & binary_cross_entropy_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, at::Tensor & grad_input) {
+        return at::_ops::binary_cross_entropy_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, grad_input);
+    }
+    
+    // aten::binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor
+    inline at::Tensor binary_cross_entropy_with_logits(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, const c10::optional & pos_weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::binary_cross_entropy_with_logits::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction);
+    }
+    
+    // aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor
+    inline at::Tensor bincount(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & weights={}, int64_t minlength=0) {
+        return at::_ops::bincount::redispatch(dispatchKeySet, self, weights, minlength);
+    }
+    
+    // aten::bitwise_not(Tensor self) -> Tensor
+    inline at::Tensor bitwise_not(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::bitwise_not::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & bitwise_not_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::bitwise_not_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_not_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::bitwise_not_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_not_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::bitwise_not_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copysign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::copysign_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copysign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::copysign_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor copysign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::copysign_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & copysign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::copysign__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor copysign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::copysign_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & copysign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::copysign__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copysign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::copysign_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copysign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::copysign_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_lazy_clone(Tensor self) -> Tensor
+    inline at::Tensor _lazy_clone(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_lazy_clone::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::logical_not(Tensor self) -> Tensor
+    inline at::Tensor logical_not(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::logical_not::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::logical_not_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & logical_not_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::logical_not_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_not_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::logical_not_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_not_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::logical_not_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::logical_xor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor logical_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_xor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & logical_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_xor_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_xor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::logical_xor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logical_and(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor logical_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_and::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & logical_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_and_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_and_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::logical_and_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logical_or(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor logical_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_or::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & logical_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_or_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logical_or_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logical_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::logical_or_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) {
+        return at::_ops::blackman_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::blackman_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) {
+        return at::_ops::blackman_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::blackman_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::bmm(Tensor self, Tensor mat2) -> Tensor
+    inline at::Tensor bmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::bmm::redispatch(dispatchKeySet, self, mat2);
+    }
+    
+    // aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::bmm_out::redispatch(dispatchKeySet, self, mat2, out);
+    }
+    
+    // aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) {
+        return at::_ops::bmm_out::redispatch(dispatchKeySet, self, mat2, out);
+    }
+    
+    // aten::broadcast_tensors(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector broadcast_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::broadcast_tensors::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
+    inline at::Tensor broadcast_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::broadcast_to::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size));
+    }
+    
+    // aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
+    inline at::Tensor broadcast_to_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::broadcast_to::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::_sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
+    inline at::Tensor _sparse_broadcast_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::_sparse_broadcast_to::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::cat(Tensor[] tensors, int dim=0) -> Tensor
+    inline at::Tensor cat(c10::DispatchKeySet dispatchKeySet, const at::ITensorListRef & tensors, int64_t dim=0) {
+        return at::_ops::cat::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::ITensorListRef & tensors, int64_t dim=0) {
+        return at::_ops::cat_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cat_outf(c10::DispatchKeySet dispatchKeySet, const at::ITensorListRef & tensors, int64_t dim, at::Tensor & out) {
+        return at::_ops::cat_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor
+    inline at::Tensor cat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) {
+        return at::_ops::cat_names::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) {
+        return at::_ops::cat_names_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) {
+        return at::_ops::cat_names_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concat(Tensor[] tensors, int dim=0) -> Tensor
+    inline at::Tensor concat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::concat::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::concat_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) {
+        return at::_ops::concat_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concat.names(Tensor[] tensors, Dimname dim) -> Tensor
+    inline at::Tensor concat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) {
+        return at::_ops::concat_names::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) {
+        return at::_ops::concat_names_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) {
+        return at::_ops::concat_names_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor
+    inline at::Tensor concatenate(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::concatenate::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concatenate_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::concatenate_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concatenate_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) {
+        return at::_ops::concatenate_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor
+    inline at::Tensor concatenate(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) {
+        return at::_ops::concatenate_names::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concatenate_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) {
+        return at::_ops::concatenate_names_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & concatenate_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) {
+        return at::_ops::concatenate_names_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::block_diag(Tensor[] tensors) -> Tensor
+    inline at::Tensor block_diag(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::block_diag::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::ceil(Tensor self) -> Tensor
+    inline at::Tensor ceil(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::ceil::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::ceil_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & ceil_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::ceil_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ceil_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::ceil_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ceil_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::ceil_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::chain_matmul(Tensor[] matrices) -> Tensor
+    inline at::Tensor chain_matmul(c10::DispatchKeySet dispatchKeySet, at::TensorList matrices) {
+        return at::_ops::chain_matmul::redispatch(dispatchKeySet, matrices);
+    }
+    
+    // aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & chain_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList matrices) {
+        return at::_ops::chain_matmul_out::redispatch(dispatchKeySet, matrices, out);
+    }
+    
+    // aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & chain_matmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList matrices, at::Tensor & out) {
+        return at::_ops::chain_matmul_out::redispatch(dispatchKeySet, matrices, out);
+    }
+    
+    // aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]
+    inline ::std::vector unsafe_chunk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t chunks, int64_t dim=0) {
+        return at::_ops::unsafe_chunk::redispatch(dispatchKeySet, self, chunks, dim);
+    }
+    
+    // aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
+    inline ::std::vector chunk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t chunks, int64_t dim=0) {
+        return at::_ops::chunk::redispatch(dispatchKeySet, self, chunks, dim);
+    }
+    
+    // aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]
+    inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections, int64_t dim=0) {
+        return at::_ops::tensor_split_sections::redispatch(dispatchKeySet, self, sections, dim);
+    }
+    
+    // aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]
+    inline ::std::vector tensor_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt sections, int64_t dim=0) {
+        return at::_ops::tensor_split_sections::redispatch(dispatchKeySet, self, sections, dim);
+    }
+    
+    // aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]
+    inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices, int64_t dim=0) {
+        return at::_ops::tensor_split_indices::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(indices), dim);
+    }
+    
+    // aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]
+    inline ::std::vector tensor_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim=0) {
+        return at::_ops::tensor_split_indices::redispatch(dispatchKeySet, self, indices, dim);
+    }
+    
+    // aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]
+    inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim=0) {
+        return at::_ops::tensor_split_tensor_indices_or_sections::redispatch(dispatchKeySet, self, tensor_indices_or_sections, dim);
+    }
+    
+    // aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+    inline at::Tensor clamp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min, const c10::optional & max=c10::nullopt) {
+        return at::_ops::clamp::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+    inline at::Tensor clamp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min={}, const c10::optional & max={}) {
+        return at::_ops::clamp_Tensor::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
+    inline at::Tensor & clamp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional & min, const c10::optional & max=c10::nullopt) {
+        return at::_ops::clamp_::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)
+    inline at::Tensor & clamp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional & min={}, const c10::optional & max={}) {
+        return at::_ops::clamp__Tensor::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & min, const c10::optional & max=c10::nullopt) {
+        return at::_ops::clamp_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min, const c10::optional & max, at::Tensor & out) {
+        return at::_ops::clamp_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & min={}, const c10::optional & max={}) {
+        return at::_ops::clamp_Tensor_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min, const c10::optional & max, at::Tensor & out) {
+        return at::_ops::clamp_Tensor_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clamp_max(Tensor self, Scalar max) -> Tensor
+    inline at::Tensor clamp_max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & max) {
+        return at::_ops::clamp_max::redispatch(dispatchKeySet, self, max);
+    }
+    
+    // aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor
+    inline at::Tensor clamp_max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & max) {
+        return at::_ops::clamp_max_Tensor::redispatch(dispatchKeySet, self, max);
+    }
+    
+    // aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)
+    inline at::Tensor & clamp_max_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & max) {
+        return at::_ops::clamp_max_::redispatch(dispatchKeySet, self, max);
+    }
+    
+    // aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!)
+    inline at::Tensor & clamp_max_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & max) {
+        return at::_ops::clamp_max__Tensor::redispatch(dispatchKeySet, self, max);
+    }
+    
+    // aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & max) {
+        return at::_ops::clamp_max_out::redispatch(dispatchKeySet, self, max, out);
+    }
+    
+    // aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & max, at::Tensor & out) {
+        return at::_ops::clamp_max_out::redispatch(dispatchKeySet, self, max, out);
+    }
+    
+    // aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & max) {
+        return at::_ops::clamp_max_Tensor_out::redispatch(dispatchKeySet, self, max, out);
+    }
+    
+    // aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & max, at::Tensor & out) {
+        return at::_ops::clamp_max_Tensor_out::redispatch(dispatchKeySet, self, max, out);
+    }
+    
+    // aten::clamp_min(Tensor self, Scalar min) -> Tensor
+    inline at::Tensor clamp_min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min) {
+        return at::_ops::clamp_min::redispatch(dispatchKeySet, self, min);
+    }
+    
+    // aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
+    inline at::Tensor clamp_min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & min) {
+        return at::_ops::clamp_min_Tensor::redispatch(dispatchKeySet, self, min);
+    }
+    
+    // aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)
+    inline at::Tensor & clamp_min_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & min) {
+        return at::_ops::clamp_min_::redispatch(dispatchKeySet, self, min);
+    }
+    
+    // aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)
+    inline at::Tensor & clamp_min_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & min) {
+        return at::_ops::clamp_min__Tensor::redispatch(dispatchKeySet, self, min);
+    }
+    
+    // aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & min) {
+        return at::_ops::clamp_min_out::redispatch(dispatchKeySet, self, min, out);
+    }
+    
+    // aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min, at::Tensor & out) {
+        return at::_ops::clamp_min_out::redispatch(dispatchKeySet, self, min, out);
+    }
+    
+    // aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & min) {
+        return at::_ops::clamp_min_Tensor_out::redispatch(dispatchKeySet, self, min, out);
+    }
+    
+    // aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clamp_min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & min, at::Tensor & out) {
+        return at::_ops::clamp_min_Tensor_out::redispatch(dispatchKeySet, self, min, out);
+    }
+    
+    // aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+    inline at::Tensor clip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min, const c10::optional & max=c10::nullopt) {
+        return at::_ops::clip::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+    inline at::Tensor clip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min={}, const c10::optional & max={}) {
+        return at::_ops::clip_Tensor::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
+    inline at::Tensor & clip_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional & min, const c10::optional & max=c10::nullopt) {
+        return at::_ops::clip_::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)
+    inline at::Tensor & clip_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::optional & min={}, const c10::optional & max={}) {
+        return at::_ops::clip__Tensor::redispatch(dispatchKeySet, self, min, max);
+    }
+    
+    // aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & min, const c10::optional & max=c10::nullopt) {
+        return at::_ops::clip_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min, const c10::optional & max, at::Tensor & out) {
+        return at::_ops::clip_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & min={}, const c10::optional & max={}) {
+        return at::_ops::clip_Tensor_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & min, const c10::optional & max, at::Tensor & out) {
+        return at::_ops::clip_Tensor_out::redispatch(dispatchKeySet, self, min, max, out);
+    }
+    
+    // aten::cudnn_is_acceptable(Tensor self) -> bool
+    inline bool cudnn_is_acceptable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::cudnn_is_acceptable::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::complex(Tensor real, Tensor imag) -> Tensor
+    inline at::Tensor complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & real, const at::Tensor & imag) {
+        return at::_ops::complex::redispatch(dispatchKeySet, real, imag);
+    }
+    
+    // aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & complex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & real, const at::Tensor & imag) {
+        return at::_ops::complex_out::redispatch(dispatchKeySet, real, imag, out);
+    }
+    
+    // aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & complex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & real, const at::Tensor & imag, at::Tensor & out) {
+        return at::_ops::complex_out::redispatch(dispatchKeySet, real, imag, out);
+    }
+    
+    // aten::polar(Tensor abs, Tensor angle) -> Tensor
+    inline at::Tensor polar(c10::DispatchKeySet dispatchKeySet, const at::Tensor & abs, const at::Tensor & angle) {
+        return at::_ops::polar::redispatch(dispatchKeySet, abs, angle);
+    }
+    
+    // aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & polar_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & abs, const at::Tensor & angle) {
+        return at::_ops::polar_out::redispatch(dispatchKeySet, abs, angle, out);
+    }
+    
+    // aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & polar_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & abs, const at::Tensor & angle, at::Tensor & out) {
+        return at::_ops::polar_out::redispatch(dispatchKeySet, abs, angle, out);
+    }
+    
+    // aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
+    inline at::Tensor constant_pad_nd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value=0) {
+        return at::_ops::constant_pad_nd::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value);
+    }
+    
+    // aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
+    inline at::Tensor constant_pad_nd_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value=0) {
+        return at::_ops::constant_pad_nd::redispatch(dispatchKeySet, self, pad, value);
+    }
+    
+    // aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
+    inline at::Tensor __dispatch_contiguous(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::MemoryFormat memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::contiguous::redispatch(dispatchKeySet, self, memory_format);
+    }
+    
+    // aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+    inline at::Tensor convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
+        return at::_ops::convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups);
+    }
+    
+    // aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+    inline at::Tensor convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
+        return at::_ops::convolution::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups);
+    }
+    
+    // aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple convolution_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : c10::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask);
+    }
+    
+    // aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple convolution_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask);
+    }
+    
+    // aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+    inline at::Tensor convolution_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
+        return at::_ops::convolution_overrideable::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups);
+    }
+    
+    // aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+    inline at::Tensor convolution_overrideable_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
+        return at::_ops::convolution_overrideable::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups);
+    }
+    
+    // aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+    inline ::std::tuple convolution_backward_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward_overrideable::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask);
+    }
+    
+    // aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+    inline ::std::tuple convolution_backward_overrideable_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward_overrideable::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask);
+    }
+    
+    // aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
+    inline at::Tensor _convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
+        return at::_ops::_convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32);
+    }
+    
+    // aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
+    inline at::Tensor _convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
+        return at::_ops::_convolution::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32);
+    }
+    
+    // aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
+    inline at::Tensor _convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) {
+        return at::_ops::_convolution_deprecated::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled);
+    }
+    
+    // aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
+    inline at::Tensor _convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) {
+        return at::_ops::_convolution_deprecated::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled);
+    }
+    
+    // aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor _convolution_mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::_convolution_mode::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor _convolution_mode_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::_convolution_mode::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _convolution_double_backward(c10::DispatchKeySet dispatchKeySet, const c10::optional & ggI, const c10::optional & ggW, const c10::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) {
+        return at::_ops::_convolution_double_backward::redispatch(dispatchKeySet, ggI, ggW, ggb, gO, weight, self, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask);
+    }
+    
+    // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _convolution_double_backward_symint(c10::DispatchKeySet dispatchKeySet, const c10::optional & ggI, const c10::optional & ggW, const c10::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::_convolution_double_backward::redispatch(dispatchKeySet, ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask);
+    }
+    
+    // aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::conv1d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::conv1d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::conv2d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::conv2d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::conv3d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::conv3d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::conv1d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::conv1d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::conv2d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::conv2d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::conv3d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::conv3d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor
+    inline at::Tensor conv_tbc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad=0) {
+        return at::_ops::conv_tbc::redispatch(dispatchKeySet, self, weight, bias, pad);
+    }
+    
+    // aten::conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple conv_tbc_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) {
+        return at::_ops::conv_tbc_backward::redispatch(dispatchKeySet, self, input, weight, bias, pad);
+    }
+    
+    // aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor
+    inline at::Tensor conv_transpose1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) {
+        return at::_ops::conv_transpose1d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor
+    inline at::Tensor conv_transpose1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::conv_transpose1d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation);
+    }
+    
+    // aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor
+    inline at::Tensor conv_transpose2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) {
+        return at::_ops::conv_transpose2d_input::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor
+    inline at::Tensor conv_transpose2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::conv_transpose2d_input::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation);
+    }
+    
+    // aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor
+    inline at::Tensor conv_transpose3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) {
+        return at::_ops::conv_transpose3d_input::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor
+    inline at::Tensor conv_transpose3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::conv_transpose3d_input::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation);
+    }
+    
+    // aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
+    inline at::Tensor copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) {
+        return at::_ops::copy::redispatch(dispatchKeySet, self, src, non_blocking);
+    }
+    
+    // aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+    inline at::Tensor & copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & src, bool non_blocking=false) {
+        return at::_ops::copy_::redispatch(dispatchKeySet, self, src, non_blocking);
+    }
+    
+    // aten::_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
+    inline at::Tensor _copy_from(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, bool non_blocking=false) {
+        return at::_ops::_copy_from::redispatch(dispatchKeySet, self, dst, non_blocking);
+    }
+    
+    // aten::_copy_from_and_resize(Tensor self, Tensor dst) -> Tensor
+    inline at::Tensor _copy_from_and_resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst) {
+        return at::_ops::_copy_from_and_resize::redispatch(dispatchKeySet, self, dst);
+    }
+    
+    // aten::cos(Tensor self) -> Tensor
+    inline at::Tensor cos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::cos::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::cos_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & cos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::cos_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::cos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::cos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::cosh(Tensor self) -> Tensor
+    inline at::Tensor cosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::cosh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::cosh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & cosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::cosh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::cosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::cosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
+    inline at::Tensor cosine_embedding_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin=0.0, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::cosine_embedding_loss::redispatch(dispatchKeySet, input1, input2, target, margin, reduction);
+    }
+    
+    // aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
+    inline at::Tensor count_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::count_nonzero_dim_IntList::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::count_nonzero(Tensor self, int? dim=None) -> Tensor
+    inline at::Tensor count_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim=c10::nullopt) {
+        return at::_ops::count_nonzero::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
+    inline at::Tensor cov(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t correction=1, const c10::optional & fweights={}, const c10::optional & aweights={}) {
+        return at::_ops::cov::redispatch(dispatchKeySet, self, correction, fweights, aweights);
+    }
+    
+    // aten::corrcoef(Tensor self) -> Tensor
+    inline at::Tensor corrcoef(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::corrcoef::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
+    inline at::Tensor cudnn_affine_grid_generator(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) {
+        return at::_ops::cudnn_affine_grid_generator::redispatch(dispatchKeySet, theta, N, C, H, W);
+    }
+    
+    // aten::cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta
+    inline at::Tensor cudnn_affine_grid_generator_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) {
+        return at::_ops::cudnn_affine_grid_generator_backward::redispatch(dispatchKeySet, grad, N, C, H, W);
+    }
+    
+    // aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple cudnn_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon) {
+        return at::_ops::cudnn_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon);
+    }
+    
+    // aten::cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple cudnn_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, const at::Tensor & reserveSpace) {
+        return at::_ops::cudnn_batch_norm_backward::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace);
+    }
+    
+    // aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+    inline at::Tensor cudnn_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32);
+    }
+    
+    // aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+    inline at::Tensor cudnn_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
+    }
+    
+    // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+    inline at::Tensor cudnn_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution_transpose::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32);
+    }
+    
+    // aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+    inline at::Tensor cudnn_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution_transpose::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
+    }
+    
+    // aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor _mps_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::_mps_convolution_transpose::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor _mps_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::_mps_convolution_transpose::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups);
+    }
+    
+    // aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor)
+    inline ::std::tuple mps_convolution_transpose_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_transpose_backward::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask);
+    }
+    
+    // aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor)
+    inline ::std::tuple mps_convolution_transpose_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_transpose_backward::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask);
+    }
+    
+    // aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor cudnn_convolution_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::cudnn_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor cudnn_convolution_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::cudnn_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor cudnn_convolution_add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::cudnn_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor cudnn_convolution_add_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::cudnn_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output
+    inline at::Tensor cudnn_grid_sampler(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid) {
+        return at::_ops::cudnn_grid_sampler::redispatch(dispatchKeySet, self, grid);
+    }
+    
+    // aten::cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid)
+    inline ::std::tuple cudnn_grid_sampler_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) {
+        return at::_ops::cudnn_grid_sampler_backward::redispatch(dispatchKeySet, self, grid, grad_output);
+    }
+    
+    // aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)
+    inline ::std::tuple cummax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::cummax::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim) {
+        return at::_ops::cummax_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::cummax_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)
+    inline ::std::tuple cummax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::cummax_dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::cummax_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::cummax_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()
+    inline void _cummax_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) {
+        return at::_ops::_cummax_helper::redispatch(dispatchKeySet, self, values, indices, dim);
+    }
+    
+    // aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)
+    inline ::std::tuple cummin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::cummin::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim) {
+        return at::_ops::cummin_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::cummin_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)
+    inline ::std::tuple cummin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::cummin_dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::cummin_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple cummin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::cummin_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices);
+    }
+    
+    // aten::_cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()
+    inline void _cummin_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) {
+        return at::_ops::_cummin_helper::redispatch(dispatchKeySet, self, values, indices, dim);
+    }
+    
+    // aten::cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor
+    inline at::Tensor cummaxmin_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim) {
+        return at::_ops::cummaxmin_backward::redispatch(dispatchKeySet, grad, input, indices, dim);
+    }
+    
+    // aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor cumprod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumprod::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
+    inline at::Tensor & cumprod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumprod_::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumprod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumprod_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumprod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::cumprod_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor cumprod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumprod_dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)
+    inline at::Tensor & cumprod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumprod__dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumprod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumprod_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumprod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::cumprod_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor
+    inline at::Tensor cumprod_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output) {
+        return at::_ops::cumprod_backward::redispatch(dispatchKeySet, grad, input, dim, output);
+    }
+    
+    // aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor cumsum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumsum::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
+    inline at::Tensor & cumsum_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumsum_::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumsum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumsum_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumsum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::cumsum_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor cumsum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumsum_dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)
+    inline at::Tensor & cumsum_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumsum__dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumsum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::cumsum_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cumsum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::cumsum_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+    inline at::Tensor cumulative_trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) {
+        return at::_ops::cumulative_trapezoid_x::redispatch(dispatchKeySet, y, x, dim);
+    }
+    
+    // aten::cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor
+    inline at::Tensor cumulative_trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Scalar & dx=1, int64_t dim=-1) {
+        return at::_ops::cumulative_trapezoid_dx::redispatch(dispatchKeySet, y, dx, dim);
+    }
+    
+    // aten::ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor
+    inline at::Tensor ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, int64_t reduction=at::Reduction::Mean, bool zero_infinity=false) {
+        return at::_ops::ctc_loss_IntList::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
+    }
+    
+    // aten::ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor
+    inline at::Tensor ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, int64_t reduction=at::Reduction::Mean, bool zero_infinity=false) {
+        return at::_ops::ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
+    }
+    
+    // aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
+    inline ::std::tuple _ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity);
+    }
+    
+    // aten::_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
+    inline ::std::tuple _ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity);
+    }
+    
+    // aten::_ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor
+    inline at::Tensor _ctc_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss_backward::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity);
+    }
+    
+    // aten::_ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor
+    inline at::Tensor _ctc_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss_backward_Tensor::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity);
+    }
+    
+    // aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor
+    inline at::Tensor diag_embed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) {
+        return at::_ops::diag_embed::redispatch(dispatchKeySet, self, offset, dim1, dim2);
+    }
+    
+    // aten::diagflat(Tensor self, int offset=0) -> Tensor
+    inline at::Tensor diagflat(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0) {
+        return at::_ops::diagflat::redispatch(dispatchKeySet, self, offset);
+    }
+    
+    // aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
+    inline at::Tensor diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) {
+        return at::_ops::diagonal::redispatch(dispatchKeySet, self, offset, dim1, dim2);
+    }
+    
+    // aten::linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)
+    inline at::Tensor linalg_diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) {
+        return at::_ops::linalg_diagonal::redispatch(dispatchKeySet, A, offset, dim1, dim2);
+    }
+    
+    // aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)
+    inline at::Tensor diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset=0) {
+        return at::_ops::diagonal_Dimname::redispatch(dispatchKeySet, self, outdim, dim1, dim2, offset);
+    }
+    
+    // aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor
+    inline at::Tensor diagonal_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+        return at::_ops::diagonal_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2);
+    }
+    
+    // aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor
+    inline at::Tensor diagonal_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+        return at::_ops::diagonal_backward::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2);
+    }
+    
+    // aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
+    inline at::Tensor & fill_diagonal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & fill_value, bool wrap=false) {
+        return at::_ops::fill_diagonal_::redispatch(dispatchKeySet, self, fill_value, wrap);
+    }
+    
+    // aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor
+    inline at::Tensor diff(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n=1, int64_t dim=-1, const c10::optional & prepend={}, const c10::optional & append={}) {
+        return at::_ops::diff::redispatch(dispatchKeySet, self, n, dim, prepend, append);
+    }
+    
+    // aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diff_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n=1, int64_t dim=-1, const c10::optional & prepend={}, const c10::optional & append={}) {
+        return at::_ops::diff_out::redispatch(dispatchKeySet, self, n, dim, prepend, append, out);
+    }
+    
+    // aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diff_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, int64_t dim, const c10::optional & prepend, const c10::optional & append, at::Tensor & out) {
+        return at::_ops::diff_out::redispatch(dispatchKeySet, self, n, dim, prepend, append, out);
+    }
+    
+    // aten::gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & spacing=c10::nullopt, c10::optional dim=c10::nullopt, int64_t edge_order=1) {
+        return at::_ops::gradient_scalarint::redispatch(dispatchKeySet, self, spacing, dim, edge_order);
+    }
+    
+    // aten::gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order=1) {
+        return at::_ops::gradient_scalararray::redispatch(dispatchKeySet, self, spacing, dim, edge_order);
+    }
+    
+    // aten::gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order=1) {
+        return at::_ops::gradient_array::redispatch(dispatchKeySet, self, dim, edge_order);
+    }
+    
+    // aten::gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ArrayRef spacing, c10::optional dim=c10::nullopt, int64_t edge_order=1) {
+        return at::_ops::gradient_scalarrayint::redispatch(dispatchKeySet, self, spacing, dim, edge_order);
+    }
+    
+    // aten::gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order=1) {
+        return at::_ops::gradient_scalarrayarray::redispatch(dispatchKeySet, self, spacing, dim, edge_order);
+    }
+    
+    // aten::gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList spacing, c10::optional dim=c10::nullopt, int64_t edge_order=1) {
+        return at::_ops::gradient_tensorarrayint::redispatch(dispatchKeySet, self, spacing, dim, edge_order);
+    }
+    
+    // aten::gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[]
+    inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order=1) {
+        return at::_ops::gradient_tensorarray::redispatch(dispatchKeySet, self, spacing, dim, edge_order);
+    }
+    
+    // aten::div.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::div_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::div__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::div_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::div_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+    inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode) {
+        return at::_ops::div_Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
+    inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode) {
+        return at::_ops::div__Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode) {
+        return at::_ops::div_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out);
+    }
+    
+    // aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode, at::Tensor & out) {
+        return at::_ops::div_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out);
+    }
+    
+    // aten::div.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::div_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::div__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+    inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, c10::optional rounding_mode) {
+        return at::_ops::div_Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)
+    inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, c10::optional rounding_mode) {
+        return at::_ops::div__Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::divide.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::divide_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::divide__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::divide_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::divide_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::divide.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::divide_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::divide__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+    inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode) {
+        return at::_ops::divide_Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
+    inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode) {
+        return at::_ops::divide__Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode) {
+        return at::_ops::divide_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out);
+    }
+    
+    // aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::optional rounding_mode, at::Tensor & out) {
+        return at::_ops::divide_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out);
+    }
+    
+    // aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+    inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, c10::optional rounding_mode) {
+        return at::_ops::divide_Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)
+    inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, c10::optional rounding_mode) {
+        return at::_ops::divide__Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode);
+    }
+    
+    // aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor true_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::true_divide_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & true_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::true_divide__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & true_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::true_divide_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & true_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::true_divide_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor true_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::true_divide_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & true_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::true_divide__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::dot(Tensor self, Tensor tensor) -> Tensor
+    inline at::Tensor dot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor) {
+        return at::_ops::dot::redispatch(dispatchKeySet, self, tensor);
+    }
+    
+    // aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor) {
+        return at::_ops::dot_out::redispatch(dispatchKeySet, self, tensor, out);
+    }
+    
+    // aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor, at::Tensor & out) {
+        return at::_ops::dot_out::redispatch(dispatchKeySet, self, tensor, out);
+    }
+    
+    // aten::vdot(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor vdot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::vdot::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & vdot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::vdot_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & vdot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::vdot_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor
+    inline at::Tensor einsum(c10::DispatchKeySet dispatchKeySet, c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path=c10::nullopt) {
+        return at::_ops::einsum::redispatch(dispatchKeySet, equation, tensors, path);
+    }
+    
+    // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
+    inline at::Tensor embedding(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) {
+        return at::_ops::embedding::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse);
+    }
+    
+    // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
+    inline at::Tensor embedding_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) {
+        return at::_ops::embedding::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse);
+    }
+    
+    // aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
+    inline at::Tensor embedding_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
+        return at::_ops::embedding_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse);
+    }
+    
+    // aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
+    inline at::Tensor embedding_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
+        return at::_ops::embedding_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse);
+    }
+    
+    // aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor
+    inline at::Tensor embedding_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) {
+        return at::_ops::embedding_dense_backward::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq);
+    }
+    
+    // aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor
+    inline at::Tensor embedding_dense_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) {
+        return at::_ops::embedding_dense_backward::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq);
+    }
+    
+    // aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)
+    inline at::Tensor & embedding_renorm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) {
+        return at::_ops::embedding_renorm_::redispatch(dispatchKeySet, self, indices, max_norm, norm_type);
+    }
+    
+    // aten::embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
+    inline at::Tensor embedding_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) {
+        return at::_ops::embedding_sparse_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq);
+    }
+    
+    // aten::_embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _embedding_bag_forward_only(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const c10::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_forward_only::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx);
+    }
+    
+    // aten::_rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
+    inline ::std::tuple _rowwise_prune(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype) {
+        return at::_ops::_rowwise_prune::redispatch(dispatchKeySet, weight, mask, compressed_indices_dtype);
+    }
+    
+    // aten::row_stack(Tensor[] tensors) -> Tensor
+    inline at::Tensor row_stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::row_stack::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & row_stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::row_stack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & row_stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::row_stack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const c10::optional & per_sample_weights={}, bool include_last_offset=false) {
+        return at::_ops::embedding_bag::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset);
+    }
+    
+    // aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, c10::optional padding_idx) {
+        return at::_ops::embedding_bag_padding_idx::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx);
+    }
+    
+    // aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const c10::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx);
+    }
+    
+    // aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx);
+    }
+    
+    // aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx);
+    }
+    
+    // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_sparse_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
+    }
+    
+    // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_sparse_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_sparse_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
+    }
+    
+    // aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_dense_backward::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
+    }
+    
+    // aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_dense_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_dense_backward::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx);
+    }
+    
+    // aten::_embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor
+    inline at::Tensor _embedding_bag_per_sample_weights_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_per_sample_weights_backward::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx);
+    }
+    
+    // aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::empty_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_memory_format::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::empty_memory_format::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_memory_format::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::empty_memory_format::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_permuted(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, at::TensorOptions options={}) {
+        return at::_ops::empty_permuted::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_permuted(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::empty_permuted::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_permuted_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::TensorOptions options={}) {
+        return at::_ops::empty_permuted::redispatch(dispatchKeySet, size, physical_layout, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_permuted_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::empty_permuted::redispatch(dispatchKeySet, size, physical_layout, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::new_empty::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_empty::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::new_empty::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_empty::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) {
+        return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) {
+        return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_empty_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, size, stride, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_full(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) {
+        return at::_ops::new_full::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_full(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_full::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_full_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) {
+        return at::_ops::new_full::redispatch(dispatchKeySet, self, size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_full_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_full::redispatch(dispatchKeySet, self, size, fill_value, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::new_zeros::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_zeros::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_zeros_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::new_zeros::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_zeros_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_zeros::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_ones(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::new_ones::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_ones(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_ones::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_ones_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::new_ones::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor new_ones_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::new_ones::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}, double scale=1, int64_t zero_point=0, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), scale, zero_point, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, double scale, int64_t zero_point, c10::optional memory_format) {
+        return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, scale, zero_point, memory_format);
+    }
+    
+    // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}, double scale=1, int64_t zero_point=0, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), scale, zero_point, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, double scale, int64_t zero_point, c10::optional memory_format) {
+        return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory, scale, zero_point, memory_format);
+    }
+    
+    // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_per_channel_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_per_channel_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
+    inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+    inline const at::Tensor & resize_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format);
+    }
+    
+    // aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+    inline const at::Tensor & resize__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_::redispatch(dispatchKeySet, self, size, memory_format);
+    }
+    
+    // aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!)
+    inline const at::Tensor & _resize_output_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device) {
+        return at::_ops::_resize_output_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device);
+    }
+    
+    // aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!)
+    inline const at::Tensor & _resize_output__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) {
+        return at::_ops::_resize_output_::redispatch(dispatchKeySet, self, size, device);
+    }
+    
+    // aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_quantized::redispatch(dispatchKeySet, size, qtensor, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::empty_quantized::redispatch(dispatchKeySet, size, qtensor, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), memory_format, out);
+    }
+    
+    // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::empty_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), memory_format, out);
+    }
+    
+    // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_out::redispatch(dispatchKeySet, size, memory_format, out);
+    }
+    
+    // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::empty_out::redispatch(dispatchKeySet, size, memory_format, out);
+    }
+    
+    // aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor empty_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::empty_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_strided(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) {
+        return at::_ops::empty_strided::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_strided(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::empty_strided::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_strided_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) {
+        return at::_ops::empty_strided::redispatch(dispatchKeySet, size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor empty_strided_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::empty_strided::redispatch(dispatchKeySet, size, stride, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::erf(Tensor self) -> Tensor
+    inline at::Tensor erf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::erf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::erf_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & erf_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::erf_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & erf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::erf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & erf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::erf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::erfc(Tensor self) -> Tensor
+    inline at::Tensor erfc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::erfc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::erfc_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & erfc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::erfc_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & erfc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::erfc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & erfc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::erfc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::exp(Tensor self) -> Tensor
+    inline at::Tensor exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::exp::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::exp_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & exp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::exp_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & exp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::exp_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & exp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::exp_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::exp2(Tensor self) -> Tensor
+    inline at::Tensor exp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::exp2::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::exp2_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & exp2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::exp2_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & exp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::exp2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & exp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::exp2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::expm1(Tensor self) -> Tensor
+    inline at::Tensor expm1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::expm1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::expm1_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & expm1_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::expm1_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & expm1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::expm1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & expm1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::expm1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
+    inline at::Tensor expand(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) {
+        return at::_ops::expand::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit);
+    }
+    
+    // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
+    inline at::Tensor expand_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) {
+        return at::_ops::expand::redispatch(dispatchKeySet, self, size, implicit);
+    }
+    
+    // aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)
+    inline at::Tensor expand_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::expand_as::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, at::TensorOptions options={}) {
+        return at::_ops::eye::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::eye::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::TensorOptions options={}) {
+        return at::_ops::eye::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::eye::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, at::TensorOptions options={}) {
+        return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, at::TensorOptions options={}) {
+        return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n) {
+        return at::_ops::eye_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, at::Tensor & out) {
+        return at::_ops::eye_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n) {
+        return at::_ops::eye_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::Tensor & out) {
+        return at::_ops::eye_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, int64_t m) {
+        return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out);
+    }
+    
+    // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, at::Tensor & out) {
+        return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out);
+    }
+    
+    // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n, c10::SymInt m) {
+        return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out);
+    }
+    
+    // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eye_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, at::Tensor & out) {
+        return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out);
+    }
+    
+    // aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)
+    inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t start_dim=0, int64_t end_dim=-1) {
+        return at::_ops::flatten_using_ints::redispatch(dispatchKeySet, self, start_dim, end_dim);
+    }
+    
+    // aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)
+    inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim) {
+        return at::_ops::flatten_named_out_dim::redispatch(dispatchKeySet, self, start_dim, end_dim, out_dim);
+    }
+    
+    // aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)
+    inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) {
+        return at::_ops::flatten_using_names::redispatch(dispatchKeySet, self, start_dim, end_dim, out_dim);
+    }
+    
+    // aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
+    inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim) {
+        return at::_ops::flatten_DimnameList::redispatch(dispatchKeySet, self, dims, out_dim);
+    }
+    
+    // aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)
+    inline at::Tensor unflatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::IntArrayRef sizes) {
+        return at::_ops::unflatten_int::redispatch(dispatchKeySet, self, dim, c10::fromIntArrayRefSlow(sizes));
+    }
+    
+    // aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)
+    inline at::Tensor unflatten_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes) {
+        return at::_ops::unflatten_int::redispatch(dispatchKeySet, self, dim, sizes);
+    }
+    
+    // aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)
+    inline at::Tensor unflatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) {
+        return at::_ops::unflatten_Dimname::redispatch(dispatchKeySet, self, dim, c10::fromIntArrayRefSlow(sizes), names);
+    }
+    
+    // aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)
+    inline at::Tensor unflatten_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) {
+        return at::_ops::unflatten_Dimname::redispatch(dispatchKeySet, self, dim, sizes, names);
+    }
+    
+    // aten::fill.Scalar(Tensor self, Scalar value) -> Tensor
+    inline at::Tensor fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & value) {
+        return at::_ops::fill_Scalar::redispatch(dispatchKeySet, self, value);
+    }
+    
+    // aten::fill.Tensor(Tensor self, Tensor value) -> Tensor
+    inline at::Tensor fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & value) {
+        return at::_ops::fill_Tensor::redispatch(dispatchKeySet, self, value);
+    }
+    
+    // aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
+    inline at::Tensor & fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & value) {
+        return at::_ops::fill__Scalar::redispatch(dispatchKeySet, self, value);
+    }
+    
+    // aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
+    inline at::Tensor & fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & value) {
+        return at::_ops::fill__Tensor::redispatch(dispatchKeySet, self, value);
+    }
+    
+    // aten::floor(Tensor self) -> Tensor
+    inline at::Tensor floor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::floor::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::floor_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & floor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::floor_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & floor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::floor_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & floor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::floor_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::floor_divide(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor floor_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::floor_divide::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & floor_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::floor_divide__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & floor_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::floor_divide_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & floor_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::floor_divide_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor floor_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::floor_divide_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & floor_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::floor_divide__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::frac(Tensor self) -> Tensor
+    inline at::Tensor frac(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::frac::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::frac_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & frac_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::frac_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & frac_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::frac_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & frac_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::frac_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::full_names::redispatch(dispatchKeySet, size, fill_value, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::full_names::redispatch(dispatchKeySet, size, fill_value, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) {
+        return at::_ops::full::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::full::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor full_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) {
+        return at::_ops::full::redispatch(dispatchKeySet, size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor full_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::full::redispatch(dispatchKeySet, size, fill_value, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Scalar & fill_value) {
+        return at::_ops::full_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, out);
+    }
+    
+    // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) {
+        return at::_ops::full_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, out);
+    }
+    
+    // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, const at::Scalar & fill_value) {
+        return at::_ops::full_out::redispatch(dispatchKeySet, size, fill_value, out);
+    }
+    
+    // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) {
+        return at::_ops::full_out::redispatch(dispatchKeySet, size, fill_value, out);
+    }
+    
+    // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor full_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::full_like::redispatch(dispatchKeySet, self, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor full_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::full_like::redispatch(dispatchKeySet, self, fill_value, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor from_file(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, c10::optional shared=c10::nullopt, c10::optional size=0, at::TensorOptions options={}) {
+        return at::_ops::from_file::redispatch(dispatchKeySet, filename, shared, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor from_file(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, c10::optional shared, c10::optional size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::from_file::redispatch(dispatchKeySet, filename, shared, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gcd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::gcd_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gcd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::gcd_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::gcd(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor gcd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::gcd::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & gcd_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::gcd_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lcm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::lcm_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lcm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::lcm_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::lcm(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor lcm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::lcm::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & lcm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::lcm_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+    inline at::Tensor grid_sampler(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::grid_sampler::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners);
+    }
+    
+    // aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+    inline at::Tensor grid_sampler_2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::grid_sampler_2d::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners);
+    }
+    
+    // aten::grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)
+    inline ::std::tuple grid_sampler_2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) {
+        return at::_ops::grid_sampler_2d_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask);
+    }
+    
+    // aten::_grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+    inline at::Tensor _grid_sampler_2d_cpu_fallback(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::_grid_sampler_2d_cpu_fallback::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners);
+    }
+    
+    // aten::_grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor)
+    inline ::std::tuple _grid_sampler_2d_cpu_fallback_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::_grid_sampler_2d_cpu_fallback_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners);
+    }
+    
+    // aten::grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+    inline at::Tensor grid_sampler_3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::grid_sampler_3d::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners);
+    }
+    
+    // aten::grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)
+    inline ::std::tuple grid_sampler_3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) {
+        return at::_ops::grid_sampler_3d_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask);
+    }
+    
+    // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) {
+        return at::_ops::hann_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::hann_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) {
+        return at::_ops::hann_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::hann_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) {
+        return at::_ops::hamming_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::hamming_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) {
+        return at::_ops::hamming_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::hamming_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, at::TensorOptions options={}) {
+        return at::_ops::hamming_window_periodic_alpha::redispatch(dispatchKeySet, window_length, periodic, alpha, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::hamming_window_periodic_alpha::redispatch(dispatchKeySet, window_length, periodic, alpha, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, at::TensorOptions options={}) {
+        return at::_ops::hamming_window_periodic_alpha_beta::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::hamming_window_periodic_alpha_beta::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) {
+        return at::_ops::kaiser_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::kaiser_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) {
+        return at::_ops::kaiser_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::kaiser_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, at::TensorOptions options={}) {
+        return at::_ops::kaiser_window_beta::redispatch(dispatchKeySet, window_length, periodic, beta, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::kaiser_window_beta::redispatch(dispatchKeySet, window_length, periodic, beta, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor
+    inline at::Tensor hinge_embedding_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, double margin=1.0, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::hinge_embedding_loss::redispatch(dispatchKeySet, self, target, margin, reduction);
+    }
+    
+    // aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor
+    inline at::Tensor group_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t num_groups, const c10::optional & weight={}, const c10::optional & bias={}, double eps=1e-05, bool cudnn_enabled=true) {
+        return at::_ops::group_norm::redispatch(dispatchKeySet, input, num_groups, weight, bias, eps, cudnn_enabled);
+    }
+    
+    // aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_group_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) {
+        return at::_ops::native_group_norm::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps);
+    }
+    
+    // aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_group_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) {
+        return at::_ops::native_group_norm::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps);
+    }
+    
+    // aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_group_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) {
+        return at::_ops::native_group_norm_backward::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask);
+    }
+    
+    // aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_group_norm_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) {
+        return at::_ops::native_group_norm_backward::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask);
+    }
+    
+    // aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
+    inline at::Tensor _fft_r2c(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) {
+        return at::_ops::_fft_r2c::redispatch(dispatchKeySet, self, dim, normalization, onesided);
+    }
+    
+    // aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_r2c_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) {
+        return at::_ops::_fft_r2c_out::redispatch(dispatchKeySet, self, dim, normalization, onesided, out);
+    }
+    
+    // aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_r2c_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided, at::Tensor & out) {
+        return at::_ops::_fft_r2c_out::redispatch(dispatchKeySet, self, dim, normalization, onesided, out);
+    }
+    
+    // aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
+    inline at::Tensor _fft_c2r(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
+        return at::_ops::_fft_c2r::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size);
+    }
+    
+    // aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
+    inline at::Tensor _fft_c2r_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) {
+        return at::_ops::_fft_c2r::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size);
+    }
+    
+    // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2r_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
+        return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out);
+    }
+    
+    // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2r_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size, at::Tensor & out) {
+        return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out);
+    }
+    
+    // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2r_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) {
+        return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out);
+    }
+    
+    // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2r_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out) {
+        return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out);
+    }
+    
+    // aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
+    inline at::Tensor _fft_c2c(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward) {
+        return at::_ops::_fft_c2c::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward);
+    }
+    
+    // aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
+    inline at::Tensor _fft_c2c_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) {
+        return at::_ops::_fft_c2c::redispatch(dispatchKeySet, self, dim, normalization, forward);
+    }
+    
+    // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2c_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward) {
+        return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward, out);
+    }
+    
+    // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2c_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out) {
+        return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward, out);
+    }
+    
+    // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2c_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) {
+        return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, dim, normalization, forward, out);
+    }
+    
+    // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fft_c2c_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out) {
+        return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, dim, normalization, forward, out);
+    }
+    
+    // aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
+    inline void _validate_compressed_sparse_indices(c10::DispatchKeySet dispatchKeySet, bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz) {
+        return at::_ops::_validate_compressed_sparse_indices::redispatch(dispatchKeySet, is_crow, compressed_idx, plain_idx, cdim, dim, nnz);
+    }
+    
+    // aten::_cufft_get_plan_cache_size(DeviceIndex device_index) -> int
+    inline int64_t _cufft_get_plan_cache_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) {
+        return at::_ops::_cufft_get_plan_cache_size::redispatch(dispatchKeySet, device_index);
+    }
+    
+    // aten::_cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int
+    inline int64_t _cufft_get_plan_cache_max_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) {
+        return at::_ops::_cufft_get_plan_cache_max_size::redispatch(dispatchKeySet, device_index);
+    }
+    
+    // aten::_cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()
+    inline void _cufft_set_plan_cache_max_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index, int64_t max_size) {
+        return at::_ops::_cufft_set_plan_cache_max_size::redispatch(dispatchKeySet, device_index, max_size);
+    }
+    
+    // aten::_cufft_clear_plan_cache(DeviceIndex device_index) -> ()
+    inline void _cufft_clear_plan_cache(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) {
+        return at::_ops::_cufft_clear_plan_cache::redispatch(dispatchKeySet, device_index);
+    }
+    
+    // aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+    inline at::Tensor index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices) {
+        return at::_ops::index_Tensor::redispatch(dispatchKeySet, self, indices);
+    }
+    
+    // aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List> & indices) {
+        return at::_ops::index_Tensor_out::redispatch(dispatchKeySet, self, indices, out);
+    }
+    
+    // aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices, at::Tensor & out) {
+        return at::_ops::index_Tensor_out::redispatch(dispatchKeySet, self, indices, out);
+    }
+    
+    // aten::_unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+    inline at::Tensor _unsafe_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices) {
+        return at::_ops::_unsafe_index_Tensor::redispatch(dispatchKeySet, self, indices);
+    }
+    
+    // aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) {
+        return at::_ops::index_copy_out::redispatch(dispatchKeySet, self, dim, index, source, out);
+    }
+    
+    // aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, at::Tensor & out) {
+        return at::_ops::index_copy_out::redispatch(dispatchKeySet, self, dim, index, source, out);
+    }
+    
+    // aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
+    inline at::Tensor & index_copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) {
+        return at::_ops::index_copy_::redispatch(dispatchKeySet, self, dim, index, source);
+    }
+    
+    // aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
+    inline at::Tensor index_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) {
+        return at::_ops::index_copy::redispatch(dispatchKeySet, self, dim, index, source);
+    }
+    
+    // aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)
+    inline at::Tensor & index_copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) {
+        return at::_ops::index_copy__dimname::redispatch(dispatchKeySet, self, dim, index, source);
+    }
+    
+    // aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor
+    inline at::Tensor index_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) {
+        return at::_ops::index_copy_dimname::redispatch(dispatchKeySet, self, dim, index, source);
+    }
+    
+    // aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
+    inline at::Tensor & index_put_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false) {
+        return at::_ops::index_put_::redispatch(dispatchKeySet, self, indices, values, accumulate);
+    }
+    
+    // aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+    inline at::Tensor index_put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false) {
+        return at::_ops::index_put::redispatch(dispatchKeySet, self, indices, values, accumulate);
+    }
+    
+    // aten::_unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+    inline at::Tensor _unsafe_index_put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false) {
+        return at::_ops::_unsafe_index_put::redispatch(dispatchKeySet, self, indices, values, accumulate);
+    }
+    
+    // aten::_index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
+    inline at::Tensor & _index_put_impl_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) {
+        return at::_ops::_index_put_impl_::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe);
+    }
+    
+    // aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
+    inline at::Tensor instance_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) {
+        return at::_ops::instance_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled);
+    }
+    
+    // aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor
+    inline at::Tensor isclose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) {
+        return at::_ops::isclose::redispatch(dispatchKeySet, self, other, rtol, atol, equal_nan);
+    }
+    
+    // aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) {
+        return at::_ops::isin_Tensor_Tensor_out::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert, out);
+    }
+    
+    // aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out) {
+        return at::_ops::isin_Tensor_Tensor_out::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert, out);
+    }
+    
+    // aten::isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor
+    inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) {
+        return at::_ops::isin_Tensor_Tensor::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert);
+    }
+    
+    // aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique=false, bool invert=false) {
+        return at::_ops::isin_Tensor_Scalar_out::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert, out);
+    }
+    
+    // aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert, at::Tensor & out) {
+        return at::_ops::isin_Tensor_Scalar_out::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert, out);
+    }
+    
+    // aten::isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor
+    inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique=false, bool invert=false) {
+        return at::_ops::isin_Tensor_Scalar::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert);
+    }
+    
+    // aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) {
+        return at::_ops::isin_Scalar_Tensor_out::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert, out);
+    }
+    
+    // aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out) {
+        return at::_ops::isin_Scalar_Tensor_out::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert, out);
+    }
+    
+    // aten::isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor
+    inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) {
+        return at::_ops::isin_Scalar_Tensor::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert);
+    }
+    
+    // aten::isnan(Tensor self) -> Tensor
+    inline at::Tensor isnan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::isnan::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_distributed(Tensor self) -> bool
+    inline bool is_distributed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_distributed::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_floating_point(Tensor self) -> bool
+    inline bool __dispatch_is_floating_point(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_floating_point::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_complex(Tensor self) -> bool
+    inline bool __dispatch_is_complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_complex::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_conj(Tensor self) -> bool
+    inline bool __dispatch_is_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_conj::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_is_zerotensor(Tensor self) -> bool
+    inline bool __dispatch__is_zerotensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_is_zerotensor::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_neg(Tensor self) -> bool
+    inline bool __dispatch_is_neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_neg::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::isreal(Tensor self) -> Tensor
+    inline at::Tensor isreal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::isreal::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_nonzero(Tensor self) -> bool
+    inline bool is_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_nonzero::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_same_size(Tensor self, Tensor other) -> bool
+    inline bool is_same_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::is_same_size::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::is_signed(Tensor self) -> bool
+    inline bool __dispatch_is_signed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_signed::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_inference(Tensor self) -> bool
+    inline bool __dispatch_is_inference(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_inference::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
+    inline at::Tensor kl_div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, bool log_target=false) {
+        return at::_ops::kl_div::redispatch(dispatchKeySet, self, target, reduction, log_target);
+    }
+    
+    // aten::kron(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor kron(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::kron::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kron_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::kron_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kron_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::kron_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple kthvalue(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim=-1, bool keepdim=false) {
+        return at::_ops::kthvalue::redispatch(dispatchKeySet, self, k, dim, keepdim);
+    }
+    
+    // aten::kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple kthvalue_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool keepdim=false) {
+        return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices);
+    }
+    
+    // aten::kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple kthvalue_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices);
+    }
+    
+    // aten::kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple kthvalue(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::kthvalue_dimname::redispatch(dispatchKeySet, self, k, dim, keepdim);
+    }
+    
+    // aten::kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple kthvalue_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices);
+    }
+    
+    // aten::kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple kthvalue_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices);
+    }
+    
+    // aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
+    inline at::Tensor layer_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight={}, const c10::optional & bias={}, double eps=1e-05, bool cudnn_enable=true) {
+        return at::_ops::layer_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, cudnn_enable);
+    }
+    
+    // aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
+    inline at::Tensor layer_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight={}, const c10::optional & bias={}, double eps=1e-05, bool cudnn_enable=true) {
+        return at::_ops::layer_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, cudnn_enable);
+    }
+    
+    // aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_layer_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps) {
+        return at::_ops::native_layer_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps);
+    }
+    
+    // aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_layer_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps) {
+        return at::_ops::native_layer_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps);
+    }
+    
+    // aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_layer_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask) {
+        return at::_ops::native_layer_norm_backward::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask);
+    }
+    
+    // aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_layer_norm_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask) {
+        return at::_ops::native_layer_norm_backward::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask);
+    }
+    
+    // aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
+    inline at::Tensor nan_to_num(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional nan=c10::nullopt, c10::optional posinf=c10::nullopt, c10::optional neginf=c10::nullopt) {
+        return at::_ops::nan_to_num::redispatch(dispatchKeySet, self, nan, posinf, neginf);
+    }
+    
+    // aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)
+    inline at::Tensor & nan_to_num_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::optional nan=c10::nullopt, c10::optional posinf=c10::nullopt, c10::optional neginf=c10::nullopt) {
+        return at::_ops::nan_to_num_::redispatch(dispatchKeySet, self, nan, posinf, neginf);
+    }
+    
+    // aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nan_to_num_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional nan=c10::nullopt, c10::optional posinf=c10::nullopt, c10::optional neginf=c10::nullopt) {
+        return at::_ops::nan_to_num_out::redispatch(dispatchKeySet, self, nan, posinf, neginf, out);
+    }
+    
+    // aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nan_to_num_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional nan, c10::optional posinf, c10::optional neginf, at::Tensor & out) {
+        return at::_ops::nan_to_num_out::redispatch(dispatchKeySet, self, nan, posinf, neginf, out);
+    }
+    
+    // aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
+    inline at::Tensor linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}) {
+        return at::_ops::linear::redispatch(dispatchKeySet, input, weight, bias);
+    }
+    
+    // aten::linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple linear_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) {
+        return at::_ops::linear_backward::redispatch(dispatchKeySet, self, grad_output, weight, output_mask);
+    }
+    
+    // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias={}) {
+        return at::_ops::linear_out::redispatch(dispatchKeySet, input, weight, bias, out);
+    }
+    
+    // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::Tensor & out) {
+        return at::_ops::linear_out::redispatch(dispatchKeySet, input, weight, bias, out);
+    }
+    
+    // aten::mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
+    inline at::Tensor mkldnn_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias={}) {
+        return at::_ops::mkldnn_linear::redispatch(dispatchKeySet, self, weight, bias);
+    }
+    
+    // aten::mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor
+    inline at::Tensor mkldnn_linear_backward_input(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) {
+        return at::_ops::mkldnn_linear_backward_input::redispatch(dispatchKeySet, input_size, grad_output, weight);
+    }
+    
+    // aten::mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor)
+    inline ::std::tuple mkldnn_linear_backward_weights(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) {
+        return at::_ops::mkldnn_linear_backward_weights::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined);
+    }
+    
+    // aten::mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple mkldnn_linear_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) {
+        return at::_ops::mkldnn_linear_backward::redispatch(dispatchKeySet, self, grad_output, weight, output_mask);
+    }
+    
+    // aten::_cslt_compress(Tensor input) -> Tensor
+    inline at::Tensor _cslt_compress(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) {
+        return at::_ops::_cslt_compress::redispatch(dispatchKeySet, input);
+    }
+    
+    // aten::_cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor
+    inline at::Tensor _cslt_sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_A, const at::Tensor & dense_B, const c10::optional & bias={}, const c10::optional & alpha={}, c10::optional out_dtype=c10::nullopt, bool transpose_result=false, int64_t alg_id=0) {
+        return at::_ops::_cslt_sparse_mm::redispatch(dispatchKeySet, compressed_A, dense_B, bias, alpha, out_dtype, transpose_result, alg_id);
+    }
+    
+    // aten::_cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int
+    inline int64_t _cslt_sparse_mm_search(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_A, const at::Tensor & dense_B, const c10::optional & bias={}, const c10::optional & alpha={}, c10::optional out_dtype=c10::nullopt, bool transpose_result=false) {
+        return at::_ops::_cslt_sparse_mm_search::redispatch(dispatchKeySet, compressed_A, dense_B, bias, alpha, out_dtype, transpose_result);
+    }
+    
+    // aten::_sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
+    inline at::Tensor _sparse_semi_structured_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const c10::optional & bias={}, c10::optional activation=c10::nullopt, c10::optional out_dtype=c10::nullopt) {
+        return at::_ops::_sparse_semi_structured_linear::redispatch(dispatchKeySet, input, weight, meta, bias, activation, out_dtype);
+    }
+    
+    // aten::_mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor
+    inline at::Tensor _mixed_dtypes_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const c10::optional & bias={}, c10::optional activation=c10::nullopt) {
+        return at::_ops::_mixed_dtypes_linear::redispatch(dispatchKeySet, input, weight, scale, bias, activation);
+    }
+    
+    // aten::fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
+    inline at::Tensor fbgemm_linear_int8_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) {
+        return at::_ops::fbgemm_linear_int8_weight_fp32_activation::redispatch(dispatchKeySet, input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias);
+    }
+    
+    // aten::fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
+    inline at::Tensor fbgemm_linear_int8_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) {
+        return at::_ops::fbgemm_linear_int8_weight::redispatch(dispatchKeySet, input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias);
+    }
+    
+    // aten::fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)
+    inline ::std::tuple fbgemm_linear_quantize_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) {
+        return at::_ops::fbgemm_linear_quantize_weight::redispatch(dispatchKeySet, input);
+    }
+    
+    // aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor
+    inline at::Tensor fbgemm_pack_gemm_matrix_fp16(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) {
+        return at::_ops::fbgemm_pack_gemm_matrix_fp16::redispatch(dispatchKeySet, input);
+    }
+    
+    // aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
+    inline at::Tensor fbgemm_linear_fp16_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) {
+        return at::_ops::fbgemm_linear_fp16_weight_fp32_activation::redispatch(dispatchKeySet, input, packed_weight, bias);
+    }
+    
+    // aten::fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
+    inline at::Tensor fbgemm_linear_fp16_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) {
+        return at::_ops::fbgemm_linear_fp16_weight::redispatch(dispatchKeySet, input, packed_weight, bias);
+    }
+    
+    // aten::fbgemm_pack_quantized_matrix(Tensor input) -> Tensor
+    inline at::Tensor fbgemm_pack_quantized_matrix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) {
+        return at::_ops::fbgemm_pack_quantized_matrix::redispatch(dispatchKeySet, input);
+    }
+    
+    // aten::fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
+    inline at::Tensor fbgemm_pack_quantized_matrix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t K, int64_t N) {
+        return at::_ops::fbgemm_pack_quantized_matrix_KN::redispatch(dispatchKeySet, input, K, N);
+    }
+    
+    // aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor ldexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ldexp_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & ldexp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ldexp_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ldexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ldexp_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ldexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::ldexp_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, at::TensorOptions options={}) {
+        return at::_ops::linspace::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::linspace::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, at::TensorOptions options={}) {
+        return at::_ops::linspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::linspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, at::TensorOptions options={}) {
+        return at::_ops::linspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::linspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, at::TensorOptions options={}) {
+        return at::_ops::linspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::linspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, int64_t steps) {
+        return at::_ops::linspace_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, at::Tensor & out) {
+        return at::_ops::linspace_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Tensor & end, int64_t steps) {
+        return at::_ops::linspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, at::Tensor & out) {
+        return at::_ops::linspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Scalar & end, int64_t steps) {
+        return at::_ops::linspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, at::Tensor & out) {
+        return at::_ops::linspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Tensor & end, int64_t steps) {
+        return at::_ops::linspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, at::Tensor & out) {
+        return at::_ops::linspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out);
+    }
+    
+    // aten::log(Tensor self) -> Tensor
+    inline at::Tensor log(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::log::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & log_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::log_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::log_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::log_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log10(Tensor self) -> Tensor
+    inline at::Tensor log10(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::log10::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log10_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & log10_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::log10_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log10_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::log10_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log10_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::log10_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log1p(Tensor self) -> Tensor
+    inline at::Tensor log1p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::log1p::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log1p_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & log1p_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::log1p_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log1p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::log1p_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log1p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::log1p_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log2(Tensor self) -> Tensor
+    inline at::Tensor log2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::log2::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log2_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & log2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::log2_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::log2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::log2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logaddexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logaddexp_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logaddexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::logaddexp_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logaddexp(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor logaddexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logaddexp::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logaddexp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logaddexp2_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logaddexp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::logaddexp2_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logaddexp2(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor logaddexp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::logaddexp2::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::xlogy_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::xlogy_Scalar_Self::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::xlogy_Scalar_Other::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & xlogy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::xlogy__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & xlogy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::xlogy__Scalar_Other::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::xlogy_OutTensor::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::xlogy_OutTensor::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::xlogy_OutScalar_Self::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::xlogy_OutScalar_Self::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::xlogy_OutScalar_Other::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::xlogy_OutScalar_Other::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base=10.0, at::TensorOptions options={}) {
+        return at::_ops::logspace::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::logspace::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base=10.0, at::TensorOptions options={}) {
+        return at::_ops::logspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::logspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base=10.0, at::TensorOptions options={}) {
+        return at::_ops::logspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::logspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base=10.0, at::TensorOptions options={}) {
+        return at::_ops::logspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::logspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base=10.0) {
+        return at::_ops::logspace_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out) {
+        return at::_ops::logspace_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base=10.0) {
+        return at::_ops::logspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out) {
+        return at::_ops::logspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base=10.0) {
+        return at::_ops::logspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out) {
+        return at::_ops::logspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base=10.0) {
+        return at::_ops::logspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out) {
+        return at::_ops::logspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out);
+    }
+    
+    // aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::log_softmax_int::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::log_softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::log_softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::log_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+    inline at::Tensor _log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_log_softmax::redispatch(dispatchKeySet, self, dim, half_to_float);
+    }
+    
+    // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) {
+        return at::_ops::_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
+    inline at::Tensor _log_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) {
+        return at::_ops::_log_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype);
+    }
+    
+    // aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _log_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) {
+        return at::_ops::_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, out);
+    }
+    
+    // aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _log_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & out) {
+        return at::_ops::_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, out);
+    }
+    
+    // aten::_logcumsumexp(Tensor self, int dim) -> Tensor
+    inline at::Tensor _logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::_logcumsumexp::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) {
+        return at::_ops::_logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) {
+        return at::_ops::_logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::logcumsumexp(Tensor self, int dim) -> Tensor
+    inline at::Tensor logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::logcumsumexp::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) {
+        return at::_ops::logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) {
+        return at::_ops::logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor
+    inline at::Tensor logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::logcumsumexp_dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::logcumsumexp_dimname_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & out) {
+        return at::_ops::logcumsumexp_dimname_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::logsumexp::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false) {
+        return at::_ops::logsumexp_names::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false) {
+        return at::_ops::logsumexp_names_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::logsumexp_names_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
+    inline at::Tensor margin_ranking_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin=0.0, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::margin_ranking_loss::redispatch(dispatchKeySet, input1, input2, target, margin, reduction);
+    }
+    
+    // aten::matmul(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::matmul::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)
+    inline ::std::tuple matmul_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) {
+        return at::_ops::matmul_backward::redispatch(dispatchKeySet, grad, self, other, mask);
+    }
+    
+    // aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::matmul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::matmul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::matrix_power(Tensor self, int n) -> Tensor
+    inline at::Tensor matrix_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n) {
+        return at::_ops::matrix_power::redispatch(dispatchKeySet, self, n);
+    }
+    
+    // aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & matrix_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n) {
+        return at::_ops::matrix_power_out::redispatch(dispatchKeySet, self, n, out);
+    }
+    
+    // aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & matrix_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, at::Tensor & out) {
+        return at::_ops::matrix_power_out::redispatch(dispatchKeySet, self, n, out);
+    }
+    
+    // aten::matrix_exp(Tensor self) -> Tensor
+    inline at::Tensor matrix_exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::matrix_exp::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::matrix_exp_backward(Tensor self, Tensor grad) -> Tensor
+    inline at::Tensor matrix_exp_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad) {
+        return at::_ops::matrix_exp_backward::redispatch(dispatchKeySet, self, grad);
+    }
+    
+    // aten::_aminmax(Tensor self) -> (Tensor, Tensor)
+    inline ::std::tuple _aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_aminmax::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple _aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::_aminmax_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)
+    inline ::std::tuple aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim=c10::nullopt, bool keepdim=false) {
+        return at::_ops::aminmax::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
+    inline ::std::tuple aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & max, const at::Tensor & self, c10::optional dim=c10::nullopt, bool keepdim=false) {
+        return at::_ops::aminmax_out::redispatch(dispatchKeySet, self, dim, keepdim, min, max);
+    }
+    
+    // aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
+    inline ::std::tuple aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max) {
+        return at::_ops::aminmax_out::redispatch(dispatchKeySet, self, dim, keepdim, min, max);
+    }
+    
+    // aten::_compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor
+    inline at::Tensor _compute_linear_combination(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & coefficients) {
+        return at::_ops::_compute_linear_combination::redispatch(dispatchKeySet, input, coefficients);
+    }
+    
+    // aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _compute_linear_combination_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & coefficients) {
+        return at::_ops::_compute_linear_combination_out::redispatch(dispatchKeySet, input, coefficients, out);
+    }
+    
+    // aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _compute_linear_combination_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & coefficients, at::Tensor & out) {
+        return at::_ops::_compute_linear_combination_out::redispatch(dispatchKeySet, input, coefficients, out);
+    }
+    
+    // aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::max_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & max, at::Tensor & max_values, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::max_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values);
+    }
+    
+    // aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & max, at::Tensor & max_values) {
+        return at::_ops::max_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values);
+    }
+    
+    // aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::max_names_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & max, at::Tensor & max_values, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::max_names_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values);
+    }
+    
+    // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & max, at::Tensor & max_values) {
+        return at::_ops::max_names_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values);
+    }
+    
+    // aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
+    inline at::Tensor value_selecting_reduction_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim) {
+        return at::_ops::value_selecting_reduction_backward::redispatch(dispatchKeySet, grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim);
+    }
+    
+    // aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
+    inline at::Tensor value_selecting_reduction_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim) {
+        return at::_ops::value_selecting_reduction_backward::redispatch(dispatchKeySet, grad, dim, indices, sizes, keepdim);
+    }
+    
+    // aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+    inline at::Tensor amax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) {
+        return at::_ops::amax::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & amax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) {
+        return at::_ops::amax_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & amax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::amax_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+    inline ::std::tuple max_pool1d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool1d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor mkldnn_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor mkldnn_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor mkldnn_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor mkldnn_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor quantized_max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::quantized_max_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor quantized_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::quantized_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor quantized_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::quantized_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+    inline at::Tensor max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::mean::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::mean_dim::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::mean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::mean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::mean_names_dim::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::mean_names_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::mean_names_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor nanmean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::nanmean::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanmean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::nanmean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanmean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::nanmean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::median(Tensor self) -> Tensor
+    inline at::Tensor median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::median::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::median_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::median_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::median_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::median_names_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::median_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::median_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::nanmedian(Tensor self) -> Tensor
+    inline at::Tensor nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::nanmedian::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::nanmedian_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::nanmedian_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::nanmedian_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::nanmedian_names_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::nanmedian_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::nanmedian_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::min_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & min_indices, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::min_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices);
+    }
+    
+    // aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices) {
+        return at::_ops::min_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices);
+    }
+    
+    // aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::min_names_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & min_indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::min_names_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices);
+    }
+    
+    // aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices) {
+        return at::_ops::min_names_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices);
+    }
+    
+    // aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+    inline at::Tensor amin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) {
+        return at::_ops::amin::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & amin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) {
+        return at::_ops::amin_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & amin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::amin_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor _mps_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::_mps_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor _mps_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::_mps_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups);
+    }
+    
+    // aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple mps_convolution_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_backward::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask);
+    }
+    
+    // aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple mps_convolution_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_backward::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask);
+    }
+    
+    // aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor mkldnn_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::mkldnn_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor mkldnn_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::mkldnn_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups);
+    }
+    
+    // aten::mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple mkldnn_rnn_layer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) {
+        return at::_ops::mkldnn_rnn_layer::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
+    }
+    
+    // aten::mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple mkldnn_rnn_layer_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) {
+        return at::_ops::mkldnn_rnn_layer_backward::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace);
+    }
+    
+    // aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple miopen_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon) {
+        return at::_ops::miopen_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon);
+    }
+    
+    // aten::miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple miopen_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon) {
+        return at::_ops::miopen_batch_norm_backward::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon);
+    }
+    
+    // aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+    inline at::Tensor miopen_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic);
+    }
+    
+    // aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+    inline at::Tensor miopen_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic);
+    }
+    
+    // aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+    inline at::Tensor miopen_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution_transpose::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic);
+    }
+    
+    // aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+    inline at::Tensor miopen_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution_transpose::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic);
+    }
+    
+    // aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+    inline at::Tensor miopen_depthwise_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_depthwise_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic);
+    }
+    
+    // aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+    inline at::Tensor miopen_depthwise_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_depthwise_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic);
+    }
+    
+    // aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor miopen_convolution_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::miopen_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor miopen_convolution_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::miopen_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor miopen_convolution_add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::miopen_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+    inline at::Tensor miopen_convolution_add_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::miopen_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups);
+    }
+    
+    // aten::miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple miopen_rnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state) {
+        return at::_ops::miopen_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state);
+    }
+    
+    // aten::miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+    inline ::std::tuple> miopen_rnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) {
+        return at::_ops::miopen_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask);
+    }
+    
+    // aten::mm(Tensor self, Tensor mat2) -> Tensor
+    inline at::Tensor mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::mm::redispatch(dispatchKeySet, self, mat2);
+    }
+    
+    // aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::mm_out::redispatch(dispatchKeySet, self, mat2, out);
+    }
+    
+    // aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) {
+        return at::_ops::mm_out::redispatch(dispatchKeySet, self, mat2, out);
+    }
+    
+    // aten::_int_mm(Tensor self, Tensor mat2) -> Tensor
+    inline at::Tensor _int_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::_int_mm::redispatch(dispatchKeySet, self, mat2);
+    }
+    
+    // aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _int_mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::_int_mm_out::redispatch(dispatchKeySet, self, mat2, out);
+    }
+    
+    // aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _int_mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) {
+        return at::_ops::_int_mm_out::redispatch(dispatchKeySet, self, mat2, out);
+    }
+    
+    // aten::_convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
+    inline at::Tensor _convert_weight_to_int4pack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t innerKTiles) {
+        return at::_ops::_convert_weight_to_int4pack::redispatch(dispatchKeySet, self, innerKTiles);
+    }
+    
+    // aten::_weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
+    inline at::Tensor _weight_int4pack_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) {
+        return at::_ops::_weight_int4pack_mm::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScaleAndZeros);
+    }
+    
+    // aten::_weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
+    inline at::Tensor _weight_int8pack_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales) {
+        return at::_ops::_weight_int8pack_mm::redispatch(dispatchKeySet, self, mat2, scales);
+    }
+    
+    // aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor
+    inline at::Tensor _sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sparse, const at::Tensor & dense) {
+        return at::_ops::_sparse_mm::redispatch(dispatchKeySet, sparse, dense);
+    }
+    
+    // aten::_sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor
+    inline at::Tensor _sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce) {
+        return at::_ops::_sparse_mm_reduce::redispatch(dispatchKeySet, sparse, dense, reduce);
+    }
+    
+    // aten::_sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor _sparse_sparse_matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::_sparse_sparse_matmul::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool keepdim=false) {
+        return at::_ops::mode::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple mode_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim=-1, bool keepdim=false) {
+        return at::_ops::mode_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple mode_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::mode_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::mode_dimname::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple mode_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) {
+        return at::_ops::mode_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple mode_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::mode_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, values, indices);
+    }
+    
+    // aten::mul.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor mul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::mul_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & mul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::mul__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::mul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::mul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::mul.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor mul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::mul_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & mul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::mul__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor multiply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::multiply_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & multiply_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::multiply__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multiply_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::multiply_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multiply_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::multiply_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor multiply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::multiply_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & multiply_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::multiply__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::mv(Tensor self, Tensor vec) -> Tensor
+    inline at::Tensor mv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec) {
+        return at::_ops::mv::redispatch(dispatchKeySet, self, vec);
+    }
+    
+    // aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec) {
+        return at::_ops::mv_out::redispatch(dispatchKeySet, self, vec, out);
+    }
+    
+    // aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec, at::Tensor & out) {
+        return at::_ops::mv_out::redispatch(dispatchKeySet, self, vec, out);
+    }
+    
+    // aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mvlgamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t p) {
+        return at::_ops::mvlgamma_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mvlgamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p, at::Tensor & out) {
+        return at::_ops::mvlgamma_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::mvlgamma(Tensor self, int p) -> Tensor
+    inline at::Tensor mvlgamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p) {
+        return at::_ops::mvlgamma::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)
+    inline at::Tensor & mvlgamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t p) {
+        return at::_ops::mvlgamma_::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
+    inline at::Tensor narrow_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) {
+        return at::_ops::narrow_copy::redispatch(dispatchKeySet, self, dim, start, length);
+    }
+    
+    // aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
+    inline at::Tensor narrow_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
+        return at::_ops::narrow_copy::redispatch(dispatchKeySet, self, dim, start, length);
+    }
+    
+    // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & narrow_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) {
+        return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out);
+    }
+    
+    // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & narrow_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length, at::Tensor & out) {
+        return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out);
+    }
+    
+    // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & narrow_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
+        return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out);
+    }
+    
+    // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & narrow_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length, at::Tensor & out) {
+        return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out);
+    }
+    
+    // aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
+    inline at::Tensor narrow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) {
+        return at::_ops::narrow::redispatch(dispatchKeySet, self, dim, start, length);
+    }
+    
+    // aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
+    inline at::Tensor narrow_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
+        return at::_ops::narrow::redispatch(dispatchKeySet, self, dim, start, length);
+    }
+    
+    // aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)
+    inline at::Tensor narrow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & start, int64_t length) {
+        return at::_ops::narrow_Tensor::redispatch(dispatchKeySet, self, dim, start, length);
+    }
+    
+    // aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)
+    inline at::Tensor narrow_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length) {
+        return at::_ops::narrow_Tensor::redispatch(dispatchKeySet, self, dim, start, length);
+    }
+    
+    // aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps) {
+        return at::_ops::native_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps);
+    }
+    
+    // aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps) {
+        return at::_ops::native_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd);
+    }
+    
+    // aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) {
+        return at::_ops::native_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd);
+    }
+    
+    // aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _native_batch_norm_legit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps);
+    }
+    
+    // aten::_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _native_batch_norm_legit_no_training(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit_no_training::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps);
+    }
+    
+    // aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))
+    inline ::std::tuple _native_batch_norm_legit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd);
+    }
+    
+    // aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))
+    inline ::std::tuple _native_batch_norm_legit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) {
+        return at::_ops::_native_batch_norm_legit_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd);
+    }
+    
+    // aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _native_batch_norm_legit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, bool training, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit_no_stats::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps);
+    }
+    
+    // aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _native_batch_norm_legit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, bool training, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit_no_stats_out::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps, out, save_mean, save_invstd);
+    }
+    
+    // aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _native_batch_norm_legit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) {
+        return at::_ops::_native_batch_norm_legit_no_stats_out::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps, out, save_mean, save_invstd);
+    }
+    
+    // aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
+    inline ::std::tuple batch_norm_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps) {
+        return at::_ops::batch_norm_stats::redispatch(dispatchKeySet, input, eps);
+    }
+    
+    // aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
+    inline at::Tensor batch_norm_elemt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) {
+        return at::_ops::batch_norm_elemt::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps);
+    }
+    
+    // aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & batch_norm_elemt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) {
+        return at::_ops::batch_norm_elemt_out::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps, out);
+    }
+    
+    // aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & batch_norm_elemt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out) {
+        return at::_ops::batch_norm_elemt_out::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps, out);
+    }
+    
+    // aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)
+    inline ::std::tuple batch_norm_gather_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, int64_t count) {
+        return at::_ops::batch_norm_gather_stats::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count);
+    }
+    
+    // aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)
+    inline ::std::tuple batch_norm_gather_stats_with_counts(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, const at::Tensor & counts) {
+        return at::_ops::batch_norm_gather_stats_with_counts::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts);
+    }
+    
+    // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple native_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask) {
+        return at::_ops::native_batch_norm_backward::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask);
+    }
+    
+    // aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple batch_norm_backward_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g) {
+        return at::_ops::batch_norm_backward_reduce::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g);
+    }
+    
+    // aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor
+    inline at::Tensor batch_norm_backward_elemt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) {
+        return at::_ops::batch_norm_backward_elemt::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count);
+    }
+    
+    // aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)
+    inline ::std::tuple batch_norm_update_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & running_mean, const c10::optional & running_var, double momentum) {
+        return at::_ops::batch_norm_update_stats::redispatch(dispatchKeySet, input, running_mean, running_var, momentum);
+    }
+    
+    // aten::is_vulkan_available() -> bool
+    inline bool is_vulkan_available(c10::DispatchKeySet dispatchKeySet) {
+        return at::_ops::is_vulkan_available::redispatch(dispatchKeySet);
+    }
+    
+    // aten::_nnpack_available() -> bool
+    inline bool _nnpack_available(c10::DispatchKeySet dispatchKeySet) {
+        return at::_ops::_nnpack_available::redispatch(dispatchKeySet);
+    }
+    
+    // aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor
+    inline at::Tensor _nnpack_spatial_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride=1) {
+        return at::_ops::_nnpack_spatial_convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride));
+    }
+    
+    // aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor
+    inline at::Tensor _nnpack_spatial_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1)) {
+        return at::_ops::_nnpack_spatial_convolution::redispatch(dispatchKeySet, input, weight, bias, padding, stride);
+    }
+    
+    // aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::ones_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::ones_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::ones::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::ones::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor ones_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::ones::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor ones_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::ones::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) {
+        return at::_ops::ones_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::ones_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) {
+        return at::_ops::ones_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::ones_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor ones_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::ones_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor ones_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::ones_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor
+    inline at::Tensor pairwise_distance(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p=2, double eps=1e-06, bool keepdim=false) {
+        return at::_ops::pairwise_distance::redispatch(dispatchKeySet, x1, x2, p, eps, keepdim);
+    }
+    
+    // aten::cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor
+    inline at::Tensor cdist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p=2, c10::optional compute_mode=c10::nullopt) {
+        return at::_ops::cdist::redispatch(dispatchKeySet, x1, x2, p, compute_mode);
+    }
+    
+    // aten::_euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+    inline at::Tensor _euclidean_dist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2) {
+        return at::_ops::_euclidean_dist::redispatch(dispatchKeySet, x1, x2);
+    }
+    
+    // aten::_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
+    inline at::Tensor _cdist_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional compute_mode) {
+        return at::_ops::_cdist_forward::redispatch(dispatchKeySet, x1, x2, p, compute_mode);
+    }
+    
+    // aten::_cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
+    inline at::Tensor _cdist_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) {
+        return at::_ops::_cdist_backward::redispatch(dispatchKeySet, grad, x1, x2, p, cdist);
+    }
+    
+    // aten::pdist(Tensor self, float p=2) -> Tensor
+    inline at::Tensor pdist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p=2) {
+        return at::_ops::pdist::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::_pdist_forward(Tensor self, float p=2) -> Tensor
+    inline at::Tensor _pdist_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p=2) {
+        return at::_ops::_pdist_forward::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::_pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor
+    inline at::Tensor _pdist_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) {
+        return at::_ops::_pdist_backward::redispatch(dispatchKeySet, grad, self, p, pdist);
+    }
+    
+    // aten::cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor
+    inline at::Tensor cosine_similarity(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, int64_t dim=1, double eps=1e-08) {
+        return at::_ops::cosine_similarity::redispatch(dispatchKeySet, x1, x2, dim, eps);
+    }
+    
+    // aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)
+    inline at::Tensor permute(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) {
+        return at::_ops::permute::redispatch(dispatchKeySet, self, dims);
+    }
+    
+    // aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+    inline at::Tensor movedim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) {
+        return at::_ops::movedim_intlist::redispatch(dispatchKeySet, self, source, destination);
+    }
+    
+    // aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+    inline at::Tensor movedim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t source, int64_t destination) {
+        return at::_ops::movedim_int::redispatch(dispatchKeySet, self, source, destination);
+    }
+    
+    // aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+    inline at::Tensor moveaxis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) {
+        return at::_ops::moveaxis_intlist::redispatch(dispatchKeySet, self, source, destination);
+    }
+    
+    // aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+    inline at::Tensor moveaxis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t source, int64_t destination) {
+        return at::_ops::moveaxis_int::redispatch(dispatchKeySet, self, source, destination);
+    }
+    
+    // aten::numpy_T(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor numpy_T(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::numpy_T::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::matrix_H(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor matrix_H(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::matrix_H::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::mT(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor mT(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::mT::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::mH(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor mH(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::mH::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::adjoint(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor adjoint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::adjoint::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
+    inline at::Tensor pixel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t upscale_factor) {
+        return at::_ops::pixel_shuffle::redispatch(dispatchKeySet, self, upscale_factor);
+    }
+    
+    // aten::pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
+    inline at::Tensor pixel_unshuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t downscale_factor) {
+        return at::_ops::pixel_unshuffle::redispatch(dispatchKeySet, self, downscale_factor);
+    }
+    
+    // aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor
+    inline at::Tensor channel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups) {
+        return at::_ops::channel_shuffle::redispatch(dispatchKeySet, self, groups);
+    }
+    
+    // aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor
+    inline at::Tensor channel_shuffle_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups) {
+        return at::_ops::channel_shuffle::redispatch(dispatchKeySet, self, groups);
+    }
+    
+    // aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor
+    inline at::Tensor native_channel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups) {
+        return at::_ops::native_channel_shuffle::redispatch(dispatchKeySet, self, groups);
+    }
+    
+    // aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor
+    inline at::Tensor native_channel_shuffle_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups) {
+        return at::_ops::native_channel_shuffle::redispatch(dispatchKeySet, self, groups);
+    }
+    
+    // aten::is_pinned(Tensor self, Device? device=None) -> bool
+    inline bool is_pinned(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional device=c10::nullopt) {
+        return at::_ops::is_pinned::redispatch(dispatchKeySet, self, device);
+    }
+    
+    // aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
+    inline at::Tensor pin_memory(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional device=c10::nullopt) {
+        return at::_ops::pin_memory::redispatch(dispatchKeySet, self, device);
+    }
+    
+    // aten::_pin_memory(Tensor self, Device? device=None) -> Tensor
+    inline at::Tensor _pin_memory(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional device=c10::nullopt) {
+        return at::_ops::_pin_memory::redispatch(dispatchKeySet, self, device);
+    }
+    
+    // aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor
+    inline at::Tensor pinverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond=1e-15) {
+        return at::_ops::pinverse::redispatch(dispatchKeySet, self, rcond);
+    }
+    
+    // aten::poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor
+    inline at::Tensor poisson_nll_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction) {
+        return at::_ops::poisson_nll_loss::redispatch(dispatchKeySet, input, target, log_input, full, eps, reduction);
+    }
+    
+    // aten::rad2deg(Tensor self) -> Tensor
+    inline at::Tensor rad2deg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::rad2deg::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::rad2deg_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & rad2deg_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::rad2deg_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rad2deg_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::rad2deg_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rad2deg_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::rad2deg_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::deg2rad(Tensor self) -> Tensor
+    inline at::Tensor deg2rad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::deg2rad::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::deg2rad_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & deg2rad_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::deg2rad_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & deg2rad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::deg2rad_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & deg2rad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::deg2rad_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor scalar_tensor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, at::TensorOptions options={}) {
+        return at::_ops::scalar_tensor::redispatch(dispatchKeySet, s, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor scalar_tensor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::scalar_tensor::redispatch(dispatchKeySet, s, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::rand_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::rand_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, size, generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, size, generator, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::rand::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::rand::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, at::TensorOptions options={}) {
+        return at::_ops::rand_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, at::TensorOptions options={}) {
+        return at::_ops::rand_generator::redispatch(dispatchKeySet, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::rand_generator::redispatch(dispatchKeySet, size, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) {
+        return at::_ops::rand_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::rand_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) {
+        return at::_ops::rand_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::rand_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional generator) {
+        return at::_ops::rand_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::rand_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional generator) {
+        return at::_ops::rand_generator_out::redispatch(dispatchKeySet, size, generator, out);
+    }
+    
+    // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::rand_generator_out::redispatch(dispatchKeySet, size, generator, out);
+    }
+    
+    // aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::rand_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::rand_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint::redispatch(dispatchKeySet, high, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint::redispatch(dispatchKeySet, high, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, c10::optional generator, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint_generator::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint_generator::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint_generator::redispatch(dispatchKeySet, high, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint_generator::redispatch(dispatchKeySet, high, size, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, c10::optional generator, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, at::TensorOptions options=at::kLong) {
+        return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, size, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t high, at::IntArrayRef size) {
+        return at::_ops::randint_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::randint_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt high, c10::SymIntArrayRef size) {
+        return at::_ops::randint_out::redispatch(dispatchKeySet, high, size, out);
+    }
+    
+    // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::randint_out::redispatch(dispatchKeySet, high, size, out);
+    }
+    
+    // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t high, at::IntArrayRef size, c10::optional generator) {
+        return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator) {
+        return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, size, generator, out);
+    }
+    
+    // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, size, generator, out);
+    }
+    
+    // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t low, int64_t high, at::IntArrayRef size) {
+        return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size) {
+        return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, size, out);
+    }
+    
+    // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, size, out);
+    }
+    
+    // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t low, int64_t high, at::IntArrayRef size, c10::optional generator) {
+        return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator) {
+        return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, size, generator, out);
+    }
+    
+    // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, size, generator, out);
+    }
+    
+    // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::randn::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::randn::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, at::TensorOptions options={}) {
+        return at::_ops::randn_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, at::TensorOptions options={}) {
+        return at::_ops::randn_generator::redispatch(dispatchKeySet, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn_generator::redispatch(dispatchKeySet, size, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::randn_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::randn_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, size, generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, size, generator, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) {
+        return at::_ops::randn_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::randn_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) {
+        return at::_ops::randn_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::randn_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional generator) {
+        return at::_ops::randn_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randn_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional generator) {
+        return at::_ops::randn_generator_out::redispatch(dispatchKeySet, size, generator, out);
+    }
+    
+    // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randn_generator_out::redispatch(dispatchKeySet, size, generator, out);
+    }
+    
+    // aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randn_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::randn_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, at::TensorOptions options=at::kLong) {
+        return at::_ops::randperm::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randperm::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::TensorOptions options=at::kLong) {
+        return at::_ops::randperm::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randperm::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, c10::optional generator, at::TensorOptions options=at::kLong) {
+        return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::optional generator, at::TensorOptions options=at::kLong) {
+        return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n) {
+        return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, at::Tensor & out) {
+        return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n) {
+        return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::Tensor & out) {
+        return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out);
+    }
+    
+    // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, c10::optional generator) {
+        return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out);
+    }
+    
+    // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out);
+    }
+    
+    // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n, c10::optional generator) {
+        return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out);
+    }
+    
+    // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randperm_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::optional generator, at::Tensor & out) {
+        return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out);
+    }
+    
+    // aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step=1, at::TensorOptions options={}) {
+        return at::_ops::range_step::redispatch(dispatchKeySet, start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::range_step::redispatch(dispatchKeySet, start, end, step, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) {
+        return at::_ops::range::redispatch(dispatchKeySet, start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::range::redispatch(dispatchKeySet, start, end, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & range_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end) {
+        return at::_ops::range_out_::redispatch(dispatchKeySet, start, end, out);
+    }
+    
+    // aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & range_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::Tensor & out) {
+        return at::_ops::range_out_::redispatch(dispatchKeySet, start, end, out);
+    }
+    
+    // aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & range_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) {
+        return at::_ops::range_out::redispatch(dispatchKeySet, start, end, step, out);
+    }
+    
+    // aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & range_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) {
+        return at::_ops::range_out::redispatch(dispatchKeySet, start, end, step, out);
+    }
+    
+    // aten::ravel(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor ravel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::ravel::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::reciprocal(Tensor self) -> Tensor
+    inline at::Tensor reciprocal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::reciprocal::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & reciprocal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::reciprocal_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reciprocal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::reciprocal_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reciprocal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::reciprocal_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::neg(Tensor self) -> Tensor
+    inline at::Tensor neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::neg::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::neg_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & neg_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::neg_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & neg_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::neg_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & neg_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::neg_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::negative(Tensor self) -> Tensor
+    inline at::Tensor negative(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::negative::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::negative_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & negative_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::negative_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & negative_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::negative_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & negative_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::negative_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::repeat(Tensor self, SymInt[] repeats) -> Tensor
+    inline at::Tensor repeat(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef repeats) {
+        return at::_ops::repeat::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats));
+    }
+    
+    // aten::repeat(Tensor self, SymInt[] repeats) -> Tensor
+    inline at::Tensor repeat_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef repeats) {
+        return at::_ops::repeat::redispatch(dispatchKeySet, self, repeats);
+    }
+    
+    // aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor
+    inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_Tensor::redispatch(dispatchKeySet, repeats, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt);
+    }
+    
+    // aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor
+    inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_Tensor::redispatch(dispatchKeySet, repeats, output_size);
+    }
+    
+    // aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+    inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_self_Tensor::redispatch(dispatchKeySet, self, repeats, dim, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt);
+    }
+    
+    // aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+    inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_self_Tensor::redispatch(dispatchKeySet, self, repeats, dim, output_size);
+    }
+    
+    // aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+    inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_self_int::redispatch(dispatchKeySet, self, repeats, dim, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt);
+    }
+    
+    // aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+    inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_self_int::redispatch(dispatchKeySet, self, repeats, dim, output_size);
+    }
+    
+    // aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
+    inline at::Tensor reshape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape) {
+        return at::_ops::reshape::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shape));
+    }
+    
+    // aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
+    inline at::Tensor reshape_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shape) {
+        return at::_ops::reshape::redispatch(dispatchKeySet, self, shape);
+    }
+    
+    // aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor _reshape_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::_reshape_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size));
+    }
+    
+    // aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor _reshape_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::_reshape_copy::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
+    inline at::Tensor _reshape_alias(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) {
+        return at::_ops::_reshape_alias::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+    }
+    
+    // aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
+    inline at::Tensor _reshape_alias_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
+        return at::_ops::_reshape_alias::redispatch(dispatchKeySet, self, size, stride);
+    }
+    
+    // aten::_mkldnn_reshape(Tensor self, int[] shape) -> Tensor
+    inline at::Tensor _mkldnn_reshape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape) {
+        return at::_ops::_mkldnn_reshape::redispatch(dispatchKeySet, self, shape);
+    }
+    
+    // aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)
+    inline at::Tensor reshape_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::reshape_as::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::round(Tensor self) -> Tensor
+    inline at::Tensor round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::round::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::round_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & round_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::round_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::round_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::round_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::round.decimals(Tensor self, *, int decimals) -> Tensor
+    inline at::Tensor round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals) {
+        return at::_ops::round_decimals::redispatch(dispatchKeySet, self, decimals);
+    }
+    
+    // aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!)
+    inline at::Tensor & round_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t decimals) {
+        return at::_ops::round__decimals::redispatch(dispatchKeySet, self, decimals);
+    }
+    
+    // aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t decimals) {
+        return at::_ops::round_decimals_out::redispatch(dispatchKeySet, self, decimals, out);
+    }
+    
+    // aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals, at::Tensor & out) {
+        return at::_ops::round_decimals_out::redispatch(dispatchKeySet, self, decimals, out);
+    }
+    
+    // aten::rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
+    inline at::Tensor rrelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::rrelu::redispatch(dispatchKeySet, self, lower, upper, training, generator);
+    }
+    
+    // aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & rrelu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::rrelu_::redispatch(dispatchKeySet, self, lower, upper, training, generator);
+    }
+    
+    // aten::relu(Tensor self) -> Tensor
+    inline at::Tensor relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::relu::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::relu_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::relu_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::relu6(Tensor self) -> Tensor
+    inline at::Tensor relu6(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::relu6::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::relu6_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & relu6_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::relu6_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::prelu(Tensor self, Tensor weight) -> Tensor
+    inline at::Tensor prelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight) {
+        return at::_ops::prelu::redispatch(dispatchKeySet, self, weight);
+    }
+    
+    // aten::_prelu_kernel(Tensor self, Tensor weight) -> Tensor
+    inline at::Tensor _prelu_kernel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight) {
+        return at::_ops::_prelu_kernel::redispatch(dispatchKeySet, self, weight);
+    }
+    
+    // aten::_prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
+    inline ::std::tuple _prelu_kernel_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight) {
+        return at::_ops::_prelu_kernel_backward::redispatch(dispatchKeySet, grad_output, self, weight);
+    }
+    
+    // aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gelu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view approximate="none") {
+        return at::_ops::gelu_out::redispatch(dispatchKeySet, self, approximate, out);
+    }
+    
+    // aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gelu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view approximate, at::Tensor & out) {
+        return at::_ops::gelu_out::redispatch(dispatchKeySet, self, approximate, out);
+    }
+    
+    // aten::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)
+    inline at::Tensor & gelu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::string_view approximate="none") {
+        return at::_ops::gelu_::redispatch(dispatchKeySet, self, approximate);
+    }
+    
+    // aten::gelu(Tensor self, *, str approximate='none') -> Tensor
+    inline at::Tensor gelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view approximate="none") {
+        return at::_ops::gelu::redispatch(dispatchKeySet, self, approximate);
+    }
+    
+    // aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & gelu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate="none") {
+        return at::_ops::gelu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, approximate, grad_input);
+    }
+    
+    // aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & gelu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate, at::Tensor & grad_input) {
+        return at::_ops::gelu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, approximate, grad_input);
+    }
+    
+    // aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor
+    inline at::Tensor gelu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate="none") {
+        return at::_ops::gelu_backward::redispatch(dispatchKeySet, grad_output, self, approximate);
+    }
+    
+    // aten::infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor
+    inline at::Tensor infinitely_differentiable_gelu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self) {
+        return at::_ops::infinitely_differentiable_gelu_backward::redispatch(dispatchKeySet, grad, self);
+    }
+    
+    // aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardshrink_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) {
+        return at::_ops::hardshrink_out::redispatch(dispatchKeySet, self, lambd, out);
+    }
+    
+    // aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardshrink_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) {
+        return at::_ops::hardshrink_out::redispatch(dispatchKeySet, self, lambd, out);
+    }
+    
+    // aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+    inline at::Tensor hardshrink(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd=0.5) {
+        return at::_ops::hardshrink::redispatch(dispatchKeySet, self, lambd);
+    }
+    
+    // aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & hardshrink_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) {
+        return at::_ops::hardshrink_backward_grad_input::redispatch(dispatchKeySet, grad_out, self, lambd, grad_input);
+    }
+    
+    // aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & hardshrink_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input) {
+        return at::_ops::hardshrink_backward_grad_input::redispatch(dispatchKeySet, grad_out, self, lambd, grad_input);
+    }
+    
+    // aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
+    inline at::Tensor hardshrink_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) {
+        return at::_ops::hardshrink_backward::redispatch(dispatchKeySet, grad_out, self, lambd);
+    }
+    
+    // aten::rsqrt(Tensor self) -> Tensor
+    inline at::Tensor rsqrt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::rsqrt::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & rsqrt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::rsqrt_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rsqrt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::rsqrt_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rsqrt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::rsqrt_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)
+    inline at::Tensor select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, int64_t index) {
+        return at::_ops::select_Dimname::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
+    inline at::Tensor select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index) {
+        return at::_ops::select_int::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
+    inline at::Tensor select_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_int::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor
+    inline at::Tensor select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) {
+        return at::_ops::select_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index);
+    }
+    
+    // aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor
+    inline at::Tensor select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_backward::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index);
+    }
+    
+    // aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor
+    inline at::Tensor _nested_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, int64_t index) {
+        return at::_ops::_nested_select_backward::redispatch(dispatchKeySet, grad_output, self, dim, index);
+    }
+    
+    // aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor
+    inline at::Tensor _nested_select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index) {
+        return at::_ops::_nested_select_backward::redispatch(dispatchKeySet, grad_output, self, dim, index);
+    }
+    
+    // aten::selu(Tensor self) -> Tensor
+    inline at::Tensor selu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::selu::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::selu_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & selu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::selu_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor
+    inline at::Tensor celu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha=1.0) {
+        return at::_ops::celu::redispatch(dispatchKeySet, self, alpha);
+    }
+    
+    // aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
+    inline at::Tensor & celu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & alpha=1.0) {
+        return at::_ops::celu_::redispatch(dispatchKeySet, self, alpha);
+    }
+    
+    // aten::silu(Tensor self) -> Tensor
+    inline at::Tensor silu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::silu::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::silu_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & silu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::silu_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & silu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::silu_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & silu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::silu_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & silu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::silu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input);
+    }
+    
+    // aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & silu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) {
+        return at::_ops::silu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input);
+    }
+    
+    // aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor silu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::silu_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::mish(Tensor self) -> Tensor
+    inline at::Tensor mish(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::mish::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::mish_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & mish_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::mish_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mish_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::mish_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mish_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::mish_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::mish_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor mish_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::mish_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::sigmoid(Tensor self) -> Tensor
+    inline at::Tensor sigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sigmoid::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sigmoid_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sigmoid_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::logit(Tensor self, float? eps=None) -> Tensor
+    inline at::Tensor logit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::logit::redispatch(dispatchKeySet, self, eps);
+    }
+    
+    // aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)
+    inline at::Tensor & logit_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::logit_::redispatch(dispatchKeySet, self, eps);
+    }
+    
+    // aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::logit_out::redispatch(dispatchKeySet, self, eps, out);
+    }
+    
+    // aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & logit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional eps, at::Tensor & out) {
+        return at::_ops::logit_out::redispatch(dispatchKeySet, self, eps, out);
+    }
+    
+    // aten::sin(Tensor self) -> Tensor
+    inline at::Tensor sin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sin::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sin_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sin_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sinc(Tensor self) -> Tensor
+    inline at::Tensor sinc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sinc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sinc_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sinc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sinc_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sinc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sinc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sinc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sinc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sinh(Tensor self) -> Tensor
+    inline at::Tensor sinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sinh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sinh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sinh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::detach(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor detach(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::detach::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::detach_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & detach_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::detach_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::size.int(Tensor self, int dim) -> int
+    inline int64_t __dispatch_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::size_int::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::size.Dimname(Tensor self, Dimname dim) -> int
+    inline int64_t size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::size_Dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::sym_size.int(Tensor self, int dim) -> SymInt
+    inline c10::SymInt __dispatch_sym_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::sym_size_int::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::sym_numel(Tensor self) -> SymInt
+    inline c10::SymInt __dispatch_sym_numel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sym_numel::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sym_storage_offset(Tensor self) -> SymInt
+    inline c10::SymInt __dispatch_sym_storage_offset(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sym_storage_offset::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+    inline at::Tensor slice(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) {
+        return at::_ops::slice_Tensor::redispatch(dispatchKeySet, self, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+    }
+    
+    // aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+    inline at::Tensor slice_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) {
+        return at::_ops::slice_Tensor::redispatch(dispatchKeySet, self, dim, start, end, step);
+    }
+    
+    // aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
+    inline at::Tensor slice_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+        return at::_ops::slice_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step);
+    }
+    
+    // aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
+    inline at::Tensor slice_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
+        return at::_ops::slice_backward::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step);
+    }
+    
+    // aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+    inline at::Tensor slice_inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) {
+        return at::_ops::slice_inverse::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+    }
+    
+    // aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+    inline at::Tensor slice_inverse_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) {
+        return at::_ops::slice_inverse::redispatch(dispatchKeySet, self, src, dim, start, end, step);
+    }
+    
+    // aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+    inline at::Tensor slice_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) {
+        return at::_ops::slice_scatter::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+    }
+    
+    // aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+    inline at::Tensor slice_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) {
+        return at::_ops::slice_scatter::redispatch(dispatchKeySet, self, src, dim, start, end, step);
+    }
+    
+    // aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
+    inline at::Tensor select_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index) {
+        return at::_ops::select_scatter::redispatch(dispatchKeySet, self, src, dim, index);
+    }
+    
+    // aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
+    inline at::Tensor select_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_scatter::redispatch(dispatchKeySet, self, src, dim, index);
+    }
+    
+    // aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
+    inline at::Tensor diagonal_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) {
+        return at::_ops::diagonal_scatter::redispatch(dispatchKeySet, self, src, offset, dim1, dim2);
+    }
+    
+    // aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+    inline at::Tensor as_strided_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_scatter::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+    }
+    
+    // aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+    inline at::Tensor as_strided_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_scatter::redispatch(dispatchKeySet, self, src, size, stride, storage_offset);
+    }
+    
+    // aten::smm(Tensor self, Tensor mat2) -> Tensor
+    inline at::Tensor smm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) {
+        return at::_ops::smm::redispatch(dispatchKeySet, self, mat2);
+    }
+    
+    // aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::softmax_int::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out);
+    }
+    
+    // aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+    inline at::Tensor _softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_softmax::redispatch(dispatchKeySet, self, dim, half_to_float);
+    }
+    
+    // aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) {
+        return at::_ops::_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
+    inline at::Tensor _softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) {
+        return at::_ops::_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype);
+    }
+    
+    // aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) {
+        return at::_ops::_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, grad_input);
+    }
+    
+    // aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & grad_input) {
+        return at::_ops::_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, grad_input);
+    }
+    
+    // aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+    inline ::std::vector unsafe_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) {
+        return at::_ops::unsafe_split_Tensor::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+    inline ::std::vector unsafe_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) {
+        return at::_ops::unsafe_split_Tensor::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
+    inline ::std::vector split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) {
+        return at::_ops::split_Tensor::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
+    inline ::std::vector split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) {
+        return at::_ops::split_Tensor::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]
+    inline ::std::vector split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_size, int64_t dim=0) {
+        return at::_ops::split_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_size), dim);
+    }
+    
+    // aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]
+    inline ::std::vector split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim=0) {
+        return at::_ops::split_sizes::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+    inline ::std::vector unsafe_split_with_sizes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::unsafe_split_with_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim);
+    }
+    
+    // aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+    inline ::std::vector unsafe_split_with_sizes_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::unsafe_split_with_sizes::redispatch(dispatchKeySet, self, split_sizes, dim);
+    }
+    
+    // aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
+    inline ::std::vector split_with_sizes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::split_with_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim);
+    }
+    
+    // aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
+    inline ::std::vector split_with_sizes_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::split_with_sizes::redispatch(dispatchKeySet, self, split_sizes, dim);
+    }
+    
+    // aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+    inline ::std::vector hsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) {
+        return at::_ops::hsplit_int::redispatch(dispatchKeySet, self, sections);
+    }
+    
+    // aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+    inline ::std::vector hsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) {
+        return at::_ops::hsplit_array::redispatch(dispatchKeySet, self, indices);
+    }
+    
+    // aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+    inline ::std::vector vsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) {
+        return at::_ops::vsplit_int::redispatch(dispatchKeySet, self, sections);
+    }
+    
+    // aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+    inline ::std::vector vsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) {
+        return at::_ops::vsplit_array::redispatch(dispatchKeySet, self, indices);
+    }
+    
+    // aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+    inline ::std::vector dsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) {
+        return at::_ops::dsplit_int::redispatch(dispatchKeySet, self, sections);
+    }
+    
+    // aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+    inline ::std::vector dsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) {
+        return at::_ops::dsplit_array::redispatch(dispatchKeySet, self, indices);
+    }
+    
+    // aten::squeeze(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::squeeze::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
+    inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::squeeze_dim::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
+    inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::squeeze_dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+    inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::squeeze_dims::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::squeeze_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::squeeze_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
+    inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim) {
+        return at::_ops::squeeze__dim::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+    inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::squeeze__dims::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)
+    inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim) {
+        return at::_ops::squeeze__dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor sspaddmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::sspaddmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha);
+    }
+    
+    // aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sspaddmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::sspaddmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sspaddmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::sspaddmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
+    inline at::Tensor _chunk_cat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, int64_t num_chunks) {
+        return at::_ops::_chunk_cat::redispatch(dispatchKeySet, tensors, dim, num_chunks);
+    }
+    
+    // aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _chunk_cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim, int64_t num_chunks) {
+        return at::_ops::_chunk_cat_out::redispatch(dispatchKeySet, tensors, dim, num_chunks, out);
+    }
+    
+    // aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _chunk_cat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor & out) {
+        return at::_ops::_chunk_cat_out::redispatch(dispatchKeySet, tensors, dim, num_chunks, out);
+    }
+    
+    // aten::stack(Tensor[] tensors, int dim=0) -> Tensor
+    inline at::Tensor stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::stack::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::stack_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) {
+        return at::_ops::stack_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::_stack(Tensor[] tensors, int dim=0) -> Tensor
+    inline at::Tensor _stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::_stack::redispatch(dispatchKeySet, tensors, dim);
+    }
+    
+    // aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) {
+        return at::_ops::_stack_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) {
+        return at::_ops::_stack_out::redispatch(dispatchKeySet, tensors, dim, out);
+    }
+    
+    // aten::hstack(Tensor[] tensors) -> Tensor
+    inline at::Tensor hstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::hstack::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::hstack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::hstack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::vstack(Tensor[] tensors) -> Tensor
+    inline at::Tensor vstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::vstack::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & vstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::vstack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & vstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::vstack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::dstack(Tensor[] tensors) -> Tensor
+    inline at::Tensor dstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::dstack::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::dstack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::dstack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
+    inline at::Tensor stft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool normalized, c10::optional onesided=c10::nullopt, c10::optional return_complex=c10::nullopt) {
+        return at::_ops::stft::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, normalized, onesided, return_complex);
+    }
+    
+    // aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
+    inline at::Tensor stft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, c10::optional hop_length=c10::nullopt, c10::optional win_length=c10::nullopt, const c10::optional & window={}, bool center=true, c10::string_view pad_mode="reflect", bool normalized=false, c10::optional onesided=c10::nullopt, c10::optional return_complex=c10::nullopt) {
+        return at::_ops::stft_center::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex);
+    }
+    
+    // aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
+    inline at::Tensor istft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, c10::optional hop_length=c10::nullopt, c10::optional win_length=c10::nullopt, const c10::optional & window={}, bool center=true, bool normalized=false, c10::optional onesided=c10::nullopt, c10::optional length=c10::nullopt, bool return_complex=false) {
+        return at::_ops::istft::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex);
+    }
+    
+    // aten::stride.int(Tensor self, int dim) -> int
+    inline int64_t __dispatch_stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::stride_int::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::stride.Dimname(Tensor self, Dimname dim) -> int
+    inline int64_t stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::stride_Dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::sym_stride.int(Tensor self, int dim) -> SymInt
+    inline c10::SymInt __dispatch_sym_stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::sym_stride_int::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::sum::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::sum_dim_IntList::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::sum_dim_DimnameList::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::sum_IntList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::sum_IntList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::sum_DimnameList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::sum_DimnameList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
+    inline at::Tensor _nested_sum_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) {
+        return at::_ops::_nested_sum_backward::redispatch(dispatchKeySet, grad, self, dim, keepdim);
+    }
+    
+    // aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor nansum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::nansum::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nansum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::nansum_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nansum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::nansum_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor sum_to_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::sum_to_size::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size));
+    }
+    
+    // aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor sum_to_size_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::sum_to_size::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::sqrt(Tensor self) -> Tensor
+    inline at::Tensor sqrt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sqrt::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sqrt_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sqrt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sqrt_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sqrt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sqrt_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sqrt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sqrt_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::square(Tensor self) -> Tensor
+    inline at::Tensor square(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::square::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::square_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & square_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::square_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & square_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::square_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & square_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::square_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::std(Tensor self, bool unbiased=True) -> Tensor
+    inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) {
+        return at::_ops::std::redispatch(dispatchKeySet, self, unbiased);
+    }
+    
+    // aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
+    inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::std_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+    inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
+    inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) {
+        return at::_ops::std_mean::redispatch(dispatchKeySet, self, unbiased);
+    }
+    
+    // aten::std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::std_mean_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_mean_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::std_mean_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_mean_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::std_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out) {
+        return at::_ops::std_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, at::Tensor & out) {
+        return at::_ops::std_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
+    inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::std_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::std_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out) {
+        return at::_ops::std_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+    inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction, bool keepdim, at::Tensor & out) {
+        return at::_ops::std_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::prod::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::prod_dim_int::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::prod_int_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::prod_int_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::prod_dim_Dimname::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::prod_Dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::prod_Dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::t(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::t::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::t_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & t_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::t_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::tan(Tensor self) -> Tensor
+    inline at::Tensor tan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::tan::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::tan_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & tan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::tan_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::tan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::tan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::tanh(Tensor self) -> Tensor
+    inline at::Tensor tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::tanh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::tanh_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & tanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::tanh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::tanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::tanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor
+    inline at::Tensor tensordot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) {
+        return at::_ops::tensordot::redispatch(dispatchKeySet, self, other, dims_self, dims_other);
+    }
+    
+    // aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tensordot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) {
+        return at::_ops::tensordot_out::redispatch(dispatchKeySet, self, other, dims_self, dims_other, out);
+    }
+    
+    // aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tensordot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other, at::Tensor & out) {
+        return at::_ops::tensordot_out::redispatch(dispatchKeySet, self, other, dims_self, dims_other, out);
+    }
+    
+    // aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor
+    inline at::Tensor threshold(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) {
+        return at::_ops::threshold::redispatch(dispatchKeySet, self, threshold, value);
+    }
+    
+    // aten::threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
+    inline at::Tensor & threshold_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) {
+        return at::_ops::threshold_::redispatch(dispatchKeySet, self, threshold, value);
+    }
+    
+    // aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & threshold_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) {
+        return at::_ops::threshold_out::redispatch(dispatchKeySet, self, threshold, value, out);
+    }
+    
+    // aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & threshold_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::threshold_out::redispatch(dispatchKeySet, self, threshold, value, out);
+    }
+    
+    // aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & threshold_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) {
+        return at::_ops::threshold_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, threshold, grad_input);
+    }
+    
+    // aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & threshold_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold, at::Tensor & grad_input) {
+        return at::_ops::threshold_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, threshold, grad_input);
+    }
+    
+    // aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
+    inline at::Tensor threshold_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) {
+        return at::_ops::threshold_backward::redispatch(dispatchKeySet, grad_output, self, threshold);
+    }
+    
+    // aten::tile(Tensor self, SymInt[] dims) -> Tensor
+    inline at::Tensor tile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) {
+        return at::_ops::tile::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dims));
+    }
+    
+    // aten::tile(Tensor self, SymInt[] dims) -> Tensor
+    inline at::Tensor tile_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dims) {
+        return at::_ops::tile::redispatch(dispatchKeySet, self, dims);
+    }
+    
+    // aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+    inline at::Tensor transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::transpose_int::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)
+    inline at::Tensor transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim0, at::Dimname dim1) {
+        return at::_ops::transpose_Dimname::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::_mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor
+    inline at::Tensor _mkldnn_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::_mkldnn_transpose::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+    inline at::Tensor & transpose_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::transpose_::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+    inline at::Tensor & _mkldnn_transpose_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::_mkldnn_transpose_::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::one_hot(Tensor self, int num_classes=-1) -> Tensor
+    inline at::Tensor one_hot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_classes=-1) {
+        return at::_ops::one_hot::redispatch(dispatchKeySet, self, num_classes);
+    }
+    
+    // aten::flip(Tensor self, int[] dims) -> Tensor
+    inline at::Tensor flip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) {
+        return at::_ops::flip::redispatch(dispatchKeySet, self, dims);
+    }
+    
+    // aten::fliplr(Tensor self) -> Tensor
+    inline at::Tensor fliplr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::fliplr::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::flipud(Tensor self) -> Tensor
+    inline at::Tensor flipud(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::flipud::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
+    inline at::Tensor roll(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims={}) {
+        return at::_ops::roll::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims);
+    }
+    
+    // aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
+    inline at::Tensor roll_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) {
+        return at::_ops::roll::redispatch(dispatchKeySet, self, shifts, dims);
+    }
+    
+    // aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
+    inline at::Tensor rot90(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k=1, at::IntArrayRef dims={0,1}) {
+        return at::_ops::rot90::redispatch(dispatchKeySet, self, k, dims);
+    }
+    
+    // aten::trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+    inline at::Tensor trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) {
+        return at::_ops::trapezoid_x::redispatch(dispatchKeySet, y, x, dim);
+    }
+    
+    // aten::trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor
+    inline at::Tensor trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Scalar & dx=1, int64_t dim=-1) {
+        return at::_ops::trapezoid_dx::redispatch(dispatchKeySet, y, dx, dim);
+    }
+    
+    // aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+    inline at::Tensor trapz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) {
+        return at::_ops::trapz_x::redispatch(dispatchKeySet, y, x, dim);
+    }
+    
+    // aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
+    inline at::Tensor trapz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, double dx=1, int64_t dim=-1) {
+        return at::_ops::trapz_dx::redispatch(dispatchKeySet, y, dx, dim);
+    }
+    
+    // aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _transform_bias_rescale_qkv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) {
+        return at::_ops::_transform_bias_rescale_qkv::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads);
+    }
+    
+    // aten::_nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor
+    inline at::Tensor _nested_tensor_from_mask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true) {
+        return at::_ops::_nested_tensor_from_mask::redispatch(dispatchKeySet, t, mask, mask_check);
+    }
+    
+    // aten::_nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool
+    inline bool _nested_tensor_from_mask_left_aligned(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask) {
+        return at::_ops::_nested_tensor_from_mask_left_aligned::redispatch(dispatchKeySet, t, mask);
+    }
+    
+    // aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
+    inline at::Tensor _nested_from_padded(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) {
+        return at::_ops::_nested_from_padded::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213);
+    }
+    
+    // aten::_nested_tensor_size(Tensor self) -> Tensor
+    inline at::Tensor _nested_tensor_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_tensor_size::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_tensor_strides(Tensor self) -> Tensor
+    inline at::Tensor _nested_tensor_strides(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_tensor_strides::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor
+    inline at::Tensor _nested_tensor_storage_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_tensor_storage_offsets::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor
+    inline at::Tensor _nested_from_padded_and_nested_example(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & nt_example) {
+        return at::_ops::_nested_from_padded_and_nested_example::redispatch(dispatchKeySet, padded, nt_example);
+    }
+    
+    // aten::_nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)
+    inline at::Tensor _nested_view_from_buffer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) {
+        return at::_ops::_nested_view_from_buffer::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets);
+    }
+    
+    // aten::_nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor
+    inline at::Tensor _nested_view_from_buffer_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) {
+        return at::_ops::_nested_view_from_buffer_copy::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets);
+    }
+    
+    // aten::_nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
+    inline at::Tensor _nested_view_from_jagged(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const c10::optional & lengths={}, int64_t ragged_idx=1) {
+        return at::_ops::_nested_view_from_jagged::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx);
+    }
+    
+    // aten::_nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor
+    inline at::Tensor _nested_view_from_jagged_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const c10::optional & lengths={}, int64_t ragged_idx=1) {
+        return at::_ops::_nested_view_from_jagged_copy::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx);
+    }
+    
+    // aten::_nested_get_values(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor _nested_get_values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_get_values::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_get_values_copy(Tensor self) -> Tensor
+    inline at::Tensor _nested_get_values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_get_values_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_get_offsets(Tensor self) -> Tensor
+    inline at::Tensor _nested_get_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_get_offsets::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_get_lengths(Tensor self) -> Tensor
+    inline at::Tensor _nested_get_lengths(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_get_lengths::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_get_ragged_idx(Tensor self) -> int
+    inline int64_t _nested_get_ragged_idx(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nested_get_ragged_idx::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nested_get_jagged_dummy(Tensor any) -> Tensor
+    inline at::Tensor _nested_get_jagged_dummy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & any) {
+        return at::_ops::_nested_get_jagged_dummy::redispatch(dispatchKeySet, any);
+    }
+    
+    // aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
+    inline at::Tensor _trilinear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim=1) {
+        return at::_ops::_trilinear::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
+    }
+    
+    // aten::triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor
+    inline at::Tensor triplet_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin=1.0, double p=2, double eps=1e-06, bool swap=false, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::triplet_margin_loss::redispatch(dispatchKeySet, anchor, positive, negative, margin, p, eps, swap, reduction);
+    }
+    
+    // aten::trunc(Tensor self) -> Tensor
+    inline at::Tensor trunc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::trunc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::trunc_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & trunc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::trunc_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & trunc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::trunc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & trunc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::trunc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::fix(Tensor self) -> Tensor
+    inline at::Tensor fix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::fix::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::fix_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & fix_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::fix_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fix_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::fix_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fix_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::fix_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::type_as(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor type_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::type_as::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool
+    inline bool _has_compatible_shallow_copy_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & from) {
+        return at::_ops::_has_compatible_shallow_copy_type::redispatch(dispatchKeySet, self, from);
+    }
+    
+    // aten::_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
+    inline ::std::tuple _unique(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted=true, bool return_inverse=false) {
+        return at::_ops::_unique::redispatch(dispatchKeySet, self, sorted, return_inverse);
+    }
+    
+    // aten::unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple unique_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool sorted=true, bool return_inverse=false, bool return_counts=false) {
+        return at::_ops::unique_dim::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts);
+    }
+    
+    // aten::unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple unique_consecutive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool return_inverse=false, bool return_counts=false, c10::optional dim=c10::nullopt) {
+        return at::_ops::unique_consecutive::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim);
+    }
+    
+    // aten::unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple unique_dim_consecutive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool return_inverse=false, bool return_counts=false) {
+        return at::_ops::unique_dim_consecutive::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts);
+    }
+    
+    // aten::_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _unique2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted=true, bool return_inverse=false, bool return_counts=false) {
+        return at::_ops::_unique2::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts);
+    }
+    
+    // aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor _unsafe_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::_unsafe_view::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size));
+    }
+    
+    // aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor _unsafe_view_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::_unsafe_view::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
+    inline at::Tensor unsqueeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::unsqueeze::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
+    inline at::Tensor & unsqueeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim) {
+        return at::_ops::unsqueeze_::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::vander(Tensor x, int? N=None, bool increasing=False) -> Tensor
+    inline at::Tensor vander(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, c10::optional N=c10::nullopt, bool increasing=false) {
+        return at::_ops::vander::redispatch(dispatchKeySet, x, N, increasing);
+    }
+    
+    // aten::var(Tensor self, bool unbiased=True) -> Tensor
+    inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) {
+        return at::_ops::var::redispatch(dispatchKeySet, self, unbiased);
+    }
+    
+    // aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
+    inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::var_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+    inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::var_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out) {
+        return at::_ops::var_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, at::Tensor & out) {
+        return at::_ops::var_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
+    inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::var_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::var_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out) {
+        return at::_ops::var_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out);
+    }
+    
+    // aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+    inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction, bool keepdim, at::Tensor & out) {
+        return at::_ops::var_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out);
+    }
+    
+    // aten::var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
+    inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) {
+        return at::_ops::var_mean::redispatch(dispatchKeySet, self, unbiased);
+    }
+    
+    // aten::var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::var_mean_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_mean_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) {
+        return at::_ops::var_mean_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim);
+    }
+    
+    // aten::var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+    inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_mean_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim);
+    }
+    
+    // aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)
+    inline at::Tensor view_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::view_as::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
+    inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::where_self::redispatch(dispatchKeySet, condition, self, other);
+    }
+    
+    // aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & where_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::where_self_out::redispatch(dispatchKeySet, condition, self, other, out);
+    }
+    
+    // aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & where_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::where_self_out::redispatch(dispatchKeySet, condition, self, other, out);
+    }
+    
+    // aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor
+    inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::where_ScalarSelf::redispatch(dispatchKeySet, condition, self, other);
+    }
+    
+    // aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor
+    inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::where_ScalarOther::redispatch(dispatchKeySet, condition, self, other);
+    }
+    
+    // aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor
+    inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other) {
+        return at::_ops::where_Scalar::redispatch(dispatchKeySet, condition, self, other);
+    }
+    
+    // aten::where(Tensor condition) -> Tensor[]
+    inline ::std::vector where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition) {
+        return at::_ops::where::redispatch(dispatchKeySet, condition);
+    }
+    
+    // aten::norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
+    inline at::Tensor norm_except_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, int64_t pow=2, int64_t dim=0) {
+        return at::_ops::norm_except_dim::redispatch(dispatchKeySet, v, pow, dim);
+    }
+    
+    // aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor
+    inline at::Tensor _weight_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) {
+        return at::_ops::_weight_norm::redispatch(dispatchKeySet, v, g, dim);
+    }
+    
+    // aten::_weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)
+    inline ::std::tuple _weight_norm_interface(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) {
+        return at::_ops::_weight_norm_interface::redispatch(dispatchKeySet, v, g, dim);
+    }
+    
+    // aten::_weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
+    inline ::std::tuple _weight_norm_interface_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) {
+        return at::_ops::_weight_norm_interface_backward::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim);
+    }
+    
+    // aten::_weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
+    inline ::std::tuple _weight_norm_differentiable_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) {
+        return at::_ops::_weight_norm_differentiable_backward::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim);
+    }
+    
+    // aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::TensorOptions options={}) {
+        return at::_ops::zeros_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::zeros_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _efficientzerotensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _efficientzerotensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _efficientzerotensor_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _efficientzerotensor_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::zeros::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::zeros::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor zeros_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::zeros::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor zeros_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::zeros::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) {
+        return at::_ops::zeros_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::zeros_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) {
+        return at::_ops::zeros_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::zeros_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor zeros_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::zeros_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor zeros_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) {
+        return at::_ops::zeros_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format);
+    }
+    
+    // aten::_standard_gamma_grad(Tensor self, Tensor output) -> Tensor
+    inline at::Tensor _standard_gamma_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & output) {
+        return at::_ops::_standard_gamma_grad::redispatch(dispatchKeySet, self, output);
+    }
+    
+    // aten::_standard_gamma(Tensor self, Generator? generator=None) -> Tensor
+    inline at::Tensor _standard_gamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::_standard_gamma::redispatch(dispatchKeySet, self, generator);
+    }
+    
+    // aten::_dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor
+    inline at::Tensor _dirichlet_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) {
+        return at::_ops::_dirichlet_grad::redispatch(dispatchKeySet, x, alpha, total);
+    }
+    
+    // aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
+    inline at::Tensor _sample_dirichlet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::_sample_dirichlet::redispatch(dispatchKeySet, self, generator);
+    }
+    
+    // aten::poisson(Tensor self, Generator? generator=None) -> Tensor
+    inline at::Tensor poisson(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::poisson::redispatch(dispatchKeySet, self, generator);
+    }
+    
+    // aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor
+    inline at::Tensor binomial(c10::DispatchKeySet dispatchKeySet, const at::Tensor & count, const at::Tensor & prob, c10::optional generator=c10::nullopt) {
+        return at::_ops::binomial::redispatch(dispatchKeySet, count, prob, generator);
+    }
+    
+    // aten::native_norm(Tensor self, Scalar p=2) -> Tensor
+    inline at::Tensor native_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p=2) {
+        return at::_ops::native_norm::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
+    inline at::Tensor native_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, c10::optional dtype) {
+        return at::_ops::native_norm_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype);
+    }
+    
+    // aten::_sparse_sum(Tensor self) -> Tensor
+    inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_sparse_sum::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor
+    inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) {
+        return at::_ops::_sparse_sum_dtype::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::_sparse_sum.dim(Tensor self, int[1] dim) -> Tensor
+    inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::_sparse_sum_dim::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::_sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor
+    inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype) {
+        return at::_ops::_sparse_sum_dim_dtype::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor
+    inline at::Tensor _sparse_sum_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::_sparse_sum_backward::redispatch(dispatchKeySet, grad, self, dim);
+    }
+    
+    // aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor _sparse_csr_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_csr_sum_dim_dtype::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::_sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor _sparse_csr_prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_csr_prod_dim_dtype::redispatch(dispatchKeySet, self, dim, keepdim, dtype);
+    }
+    
+    // aten::_sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_softmax_int::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+    inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_sparse_softmax::redispatch(dispatchKeySet, self, dim, half_to_float);
+    }
+    
+    // aten::_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
+    inline at::Tensor _sparse_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) {
+        return at::_ops::_sparse_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, self);
+    }
+    
+    // aten::_sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_log_softmax_int::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_log_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::_sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+    inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_sparse_log_softmax::redispatch(dispatchKeySet, self, dim, half_to_float);
+    }
+    
+    // aten::_sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
+    inline at::Tensor _sparse_log_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) {
+        return at::_ops::_sparse_log_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, self);
+    }
+    
+    // aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
+    inline at::Tensor _spdiags(c10::DispatchKeySet dispatchKeySet, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, c10::optional layout=c10::nullopt) {
+        return at::_ops::_spdiags::redispatch(dispatchKeySet, diagonals, offsets, shape, layout);
+    }
+    
+    // aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
+    inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::ScalarType dtype) {
+        return at::_ops::norm_ScalarOpt_dtype::redispatch(dispatchKeySet, self, p, dtype);
+    }
+    
+    // aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor
+    inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p=2) {
+        return at::_ops::norm_Scalar::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+    inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) {
+        return at::_ops::norm_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype);
+    }
+    
+    // aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::norm_ScalarOpt_dim::redispatch(dispatchKeySet, self, p, dim, keepdim);
+    }
+    
+    // aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) {
+        return at::_ops::norm_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out);
+    }
+    
+    // aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::norm_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out);
+    }
+    
+    // aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::norm_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out);
+    }
+    
+    // aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::norm_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out);
+    }
+    
+    // aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+    inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) {
+        return at::_ops::norm_names_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype);
+    }
+    
+    // aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::DimnameList dim, bool keepdim=false) {
+        return at::_ops::norm_names_ScalarOpt_dim::redispatch(dispatchKeySet, self, p, dim, keepdim);
+    }
+    
+    // aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) {
+        return at::_ops::norm_names_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out);
+    }
+    
+    // aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::norm_names_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out);
+    }
+    
+    // aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p, at::DimnameList dim, bool keepdim=false) {
+        return at::_ops::norm_names_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out);
+    }
+    
+    // aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::DimnameList dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::norm_names_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out);
+    }
+    
+    // aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
+    inline ::std::tuple frexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::frexp_Tensor::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)
+    inline ::std::tuple frexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & mantissa, at::Tensor & exponent, const at::Tensor & self) {
+        return at::_ops::frexp_Tensor_out::redispatch(dispatchKeySet, self, mantissa, exponent);
+    }
+    
+    // aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)
+    inline ::std::tuple frexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & mantissa, at::Tensor & exponent) {
+        return at::_ops::frexp_Tensor_out::redispatch(dispatchKeySet, self, mantissa, exponent);
+    }
+    
+    // aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor frobenius_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::frobenius_norm_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & frobenius_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::frobenius_norm_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & frobenius_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::frobenius_norm_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::nuclear_norm(Tensor self, bool keepdim=False) -> Tensor
+    inline at::Tensor nuclear_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool keepdim=false) {
+        return at::_ops::nuclear_norm::redispatch(dispatchKeySet, self, keepdim);
+    }
+    
+    // aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nuclear_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool keepdim=false) {
+        return at::_ops::nuclear_norm_out::redispatch(dispatchKeySet, self, keepdim, out);
+    }
+    
+    // aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nuclear_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool keepdim, at::Tensor & out) {
+        return at::_ops::nuclear_norm_out::redispatch(dispatchKeySet, self, keepdim, out);
+    }
+    
+    // aten::nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor nuclear_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::nuclear_norm_dim::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nuclear_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::nuclear_norm_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nuclear_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::nuclear_norm_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor clone(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::clone::redispatch(dispatchKeySet, self, memory_format);
+    }
+    
+    // aten::positive(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor positive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::positive::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+    inline const at::Tensor & resize_as_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_as_::redispatch(dispatchKeySet, self, the_template, memory_format);
+    }
+    
+    // aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
+    inline const at::Tensor & resize_as_sparse_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template) {
+        return at::_ops::resize_as_sparse_::redispatch(dispatchKeySet, self, the_template);
+    }
+    
+    // aten::zero_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & zero_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::zero_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::sub_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::sub_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor sub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::sub_Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & sub_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::sub__Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+    inline at::Tensor sub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::sub_Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & sub_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::sub__Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & subtract_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::subtract_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & subtract_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::subtract_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor subtract(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::subtract_Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & subtract_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::subtract__Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+    inline at::Tensor subtract(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::subtract_Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & subtract_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::subtract__Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor rsub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::rsub_Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & heaviside_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & values) {
+        return at::_ops::heaviside_out::redispatch(dispatchKeySet, self, values, out);
+    }
+    
+    // aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & heaviside_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & values, at::Tensor & out) {
+        return at::_ops::heaviside_out::redispatch(dispatchKeySet, self, values, out);
+    }
+    
+    // aten::heaviside(Tensor self, Tensor values) -> Tensor
+    inline at::Tensor heaviside(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & values) {
+        return at::_ops::heaviside::redispatch(dispatchKeySet, self, values);
+    }
+    
+    // aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
+    inline at::Tensor & heaviside_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & values) {
+        return at::_ops::heaviside_::redispatch(dispatchKeySet, self, values);
+    }
+    
+    // aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+    inline at::Tensor rsub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::rsub_Scalar::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor _sparse_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::_sparse_addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha);
+    }
+    
+    // aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sparse_sampled_addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::sparse_sampled_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sparse_sampled_addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::sparse_sampled_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor sparse_sampled_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::sparse_sampled_addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha);
+    }
+    
+    // aten::_sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
+    inline ::std::tuple _sparse_mm_reduce_impl(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::string_view reduce) {
+        return at::_ops::_sparse_mm_reduce_impl::redispatch(dispatchKeySet, self, other, reduce);
+    }
+    
+    // aten::_sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)
+    inline ::std::tuple _sparse_mm_reduce_impl_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask) {
+        return at::_ops::_sparse_mm_reduce_impl_backward::redispatch(dispatchKeySet, self, grad_out, weight, reduce, arg_out, output_mask);
+    }
+    
+    // aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha);
+    }
+    
+    // aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & addmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addmm_::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha);
+    }
+    
+    // aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _addmm_activation_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) {
+        return at::_ops::_addmm_activation_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu, out);
+    }
+    
+    // aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _addmm_activation_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu, at::Tensor & out) {
+        return at::_ops::_addmm_activation_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu, out);
+    }
+    
+    // aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
+    inline at::Tensor _addmm_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) {
+        return at::_ops::_addmm_activation::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu);
+    }
+    
+    // aten::_scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False) -> (Tensor, Tensor)
+    inline ::std::tuple _scaled_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const c10::optional & bias={}, c10::optional out_dtype=c10::nullopt, const c10::optional & scale_a={}, const c10::optional & scale_b={}, const c10::optional & scale_result={}, bool use_fast_accum=false) {
+        return at::_ops::_scaled_mm::redispatch(dispatchKeySet, self, mat2, bias, out_dtype, scale_a, scale_b, scale_result, use_fast_accum);
+    }
+    
+    // aten::_scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _scaled_mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & out_amax, const at::Tensor & self, const at::Tensor & mat2, const c10::optional & bias={}, c10::optional out_dtype=c10::nullopt, const c10::optional & scale_a={}, const c10::optional & scale_b={}, const c10::optional & scale_result={}, bool use_fast_accum=false) {
+        return at::_ops::_scaled_mm_out::redispatch(dispatchKeySet, self, mat2, bias, out_dtype, scale_a, scale_b, scale_result, use_fast_accum, out, out_amax);
+    }
+    
+    // aten::_scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _scaled_mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const c10::optional & bias, c10::optional out_dtype, const c10::optional & scale_a, const c10::optional & scale_b, const c10::optional & scale_result, bool use_fast_accum, at::Tensor & out, at::Tensor & out_amax) {
+        return at::_ops::_scaled_mm_out::redispatch(dispatchKeySet, self, mat2, bias, out_dtype, scale_a, scale_b, scale_result, use_fast_accum, out, out_amax);
+    }
+    
+    // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_compressed_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_compressed_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_csr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_csr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_csc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_csc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_bsr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_bsr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_bsc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_bsc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::TensorOptions options) {
+        return at::_ops::sparse_compressed_tensor_comp_plain_value::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_compressed_tensor_comp_plain_value::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::TensorOptions options) {
+        return at::_ops::sparse_csr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_csr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::TensorOptions options) {
+        return at::_ops::sparse_csc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_csc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::TensorOptions options) {
+        return at::_ops::sparse_bsr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_bsr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::TensorOptions options) {
+        return at::_ops::sparse_bsc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_bsc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_compressed_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_compressed_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_compressed_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_compressed_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_csr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_sparse_csr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_csr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_csr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_csc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_sparse_csc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_csc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_csc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_bsr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_sparse_bsr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_bsr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_bsr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_bsc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) {
+        return at::_ops::_sparse_bsc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _sparse_bsc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_bsc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::sparse_coo_tensor_size::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::sparse_coo_tensor_size::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options={}, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::sparse_coo_tensor_indices::redispatch(dispatchKeySet, indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced);
+    }
+    
+    // aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced) {
+        return at::_ops::sparse_coo_tensor_indices::redispatch(dispatchKeySet, indices, values, dtype, layout, device, pin_memory, is_coalesced);
+    }
+    
+    // aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::sparse_coo_tensor_indices_size::redispatch(dispatchKeySet, indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced);
+    }
+    
+    // aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced) {
+        return at::_ops::sparse_coo_tensor_indices_size::redispatch(dispatchKeySet, indices, values, size, dtype, layout, device, pin_memory, is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced) {
+        return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options={}, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced) {
+        return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, size, dtype, layout, device, pin_memory, is_coalesced);
+    }
+    
+    // aten::_validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> ()
+    inline void _validate_sparse_coo_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_validate_sparse_coo_tensor_args::redispatch(dispatchKeySet, indices, values, size, is_coalesced);
+    }
+    
+    // aten::_validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> ()
+    inline void _validate_sparse_compressed_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout) {
+        return at::_ops::_validate_sparse_compressed_tensor_args::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, layout);
+    }
+    
+    // aten::_validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
+    inline void _validate_sparse_csr_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size) {
+        return at::_ops::_validate_sparse_csr_tensor_args::redispatch(dispatchKeySet, crow_indices, col_indices, values, size);
+    }
+    
+    // aten::_validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()
+    inline void _validate_sparse_csc_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size) {
+        return at::_ops::_validate_sparse_csc_tensor_args::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size);
+    }
+    
+    // aten::_validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
+    inline void _validate_sparse_bsr_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size) {
+        return at::_ops::_validate_sparse_bsr_tensor_args::redispatch(dispatchKeySet, crow_indices, col_indices, values, size);
+    }
+    
+    // aten::_validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()
+    inline void _validate_sparse_bsc_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size) {
+        return at::_ops::_validate_sparse_bsc_tensor_args::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::TensorOptions options) {
+        return at::_ops::_sparse_coo_tensor_with_dims::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::_sparse_coo_tensor_with_dims::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, dtype, layout, device, pin_memory, is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors_symint(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
+    inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors_symint(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, dtype, layout, device, pin_memory, is_coalesced);
+    }
+    
+    // aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
+    inline const at::Tensor & sparse_resize_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) {
+        return at::_ops::sparse_resize_::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim);
+    }
+    
+    // aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
+    inline const at::Tensor & sparse_resize_and_clear_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) {
+        return at::_ops::sparse_resize_and_clear_::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim);
+    }
+    
+    // aten::sparse_mask(Tensor self, Tensor mask) -> Tensor
+    inline at::Tensor sparse_mask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask) {
+        return at::_ops::sparse_mask::redispatch(dispatchKeySet, self, mask);
+    }
+    
+    // aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
+    inline at::Tensor _sparse_mask_projection(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches=false) {
+        return at::_ops::_sparse_mask_projection::redispatch(dispatchKeySet, self, mask, accumulate_matches);
+    }
+    
+    // aten::_to_cpu(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector _to_cpu(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::_to_cpu::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
+    inline at::Tensor to_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype=c10::nullopt, c10::optional masked_grad=c10::nullopt) {
+        return at::_ops::to_dense::redispatch(dispatchKeySet, self, dtype, masked_grad);
+    }
+    
+    // aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
+    inline at::Tensor _to_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype=c10::nullopt, c10::optional masked_grad=c10::nullopt) {
+        return at::_ops::_to_dense::redispatch(dispatchKeySet, self, dtype, masked_grad);
+    }
+    
+    // aten::to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor
+    inline at::Tensor to_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, c10::optional masked_grad=c10::nullopt) {
+        return at::_ops::to_dense_backward::redispatch(dispatchKeySet, grad, input, masked_grad);
+    }
+    
+    // aten::sparse_dim(Tensor self) -> int
+    inline int64_t sparse_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sparse_dim::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_dimI(Tensor self) -> int
+    inline int64_t _dimI(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_dimI::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::dense_dim(Tensor self) -> int
+    inline int64_t dense_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::dense_dim::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_dimV(Tensor self) -> int
+    inline int64_t _dimV(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_dimV::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_nnz(Tensor self) -> int
+    inline int64_t _nnz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_nnz::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::coalesce(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor coalesce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::coalesce::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_coalesce(Tensor self) -> Tensor
+    inline at::Tensor _coalesce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_coalesce::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_coalesced(Tensor self) -> bool
+    inline bool is_coalesced(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::is_coalesced::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_indices(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor _indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_indices::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_values(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor _values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_values::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)
+    inline at::Tensor & _coalesced_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, bool coalesced) {
+        return at::_ops::_coalesced_::redispatch(dispatchKeySet, self, coalesced);
+    }
+    
+    // aten::indices(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::indices::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::values(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::values::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::crow_indices(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor crow_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::crow_indices::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::col_indices(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor col_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::col_indices::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::ccol_indices(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor ccol_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::ccol_indices::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::row_indices(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor row_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::row_indices::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hspmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mat1, const at::Tensor & mat2) {
+        return at::_ops::hspmm_out::redispatch(dispatchKeySet, mat1, mat2, out);
+    }
+    
+    // aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hspmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat2, at::Tensor & out) {
+        return at::_ops::hspmm_out::redispatch(dispatchKeySet, mat1, mat2, out);
+    }
+    
+    // aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor
+    inline at::Tensor hspmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat2) {
+        return at::_ops::hspmm::redispatch(dispatchKeySet, mat1, mat2);
+    }
+    
+    // aten::copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+    inline at::Tensor & copy_sparse_to_sparse_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & src, bool non_blocking=false) {
+        return at::_ops::copy_sparse_to_sparse_::redispatch(dispatchKeySet, self, src, non_blocking);
+    }
+    
+    // aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
+    inline ::std::vector unbind(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0) {
+        return at::_ops::unbind_int::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]
+    inline ::std::vector unbind(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) {
+        return at::_ops::unbind_Dimname::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+    inline at::Tensor to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim) {
+        return at::_ops::to_sparse_sparse_dim::redispatch(dispatchKeySet, self, sparse_dim);
+    }
+    
+    // aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+    inline at::Tensor _to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim) {
+        return at::_ops::_to_sparse_sparse_dim::redispatch(dispatchKeySet, self, sparse_dim);
+    }
+    
+    // aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+    inline at::Tensor to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional layout=c10::nullopt, at::OptionalIntArrayRef blocksize=c10::nullopt, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::to_sparse::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim);
+    }
+    
+    // aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+    inline at::Tensor _to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional layout=c10::nullopt, at::OptionalIntArrayRef blocksize=c10::nullopt, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim);
+    }
+    
+    // aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+    inline at::Tensor to_sparse_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::to_sparse_csr::redispatch(dispatchKeySet, self, dense_dim);
+    }
+    
+    // aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+    inline at::Tensor _to_sparse_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_csr::redispatch(dispatchKeySet, self, dense_dim);
+    }
+    
+    // aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+    inline at::Tensor to_sparse_csc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::to_sparse_csc::redispatch(dispatchKeySet, self, dense_dim);
+    }
+    
+    // aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+    inline at::Tensor _to_sparse_csc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_csc::redispatch(dispatchKeySet, self, dense_dim);
+    }
+    
+    // aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+    inline at::Tensor to_sparse_bsr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::to_sparse_bsr::redispatch(dispatchKeySet, self, blocksize, dense_dim);
+    }
+    
+    // aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+    inline at::Tensor _to_sparse_bsr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_bsr::redispatch(dispatchKeySet, self, blocksize, dense_dim);
+    }
+    
+    // aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+    inline at::Tensor to_sparse_bsc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::to_sparse_bsc::redispatch(dispatchKeySet, self, blocksize, dense_dim);
+    }
+    
+    // aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+    inline at::Tensor _to_sparse_bsc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_bsc::redispatch(dispatchKeySet, self, blocksize, dense_dim);
+    }
+    
+    // aten::_to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor)
+    inline ::std::tuple _to_sparse_semi_structured(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense) {
+        return at::_ops::_to_sparse_semi_structured::redispatch(dispatchKeySet, dense);
+    }
+    
+    // aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor to_mkldnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::to_mkldnn::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor
+    inline at::Tensor mkldnn_reorder_conv2d_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=c10::nullopt) {
+        return at::_ops::mkldnn_reorder_conv2d_weight::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*input_size)) : c10::nullopt);
+    }
+    
+    // aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor
+    inline at::Tensor mkldnn_reorder_conv2d_weight_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=c10::nullopt) {
+        return at::_ops::mkldnn_reorder_conv2d_weight::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size);
+    }
+    
+    // aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor mkldnn_reorder_conv3d_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::mkldnn_reorder_conv3d_weight::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups);
+    }
+    
+    // aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+    inline at::Tensor mkldnn_reorder_conv3d_weight_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::mkldnn_reorder_conv3d_weight::redispatch(dispatchKeySet, self, padding, stride, dilation, groups);
+    }
+    
+    // aten::to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor
+    inline at::Tensor to_mkldnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input) {
+        return at::_ops::to_mkldnn_backward::redispatch(dispatchKeySet, grad, input);
+    }
+    
+    // aten::quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor
+    inline at::Tensor quantize_per_tensor_dynamic(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool reduce_range) {
+        return at::_ops::quantize_per_tensor_dynamic::redispatch(dispatchKeySet, self, dtype, reduce_range);
+    }
+    
+    // aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
+    inline at::Tensor quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) {
+        return at::_ops::quantize_per_tensor::redispatch(dispatchKeySet, self, scale, zero_point, dtype);
+    }
+    
+    // aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor
+    inline at::Tensor quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) {
+        return at::_ops::quantize_per_tensor_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, dtype);
+    }
+    
+    // aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[]
+    inline ::std::vector quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) {
+        return at::_ops::quantize_per_tensor_tensors::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype);
+    }
+    
+    // aten::quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor
+    inline at::Tensor quantize_per_channel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) {
+        return at::_ops::quantize_per_channel::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype);
+    }
+    
+    // aten::dequantize.self(Tensor self) -> Tensor
+    inline at::Tensor dequantize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::dequantize_self::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::dequantize.tensors(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector dequantize(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::dequantize_tensors::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::q_scale(Tensor self) -> float
+    inline double q_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::q_scale::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::q_zero_point(Tensor self) -> int
+    inline int64_t q_zero_point(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::q_zero_point::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::q_per_channel_scales(Tensor self) -> Tensor
+    inline at::Tensor q_per_channel_scales(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::q_per_channel_scales::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::q_per_channel_zero_points(Tensor self) -> Tensor
+    inline at::Tensor q_per_channel_zero_points(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::q_per_channel_zero_points::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::q_per_channel_axis(Tensor self) -> int
+    inline int64_t q_per_channel_axis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::q_per_channel_axis::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::int_repr(Tensor self) -> Tensor
+    inline at::Tensor int_repr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::int_repr::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor
+    inline at::Tensor _make_per_tensor_quantized_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point) {
+        return at::_ops::_make_per_tensor_quantized_tensor::redispatch(dispatchKeySet, self, scale, zero_point);
+    }
+    
+    // aten::_make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor
+    inline at::Tensor _make_per_channel_quantized_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) {
+        return at::_ops::_make_per_channel_quantized_tensor::redispatch(dispatchKeySet, self, scale, zero_point, axis);
+    }
+    
+    // aten::qscheme(Tensor self) -> QScheme
+    inline at::QScheme qscheme(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::qscheme::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
+    inline at::Tensor fake_quantize_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_tensor_affine::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max);
+    }
+    
+    // aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
+    inline at::Tensor fake_quantize_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_tensor_affine_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max);
+    }
+    
+    // aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+    inline ::std::tuple fake_quantize_per_tensor_affine_cachemask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_tensor_affine_cachemask::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max);
+    }
+    
+    // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+    inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max);
+    }
+    
+    // aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
+    inline at::Tensor fake_quantize_per_tensor_affine_cachemask_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & mask) {
+        return at::_ops::fake_quantize_per_tensor_affine_cachemask_backward::redispatch(dispatchKeySet, grad, mask);
+    }
+    
+    // aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
+    inline at::Tensor _fake_quantize_learnable_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) {
+        return at::_ops::_fake_quantize_learnable_per_tensor_affine::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor);
+    }
+    
+    // aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _fake_quantize_learnable_per_tensor_affine_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) {
+        return at::_ops::_fake_quantize_learnable_per_tensor_affine_backward::redispatch(dispatchKeySet, grad, self, scale, zero_point, quant_min, quant_max, grad_factor);
+    }
+    
+    // aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
+    inline at::Tensor fake_quantize_per_channel_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_channel_affine::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max);
+    }
+    
+    // aten::fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+    inline ::std::tuple fake_quantize_per_channel_affine_cachemask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_channel_affine_cachemask::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max);
+    }
+    
+    // aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
+    inline at::Tensor fake_quantize_per_channel_affine_cachemask_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & mask) {
+        return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::redispatch(dispatchKeySet, grad, mask);
+    }
+    
+    // aten::_fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
+    inline at::Tensor _fake_quantize_learnable_per_channel_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) {
+        return at::_ops::_fake_quantize_learnable_per_channel_affine::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor);
+    }
+    
+    // aten::_fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _fake_quantize_learnable_per_channel_affine_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) {
+        return at::_ops::_fake_quantize_learnable_per_channel_affine_backward::redispatch(dispatchKeySet, grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor);
+    }
+    
+    // aten::fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor
+    inline at::Tensor fused_moving_avg_obs_fake_quant(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) {
+        return at::_ops::fused_moving_avg_obs_fake_quant::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
+    }
+    
+    // aten::_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)
+    inline ::std::tuple _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) {
+        return at::_ops::_fused_moving_avg_obs_fq_helper::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
+    }
+    
+    // aten::_choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)
+    inline ::std::tuple _choose_qparams_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool reduce_range=false) {
+        return at::_ops::_choose_qparams_per_tensor::redispatch(dispatchKeySet, self, reduce_range);
+    }
+    
+    // aten::_saturate_weight_to_fp16(Tensor weight) -> Tensor
+    inline at::Tensor _saturate_weight_to_fp16(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight) {
+        return at::_ops::_saturate_weight_to_fp16::redispatch(dispatchKeySet, weight);
+    }
+    
+    // aten::choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)
+    inline ::std::tuple choose_qparams_optimized(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width) {
+        return at::_ops::choose_qparams_optimized::redispatch(dispatchKeySet, input, numel, n_bins, ratio, bit_width);
+    }
+    
+    // aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)
+    inline at::Tensor _autocast_to_reduced_precision(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) {
+        return at::_ops::_autocast_to_reduced_precision::redispatch(dispatchKeySet, self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype);
+    }
+    
+    // aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)
+    inline at::Tensor _autocast_to_full_precision(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool cuda_enabled, bool cpu_enabled) {
+        return at::_ops::_autocast_to_full_precision::redispatch(dispatchKeySet, self, cuda_enabled, cpu_enabled);
+    }
+    
+    // aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor _to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, bool non_blocking=false, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::_to_copy::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor _to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format) {
+        return at::_ops::_to_copy::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, non_blocking, memory_format);
+    }
+    
+    // aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+    inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::to_dtype_layout::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, copy, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+    }
+    
+    // aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+    inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, bool copy, c10::optional memory_format) {
+        return at::_ops::to_dtype_layout::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, non_blocking, copy, memory_format);
+    }
+    
+    // aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+    inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::to_device::redispatch(dispatchKeySet, self, device, dtype, non_blocking, copy, memory_format);
+    }
+    
+    // aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+    inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::to_dtype::redispatch(dispatchKeySet, self, dtype, non_blocking, copy, memory_format);
+    }
+    
+    // aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+    inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::to_other::redispatch(dispatchKeySet, self, other, non_blocking, copy, memory_format);
+    }
+    
+    // aten::meshgrid(Tensor[] tensors) -> Tensor[]
+    inline ::std::vector meshgrid(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::meshgrid::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[]
+    inline ::std::vector meshgrid(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, c10::string_view indexing) {
+        return at::_ops::meshgrid_indexing::redispatch(dispatchKeySet, tensors, indexing);
+    }
+    
+    // aten::cartesian_prod(Tensor[] tensors) -> Tensor
+    inline at::Tensor cartesian_prod(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::cartesian_prod::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor
+    inline at::Tensor combinations(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t r=2, bool with_replacement=false) {
+        return at::_ops::combinations::redispatch(dispatchKeySet, self, r, with_replacement);
+    }
+    
+    // aten::item(Tensor self) -> Scalar
+    inline at::Scalar item(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::item::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType
+    inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & tensor, const at::Tensor & other) {
+        return at::_ops::result_type_Tensor::redispatch(dispatchKeySet, tensor, other);
+    }
+    
+    // aten::result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType
+    inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & tensor, const at::Scalar & other) {
+        return at::_ops::result_type_Scalar::redispatch(dispatchKeySet, tensor, other);
+    }
+    
+    // aten::result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType
+    inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Scalar & scalar, const at::Tensor & tensor) {
+        return at::_ops::result_type_Scalar_Tensor::redispatch(dispatchKeySet, scalar, tensor);
+    }
+    
+    // aten::result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType
+    inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Scalar & scalar1, const at::Scalar & scalar2) {
+        return at::_ops::result_type_Scalar_Scalar::redispatch(dispatchKeySet, scalar1, scalar2);
+    }
+    
+    // aten::can_cast(ScalarType from, ScalarType to) -> bool
+    inline bool can_cast(c10::DispatchKeySet dispatchKeySet, at::ScalarType from, at::ScalarType to) {
+        return at::_ops::can_cast::redispatch(dispatchKeySet, from, to);
+    }
+    
+    // aten::promote_types(ScalarType type1, ScalarType type2) -> ScalarType
+    inline at::ScalarType promote_types(c10::DispatchKeySet dispatchKeySet, at::ScalarType type1, at::ScalarType type2) {
+        return at::_ops::promote_types::redispatch(dispatchKeySet, type1, type2);
+    }
+    
+    // aten::_local_scalar_dense(Tensor self) -> Scalar
+    inline at::Scalar _local_scalar_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_local_scalar_dense::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _lstm_mps(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::_lstm_mps::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
+    }
+    
+    // aten::lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
+    inline ::std::tuple,::std::vector> lstm_mps_backward(c10::DispatchKeySet dispatchKeySet, const c10::optional & grad_y, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::lstm_mps_backward::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
+    }
+    
+    // aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _thnn_fused_lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const c10::optional & input_bias={}, const c10::optional & hidden_bias={}) {
+        return at::_ops::_thnn_fused_lstm_cell::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias);
+    }
+    
+    // aten::_thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _thnn_fused_lstm_cell_backward_impl(c10::DispatchKeySet dispatchKeySet, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) {
+        return at::_ops::_thnn_fused_lstm_cell_backward_impl::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias);
+    }
+    
+    // aten::_thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _thnn_fused_lstm_cell_backward(c10::DispatchKeySet dispatchKeySet, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) {
+        return at::_ops::_thnn_fused_lstm_cell_backward::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias);
+    }
+    
+    // aten::_thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _thnn_differentiable_lstm_cell_backward(c10::DispatchKeySet dispatchKeySet, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const c10::optional & input_bias, const c10::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy) {
+        return at::_ops::_thnn_differentiable_lstm_cell_backward::redispatch(dispatchKeySet, grad_hy, grad_cy, input_gates, hidden_gates, input_bias, hidden_bias, cx, cy);
+    }
+    
+    // aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)
+    inline ::std::tuple _thnn_fused_gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional & input_bias={}, const c10::optional & hidden_bias={}) {
+        return at::_ops::_thnn_fused_gru_cell::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias);
+    }
+    
+    // aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _thnn_fused_gru_cell_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) {
+        return at::_ops::_thnn_fused_gru_cell_backward::redispatch(dispatchKeySet, grad_hy, workspace, has_bias);
+    }
+    
+    // aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _thnn_differentiable_gru_cell_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional & input_bias, const c10::optional & hidden_bias) {
+        return at::_ops::_thnn_differentiable_gru_cell_backward::redispatch(dispatchKeySet, grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias);
+    }
+    
+    // aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple lstm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::lstm_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
+    }
+    
+    // aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple lstm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) {
+        return at::_ops::lstm_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional);
+    }
+    
+    // aten::gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+    inline ::std::tuple gru(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::gru_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
+    }
+    
+    // aten::gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+    inline ::std::tuple gru(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) {
+        return at::_ops::gru_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional);
+    }
+    
+    // aten::rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+    inline ::std::tuple rnn_tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::rnn_tanh_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
+    }
+    
+    // aten::rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+    inline ::std::tuple rnn_tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) {
+        return at::_ops::rnn_tanh_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional);
+    }
+    
+    // aten::rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+    inline ::std::tuple rnn_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::rnn_relu_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
+    }
+    
+    // aten::rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+    inline ::std::tuple rnn_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) {
+        return at::_ops::rnn_relu_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional);
+    }
+    
+    // aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
+    inline ::std::tuple lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const c10::optional & b_ih={}, const c10::optional & b_hh={}) {
+        return at::_ops::lstm_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh);
+    }
+    
+    // aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+    inline at::Tensor gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const c10::optional & b_ih={}, const c10::optional & b_hh={}) {
+        return at::_ops::gru_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh);
+    }
+    
+    // aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+    inline at::Tensor rnn_tanh_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const c10::optional & b_ih={}, const c10::optional & b_hh={}) {
+        return at::_ops::rnn_tanh_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh);
+    }
+    
+    // aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+    inline at::Tensor rnn_relu_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const c10::optional & b_ih={}, const c10::optional & b_hh={}) {
+        return at::_ops::rnn_relu_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh);
+    }
+    
+    // aten::quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)
+    inline ::std::tuple quantized_lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) {
+        return at::_ops::quantized_lstm_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh);
+    }
+    
+    // aten::quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+    inline at::Tensor quantized_gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) {
+        return at::_ops::quantized_gru_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh);
+    }
+    
+    // aten::quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+    inline at::Tensor quantized_rnn_relu_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) {
+        return at::_ops::quantized_rnn_relu_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh);
+    }
+    
+    // aten::quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+    inline at::Tensor quantized_rnn_tanh_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) {
+        return at::_ops::quantized_rnn_tanh_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh);
+    }
+    
+    // aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
+    inline ::std::tuple _pack_padded_sequence(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & lengths, bool batch_first) {
+        return at::_ops::_pack_padded_sequence::redispatch(dispatchKeySet, input, lengths, batch_first);
+    }
+    
+    // aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor
+    inline at::Tensor _pack_padded_sequence_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) {
+        return at::_ops::_pack_padded_sequence_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(input_size), batch_sizes, batch_first);
+    }
+    
+    // aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor
+    inline at::Tensor _pack_padded_sequence_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) {
+        return at::_ops::_pack_padded_sequence_backward::redispatch(dispatchKeySet, grad, input_size, batch_sizes, batch_first);
+    }
+    
+    // aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)
+    inline ::std::tuple _pad_packed_sequence(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length) {
+        return at::_ops::_pad_packed_sequence::redispatch(dispatchKeySet, data, batch_sizes, batch_first, padding_value, total_length);
+    }
+    
+    // aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)
+    inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source) {
+        return at::_ops::set__source_Storage::redispatch(dispatchKeySet, self, source);
+    }
+    
+    // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+    inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) {
+        return at::_ops::set__source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+    }
+    
+    // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+    inline at::Tensor & set__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) {
+        return at::_ops::set__source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride);
+    }
+    
+    // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+    inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) {
+        return at::_ops::set__source_Tensor_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+    }
+    
+    // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+    inline at::Tensor & set__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) {
+        return at::_ops::set__source_Tensor_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride);
+    }
+    
+    // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)
+    inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source) {
+        return at::_ops::set__source_Tensor::redispatch(dispatchKeySet, self, source);
+    }
+    
+    // aten::set_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::set_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::lift(Tensor self) -> Tensor
+    inline at::Tensor lift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::lift::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::lift_fresh(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor lift_fresh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::lift_fresh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::lift_fresh_copy(Tensor self) -> Tensor
+    inline at::Tensor lift_fresh_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::lift_fresh_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::is_set_to(Tensor self, Tensor tensor) -> bool
+    inline bool is_set_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor) {
+        return at::_ops::is_set_to::redispatch(dispatchKeySet, self, tensor);
+    }
+    
+    // aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)
+    inline at::Tensor & masked_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
+        return at::_ops::masked_fill__Scalar::redispatch(dispatchKeySet, self, mask, value);
+    }
+    
+    // aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
+    inline at::Tensor masked_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
+        return at::_ops::masked_fill_Scalar::redispatch(dispatchKeySet, self, mask, value);
+    }
+    
+    // aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)
+    inline at::Tensor & masked_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
+        return at::_ops::masked_fill__Tensor::redispatch(dispatchKeySet, self, mask, value);
+    }
+    
+    // aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
+    inline at::Tensor masked_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
+        return at::_ops::masked_fill_Tensor::redispatch(dispatchKeySet, self, mask, value);
+    }
+    
+    // aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)
+    inline at::Tensor & masked_scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) {
+        return at::_ops::masked_scatter_::redispatch(dispatchKeySet, self, mask, source);
+    }
+    
+    // aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
+    inline at::Tensor masked_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) {
+        return at::_ops::masked_scatter::redispatch(dispatchKeySet, self, mask, source);
+    }
+    
+    // aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor
+    inline at::Tensor masked_scatter_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, at::IntArrayRef sizes) {
+        return at::_ops::masked_scatter_backward::redispatch(dispatchKeySet, grad_output, mask, c10::fromIntArrayRefSlow(sizes));
+    }
+    
+    // aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor
+    inline at::Tensor masked_scatter_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes) {
+        return at::_ops::masked_scatter_backward::redispatch(dispatchKeySet, grad_output, mask, sizes);
+    }
+    
+    // aten::_masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor
+    inline at::Tensor _masked_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, c10::optional dim=c10::nullopt, c10::optional mask_type=c10::nullopt) {
+        return at::_ops::_masked_softmax::redispatch(dispatchKeySet, self, mask, dim, mask_type);
+    }
+    
+    // aten::_masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor
+    inline at::Tensor _masked_softmax_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional dim=c10::nullopt) {
+        return at::_ops::_masked_softmax_backward::redispatch(dispatchKeySet, grad_output, output, mask, dim);
+    }
+    
+    // aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)
+    inline at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::view::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size));
+    }
+    
+    // aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)
+    inline at::Tensor view_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::view::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
+    inline at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) {
+        return at::_ops::view_dtype::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)
+    inline at::Tensor & put_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) {
+        return at::_ops::put_::redispatch(dispatchKeySet, self, index, source, accumulate);
+    }
+    
+    // aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
+    inline at::Tensor put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) {
+        return at::_ops::put::redispatch(dispatchKeySet, self, index, source, accumulate);
+    }
+    
+    // aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) {
+        return at::_ops::index_add_out::redispatch(dispatchKeySet, self, dim, index, source, alpha, out);
+    }
+    
+    // aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::index_add_out::redispatch(dispatchKeySet, self, dim, index, source, alpha, out);
+    }
+    
+    // aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & index_add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) {
+        return at::_ops::index_add_::redispatch(dispatchKeySet, self, dim, index, source, alpha);
+    }
+    
+    // aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor index_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) {
+        return at::_ops::index_add::redispatch(dispatchKeySet, self, dim, index, source, alpha);
+    }
+    
+    // aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+    inline at::Tensor index_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) {
+        return at::_ops::index_add_dimname::redispatch(dispatchKeySet, self, dim, index, source, alpha);
+    }
+    
+    // aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) {
+        return at::_ops::index_reduce_out::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self, out);
+    }
+    
+    // aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self, at::Tensor & out) {
+        return at::_ops::index_reduce_out::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self, out);
+    }
+    
+    // aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!)
+    inline at::Tensor & index_reduce_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) {
+        return at::_ops::index_reduce_::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self);
+    }
+    
+    // aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor
+    inline at::Tensor index_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) {
+        return at::_ops::index_reduce::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self);
+    }
+    
+    // aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
+    inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::index_fill__int_Scalar::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+    inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::index_fill_int_Scalar::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)
+    inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) {
+        return at::_ops::index_fill__int_Tensor::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
+    inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) {
+        return at::_ops::index_fill_int_Tensor::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)
+    inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::index_fill__Dimname_Scalar::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)
+    inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) {
+        return at::_ops::index_fill__Dimname_Tensor::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor
+    inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::index_fill_Dimname_Scalar::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor
+    inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) {
+        return at::_ops::index_fill_Dimname_Tensor::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+    inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_src::redispatch(dispatchKeySet, self, dim, index, src);
+    }
+    
+    // aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
+    inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter__src::redispatch(dispatchKeySet, self, dim, index, src);
+    }
+    
+    // aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_src_out::redispatch(dispatchKeySet, self, dim, index, src, out);
+    }
+    
+    // aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out) {
+        return at::_ops::scatter_src_out::redispatch(dispatchKeySet, self, dim, index, src, out);
+    }
+    
+    // aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+    inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::scatter_value::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
+    inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::scatter__value::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::scatter_value_out::redispatch(dispatchKeySet, self, dim, index, value, out);
+    }
+    
+    // aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::scatter_value_out::redispatch(dispatchKeySet, self, dim, index, value, out);
+    }
+    
+    // aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor
+    inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) {
+        return at::_ops::scatter_reduce::redispatch(dispatchKeySet, self, dim, index, src, reduce);
+    }
+    
+    // aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!)
+    inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) {
+        return at::_ops::scatter__reduce::redispatch(dispatchKeySet, self, dim, index, src, reduce);
+    }
+    
+    // aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) {
+        return at::_ops::scatter_reduce_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, out);
+    }
+    
+    // aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, at::Tensor & out) {
+        return at::_ops::scatter_reduce_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, out);
+    }
+    
+    // aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor
+    inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) {
+        return at::_ops::scatter_value_reduce::redispatch(dispatchKeySet, self, dim, index, value, reduce);
+    }
+    
+    // aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)
+    inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) {
+        return at::_ops::scatter__value_reduce::redispatch(dispatchKeySet, self, dim, index, value, reduce);
+    }
+    
+    // aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) {
+        return at::_ops::scatter_value_reduce_out::redispatch(dispatchKeySet, self, dim, index, value, reduce, out);
+    }
+    
+    // aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce, at::Tensor & out) {
+        return at::_ops::scatter_value_reduce_out::redispatch(dispatchKeySet, self, dim, index, value, reduce, out);
+    }
+    
+    // aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
+    inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_dimname_src::redispatch(dispatchKeySet, self, dim, index, src);
+    }
+    
+    // aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor
+    inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::scatter_dimname_value::redispatch(dispatchKeySet, self, dim, index, value);
+    }
+    
+    // aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+    inline at::Tensor scatter_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_add::redispatch(dispatchKeySet, self, dim, index, src);
+    }
+    
+    // aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
+    inline at::Tensor & scatter_add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_add_::redispatch(dispatchKeySet, self, dim, index, src);
+    }
+    
+    // aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_add_out::redispatch(dispatchKeySet, self, dim, index, src, out);
+    }
+    
+    // aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out) {
+        return at::_ops::scatter_add_out::redispatch(dispatchKeySet, self, dim, index, src, out);
+    }
+    
+    // aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
+    inline at::Tensor scatter_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) {
+        return at::_ops::scatter_add_dimname::redispatch(dispatchKeySet, self, dim, index, src);
+    }
+    
+    // aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor
+    inline at::Tensor scatter_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) {
+        return at::_ops::scatter_reduce_two::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self);
+    }
+    
+    // aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!)
+    inline at::Tensor & scatter_reduce_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) {
+        return at::_ops::scatter_reduce__two::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self);
+    }
+    
+    // aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) {
+        return at::_ops::scatter_reduce_two_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self, out);
+    }
+    
+    // aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scatter_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self, at::Tensor & out) {
+        return at::_ops::scatter_reduce_two_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self, out);
+    }
+    
+    // aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & eq_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::eq__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & eq_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::eq__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_and_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_and_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_and_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::bitwise_and_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_and_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_and_Scalar_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_and_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_and__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_and__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor __and__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__and___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor __and__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__and___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & __iand__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__iand___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & __iand__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__iand___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_or_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_or_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_or_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::bitwise_or_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_or_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_or_Scalar_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_or_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_or__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_or__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor __or__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__or___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor __or__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__or___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & __ior__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__ior___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & __ior__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__ior___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_xor_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_xor_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_xor_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::bitwise_xor_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_xor_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_xor_Scalar_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_xor_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_xor__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_xor__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor __xor__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__xor___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor __xor__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__xor___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & __ixor__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__ixor___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & __ixor__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__ixor___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor __lshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__lshift___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor __lshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__lshift___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & __ilshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__ilshift___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & __ilshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__ilshift___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_left_shift_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_left_shift__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_left_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_left_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_left_shift_Tensor_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_left_shift__Tensor_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_left_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::bitwise_left_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_left_shift_Scalar_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor __rshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__rshift___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor __rshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__rshift___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & __irshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__irshift___Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & __irshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__irshift___Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_right_shift_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_right_shift__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::bitwise_right_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_right_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_right_shift_Tensor_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_right_shift__Tensor_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::bitwise_right_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::bitwise_right_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_right_shift_Scalar_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
+    inline at::Tensor & tril_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::tril_::redispatch(dispatchKeySet, self, diagonal);
+    }
+    
+    // aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
+    inline at::Tensor & triu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::triu_::redispatch(dispatchKeySet, self, diagonal);
+    }
+    
+    // aten::digamma_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & digamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::digamma_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)
+    inline at::Tensor & lerp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
+        return at::_ops::lerp__Scalar::redispatch(dispatchKeySet, self, end, weight);
+    }
+    
+    // aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)
+    inline at::Tensor & lerp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) {
+        return at::_ops::lerp__Tensor::redispatch(dispatchKeySet, self, end, weight);
+    }
+    
+    // aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+    inline at::Tensor & addbmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addbmm_::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha);
+    }
+    
+    // aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out);
+    }
+    
+    // aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::addbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out);
+    }
+    
+    // aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    inline at::Tensor addbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::addbmm::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha);
+    }
+    
+    // aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t from, c10::optional to, c10::optional generator=c10::nullopt) {
+        return at::_ops::random__from::redispatch(dispatchKeySet, self, from, to, generator);
+    }
+    
+    // aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t to, c10::optional generator=c10::nullopt) {
+        return at::_ops::random__to::redispatch(dispatchKeySet, self, to, generator);
+    }
+    
+    // aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::random_::redispatch(dispatchKeySet, self, generator);
+    }
+    
+    // aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & uniform_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double from=0, double to=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::uniform_::redispatch(dispatchKeySet, self, from, to, generator);
+    }
+    
+    // aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & cauchy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double median=0, double sigma=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::cauchy_::redispatch(dispatchKeySet, self, median, sigma, generator);
+    }
+    
+    // aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & log_normal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double mean=1, double std=2, c10::optional generator=c10::nullopt) {
+        return at::_ops::log_normal_::redispatch(dispatchKeySet, self, mean, std, generator);
+    }
+    
+    // aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & exponential_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double lambd=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::exponential_::redispatch(dispatchKeySet, self, lambd, generator);
+    }
+    
+    // aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & geometric_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, c10::optional generator=c10::nullopt) {
+        return at::_ops::geometric_::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::diag_out::redispatch(dispatchKeySet, self, diagonal, out);
+    }
+    
+    // aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diag_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) {
+        return at::_ops::diag_out::redispatch(dispatchKeySet, self, diagonal, out);
+    }
+    
+    // aten::diag(Tensor self, int diagonal=0) -> Tensor
+    inline at::Tensor diag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::diag::redispatch(dispatchKeySet, self, diagonal);
+    }
+    
+    // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cross_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, c10::optional dim=c10::nullopt) {
+        return at::_ops::cross_out::redispatch(dispatchKeySet, self, other, dim, out);
+    }
+    
+    // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cross_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::optional dim, at::Tensor & out) {
+        return at::_ops::cross_out::redispatch(dispatchKeySet, self, other, dim, out);
+    }
+    
+    // aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor
+    inline at::Tensor cross(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::optional dim=c10::nullopt) {
+        return at::_ops::cross::redispatch(dispatchKeySet, self, other, dim);
+    }
+    
+    // aten::triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & triu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out);
+    }
+    
+    // aten::triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & triu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) {
+        return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out);
+    }
+    
+    // aten::triu(Tensor self, int diagonal=0) -> Tensor
+    inline at::Tensor triu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::triu::redispatch(dispatchKeySet, self, diagonal);
+    }
+    
+    // aten::tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tril_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out);
+    }
+    
+    // aten::tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tril_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) {
+        return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out);
+    }
+    
+    // aten::tril(Tensor self, int diagonal=0) -> Tensor
+    inline at::Tensor tril(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) {
+        return at::_ops::tril::redispatch(dispatchKeySet, self, diagonal);
+    }
+    
+    // aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor tril_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset=0, at::TensorOptions options=at::kLong) {
+        return at::_ops::tril_indices::redispatch(dispatchKeySet, row, col, offset, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor tril_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::tril_indices::redispatch(dispatchKeySet, row, col, offset, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor triu_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset=0, at::TensorOptions options=at::kLong) {
+        return at::_ops::triu_indices::redispatch(dispatchKeySet, row, col, offset, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor triu_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::triu_indices::redispatch(dispatchKeySet, row, col, offset, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::trace(Tensor self) -> Tensor
+    inline at::Tensor trace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::trace::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor
+    inline at::Tensor trace_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef sizes) {
+        return at::_ops::trace_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(sizes));
+    }
+    
+    // aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor
+    inline at::Tensor trace_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef sizes) {
+        return at::_ops::trace_backward::redispatch(dispatchKeySet, grad, sizes);
+    }
+    
+    // aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ne_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::ne_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ne_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::ne_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ne.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor ne(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::ne_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ne_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ne_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ne_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::ne_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ne.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor ne(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ne_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & ne_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::ne__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & ne_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ne__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & not_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::not_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & not_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::not_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor not_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::not_equal_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & not_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::not_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & not_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::not_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor not_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::not_equal_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & not_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::not_equal__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & not_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::not_equal__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::eq_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::eq_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::eq.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor eq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::eq_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::eq_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & eq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::eq_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::eq.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor eq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::eq_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ge_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::ge_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ge_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::ge_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ge.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor ge(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::ge_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ge_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ge_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ge_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::ge_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::ge.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor ge(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ge_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & ge_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::ge__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & ge_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::ge__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::greater_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::greater_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor greater_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::greater_equal_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::greater_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::greater_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor greater_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::greater_equal_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & greater_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::greater_equal__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & greater_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::greater_equal__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & le_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::le_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & le_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::le_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::le.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor le(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::le_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & le_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::le_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & le_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::le_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::le.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor le(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::le_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & le_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::le__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & le_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::le__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::less_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::less_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor less_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::less_equal_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::less_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::less_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor less_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::less_equal_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & less_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::less_equal__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & less_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::less_equal__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::gt_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::gt_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::gt.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor gt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::gt_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::gt_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::gt_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::gt.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor gt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::gt_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & gt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::gt__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & gt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::gt__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::greater_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::greater_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor greater(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::greater_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::greater_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & greater_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::greater_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::greater.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor greater(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::greater_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & greater_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::greater__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & greater_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::greater__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::lt_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::lt_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::lt.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor lt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::lt_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::lt_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::lt_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::lt.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor lt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::lt_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & lt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::lt__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & lt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::lt__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::less_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::less_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor less(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::less_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::less_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & less_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::less_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::less.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor less(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::less_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & less_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::less__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & less_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::less__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & take_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & index) {
+        return at::_ops::take_out::redispatch(dispatchKeySet, self, index, out);
+    }
+    
+    // aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & take_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, at::Tensor & out) {
+        return at::_ops::take_out::redispatch(dispatchKeySet, self, index, out);
+    }
+    
+    // aten::take(Tensor self, Tensor index) -> Tensor
+    inline at::Tensor take(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index) {
+        return at::_ops::take::redispatch(dispatchKeySet, self, index);
+    }
+    
+    // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & take_along_dim_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::optional dim=c10::nullopt) {
+        return at::_ops::take_along_dim_out::redispatch(dispatchKeySet, self, indices, dim, out);
+    }
+    
+    // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & take_along_dim_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::optional dim, at::Tensor & out) {
+        return at::_ops::take_along_dim_out::redispatch(dispatchKeySet, self, indices, dim, out);
+    }
+    
+    // aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor
+    inline at::Tensor take_along_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::optional dim=c10::nullopt) {
+        return at::_ops::take_along_dim::redispatch(dispatchKeySet, self, indices, dim);
+    }
+    
+    // aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index) {
+        return at::_ops::index_select_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, at::Tensor & out) {
+        return at::_ops::index_select_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::index_select(Tensor self, int dim, Tensor index) -> Tensor
+    inline at::Tensor index_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index) {
+        return at::_ops::index_select::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, const at::Tensor & index) {
+        return at::_ops::index_select_dimname_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, at::Tensor & out) {
+        return at::_ops::index_select_dimname_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor
+    inline at::Tensor index_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index) {
+        return at::_ops::index_select_dimname::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor
+    inline at::Tensor index_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index) {
+        return at::_ops::index_select_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(self_sizes), dim, index);
+    }
+    
+    // aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor
+    inline at::Tensor index_select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index) {
+        return at::_ops::index_select_backward::redispatch(dispatchKeySet, grad, self_sizes, dim, index);
+    }
+    
+    // aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask) {
+        return at::_ops::masked_select_out::redispatch(dispatchKeySet, self, mask, out);
+    }
+    
+    // aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, at::Tensor & out) {
+        return at::_ops::masked_select_out::redispatch(dispatchKeySet, self, mask, out);
+    }
+    
+    // aten::masked_select(Tensor self, Tensor mask) -> Tensor
+    inline at::Tensor masked_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask) {
+        return at::_ops::masked_select::redispatch(dispatchKeySet, self, mask);
+    }
+    
+    // aten::masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor
+    inline at::Tensor masked_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask) {
+        return at::_ops::masked_select_backward::redispatch(dispatchKeySet, grad, input, mask);
+    }
+    
+    // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::nonzero_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::nonzero_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::nonzero(Tensor self) -> Tensor
+    inline at::Tensor nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::nonzero::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nonzero_static_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t size, int64_t fill_value=-1) {
+        return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out);
+    }
+    
+    // aten::nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nonzero_static_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, int64_t fill_value, at::Tensor & out) {
+        return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out);
+    }
+    
+    // aten::nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor
+    inline at::Tensor nonzero_static(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, int64_t fill_value=-1) {
+        return at::_ops::nonzero_static::redispatch(dispatchKeySet, self, size, fill_value);
+    }
+    
+    // aten::nonzero_numpy(Tensor self) -> Tensor[]
+    inline ::std::vector nonzero_numpy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::nonzero_numpy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::argwhere(Tensor self) -> Tensor
+    inline at::Tensor argwhere(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::argwhere::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gather_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad=false) {
+        return at::_ops::gather_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out);
+    }
+    
+    // aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gather_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out) {
+        return at::_ops::gather_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out);
+    }
+    
+    // aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+    inline at::Tensor gather(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad=false) {
+        return at::_ops::gather::redispatch(dispatchKeySet, self, dim, index, sparse_grad);
+    }
+    
+    // aten::gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor
+    inline at::Tensor gather_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) {
+        return at::_ops::gather_backward::redispatch(dispatchKeySet, grad, self, dim, index, sparse_grad);
+    }
+    
+    // aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gather_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) {
+        return at::_ops::gather_dimname_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out);
+    }
+    
+    // aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & gather_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out) {
+        return at::_ops::gather_dimname_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out);
+    }
+    
+    // aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+    inline at::Tensor gather(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) {
+        return at::_ops::gather_dimname::redispatch(dispatchKeySet, self, dim, index, sparse_grad);
+    }
+    
+    // aten::_gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor
+    inline at::Tensor _gather_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad) {
+        return at::_ops::_gather_sparse_backward::redispatch(dispatchKeySet, self, dim, index, grad);
+    }
+    
+    // aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addcmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) {
+        return at::_ops::addcmul_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addcmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::addcmul_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+    inline at::Tensor addcmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) {
+        return at::_ops::addcmul::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
+    inline at::Tensor & addcmul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) {
+        return at::_ops::addcmul_::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) {
+        return at::_ops::addcdiv_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & addcdiv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::addcdiv_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+    inline at::Tensor addcdiv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) {
+        return at::_ops::addcdiv::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
+    inline at::Tensor & addcdiv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) {
+        return at::_ops::addcdiv_::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor
+    inline at::Tensor cross_entropy_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100, double label_smoothing=0.0) {
+        return at::_ops::cross_entropy_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, label_smoothing);
+    }
+    
+    // aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor
+    inline at::Tensor cross_entropy_loss_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100, double label_smoothing=0.0) {
+        return at::_ops::cross_entropy_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, label_smoothing);
+    }
+    
+    // aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)
+    inline ::std::tuple triangular_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & X, at::Tensor & M, const at::Tensor & self, const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) {
+        return at::_ops::triangular_solve_X::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular, X, M);
+    }
+    
+    // aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)
+    inline ::std::tuple triangular_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular, at::Tensor & X, at::Tensor & M) {
+        return at::_ops::triangular_solve_X::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular, X, M);
+    }
+    
+    // aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)
+    inline ::std::tuple triangular_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) {
+        return at::_ops::triangular_solve::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular);
+    }
+    
+    // aten::_linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> ()
+    inline void _linalg_check_errors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & info, c10::string_view api_name, bool is_matrix) {
+        return at::_ops::_linalg_check_errors::redispatch(dispatchKeySet, info, api_name, is_matrix);
+    }
+    
+    // aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_solve_triangular_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & B, bool upper, bool left=true, bool unitriangular=false) {
+        return at::_ops::linalg_solve_triangular_out::redispatch(dispatchKeySet, self, B, upper, left, unitriangular, out);
+    }
+    
+    // aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_solve_triangular_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular, at::Tensor & out) {
+        return at::_ops::linalg_solve_triangular_out::redispatch(dispatchKeySet, self, B, upper, left, unitriangular, out);
+    }
+    
+    // aten::linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
+    inline at::Tensor linalg_solve_triangular(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & B, bool upper, bool left=true, bool unitriangular=false) {
+        return at::_ops::linalg_solve_triangular::redispatch(dispatchKeySet, self, B, upper, left, unitriangular);
+    }
+    
+    // aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor
+    inline at::Tensor linalg_vander(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, c10::optional N=c10::nullopt) {
+        return at::_ops::linalg_vander::redispatch(dispatchKeySet, x, N.has_value() ? c10::make_optional(c10::SymInt(*N)) : c10::nullopt);
+    }
+    
+    // aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor
+    inline at::Tensor linalg_vander_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, c10::optional N=c10::nullopt) {
+        return at::_ops::linalg_vander::redispatch(dispatchKeySet, x, N);
+    }
+    
+    // aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
+    inline ::std::tuple svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & V, const at::Tensor & self, bool some=true, bool compute_uv=true) {
+        return at::_ops::svd_U::redispatch(dispatchKeySet, self, some, compute_uv, U, S, V);
+    }
+    
+    // aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
+    inline ::std::tuple svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some, bool compute_uv, at::Tensor & U, at::Tensor & S, at::Tensor & V) {
+        return at::_ops::svd_U::redispatch(dispatchKeySet, self, some, compute_uv, U, S, V);
+    }
+    
+    // aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
+    inline ::std::tuple svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some=true, bool compute_uv=true) {
+        return at::_ops::svd::redispatch(dispatchKeySet, self, some, compute_uv);
+    }
+    
+    // aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a)
+    inline at::Tensor swapaxes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t axis0, int64_t axis1) {
+        return at::_ops::swapaxes::redispatch(dispatchKeySet, self, axis0, axis1);
+    }
+    
+    // aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!)
+    inline at::Tensor & swapaxes_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t axis0, int64_t axis1) {
+        return at::_ops::swapaxes_::redispatch(dispatchKeySet, self, axis0, axis1);
+    }
+    
+    // aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+    inline at::Tensor swapdims(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::swapdims::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+    inline at::Tensor & swapdims_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::swapdims_::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cholesky_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) {
+        return at::_ops::cholesky_out::redispatch(dispatchKeySet, self, upper, out);
+    }
+    
+    // aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cholesky_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) {
+        return at::_ops::cholesky_out::redispatch(dispatchKeySet, self, upper, out);
+    }
+    
+    // aten::cholesky(Tensor self, bool upper=False) -> Tensor
+    inline at::Tensor cholesky(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) {
+        return at::_ops::cholesky::redispatch(dispatchKeySet, self, upper);
+    }
+    
+    // aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cholesky_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2, bool upper=false) {
+        return at::_ops::cholesky_solve_out::redispatch(dispatchKeySet, self, input2, upper, out);
+    }
+    
+    // aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cholesky_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, bool upper, at::Tensor & out) {
+        return at::_ops::cholesky_solve_out::redispatch(dispatchKeySet, self, input2, upper, out);
+    }
+    
+    // aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
+    inline at::Tensor cholesky_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, bool upper=false) {
+        return at::_ops::cholesky_solve::redispatch(dispatchKeySet, self, input2, upper);
+    }
+    
+    // aten::_cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor
+    inline at::Tensor _cholesky_solve_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper) {
+        return at::_ops::_cholesky_solve_helper::redispatch(dispatchKeySet, self, A, upper);
+    }
+    
+    // aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor
+    inline at::Tensor cholesky_inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) {
+        return at::_ops::cholesky_inverse::redispatch(dispatchKeySet, self, upper);
+    }
+    
+    // aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cholesky_inverse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) {
+        return at::_ops::cholesky_inverse_out::redispatch(dispatchKeySet, self, upper, out);
+    }
+    
+    // aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cholesky_inverse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) {
+        return at::_ops::cholesky_inverse_out::redispatch(dispatchKeySet, self, upper, out);
+    }
+    
+    // aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
+    inline ::std::tuple qr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & Q, at::Tensor & R, const at::Tensor & self, bool some=true) {
+        return at::_ops::qr_Q::redispatch(dispatchKeySet, self, some, Q, R);
+    }
+    
+    // aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
+    inline ::std::tuple qr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some, at::Tensor & Q, at::Tensor & R) {
+        return at::_ops::qr_Q::redispatch(dispatchKeySet, self, some, Q, R);
+    }
+    
+    // aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)
+    inline ::std::tuple qr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some=true) {
+        return at::_ops::qr::redispatch(dispatchKeySet, self, some);
+    }
+    
+    // aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)
+    inline ::std::tuple geqrf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & a, at::Tensor & tau, const at::Tensor & self) {
+        return at::_ops::geqrf_a::redispatch(dispatchKeySet, self, a, tau);
+    }
+    
+    // aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)
+    inline ::std::tuple geqrf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & a, at::Tensor & tau) {
+        return at::_ops::geqrf_a::redispatch(dispatchKeySet, self, a, tau);
+    }
+    
+    // aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)
+    inline ::std::tuple geqrf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::geqrf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::orgqr(Tensor self, Tensor input2) -> Tensor
+    inline at::Tensor orgqr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2) {
+        return at::_ops::orgqr::redispatch(dispatchKeySet, self, input2);
+    }
+    
+    // aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & orgqr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2) {
+        return at::_ops::orgqr_out::redispatch(dispatchKeySet, self, input2, out);
+    }
+    
+    // aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & orgqr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, at::Tensor & out) {
+        return at::_ops::orgqr_out::redispatch(dispatchKeySet, self, input2, out);
+    }
+    
+    // aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ormqr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) {
+        return at::_ops::ormqr_out::redispatch(dispatchKeySet, self, input2, input3, left, transpose, out);
+    }
+    
+    // aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ormqr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose, at::Tensor & out) {
+        return at::_ops::ormqr_out::redispatch(dispatchKeySet, self, input2, input3, left, transpose, out);
+    }
+    
+    // aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor
+    inline at::Tensor ormqr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) {
+        return at::_ops::ormqr::redispatch(dispatchKeySet, self, input2, input3, left, transpose);
+    }
+    
+    // aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
+    inline ::std::tuple _lu_with_info(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool pivot=true, bool check_errors=true) {
+        return at::_ops::_lu_with_info::redispatch(dispatchKeySet, self, pivot, check_errors);
+    }
+    
+    // aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lu_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) {
+        return at::_ops::lu_solve_out::redispatch(dispatchKeySet, self, LU_data, LU_pivots, out);
+    }
+    
+    // aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lu_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots, at::Tensor & out) {
+        return at::_ops::lu_solve_out::redispatch(dispatchKeySet, self, LU_data, LU_pivots, out);
+    }
+    
+    // aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
+    inline at::Tensor lu_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) {
+        return at::_ops::lu_solve::redispatch(dispatchKeySet, self, LU_data, LU_pivots);
+    }
+    
+    // aten::lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+    inline ::std::tuple lu_unpack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data=true, bool unpack_pivots=true) {
+        return at::_ops::lu_unpack::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots);
+    }
+    
+    // aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+    inline ::std::tuple lu_unpack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data=true, bool unpack_pivots=true) {
+        return at::_ops::lu_unpack_out::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U);
+    }
+    
+    // aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+    inline ::std::tuple lu_unpack_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots, at::Tensor & P, at::Tensor & L, at::Tensor & U) {
+        return at::_ops::lu_unpack_out::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U);
+    }
+    
+    // aten::multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multinomial_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t num_samples, bool replacement=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out);
+    }
+    
+    // aten::multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multinomial_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_samples, bool replacement, c10::optional generator, at::Tensor & out) {
+        return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out);
+    }
+    
+    // aten::multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
+    inline at::Tensor multinomial(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_samples, bool replacement=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::multinomial::redispatch(dispatchKeySet, self, num_samples, replacement, generator);
+    }
+    
+    // aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lgamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::lgamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lgamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::lgamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::lgamma_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & lgamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::lgamma_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::lgamma(Tensor self) -> Tensor
+    inline at::Tensor lgamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::lgamma::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & digamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::digamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & digamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::digamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::digamma(Tensor self) -> Tensor
+    inline at::Tensor digamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::digamma::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & polygamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, const at::Tensor & self) {
+        return at::_ops::polygamma_out::redispatch(dispatchKeySet, n, self, out);
+    }
+    
+    // aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & polygamma_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::polygamma_out::redispatch(dispatchKeySet, n, self, out);
+    }
+    
+    // aten::polygamma(int n, Tensor self) -> Tensor
+    inline at::Tensor polygamma(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self) {
+        return at::_ops::polygamma::redispatch(dispatchKeySet, n, self);
+    }
+    
+    // aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
+    inline at::Tensor & polygamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t n) {
+        return at::_ops::polygamma_::redispatch(dispatchKeySet, self, n);
+    }
+    
+    // aten::erfinv(Tensor self) -> Tensor
+    inline at::Tensor erfinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::erfinv::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::erfinv_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & erfinv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::erfinv_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & erfinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::erfinv_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & erfinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::erfinv_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::i0(Tensor self) -> Tensor
+    inline at::Tensor i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::i0::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::i0_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & i0_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::i0_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::i0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::i0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sign(Tensor self) -> Tensor
+    inline at::Tensor sign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::sign::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sign_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & sign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::sign_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::sign_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::sign_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::signbit(Tensor self) -> Tensor
+    inline at::Tensor signbit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::signbit::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & signbit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::signbit_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & signbit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::signbit_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
+    inline at::Tensor dist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p=2) {
+        return at::_ops::dist::redispatch(dispatchKeySet, self, other, p);
+    }
+    
+    // aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & atan2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::atan2_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & atan2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::atan2_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & atan2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::atan2_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::atan2(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor atan2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::atan2::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::arctan2(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor arctan2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::arctan2::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arctan2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::arctan2_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & arctan2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::arctan2_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & arctan2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::arctan2_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lerp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
+        return at::_ops::lerp_Scalar_out::redispatch(dispatchKeySet, self, end, weight, out);
+    }
+    
+    // aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lerp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out) {
+        return at::_ops::lerp_Scalar_out::redispatch(dispatchKeySet, self, end, weight, out);
+    }
+    
+    // aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lerp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) {
+        return at::_ops::lerp_Tensor_out::redispatch(dispatchKeySet, self, end, weight, out);
+    }
+    
+    // aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lerp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out) {
+        return at::_ops::lerp_Tensor_out::redispatch(dispatchKeySet, self, end, weight, out);
+    }
+    
+    // aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
+    inline at::Tensor lerp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
+        return at::_ops::lerp_Scalar::redispatch(dispatchKeySet, self, end, weight);
+    }
+    
+    // aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor
+    inline at::Tensor lerp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) {
+        return at::_ops::lerp_Tensor::redispatch(dispatchKeySet, self, end, weight);
+    }
+    
+    // aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & histc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) {
+        return at::_ops::histc_out::redispatch(dispatchKeySet, self, bins, min, max, out);
+    }
+    
+    // aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & histc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max, at::Tensor & out) {
+        return at::_ops::histc_out::redispatch(dispatchKeySet, self, bins, min, max, out);
+    }
+    
+    // aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
+    inline at::Tensor histc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) {
+        return at::_ops::histc::redispatch(dispatchKeySet, self, bins, min, max);
+    }
+    
+    // aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
+    inline ::std::tuple histogram_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & hist, at::Tensor & bin_edges, const at::Tensor & self, const at::Tensor & bins, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogram_bins_tensor_out::redispatch(dispatchKeySet, self, bins, weight, density, hist, bin_edges);
+    }
+    
+    // aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
+    inline ::std::tuple histogram_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & bins, const c10::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges) {
+        return at::_ops::histogram_bins_tensor_out::redispatch(dispatchKeySet, self, bins, weight, density, hist, bin_edges);
+    }
+    
+    // aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
+    inline ::std::tuple histogram(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & bins, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogram_bins_tensor::redispatch(dispatchKeySet, self, bins, weight, density);
+    }
+    
+    // aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
+    inline ::std::tuple histogram_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & hist, at::Tensor & bin_edges, const at::Tensor & self, int64_t bins=100, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogram_bin_ct_out::redispatch(dispatchKeySet, self, bins, range, weight, density, hist, bin_edges);
+    }
+    
+    // aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
+    inline ::std::tuple histogram_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, c10::optional> range, const c10::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges) {
+        return at::_ops::histogram_bin_ct_out::redispatch(dispatchKeySet, self, bins, range, weight, density, hist, bin_edges);
+    }
+    
+    // aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
+    inline ::std::tuple histogram(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins=100, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogram_bin_ct::redispatch(dispatchKeySet, self, bins, range, weight, density);
+    }
+    
+    // aten::_histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]
+    inline ::std::vector _histogramdd_bin_edges(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::_histogramdd_bin_edges::redispatch(dispatchKeySet, self, bins, range, weight, density);
+    }
+    
+    // aten::_histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor
+    inline at::Tensor _histogramdd_from_bin_cts(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::_histogramdd_from_bin_cts::redispatch(dispatchKeySet, self, bins, range, weight, density);
+    }
+    
+    // aten::_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor
+    inline at::Tensor _histogramdd_from_bin_tensors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::_histogramdd_from_bin_tensors::redispatch(dispatchKeySet, self, bins, weight, density);
+    }
+    
+    // aten::histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
+    inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogramdd::redispatch(dispatchKeySet, self, bins, range, weight, density);
+    }
+    
+    // aten::histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
+    inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogramdd_int_bins::redispatch(dispatchKeySet, self, bins, range, weight, density);
+    }
+    
+    // aten::histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
+    inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::histogramdd_TensorList_bins::redispatch(dispatchKeySet, self, bins, range, weight, density);
+    }
+    
+    // aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::fmod_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::fmod_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor fmod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::fmod_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & fmod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::fmod__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmod_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::fmod_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor fmod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmod_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & fmod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmod__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hypot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::hypot_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hypot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::hypot_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::hypot(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor hypot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::hypot::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & hypot_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::hypot_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & igamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::igamma_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & igamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::igamma_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::igamma(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor igamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::igamma::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & igamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::igamma_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & igammac_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::igammac_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & igammac_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::igammac_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::igammac(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor igammac(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::igammac::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & igammac_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::igammac_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nextafter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::nextafter_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nextafter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::nextafter_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::nextafter(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor nextafter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::nextafter::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & nextafter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::nextafter_::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::remainder_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::remainder_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::remainder_Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+    inline at::Tensor & remainder_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::remainder__Scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::remainder_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::remainder_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::remainder_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+    inline at::Tensor & remainder_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::remainder__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::remainder_Scalar_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::min(Tensor self) -> Tensor
+    inline at::Tensor min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::min::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::min_unary_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::min_unary_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::fmin(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor fmin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmin::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmin_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::fmin_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::max(Tensor self) -> Tensor
+    inline at::Tensor max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::max::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::fmax(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor fmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmax::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::fmax_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::fmax_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::maximum(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor maximum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::maximum::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & maximum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::maximum_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & maximum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::maximum_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::max.other(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::max_other::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::max_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::max_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::max_unary_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::max_unary_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::minimum(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor minimum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::minimum::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & minimum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::minimum_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & minimum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::minimum_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::min_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::min_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::min.other(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::min_other::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+    inline at::Tensor quantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::quantile::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation);
+    }
+    
+    // aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::quantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) {
+        return at::_ops::quantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+    inline at::Tensor quantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::quantile_scalar::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation);
+    }
+    
+    // aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::quantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, c10::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) {
+        return at::_ops::quantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+    inline at::Tensor nanquantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::nanquantile::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation);
+    }
+    
+    // aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanquantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::nanquantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanquantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) {
+        return at::_ops::nanquantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+    inline at::Tensor nanquantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::nanquantile_scalar::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation);
+    }
+    
+    // aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanquantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") {
+        return at::_ops::nanquantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanquantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, c10::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) {
+        return at::_ops::nanquantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out);
+    }
+    
+    // aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim=-1, bool descending=false) {
+        return at::_ops::sort_values::redispatch(dispatchKeySet, self, dim, descending, values, indices);
+    }
+    
+    // aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::sort_values::redispatch(dispatchKeySet, self, dim, descending, values, indices);
+    }
+    
+    // aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::optional stable, int64_t dim=-1, bool descending=false) {
+        return at::_ops::sort_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices);
+    }
+    
+    // aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional stable, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::sort_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices);
+    }
+    
+    // aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool descending=false) {
+        return at::_ops::sort::redispatch(dispatchKeySet, self, dim, descending);
+    }
+    
+    // aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional stable, int64_t dim=-1, bool descending=false) {
+        return at::_ops::sort_stable::redispatch(dispatchKeySet, self, stable, dim, descending);
+    }
+    
+    // aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool descending=false) {
+        return at::_ops::sort_dimname_values::redispatch(dispatchKeySet, self, dim, descending, values, indices);
+    }
+    
+    // aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::sort_dimname_values::redispatch(dispatchKeySet, self, dim, descending, values, indices);
+    }
+    
+    // aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::optional stable, at::Dimname dim, bool descending=false) {
+        return at::_ops::sort_dimname_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices);
+    }
+    
+    // aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional stable, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::sort_dimname_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices);
+    }
+    
+    // aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending=false) {
+        return at::_ops::sort_dimname::redispatch(dispatchKeySet, self, dim, descending);
+    }
+    
+    // aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
+    inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional stable, at::Dimname dim, bool descending=false) {
+        return at::_ops::sort_dimname_stable::redispatch(dispatchKeySet, self, stable, dim, descending);
+    }
+    
+    // aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & msort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::msort_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & msort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::msort_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::msort(Tensor self) -> Tensor
+    inline at::Tensor msort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::msort::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
+    inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool descending=false) {
+        return at::_ops::argsort::redispatch(dispatchKeySet, self, dim, descending);
+    }
+    
+    // aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor
+    inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) {
+        return at::_ops::argsort_stable::redispatch(dispatchKeySet, self, stable, dim, descending);
+    }
+    
+    // aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor
+    inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending=false) {
+        return at::_ops::argsort_dimname::redispatch(dispatchKeySet, self, dim, descending);
+    }
+    
+    // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple topk_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) {
+        return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices);
+    }
+    
+    // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple topk_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices);
+    }
+    
+    // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple topk_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) {
+        return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices);
+    }
+    
+    // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+    inline ::std::tuple topk_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) {
+        return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices);
+    }
+    
+    // aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
+    inline ::std::tuple topk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) {
+        return at::_ops::topk::redispatch(dispatchKeySet, self, k, dim, largest, sorted);
+    }
+    
+    // aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
+    inline ::std::tuple topk_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) {
+        return at::_ops::topk::redispatch(dispatchKeySet, self, k, dim, largest, sorted);
+    }
+    
+    // aten::all(Tensor self) -> Tensor
+    inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::all::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::all_all_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::all_all_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::any(Tensor self) -> Tensor
+    inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::any::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::any_all_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::any_all_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & renorm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) {
+        return at::_ops::renorm_out::redispatch(dispatchKeySet, self, p, dim, maxnorm, out);
+    }
+    
+    // aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & renorm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm, at::Tensor & out) {
+        return at::_ops::renorm_out::redispatch(dispatchKeySet, self, p, dim, maxnorm, out);
+    }
+    
+    // aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
+    inline at::Tensor renorm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) {
+        return at::_ops::renorm::redispatch(dispatchKeySet, self, p, dim, maxnorm);
+    }
+    
+    // aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)
+    inline at::Tensor & renorm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) {
+        return at::_ops::renorm_::redispatch(dispatchKeySet, self, p, dim, maxnorm);
+    }
+    
+    // aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
+    inline at::Tensor unfold(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) {
+        return at::_ops::unfold::redispatch(dispatchKeySet, self, dimension, size, step);
+    }
+    
+    // aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
+    inline at::Tensor unfold_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) {
+        return at::_ops::unfold_backward::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step);
+    }
+    
+    // aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
+    inline at::Tensor unfold_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) {
+        return at::_ops::unfold_backward::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step);
+    }
+    
+    // aten::equal(Tensor self, Tensor other) -> bool
+    inline bool equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::equal::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & exponent) {
+        return at::_ops::pow_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out) {
+        return at::_ops::pow_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+    inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent) {
+        return at::_ops::pow_Tensor_Tensor::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & exponent) {
+        return at::_ops::pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out) {
+        return at::_ops::pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor
+    inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent) {
+        return at::_ops::pow_Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & exponent) {
+        return at::_ops::pow_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out) {
+        return at::_ops::pow_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+    inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent) {
+        return at::_ops::pow_Tensor_Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+    inline at::Tensor & pow_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & exponent) {
+        return at::_ops::pow__Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+    inline at::Tensor & pow_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & exponent) {
+        return at::_ops::pow__Tensor::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & exponent) {
+        return at::_ops::float_power_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out) {
+        return at::_ops::float_power_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+    inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent) {
+        return at::_ops::float_power_Tensor_Tensor::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & exponent) {
+        return at::_ops::float_power_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out) {
+        return at::_ops::float_power_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::float_power.Scalar(Scalar self, Tensor exponent) -> Tensor
+    inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent) {
+        return at::_ops::float_power_Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & exponent) {
+        return at::_ops::float_power_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out) {
+        return at::_ops::float_power_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+    inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent) {
+        return at::_ops::float_power_Tensor_Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+    inline at::Tensor & float_power_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & exponent) {
+        return at::_ops::float_power__Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+    inline at::Tensor & float_power_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & exponent) {
+        return at::_ops::float_power__Tensor::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & normal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double mean=0, double std=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_::redispatch(dispatchKeySet, self, mean, std, generator);
+    }
+    
+    // aten::normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor
+    inline at::Tensor normal_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean=0, double std=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_functional::redispatch(dispatchKeySet, self, mean, std, generator);
+    }
+    
+    // aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mean, double std=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_Tensor_float_out::redispatch(dispatchKeySet, mean, std, generator, out);
+    }
+    
+    // aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, double std, c10::optional generator, at::Tensor & out) {
+        return at::_ops::normal_Tensor_float_out::redispatch(dispatchKeySet, mean, std, generator, out);
+    }
+    
+    // aten::normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
+    inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, double std=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_Tensor_float::redispatch(dispatchKeySet, mean, std, generator);
+    }
+    
+    // aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, const at::Tensor & std, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_float_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out);
+    }
+    
+    // aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, double mean, const at::Tensor & std, c10::optional generator, at::Tensor & out) {
+        return at::_ops::normal_float_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out);
+    }
+    
+    // aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
+    inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, const at::Tensor & std, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_float_Tensor::redispatch(dispatchKeySet, mean, std, generator);
+    }
+    
+    // aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mean, const at::Tensor & std, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_Tensor_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out);
+    }
+    
+    // aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, const at::Tensor & std, c10::optional generator, at::Tensor & out) {
+        return at::_ops::normal_Tensor_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out);
+    }
+    
+    // aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
+    inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, const at::Tensor & std, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_Tensor_Tensor::redispatch(dispatchKeySet, mean, std, generator);
+    }
+    
+    // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, c10::optional generator=c10::nullopt, at::TensorOptions options={}) {
+        return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor normal_symint(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, c10::optional generator=c10::nullopt, at::TensorOptions options={}) {
+        return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor normal_symint(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, size, generator, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, double std, at::IntArrayRef size, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, out);
+    }
+    
+    // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, double std, c10::SymIntArrayRef size, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, size, generator, out);
+    }
+    
+    // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_symint_outf(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, c10::optional generator, at::Tensor & out) {
+        return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, size, generator, out);
+    }
+    
+    // aten::alias(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor alias(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::alias::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()
+    inline void _amp_foreach_non_finite_check_and_unscale_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) {
+        return at::_ops::_amp_foreach_non_finite_check_and_unscale_::redispatch(dispatchKeySet, self, found_inf, inv_scale);
+    }
+    
+    // aten::_amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
+    inline at::Tensor & _amp_update_scale_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) {
+        return at::_ops::_amp_update_scale_::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval);
+    }
+    
+    // aten::_foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_add_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_add__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
+    inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_add_List::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
+    inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_add__List::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_add_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_add__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[]
+    inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_add_Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> ()
+    inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_add__Tensor::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_sub_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_sub__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
+    inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_sub_List::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
+    inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_sub__List::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_sub_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_sub__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_mul_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_mul__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
+    inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_mul_List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+    inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_mul__List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_mul_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_mul__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[]
+    inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) {
+        return at::_ops::_foreach_mul_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
+    inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) {
+        return at::_ops::_foreach_mul__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_div_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_div__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
+    inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_div_List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+    inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_div__List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_div_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_div__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]
+    inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) {
+        return at::_ops::_foreach_div_Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
+    inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) {
+        return at::_ops::_foreach_div__Tensor::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_clamp_max_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_clamp_max__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[]
+    inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_clamp_max_List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+    inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_clamp_max__List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_clamp_max_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_clamp_max__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_clamp_min_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_clamp_min__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[]
+    inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_clamp_min_List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+    inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_clamp_min__List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_clamp_min_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_clamp_min__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_maximum_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_maximum__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[]
+    inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_maximum_List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+    inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_maximum__List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_maximum_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_maximum__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+    inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_minimum_Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+    inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_minimum__Scalar::redispatch(dispatchKeySet, self, scalar);
+    }
+    
+    // aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]
+    inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_minimum_List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+    inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_minimum__List::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_minimum_ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+    inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_minimum__ScalarList::redispatch(dispatchKeySet, self, scalars);
+    }
+    
+    // aten::_foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
+    inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) {
+        return at::_ops::_foreach_addcdiv_Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::_foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) {
+        return at::_ops::_foreach_addcdiv_ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]
+    inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) {
+        return at::_ops::_foreach_addcdiv_Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
+    inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) {
+        return at::_ops::_foreach_addcdiv__Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::_foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
+    inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) {
+        return at::_ops::_foreach_addcdiv__ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()
+    inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) {
+        return at::_ops::_foreach_addcdiv__Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
+    inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) {
+        return at::_ops::_foreach_addcmul_Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::_foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
+    inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) {
+        return at::_ops::_foreach_addcmul_ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]
+    inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) {
+        return at::_ops::_foreach_addcmul_Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
+    inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) {
+        return at::_ops::_foreach_addcmul__Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value);
+    }
+    
+    // aten::_foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
+    inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) {
+        return at::_ops::_foreach_addcmul__ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()
+    inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) {
+        return at::_ops::_foreach_addcmul__Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars);
+    }
+    
+    // aten::_foreach_abs(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_abs(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_abs::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_abs_(Tensor(a!)[] self) -> ()
+    inline void _foreach_abs_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_abs_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_acos(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_acos(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_acos::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_acos_(Tensor(a!)[] self) -> ()
+    inline void _foreach_acos_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_acos_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_asin(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_asin(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_asin::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_asin_(Tensor(a!)[] self) -> ()
+    inline void _foreach_asin_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_asin_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_atan(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_atan(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_atan::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_atan_(Tensor(a!)[] self) -> ()
+    inline void _foreach_atan_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_atan_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_ceil(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_ceil(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_ceil::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_ceil_(Tensor(a!)[] self) -> ()
+    inline void _foreach_ceil_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_ceil_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_cos(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_cos(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_cos::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_cos_(Tensor(a!)[] self) -> ()
+    inline void _foreach_cos_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_cos_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_cosh(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_cosh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_cosh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_cosh_(Tensor(a!)[] self) -> ()
+    inline void _foreach_cosh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_cosh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_erf(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_erf(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_erf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_erf_(Tensor(a!)[] self) -> ()
+    inline void _foreach_erf_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_erf_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_erfc(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_erfc(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_erfc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_erfc_(Tensor(a!)[] self) -> ()
+    inline void _foreach_erfc_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_erfc_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_exp(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_exp(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_exp::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_exp_(Tensor(a!)[] self) -> ()
+    inline void _foreach_exp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_exp_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_expm1(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_expm1(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_expm1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_expm1_(Tensor(a!)[] self) -> ()
+    inline void _foreach_expm1_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_expm1_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_floor(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_floor(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_floor::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_floor_(Tensor(a!)[] self) -> ()
+    inline void _foreach_floor_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_floor_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_frac(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_frac(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_frac::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_frac_(Tensor(a!)[] self) -> ()
+    inline void _foreach_frac_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_frac_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[]
+    inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights) {
+        return at::_ops::_foreach_lerp_List::redispatch(dispatchKeySet, self, tensors1, weights);
+    }
+    
+    // aten::_foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> ()
+    inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights) {
+        return at::_ops::_foreach_lerp__List::redispatch(dispatchKeySet, self, tensors1, weights);
+    }
+    
+    // aten::_foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[]
+    inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) {
+        return at::_ops::_foreach_lerp_Scalar::redispatch(dispatchKeySet, self, tensors1, weight);
+    }
+    
+    // aten::_foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> ()
+    inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) {
+        return at::_ops::_foreach_lerp__Scalar::redispatch(dispatchKeySet, self, tensors1, weight);
+    }
+    
+    // aten::_foreach_lgamma(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_lgamma(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_lgamma::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_lgamma_(Tensor(a!)[] self) -> ()
+    inline void _foreach_lgamma_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_lgamma_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_log(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log_(Tensor(a!)[] self) -> ()
+    inline void _foreach_log_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log10(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_log10(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log10::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log10_(Tensor(a!)[] self) -> ()
+    inline void _foreach_log10_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log10_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log1p(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_log1p(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log1p::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log1p_(Tensor(a!)[] self) -> ()
+    inline void _foreach_log1p_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log1p_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log2(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_log2(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log2::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_log2_(Tensor(a!)[] self) -> ()
+    inline void _foreach_log2_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_log2_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_neg(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_neg(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_neg::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_neg_(Tensor(a!)[] self) -> ()
+    inline void _foreach_neg_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_neg_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
+    inline ::std::vector _foreach_norm(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & ord=2) {
+        return at::_ops::_foreach_norm_Scalar::redispatch(dispatchKeySet, self, ord);
+    }
+    
+    // aten::_foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
+    inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent) {
+        return at::_ops::_foreach_pow_List::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[]
+    inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent) {
+        return at::_ops::_foreach_pow_Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
+    inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent) {
+        return at::_ops::_foreach_pow_ScalarList::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
+    inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, at::TensorList exponent) {
+        return at::_ops::_foreach_pow_ScalarAndTensor::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> ()
+    inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent) {
+        return at::_ops::_foreach_pow__List::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> ()
+    inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent) {
+        return at::_ops::_foreach_pow__Scalar::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> ()
+    inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent) {
+        return at::_ops::_foreach_pow__ScalarList::redispatch(dispatchKeySet, self, exponent);
+    }
+    
+    // aten::_foreach_reciprocal(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_reciprocal(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_reciprocal::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_reciprocal_(Tensor(a!)[] self) -> ()
+    inline void _foreach_reciprocal_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_reciprocal_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_round(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_round(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_round::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_round_(Tensor(a!)[] self) -> ()
+    inline void _foreach_round_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_round_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sigmoid(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_sigmoid(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sigmoid::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sigmoid_(Tensor(a!)[] self) -> ()
+    inline void _foreach_sigmoid_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sigmoid_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sign(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_sign(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sign::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sign_(Tensor(a!)[] self) -> ()
+    inline void _foreach_sign_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sign_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sin(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_sin(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sin::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sin_(Tensor(a!)[] self) -> ()
+    inline void _foreach_sin_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sin_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sinh(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_sinh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sinh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sinh_(Tensor(a!)[] self) -> ()
+    inline void _foreach_sinh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sinh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sqrt(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_sqrt(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sqrt::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_sqrt_(Tensor(a!)[] self) -> ()
+    inline void _foreach_sqrt_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_sqrt_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_tan(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_tan(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_tan::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_tan_(Tensor(a!)[] self) -> ()
+    inline void _foreach_tan_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_tan_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_tanh(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_tanh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_tanh::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_tanh_(Tensor(a!)[] self) -> ()
+    inline void _foreach_tanh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_tanh_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_trunc(Tensor[] self) -> Tensor[]
+    inline ::std::vector _foreach_trunc(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_trunc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_trunc_(Tensor(a!)[] self) -> ()
+    inline void _foreach_trunc_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_trunc_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_zero_(Tensor(a!)[] self) -> ()
+    inline void _foreach_zero_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_zero_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> ()
+    inline void _foreach_copy_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking=false) {
+        return at::_ops::_foreach_copy_::redispatch(dispatchKeySet, self, src, non_blocking);
+    }
+    
+    // aten::bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
+    inline at::Tensor bucketize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) {
+        return at::_ops::bucketize_Tensor::redispatch(dispatchKeySet, self, boundaries, out_int32, right);
+    }
+    
+    // aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bucketize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) {
+        return at::_ops::bucketize_Tensor_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out);
+    }
+    
+    // aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bucketize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out) {
+        return at::_ops::bucketize_Tensor_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out);
+    }
+    
+    // aten::bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
+    inline at::Tensor bucketize(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) {
+        return at::_ops::bucketize_Scalar::redispatch(dispatchKeySet, self, boundaries, out_int32, right);
+    }
+    
+    // aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor
+    inline at::Tensor searchsorted(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32=false, bool right=false, c10::optional side=c10::nullopt, const c10::optional & sorter={}) {
+        return at::_ops::searchsorted_Tensor::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter);
+    }
+    
+    // aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & searchsorted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32=false, bool right=false, c10::optional side=c10::nullopt, const c10::optional & sorter={}) {
+        return at::_ops::searchsorted_Tensor_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out);
+    }
+    
+    // aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & searchsorted_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, c10::optional side, const c10::optional & sorter, at::Tensor & out) {
+        return at::_ops::searchsorted_Tensor_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out);
+    }
+    
+    // aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor
+    inline at::Tensor searchsorted(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32=false, bool right=false, c10::optional side=c10::nullopt, const c10::optional & sorter={}) {
+        return at::_ops::searchsorted_Scalar::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter);
+    }
+    
+    // aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & searchsorted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32=false, bool right=false, c10::optional side=c10::nullopt, const c10::optional & sorter={}) {
+        return at::_ops::searchsorted_Scalar_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out);
+    }
+    
+    // aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & searchsorted_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, c10::optional side, const c10::optional & sorter, at::Tensor & out) {
+        return at::_ops::searchsorted_Scalar_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out);
+    }
+    
+    // aten::_convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor
+    inline at::Tensor _convert_indices_from_coo_to_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, bool out_int32=false) {
+        return at::_ops::_convert_indices_from_coo_to_csr::redispatch(dispatchKeySet, self, size, out_int32);
+    }
+    
+    // aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convert_indices_from_coo_to_csr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t size, bool out_int32=false) {
+        return at::_ops::_convert_indices_from_coo_to_csr_out::redispatch(dispatchKeySet, self, size, out_int32, out);
+    }
+    
+    // aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convert_indices_from_coo_to_csr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, bool out_int32, at::Tensor & out) {
+        return at::_ops::_convert_indices_from_coo_to_csr_out::redispatch(dispatchKeySet, self, size, out_int32, out);
+    }
+    
+    // aten::_convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor
+    inline at::Tensor _convert_indices_from_csr_to_coo(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32=false, bool transpose=false) {
+        return at::_ops::_convert_indices_from_csr_to_coo::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose);
+    }
+    
+    // aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convert_indices_from_csr_to_coo_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32=false, bool transpose=false) {
+        return at::_ops::_convert_indices_from_csr_to_coo_out::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose, out);
+    }
+    
+    // aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convert_indices_from_csr_to_coo_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose, at::Tensor & out) {
+        return at::_ops::_convert_indices_from_csr_to_coo_out::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose, out);
+    }
+    
+    // aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mse_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::mse_loss_out::redispatch(dispatchKeySet, self, target, reduction, out);
+    }
+    
+    // aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mse_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) {
+        return at::_ops::mse_loss_out::redispatch(dispatchKeySet, self, target, reduction, out);
+    }
+    
+    // aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+    inline at::Tensor mse_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::mse_loss::redispatch(dispatchKeySet, self, target, reduction);
+    }
+    
+    // aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & mse_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) {
+        return at::_ops::mse_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input);
+    }
+    
+    // aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & mse_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input) {
+        return at::_ops::mse_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input);
+    }
+    
+    // aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
+    inline at::Tensor mse_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) {
+        return at::_ops::mse_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction);
+    }
+    
+    // aten::l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+    inline at::Tensor l1_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::l1_loss::redispatch(dispatchKeySet, self, target, reduction);
+    }
+    
+    // aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multi_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p=1, const at::Scalar & margin=1, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::multi_margin_loss_out::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction, out);
+    }
+    
+    // aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multi_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const c10::optional & weight, int64_t reduction, at::Tensor & out) {
+        return at::_ops::multi_margin_loss_out::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction, out);
+    }
+    
+    // aten::multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor
+    inline at::Tensor multi_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p=1, const at::Scalar & margin=1, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::multi_margin_loss::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction);
+    }
+    
+    // aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & multi_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::multi_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction, grad_input);
+    }
+    
+    // aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & multi_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const c10::optional & weight, int64_t reduction, at::Tensor & grad_input) {
+        return at::_ops::multi_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction, grad_input);
+    }
+    
+    // aten::multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor
+    inline at::Tensor multi_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::multi_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction);
+    }
+    
+    // aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multilabel_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::multilabel_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out);
+    }
+    
+    // aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & multilabel_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) {
+        return at::_ops::multilabel_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out);
+    }
+    
+    // aten::multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+    inline at::Tensor multilabel_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::multilabel_margin_loss::redispatch(dispatchKeySet, self, target, reduction);
+    }
+    
+    // aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple multilabel_margin_loss_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & is_target, const at::Tensor & self, const at::Tensor & target, int64_t reduction) {
+        return at::_ops::multilabel_margin_loss_forward_output::redispatch(dispatchKeySet, self, target, reduction, output, is_target);
+    }
+    
+    // aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple multilabel_margin_loss_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & output, at::Tensor & is_target) {
+        return at::_ops::multilabel_margin_loss_forward_output::redispatch(dispatchKeySet, self, target, reduction, output, is_target);
+    }
+    
+    // aten::multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)
+    inline ::std::tuple multilabel_margin_loss_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction) {
+        return at::_ops::multilabel_margin_loss_forward::redispatch(dispatchKeySet, self, target, reduction);
+    }
+    
+    // aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & multilabel_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) {
+        return at::_ops::multilabel_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target, grad_input);
+    }
+    
+    // aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & multilabel_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target, at::Tensor & grad_input) {
+        return at::_ops::multilabel_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target, grad_input);
+    }
+    
+    // aten::multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor
+    inline at::Tensor multilabel_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) {
+        return at::_ops::multilabel_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target);
+    }
+    
+    // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) {
+        return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & out) {
+        return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) {
+        return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out) {
+        return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+    inline at::Tensor nll_loss_nd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) {
+        return at::_ops::nll_loss_nd::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+    inline at::Tensor nll_loss_nd_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) {
+        return at::_ops::nll_loss_nd::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+    inline at::Tensor nll_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) {
+        return at::_ops::nll_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+    inline at::Tensor nll_loss_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) {
+        return at::_ops::nll_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index) {
+        return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & output, at::Tensor & total_weight) {
+        return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index) {
+        return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight) {
+        return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+    inline ::std::tuple nll_loss_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index) {
+        return at::_ops::nll_loss_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+    inline ::std::tuple nll_loss_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index) {
+        return at::_ops::nll_loss_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) {
+        return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) {
+        return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+    inline at::Tensor nll_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight);
+    }
+    
+    // aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+    inline at::Tensor nll_loss_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight);
+    }
+    
+    // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) {
+        return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & out) {
+        return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) {
+        return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out) {
+        return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out);
+    }
+    
+    // aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+    inline at::Tensor nll_loss2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) {
+        return at::_ops::nll_loss2d::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+    inline at::Tensor nll_loss2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) {
+        return at::_ops::nll_loss2d::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss2d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index) {
+        return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss2d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & output, at::Tensor & total_weight) {
+        return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss2d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index) {
+        return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple nll_loss2d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight) {
+        return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight);
+    }
+    
+    // aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+    inline ::std::tuple nll_loss2d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index) {
+        return at::_ops::nll_loss2d_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+    inline ::std::tuple nll_loss2d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index) {
+        return at::_ops::nll_loss2d_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index);
+    }
+    
+    // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) {
+        return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & nll_loss2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) {
+        return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input);
+    }
+    
+    // aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+    inline at::Tensor nll_loss2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss2d_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight);
+    }
+    
+    // aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+    inline at::Tensor nll_loss2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) {
+        return at::_ops::nll_loss2d_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight);
+    }
+    
+    // aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & smooth_l1_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double beta=1.0) {
+        return at::_ops::smooth_l1_loss_out::redispatch(dispatchKeySet, self, target, reduction, beta, out);
+    }
+    
+    // aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & smooth_l1_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out) {
+        return at::_ops::smooth_l1_loss_out::redispatch(dispatchKeySet, self, target, reduction, beta, out);
+    }
+    
+    // aten::smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
+    inline at::Tensor smooth_l1_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double beta=1.0) {
+        return at::_ops::smooth_l1_loss::redispatch(dispatchKeySet, self, target, reduction, beta);
+    }
+    
+    // aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & smooth_l1_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) {
+        return at::_ops::smooth_l1_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta, grad_input);
+    }
+    
+    // aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & smooth_l1_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & grad_input) {
+        return at::_ops::smooth_l1_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta, grad_input);
+    }
+    
+    // aten::smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
+    inline at::Tensor smooth_l1_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) {
+        return at::_ops::smooth_l1_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta);
+    }
+    
+    // aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & huber_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double delta=1.0) {
+        return at::_ops::huber_loss_out::redispatch(dispatchKeySet, self, target, reduction, delta, out);
+    }
+    
+    // aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & huber_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & out) {
+        return at::_ops::huber_loss_out::redispatch(dispatchKeySet, self, target, reduction, delta, out);
+    }
+    
+    // aten::huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor
+    inline at::Tensor huber_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double delta=1.0) {
+        return at::_ops::huber_loss::redispatch(dispatchKeySet, self, target, reduction, delta);
+    }
+    
+    // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & huber_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) {
+        return at::_ops::huber_loss_backward_out::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta, grad_input);
+    }
+    
+    // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & huber_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input) {
+        return at::_ops::huber_loss_backward_out::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta, grad_input);
+    }
+    
+    // aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor
+    inline at::Tensor huber_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) {
+        return at::_ops::huber_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta);
+    }
+    
+    // aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & soft_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::soft_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out);
+    }
+    
+    // aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & soft_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) {
+        return at::_ops::soft_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out);
+    }
+    
+    // aten::soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+    inline at::Tensor soft_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::soft_margin_loss::redispatch(dispatchKeySet, self, target, reduction);
+    }
+    
+    // aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & soft_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) {
+        return at::_ops::soft_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input);
+    }
+    
+    // aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & soft_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input) {
+        return at::_ops::soft_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input);
+    }
+    
+    // aten::soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
+    inline at::Tensor soft_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) {
+        return at::_ops::soft_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction);
+    }
+    
+    // aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & elu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) {
+        return at::_ops::elu_out::redispatch(dispatchKeySet, self, alpha, scale, input_scale, out);
+    }
+    
+    // aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & elu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, at::Tensor & out) {
+        return at::_ops::elu_out::redispatch(dispatchKeySet, self, alpha, scale, input_scale, out);
+    }
+    
+    // aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
+    inline at::Tensor elu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) {
+        return at::_ops::elu::redispatch(dispatchKeySet, self, alpha, scale, input_scale);
+    }
+    
+    // aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & elu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) {
+        return at::_ops::elu_backward_grad_input::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result, grad_input);
+    }
+    
+    // aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & elu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result, at::Tensor & grad_input) {
+        return at::_ops::elu_backward_grad_input::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result, grad_input);
+    }
+    
+    // aten::elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor
+    inline at::Tensor elu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) {
+        return at::_ops::elu_backward::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result);
+    }
+    
+    // aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)
+    inline at::Tensor & elu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) {
+        return at::_ops::elu_::redispatch(dispatchKeySet, self, alpha, scale, input_scale);
+    }
+    
+    // aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & glu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=-1) {
+        return at::_ops::glu_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & glu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) {
+        return at::_ops::glu_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::glu(Tensor self, int dim=-1) -> Tensor
+    inline at::Tensor glu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1) {
+        return at::_ops::glu::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & glu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) {
+        return at::_ops::glu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, dim, grad_input);
+    }
+    
+    // aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & glu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, at::Tensor & grad_input) {
+        return at::_ops::glu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, dim, grad_input);
+    }
+    
+    // aten::glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
+    inline at::Tensor glu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) {
+        return at::_ops::glu_backward::redispatch(dispatchKeySet, grad_output, self, dim);
+    }
+    
+    // aten::glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor
+    inline at::Tensor glu_jvp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) {
+        return at::_ops::glu_jvp::redispatch(dispatchKeySet, glu, x, dx, dim);
+    }
+    
+    // aten::glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor
+    inline at::Tensor glu_backward_jvp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) {
+        return at::_ops::glu_backward_jvp::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim);
+    }
+    
+    // aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardsigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::hardsigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardsigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::hardsigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::hardsigmoid(Tensor self) -> Tensor
+    inline at::Tensor hardsigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::hardsigmoid::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::hardsigmoid_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & hardsigmoid_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::hardsigmoid_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & hardsigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::hardsigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input);
+    }
+    
+    // aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & hardsigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) {
+        return at::_ops::hardsigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input);
+    }
+    
+    // aten::hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor hardsigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::hardsigmoid_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardtanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) {
+        return at::_ops::hardtanh_out::redispatch(dispatchKeySet, self, min_val, max_val, out);
+    }
+    
+    // aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardtanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & out) {
+        return at::_ops::hardtanh_out::redispatch(dispatchKeySet, self, min_val, max_val, out);
+    }
+    
+    // aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
+    inline at::Tensor hardtanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) {
+        return at::_ops::hardtanh::redispatch(dispatchKeySet, self, min_val, max_val);
+    }
+    
+    // aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & hardtanh_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) {
+        return at::_ops::hardtanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, min_val, max_val, grad_input);
+    }
+    
+    // aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & hardtanh_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & grad_input) {
+        return at::_ops::hardtanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, min_val, max_val, grad_input);
+    }
+    
+    // aten::hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor
+    inline at::Tensor hardtanh_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) {
+        return at::_ops::hardtanh_backward::redispatch(dispatchKeySet, grad_output, self, min_val, max_val);
+    }
+    
+    // aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!)
+    inline at::Tensor & hardtanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) {
+        return at::_ops::hardtanh_::redispatch(dispatchKeySet, self, min_val, max_val);
+    }
+    
+    // aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardswish_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::hardswish_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardswish_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::hardswish_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::hardswish(Tensor self) -> Tensor
+    inline at::Tensor hardswish(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::hardswish::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::hardswish_(Tensor(a!) self) -> Tensor(a!)
+    inline at::Tensor & hardswish_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) {
+        return at::_ops::hardswish_::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor hardswish_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::hardswish_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & leaky_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & negative_slope=0.01) {
+        return at::_ops::leaky_relu_out::redispatch(dispatchKeySet, self, negative_slope, out);
+    }
+    
+    // aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & leaky_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & negative_slope, at::Tensor & out) {
+        return at::_ops::leaky_relu_out::redispatch(dispatchKeySet, self, negative_slope, out);
+    }
+    
+    // aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
+    inline at::Tensor leaky_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & negative_slope=0.01) {
+        return at::_ops::leaky_relu::redispatch(dispatchKeySet, self, negative_slope);
+    }
+    
+    // aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & leaky_relu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) {
+        return at::_ops::leaky_relu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result, grad_input);
+    }
+    
+    // aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & leaky_relu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result, at::Tensor & grad_input) {
+        return at::_ops::leaky_relu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result, grad_input);
+    }
+    
+    // aten::leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
+    inline at::Tensor leaky_relu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) {
+        return at::_ops::leaky_relu_backward::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result);
+    }
+    
+    // aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)
+    inline at::Tensor & leaky_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & negative_slope=0.01) {
+        return at::_ops::leaky_relu_::redispatch(dispatchKeySet, self, negative_slope);
+    }
+    
+    // aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::log_sigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_sigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::log_sigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::log_sigmoid(Tensor self) -> Tensor
+    inline at::Tensor log_sigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::log_sigmoid::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple log_sigmoid_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & buffer, const at::Tensor & self) {
+        return at::_ops::log_sigmoid_forward_output::redispatch(dispatchKeySet, self, output, buffer);
+    }
+    
+    // aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple log_sigmoid_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & output, at::Tensor & buffer) {
+        return at::_ops::log_sigmoid_forward_output::redispatch(dispatchKeySet, self, output, buffer);
+    }
+    
+    // aten::log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
+    inline ::std::tuple log_sigmoid_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::log_sigmoid_forward::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & log_sigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) {
+        return at::_ops::log_sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, buffer, grad_input);
+    }
+    
+    // aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & log_sigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer, at::Tensor & grad_input) {
+        return at::_ops::log_sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, buffer, grad_input);
+    }
+    
+    // aten::log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
+    inline at::Tensor log_sigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) {
+        return at::_ops::log_sigmoid_backward::redispatch(dispatchKeySet, grad_output, self, buffer);
+    }
+    
+    // aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rrelu_with_noise_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::rrelu_with_noise_out::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator, out);
+    }
+    
+    // aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rrelu_with_noise_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, c10::optional generator, at::Tensor & out) {
+        return at::_ops::rrelu_with_noise_out::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator, out);
+    }
+    
+    // aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
+    inline at::Tensor rrelu_with_noise(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::rrelu_with_noise::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator);
+    }
+    
+    // aten::rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
+    inline at::Tensor rrelu_with_noise_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) {
+        return at::_ops::rrelu_with_noise_backward::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result);
+    }
+    
+    // aten::rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
+    inline at::Tensor & rrelu_with_noise_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, c10::optional generator=c10::nullopt) {
+        return at::_ops::rrelu_with_noise_::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator);
+    }
+    
+    // aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & softplus_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20) {
+        return at::_ops::softplus_out::redispatch(dispatchKeySet, self, beta, threshold, out);
+    }
+    
+    // aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & softplus_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & out) {
+        return at::_ops::softplus_out::redispatch(dispatchKeySet, self, beta, threshold, out);
+    }
+    
+    // aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
+    inline at::Tensor softplus(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20) {
+        return at::_ops::softplus::redispatch(dispatchKeySet, self, beta, threshold);
+    }
+    
+    // aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & softplus_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) {
+        return at::_ops::softplus_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, beta, threshold, grad_input);
+    }
+    
+    // aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & softplus_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & grad_input) {
+        return at::_ops::softplus_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, beta, threshold, grad_input);
+    }
+    
+    // aten::softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor
+    inline at::Tensor softplus_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) {
+        return at::_ops::softplus_backward::redispatch(dispatchKeySet, grad_output, self, beta, threshold);
+    }
+    
+    // aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & softshrink_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) {
+        return at::_ops::softshrink_out::redispatch(dispatchKeySet, self, lambd, out);
+    }
+    
+    // aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & softshrink_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) {
+        return at::_ops::softshrink_out::redispatch(dispatchKeySet, self, lambd, out);
+    }
+    
+    // aten::softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+    inline at::Tensor softshrink(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd=0.5) {
+        return at::_ops::softshrink::redispatch(dispatchKeySet, self, lambd);
+    }
+    
+    // aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & softshrink_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) {
+        return at::_ops::softshrink_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, lambd, grad_input);
+    }
+    
+    // aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & softshrink_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input) {
+        return at::_ops::softshrink_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, lambd, grad_input);
+    }
+    
+    // aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
+    inline at::Tensor softshrink_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) {
+        return at::_ops::softshrink_backward::redispatch(dispatchKeySet, grad_output, self, lambd);
+    }
+    
+    // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+    inline at::Tensor adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size));
+    }
+    
+    // aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+    inline at::Tensor adaptive_avg_pool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor
+    inline at::Tensor mkldnn_adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::mkldnn_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::mkldnn_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::mkldnn_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor mkldnn_adaptive_avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::mkldnn_adaptive_avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+    inline at::Tensor _adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size));
+    }
+    
+    // aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+    inline at::Tensor _adaptive_avg_pool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor _adaptive_avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::_adaptive_avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+    inline at::Tensor adaptive_avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size));
+    }
+    
+    // aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+    inline at::Tensor adaptive_avg_pool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::adaptive_avg_pool3d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+    inline at::Tensor _adaptive_avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size));
+    }
+    
+    // aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+    inline at::Tensor _adaptive_avg_pool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool3d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::adaptive_avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input);
+    }
+    
+    // aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & adaptive_avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) {
+        return at::_ops::adaptive_avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input);
+    }
+    
+    // aten::_adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor
+    inline at::Tensor _adaptive_avg_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::_adaptive_avg_pool3d_backward::redispatch(dispatchKeySet, grad_output, self);
+    }
+    
+    // aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple adaptive_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_max_pool2d_out::redispatch(dispatchKeySet, self, output_size, out, indices);
+    }
+    
+    // aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple adaptive_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) {
+        return at::_ops::adaptive_max_pool2d_out::redispatch(dispatchKeySet, self, output_size, out, indices);
+    }
+    
+    // aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
+    inline ::std::tuple adaptive_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_max_pool2d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & adaptive_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) {
+        return at::_ops::adaptive_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input);
+    }
+    
+    // aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & adaptive_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) {
+        return at::_ops::adaptive_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input);
+    }
+    
+    // aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
+    inline at::Tensor adaptive_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) {
+        return at::_ops::adaptive_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, indices);
+    }
+    
+    // aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple adaptive_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_max_pool3d_out::redispatch(dispatchKeySet, self, output_size, out, indices);
+    }
+    
+    // aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple adaptive_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) {
+        return at::_ops::adaptive_max_pool3d_out::redispatch(dispatchKeySet, self, output_size, out, indices);
+    }
+    
+    // aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
+    inline ::std::tuple adaptive_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::adaptive_max_pool3d::redispatch(dispatchKeySet, self, output_size);
+    }
+    
+    // aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & adaptive_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) {
+        return at::_ops::adaptive_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input);
+    }
+    
+    // aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & adaptive_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) {
+        return at::_ops::adaptive_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input);
+    }
+    
+    // aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
+    inline at::Tensor adaptive_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) {
+        return at::_ops::adaptive_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, indices);
+    }
+    
+    // aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, c10::optional divisor_override=c10::nullopt) {
+        return at::_ops::avg_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out);
+    }
+    
+    // aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, at::Tensor & out) {
+        return at::_ops::avg_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out);
+    }
+    
+    // aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
+    inline at::Tensor avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, c10::optional divisor_override=c10::nullopt) {
+        return at::_ops::avg_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override);
+    }
+    
+    // aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) {
+        return at::_ops::avg_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input);
+    }
+    
+    // aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, at::Tensor & grad_input) {
+        return at::_ops::avg_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input);
+    }
+    
+    // aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
+    inline at::Tensor avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) {
+        return at::_ops::avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override);
+    }
+    
+    // aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, c10::optional divisor_override=c10::nullopt) {
+        return at::_ops::avg_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out);
+    }
+    
+    // aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, at::Tensor & out) {
+        return at::_ops::avg_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out);
+    }
+    
+    // aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
+    inline at::Tensor avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, c10::optional divisor_override=c10::nullopt) {
+        return at::_ops::avg_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override);
+    }
+    
+    // aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) {
+        return at::_ops::avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input);
+    }
+    
+    // aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, at::Tensor & grad_input) {
+        return at::_ops::avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input);
+    }
+    
+    // aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
+    inline at::Tensor avg_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) {
+        return at::_ops::avg_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override);
+    }
+    
+    // aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fractional_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) {
+        return at::_ops::fractional_max_pool2d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices);
+    }
+    
+    // aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fractional_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices) {
+        return at::_ops::fractional_max_pool2d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices);
+    }
+    
+    // aten::fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)
+    inline ::std::tuple fractional_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) {
+        return at::_ops::fractional_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples);
+    }
+    
+    // aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & fractional_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) {
+        return at::_ops::fractional_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input);
+    }
+    
+    // aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & fractional_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input) {
+        return at::_ops::fractional_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input);
+    }
+    
+    // aten::fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor
+    inline at::Tensor fractional_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) {
+        return at::_ops::fractional_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices);
+    }
+    
+    // aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fractional_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) {
+        return at::_ops::fractional_max_pool3d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices);
+    }
+    
+    // aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fractional_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices) {
+        return at::_ops::fractional_max_pool3d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices);
+    }
+    
+    // aten::fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)
+    inline ::std::tuple fractional_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) {
+        return at::_ops::fractional_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples);
+    }
+    
+    // aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & fractional_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) {
+        return at::_ops::fractional_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input);
+    }
+    
+    // aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & fractional_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input) {
+        return at::_ops::fractional_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input);
+    }
+    
+    // aten::fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor
+    inline at::Tensor fractional_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) {
+        return at::_ops::fractional_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices);
+    }
+    
+    // aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple max_pool2d_with_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool2d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
+    }
+    
+    // aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple max_pool2d_with_indices_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices) {
+        return at::_ops::max_pool2d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
+    }
+    
+    // aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+    inline ::std::tuple max_pool2d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool2d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & max_pool2d_with_indices_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) {
+        return at::_ops::max_pool2d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input);
+    }
+    
+    // aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & max_pool2d_with_indices_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input) {
+        return at::_ops::max_pool2d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input);
+    }
+    
+    // aten::max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
+    inline at::Tensor max_pool2d_with_indices_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) {
+        return at::_ops::max_pool2d_with_indices_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices);
+    }
+    
+    // aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple max_pool3d_with_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool3d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
+    }
+    
+    // aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple max_pool3d_with_indices_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices) {
+        return at::_ops::max_pool3d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
+    }
+    
+    // aten::max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+    inline ::std::tuple max_pool3d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool3d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode);
+    }
+    
+    // aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & max_pool3d_with_indices_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) {
+        return at::_ops::max_pool3d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input);
+    }
+    
+    // aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & max_pool3d_with_indices_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input) {
+        return at::_ops::max_pool3d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input);
+    }
+    
+    // aten::max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
+    inline at::Tensor max_pool3d_with_indices_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) {
+        return at::_ops::max_pool3d_with_indices_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices);
+    }
+    
+    // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size) {
+        return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) {
+        return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, output_size, out);
+    }
+    
+    // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, output_size, out);
+    }
+    
+    // aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
+    inline at::Tensor max_unpool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size) {
+        return at::_ops::max_unpool2d::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size));
+    }
+    
+    // aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
+    inline at::Tensor max_unpool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) {
+        return at::_ops::max_unpool2d::redispatch(dispatchKeySet, self, indices, output_size);
+    }
+    
+    // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding, out);
+    }
+    
+    // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding, out);
+    }
+    
+    // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, output_size, stride, padding, out);
+    }
+    
+    // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_unpool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, output_size, stride, padding, out);
+    }
+    
+    // aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
+    inline at::Tensor max_unpool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::max_unpool3d::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding);
+    }
+    
+    // aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
+    inline at::Tensor max_unpool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::max_unpool3d::redispatch(dispatchKeySet, self, indices, output_size, stride, padding);
+    }
+    
+    // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor reflection_pad1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor reflection_pad1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad1d::redispatch(dispatchKeySet, self, padding);
+    }
+    
+    // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor reflection_pad1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor reflection_pad1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, padding);
+    }
+    
+    // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor reflection_pad2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor reflection_pad2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad2d::redispatch(dispatchKeySet, self, padding);
+    }
+    
+    // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor reflection_pad2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor reflection_pad2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, padding);
+    }
+    
+    // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor reflection_pad3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor reflection_pad3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad3d::redispatch(dispatchKeySet, self, padding);
+    }
+    
+    // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & reflection_pad3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor reflection_pad3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::reflection_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor reflection_pad3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::reflection_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, padding);
+    }
+    
+    // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor replication_pad1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor replication_pad1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad1d::redispatch(dispatchKeySet, self, padding);
+    }
+    
+    // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor replication_pad1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+    inline at::Tensor replication_pad1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, padding);
+    }
+    
+    // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor replication_pad2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor replication_pad2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad2d::redispatch(dispatchKeySet, self, padding);
+    }
+    
+    // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor replication_pad2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+    inline at::Tensor replication_pad2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, padding);
+    }
+    
+    // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, padding, out);
+    }
+    
+    // aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor replication_pad3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor replication_pad3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad3d::redispatch(dispatchKeySet, self, padding);
+    }
+    
+    // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input);
+    }
+    
+    // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & replication_pad3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) {
+        return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input);
+    }
+    
+    // aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor replication_pad3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) {
+        return at::_ops::replication_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+    inline at::Tensor replication_pad3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) {
+        return at::_ops::replication_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, padding);
+    }
+    
+    // aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor
+    inline at::Tensor _pad_circular(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad) {
+        return at::_ops::_pad_circular::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad));
+    }
+    
+    // aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor
+    inline at::Tensor _pad_circular_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad) {
+        return at::_ops::_pad_circular::redispatch(dispatchKeySet, self, pad);
+    }
+    
+    // aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor
+    inline at::Tensor _pad_enum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, int64_t mode, c10::optional value=c10::nullopt) {
+        return at::_ops::_pad_enum::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), mode, value);
+    }
+    
+    // aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor
+    inline at::Tensor _pad_enum_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, c10::optional value=c10::nullopt) {
+        return at::_ops::_pad_enum::redispatch(dispatchKeySet, self, pad, mode, value);
+    }
+    
+    // aten::pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
+    inline at::Tensor pad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, c10::string_view mode="constant", c10::optional value=c10::nullopt) {
+        return at::_ops::pad::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), mode, value);
+    }
+    
+    // aten::pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
+    inline at::Tensor pad_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode="constant", c10::optional value=c10::nullopt) {
+        return at::_ops::pad::redispatch(dispatchKeySet, self, pad, mode, value);
+    }
+    
+    // aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_linear1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_linear1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_linear1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_linear1d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_bilinear2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_bilinear2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_bilinear2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_bilinear2d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors);
+    }
+    
+    // aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_bilinear2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::_upsample_bilinear2d_aa_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, align_corners, scale_factors);
+    }
+    
+    // aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_bilinear2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::_upsample_bilinear2d_aa_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_trilinear3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_trilinear3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_trilinear3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_trilinear3d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_bicubic2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_bicubic2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_bicubic2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::upsample_bicubic2d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors);
+    }
+    
+    // aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_bicubic2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::_upsample_bicubic2d_aa_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, align_corners, scale_factors);
+    }
+    
+    // aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_bicubic2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors) {
+        return at::_ops::_upsample_bicubic2d_aa_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors);
+    }
+    
+    // aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_nearest1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::upsample_nearest1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, scale_factors);
+    }
+    
+    // aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_nearest1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::upsample_nearest1d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors);
+    }
+    
+    // aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_nearest_exact1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::_upsample_nearest_exact1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, scale_factors);
+    }
+    
+    // aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_nearest_exact1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::_upsample_nearest_exact1d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors);
+    }
+    
+    // aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_nearest2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::upsample_nearest2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, scale_factors);
+    }
+    
+    // aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_nearest2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::upsample_nearest2d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors);
+    }
+    
+    // aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_nearest_exact2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::_upsample_nearest_exact2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, scale_factors);
+    }
+    
+    // aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_nearest_exact2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::_upsample_nearest_exact2d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors);
+    }
+    
+    // aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_nearest3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::upsample_nearest3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, scale_factors);
+    }
+    
+    // aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor upsample_nearest3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::upsample_nearest3d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors);
+    }
+    
+    // aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_nearest_exact3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::_upsample_nearest_exact3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, scale_factors);
+    }
+    
+    // aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+    inline at::Tensor _upsample_nearest_exact3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, c10::optional> scale_factors) {
+        return at::_ops::_upsample_nearest_exact3d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors);
+    }
+    
+    // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales, out);
+    }
+    
+    // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales, at::Tensor & out) {
+        return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales, out);
+    }
+    
+    // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales, out);
+    }
+    
+    // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales, at::Tensor & out) {
+        return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales, out);
+    }
+    
+    // aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
+    inline at::Tensor upsample_linear1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales);
+    }
+    
+    // aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
+    inline at::Tensor upsample_linear1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d::redispatch(dispatchKeySet, self, output_size, align_corners, scales);
+    }
+    
+    // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales, grad_input);
+    }
+    
+    // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales, at::Tensor & grad_input) {
+        return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales, grad_input);
+    }
+    
+    // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales, grad_input);
+    }
+    
+    // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_linear1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales, at::Tensor & grad_input) {
+        return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales, grad_input);
+    }
+    
+    // aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
+    inline at::Tensor upsample_linear1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales);
+    }
+    
+    // aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
+    inline at::Tensor upsample_linear1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_linear1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales);
+    }
+    
+    // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bilinear2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bilinear2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bilinear2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bilinear2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bilinear2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bilinear2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bilinear2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bilinear2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bilinear2d_aa_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bilinear2d_aa_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bilinear2d_aa_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bilinear2d_aa_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bicubic2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bicubic2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_bicubic2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bicubic2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_bicubic2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_bicubic2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bicubic2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bicubic2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_bicubic2d_aa_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bicubic2d_aa_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_bicubic2d_aa_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_bicubic2d_aa_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
+    }
+    
+    // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_trilinear3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_trilinear3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_trilinear3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_trilinear3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_trilinear3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_trilinear3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out);
+    }
+    
+    // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales, at::Tensor & out) {
+        return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out);
+    }
+    
+    // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, output_size, scales, out);
+    }
+    
+    // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales, at::Tensor & out) {
+        return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, output_size, scales, out);
+    }
+    
+    // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out);
+    }
+    
+    // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales, at::Tensor & out) {
+        return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out);
+    }
+    
+    // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, output_size, scales, out);
+    }
+    
+    // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales, at::Tensor & out) {
+        return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, output_size, scales, out);
+    }
+    
+    // aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+    inline at::Tensor upsample_nearest1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales);
+    }
+    
+    // aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+    inline at::Tensor upsample_nearest1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d::redispatch(dispatchKeySet, self, output_size, scales);
+    }
+    
+    // aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales);
+    }
+    
+    // aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d::redispatch(dispatchKeySet, self, output_size, scales);
+    }
+    
+    // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input);
+    }
+    
+    // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales, at::Tensor & grad_input) {
+        return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input);
+    }
+    
+    // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input);
+    }
+    
+    // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales, at::Tensor & grad_input) {
+        return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales, at::Tensor & grad_input) {
+        return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales, at::Tensor & grad_input) {
+        return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input);
+    }
+    
+    // aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+    inline at::Tensor upsample_nearest1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales);
+    }
+    
+    // aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+    inline at::Tensor upsample_nearest1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::upsample_nearest1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales);
+    }
+    
+    // aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales);
+    }
+    
+    // aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales);
+    }
+    
+    // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & out) {
+        return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out);
+    }
+    
+    // aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & upsample_nearest3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & _upsample_nearest_exact3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, at::Tensor & grad_input) {
+        return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input);
+    }
+    
+    // aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w);
+    }
+    
+    // aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor upsample_nearest3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::upsample_nearest3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w);
+    }
+    
+    // aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+    inline at::Tensor _upsample_nearest_exact3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d=c10::nullopt, c10::optional scales_h=c10::nullopt, c10::optional scales_w=c10::nullopt) {
+        return at::_ops::_upsample_nearest_exact3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w);
+    }
+    
+    // aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & sigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output) {
+        return at::_ops::sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input);
+    }
+    
+    // aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & sigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input) {
+        return at::_ops::sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input);
+    }
+    
+    // aten::sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor
+    inline at::Tensor sigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output) {
+        return at::_ops::sigmoid_backward::redispatch(dispatchKeySet, grad_output, output);
+    }
+    
+    // aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & logit_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::logit_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, eps, grad_input);
+    }
+    
+    // aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & logit_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::optional eps, at::Tensor & grad_input) {
+        return at::_ops::logit_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, eps, grad_input);
+    }
+    
+    // aten::logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor
+    inline at::Tensor logit_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::logit_backward::redispatch(dispatchKeySet, grad_output, self, eps);
+    }
+    
+    // aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & tanh_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output) {
+        return at::_ops::tanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input);
+    }
+    
+    // aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
+    inline at::Tensor & tanh_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input) {
+        return at::_ops::tanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input);
+    }
+    
+    // aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor
+    inline at::Tensor tanh_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output) {
+        return at::_ops::tanh_backward::redispatch(dispatchKeySet, grad_output, output);
+    }
+    
+    // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out);
+    }
+    
+    // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out);
+    }
+    
+    // aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_transpose2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_transpose2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_transpose2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_transpose2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation);
+    }
+    
+    // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out);
+    }
+    
+    // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_transpose3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out);
+    }
+    
+    // aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_transpose3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_transpose3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_transpose3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_transpose3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation);
+    }
+    
+    // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & thnn_conv2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) {
+        return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & thnn_conv2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & thnn_conv2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) {
+        return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out);
+    }
+    
+    // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & thnn_conv2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out);
+    }
+    
+    // aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor
+    inline at::Tensor thnn_conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) {
+        return at::_ops::thnn_conv2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor
+    inline at::Tensor thnn_conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) {
+        return at::_ops::thnn_conv2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding);
+    }
+    
+    // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & _slow_conv2d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output);
+    }
+    
+    // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & _slow_conv2d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & output) {
+        return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output);
+    }
+    
+    // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & _slow_conv2d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) {
+        return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output);
+    }
+    
+    // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & _slow_conv2d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output) {
+        return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output);
+    }
+    
+    // aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor
+    inline at::Tensor _slow_conv2d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::_slow_conv2d_forward::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor
+    inline at::Tensor _slow_conv2d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) {
+        return at::_ops::_slow_conv2d_forward::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding);
+    }
+    
+    // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), grad_input, grad_weight, grad_bias);
+    }
+    
+    // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias) {
+        return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), grad_input, grad_weight, grad_bias);
+    }
+    
+    // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) {
+        return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, grad_input, grad_weight, grad_bias);
+    }
+    
+    // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias) {
+        return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, grad_input, grad_weight, grad_bias);
+    }
+    
+    // aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+    inline ::std::tuple _slow_conv2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask) {
+        return at::_ops::_slow_conv2d_backward_output_mask::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask);
+    }
+    
+    // aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+    inline ::std::tuple _slow_conv2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) {
+        return at::_ops::_slow_conv2d_backward_output_mask::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask);
+    }
+    
+    // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _conv_depthwise2d_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) {
+        return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _conv_depthwise2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, const at::Tensor & out) {
+        return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _conv_depthwise2d_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) {
+        return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _conv_depthwise2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, const at::Tensor & out) {
+        return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor
+    inline at::Tensor _conv_depthwise2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) {
+        return at::_ops::_conv_depthwise2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor
+    inline at::Tensor _conv_depthwise2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) {
+        return at::_ops::_conv_depthwise2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation);
+    }
+    
+    // aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor
+    inline at::Tensor conv_depthwise3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) {
+        return at::_ops::conv_depthwise3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor
+    inline at::Tensor conv_depthwise3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) {
+        return at::_ops::conv_depthwise3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation);
+    }
+    
+    // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) {
+        return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) {
+        return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out);
+    }
+    
+    // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) {
+        return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out);
+    }
+    
+    // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out) {
+        return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out);
+    }
+    
+    // aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor
+    inline at::Tensor slow_conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) {
+        return at::_ops::slow_conv3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor
+    inline at::Tensor slow_conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) {
+        return at::_ops::slow_conv3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding);
+    }
+    
+    // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output);
+    }
+    
+    // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & output) {
+        return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output);
+    }
+    
+    // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) {
+        return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output);
+    }
+    
+    // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)
+    inline at::Tensor & slow_conv3d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output) {
+        return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output);
+    }
+    
+    // aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor
+    inline at::Tensor slow_conv3d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) {
+        return at::_ops::slow_conv3d_forward::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding));
+    }
+    
+    // aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor
+    inline at::Tensor slow_conv3d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) {
+        return at::_ops::slow_conv3d_forward::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding);
+    }
+    
+    // aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_dilated2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_dilated2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_dilated2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_dilated2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation);
+    }
+    
+    // aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_dilated3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_dilated3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation));
+    }
+    
+    // aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor
+    inline at::Tensor slow_conv_dilated3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_dilated3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation);
+    }
+    
+    // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & col2im_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
+        return at::_ops::col2im_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride, out);
+    }
+    
+    // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & col2im_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::col2im_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride, out);
+    }
+    
+    // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & col2im_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
+        return at::_ops::col2im_out::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride, out);
+    }
+    
+    // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & col2im_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::col2im_out::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride, out);
+    }
+    
+    // aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+    inline at::Tensor col2im(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
+        return at::_ops::col2im::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride);
+    }
+    
+    // aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+    inline at::Tensor col2im_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
+        return at::_ops::col2im::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride);
+    }
+    
+    // aten::column_stack(Tensor[] tensors) -> Tensor
+    inline at::Tensor column_stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::column_stack::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & column_stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::column_stack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & column_stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::column_stack_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & im2col_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
+        return at::_ops::im2col_out::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride, out);
+    }
+    
+    // aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & im2col_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::im2col_out::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride, out);
+    }
+    
+    // aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+    inline at::Tensor im2col(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
+        return at::_ops::im2col::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride);
+    }
+    
+    // aten::isfinite(Tensor self) -> Tensor
+    inline at::Tensor isfinite(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::isfinite::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::isinf(Tensor self) -> Tensor
+    inline at::Tensor isinf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::isinf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::record_stream(Tensor(a!) self, Stream s) -> ()
+    inline void record_stream(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Stream s) {
+        return at::_ops::record_stream::redispatch(dispatchKeySet, self, s);
+    }
+    
+    // aten::isposinf(Tensor self) -> Tensor
+    inline at::Tensor isposinf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::isposinf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isposinf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::isposinf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isposinf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::isposinf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::isneginf(Tensor self) -> Tensor
+    inline at::Tensor isneginf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::isneginf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isneginf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::isneginf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isneginf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::isneginf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor
+    inline at::Tensor _add_batch_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t batch_dim, int64_t level) {
+        return at::_ops::_add_batch_dim::redispatch(dispatchKeySet, self, batch_dim, level);
+    }
+    
+    // aten::_remove_batch_dim(Tensor self, int level, int batch_size, int out_dim) -> Tensor
+    inline at::Tensor _remove_batch_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, int64_t batch_size, int64_t out_dim) {
+        return at::_ops::_remove_batch_dim::redispatch(dispatchKeySet, self, level, batch_size, out_dim);
+    }
+    
+    // aten::special_entr(Tensor self) -> Tensor
+    inline at::Tensor special_entr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_entr::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_entr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_entr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_entr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_entr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_ndtri(Tensor self) -> Tensor
+    inline at::Tensor special_ndtri(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_ndtri::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_ndtri_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_ndtri_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_ndtri_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_ndtri_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_log_ndtr(Tensor self) -> Tensor
+    inline at::Tensor special_log_ndtr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_log_ndtr::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_log_ndtr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_log_ndtr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_log_ndtr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_log_ndtr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_expm1(Tensor self) -> Tensor
+    inline at::Tensor special_expm1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_expm1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_expm1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_expm1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_expm1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_expm1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_exp2(Tensor self) -> Tensor
+    inline at::Tensor special_exp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_exp2::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_exp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_exp2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_exp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_exp2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_psi(Tensor self) -> Tensor
+    inline at::Tensor special_psi(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_psi::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_psi_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_psi_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_psi_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_psi_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_digamma(Tensor self) -> Tensor
+    inline at::Tensor special_digamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_digamma::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_digamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_digamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_digamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_digamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_gammaln(Tensor self) -> Tensor
+    inline at::Tensor special_gammaln(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_gammaln::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_gammaln_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_gammaln_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_gammaln_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_gammaln_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erf(Tensor self) -> Tensor
+    inline at::Tensor special_erf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_erf::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_erf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_erf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erfc(Tensor self) -> Tensor
+    inline at::Tensor special_erfc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_erfc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erfc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_erfc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erfc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_erfc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erfcx(Tensor self) -> Tensor
+    inline at::Tensor special_erfcx(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_erfcx::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erfcx_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_erfcx_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erfcx_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_erfcx_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erfinv(Tensor self) -> Tensor
+    inline at::Tensor special_erfinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_erfinv::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erfinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_erfinv_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_erfinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_erfinv_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_ndtr(Tensor self) -> Tensor
+    inline at::Tensor special_ndtr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_ndtr::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_ndtr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_ndtr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_ndtr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_ndtr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_xlog1py(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_xlog1py::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::special_xlog1py_self_scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::special_xlog1py_other_scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_xlog1py_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_xlog1py_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::special_xlog1py_self_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_xlog1py_self_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::special_xlog1py_other_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::special_xlog1py_other_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlogy(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_xlogy::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::special_xlogy_self_scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::special_xlogy_other_scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_xlogy_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_xlogy_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::special_xlogy_self_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_xlogy_self_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::special_xlogy_other_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::special_xlogy_other_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_zeta(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_zeta::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
+    inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::special_zeta_self_scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
+    inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::special_zeta_other_scalar::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_zeta_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_zeta_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::special_zeta_self_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_zeta_self_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::special_zeta_other_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::special_zeta_other_scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_i0(Tensor self) -> Tensor
+    inline at::Tensor special_i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_i0::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_i0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_i0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i0e(Tensor self) -> Tensor
+    inline at::Tensor special_i0e(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_i0e::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i0e_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_i0e_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i0e_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_i0e_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i1(Tensor self) -> Tensor
+    inline at::Tensor special_i1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_i1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_i1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_i1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i1e(Tensor self) -> Tensor
+    inline at::Tensor special_i1e(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_i1e::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i1e_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_i1e_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_i1e_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_i1e_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_logit(Tensor self, float? eps=None) -> Tensor
+    inline at::Tensor special_logit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::special_logit::redispatch(dispatchKeySet, self, eps);
+    }
+    
+    // aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_logit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional eps=c10::nullopt) {
+        return at::_ops::special_logit_out::redispatch(dispatchKeySet, self, eps, out);
+    }
+    
+    // aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_logit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional eps, at::Tensor & out) {
+        return at::_ops::special_logit_out::redispatch(dispatchKeySet, self, eps, out);
+    }
+    
+    // aten::special_polygamma(int n, Tensor self) -> Tensor
+    inline at::Tensor special_polygamma(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self) {
+        return at::_ops::special_polygamma::redispatch(dispatchKeySet, n, self);
+    }
+    
+    // aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_polygamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, const at::Tensor & self) {
+        return at::_ops::special_polygamma_out::redispatch(dispatchKeySet, n, self, out);
+    }
+    
+    // aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_polygamma_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_polygamma_out::redispatch(dispatchKeySet, n, self, out);
+    }
+    
+    // aten::special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+    inline at::Tensor special_logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::special_logsumexp::redispatch(dispatchKeySet, self, dim, keepdim);
+    }
+    
+    // aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) {
+        return at::_ops::special_logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) {
+        return at::_ops::special_logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out);
+    }
+    
+    // aten::special_expit(Tensor self) -> Tensor
+    inline at::Tensor special_expit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_expit::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_expit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_expit_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_expit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_expit_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_sinc(Tensor self) -> Tensor
+    inline at::Tensor special_sinc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_sinc::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_sinc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_sinc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_sinc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_sinc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_round(Tensor self, *, int decimals=0) -> Tensor
+    inline at::Tensor special_round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals=0) {
+        return at::_ops::special_round::redispatch(dispatchKeySet, self, decimals);
+    }
+    
+    // aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t decimals=0) {
+        return at::_ops::special_round_out::redispatch(dispatchKeySet, self, decimals, out);
+    }
+    
+    // aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals, at::Tensor & out) {
+        return at::_ops::special_round_out::redispatch(dispatchKeySet, self, decimals, out);
+    }
+    
+    // aten::special_log1p(Tensor self) -> Tensor
+    inline at::Tensor special_log1p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_log1p::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_log1p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_log1p_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_log1p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_log1p_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor special_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::special_log_softmax::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_gammainc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_gammainc_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_gammainc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_gammainc_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_gammainc(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor special_gammainc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_gammainc::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_gammaincc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_gammaincc_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_gammaincc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::special_gammaincc_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::special_gammaincc(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor special_gammaincc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::special_gammaincc::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::special_multigammaln(Tensor self, int p) -> Tensor
+    inline at::Tensor special_multigammaln(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p) {
+        return at::_ops::special_multigammaln::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_multigammaln_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t p) {
+        return at::_ops::special_multigammaln_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_multigammaln_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p, at::Tensor & out) {
+        return at::_ops::special_multigammaln_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor special_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional dtype=c10::nullopt) {
+        return at::_ops::special_softmax::redispatch(dispatchKeySet, self, dim, dtype);
+    }
+    
+    // aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_fft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_fft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft::redispatch(dispatchKeySet, self, n, dim, norm);
+    }
+    
+    // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_ifft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_ifft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft::redispatch(dispatchKeySet, self, n, dim, norm);
+    }
+    
+    // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_rfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_rfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft::redispatch(dispatchKeySet, self, n, dim, norm);
+    }
+    
+    // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_irfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_irfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft::redispatch(dispatchKeySet, self, n, dim, norm);
+    }
+    
+    // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_hfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_hfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft::redispatch(dispatchKeySet, self, n, dim, norm);
+    }
+    
+    // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_hfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_hfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_hfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_hfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_ihfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+    inline at::Tensor fft_ihfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft::redispatch(dispatchKeySet, self, n, dim, norm);
+    }
+    
+    // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ihfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ihfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n.has_value() ? c10::make_optional(c10::SymInt(*n)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ihfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional n=c10::nullopt, int64_t dim=-1, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ihfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional n, int64_t dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out);
+    }
+    
+    // aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_fft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft2::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_fft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft2::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_ifft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft2::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_ifft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft2::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_rfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft2::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_rfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft2::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_irfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft2::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_irfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft2::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_hfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft2::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_hfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft2::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfft2_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfft2_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_ihfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft2::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+    inline at::Tensor fft_ihfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft2::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfft2_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfft2_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::IntArrayRef dim={-2,-1}, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_fftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fftn::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_fftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fftn::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_ifftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifftn::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_ifftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifftn::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_ifftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_rfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfftn::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_rfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfftn::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_irfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfftn::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_irfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfftn::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_irfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, at::Tensor & out) {
+        return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_hfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfftn::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_hfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfftn::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfftn_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfftn_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_hfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_ihfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfftn::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm);
+    }
+    
+    // aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+    inline at::Tensor fft_ihfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfftn::redispatch(dispatchKeySet, self, s, dim, norm);
+    }
+    
+    // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfftn_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*s)) : c10::nullopt, dim, norm, out);
+    }
+    
+    // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfftn_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, c10::optional norm=c10::nullopt) {
+        return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & fft_ihfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, c10::optional norm, const at::Tensor & out) {
+        return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out);
+    }
+    
+    // aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor fft_fftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d=1.0, at::TensorOptions options={}) {
+        return at::_ops::fft_fftfreq::redispatch(dispatchKeySet, n, d, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor fft_fftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::fft_fftfreq::redispatch(dispatchKeySet, n, d, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fftfreq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, double d=1.0) {
+        return at::_ops::fft_fftfreq_out::redispatch(dispatchKeySet, n, d, out);
+    }
+    
+    // aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_fftfreq_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, at::Tensor & out) {
+        return at::_ops::fft_fftfreq_out::redispatch(dispatchKeySet, n, d, out);
+    }
+    
+    // aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor fft_rfftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d=1.0, at::TensorOptions options={}) {
+        return at::_ops::fft_rfftfreq::redispatch(dispatchKeySet, n, d, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    }
+    
+    // aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor fft_rfftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) {
+        return at::_ops::fft_rfftfreq::redispatch(dispatchKeySet, n, d, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfftfreq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, double d=1.0) {
+        return at::_ops::fft_rfftfreq_out::redispatch(dispatchKeySet, n, d, out);
+    }
+    
+    // aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fft_rfftfreq_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, at::Tensor & out) {
+        return at::_ops::fft_rfftfreq_out::redispatch(dispatchKeySet, n, d, out);
+    }
+    
+    // aten::fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor
+    inline at::Tensor fft_fftshift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt) {
+        return at::_ops::fft_fftshift::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor
+    inline at::Tensor fft_ifftshift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt) {
+        return at::_ops::fft_ifftshift::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
+    inline ::std::tuple linalg_cholesky_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false, bool check_errors=false) {
+        return at::_ops::linalg_cholesky_ex::redispatch(dispatchKeySet, self, upper, check_errors);
+    }
+    
+    // aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)
+    inline ::std::tuple linalg_cholesky_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & L, at::Tensor & info, const at::Tensor & self, bool upper=false, bool check_errors=false) {
+        return at::_ops::linalg_cholesky_ex_L::redispatch(dispatchKeySet, self, upper, check_errors, L, info);
+    }
+    
+    // aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)
+    inline ::std::tuple linalg_cholesky_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, bool check_errors, at::Tensor & L, at::Tensor & info) {
+        return at::_ops::linalg_cholesky_ex_L::redispatch(dispatchKeySet, self, upper, check_errors, L, info);
+    }
+    
+    // aten::linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
+    inline at::Tensor linalg_cholesky(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) {
+        return at::_ops::linalg_cholesky::redispatch(dispatchKeySet, self, upper);
+    }
+    
+    // aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cholesky_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) {
+        return at::_ops::linalg_cholesky_out::redispatch(dispatchKeySet, self, upper, out);
+    }
+    
+    // aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cholesky_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) {
+        return at::_ops::linalg_cholesky_out::redispatch(dispatchKeySet, self, upper, out);
+    }
+    
+    // aten::linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
+    inline at::Tensor linalg_cross(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t dim=-1) {
+        return at::_ops::linalg_cross::redispatch(dispatchKeySet, self, other, dim);
+    }
+    
+    // aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cross_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, int64_t dim=-1) {
+        return at::_ops::linalg_cross_out::redispatch(dispatchKeySet, self, other, dim, out);
+    }
+    
+    // aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cross_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t dim, at::Tensor & out) {
+        return at::_ops::linalg_cross_out::redispatch(dispatchKeySet, self, other, dim, out);
+    }
+    
+    // aten::linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
+    inline ::std::tuple linalg_lu_factor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true) {
+        return at::_ops::linalg_lu_factor::redispatch(dispatchKeySet, A, pivot);
+    }
+    
+    // aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
+    inline ::std::tuple linalg_lu_factor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A, bool pivot=true) {
+        return at::_ops::linalg_lu_factor_out::redispatch(dispatchKeySet, A, pivot, LU, pivots);
+    }
+    
+    // aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
+    inline ::std::tuple linalg_lu_factor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, at::Tensor & LU, at::Tensor & pivots) {
+        return at::_ops::linalg_lu_factor_out::redispatch(dispatchKeySet, A, pivot, LU, pivots);
+    }
+    
+    // aten::linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
+    inline ::std::tuple linalg_lu_factor_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true, bool check_errors=false) {
+        return at::_ops::linalg_lu_factor_ex::redispatch(dispatchKeySet, A, pivot, check_errors);
+    }
+    
+    // aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info)
+    inline ::std::tuple linalg_lu_factor_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info, const at::Tensor & A, bool pivot=true, bool check_errors=false) {
+        return at::_ops::linalg_lu_factor_ex_out::redispatch(dispatchKeySet, A, pivot, check_errors, LU, pivots, info);
+    }
+    
+    // aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info)
+    inline ::std::tuple linalg_lu_factor_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, bool check_errors, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info) {
+        return at::_ops::linalg_lu_factor_ex_out::redispatch(dispatchKeySet, A, pivot, check_errors, LU, pivots, info);
+    }
+    
+    // aten::linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)
+    inline ::std::tuple linalg_lu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true) {
+        return at::_ops::linalg_lu::redispatch(dispatchKeySet, A, pivot);
+    }
+    
+    // aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+    inline ::std::tuple linalg_lu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & A, bool pivot=true) {
+        return at::_ops::linalg_lu_out::redispatch(dispatchKeySet, A, pivot, P, L, U);
+    }
+    
+    // aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+    inline ::std::tuple linalg_lu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U) {
+        return at::_ops::linalg_lu_out::redispatch(dispatchKeySet, A, pivot, P, L, U);
+    }
+    
+    // aten::linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor
+    inline at::Tensor linalg_lu_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left=true, bool adjoint=false) {
+        return at::_ops::linalg_lu_solve::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint);
+    }
+    
+    // aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_lu_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left=true, bool adjoint=false) {
+        return at::_ops::linalg_lu_solve_out::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint, out);
+    }
+    
+    // aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_lu_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint, at::Tensor & out) {
+        return at::_ops::linalg_lu_solve_out::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint, out);
+    }
+    
+    // aten::_linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)
+    inline ::std::tuple _linalg_det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) {
+        return at::_ops::_linalg_det::redispatch(dispatchKeySet, A);
+    }
+    
+    // aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)
+    inline ::std::tuple _linalg_det_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A) {
+        return at::_ops::_linalg_det_result::redispatch(dispatchKeySet, A, result, LU, pivots);
+    }
+    
+    // aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)
+    inline ::std::tuple _linalg_det_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots) {
+        return at::_ops::_linalg_det_result::redispatch(dispatchKeySet, A, result, LU, pivots);
+    }
+    
+    // aten::linalg_det(Tensor A) -> Tensor
+    inline at::Tensor linalg_det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) {
+        return at::_ops::linalg_det::redispatch(dispatchKeySet, A);
+    }
+    
+    // aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_det_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A) {
+        return at::_ops::linalg_det_out::redispatch(dispatchKeySet, A, out);
+    }
+    
+    // aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_det_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & out) {
+        return at::_ops::linalg_det_out::redispatch(dispatchKeySet, A, out);
+    }
+    
+    // aten::det(Tensor self) -> Tensor
+    inline at::Tensor det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::det::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info)
+    inline ::std::tuple linalg_ldl_factor_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian=false, bool check_errors=false) {
+        return at::_ops::linalg_ldl_factor_ex::redispatch(dispatchKeySet, self, hermitian, check_errors);
+    }
+    
+    // aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info)
+    inline ::std::tuple linalg_ldl_factor_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info, const at::Tensor & self, bool hermitian=false, bool check_errors=false) {
+        return at::_ops::linalg_ldl_factor_ex_out::redispatch(dispatchKeySet, self, hermitian, check_errors, LD, pivots, info);
+    }
+    
+    // aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info)
+    inline ::std::tuple linalg_ldl_factor_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian, bool check_errors, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info) {
+        return at::_ops::linalg_ldl_factor_ex_out::redispatch(dispatchKeySet, self, hermitian, check_errors, LD, pivots, info);
+    }
+    
+    // aten::linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots)
+    inline ::std::tuple linalg_ldl_factor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian=false) {
+        return at::_ops::linalg_ldl_factor::redispatch(dispatchKeySet, self, hermitian);
+    }
+    
+    // aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots)
+    inline ::std::tuple linalg_ldl_factor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LD, at::Tensor & pivots, const at::Tensor & self, bool hermitian=false) {
+        return at::_ops::linalg_ldl_factor_out::redispatch(dispatchKeySet, self, hermitian, LD, pivots);
+    }
+    
+    // aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots)
+    inline ::std::tuple linalg_ldl_factor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian, at::Tensor & LD, at::Tensor & pivots) {
+        return at::_ops::linalg_ldl_factor_out::redispatch(dispatchKeySet, self, hermitian, LD, pivots);
+    }
+    
+    // aten::linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_ldl_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian=false) {
+        return at::_ops::linalg_ldl_solve::redispatch(dispatchKeySet, LD, pivots, B, hermitian);
+    }
+    
+    // aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_ldl_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian=false) {
+        return at::_ops::linalg_ldl_solve_out::redispatch(dispatchKeySet, LD, pivots, B, hermitian, out);
+    }
+    
+    // aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_ldl_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_ldl_solve_out::redispatch(dispatchKeySet, LD, pivots, B, hermitian, out);
+    }
+    
+    // aten::linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+    inline ::std::tuple linalg_lstsq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & b, c10::optional rcond=c10::nullopt, c10::optional driver=c10::nullopt) {
+        return at::_ops::linalg_lstsq::redispatch(dispatchKeySet, self, b, rcond, driver);
+    }
+    
+    // aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
+    inline ::std::tuple linalg_lstsq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values, const at::Tensor & self, const at::Tensor & b, c10::optional rcond=c10::nullopt, c10::optional driver=c10::nullopt) {
+        return at::_ops::linalg_lstsq_out::redispatch(dispatchKeySet, self, b, rcond, driver, solution, residuals, rank, singular_values);
+    }
+    
+    // aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
+    inline ::std::tuple linalg_lstsq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & b, c10::optional rcond, c10::optional driver, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values) {
+        return at::_ops::linalg_lstsq_out::redispatch(dispatchKeySet, self, b, rcond, driver, solution, residuals, rank, singular_values);
+    }
+    
+    // aten::linalg_matmul(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor linalg_matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::linalg_matmul::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::linalg_matmul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::linalg_matmul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
+    inline at::Tensor linalg_vecdot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & y, int64_t dim=-1) {
+        return at::_ops::linalg_vecdot::redispatch(dispatchKeySet, x, y, dim);
+    }
+    
+    // aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_vecdot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & y, int64_t dim=-1) {
+        return at::_ops::linalg_vecdot_out::redispatch(dispatchKeySet, x, y, dim, out);
+    }
+    
+    // aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_vecdot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & y, int64_t dim, at::Tensor & out) {
+        return at::_ops::linalg_vecdot_out::redispatch(dispatchKeySet, x, y, dim, out);
+    }
+    
+    // aten::linalg_matrix_exp(Tensor self) -> Tensor
+    inline at::Tensor linalg_matrix_exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::linalg_matrix_exp::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)
+    inline ::std::tuple _linalg_slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) {
+        return at::_ops::_linalg_slogdet::redispatch(dispatchKeySet, A);
+    }
+    
+    // aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots)
+    inline ::std::tuple _linalg_slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A) {
+        return at::_ops::_linalg_slogdet_sign::redispatch(dispatchKeySet, A, sign, logabsdet, LU, pivots);
+    }
+    
+    // aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots)
+    inline ::std::tuple _linalg_slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots) {
+        return at::_ops::_linalg_slogdet_sign::redispatch(dispatchKeySet, A, sign, logabsdet, LU, pivots);
+    }
+    
+    // aten::linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet)
+    inline ::std::tuple linalg_slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) {
+        return at::_ops::linalg_slogdet::redispatch(dispatchKeySet, A);
+    }
+    
+    // aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)
+    inline ::std::tuple linalg_slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, const at::Tensor & A) {
+        return at::_ops::linalg_slogdet_out::redispatch(dispatchKeySet, A, sign, logabsdet);
+    }
+    
+    // aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)
+    inline ::std::tuple linalg_slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet) {
+        return at::_ops::linalg_slogdet_out::redispatch(dispatchKeySet, A, sign, logabsdet);
+    }
+    
+    // aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
+    inline ::std::tuple slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::slogdet::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)
+    inline ::std::tuple slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, const at::Tensor & self) {
+        return at::_ops::slogdet_out::redispatch(dispatchKeySet, self, sign, logabsdet);
+    }
+    
+    // aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)
+    inline ::std::tuple slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & sign, at::Tensor & logabsdet) {
+        return at::_ops::slogdet_out::redispatch(dispatchKeySet, self, sign, logabsdet);
+    }
+    
+    // aten::logdet(Tensor self) -> Tensor
+    inline at::Tensor logdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::logdet::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
+    inline ::std::tuple linalg_eig(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::linalg_eig::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+    inline ::std::tuple linalg_eig_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigenvalues, at::Tensor & eigenvectors, const at::Tensor & self) {
+        return at::_ops::linalg_eig_out::redispatch(dispatchKeySet, self, eigenvalues, eigenvectors);
+    }
+    
+    // aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+    inline ::std::tuple linalg_eig_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & eigenvalues, at::Tensor & eigenvectors) {
+        return at::_ops::linalg_eig_out::redispatch(dispatchKeySet, self, eigenvalues, eigenvectors);
+    }
+    
+    // aten::_linalg_eigvals(Tensor self) -> Tensor
+    inline at::Tensor _linalg_eigvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_linalg_eigvals::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::linalg_eigvals(Tensor self) -> Tensor
+    inline at::Tensor linalg_eigvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::linalg_eigvals::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_eigvals_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::linalg_eigvals_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_eigvals_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::linalg_eigvals_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)
+    inline ::std::tuple _linalg_eigh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view UPLO="L", bool compute_v=true) {
+        return at::_ops::_linalg_eigh::redispatch(dispatchKeySet, A, UPLO, compute_v);
+    }
+    
+    // aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+    inline ::std::tuple _linalg_eigh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigenvalues, at::Tensor & eigenvectors, const at::Tensor & A, c10::string_view UPLO="L", bool compute_v=true) {
+        return at::_ops::_linalg_eigh_eigenvalues::redispatch(dispatchKeySet, A, UPLO, compute_v, eigenvalues, eigenvectors);
+    }
+    
+    // aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+    inline ::std::tuple _linalg_eigh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view UPLO, bool compute_v, at::Tensor & eigenvalues, at::Tensor & eigenvectors) {
+        return at::_ops::_linalg_eigh_eigenvalues::redispatch(dispatchKeySet, A, UPLO, compute_v, eigenvalues, eigenvectors);
+    }
+    
+    // aten::linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
+    inline ::std::tuple linalg_eigh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO="L") {
+        return at::_ops::linalg_eigh::redispatch(dispatchKeySet, self, UPLO);
+    }
+    
+    // aten::linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+    inline ::std::tuple linalg_eigh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigvals, at::Tensor & eigvecs, const at::Tensor & self, c10::string_view UPLO="L") {
+        return at::_ops::linalg_eigh_eigvals::redispatch(dispatchKeySet, self, UPLO, eigvals, eigvecs);
+    }
+    
+    // aten::linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+    inline ::std::tuple linalg_eigh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs) {
+        return at::_ops::linalg_eigh_eigvals::redispatch(dispatchKeySet, self, UPLO, eigvals, eigvecs);
+    }
+    
+    // aten::linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
+    inline at::Tensor linalg_eigvalsh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO="L") {
+        return at::_ops::linalg_eigvalsh::redispatch(dispatchKeySet, self, UPLO);
+    }
+    
+    // aten::linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_eigvalsh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view UPLO="L") {
+        return at::_ops::linalg_eigvalsh_out::redispatch(dispatchKeySet, self, UPLO, out);
+    }
+    
+    // aten::linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_eigvalsh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & out) {
+        return at::_ops::linalg_eigvalsh_out::redispatch(dispatchKeySet, self, UPLO, out);
+    }
+    
+    // aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor
+    inline at::Tensor linalg_householder_product(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tau) {
+        return at::_ops::linalg_householder_product::redispatch(dispatchKeySet, input, tau);
+    }
+    
+    // aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_householder_product_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & tau) {
+        return at::_ops::linalg_householder_product_out::redispatch(dispatchKeySet, input, tau, out);
+    }
+    
+    // aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_householder_product_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tau, at::Tensor & out) {
+        return at::_ops::linalg_householder_product_out::redispatch(dispatchKeySet, input, tau, out);
+    }
+    
+    // aten::linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
+    inline ::std::tuple linalg_inv_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool check_errors=false) {
+        return at::_ops::linalg_inv_ex::redispatch(dispatchKeySet, A, check_errors);
+    }
+    
+    // aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)
+    inline ::std::tuple linalg_inv_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & inverse, at::Tensor & info, const at::Tensor & A, bool check_errors=false) {
+        return at::_ops::linalg_inv_ex_inverse::redispatch(dispatchKeySet, A, check_errors, inverse, info);
+    }
+    
+    // aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)
+    inline ::std::tuple linalg_inv_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool check_errors, at::Tensor & inverse, at::Tensor & info) {
+        return at::_ops::linalg_inv_ex_inverse::redispatch(dispatchKeySet, A, check_errors, inverse, info);
+    }
+    
+    // aten::linalg_inv(Tensor A) -> Tensor
+    inline at::Tensor linalg_inv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) {
+        return at::_ops::linalg_inv::redispatch(dispatchKeySet, A);
+    }
+    
+    // aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_inv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A) {
+        return at::_ops::linalg_inv_out::redispatch(dispatchKeySet, A, out);
+    }
+    
+    // aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_inv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & out) {
+        return at::_ops::linalg_inv_out::redispatch(dispatchKeySet, A, out);
+    }
+    
+    // aten::inverse(Tensor self) -> Tensor
+    inline at::Tensor inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::inverse::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & inverse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::inverse_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & inverse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::inverse_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::inner(Tensor self, Tensor other) -> Tensor
+    inline at::Tensor inner(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::inner::redispatch(dispatchKeySet, self, other);
+    }
+    
+    // aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & inner_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::inner_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & inner_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::inner_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::outer(Tensor self, Tensor vec2) -> Tensor
+    inline at::Tensor outer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2) {
+        return at::_ops::outer::redispatch(dispatchKeySet, self, vec2);
+    }
+    
+    // aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & outer_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec2) {
+        return at::_ops::outer_out::redispatch(dispatchKeySet, self, vec2, out);
+    }
+    
+    // aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & outer_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out) {
+        return at::_ops::outer_out::redispatch(dispatchKeySet, self, vec2, out);
+    }
+    
+    // aten::ger(Tensor self, Tensor vec2) -> Tensor
+    inline at::Tensor ger(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2) {
+        return at::_ops::ger::redispatch(dispatchKeySet, self, vec2);
+    }
+    
+    // aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ger_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec2) {
+        return at::_ops::ger_out::redispatch(dispatchKeySet, self, vec2, out);
+    }
+    
+    // aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ger_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out) {
+        return at::_ops::ger_out::redispatch(dispatchKeySet, self, vec2, out);
+    }
+    
+    // aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor linalg_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & ord=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype);
+    }
+    
+    // aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor linalg_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_norm_ord_str::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype);
+    }
+    
+    // aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & ord=c10::nullopt, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::linalg_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_norm_ord_str_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::linalg_norm_ord_str_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor linalg_vector_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_vector_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype);
+    }
+    
+    // aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_vector_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_vector_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_vector_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::linalg_vector_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor linalg_matrix_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim={-2,-1}, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_matrix_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype);
+    }
+    
+    // aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim={-2,-1}, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_matrix_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::linalg_matrix_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+    inline at::Tensor linalg_matrix_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord="fro", at::IntArrayRef dim={-2,-1}, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_matrix_norm_str_ord::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype);
+    }
+    
+    // aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view ord="fro", at::IntArrayRef dim={-2,-1}, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::linalg_matrix_norm_str_ord_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::linalg_matrix_norm_str_ord_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
+    inline ::std::tuple _linalg_svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, c10::optional driver=c10::nullopt) {
+        return at::_ops::_linalg_svd::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver);
+    }
+    
+    // aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
+    inline ::std::tuple _linalg_svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, c10::optional driver=c10::nullopt) {
+        return at::_ops::_linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver, U, S, Vh);
+    }
+    
+    // aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
+    inline ::std::tuple _linalg_svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh) {
+        return at::_ops::_linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver, U, S, Vh);
+    }
+    
+    // aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
+    inline ::std::tuple linalg_svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices=true, c10::optional driver=c10::nullopt) {
+        return at::_ops::linalg_svd::redispatch(dispatchKeySet, A, full_matrices, driver);
+    }
+    
+    // aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
+    inline ::std::tuple linalg_svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=true, c10::optional driver=c10::nullopt) {
+        return at::_ops::linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, driver, U, S, Vh);
+    }
+    
+    // aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
+    inline ::std::tuple linalg_svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, c10::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh) {
+        return at::_ops::linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, driver, U, S, Vh);
+    }
+    
+    // aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor
+    inline at::Tensor linalg_svdvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::optional driver=c10::nullopt) {
+        return at::_ops::linalg_svdvals::redispatch(dispatchKeySet, A, driver);
+    }
+    
+    // aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_svdvals_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A, c10::optional driver=c10::nullopt) {
+        return at::_ops::linalg_svdvals_out::redispatch(dispatchKeySet, A, driver, out);
+    }
+    
+    // aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_svdvals_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::optional driver, at::Tensor & out) {
+        return at::_ops::linalg_svdvals_out::redispatch(dispatchKeySet, A, driver, out);
+    }
+    
+    // aten::linalg_cond(Tensor self, Scalar? p=None) -> Tensor
+    inline at::Tensor linalg_cond(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p=c10::nullopt) {
+        return at::_ops::linalg_cond::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cond_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p=c10::nullopt) {
+        return at::_ops::linalg_cond_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cond_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::Tensor & out) {
+        return at::_ops::linalg_cond_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::linalg_cond.p_str(Tensor self, str p) -> Tensor
+    inline at::Tensor linalg_cond(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view p) {
+        return at::_ops::linalg_cond_p_str::redispatch(dispatchKeySet, self, p);
+    }
+    
+    // aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cond_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view p) {
+        return at::_ops::linalg_cond_p_str_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_cond_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view p, at::Tensor & out) {
+        return at::_ops::linalg_cond_p_str_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & atol={}, const c10::optional & rtol={}, bool hermitian=false) {
+        return at::_ops::linalg_pinv_atol_rtol_tensor::redispatch(dispatchKeySet, self, atol, rtol, hermitian);
+    }
+    
+    // aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & atol={}, const c10::optional & rtol={}, bool hermitian=false) {
+        return at::_ops::linalg_pinv_atol_rtol_tensor_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & atol, const c10::optional & rtol, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_pinv_atol_rtol_tensor_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian=false) {
+        return at::_ops::linalg_pinv_atol_rtol_float::redispatch(dispatchKeySet, self, atol, rtol, hermitian);
+    }
+    
+    // aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian=false) {
+        return at::_ops::linalg_pinv_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_pinv_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond, bool hermitian=false) {
+        return at::_ops::linalg_pinv::redispatch(dispatchKeySet, self, rcond, hermitian);
+    }
+    
+    // aten::linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & rcond, bool hermitian=false) {
+        return at::_ops::linalg_pinv_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian);
+    }
+    
+    // aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double rcond, bool hermitian=false) {
+        return at::_ops::linalg_pinv_out::redispatch(dispatchKeySet, self, rcond, hermitian, out);
+    }
+    
+    // aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_pinv_out::redispatch(dispatchKeySet, self, rcond, hermitian, out);
+    }
+    
+    // aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & rcond, bool hermitian=false) {
+        return at::_ops::linalg_pinv_out_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian, out);
+    }
+    
+    // aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & rcond, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_pinv_out_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian, out);
+    }
+    
+    // aten::_linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)
+    inline ::std::tuple _linalg_solve_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) {
+        return at::_ops::_linalg_solve_ex::redispatch(dispatchKeySet, A, B, left, check_errors);
+    }
+    
+    // aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info)
+    inline ::std::tuple _linalg_solve_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) {
+        return at::_ops::_linalg_solve_ex_result::redispatch(dispatchKeySet, A, B, left, check_errors, result, LU, pivots, info);
+    }
+    
+    // aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info)
+    inline ::std::tuple _linalg_solve_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info) {
+        return at::_ops::_linalg_solve_ex_result::redispatch(dispatchKeySet, A, B, left, check_errors, result, LU, pivots, info);
+    }
+    
+    // aten::linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)
+    inline ::std::tuple linalg_solve_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) {
+        return at::_ops::linalg_solve_ex::redispatch(dispatchKeySet, A, B, left, check_errors);
+    }
+    
+    // aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info)
+    inline ::std::tuple linalg_solve_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & info, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) {
+        return at::_ops::linalg_solve_ex_out::redispatch(dispatchKeySet, A, B, left, check_errors, result, info);
+    }
+    
+    // aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info)
+    inline ::std::tuple linalg_solve_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & info) {
+        return at::_ops::linalg_solve_ex_out::redispatch(dispatchKeySet, A, B, left, check_errors, result, info);
+    }
+    
+    // aten::linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
+    inline at::Tensor linalg_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true) {
+        return at::_ops::linalg_solve::redispatch(dispatchKeySet, A, B, left);
+    }
+    
+    // aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A, const at::Tensor & B, bool left=true) {
+        return at::_ops::linalg_solve_out::redispatch(dispatchKeySet, A, B, left, out);
+    }
+    
+    // aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, at::Tensor & out) {
+        return at::_ops::linalg_solve_out::redispatch(dispatchKeySet, A, B, left, out);
+    }
+    
+    // aten::linalg_tensorinv(Tensor self, int ind=2) -> Tensor
+    inline at::Tensor linalg_tensorinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t ind=2) {
+        return at::_ops::linalg_tensorinv::redispatch(dispatchKeySet, self, ind);
+    }
+    
+    // aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_tensorinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t ind=2) {
+        return at::_ops::linalg_tensorinv_out::redispatch(dispatchKeySet, self, ind, out);
+    }
+    
+    // aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_tensorinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t ind, at::Tensor & out) {
+        return at::_ops::linalg_tensorinv_out::redispatch(dispatchKeySet, self, ind, out);
+    }
+    
+    // aten::linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor
+    inline at::Tensor linalg_tensorsolve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims=c10::nullopt) {
+        return at::_ops::linalg_tensorsolve::redispatch(dispatchKeySet, self, other, dims);
+    }
+    
+    // aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_tensorsolve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims=c10::nullopt) {
+        return at::_ops::linalg_tensorsolve_out::redispatch(dispatchKeySet, self, other, dims, out);
+    }
+    
+    // aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_tensorsolve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims, at::Tensor & out) {
+        return at::_ops::linalg_tensorsolve_out::redispatch(dispatchKeySet, self, other, dims, out);
+    }
+    
+    // aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
+    inline ::std::tuple linalg_qr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view mode="reduced") {
+        return at::_ops::linalg_qr::redispatch(dispatchKeySet, A, mode);
+    }
+    
+    // aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
+    inline ::std::tuple linalg_qr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & Q, at::Tensor & R, const at::Tensor & A, c10::string_view mode="reduced") {
+        return at::_ops::linalg_qr_out::redispatch(dispatchKeySet, A, mode, Q, R);
+    }
+    
+    // aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
+    inline ::std::tuple linalg_qr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view mode, at::Tensor & Q, at::Tensor & R) {
+        return at::_ops::linalg_qr_out::redispatch(dispatchKeySet, A, mode, Q, R);
+    }
+    
+    // aten::linalg_matrix_power(Tensor self, int n) -> Tensor
+    inline at::Tensor linalg_matrix_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n) {
+        return at::_ops::linalg_matrix_power::redispatch(dispatchKeySet, self, n);
+    }
+    
+    // aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n) {
+        return at::_ops::linalg_matrix_power_out::redispatch(dispatchKeySet, self, n, out);
+    }
+    
+    // aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, at::Tensor & out) {
+        return at::_ops::linalg_matrix_power_out::redispatch(dispatchKeySet, self, n, out);
+    }
+    
+    // aten::linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & atol={}, const c10::optional & rtol={}, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_atol_rtol_tensor::redispatch(dispatchKeySet, input, atol, rtol, hermitian);
+    }
+    
+    // aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const c10::optional & atol={}, const c10::optional & rtol={}, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_atol_rtol_tensor_out::redispatch(dispatchKeySet, input, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & atol, const c10::optional & rtol, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_matrix_rank_atol_rtol_tensor_out::redispatch(dispatchKeySet, input, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_atol_rtol_float::redispatch(dispatchKeySet, self, atol, rtol, hermitian);
+    }
+    
+    // aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_matrix_rank_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double tol, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank::redispatch(dispatchKeySet, self, tol, hermitian);
+    }
+    
+    // aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double tol, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_out::redispatch(dispatchKeySet, self, tol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double tol, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_matrix_rank_out::redispatch(dispatchKeySet, self, tol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor
+    inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tol, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian);
+    }
+    
+    // aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & tol, bool hermitian=false) {
+        return at::_ops::linalg_matrix_rank_out_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian, out);
+    }
+    
+    // aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tol, bool hermitian, at::Tensor & out) {
+        return at::_ops::linalg_matrix_rank_out_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian, out);
+    }
+    
+    // aten::linalg_multi_dot(Tensor[] tensors) -> Tensor
+    inline at::Tensor linalg_multi_dot(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::linalg_multi_dot::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_multi_dot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::linalg_multi_dot_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_multi_dot_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::linalg_multi_dot_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor
+    inline at::Tensor nested_to_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=c10::nullopt) {
+        return at::_ops::nested_to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size);
+    }
+    
+    // aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
+    inline at::Tensor _test_serialization_subcmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_test_serialization_subcmul::redispatch(dispatchKeySet, self, other, alpha);
+    }
+    
+    // aten::_test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor
+    inline at::Tensor _test_parallel_materialize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_parallel, bool skip_first=false) {
+        return at::_ops::_test_parallel_materialize::redispatch(dispatchKeySet, self, num_parallel, skip_first);
+    }
+    
+    // aten::_test_optional_intlist(Tensor values, int[]? addends) -> Tensor
+    inline at::Tensor _test_optional_intlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends) {
+        return at::_ops::_test_optional_intlist::redispatch(dispatchKeySet, values, addends);
+    }
+    
+    // aten::_test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor
+    inline at::Tensor _test_optional_filled_intlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends) {
+        return at::_ops::_test_optional_filled_intlist::redispatch(dispatchKeySet, values, addends);
+    }
+    
+    // aten::_test_optional_floatlist(Tensor values, float[]? addends) -> Tensor
+    inline at::Tensor _test_optional_floatlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, c10::optional> addends) {
+        return at::_ops::_test_optional_floatlist::redispatch(dispatchKeySet, values, addends);
+    }
+    
+    // aten::_test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor
+    inline at::Tensor _test_string_default(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, c10::string_view a="\"'\\", c10::string_view b="\"'\\") {
+        return at::_ops::_test_string_default::redispatch(dispatchKeySet, dummy, a, b);
+    }
+    
+    // aten::_test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor
+    inline at::Tensor _test_ambiguous_defaults(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, int64_t a=1, int64_t b=1) {
+        return at::_ops::_test_ambiguous_defaults_a::redispatch(dispatchKeySet, dummy, a, b);
+    }
+    
+    // aten::_test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor
+    inline at::Tensor _test_ambiguous_defaults(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, int64_t a, c10::string_view b) {
+        return at::_ops::_test_ambiguous_defaults_b::redispatch(dispatchKeySet, dummy, a, b);
+    }
+    
+    // aten::_test_warn_in_autograd(Tensor self) -> Tensor
+    inline at::Tensor _test_warn_in_autograd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_test_warn_in_autograd::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor
+    inline at::Tensor _test_autograd_multiple_dispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_test_autograd_multiple_dispatch_fullcoverage::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor
+    inline at::Tensor _test_autograd_multiple_dispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool b) {
+        return at::_ops::_test_autograd_multiple_dispatch_ntonly::redispatch(dispatchKeySet, self, b);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)
+    inline at::Tensor _test_autograd_multiple_dispatch_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_test_autograd_multiple_dispatch_view::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor
+    inline at::Tensor _test_autograd_multiple_dispatch_view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_test_autograd_multiple_dispatch_view_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+    inline at::Tensor segment_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, c10::string_view reduce, const c10::optional & lengths={}, const c10::optional & indices={}, const c10::optional & offsets={}, int64_t axis=0, bool unsafe=false, const c10::optional & initial=c10::nullopt) {
+        return at::_ops::segment_reduce::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial);
+    }
+    
+    // aten::_segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor
+    inline at::Tensor _segment_reduce_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const c10::optional & lengths={}, const c10::optional & offsets={}, int64_t axis=0, const c10::optional & initial=c10::nullopt) {
+        return at::_ops::_segment_reduce_backward::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial);
+    }
+    
+    // aten::pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor
+    inline at::Tensor pad_sequence(c10::DispatchKeySet dispatchKeySet, at::TensorList sequences, bool batch_first=false, double padding_value=0.0) {
+        return at::_ops::pad_sequence::redispatch(dispatchKeySet, sequences, batch_first, padding_value);
+    }
+    
+    // aten::flatten_dense_tensors(Tensor[] tensors) -> Tensor
+    inline at::Tensor flatten_dense_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) {
+        return at::_ops::flatten_dense_tensors::redispatch(dispatchKeySet, tensors);
+    }
+    
+    // aten::unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]
+    inline ::std::vector unflatten_dense_tensors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & flat, at::TensorList tensors) {
+        return at::_ops::unflatten_dense_tensors::redispatch(dispatchKeySet, flat, tensors);
+    }
+    
+    // aten::_nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+    inline at::Tensor _nested_tensor_from_tensor_list(c10::DispatchKeySet dispatchKeySet, at::TensorList list, c10::optional dtype=c10::nullopt, c10::optional layout=c10::nullopt, c10::optional device=c10::nullopt, c10::optional pin_memory=c10::nullopt) {
+        return at::_ops::_nested_tensor_from_tensor_list::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory);
+    }
+    
+    // aten::_fw_primal_copy(Tensor self, int level) -> Tensor
+    inline at::Tensor _fw_primal_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level) {
+        return at::_ops::_fw_primal_copy::redispatch(dispatchKeySet, self, level);
+    }
+    
+    // aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor
+    inline at::Tensor _make_dual_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) {
+        return at::_ops::_make_dual_copy::redispatch(dispatchKeySet, primal, tangent, level);
+    }
+    
+    // aten::view_as_real_copy(Tensor self) -> Tensor
+    inline at::Tensor view_as_real_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::view_as_real_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::view_as_complex_copy(Tensor self) -> Tensor
+    inline at::Tensor view_as_complex_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::view_as_complex_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_conj_copy(Tensor self) -> Tensor
+    inline at::Tensor _conj_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_conj_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_neg_view_copy(Tensor self) -> Tensor
+    inline at::Tensor _neg_view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_neg_view_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+    inline at::Tensor as_strided_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+    }
+    
+    // aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+    inline at::Tensor as_strided_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_copy::redispatch(dispatchKeySet, self, size, stride, storage_offset);
+    }
+    
+    // aten::_sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor
+    inline at::Tensor _sparse_broadcast_to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::_sparse_broadcast_to_copy::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor
+    inline at::Tensor diagonal_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) {
+        return at::_ops::diagonal_copy::redispatch(dispatchKeySet, self, offset, dim1, dim2);
+    }
+    
+    // aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
+    inline at::Tensor expand_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) {
+        return at::_ops::expand_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit);
+    }
+    
+    // aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
+    inline at::Tensor expand_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) {
+        return at::_ops::expand_copy::redispatch(dispatchKeySet, self, size, implicit);
+    }
+    
+    // aten::permute_copy(Tensor self, int[] dims) -> Tensor
+    inline at::Tensor permute_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) {
+        return at::_ops::permute_copy::redispatch(dispatchKeySet, self, dims);
+    }
+    
+    // aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor
+    inline at::Tensor _reshape_alias_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) {
+        return at::_ops::_reshape_alias_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+    }
+    
+    // aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor
+    inline at::Tensor _reshape_alias_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
+        return at::_ops::_reshape_alias_copy::redispatch(dispatchKeySet, self, size, stride);
+    }
+    
+    // aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor
+    inline at::Tensor select_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index) {
+        return at::_ops::select_copy_int::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor
+    inline at::Tensor select_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_copy_int::redispatch(dispatchKeySet, self, dim, index);
+    }
+    
+    // aten::detach_copy(Tensor self) -> Tensor
+    inline at::Tensor detach_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::detach_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+    inline at::Tensor slice_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) {
+        return at::_ops::slice_copy_Tensor::redispatch(dispatchKeySet, self, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+    }
+    
+    // aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+    inline at::Tensor slice_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) {
+        return at::_ops::slice_copy_Tensor::redispatch(dispatchKeySet, self, dim, start, end, step);
+    }
+    
+    // aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+    inline ::std::vector split_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) {
+        return at::_ops::split_copy_Tensor::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+    inline ::std::vector split_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) {
+        return at::_ops::split_copy_Tensor::redispatch(dispatchKeySet, self, split_size, dim);
+    }
+    
+    // aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+    inline ::std::vector split_with_sizes_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::split_with_sizes_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim);
+    }
+    
+    // aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+    inline ::std::vector split_with_sizes_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::split_with_sizes_copy::redispatch(dispatchKeySet, self, split_sizes, dim);
+    }
+    
+    // aten::squeeze_copy(Tensor self) -> Tensor
+    inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::squeeze_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::squeeze_copy.dim(Tensor self, int dim) -> Tensor
+    inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::squeeze_copy_dim::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
+    inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::squeeze_copy_dims::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::t_copy(Tensor self) -> Tensor
+    inline at::Tensor t_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::t_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor
+    inline at::Tensor transpose_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::transpose_copy_int::redispatch(dispatchKeySet, self, dim0, dim1);
+    }
+    
+    // aten::unsqueeze_copy(Tensor self, int dim) -> Tensor
+    inline at::Tensor unsqueeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) {
+        return at::_ops::unsqueeze_copy::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::_indices_copy(Tensor self) -> Tensor
+    inline at::Tensor _indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_indices_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_values_copy(Tensor self) -> Tensor
+    inline at::Tensor _values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::_values_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::indices_copy(Tensor self) -> Tensor
+    inline at::Tensor indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::indices_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::values_copy(Tensor self) -> Tensor
+    inline at::Tensor values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::values_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::crow_indices_copy(Tensor self) -> Tensor
+    inline at::Tensor crow_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::crow_indices_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::col_indices_copy(Tensor self) -> Tensor
+    inline at::Tensor col_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::col_indices_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::ccol_indices_copy(Tensor self) -> Tensor
+    inline at::Tensor ccol_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::ccol_indices_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::row_indices_copy(Tensor self) -> Tensor
+    inline at::Tensor row_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::row_indices_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]
+    inline ::std::vector unbind_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0) {
+        return at::_ops::unbind_copy_int::redispatch(dispatchKeySet, self, dim);
+    }
+    
+    // aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unbind_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t dim=0) {
+        return at::_ops::unbind_copy_int_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unbind_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::TensorList out) {
+        return at::_ops::unbind_copy_int_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t split_size, int64_t dim=0) {
+        return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) {
+        return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) {
+        return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) {
+        return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_with_sizes_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out);
+    }
+    
+    // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_with_sizes_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
+        return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out);
+    }
+    
+    // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_with_sizes_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, split_sizes, dim, out);
+    }
+    
+    // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void split_with_sizes_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) {
+        return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, split_sizes, dim, out);
+    }
+    
+    // aten::view_copy(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::view_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size));
+    }
+    
+    // aten::view_copy(Tensor self, SymInt[] size) -> Tensor
+    inline at::Tensor view_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::view_copy::redispatch(dispatchKeySet, self, size);
+    }
+    
+    // aten::view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor
+    inline at::Tensor view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) {
+        return at::_ops::view_copy_dtype::redispatch(dispatchKeySet, self, dtype);
+    }
+    
+    // aten::unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor
+    inline at::Tensor unfold_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) {
+        return at::_ops::unfold_copy::redispatch(dispatchKeySet, self, dimension, size, step);
+    }
+    
+    // aten::alias_copy(Tensor self) -> Tensor
+    inline at::Tensor alias_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::alias_copy::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
+    inline at::Tensor to_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=c10::nullopt) {
+        return at::_ops::to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt);
+    }
+    
+    // aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
+    inline at::Tensor to_padded_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size=c10::nullopt) {
+        return at::_ops::to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size);
+    }
+    
+    // aten::_nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor
+    inline at::Tensor _nested_tensor_softmax_with_shape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & query) {
+        return at::_ops::_nested_tensor_softmax_with_shape::redispatch(dispatchKeySet, self, query);
+    }
+    
+    // aten::_transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor
+    inline at::Tensor _transformer_encoder_layer_fwd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const c10::optional & mask={}, c10::optional mask_type=c10::nullopt) {
+        return at::_ops::_transformer_encoder_layer_fwd::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type);
+    }
+    
+    // aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)
+    inline ::std::tuple _native_multi_head_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const c10::optional & mask={}, bool need_weights=true, bool average_attn_weights=true, c10::optional mask_type=c10::nullopt) {
+        return at::_ops::_native_multi_head_attention::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type);
+    }
+    
+    // aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
+    inline at::Tensor scaled_dot_product_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, c10::optional scale=c10::nullopt) {
+        return at::_ops::scaled_dot_product_attention::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, scale);
+    }
+    
+    // aten::_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int
+    inline int64_t _fused_sdp_choice(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, c10::optional scale=c10::nullopt) {
+        return at::_ops::_fused_sdp_choice::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, scale);
+    }
+    
+    // aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
+    inline ::std::tuple _scaled_dot_product_attention_math(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, const c10::optional & dropout_mask={}, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_attention_math::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale);
+    }
+    
+    // aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+    inline ::std::tuple _scaled_dot_product_flash_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_flash_attention::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, return_debug_mask, scale);
+    }
+    
+    // aten::_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
+    inline ::std::tuple _scaled_dot_product_flash_attention_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, const c10::optional & attn_mask={}, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_flash_attention_for_cpu::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, attn_mask, scale);
+    }
+    
+    // aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
+    inline ::std::tuple _scaled_dot_product_flash_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale);
+    }
+    
+    // aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
+    inline ::std::tuple _scaled_dot_product_flash_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale);
+    }
+    
+    // aten::_scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
+    inline ::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const c10::optional & attn_mask={}, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_flash_attention_for_cpu_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, attn_mask, scale);
+    }
+    
+    // aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
+    inline ::std::tuple _scaled_dot_product_efficient_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & attn_bias, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_efficient_attention::redispatch(dispatchKeySet, query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, scale);
+    }
+    
+    // aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _scaled_dot_product_efficient_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal=false, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, attn_bias, out, logsumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale);
+    }
+    
+    // aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
+    inline ::std::tuple _scaled_dot_product_cudnn_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, c10::optional scale=c10::nullopt) {
+        return at::_ops::_scaled_dot_product_cudnn_attention::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, return_debug_mask, scale);
+    }
+    
+    // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+    inline ::std::tuple _flash_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & cum_seq_q, const c10::optional & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional scale=c10::nullopt) {
+        return at::_ops::_flash_attention_forward::redispatch(dispatchKeySet, query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
+    }
+    
+    // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+    inline ::std::tuple _flash_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & cum_seq_q, const c10::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional scale=c10::nullopt) {
+        return at::_ops::_flash_attention_forward::redispatch(dispatchKeySet, query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale);
+    }
+    
+    // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _flash_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, c10::optional scale=c10::nullopt) {
+        return at::_ops::_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale);
+    }
+    
+    // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+    inline ::std::tuple _flash_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, c10::optional scale=c10::nullopt) {
+        return at::_ops::_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale);
+    }
+    
+    // aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, int? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
+    inline ::std::tuple _efficient_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & bias, const c10::optional & cu_seqlens_q, const c10::optional & cu_seqlens_k, c10::optional max_seqlen_q, c10::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp=false, c10::optional scale=c10::nullopt, const c10::optional & causal_diagonal={}, const c10::optional & seqlen_k={}) {
+        return at::_ops::_efficient_attention_forward::redispatch(dispatchKeySet, query, key, value, bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, custom_mask_type, compute_log_sumexp, scale, causal_diagonal, seqlen_k);
+    }
+    
+    // aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _efficient_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & bias, const at::Tensor & out, const c10::optional & cu_seqlens_q, const c10::optional & cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, c10::optional scale=c10::nullopt, c10::optional num_splits_key=c10::nullopt) {
+        return at::_ops::_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key);
+    }
+    
+    // aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
+    inline ::std::tuple _efficient_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const c10::optional & bias, const at::Tensor & out, const c10::optional & cu_seqlens_q, const c10::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, c10::optional scale=c10::nullopt, c10::optional num_splits_key=c10::nullopt) {
+        return at::_ops::_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key);
+    }
+    
+    // aten::_triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
+    inline at::Tensor _triton_scaled_dot_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0) {
+        return at::_ops::_triton_scaled_dot_attention::redispatch(dispatchKeySet, q, k, v, dropout_p);
+    }
+    
+    // aten::_fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)
+    inline at::Tensor & _fill_mem_eff_dropout_mask_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double dropout_p, int64_t seed, int64_t offset) {
+        return at::_ops::_fill_mem_eff_dropout_mask_::redispatch(dispatchKeySet, self, dropout_p, seed, offset);
+    }
+    
+    // aten::_triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor
+    inline at::Tensor _triton_multi_head_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const c10::optional & mask={}) {
+        return at::_ops::_triton_multi_head_attention::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask);
+    }
+    
+    // aten::special_airy_ai(Tensor x) -> Tensor
+    inline at::Tensor special_airy_ai(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) {
+        return at::_ops::special_airy_ai::redispatch(dispatchKeySet, x);
+    }
+    
+    // aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_airy_ai_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) {
+        return at::_ops::special_airy_ai_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_airy_ai_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) {
+        return at::_ops::special_airy_ai_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_bessel_j0(Tensor self) -> Tensor
+    inline at::Tensor special_bessel_j0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_bessel_j0::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_j0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_bessel_j0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_j0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_bessel_j0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_j1(Tensor self) -> Tensor
+    inline at::Tensor special_bessel_j1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_bessel_j1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_j1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_bessel_j1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_j1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_bessel_j1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_y0(Tensor self) -> Tensor
+    inline at::Tensor special_bessel_y0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_bessel_y0::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_y0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_bessel_y0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_y0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_bessel_y0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_y1(Tensor self) -> Tensor
+    inline at::Tensor special_bessel_y1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_bessel_y1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_y1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_bessel_y1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_bessel_y1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_bessel_y1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_t::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_t_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_t_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_u::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_u_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_u_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_v::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_v_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_v_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_w::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_w_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_w_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_h::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_h_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_hermite_polynomial_h_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_h_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_hermite_polynomial_h_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_h_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_hermite_polynomial_h_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_hermite_polynomial_h_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_hermite_polynomial_h_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_he::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_he_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_hermite_polynomial_he_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_he_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_hermite_polynomial_he_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_hermite_polynomial_he_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_hermite_polynomial_he_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_hermite_polynomial_he_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_hermite_polynomial_he_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_laguerre_polynomial_l::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_laguerre_polynomial_l_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_laguerre_polynomial_l_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_laguerre_polynomial_l_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_laguerre_polynomial_l_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_laguerre_polynomial_l_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_laguerre_polynomial_l_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_laguerre_polynomial_l_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_laguerre_polynomial_l_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_legendre_polynomial_p::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_legendre_polynomial_p_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_legendre_polynomial_p_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_legendre_polynomial_p_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_legendre_polynomial_p_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_legendre_polynomial_p_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_legendre_polynomial_p_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_legendre_polynomial_p_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_legendre_polynomial_p_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_modified_bessel_i0(Tensor self) -> Tensor
+    inline at::Tensor special_modified_bessel_i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_i0::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_i0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_modified_bessel_i0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_i1(Tensor self) -> Tensor
+    inline at::Tensor special_modified_bessel_i1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_i1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_i1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_i1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_i1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_modified_bessel_i1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_k0(Tensor self) -> Tensor
+    inline at::Tensor special_modified_bessel_k0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_k0::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_k0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_k0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_k0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_modified_bessel_k0_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_k1(Tensor self) -> Tensor
+    inline at::Tensor special_modified_bessel_k1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_k1::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_k1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::special_modified_bessel_k1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_modified_bessel_k1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::special_modified_bessel_k1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::special_scaled_modified_bessel_k0(Tensor x) -> Tensor
+    inline at::Tensor special_scaled_modified_bessel_k0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) {
+        return at::_ops::special_scaled_modified_bessel_k0::redispatch(dispatchKeySet, x);
+    }
+    
+    // aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_scaled_modified_bessel_k0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) {
+        return at::_ops::special_scaled_modified_bessel_k0_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_scaled_modified_bessel_k0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) {
+        return at::_ops::special_scaled_modified_bessel_k0_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_scaled_modified_bessel_k1(Tensor x) -> Tensor
+    inline at::Tensor special_scaled_modified_bessel_k1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) {
+        return at::_ops::special_scaled_modified_bessel_k1::redispatch(dispatchKeySet, x);
+    }
+    
+    // aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_scaled_modified_bessel_k1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) {
+        return at::_ops::special_scaled_modified_bessel_k1_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_scaled_modified_bessel_k1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) {
+        return at::_ops::special_scaled_modified_bessel_k1_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
+    inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar::redispatch(dispatchKeySet, x, n);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) {
+        return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out);
+    }
+    
+    // aten::special_spherical_bessel_j0(Tensor x) -> Tensor
+    inline at::Tensor special_spherical_bessel_j0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) {
+        return at::_ops::special_spherical_bessel_j0::redispatch(dispatchKeySet, x);
+    }
+    
+    // aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_spherical_bessel_j0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) {
+        return at::_ops::special_spherical_bessel_j0_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & special_spherical_bessel_j0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) {
+        return at::_ops::special_spherical_bessel_j0_out::redispatch(dispatchKeySet, x, out);
+    }
+    
+    // aten::_foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor
+    inline at::Tensor _foobar(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool arg1=true, bool arg2=true, bool arg3=true) {
+        return at::_ops::_foobar::redispatch(dispatchKeySet, self, arg1, arg2, arg3);
+    }
+    
+    // aten::_fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+    inline void _fused_adam_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adam_::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+    inline void _fused_adam_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adam__tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+    inline void _fused_adamw_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adamw_::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+    inline void _fused_adamw_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adamw__tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+    inline void _fused_sgd_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_sgd_::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+    inline void _fused_sgd_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_sgd__tensor_lr::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf);
+    }
+    
+    // aten::_propagate_xla_data(Tensor input, Tensor output) -> ()
+    inline void _propagate_xla_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & output) {
+        return at::_ops::_propagate_xla_data::redispatch(dispatchKeySet, input, output);
+    }
+    
+    // aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _new_zeros_with_same_feature_meta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims=0) {
+        return at::_ops::_new_zeros_with_same_feature_meta_out::redispatch(dispatchKeySet, self, other, self_num_batch_dims, out);
+    }
+    
+    // aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _new_zeros_with_same_feature_meta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims, at::Tensor & out) {
+        return at::_ops::_new_zeros_with_same_feature_meta_out::redispatch(dispatchKeySet, self, other, self_num_batch_dims, out);
+    }
+    
+    // aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _cudnn_ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) {
+        return at::_ops::_cudnn_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity, out0, out1);
+    }
+    
+    // aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _cudnn_ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_cudnn_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity, out0, out1);
+    }
+    
+    // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cudnn_rnn_flatten_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional) {
+        return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out);
+    }
+    
+    // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cudnn_rnn_flatten_weight_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out) {
+        return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out);
+    }
+    
+    // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cudnn_rnn_flatten_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) {
+        return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out);
+    }
+    
+    // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cudnn_rnn_flatten_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out) {
+        return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out);
+    }
+    
+    // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple _cudnn_rnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const at::Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state) {
+        return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple _cudnn_rnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const at::Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) {
+        return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple _cudnn_rnn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const at::Tensor & hx, const c10::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state) {
+        return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple _cudnn_rnn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const at::Tensor & hx, const c10::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) {
+        return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()
+    inline void _cudnn_rnn_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) {
+        return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask, out0, out1, out2, out3);
+    }
+    
+    // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()
+    inline void _cudnn_rnn_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) {
+        return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask, out0, out1, out2, out3);
+    }
+    
+    // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()
+    inline void _cudnn_rnn_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) {
+        return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3);
+    }
+    
+    // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()
+    inline void _cudnn_rnn_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) {
+        return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3);
+    }
+    
+    // aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cudnn_init_dropout_state_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double dropout, bool train, int64_t dropout_seed) {
+        return at::_ops::_cudnn_init_dropout_state_out::redispatch(dispatchKeySet, dropout, train, dropout_seed, out);
+    }
+    
+    // aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cudnn_init_dropout_state_outf(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, at::Tensor & out) {
+        return at::_ops::_cudnn_init_dropout_state_out::redispatch(dispatchKeySet, dropout, train, dropout_seed, out);
+    }
+    
+    // aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _fused_dropout_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, double p, c10::optional generator=c10::nullopt) {
+        return at::_ops::_fused_dropout_out::redispatch(dispatchKeySet, self, p, generator, out0, out1);
+    }
+    
+    // aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _fused_dropout_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, c10::optional generator, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_fused_dropout_out::redispatch(dispatchKeySet, self, p, generator, out0, out1);
+    }
+    
+    // aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _masked_scale_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, double scale) {
+        return at::_ops::_masked_scale_out::redispatch(dispatchKeySet, self, mask, scale, out);
+    }
+    
+    // aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _masked_scale_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, double scale, at::Tensor & out) {
+        return at::_ops::_masked_scale_out::redispatch(dispatchKeySet, self, mask, scale, out);
+    }
+    
+    // aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple native_dropout_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double p, c10::optional train) {
+        return at::_ops::native_dropout_out::redispatch(dispatchKeySet, input, p, train, out0, out1);
+    }
+    
+    // aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple native_dropout_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, c10::optional train, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::native_dropout_out::redispatch(dispatchKeySet, input, p, train, out0, out1);
+    }
+    
+    // aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & native_dropout_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & mask, double scale) {
+        return at::_ops::native_dropout_backward_out::redispatch(dispatchKeySet, grad_output, mask, scale, out);
+    }
+    
+    // aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & native_dropout_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, double scale, at::Tensor & out) {
+        return at::_ops::native_dropout_backward_out::redispatch(dispatchKeySet, grad_output, mask, scale, out);
+    }
+    
+    // aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _conj_physical_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_conj_physical_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _conj_physical_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_conj_physical_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::_add_relu_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::_add_relu_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::add_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::add_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & affine_grid_generator_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) {
+        return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners, out);
+    }
+    
+    // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & affine_grid_generator_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, at::IntArrayRef size, bool align_corners, at::Tensor & out) {
+        return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners, out);
+    }
+    
+    // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & affine_grid_generator_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) {
+        return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, size, align_corners, out);
+    }
+    
+    // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & affine_grid_generator_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out) {
+        return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, size, align_corners, out);
+    }
+    
+    // aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_functorch_fallback_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::_test_functorch_fallback_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_functorch_fallback_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::_test_functorch_fallback_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bartlett_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) {
+        return at::_ops::bartlett_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bartlett_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) {
+        return at::_ops::bartlett_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bartlett_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) {
+        return at::_ops::bartlett_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bartlett_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) {
+        return at::_ops::bartlett_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) {
+        return at::_ops::quantized_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point, out);
+    }
+    
+    // aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point, at::Tensor & out) {
+        return at::_ops::quantized_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point, out);
+    }
+    
+    // aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & p, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli_Tensor_out::redispatch(dispatchKeySet, self, p, generator, out);
+    }
+    
+    // aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & p, c10::optional generator, at::Tensor & out) {
+        return at::_ops::bernoulli_Tensor_out::redispatch(dispatchKeySet, self, p, generator, out);
+    }
+    
+    // aten::bernoulli.Tensor(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor
+    inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & p, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli_Tensor::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p=0.5, c10::optional generator=c10::nullopt) {
+        return at::_ops::bernoulli_float_out::redispatch(dispatchKeySet, self, p, generator, out);
+    }
+    
+    // aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, c10::optional generator, at::Tensor & out) {
+        return at::_ops::bernoulli_float_out::redispatch(dispatchKeySet, self, p, generator, out);
+    }
+    
+    // aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & binary_cross_entropy_with_logits_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight={}, const c10::optional & pos_weight={}, int64_t reduction=at::Reduction::Mean) {
+        return at::_ops::binary_cross_entropy_with_logits_out::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction, out);
+    }
+    
+    // aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & binary_cross_entropy_with_logits_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const c10::optional & weight, const c10::optional & pos_weight, int64_t reduction, at::Tensor & out) {
+        return at::_ops::binary_cross_entropy_with_logits_out::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction, out);
+    }
+    
+    // aten::bincount.out(Tensor self, Tensor? weights=None, int minlength=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bincount_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & weights={}, int64_t minlength=0) {
+        return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out);
+    }
+    
+    // aten::bincount.out(Tensor self, Tensor? weights=None, int minlength=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bincount_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & weights, int64_t minlength, at::Tensor & out) {
+        return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out);
+    }
+    
+    // aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & blackman_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) {
+        return at::_ops::blackman_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & blackman_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) {
+        return at::_ops::blackman_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & blackman_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) {
+        return at::_ops::blackman_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & blackman_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) {
+        return at::_ops::blackman_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & block_diag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) {
+        return at::_ops::block_diag_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & block_diag_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) {
+        return at::_ops::block_diag_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & constant_pad_nd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value=0) {
+        return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value, out);
+    }
+    
+    // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & constant_pad_nd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value, out);
+    }
+    
+    // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & constant_pad_nd_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value=0) {
+        return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, pad, value, out);
+    }
+    
+    // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & constant_pad_nd_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, pad, value, out);
+    }
+    
+    // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
+        return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out);
+    }
+    
+    // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, at::Tensor & out) {
+        return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out);
+    }
+    
+    // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
+        return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out);
+    }
+    
+    // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out);
+    }
+    
+    // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : c10::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : c10::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_overrideable_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
+        return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out);
+    }
+    
+    // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_overrideable_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, at::Tensor & out) {
+        return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out);
+    }
+    
+    // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_overrideable_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) {
+        return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out);
+    }
+    
+    // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & convolution_overrideable_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out);
+    }
+    
+    // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_overrideable_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_overrideable_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_overrideable_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple convolution_backward_overrideable_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
+        return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out);
+    }
+    
+    // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out) {
+        return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out);
+    }
+    
+    // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
+        return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out);
+    }
+    
+    // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out) {
+        return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out);
+    }
+    
+    // aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conv_tbc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad=0) {
+        return at::_ops::conv_tbc_out::redispatch(dispatchKeySet, self, weight, bias, pad, out);
+    }
+    
+    // aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conv_tbc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad, at::Tensor & out) {
+        return at::_ops::conv_tbc_out::redispatch(dispatchKeySet, self, weight, bias, pad, out);
+    }
+    
+    // aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) {
+        return at::_ops::copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out);
+    }
+    
+    // aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out) {
+        return at::_ops::copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out);
+    }
+    
+    // aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _copy_from_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & dst, bool non_blocking=false) {
+        return at::_ops::_copy_from_out::redispatch(dispatchKeySet, self, dst, non_blocking, out);
+    }
+    
+    // aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _copy_from_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out) {
+        return at::_ops::_copy_from_out::redispatch(dispatchKeySet, self, dst, non_blocking, out);
+    }
+    
+    // aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _copy_from_and_resize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & dst) {
+        return at::_ops::_copy_from_and_resize_out::redispatch(dispatchKeySet, self, dst, out);
+    }
+    
+    // aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _copy_from_and_resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, at::Tensor & out) {
+        return at::_ops::_copy_from_and_resize_out::redispatch(dispatchKeySet, self, dst, out);
+    }
+    
+    // aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & count_nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::count_nonzero_dim_IntList_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & count_nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) {
+        return at::_ops::count_nonzero_dim_IntList_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & count_nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dim=c10::nullopt) {
+        return at::_ops::count_nonzero_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & count_nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dim, at::Tensor & out) {
+        return at::_ops::count_nonzero_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_affine_grid_generator_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) {
+        return at::_ops::cudnn_affine_grid_generator_out::redispatch(dispatchKeySet, theta, N, C, H, W, out);
+    }
+    
+    // aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_affine_grid_generator_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out) {
+        return at::_ops::cudnn_affine_grid_generator_out::redispatch(dispatchKeySet, theta, N, C, H, W, out);
+    }
+    
+    // aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_affine_grid_generator_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) {
+        return at::_ops::cudnn_affine_grid_generator_backward_out::redispatch(dispatchKeySet, grad, N, C, H, W, out);
+    }
+    
+    // aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_affine_grid_generator_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out) {
+        return at::_ops::cudnn_affine_grid_generator_backward_out::redispatch(dispatchKeySet, grad, N, C, H, W, out);
+    }
+    
+    // aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple cudnn_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon) {
+        return at::_ops::cudnn_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2, out3);
+    }
+    
+    // aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple cudnn_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) {
+        return at::_ops::cudnn_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2, out3);
+    }
+    
+    // aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple cudnn_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, const at::Tensor & reserveSpace) {
+        return at::_ops::cudnn_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace, out0, out1, out2);
+    }
+    
+    // aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple cudnn_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, const at::Tensor & reserveSpace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::cudnn_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace, out0, out1, out2);
+    }
+    
+    // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) {
+        return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out);
+    }
+    
+    // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) {
+        return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, out);
+    }
+    
+    // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, out);
+    }
+    
+    // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple mps_convolution_transpose_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1);
+    }
+    
+    // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple mps_convolution_transpose_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1);
+    }
+    
+    // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple mps_convolution_transpose_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask, out0, out1);
+    }
+    
+    // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple mps_convolution_transpose_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask, out0, out1);
+    }
+    
+    // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_relu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups, out);
+    }
+    
+    // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_relu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups, out);
+    }
+    
+    // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_add_relu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups, out);
+    }
+    
+    // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_convolution_add_relu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups, out);
+    }
+    
+    // aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_grid_sampler_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & grid) {
+        return at::_ops::cudnn_grid_sampler_out::redispatch(dispatchKeySet, self, grid, out);
+    }
+    
+    // aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cudnn_grid_sampler_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, at::Tensor & out) {
+        return at::_ops::cudnn_grid_sampler_out::redispatch(dispatchKeySet, self, grid, out);
+    }
+    
+    // aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple cudnn_grid_sampler_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) {
+        return at::_ops::cudnn_grid_sampler_backward_out::redispatch(dispatchKeySet, self, grid, grad_output, out0, out1);
+    }
+    
+    // aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple cudnn_grid_sampler_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::cudnn_grid_sampler_backward_out::redispatch(dispatchKeySet, self, grid, grad_output, out0, out1);
+    }
+    
+    // aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1);
+    }
+    
+    // aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1);
+    }
+    
+    // aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss_Tensor_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1);
+    }
+    
+    // aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_ctc_loss_Tensor_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1);
+    }
+    
+    // aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _ctc_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) {
+        return at::_ops::_ctc_loss_backward_out::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity, out);
+    }
+    
+    // aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _ctc_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity, at::Tensor & out) {
+        return at::_ops::_ctc_loss_backward_out::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity, out);
+    }
+    
+    // aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diag_embed_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) {
+        return at::_ops::diag_embed_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out);
+    }
+    
+    // aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diag_embed_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) {
+        return at::_ops::diag_embed_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out);
+    }
+    
+    // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+        return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2, out);
+    }
+    
+    // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) {
+        return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2, out);
+    }
+    
+    // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
+        return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2, out);
+    }
+    
+    // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) {
+        return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2, out);
+    }
+    
+    // aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::div_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::div_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, c10::optional rounding_mode) {
+        return at::_ops::div_Scalar_mode_out::redispatch(dispatchKeySet, self, other, rounding_mode, out);
+    }
+    
+    // aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, c10::optional rounding_mode, at::Tensor & out) {
+        return at::_ops::div_Scalar_mode_out::redispatch(dispatchKeySet, self, other, rounding_mode, out);
+    }
+    
+    // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) {
+        return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out);
+    }
+    
+    // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out) {
+        return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out);
+    }
+    
+    // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) {
+        return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out);
+    }
+    
+    // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out) {
+        return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out);
+    }
+    
+    // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_dense_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) {
+        return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out);
+    }
+    
+    // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_dense_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, at::Tensor & out) {
+        return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out);
+    }
+    
+    // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_dense_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) {
+        return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out);
+    }
+    
+    // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_dense_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out) {
+        return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out);
+    }
+    
+    // aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_renorm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) {
+        return at::_ops::embedding_renorm_out::redispatch(dispatchKeySet, self, indices, max_norm, norm_type, out);
+    }
+    
+    // aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & embedding_renorm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type, at::Tensor & out) {
+        return at::_ops::embedding_renorm_out::redispatch(dispatchKeySet, self, indices, max_norm, norm_type, out);
+    }
+    
+    // aten::embedding_renorm(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor
+    inline at::Tensor embedding_renorm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) {
+        return at::_ops::embedding_renorm::redispatch(dispatchKeySet, self, indices, max_norm, norm_type);
+    }
+    
+    // aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple _embedding_bag_forward_only_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const c10::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_forward_only_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3);
+    }
+    
+    // aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple _embedding_bag_forward_only_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) {
+        return at::_ops::_embedding_bag_forward_only_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3);
+    }
+    
+    // aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple _embedding_bag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const c10::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3);
+    }
+    
+    // aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple _embedding_bag_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) {
+        return at::_ops::_embedding_bag_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3);
+    }
+    
+    // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _embedding_bag_dense_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out);
+    }
+    
+    // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _embedding_bag_dense_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out) {
+        return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out);
+    }
+    
+    // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _embedding_bag_dense_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out);
+    }
+    
+    // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _embedding_bag_dense_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out) {
+        return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out);
+    }
+    
+    // aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _embedding_bag_per_sample_weights_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1) {
+        return at::_ops::_embedding_bag_per_sample_weights_backward_out::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx, out);
+    }
+    
+    // aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _embedding_bag_per_sample_weights_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx, at::Tensor & out) {
+        return at::_ops::_embedding_bag_per_sample_weights_backward_out::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx, out);
+    }
+    
+    // aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional names, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_names_out::redispatch(dispatchKeySet, size, names, memory_format, out);
+    }
+    
+    // aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::empty_names_out::redispatch(dispatchKeySet, size, names, memory_format, out);
+    }
+    
+    // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_permuted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, at::IntArrayRef physical_layout) {
+        return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, out);
+    }
+    
+    // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_permuted_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out) {
+        return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, out);
+    }
+    
+    // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_permuted_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, at::IntArrayRef physical_layout) {
+        return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, size, physical_layout, out);
+    }
+    
+    // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_permuted_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out) {
+        return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, size, physical_layout, out);
+    }
+    
+    // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_strided_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) {
+        return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_strided_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_strided_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
+        return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, size, stride, out);
+    }
+    
+    // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_empty_strided_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) {
+        return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, size, stride, out);
+    }
+    
+    // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value) {
+        return at::_ops::new_full_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, out);
+    }
+    
+    // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_full_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) {
+        return at::_ops::new_full_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, out);
+    }
+    
+    // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_full_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value) {
+        return at::_ops::new_full_out::redispatch(dispatchKeySet, self, size, fill_value, out);
+    }
+    
+    // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_full_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) {
+        return at::_ops::new_full_out::redispatch(dispatchKeySet, self, size, fill_value, out);
+    }
+    
+    // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_zeros_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_zeros_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_zeros_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_ones_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_ones_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & new_ones_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_affine_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, double scale=1, int64_t zero_point=0, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scale, zero_point, memory_format, out);
+    }
+    
+    // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_affine_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, double scale, int64_t zero_point, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scale, zero_point, memory_format, out);
+    }
+    
+    // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_affine_quantized_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, double scale=1, int64_t zero_point=0, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, size, scale, zero_point, memory_format, out);
+    }
+    
+    // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_affine_quantized_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, double scale, int64_t zero_point, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, size, scale, zero_point, memory_format, out);
+    }
+    
+    // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_per_channel_affine_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
+    }
+    
+    // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_per_channel_affine_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
+    }
+    
+    // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional memory_format=MemoryFormat::Contiguous) {
+        return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, size, scales, zero_points, axis, memory_format, out);
+    }
+    
+    // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, size, scales, zero_points, axis, memory_format, out);
+    }
+    
+    // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format, out);
+    }
+    
+    // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional memory_format, const at::Tensor & out) {
+        return at::_ops::resize_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format, out);
+    }
+    
+    // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_out::redispatch(dispatchKeySet, self, size, memory_format, out);
+    }
+    
+    // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format, const at::Tensor & out) {
+        return at::_ops::resize_out::redispatch(dispatchKeySet, self, size, memory_format, out);
+    }
+    
+    // aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format);
+    }
+    
+    // aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor resize_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize::redispatch(dispatchKeySet, self, size, memory_format);
+    }
+    
+    // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _resize_output_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::Device device) {
+        return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device, out);
+    }
+    
+    // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _resize_output_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device, const at::Tensor & out) {
+        return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device, out);
+    }
+    
+    // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _resize_output_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) {
+        return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, size, device, out);
+    }
+    
+    // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & _resize_output_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device, const at::Tensor & out) {
+        return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, size, device, out);
+    }
+    
+    // aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor
+    inline at::Tensor _resize_output(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device) {
+        return at::_ops::_resize_output::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device);
+    }
+    
+    // aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor
+    inline at::Tensor _resize_output_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) {
+        return at::_ops::_resize_output::redispatch(dispatchKeySet, self, size, device);
+    }
+    
+    // aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Tensor & qtensor, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_quantized_out::redispatch(dispatchKeySet, size, qtensor, memory_format, out);
+    }
+    
+    // aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::empty_quantized_out::redispatch(dispatchKeySet, size, qtensor, memory_format, out);
+    }
+    
+    // aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::empty_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::empty_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_strided_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, at::IntArrayRef stride) {
+        return at::_ops::empty_strided_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_strided_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::empty_strided_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_strided_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
+        return at::_ops::empty_strided_out::redispatch(dispatchKeySet, size, stride, out);
+    }
+    
+    // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & empty_strided_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) {
+        return at::_ops::empty_strided_out::redispatch(dispatchKeySet, size, stride, out);
+    }
+    
+    // aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & value) {
+        return at::_ops::fill_Scalar_out::redispatch(dispatchKeySet, self, value, out);
+    }
+    
+    // aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::fill_Scalar_out::redispatch(dispatchKeySet, self, value, out);
+    }
+    
+    // aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & value) {
+        return at::_ops::fill_Tensor_out::redispatch(dispatchKeySet, self, value, out);
+    }
+    
+    // aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & value, at::Tensor & out) {
+        return at::_ops::fill_Tensor_out::redispatch(dispatchKeySet, self, value, out);
+    }
+    
+    // aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & floor_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::floor_divide_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & floor_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::floor_divide_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Scalar & fill_value, c10::optional names) {
+        return at::_ops::full_names_out::redispatch(dispatchKeySet, size, fill_value, names, out);
+    }
+    
+    // aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, c10::optional names, at::Tensor & out) {
+        return at::_ops::full_names_out::redispatch(dispatchKeySet, size, fill_value, names, out);
+    }
+    
+    // aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & fill_value, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::full_like_out::redispatch(dispatchKeySet, self, fill_value, memory_format, out);
+    }
+    
+    // aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & full_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::full_like_out::redispatch(dispatchKeySet, self, fill_value, memory_format, out);
+    }
+    
+    // aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & from_file_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::string_view filename, c10::optional shared=c10::nullopt, c10::optional size=0) {
+        return at::_ops::from_file_out::redispatch(dispatchKeySet, filename, shared, size, out);
+    }
+    
+    // aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & from_file_outf(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, c10::optional shared, c10::optional size, at::Tensor & out) {
+        return at::_ops::from_file_out::redispatch(dispatchKeySet, filename, shared, size, out);
+    }
+    
+    // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & grid_sampler_2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::grid_sampler_2d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out);
+    }
+    
+    // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & grid_sampler_2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) {
+        return at::_ops::grid_sampler_2d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out);
+    }
+    
+    // aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple grid_sampler_2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) {
+        return at::_ops::grid_sampler_2d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1);
+    }
+    
+    // aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple grid_sampler_2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::grid_sampler_2d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1);
+    }
+    
+    // aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _grid_sampler_2d_cpu_fallback_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::_grid_sampler_2d_cpu_fallback_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out);
+    }
+    
+    // aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _grid_sampler_2d_cpu_fallback_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) {
+        return at::_ops::_grid_sampler_2d_cpu_fallback_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out);
+    }
+    
+    // aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & grid_sampler_3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
+        return at::_ops::grid_sampler_3d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out);
+    }
+    
+    // aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & grid_sampler_3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) {
+        return at::_ops::grid_sampler_3d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out);
+    }
+    
+    // aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple grid_sampler_3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) {
+        return at::_ops::grid_sampler_3d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1);
+    }
+    
+    // aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple grid_sampler_3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::grid_sampler_3d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1);
+    }
+    
+    // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hann_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) {
+        return at::_ops::hann_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hann_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) {
+        return at::_ops::hann_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hann_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) {
+        return at::_ops::hann_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hann_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) {
+        return at::_ops::hann_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) {
+        return at::_ops::hamming_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) {
+        return at::_ops::hamming_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) {
+        return at::_ops::hamming_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) {
+        return at::_ops::hamming_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double alpha) {
+        return at::_ops::hamming_window_periodic_alpha_out::redispatch(dispatchKeySet, window_length, periodic, alpha, out);
+    }
+    
+    // aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, at::Tensor & out) {
+        return at::_ops::hamming_window_periodic_alpha_out::redispatch(dispatchKeySet, window_length, periodic, alpha, out);
+    }
+    
+    // aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double alpha, double beta) {
+        return at::_ops::hamming_window_periodic_alpha_beta_out::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, out);
+    }
+    
+    // aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, at::Tensor & out) {
+        return at::_ops::hamming_window_periodic_alpha_beta_out::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, out);
+    }
+    
+    // aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) {
+        return at::_ops::kaiser_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) {
+        return at::_ops::kaiser_window_out::redispatch(dispatchKeySet, window_length, out);
+    }
+    
+    // aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) {
+        return at::_ops::kaiser_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) {
+        return at::_ops::kaiser_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out);
+    }
+    
+    // aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double beta) {
+        return at::_ops::kaiser_window_beta_out::redispatch(dispatchKeySet, window_length, periodic, beta, out);
+    }
+    
+    // aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, at::Tensor & out) {
+        return at::_ops::kaiser_window_beta_out::redispatch(dispatchKeySet, window_length, periodic, beta, out);
+    }
+    
+    // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) {
+        return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) {
+        return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) {
+        return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) {
+        return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_group_norm_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2);
+    }
+    
+    // aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_put_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false) {
+        return at::_ops::index_put_out::redispatch(dispatchKeySet, self, indices, values, accumulate, out);
+    }
+    
+    // aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_put_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out) {
+        return at::_ops::index_put_out::redispatch(dispatchKeySet, self, indices, values, accumulate, out);
+    }
+    
+    // aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _index_put_impl_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) {
+        return at::_ops::_index_put_impl_out::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe, out);
+    }
+    
+    // aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _index_put_impl_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate, bool unsafe, at::Tensor & out) {
+        return at::_ops::_index_put_impl_out::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe, out);
+    }
+    
+    // aten::_index_put_impl(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor
+    inline at::Tensor _index_put_impl(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) {
+        return at::_ops::_index_put_impl::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe);
+    }
+    
+    // aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isnan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::isnan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isnan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::isnan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps) {
+        return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps) {
+        return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask) {
+        return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask) {
+        return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_layer_norm_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, out0, out1, out2);
+    }
+    
+    // aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple linear_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) {
+        return at::_ops::linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2);
+    }
+    
+    // aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple linear_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2);
+    }
+    
+    // aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_linear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias={}) {
+        return at::_ops::mkldnn_linear_out::redispatch(dispatchKeySet, self, weight, bias, out);
+    }
+    
+    // aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_linear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::Tensor & out) {
+        return at::_ops::mkldnn_linear_out::redispatch(dispatchKeySet, self, weight, bias, out);
+    }
+    
+    // aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_linear_backward_input_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) {
+        return at::_ops::mkldnn_linear_backward_input_out::redispatch(dispatchKeySet, input_size, grad_output, weight, out);
+    }
+    
+    // aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_linear_backward_input_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight, at::Tensor & out) {
+        return at::_ops::mkldnn_linear_backward_input_out::redispatch(dispatchKeySet, input_size, grad_output, weight, out);
+    }
+    
+    // aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple mkldnn_linear_backward_weights_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) {
+        return at::_ops::mkldnn_linear_backward_weights_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined, out0, out1);
+    }
+    
+    // aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple mkldnn_linear_backward_weights_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::mkldnn_linear_backward_weights_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined, out0, out1);
+    }
+    
+    // aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple mkldnn_linear_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) {
+        return at::_ops::mkldnn_linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2);
+    }
+    
+    // aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple mkldnn_linear_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::mkldnn_linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2);
+    }
+    
+    // aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple matmul_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) {
+        return at::_ops::matmul_backward_out::redispatch(dispatchKeySet, grad, self, other, mask, out0, out1);
+    }
+    
+    // aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple matmul_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::matmul_backward_out::redispatch(dispatchKeySet, grad, self, other, mask, out0, out1);
+    }
+    
+    // aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self) {
+        return at::_ops::_aminmax_out::redispatch(dispatchKeySet, self, out0, out1);
+    }
+    
+    // aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_aminmax_out::redispatch(dispatchKeySet, self, out0, out1);
+    }
+    
+    // aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, int64_t dim, bool keepdim=false) {
+        return at::_ops::_aminmax_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out0, out1);
+    }
+    
+    // aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_aminmax_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out0, out1);
+    }
+    
+    // aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::mkldnn_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::mkldnn_max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::mkldnn_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::mkldnn_max_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::mkldnn_max_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_max_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::quantized_max_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_max_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::quantized_max_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::quantized_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::quantized_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) {
+        return at::_ops::quantized_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantized_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) {
+        return at::_ops::quantized_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out);
+    }
+    
+    // aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::median_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::median_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::nanmedian_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::nanmedian_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) {
+        return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out);
+    }
+    
+    // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mps_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out);
+    }
+    
+    // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple mps_convolution_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple mps_convolution_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple mps_convolution_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) {
+        return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple mps_convolution_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask, out0, out1, out2);
+    }
+    
+    // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
+        return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) {
+        return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) {
+        return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out);
+    }
+    
+    // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out);
+    }
+    
+    // aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple mkldnn_rnn_layer_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) {
+        return at::_ops::mkldnn_rnn_layer_out::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, out0, out1, out2, out3);
+    }
+    
+    // aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple mkldnn_rnn_layer_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) {
+        return at::_ops::mkldnn_rnn_layer_out::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, out0, out1, out2, out3);
+    }
+    
+    // aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
+    inline ::std::tuple mkldnn_rnn_layer_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) {
+        return at::_ops::mkldnn_rnn_layer_backward_out::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, out0, out1, out2, out3, out4, out5, out6);
+    }
+    
+    // aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
+    inline ::std::tuple mkldnn_rnn_layer_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6) {
+        return at::_ops::mkldnn_rnn_layer_backward_out::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, out0, out1, out2, out3, out4, out5, out6);
+    }
+    
+    // aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple miopen_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon) {
+        return at::_ops::miopen_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2);
+    }
+    
+    // aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple miopen_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::miopen_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2);
+    }
+    
+    // aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple miopen_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon) {
+        return at::_ops::miopen_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, out0, out1, out2);
+    }
+    
+    // aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple miopen_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::miopen_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, out0, out1, out2);
+    }
+    
+    // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) {
+        return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) {
+        return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) {
+        return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) {
+        return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_depthwise_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_depthwise_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) {
+        return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_depthwise_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) {
+        return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & miopen_depthwise_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) {
+        return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out);
+    }
+    
+    // aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple miopen_rnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state) {
+        return at::_ops::miopen_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple miopen_rnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) {
+        return at::_ops::miopen_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()
+    inline void miopen_rnn_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) {
+        return at::_ops::miopen_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3);
+    }
+    
+    // aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()
+    inline void miopen_rnn_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const c10::optional & cx, const at::Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const c10::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) {
+        return at::_ops::miopen_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3);
+    }
+    
+    // aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_sparse_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::_sparse_sparse_matmul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_sparse_matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::_sparse_sparse_matmul_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::mul_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::mul_Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_native_batch_norm_legit_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out)
+    inline ::std::tuple _native_batch_norm_legit_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit_functional::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps);
+    }
+    
+    // aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _native_batch_norm_legit_no_training_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) {
+        return at::_ops::_native_batch_norm_legit_no_training_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2);
+    }
+    
+    // aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _native_batch_norm_legit_no_training_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_native_batch_norm_legit_no_training_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2);
+    }
+    
+    // aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double eps) {
+        return at::_ops::batch_norm_stats_out::redispatch(dispatchKeySet, input, eps, out0, out1);
+    }
+    
+    // aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::batch_norm_stats_out::redispatch(dispatchKeySet, input, eps, out0, out1);
+    }
+    
+    // aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_gather_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, int64_t count) {
+        return at::_ops::batch_norm_gather_stats_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1);
+    }
+    
+    // aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_gather_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::batch_norm_gather_stats_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1);
+    }
+    
+    // aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_gather_stats_with_counts_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, const at::Tensor & counts) {
+        return at::_ops::batch_norm_gather_stats_with_counts_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1);
+    }
+    
+    // aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_gather_stats_with_counts_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::batch_norm_gather_stats_with_counts_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1);
+    }
+    
+    // aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask) {
+        return at::_ops::native_batch_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, out0, out1, out2);
+    }
+    
+    // aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple native_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::native_batch_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, out0, out1, out2);
+    }
+    
+    // aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple batch_norm_backward_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g) {
+        return at::_ops::batch_norm_backward_reduce_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3);
+    }
+    
+    // aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
+    inline ::std::tuple batch_norm_backward_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) {
+        return at::_ops::batch_norm_backward_reduce_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3);
+    }
+    
+    // aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & batch_norm_backward_elemt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) {
+        return at::_ops::batch_norm_backward_elemt_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out);
+    }
+    
+    // aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & batch_norm_backward_elemt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const c10::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out) {
+        return at::_ops::batch_norm_backward_elemt_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out);
+    }
+    
+    // aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_update_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const c10::optional & running_mean, const c10::optional & running_var, double momentum) {
+        return at::_ops::batch_norm_update_stats_out::redispatch(dispatchKeySet, input, running_mean, running_var, momentum, out0, out1);
+    }
+    
+    // aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple batch_norm_update_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const c10::optional & running_mean, const c10::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::batch_norm_update_stats_out::redispatch(dispatchKeySet, input, running_mean, running_var, momentum, out0, out1);
+    }
+    
+    // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nnpack_spatial_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride=1) {
+        return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nnpack_spatial_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nnpack_spatial_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1)) {
+        return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, padding, stride, out);
+    }
+    
+    // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nnpack_spatial_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out) {
+        return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, padding, stride, out);
+    }
+    
+    // aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional names) {
+        return at::_ops::ones_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::Tensor & out) {
+        return at::_ops::ones_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::ones_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ones_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::ones_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _euclidean_dist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x1, const at::Tensor & x2) {
+        return at::_ops::_euclidean_dist_out::redispatch(dispatchKeySet, x1, x2, out);
+    }
+    
+    // aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _euclidean_dist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, at::Tensor & out) {
+        return at::_ops::_euclidean_dist_out::redispatch(dispatchKeySet, x1, x2, out);
+    }
+    
+    // aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cdist_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional compute_mode) {
+        return at::_ops::_cdist_forward_out::redispatch(dispatchKeySet, x1, x2, p, compute_mode, out);
+    }
+    
+    // aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cdist_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional compute_mode, at::Tensor & out) {
+        return at::_ops::_cdist_forward_out::redispatch(dispatchKeySet, x1, x2, p, compute_mode, out);
+    }
+    
+    // aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cdist_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) {
+        return at::_ops::_cdist_backward_out::redispatch(dispatchKeySet, grad, x1, x2, p, cdist, out);
+    }
+    
+    // aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cdist_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist, at::Tensor & out) {
+        return at::_ops::_cdist_backward_out::redispatch(dispatchKeySet, grad, x1, x2, p, cdist, out);
+    }
+    
+    // aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _pdist_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p=2) {
+        return at::_ops::_pdist_forward_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _pdist_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, at::Tensor & out) {
+        return at::_ops::_pdist_forward_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _pdist_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) {
+        return at::_ops::_pdist_backward_out::redispatch(dispatchKeySet, grad, self, p, pdist, out);
+    }
+    
+    // aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _pdist_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist, at::Tensor & out) {
+        return at::_ops::_pdist_backward_out::redispatch(dispatchKeySet, grad, self, p, pdist, out);
+    }
+    
+    // aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pixel_shuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t upscale_factor) {
+        return at::_ops::pixel_shuffle_out::redispatch(dispatchKeySet, self, upscale_factor, out);
+    }
+    
+    // aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pixel_shuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t upscale_factor, at::Tensor & out) {
+        return at::_ops::pixel_shuffle_out::redispatch(dispatchKeySet, self, upscale_factor, out);
+    }
+    
+    // aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pixel_unshuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t downscale_factor) {
+        return at::_ops::pixel_unshuffle_out::redispatch(dispatchKeySet, self, downscale_factor, out);
+    }
+    
+    // aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & pixel_unshuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t downscale_factor, at::Tensor & out) {
+        return at::_ops::pixel_unshuffle_out::redispatch(dispatchKeySet, self, downscale_factor, out);
+    }
+    
+    // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & channel_shuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t groups) {
+        return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out);
+    }
+    
+    // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & channel_shuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups, at::Tensor & out) {
+        return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out);
+    }
+    
+    // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & channel_shuffle_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt groups) {
+        return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out);
+    }
+    
+    // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & channel_shuffle_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out);
+    }
+    
+    // aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _pin_memory_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional device=c10::nullopt) {
+        return at::_ops::_pin_memory_out::redispatch(dispatchKeySet, self, device, out);
+    }
+    
+    // aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _pin_memory_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional device, at::Tensor & out) {
+        return at::_ops::_pin_memory_out::redispatch(dispatchKeySet, self, device, out);
+    }
+    
+    // aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scalar_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & s) {
+        return at::_ops::scalar_tensor_out::redispatch(dispatchKeySet, s, out);
+    }
+    
+    // aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & scalar_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, at::Tensor & out) {
+        return at::_ops::scalar_tensor_out::redispatch(dispatchKeySet, s, out);
+    }
+    
+    // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional names) {
+        return at::_ops::rand_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out);
+    }
+    
+    // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::Tensor & out) {
+        return at::_ops::rand_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out);
+    }
+    
+    // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional names) {
+        return at::_ops::rand_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional names, at::Tensor & out) {
+        return at::_ops::rand_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional generator, c10::optional names) {
+        return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out);
+    }
+    
+    // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional names, at::Tensor & out) {
+        return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out);
+    }
+    
+    // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional generator, c10::optional names) {
+        return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out);
+    }
+    
+    // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional names, at::Tensor & out) {
+        return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out);
+    }
+    
+    // aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::rand_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rand_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::rand_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t high, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out);
+    }
+    
+    // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out);
+    }
+    
+    // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt high, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out);
+    }
+    
+    // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out);
+    }
+    
+    // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t low, int64_t high, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out);
+    }
+    
+    // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out);
+    }
+    
+    // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt low, c10::SymInt high, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out);
+    }
+    
+    // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out);
+    }
+    
+    // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional names) {
+        return at::_ops::randn_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out);
+    }
+    
+    // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::Tensor & out) {
+        return at::_ops::randn_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out);
+    }
+    
+    // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional names) {
+        return at::_ops::randn_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional names, at::Tensor & out) {
+        return at::_ops::randn_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional generator, c10::optional names) {
+        return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out);
+    }
+    
+    // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional generator, c10::optional names, at::Tensor & out) {
+        return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out);
+    }
+    
+    // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::optional generator, c10::optional names) {
+        return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out);
+    }
+    
+    // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::optional generator, c10::optional names, at::Tensor & out) {
+        return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out);
+    }
+    
+    // aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::randn_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & randn_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::randn_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef repeats) {
+        return at::_ops::repeat_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats), out);
+    }
+    
+    // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef repeats, at::Tensor & out) {
+        return at::_ops::repeat_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats), out);
+    }
+    
+    // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef repeats) {
+        return at::_ops::repeat_out::redispatch(dispatchKeySet, self, repeats, out);
+    }
+    
+    // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef repeats, at::Tensor & out) {
+        return at::_ops::repeat_out::redispatch(dispatchKeySet, self, repeats, out);
+    }
+    
+    // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_interleave_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & repeats, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt, out);
+    }
+    
+    // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_interleave_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, c10::optional output_size, at::Tensor & out) {
+        return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt, out);
+    }
+    
+    // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_interleave_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & repeats, c10::optional output_size=c10::nullopt) {
+        return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size, out);
+    }
+    
+    // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & repeat_interleave_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, c10::optional output_size, at::Tensor & out) {
+        return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size, out);
+    }
+    
+    // aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mkldnn_reshape_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef shape) {
+        return at::_ops::_mkldnn_reshape_out::redispatch(dispatchKeySet, self, shape, out);
+    }
+    
+    // aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mkldnn_reshape_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape, at::Tensor & out) {
+        return at::_ops::_mkldnn_reshape_out::redispatch(dispatchKeySet, self, shape, out);
+    }
+    
+    // aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::relu_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::relu_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) {
+        return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index, out);
+    }
+    
+    // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index, at::Tensor & out) {
+        return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index, out);
+    }
+    
+    // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index, out);
+    }
+    
+    // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index, at::Tensor & out) {
+        return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index, out);
+    }
+    
+    // aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & celu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & alpha=1.0) {
+        return at::_ops::celu_out::redispatch(dispatchKeySet, self, alpha, out);
+    }
+    
+    // aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & celu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::celu_out::redispatch(dispatchKeySet, self, alpha, out);
+    }
+    
+    // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
+        return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step, out);
+    }
+    
+    // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step, at::Tensor & out) {
+        return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step, out);
+    }
+    
+    // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
+        return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step, out);
+    }
+    
+    // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step, at::Tensor & out) {
+        return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step, out);
+    }
+    
+    // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) {
+        return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step, out);
+    }
+    
+    // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, int64_t step, at::Tensor & out) {
+        return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step, out);
+    }
+    
+    // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) {
+        return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start, end, step, out);
+    }
+    
+    // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step, at::Tensor & out) {
+        return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start, end, step, out);
+    }
+    
+    // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index) {
+        return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out);
+    }
+    
+    // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index, at::Tensor & out) {
+        return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out);
+    }
+    
+    // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out);
+    }
+    
+    // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index, at::Tensor & out) {
+        return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out);
+    }
+    
+    // aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) {
+        return at::_ops::diagonal_scatter_out::redispatch(dispatchKeySet, self, src, offset, dim1, dim2, out);
+    }
+    
+    // aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) {
+        return at::_ops::diagonal_scatter_out::redispatch(dispatchKeySet, self, src, offset, dim1, dim2, out);
+    }
+    
+    // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt, out);
+    }
+    
+    // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset, at::Tensor & out) {
+        return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt, out);
+    }
+    
+    // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, size, stride, storage_offset, out);
+    }
+    
+    // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset, at::Tensor & out) {
+        return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, size, stride, storage_offset, out);
+    }
+    
+    // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t split_size, int64_t dim=0) {
+        return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) {
+        return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) {
+        return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) {
+        return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out);
+    }
+    
+    // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_with_sizes_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out);
+    }
+    
+    // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_with_sizes_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
+        return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out);
+    }
+    
+    // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_with_sizes_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) {
+        return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, split_sizes, dim, out);
+    }
+    
+    // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+    inline void unsafe_split_with_sizes_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) {
+        return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, split_sizes, dim, out);
+    }
+    
+    // aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::sum_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::sum_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple std_mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::std_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1);
+    }
+    
+    // aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple std_mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::std_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1);
+    }
+    
+    // aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::prod_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::prod_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mkldnn_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::_mkldnn_transpose_out::redispatch(dispatchKeySet, self, dim0, dim1, out);
+    }
+    
+    // aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _mkldnn_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out) {
+        return at::_ops::_mkldnn_transpose_out::redispatch(dispatchKeySet, self, dim0, dim1, out);
+    }
+    
+    // aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & flip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dims) {
+        return at::_ops::flip_out::redispatch(dispatchKeySet, self, dims, out);
+    }
+    
+    // aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & flip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out) {
+        return at::_ops::flip_out::redispatch(dispatchKeySet, self, dims, out);
+    }
+    
+    // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & roll_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims={}) {
+        return at::_ops::roll_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims, out);
+    }
+    
+    // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & roll_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out) {
+        return at::_ops::roll_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims, out);
+    }
+    
+    // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & roll_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) {
+        return at::_ops::roll_out::redispatch(dispatchKeySet, self, shifts, dims, out);
+    }
+    
+    // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & roll_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out) {
+        return at::_ops::roll_out::redispatch(dispatchKeySet, self, shifts, dims, out);
+    }
+    
+    // aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rot90_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t k=1, at::IntArrayRef dims={0,1}) {
+        return at::_ops::rot90_out::redispatch(dispatchKeySet, self, k, dims, out);
+    }
+    
+    // aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rot90_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::IntArrayRef dims, at::Tensor & out) {
+        return at::_ops::rot90_out::redispatch(dispatchKeySet, self, k, dims, out);
+    }
+    
+    // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _transform_bias_rescale_qkv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) {
+        return at::_ops::_transform_bias_rescale_qkv_out::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads, out0, out1, out2);
+    }
+    
+    // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _transform_bias_rescale_qkv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_transform_bias_rescale_qkv_out::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads, out0, out1, out2);
+    }
+    
+    // aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_from_mask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true) {
+        return at::_ops::_nested_tensor_from_mask_out::redispatch(dispatchKeySet, t, mask, mask_check, out);
+    }
+    
+    // aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_from_mask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out) {
+        return at::_ops::_nested_tensor_from_mask_out::redispatch(dispatchKeySet, t, mask, mask_check, out);
+    }
+    
+    // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_from_padded_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) {
+        return at::_ops::_nested_from_padded_out::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213, out);
+    }
+    
+    // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_from_padded_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out) {
+        return at::_ops::_nested_from_padded_out::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213, out);
+    }
+    
+    // aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_size_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_nested_tensor_size_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_size_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_nested_tensor_size_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_strides_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_nested_tensor_strides_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_strides_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_nested_tensor_strides_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_storage_offsets_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_nested_tensor_storage_offsets_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_storage_offsets_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_nested_tensor_storage_offsets_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_from_padded_and_nested_example_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & padded, const at::Tensor & nt_example) {
+        return at::_ops::_nested_from_padded_and_nested_example_out::redispatch(dispatchKeySet, padded, nt_example, out);
+    }
+    
+    // aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_from_padded_and_nested_example_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & nt_example, at::Tensor & out) {
+        return at::_ops::_nested_from_padded_and_nested_example_out::redispatch(dispatchKeySet, padded, nt_example, out);
+    }
+    
+    // aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_view_from_buffer_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) {
+        return at::_ops::_nested_view_from_buffer_copy_out::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets, out);
+    }
+    
+    // aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_view_from_buffer_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets, at::Tensor & out) {
+        return at::_ops::_nested_view_from_buffer_copy_out::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets, out);
+    }
+    
+    // aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_view_from_jagged_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const c10::optional & lengths={}, int64_t ragged_idx=1) {
+        return at::_ops::_nested_view_from_jagged_copy_out::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, out);
+    }
+    
+    // aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_view_from_jagged_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const c10::optional & lengths, int64_t ragged_idx, at::Tensor & out) {
+        return at::_ops::_nested_view_from_jagged_copy_out::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, out);
+    }
+    
+    // aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_get_values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_nested_get_values_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_get_values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_nested_get_values_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _trilinear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim=1) {
+        return at::_ops::_trilinear_out::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, out);
+    }
+    
+    // aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _trilinear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim, at::Tensor & out) {
+        return at::_ops::_trilinear_out::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, out);
+    }
+    
+    // aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _unique_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, bool sorted=true, bool return_inverse=false) {
+        return at::_ops::_unique_out::redispatch(dispatchKeySet, self, sorted, return_inverse, out0, out1);
+    }
+    
+    // aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _unique_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted, bool return_inverse, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_unique_out::redispatch(dispatchKeySet, self, sorted, return_inverse, out0, out1);
+    }
+    
+    // aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple unique_dim_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, int64_t dim, bool sorted=true, bool return_inverse=false, bool return_counts=false) {
+        return at::_ops::unique_dim_out::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts, out0, out1, out2);
+    }
+    
+    // aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple unique_dim_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::unique_dim_out::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts, out0, out1, out2);
+    }
+    
+    // aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple unique_consecutive_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, bool return_inverse=false, bool return_counts=false, c10::optional dim=c10::nullopt) {
+        return at::_ops::unique_consecutive_out::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim, out0, out1, out2);
+    }
+    
+    // aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple unique_consecutive_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool return_inverse, bool return_counts, c10::optional dim, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::unique_consecutive_out::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim, out0, out1, out2);
+    }
+    
+    // aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple unique_dim_consecutive_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, int64_t dim, bool return_inverse=false, bool return_counts=false) {
+        return at::_ops::unique_dim_consecutive_out::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts, out0, out1, out2);
+    }
+    
+    // aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple unique_dim_consecutive_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::unique_dim_consecutive_out::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts, out0, out1, out2);
+    }
+    
+    // aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _unique2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, bool sorted=true, bool return_inverse=false, bool return_counts=false) {
+        return at::_ops::_unique2_out::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts, out0, out1, out2);
+    }
+    
+    // aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _unique2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_unique2_out::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts, out0, out1, out2);
+    }
+    
+    // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _unsafe_view_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _unsafe_view_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _unsafe_view_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _unsafe_view_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple var_mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) {
+        return at::_ops::var_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1);
+    }
+    
+    // aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple var_mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::var_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1);
+    }
+    
+    // aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _weight_norm_interface_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) {
+        return at::_ops::_weight_norm_interface_out::redispatch(dispatchKeySet, v, g, dim, out0, out1);
+    }
+    
+    // aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _weight_norm_interface_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_weight_norm_interface_out::redispatch(dispatchKeySet, v, g, dim, out0, out1);
+    }
+    
+    // aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _weight_norm_interface_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) {
+        return at::_ops::_weight_norm_interface_backward_out::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim, out0, out1);
+    }
+    
+    // aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _weight_norm_interface_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_weight_norm_interface_backward_out::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim, out0, out1);
+    }
+    
+    // aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, c10::optional names) {
+        return at::_ops::zeros_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, c10::optional names, at::Tensor & out) {
+        return at::_ops::zeros_names_out::redispatch(dispatchKeySet, size, names, out);
+    }
+    
+    // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _efficientzerotensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) {
+        return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _efficientzerotensor_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _efficientzerotensor_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) {
+        return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _efficientzerotensor_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::zeros_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zeros_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::zeros_like_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _standard_gamma_grad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & output) {
+        return at::_ops::_standard_gamma_grad_out::redispatch(dispatchKeySet, self, output, out);
+    }
+    
+    // aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _standard_gamma_grad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & output, at::Tensor & out) {
+        return at::_ops::_standard_gamma_grad_out::redispatch(dispatchKeySet, self, output, out);
+    }
+    
+    // aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _standard_gamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::_standard_gamma_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _standard_gamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator, at::Tensor & out) {
+        return at::_ops::_standard_gamma_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _dirichlet_grad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) {
+        return at::_ops::_dirichlet_grad_out::redispatch(dispatchKeySet, x, alpha, total, out);
+    }
+    
+    // aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _dirichlet_grad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total, at::Tensor & out) {
+        return at::_ops::_dirichlet_grad_out::redispatch(dispatchKeySet, x, alpha, total, out);
+    }
+    
+    // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sample_dirichlet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::_sample_dirichlet_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sample_dirichlet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator, at::Tensor & out) {
+        return at::_ops::_sample_dirichlet_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & poisson_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::poisson_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & poisson_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator, at::Tensor & out) {
+        return at::_ops::poisson_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & binomial_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & count, const at::Tensor & prob, c10::optional generator=c10::nullopt) {
+        return at::_ops::binomial_out::redispatch(dispatchKeySet, count, prob, generator, out);
+    }
+    
+    // aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & binomial_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & count, const at::Tensor & prob, c10::optional generator, at::Tensor & out) {
+        return at::_ops::binomial_out::redispatch(dispatchKeySet, count, prob, generator, out);
+    }
+    
+    // aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & native_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p=2) {
+        return at::_ops::native_norm_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & native_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, at::Tensor & out) {
+        return at::_ops::native_norm_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & native_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, c10::optional dtype) {
+        return at::_ops::native_norm_ScalarOpt_dim_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out);
+    }
+    
+    // aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & native_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::IntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::native_norm_ScalarOpt_dim_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::_sparse_sum_dim_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) {
+        return at::_ops::_sparse_sum_dim_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_sum_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::_sparse_sum_backward_out::redispatch(dispatchKeySet, grad, self, dim, out);
+    }
+    
+    // aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_sum_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) {
+        return at::_ops::_sparse_sum_backward_out::redispatch(dispatchKeySet, grad, self, dim, out);
+    }
+    
+    // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_csr_sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_csr_sum_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_csr_sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::_sparse_csr_sum_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_csr_prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) {
+        return at::_ops::_sparse_csr_prod_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_csr_prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::_sparse_csr_prod_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out);
+    }
+    
+    // aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_sparse_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) {
+        return at::_ops::_sparse_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) {
+        return at::_ops::_sparse_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out);
+    }
+    
+    // aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_sparse_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out);
+    }
+    
+    // aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) {
+        return at::_ops::_sparse_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) {
+        return at::_ops::_sparse_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out);
+    }
+    
+    // aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_log_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) {
+        return at::_ops::_sparse_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out);
+    }
+    
+    // aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_log_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_sparse_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out);
+    }
+    
+    // aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _spdiags_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, c10::optional layout=c10::nullopt) {
+        return at::_ops::_spdiags_out::redispatch(dispatchKeySet, diagonals, offsets, shape, layout, out);
+    }
+    
+    // aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _spdiags_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, c10::optional layout, at::Tensor & out) {
+        return at::_ops::_spdiags_out::redispatch(dispatchKeySet, diagonals, offsets, shape, layout, out);
+    }
+    
+    // aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::optional & p, at::ScalarType dtype) {
+        return at::_ops::norm_ScalarOpt_dtype_out::redispatch(dispatchKeySet, self, p, dtype, out);
+    }
+    
+    // aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::optional & p, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::norm_ScalarOpt_dtype_out::redispatch(dispatchKeySet, self, p, dtype, out);
+    }
+    
+    // aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p=2) {
+        return at::_ops::norm_Scalar_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, at::Tensor & out) {
+        return at::_ops::norm_Scalar_out::redispatch(dispatchKeySet, self, p, out);
+    }
+    
+    // aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clone_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::clone_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & clone_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::clone_out::redispatch(dispatchKeySet, self, memory_format, out);
+    }
+    
+    // aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_as_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & the_template, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_as_out::redispatch(dispatchKeySet, self, the_template, memory_format, out);
+    }
+    
+    // aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_as_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, c10::optional memory_format, const at::Tensor & out) {
+        return at::_ops::resize_as_out::redispatch(dispatchKeySet, self, the_template, memory_format, out);
+    }
+    
+    // aten::resize_as(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor
+    inline at::Tensor resize_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::resize_as::redispatch(dispatchKeySet, self, the_template, memory_format);
+    }
+    
+    // aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_as_sparse_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & the_template) {
+        return at::_ops::resize_as_sparse_out::redispatch(dispatchKeySet, self, the_template, out);
+    }
+    
+    // aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & resize_as_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, const at::Tensor & out) {
+        return at::_ops::resize_as_sparse_out::redispatch(dispatchKeySet, self, the_template, out);
+    }
+    
+    // aten::resize_as_sparse(Tensor self, Tensor the_template) -> Tensor
+    inline at::Tensor resize_as_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template) {
+        return at::_ops::resize_as_sparse::redispatch(dispatchKeySet, self, the_template);
+    }
+    
+    // aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::zero_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & zero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::zero_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::zero(Tensor self) -> Tensor
+    inline at::Tensor zero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::zero::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::sub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::sub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rsub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::rsub_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rsub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::rsub_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rsub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) {
+        return at::_ops::rsub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rsub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::rsub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) {
+        return at::_ops::_sparse_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) {
+        return at::_ops::_sparse_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out);
+    }
+    
+    // aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sparse_coo_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) {
+        return at::_ops::sparse_coo_tensor_size_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sparse_coo_tensor_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::sparse_coo_tensor_size_out::redispatch(dispatchKeySet, size, out);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_coo_tensor_with_dims_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size) {
+        return at::_ops::_sparse_coo_tensor_with_dims_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, out);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_coo_tensor_with_dims_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::_sparse_coo_tensor_with_dims_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, out);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, is_coalesced, out);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, c10::optional is_coalesced, at::Tensor & out) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, is_coalesced, out);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, c10::optional is_coalesced=c10::nullopt) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, is_coalesced, out);
+    }
+    
+    // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_symint_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, c10::optional is_coalesced, at::Tensor & out) {
+        return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, is_coalesced, out);
+    }
+    
+    // aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & sparse_resize_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) {
+        return at::_ops::sparse_resize_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out);
+    }
+    
+    // aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & sparse_resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out) {
+        return at::_ops::sparse_resize_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out);
+    }
+    
+    // aten::sparse_resize(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor
+    inline at::Tensor sparse_resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) {
+        return at::_ops::sparse_resize::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim);
+    }
+    
+    // aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & sparse_resize_and_clear_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) {
+        return at::_ops::sparse_resize_and_clear_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out);
+    }
+    
+    // aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline const at::Tensor & sparse_resize_and_clear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out) {
+        return at::_ops::sparse_resize_and_clear_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out);
+    }
+    
+    // aten::sparse_resize_and_clear(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor
+    inline at::Tensor sparse_resize_and_clear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) {
+        return at::_ops::sparse_resize_and_clear::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim);
+    }
+    
+    // aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sparse_mask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask) {
+        return at::_ops::sparse_mask_out::redispatch(dispatchKeySet, self, mask, out);
+    }
+    
+    // aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & sparse_mask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, at::Tensor & out) {
+        return at::_ops::sparse_mask_out::redispatch(dispatchKeySet, self, mask, out);
+    }
+    
+    // aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_mask_projection_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches=false) {
+        return at::_ops::_sparse_mask_projection_out::redispatch(dispatchKeySet, self, mask, accumulate_matches, out);
+    }
+    
+    // aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_mask_projection_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out) {
+        return at::_ops::_sparse_mask_projection_out::redispatch(dispatchKeySet, self, mask, accumulate_matches, out);
+    }
+    
+    // aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_dense_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dtype=c10::nullopt, c10::optional masked_grad=c10::nullopt) {
+        return at::_ops::_to_dense_out::redispatch(dispatchKeySet, self, dtype, masked_grad, out);
+    }
+    
+    // aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_dense_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, c10::optional masked_grad, at::Tensor & out) {
+        return at::_ops::_to_dense_out::redispatch(dispatchKeySet, self, dtype, masked_grad, out);
+    }
+    
+    // aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _coalesce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_coalesce_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _coalesce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_coalesce_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _coalesced_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool coalesced) {
+        return at::_ops::_coalesced_out::redispatch(dispatchKeySet, self, coalesced, out);
+    }
+    
+    // aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _coalesced_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool coalesced, at::Tensor & out) {
+        return at::_ops::_coalesced_out::redispatch(dispatchKeySet, self, coalesced, out);
+    }
+    
+    // aten::_coalesced(Tensor self, bool coalesced) -> Tensor
+    inline at::Tensor _coalesced(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool coalesced) {
+        return at::_ops::_coalesced::redispatch(dispatchKeySet, self, coalesced);
+    }
+    
+    // aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copy_sparse_to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) {
+        return at::_ops::copy_sparse_to_sparse_out::redispatch(dispatchKeySet, self, src, non_blocking, out);
+    }
+    
+    // aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & copy_sparse_to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out) {
+        return at::_ops::copy_sparse_to_sparse_out::redispatch(dispatchKeySet, self, src, non_blocking, out);
+    }
+    
+    // aten::copy_sparse_to_sparse(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
+    inline at::Tensor copy_sparse_to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) {
+        return at::_ops::copy_sparse_to_sparse::redispatch(dispatchKeySet, self, src, non_blocking);
+    }
+    
+    // aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t sparse_dim) {
+        return at::_ops::_to_sparse_sparse_dim_out::redispatch(dispatchKeySet, self, sparse_dim, out);
+    }
+    
+    // aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim, at::Tensor & out) {
+        return at::_ops::_to_sparse_sparse_dim_out::redispatch(dispatchKeySet, self, sparse_dim, out);
+    }
+    
+    // aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional layout=c10::nullopt, at::OptionalIntArrayRef blocksize=c10::nullopt, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_out::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim, out);
+    }
+    
+    // aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional layout, at::OptionalIntArrayRef blocksize, c10::optional dense_dim, at::Tensor & out) {
+        return at::_ops::_to_sparse_out::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_csr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_csr_out::redispatch(dispatchKeySet, self, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_csr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dense_dim, at::Tensor & out) {
+        return at::_ops::_to_sparse_csr_out::redispatch(dispatchKeySet, self, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_csc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_csc_out::redispatch(dispatchKeySet, self, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_csc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dense_dim, at::Tensor & out) {
+        return at::_ops::_to_sparse_csc_out::redispatch(dispatchKeySet, self, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_bsr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_bsr_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_bsr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim, at::Tensor & out) {
+        return at::_ops::_to_sparse_bsr_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_bsc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) {
+        return at::_ops::_to_sparse_bsc_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out);
+    }
+    
+    // aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_sparse_bsc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, c10::optional dense_dim, at::Tensor & out) {
+        return at::_ops::_to_sparse_bsc_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out);
+    }
+    
+    // aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & to_mkldnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional dtype=c10::nullopt) {
+        return at::_ops::to_mkldnn_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & to_mkldnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional dtype, at::Tensor & out) {
+        return at::_ops::to_mkldnn_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv2d_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=c10::nullopt) {
+        return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*input_size)) : c10::nullopt, out);
+    }
+    
+    // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv2d_weight_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::OptionalIntArrayRef input_size, at::Tensor & out) {
+        return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*input_size)) : c10::nullopt, out);
+    }
+    
+    // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv2d_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=c10::nullopt) {
+        return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out);
+    }
+    
+    // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv2d_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out) {
+        return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out);
+    }
+    
+    // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv3d_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1) {
+        return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv3d_weight_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) {
+        return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out);
+    }
+    
+    // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv3d_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) {
+        return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, out);
+    }
+    
+    // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_reorder_conv3d_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) {
+        return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, out);
+    }
+    
+    // aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_tensor_dynamic_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::ScalarType dtype, bool reduce_range) {
+        return at::_ops::quantize_per_tensor_dynamic_out::redispatch(dispatchKeySet, self, dtype, reduce_range, out);
+    }
+    
+    // aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_tensor_dynamic_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool reduce_range, at::Tensor & out) {
+        return at::_ops::quantize_per_tensor_dynamic_out::redispatch(dispatchKeySet, self, dtype, reduce_range, out);
+    }
+    
+    // aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) {
+        return at::_ops::quantize_per_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out);
+    }
+    
+    // aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::quantize_per_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out);
+    }
+    
+    // aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) {
+        return at::_ops::quantize_per_tensor_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out);
+    }
+    
+    // aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::quantize_per_tensor_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out);
+    }
+    
+    // aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> ()
+    inline void quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) {
+        return at::_ops::quantize_per_tensor_tensors_out::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype, out);
+    }
+    
+    // aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> ()
+    inline void quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out) {
+        return at::_ops::quantize_per_tensor_tensors_out::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype, out);
+    }
+    
+    // aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_channel_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) {
+        return at::_ops::quantize_per_channel_out::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype, out);
+    }
+    
+    // aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & quantize_per_channel_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::quantize_per_channel_out::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype, out);
+    }
+    
+    // aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dequantize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::dequantize_self_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dequantize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::dequantize_self_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> ()
+    inline void dequantize_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList tensors) {
+        return at::_ops::dequantize_tensors_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> ()
+    inline void dequantize_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::TensorList out) {
+        return at::_ops::dequantize_tensors_out::redispatch(dispatchKeySet, tensors, out);
+    }
+    
+    // aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & q_per_channel_scales_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::q_per_channel_scales_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & q_per_channel_scales_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::q_per_channel_scales_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & q_per_channel_zero_points_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::q_per_channel_zero_points_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & q_per_channel_zero_points_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::q_per_channel_zero_points_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & int_repr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::int_repr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & int_repr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::int_repr_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _make_per_tensor_quantized_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double scale, int64_t zero_point) {
+        return at::_ops::_make_per_tensor_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, out);
+    }
+    
+    // aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _make_per_tensor_quantized_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::Tensor & out) {
+        return at::_ops::_make_per_tensor_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, out);
+    }
+    
+    // aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _make_per_channel_quantized_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) {
+        return at::_ops::_make_per_channel_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, out);
+    }
+    
+    // aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _make_per_channel_quantized_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, at::Tensor & out) {
+        return at::_ops::_make_per_channel_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, out);
+    }
+    
+    // aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fake_quantize_per_tensor_affine_cachemask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_tensor_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, out0, out1);
+    }
+    
+    // aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fake_quantize_per_tensor_affine_cachemask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::fake_quantize_per_tensor_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, out0, out1);
+    }
+    
+    // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max, out0, out1);
+    }
+    
+    // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max, out0, out1);
+    }
+    
+    // aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fake_quantize_learnable_per_tensor_affine_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) {
+        return at::_ops::_fake_quantize_learnable_per_tensor_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor, out);
+    }
+    
+    // aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fake_quantize_learnable_per_tensor_affine_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out) {
+        return at::_ops::_fake_quantize_learnable_per_tensor_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor, out);
+    }
+    
+    // aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fake_quantize_per_channel_affine_cachemask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) {
+        return at::_ops::fake_quantize_per_channel_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, out0, out1);
+    }
+    
+    // aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple fake_quantize_per_channel_affine_cachemask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::fake_quantize_per_channel_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, out0, out1);
+    }
+    
+    // aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fake_quantize_learnable_per_channel_affine_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) {
+        return at::_ops::_fake_quantize_learnable_per_channel_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor, out);
+    }
+    
+    // aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fake_quantize_learnable_per_channel_affine_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out) {
+        return at::_ops::_fake_quantize_learnable_per_channel_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor, out);
+    }
+    
+    // aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))
+    inline ::std::tuple _fused_moving_avg_obs_fq_helper_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) {
+        return at::_ops::_fused_moving_avg_obs_fq_helper_out::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant, out0, out1);
+    }
+    
+    // aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))
+    inline ::std::tuple _fused_moving_avg_obs_fq_helper_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_fused_moving_avg_obs_fq_helper_out::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant, out0, out1);
+    }
+    
+    // aten::_fused_moving_avg_obs_fq_helper_functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)
+    inline ::std::tuple _fused_moving_avg_obs_fq_helper_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) {
+        return at::_ops::_fused_moving_avg_obs_fq_helper_functional::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
+    }
+    
+    // aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool non_blocking=false, c10::optional memory_format=c10::nullopt) {
+        return at::_ops::_to_copy_out::redispatch(dispatchKeySet, self, non_blocking, memory_format, out);
+    }
+    
+    // aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _to_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking, c10::optional memory_format, at::Tensor & out) {
+        return at::_ops::_to_copy_out::redispatch(dispatchKeySet, self, non_blocking, memory_format, out);
+    }
+    
+    // aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))
+    inline ::std::tuple _lstm_mps_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::_lstm_mps_out::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5);
+    }
+    
+    // aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))
+    inline ::std::tuple _lstm_mps_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5) {
+        return at::_ops::_lstm_mps_out::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5);
+    }
+    
+    // aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> ()
+    inline void lstm_mps_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::TensorList out1, at::TensorList out2, const c10::optional & grad_y, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
+        return at::_ops::lstm_mps_backward_out::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2);
+    }
+    
+    // aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> ()
+    inline void lstm_mps_backward_outf(c10::DispatchKeySet dispatchKeySet, const c10::optional & grad_y, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2) {
+        return at::_ops::lstm_mps_backward_out::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2);
+    }
+    
+    // aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _thnn_fused_lstm_cell_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const c10::optional & input_bias={}, const c10::optional & hidden_bias={}) {
+        return at::_ops::_thnn_fused_lstm_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias, out0, out1, out2);
+    }
+    
+    // aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _thnn_fused_lstm_cell_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const c10::optional & input_bias, const c10::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_thnn_fused_lstm_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias, out0, out1, out2);
+    }
+    
+    // aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _thnn_fused_lstm_cell_backward_impl_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) {
+        return at::_ops::_thnn_fused_lstm_cell_backward_impl_out::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias, out0, out1, out2);
+    }
+    
+    // aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _thnn_fused_lstm_cell_backward_impl_outf(c10::DispatchKeySet dispatchKeySet, const c10::optional & grad_hy, const c10::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_thnn_fused_lstm_cell_backward_impl_out::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias, out0, out1, out2);
+    }
+    
+    // aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _thnn_fused_gru_cell_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional & input_bias={}, const c10::optional & hidden_bias={}) {
+        return at::_ops::_thnn_fused_gru_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias, out0, out1);
+    }
+    
+    // aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _thnn_fused_gru_cell_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const c10::optional & input_bias, const c10::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_thnn_fused_gru_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias, out0, out1);
+    }
+    
+    // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple _thnn_fused_gru_cell_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) {
+        return at::_ops::_thnn_fused_gru_cell_backward_out::redispatch(dispatchKeySet, grad_hy, workspace, has_bias, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))
+    inline ::std::tuple _thnn_fused_gru_cell_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) {
+        return at::_ops::_thnn_fused_gru_cell_backward_out::redispatch(dispatchKeySet, grad_hy, workspace, has_bias, out0, out1, out2, out3, out4);
+    }
+    
+    // aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _pack_padded_sequence_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & lengths, bool batch_first) {
+        return at::_ops::_pack_padded_sequence_out::redispatch(dispatchKeySet, input, lengths, batch_first, out0, out1);
+    }
+    
+    // aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _pack_padded_sequence_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & lengths, bool batch_first, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_pack_padded_sequence_out::redispatch(dispatchKeySet, input, lengths, batch_first, out0, out1);
+    }
+    
+    // aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source) {
+        return at::_ops::set_source_Storage_out::redispatch(dispatchKeySet, self, source, out);
+    }
+    
+    // aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, at::Tensor & out) {
+        return at::_ops::set_source_Storage_out::redispatch(dispatchKeySet, self, source, out);
+    }
+    
+    // aten::set.source_Storage(Tensor self, Storage source) -> Tensor
+    inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source) {
+        return at::_ops::set_source_Storage::redispatch(dispatchKeySet, self, source);
+    }
+    
+    // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) {
+        return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) {
+        return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, size, stride, out);
+    }
+    
+    // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) {
+        return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, size, stride, out);
+    }
+    
+    // aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor
+    inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) {
+        return at::_ops::set_source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+    }
+    
+    // aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor
+    inline at::Tensor set_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) {
+        return at::_ops::set_source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride);
+    }
+    
+    // aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & source) {
+        return at::_ops::set_source_Tensor_out::redispatch(dispatchKeySet, self, source, out);
+    }
+    
+    // aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & source, at::Tensor & out) {
+        return at::_ops::set_source_Tensor_out::redispatch(dispatchKeySet, self, source, out);
+    }
+    
+    // aten::set.source_Tensor(Tensor self, Tensor source) -> Tensor
+    inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & source) {
+        return at::_ops::set_source_Tensor::redispatch(dispatchKeySet, self, source);
+    }
+    
+    // aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::set_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::set_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::set(Tensor self) -> Tensor
+    inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
+        return at::_ops::set::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::lift_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::lift_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lift_fresh_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::lift_fresh_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & lift_fresh_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::lift_fresh_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) {
+        return at::_ops::masked_fill_Scalar_out::redispatch(dispatchKeySet, self, mask, value, out);
+    }
+    
+    // aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::masked_fill_Scalar_out::redispatch(dispatchKeySet, self, mask, value, out);
+    }
+    
+    // aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) {
+        return at::_ops::masked_fill_Tensor_out::redispatch(dispatchKeySet, self, mask, value, out);
+    }
+    
+    // aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value, at::Tensor & out) {
+        return at::_ops::masked_fill_Tensor_out::redispatch(dispatchKeySet, self, mask, value, out);
+    }
+    
+    // aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) {
+        return at::_ops::masked_scatter_out::redispatch(dispatchKeySet, self, mask, source, out);
+    }
+    
+    // aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & masked_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source, at::Tensor & out) {
+        return at::_ops::masked_scatter_out::redispatch(dispatchKeySet, self, mask, source, out);
+    }
+    
+    // aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _masked_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, c10::optional dim=c10::nullopt, c10::optional mask_type=c10::nullopt) {
+        return at::_ops::_masked_softmax_out::redispatch(dispatchKeySet, self, mask, dim, mask_type, out);
+    }
+    
+    // aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _masked_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, c10::optional dim, c10::optional mask_type, at::Tensor & out) {
+        return at::_ops::_masked_softmax_out::redispatch(dispatchKeySet, self, mask, dim, mask_type, out);
+    }
+    
+    // aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _masked_softmax_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional dim=c10::nullopt) {
+        return at::_ops::_masked_softmax_backward_out::redispatch(dispatchKeySet, grad_output, output, mask, dim, out);
+    }
+    
+    // aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _masked_softmax_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional dim, at::Tensor & out) {
+        return at::_ops::_masked_softmax_backward_out::redispatch(dispatchKeySet, grad_output, output, mask, dim, out);
+    }
+    
+    // aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & put_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) {
+        return at::_ops::put_out::redispatch(dispatchKeySet, self, index, source, accumulate, out);
+    }
+    
+    // aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & put_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate, at::Tensor & out) {
+        return at::_ops::put_out::redispatch(dispatchKeySet, self, index, source, accumulate, out);
+    }
+    
+    // aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) {
+        return at::_ops::index_fill_int_Scalar_out::redispatch(dispatchKeySet, self, dim, index, value, out);
+    }
+    
+    // aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out) {
+        return at::_ops::index_fill_int_Scalar_out::redispatch(dispatchKeySet, self, dim, index, value, out);
+    }
+    
+    // aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) {
+        return at::_ops::index_fill_int_Tensor_out::redispatch(dispatchKeySet, self, dim, index, value, out);
+    }
+    
+    // aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & index_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value, at::Tensor & out) {
+        return at::_ops::index_fill_int_Tensor_out::redispatch(dispatchKeySet, self, dim, index, value, out);
+    }
+    
+    // aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_and_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_and_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_or_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_or_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_xor_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_xor_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __lshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__lshift___Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __lshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::__lshift___Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __lshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__lshift___Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __lshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::__lshift___Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_left_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_left_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __rshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) {
+        return at::_ops::__rshift___Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __rshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) {
+        return at::_ops::__rshift___Scalar_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __rshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) {
+        return at::_ops::__rshift___Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & __rshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::__rshift___Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::bitwise_right_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::bitwise_right_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator=c10::nullopt) {
+        return at::_ops::random_from_out::redispatch(dispatchKeySet, self, from, to, generator, out);
+    }
+    
+    // aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator, at::Tensor & out) {
+        return at::_ops::random_from_out::redispatch(dispatchKeySet, self, from, to, generator, out);
+    }
+    
+    // aten::random.from(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor
+    inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t from, c10::optional to, c10::optional generator=c10::nullopt) {
+        return at::_ops::random_from::redispatch(dispatchKeySet, self, from, to, generator);
+    }
+    
+    // aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t to, c10::optional generator=c10::nullopt) {
+        return at::_ops::random_to_out::redispatch(dispatchKeySet, self, to, generator, out);
+    }
+    
+    // aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t to, c10::optional generator, at::Tensor & out) {
+        return at::_ops::random_to_out::redispatch(dispatchKeySet, self, to, generator, out);
+    }
+    
+    // aten::random.to(Tensor self, int to, *, Generator? generator=None) -> Tensor
+    inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t to, c10::optional generator=c10::nullopt) {
+        return at::_ops::random_to::redispatch(dispatchKeySet, self, to, generator);
+    }
+    
+    // aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::random_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator, at::Tensor & out) {
+        return at::_ops::random_out::redispatch(dispatchKeySet, self, generator, out);
+    }
+    
+    // aten::random(Tensor self, *, Generator? generator=None) -> Tensor
+    inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::optional generator=c10::nullopt) {
+        return at::_ops::random::redispatch(dispatchKeySet, self, generator);
+    }
+    
+    // aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & uniform_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double from=0, double to=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::uniform_out::redispatch(dispatchKeySet, self, from, to, generator, out);
+    }
+    
+    // aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & uniform_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double from, double to, c10::optional generator, at::Tensor & out) {
+        return at::_ops::uniform_out::redispatch(dispatchKeySet, self, from, to, generator, out);
+    }
+    
+    // aten::uniform(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor
+    inline at::Tensor uniform(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double from=0, double to=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::uniform::redispatch(dispatchKeySet, self, from, to, generator);
+    }
+    
+    // aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cauchy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double median=0, double sigma=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::cauchy_out::redispatch(dispatchKeySet, self, median, sigma, generator, out);
+    }
+    
+    // aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & cauchy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double median, double sigma, c10::optional generator, at::Tensor & out) {
+        return at::_ops::cauchy_out::redispatch(dispatchKeySet, self, median, sigma, generator, out);
+    }
+    
+    // aten::cauchy(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor
+    inline at::Tensor cauchy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double median=0, double sigma=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::cauchy::redispatch(dispatchKeySet, self, median, sigma, generator);
+    }
+    
+    // aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double mean=1, double std=2, c10::optional generator=c10::nullopt) {
+        return at::_ops::log_normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out);
+    }
+    
+    // aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & log_normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean, double std, c10::optional generator, at::Tensor & out) {
+        return at::_ops::log_normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out);
+    }
+    
+    // aten::log_normal(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor
+    inline at::Tensor log_normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean=1, double std=2, c10::optional generator=c10::nullopt) {
+        return at::_ops::log_normal::redispatch(dispatchKeySet, self, mean, std, generator);
+    }
+    
+    // aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & exponential_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double lambd=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::exponential_out::redispatch(dispatchKeySet, self, lambd, generator, out);
+    }
+    
+    // aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & exponential_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double lambd, c10::optional generator, at::Tensor & out) {
+        return at::_ops::exponential_out::redispatch(dispatchKeySet, self, lambd, generator, out);
+    }
+    
+    // aten::exponential(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor
+    inline at::Tensor exponential(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double lambd=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::exponential::redispatch(dispatchKeySet, self, lambd, generator);
+    }
+    
+    // aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & geometric_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p, c10::optional generator=c10::nullopt) {
+        return at::_ops::geometric_out::redispatch(dispatchKeySet, self, p, generator, out);
+    }
+    
+    // aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & geometric_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, c10::optional generator, at::Tensor & out) {
+        return at::_ops::geometric_out::redispatch(dispatchKeySet, self, p, generator, out);
+    }
+    
+    // aten::geometric(Tensor self, float p, *, Generator? generator=None) -> Tensor
+    inline at::Tensor geometric(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, c10::optional generator=c10::nullopt) {
+        return at::_ops::geometric::redispatch(dispatchKeySet, self, p, generator);
+    }
+    
+    // aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tril_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t row, int64_t col, int64_t offset=0) {
+        return at::_ops::tril_indices_out::redispatch(dispatchKeySet, row, col, offset, out);
+    }
+    
+    // aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & tril_indices_outf(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, at::Tensor & out) {
+        return at::_ops::tril_indices_out::redispatch(dispatchKeySet, row, col, offset, out);
+    }
+    
+    // aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & triu_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t row, int64_t col, int64_t offset=0) {
+        return at::_ops::triu_indices_out::redispatch(dispatchKeySet, row, col, offset, out);
+    }
+    
+    // aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & triu_indices_outf(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, at::Tensor & out) {
+        return at::_ops::triu_indices_out::redispatch(dispatchKeySet, row, col, offset, out);
+    }
+    
+    // aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & trace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::trace_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & trace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::trace_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cholesky_solve_helper_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & A, bool upper) {
+        return at::_ops::_cholesky_solve_helper_out::redispatch(dispatchKeySet, self, A, upper, out);
+    }
+    
+    // aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _cholesky_solve_helper_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper, at::Tensor & out) {
+        return at::_ops::_cholesky_solve_helper_out::redispatch(dispatchKeySet, self, A, upper, out);
+    }
+    
+    // aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p=2) {
+        return at::_ops::dist_out::redispatch(dispatchKeySet, self, other, p, out);
+    }
+    
+    // aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & dist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p, at::Tensor & out) {
+        return at::_ops::dist_out::redispatch(dispatchKeySet, self, other, p, out);
+    }
+    
+    // aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> ()
+    inline void _histogramdd_bin_edges_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::_histogramdd_bin_edges_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out);
+    }
+    
+    // aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> ()
+    inline void _histogramdd_bin_edges_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density, at::TensorList out) {
+        return at::_ops::_histogramdd_bin_edges_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out);
+    }
+    
+    // aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _histogramdd_from_bin_cts_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::_histogramdd_from_bin_cts_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out);
+    }
+    
+    // aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _histogramdd_from_bin_cts_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density, at::Tensor & out) {
+        return at::_ops::_histogramdd_from_bin_cts_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out);
+    }
+    
+    // aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _histogramdd_from_bin_tensors_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::TensorList bins, const c10::optional & weight={}, bool density=false) {
+        return at::_ops::_histogramdd_from_bin_tensors_out::redispatch(dispatchKeySet, self, bins, weight, density, out);
+    }
+    
+    // aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _histogramdd_from_bin_tensors_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const c10::optional & weight, bool density, at::Tensor & out) {
+        return at::_ops::_histogramdd_from_bin_tensors_out::redispatch(dispatchKeySet, self, bins, weight, density, out);
+    }
+    
+    // aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) {
+        return at::_ops::remainder_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) {
+        return at::_ops::remainder_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & argsort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) {
+        return at::_ops::argsort_stable_out::redispatch(dispatchKeySet, self, stable, dim, descending, out);
+    }
+    
+    // aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & argsort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out) {
+        return at::_ops::argsort_stable_out::redispatch(dispatchKeySet, self, stable, dim, descending, out);
+    }
+    
+    // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unfold_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) {
+        return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step, out);
+    }
+    
+    // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unfold_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out) {
+        return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step, out);
+    }
+    
+    // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unfold_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) {
+        return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step, out);
+    }
+    
+    // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unfold_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out) {
+        return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step, out);
+    }
+    
+    // aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double mean=0, double std=1, c10::optional generator=c10::nullopt) {
+        return at::_ops::normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out);
+    }
+    
+    // aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean, double std, c10::optional generator, at::Tensor & out) {
+        return at::_ops::normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out);
+    }
+    
+    // aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()
+    inline void _amp_foreach_non_finite_check_and_unscale_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) {
+        return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::redispatch(dispatchKeySet, self, found_inf, inv_scale, out);
+    }
+    
+    // aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()
+    inline void _amp_foreach_non_finite_check_and_unscale_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out) {
+        return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::redispatch(dispatchKeySet, self, found_inf, inv_scale, out);
+    }
+    
+    // aten::_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out)
+    inline ::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale) {
+        return at::_ops::_amp_foreach_non_finite_check_and_unscale::redispatch(dispatchKeySet, self, found_inf, inv_scale);
+    }
+    
+    // aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _amp_update_scale_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) {
+        return at::_ops::_amp_update_scale_out::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out);
+    }
+    
+    // aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _amp_update_scale_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor & out) {
+        return at::_ops::_amp_update_scale_out::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out);
+    }
+    
+    // aten::_amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out)
+    inline ::std::tuple _amp_update_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) {
+        return at::_ops::_amp_update_scale::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval);
+    }
+    
+    // aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_add_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_add_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_add_List_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) {
+        return at::_ops::_foreach_add_List_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_add_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_add_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_add_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()
+    inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out) {
+        return at::_ops::_foreach_add_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_sub_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_sub_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()
+    inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) {
+        return at::_ops::_foreach_sub_List_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()
+    inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) {
+        return at::_ops::_foreach_sub_List_out::redispatch(dispatchKeySet, self, other, alpha, out);
+    }
+    
+    // aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_sub_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_sub_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_mul_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_mul_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_mul_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) {
+        return at::_ops::_foreach_mul_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_mul_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_mul_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other) {
+        return at::_ops::_foreach_mul_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, at::TensorList out) {
+        return at::_ops::_foreach_mul_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_div_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_div_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_div_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) {
+        return at::_ops::_foreach_div_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_div_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_div_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other) {
+        return at::_ops::_foreach_div_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, at::TensorList out) {
+        return at::_ops::_foreach_div_Tensor_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_clamp_max_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_clamp_max_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_clamp_max_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) {
+        return at::_ops::_foreach_clamp_max_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_clamp_max_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_clamp_max_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_clamp_min_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_clamp_min_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_clamp_min_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) {
+        return at::_ops::_foreach_clamp_min_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_clamp_min_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_clamp_min_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_maximum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_maximum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_maximum_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) {
+        return at::_ops::_foreach_maximum_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_maximum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_maximum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) {
+        return at::_ops::_foreach_minimum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) {
+        return at::_ops::_foreach_minimum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out);
+    }
+    
+    // aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) {
+        return at::_ops::_foreach_minimum_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) {
+        return at::_ops::_foreach_minimum_List_out::redispatch(dispatchKeySet, self, other, out);
+    }
+    
+    // aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) {
+        return at::_ops::_foreach_minimum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_minimum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out);
+    }
+    
+    // aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) {
+        return at::_ops::_foreach_addcdiv_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) {
+        return at::_ops::_foreach_addcdiv_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) {
+        return at::_ops::_foreach_addcdiv_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_addcdiv_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) {
+        return at::_ops::_foreach_addcdiv_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) {
+        return at::_ops::_foreach_addcdiv_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) {
+        return at::_ops::_foreach_addcmul_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) {
+        return at::_ops::_foreach_addcmul_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out);
+    }
+    
+    // aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) {
+        return at::_ops::_foreach_addcmul_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) {
+        return at::_ops::_foreach_addcmul_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) {
+        return at::_ops::_foreach_addcmul_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) {
+        return at::_ops::_foreach_addcmul_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out);
+    }
+    
+    // aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_abs_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_abs_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_abs_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_abs_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_acos_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_acos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_acos_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_acos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_asin_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_asin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_asin_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_asin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_atan_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_atan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_atan_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_atan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_ceil_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_ceil_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_ceil_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_ceil_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_cos_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_cos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_cos_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_cos_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_cosh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_cosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_cosh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_cosh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_erf_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_erf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_erf_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_erf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_erfc_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_erfc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_erfc_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_erfc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_exp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_exp_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_exp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_exp_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_expm1_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_expm1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_expm1_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_expm1_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_floor_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_floor_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_floor_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_floor_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_frac_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_frac_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_frac_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_frac_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, at::TensorList weights) {
+        return at::_ops::_foreach_lerp_List_out::redispatch(dispatchKeySet, self, tensors1, weights, out);
+    }
+    
+    // aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out) {
+        return at::_ops::_foreach_lerp_List_out::redispatch(dispatchKeySet, self, tensors1, weights, out);
+    }
+    
+    // aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) {
+        return at::_ops::_foreach_lerp_Scalar_out::redispatch(dispatchKeySet, self, tensors1, weight, out);
+    }
+    
+    // aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out) {
+        return at::_ops::_foreach_lerp_Scalar_out::redispatch(dispatchKeySet, self, tensors1, weight, out);
+    }
+    
+    // aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_lgamma_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_lgamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_lgamma_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_lgamma_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_log_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_log_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log10_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_log10_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log10_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_log10_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log1p_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_log1p_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log1p_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_log1p_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log2_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_log2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_log2_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_log2_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_neg_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_neg_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_neg_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_neg_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_norm_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & ord=2) {
+        return at::_ops::_foreach_norm_Scalar_out::redispatch(dispatchKeySet, self, ord, out);
+    }
+    
+    // aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_norm_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & ord, at::TensorList out) {
+        return at::_ops::_foreach_norm_Scalar_out::redispatch(dispatchKeySet, self, ord, out);
+    }
+    
+    // aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList exponent) {
+        return at::_ops::_foreach_pow_List_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent, at::TensorList out) {
+        return at::_ops::_foreach_pow_List_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & exponent) {
+        return at::_ops::_foreach_pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent, at::TensorList out) {
+        return at::_ops::_foreach_pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef exponent) {
+        return at::_ops::_foreach_pow_ScalarList_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent, at::TensorList out) {
+        return at::_ops::_foreach_pow_ScalarList_out::redispatch(dispatchKeySet, self, exponent, out);
+    }
+    
+    // aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_reciprocal_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_reciprocal_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_reciprocal_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_reciprocal_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_round_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_round_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_round_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_round_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_sigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sigmoid_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_sigmoid_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sign_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_sign_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sign_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_sign_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sin_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_sin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sin_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_sin_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sinh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_sinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sinh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_sinh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sqrt_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_sqrt_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_sqrt_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_sqrt_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_tan_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_tan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_tan_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_tan_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_tanh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_tanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_tanh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_tanh_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_trunc_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_trunc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_trunc_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_trunc_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_zero_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) {
+        return at::_ops::_foreach_zero_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_zero_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) {
+        return at::_ops::_foreach_zero_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_foreach_zero(Tensor[] self) -> Tensor[] self_out
+    inline ::std::vector _foreach_zero(c10::DispatchKeySet dispatchKeySet, at::TensorList self) {
+        return at::_ops::_foreach_zero::redispatch(dispatchKeySet, self);
+    }
+    
+    // aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList src, bool non_blocking=false) {
+        return at::_ops::_foreach_copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out);
+    }
+    
+    // aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> ()
+    inline void _foreach_copy_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out) {
+        return at::_ops::_foreach_copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out);
+    }
+    
+    // aten::_foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
+    inline ::std::vector _foreach_copy(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking=false) {
+        return at::_ops::_foreach_copy::redispatch(dispatchKeySet, self, src, non_blocking);
+    }
+    
+    // aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bucketize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) {
+        return at::_ops::bucketize_Scalar_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out);
+    }
+    
+    // aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & bucketize_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out) {
+        return at::_ops::bucketize_Scalar_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out);
+    }
+    
+    // aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & glu_jvp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) {
+        return at::_ops::glu_jvp_out::redispatch(dispatchKeySet, glu, x, dx, dim, out);
+    }
+    
+    // aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & glu_jvp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim, at::Tensor & out) {
+        return at::_ops::glu_jvp_out::redispatch(dispatchKeySet, glu, x, dx, dim, out);
+    }
+    
+    // aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & glu_backward_jvp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) {
+        return at::_ops::glu_backward_jvp_out::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim, out);
+    }
+    
+    // aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & glu_backward_jvp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim, at::Tensor & out) {
+        return at::_ops::glu_backward_jvp_out::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim, out);
+    }
+    
+    // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardswish_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::hardswish_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & hardswish_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::hardswish_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rrelu_with_noise_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) {
+        return at::_ops::rrelu_with_noise_backward_out::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result, out);
+    }
+    
+    // aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & rrelu_with_noise_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result, at::Tensor & out) {
+        return at::_ops::rrelu_with_noise_backward_out::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result, out);
+    }
+    
+    // aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_adaptive_avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::mkldnn_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & mkldnn_adaptive_avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::mkldnn_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out);
+    }
+    
+    // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) {
+        return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out);
+    }
+    
+    // aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) {
+        return at::_ops::_adaptive_avg_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _adaptive_avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_adaptive_avg_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, self, out);
+    }
+    
+    // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask) {
+        return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask, out0, out1, out2);
+    }
+    
+    // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask, out0, out1, out2);
+    }
+    
+    // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) {
+        return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask, out0, out1, out2);
+    }
+    
+    // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+    inline ::std::tuple _slow_conv2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) {
+        return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask, out0, out1, out2);
+    }
+    
+    // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conv_depthwise3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) {
+        return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conv_depthwise3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conv_depthwise3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) {
+        return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & conv_depthwise3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) {
+        return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out);
+    }
+    
+    // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) {
+        return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slow_conv_dilated3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) {
+        return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out);
+    }
+    
+    // aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isinf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::isinf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & isinf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::isinf_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_exp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::linalg_matrix_exp_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & linalg_matrix_exp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::linalg_matrix_exp_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_optional_intlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends) {
+        return at::_ops::_test_optional_intlist_out::redispatch(dispatchKeySet, values, addends, out);
+    }
+    
+    // aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_optional_intlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out) {
+        return at::_ops::_test_optional_intlist_out::redispatch(dispatchKeySet, values, addends, out);
+    }
+    
+    // aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_optional_filled_intlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends) {
+        return at::_ops::_test_optional_filled_intlist_out::redispatch(dispatchKeySet, values, addends, out);
+    }
+    
+    // aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_optional_filled_intlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out) {
+        return at::_ops::_test_optional_filled_intlist_out::redispatch(dispatchKeySet, values, addends, out);
+    }
+    
+    // aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_optional_floatlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, c10::optional> addends) {
+        return at::_ops::_test_optional_floatlist_out::redispatch(dispatchKeySet, values, addends, out);
+    }
+    
+    // aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_optional_floatlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, c10::optional> addends, at::Tensor & out) {
+        return at::_ops::_test_optional_floatlist_out::redispatch(dispatchKeySet, values, addends, out);
+    }
+    
+    // aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_warn_in_autograd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_test_warn_in_autograd_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_warn_in_autograd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_test_warn_in_autograd_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_autograd_multiple_dispatch_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_test_autograd_multiple_dispatch_fullcoverage_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_autograd_multiple_dispatch_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_test_autograd_multiple_dispatch_fullcoverage_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_autograd_multiple_dispatch_view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_test_autograd_multiple_dispatch_view_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _test_autograd_multiple_dispatch_view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_test_autograd_multiple_dispatch_view_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & segment_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & data, c10::string_view reduce, const c10::optional & lengths={}, const c10::optional & indices={}, const c10::optional & offsets={}, int64_t axis=0, bool unsafe=false, const c10::optional & initial=c10::nullopt) {
+        return at::_ops::segment_reduce_out::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial, out);
+    }
+    
+    // aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & segment_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, c10::string_view reduce, const c10::optional & lengths, const c10::optional & indices, const c10::optional & offsets, int64_t axis, bool unsafe, const c10::optional & initial, at::Tensor & out) {
+        return at::_ops::segment_reduce_out::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial, out);
+    }
+    
+    // aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _segment_reduce_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const c10::optional & lengths={}, const c10::optional & offsets={}, int64_t axis=0, const c10::optional & initial=c10::nullopt) {
+        return at::_ops::_segment_reduce_backward_out::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial, out);
+    }
+    
+    // aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _segment_reduce_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const c10::optional & lengths, const c10::optional & offsets, int64_t axis, const c10::optional & initial, at::Tensor & out) {
+        return at::_ops::_segment_reduce_backward_out::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial, out);
+    }
+    
+    // aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_from_tensor_list_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList list, c10::optional dtype=c10::nullopt, c10::optional layout=c10::nullopt, c10::optional device=c10::nullopt, c10::optional pin_memory=c10::nullopt) {
+        return at::_ops::_nested_tensor_from_tensor_list_out::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory, out);
+    }
+    
+    // aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _nested_tensor_from_tensor_list_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList list, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, at::Tensor & out) {
+        return at::_ops::_nested_tensor_from_tensor_list_out::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory, out);
+    }
+    
+    // aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fw_primal_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t level) {
+        return at::_ops::_fw_primal_copy_out::redispatch(dispatchKeySet, self, level, out);
+    }
+    
+    // aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _fw_primal_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, at::Tensor & out) {
+        return at::_ops::_fw_primal_copy_out::redispatch(dispatchKeySet, self, level, out);
+    }
+    
+    // aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _make_dual_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) {
+        return at::_ops::_make_dual_copy_out::redispatch(dispatchKeySet, primal, tangent, level, out);
+    }
+    
+    // aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _make_dual_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level, at::Tensor & out) {
+        return at::_ops::_make_dual_copy_out::redispatch(dispatchKeySet, primal, tangent, level, out);
+    }
+    
+    // aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_as_real_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::view_as_real_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_as_real_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::view_as_real_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_as_complex_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::view_as_complex_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_as_complex_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::view_as_complex_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _conj_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_conj_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _conj_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_conj_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _neg_view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_neg_view_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _neg_view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_neg_view_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt, out);
+    }
+    
+    // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset, at::Tensor & out) {
+        return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt, out);
+    }
+    
+    // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) {
+        return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, size, stride, storage_offset, out);
+    }
+    
+    // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & as_strided_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset, at::Tensor & out) {
+        return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, size, stride, storage_offset, out);
+    }
+    
+    // aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_broadcast_to_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::_sparse_broadcast_to_copy_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _sparse_broadcast_to_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::_sparse_broadcast_to_copy_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) {
+        return at::_ops::diagonal_copy_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out);
+    }
+    
+    // aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & diagonal_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) {
+        return at::_ops::diagonal_copy_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out);
+    }
+    
+    // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & expand_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) {
+        return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit, out);
+    }
+    
+    // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & expand_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit, at::Tensor & out) {
+        return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit, out);
+    }
+    
+    // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & expand_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) {
+        return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, size, implicit, out);
+    }
+    
+    // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & expand_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit, at::Tensor & out) {
+        return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, size, implicit, out);
+    }
+    
+    // aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & permute_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dims) {
+        return at::_ops::permute_copy_out::redispatch(dispatchKeySet, self, dims, out);
+    }
+    
+    // aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & permute_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out) {
+        return at::_ops::permute_copy_out::redispatch(dispatchKeySet, self, dims, out);
+    }
+    
+    // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _reshape_alias_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) {
+        return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _reshape_alias_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) {
+        return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out);
+    }
+    
+    // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _reshape_alias_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
+        return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, size, stride, out);
+    }
+    
+    // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _reshape_alias_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) {
+        return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, size, stride, out);
+    }
+    
+    // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, int64_t index) {
+        return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index, at::Tensor & out) {
+        return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::SymInt index) {
+        return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & select_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) {
+        return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out);
+    }
+    
+    // aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & detach_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::detach_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & detach_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::detach_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) {
+        return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step, out);
+    }
+    
+    // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional start, c10::optional end, int64_t step, at::Tensor & out) {
+        return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step, out);
+    }
+    
+    // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) {
+        return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start, end, step, out);
+    }
+    
+    // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & slice_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step, at::Tensor & out) {
+        return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start, end, step, out);
+    }
+    
+    // aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::squeeze_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::squeeze_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) {
+        return at::_ops::squeeze_copy_dim_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) {
+        return at::_ops::squeeze_copy_dim_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) {
+        return at::_ops::squeeze_copy_dims_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) {
+        return at::_ops::squeeze_copy_dims_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & t_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::t_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & t_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::t_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & transpose_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim0, int64_t dim1) {
+        return at::_ops::transpose_copy_int_out::redispatch(dispatchKeySet, self, dim0, dim1, out);
+    }
+    
+    // aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & transpose_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out) {
+        return at::_ops::transpose_copy_int_out::redispatch(dispatchKeySet, self, dim0, dim1, out);
+    }
+    
+    // aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unsqueeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) {
+        return at::_ops::unsqueeze_copy_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unsqueeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) {
+        return at::_ops::unsqueeze_copy_out::redispatch(dispatchKeySet, self, dim, out);
+    }
+    
+    // aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::_values_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::_values_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::values_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::values_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & crow_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::crow_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & crow_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::crow_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & col_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::col_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & col_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::col_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ccol_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::ccol_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & ccol_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::ccol_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & row_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::row_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & row_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::row_indices_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) {
+        return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) {
+        return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out);
+    }
+    
+    // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) {
+        return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) {
+        return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, size, out);
+    }
+    
+    // aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::ScalarType dtype) {
+        return at::_ops::view_copy_dtype_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, at::Tensor & out) {
+        return at::_ops::view_copy_dtype_out::redispatch(dispatchKeySet, self, dtype, out);
+    }
+    
+    // aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unfold_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) {
+        return at::_ops::unfold_copy_out::redispatch(dispatchKeySet, self, dimension, size, step, out);
+    }
+    
+    // aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & unfold_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step, at::Tensor & out) {
+        return at::_ops::unfold_copy_out::redispatch(dispatchKeySet, self, dimension, size, step, out);
+    }
+    
+    // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & alias_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) {
+        return at::_ops::alias_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & alias_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) {
+        return at::_ops::alias_copy_out::redispatch(dispatchKeySet, self, out);
+    }
+    
+    // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & to_padded_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=c10::nullopt) {
+        return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, out);
+    }
+    
+    // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & to_padded_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt, out);
+    }
+    
+    // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & to_padded_tensor_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size=c10::nullopt) {
+        return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size, out);
+    }
+    
+    // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & to_padded_tensor_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size, at::Tensor & out) {
+        return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size, out);
+    }
+    
+    // aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _transformer_encoder_layer_fwd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const c10::optional & mask={}, c10::optional mask_type=c10::nullopt) {
+        return at::_ops::_transformer_encoder_layer_fwd_out::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type, out);
+    }
+    
+    // aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _transformer_encoder_layer_fwd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const c10::optional & mask, c10::optional mask_type, at::Tensor & out) {
+        return at::_ops::_transformer_encoder_layer_fwd_out::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type, out);
+    }
+    
+    // aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _native_multi_head_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const c10::optional & mask={}, bool need_weights=true, bool average_attn_weights=true, c10::optional mask_type=c10::nullopt) {
+        return at::_ops::_native_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type, out0, out1);
+    }
+    
+    // aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))
+    inline ::std::tuple _native_multi_head_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const c10::optional & mask, bool need_weights, bool average_attn_weights, c10::optional mask_type, at::Tensor & out0, at::Tensor & out1) {
+        return at::_ops::_native_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type, out0, out1);
+    }
+    
+    // aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _triton_scaled_dot_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0) {
+        return at::_ops::_triton_scaled_dot_attention_out::redispatch(dispatchKeySet, q, k, v, dropout_p, out);
+    }
+    
+    // aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _triton_scaled_dot_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out) {
+        return at::_ops::_triton_scaled_dot_attention_out::redispatch(dispatchKeySet, q, k, v, dropout_p, out);
+    }
+    
+    // aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _triton_multi_head_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const c10::optional & mask={}) {
+        return at::_ops::_triton_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, out);
+    }
+    
+    // aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _triton_multi_head_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const c10::optional & mask, at::Tensor & out) {
+        return at::_ops::_triton_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, out);
+    }
+    
+    // aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _foobar_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool arg1=true, bool arg2=true, bool arg3=true) {
+        return at::_ops::_foobar_out::redispatch(dispatchKeySet, self, arg1, arg2, arg3, out);
+    }
+    
+    // aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!)
+    inline at::Tensor & _foobar_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool arg1, bool arg2, bool arg3, at::Tensor & out) {
+        return at::_ops::_foobar_out::redispatch(dispatchKeySet, self, arg1, arg2, arg3, out);
+    }
+    
+    // aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adam_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adam_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adam_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, at::TensorList out) {
+        return at::_ops::_fused_adam_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adam(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)
+    inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adam::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adam_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adam_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adam_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, at::TensorList out) {
+        return at::_ops::_fused_adam_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adam.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)
+    inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adam_tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adamw_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adamw_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adamw_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, at::TensorList out) {
+        return at::_ops::_fused_adamw_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adamw(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)
+    inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adamw::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adamw_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adamw_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_adamw_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, at::TensorList out) {
+        return at::_ops::_fused_adamw_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_adamw.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)
+    inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_adamw_tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_sgd_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_sgd_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_sgd_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf, at::TensorList out) {
+        return at::_ops::_fused_sgd_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_sgd(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)
+    inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_sgd::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf);
+    }
+    
+    // aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_sgd_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_sgd_tensor_lr_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()
+    inline void _fused_sgd_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf, at::TensorList out) {
+        return at::_ops::_fused_sgd_tensor_lr_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out);
+    }
+    
+    // aten::_fused_sgd.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)
+    inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale={}, const c10::optional & found_inf={}) {
+        return at::_ops::_fused_sgd_tensor_lr::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf);
+    }
+} // namespace redispatch
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/RegistrationDeclarations.h b/MLPY/Lib/site-packages/torch/include/ATen/RegistrationDeclarations.h
new file mode 100644
index 0000000000000000000000000000000000000000..efd957c9e256baddec2135a1408f57a202aa1242
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/RegistrationDeclarations.h
@@ -0,0 +1,3099 @@
+// This file contains all native_functions that can be registered to
+// and the schema string that they should be registered with
+
+Tensor _cast_Byte(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Char(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Char(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Double(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Double(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Float(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Float(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Int(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Int(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Long(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Long(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Short(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Short(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _cast_Half(const Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Half(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"}
+void _backward(const Tensor & self, TensorList inputs, const c10::optional & gradient, c10::optional retain_graph, bool create_graph); // {"schema": "aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", "dispatch": "False", "default": "True"}
+void set_data(Tensor & self, const Tensor & new_data); // {"schema": "aten::set_data(Tensor(a!) self, Tensor new_data) -> ()", "dispatch": "False", "default": "True"}
+Tensor data(const Tensor & self); // {"schema": "aten::data(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+bool is_leaf(const Tensor & self); // {"schema": "aten::is_leaf(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+int64_t output_nr(const Tensor & self); // {"schema": "aten::output_nr(Tensor self) -> int", "dispatch": "False", "default": "True"}
+int64_t _version(const Tensor & self); // {"schema": "aten::_version(Tensor self) -> int", "dispatch": "False", "default": "True"}
+Tensor & requires_grad_(Tensor & self, bool requires_grad); // {"schema": "aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+void retain_grad(Tensor & self); // {"schema": "aten::retain_grad(Tensor(a!) self) -> ()", "dispatch": "False", "default": "True"}
+bool retains_grad(const Tensor & self); // {"schema": "aten::retains_grad(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+Tensor _fw_primal(const Tensor & self, int64_t level); // {"schema": "aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor _make_dual(const Tensor & primal, const Tensor & tangent, int64_t level); // {"schema": "aten::_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)", "dispatch": "True", "default": "True"}
+::std::tuple _unpack_dual(const Tensor & dual, int64_t level); // {"schema": "aten::_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)", "dispatch": "False", "default": "True"}
+Tensor _new_zeros_with_same_feature_meta(const Tensor & self, const Tensor & other, int64_t self_num_batch_dims); // {"schema": "aten::_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor", "dispatch": "True", "default": "True"}
+bool _has_same_storage_numel(const Tensor & self, const Tensor & other); // {"schema": "aten::_has_same_storage_numel(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "True"}
+Tensor & rename_(Tensor & self, c10::optional names); // {"schema": "aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor rename(const Tensor & self, c10::optional names); // {"schema": "aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor align_to(const Tensor & self, DimnameList names); // {"schema": "aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor align_to(const Tensor & self, DimnameList order, int64_t ellipsis_idx); // {"schema": "aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor align_as(const Tensor & self, const Tensor & other); // {"schema": "aten::align_as(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+::std::vector align_tensors(TensorList tensors); // {"schema": "aten::align_tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+void _assert_async(const Tensor & self); // {"schema": "aten::_assert_async(Tensor self) -> ()", "dispatch": "True", "default": "False"}
+void _assert_async(const Tensor & self, c10::string_view assert_msg); // {"schema": "aten::_assert_async.msg(Tensor self, str assert_msg) -> ()", "dispatch": "True", "default": "False"}
+void _assert_scalar(const Scalar & self, c10::string_view assert_msg); // {"schema": "aten::_assert_scalar(Scalar self, str assert_msg) -> ()", "dispatch": "True", "default": "True"}
+Tensor _functional_assert_scalar(const Scalar & self, c10::string_view assert_msg, const Tensor & dep_token); // {"schema": "aten::_functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _functional_assert_async(const Tensor & self, c10::string_view assert_msg, const Tensor & dep_token); // {"schema": "aten::_functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "False"}
+void _assert_tensor_metadata(const Tensor & a, OptionalSymIntArrayRef size, OptionalSymIntArrayRef stride, c10::optional dtype); // {"schema": "aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()", "dispatch": "False", "default": "True"}
+void _print(c10::string_view s); // {"schema": "aten::_print(str s) -> ()", "dispatch": "True", "default": "True"}
+void sym_constrain_range(const Scalar & size, c10::optional min, c10::optional max); // {"schema": "aten::sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()", "dispatch": "True", "default": "True"}
+void sym_constrain_range_for_size(const Scalar & size, c10::optional min, c10::optional max); // {"schema": "aten::sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> ()", "dispatch": "True", "default": "True"}
+Tensor _functional_sym_constrain_range(const Scalar & size, c10::optional min, c10::optional max, const Tensor & dep_token); // {"schema": "aten::_functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _functional_sym_constrain_range_for_size(const Scalar & size, c10::optional min, c10::optional max, const Tensor & dep_token); // {"schema": "aten::_functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _make_dep_token(c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor refine_names(const Tensor & self, DimnameList names); // {"schema": "aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"}
+bool _use_cudnn_ctc_loss(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank); // {"schema": "aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool", "dispatch": "True", "default": "False"}
+bool _use_cudnn_ctc_loss(const Tensor & log_probs, const Tensor & targets, const Tensor & input_lengths, const Tensor & target_lengths, int64_t blank); // {"schema": "aten::_use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool", "dispatch": "True", "default": "False"}
+::std::tuple _cudnn_ctc_loss(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity); // {"schema": "aten::_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _cudnn_ctc_loss(const Tensor & log_probs, const Tensor & targets, const Tensor & input_lengths, const Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity); // {"schema": "aten::_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+bool _use_cudnn_rnn_flatten_weight(); // {"schema": "aten::_use_cudnn_rnn_flatten_weight() -> bool", "dispatch": "False", "default": "True"}
+Tensor _cudnn_rnn_flatten_weight(TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional); // {"schema": "aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _cudnn_rnn(const Tensor & input, TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const Tensor & hx, const c10::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state); // {"schema": "aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple> _cudnn_rnn_backward(const Tensor & input, TensorList weight, int64_t weight_stride0, const Tensor & weight_buf, const Tensor & hx, const c10::optional & cx, const Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, const Tensor & reserve, ::std::array output_mask); // {"schema": "aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])", "dispatch": "True", "default": "False"}
+Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "False"}
+int64_t _debug_has_internal_overlap(const Tensor & self); // {"schema": "aten::_debug_has_internal_overlap(Tensor self) -> int", "dispatch": "False", "default": "True"}
+::std::tuple _fused_dropout(const Tensor & self, double p, c10::optional generator); // {"schema": "aten::_fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _masked_scale(const Tensor & self, const Tensor & mask, double scale); // {"schema": "aten::_masked_scale(Tensor self, Tensor mask, float scale) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple native_dropout(const Tensor & input, double p, c10::optional train); // {"schema": "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor native_dropout_backward(const Tensor & grad_output, const Tensor & mask, double scale); // {"schema": "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _sobol_engine_draw(const Tensor & quasi, int64_t n, const Tensor & sobolstate, int64_t dimension, int64_t num_generated, c10::optional dtype); // {"schema": "aten::_sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor & _sobol_engine_ff_(Tensor & self, int64_t n, const Tensor & sobolstate, int64_t dimension, int64_t num_generated); // {"schema": "aten::_sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & _sobol_engine_scramble_(Tensor & self, const Tensor & ltm, int64_t dimension); // {"schema": "aten::_sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & _sobol_engine_initialize_state_(Tensor & self, int64_t dimension); // {"schema": "aten::_sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor _reshape_from_tensor(const Tensor & self, const Tensor & shape); // {"schema": "aten::_reshape_from_tensor(Tensor self, Tensor shape) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _shape_as_tensor(const Tensor & self); // {"schema": "aten::_shape_as_tensor(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor dropout(const Tensor & input, double p, bool train); // {"schema": "aten::dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & dropout_(Tensor & self, double p, bool train); // {"schema": "aten::dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor feature_dropout(const Tensor & input, double p, bool train); // {"schema": "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & feature_dropout_(Tensor & self, double p, bool train); // {"schema": "aten::feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor alpha_dropout(const Tensor & input, double p, bool train); // {"schema": "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & alpha_dropout_(Tensor & self, double p, bool train); // {"schema": "aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor feature_alpha_dropout(const Tensor & input, double p, bool train); // {"schema": "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & feature_alpha_dropout_(Tensor & self, double p, bool train); // {"schema": "aten::feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor abs(const Tensor & self); // {"schema": "aten::abs(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & abs_(Tensor & self); // {"schema": "aten::abs_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & abs_out(const Tensor & self, Tensor & out); // {"schema": "aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor absolute(const Tensor & self); // {"schema": "aten::absolute(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & absolute_(Tensor & self); // {"schema": "aten::absolute_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & absolute_out(const Tensor & self, Tensor & out); // {"schema": "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor angle(const Tensor & self); // {"schema": "aten::angle(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & angle_out(const Tensor & self, Tensor & out); // {"schema": "aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor view_as_real(const Tensor & self); // {"schema": "aten::view_as_real(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor view_as_complex(const Tensor & self); // {"schema": "aten::view_as_complex(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor sgn(const Tensor & self); // {"schema": "aten::sgn(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sgn_(Tensor & self); // {"schema": "aten::sgn_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sgn_out(const Tensor & self, Tensor & out); // {"schema": "aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor chalf(const Tensor & self, c10::optional memory_format); // {"schema": "aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor real(const Tensor & self); // {"schema": "aten::real(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor imag(const Tensor & self); // {"schema": "aten::imag(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _conj(const Tensor & self); // {"schema": "aten::_conj(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor conj(const Tensor & self); // {"schema": "aten::conj(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _conj_physical(const Tensor & self); // {"schema": "aten::_conj_physical(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor conj_physical(const Tensor & self); // {"schema": "aten::conj_physical(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & conj_physical_out(const Tensor & self, Tensor & out); // {"schema": "aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & conj_physical_(Tensor & self); // {"schema": "aten::conj_physical_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor resolve_conj(const Tensor & self); // {"schema": "aten::resolve_conj(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor resolve_neg(const Tensor & self); // {"schema": "aten::resolve_neg(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _neg_view(const Tensor & self); // {"schema": "aten::_neg_view(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor acos(const Tensor & self); // {"schema": "aten::acos(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & acos_(Tensor & self); // {"schema": "aten::acos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & acos_out(const Tensor & self, Tensor & out); // {"schema": "aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor arccos(const Tensor & self); // {"schema": "aten::arccos(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arccos_(Tensor & self); // {"schema": "aten::arccos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arccos_out(const Tensor & self, Tensor & out); // {"schema": "aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor avg_pool1d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad); // {"schema": "aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor adaptive_avg_pool1d(const Tensor & self, IntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple adaptive_max_pool1d(const Tensor & self, IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor add(const Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & add_(Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & add_out(const Tensor & self, const Tensor & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _add_relu(const Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::_add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _add_relu_(Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::_add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _add_relu_out(const Tensor & self, const Tensor & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _add_relu(const Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::_add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _add_relu_(Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::_add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor add(const Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & add_(Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor addmv(const Tensor & self, const Tensor & mat, const Tensor & vec, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & addmv_(Tensor & self, const Tensor & mat, const Tensor & vec, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & addmv_out(const Tensor & self, const Tensor & mat, const Tensor & vec, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor addr(const Tensor & self, const Tensor & vec1, const Tensor & vec2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & addr_(Tensor & self, const Tensor & vec1, const Tensor & vec2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & addr_out(const Tensor & self, const Tensor & vec1, const Tensor & vec2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor affine_grid_generator(const Tensor & theta, c10::SymIntArrayRef size, bool align_corners); // {"schema": "aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor affine_grid_generator_backward(const Tensor & grad, c10::SymIntArrayRef size, bool align_corners); // {"schema": "aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _is_all_true(const Tensor & self); // {"schema": "aten::_is_all_true(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _is_any_true(const Tensor & self); // {"schema": "aten::_is_any_true(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _test_check_tensor(const Tensor & self); // {"schema": "aten::_test_check_tensor(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _test_functorch_fallback(const Tensor & self, const Tensor & other); // {"schema": "aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor all(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor all(const Tensor & self, OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & all_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & out); // {"schema": "aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & all_out(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor all(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & all_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & out); // {"schema": "aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+bool allclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan); // {"schema": "aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool", "dispatch": "True", "default": "True"}
+Tensor any(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor any(const Tensor & self, OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & any_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & out); // {"schema": "aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & any_out(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor any(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & any_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & out); // {"schema": "aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor arange(const Scalar & end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor arange(const Scalar & start, const Scalar & end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor arange(const Scalar & start, const Scalar & end, const Scalar & step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & arange_out(const Scalar & end, Tensor & out); // {"schema": "aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & arange_out(const Scalar & start, const Scalar & end, const Scalar & step, Tensor & out); // {"schema": "aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _dim_arange(const Tensor & like, int64_t dim); // {"schema": "aten::_dim_arange(Tensor like, int dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor argmax(const Tensor & self, c10::optional dim, bool keepdim); // {"schema": "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & argmax_out(const Tensor & self, c10::optional dim, bool keepdim, Tensor & out); // {"schema": "aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor argmin(const Tensor & self, c10::optional dim, bool keepdim); // {"schema": "aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & argmin_out(const Tensor & self, c10::optional dim, bool keepdim, Tensor & out); // {"schema": "aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor acosh(const Tensor & self); // {"schema": "aten::acosh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & acosh_(Tensor & self); // {"schema": "aten::acosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & acosh_out(const Tensor & self, Tensor & out); // {"schema": "aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor arccosh(const Tensor & self); // {"schema": "aten::arccosh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arccosh_(Tensor & self); // {"schema": "aten::arccosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arccosh_out(const Tensor & self, Tensor & out); // {"schema": "aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor asinh(const Tensor & self); // {"schema": "aten::asinh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & asinh_(Tensor & self); // {"schema": "aten::asinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & asinh_out(const Tensor & self, Tensor & out); // {"schema": "aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor arcsinh(const Tensor & self); // {"schema": "aten::arcsinh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arcsinh_(Tensor & self); // {"schema": "aten::arcsinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arcsinh_out(const Tensor & self, Tensor & out); // {"schema": "aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor atanh(const Tensor & self); // {"schema": "aten::atanh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & atanh_(Tensor & self); // {"schema": "aten::atanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & atanh_out(const Tensor & self, Tensor & out); // {"schema": "aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor arctanh(const Tensor & self); // {"schema": "aten::arctanh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arctanh_(Tensor & self); // {"schema": "aten::arctanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arctanh_out(const Tensor & self, Tensor & out); // {"schema": "aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor as_strided(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset); // {"schema": "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)", "dispatch": "True", "default": "False"}
+const Tensor & as_strided_(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset); // {"schema": "aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor asin(const Tensor & self); // {"schema": "aten::asin(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & asin_(Tensor & self); // {"schema": "aten::asin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & asin_out(const Tensor & self, Tensor & out); // {"schema": "aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor arcsin(const Tensor & self); // {"schema": "aten::arcsin(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arcsin_(Tensor & self); // {"schema": "aten::arcsin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arcsin_out(const Tensor & self, Tensor & out); // {"schema": "aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor atan(const Tensor & self); // {"schema": "aten::atan(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & atan_(Tensor & self); // {"schema": "aten::atan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & atan_out(const Tensor & self, Tensor & out); // {"schema": "aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor arctan(const Tensor & self); // {"schema": "aten::arctan(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arctan_(Tensor & self); // {"schema": "aten::arctan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arctan_out(const Tensor & self, Tensor & out); // {"schema": "aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor atleast_1d(const Tensor & self); // {"schema": "aten::atleast_1d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+::std::vector atleast_1d(TensorList tensors); // {"schema": "aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor atleast_2d(const Tensor & self); // {"schema": "aten::atleast_2d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+::std::vector atleast_2d(TensorList tensors); // {"schema": "aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor atleast_3d(const Tensor & self); // {"schema": "aten::atleast_3d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+::std::vector atleast_3d(TensorList tensors); // {"schema": "aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & baddbmm_(Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & baddbmm_out(const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bartlett_window(int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bartlett_window(int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor batch_norm(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor quantized_batch_norm(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & mean, const Tensor & var, double eps, double output_scale, int64_t output_zero_point); // {"schema": "aten::quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _batch_norm_impl_index(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)", "dispatch": "False", "default": "True"}
+::std::tuple _batch_norm_impl_index_backward(int64_t impl_index, const Tensor & input, const Tensor & grad_output, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const Tensor & reservedSpace); // {"schema": "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor bernoulli(const Tensor & self, c10::optional generator); // {"schema": "aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bernoulli_out(const Tensor & self, c10::optional generator, Tensor & out); // {"schema": "aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & bernoulli_(Tensor & self, const Tensor & p, c10::optional generator); // {"schema": "aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & bernoulli_(Tensor & self, double p, c10::optional generator); // {"schema": "aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bernoulli(const Tensor & self, double p, c10::optional generator); // {"schema": "aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bilinear(const Tensor & input1, const Tensor & input2, const Tensor & weight, const c10::optional & bias); // {"schema": "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor binary_cross_entropy(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & binary_cross_entropy_out(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, Tensor & out); // {"schema": "aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor binary_cross_entropy_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & binary_cross_entropy_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, Tensor & grad_input); // {"schema": "aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor binary_cross_entropy_with_logits(const Tensor & self, const Tensor & target, const c10::optional & weight, const c10::optional & pos_weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bincount(const Tensor & self, const c10::optional & weights, int64_t minlength); // {"schema": "aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor bitwise_not(const Tensor & self); // {"schema": "aten::bitwise_not(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_not_(Tensor & self); // {"schema": "aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_not_out(const Tensor & self, Tensor & out); // {"schema": "aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & copysign_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor copysign(const Tensor & self, const Tensor & other); // {"schema": "aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & copysign_(Tensor & self, const Tensor & other); // {"schema": "aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor copysign(const Tensor & self, const Scalar & other); // {"schema": "aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & copysign_(Tensor & self, const Scalar & other); // {"schema": "aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & copysign_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor _lazy_clone(const Tensor & self); // {"schema": "aten::_lazy_clone(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor logical_not(const Tensor & self); // {"schema": "aten::logical_not(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logical_not_(Tensor & self); // {"schema": "aten::logical_not_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & logical_not_out(const Tensor & self, Tensor & out); // {"schema": "aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logical_xor(const Tensor & self, const Tensor & other); // {"schema": "aten::logical_xor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logical_xor_(Tensor & self, const Tensor & other); // {"schema": "aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & logical_xor_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logical_and(const Tensor & self, const Tensor & other); // {"schema": "aten::logical_and(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logical_and_(Tensor & self, const Tensor & other); // {"schema": "aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & logical_and_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logical_or(const Tensor & self, const Tensor & other); // {"schema": "aten::logical_or(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logical_or_(Tensor & self, const Tensor & other); // {"schema": "aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & logical_or_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor blackman_window(int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor blackman_window(int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bmm(const Tensor & self, const Tensor & mat2); // {"schema": "aten::bmm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bmm_out(const Tensor & self, const Tensor & mat2, Tensor & out); // {"schema": "aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::vector broadcast_tensors(TensorList tensors); // {"schema": "aten::broadcast_tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor broadcast_to(const Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _sparse_broadcast_to(const Tensor & self, IntArrayRef size); // {"schema": "aten::_sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor cat(const ITensorListRef & tensors, int64_t dim); // {"schema": "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & cat_out(const ITensorListRef & tensors, int64_t dim, Tensor & out); // {"schema": "aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cat(TensorList tensors, Dimname dim); // {"schema": "aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & cat_out(TensorList tensors, Dimname dim, Tensor & out); // {"schema": "aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor concat(TensorList tensors, int64_t dim); // {"schema": "aten::concat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & concat_out(TensorList tensors, int64_t dim, Tensor & out); // {"schema": "aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor concat(TensorList tensors, Dimname dim); // {"schema": "aten::concat.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & concat_out(TensorList tensors, Dimname dim, Tensor & out); // {"schema": "aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor concatenate(TensorList tensors, int64_t dim); // {"schema": "aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & concatenate_out(TensorList tensors, int64_t dim, Tensor & out); // {"schema": "aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor concatenate(TensorList tensors, Dimname dim); // {"schema": "aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & concatenate_out(TensorList tensors, Dimname dim, Tensor & out); // {"schema": "aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor block_diag(TensorList tensors); // {"schema": "aten::block_diag(Tensor[] tensors) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor ceil(const Tensor & self); // {"schema": "aten::ceil(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ceil_(Tensor & self); // {"schema": "aten::ceil_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & ceil_out(const Tensor & self, Tensor & out); // {"schema": "aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor chain_matmul(TensorList matrices); // {"schema": "aten::chain_matmul(Tensor[] matrices) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & chain_matmul_out(TensorList matrices, Tensor & out); // {"schema": "aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::vector unsafe_chunk(const Tensor & self, int64_t chunks, int64_t dim); // {"schema": "aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector chunk(const Tensor & self, int64_t chunks, int64_t dim); // {"schema": "aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"}
+::std::vector tensor_split(const Tensor & self, c10::SymInt sections, int64_t dim); // {"schema": "aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector tensor_split(const Tensor & self, c10::SymIntArrayRef indices, int64_t dim); // {"schema": "aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector tensor_split(const Tensor & self, const Tensor & tensor_indices_or_sections, int64_t dim); // {"schema": "aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+Tensor clamp(const Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor clamp(const Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & clamp_(Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clamp_(Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clamp_out(const Tensor & self, const c10::optional & min, const c10::optional & max, Tensor & out); // {"schema": "aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & clamp_out(const Tensor & self, const c10::optional & min, const c10::optional & max, Tensor & out); // {"schema": "aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor clamp_max(const Tensor & self, const Scalar & max); // {"schema": "aten::clamp_max(Tensor self, Scalar max) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor clamp_max(const Tensor & self, const Tensor & max); // {"schema": "aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & clamp_max_(Tensor & self, const Scalar & max); // {"schema": "aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clamp_max_(Tensor & self, const Tensor & max); // {"schema": "aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clamp_max_out(const Tensor & self, const Scalar & max, Tensor & out); // {"schema": "aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & clamp_max_out(const Tensor & self, const Tensor & max, Tensor & out); // {"schema": "aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor clamp_min(const Tensor & self, const Scalar & min); // {"schema": "aten::clamp_min(Tensor self, Scalar min) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor clamp_min(const Tensor & self, const Tensor & min); // {"schema": "aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & clamp_min_(Tensor & self, const Scalar & min); // {"schema": "aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clamp_min_(Tensor & self, const Tensor & min); // {"schema": "aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clamp_min_out(const Tensor & self, const Scalar & min, Tensor & out); // {"schema": "aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & clamp_min_out(const Tensor & self, const Tensor & min, Tensor & out); // {"schema": "aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor clip(const Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor clip(const Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & clip_(Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & clip_(Tensor & self, const c10::optional & min, const c10::optional & max); // {"schema": "aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & clip_out(const Tensor & self, const c10::optional & min, const c10::optional & max, Tensor & out); // {"schema": "aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & clip_out(const Tensor & self, const c10::optional & min, const c10::optional & max, Tensor & out); // {"schema": "aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+bool cudnn_is_acceptable(const Tensor & self); // {"schema": "aten::cudnn_is_acceptable(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+Tensor complex(const Tensor & real, const Tensor & imag); // {"schema": "aten::complex(Tensor real, Tensor imag) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & complex_out(const Tensor & real, const Tensor & imag, Tensor & out); // {"schema": "aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor polar(const Tensor & abs, const Tensor & angle); // {"schema": "aten::polar(Tensor abs, Tensor angle) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & polar_out(const Tensor & abs, const Tensor & angle, Tensor & out); // {"schema": "aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor constant_pad_nd(const Tensor & self, c10::SymIntArrayRef pad, const Scalar & value); // {"schema": "aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor contiguous(const Tensor & self, MemoryFormat memory_format); // {"schema": "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor convolution(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups); // {"schema": "aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple convolution_backward(const Tensor & grad_output, const Tensor & input, const Tensor & weight, OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor convolution_overrideable(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups); // {"schema": "aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple convolution_backward_overrideable(const Tensor & grad_output, const Tensor & input, const Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", "dispatch": "True", "default": "True"}
+Tensor _convolution(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32); // {"schema": "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _convolution(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled); // {"schema": "aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _convolution_mode(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _convolution_double_backward(const c10::optional & ggI, const c10::optional & ggW, const c10::optional & ggb, const Tensor & gO, const Tensor & weight, const Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor conv1d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv2d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv3d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv1d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding=\"valid\", SymInt[1] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv2d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding=\"valid\", SymInt[2] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv3d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding=\"valid\", SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv_tbc(const Tensor & self, const Tensor & weight, const Tensor & bias, int64_t pad); // {"schema": "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple conv_tbc_backward(const Tensor & self, const Tensor & input, const Tensor & weight, const Tensor & bias, int64_t pad); // {"schema": "aten::conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor conv_transpose1d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv_transpose2d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor conv_transpose3d(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor copy(const Tensor & self, const Tensor & src, bool non_blocking); // {"schema": "aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking); // {"schema": "aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor _copy_from(const Tensor & self, const Tensor & dst, bool non_blocking); // {"schema": "aten::_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _copy_from_and_resize(const Tensor & self, const Tensor & dst); // {"schema": "aten::_copy_from_and_resize(Tensor self, Tensor dst) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor cos(const Tensor & self); // {"schema": "aten::cos(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & cos_(Tensor & self); // {"schema": "aten::cos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cos_out(const Tensor & self, Tensor & out); // {"schema": "aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cosh(const Tensor & self); // {"schema": "aten::cosh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & cosh_(Tensor & self); // {"schema": "aten::cosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cosh_out(const Tensor & self, Tensor & out); // {"schema": "aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cosine_embedding_loss(const Tensor & input1, const Tensor & input2, const Tensor & target, double margin, int64_t reduction); // {"schema": "aten::cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor count_nonzero(const Tensor & self, IntArrayRef dim); // {"schema": "aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor count_nonzero(const Tensor & self, c10::optional dim); // {"schema": "aten::count_nonzero(Tensor self, int? dim=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor cov(const Tensor & self, int64_t correction, const c10::optional & fweights, const c10::optional & aweights); // {"schema": "aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor corrcoef(const Tensor & self); // {"schema": "aten::corrcoef(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor cudnn_affine_grid_generator(const Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W); // {"schema": "aten::cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid", "dispatch": "True", "default": "False"}
+Tensor cudnn_affine_grid_generator_backward(const Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W); // {"schema": "aten::cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta", "dispatch": "True", "default": "False"}
+::std::tuple cudnn_batch_norm(const Tensor & input, const Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon); // {"schema": "aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple cudnn_batch_norm_backward(const Tensor & input, const Tensor & grad_output, const Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, const Tensor & reserveSpace); // {"schema": "aten::cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor cudnn_convolution(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32); // {"schema": "aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & cudnn_convolution_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, Tensor & out); // {"schema": "aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cudnn_convolution_transpose(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32); // {"schema": "aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _mps_convolution_transpose(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple mps_convolution_transpose_backward(const Tensor & self, const Tensor & grad_output, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor cudnn_convolution_relu(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor cudnn_convolution_add_relu(const Tensor & self, const Tensor & weight, const Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor cudnn_grid_sampler(const Tensor & self, const Tensor & grid); // {"schema": "aten::cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output", "dispatch": "True", "default": "False"}
+::std::tuple cudnn_grid_sampler_backward(const Tensor & self, const Tensor & grid, const Tensor & grad_output); // {"schema": "aten::cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid)", "dispatch": "True", "default": "False"}
+::std::tuple cummax(const Tensor & self, int64_t dim); // {"schema": "aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple cummax_out(const Tensor & self, int64_t dim, Tensor & values, Tensor & indices); // {"schema": "aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"}
+::std::tuple cummax(const Tensor & self, Dimname dim); // {"schema": "aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple cummax_out(const Tensor & self, Dimname dim, Tensor & values, Tensor & indices); // {"schema": "aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+void _cummax_helper(const Tensor & self, Tensor & values, Tensor & indices, int64_t dim); // {"schema": "aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()", "dispatch": "True", "default": "False"}
+::std::tuple cummin(const Tensor & self, int64_t dim); // {"schema": "aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple cummin_out(const Tensor & self, int64_t dim, Tensor & values, Tensor & indices); // {"schema": "aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"}
+::std::tuple cummin(const Tensor & self, Dimname dim); // {"schema": "aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple cummin_out(const Tensor & self, Dimname dim, Tensor & values, Tensor & indices); // {"schema": "aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+void _cummin_helper(const Tensor & self, Tensor & values, Tensor & indices, int64_t dim); // {"schema": "aten::_cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()", "dispatch": "True", "default": "False"}
+Tensor cummaxmin_backward(const Tensor & grad, const Tensor & input, const Tensor & indices, int64_t dim); // {"schema": "aten::cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor cumprod(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & cumprod_(Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cumprod_out(const Tensor & self, int64_t dim, c10::optional dtype, Tensor & out); // {"schema": "aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cumprod(const Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & cumprod_(Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & cumprod_out(const Tensor & self, Dimname dim, c10::optional dtype, Tensor & out); // {"schema": "aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor cumprod_backward(const Tensor & grad, const Tensor & input, int64_t dim, const Tensor & output); // {"schema": "aten::cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor cumsum(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & cumsum_(Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cumsum_out(const Tensor & self, int64_t dim, c10::optional dtype, Tensor & out); // {"schema": "aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cumsum(const Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & cumsum_(Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & cumsum_out(const Tensor & self, Dimname dim, c10::optional dtype, Tensor & out); // {"schema": "aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor cumulative_trapezoid(const Tensor & y, const Tensor & x, int64_t dim); // {"schema": "aten::cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor cumulative_trapezoid(const Tensor & y, const Scalar & dx, int64_t dim); // {"schema": "aten::cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor ctc_loss(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, int64_t reduction, bool zero_infinity); // {"schema": "aten::ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor ctc_loss(const Tensor & log_probs, const Tensor & targets, const Tensor & input_lengths, const Tensor & target_lengths, int64_t blank, int64_t reduction, bool zero_infinity); // {"schema": "aten::ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _ctc_loss(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _ctc_loss(const Tensor & log_probs, const Tensor & targets, const Tensor & input_lengths, const Tensor & target_lengths, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _ctc_loss_backward(const Tensor & grad, const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor & neg_log_likelihood, const Tensor & log_alpha, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _ctc_loss_backward(const Tensor & grad, const Tensor & log_probs, const Tensor & targets, const Tensor & input_lengths, const Tensor & target_lengths, const Tensor & neg_log_likelihood, const Tensor & log_alpha, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor diag_embed(const Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor diagflat(const Tensor & self, int64_t offset); // {"schema": "aten::diagflat(Tensor self, int offset=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor diagonal(const Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor linalg_diagonal(const Tensor & A, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor diagonal(const Tensor & self, Dimname outdim, Dimname dim1, Dimname dim2, int64_t offset); // {"schema": "aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor diagonal_backward(const Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fill_diagonal_(Tensor & self, const Scalar & fill_value, bool wrap); // {"schema": "aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor diff(const Tensor & self, int64_t n, int64_t dim, const c10::optional & prepend, const c10::optional & append); // {"schema": "aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & diff_out(const Tensor & self, int64_t n, int64_t dim, const c10::optional & prepend, const c10::optional & append, Tensor & out); // {"schema": "aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, const c10::optional & spacing, c10::optional dim, int64_t edge_order); // {"schema": "aten::gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, const Scalar & spacing, IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, ArrayRef spacing, c10::optional dim, int64_t edge_order); // {"schema": "aten::gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, ArrayRef spacing, IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, TensorList spacing, c10::optional dim, int64_t edge_order); // {"schema": "aten::gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector gradient(const Tensor & self, TensorList spacing, IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor div(const Tensor & self, const Tensor & other); // {"schema": "aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & div_(Tensor & self, const Tensor & other); // {"schema": "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & div_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor div(const Tensor & self, const Tensor & other, c10::optional rounding_mode); // {"schema": "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & div_(Tensor & self, const Tensor & other, c10::optional rounding_mode); // {"schema": "aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & div_out(const Tensor & self, const Tensor & other, c10::optional rounding_mode, Tensor & out); // {"schema": "aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor div(const Tensor & self, const Scalar & other); // {"schema": "aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & div_(Tensor & self, const Scalar & other); // {"schema": "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor div(const Tensor & self, const Scalar & other, c10::optional rounding_mode); // {"schema": "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & div_(Tensor & self, const Scalar & other, c10::optional rounding_mode); // {"schema": "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor divide(const Tensor & self, const Tensor & other); // {"schema": "aten::divide.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & divide_(Tensor & self, const Tensor & other); // {"schema": "aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & divide_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor divide(const Tensor & self, const Scalar & other); // {"schema": "aten::divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & divide_(Tensor & self, const Scalar & other); // {"schema": "aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor divide(const Tensor & self, const Tensor & other, c10::optional rounding_mode); // {"schema": "aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & divide_(Tensor & self, const Tensor & other, c10::optional rounding_mode); // {"schema": "aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & divide_out(const Tensor & self, const Tensor & other, c10::optional rounding_mode, Tensor & out); // {"schema": "aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor divide(const Tensor & self, const Scalar & other, c10::optional rounding_mode); // {"schema": "aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & divide_(Tensor & self, const Scalar & other, c10::optional rounding_mode); // {"schema": "aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor true_divide(const Tensor & self, const Tensor & other); // {"schema": "aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & true_divide_(Tensor & self, const Tensor & other); // {"schema": "aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & true_divide_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor true_divide(const Tensor & self, const Scalar & other); // {"schema": "aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & true_divide_(Tensor & self, const Scalar & other); // {"schema": "aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor dot(const Tensor & self, const Tensor & tensor); // {"schema": "aten::dot(Tensor self, Tensor tensor) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & dot_out(const Tensor & self, const Tensor & tensor, Tensor & out); // {"schema": "aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor vdot(const Tensor & self, const Tensor & other); // {"schema": "aten::vdot(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & vdot_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor einsum(c10::string_view equation, TensorList tensors, OptionalIntArrayRef path); // {"schema": "aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor embedding(const Tensor & weight, const Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); // {"schema": "aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor embedding_backward(const Tensor & grad, const Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); // {"schema": "aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor embedding_dense_backward(const Tensor & grad_output, const Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq); // {"schema": "aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & embedding_renorm_(Tensor & self, const Tensor & indices, double max_norm, double norm_type); // {"schema": "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor embedding_sparse_backward(const Tensor & grad, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); // {"schema": "aten::embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _embedding_bag_forward_only(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx); // {"schema": "aten::_embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _rowwise_prune(const Tensor & weight, const Tensor & mask, ScalarType compressed_indices_dtype); // {"schema": "aten::_rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor row_stack(TensorList tensors); // {"schema": "aten::row_stack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & row_stack_out(TensorList tensors, Tensor & out); // {"schema": "aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset); // {"schema": "aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, c10::optional padding_idx); // {"schema": "aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple _embedding_bag(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx); // {"schema": "aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _embedding_bag_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _embedding_bag_sparse_backward(const Tensor & grad, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, const Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _embedding_bag_dense_backward(const Tensor & grad, const Tensor & indices, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _embedding_bag_per_sample_weights_backward(const Tensor & grad, const Tensor & weight, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, int64_t mode, int64_t padding_idx); // {"schema": "aten::_embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor empty(IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor empty(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor empty_permuted(c10::SymIntArrayRef size, IntArrayRef physical_layout, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor new_empty(const Tensor & self, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor new_empty_strided(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor new_full(const Tensor & self, c10::SymIntArrayRef size, const Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor new_zeros(const Tensor & self, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor new_ones(const Tensor & self, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _empty_affine_quantized(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, double scale, int64_t zero_point, c10::optional memory_format); // {"schema": "aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const Tensor & scales, const Tensor & zero_points, int64_t axis, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor", "dispatch": "True", "default": "False"}
+const Tensor & resize_(const Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format); // {"schema": "aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+const Tensor & _resize_output_(const Tensor & self, c10::SymIntArrayRef size, Device device); // {"schema": "aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor empty_quantized(IntArrayRef size, const Tensor & qtensor, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & empty_out(c10::SymIntArrayRef size, c10::optional memory_format, Tensor & out); // {"schema": "aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor empty_like(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor empty_strided(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor erf(const Tensor & self); // {"schema": "aten::erf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & erf_(Tensor & self); // {"schema": "aten::erf_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & erf_out(const Tensor & self, Tensor & out); // {"schema": "aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor erfc(const Tensor & self); // {"schema": "aten::erfc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & erfc_(Tensor & self); // {"schema": "aten::erfc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & erfc_out(const Tensor & self, Tensor & out); // {"schema": "aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor exp(const Tensor & self); // {"schema": "aten::exp(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & exp_(Tensor & self); // {"schema": "aten::exp_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & exp_out(const Tensor & self, Tensor & out); // {"schema": "aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor exp2(const Tensor & self); // {"schema": "aten::exp2(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & exp2_(Tensor & self); // {"schema": "aten::exp2_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & exp2_out(const Tensor & self, Tensor & out); // {"schema": "aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor expm1(const Tensor & self); // {"schema": "aten::expm1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & expm1_(Tensor & self); // {"schema": "aten::expm1_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & expm1_out(const Tensor & self, Tensor & out); // {"schema": "aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor expand(const Tensor & self, c10::SymIntArrayRef size, bool implicit); // {"schema": "aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor expand_as(const Tensor & self, const Tensor & other); // {"schema": "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor eye(c10::SymInt n, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor eye(c10::SymInt n, c10::SymInt m, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & eye_out(c10::SymInt n, Tensor & out); // {"schema": "aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & eye_out(c10::SymInt n, c10::SymInt m, Tensor & out); // {"schema": "aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor flatten(const Tensor & self, int64_t start_dim, int64_t end_dim); // {"schema": "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor flatten(const Tensor & self, int64_t start_dim, int64_t end_dim, Dimname out_dim); // {"schema": "aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor flatten(const Tensor & self, Dimname start_dim, Dimname end_dim, Dimname out_dim); // {"schema": "aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor flatten(const Tensor & self, DimnameList dims, Dimname out_dim); // {"schema": "aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor unflatten(const Tensor & self, int64_t dim, c10::SymIntArrayRef sizes); // {"schema": "aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor unflatten(const Tensor & self, Dimname dim, c10::SymIntArrayRef sizes, DimnameList names); // {"schema": "aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor fill(const Tensor & self, const Scalar & value); // {"schema": "aten::fill.Scalar(Tensor self, Scalar value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor fill(const Tensor & self, const Tensor & value); // {"schema": "aten::fill.Tensor(Tensor self, Tensor value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fill_(Tensor & self, const Scalar & value); // {"schema": "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & fill_(Tensor & self, const Tensor & value); // {"schema": "aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor floor(const Tensor & self); // {"schema": "aten::floor(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & floor_(Tensor & self); // {"schema": "aten::floor_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & floor_out(const Tensor & self, Tensor & out); // {"schema": "aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor floor_divide(const Tensor & self, const Tensor & other); // {"schema": "aten::floor_divide(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & floor_divide_(Tensor & self, const Tensor & other); // {"schema": "aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & floor_divide_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor floor_divide(const Tensor & self, const Scalar & other); // {"schema": "aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & floor_divide_(Tensor & self, const Scalar & other); // {"schema": "aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor frac(const Tensor & self); // {"schema": "aten::frac(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & frac_(Tensor & self); // {"schema": "aten::frac_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & frac_out(const Tensor & self, Tensor & out); // {"schema": "aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor full(IntArrayRef size, const Scalar & fill_value, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor full(c10::SymIntArrayRef size, const Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & full_out(c10::SymIntArrayRef size, const Scalar & fill_value, Tensor & out); // {"schema": "aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor full_like(const Tensor & self, const Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor from_file(c10::string_view filename, c10::optional shared, c10::optional size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & gcd_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor gcd(const Tensor & self, const Tensor & other); // {"schema": "aten::gcd(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & gcd_(Tensor & self, const Tensor & other); // {"schema": "aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & lcm_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor lcm(const Tensor & self, const Tensor & other); // {"schema": "aten::lcm(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & lcm_(Tensor & self, const Tensor & other); // {"schema": "aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor grid_sampler(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor grid_sampler_2d(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple grid_sampler_2d_backward(const Tensor & grad_output, const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); // {"schema": "aten::grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _grid_sampler_2d_cpu_fallback(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::_grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple _grid_sampler_2d_cpu_fallback_backward(const Tensor & grad_output, const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::_grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor grid_sampler_3d(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple grid_sampler_3d_backward(const Tensor & grad_output, const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); // {"schema": "aten::grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor hann_window(int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor hann_window(int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor hamming_window(int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor hamming_window(int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor hamming_window(int64_t window_length, bool periodic, double alpha, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor hamming_window(int64_t window_length, bool periodic, double alpha, double beta, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor kaiser_window(int64_t window_length, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor kaiser_window(int64_t window_length, bool periodic, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor kaiser_window(int64_t window_length, bool periodic, double beta, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor hinge_embedding_loss(const Tensor & self, const Tensor & target, double margin, int64_t reduction); // {"schema": "aten::hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor group_norm(const Tensor & input, int64_t num_groups, const c10::optional & weight, const c10::optional & bias, double eps, bool cudnn_enabled); // {"schema": "aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple native_group_norm(const Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps); // {"schema": "aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"}
+::std::tuple native_group_norm_backward(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & rstd, const c10::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask); // {"schema": "aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _fft_r2c(const Tensor & self, IntArrayRef dim, int64_t normalization, bool onesided); // {"schema": "aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _fft_r2c_out(const Tensor & self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor & out); // {"schema": "aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _fft_c2r(const Tensor & self, IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size); // {"schema": "aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _fft_c2r_out(const Tensor & self, IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, Tensor & out); // {"schema": "aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _fft_c2c(const Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward); // {"schema": "aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _fft_c2c_out(const Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, Tensor & out); // {"schema": "aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+void _validate_compressed_sparse_indices(bool is_crow, const Tensor & compressed_idx, const Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz); // {"schema": "aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()", "dispatch": "True", "default": "False"}
+int64_t _cufft_get_plan_cache_size(DeviceIndex device_index); // {"schema": "aten::_cufft_get_plan_cache_size(DeviceIndex device_index) -> int", "dispatch": "False", "default": "True"}
+int64_t _cufft_get_plan_cache_max_size(DeviceIndex device_index); // {"schema": "aten::_cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int", "dispatch": "False", "default": "True"}
+void _cufft_set_plan_cache_max_size(DeviceIndex device_index, int64_t max_size); // {"schema": "aten::_cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()", "dispatch": "False", "default": "True"}
+void _cufft_clear_plan_cache(DeviceIndex device_index); // {"schema": "aten::_cufft_clear_plan_cache(DeviceIndex device_index) -> ()", "dispatch": "False", "default": "True"}
+Tensor index(const Tensor & self, const c10::List> & indices); // {"schema": "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_out(const Tensor & self, const c10::List> & indices, Tensor & out); // {"schema": "aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _unsafe_index(const Tensor & self, const c10::List> & indices); // {"schema": "aten::_unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_copy_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, Tensor & out); // {"schema": "aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source); // {"schema": "aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source); // {"schema": "aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_copy_(Tensor & self, Dimname dim, const Tensor & index, const Tensor & source); // {"schema": "aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor index_copy(const Tensor & self, Dimname dim, const Tensor & index, const Tensor & source); // {"schema": "aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & index_put_(Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate); // {"schema": "aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor index_put(const Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate); // {"schema": "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _unsafe_index_put(const Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate); // {"schema": "aten::_unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _index_put_impl_(Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate, bool unsafe); // {"schema": "aten::_index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor instance_norm(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan); // {"schema": "aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & isin_out(const Tensor & elements, const Tensor & test_elements, bool assume_unique, bool invert, Tensor & out); // {"schema": "aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor isin(const Tensor & elements, const Tensor & test_elements, bool assume_unique, bool invert); // {"schema": "aten::isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & isin_out(const Tensor & elements, const Scalar & test_element, bool assume_unique, bool invert, Tensor & out); // {"schema": "aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor isin(const Tensor & elements, const Scalar & test_element, bool assume_unique, bool invert); // {"schema": "aten::isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & isin_out(const Scalar & element, const Tensor & test_elements, bool assume_unique, bool invert, Tensor & out); // {"schema": "aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor isin(const Scalar & element, const Tensor & test_elements, bool assume_unique, bool invert); // {"schema": "aten::isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor isnan(const Tensor & self); // {"schema": "aten::isnan(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+bool is_distributed(const Tensor & self); // {"schema": "aten::is_distributed(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool is_floating_point(const Tensor & self); // {"schema": "aten::is_floating_point(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool is_complex(const Tensor & self); // {"schema": "aten::is_complex(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool is_conj(const Tensor & self); // {"schema": "aten::is_conj(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool _is_zerotensor(const Tensor & self); // {"schema": "aten::_is_zerotensor(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool is_neg(const Tensor & self); // {"schema": "aten::is_neg(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+Tensor isreal(const Tensor & self); // {"schema": "aten::isreal(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+bool is_nonzero(const Tensor & self); // {"schema": "aten::is_nonzero(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool is_same_size(const Tensor & self, const Tensor & other); // {"schema": "aten::is_same_size(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "True"}
+bool is_signed(const Tensor & self); // {"schema": "aten::is_signed(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+bool is_inference(const Tensor & self); // {"schema": "aten::is_inference(Tensor self) -> bool", "dispatch": "False", "default": "True"}
+Tensor kl_div(const Tensor & self, const Tensor & target, int64_t reduction, bool log_target); // {"schema": "aten::kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor kron(const Tensor & self, const Tensor & other); // {"schema": "aten::kron(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & kron_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple kthvalue(const Tensor & self, int64_t k, int64_t dim, bool keepdim); // {"schema": "aten::kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple kthvalue_out(const Tensor & self, int64_t k, int64_t dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple kthvalue(const Tensor & self, int64_t k, Dimname dim, bool keepdim); // {"schema": "aten::kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple kthvalue_out(const Tensor & self, int64_t k, Dimname dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+Tensor layer_norm(const Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps, bool cudnn_enable); // {"schema": "aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple native_layer_norm(const Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps); // {"schema": "aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"}
+::std::tuple native_layer_norm_backward(const Tensor & grad_out, const Tensor & input, c10::SymIntArrayRef normalized_shape, const Tensor & mean, const Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask); // {"schema": "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor nan_to_num(const Tensor & self, c10::optional nan, c10::optional posinf, c10::optional neginf); // {"schema": "aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & nan_to_num_(Tensor & self, c10::optional nan, c10::optional posinf, c10::optional neginf); // {"schema": "aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & nan_to_num_out(const Tensor & self, c10::optional nan, c10::optional posinf, c10::optional neginf, Tensor & out); // {"schema": "aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor linear(const Tensor & input, const Tensor & weight, const c10::optional & bias); // {"schema": "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple linear_backward(const Tensor & self, const Tensor & grad_output, const Tensor & weight, ::std::array output_mask); // {"schema": "aten::linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor & linear_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, Tensor & out); // {"schema": "aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor mkldnn_linear(const Tensor & self, const Tensor & weight, const c10::optional & bias); // {"schema": "aten::mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_linear_backward_input(IntArrayRef input_size, const Tensor & grad_output, const Tensor & weight); // {"schema": "aten::mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple mkldnn_linear_backward_weights(const Tensor & grad_output, const Tensor & input, const Tensor & weight, bool bias_defined); // {"schema": "aten::mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple mkldnn_linear_backward(const Tensor & self, const Tensor & grad_output, const Tensor & weight, ::std::array output_mask); // {"schema": "aten::mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _cslt_compress(const Tensor & input); // {"schema": "aten::_cslt_compress(Tensor input) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _cslt_sparse_mm(const Tensor & compressed_A, const Tensor & dense_B, const c10::optional & bias, const c10::optional & alpha, c10::optional out_dtype, bool transpose_result, int64_t alg_id); // {"schema": "aten::_cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor", "dispatch": "True", "default": "False"}
+int64_t _cslt_sparse_mm_search(const Tensor & compressed_A, const Tensor & dense_B, const c10::optional & bias, const c10::optional & alpha, c10::optional out_dtype, bool transpose_result); // {"schema": "aten::_cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int", "dispatch": "True", "default": "False"}
+Tensor _sparse_semi_structured_linear(const Tensor & input, const Tensor & weight, const Tensor & meta, const c10::optional & bias, c10::optional activation, c10::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _mixed_dtypes_linear(const Tensor & input, const Tensor & weight, const Tensor & scale, const c10::optional & bias, c10::optional activation); // {"schema": "aten::_mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor fbgemm_linear_int8_weight_fp32_activation(const Tensor & input, const Tensor & weight, const Tensor & packed, const Tensor & col_offsets, const Scalar & weight_scale, const Scalar & weight_zero_point, const Tensor & bias); // {"schema": "aten::fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fbgemm_linear_int8_weight(const Tensor & input, const Tensor & weight, const Tensor & packed, const Tensor & col_offsets, const Scalar & weight_scale, const Scalar & weight_zero_point, const Tensor & bias); // {"schema": "aten::fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple fbgemm_linear_quantize_weight(const Tensor & input); // {"schema": "aten::fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)", "dispatch": "False", "default": "True"}
+Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor & input); // {"schema": "aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fbgemm_linear_fp16_weight_fp32_activation(const Tensor & input, const Tensor & packed_weight, const Tensor & bias); // {"schema": "aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fbgemm_linear_fp16_weight(const Tensor & input, const Tensor & packed_weight, const Tensor & bias); // {"schema": "aten::fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fbgemm_pack_quantized_matrix(const Tensor & input); // {"schema": "aten::fbgemm_pack_quantized_matrix(Tensor input) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fbgemm_pack_quantized_matrix(const Tensor & input, int64_t K, int64_t N); // {"schema": "aten::fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor ldexp(const Tensor & self, const Tensor & other); // {"schema": "aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & ldexp_(Tensor & self, const Tensor & other); // {"schema": "aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & ldexp_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linspace(const Scalar & start, const Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor linspace(const Tensor & start, const Tensor & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor linspace(const Tensor & start, const Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor linspace(const Scalar & start, const Tensor & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & linspace_out(const Scalar & start, const Scalar & end, int64_t steps, Tensor & out); // {"schema": "aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & linspace_out(const Tensor & start, const Tensor & end, int64_t steps, Tensor & out); // {"schema": "aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & linspace_out(const Tensor & start, const Scalar & end, int64_t steps, Tensor & out); // {"schema": "aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & linspace_out(const Scalar & start, const Tensor & end, int64_t steps, Tensor & out); // {"schema": "aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor log(const Tensor & self); // {"schema": "aten::log(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & log_(Tensor & self); // {"schema": "aten::log_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & log_out(const Tensor & self, Tensor & out); // {"schema": "aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor log10(const Tensor & self); // {"schema": "aten::log10(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & log10_(Tensor & self); // {"schema": "aten::log10_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & log10_out(const Tensor & self, Tensor & out); // {"schema": "aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor log1p(const Tensor & self); // {"schema": "aten::log1p(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & log1p_(Tensor & self); // {"schema": "aten::log1p_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & log1p_out(const Tensor & self, Tensor & out); // {"schema": "aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor log2(const Tensor & self); // {"schema": "aten::log2(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & log2_(Tensor & self); // {"schema": "aten::log2_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & log2_out(const Tensor & self, Tensor & out); // {"schema": "aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & logaddexp_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logaddexp(const Tensor & self, const Tensor & other); // {"schema": "aten::logaddexp(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logaddexp2_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logaddexp2(const Tensor & self, const Tensor & other); // {"schema": "aten::logaddexp2(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor xlogy(const Tensor & self, const Tensor & other); // {"schema": "aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor xlogy(const Scalar & self, const Tensor & other); // {"schema": "aten::xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor xlogy(const Tensor & self, const Scalar & other); // {"schema": "aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & xlogy_(Tensor & self, const Tensor & other); // {"schema": "aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & xlogy_(Tensor & self, const Scalar & other); // {"schema": "aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & xlogy_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & xlogy_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & xlogy_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor logspace(const Scalar & start, const Scalar & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor logspace(const Tensor & start, const Tensor & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor logspace(const Tensor & start, const Scalar & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor logspace(const Scalar & start, const Tensor & end, int64_t steps, double base, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logspace_out(const Scalar & start, const Scalar & end, int64_t steps, double base, Tensor & out); // {"schema": "aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & logspace_out(const Tensor & start, const Tensor & end, int64_t steps, double base, Tensor & out); // {"schema": "aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & logspace_out(const Tensor & start, const Scalar & end, int64_t steps, double base, Tensor & out); // {"schema": "aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & logspace_out(const Scalar & start, const Tensor & end, int64_t steps, double base, Tensor & out); // {"schema": "aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor log_softmax(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & log_softmax_out(const Tensor & self, int64_t dim, c10::optional dtype, Tensor & out); // {"schema": "aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor log_softmax(const Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _log_softmax(const Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _log_softmax_out(const Tensor & self, int64_t dim, bool half_to_float, Tensor & out); // {"schema": "aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _log_softmax_backward_data(const Tensor & grad_output, const Tensor & output, int64_t dim, ScalarType input_dtype); // {"schema": "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _log_softmax_backward_data_out(const Tensor & grad_output, const Tensor & output, int64_t dim, ScalarType input_dtype, Tensor & out); // {"schema": "aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _logcumsumexp(const Tensor & self, int64_t dim); // {"schema": "aten::_logcumsumexp(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _logcumsumexp_out(const Tensor & self, int64_t dim, Tensor & out); // {"schema": "aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logcumsumexp(const Tensor & self, int64_t dim); // {"schema": "aten::logcumsumexp(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logcumsumexp_out(const Tensor & self, int64_t dim, Tensor & out); // {"schema": "aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor logcumsumexp(const Tensor & self, Dimname dim); // {"schema": "aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & logcumsumexp_out(const Tensor & self, Dimname dim, Tensor & out); // {"schema": "aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor logsumexp(const Tensor & self, IntArrayRef dim, bool keepdim); // {"schema": "aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logsumexp_out(const Tensor & self, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor logsumexp(const Tensor & self, DimnameList dim, bool keepdim); // {"schema": "aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & logsumexp_out(const Tensor & self, DimnameList dim, bool keepdim, Tensor & out); // {"schema": "aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor margin_ranking_loss(const Tensor & input1, const Tensor & input2, const Tensor & target, double margin, int64_t reduction); // {"schema": "aten::margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor matmul(const Tensor & self, const Tensor & other); // {"schema": "aten::matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple matmul_backward(const Tensor & grad, const Tensor & self, const Tensor & other, ::std::array mask); // {"schema": "aten::matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor & matmul_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor matrix_power(const Tensor & self, int64_t n); // {"schema": "aten::matrix_power(Tensor self, int n) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & matrix_power_out(const Tensor & self, int64_t n, Tensor & out); // {"schema": "aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor matrix_exp(const Tensor & self); // {"schema": "aten::matrix_exp(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor matrix_exp_backward(const Tensor & self, const Tensor & grad); // {"schema": "aten::matrix_exp_backward(Tensor self, Tensor grad) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _aminmax(const Tensor & self); // {"schema": "aten::_aminmax(Tensor self) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _aminmax(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::_aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple aminmax(const Tensor & self, c10::optional dim, bool keepdim); // {"schema": "aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)", "dispatch": "True", "default": "True"}
+::std::tuple aminmax_out(const Tensor & self, c10::optional dim, bool keepdim, Tensor & min, Tensor & max); // {"schema": "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)", "dispatch": "True", "default": "False"}
+Tensor _compute_linear_combination(const Tensor & input, const Tensor & coefficients); // {"schema": "aten::_compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _compute_linear_combination_out(const Tensor & input, const Tensor & coefficients, Tensor & out); // {"schema": "aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple max(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple max_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & max, Tensor & max_values); // {"schema": "aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple max(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple max_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & max, Tensor & max_values); // {"schema": "aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+Tensor value_selecting_reduction_backward(const Tensor & grad, int64_t dim, const Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim); // {"schema": "aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor amax(const Tensor & self, IntArrayRef dim, bool keepdim); // {"schema": "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & amax_out(const Tensor & self, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple max_pool1d_with_indices(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor max_pool1d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor max_pool2d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor max_pool2d_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_max_pool2d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_max_pool2d_backward(const Tensor & grad_output, const Tensor & output, const Tensor & input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_max_pool3d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_max_pool3d_backward(const Tensor & grad_output, const Tensor & output, const Tensor & input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor quantized_max_pool1d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor quantized_max_pool2d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor quantized_max_pool3d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor max_pool3d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor mean(const Tensor & self, c10::optional dtype); // {"schema": "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor mean(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mean_out(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mean(const Tensor & self, DimnameList dim, bool keepdim, c10::optional dtype); // {"schema": "aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & mean_out(const Tensor & self, DimnameList dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nanmean(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & nanmean_out(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor median(const Tensor & self); // {"schema": "aten::median(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple median(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple median_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple median(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple median_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+Tensor nanmedian(const Tensor & self); // {"schema": "aten::nanmedian(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple nanmedian(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple nanmedian_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple nanmedian(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple nanmedian_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+::std::tuple min(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple min_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & min, Tensor & min_indices); // {"schema": "aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple min(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple min_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & min, Tensor & min_indices); // {"schema": "aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+Tensor amin(const Tensor & self, IntArrayRef dim, bool keepdim); // {"schema": "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & amin_out(const Tensor & self, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _mps_convolution(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple mps_convolution_backward(const Tensor & self, const Tensor & grad_output, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor mkldnn_convolution(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple mkldnn_rnn_layer(const Tensor & input, const Tensor & weight0, const Tensor & weight1, const Tensor & weight2, const Tensor & weight3, const Tensor & hx_, const Tensor & cx_, bool reverse, IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train); // {"schema": "aten::mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple mkldnn_rnn_layer_backward(const Tensor & input, const Tensor & weight1, const Tensor & weight2, const Tensor & weight3, const Tensor & weight4, const Tensor & hx_, const Tensor & cx_tmp, const Tensor & output, const Tensor & hy_, const Tensor & cy_, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, IntArrayRef batch_sizes, bool batch_first, const Tensor & workspace); // {"schema": "aten::mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple miopen_batch_norm(const Tensor & input, const Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon); // {"schema": "aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple miopen_batch_norm_backward(const Tensor & input, const Tensor & grad_output, const Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon); // {"schema": "aten::miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor miopen_convolution(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor miopen_convolution_transpose(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor miopen_depthwise_convolution(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor miopen_convolution_relu(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor miopen_convolution_add_relu(const Tensor & self, const Tensor & weight, const Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple miopen_rnn(const Tensor & input, TensorList weight, int64_t weight_stride0, const Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const c10::optional & dropout_state); // {"schema": "aten::miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple> miopen_rnn_backward(const Tensor & input, TensorList weight, int64_t weight_stride0, const Tensor & weight_buf, const Tensor & hx, const c10::optional & cx, const Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const c10::optional & dropout_state, const Tensor & reserve, ::std::array output_mask); // {"schema": "aten::miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])", "dispatch": "True", "default": "False"}
+Tensor mm(const Tensor & self, const Tensor & mat2); // {"schema": "aten::mm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mm_out(const Tensor & self, const Tensor & mat2, Tensor & out); // {"schema": "aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _int_mm(const Tensor & self, const Tensor & mat2); // {"schema": "aten::_int_mm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _int_mm_out(const Tensor & self, const Tensor & mat2, Tensor & out); // {"schema": "aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _convert_weight_to_int4pack(const Tensor & self, int64_t innerKTiles); // {"schema": "aten::_convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _weight_int4pack_mm(const Tensor & self, const Tensor & mat2, int64_t qGroupSize, const Tensor & qScaleAndZeros); // {"schema": "aten::_weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _weight_int8pack_mm(const Tensor & self, const Tensor & mat2, const Tensor & scales); // {"schema": "aten::_weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_mm(const Tensor & sparse, const Tensor & dense); // {"schema": "aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_mm(const Tensor & sparse, const Tensor & dense, c10::string_view reduce); // {"schema": "aten::_sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_sparse_matmul(const Tensor & self, const Tensor & other); // {"schema": "aten::_sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple mode(const Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "False"}
+::std::tuple mode_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"}
+::std::tuple mode(const Tensor & self, Dimname dim, bool keepdim); // {"schema": "aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple mode_out(const Tensor & self, Dimname dim, bool keepdim, Tensor & values, Tensor & indices); // {"schema": "aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+Tensor mul(const Tensor & self, const Tensor & other); // {"schema": "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mul_(Tensor & self, const Tensor & other); // {"schema": "aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mul_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mul(const Tensor & self, const Scalar & other); // {"schema": "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mul_(Tensor & self, const Scalar & other); // {"schema": "aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor multiply(const Tensor & self, const Tensor & other); // {"schema": "aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & multiply_(Tensor & self, const Tensor & other); // {"schema": "aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & multiply_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor multiply(const Tensor & self, const Scalar & other); // {"schema": "aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & multiply_(Tensor & self, const Scalar & other); // {"schema": "aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor mv(const Tensor & self, const Tensor & vec); // {"schema": "aten::mv(Tensor self, Tensor vec) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mv_out(const Tensor & self, const Tensor & vec, Tensor & out); // {"schema": "aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mvlgamma_out(const Tensor & self, int64_t p, Tensor & out); // {"schema": "aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mvlgamma(const Tensor & self, int64_t p); // {"schema": "aten::mvlgamma(Tensor self, int p) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mvlgamma_(Tensor & self, int64_t p); // {"schema": "aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor narrow_copy(const Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length); // {"schema": "aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & narrow_copy_out(const Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length, Tensor & out); // {"schema": "aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor narrow(const Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length); // {"schema": "aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor narrow(const Tensor & self, int64_t dim, const Tensor & start, c10::SymInt length); // {"schema": "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)", "dispatch": "False", "default": "True"}
+::std::tuple native_batch_norm(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps); // {"schema": "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple native_batch_norm_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, Tensor & out, Tensor & save_mean, Tensor & save_invstd); // {"schema": "aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"}
+::std::tuple _native_batch_norm_legit(const Tensor & input, const c10::optional & weight, const c10::optional & bias, Tensor & running_mean, Tensor & running_var, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _native_batch_norm_legit_no_training(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & running_mean, const Tensor & running_var, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"}
+::std::tuple _native_batch_norm_legit_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, Tensor & running_mean, Tensor & running_var, bool training, double momentum, double eps, Tensor & out, Tensor & save_mean, Tensor & save_invstd); // {"schema": "aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "False"}
+::std::tuple _native_batch_norm_legit(const Tensor & input, const c10::optional & weight, const c10::optional & bias, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _native_batch_norm_legit_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, bool training, double momentum, double eps, Tensor & out, Tensor & save_mean, Tensor & save_invstd); // {"schema": "aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"}
+::std::tuple batch_norm_stats(const Tensor & input, double eps); // {"schema": "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor batch_norm_elemt(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & mean, const Tensor & invstd, double eps); // {"schema": "aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & batch_norm_elemt_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & mean, const Tensor & invstd, double eps, Tensor & out); // {"schema": "aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple batch_norm_gather_stats(const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, int64_t count); // {"schema": "aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple batch_norm_gather_stats_with_counts(const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, const Tensor & counts); // {"schema": "aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple native_batch_norm_backward(const Tensor & grad_out, const Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask); // {"schema": "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple batch_norm_backward_reduce(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g); // {"schema": "aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor batch_norm_backward_elemt(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & weight, const Tensor & sum_dy, const Tensor & sum_dy_xmu, const Tensor & count); // {"schema": "aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple batch_norm_update_stats(const Tensor & input, const c10::optional & running_mean, const c10::optional & running_var, double momentum); // {"schema": "aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+bool is_vulkan_available(); // {"schema": "aten::is_vulkan_available() -> bool", "dispatch": "False", "default": "True"}
+bool _nnpack_available(); // {"schema": "aten::_nnpack_available() -> bool", "dispatch": "False", "default": "True"}
+Tensor _nnpack_spatial_convolution(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride); // {"schema": "aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor ones(IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor ones(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ones_out(c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor ones_like(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor pairwise_distance(const Tensor & x1, const Tensor & x2, double p, double eps, bool keepdim); // {"schema": "aten::pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor cdist(const Tensor & x1, const Tensor & x2, double p, c10::optional compute_mode); // {"schema": "aten::cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _euclidean_dist(const Tensor & x1, const Tensor & x2); // {"schema": "aten::_euclidean_dist(Tensor x1, Tensor x2) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _cdist_forward(const Tensor & x1, const Tensor & x2, double p, c10::optional compute_mode); // {"schema": "aten::_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _cdist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, double p, const Tensor & cdist); // {"schema": "aten::_cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor pdist(const Tensor & self, double p); // {"schema": "aten::pdist(Tensor self, float p=2) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _pdist_forward(const Tensor & self, double p); // {"schema": "aten::_pdist_forward(Tensor self, float p=2) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _pdist_backward(const Tensor & grad, const Tensor & self, double p, const Tensor & pdist); // {"schema": "aten::_pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor cosine_similarity(const Tensor & x1, const Tensor & x2, int64_t dim, double eps); // {"schema": "aten::cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor permute(const Tensor & self, IntArrayRef dims); // {"schema": "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor movedim(const Tensor & self, IntArrayRef source, IntArrayRef destination); // {"schema": "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor movedim(const Tensor & self, int64_t source, int64_t destination); // {"schema": "aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor moveaxis(const Tensor & self, IntArrayRef source, IntArrayRef destination); // {"schema": "aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor moveaxis(const Tensor & self, int64_t source, int64_t destination); // {"schema": "aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor numpy_T(const Tensor & self); // {"schema": "aten::numpy_T(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor matrix_H(const Tensor & self); // {"schema": "aten::matrix_H(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor mT(const Tensor & self); // {"schema": "aten::mT(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor mH(const Tensor & self); // {"schema": "aten::mH(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor adjoint(const Tensor & self); // {"schema": "aten::adjoint(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor pixel_shuffle(const Tensor & self, int64_t upscale_factor); // {"schema": "aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor pixel_unshuffle(const Tensor & self, int64_t downscale_factor); // {"schema": "aten::pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor channel_shuffle(const Tensor & self, c10::SymInt groups); // {"schema": "aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor native_channel_shuffle(const Tensor & self, c10::SymInt groups); // {"schema": "aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"}
+bool is_pinned(const Tensor & self, c10::optional device); // {"schema": "aten::is_pinned(Tensor self, Device? device=None) -> bool", "dispatch": "True", "default": "True"}
+Tensor pin_memory(const Tensor & self, c10::optional device); // {"schema": "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _pin_memory(const Tensor & self, c10::optional device); // {"schema": "aten::_pin_memory(Tensor self, Device? device=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor pinverse(const Tensor & self, double rcond); // {"schema": "aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor poisson_nll_loss(const Tensor & input, const Tensor & target, bool log_input, bool full, double eps, int64_t reduction); // {"schema": "aten::poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor rad2deg(const Tensor & self); // {"schema": "aten::rad2deg(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & rad2deg_(Tensor & self); // {"schema": "aten::rad2deg_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rad2deg_out(const Tensor & self, Tensor & out); // {"schema": "aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor deg2rad(const Tensor & self); // {"schema": "aten::deg2rad(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & deg2rad_(Tensor & self); // {"schema": "aten::deg2rad_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & deg2rad_out(const Tensor & self, Tensor & out); // {"schema": "aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor scalar_tensor(const Scalar & s, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor rand(c10::SymIntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor rand(c10::SymIntArrayRef size, c10::optional generator, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor rand(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor rand(c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & rand_out(c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rand_out(c10::SymIntArrayRef size, c10::optional generator, Tensor & out); // {"schema": "aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor rand_like(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randint(c10::SymInt high, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randint(c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randint(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randint(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & randint_out(c10::SymInt high, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randint_out(c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, Tensor & out); // {"schema": "aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randint_out(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randint_out(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, c10::optional generator, Tensor & out); // {"schema": "aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor randint_like(const Tensor & self, c10::SymInt high, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randint_like(const Tensor & self, c10::SymInt low, c10::SymInt high, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randn(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randn(c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randn(c10::SymIntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randn(c10::SymIntArrayRef size, c10::optional generator, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & randn_out(c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & randn_out(c10::SymIntArrayRef size, c10::optional generator, Tensor & out); // {"schema": "aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor randn_like(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randperm(c10::SymInt n, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor randperm(c10::SymInt n, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & randperm_out(c10::SymInt n, Tensor & out); // {"schema": "aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randperm_out(c10::SymInt n, c10::optional generator, Tensor & out); // {"schema": "aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor range(const Scalar & start, const Scalar & end, const Scalar & step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor range(const Scalar & start, const Scalar & end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & range_out(const Scalar & start, const Scalar & end, Tensor & out); // {"schema": "aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & range_out(const Scalar & start, const Scalar & end, const Scalar & step, Tensor & out); // {"schema": "aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor ravel(const Tensor & self); // {"schema": "aten::ravel(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor reciprocal(const Tensor & self); // {"schema": "aten::reciprocal(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & reciprocal_(Tensor & self); // {"schema": "aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & reciprocal_out(const Tensor & self, Tensor & out); // {"schema": "aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor neg(const Tensor & self); // {"schema": "aten::neg(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & neg_(Tensor & self); // {"schema": "aten::neg_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & neg_out(const Tensor & self, Tensor & out); // {"schema": "aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor negative(const Tensor & self); // {"schema": "aten::negative(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & negative_(Tensor & self); // {"schema": "aten::negative_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & negative_out(const Tensor & self, Tensor & out); // {"schema": "aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor repeat(const Tensor & self, c10::SymIntArrayRef repeats); // {"schema": "aten::repeat(Tensor self, SymInt[] repeats) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor repeat_interleave(const Tensor & repeats, c10::optional output_size); // {"schema": "aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor repeat_interleave(const Tensor & self, const Tensor & repeats, c10::optional dim, c10::optional output_size); // {"schema": "aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor repeat_interleave(const Tensor & self, c10::SymInt repeats, c10::optional dim, c10::optional output_size); // {"schema": "aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor reshape(const Tensor & self, c10::SymIntArrayRef shape); // {"schema": "aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _reshape_copy(const Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _reshape_alias(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor _mkldnn_reshape(const Tensor & self, IntArrayRef shape); // {"schema": "aten::_mkldnn_reshape(Tensor self, int[] shape) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor reshape_as(const Tensor & self, const Tensor & other); // {"schema": "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor round(const Tensor & self); // {"schema": "aten::round(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & round_(Tensor & self); // {"schema": "aten::round_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & round_out(const Tensor & self, Tensor & out); // {"schema": "aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor round(const Tensor & self, int64_t decimals); // {"schema": "aten::round.decimals(Tensor self, *, int decimals) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & round_(Tensor & self, int64_t decimals); // {"schema": "aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & round_out(const Tensor & self, int64_t decimals, Tensor & out); // {"schema": "aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor rrelu(const Tensor & self, const Scalar & lower, const Scalar & upper, bool training, c10::optional generator); // {"schema": "aten::rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & rrelu_(Tensor & self, const Scalar & lower, const Scalar & upper, bool training, c10::optional generator); // {"schema": "aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor relu(const Tensor & self); // {"schema": "aten::relu(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & relu_(Tensor & self); // {"schema": "aten::relu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor relu6(const Tensor & self); // {"schema": "aten::relu6(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & relu6_(Tensor & self); // {"schema": "aten::relu6_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor prelu(const Tensor & self, const Tensor & weight); // {"schema": "aten::prelu(Tensor self, Tensor weight) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _prelu_kernel(const Tensor & self, const Tensor & weight); // {"schema": "aten::_prelu_kernel(Tensor self, Tensor weight) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _prelu_kernel_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight); // {"schema": "aten::_prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor & gelu_out(const Tensor & self, c10::string_view approximate, Tensor & out); // {"schema": "aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & gelu_(Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor gelu(const Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & gelu_backward_out(const Tensor & grad_output, const Tensor & self, c10::string_view approximate, Tensor & grad_input); // {"schema": "aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor gelu_backward(const Tensor & grad_output, const Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor", "dispatch": "True", "default": "True"}
+Tensor infinitely_differentiable_gelu_backward(const Tensor & grad, const Tensor & self); // {"schema": "aten::infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & hardshrink_out(const Tensor & self, const Scalar & lambd, Tensor & out); // {"schema": "aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardshrink(const Tensor & self, const Scalar & lambd); // {"schema": "aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & hardshrink_backward_out(const Tensor & grad_out, const Tensor & self, const Scalar & lambd, Tensor & grad_input); // {"schema": "aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardshrink_backward(const Tensor & grad_out, const Tensor & self, const Scalar & lambd); // {"schema": "aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor rsqrt(const Tensor & self); // {"schema": "aten::rsqrt(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & rsqrt_(Tensor & self); // {"schema": "aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rsqrt_out(const Tensor & self, Tensor & out); // {"schema": "aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor select(const Tensor & self, Dimname dim, int64_t index); // {"schema": "aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor select(const Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor select_backward(const Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index); // {"schema": "aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _nested_select_backward(const Tensor & grad_output, const Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor selu(const Tensor & self); // {"schema": "aten::selu(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & selu_(Tensor & self); // {"schema": "aten::selu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor celu(const Tensor & self, const Scalar & alpha); // {"schema": "aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & celu_(Tensor & self, const Scalar & alpha); // {"schema": "aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor silu(const Tensor & self); // {"schema": "aten::silu(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & silu_(Tensor & self); // {"schema": "aten::silu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & silu_out(const Tensor & self, Tensor & out); // {"schema": "aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & silu_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & grad_input); // {"schema": "aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor silu_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor mish(const Tensor & self); // {"schema": "aten::mish(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mish_(Tensor & self); // {"schema": "aten::mish_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mish_out(const Tensor & self, Tensor & out); // {"schema": "aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mish_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::mish_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor sigmoid(const Tensor & self); // {"schema": "aten::sigmoid(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sigmoid_(Tensor & self); // {"schema": "aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sigmoid_out(const Tensor & self, Tensor & out); // {"schema": "aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logit(const Tensor & self, c10::optional eps); // {"schema": "aten::logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & logit_(Tensor & self, c10::optional eps); // {"schema": "aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & logit_out(const Tensor & self, c10::optional eps, Tensor & out); // {"schema": "aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sin(const Tensor & self); // {"schema": "aten::sin(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sin_(Tensor & self); // {"schema": "aten::sin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sin_out(const Tensor & self, Tensor & out); // {"schema": "aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sinc(const Tensor & self); // {"schema": "aten::sinc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sinc_(Tensor & self); // {"schema": "aten::sinc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sinc_out(const Tensor & self, Tensor & out); // {"schema": "aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sinh(const Tensor & self); // {"schema": "aten::sinh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sinh_(Tensor & self); // {"schema": "aten::sinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sinh_out(const Tensor & self, Tensor & out); // {"schema": "aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor detach(const Tensor & self); // {"schema": "aten::detach(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor & detach_(Tensor & self); // {"schema": "aten::detach_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+int64_t size(const Tensor & self, int64_t dim); // {"schema": "aten::size.int(Tensor self, int dim) -> int", "dispatch": "False", "default": "True"}
+int64_t size(const Tensor & self, Dimname dim); // {"schema": "aten::size.Dimname(Tensor self, Dimname dim) -> int", "dispatch": "False", "default": "True"}
+c10::SymInt sym_size(const Tensor & self, int64_t dim); // {"schema": "aten::sym_size.int(Tensor self, int dim) -> SymInt", "dispatch": "False", "default": "True"}
+c10::SymInt sym_numel(const Tensor & self); // {"schema": "aten::sym_numel(Tensor self) -> SymInt", "dispatch": "False", "default": "True"}
+c10::SymInt sym_storage_offset(const Tensor & self); // {"schema": "aten::sym_storage_offset(Tensor self) -> SymInt", "dispatch": "False", "default": "True"}
+Tensor slice(const Tensor & self, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step); // {"schema": "aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor slice_backward(const Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step); // {"schema": "aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor slice_inverse(const Tensor & self, const Tensor & src, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step); // {"schema": "aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor slice_scatter(const Tensor & self, const Tensor & src, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step); // {"schema": "aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor select_scatter(const Tensor & self, const Tensor & src, int64_t dim, c10::SymInt index); // {"schema": "aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor diagonal_scatter(const Tensor & self, const Tensor & src, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor as_strided_scatter(const Tensor & self, const Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset); // {"schema": "aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor smm(const Tensor & self, const Tensor & mat2); // {"schema": "aten::smm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor softmax(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & softmax_out(const Tensor & self, int64_t dim, c10::optional dtype, Tensor & out); // {"schema": "aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor softmax(const Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _softmax(const Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _softmax_out(const Tensor & self, int64_t dim, bool half_to_float, Tensor & out); // {"schema": "aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _softmax_backward_data(const Tensor & grad_output, const Tensor & output, int64_t dim, ScalarType input_dtype); // {"schema": "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _softmax_backward_data_out(const Tensor & grad_output, const Tensor & output, int64_t dim, ScalarType input_dtype, Tensor & grad_input); // {"schema": "aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::vector unsafe_split(const Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"}
+::std::vector split(const Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"}
+::std::vector split(const Tensor & self, c10::SymIntArrayRef split_size, int64_t dim); // {"schema": "aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector unsafe_split_with_sizes(const Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"}
+::std::vector split_with_sizes(const Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"}
+::std::vector hsplit(const Tensor & self, int64_t sections); // {"schema": "aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector hsplit(const Tensor & self, IntArrayRef indices); // {"schema": "aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector vsplit(const Tensor & self, int64_t sections); // {"schema": "aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector vsplit(const Tensor & self, IntArrayRef indices); // {"schema": "aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector dsplit(const Tensor & self, int64_t sections); // {"schema": "aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+::std::vector dsplit(const Tensor & self, IntArrayRef indices); // {"schema": "aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+Tensor squeeze(const Tensor & self); // {"schema": "aten::squeeze(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor squeeze(const Tensor & self, int64_t dim); // {"schema": "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor squeeze(const Tensor & self, Dimname dim); // {"schema": "aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor squeeze(const Tensor & self, IntArrayRef dim); // {"schema": "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_(Tensor & self); // {"schema": "aten::squeeze_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_(Tensor & self, int64_t dim); // {"schema": "aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_(Tensor & self, IntArrayRef dim); // {"schema": "aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_(Tensor & self, Dimname dim); // {"schema": "aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor sspaddmm(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & sspaddmm_out(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _chunk_cat(TensorList tensors, int64_t dim, int64_t num_chunks); // {"schema": "aten::_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _chunk_cat_out(TensorList tensors, int64_t dim, int64_t num_chunks, Tensor & out); // {"schema": "aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor stack(TensorList tensors, int64_t dim); // {"schema": "aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & stack_out(TensorList tensors, int64_t dim, Tensor & out); // {"schema": "aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor _stack(TensorList tensors, int64_t dim); // {"schema": "aten::_stack(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _stack_out(TensorList tensors, int64_t dim, Tensor & out); // {"schema": "aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor hstack(TensorList tensors); // {"schema": "aten::hstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & hstack_out(TensorList tensors, Tensor & out); // {"schema": "aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor vstack(TensorList tensors); // {"schema": "aten::vstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & vstack_out(TensorList tensors, Tensor & out); // {"schema": "aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor dstack(TensorList tensors); // {"schema": "aten::dstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & dstack_out(TensorList tensors, Tensor & out); // {"schema": "aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor stft(const Tensor & self, int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool normalized, c10::optional onesided, c10::optional return_complex); // {"schema": "aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor stft(const Tensor & self, int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool center, c10::string_view pad_mode, bool normalized, c10::optional onesided, c10::optional return_complex); // {"schema": "aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode=\"reflect\", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor istft(const Tensor & self, int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool center, bool normalized, c10::optional onesided, c10::optional length, bool return_complex); // {"schema": "aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor", "dispatch": "False", "default": "True"}
+int64_t stride(const Tensor & self, int64_t dim); // {"schema": "aten::stride.int(Tensor self, int dim) -> int", "dispatch": "False", "default": "True"}
+int64_t stride(const Tensor & self, Dimname dim); // {"schema": "aten::stride.Dimname(Tensor self, Dimname dim) -> int", "dispatch": "False", "default": "True"}
+c10::SymInt sym_stride(const Tensor & self, int64_t dim); // {"schema": "aten::sym_stride.int(Tensor self, int dim) -> SymInt", "dispatch": "False", "default": "True"}
+Tensor sum(const Tensor & self, c10::optional dtype); // {"schema": "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor sum(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor sum(const Tensor & self, DimnameList dim, bool keepdim, c10::optional dtype); // {"schema": "aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & sum_out(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & sum_out(const Tensor & self, DimnameList dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor _nested_sum_backward(const Tensor & grad, const Tensor & self, OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::_nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor nansum(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & nansum_out(const Tensor & self, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sum_to_size(const Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sqrt(const Tensor & self); // {"schema": "aten::sqrt(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sqrt_(Tensor & self); // {"schema": "aten::sqrt_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sqrt_out(const Tensor & self, Tensor & out); // {"schema": "aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor square(const Tensor & self); // {"schema": "aten::square(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & square_(Tensor & self); // {"schema": "aten::square_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & square_out(const Tensor & self, Tensor & out); // {"schema": "aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor std(const Tensor & self, bool unbiased); // {"schema": "aten::std(Tensor self, bool unbiased=True) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor std(const Tensor & self, OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor std(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple std_mean(const Tensor & self, bool unbiased); // {"schema": "aten::std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple std_mean(const Tensor & self, OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple std_mean(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple std_mean(const Tensor & self, DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple std_mean(const Tensor & self, DimnameList dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor & std_out(const Tensor & self, OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor & out); // {"schema": "aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & std_out(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, Tensor & out); // {"schema": "aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor std(const Tensor & self, DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & std_out(const Tensor & self, DimnameList dim, bool unbiased, bool keepdim, Tensor & out); // {"schema": "aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor std(const Tensor & self, DimnameList dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & std_out(const Tensor & self, DimnameList dim, const c10::optional & correction, bool keepdim, Tensor & out); // {"schema": "aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor prod(const Tensor & self, c10::optional dtype); // {"schema": "aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor prod(const Tensor & self, int64_t dim, bool keepdim, c10::optional dtype); // {"schema": "aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & prod_out(const Tensor & self, int64_t dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor prod(const Tensor & self, Dimname dim, bool keepdim, c10::optional dtype); // {"schema": "aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & prod_out(const Tensor & self, Dimname dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor t(const Tensor & self); // {"schema": "aten::t(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor & t_(Tensor & self); // {"schema": "aten::t_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor tan(const Tensor & self); // {"schema": "aten::tan(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & tan_(Tensor & self); // {"schema": "aten::tan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & tan_out(const Tensor & self, Tensor & out); // {"schema": "aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor tanh(const Tensor & self); // {"schema": "aten::tanh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & tanh_(Tensor & self); // {"schema": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & tanh_out(const Tensor & self, Tensor & out); // {"schema": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor tensordot(const Tensor & self, const Tensor & other, IntArrayRef dims_self, IntArrayRef dims_other); // {"schema": "aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & tensordot_out(const Tensor & self, const Tensor & other, IntArrayRef dims_self, IntArrayRef dims_other, Tensor & out); // {"schema": "aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor threshold(const Tensor & self, const Scalar & threshold, const Scalar & value); // {"schema": "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & threshold_(Tensor & self, const Scalar & threshold, const Scalar & value); // {"schema": "aten::threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & threshold_out(const Tensor & self, const Scalar & threshold, const Scalar & value, Tensor & out); // {"schema": "aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & threshold_backward_out(const Tensor & grad_output, const Tensor & self, const Scalar & threshold, Tensor & grad_input); // {"schema": "aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor threshold_backward(const Tensor & grad_output, const Tensor & self, const Scalar & threshold); // {"schema": "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor tile(const Tensor & self, c10::SymIntArrayRef dims); // {"schema": "aten::tile(Tensor self, SymInt[] dims) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor transpose(const Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor transpose(const Tensor & self, Dimname dim0, Dimname dim1); // {"schema": "aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _mkldnn_transpose(const Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::_mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & transpose_(Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _mkldnn_transpose_(Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor one_hot(const Tensor & self, int64_t num_classes); // {"schema": "aten::one_hot(Tensor self, int num_classes=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor flip(const Tensor & self, IntArrayRef dims); // {"schema": "aten::flip(Tensor self, int[] dims) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor fliplr(const Tensor & self); // {"schema": "aten::fliplr(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor flipud(const Tensor & self); // {"schema": "aten::flipud(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor roll(const Tensor & self, c10::SymIntArrayRef shifts, IntArrayRef dims); // {"schema": "aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor rot90(const Tensor & self, int64_t k, IntArrayRef dims); // {"schema": "aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor trapezoid(const Tensor & y, const Tensor & x, int64_t dim); // {"schema": "aten::trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor trapezoid(const Tensor & y, const Scalar & dx, int64_t dim); // {"schema": "aten::trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor trapz(const Tensor & y, const Tensor & x, int64_t dim); // {"schema": "aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor trapz(const Tensor & y, double dx, int64_t dim); // {"schema": "aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _transform_bias_rescale_qkv(const Tensor & qkv, const Tensor & qkv_bias, int64_t num_heads); // {"schema": "aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _nested_tensor_from_mask(const Tensor & t, const Tensor & mask, bool mask_check); // {"schema": "aten::_nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor", "dispatch": "True", "default": "False"}
+bool _nested_tensor_from_mask_left_aligned(const Tensor & t, const Tensor & mask); // {"schema": "aten::_nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool", "dispatch": "True", "default": "False"}
+Tensor _nested_from_padded(const Tensor & padded, const Tensor & cpu_nested_shape_example, bool fuse_transform_0213); // {"schema": "aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_tensor_size(const Tensor & self); // {"schema": "aten::_nested_tensor_size(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_tensor_strides(const Tensor & self); // {"schema": "aten::_nested_tensor_strides(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_tensor_storage_offsets(const Tensor & self); // {"schema": "aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_from_padded_and_nested_example(const Tensor & padded, const Tensor & nt_example); // {"schema": "aten::_nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_view_from_buffer(const Tensor & self, const Tensor & nested_size, const Tensor & nested_strides, const Tensor & offsets); // {"schema": "aten::_nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor _nested_view_from_buffer_copy(const Tensor & self, const Tensor & nested_size, const Tensor & nested_strides, const Tensor & offsets); // {"schema": "aten::_nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _nested_view_from_jagged(const Tensor & self, const Tensor & offsets, const Tensor & dummy, const c10::optional & lengths, int64_t ragged_idx); // {"schema": "aten::_nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor _nested_view_from_jagged_copy(const Tensor & self, const Tensor & offsets, const Tensor & dummy, const c10::optional & lengths, int64_t ragged_idx); // {"schema": "aten::_nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _nested_get_values(const Tensor & self); // {"schema": "aten::_nested_get_values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor _nested_get_values_copy(const Tensor & self); // {"schema": "aten::_nested_get_values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _nested_get_offsets(const Tensor & self); // {"schema": "aten::_nested_get_offsets(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_get_lengths(const Tensor & self); // {"schema": "aten::_nested_get_lengths(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+int64_t _nested_get_ragged_idx(const Tensor & self); // {"schema": "aten::_nested_get_ragged_idx(Tensor self) -> int", "dispatch": "True", "default": "False"}
+Tensor _nested_get_jagged_dummy(const Tensor & any); // {"schema": "aten::_nested_get_jagged_dummy(Tensor any) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _trilinear(const Tensor & i1, const Tensor & i2, const Tensor & i3, IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, IntArrayRef sumdim, int64_t unroll_dim); // {"schema": "aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor triplet_margin_loss(const Tensor & anchor, const Tensor & positive, const Tensor & negative, double margin, double p, double eps, bool swap, int64_t reduction); // {"schema": "aten::triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor trunc(const Tensor & self); // {"schema": "aten::trunc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & trunc_(Tensor & self); // {"schema": "aten::trunc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & trunc_out(const Tensor & self, Tensor & out); // {"schema": "aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor fix(const Tensor & self); // {"schema": "aten::fix(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fix_(Tensor & self); // {"schema": "aten::fix_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & fix_out(const Tensor & self, Tensor & out); // {"schema": "aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor type_as(const Tensor & self, const Tensor & other); // {"schema": "aten::type_as(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+bool _has_compatible_shallow_copy_type(const Tensor & self, const Tensor & from); // {"schema": "aten::_has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool", "dispatch": "False", "default": "True"}
+::std::tuple _unique(const Tensor & self, bool sorted, bool return_inverse); // {"schema": "aten::_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple unique_dim(const Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts); // {"schema": "aten::unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple unique_consecutive(const Tensor & self, bool return_inverse, bool return_counts, c10::optional dim); // {"schema": "aten::unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple unique_dim_consecutive(const Tensor & self, int64_t dim, bool return_inverse, bool return_counts); // {"schema": "aten::unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _unique2(const Tensor & self, bool sorted, bool return_inverse, bool return_counts); // {"schema": "aten::_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _unsafe_view(const Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor unsqueeze(const Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor & unsqueeze_(Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor vander(const Tensor & x, c10::optional N, bool increasing); // {"schema": "aten::vander(Tensor x, int? N=None, bool increasing=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor var(const Tensor & self, bool unbiased); // {"schema": "aten::var(Tensor self, bool unbiased=True) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor var(const Tensor & self, OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor var(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & var_out(const Tensor & self, OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor & out); // {"schema": "aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & var_out(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, Tensor & out); // {"schema": "aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor var(const Tensor & self, DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & var_out(const Tensor & self, DimnameList dim, bool unbiased, bool keepdim, Tensor & out); // {"schema": "aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor var(const Tensor & self, DimnameList dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & var_out(const Tensor & self, DimnameList dim, const c10::optional & correction, bool keepdim, Tensor & out); // {"schema": "aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple var_mean(const Tensor & self, bool unbiased); // {"schema": "aten::var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple var_mean(const Tensor & self, OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple var_mean(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple var_mean(const Tensor & self, DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple var_mean(const Tensor & self, DimnameList dim, const c10::optional & correction, bool keepdim); // {"schema": "aten::var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor view_as(const Tensor & self, const Tensor & other); // {"schema": "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other); // {"schema": "aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & where_out(const Tensor & condition, const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor where(const Tensor & condition, const Scalar & self, const Tensor & other); // {"schema": "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor where(const Tensor & condition, const Tensor & self, const Scalar & other); // {"schema": "aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor where(const Tensor & condition, const Scalar & self, const Scalar & other); // {"schema": "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+::std::vector where(const Tensor & condition); // {"schema": "aten::where(Tensor condition) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim); // {"schema": "aten::norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _weight_norm(const Tensor & v, const Tensor & g, int64_t dim); // {"schema": "aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _weight_norm_interface(const Tensor & v, const Tensor & g, int64_t dim); // {"schema": "aten::_weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _weight_norm_interface_backward(const Tensor & grad_w, const Tensor & saved_v, const Tensor & saved_g, const Tensor & saved_norms, int64_t dim); // {"schema": "aten::_weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _weight_norm_differentiable_backward(const Tensor & grad_w, const Tensor & saved_v, const Tensor & saved_g, const Tensor & saved_norms, int64_t dim); // {"schema": "aten::_weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor zeros(IntArrayRef size, c10::optional names, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _efficientzerotensor(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor zeros(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & zeros_out(c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor zeros_like(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); // {"schema": "aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _standard_gamma_grad(const Tensor & self, const Tensor & output); // {"schema": "aten::_standard_gamma_grad(Tensor self, Tensor output) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _standard_gamma(const Tensor & self, c10::optional generator); // {"schema": "aten::_standard_gamma(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _dirichlet_grad(const Tensor & x, const Tensor & alpha, const Tensor & total); // {"schema": "aten::_dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sample_dirichlet(const Tensor & self, c10::optional generator); // {"schema": "aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor poisson(const Tensor & self, c10::optional generator); // {"schema": "aten::poisson(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor binomial(const Tensor & count, const Tensor & prob, c10::optional generator); // {"schema": "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor native_norm(const Tensor & self, const Scalar & p); // {"schema": "aten::native_norm(Tensor self, Scalar p=2) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor native_norm(const Tensor & self, const c10::optional & p, IntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_sum(const Tensor & self); // {"schema": "aten::_sparse_sum(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_sum(const Tensor & self, ScalarType dtype); // {"schema": "aten::_sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_sum(const Tensor & self, IntArrayRef dim); // {"schema": "aten::_sparse_sum.dim(Tensor self, int[1] dim) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _sparse_sum(const Tensor & self, IntArrayRef dim, ScalarType dtype); // {"schema": "aten::_sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_sum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim); // {"schema": "aten::_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_csr_sum(const Tensor & self, IntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_csr_prod(const Tensor & self, IntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::_sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_softmax(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::_sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_softmax(const Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::_sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_softmax(const Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_softmax_backward_data(const Tensor & grad_output, const Tensor & output, int64_t dim, const Tensor & self); // {"schema": "aten::_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_log_softmax(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::_sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_log_softmax(const Tensor & self, Dimname dim, c10::optional dtype); // {"schema": "aten::_sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_log_softmax(const Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_log_softmax_backward_data(const Tensor & grad_output, const Tensor & output, int64_t dim, const Tensor & self); // {"schema": "aten::_sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _spdiags(const Tensor & diagonals, const Tensor & offsets, IntArrayRef shape, c10::optional layout); // {"schema": "aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor norm(const Tensor & self, const c10::optional & p, ScalarType dtype); // {"schema": "aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor norm(const Tensor & self, const Scalar & p); // {"schema": "aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor norm(const Tensor & self, const c10::optional & p, IntArrayRef dim, bool keepdim, ScalarType dtype); // {"schema": "aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor norm(const Tensor & self, const c10::optional & p, IntArrayRef dim, bool keepdim); // {"schema": "aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & norm_out(const Tensor & self, const c10::optional & p, IntArrayRef dim, bool keepdim, ScalarType dtype, Tensor & out); // {"schema": "aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & norm_out(const Tensor & self, const c10::optional & p, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor norm(const Tensor & self, const c10::optional & p, DimnameList dim, bool keepdim, ScalarType dtype); // {"schema": "aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor norm(const Tensor & self, const c10::optional & p, DimnameList dim, bool keepdim); // {"schema": "aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & norm_out(const Tensor & self, const c10::optional & p, DimnameList dim, bool keepdim, ScalarType dtype, Tensor & out); // {"schema": "aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & norm_out(const Tensor & self, const c10::optional & p, DimnameList dim, bool keepdim, Tensor & out); // {"schema": "aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple frexp(const Tensor & self); // {"schema": "aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)", "dispatch": "True", "default": "True"}
+::std::tuple frexp_out(const Tensor & self, Tensor & mantissa, Tensor & exponent); // {"schema": "aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)", "dispatch": "True", "default": "False"}
+Tensor frobenius_norm(const Tensor & self, IntArrayRef dim, bool keepdim); // {"schema": "aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & frobenius_norm_out(const Tensor & self, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nuclear_norm(const Tensor & self, bool keepdim); // {"schema": "aten::nuclear_norm(Tensor self, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & nuclear_norm_out(const Tensor & self, bool keepdim, Tensor & out); // {"schema": "aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nuclear_norm(const Tensor & self, IntArrayRef dim, bool keepdim); // {"schema": "aten::nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & nuclear_norm_out(const Tensor & self, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor clone(const Tensor & self, c10::optional memory_format); // {"schema": "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor positive(const Tensor & self); // {"schema": "aten::positive(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+const Tensor & resize_as_(const Tensor & self, const Tensor & the_template, c10::optional memory_format); // {"schema": "aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+const Tensor & resize_as_sparse_(const Tensor & self, const Tensor & the_template); // {"schema": "aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & zero_(Tensor & self); // {"schema": "aten::zero_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & sub_out(const Tensor & self, const Tensor & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sub(const Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sub_(Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor sub(const Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sub_(Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & subtract_out(const Tensor & self, const Tensor & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor subtract(const Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & subtract_(Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor subtract(const Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & subtract_(Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor rsub(const Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & heaviside_out(const Tensor & self, const Tensor & values, Tensor & out); // {"schema": "aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor heaviside(const Tensor & self, const Tensor & values); // {"schema": "aten::heaviside(Tensor self, Tensor values) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & heaviside_(Tensor & self, const Tensor & values); // {"schema": "aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor rsub(const Tensor & self, const Scalar & other, const Scalar & alpha); // {"schema": "aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _sparse_addmm(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::_sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sparse_sampled_addmm_out(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sparse_sampled_addmm(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _sparse_mm_reduce_impl(const Tensor & self, const Tensor & other, c10::string_view reduce); // {"schema": "aten::_sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _sparse_mm_reduce_impl_backward(const Tensor & self, const Tensor & grad_out, const Tensor & weight, c10::string_view reduce, const Tensor & arg_out, ::std::array output_mask); // {"schema": "aten::_sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor & addmm_out(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor addmm(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & addmm_(Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _addmm_activation_out(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha, bool use_gelu, Tensor & out); // {"schema": "aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _addmm_activation(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha, bool use_gelu); // {"schema": "aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple _scaled_mm(const Tensor & self, const Tensor & mat2, const c10::optional & bias, c10::optional out_dtype, const c10::optional & scale_a, const c10::optional & scale_b, const c10::optional & scale_result, bool use_fast_accum); // {"schema": "aten::_scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_mm_out(const Tensor & self, const Tensor & mat2, const c10::optional & bias, c10::optional out_dtype, const c10::optional & scale_a, const c10::optional & scale_b, const c10::optional & scale_result, bool use_fast_accum, Tensor & out, Tensor & out_amax); // {"schema": "aten::_scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+Tensor sparse_compressed_tensor(const Tensor & compressed_indices, const Tensor & plain_indices, const Tensor & values, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor sparse_csr_tensor(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_csc_tensor(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_bsr_tensor(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_bsc_tensor(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_compressed_tensor(const Tensor & compressed_indices, const Tensor & plain_indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor sparse_csr_tensor(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_csc_tensor(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_bsr_tensor(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_bsc_tensor(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_compressed_tensor_unsafe(const Tensor & compressed_indices, const Tensor & plain_indices, const Tensor & values, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_csr_tensor_unsafe(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_csc_tensor_unsafe(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_bsr_tensor_unsafe(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_bsc_tensor_unsafe(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_coo_tensor(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor sparse_coo_tensor(const Tensor & indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced); // {"schema": "aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor sparse_coo_tensor(const Tensor & indices, const Tensor & values, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced); // {"schema": "aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _sparse_coo_tensor_unsafe(const Tensor & indices, const Tensor & values, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced); // {"schema": "aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"}
+void _validate_sparse_coo_tensor_args(const Tensor & indices, const Tensor & values, IntArrayRef size, c10::optional is_coalesced); // {"schema": "aten::_validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> ()", "dispatch": "False", "default": "True"}
+void _validate_sparse_compressed_tensor_args(const Tensor & compressed_indices, const Tensor & plain_indices, const Tensor & values, IntArrayRef size, Layout layout); // {"schema": "aten::_validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> ()", "dispatch": "False", "default": "True"}
+void _validate_sparse_csr_tensor_args(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, IntArrayRef size); // {"schema": "aten::_validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()", "dispatch": "False", "default": "True"}
+void _validate_sparse_csc_tensor_args(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, IntArrayRef size); // {"schema": "aten::_validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()", "dispatch": "False", "default": "True"}
+void _validate_sparse_bsr_tensor_args(const Tensor & crow_indices, const Tensor & col_indices, const Tensor & values, IntArrayRef size); // {"schema": "aten::_validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()", "dispatch": "False", "default": "True"}
+void _validate_sparse_bsc_tensor_args(const Tensor & ccol_indices, const Tensor & row_indices, const Tensor & values, IntArrayRef size); // {"schema": "aten::_validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()", "dispatch": "False", "default": "True"}
+Tensor _sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const Tensor & indices, const Tensor & values, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional is_coalesced); // {"schema": "aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor", "dispatch": "True", "default": "False"}
+const Tensor & sparse_resize_(const Tensor & self, IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+const Tensor & sparse_resize_and_clear_(const Tensor & self, IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sparse_mask(const Tensor & self, const Tensor & mask); // {"schema": "aten::sparse_mask(Tensor self, Tensor mask) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _sparse_mask_projection(const Tensor & self, const Tensor & mask, bool accumulate_matches); // {"schema": "aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor", "dispatch": "True", "default": "False"}
+::std::vector _to_cpu(TensorList tensors); // {"schema": "aten::_to_cpu(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor to_dense(const Tensor & self, c10::optional dtype, c10::optional masked_grad); // {"schema": "aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_dense(const Tensor & self, c10::optional dtype, c10::optional masked_grad); // {"schema": "aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_dense_backward(const Tensor & grad, const Tensor & input, c10::optional masked_grad); // {"schema": "aten::to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor", "dispatch": "False", "default": "True"}
+int64_t sparse_dim(const Tensor & self); // {"schema": "aten::sparse_dim(Tensor self) -> int", "dispatch": "True", "default": "False"}
+int64_t _dimI(const Tensor & self); // {"schema": "aten::_dimI(Tensor self) -> int", "dispatch": "True", "default": "False"}
+int64_t dense_dim(const Tensor & self); // {"schema": "aten::dense_dim(Tensor self) -> int", "dispatch": "True", "default": "False"}
+int64_t _dimV(const Tensor & self); // {"schema": "aten::_dimV(Tensor self) -> int", "dispatch": "True", "default": "False"}
+int64_t _nnz(const Tensor & self); // {"schema": "aten::_nnz(Tensor self) -> int", "dispatch": "True", "default": "False"}
+Tensor coalesce(const Tensor & self); // {"schema": "aten::coalesce(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _coalesce(const Tensor & self); // {"schema": "aten::_coalesce(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+bool is_coalesced(const Tensor & self); // {"schema": "aten::is_coalesced(Tensor self) -> bool", "dispatch": "True", "default": "True"}
+Tensor _indices(const Tensor & self); // {"schema": "aten::_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor _values(const Tensor & self); // {"schema": "aten::_values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor & _coalesced_(Tensor & self, bool coalesced); // {"schema": "aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor indices(const Tensor & self); // {"schema": "aten::indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor values(const Tensor & self); // {"schema": "aten::values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor crow_indices(const Tensor & self); // {"schema": "aten::crow_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor col_indices(const Tensor & self); // {"schema": "aten::col_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor ccol_indices(const Tensor & self); // {"schema": "aten::ccol_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor row_indices(const Tensor & self); // {"schema": "aten::row_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor & hspmm_out(const Tensor & mat1, const Tensor & mat2, Tensor & out); // {"schema": "aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hspmm(const Tensor & mat1, const Tensor & mat2); // {"schema": "aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & copy_sparse_to_sparse_(Tensor & self, const Tensor & src, bool non_blocking); // {"schema": "aten::copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::vector unbind(const Tensor & self, int64_t dim); // {"schema": "aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"}
+::std::vector unbind(const Tensor & self, Dimname dim); // {"schema": "aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]", "dispatch": "False", "default": "True"}
+Tensor to_sparse(const Tensor & self, int64_t sparse_dim); // {"schema": "aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_sparse(const Tensor & self, int64_t sparse_dim); // {"schema": "aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_sparse(const Tensor & self, c10::optional layout, OptionalIntArrayRef blocksize, c10::optional dense_dim); // {"schema": "aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_sparse(const Tensor & self, c10::optional layout, OptionalIntArrayRef blocksize, c10::optional dense_dim); // {"schema": "aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_sparse_csr(const Tensor & self, c10::optional dense_dim); // {"schema": "aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_sparse_csr(const Tensor & self, c10::optional dense_dim); // {"schema": "aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_sparse_csc(const Tensor & self, c10::optional dense_dim); // {"schema": "aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_sparse_csc(const Tensor & self, c10::optional dense_dim); // {"schema": "aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_sparse_bsr(const Tensor & self, IntArrayRef blocksize, c10::optional dense_dim); // {"schema": "aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_sparse_bsr(const Tensor & self, IntArrayRef blocksize, c10::optional dense_dim); // {"schema": "aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_sparse_bsc(const Tensor & self, IntArrayRef blocksize, c10::optional dense_dim); // {"schema": "aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _to_sparse_bsc(const Tensor & self, IntArrayRef blocksize, c10::optional dense_dim); // {"schema": "aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _to_sparse_semi_structured(const Tensor & dense); // {"schema": "aten::_to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor to_mkldnn(const Tensor & self, c10::optional dtype); // {"schema": "aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_reorder_conv2d_weight(const Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, OptionalSymIntArrayRef input_size); // {"schema": "aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor mkldnn_reorder_conv3d_weight(const Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor to_mkldnn_backward(const Tensor & grad, const Tensor & input); // {"schema": "aten::to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor quantize_per_tensor_dynamic(const Tensor & self, ScalarType dtype, bool reduce_range); // {"schema": "aten::quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor quantize_per_tensor(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype); // {"schema": "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor quantize_per_tensor(const Tensor & self, const Tensor & scale, const Tensor & zero_point, ScalarType dtype); // {"schema": "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"}
+::std::vector quantize_per_tensor(TensorList tensors, const Tensor & scales, const Tensor & zero_points, ScalarType dtype); // {"schema": "aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[]", "dispatch": "True", "default": "False"}
+Tensor quantize_per_channel(const Tensor & self, const Tensor & scales, const Tensor & zero_points, int64_t axis, ScalarType dtype); // {"schema": "aten::quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor dequantize(const Tensor & self); // {"schema": "aten::dequantize.self(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+::std::vector dequantize(TensorList tensors); // {"schema": "aten::dequantize.tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "True", "default": "False"}
+double q_scale(const Tensor & self); // {"schema": "aten::q_scale(Tensor self) -> float", "dispatch": "True", "default": "False"}
+int64_t q_zero_point(const Tensor & self); // {"schema": "aten::q_zero_point(Tensor self) -> int", "dispatch": "True", "default": "False"}
+Tensor q_per_channel_scales(const Tensor & self); // {"schema": "aten::q_per_channel_scales(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor q_per_channel_zero_points(const Tensor & self); // {"schema": "aten::q_per_channel_zero_points(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+int64_t q_per_channel_axis(const Tensor & self); // {"schema": "aten::q_per_channel_axis(Tensor self) -> int", "dispatch": "True", "default": "False"}
+Tensor int_repr(const Tensor & self); // {"schema": "aten::int_repr(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _make_per_tensor_quantized_tensor(const Tensor & self, double scale, int64_t zero_point); // {"schema": "aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _make_per_channel_quantized_tensor(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis); // {"schema": "aten::_make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor", "dispatch": "True", "default": "False"}
+QScheme qscheme(const Tensor & self); // {"schema": "aten::qscheme(Tensor self) -> QScheme", "dispatch": "True", "default": "False"}
+Tensor fake_quantize_per_tensor_affine(const Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fake_quantize_per_tensor_affine(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple fake_quantize_per_tensor_affine_cachemask(const Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"}
+::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(const Tensor & self, const Tensor & scale, const Tensor & zero_point, const Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max); // {"schema": "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"}
+Tensor fake_quantize_per_tensor_affine_cachemask_backward(const Tensor & grad, const Tensor & mask); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _fake_quantize_learnable_per_tensor_affine(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _fake_quantize_learnable_per_tensor_affine_backward(const Tensor & grad, const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor fake_quantize_per_channel_affine(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple fake_quantize_per_channel_affine_cachemask(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"}
+Tensor fake_quantize_per_channel_affine_cachemask_backward(const Tensor & grad, const Tensor & mask); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _fake_quantize_learnable_per_channel_affine(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _fake_quantize_learnable_per_channel_affine_backward(const Tensor & grad, const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor fused_moving_avg_obs_fake_quant(const Tensor & self, const Tensor & observer_on, const Tensor & fake_quant_on, Tensor & running_min, Tensor & running_max, Tensor & scale, Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _fused_moving_avg_obs_fq_helper(const Tensor & self, const Tensor & observer_on, const Tensor & fake_quant_on, Tensor & running_min, Tensor & running_max, Tensor & scale, Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"}
+::std::tuple _choose_qparams_per_tensor(const Tensor & self, bool reduce_range); // {"schema": "aten::_choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)", "dispatch": "False", "default": "True"}
+Tensor _saturate_weight_to_fp16(const Tensor & weight); // {"schema": "aten::_saturate_weight_to_fp16(Tensor weight) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple choose_qparams_optimized(const Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width); // {"schema": "aten::choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor _autocast_to_reduced_precision(const Tensor & self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype); // {"schema": "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _autocast_to_full_precision(const Tensor & self, bool cuda_enabled, bool cpu_enabled); // {"schema": "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor _to_copy(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format); // {"schema": "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor to(const Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, bool copy, c10::optional memory_format); // {"schema": "aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy, c10::optional memory_format); // {"schema": "aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy, c10::optional memory_format); // {"schema": "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor to(const Tensor & self, const Tensor & other, bool non_blocking, bool copy, c10::optional memory_format); // {"schema": "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"}
+::std::vector meshgrid(TensorList tensors); // {"schema": "aten::meshgrid(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+::std::vector meshgrid(TensorList tensors, c10::string_view indexing); // {"schema": "aten::meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor cartesian_prod(TensorList tensors); // {"schema": "aten::cartesian_prod(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor combinations(const Tensor & self, int64_t r, bool with_replacement); // {"schema": "aten::combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor", "dispatch": "False", "default": "True"}
+Scalar item(const Tensor & self); // {"schema": "aten::item(Tensor self) -> Scalar", "dispatch": "False", "default": "True"}
+ScalarType result_type(const Tensor & tensor, const Tensor & other); // {"schema": "aten::result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType", "dispatch": "False", "default": "True"}
+ScalarType result_type(const Tensor & tensor, const Scalar & other); // {"schema": "aten::result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType", "dispatch": "False", "default": "True"}
+ScalarType result_type(const Scalar & scalar, const Tensor & tensor); // {"schema": "aten::result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType", "dispatch": "False", "default": "True"}
+ScalarType result_type(const Scalar & scalar1, const Scalar & scalar2); // {"schema": "aten::result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType", "dispatch": "False", "default": "True"}
+bool can_cast(ScalarType from, ScalarType to); // {"schema": "aten::can_cast(ScalarType from, ScalarType to) -> bool", "dispatch": "False", "default": "True"}
+ScalarType promote_types(ScalarType type1, ScalarType type2); // {"schema": "aten::promote_types(ScalarType type1, ScalarType type2) -> ScalarType", "dispatch": "False", "default": "True"}
+Scalar _local_scalar_dense(const Tensor & self); // {"schema": "aten::_local_scalar_dense(Tensor self) -> Scalar", "dispatch": "True", "default": "False"}
+::std::tuple _lstm_mps(const Tensor & input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple,::std::vector> lstm_mps_backward(const c10::optional & grad_y, const c10::optional & grad_hy, const c10::optional & grad_cy, const Tensor & z_state, const Tensor & cell_state_fwd, const Tensor & input, const Tensor & layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])", "dispatch": "True", "default": "False"}
+::std::tuple _thnn_fused_lstm_cell(const Tensor & input_gates, const Tensor & hidden_gates, const Tensor & cx, const c10::optional & input_bias, const c10::optional & hidden_bias); // {"schema": "aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _thnn_fused_lstm_cell_backward_impl(const c10::optional & grad_hy, const c10::optional & grad_cy, const Tensor & cx, const Tensor & cy, const Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _thnn_fused_lstm_cell_backward(const c10::optional & grad_hy, const c10::optional & grad_cy, const Tensor & cx, const Tensor & cy, const Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple _thnn_differentiable_lstm_cell_backward(const c10::optional & grad_hy, const c10::optional & grad_cy, const Tensor & input_gates, const Tensor & hidden_gates, const c10::optional & input_bias, const c10::optional & hidden_bias, const Tensor & cx, const Tensor & cy); // {"schema": "aten::_thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple _thnn_fused_gru_cell(const Tensor & input_gates, const Tensor & hidden_gates, const Tensor & hx, const c10::optional & input_bias, const c10::optional & hidden_bias); // {"schema": "aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _thnn_fused_gru_cell_backward(const Tensor & grad_hy, const Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _thnn_differentiable_gru_cell_backward(const Tensor & grad_hy, const Tensor & input_gates, const Tensor & hidden_gates, const Tensor & hx, const c10::optional & input_bias, const c10::optional & hidden_bias); // {"schema": "aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple lstm(const Tensor & input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple lstm(const Tensor & data, const Tensor & batch_sizes, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple gru(const Tensor & input, const Tensor & hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple gru(const Tensor & data, const Tensor & batch_sizes, const Tensor & hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple rnn_tanh(const Tensor & input, const Tensor & hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple rnn_tanh(const Tensor & data, const Tensor & batch_sizes, const Tensor & hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple rnn_relu(const Tensor & input, const Tensor & hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple rnn_relu(const Tensor & data, const Tensor & batch_sizes, const Tensor & hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple lstm_cell(const Tensor & input, TensorList hx, const Tensor & w_ih, const Tensor & w_hh, const c10::optional & b_ih, const c10::optional & b_hh); // {"schema": "aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor gru_cell(const Tensor & input, const Tensor & hx, const Tensor & w_ih, const Tensor & w_hh, const c10::optional & b_ih, const c10::optional & b_hh); // {"schema": "aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor rnn_tanh_cell(const Tensor & input, const Tensor & hx, const Tensor & w_ih, const Tensor & w_hh, const c10::optional & b_ih, const c10::optional & b_hh); // {"schema": "aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor rnn_relu_cell(const Tensor & input, const Tensor & hx, const Tensor & w_ih, const Tensor & w_hh, const c10::optional & b_ih, const c10::optional & b_hh); // {"schema": "aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple quantized_lstm_cell(const Tensor & input, TensorList hx, const Tensor & w_ih, const Tensor & w_hh, const Tensor & b_ih, const Tensor & b_hh, const Tensor & packed_ih, const Tensor & packed_hh, const Tensor & col_offsets_ih, const Tensor & col_offsets_hh, const Scalar & scale_ih, const Scalar & scale_hh, const Scalar & zero_point_ih, const Scalar & zero_point_hh); // {"schema": "aten::quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor quantized_gru_cell(const Tensor & input, const Tensor & hx, const Tensor & w_ih, const Tensor & w_hh, const Tensor & b_ih, const Tensor & b_hh, const Tensor & packed_ih, const Tensor & packed_hh, const Tensor & col_offsets_ih, const Tensor & col_offsets_hh, const Scalar & scale_ih, const Scalar & scale_hh, const Scalar & zero_point_ih, const Scalar & zero_point_hh); // {"schema": "aten::quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor quantized_rnn_relu_cell(const Tensor & input, const Tensor & hx, const Tensor & w_ih, const Tensor & w_hh, const Tensor & b_ih, const Tensor & b_hh, const Tensor & packed_ih, const Tensor & packed_hh, const Tensor & col_offsets_ih, const Tensor & col_offsets_hh, const Scalar & scale_ih, const Scalar & scale_hh, const Scalar & zero_point_ih, const Scalar & zero_point_hh); // {"schema": "aten::quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor quantized_rnn_tanh_cell(const Tensor & input, const Tensor & hx, const Tensor & w_ih, const Tensor & w_hh, const Tensor & b_ih, const Tensor & b_hh, const Tensor & packed_ih, const Tensor & packed_hh, const Tensor & col_offsets_ih, const Tensor & col_offsets_hh, const Scalar & scale_ih, const Scalar & scale_hh, const Scalar & zero_point_ih, const Scalar & zero_point_hh); // {"schema": "aten::quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _pack_padded_sequence(const Tensor & input, const Tensor & lengths, bool batch_first); // {"schema": "aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor _pack_padded_sequence_backward(const Tensor & grad, c10::SymIntArrayRef input_size, const Tensor & batch_sizes, bool batch_first); // {"schema": "aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple _pad_packed_sequence(const Tensor & data, const Tensor & batch_sizes, bool batch_first, const Scalar & padding_value, int64_t total_length); // {"schema": "aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+Tensor & set_(Tensor & self, Storage source); // {"schema": "aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & set_(Tensor & self, Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & set_(Tensor & self, const Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & set_(Tensor & self, const Tensor & source); // {"schema": "aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & set_(Tensor & self); // {"schema": "aten::set_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor lift(const Tensor & self); // {"schema": "aten::lift(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor lift_fresh(const Tensor & self); // {"schema": "aten::lift_fresh(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor lift_fresh_copy(const Tensor & self); // {"schema": "aten::lift_fresh_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+bool is_set_to(const Tensor & self, const Tensor & tensor); // {"schema": "aten::is_set_to(Tensor self, Tensor tensor) -> bool", "dispatch": "True", "default": "False"}
+Tensor & masked_fill_(Tensor & self, const Tensor & mask, const Scalar & value); // {"schema": "aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor masked_fill(const Tensor & self, const Tensor & mask, const Scalar & value); // {"schema": "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & masked_fill_(Tensor & self, const Tensor & mask, const Tensor & value); // {"schema": "aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & value); // {"schema": "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source); // {"schema": "aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source); // {"schema": "aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor masked_scatter_backward(const Tensor & grad_output, const Tensor & mask, c10::SymIntArrayRef sizes); // {"schema": "aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _masked_softmax(const Tensor & self, const Tensor & mask, c10::optional dim, c10::optional mask_type); // {"schema": "aten::_masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _masked_softmax_backward(const Tensor & grad_output, const Tensor & output, const Tensor & mask, c10::optional dim); // {"schema": "aten::_masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor view(const Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor view(const Tensor & self, ScalarType dtype); // {"schema": "aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor & put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate); // {"schema": "aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor put(const Tensor & self, const Tensor & index, const Tensor & source, bool accumulate); // {"schema": "aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_add_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, const Scalar & alpha, Tensor & out); // {"schema": "aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & index_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, const Scalar & alpha); // {"schema": "aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, const Scalar & alpha); // {"schema": "aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor index_add(const Tensor & self, Dimname dim, const Tensor & index, const Tensor & source, const Scalar & alpha); // {"schema": "aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & index_reduce_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, c10::string_view reduce, bool include_self, Tensor & out); // {"schema": "aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & index_reduce_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, c10::string_view reduce, bool include_self); // {"schema": "aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor index_reduce(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, c10::string_view reduce, bool include_self); // {"schema": "aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar & value); // {"schema": "aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar & value); // {"schema": "aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & value); // {"schema": "aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & value); // {"schema": "aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & index_fill_(Tensor & self, Dimname dim, const Tensor & index, const Scalar & value); // {"schema": "aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & index_fill_(Tensor & self, Dimname dim, const Tensor & index, const Tensor & value); // {"schema": "aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor index_fill(const Tensor & self, Dimname dim, const Tensor & index, const Scalar & value); // {"schema": "aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor index_fill(const Tensor & self, Dimname dim, const Tensor & index, const Tensor & value); // {"schema": "aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src); // {"schema": "aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src); // {"schema": "aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scatter_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, Tensor & out); // {"schema": "aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Scalar & value); // {"schema": "aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Scalar & value); // {"schema": "aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scatter_out(const Tensor & self, int64_t dim, const Tensor & index, const Scalar & value, Tensor & out); // {"schema": "aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, c10::string_view reduce); // {"schema": "aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, c10::string_view reduce); // {"schema": "aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scatter_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, c10::string_view reduce, Tensor & out); // {"schema": "aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Scalar & value, c10::string_view reduce); // {"schema": "aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Scalar & value, c10::string_view reduce); // {"schema": "aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scatter_out(const Tensor & self, int64_t dim, const Tensor & index, const Scalar & value, c10::string_view reduce, Tensor & out); // {"schema": "aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor scatter(const Tensor & self, Dimname dim, const Tensor & index, const Tensor & src); // {"schema": "aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor scatter(const Tensor & self, Dimname dim, const Tensor & index, const Scalar & value); // {"schema": "aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src); // {"schema": "aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src); // {"schema": "aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scatter_add_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, Tensor & out); // {"schema": "aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor scatter_add(const Tensor & self, Dimname dim, const Tensor & index, const Tensor & src); // {"schema": "aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor scatter_reduce(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, c10::string_view reduce, bool include_self); // {"schema": "aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & scatter_reduce_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, c10::string_view reduce, bool include_self); // {"schema": "aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scatter_reduce_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src, c10::string_view reduce, bool include_self, Tensor & out); // {"schema": "aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & eq_(Tensor & self, const Scalar & other); // {"schema": "aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & eq_(Tensor & self, const Tensor & other); // {"schema": "aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_and_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & bitwise_and_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor bitwise_and(const Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bitwise_and(const Scalar & self, const Tensor & other); // {"schema": "aten::bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bitwise_and(const Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_and_(Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_and_(Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor __and__(const Tensor & self, const Scalar & other); // {"schema": "aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor __and__(const Tensor & self, const Tensor & other); // {"schema": "aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & __iand__(Tensor & self, const Scalar & other); // {"schema": "aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & __iand__(Tensor & self, const Tensor & other); // {"schema": "aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & bitwise_or_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & bitwise_or_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor bitwise_or(const Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bitwise_or(const Scalar & self, const Tensor & other); // {"schema": "aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bitwise_or(const Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_or_(Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_or_(Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor __or__(const Tensor & self, const Scalar & other); // {"schema": "aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor __or__(const Tensor & self, const Tensor & other); // {"schema": "aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & __ior__(Tensor & self, const Scalar & other); // {"schema": "aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & __ior__(Tensor & self, const Tensor & other); // {"schema": "aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & bitwise_xor_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & bitwise_xor_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor bitwise_xor(const Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bitwise_xor(const Scalar & self, const Tensor & other); // {"schema": "aten::bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor bitwise_xor(const Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_xor_(Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_xor_(Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor __xor__(const Tensor & self, const Scalar & other); // {"schema": "aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor __xor__(const Tensor & self, const Tensor & other); // {"schema": "aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & __ixor__(Tensor & self, const Scalar & other); // {"schema": "aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & __ixor__(Tensor & self, const Tensor & other); // {"schema": "aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor __lshift__(const Tensor & self, const Scalar & other); // {"schema": "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor __lshift__(const Tensor & self, const Tensor & other); // {"schema": "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & __ilshift__(Tensor & self, const Scalar & other); // {"schema": "aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & __ilshift__(Tensor & self, const Tensor & other); // {"schema": "aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bitwise_left_shift(const Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_left_shift_(Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_left_shift_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bitwise_left_shift(const Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_left_shift_(Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_left_shift_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor bitwise_left_shift(const Scalar & self, const Tensor & other); // {"schema": "aten::bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor __rshift__(const Tensor & self, const Scalar & other); // {"schema": "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor __rshift__(const Tensor & self, const Tensor & other); // {"schema": "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & __irshift__(Tensor & self, const Scalar & other); // {"schema": "aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & __irshift__(Tensor & self, const Tensor & other); // {"schema": "aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bitwise_right_shift(const Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_right_shift_(Tensor & self, const Tensor & other); // {"schema": "aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_right_shift_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bitwise_right_shift(const Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bitwise_right_shift_(Tensor & self, const Scalar & other); // {"schema": "aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_right_shift_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor bitwise_right_shift(const Scalar & self, const Tensor & other); // {"schema": "aten::bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & tril_(Tensor & self, int64_t diagonal); // {"schema": "aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & triu_(Tensor & self, int64_t diagonal); // {"schema": "aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & digamma_(Tensor & self); // {"schema": "aten::digamma_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & lerp_(Tensor & self, const Tensor & end, const Scalar & weight); // {"schema": "aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & lerp_(Tensor & self, const Tensor & end, const Tensor & weight); // {"schema": "aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & addbmm_(Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & addbmm_out(const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor addbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar & beta, const Scalar & alpha); // {"schema": "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & random_(Tensor & self, int64_t from, c10::optional to, c10::optional generator); // {"schema": "aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & random_(Tensor & self, int64_t to, c10::optional generator); // {"schema": "aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & random_(Tensor & self, c10::optional generator); // {"schema": "aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & uniform_(Tensor & self, double from, double to, c10::optional generator); // {"schema": "aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & cauchy_(Tensor & self, double median, double sigma, c10::optional generator); // {"schema": "aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & log_normal_(Tensor & self, double mean, double std, c10::optional generator); // {"schema": "aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & exponential_(Tensor & self, double lambd, c10::optional generator); // {"schema": "aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & geometric_(Tensor & self, double p, c10::optional generator); // {"schema": "aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & diag_out(const Tensor & self, int64_t diagonal, Tensor & out); // {"schema": "aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor diag(const Tensor & self, int64_t diagonal); // {"schema": "aten::diag(Tensor self, int diagonal=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & cross_out(const Tensor & self, const Tensor & other, c10::optional dim, Tensor & out); // {"schema": "aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor cross(const Tensor & self, const Tensor & other, c10::optional dim); // {"schema": "aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & triu_out(const Tensor & self, int64_t diagonal, Tensor & out); // {"schema": "aten::triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor triu(const Tensor & self, int64_t diagonal); // {"schema": "aten::triu(Tensor self, int diagonal=0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & tril_out(const Tensor & self, int64_t diagonal, Tensor & out); // {"schema": "aten::tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor tril(const Tensor & self, int64_t diagonal); // {"schema": "aten::tril(Tensor self, int diagonal=0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor tril_indices(int64_t row, int64_t col, int64_t offset, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor triu_indices(int64_t row, int64_t col, int64_t offset, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor trace(const Tensor & self); // {"schema": "aten::trace(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor trace_backward(const Tensor & grad, c10::SymIntArrayRef sizes); // {"schema": "aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & ne_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor ne(const Tensor & self, const Scalar & other); // {"schema": "aten::ne.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ne_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor ne(const Tensor & self, const Tensor & other); // {"schema": "aten::ne.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ne_(Tensor & self, const Scalar & other); // {"schema": "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & ne_(Tensor & self, const Tensor & other); // {"schema": "aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & not_equal_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor not_equal(const Tensor & self, const Scalar & other); // {"schema": "aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & not_equal_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor not_equal(const Tensor & self, const Tensor & other); // {"schema": "aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & not_equal_(Tensor & self, const Scalar & other); // {"schema": "aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & not_equal_(Tensor & self, const Tensor & other); // {"schema": "aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & eq_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor eq(const Tensor & self, const Scalar & other); // {"schema": "aten::eq.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & eq_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor eq(const Tensor & self, const Tensor & other); // {"schema": "aten::eq.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ge_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor ge(const Tensor & self, const Scalar & other); // {"schema": "aten::ge.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ge_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor ge(const Tensor & self, const Tensor & other); // {"schema": "aten::ge.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & ge_(Tensor & self, const Scalar & other); // {"schema": "aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & ge_(Tensor & self, const Tensor & other); // {"schema": "aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & greater_equal_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor greater_equal(const Tensor & self, const Scalar & other); // {"schema": "aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & greater_equal_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor greater_equal(const Tensor & self, const Tensor & other); // {"schema": "aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & greater_equal_(Tensor & self, const Scalar & other); // {"schema": "aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & greater_equal_(Tensor & self, const Tensor & other); // {"schema": "aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & le_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor le(const Tensor & self, const Scalar & other); // {"schema": "aten::le.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & le_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor le(const Tensor & self, const Tensor & other); // {"schema": "aten::le.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & le_(Tensor & self, const Scalar & other); // {"schema": "aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & le_(Tensor & self, const Tensor & other); // {"schema": "aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & less_equal_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor less_equal(const Tensor & self, const Scalar & other); // {"schema": "aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & less_equal_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor less_equal(const Tensor & self, const Tensor & other); // {"schema": "aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & less_equal_(Tensor & self, const Scalar & other); // {"schema": "aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & less_equal_(Tensor & self, const Tensor & other); // {"schema": "aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & gt_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor gt(const Tensor & self, const Scalar & other); // {"schema": "aten::gt.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & gt_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor gt(const Tensor & self, const Tensor & other); // {"schema": "aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & gt_(Tensor & self, const Scalar & other); // {"schema": "aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & gt_(Tensor & self, const Tensor & other); // {"schema": "aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & greater_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor greater(const Tensor & self, const Scalar & other); // {"schema": "aten::greater.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & greater_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor greater(const Tensor & self, const Tensor & other); // {"schema": "aten::greater.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & greater_(Tensor & self, const Scalar & other); // {"schema": "aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & greater_(Tensor & self, const Tensor & other); // {"schema": "aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & lt_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor lt(const Tensor & self, const Scalar & other); // {"schema": "aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & lt_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor lt(const Tensor & self, const Tensor & other); // {"schema": "aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & lt_(Tensor & self, const Scalar & other); // {"schema": "aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & lt_(Tensor & self, const Tensor & other); // {"schema": "aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & less_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor less(const Tensor & self, const Scalar & other); // {"schema": "aten::less.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & less_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor less(const Tensor & self, const Tensor & other); // {"schema": "aten::less.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & less_(Tensor & self, const Scalar & other); // {"schema": "aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & less_(Tensor & self, const Tensor & other); // {"schema": "aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & take_out(const Tensor & self, const Tensor & index, Tensor & out); // {"schema": "aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor take(const Tensor & self, const Tensor & index); // {"schema": "aten::take(Tensor self, Tensor index) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & take_along_dim_out(const Tensor & self, const Tensor & indices, c10::optional dim, Tensor & out); // {"schema": "aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor take_along_dim(const Tensor & self, const Tensor & indices, c10::optional dim); // {"schema": "aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & index_select_out(const Tensor & self, int64_t dim, const Tensor & index, Tensor & out); // {"schema": "aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor index_select(const Tensor & self, int64_t dim, const Tensor & index); // {"schema": "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & index_select_out(const Tensor & self, Dimname dim, const Tensor & index, Tensor & out); // {"schema": "aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor index_select(const Tensor & self, Dimname dim, const Tensor & index); // {"schema": "aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor index_select_backward(const Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const Tensor & index); // {"schema": "aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & masked_select_out(const Tensor & self, const Tensor & mask, Tensor & out); // {"schema": "aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor masked_select(const Tensor & self, const Tensor & mask); // {"schema": "aten::masked_select(Tensor self, Tensor mask) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor masked_select_backward(const Tensor & grad, const Tensor & input, const Tensor & mask); // {"schema": "aten::masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & nonzero_out(const Tensor & self, Tensor & out); // {"schema": "aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor nonzero(const Tensor & self); // {"schema": "aten::nonzero(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & nonzero_static_out(const Tensor & self, int64_t size, int64_t fill_value, Tensor & out); // {"schema": "aten::nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor nonzero_static(const Tensor & self, int64_t size, int64_t fill_value); // {"schema": "aten::nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor", "dispatch": "True", "default": "False"}
+::std::vector nonzero_numpy(const Tensor & self); // {"schema": "aten::nonzero_numpy(Tensor self) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor argwhere(const Tensor & self); // {"schema": "aten::argwhere(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & gather_out(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad, Tensor & out); // {"schema": "aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad); // {"schema": "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor gather_backward(const Tensor & grad, const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad); // {"schema": "aten::gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & gather_out(const Tensor & self, Dimname dim, const Tensor & index, bool sparse_grad, Tensor & out); // {"schema": "aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor gather(const Tensor & self, Dimname dim, const Tensor & index, bool sparse_grad); // {"schema": "aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _gather_sparse_backward(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & grad); // {"schema": "aten::_gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & addcmul_out(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, const Scalar & value, Tensor & out); // {"schema": "aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, const Scalar & value); // {"schema": "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & addcmul_(Tensor & self, const Tensor & tensor1, const Tensor & tensor2, const Scalar & value); // {"schema": "aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & addcdiv_out(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, const Scalar & value, Tensor & out); // {"schema": "aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, const Scalar & value); // {"schema": "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & addcdiv_(Tensor & self, const Tensor & tensor1, const Tensor & tensor2, const Scalar & value); // {"schema": "aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor cross_entropy_loss(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, double label_smoothing); // {"schema": "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple triangular_solve_out(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular, Tensor & X, Tensor & M); // {"schema": "aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)", "dispatch": "True", "default": "False"}
+::std::tuple triangular_solve(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular); // {"schema": "aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)", "dispatch": "True", "default": "True"}
+void _linalg_check_errors(const Tensor & info, c10::string_view api_name, bool is_matrix); // {"schema": "aten::_linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> ()", "dispatch": "True", "default": "True"}
+Tensor & linalg_solve_triangular_out(const Tensor & self, const Tensor & B, bool upper, bool left, bool unitriangular, Tensor & out); // {"schema": "aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor linalg_solve_triangular(const Tensor & self, const Tensor & B, bool upper, bool left, bool unitriangular); // {"schema": "aten::linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor linalg_vander(const Tensor & x, c10::optional N); // {"schema": "aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple svd_out(const Tensor & self, bool some, bool compute_uv, Tensor & U, Tensor & S, Tensor & V); // {"schema": "aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)", "dispatch": "False", "default": "True"}
+::std::tuple svd(const Tensor & self, bool some, bool compute_uv); // {"schema": "aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)", "dispatch": "False", "default": "True"}
+Tensor swapaxes(const Tensor & self, int64_t axis0, int64_t axis1); // {"schema": "aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor & swapaxes_(Tensor & self, int64_t axis0, int64_t axis1); // {"schema": "aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor swapdims(const Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "dispatch": "False", "default": "True"}
+Tensor & swapdims_(Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & cholesky_out(const Tensor & self, bool upper, Tensor & out); // {"schema": "aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor cholesky(const Tensor & self, bool upper); // {"schema": "aten::cholesky(Tensor self, bool upper=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & cholesky_solve_out(const Tensor & self, const Tensor & input2, bool upper, Tensor & out); // {"schema": "aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor cholesky_solve(const Tensor & self, const Tensor & input2, bool upper); // {"schema": "aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _cholesky_solve_helper(const Tensor & self, const Tensor & A, bool upper); // {"schema": "aten::_cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor cholesky_inverse(const Tensor & self, bool upper); // {"schema": "aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & cholesky_inverse_out(const Tensor & self, bool upper, Tensor & out); // {"schema": "aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple qr_out(const Tensor & self, bool some, Tensor & Q, Tensor & R); // {"schema": "aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)", "dispatch": "False", "default": "True"}
+::std::tuple qr(const Tensor & self, bool some); // {"schema": "aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)", "dispatch": "False", "default": "True"}
+::std::tuple geqrf_out(const Tensor & self, Tensor & a, Tensor & tau); // {"schema": "aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)", "dispatch": "True", "default": "False"}
+::std::tuple geqrf(const Tensor & self); // {"schema": "aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)", "dispatch": "True", "default": "False"}
+Tensor orgqr(const Tensor & self, const Tensor & input2); // {"schema": "aten::orgqr(Tensor self, Tensor input2) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & orgqr_out(const Tensor & self, const Tensor & input2, Tensor & out); // {"schema": "aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & ormqr_out(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose, Tensor & out); // {"schema": "aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose); // {"schema": "aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _lu_with_info(const Tensor & self, bool pivot, bool check_errors); // {"schema": "aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)", "dispatch": "False", "default": "True"}
+Tensor & lu_solve_out(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots, Tensor & out); // {"schema": "aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots); // {"schema": "aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple lu_unpack(const Tensor & LU_data, const Tensor & LU_pivots, bool unpack_data, bool unpack_pivots); // {"schema": "aten::lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)", "dispatch": "True", "default": "True"}
+::std::tuple lu_unpack_out(const Tensor & LU_data, const Tensor & LU_pivots, bool unpack_data, bool unpack_pivots, Tensor & P, Tensor & L, Tensor & U); // {"schema": "aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)", "dispatch": "True", "default": "False"}
+Tensor & multinomial_out(const Tensor & self, int64_t num_samples, bool replacement, c10::optional generator, Tensor & out); // {"schema": "aten::multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, c10::optional generator); // {"schema": "aten::multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & lgamma_out(const Tensor & self, Tensor & out); // {"schema": "aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & lgamma_(Tensor & self); // {"schema": "aten::lgamma_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor lgamma(const Tensor & self); // {"schema": "aten::lgamma(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & digamma_out(const Tensor & self, Tensor & out); // {"schema": "aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor digamma(const Tensor & self); // {"schema": "aten::digamma(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & polygamma_out(int64_t n, const Tensor & self, Tensor & out); // {"schema": "aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor polygamma(int64_t n, const Tensor & self); // {"schema": "aten::polygamma(int n, Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & polygamma_(Tensor & self, int64_t n); // {"schema": "aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor erfinv(const Tensor & self); // {"schema": "aten::erfinv(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & erfinv_(Tensor & self); // {"schema": "aten::erfinv_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & erfinv_out(const Tensor & self, Tensor & out); // {"schema": "aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor i0(const Tensor & self); // {"schema": "aten::i0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & i0_(Tensor & self); // {"schema": "aten::i0_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & i0_out(const Tensor & self, Tensor & out); // {"schema": "aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sign(const Tensor & self); // {"schema": "aten::sign(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sign_(Tensor & self); // {"schema": "aten::sign_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sign_out(const Tensor & self, Tensor & out); // {"schema": "aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor signbit(const Tensor & self); // {"schema": "aten::signbit(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & signbit_out(const Tensor & self, Tensor & out); // {"schema": "aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor dist(const Tensor & self, const Tensor & other, const Scalar & p); // {"schema": "aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & atan2_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & atan2_(Tensor & self, const Tensor & other); // {"schema": "aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor atan2(const Tensor & self, const Tensor & other); // {"schema": "aten::atan2(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor arctan2(const Tensor & self, const Tensor & other); // {"schema": "aten::arctan2(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & arctan2_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & arctan2_(Tensor & self, const Tensor & other); // {"schema": "aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & lerp_out(const Tensor & self, const Tensor & end, const Scalar & weight, Tensor & out); // {"schema": "aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & lerp_out(const Tensor & self, const Tensor & end, const Tensor & weight, Tensor & out); // {"schema": "aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor lerp(const Tensor & self, const Tensor & end, const Scalar & weight); // {"schema": "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor lerp(const Tensor & self, const Tensor & end, const Tensor & weight); // {"schema": "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & histc_out(const Tensor & self, int64_t bins, const Scalar & min, const Scalar & max, Tensor & out); // {"schema": "aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor histc(const Tensor & self, int64_t bins, const Scalar & min, const Scalar & max); // {"schema": "aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple histogram_out(const Tensor & self, const Tensor & bins, const c10::optional & weight, bool density, Tensor & hist, Tensor & bin_edges); // {"schema": "aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)", "dispatch": "True", "default": "False"}
+::std::tuple histogram(const Tensor & self, const Tensor & bins, const c10::optional & weight, bool density); // {"schema": "aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)", "dispatch": "True", "default": "False"}
+::std::tuple histogram_out(const Tensor & self, int64_t bins, c10::optional> range, const c10::optional & weight, bool density, Tensor & hist, Tensor & bin_edges); // {"schema": "aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)", "dispatch": "True", "default": "False"}
+::std::tuple histogram(const Tensor & self, int64_t bins, c10::optional> range, const c10::optional & weight, bool density); // {"schema": "aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)", "dispatch": "True", "default": "False"}
+::std::vector _histogramdd_bin_edges(const Tensor & self, IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density); // {"schema": "aten::_histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]", "dispatch": "True", "default": "False"}
+Tensor _histogramdd_from_bin_cts(const Tensor & self, IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density); // {"schema": "aten::_histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _histogramdd_from_bin_tensors(const Tensor & self, TensorList bins, const c10::optional & weight, bool density); // {"schema": "aten::_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple> histogramdd(const Tensor & self, IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density); // {"schema": "aten::histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"}
+::std::tuple> histogramdd(const Tensor & self, int64_t bins, c10::optional> range, const c10::optional & weight, bool density); // {"schema": "aten::histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"}
+::std::tuple> histogramdd(const Tensor & self, TensorList bins, c10::optional> range, const c10::optional & weight, bool density); // {"schema": "aten::histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"}
+Tensor & fmod_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor fmod(const Tensor & self, const Scalar & other); // {"schema": "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fmod_(Tensor & self, const Scalar & other); // {"schema": "aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & fmod_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor fmod(const Tensor & self, const Tensor & other); // {"schema": "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fmod_(Tensor & self, const Tensor & other); // {"schema": "aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hypot_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hypot(const Tensor & self, const Tensor & other); // {"schema": "aten::hypot(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & hypot_(Tensor & self, const Tensor & other); // {"schema": "aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & igamma_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor igamma(const Tensor & self, const Tensor & other); // {"schema": "aten::igamma(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & igamma_(Tensor & self, const Tensor & other); // {"schema": "aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & igammac_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor igammac(const Tensor & self, const Tensor & other); // {"schema": "aten::igammac(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & igammac_(Tensor & self, const Tensor & other); // {"schema": "aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & nextafter_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor nextafter(const Tensor & self, const Tensor & other); // {"schema": "aten::nextafter(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & nextafter_(Tensor & self, const Tensor & other); // {"schema": "aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & remainder_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor remainder(const Tensor & self, const Scalar & other); // {"schema": "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & remainder_(Tensor & self, const Scalar & other); // {"schema": "aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & remainder_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor remainder(const Tensor & self, const Tensor & other); // {"schema": "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & remainder_(Tensor & self, const Tensor & other); // {"schema": "aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor remainder(const Scalar & self, const Tensor & other); // {"schema": "aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor min(const Tensor & self); // {"schema": "aten::min(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & min_out(const Tensor & self, Tensor & out); // {"schema": "aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor fmin(const Tensor & self, const Tensor & other); // {"schema": "aten::fmin(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fmin_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor max(const Tensor & self); // {"schema": "aten::max(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor fmax(const Tensor & self, const Tensor & other); // {"schema": "aten::fmax(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fmax_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor maximum(const Tensor & self, const Tensor & other); // {"schema": "aten::maximum(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & maximum_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor max(const Tensor & self, const Tensor & other); // {"schema": "aten::max.other(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & max_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & max_out(const Tensor & self, Tensor & out); // {"schema": "aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor minimum(const Tensor & self, const Tensor & other); // {"schema": "aten::minimum(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & minimum_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & min_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor min(const Tensor & self, const Tensor & other); // {"schema": "aten::min.other(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor quantile(const Tensor & self, const Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & quantile_out(const Tensor & self, const Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation, Tensor & out); // {"schema": "aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor quantile(const Tensor & self, double q, c10::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & quantile_out(const Tensor & self, double q, c10::optional dim, bool keepdim, c10::string_view interpolation, Tensor & out); // {"schema": "aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nanquantile(const Tensor & self, const Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & nanquantile_out(const Tensor & self, const Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation, Tensor & out); // {"schema": "aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nanquantile(const Tensor & self, double q, c10::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & nanquantile_out(const Tensor & self, double q, c10::optional dim, bool keepdim, c10::string_view interpolation, Tensor & out); // {"schema": "aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple sort_out(const Tensor & self, int64_t dim, bool descending, Tensor & values, Tensor & indices); // {"schema": "aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"}
+::std::tuple sort_out(const Tensor & self, c10::optional stable, int64_t dim, bool descending, Tensor & values, Tensor & indices); // {"schema": "aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple sort(const Tensor & self, int64_t dim, bool descending); // {"schema": "aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple sort(const Tensor & self, c10::optional stable, int64_t dim, bool descending); // {"schema": "aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+::std::tuple sort_out(const Tensor & self, Dimname dim, bool descending, Tensor & values, Tensor & indices); // {"schema": "aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+::std::tuple sort_out(const Tensor & self, c10::optional stable, Dimname dim, bool descending, Tensor & values, Tensor & indices); // {"schema": "aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"}
+::std::tuple sort(const Tensor & self, Dimname dim, bool descending); // {"schema": "aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+::std::tuple sort(const Tensor & self, c10::optional stable, Dimname dim, bool descending); // {"schema": "aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"}
+Tensor & msort_out(const Tensor & self, Tensor & out); // {"schema": "aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor msort(const Tensor & self); // {"schema": "aten::msort(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor argsort(const Tensor & self, int64_t dim, bool descending); // {"schema": "aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor argsort(const Tensor & self, bool stable, int64_t dim, bool descending); // {"schema": "aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor argsort(const Tensor & self, Dimname dim, bool descending); // {"schema": "aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple topk_out(const Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, Tensor & values, Tensor & indices); // {"schema": "aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"}
+::std::tuple topk(const Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted); // {"schema": "aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"}
+Tensor all(const Tensor & self); // {"schema": "aten::all(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & all_out(const Tensor & self, Tensor & out); // {"schema": "aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor any(const Tensor & self); // {"schema": "aten::any(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & any_out(const Tensor & self, Tensor & out); // {"schema": "aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & renorm_out(const Tensor & self, const Scalar & p, int64_t dim, const Scalar & maxnorm, Tensor & out); // {"schema": "aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor renorm(const Tensor & self, const Scalar & p, int64_t dim, const Scalar & maxnorm); // {"schema": "aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & renorm_(Tensor & self, const Scalar & p, int64_t dim, const Scalar & maxnorm); // {"schema": "aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor unfold(const Tensor & self, int64_t dimension, int64_t size, int64_t step); // {"schema": "aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)", "dispatch": "True", "default": "False"}
+Tensor unfold_backward(const Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step); // {"schema": "aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor", "dispatch": "True", "default": "False"}
+bool equal(const Tensor & self, const Tensor & other); // {"schema": "aten::equal(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "False"}
+Tensor & pow_out(const Tensor & self, const Tensor & exponent, Tensor & out); // {"schema": "aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor pow(const Tensor & self, const Tensor & exponent); // {"schema": "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & pow_out(const Scalar & self, const Tensor & exponent, Tensor & out); // {"schema": "aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor pow(const Scalar & self, const Tensor & exponent); // {"schema": "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & pow_out(const Tensor & self, const Scalar & exponent, Tensor & out); // {"schema": "aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor pow(const Tensor & self, const Scalar & exponent); // {"schema": "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & pow_(Tensor & self, const Scalar & exponent); // {"schema": "aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & pow_(Tensor & self, const Tensor & exponent); // {"schema": "aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & float_power_out(const Tensor & self, const Tensor & exponent, Tensor & out); // {"schema": "aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor float_power(const Tensor & self, const Tensor & exponent); // {"schema": "aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & float_power_out(const Scalar & self, const Tensor & exponent, Tensor & out); // {"schema": "aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor float_power(const Scalar & self, const Tensor & exponent); // {"schema": "aten::float_power.Scalar(Scalar self, Tensor exponent) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & float_power_out(const Tensor & self, const Scalar & exponent, Tensor & out); // {"schema": "aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor float_power(const Tensor & self, const Scalar & exponent); // {"schema": "aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & float_power_(Tensor & self, const Scalar & exponent); // {"schema": "aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & float_power_(Tensor & self, const Tensor & exponent); // {"schema": "aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & normal_(Tensor & self, double mean, double std, c10::optional generator); // {"schema": "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor normal_functional(const Tensor & self, double mean, double std, c10::optional generator); // {"schema": "aten::normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & normal_out(const Tensor & mean, double std, c10::optional generator, Tensor & out); // {"schema": "aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor normal(const Tensor & mean, double std, c10::optional generator); // {"schema": "aten::normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & normal_out(double mean, const Tensor & std, c10::optional generator, Tensor & out); // {"schema": "aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor normal(double mean, const Tensor & std, c10::optional generator); // {"schema": "aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & normal_out(const Tensor & mean, const Tensor & std, c10::optional generator, Tensor & out); // {"schema": "aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor normal(const Tensor & mean, const Tensor & std, c10::optional generator); // {"schema": "aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor normal(double mean, double std, c10::SymIntArrayRef size, c10::optional generator, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & normal_out(double mean, double std, c10::SymIntArrayRef size, c10::optional generator, Tensor & out); // {"schema": "aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor alias(const Tensor & self); // {"schema": "aten::alias(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+void _amp_foreach_non_finite_check_and_unscale_(TensorList self, Tensor & found_inf, const Tensor & inv_scale); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()", "dispatch": "True", "default": "False"}
+Tensor & _amp_update_scale_(Tensor & self, Tensor & growth_tracker, const Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); // {"schema": "aten::_amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::vector _foreach_add(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_add_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_add(TensorList self, TensorList other, const Scalar & alpha); // {"schema": "aten::_foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_add_(TensorList self, TensorList other, const Scalar & alpha); // {"schema": "aten::_foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_add(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_add_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_add(TensorList self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::_foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_add_(TensorList self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::_foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sub(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sub_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sub(TensorList self, TensorList other, const Scalar & alpha); // {"schema": "aten::_foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sub_(TensorList self, TensorList other, const Scalar & alpha); // {"schema": "aten::_foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sub(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sub_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_mul(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_mul_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_mul(TensorList self, TensorList other); // {"schema": "aten::_foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_mul_(TensorList self, TensorList other); // {"schema": "aten::_foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_mul(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_mul_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_mul(TensorList self, const Tensor & other); // {"schema": "aten::_foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_mul_(TensorList self, const Tensor & other); // {"schema": "aten::_foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_div(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_div_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_div(TensorList self, TensorList other); // {"schema": "aten::_foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_div_(TensorList self, TensorList other); // {"schema": "aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_div(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_div_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_div(TensorList self, const Tensor & other); // {"schema": "aten::_foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_div_(TensorList self, const Tensor & other); // {"schema": "aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_clamp_max(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_clamp_max_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_clamp_max(TensorList self, TensorList other); // {"schema": "aten::_foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_clamp_max_(TensorList self, TensorList other); // {"schema": "aten::_foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_clamp_max(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_clamp_max_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_clamp_min(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_clamp_min_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_clamp_min(TensorList self, TensorList other); // {"schema": "aten::_foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_clamp_min_(TensorList self, TensorList other); // {"schema": "aten::_foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_clamp_min(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_clamp_min_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_maximum(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_maximum_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_maximum(TensorList self, TensorList other); // {"schema": "aten::_foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_maximum_(TensorList self, TensorList other); // {"schema": "aten::_foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_maximum(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_maximum_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_minimum(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_minimum_(TensorList self, const Scalar & scalar); // {"schema": "aten::_foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_minimum(TensorList self, TensorList other); // {"schema": "aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_minimum_(TensorList self, TensorList other); // {"schema": "aten::_foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_minimum(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_minimum_(TensorList self, ArrayRef scalars); // {"schema": "aten::_foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_addcdiv(TensorList self, TensorList tensor1, TensorList tensor2, const Scalar & value); // {"schema": "aten::_foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_addcdiv(TensorList self, TensorList tensor1, TensorList tensor2, ArrayRef scalars); // {"schema": "aten::_foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_addcdiv(TensorList self, TensorList tensor1, TensorList tensor2, const Tensor & scalars); // {"schema": "aten::_foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_addcdiv_(TensorList self, TensorList tensor1, TensorList tensor2, const Scalar & value); // {"schema": "aten::_foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_addcdiv_(TensorList self, TensorList tensor1, TensorList tensor2, ArrayRef scalars); // {"schema": "aten::_foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_addcdiv_(TensorList self, TensorList tensor1, TensorList tensor2, const Tensor & scalars); // {"schema": "aten::_foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_addcmul(TensorList self, TensorList tensor1, TensorList tensor2, const Scalar & value); // {"schema": "aten::_foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_addcmul(TensorList self, TensorList tensor1, TensorList tensor2, ArrayRef scalars); // {"schema": "aten::_foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_addcmul(TensorList self, TensorList tensor1, TensorList tensor2, const Tensor & scalars); // {"schema": "aten::_foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_addcmul_(TensorList self, TensorList tensor1, TensorList tensor2, const Scalar & value); // {"schema": "aten::_foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_addcmul_(TensorList self, TensorList tensor1, TensorList tensor2, ArrayRef scalars); // {"schema": "aten::_foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_addcmul_(TensorList self, TensorList tensor1, TensorList tensor2, const Tensor & scalars); // {"schema": "aten::_foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_abs(TensorList self); // {"schema": "aten::_foreach_abs(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_abs_(TensorList self); // {"schema": "aten::_foreach_abs_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_acos(TensorList self); // {"schema": "aten::_foreach_acos(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_acos_(TensorList self); // {"schema": "aten::_foreach_acos_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_asin(TensorList self); // {"schema": "aten::_foreach_asin(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_asin_(TensorList self); // {"schema": "aten::_foreach_asin_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_atan(TensorList self); // {"schema": "aten::_foreach_atan(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_atan_(TensorList self); // {"schema": "aten::_foreach_atan_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_ceil(TensorList self); // {"schema": "aten::_foreach_ceil(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_ceil_(TensorList self); // {"schema": "aten::_foreach_ceil_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_cos(TensorList self); // {"schema": "aten::_foreach_cos(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_cos_(TensorList self); // {"schema": "aten::_foreach_cos_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_cosh(TensorList self); // {"schema": "aten::_foreach_cosh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_cosh_(TensorList self); // {"schema": "aten::_foreach_cosh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_erf(TensorList self); // {"schema": "aten::_foreach_erf(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_erf_(TensorList self); // {"schema": "aten::_foreach_erf_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_erfc(TensorList self); // {"schema": "aten::_foreach_erfc(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_erfc_(TensorList self); // {"schema": "aten::_foreach_erfc_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_exp(TensorList self); // {"schema": "aten::_foreach_exp(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_exp_(TensorList self); // {"schema": "aten::_foreach_exp_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_expm1(TensorList self); // {"schema": "aten::_foreach_expm1(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_expm1_(TensorList self); // {"schema": "aten::_foreach_expm1_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_floor(TensorList self); // {"schema": "aten::_foreach_floor(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_floor_(TensorList self); // {"schema": "aten::_foreach_floor_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_frac(TensorList self); // {"schema": "aten::_foreach_frac(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_frac_(TensorList self); // {"schema": "aten::_foreach_frac_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_lerp(TensorList self, TensorList tensors1, TensorList weights); // {"schema": "aten::_foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_lerp_(TensorList self, TensorList tensors1, TensorList weights); // {"schema": "aten::_foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_lerp(TensorList self, TensorList tensors1, const Scalar & weight); // {"schema": "aten::_foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_lerp_(TensorList self, TensorList tensors1, const Scalar & weight); // {"schema": "aten::_foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_lgamma(TensorList self); // {"schema": "aten::_foreach_lgamma(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_lgamma_(TensorList self); // {"schema": "aten::_foreach_lgamma_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_log(TensorList self); // {"schema": "aten::_foreach_log(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_log_(TensorList self); // {"schema": "aten::_foreach_log_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_log10(TensorList self); // {"schema": "aten::_foreach_log10(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_log10_(TensorList self); // {"schema": "aten::_foreach_log10_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_log1p(TensorList self); // {"schema": "aten::_foreach_log1p(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_log1p_(TensorList self); // {"schema": "aten::_foreach_log1p_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_log2(TensorList self); // {"schema": "aten::_foreach_log2(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_log2_(TensorList self); // {"schema": "aten::_foreach_log2_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_neg(TensorList self); // {"schema": "aten::_foreach_neg(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_neg_(TensorList self); // {"schema": "aten::_foreach_neg_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_norm(TensorList self, const Scalar & ord); // {"schema": "aten::_foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_pow(TensorList self, TensorList exponent); // {"schema": "aten::_foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_pow(TensorList self, const Scalar & exponent); // {"schema": "aten::_foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_pow(TensorList self, ArrayRef exponent); // {"schema": "aten::_foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]", "dispatch": "True", "default": "False"}
+::std::vector _foreach_pow(const Scalar & self, TensorList exponent); // {"schema": "aten::_foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_pow_(TensorList self, TensorList exponent); // {"schema": "aten::_foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_pow_(TensorList self, const Scalar & exponent); // {"schema": "aten::_foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_pow_(TensorList self, ArrayRef exponent); // {"schema": "aten::_foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_reciprocal(TensorList self); // {"schema": "aten::_foreach_reciprocal(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_reciprocal_(TensorList self); // {"schema": "aten::_foreach_reciprocal_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_round(TensorList self); // {"schema": "aten::_foreach_round(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_round_(TensorList self); // {"schema": "aten::_foreach_round_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sigmoid(TensorList self); // {"schema": "aten::_foreach_sigmoid(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sigmoid_(TensorList self); // {"schema": "aten::_foreach_sigmoid_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sign(TensorList self); // {"schema": "aten::_foreach_sign(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sign_(TensorList self); // {"schema": "aten::_foreach_sign_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sin(TensorList self); // {"schema": "aten::_foreach_sin(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sin_(TensorList self); // {"schema": "aten::_foreach_sin_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sinh(TensorList self); // {"schema": "aten::_foreach_sinh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sinh_(TensorList self); // {"schema": "aten::_foreach_sinh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_sqrt(TensorList self); // {"schema": "aten::_foreach_sqrt(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_sqrt_(TensorList self); // {"schema": "aten::_foreach_sqrt_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_tan(TensorList self); // {"schema": "aten::_foreach_tan(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_tan_(TensorList self); // {"schema": "aten::_foreach_tan_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_tanh(TensorList self); // {"schema": "aten::_foreach_tanh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_tanh_(TensorList self); // {"schema": "aten::_foreach_tanh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+::std::vector _foreach_trunc(TensorList self); // {"schema": "aten::_foreach_trunc(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "False"}
+void _foreach_trunc_(TensorList self); // {"schema": "aten::_foreach_trunc_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_zero_(TensorList self); // {"schema": "aten::_foreach_zero_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "False"}
+void _foreach_copy_(TensorList self, TensorList src, bool non_blocking); // {"schema": "aten::_foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> ()", "dispatch": "True", "default": "False"}
+Tensor bucketize(const Tensor & self, const Tensor & boundaries, bool out_int32, bool right); // {"schema": "aten::bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & bucketize_out(const Tensor & self, const Tensor & boundaries, bool out_int32, bool right, Tensor & out); // {"schema": "aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor bucketize(const Scalar & self, const Tensor & boundaries, bool out_int32, bool right); // {"schema": "aten::bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor searchsorted(const Tensor & sorted_sequence, const Tensor & self, bool out_int32, bool right, c10::optional side, const c10::optional & sorter); // {"schema": "aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & searchsorted_out(const Tensor & sorted_sequence, const Tensor & self, bool out_int32, bool right, c10::optional side, const c10::optional & sorter, Tensor & out); // {"schema": "aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor searchsorted(const Tensor & sorted_sequence, const Scalar & self, bool out_int32, bool right, c10::optional side, const c10::optional & sorter); // {"schema": "aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & searchsorted_out(const Tensor & sorted_sequence, const Scalar & self, bool out_int32, bool right, c10::optional side, const c10::optional & sorter, Tensor & out); // {"schema": "aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _convert_indices_from_coo_to_csr(const Tensor & self, int64_t size, bool out_int32); // {"schema": "aten::_convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _convert_indices_from_coo_to_csr_out(const Tensor & self, int64_t size, bool out_int32, Tensor & out); // {"schema": "aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _convert_indices_from_csr_to_coo(const Tensor & crow_indices, const Tensor & col_indices, bool out_int32, bool transpose); // {"schema": "aten::_convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _convert_indices_from_csr_to_coo_out(const Tensor & crow_indices, const Tensor & col_indices, bool out_int32, bool transpose, Tensor & out); // {"schema": "aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & mse_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & out); // {"schema": "aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mse_loss(const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & mse_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, Tensor & grad_input); // {"schema": "aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mse_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor l1_loss(const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & multi_margin_loss_out(const Tensor & self, const Tensor & target, const Scalar & p, const Scalar & margin, const c10::optional & weight, int64_t reduction, Tensor & out); // {"schema": "aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor multi_margin_loss(const Tensor & self, const Tensor & target, const Scalar & p, const Scalar & margin, const c10::optional & weight, int64_t reduction); // {"schema": "aten::multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & multi_margin_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const Scalar & p, const Scalar & margin, const c10::optional & weight, int64_t reduction, Tensor & grad_input); // {"schema": "aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor multi_margin_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const Scalar & p, const Scalar & margin, const c10::optional & weight, int64_t reduction); // {"schema": "aten::multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & multilabel_margin_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & out); // {"schema": "aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor multilabel_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple multilabel_margin_loss_forward_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & output, Tensor & is_target); // {"schema": "aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple multilabel_margin_loss_forward(const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)", "dispatch": "True", "default": "False"}
+Tensor & multilabel_margin_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, const Tensor & is_target, Tensor & grad_input); // {"schema": "aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor multilabel_margin_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, const Tensor & is_target); // {"schema": "aten::multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, Tensor & out); // {"schema": "aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nll_loss_nd(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor nll_loss(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple nll_loss_forward_out(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, Tensor & output, Tensor & total_weight); // {"schema": "aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple nll_loss_forward(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)", "dispatch": "True", "default": "True"}
+Tensor & nll_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const Tensor & total_weight, Tensor & grad_input); // {"schema": "aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor nll_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const Tensor & total_weight); // {"schema": "aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & nll_loss2d_out(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, Tensor & out); // {"schema": "aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nll_loss2d(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple nll_loss2d_forward_out(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, Tensor & output, Tensor & total_weight); // {"schema": "aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple nll_loss2d_forward(const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)", "dispatch": "True", "default": "False"}
+Tensor & nll_loss2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const Tensor & total_weight, Tensor & grad_input); // {"schema": "aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor nll_loss2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional & weight, int64_t reduction, c10::SymInt ignore_index, const Tensor & total_weight); // {"schema": "aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & smooth_l1_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, double beta, Tensor & out); // {"schema": "aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor smooth_l1_loss(const Tensor & self, const Tensor & target, int64_t reduction, double beta); // {"schema": "aten::smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & smooth_l1_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, double beta, Tensor & grad_input); // {"schema": "aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor smooth_l1_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, double beta); // {"schema": "aten::smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & huber_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, double delta, Tensor & out); // {"schema": "aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor huber_loss(const Tensor & self, const Tensor & target, int64_t reduction, double delta); // {"schema": "aten::huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & huber_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, double delta, Tensor & grad_input); // {"schema": "aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor huber_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, double delta); // {"schema": "aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & soft_margin_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & out); // {"schema": "aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor soft_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & soft_margin_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, Tensor & grad_input); // {"schema": "aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor soft_margin_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction); // {"schema": "aten::soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & elu_out(const Tensor & self, const Scalar & alpha, const Scalar & scale, const Scalar & input_scale, Tensor & out); // {"schema": "aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor elu(const Tensor & self, const Scalar & alpha, const Scalar & scale, const Scalar & input_scale); // {"schema": "aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & elu_backward_out(const Tensor & grad_output, const Scalar & alpha, const Scalar & scale, const Scalar & input_scale, bool is_result, const Tensor & self_or_result, Tensor & grad_input); // {"schema": "aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor elu_backward(const Tensor & grad_output, const Scalar & alpha, const Scalar & scale, const Scalar & input_scale, bool is_result, const Tensor & self_or_result); // {"schema": "aten::elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & elu_(Tensor & self, const Scalar & alpha, const Scalar & scale, const Scalar & input_scale); // {"schema": "aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & glu_out(const Tensor & self, int64_t dim, Tensor & out); // {"schema": "aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor glu(const Tensor & self, int64_t dim); // {"schema": "aten::glu(Tensor self, int dim=-1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & glu_backward_out(const Tensor & grad_output, const Tensor & self, int64_t dim, Tensor & grad_input); // {"schema": "aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor glu_backward(const Tensor & grad_output, const Tensor & self, int64_t dim); // {"schema": "aten::glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor glu_jvp(const Tensor & glu, const Tensor & x, const Tensor & dx, int64_t dim); // {"schema": "aten::glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor glu_backward_jvp(const Tensor & grad_x, const Tensor & grad_glu, const Tensor & x, const Tensor & dgrad_glu, const Tensor & dx, int64_t dim); // {"schema": "aten::glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & hardsigmoid_out(const Tensor & self, Tensor & out); // {"schema": "aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardsigmoid(const Tensor & self); // {"schema": "aten::hardsigmoid(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & hardsigmoid_(Tensor & self); // {"schema": "aten::hardsigmoid_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hardsigmoid_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & grad_input); // {"schema": "aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardsigmoid_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & hardtanh_out(const Tensor & self, const Scalar & min_val, const Scalar & max_val, Tensor & out); // {"schema": "aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardtanh(const Tensor & self, const Scalar & min_val, const Scalar & max_val); // {"schema": "aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & hardtanh_backward_out(const Tensor & grad_output, const Tensor & self, const Scalar & min_val, const Scalar & max_val, Tensor & grad_input); // {"schema": "aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardtanh_backward(const Tensor & grad_output, const Tensor & self, const Scalar & min_val, const Scalar & max_val); // {"schema": "aten::hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & hardtanh_(Tensor & self, const Scalar & min_val, const Scalar & max_val); // {"schema": "aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & hardswish_out(const Tensor & self, Tensor & out); // {"schema": "aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardswish(const Tensor & self); // {"schema": "aten::hardswish(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & hardswish_(Tensor & self); // {"schema": "aten::hardswish_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor hardswish_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & leaky_relu_out(const Tensor & self, const Scalar & negative_slope, Tensor & out); // {"schema": "aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor leaky_relu(const Tensor & self, const Scalar & negative_slope); // {"schema": "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & leaky_relu_backward_out(const Tensor & grad_output, const Tensor & self, const Scalar & negative_slope, bool self_is_result, Tensor & grad_input); // {"schema": "aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor leaky_relu_backward(const Tensor & grad_output, const Tensor & self, const Scalar & negative_slope, bool self_is_result); // {"schema": "aten::leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & leaky_relu_(Tensor & self, const Scalar & negative_slope); // {"schema": "aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & log_sigmoid_out(const Tensor & self, Tensor & out); // {"schema": "aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor log_sigmoid(const Tensor & self); // {"schema": "aten::log_sigmoid(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple log_sigmoid_forward_out(const Tensor & self, Tensor & output, Tensor & buffer); // {"schema": "aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple log_sigmoid_forward(const Tensor & self); // {"schema": "aten::log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)", "dispatch": "True", "default": "False"}
+Tensor & log_sigmoid_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & buffer, Tensor & grad_input); // {"schema": "aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor log_sigmoid_backward(const Tensor & grad_output, const Tensor & self, const Tensor & buffer); // {"schema": "aten::log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & rrelu_with_noise_out(const Tensor & self, const Tensor & noise, const Scalar & lower, const Scalar & upper, bool training, c10::optional generator, Tensor & out); // {"schema": "aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor rrelu_with_noise(const Tensor & self, const Tensor & noise, const Scalar & lower, const Scalar & upper, bool training, c10::optional generator); // {"schema": "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, const Scalar & lower, const Scalar & upper, bool training, bool self_is_result); // {"schema": "aten::rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & rrelu_with_noise_(Tensor & self, const Tensor & noise, const Scalar & lower, const Scalar & upper, bool training, c10::optional generator); // {"schema": "aten::rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & softplus_out(const Tensor & self, const Scalar & beta, const Scalar & threshold, Tensor & out); // {"schema": "aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor softplus(const Tensor & self, const Scalar & beta, const Scalar & threshold); // {"schema": "aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & softplus_backward_out(const Tensor & grad_output, const Tensor & self, const Scalar & beta, const Scalar & threshold, Tensor & grad_input); // {"schema": "aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor softplus_backward(const Tensor & grad_output, const Tensor & self, const Scalar & beta, const Scalar & threshold); // {"schema": "aten::softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & softshrink_out(const Tensor & self, const Scalar & lambd, Tensor & out); // {"schema": "aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor softshrink(const Tensor & self, const Scalar & lambd); // {"schema": "aten::softshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & softshrink_backward_out(const Tensor & grad_output, const Tensor & self, const Scalar & lambd, Tensor & grad_input); // {"schema": "aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor softshrink_backward(const Tensor & grad_output, const Tensor & self, const Scalar & lambd); // {"schema": "aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & adaptive_avg_pool2d_out(const Tensor & self, c10::SymIntArrayRef output_size, Tensor & out); // {"schema": "aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor adaptive_avg_pool2d(const Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor mkldnn_adaptive_avg_pool2d(const Tensor & self, IntArrayRef output_size); // {"schema": "aten::mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & mkldnn_adaptive_avg_pool2d_out(const Tensor & self, IntArrayRef output_size, Tensor & out); // {"schema": "aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor mkldnn_adaptive_avg_pool2d_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _adaptive_avg_pool2d(const Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _adaptive_avg_pool2d_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & adaptive_avg_pool3d_out(const Tensor & self, c10::SymIntArrayRef output_size, Tensor & out); // {"schema": "aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor adaptive_avg_pool3d(const Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _adaptive_avg_pool3d(const Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & adaptive_avg_pool3d_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & grad_input); // {"schema": "aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _adaptive_avg_pool3d_backward(const Tensor & grad_output, const Tensor & self); // {"schema": "aten::_adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple adaptive_max_pool2d_out(const Tensor & self, IntArrayRef output_size, Tensor & out, Tensor & indices); // {"schema": "aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple adaptive_max_pool2d(const Tensor & self, IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor & adaptive_max_pool2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & indices, Tensor & grad_input); // {"schema": "aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor adaptive_max_pool2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & indices); // {"schema": "aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple adaptive_max_pool3d_out(const Tensor & self, IntArrayRef output_size, Tensor & out, Tensor & indices); // {"schema": "aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple adaptive_max_pool3d(const Tensor & self, IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor & adaptive_max_pool3d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & indices, Tensor & grad_input); // {"schema": "aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor adaptive_max_pool3d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & indices); // {"schema": "aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & avg_pool2d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, Tensor & out); // {"schema": "aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor avg_pool2d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override); // {"schema": "aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & avg_pool2d_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, Tensor & grad_input); // {"schema": "aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor avg_pool2d_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override); // {"schema": "aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & avg_pool3d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, Tensor & out); // {"schema": "aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor avg_pool3d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override); // {"schema": "aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & avg_pool3d_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, Tensor & grad_input); // {"schema": "aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor avg_pool3d_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override); // {"schema": "aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple fractional_max_pool2d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & random_samples, Tensor & output, Tensor & indices); // {"schema": "aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple fractional_max_pool2d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & random_samples); // {"schema": "aten::fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor & fractional_max_pool2d_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & indices, Tensor & grad_input); // {"schema": "aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor fractional_max_pool2d_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & indices); // {"schema": "aten::fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple fractional_max_pool3d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & random_samples, Tensor & output, Tensor & indices); // {"schema": "aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple fractional_max_pool3d(const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & random_samples); // {"schema": "aten::fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor & fractional_max_pool3d_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & indices, Tensor & grad_input); // {"schema": "aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor fractional_max_pool3d_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor & indices); // {"schema": "aten::fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple max_pool2d_with_indices_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out, Tensor & indices); // {"schema": "aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple max_pool2d_with_indices(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"}
+Tensor & max_pool2d_with_indices_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices, Tensor & grad_input); // {"schema": "aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor max_pool2d_with_indices_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices); // {"schema": "aten::max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple max_pool3d_with_indices_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out, Tensor & indices); // {"schema": "aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"}
+::std::tuple max_pool3d_with_indices(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor & max_pool3d_with_indices_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices, Tensor & grad_input); // {"schema": "aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor max_pool3d_with_indices_backward(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices); // {"schema": "aten::max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & max_unpool2d_out(const Tensor & self, const Tensor & indices, c10::SymIntArrayRef output_size, Tensor & out); // {"schema": "aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor max_unpool2d(const Tensor & self, const Tensor & indices, c10::SymIntArrayRef output_size); // {"schema": "aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & max_unpool3d_out(const Tensor & self, const Tensor & indices, c10::SymIntArrayRef output_size, IntArrayRef stride, IntArrayRef padding, Tensor & out); // {"schema": "aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor max_unpool3d(const Tensor & self, const Tensor & indices, c10::SymIntArrayRef output_size, IntArrayRef stride, IntArrayRef padding); // {"schema": "aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & reflection_pad1d_out(const Tensor & self, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor reflection_pad1d(const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & reflection_pad1d_backward_out(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding, Tensor & grad_input); // {"schema": "aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor reflection_pad1d_backward(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & reflection_pad2d_out(const Tensor & self, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor reflection_pad2d(const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & reflection_pad2d_backward_out(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding, Tensor & grad_input); // {"schema": "aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor reflection_pad2d_backward(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & reflection_pad3d_out(const Tensor & self, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor reflection_pad3d(const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & reflection_pad3d_backward_out(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding, Tensor & grad_input); // {"schema": "aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor reflection_pad3d_backward(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & replication_pad1d_out(const Tensor & self, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor replication_pad1d(const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & replication_pad1d_backward_out(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding, Tensor & grad_input); // {"schema": "aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor replication_pad1d_backward(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & replication_pad2d_out(const Tensor & self, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor replication_pad2d(const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & replication_pad2d_backward_out(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding, Tensor & grad_input); // {"schema": "aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor replication_pad2d_backward(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & replication_pad3d_out(const Tensor & self, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor replication_pad3d(const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & replication_pad3d_backward_out(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding, Tensor & grad_input); // {"schema": "aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor replication_pad3d_backward(const Tensor & grad_output, const Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _pad_circular(const Tensor & self, c10::SymIntArrayRef pad); // {"schema": "aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _pad_enum(const Tensor & self, c10::SymIntArrayRef pad, int64_t mode, c10::optional value); // {"schema": "aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor pad(const Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode, c10::optional value); // {"schema": "aten::pad(Tensor self, SymInt[] pad, str mode=\"constant\", float? value=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_linear1d(const Tensor & input, OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors); // {"schema": "aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_bilinear2d(const Tensor & input, OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors); // {"schema": "aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _upsample_bilinear2d_aa(const Tensor & input, OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors); // {"schema": "aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_trilinear3d(const Tensor & input, OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors); // {"schema": "aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_bicubic2d(const Tensor & input, OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors); // {"schema": "aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _upsample_bicubic2d_aa(const Tensor & input, OptionalSymIntArrayRef output_size, bool align_corners, c10::optional> scale_factors); // {"schema": "aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_nearest1d(const Tensor & input, OptionalSymIntArrayRef output_size, c10::optional> scale_factors); // {"schema": "aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _upsample_nearest_exact1d(const Tensor & input, OptionalSymIntArrayRef output_size, c10::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_nearest2d(const Tensor & input, OptionalSymIntArrayRef output_size, c10::optional> scale_factors); // {"schema": "aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _upsample_nearest_exact2d(const Tensor & input, OptionalSymIntArrayRef output_size, c10::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor upsample_nearest3d(const Tensor & input, OptionalSymIntArrayRef output_size, c10::optional> scale_factors); // {"schema": "aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _upsample_nearest_exact3d(const Tensor & input, OptionalSymIntArrayRef output_size, c10::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & upsample_linear1d_out(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales, Tensor & out); // {"schema": "aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_linear1d(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales); // {"schema": "aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_linear1d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales, Tensor & grad_input); // {"schema": "aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_linear1d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales); // {"schema": "aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_bilinear2d_out(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_bilinear2d(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_bilinear2d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_bilinear2d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _upsample_bilinear2d_aa_out(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _upsample_bilinear2d_aa(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _upsample_bilinear2d_aa_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _upsample_bilinear2d_aa_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_bicubic2d_out(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_bicubic2d(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_bicubic2d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_bicubic2d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _upsample_bicubic2d_aa_out(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _upsample_bicubic2d_aa(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _upsample_bicubic2d_aa_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _upsample_bicubic2d_aa_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_trilinear3d_out(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_trilinear3d(const Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_trilinear3d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_trilinear3d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_nearest1d_out(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales, Tensor & out); // {"schema": "aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _upsample_nearest_exact1d_out(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales, Tensor & out); // {"schema": "aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_nearest1d(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales); // {"schema": "aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _upsample_nearest_exact1d(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales); // {"schema": "aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_nearest1d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales, Tensor & grad_input); // {"schema": "aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _upsample_nearest_exact1d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales, Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_nearest1d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales); // {"schema": "aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _upsample_nearest_exact1d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales); // {"schema": "aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_nearest2d_out(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _upsample_nearest_exact2d_out(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_nearest2d(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _upsample_nearest_exact2d(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_nearest2d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _upsample_nearest_exact2d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_nearest2d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _upsample_nearest_exact2d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_nearest3d_out(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _upsample_nearest_exact3d_out(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, Tensor & out); // {"schema": "aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_nearest3d(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _upsample_nearest_exact3d(const Tensor & self, c10::SymIntArrayRef output_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & upsample_nearest3d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & _upsample_nearest_exact3d_backward_out(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w, Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor upsample_nearest3d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _upsample_nearest_exact3d_backward(const Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, c10::optional scales_d, c10::optional scales_h, c10::optional scales_w); // {"schema": "aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sigmoid_backward_out(const Tensor & grad_output, const Tensor & output, Tensor & grad_input); // {"schema": "aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor sigmoid_backward(const Tensor & grad_output, const Tensor & output); // {"schema": "aten::sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & logit_backward_out(const Tensor & grad_output, const Tensor & self, c10::optional eps, Tensor & grad_input); // {"schema": "aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor logit_backward(const Tensor & grad_output, const Tensor & self, c10::optional eps); // {"schema": "aten::logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & tanh_backward_out(const Tensor & grad_output, const Tensor & output, Tensor & grad_input); // {"schema": "aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor tanh_backward(const Tensor & grad_output, const Tensor & output); // {"schema": "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & slow_conv_transpose2d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, Tensor & out); // {"schema": "aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor slow_conv_transpose2d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & slow_conv_transpose3d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, Tensor & out); // {"schema": "aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor slow_conv_transpose3d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & thnn_conv2d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor thnn_conv2d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & _slow_conv2d_forward_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, Tensor & output); // {"schema": "aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _slow_conv2d_forward(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _slow_conv2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, Tensor & grad_input, Tensor & grad_weight, Tensor & grad_bias); // {"schema": "aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"}
+::std::tuple _slow_conv2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask); // {"schema": "aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", "dispatch": "True", "default": "False"}
+const Tensor & _conv_depthwise2d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, const Tensor & out); // {"schema": "aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _conv_depthwise2d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor conv_depthwise3d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & slow_conv3d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, Tensor & out); // {"schema": "aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor slow_conv3d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & slow_conv3d_forward_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, Tensor & output); // {"schema": "aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor slow_conv3d_forward(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor slow_conv_dilated2d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor slow_conv_dilated3d(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & col2im_out(const Tensor & self, c10::SymIntArrayRef output_size, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride, Tensor & out); // {"schema": "aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor col2im(const Tensor & self, c10::SymIntArrayRef output_size, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride); // {"schema": "aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor column_stack(TensorList tensors); // {"schema": "aten::column_stack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & column_stack_out(TensorList tensors, Tensor & out); // {"schema": "aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & im2col_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride, Tensor & out); // {"schema": "aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor im2col(const Tensor & self, IntArrayRef kernel_size, IntArrayRef dilation, IntArrayRef padding, IntArrayRef stride); // {"schema": "aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor isfinite(const Tensor & self); // {"schema": "aten::isfinite(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor isinf(const Tensor & self); // {"schema": "aten::isinf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+void record_stream(Tensor & self, Stream s); // {"schema": "aten::record_stream(Tensor(a!) self, Stream s) -> ()", "dispatch": "True", "default": "False"}
+Tensor isposinf(const Tensor & self); // {"schema": "aten::isposinf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & isposinf_out(const Tensor & self, Tensor & out); // {"schema": "aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor isneginf(const Tensor & self); // {"schema": "aten::isneginf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & isneginf_out(const Tensor & self, Tensor & out); // {"schema": "aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _add_batch_dim(const Tensor & self, int64_t batch_dim, int64_t level); // {"schema": "aten::_add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _remove_batch_dim(const Tensor & self, int64_t level, int64_t batch_size, int64_t out_dim); // {"schema": "aten::_remove_batch_dim(Tensor self, int level, int batch_size, int out_dim) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor special_entr(const Tensor & self); // {"schema": "aten::special_entr(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_entr_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_ndtri(const Tensor & self); // {"schema": "aten::special_ndtri(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_ndtri_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_log_ndtr(const Tensor & self); // {"schema": "aten::special_log_ndtr(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_log_ndtr_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_expm1(const Tensor & self); // {"schema": "aten::special_expm1(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_expm1_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_exp2(const Tensor & self); // {"schema": "aten::special_exp2(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_exp2_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_psi(const Tensor & self); // {"schema": "aten::special_psi(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_psi_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_digamma(const Tensor & self); // {"schema": "aten::special_digamma(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_digamma_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_gammaln(const Tensor & self); // {"schema": "aten::special_gammaln(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_gammaln_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_erf(const Tensor & self); // {"schema": "aten::special_erf(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_erf_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_erfc(const Tensor & self); // {"schema": "aten::special_erfc(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_erfc_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_erfcx(const Tensor & self); // {"schema": "aten::special_erfcx(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_erfcx_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_erfinv(const Tensor & self); // {"schema": "aten::special_erfinv(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_erfinv_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_ndtr(const Tensor & self); // {"schema": "aten::special_ndtr(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_ndtr_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_xlog1py(const Tensor & self, const Tensor & other); // {"schema": "aten::special_xlog1py(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_xlog1py(const Scalar & self, const Tensor & other); // {"schema": "aten::special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_xlog1py(const Tensor & self, const Scalar & other); // {"schema": "aten::special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_xlog1py_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_xlog1py_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_xlog1py_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_xlogy(const Tensor & self, const Tensor & other); // {"schema": "aten::special_xlogy(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor special_xlogy(const Scalar & self, const Tensor & other); // {"schema": "aten::special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor special_xlogy(const Tensor & self, const Scalar & other); // {"schema": "aten::special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_xlogy_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & special_xlogy_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & special_xlogy_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_zeta(const Tensor & self, const Tensor & other); // {"schema": "aten::special_zeta(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_zeta(const Scalar & self, const Tensor & other); // {"schema": "aten::special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_zeta(const Tensor & self, const Scalar & other); // {"schema": "aten::special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_zeta_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_zeta_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_zeta_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_i0(const Tensor & self); // {"schema": "aten::special_i0(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_i0_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_i0e(const Tensor & self); // {"schema": "aten::special_i0e(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_i0e_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_i1(const Tensor & self); // {"schema": "aten::special_i1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_i1_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_i1e(const Tensor & self); // {"schema": "aten::special_i1e(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_i1e_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_logit(const Tensor & self, c10::optional eps); // {"schema": "aten::special_logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_logit_out(const Tensor & self, c10::optional eps, Tensor & out); // {"schema": "aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_polygamma(int64_t n, const Tensor & self); // {"schema": "aten::special_polygamma(int n, Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_polygamma_out(int64_t n, const Tensor & self, Tensor & out); // {"schema": "aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_logsumexp(const Tensor & self, IntArrayRef dim, bool keepdim); // {"schema": "aten::special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_logsumexp_out(const Tensor & self, IntArrayRef dim, bool keepdim, Tensor & out); // {"schema": "aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_expit(const Tensor & self); // {"schema": "aten::special_expit(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_expit_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_sinc(const Tensor & self); // {"schema": "aten::special_sinc(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_sinc_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_round(const Tensor & self, int64_t decimals); // {"schema": "aten::special_round(Tensor self, *, int decimals=0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_round_out(const Tensor & self, int64_t decimals, Tensor & out); // {"schema": "aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_log1p(const Tensor & self); // {"schema": "aten::special_log1p(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_log1p_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_log_softmax(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_gammainc_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_gammainc(const Tensor & self, const Tensor & other); // {"schema": "aten::special_gammainc(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_gammaincc_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_gammaincc(const Tensor & self, const Tensor & other); // {"schema": "aten::special_gammaincc(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor special_multigammaln(const Tensor & self, int64_t p); // {"schema": "aten::special_multigammaln(Tensor self, int p) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & special_multigammaln_out(const Tensor & self, int64_t p, Tensor & out); // {"schema": "aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor special_softmax(const Tensor & self, int64_t dim, c10::optional dtype); // {"schema": "aten::special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fft_fft(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm); // {"schema": "aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_fft_out(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_ifft(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm); // {"schema": "aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_ifft_out(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_rfft(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm); // {"schema": "aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_rfft_out(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_irfft(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm); // {"schema": "aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_irfft_out(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_hfft(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm); // {"schema": "aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_hfft_out(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_ihfft(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm); // {"schema": "aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_ihfft_out(const Tensor & self, c10::optional n, int64_t dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_fft2(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_fft2_out(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_ifft2(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_ifft2_out(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_rfft2(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_rfft2_out(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_irfft2(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_irfft2_out(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_hfft2(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+const Tensor & fft_hfft2_out(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm, const Tensor & out); // {"schema": "aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_ihfft2(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+const Tensor & fft_ihfft2_out(const Tensor & self, OptionalSymIntArrayRef s, IntArrayRef dim, c10::optional norm, const Tensor & out); // {"schema": "aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_fftn(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_fftn_out(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_ifftn(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_ifftn_out(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_rfftn(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_rfftn_out(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_irfftn(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & fft_irfftn_out(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm, Tensor & out); // {"schema": "aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_hfftn(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+const Tensor & fft_hfftn_out(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm, const Tensor & out); // {"schema": "aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_ihfftn(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm); // {"schema": "aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"}
+const Tensor & fft_ihfftn_out(const Tensor & self, OptionalSymIntArrayRef s, OptionalIntArrayRef dim, c10::optional norm, const Tensor & out); // {"schema": "aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor fft_fftfreq(int64_t n, double d, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fft_fftfreq_out(int64_t n, double d, Tensor & out); // {"schema": "aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor fft_rfftfreq(int64_t n, double d, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & fft_rfftfreq_out(int64_t n, double d, Tensor & out); // {"schema": "aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor fft_fftshift(const Tensor & self, OptionalIntArrayRef dim); // {"schema": "aten::fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor fft_ifftshift(const Tensor & self, OptionalIntArrayRef dim); // {"schema": "aten::fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple linalg_cholesky_ex(const Tensor & self, bool upper, bool check_errors); // {"schema": "aten::linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_cholesky_ex_out(const Tensor & self, bool upper, bool check_errors, Tensor & L, Tensor & info); // {"schema": "aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)", "dispatch": "True", "default": "False"}
+Tensor linalg_cholesky(const Tensor & self, bool upper); // {"schema": "aten::linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_cholesky_out(const Tensor & self, bool upper, Tensor & out); // {"schema": "aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_cross(const Tensor & self, const Tensor & other, int64_t dim); // {"schema": "aten::linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & linalg_cross_out(const Tensor & self, const Tensor & other, int64_t dim, Tensor & out); // {"schema": "aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_lu_factor(const Tensor & A, bool pivot); // {"schema": "aten::linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_lu_factor_out(const Tensor & A, bool pivot, Tensor & LU, Tensor & pivots); // {"schema": "aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_lu_factor_ex(const Tensor & A, bool pivot, bool check_errors); // {"schema": "aten::linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_lu_factor_ex_out(const Tensor & A, bool pivot, bool check_errors, Tensor & LU, Tensor & pivots, Tensor & info); // {"schema": "aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_lu(const Tensor & A, bool pivot); // {"schema": "aten::linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_lu_out(const Tensor & A, bool pivot, Tensor & P, Tensor & L, Tensor & U); // {"schema": "aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)", "dispatch": "True", "default": "False"}
+Tensor linalg_lu_solve(const Tensor & LU, const Tensor & pivots, const Tensor & B, bool left, bool adjoint); // {"schema": "aten::linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & linalg_lu_solve_out(const Tensor & LU, const Tensor & pivots, const Tensor & B, bool left, bool adjoint, Tensor & out); // {"schema": "aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple _linalg_det(const Tensor & A); // {"schema": "aten::_linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)", "dispatch": "True", "default": "True"}
+::std::tuple _linalg_det_out(const Tensor & A, Tensor & result, Tensor & LU, Tensor & pivots); // {"schema": "aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)", "dispatch": "True", "default": "False"}
+Tensor linalg_det(const Tensor & A); // {"schema": "aten::linalg_det(Tensor A) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_det_out(const Tensor & A, Tensor & out); // {"schema": "aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor det(const Tensor & self); // {"schema": "aten::det(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple linalg_ldl_factor_ex(const Tensor & self, bool hermitian, bool check_errors); // {"schema": "aten::linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_ldl_factor_ex_out(const Tensor & self, bool hermitian, bool check_errors, Tensor & LD, Tensor & pivots, Tensor & info); // {"schema": "aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_ldl_factor(const Tensor & self, bool hermitian); // {"schema": "aten::linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_ldl_factor_out(const Tensor & self, bool hermitian, Tensor & LD, Tensor & pivots); // {"schema": "aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots)", "dispatch": "False", "default": "True"}
+Tensor linalg_ldl_solve(const Tensor & LD, const Tensor & pivots, const Tensor & B, bool hermitian); // {"schema": "aten::linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & linalg_ldl_solve_out(const Tensor & LD, const Tensor & pivots, const Tensor & B, bool hermitian, Tensor & out); // {"schema": "aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_lstsq(const Tensor & self, const Tensor & b, c10::optional rcond, c10::optional driver); // {"schema": "aten::linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_lstsq_out(const Tensor & self, const Tensor & b, c10::optional rcond, c10::optional driver, Tensor & solution, Tensor & residuals, Tensor & rank, Tensor & singular_values); // {"schema": "aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)", "dispatch": "True", "default": "False"}
+Tensor linalg_matmul(const Tensor & self, const Tensor & other); // {"schema": "aten::linalg_matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matmul_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_vecdot(const Tensor & x, const Tensor & y, int64_t dim); // {"schema": "aten::linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_vecdot_out(const Tensor & x, const Tensor & y, int64_t dim, Tensor & out); // {"schema": "aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_matrix_exp(const Tensor & self); // {"schema": "aten::linalg_matrix_exp(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _linalg_slogdet(const Tensor & A); // {"schema": "aten::_linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)", "dispatch": "True", "default": "True"}
+::std::tuple _linalg_slogdet_out(const Tensor & A, Tensor & sign, Tensor & logabsdet, Tensor & LU, Tensor & pivots); // {"schema": "aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_slogdet(const Tensor & A); // {"schema": "aten::linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_slogdet_out(const Tensor & A, Tensor & sign, Tensor & logabsdet); // {"schema": "aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)", "dispatch": "False", "default": "True"}
+::std::tuple slogdet(const Tensor & self); // {"schema": "aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)", "dispatch": "False", "default": "True"}
+::std::tuple slogdet_out(const Tensor & self, Tensor & sign, Tensor & logabsdet); // {"schema": "aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)", "dispatch": "False", "default": "True"}
+Tensor logdet(const Tensor & self); // {"schema": "aten::logdet(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+::std::tuple linalg_eig(const Tensor & self); // {"schema": "aten::linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_eig_out(const Tensor & self, Tensor & eigenvalues, Tensor & eigenvectors); // {"schema": "aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "True", "default": "False"}
+Tensor _linalg_eigvals(const Tensor & self); // {"schema": "aten::_linalg_eigvals(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor linalg_eigvals(const Tensor & self); // {"schema": "aten::linalg_eigvals(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_eigvals_out(const Tensor & self, Tensor & out); // {"schema": "aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple _linalg_eigh(const Tensor & A, c10::string_view UPLO, bool compute_v); // {"schema": "aten::_linalg_eigh(Tensor A, str UPLO=\"L\", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "True", "default": "True"}
+::std::tuple _linalg_eigh_out(const Tensor & A, c10::string_view UPLO, bool compute_v, Tensor & eigenvalues, Tensor & eigenvectors); // {"schema": "aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO=\"L\", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_eigh(const Tensor & self, c10::string_view UPLO); // {"schema": "aten::linalg_eigh(Tensor self, str UPLO=\"L\") -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_eigh_out(const Tensor & self, c10::string_view UPLO, Tensor & eigvals, Tensor & eigvecs); // {"schema": "aten::linalg_eigh.eigvals(Tensor self, str UPLO=\"L\", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "False", "default": "True"}
+Tensor linalg_eigvalsh(const Tensor & self, c10::string_view UPLO); // {"schema": "aten::linalg_eigvalsh(Tensor self, str UPLO=\"L\") -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_eigvalsh_out(const Tensor & self, c10::string_view UPLO, Tensor & out); // {"schema": "aten::linalg_eigvalsh.out(Tensor self, str UPLO=\"L\", *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_householder_product(const Tensor & input, const Tensor & tau); // {"schema": "aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & linalg_householder_product_out(const Tensor & input, const Tensor & tau, Tensor & out); // {"schema": "aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_inv_ex(const Tensor & A, bool check_errors); // {"schema": "aten::linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_inv_ex_out(const Tensor & A, bool check_errors, Tensor & inverse, Tensor & info); // {"schema": "aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)", "dispatch": "True", "default": "False"}
+Tensor linalg_inv(const Tensor & A); // {"schema": "aten::linalg_inv(Tensor A) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_inv_out(const Tensor & A, Tensor & out); // {"schema": "aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor inverse(const Tensor & self); // {"schema": "aten::inverse(Tensor self) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & inverse_out(const Tensor & self, Tensor & out); // {"schema": "aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor inner(const Tensor & self, const Tensor & other); // {"schema": "aten::inner(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & inner_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor outer(const Tensor & self, const Tensor & vec2); // {"schema": "aten::outer(Tensor self, Tensor vec2) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & outer_out(const Tensor & self, const Tensor & vec2, Tensor & out); // {"schema": "aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor ger(const Tensor & self, const Tensor & vec2); // {"schema": "aten::ger(Tensor self, Tensor vec2) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & ger_out(const Tensor & self, const Tensor & vec2, Tensor & out); // {"schema": "aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_norm(const Tensor & self, const c10::optional & ord, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor linalg_norm(const Tensor & self, c10::string_view ord, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_norm_out(const Tensor & self, const c10::optional & ord, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & linalg_norm_out(const Tensor & self, c10::string_view ord, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_vector_norm(const Tensor & self, const Scalar & ord, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & linalg_vector_norm_out(const Tensor & self, const Scalar & ord, OptionalIntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor linalg_matrix_norm(const Tensor & self, const Scalar & ord, IntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_norm_out(const Tensor & self, const Scalar & ord, IntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_matrix_norm(const Tensor & self, c10::string_view ord, IntArrayRef dim, bool keepdim, c10::optional dtype); // {"schema": "aten::linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_norm_out(const Tensor & self, c10::string_view ord, IntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple _linalg_svd(const Tensor & A, bool full_matrices, bool compute_uv, c10::optional driver); // {"schema": "aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)", "dispatch": "True", "default": "True"}
+::std::tuple _linalg_svd_out(const Tensor & A, bool full_matrices, bool compute_uv, c10::optional driver, Tensor & U, Tensor & S, Tensor & Vh); // {"schema": "aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_svd(const Tensor & A, bool full_matrices, c10::optional driver); // {"schema": "aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_svd_out(const Tensor & A, bool full_matrices, c10::optional driver, Tensor & U, Tensor & S, Tensor & Vh); // {"schema": "aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)", "dispatch": "False", "default": "True"}
+Tensor linalg_svdvals(const Tensor & A, c10::optional driver); // {"schema": "aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_svdvals_out(const Tensor & A, c10::optional driver, Tensor & out); // {"schema": "aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_cond(const Tensor & self, const c10::optional & p); // {"schema": "aten::linalg_cond(Tensor self, Scalar? p=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_cond_out(const Tensor & self, const c10::optional & p, Tensor & out); // {"schema": "aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_cond(const Tensor & self, c10::string_view p); // {"schema": "aten::linalg_cond.p_str(Tensor self, str p) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_cond_out(const Tensor & self, c10::string_view p, Tensor & out); // {"schema": "aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_pinv(const Tensor & self, const c10::optional & atol, const c10::optional & rtol, bool hermitian); // {"schema": "aten::linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & linalg_pinv_out(const Tensor & self, const c10::optional & atol, const c10::optional & rtol, bool hermitian, Tensor & out); // {"schema": "aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor linalg_pinv(const Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian); // {"schema": "aten::linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_pinv_out(const Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian, Tensor & out); // {"schema": "aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_pinv(const Tensor & self, double rcond, bool hermitian); // {"schema": "aten::linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor linalg_pinv(const Tensor & self, const Tensor & rcond, bool hermitian); // {"schema": "aten::linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_pinv_out(const Tensor & self, double rcond, bool hermitian, Tensor & out); // {"schema": "aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor & linalg_pinv_out(const Tensor & self, const Tensor & rcond, bool hermitian, Tensor & out); // {"schema": "aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple _linalg_solve_ex(const Tensor & A, const Tensor & B, bool left, bool check_errors); // {"schema": "aten::_linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"}
+::std::tuple _linalg_solve_ex_out(const Tensor & A, const Tensor & B, bool left, bool check_errors, Tensor & result, Tensor & LU, Tensor & pivots, Tensor & info); // {"schema": "aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info)", "dispatch": "True", "default": "False"}
+::std::tuple linalg_solve_ex(const Tensor & A, const Tensor & B, bool left, bool check_errors); // {"schema": "aten::linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_solve_ex_out(const Tensor & A, const Tensor & B, bool left, bool check_errors, Tensor & result, Tensor & info); // {"schema": "aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info)", "dispatch": "False", "default": "True"}
+Tensor linalg_solve(const Tensor & A, const Tensor & B, bool left); // {"schema": "aten::linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_solve_out(const Tensor & A, const Tensor & B, bool left, Tensor & out); // {"schema": "aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_tensorinv(const Tensor & self, int64_t ind); // {"schema": "aten::linalg_tensorinv(Tensor self, int ind=2) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_tensorinv_out(const Tensor & self, int64_t ind, Tensor & out); // {"schema": "aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_tensorsolve(const Tensor & self, const Tensor & other, OptionalIntArrayRef dims); // {"schema": "aten::linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_tensorsolve_out(const Tensor & self, const Tensor & other, OptionalIntArrayRef dims, Tensor & out); // {"schema": "aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+::std::tuple linalg_qr(const Tensor & A, c10::string_view mode); // {"schema": "aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)", "dispatch": "True", "default": "True"}
+::std::tuple linalg_qr_out(const Tensor & A, c10::string_view mode, Tensor & Q, Tensor & R); // {"schema": "aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)", "dispatch": "True", "default": "False"}
+Tensor linalg_matrix_power(const Tensor & self, int64_t n); // {"schema": "aten::linalg_matrix_power(Tensor self, int n) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_power_out(const Tensor & self, int64_t n, Tensor & out); // {"schema": "aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_matrix_rank(const Tensor & input, const c10::optional & atol, const c10::optional & rtol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_rank_out(const Tensor & input, const c10::optional & atol, const c10::optional & rtol, bool hermitian, Tensor & out); // {"schema": "aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_matrix_rank(const Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_rank_out(const Tensor & self, c10::optional atol, c10::optional rtol, bool hermitian, Tensor & out); // {"schema": "aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_matrix_rank(const Tensor & self, double tol, bool hermitian); // {"schema": "aten::linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_rank_out(const Tensor & self, double tol, bool hermitian, Tensor & out); // {"schema": "aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_matrix_rank(const Tensor & input, const Tensor & tol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_matrix_rank_out(const Tensor & input, const Tensor & tol, bool hermitian, Tensor & out); // {"schema": "aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor linalg_multi_dot(TensorList tensors); // {"schema": "aten::linalg_multi_dot(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor & linalg_multi_dot_out(TensorList tensors, Tensor & out); // {"schema": "aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"}
+Tensor nested_to_padded_tensor(const Tensor & self, double padding, OptionalIntArrayRef output_size); // {"schema": "aten::nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _test_serialization_subcmul(const Tensor & self, const Tensor & other, const Scalar & alpha); // {"schema": "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _test_parallel_materialize(const Tensor & self, int64_t num_parallel, bool skip_first); // {"schema": "aten::_test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _test_optional_intlist(const Tensor & values, OptionalIntArrayRef addends); // {"schema": "aten::_test_optional_intlist(Tensor values, int[]? addends) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _test_optional_filled_intlist(const Tensor & values, OptionalIntArrayRef addends); // {"schema": "aten::_test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _test_optional_floatlist(const Tensor & values, c10::optional> addends); // {"schema": "aten::_test_optional_floatlist(Tensor values, float[]? addends) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _test_string_default(const Tensor & dummy, c10::string_view a, c10::string_view b); // {"schema": "aten::_test_string_default(Tensor dummy, str a=\"\\\"'\\\\\", str b='\"\\'\\\\') -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _test_ambiguous_defaults(const Tensor & dummy, int64_t a, int64_t b); // {"schema": "aten::_test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _test_ambiguous_defaults(const Tensor & dummy, int64_t a, c10::string_view b); // {"schema": "aten::_test_ambiguous_defaults.b(Tensor dummy, int a=2, str b=\"2\") -> Tensor", "dispatch": "False", "default": "True"}
+Tensor _test_warn_in_autograd(const Tensor & self); // {"schema": "aten::_test_warn_in_autograd(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _test_autograd_multiple_dispatch(const Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _test_autograd_multiple_dispatch(const Tensor & self, bool b); // {"schema": "aten::_test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _test_autograd_multiple_dispatch_view(const Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"}
+Tensor _test_autograd_multiple_dispatch_view_copy(const Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor segment_reduce(const Tensor & data, c10::string_view reduce, const c10::optional & lengths, const c10::optional & indices, const c10::optional & offsets, int64_t axis, bool unsafe, const c10::optional & initial); // {"schema": "aten::segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _segment_reduce_backward(const Tensor & grad, const Tensor & output, const Tensor & data, c10::string_view reduce, const c10::optional & lengths, const c10::optional & offsets, int64_t axis, const c10::optional & initial); // {"schema": "aten::_segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value); // {"schema": "aten::pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor", "dispatch": "False", "default": "True"}
+Tensor flatten_dense_tensors(TensorList tensors); // {"schema": "aten::flatten_dense_tensors(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"}
+::std::vector unflatten_dense_tensors(const Tensor & flat, TensorList tensors); // {"schema": "aten::unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"}
+Tensor _nested_tensor_from_tensor_list(TensorList list, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); // {"schema": "aten::_nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _fw_primal_copy(const Tensor & self, int64_t level); // {"schema": "aten::_fw_primal_copy(Tensor self, int level) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _make_dual_copy(const Tensor & primal, const Tensor & tangent, int64_t level); // {"schema": "aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor view_as_real_copy(const Tensor & self); // {"schema": "aten::view_as_real_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor view_as_complex_copy(const Tensor & self); // {"schema": "aten::view_as_complex_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _conj_copy(const Tensor & self); // {"schema": "aten::_conj_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _neg_view_copy(const Tensor & self); // {"schema": "aten::_neg_view_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor as_strided_copy(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset); // {"schema": "aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _sparse_broadcast_to_copy(const Tensor & self, IntArrayRef size); // {"schema": "aten::_sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor diagonal_copy(const Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor expand_copy(const Tensor & self, c10::SymIntArrayRef size, bool implicit); // {"schema": "aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor permute_copy(const Tensor & self, IntArrayRef dims); // {"schema": "aten::permute_copy(Tensor self, int[] dims) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _reshape_alias_copy(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor select_copy(const Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor detach_copy(const Tensor & self); // {"schema": "aten::detach_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor slice_copy(const Tensor & self, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step); // {"schema": "aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", "dispatch": "True", "default": "True"}
+::std::vector split_copy(const Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"}
+::std::vector split_with_sizes_copy(const Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"}
+Tensor squeeze_copy(const Tensor & self); // {"schema": "aten::squeeze_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor squeeze_copy(const Tensor & self, int64_t dim); // {"schema": "aten::squeeze_copy.dim(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor squeeze_copy(const Tensor & self, IntArrayRef dim); // {"schema": "aten::squeeze_copy.dims(Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor t_copy(const Tensor & self); // {"schema": "aten::t_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor transpose_copy(const Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor unsqueeze_copy(const Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze_copy(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _indices_copy(const Tensor & self); // {"schema": "aten::_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor _values_copy(const Tensor & self); // {"schema": "aten::_values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor indices_copy(const Tensor & self); // {"schema": "aten::indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor values_copy(const Tensor & self); // {"schema": "aten::values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor crow_indices_copy(const Tensor & self); // {"schema": "aten::crow_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor col_indices_copy(const Tensor & self); // {"schema": "aten::col_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor ccol_indices_copy(const Tensor & self); // {"schema": "aten::ccol_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor row_indices_copy(const Tensor & self); // {"schema": "aten::row_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+::std::vector unbind_copy(const Tensor & self, int64_t dim); // {"schema": "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"}
+void unbind_copy_out(const Tensor & self, int64_t dim, TensorList out); // {"schema": "aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void split_copy_out(const Tensor & self, c10::SymInt split_size, int64_t dim, TensorList out); // {"schema": "aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void split_with_sizes_copy_out(const Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, TensorList out); // {"schema": "aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+Tensor view_copy(const Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::view_copy(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor view_copy(const Tensor & self, ScalarType dtype); // {"schema": "aten::view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor unfold_copy(const Tensor & self, int64_t dimension, int64_t size, int64_t step); // {"schema": "aten::unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor alias_copy(const Tensor & self); // {"schema": "aten::alias_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor to_padded_tensor(const Tensor & self, double padding, OptionalSymIntArrayRef output_size); // {"schema": "aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _nested_tensor_softmax_with_shape(const Tensor & self, const Tensor & query); // {"schema": "aten::_nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor _transformer_encoder_layer_fwd(const Tensor & src, int64_t embed_dim, int64_t num_heads, const Tensor & qkv_weight, const Tensor & qkv_bias, const Tensor & proj_weight, const Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const Tensor & norm_weight_1, const Tensor & norm_bias_1, const Tensor & norm_weight_2, const Tensor & norm_bias_2, const Tensor & ffn_weight_1, const Tensor & ffn_bias_1, const Tensor & ffn_weight_2, const Tensor & ffn_bias_2, const c10::optional & mask, c10::optional mask_type); // {"schema": "aten::_transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor", "dispatch": "True", "default": "False"}
+::std::tuple _native_multi_head_attention(const Tensor & query, const Tensor & key, const Tensor & value, int64_t embed_dim, int64_t num_head, const Tensor & qkv_weight, const Tensor & qkv_bias, const Tensor & proj_weight, const Tensor & proj_bias, const c10::optional & mask, bool need_weights, bool average_attn_weights, c10::optional mask_type); // {"schema": "aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor scaled_dot_product_attention(const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & attn_mask, double dropout_p, bool is_causal, c10::optional scale); // {"schema": "aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor", "dispatch": "False", "default": "True"}
+int64_t _fused_sdp_choice(const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & attn_mask, double dropout_p, bool is_causal, c10::optional scale); // {"schema": "aten::_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_attention_math(const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & attn_mask, double dropout_p, bool is_causal, const c10::optional & dropout_mask, c10::optional scale); // {"schema": "aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"}
+::std::tuple _scaled_dot_product_flash_attention(const Tensor & query, const Tensor & key, const Tensor & value, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_flash_attention_for_cpu(const Tensor & query, const Tensor & key, const Tensor & value, double dropout_p, bool is_causal, const c10::optional & attn_mask, c10::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_flash_attention_backward(const Tensor & grad_out, const Tensor & query, const Tensor & key, const Tensor & value, const Tensor & out, const Tensor & logsumexp, const Tensor & cum_seq_q, const Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const Tensor & philox_seed, const Tensor & philox_offset, c10::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward(const Tensor & grad_out, const Tensor & query, const Tensor & key, const Tensor & value, const Tensor & out, const Tensor & logsumexp, double dropout_p, bool is_causal, const c10::optional & attn_mask, c10::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_efficient_attention(const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, c10::optional scale); // {"schema": "aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_efficient_attention_backward(const Tensor & grad_out_, const Tensor & query, const Tensor & key, const Tensor & value, const Tensor & attn_bias, const Tensor & out, const Tensor & logsumexp, const Tensor & philox_seed, const Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal, c10::optional scale); // {"schema": "aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _scaled_dot_product_cudnn_attention(const Tensor & query, const Tensor & key, const Tensor & value, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional scale); // {"schema": "aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)", "dispatch": "True", "default": "False"}
+::std::tuple _flash_attention_forward(const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & cum_seq_q, const c10::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, c10::optional scale); // {"schema": "aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"}
+::std::tuple _flash_attention_backward(const Tensor & grad_out, const Tensor & query, const Tensor & key, const Tensor & value, const Tensor & out, const Tensor & logsumexp, const Tensor & cum_seq_q, const Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const Tensor & philox_seed, const Tensor & philox_offset, c10::optional scale); // {"schema": "aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+::std::tuple _efficient_attention_forward(const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & bias, const c10::optional & cu_seqlens_q, const c10::optional & cu_seqlens_k, c10::optional max_seqlen_q, c10::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp, c10::optional scale, const c10::optional & causal_diagonal, const c10::optional & seqlen_k); // {"schema": "aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, int? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)", "dispatch": "True", "default": "False"}
+::std::tuple _efficient_attention_backward(const Tensor & grad_out_, const Tensor & query, const Tensor & key, const Tensor & value, const c10::optional & bias, const Tensor & out, const c10::optional & cu_seqlens_q, const c10::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const Tensor & logsumexp, double dropout_p, const Tensor & philox_seed, const Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, c10::optional scale, c10::optional num_splits_key); // {"schema": "aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"}
+Tensor _triton_scaled_dot_attention(const Tensor & q, const Tensor & k, const Tensor & v, double dropout_p); // {"schema": "aten::_triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor & _fill_mem_eff_dropout_mask_(Tensor & self, double dropout_p, int64_t seed, int64_t offset); // {"schema": "aten::_fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _triton_multi_head_attention(const Tensor & query, const Tensor & key, const Tensor & value, int64_t embed_dim, int64_t num_head, const Tensor & qkv_weight, const Tensor & qkv_bias, const Tensor & proj_weight, const Tensor & proj_bias, const c10::optional & mask); // {"schema": "aten::_triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor", "dispatch": "True", "default": "False"}
+Tensor special_airy_ai(const Tensor & x); // {"schema": "aten::special_airy_ai(Tensor x) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_airy_ai_out(const Tensor & x, Tensor & out); // {"schema": "aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_bessel_j0(const Tensor & self); // {"schema": "aten::special_bessel_j0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_bessel_j0_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_bessel_j1(const Tensor & self); // {"schema": "aten::special_bessel_j1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_bessel_j1_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_bessel_y0(const Tensor & self); // {"schema": "aten::special_bessel_y0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_bessel_y0_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_bessel_y1(const Tensor & self); // {"schema": "aten::special_bessel_y1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_bessel_y1_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_chebyshev_polynomial_t(const Tensor & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_t(const Scalar & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_t(const Tensor & x, const Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_t_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_chebyshev_polynomial_t_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_t_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_u(const Tensor & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_u(const Scalar & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_u(const Tensor & x, const Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_u_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_chebyshev_polynomial_u_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_u_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_v(const Tensor & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_v(const Scalar & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_v(const Tensor & x, const Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_v_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_chebyshev_polynomial_v_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_v_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_w(const Tensor & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_w(const Scalar & x, const Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_chebyshev_polynomial_w(const Tensor & x, const Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_w_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_chebyshev_polynomial_w_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_chebyshev_polynomial_w_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_hermite_polynomial_h(const Tensor & x, const Tensor & n); // {"schema": "aten::special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_hermite_polynomial_h(const Scalar & x, const Tensor & n); // {"schema": "aten::special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_hermite_polynomial_h(const Tensor & x, const Scalar & n); // {"schema": "aten::special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_hermite_polynomial_h_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_hermite_polynomial_h_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_hermite_polynomial_h_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_hermite_polynomial_he(const Tensor & x, const Tensor & n); // {"schema": "aten::special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_hermite_polynomial_he(const Scalar & x, const Tensor & n); // {"schema": "aten::special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_hermite_polynomial_he(const Tensor & x, const Scalar & n); // {"schema": "aten::special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_hermite_polynomial_he_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_hermite_polynomial_he_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_hermite_polynomial_he_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_laguerre_polynomial_l(const Tensor & x, const Tensor & n); // {"schema": "aten::special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_laguerre_polynomial_l(const Scalar & x, const Tensor & n); // {"schema": "aten::special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_laguerre_polynomial_l(const Tensor & x, const Scalar & n); // {"schema": "aten::special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_laguerre_polynomial_l_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_laguerre_polynomial_l_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_laguerre_polynomial_l_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_legendre_polynomial_p(const Tensor & x, const Tensor & n); // {"schema": "aten::special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_legendre_polynomial_p(const Scalar & x, const Tensor & n); // {"schema": "aten::special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_legendre_polynomial_p(const Tensor & x, const Scalar & n); // {"schema": "aten::special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_legendre_polynomial_p_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_legendre_polynomial_p_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_legendre_polynomial_p_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_modified_bessel_i0(const Tensor & self); // {"schema": "aten::special_modified_bessel_i0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_modified_bessel_i0_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_modified_bessel_i1(const Tensor & self); // {"schema": "aten::special_modified_bessel_i1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_modified_bessel_i1_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_modified_bessel_k0(const Tensor & self); // {"schema": "aten::special_modified_bessel_k0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_modified_bessel_k0_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_modified_bessel_k1(const Tensor & self); // {"schema": "aten::special_modified_bessel_k1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_modified_bessel_k1_out(const Tensor & self, Tensor & out); // {"schema": "aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_scaled_modified_bessel_k0(const Tensor & x); // {"schema": "aten::special_scaled_modified_bessel_k0(Tensor x) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_scaled_modified_bessel_k0_out(const Tensor & x, Tensor & out); // {"schema": "aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_scaled_modified_bessel_k1(const Tensor & x); // {"schema": "aten::special_scaled_modified_bessel_k1(Tensor x) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_scaled_modified_bessel_k1_out(const Tensor & x, Tensor & out); // {"schema": "aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor special_shifted_chebyshev_polynomial_t(const Tensor & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_t(const Scalar & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_t(const Tensor & x, const Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_t_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_shifted_chebyshev_polynomial_t_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_t_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_u(const Tensor & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_u(const Scalar & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_u(const Tensor & x, const Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_u_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_shifted_chebyshev_polynomial_u_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_u_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_v(const Tensor & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_v(const Scalar & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_v(const Tensor & x, const Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_v_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_shifted_chebyshev_polynomial_v_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_v_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_w(const Tensor & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_w(const Scalar & x, const Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor special_shifted_chebyshev_polynomial_w(const Tensor & x, const Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_w_out(const Tensor & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor & special_shifted_chebyshev_polynomial_w_out(const Scalar & x, const Tensor & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & special_shifted_chebyshev_polynomial_w_out(const Tensor & x, const Scalar & n, Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor special_spherical_bessel_j0(const Tensor & x); // {"schema": "aten::special_spherical_bessel_j0(Tensor x) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & special_spherical_bessel_j0_out(const Tensor & x, Tensor & out); // {"schema": "aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"}
+Tensor _foobar(const Tensor & self, bool arg1, bool arg2, bool arg3); // {"schema": "aten::_foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor", "dispatch": "True", "default": "False"}
+void _fused_adam_(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"}
+void _fused_adam_(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, const Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"}
+void _fused_adamw_(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"}
+void _fused_adamw_(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, const Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"}
+void _fused_sgd_(TensorList self, TensorList grads, TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"}
+void _fused_sgd_(TensorList self, TensorList grads, TensorList momentum_buffer_list, double weight_decay, double momentum, const Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"}
+void _propagate_xla_data(const Tensor & input, const Tensor & output); // {"schema": "aten::_propagate_xla_data(Tensor input, Tensor output) -> ()", "dispatch": "False", "default": "True"}
+Tensor & _new_zeros_with_same_feature_meta_out(const Tensor & self, const Tensor & other, int64_t self_num_batch_dims, Tensor & out); // {"schema": "aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _cudnn_ctc_loss_out(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity, Tensor & out0, Tensor & out1); // {"schema": "aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _cudnn_rnn_flatten_weight_out(TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional, Tensor & out); // {"schema": "aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _cudnn_rnn_out(const Tensor & input, TensorList weight, int64_t weight_stride0, const c10::optional & weight_buf, const Tensor & hx, const c10::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3, Tensor & out4); // {"schema": "aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"}
+void _cudnn_rnn_backward_out(const Tensor & input, TensorList weight, int64_t weight_stride0, const Tensor & weight_buf, const Tensor & hx, const c10::optional & cx, const Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const c10::optional & dropout_state, const Tensor & reserve, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2, TensorList out3); // {"schema": "aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()", "dispatch": "True", "default": "True"}
+Tensor & _cudnn_init_dropout_state_out(double dropout, bool train, int64_t dropout_seed, Tensor & out); // {"schema": "aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _fused_dropout_out(const Tensor & self, double p, c10::optional generator, Tensor & out0, Tensor & out1); // {"schema": "aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _masked_scale_out(const Tensor & self, const Tensor & mask, double scale, Tensor & out); // {"schema": "aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple native_dropout_out(const Tensor & input, double p, c10::optional train, Tensor & out0, Tensor & out1); // {"schema": "aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & native_dropout_backward_out(const Tensor & grad_output, const Tensor & mask, double scale, Tensor & out); // {"schema": "aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _conj_physical_out(const Tensor & self, Tensor & out); // {"schema": "aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _add_relu_out(const Tensor & self, const Scalar & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & add_out(const Tensor & self, const Scalar & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & affine_grid_generator_out(const Tensor & theta, c10::SymIntArrayRef size, bool align_corners, Tensor & out); // {"schema": "aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_functorch_fallback_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bartlett_window_out(int64_t window_length, Tensor & out); // {"schema": "aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bartlett_window_out(int64_t window_length, bool periodic, Tensor & out); // {"schema": "aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantized_batch_norm_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & mean, const Tensor & var, double eps, double output_scale, int64_t output_zero_point, Tensor & out); // {"schema": "aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bernoulli_out(const Tensor & self, const Tensor & p, c10::optional generator, Tensor & out); // {"schema": "aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor bernoulli(const Tensor & self, const Tensor & p, c10::optional generator); // {"schema": "aten::bernoulli.Tensor(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & bernoulli_out(const Tensor & self, double p, c10::optional generator, Tensor & out); // {"schema": "aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & binary_cross_entropy_with_logits_out(const Tensor & self, const Tensor & target, const c10::optional & weight, const c10::optional & pos_weight, int64_t reduction, Tensor & out); // {"schema": "aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bincount_out(const Tensor & self, const c10::optional & weights, int64_t minlength, Tensor & out); // {"schema": "aten::bincount.out(Tensor self, Tensor? weights=None, int minlength=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & blackman_window_out(int64_t window_length, Tensor & out); // {"schema": "aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & blackman_window_out(int64_t window_length, bool periodic, Tensor & out); // {"schema": "aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & block_diag_out(TensorList tensors, Tensor & out); // {"schema": "aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & constant_pad_nd_out(const Tensor & self, c10::SymIntArrayRef pad, const Scalar & value, Tensor & out); // {"schema": "aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & convolution_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, Tensor & out); // {"schema": "aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple convolution_backward_out(const Tensor & grad_output, const Tensor & input, const Tensor & weight, OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & convolution_overrideable_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, Tensor & out); // {"schema": "aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple convolution_backward_overrideable_out(const Tensor & grad_output, const Tensor & input, const Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & _convolution_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, Tensor & out); // {"schema": "aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & conv_tbc_out(const Tensor & self, const Tensor & weight, const Tensor & bias, int64_t pad, Tensor & out); // {"schema": "aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & copy_out(const Tensor & self, const Tensor & src, bool non_blocking, Tensor & out); // {"schema": "aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _copy_from_out(const Tensor & self, const Tensor & dst, bool non_blocking, Tensor & out); // {"schema": "aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _copy_from_and_resize_out(const Tensor & self, const Tensor & dst, Tensor & out); // {"schema": "aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & count_nonzero_out(const Tensor & self, IntArrayRef dim, Tensor & out); // {"schema": "aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & count_nonzero_out(const Tensor & self, c10::optional dim, Tensor & out); // {"schema": "aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cudnn_affine_grid_generator_out(const Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W, Tensor & out); // {"schema": "aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cudnn_affine_grid_generator_backward_out(const Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W, Tensor & out); // {"schema": "aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple cudnn_batch_norm_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3); // {"schema": "aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"}
+::std::tuple cudnn_batch_norm_backward_out(const Tensor & input, const Tensor & grad_output, const Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, const Tensor & reserveSpace, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & cudnn_convolution_transpose_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, Tensor & out); // {"schema": "aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _mps_convolution_transpose_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, Tensor & out); // {"schema": "aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple mps_convolution_transpose_backward_out(const Tensor & self, const Tensor & grad_output, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, Tensor & out0, Tensor & out1); // {"schema": "aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & cudnn_convolution_relu_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, Tensor & out); // {"schema": "aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cudnn_convolution_add_relu_out(const Tensor & self, const Tensor & weight, const Tensor & z, const c10::optional & alpha, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, Tensor & out); // {"schema": "aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & cudnn_grid_sampler_out(const Tensor & self, const Tensor & grid, Tensor & out); // {"schema": "aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple cudnn_grid_sampler_backward_out(const Tensor & self, const Tensor & grid, const Tensor & grad_output, Tensor & out0, Tensor & out1); // {"schema": "aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _ctc_loss_out(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, bool zero_infinity, Tensor & out0, Tensor & out1); // {"schema": "aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _ctc_loss_out(const Tensor & log_probs, const Tensor & targets, const Tensor & input_lengths, const Tensor & target_lengths, int64_t blank, bool zero_infinity, Tensor & out0, Tensor & out1); // {"schema": "aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _ctc_loss_backward_out(const Tensor & grad, const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor & neg_log_likelihood, const Tensor & log_alpha, int64_t blank, bool zero_infinity, Tensor & out); // {"schema": "aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & diag_embed_out(const Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, Tensor & out); // {"schema": "aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & diagonal_backward_out(const Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, Tensor & out); // {"schema": "aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & div_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & div_out(const Tensor & self, const Scalar & other, c10::optional rounding_mode, Tensor & out); // {"schema": "aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & embedding_out(const Tensor & weight, const Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse, Tensor & out); // {"schema": "aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & embedding_dense_backward_out(const Tensor & grad_output, const Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, Tensor & out); // {"schema": "aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & embedding_renorm_out(const Tensor & self, const Tensor & indices, double max_norm, double norm_type, Tensor & out); // {"schema": "aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor embedding_renorm(const Tensor & self, const Tensor & indices, double max_norm, double norm_type); // {"schema": "aten::embedding_renorm(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor", "dispatch": "True", "default": "True"}
+::std::tuple _embedding_bag_forward_only_out(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3); // {"schema": "aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"}
+::std::tuple _embedding_bag_out(const Tensor & weight, const Tensor & indices, const Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3); // {"schema": "aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"}
+Tensor & _embedding_bag_dense_backward_out(const Tensor & grad, const Tensor & indices, const Tensor & offset2bag, const Tensor & bag_size, const Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx, Tensor & out); // {"schema": "aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _embedding_bag_per_sample_weights_backward_out(const Tensor & grad, const Tensor & weight, const Tensor & indices, const Tensor & offsets, const Tensor & offset2bag, int64_t mode, int64_t padding_idx, Tensor & out); // {"schema": "aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & empty_out(IntArrayRef size, c10::optional names, c10::optional memory_format, Tensor & out); // {"schema": "aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & empty_permuted_out(c10::SymIntArrayRef size, IntArrayRef physical_layout, Tensor & out); // {"schema": "aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & new_empty_out(const Tensor & self, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & new_empty_strided_out(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, Tensor & out); // {"schema": "aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & new_full_out(const Tensor & self, c10::SymIntArrayRef size, const Scalar & fill_value, Tensor & out); // {"schema": "aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & new_zeros_out(const Tensor & self, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & new_ones_out(const Tensor & self, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _empty_affine_quantized_out(c10::SymIntArrayRef size, double scale, int64_t zero_point, c10::optional memory_format, Tensor & out); // {"schema": "aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _empty_per_channel_affine_quantized_out(c10::SymIntArrayRef size, const Tensor & scales, const Tensor & zero_points, int64_t axis, c10::optional memory_format, Tensor & out); // {"schema": "aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+const Tensor & resize_out(const Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format, const Tensor & out); // {"schema": "aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor resize(const Tensor & self, c10::SymIntArrayRef size, c10::optional memory_format); // {"schema": "aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+const Tensor & _resize_output_out(const Tensor & self, c10::SymIntArrayRef size, Device device, const Tensor & out); // {"schema": "aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor _resize_output(const Tensor & self, c10::SymIntArrayRef size, Device device); // {"schema": "aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & empty_quantized_out(IntArrayRef size, const Tensor & qtensor, c10::optional memory_format, Tensor & out); // {"schema": "aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & empty_like_out(const Tensor & self, c10::optional memory_format, Tensor & out); // {"schema": "aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & empty_strided_out(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, Tensor & out); // {"schema": "aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & fill_out(const Tensor & self, const Scalar & value, Tensor & out); // {"schema": "aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & fill_out(const Tensor & self, const Tensor & value, Tensor & out); // {"schema": "aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & floor_divide_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & full_out(IntArrayRef size, const Scalar & fill_value, c10::optional names, Tensor & out); // {"schema": "aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & full_like_out(const Tensor & self, const Scalar & fill_value, c10::optional memory_format, Tensor & out); // {"schema": "aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & from_file_out(c10::string_view filename, c10::optional shared, c10::optional size, Tensor & out); // {"schema": "aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & grid_sampler_2d_out(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, Tensor & out); // {"schema": "aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple grid_sampler_2d_backward_out(const Tensor & grad_output, const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, Tensor & out0, Tensor & out1); // {"schema": "aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _grid_sampler_2d_cpu_fallback_out(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, Tensor & out); // {"schema": "aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & grid_sampler_3d_out(const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, Tensor & out); // {"schema": "aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple grid_sampler_3d_backward_out(const Tensor & grad_output, const Tensor & input, const Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, Tensor & out0, Tensor & out1); // {"schema": "aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & hann_window_out(int64_t window_length, Tensor & out); // {"schema": "aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hann_window_out(int64_t window_length, bool periodic, Tensor & out); // {"schema": "aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hamming_window_out(int64_t window_length, Tensor & out); // {"schema": "aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hamming_window_out(int64_t window_length, bool periodic, Tensor & out); // {"schema": "aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hamming_window_out(int64_t window_length, bool periodic, double alpha, Tensor & out); // {"schema": "aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hamming_window_out(int64_t window_length, bool periodic, double alpha, double beta, Tensor & out); // {"schema": "aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & kaiser_window_out(int64_t window_length, Tensor & out); // {"schema": "aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & kaiser_window_out(int64_t window_length, bool periodic, Tensor & out); // {"schema": "aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & kaiser_window_out(int64_t window_length, bool periodic, double beta, Tensor & out); // {"schema": "aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple native_group_norm_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple native_group_norm_backward_out(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & rstd, const c10::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & index_put_out(const Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate, Tensor & out); // {"schema": "aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _index_put_impl_out(const Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate, bool unsafe, Tensor & out); // {"schema": "aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor _index_put_impl(const Tensor & self, const c10::List> & indices, const Tensor & values, bool accumulate, bool unsafe); // {"schema": "aten::_index_put_impl(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & isnan_out(const Tensor & self, Tensor & out); // {"schema": "aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple native_layer_norm_out(const Tensor & input, c10::SymIntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple native_layer_norm_backward_out(const Tensor & grad_out, const Tensor & input, c10::SymIntArrayRef normalized_shape, const Tensor & mean, const Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple linear_backward_out(const Tensor & self, const Tensor & grad_output, const Tensor & weight, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_linear_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, Tensor & out); // {"schema": "aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_linear_backward_input_out(IntArrayRef input_size, const Tensor & grad_output, const Tensor & weight, Tensor & out); // {"schema": "aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple mkldnn_linear_backward_weights_out(const Tensor & grad_output, const Tensor & input, const Tensor & weight, bool bias_defined, Tensor & out0, Tensor & out1); // {"schema": "aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple mkldnn_linear_backward_out(const Tensor & self, const Tensor & grad_output, const Tensor & weight, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple matmul_backward_out(const Tensor & grad, const Tensor & self, const Tensor & other, ::std::array mask, Tensor & out0, Tensor & out1); // {"schema": "aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _aminmax_out(const Tensor & self, Tensor & out0, Tensor & out1); // {"schema": "aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _aminmax_out(const Tensor & self, int64_t dim, bool keepdim, Tensor & out0, Tensor & out1); // {"schema": "aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & max_pool2d_backward_out(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_max_pool2d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_max_pool2d_backward_out(const Tensor & grad_output, const Tensor & output, const Tensor & input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_max_pool3d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_max_pool3d_backward_out(const Tensor & grad_output, const Tensor & output, const Tensor & input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantized_max_pool1d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantized_max_pool2d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantized_max_pool3d_out(const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor & out); // {"schema": "aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & median_out(const Tensor & self, Tensor & out); // {"schema": "aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & nanmedian_out(const Tensor & self, Tensor & out); // {"schema": "aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _mps_convolution_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, Tensor & out); // {"schema": "aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple mps_convolution_backward_out(const Tensor & self, const Tensor & grad_output, const Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_convolution_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, Tensor & out); // {"schema": "aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple mkldnn_rnn_layer_out(const Tensor & input, const Tensor & weight0, const Tensor & weight1, const Tensor & weight2, const Tensor & weight3, const Tensor & hx_, const Tensor & cx_, bool reverse, IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3); // {"schema": "aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"}
+::std::tuple mkldnn_rnn_layer_backward_out(const Tensor & input, const Tensor & weight1, const Tensor & weight2, const Tensor & weight3, const Tensor & weight4, const Tensor & hx_, const Tensor & cx_tmp, const Tensor & output, const Tensor & hy_, const Tensor & cy_, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, IntArrayRef batch_sizes, bool batch_first, const Tensor & workspace, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3, Tensor & out4, Tensor & out5, Tensor & out6); // {"schema": "aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))", "dispatch": "True", "default": "True"}
+::std::tuple miopen_batch_norm_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double exponential_average_factor, double epsilon, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple miopen_batch_norm_backward_out(const Tensor & input, const Tensor & grad_output, const Tensor & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_var, double epsilon, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & miopen_convolution_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, Tensor & out); // {"schema": "aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & miopen_convolution_transpose_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, Tensor & out); // {"schema": "aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & miopen_depthwise_convolution_out(const Tensor & self, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, Tensor & out); // {"schema": "aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple miopen_rnn_out(const Tensor & input, TensorList weight, int64_t weight_stride0, const Tensor & hx, const c10::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const c10::optional & dropout_state, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3, Tensor & out4); // {"schema": "aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"}
+void miopen_rnn_backward_out(const Tensor & input, TensorList weight, int64_t weight_stride0, const Tensor & weight_buf, const Tensor & hx, const c10::optional & cx, const Tensor & output, const c10::optional & grad_output, const c10::optional & grad_hy, const c10::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const c10::optional & dropout_state, const Tensor & reserve, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2, TensorList out3); // {"schema": "aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()", "dispatch": "True", "default": "True"}
+Tensor & _sparse_sparse_matmul_out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mul_out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _native_batch_norm_legit_functional(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & running_mean, const Tensor & running_var, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out)", "dispatch": "True", "default": "True"}
+::std::tuple _native_batch_norm_legit_no_training_out(const Tensor & input, const c10::optional & weight, const c10::optional & bias, const Tensor & running_mean, const Tensor & running_var, double momentum, double eps, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple batch_norm_stats_out(const Tensor & input, double eps, Tensor & out0, Tensor & out1); // {"schema": "aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple batch_norm_gather_stats_out(const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, int64_t count, Tensor & out0, Tensor & out1); // {"schema": "aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple batch_norm_gather_stats_with_counts_out(const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & running_mean, const c10::optional & running_var, double momentum, double eps, const Tensor & counts, Tensor & out0, Tensor & out1); // {"schema": "aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple native_batch_norm_backward_out(const Tensor & grad_out, const Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple batch_norm_backward_reduce_out(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & weight, bool input_g, bool weight_g, bool bias_g, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3); // {"schema": "aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"}
+Tensor & batch_norm_backward_elemt_out(const Tensor & grad_out, const Tensor & input, const Tensor & mean, const Tensor & invstd, const c10::optional & weight, const Tensor & sum_dy, const Tensor & sum_dy_xmu, const Tensor & count, Tensor & out); // {"schema": "aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple batch_norm_update_stats_out(const Tensor & input, const c10::optional & running_mean, const c10::optional & running_var, double momentum, Tensor & out0, Tensor & out1); // {"schema": "aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _nnpack_spatial_convolution_out(const Tensor & input, const Tensor & weight, const c10::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, Tensor & out); // {"schema": "aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & ones_out(IntArrayRef size, c10::optional names, Tensor & out); // {"schema": "aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & ones_like_out(const Tensor & self, c10::optional memory_format, Tensor & out); // {"schema": "aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _euclidean_dist_out(const Tensor & x1, const Tensor & x2, Tensor & out); // {"schema": "aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _cdist_forward_out(const Tensor & x1, const Tensor & x2, double p, c10::optional compute_mode, Tensor & out); // {"schema": "aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _cdist_backward_out(const Tensor & grad, const Tensor & x1, const Tensor & x2, double p, const Tensor & cdist, Tensor & out); // {"schema": "aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _pdist_forward_out(const Tensor & self, double p, Tensor & out); // {"schema": "aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _pdist_backward_out(const Tensor & grad, const Tensor & self, double p, const Tensor & pdist, Tensor & out); // {"schema": "aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & pixel_shuffle_out(const Tensor & self, int64_t upscale_factor, Tensor & out); // {"schema": "aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & pixel_unshuffle_out(const Tensor & self, int64_t downscale_factor, Tensor & out); // {"schema": "aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & channel_shuffle_out(const Tensor & self, c10::SymInt groups, Tensor & out); // {"schema": "aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _pin_memory_out(const Tensor & self, c10::optional device, Tensor & out); // {"schema": "aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & scalar_tensor_out(const Scalar & s, Tensor & out); // {"schema": "aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rand_out(c10::SymIntArrayRef size, c10::optional names, Tensor & out); // {"schema": "aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rand_out(c10::SymIntArrayRef size, c10::optional generator, c10::optional names, Tensor & out); // {"schema": "aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rand_like_out(const Tensor & self, c10::optional memory_format, Tensor & out); // {"schema": "aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randint_like_out(const Tensor & self, c10::SymInt high, c10::optional memory_format, Tensor & out); // {"schema": "aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randint_like_out(const Tensor & self, c10::SymInt low, c10::SymInt high, c10::optional memory_format, Tensor & out); // {"schema": "aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randn_out(c10::SymIntArrayRef size, c10::optional names, Tensor & out); // {"schema": "aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randn_out(c10::SymIntArrayRef size, c10::optional generator, c10::optional names, Tensor & out); // {"schema": "aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & randn_like_out(const Tensor & self, c10::optional memory_format, Tensor & out); // {"schema": "aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & repeat_out(const Tensor & self, c10::SymIntArrayRef repeats, Tensor & out); // {"schema": "aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & repeat_interleave_out(const Tensor & repeats, c10::optional output_size, Tensor & out); // {"schema": "aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _mkldnn_reshape_out(const Tensor & self, IntArrayRef shape, Tensor & out); // {"schema": "aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & relu_out(const Tensor & self, Tensor & out); // {"schema": "aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & select_backward_out(const Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index, Tensor & out); // {"schema": "aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & celu_out(const Tensor & self, const Scalar & alpha, Tensor & out); // {"schema": "aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & slice_backward_out(const Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step, Tensor & out); // {"schema": "aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & slice_scatter_out(const Tensor & self, const Tensor & src, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step, Tensor & out); // {"schema": "aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & select_scatter_out(const Tensor & self, const Tensor & src, int64_t dim, c10::SymInt index, Tensor & out); // {"schema": "aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & diagonal_scatter_out(const Tensor & self, const Tensor & src, int64_t offset, int64_t dim1, int64_t dim2, Tensor & out); // {"schema": "aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & as_strided_scatter_out(const Tensor & self, const Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset, Tensor & out); // {"schema": "aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+void unsafe_split_out(const Tensor & self, c10::SymInt split_size, int64_t dim, TensorList out); // {"schema": "aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void unsafe_split_with_sizes_out(const Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, TensorList out); // {"schema": "aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+Tensor & sum_out(const Tensor & self, c10::optional dtype, Tensor & out); // {"schema": "aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple std_mean_out(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, Tensor & out0, Tensor & out1); // {"schema": "aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & prod_out(const Tensor & self, c10::optional dtype, Tensor & out); // {"schema": "aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _mkldnn_transpose_out(const Tensor & self, int64_t dim0, int64_t dim1, Tensor & out); // {"schema": "aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & flip_out(const Tensor & self, IntArrayRef dims, Tensor & out); // {"schema": "aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & roll_out(const Tensor & self, c10::SymIntArrayRef shifts, IntArrayRef dims, Tensor & out); // {"schema": "aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rot90_out(const Tensor & self, int64_t k, IntArrayRef dims, Tensor & out); // {"schema": "aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _transform_bias_rescale_qkv_out(const Tensor & qkv, const Tensor & qkv_bias, int64_t num_heads, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & _nested_tensor_from_mask_out(const Tensor & t, const Tensor & mask, bool mask_check, Tensor & out); // {"schema": "aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_from_padded_out(const Tensor & padded, const Tensor & cpu_nested_shape_example, bool fuse_transform_0213, Tensor & out); // {"schema": "aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_tensor_size_out(const Tensor & self, Tensor & out); // {"schema": "aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_tensor_strides_out(const Tensor & self, Tensor & out); // {"schema": "aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_tensor_storage_offsets_out(const Tensor & self, Tensor & out); // {"schema": "aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_from_padded_and_nested_example_out(const Tensor & padded, const Tensor & nt_example, Tensor & out); // {"schema": "aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_view_from_buffer_copy_out(const Tensor & self, const Tensor & nested_size, const Tensor & nested_strides, const Tensor & offsets, Tensor & out); // {"schema": "aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_view_from_jagged_copy_out(const Tensor & self, const Tensor & offsets, const Tensor & dummy, const c10::optional & lengths, int64_t ragged_idx, Tensor & out); // {"schema": "aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_get_values_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _trilinear_out(const Tensor & i1, const Tensor & i2, const Tensor & i3, IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, IntArrayRef sumdim, int64_t unroll_dim, Tensor & out); // {"schema": "aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _unique_out(const Tensor & self, bool sorted, bool return_inverse, Tensor & out0, Tensor & out1); // {"schema": "aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple unique_dim_out(const Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple unique_consecutive_out(const Tensor & self, bool return_inverse, bool return_counts, c10::optional dim, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple unique_dim_consecutive_out(const Tensor & self, int64_t dim, bool return_inverse, bool return_counts, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple _unique2_out(const Tensor & self, bool sorted, bool return_inverse, bool return_counts, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & _unsafe_view_out(const Tensor & self, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple var_mean_out(const Tensor & self, OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim, Tensor & out0, Tensor & out1); // {"schema": "aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _weight_norm_interface_out(const Tensor & v, const Tensor & g, int64_t dim, Tensor & out0, Tensor & out1); // {"schema": "aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _weight_norm_interface_backward_out(const Tensor & grad_w, const Tensor & saved_v, const Tensor & saved_g, const Tensor & saved_norms, int64_t dim, Tensor & out0, Tensor & out1); // {"schema": "aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & zeros_out(IntArrayRef size, c10::optional names, Tensor & out); // {"schema": "aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _efficientzerotensor_out(c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & zeros_like_out(const Tensor & self, c10::optional memory_format, Tensor & out); // {"schema": "aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _standard_gamma_grad_out(const Tensor & self, const Tensor & output, Tensor & out); // {"schema": "aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _standard_gamma_out(const Tensor & self, c10::optional generator, Tensor & out); // {"schema": "aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _dirichlet_grad_out(const Tensor & x, const Tensor & alpha, const Tensor & total, Tensor & out); // {"schema": "aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sample_dirichlet_out(const Tensor & self, c10::optional generator, Tensor & out); // {"schema": "aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & poisson_out(const Tensor & self, c10::optional generator, Tensor & out); // {"schema": "aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & binomial_out(const Tensor & count, const Tensor & prob, c10::optional generator, Tensor & out); // {"schema": "aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & native_norm_out(const Tensor & self, const Scalar & p, Tensor & out); // {"schema": "aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & native_norm_out(const Tensor & self, const c10::optional & p, IntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_sum_out(const Tensor & self, IntArrayRef dim, Tensor & out); // {"schema": "aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_sum_backward_out(const Tensor & grad, const Tensor & self, IntArrayRef dim, Tensor & out); // {"schema": "aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_csr_sum_out(const Tensor & self, IntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_csr_prod_out(const Tensor & self, IntArrayRef dim, bool keepdim, c10::optional dtype, Tensor & out); // {"schema": "aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_softmax_out(const Tensor & self, int64_t dim, bool half_to_float, Tensor & out); // {"schema": "aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_softmax_backward_data_out(const Tensor & grad_output, const Tensor & output, int64_t dim, const Tensor & self, Tensor & out); // {"schema": "aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_log_softmax_out(const Tensor & self, int64_t dim, bool half_to_float, Tensor & out); // {"schema": "aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_log_softmax_backward_data_out(const Tensor & grad_output, const Tensor & output, int64_t dim, const Tensor & self, Tensor & out); // {"schema": "aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _spdiags_out(const Tensor & diagonals, const Tensor & offsets, IntArrayRef shape, c10::optional layout, Tensor & out); // {"schema": "aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & norm_out(const Tensor & self, const c10::optional & p, ScalarType dtype, Tensor & out); // {"schema": "aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & norm_out(const Tensor & self, const Scalar & p, Tensor & out); // {"schema": "aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & clone_out(const Tensor & self, c10::optional memory_format, Tensor & out); // {"schema": "aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+const Tensor & resize_as_out(const Tensor & self, const Tensor & the_template, c10::optional memory_format, const Tensor & out); // {"schema": "aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor resize_as(const Tensor & self, const Tensor & the_template, c10::optional memory_format); // {"schema": "aten::resize_as(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"}
+const Tensor & resize_as_sparse_out(const Tensor & self, const Tensor & the_template, const Tensor & out); // {"schema": "aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor resize_as_sparse(const Tensor & self, const Tensor & the_template); // {"schema": "aten::resize_as_sparse(Tensor self, Tensor the_template) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & zero_out(const Tensor & self, Tensor & out); // {"schema": "aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor zero(const Tensor & self); // {"schema": "aten::zero(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sub_out(const Tensor & self, const Scalar & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rsub_out(const Tensor & self, const Tensor & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rsub_out(const Tensor & self, const Scalar & other, const Scalar & alpha, Tensor & out); // {"schema": "aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_addmm_out(const Tensor & self, const Tensor & mat1, const Tensor & mat2, const Scalar & beta, const Scalar & alpha, Tensor & out); // {"schema": "aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & sparse_coo_tensor_out(IntArrayRef size, Tensor & out); // {"schema": "aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_coo_tensor_with_dims_out(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size, Tensor & out); // {"schema": "aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_coo_tensor_with_dims_and_tensors_out(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const Tensor & indices, const Tensor & values, c10::optional is_coalesced, Tensor & out); // {"schema": "aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+const Tensor & sparse_resize_out(const Tensor & self, IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const Tensor & out); // {"schema": "aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor sparse_resize(const Tensor & self, IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor", "dispatch": "True", "default": "True"}
+const Tensor & sparse_resize_and_clear_out(const Tensor & self, IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const Tensor & out); // {"schema": "aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor sparse_resize_and_clear(const Tensor & self, IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_and_clear(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & sparse_mask_out(const Tensor & self, const Tensor & mask, Tensor & out); // {"schema": "aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_mask_projection_out(const Tensor & self, const Tensor & mask, bool accumulate_matches, Tensor & out); // {"schema": "aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _to_dense_out(const Tensor & self, c10::optional dtype, c10::optional masked_grad, Tensor & out); // {"schema": "aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _coalesce_out(const Tensor & self, Tensor & out); // {"schema": "aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _coalesced_out(const Tensor & self, bool coalesced, Tensor & out); // {"schema": "aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor _coalesced(const Tensor & self, bool coalesced); // {"schema": "aten::_coalesced(Tensor self, bool coalesced) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & copy_sparse_to_sparse_out(const Tensor & self, const Tensor & src, bool non_blocking, Tensor & out); // {"schema": "aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor copy_sparse_to_sparse(const Tensor & self, const Tensor & src, bool non_blocking); // {"schema": "aten::copy_sparse_to_sparse(Tensor self, Tensor src, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & _to_sparse_out(const Tensor & self, int64_t sparse_dim, Tensor & out); // {"schema": "aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _to_sparse_out(const Tensor & self, c10::optional layout, OptionalIntArrayRef blocksize, c10::optional dense_dim, Tensor & out); // {"schema": "aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _to_sparse_csr_out(const Tensor & self, c10::optional dense_dim, Tensor & out); // {"schema": "aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _to_sparse_csc_out(const Tensor & self, c10::optional dense_dim, Tensor & out); // {"schema": "aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _to_sparse_bsr_out(const Tensor & self, IntArrayRef blocksize, c10::optional dense_dim, Tensor & out); // {"schema": "aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _to_sparse_bsc_out(const Tensor & self, IntArrayRef blocksize, c10::optional dense_dim, Tensor & out); // {"schema": "aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & to_mkldnn_out(const Tensor & self, c10::optional dtype, Tensor & out); // {"schema": "aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_reorder_conv2d_weight_out(const Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, OptionalSymIntArrayRef input_size, Tensor & out); // {"schema": "aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_reorder_conv3d_weight_out(const Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, Tensor & out); // {"schema": "aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantize_per_tensor_dynamic_out(const Tensor & self, ScalarType dtype, bool reduce_range, Tensor & out); // {"schema": "aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantize_per_tensor_out(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype, Tensor & out); // {"schema": "aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & quantize_per_tensor_out(const Tensor & self, const Tensor & scale, const Tensor & zero_point, ScalarType dtype, Tensor & out); // {"schema": "aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+void quantize_per_tensor_out(TensorList tensors, const Tensor & scales, const Tensor & zero_points, ScalarType dtype, TensorList out); // {"schema": "aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+Tensor & quantize_per_channel_out(const Tensor & self, const Tensor & scales, const Tensor & zero_points, int64_t axis, ScalarType dtype, Tensor & out); // {"schema": "aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & dequantize_out(const Tensor & self, Tensor & out); // {"schema": "aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+void dequantize_out(TensorList tensors, TensorList out); // {"schema": "aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+Tensor & q_per_channel_scales_out(const Tensor & self, Tensor & out); // {"schema": "aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & q_per_channel_zero_points_out(const Tensor & self, Tensor & out); // {"schema": "aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & int_repr_out(const Tensor & self, Tensor & out); // {"schema": "aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _make_per_tensor_quantized_tensor_out(const Tensor & self, double scale, int64_t zero_point, Tensor & out); // {"schema": "aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _make_per_channel_quantized_tensor_out(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, Tensor & out); // {"schema": "aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple fake_quantize_per_tensor_affine_cachemask_out(const Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, Tensor & out0, Tensor & out1); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out(const Tensor & self, const Tensor & scale, const Tensor & zero_point, const Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max, Tensor & out0, Tensor & out1); // {"schema": "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _fake_quantize_learnable_per_tensor_affine_out(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor, Tensor & out); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple fake_quantize_per_channel_affine_cachemask_out(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, Tensor & out0, Tensor & out1); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _fake_quantize_learnable_per_channel_affine_out(const Tensor & self, const Tensor & scale, const Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor, Tensor & out); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _fused_moving_avg_obs_fq_helper_out(const Tensor & self, const Tensor & observer_on, const Tensor & fake_quant_on, Tensor & running_min, Tensor & running_max, Tensor & scale, Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant, Tensor & out0, Tensor & out1); // {"schema": "aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "True"}
+::std::tuple _fused_moving_avg_obs_fq_helper_functional(const Tensor & self, const Tensor & observer_on, const Tensor & fake_quant_on, const Tensor & running_min, const Tensor & running_max, const Tensor & scale, const Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::_fused_moving_avg_obs_fq_helper_functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)", "dispatch": "True", "default": "True"}
+Tensor & _to_copy_out(const Tensor & self, bool non_blocking, c10::optional memory_format, Tensor & out); // {"schema": "aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _lstm_mps_out(const Tensor & input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3, Tensor & out4, Tensor & out5); // {"schema": "aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "True"}
+void lstm_mps_backward_out(const c10::optional & grad_y, const c10::optional & grad_hy, const c10::optional & grad_cy, const Tensor & z_state, const Tensor & cell_state_fwd, const Tensor & input, const Tensor & layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, Tensor & out0, TensorList out1, TensorList out2); // {"schema": "aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple _thnn_fused_lstm_cell_out(const Tensor & input_gates, const Tensor & hidden_gates, const Tensor & cx, const c10::optional & input_bias, const c10::optional & hidden_bias, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple _thnn_fused_lstm_cell_backward_impl_out(const c10::optional & grad_hy, const c10::optional & grad_cy, const Tensor & cx, const Tensor & cy, const Tensor & workspace, bool has_bias, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+::std::tuple _thnn_fused_gru_cell_out(const Tensor & input_gates, const Tensor & hidden_gates, const Tensor & hx, const c10::optional & input_bias, const c10::optional & hidden_bias, Tensor & out0, Tensor & out1); // {"schema": "aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+::std::tuple _thnn_fused_gru_cell_backward_out(const Tensor & grad_hy, const Tensor & workspace, bool has_bias, Tensor & out0, Tensor & out1, Tensor & out2, Tensor & out3, Tensor & out4); // {"schema": "aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"}
+::std::tuple _pack_padded_sequence_out(const Tensor & input, const Tensor & lengths, bool batch_first, Tensor & out0, Tensor & out1); // {"schema": "aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & set_out(const Tensor & self, Storage source, Tensor & out); // {"schema": "aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor set(const Tensor & self, Storage source); // {"schema": "aten::set.source_Storage(Tensor self, Storage source) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & set_out(const Tensor & self, Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, Tensor & out); // {"schema": "aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor set(const Tensor & self, Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & set_out(const Tensor & self, const Tensor & source, Tensor & out); // {"schema": "aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor set(const Tensor & self, const Tensor & source); // {"schema": "aten::set.source_Tensor(Tensor self, Tensor source) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & set_out(const Tensor & self, Tensor & out); // {"schema": "aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor set(const Tensor & self); // {"schema": "aten::set(Tensor self) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & lift_out(const Tensor & self, Tensor & out); // {"schema": "aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & lift_fresh_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & masked_fill_out(const Tensor & self, const Tensor & mask, const Scalar & value, Tensor & out); // {"schema": "aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & masked_fill_out(const Tensor & self, const Tensor & mask, const Tensor & value, Tensor & out); // {"schema": "aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & masked_scatter_out(const Tensor & self, const Tensor & mask, const Tensor & source, Tensor & out); // {"schema": "aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _masked_softmax_out(const Tensor & self, const Tensor & mask, c10::optional dim, c10::optional mask_type, Tensor & out); // {"schema": "aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _masked_softmax_backward_out(const Tensor & grad_output, const Tensor & output, const Tensor & mask, c10::optional dim, Tensor & out); // {"schema": "aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & put_out(const Tensor & self, const Tensor & index, const Tensor & source, bool accumulate, Tensor & out); // {"schema": "aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & index_fill_out(const Tensor & self, int64_t dim, const Tensor & index, const Scalar & value, Tensor & out); // {"schema": "aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & index_fill_out(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & value, Tensor & out); // {"schema": "aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_and_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_or_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_xor_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & __lshift___out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & __lshift___out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_left_shift_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & __rshift___out(const Tensor & self, const Scalar & other, Tensor & out); // {"schema": "aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & __rshift___out(const Tensor & self, const Tensor & other, Tensor & out); // {"schema": "aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & bitwise_right_shift_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & random_out(const Tensor & self, int64_t from, c10::optional to, c10::optional generator, Tensor & out); // {"schema": "aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor random(const Tensor & self, int64_t from, c10::optional to, c10::optional generator); // {"schema": "aten::random.from(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & random_out(const Tensor & self, int64_t to, c10::optional generator, Tensor & out); // {"schema": "aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor random(const Tensor & self, int64_t to, c10::optional generator); // {"schema": "aten::random.to(Tensor self, int to, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & random_out(const Tensor & self, c10::optional generator, Tensor & out); // {"schema": "aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor random(const Tensor & self, c10::optional generator); // {"schema": "aten::random(Tensor self, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & uniform_out(const Tensor & self, double from, double to, c10::optional generator, Tensor & out); // {"schema": "aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor uniform(const Tensor & self, double from, double to, c10::optional generator); // {"schema": "aten::uniform(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & cauchy_out(const Tensor & self, double median, double sigma, c10::optional generator, Tensor & out); // {"schema": "aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor cauchy(const Tensor & self, double median, double sigma, c10::optional generator); // {"schema": "aten::cauchy(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & log_normal_out(const Tensor & self, double mean, double std, c10::optional generator, Tensor & out); // {"schema": "aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor log_normal(const Tensor & self, double mean, double std, c10::optional generator); // {"schema": "aten::log_normal(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & exponential_out(const Tensor & self, double lambd, c10::optional generator, Tensor & out); // {"schema": "aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor exponential(const Tensor & self, double lambd, c10::optional generator); // {"schema": "aten::exponential(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & geometric_out(const Tensor & self, double p, c10::optional generator, Tensor & out); // {"schema": "aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor geometric(const Tensor & self, double p, c10::optional generator); // {"schema": "aten::geometric(Tensor self, float p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"}
+Tensor & tril_indices_out(int64_t row, int64_t col, int64_t offset, Tensor & out); // {"schema": "aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & triu_indices_out(int64_t row, int64_t col, int64_t offset, Tensor & out); // {"schema": "aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & trace_out(const Tensor & self, Tensor & out); // {"schema": "aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _cholesky_solve_helper_out(const Tensor & self, const Tensor & A, bool upper, Tensor & out); // {"schema": "aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & dist_out(const Tensor & self, const Tensor & other, const Scalar & p, Tensor & out); // {"schema": "aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+void _histogramdd_bin_edges_out(const Tensor & self, IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density, TensorList out); // {"schema": "aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+Tensor & _histogramdd_from_bin_cts_out(const Tensor & self, IntArrayRef bins, c10::optional> range, const c10::optional & weight, bool density, Tensor & out); // {"schema": "aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _histogramdd_from_bin_tensors_out(const Tensor & self, TensorList bins, const c10::optional & weight, bool density, Tensor & out); // {"schema": "aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & remainder_out(const Scalar & self, const Tensor & other, Tensor & out); // {"schema": "aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & argsort_out(const Tensor & self, bool stable, int64_t dim, bool descending, Tensor & out); // {"schema": "aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & unfold_backward_out(const Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, Tensor & out); // {"schema": "aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & normal_out(const Tensor & self, double mean, double std, c10::optional generator, Tensor & out); // {"schema": "aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+void _amp_foreach_non_finite_check_and_unscale_out(TensorList self, Tensor & found_inf, const Tensor & inv_scale, TensorList out); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,Tensor> _amp_foreach_non_finite_check_and_unscale(TensorList self, const Tensor & found_inf, const Tensor & inv_scale); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out)", "dispatch": "True", "default": "True"}
+Tensor & _amp_update_scale_out(const Tensor & self, Tensor & growth_tracker, const Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, Tensor & out); // {"schema": "aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _amp_update_scale(const Tensor & self, const Tensor & growth_tracker, const Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); // {"schema": "aten::_amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out)", "dispatch": "True", "default": "True"}
+void _foreach_add_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_add_out(TensorList self, TensorList other, const Scalar & alpha, TensorList out); // {"schema": "aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_add_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_add_out(TensorList self, const Tensor & other, const Scalar & alpha, TensorList out); // {"schema": "aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sub_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sub_out(TensorList self, TensorList other, const Scalar & alpha, TensorList out); // {"schema": "aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sub_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_mul_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_mul_out(TensorList self, TensorList other, TensorList out); // {"schema": "aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_mul_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_mul_out(TensorList self, const Tensor & other, TensorList out); // {"schema": "aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_div_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_div_out(TensorList self, TensorList other, TensorList out); // {"schema": "aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_div_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_div_out(TensorList self, const Tensor & other, TensorList out); // {"schema": "aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_clamp_max_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_clamp_max_out(TensorList self, TensorList other, TensorList out); // {"schema": "aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_clamp_max_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_clamp_min_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_clamp_min_out(TensorList self, TensorList other, TensorList out); // {"schema": "aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_clamp_min_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_maximum_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_maximum_out(TensorList self, TensorList other, TensorList out); // {"schema": "aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_maximum_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_minimum_out(TensorList self, const Scalar & scalar, TensorList out); // {"schema": "aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_minimum_out(TensorList self, TensorList other, TensorList out); // {"schema": "aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_minimum_out(TensorList self, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_addcdiv_out(TensorList self, TensorList tensor1, TensorList tensor2, const Scalar & value, TensorList out); // {"schema": "aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_addcdiv_out(TensorList self, TensorList tensor1, TensorList tensor2, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_addcdiv_out(TensorList self, TensorList tensor1, TensorList tensor2, const Tensor & scalars, TensorList out); // {"schema": "aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_addcmul_out(TensorList self, TensorList tensor1, TensorList tensor2, const Scalar & value, TensorList out); // {"schema": "aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_addcmul_out(TensorList self, TensorList tensor1, TensorList tensor2, ArrayRef scalars, TensorList out); // {"schema": "aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_addcmul_out(TensorList self, TensorList tensor1, TensorList tensor2, const Tensor & scalars, TensorList out); // {"schema": "aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_abs_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_acos_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_asin_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_atan_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_ceil_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_cos_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_cosh_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_erf_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_erfc_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_exp_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_expm1_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_floor_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_frac_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_lerp_out(TensorList self, TensorList tensors1, TensorList weights, TensorList out); // {"schema": "aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_lerp_out(TensorList self, TensorList tensors1, const Scalar & weight, TensorList out); // {"schema": "aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_lgamma_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_log_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_log10_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_log1p_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_log2_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_neg_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_norm_out(TensorList self, const Scalar & ord, TensorList out); // {"schema": "aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_pow_out(TensorList self, TensorList exponent, TensorList out); // {"schema": "aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_pow_out(TensorList self, const Scalar & exponent, TensorList out); // {"schema": "aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_pow_out(TensorList self, ArrayRef exponent, TensorList out); // {"schema": "aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_reciprocal_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_round_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sigmoid_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sign_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sin_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sinh_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_sqrt_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_tan_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_tanh_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_trunc_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+void _foreach_zero_out(TensorList self, TensorList out); // {"schema": "aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::vector _foreach_zero(TensorList self); // {"schema": "aten::_foreach_zero(Tensor[] self) -> Tensor[] self_out", "dispatch": "True", "default": "True"}
+void _foreach_copy_out(TensorList self, TensorList src, bool non_blocking, TensorList out); // {"schema": "aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::vector _foreach_copy(TensorList self, TensorList src, bool non_blocking); // {"schema": "aten::_foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out", "dispatch": "True", "default": "True"}
+Tensor & bucketize_out(const Scalar & self, const Tensor & boundaries, bool out_int32, bool right, Tensor & out); // {"schema": "aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & glu_jvp_out(const Tensor & glu, const Tensor & x, const Tensor & dx, int64_t dim, Tensor & out); // {"schema": "aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & glu_backward_jvp_out(const Tensor & grad_x, const Tensor & grad_glu, const Tensor & x, const Tensor & dgrad_glu, const Tensor & dx, int64_t dim, Tensor & out); // {"schema": "aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & hardswish_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & out); // {"schema": "aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & rrelu_with_noise_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & noise, const Scalar & lower, const Scalar & upper, bool training, bool self_is_result, Tensor & out); // {"schema": "aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & mkldnn_adaptive_avg_pool2d_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & out); // {"schema": "aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _adaptive_avg_pool2d_out(const Tensor & self, c10::SymIntArrayRef output_size, Tensor & out); // {"schema": "aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _adaptive_avg_pool2d_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & out); // {"schema": "aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _adaptive_avg_pool3d_out(const Tensor & self, c10::SymIntArrayRef output_size, Tensor & out); // {"schema": "aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _adaptive_avg_pool3d_backward_out(const Tensor & grad_output, const Tensor & self, Tensor & out); // {"schema": "aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _slow_conv2d_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask, Tensor & out0, Tensor & out1, Tensor & out2); // {"schema": "aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"}
+Tensor & conv_depthwise3d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, Tensor & out); // {"schema": "aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & slow_conv_dilated2d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, Tensor & out); // {"schema": "aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & slow_conv_dilated3d_out(const Tensor & self, const Tensor & weight, c10::SymIntArrayRef kernel_size, const c10::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, Tensor & out); // {"schema": "aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & isinf_out(const Tensor & self, Tensor & out); // {"schema": "aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & linalg_matrix_exp_out(const Tensor & self, Tensor & out); // {"schema": "aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_optional_intlist_out(const Tensor & values, OptionalIntArrayRef addends, Tensor & out); // {"schema": "aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_optional_filled_intlist_out(const Tensor & values, OptionalIntArrayRef addends, Tensor & out); // {"schema": "aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_optional_floatlist_out(const Tensor & values, c10::optional> addends, Tensor & out); // {"schema": "aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_warn_in_autograd_out(const Tensor & self, Tensor & out); // {"schema": "aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_autograd_multiple_dispatch_out(const Tensor & self, Tensor & out); // {"schema": "aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _test_autograd_multiple_dispatch_view_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & segment_reduce_out(const Tensor & data, c10::string_view reduce, const c10::optional & lengths, const c10::optional & indices, const c10::optional & offsets, int64_t axis, bool unsafe, const c10::optional & initial, Tensor & out); // {"schema": "aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _segment_reduce_backward_out(const Tensor & grad, const Tensor & output, const Tensor & data, c10::string_view reduce, const c10::optional & lengths, const c10::optional & offsets, int64_t axis, const c10::optional & initial, Tensor & out); // {"schema": "aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _nested_tensor_from_tensor_list_out(TensorList list, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, Tensor & out); // {"schema": "aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _fw_primal_copy_out(const Tensor & self, int64_t level, Tensor & out); // {"schema": "aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _make_dual_copy_out(const Tensor & primal, const Tensor & tangent, int64_t level, Tensor & out); // {"schema": "aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & view_as_real_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & view_as_complex_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _conj_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _neg_view_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & as_strided_copy_out(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset, Tensor & out); // {"schema": "aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _sparse_broadcast_to_copy_out(const Tensor & self, IntArrayRef size, Tensor & out); // {"schema": "aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & diagonal_copy_out(const Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, Tensor & out); // {"schema": "aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & expand_copy_out(const Tensor & self, c10::SymIntArrayRef size, bool implicit, Tensor & out); // {"schema": "aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & permute_copy_out(const Tensor & self, IntArrayRef dims, Tensor & out); // {"schema": "aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _reshape_alias_copy_out(const Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, Tensor & out); // {"schema": "aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & select_copy_out(const Tensor & self, int64_t dim, c10::SymInt index, Tensor & out); // {"schema": "aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & detach_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & slice_copy_out(const Tensor & self, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step, Tensor & out); // {"schema": "aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_copy_out(const Tensor & self, int64_t dim, Tensor & out); // {"schema": "aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & squeeze_copy_out(const Tensor & self, IntArrayRef dim, Tensor & out); // {"schema": "aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & t_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & transpose_copy_out(const Tensor & self, int64_t dim0, int64_t dim1, Tensor & out); // {"schema": "aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & unsqueeze_copy_out(const Tensor & self, int64_t dim, Tensor & out); // {"schema": "aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _indices_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _values_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & indices_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & values_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & crow_indices_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & col_indices_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & ccol_indices_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & row_indices_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & view_copy_out(const Tensor & self, c10::SymIntArrayRef size, Tensor & out); // {"schema": "aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & view_copy_out(const Tensor & self, ScalarType dtype, Tensor & out); // {"schema": "aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & unfold_copy_out(const Tensor & self, int64_t dimension, int64_t size, int64_t step, Tensor & out); // {"schema": "aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & alias_copy_out(const Tensor & self, Tensor & out); // {"schema": "aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & to_padded_tensor_out(const Tensor & self, double padding, OptionalSymIntArrayRef output_size, Tensor & out); // {"schema": "aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _transformer_encoder_layer_fwd_out(const Tensor & src, int64_t embed_dim, int64_t num_heads, const Tensor & qkv_weight, const Tensor & qkv_bias, const Tensor & proj_weight, const Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const Tensor & norm_weight_1, const Tensor & norm_bias_1, const Tensor & norm_weight_2, const Tensor & norm_bias_2, const Tensor & ffn_weight_1, const Tensor & ffn_bias_1, const Tensor & ffn_weight_2, const Tensor & ffn_bias_2, const c10::optional & mask, c10::optional mask_type, Tensor & out); // {"schema": "aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+::std::tuple _native_multi_head_attention_out(const Tensor & query, const Tensor & key, const Tensor & value, int64_t embed_dim, int64_t num_head, const Tensor & qkv_weight, const Tensor & qkv_bias, const Tensor & proj_weight, const Tensor & proj_bias, const c10::optional & mask, bool need_weights, bool average_attn_weights, c10::optional mask_type, Tensor & out0, Tensor & out1); // {"schema": "aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"}
+Tensor & _triton_scaled_dot_attention_out(const Tensor & q, const Tensor & k, const Tensor & v, double dropout_p, Tensor & out); // {"schema": "aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _triton_multi_head_attention_out(const Tensor & query, const Tensor & key, const Tensor & value, int64_t embed_dim, int64_t num_head, const Tensor & qkv_weight, const Tensor & qkv_bias, const Tensor & proj_weight, const Tensor & proj_bias, const c10::optional & mask, Tensor & out); // {"schema": "aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+Tensor & _foobar_out(const Tensor & self, bool arg1, bool arg2, bool arg3, Tensor & out); // {"schema": "aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"}
+void _fused_adam_out(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, TensorList out); // {"schema": "aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adam(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"}
+void _fused_adam_out(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, const Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, TensorList out); // {"schema": "aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, const Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adam.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"}
+void _fused_adamw_out(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, TensorList out); // {"schema": "aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adamw(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"}
+void _fused_adamw_out(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, const Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf, TensorList out); // {"schema": "aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(TensorList self, TensorList grads, TensorList exp_avgs, TensorList exp_avg_sqs, TensorList max_exp_avg_sqs, TensorList state_steps, const Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_adamw.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"}
+void _fused_sgd_out(TensorList self, TensorList grads, TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf, TensorList out); // {"schema": "aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(TensorList self, TensorList grads, TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_sgd(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)", "dispatch": "True", "default": "True"}
+void _fused_sgd_out(TensorList self, TensorList grads, TensorList momentum_buffer_list, double weight_decay, double momentum, const Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf, TensorList out); // {"schema": "aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"}
+::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(TensorList self, TensorList grads, TensorList momentum_buffer_list, double weight_decay, double momentum, const Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const c10::optional & grad_scale, const c10::optional & found_inf); // {"schema": "aten::_fused_sgd.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)", "dispatch": "True", "default": "True"}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/SavedTensorHooks.h b/MLPY/Lib/site-packages/torch/include/ATen/SavedTensorHooks.h
new file mode 100644
index 0000000000000000000000000000000000000000..8e5708ce33e4718344562a0996472ddbed438737
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/SavedTensorHooks.h
@@ -0,0 +1,52 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+
+namespace impl {
+
+struct TORCH_API SavedTensorDefaultHooksTLS {
+  // PyObject is defined in c10/util/python_stub.h
+  std::stack> stack;
+
+  // See NOTE: [Disabling SavedTensorDefaultHooks] for context
+  // NOTE: [disabled_error_message invariant]
+  // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
+  // We did this for efficiency (so we didn't have to keep a separate bool
+  // around)
+  c10::optional disabled_error_message;
+};
+
+} // namespace impl
+
+struct TORCH_API SavedTensorDefaultHooks {
+  static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
+  static void pop_hooks();
+  static std::pair get_hooks();
+  static void lazy_initialize();
+  static std::stack> get_stack();
+  static void set_stack(std::stack>);
+
+  static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
+  static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
+
+  // NOTE: [Disabling SavedTensorDefaultHooks]
+  // A developer of a PyTorch feature may choose to disable SavedTensorDefault
+  // hooks, especially if their feature does not work with it. If they are
+  // disabled, then the following will raise an error:
+  // - Attempting to push_hooks
+  // - calling disable(message) with a non-zero stack (from get_stack) size
+  static void disable(const std::string& error_message);
+  static void enable();
+  static bool is_enabled();
+  static const c10::optional& get_disabled_error_message();
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/Scalar.h b/MLPY/Lib/site-packages/torch/include/ATen/Scalar.h
new file mode 100644
index 0000000000000000000000000000000000000000..6dec39dd3c32cef073fec4891ab16a71c58e8077
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/Scalar.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ScalarOps.h b/MLPY/Lib/site-packages/torch/include/ATen/ScalarOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..533ba88771c62e3fc4fd95490d1c1a1421c412f4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ScalarOps.h
@@ -0,0 +1,53 @@
+#pragma once
+
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::detail {
+// When filling a number to 1-element CPU tensor, we want to skip
+// everything but manipulate data ptr directly.
+// Ideally this fast pass should be implemented in TensorIterator,
+// but we also want to skip compute_types which in not avoidable
+// in TensorIterator for now.
+Tensor& scalar_fill(Tensor& self, const Scalar& value);
+TORCH_API Tensor scalar_tensor_static(
+    const Scalar& s,
+    c10::optional dtype_opt,
+    c10::optional device_opt);
+} // namespace at::detail
+
+// This is in the c10 namespace because we use ADL to find the functions in it.
+namespace c10 {
+
+// FIXME: this should be (and was) Scalar::toTensor, but there is currently no
+// way to implement this without going through Derived Types (which are not part
+// of core).
+inline at::Tensor scalar_to_tensor(
+    const Scalar& s,
+    const Device device = at::kCPU) {
+  // This is the fast track we have for CPU scalar tensors.
+  if (device == at::kCPU) {
+    return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
+  }
+  return at::scalar_tensor(s, at::device(device).dtype(s.type()));
+}
+
+} // namespace c10
+
+namespace at::native {
+
+inline Tensor wrapped_scalar_tensor(
+    const Scalar& scalar,
+    const Device device = at::kCPU) {
+  auto tensor = scalar_to_tensor(scalar, device);
+  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
+  return tensor;
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/autocast_mode.h b/MLPY/Lib/site-packages/torch/include/ATen/autocast_mode.h
new file mode 100644
index 0000000000000000000000000000000000000000..b29bd694747b7ea2c9f21d4a9603292b84215654
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/autocast_mode.h
@@ -0,0 +1,647 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at::autocast {
+
+TORCH_API bool is_enabled();
+TORCH_API void set_enabled(bool enabled);
+TORCH_API void clear_cache();
+TORCH_API int increment_nesting();
+TORCH_API int decrement_nesting();
+TORCH_API bool is_cpu_enabled();
+TORCH_API void set_cpu_enabled(bool enabled);
+TORCH_API at::ScalarType get_autocast_gpu_dtype();
+TORCH_API at::ScalarType get_autocast_cpu_dtype();
+TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
+TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
+TORCH_API bool is_xpu_enabled();
+TORCH_API void set_xpu_enabled(bool enabled);
+TORCH_API at::ScalarType get_autocast_xpu_dtype();
+TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
+TORCH_API bool is_ipu_enabled();
+TORCH_API void set_ipu_enabled(bool enabled);
+TORCH_API at::ScalarType get_autocast_ipu_dtype();
+TORCH_API void set_autocast_ipu_dtype(at::ScalarType dtype);
+TORCH_API bool is_hpu_enabled();
+TORCH_API void set_hpu_enabled(bool enabled);
+TORCH_API at::ScalarType get_autocast_hpu_dtype();
+TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
+TORCH_API bool is_xla_enabled();
+TORCH_API void set_xla_enabled(bool enabled);
+TORCH_API at::ScalarType get_autocast_xla_dtype();
+TORCH_API void set_autocast_xla_dtype(at::ScalarType dtype);
+TORCH_API bool is_privateuseone_enabled();
+TORCH_API void set_privateuseone_enabled(bool enabled);
+TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
+TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
+TORCH_API bool is_autocast_cache_enabled();
+TORCH_API void set_autocast_cache_enabled(bool enabled);
+
+namespace {
+inline bool is_autocast_eligible(
+    const Tensor& tensor,
+    c10::DeviceType device_type) {
+  switch (device_type) {
+    case c10::DeviceType::CUDA:
+      return (tensor.is_cuda() || tensor.is_xla()) &&
+          tensor.is_floating_point();
+    case c10::DeviceType::CPU:
+      return (tensor.is_cpu() || tensor.is_mkldnn()) &&
+          tensor.is_floating_point();
+    case c10::DeviceType::XPU:
+      return tensor.is_xpu() && tensor.is_floating_point();
+    case c10::DeviceType::IPU:
+      return tensor.is_ipu() && tensor.is_floating_point();
+    case c10::DeviceType::HPU:
+      return tensor.is_hpu() && tensor.is_floating_point();
+    case c10::DeviceType::XLA:
+      return tensor.is_xla() && tensor.is_floating_point();
+    case c10::DeviceType::PrivateUse1:
+      return tensor.is_privateuseone() && tensor.is_floating_point();
+    default:
+      return false;
+  }
+}
+} // namespace
+
+inline DispatchKey get_autocast_dispatch_key_from_device_type(
+    c10::DeviceType device_type) {
+  switch (device_type) {
+    case c10::DeviceType::CUDA:
+      return DispatchKey::Autocast;
+    case c10::DeviceType::CPU:
+      return DispatchKey::AutocastCPU;
+    case c10::DeviceType::XPU:
+      return DispatchKey::AutocastXPU;
+    case c10::DeviceType::IPU:
+      return DispatchKey::AutocastIPU;
+    case c10::DeviceType::HPU:
+      return DispatchKey::AutocastHPU;
+    case c10::DeviceType::XLA:
+      return DispatchKey::AutocastXLA;
+    case c10::DeviceType::PrivateUse1:
+      return DispatchKey::AutocastPrivateUse1;
+    default:
+      throw std::runtime_error(
+          "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
+  }
+}
+
+inline at::ScalarType get_lower_precision_fp_from_device_type(
+    c10::DeviceType device_type) {
+  switch (device_type) {
+    case c10::DeviceType::CUDA:
+      return get_autocast_gpu_dtype();
+    case c10::DeviceType::CPU:
+      return get_autocast_cpu_dtype();
+    case c10::DeviceType::XPU:
+      return get_autocast_xpu_dtype();
+    case c10::DeviceType::IPU:
+      return get_autocast_ipu_dtype();
+    case c10::DeviceType::HPU:
+      return get_autocast_hpu_dtype();
+    case c10::DeviceType::XLA:
+      return get_autocast_xla_dtype();
+    case c10::DeviceType::PrivateUse1:
+      return get_autocast_privateuseone_dtype();
+    default:
+      throw std::runtime_error(
+          "unknown device type for autocast in get_lower_precision_fp_from_device_type");
+  }
+}
+
+/********************************************************************
+Logic to extract the promote type from any Tensor or TensorList args.
+********************************************************************/
+
+// Overload to catch Tensor args.
+// If nextArg is floating-point, compare its scalar_type with our
+// current best guess for the promote type, and update if necessary.
+inline at::ScalarType prioritize(
+    at::ScalarType current,
+    const Tensor& nextArg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  if (current == at::kDouble) {
+    AT_ERROR("promote type is double in at::autocast::prioritize");
+    return current;
+  }
+  at::ScalarType lower_precision_fp =
+      get_lower_precision_fp_from_device_type(device_type);
+  if (is_autocast_eligible(nextArg, device_type)) {
+    auto next = nextArg.scalar_type();
+    if (next == at::kDouble) {
+      return current; // ignores double tensors
+    } else if (current == at::kFloat || next == at::kFloat) {
+      return at::kFloat; // prioritizes float over lower_precision_fp
+    } else if (current == lower_precision_fp && next == lower_precision_fp) {
+      return lower_precision_fp;
+    } else {
+      AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
+      return current;
+    }
+  } else {
+    return current;
+  }
+}
+
+// Overload to catch TensorList args (for e.g. cat, stack).
+// Reuses the overload above to process each Tensor in the list.
+inline at::ScalarType prioritize(
+    at::ScalarType current,
+    const TensorList& list,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  for (const auto& tensor : list) {
+    current = prioritize(current, tensor, device_type);
+  }
+  return current;
+}
+
+inline at::ScalarType prioritize(
+    at::ScalarType current,
+    const ITensorListRef& list,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  for (const auto& tensor : list) {
+    current = prioritize(current, tensor, device_type);
+  }
+  return current;
+}
+
+// Template to catch non-Tensor args (no-op that returns current best guess)
+template 
+inline at::ScalarType prioritize(
+    at::ScalarType current,
+    T nextArg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  return current;
+}
+
+// Overload for the tail case.
+inline at::ScalarType promote_type(
+    at::ScalarType current,
+    c10::DeviceType device_type) {
+  return current;
+}
+
+// Unpack args and determine if incoming lower_precision_fp tensors need to be
+// promoted to float32. Non-Tensor arguments are ignored.
+template 
+inline at::ScalarType promote_type(
+    at::ScalarType current,
+    c10::DeviceType device_type,
+    Arg0 arg0,
+    Args... args) {
+  auto new_current = prioritize(current, arg0, device_type);
+  return promote_type(new_current, device_type, args...);
+}
+
+/****************************************************
+Logic to apply cached casting to any Tensor argument.
+****************************************************/
+inline bool is_eligible(
+    const Tensor& arg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  return (
+      arg.defined() && is_autocast_eligible(arg, device_type) &&
+      (arg.scalar_type() != at::kDouble));
+}
+
+// Overload to catch Tensor args
+TORCH_API Tensor cached_cast(
+    at::ScalarType to_type,
+    const Tensor& arg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA);
+
+// Overload to process optional
+inline c10::optional cached_cast(
+    at::ScalarType to_type,
+    const c10::optional& arg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  if (arg.has_value()) {
+    return cached_cast(to_type, *arg, device_type);
+  } else {
+    return c10::nullopt;
+  }
+}
+
+// Overload to process TensorLists
+inline std::vector cached_cast(
+    at::ScalarType to_type,
+    const TensorList& arg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  std::vector vec;
+  vec.reserve(arg.size());
+  for (const auto& t : arg) {
+    vec.emplace_back(cached_cast(to_type, t, device_type));
+  }
+  return vec;
+}
+
+inline std::vector cached_cast(
+    at::ScalarType to_type,
+    const ITensorListRef& arg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  std::vector vec;
+  vec.reserve(arg.size());
+  for (const auto& t : arg) {
+    vec.emplace_back(cached_cast(to_type, t, device_type));
+  }
+  return vec;
+}
+
+// Template to catch non-Tensor args.
+template 
+inline T cached_cast(
+    at::ScalarType to_type,
+    T arg,
+    c10::DeviceType device_type = c10::DeviceType::CUDA) {
+  return arg;
+}
+
+/*******************************************************
+Logic to flip an output dtype flag.
+Keep it simple for now by assuming only one such flag is
+present in the argument list.  If I ever need a function
+with more than flag I'll figure out something else.
+The policy is:
+If the user has explicity specified a dtype, respect it.
+Otherwise, set it to the autocast type.
+********************************************************/
+
+// Overload to catch dtype flags
+c10::optional inline set_opt_dtype(
+    at::ScalarType to_type,
+    const c10::optional& dtype) {
+  return dtype.has_value() ? dtype : to_type;
+}
+
+// Template to catch other args
+template 
+inline T set_opt_dtype(at::ScalarType to_type, T arg) {
+  return arg;
+}
+
+template 
+inline bool firstarg_is_eligible(
+    c10::DeviceType device_type,
+    const Tensor& arg,
+    Args... args) {
+  return is_eligible(arg, device_type);
+}
+
+template 
+inline at::ScalarType type_from_firstarg(
+    c10::DeviceType device_type,
+    at::ScalarType to_type,
+    const Tensor& arg,
+    Args... args) {
+  return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
+}
+
+// Policies correspond to op categories that need code-divergent handling.
+// Wrapper templates below are specialized based on a policy template parameter.
+enum class CastPolicy : uint8_t {
+  lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
+                          // running the op. Currently, lower_precision_fp is
+                          // fp16 for AutocastCUDA, and is defined by user
+                          // (default bf16) for AutocastCPU or other device.
+  fp32, // Cast all inputs to at::kFloat before running the op.
+  fp32_set_opt_dtype, // Treats functions (like softmax) that
+                      //  1. we'd like to run in fp32 and
+                      //  2. have a c10::optional arg that controls
+                      //  the output type.
+                      // fp32_set_opt_dtype wrappers' policy is: if the output
+                      // type is already set, don't touch it, otherwise, set
+                      // it to at::kFloat.
+  fp32_append_dtype, // Treats functions (like norm) that
+                     //  1. we'd like to run in fp32 and
+                     //  2. have some overloads that accept an output type and
+                     //  other overloads that don't.
+                     // fp32_append_dtype wrappers wrap the overloads that don't
+                     // have an output dtype.
+                     // The wrapper policy is:  append at::kFloat to the args,
+                     // and redispatch to the type-aware overload.
+  promote, // Run in the widest dtype among several args.
+};
+
+/********************************************************************************************************
+Templates to provide wrapper functions
+
+I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
+extract args and return type. (see also
+https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
+
+This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
+of (in my case several specializations of) an interior "WrapFunction_".
+Interior WrapFunction_ specializations are defined for each CastPolicy.
+********************************************************************************************************/
+
+// Base template for WrapFunction_, which is specialized to contain a "call"
+// method each CastPolicy
+template <
+    CastPolicy policy,
+    c10::DeviceType device_type,
+    class Redispatch,
+    Redispatch* F,
+    class Ret,
+    class ArgList>
+struct WrapFunction_ {};
+
+// CastPolicy::lower_precision_fp General_DeviceType
+template <
+    c10::DeviceType device_type,
+    class Redispatch,
+    Redispatch* F,
+    class Ret,
+    class... Args>
+struct WrapFunction_<
+    CastPolicy::lower_precision_fp,
+    device_type,
+    Redispatch,
+    F,
+    Ret,
+    guts::typelist::typelist> {
+  static Ret call(Args... args) {
+    c10::impl::ExcludeDispatchKeyGuard no_autocast(
+        get_autocast_dispatch_key_from_device_type(device_type));
+    return (*F)(cached_cast(
+        get_lower_precision_fp_from_device_type(device_type),
+        args,
+        device_type)...);
+  }
+};
+
+// CastPolicy::fp32 General_DeviceType
+template <
+    c10::DeviceType device_type,
+    class Redispatch,
+    Redispatch* F,
+    class Ret,
+    class... Args>
+struct WrapFunction_<
+    CastPolicy::fp32,
+    device_type,
+    Redispatch,
+    F,
+    Ret,
+    guts::typelist::typelist> {
+  static Ret call(Args... args) {
+    c10::impl::ExcludeDispatchKeyGuard no_autocast(
+        get_autocast_dispatch_key_from_device_type(device_type));
+    return (*F)(cached_cast(at::kFloat, args, device_type)...);
+  }
+};
+
+// CastPolicy::fp32_set_opt_dtype General_DeviceType
+template <
+    c10::DeviceType device_type,
+    class Redispatch,
+    Redispatch* F,
+    class Ret,
+    class... Args>
+struct WrapFunction_<
+    CastPolicy::fp32_set_opt_dtype,
+    device_type,
+    Redispatch,
+    F,
+    Ret,
+    guts::typelist::typelist> {
+  static Ret call(Args... args) {
+    c10::impl::ExcludeDispatchKeyGuard no_autocast(
+        get_autocast_dispatch_key_from_device_type(device_type));
+    if (firstarg_is_eligible(device_type, args...)) {
+      return (*F)(set_opt_dtype(at::kFloat, args)...);
+    } else {
+      // If ineligible, calls F with unaltered args.  Does not set opt dtype,
+      // because setting opt dtype explicitly may interfere with internal
+      // implicit promotion decisions.
+      return (*F)(args...);
+    }
+  }
+};
+
+// CastPolicy::fp32_append_dtype General_DeviceType
+template <
+    c10::DeviceType device_type,
+    class Redispatch,
+    Redispatch* F,
+    class Ret,
+    class... Args>
+struct WrapFunction_<
+    CastPolicy::fp32_append_dtype,
+    device_type,
+    Redispatch,
+    F,
+    Ret,
+    guts::typelist::typelist> {
+  static Ret call(Args... args) {
+    c10::impl::ExcludeDispatchKeyGuard no_autocast(
+        get_autocast_dispatch_key_from_device_type(device_type));
+    at::ScalarType out_type =
+        type_from_firstarg(device_type, at::kFloat, args...);
+    return (*F)(args..., out_type);
+  }
+};
+
+// CastPolicy::promote General_DeviceType
+template <
+    c10::DeviceType device_type,
+    class Redispatch,
+    Redispatch* F,
+    class Ret,
+    class... Args>
+struct WrapFunction_<
+    CastPolicy::promote,
+    device_type,
+    Redispatch,
+    F,
+    Ret,
+    guts::typelist::typelist> {
+  static Ret call(Args... args) {
+    c10::impl::ExcludeDispatchKeyGuard no_autocast(
+        get_autocast_dispatch_key_from_device_type(device_type));
+    auto to_type = promote_type(
+        get_lower_precision_fp_from_device_type(device_type),
+        device_type,
+        args...);
+    return (*F)(cached_cast(to_type, args, device_type)...);
+  }
+};
+
+// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
+// core/boxing/impl/WrapFunctionIntoFunctor.h)
+template <
+    CastPolicy policy,
+    c10::DeviceType device_type,
+    class Registered, // The signature for which we're registering.  The
+                      // dispatcher's calling code invokes our registered
+                      // functions with arguments matching Registered, so we
+                      // register WrapFunction_::call methods with a matching
+                      // signature to properly field those arguments.
+    // guts::function_traits below extracts return_type and
+    // parameter_types from Registered, which WrapFunction_
+    // templates above use to declare their call methods.
+    class Redispatch, // The signature for the function we're redispatching to.
+                      // In most cases this is the same as Registered, but for
+                      // some ops (for example, ops where we append a dtype)
+                      // it's useful to redispatch to a function with a
+                      // different signature.
+    Redispatch* F> // The actual function we're redispatching to.
+struct WrapFunction final {
+  using type = WrapFunction_<
+      policy,
+      device_type,
+      Redispatch,
+      F,
+      typename guts::function_traits::return_type,
+      typename guts::function_traits::parameter_types>;
+};
+
+/*****************************************************************************************************************
+This section performs load-time registration for autocast wrappers.
+
+It's debatable at what level operations should be patched.  We'd like casts to
+be autograd-exposed and precede autograd history recording, so that for
+lower_precision_fp ops, input tensors are saved for backward in
+lower_precision_fp rather than fp32.  Saving inputs in lower_precision_fp
+can significantly reduce a model's memory footprint.
+
+Option 1 (strawman):  Patch only at the level of explicit calls into
+cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
+guaranteed to use Tensor Cores, therefore they're the ones that will benefit
+most from lower_precision_fp.   Potential pitfall:  convolutions (and other ops)
+are wrapped in several layers of at::* calls.  If one of those happens to record
+autograd history, then we've lost the opportunity to save inputs in
+lower_precision_fp.
+
+Option 2:  Patch the Python-exposed surface of calls, to make 100% sure autograd
+history recording can't sneak in ahead of autocast.  This mirrors Apex most
+closely.
+
+I think Option 2 is the right answer for all ops, not just convolutions. Option
+2 is what I implement here.
+*****************************************************************************************************************/
+
+/********************************************************************************************************************
+Explicit registration for out-of-place ops
+
+The stuff below could be codegenned.  Ed said
+> you are going to have to write the function definition at some point, I
+wouldn't try to get clever about it Therefore, for the moment, this is all
+copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
+********************************************************************************************************************/
+
+} // namespace at::autocast
+
+#define ADD_NS(RAW_OP) at::RAW_OP
+
+// Common cases where registration signature matches redispatch signature
+// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
+#define KERNEL(DISPATCHKEY, OP, POLICY)       \
+  m.impl(                                     \
+      TORCH_SELECTIVE_NAME("aten::" #OP),     \
+      &::at::autocast::WrapFunction<          \
+          ::at::autocast::CastPolicy::POLICY, \
+          DISPATCHKEY,                        \
+          decltype(ATEN_FN(OP)),              \
+          decltype(ATEN_FN(OP)),              \
+          &ATEN_FN(OP)>::type::call);
+
+#define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY)      \
+  m.impl(                                               \
+      TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
+      &::at::autocast::WrapFunction<                    \
+          ::at::autocast::CastPolicy::POLICY,           \
+          DISPATCHKEY,                                  \
+          decltype(ATEN_FN2(OP, OVERLOAD)),             \
+          decltype(ATEN_FN2(OP, OVERLOAD)),             \
+          &ATEN_FN2(OP, OVERLOAD)>::type::call);
+
+// Less-common but still useful case: redispatching to a function
+// with a new signature (e.g. appending a dtype)
+#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(      \
+    DISPATCHKEY,                                    \
+    REDISPATCH_FUNC,                                \
+    REGISTER_NAME,                                  \
+    REGISTER_SIGNATURE,                             \
+    REDISPATCH_SIGNATURE,                           \
+    POLICY)                                         \
+  m.impl(                                           \
+      TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
+      &::at::autocast::WrapFunction<                \
+          ::at::autocast::CastPolicy::POLICY,       \
+          DISPATCHKEY,                              \
+          REGISTER_SIGNATURE,                       \
+          REDISPATCH_SIGNATURE,                     \
+          &REDISPATCH_FUNC>::type::call);
+
+// KERNEL_CPU/KERNEL_CPU2/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
+// registration for AutocastCPU
+#define KERNEL_CPU(OP, POLICY) KERNEL(c10::DeviceType::CPU, OP, POLICY)
+
+#define KERNEL_CPU2(OP, OVERLOAD, POLICY) \
+  KERNEL2(c10::DeviceType::CPU, OP, OVERLOAD, POLICY)
+
+#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
+    REDISPATCH_FUNC,                               \
+    REGISTER_NAME,                                 \
+    REGISTER_SIGNATURE,                            \
+    REDISPATCH_SIGNATURE,                          \
+    POLICY)                                        \
+  KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(           \
+      c10::DeviceType::CPU,                        \
+      REDISPATCH_FUNC,                             \
+      REGISTER_NAME,                               \
+      REGISTER_SIGNATURE,                          \
+      REDISPATCH_SIGNATURE,                        \
+      POLICY)
+
+// KERNEL_CUDA/KERNEL_CUDA2/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
+// registration for AutocastCUDA
+#define KERNEL_CUDA(OP, POLICY) KERNEL(c10::DeviceType::CUDA, OP, POLICY)
+
+#define KERNEL_CUDA2(OP, OVERLOAD, POLICY) \
+  KERNEL2(c10::DeviceType::CUDA, OP, OVERLOAD, POLICY)
+
+#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
+    REDISPATCH_FUNC,                                \
+    REGISTER_NAME,                                  \
+    REGISTER_SIGNATURE,                             \
+    REDISPATCH_SIGNATURE,                           \
+    POLICY)                                         \
+  KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(            \
+      c10::DeviceType::CUDA,                        \
+      REDISPATCH_FUNC,                              \
+      REGISTER_NAME,                                \
+      REGISTER_SIGNATURE,                           \
+      REDISPATCH_SIGNATURE,                         \
+      POLICY)
+
+// KERNEL_PRIVATEUSEONE/KERNEL_PRIVATEUSEONE2/
+// KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
+// registration for AutocastPrivateUse1
+#define KERNEL_PRIVATEUSEONE(OP, POLICY) \
+  KERNEL(c10::DeviceType::PrivateUse1, OP, POLICY)
+
+#define KERNEL_PRIVATEUSEONE2(OP, OVERLOAD, POLICY) \
+  KERNEL2(c10::DeviceType::PrivateUse1, OP, OVERLOAD, POLICY)
+
+#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
+    REDISPATCH_FUNC,                                         \
+    REGISTER_NAME,                                           \
+    REGISTER_SIGNATURE,                                      \
+    REDISPATCH_SIGNATURE,                                    \
+    POLICY)                                                  \
+  KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(                     \
+      c10::DeviceType::PrivateUse1,                          \
+      REDISPATCH_FUNC,                                       \
+      REGISTER_NAME,                                         \
+      REGISTER_SIGNATURE,                                    \
+      REDISPATCH_SIGNATURE,                                  \
+      POLICY)
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ceil_div.h b/MLPY/Lib/site-packages/torch/include/ATen/ceil_div.h
new file mode 100644
index 0000000000000000000000000000000000000000..7eb9940e57d8bd97cef964acb8650466d663da17
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/ceil_div.h
@@ -0,0 +1,24 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+
+/**
+   Computes ceil(a / b)
+*/
+template >>
+C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) {
+  return (a + b - 1) / b;
+}
+
+/**
+   Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest
+   multiple of b
+*/
+template 
+C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) {
+  return ceil_div(a, b) * b;
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/code_template.h b/MLPY/Lib/site-packages/torch/include/ATen/code_template.h
new file mode 100644
index 0000000000000000000000000000000000000000..45872bb07daedbecfb59ca46dc6c507dc16a6aac
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/code_template.h
@@ -0,0 +1,243 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::jit {
+
+// A template environment is a mapping from template variable names, e.g.,
+// identifier (corresponding to $identifier) to their expansions.
+//
+// This template environment supports storing strings, numbers and lists
+// of strings, and can be chained together (so that lookup proceeds in
+// in the top level environment, and then recurses into a parent
+// environment if the key is not found.)
+struct TemplateEnv {
+  TemplateEnv() = default;
+  TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
+
+  using string_list = std::vector;
+
+  // Add a string 'v' to the map at key 'k'.
+  void s(const std::string& k, const std::string& v) {
+    strings_[k] = v;
+    lists_.erase(k);
+  }
+
+  // Add a number 'v' to the map at key 'k'
+  template 
+  void d(const std::string& k, const T& v) {
+    strings_[k] = c10::to_string(v);
+    lists_.erase(k);
+  }
+
+  // Retrieve the string representation of the value stored at 'k' from the map.
+  // Raises an exception if the key is not found.
+  const std::string& s(const std::string& k) const {
+    if (strings_.count(k) == 0) {
+      if (parent) {
+        return parent->s(k);
+      }
+      notFound(k);
+    }
+    return strings_.at(k);
+  }
+
+  // Store a list of strings 'v' in the map at 'k'.
+  void v(const std::string& k, const string_list& v) {
+    lists_[k] = v;
+    strings_.erase(k);
+  }
+
+  // Retrieve a list of strings stored at 'k' from the map.
+  // Raises an exception if the key is not found.
+  const string_list& v(const std::string& k) const {
+    if (lists_.count(k) == 0) {
+      if (parent) {
+        return parent->v(k);
+      }
+      notFound(k);
+    }
+    return lists_.at(k);
+  }
+
+  // Test if a string 'k' is a string (as opposed to a list.)
+  bool keyIsString(const std::string& k) const {
+    if (strings_.count(k) > 0)
+      return true;
+    if (lists_.count(k) > 0)
+      return false;
+    if (parent)
+      return parent->keyIsString(k);
+    notFound(k);
+  }
+
+ private:
+  [[noreturn]] void notFound(const std::string& k) const {
+    std::stringstream ss;
+    ss << "key not found: " << k;
+    throw std::logic_error(ss.str());
+  }
+
+  std::unordered_map strings_;
+  std::unordered_map lists_;
+  TemplateEnv* parent{nullptr};
+};
+
+/*
+# Match $identifier or ${identifier} and replace with the value in env.
+# If this identifier is at the beginning of whitespace on a line
+# and its value is a list then it is treated as
+# block substitution by indenting all lines of all elements.
+# If the identifier is on a line starting with non-whitespace and a list
+# then it is comma separated. ${,foo} will insert a comma before the list
+# if this list is not empty and ${foo,} will insert one after.
+*/
+struct CodeTemplate {
+  /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
+
+  std::string format(const TemplateEnv& env) const {
+    std::stringstream out;
+    size_t pos = 0;
+    size_t indent = 0;
+    bool all_whitespace = true;
+    while (pos < template_text.size()) {
+      char c = template_text[pos];
+      if (c == '$') {
+        std::stringstream kss;
+        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+        bool comma_before;
+        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+        bool comma_after;
+        size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
+        std::string k = kss.str();
+        bool is_string = env.keyIsString(k);
+        if (all_whitespace) {
+          if (is_string)
+            emitStringWithIndents(out, indent, env.s(k));
+          else
+            emitLinesIndented(out, indent, env.v(k));
+        } else {
+          if (is_string)
+            out << env.s(k);
+          else
+            emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
+        }
+        all_whitespace = false;
+        pos = new_pos;
+      } else {
+        out << c;
+        if (!isspace(c))
+          all_whitespace = false;
+        indent++;
+        if (c == '\n') {
+          indent = 0;
+          all_whitespace = true;
+        }
+        pos++;
+      }
+    }
+    return out.str();
+  }
+
+ private:
+  using string_list = std::vector;
+  char charAt(size_t p) const {
+    if (p >= template_text.size())
+      throw std::logic_error("EOS found in key");
+    return template_text[p];
+  }
+  size_t parseKey(
+      size_t pos,
+      std::ostream& k,
+      bool& comma_before,
+      bool& comma_after) const {
+    comma_before = false;
+    comma_after = false;
+    pos++;
+    if (charAt(pos) == '{') {
+      pos++;
+      if (charAt(pos) == ',') {
+        comma_before = true;
+        pos++;
+      }
+      pos = parseIdent(pos, k);
+      if (charAt(pos) == ',') {
+        comma_after = true;
+        pos++;
+      }
+      if (charAt(pos) != '}')
+        throw std::logic_error("missing terminating '}'");
+      pos++;
+      return pos;
+    } else {
+      return parseIdent(pos, k);
+    }
+  }
+  size_t parseIdent(size_t pos, std::ostream& k) const {
+    while (pos < template_text.size() &&
+           (isalnum(template_text[pos]) || template_text[pos] == '_')) {
+      k << template_text[pos];
+      pos++;
+    }
+    return pos;
+  }
+  void emitCommaSeparatedList(
+      std::ostream& out,
+      const string_list& strings,
+      bool comma_before,
+      bool comma_after) const {
+    if (comma_before && !strings.empty())
+      out << ", ";
+    for (const auto i : c10::irange(strings.size())) {
+      if (i > 0)
+        out << ", ";
+      out << strings[i];
+    }
+    if (comma_after && !strings.empty())
+      out << ", ";
+  }
+  // These indentation functions follow the convention that they never emit
+  // leading or trailing newlines when the input string does not have leading
+  // or trailing newlines. It's the responsibility of the calling function
+  // to indent correctly in the context.
+  void emitIndent(std::ostream& out, size_t indent) const {
+    for (C10_UNUSED const auto i : c10::irange(indent)) {
+      out << " ";
+    }
+  }
+  void emitStringWithIndents(
+      std::ostream& out,
+      size_t indent,
+      const std::string& str) const {
+    for (auto c : str) {
+      out << c;
+      if (c == '\n') {
+        emitIndent(out, indent);
+      }
+    }
+  }
+  void emitLinesIndented(
+      std::stringstream& out,
+      size_t indent,
+      const string_list& strings) const {
+    for (const auto i : c10::irange(strings.size())) {
+      if (i > 0)
+        emitIndent(out, indent);
+      emitStringWithIndents(out, indent, strings[i]);
+      if (i + 1 != strings.size())
+        out << "\n";
+    }
+  }
+  std::string template_text;
+};
+
+static inline std::string format(const std::string& fmt, TemplateEnv& env) {
+  return CodeTemplate(fmt).format(env);
+}
+
+} // namespace at::jit
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ATenGeneral.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ATenGeneral.h
new file mode 100644
index 0000000000000000000000000000000000000000..8f411e535837a17c272762ccbd2714e15a1466cd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ATenGeneral.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ATenOpList.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ATenOpList.h
new file mode 100644
index 0000000000000000000000000000000000000000..6dfed2b9398544bb43938cdcc8243cfb10d9be32
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ATenOpList.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+struct OperatorName;
+}
+
+namespace at {
+
+// check if an op is a custom op (i.e. did not come from native_functions.yaml)
+TORCH_API bool is_custom_op(const c10::OperatorName& opName);
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ATen_fwd.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ATen_fwd.h
new file mode 100644
index 0000000000000000000000000000000000000000..263e339c5bd6c7d4362771bc078ca8d980e042ec
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ATen_fwd.h
@@ -0,0 +1,46 @@
+#pragma once
+#include 
+
+// Forward declarations of core ATen types used in dispatch functions
+namespace c10 {
+
+template
+class List;
+template
+class IListRef;
+class Stream;
+class Scalar;
+class SymInt;
+class SymIntList;
+struct Storage;
+struct TensorOptions;
+template 
+class ArrayRef;
+template 
+class OptionalArrayRef;
+
+}  // namespace c10
+
+namespace at {
+
+class Tensor;
+class OptionalTensorRef;
+struct Dimname;
+struct Generator;
+using TensorList = c10::ArrayRef;
+using ITensorListRef = c10::IListRef;
+using IOptTensorListRef = c10::IListRef;
+using DimnameList = c10::ArrayRef;
+using IntArrayRef = c10::ArrayRef;
+using OptionalIntArrayRef = c10::OptionalArrayRef;
+using OptionalSymIntArrayRef = c10::OptionalArrayRef;
+
+using c10::Stream;
+using c10::Storage;
+using c10::QScheme;
+using c10::Scalar;
+using c10::SymInt;
+using c10::SymIntList;
+using c10::TensorOptions;
+
+}  // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ATen_pch.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ATen_pch.h
new file mode 100644
index 0000000000000000000000000000000000000000..a0f32460ebe4ead8ca3c01d2c1f58bfce900942a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ATen_pch.h
@@ -0,0 +1,165 @@
+// This global header must not depend on native_functions.yaml or
+// incremental builds will be next to useless
+#pragma push_macro("TORCH_ASSERT_NO_OPERATORS")
+#define TORCH_ASSERT_NO_OPERATORS
+
+// This macro doesn't work if defined after the first time inttypes.h
+// is included, so won't work anywhere if not defined here.
+#ifndef __STDC_FORMAT_MACROS
+#define __STDC_FORMAT_MACROS
+#endif
+#include 
+
+// This list of headers was generated using a script that finds
+// high-impact headers and then manually tweaked to remove OS specific
+// or duplicate headers (e.g.  and ) and to remove
+// "impl" headers (e.g BFloat16-inl.h or complex_math.h in c10).
+
+// To generate the initial list:
+// 1. Build pytorch from scratch with all build caching disabled
+// 2. Generate a build trace with ninjatracing (https://github.com/nico/ninjatracing)
+//    $ ninjatracing /path/to/pytorch/build/.ninja_log > trace_all.json
+// 3. Run pch_gen.py from https://github.com/peterbell10/build_analysis/
+//    $ python pch_gen.py --threshold .80 --target torch_cpu --build_dir /path/to/pytorch/build --trace trace_all.json
+//    Where the threshold can be tweaked until c10 and some of ATen
+//    core are included but TORCH_ASSERT_NO_OPERATORS still passes.
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#pragma pop_macro("TORCH_ASSERT_NO_OPERATORS")
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Array.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Array.h
new file mode 100644
index 0000000000000000000000000000000000000000..c81a3cffbfd59e277bc5b7be6b1aca77db52c1ce
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Array.h
@@ -0,0 +1,39 @@
+#pragma once
+
+// A fixed-size array type usable from both host and
+// device code.
+
+#include 
+#include 
+
+namespace at { namespace detail {
+
+template 
+struct Array {
+  T data[size_];
+
+  C10_HOST_DEVICE T operator[](int i) const {
+    return data[i];
+  }
+  C10_HOST_DEVICE T& operator[](int i) {
+    return data[i];
+  }
+#if defined(USE_ROCM)
+  C10_HOST_DEVICE Array() = default;
+  C10_HOST_DEVICE Array(const Array&) = default;
+  C10_HOST_DEVICE Array& operator=(const Array&) = default;
+#else
+  Array() = default;
+  Array(const Array&) = default;
+  Array& operator=(const Array&) = default;
+#endif
+  static constexpr int size(){return size_;}
+  // Fill the array with x.
+  C10_HOST_DEVICE Array(T x) {
+    for (int i = 0; i < size_; i++) {
+      data[i] = x;
+    }
+  }
+};
+
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Backtrace.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Backtrace.h
new file mode 100644
index 0000000000000000000000000000000000000000..684825dc2ba32d0dd84284f08591ec0ec314980f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Backtrace.h
@@ -0,0 +1,2 @@
+#include 
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/CheckMemoryFormat.h b/MLPY/Lib/site-packages/torch/include/ATen/core/CheckMemoryFormat.h
new file mode 100644
index 0000000000000000000000000000000000000000..ce83c43497192016cdb26de022a52ee30020b28b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/CheckMemoryFormat.h
@@ -0,0 +1,25 @@
+#include 
+
+namespace c10 { namespace impl {
+
+inline c10::optional
+check_tensor_options_and_extract_memory_format(
+    const TensorOptions& options,
+    c10::optional memory_format) {
+  TORCH_CHECK(
+      options.requires_grad_opt() == c10::nullopt ||
+      options.requires_grad_opt().value() == false,
+      "Operators taking TensorOptions cannot take a TensorOptions with "
+      "options.requires_grad set as true. This isn't implemented yet.");
+  TORCH_CHECK(
+      !(options.has_memory_format() && memory_format.has_value()),
+      "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
+      "the redundant setter.");
+  if (memory_format.has_value()) {
+    return memory_format;
+  } else {
+    return options.memory_format_opt();
+  }
+}
+
+}} // namespace impl namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h b/MLPY/Lib/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h
new file mode 100644
index 0000000000000000000000000000000000000000..5c95fc31149c7cde43bf62d114408e17270cef62
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h
@@ -0,0 +1,139 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+namespace at {
+
+class Tensor;
+
+// This class specifies a Backend and a ScalarType. Currently, it primarily
+// serves as a replacement return value for Tensor::type(). Previously,
+// Tensor::type() returned Type&, but we are changing Type to not be
+// dtype-specific.
+class TORCH_API DeprecatedTypeProperties {
+ public:
+  DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
+    : backend_(backend), scalar_type_(scalar_type) {}
+
+  Backend backend() const {
+    return backend_;
+  }
+
+  Layout layout() const {
+    return layout_from_backend(backend_);
+  }
+
+  bool is_sparse() const {
+    return layout_from_backend(backend()) == kSparse;
+  }
+
+  bool is_sparse_csr() const {
+    return layout_from_backend(backend()) == kSparseCsr;
+  }
+
+  c10::DeviceType device_type() const {
+    return backendToDeviceType(backend_);
+  }
+
+  bool is_cuda() const {
+    return backendToDeviceType(backend_) == kCUDA;
+  }
+
+  ScalarType scalarType() const {
+    return scalar_type_;
+  }
+
+  caffe2::TypeMeta typeMeta() const {
+    return scalarTypeToTypeMeta(scalar_type_);
+  }
+
+  bool operator==(const DeprecatedTypeProperties& other) const {
+    return backend_ == other.backend() && scalar_type_ == other.scalarType();
+  }
+
+  bool operator!=(const DeprecatedTypeProperties& other) const {
+    return !(*this == other);
+  }
+
+  std::string toString() const {
+    std::string base_str;
+    if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) {
+      base_str = "UndefinedType";
+    } else {
+      base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type";
+    }
+    return base_str;
+  }
+
+  DeprecatedTypeProperties & toBackend(Backend b) const {
+    return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+        b, scalar_type_);
+  }
+
+  DeprecatedTypeProperties & toScalarType(ScalarType s) const {
+    return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+        backend_, s);
+  }
+
+  DeprecatedTypeProperties & cpu() const {
+    return toBackend(Backend::CPU);
+  }
+
+  DeprecatedTypeProperties & cuda() const {
+    return toBackend(Backend::CUDA);
+  }
+
+  DeprecatedTypeProperties & hip() const {
+    return toBackend(Backend::HIP);
+  }
+
+  DeprecatedTypeProperties & privateUser1() const {
+    return toBackend(Backend::PrivateUse1);
+  }
+
+  /// Constructs the `TensorOptions` from a type and a `device_index`.
+  TensorOptions options(int16_t device_index = -1) const {
+    return TensorOptions().dtype(typeMeta())
+                          .device(device_type(), static_cast(device_index))
+                          .layout(layout());
+  }
+
+  /// Constructs the `TensorOptions` from a type and a Device.  Asserts that
+  /// the device type matches the device type of the type.
+  TensorOptions options(c10::optional device_opt) const {
+    if (!device_opt.has_value()) {
+      return options(-1);
+    } else {
+      Device device = device_opt.value();
+      AT_ASSERT(device.type() == device_type());
+      return options(device.index());
+    }
+  }
+
+  operator TensorOptions() const {
+    return options();
+  }
+
+  int64_t id() const {
+    return static_cast(backend()) *
+        static_cast(ScalarType::NumOptions) +
+        static_cast(scalarType());
+  }
+
+  Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
+  Storage unsafeStorageFromTH(void * th_pointer, bool retain) const;
+  Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional to_device={}) const;
+
+ private:
+  Backend backend_;
+  ScalarType scalar_type_;
+};
+
+}  // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h b/MLPY/Lib/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h
new file mode 100644
index 0000000000000000000000000000000000000000..fcf7a88f8d0ad26f1dc21a687a6696f48108af5f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h
@@ -0,0 +1,32 @@
+#pragma once
+
+// In order to preserve bc, we make DeprecatedTypeProperties instances unique
+// just like they are for Type.
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+class DeprecatedTypeProperties;
+
+struct TORCH_API DeprecatedTypePropertiesDeleter {
+  void operator()(DeprecatedTypeProperties * ptr);
+};
+
+class TORCH_API DeprecatedTypePropertiesRegistry {
+ public:
+  DeprecatedTypePropertiesRegistry();
+
+  DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) const;
+
+private:
+  std::unique_ptr registry
+    [static_cast(Backend::NumOptions)]
+    [static_cast(ScalarType::NumOptions)];
+};
+
+TORCH_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Dict.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Dict.h
new file mode 100644
index 0000000000000000000000000000000000000000..7808d52d32f9348b96b2195119a3255cd4f9b276
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Dict.h
@@ -0,0 +1,397 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+struct IValue;
+template class Dict;
+struct Type;
+
+namespace impl {
+
+using valid_dict_key_types = guts::typelist::typelist<
+  int64_t,
+  std::string,
+  double,
+  c10::complex,
+  bool,
+  at::Tensor
+>;
+}
+
+namespace detail {
+
+struct DictKeyHash {
+  size_t operator()(const IValue& ivalue) const;
+};
+
+struct DictKeyEqualTo {
+  bool operator()(const IValue& lhs, const IValue& rhs) const;
+};
+
+struct DictImpl final : public c10::intrusive_ptr_target {
+  using dict_map_type = ska_ordered::order_preserving_flat_hash_map;
+  struct DictElementTypes final {
+    TypePtr keyType;
+    TypePtr valueType;
+  };
+
+  explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_)
+  : dict(std::move(dict_))
+  , elementTypes(std::move(elementTypes_)) {}
+  dict_map_type dict;
+
+  DictElementTypes elementTypes;
+
+  intrusive_ptr copy() const;
+  friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs);
+};
+
+}
+
+namespace impl {
+template class DictIterator;
+
+/**
+ * A reference to an entry in the Dict.
+ * Use the `key()` and `value()` methods to read the element.
+ */
+template
+class DictEntryRef final {
+public:
+  explicit DictEntryRef(Iterator iterator)
+  : iterator_(std::move(iterator)) {}
+
+  decltype(auto) key() const {
+    return iterator_->first.template to();
+  }
+
+  decltype(auto) value() const {
+    return iterator_->second.template to();
+  }
+
+  template
+  void setValue(Value_&& value) const {
+    static_assert(std::is_constructible::value, "Wrong type for the value argument of setValue()");
+    iterator_->second = Value(std::forward(value));
+  }
+
+private:
+  // allow copying and moving, but only our friends (i.e. the Dict class) can do
+  // it. Copying/moving this reference wrapper would be too ambiguous to allow it
+  // in the public API.
+  DictEntryRef(const DictEntryRef&) = default;
+  DictEntryRef& operator=(const DictEntryRef&) = default;
+  DictEntryRef(DictEntryRef&&) noexcept = default;
+  DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default;
+
+  Iterator iterator_;
+  friend class DictIterator;
+  friend class Dict;
+};
+
+// this wraps map_type::iterator to make sure user code can't rely
+// on it being the type of the underlying map.
+template
+class DictIterator final {
+public:
+   // C++17 friendly std::iterator implementation
+  using iterator_category = std::forward_iterator_tag;
+  using value_type = DictEntryRef;
+  using difference_type = std::ptrdiff_t;
+  using pointer = value_type*;
+  using reference = value_type&;
+
+  explicit DictIterator() = default;
+  ~DictIterator() = default;
+
+  DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {}
+  DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {}
+  DictIterator& operator=(const DictIterator& rhs) {
+    entryRef_ = rhs.entryRef_;
+    return *this;
+  }
+  DictIterator& operator=(DictIterator&& rhs) noexcept {
+    entryRef_ = std::move(rhs.entryRef_);
+    return *this;
+  }
+
+  DictIterator& operator++() {
+      ++entryRef_.iterator_;
+      return *this;
+  }
+
+  DictIterator operator++(int) {
+      DictIterator copy(*this);
+      ++*this;
+      return copy;
+  }
+
+  const DictEntryRef& operator*() const {
+      return entryRef_;
+  }
+
+  const DictEntryRef* operator->() const {
+    return &entryRef_;
+  }
+
+  friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_;
+  }
+
+private:
+  explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {}
+
+  const Iterator& get_iterator_() const {
+    return entryRef_.iterator_;
+  }
+
+  friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.get_iterator_() == rhs.get_iterator_();
+  }
+
+  friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.get_iterator_() != rhs.get_iterator_();
+  }
+
+  friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.get_iterator_() < rhs.get_iterator_();
+  }
+
+  friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.get_iterator_() <= rhs.get_iterator_();
+  }
+
+  friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.get_iterator_() > rhs.get_iterator_();
+  }
+
+  friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) {
+    return lhs.get_iterator_() >= rhs.get_iterator_();
+  }
+
+  DictEntryRef entryRef_;
+
+  friend class DictIterator;
+  friend class Dict;
+};
+
+template Dict toTypedDict(Dict dict);
+template Dict toGenericDict(Dict dict);
+}
+
+/**
+ * An object of this class stores a map from Key to Value.
+ *
+ * This is a pointer type. After a copy, both Dicts
+ * will share the same storage:
+ *
+ * > Dict a;
+ * > Dict b = a;
+ * > b.insert(3, "three");
+ * > ASSERT("three" == a.at(3));
+ *
+ * We use this class in the PyTorch kernel API because that
+ * allows us to do optimizations and switch out the underlying
+ * map implementation without breaking backwards compatibility
+ * for the kernel API.
+ */
+template
+class Dict final {
+private:
+  static_assert((std::is_same::value && std::is_same::value) || guts::typelist::contains::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string.");
+
+  // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map.
+  // We intentionally don't offer conversion from/to
+  // order_preserving_flat_hash_map, return references to it or something like that,
+  // because such operations would get expensive if we switch out
+  // the actual map implementation.
+  // This is an intrusive_ptr because Dict is a pointer type.
+  // Invariant: This will never be a nullptr, there will always be a valid
+  // DictImpl.
+  c10::intrusive_ptr impl_;
+
+  explicit Dict(c10::intrusive_ptr&& impl);
+  friend struct IValue;
+  template friend Dict impl::toTypedDict(Dict);
+  template friend Dict impl::toGenericDict(Dict);
+
+public:
+  using key_type = Key;
+  using mapped_type = Value;
+  using size_type = typename detail::DictImpl::dict_map_type::size_type;
+  using iterator = impl::DictIterator;
+
+  /**
+   * Creates an empty dict.
+   */
+  explicit Dict();
+
+  /**
+   * Create a generic dict with runtime type information.
+   * This only works for c10::impl::GenericDict and is not part of the public API
+   * but only supposed to be used internally by PyTorch.
+   */
+  explicit Dict(TypePtr keyType, TypePtr valueType);
+
+  ~Dict() = default;
+
+  Dict(const Dict&) = default;
+  Dict& operator=(const Dict&) = default;
+
+  /**
+   * Create a new Dict pointing to a deep copy of the same data.
+   * The Dict returned is a new dict with separate storage.
+   * Changes in it are not reflected in the original dict or vice versa.
+   */
+  Dict copy() const;
+
+  /**
+   * Returns an iterator to the first element of the container.
+   * If the container is empty, the returned iterator will be equal to end().
+   */
+  iterator begin() const;
+
+  /**
+   * Returns an iterator to the element following the last element of the container.
+   * This element acts as a placeholder; attempting to access it results in undefined behavior.
+   */
+  iterator end() const;
+
+  /**
+   * Checks if the container has no elements.
+   */
+  bool empty() const;
+
+  /**
+   * Returns the number of elements in the container.
+   */
+  size_type size() const;
+
+  /**
+   * Erases all elements from the container. After this call, size() returns zero.
+   * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators.
+   */
+  void clear() const;
+
+  /**
+   * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key.
+   * May invalidate any references, pointers, or iterators referring to contained elements.
+   *
+   * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place.
+   */
+  template
+  std::pair insert(Key_&& key, Value_&& value) const;
+
+  /**
+   * If an element with the given key already exists, it is overwritten with the given value.
+   * Otherwise, a new element with the given key and value are inserted.
+   * May invalidate any references, pointers, or iterators referring to contained elements.
+   *
+   * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated.
+   */
+  template
+  std::pair insert_or_assign(Key_&& key, Value_&& value) const;
+
+  /**
+   * Removes the element pointed to by iter.
+   * May invalidate any references, pointers, or iterators referring to contained elements.
+   * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter.
+   */
+  void erase(iterator iter) const;
+
+  /**
+   * Removes the element with the given key, if it exists.
+   * May invalidate any references, pointers, or iterators referring to contained elements.
+   *
+   * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't.
+   */
+  C10_NODISCARD size_t erase(const Key& key) const;
+
+  /**
+   * Returns the mapped value of the element with key equivalent to key.
+   * If no such element exists, an exception of type std::out_of_range is thrown.
+   */
+  Value at(const Key& key) const;
+
+  /**
+   * Finds an element with key equivalent to key.
+   *
+   * @return Iterator to an element with key equivalent to key.
+   *         If no such element is found, past-the-end (see end()) iterator is returned.
+   */
+  iterator find(const Key& key) const;
+
+  /**
+   * Checks if there is an element with key equivalent to key in the container.
+   *
+   * @return true if there is such an element, otherwise false.
+   */
+  bool contains(const Key& key) const;
+
+  /**
+   * Increase the capacity so that at least count elements can be stored without
+   * having to reallocate or rehash.
+   */
+  void reserve(size_type count) const;
+
+  /**
+   * Value equality comparison. This function implements Python-like semantics for
+   * equality: two dicts with the same identity (e.g. same pointer) trivially
+   * compare equal, otherwise each element is compared for equality.
+   */
+  template 
+  friend bool operator==(
+      const Dict& lhs,
+      const Dict& rhs);
+  template 
+  friend bool operator!=(
+      const Dict& lhs,
+      const Dict& rhs);
+
+  /**
+   * Identity comparison. Returns true if and only if `rhs` represents the same
+   * Dict object as `this`.
+   */
+  bool is(const Dict& rhs) const;
+
+  // private API for now because the return type will change to TypePtr
+  // instead of optional once types are mandatory.
+  TypePtr keyType() const;
+  TypePtr valueType() const;
+
+  // [unsafe set type]
+  // These functions mutate the tagged type of this dictionary in place.
+  // There is no checking that the members of the dictionary are instances
+  // of the new types, nor is there a check that other IValues which
+  // hold references to this dictionary have the right static type.
+  // This functionality is used only in the unpickler, where at
+  // creation type the real type of the dictionary is unknown, but
+  // then later recovered from the static type information of the
+  // unpickled object.
+  void unsafeSetKeyType(TypePtr t);
+  void unsafeSetValueType(TypePtr t);
+};
+
+namespace impl {
+// GenericDict is how IValue stores dicts. It is, however, not part of the
+// public API. Kernels should use Dicts with concrete Key, Value types instead
+// (maybe except for some internal prim ops).
+using GenericDict = Dict;
+
+}
+}
+
+namespace torch {
+  template using Dict = c10::Dict;
+}
+
+#include   // IWYU pragma: keep
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Dict_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Dict_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..9340a06ac7479838f30d26dafb30beb16c2f4c9d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Dict_inl.h
@@ -0,0 +1,209 @@
+#pragma once
+
+#include 
+#include 
+
+namespace c10 {
+namespace detail {
+inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
+  if (lhs.isTensor() && rhs.isTensor()) {
+    // for tensors, we compare only by identity (following how it's done in Python).
+    return lhs.is(rhs);
+  }
+  // Otherwise, we first compare by identity for efficiency, then by value (see:
+  // [container equality])
+  return _fastEqualsForContainer(lhs, rhs);
+}
+}
+
+template decltype(auto) getTypePtr();
+std::string toString(const Type& type);
+
+namespace impl {
+
+template
+Dict toTypedDict(GenericDict dict) {
+  TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr()), ", ", toString(*getTypePtr()), ">. Key types mismatch.");
+  TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr()), ", ", toString(*getTypePtr()), ">. Value types mismatch.");
+
+  return Dict(std::move(dict.impl_));
+}
+
+template
+GenericDict toGenericDict(Dict dict) {
+  return GenericDict(std::move(dict.impl_));
+}
+}
+
+namespace detail {
+
+inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
+  if (ivalue.isInt()) {
+    return std::hash()(ivalue.toInt());
+  } else if (ivalue.isString()) {
+    return std::hash()(ivalue.toStringView());
+  } else if (ivalue.isDouble()) {
+    return std::hash()(ivalue.toDouble());
+  } else if (ivalue.isComplexDouble()) {
+    return c10::hash>()(ivalue.toComplexDouble());
+  } else if (ivalue.isBool()) {
+    return std::hash()(ivalue.toBool());
+  } else if (ivalue.isTensor()) {
+    return std::hash()(ivalue.toTensor().unsafeGetTensorImpl());
+  } else if (ivalue.isDevice()) {
+    return std::hash()(ivalue.toDevice());
+  } else {
+    throw std::runtime_error(
+        "Can't hash IValues with tag '" + ivalue.tagKind() + "'");
+  }
+}
+
+inline intrusive_ptr DictImpl::copy() const {
+  return make_intrusive(dict, elementTypes);
+}
+
+}
+
+template
+Dict::Dict()
+  :Dict(make_intrusive(
+      detail::DictImpl::dict_map_type(),
+      detail::DictImpl::DictElementTypes{getTypePtr(), getTypePtr()})) {
+  static_assert(!std::is_same::value, "This constructor is not valid for Dict. Please use c10::impl::GenericDict(keyType, valueType) instead.");
+  static_assert(!std::is_same::value, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
+}
+
+template
+Dict::Dict(TypePtr keyType, TypePtr valueType)
+: Dict(make_intrusive(
+    detail::DictImpl::dict_map_type(),
+    detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
+  static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict.");
+  static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict.");
+}
+
+template
+Dict::Dict(c10::intrusive_ptr&& impl): impl_(std::move(impl)) {}
+
+template
+Dict Dict::copy() const {
+  return Dict(impl_->copy());
+}
+
+template
+typename Dict::iterator Dict::begin() const {
+  return iterator{impl_->dict.begin()};
+}
+
+template
+typename Dict::iterator Dict::end() const {
+  return iterator{impl_->dict.end()};
+}
+
+template
+bool Dict::empty() const {
+  return impl_->dict.empty();
+}
+
+template
+typename Dict::size_type Dict::size() const {
+  return impl_->dict.size();
+}
+
+template
+void Dict::clear() const {
+  impl_->dict.clear();
+}
+
+template
+template
+std::pair::iterator, bool> Dict::insert(Key_&& key, Value_&& value) const {
+  static_assert(std::is_constructible::value, "Wrong type for the key argument of Dict::insert");
+  static_assert(std::is_constructible::value, "Wrong type for the value argument of Dict::insert");
+  auto inserted = impl_->dict.emplace(
+      Key(std::forward(key)),
+      Value(std::forward(value)));
+  return {iterator{inserted.first}, inserted.second};
+}
+
+template
+template
+std::pair::iterator, bool> Dict::insert_or_assign(Key_&& key, Value_&& value) const {
+  static_assert(std::is_constructible::value, "Wrong type for the key argument of Dict::insert_or_assign");
+  static_assert(std::is_constructible::value, "Wrong type for the value argument of Dict::insert_or_assign");
+  auto inserted = impl_->dict.insert_or_assign(
+    Key(std::forward(key)),
+    Value(std::forward(value)));
+  return {iterator{inserted.first}, inserted.second};
+}
+
+template
+void Dict::erase(iterator iter) const {
+  impl_->dict.erase(iter.entryRef_.iterator_);
+}
+
+template
+C10_NODISCARD size_t Dict::erase(const Key& key) const {
+  return impl_->dict.erase(key);
+}
+
+template
+Value Dict::at(const Key& key) const {
+  return impl_->dict.at(key).template to();
+}
+
+template
+typename Dict::iterator Dict::find(const Key& key) const {
+  return iterator{impl_->dict.find(key)};
+}
+
+template
+bool Dict::contains(const Key& key) const {
+  return end() != find(key);
+}
+
+template
+void Dict::reserve(size_type count) const {
+  impl_->dict.reserve(count);
+}
+
+template
+TypePtr Dict::keyType() const {
+  return impl_->elementTypes.keyType;
+}
+
+template
+TypePtr Dict::valueType() const {
+  return impl_->elementTypes.valueType;
+}
+template 
+void Dict::unsafeSetKeyType(TypePtr t) {
+  impl_->elementTypes.keyType = std::move(t);
+}
+
+template 
+void Dict::unsafeSetValueType(TypePtr t) {
+  impl_->elementTypes.valueType = std::move(t);
+}
+
+template 
+bool operator==(const Dict& lhs, const Dict& rhs) {
+  // Dicts with the same identity trivially compare equal.
+  if (lhs.impl_ == rhs.impl_) {
+    return true;
+  }
+
+  // Otherwise compare the values
+  return *lhs.impl_ == *rhs.impl_;
+}
+
+template 
+bool operator!=(const Dict& lhs, const Dict& rhs) {
+  return !(lhs == rhs);
+}
+
+template 
+bool Dict::is(const Dict& rhs) const {
+  return this->impl_ == rhs.impl_;
+}
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/DimVector.h b/MLPY/Lib/site-packages/torch/include/ATen/core/DimVector.h
new file mode 100644
index 0000000000000000000000000000000000000000..9d0318b7e3bd6b6207c9b2e333b6fdf99eaf0585
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/DimVector.h
@@ -0,0 +1,13 @@
+#pragma once
+#include 
+
+namespace at {
+
+// Re-declaring 'DimVector' type and size inside 'at' namespace.
+// This is done to avoid modifying every use into their 'c10'
+// equivalent.
+
+using c10::kDimVectorStaticSize;
+using c10::DimVector;
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Dimname.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Dimname.h
new file mode 100644
index 0000000000000000000000000000000000000000..9ac2abe3ac0ae8a78af55a426a325b072e32439d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Dimname.h
@@ -0,0 +1,48 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+enum class NameType: uint8_t { BASIC, WILDCARD };
+
+struct TORCH_API Dimname {
+  static Dimname fromSymbol(Symbol name);
+  static Dimname wildcard();
+  static bool isValidName(const std::string& name);
+
+  NameType type() const { return type_; }
+  Symbol symbol() const { return name_; }
+
+  bool isBasic() const { return type_ == NameType::BASIC; }
+  bool isWildcard() const { return type_ == NameType::WILDCARD; }
+
+  bool matches(Dimname other) const;
+  c10::optional unify(Dimname other) const;
+
+ private:
+  Dimname(Symbol name)
+    : name_(name), type_(NameType::BASIC) {}
+  Dimname(Symbol name, NameType type)
+    : name_(name), type_(type) {}
+
+  Symbol name_;
+  NameType type_;
+};
+
+using DimnameList = c10::ArrayRef;
+
+TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
+
+inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
+  return lhs.symbol() == rhs.symbol();
+}
+
+inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {
+  return !(lhs == rhs);
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/DistributionsHelper.h b/MLPY/Lib/site-packages/torch/include/ATen/core/DistributionsHelper.h
new file mode 100644
index 0000000000000000000000000000000000000000..ae4a73662fc74d6b75177ea87bbe533034011696
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/DistributionsHelper.h
@@ -0,0 +1,337 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+/**
+ * Distributions kernel adapted from THRandom.cpp
+ * The kernels try to follow std::random distributions signature
+ * For instance: in ATen
+ *      auto gen = at::detail::createCPUGenerator();
+ *      at::uniform_real_distribution uniform(0, 1);
+ *      auto sample = uniform(gen.get());
+ *
+ *      vs std::random
+ *
+ *      std::mt19937 gen;
+ *      std::uniform_real_distribution uniform(0, 1);
+ *      auto sample = uniform(gen);
+ */
+
+
+namespace at {
+namespace {
+
+/**
+ * Samples a discrete uniform distribution in the range [base, base+range) of type T
+ */
+template 
+struct uniform_int_from_to_distribution {
+
+  C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {}
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    if ((
+      std::is_same::value ||
+      std::is_same::value ||
+      std::is_same::value ||
+      std::is_same::value) && range_ >= 1ULL << 32)
+    {
+      return transformation::uniform_int_from_to(generator->random64(), range_, base_);
+    } else {
+      return transformation::uniform_int_from_to(generator->random(), range_, base_);
+    }
+  }
+
+  private:
+    uint64_t range_;
+    int64_t base_;
+};
+
+/**
+ * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)]
+ */
+template 
+struct uniform_int_full_range_distribution {
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    return transformation::uniform_int_full_range(generator->random64());
+  }
+
+};
+
+/**
+ * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types
+ * and [0, 2^mantissa] for floating-point types.
+ */
+template 
+struct uniform_int_distribution {
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    if constexpr (std::is_same_v || std::is_same_v) {
+      return transformation::uniform_int(generator->random64());
+    } else {
+      return transformation::uniform_int(generator->random());
+    }
+  }
+
+};
+
+/**
+ * Samples a uniform distribution in the range [from, to) of type T
+ */
+template 
+struct uniform_real_distribution {
+
+  C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) {
+    TORCH_CHECK_IF_NOT_ON_CUDA(from <= to);
+    TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits::max());
+    from_ = from;
+    to_ = to;
+  }
+
+  template 
+  C10_HOST_DEVICE inline dist_acctype operator()(RNG generator){
+    if constexpr (std::is_same_v) {
+      return transformation::uniform_real(generator->random64(), from_, to_);
+    } else {
+      return transformation::uniform_real(generator->random(), from_, to_);
+    }
+  }
+
+  private:
+    T from_;
+    T to_;
+};
+
+// The SFINAE checks introduced in #39816 looks overcomplicated and must revisited
+// https://github.com/pytorch/pytorch/issues/40052
+#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member)              \
+template                                                 \
+struct has_member_##member                                           \
+{                                                                    \
+    typedef char yes;                                                \
+    typedef long no;                                                 \
+    template  static yes test(decltype(&U::member));     \
+    template  static no test(...);                       \
+    static constexpr bool value = sizeof(test(0)) == sizeof(yes); \
+}
+
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample);
+DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample);
+
+#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE)                                      \
+                                                                                                    \
+template ::value &&                                   \
+            has_member_set_next_##TYPE##_normal_sample::value                                  \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) {  \
+  if (generator->next_##TYPE##_normal_sample()) {                                                   \
+    *ret = *(generator->next_##TYPE##_normal_sample());                                             \
+    generator->set_next_##TYPE##_normal_sample(c10::optional());                              \
+    return true;                                                                                    \
+  }                                                                                                 \
+  return false;                                                                                     \
+}                                                                                                   \
+                                                                                                    \
+template ::value ||                                  \
+            !has_member_set_next_##TYPE##_normal_sample::value                                 \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) {  \
+  return false;                                                                                     \
+}                                                                                                   \
+                                                                                                    \
+template ::value                                  \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \
+  generator->set_next_##TYPE##_normal_sample(cache);                                                \
+}                                                                                                   \
+                                                                                                    \
+template ::value                                 \
+          ), int> = 0>                                                                              \
+C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \
+}
+
+DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double);
+DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float);
+
+/**
+ * Samples a normal distribution using the Box-Muller method
+ * Takes mean and standard deviation as inputs
+ * Note that Box-muller method returns two samples at a time.
+ * Hence, we cache the "next" sample in the CPUGeneratorImpl class.
+ */
+template 
+struct normal_distribution {
+
+  C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) {
+    TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in);
+    mean = mean_in;
+    stdv = stdv_in;
+  }
+
+  template 
+  C10_HOST_DEVICE inline dist_acctype operator()(RNG generator){
+    dist_acctype ret;
+    // return cached values if available
+    if constexpr (std::is_same_v) {
+      if (maybe_get_next_double_normal_sample(generator, &ret)) {
+        return transformation::normal(ret, mean, stdv);
+      }
+    } else {
+      if (maybe_get_next_float_normal_sample(generator, &ret)) {
+        return transformation::normal(ret, mean, stdv);
+      }
+    }
+    // otherwise generate new normal values
+    uniform_real_distribution uniform(0.0, 1.0);
+    const dist_acctype u1 = uniform(generator);
+    const dist_acctype u2 = uniform(generator);
+    const dist_acctype r = ::sqrt(static_cast(-2.0) * ::log1p(-u2));
+    const dist_acctype theta = static_cast(2.0) * c10::pi * u1;
+    if constexpr (std::is_same_v) {
+      maybe_set_next_double_normal_sample(generator, r * ::sin(theta));
+    } else {
+      maybe_set_next_float_normal_sample(generator, r * ::sin(theta));
+    }
+    ret = r * ::cos(theta);
+    return transformation::normal(ret, mean, stdv);
+  }
+
+  private:
+    T mean;
+    T stdv;
+};
+
+template 
+struct DiscreteDistributionType { using type = float; };
+
+template <> struct DiscreteDistributionType { using type = double; };
+
+/**
+ * Samples a bernoulli distribution given a probability input
+ */
+template 
+struct bernoulli_distribution {
+
+  C10_HOST_DEVICE inline bernoulli_distribution(T p_in) {
+    TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1);
+    p = p_in;
+  }
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    uniform_real_distribution uniform(0.0, 1.0);
+    return transformation::bernoulli(uniform(generator), p);
+  }
+
+  private:
+    T p;
+};
+
+/**
+ * Samples a geometric distribution given a probability input
+ */
+template 
+struct geometric_distribution {
+
+  C10_HOST_DEVICE inline geometric_distribution(T p_in) {
+    TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1);
+    p = p_in;
+  }
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    uniform_real_distribution uniform(0.0, 1.0);
+    return transformation::geometric(uniform(generator), p);
+  }
+
+  private:
+    T p;
+};
+
+/**
+ * Samples an exponential distribution given a lambda input
+ */
+template 
+struct exponential_distribution {
+
+  C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {}
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    uniform_real_distribution uniform(0.0, 1.0);
+    return transformation::exponential(uniform(generator), lambda);
+  }
+
+  private:
+    T lambda;
+};
+
+/**
+ * Samples a cauchy distribution given median and sigma as inputs
+ */
+template 
+struct cauchy_distribution {
+
+  C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {}
+
+  template 
+  C10_HOST_DEVICE inline T operator()(RNG generator) {
+    uniform_real_distribution uniform(0.0, 1.0);
+    return transformation::cauchy(uniform(generator), median, sigma);
+  }
+
+  private:
+    T median;
+    T sigma;
+};
+
+/**
+ * Samples a lognormal distribution
+ * Takes mean and standard deviation as inputs
+ * Outputs two samples at a time
+ */
+template 
+struct lognormal_distribution {
+
+  C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) {
+    TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0);
+    mean = mean_in;
+    stdv = stdv_in;
+  }
+
+  template
+  C10_HOST_DEVICE inline T operator()(RNG generator){
+    normal_distribution normal(mean, stdv);
+    return transformation::log_normal(normal(generator));
+  }
+
+  private:
+    T mean;
+    T stdv;
+};
+}
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Formatting.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Formatting.h
new file mode 100644
index 0000000000000000000000000000000000000000..05b22d474582873e382a8bfde55758c951e154e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Formatting.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+
+namespace c10 {
+TORCH_API std::ostream& operator<<(std::ostream& out, Backend b);
+TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s);
+TORCH_API std::string toString(const Scalar& s);
+}
+namespace at {
+
+TORCH_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
+TORCH_API std::ostream& print(
+    std::ostream& stream,
+    const Tensor& tensor,
+    int64_t linesize);
+static inline std::ostream& operator<<(std::ostream & out, const Tensor & t) {
+  return print(out,t,80);
+}
+TORCH_API void print(const Tensor & t, int64_t linesize=80);
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Generator.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Generator.h
new file mode 100644
index 0000000000000000000000000000000000000000..0dfefa1217177ad0d309f2913e25b1651c25cc56
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Generator.h
@@ -0,0 +1,190 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+// For the record I don't think this is a correct pimpl idiom.
+// Including Impl header in interface header defeats the purpose
+// because you can't change Impl private members without forcing
+// everything that included the interface to rebuild.
+// Impl should be forward-declared in the interface header instead.
+#include 
+
+/**
+ * Note [Generator]
+ * ~~~~~~~~~~~~~~~~
+ * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm to
+ * generate a seemingly random sequence of numbers, that may be later be used in creating
+ * a random distribution. Such an engine almost always maintains a state and requires a
+ * seed to start off the creation of random numbers. Often times, users have
+ * found it beneficial to be able to explicitly create, retain, and destroy
+ * PRNG states and also be able to have control over the seed value.
+ *
+ * A Generator in ATen gives users the ability to read, write and modify a PRNG engine.
+ * For instance, it does so by letting users seed a PRNG engine, fork the state of the
+ * engine, etc.
+ *
+ * By default, there is one generator per device, and a device's generator is
+ * lazily created. A user can use the torch.Generator() api to create their own generator.
+ */
+
+/**
+ * Note [Acquire lock when using random generators]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * Generator and its derived classes are NOT thread-safe. Please note that most of the
+ * places where we have inserted locking for generators are historically based, and we
+ * haven't actually checked that everything is truly thread safe (and it probably isn't).
+ * Please use the public mutex_ when using any methods from these classes, except for the
+ * read-only methods. You can learn about the usage by looking into the unittests
+ * (aten/src/ATen/cpu_generator_test.cpp) and other places where we have used lock_guard.
+ *
+ * TODO: Look into changing the threading semantics of Generators in ATen (e.g., making
+ * them non-thread safe and instead making the generator state splittable, to accommodate
+ * forks into other threads).
+ */
+
+namespace at {
+
+class Tensor;
+
+struct TORCH_API Generator {
+  Generator() = default;
+
+  explicit Generator(c10::intrusive_ptr gen_impl)
+   : impl_(std::move(gen_impl)) {
+    if (impl_.get() == nullptr) {
+      throw std::runtime_error("GeneratorImpl with nullptr is not supported");
+    }
+  }
+
+  bool operator==(const Generator& rhs) const {
+    return this->impl_ == rhs.impl_;
+  }
+
+  bool operator!=(const Generator& rhs) const {
+    return !((*this) == rhs);
+  }
+
+  bool defined() const {
+    return static_cast(impl_);
+  }
+
+  c10::GeneratorImpl* unsafeGetGeneratorImpl() const {
+    return impl_.get();
+  }
+
+  c10::GeneratorImpl* unsafeReleaseGeneratorImpl() {
+    return impl_.release();
+  }
+
+  const c10::intrusive_ptr& getIntrusivePtr() const {
+    return impl_;
+  }
+
+  void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); }
+  // Sets the offset of Generator state to the desired offset. This is currently
+  // supported for only Philox based Generators, i.e., CUDA and MPS.
+  void set_offset(uint64_t offset) { impl_->set_offset(offset); }
+
+  // Returns the offset of Generator state. This is currently supported for only
+  // Philox based Generators, i.e., CUDA and MPS.
+  uint64_t get_offset() const { return impl_->get_offset(); }
+
+  uint64_t current_seed() const { return impl_->current_seed(); }
+
+  uint64_t seed() { return impl_->seed(); }
+
+  // Implementation not inlined to prevent cycle reference between
+  // `ATen/core/Generator.h` and `ATen/core/Tensor.h`
+  void set_state(const at::Tensor& new_state);
+
+  at::Tensor get_state() const;
+
+  std::mutex& mutex() {
+    return impl_->mutex_;
+  }
+
+  DispatchKeySet key_set() const {
+    return impl_->key_set();
+  }
+
+  Device device() const { return impl_->device(); }
+
+  inline void set_pyobj(PyObject* pyobj) const noexcept {
+    impl_->set_pyobj(pyobj);
+  }
+
+  inline PyObject* pyobj() const noexcept {
+    return impl_->pyobj();
+  }
+
+  template
+  T* get() const { return static_cast(impl_.get()); }
+
+  Generator clone() const {
+    return Generator(impl_->clone());
+  }
+
+ private:
+  c10::intrusive_ptr impl_;
+};
+
+template
+Generator make_generator(Args&&... args) {
+  return Generator(c10::make_intrusive(std::forward(args)...));
+}
+
+/**
+ * Utility function to static cast input Generator* to
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
+ */
+template 
+static inline T * check_generator(c10::optional gen) {
+  TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
+  TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
+  TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
+  return gen->get();
+}
+
+/**
+ * Utility function used in tensor implementations, which
+ * supplies the default generator to tensors, if an input generator
+ * is not supplied. The input Generator* is also static casted to
+ * the backend generator type (CPU/CUDAGeneratorImpl etc.)
+ */
+template 
+static inline T* get_generator_or_default(const c10::optional& gen, const Generator& default_gen) {
+  return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen);
+}
+
+namespace detail {
+
+/**
+ * Helper function for checking the validity of new random generator
+ * state. Right now following conditions are checked:
+ *
+ * - The new state tensor must be a torch.ByteTensor
+ * - Data of the new state tensor must be contiguous
+ */
+static inline void check_rng_state(const c10::TensorImpl& new_state) {
+  TORCH_CHECK_TYPE(
+    new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
+    "RNG state must be a torch.ByteTensor"
+  );
+
+  TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
+}
+
+} // namespace detail
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h b/MLPY/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h
new file mode 100644
index 0000000000000000000000000000000000000000..2daa607f02ec720d30a3d9a10c8b429bb1703716
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/GeneratorForPrivateuseone.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+
+using GeneratorFuncType = std::function;
+
+c10::optional& GetGeneratorPrivate();
+
+class TORCH_API _GeneratorRegister {
+ public:
+  explicit _GeneratorRegister(const GeneratorFuncType& func);
+};
+
+TORCH_API at::Generator GetGeneratorForPrivateuse1(
+    c10::DeviceIndex device_index);
+
+/**
+ * This is used to register Generator to PyTorch for `privateuse1` key.
+ *
+ * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
+ *
+ * class CustomGeneratorImpl : public c10::GeneratorImpl {
+ *   CustomGeneratorImpl(DeviceIndex device_index = -1);
+ *   explicit ~CustomGeneratorImpl() override = default;
+ *   ...
+ * };
+ *
+ * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
+ *   return at::make_generator(id);
+ * }
+ */
+
+#define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
+  static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/IListRef.h b/MLPY/Lib/site-packages/torch/include/ATen/core/IListRef.h
new file mode 100644
index 0000000000000000000000000000000000000000..d0e1f9e063e3d87e729c037d2b61c913288619f0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/IListRef.h
@@ -0,0 +1,631 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+/*
+ * [Note: IListRef]
+ * Wrapper around different API containers (e.g. boxed and unboxed).
+ *
+ * What is it?
+ * ===========
+ * It is a tagged union of both boxed and unboxed API containers.
+ * Working implementations:
+ *
+ * - `IListRef`
+ * - `IListRef`
+ *
+ * Note that `IListRef` is a view type. Meaning that it won't own the
+ * tensors it holds. It's intended to be used only as argument parameters.
+ * Specifically, where these 2 worlds overlap.
+ *
+ * What is this for?
+ * =================
+ * Historically, PyTorch has maintained 2 different APIs: the unboxed
+ * (called from C++ API and Python eager mode) and boxed APIs (called
+ * from the TorchScript JIT, mobile interpreter, and boxed fallbacks).
+ *
+ * Calling unboxed kernels from the boxed "world" and vice-versa may
+ * result in non-negligible overhead. Lists are one of those types:
+ *
+ * - Boxed world: `c10::List`
+ * - Unboxed world: `c10::ArrayRef`
+ *
+ * In this context, `c10::IListRef` solves this problem by wrapping those
+ * 2 container types, so that we don't need to convert from one to
+ * the other.
+ *
+ * (see https://github.com/pytorch/pytorch/issues/66328)
+ *
+ * What does it do?
+ * ================
+ * This container wraps around the different tagged containers
+ * (currently, only boxed and unboxed), without incurring in extra
+ * overhead for converting from one to another. It does so while
+ * exposing usual container methods, which dispatch to corresponding
+ * implementations.
+ *
+ * While it works with different container types, it introduces
+ * overhead for repeatedly calling member functions (since those will
+ * get dispatched, again). Therefore, you should only use it to iterate
+ * through the list up to one time. If you need to do more complex things,
+ * call `materialize()` first.
+ *
+ * Adding support for a new Tag
+ * ============================
+ * Suppose we want to add a new tag: `Chest`. Here are the steps
+ * we would have to go through:
+ *
+ * 1. Add a line for it in the macro `TORCH_ILISTREF_FORALL_TAGS`.
+ *
+ *   #define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
+ *     ...
+ *     _(Chest, ##__VA_ARGS__)
+ *
+ * 2. Add type aliases, union members, and constructors.
+ *
+ *   template 
+ *   class IListRef {
+ *     ...
+ *     using chest_type =
+ *       typename detail::IListRefTagImpl::list_type;
+ *     ...
+ *     IListRef(...) : tag_(IListRefTag::Chest) {
+ *       ...
+ *     }
+ *     ...
+ *     union Payload {
+ *       ...
+ *       chest_type chest;
+ *       ...
+ *     };
+ *     ...
+ *   };
+ *
+ * 3. Add a default implementation for it (in 'IListRef_inl.h'). It's
+ *    preferable to make the default implementation work for `T = Tensor`
+ *    (both `Unboxed` and `Boxed` do it).
+ *
+ *   template 
+ *   class IListRefTagImplBase {
+ *    public:
+ *     using elem_type = ListElemT;
+ *     using list_type = ChestContainer;
+ *
+ *     static const list_type& unwrap(const IListRef& ilist) { ... }
+ *
+ *     static typename list_type::const_iterator& unwrap(
+ *         IListRefIterator& it) { ... }
+ *
+ *     static const typename list_type::const_iterator& unwrap(
+ *         const IListRefIterator& it) { ... }
+ *
+ *     static IListRefConstRef iterator_get(
+ *         const typename list_type::const_iterator& it) { ... }
+ *   }
+ *
+ * 4. Add an specialization for each of the already supported types.
+ *    Finally, for consistency, add them to the tracking list.
+ *    (see [Note: IListRefTagImpl Specializations])
+ *
+ *   template <>
+ *   class IListRefTagImpl
+ *       : public IListRefTagImplBase {};
+ *
+ * Adding support for a new Type
+ * =============================
+ * Suppose we want to add support for a new type: `Matrix`.
+ * Here are the steps we would have to go through:
+ *
+ * 1. Add an specialization for each of the existing tags.
+ *    For consistency, add them to the tracking list.
+ *    (see [Note: IListRefTagImpl Specializations])
+ *
+ *   template <>
+ *   class IListRefTagImpl
+ *       : public IListRefTagImplBase {};
+ *
+ *   template <>
+ *   class IListRefTagImpl
+ *       : public IListRefTagImplBase {};
+ *
+ * Common Problems
+ * ===============
+ * 1. One of `IListRef(Iterator)` methods are failing to compile.
+ *
+ *     That may be happening because the container type you added
+ *     is not compatible with the code written for that method. If
+ *     that's true, then you might have to transform that code into
+ *     a static method call (see `List::operator[]` method).
+ *
+ * 2. Can't make `IListRefIterator::operator*` return a const-reference.
+ *
+ *    First, keep in mind that we assume that boxed containers will
+ *    have to deal with `IValue` (e.g. `c10::List`). In this context,
+ *    what may be happening is that `IValue` doesn't store internally
+ *    your type `T`. Instead, it constructs a type new `T` everytime
+ *    you try to get `T` for it (see `IListRef`).
+ */
+
+namespace c10 {
+template 
+class IListRef;
+
+/*
+ * Applies arbitrary macros to each `IListRefTag`.
+ */
+#define TORCH_ILISTREF_FORALL_TAGS(_, ...) \
+  _(Unboxed, ##__VA_ARGS__)                \
+  _(Boxed, ##__VA_ARGS__)                  \
+  _(Materialized, ##__VA_ARGS__)
+
+/*
+ * Defines a "switch-case" for `TAG`. Inside, it executes `BODY`,
+ * while bringing to scope:
+ *
+ * - `ImplT`: the implementation class for `TAG`
+ * - `this_`: the result of unwrapping `this`
+ */
+#define TORCH_ILISTREF_UNWRAP_CASE(TAG, BODY)                        \
+  case c10::IListRefTag::TAG: {                                      \
+    using ImplT = c10::detail::IListRefTagImpl; \
+    auto& this_ = ImplT::unwrap(*this);                              \
+    BODY                                                             \
+  } break;
+
+/*
+ * Dispatches the unwrap call, depending on `TAG`, followed by
+ * the execution of `BODY`. It aborts if `TAG` is not a `IListRefTag`.
+ *
+ * This macro is useful because it allows us to handle different
+ * types (that correspond to different tags) to be implemented
+ * only once. We can do it even when the implementation of the
+ * different tags aren't syntatically the same, by dispatching
+ * it to a function (e.g. `ImplT::(this_)`).
+ */
+#define TORCH_ILISTREF_UNWRAP(TAG, BODY)                         \
+  switch (TAG) {                                                 \
+    TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
+    break;                                                       \
+    default:                                                     \
+      TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");     \
+  }
+
+enum class IListRefTag {
+#define DEFINE_TAG(tag, ...) tag,
+  TORCH_ILISTREF_FORALL_TAGS(DEFINE_TAG)
+#undef DEFINE_TAG
+      None
+};
+
+namespace detail {
+/*
+ * Type alias that specifies whether we return a reference or a copy of `T`.
+ *
+ * What is this for?
+ * =================
+ * Since values in the boxed world are represented by an `IValue`, we also
+ * depend on whether it can be converted to a const-reference (`Tensor`) or
+ * has to create a new copy of `T` (`OptionalTensorRef`).
+ */
+template 
+using IListRefConstRef = typename ivalue_to_const_ref_overload_return::type;
+
+/*
+ * Interface that implements key functions for each `IListRefTag` type.
+ *
+ * What is this for?
+ * =================
+ * Given an `IListRef(Iterator)`, some methods have to be implemented
+ * differently for each `TAG`. Therefore, the methods inside this class
+ * are used as dispatch targets for the different `IListRefTag` values.
+ *
+ * You should create an specialization of this class for each possible
+ * combination of `IListRefTag` type (except `None`) and element types
+ * (e.g. `Tensor`).
+ *
+ * What does it do?
+ * ================
+ * 1. defines static methods to be used as dispatch targets by both
+ *    `IListRef` and `IListRefIterator` (see the implementation of
+ *    `IListRefTagImplBase`).
+ *
+ * 2. defines the `elem_type` and `list_type` aliases that will be
+ *    used in the definition of `IListRef`. In general, we should do
+ *    so by inheriting from `IListRefTagImplBase`.
+ *
+ * [Note: IListRefTagImpl Specialization]
+ * ======================================
+ * For `IListRef(Iterator)`:
+ * - 
+ * - 
+ * - 
+ *
+ * For `IListRef(Iterator)`:
+ * - 
+ * - 
+ * - 
+ */
+template 
+class IListRefTagImpl {};
+
+/*
+ * Base implementation of `IListRefTagImpl` methods.
+ *
+ * What is this for?
+ * =================
+ * This should make adding specializations for new types easier. For
+ * example, one should be able to add a new type just by making its
+ * `IListRefTagImpl` specialization inherit from `IListRefTagImplBase`.
+ *
+ * You should create a partial specialization for this class only if
+ * you introduce a new `IListRefTag`. The idea being that there is one
+ * default implementation for each possible value of `IListRefTag`.
+ *
+ * What does it do?
+ * ================
+ * 1. defines `elem_type` as an alias to `ListElemT`.
+ *
+ * 1. defines `list_type` as an alias to the default container type
+ *    that will hold a collection of `elem_type`. The idea being that
+ *    all types tagged as `TAG` will have `list_type` as its container,
+ *    with different `elem_type`.
+ *
+ * 3. defines the default implementation for each of the methods that
+ *    are supposed to be defined on `IListRefTagImpl` specializations.
+ *
+ * 4. inheriting from `IListRefTagImplBase` also means
+ *    that the payload of the type `IListRef` will be of type `list_type`
+ *    when it is tagged as `TAG`.
+ */
+template 
+class IListRefTagImplBase {};
+
+/*
+ * Materialized container for `IListRef`.
+ *
+ * What is this for?
+ * =================
+ * Container that groups `T` references together. This exchanges the
+ * overhead of every method call from `IListRef` for a dynamic allocation.
+ *
+ * You should use this container instead of `IListRef` if:
+ *
+ *   - You are going to iterate the list more than once
+ *   - You need to repeatedly access arbitrary elements (using `operator[]`)
+ * What does it do?
+
+ * ================
+ * Removes the reference (&) from the type, and wraps it into a
+ * `std::reference_wrapper`. If `IListRefConstRef` is not a
+ * reference type, then it's left unchanged.
+ */
+template 
+using _MaterializedIListRefElem = typename std::conditional<
+    std::is_reference::value,
+    typename std::reference_wrapper::type>,
+    T>::type;
+
+template 
+using MaterializedIListRefElem = _MaterializedIListRefElem>;
+
+template 
+using MaterializedIListRef = std::vector>;
+
+} // namespace detail
+
+/*
+ * Iterator for `IListRef`.
+ *
+ * What is it?
+ * ===========
+ * Currently, a `std::bidirectional_iterator` that wraps the iterator
+ * types defined for each of the `IListRefTag`.
+ *
+ * One should be able to use it, as if it were the unwrapped
+ * iterators themselves.
+
+ * What does it do?
+ * ================
+ * Similarly to `IListRef`, this is a wrapper class. Specifically, it
+ * wraps each container's `const_iterator` type alias. So, for example,
+ * given that the container for `IListRefTag::Boxed` is `c10::List`, this
+ * iterator will wrap a `c10::List::const_iterator`.
+ *
+ * [Note: MSVC Iterator Debug]
+ * ===========================
+ * MSVC `vector::iterator` implementation (used in the boxed variant)
+ * makes it so this union's destructor, copy-constructor (assignment), and
+ * move-constructor (assignment) are implicitly deleted.
+ *
+ * Therefore, we need to explicitly define them as needed. Follows a list
+ * of places where these are needed and their reason:
+ *
+ *   - `Payload` destructor:
+ *     it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is set to 2.
+ *
+ *   - `IListRefIterator` destructor:
+ *     same as above. However, we need to explicitly call the variant
+ *     destructor explicitly.
+ *
+ *   - `IListRefIterator` copy-constructor:
+ *     it is deleted only if the macro `_ITERATOR_DEBUG_LEVEL` is different
+ *     than 0.
+ */
+template 
+class IListRefIterator {
+ private:
+#define DEFINE_FRIEND_CLASS(TAG, ...)                        \
+  friend class detail::IListRefTagImpl; \
+  friend class detail::IListRefTagImplBase<                  \
+      IListRefTag::TAG,                                      \
+      T,                                                     \
+      typename detail::IListRefTagImpl::elem_type>;
+  TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
+#undef DEFINE_FRIEND_CLASS
+
+ public:
+  // C++17 friendly std::iterator implementation
+  using iterator_category = std::bidirectional_iterator_tag;
+  using value_type = T;
+  using difference_type = std::ptrdiff_t;
+  using pointer = T*;
+  using reference = T&;
+
+  using unboxed_iterator_type = typename detail::
+      IListRefTagImpl::list_type::const_iterator;
+  using boxed_iterator_type = typename detail::
+      IListRefTagImpl::list_type::const_iterator;
+  using materialized_iterator_type =
+      typename detail::MaterializedIListRef::const_iterator;
+
+  IListRefIterator() : tag_(IListRefTag::None) {}
+
+#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL != 0
+  // See [Note: MSVC Iterator Debug]
+  IListRefIterator(const IListRefIterator& iterator)
+      : tag_(iterator.tag_) {
+    switch (tag_) {
+      case IListRefTag::Boxed:
+        payload_.boxed_iterator = iterator.payload_.boxed_iterator;
+        break;
+      case IListRefTag::Unboxed:
+        payload_.unboxed_iterator = iterator.payload_.unboxed_iterator;
+        break;
+      case IListRefTag::Materialized:
+        payload_.materialized_iterator = iterator.payload_.materialized_iterator;
+        break;
+      default:
+        TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
+    }
+  }
+#endif
+
+#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
+  // See [Note: MSVC Iterator Debug]
+  ~IListRefIterator() noexcept(false) {
+    switch (tag_) {
+      case IListRefTag::Boxed:
+        payload_.boxed_iterator.~boxed_iterator_type();
+        break;
+      case IListRefTag::Unboxed:
+        payload_.unboxed_iterator.~unboxed_iterator_type();
+        break;
+      case IListRefTag::Materialized:
+        payload_.materialized_iterator.~materialized_iterator_type();
+        break;
+      default:
+        TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");
+    }
+  }
+#endif
+
+  IListRefIterator(boxed_iterator_type boxed) : tag_(IListRefTag::Boxed) {
+    payload_.boxed_iterator = boxed;
+  }
+
+  IListRefIterator(unboxed_iterator_type unboxed) : tag_(IListRefTag::Unboxed) {
+    payload_.unboxed_iterator = unboxed;
+  }
+
+  IListRefIterator(materialized_iterator_type materialized) : tag_(IListRefTag::Materialized) {
+    payload_.materialized_iterator = materialized;
+  }
+
+  detail::IListRefConstRef operator*() const {
+    TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::iterator_get(this_); });
+  }
+
+  IListRefIterator& operator++() {
+    TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
+    return *this;
+  }
+
+  IListRefIterator operator++(int) {
+    auto old = *this;
+    TORCH_ILISTREF_UNWRAP(tag_, { ++this_; });
+    return old;
+  }
+
+  IListRefIterator& operator--() {
+    TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
+    return *this;
+  }
+
+  IListRefIterator operator--(int) {
+    auto old = *this;
+    TORCH_ILISTREF_UNWRAP(tag_, { --this_; });
+    return old;
+  }
+
+  bool operator==(const IListRefIterator& rhs) const {
+    if (tag_ != rhs.tag_) {
+      return false;
+    }
+    TORCH_ILISTREF_UNWRAP(tag_, {
+      auto& rhs_it = ImplT::unwrap(rhs);
+      return this_ == rhs_it;
+    });
+  }
+
+  bool operator!=(const IListRefIterator& rhs) const {
+    return !(*this == rhs);
+  }
+
+ private:
+  union Payload {
+    boxed_iterator_type boxed_iterator;
+    unboxed_iterator_type unboxed_iterator;
+    materialized_iterator_type materialized_iterator;
+    void* _init_ptr;
+    Payload() : _init_ptr(nullptr) {}
+#if defined(_MSC_VER)
+    // See [Note: MSVC Iterator Debug]
+    ~Payload() {}
+#endif
+  };
+
+  Payload payload_;
+  IListRefTag tag_;
+};
+
+/*
+ * See [Note: IListRef]
+ */
+template 
+class IListRef {
+ private:
+#define DEFINE_FRIEND_CLASS(TAG, ...)                        \
+  friend class detail::IListRefTagImpl; \
+  friend class detail::IListRefTagImplBase<                  \
+      IListRefTag::TAG,                                      \
+      T,                                                     \
+      typename detail::IListRefTagImpl::elem_type>;
+  TORCH_ILISTREF_FORALL_TAGS(DEFINE_FRIEND_CLASS)
+#undef DEFINE_FRIEND_CLASS
+
+ public:
+  using unboxed_type =
+      typename detail::IListRefTagImpl::list_type;
+  using boxed_type =
+      typename detail::IListRefTagImpl::list_type;
+  using materialized_type =
+      typename detail::MaterializedIListRef;
+
+  using iterator = IListRefIterator;
+  using const_iterator = IListRefIterator;
+  using reverse_iterator = std::reverse_iterator;
+  using value_type = typename iterator::value_type;
+
+  IListRef() : tag_(IListRefTag::None) {}
+
+  IListRef(const boxed_type& boxed) : tag_(IListRefTag::Boxed) {
+    payload_.boxed = &boxed;
+  }
+
+  IListRef(const unboxed_type& unboxed) : tag_(IListRefTag::Unboxed) {
+    payload_.unboxed = unboxed;
+  }
+
+  IListRef(const std::initializer_list& list) : tag_(IListRefTag::Unboxed) {
+    payload_.unboxed = at::ArrayRef(list);
+  }
+
+  template <
+      typename... UnboxedConstructorArgs,
+      typename = std::enable_if_t<
+          std::is_constructible::value>>
+  IListRef(UnboxedConstructorArgs&&... args) : tag_(IListRefTag::Unboxed) {
+    payload_.unboxed = unboxed_type(std::forward(args)...);
+  }
+
+  IListRef(const materialized_type& materialized) : tag_(IListRefTag::Materialized) {
+    payload_.materialized = &materialized;
+  }
+
+  size_t size() const {
+    TORCH_ILISTREF_UNWRAP(tag_, { return this_.size(); });
+  }
+
+  bool empty() const {
+    return size() == 0;
+  }
+
+  iterator begin() const {
+    TORCH_ILISTREF_UNWRAP(tag_, { return this_.begin(); });
+  }
+
+  iterator end() const {
+    TORCH_ILISTREF_UNWRAP(tag_, { return this_.end(); });
+  }
+
+  detail::IListRefConstRef front() const {
+    TORCH_ILISTREF_UNWRAP(tag_, { return ImplT::front(this_); });
+  }
+
+  /*
+   * Materializes the `IListRef` into a `std::vector`.
+   *
+   * This should be used when one wishes to either:
+   *
+   *   - iterate over the list more than once: each `IListRefIterator`
+   *     member function call has to go through a switch, introducing
+   *     non-negligible overhead
+   *
+   *   - randomly access an arbitrary element using `operator[]`:
+   *     same reason as above
+   */
+  detail::MaterializedIListRef materialize() const {
+    if (isMaterialized()) {
+      return toMaterialized();
+    }
+
+    detail::MaterializedIListRef materialized;
+    materialized.reserve(size());
+    for (const auto& t : *this) {
+      materialized.emplace_back(t);
+    }
+    return materialized;
+  }
+
+#define DEFINE_CHECK(TAG, ...)    \
+  bool is##TAG() const {          \
+    return tag_ == IListRefTag::TAG; \
+  }
+  TORCH_ILISTREF_FORALL_TAGS(DEFINE_CHECK);
+#undef DEFINE_CHECK
+
+  bool isNone() const {
+    return tag_ == IListRefTag::None;
+  }
+
+#define DEFINE_CASTING(TAG, ...)                                          \
+  const typename detail::IListRefTagImpl::list_type& \
+      to##TAG() const {                                                   \
+    TORCH_INTERNAL_ASSERT(is##TAG());                                     \
+    return detail::IListRefTagImpl::unwrap(*this);   \
+  }
+  TORCH_ILISTREF_FORALL_TAGS(DEFINE_CASTING);
+#undef DEFINE_CASTING
+
+ private:
+  union Payload {
+    const boxed_type* boxed;
+    unboxed_type unboxed;
+    const materialized_type* materialized;
+    Payload() : boxed(nullptr) {}
+  };
+
+  Payload payload_;
+  IListRefTag tag_;
+};
+
+} // namespace c10
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..bee60a6bafedf55bb24704040dd7e4db2ce9b795
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/IListRef_inl.h
@@ -0,0 +1,201 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class Tensor;
+class OptionalTensorRef;
+}
+
+namespace c10 {
+namespace detail {
+
+/*
+ * Specializations of `IListRefTagImplBase` that implement the default
+ * implementation for `IListRefTag::Unboxed`.
+ */
+template 
+class IListRefTagImplBase {
+ public:
+  using elem_type = ListElemT;
+  using list_type = ArrayRef;
+
+  /*
+   * These `unwrap` static methods unwraps the inner containers out
+   * of `IListRef` (and `IListRefIterator`). They are required when
+   * the macro `TORCH_ILISTREF_UNWRAP` is called.
+   */
+  static const list_type& unwrap(const IListRef& ilist) {
+    return ilist.payload_.unboxed;
+  }
+
+  static typename list_type::const_iterator& unwrap(IListRefIterator& it) {
+    return it.payload_.unboxed_iterator;
+  }
+
+  static const typename list_type::const_iterator& unwrap(
+      const IListRefIterator& it) {
+    return it.payload_.unboxed_iterator;
+  }
+
+  /*
+   * We have these function (besides the `unwrap`s above) because the
+   * implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
+   * weren't syntatically equal for the existing tags at the time
+   * (`Unboxed` and `Boxed`).
+   */
+  static IListRefConstRef front(const list_type& lst) {
+    return lst.front();
+  }
+
+  static IListRefConstRef iterator_get(
+      const typename list_type::const_iterator& it) {
+    return *it;
+  }
+};
+
+/*
+ * Specializations of `IListRefTagImplBase` that implement the default
+ * implementation for `IListRefTag::Boxed`.
+ */
+template 
+class IListRefTagImplBase {
+ public:
+  using elem_type = ListElemT;
+  using list_type = List;
+
+  static const list_type& unwrap(const IListRef& ilist) {
+    return *ilist.payload_.boxed;
+  }
+
+  static typename list_type::const_iterator& unwrap(IListRefIterator& it) {
+    return it.payload_.boxed_iterator;
+  }
+
+  static const typename list_type::const_iterator& unwrap(
+      const IListRefIterator& it) {
+    return it.payload_.boxed_iterator;
+  }
+
+  static IListRefConstRef front(const list_type& lst) {
+    return lst[0];
+  }
+
+  static IListRefConstRef iterator_get(
+      const typename list_type::const_iterator& it) {
+    return (*it).get().toTensor();
+  }
+};
+
+/*
+ * Specializations of `IListRefTagImplBase` that implement the default
+ * implementation for `IListRefTag::Materialized`.
+ */
+template 
+class IListRefTagImplBase> {
+ public:
+  using elem_type = MaterializedIListRefElem;
+  using list_type = MaterializedIListRef;
+
+  static const list_type& unwrap(const IListRef& ilist) {
+    return *ilist.payload_.materialized;
+  }
+
+  static typename list_type::const_iterator& unwrap(IListRefIterator& it) {
+    return it.payload_.materialized_iterator;
+  }
+
+  static const typename list_type::const_iterator& unwrap(
+      const IListRefIterator& it) {
+    return it.payload_.materialized_iterator;
+  }
+
+  static IListRefConstRef front(const list_type& lst) {
+    return lst[0];
+  }
+
+  static IListRefConstRef iterator_get(
+      const typename list_type::const_iterator& it) {
+    return *it;
+  }
+};
+
+/*
+ * [Note: ITensorListRef]
+ * Specializations necessary for `IListRef` type.
+ *
+ * Since the default implementations are usually done with supporting
+ * `Tensor` in mind, we only have to inherit from the base implementations.
+ */
+template <>
+class IListRefTagImpl
+    : public IListRefTagImplBase {};
+
+template <>
+class IListRefTagImpl
+    : public IListRefTagImplBase {};
+
+template <>
+class IListRefTagImpl
+    : public IListRefTagImplBase<
+          IListRefTag::Materialized,
+          at::Tensor,
+          MaterializedIListRefElem> {};
+
+/*
+ * [Note: IOptTensorListRef]
+ * Specializations necessary for `IListRef` type.
+ *
+ * We can't get an `at::OptionalTensorRef` directly from an instance of
+ * `List>` (the type that corresponds to the boxed world).
+ *
+ * So, the default implementation won't help us. Thus, we have to implement
+ * this method ourselves.
+ */
+template <>
+class IListRefTagImpl
+    : public IListRefTagImplBase {};
+
+template <>
+class IListRefTagImpl
+    : public IListRefTagImplBase> {
+
+ public:
+  /*
+   * Given an instance of the types corresponding to the `Boxed` tag, we override
+   * the default implementation, so that we can return a `at::OptionalTensorRef`.
+   */
+  static IListRefConstRef iterator_get(
+      const typename list_type::const_iterator& it) {
+    const auto& ivalue = (*it).get();
+    if (!ivalue.isNone()) {
+        const auto& tensor = ivalue.toTensor();
+        return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
+    }
+    return {};
+  }
+};
+
+template <>
+class IListRefTagImpl
+    : public IListRefTagImplBase<
+          IListRefTag::Materialized,
+          at::OptionalTensorRef,
+          MaterializedIListRefElem> {};
+
+} // namespace detail
+} // namespace c10
+
+namespace at {
+
+// [Note: ITensorListRef]
+using ITensorListRef = c10::IListRef;
+using ITensorListRefIterator = c10::IListRefIterator;
+using MaterializedITensorListRef = c10::detail::MaterializedIListRef;
+// [Note: IOptTensorListRef]
+using IOptTensorListRef = c10::IListRef;
+using IOptTensorListRefIterator = c10::IListRefIterator;
+using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef;
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h
new file mode 100644
index 0000000000000000000000000000000000000000..1244d0fda87ffd3846c6c6352ed46a7de45b5678
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/LegacyTypeDispatch.h
@@ -0,0 +1,111 @@
+#pragma once
+
+// The legacy mechanism for dispatching operators in ATen is a Type
+// object, which is essentially a giant virtual dispatch table
+// for every operation we support dynamically dispatching over.
+//
+// This has been deprecated in favor of ATenDispatch, and in the future,
+// c10 dispatcher.
+// TODO: Clean up what remains here
+
+#include 
+
+namespace at {
+
+// A RAII, thread local (!) guard that will disable dispatch to variable
+// handler.
+//
+// NOTE [ Treating Variables as non-Variables in type dispatch ]
+//
+// What exactly does AutoDispatchBelowAutograd do?  The short answer is, it causes
+// dispatches on ATen functions to go to the non-variable implementation,
+// bypassing autograd handling (and also profiling and tracing).
+//
+// To understand why this guard exists, it's helpful to understand the history
+// behind how Variable was implemented.  Previously, Variables were implemented
+// as a wrapper on Tensors; so the act of processing a Variable involved
+// unwrapping the underlying Tensor, and then calling the underlying base
+// operation on /that/ operation
+//
+// However, after the Variable/Tensor merge, there is no concept of unwrapping
+// a tensor anymore.  If you just call the operation on the same variable
+// again inside your VariableType handler, you'll dispatch back to
+// VariableType, which is not what we want.
+//
+// The solution to the above problem is to add `at::AutoDispatchBelowAutograd`, which
+// when enabled will cause `legacyTensorType()` and `getType()` to always return
+// non-Variable type, even if the tensor being called on is a variable.
+
+/* Note [AutoDispatchBelowAutograd]
+ * AutoDispatchBelowAutograd is **INTERNAL ONLY** that it should be used
+ * for kernel implementations and customized C++ kernels.
+ * If you are looking for a guard to run workload in inference mode, please use
+ * c10::InferenceMode RAII which is user facing API.
+ * In the past AutoDispatchBelowAutograd(or its old version AutoNonVariableTypeMode)
+ * was used in the user code for inference-only workload, this was under risk of
+ * producing wrong results silently in some edge cases. For example:
+ * ```
+ *  torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
+ *  torch::Tensor out = s * s;
+ *  {
+ *    at::AutoDispatchBelowAutograd guard;
+ *    s.add_(1);  // Skips version bump on `s`.
+ *  }
+ *  // WRONG GRADIENT! s.grad() are now computed using `s` value after the
+ *  // inplace update.
+ *  out.backward(torch::ones_like(out));
+ * ```
+ * Users should use `c10::InferenceMode` here so that it'll properly throw an
+ * error saying "one of the variables needed for gradient computation has be modified."
+ */
+struct TORCH_API AutoDispatchBelowAutograd {
+  AutoDispatchBelowAutograd() :
+    autograd_guard_(c10::autograd_dispatch_keyset) {
+  }
+
+  // disable all autograd dispatch keys
+  c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
+};
+
+// TODO: AutoNonVariableTypeMode should be removed in release 1.10.
+struct TORCH_API AutoNonVariableTypeMode {
+  AutoNonVariableTypeMode(bool enabled = true) :
+    autograd_guard_(c10::autograd_dispatch_keyset) {
+    TORCH_WARN_ONCE("AutoNonVariableTypeMode is deprecated and will be removed in 1.10 release. "
+        "For kernel implementations please use AutoDispatchBelowADInplaceOrView instead, "
+        "If you are looking for a user facing API to enable running your inference-only "
+        "workload, please use c10::InferenceMode. Using AutoDispatchBelowADInplaceOrView in user code "
+        "is under risk of producing silent wrong result in some edge cases. "
+        "See Note [AutoDispatchBelowAutograd] for more details.");
+    TORCH_INTERNAL_ASSERT(enabled);
+  }
+
+  // disable all autograd dispatch keys
+  c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
+};
+
+struct TORCH_API AutoDispatchSkipFunctionalize {
+  AutoDispatchSkipFunctionalize() :
+    dispatch_key_guard_(c10::DispatchKeySet(c10::DispatchKey::Functionalize)) {
+  }
+  c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
+};
+
+/* Note [AutoDispatchBelowADInplaceOrView]
+ * AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode
+ * before we split inplace & view ops out of VariableType kernel.
+ * Note this guard is used in VariableType kernels for functional ops
+ * as well as ADInplaceOrView kernels for inplace/view ops to enforce the
+ * Invariant:
+ *   Once you are in VariableType/ADInplaceOrView kernel for an op,
+ *   you never go back to a kernel on same dispatch key until
+ *   you finish the current op.
+ */
+struct TORCH_API AutoDispatchBelowADInplaceOrView {
+  AutoDispatchBelowADInplaceOrView() :
+    dispatch_key_guard_(c10::autograd_dispatch_keyset_with_ADInplaceOrView) {
+  }
+  // disable Autograd & ADInplaceOrView dispatch keys
+  c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
+};
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/List.h b/MLPY/Lib/site-packages/torch/include/ATen/core/List.h
new file mode 100644
index 0000000000000000000000000000000000000000..30316e388457d1cb1148127dd90c78b860589478
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/List.h
@@ -0,0 +1,490 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+class Tensor;
+}
+namespace c10 {
+struct IValue;
+template class List;
+struct Type;
+
+namespace detail {
+
+struct ListImpl final : public c10::intrusive_ptr_target {
+  using list_type = std::vector;
+
+  explicit TORCH_API ListImpl(list_type list_, TypePtr elementType_);
+
+  list_type list;
+
+  TypePtr elementType;
+
+  intrusive_ptr copy() const {
+    return make_intrusive(list, elementType);
+  }
+  friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
+};
+}
+
+namespace impl {
+
+template class ListIterator;
+
+template class ListElementReference;
+
+template
+void swap(ListElementReference&& lhs, ListElementReference&& rhs);
+
+template
+bool operator==(const ListElementReference& lhs, const T& rhs);
+
+template
+bool operator==(const T& lhs, const ListElementReference& rhs);
+
+template
+struct ListElementConstReferenceTraits {
+  // In the general case, we use IValue::to().
+  using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return::type;
+};
+
+// There is no to() overload for c10::optional.
+template<>
+struct ListElementConstReferenceTraits> {
+  using const_reference = c10::optional>;
+};
+
+template
+class ListElementReference final {
+public:
+  operator std::conditional_t<
+      std::is_reference::type>::value,
+      const T&,
+      T>() const;
+
+  ListElementReference& operator=(T&& new_value) &&;
+
+  ListElementReference& operator=(const T& new_value) &&;
+
+  // assigning another ref to this assigns the underlying value
+  ListElementReference& operator=(ListElementReference&& rhs) && noexcept;
+
+  const IValue& get() const& {
+    return *iterator_;
+  }
+
+  friend void swap(ListElementReference&& lhs, ListElementReference&& rhs);
+
+  ListElementReference(const ListElementReference&) = delete;
+  ListElementReference& operator=(const ListElementReference&) = delete;
+
+private:
+  ListElementReference(Iterator iter)
+  : iterator_(iter) {}
+
+  // allow moving, but only our friends (i.e. the List class) can move us
+  ListElementReference(ListElementReference&&) noexcept = default;
+  ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
+    iterator_ = std::move(rhs.iterator_);
+    return *this;
+  }
+
+  friend class List;
+  friend class ListIterator;
+
+  Iterator iterator_;
+};
+
+// this wraps vector::iterator to make sure user code can't rely
+// on it being the type of the underlying vector.
+template 
+class ListIterator final {
+ public:
+   // C++17 friendly std::iterator implementation
+  using iterator_category = std::random_access_iterator_tag;
+  using value_type = T;
+  using difference_type = std::ptrdiff_t;
+  using pointer = T*;
+  using reference = ListElementReference;
+
+  explicit ListIterator() = default;
+  ~ListIterator() = default;
+
+  ListIterator(const ListIterator&) = default;
+  ListIterator(ListIterator&&) noexcept = default;
+  ListIterator& operator=(const ListIterator&) = default;
+  ListIterator& operator=(ListIterator&&) noexcept = default;
+
+  ListIterator& operator++() {
+      ++iterator_;
+      return *this;
+  }
+
+  ListIterator operator++(int) {
+      ListIterator copy(*this);
+      ++*this;
+      return copy;
+  }
+
+  ListIterator& operator--() {
+      --iterator_;
+      return *this;
+  }
+
+  ListIterator operator--(int) {
+      ListIterator copy(*this);
+      --*this;
+      return copy;
+  }
+
+  ListIterator& operator+=(typename List::size_type offset) {
+      iterator_ += offset;
+      return *this;
+  }
+
+  ListIterator& operator-=(typename List::size_type offset) {
+      iterator_ -= offset;
+      return *this;
+  }
+
+  ListIterator operator+(typename List::size_type offset) const {
+    return ListIterator{iterator_ + offset};
+  }
+
+  ListIterator operator-(typename List::size_type offset) const {
+    return ListIterator{iterator_ - offset};
+  }
+
+  friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
+    return lhs.iterator_ - rhs.iterator_;
+  }
+
+  ListElementReference operator*() const {
+    return {iterator_};
+  }
+
+  ListElementReference operator[](typename List::size_type offset) const {
+    return {iterator_ + offset};
+  }
+
+private:
+  explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
+
+  Iterator iterator_;
+
+  friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
+    return lhs.iterator_ == rhs.iterator_;
+  }
+
+  friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
+    return !(lhs == rhs);
+  }
+
+  friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
+    return lhs.iterator_ < rhs.iterator_;
+  }
+
+  friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
+    return lhs.iterator_ <= rhs.iterator_;
+  }
+
+  friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
+    return lhs.iterator_ > rhs.iterator_;
+  }
+
+  friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
+    return lhs.iterator_ >= rhs.iterator_;
+  }
+
+  friend class ListIterator;
+  friend class List;
+};
+
+template List toTypedList(List list);
+template List toList(List&& list);
+template List toList(const List& list);
+const IValue* ptr_to_first_element(const List& list);
+}
+
+/**
+ * An object of this class stores a list of values of type T.
+ *
+ * This is a pointer type. After a copy, both Lists
+ * will share the same storage:
+ *
+ * > List a;
+ * > List b = a;
+ * > b.push_back("three");
+ * > ASSERT("three" == a.get(0));
+ *
+ * We use this class in the PyTorch kernel API instead of
+ * std::vector, because that allows us to do optimizations
+ * and switch out the underlying list implementation without
+ * breaking backwards compatibility for the kernel API.
+ */
+template
+class List final {
+private:
+  // This is an intrusive_ptr because List is a pointer type.
+  // Invariant: This will never be a nullptr, there will always be a valid
+  // ListImpl.
+  c10::intrusive_ptr impl_;
+
+  using internal_reference_type = impl::ListElementReference;
+  using internal_const_reference_type = typename impl::ListElementConstReferenceTraits::const_reference;
+
+public:
+  using value_type = T;
+  using size_type = typename c10::detail::ListImpl::list_type::size_type;
+  using iterator = impl::ListIterator;
+  using const_iterator = impl::ListIterator;
+  using reverse_iterator = impl::ListIterator;
+
+  /**
+   * Constructs an empty list.
+   */
+  explicit List();
+
+  /**
+   * Constructs a list with some initial values.
+   * Example:
+   *   List a({2, 3, 4});
+   */
+  List(std::initializer_list initial_values);
+  explicit List(ArrayRef initial_values);
+
+  /**
+   * Create a generic list with runtime type information.
+   * This only works for c10::impl::GenericList and is not part of the public API
+   * but only supposed to be used internally by PyTorch.
+   */
+  explicit List(TypePtr elementType);
+
+  List(const List&) = default;
+  List& operator=(const List&) = default;
+
+  /**
+   * Create a new List pointing to a deep copy of the same data.
+   * The List returned is a new list with separate storage.
+   * Changes in it are not reflected in the original list or vice versa.
+   */
+  List copy() const;
+
+  /**
+   * Returns the element at specified location pos, with bounds checking.
+   * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
+   */
+  internal_const_reference_type get(size_type pos) const;
+
+  /**
+   * Moves out the element at the specified location pos and returns it, with bounds checking.
+   * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
+   * The list contains an invalid element at position pos afterwards. Any operations
+   * on it before re-setting it are invalid.
+   */
+  value_type extract(size_type pos) const;
+
+  /**
+   * Returns a reference to the element at specified location pos, with bounds checking.
+   * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
+   *
+   * You cannot store the reference, but you can read it and assign new values to it:
+   *
+   *   List list = ...;
+   *   list[2] = 5;
+   *   int64_t v = list[1];
+   */
+  internal_const_reference_type operator[](size_type pos) const;
+
+  internal_reference_type operator[](size_type pos);
+
+  /**
+   * Assigns a new value to the element at location pos.
+   */
+  void set(size_type pos, const value_type& value) const;
+
+  /**
+   * Assigns a new value to the element at location pos.
+   */
+  void set(size_type pos, value_type&& value) const;
+
+  /**
+   * Returns an iterator to the first element of the container.
+   * If the container is empty, the returned iterator will be equal to end().
+   */
+  iterator begin() const;
+
+  /**
+   * Returns an iterator to the element following the last element of the container.
+   * This element acts as a placeholder; attempting to access it results in undefined behavior.
+   */
+  iterator end() const;
+
+  /**
+   * Checks if the container has no elements.
+   */
+  bool empty() const;
+
+  /**
+   * Returns the number of elements in the container
+   */
+  size_type size() const;
+
+  /**
+   * Increase the capacity of the vector to a value that's greater or equal to new_cap.
+   */
+  void reserve(size_type new_cap) const;
+
+  /**
+   * Erases all elements from the container. After this call, size() returns zero.
+   * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
+   */
+  void clear() const;
+
+  /**
+   * Inserts value before pos.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  iterator insert(iterator pos, const T& value) const;
+
+  /**
+   * Inserts value before pos.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  iterator insert(iterator pos, T&& value) const;
+
+  /**
+   * Inserts a new element into the container directly before pos.
+   * The new element is constructed with the given arguments.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  template
+  iterator emplace(iterator pos, Args&&... value) const;
+
+  /**
+   * Appends the given element value to the end of the container.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  void push_back(const T& value) const;
+
+  /**
+   * Appends the given element value to the end of the container.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  void push_back(T&& value) const;
+
+  /**
+   * Appends the given list to the end of the container. Uses at most one memory allocation.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  void append(List lst) const;
+
+  /**
+   * Appends the given element value to the end of the container.
+   * The new element is constructed with the given arguments.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  template
+  void emplace_back(Args&&... args) const;
+
+  /**
+   * Removes the element at pos.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  iterator erase(iterator pos) const;
+
+  /**
+   * Removes the elements in the range [first, last).
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  iterator erase(iterator first, iterator last) const;
+
+  /**
+   * Removes the last element of the container.
+   * Calling pop_back on an empty container is undefined.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  void pop_back() const;
+
+  /**
+   * Resizes the container to contain count elements.
+   * If the current size is less than count, additional default-inserted elements are appended.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  void resize(size_type count) const;
+
+  /**
+   * Resizes the container to contain count elements.
+   * If the current size is less than count, additional copies of value are appended.
+   * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
+   */
+  void resize(size_type count, const T& value) const;
+
+  /**
+   * Value equality comparison. This function implements Python-like semantics for
+   * equality: two lists with the same identity (e.g. same pointer) trivially
+   * compare equal, otherwise each element is compared for equality.
+   */
+  template 
+  friend bool operator==(const List& lhs, const List& rhs);
+
+  template 
+  friend bool operator!=(const List& lhs, const List& rhs);
+
+  /**
+   * Identity comparison. Returns true if and only if `rhs` represents the same
+   * List object as `this`.
+   */
+  bool is(const List& rhs) const;
+
+  std::vector vec() const;
+
+  /**
+   * Returns the number of Lists currently pointing to this same list.
+   * If this is the only instance pointing to this list, returns 1.
+   */
+  // TODO Test use_count
+  size_t use_count() const;
+
+  TypePtr elementType() const;
+
+  // See [unsafe set type] for why this exists.
+  void unsafeSetElementType(TypePtr t);
+
+private:
+  explicit List(c10::intrusive_ptr&& elements);
+  explicit List(const c10::intrusive_ptr& elements);
+  friend struct IValue;
+  template friend List impl::toTypedList(List);
+  template friend List impl::toList(List&&);
+  template friend List impl::toList(const List&);
+  friend const IValue* impl::ptr_to_first_element(const List& list);
+};
+
+namespace impl {
+// GenericList is how IValue stores lists. It is, however, not part of the
+// public API. Kernels should use Lists with concrete types instead
+// (maybe except for some internal prim ops).
+using GenericList = List;
+
+const IValue* ptr_to_first_element(const GenericList& list);
+
+}
+}
+
+namespace torch {
+  template using List = c10::List;
+}
+
+#include   // IWYU pragma: keep
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/List_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/List_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..3866f77ad866551432b08b975bb30382edbf8ef5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/List_inl.h
@@ -0,0 +1,360 @@
+#pragma once
+
+#include 
+#include 
+
+namespace c10 {
+
+template decltype(auto) getTypePtr();
+std::string toString(const Type& type);
+
+template
+List::List(c10::intrusive_ptr&& elements)
+: impl_(std::move(elements)) {}
+
+template
+List::List(const c10::intrusive_ptr& elements)
+: impl_(elements) {}
+
+template
+List::List()
+: List(make_intrusive(
+  typename c10::detail::ListImpl::list_type(),
+  getTypePtr())) {
+  static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead.");
+}
+
+template
+List::List(ArrayRef values)
+: List(make_intrusive(
+    typename c10::detail::ListImpl::list_type(),
+    getTypePtr())) {
+  static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType).");
+  impl_->list.reserve(values.size());
+  for (const T& element : values) {
+    impl_->list.push_back(element);
+  }
+}
+
+template
+List::List(std::initializer_list initial_values)
+: List(ArrayRef(initial_values)) {
+  static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType).");
+}
+
+template
+List::List(TypePtr elementType)
+: List(make_intrusive(
+    typename c10::detail::ListImpl::list_type(),
+    std::move(elementType))) {
+  static_assert(std::is_same::value || std::is_same>::value,
+                "This constructor is only valid for c10::impl::GenericList or List.");
+}
+
+namespace impl {
+template
+List toTypedList(impl::GenericList list) {
+  // If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
+  // because upcasting would allow people to add types into the new list that would break the old list.
+  // However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
+  // allow upcasting. This can be a perf improvement since we can cast List to List>
+  // without having to copy it. This is also used to provide backwards compatibility with some old models
+  // that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
+  // as List before we changed that argument to be List>. When deserializing, we
+  // have list.use_count() == 1 and can deserialize the List directly as List>.
+  TORCH_CHECK(*list.impl_->elementType == *getTypePtr()
+    || (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr()))
+    , "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(*getTypePtr()), ">. Types mismatch.");
+  return List(std::move(list.impl_));
+}
+
+template
+impl::GenericList toList(List&& list) {
+  return GenericList(std::move(list.impl_));
+}
+template
+impl::GenericList toList(const List& list) {
+  return GenericList(list.impl_);
+}
+}
+
+template
+List List::copy() const {
+  return List(impl_->copy());
+}
+
+namespace detail {
+  template
+  T list_element_to(T element) {
+    return element;
+  }
+  template
+  T list_element_to(const IValue& element) {
+    return element.template to();
+  }
+  template
+  T list_element_to(IValue&& element) {
+    return std::move(element).template to();
+  }
+  template
+  struct ListElementFrom {
+    static IValue from(const T& element) {
+      return element;
+    }
+    static IValue from(T&& element) {
+      return std::move(element);
+    }
+  };
+  template<>
+  struct ListElementFrom {
+    static const IValue& from(const IValue& element) {
+      return element;
+    }
+    static IValue&& from(IValue&& element) {
+      return std::move(element);
+    }
+  };
+}
+
+namespace impl {
+
+template 
+ListElementReference::operator std::conditional_t<
+    std::is_reference::type>::value,
+    const T&,
+    T>() const {
+  return iterator_->template to();
+}
+
+template
+ListElementReference& ListElementReference::operator=(T&& new_value) && {
+  *iterator_ = c10::detail::ListElementFrom::from(std::move(new_value));
+  return *this;
+}
+
+template
+ListElementReference& ListElementReference::operator=(const T& new_value) && {
+  *iterator_ = c10::detail::ListElementFrom::from(new_value);
+  return *this;
+}
+
+template
+ListElementReference& ListElementReference::operator=(ListElementReference&& rhs) && noexcept {
+  *iterator_ = *rhs.iterator_;
+  return *this;
+}
+
+template
+void swap(ListElementReference&& lhs, ListElementReference&& rhs) {
+  std::swap(*lhs.iterator_, *rhs.iterator_);
+}
+
+template
+bool operator==(const ListElementReference& lhs, const T& rhs) {
+  const T& lhs_tmp = lhs;
+  return lhs_tmp == rhs;
+}
+
+template
+inline bool operator==(const T& lhs, const ListElementReference& rhs) {
+  return rhs == lhs;
+}
+
+template
+inline typename ListElementConstReferenceTraits::const_reference
+list_element_to_const_ref(const IValue& element) {
+  return element.template to();
+}
+
+template<>
+inline typename ListElementConstReferenceTraits>::const_reference
+list_element_to_const_ref>(const IValue& element) {
+  return element.toOptionalStringRef();
+}
+
+} // namespace impl
+
+template
+void List::set(size_type pos, const value_type& value) const {
+  impl_->list.at(pos) = c10::detail::ListElementFrom::from(value);
+}
+
+template
+void List::set(size_type pos, value_type&& value) const {
+  impl_->list.at(pos) = c10::detail::ListElementFrom::from(std::move(value));
+}
+
+template
+typename List::internal_const_reference_type List::get(size_type pos) const {
+  return operator[](pos);
+}
+
+template
+typename List::internal_const_reference_type List::operator[](size_type pos) const {
+  return c10::impl::list_element_to_const_ref(impl_->list.at(pos));
+}
+
+template
+typename List::internal_reference_type List::operator[](size_type pos) {
+  static_cast(impl_->list.at(pos)); // Throw the exception if it is out of range.
+  return {impl_->list.begin() + static_castlist)::difference_type>(pos)};
+}
+
+template
+typename List::value_type List::extract(size_type pos) const {
+  auto& elem = impl_->list.at(pos);
+  auto result = c10::detail::list_element_to(std::move(elem));
+  // Reset the list element to a T() instead of None to keep it correctly typed
+  elem = c10::detail::ListElementFrom::from(T{});
+  return result;
+}
+
+template
+typename List::iterator List::begin() const {
+  return iterator(impl_->list.begin());
+}
+
+template
+typename List::iterator List::end() const {
+  return iterator(impl_->list.end());
+}
+
+template
+bool List::empty() const {
+  return impl_->list.empty();
+}
+
+template
+typename List::size_type List::size() const {
+  return impl_->list.size();
+}
+
+template
+void List::reserve(size_type new_cap) const {
+  impl_->list.reserve(new_cap);
+}
+
+template
+void List::clear() const {
+  impl_->list.clear();
+}
+
+template
+typename List::iterator List::insert(iterator pos, const T& value) const {
+  return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(value)) };
+}
+
+template
+typename List::iterator List::insert(iterator pos, T&& value) const {
+  return iterator { impl_->list.insert(pos.iterator_, c10::detail::ListElementFrom::from(std::move(value))) };
+}
+
+template
+template
+typename List::iterator List::emplace(iterator pos, Args&&... value) const {
+  // TODO Use list_element_from?
+  return iterator { impl_->list.emplace(pos.iterator_, std::forward(value)...) };
+}
+
+template
+void List::push_back(const T& value) const {
+  impl_->list.push_back(c10::detail::ListElementFrom::from(value));
+}
+
+template
+void List::push_back(T&& value) const {
+  impl_->list.push_back(c10::detail::ListElementFrom::from(std::move(value)));
+}
+
+template
+void List::append(List b) const {
+  if (b.use_count() == 1) {
+    impl_->list.insert(impl_->list.end(), make_move_iterator(b.impl_->list.begin()), make_move_iterator(b.impl_->list.end()));
+  } else {
+    impl_->list.insert(impl_->list.end(), b.impl_->list.begin(), b.impl_->list.end());
+  }
+}
+
+template
+template
+void List::emplace_back(Args&&... args) const {
+  // TODO Use list_element_from?
+  impl_->list.push_back(T(std::forward(args)...));
+}
+
+template
+typename List::iterator List::erase(iterator pos) const {
+  return iterator { impl_->list.erase(pos.iterator_) };
+}
+
+template
+typename List::iterator List::erase(iterator first, iterator last) const {
+  return iterator { impl_->list.erase(first.iterator_, last.iterator_) };
+}
+
+template
+void List::pop_back() const {
+  impl_->list.pop_back();
+}
+
+template
+void List::resize(size_type count) const {
+  impl_->list.resize(count, T{});
+}
+
+template
+void List::resize(size_type count, const T& value) const {
+  impl_->list.resize(count, value);
+}
+
+template
+bool operator==(const List& lhs, const List& rhs) {
+  // Lists with the same identity trivially compare equal.
+  if (lhs.impl_ == rhs.impl_) {
+    return true;
+  }
+
+  // Otherwise, just compare values directly.
+  return *lhs.impl_ == *rhs.impl_;
+}
+
+template
+bool operator!=(const List& lhs, const List& rhs) {
+  return !(lhs == rhs);
+}
+
+template
+bool List::is(const List& rhs) const {
+  return this->impl_ == rhs.impl_;
+}
+
+template
+std::vector List::vec() const {
+  std::vector result(begin(), end());
+  return result;
+}
+
+template
+size_t List::use_count() const {
+  return impl_.use_count();
+}
+
+template 
+TypePtr List::elementType() const {
+  return impl_->elementType;
+}
+
+template 
+void List::unsafeSetElementType(TypePtr t) {
+  impl_->elementType = std::move(t);
+}
+
+namespace impl {
+
+inline const IValue* ptr_to_first_element(const GenericList& list) {
+  return &list.impl_->list[0];
+}
+
+}
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h b/MLPY/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h
new file mode 100644
index 0000000000000000000000000000000000000000..7aaebf8289e5c3ce80411846ea8f24ca29c3f620
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/MT19937RNGEngine.h
@@ -0,0 +1,194 @@
+#pragma once
+
+#include 
+
+// define constants like M_PI and C keywords for MSVC
+#ifdef _MSC_VER
+#ifndef _USE_MATH_DEFINES
+#define _USE_MATH_DEFINES
+#endif
+#include 
+#endif
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+constexpr int MERSENNE_STATE_N = 624;
+constexpr int MERSENNE_STATE_M = 397;
+constexpr uint32_t MATRIX_A = 0x9908b0df;
+constexpr uint32_t UMASK = 0x80000000;
+constexpr uint32_t LMASK = 0x7fffffff;
+
+/**
+ * Note [Mt19937 Engine implementation]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * Originally implemented in:
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/MT2002/CODES/MTARCOK/mt19937ar-cok.c
+ * and modified with C++ constructs. Moreover the state array of the engine
+ * has been modified to hold 32 bit uints instead of 64 bits.
+ *
+ * Note that we reimplemented mt19937 instead of using std::mt19937 because,
+ * at::mt19937 turns out to be faster in the pytorch codebase. PyTorch builds with -O2
+ * by default and following are the benchmark numbers (benchmark code can be found at
+ * https://github.com/syed-ahmed/benchmark-rngs):
+ *
+ * with -O2
+ * Time to get 100000000 philox randoms with at::uniform_real_distribution = 0.462759s
+ * Time to get 100000000 at::mt19937 randoms with at::uniform_real_distribution = 0.39628s
+ * Time to get 100000000 std::mt19937 randoms with std::uniform_real_distribution = 0.352087s
+ * Time to get 100000000 std::mt19937 randoms with at::uniform_real_distribution = 0.419454s
+ *
+ * std::mt19937 is faster when used in conjunction with std::uniform_real_distribution,
+ * however we can't use std::uniform_real_distribution because of this bug:
+ * http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524. Plus, even if we used
+ * std::uniform_real_distribution and filtered out the 1's, it is a different algorithm
+ * than what's in pytorch currently and that messes up the tests in tests_distributions.py.
+ * The other option, using std::mt19937 with at::uniform_real_distribution is a tad bit slower
+ * than at::mt19937 with at::uniform_real_distribution and hence, we went with the latter.
+ *
+ * Copyright notice:
+ * A C-program for MT19937, with initialization improved 2002/2/10.
+ * Coded by Takuji Nishimura and Makoto Matsumoto.
+ * This is a faster version by taking Shawn Cokus's optimization,
+ * Matthe Bellew's simplification, Isaku Wada's real version.
+ *
+ * Before using, initialize the state by using init_genrand(seed)
+ * or init_by_array(init_key, key_length).
+ *
+ * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ *
+ *   1. Redistributions of source code must retain the above copyright
+ *   notice, this list of conditions and the following disclaimer.
+ *
+ *   2. Redistributions in binary form must reproduce the above copyright
+ *   notice, this list of conditions and the following disclaimer in the
+ *   documentation and/or other materials provided with the distribution.
+ *
+ *   3. The names of its contributors may not be used to endorse or promote
+ *   products derived from this software without specific prior written
+ *   permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ *
+ * Any feedback is very welcome.
+ * http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html
+ * email: m-mat @ math.sci.hiroshima-u.ac.jp (remove space)
+ */
+
+/**
+ * mt19937_data_pod is used to get POD data in and out
+ * of mt19937_engine. Used in torch.get_rng_state and
+ * torch.set_rng_state functions.
+ */
+struct mt19937_data_pod {
+  uint64_t seed_;
+  int left_;
+  bool seeded_;
+  uint32_t next_;
+  std::array state_;
+};
+
+class mt19937_engine {
+public:
+
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+  inline explicit mt19937_engine(uint64_t seed = 5489) {
+    init_with_uint32(seed);
+  }
+
+  inline mt19937_data_pod data() const {
+    return data_;
+  }
+
+  inline void set_data(const mt19937_data_pod& data) {
+    data_ = data;
+  }
+
+  inline uint64_t seed() const {
+    return data_.seed_;
+  }
+
+  inline bool is_valid() {
+    if ((data_.seeded_ == true)
+      && (data_.left_ > 0 && data_.left_ <= MERSENNE_STATE_N)
+      && (data_.next_ <= MERSENNE_STATE_N)) {
+      return true;
+    }
+    return false;
+  }
+
+  inline uint32_t operator()() {
+    if (--(data_.left_) == 0) {
+        next_state();
+    }
+    uint32_t y = *(data_.state_.data() + data_.next_++);
+    y ^= (y >> 11);
+    y ^= (y << 7) & 0x9d2c5680;
+    y ^= (y << 15) & 0xefc60000;
+    y ^= (y >> 18);
+
+    return y;
+  }
+
+private:
+  mt19937_data_pod data_;
+
+  inline void init_with_uint32(uint64_t seed) {
+    data_.seed_ = seed;
+    data_.seeded_ = true;
+    data_.state_[0] = seed & 0xffffffff;
+    for (const auto j : c10::irange(1, MERSENNE_STATE_N)) {
+      data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j);
+    }
+    data_.left_ = 1;
+    data_.next_ = 0;
+  }
+
+  inline uint32_t mix_bits(uint32_t u, uint32_t v) {
+    return (u & UMASK) | (v & LMASK);
+  }
+
+  inline uint32_t twist(uint32_t u, uint32_t v) {
+    return (mix_bits(u,v) >> 1) ^ (v & 1 ? MATRIX_A : 0);
+  }
+
+  inline void next_state() {
+    uint32_t* p = data_.state_.data();
+    data_.left_ = MERSENNE_STATE_N;
+    data_.next_ = 0;
+
+    for(int j = MERSENNE_STATE_N - MERSENNE_STATE_M + 1; --j; p++) {
+      *p = p[MERSENNE_STATE_M] ^ twist(p[0], p[1]);
+    }
+
+    for(int j = MERSENNE_STATE_M; --j; p++) {
+      *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], p[1]);
+    }
+
+    *p = p[MERSENNE_STATE_M - MERSENNE_STATE_N] ^ twist(p[0], data_.state_[0]);
+  }
+
+};
+
+typedef mt19937_engine mt19937;
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/NamedTensor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/NamedTensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..55db7ca5aae53673fbe90b47e3c90a3c88784afe
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/NamedTensor.h
@@ -0,0 +1,139 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+
+class TensorBase;
+
+// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
+// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
+// so we have a couple of workarounds.
+//
+// In the long term, we'll move Dimname to c10 and everything in this file
+// can be refactored out. The main blocker for that is that "c10::Symbol"
+// actually exists outside of c10 and needs to be moved in.
+
+// TensorImpl has a unique_ptr field.
+// XXX: Ideally we would just put optional> into TensorImpl.
+//
+// This class has an important invariant: there must be at least ONE
+// non-wildcard
+struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
+  // This enum is to remind people that the invariant on constructors is that
+  // the list of dimnames must have at least one non-wildcard
+  enum HAS_NON_WILDCARD {
+    HasNonWildcard
+  };
+
+  explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
+    : names_(names.vec()) {
+    check_invariants();
+  }
+  explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector&& names)
+    : names_(std::move(names)) {
+    check_invariants();
+  }
+
+  std::unique_ptr clone() const override {
+    return std::make_unique(HasNonWildcard, names_);
+  }
+
+  DimnameList names() const { return names_; }
+
+  // Used for an assertion in TensorImpl.h
+  int64_t slow_dim() const override {
+    return names_.size();
+  }
+
+  void check_invariants() const {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
+  }
+
+  void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
+    TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
+    std::copy(new_names.begin(), new_names.end(), names_.begin());
+    check_invariants();
+  }
+
+  void set_names(HAS_NON_WILDCARD, std::vector&& new_names) {
+    TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
+    names_ = std::move(new_names);
+    check_invariants();
+  }
+
+  // INVARIANT: at least one Dimname is non-WILDCARD
+  std::vector names_;
+};
+
+// When NamesMode is disabled, then all operations ignore tensors' names fields.
+// Concretely speaking, all tensors are treated as having nullopt names.
+struct TORCH_API NamesMode {
+  static bool is_enabled();
+  static void set_enabled(bool enabled);
+};
+
+
+// A RAII, thread local (!) guard that enables or disables names upon
+// construction, and sets it back to the original value upon destruction.
+struct TORCH_API NoNamesGuard {
+  NoNamesGuard() : prev_mode(NamesMode::is_enabled()), initialized(true) {
+    NamesMode::set_enabled(false);
+  }
+  ~NoNamesGuard() {
+    if (initialized) {
+      reset();
+    }
+  }
+  void reset() {
+    TORCH_INTERNAL_ASSERT(initialized);
+    NamesMode::set_enabled(prev_mode);
+  }
+ private:
+  bool prev_mode;
+  bool initialized;
+};
+
+void check_names_valid_for(const TensorBase& tensor, DimnameList names);
+void check_names_valid_for(size_t tensor_dim, DimnameList names);
+
+// Sets the names of `tensor` to be `names`.
+TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, c10::optional names);
+TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector&& names, bool validate_names);
+
+constexpr size_t kMaxNamedTensorDim = 64;
+
+DimnameList default_names(size_t len);
+
+namespace impl {
+
+// Some helper functions on TensorImpl. Useful for working with names in TH.
+// XXX: Ideally these would exist as methods on TensorImpl
+TORCH_API void internal_set_names_inplace(TensorImpl* impl, c10::optional names, bool validate_names);
+TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names);
+
+void check_names_valid_for(TensorImpl* impl, DimnameList names);
+
+// Returns true if the tensor's names exist and are not all 'None'.
+// Returns false if the tensor's names don't exist (were not allocated),
+// or if all names are 'None'.
+// We treat not-allocated-names the same as allocated names that are all 'None'.
+TORCH_API bool has_names(const TensorImpl* impl);
+
+// Returns the names of the tensor's dimensions.
+// Unnamed tensors are treated as having 'None' in all dimension; this method
+// would return a DimnameList of all 'None's for an unnamed tensor.
+TORCH_API DimnameList get_names(const TensorImpl* impl);
+
+// This is more of an implementation detail; one should use impl::get_names /
+// Tensor::names() whenever possible because it provides a cleaner API.
+// Returns the names of the tensor if they have been allocated; returns nullopt
+// instead if the haven't been. The names of a tensor are not allocated if a
+// tensor is constructed with names=None.
+TORCH_API c10::optional get_opt_names(const TensorImpl* impl);
+
+} // namespace impl
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..067e18717dcd51b755e4babb096e1fa3a52ce43c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/NestedIntSymNodeImpl.h
@@ -0,0 +1,186 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+// The motivating usecase for this is to represent the ragged size structure
+// of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
+// allows us to simply return [B, j0, D] if someone queries for the size of our
+// tensor.
+//
+// Morally we define comparison between two nested ints to return true if
+// that comparison holds for all corresponding elements of the arrays they
+// represent. Comparison between a nested int and a plain int is defined
+// similarly.
+//
+// To simulate this desired behavior but also avoid the O(N) cost of checking,
+// we associate each raggedness pattern with an integer "id" that can be used as
+// a proxy to evaluate equality. We also constrain the range of values for this
+// as to enable inequality checks.
+//
+// We also support a positive integer scalar "coeff" that is used for computing
+// strides. For example given, a [B, j0, D] tensor, it can be strided in two
+// different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
+// differentiate the two cases.
+//
+// During tracing the strides of the outputs need to be a function of the size
+// and strides of the inputs so it is important that NestedIntSymNode itself is
+// able to express this.
+class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
+ public:
+  // CAUTION: you should probably not be constructing these directly; please
+  // the higher-level API in python instead (TODO: actually introduce that).
+  explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
+      : val_(val), coeff_(coeff) {}
+
+  bool bool_() override {
+    return false;
+  }
+
+  bool is_int() override {
+    return true;
+  }
+
+  bool is_float() override {
+    return false;
+  }
+
+  bool is_bool() override {
+    return false;
+  }
+
+  bool is_nested_int() const override {
+    return true;
+  }
+
+  bool has_hint() override {
+    return true;
+  }
+
+  c10::SymNode wrap_int(int64_t num) override {
+    return SymNode(c10::make_intrusive>(num));
+  };
+
+  int64_t guard_int(const char* file, int64_t line) override {
+    TORCH_CHECK(false);
+  }
+
+  double guard_float(const char* file, int64_t line) override {
+    TORCH_CHECK(false, "not a float");
+  }
+
+  bool guard_bool(const char* file, int64_t line) override {
+    TORCH_CHECK(false, "not a bool");
+  }
+
+  int64_t int_() override {
+    TORCH_CHECK(false);
+  }
+
+  std::string str() override {
+    if (coeff_ == 1) {
+      return "j" + std::to_string(val_);
+    }
+    return std::to_string(coeff_) + "*j" + std::to_string(val_);
+  }
+
+  // NOTE [ Inequalities with nested int ]
+  //
+  // The semantics of nested int when it comes to relations is that it is
+  // treated as integer known to be within a certain range,
+  //
+  //     j0 \in [2, int64_t::max]
+  //
+  // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
+  // This is a useful default range for the raggedness pattern of a jagged
+  // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
+  // specialization checks.
+  //
+  // [ Indeterminate inequalities error out ]
+  //
+  // Given the semantic defined above, certain relations like j0 < 3 are thus
+  // indeterminable. In our impl today, evaluating such relations error
+  //
+  // It may seem convenient to just define indeterminate relations to return
+  // False, but the implementation we maintain in parallel using sympy does not
+  // allow this.
+  //
+  // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
+  // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
+  // would mean that means that if we define the indeterminate j0 >= 3 to be
+  // False, the also indeterminate j0 < 3 will be evaluated to be True!
+  //
+  // [ Coefficient are assumed positive ]
+  //
+  // For the purpose of computing inequalities, we consider the coefficient of
+  // the nested int to be a positive integer.
+  //
+  // Thus, no modifications are needed to the logic since
+  // j0 >= k implies coeff * j0 >= k
+  //
+  c10::SymNode eq(const c10::SymNode& other) override;
+  c10::SymNode ne(const c10::SymNode& other) override;
+  c10::SymNode ge(const c10::SymNode& other) override;
+  c10::SymNode gt(const c10::SymNode& other) override;
+  c10::SymNode lt(const c10::SymNode& other) override;
+  c10::SymNode le(const c10::SymNode& other) override;
+  c10::SymNode mul(const c10::SymNode& other) override;
+
+  c10::optional nested_int() override {
+    return val_;
+  }
+
+  c10::optional nested_int_coeff() override {
+    return coeff_;
+  }
+
+  bool is_symbolic() override {
+    return false;
+  }
+
+#define DEFINE_BINARY_NOT_SUPPORTED(name)                           \
+  c10::SymNode name(const c10::SymNode& other) override {           \
+    TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
+  }
+
+  DEFINE_BINARY_NOT_SUPPORTED(add)
+  DEFINE_BINARY_NOT_SUPPORTED(sub)
+  DEFINE_BINARY_NOT_SUPPORTED(truediv)
+  DEFINE_BINARY_NOT_SUPPORTED(pow)
+  DEFINE_BINARY_NOT_SUPPORTED(floordiv)
+  DEFINE_BINARY_NOT_SUPPORTED(mod)
+  DEFINE_BINARY_NOT_SUPPORTED(sym_min)
+  DEFINE_BINARY_NOT_SUPPORTED(sym_max)
+  DEFINE_BINARY_NOT_SUPPORTED(sym_and)
+  DEFINE_BINARY_NOT_SUPPORTED(sym_or)
+
+#undef DEFINE_BINARY_NOT_SUPPORTED
+
+#define DEFINE_NOT_SUPPORTED(name)                                     \
+  c10::SymNode name() override {                                       \
+    TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
+  }
+
+  DEFINE_NOT_SUPPORTED(sym_not)
+  DEFINE_NOT_SUPPORTED(ceil)
+  DEFINE_NOT_SUPPORTED(floor)
+  DEFINE_NOT_SUPPORTED(neg)
+  DEFINE_NOT_SUPPORTED(clone)
+  DEFINE_NOT_SUPPORTED(sym_float)
+
+#undef DEFINE_NOT_SUPPORTED
+
+ private:
+  int64_t val_;
+  int64_t coeff_;
+};
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h b/MLPY/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h
new file mode 100644
index 0000000000000000000000000000000000000000..e061933486045108de1c672c4343c6fa08ae6629
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/PhiloxRNGEngine.h
@@ -0,0 +1,242 @@
+#pragma once
+
+// define constants like M_PI and C keywords for MSVC
+#ifdef _MSC_VER
+#define _USE_MATH_DEFINES
+#include 
+#endif
+
+
+#ifdef __CUDACC__
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+// typedefs for holding vector data
+namespace detail {
+
+typedef at::detail::Array UINT4;
+typedef at::detail::Array UINT2;
+typedef at::detail::Array DOUBLE2;
+typedef at::detail::Array FLOAT2;
+
+} // namespace detail
+
+/**
+ * Note [Philox Engine implementation]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * Originally implemented in PyTorch's fusion compiler
+ * Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
+ * for details regarding the engine.
+ *
+ * Note that currently this implementation of the philox engine is not used
+ * anywhere except for tests in cpu_generator_test.cpp. However, this engine
+ * will replace curandStatePhilox4_32_10_t in the future.
+ *
+ * The philox engine takes a seed value, a subsequeunce
+ * for starting the generation and an offset for the subsequence.
+ * Think of this engine as an algorithm producing a huge array. We are
+ * parallelizing this array by partitioning the huge array and assigning
+ * a thread index to each partition. In other words, each seed value
+ * (there are 2^64 possible seed values) gives a sub array of size
+ * 2^128 (each element in that array is a 128 bit number). Reasoning
+ * behind the array being of size 2^128 is, there are 2^64 possible
+ * thread index value and there is an array of size 2^64 for each of
+ * those thread index. Hence 2^64 * 2^64 = 2^128 for each seed value.
+ *
+ * In short, this generator can produce 2^64 (seed values) * 2^128 (number
+ * of elements in an array given by a seed value) = 2^192 values.
+ *
+ * Arguments:
+ * seed:        Seed values could be any number from 0 to 2^64-1.
+ * subsequence: Subsequence is just the cuda thread indexing with:
+ *              - blockIdx.x * blockDim.x + threadIdx.x
+ * offset:      The offset variable in PhiloxEngine  decides how many 128-bit
+ *              random numbers to skip (i.e. how many groups of 4, 32-bit numbers to skip)
+ *              and hence really decides the total number of randoms that can be achieved
+ *              for the given subsequence.
+ */
+
+class philox_engine {
+public:
+
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+  C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721,
+                                 uint64_t subsequence = 0,
+                                 uint64_t offset = 0) {
+
+    reset_state(seed, subsequence);
+    incr_n(offset);
+  }
+
+  C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721,
+                                 uint64_t subsequence = 0) {
+    key_[0] = static_cast(seed);
+    key_[1] = static_cast(seed >> 32);
+    counter_ = detail::UINT4(0);
+    counter_[2] = static_cast(subsequence);
+    counter_[3] = static_cast(subsequence >> 32);
+    STATE = 0;
+  }
+
+  /**
+   * Set the offset field of Philox Generator to the desired offset.
+   */
+  C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
+    counter_[0] = static_cast(offset);
+    counter_[1] = static_cast(offset >> 32);
+  }
+
+  /**
+   * Gets the current offset of the Philox Generator.
+   */
+  C10_HOST_DEVICE uint64_t get_offset() const {
+    uint64_t lo = static_cast(counter_[0]);
+    uint64_t hi = static_cast(counter_[1]) << 32;
+    return lo | hi;
+  }
+
+  /**
+   * Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
+   */
+  C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { // 10 here to preserve back-compat behavior
+    if(STATE == 0) {
+      detail::UINT4 counter = counter_;
+      detail::UINT2 key = key_;
+      output_ = rand(counter, key, n_rounds);
+      incr();
+    }
+    uint32_t ret = output_[static_cast(STATE)];
+    STATE = (STATE + 1) & 3;
+    return ret;
+  }
+
+  inline float randn(uint32_t n_rounds) {
+    #ifdef __CUDA_ARCH__
+    AT_ASSERT(false, "Unsupported invocation of randn on CUDA");
+    #endif
+    if(STATE == 0) {
+      detail::UINT4 counter = counter_;
+      detail::UINT2 key = key_;
+      output_ = rand(counter, key, n_rounds);
+      incr();
+    }
+    // TODO(min-jean-cho) change to Polar method, a more efficient version of Box-Muller method
+    // TODO(voz) We use std:: below, and thus need a separate impl for CUDA.
+    float u1 = 1 - uint32_to_uniform_float(output_[0]); // uint32_to_uniform_float returns [0,1), we need (0,1] to avoid passing 0 to log.
+    float u2 = 1 - uint32_to_uniform_float(output_[1]);
+    return static_cast(std::sqrt(-2.0 * std::log(u1)) * std::cos(2.0 * M_PI * u2));
+  }
+
+  /**
+   * Function that Skips N 128 bit numbers in a subsequence
+   */
+  C10_HOST_DEVICE inline void incr_n(uint64_t n) {
+    uint32_t nlo = static_cast(n);
+    uint32_t nhi = static_cast(n >> 32);
+    counter_[0] += nlo;
+    // if overflow in x has occurred, carry over to nhi
+    if (counter_[0] < nlo) {
+      nhi++;
+      // if overflow in nhi has occurred during carry over,
+      // propagate that overflow to y and exit to increment z
+      // otherwise return
+      counter_[1] += nhi;
+      if(nhi != 0) {
+        if (nhi <= counter_[1]) {
+          return;
+        }
+      }
+    } else {
+      // if overflow in y has occurred during addition,
+      // exit to increment z
+      // otherwise return
+      counter_[1] += nhi;
+      if (nhi <= counter_[1]) {
+        return;
+      }
+    }
+    if (++counter_[2])
+      return;
+    ++counter_[3];
+  }
+
+  /**
+   * Function that Skips one 128 bit number in a subsequence
+   */
+  C10_HOST_DEVICE inline void incr() {
+    if (++counter_[0])
+      return;
+    if (++counter_[1])
+      return;
+    if (++counter_[2]) {
+      return;
+    }
+    ++counter_[3];
+  }
+
+private:
+  detail::UINT4 counter_;
+  detail::UINT4 output_;
+  detail::UINT2 key_;
+  uint32_t STATE;
+
+  C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b,
+                                    uint32_t *result_high) {
+    #ifdef __CUDA_ARCH__
+      *result_high = __umulhi(a, b);
+      return a*b;
+    #else
+      const uint64_t product = static_cast(a) * b;
+      *result_high = static_cast(product >> 32);
+      return static_cast(product);
+    #endif
+  }
+
+  C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) {
+    uint32_t hi0 = 0;
+    uint32_t hi1 = 0;
+    uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0);
+    uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1);
+    detail::UINT4 ret;
+    ret[0] = hi1 ^ ctr[1] ^ in_key[0];
+    ret[1] = lo1;
+    ret[2] = hi0 ^ ctr[3] ^ in_key[1];
+    ret[3] = lo0;
+    return ret;
+  }
+
+  C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) {
+      // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
+      constexpr float scale = 4.6566127342e-10;
+      return static_cast(value & 0x7FFFFFFF) * scale;
+  }
+
+
+
+  C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) {
+    for (uint32_t round = 0; round < (n_rounds - 1); round++) {
+        counter = single_round(counter, key);
+        key[0] += (kPhilox10A); key[1] += (kPhilox10B);
+      }
+    return single_round(counter, key);
+  }
+
+
+  static const uint32_t kPhilox10A = 0x9E3779B9;
+  static const uint32_t kPhilox10B = 0xBB67AE85;
+  static const uint32_t kPhiloxSA = 0xD2511F53;
+  static const uint32_t kPhiloxSB = 0xCD9E8D57;
+};
+
+typedef philox_engine Philox4_32;
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..eab730ccef58e4b548eb12469cb8428860dc483d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/PythonFallbackKernel.h
@@ -0,0 +1,28 @@
+#pragma once
+#include 
+
+namespace at {
+namespace impl {
+
+struct TORCH_API RestorePythonTLSSnapshot {
+  RestorePythonTLSSnapshot();
+  ~RestorePythonTLSSnapshot();
+
+private:
+  c10::impl::LocalDispatchKeySet saved_;
+  c10::impl::ForceDispatchKeyGuard guard_;
+};
+
+
+// RAII guard to make working with the above TLS safer.
+struct TORCH_API MaybeSetTLSOnEntryGuard {
+public:
+  MaybeSetTLSOnEntryGuard();
+  ~MaybeSetTLSOnEntryGuard();
+
+private:
+  bool value_set_;
+};
+
+} // namespace impl
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h b/MLPY/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h
new file mode 100644
index 0000000000000000000000000000000000000000..979a21ef13a5647a9e839ca1b2632adbfbd2cf4d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/PythonOpRegistrationTrampoline.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include 
+
+// TODO: this can probably live in c10
+
+namespace at {
+namespace impl {
+
+class TORCH_API PythonOpRegistrationTrampoline final {
+  static std::atomic interpreter_;
+
+public:
+  //  Returns true if you successfully registered yourself (that means
+  //  you are in the hot seat for doing the operator registrations!)
+  static bool registerInterpreter(c10::impl::PyInterpreter*);
+
+  // Returns nullptr if no interpreter has been registered yet.
+  static c10::impl::PyInterpreter* getInterpreter();
+};
+
+} // namespace impl
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h b/MLPY/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h
new file mode 100644
index 0000000000000000000000000000000000000000..320bb0e859785c0813e763412dbf92dfb92b98a3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/QuantizerBase.h
@@ -0,0 +1,83 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+class Tensor;
+struct QTensorImpl;
+struct Quantizer;
+using ConstQuantizerPtr = const c10::intrusive_ptr&;
+using QuantizerPtr = c10::intrusive_ptr;
+
+/**
+ * Quantizer is the class for storing all the information
+ * that's necessary to perform quantize and dequantize
+ * operation.
+ *
+ * We might have different types of quantization schemes and this is
+ * the base class for all quantizers.
+ *
+ * QTensorImpl will hold a pointer to Quantizer so that we can support
+ * different quantization schemes on Tensor.
+ *
+ * For example, the most common quantization scheme, Affine Quantization,
+ * requires scale and zero_point as parameters, we'll store scale and zero_point
+ * inside the instance and we can use it to quantize a float Tensor or
+ * dequantize a quantized Tensor.
+ *
+ * When you add new types of leaf Quantizer class, please also
+ * make sure to add a corresponding QScheme enum since
+ * they should have one to one mapping.
+ *
+ * Note about intrusive_ptr:
+ * Quantized Tensor holds an intrusive_ptr to Quantizer, and multiple Tensor can
+ * share the same Quantizer. Quantizer should be immutable.
+ */
+struct TORCH_API Quantizer : public c10::intrusive_ptr_target {
+  const ScalarType scalar_type_;
+  explicit Quantizer(ScalarType scalar_type) : scalar_type_(scalar_type) {}
+  ~Quantizer() override;
+
+  // Copied from torch/csrc/jit/ir/scope.h
+  QuantizerPtr intrusive_from_this() {
+    c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
+                                           // from a raw `this` pointer
+                                           // so we need to bump the refcount
+                                           // to account for this ownership
+    return c10::intrusive_ptr::reclaim(this);
+  }
+
+  /**
+   * Each concrete Quantizer type should have a unique QScheme type.
+   */
+  virtual QScheme qscheme() const = 0;
+
+  ScalarType scalar_type() const {
+    return scalar_type_;
+  }
+
+  /**
+   * quantize a float Tensor into a quantized Tensor.
+   */
+  virtual Tensor quantize(const Tensor& t) = 0;
+
+  /**
+   * dequantize a quantized Tensor into a float Tensor.
+   */
+  virtual Tensor dequantize(const Tensor& t) = 0;
+
+  /**
+   * dequantize a quantized Tensor into a float Tensor, out= variant
+   */
+  virtual Tensor& dequantize_out(Tensor& out, const Tensor& t) = 0;
+
+  /**
+   * Compare against `other` for equality.
+   */
+  virtual bool equalTo(QuantizerPtr other) const = 0;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Range.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Range.h
new file mode 100644
index 0000000000000000000000000000000000000000..eb79331a2fa8e6520929badeeab10868d9f6f23e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Range.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+
+struct Range {
+  Range(int64_t begin, int64_t end)
+    : begin(begin)
+    , end(end) {}
+
+  int64_t size() const { return end - begin; }
+
+  Range operator/(int64_t divisor) {
+    return Range(begin / divisor, end / divisor);
+  }
+
+  int64_t begin;
+  int64_t end;
+};
+
+std::ostream& operator<<(std::ostream& out, const Range& range);
+
+}  // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Reduction.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Reduction.h
new file mode 100644
index 0000000000000000000000000000000000000000..04a94e25e9fc6014bc13ffa0749a3791cceedf94
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Reduction.h
@@ -0,0 +1,16 @@
+#pragma once
+
+namespace at {
+namespace Reduction {
+
+// NB: Keep this in sync with Reduction class in torch/nn/_reduction.py
+// These constants control the reduction behavior of loss functions.
+// Ideally, this would be a scoped enum, but jit doesn't support that
+enum Reduction {
+  None,             // Do not reduce
+  Mean,             // (Possibly weighted) mean of losses
+  Sum,              // Sum losses
+  END
+};
+} // namespace Reduction
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Scalar.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Scalar.h
new file mode 100644
index 0000000000000000000000000000000000000000..7c1649491a2b69236e633783b794d97faf3546b1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Scalar.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ScalarType.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ScalarType.h
new file mode 100644
index 0000000000000000000000000000000000000000..b83740b82dc25709e2aa8d2252c7fad88b4638dd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ScalarType.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Tensor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Tensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..0c3c53f9577c0b9aa544e8c8975bb8ef5b3fb228
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Tensor.h
@@ -0,0 +1,92 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class TORCH_API OptionalTensorRef {
+ public:
+  OptionalTensorRef() = default;
+
+  ~OptionalTensorRef() {
+    ref_.unsafeReleaseTensorImpl();
+  }
+
+  OptionalTensorRef(const TensorBase& src)
+      : ref_(Tensor::unsafe_borrow_t{}, src) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
+  }
+
+  OptionalTensorRef(const OptionalTensorRef& rhs)
+      : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
+
+  OptionalTensorRef& operator=(OptionalTensorRef rhs) {
+    std::swap(ref_, rhs.ref_);
+    return *this;
+  }
+
+  bool has_value() const {
+    return ref_.defined();
+  }
+
+  const Tensor& getTensorRef() const & {
+    return ref_;
+  }
+
+  const Tensor& operator*() const & {
+    return ref_;
+  }
+
+  const Tensor* operator->() const & {
+    return &ref_;
+  }
+
+  operator bool() const {
+    return ref_.defined();
+  }
+
+ private:
+  Tensor ref_;
+};
+
+// Use to convert a TensorBase (that may be undefined) to an at::Tensor
+// without bumping refcount.
+class TORCH_API TensorRef {
+ public:
+  ~TensorRef() {
+    ref_.unsafeReleaseTensorImpl();
+  }
+
+  TensorRef(const TensorBase& src)
+      : ref_(Tensor::unsafe_borrow_t{}, src) {}
+
+  const Tensor& operator*() const & {
+    return ref_;
+  }
+ private:
+  Tensor ref_;
+};
+
+template 
+auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t {
+  // Return the grad argument in case of a hook with void return type to have an
+  // std::function with Tensor return type
+  static_assert(std::is_same::value,
+                "Expected hook to return void");
+  return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) {
+    TensorRef grad(grad_base);
+    fn(*grad);
+    return Tensor();
+  });
+}
+
+template 
+auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t {
+  return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) {
+    TensorRef grad(grad_base);
+    Tensor ret = fn(*grad);
+    return TensorBase(std::move(ret));
+  });
+}
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h
new file mode 100644
index 0000000000000000000000000000000000000000..f5e4cbf7b991dcb7f3033f70a63f2190c85fc278
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/TensorAccessor.h
@@ -0,0 +1,276 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
+// is used to enable the __restrict__ keyword/modifier for the data
+// passed to cuda.
+template 
+struct DefaultPtrTraits {
+  typedef T* PtrType;
+};
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+template 
+struct RestrictPtrTraits {
+  typedef T* __restrict__ PtrType;
+};
+#endif
+
+// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
+// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
+// to functions and types available there (e.g. IntArrayRef isn't).
+
+// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
+template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+class TensorAccessorBase {
+public:
+  typedef typename PtrTraits::PtrType PtrType;
+
+  C10_HOST_DEVICE TensorAccessorBase(
+      PtrType data_,
+      const index_t* sizes_,
+      const index_t* strides_)
+      : data_(data_), sizes_(sizes_), strides_(strides_) {}
+  C10_HOST IntArrayRef sizes() const {
+    return IntArrayRef(sizes_,N);
+  }
+  C10_HOST IntArrayRef strides() const {
+    return IntArrayRef(strides_,N);
+  }
+  C10_HOST_DEVICE index_t stride(index_t i) const {
+    return strides_[i];
+  }
+  C10_HOST_DEVICE index_t size(index_t i) const {
+    return sizes_[i];
+  }
+  C10_HOST_DEVICE PtrType data() {
+    return data_;
+  }
+  C10_HOST_DEVICE const PtrType data() const {
+    return data_;
+  }
+protected:
+  PtrType data_;
+  const index_t* sizes_;
+  const index_t* strides_;
+};
+
+// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
+// `Tensor.accessor()`.
+// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
+// indexing on the device uses `TensorAccessor`s.
+template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+class TensorAccessor : public TensorAccessorBase {
+public:
+  typedef typename PtrTraits::PtrType PtrType;
+
+  C10_HOST_DEVICE TensorAccessor(
+      PtrType data_,
+      const index_t* sizes_,
+      const index_t* strides_)
+      : TensorAccessorBase(data_,sizes_,strides_) {}
+
+  C10_HOST_DEVICE TensorAccessor operator[](index_t i) {
+    return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
+  }
+
+  C10_HOST_DEVICE const TensorAccessor operator[](index_t i) const {
+    return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
+  }
+};
+
+template class PtrTraits, typename index_t>
+class TensorAccessor : public TensorAccessorBase {
+public:
+  typedef typename PtrTraits::PtrType PtrType;
+
+  C10_HOST_DEVICE TensorAccessor(
+      PtrType data_,
+      const index_t* sizes_,
+      const index_t* strides_)
+      : TensorAccessorBase(data_,sizes_,strides_) {}
+  C10_HOST_DEVICE T & operator[](index_t i) {
+    // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
+    return this->data_[this->strides_[0]*i];
+  }
+  C10_HOST_DEVICE const T & operator[](index_t i) const {
+    return this->data_[this->strides_[0]*i];
+  }
+};
+
+
+// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
+// and as
+// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
+// in order to transfer them on the device when calling kernels.
+// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
+// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
+// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
+// on the device, so those functions are host only.
+template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+class GenericPackedTensorAccessorBase {
+public:
+  typedef typename PtrTraits::PtrType PtrType;
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+  C10_HOST GenericPackedTensorAccessorBase(
+      PtrType data_,
+      const index_t* sizes_,
+      const index_t* strides_)
+      : data_(data_) {
+    std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
+    std::copy(strides_, strides_ + N, std::begin(this->strides_));
+  }
+
+  // if index_t is not int64_t, we want to have an int64_t constructor
+  template ::value>::type>
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
+  C10_HOST GenericPackedTensorAccessorBase(
+      PtrType data_,
+      const source_index_t* sizes_,
+      const source_index_t* strides_)
+      : data_(data_) {
+    for (const auto i : c10::irange(N)) {
+      this->sizes_[i] = sizes_[i];
+      this->strides_[i] = strides_[i];
+    }
+  }
+
+  C10_HOST_DEVICE index_t stride(index_t i) const {
+    return strides_[i];
+  }
+  C10_HOST_DEVICE index_t size(index_t i) const {
+    return sizes_[i];
+  }
+  C10_HOST_DEVICE PtrType data() {
+    return data_;
+  }
+  C10_HOST_DEVICE const PtrType data() const {
+    return data_;
+  }
+protected:
+  PtrType data_;
+  // NOLINTNEXTLINE(*c-arrays*)
+  index_t sizes_[N];
+  // NOLINTNEXTLINE(*c-arrays*)
+  index_t strides_[N];
+  C10_HOST void bounds_check_(index_t i) const {
+    TORCH_CHECK_INDEX(
+        0 <= i && i < index_t{N},
+        "Index ",
+        i,
+        " is not within bounds of a tensor of dimension ",
+        N);
+  }
+};
+
+template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase {
+public:
+  typedef typename PtrTraits::PtrType PtrType;
+
+  C10_HOST GenericPackedTensorAccessor(
+      PtrType data_,
+      const index_t* sizes_,
+      const index_t* strides_)
+      : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {}
+
+  // if index_t is not int64_t, we want to have an int64_t constructor
+  template ::value>::type>
+  C10_HOST GenericPackedTensorAccessor(
+      PtrType data_,
+      const source_index_t* sizes_,
+      const source_index_t* strides_)
+      : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {}
+
+  C10_DEVICE TensorAccessor operator[](index_t i) {
+    index_t* new_sizes = this->sizes_ + 1;
+    index_t* new_strides = this->strides_ + 1;
+    return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
+  }
+
+  C10_DEVICE const TensorAccessor operator[](index_t i) const {
+    const index_t* new_sizes = this->sizes_ + 1;
+    const index_t* new_strides = this->strides_ + 1;
+    return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
+  }
+
+  /// Returns a PackedTensorAccessor of the same dimension after transposing the
+  /// two dimensions given. Does not actually move elements; transposition is
+  /// made by permuting the size/stride arrays. If the dimensions are not valid,
+  /// asserts.
+  C10_HOST GenericPackedTensorAccessor transpose(
+      index_t dim1,
+      index_t dim2) const {
+    this->bounds_check_(dim1);
+    this->bounds_check_(dim2);
+    GenericPackedTensorAccessor result(
+        this->data_, this->sizes_, this->strides_);
+    std::swap(result.strides_[dim1], result.strides_[dim2]);
+    std::swap(result.sizes_[dim1], result.sizes_[dim2]);
+    return result;
+  }
+};
+
+template class PtrTraits, typename index_t>
+class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase {
+public:
+  typedef typename PtrTraits::PtrType PtrType;
+  C10_HOST GenericPackedTensorAccessor(
+      PtrType data_,
+      const index_t* sizes_,
+      const index_t* strides_)
+      : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {}
+
+  // if index_t is not int64_t, we want to have an int64_t constructor
+  template ::value>::type>
+  C10_HOST GenericPackedTensorAccessor(
+      PtrType data_,
+      const source_index_t* sizes_,
+      const source_index_t* strides_)
+      : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {}
+
+  C10_DEVICE T & operator[](index_t i) {
+    return this->data_[this->strides_[0] * i];
+  }
+  C10_DEVICE const T& operator[](index_t i) const {
+    return this->data_[this->strides_[0]*i];
+  }
+
+  // Same as in the general N-dimensional case, but note that in the
+  // 1-dimensional case the returned PackedTensorAccessor will always be an
+  // identical copy of the original
+  C10_HOST GenericPackedTensorAccessor transpose(
+      index_t dim1,
+      index_t dim2) const {
+    this->bounds_check_(dim1);
+    this->bounds_check_(dim2);
+    return GenericPackedTensorAccessor(
+        this->data_, this->sizes_, this->strides_);
+  }
+};
+
+
+// Can't put this directly into the macro function args because of commas
+#define AT_X GenericPackedTensorAccessor
+
+// Old name for `GenericPackedTensorAccessor`
+template  class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X)
+
+#undef AT_X
+
+template  class PtrTraits = DefaultPtrTraits>
+using PackedTensorAccessor32 = GenericPackedTensorAccessor;
+
+template  class PtrTraits = DefaultPtrTraits>
+using PackedTensorAccessor64 = GenericPackedTensorAccessor;
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/TensorBase.h b/MLPY/Lib/site-packages/torch/include/ATen/core/TensorBase.h
new file mode 100644
index 0000000000000000000000000000000000000000..8102105aef7acd0eec70b22ae471e91fafd14192
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/TensorBase.h
@@ -0,0 +1,1055 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace torch { namespace autograd {
+
+struct Node;
+
+}} // namespace torch::autograd
+
+namespace at {
+
+class Tensor;
+class TensorBase;
+
+// Convert Tensor to TensorBase without any need to include Tensor.h
+TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
+
+namespace impl {
+inline bool variable_excluded_from_dispatch() {
+#ifdef C10_MOBILE
+  // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
+  return true;
+#else
+  return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
+#endif
+}
+
+}
+
+// NOTE: [Tensor vs. TensorBase]
+//
+// Tensor, being the central data structure in PyTorch, gets used and
+// it's header included almost everywhere. Unfortunately this means
+// every time an operator signature is updated or changed in
+// native_functions.yaml, you (and every other PyTorch developer) need
+// to recompile all of ATen and it's dependencies.
+//
+// TensorBase aims to break up these header dependencies, and improve
+// incremental build times for all PyTorch developers. TensorBase
+// represents a reference counted handle to TensorImpl, exactly the
+// same as Tensor. However, TensorBase doesn't have code generated
+// methods in it's API and thus no dependence on native_functions.yaml.
+//
+// Usage tips
+// ----------
+// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
+//   or .cu file to ensure it has no header dependencies on
+//   native_functions.yaml (direct or indirect).
+// - Tensor inherits from TensorBase, so functions taking
+//   `const TensorBase &` are callable with Tensor as well.
+// - TensorBase can be converted to tensor with `Tensor(tensor_base)`,
+//   but this requires a reference-count bump. OptionalTensorRef on
+//   the other hand can materialize a `const Tensor &` without
+//   touching the reference-count.
+class TORCH_API TensorBase {
+ public:
+  struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
+
+ protected:
+  // Create a Tensor with a +0 reference count. Special care must be
+  // taken to avoid decrementing this reference count at destruction
+  // time. Intended to support MaybeOwnedTraits.
+  explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
+      : impl_(c10::intrusive_ptr::reclaim(rhs.impl_.get())) {}
+  friend MaybeOwnedTraits;
+
+ public:
+  TensorBase() = default;
+  // This constructor should not be used by end users and is an implementation
+  // detail invoked by autogenerated code.
+  explicit TensorBase(
+      c10::intrusive_ptr tensor_impl)
+      : impl_(std::move(tensor_impl)) {
+    if (impl_.get() == nullptr) {
+      throw std::runtime_error("TensorImpl with nullptr is not supported");
+    }
+  }
+  TensorBase(const TensorBase&) = default;
+  TensorBase(TensorBase&&) noexcept = default;
+
+ public:
+  // Creates a new wrapper from TensorImpl. Intentionally a free method because
+  // it should be used with care. Checks necessary invariants
+  static TensorBase wrap_tensor_impl(
+      c10::intrusive_ptr tensor_impl) {
+    TensorBase r(std::move(tensor_impl));
+    r.enforce_invariants();
+    return r;
+  }
+
+  int64_t dim() const {
+    return impl_->dim();
+  }
+  int64_t storage_offset() const {
+    return impl_->storage_offset();
+  }
+
+  TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
+    if (is_contiguous(memory_format)) {
+      return *this;
+    } else {
+      return __dispatch_contiguous(memory_format);
+    }
+  }
+
+  /// Should be used if *this can reasonably be expected to be contiguous and
+  /// performance is important.
+  /// Compared to contiguous, it saves a reference count
+  /// increment/decrement if *this is already contiguous, at the cost
+  /// in all cases of an extra pointer of stack usage, an extra branch
+  /// to access, and an extra branch at destruction time.
+  c10::MaybeOwned expect_contiguous(
+      MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
+
+  // Use .contiguous() instead. Trying to borrow from a prvalue
+  // will only lead to trouble and dangling references.
+  c10::MaybeOwned expect_contiguous(
+      MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
+
+  const TensorBase& fill_(const c10::Scalar& scalar) const;
+  const TensorBase& zero_() const;
+
+  TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) const;
+
+  bool is_complex() const {
+    return at::isComplexType(this->scalar_type());
+  }
+
+  bool is_floating_point() const {
+    return at::isFloatingType(this->scalar_type());
+  }
+
+  bool is_signed() const {
+    return at::isSignedType(this->scalar_type());
+  }
+
+  c10::SymInt sym_size(int64_t dim) const {
+    return impl_->sym_size(dim);
+  }
+
+  c10::SymInt sym_stride(int64_t dim) const {
+    const auto sizes = this->sym_strides();
+    const auto ndim = static_cast(sizes.size());
+    // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
+    return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
+
+  }
+
+  int64_t size(int64_t dim) const {
+    return impl_->size(dim);
+  }
+
+  int64_t stride(int64_t dim) const {
+    const auto strides = this->strides();
+    const auto ndim = static_cast(strides.size());
+    // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
+    return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
+  }
+
+  TensorImpl * unsafeGetTensorImpl() const {
+    return impl_.get();
+  }
+  TensorImpl * unsafeReleaseTensorImpl() {
+    return impl_.release();
+  }
+  const c10::intrusive_ptr& getIntrusivePtr() const {
+    return impl_;
+  }
+
+  c10::intrusive_ptr unsafeReleaseIntrusivePtr() {
+    return std::move(impl_);
+  }
+
+  bool defined() const {
+    return impl_;
+  }
+
+  void reset() {
+    impl_.reset();
+  }
+
+#if defined (_MSC_VER)
+  TensorBase& operator=(const TensorBase& x) & {
+    impl_ = x.impl_;
+    return *this;
+  };
+  TensorBase& operator=(TensorBase&& x) & noexcept {
+    impl_ = std::move(x.impl_);
+    return *this;
+  }
+#else
+  TensorBase& operator=(const TensorBase& x) & = default;
+  TensorBase& operator=(TensorBase&& x) & noexcept = default;
+#endif
+
+  // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
+  TensorBase& operator=(const TensorBase&) && = delete;
+  TensorBase& operator=(TensorBase&&) && noexcept = delete;
+
+  bool is_same(const TensorBase& other) const noexcept {
+    return impl_ == other.impl_;
+  }
+  size_t use_count() const noexcept {
+    return impl_.use_count();
+  }
+  size_t weak_use_count() const noexcept {
+    return impl_.weak_use_count();
+  }
+
+  std::string toString() const;
+
+  IntArrayRef sizes() const {
+    return impl_->sizes();
+  }
+  c10::SymIntArrayRef sym_sizes() const {
+    return impl_->sym_sizes();
+  }
+  c10::SymIntArrayRef sym_strides() const {
+    return impl_->sym_strides();
+  }
+  IntArrayRef strides() const {
+    return impl_->strides();
+  }
+  // See impl::get_opt_names in ATen/NamedTensor.h for docs.
+  c10::optional opt_names() const {
+    return impl::get_opt_names(unsafeGetTensorImpl());
+  }
+  // See impl::get_names in ATen/NamedTensor.h for docs.
+  DimnameList names() const {
+    return impl::get_names(unsafeGetTensorImpl());
+  }
+  int64_t ndimension() const {
+    return dim();
+  }
+
+  bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
+    return impl_->is_contiguous(memory_format);
+  }
+
+  bool is_non_overlapping_and_dense() const {
+    return impl_->is_non_overlapping_and_dense();
+  }
+
+  at::MemoryFormat suggest_memory_format(
+      bool channels_last_strides_exact_match = false) const {
+    // Setting channels_last_strides_exact_match to true forces function to
+    // check 0,1 - sized dimension strides.
+    if (layout() == at::kStrided) {
+      if (impl_->is_strides_like_channels_last()) {
+        if (!channels_last_strides_exact_match ||
+            get_channels_last_strides_2d(sizes()) == strides()) {
+          return at::MemoryFormat::ChannelsLast;
+        }
+      }
+      else if (impl_->is_strides_like_channels_last_3d()) {
+        if (!channels_last_strides_exact_match ||
+            get_channels_last_strides_3d(sizes()) == strides()) {
+          return at::MemoryFormat::ChannelsLast3d;
+        }
+      }
+    }
+    return at::MemoryFormat::Contiguous;
+  }
+
+  // Total bytes consumed by the "view" of elements of the array.  Does not
+  // include size of metadata.  The number reported here does not necessarily
+  // correspond to the true physical memory consumed by a tensor; instead,
+  // it reports the memory the tensor would take *if* it were contiguous.
+  // Defined to be numel() * itemsize()
+  size_t nbytes() const {
+    TORCH_CHECK(layout () != at::kSparse,
+                "nbytes is not defined for sparse tensors.  If you want the size of the constituent " \
+                "tensors, add the nbytes of the indices and values.  If you want the size of the  " \
+                "equivalent dense tensor, multiply numel() by element_size()");
+    return impl_->numel() * impl_->itemsize();
+  }
+
+  c10::SymInt sym_nbytes() const {
+    TORCH_CHECK(layout () != at::kSparse,
+                "nbytes is not defined for sparse tensors.  If you want the size of the constituent " \
+                "tensors, add the nbytes of the indices and values.  If you want the size of the  " \
+                "equivalent dense tensor, multiply numel() by element_size()");
+    return impl_->sym_numel() * impl_->itemsize();
+  }
+
+  int64_t numel() const {
+    return impl_->numel();
+  }
+
+  c10::SymInt sym_numel() const {
+    return impl_->sym_numel();
+  }
+
+  c10::SymInt sym_storage_offset() const {
+    return impl_->sym_storage_offset();
+  }
+
+  // Length of one array element in bytes.  This is the traditional
+  // Numpy naming.
+  size_t itemsize() const {
+    return impl_->itemsize();
+  }
+
+  // Same as itemsize().  This is the PyTorch naming.
+  int64_t element_size() const {
+    return static_cast(impl_->itemsize());
+  }
+
+  DispatchKeySet key_set() const {
+    return impl_->key_set();
+  }
+  ScalarType scalar_type() const {
+    return typeMetaToScalarType(impl_->dtype());
+  }
+  bool has_storage() const {
+    return defined() && impl_->has_storage();
+  }
+  const Storage& storage() const {
+    return impl_->storage();
+  }
+  bool is_alias_of(const at::TensorBase& other) const{
+    return impl_->storage().is_alias_of(other.storage());
+  }
+
+  // Move the storage backend to shm based
+  // to enable memory sharing across processes.
+  //
+  // NB1: the ideal behavior of this API still requires further discussion
+  // but for now we are inclined to keep it consistent with existing THP behavior
+  // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
+  // so we don't assert on anything here and rely on caller knowing
+  // what it's doing.
+  //
+  // NB2: this currently provides Linux fd based shm support only
+  // to simplify the storage lifetime management logic in ATen
+  // and similarly for now we are not adding support for file system based
+  // shm support like in THP due to additional GC manager support needed
+  // to prevent leaks.
+  // As such, calling this from non supported systems (e.g. Windows) would fail.
+  void share_memory_() {
+    at::share_memory_(*this);
+  }
+
+  inline bool _is_zerotensor() const {
+    return impl_->_is_zerotensor();
+  }
+
+  inline void _set_zero(bool zero) const {
+    impl_->_set_zero(zero);
+  }
+
+  inline bool is_conj() const {
+    return impl_->is_conj();
+  }
+
+  // sets the conjugate bit of a tensor.
+  // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
+  // that's what you want. Changing this might lead to incorrect behavior since conjugation is
+  // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
+  inline void _set_conj(bool conjugate) const {
+    impl_->_set_conj(conjugate);
+  }
+
+  inline bool is_neg() const {
+    return impl_->is_neg();
+  }
+
+  // sets the negative bit of a tensor.
+  // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
+  // that's what you want. Changing this might lead to incorrect behavior since we rely on this
+  // bit to determine if a negation needs to be materialized.
+  inline void _set_neg(bool negative) const {
+    impl_->_set_neg(negative);
+  }
+
+  /// Returns a `Tensor`'s layout.
+  Layout layout() const {
+    return impl_->layout();
+  }
+
+  /// Returns a `Tensor`'s dtype (`TypeMeta`).
+  caffe2::TypeMeta dtype() const {
+    return impl_->dtype();
+  }
+
+  /// Returns a `Tensor`'s device.
+  inline Device device() const {
+    return impl_->device();
+  }
+
+  /// Returns a `Tensor`'s device index.
+  DeviceIndex get_device() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->get_device();
+  }
+
+  /// Returns if a `Tensor` has CPU backend.
+  bool is_cpu() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_cpu();
+  }
+
+  /// Returns if a `Tensor` has CUDA backend.
+  bool is_cuda() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_cuda();
+  }
+
+  /// Returns if a `Tensor` has IPU backend.
+  bool is_ipu() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_ipu();
+  }
+
+  /// Returns if a `Tensor` has XPU backend.
+  bool is_xpu() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_xpu();
+  }
+
+  /// Returns if a `Tensor` has XLA backend.
+  bool is_xla() const {
+    return impl_->is_xla();
+  }
+
+  /// Returns if a `Tensor` has MTIA backend.
+  bool is_mtia() const {
+    return impl_->is_mtia();
+  }
+
+  /// Returns if a `Tensor` has HPU backend.
+  bool is_hpu() const {
+    return impl_->is_hpu();
+  }
+
+  /// Returns if a `Tensor` has Lazy backend.
+  bool is_lazy() const {
+    return impl_->is_lazy();
+  }
+
+  /// Returns if a `Tensor` has HIP backend.
+  bool is_hip() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_hip();
+  }
+
+  /// Returns if a `Tensor` has VE backend.
+  bool is_ve() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_ve();
+  }
+
+  /// Returns if a `Tensor` has PrivateUse1 backend.
+  bool is_privateuseone() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_privateuseone();
+  }
+
+  /// Returns if a `Tensor` has sparse backend.
+  bool is_sparse() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_sparse();
+  }
+
+  /// Returns is a `Tensor` has a sparse CSR backend.
+  bool is_sparse_csr() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_sparse_csr();
+  }
+
+  /// Returns if a `Tensor` is mkldnn tensor.
+  bool is_mkldnn() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_mkldnn();
+  }
+
+  /// Returns if a `Tensor` is mps tensor.
+  bool is_mps() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_mps();
+  }
+
+  /// Returns if a `Tensor` is ort tensor.
+  bool is_ort() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_ort();
+  }
+
+  /// Returns if a `Tensor` is vulkan tensor.
+  bool is_vulkan() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_vulkan();
+  }
+
+  /// Returns if a `Tensor` is metal tensor.
+  bool is_metal() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_metal();
+  }
+
+  /// Returns if a `Tensor` has quantized backend.
+  bool is_quantized() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_quantized();
+  }
+
+  /// Returns if a `Tensor` is a meta tensor.  Meta tensors can
+  /// also have other designations.
+  bool is_meta() const {
+    return impl_->is_meta();
+  }
+
+  /// Returns if a `Tensor` is an inference tensor.
+  bool is_inference() const {
+    return impl_->is_inference();
+  }
+
+  // Returns if a `Tensor` is a NestedTensor.
+  bool is_nested() const {
+    return impl_->is_nested();
+  }
+
+  /// If a tensor is a quantized tensor, returns its quantizer
+  /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
+  QuantizerPtr quantizer() const;
+
+  /// Returns if a `Tensor` has any dimension names
+  bool has_names() const {
+    // If a user is using unnamed tensors, then we can short-circuit right here.
+    // Otherwise, impl::has_names attempts to retrieve names.
+    if (!impl_->has_named_tensor_meta()) {
+      return false;
+    }
+    return impl::has_names(unsafeGetTensorImpl());
+  }
+
+  /// Returns a `Tensor`'s dimension names data structure
+  const NamedTensorMeta* get_named_tensor_meta() const {
+    return static_cast(impl_->named_tensor_meta());
+  }
+
+  NamedTensorMeta* get_named_tensor_meta() {
+    return static_cast(impl_->named_tensor_meta());
+  }
+
+  /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
+  /// TensorOptions.h.
+  TensorOptions options() const {
+    return TensorOptions().dtype(dtype())
+                          .device(device())
+                          .layout(layout());
+  }
+
+  const void* const_data_ptr() const {
+    return this->unsafeGetTensorImpl()->data();
+  }
+
+  void* mutable_data_ptr() const {
+    return this->unsafeGetTensorImpl()->mutable_data();
+  }
+
+  // TODO(#97856) Make this return a const pointer. This currently
+  //              returns a non-const pointer because of the large
+  //              number of clients that we still want to audit before
+  //              migrating to mutable_data_ptr().
+  void* data_ptr() const {
+    return mutable_data_ptr();
+  }
+
+  template ::value, int> = 0>
+  const T* const_data_ptr() const;
+
+  template ::value, int> = 0>
+  const std::remove_const_t* const_data_ptr() const;
+
+  template 
+  T* mutable_data_ptr() const;
+
+  // Legacy interface during the migration to indicate that a callsite
+  // has not been audited for mutability.
+  //
+  // Do not add new uses of this, use const_data_ptr() if possible,
+  // mutable_data_ptr() otherwise.
+  //
+  // TODO(#97856) Make this return a const pointer. This is currently
+  //              const because of the vast number of clients that
+  //              rely on this.
+  template 
+  T* data_ptr() const;
+
+  // Purposely not defined here to avoid inlining
+  void print() const;
+
+  // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
+  // dimension.
+  template
+  TensorAccessor accessor() const& {
+    static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()");
+    TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
+    T* ptr = nullptr;
+    if constexpr (std::is_const::value) {
+      ptr = const_data_ptr();
+    } else {
+      ptr = mutable_data_ptr();
+    }
+    return TensorAccessor(ptr,sizes().data(),strides().data());
+  }
+  template
+  TensorAccessor accessor() && = delete;
+
+  // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
+  // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
+  // cast the data pointer to a __restrict__ pointer.
+  // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
+  // as an argument.
+  template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  GenericPackedTensorAccessor generic_packed_accessor() const& {
+    static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()");
+    TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
+    T* ptr = nullptr;
+    if constexpr (std::is_const::value) {
+      ptr = const_data_ptr();
+    } else {
+      ptr = mutable_data_ptr();
+    }
+    return GenericPackedTensorAccessor(static_cast::PtrType>(ptr),sizes().data(),strides().data());
+  }
+  template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  GenericPackedTensorAccessor generic_packed_accessor() && = delete;
+
+  template class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor32 packed_accessor32() const& {
+    TORCH_CHECK(
+        impl_->numel() <=
+            static_cast(std::numeric_limits::max()),
+        "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
+    return generic_packed_accessor();
+  }
+  template class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor32 packed_accessor32() && = delete;
+
+  template class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor64 packed_accessor64() const& {
+    return generic_packed_accessor();
+  }
+  template class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor64 packed_accessor64() && = delete;
+
+  // ~~~~~ Autograd API ~~~~~
+
+  /// \fn bool is_leaf() const;
+  ///
+  /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
+  ///
+  /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
+  /// created by the user. This means that they are not the result of an operation and so
+  /// `grad_fn()` is `nullptr`.
+  ///
+  /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
+  /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
+  ///
+  /// Example:
+  /// @code
+  /// auto a = torch::rand(10, torch::requires_grad());
+  /// std::cout << a.is_leaf() << std::endl; // prints `true`
+  ///
+  /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
+  /// std::cout << b.is_leaf() << std::endl; // prints `false`
+  /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
+  ///
+  /// auto c = torch::rand(10, torch::requires_grad()) + 2;
+  /// std::cout << c.is_leaf() << std::endl; // prints `false`
+  /// // c was created by the addition operation
+  ///
+  /// auto d = torch::rand(10).cuda();
+  /// std::cout << d.is_leaf() << std::endl; // prints `true`
+  /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
+  ///
+  /// auto e = torch::rand(10).cuda().requires_grad_();
+  /// std::cout << e.is_leaf() << std::endl; // prints `true`
+  /// // e requires gradients and has no operations creating it
+  ///
+  /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
+  /// std::cout << f.is_leaf() << std::endl; // prints `true`
+  /// // f requires grad, has no operation creating it
+  /// @endcode
+
+  /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const;
+  ///
+  /// Computes the gradient of current tensor with respect to graph leaves.
+  ///
+  /// The graph is differentiated using the chain rule. If the tensor is
+  /// non-scalar (i.e. its data has more than one element) and requires
+  /// gradient, the function additionally requires specifying ``gradient``.
+  /// It should be a tensor of matching type and location, that contains
+  /// the gradient of the differentiated function w.r.t. this Tensor.
+  ///
+  /// This function accumulates gradients in the leaves - you might need to
+  /// zero them before calling it.
+  ///
+  /// \param gradient Gradient w.r.t. the
+  ///     tensor. If it is a tensor, it will be automatically converted
+  ///     to a Tensor that does not require grad unless ``create_graph`` is True.
+  ///     None values can be specified for scalar Tensors or ones that
+  ///     don't require grad. If a None value would be acceptable then
+  ///     this argument is optional.
+  /// \param retain_graph If ``false``, the graph used to compute
+  ///     the grads will be freed. Note that in nearly all cases setting
+  ///     this option to True is not needed and often can be worked around
+  ///     in a much more efficient way. Defaults to the value of
+  ///     ``create_graph``.
+  /// \param create_graph If ``true``, graph of the derivative will
+  ///     be constructed, allowing to compute higher order derivative
+  ///     products. Defaults to ``false``.
+  /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
+  ///     ``at::Tensor::grad``. All other Tensors will be ignored. If not
+  ///     provided, the gradient is accumulated into all the leaf Tensors
+  ///     that were used to compute the current tensor.
+  ///     When inputs are provided and a given input is not a leaf,
+  ///     the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
+  ///     It is an implementation detail on which the user should not rely.
+  ///     See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
+
+  /// \fn Tensor detach() const;
+  ///
+  /// Returns a new Tensor, detached from the current graph.
+  /// The result will never require gradient.
+
+  /// \fn Tensor & detach_() const;
+  ///
+  /// Detaches the Tensor from the graph that created it, making it a leaf.
+  /// Views cannot be detached in-place.
+
+  /// \fn void retain_grad() const;
+  ///
+  /// Enables this Tensor to have their :attr:`grad` populated during
+  /// :func:`backward`. This is a no-op for leaf tensors.
+
+  /// \fn bool retains_grad() const;
+  ///
+  /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
+  /// populated during :func:`backward`, ``false`` otherwise.
+
+  const TensorBase& set_requires_grad(bool requires_grad) const {
+    impl_->set_requires_grad(requires_grad);
+    return *this;
+  }
+  bool requires_grad() const {
+    return impl_->requires_grad();
+  }
+
+  // The Forward AD API functions below are low level and are not to be used by end
+  // users who should use the API provided in torch/csrc/autograd.h
+
+  /// This function returns the forward gradient for this Tensor at the given level.
+  const Tensor& _fw_grad(uint64_t level) const {
+    return impl_->_fw_grad(level, *this);
+  }
+
+  /// This function can be used to set the value of the forward grad.
+  /// Note that the given new_grad might not be used directly if it has different
+  /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
+  /// new_grad content will be copied into a new Tensor
+  void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
+    impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
+  }
+
+  /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
+  /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
+  /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
+  ///
+  /// One notable difference with the legacy `.data()` function is that changes to the
+  /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
+  /// will not update the original `Variable`, due to the fact that this function
+  /// shallow-copies the `Variable`'s underlying TensorImpl.
+  at::TensorBase tensor_data() const;
+
+  /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
+  /// in Python, which create a new `Variable` that shares the same storage and
+  /// tensor metadata with the original `Variable`, but with a completely new
+  /// autograd history.
+  ///
+  /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
+  /// storage / storage_offset) of a variable created from `var.variable_data()`, those
+  /// changes will not update the original variable `var`. In `.variable_data()`, we set
+  /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
+  /// in order to prevent users from changing metadata of `var.variable_data()`
+  /// and expecting the original variable `var` to also be updated.
+  at::TensorBase variable_data() const;
+
+  // Gradient Node and Edges
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  /// Gets the gradient function of the `Variable`. If this is a leaf variable,
+  /// the pointer returned will be null.
+  ///
+  /// For View Variables:
+  /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
+  /// re-create the grad_fn to express the up-to-date view relationship between
+  /// this and the base Variable.
+  const std::shared_ptr& grad_fn() const;
+
+  // Hooks
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  template 
+  using hook_return_void_t = std::enable_if_t>::value, unsigned>;
+  template 
+  using hook_return_var_t = std::enable_if_t, TensorBase>::value, unsigned>;
+
+  /// Registers a backward hook.
+  ///
+  /// The hook will be called every time a gradient with respect to the Tensor is computed.
+  /// The hook should have one of the following signature:
+  /// ```
+  /// hook(TensorBase grad) -> TensorBase
+  /// ```
+  /// ```
+  /// hook(TensorBase grad) -> void
+  /// ```
+  /// The hook should not modify its argument, but it can optionally return a new gradient
+  /// which will be used in place of `grad`.
+  ///
+  /// This function returns the index of the hook in the list which can be used to remove hook.
+  ///
+  /// Example:
+  /// @code
+  /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
+  /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
+  /// v.backward(torch::tensor({1., 2., 3.}));
+  /// // This prints:
+  /// // ```
+  /// //  2
+  /// //  4
+  /// //  6
+  /// // [ CPUFloatType{3} ]
+  /// // ```
+  /// std::cout << v.grad() << std::endl;
+  /// v.remove_hook(h);  // removes the hook
+  /// @endcode
+  template 
+  hook_return_void_t register_hook(T&& hook) const;
+  template 
+  hook_return_var_t register_hook(T&& hook) const;
+
+protected:
+  unsigned _register_hook(std::function hook) const;
+
+public:
+
+  /// Remove hook at given position
+  void remove_hook(unsigned pos) const;
+
+  // Variable methods
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  bool is_leaf() const;
+
+  int64_t output_nr() const;
+
+  void set_data(const TensorBase & new_data) const;
+
+  TensorBase data() const;
+
+  int64_t _version() const;
+
+  void retain_grad() const;
+
+  bool retains_grad() const;
+
+  const TensorBase& requires_grad_(bool _requires_grad=true) const;
+
+  // View Variables
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  /// Returns true if this `Variable` is a view of another `Variable`.
+  bool is_view() const;
+
+  /// Returns the `Variable` that this `Variable` is a view of. If this
+  /// `Variable` is not a view, throw a `std::runtime_error`.
+  const TensorBase& _base() const;
+
+  // Miscellaneous
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  const std::string& name() const;
+
+protected:
+  void enforce_invariants();
+  c10::intrusive_ptr impl_;
+
+private:
+  TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
+};
+
+inline DeviceIndex get_device(const TensorBase& self) {
+  return self.get_device();
+}
+
+template 
+auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t {
+  // Return the grad argument in case of a hook with void return type to have an
+  // std::function with Tensor return type
+  static_assert(std::is_same::value,
+                "Expected hook to return void");
+  return _register_hook([fn=std::forward(hook)](const TensorBase& grad) {
+    fn(grad);
+    return TensorBase();
+  });
+}
+
+template 
+auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t {
+  return _register_hook(std::forward(hook));
+}
+
+namespace detail {
+// Helper creator for Tensor class which doesn't requires the users to pass
+// in an intrusive_ptr instead it just converts the argument passed to
+// requested intrusive_ptr type.
+template 
+TensorBase make_tensor_base(Args&&... args) {
+  return TensorBase(c10::make_intrusive(std::forward(args)...));
+}
+
+} // namespace detail
+
+static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
+  return legacyExtractDispatchKey(t.key_set());
+}
+
+} // namespace at
+
+namespace c10 {
+template <>
+struct MaybeOwnedTraits {
+  using owned_type = at::TensorBase;
+  using borrow_type = at::TensorBase;
+
+  static borrow_type createBorrow(const owned_type& from) {
+    // NOTE: this can be implemented without the special
+    // unsafe_borrow_t Tensor constructor as
+    //
+    // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl()));
+    //
+    // but that hurts inlining due to the nullptr check in the
+    // Tensor(c10::intrusive_ptr<...>) constructor. We already know
+    // that from.impl_ isn't null because from is a valid Tensor, so
+    // we needn't do the check again. (using __builtin_assume can
+    // avoid this, but wouldn't be portable to MSVC.)
+    return borrow_type(borrow_type::unsafe_borrow_t{}, from);
+  }
+
+  static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
+    lhs.unsafeReleaseTensorImpl();
+    // See above note: this can be implemented with public API
+    // similarly to createBorrow(), but that would hurt inlining.
+    lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
+  }
+
+  static void destroyBorrow(borrow_type& toDestroy) {
+    toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
+  }
+
+  static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
+    return borrow;
+  }
+
+  static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
+    return &borrow;
+  }
+
+  static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
+    return true;
+  }
+};
+
+template <>
+struct ExclusivelyOwnedTraits : public c10::ExclusivelyOwnedTensorTraits {};
+} // namespace c10
+
+namespace at {
+
+inline c10::MaybeOwned borrow_from_optional_tensor(
+    const c10::optional& opt) {
+  return opt.has_value()
+    ? c10::MaybeOwned::borrowed(*opt)
+    : c10::MaybeOwned::owned(std::in_place);
+}
+
+inline c10::MaybeOwned TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
+  if (is_contiguous(memory_format)) {
+    return c10::MaybeOwned::borrowed(*this);
+  } else {
+    return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format));
+  }
+}
+
+namespace symint {
+
+template 
+using enable_if_symint = std::enable_if_t::value>;
+template 
+using enable_if_int = std::enable_if_t::value>;
+
+template >
+c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
+template >
+IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
+
+template >
+c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
+template >
+int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
+
+template >
+c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
+template >
+IntArrayRef strides(const TensorBase& t) { return t.strides(); }
+
+template >
+c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
+template >
+int64_t numel(const TensorBase& t) { return t.numel(); }
+
+} // namespace symint
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/TensorBody.h b/MLPY/Lib/site-packages/torch/include/ATen/core/TensorBody.h
new file mode 100644
index 0000000000000000000000000000000000000000..41e3dc0fc3e04bd9db687af8d199022bbb8b7160
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/TensorBody.h
@@ -0,0 +1,5731 @@
+#pragma once
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,            \
+  meaning the file will need to be re-compiled every time an operator     \
+  is changed or added. Consider if your change would be better placed in  \
+  another file, or if a more specific header might achieve the same goal. \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+#include 
+
+namespace c10{
+template class List;
+template class IListRef;
+}
+namespace at {
+struct Generator;
+struct Type;
+class DeprecatedTypeProperties;
+class Tensor;
+} // namespace at
+namespace at {
+namespace indexing {
+struct TensorIndex;
+} // namespace indexing
+} // namespace at
+
+namespace torch { namespace autograd {
+
+struct Node;
+
+}} // namespace torch::autograd
+
+namespace at {
+
+class OptionalTensorRef;
+class TensorRef;
+class Tensor;
+using TensorList = ArrayRef;
+using ITensorList = c10::IListRef;
+
+using Stream = c10::Stream;
+
+// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
+// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
+//
+// For example:
+//
+// void func(Tensor a) {
+//   Tensor b = a;
+//   ...
+// }
+//
+// In this example, when we say Tensor b = a, we are creating a new object that points to the
+// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the
+// destructor decrements the reference count by calling release() on the TensorImpl it points to.
+// The existing constructors, operator overloads, etc. take care to implement the correct semantics.
+//
+// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
+// special care must be taken to handle this.
+class TORCH_API Tensor: public TensorBase {
+ protected:
+  // Create a Tensor with a +0 reference count. Special care must be
+  // taken to avoid decrementing this reference count at destruction
+  // time. Intended to support MaybeOwnedTraits.
+  explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {}
+  friend MaybeOwnedTraits;
+  friend OptionalTensorRef;
+  friend TensorRef;
+
+ public:
+  Tensor() = default;
+  // This constructor should not be used by end users and is an implementation
+  // detail invoked by autogenerated code.
+  explicit Tensor(
+      c10::intrusive_ptr tensor_impl)
+      : TensorBase(std::move(tensor_impl)) {}
+  Tensor(const Tensor &tensor) = default;
+  Tensor(Tensor &&tensor) = default;
+
+  // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount
+  explicit Tensor(const TensorBase &base): TensorBase(base) {}
+  /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {}
+
+  // Creates a new wrapper from TensorImpl. Intentionally a free method because
+  // it should be used with care. Checks necessary invariants
+  static Tensor wrap_tensor_impl(
+      c10::intrusive_ptr tensor_impl) {
+    return TensorBase::wrap_tensor_impl(std::move(tensor_impl));
+  }
+
+  Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
+    return TensorBase::contiguous(memory_format);
+  }
+
+  Tensor conj() const {
+    if (!this->is_complex()) {
+      return *this;
+    }
+
+    switch (this->layout()) {
+      case at::kSparse:
+      case at::kSparseCsr:
+      case at::kSparseCsc:
+      case at::kSparseBsr:
+      case at::kSparseBsc:
+        return this->conj_physical();
+      default:
+        return this->_conj();
+    }
+  }
+
+  // Aliased by Dimname overloads, so need explicit using
+  using TensorBase::size;
+  using TensorBase::sym_size;
+  using TensorBase::stride;
+
+  /// Should be used if *this can reasonably be expected to be contiguous and
+  /// performance is important.
+  /// Compared to contiguous, it saves a reference count
+  /// increment/decrement if *this is already contiguous, at the cost
+  /// in all cases of an extra pointer of stack usage, an extra branch
+  /// to access, and an extra branch at destruction time.
+  c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
+
+  // Use .contiguous() instead. Trying to borrow from a prvalue Tensor
+  // will only lead to trouble and dangling references.
+  c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
+
+  // The following overloads are very intruiging.  Consider the following
+  // program:
+  //
+  //    x[1] = 3;
+  //
+  // We would expect that the first entry of x is written to 3.  But how can we
+  // actually achieve this?  x[1] evaluates to a tensor...
+  //
+  // The answer is, using a ref-qualifier.  x[1] is an rvalue, which cannot be
+  // (profitably) assigned to in the traditional sense, so we overload
+  // assignment to mean, "Actually, copy 3 into the tensor data."  This is done
+  // with an rvalue-reference ref-qualified overload (the methods with && at the
+  // end of their type.)
+  //
+  // There's one more fly in the ointment: We also want
+  //
+  //    Tensor x = y;
+  //
+  // to work, and we want it NOT to copy.  So we need a traditional operator=
+  // overload.  But we MUST specify a mutable lvalue ref-qualifier, to
+  // disambiguate the traditional overload from the rvalue-reference
+  // ref-qualified overload.  Otherwise, it will be ambiguous, because
+  // a non ref-qualified method is eligible for all situations.
+
+  // Unfortunately, we have to write these constructors out manually
+  // to work around an MSVC bug:
+  //    error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &':
+  //    multiple versions of a defaulted special member functions are not allowed
+  // Tensor& operator=(const Tensor&) & = default;
+  // Tensor& operator=(Tensor&&) & = default;
+
+  // Also MSVC will wrongly issue the following warning with the aforementioned fix
+  //    warning C4522: 'at::Tensor': multiple assignment operators specified
+  // Let's just skip the warning.
+  //
+  // TODO: temporarily disabled
+
+  Tensor& operator=(const TensorBase& x) & {
+    impl_ = x.getIntrusivePtr();
+    return *this;
+  }
+  Tensor& operator=(TensorBase&& x) & noexcept {
+    impl_ = x.unsafeReleaseIntrusivePtr();
+    return *this;
+  }
+
+  Tensor& operator=(const Tensor &x) & {
+    return operator=(static_cast(x));
+  }
+  Tensor& operator=(Tensor &&x) & noexcept {
+    return operator=(static_cast(x));
+  }
+
+  Tensor& operator=(const Scalar &v) && {
+    return fill_(v);
+  }
+  Tensor& operator=(const Tensor &rhs) && {
+    return copy_(rhs);
+  }
+  Tensor& operator=(Tensor&& rhs) && {
+    return copy_(rhs);
+  }
+
+  C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().")
+  DeprecatedTypeProperties & type() const {
+    return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+        dispatchKeyToBackend(legacyExtractDispatchKey(key_set())),
+        scalar_type());
+  }
+
+  Tensor toType(ScalarType t) const {
+    return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  // TODO: Deprecate me
+  Tensor toBackend(Backend b) const {
+    return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  C10_DEPRECATED_MESSAGE("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())")
+  bool is_variable() const noexcept {
+    return !at::impl::variable_excluded_from_dispatch();
+  }
+
+  template
+  C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.")
+  T * data() const {
+    return data_ptr();
+  }
+
+  template 
+  T item() const;
+
+  template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
+  GenericPackedTensorAccessor packed_accessor() const & {
+    return generic_packed_accessor();
+  }
+  template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
+  GenericPackedTensorAccessor packed_accessor() && = delete;
+
+  Tensor operator~() const {
+    return bitwise_not();
+  }
+  Tensor operator-() const {
+    return neg();
+  }
+  Tensor& operator+=(const Tensor & other) {
+    return add_(other);
+  }
+  Tensor& operator+=(const Scalar & other) {
+    return add_(other);
+  }
+  Tensor& operator-=(const Tensor & other) {
+    return sub_(other);
+  }
+  Tensor& operator-=(const Scalar & other) {
+    return sub_(other);
+  }
+  Tensor& operator*=(const Tensor & other) {
+    return mul_(other);
+  }
+  Tensor& operator*=(const Scalar & other) {
+    return mul_(other);
+  }
+  Tensor& operator/=(const Tensor & other) {
+    return div_(other);
+  }
+  Tensor& operator/=(const Scalar & other) {
+    return div_(other);
+  }
+  Tensor& operator&=(const Tensor & other) {
+    return bitwise_and_(other);
+  }
+  Tensor& operator|=(const Tensor & other) {
+    return bitwise_or_(other);
+  }
+  Tensor& operator^=(const Tensor & other) {
+    return bitwise_xor_(other);
+  }
+  Tensor operator[](const Scalar & index) const {
+    if (!index.isIntegral(false)) {
+      TORCH_CHECK_INDEX(false, "Can only index tensors with integral scalars");
+    }
+    return this->operator[](index.toLong());
+  }
+  Tensor operator[](const Tensor & index) const {
+    // These properties are checked in the Scalar constructor, but we already
+    // check them here to provide more useful diagnostics for the user.
+    if (!index.defined()) {
+      TORCH_CHECK_INDEX(false, "Can only index with tensors that are defined");
+    }
+    if (index.dim() != 0) {
+      TORCH_CHECK_INDEX(false,
+                        "Can only index with tensors that are scalars (zero-dim)");
+    }
+    // The Scalar(Tensor) constructor is explicit, so we need to call it.
+    return this->operator[](index.item());
+  }
+  Tensor operator[](int64_t index) const {
+    return select(0, index);
+  }
+
+  Tensor index(ArrayRef indices) const;
+  Tensor index(std::initializer_list indices) const;
+
+  Tensor & index_put_(ArrayRef indices, Tensor const & rhs);
+  Tensor & index_put_(ArrayRef indices, const Scalar& v);
+  Tensor & index_put_(std::initializer_list indices, Tensor const & rhs);
+  Tensor & index_put_(std::initializer_list indices, const Scalar& v);
+
+  Tensor cpu() const {
+    return to(options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  // TODO: The Python version also accepts arguments
+  Tensor cuda() const {
+    return to(options().device(c10::DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor hip() const {
+    return to(options().device(c10::DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor ve() const {
+    return to(options().device(c10::DeviceType::VE), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor vulkan() const {
+    return to(options().device(c10::DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor metal() const {
+    return to(options().device(c10::DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor meta() const {
+    return to(options().device(c10::DeviceType::Meta), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  // ~~~~~ Autograd API ~~~~~
+
+  /// \fn bool is_leaf() const;
+  ///
+  /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
+  ///
+  /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
+  /// created by the user. This means that they are not the result of an operation and so
+  /// `grad_fn()` is `nullptr`.
+  ///
+  /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
+  /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
+  ///
+  /// Example:
+  /// @code
+  /// auto a = torch::rand(10, torch::requires_grad());
+  /// std::cout << a.is_leaf() << std::endl; // prints `true`
+  ///
+  /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
+  /// std::cout << b.is_leaf() << std::endl; // prints `false`
+  /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
+  ///
+  /// auto c = torch::rand(10, torch::requires_grad()) + 2;
+  /// std::cout << c.is_leaf() << std::endl; // prints `false`
+  /// // c was created by the addition operation
+  ///
+  /// auto d = torch::rand(10).cuda();
+  /// std::cout << d.is_leaf() << std::endl; // prints `true`
+  /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
+  ///
+  /// auto e = torch::rand(10).cuda().requires_grad_();
+  /// std::cout << e.is_leaf() << std::endl; // prints `true`
+  /// // e requires gradients and has no operations creating it
+  ///
+  /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
+  /// std::cout << f.is_leaf() << std::endl; // prints `true`
+  /// // f requires grad, has no operation creating it
+  /// @endcode
+
+  /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const;
+  ///
+  /// Computes the gradient of current tensor with respect to graph leaves.
+  ///
+  /// The graph is differentiated using the chain rule. If the tensor is
+  /// non-scalar (i.e. its data has more than one element) and requires
+  /// gradient, the function additionally requires specifying ``gradient``.
+  /// It should be a tensor of matching type and location, that contains
+  /// the gradient of the differentiated function w.r.t. this Tensor.
+  ///
+  /// This function accumulates gradients in the leaves - you might need to
+  /// zero them before calling it.
+  ///
+  /// \param gradient Gradient w.r.t. the
+  ///     tensor. If it is a tensor, it will be automatically converted
+  ///     to a Tensor that does not require grad unless ``create_graph`` is True.
+  ///     None values can be specified for scalar Tensors or ones that
+  ///     don't require grad. If a None value would be acceptable then
+  ///     this argument is optional.
+  /// \param retain_graph If ``false``, the graph used to compute
+  ///     the grads will be freed. Note that in nearly all cases setting
+  ///     this option to True is not needed and often can be worked around
+  ///     in a much more efficient way. Defaults to the value of
+  ///     ``create_graph``.
+  /// \param create_graph If ``true``, graph of the derivative will
+  ///     be constructed, allowing to compute higher order derivative
+  ///     products. Defaults to ``false``.
+  /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
+  ///     ``at::Tensor::grad``. All other Tensors will be ignored. If not
+  ///     provided, the gradient is accumulated into all the leaf Tensors
+  ///     that were used to compute the current tensor.
+  ///     When inputs are provided and a given input is not a leaf,
+  ///     the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
+  ///     It is an implementation detail on which the user should not rely.
+  ///     See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
+  void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const {
+    // NB: Adding this wrapper to _backward here because we'd like our
+    // 'backwards' api to accept the 'inputs' argument optionally. Since code gen
+    // currently does not support optional of TensorList our approach is to replace
+    // backward in native_functions.yaml with _backward and call it here instead.
+    if (inputs.has_value()) {
+      TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty")
+      this->_backward(inputs.value(), gradient, retain_graph, create_graph);
+    } else {
+      this->_backward({}, gradient, retain_graph, create_graph);
+    }
+  }
+
+  /// \fn Tensor detach() const;
+  ///
+  /// Returns a new Tensor, detached from the current graph.
+  /// The result will never require gradient.
+
+  /// \fn Tensor & detach_() const;
+  ///
+  /// Detaches the Tensor from the graph that created it, making it a leaf.
+  /// Views cannot be detached in-place.
+
+  /// \fn void retain_grad() const;
+  ///
+  /// Enables this Tensor to have their :attr:`grad` populated during
+  /// :func:`backward`. This is a no-op for leaf tensors.
+
+  /// \fn bool retains_grad() const;
+  ///
+  /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
+  /// populated during :func:`backward`, ``false`` otherwise.
+
+  const Tensor& set_requires_grad(bool requires_grad) const {
+    TensorBase::set_requires_grad(requires_grad);
+    return *this;
+  }
+
+  /// Return a mutable reference to the gradient. This is conventionally
+  /// used as `t.grad() = x` to set a gradient to a completely new tensor.
+  /// Note that this function work with a non-const Tensor and is not
+  /// thread safe.
+  Tensor& mutable_grad() const {
+    return impl_->mutable_grad();
+  }
+
+  /// This function returns an undefined tensor by default and returns a defined tensor
+  /// the first time a call to `backward()` computes gradients for this Tensor.
+  /// The attribute will then contain the gradients computed and future calls
+  /// to `backward()` will accumulate (add) gradients into it.
+  const Tensor& grad() const {
+    const Tensor& maybe_grad = impl_->grad();
+    if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) {
+      TORCH_WARN(
+        "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
+        "attribute won't be populated during autograd.backward(). If you indeed want the .grad "
+        "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. "
+        "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor "
+        "instead. See github.com/pytorch/pytorch/pull/30531 for more informations.");
+    }
+    return maybe_grad;
+  }
+
+  // The Forward AD API functions below are low level and are not to be used by end
+  // users who should use the API provided in torch/csrc/autograd.h
+
+  /// This function returns the forward gradient for this Tensor at the given level.
+  const Tensor& _fw_grad(uint64_t level) const {
+    return impl_->_fw_grad(level, *this);
+  }
+
+  /// This function can be used to set the value of the forward grad.
+  /// Note that the given new_grad might not be used directly if it has different
+  /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
+  /// new_grad content will be copied into a new Tensor
+  void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
+    impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
+  }
+
+
+  // STOP.  Thinking of adding a method here, which only makes use
+  // of other ATen methods?  Define it in native_functions.yaml.
+
+  //example
+  //Tensor * add(Tensor & b);
+  void __dispatch__backward(at::TensorList inputs, const c10::optional & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false) const;
+  void __dispatch_set_data(const at::Tensor & new_data) const;
+  at::Tensor __dispatch_data() const;
+  bool __dispatch_is_leaf() const;
+  int64_t __dispatch_output_nr() const;
+  int64_t __dispatch__version() const;
+  at::Tensor & __dispatch_requires_grad_(bool requires_grad=true) const;
+  void __dispatch_retain_grad() const;
+  bool __dispatch_retains_grad() const;
+  at::Tensor _fw_primal(int64_t level) const;
+  at::Tensor & rename_(c10::optional names) const;
+  at::Tensor rename(c10::optional names) const;
+  at::Tensor align_to(at::DimnameList names) const;
+  at::Tensor align_to(at::DimnameList order, int64_t ellipsis_idx) const;
+  at::Tensor align_as(const at::Tensor & other) const;
+  at::Tensor refine_names(at::DimnameList names) const;
+  at::Tensor abs() const;
+  at::Tensor & abs_() const;
+  at::Tensor absolute() const;
+  at::Tensor & absolute_() const;
+  at::Tensor angle() const;
+  at::Tensor sgn() const;
+  at::Tensor & sgn_() const;
+  at::Tensor chalf(c10::optional memory_format=c10::nullopt) const;
+  at::Tensor _conj() const;
+  at::Tensor __dispatch_conj() const;
+  at::Tensor _conj_physical() const;
+  at::Tensor conj_physical() const;
+  at::Tensor & conj_physical_() const;
+  at::Tensor resolve_conj() const;
+  at::Tensor resolve_neg() const;
+  at::Tensor _neg_view() const;
+  at::Tensor acos() const;
+  at::Tensor & acos_() const;
+  at::Tensor arccos() const;
+  at::Tensor & arccos_() const;
+  at::Tensor add(const at::Tensor & other, const at::Scalar & alpha=1) const;
+  at::Tensor & add_(const at::Tensor & other, const at::Scalar & alpha=1) const;
+  at::Tensor add(const at::Scalar & other, const at::Scalar & alpha=1) const;
+  at::Tensor & add_(const at::Scalar & other, const at::Scalar & alpha=1) const;
+  at::Tensor addmv(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor & addmv_(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor addr(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor & addr_(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor _is_all_true() const;
+  at::Tensor _is_any_true() const;
+  at::Tensor all(int64_t dim, bool keepdim=false) const;
+  at::Tensor all(at::OptionalIntArrayRef dim, bool keepdim=false) const;
+  at::Tensor all(at::Dimname dim, bool keepdim=false) const;
+  bool allclose(const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
+  at::Tensor any(int64_t dim, bool keepdim=false) const;
+  at::Tensor any(at::OptionalIntArrayRef dim, bool keepdim=false) const;
+  at::Tensor any(at::Dimname dim, bool keepdim=false) const;
+  at::Tensor argmax(c10::optional dim=c10::nullopt, bool keepdim=false) const;
+  at::Tensor argmin(c10::optional dim=c10::nullopt, bool keepdim=false) const;
+  at::Tensor acosh() const;
+  at::Tensor & acosh_() const;
+  at::Tensor arccosh() const;
+  at::Tensor & arccosh_() const;
+  at::Tensor asinh() const;
+  at::Tensor & asinh_() const;
+  at::Tensor arcsinh() const;
+  at::Tensor & arcsinh_() const;
+  at::Tensor atanh() const;
+  at::Tensor & atanh_() const;
+  at::Tensor arctanh() const;
+  at::Tensor & arctanh_() const;
+  at::Tensor as_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) const;
+  at::Tensor as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) const;
+  const at::Tensor & as_strided_(at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) const;
+  const at::Tensor & as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) const;
+  at::Tensor asin() const;
+  at::Tensor & asin_() const;
+  at::Tensor arcsin() const;
+  at::Tensor & arcsin_() const;
+  at::Tensor atan() const;
+  at::Tensor & atan_() const;
+  at::Tensor arctan() const;
+  at::Tensor & arctan_() const;
+  at::Tensor baddbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor & baddbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor bernoulli(c10::optional generator=c10::nullopt) const;
+  at::Tensor & bernoulli_(const at::Tensor & p, c10::optional generator=c10::nullopt) const;
+  at::Tensor & bernoulli_(double p=0.5, c10::optional generator=c10::nullopt) const;
+  at::Tensor bernoulli(double p, c10::optional generator=c10::nullopt) const;
+  at::Tensor bincount(const c10::optional & weights={}, int64_t minlength=0) const;
+  at::Tensor bitwise_not() const;
+  at::Tensor & bitwise_not_() const;
+  at::Tensor copysign(const at::Tensor & other) const;
+  at::Tensor & copysign_(const at::Tensor & other) const;
+  at::Tensor copysign(const at::Scalar & other) const;
+  at::Tensor & copysign_(const at::Scalar & other) const;
+  at::Tensor _lazy_clone() const;
+  at::Tensor logical_not() const;
+  at::Tensor & logical_not_() const;
+  at::Tensor logical_xor(const at::Tensor & other) const;
+  at::Tensor & logical_xor_(const at::Tensor & other) const;
+  at::Tensor logical_and(const at::Tensor & other) const;
+  at::Tensor & logical_and_(const at::Tensor & other) const;
+  at::Tensor logical_or(const at::Tensor & other) const;
+  at::Tensor & logical_or_(const at::Tensor & other) const;
+  at::Tensor bmm(const at::Tensor & mat2) const;
+  at::Tensor broadcast_to(at::IntArrayRef size) const;
+  at::Tensor broadcast_to_symint(c10::SymIntArrayRef size) const;
+  at::Tensor ceil() const;
+  at::Tensor & ceil_() const;
+  ::std::vector unsafe_chunk(int64_t chunks, int64_t dim=0) const;
+  ::std::vector chunk(int64_t chunks, int64_t dim=0) const;
+  ::std::vector tensor_split(int64_t sections, int64_t dim=0) const;
+  ::std::vector tensor_split_symint(c10::SymInt sections, int64_t dim=0) const;
+  ::std::vector tensor_split(at::IntArrayRef indices, int64_t dim=0) const;
+  ::std::vector tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim=0) const;
+  ::std::vector tensor_split(const at::Tensor & tensor_indices_or_sections, int64_t dim=0) const;
+  at::Tensor clamp(const c10::optional & min, const c10::optional & max=c10::nullopt) const;
+  at::Tensor clamp(const c10::optional & min={}, const c10::optional & max={}) const;
+  at::Tensor & clamp_(const c10::optional & min, const c10::optional & max=c10::nullopt) const;
+  at::Tensor & clamp_(const c10::optional & min={}, const c10::optional & max={}) const;
+  at::Tensor clamp_max(const at::Scalar & max) const;
+  at::Tensor clamp_max(const at::Tensor & max) const;
+  at::Tensor & clamp_max_(const at::Scalar & max) const;
+  at::Tensor & clamp_max_(const at::Tensor & max) const;
+  at::Tensor clamp_min(const at::Scalar & min) const;
+  at::Tensor clamp_min(const at::Tensor & min) const;
+  at::Tensor & clamp_min_(const at::Scalar & min) const;
+  at::Tensor & clamp_min_(const at::Tensor & min) const;
+  at::Tensor clip(const c10::optional & min, const c10::optional & max=c10::nullopt) const;
+  at::Tensor clip(const c10::optional & min={}, const c10::optional & max={}) const;
+  at::Tensor & clip_(const c10::optional & min, const c10::optional & max=c10::nullopt) const;
+  at::Tensor & clip_(const c10::optional & min={}, const c10::optional & max={}) const;
+  at::Tensor __dispatch_contiguous(at::MemoryFormat memory_format=MemoryFormat::Contiguous) const;
+  at::Tensor & copy_(const at::Tensor & src, bool non_blocking=false) const;
+  at::Tensor cos() const;
+  at::Tensor & cos_() const;
+  at::Tensor cosh() const;
+  at::Tensor & cosh_() const;
+  at::Tensor count_nonzero(at::IntArrayRef dim) const;
+  at::Tensor count_nonzero(c10::optional dim=c10::nullopt) const;
+  at::Tensor cov(int64_t correction=1, const c10::optional & fweights={}, const c10::optional & aweights={}) const;
+  at::Tensor corrcoef() const;
+  ::std::tuple cummax(int64_t dim) const;
+  ::std::tuple cummax(at::Dimname dim) const;
+  ::std::tuple cummin(int64_t dim) const;
+  ::std::tuple cummin(at::Dimname dim) const;
+  at::Tensor cumprod(int64_t dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor & cumprod_(int64_t dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor cumprod(at::Dimname dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor & cumprod_(at::Dimname dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor cumsum(int64_t dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor & cumsum_(int64_t dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor cumsum(at::Dimname dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor & cumsum_(at::Dimname dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const;
+  at::Tensor diagflat(int64_t offset=0) const;
+  at::Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const;
+  at::Tensor diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset=0) const;
+  at::Tensor & fill_diagonal_(const at::Scalar & fill_value, bool wrap=false) const;
+  at::Tensor diff(int64_t n=1, int64_t dim=-1, const c10::optional & prepend={}, const c10::optional & append={}) const;
+  at::Tensor div(const at::Tensor & other) const;
+  at::Tensor & div_(const at::Tensor & other) const;
+  at::Tensor div(const at::Tensor & other, c10::optional rounding_mode) const;
+  at::Tensor & div_(const at::Tensor & other, c10::optional rounding_mode) const;
+  at::Tensor div(const at::Scalar & other) const;
+  at::Tensor & div_(const at::Scalar & other) const;
+  at::Tensor div(const at::Scalar & other, c10::optional rounding_mode) const;
+  at::Tensor & div_(const at::Scalar & other, c10::optional rounding_mode) const;
+  at::Tensor divide(const at::Tensor & other) const;
+  at::Tensor & divide_(const at::Tensor & other) const;
+  at::Tensor divide(const at::Scalar & other) const;
+  at::Tensor & divide_(const at::Scalar & other) const;
+  at::Tensor divide(const at::Tensor & other, c10::optional rounding_mode) const;
+  at::Tensor & divide_(const at::Tensor & other, c10::optional rounding_mode) const;
+  at::Tensor divide(const at::Scalar & other, c10::optional rounding_mode) const;
+  at::Tensor & divide_(const at::Scalar & other, c10::optional rounding_mode) const;
+  at::Tensor true_divide(const at::Tensor & other) const;
+  at::Tensor & true_divide_(const at::Tensor & other) const;
+  at::Tensor true_divide(const at::Scalar & other) const;
+  at::Tensor & true_divide_(const at::Scalar & other) const;
+  at::Tensor dot(const at::Tensor & tensor) const;
+  at::Tensor vdot(const at::Tensor & other) const;
+  at::Tensor new_empty(at::IntArrayRef size, at::TensorOptions options={}) const;
+  at::Tensor new_empty(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_empty_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const;
+  at::Tensor new_empty_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) const;
+  at::Tensor new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) const;
+  at::Tensor new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_full(at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) const;
+  at::Tensor new_full(at::IntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) const;
+  at::Tensor new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_zeros(at::IntArrayRef size, at::TensorOptions options={}) const;
+  at::Tensor new_zeros(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_zeros_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const;
+  at::Tensor new_zeros_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_ones(at::IntArrayRef size, at::TensorOptions options={}) const;
+  at::Tensor new_ones(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  at::Tensor new_ones_symint(c10::SymIntArrayRef size, at::TensorOptions options={}) const;
+  at::Tensor new_ones_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const;
+  const at::Tensor & resize_(at::IntArrayRef size, c10::optional memory_format=c10::nullopt) const;
+  const at::Tensor & resize__symint(c10::SymIntArrayRef size, c10::optional memory_format=c10::nullopt) const;
+  at::Tensor erf() const;
+  at::Tensor & erf_() const;
+  at::Tensor erfc() const;
+  at::Tensor & erfc_() const;
+  at::Tensor exp() const;
+  at::Tensor & exp_() const;
+  at::Tensor exp2() const;
+  at::Tensor & exp2_() const;
+  at::Tensor expm1() const;
+  at::Tensor & expm1_() const;
+  at::Tensor expand(at::IntArrayRef size, bool implicit=false) const;
+  at::Tensor expand_symint(c10::SymIntArrayRef size, bool implicit=false) const;
+  at::Tensor expand_as(const at::Tensor & other) const;
+  at::Tensor flatten(int64_t start_dim=0, int64_t end_dim=-1) const;
+  at::Tensor flatten(int64_t start_dim, int64_t end_dim, at::Dimname out_dim) const;
+  at::Tensor flatten(at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) const;
+  at::Tensor flatten(at::DimnameList dims, at::Dimname out_dim) const;
+  at::Tensor unflatten(int64_t dim, at::IntArrayRef sizes) const;
+  at::Tensor unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const;
+  at::Tensor unflatten(at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) const;
+  at::Tensor unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const;
+  at::Tensor & fill_(const at::Scalar & value) const;
+  at::Tensor & fill_(const at::Tensor & value) const;
+  at::Tensor floor() const;
+  at::Tensor & floor_() const;
+  at::Tensor floor_divide(const at::Tensor & other) const;
+  at::Tensor & floor_divide_(const at::Tensor & other) const;
+  at::Tensor floor_divide(const at::Scalar & other) const;
+  at::Tensor & floor_divide_(const at::Scalar & other) const;
+  at::Tensor frac() const;
+  at::Tensor & frac_() const;
+  at::Tensor gcd(const at::Tensor & other) const;
+  at::Tensor & gcd_(const at::Tensor & other) const;
+  at::Tensor lcm(const at::Tensor & other) const;
+  at::Tensor & lcm_(const at::Tensor & other) const;
+  at::Tensor index(const c10::List> & indices) const;
+  at::Tensor & index_copy_(int64_t dim, const at::Tensor & index, const at::Tensor & source) const;
+  at::Tensor index_copy(int64_t dim, const at::Tensor & index, const at::Tensor & source) const;
+  at::Tensor & index_copy_(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const;
+  at::Tensor index_copy(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const;
+  at::Tensor & index_put_(const c10::List> & indices, const at::Tensor & values, bool accumulate=false) const;
+  at::Tensor index_put(const c10::List> & indices, const at::Tensor & values, bool accumulate=false) const;
+  at::Tensor isclose(const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
+  at::Tensor isnan() const;
+  bool is_distributed() const;
+  bool __dispatch_is_floating_point() const;
+  bool __dispatch_is_complex() const;
+  bool __dispatch_is_conj() const;
+  bool __dispatch__is_zerotensor() const;
+  bool __dispatch_is_neg() const;
+  at::Tensor isreal() const;
+  bool is_nonzero() const;
+  bool is_same_size(const at::Tensor & other) const;
+  bool __dispatch_is_signed() const;
+  bool __dispatch_is_inference() const;
+  at::Tensor kron(const at::Tensor & other) const;
+  ::std::tuple kthvalue(int64_t k, int64_t dim=-1, bool keepdim=false) const;
+  ::std::tuple kthvalue(int64_t k, at::Dimname dim, bool keepdim=false) const;
+  at::Tensor nan_to_num(c10::optional nan=c10::nullopt, c10::optional posinf=c10::nullopt, c10::optional neginf=c10::nullopt) const;
+  at::Tensor & nan_to_num_(c10::optional nan=c10::nullopt, c10::optional posinf=c10::nullopt, c10::optional neginf=c10::nullopt) const;
+  at::Tensor ldexp(const at::Tensor & other) const;
+  at::Tensor & ldexp_(const at::Tensor & other) const;
+  at::Tensor log() const;
+  at::Tensor & log_() const;
+  at::Tensor log10() const;
+  at::Tensor & log10_() const;
+  at::Tensor log1p() const;
+  at::Tensor & log1p_() const;
+  at::Tensor log2() const;
+  at::Tensor & log2_() const;
+  at::Tensor logaddexp(const at::Tensor & other) const;
+  at::Tensor logaddexp2(const at::Tensor & other) const;
+  at::Tensor xlogy(const at::Tensor & other) const;
+  at::Tensor xlogy(const at::Scalar & other) const;
+  at::Tensor & xlogy_(const at::Tensor & other) const;
+  at::Tensor & xlogy_(const at::Scalar & other) const;
+  at::Tensor log_softmax(int64_t dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor log_softmax(at::Dimname dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor logcumsumexp(int64_t dim) const;
+  at::Tensor logcumsumexp(at::Dimname dim) const;
+  at::Tensor logsumexp(at::IntArrayRef dim, bool keepdim=false) const;
+  at::Tensor logsumexp(at::DimnameList dim, bool keepdim=false) const;
+  at::Tensor matmul(const at::Tensor & other) const;
+  at::Tensor matrix_power(int64_t n) const;
+  at::Tensor matrix_exp() const;
+  ::std::tuple aminmax(c10::optional dim=c10::nullopt, bool keepdim=false) const;
+  ::std::tuple max(int64_t dim, bool keepdim=false) const;
+  ::std::tuple max(at::Dimname dim, bool keepdim=false) const;
+  at::Tensor amax(at::IntArrayRef dim={}, bool keepdim=false) const;
+  at::Tensor mean(c10::optional dtype=c10::nullopt) const;
+  at::Tensor mean(at::OptionalIntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor mean(at::DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor nanmean(at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor median() const;
+  ::std::tuple median(int64_t dim, bool keepdim=false) const;
+  ::std::tuple median(at::Dimname dim, bool keepdim=false) const;
+  at::Tensor nanmedian() const;
+  ::std::tuple nanmedian(int64_t dim, bool keepdim=false) const;
+  ::std::tuple nanmedian(at::Dimname dim, bool keepdim=false) const;
+  ::std::tuple min(int64_t dim, bool keepdim=false) const;
+  ::std::tuple min(at::Dimname dim, bool keepdim=false) const;
+  at::Tensor amin(at::IntArrayRef dim={}, bool keepdim=false) const;
+  at::Tensor mm(const at::Tensor & mat2) const;
+  ::std::tuple mode(int64_t dim=-1, bool keepdim=false) const;
+  ::std::tuple mode(at::Dimname dim, bool keepdim=false) const;
+  at::Tensor mul(const at::Tensor & other) const;
+  at::Tensor & mul_(const at::Tensor & other) const;
+  at::Tensor mul(const at::Scalar & other) const;
+  at::Tensor & mul_(const at::Scalar & other) const;
+  at::Tensor multiply(const at::Tensor & other) const;
+  at::Tensor & multiply_(const at::Tensor & other) const;
+  at::Tensor multiply(const at::Scalar & other) const;
+  at::Tensor & multiply_(const at::Scalar & other) const;
+  at::Tensor mv(const at::Tensor & vec) const;
+  at::Tensor mvlgamma(int64_t p) const;
+  at::Tensor & mvlgamma_(int64_t p) const;
+  at::Tensor narrow_copy(int64_t dim, int64_t start, int64_t length) const;
+  at::Tensor narrow_copy_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const;
+  at::Tensor narrow(int64_t dim, int64_t start, int64_t length) const;
+  at::Tensor narrow_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const;
+  at::Tensor narrow(int64_t dim, const at::Tensor & start, int64_t length) const;
+  at::Tensor narrow_symint(int64_t dim, const at::Tensor & start, c10::SymInt length) const;
+  at::Tensor permute(at::IntArrayRef dims) const;
+  at::Tensor movedim(at::IntArrayRef source, at::IntArrayRef destination) const;
+  at::Tensor movedim(int64_t source, int64_t destination) const;
+  at::Tensor moveaxis(at::IntArrayRef source, at::IntArrayRef destination) const;
+  at::Tensor moveaxis(int64_t source, int64_t destination) const;
+  at::Tensor numpy_T() const;
+  at::Tensor matrix_H() const;
+  at::Tensor mT() const;
+  at::Tensor mH() const;
+  at::Tensor adjoint() const;
+  bool is_pinned(c10::optional device=c10::nullopt) const;
+  at::Tensor pin_memory(c10::optional device=c10::nullopt) const;
+  at::Tensor pinverse(double rcond=1e-15) const;
+  at::Tensor rad2deg() const;
+  at::Tensor & rad2deg_() const;
+  at::Tensor deg2rad() const;
+  at::Tensor & deg2rad_() const;
+  at::Tensor ravel() const;
+  at::Tensor reciprocal() const;
+  at::Tensor & reciprocal_() const;
+  at::Tensor neg() const;
+  at::Tensor & neg_() const;
+  at::Tensor negative() const;
+  at::Tensor & negative_() const;
+  at::Tensor repeat(at::IntArrayRef repeats) const;
+  at::Tensor repeat_symint(c10::SymIntArrayRef repeats) const;
+  at::Tensor repeat_interleave(const at::Tensor & repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) const;
+  at::Tensor repeat_interleave_symint(const at::Tensor & repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) const;
+  at::Tensor repeat_interleave(int64_t repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) const;
+  at::Tensor repeat_interleave_symint(c10::SymInt repeats, c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) const;
+  at::Tensor reshape(at::IntArrayRef shape) const;
+  at::Tensor reshape_symint(c10::SymIntArrayRef shape) const;
+  at::Tensor _reshape_alias(at::IntArrayRef size, at::IntArrayRef stride) const;
+  at::Tensor _reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const;
+  at::Tensor reshape_as(const at::Tensor & other) const;
+  at::Tensor round() const;
+  at::Tensor & round_() const;
+  at::Tensor round(int64_t decimals) const;
+  at::Tensor & round_(int64_t decimals) const;
+  at::Tensor relu() const;
+  at::Tensor & relu_() const;
+  at::Tensor prelu(const at::Tensor & weight) const;
+  at::Tensor hardshrink(const at::Scalar & lambd=0.5) const;
+  at::Tensor hardshrink_backward(const at::Tensor & grad_out, const at::Scalar & lambd) const;
+  at::Tensor rsqrt() const;
+  at::Tensor & rsqrt_() const;
+  at::Tensor select(at::Dimname dim, int64_t index) const;
+  at::Tensor select(int64_t dim, int64_t index) const;
+  at::Tensor select_symint(int64_t dim, c10::SymInt index) const;
+  at::Tensor sigmoid() const;
+  at::Tensor & sigmoid_() const;
+  at::Tensor logit(c10::optional eps=c10::nullopt) const;
+  at::Tensor & logit_(c10::optional eps=c10::nullopt) const;
+  at::Tensor sin() const;
+  at::Tensor & sin_() const;
+  at::Tensor sinc() const;
+  at::Tensor & sinc_() const;
+  at::Tensor sinh() const;
+  at::Tensor & sinh_() const;
+  at::Tensor detach() const;
+  at::Tensor & detach_() const;
+  int64_t size(at::Dimname dim) const;
+  at::Tensor slice(int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) const;
+  at::Tensor slice_symint(int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) const;
+  at::Tensor slice_inverse(const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) const;
+  at::Tensor slice_inverse_symint(const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) const;
+  at::Tensor slice_scatter(const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, int64_t step=1) const;
+  at::Tensor slice_scatter_symint(const at::Tensor & src, int64_t dim=0, c10::optional start=c10::nullopt, c10::optional end=c10::nullopt, c10::SymInt step=1) const;
+  at::Tensor select_scatter(const at::Tensor & src, int64_t dim, int64_t index) const;
+  at::Tensor select_scatter_symint(const at::Tensor & src, int64_t dim, c10::SymInt index) const;
+  at::Tensor diagonal_scatter(const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const;
+  at::Tensor as_strided_scatter(const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset=c10::nullopt) const;
+  at::Tensor as_strided_scatter_symint(const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset=c10::nullopt) const;
+  at::Tensor smm(const at::Tensor & mat2) const;
+  at::Tensor softmax(int64_t dim, c10::optional dtype=c10::nullopt) const;
+  at::Tensor softmax(at::Dimname dim, c10::optional dtype=c10::nullopt) const;
+  ::std::vector unsafe_split(int64_t split_size, int64_t dim=0) const;
+  ::std::vector unsafe_split_symint(c10::SymInt split_size, int64_t dim=0) const;
+  ::std::vector split(int64_t split_size, int64_t dim=0) const;
+  ::std::vector split_symint(c10::SymInt split_size, int64_t dim=0) const;
+  ::std::vector split(at::IntArrayRef split_size, int64_t dim=0) const;
+  ::std::vector split_symint(c10::SymIntArrayRef split_size, int64_t dim=0) const;
+  ::std::vector unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim=0) const;
+  ::std::vector unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim=0) const;
+  ::std::vector split_with_sizes(at::IntArrayRef split_sizes, int64_t dim=0) const;
+  ::std::vector split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim=0) const;
+  ::std::vector hsplit(int64_t sections) const;
+  ::std::vector hsplit(at::IntArrayRef indices) const;
+  ::std::vector vsplit(int64_t sections) const;
+  ::std::vector vsplit(at::IntArrayRef indices) const;
+  ::std::vector dsplit(int64_t sections) const;
+  ::std::vector dsplit(at::IntArrayRef indices) const;
+  at::Tensor squeeze() const;
+  at::Tensor squeeze(int64_t dim) const;
+  at::Tensor squeeze(at::Dimname dim) const;
+  at::Tensor squeeze(at::IntArrayRef dim) const;
+  at::Tensor & squeeze_() const;
+  at::Tensor & squeeze_(int64_t dim) const;
+  at::Tensor & squeeze_(at::IntArrayRef dim) const;
+  at::Tensor & squeeze_(at::Dimname dim) const;
+  at::Tensor sspaddmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool normalized, c10::optional onesided=c10::nullopt, c10::optional return_complex=c10::nullopt) const;
+  at::Tensor stft(int64_t n_fft, c10::optional hop_length=c10::nullopt, c10::optional win_length=c10::nullopt, const c10::optional & window={}, bool center=true, c10::string_view pad_mode="reflect", bool normalized=false, c10::optional onesided=c10::nullopt, c10::optional return_complex=c10::nullopt) const;
+  at::Tensor istft(int64_t n_fft, c10::optional hop_length=c10::nullopt, c10::optional win_length=c10::nullopt, const c10::optional & window={}, bool center=true, bool normalized=false, c10::optional onesided=c10::nullopt, c10::optional length=c10::nullopt, bool return_complex=false) const;
+  int64_t stride(at::Dimname dim) const;
+  at::Tensor sum(c10::optional dtype=c10::nullopt) const;
+  at::Tensor sum(at::OptionalIntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor sum(at::DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor nansum(at::OptionalIntArrayRef dim=c10::nullopt, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor sum_to_size(at::IntArrayRef size) const;
+  at::Tensor sum_to_size_symint(c10::SymIntArrayRef size) const;
+  at::Tensor sqrt() const;
+  at::Tensor & sqrt_() const;
+  at::Tensor square() const;
+  at::Tensor & square_() const;
+  at::Tensor std(bool unbiased) const;
+  at::Tensor std(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) const;
+  at::Tensor std(at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) const;
+  at::Tensor std(at::DimnameList dim, bool unbiased, bool keepdim=false) const;
+  at::Tensor std(at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) const;
+  at::Tensor prod(c10::optional dtype=c10::nullopt) const;
+  at::Tensor prod(int64_t dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor prod(at::Dimname dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const;
+  at::Tensor t() const;
+  at::Tensor & t_() const;
+  at::Tensor tan() const;
+  at::Tensor & tan_() const;
+  at::Tensor tanh() const;
+  at::Tensor & tanh_() const;
+  at::Tensor tile(at::IntArrayRef dims) const;
+  at::Tensor tile_symint(c10::SymIntArrayRef dims) const;
+  at::Tensor transpose(int64_t dim0, int64_t dim1) const;
+  at::Tensor transpose(at::Dimname dim0, at::Dimname dim1) const;
+  at::Tensor & transpose_(int64_t dim0, int64_t dim1) const;
+  at::Tensor flip(at::IntArrayRef dims) const;
+  at::Tensor fliplr() const;
+  at::Tensor flipud() const;
+  at::Tensor roll(at::IntArrayRef shifts, at::IntArrayRef dims={}) const;
+  at::Tensor roll_symint(c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) const;
+  at::Tensor rot90(int64_t k=1, at::IntArrayRef dims={0,1}) const;
+  at::Tensor _nested_tensor_size() const;
+  at::Tensor _nested_tensor_strides() const;
+  at::Tensor _nested_tensor_storage_offsets() const;
+  at::Tensor trunc() const;
+  at::Tensor & trunc_() const;
+  at::Tensor fix() const;
+  at::Tensor & fix_() const;
+  at::Tensor type_as(const at::Tensor & other) const;
+  at::Tensor unsqueeze(int64_t dim) const;
+  at::Tensor & unsqueeze_(int64_t dim) const;
+  at::Tensor var(bool unbiased) const;
+  at::Tensor var(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) const;
+  at::Tensor var(at::OptionalIntArrayRef dim=c10::nullopt, const c10::optional & correction=c10::nullopt, bool keepdim=false) const;
+  at::Tensor var(at::DimnameList dim, bool unbiased, bool keepdim=false) const;
+  at::Tensor var(at::DimnameList dim, const c10::optional & correction=c10::nullopt, bool keepdim=false) const;
+  at::Tensor view_as(const at::Tensor & other) const;
+  at::Tensor where(const at::Tensor & condition, const at::Tensor & other) const;
+  at::Tensor where(const at::Tensor & condition, const at::Scalar & other) const;
+  at::Tensor norm(const c10::optional & p, at::ScalarType dtype) const;
+  at::Tensor norm(const at::Scalar & p=2) const;
+  at::Tensor norm(const c10::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const;
+  at::Tensor norm(const c10::optional & p, at::IntArrayRef dim, bool keepdim=false) const;
+  at::Tensor norm(const c10::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const;
+  at::Tensor norm(const c10::optional & p, at::DimnameList dim, bool keepdim=false) const;
+  ::std::tuple frexp() const;
+  at::Tensor clone(c10::optional memory_format=c10::nullopt) const;
+  at::Tensor positive() const;
+  const at::Tensor & resize_as_(const at::Tensor & the_template, c10::optional memory_format=c10::nullopt) const;
+  const at::Tensor & resize_as_sparse_(const at::Tensor & the_template) const;
+  at::Tensor & zero_() const;
+  at::Tensor sub(const at::Tensor & other, const at::Scalar & alpha=1) const;
+  at::Tensor & sub_(const at::Tensor & other, const at::Scalar & alpha=1) const;
+  at::Tensor sub(const at::Scalar & other, const at::Scalar & alpha=1) const;
+  at::Tensor & sub_(const at::Scalar & other, const at::Scalar & alpha=1) const;
+  at::Tensor subtract(const at::Tensor & other, const at::Scalar & alpha=1) const;
+  at::Tensor & subtract_(const at::Tensor & other, const at::Scalar & alpha=1) const;
+  at::Tensor subtract(const at::Scalar & other, const at::Scalar & alpha=1) const;
+  at::Tensor & subtract_(const at::Scalar & other, const at::Scalar & alpha=1) const;
+  at::Tensor heaviside(const at::Tensor & values) const;
+  at::Tensor & heaviside_(const at::Tensor & values) const;
+  at::Tensor addmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor & addmm_(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor _addmm_activation(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) const;
+  const at::Tensor & sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const;
+  const at::Tensor & sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const;
+  at::Tensor sparse_mask(const at::Tensor & mask) const;
+  at::Tensor _sparse_mask_projection(const at::Tensor & mask, bool accumulate_matches=false) const;
+  at::Tensor to_dense(c10::optional dtype=c10::nullopt, c10::optional masked_grad=c10::nullopt) const;
+  at::Tensor _to_dense(c10::optional dtype=c10::nullopt, c10::optional masked_grad=c10::nullopt) const;
+  int64_t sparse_dim() const;
+  int64_t _dimI() const;
+  int64_t dense_dim() const;
+  int64_t _dimV() const;
+  int64_t _nnz() const;
+  at::Tensor coalesce() const;
+  bool is_coalesced() const;
+  at::Tensor _indices() const;
+  at::Tensor _values() const;
+  at::Tensor & _coalesced_(bool coalesced) const;
+  at::Tensor indices() const;
+  at::Tensor values() const;
+  at::Tensor crow_indices() const;
+  at::Tensor col_indices() const;
+  at::Tensor ccol_indices() const;
+  at::Tensor row_indices() const;
+  ::std::vector unbind(int64_t dim=0) const;
+  ::std::vector unbind(at::Dimname dim) const;
+  at::Tensor to_sparse(int64_t sparse_dim) const;
+  at::Tensor _to_sparse(int64_t sparse_dim) const;
+  at::Tensor to_sparse(c10::optional layout=c10::nullopt, at::OptionalIntArrayRef blocksize=c10::nullopt, c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor _to_sparse(c10::optional layout=c10::nullopt, at::OptionalIntArrayRef blocksize=c10::nullopt, c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor to_sparse_csr(c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor _to_sparse_csr(c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor to_sparse_csc(c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor _to_sparse_csc(c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor to_sparse_bsr(at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor _to_sparse_bsr(at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor to_sparse_bsc(at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor _to_sparse_bsc(at::IntArrayRef blocksize, c10::optional dense_dim=c10::nullopt) const;
+  at::Tensor to_mkldnn(c10::optional dtype=c10::nullopt) const;
+  at::Tensor dequantize() const;
+  double q_scale() const;
+  int64_t q_zero_point() const;
+  at::Tensor q_per_channel_scales() const;
+  at::Tensor q_per_channel_zero_points() const;
+  int64_t q_per_channel_axis() const;
+  at::Tensor int_repr() const;
+  at::QScheme qscheme() const;
+  at::Tensor _autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const;
+  at::Tensor _autocast_to_full_precision(bool cuda_enabled, bool cpu_enabled) const;
+  at::Tensor to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) const;
+  at::Tensor to(c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, bool copy, c10::optional memory_format) const;
+  at::Tensor to(at::Device device, at::ScalarType dtype, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) const;
+  at::Tensor to(at::ScalarType dtype, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) const;
+  at::Tensor to(const at::Tensor & other, bool non_blocking=false, bool copy=false, c10::optional memory_format=c10::nullopt) const;
+  at::Scalar item() const;
+  at::Tensor & set_(at::Storage source) const;
+  at::Tensor & set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) const;
+  at::Tensor & set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) const;
+  at::Tensor & set_(const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) const;
+  at::Tensor & set__symint(const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) const;
+  at::Tensor & set_(const at::Tensor & source) const;
+  at::Tensor & set_() const;
+  bool is_set_to(const at::Tensor & tensor) const;
+  at::Tensor & masked_fill_(const at::Tensor & mask, const at::Scalar & value) const;
+  at::Tensor masked_fill(const at::Tensor & mask, const at::Scalar & value) const;
+  at::Tensor & masked_fill_(const at::Tensor & mask, const at::Tensor & value) const;
+  at::Tensor masked_fill(const at::Tensor & mask, const at::Tensor & value) const;
+  at::Tensor & masked_scatter_(const at::Tensor & mask, const at::Tensor & source) const;
+  at::Tensor masked_scatter(const at::Tensor & mask, const at::Tensor & source) const;
+  at::Tensor view(at::IntArrayRef size) const;
+  at::Tensor view_symint(c10::SymIntArrayRef size) const;
+  at::Tensor view(at::ScalarType dtype) const;
+  at::Tensor & put_(const at::Tensor & index, const at::Tensor & source, bool accumulate=false) const;
+  at::Tensor put(const at::Tensor & index, const at::Tensor & source, bool accumulate=false) const;
+  at::Tensor & index_add_(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const;
+  at::Tensor index_add(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const;
+  at::Tensor index_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) const;
+  at::Tensor & index_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) const;
+  at::Tensor index_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) const;
+  at::Tensor & index_fill_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor index_fill(int64_t dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor & index_fill_(int64_t dim, const at::Tensor & index, const at::Tensor & value) const;
+  at::Tensor index_fill(int64_t dim, const at::Tensor & index, const at::Tensor & value) const;
+  at::Tensor & index_fill_(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor & index_fill_(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const;
+  at::Tensor index_fill(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor index_fill(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const;
+  at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src) const;
+  at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const;
+  at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const;
+  at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const;
+  at::Tensor scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const;
+  at::Tensor & scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const;
+  at::Tensor scatter(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const;
+  at::Tensor scatter(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const;
+  at::Tensor scatter_add(int64_t dim, const at::Tensor & index, const at::Tensor & src) const;
+  at::Tensor & scatter_add_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const;
+  at::Tensor scatter_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const;
+  at::Tensor scatter_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) const;
+  at::Tensor & scatter_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) const;
+  at::Tensor & eq_(const at::Scalar & other) const;
+  at::Tensor & eq_(const at::Tensor & other) const;
+  at::Tensor bitwise_and(const at::Scalar & other) const;
+  at::Tensor bitwise_and(const at::Tensor & other) const;
+  at::Tensor & bitwise_and_(const at::Scalar & other) const;
+  at::Tensor & bitwise_and_(const at::Tensor & other) const;
+  at::Tensor __and__(const at::Scalar & other) const;
+  at::Tensor __and__(const at::Tensor & other) const;
+  at::Tensor & __iand__(const at::Scalar & other) const;
+  at::Tensor & __iand__(const at::Tensor & other) const;
+  at::Tensor bitwise_or(const at::Scalar & other) const;
+  at::Tensor bitwise_or(const at::Tensor & other) const;
+  at::Tensor & bitwise_or_(const at::Scalar & other) const;
+  at::Tensor & bitwise_or_(const at::Tensor & other) const;
+  at::Tensor __or__(const at::Scalar & other) const;
+  at::Tensor __or__(const at::Tensor & other) const;
+  at::Tensor & __ior__(const at::Scalar & other) const;
+  at::Tensor & __ior__(const at::Tensor & other) const;
+  at::Tensor bitwise_xor(const at::Scalar & other) const;
+  at::Tensor bitwise_xor(const at::Tensor & other) const;
+  at::Tensor & bitwise_xor_(const at::Scalar & other) const;
+  at::Tensor & bitwise_xor_(const at::Tensor & other) const;
+  at::Tensor __xor__(const at::Scalar & other) const;
+  at::Tensor __xor__(const at::Tensor & other) const;
+  at::Tensor & __ixor__(const at::Scalar & other) const;
+  at::Tensor & __ixor__(const at::Tensor & other) const;
+  at::Tensor __lshift__(const at::Scalar & other) const;
+  at::Tensor __lshift__(const at::Tensor & other) const;
+  at::Tensor & __ilshift__(const at::Scalar & other) const;
+  at::Tensor & __ilshift__(const at::Tensor & other) const;
+  at::Tensor bitwise_left_shift(const at::Tensor & other) const;
+  at::Tensor & bitwise_left_shift_(const at::Tensor & other) const;
+  at::Tensor bitwise_left_shift(const at::Scalar & other) const;
+  at::Tensor & bitwise_left_shift_(const at::Scalar & other) const;
+  at::Tensor __rshift__(const at::Scalar & other) const;
+  at::Tensor __rshift__(const at::Tensor & other) const;
+  at::Tensor & __irshift__(const at::Scalar & other) const;
+  at::Tensor & __irshift__(const at::Tensor & other) const;
+  at::Tensor bitwise_right_shift(const at::Tensor & other) const;
+  at::Tensor & bitwise_right_shift_(const at::Tensor & other) const;
+  at::Tensor bitwise_right_shift(const at::Scalar & other) const;
+  at::Tensor & bitwise_right_shift_(const at::Scalar & other) const;
+  at::Tensor & tril_(int64_t diagonal=0) const;
+  at::Tensor & triu_(int64_t diagonal=0) const;
+  at::Tensor & digamma_() const;
+  at::Tensor & lerp_(const at::Tensor & end, const at::Scalar & weight) const;
+  at::Tensor & lerp_(const at::Tensor & end, const at::Tensor & weight) const;
+  at::Tensor & addbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor addbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) const;
+  at::Tensor & random_(int64_t from, c10::optional to, c10::optional generator=c10::nullopt) const;
+  at::Tensor & random_(int64_t to, c10::optional generator=c10::nullopt) const;
+  at::Tensor & random_(c10::optional generator=c10::nullopt) const;
+  at::Tensor & uniform_(double from=0, double to=1, c10::optional generator=c10::nullopt) const;
+  at::Tensor & cauchy_(double median=0, double sigma=1, c10::optional generator=c10::nullopt) const;
+  at::Tensor & log_normal_(double mean=1, double std=2, c10::optional generator=c10::nullopt) const;
+  at::Tensor & exponential_(double lambd=1, c10::optional generator=c10::nullopt) const;
+  at::Tensor & geometric_(double p, c10::optional generator=c10::nullopt) const;
+  at::Tensor diag(int64_t diagonal=0) const;
+  at::Tensor cross(const at::Tensor & other, c10::optional dim=c10::nullopt) const;
+  at::Tensor triu(int64_t diagonal=0) const;
+  at::Tensor tril(int64_t diagonal=0) const;
+  at::Tensor trace() const;
+  at::Tensor ne(const at::Scalar & other) const;
+  at::Tensor ne(const at::Tensor & other) const;
+  at::Tensor & ne_(const at::Scalar & other) const;
+  at::Tensor & ne_(const at::Tensor & other) const;
+  at::Tensor not_equal(const at::Scalar & other) const;
+  at::Tensor not_equal(const at::Tensor & other) const;
+  at::Tensor & not_equal_(const at::Scalar & other) const;
+  at::Tensor & not_equal_(const at::Tensor & other) const;
+  at::Tensor eq(const at::Scalar & other) const;
+  at::Tensor eq(const at::Tensor & other) const;
+  at::Tensor ge(const at::Scalar & other) const;
+  at::Tensor ge(const at::Tensor & other) const;
+  at::Tensor & ge_(const at::Scalar & other) const;
+  at::Tensor & ge_(const at::Tensor & other) const;
+  at::Tensor greater_equal(const at::Scalar & other) const;
+  at::Tensor greater_equal(const at::Tensor & other) const;
+  at::Tensor & greater_equal_(const at::Scalar & other) const;
+  at::Tensor & greater_equal_(const at::Tensor & other) const;
+  at::Tensor le(const at::Scalar & other) const;
+  at::Tensor le(const at::Tensor & other) const;
+  at::Tensor & le_(const at::Scalar & other) const;
+  at::Tensor & le_(const at::Tensor & other) const;
+  at::Tensor less_equal(const at::Scalar & other) const;
+  at::Tensor less_equal(const at::Tensor & other) const;
+  at::Tensor & less_equal_(const at::Scalar & other) const;
+  at::Tensor & less_equal_(const at::Tensor & other) const;
+  at::Tensor gt(const at::Scalar & other) const;
+  at::Tensor gt(const at::Tensor & other) const;
+  at::Tensor & gt_(const at::Scalar & other) const;
+  at::Tensor & gt_(const at::Tensor & other) const;
+  at::Tensor greater(const at::Scalar & other) const;
+  at::Tensor greater(const at::Tensor & other) const;
+  at::Tensor & greater_(const at::Scalar & other) const;
+  at::Tensor & greater_(const at::Tensor & other) const;
+  at::Tensor lt(const at::Scalar & other) const;
+  at::Tensor lt(const at::Tensor & other) const;
+  at::Tensor & lt_(const at::Scalar & other) const;
+  at::Tensor & lt_(const at::Tensor & other) const;
+  at::Tensor less(const at::Scalar & other) const;
+  at::Tensor less(const at::Tensor & other) const;
+  at::Tensor & less_(const at::Scalar & other) const;
+  at::Tensor & less_(const at::Tensor & other) const;
+  at::Tensor take(const at::Tensor & index) const;
+  at::Tensor take_along_dim(const at::Tensor & indices, c10::optional dim=c10::nullopt) const;
+  at::Tensor index_select(int64_t dim, const at::Tensor & index) const;
+  at::Tensor index_select(at::Dimname dim, const at::Tensor & index) const;
+  at::Tensor masked_select(const at::Tensor & mask) const;
+  at::Tensor nonzero() const;
+  at::Tensor nonzero_static(int64_t size, int64_t fill_value=-1) const;
+  ::std::vector nonzero_numpy() const;
+  at::Tensor argwhere() const;
+  at::Tensor gather(int64_t dim, const at::Tensor & index, bool sparse_grad=false) const;
+  at::Tensor gather(at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) const;
+  at::Tensor addcmul(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const;
+  at::Tensor & addcmul_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const;
+  at::Tensor addcdiv(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const;
+  at::Tensor & addcdiv_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) const;
+  ::std::tuple triangular_solve(const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
+  ::std::tuple svd(bool some=true, bool compute_uv=true) const;
+  at::Tensor swapaxes(int64_t axis0, int64_t axis1) const;
+  at::Tensor & swapaxes_(int64_t axis0, int64_t axis1) const;
+  at::Tensor swapdims(int64_t dim0, int64_t dim1) const;
+  at::Tensor & swapdims_(int64_t dim0, int64_t dim1) const;
+  at::Tensor cholesky(bool upper=false) const;
+  at::Tensor cholesky_solve(const at::Tensor & input2, bool upper=false) const;
+  at::Tensor cholesky_inverse(bool upper=false) const;
+  ::std::tuple qr(bool some=true) const;
+  ::std::tuple geqrf() const;
+  at::Tensor orgqr(const at::Tensor & input2) const;
+  at::Tensor ormqr(const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) const;
+  at::Tensor lu_solve(const at::Tensor & LU_data, const at::Tensor & LU_pivots) const;
+  at::Tensor multinomial(int64_t num_samples, bool replacement=false, c10::optional generator=c10::nullopt) const;
+  at::Tensor & lgamma_() const;
+  at::Tensor lgamma() const;
+  at::Tensor digamma() const;
+  at::Tensor polygamma(int64_t n) const;
+  at::Tensor & polygamma_(int64_t n) const;
+  at::Tensor erfinv() const;
+  at::Tensor & erfinv_() const;
+  at::Tensor i0() const;
+  at::Tensor & i0_() const;
+  at::Tensor sign() const;
+  at::Tensor & sign_() const;
+  at::Tensor signbit() const;
+  at::Tensor dist(const at::Tensor & other, const at::Scalar & p=2) const;
+  at::Tensor & atan2_(const at::Tensor & other) const;
+  at::Tensor atan2(const at::Tensor & other) const;
+  at::Tensor arctan2(const at::Tensor & other) const;
+  at::Tensor & arctan2_(const at::Tensor & other) const;
+  at::Tensor lerp(const at::Tensor & end, const at::Scalar & weight) const;
+  at::Tensor lerp(const at::Tensor & end, const at::Tensor & weight) const;
+  at::Tensor histc(int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) const;
+  ::std::tuple histogram(const at::Tensor & bins, const c10::optional & weight={}, bool density=false) const;
+  ::std::tuple histogram(int64_t bins=100, c10::optional> range=c10::nullopt, const c10::optional & weight={}, bool density=false) const;
+  at::Tensor fmod(const at::Scalar & other) const;
+  at::Tensor & fmod_(const at::Scalar & other) const;
+  at::Tensor fmod(const at::Tensor & other) const;
+  at::Tensor & fmod_(const at::Tensor & other) const;
+  at::Tensor hypot(const at::Tensor & other) const;
+  at::Tensor & hypot_(const at::Tensor & other) const;
+  at::Tensor igamma(const at::Tensor & other) const;
+  at::Tensor & igamma_(const at::Tensor & other) const;
+  at::Tensor igammac(const at::Tensor & other) const;
+  at::Tensor & igammac_(const at::Tensor & other) const;
+  at::Tensor nextafter(const at::Tensor & other) const;
+  at::Tensor & nextafter_(const at::Tensor & other) const;
+  at::Tensor remainder(const at::Scalar & other) const;
+  at::Tensor & remainder_(const at::Scalar & other) const;
+  at::Tensor remainder(const at::Tensor & other) const;
+  at::Tensor & remainder_(const at::Tensor & other) const;
+  at::Tensor min() const;
+  at::Tensor fmin(const at::Tensor & other) const;
+  at::Tensor max() const;
+  at::Tensor fmax(const at::Tensor & other) const;
+  at::Tensor maximum(const at::Tensor & other) const;
+  at::Tensor max(const at::Tensor & other) const;
+  at::Tensor minimum(const at::Tensor & other) const;
+  at::Tensor min(const at::Tensor & other) const;
+  at::Tensor quantile(const at::Tensor & q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const;
+  at::Tensor quantile(double q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const;
+  at::Tensor nanquantile(const at::Tensor & q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const;
+  at::Tensor nanquantile(double q, c10::optional dim=c10::nullopt, bool keepdim=false, c10::string_view interpolation="linear") const;
+  ::std::tuple sort(int64_t dim=-1, bool descending=false) const;
+  ::std::tuple sort(c10::optional stable, int64_t dim=-1, bool descending=false) const;
+  ::std::tuple sort(at::Dimname dim, bool descending=false) const;
+  ::std::tuple sort(c10::optional stable, at::Dimname dim, bool descending=false) const;
+  at::Tensor msort() const;
+  at::Tensor argsort(int64_t dim=-1, bool descending=false) const;
+  at::Tensor argsort(bool stable, int64_t dim=-1, bool descending=false) const;
+  at::Tensor argsort(at::Dimname dim, bool descending=false) const;
+  ::std::tuple topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const;
+  ::std::tuple topk_symint(c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) const;
+  at::Tensor all() const;
+  at::Tensor any() const;
+  at::Tensor renorm(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const;
+  at::Tensor & renorm_(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const;
+  at::Tensor unfold(int64_t dimension, int64_t size, int64_t step) const;
+  bool equal(const at::Tensor & other) const;
+  at::Tensor pow(const at::Tensor & exponent) const;
+  at::Tensor pow(const at::Scalar & exponent) const;
+  at::Tensor & pow_(const at::Scalar & exponent) const;
+  at::Tensor & pow_(const at::Tensor & exponent) const;
+  at::Tensor float_power(const at::Tensor & exponent) const;
+  at::Tensor float_power(const at::Scalar & exponent) const;
+  at::Tensor & float_power_(const at::Scalar & exponent) const;
+  at::Tensor & float_power_(const at::Tensor & exponent) const;
+  at::Tensor & normal_(double mean=0, double std=1, c10::optional generator=c10::nullopt) const;
+  at::Tensor alias() const;
+  at::Tensor isfinite() const;
+  at::Tensor isinf() const;
+  void record_stream(at::Stream s) const;
+  at::Tensor isposinf() const;
+  at::Tensor isneginf() const;
+  at::Tensor det() const;
+  ::std::tuple slogdet() const;
+  at::Tensor logdet() const;
+  at::Tensor inverse() const;
+  at::Tensor inner(const at::Tensor & other) const;
+  at::Tensor outer(const at::Tensor & vec2) const;
+  at::Tensor ger(const at::Tensor & vec2) const;
+  at::Tensor to_padded_tensor(double padding, at::OptionalIntArrayRef output_size=c10::nullopt) const;
+  at::Tensor to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size=c10::nullopt) const;
+
+  // Special C++ only overloads for std()-like functions (See gh-40287)
+  // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
+  // So, for example std(0) would select the std(unbiased=False) overload
+
+  Tensor var(int dim) const {
+    return var(IntArrayRef{dim});
+  }
+
+  Tensor std(int dim) const {
+    return std(IntArrayRef{dim});
+  }
+
+  // We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the
+  // at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet.
+  // Before that change, we make this method to maintain BC for C++ usage like
+  // `x.to(y.dtype)`.
+  // TODO: remove following two after at::kDouble and its friends are TypeMeta's.
+  inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
+    return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
+  }
+  inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
+    return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
+  }
+
+  template 
+  decltype(auto) m(F func, Args&&... params) const {
+    return func(*this, std::forward(params)...);
+  }
+
+  /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
+  /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
+  /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
+  ///
+  /// One notable difference with the legacy `.data()` function is that changes to the
+  /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
+  /// will not update the original `Variable`, due to the fact that this function
+  /// shallow-copies the `Variable`'s underlying TensorImpl.
+  at::Tensor tensor_data() const {
+    return TensorBase::tensor_data();
+  }
+
+  /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
+  /// in Python, which create a new `Variable` that shares the same storage and
+  /// tensor metadata with the original `Variable`, but with a completely new
+  /// autograd history.
+  ///
+  /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
+  /// storage / storage_offset) of a variable created from `var.variable_data()`, those
+  /// changes will not update the original variable `var`. In `.variable_data()`, we set
+  /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
+  /// in order to prevent users from changing metadata of `var.variable_data()`
+  /// and expecting the original variable `var` to also be updated.
+  at::Tensor variable_data() const {
+    return TensorBase::variable_data();
+  }
+
+  // Hooks
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  template 
+  using hook_return_void_t = std::enable_if_t>::value, unsigned>;
+  template 
+  using hook_return_var_t = std::enable_if_t, Tensor>::value, unsigned>;
+
+  /// Registers a backward hook.
+  ///
+  /// The hook will be called every time a gradient with respect to the Tensor is computed.
+  /// The hook should have one of the following signature:
+  /// ```
+  /// hook(Tensor grad) -> Tensor
+  /// ```
+  /// ```
+  /// hook(Tensor grad) -> void
+  /// ```
+  /// The hook should not modify its argument, but it can optionally return a new gradient
+  /// which will be used in place of `grad`.
+  ///
+  /// This function returns the index of the hook in the list which can be used to remove hook.
+  ///
+  /// Example:
+  /// @code
+  /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
+  /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
+  /// v.backward(torch::tensor({1., 2., 3.}));
+  /// // This prints:
+  /// // ```
+  /// //  2
+  /// //  4
+  /// //  6
+  /// // [ CPUFloatType{3} ]
+  /// // ```
+  /// std::cout << v.grad() << std::endl;
+  /// v.remove_hook(h);  // removes the hook
+  /// @endcode
+  template 
+  hook_return_void_t register_hook(T&& hook) const;
+  template 
+  hook_return_var_t register_hook(T&& hook) const;
+
+  // Variable methods
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  Tensor data() const {
+    return TensorBase::data();
+  }
+
+  void _backward(TensorList inputs, const c10::optional& gradient, c10::optional keep_graph, bool create_graph) const;
+
+  const Tensor& requires_grad_(bool _requires_grad=true) const {
+    TensorBase::requires_grad_(_requires_grad);
+    return *this;
+  }
+};
+
+namespace detail {
+// Helper creator for Tensor class which doesn't requires the users to pass
+// in an intrusive_ptr instead it just converts the argument passed to
+// requested intrusive_ptr type.
+template 
+Tensor make_tensor(Args&&... args) {
+  return Tensor(c10::make_intrusive(std::forward(args)...));
+}
+
+} // namespace detail
+
+} // namespace at
+
+
+namespace at {
+
+// aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
+inline void Tensor::__dispatch__backward(at::TensorList inputs, const c10::optional & gradient, c10::optional retain_graph, bool create_graph) const {
+    return at::_ops::_backward::call(const_cast(*this), inputs, gradient, retain_graph, create_graph);
+}
+
+// aten::set_data(Tensor(a!) self, Tensor new_data) -> ()
+inline void Tensor::__dispatch_set_data(const at::Tensor & new_data) const {
+    return at::_ops::set_data::call(const_cast(*this), new_data);
+}
+
+// aten::data(Tensor self) -> Tensor
+inline at::Tensor Tensor::__dispatch_data() const {
+    return at::_ops::data::call(const_cast(*this));
+}
+
+// aten::is_leaf(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_leaf() const {
+    return at::_ops::is_leaf::call(const_cast(*this));
+}
+
+// aten::output_nr(Tensor self) -> int
+inline int64_t Tensor::__dispatch_output_nr() const {
+    return at::_ops::output_nr::call(const_cast(*this));
+}
+
+// aten::_version(Tensor self) -> int
+inline int64_t Tensor::__dispatch__version() const {
+    return at::_ops::_version::call(const_cast(*this));
+}
+
+// aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)
+inline at::Tensor & Tensor::__dispatch_requires_grad_(bool requires_grad) const {
+    return at::_ops::requires_grad_::call(const_cast(*this), requires_grad);
+}
+
+// aten::retain_grad(Tensor(a!) self) -> ()
+inline void Tensor::__dispatch_retain_grad() const {
+    return at::_ops::retain_grad::call(const_cast(*this));
+}
+
+// aten::retains_grad(Tensor self) -> bool
+inline bool Tensor::__dispatch_retains_grad() const {
+    return at::_ops::retains_grad::call(const_cast(*this));
+}
+
+// aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a)
+inline at::Tensor Tensor::_fw_primal(int64_t level) const {
+    return at::_ops::_fw_primal::call(const_cast(*this), level);
+}
+
+// aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
+inline at::Tensor & Tensor::rename_(c10::optional names) const {
+    return at::_ops::rename_::call(const_cast(*this), names);
+}
+
+// aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)
+inline at::Tensor Tensor::rename(c10::optional names) const {
+    return at::_ops::rename::call(const_cast(*this), names);
+}
+
+// aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)
+inline at::Tensor Tensor::align_to(at::DimnameList names) const {
+    return at::_ops::align_to::call(const_cast(*this), names);
+}
+
+// aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)
+inline at::Tensor Tensor::align_to(at::DimnameList order, int64_t ellipsis_idx) const {
+    return at::_ops::align_to_ellipsis_idx::call(const_cast(*this), order, ellipsis_idx);
+}
+
+// aten::align_as(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::align_as(const at::Tensor & other) const {
+    return at::_ops::align_as::call(const_cast(*this), other);
+}
+
+// aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
+inline at::Tensor Tensor::refine_names(at::DimnameList names) const {
+    return at::_ops::refine_names::call(const_cast(*this), names);
+}
+
+// aten::abs(Tensor self) -> Tensor
+inline at::Tensor Tensor::abs() const {
+    return at::_ops::abs::call(const_cast(*this));
+}
+
+// aten::abs_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::abs_() const {
+    return at::_ops::abs_::call(const_cast(*this));
+}
+
+// aten::absolute(Tensor self) -> Tensor
+inline at::Tensor Tensor::absolute() const {
+    return at::_ops::absolute::call(const_cast(*this));
+}
+
+// aten::absolute_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::absolute_() const {
+    return at::_ops::absolute_::call(const_cast(*this));
+}
+
+// aten::angle(Tensor self) -> Tensor
+inline at::Tensor Tensor::angle() const {
+    return at::_ops::angle::call(const_cast(*this));
+}
+
+// aten::sgn(Tensor self) -> Tensor
+inline at::Tensor Tensor::sgn() const {
+    return at::_ops::sgn::call(const_cast(*this));
+}
+
+// aten::sgn_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sgn_() const {
+    return at::_ops::sgn_::call(const_cast(*this));
+}
+
+// aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+inline at::Tensor Tensor::chalf(c10::optional memory_format) const {
+    return at::_ops::chalf::call(const_cast(*this), memory_format);
+}
+
+// aten::_conj(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::_conj() const {
+    return at::_ops::_conj::call(const_cast(*this));
+}
+
+// aten::conj(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::__dispatch_conj() const {
+    return at::_ops::conj::call(const_cast(*this));
+}
+
+// aten::_conj_physical(Tensor self) -> Tensor
+inline at::Tensor Tensor::_conj_physical() const {
+    return at::_ops::_conj_physical::call(const_cast(*this));
+}
+
+// aten::conj_physical(Tensor self) -> Tensor
+inline at::Tensor Tensor::conj_physical() const {
+    return at::_ops::conj_physical::call(const_cast(*this));
+}
+
+// aten::conj_physical_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::conj_physical_() const {
+    return at::_ops::conj_physical_::call(const_cast(*this));
+}
+
+// aten::resolve_conj(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::resolve_conj() const {
+    return at::_ops::resolve_conj::call(const_cast(*this));
+}
+
+// aten::resolve_neg(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::resolve_neg() const {
+    return at::_ops::resolve_neg::call(const_cast(*this));
+}
+
+// aten::_neg_view(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::_neg_view() const {
+    return at::_ops::_neg_view::call(const_cast(*this));
+}
+
+// aten::acos(Tensor self) -> Tensor
+inline at::Tensor Tensor::acos() const {
+    return at::_ops::acos::call(const_cast(*this));
+}
+
+// aten::acos_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::acos_() const {
+    return at::_ops::acos_::call(const_cast(*this));
+}
+
+// aten::arccos(Tensor self) -> Tensor
+inline at::Tensor Tensor::arccos() const {
+    return at::_ops::arccos::call(const_cast(*this));
+}
+
+// aten::arccos_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::arccos_() const {
+    return at::_ops::arccos_::call(const_cast(*this));
+}
+
+// aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const {
+    return at::_ops::add_Tensor::call(const_cast(*this), other, alpha);
+}
+
+// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::add_(const at::Tensor & other, const at::Scalar & alpha) const {
+    return at::_ops::add__Tensor::call(const_cast(*this), other, alpha);
+}
+
+// aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::add(const at::Scalar & other, const at::Scalar & alpha) const {
+    return at::_ops::add_Scalar::call(const_cast(*this), other, alpha);
+}
+
+// aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::add_(const at::Scalar & other, const at::Scalar & alpha) const {
+    return at::_ops::add__Scalar::call(const_cast(*this), other, alpha);
+}
+
+// aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::addmv(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addmv::call(const_cast(*this), mat, vec, beta, alpha);
+}
+
+// aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::addmv_(const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addmv_::call(const_cast(*this), mat, vec, beta, alpha);
+}
+
+// aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::addr(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addr::call(const_cast(*this), vec1, vec2, beta, alpha);
+}
+
+// aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::addr_(const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addr_::call(const_cast(*this), vec1, vec2, beta, alpha);
+}
+
+// aten::_is_all_true(Tensor self) -> Tensor
+inline at::Tensor Tensor::_is_all_true() const {
+    return at::_ops::_is_all_true::call(const_cast(*this));
+}
+
+// aten::_is_any_true(Tensor self) -> Tensor
+inline at::Tensor Tensor::_is_any_true() const {
+    return at::_ops::_is_any_true::call(const_cast(*this));
+}
+
+// aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::all(int64_t dim, bool keepdim) const {
+    return at::_ops::all_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::all(at::OptionalIntArrayRef dim, bool keepdim) const {
+    return at::_ops::all_dims::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::all(at::Dimname dim, bool keepdim) const {
+    return at::_ops::all_dimname::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
+inline bool Tensor::allclose(const at::Tensor & other, double rtol, double atol, bool equal_nan) const {
+    return at::_ops::allclose::call(const_cast(*this), other, rtol, atol, equal_nan);
+}
+
+// aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::any(int64_t dim, bool keepdim) const {
+    return at::_ops::any_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::any(at::OptionalIntArrayRef dim, bool keepdim) const {
+    return at::_ops::any_dims::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::any(at::Dimname dim, bool keepdim) const {
+    return at::_ops::any_dimname::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::argmax(c10::optional dim, bool keepdim) const {
+    return at::_ops::argmax::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::argmin(c10::optional dim, bool keepdim) const {
+    return at::_ops::argmin::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::acosh(Tensor self) -> Tensor
+inline at::Tensor Tensor::acosh() const {
+    return at::_ops::acosh::call(const_cast(*this));
+}
+
+// aten::acosh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::acosh_() const {
+    return at::_ops::acosh_::call(const_cast(*this));
+}
+
+// aten::arccosh(Tensor self) -> Tensor
+inline at::Tensor Tensor::arccosh() const {
+    return at::_ops::arccosh::call(const_cast(*this));
+}
+
+// aten::arccosh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::arccosh_() const {
+    return at::_ops::arccosh_::call(const_cast(*this));
+}
+
+// aten::asinh(Tensor self) -> Tensor
+inline at::Tensor Tensor::asinh() const {
+    return at::_ops::asinh::call(const_cast(*this));
+}
+
+// aten::asinh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::asinh_() const {
+    return at::_ops::asinh_::call(const_cast(*this));
+}
+
+// aten::arcsinh(Tensor self) -> Tensor
+inline at::Tensor Tensor::arcsinh() const {
+    return at::_ops::arcsinh::call(const_cast(*this));
+}
+
+// aten::arcsinh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::arcsinh_() const {
+    return at::_ops::arcsinh_::call(const_cast(*this));
+}
+
+// aten::atanh(Tensor self) -> Tensor
+inline at::Tensor Tensor::atanh() const {
+    return at::_ops::atanh::call(const_cast(*this));
+}
+
+// aten::atanh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::atanh_() const {
+    return at::_ops::atanh_::call(const_cast(*this));
+}
+
+// aten::arctanh(Tensor self) -> Tensor
+inline at::Tensor Tensor::arctanh() const {
+    return at::_ops::arctanh::call(const_cast(*this));
+}
+
+// aten::arctanh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::arctanh_() const {
+    return at::_ops::arctanh_::call(const_cast(*this));
+}
+
+// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
+inline at::Tensor Tensor::as_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) const {
+    return at::_ops::as_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+}
+
+// aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
+inline at::Tensor Tensor::as_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset) const {
+    return at::_ops::as_strided::call(const_cast(*this), size, stride, storage_offset);
+}
+
+// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
+inline const at::Tensor & Tensor::as_strided_(at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) const {
+    return at::_ops::as_strided_::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+}
+
+// aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
+inline const at::Tensor & Tensor::as_strided__symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset) const {
+    return at::_ops::as_strided_::call(const_cast(*this), size, stride, storage_offset);
+}
+
+// aten::asin(Tensor self) -> Tensor
+inline at::Tensor Tensor::asin() const {
+    return at::_ops::asin::call(const_cast(*this));
+}
+
+// aten::asin_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::asin_() const {
+    return at::_ops::asin_::call(const_cast(*this));
+}
+
+// aten::arcsin(Tensor self) -> Tensor
+inline at::Tensor Tensor::arcsin() const {
+    return at::_ops::arcsin::call(const_cast(*this));
+}
+
+// aten::arcsin_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::arcsin_() const {
+    return at::_ops::arcsin_::call(const_cast(*this));
+}
+
+// aten::atan(Tensor self) -> Tensor
+inline at::Tensor Tensor::atan() const {
+    return at::_ops::atan::call(const_cast(*this));
+}
+
+// aten::atan_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::atan_() const {
+    return at::_ops::atan_::call(const_cast(*this));
+}
+
+// aten::arctan(Tensor self) -> Tensor
+inline at::Tensor Tensor::arctan() const {
+    return at::_ops::arctan::call(const_cast(*this));
+}
+
+// aten::arctan_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::arctan_() const {
+    return at::_ops::arctan_::call(const_cast(*this));
+}
+
+// aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::baddbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::baddbmm::call(const_cast(*this), batch1, batch2, beta, alpha);
+}
+
+// aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::baddbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::baddbmm_::call(const_cast(*this), batch1, batch2, beta, alpha);
+}
+
+// aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor
+inline at::Tensor Tensor::bernoulli(c10::optional generator) const {
+    return at::_ops::bernoulli::call(const_cast(*this), generator);
+}
+
+// aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::bernoulli_(const at::Tensor & p, c10::optional generator) const {
+    return at::_ops::bernoulli__Tensor::call(const_cast(*this), p, generator);
+}
+
+// aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::bernoulli_(double p, c10::optional generator) const {
+    return at::_ops::bernoulli__float::call(const_cast(*this), p, generator);
+}
+
+// aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor
+inline at::Tensor Tensor::bernoulli(double p, c10::optional generator) const {
+    return at::_ops::bernoulli_p::call(const_cast(*this), p, generator);
+}
+
+// aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor
+inline at::Tensor Tensor::bincount(const c10::optional & weights, int64_t minlength) const {
+    return at::_ops::bincount::call(const_cast(*this), weights, minlength);
+}
+
+// aten::bitwise_not(Tensor self) -> Tensor
+inline at::Tensor Tensor::bitwise_not() const {
+    return at::_ops::bitwise_not::call(const_cast(*this));
+}
+
+// aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_not_() const {
+    return at::_ops::bitwise_not_::call(const_cast(*this));
+}
+
+// aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::copysign(const at::Tensor & other) const {
+    return at::_ops::copysign_Tensor::call(const_cast(*this), other);
+}
+
+// aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::copysign_(const at::Tensor & other) const {
+    return at::_ops::copysign__Tensor::call(const_cast(*this), other);
+}
+
+// aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::copysign(const at::Scalar & other) const {
+    return at::_ops::copysign_Scalar::call(const_cast(*this), other);
+}
+
+// aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::copysign_(const at::Scalar & other) const {
+    return at::_ops::copysign__Scalar::call(const_cast(*this), other);
+}
+
+// aten::_lazy_clone(Tensor self) -> Tensor
+inline at::Tensor Tensor::_lazy_clone() const {
+    return at::_ops::_lazy_clone::call(const_cast(*this));
+}
+
+// aten::logical_not(Tensor self) -> Tensor
+inline at::Tensor Tensor::logical_not() const {
+    return at::_ops::logical_not::call(const_cast(*this));
+}
+
+// aten::logical_not_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::logical_not_() const {
+    return at::_ops::logical_not_::call(const_cast(*this));
+}
+
+// aten::logical_xor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::logical_xor(const at::Tensor & other) const {
+    return at::_ops::logical_xor::call(const_cast(*this), other);
+}
+
+// aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::logical_xor_(const at::Tensor & other) const {
+    return at::_ops::logical_xor_::call(const_cast(*this), other);
+}
+
+// aten::logical_and(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::logical_and(const at::Tensor & other) const {
+    return at::_ops::logical_and::call(const_cast(*this), other);
+}
+
+// aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::logical_and_(const at::Tensor & other) const {
+    return at::_ops::logical_and_::call(const_cast(*this), other);
+}
+
+// aten::logical_or(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::logical_or(const at::Tensor & other) const {
+    return at::_ops::logical_or::call(const_cast(*this), other);
+}
+
+// aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::logical_or_(const at::Tensor & other) const {
+    return at::_ops::logical_or_::call(const_cast(*this), other);
+}
+
+// aten::bmm(Tensor self, Tensor mat2) -> Tensor
+inline at::Tensor Tensor::bmm(const at::Tensor & mat2) const {
+    return at::_ops::bmm::call(const_cast(*this), mat2);
+}
+
+// aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
+inline at::Tensor Tensor::broadcast_to(at::IntArrayRef size) const {
+    return at::_ops::broadcast_to::call(const_cast(*this), c10::fromIntArrayRefSlow(size));
+}
+
+// aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
+inline at::Tensor Tensor::broadcast_to_symint(c10::SymIntArrayRef size) const {
+    return at::_ops::broadcast_to::call(const_cast(*this), size);
+}
+
+// aten::ceil(Tensor self) -> Tensor
+inline at::Tensor Tensor::ceil() const {
+    return at::_ops::ceil::call(const_cast(*this));
+}
+
+// aten::ceil_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::ceil_() const {
+    return at::_ops::ceil_::call(const_cast(*this));
+}
+
+// aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]
+inline ::std::vector Tensor::unsafe_chunk(int64_t chunks, int64_t dim) const {
+    return at::_ops::unsafe_chunk::call(const_cast(*this), chunks, dim);
+}
+
+// aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::chunk(int64_t chunks, int64_t dim) const {
+    return at::_ops::chunk::call(const_cast(*this), chunks, dim);
+}
+
+// aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::tensor_split(int64_t sections, int64_t dim) const {
+    return at::_ops::tensor_split_sections::call(const_cast(*this), sections, dim);
+}
+
+// aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::tensor_split_symint(c10::SymInt sections, int64_t dim) const {
+    return at::_ops::tensor_split_sections::call(const_cast(*this), sections, dim);
+}
+
+// aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::tensor_split(at::IntArrayRef indices, int64_t dim) const {
+    return at::_ops::tensor_split_indices::call(const_cast(*this), c10::fromIntArrayRefSlow(indices), dim);
+}
+
+// aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::tensor_split_symint(c10::SymIntArrayRef indices, int64_t dim) const {
+    return at::_ops::tensor_split_indices::call(const_cast(*this), indices, dim);
+}
+
+// aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::tensor_split(const at::Tensor & tensor_indices_or_sections, int64_t dim) const {
+    return at::_ops::tensor_split_tensor_indices_or_sections::call(const_cast(*this), tensor_indices_or_sections, dim);
+}
+
+// aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+inline at::Tensor Tensor::clamp(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clamp::call(const_cast(*this), min, max);
+}
+
+// aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+inline at::Tensor Tensor::clamp(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clamp_Tensor::call(const_cast(*this), min, max);
+}
+
+// aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
+inline at::Tensor & Tensor::clamp_(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clamp_::call(const_cast(*this), min, max);
+}
+
+// aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)
+inline at::Tensor & Tensor::clamp_(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clamp__Tensor::call(const_cast(*this), min, max);
+}
+
+// aten::clamp_max(Tensor self, Scalar max) -> Tensor
+inline at::Tensor Tensor::clamp_max(const at::Scalar & max) const {
+    return at::_ops::clamp_max::call(const_cast(*this), max);
+}
+
+// aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor
+inline at::Tensor Tensor::clamp_max(const at::Tensor & max) const {
+    return at::_ops::clamp_max_Tensor::call(const_cast(*this), max);
+}
+
+// aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)
+inline at::Tensor & Tensor::clamp_max_(const at::Scalar & max) const {
+    return at::_ops::clamp_max_::call(const_cast(*this), max);
+}
+
+// aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!)
+inline at::Tensor & Tensor::clamp_max_(const at::Tensor & max) const {
+    return at::_ops::clamp_max__Tensor::call(const_cast(*this), max);
+}
+
+// aten::clamp_min(Tensor self, Scalar min) -> Tensor
+inline at::Tensor Tensor::clamp_min(const at::Scalar & min) const {
+    return at::_ops::clamp_min::call(const_cast(*this), min);
+}
+
+// aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
+inline at::Tensor Tensor::clamp_min(const at::Tensor & min) const {
+    return at::_ops::clamp_min_Tensor::call(const_cast(*this), min);
+}
+
+// aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)
+inline at::Tensor & Tensor::clamp_min_(const at::Scalar & min) const {
+    return at::_ops::clamp_min_::call(const_cast(*this), min);
+}
+
+// aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)
+inline at::Tensor & Tensor::clamp_min_(const at::Tensor & min) const {
+    return at::_ops::clamp_min__Tensor::call(const_cast(*this), min);
+}
+
+// aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+inline at::Tensor Tensor::clip(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clip::call(const_cast(*this), min, max);
+}
+
+// aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+inline at::Tensor Tensor::clip(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clip_Tensor::call(const_cast(*this), min, max);
+}
+
+// aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
+inline at::Tensor & Tensor::clip_(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clip_::call(const_cast(*this), min, max);
+}
+
+// aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)
+inline at::Tensor & Tensor::clip_(const c10::optional & min, const c10::optional & max) const {
+    return at::_ops::clip__Tensor::call(const_cast(*this), min, max);
+}
+
+// aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
+inline at::Tensor Tensor::__dispatch_contiguous(at::MemoryFormat memory_format) const {
+    return at::_ops::contiguous::call(const_cast(*this), memory_format);
+}
+
+// aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+inline at::Tensor & Tensor::copy_(const at::Tensor & src, bool non_blocking) const {
+    return at::_ops::copy_::call(const_cast(*this), src, non_blocking);
+}
+
+// aten::cos(Tensor self) -> Tensor
+inline at::Tensor Tensor::cos() const {
+    return at::_ops::cos::call(const_cast(*this));
+}
+
+// aten::cos_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::cos_() const {
+    return at::_ops::cos_::call(const_cast(*this));
+}
+
+// aten::cosh(Tensor self) -> Tensor
+inline at::Tensor Tensor::cosh() const {
+    return at::_ops::cosh::call(const_cast(*this));
+}
+
+// aten::cosh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::cosh_() const {
+    return at::_ops::cosh_::call(const_cast(*this));
+}
+
+// aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
+inline at::Tensor Tensor::count_nonzero(at::IntArrayRef dim) const {
+    return at::_ops::count_nonzero_dim_IntList::call(const_cast(*this), dim);
+}
+
+// aten::count_nonzero(Tensor self, int? dim=None) -> Tensor
+inline at::Tensor Tensor::count_nonzero(c10::optional dim) const {
+    return at::_ops::count_nonzero::call(const_cast(*this), dim);
+}
+
+// aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
+inline at::Tensor Tensor::cov(int64_t correction, const c10::optional & fweights, const c10::optional & aweights) const {
+    return at::_ops::cov::call(const_cast(*this), correction, fweights, aweights);
+}
+
+// aten::corrcoef(Tensor self) -> Tensor
+inline at::Tensor Tensor::corrcoef() const {
+    return at::_ops::corrcoef::call(const_cast(*this));
+}
+
+// aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::cummax(int64_t dim) const {
+    return at::_ops::cummax::call(const_cast(*this), dim);
+}
+
+// aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::cummax(at::Dimname dim) const {
+    return at::_ops::cummax_dimname::call(const_cast(*this), dim);
+}
+
+// aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::cummin(int64_t dim) const {
+    return at::_ops::cummin::call(const_cast(*this), dim);
+}
+
+// aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::cummin(at::Dimname dim) const {
+    return at::_ops::cummin_dimname::call(const_cast(*this), dim);
+}
+
+// aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) const {
+    return at::_ops::cumprod::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
+inline at::Tensor & Tensor::cumprod_(int64_t dim, c10::optional dtype) const {
+    return at::_ops::cumprod_::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::cumprod(at::Dimname dim, c10::optional dtype) const {
+    return at::_ops::cumprod_dimname::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)
+inline at::Tensor & Tensor::cumprod_(at::Dimname dim, c10::optional dtype) const {
+    return at::_ops::cumprod__dimname::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const {
+    return at::_ops::cumsum::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
+inline at::Tensor & Tensor::cumsum_(int64_t dim, c10::optional dtype) const {
+    return at::_ops::cumsum_::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::cumsum(at::Dimname dim, c10::optional dtype) const {
+    return at::_ops::cumsum_dimname::call(const_cast(*this), dim, dtype);
+}
+
+// aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)
+inline at::Tensor & Tensor::cumsum_(at::Dimname dim, c10::optional dtype) const {
+    return at::_ops::cumsum__dimname::call(const_cast(*this), dim, dtype);
+}
+
+// aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor
+inline at::Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const {
+    return at::_ops::diag_embed::call(const_cast(*this), offset, dim1, dim2);
+}
+
+// aten::diagflat(Tensor self, int offset=0) -> Tensor
+inline at::Tensor Tensor::diagflat(int64_t offset) const {
+    return at::_ops::diagflat::call(const_cast(*this), offset);
+}
+
+// aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
+inline at::Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const {
+    return at::_ops::diagonal::call(const_cast(*this), offset, dim1, dim2);
+}
+
+// aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)
+inline at::Tensor Tensor::diagonal(at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset) const {
+    return at::_ops::diagonal_Dimname::call(const_cast(*this), outdim, dim1, dim2, offset);
+}
+
+// aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
+inline at::Tensor & Tensor::fill_diagonal_(const at::Scalar & fill_value, bool wrap) const {
+    return at::_ops::fill_diagonal_::call(const_cast(*this), fill_value, wrap);
+}
+
+// aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor
+inline at::Tensor Tensor::diff(int64_t n, int64_t dim, const c10::optional & prepend, const c10::optional & append) const {
+    return at::_ops::diff::call(const_cast(*this), n, dim, prepend, append);
+}
+
+// aten::div.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::div(const at::Tensor & other) const {
+    return at::_ops::div_Tensor::call(const_cast(*this), other);
+}
+
+// aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::div_(const at::Tensor & other) const {
+    return at::_ops::div__Tensor::call(const_cast(*this), other);
+}
+
+// aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+inline at::Tensor Tensor::div(const at::Tensor & other, c10::optional rounding_mode) const {
+    return at::_ops::div_Tensor_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
+inline at::Tensor & Tensor::div_(const at::Tensor & other, c10::optional rounding_mode) const {
+    return at::_ops::div__Tensor_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::div.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::div(const at::Scalar & other) const {
+    return at::_ops::div_Scalar::call(const_cast(*this), other);
+}
+
+// aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::div_(const at::Scalar & other) const {
+    return at::_ops::div__Scalar::call(const_cast(*this), other);
+}
+
+// aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+inline at::Tensor Tensor::div(const at::Scalar & other, c10::optional rounding_mode) const {
+    return at::_ops::div_Scalar_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)
+inline at::Tensor & Tensor::div_(const at::Scalar & other, c10::optional rounding_mode) const {
+    return at::_ops::div__Scalar_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::divide.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::divide(const at::Tensor & other) const {
+    return at::_ops::divide_Tensor::call(const_cast(*this), other);
+}
+
+// aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::divide_(const at::Tensor & other) const {
+    return at::_ops::divide__Tensor::call(const_cast(*this), other);
+}
+
+// aten::divide.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::divide(const at::Scalar & other) const {
+    return at::_ops::divide_Scalar::call(const_cast(*this), other);
+}
+
+// aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::divide_(const at::Scalar & other) const {
+    return at::_ops::divide__Scalar::call(const_cast(*this), other);
+}
+
+// aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+inline at::Tensor Tensor::divide(const at::Tensor & other, c10::optional rounding_mode) const {
+    return at::_ops::divide_Tensor_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
+inline at::Tensor & Tensor::divide_(const at::Tensor & other, c10::optional rounding_mode) const {
+    return at::_ops::divide__Tensor_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+inline at::Tensor Tensor::divide(const at::Scalar & other, c10::optional rounding_mode) const {
+    return at::_ops::divide_Scalar_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)
+inline at::Tensor & Tensor::divide_(const at::Scalar & other, c10::optional rounding_mode) const {
+    return at::_ops::divide__Scalar_mode::call(const_cast(*this), other, rounding_mode);
+}
+
+// aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::true_divide(const at::Tensor & other) const {
+    return at::_ops::true_divide_Tensor::call(const_cast(*this), other);
+}
+
+// aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::true_divide_(const at::Tensor & other) const {
+    return at::_ops::true_divide__Tensor::call(const_cast(*this), other);
+}
+
+// aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::true_divide(const at::Scalar & other) const {
+    return at::_ops::true_divide_Scalar::call(const_cast(*this), other);
+}
+
+// aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::true_divide_(const at::Scalar & other) const {
+    return at::_ops::true_divide__Scalar::call(const_cast(*this), other);
+}
+
+// aten::dot(Tensor self, Tensor tensor) -> Tensor
+inline at::Tensor Tensor::dot(const at::Tensor & tensor) const {
+    return at::_ops::dot::call(const_cast(*this), tensor);
+}
+
+// aten::vdot(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::vdot(const at::Tensor & other) const {
+    return at::_ops::vdot::call(const_cast(*this), other);
+}
+
+// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty(at::IntArrayRef size, at::TensorOptions options) const {
+    return at::_ops::new_empty::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_empty::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+}
+
+// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty_symint(c10::SymIntArrayRef size, at::TensorOptions options) const {
+    return at::_ops::new_empty::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_empty::call(const_cast(*this), size, dtype, layout, device, pin_memory);
+}
+
+// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options) const {
+    return at::_ops::new_empty_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_empty_strided::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory);
+}
+
+// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options) const {
+    return at::_ops::new_empty_strided::call(const_cast(*this), size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_empty_strided_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_empty_strided::call(const_cast(*this), size, stride, dtype, layout, device, pin_memory);
+}
+
+// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_full(at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options) const {
+    return at::_ops::new_full::call(const_cast(*this), c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_full(at::IntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_full::call(const_cast(*this), c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory);
+}
+
+// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options) const {
+    return at::_ops::new_full::call(const_cast(*this), size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_full_symint(c10::SymIntArrayRef size, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_full::call(const_cast(*this), size, fill_value, dtype, layout, device, pin_memory);
+}
+
+// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_zeros(at::IntArrayRef size, at::TensorOptions options) const {
+    return at::_ops::new_zeros::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_zeros(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_zeros::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+}
+
+// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_zeros_symint(c10::SymIntArrayRef size, at::TensorOptions options) const {
+    return at::_ops::new_zeros::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_zeros_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_zeros::call(const_cast(*this), size, dtype, layout, device, pin_memory);
+}
+
+// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_ones(at::IntArrayRef size, at::TensorOptions options) const {
+    return at::_ops::new_ones::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_ones(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_ones::call(const_cast(*this), c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
+}
+
+// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_ones_symint(c10::SymIntArrayRef size, at::TensorOptions options) const {
+    return at::_ops::new_ones::call(const_cast(*this), size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+}
+
+// aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+inline at::Tensor Tensor::new_ones_symint(c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) const {
+    return at::_ops::new_ones::call(const_cast(*this), size, dtype, layout, device, pin_memory);
+}
+
+// aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+inline const at::Tensor & Tensor::resize_(at::IntArrayRef size, c10::optional memory_format) const {
+    return at::_ops::resize_::call(const_cast(*this), c10::fromIntArrayRefSlow(size), memory_format);
+}
+
+// aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+inline const at::Tensor & Tensor::resize__symint(c10::SymIntArrayRef size, c10::optional memory_format) const {
+    return at::_ops::resize_::call(const_cast(*this), size, memory_format);
+}
+
+// aten::erf(Tensor self) -> Tensor
+inline at::Tensor Tensor::erf() const {
+    return at::_ops::erf::call(const_cast(*this));
+}
+
+// aten::erf_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::erf_() const {
+    return at::_ops::erf_::call(const_cast(*this));
+}
+
+// aten::erfc(Tensor self) -> Tensor
+inline at::Tensor Tensor::erfc() const {
+    return at::_ops::erfc::call(const_cast(*this));
+}
+
+// aten::erfc_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::erfc_() const {
+    return at::_ops::erfc_::call(const_cast(*this));
+}
+
+// aten::exp(Tensor self) -> Tensor
+inline at::Tensor Tensor::exp() const {
+    return at::_ops::exp::call(const_cast(*this));
+}
+
+// aten::exp_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::exp_() const {
+    return at::_ops::exp_::call(const_cast(*this));
+}
+
+// aten::exp2(Tensor self) -> Tensor
+inline at::Tensor Tensor::exp2() const {
+    return at::_ops::exp2::call(const_cast(*this));
+}
+
+// aten::exp2_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::exp2_() const {
+    return at::_ops::exp2_::call(const_cast(*this));
+}
+
+// aten::expm1(Tensor self) -> Tensor
+inline at::Tensor Tensor::expm1() const {
+    return at::_ops::expm1::call(const_cast(*this));
+}
+
+// aten::expm1_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::expm1_() const {
+    return at::_ops::expm1_::call(const_cast(*this));
+}
+
+// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
+inline at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit) const {
+    return at::_ops::expand::call(const_cast(*this), c10::fromIntArrayRefSlow(size), implicit);
+}
+
+// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
+inline at::Tensor Tensor::expand_symint(c10::SymIntArrayRef size, bool implicit) const {
+    return at::_ops::expand::call(const_cast(*this), size, implicit);
+}
+
+// aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)
+inline at::Tensor Tensor::expand_as(const at::Tensor & other) const {
+    return at::_ops::expand_as::call(const_cast(*this), other);
+}
+
+// aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)
+inline at::Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const {
+    return at::_ops::flatten_using_ints::call(const_cast(*this), start_dim, end_dim);
+}
+
+// aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)
+inline at::Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, at::Dimname out_dim) const {
+    return at::_ops::flatten_named_out_dim::call(const_cast(*this), start_dim, end_dim, out_dim);
+}
+
+// aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)
+inline at::Tensor Tensor::flatten(at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) const {
+    return at::_ops::flatten_using_names::call(const_cast(*this), start_dim, end_dim, out_dim);
+}
+
+// aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
+inline at::Tensor Tensor::flatten(at::DimnameList dims, at::Dimname out_dim) const {
+    return at::_ops::flatten_DimnameList::call(const_cast(*this), dims, out_dim);
+}
+
+// aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)
+inline at::Tensor Tensor::unflatten(int64_t dim, at::IntArrayRef sizes) const {
+    return at::_ops::unflatten_int::call(const_cast(*this), dim, c10::fromIntArrayRefSlow(sizes));
+}
+
+// aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)
+inline at::Tensor Tensor::unflatten_symint(int64_t dim, c10::SymIntArrayRef sizes) const {
+    return at::_ops::unflatten_int::call(const_cast(*this), dim, sizes);
+}
+
+// aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)
+inline at::Tensor Tensor::unflatten(at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) const {
+    return at::_ops::unflatten_Dimname::call(const_cast(*this), dim, c10::fromIntArrayRefSlow(sizes), names);
+}
+
+// aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)
+inline at::Tensor Tensor::unflatten_symint(at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) const {
+    return at::_ops::unflatten_Dimname::call(const_cast(*this), dim, sizes, names);
+}
+
+// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
+inline at::Tensor & Tensor::fill_(const at::Scalar & value) const {
+    return at::_ops::fill__Scalar::call(const_cast(*this), value);
+}
+
+// aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
+inline at::Tensor & Tensor::fill_(const at::Tensor & value) const {
+    return at::_ops::fill__Tensor::call(const_cast(*this), value);
+}
+
+// aten::floor(Tensor self) -> Tensor
+inline at::Tensor Tensor::floor() const {
+    return at::_ops::floor::call(const_cast(*this));
+}
+
+// aten::floor_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::floor_() const {
+    return at::_ops::floor_::call(const_cast(*this));
+}
+
+// aten::floor_divide(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::floor_divide(const at::Tensor & other) const {
+    return at::_ops::floor_divide::call(const_cast(*this), other);
+}
+
+// aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::floor_divide_(const at::Tensor & other) const {
+    return at::_ops::floor_divide__Tensor::call(const_cast(*this), other);
+}
+
+// aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::floor_divide(const at::Scalar & other) const {
+    return at::_ops::floor_divide_Scalar::call(const_cast(*this), other);
+}
+
+// aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::floor_divide_(const at::Scalar & other) const {
+    return at::_ops::floor_divide__Scalar::call(const_cast(*this), other);
+}
+
+// aten::frac(Tensor self) -> Tensor
+inline at::Tensor Tensor::frac() const {
+    return at::_ops::frac::call(const_cast(*this));
+}
+
+// aten::frac_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::frac_() const {
+    return at::_ops::frac_::call(const_cast(*this));
+}
+
+// aten::gcd(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::gcd(const at::Tensor & other) const {
+    return at::_ops::gcd::call(const_cast(*this), other);
+}
+
+// aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::gcd_(const at::Tensor & other) const {
+    return at::_ops::gcd_::call(const_cast(*this), other);
+}
+
+// aten::lcm(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::lcm(const at::Tensor & other) const {
+    return at::_ops::lcm::call(const_cast(*this), other);
+}
+
+// aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::lcm_(const at::Tensor & other) const {
+    return at::_ops::lcm_::call(const_cast(*this), other);
+}
+
+// aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+inline at::Tensor Tensor::index(const c10::List> & indices) const {
+    return at::_ops::index_Tensor::call(const_cast(*this), indices);
+}
+
+// aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
+inline at::Tensor & Tensor::index_copy_(int64_t dim, const at::Tensor & index, const at::Tensor & source) const {
+    return at::_ops::index_copy_::call(const_cast(*this), dim, index, source);
+}
+
+// aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
+inline at::Tensor Tensor::index_copy(int64_t dim, const at::Tensor & index, const at::Tensor & source) const {
+    return at::_ops::index_copy::call(const_cast(*this), dim, index, source);
+}
+
+// aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)
+inline at::Tensor & Tensor::index_copy_(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const {
+    return at::_ops::index_copy__dimname::call(const_cast(*this), dim, index, source);
+}
+
+// aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor
+inline at::Tensor Tensor::index_copy(at::Dimname dim, const at::Tensor & index, const at::Tensor & source) const {
+    return at::_ops::index_copy_dimname::call(const_cast(*this), dim, index, source);
+}
+
+// aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
+inline at::Tensor & Tensor::index_put_(const c10::List> & indices, const at::Tensor & values, bool accumulate) const {
+    return at::_ops::index_put_::call(const_cast(*this), indices, values, accumulate);
+}
+
+// aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+inline at::Tensor Tensor::index_put(const c10::List> & indices, const at::Tensor & values, bool accumulate) const {
+    return at::_ops::index_put::call(const_cast(*this), indices, values, accumulate);
+}
+
+// aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor
+inline at::Tensor Tensor::isclose(const at::Tensor & other, double rtol, double atol, bool equal_nan) const {
+    return at::_ops::isclose::call(const_cast(*this), other, rtol, atol, equal_nan);
+}
+
+// aten::isnan(Tensor self) -> Tensor
+inline at::Tensor Tensor::isnan() const {
+    return at::_ops::isnan::call(const_cast(*this));
+}
+
+// aten::is_distributed(Tensor self) -> bool
+inline bool Tensor::is_distributed() const {
+    return at::_ops::is_distributed::call(const_cast(*this));
+}
+
+// aten::is_floating_point(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_floating_point() const {
+    return at::_ops::is_floating_point::call(const_cast(*this));
+}
+
+// aten::is_complex(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_complex() const {
+    return at::_ops::is_complex::call(const_cast(*this));
+}
+
+// aten::is_conj(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_conj() const {
+    return at::_ops::is_conj::call(const_cast(*this));
+}
+
+// aten::_is_zerotensor(Tensor self) -> bool
+inline bool Tensor::__dispatch__is_zerotensor() const {
+    return at::_ops::_is_zerotensor::call(const_cast(*this));
+}
+
+// aten::is_neg(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_neg() const {
+    return at::_ops::is_neg::call(const_cast(*this));
+}
+
+// aten::isreal(Tensor self) -> Tensor
+inline at::Tensor Tensor::isreal() const {
+    return at::_ops::isreal::call(const_cast(*this));
+}
+
+// aten::is_nonzero(Tensor self) -> bool
+inline bool Tensor::is_nonzero() const {
+    return at::_ops::is_nonzero::call(const_cast(*this));
+}
+
+// aten::is_same_size(Tensor self, Tensor other) -> bool
+inline bool Tensor::is_same_size(const at::Tensor & other) const {
+    return at::_ops::is_same_size::call(const_cast(*this), other);
+}
+
+// aten::is_signed(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_signed() const {
+    return at::_ops::is_signed::call(const_cast(*this));
+}
+
+// aten::is_inference(Tensor self) -> bool
+inline bool Tensor::__dispatch_is_inference() const {
+    return at::_ops::is_inference::call(const_cast(*this));
+}
+
+// aten::kron(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::kron(const at::Tensor & other) const {
+    return at::_ops::kron::call(const_cast(*this), other);
+}
+
+// aten::kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const {
+    return at::_ops::kthvalue::call(const_cast(*this), k, dim, keepdim);
+}
+
+// aten::kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::kthvalue(int64_t k, at::Dimname dim, bool keepdim) const {
+    return at::_ops::kthvalue_dimname::call(const_cast(*this), k, dim, keepdim);
+}
+
+// aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
+inline at::Tensor Tensor::nan_to_num(c10::optional nan, c10::optional posinf, c10::optional neginf) const {
+    return at::_ops::nan_to_num::call(const_cast(*this), nan, posinf, neginf);
+}
+
+// aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)
+inline at::Tensor & Tensor::nan_to_num_(c10::optional nan, c10::optional posinf, c10::optional neginf) const {
+    return at::_ops::nan_to_num_::call(const_cast(*this), nan, posinf, neginf);
+}
+
+// aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::ldexp(const at::Tensor & other) const {
+    return at::_ops::ldexp_Tensor::call(const_cast(*this), other);
+}
+
+// aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::ldexp_(const at::Tensor & other) const {
+    return at::_ops::ldexp_::call(const_cast(*this), other);
+}
+
+// aten::log(Tensor self) -> Tensor
+inline at::Tensor Tensor::log() const {
+    return at::_ops::log::call(const_cast(*this));
+}
+
+// aten::log_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::log_() const {
+    return at::_ops::log_::call(const_cast(*this));
+}
+
+// aten::log10(Tensor self) -> Tensor
+inline at::Tensor Tensor::log10() const {
+    return at::_ops::log10::call(const_cast(*this));
+}
+
+// aten::log10_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::log10_() const {
+    return at::_ops::log10_::call(const_cast(*this));
+}
+
+// aten::log1p(Tensor self) -> Tensor
+inline at::Tensor Tensor::log1p() const {
+    return at::_ops::log1p::call(const_cast(*this));
+}
+
+// aten::log1p_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::log1p_() const {
+    return at::_ops::log1p_::call(const_cast(*this));
+}
+
+// aten::log2(Tensor self) -> Tensor
+inline at::Tensor Tensor::log2() const {
+    return at::_ops::log2::call(const_cast(*this));
+}
+
+// aten::log2_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::log2_() const {
+    return at::_ops::log2_::call(const_cast(*this));
+}
+
+// aten::logaddexp(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::logaddexp(const at::Tensor & other) const {
+    return at::_ops::logaddexp::call(const_cast(*this), other);
+}
+
+// aten::logaddexp2(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::logaddexp2(const at::Tensor & other) const {
+    return at::_ops::logaddexp2::call(const_cast(*this), other);
+}
+
+// aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::xlogy(const at::Tensor & other) const {
+    return at::_ops::xlogy_Tensor::call(const_cast(*this), other);
+}
+
+// aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::xlogy(const at::Scalar & other) const {
+    return at::_ops::xlogy_Scalar_Other::call(const_cast(*this), other);
+}
+
+// aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::xlogy_(const at::Tensor & other) const {
+    return at::_ops::xlogy__Tensor::call(const_cast(*this), other);
+}
+
+// aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::xlogy_(const at::Scalar & other) const {
+    return at::_ops::xlogy__Scalar_Other::call(const_cast(*this), other);
+}
+
+// aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) const {
+    return at::_ops::log_softmax_int::call(const_cast(*this), dim, dtype);
+}
+
+// aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::log_softmax(at::Dimname dim, c10::optional dtype) const {
+    return at::_ops::log_softmax_Dimname::call(const_cast(*this), dim, dtype);
+}
+
+// aten::logcumsumexp(Tensor self, int dim) -> Tensor
+inline at::Tensor Tensor::logcumsumexp(int64_t dim) const {
+    return at::_ops::logcumsumexp::call(const_cast(*this), dim);
+}
+
+// aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor
+inline at::Tensor Tensor::logcumsumexp(at::Dimname dim) const {
+    return at::_ops::logcumsumexp_dimname::call(const_cast(*this), dim);
+}
+
+// aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::logsumexp(at::IntArrayRef dim, bool keepdim) const {
+    return at::_ops::logsumexp::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::logsumexp(at::DimnameList dim, bool keepdim) const {
+    return at::_ops::logsumexp_names::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::matmul(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::matmul(const at::Tensor & other) const {
+    return at::_ops::matmul::call(const_cast(*this), other);
+}
+
+// aten::matrix_power(Tensor self, int n) -> Tensor
+inline at::Tensor Tensor::matrix_power(int64_t n) const {
+    return at::_ops::matrix_power::call(const_cast(*this), n);
+}
+
+// aten::matrix_exp(Tensor self) -> Tensor
+inline at::Tensor Tensor::matrix_exp() const {
+    return at::_ops::matrix_exp::call(const_cast(*this));
+}
+
+// aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)
+inline ::std::tuple Tensor::aminmax(c10::optional dim, bool keepdim) const {
+    return at::_ops::aminmax::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::max(int64_t dim, bool keepdim) const {
+    return at::_ops::max_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::max(at::Dimname dim, bool keepdim) const {
+    return at::_ops::max_names_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::amax(at::IntArrayRef dim, bool keepdim) const {
+    return at::_ops::amax::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::mean(c10::optional dtype) const {
+    return at::_ops::mean::call(const_cast(*this), dtype);
+}
+
+// aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::mean(at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::mean_dim::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::mean(at::DimnameList dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::mean_names_dim::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::nanmean(at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::nanmean::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::median(Tensor self) -> Tensor
+inline at::Tensor Tensor::median() const {
+    return at::_ops::median::call(const_cast(*this));
+}
+
+// aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::median(int64_t dim, bool keepdim) const {
+    return at::_ops::median_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::median(at::Dimname dim, bool keepdim) const {
+    return at::_ops::median_names_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::nanmedian(Tensor self) -> Tensor
+inline at::Tensor Tensor::nanmedian() const {
+    return at::_ops::nanmedian::call(const_cast(*this));
+}
+
+// aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::nanmedian(int64_t dim, bool keepdim) const {
+    return at::_ops::nanmedian_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::nanmedian(at::Dimname dim, bool keepdim) const {
+    return at::_ops::nanmedian_names_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::min(int64_t dim, bool keepdim) const {
+    return at::_ops::min_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::min(at::Dimname dim, bool keepdim) const {
+    return at::_ops::min_names_dim::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::amin(at::IntArrayRef dim, bool keepdim) const {
+    return at::_ops::amin::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::mm(Tensor self, Tensor mat2) -> Tensor
+inline at::Tensor Tensor::mm(const at::Tensor & mat2) const {
+    return at::_ops::mm::call(const_cast(*this), mat2);
+}
+
+// aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::mode(int64_t dim, bool keepdim) const {
+    return at::_ops::mode::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::mode(at::Dimname dim, bool keepdim) const {
+    return at::_ops::mode_dimname::call(const_cast(*this), dim, keepdim);
+}
+
+// aten::mul.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::mul(const at::Tensor & other) const {
+    return at::_ops::mul_Tensor::call(const_cast(*this), other);
+}
+
+// aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::mul_(const at::Tensor & other) const {
+    return at::_ops::mul__Tensor::call(const_cast(*this), other);
+}
+
+// aten::mul.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::mul(const at::Scalar & other) const {
+    return at::_ops::mul_Scalar::call(const_cast(*this), other);
+}
+
+// aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::mul_(const at::Scalar & other) const {
+    return at::_ops::mul__Scalar::call(const_cast(*this), other);
+}
+
+// aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::multiply(const at::Tensor & other) const {
+    return at::_ops::multiply_Tensor::call(const_cast(*this), other);
+}
+
+// aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::multiply_(const at::Tensor & other) const {
+    return at::_ops::multiply__Tensor::call(const_cast(*this), other);
+}
+
+// aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::multiply(const at::Scalar & other) const {
+    return at::_ops::multiply_Scalar::call(const_cast(*this), other);
+}
+
+// aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::multiply_(const at::Scalar & other) const {
+    return at::_ops::multiply__Scalar::call(const_cast(*this), other);
+}
+
+// aten::mv(Tensor self, Tensor vec) -> Tensor
+inline at::Tensor Tensor::mv(const at::Tensor & vec) const {
+    return at::_ops::mv::call(const_cast(*this), vec);
+}
+
+// aten::mvlgamma(Tensor self, int p) -> Tensor
+inline at::Tensor Tensor::mvlgamma(int64_t p) const {
+    return at::_ops::mvlgamma::call(const_cast(*this), p);
+}
+
+// aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)
+inline at::Tensor & Tensor::mvlgamma_(int64_t p) const {
+    return at::_ops::mvlgamma_::call(const_cast(*this), p);
+}
+
+// aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
+inline at::Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const {
+    return at::_ops::narrow_copy::call(const_cast(*this), dim, start, length);
+}
+
+// aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
+inline at::Tensor Tensor::narrow_copy_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const {
+    return at::_ops::narrow_copy::call(const_cast(*this), dim, start, length);
+}
+
+// aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
+inline at::Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const {
+    return at::_ops::narrow::call(const_cast(*this), dim, start, length);
+}
+
+// aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
+inline at::Tensor Tensor::narrow_symint(int64_t dim, c10::SymInt start, c10::SymInt length) const {
+    return at::_ops::narrow::call(const_cast(*this), dim, start, length);
+}
+
+// aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)
+inline at::Tensor Tensor::narrow(int64_t dim, const at::Tensor & start, int64_t length) const {
+    return at::_ops::narrow_Tensor::call(const_cast(*this), dim, start, length);
+}
+
+// aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)
+inline at::Tensor Tensor::narrow_symint(int64_t dim, const at::Tensor & start, c10::SymInt length) const {
+    return at::_ops::narrow_Tensor::call(const_cast(*this), dim, start, length);
+}
+
+// aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)
+inline at::Tensor Tensor::permute(at::IntArrayRef dims) const {
+    return at::_ops::permute::call(const_cast(*this), dims);
+}
+
+// aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+inline at::Tensor Tensor::movedim(at::IntArrayRef source, at::IntArrayRef destination) const {
+    return at::_ops::movedim_intlist::call(const_cast(*this), source, destination);
+}
+
+// aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+inline at::Tensor Tensor::movedim(int64_t source, int64_t destination) const {
+    return at::_ops::movedim_int::call(const_cast(*this), source, destination);
+}
+
+// aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+inline at::Tensor Tensor::moveaxis(at::IntArrayRef source, at::IntArrayRef destination) const {
+    return at::_ops::moveaxis_intlist::call(const_cast(*this), source, destination);
+}
+
+// aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+inline at::Tensor Tensor::moveaxis(int64_t source, int64_t destination) const {
+    return at::_ops::moveaxis_int::call(const_cast(*this), source, destination);
+}
+
+// aten::numpy_T(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::numpy_T() const {
+    return at::_ops::numpy_T::call(const_cast(*this));
+}
+
+// aten::matrix_H(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::matrix_H() const {
+    return at::_ops::matrix_H::call(const_cast(*this));
+}
+
+// aten::mT(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::mT() const {
+    return at::_ops::mT::call(const_cast(*this));
+}
+
+// aten::mH(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::mH() const {
+    return at::_ops::mH::call(const_cast(*this));
+}
+
+// aten::adjoint(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::adjoint() const {
+    return at::_ops::adjoint::call(const_cast(*this));
+}
+
+// aten::is_pinned(Tensor self, Device? device=None) -> bool
+inline bool Tensor::is_pinned(c10::optional device) const {
+    return at::_ops::is_pinned::call(const_cast(*this), device);
+}
+
+// aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
+inline at::Tensor Tensor::pin_memory(c10::optional device) const {
+    return at::_ops::pin_memory::call(const_cast(*this), device);
+}
+
+// aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor
+inline at::Tensor Tensor::pinverse(double rcond) const {
+    return at::_ops::pinverse::call(const_cast(*this), rcond);
+}
+
+// aten::rad2deg(Tensor self) -> Tensor
+inline at::Tensor Tensor::rad2deg() const {
+    return at::_ops::rad2deg::call(const_cast(*this));
+}
+
+// aten::rad2deg_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::rad2deg_() const {
+    return at::_ops::rad2deg_::call(const_cast(*this));
+}
+
+// aten::deg2rad(Tensor self) -> Tensor
+inline at::Tensor Tensor::deg2rad() const {
+    return at::_ops::deg2rad::call(const_cast(*this));
+}
+
+// aten::deg2rad_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::deg2rad_() const {
+    return at::_ops::deg2rad_::call(const_cast(*this));
+}
+
+// aten::ravel(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::ravel() const {
+    return at::_ops::ravel::call(const_cast(*this));
+}
+
+// aten::reciprocal(Tensor self) -> Tensor
+inline at::Tensor Tensor::reciprocal() const {
+    return at::_ops::reciprocal::call(const_cast(*this));
+}
+
+// aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::reciprocal_() const {
+    return at::_ops::reciprocal_::call(const_cast(*this));
+}
+
+// aten::neg(Tensor self) -> Tensor
+inline at::Tensor Tensor::neg() const {
+    return at::_ops::neg::call(const_cast(*this));
+}
+
+// aten::neg_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::neg_() const {
+    return at::_ops::neg_::call(const_cast(*this));
+}
+
+// aten::negative(Tensor self) -> Tensor
+inline at::Tensor Tensor::negative() const {
+    return at::_ops::negative::call(const_cast(*this));
+}
+
+// aten::negative_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::negative_() const {
+    return at::_ops::negative_::call(const_cast(*this));
+}
+
+// aten::repeat(Tensor self, SymInt[] repeats) -> Tensor
+inline at::Tensor Tensor::repeat(at::IntArrayRef repeats) const {
+    return at::_ops::repeat::call(const_cast(*this), c10::fromIntArrayRefSlow(repeats));
+}
+
+// aten::repeat(Tensor self, SymInt[] repeats) -> Tensor
+inline at::Tensor Tensor::repeat_symint(c10::SymIntArrayRef repeats) const {
+    return at::_ops::repeat::call(const_cast(*this), repeats);
+}
+
+// aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+inline at::Tensor Tensor::repeat_interleave(const at::Tensor & repeats, c10::optional dim, c10::optional output_size) const {
+    return at::_ops::repeat_interleave_self_Tensor::call(const_cast(*this), repeats, dim, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt);
+}
+
+// aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+inline at::Tensor Tensor::repeat_interleave_symint(const at::Tensor & repeats, c10::optional dim, c10::optional output_size) const {
+    return at::_ops::repeat_interleave_self_Tensor::call(const_cast(*this), repeats, dim, output_size);
+}
+
+// aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+inline at::Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional dim, c10::optional output_size) const {
+    return at::_ops::repeat_interleave_self_int::call(const_cast(*this), repeats, dim, output_size.has_value() ? c10::make_optional(c10::SymInt(*output_size)) : c10::nullopt);
+}
+
+// aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+inline at::Tensor Tensor::repeat_interleave_symint(c10::SymInt repeats, c10::optional dim, c10::optional output_size) const {
+    return at::_ops::repeat_interleave_self_int::call(const_cast(*this), repeats, dim, output_size);
+}
+
+// aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
+inline at::Tensor Tensor::reshape(at::IntArrayRef shape) const {
+    return at::_ops::reshape::call(const_cast(*this), c10::fromIntArrayRefSlow(shape));
+}
+
+// aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
+inline at::Tensor Tensor::reshape_symint(c10::SymIntArrayRef shape) const {
+    return at::_ops::reshape::call(const_cast(*this), shape);
+}
+
+// aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
+inline at::Tensor Tensor::_reshape_alias(at::IntArrayRef size, at::IntArrayRef stride) const {
+    return at::_ops::_reshape_alias::call(const_cast(*this), c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+}
+
+// aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
+inline at::Tensor Tensor::_reshape_alias_symint(c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const {
+    return at::_ops::_reshape_alias::call(const_cast(*this), size, stride);
+}
+
+// aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)
+inline at::Tensor Tensor::reshape_as(const at::Tensor & other) const {
+    return at::_ops::reshape_as::call(const_cast(*this), other);
+}
+
+// aten::round(Tensor self) -> Tensor
+inline at::Tensor Tensor::round() const {
+    return at::_ops::round::call(const_cast(*this));
+}
+
+// aten::round_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::round_() const {
+    return at::_ops::round_::call(const_cast(*this));
+}
+
+// aten::round.decimals(Tensor self, *, int decimals) -> Tensor
+inline at::Tensor Tensor::round(int64_t decimals) const {
+    return at::_ops::round_decimals::call(const_cast(*this), decimals);
+}
+
+// aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!)
+inline at::Tensor & Tensor::round_(int64_t decimals) const {
+    return at::_ops::round__decimals::call(const_cast(*this), decimals);
+}
+
+// aten::relu(Tensor self) -> Tensor
+inline at::Tensor Tensor::relu() const {
+    return at::_ops::relu::call(const_cast(*this));
+}
+
+// aten::relu_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::relu_() const {
+    return at::_ops::relu_::call(const_cast(*this));
+}
+
+// aten::prelu(Tensor self, Tensor weight) -> Tensor
+inline at::Tensor Tensor::prelu(const at::Tensor & weight) const {
+    return at::_ops::prelu::call(const_cast(*this), weight);
+}
+
+// aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+inline at::Tensor Tensor::hardshrink(const at::Scalar & lambd) const {
+    return at::_ops::hardshrink::call(const_cast(*this), lambd);
+}
+
+// aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
+inline at::Tensor Tensor::hardshrink_backward(const at::Tensor & grad_out, const at::Scalar & lambd) const {
+    return at::_ops::hardshrink_backward::call(grad_out, const_cast(*this), lambd);
+}
+
+// aten::rsqrt(Tensor self) -> Tensor
+inline at::Tensor Tensor::rsqrt() const {
+    return at::_ops::rsqrt::call(const_cast(*this));
+}
+
+// aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::rsqrt_() const {
+    return at::_ops::rsqrt_::call(const_cast(*this));
+}
+
+// aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)
+inline at::Tensor Tensor::select(at::Dimname dim, int64_t index) const {
+    return at::_ops::select_Dimname::call(const_cast(*this), dim, index);
+}
+
+// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
+inline at::Tensor Tensor::select(int64_t dim, int64_t index) const {
+    return at::_ops::select_int::call(const_cast(*this), dim, index);
+}
+
+// aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
+inline at::Tensor Tensor::select_symint(int64_t dim, c10::SymInt index) const {
+    return at::_ops::select_int::call(const_cast(*this), dim, index);
+}
+
+// aten::sigmoid(Tensor self) -> Tensor
+inline at::Tensor Tensor::sigmoid() const {
+    return at::_ops::sigmoid::call(const_cast(*this));
+}
+
+// aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sigmoid_() const {
+    return at::_ops::sigmoid_::call(const_cast(*this));
+}
+
+// aten::logit(Tensor self, float? eps=None) -> Tensor
+inline at::Tensor Tensor::logit(c10::optional eps) const {
+    return at::_ops::logit::call(const_cast(*this), eps);
+}
+
+// aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)
+inline at::Tensor & Tensor::logit_(c10::optional eps) const {
+    return at::_ops::logit_::call(const_cast(*this), eps);
+}
+
+// aten::sin(Tensor self) -> Tensor
+inline at::Tensor Tensor::sin() const {
+    return at::_ops::sin::call(const_cast(*this));
+}
+
+// aten::sin_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sin_() const {
+    return at::_ops::sin_::call(const_cast(*this));
+}
+
+// aten::sinc(Tensor self) -> Tensor
+inline at::Tensor Tensor::sinc() const {
+    return at::_ops::sinc::call(const_cast(*this));
+}
+
+// aten::sinc_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sinc_() const {
+    return at::_ops::sinc_::call(const_cast(*this));
+}
+
+// aten::sinh(Tensor self) -> Tensor
+inline at::Tensor Tensor::sinh() const {
+    return at::_ops::sinh::call(const_cast(*this));
+}
+
+// aten::sinh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sinh_() const {
+    return at::_ops::sinh_::call(const_cast(*this));
+}
+
+// aten::detach(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::detach() const {
+    return at::_ops::detach::call(const_cast(*this));
+}
+
+// aten::detach_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::detach_() const {
+    return at::_ops::detach_::call(const_cast(*this));
+}
+
+// aten::size.Dimname(Tensor self, Dimname dim) -> int
+inline int64_t Tensor::size(at::Dimname dim) const {
+    return at::_ops::size_Dimname::call(const_cast(*this), dim);
+}
+
+// aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+inline at::Tensor Tensor::slice(int64_t dim, c10::optional start, c10::optional end, int64_t step) const {
+    return at::_ops::slice_Tensor::call(const_cast(*this), dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+}
+
+// aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+inline at::Tensor Tensor::slice_symint(int64_t dim, c10::optional start, c10::optional end, c10::SymInt step) const {
+    return at::_ops::slice_Tensor::call(const_cast(*this), dim, start, end, step);
+}
+
+// aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+inline at::Tensor Tensor::slice_inverse(const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, int64_t step) const {
+    return at::_ops::slice_inverse::call(const_cast(*this), src, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+}
+
+// aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+inline at::Tensor Tensor::slice_inverse_symint(const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step) const {
+    return at::_ops::slice_inverse::call(const_cast(*this), src, dim, start, end, step);
+}
+
+// aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+inline at::Tensor Tensor::slice_scatter(const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, int64_t step) const {
+    return at::_ops::slice_scatter::call(const_cast(*this), src, dim, start.has_value() ? c10::make_optional(c10::SymInt(*start)) : c10::nullopt, end.has_value() ? c10::make_optional(c10::SymInt(*end)) : c10::nullopt, step);
+}
+
+// aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+inline at::Tensor Tensor::slice_scatter_symint(const at::Tensor & src, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step) const {
+    return at::_ops::slice_scatter::call(const_cast(*this), src, dim, start, end, step);
+}
+
+// aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
+inline at::Tensor Tensor::select_scatter(const at::Tensor & src, int64_t dim, int64_t index) const {
+    return at::_ops::select_scatter::call(const_cast(*this), src, dim, index);
+}
+
+// aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
+inline at::Tensor Tensor::select_scatter_symint(const at::Tensor & src, int64_t dim, c10::SymInt index) const {
+    return at::_ops::select_scatter::call(const_cast(*this), src, dim, index);
+}
+
+// aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
+inline at::Tensor Tensor::diagonal_scatter(const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2) const {
+    return at::_ops::diagonal_scatter::call(const_cast(*this), src, offset, dim1, dim2);
+}
+
+// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+inline at::Tensor Tensor::as_strided_scatter(const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) const {
+    return at::_ops::as_strided_scatter::call(const_cast(*this), src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
+}
+
+// aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+inline at::Tensor Tensor::as_strided_scatter_symint(const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset) const {
+    return at::_ops::as_strided_scatter::call(const_cast(*this), src, size, stride, storage_offset);
+}
+
+// aten::smm(Tensor self, Tensor mat2) -> Tensor
+inline at::Tensor Tensor::smm(const at::Tensor & mat2) const {
+    return at::_ops::smm::call(const_cast(*this), mat2);
+}
+
+// aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::softmax(int64_t dim, c10::optional dtype) const {
+    return at::_ops::softmax_int::call(const_cast(*this), dim, dtype);
+}
+
+// aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::softmax(at::Dimname dim, c10::optional dtype) const {
+    return at::_ops::softmax_Dimname::call(const_cast(*this), dim, dtype);
+}
+
+// aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+inline ::std::vector Tensor::unsafe_split(int64_t split_size, int64_t dim) const {
+    return at::_ops::unsafe_split_Tensor::call(const_cast(*this), split_size, dim);
+}
+
+// aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+inline ::std::vector Tensor::unsafe_split_symint(c10::SymInt split_size, int64_t dim) const {
+    return at::_ops::unsafe_split_Tensor::call(const_cast(*this), split_size, dim);
+}
+
+// aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::split(int64_t split_size, int64_t dim) const {
+    return at::_ops::split_Tensor::call(const_cast(*this), split_size, dim);
+}
+
+// aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::split_symint(c10::SymInt split_size, int64_t dim) const {
+    return at::_ops::split_Tensor::call(const_cast(*this), split_size, dim);
+}
+
+// aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::split(at::IntArrayRef split_size, int64_t dim) const {
+    return at::_ops::split_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_size), dim);
+}
+
+// aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::split_symint(c10::SymIntArrayRef split_size, int64_t dim) const {
+    return at::_ops::split_sizes::call(const_cast(*this), split_size, dim);
+}
+
+// aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+inline ::std::vector Tensor::unsafe_split_with_sizes(at::IntArrayRef split_sizes, int64_t dim) const {
+    return at::_ops::unsafe_split_with_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_sizes), dim);
+}
+
+// aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+inline ::std::vector Tensor::unsafe_split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim) const {
+    return at::_ops::unsafe_split_with_sizes::call(const_cast(*this), split_sizes, dim);
+}
+
+// aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::split_with_sizes(at::IntArrayRef split_sizes, int64_t dim) const {
+    return at::_ops::split_with_sizes::call(const_cast(*this), c10::fromIntArrayRefSlow(split_sizes), dim);
+}
+
+// aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::split_with_sizes_symint(c10::SymIntArrayRef split_sizes, int64_t dim) const {
+    return at::_ops::split_with_sizes::call(const_cast(*this), split_sizes, dim);
+}
+
+// aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+inline ::std::vector Tensor::hsplit(int64_t sections) const {
+    return at::_ops::hsplit_int::call(const_cast(*this), sections);
+}
+
+// aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+inline ::std::vector Tensor::hsplit(at::IntArrayRef indices) const {
+    return at::_ops::hsplit_array::call(const_cast(*this), indices);
+}
+
+// aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+inline ::std::vector Tensor::vsplit(int64_t sections) const {
+    return at::_ops::vsplit_int::call(const_cast(*this), sections);
+}
+
+// aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+inline ::std::vector Tensor::vsplit(at::IntArrayRef indices) const {
+    return at::_ops::vsplit_array::call(const_cast(*this), indices);
+}
+
+// aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+inline ::std::vector Tensor::dsplit(int64_t sections) const {
+    return at::_ops::dsplit_int::call(const_cast(*this), sections);
+}
+
+// aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+inline ::std::vector Tensor::dsplit(at::IntArrayRef indices) const {
+    return at::_ops::dsplit_array::call(const_cast(*this), indices);
+}
+
+// aten::squeeze(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::squeeze() const {
+    return at::_ops::squeeze::call(const_cast(*this));
+}
+
+// aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
+inline at::Tensor Tensor::squeeze(int64_t dim) const {
+    return at::_ops::squeeze_dim::call(const_cast(*this), dim);
+}
+
+// aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
+inline at::Tensor Tensor::squeeze(at::Dimname dim) const {
+    return at::_ops::squeeze_dimname::call(const_cast(*this), dim);
+}
+
+// aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+inline at::Tensor Tensor::squeeze(at::IntArrayRef dim) const {
+    return at::_ops::squeeze_dims::call(const_cast(*this), dim);
+}
+
+// aten::squeeze_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::squeeze_() const {
+    return at::_ops::squeeze_::call(const_cast(*this));
+}
+
+// aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
+inline at::Tensor & Tensor::squeeze_(int64_t dim) const {
+    return at::_ops::squeeze__dim::call(const_cast(*this), dim);
+}
+
+// aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+inline at::Tensor & Tensor::squeeze_(at::IntArrayRef dim) const {
+    return at::_ops::squeeze__dims::call(const_cast(*this), dim);
+}
+
+// aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)
+inline at::Tensor & Tensor::squeeze_(at::Dimname dim) const {
+    return at::_ops::squeeze__dimname::call(const_cast(*this), dim);
+}
+
+// aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::sspaddmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::sspaddmm::call(const_cast(*this), mat1, mat2, beta, alpha);
+}
+
+// aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
+inline at::Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool normalized, c10::optional onesided, c10::optional return_complex) const {
+    return at::_ops::stft::call(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided, return_complex);
+}
+
+// aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
+inline at::Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool center, c10::string_view pad_mode, bool normalized, c10::optional onesided, c10::optional return_complex) const {
+    return at::_ops::stft_center::call(const_cast(*this), n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex);
+}
+
+// aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
+inline at::Tensor Tensor::istft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const c10::optional & window, bool center, bool normalized, c10::optional onesided, c10::optional length, bool return_complex) const {
+    return at::_ops::istft::call(const_cast(*this), n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex);
+}
+
+// aten::stride.Dimname(Tensor self, Dimname dim) -> int
+inline int64_t Tensor::stride(at::Dimname dim) const {
+    return at::_ops::stride_Dimname::call(const_cast(*this), dim);
+}
+
+// aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::sum(c10::optional dtype) const {
+    return at::_ops::sum::call(const_cast(*this), dtype);
+}
+
+// aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::sum(at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::sum_dim_IntList::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::sum(at::DimnameList dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::sum_dim_DimnameList::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::nansum(at::OptionalIntArrayRef dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::nansum::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor
+inline at::Tensor Tensor::sum_to_size(at::IntArrayRef size) const {
+    return at::_ops::sum_to_size::call(const_cast(*this), c10::fromIntArrayRefSlow(size));
+}
+
+// aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor
+inline at::Tensor Tensor::sum_to_size_symint(c10::SymIntArrayRef size) const {
+    return at::_ops::sum_to_size::call(const_cast(*this), size);
+}
+
+// aten::sqrt(Tensor self) -> Tensor
+inline at::Tensor Tensor::sqrt() const {
+    return at::_ops::sqrt::call(const_cast(*this));
+}
+
+// aten::sqrt_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sqrt_() const {
+    return at::_ops::sqrt_::call(const_cast(*this));
+}
+
+// aten::square(Tensor self) -> Tensor
+inline at::Tensor Tensor::square() const {
+    return at::_ops::square::call(const_cast(*this));
+}
+
+// aten::square_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::square_() const {
+    return at::_ops::square_::call(const_cast(*this));
+}
+
+// aten::std(Tensor self, bool unbiased=True) -> Tensor
+inline at::Tensor Tensor::std(bool unbiased) const {
+    return at::_ops::std::call(const_cast(*this), unbiased);
+}
+
+// aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::std(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) const {
+    return at::_ops::std_dim::call(const_cast(*this), dim, unbiased, keepdim);
+}
+
+// aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::std(at::OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim) const {
+    return at::_ops::std_correction::call(const_cast(*this), dim, correction, keepdim);
+}
+
+// aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::std(at::DimnameList dim, bool unbiased, bool keepdim) const {
+    return at::_ops::std_names_dim::call(const_cast(*this), dim, unbiased, keepdim);
+}
+
+// aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::std(at::DimnameList dim, const c10::optional & correction, bool keepdim) const {
+    return at::_ops::std_correction_names::call(const_cast(*this), dim, correction, keepdim);
+}
+
+// aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::prod(c10::optional dtype) const {
+    return at::_ops::prod::call(const_cast(*this), dtype);
+}
+
+// aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::prod_dim_int::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::prod(at::Dimname dim, bool keepdim, c10::optional dtype) const {
+    return at::_ops::prod_dim_Dimname::call(const_cast(*this), dim, keepdim, dtype);
+}
+
+// aten::t(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::t() const {
+    return at::_ops::t::call(const_cast(*this));
+}
+
+// aten::t_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::t_() const {
+    return at::_ops::t_::call(const_cast(*this));
+}
+
+// aten::tan(Tensor self) -> Tensor
+inline at::Tensor Tensor::tan() const {
+    return at::_ops::tan::call(const_cast(*this));
+}
+
+// aten::tan_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::tan_() const {
+    return at::_ops::tan_::call(const_cast(*this));
+}
+
+// aten::tanh(Tensor self) -> Tensor
+inline at::Tensor Tensor::tanh() const {
+    return at::_ops::tanh::call(const_cast(*this));
+}
+
+// aten::tanh_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::tanh_() const {
+    return at::_ops::tanh_::call(const_cast(*this));
+}
+
+// aten::tile(Tensor self, SymInt[] dims) -> Tensor
+inline at::Tensor Tensor::tile(at::IntArrayRef dims) const {
+    return at::_ops::tile::call(const_cast(*this), c10::fromIntArrayRefSlow(dims));
+}
+
+// aten::tile(Tensor self, SymInt[] dims) -> Tensor
+inline at::Tensor Tensor::tile_symint(c10::SymIntArrayRef dims) const {
+    return at::_ops::tile::call(const_cast(*this), dims);
+}
+
+// aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+inline at::Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const {
+    return at::_ops::transpose_int::call(const_cast(*this), dim0, dim1);
+}
+
+// aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)
+inline at::Tensor Tensor::transpose(at::Dimname dim0, at::Dimname dim1) const {
+    return at::_ops::transpose_Dimname::call(const_cast(*this), dim0, dim1);
+}
+
+// aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+inline at::Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const {
+    return at::_ops::transpose_::call(const_cast(*this), dim0, dim1);
+}
+
+// aten::flip(Tensor self, int[] dims) -> Tensor
+inline at::Tensor Tensor::flip(at::IntArrayRef dims) const {
+    return at::_ops::flip::call(const_cast(*this), dims);
+}
+
+// aten::fliplr(Tensor self) -> Tensor
+inline at::Tensor Tensor::fliplr() const {
+    return at::_ops::fliplr::call(const_cast(*this));
+}
+
+// aten::flipud(Tensor self) -> Tensor
+inline at::Tensor Tensor::flipud() const {
+    return at::_ops::flipud::call(const_cast(*this));
+}
+
+// aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
+inline at::Tensor Tensor::roll(at::IntArrayRef shifts, at::IntArrayRef dims) const {
+    return at::_ops::roll::call(const_cast(*this), c10::fromIntArrayRefSlow(shifts), dims);
+}
+
+// aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
+inline at::Tensor Tensor::roll_symint(c10::SymIntArrayRef shifts, at::IntArrayRef dims) const {
+    return at::_ops::roll::call(const_cast(*this), shifts, dims);
+}
+
+// aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
+inline at::Tensor Tensor::rot90(int64_t k, at::IntArrayRef dims) const {
+    return at::_ops::rot90::call(const_cast(*this), k, dims);
+}
+
+// aten::_nested_tensor_size(Tensor self) -> Tensor
+inline at::Tensor Tensor::_nested_tensor_size() const {
+    return at::_ops::_nested_tensor_size::call(const_cast(*this));
+}
+
+// aten::_nested_tensor_strides(Tensor self) -> Tensor
+inline at::Tensor Tensor::_nested_tensor_strides() const {
+    return at::_ops::_nested_tensor_strides::call(const_cast(*this));
+}
+
+// aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor
+inline at::Tensor Tensor::_nested_tensor_storage_offsets() const {
+    return at::_ops::_nested_tensor_storage_offsets::call(const_cast(*this));
+}
+
+// aten::trunc(Tensor self) -> Tensor
+inline at::Tensor Tensor::trunc() const {
+    return at::_ops::trunc::call(const_cast(*this));
+}
+
+// aten::trunc_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::trunc_() const {
+    return at::_ops::trunc_::call(const_cast(*this));
+}
+
+// aten::fix(Tensor self) -> Tensor
+inline at::Tensor Tensor::fix() const {
+    return at::_ops::fix::call(const_cast(*this));
+}
+
+// aten::fix_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::fix_() const {
+    return at::_ops::fix_::call(const_cast(*this));
+}
+
+// aten::type_as(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::type_as(const at::Tensor & other) const {
+    return at::_ops::type_as::call(const_cast(*this), other);
+}
+
+// aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
+inline at::Tensor Tensor::unsqueeze(int64_t dim) const {
+    return at::_ops::unsqueeze::call(const_cast(*this), dim);
+}
+
+// aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
+inline at::Tensor & Tensor::unsqueeze_(int64_t dim) const {
+    return at::_ops::unsqueeze_::call(const_cast(*this), dim);
+}
+
+// aten::var(Tensor self, bool unbiased=True) -> Tensor
+inline at::Tensor Tensor::var(bool unbiased) const {
+    return at::_ops::var::call(const_cast(*this), unbiased);
+}
+
+// aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::var(at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) const {
+    return at::_ops::var_dim::call(const_cast(*this), dim, unbiased, keepdim);
+}
+
+// aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::var(at::OptionalIntArrayRef dim, const c10::optional & correction, bool keepdim) const {
+    return at::_ops::var_correction::call(const_cast(*this), dim, correction, keepdim);
+}
+
+// aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::var(at::DimnameList dim, bool unbiased, bool keepdim) const {
+    return at::_ops::var_names_dim::call(const_cast(*this), dim, unbiased, keepdim);
+}
+
+// aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::var(at::DimnameList dim, const c10::optional & correction, bool keepdim) const {
+    return at::_ops::var_correction_names::call(const_cast(*this), dim, correction, keepdim);
+}
+
+// aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)
+inline at::Tensor Tensor::view_as(const at::Tensor & other) const {
+    return at::_ops::view_as::call(const_cast(*this), other);
+}
+
+// aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::where(const at::Tensor & condition, const at::Tensor & other) const {
+    return at::_ops::where_self::call(condition, const_cast(*this), other);
+}
+
+// aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::where(const at::Tensor & condition, const at::Scalar & other) const {
+    return at::_ops::where_ScalarOther::call(condition, const_cast(*this), other);
+}
+
+// aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
+inline at::Tensor Tensor::norm(const c10::optional & p, at::ScalarType dtype) const {
+    return at::_ops::norm_ScalarOpt_dtype::call(const_cast(*this), p, dtype);
+}
+
+// aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor
+inline at::Tensor Tensor::norm(const at::Scalar & p) const {
+    return at::_ops::norm_Scalar::call(const_cast(*this), p);
+}
+
+// aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+inline at::Tensor Tensor::norm(const c10::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) const {
+    return at::_ops::norm_ScalarOpt_dim_dtype::call(const_cast(*this), p, dim, keepdim, dtype);
+}
+
+// aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::norm(const c10::optional & p, at::IntArrayRef dim, bool keepdim) const {
+    return at::_ops::norm_ScalarOpt_dim::call(const_cast(*this), p, dim, keepdim);
+}
+
+// aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+inline at::Tensor Tensor::norm(const c10::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) const {
+    return at::_ops::norm_names_ScalarOpt_dim_dtype::call(const_cast(*this), p, dim, keepdim, dtype);
+}
+
+// aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor
+inline at::Tensor Tensor::norm(const c10::optional & p, at::DimnameList dim, bool keepdim) const {
+    return at::_ops::norm_names_ScalarOpt_dim::call(const_cast(*this), p, dim, keepdim);
+}
+
+// aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
+inline ::std::tuple Tensor::frexp() const {
+    return at::_ops::frexp_Tensor::call(const_cast(*this));
+}
+
+// aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+inline at::Tensor Tensor::clone(c10::optional memory_format) const {
+    return at::_ops::clone::call(const_cast(*this), memory_format);
+}
+
+// aten::positive(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::positive() const {
+    return at::_ops::positive::call(const_cast(*this));
+}
+
+// aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+inline const at::Tensor & Tensor::resize_as_(const at::Tensor & the_template, c10::optional memory_format) const {
+    return at::_ops::resize_as_::call(const_cast(*this), the_template, memory_format);
+}
+
+// aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
+inline const at::Tensor & Tensor::resize_as_sparse_(const at::Tensor & the_template) const {
+    return at::_ops::resize_as_sparse_::call(const_cast(*this), the_template);
+}
+
+// aten::zero_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::zero_() const {
+    return at::_ops::zero_::call(const_cast(*this));
+}
+
+// aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::sub(const at::Tensor & other, const at::Scalar & alpha) const {
+    return at::_ops::sub_Tensor::call(const_cast(*this), other, alpha);
+}
+
+// aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::sub_(const at::Tensor & other, const at::Scalar & alpha) const {
+    return at::_ops::sub__Tensor::call(const_cast(*this), other, alpha);
+}
+
+// aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::sub(const at::Scalar & other, const at::Scalar & alpha) const {
+    return at::_ops::sub_Scalar::call(const_cast(*this), other, alpha);
+}
+
+// aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::sub_(const at::Scalar & other, const at::Scalar & alpha) const {
+    return at::_ops::sub__Scalar::call(const_cast(*this), other, alpha);
+}
+
+// aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::subtract(const at::Tensor & other, const at::Scalar & alpha) const {
+    return at::_ops::subtract_Tensor::call(const_cast(*this), other, alpha);
+}
+
+// aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::subtract_(const at::Tensor & other, const at::Scalar & alpha) const {
+    return at::_ops::subtract__Tensor::call(const_cast(*this), other, alpha);
+}
+
+// aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::subtract(const at::Scalar & other, const at::Scalar & alpha) const {
+    return at::_ops::subtract_Scalar::call(const_cast(*this), other, alpha);
+}
+
+// aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::subtract_(const at::Scalar & other, const at::Scalar & alpha) const {
+    return at::_ops::subtract__Scalar::call(const_cast(*this), other, alpha);
+}
+
+// aten::heaviside(Tensor self, Tensor values) -> Tensor
+inline at::Tensor Tensor::heaviside(const at::Tensor & values) const {
+    return at::_ops::heaviside::call(const_cast(*this), values);
+}
+
+// aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
+inline at::Tensor & Tensor::heaviside_(const at::Tensor & values) const {
+    return at::_ops::heaviside_::call(const_cast(*this), values);
+}
+
+// aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::addmm(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addmm::call(const_cast(*this), mat1, mat2, beta, alpha);
+}
+
+// aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::addmm_(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addmm_::call(const_cast(*this), mat1, mat2, beta, alpha);
+}
+
+// aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
+inline at::Tensor Tensor::_addmm_activation(const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu) const {
+    return at::_ops::_addmm_activation::call(const_cast(*this), mat1, mat2, beta, alpha, use_gelu);
+}
+
+// aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
+inline const at::Tensor & Tensor::sparse_resize_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const {
+    return at::_ops::sparse_resize_::call(const_cast(*this), size, sparse_dim, dense_dim);
+}
+
+// aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
+inline const at::Tensor & Tensor::sparse_resize_and_clear_(at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const {
+    return at::_ops::sparse_resize_and_clear_::call(const_cast(*this), size, sparse_dim, dense_dim);
+}
+
+// aten::sparse_mask(Tensor self, Tensor mask) -> Tensor
+inline at::Tensor Tensor::sparse_mask(const at::Tensor & mask) const {
+    return at::_ops::sparse_mask::call(const_cast(*this), mask);
+}
+
+// aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
+inline at::Tensor Tensor::_sparse_mask_projection(const at::Tensor & mask, bool accumulate_matches) const {
+    return at::_ops::_sparse_mask_projection::call(const_cast(*this), mask, accumulate_matches);
+}
+
+// aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
+inline at::Tensor Tensor::to_dense(c10::optional dtype, c10::optional masked_grad) const {
+    return at::_ops::to_dense::call(const_cast(*this), dtype, masked_grad);
+}
+
+// aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
+inline at::Tensor Tensor::_to_dense(c10::optional dtype, c10::optional masked_grad) const {
+    return at::_ops::_to_dense::call(const_cast(*this), dtype, masked_grad);
+}
+
+// aten::sparse_dim(Tensor self) -> int
+inline int64_t Tensor::sparse_dim() const {
+    return at::_ops::sparse_dim::call(const_cast(*this));
+}
+
+// aten::_dimI(Tensor self) -> int
+inline int64_t Tensor::_dimI() const {
+    return at::_ops::_dimI::call(const_cast(*this));
+}
+
+// aten::dense_dim(Tensor self) -> int
+inline int64_t Tensor::dense_dim() const {
+    return at::_ops::dense_dim::call(const_cast(*this));
+}
+
+// aten::_dimV(Tensor self) -> int
+inline int64_t Tensor::_dimV() const {
+    return at::_ops::_dimV::call(const_cast(*this));
+}
+
+// aten::_nnz(Tensor self) -> int
+inline int64_t Tensor::_nnz() const {
+    return at::_ops::_nnz::call(const_cast(*this));
+}
+
+// aten::coalesce(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::coalesce() const {
+    return at::_ops::coalesce::call(const_cast(*this));
+}
+
+// aten::is_coalesced(Tensor self) -> bool
+inline bool Tensor::is_coalesced() const {
+    return at::_ops::is_coalesced::call(const_cast(*this));
+}
+
+// aten::_indices(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::_indices() const {
+    return at::_ops::_indices::call(const_cast(*this));
+}
+
+// aten::_values(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::_values() const {
+    return at::_ops::_values::call(const_cast(*this));
+}
+
+// aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)
+inline at::Tensor & Tensor::_coalesced_(bool coalesced) const {
+    return at::_ops::_coalesced_::call(const_cast(*this), coalesced);
+}
+
+// aten::indices(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::indices() const {
+    return at::_ops::indices::call(const_cast(*this));
+}
+
+// aten::values(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::values() const {
+    return at::_ops::values::call(const_cast(*this));
+}
+
+// aten::crow_indices(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::crow_indices() const {
+    return at::_ops::crow_indices::call(const_cast(*this));
+}
+
+// aten::col_indices(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::col_indices() const {
+    return at::_ops::col_indices::call(const_cast(*this));
+}
+
+// aten::ccol_indices(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::ccol_indices() const {
+    return at::_ops::ccol_indices::call(const_cast(*this));
+}
+
+// aten::row_indices(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::row_indices() const {
+    return at::_ops::row_indices::call(const_cast(*this));
+}
+
+// aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
+inline ::std::vector Tensor::unbind(int64_t dim) const {
+    return at::_ops::unbind_int::call(const_cast(*this), dim);
+}
+
+// aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]
+inline ::std::vector Tensor::unbind(at::Dimname dim) const {
+    return at::_ops::unbind_Dimname::call(const_cast(*this), dim);
+}
+
+// aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+inline at::Tensor Tensor::to_sparse(int64_t sparse_dim) const {
+    return at::_ops::to_sparse_sparse_dim::call(const_cast(*this), sparse_dim);
+}
+
+// aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+inline at::Tensor Tensor::_to_sparse(int64_t sparse_dim) const {
+    return at::_ops::_to_sparse_sparse_dim::call(const_cast(*this), sparse_dim);
+}
+
+// aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::to_sparse(c10::optional layout, at::OptionalIntArrayRef blocksize, c10::optional dense_dim) const {
+    return at::_ops::to_sparse::call(const_cast(*this), layout, blocksize, dense_dim);
+}
+
+// aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::_to_sparse(c10::optional layout, at::OptionalIntArrayRef blocksize, c10::optional dense_dim) const {
+    return at::_ops::_to_sparse::call(const_cast(*this), layout, blocksize, dense_dim);
+}
+
+// aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::to_sparse_csr(c10::optional dense_dim) const {
+    return at::_ops::to_sparse_csr::call(const_cast(*this), dense_dim);
+}
+
+// aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::_to_sparse_csr(c10::optional dense_dim) const {
+    return at::_ops::_to_sparse_csr::call(const_cast(*this), dense_dim);
+}
+
+// aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::to_sparse_csc(c10::optional dense_dim) const {
+    return at::_ops::to_sparse_csc::call(const_cast(*this), dense_dim);
+}
+
+// aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::_to_sparse_csc(c10::optional dense_dim) const {
+    return at::_ops::_to_sparse_csc::call(const_cast(*this), dense_dim);
+}
+
+// aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::to_sparse_bsr(at::IntArrayRef blocksize, c10::optional dense_dim) const {
+    return at::_ops::to_sparse_bsr::call(const_cast(*this), blocksize, dense_dim);
+}
+
+// aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::_to_sparse_bsr(at::IntArrayRef blocksize, c10::optional dense_dim) const {
+    return at::_ops::_to_sparse_bsr::call(const_cast(*this), blocksize, dense_dim);
+}
+
+// aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::to_sparse_bsc(at::IntArrayRef blocksize, c10::optional dense_dim) const {
+    return at::_ops::to_sparse_bsc::call(const_cast(*this), blocksize, dense_dim);
+}
+
+// aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+inline at::Tensor Tensor::_to_sparse_bsc(at::IntArrayRef blocksize, c10::optional dense_dim) const {
+    return at::_ops::_to_sparse_bsc::call(const_cast(*this), blocksize, dense_dim);
+}
+
+// aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor
+inline at::Tensor Tensor::to_mkldnn(c10::optional dtype) const {
+    return at::_ops::to_mkldnn::call(const_cast(*this), dtype);
+}
+
+// aten::dequantize.self(Tensor self) -> Tensor
+inline at::Tensor Tensor::dequantize() const {
+    return at::_ops::dequantize_self::call(const_cast(*this));
+}
+
+// aten::q_scale(Tensor self) -> float
+inline double Tensor::q_scale() const {
+    return at::_ops::q_scale::call(const_cast(*this));
+}
+
+// aten::q_zero_point(Tensor self) -> int
+inline int64_t Tensor::q_zero_point() const {
+    return at::_ops::q_zero_point::call(const_cast(*this));
+}
+
+// aten::q_per_channel_scales(Tensor self) -> Tensor
+inline at::Tensor Tensor::q_per_channel_scales() const {
+    return at::_ops::q_per_channel_scales::call(const_cast(*this));
+}
+
+// aten::q_per_channel_zero_points(Tensor self) -> Tensor
+inline at::Tensor Tensor::q_per_channel_zero_points() const {
+    return at::_ops::q_per_channel_zero_points::call(const_cast(*this));
+}
+
+// aten::q_per_channel_axis(Tensor self) -> int
+inline int64_t Tensor::q_per_channel_axis() const {
+    return at::_ops::q_per_channel_axis::call(const_cast(*this));
+}
+
+// aten::int_repr(Tensor self) -> Tensor
+inline at::Tensor Tensor::int_repr() const {
+    return at::_ops::int_repr::call(const_cast(*this));
+}
+
+// aten::qscheme(Tensor self) -> QScheme
+inline at::QScheme Tensor::qscheme() const {
+    return at::_ops::qscheme::call(const_cast(*this));
+}
+
+// aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)
+inline at::Tensor Tensor::_autocast_to_reduced_precision(bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) const {
+    return at::_ops::_autocast_to_reduced_precision::call(const_cast(*this), cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype);
+}
+
+// aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)
+inline at::Tensor Tensor::_autocast_to_full_precision(bool cuda_enabled, bool cpu_enabled) const {
+    return at::_ops::_autocast_to_full_precision::call(const_cast(*this), cuda_enabled, cpu_enabled);
+}
+
+// aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+inline at::Tensor Tensor::to(at::TensorOptions options, bool non_blocking, bool copy, c10::optional memory_format) const {
+    return at::_ops::to_dtype_layout::call(const_cast(*this), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, copy, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
+}
+
+// aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+inline at::Tensor Tensor::to(c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, bool copy, c10::optional memory_format) const {
+    return at::_ops::to_dtype_layout::call(const_cast(*this), dtype, layout, device, pin_memory, non_blocking, copy, memory_format);
+}
+
+// aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+inline at::Tensor Tensor::to(at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional memory_format) const {
+    return at::_ops::to_device::call(const_cast(*this), device, dtype, non_blocking, copy, memory_format);
+}
+
+// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+inline at::Tensor Tensor::to(at::ScalarType dtype, bool non_blocking, bool copy, c10::optional memory_format) const {
+    return at::_ops::to_dtype::call(const_cast(*this), dtype, non_blocking, copy, memory_format);
+}
+
+// aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+inline at::Tensor Tensor::to(const at::Tensor & other, bool non_blocking, bool copy, c10::optional memory_format) const {
+    return at::_ops::to_other::call(const_cast(*this), other, non_blocking, copy, memory_format);
+}
+
+// aten::item(Tensor self) -> Scalar
+inline at::Scalar Tensor::item() const {
+    return at::_ops::item::call(const_cast(*this));
+}
+
+// aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)
+inline at::Tensor & Tensor::set_(at::Storage source) const {
+    return at::_ops::set__source_Storage::call(const_cast(*this), source);
+}
+
+// aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+inline at::Tensor & Tensor::set_(at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride) const {
+    return at::_ops::set__source_Storage_storage_offset::call(const_cast(*this), source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+}
+
+// aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+inline at::Tensor & Tensor::set__symint(at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const {
+    return at::_ops::set__source_Storage_storage_offset::call(const_cast(*this), source, storage_offset, size, stride);
+}
+
+// aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+inline at::Tensor & Tensor::set_(const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride) const {
+    return at::_ops::set__source_Tensor_storage_offset::call(const_cast(*this), source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
+}
+
+// aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+inline at::Tensor & Tensor::set__symint(const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) const {
+    return at::_ops::set__source_Tensor_storage_offset::call(const_cast(*this), source, storage_offset, size, stride);
+}
+
+// aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)
+inline at::Tensor & Tensor::set_(const at::Tensor & source) const {
+    return at::_ops::set__source_Tensor::call(const_cast(*this), source);
+}
+
+// aten::set_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::set_() const {
+    return at::_ops::set_::call(const_cast(*this));
+}
+
+// aten::is_set_to(Tensor self, Tensor tensor) -> bool
+inline bool Tensor::is_set_to(const at::Tensor & tensor) const {
+    return at::_ops::is_set_to::call(const_cast(*this), tensor);
+}
+
+// aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)
+inline at::Tensor & Tensor::masked_fill_(const at::Tensor & mask, const at::Scalar & value) const {
+    return at::_ops::masked_fill__Scalar::call(const_cast(*this), mask, value);
+}
+
+// aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
+inline at::Tensor Tensor::masked_fill(const at::Tensor & mask, const at::Scalar & value) const {
+    return at::_ops::masked_fill_Scalar::call(const_cast(*this), mask, value);
+}
+
+// aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)
+inline at::Tensor & Tensor::masked_fill_(const at::Tensor & mask, const at::Tensor & value) const {
+    return at::_ops::masked_fill__Tensor::call(const_cast(*this), mask, value);
+}
+
+// aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
+inline at::Tensor Tensor::masked_fill(const at::Tensor & mask, const at::Tensor & value) const {
+    return at::_ops::masked_fill_Tensor::call(const_cast(*this), mask, value);
+}
+
+// aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)
+inline at::Tensor & Tensor::masked_scatter_(const at::Tensor & mask, const at::Tensor & source) const {
+    return at::_ops::masked_scatter_::call(const_cast(*this), mask, source);
+}
+
+// aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
+inline at::Tensor Tensor::masked_scatter(const at::Tensor & mask, const at::Tensor & source) const {
+    return at::_ops::masked_scatter::call(const_cast(*this), mask, source);
+}
+
+// aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)
+inline at::Tensor Tensor::view(at::IntArrayRef size) const {
+    return at::_ops::view::call(const_cast(*this), c10::fromIntArrayRefSlow(size));
+}
+
+// aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)
+inline at::Tensor Tensor::view_symint(c10::SymIntArrayRef size) const {
+    return at::_ops::view::call(const_cast(*this), size);
+}
+
+// aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
+inline at::Tensor Tensor::view(at::ScalarType dtype) const {
+    return at::_ops::view_dtype::call(const_cast(*this), dtype);
+}
+
+// aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)
+inline at::Tensor & Tensor::put_(const at::Tensor & index, const at::Tensor & source, bool accumulate) const {
+    return at::_ops::put_::call(const_cast(*this), index, source, accumulate);
+}
+
+// aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
+inline at::Tensor Tensor::put(const at::Tensor & index, const at::Tensor & source, bool accumulate) const {
+    return at::_ops::put::call(const_cast(*this), index, source, accumulate);
+}
+
+// aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::index_add_(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const {
+    return at::_ops::index_add_::call(const_cast(*this), dim, index, source, alpha);
+}
+
+// aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::index_add(int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const {
+    return at::_ops::index_add::call(const_cast(*this), dim, index, source, alpha);
+}
+
+// aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::index_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) const {
+    return at::_ops::index_add_dimname::call(const_cast(*this), dim, index, source, alpha);
+}
+
+// aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!)
+inline at::Tensor & Tensor::index_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) const {
+    return at::_ops::index_reduce_::call(const_cast(*this), dim, index, source, reduce, include_self);
+}
+
+// aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor
+inline at::Tensor Tensor::index_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) const {
+    return at::_ops::index_reduce::call(const_cast(*this), dim, index, source, reduce, include_self);
+}
+
+// aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
+inline at::Tensor & Tensor::index_fill_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::index_fill__int_Scalar::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+inline at::Tensor Tensor::index_fill(int64_t dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::index_fill_int_Scalar::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)
+inline at::Tensor & Tensor::index_fill_(int64_t dim, const at::Tensor & index, const at::Tensor & value) const {
+    return at::_ops::index_fill__int_Tensor::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
+inline at::Tensor Tensor::index_fill(int64_t dim, const at::Tensor & index, const at::Tensor & value) const {
+    return at::_ops::index_fill_int_Tensor::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)
+inline at::Tensor & Tensor::index_fill_(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::index_fill__Dimname_Scalar::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)
+inline at::Tensor & Tensor::index_fill_(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const {
+    return at::_ops::index_fill__Dimname_Tensor::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor
+inline at::Tensor Tensor::index_fill(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::index_fill_Dimname_Scalar::call(const_cast(*this), dim, index, value);
+}
+
+// aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor
+inline at::Tensor Tensor::index_fill(at::Dimname dim, const at::Tensor & index, const at::Tensor & value) const {
+    return at::_ops::index_fill_Dimname_Tensor::call(const_cast(*this), dim, index, value);
+}
+
+// aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src) const {
+    return at::_ops::scatter_src::call(const_cast(*this), dim, index, src);
+}
+
+// aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
+inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const {
+    return at::_ops::scatter__src::call(const_cast(*this), dim, index, src);
+}
+
+// aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::scatter_value::call(const_cast(*this), dim, index, value);
+}
+
+// aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
+inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::scatter__value::call(const_cast(*this), dim, index, value);
+}
+
+// aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor
+inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const {
+    return at::_ops::scatter_reduce::call(const_cast(*this), dim, index, src, reduce);
+}
+
+// aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!)
+inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) const {
+    return at::_ops::scatter__reduce::call(const_cast(*this), dim, index, src, reduce);
+}
+
+// aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor
+inline at::Tensor Tensor::scatter(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const {
+    return at::_ops::scatter_value_reduce::call(const_cast(*this), dim, index, value, reduce);
+}
+
+// aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)
+inline at::Tensor & Tensor::scatter_(int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) const {
+    return at::_ops::scatter__value_reduce::call(const_cast(*this), dim, index, value, reduce);
+}
+
+// aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
+inline at::Tensor Tensor::scatter(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const {
+    return at::_ops::scatter_dimname_src::call(const_cast(*this), dim, index, src);
+}
+
+// aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor
+inline at::Tensor Tensor::scatter(at::Dimname dim, const at::Tensor & index, const at::Scalar & value) const {
+    return at::_ops::scatter_dimname_value::call(const_cast(*this), dim, index, value);
+}
+
+// aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+inline at::Tensor Tensor::scatter_add(int64_t dim, const at::Tensor & index, const at::Tensor & src) const {
+    return at::_ops::scatter_add::call(const_cast(*this), dim, index, src);
+}
+
+// aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
+inline at::Tensor & Tensor::scatter_add_(int64_t dim, const at::Tensor & index, const at::Tensor & src) const {
+    return at::_ops::scatter_add_::call(const_cast(*this), dim, index, src);
+}
+
+// aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
+inline at::Tensor Tensor::scatter_add(at::Dimname dim, const at::Tensor & index, const at::Tensor & src) const {
+    return at::_ops::scatter_add_dimname::call(const_cast(*this), dim, index, src);
+}
+
+// aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor
+inline at::Tensor Tensor::scatter_reduce(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) const {
+    return at::_ops::scatter_reduce_two::call(const_cast(*this), dim, index, src, reduce, include_self);
+}
+
+// aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!)
+inline at::Tensor & Tensor::scatter_reduce_(int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) const {
+    return at::_ops::scatter_reduce__two::call(const_cast(*this), dim, index, src, reduce, include_self);
+}
+
+// aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::eq_(const at::Scalar & other) const {
+    return at::_ops::eq__Scalar::call(const_cast(*this), other);
+}
+
+// aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::eq_(const at::Tensor & other) const {
+    return at::_ops::eq__Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::bitwise_and(const at::Scalar & other) const {
+    return at::_ops::bitwise_and_Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::bitwise_and(const at::Tensor & other) const {
+    return at::_ops::bitwise_and_Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_and_(const at::Scalar & other) const {
+    return at::_ops::bitwise_and__Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_and_(const at::Tensor & other) const {
+    return at::_ops::bitwise_and__Tensor::call(const_cast(*this), other);
+}
+
+// aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::__and__(const at::Scalar & other) const {
+    return at::_ops::__and___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::__and__(const at::Tensor & other) const {
+    return at::_ops::__and___Tensor::call(const_cast(*this), other);
+}
+
+// aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::__iand__(const at::Scalar & other) const {
+    return at::_ops::__iand___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::__iand__(const at::Tensor & other) const {
+    return at::_ops::__iand___Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::bitwise_or(const at::Scalar & other) const {
+    return at::_ops::bitwise_or_Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::bitwise_or(const at::Tensor & other) const {
+    return at::_ops::bitwise_or_Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_or_(const at::Scalar & other) const {
+    return at::_ops::bitwise_or__Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_or_(const at::Tensor & other) const {
+    return at::_ops::bitwise_or__Tensor::call(const_cast(*this), other);
+}
+
+// aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::__or__(const at::Scalar & other) const {
+    return at::_ops::__or___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::__or__(const at::Tensor & other) const {
+    return at::_ops::__or___Tensor::call(const_cast(*this), other);
+}
+
+// aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::__ior__(const at::Scalar & other) const {
+    return at::_ops::__ior___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::__ior__(const at::Tensor & other) const {
+    return at::_ops::__ior___Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::bitwise_xor(const at::Scalar & other) const {
+    return at::_ops::bitwise_xor_Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::bitwise_xor(const at::Tensor & other) const {
+    return at::_ops::bitwise_xor_Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_xor_(const at::Scalar & other) const {
+    return at::_ops::bitwise_xor__Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_xor_(const at::Tensor & other) const {
+    return at::_ops::bitwise_xor__Tensor::call(const_cast(*this), other);
+}
+
+// aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::__xor__(const at::Scalar & other) const {
+    return at::_ops::__xor___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::__xor__(const at::Tensor & other) const {
+    return at::_ops::__xor___Tensor::call(const_cast(*this), other);
+}
+
+// aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::__ixor__(const at::Scalar & other) const {
+    return at::_ops::__ixor___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::__ixor__(const at::Tensor & other) const {
+    return at::_ops::__ixor___Tensor::call(const_cast(*this), other);
+}
+
+// aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::__lshift__(const at::Scalar & other) const {
+    return at::_ops::__lshift___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::__lshift__(const at::Tensor & other) const {
+    return at::_ops::__lshift___Tensor::call(const_cast(*this), other);
+}
+
+// aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::__ilshift__(const at::Scalar & other) const {
+    return at::_ops::__ilshift___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::__ilshift__(const at::Tensor & other) const {
+    return at::_ops::__ilshift___Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::bitwise_left_shift(const at::Tensor & other) const {
+    return at::_ops::bitwise_left_shift_Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_left_shift_(const at::Tensor & other) const {
+    return at::_ops::bitwise_left_shift__Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::bitwise_left_shift(const at::Scalar & other) const {
+    return at::_ops::bitwise_left_shift_Tensor_Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_left_shift_(const at::Scalar & other) const {
+    return at::_ops::bitwise_left_shift__Tensor_Scalar::call(const_cast(*this), other);
+}
+
+// aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::__rshift__(const at::Scalar & other) const {
+    return at::_ops::__rshift___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::__rshift__(const at::Tensor & other) const {
+    return at::_ops::__rshift___Tensor::call(const_cast(*this), other);
+}
+
+// aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::__irshift__(const at::Scalar & other) const {
+    return at::_ops::__irshift___Scalar::call(const_cast(*this), other);
+}
+
+// aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::__irshift__(const at::Tensor & other) const {
+    return at::_ops::__irshift___Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::bitwise_right_shift(const at::Tensor & other) const {
+    return at::_ops::bitwise_right_shift_Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_right_shift_(const at::Tensor & other) const {
+    return at::_ops::bitwise_right_shift__Tensor::call(const_cast(*this), other);
+}
+
+// aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::bitwise_right_shift(const at::Scalar & other) const {
+    return at::_ops::bitwise_right_shift_Tensor_Scalar::call(const_cast(*this), other);
+}
+
+// aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::bitwise_right_shift_(const at::Scalar & other) const {
+    return at::_ops::bitwise_right_shift__Tensor_Scalar::call(const_cast(*this), other);
+}
+
+// aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
+inline at::Tensor & Tensor::tril_(int64_t diagonal) const {
+    return at::_ops::tril_::call(const_cast(*this), diagonal);
+}
+
+// aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
+inline at::Tensor & Tensor::triu_(int64_t diagonal) const {
+    return at::_ops::triu_::call(const_cast(*this), diagonal);
+}
+
+// aten::digamma_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::digamma_() const {
+    return at::_ops::digamma_::call(const_cast(*this));
+}
+
+// aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)
+inline at::Tensor & Tensor::lerp_(const at::Tensor & end, const at::Scalar & weight) const {
+    return at::_ops::lerp__Scalar::call(const_cast(*this), end, weight);
+}
+
+// aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)
+inline at::Tensor & Tensor::lerp_(const at::Tensor & end, const at::Tensor & weight) const {
+    return at::_ops::lerp__Tensor::call(const_cast(*this), end, weight);
+}
+
+// aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+inline at::Tensor & Tensor::addbmm_(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addbmm_::call(const_cast(*this), batch1, batch2, beta, alpha);
+}
+
+// aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+inline at::Tensor Tensor::addbmm(const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) const {
+    return at::_ops::addbmm::call(const_cast(*this), batch1, batch2, beta, alpha);
+}
+
+// aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::random_(int64_t from, c10::optional to, c10::optional generator) const {
+    return at::_ops::random__from::call(const_cast(*this), from, to, generator);
+}
+
+// aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::random_(int64_t to, c10::optional generator) const {
+    return at::_ops::random__to::call(const_cast(*this), to, generator);
+}
+
+// aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::random_(c10::optional generator) const {
+    return at::_ops::random_::call(const_cast(*this), generator);
+}
+
+// aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::uniform_(double from, double to, c10::optional generator) const {
+    return at::_ops::uniform_::call(const_cast(*this), from, to, generator);
+}
+
+// aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::cauchy_(double median, double sigma, c10::optional generator) const {
+    return at::_ops::cauchy_::call(const_cast(*this), median, sigma, generator);
+}
+
+// aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::log_normal_(double mean, double std, c10::optional generator) const {
+    return at::_ops::log_normal_::call(const_cast(*this), mean, std, generator);
+}
+
+// aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::exponential_(double lambd, c10::optional generator) const {
+    return at::_ops::exponential_::call(const_cast(*this), lambd, generator);
+}
+
+// aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::geometric_(double p, c10::optional generator) const {
+    return at::_ops::geometric_::call(const_cast(*this), p, generator);
+}
+
+// aten::diag(Tensor self, int diagonal=0) -> Tensor
+inline at::Tensor Tensor::diag(int64_t diagonal) const {
+    return at::_ops::diag::call(const_cast(*this), diagonal);
+}
+
+// aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor
+inline at::Tensor Tensor::cross(const at::Tensor & other, c10::optional dim) const {
+    return at::_ops::cross::call(const_cast(*this), other, dim);
+}
+
+// aten::triu(Tensor self, int diagonal=0) -> Tensor
+inline at::Tensor Tensor::triu(int64_t diagonal) const {
+    return at::_ops::triu::call(const_cast(*this), diagonal);
+}
+
+// aten::tril(Tensor self, int diagonal=0) -> Tensor
+inline at::Tensor Tensor::tril(int64_t diagonal) const {
+    return at::_ops::tril::call(const_cast(*this), diagonal);
+}
+
+// aten::trace(Tensor self) -> Tensor
+inline at::Tensor Tensor::trace() const {
+    return at::_ops::trace::call(const_cast(*this));
+}
+
+// aten::ne.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::ne(const at::Scalar & other) const {
+    return at::_ops::ne_Scalar::call(const_cast(*this), other);
+}
+
+// aten::ne.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::ne(const at::Tensor & other) const {
+    return at::_ops::ne_Tensor::call(const_cast(*this), other);
+}
+
+// aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::ne_(const at::Scalar & other) const {
+    return at::_ops::ne__Scalar::call(const_cast(*this), other);
+}
+
+// aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::ne_(const at::Tensor & other) const {
+    return at::_ops::ne__Tensor::call(const_cast(*this), other);
+}
+
+// aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::not_equal(const at::Scalar & other) const {
+    return at::_ops::not_equal_Scalar::call(const_cast(*this), other);
+}
+
+// aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::not_equal(const at::Tensor & other) const {
+    return at::_ops::not_equal_Tensor::call(const_cast(*this), other);
+}
+
+// aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::not_equal_(const at::Scalar & other) const {
+    return at::_ops::not_equal__Scalar::call(const_cast(*this), other);
+}
+
+// aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::not_equal_(const at::Tensor & other) const {
+    return at::_ops::not_equal__Tensor::call(const_cast(*this), other);
+}
+
+// aten::eq.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::eq(const at::Scalar & other) const {
+    return at::_ops::eq_Scalar::call(const_cast(*this), other);
+}
+
+// aten::eq.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::eq(const at::Tensor & other) const {
+    return at::_ops::eq_Tensor::call(const_cast(*this), other);
+}
+
+// aten::ge.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::ge(const at::Scalar & other) const {
+    return at::_ops::ge_Scalar::call(const_cast(*this), other);
+}
+
+// aten::ge.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::ge(const at::Tensor & other) const {
+    return at::_ops::ge_Tensor::call(const_cast(*this), other);
+}
+
+// aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::ge_(const at::Scalar & other) const {
+    return at::_ops::ge__Scalar::call(const_cast(*this), other);
+}
+
+// aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::ge_(const at::Tensor & other) const {
+    return at::_ops::ge__Tensor::call(const_cast(*this), other);
+}
+
+// aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::greater_equal(const at::Scalar & other) const {
+    return at::_ops::greater_equal_Scalar::call(const_cast(*this), other);
+}
+
+// aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::greater_equal(const at::Tensor & other) const {
+    return at::_ops::greater_equal_Tensor::call(const_cast(*this), other);
+}
+
+// aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::greater_equal_(const at::Scalar & other) const {
+    return at::_ops::greater_equal__Scalar::call(const_cast(*this), other);
+}
+
+// aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::greater_equal_(const at::Tensor & other) const {
+    return at::_ops::greater_equal__Tensor::call(const_cast(*this), other);
+}
+
+// aten::le.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::le(const at::Scalar & other) const {
+    return at::_ops::le_Scalar::call(const_cast(*this), other);
+}
+
+// aten::le.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::le(const at::Tensor & other) const {
+    return at::_ops::le_Tensor::call(const_cast(*this), other);
+}
+
+// aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::le_(const at::Scalar & other) const {
+    return at::_ops::le__Scalar::call(const_cast(*this), other);
+}
+
+// aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::le_(const at::Tensor & other) const {
+    return at::_ops::le__Tensor::call(const_cast(*this), other);
+}
+
+// aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::less_equal(const at::Scalar & other) const {
+    return at::_ops::less_equal_Scalar::call(const_cast(*this), other);
+}
+
+// aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::less_equal(const at::Tensor & other) const {
+    return at::_ops::less_equal_Tensor::call(const_cast(*this), other);
+}
+
+// aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::less_equal_(const at::Scalar & other) const {
+    return at::_ops::less_equal__Scalar::call(const_cast(*this), other);
+}
+
+// aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::less_equal_(const at::Tensor & other) const {
+    return at::_ops::less_equal__Tensor::call(const_cast(*this), other);
+}
+
+// aten::gt.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::gt(const at::Scalar & other) const {
+    return at::_ops::gt_Scalar::call(const_cast(*this), other);
+}
+
+// aten::gt.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::gt(const at::Tensor & other) const {
+    return at::_ops::gt_Tensor::call(const_cast(*this), other);
+}
+
+// aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::gt_(const at::Scalar & other) const {
+    return at::_ops::gt__Scalar::call(const_cast(*this), other);
+}
+
+// aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::gt_(const at::Tensor & other) const {
+    return at::_ops::gt__Tensor::call(const_cast(*this), other);
+}
+
+// aten::greater.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::greater(const at::Scalar & other) const {
+    return at::_ops::greater_Scalar::call(const_cast(*this), other);
+}
+
+// aten::greater.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::greater(const at::Tensor & other) const {
+    return at::_ops::greater_Tensor::call(const_cast(*this), other);
+}
+
+// aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::greater_(const at::Scalar & other) const {
+    return at::_ops::greater__Scalar::call(const_cast(*this), other);
+}
+
+// aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::greater_(const at::Tensor & other) const {
+    return at::_ops::greater__Tensor::call(const_cast(*this), other);
+}
+
+// aten::lt.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::lt(const at::Scalar & other) const {
+    return at::_ops::lt_Scalar::call(const_cast(*this), other);
+}
+
+// aten::lt.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::lt(const at::Tensor & other) const {
+    return at::_ops::lt_Tensor::call(const_cast(*this), other);
+}
+
+// aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::lt_(const at::Scalar & other) const {
+    return at::_ops::lt__Scalar::call(const_cast(*this), other);
+}
+
+// aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::lt_(const at::Tensor & other) const {
+    return at::_ops::lt__Tensor::call(const_cast(*this), other);
+}
+
+// aten::less.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::less(const at::Scalar & other) const {
+    return at::_ops::less_Scalar::call(const_cast(*this), other);
+}
+
+// aten::less.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::less(const at::Tensor & other) const {
+    return at::_ops::less_Tensor::call(const_cast(*this), other);
+}
+
+// aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::less_(const at::Scalar & other) const {
+    return at::_ops::less__Scalar::call(const_cast(*this), other);
+}
+
+// aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::less_(const at::Tensor & other) const {
+    return at::_ops::less__Tensor::call(const_cast(*this), other);
+}
+
+// aten::take(Tensor self, Tensor index) -> Tensor
+inline at::Tensor Tensor::take(const at::Tensor & index) const {
+    return at::_ops::take::call(const_cast(*this), index);
+}
+
+// aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor
+inline at::Tensor Tensor::take_along_dim(const at::Tensor & indices, c10::optional dim) const {
+    return at::_ops::take_along_dim::call(const_cast(*this), indices, dim);
+}
+
+// aten::index_select(Tensor self, int dim, Tensor index) -> Tensor
+inline at::Tensor Tensor::index_select(int64_t dim, const at::Tensor & index) const {
+    return at::_ops::index_select::call(const_cast(*this), dim, index);
+}
+
+// aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor
+inline at::Tensor Tensor::index_select(at::Dimname dim, const at::Tensor & index) const {
+    return at::_ops::index_select_dimname::call(const_cast(*this), dim, index);
+}
+
+// aten::masked_select(Tensor self, Tensor mask) -> Tensor
+inline at::Tensor Tensor::masked_select(const at::Tensor & mask) const {
+    return at::_ops::masked_select::call(const_cast(*this), mask);
+}
+
+// aten::nonzero(Tensor self) -> Tensor
+inline at::Tensor Tensor::nonzero() const {
+    return at::_ops::nonzero::call(const_cast(*this));
+}
+
+// aten::nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor
+inline at::Tensor Tensor::nonzero_static(int64_t size, int64_t fill_value) const {
+    return at::_ops::nonzero_static::call(const_cast(*this), size, fill_value);
+}
+
+// aten::nonzero_numpy(Tensor self) -> Tensor[]
+inline ::std::vector Tensor::nonzero_numpy() const {
+    return at::_ops::nonzero_numpy::call(const_cast(*this));
+}
+
+// aten::argwhere(Tensor self) -> Tensor
+inline at::Tensor Tensor::argwhere() const {
+    return at::_ops::argwhere::call(const_cast(*this));
+}
+
+// aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+inline at::Tensor Tensor::gather(int64_t dim, const at::Tensor & index, bool sparse_grad) const {
+    return at::_ops::gather::call(const_cast(*this), dim, index, sparse_grad);
+}
+
+// aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+inline at::Tensor Tensor::gather(at::Dimname dim, const at::Tensor & index, bool sparse_grad) const {
+    return at::_ops::gather_dimname::call(const_cast(*this), dim, index, sparse_grad);
+}
+
+// aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+inline at::Tensor Tensor::addcmul(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const {
+    return at::_ops::addcmul::call(const_cast(*this), tensor1, tensor2, value);
+}
+
+// aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
+inline at::Tensor & Tensor::addcmul_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const {
+    return at::_ops::addcmul_::call(const_cast(*this), tensor1, tensor2, value);
+}
+
+// aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+inline at::Tensor Tensor::addcdiv(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const {
+    return at::_ops::addcdiv::call(const_cast(*this), tensor1, tensor2, value);
+}
+
+// aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
+inline at::Tensor & Tensor::addcdiv_(const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) const {
+    return at::_ops::addcdiv_::call(const_cast(*this), tensor1, tensor2, value);
+}
+
+// aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)
+inline ::std::tuple Tensor::triangular_solve(const at::Tensor & A, bool upper, bool transpose, bool unitriangular) const {
+    return at::_ops::triangular_solve::call(const_cast(*this), A, upper, transpose, unitriangular);
+}
+
+// aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
+inline ::std::tuple Tensor::svd(bool some, bool compute_uv) const {
+    return at::_ops::svd::call(const_cast(*this), some, compute_uv);
+}
+
+// aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a)
+inline at::Tensor Tensor::swapaxes(int64_t axis0, int64_t axis1) const {
+    return at::_ops::swapaxes::call(const_cast(*this), axis0, axis1);
+}
+
+// aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!)
+inline at::Tensor & Tensor::swapaxes_(int64_t axis0, int64_t axis1) const {
+    return at::_ops::swapaxes_::call(const_cast(*this), axis0, axis1);
+}
+
+// aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+inline at::Tensor Tensor::swapdims(int64_t dim0, int64_t dim1) const {
+    return at::_ops::swapdims::call(const_cast(*this), dim0, dim1);
+}
+
+// aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+inline at::Tensor & Tensor::swapdims_(int64_t dim0, int64_t dim1) const {
+    return at::_ops::swapdims_::call(const_cast(*this), dim0, dim1);
+}
+
+// aten::cholesky(Tensor self, bool upper=False) -> Tensor
+inline at::Tensor Tensor::cholesky(bool upper) const {
+    return at::_ops::cholesky::call(const_cast(*this), upper);
+}
+
+// aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
+inline at::Tensor Tensor::cholesky_solve(const at::Tensor & input2, bool upper) const {
+    return at::_ops::cholesky_solve::call(const_cast(*this), input2, upper);
+}
+
+// aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor
+inline at::Tensor Tensor::cholesky_inverse(bool upper) const {
+    return at::_ops::cholesky_inverse::call(const_cast(*this), upper);
+}
+
+// aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)
+inline ::std::tuple Tensor::qr(bool some) const {
+    return at::_ops::qr::call(const_cast(*this), some);
+}
+
+// aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)
+inline ::std::tuple Tensor::geqrf() const {
+    return at::_ops::geqrf::call(const_cast(*this));
+}
+
+// aten::orgqr(Tensor self, Tensor input2) -> Tensor
+inline at::Tensor Tensor::orgqr(const at::Tensor & input2) const {
+    return at::_ops::orgqr::call(const_cast(*this), input2);
+}
+
+// aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor
+inline at::Tensor Tensor::ormqr(const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose) const {
+    return at::_ops::ormqr::call(const_cast(*this), input2, input3, left, transpose);
+}
+
+// aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
+inline at::Tensor Tensor::lu_solve(const at::Tensor & LU_data, const at::Tensor & LU_pivots) const {
+    return at::_ops::lu_solve::call(const_cast(*this), LU_data, LU_pivots);
+}
+
+// aten::multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
+inline at::Tensor Tensor::multinomial(int64_t num_samples, bool replacement, c10::optional generator) const {
+    return at::_ops::multinomial::call(const_cast(*this), num_samples, replacement, generator);
+}
+
+// aten::lgamma_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::lgamma_() const {
+    return at::_ops::lgamma_::call(const_cast(*this));
+}
+
+// aten::lgamma(Tensor self) -> Tensor
+inline at::Tensor Tensor::lgamma() const {
+    return at::_ops::lgamma::call(const_cast(*this));
+}
+
+// aten::digamma(Tensor self) -> Tensor
+inline at::Tensor Tensor::digamma() const {
+    return at::_ops::digamma::call(const_cast(*this));
+}
+
+// aten::polygamma(int n, Tensor self) -> Tensor
+inline at::Tensor Tensor::polygamma(int64_t n) const {
+    return at::_ops::polygamma::call(n, const_cast(*this));
+}
+
+// aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
+inline at::Tensor & Tensor::polygamma_(int64_t n) const {
+    return at::_ops::polygamma_::call(const_cast(*this), n);
+}
+
+// aten::erfinv(Tensor self) -> Tensor
+inline at::Tensor Tensor::erfinv() const {
+    return at::_ops::erfinv::call(const_cast(*this));
+}
+
+// aten::erfinv_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::erfinv_() const {
+    return at::_ops::erfinv_::call(const_cast(*this));
+}
+
+// aten::i0(Tensor self) -> Tensor
+inline at::Tensor Tensor::i0() const {
+    return at::_ops::i0::call(const_cast(*this));
+}
+
+// aten::i0_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::i0_() const {
+    return at::_ops::i0_::call(const_cast(*this));
+}
+
+// aten::sign(Tensor self) -> Tensor
+inline at::Tensor Tensor::sign() const {
+    return at::_ops::sign::call(const_cast(*this));
+}
+
+// aten::sign_(Tensor(a!) self) -> Tensor(a!)
+inline at::Tensor & Tensor::sign_() const {
+    return at::_ops::sign_::call(const_cast(*this));
+}
+
+// aten::signbit(Tensor self) -> Tensor
+inline at::Tensor Tensor::signbit() const {
+    return at::_ops::signbit::call(const_cast(*this));
+}
+
+// aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
+inline at::Tensor Tensor::dist(const at::Tensor & other, const at::Scalar & p) const {
+    return at::_ops::dist::call(const_cast(*this), other, p);
+}
+
+// aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::atan2_(const at::Tensor & other) const {
+    return at::_ops::atan2_::call(const_cast(*this), other);
+}
+
+// aten::atan2(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::atan2(const at::Tensor & other) const {
+    return at::_ops::atan2::call(const_cast(*this), other);
+}
+
+// aten::arctan2(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::arctan2(const at::Tensor & other) const {
+    return at::_ops::arctan2::call(const_cast(*this), other);
+}
+
+// aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::arctan2_(const at::Tensor & other) const {
+    return at::_ops::arctan2_::call(const_cast(*this), other);
+}
+
+// aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
+inline at::Tensor Tensor::lerp(const at::Tensor & end, const at::Scalar & weight) const {
+    return at::_ops::lerp_Scalar::call(const_cast(*this), end, weight);
+}
+
+// aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor
+inline at::Tensor Tensor::lerp(const at::Tensor & end, const at::Tensor & weight) const {
+    return at::_ops::lerp_Tensor::call(const_cast(*this), end, weight);
+}
+
+// aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
+inline at::Tensor Tensor::histc(int64_t bins, const at::Scalar & min, const at::Scalar & max) const {
+    return at::_ops::histc::call(const_cast(*this), bins, min, max);
+}
+
+// aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
+inline ::std::tuple Tensor::histogram(const at::Tensor & bins, const c10::optional & weight, bool density) const {
+    return at::_ops::histogram_bins_tensor::call(const_cast(*this), bins, weight, density);
+}
+
+// aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
+inline ::std::tuple Tensor::histogram(int64_t bins, c10::optional> range, const c10::optional & weight, bool density) const {
+    return at::_ops::histogram_bin_ct::call(const_cast(*this), bins, range, weight, density);
+}
+
+// aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::fmod(const at::Scalar & other) const {
+    return at::_ops::fmod_Scalar::call(const_cast(*this), other);
+}
+
+// aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::fmod_(const at::Scalar & other) const {
+    return at::_ops::fmod__Scalar::call(const_cast(*this), other);
+}
+
+// aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::fmod(const at::Tensor & other) const {
+    return at::_ops::fmod_Tensor::call(const_cast(*this), other);
+}
+
+// aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::fmod_(const at::Tensor & other) const {
+    return at::_ops::fmod__Tensor::call(const_cast(*this), other);
+}
+
+// aten::hypot(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::hypot(const at::Tensor & other) const {
+    return at::_ops::hypot::call(const_cast(*this), other);
+}
+
+// aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::hypot_(const at::Tensor & other) const {
+    return at::_ops::hypot_::call(const_cast(*this), other);
+}
+
+// aten::igamma(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::igamma(const at::Tensor & other) const {
+    return at::_ops::igamma::call(const_cast(*this), other);
+}
+
+// aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::igamma_(const at::Tensor & other) const {
+    return at::_ops::igamma_::call(const_cast(*this), other);
+}
+
+// aten::igammac(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::igammac(const at::Tensor & other) const {
+    return at::_ops::igammac::call(const_cast(*this), other);
+}
+
+// aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::igammac_(const at::Tensor & other) const {
+    return at::_ops::igammac_::call(const_cast(*this), other);
+}
+
+// aten::nextafter(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::nextafter(const at::Tensor & other) const {
+    return at::_ops::nextafter::call(const_cast(*this), other);
+}
+
+// aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::nextafter_(const at::Tensor & other) const {
+    return at::_ops::nextafter_::call(const_cast(*this), other);
+}
+
+// aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor
+inline at::Tensor Tensor::remainder(const at::Scalar & other) const {
+    return at::_ops::remainder_Scalar::call(const_cast(*this), other);
+}
+
+// aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+inline at::Tensor & Tensor::remainder_(const at::Scalar & other) const {
+    return at::_ops::remainder__Scalar::call(const_cast(*this), other);
+}
+
+// aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::remainder(const at::Tensor & other) const {
+    return at::_ops::remainder_Tensor::call(const_cast(*this), other);
+}
+
+// aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+inline at::Tensor & Tensor::remainder_(const at::Tensor & other) const {
+    return at::_ops::remainder__Tensor::call(const_cast(*this), other);
+}
+
+// aten::min(Tensor self) -> Tensor
+inline at::Tensor Tensor::min() const {
+    return at::_ops::min::call(const_cast(*this));
+}
+
+// aten::fmin(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::fmin(const at::Tensor & other) const {
+    return at::_ops::fmin::call(const_cast(*this), other);
+}
+
+// aten::max(Tensor self) -> Tensor
+inline at::Tensor Tensor::max() const {
+    return at::_ops::max::call(const_cast(*this));
+}
+
+// aten::fmax(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::fmax(const at::Tensor & other) const {
+    return at::_ops::fmax::call(const_cast(*this), other);
+}
+
+// aten::maximum(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::maximum(const at::Tensor & other) const {
+    return at::_ops::maximum::call(const_cast(*this), other);
+}
+
+// aten::max.other(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::max(const at::Tensor & other) const {
+    return at::_ops::max_other::call(const_cast(*this), other);
+}
+
+// aten::minimum(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::minimum(const at::Tensor & other) const {
+    return at::_ops::minimum::call(const_cast(*this), other);
+}
+
+// aten::min.other(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::min(const at::Tensor & other) const {
+    return at::_ops::min_other::call(const_cast(*this), other);
+}
+
+// aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+inline at::Tensor Tensor::quantile(const at::Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation) const {
+    return at::_ops::quantile::call(const_cast(*this), q, dim, keepdim, interpolation);
+}
+
+// aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+inline at::Tensor Tensor::quantile(double q, c10::optional dim, bool keepdim, c10::string_view interpolation) const {
+    return at::_ops::quantile_scalar::call(const_cast(*this), q, dim, keepdim, interpolation);
+}
+
+// aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+inline at::Tensor Tensor::nanquantile(const at::Tensor & q, c10::optional dim, bool keepdim, c10::string_view interpolation) const {
+    return at::_ops::nanquantile::call(const_cast(*this), q, dim, keepdim, interpolation);
+}
+
+// aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+inline at::Tensor Tensor::nanquantile(double q, c10::optional dim, bool keepdim, c10::string_view interpolation) const {
+    return at::_ops::nanquantile_scalar::call(const_cast(*this), q, dim, keepdim, interpolation);
+}
+
+// aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::sort(int64_t dim, bool descending) const {
+    return at::_ops::sort::call(const_cast(*this), dim, descending);
+}
+
+// aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::sort(c10::optional stable, int64_t dim, bool descending) const {
+    return at::_ops::sort_stable::call(const_cast(*this), stable, dim, descending);
+}
+
+// aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::sort(at::Dimname dim, bool descending) const {
+    return at::_ops::sort_dimname::call(const_cast(*this), dim, descending);
+}
+
+// aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::sort(c10::optional stable, at::Dimname dim, bool descending) const {
+    return at::_ops::sort_dimname_stable::call(const_cast(*this), stable, dim, descending);
+}
+
+// aten::msort(Tensor self) -> Tensor
+inline at::Tensor Tensor::msort() const {
+    return at::_ops::msort::call(const_cast(*this));
+}
+
+// aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
+inline at::Tensor Tensor::argsort(int64_t dim, bool descending) const {
+    return at::_ops::argsort::call(const_cast(*this), dim, descending);
+}
+
+// aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor
+inline at::Tensor Tensor::argsort(bool stable, int64_t dim, bool descending) const {
+    return at::_ops::argsort_stable::call(const_cast(*this), stable, dim, descending);
+}
+
+// aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor
+inline at::Tensor Tensor::argsort(at::Dimname dim, bool descending) const {
+    return at::_ops::argsort_dimname::call(const_cast(*this), dim, descending);
+}
+
+// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const {
+    return at::_ops::topk::call(const_cast(*this), k, dim, largest, sorted);
+}
+
+// aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
+inline ::std::tuple Tensor::topk_symint(c10::SymInt k, int64_t dim, bool largest, bool sorted) const {
+    return at::_ops::topk::call(const_cast(*this), k, dim, largest, sorted);
+}
+
+// aten::all(Tensor self) -> Tensor
+inline at::Tensor Tensor::all() const {
+    return at::_ops::all::call(const_cast(*this));
+}
+
+// aten::any(Tensor self) -> Tensor
+inline at::Tensor Tensor::any() const {
+    return at::_ops::any::call(const_cast(*this));
+}
+
+// aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
+inline at::Tensor Tensor::renorm(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const {
+    return at::_ops::renorm::call(const_cast(*this), p, dim, maxnorm);
+}
+
+// aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)
+inline at::Tensor & Tensor::renorm_(const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) const {
+    return at::_ops::renorm_::call(const_cast(*this), p, dim, maxnorm);
+}
+
+// aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
+inline at::Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const {
+    return at::_ops::unfold::call(const_cast(*this), dimension, size, step);
+}
+
+// aten::equal(Tensor self, Tensor other) -> bool
+inline bool Tensor::equal(const at::Tensor & other) const {
+    return at::_ops::equal::call(const_cast(*this), other);
+}
+
+// aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+inline at::Tensor Tensor::pow(const at::Tensor & exponent) const {
+    return at::_ops::pow_Tensor_Tensor::call(const_cast(*this), exponent);
+}
+
+// aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+inline at::Tensor Tensor::pow(const at::Scalar & exponent) const {
+    return at::_ops::pow_Tensor_Scalar::call(const_cast(*this), exponent);
+}
+
+// aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+inline at::Tensor & Tensor::pow_(const at::Scalar & exponent) const {
+    return at::_ops::pow__Scalar::call(const_cast(*this), exponent);
+}
+
+// aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+inline at::Tensor & Tensor::pow_(const at::Tensor & exponent) const {
+    return at::_ops::pow__Tensor::call(const_cast(*this), exponent);
+}
+
+// aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+inline at::Tensor Tensor::float_power(const at::Tensor & exponent) const {
+    return at::_ops::float_power_Tensor_Tensor::call(const_cast(*this), exponent);
+}
+
+// aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+inline at::Tensor Tensor::float_power(const at::Scalar & exponent) const {
+    return at::_ops::float_power_Tensor_Scalar::call(const_cast(*this), exponent);
+}
+
+// aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+inline at::Tensor & Tensor::float_power_(const at::Scalar & exponent) const {
+    return at::_ops::float_power__Scalar::call(const_cast(*this), exponent);
+}
+
+// aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+inline at::Tensor & Tensor::float_power_(const at::Tensor & exponent) const {
+    return at::_ops::float_power__Tensor::call(const_cast(*this), exponent);
+}
+
+// aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
+inline at::Tensor & Tensor::normal_(double mean, double std, c10::optional generator) const {
+    return at::_ops::normal_::call(const_cast(*this), mean, std, generator);
+}
+
+// aten::alias(Tensor(a) self) -> Tensor(a)
+inline at::Tensor Tensor::alias() const {
+    return at::_ops::alias::call(const_cast(*this));
+}
+
+// aten::isfinite(Tensor self) -> Tensor
+inline at::Tensor Tensor::isfinite() const {
+    return at::_ops::isfinite::call(const_cast(*this));
+}
+
+// aten::isinf(Tensor self) -> Tensor
+inline at::Tensor Tensor::isinf() const {
+    return at::_ops::isinf::call(const_cast(*this));
+}
+
+// aten::record_stream(Tensor(a!) self, Stream s) -> ()
+inline void Tensor::record_stream(at::Stream s) const {
+    return at::_ops::record_stream::call(const_cast(*this), s);
+}
+
+// aten::isposinf(Tensor self) -> Tensor
+inline at::Tensor Tensor::isposinf() const {
+    return at::_ops::isposinf::call(const_cast(*this));
+}
+
+// aten::isneginf(Tensor self) -> Tensor
+inline at::Tensor Tensor::isneginf() const {
+    return at::_ops::isneginf::call(const_cast(*this));
+}
+
+// aten::det(Tensor self) -> Tensor
+inline at::Tensor Tensor::det() const {
+    return at::_ops::det::call(const_cast(*this));
+}
+
+// aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
+inline ::std::tuple Tensor::slogdet() const {
+    return at::_ops::slogdet::call(const_cast(*this));
+}
+
+// aten::logdet(Tensor self) -> Tensor
+inline at::Tensor Tensor::logdet() const {
+    return at::_ops::logdet::call(const_cast(*this));
+}
+
+// aten::inverse(Tensor self) -> Tensor
+inline at::Tensor Tensor::inverse() const {
+    return at::_ops::inverse::call(const_cast(*this));
+}
+
+// aten::inner(Tensor self, Tensor other) -> Tensor
+inline at::Tensor Tensor::inner(const at::Tensor & other) const {
+    return at::_ops::inner::call(const_cast(*this), other);
+}
+
+// aten::outer(Tensor self, Tensor vec2) -> Tensor
+inline at::Tensor Tensor::outer(const at::Tensor & vec2) const {
+    return at::_ops::outer::call(const_cast(*this), vec2);
+}
+
+// aten::ger(Tensor self, Tensor vec2) -> Tensor
+inline at::Tensor Tensor::ger(const at::Tensor & vec2) const {
+    return at::_ops::ger::call(const_cast(*this), vec2);
+}
+
+// aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
+inline at::Tensor Tensor::to_padded_tensor(double padding, at::OptionalIntArrayRef output_size) const {
+    return at::_ops::to_padded_tensor::call(const_cast(*this), padding, output_size.has_value() ? c10::make_optional(c10::fromIntArrayRefSlow(*output_size)) : c10::nullopt);
+}
+
+// aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
+inline at::Tensor Tensor::to_padded_tensor_symint(double padding, at::OptionalSymIntArrayRef output_size) const {
+    return at::_ops::to_padded_tensor::call(const_cast(*this), padding, output_size);
+}
+} // namespace at
+
+
+namespace c10 {
+template <>
+struct MaybeOwnedTraits {
+  using owned_type = at::Tensor;
+  using borrow_type = at::Tensor;
+
+  static borrow_type createBorrow(const owned_type& from) {
+    // NOTE: this can be implemented without the special
+    // unsafe_borrow_t Tensor constructor as
+    //
+    // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl()));
+    //
+    // but that hurts inlining due to the nullptr check in the
+    // Tensor(c10::intrusive_ptr<...>) constructor. We already know
+    // that from.impl_ isn't null because from is a valid Tensor, so
+    // we needn't do the check again. (using __builtin_assume can
+    // avoid this, but wouldn't be portable to MSVC.)
+    return borrow_type(borrow_type::unsafe_borrow_t{}, from);
+  }
+
+  static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
+    lhs.unsafeReleaseTensorImpl();
+    // See above note: this can be implemented with public API
+    // similarly to createBorrow(), but that would hurt inlining.
+    lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
+  }
+
+  static void destroyBorrow(borrow_type& toDestroy) {
+    toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
+  }
+
+  static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
+    return borrow;
+  }
+
+  static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
+    return &borrow;
+  }
+
+  static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
+    return true;
+  }
+};
+
+template <>
+struct ExclusivelyOwnedTraits {
+  using repr_type = at::Tensor;
+  using pointer_type = at::Tensor*;
+  using const_pointer_type = const at::Tensor*;
+
+  static repr_type nullRepr() {
+    return at::Tensor();
+  }
+
+  template 
+  static repr_type createInPlace(Args&&... args) {
+    return at::Tensor(std::forward(args)...);
+  }
+
+  static repr_type moveToRepr(at::Tensor&& x) {
+    return std::move(x);
+  }
+
+  static void destroyOwned(at::Tensor& x) {
+    return ExclusivelyOwnedTraits::destroyOwned(x);
+  }
+
+  static at::Tensor take(at::Tensor& x) {
+    return std::move(x);
+  }
+
+  static pointer_type getImpl(repr_type& x) {
+    return &x;
+  }
+
+  static const_pointer_type getImpl(const repr_type& x) {
+    return &x;
+  }
+};
+} // namespace c10
+
+namespace at {
+
+inline c10::MaybeOwned borrow_from_optional_tensor(
+    const c10::optional& opt) {
+  return opt.has_value()
+    ? c10::MaybeOwned::borrowed(*opt)
+    : c10::MaybeOwned::owned(std::in_place);
+}
+
+inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & {
+  if (is_contiguous(memory_format)) {
+    return c10::MaybeOwned::borrowed(*this);
+  } else {
+    return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format));
+  }
+}
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..ad5d009e70553593a4d20106eacc8908717d8ab3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/TorchDispatchUtils.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace impl {
+
+TORCH_API bool tensor_has_dispatch(const at::Tensor& t);
+TORCH_API bool tensorlist_has_dispatch(at::ITensorListRef li);
+TORCH_API bool tensorlist_has_dispatch(const c10::List>& li);
+using c10::impl::dispatch_mode_enabled;
+
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h b/MLPY/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h
new file mode 100644
index 0000000000000000000000000000000000000000..2052e4c47df6a17159070091a323c1e604512590
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/TransformationHelper.h
@@ -0,0 +1,173 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+// Using DistAccumType in accumulate types for distributions.
+// Note: Ideally we'd be using ATen/AccumulateType.h but looks
+// like the there is some inconsistency in how accumulate types
+// are mapped currently, e.g. for the cpu side, float is mapped
+// to double.
+template 
+struct DistAccumType {  };
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+template <> struct DistAccumType { using type = float; };
+#endif
+template <> struct DistAccumType { using type = float; };
+template <> struct DistAccumType { using type = float; };
+template <> struct DistAccumType { using type = float; };
+template <> struct DistAccumType { using type = double; };
+
+template 
+using dist_acctype = typename DistAccumType::type;
+
+namespace transformation {
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when both `from` and `to` are specified.
+ * `range` is `to - from`
+ * `base` is `from`
+ */
+template 
+C10_HOST_DEVICE inline T uniform_int_from_to(V val, uint64_t range, int64_t base) {
+  return static_cast(static_cast((val % range) + base));
+}
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when `from=min_value(int64_t)` and to=None
+ */
+template 
+C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
+  return static_cast(static_cast(val));
+}
+
+/**
+ * A transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`.
+ * In order to prevent compiler warnings reported in GitHub issue 46391, T can't be float or double
+ * in this overloaded version
+ */
+template 
+C10_HOST_DEVICE inline typename std::enable_if::value), T>::type uniform_int(V val) {
+  if constexpr (std::is_same_v) {
+    return static_cast(val & 1);
+  } else if constexpr (std::is_same_v) {
+    return static_cast(val % (static_cast(std::numeric_limits::max()) + 1));
+  } else if constexpr (std::is_same_v || std::is_same::value) {
+    return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1));
+  } else if constexpr (std::is_integral_v) {
+    return static_cast(val % (static_cast(std::numeric_limits::max()) + 1));
+  } else {
+    assert(false);
+    return 0;
+  }
+}
+
+/**
+ * An overloaded transformation function for `torch.Tensor.random_()`, when used without specifying `from` and `to`,
+ * added to fix compiler warnings reported in GitHub issue 46391. T is either float or double in this version.
+ */
+template
+C10_HOST_DEVICE inline typename std::enable_if::value, T>::type uniform_int(V val) {
+  return static_cast(val % static_cast((1ULL << std::numeric_limits::digits) + 1));
+}
+
+template 
+C10_HOST_DEVICE inline dist_acctype uniform_real(V val, T from, T to) {
+  constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1);
+  constexpr auto DIVISOR = static_cast>(1) / (static_cast(1) << std::numeric_limits::digits);
+  dist_acctype x = (val & MASK) * DIVISOR;
+  return (x * (to - from) + from);
+}
+
+/**
+ * Transforms normally distributed `val` with mean 0.0 and standard deviation 1.0 to
+ * normally distributed with `mean` and standard deviation `std`.
+ */
+template 
+C10_HOST_DEVICE inline T normal(T val, T mean, T std) {
+  return val * std + mean;
+}
+
+/**
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
+ * Cauchy distribution with location parameter `median` and scale parameter `sigma`.
+ */
+template 
+C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
+  // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
+  // __tanf overflows and returns `inf/-inf` when (val > 1 - eps) or (val < 0 + eps),
+  // thus we clip those values.
+  constexpr T eps = std::numeric_limits::epsilon();
+  constexpr T one_minus_eps = 1 - eps;
+  constexpr T zero_plus_eps = 0 + eps;
+  val = (val > one_minus_eps ? one_minus_eps : val);
+  val = (val < zero_plus_eps ? zero_plus_eps : val);
+  return median + sigma * at::tan(c10::pi * (val - static_cast(0.5)));
+}
+
+template <>
+C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
+  // https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
+  return median + sigma * at::tan(c10::pi * (val - static_cast(0.5)));
+}
+
+/**
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
+ * exponentially distributed with `lambda` parameter of the distribution.
+ */
+template 
+C10_HOST_DEVICE inline T exponential(T val, T lambda) {
+  // https://en.wikipedia.org/wiki/Exponential_distribution#Generating_exponential_variates
+  // Different implementations for CUDA and CPU to preserve original logic
+  // TODO: must be investigated and unified!!!
+  // https://github.com/pytorch/pytorch/issues/38662
+#if defined(__CUDACC__) || defined(__HIPCC__)
+      // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
+      // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
+      // we need log to be not 0, and not underflow when converted to half
+      // fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
+  auto log = val >= static_cast(1.) - std::numeric_limits::epsilon() / 2
+      ? -std::numeric_limits::epsilon() / 2
+      : at::log(val);
+  return static_cast(-1.0) / lambda * log;
+#else
+  return static_cast(-1.0) / lambda * at::log1p(-val);
+#endif
+}
+
+/**
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
+ * geometrically distributed with success probability `p`.
+ */
+template 
+C10_HOST_DEVICE inline T geometric(T val, T p) {
+  // https://en.wikipedia.org/wiki/Geometric_distribution#Related_distributions
+  return static_cast(::ceil(at::log(val) / at::log1p(-p)));
+}
+
+/**
+ * Transforms normally distributed `val` to log-normally distributed.
+ */
+template 
+C10_HOST_DEVICE inline T log_normal(T val) {
+  // https://en.wikipedia.org/wiki/Log-normal_distribution#Mode,_median,_quantiles
+  return at::exp(val);
+}
+
+/**
+ * Transforms uniformly distributed `val` between 0.0 and 1.0 to
+ * bernoulli distributed with success probability `p`.
+ */
+template 
+C10_HOST_DEVICE inline T bernoulli(T val, T p) {
+  return val < p;
+}
+
+}} // namespace at::transformation
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..49612392cc4f66224d30e9480522acca886fd293
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/UndefinedTensorImpl.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h b/MLPY/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h
new file mode 100644
index 0000000000000000000000000000000000000000..a47ad1586d70587faf7dea99d50b20dbea3a344f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/UnsafeFromTH.h
@@ -0,0 +1,21 @@
+#pragma once
+#include 
+
+namespace at {
+
+inline Tensor unsafeTensorFromTH(void * th_pointer, bool retain) {
+  auto tensor_impl = c10::intrusive_ptr::reclaim(static_cast(th_pointer));
+  if (retain && tensor_impl.get() != UndefinedTensorImpl::singleton()) {
+    c10::raw::intrusive_ptr::incref(tensor_impl.get());
+  }
+  return Tensor(std::move(tensor_impl));
+}
+
+inline Storage unsafeStorageFromTH(void * th_pointer, bool retain) {
+  if (retain && th_pointer) {
+    c10::raw::intrusive_ptr::incref(static_cast(th_pointer));
+  }
+  return Storage(c10::intrusive_ptr::reclaim(static_cast(th_pointer)));
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..f315a092b63b7993527e6483782a8678bb1583be
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/VariableHooksInterface.h
@@ -0,0 +1,75 @@
+#pragma once
+
+#include 
+#include 
+
+// A little explanation about why this file exists at all.  We have
+// a few methods on Tensor class which require access to reified access to
+// AutogradMeta.  In open source, this isn't a big deal: we just access
+// torch/csrc/autograd/variable.h from aten/src/ATen/core/Tensor.cpp and
+// we can put the definitions inline.  This is because everything gets balled
+// into a single dynamic library in the end.
+//
+// However, inside our Facebook internal version of our build system, we
+// have a split between aten and torch/csrc.  So we cannot simply just
+// cross this boundary.  "Now wait," you might say, "Why don't we just
+// merge the libraries inside Facebook".  Well, the problem is that there
+// are some downstream applications which are at binary size limit, and
+// incorporating all of the extra code from libtorch would push them
+// over (admarket/adreview/service:adreviewservice, see also
+// https://github.com/pytorch/pytorch/pull/29299)  So if you want to do that,
+// we have to fix all of the services like this.
+//
+// I didn't want to block eliminating Tensor-Variable on this work, so I
+// had to introduce another dynamic dispatch to get to the variable
+// implementations (which live in torch/csrc/autograd/variable.cpp, FYI).
+//
+// I also considered using our existing dynamic dispatch mechanism, c10
+// dispatcher, to do this.  However, (1) some of the functions on Tensor
+// have weird signatures that are not supported by autograd, and (2)
+// see this bug https://github.com/pytorch/pytorch/issues/30102
+
+namespace torch { namespace autograd {
+
+struct Node;
+
+}} // namespace torch::autograd
+
+namespace at {
+namespace impl {
+
+struct TORCH_API VariableHooksInterface {
+  virtual ~VariableHooksInterface() = default;
+  virtual TensorBase tensor_data(const TensorBase&) const = 0;
+  virtual TensorBase variable_data(const TensorBase&) const = 0;
+  virtual const std::shared_ptr& grad_fn(const TensorBase&) const = 0;
+  virtual unsigned _register_hook(
+      const TensorBase&,
+      std::function hook) const = 0;
+  virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
+  virtual bool is_view(const TensorBase&) const = 0;
+  virtual const TensorBase& base(const TensorBase&) const = 0;
+  virtual const std::string& name(const TensorBase&) const = 0;
+  virtual bool is_leaf(const TensorBase&) const = 0;
+  virtual int64_t output_nr(const TensorBase&) const = 0;
+  virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
+  virtual TensorBase data(const TensorBase&) const = 0;
+  virtual int64_t _version(const TensorBase&) const = 0;
+  virtual void retain_grad(const TensorBase&) const = 0;
+  virtual bool retains_grad(const TensorBase&) const = 0;
+  virtual void _backward(const Tensor&, TensorList, const c10::optional&, c10::optional, bool) const = 0;
+  virtual void requires_grad_(const TensorBase&, bool) const = 0;
+  virtual void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) const = 0;
+};
+
+TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
+TORCH_API VariableHooksInterface* GetVariableHooks();
+TORCH_API bool HasVariableHooks();
+
+struct TORCH_API VariableHooksRegisterer {
+  explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
+    SetVariableHooks(hooks);
+  }
+};
+
+}} // namespace at::impl
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Variadic.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Variadic.h
new file mode 100644
index 0000000000000000000000000000000000000000..22007cf39eff2b832732d68c7bb0bd9429a49865
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Variadic.h
@@ -0,0 +1,95 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at {
+
+// This class allows you to write variadic functions which
+// call a (possibly overloaded) function on each argument,
+// in order.  This is most commonly used in autogenerated code,
+// where it is convenient to have a function that can uniformly
+// take arguments of different types.  If your arguments
+// are homogenous consider using a std::initializer_list instead.
+//
+// For examples of this in use, see torch/csrc/utils/variadic.h
+template 
+struct IterArgs {
+  template 
+  inline F& apply() {
+    return self();
+  }
+
+  // NB: Use perfect forwarding here, otherwise we'll make value
+  // copies of all arguments!
+  template 
+  inline F& apply(T&& arg, Args&&... args) {
+    self()(std::forward(arg));
+    if (self().short_circuit()) {
+      return self();
+    } else {
+      return apply(std::forward(args)...);
+    }
+  }
+
+  // Here are some handy overloads which provide sensible
+  // defaults for container-like structures that one might
+  // be interested in recursing into.  You can enable them
+  // by adding:
+  //
+  //    using IterArgs::operator()
+  //
+  // to your struct.  These are not enabled by default because
+  // you may be able to process these structures more efficiently
+  // than handling them one-by-one.
+
+  template 
+  void operator()(c10::IListRef args) {
+    for (const auto& arg : args) {
+      self()(arg);
+      if (self().short_circuit())
+        return;
+    }
+  }
+
+  template 
+  void operator()(at::ArrayRef args) {
+    for (const auto& arg : args) {
+      self()(arg);
+      if (self().short_circuit())
+        return;
+    }
+  }
+
+  template 
+  void operator()(const torch::List& args) {
+    for (const auto& arg : args) {
+      self()(arg);
+      if (self().short_circuit())
+        return;
+    }
+  }
+
+  // NB: we need to specify std::vector manually as C++ won't
+  // do an implicit conversion to make a template deduction go through.
+  template 
+  void operator()(const std::vector& args) {
+    self()(at::ArrayRef{args});
+  }
+
+  constexpr bool short_circuit() const {
+    return false;
+  }
+
+ private:
+  inline F& self() {
+    return *static_cast(this);
+  }
+};
+
+} // namespace torch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/Vitals.h b/MLPY/Lib/site-packages/torch/include/ATen/core/Vitals.h
new file mode 100644
index 0000000000000000000000000000000000000000..0fbaa61f37c9f32951c52855b22efac8d73ac74d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/Vitals.h
@@ -0,0 +1,96 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+namespace vitals {
+
+TORCH_API bool torchVitalEnabled();
+
+struct TORCH_API TorchVitalAttr {
+  // always initialized to empty
+  std::string value = "";
+  template 
+  TorchVitalAttr& operator<<(const T& t) {
+    if (torchVitalEnabled()) {
+      std::stringstream ss;
+      ss << t;
+      value += ss.str();
+    }
+    return *this;
+  }
+
+  template 
+  void write(const T& t, bool force) {
+    if (force || torchVitalEnabled()) {
+      std::stringstream ss;
+      ss << t;
+      value = ss.str();
+    }
+  }
+};
+
+struct TORCH_API TorchVital {
+  std::string name;
+  std::unordered_map attrs;
+
+  explicit TorchVital(std::string n) : name(std::move(n)) {}
+  TorchVital(const TorchVital&) = default;
+  TorchVital(TorchVital&&) = default;
+  TorchVital() = delete;
+
+  TorchVitalAttr& create(const std::string& attr);
+  TorchVitalAttr& create(const std::string& attr, bool force);
+  friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt);
+
+  ~TorchVital();
+};
+
+std::ostream& operator<<(std::ostream& os, TorchVital const& tv);
+
+// A way to access vitals by string names instead of by global reference.
+// This enables access to vitals from the PythonAPI.
+class TORCH_API APIVitals {
+ public:
+  bool vitals_enabled;
+
+  // Set any vital sign that was added to the map.
+  bool setVital(
+      const std::string& vital_name,
+      const std::string& attr_name,
+      const std::string& value,
+      bool force = false);
+  std::string readVitals();
+
+  APIVitals();
+
+  // Ensure this stays a singleton
+  APIVitals(APIVitals const& other) = delete;
+  APIVitals(APIVitals&& other) = delete;
+  APIVitals& operator=(const APIVitals&) = delete;
+  APIVitals& operator=(APIVitals&&) = delete;
+
+ private:
+  std::unordered_map name_map_;
+};
+
+extern TORCH_API APIVitals VitalsAPI;
+
+} // namespace vitals
+} // namespace at
+
+#define TORCH_VITAL_DECLARE(name) \
+  TORCH_API at::vitals::TorchVital TorchVital_##name;
+
+#define TORCH_VITAL_DEFINE(name) \
+  TORCH_API at::vitals::TorchVital TorchVital_##name(#name);
+
+#define TORCH_VITAL_BASE(name) TorchVital_##name
+
+#define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr)
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/alias_info.h b/MLPY/Lib/site-packages/torch/include/ATen/core/alias_info.h
new file mode 100644
index 0000000000000000000000000000000000000000..9670e92646c44d7ca23700010f7ea971bc0b7989
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/alias_info.h
@@ -0,0 +1,151 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+/**
+ * class AliasInfo
+ *
+ * Data structure to hold aliasing information for an `Argument`. They can be
+ * nested to represent aliasing information on contained types.
+ *
+ * There is a `beforeSet` which describes the aliasing information before the
+ * operator executes, and an `afterSet` that describes aliasing info
+ * after execution.
+ */
+class AliasInfo {
+ public:
+  // Symbol for the set that can alias anything
+  static Symbol wildcardSet() {
+    static const Symbol wc = Symbol::fromQualString("alias::*");
+    return wc;
+  }
+
+  void setIsWrite(bool isWrite) {
+    isWrite_ = isWrite;
+  }
+
+  bool isWrite() const {
+    return isWrite_;
+  }
+
+  void addBeforeSet(Symbol aliasSet) {
+    beforeSets_.insert(aliasSet);
+  }
+
+  void addAfterSet(Symbol aliasSet) {
+    afterSets_.insert(aliasSet);
+  }
+
+  const std::unordered_set& beforeSets() const {
+    return beforeSets_;
+  }
+
+  const std::unordered_set& afterSets() const {
+    return afterSets_;
+  }
+
+  Symbol beforeSet() const {
+    AT_ASSERT(beforeSets_.size() == 1);
+    return *beforeSets_.begin();
+  }
+
+  bool isWildcardBefore() const {
+    return beforeSets_.count(wildcardSet()) != 0;
+  }
+
+  bool isWildcardAfter() const {
+    return afterSets_.count(wildcardSet()) != 0;
+  }
+
+  // the alias info for the contained types of the type
+  // e.g. if this is an annotation on List[T], `sets` refers to
+  // the alias sets that the list may be in
+  // while containedTypes()[0] refers to the sets that members of the list
+  // may be in
+  void addContainedType(AliasInfo aliasInfo) {
+    containedTypes_.push_back(std::move(aliasInfo));
+  }
+  const std::vector& containedTypes() const {
+    return containedTypes_;
+  }
+
+ private:
+  std::unordered_set beforeSets_;
+  std::unordered_set afterSets_;
+  std::vector containedTypes_;
+  bool isWrite_ = false;
+};
+
+inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
+  return lhs.isWrite() == rhs.isWrite()
+      && lhs.beforeSets() == rhs.beforeSets()
+      && lhs.afterSets() == rhs.afterSets()
+      && lhs.containedTypes() == rhs.containedTypes();
+}
+
+// this does match the way things are represented in the schema
+inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
+  out << "(";
+  bool first = true;
+  for (const auto& set : aliasInfo.beforeSets()) {
+    if (first) {
+      first = false;
+    } else {
+      out << "|";
+    }
+    out << set.toUnqualString();
+  }
+  if (aliasInfo.isWrite()) {
+    out << "!";
+  }
+  if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
+    out << " -> ";
+    first = true;
+    for (const auto& set : aliasInfo.afterSets()) {
+      if (first) {
+        first = false;
+      } else {
+        out << "|";
+      }
+      out << set.toUnqualString();
+    }
+  }
+  out << ")";
+  return out;
+}
+} // namespace c10
+
+namespace std {
+template <>
+  struct hash {
+    size_t operator()(const c10::AliasInfo& aliasInfo) const {
+      auto hash = std::hash()(aliasInfo.isWrite());
+
+      // NOTE: for unordered_set hashes, we couldn't use hash_combine
+      // because hash_combine is order dependent. Instead, we choose to
+      // use XOR as the combining function as XOR is commutative.
+      size_t before_set_hash_seed = 0;
+      for (auto &e: aliasInfo.beforeSets()) {
+        auto symbol_hash = std::hash()(e);
+        before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
+      }
+      size_t after_set_hash_seed = 0;
+      for (auto &e: aliasInfo.afterSets()) {
+        auto symbol_hash = std::hash()(e);
+        after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
+      }
+
+      hash = c10::hash_combine(hash, before_set_hash_seed);
+      hash = c10::hash_combine(hash, after_set_hash_seed);
+      for (auto &e: aliasInfo.containedTypes()) {
+        auto contained_type_hash = std::hash()(e);
+        hash = c10::hash_combine(hash, contained_type_hash);
+      }
+      return hash;
+    }
+  };
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/aten_interned_strings.h b/MLPY/Lib/site-packages/torch/include/ATen/core/aten_interned_strings.h
new file mode 100644
index 0000000000000000000000000000000000000000..8348b554d6f189e1b35d86a087822f67286dc235
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/aten_interned_strings.h
@@ -0,0 +1,2213 @@
+#pragma once
+
+// @generated by torchgen/gen.py from aten_interned_strings.h
+
+#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on native_functions.yaml,          \
+  meaning the file will need to be re-compiled every time an operator   \
+  is changed or added. Consider if including  for   \
+  the c10::Symbol class would be sufficient, or if your change would be \
+  better placed in another file.
+#endif
+
+// ATen symbols correspond exactly to operators defined in ATen. Every
+// symbol here corresponds exactly to an ATen operation defined in
+// native_functions.yaml; attributes are in one-to-one correspondence
+// with their ATen name.
+
+#define FORALL_ATEN_BASE_SYMBOLS(_) \
+_(aten, __and__) \
+_(aten, __iand__) \
+_(aten, __ilshift__) \
+_(aten, __ior__) \
+_(aten, __irshift__) \
+_(aten, __ixor__) \
+_(aten, __lshift__) \
+_(aten, __or__) \
+_(aten, __rshift__) \
+_(aten, __xor__) \
+_(aten, _adaptive_avg_pool2d) \
+_(aten, _adaptive_avg_pool2d_backward) \
+_(aten, _adaptive_avg_pool3d) \
+_(aten, _adaptive_avg_pool3d_backward) \
+_(aten, _add_batch_dim) \
+_(aten, _add_relu) \
+_(aten, _add_relu_) \
+_(aten, _addmm_activation) \
+_(aten, _aminmax) \
+_(aten, _amp_foreach_non_finite_check_and_unscale) \
+_(aten, _amp_foreach_non_finite_check_and_unscale_) \
+_(aten, _amp_update_scale) \
+_(aten, _amp_update_scale_) \
+_(aten, _assert_async) \
+_(aten, _assert_scalar) \
+_(aten, _assert_tensor_metadata) \
+_(aten, _autocast_to_full_precision) \
+_(aten, _autocast_to_reduced_precision) \
+_(aten, _backward) \
+_(aten, _batch_norm_impl_index) \
+_(aten, _batch_norm_impl_index_backward) \
+_(aten, _cast_Byte) \
+_(aten, _cast_Char) \
+_(aten, _cast_Double) \
+_(aten, _cast_Float) \
+_(aten, _cast_Half) \
+_(aten, _cast_Int) \
+_(aten, _cast_Long) \
+_(aten, _cast_Short) \
+_(aten, _cdist_backward) \
+_(aten, _cdist_forward) \
+_(aten, _cholesky_solve_helper) \
+_(aten, _choose_qparams_per_tensor) \
+_(aten, _chunk_cat) \
+_(aten, _coalesce) \
+_(aten, _coalesced) \
+_(aten, _coalesced_) \
+_(aten, _compute_linear_combination) \
+_(aten, _conj) \
+_(aten, _conj_copy) \
+_(aten, _conj_physical) \
+_(aten, _conv_depthwise2d) \
+_(aten, _convert_indices_from_coo_to_csr) \
+_(aten, _convert_indices_from_csr_to_coo) \
+_(aten, _convert_weight_to_int4pack) \
+_(aten, _convolution) \
+_(aten, _convolution_double_backward) \
+_(aten, _convolution_mode) \
+_(aten, _copy_from) \
+_(aten, _copy_from_and_resize) \
+_(aten, _cslt_compress) \
+_(aten, _cslt_sparse_mm) \
+_(aten, _cslt_sparse_mm_search) \
+_(aten, _ctc_loss) \
+_(aten, _ctc_loss_backward) \
+_(aten, _cudnn_ctc_loss) \
+_(aten, _cudnn_init_dropout_state) \
+_(aten, _cudnn_rnn) \
+_(aten, _cudnn_rnn_backward) \
+_(aten, _cudnn_rnn_flatten_weight) \
+_(aten, _cufft_clear_plan_cache) \
+_(aten, _cufft_get_plan_cache_max_size) \
+_(aten, _cufft_get_plan_cache_size) \
+_(aten, _cufft_set_plan_cache_max_size) \
+_(aten, _cummax_helper) \
+_(aten, _cummin_helper) \
+_(aten, _debug_has_internal_overlap) \
+_(aten, _dimI) \
+_(aten, _dimV) \
+_(aten, _dim_arange) \
+_(aten, _dirichlet_grad) \
+_(aten, _efficient_attention_backward) \
+_(aten, _efficient_attention_forward) \
+_(aten, _efficientzerotensor) \
+_(aten, _embedding_bag) \
+_(aten, _embedding_bag_backward) \
+_(aten, _embedding_bag_dense_backward) \
+_(aten, _embedding_bag_forward_only) \
+_(aten, _embedding_bag_per_sample_weights_backward) \
+_(aten, _embedding_bag_sparse_backward) \
+_(aten, _empty_affine_quantized) \
+_(aten, _empty_per_channel_affine_quantized) \
+_(aten, _euclidean_dist) \
+_(aten, _fake_quantize_learnable_per_channel_affine) \
+_(aten, _fake_quantize_learnable_per_channel_affine_backward) \
+_(aten, _fake_quantize_learnable_per_tensor_affine) \
+_(aten, _fake_quantize_learnable_per_tensor_affine_backward) \
+_(aten, _fake_quantize_per_tensor_affine_cachemask_tensor_qparams) \
+_(aten, _fft_c2c) \
+_(aten, _fft_c2r) \
+_(aten, _fft_r2c) \
+_(aten, _fill_mem_eff_dropout_mask) \
+_(aten, _fill_mem_eff_dropout_mask_) \
+_(aten, _flash_attention_backward) \
+_(aten, _flash_attention_forward) \
+_(aten, _foobar) \
+_(aten, _foreach_abs) \
+_(aten, _foreach_abs_) \
+_(aten, _foreach_acos) \
+_(aten, _foreach_acos_) \
+_(aten, _foreach_add) \
+_(aten, _foreach_add_) \
+_(aten, _foreach_addcdiv) \
+_(aten, _foreach_addcdiv_) \
+_(aten, _foreach_addcmul) \
+_(aten, _foreach_addcmul_) \
+_(aten, _foreach_asin) \
+_(aten, _foreach_asin_) \
+_(aten, _foreach_atan) \
+_(aten, _foreach_atan_) \
+_(aten, _foreach_ceil) \
+_(aten, _foreach_ceil_) \
+_(aten, _foreach_clamp_max) \
+_(aten, _foreach_clamp_max_) \
+_(aten, _foreach_clamp_min) \
+_(aten, _foreach_clamp_min_) \
+_(aten, _foreach_copy) \
+_(aten, _foreach_copy_) \
+_(aten, _foreach_cos) \
+_(aten, _foreach_cos_) \
+_(aten, _foreach_cosh) \
+_(aten, _foreach_cosh_) \
+_(aten, _foreach_div) \
+_(aten, _foreach_div_) \
+_(aten, _foreach_erf) \
+_(aten, _foreach_erf_) \
+_(aten, _foreach_erfc) \
+_(aten, _foreach_erfc_) \
+_(aten, _foreach_exp) \
+_(aten, _foreach_exp_) \
+_(aten, _foreach_expm1) \
+_(aten, _foreach_expm1_) \
+_(aten, _foreach_floor) \
+_(aten, _foreach_floor_) \
+_(aten, _foreach_frac) \
+_(aten, _foreach_frac_) \
+_(aten, _foreach_lerp) \
+_(aten, _foreach_lerp_) \
+_(aten, _foreach_lgamma) \
+_(aten, _foreach_lgamma_) \
+_(aten, _foreach_log) \
+_(aten, _foreach_log10) \
+_(aten, _foreach_log10_) \
+_(aten, _foreach_log1p) \
+_(aten, _foreach_log1p_) \
+_(aten, _foreach_log2) \
+_(aten, _foreach_log2_) \
+_(aten, _foreach_log_) \
+_(aten, _foreach_maximum) \
+_(aten, _foreach_maximum_) \
+_(aten, _foreach_minimum) \
+_(aten, _foreach_minimum_) \
+_(aten, _foreach_mul) \
+_(aten, _foreach_mul_) \
+_(aten, _foreach_neg) \
+_(aten, _foreach_neg_) \
+_(aten, _foreach_norm) \
+_(aten, _foreach_pow) \
+_(aten, _foreach_pow_) \
+_(aten, _foreach_reciprocal) \
+_(aten, _foreach_reciprocal_) \
+_(aten, _foreach_round) \
+_(aten, _foreach_round_) \
+_(aten, _foreach_sigmoid) \
+_(aten, _foreach_sigmoid_) \
+_(aten, _foreach_sign) \
+_(aten, _foreach_sign_) \
+_(aten, _foreach_sin) \
+_(aten, _foreach_sin_) \
+_(aten, _foreach_sinh) \
+_(aten, _foreach_sinh_) \
+_(aten, _foreach_sqrt) \
+_(aten, _foreach_sqrt_) \
+_(aten, _foreach_sub) \
+_(aten, _foreach_sub_) \
+_(aten, _foreach_tan) \
+_(aten, _foreach_tan_) \
+_(aten, _foreach_tanh) \
+_(aten, _foreach_tanh_) \
+_(aten, _foreach_trunc) \
+_(aten, _foreach_trunc_) \
+_(aten, _foreach_zero) \
+_(aten, _foreach_zero_) \
+_(aten, _functional_assert_async) \
+_(aten, _functional_assert_scalar) \
+_(aten, _functional_sym_constrain_range) \
+_(aten, _functional_sym_constrain_range_for_size) \
+_(aten, _fused_adam) \
+_(aten, _fused_adam_) \
+_(aten, _fused_adamw) \
+_(aten, _fused_adamw_) \
+_(aten, _fused_dropout) \
+_(aten, _fused_moving_avg_obs_fq_helper) \
+_(aten, _fused_moving_avg_obs_fq_helper_functional) \
+_(aten, _fused_sdp_choice) \
+_(aten, _fused_sgd) \
+_(aten, _fused_sgd_) \
+_(aten, _fw_primal) \
+_(aten, _fw_primal_copy) \
+_(aten, _gather_sparse_backward) \
+_(aten, _grid_sampler_2d_cpu_fallback) \
+_(aten, _grid_sampler_2d_cpu_fallback_backward) \
+_(aten, _has_compatible_shallow_copy_type) \
+_(aten, _has_same_storage_numel) \
+_(aten, _histogramdd_bin_edges) \
+_(aten, _histogramdd_from_bin_cts) \
+_(aten, _histogramdd_from_bin_tensors) \
+_(aten, _index_put_impl) \
+_(aten, _index_put_impl_) \
+_(aten, _indices) \
+_(aten, _indices_copy) \
+_(aten, _int_mm) \
+_(aten, _is_all_true) \
+_(aten, _is_any_true) \
+_(aten, _is_zerotensor) \
+_(aten, _lazy_clone) \
+_(aten, _linalg_check_errors) \
+_(aten, _linalg_det) \
+_(aten, _linalg_eigh) \
+_(aten, _linalg_eigvals) \
+_(aten, _linalg_slogdet) \
+_(aten, _linalg_solve_ex) \
+_(aten, _linalg_svd) \
+_(aten, _local_scalar_dense) \
+_(aten, _log_softmax) \
+_(aten, _log_softmax_backward_data) \
+_(aten, _logcumsumexp) \
+_(aten, _lstm_mps) \
+_(aten, _lu_with_info) \
+_(aten, _make_dep_token) \
+_(aten, _make_dual) \
+_(aten, _make_dual_copy) \
+_(aten, _make_per_channel_quantized_tensor) \
+_(aten, _make_per_tensor_quantized_tensor) \
+_(aten, _masked_scale) \
+_(aten, _masked_softmax) \
+_(aten, _masked_softmax_backward) \
+_(aten, _mixed_dtypes_linear) \
+_(aten, _mkldnn_reshape) \
+_(aten, _mkldnn_transpose) \
+_(aten, _mkldnn_transpose_) \
+_(aten, _mps_convolution) \
+_(aten, _mps_convolution_transpose) \
+_(aten, _native_batch_norm_legit) \
+_(aten, _native_batch_norm_legit_functional) \
+_(aten, _native_batch_norm_legit_no_training) \
+_(aten, _native_multi_head_attention) \
+_(aten, _neg_view) \
+_(aten, _neg_view_copy) \
+_(aten, _nested_from_padded) \
+_(aten, _nested_from_padded_and_nested_example) \
+_(aten, _nested_get_jagged_dummy) \
+_(aten, _nested_get_lengths) \
+_(aten, _nested_get_offsets) \
+_(aten, _nested_get_ragged_idx) \
+_(aten, _nested_get_values) \
+_(aten, _nested_get_values_copy) \
+_(aten, _nested_select_backward) \
+_(aten, _nested_sum_backward) \
+_(aten, _nested_tensor_from_mask) \
+_(aten, _nested_tensor_from_mask_left_aligned) \
+_(aten, _nested_tensor_from_tensor_list) \
+_(aten, _nested_tensor_size) \
+_(aten, _nested_tensor_softmax_with_shape) \
+_(aten, _nested_tensor_storage_offsets) \
+_(aten, _nested_tensor_strides) \
+_(aten, _nested_view_from_buffer) \
+_(aten, _nested_view_from_buffer_copy) \
+_(aten, _nested_view_from_jagged) \
+_(aten, _nested_view_from_jagged_copy) \
+_(aten, _new_zeros_with_same_feature_meta) \
+_(aten, _nnpack_available) \
+_(aten, _nnpack_spatial_convolution) \
+_(aten, _nnz) \
+_(aten, _pack_padded_sequence) \
+_(aten, _pack_padded_sequence_backward) \
+_(aten, _pad_circular) \
+_(aten, _pad_enum) \
+_(aten, _pad_packed_sequence) \
+_(aten, _pdist_backward) \
+_(aten, _pdist_forward) \
+_(aten, _pin_memory) \
+_(aten, _prelu_kernel) \
+_(aten, _prelu_kernel_backward) \
+_(aten, _print) \
+_(aten, _propagate_xla_data) \
+_(aten, _remove_batch_dim) \
+_(aten, _reshape_alias) \
+_(aten, _reshape_alias_copy) \
+_(aten, _reshape_copy) \
+_(aten, _reshape_from_tensor) \
+_(aten, _resize_output) \
+_(aten, _resize_output_) \
+_(aten, _rowwise_prune) \
+_(aten, _sample_dirichlet) \
+_(aten, _saturate_weight_to_fp16) \
+_(aten, _scaled_dot_product_attention_math) \
+_(aten, _scaled_dot_product_cudnn_attention) \
+_(aten, _scaled_dot_product_efficient_attention) \
+_(aten, _scaled_dot_product_efficient_attention_backward) \
+_(aten, _scaled_dot_product_flash_attention) \
+_(aten, _scaled_dot_product_flash_attention_backward) \
+_(aten, _scaled_dot_product_flash_attention_for_cpu) \
+_(aten, _scaled_dot_product_flash_attention_for_cpu_backward) \
+_(aten, _scaled_mm) \
+_(aten, _segment_reduce_backward) \
+_(aten, _shape_as_tensor) \
+_(aten, _slow_conv2d_backward) \
+_(aten, _slow_conv2d_forward) \
+_(aten, _sobol_engine_draw) \
+_(aten, _sobol_engine_ff) \
+_(aten, _sobol_engine_ff_) \
+_(aten, _sobol_engine_initialize_state) \
+_(aten, _sobol_engine_initialize_state_) \
+_(aten, _sobol_engine_scramble) \
+_(aten, _sobol_engine_scramble_) \
+_(aten, _softmax) \
+_(aten, _softmax_backward_data) \
+_(aten, _sparse_addmm) \
+_(aten, _sparse_broadcast_to) \
+_(aten, _sparse_broadcast_to_copy) \
+_(aten, _sparse_bsc_tensor_unsafe) \
+_(aten, _sparse_bsr_tensor_unsafe) \
+_(aten, _sparse_compressed_tensor_unsafe) \
+_(aten, _sparse_coo_tensor_unsafe) \
+_(aten, _sparse_coo_tensor_with_dims) \
+_(aten, _sparse_coo_tensor_with_dims_and_tensors) \
+_(aten, _sparse_csc_tensor_unsafe) \
+_(aten, _sparse_csr_prod) \
+_(aten, _sparse_csr_sum) \
+_(aten, _sparse_csr_tensor_unsafe) \
+_(aten, _sparse_log_softmax) \
+_(aten, _sparse_log_softmax_backward_data) \
+_(aten, _sparse_mask_projection) \
+_(aten, _sparse_mm) \
+_(aten, _sparse_mm_reduce_impl) \
+_(aten, _sparse_mm_reduce_impl_backward) \
+_(aten, _sparse_semi_structured_linear) \
+_(aten, _sparse_softmax) \
+_(aten, _sparse_softmax_backward_data) \
+_(aten, _sparse_sparse_matmul) \
+_(aten, _sparse_sum) \
+_(aten, _sparse_sum_backward) \
+_(aten, _spdiags) \
+_(aten, _stack) \
+_(aten, _standard_gamma) \
+_(aten, _standard_gamma_grad) \
+_(aten, _test_ambiguous_defaults) \
+_(aten, _test_autograd_multiple_dispatch) \
+_(aten, _test_autograd_multiple_dispatch_view) \
+_(aten, _test_autograd_multiple_dispatch_view_copy) \
+_(aten, _test_check_tensor) \
+_(aten, _test_functorch_fallback) \
+_(aten, _test_optional_filled_intlist) \
+_(aten, _test_optional_floatlist) \
+_(aten, _test_optional_intlist) \
+_(aten, _test_parallel_materialize) \
+_(aten, _test_serialization_subcmul) \
+_(aten, _test_string_default) \
+_(aten, _test_warn_in_autograd) \
+_(aten, _thnn_differentiable_gru_cell_backward) \
+_(aten, _thnn_differentiable_lstm_cell_backward) \
+_(aten, _thnn_fused_gru_cell) \
+_(aten, _thnn_fused_gru_cell_backward) \
+_(aten, _thnn_fused_lstm_cell) \
+_(aten, _thnn_fused_lstm_cell_backward) \
+_(aten, _thnn_fused_lstm_cell_backward_impl) \
+_(aten, _to_copy) \
+_(aten, _to_cpu) \
+_(aten, _to_dense) \
+_(aten, _to_sparse) \
+_(aten, _to_sparse_bsc) \
+_(aten, _to_sparse_bsr) \
+_(aten, _to_sparse_csc) \
+_(aten, _to_sparse_csr) \
+_(aten, _to_sparse_semi_structured) \
+_(aten, _transform_bias_rescale_qkv) \
+_(aten, _transformer_encoder_layer_fwd) \
+_(aten, _trilinear) \
+_(aten, _triton_multi_head_attention) \
+_(aten, _triton_scaled_dot_attention) \
+_(aten, _unique) \
+_(aten, _unique2) \
+_(aten, _unpack_dual) \
+_(aten, _unsafe_index) \
+_(aten, _unsafe_index_put) \
+_(aten, _unsafe_view) \
+_(aten, _upsample_bicubic2d_aa) \
+_(aten, _upsample_bicubic2d_aa_backward) \
+_(aten, _upsample_bilinear2d_aa) \
+_(aten, _upsample_bilinear2d_aa_backward) \
+_(aten, _upsample_nearest_exact1d) \
+_(aten, _upsample_nearest_exact1d_backward) \
+_(aten, _upsample_nearest_exact2d) \
+_(aten, _upsample_nearest_exact2d_backward) \
+_(aten, _upsample_nearest_exact3d) \
+_(aten, _upsample_nearest_exact3d_backward) \
+_(aten, _use_cudnn_ctc_loss) \
+_(aten, _use_cudnn_rnn_flatten_weight) \
+_(aten, _validate_compressed_sparse_indices) \
+_(aten, _validate_sparse_bsc_tensor_args) \
+_(aten, _validate_sparse_bsr_tensor_args) \
+_(aten, _validate_sparse_compressed_tensor_args) \
+_(aten, _validate_sparse_coo_tensor_args) \
+_(aten, _validate_sparse_csc_tensor_args) \
+_(aten, _validate_sparse_csr_tensor_args) \
+_(aten, _values) \
+_(aten, _values_copy) \
+_(aten, _version) \
+_(aten, _weight_int4pack_mm) \
+_(aten, _weight_int8pack_mm) \
+_(aten, _weight_norm) \
+_(aten, _weight_norm_differentiable_backward) \
+_(aten, _weight_norm_interface) \
+_(aten, _weight_norm_interface_backward) \
+_(aten, abs) \
+_(aten, abs_) \
+_(aten, absolute) \
+_(aten, absolute_) \
+_(aten, acos) \
+_(aten, acos_) \
+_(aten, acosh) \
+_(aten, acosh_) \
+_(aten, adaptive_avg_pool1d) \
+_(aten, adaptive_avg_pool2d) \
+_(aten, adaptive_avg_pool3d) \
+_(aten, adaptive_avg_pool3d_backward) \
+_(aten, adaptive_max_pool1d) \
+_(aten, adaptive_max_pool2d) \
+_(aten, adaptive_max_pool2d_backward) \
+_(aten, adaptive_max_pool3d) \
+_(aten, adaptive_max_pool3d_backward) \
+_(aten, add) \
+_(aten, add_) \
+_(aten, addbmm) \
+_(aten, addbmm_) \
+_(aten, addcdiv) \
+_(aten, addcdiv_) \
+_(aten, addcmul) \
+_(aten, addcmul_) \
+_(aten, addmm) \
+_(aten, addmm_) \
+_(aten, addmv) \
+_(aten, addmv_) \
+_(aten, addr) \
+_(aten, addr_) \
+_(aten, adjoint) \
+_(aten, affine_grid_generator) \
+_(aten, affine_grid_generator_backward) \
+_(aten, alias) \
+_(aten, alias_copy) \
+_(aten, align_as) \
+_(aten, align_tensors) \
+_(aten, align_to) \
+_(aten, all) \
+_(aten, allclose) \
+_(aten, alpha_dropout) \
+_(aten, alpha_dropout_) \
+_(aten, amax) \
+_(aten, amin) \
+_(aten, aminmax) \
+_(aten, angle) \
+_(aten, any) \
+_(aten, arange) \
+_(aten, arccos) \
+_(aten, arccos_) \
+_(aten, arccosh) \
+_(aten, arccosh_) \
+_(aten, arcsin) \
+_(aten, arcsin_) \
+_(aten, arcsinh) \
+_(aten, arcsinh_) \
+_(aten, arctan) \
+_(aten, arctan2) \
+_(aten, arctan2_) \
+_(aten, arctan_) \
+_(aten, arctanh) \
+_(aten, arctanh_) \
+_(aten, argmax) \
+_(aten, argmin) \
+_(aten, argsort) \
+_(aten, argwhere) \
+_(aten, as_strided) \
+_(aten, as_strided_) \
+_(aten, as_strided_copy) \
+_(aten, as_strided_scatter) \
+_(aten, asin) \
+_(aten, asin_) \
+_(aten, asinh) \
+_(aten, asinh_) \
+_(aten, atan) \
+_(aten, atan2) \
+_(aten, atan2_) \
+_(aten, atan_) \
+_(aten, atanh) \
+_(aten, atanh_) \
+_(aten, atleast_1d) \
+_(aten, atleast_2d) \
+_(aten, atleast_3d) \
+_(aten, avg_pool1d) \
+_(aten, avg_pool2d) \
+_(aten, avg_pool2d_backward) \
+_(aten, avg_pool3d) \
+_(aten, avg_pool3d_backward) \
+_(aten, baddbmm) \
+_(aten, baddbmm_) \
+_(aten, bartlett_window) \
+_(aten, batch_norm) \
+_(aten, batch_norm_backward_elemt) \
+_(aten, batch_norm_backward_reduce) \
+_(aten, batch_norm_elemt) \
+_(aten, batch_norm_gather_stats) \
+_(aten, batch_norm_gather_stats_with_counts) \
+_(aten, batch_norm_stats) \
+_(aten, batch_norm_update_stats) \
+_(aten, bernoulli) \
+_(aten, bernoulli_) \
+_(aten, bilinear) \
+_(aten, binary_cross_entropy) \
+_(aten, binary_cross_entropy_backward) \
+_(aten, binary_cross_entropy_with_logits) \
+_(aten, bincount) \
+_(aten, binomial) \
+_(aten, bitwise_and) \
+_(aten, bitwise_and_) \
+_(aten, bitwise_left_shift) \
+_(aten, bitwise_left_shift_) \
+_(aten, bitwise_not) \
+_(aten, bitwise_not_) \
+_(aten, bitwise_or) \
+_(aten, bitwise_or_) \
+_(aten, bitwise_right_shift) \
+_(aten, bitwise_right_shift_) \
+_(aten, bitwise_xor) \
+_(aten, bitwise_xor_) \
+_(aten, blackman_window) \
+_(aten, block_diag) \
+_(aten, bmm) \
+_(aten, broadcast_tensors) \
+_(aten, broadcast_to) \
+_(aten, bucketize) \
+_(aten, can_cast) \
+_(aten, cartesian_prod) \
+_(aten, cat) \
+_(aten, cauchy) \
+_(aten, cauchy_) \
+_(aten, ccol_indices) \
+_(aten, ccol_indices_copy) \
+_(aten, cdist) \
+_(aten, ceil) \
+_(aten, ceil_) \
+_(aten, celu) \
+_(aten, celu_) \
+_(aten, chain_matmul) \
+_(aten, chalf) \
+_(aten, channel_shuffle) \
+_(aten, cholesky) \
+_(aten, cholesky_inverse) \
+_(aten, cholesky_solve) \
+_(aten, choose_qparams_optimized) \
+_(aten, chunk) \
+_(aten, clamp) \
+_(aten, clamp_) \
+_(aten, clamp_max) \
+_(aten, clamp_max_) \
+_(aten, clamp_min) \
+_(aten, clamp_min_) \
+_(aten, clip) \
+_(aten, clip_) \
+_(aten, clone) \
+_(aten, coalesce) \
+_(aten, col2im) \
+_(aten, col_indices) \
+_(aten, col_indices_copy) \
+_(aten, column_stack) \
+_(aten, combinations) \
+_(aten, complex) \
+_(aten, concat) \
+_(aten, concatenate) \
+_(aten, conj) \
+_(aten, conj_physical) \
+_(aten, conj_physical_) \
+_(aten, constant_pad_nd) \
+_(aten, contiguous) \
+_(aten, conv1d) \
+_(aten, conv2d) \
+_(aten, conv3d) \
+_(aten, conv_depthwise3d) \
+_(aten, conv_tbc) \
+_(aten, conv_tbc_backward) \
+_(aten, conv_transpose1d) \
+_(aten, conv_transpose2d) \
+_(aten, conv_transpose3d) \
+_(aten, convolution) \
+_(aten, convolution_backward) \
+_(aten, convolution_backward_overrideable) \
+_(aten, convolution_overrideable) \
+_(aten, copy) \
+_(aten, copy_) \
+_(aten, copy_sparse_to_sparse) \
+_(aten, copy_sparse_to_sparse_) \
+_(aten, copysign) \
+_(aten, copysign_) \
+_(aten, corrcoef) \
+_(aten, cos) \
+_(aten, cos_) \
+_(aten, cosh) \
+_(aten, cosh_) \
+_(aten, cosine_embedding_loss) \
+_(aten, cosine_similarity) \
+_(aten, count_nonzero) \
+_(aten, cov) \
+_(aten, cross) \
+_(aten, cross_entropy_loss) \
+_(aten, crow_indices) \
+_(aten, crow_indices_copy) \
+_(aten, ctc_loss) \
+_(aten, cudnn_affine_grid_generator) \
+_(aten, cudnn_affine_grid_generator_backward) \
+_(aten, cudnn_batch_norm) \
+_(aten, cudnn_batch_norm_backward) \
+_(aten, cudnn_convolution) \
+_(aten, cudnn_convolution_add_relu) \
+_(aten, cudnn_convolution_relu) \
+_(aten, cudnn_convolution_transpose) \
+_(aten, cudnn_grid_sampler) \
+_(aten, cudnn_grid_sampler_backward) \
+_(aten, cudnn_is_acceptable) \
+_(aten, cummax) \
+_(aten, cummaxmin_backward) \
+_(aten, cummin) \
+_(aten, cumprod) \
+_(aten, cumprod_) \
+_(aten, cumprod_backward) \
+_(aten, cumsum) \
+_(aten, cumsum_) \
+_(aten, cumulative_trapezoid) \
+_(aten, data) \
+_(aten, deg2rad) \
+_(aten, deg2rad_) \
+_(aten, dense_dim) \
+_(aten, dequantize) \
+_(aten, det) \
+_(aten, detach) \
+_(aten, detach_) \
+_(aten, detach_copy) \
+_(aten, diag) \
+_(aten, diag_embed) \
+_(aten, diagflat) \
+_(aten, diagonal) \
+_(aten, diagonal_backward) \
+_(aten, diagonal_copy) \
+_(aten, diagonal_scatter) \
+_(aten, diff) \
+_(aten, digamma) \
+_(aten, digamma_) \
+_(aten, dist) \
+_(aten, div) \
+_(aten, div_) \
+_(aten, divide) \
+_(aten, divide_) \
+_(aten, dot) \
+_(aten, dropout) \
+_(aten, dropout_) \
+_(aten, dsplit) \
+_(aten, dstack) \
+_(aten, einsum) \
+_(aten, elu) \
+_(aten, elu_) \
+_(aten, elu_backward) \
+_(aten, embedding) \
+_(aten, embedding_backward) \
+_(aten, embedding_bag) \
+_(aten, embedding_dense_backward) \
+_(aten, embedding_renorm) \
+_(aten, embedding_renorm_) \
+_(aten, embedding_sparse_backward) \
+_(aten, empty) \
+_(aten, empty_like) \
+_(aten, empty_permuted) \
+_(aten, empty_quantized) \
+_(aten, empty_strided) \
+_(aten, eq) \
+_(aten, eq_) \
+_(aten, equal) \
+_(aten, erf) \
+_(aten, erf_) \
+_(aten, erfc) \
+_(aten, erfc_) \
+_(aten, erfinv) \
+_(aten, erfinv_) \
+_(aten, exp) \
+_(aten, exp2) \
+_(aten, exp2_) \
+_(aten, exp_) \
+_(aten, expand) \
+_(aten, expand_as) \
+_(aten, expand_copy) \
+_(aten, expm1) \
+_(aten, expm1_) \
+_(aten, exponential) \
+_(aten, exponential_) \
+_(aten, eye) \
+_(aten, fake_quantize_per_channel_affine) \
+_(aten, fake_quantize_per_channel_affine_cachemask) \
+_(aten, fake_quantize_per_channel_affine_cachemask_backward) \
+_(aten, fake_quantize_per_tensor_affine) \
+_(aten, fake_quantize_per_tensor_affine_cachemask) \
+_(aten, fake_quantize_per_tensor_affine_cachemask_backward) \
+_(aten, fbgemm_linear_fp16_weight) \
+_(aten, fbgemm_linear_fp16_weight_fp32_activation) \
+_(aten, fbgemm_linear_int8_weight) \
+_(aten, fbgemm_linear_int8_weight_fp32_activation) \
+_(aten, fbgemm_linear_quantize_weight) \
+_(aten, fbgemm_pack_gemm_matrix_fp16) \
+_(aten, fbgemm_pack_quantized_matrix) \
+_(aten, feature_alpha_dropout) \
+_(aten, feature_alpha_dropout_) \
+_(aten, feature_dropout) \
+_(aten, feature_dropout_) \
+_(aten, fft_fft) \
+_(aten, fft_fft2) \
+_(aten, fft_fftfreq) \
+_(aten, fft_fftn) \
+_(aten, fft_fftshift) \
+_(aten, fft_hfft) \
+_(aten, fft_hfft2) \
+_(aten, fft_hfftn) \
+_(aten, fft_ifft) \
+_(aten, fft_ifft2) \
+_(aten, fft_ifftn) \
+_(aten, fft_ifftshift) \
+_(aten, fft_ihfft) \
+_(aten, fft_ihfft2) \
+_(aten, fft_ihfftn) \
+_(aten, fft_irfft) \
+_(aten, fft_irfft2) \
+_(aten, fft_irfftn) \
+_(aten, fft_rfft) \
+_(aten, fft_rfft2) \
+_(aten, fft_rfftfreq) \
+_(aten, fft_rfftn) \
+_(aten, fill) \
+_(aten, fill_) \
+_(aten, fill_diagonal) \
+_(aten, fill_diagonal_) \
+_(aten, fix) \
+_(aten, fix_) \
+_(aten, flatten) \
+_(aten, flatten_dense_tensors) \
+_(aten, flip) \
+_(aten, fliplr) \
+_(aten, flipud) \
+_(aten, float_power) \
+_(aten, float_power_) \
+_(aten, floor) \
+_(aten, floor_) \
+_(aten, floor_divide) \
+_(aten, floor_divide_) \
+_(aten, fmax) \
+_(aten, fmin) \
+_(aten, fmod) \
+_(aten, fmod_) \
+_(aten, frac) \
+_(aten, frac_) \
+_(aten, fractional_max_pool2d) \
+_(aten, fractional_max_pool2d_backward) \
+_(aten, fractional_max_pool3d) \
+_(aten, fractional_max_pool3d_backward) \
+_(aten, frexp) \
+_(aten, frobenius_norm) \
+_(aten, from_file) \
+_(aten, full) \
+_(aten, full_like) \
+_(aten, fused_moving_avg_obs_fake_quant) \
+_(aten, gather) \
+_(aten, gather_backward) \
+_(aten, gcd) \
+_(aten, gcd_) \
+_(aten, ge) \
+_(aten, ge_) \
+_(aten, gelu) \
+_(aten, gelu_) \
+_(aten, gelu_backward) \
+_(aten, geometric) \
+_(aten, geometric_) \
+_(aten, geqrf) \
+_(aten, ger) \
+_(aten, glu) \
+_(aten, glu_backward) \
+_(aten, glu_backward_jvp) \
+_(aten, glu_jvp) \
+_(aten, gradient) \
+_(aten, greater) \
+_(aten, greater_) \
+_(aten, greater_equal) \
+_(aten, greater_equal_) \
+_(aten, grid_sampler) \
+_(aten, grid_sampler_2d) \
+_(aten, grid_sampler_2d_backward) \
+_(aten, grid_sampler_3d) \
+_(aten, grid_sampler_3d_backward) \
+_(aten, group_norm) \
+_(aten, gru) \
+_(aten, gru_cell) \
+_(aten, gt) \
+_(aten, gt_) \
+_(aten, hamming_window) \
+_(aten, hann_window) \
+_(aten, hardshrink) \
+_(aten, hardshrink_backward) \
+_(aten, hardsigmoid) \
+_(aten, hardsigmoid_) \
+_(aten, hardsigmoid_backward) \
+_(aten, hardswish) \
+_(aten, hardswish_) \
+_(aten, hardswish_backward) \
+_(aten, hardtanh) \
+_(aten, hardtanh_) \
+_(aten, hardtanh_backward) \
+_(aten, heaviside) \
+_(aten, heaviside_) \
+_(aten, hinge_embedding_loss) \
+_(aten, histc) \
+_(aten, histogram) \
+_(aten, histogramdd) \
+_(aten, hsplit) \
+_(aten, hspmm) \
+_(aten, hstack) \
+_(aten, huber_loss) \
+_(aten, huber_loss_backward) \
+_(aten, hypot) \
+_(aten, hypot_) \
+_(aten, i0) \
+_(aten, i0_) \
+_(aten, igamma) \
+_(aten, igamma_) \
+_(aten, igammac) \
+_(aten, igammac_) \
+_(aten, im2col) \
+_(aten, imag) \
+_(aten, index) \
+_(aten, index_add) \
+_(aten, index_add_) \
+_(aten, index_copy) \
+_(aten, index_copy_) \
+_(aten, index_fill) \
+_(aten, index_fill_) \
+_(aten, index_put) \
+_(aten, index_put_) \
+_(aten, index_reduce) \
+_(aten, index_reduce_) \
+_(aten, index_select) \
+_(aten, index_select_backward) \
+_(aten, indices) \
+_(aten, indices_copy) \
+_(aten, infinitely_differentiable_gelu_backward) \
+_(aten, inner) \
+_(aten, instance_norm) \
+_(aten, int_repr) \
+_(aten, inverse) \
+_(aten, is_coalesced) \
+_(aten, is_complex) \
+_(aten, is_conj) \
+_(aten, is_distributed) \
+_(aten, is_floating_point) \
+_(aten, is_inference) \
+_(aten, is_leaf) \
+_(aten, is_neg) \
+_(aten, is_nonzero) \
+_(aten, is_pinned) \
+_(aten, is_same_size) \
+_(aten, is_set_to) \
+_(aten, is_signed) \
+_(aten, is_vulkan_available) \
+_(aten, isclose) \
+_(aten, isfinite) \
+_(aten, isin) \
+_(aten, isinf) \
+_(aten, isnan) \
+_(aten, isneginf) \
+_(aten, isposinf) \
+_(aten, isreal) \
+_(aten, istft) \
+_(aten, item) \
+_(aten, kaiser_window) \
+_(aten, kl_div) \
+_(aten, kron) \
+_(aten, kthvalue) \
+_(aten, l1_loss) \
+_(aten, layer_norm) \
+_(aten, lcm) \
+_(aten, lcm_) \
+_(aten, ldexp) \
+_(aten, ldexp_) \
+_(aten, le) \
+_(aten, le_) \
+_(aten, leaky_relu) \
+_(aten, leaky_relu_) \
+_(aten, leaky_relu_backward) \
+_(aten, lerp) \
+_(aten, lerp_) \
+_(aten, less) \
+_(aten, less_) \
+_(aten, less_equal) \
+_(aten, less_equal_) \
+_(aten, lgamma) \
+_(aten, lgamma_) \
+_(aten, lift) \
+_(aten, lift_fresh) \
+_(aten, lift_fresh_copy) \
+_(aten, linalg_cholesky) \
+_(aten, linalg_cholesky_ex) \
+_(aten, linalg_cond) \
+_(aten, linalg_cross) \
+_(aten, linalg_det) \
+_(aten, linalg_diagonal) \
+_(aten, linalg_eig) \
+_(aten, linalg_eigh) \
+_(aten, linalg_eigvals) \
+_(aten, linalg_eigvalsh) \
+_(aten, linalg_householder_product) \
+_(aten, linalg_inv) \
+_(aten, linalg_inv_ex) \
+_(aten, linalg_ldl_factor) \
+_(aten, linalg_ldl_factor_ex) \
+_(aten, linalg_ldl_solve) \
+_(aten, linalg_lstsq) \
+_(aten, linalg_lu) \
+_(aten, linalg_lu_factor) \
+_(aten, linalg_lu_factor_ex) \
+_(aten, linalg_lu_solve) \
+_(aten, linalg_matmul) \
+_(aten, linalg_matrix_exp) \
+_(aten, linalg_matrix_norm) \
+_(aten, linalg_matrix_power) \
+_(aten, linalg_matrix_rank) \
+_(aten, linalg_multi_dot) \
+_(aten, linalg_norm) \
+_(aten, linalg_pinv) \
+_(aten, linalg_qr) \
+_(aten, linalg_slogdet) \
+_(aten, linalg_solve) \
+_(aten, linalg_solve_ex) \
+_(aten, linalg_solve_triangular) \
+_(aten, linalg_svd) \
+_(aten, linalg_svdvals) \
+_(aten, linalg_tensorinv) \
+_(aten, linalg_tensorsolve) \
+_(aten, linalg_vander) \
+_(aten, linalg_vecdot) \
+_(aten, linalg_vector_norm) \
+_(aten, linear) \
+_(aten, linear_backward) \
+_(aten, linspace) \
+_(aten, log) \
+_(aten, log10) \
+_(aten, log10_) \
+_(aten, log1p) \
+_(aten, log1p_) \
+_(aten, log2) \
+_(aten, log2_) \
+_(aten, log_) \
+_(aten, log_normal) \
+_(aten, log_normal_) \
+_(aten, log_sigmoid) \
+_(aten, log_sigmoid_backward) \
+_(aten, log_sigmoid_forward) \
+_(aten, log_softmax) \
+_(aten, logaddexp) \
+_(aten, logaddexp2) \
+_(aten, logcumsumexp) \
+_(aten, logdet) \
+_(aten, logical_and) \
+_(aten, logical_and_) \
+_(aten, logical_not) \
+_(aten, logical_not_) \
+_(aten, logical_or) \
+_(aten, logical_or_) \
+_(aten, logical_xor) \
+_(aten, logical_xor_) \
+_(aten, logit) \
+_(aten, logit_) \
+_(aten, logit_backward) \
+_(aten, logspace) \
+_(aten, logsumexp) \
+_(aten, lshift) \
+_(aten, lstm) \
+_(aten, lstm_cell) \
+_(aten, lstm_mps_backward) \
+_(aten, lt) \
+_(aten, lt_) \
+_(aten, lu_solve) \
+_(aten, lu_unpack) \
+_(aten, mH) \
+_(aten, mT) \
+_(aten, margin_ranking_loss) \
+_(aten, masked_fill) \
+_(aten, masked_fill_) \
+_(aten, masked_scatter) \
+_(aten, masked_scatter_) \
+_(aten, masked_scatter_backward) \
+_(aten, masked_select) \
+_(aten, masked_select_backward) \
+_(aten, matmul) \
+_(aten, matmul_backward) \
+_(aten, matrix_H) \
+_(aten, matrix_exp) \
+_(aten, matrix_exp_backward) \
+_(aten, matrix_power) \
+_(aten, max) \
+_(aten, max_pool1d) \
+_(aten, max_pool1d_with_indices) \
+_(aten, max_pool2d) \
+_(aten, max_pool2d_backward) \
+_(aten, max_pool2d_with_indices) \
+_(aten, max_pool2d_with_indices_backward) \
+_(aten, max_pool3d) \
+_(aten, max_pool3d_with_indices) \
+_(aten, max_pool3d_with_indices_backward) \
+_(aten, max_unpool2d) \
+_(aten, max_unpool3d) \
+_(aten, maximum) \
+_(aten, mean) \
+_(aten, median) \
+_(aten, meshgrid) \
+_(aten, min) \
+_(aten, minimum) \
+_(aten, miopen_batch_norm) \
+_(aten, miopen_batch_norm_backward) \
+_(aten, miopen_convolution) \
+_(aten, miopen_convolution_add_relu) \
+_(aten, miopen_convolution_relu) \
+_(aten, miopen_convolution_transpose) \
+_(aten, miopen_depthwise_convolution) \
+_(aten, miopen_rnn) \
+_(aten, miopen_rnn_backward) \
+_(aten, mish) \
+_(aten, mish_) \
+_(aten, mish_backward) \
+_(aten, mkldnn_adaptive_avg_pool2d) \
+_(aten, mkldnn_adaptive_avg_pool2d_backward) \
+_(aten, mkldnn_convolution) \
+_(aten, mkldnn_linear) \
+_(aten, mkldnn_linear_backward) \
+_(aten, mkldnn_linear_backward_input) \
+_(aten, mkldnn_linear_backward_weights) \
+_(aten, mkldnn_max_pool2d) \
+_(aten, mkldnn_max_pool2d_backward) \
+_(aten, mkldnn_max_pool3d) \
+_(aten, mkldnn_max_pool3d_backward) \
+_(aten, mkldnn_reorder_conv2d_weight) \
+_(aten, mkldnn_reorder_conv3d_weight) \
+_(aten, mkldnn_rnn_layer) \
+_(aten, mkldnn_rnn_layer_backward) \
+_(aten, mm) \
+_(aten, mode) \
+_(aten, moveaxis) \
+_(aten, movedim) \
+_(aten, mps_convolution_backward) \
+_(aten, mps_convolution_transpose_backward) \
+_(aten, mse_loss) \
+_(aten, mse_loss_backward) \
+_(aten, msort) \
+_(aten, mul) \
+_(aten, mul_) \
+_(aten, multi_margin_loss) \
+_(aten, multi_margin_loss_backward) \
+_(aten, multilabel_margin_loss) \
+_(aten, multilabel_margin_loss_backward) \
+_(aten, multilabel_margin_loss_forward) \
+_(aten, multinomial) \
+_(aten, multiply) \
+_(aten, multiply_) \
+_(aten, mv) \
+_(aten, mvlgamma) \
+_(aten, mvlgamma_) \
+_(aten, nan_to_num) \
+_(aten, nan_to_num_) \
+_(aten, nanmean) \
+_(aten, nanmedian) \
+_(aten, nanquantile) \
+_(aten, nansum) \
+_(aten, narrow) \
+_(aten, narrow_copy) \
+_(aten, native_batch_norm) \
+_(aten, native_batch_norm_backward) \
+_(aten, native_channel_shuffle) \
+_(aten, native_dropout) \
+_(aten, native_dropout_backward) \
+_(aten, native_group_norm) \
+_(aten, native_group_norm_backward) \
+_(aten, native_layer_norm) \
+_(aten, native_layer_norm_backward) \
+_(aten, native_norm) \
+_(aten, ne) \
+_(aten, ne_) \
+_(aten, neg) \
+_(aten, neg_) \
+_(aten, negative) \
+_(aten, negative_) \
+_(aten, nested_to_padded_tensor) \
+_(aten, new_empty) \
+_(aten, new_empty_strided) \
+_(aten, new_full) \
+_(aten, new_ones) \
+_(aten, new_zeros) \
+_(aten, nextafter) \
+_(aten, nextafter_) \
+_(aten, nll_loss) \
+_(aten, nll_loss2d) \
+_(aten, nll_loss2d_backward) \
+_(aten, nll_loss2d_forward) \
+_(aten, nll_loss_backward) \
+_(aten, nll_loss_forward) \
+_(aten, nll_loss_nd) \
+_(aten, nonzero) \
+_(aten, nonzero_numpy) \
+_(aten, nonzero_static) \
+_(aten, norm) \
+_(aten, norm_except_dim) \
+_(aten, normal) \
+_(aten, normal_) \
+_(aten, normal_functional) \
+_(aten, not_equal) \
+_(aten, not_equal_) \
+_(aten, nuclear_norm) \
+_(aten, numpy_T) \
+_(aten, one_hot) \
+_(aten, ones) \
+_(aten, ones_like) \
+_(aten, orgqr) \
+_(aten, ormqr) \
+_(aten, outer) \
+_(aten, output_nr) \
+_(aten, pad) \
+_(aten, pad_sequence) \
+_(aten, pairwise_distance) \
+_(aten, pdist) \
+_(aten, permute) \
+_(aten, permute_copy) \
+_(aten, pin_memory) \
+_(aten, pinverse) \
+_(aten, pixel_shuffle) \
+_(aten, pixel_unshuffle) \
+_(aten, poisson) \
+_(aten, poisson_nll_loss) \
+_(aten, polar) \
+_(aten, polygamma) \
+_(aten, polygamma_) \
+_(aten, positive) \
+_(aten, pow) \
+_(aten, pow_) \
+_(aten, prelu) \
+_(aten, prod) \
+_(aten, promote_types) \
+_(aten, put) \
+_(aten, put_) \
+_(aten, q_per_channel_axis) \
+_(aten, q_per_channel_scales) \
+_(aten, q_per_channel_zero_points) \
+_(aten, q_scale) \
+_(aten, q_zero_point) \
+_(aten, qr) \
+_(aten, qscheme) \
+_(aten, quantile) \
+_(aten, quantize_per_channel) \
+_(aten, quantize_per_tensor) \
+_(aten, quantize_per_tensor_dynamic) \
+_(aten, quantized_batch_norm) \
+_(aten, quantized_gru_cell) \
+_(aten, quantized_lstm_cell) \
+_(aten, quantized_max_pool1d) \
+_(aten, quantized_max_pool2d) \
+_(aten, quantized_max_pool3d) \
+_(aten, quantized_rnn_relu_cell) \
+_(aten, quantized_rnn_tanh_cell) \
+_(aten, rad2deg) \
+_(aten, rad2deg_) \
+_(aten, rand) \
+_(aten, rand_like) \
+_(aten, randint) \
+_(aten, randint_like) \
+_(aten, randn) \
+_(aten, randn_like) \
+_(aten, random) \
+_(aten, random_) \
+_(aten, randperm) \
+_(aten, range) \
+_(aten, ravel) \
+_(aten, real) \
+_(aten, reciprocal) \
+_(aten, reciprocal_) \
+_(aten, record_stream) \
+_(aten, refine_names) \
+_(aten, reflection_pad1d) \
+_(aten, reflection_pad1d_backward) \
+_(aten, reflection_pad2d) \
+_(aten, reflection_pad2d_backward) \
+_(aten, reflection_pad3d) \
+_(aten, reflection_pad3d_backward) \
+_(aten, relu) \
+_(aten, relu6) \
+_(aten, relu6_) \
+_(aten, relu_) \
+_(aten, remainder) \
+_(aten, remainder_) \
+_(aten, rename) \
+_(aten, rename_) \
+_(aten, renorm) \
+_(aten, renorm_) \
+_(aten, repeat) \
+_(aten, repeat_interleave) \
+_(aten, replication_pad1d) \
+_(aten, replication_pad1d_backward) \
+_(aten, replication_pad2d) \
+_(aten, replication_pad2d_backward) \
+_(aten, replication_pad3d) \
+_(aten, replication_pad3d_backward) \
+_(aten, requires_grad) \
+_(aten, requires_grad_) \
+_(aten, reshape) \
+_(aten, reshape_as) \
+_(aten, resize) \
+_(aten, resize_) \
+_(aten, resize_as) \
+_(aten, resize_as_) \
+_(aten, resize_as_sparse) \
+_(aten, resize_as_sparse_) \
+_(aten, resolve_conj) \
+_(aten, resolve_neg) \
+_(aten, result_type) \
+_(aten, retain_grad) \
+_(aten, retains_grad) \
+_(aten, rnn_relu) \
+_(aten, rnn_relu_cell) \
+_(aten, rnn_tanh) \
+_(aten, rnn_tanh_cell) \
+_(aten, roll) \
+_(aten, rot90) \
+_(aten, round) \
+_(aten, round_) \
+_(aten, row_indices) \
+_(aten, row_indices_copy) \
+_(aten, row_stack) \
+_(aten, rrelu) \
+_(aten, rrelu_) \
+_(aten, rrelu_with_noise) \
+_(aten, rrelu_with_noise_) \
+_(aten, rrelu_with_noise_backward) \
+_(aten, rshift) \
+_(aten, rsqrt) \
+_(aten, rsqrt_) \
+_(aten, rsub) \
+_(aten, scalar_tensor) \
+_(aten, scaled_dot_product_attention) \
+_(aten, scatter) \
+_(aten, scatter_) \
+_(aten, scatter_add) \
+_(aten, scatter_add_) \
+_(aten, scatter_reduce) \
+_(aten, scatter_reduce_) \
+_(aten, searchsorted) \
+_(aten, segment_reduce) \
+_(aten, select) \
+_(aten, select_backward) \
+_(aten, select_copy) \
+_(aten, select_scatter) \
+_(aten, selu) \
+_(aten, selu_) \
+_(aten, set) \
+_(aten, set_) \
+_(aten, set_data) \
+_(aten, sgn) \
+_(aten, sgn_) \
+_(aten, sigmoid) \
+_(aten, sigmoid_) \
+_(aten, sigmoid_backward) \
+_(aten, sign) \
+_(aten, sign_) \
+_(aten, signbit) \
+_(aten, silu) \
+_(aten, silu_) \
+_(aten, silu_backward) \
+_(aten, sin) \
+_(aten, sin_) \
+_(aten, sinc) \
+_(aten, sinc_) \
+_(aten, sinh) \
+_(aten, sinh_) \
+_(aten, size) \
+_(aten, slice) \
+_(aten, slice_backward) \
+_(aten, slice_copy) \
+_(aten, slice_inverse) \
+_(aten, slice_scatter) \
+_(aten, slogdet) \
+_(aten, slow_conv3d) \
+_(aten, slow_conv3d_forward) \
+_(aten, slow_conv_dilated2d) \
+_(aten, slow_conv_dilated3d) \
+_(aten, slow_conv_transpose2d) \
+_(aten, slow_conv_transpose3d) \
+_(aten, smm) \
+_(aten, smooth_l1_loss) \
+_(aten, smooth_l1_loss_backward) \
+_(aten, soft_margin_loss) \
+_(aten, soft_margin_loss_backward) \
+_(aten, softmax) \
+_(aten, softplus) \
+_(aten, softplus_backward) \
+_(aten, softshrink) \
+_(aten, softshrink_backward) \
+_(aten, sort) \
+_(aten, sparse_bsc_tensor) \
+_(aten, sparse_bsr_tensor) \
+_(aten, sparse_compressed_tensor) \
+_(aten, sparse_coo_tensor) \
+_(aten, sparse_csc_tensor) \
+_(aten, sparse_csr_tensor) \
+_(aten, sparse_dim) \
+_(aten, sparse_mask) \
+_(aten, sparse_resize) \
+_(aten, sparse_resize_) \
+_(aten, sparse_resize_and_clear) \
+_(aten, sparse_resize_and_clear_) \
+_(aten, sparse_sampled_addmm) \
+_(aten, special_airy_ai) \
+_(aten, special_bessel_j0) \
+_(aten, special_bessel_j1) \
+_(aten, special_bessel_y0) \
+_(aten, special_bessel_y1) \
+_(aten, special_chebyshev_polynomial_t) \
+_(aten, special_chebyshev_polynomial_u) \
+_(aten, special_chebyshev_polynomial_v) \
+_(aten, special_chebyshev_polynomial_w) \
+_(aten, special_digamma) \
+_(aten, special_entr) \
+_(aten, special_erf) \
+_(aten, special_erfc) \
+_(aten, special_erfcx) \
+_(aten, special_erfinv) \
+_(aten, special_exp2) \
+_(aten, special_expit) \
+_(aten, special_expm1) \
+_(aten, special_gammainc) \
+_(aten, special_gammaincc) \
+_(aten, special_gammaln) \
+_(aten, special_hermite_polynomial_h) \
+_(aten, special_hermite_polynomial_he) \
+_(aten, special_i0) \
+_(aten, special_i0e) \
+_(aten, special_i1) \
+_(aten, special_i1e) \
+_(aten, special_laguerre_polynomial_l) \
+_(aten, special_legendre_polynomial_p) \
+_(aten, special_log1p) \
+_(aten, special_log_ndtr) \
+_(aten, special_log_softmax) \
+_(aten, special_logit) \
+_(aten, special_logsumexp) \
+_(aten, special_modified_bessel_i0) \
+_(aten, special_modified_bessel_i1) \
+_(aten, special_modified_bessel_k0) \
+_(aten, special_modified_bessel_k1) \
+_(aten, special_multigammaln) \
+_(aten, special_ndtr) \
+_(aten, special_ndtri) \
+_(aten, special_polygamma) \
+_(aten, special_psi) \
+_(aten, special_round) \
+_(aten, special_scaled_modified_bessel_k0) \
+_(aten, special_scaled_modified_bessel_k1) \
+_(aten, special_shifted_chebyshev_polynomial_t) \
+_(aten, special_shifted_chebyshev_polynomial_u) \
+_(aten, special_shifted_chebyshev_polynomial_v) \
+_(aten, special_shifted_chebyshev_polynomial_w) \
+_(aten, special_sinc) \
+_(aten, special_softmax) \
+_(aten, special_spherical_bessel_j0) \
+_(aten, special_xlog1py) \
+_(aten, special_xlogy) \
+_(aten, special_zeta) \
+_(aten, split) \
+_(aten, split_copy) \
+_(aten, split_with_sizes) \
+_(aten, split_with_sizes_copy) \
+_(aten, sqrt) \
+_(aten, sqrt_) \
+_(aten, square) \
+_(aten, square_) \
+_(aten, squeeze) \
+_(aten, squeeze_) \
+_(aten, squeeze_copy) \
+_(aten, sspaddmm) \
+_(aten, stack) \
+_(aten, std) \
+_(aten, std_mean) \
+_(aten, stft) \
+_(aten, stride) \
+_(aten, sub) \
+_(aten, sub_) \
+_(aten, subtract) \
+_(aten, subtract_) \
+_(aten, sum) \
+_(aten, sum_to_size) \
+_(aten, svd) \
+_(aten, swapaxes) \
+_(aten, swapaxes_) \
+_(aten, swapdims) \
+_(aten, swapdims_) \
+_(aten, sym_constrain_range) \
+_(aten, sym_constrain_range_for_size) \
+_(aten, sym_numel) \
+_(aten, sym_size) \
+_(aten, sym_storage_offset) \
+_(aten, sym_stride) \
+_(aten, t) \
+_(aten, t_) \
+_(aten, t_copy) \
+_(aten, take) \
+_(aten, take_along_dim) \
+_(aten, tan) \
+_(aten, tan_) \
+_(aten, tanh) \
+_(aten, tanh_) \
+_(aten, tanh_backward) \
+_(aten, tensor_split) \
+_(aten, tensordot) \
+_(aten, thnn_conv2d) \
+_(aten, threshold) \
+_(aten, threshold_) \
+_(aten, threshold_backward) \
+_(aten, tile) \
+_(aten, to) \
+_(aten, to_dense) \
+_(aten, to_dense_backward) \
+_(aten, to_mkldnn) \
+_(aten, to_mkldnn_backward) \
+_(aten, to_padded_tensor) \
+_(aten, to_sparse) \
+_(aten, to_sparse_bsc) \
+_(aten, to_sparse_bsr) \
+_(aten, to_sparse_csc) \
+_(aten, to_sparse_csr) \
+_(aten, topk) \
+_(aten, trace) \
+_(aten, trace_backward) \
+_(aten, transpose) \
+_(aten, transpose_) \
+_(aten, transpose_copy) \
+_(aten, trapezoid) \
+_(aten, trapz) \
+_(aten, triangular_solve) \
+_(aten, tril) \
+_(aten, tril_) \
+_(aten, tril_indices) \
+_(aten, triplet_margin_loss) \
+_(aten, triu) \
+_(aten, triu_) \
+_(aten, triu_indices) \
+_(aten, true_divide) \
+_(aten, true_divide_) \
+_(aten, trunc) \
+_(aten, trunc_) \
+_(aten, type_as) \
+_(aten, unbind) \
+_(aten, unbind_copy) \
+_(aten, unflatten) \
+_(aten, unflatten_dense_tensors) \
+_(aten, unfold) \
+_(aten, unfold_backward) \
+_(aten, unfold_copy) \
+_(aten, uniform) \
+_(aten, uniform_) \
+_(aten, unique_consecutive) \
+_(aten, unique_dim) \
+_(aten, unique_dim_consecutive) \
+_(aten, unsafe_chunk) \
+_(aten, unsafe_split) \
+_(aten, unsafe_split_with_sizes) \
+_(aten, unsqueeze) \
+_(aten, unsqueeze_) \
+_(aten, unsqueeze_copy) \
+_(aten, upsample_bicubic2d) \
+_(aten, upsample_bicubic2d_backward) \
+_(aten, upsample_bilinear2d) \
+_(aten, upsample_bilinear2d_backward) \
+_(aten, upsample_linear1d) \
+_(aten, upsample_linear1d_backward) \
+_(aten, upsample_nearest1d) \
+_(aten, upsample_nearest1d_backward) \
+_(aten, upsample_nearest2d) \
+_(aten, upsample_nearest2d_backward) \
+_(aten, upsample_nearest3d) \
+_(aten, upsample_nearest3d_backward) \
+_(aten, upsample_trilinear3d) \
+_(aten, upsample_trilinear3d_backward) \
+_(aten, value_selecting_reduction_backward) \
+_(aten, values) \
+_(aten, values_copy) \
+_(aten, vander) \
+_(aten, var) \
+_(aten, var_mean) \
+_(aten, vdot) \
+_(aten, view) \
+_(aten, view_as) \
+_(aten, view_as_complex) \
+_(aten, view_as_complex_copy) \
+_(aten, view_as_real) \
+_(aten, view_as_real_copy) \
+_(aten, view_copy) \
+_(aten, vsplit) \
+_(aten, vstack) \
+_(aten, where) \
+_(aten, xlogy) \
+_(aten, xlogy_) \
+_(aten, zero) \
+_(aten, zero_) \
+_(aten, zeros) \
+_(aten, zeros_like)
+
+#define FORALL_ATTR_BASE_SYMBOLS(_) \
+_(attr, A) \
+_(attr, B) \
+_(attr, C) \
+_(attr, H) \
+_(attr, HxW) \
+_(attr, K) \
+_(attr, L) \
+_(attr, LD) \
+_(attr, LU) \
+_(attr, LU_data) \
+_(attr, LU_pivots) \
+_(attr, M) \
+_(attr, N) \
+_(attr, P) \
+_(attr, Q) \
+_(attr, R) \
+_(attr, S) \
+_(attr, U) \
+_(attr, UPLO) \
+_(attr, V) \
+_(attr, Vh) \
+_(attr, W) \
+_(attr, X) \
+_(attr, a) \
+_(attr, abs) \
+_(attr, accumulate) \
+_(attr, accumulate_matches) \
+_(attr, activation) \
+_(attr, addends) \
+_(attr, adjoint) \
+_(attr, alg_id) \
+_(attr, align_corners) \
+_(attr, allow_tf32) \
+_(attr, alpha) \
+_(attr, amsgrad) \
+_(attr, anchor) \
+_(attr, angle) \
+_(attr, any) \
+_(attr, api_name) \
+_(attr, append) \
+_(attr, approximate) \
+_(attr, arg1) \
+_(attr, arg2) \
+_(attr, arg3) \
+_(attr, arg_out) \
+_(attr, assert_msg) \
+_(attr, assume_unique) \
+_(attr, atol) \
+_(attr, attn_bias) \
+_(attr, attn_mask) \
+_(attr, average_attn_weights) \
+_(attr, averaging_const) \
+_(attr, aweights) \
+_(attr, axis) \
+_(attr, axis0) \
+_(attr, axis1) \
+_(attr, b) \
+_(attr, b_hh) \
+_(attr, b_ih) \
+_(attr, bag_size) \
+_(attr, base) \
+_(attr, batch1) \
+_(attr, batch2) \
+_(attr, batch_dim) \
+_(attr, batch_first) \
+_(attr, batch_size) \
+_(attr, batch_sizes) \
+_(attr, benchmark) \
+_(attr, beta) \
+_(attr, beta1) \
+_(attr, beta2) \
+_(attr, bias) \
+_(attr, bias_defined) \
+_(attr, bias_g) \
+_(attr, bias_requires_grad) \
+_(attr, bias_sizes) \
+_(attr, bidirectional) \
+_(attr, bin_edges) \
+_(attr, bins) \
+_(attr, bit_width) \
+_(attr, blank) \
+_(attr, blocksize) \
+_(attr, boundaries) \
+_(attr, buffer) \
+_(attr, causal_diagonal) \
+_(attr, ccol_indices) \
+_(attr, cdim) \
+_(attr, cdist) \
+_(attr, ceil_mode) \
+_(attr, cell_state_fwd) \
+_(attr, center) \
+_(attr, ch_axis) \
+_(attr, check_errors) \
+_(attr, chunks) \
+_(attr, coalesced) \
+_(attr, coefficients) \
+_(attr, col) \
+_(attr, col_indices) \
+_(attr, col_offsets) \
+_(attr, col_offsets_hh) \
+_(attr, col_offsets_ih) \
+_(attr, compressed_A) \
+_(attr, compressed_idx) \
+_(attr, compressed_indices) \
+_(attr, compressed_indices_dtype) \
+_(attr, compute_log_sumexp) \
+_(attr, compute_mode) \
+_(attr, compute_uv) \
+_(attr, compute_v) \
+_(attr, condition) \
+_(attr, copy) \
+_(attr, correction) \
+_(attr, count) \
+_(attr, count_include_pad) \
+_(attr, counts) \
+_(attr, cpu_dtype) \
+_(attr, cpu_enabled) \
+_(attr, cpu_nested_shape_example) \
+_(attr, create_graph) \
+_(attr, crow_indices) \
+_(attr, cu_seqlens_k) \
+_(attr, cu_seqlens_q) \
+_(attr, cuda_dtype) \
+_(attr, cuda_enabled) \
+_(attr, cudnn_enable) \
+_(attr, cudnn_enabled) \
+_(attr, cum_seq_k) \
+_(attr, cum_seq_q) \
+_(attr, custom_mask_type) \
+_(attr, cx) \
+_(attr, cx_) \
+_(attr, cx_tmp) \
+_(attr, cy) \
+_(attr, cy_) \
+_(attr, d) \
+_(attr, dampening) \
+_(attr, data) \
+_(attr, decimals) \
+_(attr, delta) \
+_(attr, dense) \
+_(attr, dense_B) \
+_(attr, dense_dim) \
+_(attr, density) \
+_(attr, dep_token) \
+_(attr, descending) \
+_(attr, destination) \
+_(attr, deterministic) \
+_(attr, device) \
+_(attr, device_index) \
+_(attr, dgrad_glu) \
+_(attr, diagonal) \
+_(attr, diagonals) \
+_(attr, dilation) \
+_(attr, dim) \
+_(attr, dim0) \
+_(attr, dim1) \
+_(attr, dim2) \
+_(attr, dimension) \
+_(attr, dims) \
+_(attr, dims_other) \
+_(attr, dims_self) \
+_(attr, divisor_override) \
+_(attr, downscale_factor) \
+_(attr, driver) \
+_(attr, dropout) \
+_(attr, dropout_mask) \
+_(attr, dropout_p) \
+_(attr, dropout_seed) \
+_(attr, dropout_state) \
+_(attr, dst) \
+_(attr, dtype) \
+_(attr, dual) \
+_(attr, dummy) \
+_(attr, dx) \
+_(attr, edge_order) \
+_(attr, eigenvalues) \
+_(attr, eigenvectors) \
+_(attr, eigvals) \
+_(attr, eigvecs) \
+_(attr, element) \
+_(attr, elements) \
+_(attr, ellipsis_idx) \
+_(attr, embed_dim) \
+_(attr, end) \
+_(attr, end_dim) \
+_(attr, eps) \
+_(attr, epsilon) \
+_(attr, equal_nan) \
+_(attr, equation) \
+_(attr, exp_avg_sqs) \
+_(attr, exp_avgs) \
+_(attr, expand1) \
+_(attr, expand2) \
+_(attr, expand3) \
+_(attr, exponent) \
+_(attr, exponential_average_factor) \
+_(attr, fake_quant_enabled) \
+_(attr, fake_quant_on) \
+_(attr, ffn_bias_1) \
+_(attr, ffn_bias_2) \
+_(attr, ffn_weight_1) \
+_(attr, ffn_weight_2) \
+_(attr, filename) \
+_(attr, fill_value) \
+_(attr, flat) \
+_(attr, forward) \
+_(attr, found_inf) \
+_(attr, from) \
+_(attr, full) \
+_(attr, full_matrices) \
+_(attr, fuse_transform_0213) \
+_(attr, fweights) \
+_(attr, g) \
+_(attr, gO) \
+_(attr, generator) \
+_(attr, ggI) \
+_(attr, ggW) \
+_(attr, ggb) \
+_(attr, glu) \
+_(attr, grad) \
+_(attr, grad_bias) \
+_(attr, grad_cy) \
+_(attr, grad_factor) \
+_(attr, grad_glu) \
+_(attr, grad_hy) \
+_(attr, grad_in) \
+_(attr, grad_input) \
+_(attr, grad_input_mask) \
+_(attr, grad_out) \
+_(attr, grad_out_) \
+_(attr, grad_output) \
+_(attr, grad_scale) \
+_(attr, grad_w) \
+_(attr, grad_weight) \
+_(attr, grad_x) \
+_(attr, grad_y) \
+_(attr, gradient) \
+_(attr, grads) \
+_(attr, grid) \
+_(attr, group) \
+_(attr, groups) \
+_(attr, growth_interval) \
+_(attr, growth_tracker) \
+_(attr, half_to_float) \
+_(attr, has_bias) \
+_(attr, has_biases) \
+_(attr, hermitian) \
+_(attr, hidden_bias) \
+_(attr, hidden_gates) \
+_(attr, hidden_size) \
+_(attr, high) \
+_(attr, hist) \
+_(attr, hop_length) \
+_(attr, hx) \
+_(attr, hx_) \
+_(attr, hy_) \
+_(attr, i1) \
+_(attr, i2) \
+_(attr, i3) \
+_(attr, ignore_index) \
+_(attr, imag) \
+_(attr, impl_index) \
+_(attr, implicit) \
+_(attr, include_last_offset) \
+_(attr, include_self) \
+_(attr, increasing) \
+_(attr, ind) \
+_(attr, index) \
+_(attr, indexing) \
+_(attr, indices) \
+_(attr, info) \
+_(attr, initial) \
+_(attr, innerKTiles) \
+_(attr, input) \
+_(attr, input1) \
+_(attr, input2) \
+_(attr, input3) \
+_(attr, input_bias) \
+_(attr, input_dtype) \
+_(attr, input_g) \
+_(attr, input_gates) \
+_(attr, input_lengths) \
+_(attr, input_scale) \
+_(attr, input_size) \
+_(attr, input_sizes) \
+_(attr, inputs) \
+_(attr, interpolation) \
+_(attr, interpolation_mode) \
+_(attr, inv_scale) \
+_(attr, inverse) \
+_(attr, invert) \
+_(attr, invstd) \
+_(attr, is_causal) \
+_(attr, is_coalesced) \
+_(attr, is_crow) \
+_(attr, is_first_step) \
+_(attr, is_matrix) \
+_(attr, is_result) \
+_(attr, is_target) \
+_(attr, k) \
+_(attr, keepdim) \
+_(attr, kernel_size) \
+_(attr, key) \
+_(attr, label_smoothing) \
+_(attr, lambd) \
+_(attr, largest) \
+_(attr, last_dim_size) \
+_(attr, layersOutputs) \
+_(attr, layout) \
+_(attr, left) \
+_(attr, length) \
+_(attr, lengths) \
+_(attr, level) \
+_(attr, like) \
+_(attr, list) \
+_(attr, log_alpha) \
+_(attr, log_input) \
+_(attr, log_probs) \
+_(attr, log_target) \
+_(attr, logabsdet) \
+_(attr, logsumexp) \
+_(attr, low) \
+_(attr, lower) \
+_(attr, lr) \
+_(attr, ltm) \
+_(attr, m) \
+_(attr, mantissa) \
+_(attr, margin) \
+_(attr, mask) \
+_(attr, mask_check) \
+_(attr, mask_type) \
+_(attr, masked_grad) \
+_(attr, mat) \
+_(attr, mat1) \
+_(attr, mat2) \
+_(attr, matrices) \
+_(attr, max) \
+_(attr, max_exp_avg_sqs) \
+_(attr, max_k) \
+_(attr, max_norm) \
+_(attr, max_q) \
+_(attr, max_seqlen_k) \
+_(attr, max_seqlen_q) \
+_(attr, max_size) \
+_(attr, max_val) \
+_(attr, max_values) \
+_(attr, maximize) \
+_(attr, maximum_indices) \
+_(attr, maxnorm) \
+_(attr, mean) \
+_(attr, median) \
+_(attr, memory_format) \
+_(attr, meta) \
+_(attr, min) \
+_(attr, min_indices) \
+_(attr, min_val) \
+_(attr, minlength) \
+_(attr, mode) \
+_(attr, momentum) \
+_(attr, momentum_buffer_list) \
+_(attr, n) \
+_(attr, n_bins) \
+_(attr, n_fft) \
+_(attr, names) \
+_(attr, nan) \
+_(attr, need_weights) \
+_(attr, neg_log_likelihood) \
+_(attr, negative) \
+_(attr, negative_slope) \
+_(attr, neginf) \
+_(attr, nested_size) \
+_(attr, nested_strides) \
+_(attr, nesterov) \
+_(attr, new_data) \
+_(attr, nnz) \
+_(attr, noise) \
+_(attr, non_blocking) \
+_(attr, norm) \
+_(attr, norm_bias_1) \
+_(attr, norm_bias_2) \
+_(attr, norm_first) \
+_(attr, norm_type) \
+_(attr, norm_weight_1) \
+_(attr, norm_weight_2) \
+_(attr, normalization) \
+_(attr, normalized) \
+_(attr, normalized_shape) \
+_(attr, nt_example) \
+_(attr, num_chunks) \
+_(attr, num_classes) \
+_(attr, num_generated) \
+_(attr, num_groups) \
+_(attr, num_head) \
+_(attr, num_heads) \
+_(attr, num_layers) \
+_(attr, num_parallel) \
+_(attr, num_samples) \
+_(attr, num_splits_key) \
+_(attr, num_weights) \
+_(attr, numel) \
+_(attr, observer_on) \
+_(attr, offset) \
+_(attr, offset2bag) \
+_(attr, offsets) \
+_(attr, onesided) \
+_(attr, ord) \
+_(attr, order) \
+_(attr, other) \
+_(attr, out) \
+_(attr, out0) \
+_(attr, out1) \
+_(attr, out2) \
+_(attr, out3) \
+_(attr, out4) \
+_(attr, out5) \
+_(attr, out6) \
+_(attr, out_amax) \
+_(attr, out_dim) \
+_(attr, out_dtype) \
+_(attr, out_int32) \
+_(attr, outdim) \
+_(attr, output) \
+_(attr, output_mask) \
+_(attr, output_padding) \
+_(attr, output_scale) \
+_(attr, output_size) \
+_(attr, output_zero_point) \
+_(attr, p) \
+_(attr, packed) \
+_(attr, packed_hh) \
+_(attr, packed_ih) \
+_(attr, packed_weight) \
+_(attr, pad) \
+_(attr, pad_mode) \
+_(attr, padded) \
+_(attr, padding) \
+_(attr, padding_idx) \
+_(attr, padding_mode) \
+_(attr, padding_value) \
+_(attr, params) \
+_(attr, path) \
+_(attr, pdist) \
+_(attr, per_row_fake_quant) \
+_(attr, per_sample_weights) \
+_(attr, periodic) \
+_(attr, philox_offset) \
+_(attr, philox_seed) \
+_(attr, physical_layout) \
+_(attr, pin_memory) \
+_(attr, pivot) \
+_(attr, pivots) \
+_(attr, plain_idx) \
+_(attr, plain_indices) \
+_(attr, pos_weight) \
+_(attr, posinf) \
+_(attr, positive) \
+_(attr, pow) \
+_(attr, prepend) \
+_(attr, primal) \
+_(attr, prob) \
+_(attr, proj_bias) \
+_(attr, proj_size) \
+_(attr, proj_weight) \
+_(attr, q) \
+_(attr, qGroupSize) \
+_(attr, qScaleAndZeros) \
+_(attr, qkv) \
+_(attr, qkv_bias) \
+_(attr, qkv_weight) \
+_(attr, qtensor) \
+_(attr, quant_max) \
+_(attr, quant_min) \
+_(attr, quasi) \
+_(attr, query) \
+_(attr, r) \
+_(attr, ragged_idx) \
+_(attr, random_samples) \
+_(attr, range) \
+_(attr, rank) \
+_(attr, ratio) \
+_(attr, rcond) \
+_(attr, real) \
+_(attr, reduce) \
+_(attr, reduce_range) \
+_(attr, reduction) \
+_(attr, repeats) \
+_(attr, replacement) \
+_(attr, requires_grad) \
+_(attr, reserve) \
+_(attr, reserveSpace) \
+_(attr, reservedSpace) \
+_(attr, residuals) \
+_(attr, result) \
+_(attr, retain_graph) \
+_(attr, return_complex) \
+_(attr, return_counts) \
+_(attr, return_debug_mask) \
+_(attr, return_inverse) \
+_(attr, reverse) \
+_(attr, right) \
+_(attr, rounding_mode) \
+_(attr, row) \
+_(attr, row_indices) \
+_(attr, rstd) \
+_(attr, rtol) \
+_(attr, running_max) \
+_(attr, running_mean) \
+_(attr, running_min) \
+_(attr, running_var) \
+_(attr, s) \
+_(attr, save_invstd) \
+_(attr, save_mean) \
+_(attr, save_var) \
+_(attr, save_var_transform) \
+_(attr, saved_g) \
+_(attr, saved_norms) \
+_(attr, saved_v) \
+_(attr, scalar) \
+_(attr, scalar1) \
+_(attr, scalar2) \
+_(attr, scalars) \
+_(attr, scale) \
+_(attr, scale_a) \
+_(attr, scale_b) \
+_(attr, scale_backoff_factor) \
+_(attr, scale_factors) \
+_(attr, scale_grad_by_freq) \
+_(attr, scale_growth_factor) \
+_(attr, scale_hh) \
+_(attr, scale_ih) \
+_(attr, scale_result) \
+_(attr, scales) \
+_(attr, scales_d) \
+_(attr, scales_h) \
+_(attr, scales_w) \
+_(attr, sections) \
+_(attr, seed) \
+_(attr, self) \
+_(attr, self_is_result) \
+_(attr, self_num_batch_dims) \
+_(attr, self_or_result) \
+_(attr, self_sizes) \
+_(attr, seqlen_k) \
+_(attr, sequences) \
+_(attr, shape) \
+_(attr, shared) \
+_(attr, shifts) \
+_(attr, side) \
+_(attr, sigma) \
+_(attr, sign) \
+_(attr, singular_values) \
+_(attr, size) \
+_(attr, sizes) \
+_(attr, skip_first) \
+_(attr, sobolstate) \
+_(attr, solution) \
+_(attr, some) \
+_(attr, sorted) \
+_(attr, sorted_sequence) \
+_(attr, sorter) \
+_(attr, source) \
+_(attr, spacing) \
+_(attr, sparse) \
+_(attr, sparse_dim) \
+_(attr, sparse_grad) \
+_(attr, split_size) \
+_(attr, split_sizes) \
+_(attr, src) \
+_(attr, stable) \
+_(attr, start) \
+_(attr, start_dim) \
+_(attr, state_steps) \
+_(attr, std) \
+_(attr, step) \
+_(attr, steps) \
+_(attr, storage_offset) \
+_(attr, stride) \
+_(attr, sum_dy) \
+_(attr, sum_dy_xmu) \
+_(attr, sumdim) \
+_(attr, swap) \
+_(attr, symmetric_quant) \
+_(attr, t) \
+_(attr, tangent) \
+_(attr, target) \
+_(attr, target_lengths) \
+_(attr, targets) \
+_(attr, tau) \
+_(attr, tensor) \
+_(attr, tensor1) \
+_(attr, tensor2) \
+_(attr, tensor_indices_or_sections) \
+_(attr, tensors) \
+_(attr, tensors1) \
+_(attr, test_element) \
+_(attr, test_elements) \
+_(attr, the_template) \
+_(attr, theta) \
+_(attr, threshold) \
+_(attr, to) \
+_(attr, tol) \
+_(attr, total) \
+_(attr, total_length) \
+_(attr, total_weight) \
+_(attr, train) \
+_(attr, training) \
+_(attr, transpose) \
+_(attr, transpose_result) \
+_(attr, transposed) \
+_(attr, type1) \
+_(attr, type2) \
+_(attr, unbiased) \
+_(attr, unitriangular) \
+_(attr, unpack_data) \
+_(attr, unpack_pivots) \
+_(attr, unroll_dim) \
+_(attr, unsafe) \
+_(attr, upper) \
+_(attr, upscale_factor) \
+_(attr, use_fast_accum) \
+_(attr, use_gelu) \
+_(attr, use_input_stats) \
+_(attr, v) \
+_(attr, value) \
+_(attr, values) \
+_(attr, var) \
+_(attr, vec) \
+_(attr, vec1) \
+_(attr, vec2) \
+_(attr, w_hh) \
+_(attr, w_ih) \
+_(attr, weight) \
+_(attr, weight0) \
+_(attr, weight1) \
+_(attr, weight2) \
+_(attr, weight3) \
+_(attr, weight4) \
+_(attr, weight_arr) \
+_(attr, weight_buf) \
+_(attr, weight_decay) \
+_(attr, weight_g) \
+_(attr, weight_scale) \
+_(attr, weight_stride0) \
+_(attr, weight_zero_point) \
+_(attr, weights) \
+_(attr, win_length) \
+_(attr, window) \
+_(attr, window_length) \
+_(attr, with_replacement) \
+_(attr, workspace) \
+_(attr, wrap) \
+_(attr, x) \
+_(attr, x1) \
+_(attr, x2) \
+_(attr, y) \
+_(attr, z) \
+_(attr, z_state) \
+_(attr, zero_infinity) \
+_(attr, zero_point) \
+_(attr, zero_point_hh) \
+_(attr, zero_point_ih) \
+_(attr, zero_points)
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/blob.h b/MLPY/Lib/site-packages/torch/include/ATen/core/blob.h
new file mode 100644
index 0000000000000000000000000000000000000000..7aa52ea67a6b52894d30f2c020f4f64952ad7af6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/blob.h
@@ -0,0 +1,208 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace caffe2 {
+
+class Tensor;
+
+/**
+ * @brief Blob is a general container that hosts a typed pointer.
+ *
+ * A Blob hosts a pointer as well as its type, and takes charge of deleting it
+ * properly when the blob is deallocated or re-allocated with a new type. A blob
+ * could contain anything, although the most common case is to contain a Tensor.
+ */
+class TORCH_API Blob final : public c10::intrusive_ptr_target {
+ public:
+  /**
+   * Initializes an empty Blob.
+   */
+  Blob() noexcept : meta_(), pointer_(nullptr), has_ownership_(false) {}
+  ~Blob() override {
+    Reset();
+  }
+
+  Blob(Blob&& other) noexcept : Blob() {
+    swap(other);
+  }
+
+  Blob& operator=(Blob&& other) noexcept {
+    Blob(std::move(other)).swap(*this);
+    return *this;
+  }
+
+  /**
+   * Checks if the content stored in the blob is of type T.
+   */
+  template 
+  bool IsType() const noexcept {
+    return meta_.Match();
+  }
+
+  /**
+   * Returns the meta info of the blob.
+   */
+  const TypeMeta meta() const noexcept {
+    return meta_;
+  }
+
+  /**
+   * Returns a printable typename of the blob.
+   */
+  c10::string_view TypeName() const noexcept {
+    return meta_.name();
+  }
+
+  /**
+   * @brief Gets the const reference of the stored object. The code checks if
+   * the stored object is of the desired type.
+   */
+  // TODO(jerryzh): add a Get(c10::DeviceType) function?
+  template 
+  const T& Get() const {
+    TORCH_INTERNAL_ASSERT(
+        IsType(),
+        "wrong type for the Blob instance. Blob contains ",
+        meta_.name(),
+        " while caller expects ",
+        TypeMeta::TypeName());
+    // TODO: after we add Get(c10::DeviceType)
+    // and changed all the callsites, we can add
+    // a static assert here to enforce T != Tensor
+    return *static_cast(pointer_);
+  }
+
+  const void* GetRaw() const noexcept {
+    return pointer_;
+  }
+  void* GetRaw() noexcept {
+    return pointer_;
+  }
+
+  /**
+   * @brief Gets a mutable pointer to the stored object.
+   *
+   * If the current object is not of the right type, a new object is created
+   * and the old object is freed. Note that type T should have a default
+   * constructor. Otherwise, create the object yourself first, and use
+   * Reset().
+   */
+  template 
+  T* GetMutable() {
+    static_assert(
+        std::is_default_constructible::value,
+        "GetMutable can't be called with non-default-constructible types. "
+        "Try using specialized methods");
+    if (IsType()) {
+      return static_cast(pointer_);
+    } else {
+      // TODO Re-enable logging
+      // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName();
+      return Reset(new T());
+    }
+  }
+
+  template 
+  T* GetMutableOrNull() {
+    if (IsType()) {
+      return static_cast(pointer_);
+    } else {
+      return nullptr;
+    }
+  }
+
+  /**
+   * Sets the underlying object to the allocated one. The Blob then takes over
+   * the ownership of the passed in pointer. If there is already an object in
+   * the Blob, the old object is freed.
+   *
+   * This is used when the underlying class T does not have a default ctor, or
+   * complex initializations needs to be done outside the blob.
+   */
+  template 
+  T* Reset(T* allocated) {
+    free_();
+    meta_ = TypeMeta::Make();
+    pointer_ = static_cast(allocated);
+    has_ownership_ = true;
+    return allocated;
+  }
+
+  /**
+   * Sets the underlying object to the allocated one, but does not take over
+   * the ownership of the passed in pointer. If there is already an object in
+   * the Blob, the old object is freed.
+   *
+   * Unlike Reset, this does not take over the ownership of the pointer and the
+   * caller is responsible for making sure that the lifetime of the allocated
+   * blob outlasts the lifetime of any access to this blob, until another Reset
+   * call is made or the blob is destructed.
+   */
+  template 
+  typename std::remove_const::type* ShareExternal(
+      typename std::remove_const::type* allocated) {
+    return static_cast(ShareExternal(
+        static_cast(allocated),
+        TypeMeta::Make::type>()));
+  }
+
+  void* ShareExternal(void* allocated, const TypeMeta meta) {
+    free_();
+    meta_ = meta;
+    pointer_ = allocated;
+    has_ownership_ = false;
+    return allocated;
+  }
+
+  /**
+   * Resets the Blob to an empty one.
+   */
+  void Reset() {
+    free_();
+    pointer_ = nullptr;
+    meta_ = TypeMeta();
+    has_ownership_ = false;
+  }
+
+  /**
+   * @brief Swaps the underlying storage of two blobs.
+   */
+  void swap(Blob& rhs) {
+    using std::swap;
+    swap(meta_, rhs.meta_);
+    swap(pointer_, rhs.pointer_);
+    swap(has_ownership_, rhs.has_ownership_);
+  }
+
+ private:
+  void free_() {
+    if (has_ownership_ && pointer_ != nullptr) {
+      (*meta_.deleteFn())(pointer_);
+    }
+  }
+
+  TypeMeta meta_;
+  void* pointer_;
+  bool has_ownership_;
+
+  C10_DISABLE_COPY_AND_ASSIGN(Blob);
+};
+
+inline void swap(Blob& lhs, Blob& rhs) {
+  lhs.swap(rhs);
+}
+
+inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
+  return out << "Blob[" << v.TypeName() << "]";
+}
+
+} // namespace caffe2
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..b0adf28a2f1937cc728a9efe484cbd47d7937b87
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel.h
@@ -0,0 +1,176 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+struct IValue;
+using Stack = std::vector;
+
+class OperatorHandle;
+class KernelFunction;
+
+// This kernel implements the behavior of falling through to the next available
+// registered dispatch key.  The implementation of this function is FAST; it is
+// no overhead to fallthrough to the next key.  See cpp file for some more
+// implementation notes; notably, this does NOT actually go through the
+// boxing/unboxing codepath.
+TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
+
+// Note [Ambiguity in AutogradOther kernel]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// This error-reporting kernel is registered to the AutogradOther entry in the
+// dispatch table when there is both a CompositeImplicitAutograd kernel and a
+// backend kernel for ANY backend that maps to AutogradOther.  To see why
+// this is necessary in the AutogradOther case, it's helpful to first see
+// why everything works out fine for a backend that has a reserved Autograd
+// entry (see rule 2.2 in [Note] DispatchTable computation):
+//
+//    CPU   AutogradCPU
+//    reg?  registers with...
+//    -------------------------------------------------
+//    y     Autograd registration takes precedence
+//          over CompositeImplicitAutograd.
+//          This is good, because the CPU specific backend
+//          implementation is more specialized and typically better;
+//          if we used the composite, we would bypass it.
+//          (NB: the Autograd key is guaranteed to exist because
+//          the autograd codegen requires it!)
+//
+//    n     CompositeImplicitAutograd takes precedence.
+//          This is also good, because the Autograd
+//          registration (if it exists) would try to redispatch
+//          to the (non-existent) CPU implementation; by
+//          using the composite, we ensure the operator
+//          actually works.
+//
+// As you can see, when we have a specific Autograd key (AutogradCPU), we can
+// decide whether or not to use the CompositeImplicitAutograd kernel or the
+// Autograd kernel based on whether or not the backend kernel exists.
+//
+// However, for AutogradOther (which is the catchall autograd kernel for
+// everything that doesn't have a specific Autograd key), we can't do this
+// trick because there isn't any unique backend to peek at to disambiguate;
+// if there are some backends that have implementations they prefer Autograd,
+// but unimplemented backends would prefer CompositeImplicitAutograd.  Rather
+// than arbitrarily pick one or the other, we just register a kernel that raises
+// an error and let the user decide how to proceed.
+TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
+
+// Note [named_not_supported_kernel]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// This kernel implements reporting an error message saying that named tensor is
+// not supported.  This kernel doesn't rely on the Stack, and so it is special
+// cased in the dispatcher to be triggered before we attempt boxing (so we can
+// give a good error message in cases when boxing is not supported).  When
+// boxing is universally supported this can be removed.
+[[noreturn]] TORCH_API void named_not_supported_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
+
+/**
+ * BoxedKernel is similar to a std::function storing a boxed kernel.
+ */
+class TORCH_API BoxedKernel final {
+public:
+  // This is how boxed kernels are actually stored
+  //
+  // Note [Plumbing Keys Through The Dispatcher]
+  // Benchmarks have shown that it is expensive for the dispatcher to read from thread-local storage (TLS)
+  // upon every dispatch call into order to compute which kernel to dispatch to.
+  //
+  // To mitigate this, we've updated the calling convention inside the dispatcher to expect every kernel that it stores
+  // to have a first argument of type DispatchKeySet.
+  //
+  // What are the invariants of the DispatchKeySet when it gets passed to a kernel?
+  // - All keys to the left of the current dispatch key have been masked out.
+  //   (e.g. a Tracing kernel that takes in the DispatchKeySet will expect the highest bit to be DispatchKey::Tracer)
+  // - All other keys that dispatcher normally would have computed through TLS + global state + op arguments
+  //   are still in the set.
+  //
+  // Kernels can then opt into using this keyset to save the dispatcher from doing repeated work during redispatches:
+  // recalculating the highest-priority dispatch key, which involves reading from TLS. Instead, the kernels that opt in will
+  // calculate an updated DispatchKeySet directly from the old one, and pass the updated set directly into the dispatcher
+  // upon redispatching.
+  //
+  // This is an opt-in mechanism: Kernels can automatically opt in by setting the first argument in their signature
+  // to be of type DispatchKeySet. See the kernels in VariableTypeEverything.cpp and TraceTypeEverything.cpp for examples.
+  //
+  // The mechanism for optionally passing that DispatchKeySet into the kernel lives in make_boxed_from_unboxed_functor.h.
+  // See Note [Plumbing Keys Through The Dispatcher 2] for details.
+  using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
+  // This is the public API for how boxed kernels are defined
+  using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
+  using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*);
+
+  BoxedKernel();
+
+  // Fast path for dispatch to allow not touching the boxed kernel in
+  // the common case where unboxed is available.
+  bool isValid() const;
+  bool isFallthrough() const;
+
+  /**
+   * Call the function with boxed arguments.
+   */
+  void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
+
+  /**
+   * Create a KernelFunction from a boxed function.
+   *
+   * Example:
+   *
+   * > void boxed_func(OperatorKernel*, Stack* stack) {...}
+   * > BoxedFunction func = BoxedKernel::makeFromFunction<&boxed_func>();
+   */
+  template
+  static BoxedKernel makeFromFunction();
+
+  /**
+   * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
+   * See Note [Plumbing Keys Through The Dispatcher] for details.
+   */
+  template
+  static BoxedKernel makeFromFunction();
+
+  /**
+   * Create a KernelFunction from a boxed functor.
+   *
+   * Example:
+   *
+   * > class MyFunctor final : public c10::OperatorKernel {
+   * >   public:
+   * >     void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
+   * > };
+   * > BoxedKernel func = BoxedKernel::makeFromFunctor(std::make_unique());
+   */
+  template
+  static BoxedKernel makeFromFunctor(std::unique_ptr kernelFunctor);
+
+
+  static BoxedKernel makeFallthrough();
+  static BoxedKernel makeAmbiguousAutogradOther();
+  static BoxedKernel makeNamedNotSupported();
+
+private:
+
+  friend class KernelFunction;
+
+  template
+  static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);
+
+  template
+  static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);
+
+  explicit BoxedKernel(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func);
+
+  OperatorKernel* getFunctor() const;
+  InternalBoxedKernelFunction* getFnPtr() const;
+
+  c10::intrusive_ptr functor_;
+  InternalBoxedKernelFunction* boxed_kernel_func_;
+};
+
+}  // namespace c10
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..4520f0de4585bf50b2b786396afbec3bfa9ff782
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/BoxedKernel_impl.h
@@ -0,0 +1,99 @@
+#pragma once
+
+namespace c10 {
+
+inline BoxedKernel::BoxedKernel()
+    : functor_()
+, boxed_kernel_func_(nullptr)
+{}
+
+inline BoxedKernel::BoxedKernel(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func)
+: functor_(std::move(functor))
+, boxed_kernel_func_(boxed_kernel_func)
+{}
+
+template
+inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack) {
+    // Note that we're dropping the DispatchKeySet argument.
+    // See Note [Plumbing Keys Through The Dispatcher 2] for details.
+    func(opHandle, stack);
+}
+
+template
+inline void BoxedKernel::make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet ks, Stack* stack) {
+    // See Note [Plumbing Keys Through The Dispatcher 2] for details.
+    func(opHandle, ks, stack);
+}
+
+inline bool BoxedKernel::isValid() const {
+    return boxed_kernel_func_ != nullptr;
+}
+
+inline bool BoxedKernel::isFallthrough() const {
+    return boxed_kernel_func_ == &fallthrough_kernel;
+}
+
+inline void BoxedKernel::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+        boxed_kernel_func_ != nullptr,
+        "Tried to call BoxedKernel::callBoxed() on an uninitialized BoxedKernel."
+    );
+    (*boxed_kernel_func_)(functor_.get(), opHandle, dispatchKeySet, stack);
+}
+
+template
+inline BoxedKernel BoxedKernel::makeFromFunction() {
+    return BoxedKernel(
+        nullptr,  // no functor_ object
+        &make_boxed_function
+    );
+}
+
+template
+inline BoxedKernel BoxedKernel::makeFromFunction() {
+    return BoxedKernel(
+        nullptr,  // no functor_ object
+        &make_boxed_function
+    );
+}
+
+inline BoxedKernel BoxedKernel::makeFallthrough() {
+    return BoxedKernel(
+        nullptr,  // no functor_ object
+        &fallthrough_kernel
+    );
+}
+
+inline BoxedKernel BoxedKernel::makeAmbiguousAutogradOther() {
+    return BoxedKernel(
+        nullptr,  // no functor_ object
+        &ambiguous_autogradother_kernel
+    );
+}
+
+inline BoxedKernel BoxedKernel::makeNamedNotSupported() {
+    return BoxedKernel(
+        nullptr,  // no functor_ object
+        &named_not_supported_kernel
+    );
+}
+
+template
+inline BoxedKernel BoxedKernel::makeFromFunctor(std::unique_ptr kernelFunctor) {
+    static_assert(std::is_base_of::value, "Tried to call BoxedKernel::makeFromFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+    return BoxedKernel(
+        std::move(kernelFunctor),
+        [](OperatorKernel* kernel, const OperatorHandle& op, DispatchKeySet ks, Stack* stack) {
+          (*static_cast(kernel))(op, ks, stack);
+        }
+    );
+}
+
+inline OperatorKernel* BoxedKernel::getFunctor() const {
+  return functor_.get();
+}
+inline BoxedKernel::InternalBoxedKernelFunction* BoxedKernel::getFnPtr() const {
+  return boxed_kernel_func_;
+}
+
+}  // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h
new file mode 100644
index 0000000000000000000000000000000000000000..41d9467e03d561c6ad46d0ff2e9d095a76b12ef3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction.h
@@ -0,0 +1,260 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
+
+class OperatorHandle;
+struct OperatorKernel;
+class KernelFunction;
+
+template 
+using has_symint =
+  std::disjunction<
+    std::is_same,
+    std::is_same,
+    std::is_same,
+    std::is_same, T>
+  >;
+
+template 
+struct remove_symint {
+  using type = T;
+};
+
+template <>
+struct remove_symint {
+  using type = int64_t;
+};
+
+template <>
+struct remove_symint {
+  using type = OptionalIntArrayRef;
+};
+
+template <>
+struct remove_symint {
+  using type = c10::IntArrayRef;
+};
+
+template <>
+struct remove_symint> {
+  using type = c10::optional;
+};
+
+
+template 
+struct maybe_keep_symint final {};
+
+template 
+struct maybe_keep_symint { using type = T; };
+
+template 
+struct maybe_keep_symint { using type = typename remove_symint::type; };
+
+template 
+using fn_has_symint = typename guts::typelist::true_for_any_type<
+  has_symint,
+  typename guts::infer_function_traits::type::parameter_types
+>;
+
+template 
+struct fn_remove_symint;
+
+template 
+struct fn_remove_symint {
+  using type = Ret(typename remove_symint::type...);
+};
+
+/**
+ * KernelFunction is similar to std::function but stores a kernel function.
+ * You can create a KernelFunction from a boxed or unboxed function/functor/lambda
+ * and call it in a boxed or unboxed way. If the way it was created doesn't
+ * match the way it was called, it will do boxing or unboxing as necessary.
+ */
+class TORCH_API KernelFunction final {
+public:
+  using InternalBoxedKernelFunction = BoxedKernel::InternalBoxedKernelFunction;
+  using BoxedKernelFunction = BoxedKernel::BoxedKernelFunction;
+  using BoxedKernelFunction_withDispatchKeys = BoxedKernel::BoxedKernelFunction_withDispatchKeys;
+
+  KernelFunction();
+
+  // Fast path for dispatch to allow not touching the boxed kernel in
+  // the common case where unboxed is available.
+  bool isValidUnboxed() const;
+  bool isValidSymUnboxed() const;
+  bool isValid() const;
+  bool isFallthrough() const;
+
+  /**
+   * Call the function in a boxed way.
+   * If the kernel function was created with an unboxed function,
+   * this will call an unboxing wrapper which then calls into that
+   * unboxed function.
+   *
+   * Example:
+   *
+   * > void boxed_func(OperatorKernel*, Stack* stack) {...}
+   * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
+   * > Tensor result = func.callBoxed(stack);
+   *
+   * Or, with an unboxed implementation:
+   *
+   * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
+   * >      [] (Tensor a, bool b) -> Tensor {...});
+   * > Tensor result = func.callBoxed(stack);
+   */
+  void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
+
+  /**
+   * Call the function in an unboxed way.
+   * If the kernel function was created with a boxed function,
+   * this will box all inputs and then call into that boxed function.
+   *
+   * Note that this doesn't work for all types yet.
+   *
+   * Example:
+   *
+   * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
+   * >      [] (Tensor a, bool b) -> Tensor {...});
+   * > Tensor result = func.call(tensor1, true);
+   *
+   * Or, with a boxed implementation:
+   *
+   * > void boxed_func(OperatorKernel*, Stack* stack) {...}
+   * > KernelFunction func = KernelFunction::makeFromBoxedFunction(&boxed_func);
+   * > Tensor result = func.call(tensor1, true);
+   */
+  template
+  Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const;
+
+  /**
+   * Create a KernelFunction from a BoxedKernel.
+   */
+  static KernelFunction makeFromBoxedKernel(BoxedKernel boxed_fn);
+
+  /**
+   * Create a KernelFunction from a boxed function.
+   *
+   * Example:
+   *
+   * > void boxed_func(OperatorKernel*, Stack* stack) {...}
+   * > KernelFunction func = KernelFunction::makeFromBoxedFunction<&boxed_func>();
+   */
+  template
+  static KernelFunction makeFromBoxedFunction();
+
+  /**
+   * TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
+   * See Note [Plumbing Keys Through The Dispatcher] for details.
+   */
+  template
+  static KernelFunction makeFromBoxedFunction();
+
+  /**
+   * Create a KernelFunction from an unboxed functor.
+   *
+   * Example:
+   *
+   * > class MyFunctor final : public c10::OperatorKernel {
+   * >   public:
+   * >     Tensor operator()(Tensor a, Tensor b) {...}
+   * > };
+   * > KernelFunction func = KernelFunction::makeFromUnboxedFunctor(std::make_unique());
+   */
+  template
+  static KernelFunction makeFromUnboxedFunctor(std::unique_ptr kernelFunctor);
+
+  /**
+   * Create a KernelFunction from a boxed functor.
+   *
+   * Example:
+   *
+   * > class MyFunctor final : public c10::OperatorKernel {
+   * >   public:
+   * >     void operator()(const OperatorHandle&, DispatchKeySet, Stack*) {...}
+   * > };
+   * > KernelFunction func = KernelFunction::makeFromBoxedFunctor(std::make_unique());
+   */
+  template
+  static KernelFunction makeFromBoxedFunctor(std::unique_ptr kernelFunctor);
+
+  /**
+   * Create a KernelFunction from an unboxed function.
+   * This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
+   * because knowing the function pointer as a template argument (i.e. at
+   * compile time) allows the compiler to inline the function into its
+   * unboxing wrapper and yields better performance when calling the function.
+   *
+   * Example:
+   *
+   * > Tensor unboxed_func(Tensor a, Tensor b) {...}
+   * > KernelFunction func = KernelFunction::makeFromUnboxedFunction();
+   */
+  template
+  static KernelFunction makeFromUnboxedFunction(FuncPtr);
+
+  /**
+   * Create a KernelFunction from an unboxed function.
+   * KernelFunction::makeFromUnboxedFunction is usually a better choice than
+   * this if you know the function pointer at compile time, see doc comment
+   * there for an explanation.
+   *
+   * Example:
+   *
+   * > Tensor unboxed_func(Tensor a, Tensor b) {...}
+   * > KernelFunction func = KernelFunction::makeFromUnboxedRuntimeFunction(&unboxed_func);
+   */
+  template
+  static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);
+
+  static KernelFunction makeFallthrough();
+  static KernelFunction makeAmbiguousAutogradOther();
+  static KernelFunction makeNamedNotSupported();
+
+  /**
+   * Create a KernelFunction from an unboxed lambda.
+   *
+   * Example:
+   *
+   * > KernelFunction func = KernelFunction::makeFromUnboxedLambda(
+   * >      [] (Tensor a, bool b) -> Tensor {...});
+   */
+  template
+  static std::enable_if_t>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda);
+  template
+  static std::enable_if_t>::value, KernelFunction> makeFromUnboxedLambda(Lambda&& lambda);
+
+  std::string dumpState() const;
+  // For testing internal invariants only
+  bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
+
+private:
+
+  explicit KernelFunction(
+      std::unique_ptr functor,
+      InternalBoxedKernelFunction* boxed_kernel_func,
+      void* unboxed_kernel_func,
+      void* sym_unboxed_kernel_func);
+  explicit KernelFunction(
+      BoxedKernel boxed_fn,
+      void* unboxed_kernel_func,
+      void* sym_unboxed_kernel_func);
+
+  BoxedKernel boxed_kernel_func_;
+  void* unboxed_kernel_func_;
+  void* sym_unboxed_kernel_func_;
+};
+
+}
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..fe5b8d39f72054dbe034b3dc35e247c66eefb2cd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/KernelFunction_impl.h
@@ -0,0 +1,229 @@
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace c10 {
+
+inline KernelFunction::KernelFunction()
+    : boxed_kernel_func_()
+    , unboxed_kernel_func_(nullptr)
+    , sym_unboxed_kernel_func_(nullptr)
+{}
+
+inline KernelFunction::KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
+  : boxed_kernel_func_(std::move(functor), boxed_kernel_func)
+  , unboxed_kernel_func_(unboxed_kernel_func)
+  , sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
+{}
+
+inline KernelFunction::KernelFunction(BoxedKernel boxed_fn, void* unboxed_kernel_func, void* sym_unboxed_kernel_func = nullptr)
+  : boxed_kernel_func_(std::move(boxed_fn))
+  , unboxed_kernel_func_(unboxed_kernel_func)
+  , sym_unboxed_kernel_func_(sym_unboxed_kernel_func)
+{}
+
+inline bool KernelFunction::isValidUnboxed() const {
+  return unboxed_kernel_func_ != nullptr;
+}
+
+inline bool KernelFunction::isValidSymUnboxed() const {
+  return sym_unboxed_kernel_func_ != nullptr;
+}
+
+inline bool KernelFunction::isValid() const {
+  return boxed_kernel_func_.isValid();
+}
+
+inline bool KernelFunction::isFallthrough() const {
+  return boxed_kernel_func_.isFallthrough();
+}
+
+inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const {
+  boxed_kernel_func_.callBoxed(opHandle, dispatchKeySet, stack);
+}
+
+template
+inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) {
+    using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...);
+    ActualSignature* func = reinterpret_cast(unboxed_kernel_func);
+    return (*func)(functor, dispatchKeySet, std::forward(args)...);
+}
+
+// This template requires you to explicitly specify the argument you want to
+// forward; it doesn't work if you try to deduce it
+// NB: keep this in sync with cloneWithRealTypes in function_schema.cpp
+
+template 
+inline typename remove_symint::type unpackSymInt(T x) { return x; }
+
+template <>
+inline typename remove_symint::type unpackSymInt(c10::SymInt x) {
+  return x.guard_int(__FILE__, __LINE__);
+}
+
+template <>
+inline typename remove_symint::type unpackSymInt(c10::SymIntArrayRef x) {
+  return C10_AS_INTARRAYREF_SLOW(x);
+}
+
+template <>
+inline typename remove_symint>::type unpackSymInt(c10::optional x) {
+  return x.has_value() ? c10::make_optional(x->guard_int(__FILE__, __LINE__)) : c10::nullopt;
+}
+
+template <>
+inline typename remove_symint::type unpackSymInt(at::OptionalSymIntArrayRef x) {
+  return x.has_value() ? c10::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : c10::nullopt;
+}
+
+template
+C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
+    // note: Args above is intentionally not Args&&. We don't want perfect
+    // forwarding, which would require Args to be deduced, but instead we
+    // want callers to explicitly specify the Args.
+
+    if constexpr (std::disjunction_v...>) {
+      if (sym_unboxed_kernel_func_ != nullptr) {
+        auto *functor = boxed_kernel_func_.getFunctor();
+        return callUnboxedKernelFunction(
+            sym_unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...);
+      }
+
+      if (unboxed_kernel_func_ != nullptr) {
+        auto *functor = boxed_kernel_func_.getFunctor();
+        return callUnboxedKernelFunction::type...>(
+            unboxed_kernel_func_, functor, dispatchKeySet, unpackSymInt(args)...);
+      }
+    } else {
+      if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
+        auto *functor = boxed_kernel_func_.getFunctor();
+        return callUnboxedKernelFunction(
+            unboxed_kernel_func_, functor, dispatchKeySet, std::forward(args)...);
+      }
+    }
+
+    return impl::BoxedKernelWrapper::call(
+        boxed_kernel_func_,
+        opHandle,
+        dispatchKeySet,
+        std::forward(args)...
+    );
+}
+
+inline KernelFunction KernelFunction::makeFromBoxedKernel(BoxedKernel boxed_fn) {
+  return KernelFunction(std::move(boxed_fn), nullptr);  // no unboxed function pointer
+}
+
+template
+inline KernelFunction KernelFunction::makeFromBoxedFunction() {
+  return KernelFunction::makeFromBoxedKernel(
+      BoxedKernel::makeFromFunction());
+}
+
+template
+inline KernelFunction KernelFunction::makeFromBoxedFunction() {
+  return KernelFunction::makeFromBoxedKernel(
+      BoxedKernel::makeFromFunction());
+}
+
+inline KernelFunction KernelFunction::makeFallthrough() {
+  return KernelFunction::makeFromBoxedKernel(
+      BoxedKernel::makeFallthrough());
+}
+
+inline KernelFunction KernelFunction::makeAmbiguousAutogradOther() {
+  return KernelFunction::makeFromBoxedKernel(
+      BoxedKernel::makeAmbiguousAutogradOther());
+}
+
+inline KernelFunction KernelFunction::makeNamedNotSupported() {
+  return KernelFunction::makeFromBoxedKernel(
+      BoxedKernel::makeNamedNotSupported());
+}
+
+template
+inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr kernelFunctor) {
+#ifndef NDEBUG
+  // This assertion is costly for build time so it's debug-gated.
+    static_assert(guts::is_functor::value, "Tried to call KernelFunction::makeFromUnboxedFunctor but the argument is not a functor.");
+#endif
+    static_assert(std::is_base_of::value, "Tried to call KernelFunction::makeFromUnboxedFunctor, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+
+    auto* unboxed_fn = &impl::wrap_kernel_functor_unboxed::call;
+    void* void_unboxed_fn = reinterpret_cast(unboxed_fn);
+    bool is_symint = fn_has_symint::value;
+    return KernelFunction(
+        std::move(kernelFunctor),
+        &impl::make_boxed_from_unboxed_functor::call,
+        is_symint ? nullptr : void_unboxed_fn,
+        is_symint ? void_unboxed_fn : nullptr
+    );
+}
+
+template
+inline KernelFunction KernelFunction::makeFromBoxedFunctor(std::unique_ptr kernelFunctor) {
+  return KernelFunction::makeFromBoxedKernel(
+      BoxedKernel::makeFromFunctor(std::move(kernelFunctor)));
+}
+
+template
+inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) {
+    static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
+    static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
+    static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");
+
+#if !defined(C10_MOBILE)
+    (void)func_ptr; // Suppress unused variable warning
+    return makeFromUnboxedFunctor::type>(
+        guts::make_unique_base::type>()
+    );
+#else
+    // On mobile, we rather want to optimize for binary size than for performance,
+    // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction
+    // instead.
+    return makeFromUnboxedRuntimeFunction(func_ptr.func_ptr());
+#endif
+}
+
+template
+inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) {
+    static_assert(guts::is_function_type::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
+    static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
+    TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");
+
+    return makeFromUnboxedFunctor>>(
+        guts::make_unique_base>>(func)
+    );
+}
+
+template
+inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
+    static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
+
+#if !defined(C10_MOBILE)
+    return makeFromUnboxedFunctor>>(
+        guts::make_unique_base>>(std::forward(lambda))
+    );
+#else
+    // On mobile, we rather want to optimize for binary size than for performance,
+    // so let's not inline the kernel into the wrapper but use makeFromUnboxedRuntimeFunction
+    // instead.
+    using FuncType = typename guts::infer_function_traits_t>::func_type;
+    return makeFromUnboxedRuntimeFunction(lambda);
+#endif
+}
+
+template
+inline std::enable_if_t>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
+    static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
+
+    return makeFromUnboxedFunctor>>(
+        guts::make_unique_base>>(std::forward(lambda))
+    );
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..7bda1820b4ffdf210f59df8443a9f725564a6104
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/OperatorKernel.h
@@ -0,0 +1,27 @@
+#pragma once
+#include 
+
+namespace c10 {
+
+/**
+ * Inherit from OperatorKernel to implement a c10 kernel.
+ *
+ * Example:
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ *
+ * The kernel class is allowed to have members but these are equivalent
+ * to global variables. The kernel implementation is responsible for
+ * preventing race conditions on them.
+ *
+ * See below for how to register this kernel with PyTorch.
+ */
+struct TORCH_API OperatorKernel : public c10::intrusive_ptr_target {
+  ~OperatorKernel() override = default;
+};
+
+}  // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h
new file mode 100644
index 0000000000000000000000000000000000000000..fa4811722a47e932eafebdcceac4e844c43fc2e2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+namespace impl {
+  namespace detail {
+    template class WrapFunctionIntoFunctor_ {};
+    template
+    class WrapFunctionIntoFunctor_> final : public c10::OperatorKernel {
+    public:
+      C10_ALWAYS_INLINE decltype(auto) operator()(Parameters... args) {
+        return (*FuncPtr::func_ptr())(std::forward(args)...);
+      }
+    };
+  }
+
+  // WrapFunctionIntoFunctor: Wraps a compile time function pointer into a kernel functor.
+  // Since it is a compile time function pointer, many compilers can inline it
+  // into the wrapper and you don't get any performance overhead for wrapping.
+  template
+  struct WrapFunctionIntoFunctor final {
+    static_assert(c10::is_compile_time_function_pointer::value, "WrapFunctionIntoFunctor can only wrap functions created with TORCH_FN.");
+    using type = detail::WrapFunctionIntoFunctor_<
+        FuncPtr,
+        typename guts::function_traits::return_type,
+        typename guts::function_traits::parameter_types
+    >;
+  };
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h
new file mode 100644
index 0000000000000000000000000000000000000000..a12160b47f494b3deb455205956e7d271bece967
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+
+namespace impl {
+  namespace detail {
+    template class WrapFunctionIntoRuntimeFunctor_ {};
+    template
+    class WrapFunctionIntoRuntimeFunctor_> final : public c10::OperatorKernel {
+    public:
+      template
+      explicit WrapFunctionIntoRuntimeFunctor_(FuncType_&& kernel_func)
+      : kernel_func_(std::forward(kernel_func)) {}
+
+      decltype(auto) operator()(Parameters... args) {
+        return kernel_func_(std::forward(args)...);
+      }
+
+    private:
+      FuncType kernel_func_;
+    };
+  }
+
+  // WrapFunctionIntoRuntimeFunctor: Wraps any runtime functor into a functor that
+  // inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
+  // This can, for example, be used for lambdas, functors or even function pointers.
+  // In the case of function pointers, since it is a runtime function pointer,
+  // there is an overhead for calling it whenever the kernel is invoked.
+  template
+  using WrapFunctionIntoRuntimeFunctor = detail::WrapFunctionIntoRuntimeFunctor_<
+      FuncType,
+      typename guts::infer_function_traits_t::return_type,
+      typename guts::infer_function_traits_t::parameter_types
+  >;
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h
new file mode 100644
index 0000000000000000000000000000000000000000..041b261031c4496105eb2b58dd54d7e8b570165b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/boxing.h
@@ -0,0 +1,387 @@
+#pragma once
+
+// This file contains boxing (not unboxing) logic,
+// i.e. how to make a vector from a set of concrete arguments.
+
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+
+namespace c10 {
+namespace impl {
+
+//
+// utils
+//
+
+// is_mutable_tensor_ref
+template  struct is_mutable_tensor_ref : std::false_type {};
+template <> struct is_mutable_tensor_ref : std::true_type {};
+
+// is_tuple_of_mutable_tensor_refs
+//
+template 
+struct is_tuple_of_mutable_tensor_refs : std::false_type {};
+
+template 
+struct is_tuple_of_mutable_tensor_refs::value, void>>
+: guts::typelist::all>
+{};
+
+// has_ivalue_to tests the presence/absence of instance method IValue::to()
+//
+template 
+struct has_ivalue_to : std::false_type {};
+
+template 
+struct has_ivalue_to().to())>>
+: std::true_type
+{};
+
+//
+// boxing predicates
+//
+
+// A boxable arg type is one that IValue has a constructor for.
+template 
+using can_box =
+  std::disjunction<
+    std::is_constructible>,
+    // TensorOptions are not directly constructible into IValue,
+    // but torch::jit::push knows how to handle them
+    std::is_same>
+  >;
+
+template 
+using can_box_all = std::conjunction...>;
+
+// an unboxable result is one that can be extracted from an IValue
+template 
+using can_unbox =
+   std::conjunction<
+    std::disjunction<
+      has_ivalue_to,
+      // void returns are ok
+      std::is_same
+    >,
+    std::negation>
+  >;
+
+//
+// boxArgs - utility for pushing unboxed args onto IValue stack
+//
+template 
+torch::jit::Stack boxArgs(Args... args) {
+  // TODO Reuse stack vector instead of allocating?
+  torch::jit::Stack stack;
+  stack.reserve(sizeof...(Args));
+  torch::jit::push(stack, std::forward(args)...);
+  return stack;
+}
+
+template 
+static inline constexpr size_t boxed_size_one() {
+  static_assert(!std::is_same, c10::TensorOptions>::value, "need to patch this path to support TensorOptions passed by reference");
+  return 1;
+}
+
+// torch::jit::push pushes 4 values for a TensorOptions; this needs to
+// be kept in sync.
+template <>
+inline constexpr size_t boxed_size_one() {
+  return 4;
+}
+
+// NOTE: this could probably be simplified with C++17 fold expressions.
+template 
+struct BoxedSize : std::integral_constant {};
+template 
+struct BoxedSize : std::integral_constant() + BoxedSize::value> {};
+
+template 
+static inline constexpr size_t boxed_size() {
+  return BoxedSize::value;
+}
+
+using IValueAlignedStorage = std::aligned_storage_t;
+
+template 
+C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValueAlignedStorage* dest, T& arg, int& lastIdx) {
+  new (&dest[lastIdx]) IValue(arg);
+  lastIdx++;
+}
+
+C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValueAlignedStorage* dest, c10::TensorOptions options, int& lastIdx) {
+  new (&dest[lastIdx++]) IValue(c10::typeMetaToScalarType(options.dtype()));
+  new (&dest[lastIdx++]) IValue(options.layout());
+  new (&dest[lastIdx++]) IValue(options.device());
+  new (&dest[lastIdx++]) IValue(options.pinned_memory());
+}
+
+inline void boxArgsToStack(IValueAlignedStorage*, int&) {}
+
+template
+C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(IValueAlignedStorage* dest, int& lastIdx, T& arg, Args &... args) {
+  boxToStack(dest, arg, lastIdx);
+  boxArgsToStack(dest, lastIdx, args...);
+}
+
+//
+// PopResult is a helper class whose specializations handle popping single and
+// multiple return values, respectively.
+//
+template 
+struct PopResult final {
+  static Result call(Stack& stack) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      stack.size() == 1,
+      "Boxed kernel was expected to return one value on the stack, ",
+      "but instead pushed ", stack.size(), " values."
+    );
+    return std::move(stack[0]).to();
+  }
+};
+
+template 
+struct PopResult> final {
+  using Result = std::tuple;
+
+  static Result call(Stack& stack) {
+    // for tuple return types, boxed kernel has pushed multiple values onto the stack
+    constexpr int RetCount = sizeof...(Types);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      stack.size() == RetCount,
+      "Boxed kernel was expected to return ", RetCount, " values on the stack, ",
+      "but instead pushed ", stack.size(), " values."
+    );
+    return pop_to_tuple_impl(stack, std::make_index_sequence());
+  }
+private:
+  // note: this has been moved into its own helper only to avoid a parse error on `indices` otherwise.
+  // I'm sure there's an incantation that slips it past the parser but eh
+  template 
+  static Result pop_to_tuple_impl(Stack& stack, std::index_sequence) {
+    return std::make_tuple((std::move(stack[indices]).to())...);
+  }
+};
+
+//
+// BoxedKernelWrapper
+//
+// For a given function type FT, BoxedKernelWrapper implements
+// a `call` method that
+// - takes a boxed kernel and unboxed arguments as specified by FT,
+// - calls `boxArgs` to box the arguments
+// - calls the boxed kernel
+// - unboxes and returns the result
+//
+// The partial specializations below handle various cases: in
+// particular, not all types appearing in op signatures are supported,
+// and ops returning references have nonstandard wrapper implementations.
+//
+
+// 1. The base specialization of BoxedKernelWrapper should never be instantiated.
+// A "no call method defined on BoxedKernelWrapper" compile error means that
+// an op signature has failed to trigger any of the partial specializations
+// that follow this one.
+//
+template 
+struct BoxedKernelWrapper {
+  // The reason we're not just doing straight up static_assert(false, ...) here:
+  // Basically, the way to make sure a static_assert only fires if a template
+  // is actually instantiated (rather than every time the file is parsed) is to use
+  // template parameters in the expression, e.g. FuncType here. However, since
+  // `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the same
+  // effect.
+  static_assert(sizeof(FuncType) != sizeof(FuncType),
+     "Function signature contains one or more unsupported parameter and/or return types. "
+     "Look for a nearby error like "
+     "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "
+     "- (your function type) is the unsupported signature.");
+};
+
+//
+// 2. Supported signatures, other than those involving non-const Tensor refs -
+// i.e., "functional" ops.
+//
+
+template 
+struct BoxedKernelWrapper<
+  Result(Args...),
+  std::enable_if_t<
+    can_box_all::value && can_unbox::value && !is_tuple_of_mutable_tensor_refs::value,
+    void
+  >
+> {
+  static Result call(
+    const BoxedKernel& boxed_kernel_func,
+    const OperatorHandle& opHandle,
+    DispatchKeySet dispatchKeySet,
+    Args... args
+  ) {
+    torch::jit::Stack stack = boxArgs(std::forward(args)...);
+    boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
+
+    if constexpr (!std::is_same_v) {
+        // op has pushed one or more values onto the stack.
+        return PopResult::call(stack);
+    } else {
+      // op returns void, boxed kernel has pushed nothing onto stack.
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+          stack.empty(),
+          "Boxed kernel was expected to return no values on the stack, ",
+          "but instead returned ", stack.size(), " values."
+      );
+    }
+  }
+};
+
+//
+// 3. in-place ops take a single non-const Tensor reference
+// as their first argument, and return it.
+//
+// Note: all signatures matching this pattern are assumed to be for such ops.
+// Because of this, the generated BoxedKernelWrapper specializations simply
+// return the in-place argument.
+//
+
+template 
+struct BoxedKernelWrapper<
+  at::Tensor&(at::Tensor&, OtherArgs...),
+  std::enable_if_t::value, void>
+> {
+  static at::Tensor& call(
+    const BoxedKernel& boxed_kernel_func,
+    const OperatorHandle& opHandle,
+    DispatchKeySet dispatchKeySet,
+    at::Tensor& outArg, OtherArgs... otherArgs
+  ) {
+    torch::jit::Stack stack = boxArgs(outArg, std::forward(otherArgs)...);
+    boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      stack.size() == 1,
+      "Boxed kernel was expected to return a single value on the stack, ",
+      "but instead returned ", stack.size(), " values."
+    );
+
+    return outArg;
+  }
+};
+
+//
+// 3.5. In-process migration to make in-place ops take and return
+// const references instead.
+template 
+struct BoxedKernelWrapper<
+  const at::Tensor&(const at::Tensor&, OtherArgs...),
+  std::enable_if_t::value, void>
+> {
+  static const at::Tensor& call(
+    const BoxedKernel& boxed_kernel_func,
+    const OperatorHandle& opHandle,
+    DispatchKeySet dispatchKeySet,
+    const at::Tensor& outArg, OtherArgs... otherArgs
+  ) {
+    torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
+    boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      stack.size() == 1,
+      "Boxed kernel was expected to return a single value on the stack, ",
+      "but instead returned ", stack.size(), " values."
+    );
+
+    return outArg;
+  }
+};
+
+//
+// 4. out of place ops that take a single non-const Tensor reference as their
+// final argument, and also return it.
+//
+// Note: all signatures matching this pattern are assumed to be for such ops.
+// This assumption permits the generated BoxedKernelWrapper specializations to simply
+// return out arguments.
+//
+template 
+struct BoxedKernelWrapper<
+  at::Tensor&(FirstArg, RestArgs...),
+  std::enable_if_t<
+    can_box_all::value
+    // this skips over in-place kernels with a non-const Tensor
+    // arg at the front, so those can unambiguously trigger the preceding specialization.
+    && !is_mutable_tensor_ref::value,
+    void
+  >
+> {
+  static at::Tensor& call(
+    const BoxedKernel& boxed_kernel_func,
+    const OperatorHandle& opHandle,
+    DispatchKeySet dispatchKeySet,
+    FirstArg firstArg, RestArgs... restArgs
+  ) {
+    torch::jit::Stack stack = boxArgs(std::forward(firstArg), std::forward(restArgs)...);
+    boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      stack.size() == 1,
+      "Boxed kernel was expected to return a single value on the stack, ",
+      "but instead returned ", stack.size(), " values."
+    );
+
+    // reusing restArgs after it has been forwarded here is ok because we know
+    // that the last element is of type `Tensor&`.
+    return std::get(std::tuple{restArgs...});
+  }
+};
+
+//
+// 5. out of place ops that take multiple non-const Tensor references as their
+// final arguments, and return them in a std::tuple.
+//
+// Note: all signatures matching this pattern are assumed to be for such ops.
+// This assumption permits the generated BoxedKernelWrapper specializations to simply
+// return the out arguments.
+//
+template 
+struct BoxedKernelWrapper<
+  Result(Args...),
+  std::enable_if_t<
+    can_box_all::value && is_tuple_of_mutable_tensor_refs::value,
+    void
+  >
+> {
+  static Result call(
+    const BoxedKernel& boxed_kernel_func,
+    const OperatorHandle& opHandle,
+    DispatchKeySet dispatchKeySet,
+    Args... args
+  ) {
+    using ArgTuple = std::tuple;
+    constexpr int RetCount = std::tuple_size();
+
+    torch::jit::Stack stack = boxArgs(std::forward(args)...);
+    boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      stack.size() == RetCount,
+      "Boxed kernel was expected to return ", RetCount, " values on the stack, ",
+      "but instead returned ", stack.size(), " values."
+    );
+
+    // reusing args after it has been forwarded here is ok because we know
+    // that the last RetCount elements are of type `Tensor&`.
+    auto result = guts::tuple_take(ArgTuple{std::forward(args)...});
+    static_assert(
+        std::is_same::value,
+        "The parameter list of an op returning a tuple of Tensor references "
+            "must end with an equal number of Tensor reference parameters."
+    );
+    return result;
+  }
+};
+
+} // impl
+} // c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
new file mode 100644
index 0000000000000000000000000000000000000000..91bf0bff104adfd08bb85d05f744e7886370148a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
@@ -0,0 +1,600 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace c10 {
+
+using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
+class OperatorHandle;
+
+/*
+ * [Note: Argument forwarding in the dispatcher]
+ *
+ * The dispatcher uses a somewhat unusual way to forward arguments through several layers of
+ * wrapper functions. This can be confusing because an experienced C++ programmer would look at this
+ * and think "oh this is supposed to be forwarding a universal reference but the && is missing. This is a bug.".
+ * It is not a bug. The common way in C++ to forward arguments is to use universal references:
+ *
+ * > template void func(T&& arg) { func2(std::forward(arg)); }
+ *
+ * but that relies on inferring the correct reference type (i.e. value vs & vs &&) from the argument.
+ * In our case, we cannot rely on the argument as supplied by the caller, because that could infer a
+ * different reference type than was used in the kernel function. The correct reference type
+ * is dictated by the kernel signature and must be identical since we cast function pointers
+ * through void* pointers and mismatches would be UB. So we need a forwarding pattern that determines
+ * the reference type to use by looking at the explicitly supplied operator signature, not by looking at
+ * the argument we're calling it with.
+ *
+ * What does std::forward do, exactly?
+ * ------------------------------------
+ * std::forward(t) is a way to cast t to the reference type supplied in T.
+ * Let's assume decay_t == U and T is either U or some reference of U.
+ *  - std::forward(t) will return U&, no matter what kind of reference t is.
+ *  - std::forward(t) will return U&&, no matter what kind of reference t is.
+ *  - std::forward(t) will return U&& (not U!), no matter what kind of reference t is.
+ *
+ * For universal references, that means that in the following function
+ * > template void func(T&& arg) { func2(std::forward(arg)); }
+ *
+ *  - when called with arg being a rvalue reference or non-reference value, T gets inferred to be
+ *    a non-reference U, and std::forward(t) will return U&&, correctly moving the argument.
+ *  - when called with arg behind a lvalue reference, T gets inferred to be U& because that's the only
+ *    way to match the signature (in C++, a type that is (T&)&& will collapse to T&).
+ *    That means std::forward(t) will return U& and the value will not be moved but passed on as
+ *    a lvalue reference.
+ *
+ * How do we use that?
+ * ------------------------------------
+ * But std::forward can also be used outside of the common "universal forwarding" pattern to change
+ * reference types. So instead of following the common C++ pattern, we notice what
+ * std::forward() actually does, and that is it takes a value and changes its reference to the
+ * type of reference passed in as T. If we don't infer T but explicitly specify it, we can use this
+ * to forward based on an explicitly specified reference type instead of the inferred argument type.
+ *
+ * This is why many of the dispatcher functions look like
+ * > template func(T t) { func2(std::forward(t)); }
+ * instead of the common
+ * > template func(T&& t) { func2(std::forward(t)); }
+ *
+ * and are expected to be called by explicitly specifying the template parameters in a way that matches
+ * the expected operator signature at each call site.
+ */
+
+namespace impl {
+  // supported_primitive_arg_types defines which primitive types we allow in
+  // kernel functions as arguments or returns.
+  // Additionally, we support lists, dicts and optionals containing these types.
+  using supported_primitive_arg_types = guts::typelist::typelist<
+    int64_t,
+    double,
+    bool,
+    c10::string_view,
+    at::Tensor,
+    at::Scalar,
+    c10::QScheme,
+    c10::ScalarType,
+    c10::Device,
+    c10::DeviceIndex,
+    c10::Layout,
+    c10::MemoryFormat,
+    at::Dimname
+  >;
+
+  // We have an unboxed functor in hand that takes C++ arguments, and
+  // we're building a boxed functor wrapper for it that takes IValues.
+  // So "outside" is boxed and "inside" is unboxed.
+  //
+  // So a valid input type is one that our boxed functor wrapper can
+  // unbox from an IValue into a C++ value.
+  //
+  // Whereas a valid output type is one that our wrapper can recieve
+  // as a C++ value from the unboxed functor, and box into an IValue.
+
+  //
+  // assert_is_valid_input_type
+  // checks that T can be unboxed from an IValue into a C++ value.
+  //
+
+  template
+  struct assert_is_valid_input_type {
+    assert_is_valid_input_type() {
+      if constexpr (guts::typelist::contains::value) {
+        /* everything is ok, this is a primitive type */
+      } else {
+        /* otherwise this must be an instance of a valid custom class, since it can only
+           have been created via IValue(x), which ensures this. */
+      }
+    }
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {};
+
+  template 
+  struct TypeCheckHelper;
+
+  template 
+  struct TypeCheckHelper {};
+
+  template 
+  struct TypeCheckHelper
+  : TypeCheckHelper {
+    assert_is_valid_input_type check;
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : TypeCheckHelper {};
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {
+    static_assert(guts::typelist::contains::value,
+      "You tried to register a kernel with an unsupported input type: Dict where Key is invalid. We only support int64_t, double, bool, and string.");
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {
+    static_assert(AllowDeprecatedTypes,
+      "You tried to register a kernel with an unsupported input type: std::unordered_map. Please use Dict instead.");
+    static_assert(guts::typelist::contains::value,
+      "You tried to register a kernel with an unsupported input type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string.");
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead.");
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported input type: ArrayRef. Please use List, List or Tensor instead.");
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported input type: OptionalArrayRef. Please use List, List or Tensor instead.");
+  };
+
+  template
+  struct assert_is_valid_input_type, AllowDeprecatedTypes>
+  : assert_is_valid_input_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported input type: std::array. Please use std::array instead.");
+  };
+
+  template
+  struct assert_is_valid_input_type::value>> {
+    // There is no reason to support float when we have double. Keep the API lean.
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported input type: float. Please use double instead.");
+  };
+  template
+  struct assert_is_valid_input_type::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead.");
+  };
+  template
+  struct assert_is_valid_input_type, T>::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported input type: vector. Please use List instead.");
+  };
+  template
+  struct assert_is_valid_input_type::value && !guts::typelist::contains::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead.");
+  };
+  template
+  struct assert_is_valid_input_type::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
+  };
+
+  // TODO: it probably would be good to tighten this up quite a bit more with
+  // an explicit list for everything
+
+  //
+  // assert_is_valid_output_type
+  //
+
+  template
+  struct assert_is_valid_output_type {
+    assert_is_valid_output_type() {
+      if constexpr(guts::typelist::contains::value) {
+        /* everything is ok, this is a primitive type */
+      } else {
+        /* otherwise T is verified to be a registered custom class in the IValue
+          constructor, so no benefit in double-checking here */
+      }
+    }
+  };
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {};
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {};
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {
+    static_assert(guts::typelist::contains::value,
+      "You tried to register a kernel with an unsupported output type: Dict where Key is invalid. We only support int64_t, double, bool, and string.");
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported output type: Dict. Please use Dict or Dict.");
+  };
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {
+    static_assert(AllowDeprecatedTypes,
+      "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict instead.");
+    static_assert(guts::typelist::contains::value,
+      "You tried to register a kernel with an unsupported output type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string.");
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict or Dict.");
+  };
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported output type: List. Please use List, List or Tensor instead.");
+  };
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported output type: std::vector. Please use List, List or Tensor instead.");
+    // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector. Please use List instead.");
+  };
+
+  template
+  struct assert_is_valid_output_type, AllowDeprecatedTypes>
+  : assert_is_valid_output_type {
+    static_assert(!std::is_same::value,
+      "You tried to register a kernel with an unsupported output type: std::array. Please use std::array instead.");
+  };
+
+  // The following specialisations of assert_is_valid_output_type are technically not
+  // necessary since we would hit the base case and show an error message
+  // there if they didn't exist, but we can show a better error message
+  // in some common error scenarios.
+  template
+  struct assert_is_valid_output_type::value>> {
+    // There is no reason to support float when we have double. Keep the API lean.
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported output type: float. Please use double instead.");
+  };
+  template
+  struct assert_is_valid_output_type::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead.");
+  };
+  template
+  struct assert_is_valid_output_type, T>::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported output type: vector. Please use List instead.");
+  };
+  template
+  struct assert_is_valid_output_type::value && !guts::typelist::contains::value>> {
+    static_assert(guts::false_t::value,
+      "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead.");
+  };
+
+  // ivalue_to_arg
+
+  template
+  struct decay_if_not_tensor final {
+    using type = std::decay_t;
+  };
+
+  template<>
+  struct decay_if_not_tensor final {
+    using type = at::Tensor&;
+  };
+
+  template<>
+  struct decay_if_not_tensor final {
+    using type = const at::Tensor&;
+  };
+
+  template
+  struct ivalue_to_arg final {
+    static decltype(auto) call(IValue& v) {
+      assert_is_valid_input_type();
+      return std::move(v).to();
+    }
+  };
+
+  // The following two specializations take advantage of specialized
+  // `toTensor()` overloads on IValue to avoid copying.
+  template
+  struct ivalue_to_arg final {
+    // We cannot use the default implementation if they asked for a
+    // `at::Tensor&` because it moves from the IValue, so it can't get
+    // an lvalue reference.
+    static at::Tensor& call(IValue& v) {
+      // Tensor& is valid, don't bother asserting
+      return v.toTensor();
+    }
+  };
+
+  template
+  struct ivalue_to_arg final {
+    // We should not use the default implementation if they asked for
+    // a `const at::Tensor&` because it moves from the IValue and they
+    // didn't ask for that.
+    static const at::Tensor& call(IValue& v) {
+      // const Tensor& is valid, don't bother asserting
+      return v.toTensor();
+    }
+  };
+
+  template
+  struct ivalue_to_arg final {
+    static List call(IValue& v) {
+      return v.toTensorList();
+    }
+  };
+
+  template
+  struct ivalue_to_arg, AllowDeprecatedTypes> final {
+    // If an argument is ArrayRef, convert the IValue to a std::vector and pass that
+    // to the operator. std::vector is implicitly convertible to ArrayRef.
+    static std::vector call(IValue& v) {
+      return ivalue_to_arg, AllowDeprecatedTypes>::call(v);
+    }
+  };
+  template
+  struct ivalue_to_arg final {
+    static std::vector call(IValue& v) {
+      if (v.isIntList()) {
+        std::vector r;
+        auto src = v.toIntList();
+        std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
+        return r;
+      } else {
+        return ivalue_to_arg, AllowDeprecatedTypes>::call(v);
+      }
+    }
+  };
+  template
+  struct ivalue_to_arg, AllowDeprecatedTypes> final {
+    static OptionalArray call(IValue& v) {
+      if (v.isIntList()) {
+        std::vector r;
+        auto src = v.toIntList();
+        std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
+        return OptionalArray(std::move(r));
+      } else {
+        return std::move(v).to>();
+      }
+    }
+  };
+  template
+  struct ivalue_to_arg>, AllowDeprecatedTypes> final {
+    // If an argument is optional>, convert the IValue to an optional> and pass that
+    // to the operator. OptionalArray is basically a optional> but implicitly convertible
+    // to optional>.
+    static OptionalArray call(IValue& v) {
+      return ivalue_to_arg, AllowDeprecatedTypes>::call(v);
+    }
+  };
+
+  template
+  struct ivalue_to_arg, AllowDeprecatedTypes> final {
+    // If an argument is OptionalArrayRef, convert the IValue to an
+    // optional> and pass that to the operator. OptionalArray
+    // is basically a optional> but implicitly convertible to
+    // OptionalArrayRef
+    static OptionalArray call(IValue& v) {
+      return ivalue_to_arg, AllowDeprecatedTypes>::call(v);
+    }
+  };
+
+  // return_to_ivalue
+  template
+  struct return_to_ivalue final {};
+
+  template
+  struct return_to_ivalue::value>> final {
+    static IValue call(T&& v) {
+      assert_is_valid_output_type();
+      return c10::ivalue::from(std::move(v));
+    }
+    static IValue copy(const T& v) {
+      assert_is_valid_output_type();
+      return IValue(v);
+    }
+  };
+
+  // Special case to allow kernels to return `Tensor&`.
+  // TODO Delete this once kernels don't do that anymore
+  template
+  struct return_to_ivalue final {
+    static IValue call(at::Tensor& v) {
+      return c10::ivalue::from(v);
+    }
+    static IValue copy(at::Tensor& v) {
+      return IValue(v);
+    }
+  };
+
+  // wrap_kernel_functor_unboxed_
+
+  template
+  struct wrap_kernel_functor_unboxed_ final {};
+
+  // This specialization is for kernels with a first argument that is NOT of type DispatchKeySet
+  // This includes kernels with 0 arguments.
+  template
+  struct wrap_kernel_functor_unboxed_ final {
+    static_assert(std::is_same::return_type>::value,
+      "Return type mismatch");
+    static_assert(std::is_same, typename guts::infer_function_traits_t::parameter_types>::value,
+      "Parameter types mismatch");
+
+    // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use &&
+    static ReturnType call(OperatorKernel* functor, DispatchKeySet, ParameterTypes... args) {
+      KernelFunctor* functor_ = static_cast(functor);
+      // Note [Plumbing Keys Through The Dispatcher 2]
+      // See Note [Plumbing Keys Through The Dispatcher] for the background.
+      // This functor explicitly takes in a dispatchKeySet and drops it on the floor- it does not forward it to the registered kernel.
+      //
+      // This is due to the calling convention within the dispatcher, which expects all registered kernels to have a first argument of type
+      // DispatchKeySet.
+      // This is not the case for pretty much all manually written kernels, however- this functor serves to separate the calling convention
+      // of the dispatcher from the calling convention of manually written kernels.
+      return (*functor_)(std::forward(args)...);
+    }
+  };
+
+  // This specialization is for kernels with a first argument of type DispatchKeySet
+  template
+  struct wrap_kernel_functor_unboxed_ final {
+    static_assert(std::is_same::return_type>::value,
+      "Return type mismatch");
+    static_assert(std::is_same, typename guts::infer_function_traits_t::parameter_types>::value,
+      "Parameter types mismatch");
+
+    // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use &&
+    static ReturnType call(OperatorKernel* functor, DispatchKeySet dispatchKeySet, ParameterTypes... args) {
+      KernelFunctor* functor_ = static_cast(functor);
+      // We're explicitly taking in a dispatchKeySet and forwarding it to the registered kernel.
+      // See Note [Plumbing Keys Through The Dispatcher 2] for details.
+      return (*functor_)(dispatchKeySet, std::forward(args)...);
+    }
+  };
+
+  template
+  using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_::func_type>;
+
+  // call_functor_with_args_from_stack
+
+  template
+  std::decay_t::return_type>
+  call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence, guts::typelist::typelist*) {
+    (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
+
+    // We're explicitly filtering out DispatchKeySet from the argument list.
+    // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
+    // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
+    // See Note [Plumbing Keys Through The Dispatcher] for the background.
+    return wrap_kernel_functor_unboxed::call(functor, dispatchKeySet,
+      ivalue_to_arg::type, AllowDeprecatedTypes>::call(
+        torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))
+    )...);
+  }
+
+  template
+  std::decay_t::return_type>
+  call_functor_with_args_from_stack(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack) {
+    // We're explicitly filtering out DispatchKeySet from the argument list.
+    // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
+    // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
+    // See Note [Plumbing Keys Through The Dispatcher] for the background.
+    using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func::parameter_types;
+    constexpr size_t num_ivalue_args = guts::typelist::size::value;
+    return call_functor_with_args_from_stack_(functor, dispatchKeySet, stack, std::make_index_sequence(), static_cast(nullptr));
+  }
+
+  // push_outputs
+
+  template
+  struct push_outputs final {
+    // Contrary to [Note: Argument forwarding in the dispatcher], we use OutputType&& here
+    // to avoid one extra call to the move constructor in this case. This is still not a
+    // universal reference though because OutputType is an explicitly specified class
+    // template parameter.
+    static void call(OutputType&& output, Stack* stack) {
+      torch::jit::push(*stack, return_to_ivalue::call(std::forward(output)));
+    }
+    static void copy(const OutputType& output, Stack* stack) {
+      torch::jit::push(*stack, return_to_ivalue::copy(output));
+    }
+  };
+  template
+  struct push_outputs, AllowDeprecatedTypes> final {
+    static void call(std::tuple&& output, Stack* stack) {
+      call_(std::move(output), stack, std::make_index_sequence());
+    }
+    static void copy(const std::tuple& output, Stack* stack) {
+      copy_(output, stack, std::make_index_sequence());
+    }
+
+  private:
+    template
+    static void call_(std::tuple&& output, Stack* stack, std::index_sequence) {
+      torch::jit::push(*stack, return_to_ivalue::call(std::forward(std::get(output)))...);
+    }
+    template
+    static void copy_(const std::tuple& output, Stack* stack, std::index_sequence) {
+      torch::jit::push(*stack, return_to_ivalue::copy(std::get(output))...);
+    }
+  };
+  template
+  struct push_outputs final {
+    static void call(int /*dummy*/, Stack* /*stack*/) {
+    }
+    static void copy(int /*dummy*/, Stack* /*stack*/) {
+    }
+  };
+
+  // make_boxed_from_unboxed_functor
+
+  template
+  struct make_boxed_from_unboxed_functor final {
+    static_assert(std::is_base_of::value,
+      "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+
+    static void call(OperatorKernel* functor, const OperatorHandle&, DispatchKeySet dispatchKeySet, Stack* stack) {
+      using ReturnType = typename guts::infer_function_traits_t::return_type;
+      // We're explicitly filtering out DispatchKeySet from the argument list.
+      // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
+      // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
+      // See Note [Plumbing Keys Through The Dispatcher] for the background.
+      using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func::parameter_types;
+      constexpr bool has_outputs = !std::is_same::value;
+      constexpr size_t num_inputs = guts::typelist::size::value;
+      if constexpr (has_outputs) {
+        // Decay ReturnType to ReturnType_ so that if a reference gets returned, we actually store it by value
+        // and don't get a dangling reference. This is only required because some kernels still return `Tensor&`.
+        // [Note: VC++ and 'std': ambiguous symbol]
+        using ReturnType_ = ::std::decay_t;
+        ReturnType_ output = call_functor_with_args_from_stack(functor, dispatchKeySet, stack);
+        torch::jit::drop(*stack, num_inputs);
+        // See note [ VC++ and 'std': ambiguous symbol]
+        push_outputs::call(::std::move(output), stack);
+      } else {
+        call_functor_with_args_from_stack(functor, dispatchKeySet, stack);
+        torch::jit::drop(*stack, num_inputs);
+      }
+    }
+  };
+} // namespace impl
+
+} // namespace c10
+
+namespace torch {
+  using OperatorKernel = c10::OperatorKernel;
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h
new file mode 100644
index 0000000000000000000000000000000000000000..7d6d8134698c649114a31f5bed05419e51a1d7fa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/boxing/impl/test_helpers.h
@@ -0,0 +1,124 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+template
+inline std::vector makeStack(Inputs&&... inputs) {
+  return {std::forward(inputs)...};
+}
+
+inline at::Tensor dummyTensor(c10::DispatchKeySet ks, bool requires_grad=false) {
+  auto* allocator = c10::GetCPUAllocator();
+  int64_t nelements = 1;
+  auto dtype = caffe2::TypeMeta::Make();
+  int64_t size_bytes = nelements * dtype.itemsize();
+  auto storage_impl = c10::make_intrusive(
+      c10::StorageImpl::use_byte_size_t(),
+      size_bytes,
+      allocator->allocate(size_bytes),
+      allocator,
+      /*resizable=*/true);
+  at::Tensor t = at::detail::make_tensor(storage_impl, ks, dtype);
+  // TODO: We add this to simulate the ideal case where we only have Autograd backend keys
+  //       on Tensor when it requires grad. But currently Autograd keys are added in TensorImpl
+  //       constructor by default.
+  if (!requires_grad) {
+    t.unsafeGetTensorImpl()->remove_autograd_key();
+  }
+  return t;
+}
+
+inline at::Tensor dummyTensor(c10::DispatchKey dispatch_key, bool requires_grad=false) {
+  return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
+}
+
+template
+inline std::vector callOp(const c10::OperatorHandle& op, Args... args) {
+  auto stack = makeStack(std::forward(args)...);
+  op.callBoxed(&stack);
+  return stack;
+}
+
+template
+inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
+  return op.typed().call(std::forward(args)...);
+}
+
+template
+inline Result callOpUnboxedWithDispatchKey(const c10::OperatorHandle& op, c10::DispatchKey dispatchKey, Args... args) {
+  return op.typed().callWithDispatchKey(dispatchKey, std::forward(args)...);
+}
+
+template
+inline Result callOpUnboxedWithPrecomputedDispatchKeySet(const c10::OperatorHandle& op, c10::DispatchKeySet ks, Args... args) {
+  return op.typed().redispatch(ks, std::forward(args)...);
+}
+
+inline void expectDoesntFindKernel(const char* op_name, c10::DispatchKey dispatch_key) {
+  auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
+  EXPECT_ANY_THROW(
+    callOp(*op, dummyTensor(dispatch_key), 5);
+  );
+}
+
+inline void expectDoesntFindOperator(const char* op_name) {
+  auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
+  EXPECT_FALSE(op.has_value());
+}
+
+template
+inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
+  try {
+    std::forward(functor)();
+  } catch (const Exception& e) {
+    EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
+    return;
+  }
+  ADD_FAILURE() << "Expected to throw exception containing \""
+    << expectMessageContains << "\" but didn't throw";
+}
+
+template
+void expectListEquals(c10::ArrayRef expected, std::array actual) {
+  EXPECT_EQ(expected.size(), actual.size());
+  for (const auto i : c10::irange(expected.size())) {
+    EXPECT_EQ(expected[i], actual[i]);
+  }
+}
+
+template
+void expectListEquals(c10::ArrayRef expected, c10::ArrayRef actual) {
+  EXPECT_EQ(expected.size(), actual.size());
+  for (const auto i : c10::irange(expected.size())) {
+    EXPECT_EQ(expected[i], actual[i]);
+  }
+}
+
+template
+void expectListEquals(c10::ArrayRef expected, c10::List actual) {
+  EXPECT_EQ(expected.size(), actual.size());
+  for (const auto i : c10::irange(expected.size())) {
+    EXPECT_EQ(expected[i], actual.get(i));
+  }
+}
+
+template
+void expectListEquals(c10::ArrayRef expected, std::vector actual) {
+  EXPECT_EQ(expected.size(), actual.size());
+  for (const auto i : c10::irange(expected.size())) {
+    EXPECT_EQ(expected[i], actual[i]);
+  }
+}
+
+// NB: This is not really sound, but all of the type sets constructed here
+// are singletons so it's fine
+static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
+  return legacyExtractDispatchKey(t.key_set());
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/builtin_function.h b/MLPY/Lib/site-packages/torch/include/ATen/core/builtin_function.h
new file mode 100644
index 0000000000000000000000000000000000000000..19b1f2d579d7acde75f22cff1f1ba6d2bf318725
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/builtin_function.h
@@ -0,0 +1,88 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace torch {
+namespace jit {
+
+struct BuiltinOpFunction : public Function {
+  BuiltinOpFunction(
+      c10::QualifiedName qualname,
+      c10::FunctionSchema schema,
+      std::function callable,
+      std::string doc_string = "")
+      : name_(std::move(qualname)),
+        callable_(std::move(callable)),
+        schema_(std::move(schema)),
+        doc_string_(std::move(doc_string)) {
+    TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
+  }
+
+  c10::string_view doc_string() const override {
+    return doc_string_;
+  }
+
+  void run(Stack& stack) override {
+    callable_(stack);
+  }
+
+  c10::intrusive_ptr runAsync(
+      Stack& stack,
+      TaskLauncher /* not used */) override {
+    run(stack);
+    auto res = c10::make_intrusive(stack.front().type());
+    res->markCompleted(std::move(stack.front()));
+    return res;
+  }
+
+  const c10::QualifiedName& qualname() const override {
+    return name_;
+  }
+
+  // if this isn't yet defined, run its method_creator function
+  void ensure_defined() override {
+    // nop
+  }
+
+  const c10::FunctionSchema& getSchema() const override {
+    return schema_;
+  }
+
+  size_t num_inputs() const override {
+    return schema_.arguments().size();
+  }
+
+  Function& setSchema(c10::FunctionSchema schema) override {
+    schema_ = std::move(schema);
+    return *this;
+  }
+
+  bool call(Stack& stack, c10::optional, c10::function_ref) override {
+    run(stack);
+    return false;
+  }
+
+  bool call(Stack& stack, c10::function_ref) override {
+    run(stack);
+    return false;
+  }
+
+  ~BuiltinOpFunction() override = default;
+
+ private:
+  c10::QualifiedName name_;
+
+  std::function callable_;
+
+  c10::FunctionSchema schema_;
+
+  std::string doc_string_;
+};
+
+} // namespace jit
+} // namespace torch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/class_type.h b/MLPY/Lib/site-packages/torch/include/ATen/core/class_type.h
new file mode 100644
index 0000000000000000000000000000000000000000..bd40c149e784ba7a593a5314e381a4c35db0a73d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/class_type.h
@@ -0,0 +1,441 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+
+namespace torch {
+namespace jit {
+struct CompilationUnit;
+struct Function;
+} // namespace jit
+} // namespace torch
+
+namespace c10 {
+
+struct FunctionSchema;
+
+// This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither.
+// This state is mutually exclusive. Buffers and Parameters can only appear on modules.
+enum class AttributeKind {
+  BUFFER,
+  PARAMETER,
+  REGULAR_ATTRIBUTE
+};
+
+// This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr).
+// Note: This structure does not represent the value of the attribute.
+struct TORCH_API ClassAttribute {
+  public:
+  ClassAttribute(AttributeKind kind,
+  TypePtr attributeType,
+  std::string attributeName) :
+    kind_(kind),
+    attributeType_(std::move(attributeType)),
+    attributeName_(std::move(attributeName)) {}
+
+  AttributeKind getKind() const {
+    return kind_;
+  }
+
+  const TypePtr& getType() const {
+    return attributeType_;
+  }
+
+  const std::string& getName() const {
+    return attributeName_;
+  }
+
+  private:
+  AttributeKind kind_;
+  TypePtr attributeType_;
+  std::string attributeName_;
+};
+
+/**
+ * User Defined Types
+ */
+
+struct ClassType;
+using ClassTypePtr = std::shared_ptr;
+using ::torch::jit::CompilationUnit;
+
+// This represents a class in TorchScript.
+struct TORCH_API ClassType : public NamedType {
+  // This represents an attribute of a class; a name associated with an attribute, and a
+  // getter and (optional) setter for that attribute.
+  struct Property {
+    std::string name;
+    torch::jit::Function* getter;
+    torch::jit::Function* setter;
+  };
+
+  // Create a class type with name `name` and its methods stored in `cu`.
+  static ClassTypePtr create(
+      c10::optional qualifiedName,
+      std::weak_ptr cu,
+      bool is_module = false,
+      std::string doc_string = "",
+      std::vector unresolved_class_attributes = {});
+
+  bool equals(const Type& rhs) const override {
+    if (this == &rhs) {
+      return true;
+    }
+    if (auto user_rhs = rhs.castRaw()) {
+      const auto& lhs_name = name().value();
+      const auto& rhs_name = user_rhs->name().value();
+
+      return lhs_name == rhs_name &&
+          this->compilation_unit() == user_rhs->compilation_unit();
+    }
+    return false;
+  }
+
+  std::string str() const override {
+     return annotation_str();
+  }
+
+  std::string repr_str() const override {
+    std::stringstream ss;
+    ss << str()
+       << " (of Python compilation unit at: " << compilation_unit().get() << ")";
+    return ss.str();
+  }
+
+  const std::vector& methods() const;
+
+  TypePtr findAttribute(const std::string& name) const {
+    size_t pos = 0;
+    for (const auto& attr : attributes_) {
+      if (name == attr.getName()) {
+        break;
+      }
+      ++pos;
+    }
+
+    if (pos >= attributes_.size()) {
+      return nullptr;
+    }
+    return attributes_[pos].getType();
+  }
+
+  const TypePtr& getAttribute(const std::string& name) const {
+    auto slot = findAttributeSlot(name);
+    TORCH_CHECK(
+        slot,
+        repr_str(),
+        " does not have an attribute with name '",
+        name,
+        "'");
+    return attributes_[*slot].getType();
+  }
+
+  size_t numAttributes() const {
+    return attributes_.size();
+  }
+
+  const TypePtr& getAttribute(size_t slot) const {
+    AT_ASSERT(slot < attributes_.size());
+    return attributes_.at(slot).getType();
+  }
+
+  const std::string getAttributeName(size_t slot) const {
+    AT_ASSERT(slot < attributes_.size());
+    return attributes_[slot].getName();
+  }
+
+  void checkNotExist(const std::string& name, const std::string& what) const;
+
+  // Attributes are stored in a specific slot at runtime for effiency.
+  // When emitting instructions we specify the slot so that attribute access is
+  // a constant lookup
+  c10::optional findAttributeSlot(const std::string& name) const {
+    size_t slot = 0;
+    for (const auto& attr : attributes_) {
+      if (name == attr.getName()) {
+        return slot;
+      }
+      slot++;
+    }
+    return c10::nullopt;
+  }
+  size_t getAttributeSlot(const std::string& name) const {
+    if (auto r = findAttributeSlot(name)) {
+      return *r;
+    }
+    TORCH_CHECK(
+        false,
+        repr_str(),
+        " does not have an attribute with name '",
+        name,
+        "'");
+  }
+
+  bool hasAttribute(const std::string& name) const {
+    return std::find_if(
+               attributes_.cbegin(),
+               attributes_.cend(),
+               [&](const ClassAttribute& attr) { return attr.getName() == name; }) !=
+        attributes_.cend();
+  }
+
+  bool isUnresolvedClassAttribute(const std::string& name) const;
+
+  at::ArrayRef containedTypes() const override {
+    return attributeTypes_;
+  }
+
+  size_t addAttribute(
+      const std::string& name,
+      TypePtr type,
+      bool is_parameter = false,
+      bool is_buffer = false);
+
+  // [Internal Only] Remove attribute from the ClassType,
+  // caller is responsible to make sure the modification is safe:
+  // it is unsafe to having existing allocations
+  // of this object around anymore, and any code that works on
+  // the attribute is now invalid. Only newly created code is
+  // valid again.
+  void unsafeRemoveAttribute(const std::string& name);
+
+  // [Internal Only] Change the type of an attribute of the ClassType,
+  // The caller is responsible to make sure the modification is safe:
+  // it is unsafe to maintain uses of the old type of the attribute,
+  // and any code that works on the attribute is now invalid.
+  // Only newly created code is valid again.
+  void unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty);
+
+  // Add attribute \p NAME if it doesn't exist or verify that it has a
+  // compatible type otherwise.
+  size_t addOrCheckAttribute(
+      const std::string& name,
+      TypePtr ty,
+      bool is_parameter = false,
+      bool is_buffer = false) {
+    auto slot_idx = findAttributeSlot(name);
+    if (!slot_idx) {
+      return addAttribute(name, std::move(ty), is_parameter, is_buffer);
+    }
+
+    TORCH_CHECK(
+        is_parameter == this->is_parameter(*slot_idx),
+        "Parameter field mismatch for the field '",
+        name,
+        "'");
+    const TypePtr& atype = getAttribute(*slot_idx);
+    TORCH_CHECK(
+      ty->isSubtypeOf(*atype),
+      ty->repr_str(),
+      " is not compatible with the type ",
+      atype->repr_str(),
+      " for the field '",
+      name,
+      "'");
+    return *slot_idx;
+  }
+
+  // Get the property with the given \p name, if it exists on the class.
+  c10::optional getProperty(const std::string& name);
+  // Add a property named \p name with \p getter and \p setter as its getter and setter.
+  void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter);
+  // Get a list of all properties.
+  const std::vector& properties() const {
+    return properties_;
+  }
+
+  bool hasConstant(const std::string& name) const {
+    return std::find_if(
+               constantNames_.cbegin(),
+               constantNames_.cend(),
+               [&](const std::string& constant) { return constant == name; }) !=
+        constantNames_.cend();
+  }
+
+  size_t addConstant(const std::string& name, const IValue& value);
+
+  c10::optional findConstantSlot(const std::string& name) const;
+
+  size_t getConstantSlot(const std::string& name) const {
+    if (auto r = findConstantSlot(name)) {
+      return *r;
+    }
+    TORCH_CHECK(
+        false,
+        repr_str(),
+        " does not have constant field with the name '",
+        name,
+        "'");
+  }
+
+  const std::string& getConstantName(size_t slot) const;
+
+  const std::string& doc_string() const {
+    return doc_string_;
+  }
+
+  IValue getConstant(const std::string& name) const;
+
+  IValue getConstant(size_t slot) const;
+
+  c10::optional findConstant(const std::string& name) const;
+
+  size_t numConstants() const;
+
+  at::ArrayRef constantNames() const {
+    return constantNames_;
+  }
+
+  at::ArrayRef constantValues() const;
+
+  // [Internal Only] Remove constant from the ClassType
+  // caller is responsible to make sure the modification is safe:
+  // it is unsafe to having existing allocations
+  // of this object around anymore, and any code that works on
+  // the attribute is now invalid. Only newly created code is
+  // valid again.
+  void unsafeRemoveConstant(const std::string& name);
+
+  TypePtr createWithContained(std::vector contained_types) const override {
+    auto ptr = ClassType::create(name(), compilation_unit_, is_module());
+    AT_ASSERT(numAttributes() == contained_types.size());
+    for(size_t i = 0; i < attributes_.size(); ++i) {
+      AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i]));
+      ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i]));
+    }
+    // Copy methods over
+    for (const auto& method : methods()) {
+      ptr->addMethod(method);
+    }
+    return ptr;
+  }
+
+  bool is_module() const override {
+    return isModule_;
+  }
+
+  const std::vector& getAttributes() const {
+    return attributes_;
+  }
+
+  bool is_parameter(size_t slot) const {
+    TORCH_INTERNAL_ASSERT(
+        is_module(), "asking for parameterSlots of non-Module");
+    return attributes_.at(slot).getKind() == AttributeKind::PARAMETER;
+  }
+
+  bool is_buffer(size_t slot) const {
+    TORCH_INTERNAL_ASSERT(
+        is_module(), "asking for bufferWrittenSlots of non-Module");
+    return attributes_.at(slot).getKind() == AttributeKind::BUFFER;
+  }
+
+  void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
+  void addForwardHook(torch::jit::Function* hook_ptr);
+  torch::jit::Function* findForwardPreHook(const std::string& name) const;
+  torch::jit::Function* findForwardHook(const std::string& name) const;
+  const std::vector& getForwardHooks() const;
+  const std::vector& getForwardPreHooks() const;
+
+  void checkForwardPreHookSchema(
+      int pre_hook_idx,
+      const FunctionSchema& pre_hook_schema) const;
+  void checkForwardHookSchema(
+      int hook_idx,
+      const FunctionSchema& hook_schema) const;
+
+  void addMethod(torch::jit::Function* method);
+  torch::jit::Function* findMethod(const std::string& name) const;
+  torch::jit::Function& getMethod(const std::string& name) const;
+  torch::jit::Function* findHook(const std::string& name) const;
+  torch::jit::Function& getHook(const std::string& name) const;
+  bool hasMethod(const std::string& name) const;
+
+  torch::jit::Function* findStaticMethod(const std::string& name) const;
+  void addStaticMethod(torch::jit::Function* method);
+
+  // [Internal Only] Remove method from the ClassType
+  // caller is responsible to make sure the modification is safe:
+  // it is unsafe to having existing allocations
+  // of this object around anymore, and any code that works on
+  // the attribute is now invalid. Only newly created code is
+  // valid again.
+  // Note this method is intended for freezing only.
+  void unsafeRemoveMethod(const std::string& name);
+
+  std::shared_ptr compilation_unit();
+
+  std::shared_ptr compilation_unit() const;
+
+  // generate a refined version of this class.
+  // It has the same name but the slot Types are subtypes of
+  // the original slots. It is only valid to refine a class type in a context
+  // where it is know that there are not assignments to the objects slots
+  // that would invalidate the refinement.
+  // These variants are not registered in the global class table.
+  ClassTypePtr refine(at::ArrayRef refined_slots) const;
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  static const TypeKind Kind = TypeKind::ClassType;
+
+ private:
+  ClassType(
+      c10::optional name,
+      std::weak_ptr cu,
+      bool is_module = false,
+      std::string doc_string = "",
+      std::vector unresolved_class_attributes = {});
+
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    const auto& n = name().value();
+    return n.qualifiedName();
+  }
+
+  void addAttribute(ClassAttribute classAttribute);
+  std::string getForwardPreHookErrorMessage(int pre_hook_idx) const;
+  std::string getForwardHookErrorMessage(int hook_idx) const;
+
+  // Mapping of attribute names -> their type.
+  // NOTE: this does not contain methods, which are stored in the module
+  // TODO: once modules support arbitrary ivalue attributes, we don't need this
+  // anymore.
+  // TODO: This is better represented as an OrderedDict, but alas it is not yet
+  // available from c10
+
+  // Mapping of constant names -> their value.
+  std::vector constantNames_;
+  std::vector constantValues_;
+  // Holds method attributes
+  std::weak_ptr compilation_unit_;
+
+  // Holds all atrributes, attribute details are found on ClassAttribute
+  std::vector attributes_;
+  // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
+  // Never fill this without using the appropriate provideNewClassAttribute method
+  std::vector attributeTypes_;
+
+  // List of methods associated with this class.
+  std::vector methods_;
+  std::vector staticmethods_;
+
+  // List of hooks to be run before/after forward.
+  std::vector forward_hooks_;
+  std::vector forward_pre_hooks_;
+
+  // List of properties exposed by this class.
+  std::vector properties_;
+
+  bool isModule_ = false;
+
+  // Doc string of class.
+  std::string doc_string_ = "";
+
+  // For error reporting accesses to class level attributes.
+  std::vector unresolved_class_attributes_;
+};
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/custom_class.h b/MLPY/Lib/site-packages/torch/include/ATen/core/custom_class.h
new file mode 100644
index 0000000000000000000000000000000000000000..601af3eb48c1222edfb07a5f98b3d02f4e4c5a57
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/custom_class.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+struct ClassType;
+using ClassTypePtr = std::shared_ptr;
+
+TORCH_API c10::ClassTypePtr getCustomClassTypeImpl(const std::type_index &tindex);
+
+template 
+const c10::ClassTypePtr& getCustomClassType() {
+  // Classes are never unregistered from getCustomClassTypeMap and the
+  // hash lookup can be a hot path, so just cache.
+  // For the same reason, it's fine If this ends up getting duplicated across
+  // DSO boundaries for whatever reason.
+  static c10::ClassTypePtr cache = getCustomClassTypeImpl(
+      std::type_index(typeid(T)));
+  return cache;
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h
new file mode 100644
index 0000000000000000000000000000000000000000..7ffa8df7e7bef13ff96e5473591e529845409b11
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/CppSignature.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+namespace impl {
+
+// A CppSignature object holds RTTI information about a C++ function signature at runtime
+// and can compare them or get a debug-printable name.
+class TORCH_API CppSignature final {
+public:
+    CppSignature(const CppSignature&) = default;
+    CppSignature(CppSignature&&) noexcept = default;
+    CppSignature& operator=(const CppSignature&) = default;
+    CppSignature& operator=(CppSignature&&) noexcept = default;
+
+    template
+    static CppSignature make() {
+        // Normalize functors, lambdas, function pointers, etc. into the plain function type
+        // The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
+        // We do this to guarantee that all CppSignature's for an operator will match, even if they're registered
+        // with different calling conventions.
+        // See Note [Plumbing Keys Through The Dispatcher]
+        using decayed_function_type = typename c10::remove_DispatchKeySet_arg_from_func>::func_type;
+
+        return CppSignature(std::type_index(typeid(decayed_function_type)));
+    }
+
+    std::string name() const {
+        return c10::demangle(signature_.name());
+    }
+
+    friend bool operator==(const CppSignature& lhs, const CppSignature& rhs) {
+        if (lhs.signature_ == rhs.signature_) {
+            return true;
+        }
+        // Without RTLD_GLOBAL, the type_index comparison could yield false because
+        // they point to different instances of the RTTI data, but the types would
+        // still be the same. Let's check for that case too.
+        // Note that there still is a case where this might not work, i.e. when
+        // linking libraries of different compilers together, they might have
+        // different ways to serialize a type name. That, together with a missing
+        // RTLD_GLOBAL, would still fail this.
+        if (0 == strcmp(lhs.signature_.name(), rhs.signature_.name())) {
+            return true;
+        }
+
+        return false;
+    }
+
+private:
+    explicit CppSignature(std::type_index signature): signature_(std::move(signature)) {}
+    std::type_index signature_;
+};
+
+inline bool operator!=(const CppSignature& lhs, const CppSignature& rhs) {
+    return !(lhs == rhs );
+}
+
+}
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h
new file mode 100644
index 0000000000000000000000000000000000000000..8c39b5f0a4bf12244d0d1b38ba4e2f33d72463ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/DispatchKeyExtractor.h
@@ -0,0 +1,242 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+namespace impl {
+
+// Take a DispatchKeySet for a Tensor and determine what the actual dispatch
+// DispatchKey should be, taking into account TLS, and skipping backends which
+// fall through.
+//
+// Unlike Tensor::key_set(), the value of this on a tensor can change depending
+// on TLS.
+//
+// NB: If there is no valid dispatch key, this will return Undefined
+static inline DispatchKeySet computeDispatchKeySet(
+    DispatchKeySet ks,
+    // The key mask lets us eliminate (by zero entries) keys which should not
+    // be considered for dispatch.  There are two cases when we use this:
+    //
+    // - If an operator's dispatch table contains a fallthrough entry, we
+    //   should bypass it entirely when finding the key
+    // - If a user invokes with redispatch, the mask lets us
+    //   zero out the key the user asked us to stop.
+    //
+    // These excluded backends are NOT tracked in the TLS, but must be applied
+    // AFTER TLS (since the backend may have been introduced for consideration
+    // by the included TLS), which is why you have to pass them in to this
+    // function (as opposed to just applying it to the input 'ks').
+    DispatchKeySet key_mask
+) {
+  c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set();
+  // TODO: It's a bit irritating that we have to do logical ORs here, it would
+  // be nice to only do one.  Can always_included be folded into the TLS?  Well,
+  // it's a bit troublesome, because fastpath TLS access requires the type of
+  // the TLS in question to be zero-initialized, so you don't actually win
+  // anyting in that case.
+  return (((ks | local.included_) - local.excluded_) & key_mask);
+}
+
+}
+
+namespace detail {
+  // A small gadget to extract the DispatchKeySet from types which are known
+  // to have it.  Used to extract dispatch keys from unboxed calls.
+  struct MultiDispatchKeySet : at::IterArgs {
+    DispatchKeySet ts;
+    void operator()(const at::Tensor& x) {
+      ts = ts | x.key_set();
+    }
+    void operator()(const c10::optional& x) {
+      if (x.has_value()) {
+        ts = ts | x->key_set();
+      }
+    }
+    void operator()(at::ArrayRef xs) {
+      for (const auto& x : xs) {
+        ts = ts | x.key_set();
+      }
+    }
+    // Tensor?[] translates to this case.
+    void operator()(const c10::List>& xs) {
+      for (c10::optional x : xs) {
+        if (x.has_value()) {
+          ts = ts | x.value().key_set();
+        }
+      }
+    }
+    // Structured Tensor[] translates to this case
+    void operator()(const at::ITensorListRef& xs) {
+      for (const auto& x : xs) {
+        ts = ts | x.key_set();
+      }
+    }
+    [[noreturn]] void operator()(at::ArrayRef>) {
+      // Just checking that the handling of Tensor?[] didn't change.
+      TORCH_INTERNAL_ASSERT(false);
+    }
+    void operator()(const at::Generator& gen) {
+      if (gen.defined()) {
+        ts = ts | gen.key_set();
+      }
+    }
+    void operator()(const c10::optional& gen) {
+      if (gen.has_value() && gen->defined()) {
+        ts = ts | gen->key_set();
+      }
+    }
+    template 
+    void operator()(const T&) {
+      // do nothing
+    }
+  };
+
+  // NB: take by const reference (Don't do universal forwarding here! You
+  // don't want to move into this function!)
+  template 
+  DispatchKeySet multi_dispatch_key_set(const Args&... args) {
+    return MultiDispatchKeySet().apply(args...).ts;
+  }
+}
+
+/**
+ * An instance of DispatchKeyExtractor knows how to get a dispatch key given
+ * a list of arguments for an operator call.
+ *
+ * The instance is specific for a certain operator as:
+ *  - In boxed dispatch, different operators have different ways to extract
+ *    the dispatch key (e.g. different numbers of arguments), and we precompute
+ *    the stack locations we should look at; and
+ *  - In all dispatch, some backends should be excluded from dispatch because
+ *    they have been registered as fallthrough.  The set of excluded backends
+ *    varies from operator, as some operators may have overridden the
+ *    fallthrough with custom behavior.
+ *
+ *   Note - this should maintain identical impl to the py dispatcher key extraction logic
+ *   at pytorch/torch/dispatcher.py
+ */
+struct TORCH_API DispatchKeyExtractor final {
+public:
+  static DispatchKeyExtractor make(const FunctionSchema& schema) {
+    return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema));
+  }
+
+  static DispatchKeyExtractor makeUninitialized() {
+    return DispatchKeyExtractor(c10::utils::bitset());
+  }
+
+  void registerSchema(const FunctionSchema& schema) {
+    TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset());
+    dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema);
+  }
+  void deregisterSchema() {
+    dispatch_arg_indices_reverse_ = c10::utils::bitset();
+  }
+
+  DispatchKeySet getDispatchKeySetBoxed(const torch::jit::Stack* stack) const {
+    DispatchKeySet ks;
+    dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) {
+      const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1);
+      if (C10_LIKELY(ivalue.isTensor())) {
+        // NB: Take care not to introduce a refcount bump (there's
+        // no safe toTensorRef method, alas)
+        ks = ks | ivalue.unsafeToTensorImpl()->key_set();
+      } else if (C10_UNLIKELY(ivalue.isTensorList())) {
+        for (const at::Tensor& tensor : ivalue.toTensorList()) {
+          ks = ks | tensor.key_set();
+        }
+      }
+      // Tensor?[] translates to a c10::List so we need to peek inside
+      else if (C10_UNLIKELY(ivalue.isList())) {
+        for (const auto& elt : ivalue.toListRef()) {
+          if (elt.isTensor()) {
+            ks = ks | elt.toTensor().key_set();
+          }
+        }
+      }
+    });
+    // Keys that are fallthrough should be skipped
+    if (requiresBitsetPerBackend_) {
+      auto backend_idx = ks.getBackendIndex();
+      return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
+    } else {
+      return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
+    }
+  }
+
+  template
+  DispatchKeySet getDispatchKeySetUnboxed(const Args&... args) const {
+    auto ks = detail::multi_dispatch_key_set(args...);
+    // Keys that are fallthrough should be skipped
+    if (requiresBitsetPerBackend_) {
+      auto backend_idx = ks.getBackendIndex();
+      return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]);
+    } else {
+      return impl::computeDispatchKeySet(ks, nonFallthroughKeys_);
+    }
+  }
+
+  void setOperatorHasFallthroughForKey(DispatchKey k, bool has_fallthrough);
+
+  std::string dumpState() const;
+  void checkInvariants(const FunctionSchema& schema) const;
+
+private:
+  static c10::utils::bitset makeBitsetForDispatchArgs(const FunctionSchema& schema) {
+    TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(),
+        "The function schema has ", schema.arguments().size(),
+        " arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
+    c10::utils::bitset dispatch_arg_indices_reverse;
+    for (const auto index : c10::irange(schema.arguments().size())) {
+      if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
+          schema.arguments()[index].type()->isSubtypeOf(
+              *ListType::ofTensors()) ||
+          schema.arguments()[index].type()->isSubtypeOf(
+              *ListType::ofOptionalTensors()) ||
+          schema.arguments()[index].type()->isSubtypeOf(
+              *OptionalType::ofTensor())) {
+        dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
+      }
+    }
+    return dispatch_arg_indices_reverse;
+  }
+
+  explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
+  : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse)
+  , nonFallthroughKeys_(DispatchKeySet::FULL)
+  , requiresBitsetPerBackend_(false) {
+    for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
+      nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
+    }
+  }
+
+  // this is a bitset that has ones for each argument index which has to be
+  // considered for dispatch. This avoids having to iterate over the stack
+  // to find all the tensors. The bits are stored in reverse order, i.e.
+  // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from
+  // the top of the stack (i.e. the i-th last argument of the function)
+  // is relevant for dispatch.
+  // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the
+  // fallthrough
+  c10::utils::bitset dispatch_arg_indices_reverse_;
+
+  // Set of functionality keys for which the operator does NOT have fallthrough kernel.
+  DispatchKeySet nonFallthroughKeys_;
+  // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND.
+  // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends.
+  std::array nonFallthroughKeysPerBackend_;
+  // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path),
+  // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_
+  bool requiresBitsetPerBackend_;
+};
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h
new file mode 100644
index 0000000000000000000000000000000000000000..dc08b28e9bd80cdb882c9c04dfdac12c45cf516d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/Dispatcher.h
@@ -0,0 +1,795 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#ifndef NDEBUG
+#include 
+#endif
+
+namespace c10 {
+
+TORCH_API bool show_dispatch_trace();
+TORCH_API void dispatch_trace_nesting_incr();
+TORCH_API void dispatch_trace_nesting_decr();
+TORCH_API int64_t dispatch_trace_nesting_value();
+
+struct DispatchTraceNestingGuard {
+  DispatchTraceNestingGuard() { dispatch_trace_nesting_incr(); }
+  ~DispatchTraceNestingGuard() { dispatch_trace_nesting_decr(); }
+};
+
+class TORCH_API OperatorHandle;
+template class TypedOperatorHandle;
+
+/**
+ * Implement this interface and register your instance with the dispatcher
+ * to get notified when operators are registered or deregistered with
+ * the dispatcher.
+ *
+ * NB: registration events only occur when a 'def' occurs; we don't trigger
+ * on 'impl' or 'fallback' calls.
+ */
+class TORCH_API OpRegistrationListener {
+public:
+  virtual ~OpRegistrationListener();
+
+  virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
+  virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
+};
+
+namespace detail {
+class RegistrationListenerList;
+}
+class SchemaRegistrationHandleRAII;
+
+/**
+ * Top-level dispatch interface for dispatching via the dynamic dispatcher.
+ * Most end users shouldn't use this directly; if you're trying to register
+ * ops look in op_registration
+ */
+class TORCH_API Dispatcher final {
+private:
+  // For direct access to backend fallback information
+  friend class impl::OperatorEntry;
+
+  struct OperatorDef final {
+    explicit OperatorDef(OperatorName&& op_name)
+    : op(std::move(op_name)) {}
+
+    impl::OperatorEntry op;
+
+    // These refer to the number of outstanding RegistrationHandleRAII
+    // for this operator.  def_count reflects only def() registrations
+    // (in the new world, this should only ever be 1, but old style
+    // registrations may register the schema multiple times, which
+    // will increase this count).  def_and_impl_count reflects the number
+    // of combined def() and impl() registrations.  When the last def() gets
+    // unregistered, we must immediately call the Deregistered listeners, but we
+    // must not actually delete the handle as there are other outstanding RAII
+    // destructors which will try to destruct and they had better still have a
+    // working operator handle in this case
+    size_t def_count = 0;
+    size_t def_and_impl_count = 0;
+  };
+  friend class OperatorHandle;
+  template friend class TypedOperatorHandle;
+
+  struct Guard final {
+    Guard() : alive(true), mutex() {}
+    std::atomic alive;
+    std::mutex mutex;
+  };
+
+public:
+  ~Dispatcher();
+
+  // Implementation note: this class abstracts over the fact that we have per-operator
+  // dispatch tables.  This could be easily adjusted to have a single global hash
+  // table.
+  static Dispatcher& realSingleton();
+
+  C10_ALWAYS_INLINE static Dispatcher& singleton() {
+#if !defined C10_MOBILE
+    // Implemented inline so that steady-state code needn't incur
+    // function-call overhead. We can't just inline `realSingleton`
+    // because the function-local static would get duplicated across
+    // all DSOs that include & use this header, leading to multiple
+    // singleton instances.
+    static Dispatcher& s = realSingleton();
+    return s;
+#else
+    // For C10_MOBILE, we should never inline a static function that
+    // has a static member, since the generated code calls
+    // __cxa_guard_acquire and __cxa_guard_release which help
+    // implement exactly once semantics for the initialization of the
+    // static Dispatcher& s above (for the non-mobile case). That
+    // additional code when duplicated across all operator stubs
+    // for every backend results in a lot of additional code
+    // being generated by the compiler.
+    return realSingleton();
+#endif
+  }
+
+  // ------------------------------------------------------------------------
+  //
+  // Accessing operators by schema
+  //
+  // ------------------------------------------------------------------------
+
+  /**
+   * Looks for an operator schema with the given name and overload name
+   * and returns it if it is registered WITH A SCHEMA.
+   * Returns nullopt otherwise.
+   */
+  c10::optional findSchema(const OperatorName& operator_name);
+
+  /**
+   * Variant of findSchema that results in less code generated at the call site.
+   * It (1) takes const char* pointer rather than OperatorName (so we skip
+   * generating std::string constructor calls at the call site), and (2)
+   * it raises an exception if the operator is not found (so we skip
+   * generating exception raising code at the call site)
+   *
+   * Irritatingly, we still have to generate the handful of instructions
+   * for dealing with an exception being thrown during static initialization
+   * (e.g. __cxa_guard_abort).  If we could annotate this method noexcept we
+   * could avoid this code too, but as the name of the function suggests,
+   * it does throw exceptions.
+   */
+  OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
+
+  // Like findSchema, but also returns OperatorHandle even if there is no schema
+  c10::optional findOp(const OperatorName& operator_name);
+
+  // Returns a list of all operator names present in the operatorLookupTable_
+  const std::vector getAllOpNames();
+
+  // ------------------------------------------------------------------------
+  //
+  // Invoking operators
+  //
+  // ------------------------------------------------------------------------
+
+  template
+  Return call(const TypedOperatorHandle& op, Args... args) const;
+
+
+  template
+  static Return callWithDispatchKeySlowPath(const TypedOperatorHandle& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
+
+  // Like call, but intended for use in a redispatch in kernels that have explicitly performed the DispatchKey update calculatulation.
+  // This will take the DispatchKeySet completely as is and dispatch to the kernel of the corresponding highest priority key in the set.
+  // Note that this version of redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask out the highest priority key.
+  // See Note [Plumbing Keys Through The Dispatcher]
+  template
+  Return redispatch(const TypedOperatorHandle& op, DispatchKeySet currentDispatchKeySet, Args... args) const;
+
+  // Invoke an operator via the boxed calling convention using an IValue stack
+  void callBoxed(const OperatorHandle& op, Stack* stack) const;
+  void callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const;
+
+  // TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
+  // See Note [Plumbing Keys Through The Dispatcher]
+  void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
+
+  bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
+    auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
+    if (dispatch_ix < 0) return false;
+    return backendFallbackKernels_[dispatch_ix].kernel.isValid();
+  }
+
+  // Used by torchdeploy/multipy for multiple interpreters racing.
+  void waitForDef(const FunctionSchema& schema);
+  void waitForImpl(const OperatorName& op_name, c10::optional dispatch_key);
+
+  // ------------------------------------------------------------------------
+  //
+  // Performing registrations (NON user public; use op_registration)
+  //
+  // ------------------------------------------------------------------------
+
+  /**
+   * Register a new operator schema.
+   *
+   * If a schema with the same operator name and overload name already exists,
+   * this function will check that both schemas are exactly identical.
+   */
+  RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector tags = {});
+
+  /**
+   * Register a kernel to the dispatch table for an operator.
+   * If dispatch_key is nullopt, then this registers a fallback kernel.
+   *
+   * @return A RAII object that manages the lifetime of the registration.
+   *         Once that object is destructed, the kernel will be deregistered.
+   */
+  // NB: steals the inferred function schema, as we may need to hold on to
+  // it for a bit until the real schema turns up
+  RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional dispatch_key, KernelFunction kernel, c10::optional cpp_signature, std::unique_ptr inferred_function_schema, std::string debug);
+
+  /**
+   * Given an operator, tells the Dispatcher that we have implemented an abstract impl
+   * for this op in the given Python module. Call this a "pystub".
+   */
+  RegistrationHandleRAII registerAbstractImplPyStub(const OperatorName& op_name, const char* pymodule, const char* context);
+
+  /**
+   * Given an operator, throws if we have an abstract impl pystub.
+   */
+  void throwIfHasAbstractImplPyStub(OperatorName op_name);
+
+  c10::optional> getAbstractImplPyStub(OperatorName op_name);
+
+  /**
+   * Register a new operator by name.
+   */
+  RegistrationHandleRAII registerName(OperatorName op_name);
+
+  /**
+   * Register a fallback kernel for a backend.
+   * If an operator is called but there is no concrete kernel for the dispatch
+   * key of the given operator arguments, it will check if there is such a
+   * fallback kernel for the given dispatch key and, if yes, call that one.
+   */
+  RegistrationHandleRAII registerFallback(DispatchKey dispatch_key, KernelFunction kernel, std::string debug);
+
+  /**
+   * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
+   * API.  These invocations are only permitted once per program, so we raise
+   * an error if this is called again for the same namespace.
+   */
+  RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
+
+  // ------------------------------------------------------------------------
+  //
+  // Listeners on registrations
+  //
+  // ------------------------------------------------------------------------
+
+  /**
+   * Add a listener that gets called whenever a new op is registered or an existing
+   * op is deregistered. Immediately after registering, this listener gets called
+   * for all previously registered ops, so it can be used to keep track of ops
+   * registered with this dispatcher.
+   */
+  RegistrationHandleRAII addRegistrationListener(std::unique_ptr listener);
+
+  void checkInvariants() const;
+
+  //
+  // ------------------------------------------------------------------------
+  //
+  // Assertions
+  //
+  // ------------------------------------------------------------------------
+
+  /**
+   * For testing purposes.
+   * Returns a list of all operators that were created through calls to registerImpl(),
+   * without any corresponding calls to registerDef(). After static initialization
+   * is done this is almost certainly a bug, as the created OperatorHandle won't have
+   * any schema associated with it and users calling the op through the dispatcher
+   * won't be able to access it
+   *
+   * Note that we cannot enforce this invariant "as we go" during static initialization,
+   * due to undefined static initialization order- we have no guarantees over the order
+   * in which .def() and .impl() calls are registered in the dispatcher at static
+   * initialization time. So this function should only be called after static initialization.
+   */
+  std::vector findDanglingImpls() const;
+
+  /**
+   * Useful for inspecting global Dispatcher registration state.
+   * Returns the names of all operators with a kernel registered for the specified DispatchKey.
+   * If no DispatchKey is specified, it returns all registered operators.
+   */
+  std::vector getRegistrationsForDispatchKey(c10::optional k) const;
+
+private:
+  Dispatcher();
+
+  static int64_t sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey);
+  static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey);
+  static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, c10::ArrayRef args);
+
+  #ifdef FBCODE_CAFFE2
+  static bool profilingOperatorEvents();
+  static void fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref);
+  static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref);
+  #endif // FBCODE_CAFFE2
+
+  OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
+  OperatorHandle findOrRegisterName_(const OperatorName& op_name);
+
+  void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
+  void deregisterImpl_(
+    const OperatorHandle& op,
+    const OperatorName& op_name,
+    c10::optional dispatch_key,
+    impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
+  void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
+  void deregisterFallback_(DispatchKey dispatchKey);
+  void deregisterLibrary_(const std::string& ns);
+  void cleanup(const OperatorHandle& op, const OperatorName& op_name);
+  void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug);
+
+  std::list operators_;
+#if !defined(C10_MOBILE)
+  LeftRight> operatorLookupTable_;
+#else
+  RWSafeLeftRightWrapper> operatorLookupTable_;
+#endif
+  // Map from namespace to debug string (saying, e.g., where the library was defined)
+  ska::flat_hash_map libraries_;
+
+  std::array backendFallbackKernels_;
+
+  std::unique_ptr listeners_;
+
+  // This condition variable gets notified whenever we add a new def/impl to the
+  // dispatch table.  This is primarily used by multipy/torchdeploy, when
+  // we have multiple interpreters trying to register to the dispatch table.
+  // In this situation, whenever the non-primary interpreter would have tried
+  // to register to the dispatch table, instead it will check to see if the
+  // expected registration has already been made, and if it hasn't, wait on
+  // this condition variable to see if it was just racing with the primary
+  // interpreter.
+  //
+  // We expect it to be rare for there to be any waiters on this condition
+  // variable.  This is mostly just to help give better diagnostics if
+  // something goes horribly wrong
+  std::condition_variable cond_var_;
+
+  // Protect concurrent access to the dispatcher.  We store this in a
+  // `shared_ptr` as we return callbacks that call back into dispatcher methods,
+  // and we need to be able to handle and guard against the event when the
+  // `Dispatcher` has been destroyed before the callbacks fire.
+  std::shared_ptr guard_;
+};
+
+/**
+ * This is a handle to an operator schema registered with the dispatcher.
+ * This handle can be used to register kernels with the dispatcher or
+ * to lookup a kernel for a certain set of arguments.
+ */
+class TORCH_API OperatorHandle {
+  template  friend struct std::hash;
+
+public:
+  OperatorHandle(OperatorHandle&&) noexcept = default;
+  OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
+  OperatorHandle(const OperatorHandle&) = default;
+  OperatorHandle& operator=(const OperatorHandle&) = default;
+  // NOLINTNEXTLINE(performance-trivially-destructible)
+  ~OperatorHandle();
+
+  const OperatorName& operator_name() const {
+    return operatorDef_->op.operator_name();
+  }
+
+  bool hasSchema() const {
+    return operatorDef_->op.hasSchema();
+  }
+
+  const FunctionSchema& schema() const {
+    return operatorDef_->op.schema();
+  }
+
+  const std::string& debug() const {
+    return operatorDef_->op.debug();
+  }
+
+  std::string dumpState() const {
+    return operatorDef_->op.dumpState();
+  }
+
+  bool hasKernelForDispatchKey(DispatchKey k) const {
+    return operatorDef_->op.hasKernelForDispatchKey(k);
+  }
+
+  bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
+    return operatorDef_->op.hasKernelForAnyDispatchKey(k);
+  }
+
+  bool hasComputedKernelForDispatchKey(DispatchKey k) const {
+    return operatorDef_->op.hasComputedKernelForDispatchKey(k);
+  }
+
+  std::string dumpComputedTable() const {
+    return operatorDef_->op.dumpComputedTable();
+  }
+
+  void checkInvariants() const {
+    return operatorDef_->op.checkInvariants();
+  }
+
+  c10::ArrayRef getTags() const {
+    return operatorDef_->op.getTags();
+  }
+
+  void setReportErrorCallback_(std::unique_ptr callback) {
+    operatorDef_->op.setReportErrorCallback_(std::move(callback));
+  }
+
+  bool hasTag(const at::Tag& tag) const {
+    for(const auto& tag_: getTags()) {
+      if (tag == tag_) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  template
+  TypedOperatorHandle typed() const {
+    // NB: This assert is not 100% sound: you can retrieve a typed() operator
+    // handle prior to ANY C++ signature being registered on the operator
+    // and the check will say everything is OK (at which point you can then
+    // smuggle in a kernel that is typed incorrectly).  For everything
+    // in core library this won't happen, because all the static registrations
+    // will be done by the time a typed() handle is acquired.
+#if !defined C10_MOBILE
+    operatorDef_->op.assertSignatureIsCorrect();
+    if (fn_has_symint::value) {
+      operatorDef_->op.assertSignatureIsCorrect::type>();
+    }
+#endif
+    return TypedOperatorHandle(operatorIterator_);
+  }
+
+  void callBoxed(Stack* stack) const {
+    c10::Dispatcher::singleton().callBoxed(*this, stack);
+  }
+
+  void callBoxed(Stack& stack) const {
+    callBoxed(&stack);
+  }
+
+  void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const {
+    c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack);
+  }
+
+  void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
+    c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
+  }
+
+  template 
+  PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const {
+    return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
+  }
+
+  bool operator==(const OperatorHandle& other) const {
+    return operatorDef_ == other.operatorDef_;
+  }
+
+  bool operator!=(const OperatorHandle& other) const {
+    return operatorDef_ != other.operatorDef_;
+  }
+
+private:
+  explicit OperatorHandle(std::list::iterator operatorIterator)
+  : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator)  {}
+  friend class Dispatcher;
+  template friend class TypedOperatorHandle;
+
+  // Storing a direct pointer to the OperatorDef even though we
+  // already have the iterator saves an instruction in the critical
+  // dispatch path. The iterator is effectively a
+  // pointer-to-std::list-node, and (at least in libstdc++'s
+  // implementation) the element is at an offset 16 bytes from that,
+  // because the prev/next pointers come first in the list node
+  // struct. So, an add instruction would be necessary to convert from the
+  // iterator to an OperatorDef*.
+  Dispatcher::OperatorDef* operatorDef_;
+
+  // We need to store this iterator in order to make
+  // Dispatcher::cleanup() fast -- it runs a lot on program
+  // termination (and presuambly library unloading).
+  std::list::iterator operatorIterator_;
+};
+
+/**
+ * This is a handle to an operator schema registered with the dispatcher.
+ * It holds the same information as an OperatorHandle, but it is templated
+ * on the operator arguments and allows calling the operator in an
+ * unboxed way.
+ */
+template
+class TypedOperatorHandle final {
+  static_assert(guts::false_t(), "FuncType in OperatorHandle::typed was not a valid function type");
+};
+template
+class TypedOperatorHandle final : public OperatorHandle {
+public:
+  TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
+  TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
+  TypedOperatorHandle(const TypedOperatorHandle&) = default;
+  TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
+
+  // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
+  C10_ALWAYS_INLINE Return call(Args... args) const {
+    return c10::Dispatcher::singleton().call(*this, std::forward(args)...);
+  }
+
+  // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
+  C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
+    return c10::Dispatcher::singleton().redispatch(*this, currentDispatchKeySet, std::forward(args)...);
+  }
+
+private:
+  explicit TypedOperatorHandle(std::list::iterator operatorIterator)
+  : OperatorHandle(operatorIterator) {}
+  friend class OperatorHandle;
+};
+
+namespace detail {
+template  inline void unused_arg_(const Args&...) {}
+
+// CaptureKernelCall is intended to capture return values from Dispatcher
+// unboxed kernel calls. A record function may request to get outputs from the
+// kernel calls. For boxed kernels, it's straightforward, the returned values
+// are in the stack object. The stack can be passed to record functions. For
+// unboxed kernels, we need to handle different kinds of return values, cache
+// them temporarily, then release the values for the actual function call
+// return.
+template 
+struct CaptureKernelCall {
+  template 
+  CaptureKernelCall(
+      const F& kernel,
+      const TypedOperatorHandle& op,
+      const DispatchKeySet& dispatchKeySet,
+      Args&&... args)
+      // Calls the kernel and capture the result in output_.
+      : output_{kernel.template call(
+            op,
+            dispatchKeySet,
+            std::forward(args)...)} {}
+  // Wraps the return values in a Stack.
+  Stack getOutputs() {
+    Stack stack;
+    impl::push_outputs::copy(output_, &stack);
+    return stack;
+  }
+  // Since we are returning the output_, we don't expect the output_ to be used
+  // afterward. Copy elision and RVO do not apply to class data members. Using
+  // move semantic to avoid copies when possible.
+  ReturnType release() && {
+    return std::move(output_);
+  }
+
+ private:
+  ReturnType output_;
+};
+
+// Handle the lvalue reference differently since it should not be moved.
+template <>
+inline at::Tensor& CaptureKernelCall::release() && {
+  return output_;
+}
+
+// Handle case where the kernel returns void.
+template <>
+struct CaptureKernelCall {
+  template 
+  CaptureKernelCall(
+      const F& kernel,
+      const TypedOperatorHandle& op,
+      const DispatchKeySet& dispatchKeySet,
+      Args&&... args) {
+    // Calling the kernel and no need to capture void.
+    kernel.template call(
+        op, dispatchKeySet, std::forward(args)...);
+  }
+  Stack getOutputs() {
+    return Stack();
+  }
+  void release() && {}
+};
+
+} // namespace detail
+
+// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
+template
+inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
+  // If callbacks need inputs, we box the arguments and pass them to the guard.
+  // Note: For perf reasons we wouldn't want to prematurely box the arguments.
+  at::RecordFunction guard(std::move(stepCallbacks));
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
+  auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
+  auto& schema = op.schema();
+  auto schema_ref = std::reference_wrapper(schema);
+  constexpr auto num_boxed_args = impl::boxed_size();
+  if constexpr (num_boxed_args != 0) {
+    if (guard.needsInputs()) {
+      // If we used std::array here, we would
+      // have to spend time default constructing the IValues in
+      // boxedArgs. aligned_storage has no such requirement.
+      impl::IValueAlignedStorage boxedArgs[num_boxed_args];
+      // For debugging only; could be removed (but the compiler will do
+      // that for us and it's nice to have the extra assurance of
+      // correctness from our debug builds).
+      int lastArgIdx = 0;
+      impl::boxArgsToStack(boxedArgs, lastArgIdx, args...);
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(lastArgIdx == num_boxed_args);
+      // I don't *think* we need std::launder here, because IValue has
+      // no subclasses and no const or reference fields.
+      runRecordFunction(guard, schema_ref, dispatchKey, c10::ArrayRef(reinterpret_cast(boxedArgs), num_boxed_args));
+      for (size_t ii = 0; ii < num_boxed_args; ++ii) {
+        reinterpret_cast(&boxedArgs[ii])->~IValue();
+      }
+    } else {
+      runRecordFunction(guard, schema_ref, dispatchKey);
+    }
+  } else {
+    runRecordFunction(guard, schema_ref, dispatchKey);
+  }
+
+  if (C10_UNLIKELY(guard.needsOutputs())) {
+    // Calls the kernel and capture the output temporarily to pass to
+    // RecordFunction.
+    detail::CaptureKernelCall captureKernelCall(
+        kernel, op, dispatchKeySet, std::forward(args)...);
+    guard.setOutputs(captureKernelCall.getOutputs());
+    // Releases the captured output to return to caller.
+    return std::move(captureKernelCall).release();
+  }
+
+  // keeping the guard alive while executing the kernel
+  return kernel.template call(op, dispatchKeySet, std::forward(args)...);
+}
+
+// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
+template
+C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle& op, Args... args) const {
+  detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
+  auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
+    .template getDispatchKeySetUnboxed(args...);
+#ifndef NDEBUG
+  DispatchTraceNestingGuard debug_guard;
+  if (show_dispatch_trace()) {
+      auto nesting_value = dispatch_trace_nesting_value();
+      for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
+      std::cerr << "[call] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
+  }
+#endif
+  const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
+#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
+  auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
+  if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
+    return callWithDispatchKeySlowPath(op, *step_callbacks, dispatchKeySet, kernel, std::forward(args)...);
+  }
+#endif  // PYTORCH_DISABLE_PER_OP_PROFILING
+
+#ifdef FBCODE_CAFFE2
+  if(profilingOperatorEvents()) {
+    struct FireOpRAII {
+       FireOpRAII(at::RecordFunction::schema_ref_t schema_ref) : schema_ref_(schema_ref) {
+           fireOpStartUSDT(schema_ref);
+        }
+       ~FireOpRAII() { fireOpEndUSDT(schema_ref_); }
+       at::RecordFunction::schema_ref_t schema_ref_;
+    } event(op.schema());
+    return kernel.template call(op, dispatchKeySet, std::forward(args)...);
+  } else {
+    return kernel.template call(op, dispatchKeySet, std::forward(args)...);
+  }
+#else
+    return kernel.template call(op, dispatchKeySet, std::forward(args)...);
+#endif // FBCODE_CAFFE2
+}
+
+// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
+template
+inline Return Dispatcher::redispatch(const TypedOperatorHandle& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
+  detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
+  // do not use RecordFunction on redispatch
+#ifndef NDEBUG
+  DispatchTraceNestingGuard debug_guard;
+  if (show_dispatch_trace()) {
+      auto nesting_value = dispatch_trace_nesting_value();
+      for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
+      std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
+  }
+#endif
+  const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
+  return kernel.template call(op, currentDispatchKeySet, std::forward(args)...);
+}
+
+inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
+  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+  const auto& entry = op.operatorDef_->op;
+  auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
+#ifndef NDEBUG
+  DispatchTraceNestingGuard debug_guard;
+  if (show_dispatch_trace()) {
+      auto nesting_value = dispatch_trace_nesting_value();
+      for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
+      std::cerr << "[callBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
+  }
+#endif
+  const auto& kernel = entry.lookup(dispatchKeySet);
+#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
+  auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
+  if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
+    at::RecordFunction guard(std::move(*step_callbacks));
+    auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
+    auto& schema = op.schema();
+    auto schema_ref = std::reference_wrapper(schema);
+    guard.needsInputs() ? runRecordFunction(guard, schema_ref, dispatchKey, c10::ArrayRef(stack->data(), stack->size()))
+                        : runRecordFunction(guard, schema_ref, dispatchKey);
+
+    // keeping the guard alive while executing the kernel
+    kernel.callBoxed(op, dispatchKeySet, stack);
+
+    if (C10_UNLIKELY(guard.needsOutputs())) {
+      guard.setOutputs(*stack);
+    }
+    return;
+  }
+#endif  // PYTORCH_DISABLE_PER_OP_PROFILING
+  kernel.callBoxed(op, dispatchKeySet, stack);
+}
+
+// NB: this doesn't count as a "true" dispatcher jump, so no instrumentation
+inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const {
+  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+  const auto& entry = op.operatorDef_->op;
+  // We still compute this as we're obligated to pass it on to the internal
+  // kernel, if it is a boxed fallback
+  auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
+  const auto& kernel = ([&]() {
+    if (op.hasKernelForDispatchKey(dk)) {
+      return entry.kernelForDispatchKey(dk);
+    } else {
+      auto idx = getDispatchTableIndexForDispatchKey(dk);
+      TORCH_INTERNAL_ASSERT(idx >= 0);
+      return backendFallbackKernels_[idx].kernel;
+    }
+  })();
+  kernel.callBoxed(op, dispatchKeySet, stack);
+}
+
+inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
+  // note: this doesn't need the mutex because write operations on the list keep iterators intact.
+  const auto& entry = op.operatorDef_->op;
+#ifndef NDEBUG
+  DispatchTraceNestingGuard debug_guard;
+  if (show_dispatch_trace()) {
+      auto nesting_value = dispatch_trace_nesting_value();
+      for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
+      std::cerr << "[redispatchBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
+  }
+#endif
+  const auto& kernel = entry.lookup(dispatchKeySet);
+  return kernel.callBoxed(op, dispatchKeySet, stack);
+}
+
+} // namespace c10
+
+namespace std {
+
+template <>
+struct hash {
+  size_t operator()(const c10::OperatorHandle& op) const noexcept {
+    return std::hash{}(static_cast(op.operatorDef_));
+  }
+};
+
+} // namespace std
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h
new file mode 100644
index 0000000000000000000000000000000000000000..ef2efd55af04ee5c9d2bb01683a31d4688435ccd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/ObservedOperators.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+struct TORCH_API ObservedOperators {
+  ObservedOperators() = delete;
+
+  static bool isObserved(const OperatorName& name);
+
+  static std::unordered_set& getUnobservedOperatorList();
+};
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h
new file mode 100644
index 0000000000000000000000000000000000000000..8ebdca6edee2b3534b0f3b8fbd6ddb7018c2b465
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorEntry.h
@@ -0,0 +1,313 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#ifdef C10_MOBILE
+#define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+#endif
+
+namespace c10 {
+
+class Dispatcher;
+
+namespace impl {
+
+// This data structure represents a kernel that was registered to us from a
+// user.  Unlike KernelFunction, AnnotatedKernel contains some extra metadata
+// about the kernel that isn't necessary for actual dispatching (this is why
+// we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
+// giving good error messages.
+struct AnnotatedKernel final {
+  AnnotatedKernel(KernelFunction k, std::unique_ptr s, std::string d)
+    : kernel(std::move(k))
+    , inferred_function_schema(std::move(s))
+    , debug(std::move(d))
+    {}
+  AnnotatedKernel() = default;
+  KernelFunction kernel;
+  std::unique_ptr inferred_function_schema;
+  // A little debug string to help us identify the kernel in question.
+  // Most importantly it records the TORCH_LIBRARY block that did the
+  // registration.
+  std::string debug;
+};
+
+// This data structure represents operator schema, with metadata specifying
+// where the registration of this schema occurred
+struct AnnotatedSchema final {
+  AnnotatedSchema(FunctionSchema s, std::string d)
+    : schema(std::move(s))
+    , debug(std::move(d))
+    {}
+  FunctionSchema schema;
+  std::string debug;
+};
+
+// Internal data structure that records information about a specific operator.
+// It's not part of the public API; typically, users will interact with
+// OperatorHandle instead.
+//
+// Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher
+// lock (this is important because some methods in OperatorEntry access
+// dispatcher state)
+class TORCH_API OperatorEntry final {
+public:
+  explicit OperatorEntry(OperatorName&& operator_name);
+
+  OperatorEntry(const OperatorEntry&) = delete;
+  OperatorEntry(OperatorEntry&&) noexcept = delete;
+  OperatorEntry& operator=(const OperatorEntry&) = delete;
+  OperatorEntry& operator=(OperatorEntry&&) noexcept = delete;
+
+  const FunctionSchema& schema() const {
+    TORCH_INTERNAL_ASSERT(schema_.has_value(), "Tried to access the schema for ", name_, " which doesn't have a schema registered yet");
+    return schema_->schema;
+  }
+  const std::string& debug() const {
+    TORCH_INTERNAL_ASSERT(schema_.has_value());
+    return schema_->debug;
+  }
+  bool hasSchema() const {
+    return schema_.has_value();
+  }
+
+  bool isObserved() const {
+    return is_observed_;
+  }
+
+  // We may allocate an OperatorEntry for an operator even when we don't
+  // have a schema.  When we receive the schema registration, we post
+  // facto register a schema.
+  //
+  // NB: registerSchema/deregisterSchema are not idempotent; if you
+  // attempt to register a schema when one is already present or vice
+  // versa that is an error.  (Refcounting for the registrations is
+  // handled in the OperatorHandle in Dispatcher)
+  void registerSchema(FunctionSchema&&, std::string&& debug, std::vector tags = {});
+  void deregisterSchema();
+
+  const OperatorName& operator_name() const {
+    return name_;
+  }
+
+#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
+  using AnnotatedKernelContainer = std::array;
+#else
+  using AnnotatedKernelContainer = std::list;
+#endif
+  using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator;
+
+  // Why are kernels and fallback asymmetric?  It has to do with ownership.
+  // Kernels and the computed dispatch tables for them are canonically
+  // owned by OperatorEntry, but backend fallbacks are specified once
+  // and apply for all operators, so they should be owned by Dispatcher.
+  // However, the registration of a backend fallback affects the
+  // state of the computed dispatch table, so when a backend fallback
+  // is updated, we need to update the operator tables too.  Thus,
+  // registerKernel is the mechanism by which we give kernels to
+  // operator entry to own (and update dispatch table), but we only
+  // need a non-owning mechanism to update fallback.
+
+  // Precondition: Dispatcher::mutex_ is held
+  // Postcondition: caller is responsible for disposing of the kernel
+  AnnotatedKernelContainerIterator registerKernel(
+    const Dispatcher& dispatcher,
+    c10::optional dispatch_key,
+    KernelFunction kernel,
+    c10::optional cpp_signature,
+    std::unique_ptr inferred_function_schema,
+    std::string debug
+  );
+
+  // Precondition: Dispatcher::mutex_ is held
+  void deregisterKernel_(
+    const Dispatcher& dispatcher,
+    c10::optional dispatch_key,
+    AnnotatedKernelContainerIterator kernel
+  );
+
+  // Precondition: Dispatcher::mutex_ is held
+  void updateFallback(
+    const Dispatcher& dispatcher,
+    DispatchKey dispatch_key
+  );
+
+  // Precondition: Dispatcher::mutex_ is held
+  void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
+    TORCH_INTERNAL_ASSERT(schema_.has_value());
+    schema_->schema.setAliasAnalysis(a);
+  }
+
+  std::string dumpComputedTable() const;
+  std::string dumpState() const;
+  void checkInvariants() const;
+
+  const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; }
+
+  // Asserts that the given FuncType is correct for calling this operator in an unboxed way.
+  template
+  inline void assertSignatureIsCorrect() {
+    assertSignatureIsCorrect(CppSignature::make(), fn_has_symint::value);
+  }
+
+  void assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const;
+
+  [[noreturn]] void reportError(DispatchKey dispatchKey) const;
+
+  const KernelFunction& lookup(DispatchKeySet ks) const {
+    const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
+    if (C10_UNLIKELY(idx == -1)) {
+      reportError(ks.highestPriorityTypeId());
+    }
+    const auto& kernel = dispatchTable_[idx];
+    // A valid kernel *always* has a boxed kernel and *may* have an
+    // unboxed kernel. However, we typically do unboxed calls in at::
+    // APIs, where the kernel 1) will very likely be valid and 2)
+    // should have an unboxed kernel. Checking the unboxed kernel
+    // first will allow us to avoid touching the boxed kernel at all
+    // in the common case.
+    if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
+      if (!kernel.isValid()) {
+        reportError(ks.highestPriorityTypeId());
+      }
+    }
+    return kernel;
+  }
+
+  std::string listAllDispatchKeys() const;
+
+  // Returns true if kernel_ has entry for any key in ks.
+  //
+  // Invariant: There are no alias keys in the passed-in dispatch key set.
+  // Note [No Alias Keys in DispatchKeySet]
+  // Alias keys should be checked using `hasKernelForDispatchKey`
+  // Alias keys shouldn't go inside of a DispatchKeySet, since they can technically
+  // have a value > 63 (causing overflow).
+  bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
+  // Returns true if kernel_ has entry for a particular key.
+  bool hasKernelForDispatchKey(DispatchKey k) const;
+  // Retrieves the kernel entry at a particular key.  Symmetric with
+  // hasKernelForDispatchKey.  To get the AnnotatedKernel, see
+  // getKernelForDispatchKey (private)
+  const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
+  // Returns true if the "computed table" has an entry for a particular key.
+  bool hasComputedKernelForDispatchKey(DispatchKey k) const;
+  // Returns all the operator tags added at the time of registration
+  const std::vector& getTags() const;
+  void setReportErrorCallback_(std::unique_ptr callback);
+
+  template 
+  PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
+    return py_cache_.ptr_or(self_interpreter, slow_accessor);
+  }
+
+private:
+
+  OperatorName name_;
+  c10::optional schema_;
+  #ifndef C10_MOBILE
+    std::vector tags_;
+  #endif
+  std::array dispatchTable_;
+  DispatchKeyExtractor dispatchKeyExtractor_;
+  // Pointer to the torch.ops.ns.op.overload object for speed
+  c10::PyHandleCache py_cache_;
+
+  // kernels_ stores all registered kernels for the corresponding dispatch key
+  // and catchAllKernels_ stores the catch-all kernels.
+  // If an operator library gets loaded that overwrites an already existing kernel,
+  // both kernels will be in that list but only the newer one will be in
+  // dispatchTable. If any of the kernels go away (say the library gets
+  // unloaded), we remove the kernel from this list and update the
+  // dispatchTable if necessary.
+  // Kernels in the list are ordered by registration time descendingly,
+  // newer registrations are before older registrations.
+  // We do not combine dispatchTable and kernels into one hash map because
+  // kernels is a larger data structure and accessed quite infrequently
+  // while dispatchTable is accessed often and should be kept small to fit
+  // into CPU caches.
+  // Invariants:
+  //  - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
+  //  - dispatchTable[dispatch_key] does not exist if and only if
+  //    kernels_[dispatch_key] does not exist
+  //  - If kernels_[dispatch_key] exists, then it has elements.
+  //    It is never an empty list.
+  //
+  // Why do we do that?
+  // -----
+  // We mostly do this to enable Jupyter notebooks where a cell registering
+  // a kernel could be executed multiple times and the later execution
+  // should overwrite the earlier one. Note that this still fails when the
+  // function schema changed between the executions, but it works as long
+  // as the function schema didn't change. A better solution would be to
+  // unload the old extension library from the Jupyter cell when the cell is
+  // re-executed and then only allow one kernel here, i.e. error if a kernel
+  // is already registered, but that's a lot of effort to implement and
+  // currently not high-pri.
+  ska::flat_hash_map
+#else
+                     std::list
+#endif
+                     > kernels_;
+
+  const AnnotatedKernel& missingKernel() const;
+  const AnnotatedKernel& ambiguousAutogradOtherKernel() const;
+
+  // cpp_signature_ stores function signature if any of
+  // the kernels was created in a way that allowed us to know the function
+  // signature (i.e. by supplying an unboxed C++ kernel function).
+  // If this is set, it will be used to check that future kernel
+  // registrations match and it will be used in unboxed function calls
+  // to verify their arguments against the known function signature.
+  struct CppSignatureWithDebug {
+    CppSignature signature;
+    std::string debug;
+    c10::optional dispatch_key;
+  };
+  c10::optional cpp_signature_;
+  c10::optional sym_cpp_signature_;
+
+  // A Python custom error handler for OperatorEntry::reportError
+  std::unique_ptr report_error_callback_;
+
+  // Whether this operator needs to be observed with RecordFunction
+  const bool is_observed_;
+
+  [[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const;
+  const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
+  std::pair computeDispatchTableEntryWithDebug(
+    const c10::Dispatcher& dispatcher, DispatchKey dispatch_key
+  ) const;
+  // This function re-establishes the invariant that dispatchTable
+  // contains the front element from the kernels list for a given runtime dispatch key.
+  void updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
+  // Like above, but also handles alias dispatch keys.
+  void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
+  // Like above, but for ALL entries in the dispatch table.
+  void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
+  // Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front().
+  const AnnotatedKernel* getKernelForDispatchKey(DispatchKey dispatch_key) const;
+};
+
+} // namespace impl
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h
new file mode 100644
index 0000000000000000000000000000000000000000..d542bc942279d06f791306177e7a4462cd917caf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/OperatorOptions.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+
+enum class AliasAnalysisKind : uint8_t {
+  INTERNAL_SPECIAL_CASE,
+  CONSERVATIVE, // The most conservative alias analysis type, assumes
+                // side-effects. This is the default analysis.
+  FROM_SCHEMA,
+  PURE_FUNCTION
+};
+
+#if !defined(_MSC_VER)
+constexpr // Our current MSVC version has a bug that doesn't allow this to be constexpr.
+#endif
+inline const char* toString(AliasAnalysisKind aliasAnalysisKind) {
+  return (aliasAnalysisKind == AliasAnalysisKind::CONSERVATIVE)
+      ? "CONSERVATIVE"
+      : (aliasAnalysisKind == AliasAnalysisKind::FROM_SCHEMA)
+          ? "FROM_SCHEMA"
+          : (aliasAnalysisKind == AliasAnalysisKind::PURE_FUNCTION)
+              ? "PURE_FUNCTION"
+              : (aliasAnalysisKind == AliasAnalysisKind::INTERNAL_SPECIAL_CASE)
+                  ? "INTERNAL_SPECIAL_CASE"
+                  : "UNKNOWN";
+}
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h
new file mode 100644
index 0000000000000000000000000000000000000000..a26f491a0ce9c84b0e1d5d8ef3ead0d0592d4b31
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dispatch/RegistrationHandleRAII.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+
+class RegistrationHandleRAII final {
+public:
+  explicit RegistrationHandleRAII(std::function onDestruction)
+      : onDestruction_(std::move(onDestruction)) {}
+
+  ~RegistrationHandleRAII() {
+    if (onDestruction_) {
+      onDestruction_();
+    }
+  }
+
+  RegistrationHandleRAII(const RegistrationHandleRAII&) = delete;
+  RegistrationHandleRAII& operator=(const RegistrationHandleRAII&) = delete;
+
+  RegistrationHandleRAII(RegistrationHandleRAII&& rhs) noexcept
+      : onDestruction_(std::move(rhs.onDestruction_)) {
+    rhs.onDestruction_ = nullptr;
+  }
+
+  RegistrationHandleRAII& operator=(RegistrationHandleRAII&& rhs) noexcept {
+    onDestruction_ = std::move(rhs.onDestruction_);
+    rhs.onDestruction_ = nullptr;
+    return *this;
+  }
+
+private:
+  std::function onDestruction_;
+};
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/dynamic_type.h b/MLPY/Lib/site-packages/torch/include/ATen/core/dynamic_type.h
new file mode 100644
index 0000000000000000000000000000000000000000..4b2e3970670c20c3dd2e8f3246cb59998451e313
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/dynamic_type.h
@@ -0,0 +1,239 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace c10 {
+
+using DynamicTypeBits = std::uint32_t;
+#define DYNAMIC_TYPE_BIT(x) (1u << x)
+
+constexpr DynamicTypeBits kDynamicCovariantTypeBit = DYNAMIC_TYPE_BIT(31);
+constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);
+
+constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
+constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
+constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
+constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
+constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
+constexpr DynamicTypeBits kDynamicTupleTypeBit = DYNAMIC_TYPE_BIT(8);
+constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
+
+#define FORALL_DYNAMIC_TYPES(_)                                              \
+  _(Tensor, DYNAMIC_TYPE_BIT(0), 1)                                          \
+  _(None, kDynamicNoneTypeBit, 1)                                            \
+  _(Bool, DYNAMIC_TYPE_BIT(2), 1)                                            \
+  _(Int, kDynamicIntTypeBit, 1)                                              \
+  _(Float, kDynamicFloatTypeBit, 1)                                          \
+  _(Complex, kDynamicComplexTypeBit, 1)                                      \
+  _(Number,                                                                  \
+    (kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit),    \
+    1)                                                                       \
+  _(String, DYNAMIC_TYPE_BIT(6), 1)                                          \
+  _(List, kDynamicListTypeBit, 0)                                            \
+  _(Tuple, (kDynamicTupleTypeBit | kDynamicCovariantTypeBit), 0)             \
+  _(Dict, DYNAMIC_TYPE_BIT(9), 0)                                            \
+  _(Class, kDynamicClassTypeBit, 0)                                          \
+  _(Optional,                                                                \
+    (DYNAMIC_TYPE_BIT(11) | kDynamicNoneTypeBit | kDynamicCovariantTypeBit), \
+    0)                                                                       \
+  _(AnyList, (kDynamicListTypeBit | kDynamicAnyTypeBit), 1)                  \
+  _(AnyTuple,                                                                \
+    (kDynamicTupleTypeBit | kDynamicCovariantTypeBit | kDynamicAnyTypeBit),  \
+    1)                                                                       \
+  _(DeviceObj, DYNAMIC_TYPE_BIT(12), 1)                                      \
+  _(StreamObj, DYNAMIC_TYPE_BIT(13), 1)                                      \
+  _(Capsule, DYNAMIC_TYPE_BIT(14), 1)                                        \
+  _(Generator, DYNAMIC_TYPE_BIT(15), 1)                                      \
+  _(Storage, DYNAMIC_TYPE_BIT(16), 1)                                        \
+  _(Var, DYNAMIC_TYPE_BIT(17), 0)                                            \
+  _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1)                \
+  _(QScheme, DYNAMIC_TYPE_BIT(18), 1)                                        \
+  _(Quantizer, DYNAMIC_TYPE_BIT(19), 1)                                      \
+  _(AnyEnum, DYNAMIC_TYPE_BIT(20), 1)                                        \
+  _(RRef, DYNAMIC_TYPE_BIT(21), 0)                                           \
+  _(Future, DYNAMIC_TYPE_BIT(22), 0)                                         \
+  _(Await, DYNAMIC_TYPE_BIT(23), 0)                                          \
+  _(Any, 0xffffffff, 1)
+
+#define FORALL_DYNAMIC_TYPES_FAKE(_) \
+  _(ScalarType, kDynamicIntTypeBit, 1)                                \
+  _(Layout, kDynamicIntTypeBit, 1)                                        \
+  _(SymInt, kDynamicIntTypeBit, 1)                                        \
+  _(MemoryFormat, kDynamicIntTypeBit, 1)
+
+#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
+  FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE)
+  FORALL_DYNAMIC_TYPES_FAKE(FORWARD_DECL_TYPE)
+#undef FORWARD_DECL_TYPE
+
+class DynamicType;
+using DynamicTypePtr = std::shared_ptr;
+
+/**
+ * DynamicType is designed as a low dependency type system for TorchScript. The
+ * existing JIT types are used for both compilation and runtime, which makes
+ * sense for server contexts because we often compile and run the model in
+ * the same process, however this doesn't hold for mobile devices where we
+ * always compiles a model ahead of time, therefore there will be dependencies
+ * which are not needed, but built with mobile runtime causing binary size
+ * bloat, by design. Every basic type like Int, Bool or String will bring their
+ * vtable, typeinfo, constructor, destructor and even more data from their
+ * specializations for STL types to the binary causing a long tail bloat.
+ *
+ * The core problem is about the complexity to implement and maintain a single
+ * type system for both analysis and execution purposes. Although they should
+ * have the exactly same semantics, in practice implement a unified abstraction
+ * adds conceptual and representational overhead for both sides of the world.
+ *
+ * To address the issues, DynamicType implements a minimal subset of JIT types
+ * and uses a generic algorithm to test all subtyping relations. To achieve
+ * this, we assign each dynamic type a single integer tag to represent its
+ * semantics. More specifically, a dynamic type is defined as a set of "control
+ * bits" and "data bits", where control bits describe the special behavior when
+ * testing a type and data bits map to identity of each nominal type. We use bit
+ * operations to perform all the tests.
+ *
+ * For example, a "covariant bit" is a control bit used to describe if a type
+ * is covariant, right now the most used one is tuple type, and in addition to
+ * the control bit, tuple type's data bit is the 8th bit from the LSB. Control
+ * bits start from MSB and data bits start from LSB.
+ *
+ * If two types are equal, then they are subtype of each other, also if the bits
+ * from one type tag is subset of the other tag, it automatically becomes a
+ * subtype of the other. This simplifies the subtyping logic a lot, and over the
+ * long term it is possible to adopt this scheme on the server side as well.
+ * Special cases can be added but they generally should not take too much code
+ * size.
+ *
+ * DynamicType may or may not inherit from c10::Type because it's not the core
+ * requirement of DynamicType to interface with existing JIT types, but we might
+ * want to inherit from c10::Type to reduce the migration cost.
+ */
+class DynamicType : public SharedType {
+  using ClassTypePtr = std::shared_ptr;
+
+  /**
+   * A implementation detail to support NamedTuple.
+   */
+  struct LabeledDynamicType {
+    c10::optional label;
+    DynamicTypePtr ty;
+    explicit LabeledDynamicType(DynamicTypePtr t) : ty(std::move(t)) {}
+
+    bool equals(const LabeledDynamicType& other) const;
+    bool isSubtypeOf(const LabeledDynamicType& other) const;
+  };
+
+ public:
+  // TODO Change Ptr to DynamicTypePtr when all migrations are done.
+  using Ptr = TypePtr;
+  using ElementType = DynamicType;
+  ~DynamicType() override;
+
+  struct Arguments {
+    Arguments() = default;
+    Arguments(c10::ArrayRef);
+    Arguments(const std::vector&, c10::ArrayRef);
+    std::vector elems;
+  };
+
+  enum class Tag : DynamicTypeBits {
+#define DYNAMIC_TYPE_ITEM(NAME, VAL, _) NAME = VAL,
+    FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_ITEM)
+    FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_ITEM)
+#undef DYNAMIC_TYPE_ITEM
+  };
+
+  bool equals(const Type& rhs) const override;
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+  std::string str() const override;
+  static const TypeKind Kind = TypeKind::DynamicType;
+  static TORCH_API DynamicTypePtr create(Type& ty);
+
+  explicit DynamicType(Tag, Arguments);
+  explicit DynamicType(Tag, c10::string_view, Arguments);
+
+  TypePtr containedType(size_t) const override;
+  size_t containedTypeSize() const override;
+  Tag tag() const {
+    return tag_;
+  }
+  const c10::optional& name() const {
+    return name_;
+  }
+  const Arguments& arguments() const {
+    return arguments_;
+  }
+  TORCH_API TypeKind dynamicKind() const;
+
+  // Should be used only on the server side to restore static type information.
+#ifndef C10_MOBILE
+  TORCH_API
+#endif
+  TypePtr fallback() const;
+
+ private:
+  bool symmetric() const override {
+    return false;
+  }
+  friend struct Type;
+  static std::shared_ptr create(const Type& ty);
+  DynamicType(const Type& other);
+  bool equals(const DynamicType& other) const;
+
+  template 
+  bool compareArguments(const DynamicType& other, F&& f) const {
+    if (arguments_.elems.size() != other.arguments_.elems.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < arguments_.elems.size(); i++) {
+      if (!f(arguments_.elems[i], other.arguments_.elems[i])) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  Tag tag_;
+  c10::optional name_;
+  union {
+    Arguments arguments_;
+    ClassTypePtr class_;
+  };
+};
+
+template 
+struct DynamicTypeTrait {
+  C10_NOINLINE static auto tagValue() {
+    TORCH_CHECK(false);
+    return DynamicType::Tag::Any;
+  }
+};
+
+namespace detail {
+C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
+}
+
+#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE)      \
+  template <>                                              \
+  struct TORCH_API DynamicTypeTrait {          \
+    C10_ERASE static auto tagValue() {                     \
+      return DynamicType::Tag::NAME;                       \
+    }                                                      \
+    static constexpr bool isBaseType = IS_BASE_TYPE;       \
+    template           \
+    static std::enable_if_t getBaseType() { \
+      static auto type = detail::makeBaseType(tagValue()); \
+      return type;                                         \
+    }                                                      \
+  }; // namespace c10
+FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
+FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_TAG_VALUE)
+#undef DYNAMIC_TYPE_TAG_VALUE
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/enum_tag.h b/MLPY/Lib/site-packages/torch/include/ATen/core/enum_tag.h
new file mode 100644
index 0000000000000000000000000000000000000000..0e5448211db5a6a9215f0fd794f07c6b61771e87
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/enum_tag.h
@@ -0,0 +1,20 @@
+#pragma once
+
+// @generated by torchgen/gen.py from enum_tag.h
+
+namespace at {
+    // Enum of valid tags obtained from the entries in tags.yaml
+    enum class Tag {
+        core,
+        data_dependent_output,
+        dynamic_output_shape,
+        generated,
+        inplace_view,
+        needs_fixed_stride_order,
+        nondeterministic_bitwise,
+        nondeterministic_seeded,
+        pointwise,
+        pt2_compliant_tag,
+        view_copy
+    };
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/enum_type.h b/MLPY/Lib/site-packages/torch/include/ATen/core/enum_type.h
new file mode 100644
index 0000000000000000000000000000000000000000..3cd67fd89778fa2df32a9b0e8585c3337419c3cb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/enum_type.h
@@ -0,0 +1,101 @@
+#pragma once
+
+#include 
+
+#include 
+
+namespace c10 {
+
+struct EnumType;
+using EnumTypePtr = std::shared_ptr;
+using EnumNameValue = std::pair;
+struct TORCH_API EnumType : public NamedType {
+  friend struct Type;
+  static const TypeKind Kind = TypeKind::EnumType;
+
+  static EnumTypePtr create(
+      const c10::QualifiedName& qualified_class_name,
+      TypePtr value,
+      std::vector enum_names_values,
+      std::weak_ptr<::torch::jit::CompilationUnit> cu) {
+    switch (value->kind()) {
+      case TypeKind::IntType:
+      case TypeKind::FloatType:
+      case TypeKind::StringType:
+        return EnumTypePtr(new EnumType(
+            qualified_class_name,
+            std::move(value),
+            std::move(enum_names_values),
+            std::move(cu)));
+      default:
+        AT_ERROR(
+            "Cannot create Enum with value type '",
+            value->str(),
+            "', only int, float and string are supported");
+    }
+  }
+
+  std::string str() const override {
+    return "Enum<" + annotation_str() + ">";
+  }
+
+  std::string repr_str() const override {
+    return str();
+  }
+
+  const TypePtr& getValueType() const {
+    return value_type_;
+  }
+
+  bool equals(const Type& rhs) const override {
+    if (auto* enum_rhs = rhs.castRaw()) {
+      return name().value() == enum_rhs->name().value() &&
+          *getValueType() == *(enum_rhs->getValueType()) &&
+          this->compilation_unit() == enum_rhs->compilation_unit();
+    }
+    return false;
+  }
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  std::shared_ptr compilation_unit()
+      const {
+    auto cu = cu_.lock();
+    return cu;
+  }
+
+  const QualifiedName& qualifiedClassName() const {
+    return name().value();
+  }
+
+  at::ArrayRef containedTypes() const override {
+    return value_type_;
+  }
+
+  const at::ArrayRef enumNamesValues() const {
+    return enum_names_values_;
+  }
+
+ private:
+  EnumType(
+      c10::QualifiedName qualified_class_name,
+      TypePtr value_type,
+      std::vector enum_names_values,
+      std::weak_ptr cu)
+      : NamedType(TypeKind::EnumType, std::move(qualified_class_name)),
+        value_type_(std::move(value_type)),
+        enum_names_values_(std::move(enum_names_values)),
+        cu_(std::move(cu)) {}
+
+  std::string annotation_str_impl(
+      C10_UNUSED TypePrinter printer = nullptr) const override {
+    const auto& n = name().value();
+    return n.qualifiedName();
+  }
+
+  TypePtr value_type_;
+  std::vector enum_names_values_;
+  std::weak_ptr<::torch::jit::CompilationUnit> cu_;
+};
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/function.h b/MLPY/Lib/site-packages/torch/include/ATen/core/function.h
new file mode 100644
index 0000000000000000000000000000000000000000..ef64da980b5c6f855a06cb2f5b12bae14da17da6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/function.h
@@ -0,0 +1,111 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+struct FunctionSchema;
+};
+
+namespace at {
+TORCH_API void launch(std::function func);
+}
+
+namespace torch {
+namespace jit {
+
+struct Graph;
+struct Code;
+
+namespace mobile {
+struct Code;
+}
+
+using Stack = std::vector;
+using Kwargs = std::unordered_map;
+struct RecursiveMethodCallError : public std::exception {};
+using TaskLauncher = std::function)>;
+
+TORCH_API void preoptimizeGraph(std::shared_ptr& graph, bool disable_autocast=false);
+
+// A Function is a pure Graph with no implicit `self` object bound.
+// It contains schema information and the executor that manages the
+// execution of the function. Method is a wrapper around an
+// underlying Function that also provides a `self` object.
+struct TORCH_API Function {
+  Function() = default;
+  Function(const Function&) = default;
+  Function& operator=(const Function&) = default;
+  Function(Function&&) noexcept = default;
+  Function& operator=(Function&&) noexcept = default;
+  virtual c10::string_view doc_string() const {
+    static constexpr c10::string_view no_doc_string = "";
+    return no_doc_string;
+  }
+
+  virtual bool isGraphFunction() const {
+    return false;
+  }
+
+  virtual void run(Stack& stack) = 0;
+
+  virtual c10::intrusive_ptr runAsync(
+      Stack& /*stack*/,
+      C10_UNUSED TaskLauncher taskLauncher = at::launch) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
+    return {};
+  }
+
+  at::IValue operator()(
+    Stack stack,
+    const Kwargs& kwargs = Kwargs()) {
+    getSchema().checkAndNormalizeInputs(stack, kwargs);
+    run(stack);
+    return stack.front();
+  }
+
+  virtual const c10::QualifiedName& qualname() const = 0;
+
+  const std::string& name() const {
+    return qualname().name();
+  }
+
+  // if this isn't yet defined, run its method_creator function
+  virtual void ensure_defined() = 0;
+
+  virtual const c10::FunctionSchema& getSchema() const = 0;
+
+  virtual size_t num_inputs() const = 0;
+
+  virtual Function& setSchema(c10::FunctionSchema schema) = 0;
+
+  // call() defines how different interpreter implementations interacts with
+  // Function objects. Basically interpreters need to provide a callback to
+  // communicate to Functions what to do if provided a Code object.
+  // Alternatively we could design the signature to return an optional Code
+  // object, but that requires special handling the null case in interpreter
+  // and the fallback behavior is not well defined by interpreter but rather
+  // Function themselves, so a callback approach is more reasonable than
+  // returning values.
+  // If call() returns true, then callback completes successfully, otherwise
+  // call() returns false.
+
+  // Overload for server interpreter, a bailout size is needed for graph executor.
+  virtual bool call(Stack&, c10::optional, c10::function_ref) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
+    return false;
+  }
+
+  // Overload for mobile interpreter.
+  virtual bool call(Stack&, c10::function_ref) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
+    return false;
+  }
+
+  virtual ~Function() = default;
+};
+} // namespace jit
+} // namespace torch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/function_schema.h b/MLPY/Lib/site-packages/torch/include/ATen/core/function_schema.h
new file mode 100644
index 0000000000000000000000000000000000000000..b0ab8d744da2a55d9bb0f4eb699f66274a04bcc5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/function_schema.h
@@ -0,0 +1,687 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+// schema as used in the compiler for resolving function calls and reporting
+// errors. These objects should be constructed from C10 schema once those
+// are available.
+
+struct Argument;
+struct FunctionSchema;
+
+using AliasTypeSet = std::vector;
+
+bool operator==(const Argument& lhs, const Argument& rhs);
+
+struct Argument {
+  Argument(
+      std::string name = "",
+      const TypePtr& type = nullptr,
+      c10::optional N = c10::nullopt,
+      c10::optional default_value = c10::nullopt,
+      bool kwarg_only = false,
+      c10::optional alias_info = c10::nullopt)
+    : Argument(std::move(name), type, type, N, std::move(default_value), kwarg_only, std::move(alias_info)) {}
+
+  Argument(
+      std::string name,
+      TypePtr fake_type,
+      TypePtr real_type,
+      c10::optional N = c10::nullopt,
+      c10::optional default_value = c10::nullopt,
+      bool kwarg_only = false,
+      c10::optional alias_info = c10::nullopt)
+      : name_(std::move(name)),
+        type_(fake_type ? std::move(fake_type) : TensorType::get()),
+        real_type_(real_type ? std::move(real_type) : type_),
+        N_(N),
+        default_value_(std::move(default_value)),
+        alias_info_(alias_info ? std::make_unique(std::move(*alias_info)) : nullptr),
+        kwarg_only_(kwarg_only) {
+    // this is an softly-enforced invariant for out arguments.
+    bool is_alias = alias_info_ != nullptr && alias_info_->isWrite();
+    is_out_ = kwarg_only_ && is_alias;
+  }
+
+  Argument(Argument&& rhs) noexcept = default;
+
+  Argument(const Argument& rhs)
+      : name_(rhs.name_),
+        type_(rhs.type_),
+        real_type_(rhs.real_type_),
+        N_(rhs.N_),
+        default_value_(rhs.default_value_),
+        alias_info_(rhs.alias_info_ ? std::make_unique(*rhs.alias_info_) : nullptr),
+        kwarg_only_(rhs.kwarg_only_),
+        is_out_(rhs.is_out_) {}
+
+  Argument& operator=(Argument&& rhs) = default;
+
+  Argument& operator=(const Argument& rhs) {
+    if (this != &rhs) {
+      name_ = rhs.name_;
+      type_ = rhs.type_;
+      real_type_ = rhs.real_type_;
+      N_ = rhs.N_;
+      default_value_ = rhs.default_value_;
+      alias_info_ = rhs.alias_info_ ? std::make_unique(*rhs.alias_info_) : nullptr;
+      kwarg_only_ = rhs.kwarg_only_;
+      is_out_ = rhs.is_out_;
+    }
+    return *this;
+  }
+
+  const std::string& name() const {
+    return name_;
+  }
+  const TypePtr& type() const {
+    return type_;
+  }
+  // if type() is non-null, this is guaranteed to be non-null (if no real
+  // type was provided, this takes on type()'s value)
+  const TypePtr& real_type() const {
+    return real_type_;
+  }
+  c10::optional N() const {
+    return N_;
+  }
+  const c10::optional& default_value() const {
+    return default_value_;
+  }
+  bool kwarg_only() const {
+    return kwarg_only_;
+  }
+
+  bool is_out() const {
+    return is_out_;
+  }
+
+  C10_NODISCARD const AliasInfo* alias_info() const {
+    return alias_info_.get();
+  }
+
+  bool is_inferred_type() const {
+    bool is_inferred_type = false;
+    TORCH_INTERNAL_ASSERT(type_);
+    if (auto pt = type_->cast()) {
+      if (pt->isInferredType()) {
+        is_inferred_type = true;
+      }
+    }
+    return is_inferred_type;
+  }
+
+  std::string formatTypeMismatchMsg(const std::string& actual_type) const {
+    std::string inferred_type_hint;
+    if (is_inferred_type()) {
+      inferred_type_hint = c10::str(
+          "Inferred '",
+          name(),
+          "' to be of type 'Tensor' ",
+          "because it was not annotated with an explicit type.\n");
+    }
+    return c10::str(
+        "Expected a value of type '",
+        type()->repr_str(),
+        "' for argument '",
+        name(),
+        "' but instead found type '",
+        actual_type,
+        "'.\n",
+        inferred_type_hint);
+  }
+
+  Argument cloneWithType(TypePtr new_type) const {
+    return Argument(
+        name_,
+        std::move(new_type),
+        N_,
+        default_value_,
+        kwarg_only_,
+        alias_info_ ? c10::optional(*alias_info_) : c10::nullopt);
+  }
+
+  // this function checks whether this Argument is backward compatible with
+  // the old one. we consider the following cases are backward compatible:
+  //   1) two arguments are equal
+  //   2) this arg's type should be subtype of old
+  //   3) this arg must provide the same default value if old arg has one,
+  bool isBackwardCompatibleWith(
+      const Argument& old,
+      std::ostream* why_not=nullptr) const;
+
+  // this function checks whether this Argument is forward compatible with
+  // the old one. we consider the following cases are forward compatible:
+  //   1) two arguments are equal
+  //   2) this arg's type should be subtype of old
+  //   3) this arg must provide the same default value if old arg has one,
+  bool isForwardCompatibleWith(
+      const Argument& old,
+      std::ostream* why_not = nullptr) const;
+
+ private:
+  std::string name_;
+  TypePtr type_;
+  TypePtr real_type_; // this is ScalarType, not int, e.g.
+  // for list types, an optional statically known length for the list
+  // e.g. for int[3]: type = ListType::ofInts(), N = 3
+  // If present, this will allow scalars to be broadcast to this length to
+  // become a list.
+  c10::optional N_;
+
+  c10::optional default_value_;
+  // AliasInfo is huge, so let's only allocate memory for it if
+  // necessary (which it isn't during schema parsing on startup, to
+  // give a pertinent example).
+  std::unique_ptr alias_info_;
+  // is this only specifiable as a keyword argument?
+  bool kwarg_only_;
+  // marks if the argument is out variant of the schema
+  bool is_out_;
+};
+
+inline bool operator==(const Argument& lhs, const Argument& rhs) {
+  return lhs.name() == rhs.name()
+          && *lhs.type() == *rhs.type()
+          && lhs.N() == rhs.N()
+          && lhs.default_value() == rhs.default_value()
+          && lhs.kwarg_only() == rhs.kwarg_only()
+          && (lhs.alias_info() == rhs.alias_info()
+              || (lhs.alias_info() != nullptr && rhs.alias_info() != nullptr
+                   && *lhs.alias_info() == *rhs.alias_info()));
+}
+
+inline bool operator!=(const Argument& lhs, const Argument& rhs) {
+  return !(lhs == rhs);
+}
+
+enum struct TORCH_API SchemaArgType { input, output };
+
+/**
+ * struct SchemaArgument
+ *
+ * Structure used to represent arguments or returns for a schema.
+ */
+struct TORCH_API SchemaArgument {
+  SchemaArgType type;
+  size_t index;
+  SchemaArgument(SchemaArgType tpe, size_t idx) : type(tpe), index(idx) {}
+  bool operator==(const SchemaArgument& rhs) const {
+    return type == rhs.type && index == rhs.index;
+  }
+};
+
+bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
+
+struct TORCH_API FunctionSchema {
+  FunctionSchema(
+      std::string name,
+      std::string overload_name,
+      std::vector arguments,
+      std::vector returns,
+      bool is_vararg = false,
+      bool is_varret = false)
+      : name_({std::move(name), std::move(overload_name)}),
+        arguments_(std::move(arguments)),
+        returns_(std::move(returns)),
+        is_vararg_(is_vararg),
+        is_varret_(is_varret) {
+    checkSchema();
+  }
+
+  FunctionSchema(
+      Symbol name,
+      std::string overload_name,
+      std::vector arguments,
+      std::vector returns,
+      bool is_vararg = false,
+      bool is_varret = false)
+      : FunctionSchema(
+            name.toQualString(),
+            std::move(overload_name),
+            std::move(arguments),
+            std::move(returns),
+            is_vararg,
+            is_varret) {
+    checkSchema();
+  }
+
+  // Checks whether this schema is backward compatible with the old one.
+  // The following conditions must be true:
+  // [Function structure] The new schema's name, overload-name, varargs, and
+  //      return arity are the same.
+  // [Output Narrowing] The new schema's output type must be the same class
+  //      or inherit from the old schema's output type.
+  // [Argument count] The new schema must have at least as many arguments as
+  //      the old schema (considering the list of positional and kwargs).
+  // [Arg Compatibility] Every argument in the old schema has a corresponding
+  //      argument in the new schema that:
+  //        * is at the same position.
+  //        * has the same name.
+  //        * is either positional, or kwarg and the old argument was kwarg.
+  //        * has the same type, or the old argument's type inherits from the
+  //          new argument's type.
+  // [Default Values] Every new argument must have a default value.
+  // E.g.
+  //   OK    f_new(a, b, c=1) => f_old(a, b)
+  //   NOK   f_new(a, c=1, *, b) => f_old(a, *, b)
+  //   OK    f_new(a, b, *, c) => f_old(a, *, b, c)
+  //   NOK   f_new(a, *, b, c) -> f_old(a, b, *, c)
+  //   NOK   f_new(a, *, c, b) => f_old(a, *, b, c)
+  //   OK    f_new(a, *, b, c, d=1) => f_old(a, *, b, c)
+  bool isBackwardCompatibleWith(
+      const FunctionSchema& old,
+      std::ostream* why_not = nullptr) const;
+
+  // Checks whether this schema is forward compatible with the old one.
+  // The following conditions must be true:
+  // [Function structure] The new schema's name, overload-name, varargs, and
+  //      return arity are the same.
+  // [Output Narrowing] The new schema's output type must be the same class
+  //      or inherit from the old schema's output type.
+  // [Arg Compatibility] Every argument in the old schema has a corresponding
+  //      argument in the new schema that:
+  //        * is at the same position.
+  //        * has the same name.
+  //        * is either positional, or kwarg and the old argument was kwarg.
+  //        * has the same type, or the old argument's type inherits from the
+  //          new argument's type.
+  // [Default Values] Every new argument must have a default value.
+  //         Each default value type should NOT be a container type.
+  // [Positioning] All defaults arguments MUST go after either old
+  //         default arguments or the end of positional arguments
+  //         and right BEFORE all out arguments
+  bool isForwardCompatibleWith(
+      const FunctionSchema& old,
+      std::ostringstream& why_not) const;
+
+ private:
+  OperatorName name_;
+  std::vector arguments_;
+  std::vector returns_;
+  // if true then this schema takes an arbitrary number of additional arguments
+  // after the argument specified in arguments
+  // currently this is used primarily to represent 'primitive' operators whose
+  // arguments are not checked by schema
+  bool is_vararg_;
+  bool is_varret_;
+
+  // if no alias information is directly specified, what kind of "default"
+  // alias information should we infer?
+  // NB: due to alias analysis kind merging, this may be nullopt.  Eventually
+  // this should always be set no matter what
+  c10::optional alias_kind_;
+
+  template 
+  void checkArg(const IValue& value, const Argument& argument, optional pos) const;
+
+  void checkSchema() const {
+    bool seen_default_arg = false;
+    for (const auto& arg : arguments()) {
+      if (arg.default_value()) {
+        seen_default_arg = true;
+      } else {
+        // we have historically serialized broadcasting lists wo/default values,
+        // so to not break BC allow lists here
+        if (arg.type()->kind() == ListType::Kind) {
+          continue;
+        }
+        TORCH_INTERNAL_ASSERT(
+            !seen_default_arg || arg.kwarg_only(),
+            "Non-default positional argument follows default argument. Parameter ",
+            arg.name(),
+            " in ",
+            *this);
+      }
+    }
+  }
+
+ public:
+
+  void dump() const;
+
+  const OperatorName& operator_name() const {
+    return name_;
+  }
+  const std::string& name() const {
+    return name_.name;
+  }
+  const std::string& overload_name() const {
+    return name_.overload_name;
+  }
+  const std::vector& arguments() const {
+    return arguments_;
+  }
+  const std::vector& returns() const {
+    return returns_;
+  }
+  bool is_vararg() const {
+    return is_vararg_;
+  }
+  bool is_varret() const {
+    return is_varret_;
+  }
+  bool is_aliasing(const c10::SchemaArgument &argument) const {
+    TORCH_INTERNAL_ASSERT(
+    argument.index < getCorrectList(argument.type).size(),
+    "Invalid index for schema.");
+    const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info();
+    return aliasInfo;
+  }
+  bool is_mutable() const {
+    return std::any_of(
+        arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) {
+          const AliasInfo* aliasInfo = arg.alias_info();
+          return aliasInfo && aliasInfo->isWrite();
+        });
+  }
+  bool is_mutable(const c10::SchemaArgument &argument) const {
+    TORCH_INTERNAL_ASSERT(
+        argument.index < getCorrectList(argument.type).size(),
+        "Invalid index for schema.");
+    const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info();
+    return aliasInfo && aliasInfo->isWrite();
+  }
+  bool is_mutable(c10::string_view name) const {
+    c10::optional index = argumentIndexWithName(name);
+    TORCH_INTERNAL_ASSERT(
+        index != c10::nullopt, "Schema has no argument named ", name);
+
+    return is_mutable({c10::SchemaArgType::input, static_cast(*index)});
+  }
+
+  // Returns whether lhs and rhs may alias directly.
+  // This does not account for cases where lhs or rhs are a container that
+  // may contain elements that alias the other argument.
+  // FunctionSchema::may_contain_alias will include that functionality.
+  bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const;
+
+  // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container
+  // that may contain elements that alias the other argument.
+  // bidirectional = false only returns whether lhs may contain an alias of rhs
+  // while bidirectional = true returns both directions.
+  bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const;
+
+  // Returns whether the two AliasTypeSets contain any similarities
+  // ie: whether the two type sets can alias.
+  bool canAliasTypeSetsAlias(const c10::optional &lhs, const c10::optional &rhs) const;
+
+  // Recursively Finds all contained types within the AliasTypeSet.
+  c10::optional getAliasTypeSetContainedTypes(const c10::optional &aliasTypeSet) const;
+
+  // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp.
+  // Used to map types to a type such that all types that can alias will be mapped to the same type.
+  // For example, calling this method on 'Optional[List[int]]' is the same as calling this method
+  // on 'List[int]'.
+  c10::optional mapTypeToAliasTypeSet(const TypePtr& type) const;
+
+  // Returns either arguments() or returns() depending on the SchemaArgType
+  // output => returns(), input => arguments()
+  const std::vector& getCorrectList(SchemaArgType type) const;
+
+  c10::optional argumentIndexWithName(c10::string_view name) const {
+    for (const auto i : c10::irange(arguments().size())) {
+      if(name == arguments()[i].name())
+        return i;
+    }
+    return c10::nullopt;
+  }
+  FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
+    return FunctionSchema(
+        std::move(name),
+        std::move(overload_name),
+        arguments(),
+        returns(),
+        is_vararg(),
+        is_varret()
+        );
+  }
+  FunctionSchema cloneWithArguments(std::vector new_arguments) const {
+    return FunctionSchema(
+        name(),
+        overload_name(),
+        std::move(new_arguments),
+        returns(),
+        is_vararg(),
+        is_varret());
+  }
+  FunctionSchema cloneWithReturns(std::vector new_returns) const {
+    return FunctionSchema(
+        name(),
+        overload_name(),
+        arguments(),
+        std::move(new_returns),
+        is_vararg(),
+        is_varret());
+  }
+
+  std::string formatTypeMismatchMsg(
+      const Argument& expected,
+      const std::string& actual_type,
+      c10::optional position = c10::nullopt,
+      c10::optional value = c10::nullopt) const;
+
+  FunctionSchema cloneWithRemappedTypes(
+      const std::function type_map) const;
+
+  FunctionSchema cloneWithRealTypes(bool with_symint=true) const;
+
+  // Check that inputs have the correct types and appends any missing default
+  // values.
+  template 
+  void checkAndNormalizeInputs(
+      std::vector& inputs,
+      const std::unordered_map& kwargs =
+          std::unordered_map{}) const;
+
+  std::string findErrorInKwargs(const std::vector& kwargs) const;
+
+  bool hasAnyAliasInfo() const {
+    for (const auto& arg : arguments_) {
+      if (arg.alias_info() != nullptr) {
+        return true;
+      }
+    }
+    for (const auto& ret : returns_) {
+      if (ret.alias_info() != nullptr) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+
+  // TODO remove the mutation here
+  bool isDefaultAliasAnalysisKind() const {
+    return !alias_kind_;
+  }
+  AliasAnalysisKind aliasAnalysis() const {
+    return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE);
+  }
+  void setAliasAnalysis(AliasAnalysisKind v) {
+    alias_kind_ = v;
+  }
+
+  c10::optional getNamespace() const {
+    return name_.getNamespace();
+  }
+
+  // Returns true if we successfully set the namespace (as there
+  // was none set, and false otherwise)
+  bool setNamespaceIfNotSet(const char* ns) {
+    return name_.setNamespaceIfNotSet(ns);
+  }
+
+  // can a function with this schema be substituted for a function of rhs's
+  // schema and have the program typecheck?
+  // as_method - if true, treat this schema as a method and ignore
+  // the first argument, which will be the object in both cases
+  bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
+};
+
+inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
+  return lhs.name() == rhs.name()
+     && lhs.overload_name() == rhs.overload_name()
+     && lhs.arguments() == rhs.arguments()
+     && lhs.returns() == rhs.returns()
+     && lhs.is_vararg() == rhs.is_vararg()
+     && lhs.is_varret() == rhs.is_varret();
+}
+
+inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
+  return !(lhs == rhs);
+}
+
+// print out Argument, which is compatible with FunctionSchema parser
+// full format: Type(alias)? name=default_value
+inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
+
+  // for adjusting the ? position.
+  // in schema, we have Tensor?(a!) input, and t(a!)?.
+  // however, t?(a!) doesn't work with schema parser.
+  // so we always use Type(alias)? format
+  // real_type versus fake_type: in order to be compatible with FunctionSchema
+  // parser, printing an argument with either MemoryFormat or Layout type should
+  // give us the original schema string, hence printing out real_type.
+  auto type = arg.real_type();
+  bool is_opt = type->kind() == OptionalType::Kind;
+  auto unopt_type = is_opt ? type->castRaw()->getElementType() : type;
+
+  if (unopt_type->kind() == ListType::Kind) {
+    // sized lists get size N from arg, not type
+    auto list = unopt_type->cast();
+    out << list->getElementType()->str();
+    if (arg.alias_info() && !arg.alias_info()->containedTypes().empty()){
+      out << arg.alias_info()->containedTypes()[0];
+    }
+    std::string N = "";
+    if (arg.N()) {
+        N = std::to_string(*arg.N());
+    }
+    out << "[" << N << "]";
+  } else {
+    out << unopt_type->str();
+  }
+
+  // print alias info if it has beforeSets.
+  if (arg.alias_info() && !arg.alias_info()->beforeSets().empty()) {
+    out << *arg.alias_info();
+  }
+
+  if (is_opt) {
+    out << "?";
+  }
+
+  if (!arg.name().empty()) {
+    out << " " << arg.name();
+  }
+
+  if (arg.default_value()) {
+    out << "=";
+    if ((type->kind() == c10::TypeKind::StringType ||
+        unopt_type->kind() == c10::TypeKind::StringType) &&
+        arg.default_value().value().isString()) {
+      printQuotedString(out, arg.default_value().value().toStringRef());
+    } else if (type->kind() == TypeKind::ListType && type->castRaw()->getElementType()->kind() == c10::TypeKind::IntType) {
+      // We want to faithfully replicate JIT schema.
+      // in native_functions.yaml defaults for int arrays with a single value always look like
+      //   int[2] stride=1
+      // instead of
+      //   int[2] stride=[1, 1]
+      auto default_val = arg.default_value().value().toIntList();
+      if (default_val.size() > 1) {
+        auto all_defaults_the_same = true;
+        for (const auto i : c10::irange(1, default_val.size())) {
+          if (default_val[0] != default_val[i]) all_defaults_the_same = false;
+        }
+        if (all_defaults_the_same) {
+          out << default_val[0];
+        } else {
+          out << arg.default_value().value();
+        }
+      } else {
+        out << arg.default_value().value();
+      }
+    } else {
+      out << arg.default_value().value();
+    }
+  }
+
+  return out;
+}
+
+inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
+
+inline std::string toString(const FunctionSchema& schema) {
+  std::ostringstream str;
+  str << schema;
+  return str.str();
+}
+
+} // namespace c10
+
+namespace std {
+template<>
+  struct hash {
+    size_t operator()(const c10::SchemaArgument& arg) const
+    {
+      return c10::hash_combine(std::hash()(arg.index), std::hash()(static_cast(arg.type)));
+    }
+  };
+template<>
+  struct hash {
+    size_t operator()(const c10::Argument& arg) const
+    {
+      auto hash = std::hash{}(arg.name());
+      auto type_hash = std::hash{}(arg.type());
+      auto kwarg_only_hash = std::hash{}(arg.kwarg_only());
+      hash = c10::hash_combine(hash, type_hash);
+      hash = c10::hash_combine(hash, kwarg_only_hash);
+      // hashing optional fields if they exist
+      if (arg.default_value()) {
+        auto default_value_hash = c10::hash{}(arg.default_value().value());
+        hash = c10::hash_combine(hash, default_value_hash);
+      }
+      if (arg.N()) {
+        auto N_hash = std::hash{}(*arg.N());
+        hash = c10::hash_combine(hash, N_hash);
+      }
+      if (arg.alias_info()) {
+        auto alias_info_hash = std::hash{}(*arg.alias_info());
+        hash = c10::hash_combine(hash, alias_info_hash);
+      }
+      return hash;
+    }
+  };
+template<>
+  struct hash {
+    size_t operator()(const c10::FunctionSchema& schema) const
+    {
+      auto hash = std::hash{}(schema.operator_name());
+      auto args_hash = c10::hash>{}(schema.arguments());
+      auto returns_hash = c10::hash>{}(schema.returns());
+      auto is_vararg_hash = std::hash{}(schema.is_vararg());
+      auto is_varret_hash = std::hash{}(schema.is_varret());
+      hash = c10::hash_combine(hash, args_hash);
+      hash = c10::hash_combine(hash, returns_hash);
+      hash = c10::hash_combine(hash, is_vararg_hash);
+      hash = c10::hash_combine(hash, is_varret_hash);
+      return hash;
+    }
+  };
+} // namespace std
+
+
+#include   // IWYU pragma: keep
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/function_schema_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/function_schema_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..34edfa3e7750ae47f7abf5094dffaa63237f6c6d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/function_schema_inl.h
@@ -0,0 +1,483 @@
+#pragma once
+#include 
+#include 
+
+// note: windows build doesn't find symbols in operator files unless
+// this is a header file
+
+namespace c10 {
+
+inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
+  // eventually this should look almost identical to python arg parser, but
+  // it is simpler for now to work directly on this schema
+
+  out << schema.name();
+  if (!schema.overload_name().empty()) {
+    out << "." << schema.overload_name();
+  }
+  out << "(";
+
+  bool seen_kwarg_only = false;
+  for (const auto i : c10::irange(schema.arguments().size())) {
+    if (i > 0) out << ", ";
+    if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
+      out << "*, ";
+      seen_kwarg_only = true;
+    }
+    out << schema.arguments()[i];
+  }
+
+  if(schema.is_vararg()) {
+    if(!schema.arguments().empty())
+      out << ", ";
+    out << "...";
+  }
+
+  out << ") -> ";
+
+  const auto& returns = schema.returns();
+
+  /*
+   * We should skip parenthesis if we return a single item and it's not varret,
+   * or we return nothing but varret.
+   *
+   * Need special handling for schema
+   *   aten::items.str(Dict(str, t) self) -> (str,t)[]
+   * Even though this schema returns a single item, we need add parenthesis.
+   * The is necessary so the printed schema can be parsed by the C++ SchemaParser
+   * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly
+   * treat the return type as a tuple. An alternative is to enhance the Lexer
+   * to lookahead multiple tokens to accurately decide if the return type is
+   * a tuple.
+   */
+  bool need_paren = !(
+    (returns.size() == 1 && !schema.is_varret()) ||
+    (returns.empty() && schema.is_varret()));
+
+  if (returns.size() == 1 && !schema.is_varret()) {
+    std::stringstream return_ss;
+    return_ss << returns.at(0);
+    auto return_str = return_ss.str();
+
+    // enclosing the single return item with parenthesis if the return type
+    // starts with a left parenthesis.
+    //
+    // There are 2 cases
+    // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
+    // without the extra parenthesis, the c++ schem parser can not parse it.
+    // 2. something like '-> ((str, str))'. Need extra parenthesis so the return
+    // type is a single tuple rather than two strings.
+    // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
+    // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15)
+    // also covers this case.
+    if (!return_str.empty() && return_str.front() == '(') {
+      need_paren = true;
+    }
+  }
+
+  if (need_paren) {
+    out << "(";
+  }
+  for (const auto i : c10::irange(returns.size())) {
+    if (i > 0) {
+      out << ", ";
+    }
+    out << returns.at(i);
+  }
+  if (schema.is_varret()) {
+    if (!returns.empty()) {
+      out << ", ";
+    }
+    out << "...";
+  }
+  if (need_paren) {
+    out << ")";
+  }
+  return out;
+}
+
+inline size_t findFirstOutArg(const std::vector& args) {
+  // find the start of out args in the schema
+  for (const auto out_start_idx : c10::irange(args.size())) {
+    if (args.at(out_start_idx).is_out()) {
+      return out_start_idx;
+    }
+  }
+  return args.size();
+}
+
+inline bool Argument::isBackwardCompatibleWith(
+      const Argument& old,
+      std::ostream* why_not) const {
+    const Argument* lhs = this;
+    const Argument* rhs = &old;
+    if (!(lhs->name() == rhs->name()
+        && lhs->N() == rhs->N()
+          && (lhs->alias_info() == rhs->alias_info()
+              || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
+                  && *lhs->alias_info() == *rhs->alias_info())))) {
+      return false;
+    }
+    if (lhs->kwarg_only() && !rhs->kwarg_only()) {
+      return false;
+    }
+    if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) {
+      return false;
+    }
+    if (rhs->default_value().has_value() &&
+        lhs->default_value() != rhs->default_value()) {
+      return false;
+    }
+    return true;
+}
+
+inline bool Argument::isForwardCompatibleWith(
+    const Argument& old,
+    std::ostream* why_not) const {
+  const Argument* lhs = this;
+  const Argument* rhs = &old;
+  if (!(lhs->name() == rhs->name()
+      && lhs->N() == rhs->N()
+        && (lhs->alias_info() == rhs->alias_info()
+            || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
+                && *lhs->alias_info() == *rhs->alias_info())))) {
+    return false;
+  }
+  if (lhs->kwarg_only() && !rhs->kwarg_only()) {
+    return false;
+  }
+  if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) {
+    return false;
+  }
+  if (rhs->default_value().has_value() &&
+      lhs->default_value() != rhs->default_value()) {
+    return false;
+  }
+  if (lhs->default_value().has_value() && !rhs->default_value().has_value()) {
+    return false;
+  }
+  return true;
+}
+
+inline std::string FunctionSchema::formatTypeMismatchMsg(
+    const Argument& expected,
+    const std::string& actual_type,
+    c10::optional position,
+    c10::optional value) const {
+  std::string position_str;
+  if (position) {
+    position_str = c10::str("Position: ", *position, "\n");
+  }
+  std::string value_str;
+  if (value) {
+    value_str = c10::str("Value: ", *value, "\n");
+  }
+  return c10::str(
+      name(),
+      "() ",
+      expected.formatTypeMismatchMsg(actual_type),
+      position_str,
+      value_str,
+      "Declaration: ",
+      *this);
+}
+
+inline bool FunctionSchema::isBackwardCompatibleWith(
+    const FunctionSchema& old,
+    std::ostream* why_not) const {
+  if (!(name() == old.name()
+        && overload_name() == old.overload_name()
+        // we are conservative on is_vararg and is_varret,
+        // since they are only used by internal operators
+        && is_vararg() == old.is_vararg()
+        && is_varret() == old.is_varret()
+        && returns().size() == old.returns().size()
+        && arguments().size() >= old.arguments().size())) {
+    return false;
+  }
+  for (const auto i : c10::irange(returns().size())) {
+    // Backwards compatibility requires covariance on argument types
+    // (i.e. more generic), and contravariance on return types (i.e.
+    //  more specific).
+    if (!old.returns().at(i).isBackwardCompatibleWith(
+          returns().at(i),
+          why_not)) {
+      return false;
+    }
+  }
+
+  // we want to test both out and default args separately
+  size_t old_out_start_idx = findFirstOutArg(old.arguments());
+  size_t new_out_start_idx = findFirstOutArg(arguments());
+
+  // make sure among the default args, they are backward compatible
+  for (const auto i : c10::irange(old_out_start_idx)) {
+    if (!arguments().at(i).isBackwardCompatibleWith(
+          old.arguments().at(i), why_not)) {
+      return false;
+    }
+  }
+
+  // Validate that all new arguments provided has a default value
+  for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) {
+    if (!arguments().at(i).default_value()) {
+      if (why_not) {
+        *why_not
+            << "Function schema not backward compatible since the new argument '"
+            << arguments().at(i).name() << "' of type "
+            << arguments().at(i).type()->str()
+            << " did not provide a default value.";
+      }
+      return false;
+    }
+  }
+
+  // now compare the out args
+  for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) {
+    if (!arguments()
+             .at(i - old_out_start_idx + new_out_start_idx)
+             .isBackwardCompatibleWith(old.arguments().at(i), why_not)) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+inline bool FunctionSchema::isForwardCompatibleWith(
+    const FunctionSchema& old,
+    std::ostringstream& why_not) const {
+  if (!(name() == old.name() &&
+        overload_name() == old.overload_name()
+        // we are conservative on is_vararg and is_varret,
+        // since they are only used by internal operators
+        && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() &&
+        returns().size() == old.returns().size())) {
+    return false;
+  }
+
+  // we want to test both out and default args separately
+  size_t old_out_start_idx = findFirstOutArg(old.arguments());
+  size_t new_out_start_idx = findFirstOutArg(arguments());
+
+  if (old.arguments().size() - old_out_start_idx !=
+      arguments().size() - new_out_start_idx) {
+    if (why_not) {
+      why_not << "Function schema should have the "
+              << "same number of out arguments";
+    }
+    return false;
+  }
+
+  // make sure among the default args, they are forward compatible
+  for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) {
+    if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
+      if (why_not) {
+        why_not
+            << "'" << arguments().at(i).name() << "'"
+            << " is not forward compatible with the older version of the schema";
+      }
+      return false;
+    }
+  }
+
+  // Validate that all new arguments provided has a default value
+  for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
+    if (!arguments().at(i).default_value()) {
+      if (why_not) {
+        why_not
+            << "Function schema is not forward compatible since the new argument '"
+            << arguments().at(i).name() << "' of type "
+            << arguments().at(i).type()->str()
+            << " did not provide a default value.";
+      }
+      return false;
+    }
+
+    auto default_val = arguments().at(i).default_value().value();
+    if (default_val.isList() || default_val.isGenericDict()) {
+      if (why_not) {
+        why_not
+            << "Function schema is not forward compatible since the new argument '"
+            << arguments().at(i).name() << "' of type "
+            << arguments().at(i).type()->str() << " has a container type "
+            << "as its default value.";
+      }
+      return false;
+    }
+  }
+
+  // now compare the out args
+  for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
+    if (!arguments()
+             .at(i - old_out_start_idx + new_out_start_idx)
+             .isForwardCompatibleWith(old.arguments().at(i))) {
+      if (why_not) {
+        why_not << "Out argument '"
+                << "'" << arguments().at(i).name()
+                << " is not FC with the older version of the schema";
+      }
+      return false;
+    }
+  }
+
+  return true;
+}
+
+template
+inline void FunctionSchema::checkArg(
+    const IValue& value,
+    const Argument& argument,
+    optional pos) const {
+  if (value.isTensor() && argument.type() == TensorType::get()) {
+    // Fast-path for the common case
+    return;
+  }
+  if (!value.type()->isSubtypeOf(*argument.type())) {
+    TORCH_CHECK(
+        false,
+        formatTypeMismatchMsg(
+            argument, value.type()->repr_str(), pos));
+  }
+}
+
+inline std::string FunctionSchema::findErrorInKwargs(const std::vector& kwargs) const {
+  // First check if any of the kwargs are unknown, i.e. don't match the name of
+  // any argument in the schema.
+  for (const auto& kwarg : kwargs) {
+    if (!std::count_if(
+            arguments().begin(),
+            arguments().end(),
+            [&kwarg](const Argument& argument) {
+              return argument.name() == kwarg;
+            })) {
+      return c10::str(
+          "Unknown keyword argument '",
+          kwarg,
+          "' for operator '",
+          name(),
+          "'. Schema: ",
+          *this);
+    }
+  }
+  // If there are unconsumed kwargs but none of them were unknown, the first
+  // positional argument present in the kwargs is duplicated.
+  for (const auto& argument : arguments()) {
+    if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) {
+      AT_ASSERT(!argument.default_value());
+      return c10::str(
+          "Argument '",
+          argument.name(),
+          "' specified both as positional and ",
+          "keyword argument. Schema: ",
+          *this);
+    }
+  }
+  return "";
+}
+
+template 
+inline void FunctionSchema::checkAndNormalizeInputs(
+    std::vector& inputs,
+    const std::unordered_map& kwargs) const {
+  // Do we have more inputs than the schema accepts?
+  TORCH_CHECK(
+      inputs.size() <= arguments().size(),
+      "Expected at most ",
+      arguments().size(),
+      " argument(s) for operator '",
+      name(),
+      "', but received ",
+      inputs.size(),
+      " argument(s). Declaration: ",
+      *this);
+
+  size_t consumed_kwargs = 0;
+  for (const auto pos : c10::irange(arguments().size())) {
+    const auto& argument = arguments()[pos];
+    if (pos < inputs.size()) {
+      checkArg(inputs[pos], argument, pos);
+      continue;
+    }
+    auto it = kwargs.find(argument.name());
+    if (it != kwargs.end()) {
+      checkArg(it->second, argument, nullopt);
+      inputs.push_back(it->second);
+      consumed_kwargs++;
+      continue;
+    }
+    if (argument.default_value()) {
+      inputs.push_back(*argument.default_value());
+      continue;
+    }
+    AT_ERROR(
+        name(),
+        "() is missing value for argument '",
+        argument.name(),
+        "'. Declaration: ",
+        *this);
+  }
+  if (consumed_kwargs != kwargs.size()) {
+    std::vector names;
+    names.reserve(kwargs.size());
+    for(const auto& k : kwargs) {
+      names.emplace_back(k.first);
+    }
+    throw std::runtime_error(findErrorInKwargs(names));
+  }
+}
+
+inline FunctionSchema FunctionSchema::cloneWithRemappedTypes(
+    const std::function type_map) const {
+  auto update_args = [&](const std::vector& args) {
+    std::vector new_args;
+    new_args.reserve(args.size());
+    for(const Argument& arg : args) {
+      new_args.emplace_back(arg.cloneWithType(type_map(arg.type())));
+    }
+    return new_args;
+  };
+  return FunctionSchema(
+      name(),
+      overload_name(),
+      update_args(arguments()),
+      update_args(returns()),
+      is_vararg(),
+      is_varret());
+}
+
+// covariant subtyping of list of Arguments
+inline bool isSubtypeOfList(
+    ArrayRef child,
+    ArrayRef parent,
+    std::ostream* why_not) {
+  if (child.size() != parent.size()) {
+    return false;
+  }
+  for (const auto i : c10::irange(child.size())) {
+    const Argument& c = child[i];
+    const Argument& p = parent[i];
+    if (c.name() != p.name()) {
+      return false;
+    }
+    if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline bool FunctionSchema::isSubtypeOf(
+    const FunctionSchema& rhs,
+    bool as_method,
+    std::ostream* why_not) const {
+  size_t start = as_method ? 1 : 0;
+  // functions are contravariant in arguments but covariant in returns
+  return isSubtypeOfList(
+             ArrayRef(rhs.arguments()).slice(start),
+             ArrayRef(arguments()).slice(start),
+             why_not) &&
+      isSubtypeOfList(returns(), rhs.returns(), why_not);
+}
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/functional.h b/MLPY/Lib/site-packages/torch/include/ATen/core/functional.h
new file mode 100644
index 0000000000000000000000000000000000000000..20e2d60445fe938ceb25f41508ce9761b69c41c5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/functional.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include 
+#include 
+
+namespace c10 {
+
+// The passed in function must take T by value (T), or by
+// const reference (const T&); taking T by non-const reference
+// will result in an error like:
+//
+//    error: no type named 'type' in 'class std::result_of'
+//
+// No explicit template parameters are required.
+
+// Overload for explicit function and ArrayRef
+template
+inline auto fmap(const T& inputs, const F& fn) -> std::vector {
+  std::vector r;
+  r.reserve(inputs.size());
+  for(const auto & input : inputs)
+    r.push_back(fn(input));
+  return r;
+}
+
+// C++ forbids taking an address of a constructor, so here's a workaround...
+// Overload for constructor (R) application
+template
+inline std::vector fmap(const T& inputs) {
+  std::vector r;
+  r.reserve(inputs.size());
+  for(auto & input : inputs)
+    r.push_back(R(input));
+  return r;
+}
+
+template
+inline std::vector filter(at::ArrayRef inputs, const F& fn) {
+  std::vector r;
+  r.reserve(inputs.size());
+  for(auto & input : inputs) {
+    if (fn(input)) {
+      r.push_back(input);
+    }
+  }
+  return r;
+}
+
+template
+inline std::vector filter(const std::vector& inputs, const F& fn) {
+  return filter(static_cast>(inputs), fn);
+}
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/grad_mode.h b/MLPY/Lib/site-packages/torch/include/ATen/core/grad_mode.h
new file mode 100644
index 0000000000000000000000000000000000000000..5e7dc5b0ad1ca9ca11f325cb6c5985ffa9815efc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/grad_mode.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+  using GradMode = c10::GradMode;
+  using AutoGradMode = c10::AutoGradMode;
+  using NoGradGuard = c10::NoGradGuard;
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/interned_strings.h b/MLPY/Lib/site-packages/torch/include/ATen/core/interned_strings.h
new file mode 100644
index 0000000000000000000000000000000000000000..ff02c53f7f52afa3dcde61e7a84af865387c4f63
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/interned_strings.h
@@ -0,0 +1,358 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+
+namespace c10 {
+
+#define FORALL_NS_SYMBOLS(_)         \
+  _(namespaces, prim)                \
+  _(namespaces, prims)               \
+  _(namespaces, nvprims)             \
+  _(namespaces, aten)                \
+  _(namespaces, cuda)                \
+  _(namespaces, onnx)                \
+  _(namespaces, attr)                \
+  _(namespaces, scope)               \
+  _(namespaces, user)                \
+  _(namespaces, _caffe2)             \
+  _(namespaces, dimname)             \
+  _(namespaces, namespaces)          \
+  _(prim, Assign)                    \
+  _(prim, BroadcastingChunk)         \
+  _(prim, BroadcastSizes)            \
+  _(prim, ReductionSizes)            \
+  _(prim, Constant)                  \
+  _(prim, ChunkSizes)                \
+  _(prim, ConstantMKLDNNTensor)      \
+  _(prim, BroadcastMKLDNNTensors)    \
+  _(prim, MKLDNNGroup)               \
+  _(prim, MKLDNNHardSwish)           \
+  _(prim, MKLDNNHardSigmoid)         \
+  _(prim, MKLDNNHardTanh)            \
+  _(prim, MKLDNNClamp)               \
+  _(prim, StaticRuntimeCopyOuts)     \
+  _(prim, Drop)                      \
+  _(prim, Eval)                      \
+  _(prim, Expand) /* onnx */         \
+  _(prim, FusionGroup)               \
+  _(prim, CudaFusionGroup)           \
+  _(prim, CudaFusionGuard)           \
+  _(prim, oneDNNFusionGroup)         \
+  _(prim, oneDNNFusionGuard)         \
+  _(prim, FunctionalGraph)           \
+  _(prim, add_optional)              \
+  _(prim, view_copy)                 \
+  _(prim, permute_copy)              \
+  _(prim, reshape_copy)              \
+  _(prim, squeeze_copy)              \
+  _(prim, t_copy)                    \
+  _(prim, transpose_copy)            \
+  _(prim, unsqueeze_copy)            \
+  _(prim, flatten_copy)              \
+  _(prim, expand_copy)               \
+  _(prim, expand_as_copy)            \
+  _(prim, DifferentiableGraph)       \
+  _(prim, TensorExprGroup)           \
+  _(prim, TensorExprDynamicGroup)    \
+  _(prim, StaticSubgraph)            \
+  _(prim, If)                        \
+  _(prim, Jump) /* debug */          \
+  _(prim, JumpNZ) /* debug */        \
+  _(prim, JumpZ) /* debug */         \
+  _(prim, Load)                      \
+  _(prim, Loop)                      \
+  _(prim, Param)                     \
+  _(prim, PackPadded) /* onnx */     \
+  _(prim, PadPacked) /* onnx */      \
+  _(prim, Placeholder) /* debug */   \
+  _(prim, Print)                     \
+  _(prim, EmptyListLiteral)          \
+  _(prim, LegacyTypedConstructor)    \
+  _(prim, PythonOp)                  \
+  _(prim, IgnoredPythonOp)           \
+  _(prim, Reverse)                   \
+  _(prim, Return)                    \
+  _(prim, ReturnStmt)                \
+  _(prim, BreakStmt)                 \
+  _(prim, ContinueStmt)              \
+  _(prim, ComprehensionScope)        \
+  _(prim, Store)                     \
+  _(prim, AutogradZero)              \
+  _(prim, AutogradAnyNonZero)        \
+  _(prim, AutogradAllNonZero)        \
+  _(prim, AutogradAllZero)           \
+  _(prim, Starred)                   \
+  _(prim, TupleConstruct)            \
+  _(prim, TupleUnpack)               \
+  _(prim, TupleIndex)                \
+  _(prim, TupleSlice)                \
+  _(prim, ListConstruct)             \
+  _(prim, ListUnpack)                \
+  _(prim, DictConstruct)             \
+  _(prim, ModuleContainerIndex)      \
+  _(prim, EnumName)                  \
+  _(prim, EnumValue)                 \
+  _(prim, StringIndex)               \
+  _(prim, NumToTensor)               \
+  _(prim, Uninitialized)             \
+  _(prim, VarConcat)                 \
+  _(prim, VarStack)                  \
+  _(prim, With)                      \
+  _(prim, Enter)                     \
+  _(prim, Exit)                      \
+  _(prim, IfThenElse)                \
+  _(aten, Bool)                      \
+  _(aten, Int)                       \
+  _(aten, FloatImplicit)             \
+  _(aten, ComplexImplicit)           \
+  _(aten, IntImplicit)               \
+  _(aten, ScalarImplicit)            \
+  _(aten, Float)                     \
+  _(aten, Complex)                   \
+  _(aten, str)                       \
+  _(aten, Delete)                    \
+  _(prim, device)                    \
+  _(prim, dtype)                     \
+  _(prim, layout)                    \
+  _(prim, id)                        \
+  _(prim, requires_grad)             \
+  _(prim, MakeTestTensor) /* test */ \
+  _(prim, AutogradAdd)               \
+  _(prim, GradOf)                    \
+  _(aten, grad)                      \
+  _(aten, backward)                  \
+  _(prim, Guard)                     \
+  _(prim, BailOut)                   \
+  _(prim, TypeCheck)                 \
+  _(prim, RequiresGradCheck)         \
+  _(prim, FallbackGraph)             \
+  _(prim, FusedConcat)               \
+  _(prim, ConstantChunk)             \
+  _(prim, MMTreeReduce)              \
+  _(prim, MMBatchSide)               \
+  _(prim, list)                      \
+  _(prim, dict)                      \
+  _(prim, min)                       \
+  _(prim, max)                       \
+  _(prim, abs)                       \
+  _(aten, divmod)                    \
+  _(prim, zip)                       \
+  _(prim, enumerate)                 \
+  _(prim, range)                     \
+  _(prim, rangelist)                 \
+  _(prim, isinstance)                \
+  _(prim, tolist)                    \
+  _(prim, unchecked_cast)            \
+  _(aten, _grad_sum_to_size)         \
+  _(aten, _size_if_not_equal)        \
+  _(aten, _ncf_unsqueeze)            \
+  _(aten, warn)                      \
+  _(aten, sorted)                    \
+  _(aten, floordiv)                  \
+  _(aten, __range_length)            \
+  _(aten, __derive_index)            \
+  _(aten, __round_to_zero_floordiv)  \
+  _(aten, is_scripting)              \
+  _(aten, _unwrap_optional)          \
+  _(prim, fork)                      \
+  _(prim, awaitable)                 \
+  _(prim, forkClosure)               \
+  _(prim, awaitableClosure)          \
+  _(prim, awaitable_nowait)          \
+  _(prim, awaitable_wait)            \
+  _(prim, RaiseException)            \
+  _(prim, Closure)                   \
+  _(prim, CreateObject)              \
+  _(prim, SetAttr)                   \
+  _(prim, GetAttr)                   \
+  _(prim, HasAttr)                   \
+  _(prim, profile)                   \
+  _(prim, profile_ivalue)            \
+  _(prim, AddStatValue)              \
+  _(prim, TimePoint)                 \
+  _(prim, CallFunction)              \
+  _(prim, CallMethod)                \
+  _(prim, LoopContinuation)          \
+  _(prim, annotate)                  \
+  _(prim, TracedModuleForward)       \
+  _(prim, TracedFork)                \
+  _(prim, TracedAttr)                \
+  _(prim, rpc_async)                 \
+  _(prim, rpc_sync)                  \
+  _(prim, rpc_remote)                \
+  _(prim, is_cuda)                   \
+  _(aten, append)                    \
+  _(aten, as_tensor)                 \
+  _(aten, adaptive_avg_pool2d_backward) \
+  _(aten, dim)                       \
+  _(aten, format)                    \
+  _(aten, percentFormat)             \
+  _(aten, __not__)                   \
+  _(aten, __is__)                    \
+  _(aten, __isnot__)                 \
+  _(aten, _ger)                      \
+  _(aten, __getitem__)               \
+  _(aten, _set_item)                 \
+  _(aten, manual_seed)               \
+  _(aten, device)                    \
+  _(aten, hash)                      \
+  _(aten, len)                       \
+  _(aten, list)                      \
+  _(aten, dict)                      \
+  _(aten, wait)                      \
+  _(aten, save)                      \
+  _(aten, keys)                      \
+  _(aten, ord)                       \
+  _(aten, chr)                       \
+  _(aten, hex)                       \
+  _(aten, oct)                       \
+  _(aten, clear)                     \
+  _(aten, setdefault)                \
+  _(aten, bin)                       \
+  _(aten, pop)                       \
+  _(aten, insert)                    \
+  _(aten, tensor)                    \
+  _(prim, unchecked_unwrap_optional) \
+  _(aten, __contains__)              \
+  _(prim, BailoutTemplate)           \
+  _(prim, grad)                      \
+  _(cuda, _set_device)               \
+  _(cuda, set_stream)                \
+  _(cuda, _current_device)           \
+  _(cuda, synchronize)               \
+  _(aten, has_torch_function)        \
+  _(aten, is_autocast_enabled)       \
+  _(aten, is_autocast_cpu_enabled)   \
+  _(aten, is_autocast_xla_enabled)   \
+  FORALL_ATEN_BASE_SYMBOLS(_)        \
+  _(onnx, Add)                       \
+  _(onnx, Concat)                    \
+  _(onnx, Constant)                  \
+  _(onnx, ConstantFill)              \
+  _(onnx, Div)                       \
+  _(onnx, GRU)                       \
+  _(onnx, Gather)                    \
+  _(onnx, Gemm)                      \
+  _(onnx, LSTM)                      \
+  _(onnx, MatMul)                    \
+  _(onnx, Min)                       \
+  _(onnx, Max)                       \
+  _(onnx, Mul)                       \
+  _(onnx, Pow)                       \
+  _(onnx, RNN)                       \
+  _(onnx, Shape)                     \
+  _(onnx, Size)                      \
+  _(onnx, Slice)                     \
+  _(onnx, Softmax)                   \
+  _(onnx, Squeeze)                   \
+  _(onnx, Sub)                       \
+  _(onnx, Transpose)                 \
+  _(onnx, Unsqueeze)                 \
+  _(onnx, Loop)                      \
+  _(onnx, If)                        \
+  _(onnx, Reshape)                   \
+  _(onnx, Expand)                    \
+  _(onnx, Equal)                     \
+  _(onnx, Greater)                   \
+  _(onnx, GreaterOrEqual)            \
+  _(onnx, Less)                      \
+  _(onnx, LessOrEqual)               \
+  _(onnx, Not)                       \
+  _(aten, ATen)                      \
+  _(onnx, Split)                     \
+  _(onnx, ConstantOfShape)           \
+  _(onnx, Cast)                      \
+  _(onnx, Mod)                       \
+  _(onnx, Sqrt)                      \
+  _(onnx, SplitToSequence)           \
+  _(onnx, SequenceAt)                \
+  _(onnx, SequenceConstruct)         \
+  _(onnx, SequenceEmpty)             \
+  _(onnx, SequenceInsert)            \
+  _(onnx, SequenceErase)             \
+  _(onnx, ConcatFromSequence)        \
+  _(onnx, Identity)                  \
+  _(onnx, SoftmaxCrossEntropyLoss)   \
+  _(onnx, NegativeLogLikelihoodLoss) \
+  _(onnx, LogSoftmax)                \
+  _(onnx, ReduceL1)                  \
+  _(onnx, ReduceL2)                  \
+  _(onnx, Conv)                      \
+  _(onnx, BatchNormalization)        \
+  _(onnx, ReduceMean)                \
+  _(onnx, ReduceProd)                \
+  _(onnx, Relu)                      \
+  _(onnx, Neg)                       \
+  _(onnx, NonZero)                   \
+  _(onnx, Range)                     \
+  _(onnx, Tile)                      \
+  _(onnx, Where)                     \
+  _(onnx, Optional)                  \
+  _(onnx, OptionalGetElement)        \
+  _(onnx, OptionalHasElement)        \
+  FORALL_ATTR_BASE_SYMBOLS(_)        \
+  _(attr, Subgraph)                  \
+  _(attr, ReverseSubgraph)           \
+  _(attr, f_real_outputs)            \
+  _(attr, df_input_vjps)             \
+  _(attr, df_input_captured_inputs)  \
+  _(attr, df_input_captured_outputs) \
+  _(attr, df_output_vjps)            \
+  _(attr, axes)                      \
+  _(attr, symbolic_shape_inputs)     \
+  _(attr, allow_stack_outputs)       \
+  _(attr, striding_inputs_desc)      \
+  _(attr, striding_outputs_desc)     \
+  _(attr, broadcast)                 \
+  _(attr, direction)                 \
+  _(attr, ends)                      \
+  _(attr, inplace)                   \
+  _(attr, input_as_shape)            \
+  _(attr, is_zero)                   \
+  _(attr, num_none)                  \
+  _(attr, num_present)               \
+  _(attr, perm)                      \
+  _(attr, starts)                    \
+  _(attr, profiled_type)             \
+  _(attr, transA)                    \
+  _(attr, transB)                    \
+  _(attr, name)                      \
+  _(attr, module)                    \
+  _(attr, beg)                       \
+  _(attr, idx)                       \
+  _(attr, split)                     \
+  _(attr, slot)                      \
+  _(attr, kinds)                     \
+  _(attr, types)                     \
+  _(attr, scope)                     \
+  _(attr, keepdims)                  \
+  _(attr, cache_id)                  \
+  _(attr, new_axis)                  \
+  _(attr, warn_id)                   \
+  _(attr, output_layouts)            \
+  _(attr, allowzero)                 \
+  _(attr, seen_none)                 \
+  _(attr, overload_name)             \
+  _(attr, node_stack_idx)
+
+enum class _keys : unique_t {
+    #define DEFINE_KEY(ns, s) ns##_##s,
+    FORALL_NS_SYMBOLS(DEFINE_KEY)
+    #undef DEFINE_KEY
+    num_symbols
+};
+
+#define DEFINE_SYMBOL(ns, s) \
+  namespace ns { constexpr Symbol s(static_cast(_keys::ns##_##s)); }
+FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
+#undef DEFINE_SYMBOL
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/interned_strings_class.h b/MLPY/Lib/site-packages/torch/include/ATen/core/interned_strings_class.h
new file mode 100644
index 0000000000000000000000000000000000000000..ee651b41e66729816ce68b20b365689714aab086
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/interned_strings_class.h
@@ -0,0 +1,34 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+struct TORCH_API InternedStrings {
+  InternedStrings();
+  Symbol symbol(const std::string& s);
+  std::pair string(Symbol sym);
+  Symbol ns(Symbol sym);
+
+ private:
+  // prereq - holding mutex_
+  Symbol _symbol(const std::string& s);
+  std::pair customString(Symbol sym);
+  std::unordered_map string_to_sym_;
+
+  struct SymbolInfo {
+    Symbol ns;
+    std::string qual_name;
+    std::string unqual_name;
+  };
+  std::vector sym_to_info_;
+
+  std::mutex mutex_;
+};
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue.h
new file mode 100644
index 0000000000000000000000000000000000000000..3cae44fc00dbc6c6a0f814e641b47781a11a7fa1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue.h
@@ -0,0 +1,1555 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace torch {
+class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
+namespace jit {
+using ::torch::CustomClassHolder;
+struct Function;
+struct CompilationUnit;
+struct Module;
+} // namespace jit
+} // namespace torch
+namespace c10 {
+template 
+class Dict;
+template 
+class List;
+template 
+class IListRef;
+struct IValue;
+struct ClassType;
+struct Type;
+class RRefInterface;
+
+struct ClassType;
+using ClassTypePtr = std::shared_ptr;
+
+TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs);
+
+TORCH_API torch::jit::Function* checkObjectSortSchema(
+    const c10::ClassTypePtr& t,
+    std::stringstream& why_not);
+
+// A comparator that checks ordering of two IValues of same type.
+typedef std::function IValueComparator;
+
+TORCH_API IValueComparator getLessThanComparator(const IValue& v);
+TORCH_API IValueComparator getGreaterThanComparator(const IValue& v);
+
+namespace ivalue {
+struct Tuple;
+struct Future;
+struct Await;
+struct ConstantString;
+struct GenericDict;
+struct Object;
+struct PyObjectHolder;
+struct EnumHolder;
+// We need a ComplexHolder because currently the payloads in the Union
+// only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big
+// to fit in the IValue directly, we indirect complex numbers through an
+// intrusive pointer to ComplexHolder (which contains a c10::complex).
+struct ComplexHolder : c10::intrusive_ptr_target {
+ public:
+  template 
+  ComplexHolder(c10::complex c) {
+    val = convert>(c);
+  }
+  ComplexHolder() = default;
+  c10::complex val;
+};
+
+// Similar to ComplexHolder, for StreamData3
+struct StreamData3Holder : c10::intrusive_ptr_target {
+ public:
+  StreamData3Holder(struct c10::StreamData3 d) : val(d) {}
+  StreamData3Holder() = delete;
+  struct c10::StreamData3 val;
+};
+
+} // namespace ivalue
+
+// This is an owning wrapper for a c10::optional>
+// that can be implicitly converted to a (non-owning) optional>.
+// Its purpose is to be used in generated code to keep the vector alive
+// either until the end of a statement (as a temporary), or as a saved arg
+// in autograd.
+template 
+struct OptionalArray {
+  c10::optional> list;
+
+  OptionalArray() = default;
+  OptionalArray(std::vector val) : list(std::move(val)) {}
+
+  // Used when saving an argument for the backwards pass.
+  OptionalArray& operator=(c10::optional> ref) {
+    if (ref) {
+      list = std::vector(ref->begin(), ref->end());
+    } else {
+      list = nullopt;
+    }
+    return *this;
+  }
+
+  // Used when saving an argument for the backwards pass.
+  OptionalArray& operator=(c10::OptionalArrayRef ref) {
+    if (ref) {
+      list = std::vector(ref->begin(), ref->end());
+    } else {
+      list = nullopt;
+    }
+    return *this;
+  }
+
+  operator c10::optional>() {
+    if (!list) {
+      return nullopt;
+    }
+    return *list;
+  }
+
+  operator c10::OptionalArrayRef() {
+    if (!list) {
+      return nullopt;
+    }
+    return *list;
+  }
+};
+
+// Capsule is an internal implementation detail of custom C++ classes. We
+// define it as an owning wrapper for
+// c10::intrusive_ptr This wrapper is here to serve as
+// an abstraction of the type erased custom class object pointer. It also allow
+// pybind11 to treat this as a standalone class to register as a separate type
+// caster, instead of a custom pointer holder which the pointer holder type
+// caster try to "unwrap" it automatically.
+struct Capsule {
+  c10::intrusive_ptr obj_ptr;
+  explicit Capsule(c10::intrusive_ptr ptr)
+      : obj_ptr(std::move(ptr)) {}
+};
+
+// IValue is the generic tagged union used by the interpreter to hold
+// all value types.
+// It is a 16-byte object with an 8-byte payload and an 8-byte tag.
+// The tag is currently 4 bytes to determine the type, and 1 byte
+// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
+// retain/release calls.
+
+#define TORCH_FORALL_TAGS(_) \
+  _(None)                    \
+  _(Tensor)                  \
+  _(Storage)                 \
+  _(Double)                  \
+  _(ComplexDouble)           \
+  _(Int)                     \
+  _(SymInt)                  \
+  _(SymFloat)                \
+  _(SymBool)                 \
+  _(Bool)                    \
+  _(Tuple)                   \
+  _(String)                  \
+  _(Blob)                    \
+  _(GenericList)             \
+  _(GenericDict)             \
+  _(Future)                  \
+  _(Await)                   \
+  _(Device)                  \
+  _(Stream)                  \
+  _(Object)                  \
+  _(PyObject)                \
+  _(Uninitialized)           \
+  _(Capsule)                 \
+  _(RRef)                    \
+  _(Quantizer)               \
+  _(Generator)               \
+  _(Enum)
+
+// [doxygen private]
+// These methods are not actually private but we don't want to document them, so
+// they are marked `@private`, which hides them on the doxygen documentation for
+// this page.
+
+/// IValue (Interpreter Value) is a tagged union over the types
+/// supported by the TorchScript interpreter. IValues contain their
+/// values as an `IValue::Payload`, which holds primitive types
+/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values,
+/// and all other types as a `c10::intrusive_ptr`. In order to
+/// optimize performance of the destructor and related operations by
+/// making the `Tensor` and `c10::intrusive_ptr` paths generate the
+/// same code, we represent a null `c10::intrusive_ptr` as
+/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`.
+///
+/// IValues are used as inputs to and outputs from the TorchScript interpreter.
+/// To retrieve the value contained within an IValue, use the `.toX()` methods,
+/// where `X` is the type you are trying to get. Note that neither the `.toX()`
+/// methods nor the templated `.to` functions do any kind of casting, they
+/// only unwrap the contained value. For example:
+///
+/// \rst
+/// .. code-block:: cpp
+///
+///   // Make the IValue
+///   torch::IValue my_ivalue(26);
+///   std::cout << my_ivalue << "\n";
+///
+///   // Unwrap the IValue
+///   int64_t my_int = my_ivalue.toInt();
+///   std::cout << my_int << "\n";
+///
+///   // This will throw an error!
+///   // `my_ivalue` is tagged as an int and cannot be used as another type
+///   torch::Tensor my_tensor = my_ivalue.toTensor();
+/// \endrst
+struct TORCH_API IValue final {
+  IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag) {
+    if (isIntrusivePtr() &&
+        payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+      c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
+    }
+  }
+
+  IValue(IValue&& rhs) noexcept : tag(rhs.tag) {
+    moveFrom(std::move(rhs));
+  }
+
+  /// @private [doxygen private]
+  ~IValue() {
+    destroy();
+  }
+
+  C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
+    if (&rhs == this) {
+      return *this;
+    }
+
+    destroy();
+    moveFrom(std::move(rhs));
+    return *this;
+  }
+
+  IValue& operator=(IValue const& rhs) & {
+    *this = IValue(rhs);
+    return *this;
+  }
+
+  void dump() const;
+
+  /**
+   * Equality comparison. The semantics are the same as Python's `==`:
+   * 1. Numerical types are compared by value.
+   * 2. Tensors compute element-wise equality, returning a BoolTensor (see:
+   * `torch.eq()`)
+   * 3. Strings are compared by value.
+   * 4. Sequence types (list, tuple) are compared lexicographically by
+   *    comparing their elements. Different sequence types never compare equal.
+   * 5. Mappings (dict) must have equal (key, value) pairs.
+   * 6. If not listed above, the default behavior for is to test identity
+   * equality (e.g. pointer equality).
+   *
+   * Why does this return an IValue instead of a bool? Because in PyTorch,
+   * `tensor1 == tensor2` returns a `BoolTensor`, not a bool.
+   *
+   * NOTE: we (like Python) assume that identity equality implies value equality
+   * for efficiency.
+   * TODO: need to support customizing equality
+   */
+  IValue equals(const IValue& rhs) const;
+  /**
+   * This implements the same semantics as `bool(lhs == rhs)` in Python. which
+   * is the same as `equals()` except for Tensor types.
+   */
+  TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs);
+  TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs);
+
+  /**
+   * Identity comparison. Checks if `this` is the same object as `rhs`. The
+   * semantics are the same as Python's `is` operator.
+   *
+   * NOTE: Like in Python, this operation is poorly defined for primitive types
+   * like numbers and strings. Prefer to use `==` unless you really want to
+   * check identity equality.
+   */
+  bool is(const IValue& rhs) const;
+
+  /**
+   * Hashing for IValues. Returns an IValue-boxed int.
+   *
+   * Some notes:
+   * - Like eager, Tensors are hashed by looking at the pointer. This is not
+   *   strictly correct because two value-equal tensors with different tensor
+   *   pointers will hash differently, but we choose to reproduce the eager
+   *   semantics.
+   * - Hashing is not defined on all built-in IValue types (e.g. list and
+   *   dict), following Python. Calling `hash()` on these types will throw.
+   */
+  IValue hash() const {
+    return (int64_t)IValue::hash(*this);
+  }
+  // This is defined because `c10::hash` dispatches to a function of this
+  // signature. See the member function `hash()`.
+  static size_t hash(const IValue& iv);
+
+  /**
+   * @private [doxygen private]
+   * [container equality]
+   * This is an equality implementation that assumes objects with the same
+   * identity equal themselves, for efficiency reasons. We primarily have this
+   * for consistency, because Python does the same thing. This actually
+   * provokes user-visible changes in behavior due to quirks in torch:
+   *      [tensor1] == [tensor1] -> True (because container equality will first
+   * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError:
+   * Boolean value of Tensor with more than one value is ambiguous
+   */
+  TORCH_API friend bool _fastEqualsForContainer(
+      const IValue& lhs,
+      const IValue& rhs);
+
+ private:
+  static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) {
+    if (a.is_sparse()) {
+      return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
+    }
+    if (b.is_sparse()) {
+      return isAliasOf(a, b._values()) || isAliasOf(a, b._indices());
+    }
+    if (a.is_sparse_csr()) {
+      return isAliasOf(a.values(), b) || isAliasOf(a.crow_indices(), b) ||
+          isAliasOf(a.col_indices(), b);
+    }
+    if (b.is_sparse_csr()) {
+      return isAliasOf(a, b.values()) || isAliasOf(a, b.crow_indices()) ||
+          isAliasOf(a, b.col_indices());
+    }
+
+    // Opaque tensors such as the ones constructed by the MKL-DNN backend
+    // don't have storage so we just compare their TensorImpls.
+    // TODO: Find way to expose alias info for opaque tensors.
+    if (!a.has_storage() || !b.has_storage()) {
+      return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
+    }
+
+    return a.is_alias_of(b);
+  }
+
+  template 
+  bool isListOf() const;
+
+ public:
+  /// @private [doxygen private]
+  bool isAliasOf(const IValue& rhs) const {
+    if (this->tag != rhs.tag) {
+      // Trivially don't alias if the type is different
+      return false;
+    }
+
+    // Tensors should be compared based on internal storage
+    if (this->isTensor()) {
+      return isAliasOf(this->toTensor(), rhs.toTensor());
+    }
+
+    if (!isIntrusivePtr()) {
+      // Primitive types don't alias anything
+      return false;
+    }
+
+    AT_ASSERT(rhs.isIntrusivePtr());
+
+    // Other types can be compared by their ptr value
+    return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
+  }
+
+  /// @private [doxygen private]
+  size_t use_count() const noexcept {
+    if (isTensor()) {
+      return payload.as_tensor.use_count();
+    }
+
+    if (!isIntrusivePtrLegacyBehavior()) {
+      return 1;
+    }
+
+    if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
+      return 0;
+    }
+    return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr);
+  }
+
+  /// @private [doxygen private]
+  void swap(IValue& rhs) noexcept {
+    if (isTensor() && rhs.isTensor()) {
+      std::swap(payload.as_tensor, rhs.payload.as_tensor);
+    } else if (isTensor()) {
+      at::Tensor t = std::move(payload.as_tensor);
+      // As far as I can tell, omitting the usual explicit destructor call
+      // is not UB in and of itself, and it's a slight perf win. The
+      // destructor is a no-op, because the moved-from Tensor is
+      // effectively an intrusive_ptr in the null state, so we don't need
+      // the behavior for correctness reasons either. Leaving this
+      // explanatory comment, including commented-out destructor call, to
+      // make this abundantly clear.
+      //
+      // payload.as_tensor.~Tensor();
+      payload.u = rhs.payload.u;
+      new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
+    } else if (rhs.isTensor()) {
+      rhs.swap(*this);
+      return;
+    } else {
+      std::swap(payload.u, rhs.payload.u);
+    }
+    std::swap(tag, rhs.tag);
+  }
+
+  // Accessors for subtypes are arranged together below
+  // While some of these accessors could be generated through templates,
+  // we prefer to write them manually for clarity
+
+  IValue(at::TensorBase t) : tag(Tag::Tensor) {
+    new (&payload.as_tensor) at::Tensor(std::move(t));
+  }
+  bool isTensor() const {
+    return Tag::Tensor == tag;
+  }
+
+ private:
+  // Outlined error path so that toTensor() can be inlined.
+  [[noreturn]] void reportToTensorTypeError() const;
+
+ public:
+  at::Tensor toTensor() &&;
+  at::Tensor& toTensor() &;
+  const at::Tensor& toTensor() const&;
+  at::TensorImpl* unsafeToTensorImpl() const {
+    TORCH_INTERNAL_ASSERT(isTensor());
+    return payload.as_tensor.unsafeGetTensorImpl();
+  }
+
+  IValue(at::Storage s) : tag(Tag::Storage) {
+    payload.u.as_intrusive_ptr =
+        null_to_undefined_tensor(s.unsafeReleaseStorageImpl());
+  }
+  bool isStorage() const {
+    return Tag::Storage == tag;
+  }
+  c10::Storage toStorage() &&;
+  c10::Storage toStorage() const&;
+
+  const IValue& toIValue() const {
+    return *this;
+  }
+  IValue& toIValue() {
+    return *this;
+  }
+
+  /// @private [doxygen private]
+  IValue(intrusive_ptr blob) : tag(Tag::Blob) {
+    // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
+    // and store it as a Tensor instead.
+    payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
+  }
+
+  /// @private [doxygen private]
+  bool isBlob() const {
+    return Tag::Blob == tag;
+  }
+
+  /// @private [doxygen private]
+  c10::intrusive_ptr toBlob() &&;
+
+  /// @private [doxygen private]
+  c10::intrusive_ptr toBlob() const&;
+
+  // Capsule. No new callsites of these APIs should
+  // be introduced.
+  static inline IValue make_capsule(
+      intrusive_ptr blob);
+  bool isCapsule() const {
+    return Tag::Capsule == tag;
+  }
+  c10::intrusive_ptr toCapsule() &&;
+  c10::intrusive_ptr toCapsule() const&;
+
+  // Custom C++ classes
+  template <
+      typename T,
+      std::enable_if_t<
+          std::is_base_of::value,
+          int> = 0>
+  IValue(intrusive_ptr custom_class);
+  bool isCustomClass() const;
+  template 
+  c10::intrusive_ptr toCustomClass() &&;
+  template 
+  c10::intrusive_ptr toCustomClass() const&;
+
+  // Tuple
+  IValue(c10::intrusive_ptr v);
+
+  template <
+      typename... Args,
+      std::enable_if_t<
+          !std::disjunction<
+              std::is_lvalue_reference...,
+              std::negation>...>::value,
+          std::nullptr_t> = nullptr>
+  IValue(const std::tuple& t);
+  template <
+      typename... Args,
+      std::enable_if_t<
+          !std::disjunction<
+              std::is_lvalue_reference...,
+              std::negation>...>::value,
+          std::nullptr_t> = nullptr>
+  IValue(std::tuple&& t);
+  bool isTuple() const {
+    return Tag::Tuple == tag;
+  }
+  c10::intrusive_ptr toTuple() &&;
+  c10::intrusive_ptr toTuple() const&;
+  C10_NODISCARD ivalue::Tuple& toTupleRef() const;
+
+  // Double
+  IValue(double d) : tag(Tag::Double) {
+    payload.u.as_double = d;
+  }
+  bool isDouble() const {
+    return Tag::Double == tag;
+  }
+  double toDouble() const {
+    AT_ASSERT(isDouble());
+    return payload.u.as_double;
+  }
+
+  // ComplexDouble
+  template 
+  IValue(c10::complex c);
+  bool isComplexDouble() const {
+    return Tag::ComplexDouble == tag;
+  }
+  c10::complex toComplexDouble() const;
+
+  // Future
+  IValue(c10::intrusive_ptr v);
+  bool isFuture() const {
+    return Tag::Future == tag;
+  }
+  c10::intrusive_ptr toFuture() &&;
+  c10::intrusive_ptr toFuture() const&;
+
+  IValue(c10::intrusive_ptr v);
+  bool isAwait() const {
+    return Tag::Await == tag;
+  }
+  c10::intrusive_ptr toAwait() &&;
+  c10::intrusive_ptr toAwait() const&;
+
+  // RRef
+  IValue(c10::intrusive_ptr v);
+  bool isRRef() const {
+    return Tag::RRef == tag;
+  }
+  c10::intrusive_ptr toRRef() &&;
+  c10::intrusive_ptr toRRef() const&;
+
+  // Quantizer
+  IValue(c10::intrusive_ptr v);
+  bool isQuantizer() const {
+    return Tag::Quantizer == tag;
+  }
+  c10::intrusive_ptr toQuantizer() &&;
+  c10::intrusive_ptr toQuantizer() const&;
+
+  // Int
+  IValue(int64_t i) : tag(Tag::Int) {
+    payload.u.as_int = i;
+  }
+
+  IValue(const c10::SymInt& i) {
+    if (auto mi = i.maybe_as_int()) {
+      tag = Tag::Int;
+      payload.u.as_int = *mi;
+    } else {
+      tag = Tag::SymInt;
+      payload.u.as_intrusive_ptr = i.toSymNode().release();
+    }
+  }
+
+  bool isSymInt() const {
+    return Tag::SymInt == tag;
+  }
+
+  c10::SymInt toSymInt() &&;
+  c10::SymInt toSymInt() const&;
+
+  IValue(const c10::SymFloat& i) {
+    if (i.is_symbolic()) {
+      tag = Tag::SymFloat;
+      payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
+    } else {
+      tag = Tag::Double;
+      payload.u.as_double = i.as_float_unchecked();
+    }
+  }
+
+  bool isSymFloat() const {
+    return Tag::SymFloat == tag;
+  }
+
+  c10::SymFloat toSymFloat() &&;
+  c10::SymFloat toSymFloat() const&;
+
+  IValue(const c10::SymBool& i) {
+    if (auto mi = i.maybe_as_bool()) {
+      tag = Tag::Bool;
+      payload.u.as_int = *mi;
+    } else {
+      tag = Tag::SymBool;
+      payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
+    }
+  }
+
+  bool isSymBool() const {
+    return Tag::SymBool == tag;
+  }
+
+  c10::SymBool toSymBool() &&;
+  c10::SymBool toSymBool() const&;
+
+  // allow you to pass literals (3, 4) without ambiguity
+  IValue(int32_t i) : IValue(static_cast(i)) {}
+
+  bool isInt() const {
+    return Tag::Int == tag;
+  }
+
+  int64_t toInt() const {
+    AT_ASSERT(isInt());
+    return payload.u.as_int;
+  }
+
+  // Bool
+  IValue(bool b) : tag(Tag::Bool) {
+#if defined(__clang__) && defined(__x86_64__)
+    // Initializing entire payload stops valgrind's from reporting
+    // "jump or move depends on uninitialised value" in IValue copy constructor
+    // See https://github.com/pytorch/pytorch/issues/37117
+    payload.u.as_int = b;
+#else
+    payload.u.as_bool = b;
+#endif
+  }
+  bool isBool() const {
+    return Tag::Bool == tag;
+  }
+  bool toBool() const {
+    AT_ASSERT(isBool());
+    return payload.u.as_bool;
+  }
+
+  // IntList
+  bool isIntList() const;
+  bool isSymIntList() const;
+  c10::List toIntList() &&;
+  c10::List toIntList() const&;
+  std::vector toIntVector() const;
+  std::vector toSymIntVector() const;
+  at::DimVector toDimVector() const;
+
+  // ConstantString
+  IValue(c10::intrusive_ptr v);
+  IValue(std::string v);
+  IValue(const char* v) : IValue(std::string(v)) {}
+  IValue(c10::string_view v) : IValue(std::string(v)){};
+  bool isString() const {
+    return Tag::String == tag;
+  }
+  c10::intrusive_ptr toString() &&;
+  c10::intrusive_ptr toString() const&;
+  const std::string& toStringRef() const;
+  c10::optional> toOptionalStringRef()
+      const;
+  c10::string_view toStringView() const;
+
+  // DoubleList
+  bool isDoubleList() const;
+  c10::List toDoubleList() &&;
+  c10::List toDoubleList() const&;
+  std::vector toDoubleVector() const;
+
+  // ComplexDoubleList
+  bool isComplexDoubleList() const;
+  c10::List> toComplexDoubleList() &&;
+  c10::List> toComplexDoubleList() const&;
+  std::vector> toComplexDoubleVector() const;
+
+  // BoolList
+  bool isBoolList() const;
+  c10::List toBoolList() &&;
+  c10::List toBoolList() const&;
+
+  // TensorList
+  bool isTensorList() const;
+  c10::List toTensorList() &&;
+  c10::List toTensorList() const&;
+  std::vector toTensorVector() const;
+
+  // OptionalTensorList
+  bool isOptionalTensorList() const;
+  c10::List> toOptionalTensorList() &&;
+  c10::List> toOptionalTensorList() const&;
+  std::vector> toOptionalTensorVector() const;
+
+  // GenericList
+  IValue(c10::List v);
+  bool isList() const {
+    return Tag::GenericList == tag;
+  }
+  c10::List toList() &&;
+  c10::List toList() const&;
+  c10::ArrayRef toListRef() const;
+
+  // Some template constructors of IValue calls another constructor recursively.
+  // This SFINAEs the called constructor exists.
+  template 
+  using enable_if_ivalue_constructible =
+      std::enable_if_t::value, std::nullptr_t>;
+
+  // The rule for lists is more complicated; the generic constructor is only
+  // acceptable if your element isn't SymInt.  If you do have a SymInt element,
+  // then you must also, at construction time, check if you can decay the list
+  // into an int list (this is MANDATORY, as at a use site we may expect
+  // toIntList to work even if at the call site you had a SymIntArrayRef
+  // argument).  In practice, only SymIntArrayRef is used this way, so we
+  // didn't bother making it work for the other constructors, we just make sure
+  // they're not selectable.
+  template 
+  using enable_if_list_is_ivalue_constructible = std::enable_if_t<
+      std::is_constructible::value &&
+          !std::is_same::value,
+      std::nullptr_t>;
+
+  template  = nullptr>
+  IValue(c10::List&& v);
+  template  = nullptr>
+  IValue(const c10::List& v);
+  template  = nullptr>
+  IValue(at::ArrayRef v);
+  template  = nullptr>
+  IValue(const std::vector& v);
+  template  = nullptr>
+  IValue(std::vector&& v);
+  template 
+  IValue(std::array v);
+
+  // Manual constructors for lists of symints, which decay to int list if
+  // possible.  To avoid ambiguous overload situations, we template them
+  // to prevent implicit conversions
+  template 
+  using enable_if_symint =
+      std::enable_if_t::value, std::nullptr_t>;
+
+  template  = nullptr>
+  IValue(at::ArrayRef v);
+  template  = nullptr>
+  IValue(at::OptionalArrayRef v);
+  template  = nullptr>
+  IValue(const std::vector& v);
+  template  = nullptr>
+  IValue(std::vector&& v);
+
+
+  template 
+  using enable_if_ilist_is_ivalue_constructible = std::enable_if_t<
+      std::is_constructible::value &&
+          std::is_constructible::boxed_type>::
+              value &&
+          !std::is_same::value,
+      std::nullptr_t>;
+
+  template  = nullptr>
+  IValue(c10::IListRef v);
+
+  // GenericDict
+  IValue(c10::Dict v);
+  bool isGenericDict() const {
+    return Tag::GenericDict == tag;
+  }
+  c10::Dict toGenericDict() &&;
+  c10::Dict toGenericDict() const&;
+
+  template 
+  IValue(c10::Dict v);
+
+  template 
+  /// \cond
+  /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN
+  C10_DEPRECATED_MESSAGE(
+      "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.")
+      /// \endcond
+      IValue(std::unordered_map v);
+
+  template  = nullptr>
+  IValue(c10::optional v);
+  template  = nullptr>
+  IValue(c10::OptionalArrayRef v);
+  IValue(c10::nullopt_t);
+
+  // ClassType
+  IValue(c10::intrusive_ptr v);
+  bool isObject() const {
+    return tag == Tag::Object;
+  }
+  c10::intrusive_ptr toObject() &&;
+  c10::intrusive_ptr toObject() const&;
+  ivalue::Object& toObjectRef() const;
+
+  torch::jit::Module toModule() const;
+  bool isModule() const;
+
+  // PyObject
+  IValue(c10::intrusive_ptr v);
+  bool isPyObject() const {
+    return tag == Tag::PyObject;
+  }
+  c10::intrusive_ptr toPyObjectHolder() &&;
+  c10::intrusive_ptr toPyObjectHolder() const&;
+  PyObject* toPyObject() const;
+
+  // Enum
+  explicit IValue(c10::intrusive_ptr v);
+  bool isEnum() const {
+    return tag == Tag::Enum;
+  }
+  c10::intrusive_ptr toEnumHolder() &&;
+  c10::intrusive_ptr toEnumHolder() const&;
+
+  // None
+  IValue() : tag(Tag::None) {}
+  bool isNone() const {
+    return Tag::None == tag;
+  }
+  std::string toNone() const {
+    AT_ASSERT(isNone());
+    return "None";
+  }
+
+  static IValue uninitialized() {
+    auto i = IValue();
+    i.tag = Tag::Uninitialized;
+    return i;
+  }
+
+  // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble
+  IValue(const at::Scalar& s) : IValue() {
+    // NB: do the symbolic versions first, as isFloatingPoint is true
+    // for both SymFloat and double
+    if (s.isSymInt()) {
+      tag = Tag::SymInt;
+      payload.u.as_intrusive_ptr = s.toSymInt().toSymNode().release();
+    } else if (s.isSymFloat()) {
+      tag = Tag::SymFloat;
+      payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
+    } else if (s.isSymBool()) {
+      tag = Tag::SymBool;
+      payload.u.as_intrusive_ptr = s.toSymBool().toSymNodeImpl().release();
+    } else if (s.isFloatingPoint()) {
+      tag = Tag::Double;
+      payload.u.as_double = s.toDouble();
+    } else if (s.isComplex()) {
+      *this = s.toComplexDouble();
+    } else if (s.isBoolean()) {
+      tag = Tag::Bool;
+      payload.u.as_bool = s.toBool();
+    } else {
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+          s.isIntegral(false), "Unknown type in Scalar");
+      tag = Tag::Int;
+      payload.u.as_int = s.toLong();
+    }
+  }
+
+  bool isScalar() const {
+    return isDouble() || isInt() || isComplexDouble() || isBool() ||
+        isSymInt() || isSymFloat() || isSymBool();
+  }
+
+  at::Scalar toScalar() const {
+    if (isDouble())
+      return toDouble();
+    else if (isInt())
+      return toInt();
+    else if (isComplexDouble())
+      return toComplexDouble();
+    else if (isBool())
+      return toBool();
+    else if (isSymInt())
+      return toSymInt();
+    else if (isSymFloat())
+      return toSymFloat();
+    else if (isSymBool())
+      return toSymBool();
+    throw std::runtime_error("IValue is not a Scalar");
+  }
+
+  // Device
+  IValue(c10::Device d) : tag(Tag::Device) {
+    payload.u.as_device.type = d.type();
+    payload.u.as_device.index = d.index();
+  }
+  bool isDevice() const {
+    return Tag::Device == tag;
+  }
+  c10::Device toDevice() const {
+    AT_ASSERT(isDevice());
+    return c10::Device(payload.u.as_device.type, payload.u.as_device.index);
+  }
+
+  // Stream
+  IValue(c10::Stream s) : tag(Tag::Stream) {
+    auto v = c10::make_intrusive(s.pack3());
+    payload.u.as_intrusive_ptr = v.release();
+  }
+  c10::Stream toStream() &&;
+  c10::Stream toStream() const&;
+  bool isStream() const {
+    return Tag::Stream == tag;
+  }
+
+  // ScalarType
+  IValue(ScalarType t)
+      : IValue(static_cast::type>(t)) {}
+  at::ScalarType toScalarType() const {
+    return static_cast(toInt());
+  }
+
+  // Layout
+  IValue(Layout l)
+      : IValue(static_cast::type>(l)) {}
+  at::Layout toLayout() const {
+    return static_cast(toInt());
+  }
+
+  // MemoryFormat
+  IValue(MemoryFormat m)
+      : IValue(static_cast::type>(m)) {}
+  at::MemoryFormat toMemoryFormat() const {
+    return static_cast(toInt());
+  }
+
+  // QScheme
+  IValue(at::QScheme qscheme) : tag(Tag::Int) {
+    payload.u.as_int = static_cast(qscheme);
+  }
+
+  at::QScheme toQScheme() const {
+    return static_cast(toInt());
+  }
+
+  // Dimname
+  IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {}
+
+  at::Dimname toDimname() const {
+    return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef()));
+  }
+
+  // Generator
+  IValue(at::Generator g) : tag(Tag::Generator) {
+    payload.u.as_intrusive_ptr =
+        null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl());
+  }
+  bool isGenerator() const {
+    return Tag::Generator == tag;
+  }
+  at::Generator toGenerator() &&;
+  at::Generator toGenerator() const&;
+
+  // for debugging
+  std::string tagKind() const {
+    switch (tag) {
+#define DEFINE_CASE(x) \
+  case Tag::x:         \
+    return #x;
+      TORCH_FORALL_TAGS(DEFINE_CASE)
+#undef DEFINE_CASE
+    }
+    return "InvalidTag(" + std::to_string(static_cast(tag)) + ")";
+  }
+
+  // generic v.to() implementations
+  // that can be used in special functions like pop/push
+  // that use template meta-programming.
+  // prefer the directly named methods when you can,
+  // since they are simpler to understand
+
+  // Note: if you get linker errors saying one of these is missing,
+  // change it to ... && = delete; and you will see better error messages for
+  // why However, we cannot commit this because some compiler versions barf on
+  // it.
+  template 
+  T to() &&;
+  template 
+  typename c10::detail::ivalue_to_const_ref_overload_return::type to()
+      const&;
+
+  // ToOptional: convert a IValue to the Optional obj that accepts both T and
+  // None
+  template 
+  optional toOptional();
+  template 
+  optional toOptional() const;
+
+  /// @private [doxygen private]
+  /// this is a shallow comparison of two IValues to test the object identity
+  bool isSameIdentity(const IValue& rhs) const;
+
+  // Computes the "official" string representation of an IValue. This produces a
+  // TorchScript expression that can be used to recreate an IValue with the same
+  // value (e.g. when we are printing constants in the serializer).
+  //
+  // Callers can use `customFormatter` to override how `repr()` prints out an
+  // IValue. This is useful if you have some other environment where you can
+  // look up values, and you want to print a reference to that environment (like
+  // the serializer's constant table).
+  //
+  // repr() is not necessarily defined on all objects!
+  std::ostream& repr(
+      std::ostream& stream,
+      std::function customFormatter)
+      const;
+
+  // Computes an "informal" string representation of an IValue. This should be
+  // used for debugging, or servicing `print()`-like functions.
+  // This is different from `repr()` in that there is no expectation that we can
+  // exactly reconstruct an IValue from the output; feel free to use a
+  // concise/pretty form
+  TORCH_API friend std::ostream& operator<<(std::ostream& out, const IValue& v);
+
+  bool isPtrType() const {
+    if (isTensor()) {
+      return payload.as_tensor.defined();
+    }
+    return isIntrusivePtrLegacyBehavior();
+  }
+
+  /// @private [doxygen private]
+  const void* internalToPointer() const {
+    TORCH_INTERNAL_ASSERT(
+        isPtrType(), "Can only call internalToPointer() for pointer types");
+    if (isTensor()) {
+      return payload.as_tensor.unsafeGetTensorImpl();
+    } else {
+      return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()
+          ? payload.u.as_intrusive_ptr
+          : nullptr;
+    }
+  }
+
+  template 
+  TypePtr type() const;
+
+  // Detect aliased tensors.
+  struct HashAliasedIValue {
+    size_t hashTensor(const at::Tensor& ten) const {
+      if (ten.is_sparse()) {
+        // COO sparse tensors have a "values" tensor and an "indices" tensor
+        // so this will detect overlap of sparse tensors that share a values
+        // tensor, but not sparse tensors that share an indices tensor.
+        return hashTensor(ten._values());
+      } else if (ten.is_sparse_csr()) {
+        // COO sparse tensors have a "values" tensor and an "indices" tensor
+        // so this will detect overlap of sparse tensors that share a values
+        // tensor, but not sparse tensors that share an indices tensor.
+        return hashTensor(ten.values());
+      } else if (!ten.has_storage()) {
+        // Opaque tensors such as the ones constructed by the MKL-DNN backend
+        // don't have storage so we just use their TensorImpls.
+        // TODO: Find way to expose alias info for opaque tensors.
+        return reinterpret_cast(ten.unsafeGetTensorImpl());
+      } else {
+        return reinterpret_cast(ten.storage().unsafeGetStorageImpl());
+      }
+    }
+    size_t operator()(const IValue& val) const {
+      if (val.isTensor()) {
+        return hashTensor(val.toTensor());
+      }
+      // If it is not a Tensor, then two mutable IValues alias each other only
+      // if they are the same pointer.
+      return val.payload.u.as_int;
+    }
+  };
+
+  struct CompAliasedIValues {
+    bool operator()(const IValue& lhs, const IValue& rhs) const {
+      return lhs.isAliasOf(rhs);
+    }
+  };
+
+  using HashAliasedIValues =
+      std::unordered_set;
+  using HashAliasedIValueMap =
+      std::unordered_map;
+
+  // Chechs if this and rhs has a subvalues in common.
+  // [t1,t2] and [t2, t3] returns true.
+  bool overlaps(const IValue& rhs) const;
+
+  // Inserts all subvalues of this in subValues.
+  void getSubValues(HashAliasedIValues& subValues) const;
+
+  // Apply visitor to every subvalue.
+  // TODO: There are several places that recurse over IValue. This is fragile.
+  // This visitor should be used to recurse over ivalues.
+  void visit(const std::function& visitor) const;
+  IValue deepcopy(c10::optional device = c10::nullopt) const;
+  IValue deepcopy(
+      HashAliasedIValueMap& memo,
+      c10::optional device = c10::nullopt) const;
+
+ private:
+  static c10::intrusive_ptr_target* null_to_undefined_tensor(
+      c10::intrusive_ptr_target* p) {
+    return p ? p
+             : static_cast(
+                   c10::UndefinedTensorImpl::singleton());
+  }
+
+  static bool ptrEqual(const IValue& lhs, const IValue& rhs);
+  // NOTE: IValue tags are intentionally private. In the future we may encode
+  // this value different (e.g. using NaN boxing), and this would make it more
+  // costly to determine the tag for all types vs just determining if something
+  // is a particular type. Instead we want clients to use the `isX` methods when
+  // possible. If for perf. reasons you really, absolutely, must have a jump
+  // table, then we can revisit this.
+  enum class Tag : uint32_t {
+#define DEFINE_TAG(x) x,
+    TORCH_FORALL_TAGS(DEFINE_TAG)
+#undef DEFINE_TAG
+  };
+
+#define COUNT_TAG(x) 1 +
+  static constexpr auto kNumTags = TORCH_FORALL_TAGS(COUNT_TAG) 0;
+#undef COUNT_TAG
+
+  template <
+      class T,
+      class NullType = c10::detail::intrusive_target_default_null_type>
+  c10::intrusive_ptr moveToIntrusivePtr();
+  template <
+      typename T,
+      class NullType = c10::detail::intrusive_target_default_null_type>
+  c10::intrusive_ptr toIntrusivePtr() const;
+
+  void destroy() {
+    // We carefully construct this call to both 1) avoid UB by using
+    // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable
+    // the compiler to generate the same code for each case. It is
+    // surprisingly difficult to get this right.
+    if (isTensor() || isIntrusivePtr()) {
+      c10::intrusive_ptr_target* p = isTensor()
+          ? payload.as_tensor.unsafeGetTensorImpl()
+          : payload.u.as_intrusive_ptr;
+      c10::intrusive_ptr::
+          reclaim(p);
+      // No need to make this destructor call!
+      // payload.as_tensor.~Tensor();
+    }
+  }
+
+  C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
+    if (rhs.isTensor()) {
+      new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
+      // As far as I can tell, omitting the usual explicit destructor call
+      // is not UB in and of itself, and it's a slight perf win. The
+      // destructor is a no-op, because the moved-from Tensor is
+      // effectively an intrusive_ptr in the null state, so we don't need
+      // the behavior for correctness reasons either. Leaving this
+      // explanatory comment, including commented-out destructor call, to
+      // make this abundantly clear.
+      //
+      // rhs.payload.as_tensor.~Tensor();
+    } else {
+      payload.u = rhs.payload.u;
+    }
+    tag = rhs.tag;
+    rhs.clearToNone();
+  }
+
+  void clearToNone() noexcept {
+    payload.u.as_int = 0;
+    tag = Tag::None;
+  }
+
+ private:
+  // This is the source of truth for isIntrusivePtr; edit results here
+  // as needed and isIntrusivePtr will pick them up.
+  // NOLINTBEGIN(bugprone-branch-clone)
+  static constexpr bool isIntrusivePtrConstexpr(Tag tag) {
+    switch (tag) {
+      case Tag::None:
+        return false;
+      case Tag::Tensor:
+        return false;
+      case Tag::Storage:
+        return true;
+      case Tag::Generator:
+        return true;
+      case Tag::Double:
+        return false;
+      case Tag::ComplexDouble:
+        return true;
+      case Tag::Int:
+        return false;
+      case Tag::SymInt:
+        return true;
+      case Tag::SymFloat:
+        return true;
+      case Tag::SymBool:
+        return true;
+      case Tag::Bool:
+        return false;
+      case Tag::Tuple:
+        return true;
+      case Tag::String:
+        return true;
+      case Tag::Blob:
+        return true;
+      case Tag::GenericList:
+        return true;
+      case Tag::GenericDict:
+        return true;
+      case Tag::Future:
+        return true;
+      case Tag::Await:
+        return true;
+      case Tag::Device:
+        return false;
+      case Tag::Stream:
+        return true;
+      case Tag::Object:
+        return true;
+      case Tag::PyObject:
+        return true;
+      case Tag::Uninitialized:
+        return false;
+      case Tag::Capsule:
+        return true;
+      case Tag::RRef:
+        return true;
+      case Tag::Quantizer:
+        return true;
+      case Tag::Enum:
+        return true;
+    }
+    return false;
+  }
+  // NOLINTEND(bugprone-branch-clone)
+
+ public:
+  // Don't edit this just to add results for new tags; edit
+  // isIntrusivePtrConstexpr above.
+  bool isIntrusivePtr() const {
+    // Implementation NOTE: the switch in isIntrusivePtrConstexpr
+    // above is the previous production implementation of this
+    // function. We observed that, at least on x86_64, the generated
+    // instruction sequence was a similar bit vector test to what we
+    // have manually implemented below, except that there was an extra
+    // "bounds check" branch confirming, essentially, that `tag <
+    // kNumTags` and providing a consistent result in that case. We
+    // don't care about the result if tag is out of bounds, so we'd
+    // like to eliminate that comparison and branch; manually
+    // implementing this function as a bit test is the simplest way I
+    // could find to accomplish that elimination.
+    static constexpr uint32_t kTruthTableBitVector =
+#define TRUTH_TABLE_ENTRY(tag) \
+  (uint32_t(isIntrusivePtrConstexpr(Tag::tag)) << uint32_t(Tag::tag)) |
+        TORCH_FORALL_TAGS(TRUTH_TABLE_ENTRY)
+#undef TRUTH_TABLE_ENTRY
+            0;
+
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+        static_cast(tag) < kNumTags,
+        "unexpected tag ",
+        static_cast(tag));
+    return kTruthTableBitVector & (1 << (uint32_t(tag) % 32));
+  }
+
+  // Storage and Generator were treated specially when
+  // is_intrusive_ptr was stored as explicit state. This getter
+  // preserves the old behavior for use with WeakIValue for now.
+  bool isIntrusivePtrLegacyBehavior() const {
+    if (tag == Tag::Storage || tag == Tag::Generator) {
+      return payload.u.as_intrusive_ptr !=
+          c10::UndefinedTensorImpl::singleton();
+    } else {
+      return isIntrusivePtr();
+    }
+  }
+
+  union Payload {
+    // [TriviallyCopyablePayload]
+    // We use a nested union here so that we can make the copy easy
+    // and efficient in the non-tensor (i.e., trivially copyable)
+    // case. Specifically, we do not have to do a switch-on-tag to
+    // figure out which union member to assign; we can just use
+    // TriviallyCopyablePayload::operator=.
+    union TriviallyCopyablePayload {
+      TriviallyCopyablePayload() : as_int(0) {}
+      int64_t as_int;
+      double as_double;
+      bool as_bool;
+      // Invariant: never nullptr; null state is represented as
+      // c10::UndefinedTensorImpl::singleton() for consistency of
+      // representation with Tensor.
+      c10::intrusive_ptr_target* as_intrusive_ptr;
+      struct {
+        c10::DeviceType type;
+        DeviceIndex index;
+      } as_device;
+    } u;
+    at::Tensor as_tensor;
+    Payload() : u() {}
+    ~Payload() {}
+  };
+
+  IValue(const Payload& p, Tag t) : tag(t) {
+    if (isTensor()) {
+      new (&payload.as_tensor) at::Tensor(p.as_tensor);
+    } else {
+      payload.u = p.u;
+    }
+  }
+
+  template 
+  struct TagType {};
+
+  friend MaybeOwnedTraits;
+
+  Payload payload;
+  Tag tag{IValue::Tag::None};
+  friend struct WeakIValue;
+};
+
+struct TORCH_API WeakIValue final {
+  WeakIValue() = default;
+
+  WeakIValue(const WeakIValue& rhs)
+      : payload(rhs.payload),
+        tag(rhs.tag),
+        is_intrusive_ptr(rhs.is_intrusive_ptr) {
+    if (is_intrusive_ptr &&
+        payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+      c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
+    }
+  }
+  WeakIValue(const IValue& rhs)
+      : tag(rhs.tag), is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) {
+    if (rhs.isTensor()) {
+      payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
+      is_intrusive_ptr = true;
+    } else {
+      payload = rhs.payload.u;
+    }
+    if (is_intrusive_ptr) {
+      if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+        c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
+      }
+    }
+  }
+  WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() {
+    swap(rhs);
+  }
+  ~WeakIValue() {
+    if (is_intrusive_ptr &&
+        payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
+      c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr);
+    }
+  }
+  WeakIValue& operator=(WeakIValue&& rhs) & noexcept {
+    WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None
+    return *this;
+  }
+  WeakIValue& operator=(WeakIValue const& rhs) & {
+    WeakIValue(rhs).swap(*this);
+    return *this;
+  }
+  void swap(WeakIValue& rhs) noexcept {
+    std::swap(payload, rhs.payload);
+    std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
+    std::swap(tag, rhs.tag);
+  }
+
+  bool isSameIdentity(const WeakIValue& rhs) const {
+    return payload.as_int == rhs.payload.as_int && tag == rhs.tag &&
+        is_intrusive_ptr == rhs.is_intrusive_ptr;
+  }
+
+  IValue lock() const {
+    if (!is_intrusive_ptr) {
+      IValue::Payload newPayload;
+      newPayload.u = payload;
+      return IValue(newPayload, tag);
+    }
+    if (IValue::Tag::Tensor == tag) {
+      auto temp =
+          c10::weak_intrusive_ptr::
+              reclaim(static_cast(payload.as_intrusive_ptr));
+      c10::intrusive_ptr ip(
+          temp.lock());
+      temp.release();
+      if (!ip) {
+        return IValue();
+      } else {
+        return IValue(at::Tensor(std::move(ip)));
+      }
+    } else {
+      auto temp = c10::weak_intrusive_ptr::reclaim(
+          payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+              ? nullptr
+              : payload.as_intrusive_ptr);
+      IValue::Payload pl;
+      pl.u.as_intrusive_ptr = temp.lock().release();
+      temp.release();
+      if (!pl.u.as_intrusive_ptr) {
+        return IValue();
+      } else {
+        return IValue(pl, tag);
+      }
+    }
+  }
+
+  size_t use_count() const noexcept {
+    if (!is_intrusive_ptr) {
+      return 1;
+    }
+    auto temp = c10::weak_intrusive_ptr<
+        c10::intrusive_ptr_target,
+        c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr);
+    size_t result = temp.use_count();
+    temp.release();
+    return result;
+  }
+
+  size_t weak_use_count() const noexcept {
+    if (!is_intrusive_ptr) {
+      return 1;
+    }
+    auto temp = c10::weak_intrusive_ptr<
+        c10::intrusive_ptr_target,
+        c10::UndefinedTensorImpl>::reclaim(payload.as_intrusive_ptr);
+    size_t result = temp.weak_use_count();
+    temp.release();
+    return result;
+  }
+  size_t hash() const {
+    return payload.as_int;
+  }
+
+ private:
+  using Payload = IValue::Payload::TriviallyCopyablePayload;
+  Payload payload;
+  IValue::Tag tag{IValue::Tag::None};
+  bool is_intrusive_ptr{false};
+};
+
+// An owning pointer to a type. When the type is class type, it requires a pair
+// of shared_ptrs to the class type and its owning CU, so that the class type is
+// guaranteed to stay alive as long as we hold this object.
+struct TORCH_API StrongTypePtr {
+  StrongTypePtr(std::shared_ptr cu, TypePtr type);
+
+  std::shared_ptr cu_;
+  TypePtr type_;
+};
+
+// [Constant Object Weak CompilationUnit Reference]
+// A non owning pointer to a type. When a class get inserted as a constant
+// into a graph, if we used a strong pointer we would have a circular reference
+// from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the
+// Constant Object)
+struct TORCH_API WeakTypePtr {
+  WeakTypePtr(std::weak_ptr cu, TypePtr type);
+
+  std::weak_ptr cu_;
+  TypePtr type_;
+};
+
+// internal build errors with std::variant :/
+struct WeakOrStrongCompilationUnit {
+  explicit WeakOrStrongCompilationUnit(
+      std::shared_ptr shared_cu)
+      : strong_ptr_(std::move(shared_cu)), weak_ptr_(c10::nullopt) {}
+
+  explicit WeakOrStrongCompilationUnit(
+      std::weak_ptr weak_cu)
+      : strong_ptr_(c10::nullopt), weak_ptr_(std::move(weak_cu)) {}
+
+  std::shared_ptr getStrongRefOrThrow() const {
+    TORCH_INTERNAL_ASSERT(strong_ptr_ != c10::nullopt);
+    return *strong_ptr_;
+  }
+
+  std::weak_ptr getWeakRefOrThrow() const {
+    TORCH_INTERNAL_ASSERT(weak_ptr_ != c10::nullopt);
+    return *weak_ptr_;
+  }
+
+  bool holdingStrongRef() const {
+    return strong_ptr_ != c10::nullopt;
+  }
+
+  bool holdingEmptyStrongRef() const {
+    return holdingStrongRef() && *strong_ptr_ == nullptr;
+  }
+
+  c10::optional> strong_ptr_;
+  c10::optional> weak_ptr_;
+};
+
+// An Object will hold a non-owning Compilation Unit reference if it is a
+// Constant in the graph and a Owning reference otherwise
+struct TORCH_API WeakOrStrongTypePtr {
+  explicit WeakOrStrongTypePtr(WeakTypePtr weak)
+      : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))),
+        type_(std::move(weak.type_)) {}
+  explicit WeakOrStrongTypePtr(StrongTypePtr strong)
+      : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))),
+        type_(std::move(strong.type_)) {}
+  explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type)
+      : cu_(std::move(cu)), type_(std::move(type)) {}
+  WeakTypePtr asWeakTypePtr() const;
+
+  WeakOrStrongCompilationUnit cu_;
+  TypePtr type_;
+
+  bool holds_strong_ref() const {
+    return cu_.holdingStrongRef();
+  }
+
+  bool holds_empty_strong_ref() const {
+    return cu_.holdingEmptyStrongRef();
+  }
+};
+
+} // namespace c10
+
+#include  // IWYU pragma: keep
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue_inl.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..a9b25a0451b3ce3e710b009f52583926ad657350
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue_inl.h
@@ -0,0 +1,2545 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace torch {
+namespace jit {
+struct Function;
+struct CompilationUnit;
+} // namespace jit
+TORCH_API bool isCustomClass(const c10::IValue& v);
+} // namespace torch
+namespace c10 {
+struct IValue;
+struct ClassType;
+struct TupleType;
+struct EnumType;
+struct InferredType;
+
+// For custom class __init__ registration, we need to pass in a function
+// that looks like this: [](IValue x, args...)
+
+// However, make_boxed_from_unboxed_functor.h automatically sets the input types
+// of the function by introspecting the types of the functor (which is IValue in
+// this case). However, we need the type it binds to be Foo.
+
+// Instead, we pass in a lambda [](ivalue_holder x, args...) from
+// which getTypePtr can recover the original class pointer.
+
+template 
+struct tagged_capsule {
+  IValue ivalue;
+};
+
+template 
+c10::intrusive_ptr IValue::moveToIntrusivePtr() {
+  auto t = c10::intrusive_ptr::reclaim(
+      payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
+      ? NullType::singleton()
+      : static_cast(payload.u.as_intrusive_ptr));
+  clearToNone();
+  return t;
+}
+template 
+c10::intrusive_ptr IValue::toIntrusivePtr() const {
+  if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
+    return c10::intrusive_ptr();
+  }
+  c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
+  return c10::intrusive_ptr::reclaim(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+
+template 
+intrusive_ptr static_intrusive_pointer_cast(intrusive_ptr r) {
+  return intrusive_ptr::reclaim(static_cast(r.release()));
+}
+
+template 
+intrusive_ptr dynamic_intrusive_pointer_cast(intrusive_ptr r) {
+  return intrusive_ptr::reclaim(dynamic_cast(r.release()));
+}
+
+inline c10::intrusive_ptr IValue::toFuture() && {
+  AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toFuture() const& {
+  AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toAwait() && {
+  AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toAwait() const& {
+  AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toRRef() && {
+  AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toRRef() const& {
+  AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toQuantizer() && {
+  AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toQuantizer() const& {
+  AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toString() && {
+  AT_ASSERT(isString(), "Expected String but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toString() const& {
+  AT_ASSERT(isString(), "Expected String but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toObject() && {
+  AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toObject() const& {
+  AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::
+    toPyObjectHolder() && {
+  TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toPyObjectHolder()
+    const& {
+  TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toEnumHolder() && {
+  TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toEnumHolder() const& {
+  TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline c10::complex IValue::toComplexDouble() const {
+  TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
+  auto ptr = toIntrusivePtr();
+  return (*ptr).val;
+}
+inline at::Tensor IValue::toTensor() && {
+  if (C10_UNLIKELY(!isTensor())) {
+    reportToTensorTypeError();
+  }
+  auto result = std::move(payload.as_tensor);
+  // As far as I can tell, omitting the usual explicit destructor call
+  // is not UB in and of itself, and it's a slight perf win. The
+  // destructor is a no-op, because the moved-from Tensor is
+  // effectively an intrusive_ptr in the null state, so we don't need
+  // the behavior for correctness reasons either. Leaving this
+  // explanatory comment, including commented-out destructor call, to
+  // make this abundantly clear.
+  //
+  // payload.as_tensor.~Tensor();
+  clearToNone();
+  return result;
+}
+inline at::Tensor& IValue::toTensor() & {
+  if (C10_UNLIKELY(!isTensor())) {
+    reportToTensorTypeError();
+  }
+  return payload.as_tensor;
+}
+inline const at::Tensor& IValue::toTensor() const& {
+  if (C10_UNLIKELY(!isTensor())) {
+    reportToTensorTypeError();
+  }
+  return payload.as_tensor;
+}
+inline c10::Storage IValue::toStorage() && {
+  AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
+  return c10::Storage(
+      moveToIntrusivePtr());
+}
+inline c10::Storage IValue::toStorage() const& {
+  AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
+  return c10::Storage(toIntrusivePtr());
+}
+inline c10::Stream IValue::toStream() && {
+  AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
+  auto ptr = toIntrusivePtr();
+  return c10::Stream::unpack3((*ptr).val.stream_id,
+                              (*ptr).val.device_index,
+                              (*ptr).val.device_type);
+}
+inline c10::Stream IValue::toStream() const& {
+  AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
+  auto ptr = toIntrusivePtr();
+  return c10::Stream::unpack3((*ptr).val.stream_id,
+                              (*ptr).val.device_index,
+                              (*ptr).val.device_type);
+}
+inline c10::intrusive_ptr IValue::toBlob() && {
+  AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toBlob() const& {
+  AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
+  return toIntrusivePtr();
+  ;
+}
+inline c10::intrusive_ptr IValue::toCapsule() && {
+  TORCH_INTERNAL_ASSERT(isCapsule());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toCapsule() const& {
+  TORCH_INTERNAL_ASSERT(isCapsule());
+  return toIntrusivePtr();
+}
+inline at::Generator IValue::toGenerator() && {
+  AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
+  return at::Generator(moveToIntrusivePtr());
+}
+inline at::Generator IValue::toGenerator() const& {
+  AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
+  return at::Generator(toIntrusivePtr());
+}
+inline c10::SymInt IValue::toSymInt() && {
+  AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
+  if (isSymInt()) {
+    return c10::SymInt(moveToIntrusivePtr());
+  } else {
+    return c10::SymInt(payload.u.as_int);
+  }
+}
+inline c10::SymInt IValue::toSymInt() const& {
+  AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
+  if (isSymInt()) {
+    return c10::SymInt(toIntrusivePtr());
+  } else {
+    return c10::SymInt(payload.u.as_int);
+  }
+}
+inline c10::SymFloat IValue::toSymFloat() && {
+  AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
+  if (isSymFloat()) {
+    return c10::SymFloat(moveToIntrusivePtr());
+  } else {
+    return c10::SymFloat(payload.u.as_double);
+  }
+}
+inline c10::SymFloat IValue::toSymFloat() const& {
+  AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
+  if (isSymFloat()) {
+    return c10::SymFloat(toIntrusivePtr());
+  } else {
+    return c10::SymFloat(payload.u.as_double);
+  }
+}
+inline c10::SymBool IValue::toSymBool() && {
+  AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
+  if (isSymBool()) {
+    return c10::SymBool(moveToIntrusivePtr());
+  } else {
+    return c10::SymBool(payload.u.as_bool);
+  }
+}
+
+inline c10::SymBool IValue::toSymBool() const& {
+  AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
+  if (isSymBool()) {
+    return c10::SymBool(toIntrusivePtr());
+  } else {
+    return c10::SymBool(payload.u.as_bool);
+  }
+}
+
+namespace ivalue {
+
+void TORCH_API
+checkCustomClassType(const ClassType* expected_type, const Type* actual_type);
+
+template 
+using Shared = c10::intrusive_ptr;
+
+// string
+struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
+ private:
+   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+  const std::string str_;
+
+ public:
+  ConstantString(std::string str) : str_(std::move(str)) {}
+  ConstantString(c10::string_view str) : str_(std::string(str)) {}
+  static c10::intrusive_ptr create(std::string str_);
+  static c10::intrusive_ptr create(c10::string_view str_);
+  static c10::intrusive_ptr create(const char* str_);
+
+  const std::string& string() const {
+    return str_;
+  }
+  c10::string_view string_view() const {
+    return str_;
+  }
+
+  operator const std::string&() const {
+    return string();
+  }
+  TORCH_API friend std::ostream& operator<<(
+      std::ostream& out,
+      const ConstantString& v);
+};
+
+struct Future;
+
+struct TORCH_API TupleElements {
+ private:
+  size_t inlineSize_;
+  // We represent TupleElements this way to save doing a heap
+  // allocation in the common (at least for unpickling) case where we
+  // have only 3 elements. We have our own union instead of
+  // c10::SmallVector because c10::SmallVector always
+  // stores the begin/end/capacity pointers, which would be a waste of
+  // space in our use case.
+  union {
+    std::vector elementsVector_;
+    // Don't want to declare a std::array because the convenient
+    // iteration and size members are a footgun in this case -- the
+    // actual size of the array may be smaller than 3!
+    // NOLINTNEXTLINE(*c-arrays*)
+    IValue elementsInline_[3];
+  };
+
+  void destroyInline() {
+   for (const auto ii : c10::irange(inlineSize_)) {
+     elementsInline_[ii].~IValue();
+   }
+  }
+ public:
+
+  using iterator = IValue*;
+  using const_iterator = const IValue*;
+
+  TupleElements() : inlineSize_(0) {
+    new (&elementsVector_) std::vector();
+  }
+
+  explicit TupleElements(std::vector elements)
+  : inlineSize_(0), elementsVector_(std::move(elements)) {}
+
+  explicit TupleElements(c10::ArrayRef elements)
+  : inlineSize_(elements.size() <= 3 ? elements.size() : 0) {
+    switch (inlineSize_) {
+      case 3:
+        new (&elementsInline_[2]) IValue(elements[2]);
+        [[fallthrough]];
+      case 2:
+        new (&elementsInline_[1]) IValue(elements[1]);
+        [[fallthrough]];
+      case 1:
+        new (&elementsInline_[0]) IValue(elements[0]);
+        break;
+      case 0:
+        new (&elementsVector_) std::vector(elements.begin(), elements.end());
+        break;
+    }
+  }
+
+  explicit TupleElements(IValue&& e1)
+  : inlineSize_(1) {
+    new (&elementsInline_[0]) IValue(std::move(e1));
+  }
+
+  explicit TupleElements(IValue&& e1, IValue&& e2)
+  : inlineSize_(2) {
+    new (&elementsInline_[0]) IValue(std::move(e1));
+    new (&elementsInline_[1]) IValue(std::move(e2));
+  }
+
+  explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3)
+  : inlineSize_(3) {
+    new (&elementsInline_[0]) IValue(std::move(e1));
+    new (&elementsInline_[1]) IValue(std::move(e2));
+    new (&elementsInline_[2]) IValue(std::move(e3));
+  }
+
+  ~TupleElements() {
+    if (inlineSize_) {
+      destroyInline();
+    } else {
+      elementsVector_.~vector();
+    }
+  }
+
+  // It would be nice to make this noncopyable to prevent people from
+  // writing code like `auto output =
+  // forward(...).toTupleRef().elements()` (which does refcount bumps on
+  // each element, unlike the more efficient but verbose
+  // ```
+  // auto outputIntrusivePtr = forward(...).toTuple();
+  // const auto& output = outputIntrusivePtr->elements();
+  // ```
+  // ), but there is simply an overwhelming amount of code that does
+  // it the inefficient way.
+  // See also operator std::vector below.
+  TupleElements(const TupleElements& rhs)
+  : inlineSize_(rhs.inlineSize_) {
+    if (rhs.inlineSize_) {
+      for (const auto  ii : c10::irange(inlineSize_)) {
+        new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
+      }
+    } else {
+      new (&elementsVector_) std::vector(rhs.elementsVector_);
+    }
+  }
+
+  TupleElements& operator=(const TupleElements& rhs) {
+    if (inlineSize_) {
+      if (rhs.inlineSize_) {
+        for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
+          elementsInline_[ii] = rhs.elementsInline_[ii];
+        }
+        if (rhs.inlineSize_ > inlineSize_) {
+          for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
+            new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
+          }
+        } else {
+          for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
+            elementsInline_[ii].~IValue();
+          }
+        }
+      } else {
+        destroyInline();
+        new (&elementsVector_) std::vector(rhs.elementsVector_);
+      }
+    } else {
+      if (rhs.inlineSize_) {
+        elementsVector_.~vector();
+        for (const auto ii : c10::irange(rhs.inlineSize_)) {
+          new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
+        }
+      } else {
+        elementsVector_ = rhs.elementsVector_;
+      }
+    }
+    inlineSize_ = rhs.inlineSize_;
+    return *this;
+  }
+
+  TupleElements(TupleElements&& rhs) noexcept
+  : inlineSize_(rhs.inlineSize_) {
+    if (inlineSize_) {
+      for (const auto ii : c10::irange(inlineSize_)) {
+        new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
+      }
+    } else {
+      new (&elementsVector_) std::vector(std::move(rhs.elementsVector_));
+    }
+  }
+
+  TupleElements& operator=(TupleElements&& rhs) noexcept {
+    if (inlineSize_) {
+      if (rhs.inlineSize_) {
+        for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
+          elementsInline_[ii] = std::move(rhs.elementsInline_[ii]);
+        }
+        if (rhs.inlineSize_ > inlineSize_) {
+          for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
+            new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
+          }
+        } else {
+          for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
+            elementsInline_[ii].~IValue();
+          }
+        }
+      } else {
+        destroyInline();
+        new (&elementsVector_) std::vector(std::move(rhs.elementsVector_));
+      }
+    } else {
+      if (rhs.inlineSize_) {
+        elementsVector_.~vector();
+        for (const auto ii : c10::irange(rhs.inlineSize_)) {
+          new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
+        }
+      } else {
+        elementsVector_ = std::move(rhs.elementsVector_);
+      }
+    }
+    inlineSize_ = rhs.inlineSize_;
+    return *this;
+  }
+
+  C10_NODISCARD c10::ArrayRef asArrayRef() const {
+    if (inlineSize_) {
+      return c10::ArrayRef(elementsInline_, inlineSize_);
+    } else {
+      return elementsVector_;
+    }
+  }
+
+  // Mimic implicit conversion from std::vector to ArrayRef.
+  operator c10::ArrayRef() const {
+    return asArrayRef();
+  }
+
+  static size_t hash(const TupleElements& v) {
+    return c10::hash>()(v.asArrayRef());
+  }
+
+  void setContents(std::vector&& contents) {
+    if (inlineSize_) {
+      destroyInline();
+      new (&elementsVector_) std::vector(std::move(contents));
+      inlineSize_ = 0;
+    } else {
+      elementsVector_ = std::move(contents);
+    }
+  }
+
+  C10_NODISCARD bool empty() const {
+    return inlineSize_ ? false : elementsVector_.empty();
+  }
+
+  C10_NODISCARD size_t size() const {
+    return inlineSize_ ? inlineSize_ : elementsVector_.size();
+  }
+
+  C10_NODISCARD IValue& operator[](size_t idx) {
+    if (inlineSize_) {
+      return elementsInline_[idx];
+    } else {
+      return elementsVector_[idx];
+    }
+  }
+
+  C10_NODISCARD const IValue& operator[](size_t idx) const {
+    if (inlineSize_) {
+      return elementsInline_[idx];
+    } else {
+      return elementsVector_[idx];
+    }
+  }
+
+  C10_NODISCARD IValue& at(size_t idx) {
+    if (inlineSize_) {
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
+      TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
+      return elementsInline_[idx];
+    } else {
+      return elementsVector_.at(idx);
+    }
+  }
+
+  C10_NODISCARD const IValue& at(size_t idx) const {
+    if (inlineSize_) {
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
+      TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
+      return elementsInline_[idx];
+    } else {
+      TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size());
+      return elementsVector_.at(idx);
+    }
+  }
+
+  C10_NODISCARD iterator begin() {
+    if (inlineSize_) {
+      return elementsInline_;
+    } else {
+      return elementsVector_.data();
+    }
+  }
+
+  C10_NODISCARD iterator end() {
+    if (inlineSize_) {
+      return elementsInline_ + inlineSize_;
+    } else {
+      return elementsVector_.data() + elementsVector_.size();
+    }
+  }
+
+  C10_NODISCARD const_iterator begin() const {
+    if (inlineSize_) {
+      return elementsInline_;
+    } else {
+      return elementsVector_.data();
+    }
+  }
+
+  C10_NODISCARD const_iterator end() const {
+    if (inlineSize_) {
+      return elementsInline_ + inlineSize_;
+    } else {
+      return elementsVector_.data() + elementsVector_.size();
+    }
+  }
+
+  C10_NODISCARD const_iterator cbegin() const {
+    return begin();
+  }
+
+  C10_NODISCARD const_iterator cend() const {
+    return end();
+  }
+
+  C10_NODISCARD std::vector vec() const & {
+    return asArrayRef().vec();
+  }
+
+  C10_NODISCARD IValue& back() {
+    return *(end() - 1);
+  }
+
+  C10_NODISCARD const IValue& back() const {
+    return *(end() - 1);
+  }
+
+  C10_NODISCARD std::vector vec() && {
+    std::vector result;
+    result.reserve(size());
+    for (auto&& iv : *this) {
+      result.push_back(std::move(iv));
+    }
+    return result;
+  }
+
+  // More compatibility shims for the overwhelming amount of code that
+  // likes to copy tuple elements into a vector; see comment above the
+  // copy constructor.
+  operator std::vector() const & {
+    return vec();
+  }
+
+  operator std::vector() && {
+    return vec();
+  }
+};
+
+template 
+struct TupleTypeFactory {};
+
+template <>
+struct TORCH_API TupleTypeFactory {
+  static TupleTypePtr create(std::vector types) {
+    return TupleType::create(std::move(types));
+  }
+  static TupleTypePtr fallback(const Type& type);
+};
+
+template <>
+struct TORCH_API TupleTypeFactory {
+  static DynamicTypePtr create(const std::vector& elemTypes);
+  static DynamicTypePtr fallback(const Type&);
+};
+
+struct TORCH_API Tuple : c10::intrusive_ptr_target {
+ private:
+  TupleElements elements_;
+  mutable c10::TypePtr type_; // lazily computed for unnamed tuples
+
+ public:
+  // named tuples have additional type information, so we
+  // directly create them tagged
+  static c10::intrusive_ptr createNamed(
+      std::vector elements_,
+      c10::TypePtr type_) {
+    return c10::make_intrusive(std::move(elements_), std::move(type_));
+  }
+
+  static c10::intrusive_ptr createNamed(
+      TupleElements elements_,
+      std::shared_ptr type_) {
+    return c10::make_intrusive(std::move(elements_), std::move(type_));
+  }
+
+  static c10::intrusive_ptr createNamed(
+      std::initializer_list elements_,
+      std::shared_ptr type_) {
+    return createNamed(TupleElements(c10::ArrayRef(elements_)), std::move(type_));
+  }
+
+  // MSVC apparently can't disambiguate the other two overloads of
+  // create when passed an initializer_list without this.
+  static c10::intrusive_ptr create(std::initializer_list elements_) {
+    return create(c10::ArrayRef(elements_));
+  }
+
+  static c10::intrusive_ptr create(std::vector elements_) {
+    return c10::make_intrusive(std::move(elements_));
+  }
+
+  static c10::intrusive_ptr create(TupleElements elements_) {
+    return c10::make_intrusive(std::move(elements_));
+  }
+
+  static c10::intrusive_ptr create(c10::ArrayRef elements_) {
+    return create(TupleElements(elements_));
+  }
+
+  static c10::intrusive_ptr create(IValue e1) {
+    return c10::make_intrusive(std::move(e1));
+  }
+
+  static c10::intrusive_ptr create(IValue e1, IValue e2) {
+    return c10::make_intrusive(std::move(e1), std::move(e2));
+  }
+
+  static c10::intrusive_ptr create(IValue e1, IValue e2, IValue e3) {
+    return c10::make_intrusive(std::move(e1), std::move(e2), std::move(e3));
+  }
+
+ private:
+  // Workaround inability to use `>` operator in template argument list.
+  template 
+  static constexpr bool hasMoreThanThreeArgs() {
+    return sizeof...(Args) > 3;
+  }
+
+ public:
+  template 
+  static c10::intrusive_ptr create(Args&&... elements_) {
+    switch (sizeof...(Args)) {
+      case 1:
+      case 2:
+      case 3:
+        return create(IValue(std::forward(elements_))...);
+      default:
+        return create(
+            std::vector{IValue(std::forward(elements_))...});
+    }
+  }
+
+  // Again, it would be nice to make this noncopyable, but there's a
+  // lot of extant code that copies Tuples.
+  // Tuple(const Tuple& rhs) = delete;
+
+  const TupleElements& elements() const& {
+    return elements_;
+  }
+
+  TupleElements elements() && {
+    return std::move(elements_);
+  }
+
+  void setElements(std::vector&& elements) {
+    elements_.setContents(std::move(elements));
+  }
+
+  void setElements(TupleElements&& elements) {
+    elements_ = std::move(elements);
+  }
+
+  void unsafeSetElement(size_t idx, const IValue& element) {
+    elements_[idx] = element;
+  }
+
+  void unsafeSetElement(size_t idx, IValue&& element) {
+    elements_[idx] = std::move(element);
+  }
+
+  size_t size() const {
+    return elements_.size();
+  }
+
+  template 
+  std::shared_ptr type() const {
+    if (!type_) {
+      type_ = TupleTypeFactory::create(fmap(elements(), [&](const IValue& v) {
+        return v.type();
+      }));
+    }
+    if (auto t = type_->cast()) {
+      return t;
+    }
+    return TupleTypeFactory::fallback(*type_);
+  }
+
+  static size_t hash(const Tuple& t) {
+    return c10::get_hash(t.elements());
+  }
+
+  TORCH_API friend bool operator==(
+      const ivalue::Tuple& lhs,
+      const ivalue::Tuple& rhs);
+
+ private:
+  // NOTE: If we try to avoid the overloads without
+  // `std::shared_ptr type` by defaulting it to nullptr, we
+  // end up having to call (part of) the shared_ptr destructor for
+  // `type` even though we should know statically it won't do
+  // anything.
+  explicit Tuple(std::vector elements)
+    : elements_(std::move(elements)){}
+
+  explicit Tuple(std::vector elements, c10::TypePtr type)
+    : elements_(std::move(elements)), type_(std::move(type)) {}
+
+  explicit Tuple(TupleElements&& elements)
+    : elements_(std::move(elements)) {}
+
+  explicit Tuple(TupleElements&& elements, std::shared_ptr type)
+    : elements_(std::move(elements)), type_(std::move(type)) {}
+
+  explicit Tuple(IValue&& e1)
+    : elements_(std::move(e1)) {}
+
+  explicit Tuple(IValue&& e1, std::shared_ptr type)
+    : elements_(std::move(e1)), type_(std::move(type)) {}
+
+  explicit Tuple(IValue&& e1, IValue&& e2)
+    : elements_(std::move(e1), std::move(e2)) {}
+
+  explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr type)
+    : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {}
+
+  explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3)
+    : elements_(std::move(e1), std::move(e2), std::move(e3)) {}
+
+  explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr type)
+    : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {}
+
+  friend class c10::intrusive_ptr;
+};
+
+struct Object;
+struct PyObjectHolder;
+struct EnumHolder;
+} // namespace ivalue
+
+// Future
+struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
+ private:
+  // Keep this private in order to force users to go through make_intrusive and
+  // thus prevent creating a Future that's not held by an intrusive_ptr.
+  explicit Future(TypePtr type, std::vector devices={})
+      : type_(std::move(type)),
+        impl_(getTypeOfDevices(devices)),
+        devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {}
+
+  friend c10::intrusive_ptr;
+
+  struct FutureCallback {
+    std::function callback;
+    bool uses_future; // whether the Future& passed in is actually used
+
+    template 
+    FutureCallback(T callback, bool uses_future)
+        : callback(std::move(callback)), uses_future(uses_future) {}
+  };
+
+ public:
+  Future(const Future&) = delete;
+  Future(Future&&) = delete;
+  Future& operator=(const Future&) = delete;
+  Future& operator=(Future&&) = delete;
+
+  struct TORCH_API FutureError final : public std::exception {
+    explicit FutureError(std::string&& error_msg_)
+        : error_msg(std::move(error_msg_)) {}
+
+    FutureError() = default;
+
+    const char* what() const noexcept override {
+      return error_msg.c_str();
+    }
+
+    std::string error_msg;
+  };
+
+  /**
+   * Wait on the future until it completes.
+   */
+  void wait() {
+    std::unique_lock lock(mutex_);
+    finished_cv_.wait(lock, [&]() -> bool { return completed_; });
+    synchronizeWithCurrentStreams();
+  }
+
+  /**
+   * Wait on the future until it completes and throw an
+   * exception if an error exists.
+   */
+  void waitAndThrow() {
+    wait();
+
+    if (eptr_) {
+      std::rethrow_exception(eptr_);
+    }
+  }
+
+  /**
+   * Explicitly mark the future as completed with the output value. Optionally,
+   * the storages for all tensors in IValue can be passed as well. The DataPtrs
+   * of these storages are used to synchronize CUDA streams. If storages isn't
+   * given we will attempt to extract it from the value, if we need to (this
+   * happens if a non-empty set of devices was given to the constructor). Thus
+   * one only needs to provide storages when 1) they cannot be extracted through
+   * IValue::getSubValues() or through pickling in case of Python object; or
+   * when 2) customized storage extraction is more efficient.
+   */
+  using WeakStorage = c10::weak_intrusive_ptr;
+  void markCompleted(
+      IValue value,
+      c10::optional> storages = c10::nullopt) {
+    // Start by performing all steps that can throw, before setting any field.
+    // Do this before even acquiring the mutex, because extractStorages might
+    // acquire the GIL, which could lead to a lock inversion with our mutex.
+    // See https://github.com/pytorch/pytorch/issues/58239.
+    std::vector actualStorages;
+    std::vector usedDevices;
+    try {
+      // FIXME We should always extract DataPtrs, in order to catch the case of
+      // users using CUDA values but forgetting to set devices, which currently
+      // leads to a silent synchronization/correctness issue. However, as this
+      // might worsen perf in CPU-only cases, we should only do so after careful
+      // benchmarks.
+      if (impl_.type() != c10::kCPU) {
+        actualStorages =
+            storages.has_value() ? std::move(*storages) : extractStorages(value);
+        usedDevices = getDevicesOfStorages(impl_, actualStorages);
+        ensureIsSubsetOfDevices(usedDevices, devices_);
+      }
+    } catch (const std::exception&) {
+      setError(std::current_exception());
+      return;
+    }
+
+    std::unique_lock lock(mutex_);
+    TORCH_CHECK(
+        !completed(),
+        "Attempting to mark a completed Future as complete again. Note that "
+        "a Future can only be marked completed once.");
+
+    // Only set value_ and completed_ flag once all checks and preparation steps
+    // have returned successfully to allow for proper error propagation.
+    value_ = std::move(value);
+    completed_ = true;
+
+    currentDevice_ = impl_.getDevice();
+    storages_ = std::move(actualStorages);
+    for (const c10::Device& device : usedDevices) {
+      c10::Event event(impl_.type());
+      event.record(impl_.getStream(device));
+      events_.push_back(std::move(event));
+    }
+
+    std::vector cbs;
+    cbs.swap(callbacks_);
+    lock.unlock();
+
+    finished_cv_.notify_all();
+    for (auto& callback : cbs) {
+      invokeCallback(std::move(callback.callback), callback.uses_future);
+    }
+  }
+
+  void markCompleted() {
+    markCompleted(IValue{});
+  }
+
+  void setError(std::exception_ptr eptr) {
+    std::unique_lock lock(mutex_);
+    setErrorInternal(std::move(eptr), lock);
+  }
+
+  void setErrorIfNeeded(std::exception_ptr eptr) {
+    std::unique_lock lock(mutex_);
+    if (completed_) {
+      // This should be rare and shouldn't cause log spew. Its important to
+      // log errors and thats why we have this log here.
+      std::string msg = c10::str(
+          "Skipping setting following error on the Future since "
+          "it is already marked completed (this is not necessarily "
+          "an error):\n",
+          tryRetrieveErrorMessageInternal(std::move(eptr)));
+      if (eptr_) {
+        msg += c10::str(
+            ", \nOriginal exception:\n",
+            tryRetrieveErrorMessageInternal(eptr_));
+      }
+      LOG(INFO) << msg;
+      return;
+    } else {
+      setErrorInternal(std::move(eptr), lock);
+    }
+  }
+
+  // Get the result of the current future.
+  IValue value() {
+    std::unique_lock lock(mutex_);
+    AT_ASSERT(completed());
+    if (eptr_) {
+      std::rethrow_exception(eptr_);
+    }
+    return value_;
+  }
+
+  // This accessor should only be used if we know that the future is
+  // completed() with no error.
+  const IValue& constValue() const {
+    std::unique_lock lock(mutex_);
+    AT_ASSERT(completed());
+    TORCH_INTERNAL_ASSERT(
+      !eptr_,
+      "value() accessor should only be used when future is not completed with ",
+      "an error, but future had the following error: ",
+      tryRetrieveErrorMessageInternal(eptr_)
+    );
+    return value_;
+  }
+
+  // This accessor should only be used if we know that the future is
+  // completed() with no error.
+  const std::vector& storages() const {
+    std::unique_lock lock(mutex_);
+    AT_ASSERT(completed());
+    AT_ASSERT(!eptr_);
+    return storages_;
+  }
+
+  /**
+   * Add a callback to the future.
+   * The callbacks will be executed once the future completes.
+   * If the future has already completed,
+   * this function will execute the callback immediately.
+   */
+  template 
+  void addCallback(T callback, bool uses_future = true) {
+#if __cpp_lib_is_invocable >= 201703
+    static_assert(
+        std::is_invocable_r::value,
+        "The callback must have signature void(Future&)");
+#endif
+
+    std::unique_lock lock(mutex_);
+    if (completed()) {
+      lock.unlock();
+      invokeCallback(std::move(callback), uses_future);
+      return;
+    }
+    callbacks_.emplace_back(std::move(callback), uses_future);
+  }
+
+  /**
+   * Add a callback to the future, and return another Future to hold the return
+   * value of the callback. This is necessary when the callback provider needs
+   * to know for sure when the callback has finished.
+   */
+  template 
+  c10::intrusive_ptr then(T callback, TypePtr type) {
+    using IValueWithStorages = std::tuple>;
+#if __cpp_lib_is_invocable >= 201703
+    static_assert(
+        std::disjunction<
+            std::is_invocable_r,
+            std::is_invocable_r>::value,
+        "The callback must have signature IValue(Future&) or "
+        "std::tuple>(Future&)");
+#endif
+    auto childFut = createInstance(::std::move(type));
+    addCallback([childFut,
+                 cb = std::move(callback)](Future& parentFut) mutable {
+      try {
+        if constexpr (::std::is_convertible_v, IValueWithStorages>) {
+          auto [ivalue, storages] = cb(parentFut);
+          childFut->markCompleted(::std::move(ivalue), ::std::move(storages));
+        } else {
+          childFut->markCompleted(cb(parentFut));
+        }
+      } catch (std::exception&) {
+        childFut->setError(std::current_exception());
+      }
+    });
+    return childFut;
+  }
+
+  template 
+  c10::intrusive_ptr thenAsync(T callback, TypePtr type) {
+#if __cpp_lib_is_invocable >= 201703
+    static_assert(
+        std::is_invocable_r, T, Future&>::value,
+        "The callback must have signature c10::intrusive_ptr(Future&)");
+#endif
+    auto childFut = createInstance(std::move(type));
+    addCallback(
+        [childFut, cb = std::move(callback)](Future& parentFut) mutable {
+          c10::intrusive_ptr intermediateFut;
+          try {
+            intermediateFut = cb(parentFut);
+          } catch (std::exception&) {
+            childFut->setError(std::current_exception());
+            return;
+          }
+          intermediateFut->addCallback(
+              [childFut = std::move(childFut)](Future& intermediateFut) {
+                if (intermediateFut.hasError()) {
+                  childFut->setError(intermediateFut.exception_ptr());
+                } else {
+                  childFut->markCompleted(
+                      intermediateFut.value(), intermediateFut.storages());
+                }
+              });
+        });
+    return childFut;
+  }
+
+  // Tries to retrieve the error message from std::exception_ptr.
+  std::string tryRetrieveErrorMessage() const {
+    TORCH_CHECK(hasError(), "No error present on the future.");
+    std::unique_lock lock(mutex_);
+    return tryRetrieveErrorMessageInternal(eptr_);
+  }
+
+  // Check if the current future has completed
+  bool completed() const {
+    return completed_;
+  }
+
+  bool hasValue() const {
+    std::unique_lock lock(mutex_);
+    return completed_ && !eptr_;
+  }
+
+  bool hasError() const {
+    std::unique_lock lock(mutex_);
+    return eptr_ ? true : false;
+  }
+
+  std::exception_ptr exception_ptr() const {
+    std::unique_lock lock(mutex_);
+    return eptr_;
+  }
+
+  TORCH_API friend std::ostream& operator<<(
+      std::ostream& out,
+      const Future& v);
+
+  const TypePtr& elementType() const {
+    return type_;
+  }
+
+  const std::vector& devices() const {
+    return devices_;
+  }
+
+  // This method should be used when one intends to manually create a child
+  // future, for example when implementing a customized version of then().
+  c10::intrusive_ptr createInstance(at::TypePtr type) {
+    return c10::make_intrusive(std::move(type), devices_);
+  }
+
+ private:
+
+  // This method should always be used when invoking a callback (regardless of
+  // how/when that happens) as it will ensure that the proper "environment" is
+  // set up before running the callback, as in, it will set up the CUDA streams,
+  // synchronize them with the value, and so on (if needed).
+  template
+  void invokeCallback(T callback, bool uses_future) {
+#if __cpp_lib_is_invocable >= 201703
+    static_assert(
+        std::is_invocable_r::value,
+        "The callback must have signature void(Future&)");
+#endif
+
+    // The synchronization performed below shouldn't be needed when the future
+    // is not used by the callback.
+    if (uses_future) {
+      c10::OptionalDeviceGuard deviceGuard(currentDevice_);
+
+      std::vector streams;
+      streams.reserve(devices_.size());
+      for (const c10::Device& device : devices_) {
+        streams.push_back(impl_.getStreamFromGlobalPool(device));
+      }
+      c10::MultiStreamGuard streamGuard(streams);
+      synchronizeWithCurrentStreams();
+      callback(*this);
+    } else {
+      callback(*this);
+    }
+  }
+
+  // This method should be called before this future's value is used, as it
+  // ensures that the CUDA streams that are "current" at the callsite properly
+  // synchronize with the value.
+  void synchronizeWithCurrentStreams() {
+    for (c10::Event& event : events_) {
+      event.block(impl_.getStream(event.device()));
+    }
+
+    for (const WeakStorage& weak_storage : storages_) {
+      c10::intrusive_ptr storage = weak_storage.lock();
+      if (!storage) {
+        continue;
+      }
+      if (!storage->device().is_cpu()) {
+        impl_.recordDataPtrOnStream(
+            storage->data_ptr(), impl_.getStream(storage->device()));
+      }
+    }
+  }
+
+  void setErrorInternal(
+      std::exception_ptr eptr,
+      std::unique_lock& lock) {
+    TORCH_CHECK(
+        !eptr_,
+        "Error already set on this Future: ",
+        tryRetrieveErrorMessageInternal(eptr_),
+        ", trying to set error: ",
+        tryRetrieveErrorMessageInternal(eptr));
+    TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed");
+    completed_ = true;
+    eptr_ = std::move(eptr);
+
+    std::vector cbs;
+    cbs.swap(callbacks_);
+    lock.unlock();
+
+    finished_cv_.notify_all();
+    for (auto& callback : cbs) {
+      invokeCallback(std::move(callback.callback), callback.uses_future);
+    }
+  }
+
+  // Tries to retrieve the error message from std::exception_ptr.
+  std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
+    try {
+      std::rethrow_exception(std::move(eptr));
+    } catch (const std::exception& e) {
+      return e.what();
+    } catch (...) {
+      return "Unknown Exception Type";
+    }
+  }
+
+  // Defined in ivalue.cpp.
+  static std::vector extractStorages(
+      const at::IValue& value);
+
+  static std::vector getDevicesOfStorages(
+      const c10::impl::VirtualGuardImpl& impl,
+      const std::vector& storages) {
+    c10::DeviceIndex deviceCount = impl.deviceCount();
+    std::vector isDeviceUsed(deviceCount, false);
+    for (const WeakStorage& weak_storage : storages) {
+      c10::intrusive_ptr storage = weak_storage.lock();
+      if (!storage) {
+        continue;
+      }
+      c10::Device device = storage->device();
+      if (!device.is_cpu()) {
+        TORCH_CHECK_VALUE(
+            device.type() == impl.type(),
+            "Expected all data ptrs to be on a device of type ",
+            impl.type(),
+            ", got one on device ",
+            device);
+        isDeviceUsed[device.index()] = true;
+      }
+    }
+    std::vector devices;
+    for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) {
+      if (isDeviceUsed[idx]) {
+        devices.emplace_back(impl.type(), idx);
+      }
+    }
+    return devices;
+  }
+
+  static std::string formatSetOfDevices(
+      const std::vector& devices) {
+    if (devices.empty()) {
+      return "(none)";
+    }
+    std::ostringstream oss;
+    oss << devices[0];
+    for (const auto idx : c10::irange(1, devices.size())) {
+      if (idx == devices.size() - 1) {
+        oss << " and ";
+      } else {
+        oss << ", ";
+      }
+      oss << devices[idx];
+    }
+    return oss.str();
+  }
+
+  static c10::DeviceType getTypeOfDevices(
+      const std::vector& devices) {
+    if (devices.empty()) {
+      return c10::kCPU;
+    }
+    c10::DeviceType deviceType = devices[0].type();
+    for (const auto idx : c10::irange(1, devices.size())) {
+      TORCH_CHECK_VALUE(
+          devices[idx].type() == deviceType,
+          "Expected all devices to be of the same type, but got a mismatch between ",
+          devices[0],
+          " and ",
+          devices[idx]);
+    }
+    return deviceType;
+  }
+
+  // We need devices to be sorted in order to use ensureIsSubsetOfDevices.
+  static std::vector sortAndDeduplicateDevices(
+      const c10::impl::VirtualGuardImpl& /*impl*/,
+      std::vector devices) {
+    std::sort(
+      devices.begin(), devices.end(),
+      [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
+    // Deduplicate by compacting.
+    size_t targetIdx = 0;
+    for (const auto sourceIdx : c10::irange(devices.size())) {
+      TORCH_CHECK_VALUE(
+          devices[sourceIdx].has_index(),
+          "Expected devices to have indices, got ", devices[sourceIdx]);
+      if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) {
+        // It's a duplicate, skip it.
+        continue;
+      }
+      if (sourceIdx != targetIdx) {
+        devices[targetIdx] = devices[sourceIdx];
+      }
+      targetIdx++;
+    }
+    // If there were duplicates there's now a gap at the end: trim it. Resizing
+    // requires the item type to be default-constructible (which c10::Device is
+    // not) because in principle it could be required to create new items. Since
+    // we know we'll shrink the vector, we provide a custom dummy value instead.
+    devices.resize(targetIdx, c10::Device(c10::kCPU));
+    return devices;
+  }
+
+  static void ensureIsSubsetOfDevices(
+      const std::vector& subset,
+      const std::vector& superset) {
+    // We assume the devices in both vectors have the same consistent type, and
+    // their indices are unique and sorted.
+    std::vector excessDevices;
+    std::set_difference(
+        subset.begin(),
+        subset.end(),
+        superset.begin(),
+        superset.end(),
+        std::back_inserter(excessDevices),
+        [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
+    TORCH_CHECK_VALUE(
+        excessDevices.empty(),
+        "The result contained tensors residing on device(s) ",
+        formatSetOfDevices(excessDevices),
+        " which are not among the expected device(s) ",
+        formatSetOfDevices(superset));
+  }
+
+  mutable std::mutex mutex_;
+  std::atomic_bool completed_ = {false}; // is this future complete
+  std::condition_variable finished_cv_;
+
+  IValue value_; // when finished the value
+  TypePtr type_;
+  std::vector callbacks_;
+  std::exception_ptr eptr_;
+
+  // An upcast pointer to a virtual class which allows us to manipulate events,
+  // streams, ... in a generic way, without an explicit dependency on CUDA.
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+  const c10::impl::VirtualGuardImpl impl_;
+
+  // The device that was current when markCompleted was called, which we'll
+  // restore when invoking callbacks. It's optional because we'll only store it
+  // if the future completes successfully.
+  optional currentDevice_;
+
+  // The events that correspond to the completion of the async I/O kernels. They
+  // are recorded on the appropriate streams when the future is marked completed
+  // and can then be queried/waited/blocked on. There is one event for each
+  // distinct device on which the value's tensors reside.
+  std::vector events_;
+
+  // A cached version of the storages extracted from the value when the future
+  // is first marked completed.
+  std::vector storages_;
+
+  // The bounding set of devices that this future, and any of its children, is
+  // allowed to use. This is a superset of the set of devices used by the events
+  // above. We need this to know what streams (for which devices) to set as
+  // current when invoking a callback, thus allowing the callback to use devices
+  // that the parent future didn't use. This field is set to the value provided
+  // in the constructor and will be "inherited" by all child futures.
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
+  const std::vector devices_;
+};
+
+struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target {
+ private:
+  explicit Await(TypePtr elType, std::function fn)
+      : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {}
+
+  explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { }
+
+  friend c10::intrusive_ptr;
+
+ public:
+  Await(const Await&) = delete;
+  Await(Await&&) = delete;
+  Await& operator=(const Await&) = delete;
+  Await& operator=(Await&&) = delete;
+
+  IValue wait() {
+    if (!completed_) {
+      TORCH_CHECK(fn_, "Incompleted Await: fn can't be None");
+      value_ = fn_();
+      completed_ = true;
+      args_ = {};
+    }
+    return value_;
+  }
+
+  IValue value() {
+    TORCH_CHECK(completed_, "Await must be completed");
+    return value_;
+  }
+
+  void setFn(std::function fn) {
+    fn_ = std::move(fn);
+  }
+
+  bool completed() {
+    return completed_;
+  }
+
+  void markCompleted(IValue value) {
+    value_ = std::move(value);
+    completed_ = true;
+  }
+
+  TORCH_API friend std::ostream& operator<<(
+      std::ostream& out,
+      const Await& v);
+
+  const TypePtr& elementType() const {
+    return elType_;
+  }
+
+  const TypePtr& type() const {
+    return type_;
+  }
+
+  void setArgs(std::vector args) {
+    args_ = std::move(args);
+  }
+
+  std::vector& args() {
+    return args_;
+  }
+
+ private:
+  TypePtr elType_;
+  TypePtr type_;
+  std::vector args_;
+  std::function fn_;
+  IValue value_;
+  bool completed_{};
+};
+
+// Input is a list of Futures with the same target type.
+// Output is a Future to the List of completed Futures.
+TORCH_API intrusive_ptr collectAll(
+    const c10::List>& srcs);
+// Input is a List of Futures with the same target type.
+// Output is a Future that will be updated with a seen value.
+TORCH_API intrusive_ptr collectAny(
+    const c10::List>& srcs);
+
+// User-defined object.
+struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
+ public:
+  // In general, class types hold a shared_ptr to its owning CompilationUnit,
+  // so that its type and methods do not get deallocated while the class exists.
+  // However, the CompilationUnit holds ownership of the type's graphs, so
+  // inserting a constant object into a Graph would create a reference cycle if
+  // that constant object held a shared_ptr to its CU. For these objects we
+  // instatiate them with non-owning references to its CU
+  Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
+    slots_.resize(numSlots);
+  }
+
+  Object(StrongTypePtr type, size_t numSlots)
+      : type_(WeakOrStrongTypePtr(std::move(type))) {
+    slots_.resize(numSlots);
+  }
+
+  static c10::intrusive_ptr create(
+      WeakOrStrongTypePtr type,
+      size_t numSlots) {
+    return c10::make_intrusive(std::move(type), numSlots);
+  }
+
+  static c10::intrusive_ptr create(
+      StrongTypePtr type,
+      size_t numSlots) {
+    return c10::make_intrusive(std::move(type), numSlots);
+  }
+
+  static c10::intrusive_ptr create(ClassTypePtr classType, size_t numSlots);
+
+  /**
+   * Slot API.
+   *
+   * Attributes are stored as a simple vector so that lookups are fast at
+   * runtime. A "slot" is just an index into that vector, which can be computed
+   * statically if you have access to the class type. Use this API if you are
+   * writing compiler stuff.
+   */
+  void setSlot(size_t slot, IValue v) {
+    if (slot >= slots_.size()) {
+      // for module types, it is possible that the members of the class have
+      // expanded after the object was created. In this case, we expand
+      // the slots to the right size
+      resizeObject(slot);
+    }
+    slots_[slot] = std::move(v);
+  }
+
+  const IValue& getSlot(size_t slot) const {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size());
+    // NOTE: This lookup is fairly hot, so we use unchecked access to the
+    // vector.  Errors should still be detectable with ASan.
+    return slots_[slot];
+  }
+
+  void unsafeRemoveSlot(size_t slot) {
+    TORCH_CHECK(slot < slots_.size());
+    slots_.erase(slots_.begin() + static_cast(slot));
+  }
+
+  /**
+   * Attribute API.
+   *
+   * Wrappers around the slot stuff so that users can access attributes
+   * directly. Use this API if you are a user.
+   *
+   * Note: Unlike in Python, TorchScript must make a distinction between
+   * attributes (which are IValues) and methods (which are Methods). If you
+   * want a method, use `obj.type()->getMethod()`
+   */
+  IValue getAttr(const std::string& name) const;
+  void setAttr(const std::string& name, IValue v);
+  // Remove attribute by name, caller is responsible for
+  // the safety of this operation
+  // We didn't remove the attribute in the type because the type
+  // might be shared by multiple objects.
+  // Therefore after removing attribute, the object is in an inconsistent
+  // state where it has more attribute types in its Type than
+  // the attribute slots it has, user needs to make sure the object
+  // has consistent by removing the attribute in type as well
+  void unsafeRemoveAttr(const std::string& name);
+
+  std::string name() const;
+
+  const std::vector& slots() const {
+    return slots_;
+  }
+  std::shared_ptr type() const;
+
+  std::shared_ptr compilation_unit() {
+    if (type_.holds_strong_ref()) {
+      return type_.cu_.getStrongRefOrThrow();
+    } else {
+      auto weak_ptr = type_.cu_.getWeakRefOrThrow();
+      return std::shared_ptr(weak_ptr);
+    }
+  }
+
+  c10::intrusive_ptr copy_to_weak_compilation_ref() const;
+
+  void unsafe_make_weak_compilation_ref() {
+    type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr());
+  }
+
+  c10::intrusive_ptr copy() const;
+
+  c10::intrusive_ptr deepcopy(
+      c10::optional device = c10::nullopt) const;
+
+  c10::intrusive_ptr deepcopy(
+      IValue::HashAliasedIValueMap& memo,
+      c10::optional device = c10::nullopt) const;
+
+  bool is_weak_compilation_ref() const {
+    return !type_.holds_strong_ref();
+  }
+
+  bool is_empty_strong_compilation_ref() const {
+    return type_.holds_empty_strong_ref();
+  }
+
+ private:
+  void resizeObject(size_t slot);
+  WeakOrStrongTypePtr type_;
+  std::vector slots_;
+};
+
+// virtual ivalue PyObjectHolder that hold a py::object, we make this virtual
+// because the py::object and refcounting logic should happen in libtorch_python
+// see concrete implementation in python_ivalue.h
+struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
+ public:
+  virtual PyObject* getPyObject() = 0;
+  virtual c10::InferredType tryToInferType() = 0;
+  virtual IValue toIValue(const TypePtr& type, c10::optional N = c10::nullopt) = 0;
+  virtual std::string toStr() = 0;
+  virtual std::vector extractTensors() = 0;
+
+  ~PyObjectHolder() override = default;
+};
+
+struct ivalue::EnumHolder : c10::intrusive_ptr_target {
+ public:
+  EnumHolder(std::shared_ptr type, std::string name, IValue value)
+      : type_(std::move(type)),
+        name_(std::move(name)),
+        value_(std::move(value)) {}
+
+  bool is(const ivalue::EnumHolder& rhs) {
+    return *this == rhs;
+  }
+
+  friend bool operator==(
+      const ivalue::EnumHolder& lhs,
+      const ivalue::EnumHolder& rhs);
+
+  TORCH_API friend std::ostream& operator<<(
+      std::ostream& out,
+      const ivalue::EnumHolder& v);
+
+  TORCH_API const std::string& qualifiedClassName() const;
+
+  const std::string& unqualifiedClassName() const;
+
+  const std::string& name() const {
+    return name_;
+  }
+
+  const IValue& value() const {
+    return value_;
+  }
+
+  std::shared_ptr type() const {
+    return type_;
+  }
+
+ private:
+  std::shared_ptr type_;
+  std::string name_;
+  IValue value_;
+};
+
+#undef TORCH_FORALL_TAGS
+
+namespace detail {
+
+struct _guarded_unsigned_long_unique_dummy final {
+  _guarded_unsigned_long_unique_dummy(int64_t){};
+};
+using _guarded_unsigned_long = std::conditional_t<
+    std::is_same::value ||
+        std::is_same::value,
+    _guarded_unsigned_long_unique_dummy,
+    unsigned long>;
+
+} // namespace detail
+
+inline ivalue::Object& IValue::toObjectRef() const {
+  AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
+  return *static_cast(payload.u.as_intrusive_ptr);
+}
+
+// note: when adding a DEFINE_TO case here you should also add a
+// toX method to IValue. These named methods are much more discoverable
+// than the to templated function.
+
+#define DEFINE_TO(T, method_name)                          \
+  template <>                                              \
+  inline T IValue::to()&& {                             \
+    return static_cast(std::move(*this).method_name()); \
+  }                                                        \
+  template <>                                              \
+  inline c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { \
+    typedef c10::detail::ivalue_to_const_ref_overload_return::type return_type;          \
+    return static_cast(this->method_name());                                   \
+  }
+
+DEFINE_TO(at::Tensor, toTensor)
+DEFINE_TO(at::Storage, toStorage)
+DEFINE_TO(c10::Stream, toStream)
+DEFINE_TO(float, toDouble)
+DEFINE_TO(double, toDouble)
+DEFINE_TO(c10::complex, toComplexDouble)
+DEFINE_TO(unsigned char, toInt)
+DEFINE_TO(signed char, toInt)
+DEFINE_TO(unsigned short, toInt)
+DEFINE_TO(short, toInt)
+DEFINE_TO(int, toInt)
+DEFINE_TO(uint32_t, toInt)
+DEFINE_TO(uint64_t, toInt)
+DEFINE_TO(detail::_guarded_unsigned_long, toInt)
+DEFINE_TO(int64_t, toInt)
+DEFINE_TO(bool, toBool)
+DEFINE_TO(c10::intrusive_ptr, toBlob);
+DEFINE_TO(c10::intrusive_ptr, toString)
+DEFINE_TO(c10::intrusive_ptr, toObject)
+DEFINE_TO(at::Scalar, toScalar)
+DEFINE_TO(c10::List, toIntList)
+DEFINE_TO(c10::List, toDoubleList)
+DEFINE_TO(c10::List>, toComplexDoubleList)
+DEFINE_TO(c10::List, toBoolList)
+DEFINE_TO(c10::List, toTensorList)
+DEFINE_TO(c10::impl::GenericList, toList)
+DEFINE_TO(c10::impl::GenericDict, toGenericDict)
+DEFINE_TO(c10::intrusive_ptr, toTuple)
+DEFINE_TO(std::string, toStringRef)
+DEFINE_TO(c10::string_view, toStringView)
+DEFINE_TO(c10::intrusive_ptr, toFuture)
+DEFINE_TO(c10::intrusive_ptr, toAwait)
+DEFINE_TO(c10::intrusive_ptr, toRRef)
+DEFINE_TO(c10::intrusive_ptr, toQuantizer)
+DEFINE_TO(IValue, toIValue)
+DEFINE_TO(c10::Device, toDevice)
+DEFINE_TO(at::ScalarType, toScalarType)
+DEFINE_TO(at::Layout, toLayout)
+DEFINE_TO(at::MemoryFormat, toMemoryFormat)
+DEFINE_TO(at::QScheme, toQScheme)
+DEFINE_TO(at::Dimname, toDimname)
+DEFINE_TO(at::Generator, toGenerator)
+DEFINE_TO(c10::SymInt, toSymInt)
+DEFINE_TO(c10::SymFloat, toSymFloat)
+DEFINE_TO(c10::SymBool, toSymBool)
+
+template 
+struct _fake_type {};
+
+// generic_to converts an IValue from a generic list or generic dict
+// to a concrete list/dict type likelike List, Dict<...> or optional.
+// Note that in the case of lists, this only works for IValue-based lists,
+// i.e. not for int64_t, double, ...
+// generic_to is an implementation detail of IValue::to and not
+// supposed to be called directly.
+// The _fake_type parameter allows us to overload
+// based on the return type.
+template 
+// TODO this is deprecated but we don't throw a warning because a lot of ops in
+// native_functions.yaml still return std::vector.
+// C10_DEPRECATED_MESSAGE("IValues based on std::vector are potentially slow
+// and deprecated. Please use torch::List instead.")
+std::vector generic_to(IValue ivalue, _fake_type>) {
+  // We need to do a deep copy of the vector because there might be other
+  // references to this same IValue that also use the list. We can't just
+  // move the elements out.
+  auto list = std::move(ivalue).to>();
+  std::vector result;
+  result.reserve(list.size());
+  for (Elem v : list) {
+    result.push_back(std::move(v));
+  }
+  return result;
+}
+
+template 
+c10::intrusive_ptr IValue::toCustomClass() && {
+  static_assert(
+      std::is_base_of::value == true,
+      "toCustomClass requires that template parameter T must inherit "
+      "from torch::CustomClassHolder");
+  auto obj = toObject();
+  TORCH_CHECK(
+      obj->slots().size() == 1,
+      "Tried to cast IValue to custom class but it did "
+      "not contain a custom class!");
+  const auto* expected_type = c10::getCustomClassType>().get();
+  ivalue::checkCustomClassType(expected_type, type().get());
+  auto userObj =
+      c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule());
+  return userObj;
+}
+
+template 
+c10::intrusive_ptr IValue::toCustomClass() const& {
+  static_assert(
+      std::is_base_of::value == true,
+      "toCustomClass requires that template parameter T must inherit "
+      "from torch::CustomClassHolder");
+  auto obj = toObject();
+  TORCH_CHECK(
+      obj->slots().size() == 1,
+      "Tried to cast IValue to custom class but it did "
+      "not contain a custom class!");
+  const auto* expected_type = c10::getCustomClassType>().get();
+  ivalue::checkCustomClassType(expected_type, type().get());
+  auto userObj =
+      c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule());
+  return userObj;
+}
+
+template 
+T generic_to(IValue ivalue, _fake_type) {
+  using ElemType = typename std::remove_pointer::type::element_type;
+  return std::move(ivalue).toCustomClass();
+}
+
+template 
+tagged_capsule generic_to(IValue ivalue, _fake_type>) {
+  return tagged_capsule{std::move(ivalue)};
+}
+
+template 
+c10::List generic_to(IValue ivalue, _fake_type>) {
+  return impl::toTypedList(std::move(ivalue).toList());
+}
+
+template 
+static T createVectorLikeFromList(const c10::detail::ListImpl* impl) {
+  T result;
+  result.reserve(impl->list.size());
+  for (const auto & i : impl->list) {
+    result.push_back(i.to());
+  }
+  return result;
+}
+
+template 
+static std::vector createVectorFromList(const c10::detail::ListImpl* impl) {
+  return createVectorLikeFromList>(impl);
+}
+
+template 
+std::vector createVectorFromList(const c10::List& impl) {
+  std::vector result;
+  result.reserve(impl.size());
+  for (size_t i = 0, N = impl.size(); i < N; ++i) {
+    result.push_back(impl[i]);
+  }
+  return result;
+}
+
+template 
+OptionalArray generic_to(IValue ivalue, _fake_type>) {
+  if (ivalue.isNone()) {
+    return {};
+  }
+  return createVectorFromList(
+    std::move(ivalue).to>()
+  );
+}
+
+namespace detail {
+template 
+std::array generic_to_array(
+    IValue ivalue,
+    _fake_type>,
+    std::index_sequence) {
+  // We need to do a deep copy of the array because there might be other
+  // references to this same IValue that also use the list. We can't just
+  // move the elements out.
+  auto list = std::move(ivalue).to>();
+  TORCH_CHECK(
+      list.size() == sizeof...(I),
+      "Tried to convert a List with ",
+      list.size(),
+      " elements to a fixed-size array of size ",
+      sizeof...(I));
+  return {list[I]...};
+}
+} // namespace detail
+
+template 
+std::array generic_to(
+    IValue ivalue,
+    _fake_type> ft) {
+  return detail::generic_to_array(ivalue, ft, std::make_index_sequence());
+}
+
+template 
+c10::Dict generic_to(
+    IValue ivalue,
+    _fake_type>) {
+  return impl::toTypedDict(std::move(ivalue).toGenericDict());
+}
+
+template 
+C10_DEPRECATED_MESSAGE(
+    "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.")
+std::unordered_map generic_to(
+    IValue ivalue,
+    _fake_type>) {
+  std::unordered_map specialized_dict;
+
+  for (const auto& item : std::move(ivalue).toGenericDict()) {
+    specialized_dict[item.key().template to()] = item.value().template to();
+  }
+
+  return specialized_dict;
+}
+
+template 
+c10::optional generic_to(IValue ivalue, _fake_type>) {
+  if (ivalue.isNone()) {
+    return c10::nullopt;
+  }
+  return std::move(ivalue).to();
+}
+
+namespace detail {
+template 
+Tuple generic_to_tuple_impl(
+    const ivalue::TupleElements& t,
+    std::index_sequence) {
+  return std::make_tuple(
+      t[INDEX].to::type>()...);
+}
+} // namespace detail
+
+template <
+    typename... Args,
+    typename Indices = std::make_index_sequence,
+    std::enable_if_t<
+        !std::disjunction<
+            std::is_lvalue_reference...,
+            std::negation>...>::value,
+        std::nullptr_t> = nullptr>
+std::tuple generic_to(const IValue& ivalue, _fake_type>) {
+  const auto& vals = ivalue.toTupleRef().elements();
+  TORCH_CHECK(vals.size() == sizeof...(Args));
+  return detail::generic_to_tuple_impl>(vals, Indices{});
+}
+
+template 
+inline T IValue::to() && {
+  return generic_to(std::move(*this), _fake_type{});
+}
+
+template <>
+inline c10::optional IValue::to() && {
+  // In the default implementation, the IValue is destroyed with std::move.
+  // But if the unboxed type is optional we cannot destroy
+  // the IValue.
+  return generic_to(*this, _fake_type>{});
+}
+
+template 
+inline typename c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& {
+  return generic_to(*this, _fake_type{});
+}
+
+inline c10::List IValue::toIntList() && {
+  AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
+  return c10::List(moveToIntrusivePtr());
+}
+inline c10::List IValue::toIntList() const& {
+  AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
+  return c10::List(toIntrusivePtr());
+}
+inline std::vector IValue::toIntVector() const {
+  AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toIntVector on null intrusive_ptr IValue");
+  return createVectorFromList(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline std::vector IValue::toSymIntVector() const {
+  AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toSymIntVector on null intrusive_ptr IValue");
+  return createVectorFromList(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline at::DimVector IValue::toDimVector() const {
+  AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toDimVector on null intrusive_ptr IValue");
+  return createVectorLikeFromList(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline c10::List IValue::toDoubleList() && {
+  AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
+  return c10::List(moveToIntrusivePtr());
+}
+inline c10::List IValue::toDoubleList() const& {
+  AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
+  return c10::List(toIntrusivePtr());
+}
+inline std::vector IValue::toDoubleVector() const {
+  AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toDoubleVector on null intrusive_ptr IValue");
+  return createVectorFromList(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline c10::List> IValue::toComplexDoubleList() && {
+  AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
+  return c10::List>(moveToIntrusivePtr());
+}
+inline c10::List> IValue::toComplexDoubleList() const& {
+  AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
+  return c10::List>(toIntrusivePtr());
+}
+inline std::vector> IValue::toComplexDoubleVector() const {
+  AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toComplexDoubleVector on null intrusive_ptr IValue");
+  return createVectorFromList>(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline c10::List IValue::toBoolList() && {
+  AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
+  return c10::List(moveToIntrusivePtr());
+}
+inline c10::List IValue::toBoolList() const& {
+  AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
+  return c10::List(toIntrusivePtr());
+}
+inline c10::List IValue::toTensorList() && {
+  AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
+  return c10::List(moveToIntrusivePtr());
+}
+inline c10::List IValue::toTensorList() const& {
+  AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
+  return c10::List(toIntrusivePtr());
+}
+inline std::vector IValue::toTensorVector() const {
+  AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toTensorVector on null intrusive_ptr IValue");
+  return createVectorFromList(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline c10::List> IValue::toOptionalTensorList() && {
+  AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
+  return c10::List>(moveToIntrusivePtr());
+}
+inline c10::List> IValue::toOptionalTensorList() const& {
+  AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
+  return c10::List>(toIntrusivePtr());
+}
+inline std::vector> IValue::toOptionalTensorVector() const {
+  AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toOptionalTensorVector on null intrusive_ptr IValue");
+  return createVectorFromList>(
+      static_cast(payload.u.as_intrusive_ptr));
+}
+inline c10::List IValue::toList() && {
+  AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
+  return c10::List(moveToIntrusivePtr());
+}
+inline c10::List IValue::toList() const& {
+  AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
+  return c10::List(toIntrusivePtr());
+}
+inline c10::ArrayRef IValue::toListRef() const {
+  AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toListRef on null intrusive_ptr IValue");
+  return static_cast(payload.u.as_intrusive_ptr)
+      ->list;
+}
+inline c10::Dict IValue::toGenericDict() && {
+  AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
+  return c10::Dict(moveToIntrusivePtr());
+}
+inline c10::Dict IValue::toGenericDict() const& {
+  AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
+  return c10::Dict(toIntrusivePtr());
+}
+inline c10::intrusive_ptr IValue::toTuple() && {
+  AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
+  return moveToIntrusivePtr();
+}
+inline c10::intrusive_ptr IValue::toTuple() const& {
+  AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
+  return toIntrusivePtr();
+}
+inline ivalue::Tuple& IValue::toTupleRef() const {
+  AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toTupleRef on null intrusive_ptr IValue");
+  return *static_cast(
+      payload.u.as_intrusive_ptr);
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::Tuple) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+template <
+    typename... Args,
+    std::enable_if_t<
+        !std::disjunction<
+            std::is_lvalue_reference...,
+            std::negation>...>::value,
+        std::nullptr_t>>
+inline IValue::IValue(const std::tuple& t)
+    : IValue(c10::guts::apply(c10::ivalue::Tuple::create, t)) {
+}
+
+template <
+    typename... Args,
+    std::enable_if_t<
+        !std::disjunction<
+            std::is_lvalue_reference...,
+            std::negation>...>::value,
+        std::nullptr_t>>
+inline IValue::IValue(std::tuple&& t)
+    : IValue(c10::guts::apply(c10::ivalue::Tuple::create, std::move(t))) {
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::String) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+inline IValue::IValue(std::string v)
+    : IValue(ivalue::ConstantString::create(std::move(v))) {}
+
+inline IValue::IValue(c10::impl::GenericList v)
+    : tag(Tag::GenericList) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
+}
+
+template >
+inline IValue::IValue(c10::List&& v) : IValue(impl::toList(std::move(v))) {}
+template >
+inline IValue::IValue(const c10::List& v) : IValue(impl::toList(v)) {}
+template >
+inline IValue::IValue(at::ArrayRef v) : IValue(c10::List()) {
+  auto list = to>();
+  list.reserve(v.size());
+  for (const auto& e : v) {
+    list.push_back(e);
+  }
+}
+template >
+inline IValue::IValue(at::ArrayRef v) : IValue() {
+  auto vi = c10::asIntArrayRefSlowOpt(v);
+  if (vi.has_value()) {
+    // This list is entirely integers; ensure it is typed as
+    // an IntList so toIntList works
+    *this = IValue(*vi);
+  } else {
+    // This list has SymInts; type it as a SymInt
+    *this = IValue(impl::toList(c10::List()));
+    auto list = to>();
+    list.reserve(v.size());
+    for (const auto& e : v) {
+      list.push_back(e);
+    }
+  }
+}
+template >
+inline IValue::IValue(at::OptionalArrayRef mb_v) : IValue() {
+  if (!mb_v.has_value()) return;
+  *this = IValue(*mb_v);
+}
+template >
+inline IValue::IValue(const std::vector& v) : IValue() {
+  *this = IValue(at::ArrayRef(v));
+}
+template >
+inline IValue::IValue(std::vector&& v) : IValue() {
+  auto vi = c10::asIntArrayRefSlowOpt(v);
+  if (vi.has_value()) {
+    // This list is entirely integers; ensure it is typed as
+    // an IntList so toIntList works
+    *this = IValue(*vi);
+  } else {
+    // This list has SymInts; type it as a SymInt
+    *this = IValue(impl::toList(c10::List()));
+    auto list = to>();
+    list.reserve(v.size());
+    for (auto& e : v) {
+      list.push_back(std::move(e));
+    }
+  }
+}
+template >
+inline IValue::IValue(const std::vector& v) : IValue(c10::List()) {
+  auto list = to>();
+  list.reserve(v.size());
+  for (const auto& e : v) {
+    list.push_back(e);
+  }
+}
+
+template >
+inline IValue::IValue(std::vector&& v) : IValue(c10::List()) {
+  auto list = to>();
+  list.reserve(v.size());
+  if constexpr (std::is_same_v) {
+    for (auto e : v) {
+      list.push_back(e);
+    }
+  } else {
+    for (auto& e : v) {
+      list.push_back(std::move(e));
+    }
+  }
+}
+
+template >
+inline IValue::IValue(c10::OptionalArrayRef v) : IValue() {
+  if (v.has_value()) {
+    *this = IValue(std::move(*v));
+  }
+}
+
+template 
+inline IValue::IValue(std::array v) : IValue(c10::List()) {
+  auto list = to>();
+  list.reserve(v.size());
+  for (auto& e : v) {
+    list.push_back(std::move(e));
+  }
+}
+
+template >
+inline IValue::IValue(c10::IListRef v) : IValue() {
+  constexpr bool boxed_type_constructs_ivalue =
+      std::is_constructible::boxed_type>::value;
+  // First, we try to use the boxed value.
+  // If we fail (either it's not in the boxed state, or its boxed type
+  // can not construct an IValue), we fallback to copying the list.
+  if (boxed_type_constructs_ivalue && v.isBoxed()) {
+    *this = IValue(impl::toList(v.toBoxed()));
+  } else {
+    c10::List list;
+    list.reserve(v.size());
+    for (const auto& t : v) {
+      list.push_back(t);
+    }
+    *this = IValue(impl::toList(std::move(list)));
+  }
+}
+
+inline IValue::IValue(c10::impl::GenericDict v)
+    : tag(Tag::GenericDict) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
+}
+template 
+inline IValue::IValue(c10::Dict v)
+    : IValue(impl::toGenericDict(std::move(v))) {}
+
+template 
+inline IValue::IValue(std::unordered_map v)
+    : IValue(Dict()) {
+  auto dict = to>();
+  dict.reserve(v.size());
+  for (auto& e : v) {
+    dict.insert(std::move(e.first), std::move(e.second));
+  }
+}
+
+template >
+inline IValue::IValue(c10::optional v) : IValue() {
+  if (v.has_value()) {
+    *this = IValue(std::move(*v));
+  }
+}
+
+inline IValue::IValue(c10::nullopt_t) : IValue() {}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::Object) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::PyObject) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::Enum) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+inline IValue IValue::make_capsule(
+    intrusive_ptr blob) {
+  IValue iv;
+  iv.tag = Tag::Capsule;
+  iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
+  return iv;
+}
+
+template <
+    typename T,
+    std::enable_if_t::value, int>>
+IValue::IValue(c10::intrusive_ptr custom_class) : tag(Tag::Object) {
+  auto classType = []() {
+    try {
+      return c10::getCustomClassType>();
+    } catch (const c10::Error&) {
+      throw c10::Error(
+          "Trying to instantiate a class that isn't a registered custom class: " +
+          std::string(c10::util::get_fully_qualified_type_name()),
+          "");
+    }
+  }();
+  auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1);
+  ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release());
+
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::Future) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::Await) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::RRef) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+inline IValue::IValue(c10::intrusive_ptr v)
+    : tag(Tag::Quantizer) {
+  payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
+}
+
+template 
+inline IValue::IValue(c10::complex c)
+    : tag(Tag::ComplexDouble) {
+  auto v = c10::make_intrusive(c);
+  payload.u.as_intrusive_ptr = v.release();
+}
+
+inline const std::string& IValue::toStringRef() const {
+  AT_ASSERT(isString(), "Expected String but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toStringRef on null intrusive_ptr IValue");
+  return static_cast(
+             payload.u.as_intrusive_ptr)
+      ->string();
+}
+inline c10::optional> IValue::
+    toOptionalStringRef() const {
+  if (isNone()) {
+    return c10::nullopt;
+  }
+  AT_ASSERT(isString(), "Expected optional but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toOptionalStringRef on null intrusive_ptr IValue");
+  return std::reference_wrapper(
+      static_cast(payload.u.as_intrusive_ptr)
+          ->string());
+}
+
+inline c10::string_view IValue::toStringView() const {
+  AT_ASSERT(isString(), "Expected String but got ", tagKind());
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
+      "called toStringView on null intrusive_ptr IValue");
+  return static_cast(
+        payload.u.as_intrusive_ptr)
+    ->string_view();
+}
+
+inline PyObject* IValue::toPyObject() const {
+  return toPyObjectHolder()->getPyObject();
+}
+
+template 
+inline optional IValue::toOptional() {
+  if (this->isNone()) {
+    return nullopt;
+  }
+  return this->to();
+}
+
+template 
+inline optional IValue::toOptional() const {
+  if (this->isNone()) {
+    return nullopt;
+  }
+  return this->to();
+}
+
+inline bool IValue::isCustomClass() const {
+  return torch::isCustomClass(*this);
+}
+
+inline bool IValue::isSameIdentity(const IValue& rhs) const {
+  // We choose to not use memcmp for payload check due to potential random
+  // padding characters on union type
+
+  // Semantics:
+  // 1. Immutable primitive values of the same type (Int, Double, None, Bool,
+  // Str) return value equality
+  // 2. If it is a tensor type, we need to take undefined tensor into account
+  // 3. Undefined_tensor is None and vice versa should be true
+  // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when
+  // the pointed-to object is the same.
+  // 5. False for all other comparisons.
+  if (this->isNone() && rhs.isNone()) {
+    return true;
+  } else if (this->isBool() && rhs.isBool()) {
+    // for bool type, do equality check
+    return this->toBool() == rhs.toBool();
+  } else if (this->isTensor() && rhs.isTensor()) {
+    return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
+  } else if (this->isTensor() && rhs.isNone()) {
+    // special case: undefined tensor and None are the same identity
+    return !this->payload.as_tensor.defined();
+  } else if (this->isNone() && rhs.isTensor()) {
+    // special case: undefined tensor and None are the same identity
+    return !rhs.payload.as_tensor.defined();
+  } else if (this->isInt() && rhs.isInt()) {
+    return this->toInt() == rhs.toInt();
+  } else if (this->isDouble() && rhs.isDouble()) {
+    return this->toDouble() == rhs.toDouble();
+  } else if (this->isString() && rhs.isString()) {
+    return this->toStringRef() == rhs.toStringRef();
+  } else {
+    // for objects holding in IValue, do shallow compare on pointer address to
+    // testify the identity
+    return this->isIntrusivePtr() && rhs.isIntrusivePtr() &&
+        this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
+  }
+}
+
+namespace ivalue {
+namespace detail {
+
+template 
+IValue from_(T&& x, std::true_type) {
+  return IValue(std::forward(x));
+}
+template 
+IValue from_(c10::intrusive_ptr x, std::false_type) {
+  return IValue(std::move(x));
+}
+template 
+IValue from_(T&& /*x*/, std::false_type) {
+  static_assert(
+      guts::false_t::value,
+      "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
+  return IValue();
+}
+} // namespace detail
+
+template 
+IValue from(T&& x) {
+  return detail::from_(
+      std::forward(x), typename std::is_constructible::type{});
+}
+
+} // namespace ivalue
+
+
+template <>
+struct MaybeOwnedTraits {
+  using owned_type = IValue;
+  using borrow_type = IValue;
+
+  static borrow_type createBorrow(const owned_type& from) {
+    if (!from.isPtrType()) {
+      return from;
+    }
+    if (from.isTensor()) {
+      return IValue(MaybeOwnedTraits::createBorrow(from.toTensor()));
+    } else {
+      return IValue(from.payload, from.tag);
+    }
+  }
+
+  static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
+    lhs.clearToNone();
+    if (!rhs.isPtrType()) {
+      lhs = rhs;
+    } else if (rhs.isTensor()) {
+      lhs = IValue(MaybeOwnedTraits::createBorrow(rhs.toTensor()));
+    } else {
+      lhs = IValue(rhs.payload, rhs.tag);
+    }
+  }
+
+  static void destroyBorrow(borrow_type& toDestroy) {
+    toDestroy.clearToNone();
+  }
+
+  static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
+    return borrow;
+  }
+
+  static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
+    return &borrow;
+  }
+
+  static bool debugBorrowIsValid(const borrow_type&) {
+    return true;
+  }
+};
+
+template <>
+struct IValue::TagType {
+  static TORCH_API c10::TypePtr get(const IValue&);
+};
+
+template <>
+struct IValue::TagType {
+  static TORCH_API c10::TypePtr get(const IValue&);
+};
+
+template 
+TypePtr IValue::type() const {
+  return IValue::TagType::get(*this);
+}
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue_to.h b/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue_to.h
new file mode 100644
index 0000000000000000000000000000000000000000..f750de76cfa9dc1ae0b1ef975526b38d70eb8bb0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/ivalue_to.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+
+namespace at {
+class Tensor;
+} // namespace at
+
+namespace c10 {
+struct IValue;
+namespace detail {
+// Determine the return type of `IValue::to() const &`. It's a const
+// reference when possible and a copy otherwise. It is in this
+// separate header so that List can use it as well.
+template
+struct ivalue_to_const_ref_overload_return {
+  using type = T;
+};
+
+template<>
+struct ivalue_to_const_ref_overload_return {
+  using type = const at::Tensor&;
+};
+
+template<>
+struct ivalue_to_const_ref_overload_return {
+  using type = const std::string&;
+};
+
+template<>
+struct ivalue_to_const_ref_overload_return {
+  using type = const IValue&;
+};
+
+} // namespace detail
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/jit_type.h b/MLPY/Lib/site-packages/torch/include/ATen/core/jit_type.h
new file mode 100644
index 0000000000000000000000000000000000000000..4f3a855c1f847f9ea19789cf16a697c87bb77443
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/jit_type.h
@@ -0,0 +1,2425 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace torch {
+namespace jit {
+struct Function;
+} // namespace jit
+} // namespace torch
+
+namespace c10 {
+
+template
+class Dict;
+struct IValue;
+struct FunctionSchema;
+struct NamedType;
+using OptNameList = c10::optional>;
+
+void standardizeVectorForUnion(std::vector& reference, std::vector* to_fill);
+void standardizeVectorForUnion(std::vector* to_flatten);
+
+inline bool is_contiguous_strides(
+    const IntArrayRef sizes,
+    const IntArrayRef strides) {
+  int n_dim = static_cast(sizes.size());
+  if (n_dim == 0) {
+    return true;
+  }
+
+  if (strides[n_dim - 1] != 1) {
+    return false;
+  }
+
+  for (int i = n_dim - 2; i >= 0; i--) {
+    if (strides[i] != strides[i + 1] * sizes[i + 1]) {
+      return false;
+    }
+  }
+  return true;
+}
+
+struct AnyType;
+using AnyTypePtr = SingletonTypePtr;
+// Any is the top of the type hierarchy, all other types are subtypes
+// T <: Any, forall T
+struct TORCH_API AnyType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "Any";
+  }
+  static const TypeKind Kind = TypeKind::AnyType;
+  // global singleton
+  static AnyTypePtr get();
+
+ private:
+  AnyType() : Type(TypeKind::AnyType) {}
+};
+
+inline std::string toString(const Type& type) {
+  return type.str();
+}
+
+// Shim for compatibility with code that uses TypePtr.
+inline std::string toString(const TypePtr& typePtr) {
+  return toString(*typePtr);
+}
+
+inline bool operator!=(const Type& lhs, const Type& rhs) {
+  return !(lhs == rhs);
+}
+
+// common base for all types that have a single sub element
+// e.g. Future[T], Optional[T], List[T]
+template 
+struct SingleElementType : public SharedType {
+  static const TypeKind Kind = K;
+
+  const TypePtr& getElementType() const {
+    return elem;
+  }
+
+  bool hasFreeVariables() const override {
+    return getElementType()->hasFreeVariables();
+  }
+
+  at::ArrayRef containedTypes() const override {
+    return elem;
+  }
+
+  bool equals(const Type& rhs) const override {
+    if (auto rhs_ = rhs.cast()) {
+      return *getElementType() == *rhs_->getElementType();
+    }
+    return false;
+  }
+
+ protected:
+  SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
+    if (!this->elem) {
+      throw std::runtime_error(c10::str(
+            "Can not create ", typeKindToString(Kind), " with None type"));
+    }
+  }
+
+ private:
+  TypePtr elem;
+};
+
+struct UnionType;
+using UnionTypePtr = std::shared_ptr;
+struct TORCH_API UnionType : public SharedType {
+  friend struct Type;
+
+  static const TypeKind Kind = TypeKind::UnionType;
+
+  bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
+
+  std::string str() const override;
+
+  static UnionTypePtr create(std::vector reference);
+
+  bool equals(const Type& rhs) const override;
+
+  bool isUnionType() const override {
+    return true;
+  }
+
+  at::ArrayRef containedTypes() const override {
+    return types_;
+  }
+
+  // For testing purposes only
+  at::ArrayRef getTypes() const {
+    return types_;
+  }
+
+  TypePtr createWithContained(std::vector contained_types) const override {
+    return create(std::move(contained_types));
+  }
+
+  bool canHoldType(const Type& type) const;
+
+  bool hasFreeVariables() const override {
+    return has_free_variables_;
+  }
+
+  c10::optional toOptional() const;
+
+  c10::optional subtractTypeSet(std::vector& to_subtract) const;
+
+ protected:
+    explicit UnionType(std::vector types, TypeKind kind=TypeKind::UnionType);
+    std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
+    std::string unionStr(
+        TypePrinter printer = nullptr,
+        bool is_annotation_str = false) const;
+    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
+    bool has_free_variables_;
+    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
+    std::vector types_;
+    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
+    bool can_hold_none_;
+
+};
+
+struct OptionalType;
+using OptionalTypePtr = std::shared_ptr;
+// This type represents an optional type. There is one `Optional` for
+// each element type. `Optional[T]` can accept both `T` and
+// `None`(`c10::nullopt` in C++)
+// Subtype hierarchy for Optional:
+//     - Optional[T] <: Optional[R] iff T <: R
+//     - T <: Optional[R] if T <: R
+//     - None <: Optional[T] for all T
+//     - Optional[T] == Union[T, None] for all T
+struct TORCH_API OptionalType : public UnionType {
+  static OptionalTypePtr create(const TypePtr& contained);
+
+  static const TypeKind Kind = TypeKind::OptionalType;
+
+  friend struct Type;
+
+  bool equals(const Type& rhs) const override;
+
+  const TypePtr& getElementType() const {
+    return contained_;
+  }
+
+  at::ArrayRef containedTypes() const override {
+    return contained_;
+  }
+
+  std::string str() const override {
+    std::stringstream ss;
+    ss << getElementType()->str() << "?";
+    return ss.str();
+  }
+
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    AT_ASSERT(contained_types.size() == 1);
+    return create(contained_types[0]);
+  }
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  bool isUnionType() const override {
+    return true;
+  }
+
+  // common cast Optional[Tensor] for undefined tensor type
+  static TypePtr ofTensor();
+  //
+  // global singleton
+  static TypePtr get(TypePtr inner);
+
+ private:
+  explicit OptionalType(const TypePtr& contained);
+
+  TypePtr contained_;
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    std::stringstream ss;
+    ss << "Optional[" << getElementType()->annotation_str(std::move(printer)) << "]";
+    return ss.str();
+  }
+};
+
+template 
+inline c10::optional merge_primitive(
+    const c10::optional& a,
+    const c10::optional& b) {
+  if (a.has_value() && b.has_value() && a.value() == b.value()) {
+    return a;
+  }
+  return c10::optional{};
+}
+
+// If we see `a + b + c`  and know that a, b, and c are the same size and have
+// two dimensions (WxH), then we can generate a fused kernel for them. That
+// fused kernel would likely have indexing math to handling both the W and H
+// dimensions. However, if we knew the WxH dimensions were contiguous, we can
+// pretend like we only have a single dimension, simplifying the indexing logic.
+// This can be performed even if the dimensions are transposed,
+// as long as a, b, and c are transposed in the same way.
+// We'd like to have the compiler be able to do this dimensionality reduction,
+// but simply knowing sizes is not enough.
+// We can extend profiling to also record stride information.
+// Rather than recording specific strides,
+// we can simply order the strides from smallest to largest with
+// `stride_indices` A contiguity marker on the smallest stride (c0) indicates
+// the stride is precisely 1, otherwise a contiguity marker means that $stride_n
+// = size_{n-1}*stride_{n-1}$
+struct TORCH_API Stride {
+  Stride() = default;
+  Stride(
+      const c10::optional& stride_index,
+      c10::optional contiguous,
+      const c10::optional& stride)
+      : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {}
+
+  bool operator==(const Stride& b) const {
+    return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ &&
+        stride_ == b.stride_;
+  }
+
+  bool isComplete() const {
+    return stride_index_ && contiguous_ && stride_;
+  }
+
+  c10::optional stride_index_;
+  c10::optional contiguous_;
+  c10::optional stride_;
+};
+
+template <>
+inline c10::optional merge_primitive(
+    const c10::optional& a,
+    const c10::optional& b) {
+  c10::optional left = a;
+  c10::optional right = b;
+  if (!left.has_value()) {
+    left = {Stride()};
+  }
+  if (!right.has_value()) {
+    right = {Stride()};
+  }
+
+  auto merged_index =
+      merge_primitive(left->stride_index_, right->stride_index_);
+  auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_);
+  auto merged_stride = merge_primitive(left->stride_, right->stride_);
+  auto r = Stride(merged_index, merged_cont, merged_stride);
+  // normalize
+  if (!r.stride_index_.has_value() && !r.contiguous_.has_value() &&
+      !r.stride_.has_value()) {
+    return c10::optional{};
+  }
+
+  return r;
+}
+
+struct TORCH_API ShapeSymbol {
+  // needed for use in `std::map`
+  ShapeSymbol() : value_(-1) {}
+  // is this symbol a fixed/static dimension
+  bool is_static() const {
+    return value_ >= 0;
+  };
+  bool operator==(const ShapeSymbol& b) const {
+    return value_ == b.value_;
+  }
+  bool operator<(const ShapeSymbol& b) const {
+    return value_ < b.value_;
+  }
+
+  static ShapeSymbol fromStaticSize(int64_t val) {
+    return ShapeSymbol(val);
+  }
+  int64_t static_size() const {
+    TORCH_CHECK(is_static());
+    return value_;
+  };
+
+  int64_t value() const {
+    return value_;
+  };
+
+  static ShapeSymbol newSymbol() {
+    return fromStaticSize(-static_cast(++num_symbols));
+  };
+  friend TORCH_API std::ostream& operator<<(
+      std::ostream& os,
+      const ShapeSymbol& s);
+
+ private:
+  ShapeSymbol(int64_t val) : value_(val) {}
+  int64_t value_;
+  static std::atomic num_symbols;
+};
+
+inline ShapeSymbol merge_primitive(
+    const ShapeSymbol& a,
+    const ShapeSymbol& b) {
+  if (a.is_static() && b.is_static() && a == b) {
+    return a;
+  }
+  return ShapeSymbol::newSymbol();
+}
+
+// Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown
+// dims, partially known and fully known shapes are all supported.
+struct TORCH_API SymbolicShape {
+  // Unranked shape constructor.
+  SymbolicShape() : dims_(c10::nullopt) {}
+
+  // Known rank but unknown dimentions.
+  SymbolicShape(c10::optional rank) : dims_(c10::nullopt) {
+    if(!rank) {
+      return;
+    }
+
+    std::vector shape_symbols;
+    shape_symbols.reserve(*rank);
+    for(size_t i = 0; i < *rank; ++i) {
+      shape_symbols.push_back(ShapeSymbol::newSymbol());
+    }
+    dims_ = shape_symbols;
+  }
+
+  // Mix of known and unknown ranks
+  SymbolicShape(const std::vector>& dims) {
+    std::vector shape_symbols;
+    shape_symbols.reserve(dims.size());
+    for(c10::optional dim: dims) {
+      if(!dim) {
+        shape_symbols.push_back(ShapeSymbol::newSymbol());
+      } else {
+        shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim));
+      }
+    }
+    dims_ = shape_symbols;
+  }
+
+  void dump() const;
+
+  SymbolicShape(std::vector dims) : dims_(std::move(dims)) {}
+
+  SymbolicShape(c10::IntArrayRef dims) {
+    std::vector shape_symbols;
+    shape_symbols.reserve(dims.size());
+    for(int64_t dim : dims) {
+      shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim));
+    }
+    dims_ = shape_symbols;
+  }
+
+  ShapeSymbol operator[](size_t i) const {
+    if (!dims_) {
+      throw std::runtime_error("Rank isn't fixed");
+    }
+    return (*dims_).at(i);
+  }
+
+  ShapeSymbol at(size_t i) const {
+    if (!dims_) {
+      throw std::runtime_error("Rank isn't fixed");
+    }
+    return (*dims_).at(i);
+  }
+
+  // Returns rank or nullopt in case of unranked shape.
+  c10::optional rank() const {
+    if(!dims_) {
+      return c10::nullopt;
+    }
+    return dims_->size();
+  }
+
+  c10::optional> sizes() const {
+    return dims_;
+  }
+
+  c10::optional> symbolicDims() const {
+    if (!dims_) {
+      return c10::nullopt;
+    }
+    auto symbolic_dims = std::vector();
+    for (const ShapeSymbol& s : *dims_) {
+      symbolic_dims.push_back(!s.is_static());
+    }
+    return symbolic_dims;
+  }
+
+  // Checks whether the shape is fully defined/complete, ie. rank and sizes
+  // of every dimension are known.
+  bool isComplete() const {
+    if(!dims_) {
+      return false;
+    }
+    for(auto d : *dims_) {
+      if(!d.is_static()) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // Create new SymbolicShape that is result of merging self and another
+  // SymbolicShape. Only dimensions that are static and equal will be
+  // preserved.
+  // If either of two shapes are of unknown rank or they have unmatching rank,
+  // result will be unranked.
+  SymbolicShape merge(const SymbolicShape& other) const;
+
+  friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) {
+    return lhs.dims_ == rhs.dims_;
+  }
+
+  friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) {
+    return !(lhs == rhs);
+  }
+
+  private:
+    c10::optional> dims_;
+};
+
+namespace detail {
+inline bool isComplete(const Stride& s) {
+  return s.isComplete();
+}
+
+template
+inline bool isComplete(const T& /*t*/) {
+  return true;
+}
+}
+
+template 
+struct VaryingShape {
+  using ListOfOptionalElements = std::vector>;
+  VaryingShape(const std::vector& vec)
+      : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
+
+  VaryingShape(c10::ArrayRef vec)
+      : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {}
+
+  VaryingShape(c10::optional size = c10::nullopt) : dims_(c10::nullopt) {
+    if (size) {
+      dims_ = ListOfOptionalElements(*size);
+    }
+  }
+
+  VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {}
+
+  VaryingShape(size_t size) : VaryingShape(c10::optional(size)) {}
+
+  bool operator==(const VaryingShape& other) const {
+    return dims_ == other.dims_;
+  }
+
+  const c10::optional &operator[](size_t i) const {
+    if (!dims_) {
+      throw std::runtime_error("Rank isn't fixed");
+    }
+    return (*dims_).at(i);
+  }
+
+  c10::optional size() const {
+    if (!dims_) {
+      return c10::nullopt;
+    }
+    const auto& dims = dims_.value();
+    return dims.size();
+  }
+
+  const c10::optional& sizes() const {
+    return dims_;
+  }
+
+  TORCH_API VaryingShape merge(const VaryingShape& other) const;
+
+  c10::optional> concrete_sizes() const {
+    if (!dims_) {
+      return c10::nullopt;
+    }
+    std::vector sizes;
+    sizes.reserve(dims_.value().size());
+    for (auto d : *dims_) {
+      if (!d) {
+        return c10::nullopt;
+      }
+      sizes.push_back(d.value());
+    }
+    return sizes;
+  }
+
+  bool isComplete() const {
+    if (!dims_) {
+      return false;
+    }
+    for (auto d : *dims_) {
+      if (!d || !detail::isComplete(*d)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+ private:
+  c10::optional dims_;
+};
+
+struct TensorType;
+// TODO: investigate making this SingletonOrSharedTypePtr
+using TensorTypePtr = std::shared_ptr;
+// This type represents a single Tensor with a specific size
+struct TORCH_API TensorType : public SharedType {
+  static TensorTypePtr create(const at::Tensor& t);
+
+  // used by TensorType::create(size_t dim) which in turn used by
+  // shape_analysis.cpp
+  static TensorTypePtr create(
+      c10::optional scalar_type,
+      c10::optional device,
+      const VaryingShape& sizes,
+      const VaryingShape& strides,
+      c10::optional requires_grad,
+      c10::optional undefined = false,
+      bool tensor_contiguity = false);
+
+  static TensorTypePtr create(
+      c10::optional scalar_type,
+      c10::optional device,
+      const SymbolicShape& sizes,
+      const VaryingShape& stride_,
+      c10::optional requires_grad,
+      c10::optional undefined = false);
+
+  static TensorTypePtr create(
+      c10::optional scalar_type,
+      c10::optional device,
+      c10::optional dim,
+      c10::optional requires_grad);
+
+  // overloaded create variadic template argument as it could not distinguish
+  // initializer list
+  static TensorTypePtr createContiguous(
+      at::ScalarType scalar_type,
+      at::Device device,
+      at::IntArrayRef sizes);
+
+  static TypePtr fromNumberType(const Type& typ);
+  static TypePtr fromBoolType();
+
+  c10::optional dim() const {
+    return sizes().size();
+  }
+
+  VaryingShape sizes() const;
+
+  VaryingShape strides() const;
+
+  const VaryingShape& stride_properties() const {
+    return strides_;
+  }
+
+  c10::optional device() const {
+    return device_;
+  }
+  c10::optional scalarType() const {
+    return scalar_type_;
+  }
+  c10::optional requiresGrad() const {
+    return requires_grad_;
+  }
+  bool requires_grad() const override {
+    return requires_grad_ ? *requires_grad_ : true;
+  }
+
+  bool equals(const Type& rhs) const override;
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  std::string str() const override;
+
+  std::string repr_str() const override {
+    if (isInferredType()) {
+      return str() + " (inferred)";
+    } else {
+      return str();
+    }
+  }
+
+  c10::optional numel() const {
+    size_t prod = 1;
+    const auto& shape = sizes();
+
+    for (size_t i = 0; i < shape.size(); i++) {
+      if (!shape[i]) {
+        return c10::optional{};
+      }
+      prod *= shape[i].value();
+    }
+    return prod;
+  }
+
+  TensorTypePtr withRequiresGrad(c10::optional s) {
+    auto copy = clone();
+    copy->requires_grad_ = s;
+    return copy;
+  }
+
+  TensorTypePtr withScalarType(c10::optional st) {
+    auto copy = clone();
+    copy->scalar_type_ = st;
+    return copy;
+  }
+
+  TensorTypePtr withDim(c10::optional d) {
+    auto copy = clone();
+    // withDim is only used by the legacy executor
+    // that only cares about the rank, so create dummy symbols)) :
+    copy->sizes_ = SymbolicShape(d);
+    copy->strides_ = VaryingShape(d);
+    return copy;
+  }
+
+  TensorTypePtr withStrides(VaryingShape sstrides) const {
+    auto cloned = clone();
+    cloned->strides_ = std::move(sstrides);
+    return cloned;
+  }
+
+  TensorTypePtr withSizesStrides(
+      at::IntArrayRef sizes,
+      at::IntArrayRef strides) const {
+    auto cloned = clone();
+    auto ssizes = SymbolicShape(sizes);
+    cloned->sizes_ = ssizes;
+    cloned->strides_ = computeStrideProps(sizes, strides);
+    return cloned;
+  }
+
+  TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const {
+    auto cloned = clone();
+    cloned->sizes_ = std::move(ssizes);
+    return cloned;
+  }
+
+  TensorTypePtr withSizes(at::IntArrayRef sizes) const {
+    return withSizesStrides(
+        sizes, contiguousStridesOf(sizes));
+  }
+
+  TensorTypePtr withDevice(const c10::optional device) const {
+    auto copy = clone();
+    copy->device_ = device;
+    return copy;
+  }
+
+  TensorTypePtr dimensionedOnly() const {
+    auto copy = clone();
+    copy->sizes_ = SymbolicShape(sizes().size());
+    copy->strides_ = VaryingShape(sizes().size());
+    return copy;
+  }
+
+  TensorTypePtr contiguous() const {
+    auto cloned = clone();
+    TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value());
+    auto strides = computeStrideProps(
+        *sizes().concrete_sizes(),
+        contiguousStridesOf(*sizes().concrete_sizes()));
+    cloned->strides_ = strides;
+    return cloned;
+  }
+
+  const SymbolicShape& symbolic_sizes() const;
+
+  TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const;
+
+  bool matchTensor(const at::Tensor& t);
+
+  // is all information about the type specified except for autograd?
+  // This replaces the notion of a 'CompleteTensorType' that used to exist
+  // in the type-hierarchy. Excluding require_grad and undefined allows
+  // this to match the old behavior.
+  bool isComplete() const {
+    return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete();
+  }
+
+  bool isInferredType() const {
+    return is_inferred_;
+  }
+
+  static TensorTypePtr getInferred() {
+    static auto valueInferred = TensorType::create(
+        /*scalar_type=*/{},
+        /*device=*/{},
+        /*sizes=*/SymbolicShape(),
+        /*stride=*/VaryingShape{},
+        /*requires_grad=*/{},
+        /*undefined=*/false);
+    valueInferred->is_inferred_ = true;
+    return valueInferred;
+  }
+
+  // this property is used by GuardElimination
+  // please see `checkInputs` for more details
+  bool isSummarized() const {
+    return !(isComplete() && requiresGrad().has_value() &&
+             undefined().has_value());
+  }
+
+  TensorTypePtr withUndefined() {
+    auto r = clone();
+    r->undefined_ = true;
+    return r;
+  }
+
+  TensorTypePtr withPossiblyUndefined() {
+    auto r = clone();
+    r->undefined_ = c10::nullopt;
+    return r;
+  }
+
+  c10::optional undefined() const { return undefined_; }
+
+  static const TensorTypePtr& get();
+
+  static const TypeKind Kind = TypeKind::TensorType;
+
+  static std::vector contiguousStridesOf(
+      at::IntArrayRef in_sizes,
+      at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
+    auto contiguous_fn = [](const at::IntArrayRef& sizes,
+                            const std::vector& dim_order) {
+      std::vector strides(sizes.size());
+      if (sizes.empty()) // zero-dim case
+        return strides;
+
+      strides[dim_order[0]] = 1;
+      for (size_t i = 1; i < dim_order.size(); i++) {
+        auto cur_dim = dim_order[i];
+        auto pre_dim = dim_order[i - 1];
+        strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
+      }
+      return strides;
+    };
+
+    std::vector dim_order(in_sizes.size());
+    if (memory_format == MemoryFormat::ChannelsLast) {
+      dim_order = {1, 3, 2, 0};
+    } else if (memory_format == MemoryFormat::ChannelsLast3d) {
+      dim_order = {1, 4, 3, 2, 0};
+    } else {
+      auto ndims = in_sizes.size();
+      for (size_t i = 0; i < ndims; i++) {
+        dim_order[i] = static_cast(ndims - i - 1); // Reverse
+      }
+    }
+    return contiguous_fn(in_sizes, dim_order);
+  }
+
+ private:
+  TensorType(
+      c10::optional scalar_type,
+      c10::optional device,
+      SymbolicShape sizes,
+      VaryingShape strides,
+      c10::optional requires_grad,
+      c10::optional undefined = false);
+
+  TensorTypePtr clone() const {
+    return TensorTypePtr(new TensorType(
+        scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_));
+  }
+
+  static VaryingShape computeStrideProps(
+      at::IntArrayRef sizes,
+      at::IntArrayRef strides,
+      bool tensor_contiguity = false);
+
+  c10::optional scalar_type_;
+  c10::optional device_;
+  SymbolicShape sizes_;
+  VaryingShape strides_;
+  c10::optional requires_grad_;
+  // we exploit the fact certain tensors must be zero in the autograd to
+  // optimize gradient computation. Such zero tensors are currently implemented
+  // with `UndefinedTensorImpl.` They can be handled only by special operators
+  // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false.
+  // Normally, `undefined_` is set to false, unless a type was created
+  // with `withUndefined`
+  // This will also mean that `undefined` tensors will fail
+  // `subtypeOf(TensorType::get())` check
+  // undefined_ may become `c10::nullopt` if the tensor was observed to be both
+  // defined and undefined. However, no tensor type starts out with
+  // `undefined_` set to `c10::nullopt`
+  c10::optional undefined_;
+  // Represents whether or not this type was inferred.
+  bool is_inferred_ = false;
+};
+
+struct ListType;
+using ListTypePtr = std::shared_ptr;
+struct TORCH_API ListType
+    : public SingleElementType {
+  // It's not exactly a singleton, but there should be exactly one instance of
+  // List[T] for every T
+  friend struct Type;
+  template 
+  static ListTypePtr create(T&&... all) {
+    return ListTypePtr(
+        new ListType(std::forward(all)...)); // NOLINT(modernize-make-shared)
+  }
+
+  std::string str() const override {
+    std::stringstream ss;
+    ss << getElementType()->str() << "[]";
+    return ss.str();
+  }
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    return create(std::move(contained_types.at(0)));
+  }
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  // global singleton
+  // Given an inner type T and an identifier,
+  // this function wil return the global singleton type pointer
+  // the type List.
+  // The extra "identifier" argument is needed beccause we have multiple container types
+  // that all re-use this function (List, array, etc.)
+  static TypePtr get(const std::string& identifier, TypePtr inner);
+
+  // common cast List[Tensor]
+  static ListTypePtr ofTensors();
+  static ListTypePtr ofOptionalTensors();
+  static ListTypePtr ofInts();
+  static ListTypePtr ofSymInts();
+  static ListTypePtr ofFloats();
+  static ListTypePtr ofComplexDoubles();
+  static ListTypePtr ofBools();
+  static ListTypePtr ofStrings();
+  static ListTypePtr ofNumbers();
+
+ private:
+  ListType(TypePtr elem) : SingleElementType(std::move(elem)) {}
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    std::stringstream ss;
+    ss << "List[" << getElementType()->annotation_str(std::move(printer)) << "]";
+    return ss.str();
+  }
+};
+
+struct DictType;
+using DictTypePtr = std::shared_ptr;
+struct TORCH_API DictType : public SharedType {
+  friend struct Type;
+  static const TypeKind Kind = TypeKind::DictType;
+
+  static DictTypePtr create(TypePtr key, TypePtr value) {
+    auto kind = key->kind();
+    if (auto dyn = key->castRaw()) {
+      kind = dyn->dynamicKind();
+    }
+    switch (kind) {
+      case TypeKind::AnyType:
+      case TypeKind::IntType:
+      case TypeKind::BoolType:
+      case TypeKind::FloatType:
+      case TypeKind::ComplexType:
+      case TypeKind::StringType:
+      case TypeKind::TensorType:
+      case TypeKind::DeviceObjType:
+        return DictTypePtr(new DictType(std::move(key), std::move(value)));
+      default:
+        AT_ERROR(
+            "Cannot create dict for key type '",
+            key->str(),
+            "', only int, float, complex, Tensor, device and string keys are supported");
+    }
+  }
+
+  // aligned with the format in FunctionSchema
+  std::string str() const override {
+    std::stringstream ss;
+    ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
+       << ")";
+    return ss.str();
+  }
+
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    if (contained_types.size() != 2) {
+      throw std::runtime_error("Expected 2 contained types");
+    }
+    return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
+  }
+
+  const TypePtr& getKeyType() const {
+    return types.at(0);
+  }
+
+  const TypePtr& getValueType() const {
+    return types.at(1);
+  }
+
+  bool hasFreeVariables() const override {
+    return has_free_variables;
+  }
+
+  at::ArrayRef containedTypes() const override {
+    return types;
+  }
+
+  bool equals(const Type& rhs) const override {
+    if (auto* dict_rhs = rhs.castRaw()) {
+      return *getKeyType() == *(dict_rhs->getKeyType()) &&
+          *getValueType() == *(dict_rhs->getValueType());
+    }
+    return false;
+  }
+
+  // global singleton
+  // Given an inner type T and an identifier,
+  // this function will return the global singleton type pointer
+  // the type List.
+  // The extra "identifier" argument is needed because we have multiple container types
+  // that all re-use this function (Dict and unordered_map)
+  static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val);
+
+ private:
+  DictType(TypePtr key, TypePtr value)
+      : SharedType(TypeKind::DictType),
+        has_free_variables(
+            key->hasFreeVariables() || value->hasFreeVariables()) {
+    types.reserve(2);
+    types.push_back(std::move(key));
+    types.push_back(std::move(value));
+  }
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
+
+  std::vector types;
+  bool has_free_variables;
+};
+
+struct FutureType;
+using FutureTypePtr = std::shared_ptr;
+
+struct TORCH_API FutureType
+    : public SingleElementType {
+  friend struct Type;
+  template 
+  static FutureTypePtr create(TypePtr elem) {
+    return FutureTypePtr(
+        new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
+  }
+
+  std::string str() const override {
+    std::stringstream ss;
+    ss << "Future(" << getElementType()->str() << ")";
+    return ss.str();
+  }
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    return create(std::move(contained_types.at(0)));
+  }
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
+    if (Type::isSubtypeOfExt(rhs, why_not)) {
+      return true;
+    }
+    if (auto rhs_ = rhs.castRaw()) {
+      return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
+    }
+    return false;
+  }
+
+ private:
+  FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {}
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    std::stringstream ss;
+    ss << "Future[" << getElementType()->annotation_str(std::move(printer)) << "]";
+    return ss.str();
+  }
+};
+
+struct AwaitType;
+using AwaitTypePtr = std::shared_ptr;
+
+struct TORCH_API AwaitType
+    : public SingleElementType {
+  friend struct Type;
+  template 
+  static AwaitTypePtr create(TypePtr elem) {
+    return AwaitTypePtr(
+        new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared)
+  }
+
+  std::string str() const override {
+    std::stringstream ss;
+    ss << "Await(" << getElementType()->str() << ")";
+    return ss.str();
+  }
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    return create(std::move(contained_types.at(0)));
+  }
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
+    if (Type::isSubtypeOfExt(rhs, why_not)) {
+      return true;
+    }
+    if (auto rhs_ = rhs.castRaw()) {
+      return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not);
+    }
+    return false;
+  }
+
+ private:
+  AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {}
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    std::stringstream ss;
+    ss << "Await[" << getElementType()->annotation_str(printer) << "]";
+    return ss.str();
+  }
+};
+
+struct RRefType;
+using RRefTypePtr = std::shared_ptr;
+
+struct TORCH_API RRefType
+    : public SingleElementType {
+  friend struct Type;
+  template 
+  static RRefTypePtr create(TypePtr elem) {
+    return RRefTypePtr(
+        new RRefType(std::move(elem))); // NOLINT(modernize-make-shared)
+  }
+
+  std::string str() const override {
+    std::stringstream ss;
+    ss << "RRef(" << getElementType()->str() << ")";
+    return ss.str();
+  }
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    return create(std::move(contained_types.at(0)));
+  }
+
+ private:
+  RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {}
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    std::stringstream ss;
+    ss << "RRef[" << getElementType()->annotation_str(std::move(printer)) << "]";
+    return ss.str();
+  }
+};
+
+// Any should never appear in a named type like a class, namedtuple or
+// interface. If it does, then dynamic type information will be lost in the
+// Pickler, leading to hard-to-track-down bugs that will only occur
+// after saving or loading a model. This is because we rely on the
+// static types in named types to reconstruct type tags of loaded
+// values. Lifting this restriction requires solving the serialization
+// problem first.
+TORCH_API void checkNoAny(
+    const Type& base,
+    const char* what,
+    const std::string& attrname,
+    const TypePtr& attrtype);
+
+struct TupleType;
+using TupleTypePtr = std::shared_ptr;
+using NameList = std::vector;
+// This type represents a Tuple
+struct TORCH_API TupleType : public NamedType {
+
+  static TupleTypePtr createNamed(const c10::optional& name,
+      const std::vector& field_names,
+      const std::vector& field_types,
+      std::vector& field_defaults);
+
+  static TupleTypePtr createNamed(const c10::optional& name,
+      const std::vector& field_names,
+      const std::vector& field_types);
+
+  static TupleTypePtr createNamed(const c10::optional& name,
+      const std::vector& field_names,
+      const std::vector& field_types);
+
+  static TupleTypePtr create(
+      std::vector types) {
+    return TupleTypePtr(new TupleType(
+        std::move(types),
+        c10::nullopt,
+        nullptr)); // NOLINT(modernize-make-shared)
+  }
+  static TupleTypePtr create() {
+    return create({});
+  }
+
+  at::ArrayRef elements() const {
+    return elements_;
+  }
+
+  bool equals(const Type& rhs) const override;
+  bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override;
+
+  std::string str() const override;
+  bool hasFreeVariables() const override {
+    return has_free_variables_;
+  }
+  at::ArrayRef containedTypes() const override {
+    return elements_;
+  }
+  TypePtr createWithContained(
+      std::vector contained_types) const override {
+    return std::shared_ptr(
+        new TupleType(std::move(contained_types), name(), schema()));
+  }
+  const std::shared_ptr& schema() const {
+    return schema_;
+  }
+  c10::optional> names() const;
+
+  static const TypeKind Kind = TypeKind::TupleType;
+
+ private:
+  template 
+  static TupleTypePtr createWithSpec(
+      const c10::optional& name,
+      const std::vector& field_names,
+      const std::vector& field_types,
+      std::vector& field_defaults);
+
+  TupleType(
+      std::vector elements_,
+      c10::optional name,
+      std::shared_ptr schema);
+
+  bool compare(
+      const Type& rhs,
+      const std::function& fn) const {
+    if (rhs.kind() != kind()) {
+      return false;
+    }
+
+    const auto& l_elements = elements();
+    const auto& r_elements = rhs.castRaw()->elements();
+    if (l_elements.size() != r_elements.size())
+      return false;
+    for (size_t i = 0; i < l_elements.size(); ++i) {
+      if (!fn(*l_elements[i], *r_elements[i]))
+        return false;
+    }
+    return true;
+  }
+
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
+
+  std::vector elements_;
+  bool has_free_variables_;
+  std::shared_ptr schema_;
+};
+
+// the common supertype of all Enums, only used in operator registraion.
+// EnumType <: AnyEnumType for all Enums
+struct AnyEnumType;
+using AnyEnumTypePtr = SingletonTypePtr;
+struct TORCH_API AnyEnumType final : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "AnyEnumType";
+  }
+  static const TypeKind Kind = TypeKind::AnyEnumType;
+  // global singleton
+  static AnyEnumTypePtr get();
+private:
+  AnyEnumType()
+  : Type(TypeKind::AnyEnumType) {}
+};
+
+struct NumberType;
+using NumberTypePtr = SingletonTypePtr;
+// This type represents a Python number
+// Subtype hierarchy for Number Types (NumberType as the base type):
+// IntType <: NumberType
+// FloatType <: NumberType
+// ComplexType <:NumberType
+//
+// WARNING: if you add a new subtype of NumberType that is not
+// represented by a global singleton, you need to change NumberTypePtr
+// to a SingletonOrSharedTypePtr and deal with NumberType needing to
+// both inherit and not inherit from SharedType!
+struct TORCH_API NumberType : public Type {
+  bool equals(const Type& rhs) const override;
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  std::string str() const override {
+    return "Scalar"; // match what PythonArgParser says for clarity
+  }
+  static const TypeKind Kind = TypeKind::NumberType;
+  // global singleton
+  static NumberTypePtr get();
+
+ protected:
+  NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {}
+
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return "number"; // technically not a valid python type, but
+                     // we need to use it when parsing back in annotations
+                     // for implicit conversions
+  }
+};
+
+struct FloatType;
+using FloatTypePtr = SingletonTypePtr;
+// This type represents a Python float number
+struct TORCH_API FloatType : public NumberType {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "float";
+  }
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
+    // NOLINTNEXTLINE(bugprone-parent-virtual-call)
+    return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
+  }
+  static const TypeKind Kind = TypeKind::FloatType;
+  // global singleton
+  static FloatTypePtr get();
+
+ private:
+  FloatType() : NumberType(TypeKind::FloatType) {}
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return "float";
+  }
+};
+
+struct ComplexType;
+using ComplexTypePtr = SingletonTypePtr;
+// This type represents a Python float number
+struct TORCH_API ComplexType : public NumberType {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "complex";
+  }
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
+    // NOLINTNEXTLINE(bugprone-parent-virtual-call)
+    return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
+  }
+  static const TypeKind Kind = TypeKind::ComplexType;
+  // global singleton
+  static ComplexTypePtr get();
+
+ private:
+  ComplexType() : NumberType(TypeKind::ComplexType) {}
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return "complex";
+  }
+};
+
+// We need to introduce `SymIntType` to represent the `SymInt` type
+// used in function schemas e.g. `aten::narrow_copy(... SymInt length)
+// `SymInt` will be used to enable tracing arithmetic operations on
+// dimension values. Please see [SymInt.h] for more information
+struct SymIntType;
+using SymIntTypePtr = SingletonTypePtr;
+struct TORCH_API SymIntType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "SymInt";
+  }
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    return "int";
+  }
+  static const TypeKind Kind = TypeKind::SymIntType;
+  // global singleton
+  static SymIntTypePtr get();
+
+ private:
+  SymIntType() : Type(TypeKind::SymIntType) {}
+};
+
+struct SymFloatType;
+using SymFloatTypePtr = SingletonTypePtr;
+struct TORCH_API SymFloatType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "SymFloat";
+  }
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    return "float";
+  }
+  static const TypeKind Kind = TypeKind::SymFloatType;
+  // global singleton
+  static SymFloatTypePtr get();
+
+ private:
+  SymFloatType() : Type(TypeKind::SymFloatType) {}
+};
+
+struct SymBoolType;
+using SymBoolTypePtr = SingletonTypePtr;
+struct TORCH_API SymBoolType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "SymBool";
+  }
+  std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
+    return "bool";
+  }
+  static const TypeKind Kind = TypeKind::SymBoolType;
+  // global singleton
+  static SymBoolTypePtr get();
+
+ private:
+  SymBoolType() : Type(TypeKind::SymBoolType) {}
+};
+
+struct IntType;
+using IntTypePtr = SingletonTypePtr;
+// This type represents a Python int number
+struct TORCH_API IntType : public NumberType {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "int";
+  }
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override {
+    // NOLINTNEXTLINE(bugprone-parent-virtual-call)
+    return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
+  }
+  static const TypeKind Kind = TypeKind::IntType;
+  // global singleton
+  static IntTypePtr get();
+
+ private:
+  IntType() : NumberType(TypeKind::IntType) {}
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return "int";
+  }
+};
+
+struct BoolType;
+using BoolTypePtr = SingletonTypePtr;
+// This node represents a Python bool value
+struct TORCH_API BoolType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "bool";
+  }
+  static const TypeKind Kind = TypeKind::BoolType;
+  // global singleton
+  static BoolTypePtr get();
+
+ private:
+  BoolType() : Type(TypeKind::BoolType) {}
+};
+
+struct StringType;
+using StringTypePtr = SingletonTypePtr;
+// This type represents a Python string
+struct TORCH_API StringType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    // we only use "str" (not "string") in both FunctionSchema and script
+    return annotation_str();
+  }
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return "str";
+  }
+  static const TypeKind Kind = TypeKind::StringType;
+  // global singleton
+  static StringTypePtr get();
+
+ private:
+  StringType() : Type(TypeKind::StringType) {}
+};
+
+struct StorageType;
+using StorageTypePtr = SingletonTypePtr;
+struct TORCH_API StorageType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return annotation_str();
+  }
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return "Storage";
+  }
+  static const TypeKind Kind = TypeKind::StorageType;
+  // global singleton
+  static StorageTypePtr get();
+
+ private:
+  StorageType() : Type(TypeKind::StorageType) {}
+};
+
+struct FunctionType;
+using FunctionTypePtr = std::shared_ptr;
+struct TORCH_API FunctionType : public NamedType {
+  static FunctionTypePtr create(torch::jit::Function* function) {
+    return FunctionTypePtr(
+        new FunctionType(function)); // NOLINT(modernize-make-shared)
+  }
+  bool equals(const Type& rhs) const override {
+    if (auto func_type = rhs.cast()) {
+      return func_type->function_ == function_;
+    }
+
+    return false;
+  }
+  std::string str() const override {
+    return "Function";
+  }
+  torch::jit::Function* function() const {
+    return function_;
+  }
+  static const TypeKind Kind = TypeKind::FunctionType;
+
+ private:
+  FunctionType(torch::jit::Function* function);
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    const auto& n = name().value();
+    return n.qualifiedName();
+  }
+  torch::jit::Function* function_;
+};
+
+struct NoneType;
+using NoneTypePtr = SingletonTypePtr;
+// This type represents a Python None
+struct TORCH_API NoneType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "NoneType";
+  }
+  bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override;
+
+  static const TypeKind Kind = TypeKind::NoneType;
+  // global singleton
+  static NoneTypePtr get();
+
+ private:
+  NoneType() : Type(TypeKind::NoneType) {}
+};
+
+struct GeneratorType;
+using GeneratorTypePtr = SingletonTypePtr;
+// This type represents a Generator
+struct TORCH_API GeneratorType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "Generator";
+  }
+  static const TypeKind Kind = TypeKind::GeneratorType;
+  // global singleton
+  static GeneratorTypePtr get();
+
+ private:
+  GeneratorType() : Type(TypeKind::GeneratorType) {}
+};
+
+struct QuantizerType;
+using QuantizerTypePtr = SingletonTypePtr;
+// This type represents a Quantizer
+struct TORCH_API QuantizerType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "Quantizer";
+  }
+  static const TypeKind Kind = TypeKind::QuantizerType;
+  // global singleton
+  static QuantizerTypePtr get();
+
+ private:
+  QuantizerType() : Type(TypeKind::QuantizerType) {}
+};
+
+struct QSchemeType;
+using QSchemeTypePtr = SingletonTypePtr;
+// This type represents a QScheme
+struct TORCH_API QSchemeType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "QScheme";
+  }
+  static const TypeKind Kind = TypeKind::QSchemeType;
+  // global singleton
+  static QSchemeTypePtr get();
+
+ private:
+  QSchemeType() : Type(TypeKind::QSchemeType) {}
+};
+
+struct DeviceObjType;
+using DeviceObjTypePtr = SingletonTypePtr;
+// This type represents a Device
+struct TORCH_API DeviceObjType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "Device";
+  }
+  static const TypeKind Kind = TypeKind::DeviceObjType;
+  // global singleton
+  static DeviceObjTypePtr get();
+
+ private:
+  DeviceObjType() : Type(TypeKind::DeviceObjType) {}
+};
+
+struct StreamObjType;
+using StreamObjTypePtr = SingletonTypePtr;
+// This type represents a Generator
+struct TORCH_API StreamObjType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "Stream";
+  }
+  static const TypeKind Kind = TypeKind::StreamObjType;
+  // global singleton
+  static StreamObjTypePtr get();
+
+private:
+  StreamObjType() : Type(TypeKind::StreamObjType) {}
+};
+
+struct VarType;
+using VarTypePtr = std::shared_ptr;
+// This type represents a type variable, used in FunctionSchema
+struct VarType : public SharedType {
+  static VarTypePtr create(std::string name_) {
+    return VarTypePtr(new VarType(std::move(name_)));
+  }
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return name();
+  }
+  const std::string& name() const {
+    return name_;
+  }
+  bool hasFreeVariables() const override {
+    return true;
+  }
+  static const TypeKind Kind = TypeKind::VarType;
+
+ private:
+  VarType(std::string name_)
+      : SharedType(TypeKind::VarType), name_(std::move(name_)) {}
+  std::string name_;
+};
+
+struct CapsuleType;
+using CapsuleTypePtr = SingletonTypePtr;
+// This type represents a Python Capsule.
+// It does not appear in the IR and is only used during runtime
+struct TORCH_API CapsuleType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "Capsule";
+  }
+  static const TypeKind Kind = TypeKind::CapsuleType;
+  // global singleton
+  static CapsuleTypePtr get();
+private:
+  CapsuleType()
+  : Type(TypeKind::CapsuleType) {}
+};
+
+struct PyObjectType;
+using PyObjectTypePtr = SingletonTypePtr;
+// This type represents a PyObject Type
+struct TORCH_API PyObjectType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "PyObject";
+  }
+  static const TypeKind Kind = TypeKind::PyObjectType;
+  // global singleton
+  static PyObjectTypePtr get();
+private:
+  PyObjectType()
+  : Type(TypeKind::PyObjectType) {}
+};
+
+enum class TypeVerbosity {
+  None,
+  Type,
+  TypeAndStride,
+  Full,
+  Symbolic,
+  Default = Full,
+};
+
+TORCH_API TypeVerbosity type_verbosity();
+
+TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t);
+template 
+TORCH_API std::ostream& operator<<(
+    std::ostream& out,
+    const VaryingShape& t);
+TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s);
+TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s);
+TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
+// what is the type, ignoring extra size/shape information?
+// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
+
+// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
+// subtypes as simply "Tensor"; we also create a new version of any
+// container types in which internal Tensors have undergone the same
+// operation. This is used for type comparisons between two Tensor types
+// (`unshapedType` means that we don't falsely return `false` for e.g.
+// Tensors of different dimensions). It's also used in the alias
+// analysis pass.
+// Be careful with calls because this can be very slow. If calling this
+// on a graph, use `EraseShapeInformation` in shape_analysis.h
+inline TypePtr unshapedType(const TypePtr& type) {
+  if (type->isSubtypeOf(*TensorType::get())) {
+    return TensorType::get();
+  }
+  at::ArrayRef contained = type->containedTypes();
+  if (contained.empty()) {
+    return type;
+  }
+  return type->withContained(fmap(type->containedTypes(), unshapedType));
+}
+
+inline TypePtr TensorType::fromNumberType(const Type& typ) {
+  if (typ.isSubtypeOf(*IntType::get())) {
+    return TensorType::createContiguous(at::kLong, at::kCPU, {});
+  } else if (typ.isSubtypeOf(*FloatType::get())) {
+    return TensorType::createContiguous(at::kDouble, at::kCPU, {});
+  } else if (typ.isSubtypeOf(*BoolType::get())) {
+    return TensorType::createContiguous(at::kBool, at::kCPU, {});
+  } else if (typ.kind() == NumberType::Kind) {
+    return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt);
+  }
+  TORCH_CHECK(false, "Unknown number type: ", typ.str());
+}
+inline TypePtr TensorType::fromBoolType() {
+  return TensorType::createContiguous(at::kBool, at::kCPU, {});
+}
+
+inline c10::optional tryScalarTypeFromJitType(const Type& type) {
+  if (type == *FloatType::get()) {
+    return at::typeMetaToScalarType(c10::get_default_dtype());
+  } else if (type == *IntType::get()) {
+    return at::ScalarType::Long;
+  } else if (type == *BoolType::get()) {
+    return at::ScalarType::Bool;
+  }
+  return c10::nullopt;
+}
+
+inline at::ScalarType scalarTypeFromJitType(const Type& type) {
+  auto result = tryScalarTypeFromJitType(type);
+  TORCH_CHECK(
+      result,
+      "Add new condition, expected Float, Complex, Int, or Bool but got",
+      type.str());
+  return *result;
+}
+
+// Attempt to find the correct supertype of the two types `t1` and `t2`.
+// If no supertype is found, then nullopt will be returned if
+// `default_to_union` is false, and `Union[t1, t2]` will be returned
+// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
+// then `t2` will be returned (and vice versa).
+//
+// Two different tensortypes will return dynamic.
+//
+// Currently we chose not to support returning a NumberType for
+// two types from the set of {FloatType, IntType, ComplexType}, because
+// there is a lack of operator support for NumberType.
+//
+// If `type_hint` is an `InterfaceType`, then we can use that as a
+// potential supertype for `ClassType`s in the list. Otherwise, we have
+// no way to find and use some common interface type
+TORCH_API c10::optional unifyTypes(
+    const TypePtr& t1,
+    const TypePtr& t2,
+    bool default_to_union = false,
+    const TypePtr& type_hint = nullptr);
+
+TORCH_API c10::optional unifyTypeList(
+    at::ArrayRef elements,
+    std::ostream& why_not,
+    bool default_to_union = false,
+    const TypePtr& type_hint = nullptr);
+
+namespace detail {
+template 
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return ([]() {
+      try {
+        return getCustomClassType();
+      } catch(const c10::Error&) {
+        TORCH_CHECK(
+            false,
+            "Type ",
+            c10::util::get_fully_qualified_type_name(),
+            " could not be converted to any of the known types."
+        );
+      }
+    }());
+  }
+};
+
+template 
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return getTypePtr_::call();
+  }
+};
+
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return AnyType::get();
+  }
+};
+
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return TensorType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return StorageType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return StreamObjType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return FloatType::get();
+  }
+};
+template <>
+struct getTypePtr_> final {
+  static decltype(auto) call() {
+    return ComplexType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return IntType::get();
+  }
+};
+
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return IntType::get();
+  }
+};
+
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return SymIntType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return IntType::get();
+  }
+};
+
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return SymFloatType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return FloatType::get();
+  }
+};
+
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return SymBoolType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return BoolType::get();
+  }
+};
+
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return DeviceObjType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return BoolType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return NumberType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return QSchemeType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return TypeFactory::create(
+        TypeFactory::get());
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return StringType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return StringType::get();
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return StringType::get();
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    // The "per vector" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = ListType::get("vector", inner_type);
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    // The "per ArrayRef" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = ListType::get("ArrayRef", inner_type);
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_ final {
+  static const auto& call() {
+    static auto type = ListType::create(getMaybeFakeTypePtr_::call());
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    // The "per List" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = ListType::get("List", inner_type);
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    static auto type = ListType::get("List", inner_type);
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    // The "per array" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    // (Concatenating the length onto the end of the string because we want a unique
+    // type_ptr created for every std::array type).
+    static auto type = ListType::get(std::string("array") + std::to_string(N), inner_type);
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_key_type = getMaybeFakeTypePtr_::call();
+    static auto inner_val_type = getMaybeFakeTypePtr_::call();
+    // The "per unordered_map" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type);
+    return type;
+  }
+};
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_key_type = getMaybeFakeTypePtr_::call();
+    static auto inner_val_type = getMaybeFakeTypePtr_::call();
+    // The "per Dict" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = DictType::get("Dict", inner_key_type, inner_val_type);
+    return type;
+  }
+};
+
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    // The "per optional" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = OptionalType::get(inner_type);
+    return type;
+  }
+};
+
+
+template<>
+struct getTypePtr_ final {
+  static const auto& call() {
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    // The "per optional" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto type = OptionalType::get(inner_type);
+    return type;
+  }
+};
+
+template 
+struct getMaybeFakeTypePtr_ final {
+  static const auto& call() {
+    // The "per optional" static singleton needs to live in a .cpp file,
+    // otherwise we'll end up with one singleton instance per shared library.
+    static auto inner_type = getMaybeFakeTypePtr_::call();
+    static auto type = OptionalType::get(inner_type);
+    return type;
+  }
+};
+
+template 
+struct getMaybeFakeTypePtr_, fake> final {
+  static const auto& call() {
+    static auto type = ([]() {
+      std::vector contained_types = {
+        (getMaybeFakeTypePtr_::call())...
+      };
+      return TupleType::create(std::move(contained_types));
+    })();
+    return type;
+  }
+};
+template <>
+struct getTypePtr_ final {
+  static decltype(auto) call() {
+    return NoneType::get();
+  }
+};
+} // namespace detail
+template 
+inline decltype(auto) getTypePtr() {
+  // TODO: static_assert that a templated function exists, and throw a friendly
+  // error message if not
+  return detail::getMaybeFakeTypePtr_::call();
+}
+
+template 
+inline TypePtr getTypePtrCopy() {
+  // TODO: static_assert that a templated function exists, and throw a friendly
+  // error message if not
+  return getTypePtr();
+}
+
+template 
+inline decltype(auto) getFakeTypePtr() {
+  return detail::getMaybeFakeTypePtr_::call();
+}
+
+template 
+inline TypePtr getFakeTypePtrCopy() {
+  return getFakeTypePtr();
+}
+
+using TypeEnv = std::unordered_map;
+struct MatchTypeReturn {
+  MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {}
+  static MatchTypeReturn Success() {
+    return MatchTypeReturn();
+  }
+  bool success() const {
+    return !reason_.has_value();
+  }
+  const std::string& reason() const {
+    return reason_.value();
+  }
+
+ private:
+  MatchTypeReturn()
+  : reason_(c10::nullopt) {}
+  c10::optional reason_; // is there is no match, this contains the reason
+};
+
+// attempt to match the type variables in formal to actual, adding them to type_env.
+// If no match is possible this returns a MatchTypeReturn with r.success() == false
+// and a r.reason() that describes why it could not match.
+// note: It is possible to successfully match a formal, but for type variables
+// in the formal to still not be defined. In particular, None matches Optional[T]
+// but does not define the value of T.
+TORCH_API MatchTypeReturn
+matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env);
+
+// replace type variables appearing in `type` with the values in
+// `type_env`. Returns nullptr if a variable used in `type`
+// does not appear in `type_env`
+TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env);
+
+TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type);
+
+struct InterfaceType;
+using InterfaceTypePtr = std::shared_ptr;
+
+// Interfaces are a list of abstract methods that a class might meet.
+// If a class provides those methods, it implicitly meets the interface.
+
+// Subtype relations for Interface with ClassType:
+// lhs (ClassType or InterfaceType) is a subtype of rhs if:
+// 1. lhs methods are a superset of rhs methods
+// 2. if rhs is module interface, the lhs must be module interface or module itself
+struct TORCH_API InterfaceType : public NamedType {
+  static InterfaceTypePtr create(
+      QualifiedName qualifiedName, bool is_module=false);
+
+  bool equals(const Type& rhs) const override {
+    if (auto user_rhs = rhs.castRaw()) {
+      return isSubTypeImpl(*this, *user_rhs, nullptr) &&
+          isSubTypeImpl(*user_rhs, *this, nullptr);
+    }
+    return false;
+  }
+
+  std::string str() const override {
+    return std::string("InterfaceType<") + name()->name() + ">";
+  }
+
+  bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
+
+  // try to find a method of this interface,
+  // returns nullptr if not found.
+  const FunctionSchema* getMethod(const std::string& name) const;
+  void addMethod(FunctionSchema schema);
+  const std::vector& methods() const {
+    return *methods_;
+  }
+
+  bool is_module() const override{
+    return is_module_;
+  }
+  static const TypeKind Kind = TypeKind::InterfaceType;
+  ~InterfaceType() override;
+ private:
+  InterfaceType(QualifiedName name, bool is_module);
+  static bool isSubTypeImpl(
+      const InterfaceType& lhs,
+      const InterfaceType& rhs,
+      std::ostream* why_not);
+
+  std::string annotation_str_impl(C10_UNUSED TypePrinter printer = nullptr) const override {
+    return name()->qualifiedName();
+  }
+
+  // shared_ptr so that this header does not have to depend on
+  // FunctionSchema.h
+  std::shared_ptr> methods_;
+  // flag to distinguish if it's an interface type from a module or not
+  bool is_module_;
+};
+
+template 
+struct EnumerationType : public Type {
+static const TypeKind Kind = K;
+
+bool equals(const Type& rhs) const override {
+  return rhs.kind() == kind();
+}
+
+protected:
+EnumerationType() : Type(Kind) {}
+};
+
+// WARNING: These enumeration types below DO NOT actually get parsed out
+// from the logical schema strings, instead they are mapped as ints.  To
+// observe these types, use real_type() instead of type() on Argument
+
+struct ScalarTypeType;
+using ScalarTypeTypePtr = SingletonTypePtr;
+struct TORCH_API ScalarTypeType : public EnumerationType {
+std::string str() const override {
+return "ScalarType";
+}
+static const TypeKind Kind = TypeKind::ScalarTypeType;
+// global singleton
+static ScalarTypeTypePtr get();
+
+private:
+ScalarTypeType() : EnumerationType() {}
+};
+
+struct MemoryFormatType;
+using MemoryFormatTypePtr = SingletonTypePtr;
+struct TORCH_API MemoryFormatType : public EnumerationType {
+std::string str() const override {
+return "MemoryFormat";
+}
+static const TypeKind Kind = TypeKind::MemoryFormatType;
+// global singleton
+static MemoryFormatTypePtr get();
+
+private:
+MemoryFormatType() : EnumerationType() {}
+};
+
+struct LayoutType;
+using LayoutTypePtr = SingletonTypePtr;
+struct TORCH_API LayoutType : public EnumerationType {
+std::string str() const override {
+return "Layout";
+}
+static const TypeKind Kind = TypeKind::LayoutType;
+// global singleton
+static LayoutTypePtr get();
+
+private:
+LayoutType() : EnumerationType() {}
+};
+
+namespace detail {
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return ScalarTypeType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return LayoutType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return MemoryFormatType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return IntType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return IntType::get();
+  }
+};
+template <>
+struct getMaybeFakeTypePtr_ final {
+  static decltype(auto) call() {
+    return IntType::get();
+  }
+};
+} // namespace detail
+
+// the common supertype of all lists,
+// List[T] <: AnyList for all T
+struct AnyListType;
+using AnyListTypePtr = SingletonTypePtr;
+struct TORCH_API AnyListType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "list";
+  }
+  static const TypeKind Kind = TypeKind::AnyListType;
+  // global singleton
+  static AnyListTypePtr get();
+private:
+  AnyListType()
+  : Type(TypeKind::AnyListType) {}
+};
+
+// the common supertype of all tuples,
+// Tuple[T...] <: AnyTuple for all T
+struct AnyTupleType;
+using AnyTupleTypePtr = SingletonTypePtr;
+struct TORCH_API AnyTupleType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+
+  std::string str() const override {
+    return "tuple";
+  }
+  static const TypeKind Kind = TypeKind::AnyTupleType;
+
+  // global singleton
+  static AnyTupleTypePtr get();
+private:
+  AnyTupleType()
+  : Type(TypeKind::AnyTupleType) {}
+};
+
+// the common supertype of all classes,
+// ClassType <: AnyClassType for all classes
+struct AnyClassType;
+using AnyClassTypePtr = SingletonTypePtr;
+struct TORCH_API AnyClassType : public Type {
+  bool equals(const Type& rhs) const override {
+    return rhs.kind() == kind();
+  }
+  std::string str() const override {
+    return "AnyClassType";
+  }
+  static const TypeKind Kind = TypeKind::AnyClassType;
+  // global singleton
+  static AnyClassTypePtr get();
+private:
+  AnyClassType()
+  : Type(TypeKind::AnyClassType) {}
+};
+
+template<>
+inline typename detail::CastReturnType::type Type::cast() {
+  if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
+      kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
+    return std::static_pointer_cast(static_cast(this)->shared_from_this());
+  }
+  return nullptr;
+}
+
+template<>
+inline typename detail::CastConstReturnType::type Type::cast() const {
+  if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
+      kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
+    return std::static_pointer_cast(static_cast(this)->shared_from_this());
+  }
+  return nullptr;
+}
+
+template<>
+inline const NamedType* Type::castRaw() const {
+  if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
+      kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
+    return static_cast(this);
+  }
+  return nullptr;
+}
+
+// Used as a return type when inferring the IValue type of a Python object.
+struct InferredType {
+  /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {}
+  /* implicit */ InferredType(std::string reason)
+      : type_(nullptr), reason_(std::move(reason)) {}
+  TypePtr type() const {
+    TORCH_INTERNAL_ASSERT(
+        type_,
+        "Tried to get the type from an InferredType but the type is null. ",
+        "Reason: ",
+        reason_);
+    return type_;
+  }
+  bool success() const {
+    return type_ != nullptr;
+  }
+  const std::string& reason() const {
+    TORCH_INTERNAL_ASSERT(!type_);
+    return reason_;
+  }
+
+private:
+  TypePtr type_;
+  std::string reason_;
+};
+
+TORCH_API bool containsAnyType(const TypePtr& type);
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/jit_type_base.h b/MLPY/Lib/site-packages/torch/include/ATen/core/jit_type_base.h
new file mode 100644
index 0000000000000000000000000000000000000000..73f153ef523e0b54310aac0159e4677971075712
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/jit_type_base.h
@@ -0,0 +1,719 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+#define C10_FORALL_TYPES(_) \
+  _(AnyType)                \
+  _(EnumType)               \
+  _(AnyEnumType)            \
+  _(TensorType)             \
+  _(StorageType)            \
+  _(TupleType)              \
+  _(ListType)               \
+  _(DictType)               \
+  _(NumberType)             \
+  _(FloatType)              \
+  _(ComplexType)            \
+  _(FutureType)             \
+  _(AwaitType)              \
+  _(RRefType)               \
+  _(IntType)                \
+  _(NoneType)               \
+  _(StringType)             \
+  _(GeneratorType)          \
+  _(QuantizerType)          \
+  _(BoolType)               \
+  _(OptionalType)           \
+  _(VarType)                \
+  _(DeviceObjType)          \
+  _(StreamObjType)          \
+  _(FunctionType)           \
+  _(ClassType)              \
+  _(PyObjectType)           \
+  _(CapsuleType)            \
+  _(InterfaceType)          \
+  _(QSchemeType)            \
+  _(ScalarTypeType)         \
+  _(LayoutType)             \
+  _(MemoryFormatType)       \
+  _(AnyListType)            \
+  _(AnyTupleType)           \
+  _(AnyClassType)           \
+  _(SymIntType)             \
+  _(SymFloatType)           \
+  _(SymBoolType)            \
+  _(UnionType)              \
+  _(DynamicType)
+
+enum class TypeKind {
+#define DEFINE_TYPE(T) T,
+  C10_FORALL_TYPES(DEFINE_TYPE)
+#undef DEFINE_TYPE
+};
+
+TORCH_API const char* typeKindToString(TypeKind kind);
+
+struct Type;
+struct SharedType;
+
+// Use this to customize how a Type is printed using `annotation_str()`. If
+// c10::nullopt is returned, `annotation_str()` falls through to its default
+// implementation.
+using TypePrinter = std::function(const Type&)>;
+
+namespace detail {
+template 
+struct IsSingletonType : public std::integral_constant {};
+} // namespace detail
+#define TORCH_DECLARE_SINGLETON(Type) \
+  struct Type;                                                          \
+  namespace detail { \
+  template <> struct IsSingletonType : public std::integral_constant {}; \
+  }
+
+TORCH_DECLARE_SINGLETON(AnyType);
+TORCH_DECLARE_SINGLETON(AnyEnumType);
+TORCH_DECLARE_SINGLETON(NumberType);
+TORCH_DECLARE_SINGLETON(FloatType);
+TORCH_DECLARE_SINGLETON(ComplexType);
+TORCH_DECLARE_SINGLETON(IntType);
+TORCH_DECLARE_SINGLETON(BoolType);
+TORCH_DECLARE_SINGLETON(StringType);
+TORCH_DECLARE_SINGLETON(StorageType);
+TORCH_DECLARE_SINGLETON(NoneType);
+TORCH_DECLARE_SINGLETON(GeneratorType);
+TORCH_DECLARE_SINGLETON(QuantizerType);
+TORCH_DECLARE_SINGLETON(QSchemeType);
+TORCH_DECLARE_SINGLETON(DeviceObjType);
+TORCH_DECLARE_SINGLETON(StreamObjType);
+TORCH_DECLARE_SINGLETON(CapsuleType);
+TORCH_DECLARE_SINGLETON(PyObjectType);
+TORCH_DECLARE_SINGLETON(ScalarTypeType);
+TORCH_DECLARE_SINGLETON(LayoutType);
+TORCH_DECLARE_SINGLETON(MemoryFormatType);
+TORCH_DECLARE_SINGLETON(AnyListType);
+TORCH_DECLARE_SINGLETON(AnyTupleType);
+TORCH_DECLARE_SINGLETON(AnyClassType);
+
+namespace detail {
+template 
+struct CastReturnType {
+  using type = std::shared_ptr;
+};
+
+template 
+struct CastReturnType::value>::type> {
+  using type = SingletonTypePtr;
+};
+
+template 
+struct CastConstReturnType {
+  using type = std::shared_ptr;
+};
+
+template 
+struct CastConstReturnType::value>::type> {
+  using type = SingletonTypePtr;
+};
+
+template 
+struct as_shared_type {
+  using type = SharedType*;
+};
+
+template 
+struct as_shared_type {
+  using type = const SharedType *;
+};
+} // namespace detail
+
+struct TORCH_API Type {
+  friend TORCH_API bool operator==(const Type& lhs, const Type& rhs);
+  private:
+  TypeKind kind_;
+
+  protected:
+  Type(TypeKind kind) : kind_(kind) {}
+
+  Type(const Type&) = default;
+  Type& operator=(const Type&) = default;
+  Type(Type&&) noexcept = default;
+  Type& operator=(Type&&) noexcept = default;
+
+  virtual std::string annotation_str_impl(TypePrinter /*printer*/) const {
+    return str();
+  }
+  // a == b
+  virtual bool equals(const Type& rhs) const = 0;
+  // a == b <=> b == a
+  virtual bool symmetric() const {
+    return true;
+  }
+
+ public:
+  template 
+  class SingletonOrSharedTypePtr {
+   public:
+    using element_type = typename std::shared_ptr::element_type;
+
+    SingletonOrSharedTypePtr() = default;
+
+    /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr x)
+        : repr_(std::move(x)) {}
+
+    template ::value, bool> = true>
+    /* implicit */ SingletonOrSharedTypePtr(std::shared_ptr x)
+        : repr_(std::move(x)) {}
+
+    /* implicit */ SingletonOrSharedTypePtr(std::nullptr_t)
+        : repr_(nullptr) {}
+
+    /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p)
+        : repr_(p) {}
+
+    template ::value, bool> = true>
+    /* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr p)
+        : repr_(SingletonTypePtr(p.get())) {}
+
+
+    // We need to support construction from T* for pybind. The problem
+    // is that it's not clear if we are supposed to be taking shared
+    // ownership or not.
+    //
+    // Case 1: if T is known statically to derive from SharedType, we should use
+    // shared_from_this() and take shared_ownership.
+    //
+    // Case 2: if T is exactly Type, we need to do a dynamic_cast to
+    // check if it's a SharedType and do the right thing.
+    //
+    // Case 3: Otherwise, T is not a SharedType. (debug-check this
+    // assumption!) Use a singleton pointer.
+
+    template ::value, bool> = true>
+    /* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast::type>(p)->shared_from_this()) {}
+
+    template ::value, bool> = true>
+    /* implicit */ SingletonOrSharedTypePtr(T* p) {
+      if (auto* shared_p = dynamic_cast::type>(p)) {
+        repr_ = Repr(shared_p->shared_from_this());
+      } else {
+        repr_ = Repr(p);
+      }
+    }
+
+    template ::value && !std::is_base_of::value, bool> = true>
+    /* implicit */ SingletonOrSharedTypePtr(T* p)
+        : repr_(p) {
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast::type>(p) == nullptr);
+    }
+
+    SingletonOrSharedTypePtr(const SingletonOrSharedTypePtr&) = default;
+    SingletonOrSharedTypePtr(SingletonOrSharedTypePtr&&) noexcept = default;
+    SingletonOrSharedTypePtr& operator=(const SingletonOrSharedTypePtr&) = default;
+    SingletonOrSharedTypePtr& operator=(SingletonOrSharedTypePtr&&) noexcept = default;
+
+    T* get() const {
+      return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast(repr_.rawRepr().first);
+    }
+
+    operator bool() const {
+      return repr_.isNonNull();
+    }
+
+    bool operator==(std::nullptr_t) const {
+      return !repr_.isNonNull();
+    }
+
+    bool operator!=(std::nullptr_t) const {
+      return repr_.isNonNull();
+    }
+
+    template , void>::value, bool> = true>
+    U& operator*() const {
+      return *get();
+    }
+
+    T* operator->() const {
+      return get();
+    }
+
+  private:
+    // NOTE: SharedPtrWrapper exists to work around a baffling bug in
+    // nvcc; see comment in destroy() below.
+    struct SharedPtrWrapper {
+      SharedPtrWrapper(std::shared_ptr &&x)
+          : repr_(std::move(x)) {}
+      std::shared_ptr repr_;
+    };
+    union Repr {
+      Repr() : Repr(nullptr) {}
+
+      explicit Repr(std::shared_ptr x)
+          : shared_(std::move(x)) {}
+
+      explicit Repr(std::nullptr_t)
+          : singletonRepr_(nullptr) {}
+
+      explicit Repr(SingletonTypePtr p)
+          : singletonRepr_(p.get()) {}
+
+      ~Repr() {
+        destroy();
+      }
+
+      // NOTE: the only non-UB way to access our null state is through
+      // rawRepr(), because our copy operation doesn't preserve which
+      // union member is active for null pointers.
+      Repr(const Repr& rhs) {
+        if (rhs.isSharedAndNonNull()) {
+          new (&shared_) SharedPtrWrapper(rhs.shared_);
+        } else {
+          singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first);
+          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
+          singletonRepr_.unused_ = nullptr;
+        }
+      }
+
+      Repr(Repr&& rhs) noexcept {
+        if (rhs.isSharedAndNonNull()) {
+          new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
+        } else {
+          singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first);
+          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
+          singletonRepr_.unused_ = nullptr;
+        }
+      }
+
+      Repr& operator=(const Repr& rhs) {
+        if (&rhs == this) {
+          return *this;
+        }
+        if (rhs.isSharedAndNonNull()) {
+          if (isSharedAndNonNull()) {
+            shared_ = rhs.shared_;
+          } else {
+            new (&shared_) SharedPtrWrapper(rhs.shared_);
+          }
+        } else {
+          if (isSharedAndNonNull()) {
+            destroy();
+          }
+          singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first);
+          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
+          singletonRepr_.unused_ = nullptr;
+        }
+        return *this;
+      }
+
+      Repr& operator=(Repr&& rhs) noexcept {
+        if (&rhs == this) {
+          return *this;
+        }
+        if (rhs.isSharedAndNonNull()) {
+          if (isSharedAndNonNull()) {
+            shared_ = std::move(rhs.shared_);
+          } else {
+            new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
+          }
+        } else {
+          if (isSharedAndNonNull()) {
+            destroy();
+          }
+          singletonRepr_.singleton_ = static_cast(rhs.rawRepr().first);
+          TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
+          singletonRepr_.unused_ = nullptr;
+        }
+        return *this;
+      }
+
+      SharedPtrWrapper shared_;
+
+      struct SingletonRepr {
+        explicit SingletonRepr(T* s) : singleton_(s) {}
+        T* singleton_;
+        void* unused_ = nullptr;
+      } singletonRepr_;
+      struct RawRepr {
+        void* first;
+        void* nullIfSingleton_;
+      };
+
+      // It is UB to read the singleton part of Repr if it was
+      // constructed as a shared_ptr and vice versa, but memcpying out
+      // the representation is always OK, so here's an accessor to obey
+      // the letter of the law.
+      RawRepr rawRepr() const {
+        RawRepr repr{};
+        memcpy(&repr, reinterpret_cast(this), sizeof(RawRepr));
+        return repr;
+      }
+
+      bool isNonNull() const {
+        auto repr = rawRepr();
+        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
+        return repr.first != nullptr;
+      }
+
+      bool isSharedAndNonNull() const {
+        return rawRepr().nullIfSingleton_ != nullptr;
+      }
+
+     private:
+      void destroy() {
+        if (isSharedAndNonNull()) {
+          // Without SharedPtrWrapper, this line would read
+          // `shared_.~shared_ptr()` and nvcc would complain with
+          // "error: expected primary-expression before '>' token"
+          // referring to the "t" in "shared_ptr". SharedPtrWrapper
+          // exists to work around this compiler bug.
+          shared_.~SharedPtrWrapper();
+        }
+      }
+    } repr_;
+  };
+
+  using TypePtr = SingletonOrSharedTypePtr;
+  using Ptr = TypePtr;
+  using ElementType = Type;
+
+  // subtyping relation. By default, we return true for the case
+  // when the type is exactly equal or if this <: T where rhs = Optional[T]
+
+  // if this returns false and the why_not stream is non-null, it contains
+  // additional details that describe why this is not a subtype of 'rhs'.
+  // This additional information should only contain details that are not
+  // obvious from the annotation_str() that describes the type. For instance it
+  // is clear that `int <: str` is false but not clear why `Foo <: InterfaceBar`
+  // might be false.
+  virtual bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const;
+  virtual bool is_module() const;
+  bool isSubtypeOf(const Type& rhs) const {
+    return isSubtypeOfExt(rhs, nullptr);
+  }
+  // Compatibility shims to accommodate existing code that passes shared_ptrs
+  // around. Ideally, we would just delete this, but it should be harmless.
+  template 
+  typename std::enable_if::value, bool>::type
+  isSubtypeOf(const std::shared_ptr& rhs) const {
+    return isSubtypeOf(*rhs);
+  }
+
+  template 
+  typename std::enable_if::value, bool>::type
+  isSubtypeOf(const SingletonOrSharedTypePtr& rhs) const {
+    return isSubtypeOf(*rhs);
+  }
+
+  template 
+  typename std::enable_if::value, bool>::type
+  isSubtypeOf(SingletonTypePtr rhs) const {
+    return isSubtypeOf(*rhs);
+  }
+
+  template 
+  typename std::enable_if::value, bool>::type
+  isSubtypeOfExt(const SingletonOrSharedTypePtr& rhs, std::ostream* why_not) const {
+    return isSubtypeOfExt(*rhs, why_not);
+  }
+
+  template 
+  typename std::enable_if::value, bool>::type
+  isSubtypeOfExt(const std::shared_ptr& rhs, std::ostream* why_not) const {
+    return isSubtypeOfExt(*rhs, why_not);
+  }
+
+  template 
+  typename std::enable_if::value, bool>::type
+  isSubtypeOfExt(SingletonTypePtr rhs, std::ostream* why_not) const {
+    return isSubtypeOfExt(*rhs, why_not);
+  }
+
+  // How this type will appear in FunctionSchema declarations
+  virtual std::string str() const = 0;
+
+  // How this type will appear as if it were a type annotation in Python
+  // which is sometimes different than how it appears in declarations (e.g.
+  // int[] vs List[int])
+  //
+  // Takes a custom printer that users can pass in to customize the output of
+  // this method.
+  std::string annotation_str(TypePrinter printer) const {
+    if (printer) {
+      // the printer can return nullopt to fall through to the default impl
+      if (auto renamed = printer(*this)) {
+        return *renamed;
+      }
+    }
+    return annotation_str_impl(std::move(printer));
+  }
+  std::string annotation_str() const {
+    // Overload instead of define a default value for `printer` to help
+    // debuggers out.
+    return annotation_str(nullptr);
+  }
+
+  // Returns a human readable string that includes additional information like
+  // "type is inferred rather than explicitly defined" to help construct more
+  // user-friendly messages.
+  virtual std::string repr_str() const {
+    return annotation_str();
+  }
+
+  TypeKind kind() const {
+    return kind_;
+  }
+
+  virtual bool isUnionType() const {
+    return false;
+  }
+
+  virtual bool requires_grad() const {
+    for (const auto& ct : containedTypes()) {
+      if (ct->requires_grad()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  // Dynamically cast this object to the subclass indicated by the
+  // template variable, returning nullptr if the cast is invalid.
+  template ::value, bool> = true>
+  typename detail::CastReturnType::type cast() {
+    if (T::Kind == kind()) {
+      return std::static_pointer_cast(static_cast(this)->shared_from_this());
+    }
+    return nullptr;
+  }
+  template ::value, bool> = true>
+  typename detail::CastReturnType::type cast() {
+    if (T::Kind == kind()) {
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get());
+      return typename detail::CastReturnType::type(static_cast(this));
+    }
+    return nullptr;
+  }
+  template ::value, bool> = true>
+  typename detail::CastConstReturnType::type cast() const {
+    if (T::Kind == kind()) {
+      return std::static_pointer_cast(static_cast(this)->shared_from_this());
+    }
+    return nullptr;
+  }
+  template ::value, bool> = true>
+  typename detail::CastConstReturnType::type cast() const {
+    if (T::Kind == kind()) {
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this == T::get().get());
+      return typename detail::CastConstReturnType::type(static_cast(this));
+    }
+    return nullptr;
+  }
+  template 
+  T* castRaw() {
+    if (T::Kind == kind()) {
+      return static_cast(this);
+    }
+    return nullptr;
+  }
+  template 
+  const T* castRaw() const {
+    if (T::Kind == kind()) {
+      return static_cast(this);
+    }
+    return nullptr;
+  }
+  template 
+  auto expect() {
+    auto r = cast();
+    AT_ASSERT(r);
+    return r;
+  }
+  template 
+  auto expect() const {
+    auto r = cast();
+    AT_ASSERT(r);
+    return r;
+  }
+  template 
+  T& expectRef() {
+    auto* r = castRaw();
+    AT_ASSERT(r);
+    return *r;
+  }
+  template 
+  const T& expectRef() const {
+    auto* r = castRaw();
+    AT_ASSERT(r);
+    return *r;
+  }
+  virtual ~Type() = default;
+  virtual bool hasFreeVariables() const {
+    return false;
+  }
+  // list of types this type contains, e.g. for a List then element type of a
+  // list for a tuple, the types of the tuple elements
+  virtual at::ArrayRef containedTypes() const {
+    return {};
+  }
+  virtual TypePtr containedType(size_t i) const {
+    return containedTypes().at(i);
+  }
+  virtual size_t containedTypeSize() const {
+    return containedTypes().size();
+  }
+  // create a new version of this type, replacing its contained types with
+  // contained_types
+  TypePtr withContained(std::vector contained_types);
+  // per-type constructor, you only need to override this if the
+  // containedTypes() is not empty
+  virtual TypePtr createWithContained(
+      std::vector /*contained_types*/) const {
+    AT_ERROR(
+        "type with contained types did not overload createWithContained: ",
+        str());
+  }
+
+};
+
+template 
+using SingletonOrSharedTypePtr = Type::SingletonOrSharedTypePtr;
+
+
+template 
+bool operator==(const SingletonOrSharedTypePtr& x, const SingletonOrSharedTypePtr& y) {
+  return (void*)x.get() == (void*)y.get();
+}
+
+template 
+bool operator==(const SingletonOrSharedTypePtr& x, const std::shared_ptr& y) {
+  return (void*)x.get() == (void*)y.get();
+}
+
+template 
+bool operator==(const std::shared_ptr& x, const SingletonOrSharedTypePtr& y) {
+  return (void*)x.get() == (void*)y.get();
+}
+
+template 
+bool operator==(const SingletonOrSharedTypePtr& x, const SingletonTypePtr& y) {
+  return (void*)x.get() == (void*)y.get();
+}
+
+template 
+bool operator==(const SingletonTypePtr& x, const SingletonOrSharedTypePtr& y) {
+  return (void*)x.get() == (void*)y.get();
+}
+
+template 
+bool operator!=(const SingletonOrSharedTypePtr& x, const SingletonOrSharedTypePtr& y) {
+  return !(x == y);
+}
+
+template 
+bool operator!=(const SingletonOrSharedTypePtr& x, const std::shared_ptr& y) {
+  return !(x == y);
+}
+
+template 
+bool operator!=(const std::shared_ptr& x, const SingletonOrSharedTypePtr& y) {
+  return !(x == y);
+}
+
+template 
+bool operator!=(const SingletonOrSharedTypePtr& x, const SingletonTypePtr& y) {
+  return !(x == y);
+}
+
+template 
+bool operator!=(const SingletonTypePtr& x, const SingletonOrSharedTypePtr& y) {
+  return !(x == y);
+}
+
+using TypePtr = SingletonOrSharedTypePtr;
+using ConstTypePtr = SingletonOrSharedTypePtr;
+
+// Explicitly enable MaybeOwned>, rather than allowing
+// MaybeOwned to be used for any type right away.
+template 
+struct MaybeOwnedTraits>
+    : public MaybeOwnedTraitsGenericImpl> {};
+
+// Base class for Types that are guaranteed to be owned by std::shared_ptr.
+struct TORCH_API SharedType : public Type, public std::enable_shared_from_this {
+  using Type::Type;
+};
+
+inline TypePtr Type::withContained(std::vector contained_types) {
+  auto current_contained = containedTypes();
+  // Types with no contained_types don't need this call. Check before calling!
+  //
+  // (We can't support this efficiently because types without
+  // contained types may be singletons, in which case
+  // shared_from_this will crash; we would have to provide a virtual
+  // typeptr_from_this or isSingleton.)
+  TORCH_INTERNAL_ASSERT(!current_contained.empty() && current_contained.size() == contained_types.size());
+  if (current_contained.equals(contained_types)) {
+    return std::static_pointer_cast(static_cast(this)->shared_from_this());
+  }
+  return createWithContained(std::move(contained_types));
+}
+
+
+TORCH_API inline bool operator==(const Type& lhs, const Type& rhs) {
+  if (C10_UNLIKELY(!rhs.symmetric())) {
+    return rhs.equals(lhs);
+  }
+  return lhs.equals(rhs);
+}
+
+struct NamedType;
+using NamedTypePtr = std::shared_ptr;
+using ConstNamedTypePtr = std::shared_ptr;
+
+struct TORCH_API NamedType : public SharedType {
+  NamedType(TypeKind tk, c10::optional name)
+      : SharedType(tk), name_(std::move(name)) {
+    TORCH_INTERNAL_ASSERT(
+        tk == TypeKind::TupleType || tk == TypeKind::FunctionType ||
+            tk == TypeKind::ClassType || tk == TypeKind::InterfaceType ||
+            tk == TypeKind::EnumType,
+        "If you add a new kind of NamedType, ",
+        "please update the cast specialization and this assert");
+  }
+
+  // Fully qualified name of type
+  // Looks like: "foo.bar.Baz".
+  const c10::optional& name() const {
+    return name_;
+  }
+
+ private:
+  c10::optional name_;
+};
+
+} // namespace c10
+
+namespace std {
+template 
+struct hash> {
+  size_t operator()(const c10::SingletonOrSharedTypePtr& x) const {
+    return std::hash()(x.get());
+  }
+};
+} // namespace std
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/adaption.h b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/adaption.h
new file mode 100644
index 0000000000000000000000000000000000000000..e6e555f3bb47b61665088f6ee2ca7179f5b120c0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/adaption.h
@@ -0,0 +1,83 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+/*
+ * [Note: hacky wrapper removal for optional tensor]
+ *
+ * The kernel implementation takes an optional tensor marked in the schema as
+ * Tensor? but the C++ function takes Tensor instead of the optional
+ * expected by the dispatcher.
+ *
+ * To remove the hacky wrapper, the C++ function is changed to take
+ * optional and unwrap the Tensor value at the beginning of
+ * the function, e.g.:
+ *   > c10::MaybeOwned weight_maybe_owned =
+ *   >     at::borrow_from_optional_tensor(weight_opt);
+ *   > const Tensor& weight = *weight_maybe_owned;
+ *
+ * We may want to make the kernel handle optional directly without
+ * going through the creation of a default-constructed Tensor in
+ * at::borrow_from_optional_tensor.
+ */
+
+/*
+ * [Note: hacky wrapper removal for TensorOptions]
+ *
+ * The kernel implementation takes a TensorOptions argument but the dispatcher
+ * expects separate arguments for dtype, layout, device, pin_memory.
+ *
+ * To remove the hacky wrapper, the kernel implementation is changed to take
+ * the 4 arguments (dtype, layout, device, pin_memory), and assemble the
+ * TensorOptions value at the beginning of the function, e.g.:
+ *   > TensorOptions options = TensorOptions().dtype(dtype).layout(layout)
+ *   >    .device(device).pinned_memory(pin_memory);
+ *
+ * We may want make the kernel handle these parameters directly without going
+ * through the creation of a TensorOptions value.
+ */
+
+namespace c10 {
+namespace impl {
+
+TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
+
+inline void check_and_update_common_device(optional& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
+  // TODO: Remove this once the following issue is addressed:
+  // https://github.com/pytorch/pytorch/issues/57380
+  if (!tensor.defined()) {
+    return;
+  }
+
+  if (!common_device.has_value()) {
+    common_device = tensor.device();
+    return;
+  }
+
+  if (C10_UNLIKELY(common_device != tensor.device())) {
+    common_device_check_failure(*common_device, tensor, methodName, argName);
+  }
+}
+
+inline void check_and_update_common_device(optional& common_device, const optional& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
+  if (tensor.has_value()) {
+    check_and_update_common_device(common_device, tensor.value(), methodName, argName);
+  }
+}
+
+inline void check_and_update_common_device(optional& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
+  for (const auto& tensor : tensors) {
+    check_and_update_common_device(common_device, tensor, methodName, argName);
+  }
+}
+
+inline void check_and_update_common_device(optional& common_device, const List>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
+  for (const auto& tensor : tensors) {
+    check_and_update_common_device(common_device, tensor, methodName, argName);
+  }
+}
+} // namespace impl
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/infer_schema.h b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/infer_schema.h
new file mode 100644
index 0000000000000000000000000000000000000000..7e089330f5a8681d0f41e6357fbc63b6f328b79d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/infer_schema.h
@@ -0,0 +1,160 @@
+#pragma once
+
+/**
+ * This file contains functionality to take a C++ function and infer its
+ * c10::FunctionSchema.
+ */
+
+#include 
+#include 
+
+namespace c10 {
+namespace detail {
+
+namespace infer_schema {
+
+/// The templated inference code creates `ArgumentDef` instead of `Argument`,
+/// because that can be constructed at compile time and has a much smaller
+/// binary size than having calls to `Argument` constructors in the template.
+/// Creating `Argument` objects from `ArgumentDef` can then be done at
+/// runtime in a non-templated way.
+struct ArgumentDef final {
+  using GetTypeFn = TypePtr();
+  GetTypeFn* getTypeFn;
+  GetTypeFn* getFakeTypeFn;
+  constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {}
+  explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {}
+};
+
+template
+struct bool_t {};
+template<> struct bool_t : std::true_type {};
+template<> struct bool_t : std::false_type {};
+
+/// Checks the static C++ types `Types` for correctness to catch common error cases.
+template 
+constexpr int checkStaticTypes() {
+ // Give nice error messages for some of the common error cases.
+ // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
+ static_assert(std::conjunction<
+     bool_t::value || std::is_same::value || std::is_same::value || std::is_same::value>...
+   >::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
+ static_assert(std::conjunction<
+     bool_t::value>...
+   >::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
+ return 0;
+}
+
+template 
+constexpr std::array createArgumentVectorFromTypes(std::index_sequence) {
+  return (
+    // Check types for common errors
+    checkStaticTypes(),
+
+    // Create the return value
+    std::array{
+      ArgumentDef(&getTypePtrCopy>, &getFakeTypePtrCopy>)...}
+  );
+}
+
+/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
+/// as template arguments.
+template struct createArguments final {};
+template
+struct createArguments> final {
+  static constexpr std::array call() {
+    return createArgumentVectorFromTypes(
+        std::make_index_sequence()
+    );
+  }
+};
+
+/// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
+/// as a tuple (i.e. in the way c10 kernels return values).
+/// It can be a tuple if there's three output arguments with types A, B, C.
+/// It can be an empty tuple<>, or void for kernels that don't return anything.
+/// It can be a single type A (i.e. no tuple) for the case where a kernel just
+/// returns one value.
+template struct createReturns final {};
+
+template
+struct createReturns, void> final {
+  static constexpr std::array call() {
+    return createArgumentVectorFromTypes(
+        std::make_index_sequence()
+    );
+  }
+};
+
+template
+struct createReturns::value && !guts::is_instantiation_of::value>> final {
+  static constexpr std::array call() {
+    return createReturns>::call();
+  }
+};
+
+template<>
+struct createReturns final {
+  static constexpr std::array call() {
+    return createReturns>::call();
+  }
+};
+
+template 
+struct createSingleReturn {
+  static constexpr std::array call() {
+    return createArgumentVectorFromTypes(std::make_index_sequence<1>());
+  }
+};
+
+TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef arguments, c10::ArrayRef returns);
+TORCH_API FunctionSchema make_function_schema(c10::ArrayRef arguments, c10::ArrayRef returns);
+
+/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
+/// function. Flattens std::tuple returns into multiple return types
+template 
+FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() {
+ using ReturnType = typename FunctionTraits::return_type;
+ using ParameterTypes = typename FunctionTraits::parameter_types;
+
+ // arguments and returns are computed into a std::array at compile time and embedded into the binary.
+ // The only code executed at runtime here is the one that creates a std::vector
+ // of the arguments/returns from the std::array.
+ constexpr auto arguments = createArguments::call();
+ constexpr auto returns = createReturns::call();
+
+ return make_function_schema(arguments, returns);
+}
+
+/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
+/// function. Preserves std::tuple returns as a Tuple return type
+template 
+FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
+ using ReturnType = typename FunctionTraits::return_type;
+ using ParameterTypes = typename FunctionTraits::parameter_types;
+
+ // arguments and returns are computed into a std::array at compile time and embedded into the binary.
+ // The only code executed at runtime here is the one that creates a std::vector
+ // of the arguments/returns from the std::array.
+ constexpr auto arguments = createArguments::call();
+ constexpr auto returns = createSingleReturn::call();
+
+ return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
+}
+
+}
+}
+
+template
+FunctionSchema inferFunctionSchemaFlattenedReturns() {
+  return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns>();
+}
+
+template
+FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
+  return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn>(std::move(name), std::move(overload_name));
+}
+
+TORCH_API c10::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h
new file mode 100644
index 0000000000000000000000000000000000000000..c2703c741fcbef3d2c45c5df2a7210b17aac6925
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/op_allowlist.h
@@ -0,0 +1,199 @@
+#pragma once
+
+// TODO: unify to C10_MOBILE. In theory this header could be used in OSS.
+#ifdef TEMPLATE_SELECTIVE_BUILD
+#include 
+#endif
+
+/**
+ * This header implements functionality to build PyTorch with only a certain
+ * set of operators (+ dependencies) included.
+ *
+ * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these
+ *   two ops will be included in your build.  The allowlist records operators
+ *   only, no overloads; if you include aten::add, all overloads of aten::add
+ *   will be included.
+ *
+ * Internally, this is done by removing the operator registration calls
+ * using compile time programming, and the linker will then prune all
+ * operator functions that weren't registered.
+ * See Note [Selective build] for more details
+ *
+ * WARNING: The allowlist mechanism doesn't work for all ways you could go about
+ * registering an operator.  If the dispatch key / operator name is not
+ * sufficiently obvious at compile time, then the allowlisting mechanism
+ * will fail (and the operator will be included in the binary anyway).
+ */
+
+#include 
+#include 
+#include 
+
+
+#if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
+#include 
+#endif
+
+namespace c10 {
+
+namespace impl {
+
+constexpr bool allowlist_contains(string_view allowlist, string_view item);  // Forward Declare
+
+/**
+ * In selective build mode returns true/false depending on whether a build
+ * feature is available or not.
+ *
+ * In instrumenting mode (tracing mode), always returns true, and doesn't
+ * trigger any side effects.
+ */
+constexpr bool is_build_feature_available(const char* name) {
+#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
+  // Selective Build mode.
+#if !defined(TORCH_BUILD_FEATURE_ALLOWLIST)
+  (void)name;
+  return true;
+#else
+  return allowlist_contains(
+    C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST),
+    name);
+#endif
+
+#else
+  // Instrumenting mode.
+  (void)name;
+  return true;
+#endif
+}
+
+[[noreturn]] void build_feature_required_feature_not_available(const char* feature);
+
+/**
+ * Use BUILD_FEATURE_REQUIRED macro in user-code.
+ *
+ * In selective build mode becomes a no-op if the build feature passed
+ * in is available. If not available, throws an exception (c10::Error).
+ * The compiler is able to perform dead code elimination for code
+ * following this method if the build feature is not available.
+ *
+ * In instrumenting mode (tracing mode), registers (as a side effect)
+ * the presence of this specific build feature being triggered.
+ */
+#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)  // selective build mode
+
+#if defined(TORCH_BUILD_FEATURE_ALLOWLIST)
+#define BUILD_FEATURE_REQUIRED(NAME)                                 \
+  if (!c10::impl::is_build_feature_available(NAME)) {                \
+    ::c10::impl::build_feature_required_feature_not_available(NAME); \
+  }
+#else  // Everything trivially selected
+#define BUILD_FEATURE_REQUIRED(NAME)
+
+#endif
+
+#else  // trace mode
+#define BUILD_FEATURE_REQUIRED(NAME)  \
+  RECORD_FUNCTION_WITH_SCOPE(         \
+      at::RecordScope::BUILD_FEATURE, \
+      std::string(NAME),              \
+      {});
+#endif
+
+// Use this macro, and not is_build_feature_available
+#define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME)
+
+// returns true iff allowlist contains item
+// allowlist_contains("a;bc;d", "bc") == true
+constexpr bool allowlist_contains(string_view allowlist, string_view item) {
+    //Choose a really big value for next so that if something goes wrong
+    //this code will blow up in a hopefully detectable way.
+    size_t next = std::numeric_limits::max();
+    for (size_t cur = 0; cur <= allowlist.size(); cur = next) {
+      next = allowlist.find(';', cur);
+      if (next != string_view::npos) {
+        if (allowlist.substr(cur, next - cur).compare(item) == 0) {
+          return true;
+        }
+        next++;
+      } else {
+        if (allowlist.substr(cur).compare(item) == 0) {
+          return true;
+        }
+        break;
+      }
+    }
+    return false;
+}
+
+// Returns true iff the given op name is on the allowlist
+// and should be registered
+constexpr bool op_allowlist_check(string_view op_name) {
+  assert(op_name.find("::") != string_view::npos);
+  // Use assert() instead of throw() due to a gcc bug. See:
+  // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function
+  // https://github.com/fmtlib/fmt/issues/682
+  assert(op_name.find("(") == string_view::npos);
+#if !defined(TORCH_OPERATOR_WHITELIST)
+  // If the TORCH_OPERATOR_WHITELIST parameter is not defined,
+  // all ops are to be registered
+  return true;
+#else
+  return allowlist_contains(
+    C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
+    // This function is majorly used for mobile selective build with
+    // root operators, where the overload is included in the allowlist.
+    op_name);
+    // // Strip overload name (as allowlist doesn't contain overloads)
+    // // Another function based on this may be added when there's usage
+    // // on op names without overload.
+    // OperatorNameView::parse(op_name).name);
+#endif
+}
+
+// Returns true iff the given schema string is on the allowlist
+// and should be registered
+constexpr bool schema_allowlist_check(string_view schema) {
+#if defined(TORCH_FORCE_SCHEMA_REGISTRATION)
+  return true;
+#else
+  return op_allowlist_check(schema.substr(0, schema.find("(")));
+#endif
+}
+
+// Returns true iff the given custom class name is on the allowlist
+// and should be registered
+constexpr bool custom_class_allowlist_check(string_view custom_class_name) {
+#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
+  // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
+  // all custom classes are to be registered
+  (void)custom_class_name;
+  return true;
+#else
+  return allowlist_contains(
+    C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
+    custom_class_name);
+#endif
+}
+
+// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
+// Add this API to pass arbitrary allowlist.
+constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) {
+  return allowlist_contains(allowlist, schema.substr(0, schema.find("(")));
+}
+
+// Returns true iff the given dispatch key is on the allowlist
+// and should be registered.  When we turn this on, the list of valid
+// mobile dispatch keys is hard coded (but you need to make sure
+// that you have the correct set of dispatch keys for this).
+constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
+#ifdef C10_MOBILE
+  return true;
+  // Disabled for now: to be enabled later!
+  // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
+#else
+  return true;
+#endif
+}
+
+} // namespace impl
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/op_registration.h b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/op_registration.h
new file mode 100644
index 0000000000000000000000000000000000000000..751c3bfed81c0ab8fee17fa7caec35e1ed2a645d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/op_registration/op_registration.h
@@ -0,0 +1,596 @@
+#pragma once
+
+/**
+ * Include this file if you want to register operators. It includes all
+ * functionality needed to do so for you.
+ */
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
+#include 
+#endif
+#include 
+
+namespace c10 {
+
+namespace detail {
+// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
+// We do this because every argument in a function schema is expected to be convertable
+// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
+// See Note [Plumbing Keys Through The Dispatcher]
+template
+std::unique_ptr inferFunctionSchemaFromFunctor() {
+  using func_type = typename c10::remove_DispatchKeySet_arg_from_func::func_type;
+  return std::make_unique(inferFunctionSchemaFlattenedReturns());
+}
+}
+
+/**
+ * An instance of this class handles the registration for one or more operators.
+ * Make sure you keep the RegisterOperators instance around since it will
+ * deregister the operator it's responsible for in its destructor.
+ *
+ * Example:
+ *
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op(c10::RegisterOperators::options()
+ * >         .schema("my_op")
+ * >         .kernel(DispatchKey::CPU));
+ */
+class TORCH_API RegisterOperators final {
+public:
+  RegisterOperators() = default;
+  ~RegisterOperators() = default;
+
+  RegisterOperators(const RegisterOperators&) = delete;
+  RegisterOperators& operator=(const RegisterOperators&) = delete;
+  RegisterOperators(RegisterOperators&&) noexcept = default;
+  RegisterOperators& operator=(RegisterOperators&&) noexcept = default;
+
+  class TORCH_API Options final {
+  public:
+    Options(const Options&) = delete;
+    Options(Options&&) noexcept = delete;
+    Options& operator=(const Options&) = delete;
+    Options& operator=(Options&&) noexcept = delete;
+
+    // internal-only for registering stack based kernels
+    template
+    Options&& kernel(DispatchKey dispatch_key) && {
+      return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction(), nullopt, nullptr);
+    }
+
+    // internal-only for registering stack based catch-all kernels
+    template
+    Options&& catchAllKernel() && {
+      return std::move(*this).kernel(c10::nullopt, KernelFunction::makeFromBoxedFunction(), nullopt, nullptr);
+    }
+
+    // internal only for registering caffe2 ops
+    Options&& schema(FunctionSchema&& schema) {
+        TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration.");
+        schemaOrName_ = FunctionSchema(std::move(schema));
+        return std::move(*this);
+    }
+
+    /**
+     * Use this to specify the schema for an operator. You can also specify
+     * the operator name only to have the function signature part of the
+     * schema be inferred from the kernel function.
+     *
+     * Example:
+     *
+     * > // Infer function signature from my_kernel_cpu
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .kernel(DispatchKey::CPU));
+     * >
+     * >
+     * > // Explicitly specify full schema
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op(Tensor a) -> Tensor")
+     * >         .kernel(DispatchKey::CPU));
+     */
+    Options&& schema(const std::string& schemaOrName) {
+      TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration.");
+
+      #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD)
+        throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
+      #else
+        schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName);
+      #endif
+
+      return std::move(*this);
+    }
+
+    /**
+     * Use this to register an operator whose kernel is implemented as a functor.
+     * The kernel is only called for inputs matching the given dispatch key.
+     * You can register multiple kernels for different dispatch keys.
+     *
+     * Example:
+     *
+     * > namespace {
+     * >   class my_kernel_cpu final : public c10::OperatorKernel {
+     * >   public:
+     * >     Tensor operator()(Tensor a, Tensor b) {...}
+     * >   };
+     * > }
+     * >
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .kernel(DispatchKey::CPU));
+     *
+     * The functor constructor can take arguments to configure the kernel.
+     * The arguments are defined in the kernel registration.
+     * Example:
+     *
+     * > namespace {
+     * >   class my_kernel_cpu final : public c10::OperatorKernel {
+     * >   public:
+     * >     explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
+     * >         : ... {...}
+     * >
+     * >     Tensor operator()(Tensor a, Tensor b) {...}
+     * >   };
+     * > }
+     * >
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .kernel(DispatchKey::CPU, "some_configuration", 3, true));
+     */
+    template
+    // enable_if: only enable it if KernelFunctor is actually a functor
+    std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && {
+      static_assert(std::is_base_of::value, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+      static_assert(std::is_constructible::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor.");
+
+      return std::move(*this).kernel(
+        dispatch_key,
+        KernelFunction::makeFromUnboxedFunctor(std::make_unique(std::forward(constructorParameters)...)),
+        impl::CppSignature::make(),
+        detail::inferFunctionSchemaFromFunctor()
+      );
+    }
+
+    /**
+     * Use this to register an operator whose kernel is implemented as a functor.
+     * The kernel is a catch-all kernel, meaning it's called independent from
+     * the input. Dispatch is disabled for this operator.
+     *
+     * Example:
+     *
+     * > namespace {
+     * >   class my_kernel_cpu final : public c10::OperatorKernel {
+     * >   public:
+     * >     Tensor operator()(Tensor a, Tensor b) {...}
+     * >   };
+     * > }
+     * >
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .catchAllKernel());
+     *
+     * The functor constructor can take arguments to configure the kernel.
+     * The arguments are defined in the kernel registration.
+     * Example:
+     *
+     * > namespace {
+     * >   class my_kernel_cpu final : public c10::OperatorKernel {
+     * >   public:
+     * >     explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
+     * >         : ... {...}
+     * >
+     * >     Tensor operator()(Tensor a, Tensor b) {...}
+     * >   };
+     * > }
+     * >
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .catchAllKernel("some_configuration", 3, true));
+     */
+    template
+    // enable_if: only enable it if KernelFunctor is actually a functor
+    std::enable_if_t::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && {
+      static_assert(std::is_base_of::value, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
+      static_assert(std::is_constructible::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor.");
+
+      return std::move(*this).kernel(
+        c10::nullopt,
+        KernelFunction::makeFromUnboxedFunctor(std::make_unique(std::forward(constructorParameters)...)),
+        impl::CppSignature::make(),
+        detail::inferFunctionSchemaFromFunctor()
+      );
+    }
+
+    /**
+     * Use this to register an operator whose kernel is implemented by a function.
+     * The kernel is only called for inputs matching the given dispatch key.
+     * You can register multiple kernels for different dispatch keys.
+     *
+     * Example:
+     *
+     * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
+     * >
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .kernel(DispatchKey::CPU));
+     */
+    template
+    // enable_if: only enable it if FuncType is actually a function
+    std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key) && {
+      static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
+      static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
+
+      return std::move(*this).kernel(
+        dispatch_key,
+        KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoFunctor
+        detail::inferFunctionSchemaFromFunctor>::type>()
+      );
+    }
+
+    /**
+     * Use this to register an operator whose kernel is implemented by a function.
+     * The kernel is a catch-all kernel, meaning it's called independent from
+     * the input. Dispatch is disabled for this operator.
+     *
+     * Example:
+     *
+     * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
+     * >
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .catchAllKernel());
+     */
+    template
+    // enable_if: only enable it if FuncType is actually a function
+    std::enable_if_t::value, Options&&> catchAllKernel() && {
+      static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
+      static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
+
+      return std::move(*this).kernel(
+        c10::nullopt,
+        KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoFunctor
+        detail::inferFunctionSchemaFromFunctor>::type>()
+      );
+    }
+
+    template
+    // enable_if: only enable it if FuncType is actually a function
+    std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && {
+      static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
+      TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
+
+      return std::move(*this).kernel(
+        dispatch_key,
+        KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoFunctor
+        detail::inferFunctionSchemaFromFunctor>>()
+      );
+    }
+
+    template
+    // enable_if: only enable it if FuncType is actually a function
+    std::enable_if_t::value, Options&&> catchAllKernel(FuncType* kernel_func) && {
+      static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
+      TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
+
+      return std::move(*this).kernel(
+        c10::nullopt,
+        KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoFunctor
+        detail::inferFunctionSchemaFromFunctor>>()
+      );
+    }
+
+    /**
+     * Use this to register an operator whose kernel is implemented as a lambda.
+     * The kernel is only called for inputs matching the given dispatch key.
+     * You can register multiple kernels for different dispatch keys.
+     *
+     * The lambda must be stateless, i.e. not have a capture. If your kernel
+     * needs to store some configuration parameters, write the kernel as a
+     * functor instead.
+     *
+     * Example:
+     *
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...}));
+     */
+    template
+    // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
+    std::enable_if_t<
+        guts::is_functor>::value
+        && !std::is_same>::func_type, KernelFunction::BoxedKernelFunction>::value,
+        Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && {
+      static_assert(!std::is_base_of>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead.");
+
+      // We don't support stateful lambdas (i.e. lambdas with a capture), because their
+      // behavior would be nonobvious. A functor kernel with cache gets a new instance of
+      // its cache each time the kernel is looked up from the dispatch table.
+      // A lambda with a capture would be global and share its capture between all kernel lookups.
+      // So, instead of making users having to think about it (including the thread-safety
+      // issues this causes), let's just forbid stateful lambdas altogether.
+      static_assert(guts::is_stateless_lambda>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel() instead.");
+
+      return std::move(*this).kernel(
+        dispatch_key,
+        KernelFunction::makeFromUnboxedLambda(std::forward(functor)),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
+        detail::inferFunctionSchemaFromFunctor>>()
+      );
+    }
+
+    /**
+     * Use this to register an operator whose kernel is implemented as a lambda.
+     * The kernel is a catch-all kernel, meaning it's called independent from
+     * the input. Dispatch is disabled for this operator.
+     *
+     * The lambda must be stateless, i.e. not have a capture. If your kernel
+     * needs to store some configuration parameters, write the kernel as a
+     * functor instead.
+     *
+     * Example:
+     *
+     * > static auto registry = c10::RegisterOperators()
+     * >     .op(c10::RegisterOperators::options()
+     * >         .schema("my_op")
+     * >         .catchAllKernel([] (Tensor a) -> Tensor {...}));
+     */
+    template
+    // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
+    std::enable_if_t<
+        guts::is_functor>::value
+        && !std::is_same>::func_type, KernelFunction::BoxedKernelFunction>::value,
+        Options&&> catchAllKernel(Lambda&& lambda) && {
+      static_assert(!std::is_base_of>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead.");
+
+      // We don't support stateful lambdas (i.e. lambdas with a capture), because their
+      // behavior would be nonobvious.
+      // A lambda with a capture would be global and share its capture between all kernel lookups.
+      // This would be a likely source for unexpected race conditions, so we forbid it.
+      // If a kernel really needs global state, they can just have regular global state
+      // in their .cpp file next to the kernel lambda.
+      static_assert(guts::is_stateless_lambda>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel() instead.");
+
+      return std::move(*this).kernel(
+        c10::nullopt,
+        KernelFunction::makeFromUnboxedLambda(std::forward(lambda)),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
+        detail::inferFunctionSchemaFromFunctor>>()
+      );
+    }
+
+    Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && {
+      TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration.");
+      aliasAnalysisKind_ = aliasAnalysisKind;
+      return std::move(*this);
+    }
+
+  private:
+    Options&& kernel(c10::optional dispatch_key, KernelFunction&& func, c10::optional cpp_signature, std::unique_ptr&& inferred_function_schema) && {
+      KernelRegistrationConfig config;
+      config.dispatch_key = dispatch_key;
+      config.func = std::move(func);
+      config.cpp_signature = cpp_signature;
+      config.inferred_function_schema = std::move(inferred_function_schema);
+      kernels.push_back(std::move(config));
+      return std::move(*this);
+    }
+
+    Options()
+    : schemaOrName_(c10::nullopt)
+    , kernels()
+    , aliasAnalysisKind_(c10::nullopt)
+    {}
+
+    // KernelRegistrationConfig accumulates all information from the config
+    // parameters passed to a RegisterOperators::op() call into one object.
+    struct KernelRegistrationConfig final {
+      KernelRegistrationConfig()
+        : dispatch_key(c10::nullopt)
+        , func()
+        , cpp_signature(c10::nullopt)
+        , inferred_function_schema(nullptr)
+      {}
+
+      c10::optional dispatch_key;
+      KernelFunction func;
+      c10::optional cpp_signature;
+      std::unique_ptr inferred_function_schema;
+    };
+
+    c10::optional> schemaOrName_;
+
+    std::vector kernels;
+    optional aliasAnalysisKind_;
+    friend class RegisterOperators;
+    friend class Library;
+  };
+
+  /**
+   * Call this to get an instance of registration options, which
+   * can be passed to a call to RegisterOperators::op() to specify
+   * these options for the operator registration.
+   * See class doc comment for examples.
+   */
+  static Options options() {
+    return {};
+  }
+
+  /**
+   * Call this to register an operator. See class doc comment for examples.
+   */
+  RegisterOperators&& op(Options&& options) && {
+    checkSchemaAndRegisterOp_(std::move(options));
+    return std::move(*this);
+  }
+
+  // Regular mutator version of the && version above
+  RegisterOperators& op(Options&& options) & {
+    checkSchemaAndRegisterOp_(std::move(options));
+    return *this;
+  }
+
+  /**
+   * This is a shorthand for RegisterOperators::op(Options) where you can
+   * specify the operator schema outside of the options parameter.
+   * See class doc comment for examples.
+   */
+  RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && {
+    return std::move(*this).op(std::move(options).schema(schemaOrName));
+  }
+
+  // internal only for registering caffe2 ops
+  RegisterOperators&& op(FunctionSchema schema, Options&& options) && {
+    return std::move(*this).op(std::move(options).schema(std::move(schema)));
+  }
+
+  template
+  explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options())
+  : RegisterOperators() {
+    std::move(*this).op(schemaOrName, std::forward(func), std::move(options));
+  }
+
+  /**
+   * This API registers an operator based on a kernel function pointer.
+   *
+   * Given a kernel
+   *
+   * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
+   *
+   * This API looks like:
+   *
+   * > static auto registry = c10::RegisterOperators()
+   * >     .op("my_op", &my_kernel_cpu);
+   *
+   * If your kernel is small and the overhead of calling it matters,
+   * then this API might be the wrong choice since the following API
+   * has a slightly lower overhead for calling into the kernel:
+   *
+   * > static auto registry = c10::RegisterOperators()
+   * >     .op("my_op", c10::RegisterOperators::options()
+   * >         .kernel());
+   *
+   * Or, alternatively, write your kernel as a functor:
+   *
+   * > namespace {
+   * >   class my_kernel_cpu final : public c10::OperatorKernel {
+   * >   public:
+   * >     Tensor operator()(Tensor a, Tensor b) {...}
+   * >   };
+   * > }
+   * >
+   * > static auto registry = c10::RegisterOperators()
+   * >     .op("my_op", c10::RegisterOperators::options()
+   * >         .kernel());
+   */
+   template
+   // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction.
+   std::enable_if_t::value && !std::is_same::value, RegisterOperators&&>
+   op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && {
+     constexpr bool AllowLegacyTypes = true;
+     return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
+       c10::nullopt,
+       KernelFunction::makeFromUnboxedRuntimeFunction(func),
+       impl::CppSignature::make(),
+       // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
+       detail::inferFunctionSchemaFromFunctor>>()
+     ));
+   }
+
+   /**
+    * This API registers an operator based on a kernel lambda.
+    *
+    * This API looks like:
+    *
+    * > static auto registry = c10::RegisterOperators()
+    * >     .op("my_op", [] (Tensor a, Tensor b) {...});
+    *
+    * This is equivalent to:
+    *
+    * > static auto registry = c10::RegisterOperators()
+    * >     .op("my_op", c10::RegisterOperators::options()
+    * >         .catchAllKernel([] (Tensor a, Tensor b) {...}));
+    *
+    */
+    template
+    // enable_if: only enable it if Lambda is actually a stateless lambda
+    std::enable_if_t::value && guts::is_stateless_lambda>::value, RegisterOperators&&>
+    op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
+      static_assert(!std::is_base_of::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
+
+      constexpr bool AllowLegacyTypes = true;
+      return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
+        c10::nullopt,
+        KernelFunction::makeFromUnboxedLambda(std::forward(lambda)),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
+        detail::inferFunctionSchemaFromFunctor>>()
+      ));
+    }
+
+    template
+    C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.")
+    // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda
+    std::enable_if_t::value && !guts::is_stateless_lambda>::value, RegisterOperators&&>
+    op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
+      static_assert(!std::is_base_of::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
+
+      constexpr bool AllowLegacyTypes = true;
+      return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
+        c10::nullopt,
+        KernelFunction::makeFromUnboxedLambda(std::forward(lambda)),
+        impl::CppSignature::make(),
+        // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
+        detail::inferFunctionSchemaFromFunctor>>()
+      ));
+    }
+
+private:
+  void checkSchemaAndRegisterOp_(Options&& config);
+
+  static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
+  void checkNoDuplicateKernels_(const Options& options);
+  void registerOp_(Options&& options);
+
+  std::vector registrars_;
+};
+
+} // namespace c10
+
+namespace torch {
+  // Old-style API
+  using RegisterOperators = c10::RegisterOperators;
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/operator_name.h b/MLPY/Lib/site-packages/torch/include/ATen/core/operator_name.h
new file mode 100644
index 0000000000000000000000000000000000000000..83995e24f9122981968e01fba55deab23d904695
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/operator_name.h
@@ -0,0 +1,92 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+// TODO: consider storing namespace separately too
+struct OperatorName final {
+  std::string name;
+  std::string overload_name;
+  OperatorName(std::string name, std::string overload_name)
+      : name(std::move(name)), overload_name(std::move(overload_name)) {}
+
+  // TODO: These two functions below are slow!  Fix internal data structures so
+  // I don't have to manually reconstruct the namespaces!
+
+  // Return the namespace of this OperatorName, if it exists.  The
+  // returned string_view is only live as long as the OperatorName
+  // exists and name is not mutated
+  c10::optional getNamespace() const {
+    auto pos = name.find("::");
+    if (pos == std::string::npos) {
+      return c10::nullopt;
+    } else {
+      return c10::make_optional(c10::string_view(name.data(), pos));
+    }
+  }
+
+  // Returns true if we successfully set the namespace
+  bool setNamespaceIfNotSet(const char* ns) {
+    if (!getNamespace().has_value()) {
+      const auto ns_len = strlen(ns);
+      const auto old_name_size = name.size();
+      name.resize(ns_len + 2 + old_name_size);
+      // Shift current value of name to the end of the new space.
+      name.replace(name.size() - old_name_size, old_name_size, name, 0, old_name_size);
+      name.replace(0, ns_len, ns, ns_len);
+      name[ns_len] = ':';
+      name[ns_len + 1] = ':';
+      return true;
+    } else {
+      return false;
+    }
+  }
+};
+
+// Non-owning view of an OperatorName.  Unlike OperatorName, most of
+// its functions are constexpr, so it can be used for compile time
+// computations
+struct OperatorNameView final {
+  c10::string_view name;
+  c10::string_view overload_name;
+  constexpr OperatorNameView(c10::string_view name, c10::string_view overload_name)
+    : name(name), overload_name(overload_name) {}
+  // Parses strings like "foo.overload" and also "foo"
+  constexpr static OperatorNameView parse(c10::string_view full_name) {
+    auto i = full_name.find('.');
+    if (i == c10::string_view::npos) {
+      return OperatorNameView(full_name, c10::string_view());
+    } else {
+      return OperatorNameView(full_name.substr(0, i), full_name.substr(i + 1));
+    }
+  }
+};
+
+inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) {
+  return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name;
+}
+
+inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) {
+  return !operator==(lhs, rhs);
+}
+
+TORCH_API std::string toString(const OperatorName& opName);
+TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&);
+
+} // namespace c10
+
+namespace std {
+  template <>
+  struct hash<::c10::OperatorName> {
+    size_t operator()(const ::c10::OperatorName& x) const {
+      return std::hash()(x.name) ^ (~ std::hash()(x.overload_name));
+    }
+  };
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/qualified_name.h b/MLPY/Lib/site-packages/torch/include/ATen/core/qualified_name.h
new file mode 100644
index 0000000000000000000000000000000000000000..fcc5bdada9b276b2aa7745d87d47c850c95c6894
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/qualified_name.h
@@ -0,0 +1,161 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+// Represents a name of the form "foo.bar.baz"
+struct QualifiedName {
+  QualifiedName() = default;
+
+  // `name` can be a dotted string, like "foo.bar.baz", or just a bare name.
+  /* implicit */ QualifiedName(const std::string& name) {
+    TORCH_CHECK(!name.empty());
+    // split the string into its atoms.
+    size_t startSearchFrom = 0;
+    size_t pos = name.find(delimiter_, startSearchFrom);
+
+    while (pos != std::string::npos) {
+      auto atom = name.substr(startSearchFrom, pos - startSearchFrom);
+      TORCH_INTERNAL_ASSERT(
+          !atom.empty(), "Invalid name for qualified name: '", name, "'");
+      atoms_.push_back(std::move(atom));
+      startSearchFrom = pos + 1;
+      pos = name.find(delimiter_, startSearchFrom);
+    }
+
+    auto finalAtom = name.substr(startSearchFrom);
+    TORCH_INTERNAL_ASSERT(
+        !finalAtom.empty(), "Invalid name for qualified name: '", name, "'");
+    atoms_.emplace_back(std::move(finalAtom));
+
+    cacheAccessors();
+  }
+
+  explicit QualifiedName(std::vector atoms) : atoms_(std::move(atoms)) {
+    for (const auto& atom : atoms_) {
+      TORCH_CHECK(!atom.empty(), "Atom cannot be empty");
+      TORCH_CHECK(
+          atom.find(delimiter_) == std::string::npos,
+          "Delimiter not allowed in atom");
+    }
+
+    cacheAccessors();
+  }
+  // Unnecessary copy. Ideally we'd use something like std::string_view.
+  /* implicit */ QualifiedName(const char* name)
+      : QualifiedName(std::string(name)) {}
+
+  // `name` must be a bare name (no dots!)
+  explicit QualifiedName(const QualifiedName& prefix, std::string name) {
+    TORCH_INTERNAL_ASSERT(!name.empty());
+    TORCH_INTERNAL_ASSERT(name.find(delimiter_) == std::string::npos);
+    atoms_.insert(atoms_.begin(), prefix.atoms_.begin(), prefix.atoms_.end());
+    atoms_.push_back(std::move(name));
+
+    cacheAccessors();
+  }
+
+  // Is `this` a prefix of `other`?
+  // For example, "foo.bar" is a prefix of "foo.bar.baz"
+  bool isPrefixOf(const QualifiedName& other) const {
+    const auto& thisAtoms = atoms_;
+    const auto& otherAtoms = other.atoms_;
+
+    if (thisAtoms.size() > otherAtoms.size()) {
+      // Can't be a prefix if it's bigger
+      return false;
+    }
+    for (const auto i : c10::irange(thisAtoms.size())) {
+      if (thisAtoms[i] != otherAtoms[i]) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // The fully qualified name, like "foo.bar.baz"
+  const std::string& qualifiedName() const {
+    return qualifiedName_;
+  }
+
+  // The leading qualifier, like "foo.bar"
+  const std::string& prefix() const {
+    return prefix_;
+  }
+
+  // The base name, like "baz"
+  const std::string& name() const {
+    return name_;
+  }
+
+  const std::vector& atoms() const {
+    return atoms_;
+  }
+
+  bool operator==(const QualifiedName& other) const {
+    return this->qualifiedName_ == other.qualifiedName_;
+  }
+
+  bool operator!=(const QualifiedName& other) const {
+    return !(*this == other);
+  }
+
+ private:
+  static constexpr char delimiter_ = '.';
+
+  // Helper for cacheAccessors() below.
+  template
+  std::string join(char delimiter, const T& v) {
+    std::string out;
+    size_t reserve = 0;
+    for (const auto& e : v) {
+      reserve += e.size() + 1;
+    }
+    out.reserve(reserve);
+    for (const auto i : c10::irange(v.size())) {
+      if (i != 0) {
+        out.push_back(delimiter);
+      }
+      out.append(v[i]);
+    }
+    return out;
+  }
+
+  void cacheAccessors() {
+    qualifiedName_ = join(delimiter_, atoms_);
+    if (atoms_.size() > 1) {
+      ArrayRef view(atoms_);
+      const auto prefixView = view.slice(0, view.size() - 1);
+      prefix_ = join(delimiter_, prefixView);
+    }
+
+    if (!atoms_.empty()) {
+      name_ = atoms_.back();
+    }
+  }
+
+  // The actual list of names, like "{foo, bar, baz}"
+  std::vector atoms_;
+
+  /*
+   * Cached accessors, derived from `atoms_`.
+   */
+  std::string qualifiedName_;
+  std::string prefix_;
+  std::string name_;
+};
+} // namespace c10
+
+namespace std {
+template <>
+struct hash {
+  size_t operator()(const c10::QualifiedName& n) const noexcept {
+    return std::hash()(n.qualifiedName());
+  }
+};
+} // namespace std
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/rref_interface.h b/MLPY/Lib/site-packages/torch/include/ATen/core/rref_interface.h
new file mode 100644
index 0000000000000000000000000000000000000000..c31ea40902dc8432fcabbbf8401a6d8acd8cc4e5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/rref_interface.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include 
+#include 
+
+namespace c10 {
+
+struct Type;
+using worker_id_t = int16_t;
+
+// This abstract class contains only user-facing APIs, and will be shared
+// between jit and distributed to implement TorchScript support.
+class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target {
+ public:
+  RRefInterface() = default;
+  // RRef is made NOT copyable NOT movable to prevent messing up reference
+  // counting.
+  RRefInterface(const RRefInterface& other) = delete;
+  RRefInterface(RRefInterface&& other) = delete;
+  RRefInterface& operator=(RRefInterface&& other) = delete;
+
+  ~RRefInterface() override = default;
+
+  // returns the worker id of the owner
+  virtual worker_id_t owner() const = 0;
+
+  // returns the worker name of the owner
+  virtual std::string ownerName() const = 0;
+
+  // Returns true if this is the ``OwnerRRef``
+  virtual bool isOwner() const = 0;
+
+  // Returns true if this is an ``OwnerRRef`` or if this ``UserRRef`` has been
+  // confirmed by its owner.
+  virtual bool confirmedByOwner() const = 0;
+
+  virtual const TypePtr type() const = 0;
+};
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/stack.h b/MLPY/Lib/site-packages/torch/include/ATen/core/stack.h
new file mode 100644
index 0000000000000000000000000000000000000000..6aac6f102d4ebdfff37980f79c29e688ed3901e0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/stack.h
@@ -0,0 +1,200 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+
+// TODO move this to c10 namespace
+
+namespace torch {
+namespace jit {
+
+using c10::IValue;
+using Stack = std::vector;
+
+class Operation {
+  template 
+  using accepts = std::is_constructible, F&&>;
+
+ public:
+  template ::value, int> = 0>
+  C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.")
+  Operation(F&& raw): op_([raw = std::forward(raw)](Stack& stack) {
+    raw(&stack);
+  }) {}
+
+  template ::value &&
+                !std::is_same, Operation>::value, int> = 0>
+  Operation(F&& op): op_(std::forward(op)) {}
+
+  Operation(std::nullptr_t) noexcept {}
+
+  explicit operator bool() const noexcept {
+    return op_ ? true : false;
+  }
+
+  void operator()(Stack& stack) {
+    op_(stack);
+  }
+
+  template 
+  T* target() noexcept {
+    return op_.target();
+  }
+
+ private:
+  std::function op_;
+};
+
+// An operation with N inputs and M outputs pops the last N inputs off
+// the stack and pushes its M inputs onto the stack
+// before:  I0, I1, ... IN <- stack.back()
+// after:  O0, O1, ... OM
+// operations are defined this way so that ownership of inputs can be
+// transferred to the operation and it can incrementally drop ownership of
+// tensors when they become unneeded. For large operations, like 'run an entire
+// subgraph', this functionality is very important for minimizing gpu memory
+// usage return value is the relative 'offset' to jump to for the next
+// operation:
+//   pc += 1 + offset
+// so a return value of 0 goes to the next instruction
+
+// treat the last N elements of the stack as a list, looking up
+// element i
+static inline IValue& peek(Stack& stack, size_t i, size_t N) {
+  return *(stack.end() - N + i);
+}
+static inline IValue& peek(Stack* stack, size_t i, size_t N) {
+  return peek(*stack, i, N);
+}
+static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
+  return *(stack.end() - N + i);
+}
+static inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
+  return peek(*stack, i, N);
+}
+// treat the last N elements of the stack as a list, looking up the
+// slice starting at index i and having length len
+static inline at::ArrayRef peekSlice(
+    const Stack& stack,
+    size_t i,
+    size_t len,
+    size_t N) {
+  return at::ArrayRef(stack).slice(stack.size() - N + i, len);
+}
+static inline at::ArrayRef last(const Stack& stack, size_t N) {
+  return peekSlice(stack, 0, N, N);
+}
+static inline at::ArrayRef last(const Stack* stack, size_t N) {
+  return last(*stack, N);
+}
+static inline void drop(Stack& stack, size_t n) {
+  stack.erase(stack.end() - n, stack.end());
+}
+static inline void drop(Stack* stack, size_t n) {
+  drop(*stack, n);
+}
+static inline IValue pop(Stack& stack) {
+  auto r = std::move(stack.back());
+  stack.pop_back();
+  return r;
+}
+static inline IValue pop(Stack* stack) {
+  return pop(*stack);
+}
+static inline std::vector pop(Stack& stack, size_t n) {
+  std::vector result;
+  result.reserve(n);
+  for (const auto i : c10::irange(n)) {
+    result.push_back(std::move(peek(stack, i, n)));
+  }
+  drop(stack, n);
+  return result;
+}
+
+// variadic pop:
+// int64_t a; at::Tensor b;
+// pop(stack, a, b);
+// equivalent to:
+// b = pop(stack).toTensor();
+// a = pop(stack).toInt();
+template 
+static inline void pop(Stack& stack, Types&... args) {
+  size_t i = 0;
+  constexpr size_t N = sizeof...(args);
+  (void)std::initializer_list{
+      (args = std::move(peek(stack, i++, N)).template to(), 0)...};
+  drop(stack, N);
+}
+template 
+static inline void pop(Stack* stack, Types&... args) {
+  pop(*stack, args...);
+}
+template 
+static inline void push_one(Stack& stack, Type&& arg) {
+  stack.emplace_back(std::forward(arg));
+}
+
+static inline void push_one(Stack& stack, c10::TensorOptions options) {
+  stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
+  stack.emplace_back(options.layout());
+  stack.emplace_back(options.device());
+  stack.emplace_back(options.pinned_memory());
+}
+
+template 
+static inline void push(Stack& stack, Types&&... args) {
+  (void)std::initializer_list{(push_one(stack, std::forward(args)), 0)...};
+}
+template 
+static inline void push(Stack* stack, Types&&... args) {
+  return push(*stack, std::forward(args)...);
+}
+template 
+static inline void push_list_elements(Stack& stack, const c10::List& elements) {
+  for (T elem : elements) {
+    stack.push_back(std::move(elem));
+  }
+}
+
+// The packer here is carefully written not to make any unnecessary
+// copies.
+
+// pack takes the return values of aten functions pushes them onto the stack
+template 
+inline void pack(Stack& stack, T&& v) {
+  stack.emplace_back(std::forward(v));
+}
+template 
+inline void pack(Stack* stack, T&& v) {
+  pack(*stack, std::forward(v));
+}
+
+template 
+struct TuplePacker {
+  // NB: *Not* a universal reference.
+  static void execute(Stack& stack, std::tuple&& t) {
+    // NB: The move here does not "destroy" the entire tuple, that is
+    // not what std::move does; only the particular tuple index
+    // processed here gets stolen.
+    pack(stack, std::get(std::move(t)));
+    TuplePacker::execute(stack, std::move(t));
+  }
+};
+
+template 
+struct TuplePacker<0, Args...> {
+  static void execute(Stack& /*stack*/, std::tuple&& /*t*/){};
+};
+
+template 
+inline void pack(Stack& stack, std::tuple&& t) {
+  TuplePacker::execute(stack, std::move(t));
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/symbol.h b/MLPY/Lib/site-packages/torch/include/ATen/core/symbol.h
new file mode 100644
index 0000000000000000000000000000000000000000..3e23098d098259d6e914cab655cc1f3805ef4753
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/symbol.h
@@ -0,0 +1,147 @@
+#pragma once
+#include 
+#include 
+#include   // For std::hash
+#include 
+
+
+namespace c10 {
+
+// 'prim' symbols are synthetic operators that occur only in the IR
+// and don't have corresponding implementations in ATen.
+
+// 'onnx' symbols correspond to ONNX operators.  Their semantics
+// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
+// The particular version we are targeting is specified by '_onnx_opset_version'
+// in torch.onnx.symbolic_helper
+//
+// In general, most ONNX operators won't get an entry here, because they
+// are handled from the Python end.  However, you may occasionally need
+// to intern an ONNX symbol here so that you can conveniently write an
+// optimization on ONNX operations.
+
+// 'attr' symbols are attribute keys.  They are shared between both ONNX and ATen
+// operators (you disambiguate their meaning by looking at the operator itself).
+// In general, you only need to define attribute keys that are used by
+// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
+
+// Note [Symbol allocation]
+// ~~~~~~~~~~~~~~~~~~~~~~~~
+//
+//  1. Symbol namespace is split up into namespaces.
+//
+//  2. The intended access pattern for built-in symbols is onnx::MatMul
+//  in the c10 namespace (this is a Symbol).
+//
+
+// Built-in constant definition strategy:
+// - Enum is the most convenient way to generate a contiguous sequence
+//   of numbers for an identifier.
+// - However, an enum gives you a fresh type.  We want onnx::MatMul to
+//   be type Symbol, not some random enum type!
+// - Therefore, after using enums to generate the sequence of integers,
+//   we then declare constexpr Symbols to get everything the actual Symbol
+//   type we want.  Symbols must be constexpr to be valid to be "case"ed on.
+
+using unique_t = uint32_t;
+
+const std::string& domain_prefix();
+
+// A Symbol is like an interned string, but with a little extra
+// structure; it is namespaced via SymbolNamespace and the resulting
+// intern pointers support efficient namespace testing.
+struct TORCH_API Symbol {
+  explicit constexpr Symbol() : value(0) {};
+  explicit constexpr Symbol(unique_t uniq)
+  : value(uniq) {}
+
+  // Get a Symbol for a qualified string like "attr::bar"
+  static Symbol fromQualString(const std::string & s);
+
+  // Get a Symbol from a domain and an unqualified string like "org.pytorch.attr" and "bar"
+  static Symbol fromDomainAndUnqualString(const std::string & d, const std::string & s);
+
+  // Constructors for our various namespaced strings.  This will construct
+  // the appropriate namespaced string, e.g., "attr::foo" for the
+  // argument "foo", and then attempt to intern it.  DO NOT USE THIS
+  // with a string literal; attr::foo should be available in that case
+  // (and if it's not, you should add it to the built-ins list above.)
+  static Symbol attr(const std::string & s);
+  static Symbol aten(const std::string & s);
+  static Symbol cuda(const std::string & s);
+  static Symbol onnx(const std::string & s);
+  static Symbol prim(const std::string & s);
+  static Symbol user(const std::string & s);
+  static Symbol caffe2(const std::string & s);
+  static Symbol dimname(const std::string & s);
+  // TODO: eliminate me
+  static Symbol scope(const std::string & s);
+
+  bool is_attr() const;
+  bool is_aten() const;
+  bool is_cuda() const;
+  bool is_prim() const;
+  bool is_prims() const;
+  bool is_nvprims() const;
+  bool is_onnx() const;
+  bool is_user() const;
+  bool is_caffe2() const;
+  bool is_dimname() const;
+
+  // So we can switch on this
+  constexpr operator unique_t() const {
+    return value;
+  }
+
+  Symbol ns() const;
+
+  // Give a string corresponding to the unqualified version of this name, e.g.,
+  // "mm". Use this in a context where the intended namespace of the string is
+  // obvious; this is a *lossy* conversion.
+  const char * toUnqualString() const;
+
+  // Give a string corresponding to the qualified version of this name,
+  // e.g., "aten::mm".  This string format is made available to Python bindings
+  // (so we know how to parse it.)
+  const char * toQualString() const;
+
+  // This describes a symbol in a case where humans read it.  At the moment it's
+  // the same as toQualString.  This has to be a const char* returned because
+  // a lot of printf style macros use it.
+  const char * toDisplayString() const;
+
+  // Give a string corresponding to the domain name for the symbol,
+  // e.g., "org.pytorch.aten".
+  std::string domainString() const;
+
+private:
+
+  explicit Symbol(Symbol ns, const std::string & s);
+  unique_t value;
+};
+
+static inline bool operator==(Symbol lhs, Symbol rhs) {
+  return static_cast(lhs) == static_cast(rhs);
+}
+
+inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
+inline Symbol Symbol::aten(const std::string & s)  { return Symbol::fromQualString("aten::" + s); }
+inline Symbol Symbol::cuda(const std::string & s)  { return Symbol::fromQualString("cuda::" + s); }
+inline Symbol Symbol::onnx(const std::string & s)  { return Symbol::fromQualString("onnx::" + s); }
+inline Symbol Symbol::prim(const std::string & s)  { return Symbol::fromQualString("prim::" + s); }
+inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
+inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
+inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); }
+inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
+
+} // namespace c10
+
+// make symbol behave like an integer in hash tables
+namespace std {
+template <>
+struct hash {
+  size_t operator()(c10::Symbol s) const {
+    return std::hash()(static_cast(s));
+  }
+};
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/type_factory.h b/MLPY/Lib/site-packages/torch/include/ATen/core/type_factory.h
new file mode 100644
index 0000000000000000000000000000000000000000..771cc65c43b0431c8a43ed8595980fb1d9d819fd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/type_factory.h
@@ -0,0 +1,108 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+
+template 
+struct TORCH_API TypeFactoryBase {};
+
+template <>
+struct TORCH_API TypeFactoryBase {
+  template 
+  static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
+    return std::make_shared(
+        c10::DynamicTypeTrait::tagValue(),
+        c10::DynamicType::Arguments(c10::ArrayRef(
+            {std::move(ty), std::forward(args)...})));
+  }
+  template 
+  static c10::DynamicTypePtr create(const std::vector& types) {
+    return std::make_shared(
+        c10::DynamicTypeTrait::tagValue(),
+        c10::DynamicType::Arguments(types));
+  }
+  static c10::DynamicTypePtr createNamedTuple(
+      const std::string& name,
+      const std::vector& fields,
+      const std::vector& types) {
+    return std::make_shared(
+        c10::DynamicType::Tag::Tuple,
+        name,
+        c10::DynamicType::Arguments(fields, types));
+  }
+  template 
+  C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
+    return std::make_shared(
+        c10::DynamicTypeTrait::tagValue(),
+        name,
+        c10::DynamicType::Arguments{});
+  }
+  template 
+  C10_ERASE static c10::DynamicTypePtr get() {
+    return DynamicTypeTrait::getBaseType();
+  }
+  static const std::unordered_map& basePythonTypes();
+};
+
+using DynamicTypeFactory = TypeFactoryBase;
+
+// Helper functions for constructing DynamicTypes inline.
+template <
+    typename T,
+    std::enable_if_t::isBaseType, int> = 0>
+C10_ERASE DynamicTypePtr dynT() {
+  return DynamicTypeFactory::get();
+}
+
+template <
+    typename T,
+    typename... Args,
+    std::enable_if_t::isBaseType, int> = 0>
+C10_ERASE DynamicTypePtr dynT(Args&&... args) {
+  return DynamicTypeFactory::create(std::forward(args)...);
+}
+
+template <>
+struct TORCH_API TypeFactoryBase {
+  template 
+  static c10::TypePtr create(TypePtr ty, Args&&... args) {
+    return T::create(std::move(ty), std::forward(args)...);
+  }
+  template 
+  static c10::TypePtr create(std::vector types) {
+    return T::create(std::move(types));
+  }
+  static c10::TypePtr createNamedTuple(
+      const std::string& name,
+      const std::vector& fields,
+      const std::vector& types);
+  template 
+  C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
+    return T::create(name);
+  }
+  static const std::unordered_map& basePythonTypes();
+  template 
+  C10_ERASE static c10::TypePtr get() {
+    return T::get();
+  }
+};
+
+using DefaultTypeFactory = TypeFactoryBase;
+
+using PlatformType =
+#ifdef C10_MOBILE
+    c10::DynamicType
+#else
+    c10::Type
+#endif
+    ;
+
+using TypeFactory = TypeFactoryBase;
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/type_ptr.h b/MLPY/Lib/site-packages/torch/include/ATen/core/type_ptr.h
new file mode 100644
index 0000000000000000000000000000000000000000..7b183d4249201bcb7359bdb0156b9669525280b2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/type_ptr.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+
+namespace c10 {
+
+// Compatibility wrapper around a raw pointer so that existing code
+// written to deal with a shared_ptr can keep working.
+template 
+class SingletonTypePtr {
+ public:
+  /* implicit */ SingletonTypePtr(T* p) : repr_(p) {}
+
+  // We need this to satisfy Pybind11, but it shouldn't be hit.
+  explicit SingletonTypePtr(std::shared_ptr) { TORCH_CHECK(false); }
+
+  using element_type = typename std::shared_ptr::element_type;
+
+  template , void>::value, bool> = true>
+  T& operator*() const {
+    return *repr_;
+  }
+
+  T* get() const {
+    return repr_;
+  }
+
+  T* operator->() const {
+    return repr_;
+  }
+
+  operator bool() const {
+    return repr_ != nullptr;
+  }
+
+ private:
+  T* repr_{nullptr};
+};
+
+template 
+bool operator==(SingletonTypePtr lhs, SingletonTypePtr rhs) {
+  return (void*)lhs.get() == (void*)rhs.get();
+}
+
+template 
+bool operator!=(SingletonTypePtr lhs, SingletonTypePtr rhs) {
+  return !(lhs == rhs);
+}
+
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/core/typeid.h b/MLPY/Lib/site-packages/torch/include/ATen/core/typeid.h
new file mode 100644
index 0000000000000000000000000000000000000000..d69eba920abb0059a113405faf0264cd5a9b7bab
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/core/typeid.h
@@ -0,0 +1 @@
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpp_custom_type_hack.h b/MLPY/Lib/site-packages/torch/include/ATen/cpp_custom_type_hack.h
new file mode 100644
index 0000000000000000000000000000000000000000..e9e4e3e677d16b3001188f678ef2b985319b8405
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpp_custom_type_hack.h
@@ -0,0 +1,110 @@
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+
+// YOU ARE IN THE WRONG PLACE! TURN BACK NOW!
+
+// This code was a temporary hack to enable embedding arbitrary C++ structures
+// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE,
+// IT __WILL__ BREAK.
+
+// This code has been superseded by custom classes:
+// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
+
+// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED
+// IN THIS FILE**.
+
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
+
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::cpp_custom_type_hack {
+
+template 
+[[deprecated(
+    "Use custom classes instead: "
+    "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool
+isa(const Tensor& packed) {
+  return (packed.scalar_type() == kByte) &&
+      (packed.storage().data_ptr().get_deleter() ==
+       caffe2::TypeMeta::Make().deleteFn());
+}
+
+template 
+[[deprecated(
+    "Use custom classes instead: "
+    "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T&
+cast(const Tensor& packed) {
+  TORCH_CHECK(
+      packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
+  TORCH_CHECK(
+      packed.storage().data_ptr().get_deleter() ==
+          caffe2::TypeMeta::Make().deleteFn(),
+      "Expected temporary cpp type wrapper of type ",
+      caffe2::TypeMeta::TypeName());
+  return *reinterpret_cast(packed.storage().data_ptr().get());
+}
+
+template 
+[[deprecated(
+    "Use custom classes instead: "
+    "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor
+create(std::unique_ptr ptr, TensorOptions options) {
+  // None of this should trace, so turn off Tracer dispatching
+  at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
+  at::tracer::impl::NoTracerDispatchMode tracer_guard;
+
+  // We store this instance away in a Tensor and register a deleter function
+  // so that we do not leak memory. On the other side, we pull out the storage's
+  // data_ptr and get the right typed pointer.
+  void* raw_ptr = ptr.release();
+  at::DataPtr at_ptr(
+      raw_ptr, raw_ptr, caffe2::TypeMeta::Make().deleteFn(), at::kCPU);
+
+  // size doesn't really matter, but we can align it to the actual size
+  // returning variables because one likely want to use this hack from python
+  auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte));
+  retval.storage().set_data_ptr_noswap(std::move(at_ptr));
+  return retval;
+}
+
+} // namespace at::cpp_custom_type_hack
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/FlushDenormal.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/FlushDenormal.h
new file mode 100644
index 0000000000000000000000000000000000000000..0d7b4b9cc679c93d48f3b1be053a7ff9fb004128
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/FlushDenormal.h
@@ -0,0 +1,14 @@
+/// Flush-To-Zero and Denormals-Are-Zero mode
+///
+/// Flush-To-Zero (FTZ) and Denormals-Are-Zero (DAZ) are modes that bypass
+/// IEEE 754 methods of dealing with denormal floating-point numbers on x86-64
+/// and some x86 CPUs. They result in reduced precision for values near zero,
+/// but increased performance.
+///
+/// See https://software.intel.com/en-us/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz
+
+namespace at::cpu {
+
+bool set_flush_denormal(bool on);
+
+}  // namespace at::cpu
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/Utils.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/Utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..ab06ce7d18dab2ed20a64b2105beb451dccbf189
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/Utils.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include 
+
+namespace at::cpu {
+
+// Detect if CPU support Vector Neural Network Instruction.
+TORCH_API bool is_cpu_support_vnni();
+
+} // namespace at::cpu
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional.h
new file mode 100644
index 0000000000000000000000000000000000000000..032e9bfa471391b3a38e56dedd04c7a881a241f2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional.h
@@ -0,0 +1,4 @@
+#pragma once
+
+#include 
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional_base.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional_base.h
new file mode 100644
index 0000000000000000000000000000000000000000..801685e23cfa5eff3a3b22bf7e77af3f033e9311
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional_base.h
@@ -0,0 +1,329 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+
+namespace at::vec {
+
+// slow path
+template 
+inline scalar_t vec_reduce_all(
+    const Op& vec_fun,
+    vec::Vectorized acc_vec,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  scalar_t acc_arr[Vec::size()];
+  acc_vec.store(acc_arr);
+  for (const auto i : c10::irange(1, size)) {
+    std::array acc_arr_next = {0};
+    acc_arr_next[0] = acc_arr[i];
+    Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
+    acc_vec = vec_fun(acc_vec, acc_vec_next);
+  }
+  acc_vec.store(acc_arr);
+  return acc_arr[0];
+}
+
+template 
+struct VecReduceAllSIMD {
+  static inline scalar_t apply(const Op& vec_fun, const Vectorized& acc_vec) {
+    return vec_reduce_all(vec_fun, acc_vec, Vectorized::size());
+  }
+};
+
+#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
+#if defined(CPU_CAPABILITY_AVX2)
+template 
+struct VecReduceAllSIMD {
+  static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) {
+    using Vec = Vectorized;
+    Vec v = acc_vec;
+    // 128-bit shuffle
+    Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
+    v = vec_fun(v, v1);
+    // 64-bit shuffle
+    v1 = _mm256_shuffle_ps(v, v, 0x4E);
+    v = vec_fun(v, v1);
+    // 32-bit shuffle
+    v1 = _mm256_shuffle_ps(v, v, 0xB1);
+    v = vec_fun(v, v1);
+    return _mm256_cvtss_f32(v);
+  }
+};
+#endif // defined(CPU_CAPABILITY_AVX2)
+#if defined(CPU_CAPABILITY_AVX512)
+template 
+struct VecReduceAllSIMD {
+  static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) {
+    using Vec = Vectorized;
+    Vec v = acc_vec;
+    // 256-bit shuffle
+    Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
+    v = vec_fun(v, v1);
+    // 128-bit shuffle
+    v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
+    v = vec_fun(v, v1);
+    // 64-bit shuffle
+    v1 = _mm512_shuffle_ps(v, v, 0x4E);
+    v = vec_fun(v, v1);
+    // 32-bit shuffle
+    v1 = _mm512_shuffle_ps(v, v, 0xB1);
+    v = vec_fun(v, v1);
+    return _mm512_cvtss_f32(v);
+  }
+};
+#endif // defined(CPU_CAPABILITY_AVX512)
+#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
+
+template 
+inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized& acc_vec) {
+  return VecReduceAllSIMD::apply(vec_fun, acc_vec);
+}
+
+template , int> = 0>
+inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
+  using Vec = vec::Vectorized;
+  if (size < Vec::size())
+    return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
+  int64_t d = Vec::size();
+  Vec acc_vec = Vec::loadu(data);
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec = Vec::loadu(data + d);
+    acc_vec = vec_fun(acc_vec, data_vec);
+  }
+  if (size - d > 0) {
+    Vec data_vec = Vec::loadu(data + d, size - d);
+    acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
+  }
+  return vec_reduce_all(vec_fun, acc_vec);
+}
+
+// similar to reduce_all, but reduces into two outputs
+template , int> = 0>
+inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
+    const scalar_t* data, int64_t size) {
+  using Vec = vec::Vectorized;
+  if (size < Vec::size()) {
+    auto loaded_data = Vec::loadu(data, size);
+    return std::pair(
+      vec_reduce_all(vec_fun1, loaded_data, size),
+      vec_reduce_all(vec_fun2, loaded_data, size));
+  }
+  int64_t d = Vec::size();
+  Vec acc_vec1 = Vec::loadu(data);
+  Vec acc_vec2 = Vec::loadu(data);
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec = Vec::loadu(data + d);
+    acc_vec1 = vec_fun1(acc_vec1, data_vec);
+    acc_vec2 = vec_fun2(acc_vec2, data_vec);
+  }
+  if (size - d > 0) {
+    Vec data_vec = Vec::loadu(data + d, size - d);
+    acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
+    acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
+  }
+  return std::pair(
+    vec_reduce_all(vec_fun1, acc_vec1),
+    vec_reduce_all(vec_fun2, acc_vec2));
+}
+
+template , int> = 0>
+inline scalar_t map_reduce_all(
+    const MapOp& map_fun,
+    const ReduceOp& red_fun,
+    const scalar_t* data,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  if (size < Vec::size())
+    return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
+  int64_t d = Vec::size();
+  Vec acc_vec = map_fun(Vec::loadu(data));
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec = Vec::loadu(data + d);
+    data_vec = map_fun(data_vec);
+    acc_vec = red_fun(acc_vec, data_vec);
+  }
+  if (size - d > 0) {
+    Vec data_vec = Vec::loadu(data + d, size - d);
+    data_vec = map_fun(data_vec);
+    acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
+  }
+  return vec_reduce_all(red_fun, acc_vec);
+}
+
+template , int> = 0>
+inline scalar_t map2_reduce_all(
+    const MapOp& map_fun,
+    const ReduceOp& red_fun,
+    const scalar_t* data,
+    const scalar_t* data2,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  if (size < Vec::size()) {
+    Vec data_vec = Vec::loadu(data, size);
+    Vec data2_vec = Vec::loadu(data2, size);
+    data_vec = map_fun(data_vec, data2_vec);
+    return vec_reduce_all(red_fun, data_vec, size);
+  }
+  int64_t d = Vec::size();
+  Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec = Vec::loadu(data + d);
+    Vec data2_vec = Vec::loadu(data2 + d);
+    data_vec = map_fun(data_vec, data2_vec);
+    acc_vec = red_fun(acc_vec, data_vec);
+  }
+  if (size - d > 0) {
+    Vec data_vec = Vec::loadu(data + d, size - d);
+    Vec data2_vec = Vec::loadu(data2 + d, size - d);
+    data_vec = map_fun(data_vec, data2_vec);
+    acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
+  }
+  return vec_reduce_all(red_fun, acc_vec);
+}
+
+template , int> = 0>
+inline scalar_t map3_reduce_all(
+    const MapOp& map_fun,
+    const ReduceOp& red_fun,
+    const scalar_t* data,
+    const scalar_t* data2,
+    const scalar_t* data3,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  if (size < Vec::size()) {
+    Vec data_vec = Vec::loadu(data, size);
+    Vec data2_vec = Vec::loadu(data2, size);
+    Vec data3_vec = Vec::loadu(data3, size);
+    data_vec = map_fun(data_vec, data2_vec, data3_vec);
+    return vec_reduce_all(red_fun, data_vec, size);
+  }
+
+  int64_t d = Vec::size();
+  Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec = Vec::loadu(data + d);
+    Vec data2_vec = Vec::loadu(data2 + d);
+    Vec data3_vec = Vec::loadu(data3 + d);
+    data_vec = map_fun(data_vec, data2_vec, data3_vec);
+    acc_vec = red_fun(acc_vec, data_vec);
+  }
+  if (size - d > 0) {
+    Vec data_vec = Vec::loadu(data + d, size - d);
+    Vec data2_vec = Vec::loadu(data2 + d, size - d);
+    Vec data3_vec = Vec::loadu(data3 + d, size - d);
+    data_vec = map_fun(data_vec, data2_vec, data3_vec);
+    acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
+  }
+  return vec_reduce_all(red_fun, acc_vec);
+}
+
+template , int> = 0>
+inline void map(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec output_vec = vec_fun(Vec::loadu(input_data + d));
+    output_vec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
+    output_vec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map2(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data,
+    const scalar_t* input_data2,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec = Vec::loadu(input_data + d);
+    Vec data_vec2 = Vec::loadu(input_data2 + d);
+    Vec output_vec = vec_fun(data_vec, data_vec2);
+    output_vec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    Vec data_vec = Vec::loadu(input_data + d, size - d);
+    Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
+    Vec output_vec = vec_fun(data_vec, data_vec2);
+    output_vec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map3(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data1,
+    const scalar_t* input_data2,
+    const scalar_t* input_data3,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec1 = Vec::loadu(input_data1 + d);
+    Vec data_vec2 = Vec::loadu(input_data2 + d);
+    Vec data_vec3 = Vec::loadu(input_data3 + d);
+    Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
+    output_vec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
+    Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
+    Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
+    Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
+    output_vec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map4(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data1,
+    const scalar_t* input_data2,
+    const scalar_t* input_data3,
+    const scalar_t* input_data4,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % Vec::size()); d += Vec::size()) {
+    Vec data_vec1 = Vec::loadu(input_data1 + d);
+    Vec data_vec2 = Vec::loadu(input_data2 + d);
+    Vec data_vec3 = Vec::loadu(input_data3 + d);
+    Vec data_vec4 = Vec::loadu(input_data4 + d);
+    Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
+    output_vec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
+    Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
+    Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
+    Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
+    Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
+    output_vec.store(output_data + d, size - d);
+  }
+}
+
+} // namespace at::vec
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h
new file mode 100644
index 0000000000000000000000000000000000000000..e11401955706cc082d1b8a505c1582d156975957
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h
@@ -0,0 +1,549 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+
+namespace at::vec {
+
+// BFloat16 specification
+template  struct VecScalarType { using type = scalar_t; };
+template <> struct VecScalarType { using type = float; };
+template <> struct VecScalarType { using type = float; };
+
+// This is different from at::acc_type since we only need to specialize BFloat16
+template 
+using vec_scalar_t = typename VecScalarType::type;
+
+// Vector conversion between float and bfloat16/half
+template , int> = 0>
+inline std::tuple, Vectorized> convert_to_float(const Vectorized&);
+
+template <>
+inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) {
+  return convert_bfloat16_float(a);
+}
+
+template <>
+inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) {
+    return convert_half_float(a);
+}
+
+template , int> = 0>
+inline Vectorized convert_from_float(const Vectorized&, const Vectorized&);
+
+template <>
+inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) {
+  return convert_float_bfloat16(a, b);
+}
+
+template <>
+inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) {
+  return convert_float_half(a, b);
+}
+
+template , int> = 0>
+inline void load_to_float(const scalar_t *data, Vectorized &out1, Vectorized &out2);
+
+template <>
+inline void load_to_float (const BFloat16 *data, Vectorized &out1, Vectorized &out2) {
+  load_fp32_from_bf16(data, out1, out2);
+}
+
+template <>
+inline void load_to_float (const Half *data, Vectorized &out1, Vectorized &out2) {
+  load_fp32_from_fp16(data, out1, out2);
+}
+
+template , int> = 0>
+inline void load_to_float(const scalar_t *data, Vectorized &out);
+
+template <>
+inline void load_to_float (const BFloat16 *data, Vectorized &out) {
+  load_fp32_from_bf16(data, out);
+}
+
+template <>
+inline void load_to_float (const Half *data, Vectorized &out) {
+  load_fp32_from_fp16(data, out);
+}
+
+// Note that we already have specialized member of Vectorized for BFloat16
+// so the following functions would run smoothly:
+//   using Vec = Vectorized;
+//   Vec one = Vec(BFloat16(1));
+//   vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
+//
+// Then why we still need to specialize "functional"?
+//   If we do specialization at Vectorized<> level, the above example would need 3 pairs of
+//   conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
+//   If we do specialization at vec::map<>() level, we have only 1 pair of conversion
+//   of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
+//
+// The following BFloat16 functionality will only do data type conversion for input
+// and output vector (reduce functionality will only convert the final scalar back to bf16).
+// Compared to Vectorized<> specialization,
+//   1. better performance since we have less data type conversion;
+//   2. less rounding error since immediate results are kept in fp32;
+//   3. accumulation done on data type of fp32.
+//
+//  If you plan to extend this file, please ensure adding unit tests at
+//    aten/src/ATen/test/vec_test_all_types.cpp
+//
+template , int> = 0>
+inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  if (size < bVec::size()) {
+    bVec data_bvec = bVec::loadu(data, size);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    if (size > fVec::size()) {
+      data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
+      return vec_reduce_all(vec_fun, data_fvec0, fVec::size());
+    } else {
+      return vec_reduce_all(vec_fun, data_fvec0, size);
+    }
+  }
+  int64_t d = bVec::size();
+  bVec acc_bvec = bVec::loadu(data);
+  auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec);
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
+    acc_fvec1 = vec_fun(acc_fvec1, data_fvec1);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    if (size - d > fVec::size()) {
+      acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
+      acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
+    } else {
+      acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
+    }
+  }
+  acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
+  return vec_reduce_all(vec_fun, acc_fvec0);
+}
+
+template , int> = 0>
+inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
+    const scalar_t* data, int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  if (size < bVec::size()) {
+    bVec data_bvec = bVec::loadu(data, size);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    if (size > fVec::size()) {
+      fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
+      fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
+      return std::pair(
+          vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()),
+          vec_reduce_all(vec_fun2, acc2_fvec, fVec::size()));
+    } else {
+      return std::pair(
+          vec_reduce_all(vec_fun1, data_fvec0, size),
+          vec_reduce_all(vec_fun2, data_fvec0, size));
+    }
+  }
+  int64_t d = bVec::size();
+  bVec acc_bvec = bVec::loadu(data);
+  auto [acc1_fvec0, acc1_fvec1] = convert_to_float(acc_bvec);
+  auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc_bvec);
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
+    acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1);
+    acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
+    acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    if (size - d > fVec::size()) {
+      acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
+      acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size());
+      acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
+      acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size());
+    } else {
+      acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
+      acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
+    }
+  }
+  acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
+  acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1);
+  return std::pair(
+      vec_reduce_all(vec_fun1, acc1_fvec0),
+      vec_reduce_all(vec_fun2, acc2_fvec0));
+}
+
+template , int> = 0>
+inline float map_reduce_all(
+    const MapOp& map_fun,
+    const ReduceOp& red_fun,
+    const scalar_t* data,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  if (size < bVec::size()) {
+    bVec data_bvec = bVec::loadu(data, size);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    if (size > fVec::size()) {
+      data_fvec0 = map_fun(data_fvec0);
+      data_fvec1 = map_fun(data_fvec1);
+      data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
+      return vec_reduce_all(red_fun, data_fvec0, fVec::size());
+    } else {
+      data_fvec0 = map_fun(data_fvec0);
+      return vec_reduce_all(red_fun, data_fvec0, size);
+    }
+  }
+  int64_t d = bVec::size();
+  bVec acc_bvec = bVec::loadu(data);
+  auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec);
+  acc_fvec0 = map_fun(acc_fvec0);
+  acc_fvec1 = map_fun(acc_fvec1);
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    data_fvec0 = map_fun(data_fvec0);
+    data_fvec1 = map_fun(data_fvec1);
+    acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
+    acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    if (size - d > fVec::size()) {
+      data_fvec0 = map_fun(data_fvec0);
+      data_fvec1 = map_fun(data_fvec1);
+      acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
+      acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
+    } else {
+      data_fvec0 = map_fun(data_fvec0);
+      acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
+    }
+  }
+  acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
+  return vec_reduce_all(red_fun, acc_fvec0);
+}
+
+template , int> = 0>
+inline float map2_reduce_all(
+    const MapOp& map_fun,
+    const ReduceOp& red_fun,
+    const scalar_t* data,
+    const scalar_t* data2,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  if (size < bVec::size()) {
+    bVec data_bvec = bVec::loadu(data, size);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(data2, size);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    if (size > fVec::size()) {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0);
+      data_fvec1 = map_fun(data_fvec1, data2_fvec1);
+      data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
+      return vec_reduce_all(red_fun, data_fvec0, fVec::size());
+    } else {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0);
+      return vec_reduce_all(red_fun, data_fvec0, size);
+    }
+  }
+  int64_t d = bVec::size();
+  bVec acc_bvec = bVec::loadu(data);
+  auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec);
+  bVec acc2_bvec = bVec::loadu(data2);
+  auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec);
+  acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0);
+  acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1);
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(data2 + d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    data_fvec0 = map_fun(data_fvec0, data2_fvec0);
+    data_fvec1 = map_fun(data_fvec1, data2_fvec1);
+    acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
+    acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(data2 + d, size - d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    if (size - d > fVec::size()) {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0);
+      data_fvec1 = map_fun(data_fvec1, data2_fvec1);
+      acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
+      acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
+    } else {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0);
+      acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
+    }
+  }
+  acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
+  return vec_reduce_all(red_fun, acc_fvec0);
+}
+
+template , int> = 0>
+inline float map3_reduce_all(
+    const MapOp& map_fun,
+    const ReduceOp& red_fun,
+    const scalar_t* data,
+    const scalar_t* data2,
+    const scalar_t* data3,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  if (size < bVec::size()) {
+    bVec data_bvec = bVec::loadu(data, size);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(data2, size);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(data3, size);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    if (size > fVec::size()) {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
+      data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
+      data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
+      return vec_reduce_all(red_fun, data_fvec0, fVec::size());
+    } else {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
+      return vec_reduce_all(red_fun, data_fvec0, size);
+    }
+  }
+  int64_t d = bVec::size();
+  bVec acc_bvec = bVec::loadu(data);
+  auto [acc_fvec0, acc_fvec1] = convert_to_float(acc_bvec);
+  bVec acc2_bvec = bVec::loadu(data2);
+  auto [acc2_fvec0, acc2_fvec1] = convert_to_float(acc2_bvec);
+  bVec acc3_bvec = bVec::loadu(data3);
+  auto [acc3_fvec0, acc3_fvec1] = convert_to_float(acc3_bvec);
+  acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0);
+  acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1);
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(data2 + d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(data3 + d);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
+    data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
+    acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
+    acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(data2 + d, size - d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(data3 + d, size - d);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    if (size - d > fVec::size()) {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
+      data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
+      acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
+      acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
+    } else {
+      data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
+      acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
+    }
+  }
+  acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
+  return vec_reduce_all(red_fun, acc_fvec0);
+}
+
+template , int> = 0>
+inline void map(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(input_data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    fVec output_fvec0 = vec_fun(data_fvec0);
+    fVec output_fvec1 = vec_fun(data_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(input_data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    fVec output_fvec0 = vec_fun(data_fvec0);
+    fVec output_fvec1 = vec_fun(data_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const float* input_data,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    fVec data_fvec0 = fVec::loadu(input_data + d);
+    fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size());
+    fVec output_fvec0 = vec_fun(data_fvec0);
+    fVec output_fvec1 = vec_fun(data_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    fVec data_fvec0, data_fvec1;
+    if (size - d > fVec::size()) {
+      data_fvec0 = fVec::loadu(input_data + d);
+      data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
+    } else {
+      // choose to align with behaviour of bVec::loadu(ptr, size),
+      // which leaves data_fvec1 uninitialized
+      data_fvec0 = fVec::loadu(input_data + d, size - d);
+    }
+    fVec output_fvec0 = vec_fun(data_fvec0);
+    fVec output_fvec1 = vec_fun(data_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map2(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data,
+    const scalar_t* input_data2,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data_bvec = bVec::loadu(input_data + d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(input_data2 + d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
+    fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    bVec data_bvec = bVec::loadu(input_data + d, size - d);
+    auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec);
+    bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
+    fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map3(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data1,
+    const scalar_t* input_data2,
+    const scalar_t* input_data3,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data1_bvec = bVec::loadu(input_data1 + d);
+    auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec);
+    bVec data2_bvec = bVec::loadu(input_data2 + d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(input_data3 + d);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
+    fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
+    auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec);
+    bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
+    fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d, size - d);
+  }
+}
+
+template , int> = 0>
+inline void map4(
+    const Op& vec_fun,
+    scalar_t* output_data,
+    const scalar_t* input_data1,
+    const scalar_t* input_data2,
+    const scalar_t* input_data3,
+    const scalar_t* input_data4,
+    int64_t size) {
+  using bVec = vec::Vectorized;
+  using fVec = vec::Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec data1_bvec = bVec::loadu(input_data1 + d);
+    auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec);
+    bVec data2_bvec = bVec::loadu(input_data2 + d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(input_data3 + d);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    bVec data4_bvec = bVec::loadu(input_data4 + d);
+    auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec);
+    fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
+    fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d);
+  }
+  if (size - d > 0) {
+    bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
+    auto [data1_fvec0, data1_fvec1] = convert_to_float(data1_bvec);
+    bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
+    auto [data2_fvec0, data2_fvec1] = convert_to_float(data2_bvec);
+    bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
+    auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec);
+    bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
+    auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec);
+    fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
+    fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
+    bVec output_bvec = convert_from_float(output_fvec0, output_fvec1);
+    output_bvec.store(output_data + d, size - d);
+  }
+}
+
+} // namespace at::vec
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/intrinsics.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/intrinsics.h
new file mode 100644
index 0000000000000000000000000000000000000000..054f457b7e006cff43c622982fe9885d17869a50
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/intrinsics.h
@@ -0,0 +1,43 @@
+#pragma once
+#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
+/* GCC or clang-compatible compiler, targeting x86/x86-64 */
+#include 
+#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
+/* Clang-compatible compiler, targeting arm neon */
+#include 
+#elif defined(_MSC_VER)
+/* Microsoft C/C++-compatible compiler */
+#include 
+#if _MSC_VER <= 1900
+#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
+#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
+#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
+#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
+#endif
+#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
+/* GCC-compatible compiler, targeting ARM with NEON */
+#include 
+#if defined (MISSING_ARM_VLD1)
+#include 
+#elif defined (MISSING_ARM_VST1)
+#include 
+#endif
+#elif defined(__GNUC__) && defined(__IWMMXT__)
+/* GCC-compatible compiler, targeting ARM with WMMX */
+#include 
+#elif defined(__s390x__)
+// targets Z/architecture
+// we will include vecintrin later
+#elif (defined(__GNUC__) || defined(__xlC__)) &&                               \
+        (defined(__VEC__) || defined(__ALTIVEC__))
+/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
+#include 
+/* We need to undef those tokens defined by  to avoid conflicts
+   with the C++ types. => Can still use __bool/__vector */
+#undef bool
+#undef vector
+#undef pixel
+#elif defined(__GNUC__) && defined(__SPE__)
+/* GCC-compatible compiler, targeting PowerPC with SPE */
+#include 
+#endif
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec.h
new file mode 100644
index 0000000000000000000000000000000000000000..7b6912193ff3627c07a05ed696d519ef00cb0bc7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec.h
@@ -0,0 +1,47 @@
+#pragma once
+
+#if defined(CPU_CAPABILITY_AVX512)
+#include 
+#else
+#include 
+#endif
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+inline Vectorized convert_to_bool(Vectorized x) {
+  __at_align__ bool buffer[x.size()];
+  x.ne(Vectorized(0)).store(buffer);
+
+  Vectorized ret;
+  static_assert(x.size() == ret.size(), "");
+  std::memcpy(ret, buffer, ret.size() * sizeof(bool));
+  return ret;
+}
+
+template <>
+inline Vectorized Vectorized::loadu(const void* ptr) {
+  // See NOTE [Loading boolean values]
+  return convert_to_bool(Vectorized::loadu(ptr));
+}
+
+template <>
+inline Vectorized Vectorized::loadu(const void* ptr, int64_t count) {
+  // See NOTE [Loading boolean values]
+  return convert_to_bool(Vectorized::loadu(ptr, count));
+}
+
+template 
+struct VecHoldType { using hold_type = typename VT::value_type; };
+
+template <>
+struct VecHoldType> { using hold_type = BFloat16; };
+
+template <>
+struct VecHoldType> {using hold_type = Half; };
+
+template 
+using vechold_type = typename VecHoldType::hold_type;
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h
new file mode 100644
index 0000000000000000000000000000000000000000..79cbec407d65c19df52f61308c0ecc1a6c9d6d9e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h
@@ -0,0 +1,452 @@
+/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7.  */
+
+__extension__ extern __inline uint8x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_u8_x2 (const uint8_t *__a)
+{
+  uint8x8x2_t ret;
+  asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int8x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_s8_x2 (const int8_t *__a)
+{
+  int8x8x2_t ret;
+  asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint16x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_u16_x2 (const uint16_t *__a)
+{
+  uint16x4x2_t ret;
+  asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int16x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_s16_x2 (const int16_t *__a)
+{
+  int16x4x2_t ret;
+  asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint32x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_u32_x2 (const uint32_t *__a)
+{
+  uint32x2x2_t ret;
+  asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int32x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_s32_x2 (const int32_t *__a)
+{
+  int32x2x2_t ret;
+  asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint64x1x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_u64_x2 (const uint64_t *__a)
+{
+  uint64x1x2_t ret;
+  asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int64x1x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_s64_x2 (const int64_t *__a)
+{
+  int64x1x2_t ret;
+  __builtin_aarch64_simd_oi __o;
+  asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline float16x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_f16_x2 (const float16_t *__a)
+{
+  float16x4x2_t ret;
+  asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline float32x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_f32_x2 (const float32_t *__a)
+{
+  float32x2x2_t ret;
+  asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline float64x1x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_f64_x2 (const float64_t *__a)
+{
+  float64x1x2_t ret;
+  asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline poly8x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_p8_x2 (const poly8_t *__a)
+{
+  poly8x8x2_t ret;
+  asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline poly16x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_p16_x2 (const poly16_t *__a)
+{
+  poly16x4x2_t ret;
+  asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline poly64x1x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1_p64_x2 (const poly64_t *__a)
+{
+  poly64x1x2_t ret;
+  asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint8x16x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_u8_x2 (const uint8_t *__a)
+{
+  uint8x16x2_t ret;
+  asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int8x16x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_s8_x2 (const int8_t *__a)
+{
+  int8x16x2_t ret;
+  asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint16x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_u16_x2 (const uint16_t *__a)
+{
+  uint16x8x2_t ret;
+  asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int16x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_s16_x2 (const int16_t *__a)
+{
+  int16x8x2_t ret;
+  asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint32x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_u32_x2 (const uint32_t *__a)
+{
+  uint32x4x2_t ret;
+  asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int32x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_s32_x2 (const int32_t *__a)
+{
+  int32x4x2_t ret;
+  asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline uint64x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_u64_x2 (const uint64_t *__a)
+{
+  uint64x2x2_t ret;
+  asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline int64x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_s64_x2 (const int64_t *__a)
+{
+  int64x2x2_t ret;
+  asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline float16x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_f16_x2 (const float16_t *__a)
+{
+  float16x8x2_t ret;
+  asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline float32x4x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_f32_x2 (const float32_t *__a)
+{
+  float32x4x2_t ret;
+  asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline float64x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_f64_x2 (const float64_t *__a)
+{
+  float64x2x2_t ret;
+  asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline poly8x16x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_p8_x2 (const poly8_t *__a)
+{
+  poly8x16x2_t ret;
+  asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline poly16x8x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_p16_x2 (const poly16_t *__a)
+{
+  poly16x8x2_t ret;
+  asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+__extension__ extern __inline poly64x2x2_t
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vld1q_p64_x2 (const poly64_t *__a)
+{
+  poly64x2x2_t ret;
+  asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w" (ret) : "Q"(*__a));
+  return ret;
+}
+
+/* vst1x2 */
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_s64_x2 (int64_t * __a, int64x1x2_t val)
+{
+  asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_u64_x2 (uint64_t * __a, uint64x1x2_t val)
+{
+  asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_f64_x2 (float64_t * __a, float64x1x2_t val)
+{
+  asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_s8_x2 (int8_t * __a, int8x8x2_t val)
+{
+  asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_p8_x2 (poly8_t * __a, poly8x8x2_t val)
+{
+  asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_s16_x2 (int16_t * __a, int16x4x2_t val)
+{
+  asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_p16_x2 (poly16_t * __a, poly16x4x2_t val)
+{
+  asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_s32_x2 (int32_t * __a, int32x2x2_t val)
+{
+  asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_u8_x2 (uint8_t * __a, uint8x8x2_t val)
+{
+  asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_u16_x2 (uint16_t * __a, uint16x4x2_t val)
+{
+  asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_u32_x2 (uint32_t * __a, uint32x2x2_t val)
+{
+  asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_f16_x2 (float16_t * __a, float16x4x2_t val)
+{
+  asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_f32_x2 (float32_t * __a, float32x2x2_t val)
+{
+  asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1_p64_x2 (poly64_t * __a, poly64x1x2_t val)
+{
+  asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_s8_x2 (int8_t * __a, int8x16x2_t val)
+{
+  asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_p8_x2 (poly8_t * __a, poly8x16x2_t val)
+{
+  asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_s16_x2 (int16_t * __a, int16x8x2_t val)
+{
+  asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_p16_x2 (poly16_t * __a, poly16x8x2_t val)
+{
+  asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_s32_x2 (int32_t * __a, int32x4x2_t val)
+{
+  asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_s64_x2 (int64_t * __a, int64x2x2_t val)
+{
+  asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_u8_x2 (uint8_t * __a, uint8x16x2_t val)
+{
+  asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_u16_x2 (uint16_t * __a, uint16x8x2_t val)
+{
+  asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_u32_x2 (uint32_t * __a, uint32x4x2_t val)
+{
+  asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_u64_x2 (uint64_t * __a, uint64x2x2_t val)
+{
+  asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_f16_x2 (float16_t * __a, float16x8x2_t val)
+{
+  asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_f32_x2 (float32_t * __a, float32x4x2_t val)
+{
+  asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_f64_x2 (float64_t * __a, float64x2x2_t val)
+{
+  asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
+}
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_p64_x2 (poly64_t * __a, poly64x2x2_t val)
+{
+  asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q" (*__a) : "w" (val));
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h
new file mode 100644
index 0000000000000000000000000000000000000000..d882a5dbe8753aab083e5a72fde9684d26998196
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h
@@ -0,0 +1,8 @@
+/* Workaround for missing vst1q_f32_x2 in gcc-8.  */
+
+__extension__ extern __inline void
+__attribute__ ((__always_inline__, __gnu_inline__, __artificial__))
+vst1q_f32_x2 (float32_t * __a, float32x4x2_t val)
+{
+  asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q" (*__a) : "w" (val));
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h
new file mode 100644
index 0000000000000000000000000000000000000000..272c3295fca5d21ed0d7e7d67801f7acf78400d0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h
@@ -0,0 +1,307 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+
+#include 
+#if !(defined(__VSX__)  || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#elif defined(__VSX__)  || defined(CPU_CAPABILITY_VSX)
+#include 
+#else
+#include 
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::vec {
+
+// Note [CPU_CAPABILITY namespace]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// This header, and all of its subheaders, will be compiled with
+// different architecture flags for each supported set of vector
+// intrinsics. So we need to make sure they aren't inadvertently
+// linked together. We do this by declaring objects in an `inline
+// namespace` which changes the name mangling, but can still be
+// accessed as `at::vec`.
+inline namespace CPU_CAPABILITY {
+
+inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
+  stream << val.val_;
+  return stream;
+}
+inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
+  stream << static_cast(val.val_);
+  return stream;
+}
+inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
+  stream << static_cast(val.val_);
+  return stream;
+}
+
+template 
+std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) {
+  T buf[Vectorized::size()];
+  vec.store(buf);
+  stream << "vec[";
+  for (int i = 0; i != Vectorized::size(); i++) {
+    if (i != 0) {
+      stream << ", ";
+    }
+    stream << buf[i];
+  }
+  stream << "]";
+  return stream;
+}
+
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm256_castpd_ps(src);
+}
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm256_castps_pd(src);
+}
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm256_castsi256_ps(src);
+}
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm256_castsi256_pd(src);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template
+std::enable_if_t>
+inline gather(const double* base_addr, const Vectorized& vindex) {
+  return _mm256_i64gather_pd(base_addr, vindex, scale);
+}
+
+template
+std::enable_if_t>
+inline gather(const float* base_addr, const Vectorized& vindex) {
+  return _mm256_i32gather_ps(base_addr, vindex, scale);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template
+std::enable_if_t>
+inline mask_gather(const Vectorized& src, const double* base_addr,
+                   const Vectorized& vindex, Vectorized& mask) {
+  return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
+}
+
+template
+std::enable_if_t>
+inline mask_gather(const Vectorized& src, const float* base_addr,
+                   const Vectorized& vindex, Vectorized& mask) {
+  return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+// Only works for inputs in the range: [-2^51, 2^51]
+// From: https://stackoverflow.com/a/41148578
+template<>
+Vectorized
+inline convert_to_int_of_same_size(const Vectorized &src) {
+  auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
+  return _mm256_sub_epi64(
+      _mm256_castpd_si256(x),
+      _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
+  );
+}
+
+template<>
+Vectorized
+inline convert_to_int_of_same_size(const Vectorized &src) {
+  return _mm256_cvttps_epi32(src);
+}
+
+// Only works for inputs in the range: [-2^51, 2^51]
+// From: https://stackoverflow.com/a/41148578
+template<>
+Vectorized
+inline convert_to_fp_of_same_size(const Vectorized &src) {
+  auto x = _mm256_add_epi64(src, _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000)));
+  return _mm256_sub_pd(
+    _mm256_castsi256_pd(x),
+    _mm256_set1_pd(0x0018000000000000)
+  );
+}
+
+template<>
+Vectorized
+inline convert_to_fp_of_same_size(const Vectorized &src) {
+  return _mm256_cvtepi32_ps(src);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template <>
+std::pair, Vectorized>
+inline interleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, a1, a3, a3}
+  //   b = {b0, b1, b2, b3}
+
+  // swap lanes:
+  //   a_swapped = {a0, a1, b0, b1}
+  //   b_swapped = {a2, a3, b2, b3}
+  auto a_swapped = _mm256_permute2f128_pd(a, b, 0b0100000);  // 0, 2.   4 bits apart
+  auto b_swapped = _mm256_permute2f128_pd(a, b, 0b0110001);  // 1, 3.   4 bits apart
+
+  // group cols crossing lanes:
+  //   return {a0, b0, a1, b1}
+  //          {a2, b2, a3, b3}
+  return std::make_pair(_mm256_permute4x64_pd(a_swapped, 0b11011000),  // 0, 2, 1, 3
+                        _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3
+}
+
+template <>
+std::pair, Vectorized>
+inline interleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, a1, a2, a3, a4, a5, a6, a7}
+  //   b = {b0, b1, b2, b3, b4, b5, b6, b7}
+
+  // swap lanes:
+  //   a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
+  //   b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
+  // TODO: can we support caching this?
+  auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000);  // 0, 2.   4 bits apart
+  auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001);  // 1, 3.   4 bits apart
+
+  // group cols crossing lanes:
+  //   return {a0, b0, a1, b1, a2, b2, a3, b3}
+  //          {a4, b4, a5, b5, a6, b6, a7, b7}
+  const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
+  return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
+                        _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template <>
+std::pair, Vectorized>
+inline deinterleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1}
+  //   b = {a2, b2, a3, b3}
+
+  // group cols crossing lanes:
+  //   a_grouped = {a0, a1, b0, b1}
+  //   b_grouped = {a2, a3, b2, b3}
+  auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000);  // 0, 2, 1, 3
+  auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000);  // 0, 2, 1, 3
+
+  // swap lanes:
+  //   return {a0, a1, a2, a3}
+  //          {b0, b1, b2, b3}
+  return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0100000),  // 0, 2.   4 bits apart
+                        _mm256_permute2f128_pd(a_grouped, b_grouped, 0b0110001)); // 1, 3.   4 bits apart
+}
+
+template <>
+std::pair, Vectorized>
+inline deinterleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1, a2, b2, a3, b3}
+  //   b = {a4, b4, a5, b5, a6, b6, a7, b7}
+
+  // group cols crossing lanes:
+  //   a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3}
+  //   b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7}
+  // TODO: can we support caching this?
+  const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+  auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
+  auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
+
+  // swap lanes:
+  //   return {a0, a1, a2, a3, a4, a5, a6, a7}
+  //          {b0, b1, b2, b3, b4, b5, b6, b7}
+  return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0100000),  // 0, 2.   4 bits apart
+                        _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3.   4 bits apart
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+  return _mm256_permutevar8x32_ps(v, mask_float);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  return _mm256_permute4x64_pd(v, 27);  // 27 == _MM_SHUFFLE(0, 1, 2, 3)
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  return _mm256_permute4x64_epi64(v, 27);  // 27 == _MM_SHUFFLE(0, 1, 2, 3)
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+  return _mm256_permutevar8x32_epi32(v, mask_int32);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m256i mask = _mm256_set_epi8(
+    1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
+    1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14
+  );
+  auto reversed = _mm256_shuffle_epi8(v, mask);
+  return _mm256_permute2x128_si256(reversed, reversed, 1);
+}
+
+inline __m256i flip8(const __m256i & v) {
+  const __m256i mask_int8 = _mm256_set_epi8(
+    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+  );
+  auto reversed = _mm256_shuffle_epi8(v, mask_int8);
+  return _mm256_permute2x128_si256(reversed, reversed, 1);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  return flip8(v);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  return flip8(v);
+}
+
+#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+}} // namepsace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h
new file mode 100644
index 0000000000000000000000000000000000000000..b53184cbae191137d8400a0be093f2aa5f90d3cc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h
@@ -0,0 +1,1096 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#include 
+#endif
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wignored-qualifiers"
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+// bfloat16 conversion
+static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
+  o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16));
+}
+
+static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) {
+  __m128i lo = _mm256_extractf128_si256(a, 0);
+  __m128i hi = _mm256_extractf128_si256(a, 1);
+  cvtbf16_fp32(lo, o1);
+  cvtbf16_fp32(hi, o2);
+}
+static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) {
+  __m256i lo = _mm256_castps_si256(a);
+  __m256i hi = _mm256_castps_si256(b);
+  __m256i nan = _mm256_set1_epi32(0xffff);
+  __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q));
+  __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q));
+  __m256i ones = _mm256_set1_epi32(0x1);
+  __m256i vec_bias = _mm256_set1_epi32(0x7fff);
+  // uint32_t lsb = (input >> 16) & 1;
+  auto t_lo = _mm256_and_si256(_mm256_srli_epi32(lo, 16), ones);
+  auto t_hi = _mm256_and_si256(_mm256_srli_epi32(hi, 16), ones);
+  // uint32_t rounding_bias = 0x7fff + lsb;
+  t_lo = _mm256_add_epi32(t_lo, vec_bias);
+  t_hi = _mm256_add_epi32(t_hi, vec_bias);
+  // input += rounding_bias;
+  t_lo = _mm256_add_epi32(t_lo, lo);
+  t_hi = _mm256_add_epi32(t_hi, hi);
+  // input = input >> 16;
+  t_lo = _mm256_srli_epi32(t_lo, 16);
+  t_hi = _mm256_srli_epi32(t_hi, 16);
+  // Check NaN before converting back to bf16
+  t_lo = _mm256_blendv_epi8(nan, t_lo, mask_lo);
+  t_hi = _mm256_blendv_epi8(nan, t_hi, mask_hi);
+
+  t_lo = _mm256_packus_epi32(t_lo, t_hi);      // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
+  return _mm256_permute4x64_epi64(t_lo, 0xd8); // 11        01        10        00
+}
+
+static inline __m256i merge_compare_result(const __m256& a, const __m256& b) {
+  __m256i lo = _mm256_castps_si256(a);
+  __m256i hi = _mm256_castps_si256(b);
+  lo = _mm256_srli_epi32(lo, 16);
+  hi = _mm256_srli_epi32(hi, 16);
+  auto out = _mm256_packus_epi32(lo, hi);
+  return _mm256_permute4x64_epi64(out, 0xd8);
+}
+
+// float16 conversion
+static inline void cvtfp16_fp32(const __m128i& a, __m256& o) {
+  o = _mm256_cvtph_ps(a);
+}
+
+static inline void cvtfp16_fp32(const __m256i& a, __m256& o1, __m256& o2) {
+  __m128i lo = _mm256_extractf128_si256(a, 0);
+  __m128i hi = _mm256_extractf128_si256(a, 1);
+  cvtfp16_fp32(lo, o1);
+  cvtfp16_fp32(hi, o2);
+}
+
+static inline __m256i cvtfp32_fp16(const __m256& a, const __m256& b) {
+  __m128i lo = _mm256_cvtps_ph(
+      a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  __m128i hi = _mm256_cvtps_ph(
+      b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1);
+}
+
+// dtype conversion between float16/bfloat16 and float32
+template , int> = 0>
+inline void cvt_to_fp32(const __m128i& a, __m256& o);
+template <> inline void cvt_to_fp32(const __m128i& a, __m256& o) {
+  cvtbf16_fp32(a, o);
+};
+template <> inline void cvt_to_fp32(const __m128i& a, __m256& o) {
+  cvtfp16_fp32(a, o);
+}
+
+template , int> = 0>
+inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2);
+template <> inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) {
+  cvtbf16_fp32(a, o1, o2);
+}
+template <> inline void cvt_to_fp32(const __m256i& a, __m256& o1, __m256& o2) {
+  cvtfp16_fp32(a, o1, o2);
+}
+
+template , int> = 0>
+inline __m256i cvt_from_fp32(const __m256& a, const __m256& b);
+template <> inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) {
+  return cvtfp32_bf16(a, b);
+}
+template <> inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) {
+  return merge_compare_result(a, b);
+}
+template <> inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) {
+  return cvtfp32_fp16(a, b);
+}
+template <> inline __m256i cvt_from_fp32(const __m256& a, const __m256& b) {
+  return cvtfp32_fp16(a, b);
+}
+
+template 
+class Vectorized16 {
+static_assert(
+  is_reduced_floating_point_v,
+  "Support only float16 and bfloat16.");
+protected:
+  __m256i values;
+public:
+  using value_type = uint16_t;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 16;
+  }
+  Vectorized16() {}
+  Vectorized16(__m256i v) : values(v) {}
+  Vectorized16(T val) {
+    value_type uw = val.x;
+    values = _mm256_set1_epi16(uw);
+  }
+  Vectorized16(T val1, T val2, T val3, T val4,
+         T val5, T val6, T val7, T val8,
+         T val9, T val10, T val11, T val12,
+         T val13, T val14, T val15, T val16) {
+    values = _mm256_setr_epi16(
+        val1.x, val2.x, val3.x, val4.x, val5.x, val6.x, val7.x, val8.x,
+        val9.x, val10.x, val11.x, val12.x, val13.x, val14.x, val15.x, val16.x);
+  }
+  operator __m256i() const {
+    return values;
+  }
+  T& operator[](int idx) = delete;
+  const T& operator[](int idx) const  = delete;
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    __m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0));
+    return _mm256_movemask_epi8(cmp);
+  }
+  static Vectorized loadu(const void* ptr, int16_t count = size()) {
+    if (count == size())
+      return _mm256_loadu_si256(reinterpret_cast(ptr));
+
+    __at_align__ int16_t tmp_values[size()];
+    std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
+    return _mm256_loadu_si256(reinterpret_cast(tmp_values));
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
+    } else if (count > 0) {
+      __at_align__ int16_t tmp_values[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
+    }
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    __at_align__ int16_t tmp_values[size()];
+    a.store(tmp_values);
+    if (mask & 0x01)
+      tmp_values[0] = _mm256_extract_epi16(b.values, 0);
+    if (mask & 0x02)
+      tmp_values[1] = _mm256_extract_epi16(b.values, 1);
+    if (mask & 0x04)
+      tmp_values[2] = _mm256_extract_epi16(b.values, 2);
+    if (mask & 0x08)
+      tmp_values[3] = _mm256_extract_epi16(b.values, 3);
+    if (mask & 0x10)
+      tmp_values[4] = _mm256_extract_epi16(b.values, 4);
+    if (mask & 0x20)
+      tmp_values[5] = _mm256_extract_epi16(b.values, 5);
+    if (mask & 0x40)
+      tmp_values[6] = _mm256_extract_epi16(b.values, 6);
+    if (mask & 0x80)
+      tmp_values[7] = _mm256_extract_epi16(b.values, 7);
+    if (mask & 0x100)
+      tmp_values[8] = _mm256_extract_epi16(b.values, 8);
+    if (mask & 0x200)
+      tmp_values[9] = _mm256_extract_epi16(b.values, 9);
+    if (mask & 0x400)
+      tmp_values[10] = _mm256_extract_epi16(b.values, 10);
+    if (mask & 0x800)
+      tmp_values[11] = _mm256_extract_epi16(b.values, 11);
+    if (mask & 0x1000)
+      tmp_values[12] = _mm256_extract_epi16(b.values, 12);
+    if (mask & 0x2000)
+      tmp_values[13] = _mm256_extract_epi16(b.values, 13);
+    if (mask & 0x4000)
+      tmp_values[14] = _mm256_extract_epi16(b.values, 14);
+    if (mask & 0x8000)
+      tmp_values[15] = _mm256_extract_epi16(b.values, 15);
+    return loadu(tmp_values);
+  }
+  static Vectorized blendv(const Vectorized& a,
+      const Vectorized& b, const Vectorized& mask) {
+    return _mm256_blendv_epi8(a.values, b.values, mask.values);
+  }
+  template
+  static Vectorized arange(T base = 0.f, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
+  }
+  static Vectorized set(const Vectorized& a,
+      const Vectorized& b, int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+      case 8:
+        return blend<255>(a, b);
+      case 9:
+        return blend<511>(a, b);
+      case 10:
+        return blend<1023>(a, b);
+      case 11:
+        return blend<2047>(a, b);
+      case 12:
+        return blend<4095>(a, b);
+      case 13:
+        return blend<8191>(a, b);
+      case 14:
+        return blend<16383>(a, b);
+      case 15:
+        return blend<32767>(a, b);
+    }
+    return b;
+  }
+  Vectorized map(const __m256 (*const vop)(__m256)) const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    const auto o1 = vop(lo);
+    const auto o2 = vop(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized isnan() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    lo = _mm256_cmp_ps(lo, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
+    hi = _mm256_cmp_ps(hi, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
+    return merge_compare_result(lo, hi);
+  }
+  Vectorized abs() const {
+    return _mm256_andnot_si256(_mm256_set1_epi16(0x8000), values);
+  }
+  Vectorized angle() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto angle_lambda = [](__m256 values_2) {
+      const auto zero_vec = _mm256_set1_ps(0.f);
+      const auto nan_vec = _mm256_set1_ps(NAN);
+      const auto not_nan_mask = _mm256_cmp_ps(values_2, values_2, _CMP_EQ_OQ);
+      const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
+      const auto pi = _mm256_set1_ps(c10::pi);
+
+      const auto neg_mask = _mm256_cmp_ps(values_2, zero_vec, _CMP_LT_OQ);
+      auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
+      angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
+      return angle;
+    };
+    auto o1 = angle_lambda(lo);
+    auto o2 = angle_lambda(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_epi16(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return map(Sleef_acosf8_u10);
+  }
+  Vectorized acosh() const {
+    return map(Sleef_acoshf8_u10);
+  }
+  Vectorized asin() const {
+    return map(Sleef_asinf8_u10);
+  }
+  Vectorized atan() const {
+    return map(Sleef_atanf8_u10);
+  }
+  Vectorized atanh() const {
+    return map(Sleef_atanhf8_u10);
+  }
+  Vectorized atan2(const Vectorized &b) const {
+    __m256 lo, hi;
+    __m256 b1, b2;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(b.values, b1, b2);
+    auto o1 = Sleef_atan2f8_u10(lo, b1);
+    auto o2 = Sleef_atan2f8_u10(hi, b2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    // copy sign bit (0x8000) from sign and remaining bits from values
+    __m256i mask_value = _mm256_set1_epi32(~0x80008000);
+    __m256i mask_signbit = _mm256_set1_epi32(0x80008000);
+    return Vectorized(
+      _mm256_or_si256(
+        _mm256_and_si256(values, mask_value),
+        _mm256_and_si256(sign, mask_signbit)));
+  }
+  Vectorized erf() const {
+    return map(Sleef_erff8_u10);
+  }
+  Vectorized erfc() const {
+    return map(Sleef_erfcf8_u15);
+  }
+  Vectorized erfinv() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm256_storeu_ps(reinterpret_cast(tmp2), hi);
+    for (int64_t i = 0; i < size() / 2; i++) {
+      tmp1[i] = calc_erfinv(tmp1[i]);
+      tmp2[i] = calc_erfinv(tmp2[i]);
+    }
+    auto o1 = _mm256_loadu_ps(tmp1);
+    auto o2 = _mm256_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized exp() const {
+    return map(Sleef_expf8_u10);
+  }
+  Vectorized exp2() const {
+    return map(Sleef_exp2f8_u10);
+  }
+  Vectorized expm1() const {
+    return map(Sleef_expm1f8_u10);
+  }
+  Vectorized exp_u20() const {
+    return exp();
+  }
+  Vectorized fmod(const Vectorized & q) const {
+    __m256 x_lo, x_hi;
+    cvt_to_fp32(values, x_lo, x_hi);
+    __m256 q_lo, q_hi;
+    cvt_to_fp32(q.values, q_lo, q_hi);
+    auto o1 = Sleef_fmodf8(x_lo, q_lo);
+    auto o2 = Sleef_fmodf8(x_hi, q_hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    __m256 lo, hi;
+    __m256 b1, b2;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(b.values, b1, b2);
+    auto o1 = Sleef_hypotf8_u05(lo, b1);
+    auto o2 = Sleef_hypotf8_u05(hi, b2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized i0() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm256_storeu_ps(reinterpret_cast(tmp2), hi);
+    for (int64_t i = 0; i < size() / 2; i++) {
+      tmp1[i] = calc_i0(tmp1[i]);
+      tmp2[i] = calc_i0(tmp2[i]);
+    }
+    auto o1 = _mm256_loadu_ps(tmp1);
+    auto o2 = _mm256_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized i0e() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    constexpr auto sz = size();
+    __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm256_storeu_ps(reinterpret_cast(tmp2), hi);
+
+    for (auto i = decltype(sz){0}; i < sz / 2; i++) {
+      tmp1[i] = calc_i0e(tmp1[i]);
+      tmp2[i] = calc_i0e(tmp2[i]);
+    }
+    const auto o1 = _mm256_loadu_ps(tmp1);
+    const auto o2 = _mm256_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized digamma() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    constexpr auto sz = size();
+    __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm256_storeu_ps(reinterpret_cast(tmp2), hi);
+
+    for (auto i = decltype(sz){0}; i < sz / 2; i++) {
+      tmp1[i] = calc_digamma(tmp1[i]);
+      tmp2[i] = calc_digamma(tmp2[i]);
+    }
+    const auto o1 = _mm256_loadu_ps(tmp1);
+    const auto o2 = _mm256_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __m256 lo, hi;
+    __m256 xlo, xhi;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(x.values, xlo, xhi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm256_storeu_ps(reinterpret_cast(tmp2), hi);
+    __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo);
+    _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi);
+    for (int64_t i = 0; i < size() / 2; ++i) {
+      tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
+      tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
+    }
+    auto o1 = _mm256_loadu_ps(tmp1);
+    auto o2 = _mm256_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+
+  Vectorized igammac(const Vectorized &x) const {
+    __m256 lo, hi;
+    __m256 xlo, xhi;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(x.values, xlo, xhi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm256_storeu_ps(reinterpret_cast(tmp2), hi);
+    __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
+    _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo);
+    _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi);
+    for (int64_t i = 0; i < size() / 2; ++i) {
+      tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
+      tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
+    }
+    auto o1 = _mm256_loadu_ps(tmp1);
+    auto o2 = _mm256_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized log() const {
+    return map(Sleef_logf8_u10);
+  }
+  Vectorized log2() const {
+    return map(Sleef_log2f8_u10);
+  }
+  Vectorized log10() const {
+    return map(Sleef_log10f8_u10);
+  }
+  Vectorized log1p() const {
+    return map(Sleef_log1pf8_u10);
+  }
+  Vectorized sin() const {
+    return map(Sleef_sinf8_u10);
+  }
+  Vectorized sinh() const {
+    return map(Sleef_sinhf8_u10);
+  }
+  Vectorized cos() const {
+    return map(Sleef_cosf8_u10);
+  }
+  Vectorized cosh() const {
+    return map(Sleef_coshf8_u10);
+  }
+  Vectorized ceil() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm256_ceil_ps(lo);
+    auto o2 = _mm256_ceil_ps(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized floor() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm256_floor_ps(lo);
+    auto o2 = _mm256_floor_ps(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized neg() const {
+    return _mm256_xor_si256(values, _mm256_set1_epi16(0x8000));
+  }
+  Vectorized round() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+    auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized tan() const {
+    return map(Sleef_tanf8_u10);
+  }
+  Vectorized tanh() const {
+    return map(Sleef_tanhf8_u10);
+  }
+  Vectorized trunc() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+    auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized lgamma() const {
+    return map(Sleef_lgammaf8_u10);
+  }
+  Vectorized sqrt() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm256_sqrt_ps(lo);
+    auto o2 = _mm256_sqrt_ps(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized reciprocal() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto ones = _mm256_set1_ps(1);
+    auto o1 = _mm256_div_ps(ones, lo);
+    auto o2 = _mm256_div_ps(ones, hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized rsqrt() const {
+    __m256 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto ones = _mm256_set1_ps(1);
+    auto o1 = _mm256_div_ps(ones, _mm256_sqrt_ps(lo));
+    auto o2 = _mm256_div_ps(ones, _mm256_sqrt_ps(hi));
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized pow(const Vectorized &b) const {
+    __m256 lo, hi;
+    __m256 b1, b2;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(b.values, b1, b2);
+    auto o1 = Sleef_powf8_u10(lo, b1);
+    auto o2 = Sleef_powf8_u10(hi, b2);
+    return cvt_from_fp32(o1, o2);
+  }
+private:
+  template
+  Vectorized inline binary_compare(const Vectorized& b, Op op) const {
+    __m256 a_lo, a_hi;
+    __m256 b_lo, b_hi;
+    cvt_to_fp32(values, a_lo, a_hi);
+    cvt_to_fp32(b.values, b_lo, b_hi);
+    auto o1 = op(a_lo, b_lo);
+    auto o2 = op(a_hi, b_hi);
+    return cvt_from_fp32(o1, o2);
+  }
+
+public:
+  Vectorized inline operator>(const Vectorized& other) const {
+    return binary_compare(other, [](__m256 x, __m256 y) { return _mm256_cmp_ps(x, y, _CMP_GT_OQ); });
+  }
+  Vectorized inline operator<(const Vectorized& other) const {
+    return binary_compare(other, [](__m256 x, __m256 y) { return _mm256_cmp_ps(x, y, _CMP_LT_OQ); });
+  }
+  Vectorized inline operator>=(const Vectorized& other) const {
+    return binary_compare(other, [](__m256 x, __m256 y) { return _mm256_cmp_ps(x, y, _CMP_GE_OQ); });
+  }
+  Vectorized inline operator<=(const Vectorized& other) const {
+    return binary_compare(other, [](__m256 x, __m256 y) { return _mm256_cmp_ps(x, y, _CMP_LE_OQ); });
+  }
+  Vectorized inline operator==(const Vectorized& other) const {
+    return binary_compare(other, [](__m256 x, __m256 y) { return _mm256_cmp_ps(x, y, _CMP_EQ_OQ); });
+  }
+  Vectorized inline operator!=(const Vectorized& other) const {
+    return binary_compare(other, [](__m256 x, __m256 y) { return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); });
+  }
+};
+
+template
+static inline Vectorized binary_op_as_fp32(const Vectorized& a, const Vectorized& b, Op op) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  cvt_to_fp32(__m256i(a), a_lo, a_hi);
+  cvt_to_fp32(__m256i(b), b_lo, b_hi);
+  auto o1 = op(a_lo, b_lo);
+  auto o2 = op(a_hi, b_hi);
+  return cvt_from_fp32(o1, o2);
+}
+
+template <>
+class Vectorized: public Vectorized16 {
+public:
+  using Vectorized16::Vectorized16;
+
+  Vectorized frac() const;
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_add_ps(x, y); });
+}
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_sub_ps(x, y); });
+}
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_mul_ps(x, y); });
+}
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_div_ps(x, y); });
+}
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm256_and_si256(a, b);
+}
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm256_or_si256(a, b);
+}
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm256_xor_si256(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  cvtbf16_fp32(__m256i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m256i(b), b_lo, b_hi);
+  auto max_lo = _mm256_max_ps(a_lo, b_lo);
+  auto max_hi = _mm256_max_ps(a_hi, b_hi);
+  auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm256_or_ps(max_lo, nan_lo);
+  auto o2 = _mm256_or_ps(max_hi, nan_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  cvtbf16_fp32(__m256i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m256i(b), b_lo, b_hi);
+  auto min_lo = _mm256_min_ps(a_lo, b_lo);
+  auto min_hi = _mm256_min_ps(a_hi, b_hi);
+  auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm256_or_ps(min_lo, nan_lo);
+  auto o2 = _mm256_or_ps(min_hi, nan_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a,
+    const Vectorized& min, const Vectorized& max) {
+  __m256 a_lo, a_hi;
+  __m256 min_lo, min_hi;
+  __m256 max_lo, max_hi;
+  cvtbf16_fp32(__m256i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m256i(min), min_lo, min_hi);
+  cvtbf16_fp32(__m256i(max), max_lo, max_hi);
+  auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo));
+  auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi));
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  __m256 a_lo, a_hi;
+  __m256 max_lo, max_hi;
+  cvtbf16_fp32(__m256i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m256i(max), max_lo, max_hi);
+  auto o1 = _mm256_min_ps(max_lo, a_lo);
+  auto o2 = _mm256_min_ps(max_hi, a_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  __m256 a_lo, a_hi;
+  __m256 min_lo, min_hi;
+  cvtbf16_fp32(__m256i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m256i(min), min_lo, min_hi);
+  auto o1 = _mm256_max_ps(min_lo, a_lo);
+  auto o2 = _mm256_max_ps(min_hi, a_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto vsrc = _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i)));
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc);
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+template <>
+inline void convert(const float* src, BFloat16* dst, int64_t n) {
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m256 a = _mm256_loadu_ps(&src[i]);
+    __m256 b = _mm256_loadu_ps(&src[i + 8]);
+
+    __m256i bf = cvtfp32_bf16(a, b);
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+inline void convert(const double* src, BFloat16* dst, int64_t n) {
+  auto load_float = [](const double *src) -> __m256 {
+    // Load one float vector from an array of doubles
+    __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src));
+    __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4));
+    return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1);
+  };
+
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m256 a = load_float(&src[i]);
+    __m256 b = load_float(&src[i + 8]);
+
+    __m256i bf = cvtfp32_bf16(a, b);
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a,
+    const Vectorized& b, const Vectorized& c) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  __m256 c_lo, c_hi;
+  cvtbf16_fp32(__m256i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m256i(b), b_lo, b_hi);
+  cvtbf16_fp32(__m256i(c), c_lo, c_hi);
+  auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo);
+  auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+class Vectorized: public Vectorized16 {
+public:
+  using Vectorized16::Vectorized16;
+
+  Vectorized frac() const;
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_add_ps(x, y); });
+}
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_sub_ps(x, y); });
+}
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_mul_ps(x, y); });
+}
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_div_ps(x, y); });
+}
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm256_and_si256(a, b);
+}
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm256_or_si256(a, b);
+}
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm256_xor_si256(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  cvtfp16_fp32(__m256i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m256i(b), b_lo, b_hi);
+  auto max_lo = _mm256_max_ps(a_lo, b_lo);
+  auto max_hi = _mm256_max_ps(a_hi, b_hi);
+  auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm256_or_ps(max_lo, nan_lo);
+  auto o2 = _mm256_or_ps(max_hi, nan_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  cvtfp16_fp32(__m256i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m256i(b), b_lo, b_hi);
+  auto min_lo = _mm256_min_ps(a_lo, b_lo);
+  auto min_hi = _mm256_min_ps(a_hi, b_hi);
+  auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm256_or_ps(min_lo, nan_lo);
+  auto o2 = _mm256_or_ps(min_hi, nan_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a,
+    const Vectorized& min, const Vectorized& max) {
+  __m256 a_lo, a_hi;
+  __m256 min_lo, min_hi;
+  __m256 max_lo, max_hi;
+  cvtfp16_fp32(__m256i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m256i(min), min_lo, min_hi);
+  cvtfp16_fp32(__m256i(max), max_lo, max_hi);
+  auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo));
+  auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi));
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  __m256 a_lo, a_hi;
+  __m256 max_lo, max_hi;
+  cvtfp16_fp32(__m256i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m256i(max), max_lo, max_hi);
+  auto o1 = _mm256_min_ps(max_lo, a_lo);
+  auto o2 = _mm256_min_ps(max_hi, a_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  __m256 a_lo, a_hi;
+  __m256 min_lo, min_hi;
+  cvtfp16_fp32(__m256i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m256i(min), min_lo, min_hi);
+  auto o1 = _mm256_max_ps(min_lo, a_lo);
+  auto o2 = _mm256_max_ps(min_hi, a_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+inline void convert(const Half* src, Half* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto vsrc = _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i)));
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc);
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+template <>
+inline void convert(const float* src, Half* dst, int64_t n) {
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m256 a = _mm256_loadu_ps(&src[i]);
+    __m256 b = _mm256_loadu_ps(&src[i + 8]);
+
+    __m256i c = cvtfp32_fp16(a, b);
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+inline void convert(const double* src, Half* dst, int64_t n) {
+  auto load_float = [](const double *src) -> __m256 {
+    // Load one float vector from an array of doubles
+    __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src));
+    __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4));
+    return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1);
+  };
+
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m256 a = load_float(&src[i]);
+    __m256 b = load_float(&src[i + 8]);
+
+    __m256i c = cvtfp32_fp16(a, b);
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), c);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a,
+    const Vectorized& b, const Vectorized& c) {
+  __m256 a_lo, a_hi;
+  __m256 b_lo, b_hi;
+  __m256 c_lo, c_hi;
+  cvtfp16_fp32(__m256i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m256i(b), b_lo, b_hi);
+  cvtfp16_fp32(__m256i(c), c_lo, c_hi);
+  auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo);
+  auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+#define CONVERT_VECTORIZED_INIT(type, name) \
+inline std::tuple, Vectorized> convert_##name##_float(const Vectorized& a) { \
+  __m256 o1, o2; \
+  cvt_to_fp32(__m256i(a), o1, o2); \
+  return std::make_tuple(o1, o2); \
+} \
+inline Vectorized convert_float_##name(const Vectorized& a, const Vectorized& b) { \
+  return cvt_from_fp32(__m256(a), __m256(b)); \
+}
+CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
+CONVERT_VECTORIZED_INIT(Half, half);
+
+#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+#define CONVERT_NON_VECTORIZED_INIT(type, name) \
+inline std::tuple, Vectorized> convert_##name##_float(const Vectorized& a) { \
+  constexpr int64_t K = Vectorized::size(); \
+  __at_align__ float arr[K]; \
+  __at_align__ type arr2[K]; \
+  a.store(arr2); \
+  convert(arr2, arr, K); \
+  return std::make_tuple( \
+      Vectorized::loadu(arr), \
+      Vectorized::loadu(arr + Vectorized::size())); \
+} \
+inline Vectorized convert_float_##name(const Vectorized& a, const Vectorized& b) { \
+  constexpr int64_t K = Vectorized::size(); \
+  __at_align__ float arr[K]; \
+  __at_align__ type arr2[K]; \
+  a.store(arr); \
+  b.store(arr + Vectorized::size()); \
+  convert(arr, arr2, K); \
+  return Vectorized::loadu(arr2); \
+}
+CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
+CONVERT_NON_VECTORIZED_INIT(Half, half);
+
+#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#define LOAD_FP32_VECTORIZED_INIT(type, name) \
+inline void load_fp32_from_##name(const type *data, Vectorized& out) { \
+  auto values = _mm_loadu_si128(reinterpret_cast(data)); \
+  __m256 out_values; \
+  cvt_to_fp32(values, out_values); \
+  out = out_values; \
+} \
+\
+inline void load_fp32_from_##name(const type *data, Vectorized& out1, Vectorized& out2) { \
+  auto vec = Vectorized::loadu(data); \
+  __m256 out1_values, out2_values; \
+  cvt_to_fp32(vec, out1_values, out2_values); \
+  out1 = out1_values; \
+  out2 = out2_values; \
+}
+LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
+LOAD_FP32_VECTORIZED_INIT(Half, fp16);
+
+#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
+inline void load_fp32_from_##name(const type *data, Vectorized& out) { \
+  __at_align__ float values[Vectorized::size()]; \
+  for (const auto k : c10::irange(Vectorized::size())) { \
+    values[k] = data[k]; \
+  } \
+  out = Vectorized::loadu(values); \
+} \
+\
+inline void load_fp32_from_##name(const type *data, Vectorized& out1, Vectorized& out2) { \
+  load_fp32_from_##name(data, out1); \
+  data += Vectorized::size(); \
+  load_fp32_from_##name(data, out2); \
+}
+LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16);
+LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16);
+
+#endif
+}} // namsepace at::vec::CPU_CAPABILITY
+
+#pragma GCC diagnostic pop
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h
new file mode 100644
index 0000000000000000000000000000000000000000..a095d00637a245a64c7cfb51485dc552d5b65c60
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h
@@ -0,0 +1,431 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#include 
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+template <> class Vectorized> {
+private:
+  __m256d values;
+public:
+  using value_type = c10::complex;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 2;
+  }
+  Vectorized() {}
+  Vectorized(__m256d v) : values(v) {}
+  Vectorized(c10::complex val) {
+    double real_value = val.real();
+    double imag_value = val.imag();
+    values = _mm256_setr_pd(real_value, imag_value,
+                            real_value, imag_value);
+  }
+  Vectorized(c10::complex val1, c10::complex val2) {
+    values = _mm256_setr_pd(val1.real(), val1.imag(),
+                            val2.real(), val2.imag());
+  }
+  operator __m256d() const {
+    return values;
+  }
+  template 
+  static Vectorized> blend(const Vectorized>& a, const Vectorized>& b) {
+     // convert c10::complex index mask to V index mask: xy -> xxyy
+    static_assert (mask > -1 && mask < 4, "Unexpected mask value");
+    switch (mask) {
+      case 0:
+        return a;
+      case 1:
+        return _mm256_blend_pd(a.values, b.values, 0x03);
+      case 2:
+        return _mm256_blend_pd(a.values, b.values, 0x0c);
+      case 3: break;
+    }
+    return b;
+  }
+  static Vectorized> blendv(const Vectorized>& a, const Vectorized>& b,
+                               const Vectorized>& mask) {
+    // convert c10::complex index mask to V index mask: xy -> xxyy
+    auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values);
+    return _mm256_blendv_pd(a.values, b.values, mask_);
+
+  }
+  template
+  static Vectorized> arange(c10::complex base = 0., step_t step = static_cast(1)) {
+    return Vectorized>(base,
+                                        base + step);
+  }
+  static Vectorized> set(const Vectorized>& a, const Vectorized>& b,
+                            int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+    }
+    return b;
+  }
+  static Vectorized> loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm256_loadu_pd(reinterpret_cast(ptr));
+
+    __at_align__ double tmp_values[2*size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(2*size())) {
+      tmp_values[i] = 0.0;
+    }
+    std::memcpy(
+        tmp_values,
+        reinterpret_cast(ptr),
+        count * sizeof(c10::complex));
+    return _mm256_load_pd(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm256_storeu_pd(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      double tmp_values[2*size()];
+      _mm256_storeu_pd(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(c10::complex));
+    }
+  }
+  const c10::complex& operator[](int idx) const  = delete;
+  c10::complex& operator[](int idx) = delete;
+  Vectorized> map(c10::complex (*const f)(const c10::complex &)) const {
+    __at_align__ c10::complex tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  __m256d abs_2_() const {
+    auto val_2 = _mm256_mul_pd(values, values);     // a*a     b*b
+    return _mm256_hadd_pd(val_2, val_2);            // a*a+b*b a*a+b*b
+  }
+  __m256d abs_() const {
+    auto real = _mm256_movedup_pd(values);       // real real
+    // movehdup_pd does not exist...
+    auto imag = _mm256_permute_pd(values, 0xf);  // imag imag
+    return Sleef_hypotd4_u05(real, imag);        // abs  abs
+  }
+  Vectorized> abs() const {
+    const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                     0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
+    return _mm256_and_pd(abs_(), real_mask);        // abs     0
+  }
+  __m256d angle_() const {
+    //angle = atan2(b/a)
+    auto b_a = _mm256_permute_pd(values, 0x05);     // b        a
+    return Sleef_atan2d4_u10(values, b_a);          // 90-angle angle
+  }
+  Vectorized> angle() const {
+    const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                     0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
+    auto angle = _mm256_permute_pd(angle_(), 0x05); // angle    90-angle
+    return _mm256_and_pd(angle, real_mask);         // angle    0
+  }
+  Vectorized> sgn() const {
+    auto abs = abs_();
+    auto zero = _mm256_setzero_pd();
+    auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
+    auto div = values / abs;
+    return _mm256_blendv_pd(div, zero, mask);
+  }
+  __m256d real_() const {
+    const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                     0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
+    return _mm256_and_pd(values, real_mask);
+  }
+  Vectorized> real() const {
+    return real_();
+  }
+  __m256d imag_() const {
+    const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
+                                                                     0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
+    return _mm256_and_pd(values, imag_mask);
+  }
+  Vectorized> imag() const {
+    return _mm256_permute_pd(imag_(), 0x05);           //b        a
+  }
+  __m256d conj_() const {
+    const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
+    return _mm256_xor_pd(values, sign_mask);           // a       -b
+  }
+  Vectorized> conj() const {
+    return conj_();
+  }
+  Vectorized> log() const {
+    // Most trigonomic ops use the log() op to improve complex number performance.
+    return map(std::log);
+  }
+  Vectorized> log2() const {
+    const __m256d log2_ = _mm256_set1_pd(std::log(2));
+    return _mm256_div_pd(log(), log2_);
+  }
+  Vectorized> log10() const {
+    const __m256d log10_ = _mm256_set1_pd(std::log(10));
+    return _mm256_div_pd(log(), log10_);
+  }
+  Vectorized> log1p() const {
+    return map(std::log1p);
+  }
+  Vectorized> asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+    const __m256d one = _mm256_set1_pd(1);
+
+    auto conj = conj_();
+    auto b_a = _mm256_permute_pd(conj, 0x05);                         //-b        a
+    auto ab = _mm256_mul_pd(conj, b_a);                               //-ab       -ab
+    auto im = _mm256_add_pd(ab, ab);                                  //-2ab      -2ab
+
+    auto val_2 = _mm256_mul_pd(values, values);                       // a*a      b*b
+    auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05));  // a*a-b*b  b*b-a*a
+    re = _mm256_sub_pd(one, re);
+
+    auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt();         //sqrt(re + i*im)
+    auto ln = Vectorized(_mm256_add_pd(b_a, root)).log();                 //ln(iz + sqrt())
+    return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj();         //-i*ln()
+  }
+  Vectorized> acos() const {
+    // acos(x) = pi/2 - asin(x)
+    constexpr auto pi_2d = c10::pi / 2;
+    const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0);
+    return _mm256_sub_pd(pi_2, asin());
+  }
+  Vectorized> atan() const;
+  Vectorized> atanh() const {
+    return map(std::atanh);
+  }
+  Vectorized> exp() const {
+    //exp(a + bi)
+    // = exp(a)*(cos(b) + sin(b)i)
+    auto exp = Sleef_expd4_u10(values);                               //exp(a)           exp(b)
+    exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A);   //exp(a)           exp(a)
+
+    auto sin_cos = Sleef_sincosd4_u10(values);                        //[sin(a), cos(a)] [sin(b), cos(b)]
+    auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, 0x05),
+                                   sin_cos.x, 0x0A);                  //cos(b)           sin(b)
+    return _mm256_mul_pd(exp, cos_sin);
+  }
+  Vectorized> exp2() const {
+    // Use identity 2**x = exp(log(2) * x)
+    const __m256d ln_2 = _mm256_set1_pd(c10::ln_2);
+    Vectorized> scaled_values = _mm256_mul_pd(values, ln_2);
+    return scaled_values.exp();
+  }
+  Vectorized> expm1() const {
+    return map(std::expm1);
+  }
+  Vectorized> sin() const {
+    return map(std::sin);
+  }
+  Vectorized> sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized> cos() const {
+    return map(std::cos);
+  }
+  Vectorized> cosh() const {
+    return map(std::cosh);
+  }
+  Vectorized> ceil() const {
+    return _mm256_ceil_pd(values);
+  }
+  Vectorized> floor() const {
+    return _mm256_floor_pd(values);
+  }
+  Vectorized> neg() const {
+    auto zero = _mm256_setzero_pd();
+    return _mm256_sub_pd(zero, values);
+  }
+  Vectorized> round() const {
+    return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> tan() const {
+    return map(std::tan);
+  }
+  Vectorized> tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized> trunc() const {
+    return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> sqrt() const {
+    return map(std::sqrt);
+  }
+  Vectorized> reciprocal() const;
+  Vectorized> rsqrt() const {
+    return sqrt().reciprocal();
+  }
+  Vectorized> pow(const Vectorized> &exp) const {
+    __at_align__ c10::complex x_tmp[size()];
+    __at_align__ c10::complex y_tmp[size()];
+    store(x_tmp);
+    exp.store(y_tmp);
+    for (const auto i : c10::irange(size())) {
+      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
+    }
+    return loadu(x_tmp);
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized> operator==(const Vectorized>& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
+  }
+  Vectorized> operator!=(const Vectorized>& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
+  }
+  Vectorized> operator<(const Vectorized>&) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator<=(const Vectorized>&) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>(const Vectorized>&) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>=(const Vectorized>&) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized> eq(const Vectorized>& other) const;
+  Vectorized> ne(const Vectorized>& other) const;
+};
+
+template <> Vectorized> inline operator+(const Vectorized> &a, const Vectorized> &b) {
+  return _mm256_add_pd(a, b);
+}
+
+template <> Vectorized> inline operator-(const Vectorized> &a, const Vectorized> &b) {
+  return _mm256_sub_pd(a, b);
+}
+
+template <> Vectorized> inline operator*(const Vectorized> &a, const Vectorized> &b) {
+  //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+  const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
+  auto ac_bd = _mm256_mul_pd(a, b);         //ac       bd
+
+  auto d_c = _mm256_permute_pd(b, 0x05);    //d        c
+  d_c = _mm256_xor_pd(sign_mask, d_c);      //d       -c
+  auto ad_bc = _mm256_mul_pd(a, d_c);       //ad      -bc
+
+  auto ret = _mm256_hsub_pd(ac_bd, ad_bc);  //ac - bd  ad + bc
+  return ret;
+}
+
+template <> Vectorized> inline operator/(const Vectorized> &a, const Vectorized> &b) {
+  //re + im*i = (a + bi)  / (c + di)
+  auto mask = _mm256_set1_pd(-0.f);
+  auto fabs_cd = _mm256_andnot_pd(mask, b);     // |c|    |d|
+  auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05);   // |d|    |c|
+  auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, fabs_dc));  // 1/sc     1/sc
+  auto a2 = _mm256_mul_pd(a, scale);         // a/sc     b/sc
+  auto b2 = _mm256_mul_pd(b, scale);         // c/sc     d/sc
+  auto acbd2 = _mm256_mul_pd(a2, b2);
+
+  const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0);
+  auto dc2 = _mm256_permute_pd(b2, 0x05);    // d/sc         c/sc
+  dc2 = _mm256_xor_pd(sign_mask, dc2);       // -d/|c,d|        c/sc
+  auto adbc2 = _mm256_mul_pd(a2, dc2);       //-ad/sc^2      bc/sc^2
+  auto res2 = _mm256_hadd_pd(acbd2, adbc2);  //(ac+bd)/sc^2  (bc-ad)/sc^2
+
+  // get the denominator
+  auto denom2 = Vectorized>(b2).abs_2_();  // (c^2+d^2)/sc^2   (c^2+d^2)/sc^2
+  res2 = _mm256_div_pd(res2, denom2);
+  return res2;
+}
+
+// reciprocal. Implement this here so we can use multiplication.
+inline Vectorized> Vectorized>::reciprocal() const{
+  //re + im*i = (a + bi)  / (c + di)
+  //re = (ac + bd)/abs_2() = c/abs_2()
+  //im = (bc - ad)/abs_2() = d/abs_2()
+  const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
+  auto c_d = _mm256_xor_pd(sign_mask, values);    //c       -d
+  return _mm256_div_pd(c_d, abs_2_());
+}
+
+inline Vectorized> Vectorized>::atan() const {
+  // atan(x) = i/2 * ln((i + z)/(i - z))
+  const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0);
+  const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5);
+
+  auto sum = Vectorized(_mm256_add_pd(i, values));                      // a        1+b
+  auto sub = Vectorized(_mm256_sub_pd(i, values));                      // -a       1-b
+  auto ln = (sum/sub).log();                                        // ln((i + z)/(i - z))
+  return i_half*ln;                                                 // i/2*ln()
+}
+
+template <>
+Vectorized> inline maximum(const Vectorized>& a, const Vectorized>& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ);
+  auto max = _mm256_blendv_pd(a, b, mask);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q);
+  return _mm256_or_pd(max, isnan);
+}
+
+template <>
+Vectorized> inline minimum(const Vectorized>& a, const Vectorized>& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ);
+  auto min = _mm256_blendv_pd(a, b, mask);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q);
+  return _mm256_or_pd(min, isnan);
+}
+
+template <>
+Vectorized> inline operator&(const Vectorized>& a, const Vectorized>& b) {
+  return _mm256_and_pd(a, b);
+}
+
+template <>
+Vectorized> inline operator|(const Vectorized>& a, const Vectorized>& b) {
+  return _mm256_or_pd(a, b);
+}
+
+template <>
+Vectorized> inline operator^(const Vectorized>& a, const Vectorized>& b) {
+  return _mm256_xor_pd(a, b);
+}
+
+inline Vectorized> Vectorized>::eq(const Vectorized>& other) const {
+  auto eq = (*this == other);  // compares real and imag individually
+  // If both real numbers and imag numbers are equal, then the complex numbers are equal
+  return (eq.real() & eq.imag()) & Vectorized>(_mm256_set1_pd(1.0));
+}
+
+inline Vectorized> Vectorized>::ne(const Vectorized>& other) const {
+  auto ne = (*this != other);  // compares real and imag individually
+  // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+  return (ne.real() | ne.imag()) & Vectorized>(_mm256_set1_pd(1.0));
+}
+
+#endif
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h
new file mode 100644
index 0000000000000000000000000000000000000000..be44f3e94ad6c74e7f645346ac8bfc72d0441673
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h
@@ -0,0 +1,468 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#include 
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+template <> class Vectorized> {
+private:
+  __m256 values;
+public:
+  using value_type = c10::complex;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 4;
+  }
+  Vectorized() {}
+  Vectorized(__m256 v) : values(v) {}
+  Vectorized(c10::complex val) {
+    float real_value = val.real();
+    float imag_value = val.imag();
+    values = _mm256_setr_ps(real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value
+                            );
+  }
+  Vectorized(c10::complex val1, c10::complex val2, c10::complex val3, c10::complex val4) {
+    values = _mm256_setr_ps(val1.real(), val1.imag(),
+                            val2.real(), val2.imag(),
+                            val3.real(), val3.imag(),
+                            val4.real(), val4.imag()
+                            );
+  }
+  operator __m256() const {
+    return values;
+  }
+  template 
+  static Vectorized> blend(const Vectorized>& a, const Vectorized>& b) {
+     // convert c10::complex index mask to V index mask: xy -> xxyy
+    static_assert(mask > -1 && mask < 16, "Unexpected mask range");
+    switch (mask) {
+      case 0:
+        return a;
+      case 1:
+        return _mm256_blend_ps(a.values, b.values, 0x03); //b0000 0001 = b0000 0011
+      case 2:
+        return _mm256_blend_ps(a.values, b.values, 0x0C); //b0000 0010 = b0000 1100
+      case 3:
+        return _mm256_blend_ps(a.values, b.values, 0x0F); //b0000 0011 = b0000 1111
+      case 4:
+        return _mm256_blend_ps(a.values, b.values, 0x30); //b0000 0100 = b0011 0000
+      case 5:
+        return _mm256_blend_ps(a.values, b.values, 0x33); //b0000 0101 = b0011 0011
+      case 6:
+        return _mm256_blend_ps(a.values, b.values, 0x3C); //b0000 0110 = b0011 1100
+      case 7:
+        return _mm256_blend_ps(a.values, b.values, 0x3F); //b0000 0111 = b0011 1111
+      case 8:
+        return _mm256_blend_ps(a.values, b.values, 0xC0); //b0000 1000 = b1100 0000
+      case 9:
+        return _mm256_blend_ps(a.values, b.values, 0xC3); //b0000 1001 = b1100 0011
+      case 10:
+        return _mm256_blend_ps(a.values, b.values, 0xCC); //b0000 1010 = b1100 1100
+      case 11:
+        return _mm256_blend_ps(a.values, b.values, 0xCF); //b0000 1011 = b1100 1111
+      case 12:
+        return _mm256_blend_ps(a.values, b.values, 0xF0); //b0000 1100 = b1111 0000
+      case 13:
+        return _mm256_blend_ps(a.values, b.values, 0xF3); //b0000 1101 = b1111 0011
+      case 14:
+        return _mm256_blend_ps(a.values, b.values, 0xFC); //b0000 1110 = b1111 1100
+      default: break;
+    }
+    return b;
+  }
+  static Vectorized> blendv(const Vectorized>& a, const Vectorized>& b,
+                               const Vectorized>& mask) {
+    // convert c10::complex index mask to V index mask: xy -> xxyy
+    auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values);
+    return _mm256_blendv_ps(a.values, b.values, mask_);
+
+  }
+  template
+  static Vectorized> arange(c10::complex base = 0., step_t step = static_cast(1)) {
+    return Vectorized>(base,
+                                        base + step,
+                                        base + c10::complex(2)*step,
+                                        base + c10::complex(3)*step);
+  }
+  static Vectorized> set(const Vectorized>& a, const Vectorized>& b,
+                            int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+    return b;
+  }
+  static Vectorized> loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm256_loadu_ps(reinterpret_cast(ptr));
+
+    __at_align__ float tmp_values[2*size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(2*size())) {
+      tmp_values[i] = 0.0;
+    }
+    std::memcpy(
+        tmp_values,
+        reinterpret_cast(ptr),
+        count * sizeof(c10::complex));
+    return _mm256_load_ps(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm256_storeu_ps(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      float tmp_values[2*size()];
+      _mm256_storeu_ps(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(c10::complex));
+    }
+  }
+  const c10::complex& operator[](int idx) const  = delete;
+  c10::complex& operator[](int idx) = delete;
+  Vectorized> map(c10::complex (*const f)(const c10::complex &)) const {
+    __at_align__ c10::complex tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  __m256 abs_2_() const {
+    auto val_2 = _mm256_mul_ps(values, values);     // a*a     b*b
+    auto ret = _mm256_hadd_ps(val_2, val_2);        // a*a+b*b a*a+b*b
+    return _mm256_permute_ps(ret, 0xD8);
+  }
+  __m256 abs_() const {
+    auto real = _mm256_moveldup_ps(values);   // real real
+    auto imag = _mm256_movehdup_ps(values);   // imag imag
+    return Sleef_hypotf8_u05(real, imag);     // abs  abs
+  }
+  Vectorized> abs() const {
+    const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
+    return _mm256_and_ps(abs_(), real_mask);        // abs     0
+  }
+  __m256 angle_() const {
+    //angle = atan2(b/a)
+    auto b_a = _mm256_permute_ps(values, 0xB1);     // b        a
+    return Sleef_atan2f8_u10(values, b_a);          // 90-angle angle
+  }
+  Vectorized> angle() const {
+    const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
+    auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle    90-angle
+    return _mm256_and_ps(angle, real_mask);         // angle    0
+  }
+  Vectorized> sgn() const {
+    auto abs = abs_();
+    auto zero = _mm256_setzero_ps();
+    auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
+    auto div = values / abs;
+    return _mm256_blendv_ps(div, zero, mask);
+  }
+  __m256 real_() const {
+    const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
+    return _mm256_and_ps(values, real_mask);
+  }
+  Vectorized> real() const {
+    return real_();
+  }
+  __m256 imag_() const {
+    const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
+                                                                   0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF));
+    return _mm256_and_ps(values, imag_mask);
+  }
+  Vectorized> imag() const {
+    return _mm256_permute_ps(imag_(), 0xB1);        //b        a
+  }
+  __m256 conj_() const {
+    const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+    return _mm256_xor_ps(values, sign_mask);        // a       -b
+  }
+  Vectorized> conj() const {
+    return conj_();
+  }
+  Vectorized> log() const {
+    // Most trigonomic ops use the log() op to improve complex number performance.
+    return map(std::log);
+  }
+  Vectorized> log2() const {
+    const __m256 log2_ = _mm256_set1_ps(std::log(2));
+    return _mm256_div_ps(log(), log2_);
+  }
+  Vectorized> log10() const {
+    const __m256 log10_ = _mm256_set1_ps(std::log(10));
+    return _mm256_div_ps(log(), log10_);
+  }
+  Vectorized> log1p() const {
+    return map(std::log1p);
+  }
+  Vectorized> asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+    const __m256 one = _mm256_set1_ps(1);
+
+    auto conj = conj_();
+    auto b_a = _mm256_permute_ps(conj, 0xB1);                         //-b        a
+    auto ab = _mm256_mul_ps(conj, b_a);                               //-ab       -ab
+    auto im = _mm256_add_ps(ab, ab);                                  //-2ab      -2ab
+
+    auto val_2 = _mm256_mul_ps(values, values);                       // a*a      b*b
+    auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1));  // a*a-b*b  b*b-a*a
+    re = _mm256_permute_ps(re, 0xD8);
+    re = _mm256_sub_ps(one, re);
+
+    auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt();         //sqrt(re + i*im)
+    auto ln = Vectorized(_mm256_add_ps(b_a, root)).log();                 //ln(iz + sqrt())
+    return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj();         //-i*ln()
+  }
+  Vectorized> acos() const {
+    return map(std::acos);
+  }
+  Vectorized> atan() const;
+  Vectorized> atanh() const {
+    return map(std::atanh);
+  }
+  Vectorized> exp() const {
+    //exp(a + bi)
+    // = exp(a)*(cos(b) + sin(b)i)
+    auto exp = Sleef_expf8_u10(values);                               //exp(a)           exp(b)
+    exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA);   //exp(a)           exp(a)
+
+    auto sin_cos = Sleef_sincosf8_u10(values);                        //[sin(a), cos(a)] [sin(b), cos(b)]
+    auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, 0xB1),
+                                   sin_cos.x, 0xAA);                  //cos(b)           sin(b)
+    return _mm256_mul_ps(exp, cos_sin);
+  }
+  Vectorized> exp2() const {
+    // Use identity 2**x = exp(log(2) * x)
+    const __m256 ln_2 = _mm256_set1_ps(c10::ln_2);
+    Vectorized> scaled_values = _mm256_mul_ps(values, ln_2);
+    return scaled_values.exp();
+  }
+  Vectorized> expm1() const {
+    return map(std::expm1);
+  }
+  Vectorized> sin() const {
+    return map(std::sin);
+  }
+  Vectorized> sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized> cos() const {
+    return map(std::cos);
+  }
+  Vectorized> cosh() const {
+    return map(std::cosh);
+  }
+  Vectorized> ceil() const {
+    return _mm256_ceil_ps(values);
+  }
+  Vectorized> floor() const {
+    return _mm256_floor_ps(values);
+  }
+  Vectorized> neg() const {
+    auto zero = _mm256_setzero_ps();
+    return _mm256_sub_ps(zero, values);
+  }
+  Vectorized> round() const {
+    return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> tan() const {
+    return map(std::tan);
+  }
+  Vectorized> tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized> trunc() const {
+    return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> sqrt() const {
+    return map(std::sqrt);
+  }
+  Vectorized> reciprocal() const;
+  Vectorized> rsqrt() const {
+    return sqrt().reciprocal();
+  }
+  Vectorized> pow(const Vectorized> &exp) const {
+    __at_align__ c10::complex x_tmp[size()];
+    __at_align__ c10::complex y_tmp[size()];
+    store(x_tmp);
+    exp.store(y_tmp);
+    for (const auto i : c10::irange(size())) {
+      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
+    }
+    return loadu(x_tmp);
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized> operator==(const Vectorized>& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
+  }
+  Vectorized> operator!=(const Vectorized>& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
+  }
+  Vectorized> operator<(const Vectorized>& /*other*/) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator<=(const Vectorized>& /*other*/) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>(const Vectorized>& /*other*/) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>=(const Vectorized>& /*other*/) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized> eq(const Vectorized>& other) const;
+  Vectorized> ne(const Vectorized>& other) const;
+};
+
+template <> Vectorized> inline operator+(const Vectorized> &a, const Vectorized> &b) {
+  return _mm256_add_ps(a, b);
+}
+
+template <> Vectorized> inline operator-(const Vectorized> &a, const Vectorized> &b) {
+  return _mm256_sub_ps(a, b);
+}
+
+template <> Vectorized> inline operator*(const Vectorized> &a, const Vectorized> &b) {
+  //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+  const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+  auto ac_bd = _mm256_mul_ps(a, b);         //ac       bd
+
+  auto d_c = _mm256_permute_ps(b, 0xB1);    //d        c
+  d_c = _mm256_xor_ps(sign_mask, d_c);      //d       -c
+  auto ad_bc = _mm256_mul_ps(a, d_c);       //ad      -bc
+
+  auto ret = _mm256_hsub_ps(ac_bd, ad_bc);  //ac - bd  ad + bc
+  ret = _mm256_permute_ps(ret, 0xD8);
+  return ret;
+}
+
+template <> Vectorized> inline operator/(const Vectorized> &a, const Vectorized> &b) {
+  //re + im*i = (a + bi)  / (c + di)
+  auto mask = _mm256_set1_ps(-0.f);
+  auto fabs_cd = _mm256_andnot_ps(mask, b);     // |c|    |d|
+  auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1);   // |d|    |c|
+  auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc));  // 1/sc     1/sc
+  auto a2 = _mm256_mul_ps(a, scale);         // a/sc     b/sc
+  auto b2 = _mm256_mul_ps(b, scale);         // c/sc     d/sc
+  auto acbd2 = _mm256_mul_ps(a2, b2);
+
+  const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
+  auto dc2 = _mm256_permute_ps(b2, 0xB1);    // d/sc         c/sc
+  dc2 = _mm256_xor_ps(sign_mask, dc2);       // -d/|c,d|        c/sc
+  auto adbc2 = _mm256_mul_ps(a2, dc2);       //-ad/sc^2      bc/sc^2
+  auto res2 = _mm256_hadd_ps(acbd2, adbc2);  //(ac+bd)/sc^2  (bc-ad)/sc^2
+  res2 = _mm256_permute_ps(res2, 0xD8);
+
+  // get the denominator
+  auto denom2 = Vectorized>(b2).abs_2_();  // (c^2+d^2)/sc^2   (c^2+d^2)/sc^2
+  res2 = _mm256_div_ps(res2, denom2);
+  return res2;
+}
+
+// reciprocal. Implement this here so we can use multiplication.
+inline Vectorized> Vectorized>::reciprocal() const {
+  //re + im*i = (a + bi)  / (c + di)
+  //re = (ac + bd)/abs_2() = c/abs_2()
+  //im = (bc - ad)/abs_2() = d/abs_2()
+  const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+  auto c_d = _mm256_xor_ps(sign_mask, values);    //c       -d
+  return _mm256_div_ps(c_d, abs_2_());
+}
+
+inline Vectorized> Vectorized>::atan() const {
+  // atan(x) = i/2 * ln((i + z)/(i - z))
+  const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
+  const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
+
+  auto sum = Vectorized(_mm256_add_ps(i, values));                      // a        1+b
+  auto sub = Vectorized(_mm256_sub_ps(i, values));                      // -a       1-b
+  auto ln = (sum/sub).log();                                        // ln((i + z)/(i - z))
+  return i_half*ln;                                                 // i/2*ln()
+}
+
+template <>
+Vectorized> inline maximum(const Vectorized>& a, const Vectorized>& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
+  auto max = _mm256_blendv_ps(a, b, mask);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
+  return _mm256_or_ps(max, isnan);
+}
+
+template <>
+Vectorized> inline minimum(const Vectorized>& a, const Vectorized>& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
+  auto min = _mm256_blendv_ps(a, b, mask);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
+  return _mm256_or_ps(min, isnan);
+}
+
+template <>
+Vectorized> inline operator&(const Vectorized>& a, const Vectorized>& b) {
+  return _mm256_and_ps(a, b);
+}
+
+template <>
+Vectorized> inline operator|(const Vectorized>& a, const Vectorized>& b) {
+  return _mm256_or_ps(a, b);
+}
+
+template <>
+Vectorized> inline operator^(const Vectorized>& a, const Vectorized>& b) {
+  return _mm256_xor_ps(a, b);
+}
+
+inline Vectorized> Vectorized>::eq(
+    const Vectorized>& other) const {
+  auto eq = (*this == other);  // compares real and imag individually
+  // If both real numbers and imag numbers are equal, then the complex numbers are equal
+  return (eq.real() & eq.imag()) & Vectorized>(_mm256_set1_ps(1.0f));
+}
+
+inline Vectorized> Vectorized>::ne(
+    const Vectorized>& other) const {
+  auto ne = (*this != other);  // compares real and imag individually
+  // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+  return (ne.real() | ne.imag()) & Vectorized>(_mm256_set1_ps(1.0f));
+}
+
+#endif
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h
new file mode 100644
index 0000000000000000000000000000000000000000..328e33a79a4e2a89d1044144280c9625321fab9a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h
@@ -0,0 +1,442 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+template <> class Vectorized {
+private:
+  __m256d values;
+public:
+  using value_type = double;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 4;
+  }
+  Vectorized() {}
+  Vectorized(__m256d v) : values(v) {}
+  Vectorized(double val) {
+    values = _mm256_set1_pd(val);
+  }
+  Vectorized(double val1, double val2, double val3, double val4) {
+    values = _mm256_setr_pd(val1, val2, val3, val4);
+  }
+  operator __m256d() const {
+    return values;
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    return _mm256_blend_pd(a.values, b.values, mask);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                               const Vectorized& mask) {
+    return _mm256_blendv_pd(a.values, b.values, mask.values);
+  }
+  template
+  static Vectorized arange(double base = 0., step_t step = static_cast(1)) {
+    return Vectorized(base, base + step, base + 2 * step, base + 3 * step);
+  }
+  static Vectorized set(const Vectorized& a, const Vectorized& b,
+                            int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm256_loadu_pd(reinterpret_cast(ptr));
+
+
+    __at_align__ double tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0.0;
+    }
+    std::memcpy(
+        tmp_values,
+        reinterpret_cast(ptr),
+        count * sizeof(double));
+    return _mm256_load_pd(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm256_storeu_pd(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      double tmp_values[size()];
+      _mm256_storeu_pd(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(double));
+    }
+  }
+  const double& operator[](int idx) const  = delete;
+  double& operator[](int idx) = delete;
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ);
+    return _mm256_movemask_pd(cmp);
+  }
+  Vectorized isnan() const {
+    return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
+  }
+  bool has_inf_nan() const {
+    __m256d self_sub  = _mm256_sub_pd(values, values);
+    return (_mm256_movemask_epi8(_mm256_castpd_si256(self_sub)) & 0x77777777) != 0;
+  }
+  Vectorized map(double (*const f)(double)) const {
+    __at_align__ double tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized abs() const {
+    auto mask = _mm256_set1_pd(-0.f);
+    return _mm256_andnot_pd(mask, values);
+  }
+  Vectorized angle() const {
+    const auto zero_vec = _mm256_set1_pd(0.f);
+    const auto nan_vec = _mm256_set1_pd(NAN);
+    const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
+    const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
+    const auto pi = _mm256_set1_pd(c10::pi);
+
+    const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
+    auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
+    angle = _mm256_blendv_pd(angle, nan_vec, nan_mask);
+    return angle;
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_pd(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return Vectorized(Sleef_acosd4_u10(values));
+  }
+  Vectorized acosh() const {
+    return Vectorized(Sleef_acoshd4_u10(values));
+  }
+  Vectorized asin() const {
+    return Vectorized(Sleef_asind4_u10(values));
+  }
+  Vectorized atan() const {
+    return Vectorized(Sleef_atand4_u10(values));
+  }
+  Vectorized atanh() const {
+    return Vectorized(Sleef_atanhd4_u10(values));
+  }
+  Vectorized atan2(const Vectorized &b) const {
+    return Vectorized(Sleef_atan2d4_u10(values, b));
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    return Vectorized(Sleef_copysignd4(values, sign));
+  }
+  Vectorized erf() const {
+    return Vectorized(Sleef_erfd4_u10(values));
+  }
+  Vectorized erfc() const {
+    return Vectorized(Sleef_erfcd4_u15(values));
+  }
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+  Vectorized exp() const {
+    return Vectorized(Sleef_expd4_u10(values));
+  }
+  Vectorized exp2() const {
+    return Vectorized(Sleef_exp2d4_u10(values));
+  }
+  Vectorized expm1() const {
+    return Vectorized(Sleef_expm1d4_u10(values));
+  }
+  Vectorized exp_u20() const {
+    return exp();
+  }
+  Vectorized fmod(const Vectorized& q) const {
+    return Vectorized(Sleef_fmodd4(values, q));
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    return Vectorized(Sleef_hypotd4_u05(values, b));
+  }
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __at_align__ double tmp[size()];
+    __at_align__ double tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized igammac(const Vectorized &x) const {
+    __at_align__ double tmp[size()];
+    __at_align__ double tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized log() const {
+    return Vectorized(Sleef_logd4_u10(values));
+  }
+  Vectorized log2() const {
+    return Vectorized(Sleef_log2d4_u10(values));
+  }
+  Vectorized log10() const {
+    return Vectorized(Sleef_log10d4_u10(values));
+  }
+  Vectorized log1p() const {
+    return Vectorized(Sleef_log1pd4_u10(values));
+  }
+  Vectorized sin() const {
+    return Vectorized(Sleef_sind4_u10(values));
+  }
+  Vectorized sinh() const {
+    return Vectorized(Sleef_sinhd4_u10(values));
+  }
+  Vectorized cos() const {
+    return Vectorized(Sleef_cosd4_u10(values));
+  }
+  Vectorized cosh() const {
+    return Vectorized(Sleef_coshd4_u10(values));
+  }
+  Vectorized ceil() const {
+    return _mm256_ceil_pd(values);
+  }
+  Vectorized floor() const {
+    return _mm256_floor_pd(values);
+  }
+  Vectorized frac() const;
+  Vectorized neg() const {
+    return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
+  }
+  Vectorized nextafter(const Vectorized &b) const {
+    return Vectorized(Sleef_nextafterd4(values, b));
+  }
+  Vectorized round() const {
+    return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized tan() const {
+    return Vectorized(Sleef_tand4_u10(values));
+  }
+  Vectorized tanh() const {
+    return Vectorized(Sleef_tanhd4_u10(values));
+  }
+  Vectorized trunc() const {
+    return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized lgamma() const {
+    return Vectorized(Sleef_lgammad4_u10(values));
+  }
+  Vectorized sqrt() const {
+    return _mm256_sqrt_pd(values);
+  }
+  Vectorized reciprocal() const {
+    return _mm256_div_pd(_mm256_set1_pd(1), values);
+  }
+  Vectorized rsqrt() const {
+    return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
+  }
+  Vectorized pow(const Vectorized &b) const {
+    return Vectorized(Sleef_powd4_u10(values, b));
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
+  }
+
+  Vectorized operator!=(const Vectorized& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_pd(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_pd(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm256_mul_pd(a, b);
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return _mm256_div_pd(a, b);
+}
+
+// frac. Implement this here so we can use subtraction.
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  Vectorized max = _mm256_max_pd(a, b);
+  Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  return _mm256_or_pd(max, isnan);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  Vectorized min = _mm256_min_pd(a, b);
+  Vectorized isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  return _mm256_or_pd(min, isnan);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) {
+  return _mm256_min_pd(max, _mm256_max_pd(min, a));
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  return _mm256_max_pd(min, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  return _mm256_min_pd(max, a);
+}
+
+template <>
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm256_and_pd(a, b);
+}
+
+template <>
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm256_or_pd(a, b);
+}
+
+template <>
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm256_xor_pd(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0);
+}
+
+template <>
+inline void convert(const double* src, double* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+#ifdef CPU_CAPABILITY_AVX2
+template <>
+Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm256_fmadd_pd(a, b, c);
+}
+
+template <>
+Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm256_fmsub_pd(a, b, c);
+}
+#endif
+
+#endif
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h
new file mode 100644
index 0000000000000000000000000000000000000000..2b372f294f9036e1a7c5a1915cb37f24f7b645fd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h
@@ -0,0 +1,636 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+template <> class Vectorized {
+private:
+  __m256 values;
+public:
+  using value_type = float;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+  Vectorized() {}
+  Vectorized(__m256 v) : values(v) {}
+  Vectorized(float val) {
+    values = _mm256_set1_ps(val);
+  }
+  Vectorized(float val1, float val2, float val3, float val4,
+         float val5, float val6, float val7, float val8) {
+    values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8);
+  }
+  operator __m256() const {
+    return values;
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    return _mm256_blend_ps(a.values, b.values, mask);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                              const Vectorized& mask) {
+    return _mm256_blendv_ps(a.values, b.values, mask.values);
+  }
+  template
+  static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,            base +     step, base + 2 * step, base + 3 * step,
+      base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
+  }
+  static Vectorized set(const Vectorized& a, const Vectorized& b,
+                           int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm256_loadu_ps(reinterpret_cast(ptr));
+    __at_align__ float tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0.0;
+    }
+    std::memcpy(
+        tmp_values, reinterpret_cast(ptr), count * sizeof(float));
+    return _mm256_loadu_ps(tmp_values);
+  }
+  void store(void* ptr, int64_t count = size()) const {
+    if (count == size()) {
+      _mm256_storeu_ps(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      float tmp_values[size()];
+      _mm256_storeu_ps(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(float));
+    }
+  }
+  const float& operator[](int idx) const  = delete;
+  float& operator[](int idx) = delete;
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    __m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ);
+    return _mm256_movemask_ps(cmp);
+  }
+  Vectorized isnan() const {
+    return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q);
+  }
+
+  bool has_inf_nan() const {
+    __m256 self_sub  = _mm256_sub_ps(values, values);
+    return (_mm256_movemask_epi8(_mm256_castps_si256(self_sub)) & 0x77777777) != 0;
+  }
+
+  Vectorized map(float (*const f)(float)) const {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized abs() const {
+    auto mask = _mm256_set1_ps(-0.f);
+    return _mm256_andnot_ps(mask, values);
+  }
+  Vectorized angle() const {
+    const auto zero_vec = _mm256_set1_ps(0.f);
+    const auto nan_vec = _mm256_set1_ps(NAN);
+    const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
+    const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
+    const auto pi = _mm256_set1_ps(c10::pi);
+
+    const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
+    auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
+    angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
+    return angle;
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_ps(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return Vectorized(Sleef_acosf8_u10(values));
+  }
+  Vectorized acosh() const {
+    return Vectorized(Sleef_acoshf8_u10(values));
+  }
+  Vectorized asin() const {
+    return Vectorized(Sleef_asinf8_u10(values));
+  }
+  Vectorized atan() const {
+    return Vectorized(Sleef_atanf8_u10(values));
+  }
+  Vectorized atanh() const {
+    return Vectorized(Sleef_atanhf8_u10(values));
+  }
+  Vectorized atan2(const Vectorized &b) const {
+    return Vectorized(Sleef_atan2f8_u10(values, b));
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    return Vectorized(Sleef_copysignf8(values, sign));
+  }
+  Vectorized erf() const {
+    // constants
+    const auto neg_zero_vec = _mm256_set1_ps(-0.f);
+    const auto one_vec = _mm256_set1_ps(1.0f);
+    const auto p = _mm256_set1_ps(0.3275911f);
+    const auto p1 = _mm256_set1_ps(0.254829592f);
+    const auto p2 = _mm256_set1_ps(-0.284496736f);
+    const auto p3 = _mm256_set1_ps(1.421413741f);
+    const auto p4 = _mm256_set1_ps(-1.453152027f);
+    const auto p5 = _mm256_set1_ps(1.061405429f);
+    // sign(x)
+    auto sign_mask = _mm256_and_ps(neg_zero_vec, values);
+    auto abs_vec = _mm256_xor_ps(sign_mask, values);
+    // t = 1 / (p * abs(x) + 1)
+    auto tmp0 = _mm256_fmadd_ps(p, abs_vec, one_vec);
+    auto t = _mm256_div_ps(one_vec, tmp0);
+    // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
+    auto tmp1 = _mm256_fmadd_ps(p5, t, p4);
+    auto tmp2 = _mm256_fmadd_ps(tmp1, t, p3);
+    auto tmp3 = _mm256_fmadd_ps(tmp2, t, p2);
+    auto r = _mm256_fmadd_ps(tmp3, t, p1);
+    // - exp(- x * x)
+    auto pow_2 = _mm256_mul_ps(values, values);
+    auto neg_pow_2 = _mm256_xor_ps(neg_zero_vec, pow_2);
+    // auto tmp4 = exp(neg_pow_2);
+    auto tmp4 = Vectorized(Sleef_expf8_u10(neg_pow_2));
+    auto tmp5 = _mm256_xor_ps(neg_zero_vec, tmp4);
+    // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
+    auto tmp6 = _mm256_mul_ps(tmp5, t);
+    auto tmp7 = _mm256_fmadd_ps(tmp6, r, one_vec);
+    return _mm256_xor_ps(sign_mask, tmp7);
+  }
+  Vectorized erfc() const {
+    return Vectorized(Sleef_erfcf8_u15(values));
+  }
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+  Vectorized exp() const {
+    return Vectorized(Sleef_expf8_u10(values));
+  }
+  Vectorized exp2() const {
+    return Vectorized(Sleef_exp2f8_u10(values));
+  }
+  Vectorized expm1() const {
+    return Vectorized(Sleef_expm1f8_u10(values));
+  }
+  Vectorized exp_u20() const {
+    // A faster version of exp with ULP=20
+    static __m256 vec_factorial_1 =
+        _mm256_set1_ps(0.999999701f); // 1/factorial(1)
+    static __m256 vec_factorial_2 =
+        _mm256_set1_ps(0.499991506f); // 1/factorial(2)
+    static __m256 vec_factorial_3 =
+        _mm256_set1_ps(0.166676521f); // 1/factorial(3)
+    static __m256 vec_factorial_4 =
+        _mm256_set1_ps(0.0418978221f); // 1/factorial(4)
+    static __m256 vec_factorial_5 =
+        _mm256_set1_ps(0.00828929059f); // 1/factorial(5)
+    static __m256 vec_exp_log2ef =
+        (__m256)_mm256_set1_epi32(0x3fb8aa3b); // log2(e)
+    static __m256 vec_half = _mm256_set1_ps(0.5f);
+    static __m256 vec_one = _mm256_set1_ps(1.f);
+    static __m256 vec_zero = _mm256_set1_ps(0.f);
+    static __m256 vec_two = _mm256_set1_ps(2.f);
+    static __m256 vec_ln2f = (__m256)_mm256_set1_epi32(0x3f317218); // ln(2)
+    static __m256 vec_ln_flt_min = (__m256)_mm256_set1_epi32(0xc2aeac50);
+    static __m256 vec_ln_flt_max = (__m256)_mm256_set1_epi32(0x42b17218);
+    static __m256i vec_127 = _mm256_set1_epi32(0x0000007f);
+    static int n_mantissa_bits = 23;
+
+    // exp(x) =
+    // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
+    // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression
+
+    auto less_ln_flt_min_mask =
+        _mm256_cmp_ps(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/);
+    auto vec_src = _mm256_min_ps(values, vec_ln_flt_max);
+    vec_src = _mm256_max_ps(vec_src, vec_ln_flt_min);
+
+    // fx = floorf(x * log2ef + 0.5)
+    auto vec_fx = _mm256_fmadd_ps(vec_src, vec_exp_log2ef, vec_half);
+    vec_fx = _mm256_floor_ps(vec_fx);
+
+    // x = x - fx * ln2
+    auto vec_exp_poly = _mm256_fnmadd_ps(vec_fx, vec_ln2f, vec_src);
+
+    // compute polynomial
+    auto vec_res =
+        _mm256_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4);
+    vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3);
+    vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2);
+    vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1);
+    vec_res = _mm256_fmadd_ps(vec_exp_poly, vec_res, vec_one);
+
+    // compute 2^(n-1)
+    auto vec_exp_number = _mm256_sub_ps(vec_fx, vec_one);
+    auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number);
+    auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127);
+    vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
+    auto vec_two_pow_n = (__m256)vec_two_pow_n_i;
+    vec_two_pow_n =
+        _mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask);
+
+    // y = y * 2^n
+    vec_res = _mm256_mul_ps(vec_res, vec_two_pow_n);
+    vec_res = _mm256_mul_ps(vec_res, vec_two);
+    return vec_res;
+  }
+  Vectorized fmod(const Vectorized& q) const {
+    return Vectorized(Sleef_fmodf8(values, q));
+  }
+  Vectorized log() const {
+    return Vectorized(Sleef_logf8_u10(values));
+  }
+  Vectorized log2() const {
+    return Vectorized(Sleef_log2f8_u10(values));
+  }
+  Vectorized log10() const {
+    return Vectorized(Sleef_log10f8_u10(values));
+  }
+  Vectorized log1p() const {
+    return Vectorized(Sleef_log1pf8_u10(values));
+  }
+  Vectorized frac() const;
+  Vectorized sin() const {
+    return Vectorized(Sleef_sinf8_u35(values));
+  }
+  Vectorized sinh() const {
+    return Vectorized(Sleef_sinhf8_u10(values));
+  }
+  Vectorized cos() const {
+    return Vectorized(Sleef_cosf8_u35(values));
+  }
+  Vectorized cosh() const {
+    return Vectorized(Sleef_coshf8_u10(values));
+  }
+  Vectorized ceil() const {
+    return _mm256_ceil_ps(values);
+  }
+  Vectorized floor() const {
+    return _mm256_floor_ps(values);
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    return Vectorized(Sleef_hypotf8_u05(values, b));
+  }
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __at_align__ float tmp[size()];
+    __at_align__ float tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized igammac(const Vectorized &x) const {
+    __at_align__ float tmp[size()];
+    __at_align__ float tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized neg() const {
+    return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
+  }
+  Vectorized nextafter(const Vectorized &b) const {
+    return Vectorized(Sleef_nextafterf8(values, b));
+  }
+  Vectorized round() const {
+    return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized tan() const {
+    return Vectorized(Sleef_tanf8_u10(values));
+  }
+  Vectorized tanh() const {
+    return Vectorized(Sleef_tanhf8_u10(values));
+  }
+  Vectorized trunc() const {
+    return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized lgamma() const {
+    return Vectorized(Sleef_lgammaf8_u10(values));
+  }
+  Vectorized sqrt() const {
+    return _mm256_sqrt_ps(values);
+  }
+  Vectorized reciprocal() const {
+    return _mm256_div_ps(_mm256_set1_ps(1), values);
+  }
+  Vectorized rsqrt() const {
+    return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
+  }
+  Vectorized pow(const Vectorized &b) const {
+    return Vectorized(Sleef_powf8_u10(values, b));
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
+  }
+
+  Vectorized operator!=(const Vectorized& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ);
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ);
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ);
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ);
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_ps(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_ps(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm256_mul_ps(a, b);
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return _mm256_div_ps(a, b);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  Vectorized max = _mm256_max_ps(a, b);
+  Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  return _mm256_or_ps(max, isnan);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  Vectorized min = _mm256_min_ps(a, b);
+  Vectorized isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
+  // Exploit the fact that all-ones is a NaN.
+  return _mm256_or_ps(min, isnan);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) {
+  return _mm256_min_ps(max, _mm256_max_ps(min, a));
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  return _mm256_min_ps(max, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  return _mm256_max_ps(min, a);
+}
+
+template <>
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm256_and_ps(a, b);
+}
+
+template <>
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm256_or_ps(a, b);
+}
+
+template <>
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm256_xor_ps(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+template <>
+inline void convert(const float* src, float* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+
+template <>
+Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm256_fmadd_ps(a, b, c);
+}
+
+template <>
+Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm256_fmsub_ps(a, b, c);
+}
+
+// Used by Inductor CPP codegen
+template<>
+inline void transpose_mxn(
+    const float* src,
+    int64_t ld_src,
+    float* dst,
+    int64_t ld_dst) {
+  // load from src to registers
+  // a: a0  a1  a2  a3  a4  a5  a6  a7
+  // b: b0  b1  b2  b3  b4  b5  b6  b7
+  // c: c0  c1  c2  c3  c4  c5  c6  c7
+  // d: d0  d1  d2  d3  d4  d5  d6  d7
+  // e: e0  e1  e2  e3  e4  e5  e6  e7
+  // f: f0  f1  f2  f3  f4  f5  f6  f7
+  // g: g0  g1  g2  g3  g4  g5  g6  g7
+  // h: h0  h1  h2  h3  h4  h5  h6  h7
+  __m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
+  __m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
+  __m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
+  __m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
+  __m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
+  __m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
+  __m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
+  __m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
+
+  __m256 ta, tb, tc, td, te, tf, tg, th;
+  // unpacking and interleaving 32-bit elements
+  // a0  b0  a1  b1  a4  b4  a5  b5
+  // a2  b2  a3  b3  a6  b6  a7  b7
+  // c0  d0  c1  d1 ...
+  // c2  d2  c3  d3 ...
+  // e0  f0  e1  f1 ...
+  // e2  f2  e3  f3 ...
+  // g0  h0  g1  h1 ...
+  // g2  h2  g3  h3 ...
+  ta = _mm256_unpacklo_ps(a, b);
+  tb = _mm256_unpackhi_ps(a, b);
+  tc = _mm256_unpacklo_ps(c, d);
+  td = _mm256_unpackhi_ps(c, d);
+  te = _mm256_unpacklo_ps(e, f);
+  tf = _mm256_unpackhi_ps(e, f);
+  tg = _mm256_unpacklo_ps(g, h);
+  th = _mm256_unpackhi_ps(g, h);
+
+  // unpacking and interleaving 64-bit elements
+  //  a0  b0  c0  d0  a4  b4  c4  d4
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  e0  f0  g0  h0  e4  f4  g4  h4
+  //  e1  f1  g1  h1 ...
+  //  e2  f2  g2  h2 ...
+  //  e3  f3  g3  h3 ...
+  a = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc)));
+  b = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc)));
+  c = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td)));
+  d = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td)));
+  e = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg)));
+  f = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg)));
+  g = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th)));
+  h = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th)));
+
+  //  shuffle 128-bits (composed of 4 32-bit elements)
+  //  a0  b0  c0  d0  e0  f0  g0  h0
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  a4  b4  c4  d4 ...
+  //  a5  b5  c5  d5 ...
+  //  a6  b6  c6  d6 ...
+  //  a7  b7  c7  d7 ...
+  ta = _mm256_permute2f128_ps(a, e, 0x20);
+  tb = _mm256_permute2f128_ps(b, f, 0x20);
+  tc = _mm256_permute2f128_ps(c, g, 0x20);
+  td = _mm256_permute2f128_ps(d, h, 0x20);
+  te = _mm256_permute2f128_ps(a, e, 0x31);
+  tf = _mm256_permute2f128_ps(b, f, 0x31);
+  tg = _mm256_permute2f128_ps(c, g, 0x31);
+  th = _mm256_permute2f128_ps(d, h, 0x31);
+
+  // store from registers to dst
+  _mm256_storeu_ps(&dst[0 * ld_dst], ta);
+  _mm256_storeu_ps(&dst[1 * ld_dst], tb);
+  _mm256_storeu_ps(&dst[2 * ld_dst], tc);
+  _mm256_storeu_ps(&dst[3 * ld_dst], td);
+  _mm256_storeu_ps(&dst[4 * ld_dst], te);
+  _mm256_storeu_ps(&dst[5 * ld_dst], tf);
+  _mm256_storeu_ps(&dst[6 * ld_dst], tg);
+  _mm256_storeu_ps(&dst[7 * ld_dst], th);
+}
+
+#endif
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float_neon.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float_neon.h
new file mode 100644
index 0000000000000000000000000000000000000000..a8f9a5e74500be63df8b04539a527013ab99f33f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float_neon.h
@@ -0,0 +1,892 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+
+#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
+#include 
+#endif
+
+// Sleef offers vectorized versions of some transcedentals
+// such as sin, cos, tan etc..
+// However for now opting for STL, since we are not building
+// with Sleef for mobile yet.
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+// Right now contains only aarch64 implementation.
+// Due to follow two reasons aarch32 is not currently supported.
+// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
+//    that work for aarch64 dont work for aarch32.
+// 2. Android NDK r21 has problems with compiling aarch32.
+//    Clang seg faults.
+//    https://github.com/android/ndk/issues/1248
+//    https://bugs.llvm.org/show_bug.cgi?id=45824
+// Most likely we will do aarch32 support with inline asm.
+#if defined(__aarch64__)
+
+#ifdef __BIG_ENDIAN__
+#error "Big endian is not supported."
+#endif
+
+#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
+#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
+#else
+#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
+#endif
+
+template
+struct BlendRegs {
+  static float32x4_t impl(
+    const float32x4_t& a, const float32x4_t& b, float32x4_t& res);
+};
+
+template
+struct BlendRegs{
+  static float32x4_t impl(
+      const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
+    return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
+  }
+};
+
+template
+struct BlendRegs{
+  static float32x4_t impl(
+      const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
+    return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
+  }
+};
+
+template <> class Vectorized {
+private:
+  float32x4x2_t values;
+public:
+  using value_type = float;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+  Vectorized() {}
+  Vectorized(float32x4x2_t v) : values(v) {}
+  Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {}
+  Vectorized(float val0, float val1, float val2, float val3,
+         float val4, float val5, float val6, float val7) :
+         values{val0, val1, val2, val3, val4, val5, val6, val7} {}
+  Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {}
+  operator float32x4x2_t() const {
+    return values;
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    Vectorized vec;
+    // 0.
+    vec.values.val[0] =
+      BlendRegs<0, (mask & 0x01)!=0>::impl(
+          a.values.val[0], b.values.val[0], vec.values.val[0]);
+    vec.values.val[0] =
+      BlendRegs<1, (mask & 0x02)!=0>::impl(
+          a.values.val[0], b.values.val[0], vec.values.val[0]);
+    vec.values.val[0] =
+      BlendRegs<2, (mask & 0x04)!=0>::impl(
+          a.values.val[0], b.values.val[0], vec.values.val[0]);
+    vec.values.val[0] =
+      BlendRegs<3, (mask & 0x08)!=0>::impl(
+          a.values.val[0], b.values.val[0], vec.values.val[0]);
+    // 1.
+    vec.values.val[1] =
+      BlendRegs<0, (mask & 0x10)!=0>::impl(
+          a.values.val[1], b.values.val[1], vec.values.val[1]);
+    vec.values.val[1] =
+      BlendRegs<1, (mask & 0x20)!=0>::impl(
+          a.values.val[1], b.values.val[1], vec.values.val[1]);
+    vec.values.val[1] =
+      BlendRegs<2, (mask & 0x40)!=0>::impl(
+          a.values.val[1], b.values.val[1], vec.values.val[1]);
+    vec.values.val[1] =
+      BlendRegs<3, (mask & 0x80)!=0>::impl(
+          a.values.val[1], b.values.val[1], vec.values.val[1]);
+    return vec;
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                              const Vectorized& mask) {
+    // TODO
+    // NB: This requires that each value, i.e., each uint value,
+    // of the mask either all be zeros or all be 1s.
+    // We perhaps need some kind of an assert?
+    // But that will affect performance.
+    Vectorized vec(mask.values);
+    vec.values.val[0] = vbslq_f32(
+        vreinterpretq_u32_f32(vec.values.val[0]),
+        b.values.val[0],
+        a.values.val[0]);
+    vec.values.val[1] = vbslq_f32(
+        vreinterpretq_u32_f32(vec.values.val[1]),
+        b.values.val[1],
+        a.values.val[1]);
+    return vec;
+  }
+  template
+  static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) {
+    const Vectorized base_vec(base);
+    const Vectorized step_vec(step);
+    const Vectorized step_sizes(0, 1, 2, 3, 4, 5, 6, 7);
+    return fmadd(step_sizes, step_vec, base_vec);
+  }
+  static Vectorized set(const Vectorized& a, const Vectorized& b,
+                           int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        {
+          Vectorized vec;
+          static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0};
+          vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
+          vec.values.val[1] = a.values.val[1];
+          vec.values.val[0] = vbslq_f32(
+              vreinterpretq_u32_f32(vec.values.val[0]),
+              b.values.val[0],
+              a.values.val[0]);
+          return vec;
+        }
+      case 2:
+        {
+          Vectorized vec;
+          static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
+          vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
+          vec.values.val[1] = a.values.val[1];
+          vec.values.val[0] = vbslq_f32(
+              vreinterpretq_u32_f32(vec.values.val[0]),
+              b.values.val[0],
+              a.values.val[0]);
+          return vec;
+        }
+      case 3:
+        {
+          Vectorized vec;
+          static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
+          vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
+          vec.values.val[1] = a.values.val[1];
+          vec.values.val[0] = vbslq_f32(
+              vreinterpretq_u32_f32(vec.values.val[0]),
+              b.values.val[0],
+              a.values.val[0]);
+          return vec;
+        }
+      case 4:
+        return Vectorized(b.values.val[0], a.values.val[1]);
+      case 5:
+        {
+          Vectorized vec;
+          static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0};
+          vec.values.val[0] = b.values.val[0];
+          vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
+          vec.values.val[1] = vbslq_f32(
+              vreinterpretq_u32_f32(vec.values.val[1]),
+              b.values.val[1],
+              a.values.val[1]);
+          return vec;
+        }
+      case 6:
+        {
+          Vectorized vec;
+          static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
+          vec.values.val[0] = b.values.val[0];
+          vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
+          vec.values.val[1] = vbslq_f32(
+              vreinterpretq_u32_f32(vec.values.val[1]),
+              b.values.val[1],
+              a.values.val[1]);
+          return vec;
+        }
+      case 7:
+        {
+          Vectorized vec;
+          static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
+          vec.values.val[0] = b.values.val[0];
+          vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
+          vec.values.val[1] = vbslq_f32(
+              vreinterpretq_u32_f32(vec.values.val[1]),
+              b.values.val[1],
+              a.values.val[1]);
+          return vec;
+        }
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr, int64_t count = size()) {
+    if (count == size()) {
+      return vld1q_f32_x2(reinterpret_cast(ptr));
+    }
+    else if (count == (size() >> 1)) {
+      Vectorized res;
+      res.values.val[0] = vld1q_f32(reinterpret_cast(ptr));
+      res.values.val[1] = vdupq_n_f32(0.f);
+      return res;
+    }
+    else {
+      __at_align__ float tmp_values[size()];
+      for (const auto i : c10::irange(size())) {
+        tmp_values[i] = 0.0;
+      }
+      std::memcpy(
+          tmp_values,
+          reinterpret_cast(ptr),
+          count * sizeof(float));
+      return vld1q_f32_x2(reinterpret_cast(tmp_values));
+    }
+  }
+  void store(void* ptr, int64_t count = size()) const {
+    if (count == size()) {
+      vst1q_f32_x2(reinterpret_cast(ptr), values);
+    }
+    else if (count == (size() >> 1)) {
+      vst1q_f32(reinterpret_cast(ptr), values.val[0]);
+    }
+    else {
+      float tmp_values[size()];
+      vst1q_f32_x2(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(float));
+    }
+  }
+  inline const float32x4_t& get_low() const {
+    return values.val[0];
+  }
+  inline float32x4_t& get_low() {
+    return values.val[0];
+  }
+  inline const float32x4_t& get_high() const {
+    return values.val[1];
+  }
+  inline float32x4_t& get_high() {
+    return values.val[1];
+  }
+  // Very slow implementation of indexing.
+  // Only required because vec256_qint refers to this.
+  // Once we specialize that implementation for ARM
+  // this should be removed. TODO (kimishpatel)
+  float operator[](int idx) const {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    return tmp[idx];
+  }
+  float operator[](int idx) {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    return tmp[idx];
+  }
+  // For boolean version where we want to if any 1/all zero
+  // etc. can be done faster in a different way.
+  int zero_mask() const {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    int mask = 0;
+    for (int i = 0; i < size(); ++ i) {
+      if (tmp[i] == 0.f) {
+        mask |= (1 << i);
+      }
+    }
+    return mask;
+  }
+  Vectorized isnan() const {
+    __at_align__ float tmp[size()];
+    __at_align__ float res[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      if (_isnan(tmp[i])) {
+        std::memset(static_cast(&res[i]), 0xFF, sizeof(float));
+      } else {
+        std::memset(static_cast(&res[i]), 0, sizeof(float));
+      }
+    }
+    return loadu(res);
+  };
+  bool has_inf_nan() const {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      if(_isnan(tmp[i]) || _isinf(tmp[i])) {
+        return true;
+      }
+    }
+    return false;
+  }
+  Vectorized map(float (*const f)(float)) const {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized abs() const {
+    return Vectorized(vabsq_f32(values.val[0]), vabsq_f32(values.val[1]));
+  }
+  Vectorized angle() const {
+    auto zero = Vectorized(0);
+    auto pi = Vectorized(c10::pi);
+    auto tmp = blendv(zero, pi, *this < zero);
+    return blendv(tmp, *this, isnan());
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized(0.f);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])),
+      map(std::acos)
+    );
+  }
+  Vectorized asin() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])),
+      map(std::asin)
+    );
+  }
+  Vectorized atan() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])),
+      map(std::atan)
+    );
+  }
+  Vectorized atanh() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])),
+      map(std::atanh)
+    );
+  }
+  Vectorized atan2(const Vectorized &exp) const {
+    USE_SLEEF(
+      {
+        return Vectorized(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]),
+                                 Sleef_atan2f4_u10(values.val[1], exp.values.val[1]));
+      },
+      {
+        __at_align__ float tmp[size()];
+        __at_align__ float tmp_exp[size()];
+        store(tmp);
+        exp.store(tmp_exp);
+        for (const auto i : c10::irange(size())) {
+          tmp[i] = std::atan2(tmp[i], tmp_exp[i]);
+        }
+        return loadu(tmp);
+      }
+    )
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    USE_SLEEF(
+      {
+        return Vectorized(Sleef_copysignf4(values.val[0], sign.values.val[0]),
+                                 Sleef_copysignf4(values.val[1], sign.values.val[1]));
+      },
+      {
+        __at_align__ float tmp[size()];
+        __at_align__ float tmp_sign[size()];
+        store(tmp);
+        sign.store(tmp_sign);
+        for (size_type i = 0; i < size(); i++) {
+          tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
+        }
+        return loadu(tmp);
+      }
+    )
+  }
+  Vectorized erf() const;
+  Vectorized erfc() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])),
+      map(std::erfc)
+    );
+  }
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+  Vectorized exp() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])),
+      map(std::exp)
+    );
+  }
+  Vectorized exp2() const {
+    return USE_SLEEF(
+        Vectorized(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])),
+        map(std::exp2)
+      );
+  }
+  Vectorized expm1() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])),
+      map(std::expm1)
+    );
+  }
+  Vectorized exp_u20() const {
+    return exp();
+  }
+  Vectorized fmod(const Vectorized& q) const {
+    USE_SLEEF(
+      {
+        return Vectorized(Sleef_fmodf4(values.val[0], q.values.val[0]),
+                                 Sleef_fmodf4(values.val[1], q.values.val[1]));
+      },
+      {
+        __at_align__ float tmp[size()];
+        __at_align__ float tmp_q[size()];
+        store(tmp);
+        q.store(tmp_q);
+        for (const auto i : c10::irange(size())) {
+          tmp[i] = std::fmod(tmp[i], tmp_q[i]);
+        }
+        return loadu(tmp);
+      }
+    )
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    USE_SLEEF(
+      {
+        return Vectorized(Sleef_hypotf4_u05(values.val[0], b.values.val[0]),
+                                 Sleef_hypotf4_u05(values.val[1], b.values.val[1]));
+      },
+      {
+        __at_align__ float tmp[size()];
+        __at_align__ float tmp_b[size()];
+        store(tmp);
+        b.store(tmp_b);
+        for (const auto i : c10::irange(size())) {
+          tmp[i] = std::hypot(tmp[i], tmp_b[i]);
+        }
+        return loadu(tmp);
+      }
+    )
+  }
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __at_align__ float tmp[size()];
+    __at_align__ float tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized igammac(const Vectorized &x) const {
+    __at_align__ float tmp[size()];
+    __at_align__ float tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized log() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])),
+      map(std::log)
+    );
+  }
+  Vectorized log10() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])),
+      map(std::log10)
+    );
+  }
+  Vectorized log1p() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])),
+      map(std::log1p)
+    );
+  }
+  Vectorized log2() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])),
+      map(std::log2)
+    );
+  }
+  Vectorized nextafter(const Vectorized &b) const {
+    USE_SLEEF(
+      {
+        return Vectorized(Sleef_nextafterf4(values.val[0], b.values.val[0]),
+                                 Sleef_nextafterf4(values.val[1], b.values.val[1]));
+      },
+      {
+        __at_align__ float tmp[size()];
+        __at_align__ float tmp_b[size()];
+        store(tmp);
+        b.store(tmp_b);
+        for (const auto i : c10::irange(size())) {
+          tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
+        }
+        return loadu(tmp);
+      }
+    )
+  }
+  Vectorized frac() const;
+  Vectorized sin() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])),
+      map(std::sin)
+    );
+  }
+  Vectorized sinh() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])),
+      map(std::sinh)
+    );
+  }
+  Vectorized cos() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])),
+      map(std::cos)
+    );
+  }
+  Vectorized cosh() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])),
+      map(std::cosh)
+    );
+  }
+  Vectorized ceil() const {
+    return map(at::native::ceil_impl);
+  }
+  Vectorized floor() const {
+    return map(at::native::floor_impl);
+  }
+  Vectorized neg() const {
+    return Vectorized(
+        vnegq_f32(values.val[0]),
+        vnegq_f32(values.val[1]));
+  }
+  Vectorized round() const {
+    // We do not use std::round because we would like to round midway numbers to the nearest even integer.
+    return map(at::native::round_impl);
+  }
+  Vectorized tan() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])),
+      map(std::tan)
+    );
+  }
+  Vectorized tanh() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])),
+      map(std::tanh)
+    );
+  }
+  Vectorized trunc() const {
+    float32x4_t r0 = vrndq_f32(values.val[0]);
+    float32x4_t r1 = vrndq_f32(values.val[1]);
+    return Vectorized(r0, r1);
+  }
+  Vectorized lgamma() const {
+    return USE_SLEEF(
+      Vectorized(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])),
+      map(std::lgamma)
+    );
+  }
+  Vectorized sqrt() const {
+    return Vectorized(
+        vsqrtq_f32(values.val[0]),
+        vsqrtq_f32(values.val[1]));
+  }
+  Vectorized reciprocal() const {
+    auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]);
+    auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]);
+    return Vectorized(r0, r1);
+  }
+  Vectorized rsqrt() const {
+    return this->sqrt().reciprocal();
+  }
+  Vectorized pow(const Vectorized &exp) const {
+    USE_SLEEF(
+      {
+        return Vectorized(Sleef_powf4_u10(values.val[0], exp.values.val[0]),
+                                 Sleef_powf4_u10(values.val[1], exp.values.val[1]));
+      },
+      {
+        __at_align__ float tmp[size()];
+        __at_align__ float tmp_exp[size()];
+        store(tmp);
+        exp.store(tmp_exp);
+        for (const auto i : c10::irange(size())) {
+          tmp[i] = std::pow(tmp[i], tmp_exp[i]);
+        }
+        return loadu(tmp);
+      }
+    )
+  }
+  Vectorized operator==(const Vectorized& other) const {
+    float32x4_t r0 =
+      vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0]));
+    float32x4_t r1 =
+      vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1]));
+    return Vectorized(r0, r1);
+  }
+
+  Vectorized operator!=(const Vectorized& other) const {
+    float32x4_t r0 = vreinterpretq_f32_u32(
+        vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0])));
+    float32x4_t r1 = vreinterpretq_f32_u32(
+        vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1])));
+    return Vectorized(r0, r1);
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    float32x4_t r0 =
+      vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0]));
+    float32x4_t r1 =
+      vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1]));
+    return Vectorized(r0, r1);
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    float32x4_t r0 =
+      vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0]));
+    float32x4_t r1 =
+      vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1]));
+    return Vectorized(r0, r1);
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    float32x4_t r0 =
+      vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0]));
+    float32x4_t r1 =
+      vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1]));
+    return Vectorized(r0, r1);
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    float32x4_t r0 =
+      vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0]));
+    float32x4_t r1 =
+      vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1]));
+    return Vectorized(r0, r1);
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low());
+  float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low());
+  float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low());
+  float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low());
+  float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vmaxq_f32(a.get_low(), b.get_low());
+  float32x4_t r1 = vmaxq_f32(a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vminq_f32(a.get_low(), b.get_low());
+  float32x4_t r1 = vminq_f32(a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) {
+  return minimum(max, maximum(min, a));
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  return minimum(max, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  return maximum(min, a);
+}
+
+template <>
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32(
+      vreinterpretq_u32_f32(a.get_low()),
+      vreinterpretq_u32_f32(b.get_low())));
+  float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32(
+      vreinterpretq_u32_f32(a.get_high()),
+      vreinterpretq_u32_f32(b.get_high())));
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32(
+      vreinterpretq_u32_f32(a.get_low()),
+      vreinterpretq_u32_f32(b.get_low())));
+  float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32(
+      vreinterpretq_u32_f32(a.get_high()),
+      vreinterpretq_u32_f32(b.get_high())));
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32(
+      vreinterpretq_u32_f32(a.get_low()),
+      vreinterpretq_u32_f32(b.get_low())));
+  float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32(
+      vreinterpretq_u32_f32(a.get_high()),
+      vreinterpretq_u32_f32(b.get_high())));
+  return Vectorized(r0, r1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+template <>
+inline void convert(const float* src, int32_t* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
+    vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4)));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+inline void convert(const int32_t* src, float* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
+    vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4)));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low());
+  float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+template <>
+Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low());
+  float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high());
+  return Vectorized(r0, r1);
+}
+
+inline Vectorized Vectorized::erf() const{
+    // constants
+    const Vectorized neg_zero_vec(-0.f);
+    const Vectorized one_vec(1.0f);
+    const Vectorized p(0.3275911f);
+    const Vectorized p1(0.254829592f);
+    const Vectorized p2(-0.284496736f);
+    const Vectorized p3(1.421413741f);
+    const Vectorized p4(-1.453152027f);
+    const Vectorized p5(1.061405429f);
+    // sign(x)
+    auto sign_mask = neg_zero_vec & *this;
+    auto abs_vec = this->abs();
+    // t = 1 / (p * abs(x) + 1)
+    auto tmp0 = fmadd(p, abs_vec, one_vec);
+    auto t = one_vec / tmp0;
+    // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
+    auto tmp1 = fmadd(p5, t, p4);
+    auto tmp2 = fmadd(tmp1, t, p3);
+    auto tmp3 = fmadd(tmp2, t, p2);
+    auto r = fmadd(tmp3, t, p1);
+    // - exp(- x * x)
+    auto pow_2 = (*this) * (*this);
+    auto neg_pow_2 = pow_2 ^ neg_zero_vec;
+    auto tmp4 = neg_pow_2.map(std::exp); // This can be swapped for a faster implementation of exp.
+    auto tmp5 = tmp4 ^ neg_zero_vec;
+    // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
+    auto tmp6 = t * tmp5;
+    auto tmp7 = fmadd(tmp6, r, one_vec);
+    return tmp7 ^ sign_mask;
+}
+#endif /* defined(aarch64) */
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h
new file mode 100644
index 0000000000000000000000000000000000000000..5f337fea3bfdf20a44e69600e3368b408f537623
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h
@@ -0,0 +1,1586 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::vec {
+inline namespace CPU_CAPABILITY {
+
+#ifdef CPU_CAPABILITY_AVX2
+
+struct Vectorizedi {
+protected:
+  __m256i values;
+
+  static inline __m256i invert(const __m256i& v) {
+    const auto ones = _mm256_set1_epi64x(-1);
+    return _mm256_xor_si256(ones, v);
+  }
+public:
+  Vectorizedi() {}
+  Vectorizedi(__m256i v) : values(v) {}
+  operator __m256i() const {
+    return values;
+  }
+};
+
+#else
+
+struct Vectorizedi {};  // dummy definition to make Vectorizedi always defined
+
+#endif // CPU_CAPABILITY_AVX2
+
+#ifdef CPU_CAPABILITY_AVX2
+
+template <>
+class Vectorized : public Vectorizedi {
+private:
+  static const Vectorized ones;
+public:
+  using value_type = int64_t;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 4;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized() {}
+  Vectorized(int64_t v) { values = _mm256_set1_epi64x(v); }
+  Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) {
+    values = _mm256_setr_epi64x(val1, val2, val3, val4);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    __at_align__ int64_t tmp_values[size()];
+    a.store(tmp_values);
+    if (mask & 0x01)
+      tmp_values[0] = _mm256_extract_epi64(b.values, 0);
+    if (mask & 0x02)
+      tmp_values[1] = _mm256_extract_epi64(b.values, 1);
+    if (mask & 0x04)
+      tmp_values[2] = _mm256_extract_epi64(b.values, 2);
+    if (mask & 0x08)
+      tmp_values[3] = _mm256_extract_epi64(b.values, 3);
+    return loadu(tmp_values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                                const Vectorized& mask) {
+    return _mm256_blendv_epi8(a.values, b.values, mask.values);
+  }
+  template 
+  static Vectorized arange(int64_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(base, base + step, base + 2 * step, base + 3 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm256_loadu_si256(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ int64_t tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
+    return loadu(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
+    } else if (count > 0) {
+      __at_align__ int64_t tmp_values[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
+    }
+  }
+  const int64_t& operator[](int idx) const  = delete;
+  int64_t& operator[](int idx)  = delete;
+  Vectorized abs() const {
+    auto zero = _mm256_set1_epi64x(0);
+    auto is_larger = _mm256_cmpgt_epi64(zero, values);
+    auto inverse = _mm256_xor_si256(values, is_larger);
+    return _mm256_sub_epi64(inverse, is_larger);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_epi64x(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized neg() const;
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmpeq_epi64(values, other.values);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    return invert(_mm256_cmpeq_epi64(values, other.values));
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    return _mm256_cmpgt_epi64(other.values, values);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi64(values, other.values));
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return _mm256_cmpgt_epi64(values, other.values);
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi64(other.values, values));
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+class Vectorized : public Vectorizedi {
+private:
+  static const Vectorized ones;
+public:
+  using value_type = int32_t;
+  static constexpr int size() {
+    return 8;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized() {}
+  Vectorized(int32_t v) { values = _mm256_set1_epi32(v); }
+  Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
+         int32_t val5, int32_t val6, int32_t val7, int32_t val8) {
+    values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    return _mm256_blend_epi32(a, b, mask);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                                const Vectorized& mask) {
+    return _mm256_blendv_epi8(a.values, b.values, mask.values);
+  }
+  template 
+  static Vectorized arange(int32_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,            base +     step, base + 2 * step, base + 3 * step,
+      base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, int32_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm256_loadu_si256(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu(const void* ptr, int32_t count) {
+    __at_align__ int32_t tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
+    return loadu(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
+    } else if (count > 0) {
+      __at_align__ int32_t tmp_values[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
+    }
+  }
+  const int32_t& operator[](int idx) const  = delete;
+  int32_t& operator[](int idx)  = delete;
+  Vectorized abs() const {
+    return _mm256_abs_epi32(values);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_epi32(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized neg() const;
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmpeq_epi32(values, other.values);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    return invert(_mm256_cmpeq_epi32(values, other.values));
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    return _mm256_cmpgt_epi32(other.values, values);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi32(values, other.values));
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return _mm256_cmpgt_epi32(values, other.values);
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi32(other.values, values));
+  }
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+inline void convert(const int32_t *src, float *dst, int64_t n) {
+  int64_t i;
+  // int32_t and float have same size
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto input_vec = _mm256_loadu_si256(reinterpret_cast(src + i));
+    auto output_vec = _mm256_cvtepi32_ps(input_vec);
+    _mm256_storeu_ps(reinterpret_cast(dst + i), output_vec);
+  }
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+inline void convert(const int32_t *src, double *dst, int64_t n) {
+  int64_t i;
+  // int32_t has half the size of double
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto input_128_vec = _mm_loadu_si128(reinterpret_cast(src + i));
+    auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
+    _mm256_storeu_pd(reinterpret_cast(dst + i), output_vec);
+  }
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+class Vectorized : public Vectorizedi {
+private:
+  static const Vectorized ones;
+public:
+  using value_type = int16_t;
+  static constexpr int size() {
+    return 16;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized() {}
+  Vectorized(int16_t v) { values = _mm256_set1_epi16(v); }
+  Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
+         int16_t val5, int16_t val6, int16_t val7, int16_t val8,
+         int16_t val9, int16_t val10, int16_t val11, int16_t val12,
+         int16_t val13, int16_t val14, int16_t val15, int16_t val16) {
+    values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8,
+                               val9, val10, val11, val12, val13, val14, val15, val16);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    __at_align__ int16_t tmp_values[size()];
+    a.store(tmp_values);
+    if (mask & 0x01)
+      tmp_values[0] = _mm256_extract_epi16(b.values, 0);
+    if (mask & 0x02)
+      tmp_values[1] = _mm256_extract_epi16(b.values, 1);
+    if (mask & 0x04)
+      tmp_values[2] = _mm256_extract_epi16(b.values, 2);
+    if (mask & 0x08)
+      tmp_values[3] = _mm256_extract_epi16(b.values, 3);
+    if (mask & 0x10)
+      tmp_values[4] = _mm256_extract_epi16(b.values, 4);
+    if (mask & 0x20)
+      tmp_values[5] = _mm256_extract_epi16(b.values, 5);
+    if (mask & 0x40)
+      tmp_values[6] = _mm256_extract_epi16(b.values, 6);
+    if (mask & 0x80)
+      tmp_values[7] = _mm256_extract_epi16(b.values, 7);
+    if (mask & 0x100)
+      tmp_values[8] = _mm256_extract_epi16(b.values, 8);
+    if (mask & 0x200)
+      tmp_values[9] = _mm256_extract_epi16(b.values, 9);
+    if (mask & 0x400)
+      tmp_values[10] = _mm256_extract_epi16(b.values, 10);
+    if (mask & 0x800)
+      tmp_values[11] = _mm256_extract_epi16(b.values, 11);
+    if (mask & 0x1000)
+      tmp_values[12] = _mm256_extract_epi16(b.values, 12);
+    if (mask & 0x2000)
+      tmp_values[13] = _mm256_extract_epi16(b.values, 13);
+    if (mask & 0x4000)
+      tmp_values[14] = _mm256_extract_epi16(b.values, 14);
+    if (mask & 0x8000)
+      tmp_values[15] = _mm256_extract_epi16(b.values, 15);
+    return loadu(tmp_values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                                const Vectorized& mask) {
+    return _mm256_blendv_epi8(a.values, b.values, mask.values);
+  }
+  template 
+  static Vectorized arange(int16_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, int16_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+      case 8:
+        return blend<255>(a, b);
+      case 9:
+        return blend<511>(a, b);
+      case 10:
+        return blend<1023>(a, b);
+      case 11:
+        return blend<2047>(a, b);
+      case 12:
+        return blend<4095>(a, b);
+      case 13:
+        return blend<8191>(a, b);
+      case 14:
+        return blend<16383>(a, b);
+      case 15:
+        return blend<32767>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm256_loadu_si256(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu(const void* ptr, int16_t count) {
+    __at_align__ int16_t tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
+    return loadu(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
+    } else if (count > 0) {
+      __at_align__ int16_t tmp_values[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
+    }
+  }
+  const int16_t& operator[](int idx) const  = delete;
+  int16_t& operator[](int idx)  = delete;
+  Vectorized abs() const {
+    return _mm256_abs_epi16(values);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_epi16(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized neg() const;
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmpeq_epi16(values, other.values);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    return invert(_mm256_cmpeq_epi16(values, other.values));
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    return _mm256_cmpgt_epi16(other.values, values);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi16(values, other.values));
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return _mm256_cmpgt_epi16(values, other.values);
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi16(other.values, values));
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template 
+class Vectorized8 : public Vectorizedi {
+  static_assert(
+    std::is_same::value || std::is_same::value,
+    "Only int8_t/uint8_t are supported");
+protected:
+  static const Vectorized ones;
+public:
+  using value_type = T;
+  static constexpr int size() {
+    return 32;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized8() {}
+  Vectorized8(T v) { values = _mm256_set1_epi8(v); }
+  Vectorized8(T val1, T val2, T val3, T val4,
+         T val5, T val6, T val7, T val8,
+         T val9, T val10, T val11, T val12,
+         T val13, T val14, T val15, T val16,
+         T val17, T val18, T val19, T val20,
+         T val21, T val22, T val23, T val24,
+         T val25, T val26, T val27, T val28,
+         T val29, T val30, T val31, T val32) {
+    values = _mm256_setr_epi8(val1, val2, val3, val4, val5, val6, val7, val8,
+                              val9, val10, val11, val12, val13, val14, val15, val16,
+                              val17, val18, val19, val20, val21, val22, val23, val24,
+                              val25, val26, val27, val28, val29, val30, val31, val32);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    __at_align__ T tmp_values[size()];
+    a.store(tmp_values);
+    if (mask & 0x01)
+      tmp_values[0] = _mm256_extract_epi8(b.values, 0);
+    if (mask & 0x02)
+      tmp_values[1] = _mm256_extract_epi8(b.values, 1);
+    if (mask & 0x04)
+      tmp_values[2] = _mm256_extract_epi8(b.values, 2);
+    if (mask & 0x08)
+      tmp_values[3] = _mm256_extract_epi8(b.values, 3);
+    if (mask & 0x10)
+      tmp_values[4] = _mm256_extract_epi8(b.values, 4);
+    if (mask & 0x20)
+      tmp_values[5] = _mm256_extract_epi8(b.values, 5);
+    if (mask & 0x40)
+      tmp_values[6] = _mm256_extract_epi8(b.values, 6);
+    if (mask & 0x80)
+      tmp_values[7] = _mm256_extract_epi8(b.values, 7);
+    if (mask & 0x100)
+      tmp_values[8] = _mm256_extract_epi8(b.values, 8);
+    if (mask & 0x200)
+      tmp_values[9] = _mm256_extract_epi8(b.values, 9);
+    if (mask & 0x400)
+      tmp_values[10] = _mm256_extract_epi8(b.values, 10);
+    if (mask & 0x800)
+      tmp_values[11] = _mm256_extract_epi8(b.values, 11);
+    if (mask & 0x1000)
+      tmp_values[12] = _mm256_extract_epi8(b.values, 12);
+    if (mask & 0x2000)
+      tmp_values[13] = _mm256_extract_epi8(b.values, 13);
+    if (mask & 0x4000)
+      tmp_values[14] = _mm256_extract_epi8(b.values, 14);
+    if (mask & 0x8000)
+      tmp_values[15] = _mm256_extract_epi8(b.values, 15);
+    if (mask & 0x010000)
+      tmp_values[16] = _mm256_extract_epi8(b.values, 16);
+    if (mask & 0x020000)
+      tmp_values[17] = _mm256_extract_epi8(b.values, 17);
+    if (mask & 0x040000)
+      tmp_values[18] = _mm256_extract_epi8(b.values, 18);
+    if (mask & 0x080000)
+      tmp_values[19] = _mm256_extract_epi8(b.values, 19);
+    if (mask & 0x100000)
+      tmp_values[20] = _mm256_extract_epi8(b.values, 20);
+    if (mask & 0x200000)
+      tmp_values[21] = _mm256_extract_epi8(b.values, 21);
+    if (mask & 0x400000)
+      tmp_values[22] = _mm256_extract_epi8(b.values, 22);
+    if (mask & 0x800000)
+      tmp_values[23] = _mm256_extract_epi8(b.values, 23);
+    if (mask & 0x1000000)
+      tmp_values[24] = _mm256_extract_epi8(b.values, 24);
+    if (mask & 0x2000000)
+      tmp_values[25] = _mm256_extract_epi8(b.values, 25);
+    if (mask & 0x4000000)
+      tmp_values[26] = _mm256_extract_epi8(b.values, 26);
+    if (mask & 0x8000000)
+      tmp_values[27] = _mm256_extract_epi8(b.values, 27);
+    if (mask & 0x10000000)
+      tmp_values[28] = _mm256_extract_epi8(b.values, 28);
+    if (mask & 0x20000000)
+      tmp_values[29] = _mm256_extract_epi8(b.values, 29);
+    if (mask & 0x40000000)
+      tmp_values[30] = _mm256_extract_epi8(b.values, 30);
+    if (mask & 0x80000000)
+      tmp_values[31] = _mm256_extract_epi8(b.values, 31);
+    return loadu(tmp_values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                               const Vectorized& mask) {
+    return _mm256_blendv_epi8(a.values, b.values, mask.values);
+  }
+  template 
+  static Vectorized arange(T base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
+      base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
+      base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
+      base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
+      base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, T count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<0x1>(a, b);
+      case 2:
+        return blend<0x3>(a, b);
+      case 3:
+        return blend<0x7>(a, b);
+      case 4:
+        return blend<0xF>(a, b);
+      case 5:
+        return blend<0x1F>(a, b);
+      case 6:
+        return blend<0x3F>(a, b);
+      case 7:
+        return blend<0x7F>(a, b);
+      case 8:
+        return blend<0xFF>(a, b);
+      case 9:
+        return blend<0x1FF>(a, b);
+      case 10:
+        return blend<0x3FF>(a, b);
+      case 11:
+        return blend<0x7FF>(a, b);
+      case 12:
+        return blend<0xFFF>(a, b);
+      case 13:
+        return blend<0x1FFF>(a, b);
+      case 14:
+        return blend<0x3FFF>(a, b);
+      case 15:
+        return blend<0x7FFF>(a, b);
+      case 16:
+        return blend<0xFFFF>(a, b);
+      case 17:
+        return blend<0x1FFFF>(a, b);
+      case 18:
+        return blend<0x3FFFF>(a, b);
+      case 19:
+        return blend<0x7FFFF>(a, b);
+      case 20:
+        return blend<0xFFFFF>(a, b);
+      case 21:
+        return blend<0x1FFFFF>(a, b);
+      case 22:
+        return blend<0x3FFFFF>(a, b);
+      case 23:
+        return blend<0x7FFFFF>(a, b);
+      case 24:
+        return blend<0xFFFFFF>(a, b);
+      case 25:
+        return blend<0x1FFFFFF>(a, b);
+      case 26:
+        return blend<0x3FFFFFF>(a, b);
+      case 27:
+        return blend<0x7FFFFFF>(a, b);
+      case 28:
+        return blend<0xFFFFFFF>(a, b);
+      case 29:
+        return blend<0x1FFFFFFF>(a, b);
+      case 30:
+        return blend<0x3FFFFFFF>(a, b);
+      case 31:
+        return blend<0x7FFFFFFF>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm256_loadu_si256(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu_one_fourth(const void* ptr) {
+      // Fast path if only load element number of 8.
+      // Note: We didn't merge it as fast path of loadu(const void* ptr, T count),
+      // Because loadu(const void* ptr, T count) requires zero initialization for upper 128 bits.
+      // However, by using _mm256_castsi128_si256, the upper 128 bits of the result are undefined.
+      // TODO We can use _mm256_zextsi128_si256 in the furture,
+      // since gcc 9.3 doesn't support it now.
+      __m128i input_128 = _mm_loadl_epi64(reinterpret_cast(ptr));
+      return _mm256_castsi128_si256(input_128);
+  }
+  static Vectorized loadu(const void* ptr, T count) {
+    __at_align__ T tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, ptr, count * sizeof(T));
+    return loadu(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
+    } else if (count > 0) {
+      if (count == 8) {
+        // Fast path if only store element number of 8
+        _mm_storel_epi64(reinterpret_cast<__m128i*>(ptr), _mm256_castsi256_si128(values));
+      } else {
+        __at_align__ T tmp_values[size()];
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
+        std::memcpy(ptr, tmp_values, count * sizeof(T));
+      }
+    }
+  }
+  const T& operator[](int idx) const  = delete;
+  T& operator[](int idx)  = delete;
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm256_set1_epi8(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+};
+
+template<>
+class Vectorized: public Vectorized8 {
+public:
+  using Vectorized8::Vectorized8;
+
+  Vectorized neg() const;
+
+  Vectorized abs() const {
+   return _mm256_abs_epi8(values);
+  }
+
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmpeq_epi8(values, other.values);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    return invert(_mm256_cmpeq_epi8(values, other.values));
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    return _mm256_cmpgt_epi8(other.values, values);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    return invert(_mm256_cmpgt_epi8(values, other.values));
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return other < *this;
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return other <= *this;
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template<>
+class Vectorized: public Vectorized8 {
+public:
+  using Vectorized8::Vectorized8;
+
+  Vectorized neg() const;
+
+  Vectorized abs() const {
+    return *this;
+  }
+
+  Vectorized operator==(const Vectorized& other) const {
+    return _mm256_cmpeq_epi8(values, other.values);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    return invert(_mm256_cmpeq_epi8(values, other.values));
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    __m256i max = _mm256_max_epu8(values, other.values);
+    return invert(_mm256_cmpeq_epi8(max, values));
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    __m256i max = _mm256_max_epu8(values, other.values);
+    return _mm256_cmpeq_epi8(max, other.values);
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return other < *this;
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return other <= *this;
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_epi16(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_epi8(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm256_add_epi8(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_epi16(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_epi8(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sub_epi8(a, b);
+}
+
+// Negation. Defined here so we can utilize operator-
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+// Emulate operations with no native 64-bit support in avx,
+// by extracting each element, performing the operation pointwise,
+// then combining the results into a vector.
+template 
+Vectorized inline emulate(const Vectorized& a, const Vectorized& b, const op_t& op) {
+  int64_t a0 = _mm256_extract_epi64(a, 0);
+  int64_t a1 = _mm256_extract_epi64(a, 1);
+  int64_t a2 = _mm256_extract_epi64(a, 2);
+  int64_t a3 = _mm256_extract_epi64(a, 3);
+
+  int64_t b0 = _mm256_extract_epi64(b, 0);
+  int64_t b1 = _mm256_extract_epi64(b, 1);
+  int64_t b2 = _mm256_extract_epi64(b, 2);
+  int64_t b3 = _mm256_extract_epi64(b, 3);
+
+  int64_t c0 = op(a0, b0);
+  int64_t c1 = op(a1, b1);
+  int64_t c2 = op(a2, b2);
+  int64_t c3 = op(a3, b3);
+
+  return _mm256_set_epi64x(c3, c2, c1, c0);
+}
+
+template 
+Vectorized inline emulate(const Vectorized& a, const Vectorized& b, const Vectorized& c, const op_t& op) {
+  int64_t a0 = _mm256_extract_epi64(a, 0);
+  int64_t a1 = _mm256_extract_epi64(a, 1);
+  int64_t a2 = _mm256_extract_epi64(a, 2);
+  int64_t a3 = _mm256_extract_epi64(a, 3);
+
+  int64_t b0 = _mm256_extract_epi64(b, 0);
+  int64_t b1 = _mm256_extract_epi64(b, 1);
+  int64_t b2 = _mm256_extract_epi64(b, 2);
+  int64_t b3 = _mm256_extract_epi64(b, 3);
+
+  int64_t c0 = _mm256_extract_epi64(c, 0);
+  int64_t c1 = _mm256_extract_epi64(c, 1);
+  int64_t c2 = _mm256_extract_epi64(c, 2);
+  int64_t c3 = _mm256_extract_epi64(c, 3);
+
+  int64_t d0 = op(a0, b0, c0);
+  int64_t d1 = op(a1, b1, c1);
+  int64_t d2 = op(a2, b2, c2);
+  int64_t d3 = op(a3, b3, c3);
+
+  return _mm256_set_epi64x(d3, d2, d1, d0);
+}
+
+// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
+// This could be implemented more efficiently using epi32 instructions
+// This is also technically avx compatible, but then we'll need AVX
+// code for add as well.
+// Note: intentionally ignores undefined behavior like (-lowest * -1).
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return emulate(a, b, [](int64_t a_point, int64_t b_point) __ubsan_ignore_undefined__ {return a_point * b_point;});
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm256_mullo_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm256_mullo_epi16(a, b);
+}
+
+template 
+Vectorized inline int_elementwise_binary_256(const Vectorized& a, const Vectorized& b, Op op) {
+  T values_a[Vectorized::size()];
+  T values_b[Vectorized::size()];
+  a.store(values_a);
+  b.store(values_b);
+  for (int i = 0; i != Vectorized::size(); i++) {
+    values_a[i] = op(values_a[i], values_b[i]);
+  }
+  return Vectorized::loadu(values_a);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  // We don't have an instruction for multiplying int8_t
+#ifndef CPU_CAPABILITY_AVX2
+  return int_elementwise_binary_256(a, b, std::multiplies());
+#else
+  __m256i mask00FF = _mm256_set1_epi16(0x00FF);
+  __m256i a_lo = _mm256_srai_epi16(_mm256_slli_epi16(a, 8), 8);
+  __m256i b_lo = _mm256_srai_epi16(_mm256_slli_epi16(b, 8), 8);
+  __m256i a_hi = _mm256_srai_epi16(a, 8);
+  __m256i b_hi = _mm256_srai_epi16(b, 8);
+  __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF);
+  __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8);
+  __m256i res = _mm256_or_si256(res_hi, res_lo);
+  return res;
+#endif
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  // We don't have an instruction for multiplying uint8_t
+#ifndef CPU_CAPABILITY_AVX2
+  return int_elementwise_binary_256(a, b, std::multiplies());
+#else
+  __m256i mask00FF = _mm256_set1_epi16(0x00FF);
+  __m256i a_lo = _mm256_and_si256 (a, mask00FF);
+  __m256i b_lo = _mm256_and_si256 (b, mask00FF);
+  __m256i a_hi = _mm256_srli_epi16(a, 8);
+  __m256i b_hi = _mm256_srli_epi16(b, 8);
+  __m256i res_lo = _mm256_and_si256(_mm256_mullo_epi16(a_lo, b_lo), mask00FF);
+  __m256i res_hi = _mm256_slli_epi16(_mm256_mullo_epi16(a_hi, b_hi), 8);
+  __m256i res = _mm256_or_si256(res_hi, res_lo);
+  return res;
+#endif
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+#ifndef CPU_CAPABILITY_AVX2
+  return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
+#else
+  __m256i cmp = _mm256_cmpgt_epi64(a, b);
+  return _mm256_blendv_epi8(a, b, cmp);
+#endif
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_min_epi32(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_min_epi16(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_min_epi8(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_min_epu8(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+#ifndef CPU_CAPABILITY_AVX2
+  return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
+#else
+  __m256i cmp = _mm256_cmpgt_epi64(a, b);
+  return _mm256_blendv_epi8(b, a, cmp);
+#endif
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_max_epi32(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_max_epi16(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_max_epi8(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm256_max_epu8(a, b);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+#ifndef CPU_CAPABILITY_AVX2
+  return emulate(a, min_val, max_val, [](int64_t a_point, int64_t min_point, int64_t max_point) {return std::min(max_point, std::max(a_point, min_point));});
+#else
+  return minimum(maximum(a, min_val), max_val);
+#endif
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val));
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+#ifndef CPU_CAPABILITY_AVX2
+  return emulate(a, max_val, [](int64_t a_point, int64_t max_point) {return std::min(max_point, a_point);});
+#else
+  return minimum(max_val, a);
+#endif
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm256_min_epi32(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm256_min_epi16(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm256_min_epi8(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm256_min_epu8(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+#ifndef CPU_CAPABILITY_AVX2
+  return emulate(a, min_val, [](int64_t a_point, int64_t min_point) {return std::max(min_point, a_point);});
+#else
+  return maximum(min_val, a);
+#endif
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm256_max_epi32(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm256_max_epi16(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm256_max_epi8(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm256_max_epu8(min_val, a);
+}
+
+template
+Vectorized inline convert_to_int32(const T* ptr) {
+  return Vectorized::loadu(ptr);
+}
+
+template<>
+Vectorized inline convert_to_int32(const int8_t* ptr) {
+  return _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast(ptr)));
+}
+
+template<>
+Vectorized inline convert_to_int32(const uint8_t* ptr) {
+  return _mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast(ptr)));
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_256(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_256(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_256(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_256(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_256(a, b, std::divides());
+}
+
+template>::value, int> = 0>
+inline Vectorized operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm256_and_si256(a, b);
+}
+template>::value, int> = 0>
+inline Vectorized operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm256_or_si256(a, b);
+}
+template>::value, int> = 0>
+inline Vectorized operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm256_xor_si256(a, b);
+}
+template>::value, int> = 0>
+inline Vectorized operator~(const Vectorized& a) {
+  return _mm256_xor_si256(a, _mm256_set1_epi32(-1));
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+template 
+Vectorized inline shift_256_16(const Vectorized& a, const Vectorized& b) {
+  // No vector instruction for shifting int16_t, so emulating it instead.
+
+  // Control masks for shuffle operation, treating 256 bits as an
+  // array of 16-bit elements, and considering pairs of neighboring
+  // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
+  // M!=N) is set so that shuffle will move element with index M from
+  // input pair into element with index N in output pair, and element
+  // with index M in output pair will be set to all 0s.
+  __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80,
+                                    21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80,
+                                    13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80,
+                                    5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80);
+  __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26,
+                                    0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18,
+                                    0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10,
+                                    0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2);
+
+  // Masks for bitwise and operation, treating 256 bits as an array of
+  // 16-bit elements, and considering them in pairs of neighboring
+  // elements.  A mask named "keep_M" (M in [0,1]) is set so that
+  // bitwise and will copy element with index M from input pair into
+  // element with the same index in output pair, while the other
+  // element in output pair will be set to all 0s.
+  __m256i keep_0 = _mm256_set1_epi32(0xFFFF);
+  __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000);
+
+  // Take each 16-bit element with idx%2==0 from input array to be
+  // shifted and extend it to 32 bits so that 0s are added to the
+  // right.  Then, perform shifting on this 32-bit number.  Upper 16
+  // bits will be proper result of shifting original 16-bit number, so
+  // write them to result array, into the same position from which
+  // corresponding input element is taken.  Also, make sure that
+  // result array elements with idx%2!=0 are set to all 0s.
+  //
+  // Note that number of bits to shift for is extended to 32 bits by
+  // adding 0s to the left.  That means this number is not properly
+  // sign-extended for negative values.  However, number of bits to
+  // shift is treated as an unsigned integer by respective shift
+  // intrinsics anyway so if negative then either with or without
+  // proper sign extension, it will be interpreted as a number greater
+  // than 32, and the shifting result will be the same.
+  __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1);
+  __m256i b0 = _mm256_and_si256(b, keep_0);
+  __m256i c0;
+  if (left_shift)
+    c0 = _mm256_sllv_epi32(a0, b0);
+  else
+    c0 = _mm256_srav_epi32(a0, b0);
+  c0 = _mm256_shuffle_epi8(c0, ctl_1_0);
+
+  // Peform shifting the same way for input array elements with
+  // idx%2==1.
+  __m256i a1 = _mm256_and_si256(a, keep_1);
+  __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
+  __m256i c1;
+  if (left_shift)
+    c1 = _mm256_sllv_epi32(a1, b1);
+  else
+    c1 = _mm256_srav_epi32(a1, b1);
+  c1 = _mm256_and_si256(c1, keep_1);
+
+  // Merge partial results into the final result.
+  __m256i c = _mm256_or_si256(c0, c1);
+
+  return c;
+}
+
+template ::value || std::is_same::value, int> = 0>
+Vectorized inline shift_256_8(const Vectorized& a, const Vectorized& b) {
+  // No vector instruction for shifting int8_t/uint8_t, so emulating
+  // it instead.
+
+  // Control masks for shuffle operation, treating 256 bits as an
+  // array of 8-bit elements, and considering quadruples of
+  // neighboring elements.  Specifially, a mask named "ctl_M_N" (M,N
+  // in [0,1,2,3], and M!=N) is set so that shuffle will move element
+  // with index M from input quadruple into element with index N in
+  // output quadruple, and other elements in output quadruple will be
+  // set to all 0s.
+  __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80,
+                                    20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80,
+                                    12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80,
+                                    4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80);
+  __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25,
+                                    0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17,
+                                    0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9,
+                                    0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1);
+  __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80,
+                                    21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80,
+                                    13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80,
+                                    5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80);
+  __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26,
+                                    0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18,
+                                    0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10,
+                                    0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2);
+  __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80,
+                                    22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80,
+                                    14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80,
+                                    6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80);
+  __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27,
+                                    0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19,
+                                    0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11,
+                                    0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3);
+  __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80,
+                                    0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80,
+                                    0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80,
+                                    0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80);
+  __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80,
+                                    0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80,
+                                    0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80,
+                                    0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80);
+
+  // Masks for bitwise and operation, treating 256 bits as an array of
+  // 8-bit elements, and considering them in quadruples of neighboring
+  // elements.  A mask named "keep_M" (M in [0,1,2,3]) is set so that
+  // bitwise and will copy element with index M from input quadruple
+  // into element with the same index in output quadruple, while the
+  // other elements in output quadruple will be set to all 0s.
+  __m256i keep_0 = _mm256_set1_epi32(0xFF);
+  __m256i keep_3 = _mm256_set1_epi32(0xFF000000);
+
+  // Take each 8-bit element with idx%4==0 from input array to be
+  // shifted and extend it to 32 bits so that 0s are added to the
+  // right.  Then, perform shifting on this 32-bit number.  Upper 8
+  // bits will be proper result of shifting original 8-bit number, so
+  // write them to result array, into the same position from which
+  // corresponding input element is taken.  Also, make sure that
+  // result array elements with idx%4!=0 are set to all 0s.
+  //
+  // Note that number of bits to shift for is extended to 32 bits by
+  // adding 0s to the left.  That means this number is not properly
+  // sign-extended for negative values.  However, number of bits to
+  // shift is treated as an unsigned integer by respective shift
+  // intrinsics anyway so if negative then either with or without
+  // proper sign extension, it will be interpreted as a number greater
+  // than 32, and the shifting result will be the same.
+  __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3);
+  __m256i b0 = _mm256_and_si256(b, keep_0);
+  __m256i c0;
+  if (left_shift)
+    c0 = _mm256_sllv_epi32(a0, b0);
+  else
+    if constexpr (std::is_same_v)
+      c0 = _mm256_srav_epi32(a0, b0);
+    else
+      c0 = _mm256_srlv_epi32(a0, b0);
+  c0 = _mm256_shuffle_epi8(c0, ctl_3_0);
+
+  // Peform shifting the same way for input array elements with
+  // idx%4==1.
+  __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
+  __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
+  __m256i c1;
+  if (left_shift)
+    c1 = _mm256_sllv_epi32(a1, b1);
+  else
+    if constexpr (std::is_same_v)
+      c1 = _mm256_srav_epi32(a1, b1);
+    else
+      c1 = _mm256_srlv_epi32(a1, b1);
+  c1 = _mm256_shuffle_epi8(c1, ctl_3_1);
+
+  // Peform shifting the same way for input array elements with
+  // idx%4==2.
+  __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
+  __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
+  __m256i c2;
+  if (left_shift)
+    c2 = _mm256_sllv_epi32(a2, b2);
+  else
+    if constexpr (std::is_same_v)
+      c2 = _mm256_srav_epi32(a2, b2);
+    else
+      c2 = _mm256_srlv_epi32(a2, b2);
+  c2 = _mm256_shuffle_epi8(c2, ctl_3_2);
+
+  // Peform shifting the same way for input array elements with
+  // idx%4==3.
+  __m256i a3 =  _mm256_and_si256(a, keep_3);
+  __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
+  __m256i c3;
+  if (left_shift)
+    c3 = _mm256_sllv_epi32(a3, b3);
+  else
+    if constexpr (std::is_same_v)
+      c3 = _mm256_srav_epi32(a3, b3);
+    else
+      c3 = _mm256_srlv_epi32(a3, b3);
+  c3 = _mm256_and_si256(c3, keep_3);
+
+  // Merge partial results into the final result.
+  __m256i c01 = _mm256_or_si256(c0, c1);
+  __m256i c23 = _mm256_or_si256(c2, c3);
+  __m256i c = _mm256_or_si256(c01, c23);
+
+  return c;
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sllv_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return _mm256_sllv_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return shift_256_16(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return shift_256_8(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return shift_256_8(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  // No vector instruction for right arithmetic shifting int64_t, so emulating it
+  // instead.
+
+  // Clamp the shift values such that shift values < 0 and > 64 are changed to 64
+  // which results in -1 for negative input and 0 for non-negative input.
+  __m256i zero = _mm256_set1_epi64x(0);
+  __m256i max_shift = _mm256_set1_epi64x(64);
+  __m256i mask = _mm256_or_si256(_mm256_cmpgt_epi64(zero, b), _mm256_cmpgt_epi64(b, max_shift));
+  __m256i shift = _mm256_blendv_epi8(b, max_shift, mask);
+  // Shift the number logically to the right, thus filling the most
+  // significant bits with 0s.  Then, replace these bits with the sign
+  // bit.
+  __m256i sign_bits = _mm256_cmpgt_epi64(zero, a);
+  __m256i sign_shift = _mm256_sub_epi64(max_shift, shift);
+  __m256i sign_ext = _mm256_sllv_epi64(sign_bits, sign_shift);
+  __m256i c = _mm256_srlv_epi64(a, shift);
+  c = _mm256_or_si256(c, sign_ext);
+
+  return c;
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return _mm256_srav_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return shift_256_16(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return shift_256_8(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return shift_256_8(a, b);
+}
+
+#endif
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h
new file mode 100644
index 0000000000000000000000000000000000000000..28e0b4e50a4270d784acf41c6b501620281a9db5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_qint.h
@@ -0,0 +1,1335 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+// This file defines Vectorized<> for the quantized types.
+//
+//
+// Currently, we simply use these classes as efficient converters between
+// the quantized types and Vectorized, usually in bandwidth-bound cases
+// where doing the arithmetic in full-precision is acceptable (e.g.
+// elementwise operators).
+//
+//
+// Conversions are as follows:
+//  Vectorized -> 4x Vectorized
+//  Vectorized -> 4x Vectorized
+//  Vectorized -> 1x Vectorized
+//
+// The size of the returned float vector is specified by the special
+// constexpr function float_num_vecs. The type of the value returned
+// from dequantize (and expected as an argument to quantize) is
+// specified by float_vec_return_type.
+//
+// When writing kernels with these vectors, it is expected that floating-
+// point operations will be carried out in a loop over Vectorized::float_num_vecs
+// iterations.
+
+namespace at::vec {
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+
+struct Vectorizedqi {
+ protected:
+  __m256i vals __attribute__((aligned(64)));
+
+ public:
+  Vectorizedqi() {}
+  Vectorizedqi(__m256i v) : vals(v) {}
+  operator __m256i() const {
+    return vals;
+  }
+};
+
+template 
+__m256i pack_saturate_and_clamp(
+    __m256i first,
+    __m256i second,
+    T min_val,
+    T max_val);
+
+template <>
+inline __m256i pack_saturate_and_clamp(
+    __m256i /*first*/,
+    __m256i /*second*/,
+    int32_t /*min_val*/,
+    int32_t /*max_val*/) {
+  // This function is for linkage only, will not be used
+  AT_ERROR("pack_saturate_and_clamp is not supported");
+}
+
+template <>
+inline __m256i pack_saturate_and_clamp(
+    __m256i first,
+    __m256i second,
+    int8_t min_val,
+    int8_t max_val) {
+  __m256i packed_and_sat = _mm256_packs_epi16(first, second);
+  return _mm256_max_epi8(
+      _mm256_set1_epi8(min_val),
+      _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val)));
+}
+
+template <>
+inline __m256i pack_saturate_and_clamp(
+    __m256i first,
+    __m256i second,
+    uint8_t min_val,
+    uint8_t max_val) {
+  __m256i packed_and_sat = _mm256_packus_epi16(first, second);
+  return _mm256_max_epu8(
+      _mm256_set1_epi8(min_val),
+      _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
+}
+
+template 
+typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type
+inline convert_int8_to_float(at::vec::Vectorized src) {
+  // Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size()
+  // Only handle first 8*8 bits
+  __m128i input_128 = _mm256_castsi256_si128(src);
+  // Convert from 8*uint8/int8 to 8*int32
+  __m256i input_256_int32;
+  if constexpr (std::is_same_v)
+    input_256_int32 = _mm256_cvtepu8_epi32(input_128);
+  else
+    input_256_int32 = _mm256_cvtepi8_epi32(input_128);
+  // Convert from 8*int32 to 8*float
+  return _mm256_cvtepi32_ps(input_256_int32);
+}
+
+template 
+typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type
+inline convert_float_to_int8(at::vec::Vectorized src) {
+  // Convert from float32 to int32 with truncation
+  __m256i x_values_int32 = _mm256_cvttps_epi32(src);
+
+  // Convert from int32 to int16 using signed saturation
+  __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32);
+
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+
+  // Convert from int16 to uint8/int8 using unsigned saturation
+  __m256i xyzw_clamped_v = pack_saturate_and_clamp(
+      xy_packed_v, xy_packed_v, min_val, max_val);
+  __m256i permute_mask_v =
+    _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+  return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+}
+
+template 
+inline void __attribute__((always_inline)) QuantizeAvx2(
+    const float* src,
+    T* dst,
+    int len,
+    float inverse_scale,
+    int64_t zero_point) {
+  constexpr int VLEN = 8;
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+  const __m256i min_v = _mm256_set1_epi32(min_val);
+  const __m256i max_v = _mm256_set1_epi32(max_val);
+  // This is the largest int32 value < int32_max exactly representable in float
+  constexpr int32_t int32_float_max_val =
+      std::numeric_limits::max() - 127;
+  int i = 0;
+  __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
+  // clang-format off
+  static const __m256i shuffle_mask_v = _mm256_set_epi8(
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0x0c, 0x08, 0x04, 0x00,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0x0c, 0x08, 0x04, 0x00);
+  // clang-format on
+  __m256i permute_mask_v =
+      _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+  __m256i permute_mask_l8_v =
+      _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
+  int len_aligned = len / (VLEN * 4) * (VLEN * 4);
+  for (; i < len_aligned; i += 4 * VLEN) {
+    // x
+    __m256 x_vals = _mm256_load_ps(src + i);
+    __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
+    // If the floating point value is greater than int32_max,
+    // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
+    // Clip at int32_float_max_val to avoid this.
+    x_transformed_v =
+        _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
+    // y
+    __m256 y_vals = _mm256_load_ps(src + i + VLEN);
+    __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v);
+    y_transformed_v =
+        _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val));
+    // z
+    __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN);
+    __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v);
+    z_transformed_v =
+        _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val));
+    // w
+    __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN);
+    __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v);
+    w_transformed_v =
+        _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val));
+
+    __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
+    __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
+    __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
+    __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
+
+    // add zero point
+    x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
+    y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point));
+    z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point));
+    w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point));
+
+    __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
+    __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
+    __m256i xyzw_clamped_v =
+        pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val);
+
+    xyzw_clamped_v =
+        _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+    _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
+  }
+
+  // Additional 8-lane AVX2 version to take advantage when len is smaller
+  // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
+  for (; i < len / VLEN * VLEN; i += VLEN) {
+    __m256 x_vals = _mm256_load_ps(src + i);
+    __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
+    x_transformed_v =
+        _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
+    __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
+    x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
+    __m256i x_clipped_v =
+        _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v));
+
+    x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v);
+    x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v);
+    _mm_storel_epi64(
+        reinterpret_cast<__m128i*>(dst + i),
+        _mm256_castsi256_si128(x_clipped_v));
+  }
+
+  for (; i < len; ++i) {
+    float transformed = src[i] * inverse_scale;
+
+    // Not exactly the same behavior as the vectorized code.
+    // The vectorized code above always rounds to even in halfway cases
+    // (https://software.intel.com/en-us/node/523819), but std::nearbyint
+    // does the same only when the current rounding mode is FE_TONEAREST.
+    // However, in practice, this should not be a problem because most cases
+    // use the default rounding mode FE_TONEAREST.
+    // Note that we cannot implement the same behavior as the vectorized code
+    // using std::round because it does rounding away from zero in halfway
+    // cases.
+    transformed = zero_point + std::nearbyint(transformed);
+    float clipped =
+        std::min(std::max(transformed, float(min_val)), float(max_val));
+    dst[i] = clipped;
+  }
+}
+
+template<>
+struct Vectorized : public Vectorizedqi {
+    using size_type = int;
+    static constexpr size_type size() {
+        return 8;
+    }
+
+    static constexpr int float_num_vecs() {
+        return 1;
+    }
+
+    static constexpr int int_num_vecs() {
+        return 1;
+    }
+
+    using float_vec_return_type = std::array, 1>;
+    using int_vec_return_type = std::array, 1>;
+    using value_type = c10::qint32::underlying;
+
+ public:
+    using Vectorizedqi::Vectorizedqi;
+    Vectorized() {}
+
+    Vectorized(__m256i vals_) { vals = vals_;}
+
+    // Broadcast constructor
+    Vectorized(const c10::qint32& val) {
+        value_type uw = val.val_;
+        vals = _mm256_set1_epi32(uw);
+    }
+
+    void store(void* ptr, int count = size()) const {
+      if (count != size()) {
+        memcpy(ptr, &vals, count * sizeof(value_type));
+      } else {
+        _mm256_storeu_si256((__m256i*)ptr, vals);
+      }
+    }
+
+    static Vectorized loadu(const void* ptr) {
+        return Vectorized(ptr);
+    }
+
+    static Vectorized loadu(const void* ptr, int64_t count) {
+        __at_align__ value_type tmp_values[size()];
+        // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+        // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+        // instructions while a loop would be compiled to one instruction.
+        for (const auto i : c10::irange(size())) {
+          tmp_values[i] = 0;
+        }
+        std::memcpy(
+            tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+        return _mm256_loadu_si256((const __m256i*)tmp_values);
+    }
+
+    float_vec_return_type dequantize(
+        Vectorized scale,
+        Vectorized /*zero_point*/,
+        Vectorized scale_zp_premul) const {
+      __m256 float_vals = _mm256_cvtepi32_ps(vals);
+      return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)};
+    }
+
+    float_vec_return_type dequantize(
+        Vectorized scale,
+        Vectorized zero_point) const {
+      __m256 float_vals = _mm256_cvtepi32_ps(vals);
+      return {(Vectorized(float_vals) - zero_point) * scale};
+    }
+
+    static Vectorized quantize(
+        const float_vec_return_type& rhs,
+        float scale,
+        int32_t zero_point,
+        float /*inverse_scale*/) {
+      Vectorized retval;
+      auto rhs_data = (__m256)rhs[0];
+      at::native::quantize_vec(
+          scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
+      return retval;
+    }
+
+    Vectorized maximum(Vectorized b) const {
+      return _mm256_max_epi32(vals, b.vals);
+    }
+
+    Vectorized minimum(Vectorized b) const {
+      return _mm256_min_epi32(vals, b.vals);
+    }
+
+    Vectorized relu(Vectorized zero_point) const {
+        return maximum(zero_point);
+    }
+
+    Vectorized relu6(
+        Vectorized zero_point,
+        Vectorized q_six) {
+      return _mm256_min_epi32(
+          _mm256_max_epi32(vals, zero_point.vals), q_six.vals);
+    }
+
+    int_vec_return_type widening_subtract(Vectorized b) const {
+      return {_mm256_sub_epi32(vals, b)};
+    }
+
+    static Vectorized requantize_from_int(
+        const int_vec_return_type& inp,
+        float multiplier,
+        int32_t zero_point) {
+      __m256 multiplier_v = _mm256_set1_ps(multiplier);
+      __m256i zero_point_v = _mm256_set1_epi32(zero_point);
+
+      __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v);
+      __m256i rounded = _mm256_cvtps_epi32(scaled);
+      return _mm256_add_epi32(rounded, zero_point_v);
+    }
+
+ private:
+    // Load from memory constructor
+    Vectorized(const void* ptr) {
+      vals = _mm256_loadu_si256((const __m256i*)ptr);
+    }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline operator*(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return _mm256_mullo_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator+(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return _mm256_add_epi32(a, b);
+}
+
+/*
+ * Convert values from int32 back to int8/uint8
+ */
+template 
+__m256i RequantizeAvx2(
+    const std::array, 4>& inp,
+    __m256 multiplier,
+    __m256i zp) {
+  static_assert(
+      std::is_same::value || std::is_same::value,
+      "Only int8_t/uint8_t are supported");
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+  __m256i permute_mask_v =
+      _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+  __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier);
+  __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[1]), multiplier);
+  __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[2]), multiplier);
+  __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[3]), multiplier);
+
+  __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+  __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+  __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+  __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+  /* Add zero point */
+  __m256i x_v = _mm256_add_epi32(x_rounded_v, zp);
+  __m256i y_v = _mm256_add_epi32(y_rounded_v, zp);
+  __m256i z_v = _mm256_add_epi32(z_rounded_v, zp);
+  __m256i w_v = _mm256_add_epi32(w_rounded_v, zp);
+
+  /* Pack to int16_t and saturate */
+  __m256i xy_packed_v = _mm256_packs_epi32(x_v, y_v);
+  __m256i zw_packed_v = _mm256_packs_epi32(z_v, w_v);
+
+  __m256i xyzw_clamped_v =
+      pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val);
+
+  /*
+   * xyzw_clamped_v has results in the following layout so we need to
+   * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
+   */
+  xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+  return xyzw_clamped_v;
+}
+
+template<>
+struct Vectorized : public Vectorizedqi {
+    static constexpr int size() {
+        return 32;
+    }
+
+    static constexpr int float_num_vecs() {
+        return 4;
+    }
+
+    static constexpr int int_num_vecs() {
+        return 4;
+    }
+
+    using float_vec_return_type = std::array, 4>;
+    using int_vec_return_type = std::array, 4>;
+    using value_type = typename c10::qint8::underlying;
+
+ public:
+    using Vectorizedqi::Vectorizedqi;
+
+    Vectorized() {}
+    Vectorized(__m256i vals_) { vals = vals_;}
+
+    // Broadcast constructor
+    Vectorized(const c10::qint8& val) {
+        value_type uw = val.val_;
+        vals = _mm256_set1_epi8(uw);
+    }
+
+    // This is needed because the compiler emits awful code for the default
+    // constructor for moving the enum
+    // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy)
+    C10_CLANG_DIAGNOSTIC_PUSH()
+    #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy")
+    C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy")
+    #endif
+    Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { }
+    C10_CLANG_DIAGNOSTIC_POP()
+
+    void store(void* ptr, int count = size()) const {
+        if (count != size()) {
+            memcpy(ptr, &vals, count * sizeof(value_type));
+        } else {
+            _mm256_storeu_si256((__m256i*)ptr, vals);
+        }
+    }
+
+    static Vectorized loadu(const void* ptr) {
+        return Vectorized(ptr);
+    }
+
+    static Vectorized loadu(const void* ptr, int64_t count) {
+        __at_align__ value_type tmp_values[size()];
+        // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+        // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+        // instructions while a loop would be compiled to one instruction.
+        for (const auto i : c10::irange(size())) {
+          tmp_values[i] = 0;
+        }
+        std::memcpy(
+            tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+        return _mm256_loadu_si256((const __m256i*)tmp_values);
+    }
+
+ private:
+    __m256i cvtepi8_epi32(__m128i epi8_vals) const {
+        return _mm256_cvtepi8_epi32(epi8_vals);
+    }
+
+ public:
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized /*zero_point*/,
+      Vectorized scale_neg_zp_premul) const {
+    __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
+    __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
+    __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
+    __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
+
+    __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
+    __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
+    __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
+    __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
+
+    auto val0 =
+        vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul);
+    auto val1 =
+        vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul);
+    auto val2 =
+        vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul);
+    auto val3 =
+        vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul);
+    return {val0, val1, val2, val3};
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
+    __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
+    __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
+    __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
+
+    __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
+    __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
+    __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
+    __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
+
+    auto val0 = (Vectorized(float_val0) - zero_point) * scale;
+    auto val1 = (Vectorized(float_val1) - zero_point) * scale;
+    auto val2 = (Vectorized(float_val2) - zero_point) * scale;
+    auto val3 = (Vectorized(float_val3) - zero_point) * scale;
+    return {val0, val1, val2, val3};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float /*scale*/,
+      int32_t zero_point,
+      float inverse_scale) {
+    auto* rhs_data = (float*)rhs.data();
+    int8_t quantized_values[32];
+    QuantizeAvx2(
+        rhs_data, quantized_values, 32, inverse_scale, zero_point);
+    return Vectorized::loadu(quantized_values);
+  }
+
+  Vectorized maximum(Vectorized b) const {
+      return _mm256_max_epi8(vals, b.vals);
+    }
+
+  Vectorized minimum(Vectorized b) const {
+      return _mm256_min_epi8(vals, b.vals);
+    }
+
+    Vectorized relu(Vectorized zero_point) const {
+        return maximum(zero_point);
+    }
+
+    Vectorized relu6(
+        Vectorized zero_point,
+        Vectorized q_six) {
+      return _mm256_min_epi8(
+          _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+    }
+
+    int_vec_return_type widening_subtract(Vectorized b) const {
+      __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
+      __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
+      __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
+      __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
+
+      __m256i int32_val0 = cvtepi8_epi32(int_val0);
+      __m256i int32_val1 = cvtepi8_epi32(int_val1);
+      __m256i int32_val2 = cvtepi8_epi32(int_val2);
+      __m256i int32_val3 = cvtepi8_epi32(int_val3);
+
+      __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0));
+      __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1));
+      __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2));
+      __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3));
+
+      __m256i int32_b0 = cvtepi8_epi32(int_b0);
+      __m256i int32_b1 = cvtepi8_epi32(int_b1);
+      __m256i int32_b2 = cvtepi8_epi32(int_b2);
+      __m256i int32_b3 = cvtepi8_epi32(int_b3);
+
+      __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0);
+      __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1);
+      __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2);
+      __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3);
+
+      return {Vectorized(res_0),
+              Vectorized(res_1),
+              Vectorized(res_2),
+              Vectorized(res_3)};
+    }
+
+    static Vectorized requantize_from_int(
+        const int_vec_return_type& inp,
+        float multiplier,
+        int32_t zero_point) {
+      __m256 multiplier_v = _mm256_set1_ps(multiplier);
+      __m256i zero_point_v = _mm256_set1_epi32(zero_point);
+      return RequantizeAvx2(inp, multiplier_v, zero_point_v);
+    }
+
+ private:
+    // Load from memory constructor
+    Vectorized(const void* ptr) {
+        vals = _mm256_loadu_si256((const __m256i*)ptr);
+    }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template<>
+struct Vectorized : public Vectorizedqi {
+    static constexpr int size() {
+        return 32;
+    }
+
+    static constexpr int float_num_vecs() {
+        return 4;
+    }
+
+    static constexpr int int_num_vecs() {
+        return 4;
+    }
+
+    using float_vec_return_type = std::array, 4>;
+    using int_vec_return_type = std::array, 4>;
+    using value_type = typename c10::quint8::underlying;
+
+ public:
+    using Vectorizedqi::Vectorizedqi;
+    Vectorized() {}
+
+    Vectorized(__m256i vals_) { vals = vals_;}
+
+    // Broadcast constructor
+    Vectorized(const c10::quint8& val) {
+        value_type uw = val.val_;
+        vals = _mm256_set1_epi8(uw);
+    }
+
+    // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy)
+    C10_CLANG_DIAGNOSTIC_PUSH()
+    #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy")
+    C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy")
+    #endif
+    Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { }
+    C10_CLANG_DIAGNOSTIC_POP()
+
+    void store(void* ptr, int count = size()) const {
+        if (count != size()) {
+            memcpy(ptr, &vals, count * sizeof(value_type));
+        } else {
+            _mm256_storeu_si256((__m256i*)ptr, vals);
+        }
+    }
+
+    static Vectorized loadu(const void* ptr) {
+        return Vectorized(ptr);
+    }
+
+    static Vectorized loadu(const void* ptr, int64_t count) {
+        __at_align__ value_type tmp_values[size()];
+        // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+        // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+        // instructions while a loop would be compiled to one instruction.
+        for (const auto i : c10::irange(size())) {
+          tmp_values[i] = 0;
+        }
+        std::memcpy(
+            tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+        return _mm256_loadu_si256((const __m256i*)tmp_values);
+    }
+
+ private:
+    __m256i cvtepu8_epi32(__m128i epu8_vals) const {
+        return _mm256_cvtepu8_epi32(epu8_vals);
+    }
+
+ public:
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized /*zero_point*/,
+      Vectorized scale_zp_premul) const {
+    __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
+    __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
+    __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
+    __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
+
+    __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
+    __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
+    __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
+    __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
+
+    auto val0 =
+        vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul);
+    auto val1 =
+        vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul);
+    auto val2 =
+        vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul);
+    auto val3 =
+        vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul);
+    return {val0, val1, val2, val3};
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
+    __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
+    __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
+    __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
+
+    __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
+    __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
+    __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
+    __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
+
+    auto val0 = (Vectorized(float_val0) - zero_point) * scale;
+    auto val1 = (Vectorized(float_val1) - zero_point) * scale;
+    auto val2 = (Vectorized(float_val2) - zero_point) * scale;
+    auto val3 = (Vectorized(float_val3) - zero_point) * scale;
+    return {val0, val1, val2, val3};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float /*scale*/,
+      int32_t zero_point,
+      float inverse_scale) {
+    auto* rhs_data = (float*)rhs.data();
+    uint8_t quantized_values[32];
+    QuantizeAvx2(
+        rhs_data, quantized_values, 32, inverse_scale, zero_point);
+    return Vectorized::loadu(quantized_values);
+  }
+
+  Vectorized maximum(Vectorized b) const {
+      return _mm256_max_epu8(vals, b.vals);
+    }
+
+  Vectorized minimum(Vectorized b) const {
+      return _mm256_min_epu8(vals, b.vals);
+    }
+
+    Vectorized relu(Vectorized zero_point) const {
+        return maximum(zero_point);
+    }
+
+    Vectorized relu6(
+        Vectorized zero_point,
+        Vectorized q_six) {
+      return _mm256_min_epu8(
+          _mm256_max_epu8(vals, zero_point.vals), q_six.vals);
+    }
+
+    int_vec_return_type widening_subtract(Vectorized b) const {
+      __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
+      __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
+      __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
+      __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
+
+      __m256i int32_val0 = cvtepu8_epi32(int_val0);
+      __m256i int32_val1 = cvtepu8_epi32(int_val1);
+      __m256i int32_val2 = cvtepu8_epi32(int_val2);
+      __m256i int32_val3 = cvtepu8_epi32(int_val3);
+
+      __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0));
+      __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1));
+      __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2));
+      __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3));
+
+      __m256i int32_b0 = cvtepu8_epi32(int_b0);
+      __m256i int32_b1 = cvtepu8_epi32(int_b1);
+      __m256i int32_b2 = cvtepu8_epi32(int_b2);
+      __m256i int32_b3 = cvtepu8_epi32(int_b3);
+
+      __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0);
+      __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1);
+      __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2);
+      __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3);
+      return {Vectorized(res_0),
+              Vectorized(res_1),
+              Vectorized(res_2),
+              Vectorized(res_3)};
+    }
+
+    static Vectorized requantize_from_int(
+        const int_vec_return_type& inp,
+        float multiplier,
+        int32_t zero_point) {
+      __m256 multiplier_v = _mm256_set1_ps(multiplier);
+      __m256i zero_point_v = _mm256_set1_epi32(zero_point);
+      return RequantizeAvx2(inp, multiplier_v, zero_point_v);
+    }
+
+ private:
+
+    // Load from memory constructor
+    Vectorized(const void* ptr) {
+        vals = _mm256_loadu_si256((const __m256i*)ptr);
+    }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+#else
+
+// NOTE: These are low-performance implementations that we fall back on
+// if we are not building with AVX2. This may not be an issue, because
+// currently for quantization we assume the user has at least AVX512
+// installed, so these can simply act as a reference implementation.
+//
+// If in the future we relax this requirement (AVX2+), we should probably
+// revisit these implementations
+
+template <
+    typename T,
+    typename float_vec_return_type_,
+    typename int_vec_return_type_,
+    int size_>
+struct VectorizedQuantizedConverter {
+  static constexpr int size() {
+    return size_;
+  }
+
+  static constexpr int float_num_vecs() {
+    return size() / 8;
+  }
+
+  static constexpr int int_num_vecs() {
+    return size() / 8;
+  }
+
+  using float_vec_return_type = float_vec_return_type_;
+  using int_vec_return_type = int_vec_return_type_;
+
+  using value_type = typename T::underlying;
+  std::array vals;
+
+  VectorizedQuantizedConverter(T val) {
+    for (const auto i : c10::irange(size())) {
+      vals[i] = val.val_;
+    }
+  }
+
+  VectorizedQuantizedConverter(const void* ptr) {
+    memcpy(vals.data(), ptr, sizeof(value_type) * size());
+  }
+
+  void store(void* ptr, int count = size()) const {
+    memcpy(ptr, vals.data(), count * sizeof(value_type));
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized /*scale_zp_premul*/) const {
+    float_vec_return_type rv;
+    for (const auto i : c10::irange(float_num_vecs())) {
+      float tmp_vals[8];
+      for (const auto j : c10::irange(8)) {
+        tmp_vals[j] = at::native::dequantize_val(
+            scale[j], zero_point[j], T(vals[8 * i + j]));
+      }
+      rv[i] = Vectorized(tmp_vals[0],
+          tmp_vals[1],
+          tmp_vals[2],
+          tmp_vals[3],
+          tmp_vals[4],
+          tmp_vals[5],
+          tmp_vals[6],
+          tmp_vals[7]);
+    }
+    return rv;
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    Vectorized scale_zp_premul;
+    return dequantize(scale, zero_point, scale_zp_premul);
+  }
+
+ protected:
+  VectorizedQuantizedConverter() {}
+};
+
+template <>
+struct Vectorized : public VectorizedQuantizedConverter<
+                                 c10::qint32,
+                                 std::array, 1>,
+                                 std::array, 1>,
+                                 8> {
+  Vectorized()
+      : VectorizedQuantizedConverter<
+            c10::qint32,
+            std::array, 1>,
+            std::array, 1>,
+            8>() {}
+  Vectorized(c10::qint32 val)
+      : VectorizedQuantizedConverter<
+            c10::qint32,
+            std::array, 1>,
+            std::array, 1>,
+            8>(val) {}
+  Vectorized(const void* ptr)
+      : VectorizedQuantizedConverter<
+            c10::qint32,
+            std::array, 1>,
+            std::array, 1>,
+            8>(ptr) {}
+
+  static Vectorized loadu(const void* ptr) {
+    return Vectorized(ptr);
+  }
+
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ value_type tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(
+        tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+    return Vectorized(tmp_values);
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float /*inverse_scale*/) {
+    std::array qvals;
+    std::array float_vals;
+
+    for (const auto i : c10::irange(float_num_vecs())) {
+      rhs[i].store(&float_vals[i * 8], 8);
+    }
+
+    at::native::quantize_vec(
+        scale,
+        zero_point,
+        float_vals.data(),
+        (c10::qint32*)qvals.data(),
+        8 * float_num_vecs());
+
+    return Vectorized::loadu(qvals.data());
+  }
+
+  Vectorized maximum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::max(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized minimum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized relu(Vectorized zero_point) const  {
+    return maximum(zero_point);
+  }
+
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(
+          std::max(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    int_vec_return_type retval;
+    for (const auto i : c10::irange(size())) {
+      retval[0].vals[i] = vals[i] - b.vals[i];
+    }
+    return retval;
+  }
+
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] =
+          std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) +
+          zero_point;
+    }
+    return retval;
+  }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline operator*(
+    const Vectorized& a,
+    const Vectorized& b) {
+  Vectorized retval;
+  for (const auto i : c10::irange(std::decay_t::size())) {
+    retval.vals[i] = a.vals[i] * b.vals[i];
+  }
+  return retval;
+}
+
+template <>
+Vectorized inline operator+(
+    const Vectorized& a,
+    const Vectorized& b) {
+  Vectorized retval;
+  for (const auto i : c10::irange(std::decay_t::size())) {
+    retval.vals[i] = a.vals[i] + b.vals[i];
+  }
+  return retval;
+}
+
+template <>
+struct Vectorized : public VectorizedQuantizedConverter<
+                                c10::qint8,
+                                std::array, 4>,
+                                std::array, 4>,
+                                32> {
+  Vectorized()
+      : VectorizedQuantizedConverter<
+            c10::qint8,
+            std::array, 4>,
+            std::array, 4>,
+            32>() {}
+  Vectorized(c10::qint8 val)
+      : VectorizedQuantizedConverter<
+            c10::qint8,
+            std::array, 4>,
+            std::array, 4>,
+            32>(val) {}
+  Vectorized(const void* ptr)
+      : VectorizedQuantizedConverter<
+            c10::qint8,
+            std::array, 4>,
+            std::array, 4>,
+            32>(ptr) {}
+
+  static Vectorized loadu(const void* ptr) {
+    return Vectorized(ptr);
+  }
+
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ value_type tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(
+        tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+    return Vectorized(tmp_values);
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float /*inverse_scale*/) {
+    std::array qvals;
+    std::array float_vals;
+
+    for (const auto i : c10::irange(float_num_vecs())) {
+      rhs[i].store(&float_vals[i * 8], 8);
+    }
+
+    at::native::quantize_vec(
+        scale,
+        zero_point,
+        float_vals.data(),
+        (c10::qint8*)qvals.data(),
+        8 * float_num_vecs());
+
+    return Vectorized::loadu(qvals.data());
+  }
+
+  Vectorized maximum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::max(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized minimum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized relu(Vectorized zero_point) const {
+    return maximum(zero_point);
+  }
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(
+          std::max(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    int_vec_return_type retval;
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        retval[i].vals[j] =
+            static_cast(vals[i * elem_per_int_vec + j]) -
+            static_cast(b.vals[i * elem_per_int_vec + j]);
+      }
+    }
+    return retval;
+  }
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    constexpr auto min_val = std::numeric_limits::min();
+    constexpr auto max_val = std::numeric_limits::max();
+    Vectorized retval;
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        int32_t rounded =
+            std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) +
+            zero_point;
+        retval.vals[i * elem_per_int_vec + j] =
+            std::min(std::max(rounded, min_val), max_val);
+      }
+    }
+    return retval;
+  }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+struct Vectorized : public VectorizedQuantizedConverter<
+                                 c10::quint8,
+                                 std::array, 4>,
+                                 std::array, 4>,
+                                 32> {
+  Vectorized()
+      : VectorizedQuantizedConverter<
+            c10::quint8,
+            std::array, 4>,
+            std::array, 4>,
+            32>() {}
+  Vectorized(c10::quint8 val)
+      : VectorizedQuantizedConverter<
+            c10::quint8,
+            std::array, 4>,
+            std::array, 4>,
+            32>(val) {}
+  Vectorized(const void* ptr)
+      : VectorizedQuantizedConverter<
+            c10::quint8,
+            std::array, 4>,
+            std::array, 4>,
+            32>(ptr) {}
+
+  static Vectorized loadu(const void* ptr) {
+    return Vectorized(ptr);
+  }
+
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ value_type tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(
+        tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+    return Vectorized(tmp_values);
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float /*inverse_scale*/) {
+    std::array qvals;
+    std::array float_vals;
+
+    for (const auto i : c10::irange(float_num_vecs())) {
+      rhs[i].store(&float_vals[i * 8], 8);
+    }
+
+    at::native::quantize_vec(
+        scale,
+        zero_point,
+        float_vals.data(),
+        (c10::quint8*)qvals.data(),
+        8 * float_num_vecs());
+
+    return Vectorized::loadu(qvals.data());
+  }
+
+  Vectorized maximum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::max(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized minimum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized relu(Vectorized zero_point) const {
+    return maximum(zero_point);
+  }
+
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(
+          std::max(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    int_vec_return_type retval;
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        retval[i].vals[j] =
+            static_cast(vals[i * elem_per_int_vec + j]) -
+            static_cast(b.vals[i * elem_per_int_vec + j]);
+      }
+    }
+    return retval;
+  }
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    constexpr auto min_val = std::numeric_limits::min();
+    constexpr auto max_val = std::numeric_limits::max();
+    Vectorized retval;
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        int32_t rounded =
+            std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) +
+            zero_point;
+        retval.vals[i * elem_per_int_vec + j] =
+            std::min(std::max(rounded, min_val), max_val);
+      }
+    }
+    return retval;
+  }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..d904c712ed61d39267e0e9a1e580d50a7e943614
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h
@@ -0,0 +1,73 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+inline std::tuple, Vectorized> convert_bfloat16_float(
+    const Vectorized& a) {
+  constexpr int64_t K = Vectorized::size();
+  __at_align__ float arr[K];
+  __at_align__ BFloat16 arr2[K];
+  a.store(arr2);
+  convert(arr2, arr, K);
+  return std::make_tuple(
+      Vectorized::loadu(arr),
+      Vectorized::loadu(arr + Vectorized::size()));
+}
+
+inline Vectorized convert_float_bfloat16(
+    const Vectorized& a,
+    const Vectorized& b) {
+  constexpr int64_t K = Vectorized::size();
+  __at_align__ float arr[K];
+  __at_align__ BFloat16 arr2[K];
+  a.store(arr);
+  b.store(arr + Vectorized::size());
+  convert(arr, arr2, K);
+  return Vectorized::loadu(arr2);
+}
+
+inline void load_fp32_from_bf16(const c10::BFloat16* data, Vectorized& out) {
+  __at_align__ float values[Vectorized::size()];
+  for (const auto k : c10::irange(Vectorized::size())) {
+    values[k] = data[k];
+  }
+  out = Vectorized::loadu(values);
+}
+
+inline void load_fp32_from_bf16(
+    const c10::BFloat16* data,
+    Vectorized& out1,
+    Vectorized& out2) {
+  load_fp32_from_bf16(data, out1);
+  data += Vectorized::size();
+  load_fp32_from_bf16(data, out2);
+}
+
+inline void load_fp32_from_fp16(const c10::Half* data, Vectorized& out) {
+  __at_align__ float values[Vectorized::size()];
+  for (const auto k : c10::irange(Vectorized::size())) {
+    values[k] = data[k];
+  }
+  out = Vectorized::loadu(values);
+}
+
+inline void load_fp32_from_fp16(
+    const c10::Half* data,
+    Vectorized& out1,
+    Vectorized& out2) {
+  load_fp32_from_fp16(data, out1);
+  data += Vectorized::size();
+  load_fp32_from_fp16(data, out2);
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..c3f8ae2fc513430289ae989355b540dc20527123
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h
@@ -0,0 +1,246 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+// Note: header order is important here
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+
+namespace at {
+namespace vec {
+
+inline namespace CPU_CAPABILITY {
+
+DEFINE_CLAMP_FUNCS(c10::quint8)
+DEFINE_CLAMP_FUNCS(c10::qint8)
+DEFINE_CLAMP_FUNCS(c10::qint32)
+DEFINE_CLAMP_FUNCS(int16_t)
+DEFINE_CLAMP_FUNCS(int32_t)
+DEFINE_CLAMP_FUNCS(int64_t)
+DEFINE_CLAMP_FUNCS(float)
+DEFINE_CLAMP_FUNCS(double)
+
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      vec_madd(a.vec0(), b.vec0(), c.vec0()),
+      vec_madd(a.vec1(), b.vec1(), c.vec1())};
+}
+
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
+}
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
+}
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
+}
+
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)
+
+template <>
+Vectorized C10_ALWAYS_INLINE
+convert_to_int_of_same_size(const Vectorized& src) {
+  return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())};
+}
+
+template <>
+Vectorized C10_ALWAYS_INLINE
+convert_to_int_of_same_size(
+    const Vectorized& src) {
+  return Vectorized{vec_signed(src.vec0()), vec_signed(src.vec1())};
+}
+
+template <>
+inline void convert(const int32_t* src, float* dst, int64_t n) {
+  // int32_t and float have same size
+  int64_t i;
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    const int32_t* src_a = src + i;
+    float* dst_a = dst + i;
+    vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast(src_a));
+    vint32 input_vec1 =
+        vec_vsx_ld(offset16, reinterpret_cast(src_a));
+    vfloat32 c0 = vec_float(input_vec0);
+    vfloat32 c1 = vec_float(input_vec1);
+    vec_vsx_st(c0, offset0, dst_a);
+    vec_vsx_st(c1, offset16, dst_a);
+  }
+
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+inline void convert(const int64_t* src, double* dst, int64_t n) {
+  int64_t i;
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    const int64_t* src_a = src + i;
+    double* dst_a = dst + i;
+    vint64 input_vec0 =
+        vec_vsx_ld(offset0, reinterpret_cast(src_a));
+    vint64 input_vec1 =
+        vec_vsx_ld(offset16, reinterpret_cast(src_a));
+    vfloat64 c0 = vec_double(input_vec0);
+    vfloat64 c1 = vec_double(input_vec1);
+    vec_vsx_st(c0, offset0, reinterpret_cast(dst_a));
+    vec_vsx_st(c1, offset16, reinterpret_cast(dst_a));
+  }
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+//Generic implementation to fix compiler error
+//TO-DO : Add optimized version for ppc64
+inline std::tuple, Vectorized> convert_half_float(
+    const Vectorized& a) {
+  constexpr int64_t K = Vectorized::size();
+  __at_align__ float arr[K];
+  __at_align__ Half arr2[K];
+  a.store(arr2);
+  convert(arr2, arr, K);
+  return std::make_tuple(
+       Vectorized::loadu(arr),
+       Vectorized::loadu(arr + Vectorized::size()));
+}
+
+inline Vectorized convert_float_half(
+    const Vectorized& a, const Vectorized& b) {
+  constexpr int64_t K = Vectorized::size();
+  __at_align__ float arr[K];
+  __at_align__ Half arr2[K];
+  a.store(arr);
+  b.store(arr + Vectorized::size());
+  convert(arr, arr2, K);
+  return Vectorized::loadu(arr2);
+};
+
+template <>
+std::pair, Vectorized> inline interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a      = {a0, a1, a2, a3}
+  //   b      = {b0, b1, b2, b3}
+
+  vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0);
+  vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3);
+  vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0);
+  vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3);
+  //   return {a0, b0, a1, b1}
+  //          {a2, b2, a3, b3}
+  return std::make_pair(
+      Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11});
+}
+
+template <>
+std::pair, Vectorized> inline deinterleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1}
+  //   b = {a2, b2, a3, b3}
+  vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0);
+  vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0);
+
+  vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3);
+  vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3);
+
+  // swap lanes:
+  //   return {a0, a1, a2, a3}
+  //          {b0, b1, b2, b3}
+  return std::make_pair(
+      Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23});
+}
+
+template <>
+std::pair, Vectorized> inline interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a = {a0, a1, a2, a3,, a4, a5, a6, a7}
+  //   b = {b0, b1, b2, b3,, b4, b5, b6, b7}
+
+  vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0());
+  vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0());
+
+  vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
+  vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1());
+  // group cols crossing lanes:
+  //   return {a0, b0, a1, b1,, a2, b2, a3, b3}
+  //          {a4, b4, a5, b5,, a6, b6, a7, b7}
+
+  return std::make_pair(
+      Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233});
+}
+
+template <>
+std::pair, Vectorized> inline deinterleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1,, a2, b2, a3, b3}
+  //   b = {a4, b4, a5, b5,, a6, b6, a7, b7}
+
+  // {a0,a2,b0,b2} {a1,a3,b1,b3}
+  vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
+  vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());
+
+  vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
+  vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);
+
+  vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
+  vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());
+
+  vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
+  vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);
+
+  // it could be done with vec_perm ,too
+  // swap lanes:
+  //   return {a0, a1, a2, a3,, a4, a5, a6, a7}
+  //          {b0, b1, b2, b3,, b4, b5, b6, b7}
+
+  return std::make_pair(
+      Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2});
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..0f7147213550af8d7dd980482ad8ad30b1f458d0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
@@ -0,0 +1,560 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+using ComplexDbl = c10::complex;
+
+template <>
+class Vectorized {
+  union {
+    struct {
+      vfloat64 _vec0;
+      vfloat64 _vec1;
+    };
+    struct {
+      vbool64 _vecb0;
+      vbool64 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = ComplexDbl;
+  using vec_internal_type = vfloat64;
+  using vec_internal_mask_type = vbool64;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 2;
+  }
+  Vectorized() {}
+  C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
+
+  Vectorized(ComplexDbl val) {
+    double real_value = val.real();
+    double imag_value = val.imag();
+    _vec0 = vfloat64{real_value, imag_value};
+    _vec1 = vfloat64{real_value, imag_value};
+  }
+  Vectorized(ComplexDbl val1, ComplexDbl val2) {
+    _vec0 = vfloat64{val1.real(), val1.imag()};
+    _vec1 = vfloat64{val2.real(), val2.imag()};
+  }
+
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return {a._vec0, b._vec1};
+  }
+
+  template 
+  static Vectorized C10_ALWAYS_INLINE
+  el_blend(const Vectorized& a, const Vectorized& b) {
+    const vbool64 mask_1st = VsxDblMask1(mask);
+    const vbool64 mask_2nd = VsxDblMask2(mask);
+    return {
+        (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st),
+        (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  static Vectorized blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // convert std::complex index mask to V index mask: xy -> xxyy
+    auto mask_complex =
+        Vectorized(vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0));
+    return {
+        vec_sel(a._vec0, b._vec0, mask_complex._vecb0),
+        vec_sel(a._vec1, b._vec1, mask_complex._vecb1)};
+  }
+
+  static Vectorized C10_ALWAYS_INLINE elwise_blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    return {
+        vec_sel(a._vec0, b._vec0, mask._vecb0),
+        vec_sel(a._vec1, b._vec1, mask._vecb1)};
+  }
+  template 
+  static Vectorized arange(
+      ComplexDbl base = 0.,
+      step_t step = static_cast(1)) {
+    return Vectorized(base, base + step);
+  }
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+    }
+    return b;
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {
+        vec_vsx_ld(offset0, reinterpret_cast(tmp_values)),
+        vec_vsx_ld(offset16, reinterpret_cast(tmp_values))};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values));
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+
+  const ComplexDbl& operator[](int idx) const = delete;
+  ComplexDbl& operator[](int idx) = delete;
+
+  Vectorized map(ComplexDbl (*const f)(ComplexDbl)) const {
+    __at_align__ ComplexDbl tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+
+  Vectorized map(ComplexDbl (*const f)(const ComplexDbl&)) const {
+    __at_align__ ComplexDbl tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+
+  Vectorized el_swapped() const {
+    vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2);
+    vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2);
+    return {v0, v1};
+  }
+
+  Vectorized el_madd(
+      const Vectorized& multiplier,
+      const Vectorized& val) const {
+    return {
+        vec_madd(_vec0, multiplier._vec0, val._vec0),
+        vec_madd(_vec1, multiplier._vec1, val._vec1)};
+  }
+
+  Vectorized el_mergeo() const {
+    vfloat64 v0 = vec_splat(_vec0, 1);
+    vfloat64 v1 = vec_splat(_vec1, 1);
+    return {v0, v1};
+  }
+
+  Vectorized el_mergee() const {
+    vfloat64 v0 = vec_splat(_vec0, 0);
+    vfloat64 v1 = vec_splat(_vec1, 0);
+    return {v0, v1};
+  }
+
+  static Vectorized el_mergee(
+      Vectorized& first,
+      Vectorized& second) {
+    return {
+        vec_mergeh(first._vec0, second._vec0),
+        vec_mergeh(first._vec1, second._vec1)};
+  }
+
+  static Vectorized el_mergeo(
+      Vectorized& first,
+      Vectorized& second) {
+    return {
+        vec_mergel(first._vec0, second._vec0),
+        vec_mergel(first._vec1, second._vec1)};
+  }
+
+  Vectorized abs_2_() const {
+    auto a = (*this).elwise_mult(*this);
+    auto permuted = a.el_swapped();
+    a = a + permuted;
+    return a;
+  }
+
+  Vectorized abs_() const {
+    auto vi = el_mergeo();
+    auto vr = el_mergee();
+    return {Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)};
+  }
+
+  Vectorized abs() const {
+    return abs_() & vd_real_mask;
+  }
+
+  Vectorized angle_() const {
+    // angle = atan2(b/a)
+    // auto b_a = _mm256_permute_pd(values, 0x05);     // b        a
+    // return Sleef_atan2d4_u10(values, b_a);          // 90-angle angle
+    Vectorized ret;
+    ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]);
+    ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]);
+    return ret;
+  }
+
+  Vectorized angle() const {
+    return angle_() & vd_real_mask;
+  }
+
+  Vectorized real_() const {
+    return *this & vd_real_mask;
+  }
+  Vectorized real() const {
+    return *this & vd_real_mask;
+  }
+  Vectorized imag_() const {
+    return *this & vd_imag_mask;
+  }
+  Vectorized imag() const {
+    return imag_().el_swapped();
+  }
+
+  Vectorized conj_() const {
+    return *this ^ vd_isign_mask;
+  }
+  Vectorized conj() const {
+    return *this ^ vd_isign_mask;
+  }
+
+  Vectorized log() const {
+    // Most trigonomic ops use the log() op to improve complex number
+    // performance.
+    return map(std::log);
+  }
+
+  Vectorized log2() const {
+    // log2eB_inv
+    auto ret = log();
+    return ret.elwise_mult(vd_log2e_inv);
+  }
+  Vectorized log10() const {
+    auto ret = log();
+    return ret.elwise_mult(vd_log10e_inv);
+  }
+
+  Vectorized log1p() const {
+    return map(std::log1p);
+  }
+
+  Vectorized asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+    auto conj = conj_();
+    auto b_a = conj.el_swapped();
+    auto ab = conj.elwise_mult(b_a);
+    auto im = ab + ab;
+    auto val_2 = (*this).elwise_mult(*this);
+    auto val_2_swapped = val_2.el_swapped();
+    auto re = horizontal_sub(val_2, val_2_swapped);
+    re = Vectorized(vd_one) - re;
+    auto root = el_blend<0x0A>(re, im).sqrt();
+    auto ln = (b_a + root).log();
+    return ln.el_swapped().conj();
+  }
+
+  Vectorized acos() const {
+    // acos(x) = pi/2 - asin(x)
+    return Vectorized(vd_pi_2) - asin();
+  }
+
+  Vectorized atan() const {
+    // atan(x) = i/2 * ln((i + z)/(i - z))
+    auto ione = Vectorized(vd_imag_one);
+    auto sum = ione + *this;
+    auto sub = ione - *this;
+    auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
+    return ln * vd_imag_half; // i/2*ln()
+  }
+  Vectorized atanh() const {
+    return map(std::atanh);
+  }
+
+  Vectorized sin() const {
+    return map(std::sin);
+  }
+  Vectorized sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized cos() const {
+    return map(std::cos);
+  }
+  Vectorized cosh() const {
+    return map(std::cosh);
+  }
+
+  Vectorized tan() const {
+    return map(std::tan);
+  }
+  Vectorized tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized ceil() const {
+    return {vec_ceil(_vec0), vec_ceil(_vec1)};
+  }
+  Vectorized floor() const {
+    return {vec_floor(_vec0), vec_floor(_vec1)};
+  }
+  Vectorized neg() const {
+    auto z = Vectorized(vd_zero);
+    return z - *this;
+  }
+  Vectorized round() const {
+    return {vec_rint(_vec0), vec_rint(_vec1)};
+  }
+
+  Vectorized trunc() const {
+    return {vec_trunc(_vec0), vec_trunc(_vec1)};
+  }
+
+  Vectorized elwise_sqrt() const {
+    return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
+  }
+
+  Vectorized sqrt() const {
+    return map(std::sqrt);
+  }
+
+  Vectorized reciprocal() const {
+    // re + im*i = (a + bi)  / (c + di)
+    // re = (ac + bd)/abs_2() = c/abs_2()
+    // im = (bc - ad)/abs_2() = d/abs_2()
+    auto c_d = *this ^ vd_isign_mask; // c       -d
+    auto abs = abs_2_();
+    return c_d.elwise_div(abs);
+  }
+
+  Vectorized rsqrt() const {
+    return sqrt().reciprocal();
+  }
+
+  static Vectorized horizontal_add(
+      Vectorized& first,
+      Vectorized& second) {
+    // Operates on individual floats, see _mm_hadd_ps
+    // {f0+f1, s0+s1, f2+f3, s2+s3, ...}
+    // i.e. it sums the re and im of each value and interleaves first and second:
+    // {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
+    return el_mergee(first, second) + el_mergeo(first, second);
+  }
+
+  static Vectorized horizontal_sub(
+      Vectorized& first,
+      Vectorized& second) {
+    // we will simulate it differently with 6 instructions total
+    // lets permute second so that we can add it getting horizontal sums
+    auto first_perm = first.el_swapped(); // 2perm
+    auto second_perm = second.el_swapped(); // 2perm
+    // summ
+    auto first_ret = first - first_perm; // 2sub
+    auto second_ret = second - second_perm; // 2 sub
+    // now lets choose evens
+    return el_mergee(first_ret, second_ret); // 2 mergee's
+  }
+
+  Vectorized inline operator*(const Vectorized& b) const {
+    //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+#if 1
+    // this is more vsx friendly than simulating horizontal from x86
+    auto vi = b.el_mergeo();
+    auto vr = b.el_mergee();
+    vi = vi ^ vd_rsign_mask;
+    auto ret = elwise_mult(vr);
+    auto vx_swapped = el_swapped();
+    ret = vx_swapped.el_madd(vi, ret);
+#else
+    auto ac_bd = elwise_mult(b);
+    auto d_c = b.el_swapped();
+    d_c = d_c ^ vd_isign_mask;
+    auto ad_bc = elwise_mult(d_c);
+    auto ret = horizontal_sub(ac_bd, ad_bc);
+#endif
+    return ret;
+  }
+
+  Vectorized inline operator/(const Vectorized& b) const {
+    // re + im*i = (a + bi)  / (c + di)
+    // re = (ac + bd)/abs_2()
+    // im = (bc - ad)/abs_2()
+    auto fabs_cd =  Vectorized{
+      vec_andc(b._vec0, vd_sign_mask),
+      vec_andc(b._vec1, vd_sign_mask)};       // |c|            |d|
+    auto fabs_dc =  fabs_cd.el_swapped();     // |d|            |c|
+    auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
+    auto a2 = elwise_div(scale);              // a/sc           b/sc
+    auto b2 = b.elwise_div(scale);            // c/sc           d/sc
+    auto acbd2 = a2.elwise_mult(b2);          // ac/sc^2        bd/sc^2
+    auto dc2 = b2.el_swapped();               // d/sc           c/sc
+    dc2 = dc2 ^ vd_rsign_mask;                // -d/sc          c/sc
+    auto adbc2 = a2.elwise_mult(dc2);         // -ad/sc^2       bc/sc^2
+    auto ret = horizontal_add(acbd2, adbc2);  // (ac+bd)/sc^2   (bc-ad)/sc^2
+    auto denom2 = b2.abs_2_();                // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
+    ret = ret.elwise_div(denom2);
+    return ret;
+  }
+
+  Vectorized exp() const {
+    return map(std::exp);
+  }
+  Vectorized exp2() const {
+    return map(exp2_impl);
+  }
+  Vectorized expm1() const {
+    return map(std::expm1);
+  }
+
+  Vectorized pow(const Vectorized& exp) const {
+    __at_align__ ComplexDbl x_tmp[size()];
+    __at_align__ ComplexDbl y_tmp[size()];
+    store(x_tmp);
+    exp.store(y_tmp);
+    for (const auto i : c10::irange(size())) {
+      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
+    }
+    return loadu(x_tmp);
+  }
+
+  Vectorized sgn() const {
+    return map(at::native::sgn_impl);
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized eq(const Vectorized& other) const {
+    auto eq = (*this == other);  // compares real and imag individually
+    // If both real numbers and imag numbers are equal, then the complex numbers are equal
+    return (eq.real() & eq.imag()) & vd_one;
+  }
+  Vectorized ne(const Vectorized& other) const {
+    auto ne = (*this != other);  // compares real and imag individually
+    // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+    return (ne.real() | ne.imag()) & vd_one;
+  }
+
+  DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne)
+
+  DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add)
+  DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub)
+  DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and)
+  DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or)
+  DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor)
+  // elementwise helpers
+  DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul)
+  DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div)
+  DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt)
+  DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge)
+  DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt)
+  DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple)
+  DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max)
+};
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
+  // auto max = _mm256_blendv_ps(a, b, mask);
+  auto mask = abs_a.elwise_lt(abs_b);
+  auto max = Vectorized::elwise_blendv(a, b, mask);
+
+  return max;
+  // Exploit the fact that all-ones is a NaN.
+  // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
+  // return _mm256_or_ps(max, isnan);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
+  // auto min = _mm256_blendv_ps(a, b, mask);
+  auto mask = abs_a.elwise_gt(abs_b);
+  auto min = Vectorized::elwise_blendv(a, b, mask);
+  return min;
+  // Exploit the fact that all-ones is a NaN.
+  // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
+  // return _mm256_or_ps(min, isnan);
+}
+
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..8894381bfc718518788e18f17af0c5fb9d5734f1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
@@ -0,0 +1,628 @@
+
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+using ComplexFlt = c10::complex;
+
+template <>
+class Vectorized {
+ private:
+  union {
+    struct {
+      vfloat32 _vec0;
+      vfloat32 _vec1;
+    };
+    struct {
+      vbool32 _vecb0;
+      vbool32 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = ComplexFlt;
+  using vec_internal_type = vfloat32;
+  using vec_internal_mask_type = vbool32;
+  using size_type = int;
+
+  static constexpr size_type size() {
+    return 4;
+  }
+  Vectorized() {}
+
+  C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
+
+  Vectorized(ComplexFlt val) {
+    float real_value = val.real();
+    float imag_value = val.imag();
+    _vec0 = vfloat32{real_value, imag_value, real_value, imag_value};
+    _vec1 = vfloat32{real_value, imag_value, real_value, imag_value};
+  }
+
+  Vectorized(ComplexFlt val1, ComplexFlt val2, ComplexFlt val3, ComplexFlt val4) {
+    _vec0 = vfloat32{val1.real(), val1.imag(), val2.real(), val2.imag()};
+    _vec1 = vfloat32{val3.real(), val3.imag(), val4.real(), val4.imag()};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    return {a._vec0, b._vec1};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxComplexMask1(mask);
+    return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxComplexMask1(mask);
+    return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_2nd = VsxComplexMask2(mask);
+    // generated masks
+    return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_2nd = VsxComplexMask2(mask);
+    // generated masks
+    return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxComplexMask1(mask);
+    const vbool32 mask_2nd = VsxComplexMask2(mask);
+    return {
+        (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
+        (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static Vectorized C10_ALWAYS_INLINE
+  el_blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxMask1(mask);
+    const vbool32 mask_2nd = VsxMask2(mask);
+    return {
+        (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
+        (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  static Vectorized blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // convert std::complex index mask to V index mask: xy -> xxyy
+    auto mask_complex = Vectorized(
+        vec_mergeh(mask._vec0, mask._vec0), vec_mergeh(mask._vec1, mask._vec1));
+    return {
+        vec_sel(a._vec0, b._vec0, reinterpret_cast(mask_complex._vec0)),
+        vec_sel(a._vec1, b._vec1, reinterpret_cast(mask_complex._vec1)),
+    };
+  }
+
+  static Vectorized elwise_blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    return {
+        vec_sel(a._vec0, b._vec0, reinterpret_cast(mask._vec0)),
+        vec_sel(a._vec1, b._vec1, reinterpret_cast(mask._vec1)),
+    };
+  }
+
+  template 
+  static Vectorized arange(
+      ComplexFlt base = 0.,
+      step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + ComplexFlt(2) * step,
+        base + ComplexFlt(3) * step);
+  }
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+    return b;
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {
+        vec_vsx_ld(offset0, reinterpret_cast(tmp_values)),
+        vec_vsx_ld(offset16, reinterpret_cast(tmp_values))};
+  }
+
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values));
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+
+  const ComplexFlt& operator[](int idx) const = delete;
+  ComplexFlt& operator[](int idx) = delete;
+
+  Vectorized map(ComplexFlt (*const f)(ComplexFlt)) const {
+    __at_align__ ComplexFlt tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+
+  Vectorized map(ComplexFlt (*const f)(const ComplexFlt&)) const {
+    __at_align__ ComplexFlt tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+
+  static Vectorized horizontal_add(
+      Vectorized& first,
+      Vectorized& second) {
+    // Operates on individual floats, see _mm_hadd_ps
+    // {f0+f1, s0+s1, f2+f3, s2+s3, ...}
+    // i.e. it sums the re and im of each value and interleaves first and second:
+    // {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
+    return el_mergee(first, second) + el_mergeo(first, second);
+  }
+
+  static Vectorized horizontal_sub_permD8(
+      Vectorized& first,
+      Vectorized& second) {
+    // we will simulate it differently with 6 instructions total
+    // lets permute second so that we can add it getting horizontal sums
+    auto first_perm = first.el_swapped(); // 2perm
+    auto second_perm = second.el_swapped(); // 2perm
+    // sum
+    auto first_ret = first - first_perm; // 2sub
+    auto second_ret = second - second_perm; // 2 sub
+    // now lets choose evens
+    return el_mergee(first_ret, second_ret); // 2 mergee's
+  }
+
+  Vectorized abs_2_() const {
+    auto a = (*this).elwise_mult(*this);
+    auto permuted = a.el_swapped();
+    a = a + permuted;
+    return a.el_mergee();
+  }
+
+  Vectorized abs_() const {
+    auto vi = el_mergeo();
+    auto vr = el_mergee();
+    return {Sleef_hypotf4_u05vsx(vr._vec0, vi._vec0), Sleef_hypotf4_u05vsx(vr._vec1, vi._vec1)};
+  }
+
+  Vectorized abs() const {
+    return abs_() & real_mask;
+  }
+
+  Vectorized real_() const {
+    return *this & real_mask;
+  }
+  Vectorized real() const {
+    return *this & real_mask;
+  }
+  Vectorized imag_() const {
+    return *this & imag_mask;
+  }
+  Vectorized imag() const {
+    // we can use swap_mask or sldwi
+    auto ret = imag_();
+    return {
+        vec_sldw(ret._vec0, ret._vec0, 3), vec_sldw(ret._vec1, ret._vec1, 3)};
+  }
+
+  Vectorized conj_() const {
+    return *this ^ isign_mask;
+  }
+  Vectorized conj() const {
+    return *this ^ isign_mask;
+  }
+
+  Vectorized log() const {
+    // Most trigonomic ops use the log() op to improve complex number
+    // performance.
+    return map(std::log);
+  }
+
+  Vectorized log2() const {
+    // log2eB_inv
+    auto ret = log();
+    return ret.elwise_mult(log2e_inv);
+  }
+  Vectorized log10() const {
+    auto ret = log();
+    return ret.elwise_mult(log10e_inv);
+  }
+
+  Vectorized log1p() const {
+    return map(std::log1p);
+  }
+
+  Vectorized el_swapped() const {
+    vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask);
+    vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask);
+    return {v0, v1};
+  }
+
+  Vectorized el_mergee() const {
+    // as mergee phased in , we can use vec_perm with mask
+    return {vec_mergee(_vecb0, _vecb0), vec_mergee(_vecb1, _vecb1)};
+  }
+
+  Vectorized el_mergeo() const {
+    // as mergeo phased in , we can use vec_perm with mask
+    return {vec_mergeo(_vecb0, _vecb0), vec_mergeo(_vecb1, _vecb1)};
+  }
+
+  Vectorized el_madd(
+      const Vectorized& multiplier,
+      const Vectorized& val) const {
+    return {
+        vec_madd(_vec0, multiplier._vec0, val._vec0),
+        vec_madd(_vec1, multiplier._vec1, val._vec1)};
+  }
+
+  static Vectorized el_mergee(
+      Vectorized& first,
+      Vectorized& second) {
+    return {
+        vec_mergee(first._vecb0, second._vecb0),
+        vec_mergee(first._vecb1, second._vecb1)};
+  }
+
+  static Vectorized el_mergeo(
+      Vectorized& first,
+      Vectorized& second) {
+    return {
+        vec_mergeo(first._vecb0, second._vecb0),
+        vec_mergeo(first._vecb1, second._vecb1)};
+  }
+
+  Vectorized angle_() const {
+    // angle = atan2(b/a)
+    // auto b_a = _mm256_permute_ps(values, 0xB1); // b        a
+    // return Sleef_atan2f8_u10(values, b_a); // 90-angle angle
+    Vectorized ret;
+    for (int i = 0; i < 4; i += 2) {
+      ret._vec0[i] = std::atan2(_vec0[i + 1], _vec0[i]);
+      ret._vec1[i] = std::atan2(_vec1[i + 1], _vec1[i]);
+    }
+    return ret;
+  }
+
+  Vectorized angle() const {
+    return angle_() & real_mask;
+  }
+
+  Vectorized sin() const {
+    return map(std::sin);
+  }
+  Vectorized sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized cos() const {
+    return map(std::cos);
+  }
+  Vectorized cosh() const {
+    return map(std::cosh);
+  }
+  Vectorized ceil() const {
+    return {vec_ceil(_vec0), vec_ceil(_vec1)};
+  }
+  Vectorized floor() const {
+    return {vec_floor(_vec0), vec_floor(_vec1)};
+  }
+  Vectorized neg() const {
+    auto z = Vectorized(zero);
+    return z - *this;
+  }
+  Vectorized round() const {
+    return {vec_round(_vec0), vec_round(_vec1)};
+  }
+  Vectorized tan() const {
+    return map(std::tan);
+  }
+  Vectorized tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized trunc() const {
+    return {vec_trunc(_vec0), vec_trunc(_vec1)};
+  }
+
+  Vectorized elwise_sqrt() const {
+    return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
+  }
+
+  Vectorized sqrt() const {
+    return map(std::sqrt);
+  }
+
+  Vectorized reciprocal() const {
+    // re + im*i = (a + bi)  / (c + di)
+    // re = (ac + bd)/abs_2() = c/abs_2()
+    // im = (bc - ad)/abs_2() = d/abs_2()
+    auto c_d = *this ^ isign_mask; // c       -d
+    auto abs = abs_2_();
+    return c_d.elwise_div(abs);
+  }
+
+  Vectorized rsqrt() const {
+    return sqrt().reciprocal();
+  }
+
+  Vectorized pow(const Vectorized& exp) const {
+    __at_align__ ComplexFlt x_tmp[size()];
+    __at_align__ ComplexFlt y_tmp[size()];
+    store(x_tmp);
+    exp.store(y_tmp);
+    for (const auto i : c10::irange(size())) {
+      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
+    }
+    return loadu(x_tmp);
+  }
+
+  Vectorized atan() const {
+    // atan(x) = i/2 * ln((i + z)/(i - z))
+    auto ione = Vectorized(imag_one);
+    auto sum = ione + *this;
+    auto sub = ione - *this;
+    auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
+    return ln * imag_half; // i/2*ln()
+  }
+  Vectorized atanh() const {
+    return map(std::atanh);
+  }
+
+  Vectorized acos() const {
+    // acos(x) = pi/2 - asin(x)
+    return Vectorized(pi_2) - asin();
+  }
+
+  Vectorized inline operator*(const Vectorized& b) const {
+    //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+
+#if 1
+    // this is more vsx friendly than simulating horizontal from x86
+
+    auto vi = b.el_mergeo();
+    auto vr = b.el_mergee();
+    vi = vi ^ rsign_mask;
+    auto ret = elwise_mult(vr);
+    auto vx_swapped = el_swapped();
+    ret = vx_swapped.el_madd(vi, ret);
+    return ret;
+
+#else
+
+    auto ac_bd = elwise_mult(b);
+    auto d_c = b.el_swapped();
+    d_c = d_c ^ isign_mask;
+    auto ad_bc = elwise_mult(d_c);
+    auto ret = horizontal_sub_permD8(ac_bd, ad_bc);
+    return ret;
+#endif
+  }
+
+  Vectorized inline operator/(const Vectorized& b) const {
+    // re + im*i = (a + bi)  / (c + di)
+    // re = (ac + bd)/abs_2()
+    // im = (bc - ad)/abs_2()
+    auto fabs_cd =  Vectorized{
+      vec_andc(b._vec0, sign_mask),
+      vec_andc(b._vec1, sign_mask)};          // |c|            |d|
+    auto fabs_dc =  fabs_cd.el_swapped();     // |d|            |c|
+    auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
+    auto a2 = elwise_div(scale);              // a/sc           b/sc
+    auto b2 = b.elwise_div(scale);            // c/sc           d/sc
+    auto acbd2 = a2.elwise_mult(b2);          // ac/sc^2        bd/sc^2
+    auto dc2 = b2.el_swapped();               // d/sc           c/sc
+    dc2 = dc2 ^ rsign_mask;                   // -d/sc          c/sc
+    auto adbc2 = a2.elwise_mult(dc2);         // -ad/sc^2       bc/sc^2
+    auto ret = horizontal_add(acbd2, adbc2);  // (ac+bd)/sc^2   (bc-ad)/sc^2
+    auto denom2 = b2.abs_2_();                // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
+    ret = ret.elwise_div(denom2);
+    return ret;
+  }
+
+  Vectorized asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+
+#if 1
+    auto conj = conj_();
+    auto b_a = conj.el_swapped();
+    auto ab = conj.elwise_mult(b_a);
+    auto im = ab + ab;
+    auto val_2 = (*this).elwise_mult(*this);
+    auto val_2_swapped = val_2.el_swapped();
+    auto re = horizontal_sub_permD8(val_2, val_2_swapped);
+    re = Vectorized(one) - re;
+    auto root = el_blend<0xAA>(re, im).sqrt();
+    auto ln = (b_a + root).log();
+    return ln.el_swapped().conj();
+#else
+    return map(std::asin);
+#endif
+  }
+
+  Vectorized exp() const {
+    return map(std::exp);
+  }
+  Vectorized exp2() const {
+    return map(exp2_impl);
+  }
+  Vectorized expm1() const {
+    return map(std::expm1);
+  }
+
+  Vectorized eq(const Vectorized& other) const {
+    auto eq = (*this == other);  // compares real and imag individually
+    // If both real numbers and imag numbers are equal, then the complex numbers are equal
+    return (eq.real() & eq.imag()) & one;
+  }
+  Vectorized ne(const Vectorized& other) const {
+    auto ne = (*this != other);  // compares real and imag individually
+    // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+    return (ne.real() | ne.imag()) & one;
+  }
+
+  Vectorized sgn() const {
+    return map(at::native::sgn_impl);
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  DEFINE_MEMBER_OP(operator==, ComplexFlt, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, ComplexFlt, vec_cmpne)
+
+  DEFINE_MEMBER_OP(operator+, ComplexFlt, vec_add)
+  DEFINE_MEMBER_OP(operator-, ComplexFlt, vec_sub)
+  DEFINE_MEMBER_OP(operator&, ComplexFlt, vec_and)
+  DEFINE_MEMBER_OP(operator|, ComplexFlt, vec_or)
+  DEFINE_MEMBER_OP(operator^, ComplexFlt, vec_xor)
+  // elementwise helpers
+  DEFINE_MEMBER_OP(elwise_mult, ComplexFlt, vec_mul)
+  DEFINE_MEMBER_OP(elwise_div, ComplexFlt, vec_div)
+  DEFINE_MEMBER_OP(elwise_gt, ComplexFlt, vec_cmpgt)
+  DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge)
+  DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt)
+  DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple)
+  DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max)
+};
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
+  // auto max = _mm256_blendv_ps(a, b, mask);
+  auto mask = abs_a.elwise_lt(abs_b);
+  auto max = Vectorized::elwise_blendv(a, b, mask);
+
+  return max;
+  // Exploit the fact that all-ones is a NaN.
+  // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
+  // return _mm256_or_ps(max, isnan);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
+  // auto min = _mm256_blendv_ps(a, b, mask);
+  auto mask = abs_a.elwise_gt(abs_b);
+  auto min = Vectorized::elwise_blendv(a, b, mask);
+  return min;
+  // Exploit the fact that all-ones is a NaN.
+  // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
+  // return _mm256_or_ps(min, isnan);
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..d9aeaa8650cf9deeedd123ef61a3f68804a85b20
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
@@ -0,0 +1,438 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+namespace vec {
+
+inline namespace CPU_CAPABILITY {
+
+
+template <>
+class Vectorized {
+ private:
+  union {
+    struct {
+      vfloat64 _vec0;
+      vfloat64 _vec1;
+    };
+    struct {
+      vbool64 _vecb0;
+      vbool64 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = double;
+  using vec_internal_type = vfloat64;
+  using vec_internal_mask_type = vbool64;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 4;
+  }
+  Vectorized() {}
+  C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(double scalar)
+      : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
+  C10_ALWAYS_INLINE Vectorized(
+      double scalar1,
+      double scalar2,
+      double scalar3,
+      double scalar4)
+      : _vec0{vfloat64{scalar1, scalar2}}, _vec1{vfloat64{scalar3, scalar4}} {}
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  int zero_mask() const {
+    auto cmp = (*this == vd_zero);
+    return (cmp._vecb0[0] & 1) | (cmp._vecb0[1] & 2) | (cmp._vecb1[0] & 4) |
+        (cmp._vecb1[1] & 8);
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+      return a;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+      return b;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+      return { b._vec0, a._vec1 };
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+      return { a._vec0, b._vec1 };
+  }
+
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+      const vbool64 mask_1st = VsxDblMask1(mask);
+      return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1 };
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+      blend(const Vectorized& a, const Vectorized& b) {
+      const vbool64 mask_1st = VsxDblMask1(mask);
+      return { (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1 };
+  }
+
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+      const vbool64 mask_2nd = VsxDblMask2(mask);
+      // generated masks
+      return { a._vec0,
+          (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) };
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+      const vbool64 mask_2nd = VsxDblMask2(mask);
+      // generated masks
+      return { b._vec0,
+          (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) };
+  }
+
+  template 
+  static std::enable_if_t>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+      const vbool64 mask_1st = VsxDblMask1(mask);
+      const vbool64 mask_2nd = VsxDblMask2(mask);
+      return {
+          (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st),
+          (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd) };
+  }
+
+
+  static Vectorized C10_ALWAYS_INLINE blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // the mask used here returned by comparision of vec256
+
+    return {
+        vec_sel(a._vec0, b._vec0, mask._vecb0),
+        vec_sel(a._vec1, b._vec1, mask._vecb1)};
+  }
+  template 
+  static Vectorized arange(double base = 0., step_t step = static_cast(1)) {
+    return Vectorized(base, base + step, base + 2 * step, base + 3 * step);
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  set(const Vectorized& a, const Vectorized& b, size_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+
+    return b;
+  }
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+  const double& operator[](int idx) const = delete;
+  double& operator[](int idx) = delete;
+  Vectorized map(double (*const f)(double)) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size()/2)) {
+        ret._vec0[i] = f(_vec0[i]);
+    }
+    for (const auto i : c10::irange(size()/2)) {
+        ret._vec1[i] = f(_vec1[i]);
+    }
+    return ret;
+  }
+
+  Vectorized mapbi(double (*const f)(double, double), const Vectorized& other)
+      const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size()/2)) {
+        ret._vec0[i] = f(_vec0[i], other._vec0[i]);
+    }
+    for (const auto i : c10::irange(size()/2)) {
+        ret._vec1[i] = f(_vec1[i], other._vec1[i]);
+    }
+    return ret;
+  }
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {vec_abs(_vec0), vec_abs(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE acos() const {
+     return {Sleef_acosd2_u10(_vec0), Sleef_acosd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE asin() const {
+     return {Sleef_asind2_u10(_vec0), Sleef_asind2_u10(_vec1)};
+  }
+  Vectorized atan() const {
+     return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)};
+  }
+  Vectorized atanh() const {
+     return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)};
+  }
+  Vectorized atan2(const Vectorized& b) const {
+     return {Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)};
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    return {Sleef_copysignd2(_vec0, sign._vec0), Sleef_copysignd2(_vec1, sign._vec1)};
+  }
+  Vectorized erf() const {
+     return {Sleef_erfd2_u10(_vec0), Sleef_erfd2_u10(_vec1)};
+  }
+  Vectorized erfc() const {
+     return {Sleef_erfcd2_u15(_vec0), Sleef_erfcd2_u15(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE exp() const {
+     return {Sleef_expd2_u10(_vec0), Sleef_expd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE exp2() const {
+    return {Sleef_exp2d2_u10(_vec0), Sleef_exp2d2_u10(_vec1)};
+  }
+  Vectorized expm1() const {
+     return {Sleef_expm1d2_u10(_vec0), Sleef_expm1d2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE exp_u20() const {
+     return exp();
+  }
+
+  Vectorized lgamma() const __ubsan_ignore_undefined__ {
+     return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)};
+  }
+
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+
+  Vectorized angle() const {
+    auto tmp = blendv(
+      Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+    return blendv(tmp, *this, isnan());
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized{0};
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+
+  Vectorized C10_ALWAYS_INLINE log() const {
+     return {Sleef_logd2_u10(_vec0), Sleef_logd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE log10() const {
+     return {Sleef_log10d2_u10(_vec0), Sleef_log10d2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE log1p() const {
+     return {Sleef_log1pd2_u10(_vec0), Sleef_log1pd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE log2() const {
+     return {Sleef_log2d2_u10(_vec0), Sleef_log2d2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE ceil() const {
+    return {vec_ceil(_vec0), vec_ceil(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE cos() const {
+     return {Sleef_cosd2_u10(_vec0), Sleef_cosd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE cosh() const {
+     return {Sleef_coshd2_u10(_vec0), Sleef_coshd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE floor() const {
+    return {vec_floor(_vec0), vec_floor(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE neg() const {
+    return {vec_neg(_vec0), vec_neg(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE round() const {
+    return {vec_rint(_vec0), vec_rint(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE sin() const {
+     return {Sleef_sind2_u10(_vec0), Sleef_sind2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE sinh() const {
+     return {Sleef_sinhd2_u10(_vec0), Sleef_sinhd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE tan() const {
+     return {Sleef_tand2_u10(_vec0), Sleef_tand2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE tanh() const {
+     return {Sleef_tanhd2_u10(_vec0), Sleef_tanhd2_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE trunc() const {
+    return {vec_trunc(_vec0), vec_trunc(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE frac() const {
+    return *this - trunc();
+  }
+
+  Vectorized C10_ALWAYS_INLINE sqrt() const {
+    return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE reciprocal() const {
+    return {
+        vec_div(vd_one, _vec0), // vec_re(_vec0) is estimated one.
+        vec_div(vd_one, _vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE rsqrt() const {
+    return sqrt().reciprocal();
+  }
+
+  Vectorized C10_ALWAYS_INLINE pow(const Vectorized& b) const {
+     return {Sleef_powd2_u10(_vec0, b._vec0), Sleef_powd2_u10(_vec1, b._vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE fmod(const Vectorized& b) const {
+     return {Sleef_fmodd2(_vec0, b._vec0),Sleef_fmodd2(_vec1, b._vec1)};
+  }
+
+  Vectorized hypot(const Vectorized& b) const {
+     return {Sleef_hypotd2_u05(_vec0, b._vec0), Sleef_hypotd2_u05(_vec1, b._vec1)};
+  }
+
+  Vectorized nextafter(const Vectorized& b) const {
+     return {Sleef_nextafterd2(_vec0, b._vec0), Sleef_nextafterd2(_vec1, b._vec1)};
+  }
+
+  Vectorized igamma(const Vectorized& x) const {
+    return mapbi(calc_igamma, x);
+  }
+
+  Vectorized igammac(const Vectorized& x) const {
+    return mapbi(calc_igammac, x);
+  }
+
+
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+
+  Vectorized _nor() const {
+    return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)};
+  }
+
+  Vectorized isnan() const {
+    auto x = *this;
+    auto ret = (x == x);
+    return ret._nor();
+  }
+  bool has_inf_nan() const {
+    for (const auto i : c10::irange(size()/2)) {
+      if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
+        return true;
+      }
+    }
+    for (const auto i : c10::irange(size()/2)) {
+      if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  DEFINE_MEMBER_OP(operator==, double, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, double, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, double, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, double, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, double, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, double, vec_cmpge)
+  DEFINE_MEMBER_OP_AND_ONE(eq, double, vec_cmpeq)
+  DEFINE_MEMBER_OP_AND_ONE(ne, double, vec_cmpne)
+  DEFINE_MEMBER_OP_AND_ONE(lt, double, vec_cmplt)
+  DEFINE_MEMBER_OP_AND_ONE(le, double, vec_cmple)
+  DEFINE_MEMBER_OP_AND_ONE(gt, double, vec_cmpgt)
+  DEFINE_MEMBER_OP_AND_ONE(ge, double, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, double, vec_add)
+  DEFINE_MEMBER_OP(operator-, double, vec_sub)
+  DEFINE_MEMBER_OP(operator*, double, vec_mul)
+  DEFINE_MEMBER_OP(operator/, double, vec_div)
+  DEFINE_MEMBER_OP(maximum, double, vec_max_nan2)
+  DEFINE_MEMBER_OP(minimum, double, vec_min_nan2)
+  DEFINE_MEMBER_OP(operator&, double, vec_and)
+  DEFINE_MEMBER_OP(operator|, double, vec_or)
+  DEFINE_MEMBER_OP(operator^, double, vec_xor)
+  DEFINE_MEMBER_TERNARY_OP(madd, double, vec_madd)
+};
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..6c36cf92ed9b49afeffd0c3eef97693e842248dc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
@@ -0,0 +1,461 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+
+inline namespace CPU_CAPABILITY {
+
+template <>
+class Vectorized {
+ private:
+  union {
+    struct {
+      vfloat32 _vec0;
+      vfloat32 _vec1;
+    };
+    struct {
+      vbool32 _vecb0;
+      vbool32 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = float;
+  using vec_internal_type = vfloat32;
+  using vec_internal_mask_type = vbool32;
+  using size_type = int;
+
+  static constexpr size_type size() {
+    return 8;
+  }
+  Vectorized() {}
+
+  C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(float scalar)
+      : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
+  C10_ALWAYS_INLINE Vectorized(
+      float scalar1,
+      float scalar2,
+      float scalar3,
+      float scalar4,
+      float scalar5,
+      float scalar6,
+      float scalar7,
+      float scalar8)
+      : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}},
+        _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {}
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return {a._vec0, b._vec1};
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxMask1(mask);
+    return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxMask1(mask);
+    return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_2nd = VsxMask2(mask);
+    // generated masks
+    return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_2nd = VsxMask2(mask);
+    // generated masks
+    return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    const vbool32 mask_1st = VsxMask1(mask);
+    const vbool32 mask_2nd = VsxMask2(mask);
+    return {
+        (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
+        (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  static Vectorized C10_ALWAYS_INLINE blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // the mask used here returned by comparision of vec256
+    // assuming this we can use the same mask directly with vec_sel
+    return {
+        vec_sel(a._vec0, b._vec0, mask._vecb0),
+        vec_sel(a._vec1, b._vec1, mask._vecb1)};
+  }
+
+  template 
+  static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + 2 * step,
+        base + 3 * step,
+        base + 4 * step,
+        base + 5 * step,
+        base + 6 * step,
+        base + 7 * step);
+  }
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+
+    return b;
+  }
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+
+  const float& operator[](int idx) const = delete;
+  float& operator[](int idx) = delete;
+
+  Vectorized map(float (*const f)(float)) const {
+    Vectorized ret;
+    for (int i = 0; i < size() / 2; i++) {
+      ret._vec0[i] = f(_vec0[i]);
+    }
+    for (int i = 0; i < size() / 2; i++) {
+      ret._vec1[i] = f(_vec1[i]);
+    }
+    return ret;
+  }
+
+  Vectorized mapbi(float (*const f)(float, float), const Vectorized& other)
+      const {
+    Vectorized ret;
+    for (int i = 0; i < size() / 2; i++) {
+      ret._vec0[i] = f(_vec0[i], other._vec0[i]);
+    }
+    for (int i = 0; i < size() / 2; i++) {
+      ret._vec1[i] = f(_vec1[i], other._vec1[i]);
+    }
+    return ret;
+  }
+
+  Vectorized _nor() const {
+    return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)};
+  }
+
+  Vectorized isnan() const {
+    auto x = *this;
+    auto ret = (x == x);
+    return ret._nor();
+  }
+
+  bool has_inf_nan() const {
+    for (const auto i : c10::irange(size()/2)) {
+      if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
+        return true;
+      }
+    }
+    for (const auto i : c10::irange(size()/2)) {
+      if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit
+    // and others are translated to 0-bit
+    //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ);
+    auto cmp = (*this == zero);
+    // return _mm256_movemask_ps(cmp);
+    // possible simulation  //mask= lvsl ( 0 ) vbpermq( vec, mask <<5)
+    vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits);
+    vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits);
+    return (result0[1] >> 12 | (result1[1] >> 8));
+  }
+
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {vec_abs(_vec0), vec_abs(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE acos() const {
+    return {Sleef_acosf4_u10(_vec0), Sleef_acosf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE asin() const {
+    return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)};
+  }
+  Vectorized atan() const {
+    return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)};
+  }
+  Vectorized atanh() const {
+    return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)};
+  }
+  Vectorized atan2(const Vectorized& b) const {
+    return {Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)};
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    return {Sleef_copysignf4(_vec0, sign._vec0), Sleef_copysignf4(_vec1, sign._vec1)};
+  }
+  Vectorized lgamma() const {
+    return {Sleef_lgammaf4_u10(_vec0), Sleef_lgammaf4_u10(_vec1)};
+  }
+  Vectorized erf() const {
+    return {Sleef_erff4_u10(_vec0), Sleef_erff4_u10(_vec1)};
+  }
+
+  Vectorized erfc() const {
+    return {Sleef_erfcf4_u15(_vec0), Sleef_erfcf4_u15(_vec1)};
+  }
+
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+
+  Vectorized angle() const {
+    auto tmp = blendv(
+      Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+    return blendv(tmp, *this, isnan());
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized{0};
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+
+  Vectorized C10_ALWAYS_INLINE exp() const {
+    return {Sleef_expf4_u10(_vec0), Sleef_expf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE exp2() const {
+    return {Sleef_exp2f4_u10(_vec0), Sleef_exp2f4_u10(_vec1)};
+  }
+  Vectorized expm1() const {
+    return {Sleef_expm1f4_u10(_vec0), Sleef_expm1f4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE exp_u20() const {
+    return exp();
+  }
+
+  Vectorized C10_ALWAYS_INLINE log() const {
+    return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE log10() const {
+    return {Sleef_log10f4_u10(_vec0), Sleef_log10f4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE log1p() const {
+    return {Sleef_log1pf4_u10(_vec0), Sleef_log1pf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE log2() const {
+    return {Sleef_log2f4_u10(_vec0), Sleef_log2f4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE ceil() const {
+    return {vec_ceil(_vec0), vec_ceil(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE cos() const {
+    return {Sleef_cosf4_u10(_vec0), Sleef_cosf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE cosh() const {
+    return {Sleef_coshf4_u10(_vec0), Sleef_coshf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE floor() const {
+    return {vec_floor(_vec0), vec_floor(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE neg() const {
+    return {vec_neg(_vec0), vec_neg(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE round() const {
+    return {vec_round(_vec0), vec_round(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE sin() const {
+    return {Sleef_sinf4_u10(_vec0), Sleef_sinf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE sinh() const {
+    return {Sleef_sinhf4_u10(_vec0), Sleef_sinhf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE tan() const {
+    return {Sleef_tanf4_u10(_vec0), Sleef_tanf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE tanh() const {
+    return {Sleef_tanhf4_u10(_vec0), Sleef_tanhf4_u10(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE trunc() const {
+    return {vec_trunc(_vec0), vec_trunc(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE frac() const {
+    return *this - trunc();
+  }
+
+  Vectorized C10_ALWAYS_INLINE sqrt() const {
+    return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE reciprocal() const {
+    return Vectorized(one) / (*this);
+  }
+  Vectorized C10_ALWAYS_INLINE rsqrt() const {
+    return sqrt().reciprocal();
+  }
+
+  Vectorized C10_ALWAYS_INLINE pow(const Vectorized& exp) const {
+    return {Sleef_powf4_u10(_vec0, exp._vec0), Sleef_powf4_u10(_vec1, exp._vec1)};
+  }
+
+  Vectorized fmod(const Vectorized& b) const {
+    return {Sleef_fmodf4(_vec0, b._vec0),Sleef_fmodf4(_vec1, b._vec1)};
+  }
+
+  Vectorized hypot(const Vectorized& b) const {
+    return {Sleef_hypotf4_u05(_vec0, b._vec0), Sleef_hypotf4_u05(_vec1, b._vec1)};
+  }
+
+  Vectorized nextafter(const Vectorized& b) const {
+    return {Sleef_nextafterf4(_vec0, b._vec0), Sleef_nextafterf4(_vec1, b._vec1)};
+  }
+
+  Vectorized igamma(const Vectorized& x) const {
+    return mapbi(calc_igamma, x);
+  }
+
+  Vectorized igammac(const Vectorized& x) const {
+    return mapbi(calc_igammac, x);
+  }
+
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+
+  DEFINE_MEMBER_OP(operator==, float, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, float, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, float, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, float, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, float, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, float, vec_cmpge)
+  DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq)
+  DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne)
+  DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt)
+  DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple)
+  DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt)
+  DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, float, vec_add)
+  DEFINE_MEMBER_OP(operator-, float, vec_sub)
+  DEFINE_MEMBER_OP(operator*, float, vec_mul)
+  DEFINE_MEMBER_OP(operator/, float, vec_div)
+  DEFINE_MEMBER_OP(maximum, float, vec_max_nan2)
+  DEFINE_MEMBER_OP(minimum, float, vec_min_nan2)
+  DEFINE_MEMBER_OP(operator&, float, vec_and)
+  DEFINE_MEMBER_OP(operator|, float, vec_or)
+  DEFINE_MEMBER_OP(operator^, float, vec_xor)
+  DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd)
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return a.minimum(b);
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..05a6b7d007f5e466d72b79602ad1b12f1ebb7dd9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h
@@ -0,0 +1,368 @@
+#pragma once
+
+#include 
+#include 
+#include 
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+template <>
+class Vectorized {
+ private:
+  union {
+    struct {
+      vint16 _vec0;
+      vint16 _vec1;
+    };
+    struct {
+      vbool16 _vecb0;
+      vbool16 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = int16_t;
+  using vec_internal_type = vint16;
+  using vec_internal_mask_type = vbool16;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 16;
+  }
+  Vectorized() {}
+  C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) : _vecb0{v1}, _vecb1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(int16_t scalar)
+      : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
+
+  C10_ALWAYS_INLINE Vectorized(
+      int16_t scalar1,
+      int16_t scalar2,
+      int16_t scalar3,
+      int16_t scalar4,
+      int16_t scalar5,
+      int16_t scalar6,
+      int16_t scalar7,
+      int16_t scalar8,
+      int16_t scalar9,
+      int16_t scalar10,
+      int16_t scalar11,
+      int16_t scalar12,
+      int16_t scalar13,
+      int16_t scalar14,
+      int16_t scalar15,
+      int16_t scalar16)
+      : _vec0{vint16{
+            scalar1,
+            scalar2,
+            scalar3,
+            scalar4,
+            scalar5,
+            scalar6,
+            scalar7,
+            scalar8}},
+        _vec1{vint16{
+            scalar9,
+            scalar10,
+            scalar11,
+            scalar12,
+            scalar13,
+            scalar14,
+            scalar15,
+            scalar16}} {}
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t<(mask & 65535) == 65535, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t<(mask > 0 && mask < 255), Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr int16_t g0 = (mask & 1) * 0xffff;
+    constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
+    constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
+    constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
+    constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
+    constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
+    constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
+    constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
+    const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
+
+    return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)),
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr int16_t g0_2 = (mask & 1) * 0xffff;
+    constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
+    constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
+    constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
+    constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
+    constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
+    constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
+    constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
+
+    const vint16 mask_2nd =
+        vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
+    // generated masks
+    return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)),
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr int16_t mask2 = (mask & 65535) >> 16;
+    constexpr int16_t g0_2 = (mask & 1) * 0xffff;
+    constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
+    constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
+    constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
+    constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
+    constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
+    constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
+    constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
+
+    const vint16 mask_2nd =
+        vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
+    // generated masks
+    return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) &&
+       ((mask & 255) != 255)),
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr int16_t g0 = (mask & 1) * 0xffff;
+    constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
+    constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
+    constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
+    constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
+    constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
+    constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
+    constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
+    constexpr int16_t mask2 = (mask & 65535) >> 16;
+    constexpr int16_t g0_2 = (mask & 1) * 0xffff;
+    constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
+    constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
+    constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
+    constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
+    constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
+    constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
+    constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
+
+    const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
+    const vint16 mask_2nd =
+        vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
+    // generated masks
+    return {
+        (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st),
+        (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
+  }
+
+  static Vectorized C10_ALWAYS_INLINE blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // the mask used here returned by comparision of vec256
+    // assuming this we can use the same mask directly with vec_sel
+    // warning intel style mask will not work properly
+    return {
+        vec_sel(a._vec0, b._vec0, mask._vecb0),
+        vec_sel(a._vec1, b._vec1, mask._vecb1)};
+  }
+
+  template 
+  static Vectorized arange(int16_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + 2 * step,
+        base + 3 * step,
+        base + 4 * step,
+        base + 5 * step,
+        base + 6 * step,
+        base + 7 * step,
+        base + 8 * step,
+        base + 9 * step,
+        base + 10 * step,
+        base + 11 * step,
+        base + 12 * step,
+        base + 13 * step,
+        base + 14 * step,
+        base + 15 * step);
+  }
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+      case 8:
+        return blend<255>(a, b);
+      case 9:
+        return blend<511>(a, b);
+      case 10:
+        return blend<1023>(a, b);
+      case 11:
+        return blend<2047>(a, b);
+      case 12:
+        return blend<4095>(a, b);
+      case 13:
+        return blend<8191>(a, b);
+      case 14:
+        return blend<16383>(a, b);
+      case 15:
+        return blend<32767>(a, b);
+    }
+    return b;
+  }
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+  const int16_t& operator[](int idx) const = delete;
+  int16_t& operator[](int idx) = delete;
+
+  Vectorized angle() const {
+    return blendv(
+      Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized{0};
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {vec_abs(_vec0), vec_abs(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE neg() const {
+    return {vec_neg(_vec0), vec_neg(_vec1)};
+  }
+
+  DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not)
+  DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge)
+  DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq)
+  DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne)
+  DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt)
+  DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple)
+  DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt)
+  DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, int16_t, vec_add)
+  DEFINE_MEMBER_OP(operator-, int16_t, vec_sub)
+  DEFINE_MEMBER_OP(operator*, int16_t, vec_mul)
+  DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /)
+  DEFINE_MEMBER_OP(maximum, int16_t, vec_max)
+  DEFINE_MEMBER_OP(minimum, int16_t, vec_min)
+  DEFINE_MEMBER_OP(operator&, int16_t, vec_and)
+  DEFINE_MEMBER_OP(operator|, int16_t, vec_or)
+  DEFINE_MEMBER_OP(operator^, int16_t, vec_xor)
+};
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+               vuint16 shift_vec0 = reinterpret_cast(b.vec0());
+               vuint16 shift_vec1 = reinterpret_cast(b.vec1());
+         return Vectorized{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+               vuint16 shift_vec0 = reinterpret_cast(b.vec0());
+               vuint16 shift_vec1 = reinterpret_cast(b.vec1()) ;
+         return Vectorized{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
+}
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..6299b43458b2b0ab6836d2d003a12f2ad8c31e6f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h
@@ -0,0 +1,298 @@
+#pragma once
+
+#include 
+#include 
+#include 
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+template <>
+class Vectorized {
+ private:
+  union {
+    struct {
+      vint32 _vec0;
+      vint32 _vec1;
+    };
+    struct {
+      vbool32 _vecb0;
+      vbool32 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = int32_t;
+  using vec_internal_type = vint32;
+  using vec_internal_mask_type = vbool32;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+  Vectorized() {}
+  C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(int32_t scalar)
+      : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
+  C10_ALWAYS_INLINE Vectorized(
+      int32_t scalar1,
+      int32_t scalar2,
+      int32_t scalar3,
+      int32_t scalar4,
+      int32_t scalar5,
+      int32_t scalar6,
+      int32_t scalar7,
+      int32_t scalar8)
+      : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}},
+        _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {}
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t<(mask & 255) == 255, Vectorized> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t<(mask > 0 && mask < 15), Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint32_t g0 = (mask & 1) * 0xffffffff;
+    constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
+    constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
+    constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
+    const vbool32 mask_1st = (vbool32){g0, g1, g2, g3};
+
+    return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)),
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint32_t mask2 = (mask & 255) >> 4;
+    constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
+    constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
+    constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
+    constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
+
+    const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
+    // generated masks
+    return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)),
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint32_t mask2 = (mask & 255) >> 4;
+    constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
+    constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
+    constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
+    constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
+
+    const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
+    // generated masks
+    return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) &&
+       ((mask & 15) != 15)),
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint32_t g0 = (mask & 1) * 0xffffffff;
+    constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
+    constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
+    constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
+    constexpr uint32_t mask2 = (mask & 255) >> 4;
+    constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
+    constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
+    constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
+    constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
+
+    const vbool32 mask_1st = (vbool32){g0, g1, g2, g3};
+    const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
+    // generated masks
+    return {
+        (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st),
+        (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
+  }
+
+  static Vectorized C10_ALWAYS_INLINE blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // the mask used here returned by comparision of vec256
+    // assuming this we can use the same mask directly with vec_sel
+    // warning intel style mask will not work properly
+    return {
+        vec_sel(a._vec0, b._vec0, mask._vecb0),
+        vec_sel(a._vec1, b._vec1, mask._vecb1)};
+  }
+
+  template 
+  static Vectorized arange(int32_t base = 0.f, step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + 2 * step,
+        base + 3 * step,
+        base + 4 * step,
+        base + 5 * step,
+        base + 6 * step,
+        base + 7 * step);
+  }
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+
+    return b;
+  }
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+  const int32_t& operator[](int idx) const = delete;
+  int32_t& operator[](int idx) = delete;
+
+  Vectorized angle() const {
+    return blendv(
+      Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized{0};
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {vec_abs(_vec0), vec_abs(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE neg() const {
+    return {vec_neg(_vec0), vec_neg(_vec1)};
+  }
+
+  DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not)
+  DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge)
+  DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq)
+  DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne)
+  DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt)
+  DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple)
+  DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt)
+  DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, int32_t, vec_add)
+  DEFINE_MEMBER_OP(operator-, int32_t, vec_sub)
+  DEFINE_MEMBER_OP(operator*, int32_t, vec_mul)
+  DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /)
+  DEFINE_MEMBER_OP(maximum, int32_t, vec_max)
+  DEFINE_MEMBER_OP(minimum, int32_t, vec_min)
+  DEFINE_MEMBER_OP(operator&, int32_t, vec_and)
+  DEFINE_MEMBER_OP(operator|, int32_t, vec_or)
+  DEFINE_MEMBER_OP(operator^, int32_t, vec_xor)
+};
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+                vuint32 shift_vec0 = reinterpret_cast(b.vec0());
+                vuint32 shift_vec1 = reinterpret_cast(b.vec1()) ;
+          return Vectorized{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+                vuint32 shift_vec0 = reinterpret_cast(b.vec0());
+                vuint32 shift_vec1 = reinterpret_cast(b.vec1()) ;
+          return Vectorized{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
+}
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..0a52f763aa84bfd7c06006e1187f5cb38daee320
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h
@@ -0,0 +1,251 @@
+#pragma once
+
+#include 
+#include 
+#include 
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+template <>
+class Vectorized {
+ private:
+  union {
+    struct {
+      vint64 _vec0;
+      vint64 _vec1;
+    };
+    struct {
+      vbool64 _vecb0;
+      vbool64 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  using value_type = int64_t;
+  using vec_internal_type = vint64;
+  using vec_internal_mask_type = vbool64;
+  using size_type = int;
+  using ElementType = signed long long;
+  static constexpr size_type size() {
+    return 4;
+  }
+  Vectorized() {}
+  C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(int64_t scalar)
+      : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
+  C10_ALWAYS_INLINE Vectorized(
+      int64_t scalar1,
+      int64_t scalar2,
+      int64_t scalar3,
+      int64_t scalar4)
+      : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {}
+
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t<(mask & 15) == 15, Vectorized> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t<(mask > 0 && mask < 3), Vectorized> C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
+    constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
+    const vbool64 mask_1st = (vbool64){g0, g1};
+    return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1};
+  }
+
+  template 
+  static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff;
+    constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff;
+
+    const vbool64 mask_2nd = (vbool64){g0_2, g1_2};
+    return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t<
+      (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15,
+      Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
+    constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
+    constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff;
+    constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff;
+
+    const vbool64 mask_1st = (vbool64){g0, g1};
+    const vbool64 mask_2nd = (vbool64){g0_2, g1_2};
+    return {
+        (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st),
+        (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)};
+  }
+
+  static Vectorized C10_ALWAYS_INLINE blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // the mask used here returned by comparision of vec256
+
+    return {
+        vec_sel(a._vec0, b._vec0, mask._vecb0),
+        vec_sel(a._vec1, b._vec1, mask._vecb1)};
+  }
+  template 
+  static Vectorized arange(int64_t base = 0., step_t step = static_cast(1)) {
+    return Vectorized(base, base + step, base + 2 * step, base + 3 * step);
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  set(const Vectorized& a,
+      const Vectorized& b,
+      size_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+
+    return b;
+  }
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      static_assert(sizeof(double) == sizeof(value_type));
+      const double* dptr = reinterpret_cast(ptr);
+      return {// treat it as double load
+              (vint64)vec_vsx_ld(offset0, dptr),
+              (vint64)vec_vsx_ld(offset16, dptr)};
+    }
+
+    __at_align__ double tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {
+        (vint64)vec_vsx_ld(offset0, tmp_values),
+        (vint64)vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      double* dptr = reinterpret_cast(ptr);
+      vec_vsx_st((vfloat64)_vec0, offset0, dptr);
+      vec_vsx_st((vfloat64)_vec1, offset16, dptr);
+    } else if (count > 0) {
+      __at_align__ double tmp_values[size()];
+      vec_vsx_st((vfloat64)_vec0, offset0, tmp_values);
+      vec_vsx_st((vfloat64)_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+  const int64_t& operator[](int idx) const = delete;
+  int64_t& operator[](int idx) = delete;
+
+  Vectorized angle() const {
+    return blendv(
+      Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized{0};
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {vec_abs(_vec0), vec_abs(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE neg() const {
+    return {vec_neg(_vec0), vec_neg(_vec1)};
+  }
+
+  DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not)
+  DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge)
+  DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq)
+  DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne)
+  DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt)
+  DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple)
+  DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt)
+  DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, int64_t, vec_add)
+  DEFINE_MEMBER_OP(operator-, int64_t, vec_sub)
+  DEFINE_MEMBER_OP(operator*, int64_t, vec_mul)
+  DEFINE_MEMBER_OP(operator/, int64_t, vec_div)
+  DEFINE_MEMBER_OP(maximum, int64_t, vec_max)
+  DEFINE_MEMBER_OP(minimum, int64_t, vec_min)
+  DEFINE_MEMBER_OP(operator&, int64_t, vec_and)
+  DEFINE_MEMBER_OP(operator|, int64_t, vec_or)
+  DEFINE_MEMBER_OP(operator^, int64_t, vec_xor)
+};
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+                vuint64 shift_vec0 = reinterpret_cast(b.vec0());
+                vuint64 shift_vec1 = reinterpret_cast(b.vec1()) ;
+          return Vectorized{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+                vuint64 shift_vec0 = reinterpret_cast(b.vec0());
+                vuint64 shift_vec1 = reinterpret_cast(b.vec1()) ;
+          return Vectorized{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
+}
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..4687883eaa419a8012427fb19d73bd1eaf71bc89
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h
@@ -0,0 +1,245 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// This file defines Vectorized<> for the quantized types.
+//
+//
+// Currently, we simply use these classes as efficient converters between
+// the quantized types and Vectorized, usually in bandwidth-bound cases
+// where doing the arithmetic in full-precision is acceptable (e.g.
+// elementwise operators).
+//
+//
+// Conversions are as follows:
+//  Vectorized -> 1x Vectorized
+//
+// The size of the returned float vector is specified by the special
+// constexpr function float_num_vecs. The type of the value returned
+// from dequantize (and expected as an argument to quantize) is
+// specified by float_vec_return_type.
+//
+// When writing kernels with these vectors, it is expected that floating-
+// point operations will be carried out in a loop over Vectorized::float_num_vecs
+// iterations.
+
+namespace at {
+namespace vec {
+inline namespace CPU_CAPABILITY {
+
+template <>
+struct Vectorized {
+ private:
+  union {
+    struct {
+      vint32 _vec0;
+      vint32 _vec1;
+    };
+    struct {
+      vbool32 _vecb0;
+      vbool32 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  Vectorized() {}
+
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+
+  static constexpr size_t float_num_vecs() {
+    return 1;
+  }
+  static constexpr int int_num_vecs() {
+    return 1;
+  }
+  using float_vec_return_type = std::array, 1>;
+  using int_vec_return_type = std::array, 1>;
+  using value_type = c10::qint32::underlying;
+  using vec_internal_type = vint32;
+  using vec_internal_mask_type = vbool32;
+  C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
+
+  Vectorized(const c10::qint32& val)
+      : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {}
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ value_type tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    vfloat32 float_vals0 = vec_float(_vec0);
+    vfloat32 float_vals1 = vec_float(_vec1);
+    vfloat32 scale_vec0 = scale.vec0();
+    vfloat32 scale_vec1 = scale.vec1();
+    vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
+    vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
+    return {Vectorized{
+        vec_madd(scale_vec0, float_vals0, scale_zp_premul0),
+        vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}};
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    vfloat32 float_vals0 = vec_float(_vec0);
+    vfloat32 float_vals1 = vec_float(_vec1);
+    vfloat32 scale_vec0 = scale.vec0();
+    vfloat32 scale_vec1 = scale.vec1();
+    vfloat32 zero_point0 = zero_point.vec0();
+    vfloat32 zero_point1 = zero_point.vec1();
+    return {Vectorized{
+        (float_vals0 - zero_point0) * scale_vec0,
+        (float_vals1 - zero_point1) * scale_vec1}};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    Vectorized retval;
+
+    const vint32 vmin = vec_splats(std::numeric_limits::min());
+    const vint32 vmax = vec_splats(std::numeric_limits::max());
+    vfloat32 inverse_scale_v = vec_splats(inverse_scale);
+    vfloat32 vec_zero_point = vec_splats((float)(zero_point));
+    Vectorized vf0 = rhs[0];
+
+    vfloat32 vecf0 = vf0.vec0();
+    vfloat32 vecf1 = vf0.vec1();
+    vecf0 = vec_mul(vecf0, inverse_scale_v);
+    vecf1 = vec_mul(vecf1, inverse_scale_v);
+    vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
+    vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
+    vint32 veci0  = vec_signed(vecf0);
+    vint32 veci1  = vec_signed(vecf1);
+
+    veci0 = vec_max(veci0, vmin);
+    veci1 = vec_max(veci1, vmin);
+    veci0 = vec_min(veci0, vmax);
+    veci1 = vec_min(veci1, vmax);
+
+    return {veci0, veci1};
+  }
+
+  Vectorized relu(Vectorized zero_point) const {
+    return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
+  }
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) const {
+    vint32 max0 = vec_max(_vec0, zero_point._vec0);
+    vint32 max1 = vec_max(_vec1, zero_point._vec1);
+    return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    return {*this - b};
+  }
+
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    const vint32 vmin = vec_splats(std::numeric_limits::min());
+    const vint32 vmax = vec_splats(std::numeric_limits::max());
+    vfloat32 vec_mult = vec_splats(multiplier);
+    vint32 vec_zero_point = vec_splats(zero_point);
+    Vectorized vi = inp[0];
+    vfloat32 vecf0 = vec_float(vi.vec0());
+    vfloat32 vecf1 = vec_float(vi.vec1());
+
+    vecf0 = vec_mul(vecf0, vec_mult);
+    vecf1 = vec_mul(vecf1, vec_mult);
+
+    vecf0 = vec_rint(vecf0);
+    vecf1 = vec_rint(vecf1);
+
+    vint32 veci0  = vec_add(vec_signed(vecf0),vec_zero_point);
+    vint32 veci1  = vec_add(vec_signed(vecf1),vec_zero_point);
+
+    veci0 = vec_max(veci0, vmin);
+    veci1 = vec_max(veci1, vmin);
+    veci0 = vec_min(veci0, vmax);
+    veci1 = vec_min(veci1, vmax);
+
+    return {veci0, veci1};
+  }
+
+  DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add)
+  DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub)
+  DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul)
+  DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /)
+  DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max)
+  DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min)
+  DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and)
+  DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or)
+  DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor)
+};
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..2ed4242137b41831a3aaa4249f557180dc8de16a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h
@@ -0,0 +1,447 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// This file defines Vectorized<> for the quantized types.
+//
+//
+// Currently, we simply use these classes as efficient converters between
+// the quantized types and Vectorized, usually in bandwidth-bound cases
+// where doing the arithmetic in full-precision is acceptable (e.g.
+// elementwise operators).
+//
+//
+// Conversions are as follows:
+//  Vectorized -> 4x Vectorized
+//
+// The size of the returned float vector is specified by the special
+// constexpr function float_num_vecs. The type of the value returned
+// from dequantize (and expected as an argument to quantize) is
+// specified by float_vec_return_type.
+//
+// When writing kernels with these vectors, it is expected that floating-
+// point operations will be carried out in a loop over Vectorized::float_num_vecs
+// iterations.
+
+namespace at {
+namespace vec {
+inline namespace CPU_CAPABILITY {
+
+template <>
+struct Vectorized {
+ private:
+  union {
+    struct {
+      vint8 _vec0;
+      vint8 _vec1;
+    };
+    struct {
+      vbool8 _vecb0;
+      vbool8 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  Vectorized() {}
+  using size_type = int;
+  static constexpr size_type size() {
+    return 32;
+  }
+
+  static constexpr size_t float_num_vecs() {
+    return 4;
+  }
+  static constexpr int int_num_vecs() {
+    return 4;
+  }
+  using float_vec_return_type = std::array, 4>;
+  using int_vec_return_type = std::array, 4>;
+  using value_type = typename c10::qint8::underlying;
+  using vec_internal_type = vint8;
+  using vec_internal_mask_type = vbool8;
+  // Broadcast constructor
+  C10_ALWAYS_INLINE Vectorized(const c10::qint8& val)
+      : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {}
+
+  C10_ALWAYS_INLINE Vectorized(const Vectorized& other)
+      : _vec0{other._vec0}, _vec1(other._vec1) {}
+
+  C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {}
+
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  static C10_ALWAYS_INLINE Vectorized loadu(
+      const void* ptr,
+      int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+    __at_align__ value_type tmp_values[size()];
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+
+ public:
+  float_vec_return_type C10_ALWAYS_INLINE dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    vint16 vecshi0 = vec_unpackh(_vec0);
+    vint16 vecshi1 = vec_unpackl(_vec0);
+
+    vint16 vecshi2 = vec_unpackh(_vec1);
+    vint16 vecshi3 = vec_unpackl(_vec1);
+
+    vint32 veci0 = vec_unpackh(vecshi0);
+    vint32 veci1 = vec_unpackl(vecshi0);
+
+    vint32 veci2 = vec_unpackh(vecshi1);
+    vint32 veci3 = vec_unpackl(vecshi1);
+
+    vint32 veci4 = vec_unpackh(vecshi2);
+    vint32 veci5 = vec_unpackl(vecshi2);
+
+    vint32 veci6 = vec_unpackh(vecshi3);
+    vint32 veci7 = vec_unpackl(vecshi3);
+
+    vfloat32 vecf0_0 = vec_float(veci0);
+    vfloat32 vecf1_0 = vec_float(veci1);
+
+    vfloat32 vecf0_1 = vec_float(veci2);
+    vfloat32 vecf1_1 = vec_float(veci3);
+
+    vfloat32 vecf0_2 = vec_float(veci4);
+    vfloat32 vecf1_2 = vec_float(veci5);
+
+    vfloat32 vecf0_3 = vec_float(veci6);
+    vfloat32 vecf1_3 = vec_float(veci7);
+    vfloat32 scale_vec0 = scale.vec0();
+    vfloat32 scale_vec1 = scale.vec1();
+    vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
+    vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
+    return {
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_0, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)},
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_1, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)},
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_2, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)},
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_3, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}};
+  }
+
+  float_vec_return_type C10_ALWAYS_INLINE dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    vint16 vecshi0 = vec_unpackh(_vec0);
+    vint16 vecshi1 = vec_unpackl(_vec0);
+
+    vint16 vecshi2 = vec_unpackh(_vec1);
+    vint16 vecshi3 = vec_unpackl(_vec1);
+
+    vint32 veci0 = vec_unpackh(vecshi0);
+    vint32 veci1 = vec_unpackl(vecshi0);
+
+    vint32 veci2 = vec_unpackh(vecshi1);
+    vint32 veci3 = vec_unpackl(vecshi1);
+
+    vint32 veci4 = vec_unpackh(vecshi2);
+    vint32 veci5 = vec_unpackl(vecshi2);
+
+    vint32 veci6 = vec_unpackh(vecshi3);
+    vint32 veci7 = vec_unpackl(vecshi3);
+
+    vfloat32 vecf0_0 = vec_float(veci0);
+    vfloat32 vecf1_0 = vec_float(veci1);
+
+    vfloat32 vecf0_1 = vec_float(veci2);
+    vfloat32 vecf1_1 = vec_float(veci3);
+
+    vfloat32 vecf0_2 = vec_float(veci4);
+    vfloat32 vecf1_2 = vec_float(veci5);
+
+    vfloat32 vecf0_3 = vec_float(veci6);
+    vfloat32 vecf1_3 = vec_float(veci7);
+    vfloat32 scale_vec0 = scale.vec0();
+    vfloat32 scale_vec1 = scale.vec1();
+    vfloat32 zero_point0 = zero_point.vec0();
+    vfloat32 zero_point1 = zero_point.vec1();
+    return {
+        Vectorized{
+            (vecf0_0 - zero_point0) * scale_vec0,
+            (vecf1_0 - zero_point1) * scale_vec1},
+        Vectorized{
+            (vecf0_1 - zero_point0) * scale_vec0,
+            (vecf1_1 - zero_point1) * scale_vec1},
+        Vectorized{
+            (vecf0_2 - zero_point0) * scale_vec0,
+            (vecf1_2 - zero_point1) * scale_vec1},
+        Vectorized{
+            (vecf0_3 - zero_point0) * scale_vec0,
+            (vecf1_3 - zero_point1) * scale_vec1}};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    // constexpr int32_t min_val = std::numeric_limits::min();
+    // constexpr int32_t max_val = std::numeric_limits::max();
+
+    vfloat32 inverse_scale_v = vec_splats(inverse_scale);
+    vfloat32 vec_zero_point = vec_splats((float)zero_point);
+    // vint32 vmin = vec_splats(min_val);
+    // vint32 vmax = vec_splats(max_val);
+
+    Vectorized vf0 = rhs[0];
+    Vectorized vf1 = rhs[1];
+    Vectorized vf2 = rhs[2];
+    Vectorized vf3 = rhs[3];
+    vfloat32 vecf0 = vf0.vec0();
+    vfloat32 vecf1 = vf0.vec1();
+    vfloat32 vecf2 = vf1.vec0();
+    vfloat32 vecf3 = vf1.vec1();
+
+    vfloat32 vecf4 = vf2.vec0();
+    vfloat32 vecf5 = vf2.vec1();
+    vfloat32 vecf6 = vf3.vec0();
+    vfloat32 vecf7 = vf3.vec1();
+
+    vecf0 = vec_mul(vecf0, inverse_scale_v);
+    vecf1 = vec_mul(vecf1, inverse_scale_v);
+    vecf2 = vec_mul(vecf2, inverse_scale_v);
+    vecf3 = vec_mul(vecf3, inverse_scale_v);
+
+    vecf4 = vec_mul(vecf4, inverse_scale_v);
+    vecf5 = vec_mul(vecf5, inverse_scale_v);
+    vecf6 = vec_mul(vecf6, inverse_scale_v);
+    vecf7 = vec_mul(vecf7, inverse_scale_v);
+
+    vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
+    vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
+    vecf2 = vec_add(vec_rint(vecf2), vec_zero_point);
+    vecf3 = vec_add(vec_rint(vecf3), vec_zero_point);
+
+    vecf4 = vec_add(vec_rint(vecf4), vec_zero_point);
+    vecf5 = vec_add(vec_rint(vecf5), vec_zero_point);
+    vecf6 = vec_add(vec_rint(vecf6), vec_zero_point);
+    vecf7 = vec_add(vec_rint(vecf7), vec_zero_point);
+
+    vint32 veci0 = vec_signed(vecf0);
+    vint32 veci1 = vec_signed(vecf1);
+    vint32 veci2 = vec_signed(vecf2);
+    vint32 veci3 = vec_signed(vecf3);
+
+    vint32 veci4 = vec_signed(vecf4);
+    vint32 veci5 = vec_signed(vecf5);
+    vint32 veci6 = vec_signed(vecf6);
+    vint32 veci7 = vec_signed(vecf7);
+
+    // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ;
+    // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ;
+    // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ;
+    // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ;
+
+    // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ;
+    // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ;
+    // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ;
+    // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ;
+    // vec_packs CLAMP already
+    vint16 vecshi0 = vec_packs(veci0, veci1);
+    vint16 vecshi1 = vec_packs(veci2, veci3);
+    vint16 vecshi2 = vec_packs(veci4, veci5);
+    vint16 vecshi3 = vec_packs(veci6, veci7);
+
+    vint8 vec0 = vec_packs(vecshi0, vecshi1);
+    vint8 vec1 = vec_packs(vecshi2, vecshi3);
+
+    return {vec0, vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE relu(Vectorized zero_point) const {
+    return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE
+  relu6(Vectorized zero_point, Vectorized q_six) const {
+    vint8 max0 = vec_max(_vec0, zero_point._vec0);
+    vint8 max1 = vec_max(_vec1, zero_point._vec1);
+    return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    vint16 vecshi0 = vec_unpackh(_vec0);
+    vint16 vecBshi0 = vec_unpackh(b._vec0);
+    vint16 vecshi1 = vec_unpackl(_vec0);
+    vint16 vecBshi1 = vec_unpackl(b._vec0);
+
+    vint16 vecshi2 = vec_unpackh(_vec1);
+    vint16 vecBshi2 = vec_unpackh(b._vec1);
+    vint16 vecshi3 = vec_unpackl(_vec1);
+    vint16 vecBshi3 = vec_unpackl(b._vec1);
+
+    vint32 veci0 = vec_unpackh(vecshi0);
+    vint32 vecBi0 = vec_unpackh(vecBshi0);
+    vint32 veci1 = vec_unpackl(vecshi0);
+    vint32 vecBi1 = vec_unpackl(vecBshi0);
+
+    vint32 veci2 = vec_unpackh(vecshi1);
+    vint32 vecBi2 = vec_unpackh(vecBshi1);
+    vint32 veci3 = vec_unpackl(vecshi1);
+    vint32 vecBi3 = vec_unpackl(vecBshi1);
+
+    vint32 veci4 = vec_unpackh(vecshi2);
+    vint32 vecBi4 = vec_unpackh(vecBshi2);
+    vint32 veci5 = vec_unpackl(vecshi2);
+    vint32 vecBi5 = vec_unpackl(vecBshi2);
+
+    vint32 veci6 = vec_unpackh(vecshi3);
+    vint32 vecBi6 = vec_unpackh(vecBshi3);
+    vint32 veci7 = vec_unpackl(vecshi3);
+    vint32 vecBi7 = vec_unpackl(vecBshi3);
+
+    return {
+        Vectorized(veci0 - vecBi0, veci1 - vecBi1),
+        Vectorized(veci2 - vecBi2, veci3 - vecBi3),
+        Vectorized(veci4 - vecBi4, veci5 - vecBi5),
+        Vectorized(veci6 - vecBi6, veci7 - vecBi7)};
+  }
+
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    vfloat32 vec_multiplier = vec_splats(multiplier);
+    vint32 vec_zero_point = vec_splats(zero_point);
+
+    Vectorized vi0 = inp[0];
+    Vectorized vi1 = inp[1];
+    Vectorized vi2 = inp[2];
+    Vectorized vi3 = inp[3];
+
+    vfloat32 vecf0 = vec_float(vi0.vec0());
+    vfloat32 vecf1 = vec_float(vi0.vec1());
+    vfloat32 vecf2 = vec_float(vi1.vec0());
+    vfloat32 vecf3 = vec_float(vi1.vec1());
+
+    vfloat32 vecf4 = vec_float(vi2.vec0());
+    vfloat32 vecf5 = vec_float(vi2.vec1());
+    vfloat32 vecf6 = vec_float(vi3.vec0());
+    vfloat32 vecf7 = vec_float(vi3.vec1());
+
+    vecf0 = vec_mul(vecf0, vec_multiplier);
+    vecf1 = vec_mul(vecf1, vec_multiplier);
+    vecf2 = vec_mul(vecf2, vec_multiplier);
+    vecf3 = vec_mul(vecf3, vec_multiplier);
+
+    vecf4 = vec_mul(vecf4, vec_multiplier);
+    vecf5 = vec_mul(vecf5, vec_multiplier);
+    vecf6 = vec_mul(vecf6, vec_multiplier);
+    vecf7 = vec_mul(vecf7, vec_multiplier);
+
+    vecf0 = vec_rint(vecf0);
+    vecf1 = vec_rint(vecf1);
+    vecf2 = vec_rint(vecf2);
+    vecf3 = vec_rint(vecf3);
+
+    vecf4 = vec_rint(vecf4);
+    vecf5 = vec_rint(vecf5);
+    vecf6 = vec_rint(vecf6);
+    vecf7 = vec_rint(vecf7);
+
+    vint32 veci0 = vec_signed(vecf0);
+    vint32 veci1 = vec_signed(vecf1);
+    vint32 veci2 = vec_signed(vecf2);
+    vint32 veci3 = vec_signed(vecf3);
+
+    vint32 veci4 = vec_signed(vecf4);
+    vint32 veci5 = vec_signed(vecf5);
+    vint32 veci6 = vec_signed(vecf6);
+    vint32 veci7 = vec_signed(vecf7);
+
+    veci0 = vec_add(veci0, vec_zero_point);
+    veci1 = vec_add(veci1, vec_zero_point);
+    veci2 = vec_add(veci2, vec_zero_point);
+    veci3 = vec_add(veci3, vec_zero_point);
+
+    veci4 = vec_add(veci4, vec_zero_point);
+    veci5 = vec_add(veci5, vec_zero_point);
+    veci6 = vec_add(veci6, vec_zero_point);
+    veci7 = vec_add(veci7, vec_zero_point);
+
+    vint16 vecshi0 = vec_packs(veci0, veci1);
+    vint16 vecshi1 = vec_packs(veci2, veci3);
+    vint16 vecshi2 = vec_packs(veci4, veci5);
+    vint16 vecshi3 = vec_packs(veci6, veci7);
+
+    vint8 vec0 = vec_packs(vecshi0, vecshi1);
+    vint8 vec1 = vec_packs(vecshi2, vecshi3);
+
+    return {vec0, vec1};
+  }
+
+  DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add)
+  DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub)
+  DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul)
+  DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /)
+  DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max)
+  DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min)
+  DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and)
+  DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or)
+  DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor)
+};
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h
new file mode 100644
index 0000000000000000000000000000000000000000..85a0e79400b833db8cc5cb053f53cb78cdad00a4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h
@@ -0,0 +1,466 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+// This file defines Vectorized<> for the quantized types.
+//
+//
+// Currently, we simply use these classes as efficient converters between
+// the quantized types and Vectorized, usually in bandwidth-bound cases
+// where doing the arithmetic in full-precision is acceptable (e.g.
+// elementwise operators).
+//
+//
+// Conversions are as follows:
+//  Vectorized -> 4x Vectorized
+//
+// The size of the returned float vector is specified by the special
+// constexpr function float_num_vecs. The type of the value returned
+// from dequantize (and expected as an argument to quantize) is
+// specified by float_vec_return_type.
+//
+// When writing kernels with these vectors, it is expected that floating-
+// point operations will be carried out in a loop over Vectorized::float_num_vecs
+// iterations.
+
+namespace at {
+namespace vec {
+inline namespace CPU_CAPABILITY {
+
+const vint16 mask_unsigned = vec_splats((short int)0xFF);
+template <>
+struct Vectorized {
+ private:
+  union {
+    struct {
+      vuint8 _vec0;
+      vuint8 _vec1;
+    };
+    struct {
+      vbool8 _vecb0;
+      vbool8 _vecb1;
+    };
+
+  } __attribute__((__may_alias__));
+
+ public:
+  Vectorized() {}
+  using size_type = int;
+  static constexpr size_type size() {
+    return 32;
+  }
+
+  static constexpr size_t float_num_vecs() {
+    return 4;
+  }
+  static constexpr int int_num_vecs() {
+    return 4;
+  }
+  using float_vec_return_type = std::array, 4>;
+  using int_vec_return_type = std::array, 4>;
+  using value_type = typename c10::quint8::underlying;
+  using vec_internal_type = vuint8;
+  using vec_internal_mask_type = vbool8;
+  // Broadcast constructor
+  C10_ALWAYS_INLINE Vectorized(const c10::quint8& val)
+      : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {}
+
+  C10_ALWAYS_INLINE Vectorized(const Vectorized& other)
+      : _vec0{other._vec0}, _vec1(other._vec1) {}
+
+  C10_ALWAYS_INLINE Vectorized(vuint8 v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
+  C10_ALWAYS_INLINE Vectorized(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {}
+
+  C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
+    return _vec0;
+  }
+  C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
+    return _vec1;
+  }
+
+  static C10_ALWAYS_INLINE Vectorized loadu(
+      const void* ptr,
+      int count = size()) {
+    if (count == size()) {
+      return {
+          vec_vsx_ld(offset0, reinterpret_cast(ptr)),
+          vec_vsx_ld(offset16, reinterpret_cast(ptr))};
+    }
+    __at_align__ value_type tmp_values[size()];
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
+    return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
+  }
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr));
+      vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ value_type tmp_values[size()];
+      vec_vsx_st(_vec0, offset0, tmp_values);
+      vec_vsx_st(_vec1, offset16, tmp_values);
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
+    }
+  }
+
+ public:
+  float_vec_return_type C10_ALWAYS_INLINE dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    // unpacking unsigned as signed
+    vint16 vecshi0 = vec_unpackh((vint8)_vec0);
+    vint16 vecshi1 = vec_unpackl((vint8)_vec0);
+
+    vint16 vecshi2 = vec_unpackh((vint8)_vec1);
+    vint16 vecshi3 = vec_unpackl((vint8)_vec1);
+
+    // signed ->  unsigned
+    vecshi0 = vec_and(vecshi0, mask_unsigned);
+    vecshi1 = vec_and(vecshi1, mask_unsigned);
+
+    vecshi2 = vec_and(vecshi2, mask_unsigned);
+    vecshi3 = vec_and(vecshi3, mask_unsigned);
+
+    vint32 veci0 = vec_unpackh(vecshi0);
+    vint32 veci1 = vec_unpackl(vecshi0);
+
+    vint32 veci2 = vec_unpackh(vecshi1);
+    vint32 veci3 = vec_unpackl(vecshi1);
+
+    vint32 veci4 = vec_unpackh(vecshi2);
+    vint32 veci5 = vec_unpackl(vecshi2);
+
+    vint32 veci6 = vec_unpackh(vecshi3);
+    vint32 veci7 = vec_unpackl(vecshi3);
+
+    vfloat32 vecf0_0 = vec_float(veci0);
+    vfloat32 vecf1_0 = vec_float(veci1);
+
+    vfloat32 vecf0_1 = vec_float(veci2);
+    vfloat32 vecf1_1 = vec_float(veci3);
+
+    vfloat32 vecf0_2 = vec_float(veci4);
+    vfloat32 vecf1_2 = vec_float(veci5);
+
+    vfloat32 vecf0_3 = vec_float(veci6);
+    vfloat32 vecf1_3 = vec_float(veci7);
+    vfloat32 scale_vec0 = scale.vec0();
+    vfloat32 scale_vec1 = scale.vec1();
+    vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
+    vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
+    return {
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_0, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)},
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_1, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)},
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_2, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)},
+        Vectorized{
+            vec_madd(scale_vec0, vecf0_3, scale_zp_premul0),
+            vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}};
+  }
+
+  float_vec_return_type C10_ALWAYS_INLINE dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    // unpacking unsigned as signed
+    vint16 vecshi0 = vec_unpackh((vint8)_vec0);
+    vint16 vecshi1 = vec_unpackl((vint8)_vec0);
+
+    vint16 vecshi2 = vec_unpackh((vint8)_vec1);
+    vint16 vecshi3 = vec_unpackl((vint8)_vec1);
+
+    // signed ->  unsigned
+    vecshi0 = vec_and(vecshi0, mask_unsigned);
+    vecshi1 = vec_and(vecshi1, mask_unsigned);
+
+    vecshi2 = vec_and(vecshi2, mask_unsigned);
+    vecshi3 = vec_and(vecshi3, mask_unsigned);
+
+    vint32 veci0 = vec_unpackh(vecshi0);
+    vint32 veci1 = vec_unpackl(vecshi0);
+
+    vint32 veci2 = vec_unpackh(vecshi1);
+    vint32 veci3 = vec_unpackl(vecshi1);
+
+    vint32 veci4 = vec_unpackh(vecshi2);
+    vint32 veci5 = vec_unpackl(vecshi2);
+
+    vint32 veci6 = vec_unpackh(vecshi3);
+    vint32 veci7 = vec_unpackl(vecshi3);
+
+    vfloat32 vecf0_0 = vec_float(veci0);
+    vfloat32 vecf1_0 = vec_float(veci1);
+
+    vfloat32 vecf0_1 = vec_float(veci2);
+    vfloat32 vecf1_1 = vec_float(veci3);
+
+    vfloat32 vecf0_2 = vec_float(veci4);
+    vfloat32 vecf1_2 = vec_float(veci5);
+
+    vfloat32 vecf0_3 = vec_float(veci6);
+    vfloat32 vecf1_3 = vec_float(veci7);
+    vfloat32 scale_vec0 = scale.vec0();
+    vfloat32 scale_vec1 = scale.vec1();
+    vfloat32 zero_point0 = zero_point.vec0();
+    vfloat32 zero_point1 = zero_point.vec1();
+    return {
+        Vectorized{
+            (vecf0_0 - zero_point0) * scale_vec0,
+            (vecf1_0 - zero_point1) * scale_vec1},
+        Vectorized{
+            (vecf0_1 - zero_point0) * scale_vec0,
+            (vecf1_1 - zero_point1) * scale_vec1},
+        Vectorized{
+            (vecf0_2 - zero_point0) * scale_vec0,
+            (vecf1_2 - zero_point1) * scale_vec1},
+        Vectorized{
+            (vecf0_3 - zero_point0) * scale_vec0,
+            (vecf1_3 - zero_point1) * scale_vec1}};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    // constexpr int32_t min_val = std::numeric_limits::min();
+    // constexpr int32_t max_val = std::numeric_limits::max();
+
+    vfloat32 vec_inverse = vec_splats(inverse_scale);
+    vfloat32 vec_zero_point = vec_splats((float)zero_point);
+    // vuint32 vmin = vec_splats(min_val);
+    // vuint32 vmax = vec_splats(max_val);
+    Vectorized vf0 = rhs[0];
+    Vectorized vf1 = rhs[1];
+    Vectorized vf2 = rhs[2];
+    Vectorized vf3 = rhs[3];
+    vfloat32 vecf0 = vf0.vec0();
+    vfloat32 vecf1 = vf0.vec1();
+    vfloat32 vecf2 = vf1.vec0();
+    vfloat32 vecf3 = vf1.vec1();
+
+    vfloat32 vecf4 = vf2.vec0();
+    vfloat32 vecf5 = vf2.vec1();
+    vfloat32 vecf6 = vf3.vec0();
+    vfloat32 vecf7 = vf3.vec1();
+
+    vecf0 = vec_mul(vecf0, vec_inverse);
+    vecf1 = vec_mul(vecf1, vec_inverse);
+    vecf2 = vec_mul(vecf2, vec_inverse);
+    vecf3 = vec_mul(vecf3, vec_inverse);
+
+    vecf4 = vec_mul(vecf4, vec_inverse);
+    vecf5 = vec_mul(vecf5, vec_inverse);
+    vecf6 = vec_mul(vecf6, vec_inverse);
+    vecf7 = vec_mul(vecf7, vec_inverse);
+
+    vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
+    vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
+    vecf2 = vec_add(vec_rint(vecf2), vec_zero_point);
+    vecf3 = vec_add(vec_rint(vecf3), vec_zero_point);
+
+    vecf4 = vec_add(vec_rint(vecf4), vec_zero_point);
+    vecf5 = vec_add(vec_rint(vecf5), vec_zero_point);
+    vecf6 = vec_add(vec_rint(vecf6), vec_zero_point);
+    vecf7 = vec_add(vec_rint(vecf7), vec_zero_point);
+
+    vint32 veci0 = vec_signed(vecf0);
+    vint32 veci1 = vec_signed(vecf1);
+    vint32 veci2 = vec_signed(vecf2);
+    vint32 veci3 = vec_signed(vecf3);
+
+    vint32 veci4 = vec_signed(vecf4);
+    vint32 veci5 = vec_signed(vecf5);
+    vint32 veci6 = vec_signed(vecf6);
+    vint32 veci7 = vec_signed(vecf7);
+
+    vint16 vecshi0 = vec_packs(veci0, veci1);
+    vint16 vecshi1 = vec_packs(veci2, veci3);
+    vint16 vecshi2 = vec_packs(veci4, veci5);
+    vint16 vecshi3 = vec_packs(veci6, veci7);
+
+    vuint8 vec0 = vec_packsu(vecshi0, vecshi1);
+    vuint8 vec1 = vec_packsu(vecshi2, vecshi3);
+
+    return {vec0, vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE relu(Vectorized zero_point) const {
+    return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE
+  relu6(Vectorized zero_point, Vectorized q_six) const {
+    vuint8 max0 = vec_max(_vec0, zero_point._vec0);
+    vuint8 max1 = vec_max(_vec1, zero_point._vec1);
+    return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    vint16 vecshi0 = vec_unpackh((vint8)_vec0);
+    vint16 vecBshi0 = vec_unpackh((vint8)b._vec0);
+    vint16 vecshi1 = vec_unpackl((vint8)_vec0);
+    vint16 vecBshi1 = vec_unpackl((vint8)b._vec0);
+
+    vint16 vecshi2 = vec_unpackh((vint8)_vec1);
+    vint16 vecBshi2 = vec_unpackh((vint8)b._vec1);
+    vint16 vecshi3 = vec_unpackl((vint8)_vec1);
+    vint16 vecBshi3 = vec_unpackl((vint8)b._vec1);
+
+    vecshi0 = vec_and(vecshi0, mask_unsigned);
+    vecBshi0 = vec_and(vecBshi0, mask_unsigned);
+    vecshi1 = vec_and(vecshi1, mask_unsigned);
+    vecBshi1 = vec_and(vecBshi1, mask_unsigned);
+
+    vecshi2 = vec_and(vecshi2, mask_unsigned);
+    vecBshi2 = vec_and(vecBshi2, mask_unsigned);
+    vecshi3 = vec_and(vecshi3, mask_unsigned);
+    vecBshi3 = vec_and(vecBshi3, mask_unsigned);
+
+    vint32 veci0 = vec_unpackh(vecshi0);
+    vint32 vecBi0 = vec_unpackh(vecBshi0);
+    vint32 veci1 = vec_unpackl(vecshi0);
+    vint32 vecBi1 = vec_unpackl(vecBshi0);
+
+    vint32 veci2 = vec_unpackh(vecshi1);
+    vint32 vecBi2 = vec_unpackh(vecBshi1);
+    vint32 veci3 = vec_unpackl(vecshi1);
+    vint32 vecBi3 = vec_unpackl(vecBshi1);
+
+    vint32 veci4 = vec_unpackh(vecshi2);
+    vint32 vecBi4 = vec_unpackh(vecBshi2);
+    vint32 veci5 = vec_unpackl(vecshi2);
+    vint32 vecBi5 = vec_unpackl(vecBshi2);
+
+    vint32 veci6 = vec_unpackh(vecshi3);
+    vint32 vecBi6 = vec_unpackh(vecBshi3);
+    vint32 veci7 = vec_unpackl(vecshi3);
+    vint32 vecBi7 = vec_unpackl(vecBshi3);
+
+    return {
+        Vectorized(veci0 - vecBi0, veci1 - vecBi1),
+        Vectorized(veci2 - vecBi2, veci3 - vecBi3),
+        Vectorized(veci4 - vecBi4, veci5 - vecBi5),
+        Vectorized(veci6 - vecBi6, veci7 - vecBi7)};
+  }
+
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    vfloat32 vec_multiplier = vec_splats(multiplier);
+    vint32 vec_zero_point = vec_splats(zero_point);
+
+    Vectorized vi0 = inp[0];
+    Vectorized vi1 = inp[1];
+    Vectorized vi2 = inp[2];
+    Vectorized vi3 = inp[3];
+
+    vfloat32 vecf0 = vec_float(vi0.vec0());
+    vfloat32 vecf1 = vec_float(vi0.vec1());
+    vfloat32 vecf2 = vec_float(vi1.vec0());
+    vfloat32 vecf3 = vec_float(vi1.vec1());
+
+    vfloat32 vecf4 = vec_float(vi2.vec0());
+    vfloat32 vecf5 = vec_float(vi2.vec1());
+    vfloat32 vecf6 = vec_float(vi3.vec0());
+    vfloat32 vecf7 = vec_float(vi3.vec1());
+
+    vecf0 = vec_mul(vecf0, vec_multiplier);
+    vecf1 = vec_mul(vecf1, vec_multiplier);
+    vecf2 = vec_mul(vecf2, vec_multiplier);
+    vecf3 = vec_mul(vecf3, vec_multiplier);
+
+    vecf4 = vec_mul(vecf4, vec_multiplier);
+    vecf5 = vec_mul(vecf5, vec_multiplier);
+    vecf6 = vec_mul(vecf6, vec_multiplier);
+    vecf7 = vec_mul(vecf7, vec_multiplier);
+
+    vecf0 = vec_rint(vecf0);
+    vecf1 = vec_rint(vecf1);
+    vecf2 = vec_rint(vecf2);
+    vecf3 = vec_rint(vecf3);
+
+    vecf4 = vec_rint(vecf4);
+    vecf5 = vec_rint(vecf5);
+    vecf6 = vec_rint(vecf6);
+    vecf7 = vec_rint(vecf7);
+
+    vint32 veci0 = vec_signed(vecf0);
+    vint32 veci1 = vec_signed(vecf1);
+    vint32 veci2 = vec_signed(vecf2);
+    vint32 veci3 = vec_signed(vecf3);
+
+    vint32 veci4 = vec_signed(vecf4);
+    vint32 veci5 = vec_signed(vecf5);
+    vint32 veci6 = vec_signed(vecf6);
+    vint32 veci7 = vec_signed(vecf7);
+
+    veci0 = vec_add(veci0, vec_zero_point);
+    veci1 = vec_add(veci1, vec_zero_point);
+    veci2 = vec_add(veci2, vec_zero_point);
+    veci3 = vec_add(veci3, vec_zero_point);
+
+    veci4 = vec_add(veci4, vec_zero_point);
+    veci5 = vec_add(veci5, vec_zero_point);
+    veci6 = vec_add(veci6, vec_zero_point);
+    veci7 = vec_add(veci7, vec_zero_point);
+
+    vint16 vecshi0 = vec_packs(veci0, veci1);
+    vint16 vecshi1 = vec_packs(veci2, veci3);
+    vint16 vecshi2 = vec_packs(veci4, veci5);
+    vint16 vecshi3 = vec_packs(veci6, veci7);
+
+    vuint8 vec0 = vec_packsu(vecshi0, vecshi1);
+    vuint8 vec1 = vec_packsu(vecshi2, vecshi3);
+
+    return {vec0, vec1};
+  }
+
+  DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq)
+  DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne)
+  DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt)
+  DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple)
+  DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt)
+  DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge)
+  DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add)
+  DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub)
+  DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul)
+  DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /)
+  DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max)
+  DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min)
+  DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and)
+  DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or)
+  DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor)
+};
+
+template <>
+Vectorized inline maximum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline minimum(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return a.minimum(b);
+}
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h
new file mode 100644
index 0000000000000000000000000000000000000000..c48f9fae148f123506ee82719f7a683c270836bd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h
@@ -0,0 +1,474 @@
+#pragma once
+#include 
+#include 
+#include 
+
+#if defined(__clang__)
+typedef __vector __bool char vbool8;
+typedef __vector __bool short vbool16;
+typedef __vector __bool int vbool32;
+typedef __vector __bool long long vbool64;
+using vint8    = __attribute__((vector_size(16))) signed char;
+using vint16   = __attribute__((vector_size(16))) signed short;
+using vint32   = __attribute__((vector_size(16))) signed int;
+using vint64   = __attribute__((vector_size(16))) signed long long;
+using vuint8   = __attribute__((vector_size(16))) unsigned char;
+using vuint16  = __attribute__((vector_size(16))) unsigned short;
+using vuint32  = __attribute__((vector_size(16))) unsigned int;
+using vuint64  = __attribute__((vector_size(16))) unsigned long long;
+using vfloat32 = __attribute__((vector_size(16))) float;
+using vfloat64 = __attribute__((vector_size(16))) double;
+#else
+using vbool8   =  __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char;
+using vbool16  =  __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short;
+using vbool32  =  __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) int;
+using vbool64  =  __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) long long;
+using vint8    =  __attribute__((altivec(vector__)))  signed char;
+using vint16   =  __attribute__((altivec(vector__)))  signed short;
+using vint32   =  __attribute__((altivec(vector__)))  signed int;
+using vint64   =  __attribute__((altivec(vector__)))  signed long long;
+using vuint8   =  __attribute__((altivec(vector__)))  unsigned char;
+using vuint16  =  __attribute__((altivec(vector__)))  unsigned short;
+using vuint32  =  __attribute__((altivec(vector__)))  unsigned  int;
+using vuint64  =  __attribute__((altivec(vector__)))  unsigned long long;
+using vfloat32 =  __attribute__((altivec(vector__)))  float;
+using vfloat64 =  __attribute__((altivec(vector__)))  double;
+#endif
+
+#if !defined(vec_float)
+C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) {
+  vfloat32 vec_out;
+  __asm__("xvcvsxwsp %x0,%x1" : "=wf"(vec_out) : "wa"(vec_in));
+  return vec_out;
+}
+#endif
+
+#if !defined(vec_signed)
+C10_ALWAYS_INLINE vint32 vec_signed(const vfloat32& vec_in) {
+  vint32 vec_out;
+  __asm__("xvcvspsxws %x0,%x1" : "=wa"(vec_out) : "wf"(vec_in));
+  return vec_out;
+}
+
+C10_ALWAYS_INLINE vint64 vec_signed(const vfloat64& vec_in) {
+  vint64 vec_out;
+  __asm__("xvcvdpsxds %x0,%x1" : "=wa"(vec_out) : "wd"(vec_in));
+  return vec_out;
+}
+#endif
+
+#if !defined(vec_neg)
+C10_ALWAYS_INLINE vfloat32 vec_neg(const vfloat32& vec_in) {
+  vfloat32 vec_out;
+  __asm__("xvnegsp %x0,%x1" : "=wf"(vec_out) : "wf"(vec_in));
+  return vec_out;
+}
+
+C10_ALWAYS_INLINE vfloat64 vec_neg(const vfloat64& vec_in) {
+  vfloat64 vec_out;
+  __asm__("xvnegdp %x0,%x1" : "=wd"(vec_out) : "wd"(vec_in));
+  return vec_out;
+}
+
+C10_ALWAYS_INLINE vint16 vec_neg(const vint16& vec_in) {
+  vint16 vint0 = {0, 0, 0, 0 ,0, 0, 0, 0};
+  return vec_vsubuhm(vint0, vec_in);
+}
+
+C10_ALWAYS_INLINE vint32 vec_neg(const vint32& vec_in) {
+  vint32 vint0 = {0, 0, 0, 0};
+  return vec_vsubuwm(vint0, vec_in);
+}
+
+C10_ALWAYS_INLINE vint64 vec_neg(const vint64& vec_in) {
+  return -vec_in;
+}
+#endif
+
+#if !defined(vec_sldw)
+template 
+C10_ALWAYS_INLINE vfloat32
+vec_sldw_aux(const vfloat32& vec_in0, const vfloat32& vec_in1) {
+  vfloat32 vec_out;
+  __asm("xxsldwi %x0, %x1, %x2, %3 "
+        : "=wa"(vec_out)
+        : "wa"(vec_in0), "wa"(vec_in1), "I"(C));
+  return vec_out;
+}
+
+#define vec_sldw(a, b, c) vec_sldw_aux(a, b)
+#endif
+
+#define vec_not(a) vec_nor(a, a)
+#if defined(__clang__) && !defined(vec_splats)
+C10_ALWAYS_INLINE vint64 vec_splats(const int64_t& a) {
+  return vec_splats(a);
+}
+#endif
+// Vectorized min/max which return a if any operand is nan
+template 
+C10_ALWAYS_INLINE T vec_min_nan(const T& a, const T& b) {
+  return vec_min(a, b);
+}
+template 
+C10_ALWAYS_INLINE T vec_max_nan(const T& a, const T& b) {
+  return vec_max(a, b);
+}
+
+// Specializations for float/double taken from Eigen
+template<>
+C10_ALWAYS_INLINE vfloat32 vec_min_nan(const vfloat32& a, const vfloat32& b)
+{
+  // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
+  vfloat32 ret;
+  __asm__ ("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+  return ret;
+}
+// Specializations for float/double taken from Eigen
+template<>
+C10_ALWAYS_INLINE vfloat32 vec_max_nan(const vfloat32& a, const vfloat32& b)
+{
+  // NOTE: about 10% slower than vec_max, but consistent with std::min and SSE regarding NaN
+  vfloat32 ret;
+   __asm__ ("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+  return ret;
+}
+
+template<>
+C10_ALWAYS_INLINE vfloat64 vec_min_nan(const vfloat64& a, const vfloat64& b)
+{
+  // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN
+  vfloat64 ret;
+  __asm__ ("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+  return ret;
+}
+template<>
+C10_ALWAYS_INLINE vfloat64 vec_max_nan(const vfloat64& a, const vfloat64& b)
+{
+  // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN
+  vfloat64 ret;
+  __asm__ ("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b));
+  return ret;
+}
+
+// Vectorizes min/max function which returns nan if any side is nan
+#define C10_VSX_VEC_NAN_PROPAG(name, type, btype, func)       \
+  C10_ALWAYS_INLINE type name(const type& a, const type& b) { \
+    type tmp = func(a, b);                                    \
+    btype nan_a = vec_cmpne(a, a);                            \
+    btype nan_b = vec_cmpne(b, b);                            \
+    tmp = vec_sel(tmp, a, nan_a);                             \
+    return vec_sel(tmp, b, nan_b);                            \
+  }
+
+C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat32, vbool32, vec_min)
+C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat32, vbool32, vec_max)
+C10_VSX_VEC_NAN_PROPAG(vec_min_nan2, vfloat64, vbool64, vec_min)
+C10_VSX_VEC_NAN_PROPAG(vec_max_nan2, vfloat64, vbool64, vec_max)
+
+#undef C10_VSX_VEC_NAN_PROPAG
+
+#define DEFINE_MEMBER_UNARY_OP(op, op_type, func)     \
+  Vectorized C10_ALWAYS_INLINE op() const {      \
+    return Vectorized{func(_vec0), func(_vec1)}; \
+  }
+
+#define DEFINE_MEMBER_OP(op, op_type, func)                                  \
+  Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) const { \
+    return Vectorized{                                                  \
+        func(_vec0, other._vec0), func(_vec1, other._vec1)};                 \
+  }
+
+#define DEFINE_MEMBER_BITWISE_OP(op, op_type, func)                          \
+  Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) const { \
+    return Vectorized{                                                  \
+        func(_vecb0, other._vecb0), func(_vecb1, other._vecb1)};             \
+  }
+
+#define DEFINE_MEMBER_TERNARY_OP(op, op_type, func)                    \
+  Vectorized C10_ALWAYS_INLINE op(                                \
+      const Vectorized& b, const Vectorized& c) const {      \
+    return Vectorized{                                            \
+        func(_vec0, b._vec0, c._vec0), func(_vec1, b._vec1, c._vec1)}; \
+  }
+
+#define DEFINE_MEMBER_EMULATE_BINARY_OP(op, op_type, binary_op)          \
+  Vectorized C10_ALWAYS_INLINE op(const Vectorized& b) const { \
+    Vectorized::vec_internal_type ret_0;                         \
+    Vectorized::vec_internal_type ret_1;                         \
+    for (int i = 0; i < Vectorized::size() / 2; i++) {           \
+      ret_0[i] = _vec0[i] binary_op b._vec0[i];                       \
+      ret_1[i] = _vec1[i] binary_op b._vec1[i];                       \
+    }                                                                 \
+    return Vectorized{ret_0, ret_1};                             \
+  }
+
+
+#define DEFINE_MEMBER_OP_AND_ONE(op, op_type, func)                          \
+  Vectorized C10_ALWAYS_INLINE op(const Vectorized& other) const { \
+    using vvtype = Vectorized::vec_internal_type;                       \
+    const vvtype v_one = vec_splats(static_cast(1.0));              \
+    vvtype ret0 = (vvtype)func(_vec0, other._vec0);                          \
+    vvtype ret1 = (vvtype)func(_vec1, other._vec1);                          \
+    return Vectorized{vec_and(ret0, v_one), vec_and(ret1, v_one)};      \
+  }
+
+#define DEFINE_CLAMP_FUNCS(operand_type)                                        \
+  template <>                                                                   \
+  Vectorized C10_ALWAYS_INLINE clamp(                             \
+      const Vectorized& a,                                        \
+      const Vectorized& min,                                      \
+      const Vectorized& max) {                                    \
+    return Vectorized{                                            \
+        vec_min_nan(vec_max_nan(a.vec0(), min.vec0()), max.vec0()),             \
+        vec_min_nan(vec_max_nan(a.vec1(), min.vec1()), max.vec1())};            \
+  }                                                                             \
+  template <>                                                                   \
+  Vectorized C10_ALWAYS_INLINE clamp_min(                         \
+      const Vectorized& a, const Vectorized& min) { \
+    return Vectorized{                                            \
+        vec_max_nan(a.vec0(), min.vec0()),                                      \
+        vec_max_nan(a.vec1(), min.vec1())};                                     \
+  }                                                                             \
+  template <>                                                                   \
+  Vectorized C10_ALWAYS_INLINE clamp_max(                         \
+      const Vectorized& a, const Vectorized& max) { \
+    return Vectorized{                                            \
+        vec_min_nan(a.vec0(), max.vec0()),                                      \
+        vec_min_nan(a.vec1(), max.vec1())};                                     \
+  }
+
+#define DEFINE_REINTERPRET_CAST_FUNCS(                             \
+    first_type, cast_type, cast_inner_vector_type)                 \
+  template <>                                                      \
+  C10_ALWAYS_INLINE Vectorized cast( \
+      const Vectorized& src) {                                 \
+    return Vectorized{(cast_inner_vector_type)src.vec0(),       \
+                             (cast_inner_vector_type)src.vec1()};      \
+  }
+
+#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(first_type)     \
+  DEFINE_REINTERPRET_CAST_FUNCS(first_type, double, vfloat64)    \
+  DEFINE_REINTERPRET_CAST_FUNCS(first_type, float, vfloat32)     \
+  DEFINE_REINTERPRET_CAST_FUNCS(first_type, int64_t, vint64) \
+  DEFINE_REINTERPRET_CAST_FUNCS(first_type, int32_t, vint32)   \
+  DEFINE_REINTERPRET_CAST_FUNCS(first_type, int16_t, vint16)
+
+// it can be used to emulate blend faster
+constexpr int blendChoice(uint32_t mask, uint32_t half1 = 0xF, uint32_t half2 = 0xF0) {
+  uint32_t none = 0;
+  uint32_t both = half1 | half2;
+  // clamp it between 0 and both
+  mask = mask & both;
+  // return  (a._vec0, a._vec1)
+  if (mask == none) return 0;
+  // return (b._vec0,b._vec1)
+  else if (mask == both)
+    return 1;
+  // return  (b._vec0,a._vec1)
+  else if (mask == half1)
+    return 2;
+  // return  (a._vec0,b._vec1)
+  else if (mask == half2)
+    return 3;
+  // return  (*_vec0,a._vec1)
+  else if (mask > 0 && mask < half1)
+    return 4;
+  // return  (*_vec0,b._vec1)
+  else if ((mask & half2) == half2)
+    return 5;
+  // return (a._vec0,*_vec1)
+  else if ((mask & half1) == 0 && mask > half1)
+    return 6;
+  // return (b._vec0,*_vec1)
+  else if ((mask & half1) == half1 && mask > half1)
+    return 7;
+  // return (*_vec0,*_vec1)
+  return 8;
+}
+
+// it can be used to emulate blend faster
+constexpr int blendChoiceDbl(uint32_t mask) {
+  // clamp it 0 and 0xF
+  return blendChoice(mask, 0x3, 0xC);
+}
+
+constexpr vbool32 VsxMask1(uint32_t mask) {
+  uint32_t g0 = (mask & 1) * 0xffffffff;
+  uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
+  uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
+  uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
+  return (vbool32){g0, g1, g2, g3};
+}
+
+constexpr vbool32 VsxMask2(uint32_t mask) {
+  uint32_t mask2 = (mask & 0xFF) >> 4;
+  return VsxMask1(mask2);
+}
+
+constexpr vbool64 VsxDblMask1(uint32_t mask) {
+  uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
+  uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
+  return (vbool64){g0, g1};
+}
+
+constexpr vbool64 VsxDblMask2(uint32_t mask) {
+  uint32_t mask2 = (mask & 0xF) >> 2;
+  return VsxDblMask1(mask2);
+}
+
+constexpr int maskForComplex(uint32_t mask) {
+  mask = mask & 0xF;
+  int complex_mask = 0;
+  if (mask & 1) complex_mask |= 3;
+  if (mask & 2) complex_mask |= (3 << 2);
+  if (mask & 4) complex_mask |= (3 << 4);
+  if (mask & 8) complex_mask |= (3 << 6);
+  return complex_mask;
+}
+
+constexpr int maskForComplexDbl(uint32_t mask) {
+  mask = mask & 0x3;
+  int complex_mask = 0;
+  if (mask & 1) complex_mask |= 3;
+  if (mask & 2) complex_mask |= (3 << 2);
+  return complex_mask;
+}
+
+constexpr int blendChoiceComplex(uint32_t mask) {
+  return blendChoice(maskForComplex(mask));
+}
+
+constexpr int blendChoiceComplexDbl(uint32_t mask) {
+  return blendChoiceDbl(maskForComplexDbl(mask));
+}
+
+constexpr vbool32 VsxComplexMask1(uint32_t mask) {
+  return VsxMask1(maskForComplex(mask));
+}
+
+constexpr vbool32 VsxComplexMask2(uint32_t mask) {
+  uint32_t mask2 = (mask & 0xF) >> 2;
+  return VsxMask1(maskForComplex(mask2));
+}
+
+constexpr vbool64 VsxComplexDblMask1(uint32_t mask) { return VsxDblMask1(mask); }
+
+constexpr vbool64 VsxComplexDblMask2(uint32_t mask) {
+  uint32_t mask2 = (mask & 0xF) >> 2;
+  return VsxDblMask1(mask2);
+}
+
+// constants
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+//
+constexpr int offset0 = 0;
+constexpr int offset16 = 16;
+
+// #Constants
+const vuint8 mask_zero_bits = vuint8{128, 128, 128, 128, 128, 128, 128, 128,
+                                128, 128, 128, 128, 96,  64,  32,  0};
+
+const vuint8 swap_mask =
+    vuint8{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11};
+
+const vint32 v0x7f = vec_splats(0x7f);
+const vint32 vi_0 = vec_splats((int)(0));
+const vint32 vi_1 = vec_splats((int)1);
+const vint32 vi_2 = vec_splats((int)2);
+const vint32 vi_4 = vec_splats((int)4);
+const vint32 vi_inv1 = vec_splats((int)~1);
+const vuint32 vu_29 = vec_splats(29u);
+const vuint32 vu_23 = vec_splats(23u);
+
+const vbool32 inv_mant_mask = (vbool32)vec_splats((unsigned int)~0xff800000);
+const vbool32 sign_mask = (vbool32)vec_splats((int)0x80000000);
+const vbool32 real_mask = vbool32{0xFFFFFFFF, 0x0, 0xFFFFFFFF, 0x0};
+const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF};
+const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000};
+const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0};
+
+const vbool64 vd_sign_mask  = vbool64{0x8000000000000000, 0x8000000000000000};
+const vbool64 vd_imag_mask  = vbool64{0x0, 0xFFFFFFFFFFFFFFFF};
+const vbool64 vd_real_mask  = vbool64{0xFFFFFFFFFFFFFFFF, 0x0};
+const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000};
+const vbool64 vd_rsign_mask = vbool64{0x8000000000000000, 0x0};
+
+const vfloat32 zero = vec_splats(0.f);
+const vfloat32 half = vec_splats(0.5f);
+const vfloat32 one = vec_splats(1.f);
+const vfloat32 two = vec_splats(2.0f);
+const vfloat32 _4div_pi = vec_splats(1.27323954473516f);
+const vfloat32 v_inf = (vfloat32)vec_splats(0x7f800000u);
+const vfloat32 v_minus_inf = vfloat32{ 0xff800000u, 0xff800000u, 0xff800000u, 0xff800000u };
+const vfloat32 v_nan = (vfloat32)vec_splats(0x7fffffff);
+const vfloat32 log10e_inv = vec_splats(0.43429448190325176f);
+const vfloat32 log2e_inv = vec_splats(1.4426950408889634f);
+const vfloat32 log2eB_inv = vec_splats(1.442695036924675f);
+const vfloat32 cephes_SQRTHF = vec_splats(0.707106781186547524f);
+const vfloat32 coscof_p0 = vec_splats(2.443315711809948E-005f);
+const vfloat32 coscof_p1 = vec_splats(-1.388731625493765E-003f);
+const vfloat32 coscof_p2 = vec_splats(4.166664568298827E-002f);
+const vfloat32 exp_hi = vec_splats(104.f);
+const vfloat32 exp_lo = vec_splats(-104.f);
+const vfloat32 exp_p0 = vec_splats(0.000198527617612853646278381f);
+const vfloat32 exp_p1 = vec_splats((0.00139304355252534151077271f));
+const vfloat32 exp_p2 = vec_splats(0.00833336077630519866943359f);
+const vfloat32 exp_p3 = vec_splats(0.0416664853692054748535156f);
+const vfloat32 exp_p4 = vec_splats(0.166666671633720397949219f);
+const vfloat32 exp_p5 = vec_splats(0.5f);
+const vfloat32 log_p0 = vec_splats(7.0376836292E-2f);
+const vfloat32 log_p1 = vec_splats(-1.1514610310E-1f);
+const vfloat32 log_p2 = vec_splats(1.1676998740E-1f);
+const vfloat32 log_p3 = vec_splats(-1.2420140846E-1f);
+const vfloat32 log_p4 = vec_splats(+1.4249322787E-1f);
+const vfloat32 log_p5 = vec_splats(-1.6668057665E-1f);
+const vfloat32 log_p6 = vec_splats(+2.0000714765E-1f);
+const vfloat32 log_p7 = vec_splats(-2.4999993993E-1f);
+const vfloat32 log_p8 = vec_splats(+3.3333331174E-1f);
+const vfloat32 log_q1 = vec_splats(-2.12194440e-4f);
+const vfloat32 log_q2 = vec_splats(0.693359375f);
+const vfloat32 max_logf = vec_splats(88.02969187150841f);
+const vfloat32 max_numf = vec_splats(1.7014117331926442990585209174225846272e38f);
+const vfloat32 min_inf = (vfloat32)vec_splats(0xff800000u);
+const vfloat32 min_norm_pos = (vfloat32)vec_splats(0x0800000u);
+const vfloat32 minus_cephes_dp1 = vec_splats(-0.78515625f);
+const vfloat32 minus_cephes_dp2 = vec_splats(-2.4187564849853515625e-4f);
+const vfloat32 minus_cephes_dp3 = vec_splats(-3.77489497744594108e-8f);
+const vfloat32 negln2f_hi = vec_splats(-0.693145751953125f);
+const vfloat32 negln2f_lo = vec_splats(-1.428606765330187045e-06f);
+const vfloat32 p0 = vec_splats(2.03721912945E-4f);
+const vfloat32 p1 = vec_splats(8.33028376239E-3f);
+const vfloat32 p2 = vec_splats(1.66667160211E-1f);
+const vfloat32 sincof_p0 = vec_splats(-1.9515295891E-4f);
+const vfloat32 sincof_p1 = vec_splats(8.3321608736E-3f);
+const vfloat32 sincof_p2 = vec_splats(-1.6666654611E-1f);
+const vfloat32 tanh_0p625 = vec_splats(0.625f);
+const vfloat32 tanh_half_max = vec_splats(44.014845935754205f);
+const vfloat32 tanh_p0 = vec_splats(-5.70498872745E-3f);
+const vfloat32 tanh_p1 = vec_splats(2.06390887954E-2f);
+const vfloat32 tanh_p2 = vec_splats(-5.37397155531E-2f);
+const vfloat32 tanh_p3 = vec_splats(1.33314422036E-1f);
+const vfloat32 tanh_p4 = vec_splats(-3.33332819422E-1f);
+const vfloat32 vcheck = vec_splats((float)(1LL << 24));
+const vfloat32 imag_one = vfloat32{0.f, 1.f, 0.f, 1.f};
+const vfloat32 imag_half = vfloat32{0.f, 0.5f, 0.f, 0.5f};
+const vfloat32 sqrt2_2 = vfloat32{0.70710676908493042f, 0.70710676908493042,
+                          0.70710676908493042, 0.70710676908493042};
+const vfloat32 pi_2 = vfloat32{M_PI / 2, 0.0, M_PI / 2, 0.0};
+const vfloat32 vf_89 = vfloat32{89.f, 89.f, 89.f, 89.f};
+const vfloat64 vd_one = vec_splats(1.0);
+const vfloat64 vd_zero = vec_splats(0.0);
+const vfloat64 vd_log10e_inv = vec_splats(0.43429448190325176);
+const vfloat64 vd_log2e_inv = vec_splats(1.4426950408889634);
+const vfloat64 vd_imag_one = vfloat64{0.0, 1.0};
+const vfloat64 vd_imag_half = vfloat64{0.0, 0.5};
+const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757};
+const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0};
+
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
new file mode 100644
index 0000000000000000000000000000000000000000..6284bfa6735f77a0e7d29937e0e5cbe057e69108
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
@@ -0,0 +1,2818 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+#if defined(__clang__)
+#include 
+#elif defined(__GNUC__) || defined(__GNUG__)
+#include 
+#include 
+#endif
+#include 
+#include 
+#include 
+
+#define SLEEF_MEMORY_WORKAROUND
+
+namespace at {
+namespace vec {
+
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+template 
+constexpr bool is_zarch_implemented() {
+  return (
+      std::is_same::value || std::is_same::value ||
+      std::is_same::value || std::is_same::value ||
+      std::is_same::value || std::is_same::value ||
+      std::is_same::value || std::is_same::value);
+}
+
+template 
+constexpr bool is_zarch_implemented_quant() {
+  return (
+      std::is_same::value ||
+      std::is_same::value ||
+      std::is_same::value);
+}
+
+template 
+constexpr bool is_zarch_implemented_complex() {
+  return std::is_same>::value ||
+      std::is_same>::value;
+}
+
+constexpr int offset0 = 0;
+constexpr int offset16 = 16;
+
+template 
+struct VecBinaryType {
+  using type __attribute__((vector_size(16))) = uintmax_t;
+};
+
+template <>
+struct VecBinaryType<8> {
+  using type = __attribute__((vector_size(16))) unsigned long long;
+};
+
+template <>
+struct VecBinaryType<4> {
+  using type = __attribute__((vector_size(16))) unsigned int;
+};
+
+template <>
+struct VecBinaryType<2> {
+  using type = __attribute__((vector_size(16))) unsigned short;
+};
+
+template <>
+struct VecBinaryType<1> {
+  using type = __attribute__((vector_size(16))) unsigned char;
+};
+
+template 
+struct VecInnerType {
+  using Type __attribute__((vector_size(16))) = T;
+  using BinaryType = typename VecBinaryType::type;
+  using ElementType = T;
+  static constexpr int size = 16 / sizeof(T);
+};
+
+// define for int64_t properly for load
+template <>
+struct VecInnerType {
+  using Type = __attribute__((vector_size(16))) signed long long;
+  using ElementType = signed long long;
+  using BinaryType = typename VecBinaryType::type;
+  static constexpr int size = 16 / sizeof(signed long long);
+};
+
+template 
+using ZSimdVect = typename VecInnerType::Type;
+template 
+using ZSimdVectBinary = typename VecInnerType::BinaryType;
+template 
+using ZSimdVectElement = typename VecInnerType::ElementType;
+
+constexpr int blendChoiceInner(
+    const uint64_t mask,
+    const uint64_t half1 = 0xF,
+    const uint64_t half2 = 0xF0) {
+  uint64_t none = 0;
+  uint64_t both = half1 | half2;
+  // clamp it between 0 and both
+  auto res_mask = mask & both;
+  // return  (a._vec0, a._vec1)
+  if (res_mask == none)
+    return 0;
+  // return (b._vec0,b._vec1)
+  else if (res_mask == both)
+    return 1;
+  // return  (b._vec0, a._vec1)
+  else if (res_mask == half1)
+    return 2;
+  // return  (a._vec0,b._vec1)
+  else if (res_mask == half2)
+    return 3;
+  // return  (*_vec0,a._vec1)
+  else if (res_mask > 0 && res_mask < half1)
+    return 4;
+  // return  (*_vec0,b._vec1)
+  else if ((res_mask & half2) == half2)
+    return 5;
+  // return (a._vec0,*_vec1)
+  else if ((res_mask & half1) == 0 && res_mask > half1)
+    return 6;
+  // return (b._vec0,*_vec1)
+  else if ((res_mask & half1) == half1 && res_mask > half1)
+    return 7;
+  // return (*_vec0,*_vec1)
+  return 8;
+}
+
+// it can be used to emulate blend faster
+template 
+constexpr int blendChoice(const uint64_t mask) {
+  static_assert(Z < 1 || Z > 8, "not implemented");
+  return blendChoiceInner(mask);
+}
+
+template <>
+constexpr int blendChoice<1>(const uint64_t mask) {
+  return blendChoiceInner(mask, 0x0000FFFF, 0xFFFF0000);
+}
+
+template <>
+constexpr int blendChoice<2>(const uint64_t mask) {
+  return blendChoiceInner(mask, 0x00FF, 0xFF00);
+}
+
+template <>
+constexpr int blendChoice<4>(const uint64_t mask) {
+  return blendChoiceInner(mask, 0xF, 0xF0);
+}
+
+template <>
+constexpr int blendChoice<8>(const uint64_t mask) {
+  // clamp it 0 and 0xF
+  return blendChoiceInner(mask, 0x3, 0xC);
+}
+
+template 
+constexpr auto GetMask1(const uint64_t mask) {
+  return typename VecBinaryType::type{};
+}
+
+template 
+constexpr auto GetMask2(const uint64_t mask) {
+  return typename VecBinaryType::type{};
+}
+
+template <>
+constexpr auto GetMask1<1>(const uint64_t mask) {
+  constexpr uint8_t t = (int)0xFF;
+  uint8_t g0 = (mask & 1) * t;
+  uint8_t g1 = ((mask & 2) >> 1) * t;
+  uint8_t g2 = ((mask & 4) >> 2) * t;
+  uint8_t g3 = ((mask & 8) >> 3) * t;
+  uint8_t g4 = ((mask & 16) >> 4) * t;
+  uint8_t g5 = ((mask & 32) >> 5) * t;
+  uint8_t g6 = ((mask & 64) >> 6) * t;
+  uint8_t g7 = ((mask & 128) >> 7) * t;
+  uint8_t g8 = ((mask & 256) >> 8) * t;
+  uint8_t g9 = ((mask & 512) >> 9) * t;
+  uint8_t g10 = ((mask & 1024) >> 10) * t;
+  uint8_t g11 = ((mask & 2048) >> 11) * t;
+  uint8_t g12 = ((mask & 4096) >> 12) * t;
+  uint8_t g13 = ((mask & 8192) >> 13) * t;
+  uint8_t g14 = ((mask & 16384) >> 14) * t;
+  uint8_t g15 = ((mask & 32768) >> 15) * t;
+  return (typename VecBinaryType<1>::type){
+      g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15};
+}
+
+template <>
+constexpr auto GetMask2<1>(const uint64_t mask) {
+  uint64_t mask2 = (mask & 0xFFFFFFFF) >> 16;
+  return GetMask1<1>(mask2);
+}
+
+template <>
+constexpr auto GetMask1<2>(const uint64_t mask) {
+  constexpr uint16_t t = (int)0xFFFF;
+  uint16_t g0 = (mask & 1) * t;
+  uint16_t g1 = ((mask & 2) >> 1) * t;
+  uint16_t g2 = ((mask & 4) >> 2) * t;
+  uint16_t g3 = ((mask & 8) >> 3) * t;
+  uint16_t g4 = ((mask & 16) >> 4) * t;
+  uint16_t g5 = ((mask & 32) >> 5) * t;
+  uint16_t g6 = ((mask & 64) >> 6) * t;
+  uint16_t g7 = ((mask & 128) >> 7) * t;
+  return (typename VecBinaryType<2>::type){g0, g1, g2, g3, g4, g5, g6, g7};
+}
+
+template <>
+constexpr auto GetMask2<2>(const uint64_t mask) {
+  uint64_t mask2 = (mask & 0xFFFF) >> 8;
+  return GetMask1<2>(mask2);
+}
+
+template <>
+constexpr auto GetMask1<4>(const uint64_t mask) {
+  uint32_t g0 = (mask & 1) * 0xffffffff;
+  uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
+  uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
+  uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
+  return (typename VecBinaryType<4>::type){g0, g1, g2, g3};
+}
+
+template <>
+constexpr auto GetMask2<4>(const uint64_t mask) {
+  uint64_t mask2 = (mask & 0xFF) >> 4;
+  return GetMask1<4>(mask2);
+}
+
+template <>
+constexpr auto GetMask1<8>(const uint64_t mask) {
+  uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
+  uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
+  return (typename VecBinaryType<8>::type){g0, g1};
+}
+
+template <>
+constexpr auto GetMask2<8>(const uint64_t mask) {
+  uint64_t mask2 = (mask & 0xF) >> 2;
+  return GetMask1<8>(mask2);
+}
+
+template 
+constexpr int maskForComplex(uint32_t mask) {
+  return 0;
+}
+
+template <>
+constexpr int maskForComplex<8>(uint32_t mask) {
+  mask = mask & 0xF;
+  int complex_mask = 0;
+  if (mask & 1)
+    complex_mask |= 3;
+  if (mask & 2)
+    complex_mask |= (3 << 2);
+  if (mask & 4)
+    complex_mask |= (3 << 4);
+  if (mask & 8)
+    complex_mask |= (3 << 6);
+  return complex_mask;
+}
+
+template <>
+constexpr int maskForComplex<16>(uint32_t mask) {
+  mask = mask & 0x3;
+  int complex_mask = 0;
+  if (mask & 1)
+    complex_mask |= 3;
+  if (mask & 2)
+    complex_mask |= (3 << 2);
+  return complex_mask;
+}
+
+template >
+constexpr int blend_choice() {
+  return 0xAA;
+}
+
+template <>
+constexpr int blend_choice>() {
+  return 0x0A;
+}
+
+constexpr int64_t allbitset(int16_t x) {
+  int64_t onex = 1;
+  return (onex << x) - onex;
+}
+
+namespace { /* unnamed namespace */
+
+ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) {
+  constexpr ZSimdVectBinary mergee_mask{
+      0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
+  return vec_perm(x, y, mergee_mask);
+}
+
+ZSimdVect vec_mergee(ZSimdVect x, ZSimdVect y) {
+  return vec_mergeh(x, y);
+}
+
+ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) {
+  constexpr ZSimdVectBinary mergeo_mask{
+      4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
+  return vec_perm(x, y, mergeo_mask);
+}
+
+ZSimdVect vec_mergeo(ZSimdVect x, ZSimdVect y) {
+  return vec_mergel(x, y);
+}
+
+} /* unnamed namespace */
+
+//
+template 
+constexpr auto GetBpermZeroMask() {
+  return ZSimdVectBinary{
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      96,
+      64,
+      32,
+      0};
+}
+
+template <>
+constexpr auto GetBpermZeroMask() {
+  return ZSimdVectBinary{
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      128,
+      64,
+      0};
+}
+
+constexpr auto GetSwapMaskFloat() {
+  return ZSimdVectBinary{
+      4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11};
+}
+
+template 
+struct Vectorized()>> {
+ public:
+  using value_type = T;
+  using vtype = ZSimdVect;
+  using vmaskType = ZSimdVectBinary;
+  using size_type = int;
+  // because of gcc inconsistency for int64_t we are obliged to use this, not
+  // value_type
+  using ElementType = ZSimdVectElement;
+  using vinner_data = std::pair;
+
+ private:
+  vtype _vec0;
+  vtype _vec1;
+
+ public:
+  static constexpr size_type size() {
+    return VECTOR_WIDTH / sizeof(ElementType);
+  }
+  Vectorized() {}
+
+  C10_ALWAYS_INLINE Vectorized(vtype v) : _vec0{v}, _vec1{v} {}
+  C10_ALWAYS_INLINE Vectorized(const vinner_data &v) : _vec0{v.first}, _vec1{v.second} {}
+  C10_ALWAYS_INLINE Vectorized(vtype v1, vtype v2) : _vec0{v1}, _vec1{v2} {}
+  C10_ALWAYS_INLINE Vectorized(T s)
+      : _vec0{vec_splats((ElementType)s)}, _vec1{vec_splats((ElementType)s)} {}
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    if (count == size()) {
+      return {
+          vec_xl(offset0, reinterpret_cast(ptr)),
+          vec_xl(offset16, reinterpret_cast(ptr))};
+    }
+
+    __at_align__ ElementType tmp_values[size()] = {};
+    std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(ElementType));
+
+    return {
+        vec_xl(offset0, reinterpret_cast(tmp_values)),
+        vec_xl(offset16, reinterpret_cast(tmp_values))};
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu_one_fourth(const void* ptr) {
+    // load only first 8 bytes
+    // only intended to be used with uint8_t
+    return loadu(ptr, 8 / sizeof(ElementType));
+  }
+
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      vec_xst(_vec0, offset0, reinterpret_cast(ptr));
+      vec_xst(_vec1, offset16, reinterpret_cast(ptr));
+    } else if (count > 0) {
+      __at_align__ ElementType tmp_values[size()];
+      vec_xst(_vec0, offset0, reinterpret_cast(tmp_values));
+      vec_xst(_vec1, offset16, reinterpret_cast(tmp_values));
+      std::memcpy(
+          ptr, tmp_values, std::min(count, size()) * sizeof(ElementType));
+    }
+  }
+
+  C10_ALWAYS_INLINE const vtype& vec0() const {
+    return _vec0;
+  }
+
+  C10_ALWAYS_INLINE const vtype& vec1() const {
+    return _vec1;
+  }
+
+  C10_ALWAYS_INLINE vinner_data data() const {
+    return std::make_pair<>(_vec0, _vec1);
+  }
+
+  C10_ALWAYS_INLINE operator vinner_data() const {
+    return data();
+  }
+
+  C10_ALWAYS_INLINE const vmaskType vecb0() const {
+    return (vmaskType)_vec0;
+  }
+  C10_ALWAYS_INLINE const vmaskType vecb1() const {
+    return (vmaskType)_vec1;
+  }
+
+  static Vectorized C10_ALWAYS_INLINE blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    return {
+        vec_sel(a._vec0, b._vec0, mask.vecb0()),
+        vec_sel(a._vec1, b._vec1, mask.vecb1())};
+  }
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4)
+      : _vec0{s1, s2}, _vec1{s3, s4} {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4, T s5, T s6, T s7, T s8)
+      : _vec0{s1, s2, s3, s4}, _vec1{s5, s6, s7, s8} {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(
+      T s1,
+      T s2,
+      T s3,
+      T s4,
+      T s5,
+      T s6,
+      T s7,
+      T s8,
+      T s9,
+      T s10,
+      T s11,
+      T s12,
+      T s13,
+      T s14,
+      T s15,
+      T s16)
+      : _vec0{s1, s2, s3, s4, s5, s6, s7, s8},
+        _vec1{s9, s10, s11, s12, s13, s14, s15, s16} {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(
+      T s1,
+      T s2,
+      T s3,
+      T s4,
+      T s5,
+      T s6,
+      T s7,
+      T s8,
+      T s9,
+      T s10,
+      T s11,
+      T s12,
+      T s13,
+      T s14,
+      T s15,
+      T s16,
+      T s17,
+      T s18,
+      T s19,
+      T s20,
+      T s21,
+      T s22,
+      T s23,
+      T s24,
+      T s25,
+      T s26,
+      T s27,
+      T s28,
+      T s29,
+      T s30,
+      T s31,
+      T s32)
+      : _vec0{s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16},
+        _vec1{
+            s17,
+            s18,
+            s19,
+            s20,
+            s21,
+            s22,
+            s23,
+            s24,
+            s25,
+            s26,
+            s27,
+            s28,
+            s29,
+            s30,
+            s31,
+            s32} {}
+
+  template 
+  static std::enable_if_t> arange(
+      T base = 0,
+      step_t step = static_cast(1)) {
+    return Vectorized(base, base + step, base + 2 * step, base + 3 * step);
+  }
+
+  template 
+  static std::enable_if_t> arange(
+      T base = 0,
+      step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + 2 * step,
+        base + 3 * step,
+        base + 4 * step,
+        base + 5 * step,
+        base + 6 * step,
+        base + 7 * step);
+  }
+
+  template 
+  static std::enable_if_t> arange(
+      T base = 0,
+      step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + 2 * step,
+        base + 3 * step,
+        base + 4 * step,
+        base + 5 * step,
+        base + 6 * step,
+        base + 7 * step,
+        base + 8 * step,
+        base + 9 * step,
+        base + 10 * step,
+        base + 11 * step,
+        base + 12 * step,
+        base + 13 * step,
+        base + 14 * step,
+        base + 15 * step);
+  }
+
+  template 
+  static std::enable_if_t> arange(
+      T base = 0,
+      step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + 2 * step,
+        base + 3 * step,
+        base + 4 * step,
+        base + 5 * step,
+        base + 6 * step,
+        base + 7 * step,
+        base + 8 * step,
+        base + 9 * step,
+        base + 10 * step,
+        base + 11 * step,
+        base + 12 * step,
+        base + 13 * step,
+        base + 14 * step,
+        base + 15 * step,
+        base + 16 * step,
+        base + 17 * step,
+        base + 18 * step,
+        base + 19 * step,
+        base + 20 * step,
+        base + 21 * step,
+        base + 22 * step,
+        base + 23 * step,
+        base + 24 * step,
+        base + 25 * step,
+        base + 26 * step,
+        base + 27 * step,
+        base + 28 * step,
+        base + 29 * step,
+        base + 30 * step,
+        base + 31 * step);
+  }
+
+  // blend section
+  template 
+  static std::enable_if_t(mask) == 0, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    return a;
+  }
+
+  template 
+  static std::enable_if_t(mask) == 1, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    return b;
+  }
+
+  template 
+  static std::enable_if_t(mask) == 2, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    return {b._vec0, a._vec1};
+  }
+
+  template 
+  static std::enable_if_t(mask) == 3, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    return {a._vec0, b._vec1};
+  }
+
+  template 
+  static std::enable_if_t(mask) == 4, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    const vmaskType mask_1st = GetMask1(mask);
+    return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
+  }
+
+  template 
+  static std::enable_if_t(mask) == 5, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    const vmaskType mask_1st = GetMask1(mask);
+    return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
+  }
+
+  template 
+  static std::enable_if_t(mask) == 6, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    const vmaskType mask_2nd = GetMask2(mask);
+    // generated masks
+    return {a._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t(mask) == 7, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    const vmaskType mask_2nd = GetMask2(mask);
+    // generated masks
+    return {b._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static std::enable_if_t(mask) == 8, Vectorized>
+      C10_ALWAYS_INLINE blend(const Vectorized& a, const Vectorized& b) {
+    const vmaskType mask_1st = GetMask1(mask);
+    const vmaskType mask_2nd = GetMask2(mask);
+    return {
+        (vtype)vec_sel(a._vec0, b._vec0, mask_1st),
+        (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)};
+  }
+
+  template 
+  static inline std::enable_if_t<(Z >= C), Vectorized> set_inner(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count) {
+    return b;
+  }
+
+  template 
+  static inline std::enable_if_t<(Z < C), Vectorized> set_inner(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count) {
+    if (count == Z)
+      return blend(a, b);
+    else
+      return set_inner(a, b, count);
+  }
+
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count = size()) {
+    if (count == 0)
+      return a;
+    return set_inner<1, size()>(a, b, count);
+  }
+
+  const ElementType& operator[](int idx) const = delete;
+  ElementType& operator[](int idx) = delete;
+
+  Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& other) const {
+    return Vectorized{_vec0 + other._vec0, _vec1 + other._vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator-(const Vectorized& other) const {
+    return Vectorized{_vec0 - other._vec0, _vec1 - other._vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator*(const Vectorized& other) const {
+    return Vectorized{_vec0 * other._vec0, _vec1 * other._vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator/(const Vectorized& other) const {
+    return Vectorized{_vec0 / other._vec0, _vec1 / other._vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator&(const Vectorized& other) const {
+    return Vectorized{
+        (vtype)(vecb0() & other.vecb0()), (vtype)(vecb1() & other.vecb1())};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator|(const Vectorized& other) const {
+    return Vectorized{
+        (vtype)(vecb0() | other.vecb0()), (vtype)(vecb1() | other.vecb1())};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator^(const Vectorized& other) const {
+    return Vectorized{
+        (vtype)(vecb0() ^ other.vecb0()), (vtype)(vecb1() ^ other.vecb1())};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator<<(const Vectorized &other) const {
+    constexpr ElementType max_shift = sizeof(ElementType) * CHAR_BIT;
+
+    ElementType a_array[Vectorized::size()];
+    ElementType b_array[Vectorized::size()];
+    ElementType c_array[Vectorized::size()];
+
+    store(a_array);
+    other.store(b_array);
+
+    for (int i = 0; i != Vectorized::size(); i++) {
+      T shift = b_array[i];
+      if ((static_cast>(shift) < 0) || (shift >= max_shift)) {
+        c_array[i] = 0;
+      } else {
+        c_array[i] = static_cast>(a_array[i]) << shift;
+      }
+   }
+
+    return loadu(c_array);
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator>>(const Vectorized &other) const {
+    // right shift value to retain sign bit for signed and no bits for unsigned
+    constexpr ElementType max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v;
+
+    ElementType a_array[Vectorized::size()];
+    ElementType b_array[Vectorized::size()];
+    ElementType c_array[Vectorized::size()];
+
+    store(a_array);
+    other.store(b_array);
+
+    for (int i = 0; i != Vectorized::size(); i++) {
+      T shift = b_array[i];
+      if ((static_cast>(shift) < 0) || (shift >= max_shift)) {
+        c_array[i] = a_array[i] >> max_shift;
+      } else {
+        c_array[i] = a_array[i] >> shift;
+      }
+    }
+
+    return loadu(c_array);
+  }
+
+  Vectorized _not() const {
+    return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator==(const Vectorized& other) const {
+    return Vectorized{
+        vec_cmpeq(_vec0, other._vec0), vec_cmpeq(_vec1, other._vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator!=(const Vectorized& other) const {
+    return Vectorized{
+        vec_cmpeq(_vec0, other._vec0), vec_cmpeq(_vec1, other._vec1)}
+        ._not();
+  }
+  Vectorized C10_ALWAYS_INLINE operator>(const Vectorized& other) const {
+    return Vectorized{
+        vec_cmpgt(_vec0, other._vec0), vec_cmpgt(_vec1, other._vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE operator>=(const Vectorized& other) const {
+    return Vectorized{
+        vec_cmpge(_vec0, other._vec0), vec_cmpge(_vec1, other._vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator<(const Vectorized& other) const {
+    return Vectorized{
+        vec_cmplt(_vec0, other._vec0), vec_cmplt(_vec1, other._vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator<=(const Vectorized& other) const {
+    return Vectorized{
+        vec_cmple(_vec0, other._vec0), vec_cmple(_vec1, other._vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const {
+    return (*this == other) & Vectorized((T)1.0);
+  }
+  Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const {
+    return (*this != other) & Vectorized((T)1.0);
+  }
+  Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const {
+    return (*this > other) & Vectorized((T)1.0);
+  }
+  Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const {
+    return (*this >= other) & Vectorized((T)1.0);
+  }
+  Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const {
+    return (*this < other) & Vectorized((T)1.0);
+  }
+  Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const {
+    return (*this <= other) & Vectorized((T)1.0);
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {vec_abs(_vec0), vec_abs(_vec1)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized C10_ALWAYS_INLINE abs() const {
+    return {_vec0, _vec1};
+  }
+
+  Vectorized C10_ALWAYS_INLINE neg() const {
+    return {-_vec0, -_vec1};
+  }
+
+  Vectorized isnan() const {
+    auto x = *this;
+    auto ret = (x == x);
+    return ret._not();
+  }
+
+  bool has_inf_nan() const {
+    for (const auto i : c10::irange(size()/2)) {
+      if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
+        return true;
+      }
+    }
+    for (const auto i : c10::irange(size()/2)) {
+      if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized angle() const {
+    auto tmp = blendv(
+        Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+    return blendv(tmp, *this, isnan());
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized angle() const {
+    return blendv(
+        Vectorized(0), Vectorized(c10::pi), *this < Vectorized(0));
+  }
+
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return Vectorized{0};
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  int zero_mask() const {
+    auto cmp = (*this == Vectorized(0));
+    constexpr auto mask_zero_bits = GetBpermZeroMask();
+    ZSimdVectBinary result0 =
+        vec_bperm_u128((ZSimdVectBinary)cmp.vecb0(), mask_zero_bits);
+    ZSimdVectBinary result1 =
+        vec_bperm_u128((ZSimdVectBinary)cmp.vecb1(), mask_zero_bits);
+    return (result0[0] | (result1[0] << (size() / 2)));
+  }
+
+  Vectorized C10_ALWAYS_INLINE floor() const {
+    return {vec_floor(_vec0), vec_floor(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE ceil() const {
+    return {vec_ceil(_vec0), vec_ceil(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE round() const {
+    return {vec_round(_vec0), vec_round(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE rint() const {
+    return {vec_rint(_vec0), vec_rint(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE trunc() const {
+    return {vec_trunc(_vec0), vec_trunc(_vec1)};
+  }
+
+  Vectorized C10_ALWAYS_INLINE frac() const {
+    return *this - trunc();
+  }
+
+  Vectorized C10_ALWAYS_INLINE sqrt() const {
+    return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
+  }
+  Vectorized C10_ALWAYS_INLINE reciprocal() const {
+    return Vectorized((T)1) / (*this);
+  }
+  Vectorized C10_ALWAYS_INLINE rsqrt() const {
+    return sqrt().reciprocal();
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapOrdinary(float (*const f)(float)) const {
+    float a00 = f(_vec0[0]);
+    float a01 = f(_vec0[1]);
+    float a02 = f(_vec0[2]);
+    float a03 = f(_vec0[3]);
+    float a10 = f(_vec1[0]);
+    float a11 = f(_vec1[1]);
+    float a12 = f(_vec1[2]);
+    float a13 = f(_vec1[3]);
+    return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapOrdinary(double (*const f)(double)) const {
+    return Vectorized(f(_vec0[0]), f(_vec0[1]), f(_vec1[0]), f(_vec1[1]));
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapOrdinary(
+      float (*const f)(float, float),
+      const Vectorized& b) const {
+    float a00 = f(_vec0[0], b._vec0[0]);
+    float a01 = f(_vec0[1], b._vec0[1]);
+    float a02 = f(_vec0[2], b._vec0[2]);
+    float a03 = f(_vec0[3], b._vec0[3]);
+    float a10 = f(_vec1[0], b._vec1[0]);
+    float a11 = f(_vec1[1], b._vec1[1]);
+    float a12 = f(_vec1[2], b._vec1[2]);
+    float a13 = f(_vec1[3], b._vec1[3]);
+    return Vectorized{a00, a01, a02, a03, a10, a11, a12, a13};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapOrdinary(
+      double (*const f)(double, double),
+      const Vectorized& b) const {
+    return Vectorized(
+        f(_vec0[0], b._vec0[0]),
+        f(_vec0[1], b._vec0[1]),
+        f(_vec1[0], b._vec1[0]),
+        f(_vec1[1], b._vec1[1]));
+  }
+
+  template <
+      typename FloatOp,
+      typename DoubleOp,
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapSleef(FloatOp f, DoubleOp d) const {
+    vtype a0 = f(_vec0);
+    vtype a1 = f(_vec1);
+    return Vectorized{a0, a1};
+  }
+
+  template <
+      typename FloatOp,
+      typename DoubleOp,
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapSleef(FloatOp f, DoubleOp d) const {
+    return Vectorized(d(_vec0), d(_vec1));
+  }
+
+  template <
+      typename FloatOp,
+      typename DoubleOp,
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b)
+      const {
+    vtype a0 = f(_vec0, b._vec0);
+    vtype a1 = f(_vec1, b._vec1);
+    return Vectorized{a0, a1};
+  }
+
+  template <
+      typename FloatOp,
+      typename DoubleOp,
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  inline Vectorized mapSleef(FloatOp f, DoubleOp d, const Vectorized& b)
+      const {
+    return Vectorized(d(_vec0, b._vec0), d(_vec1, b._vec1));
+  }
+
+  Vectorized acos() const {
+    return mapSleef(Sleef_acosf4_u10, Sleef_acosd2_u10);
+  }
+  Vectorized asin() const {
+    return mapSleef(Sleef_asinf4_u10, Sleef_asind2_u10);
+  }
+  Vectorized atan() const {
+    return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10);
+  }
+  Vectorized atanh() const {
+    return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10);
+  }
+
+  Vectorized erf() const {
+    return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10);
+  }
+  Vectorized erfc() const {
+    return mapSleef(Sleef_erfcf4_u15, Sleef_erfcd2_u15);
+  }
+
+  Vectorized exp() const {
+    return mapSleef(Sleef_expf4_u10, Sleef_expd2_u10);
+  }
+  Vectorized exp2() const {
+    return mapSleef(Sleef_exp2f4_u10, Sleef_exp2d2_u10);
+  }
+  Vectorized expm1() const {
+    return mapSleef(Sleef_expm1f4_u10, Sleef_expm1d2_u10);
+  }
+  Vectorized exp_u20() const {
+    return exp();
+  }
+
+  Vectorized log() const {
+    return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10);
+  }
+  Vectorized log2() const {
+    return mapSleef(Sleef_log2f4_u10, Sleef_log2d2_u10);
+  }
+  Vectorized log10() const {
+    return mapSleef(Sleef_log10f4_u10, Sleef_log10d2_u10);
+  }
+  Vectorized log1p() const {
+    return mapSleef(Sleef_log1pf4_u10, Sleef_log1pd2_u10);
+  }
+
+  Vectorized sin() const {
+#ifndef SLEEF_MEMORY_WORKAROUND
+    return mapSleef(Sleef_sinf4_u10, Sleef_sind2_u10);
+#else
+    return mapOrdinary(std::sin);
+#endif
+  }
+  Vectorized sinh() const {
+    return mapSleef(Sleef_sinhf4_u10, Sleef_sinhd2_u10);
+  }
+  Vectorized cos() const {
+#ifndef SLEEF_MEMORY_WORKAROUND
+    return mapSleef(Sleef_cosf4_u10, Sleef_cosd2_u10);
+#else
+    return mapOrdinary(std::cos);
+#endif
+  }
+  Vectorized cosh() const {
+    return mapSleef(Sleef_coshf4_u10, Sleef_coshd2_u10);
+  }
+
+  Vectorized tan() const {
+#ifndef SLEEF_MEMORY_WORKAROUND
+    return mapSleef(Sleef_tanf4_u10, Sleef_tand2_u10);
+#else
+    return mapOrdinary(std::tan);
+#endif
+  }
+  Vectorized tanh() const {
+    return mapSleef(Sleef_tanhf4_u10, Sleef_tanhd2_u10);
+  }
+
+  Vectorized lgamma() const {
+    return mapSleef(Sleef_lgammaf4_u10, Sleef_lgammad2_u10);
+  }
+
+  Vectorized atan2(const Vectorized& b) const {
+    return mapSleef(Sleef_atan2f4_u10, Sleef_atan2d2_u10, b);
+  }
+  Vectorized copysign(const Vectorized& sign) const {
+    return mapSleef(Sleef_copysignf4, Sleef_copysignd2, sign);
+  }
+  Vectorized fmod(const Vectorized& q) const {
+    return mapSleef(Sleef_fmodf4, Sleef_fmodd2, q);
+  }
+
+  Vectorized hypot(const Vectorized& b) const {
+    return mapSleef(Sleef_hypotf4_u05, Sleef_hypotd2_u05, b);
+  }
+
+  Vectorized pow(const Vectorized& b) const {
+    return mapSleef(Sleef_powf4_u10, Sleef_powd2_u10, b);
+  }
+
+  Vectorized nextafter(const Vectorized& b) const {
+    return mapSleef(Sleef_nextafterf4, Sleef_nextafterd2, b);
+  }
+
+  Vectorized erfinv() const {
+    return mapOrdinary(calc_erfinv);
+  }
+
+  Vectorized digamma() const {
+    return mapOrdinary(calc_digamma);
+  }
+
+  Vectorized igamma(const Vectorized& x) const {
+    return mapOrdinary(calc_igamma, x);
+  }
+
+  Vectorized igammac(const Vectorized& x) const {
+    return mapOrdinary(calc_igammac, x);
+  }
+
+  Vectorized i0() const {
+    return mapOrdinary(calc_i0);
+  }
+
+  Vectorized i0e() const {
+    return mapOrdinary(calc_i0e);
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized minimum(const Vectorized& other) const {
+    return {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)};
+  }
+
+  /* Propagates NaN if either input is a NaN. */
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized minimum(const Vectorized& other) const {
+    Vectorized tmp = {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)};
+    tmp = blendv(tmp, *this, isnan());
+    return blendv(tmp, other, other.isnan());
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized maximum(const Vectorized& other) const {
+    return {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)};
+  }
+
+  /* Propagates NaN if either input is a NaN. */
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized maximum(const Vectorized& other) const {
+    Vectorized tmp = {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)};
+    tmp = blendv(tmp, *this, isnan());
+    return blendv(tmp, other, other.isnan());
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized clamp_min(const Vectorized& min) const {
+    return {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)};
+  }
+
+  /* Keeps NaN if actual value is NaN */
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized clamp_min(const Vectorized& min) const {
+    Vectorized tmp = {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)};
+    return blendv(tmp, *this, isnan());
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized clamp_max(const Vectorized& max) const {
+    return {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)};
+  }
+
+  /* Keeps NaN if actual value is NaN */
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized clamp_max(const Vectorized& max) const {
+    Vectorized tmp = {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)};
+    return blendv(tmp, *this, isnan());
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized swapped() const {
+    auto swap_mask = GetSwapMaskFloat();
+    vtype v0 = vec_perm(_vec0, _vec0, swap_mask);
+    vtype v1 = vec_perm(_vec1, _vec1, swap_mask);
+    return {v0, v1};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized swapped() const {
+    vtype v0 = vec_permi(_vec0, _vec0, 2);
+    vtype v1 = vec_permi(_vec1, _vec1, 2);
+    return {v0, v1};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  static Vectorized mergee(Vectorized& first, Vectorized& second) {
+    return {
+        vec_mergee(first._vec0, second._vec0),
+        vec_mergee(first._vec1, second._vec1)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  static Vectorized mergeo(Vectorized& first, Vectorized& second) {
+    return {
+        vec_mergeo(first._vec0, second._vec0),
+        vec_mergeo(first._vec1, second._vec1)};
+  }
+
+  static Vectorized horizontal_add_perm(
+      Vectorized& first,
+      Vectorized& second) {
+    // we will simulate it differently with 6 instructions total
+    // lets permute second so that we can add it getting horizontal sums
+    auto first_perm = first.swapped(); // 2perm
+    auto second_perm = second.swapped(); // 2perm
+    // summ
+    auto first_ret = first + first_perm; // 2add
+    auto second_ret = second + second_perm; // 2 add
+    // now lets choose evens
+    return mergee(first_ret, second_ret); // 2 mergee's
+  }
+
+  static Vectorized horizontal_sub_perm(
+      Vectorized& first,
+      Vectorized& second) {
+    // we will simulate it differently with 6 instructions total
+    // lets permute second so that we can add it getting horizontal sums
+    auto first_perm = first.swapped(); // 2perm
+    auto second_perm = second.swapped(); // 2perm
+    // summ
+    auto first_ret = first - first_perm; // 2sub
+    auto second_ret = second - second_perm; // 2 sub
+    // now lets choose evens
+    return mergee(first_ret, second_ret); // 2 mergee's
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized mergee() const {
+    return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized mergeo() const {
+    return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized to_vec_float_helper() const {
+    int32_t values[8] = {
+      _vec0[0],
+      _vec0[1],
+      _vec0[2],
+      _vec0[3],
+      _vec0[4],
+      _vec0[5],
+      _vec0[6],
+      _vec0[7],
+    };
+
+    return Vectorized{
+      values[0], values[1], values[2], values[3],
+      values[4], values[5], values[6], values[7]
+    };
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::value, int> = 0>
+  Vectorized to_vec_uint8_helper() const {
+    // helper function for float to uint8_t conversion
+    uint8_t values[8] = {
+      static_cast(_vec0[0]),
+      static_cast(_vec0[1]),
+      static_cast(_vec0[2]),
+      static_cast(_vec0[3]),
+      static_cast(_vec1[0]),
+      static_cast(_vec1[1]),
+      static_cast(_vec1[2]),
+      static_cast(_vec1[3]),
+    };
+
+    return Vectorized{
+      values[0], values[1], values[2], values[3],
+      values[4], values[5], values[6], values[7],
+      0, 0, 0, 0,
+      0, 0, 0, 0,
+      0, 0, 0, 0,
+      0, 0, 0, 0,
+      0, 0, 0, 0,
+      0, 0, 0, 0,
+    };
+  }
+};
+
+template <>
+inline Vectorized operator~(const Vectorized& a) {
+  return a._not();
+}
+
+template <>
+inline Vectorized operator~(const Vectorized& a) {
+  return a._not();
+}
+
+template <>
+inline Vectorized operator~(const Vectorized& a) {
+  return a._not();
+}
+
+template <>
+inline Vectorized operator~(const Vectorized& a) {
+  return a._not();
+}
+
+template <>
+inline Vectorized operator~(const Vectorized& a) {
+  return a._not();
+}
+
+#define DEFINE_MAXMIN_FUNCS(operand_type)                                     \
+  template <>                                                                 \
+  Vectorized inline maximum(                                    \
+      const Vectorized& a, const Vectorized& b) { \
+    return a.maximum(b);                                                      \
+  }                                                                           \
+  template <>                                                                 \
+  Vectorized inline minimum(                                    \
+      const Vectorized& a, const Vectorized& b) { \
+    return a.minimum(b);                                                      \
+  }
+
+#define DEFINE_CLAMP_MAXMIN_FUNCS(typex)                          \
+  DEFINE_MAXMIN_FUNCS(typex)                                      \
+  template <>                                                     \
+  Vectorized C10_ALWAYS_INLINE clamp_min(                  \
+      const Vectorized& a, const Vectorized& min) { \
+    return a.clamp_min(min);                                      \
+  }                                                               \
+  template <>                                                     \
+  Vectorized C10_ALWAYS_INLINE clamp_max(                  \
+      const Vectorized& a, const Vectorized& max) { \
+    return a.clamp_max(max);                                      \
+  }                                                               \
+  template <>                                                     \
+  Vectorized C10_ALWAYS_INLINE clamp(                      \
+      const Vectorized& a,                                 \
+      const Vectorized& min,                               \
+      const Vectorized& max) {                             \
+    return clamp_max(clamp_min(a, min), max);                     \
+  }
+
+DEFINE_CLAMP_MAXMIN_FUNCS(int8_t)
+DEFINE_CLAMP_MAXMIN_FUNCS(uint8_t)
+DEFINE_CLAMP_MAXMIN_FUNCS(int16_t)
+DEFINE_CLAMP_MAXMIN_FUNCS(int32_t)
+DEFINE_CLAMP_MAXMIN_FUNCS(int64_t)
+DEFINE_CLAMP_MAXMIN_FUNCS(float)
+DEFINE_CLAMP_MAXMIN_FUNCS(double)
+
+namespace { /* unnamed namespace */
+
+#if !defined(vec_float) || __ARCH__ < 13
+#warning \
+    "float->int and int->float conversion is simulated. compile for z15 for improved performance"
+inline ZSimdVect vec_int_flt(const ZSimdVect x) {
+  return ZSimdVect{float(x[0]), float(x[1]), float(x[2]), float(x[3])};
+}
+inline ZSimdVect vec_flt_int(const ZSimdVect x) {
+  return ZSimdVect{int(x[0]), int(x[1]), int(x[2]), int(x[3])};
+}
+#else
+#define vec_int_flt vec_float
+#define vec_flt_int vec_signed
+#endif
+
+Vectorized convert_to_float(const Vectorized& x) {
+  return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())};
+}
+
+Vectorized convert_to_int(const Vectorized& x) {
+  return {vec_flt_int(x.vec0()), vec_flt_int(x.vec1())};
+}
+
+Vectorized convert_to_float(const Vectorized& x) {
+  return {vec_double(x.vec0()), vec_double(x.vec1())};
+}
+
+Vectorized convert_to_int(const Vectorized& x) {
+  return {vec_signed(x.vec0()), vec_signed(x.vec1())};
+}
+
+} /* unnamed namespace */
+
+template 
+Vectorized cast_zvector(const Vectorized& x) {
+  using cast_type = typename Vectorized::vtype;
+  return Vectorized{(cast_type)x.vec0(), (cast_type)x.vec1()};
+}
+
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      __builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()),
+      __builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())};
+}
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      __builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()),
+      __builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())};
+}
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
+}
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
+}
+template <>
+Vectorized C10_ALWAYS_INLINE fmadd(
+    const Vectorized& a,
+    const Vectorized& b,
+    const Vectorized& c) {
+  return Vectorized{
+      a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
+}
+
+template <>
+Vectorized C10_ALWAYS_INLINE
+convert_to_int_of_same_size(const Vectorized& src) {
+  return convert_to_int(src);
+}
+
+template <>
+Vectorized C10_ALWAYS_INLINE
+convert_to_int_of_same_size(const Vectorized& src) {
+  return convert_to_int(src);
+}
+
+template <>
+inline void convert(const int32_t* src, float* dst, int64_t n) {
+  // int32_t and float have same size
+  int64_t i;
+  for (i = 0; i <= (n - Vectorized::size());
+       i += Vectorized::size()) {
+    const int32_t* src_a = src + i;
+    float* dst_a = dst + i;
+    auto input_vec = Vectorized::loadu(src_a);
+    auto output_vec = convert_to_float(input_vec);
+    output_vec.store(dst_a);
+  }
+
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+inline void convert(const int64_t* src, double* dst, int64_t n) {
+  int64_t i;
+  for (i = 0; i <= (n - Vectorized::size());
+       i += Vectorized::size()) {
+    const int64_t* src_a = src + i;
+    double* dst_a = dst + i;
+    auto input_vec = Vectorized::loadu(src_a);
+    auto output_vec = convert_to_float(input_vec);
+    output_vec.store(dst_a);
+  }
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+#define DEFINE_REINTERPRET_CAST_FUNCS(Fst, Cst)     \
+  template <>                                       \
+  C10_ALWAYS_INLINE Vectorized cast( \
+      const Vectorized& src) {                 \
+    return cast_zvector(src);             \
+  }
+
+#define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(Fst) \
+  DEFINE_REINTERPRET_CAST_FUNCS(Fst, double)      \
+  DEFINE_REINTERPRET_CAST_FUNCS(Fst, float)       \
+  DEFINE_REINTERPRET_CAST_FUNCS(Fst, int64_t)     \
+  DEFINE_REINTERPRET_CAST_FUNCS(Fst, int32_t)     \
+  DEFINE_REINTERPRET_CAST_FUNCS(Fst, int16_t)
+
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
+DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)
+
+#undef DEFINE_REINTERPRET_CAST_FUNCS
+
+template 
+struct unpack_type {
+  using type = T;
+};
+template <>
+struct unpack_type {
+  using type = int16_t;
+};
+template <>
+struct unpack_type {
+  using type = int16_t;
+};
+template <>
+struct unpack_type {
+  using type = int32_t;
+};
+
+template 
+struct pack_type {
+  using type = T;
+};
+template <>
+struct pack_type {
+  using type = int8_t;
+};
+template <>
+struct pack_type {
+  using type = int16_t;
+};
+
+namespace { /* unnamed namespace */
+
+template ::type>
+std::pair, Vectorized> unpack(const Vectorized& x) {
+  auto vec0 = vec_unpackh(x.vec0());
+  auto vec1 = vec_unpackl(x.vec0());
+  auto vec2 = vec_unpackh(x.vec1());
+  auto vec3 = vec_unpackl(x.vec1());
+  return {Vectorized{vec0, vec1}, Vectorized{vec2, vec3}};
+}
+
+template <>
+std::pair, Vectorized> unpack(
+    const Vectorized& x) {
+  using typeX = typename Vectorized::vtype;
+  typeX vec0 = vec_unpackh(x.vec0());
+  typeX vec1 = vec_unpackl(x.vec0());
+  typeX vec2 = vec_unpackh(x.vec1());
+  typeX vec3 = vec_unpackl(x.vec1());
+  // auto mask = Vectorized(0xFF);
+  // vec0 = vec0 & mask;
+  // vec1 = vec1 & mask;
+  // vec2 = vec2 & mask;
+  // vec3 = vec3 & mask;
+  return {
+      cast_zvector(Vectorized{vec0, vec1}),
+      cast_zvector(Vectorized{vec2, vec3})};
+}
+
+template ::type>
+Vectorized pack(const Vectorized& first, const Vectorized& second) {
+  auto vec0 = vec_packs(first.vec0(), first.vec1());
+  auto vec1 = vec_packs(second.vec0(), second.vec1());
+  return Vectorized{vec0, vec1};
+}
+
+template <>
+Vectorized pack(
+    const Vectorized& first,
+    const Vectorized& second) {
+  auto vec0 = vec_packsu(first.vec0(), first.vec1());
+  auto vec1 = vec_packsu(second.vec0(), second.vec1());
+  return Vectorized{vec0, vec1};
+}
+
+} /* unnamed namespace */
+
+//////////////////////////////////QUANT///////////////////////////////////////////
+template 
+struct Vectorized()>> {
+ public:
+  using value_type = typename T::underlying;
+  using vtype = ZSimdVect;
+  using vmaskType = ZSimdVectBinary;
+  using vinner_type = Vectorized;
+  using size_type = int;
+
+  static constexpr size_type size() {
+    return VECTOR_WIDTH / sizeof(value_type);
+  }
+
+  static constexpr size_t float_num_vecs() {
+    return size() / Vectorized::size();
+  }
+  static constexpr int int_num_vecs() {
+    return float_num_vecs();
+  }
+  using float_vec_return_type = std::array, float_num_vecs()>;
+  using int_vec_return_type =
+      std::array, int_num_vecs()>;
+
+ private:
+  vinner_type _vec;
+
+ public:
+  Vectorized() {}
+
+  explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {}
+  Vectorized(const T& val) : _vec(val.val_) {}
+
+  C10_ALWAYS_INLINE const vinner_type& vec() const {
+    return _vec;
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    return Vectorized{vinner_type::loadu(ptr, count)};
+  }
+
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    _vec.store(ptr, count);
+  }
+
+  Vectorized relu(Vectorized zero_point) const {
+    return Vectorized{_vec.maximum(zero_point._vec)};
+  }
+
+  Vectorized relu6(Vectorized zero_point, Vectorized q_six) const {
+    auto ret_max = _vec.maximum(zero_point._vec);
+    auto ret_min = ret_max.minimum(q_six._vec);
+    return Vectorized{ret_min};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 1, int> = 0>
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    return {*this - b};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 1, int> = 0>
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    auto float_val = convert_to_float(_vec);
+    return {fmadd(scale, float_val, scale_zp_premul)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 1, int> = 0>
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    auto float_val = convert_to_float(_vec);
+    return {(float_val - zero_point) * scale};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 1, int> = 0>
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    Vectorized vecf = rhs[0];
+    vecf = vecf * Vectorized(inverse_scale);
+    vecf = vecf.rint() + Vectorized((float)(zero_point));
+    auto veci = convert_to_int(vecf);
+
+    return Vectorized{veci};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::int_num_vecs() == 1, int> = 0>
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    Vectorized vi = inp[0];
+    auto vecf = convert_to_float(vi.vec());
+    vecf = vecf * Vectorized(multiplier);
+    vecf = vecf.rint();
+    auto veci = convert_to_int(vecf) + Vectorized(zero_point);
+
+    return Vectorized{veci};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::int_num_vecs() == 4, int> = 0>
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    auto ret16 = unpack(_vec);
+    auto ret16B = unpack(b.vec());
+    auto ret32_0 = unpack(ret16.first);
+    auto ret32_1 = unpack(ret16.second);
+    auto ret32B_0 = unpack(ret16B.first);
+    auto ret32B_1 = unpack(ret16B.second);
+
+    return {
+        Vectorized(ret32_0.first - ret32B_0.first),
+        Vectorized(ret32_0.second - ret32B_0.second),
+        Vectorized(ret32_1.first - ret32B_1.first),
+        Vectorized(ret32_1.second - ret32B_1.second)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 4, int> = 0>
+  float_vec_return_type C10_ALWAYS_INLINE dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    // unpacking unsigned as signed
+    auto ret16 = unpack(_vec);
+    auto ret32_0 = unpack(ret16.first);
+    auto ret32_1 = unpack(ret16.second);
+
+    auto vecf_0 = convert_to_float(ret32_0.first);
+    auto vecf_1 = convert_to_float(ret32_0.second);
+
+    auto vecf_2 = convert_to_float(ret32_1.first);
+    auto vecf_3 = convert_to_float(ret32_1.second);
+    return {
+        fmadd(scale, vecf_0, scale_zp_premul),
+        fmadd(scale, vecf_1, scale_zp_premul),
+        fmadd(scale, vecf_2, scale_zp_premul),
+        fmadd(scale, vecf_3, scale_zp_premul)};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 4, int> = 0>
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    // unpacking unsigned as signed
+    auto ret16 = unpack(_vec);
+    auto ret32_0 = unpack(ret16.first);
+    auto ret32_1 = unpack(ret16.second);
+
+    auto vecf_0 = convert_to_float(ret32_0.first);
+    auto vecf_1 = convert_to_float(ret32_0.second);
+
+    auto vecf_2 = convert_to_float(ret32_1.first);
+    auto vecf_3 = convert_to_float(ret32_1.second);
+
+    return {
+        (vecf_0 - zero_point) * scale,
+        (vecf_1 - zero_point) * scale,
+        (vecf_2 - zero_point) * scale,
+        (vecf_3 - zero_point) * scale };
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::float_num_vecs() == 4, int> = 0>
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    auto vec_inverse = Vectorized(inverse_scale);
+    auto vec_zero_point = Vectorized((float)zero_point);
+
+    auto vecf0 = rhs[0];
+    auto vecf2 = rhs[1];
+    auto vecf4 = rhs[2];
+    auto vecf6 = rhs[3];
+
+    vecf0 = vecf0 * vec_inverse;
+    vecf2 = vecf2 * vec_inverse;
+    vecf4 = vecf4 * vec_inverse;
+    vecf6 = vecf6 * vec_inverse;
+
+    vecf0 = vecf0.rint() + vec_zero_point;
+    vecf2 = vecf2.rint() + vec_zero_point;
+    vecf4 = vecf4.rint() + vec_zero_point;
+    vecf6 = vecf6.rint() + vec_zero_point;
+
+    auto veci0 = convert_to_int(vecf0);
+    auto veci2 = convert_to_int(vecf2);
+    auto veci4 = convert_to_int(vecf4);
+    auto veci6 = convert_to_int(vecf6);
+
+    auto vecshi0 = pack(veci0, veci2);
+    auto vecshi2 = pack(veci4, veci6);
+    auto ret = pack(vecshi0, vecshi2);
+
+    return Vectorized{ret};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t::int_num_vecs() == 4, int> = 0>
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    Vectorized vec_multiplier = Vectorized(multiplier);
+    Vectorized vec_zero_point = Vectorized(zero_point);
+
+    Vectorized vi0 = inp[0];
+    Vectorized vi1 = inp[1];
+    Vectorized vi2 = inp[2];
+    Vectorized vi3 = inp[3];
+
+    auto vecf0 = convert_to_float(vi0.vec());
+    auto vecf2 = convert_to_float(vi1.vec());
+
+    auto vecf4 = convert_to_float(vi2.vec());
+    auto vecf6 = convert_to_float(vi3.vec());
+
+    vecf0 = vecf0 * vec_multiplier;
+    vecf2 = vecf2 * vec_multiplier;
+
+    vecf4 = vecf4 * vec_multiplier;
+    vecf6 = vecf6 * vec_multiplier;
+
+    vecf0 = vecf0.rint();
+    vecf2 = vecf2.rint();
+    vecf4 = vecf4.rint();
+    vecf6 = vecf6.rint();
+
+    auto veci0 = convert_to_int(vecf0);
+    auto veci2 = convert_to_int(vecf2);
+    auto veci4 = convert_to_int(vecf4);
+    auto veci6 = convert_to_int(vecf6);
+
+    veci0 = veci0 + vec_zero_point;
+    veci2 = veci2 + vec_zero_point;
+
+    veci4 = veci4 + vec_zero_point;
+    veci6 = veci6 + vec_zero_point;
+
+    auto vecshi0 = pack(veci0, veci2);
+    auto vecshi2 = pack(veci4, veci6);
+
+    auto ret = pack(vecshi0, vecshi2);
+
+    return Vectorized{ret};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& other) const {
+    return Vectorized{_vec + other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator-(const Vectorized& other) const {
+    return Vectorized{_vec - other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator*(const Vectorized& other) const {
+    return Vectorized{_vec * other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator/(const Vectorized& other) const {
+    return Vectorized{_vec / other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator&(const Vectorized& other) const {
+    return Vectorized{_vec & other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator|(const Vectorized& other) const {
+    return Vectorized{_vec | other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator^(const Vectorized& other) const {
+    return Vectorized{_vec ^ other._vec};
+  }
+  Vectorized C10_ALWAYS_INLINE operator==(const Vectorized& other) const {
+    return Vectorized{_vec == other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator!=(const Vectorized& other) const {
+    return Vectorized{_vec != other._vec};
+  }
+  Vectorized C10_ALWAYS_INLINE operator>(const Vectorized& other) const {
+    return Vectorized{_vec > other._vec};
+  }
+  Vectorized C10_ALWAYS_INLINE operator>=(const Vectorized& other) const {
+    return Vectorized{_vec >= other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator<(const Vectorized& other) const {
+    return Vectorized{_vec < other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator<=(const Vectorized& other) const {
+    return Vectorized{_vec <= other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const {
+    return Vectorized{_vec.eq(other._vec)};
+  }
+  Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const {
+    return Vectorized{_vec.ne(other._vec)};
+  }
+  Vectorized C10_ALWAYS_INLINE gt(const Vectorized& other) const {
+    return Vectorized{_vec.gt(other._vec)};
+  }
+  Vectorized C10_ALWAYS_INLINE ge(const Vectorized& other) const {
+    return Vectorized{_vec.ge(other._vec)};
+  }
+  Vectorized C10_ALWAYS_INLINE lt(const Vectorized& other) const {
+    return Vectorized{_vec.lt(other._vec)};
+  }
+  Vectorized C10_ALWAYS_INLINE le(const Vectorized& other) const {
+    return Vectorized{_vec.le(other._vec)};
+  }
+
+  Vectorized clamp_min(const Vectorized& min) const {
+    return Vectorized{_vec.clamp_min(min._vec)};
+  }
+
+  Vectorized clamp_max(const Vectorized& max) const {
+    return Vectorized{_vec.clamp_max(max._vec)};
+  }
+
+  Vectorized minimum(const Vectorized& other) const {
+    return Vectorized{_vec.minimum(other._vec)};
+  }
+
+  Vectorized maximum(const Vectorized& other) const {
+    return Vectorized{_vec.maximum(other._vec)};
+  }
+};
+
+DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8)
+DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8)
+DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32)
+
+template 
+constexpr auto real_mask() {
+  return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFF, 0, 0xFFFFFFFF, 0};
+}
+
+template <>
+constexpr auto real_mask() {
+  return (ZSimdVect)ZSimdVectBinary{0xFFFFFFFFFFFFFFFF, 0};
+}
+
+template 
+constexpr auto image_mask() {
+  return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFF, 0, 0xFFFFFFFF};
+}
+
+template <>
+constexpr auto image_mask() {
+  return (ZSimdVect)ZSimdVectBinary{0, 0xFFFFFFFFFFFFFFFF};
+}
+
+template 
+constexpr auto rsign_mask() {
+  return ZSimdVect{-0.f, 0.f, -0.f, 0.f};
+}
+
+template <>
+constexpr auto rsign_mask() {
+  return ZSimdVect{-0.0, 0.f};
+}
+
+template 
+constexpr auto isign_mask() {
+  return ZSimdVect{0.0, -0.f, 0.0, -0.f};
+}
+
+template <>
+constexpr auto isign_mask() {
+  return ZSimdVect{0.0, -0.0};
+}
+
+template 
+constexpr auto image_one() {
+  return ZSimdVect{0, 1.f, 0, 1.f};
+}
+
+template <>
+constexpr auto image_one() {
+  return ZSimdVect{0.0, 1.0};
+}
+
+template 
+constexpr auto pi_half() {
+  return ZSimdVect{(float)(M_PI / 2.0), 0.f, (float)(M_PI / 2.0), 0.f};
+}
+
+template <>
+constexpr auto pi_half() {
+  return ZSimdVect{M_PI / 2.0, 0.0};
+}
+
+template 
+constexpr auto image_half() {
+  return ZSimdVect{0, 0.5f, 0, 0.5f};
+}
+
+template <>
+constexpr auto image_half() {
+  return ZSimdVect{0.0, 0.5};
+}
+
+template 
+constexpr U log2e_inv() {
+  return static_cast(1.4426950408889634);
+}
+
+template 
+constexpr U log10e_inv() {
+  return static_cast(0.43429448190325176);
+}
+
+template 
+struct Vectorized()>> {
+ public:
+  using underline_type = decltype(std::declval().imag());
+  using value_type = T;
+  using vtype = ZSimdVect;
+  using vmaskType = ZSimdVectBinary;
+  using vinner_type = Vectorized;
+  using size_type = int;
+  using vinner_data = typename Vectorized::vinner_data;
+
+  static constexpr size_type size() {
+    return VECTOR_WIDTH / sizeof(value_type);
+  }
+
+ private:
+  vinner_type _vec;
+
+ public:
+  Vectorized() {}
+
+  C10_ALWAYS_INLINE Vectorized(const vinner_data &v) : _vec{v.first, v.second} {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(T s1, T s2)
+      : _vec{s1.real(), s1.imag(), s2.real(), s2.imag()} {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4)
+      : _vec{
+            s1.real(),
+            s1.imag(),
+            s2.real(),
+            s2.imag(),
+            s3.real(),
+            s3.imag(),
+            s4.real(),
+            s4.imag()} {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s) {}
+
+  template  = 0>
+  C10_ALWAYS_INLINE Vectorized(T s) : Vectorized(s, s, s, s) {}
+
+  C10_ALWAYS_INLINE operator vinner_type() const {
+    return _vec;
+  }
+
+  C10_ALWAYS_INLINE const vinner_type& vec() const {
+    return _vec;
+  }
+
+  C10_ALWAYS_INLINE operator vinner_data() const {
+    return _vec.data();
+  }
+
+  C10_ALWAYS_INLINE vinner_data data() const {
+    return _vec.data();
+  }
+
+  static Vectorized C10_ALWAYS_INLINE
+  loadu(const void* ptr, int count = size()) {
+    return Vectorized{vinner_type::loadu(ptr, 2 * count)};
+  }
+
+  void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
+    return _vec.store(ptr, 2 * count);
+  }
+
+  static Vectorized blendv(
+      const Vectorized& a,
+      const Vectorized& b,
+      const Vectorized& mask) {
+    // convert std::complex index mask to V index mask: xy -> xxyy
+    vinner_type vmask = mask.vec();
+    auto mask_complex = vinner_type(
+        vec_mergeh(vmask.vec0(), vmask.vec0()),
+        vec_mergeh(vmask.vec1(), vmask.vec1()));
+    return Vectorized{vinner_type::blendv(a.vec(), b.vec(), mask_complex)};
+  }
+
+  template 
+  static auto C10_ALWAYS_INLINE
+  blend(const Vectorized& a, const Vectorized& b) {
+    constexpr int mask_complex = maskForComplex(mask);
+    return Vectorized{
+        vinner_type::template blend(a.vec(), b.vec())};
+  }
+
+  template 
+  static std::enable_if_t> arange(
+      T base = 0,
+      step_t step = static_cast(1)) {
+    return Vectorized(base, base + step);
+  }
+
+  template 
+  static std::enable_if_t> arange(
+      T base = 0,
+      step_t step = static_cast(1)) {
+    return Vectorized(
+        base,
+        base + step,
+        base + value_type(2) * step,
+        base + value_type(3) * step);
+  }
+
+  template 
+  static inline std::enable_if_t<(Z >= C), Vectorized> set_inner(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count) {
+    return b;
+  }
+
+  template 
+  static inline std::enable_if_t<(Z < C), Vectorized> set_inner(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count) {
+    if (count == Z)
+      return blend(a, b);
+    else
+      return set_inner(a, b, count);
+  }
+
+  static Vectorized set(
+      const Vectorized& a,
+      const Vectorized& b,
+      size_t count = size()) {
+    if (count == 0)
+      return a;
+    return set_inner<1, size()>(a, b, count);
+  }
+
+  const T& operator[](int idx) const = delete;
+  T& operator[](int idx) = delete;
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  Vectorized mapOrdinary(T (*const f)(const T&)) const {
+    auto v0 = _vec.vec0();
+    auto v1 = _vec.vec1();
+    return Vectorized{
+        f(T(v0[0], v0[1])),
+        f(T(v0[2], v0[3])),
+        f(T(v1[0], v1[1])),
+        f(T(v1[2], v1[3]))};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  Vectorized mapOrdinary(T (*const f)(const T&)) const {
+    auto v0 = _vec.vec0();
+    auto v1 = _vec.vec1();
+    return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  Vectorized mapOrdinary(T (*const f)(T)) const {
+    auto v0 = _vec.vec0();
+    auto v1 = _vec.vec1();
+    return Vectorized{
+        f(T(v0[0], v0[1])),
+        f(T(v0[2], v0[3])),
+        f(T(v1[0], v1[1])),
+        f(T(v1[2], v1[3]))};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  Vectorized mapOrdinary(T (*const f)(T)) const {
+    auto v0 = _vec.vec0();
+    auto v1 = _vec.vec1();
+    return Vectorized{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  inline Vectorized mapOrdinary(
+      T (*const f)(const T&, const T&),
+      const Vectorized& b) const {
+    auto v0 = _vec.vec0();
+    auto v1 = _vec.vec1();
+    auto bvec = b.vec();
+    auto b0 = bvec.vec0();
+    auto b1 = bvec.vec1();
+    T a00 = f(T(v0[0], v0[1]), T(b0[0], b0[1]));
+    T a01 = f(T(v0[2], v0[3]), T(b0[2], b0[3]));
+    T a02 = f(T(v1[0], v1[1]), T(b1[0], b1[1]));
+    T a03 = f(T(v1[2], v1[3]), T(b1[2], b1[3]));
+    return Vectorized{a00, a01, a02, a03};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  inline Vectorized mapOrdinary(
+      T (*const f)(const T&, const T&),
+      const Vectorized& b) const {
+    auto v0 = _vec.vec0();
+    auto v1 = _vec.vec1();
+    auto bvec = b.vec();
+    auto b0 = bvec.vec0();
+    auto b1 = bvec.vec1();
+    U a00 = f(U(v0[0], v0[1]), U(b0[0], b0[1]));
+    U a01 = f(U(v1[0], v1[1]), U(b1[0], b1[1]));
+    return Vectorized{a00, a01};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& other) const {
+    return Vectorized{_vec + other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator-(const Vectorized& other) const {
+    return Vectorized{_vec - other._vec};
+  }
+
+  Vectorized inline operator*(const Vectorized& b) const {
+    //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+    vinner_type bv = b.vec();
+#if !defined(ZVECTOR_SIMULATE_X86_MULT)
+    // this is more z arch friendly than simulating horizontal from x86
+    vinner_type vi = bv.mergeo();
+    vinner_type vr = bv.mergee();
+    vi = vi ^ rsign_mask();
+    vinner_type ret = _vec * vr;
+    vinner_type vx_swapped = _vec.swapped();
+    ret = fmadd(vx_swapped, vi, ret);
+#else
+    vinner_type ac_bd = _vec * b;
+    vinner_type d_c = bv.swapped();
+    d_c = d_c ^ isign_mask();
+    vinner_type ad_bc = _vec * d_c;
+    vinner_type ret = vinner_type::horizontal_sub_perm(ac_bd, ad_bc);
+#endif
+    return Vectorized{ret};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  static typename Vectorized::vinner_type real_neg(const typename Vectorized::vinner_type &a)
+  {
+    const auto swap_mask = ZSimdVectBinary{
+      0, 1, 2, 3, 20, 21, 22, 23, 8, 9, 10, 11, 28, 29, 30, 31};
+
+    auto a_neg = a.neg();
+    vtype v0 = vec_perm(a_neg.vec0(), a.vec0(), swap_mask);
+    vtype v1 = vec_perm(a_neg.vec1(), a.vec1(), swap_mask);
+    return {v0, v1};
+  }
+
+  template <
+      typename U = T,
+      std::enable_if_t>::value, int> = 0>
+  static typename Vectorized::vinner_type real_neg(const typename Vectorized::vinner_type &a)
+  {
+    auto a_neg = a.neg();
+    auto v0 = vec_permi(a_neg.vec0(), a.vec0(), 1);
+    auto v1 = vec_permi(a_neg.vec1(), a.vec1(), 1);
+    return { v0, v1 };
+  }
+
+  Vectorized inline operator/(const Vectorized& b) const {
+    // Unfortunately, this breaks some tests
+    // Implement it like it's done for avx2
+    auto fabs_cd = b.vec().abs();                               // |c|    |d|
+    auto fabs_dc = fabs_cd.swapped();                           // |d|    |c|
+    auto scale = vinner_type {1.0} / maximum(fabs_cd, fabs_dc); // 1/sc     1/sc
+    auto a2 = vec() * scale;                                    // a/sc     b/sc
+    auto b2 = b.vec() * scale;                                  // c/sc     d/sc
+    auto acbd2 = a2 * b2;                                       // ac/sc^2  bd/sc^2
+
+    auto dc2 = b2.swapped();                                    // d/sc         c/sc
+    dc2 = Vectorized::real_neg(dc2);                         // -d/|c,d|        c/sc
+    auto adbc2 = a2 * dc2;                                      // -ad/sc^2      bc/sc^2
+    auto sum1 = acbd2 + acbd2.swapped();                        // (ac+bd)/sc^2  (ac+bd)/sc^2
+    auto sum2 = adbc2 + adbc2.swapped();                        // (bc-ad)/sc^2  (bc-ad)/sc^2
+    auto res2 = vinner_type::mergee(sum1, sum2);                // (ac+bd)/sc^2  (bc-ad)/sc^2
+
+    // get the denominator
+    auto denom2 = Vectorized{b2}.abs_2_();                   // (c^2+d^2)/sc^2   (c^2+d^2)/sc^2
+    res2 = res2 / denom2;
+    return Vectorized{ res2 };
+  }
+
+  Vectorized angle2_() const {
+    auto b_a = _vec.swapped(); // b        a
+    return Vectorized{_vec.atan2(b_a).swapped()};
+  }
+
+  Vectorized angle() const {
+    return angle2_().real();
+  }
+
+  Vectorized atan() const {
+    // atan(x) = i/2 * ln((i + z)/(i - z))
+    auto ione = Vectorized{vinner_type(image_one())};
+    auto sum = ione + *this;
+    auto sub = ione - *this;
+    auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
+    return ln *
+        Vectorized{vinner_type(image_half())}; // i/2*ln()
+  }
+
+  Vectorized atanh() const {
+    return mapOrdinary(std::atanh);
+  }
+
+  Vectorized asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+#if 1
+    vinner_type cnj = conj().vec();
+    vinner_type b_a = cnj.swapped();
+    vinner_type ab = cnj * b_a;
+    vinner_type im = ab + ab;
+    vinner_type val_2 = _vec * _vec;
+    vinner_type val_2_swapped = val_2.swapped();
+    vinner_type re = vinner_type::horizontal_sub_perm(val_2, val_2_swapped);
+    re = vinner_type(static_cast(1)) - re;
+    constexpr int blend_mask =
+        blend_choice(); // 0x0A for complex , 0xAA for complex
+    vinner_type blendx = vinner_type::template blend(re, im);
+    auto root = Vectorized(blendx).sqrt();
+    auto ln = Vectorized(Vectorized(b_a) + root).log();
+    return Vectorized(ln.vec().swapped()).conj();
+#else
+    return mapOrdinary(std::asin);
+#endif
+  }
+
+  Vectorized acos() const {
+    // acos(x) = pi/2 - asin(x)
+    return Vectorized(vinner_type(pi_half())) - asin();
+  }
+
+  Vectorized sin() const {
+    return mapOrdinary(std::sin);
+  }
+  Vectorized sinh() const {
+    return mapOrdinary(std::sinh);
+  }
+  Vectorized cos() const {
+    return mapOrdinary(std::cos);
+  }
+  Vectorized cosh() const {
+    return mapOrdinary(std::cosh);
+  }
+  Vectorized ceil() const {
+    return Vectorized{_vec.ceil()};
+  }
+  Vectorized floor() const {
+    return Vectorized{_vec.floor()};
+  }
+  Vectorized neg() const {
+    return Vectorized(_vec.neg());
+  }
+  Vectorized round() const {
+    return Vectorized{_vec.round()};
+  }
+  Vectorized tan() const {
+    return mapOrdinary(std::tan);
+  }
+  Vectorized tanh() const {
+    return mapOrdinary(std::tanh);
+  }
+  Vectorized trunc() const {
+    return Vectorized{_vec.trunc()};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator&(const Vectorized& other) const {
+    return Vectorized{_vec & other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator|(const Vectorized& other) const {
+    return Vectorized{_vec | other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator^(const Vectorized& other) const {
+    return Vectorized{_vec ^ other._vec};
+  }
+  Vectorized C10_ALWAYS_INLINE operator==(const Vectorized& other) const {
+    return Vectorized{_vec == other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE operator!=(const Vectorized& other) const {
+    return Vectorized{_vec != other._vec};
+  }
+
+  Vectorized C10_ALWAYS_INLINE eq(const Vectorized& other) const {
+    auto eq = _vec.eq(other._vec);  // compares real and imag individually
+    // If both real numbers and imag numbers are equal, then the complex numbers are equal
+    auto real = eq & vinner_type(real_mask());
+    auto imag = (eq & vinner_type(image_mask())).swapped();
+    return Vectorized{real & imag};
+  }
+  Vectorized C10_ALWAYS_INLINE ne(const Vectorized& other) const {
+    auto ne = _vec.ne(other._vec);  // compares real and imag individually
+    // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+    auto real = ne & vinner_type(real_mask());
+    auto imag = (ne & vinner_type(image_mask())).swapped();
+    return Vectorized{real | imag};
+  }
+
+  Vectorized real() const {
+    return Vectorized(_vec & vinner_type(real_mask()));
+  }
+  Vectorized imag_() const {
+    return Vectorized(_vec & vinner_type(image_mask()));
+  }
+  Vectorized imag() const {
+    return Vectorized{
+        (_vec & vinner_type(image_mask())).swapped()};
+  }
+
+  Vectorized conj() const {
+    return Vectorized(_vec ^ vinner_type(isign_mask()));
+  }
+
+  vinner_data abs_2_() const {
+    auto a = _vec * _vec;
+    a = a + a.swapped();
+    return a.mergee().data();
+  }
+
+  static T abs_helper(const T &value)
+  {
+    return T(std::abs(value));
+  }
+
+  Vectorized abs() const {
+    return mapOrdinary(abs_helper);
+  }
+
+  Vectorized exp() const {
+    return mapOrdinary(std::exp);
+  }
+
+  Vectorized exp2() const {
+    return mapOrdinary(exp2_impl);
+  }
+
+  Vectorized expm1() const {
+    return mapOrdinary(std::expm1);
+  }
+
+  Vectorized log() const {
+    return mapOrdinary(std::log);
+  }
+
+  Vectorized log2() const {
+    // log2eB_inv
+    auto ret = log();
+    return Vectorized{ret._vec * vinner_type(log2e_inv())};
+  }
+
+  Vectorized log10() const {
+    auto ret = log();
+    return Vectorized{ret._vec * vinner_type(log10e_inv())};
+  }
+
+  Vectorized log1p() const {
+    return mapOrdinary(std::log1p);
+  }
+
+  Vectorized sgn() const {
+    return mapOrdinary(at::native::sgn_impl);
+  }
+
+  Vectorized pow(const Vectorized& exp) const {
+    return mapOrdinary(std::pow, exp);
+  }
+
+  Vectorized sqrt() const {
+    return mapOrdinary(std::sqrt);
+  }
+
+  Vectorized reciprocal() const {
+    // re + im*i = (a + bi)  / (c + di)
+    // re = (ac + bd)/abs_2() = c/abs_2()
+    // im = (bc - ad)/abs_2() = d/abs_2()
+    vinner_type c_d = _vec ^ vinner_type(isign_mask());
+    vinner_type abs = abs_2_();
+    return Vectorized{c_d / abs};
+  }
+
+  Vectorized rsqrt() const {
+    return sqrt().reciprocal();
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized lt(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized le(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized gt(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized ge(const Vectorized& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+};
+
+template  = 0>
+std::pair, Vectorized> inline inner_interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a      = {a0, a1, a2, a3}
+  //   b      = {b0, b1, b2, b3}
+  using vtype = typename Vectorized::vtype;
+  vtype ab00 = vec_permi(a.vec0(), b.vec0(), 0);
+  vtype ab11 = vec_permi(a.vec0(), b.vec0(), 3);
+  vtype ab2_00 = vec_permi(a.vec1(), b.vec1(), 0);
+  vtype ab2_11 = vec_permi(a.vec1(), b.vec1(), 3);
+  //   return {a0, b0, a1, b1}
+  //          {a2, b2, a3, b3}
+  return std::make_pair(
+      Vectorized{ab00, ab11}, Vectorized{ab2_00, ab2_11});
+}
+
+template  = 0>
+std::pair, Vectorized> inline inner_deinterleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1}
+  //   b = {a2, b2, a3, b3}
+  using vtype = typename Vectorized::vtype;
+  vtype aa01 = vec_permi(a.vec0(), a.vec1(), 0);
+  vtype aa23 = vec_permi(b.vec0(), b.vec1(), 0);
+
+  vtype bb_01 = vec_permi(a.vec0(), a.vec1(), 3);
+  vtype bb_23 = vec_permi(b.vec0(), b.vec1(), 3);
+
+  // swap lanes:
+  //   return {a0, a1, a2, a3}
+  //          {b0, b1, b2, b3}
+  return std::make_pair(Vectorized{aa01, aa23}, Vectorized{bb_01, bb_23});
+}
+
+template  = 0>
+std::pair, Vectorized> inline inner_interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a = {a0, a1, a2, a3,, a4, a5, a6, a7}
+  //   b = {b0, b1, b2, b3,, b4, b5, b6, b7}
+  using vtype = typename Vectorized::vtype;
+  vtype ab0011 = vec_mergeh(a.vec0(), b.vec0());
+  vtype ab2233 = vec_mergel(a.vec0(), b.vec0());
+
+  vtype ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
+  vtype ab2_2233 = vec_mergel(a.vec1(), b.vec1());
+  // group cols crossing lanes:
+  //   return {a0, b0, a1, b1,, a2, b2, a3, b3}
+  //          {a4, b4, a5, b5,, a6, b6, a7, b7}
+
+  return std::make_pair(
+      Vectorized{ab0011, ab2233}, Vectorized{ab2_0011, ab2_2233});
+}
+
+template  = 0>
+std::pair, Vectorized> inline inner_deinterleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1,, a2, b2, a3, b3}
+  //   b = {a4, b4, a5, b5,, a6, b6, a7, b7}
+  using vtype = typename Vectorized::vtype;
+  // {a0,a2,b0,b2} {a1,a3,b1,b3}
+  vtype a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
+  vtype a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());
+
+  vtype aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
+  vtype bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);
+
+  vtype a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
+  vtype a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());
+
+  vtype aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
+  vtype bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);
+
+  // it could be done with vec_perm ,too
+  // swap lanes:
+  //   return {a0, a1, a2, a3,, a4, a5, a6, a7}
+  //          {b0, b1, b2, b3,, b4, b5, b6, b7}
+
+  return std::make_pair(
+      Vectorized{aa0123, aa0123_2}, Vectorized{bb0123, bb0123_2});
+}
+
+template <>
+std::pair, Vectorized> inline interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return inner_interleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return inner_interleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return inner_interleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline interleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return inner_interleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline deinterleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return inner_deinterleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline deinterleave2<
+    int32_t>(const Vectorized& a, const Vectorized& b) {
+  return inner_deinterleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline deinterleave2(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return inner_deinterleave2(a, b);
+}
+
+template <>
+std::pair, Vectorized> inline deinterleave2<
+    int64_t>(const Vectorized& a, const Vectorized& b) {
+  return inner_deinterleave2(a, b);
+}
+
+template 
+typename std::enable_if::value, at::vec::Vectorized>::type
+inline convert_int8_to_float(const Vectorized &src) {
+  // Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size()
+  // Only handle first 64 bits
+  auto vec_int = src.to_vec_float_helper();
+
+  return convert_to_float(vec_int);
+}
+
+template 
+typename std::enable_if::value, at::vec::Vectorized>::type
+inline convert_float_to_int8(const Vectorized &src) {
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+
+  auto vec_int = clamp(convert_to_int(src), Vectorized(min_val), Vectorized(max_val));
+
+  return vec_int.to_vec_uint8_helper();
+}
+
+#undef DEFINE_CLAMP_MAXMIN_FUNCS
+#undef DEFINE_MAXMIN_FUNCS
+} // namespace
+} // namespace vec
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h
new file mode 100644
index 0000000000000000000000000000000000000000..782f0d3950f5bbb2dfe4387e6aa0d1d48aafc14b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512.h
@@ -0,0 +1,275 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace vec {
+
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
+  stream << val.val_;
+  return stream;
+}
+inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
+  stream << static_cast(val.val_);
+  return stream;
+}
+inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
+  stream << static_cast(val.val_);
+  return stream;
+}
+
+template 
+std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) {
+  T buf[Vectorized::size()];
+  vec.store(buf);
+  stream << "vec[";
+  for (int i = 0; i != Vectorized::size(); i++) {
+    if (i != 0) {
+      stream << ", ";
+    }
+    stream << buf[i];
+  }
+  stream << "]";
+  return stream;
+}
+
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm512_castpd_ps(src);
+}
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm512_castps_pd(src);
+}
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm512_castsi512_ps(src);
+}
+
+template<>
+inline Vectorized cast(const Vectorized& src) {
+  return _mm512_castsi512_pd(src);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template
+std::enable_if_t>
+inline gather(const double* base_addr, const Vectorized& vindex) {
+  return _mm512_i64gather_pd(vindex, base_addr, scale);
+}
+
+template
+std::enable_if_t>
+inline gather(const float* base_addr, const Vectorized& vindex) {
+  return _mm512_i32gather_ps(vindex, base_addr, scale);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template
+std::enable_if_t>
+inline mask_gather(const Vectorized& src, const double* base_addr,
+                   const Vectorized& vindex, Vectorized& mask) {
+  auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF));
+  auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ);
+  return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale);
+}
+
+template
+std::enable_if_t>
+inline mask_gather(const Vectorized& src, const float* base_addr,
+                   const Vectorized& vindex, Vectorized& mask) {
+  auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF));
+  auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
+  return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<>
+Vectorized
+inline convert_to_int_of_same_size(const Vectorized &src) {
+  return _mm512_cvtpd_epi64(src);
+}
+
+template<>
+Vectorized
+inline convert_to_int_of_same_size(const Vectorized &src) {
+  return _mm512_cvttps_epi32(src);
+}
+
+template<>
+Vectorized
+inline convert_to_fp_of_same_size(const Vectorized &src) {
+  return _mm512_cvtepi64_pd(src);
+}
+
+template<>
+Vectorized
+inline convert_to_fp_of_same_size(const Vectorized &src) {
+  return _mm512_cvtepi32_ps(src);
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template <>
+std::pair, Vectorized>
+inline interleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, a1, a3, a3, a4, a5, a6, a7}
+  //   b = {b0, b1, b2, b3, b4, b5, b6, b7}
+  // group cols crossing lanes:
+  //   return {a0, b0, a1, b1, a2, b2, a3, b3}
+  //          {a4, b4, a5, b5, a6, b6, a7, b7}
+  __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0);
+  __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4);
+  return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
+                        _mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
+}
+
+template <>
+std::pair, Vectorized>
+inline interleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
+  //   b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
+  //
+  //  return:
+  //    {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
+  //    {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
+  __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4,
+                                  19, 3, 18, 2, 17, 1, 16, 0);
+  __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12,
+                                  27, 11, 26, 10, 25, 9, 24, 8);
+  return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
+                        _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template <>
+std::pair, Vectorized>
+inline deinterleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1, a2, b2, a3, b3}
+  //   b = {a4, b4, a5, b5, a6, b6, a7, b7}
+  // output:
+  //   return {a0, a1, a2, a3, a4, a5, a6, a7}
+  //          {b0, b1, b2, b3, b4, b5, b6, b7}
+  // The members of indices have been written in binary format for better understandability
+  __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
+  __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
+
+  return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
+                        _mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
+}
+
+template <>
+std::pair, Vectorized>
+inline deinterleave2(const Vectorized& a, const Vectorized& b) {
+  // inputs:
+  //   a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
+  //   b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
+  // output:
+  //   return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
+  //          {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
+  __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16,
+                                  14, 12, 10, 8, 6, 4, 2, 0);
+  __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17,
+                                  15, 13, 11, 9, 7, 5, 3, 1);
+
+  return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
+                        _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
+}
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7,
+                                        8, 9, 10, 11, 12, 13, 14, 15);
+  return _mm512_permutexvar_ps(mask, v);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
+  return _mm512_permutexvar_pd(mask, v);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
+  return _mm512_permutexvar_epi64(mask, v);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7,
+                                        8, 9, 10, 11, 12, 13, 14, 15);
+  return _mm512_permutexvar_epi32(mask, v);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  const __m512i mask = _mm512_set_epi16(
+      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+      16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
+  );
+  return _mm512_permutexvar_epi16(mask, v);
+}
+
+inline __m512i flip8(const __m512i & v) {
+  const __m512i mask1 = _mm512_set_epi8(
+      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+      0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+  );
+  const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6);
+  auto reversed_vec = _mm512_shuffle_epi8(v, mask1);
+  return _mm512_permutexvar_epi64(mask2, reversed_vec);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  return flip8(v);
+}
+
+template<>
+inline Vectorized flip(const Vectorized & v) {
+  return flip8(v);
+}
+
+#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h
new file mode 100644
index 0000000000000000000000000000000000000000..96180ed1096da1775374d85144a644a95c6ccf1c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h
@@ -0,0 +1,1644 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+// bfloat16 conversion
+static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
+  o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
+}
+
+static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
+  __m256i lo = _mm512_extracti32x8_epi32(a, 0);
+  __m256i hi = _mm512_extracti32x8_epi32(a, 1);
+  cvtbf16_fp32(lo, o1);
+  cvtbf16_fp32(hi, o2);
+}
+
+static inline __m256i cvtfp32_bf16(const __m512& src) {
+  __m512i value = _mm512_castps_si512(src);
+  __m512i nan = _mm512_set1_epi32(0xffff);
+  auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
+  __m512i ones = _mm512_set1_epi32(0x1);
+  __m512i vec_bias = _mm512_set1_epi32(0x7fff);
+  // uint32_t lsb = (input >> 16) & 1;
+  auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
+  // uint32_t rounding_bias = 0x7fff + lsb;
+  t_value = _mm512_add_epi32(t_value, vec_bias);
+  // input += rounding_bias;
+  t_value = _mm512_add_epi32(t_value, value);
+  // input = input >> 16;
+  t_value = _mm512_srli_epi32(t_value, 16);
+  // Check NaN before converting back to bf16
+  t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
+  return _mm512_cvtusepi32_epi16(t_value);
+}
+
+static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
+  __m512i lo = _mm512_castps_si512(a);
+  __m512i hi = _mm512_castps_si512(b);
+  __m512i nan = _mm512_set1_epi32(0xffff);
+  auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
+  auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q);
+  __m512i ones = _mm512_set1_epi32(0x1);
+  __m512i vec_bias = _mm512_set1_epi32(0x7fff);
+  // uint32_t lsb = (input >> 16) & 1;
+  auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones);
+  auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones);
+  // uint32_t rounding_bias = 0x7fff + lsb;
+  t_lo = _mm512_add_epi32(t_lo, vec_bias);
+  t_hi = _mm512_add_epi32(t_hi, vec_bias);
+  // input += rounding_bias;
+  t_lo = _mm512_add_epi32(t_lo, lo);
+  t_hi = _mm512_add_epi32(t_hi, hi);
+  // input = input >> 16;
+  t_lo = _mm512_srli_epi32(t_lo, 16);
+  t_hi = _mm512_srli_epi32(t_hi, 16);
+  // Check NaN before converting back to bf16
+  t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo);
+  t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi);
+
+  t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
+  __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
+  return _mm512_permutexvar_epi64(idx, t_lo);
+}
+
+static inline __m512i merge_compare_result(const __m512& a, const __m512& b) {
+  __m512i lo = _mm512_castps_si512(a);
+  __m512i hi = _mm512_castps_si512(b);
+  lo = _mm512_srli_epi32(lo, 16);
+  hi = _mm512_srli_epi32(hi, 16);
+  auto out = _mm512_packus_epi32(lo, hi);
+  __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
+  return _mm512_permutexvar_epi64(idx, out);
+}
+
+// float16 conversion
+static inline void cvtfp16_fp32(const __m256i& a, __m512& o) {
+  o = _mm512_cvtph_ps(a);
+}
+
+static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
+  __m256i lo = _mm512_extracti32x8_epi32(a, 0);
+  __m256i hi = _mm512_extracti32x8_epi32(a, 1);
+  cvtfp16_fp32(lo, o1);
+  cvtfp16_fp32(hi, o2);
+}
+
+static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) {
+  __m256i lo = _mm512_cvtps_ph(
+      a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  __m256i hi = _mm512_cvtps_ph(
+      b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo));
+  __m256 t_hi = _mm256_castsi256_ps(hi);
+  return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1));
+}
+
+// dtype conversion between float16/bfloat16 and float32
+template , int> = 0>
+inline void cvt_to_fp32(const __m256i& a, __m512& o);
+template <> inline void cvt_to_fp32(const __m256i& a, __m512& o) {
+  cvtbf16_fp32(a, o);
+}
+template <> inline void cvt_to_fp32(const __m256i& a, __m512& o) {
+  cvtfp16_fp32(a, o);
+}
+
+template , int> = 0>
+inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2);
+template <> inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) {
+  cvtbf16_fp32(a, o1, o2);
+}
+template <> inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2) {
+  cvtfp16_fp32(a, o1, o2);
+}
+
+template , int> = 0>
+inline __m512i cvt_from_fp32(const __m512& a, const __m512& b);
+template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) {
+  return cvtfp32_bf16(a, b);
+}
+template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) {
+  return merge_compare_result(a, b);
+}
+template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) {
+  return cvtfp32_fp16(a, b);
+}
+template <> inline __m512i cvt_from_fp32(const __m512& a, const __m512& b) {
+  return cvtfp32_fp16(a, b);
+}
+
+template 
+class Vectorized16 {
+static_assert(
+  is_reduced_floating_point_v,
+  "Support only float16 and bfloat16.");
+private:
+  __m512i values;
+public:
+  using value_type = uint16_t;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 32;
+  }
+  Vectorized16() {}
+  Vectorized16(__m512i v) : values(v) {}
+  Vectorized16(T val) {
+    value_type uw = val.x;
+    values = _mm512_set1_epi16(uw);
+  }
+  Vectorized16(T val1, T val2, T val3, T val4,
+         T val5, T val6, T val7, T val8,
+         T val9, T val10, T val11, T val12,
+         T val13, T val14, T val15, T val16,
+         T val17, T val18, T val19, T val20,
+         T val21, T val22, T val23, T val24,
+         T val25, T val26, T val27, T val28,
+         T val29, T val30, T val31, T val32) {
+    values = _mm512_set_epi16(
+        val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x,
+        val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x,
+        val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x,
+        val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x);
+  }
+  operator __m512i() const {
+    return values;
+  }
+  T& operator[](int idx) = delete;
+  const T& operator[](int idx) const  = delete;
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
+  }
+  static Vectorized loadu(const void* ptr, int16_t count = size()) {
+    if (count == size())
+      return _mm512_loadu_si512(reinterpret_cast(ptr));
+
+    __mmask32 mask = (1ULL << count) - 1;
+    return _mm512_maskz_loadu_epi16(mask, ptr);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
+    } else if (count > 0) {
+      __mmask32 mask = (1ULL << count) - 1;
+      _mm512_mask_storeu_epi16(ptr, mask, values);
+    }
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    __at_align__ int16_t tmp_values[size()];
+    a.store(tmp_values);
+    if (mask & 0x01)
+      tmp_values[0] = b.values[31];
+    if (mask & 0x02)
+      tmp_values[1] = b.values[30];
+    if (mask & 0x04)
+      tmp_values[2] = b.values[29];
+    if (mask & 0x08)
+      tmp_values[3] = b.values[28];
+    if (mask & 0x10)
+      tmp_values[4] = b.values[27];
+    if (mask & 0x20)
+      tmp_values[5] = b.values[26];
+    if (mask & 0x40)
+      tmp_values[6] = b.values[25];
+    if (mask & 0x80)
+      tmp_values[7] = b.values[24];
+    if (mask & 0x100)
+      tmp_values[8] = b.values[23];
+    if (mask & 0x200)
+      tmp_values[9] = b.values[22];
+    if (mask & 0x400)
+      tmp_values[10] = b.values[21];
+    if (mask & 0x800)
+      tmp_values[11] = b.values[20];
+    if (mask & 0x1000)
+      tmp_values[12] = b.values[19];
+    if (mask & 0x2000)
+      tmp_values[13] = b.values[18];
+    if (mask & 0x4000)
+      tmp_values[14] = b.values[17];
+    if (mask & 0x8000)
+      tmp_values[15] = b.values[16];
+    if (mask & 0x10000)
+      tmp_values[16] = b.values[15];
+    if (mask & 0x20000)
+      tmp_values[17] = b.values[14];
+    if (mask & 0x40000)
+      tmp_values[18] = b.values[13];
+    if (mask & 0x80000)
+      tmp_values[19] = b.values[12];
+    if (mask & 0x100000)
+      tmp_values[20] = b.values[11];
+    if (mask & 0x200000)
+      tmp_values[21] = b.values[10];
+    if (mask & 0x400000)
+      tmp_values[22] = b.values[9];
+    if (mask & 0x800000)
+      tmp_values[23] = b.values[8];
+    if (mask & 0x1000000)
+      tmp_values[24] = b.values[7];
+    if (mask & 0x2000000)
+      tmp_values[25] = b.values[6];
+    if (mask & 0x4000000)
+      tmp_values[26] = b.values[5];
+    if (mask & 0x8000000)
+      tmp_values[27] = b.values[4];
+    if (mask & 0x10000000)
+      tmp_values[28] = b.values[3];
+    if (mask & 0x20000000)
+      tmp_values[29] = b.values[2];
+    if (mask & 0x40000000)
+      tmp_values[30] = b.values[1];
+    if (mask & 0x80000000)
+      tmp_values[31] = b.values[0];
+    return loadu(tmp_values);
+  }
+  static Vectorized blendv(const Vectorized& a,
+      const Vectorized& b, const Vectorized& mask) {
+    auto all_ones = _mm512_set1_epi16(0xFFFF);
+    auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_epi16(mask_, a.values, b.values);
+  }
+  template
+  static Vectorized arange(T base = 0.f, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
+      base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
+      base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
+      base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
+      base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
+  }
+  static Vectorized set(const Vectorized& a,
+      const Vectorized& b, int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+      case 8:
+        return blend<255>(a, b);
+      case 9:
+        return blend<511>(a, b);
+      case 10:
+        return blend<1023>(a, b);
+      case 11:
+        return blend<2047>(a, b);
+      case 12:
+        return blend<4095>(a, b);
+      case 13:
+        return blend<8191>(a, b);
+      case 14:
+        return blend<16383>(a, b);
+      case 15:
+        return blend<32767>(a, b);
+      case 16:
+        return blend<65535>(a, b);
+      case 17:
+        return blend<131071>(a, b);
+      case 18:
+        return blend<262143>(a, b);
+      case 19:
+        return blend<524287>(a, b);
+      case 20:
+        return blend<1048575>(a, b);
+      case 21:
+        return blend<2097151>(a, b);
+      case 22:
+        return blend<4194303>(a, b);
+      case 23:
+        return blend<8388607>(a, b);
+      case 24:
+        return blend<16777215>(a, b);
+      case 25:
+        return blend<33554431>(a, b);
+      case 26:
+        return blend<67108863>(a, b);
+      case 27:
+        return blend<134217727>(a, b);
+      case 28:
+        return blend<268435455>(a, b);
+      case 29:
+        return blend<536870911>(a, b);
+      case 30:
+        return blend<1073741823>(a, b);
+      case 31:
+        return blend<2147483647>(a, b);
+    }
+    return b;
+  }
+  #pragma clang diagnostic push
+  #pragma clang diagnostic ignored "-Wignored-qualifiers"
+  Vectorized map(const __m512 (*const vop)(__m512)) const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    const auto o1 = vop(lo);
+    const auto o2 = vop(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized isnan() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    __mmask16 lo_mask, hi_mask;
+    __m512 zero = _mm512_set1_ps(0.0);
+    __m512i zeroi = _mm512_castps_si512(zero);
+    lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q);
+    lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF));
+    hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q);
+    hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF));
+    return merge_compare_result(lo, hi);
+  }
+  #pragma clang diagnostic pop
+  Vectorized abs() const {
+    return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values);
+  }
+  Vectorized angle() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto angle_lambda = [](__m512 values) {
+      const auto zero_vec = _mm512_set1_ps(0.f);
+      const auto nan_vec = _mm512_set1_ps(NAN);
+      const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
+      const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
+                                                           not_nan_mask, 0xFFFFFFFF);
+      const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec),
+                                               zero_vec, _CMP_EQ_OQ);
+      const auto pi = _mm512_set1_ps(c10::pi);
+
+      const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
+      auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
+      angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
+      return angle;
+    };
+    auto o1 = angle_lambda(lo);
+    auto o2 = angle_lambda(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_epi16(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return map(Sleef_acosf16_u10);
+  }
+  Vectorized acosh() const {
+    return map(Sleef_acoshf16_u10);
+  }
+  Vectorized asin() const {
+    return map(Sleef_asinf16_u10);
+  }
+  Vectorized atan() const {
+    return map(Sleef_atanf16_u10);
+  }
+  Vectorized atanh() const {
+    return map(Sleef_atanhf16_u10);
+  }
+  Vectorized atan2(const Vectorized &b) const {
+    __m512 lo, hi;
+    __m512 b1, b2;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(b.values, b1, b2);
+    auto o1 = Sleef_atan2f16_u10(lo, b1);
+    auto o2 = Sleef_atan2f16_u10(hi, b2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    // copy sign bit (0x8000) from sign and remaining bits from values
+    __m512i mask_value = _mm512_set1_epi32(~0x80008000);
+    __m512i mask_signbit = _mm512_set1_epi32(0x80008000);
+    return Vectorized(
+      _mm512_or_si512(
+        _mm512_and_si512(values, mask_value),
+        _mm512_and_si512(sign, mask_signbit)));
+  }
+  Vectorized erf() const {
+    return map(Sleef_erff16_u10);
+  }
+  Vectorized erfc() const {
+    return map(Sleef_erfcf16_u15);
+  }
+  Vectorized erfinv() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm512_storeu_ps(reinterpret_cast(tmp2), hi);
+    for (int64_t i = 0; i < size() / 2; i++) {
+      tmp1[i] = calc_erfinv(tmp1[i]);
+      tmp2[i] = calc_erfinv(tmp2[i]);
+    }
+    auto o1 = _mm512_loadu_ps(tmp1);
+    auto o2 = _mm512_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized exp() const {
+    return map(Sleef_expf16_u10);
+  }
+  Vectorized exp2() const {
+    return map(Sleef_exp2f16_u10);
+  }
+  Vectorized expm1() const {
+    return map(Sleef_expm1f16_u10);
+  }
+  Vectorized exp_u20() const {
+    return exp();
+  }
+  Vectorized fmod(const Vectorized & q) const {
+    __m512 x_lo, x_hi;
+    cvt_to_fp32(values, x_lo, x_hi);
+    __m512 q_lo, q_hi;
+    cvtbf16_fp32(q.values, q_lo, q_hi);
+    auto o1 = Sleef_fmodf16(x_lo, q_lo);
+    auto o2 = Sleef_fmodf16(x_hi, q_hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    __m512 lo, hi;
+    __m512 b1, b2;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(b.values, b1, b2);
+    auto o1 = Sleef_hypotf16_u05(lo, b1);
+    auto o2 = Sleef_hypotf16_u05(hi, b2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized i0() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm512_storeu_ps(reinterpret_cast(tmp2), hi);
+    for (int64_t i = 0; i < size() / 2; i++) {
+      tmp1[i] = calc_i0(tmp1[i]);
+      tmp2[i] = calc_i0(tmp2[i]);
+    }
+    auto o1 = _mm512_loadu_ps(tmp1);
+    auto o2 = _mm512_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized i0e() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    constexpr auto sz = size();
+    __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm512_storeu_ps(reinterpret_cast(tmp2), hi);
+
+    for (auto i = decltype(sz){0}; i < sz / 2; i++) {
+      tmp1[i] = calc_i0e(tmp1[i]);
+      tmp2[i] = calc_i0e(tmp2[i]);
+    }
+    const auto o1 = _mm512_loadu_ps(tmp1);
+    const auto o2 = _mm512_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized digamma() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    constexpr auto sz = size();
+    __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm512_storeu_ps(reinterpret_cast(tmp2), hi);
+
+    for (auto i = decltype(sz){0}; i < sz / 2; i++) {
+      tmp1[i] = calc_digamma(tmp1[i]);
+      tmp2[i] = calc_digamma(tmp2[i]);
+    }
+    const auto o1 = _mm512_loadu_ps(tmp1);
+    const auto o2 = _mm512_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __m512 lo, hi;
+    __m512 xlo, xhi;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(x.values, xlo, xhi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm512_storeu_ps(reinterpret_cast(tmp2), hi);
+    __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo);
+    _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi);
+    for (int64_t i = 0; i < size() / 2; ++i) {
+      tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
+      tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
+    }
+    auto o1 = _mm512_loadu_ps(tmp1);
+    auto o2 = _mm512_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+
+  Vectorized igammac(const Vectorized &x) const {
+    __m512 lo, hi;
+    __m512 xlo, xhi;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(x.values, xlo, xhi);
+    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmp1), lo);
+    _mm512_storeu_ps(reinterpret_cast(tmp2), hi);
+    __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
+    _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo);
+    _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi);
+    for (int64_t i = 0; i < size() / 2; ++i) {
+      tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
+      tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
+    }
+    auto o1 = _mm512_loadu_ps(tmp1);
+    auto o2 = _mm512_loadu_ps(tmp2);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized log() const {
+    return map(Sleef_logf16_u10);
+  }
+  Vectorized log2() const {
+    return map(Sleef_log2f16_u10);
+  }
+  Vectorized log10() const {
+    return map(Sleef_log10f16_u10);
+  }
+  Vectorized log1p() const {
+    return map(Sleef_log1pf16_u10);
+  }
+  Vectorized sin() const {
+    return map(Sleef_sinf16_u10);
+  }
+  Vectorized sinh() const {
+    return map(Sleef_sinhf16_u10);
+  }
+  Vectorized cos() const {
+    return map(Sleef_cosf16_u10);
+  }
+  Vectorized cosh() const {
+    return map(Sleef_coshf16_u10);
+  }
+  Vectorized ceil() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm512_ceil_ps(lo);
+    auto o2 = _mm512_ceil_ps(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized floor() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm512_floor_ps(lo);
+    auto o2 = _mm512_floor_ps(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized neg() const {
+    return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000));
+  }
+  Vectorized round() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+    auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized tan() const {
+    return map(Sleef_tanf16_u10);
+  }
+  Vectorized tanh() const {
+    return map(Sleef_tanhf16_u10);
+  }
+  Vectorized trunc() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+    auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized lgamma() const {
+    return map(Sleef_lgammaf16_u10);
+  }
+  Vectorized sqrt() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto o1 = _mm512_sqrt_ps(lo);
+    auto o2 = _mm512_sqrt_ps(hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized reciprocal() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto ones = _mm512_set1_ps(1);
+    auto o1 = _mm512_div_ps(ones, lo);
+    auto o2 = _mm512_div_ps(ones, hi);
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized rsqrt() const {
+    __m512 lo, hi;
+    cvt_to_fp32(values, lo, hi);
+    auto ones = _mm512_set1_ps(1);
+    auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
+    auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
+    return cvt_from_fp32(o1, o2);
+  }
+  Vectorized pow(const Vectorized &b) const {
+    __m512 lo, hi;
+    __m512 b1, b2;
+    cvt_to_fp32(values, lo, hi);
+    cvt_to_fp32(b.values, b1, b2);
+    auto o1 = Sleef_powf16_u10(lo, b1);
+    auto o2 = Sleef_powf16_u10(hi, b2);
+    return cvt_from_fp32(o1, o2);
+  }
+private:
+  template
+  Vectorized inline binary_compare(const Vectorized& b, Op op) const {
+    __m512 a_lo, a_hi;
+    __m512 b_lo, b_hi;
+    cvt_to_fp32(values, a_lo, a_hi);
+    cvt_to_fp32(b.values, b_lo, b_hi);
+    auto o1 = op(a_lo, b_lo);
+    auto o2 = op(a_hi, b_hi);
+    return cvt_from_fp32(o1, o2);
+  }
+
+public:
+  Vectorized inline operator>(const Vectorized& other) const {
+    return binary_compare(other, [](__m512 x, __m512 y) {
+      auto zero_vec = _mm512_set1_epi32(0);
+      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
+      return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
+    });
+  }
+  Vectorized inline operator<(const Vectorized& other) const {
+    return binary_compare(other, [](__m512 x, __m512 y) {
+      auto zero_vec = _mm512_set1_epi32(0);
+      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
+      return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
+    });
+  }
+  Vectorized inline operator>=(const Vectorized& other) const {
+    return binary_compare(other, [](__m512 x, __m512 y) {
+      auto zero_vec = _mm512_set1_epi32(0);
+      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
+      return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
+    });
+  }
+  Vectorized inline operator<=(const Vectorized& other) const {
+    return binary_compare(other, [](__m512 x, __m512 y) {
+      auto zero_vec = _mm512_set1_epi32(0);
+      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
+      return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
+    });
+  }
+  Vectorized inline operator==(const Vectorized& other) const {
+    return binary_compare(other, [](__m512 x, __m512 y) {
+      auto zero_vec = _mm512_set1_epi32(0);
+      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
+      return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
+    });
+  }
+  Vectorized inline operator!=(const Vectorized& other) const {
+    return binary_compare(other, [](__m512 x, __m512 y) {
+      auto zero_vec = _mm512_set1_epi32(0);
+      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
+      return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
+    });
+  }
+};
+
+template
+static inline Vectorized binary_op_as_fp32(const Vectorized& a, const Vectorized& b, Op op) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  cvt_to_fp32(__m512i(a), a_lo, a_hi);
+  cvt_to_fp32(__m512i(b), b_lo, b_hi);
+  auto o1 = op(a_lo, b_lo);
+  auto o2 = op(a_hi, b_hi);
+  return cvt_from_fp32(o1, o2);
+}
+
+template <>
+class Vectorized: public Vectorized16 {
+public:
+  using Vectorized16::Vectorized16;
+
+  Vectorized frac() const;
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
+}
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
+}
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
+}
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
+}
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm512_and_si512(a, b);
+}
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm512_or_si512(a, b);
+}
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm512_xor_si512(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  cvtbf16_fp32(__m512i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m512i(b), b_lo, b_hi);
+  auto max_lo = _mm512_max_ps(a_lo, b_lo);
+  auto max_hi = _mm512_max_ps(a_hi, b_hi);
+  auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
+  auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
+  auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm512_or_ps(max_lo, nan_lo);
+  auto o2 = _mm512_or_ps(max_hi, nan_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  __m512i zero_vec = _mm512_set1_epi32(0);
+  cvtbf16_fp32(__m512i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m512i(b), b_lo, b_hi);
+  auto min_lo = _mm512_min_ps(a_lo, b_lo);
+  auto min_hi = _mm512_min_ps(a_hi, b_hi);
+  auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
+  auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
+                                                           0xFFFFFFFF));
+  auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
+                                                           0xFFFFFFFF));
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm512_or_ps(min_lo, nan_lo);
+  auto o2 = _mm512_or_ps(min_hi, nan_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a,
+    const Vectorized& min, const Vectorized& max) {
+  __m512 a_lo, a_hi;
+  __m512 min_lo, min_hi;
+  __m512 max_lo, max_hi;
+  cvtbf16_fp32(__m512i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m512i(min), min_lo, min_hi);
+  cvtbf16_fp32(__m512i(max), max_lo, max_hi);
+  auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
+  auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  __m512 a_lo, a_hi;
+  __m512 max_lo, max_hi;
+  cvtbf16_fp32(__m512i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m512i(max), max_lo, max_hi);
+  auto o1 = _mm512_min_ps(max_lo, a_lo);
+  auto o2 = _mm512_min_ps(max_hi, a_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  __m512 a_lo, a_hi;
+  __m512 min_lo, min_hi;
+  cvtbf16_fp32(__m512i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m512i(min), min_lo, min_hi);
+  auto o1 = _mm512_max_ps(min_lo, a_lo);
+  auto o2 = _mm512_max_ps(min_hi, a_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+template <>
+inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+template <>
+inline void convert(const float* src, BFloat16* dst, int64_t n) {
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m512 a = _mm512_loadu_ps(&src[i]);
+    __m512 b = _mm512_loadu_ps(&src[i + 16]);
+
+    __m512i bf = cvtfp32_bf16(a, b);
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+inline void convert(const double* src, BFloat16* dst, int64_t n) {
+  auto load_float = [](const double *src) -> __m512 {
+    // Load one float vector from an array of doubles
+    __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
+    __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
+    return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
+  };
+
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m512 a = load_float(&src[i]);
+    __m512 b = load_float(&src[i + 16]);
+
+    __m512i bf = cvtfp32_bf16(a, b);
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a,
+    const Vectorized& b, const Vectorized& c) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  __m512 c_lo, c_hi;
+  cvtbf16_fp32(__m512i(a), a_lo, a_hi);
+  cvtbf16_fp32(__m512i(b), b_lo, b_hi);
+  cvtbf16_fp32(__m512i(c), c_lo, c_hi);
+  auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
+  auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
+  return cvtfp32_bf16(o1, o2);
+}
+
+static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) {
+  __m512i r[8];
+  // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15   e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15
+  // b0-b15  f0-f15
+  // c0-c15  g0-g15
+  // d0-d15  h0-h15
+  // i0-i15  m0-m15
+  // j0-j15  n0-n15
+  // k0-k15  o0-o15
+  // l0-l15  p0-p15
+#pragma unroll(4)
+  for (int i = 0; i < 4; i++) {
+    r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01);
+    r[i + 4] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01);
+  }
+
+  // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11   e0e1 f0f1 e2e3 f2f3 e8e9 f8f9 e10e11 f10f11
+  // u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15   e4e5 f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15
+  // u2: c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11   g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11
+  // u3: c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15   g4g5 h4h5 g6g7 h6h7 g12g13 h12h13 g14g15 h14h15
+  // i j  m n
+  // k l  o p
+#pragma unroll(4)
+  for (int i = 0; i < 8; i += 2) {
+    u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]);
+    u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]);
+  }
+
+  // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9  e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9 h8h9
+  // r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11  e2e3 f2f3 g2g3 h2h3 e10e11 f10f11 g10g11 h10h11
+  // r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13
+  // r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15
+  // r4: i j k l m n o p
+  r[0] = _mm512_unpacklo_epi64(u[0], u[2]);
+  r[1] = _mm512_unpackhi_epi64(u[0], u[2]);
+  r[2] = _mm512_unpacklo_epi64(u[1], u[3]);
+  r[3] = _mm512_unpackhi_epi64(u[1], u[3]);
+  r[4] = _mm512_unpacklo_epi64(u[4], u[6]);
+  r[5] = _mm512_unpackhi_epi64(u[4], u[6]);
+  r[6] = _mm512_unpacklo_epi64(u[5], u[7]);
+  r[7] = _mm512_unpackhi_epi64(u[5], u[7]);
+
+  __m512i const1 = _mm512_set_epi32(
+      0x00370035,
+      0x00330031,
+      0x00270025,
+      0x00230021,
+      0x00170015,
+      0x00130011,
+      0x00070005,
+      0x00030001,
+      0x00360034,
+      0x00320030,
+      0x00260024,
+      0x00220020,
+      0x00160014,
+      0x00120010,
+      0x00060004,
+      0x00020000);
+  __m512i const2 = _mm512_set_epi32(
+      0x003f003d,
+      0x003b0039,
+      0x002f002d,
+      0x002b0029,
+      0x001f001d,
+      0x001b0019,
+      0x000f000d,
+      0x000b0009,
+      0x003e003c,
+      0x003a0038,
+      0x002e002c,
+      0x002a0028,
+      0x001e001c,
+      0x001a0018,
+      0x000e000c,
+      0x000a0008);
+  // merge values from two regs
+  // 0-- 1--
+  // 8-- 9--
+  // 2-- 3--
+  // 10-- 11--
+  // 4-- 5--
+  // 12-- 13--
+  // 6-- 7--
+  // 14-- 15--
+#pragma unroll(4)
+  for (int i = 0; i < 4; i++) {
+    u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]);
+    u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]);
+  }
+}
+
+// TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16
+// Code referred to FBGEMM:
+// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
+template<>
+inline void transpose_mxn(
+    const BFloat16* src,
+    int64_t ld_src,
+    BFloat16* dst,
+    int64_t ld_dst) {
+  __m256i t[16];
+  // load from src to registers
+  // a: a0  a1  a2  a3  a4  a5  a6  a7  a8  a9  a10 a11 a12 a13 a14 a15
+  // b: b0  b1  b2  b3  b4  b5  b6  b7  b8  b9  b10 b11 b12 b13 b14 b15
+  // c: c0  c1  c2  c3  c4  c5  c6  c7  c8  c9  c10 c11 c12 c13 c14 c15
+  // d: d0  d1  d2  d3  d4  d5  d6  d7  d8  d9  d10 d11 d12 d13 d14 d15
+  // e: e0  e1  e2  e3  e4  e5  e6  e7  e8  e9  e10 e11 e12 e13 e14 e15
+  // f: f0  f1  f2  f3  f4  f5  f6  f7  f8  f9  f10 f11 f12 f13 f14 f15
+  // g: g0  g1  g2  g3  g4  g5  g6  g7  g8  g9  g10 g11 g12 g13 g14 g15
+  // h: h0  h1  h2  h3  h4  h5  h6  h7  h8  h9  h10 h11 h12 h13 h14 h15
+  // i: i0  i1  i2  i3  i4  i5  i6  i7  i8  i9  i10 i11 i12 i13 i14 i15
+  // j: j0  j1  j2  j3  j4  j5  j6  j7  j8  j9  j10 j11 j12 j13 j14 j15
+  // k: k0  k1  k2  k3  k4  k5  k6  k7  k8  k9  k10 k11 k12 k13 k14 k15
+  // l: l0  l1  l2  l3  l4  l5  l6  l7  l8  l9  l10 l11 l12 l13 l14 l15
+  // m: m0  m1  m2  m3  m4  m5  m6  m7  m8  m9  m10 m11 m12 m13 m14 m15
+  // n: n0  n1  n2  n3  n4  n5  n6  n7  n8  n9  n10 n11 n12 n13 n14 n15
+  // o: o0  o1  o2  o3  o4  o5  o6  o7  o8  o9  o10 o11 o12 o13 o14 o15
+  // p: p0  p1  p2  p3  p4  p5  p6  p7  p8  p9  p10 p11 p12 p13 p14 p15
+#pragma unroll(16)
+  for (int i = 0; i < 16; i++) {
+    t[i] = _mm256_loadu_si256(reinterpret_cast(src + i * ld_src));
+  }
+
+  __m512i u[8];
+  _transpose_mxn_half_16_16(t, u);
+
+#pragma unroll(8)
+  for (int i = 0; i < 8; i++) {
+    _mm256_storeu_si256(
+      reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
+      _mm512_extracti32x8_epi32(u[i], 0x0));
+    _mm256_storeu_si256(
+        reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
+        _mm512_extracti32x8_epi32(u[i], 0x01));
+  }
+}
+
+// Code referred to FBGEMM:
+// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
+template<>
+inline void transpose_mxn(
+    const Half* src,
+    int64_t ld_src,
+    Half* dst,
+    int64_t ld_dst) {
+  __m256i t[16];
+  // load from src to registers
+  // Same matrix indices as above transpose_mxn
+#pragma unroll(16)
+  for (int i = 0; i < 16; i++) {
+    t[i] = _mm256_loadu_si256(reinterpret_cast(src + i * ld_src));
+  }
+
+  __m512i u[8];
+  _transpose_mxn_half_16_16(t, u);
+
+#pragma unroll(8)
+  for (int i = 0; i < 8; i++) {
+    _mm256_storeu_si256(
+      reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
+      _mm512_extracti32x8_epi32(u[i], 0x0));
+    _mm256_storeu_si256(
+        reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
+        _mm512_extracti32x8_epi32(u[i], 0x01));
+  }
+}
+
+static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) {
+  // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59
+  // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63
+  // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123
+  // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127
+  // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 ... 187
+  // t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 175 148 ... 191
+  // t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 234 203 235 208 ... 251
+  // t[7]: 196 228 197 229 198 230 199 231 204 236 205 237 206 238 207 239 212 ... 255
+  // t[8]: 256 288 257 289 258 290 259 291 264 296 265 297 266 298 267 299 272 ... 315
+  // t[9]: 260 292 261 293 262 294 263 295 268 300 269 301 270 302 271 303 276 ... 319
+  // t[10]: 320 352 321 353 322 354 323 355 328 360 329 361 330 362 331 363 336 ... 379
+  // t[11]: 324 356 325 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383
+  // t[12]: 384 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443
+  // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 ... 447
+  // t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 491 464 ... 507
+  // t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 494 463 495 468 ... 511
+  // t[16]: 512 544 513 545 514 546 515 547 520 552 521 553 522 554 523 555 528 ... 571
+  // ...
+  // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 980 ... 1023
+#pragma unroll(16)
+  for (int i = 0; i < 16; ++i) {
+    d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]);
+    d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]);
+  }
+
+  // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121
+  // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123
+  // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125
+  // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127
+  // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 ... 249
+  // t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 235 146 ... 251
+  // t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 173 205 237 148 ... 253
+  // t[7]: 134 166 198 230 135 167 199 231 142 174 206 238 143 175 207 239 150 ... 255
+  // t[8]: 256 288 320 352 257 289 321 353 264 296 328 360 265 297 329 361 272 ... 377
+  // t[9]: 258 290 322 354 259 291 323 355 266 298 330 362 267 299 331 363 274 ... 379
+  // t[10]: 260 292 324 356 261 293 325 357 268 300 332 364 269 301 333 365 276 ... 381
+  // t[11]: 262 294 326 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383
+  // t[12]: 384 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505
+  // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 ... 507
+  // t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 493 404 ... 509
+  // t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 431 463 495 406 ... 511
+  // t[16]: 512 544 576 608 513 545 577 609 520 552 584 616 521 553 585 617 528 ... 633
+  // ...
+  // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 918 ... 1023
+#pragma unroll(8)
+  for (int i = 0; i < 8; ++i) {
+    r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]);
+    r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]);
+    r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]);
+    r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]);
+  }
+
+  // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248
+  // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249
+  // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250
+  // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251
+  // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252
+  // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253
+  // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254
+  // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255
+  // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 ... 504
+  // t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 489 273 ... 505
+  // t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 426 458 490 274 ... 506
+  // t[11]: 259 291 323 355 387 419 451 483 267 299 331 363 395 427 459 491 275 ... 507
+  // t[12]: 260 292 324 356 388 420 452 484 268 300 332 364 396 428 460 492 276 ... 508
+  // t[13]: 261 293 325 357 389 421 453 485 269 301 333 365 397 429 461 493 277 ... 509
+  // t[14]: 262 294 326 358 390 422 454 486 270 302 334 366 398 430 462 494 278 ... 510
+  // t[15]: 263 295 327 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511
+  // t[16]: 512 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760
+  // ...
+  // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 ... 1023
+#pragma unroll(4)
+  for (int i = 0; i < 4; ++i) {
+    d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]);
+    d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]);
+    d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]);
+    d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]);
+    d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]);
+    d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]);
+    d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]);
+    d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]);
+  }
+
+  // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496
+  // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497
+  // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498
+  // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499
+  // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... 500
+  // t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 ... 501
+  // t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 22 ... 502
+  // t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 23 ... 503
+  // t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 24 ... 504
+  // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 25 ... 505
+  // t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 26 ... 506
+  // t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 27 ... 507
+  // t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 28 ... 508
+  // t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 29 ... 509
+  // t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 30 ... 510
+  // t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 31 ... 511
+  // t[16]: 512 544 576 608 640 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008
+  // ...
+  // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 ... 1023
+  __m512i const1 = _mm512_set_epi64(
+      0x000000000000000d,
+      0x000000000000000c,
+      0x0000000000000005,
+      0x0000000000000004,
+      0x0000000000000009,
+      0x0000000000000008,
+      0x0000000000000001,
+      0x0000000000000000);
+  __m512i const2 = _mm512_set_epi64(
+      0x000000000000000f,
+      0x000000000000000e,
+      0x0000000000000007,
+      0x0000000000000006,
+      0x000000000000000b,
+      0x000000000000000a,
+      0x0000000000000003,
+      0x0000000000000002);
+#pragma unroll(8)
+  for (int i = 0; i < 8; ++i) {
+    r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/const1, d[i + 8]);
+    r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/const2, d[i + 8]);
+    r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const1, d[i + 24]);
+    r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const2, d[i + 24]);
+  }
+
+  // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 ... 992
+  // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 513 545 ... 993
+  // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 514 546 ... 994
+  // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 515 547 ... 995
+  // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 516 548 ... 996
+  // t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 517 549 ... 997
+  // t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 518 550 ... 998
+  // t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999
+  // t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000
+  // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 ... 1001
+  // t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 522 554 ... 1002
+  // t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 523 555 ... 1003
+  // t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 524 556 ... 1004
+  // t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005
+  // t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006
+  // t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... 1007
+  // t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 528 560 ... 1008
+  // ...
+  // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 ... 1023
+  __m512i const3 = _mm512_set_epi64(
+      0x000000000000000b,
+      0x000000000000000a,
+      0x0000000000000009,
+      0x0000000000000008,
+      0x0000000000000003,
+      0x0000000000000002,
+      0x0000000000000001,
+      0x0000000000000000);
+  __m512i const4 = _mm512_set_epi64(
+      0x000000000000000f,
+      0x000000000000000e,
+      0x000000000000000d,
+      0x000000000000000c,
+      0x0000000000000007,
+      0x0000000000000006,
+      0x0000000000000005,
+      0x0000000000000004);
+#pragma unroll(16)
+  for (int i = 0; i < 16; ++i) {
+    d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/const3, r[i + 16]);
+    d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/const4, r[i + 16]);
+  }
+}
+
+// Code referred to FBGEMM:
+// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6
+template<>
+inline void transpose_mxn(
+    const BFloat16* src,
+    int64_t ld_src,
+    BFloat16* dst,
+    int64_t ld_dst) {
+  // Load from memory
+  __m512i r[32];
+#pragma unroll(32)
+  for (int i = 0; i < 32; ++i) {
+    r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src));
+  }
+
+  __m512i d[32];
+  _transpose_mxn_half_32_32(r, d);
+
+  // Store to dst
+#pragma unroll(32)
+  for (int i = 0; i < 32; ++i) {
+    _mm512_storeu_si512(dst + i* ld_dst, d[i]);
+  }
+}
+
+template<>
+inline void transpose_mxn(
+    const Half* src,
+    int64_t ld_src,
+    Half* dst,
+    int64_t ld_dst) {
+  // Load from memory
+  __m512i r[32];
+#pragma unroll(32)
+  for (int i = 0; i < 32; ++i) {
+    r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src));
+  }
+
+  __m512i d[32];
+  _transpose_mxn_half_32_32(r, d);
+
+  // Store to dst
+#pragma unroll(32)
+  for (int i = 0; i < 32; ++i) {
+    _mm512_storeu_si512(dst + i* ld_dst, d[i]);
+  }
+}
+
+template <>
+class Vectorized: public Vectorized16 {
+public:
+  using Vectorized16::Vectorized16;
+
+  Vectorized frac() const;
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
+}
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
+}
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
+}
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
+}
+
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm512_and_si512(a, b);
+}
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm512_or_si512(a, b);
+}
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm512_xor_si512(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  cvtfp16_fp32(__m512i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m512i(b), b_lo, b_hi);
+  auto max_lo = _mm512_max_ps(a_lo, b_lo);
+  auto max_hi = _mm512_max_ps(a_hi, b_hi);
+  auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
+  auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
+  auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm512_or_ps(max_lo, nan_lo);
+  auto o2 = _mm512_or_ps(max_hi, nan_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  __m512i zero_vec = _mm512_set1_epi32(0);
+  cvtfp16_fp32(__m512i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m512i(b), b_lo, b_hi);
+  auto min_lo = _mm512_min_ps(a_lo, b_lo);
+  auto min_hi = _mm512_min_ps(a_hi, b_hi);
+  auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
+  auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
+  auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
+                                                           0xFFFFFFFF));
+  auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
+                                                           0xFFFFFFFF));
+  // Exploit the fact that all-ones is a NaN.
+  auto o1 = _mm512_or_ps(min_lo, nan_lo);
+  auto o2 = _mm512_or_ps(min_hi, nan_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a,
+    const Vectorized& min, const Vectorized& max) {
+  __m512 a_lo, a_hi;
+  __m512 min_lo, min_hi;
+  __m512 max_lo, max_hi;
+  cvtfp16_fp32(__m512i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m512i(min), min_lo, min_hi);
+  cvtfp16_fp32(__m512i(max), max_lo, max_hi);
+  auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
+  auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  __m512 a_lo, a_hi;
+  __m512 max_lo, max_hi;
+  cvtfp16_fp32(__m512i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m512i(max), max_lo, max_hi);
+  auto o1 = _mm512_min_ps(max_lo, a_lo);
+  auto o2 = _mm512_min_ps(max_hi, a_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  __m512 a_lo, a_hi;
+  __m512 min_lo, min_hi;
+  cvtfp16_fp32(__m512i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m512i(min), min_lo, min_hi);
+  auto o1 = _mm512_max_ps(min_lo, a_lo);
+  auto o2 = _mm512_max_ps(min_hi, a_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+template <>
+inline void convert(const Half* src, Half* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+template <>
+inline void convert(const float* src, Half* dst, int64_t n) {
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m512 a = _mm512_loadu_ps(&src[i]);
+    __m512 b = _mm512_loadu_ps(&src[i + 16]);
+
+    __m512i bf = cvtfp32_fp16(a, b);
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+inline void convert(const double* src, Half* dst, int64_t n) {
+  auto load_float = [](const double *src) -> __m512 {
+    // Load one float vector from an array of doubles
+    __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
+    __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
+    return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
+  };
+
+  int64_t i;
+  for (i = 0; i + Vectorized::size() <= n; i += Vectorized::size()) {
+    __m512 a = load_float(&src[i]);
+    __m512 b = load_float(&src[i + 16]);
+
+    __m512i bf = cvtfp32_fp16(a, b);
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
+  }
+  for (; i < n; i++) {
+    dst[i] = c10::convert(src[i]);
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a,
+    const Vectorized& b, const Vectorized& c) {
+  __m512 a_lo, a_hi;
+  __m512 b_lo, b_hi;
+  __m512 c_lo, c_hi;
+  cvtfp16_fp32(__m512i(a), a_lo, a_hi);
+  cvtfp16_fp32(__m512i(b), b_lo, b_hi);
+  cvtfp16_fp32(__m512i(c), c_lo, c_hi);
+  auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
+  auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
+  return cvtfp32_fp16(o1, o2);
+}
+
+#define CONVERT_VECTORIZED_INIT(type, name) \
+inline std::tuple, Vectorized> convert_##name##_float(const Vectorized& a) { \
+  __m512 o1, o2; \
+  cvt_to_fp32(__m512i(a), o1, o2); \
+  return std::make_tuple(o1, o2); \
+} \
+\
+inline Vectorized convert_float_##name(const Vectorized& a, const Vectorized& b) { \
+ return cvt_from_fp32(__m512(a), __m512(b)); \
+}
+CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
+CONVERT_VECTORIZED_INIT(Half, half);
+
+#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+#define CONVERT_NON_VECTORIZED_INIT(type, name) \
+inline std::tuple, Vectorized> convert_##name##_float(const Vectorized& a) { \
+  constexpr int64_t K = Vectorized::size(); \
+  __at_align__ float arr[K]; \
+  __at_align__ type arr2[K]; \
+  a.store(arr2); \
+  for (const auto k : c10::irange(K)) { \
+    arr[k] = c10::convert(arr2[k]); \
+  } \
+  return std::make_tuple( \
+      Vectorized::loadu(arr), \
+      Vectorized::loadu(arr + Vectorized::size())); \
+} \
+\
+inline Vectorized convert_float_##name(const Vectorized& a, const Vectorized& b) { \
+  constexpr int64_t K = Vectorized::size(); \
+  __at_align__ float arr[K]; \
+  __at_align__ type arr2[K]; \
+  a.store(arr); \
+  b.store(arr + Vectorized::size()); \
+  for (const auto k : c10::irange(K)) { \
+    arr2[k] = c10::convert(arr[k]); \
+  } \
+  return Vectorized::loadu(arr2); \
+}
+CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
+CONVERT_NON_VECTORIZED_INIT(Half, half);
+
+#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+#define LOAD_FP32_VECTORIZED_INIT(type, name) \
+inline void load_fp32_from_##name(const type *data, Vectorized& out) { \
+  auto values = _mm256_loadu_si256(reinterpret_cast(data)); \
+  __m512 out_values; \
+  cvt_to_fp32(values, out_values); \
+  out = out_values; \
+} \
+\
+inline void load_fp32_from_##name(const type *data, Vectorized& out1, Vectorized& out2) { \
+  auto vec = Vectorized::loadu(data); \
+  __m512 out1_values, out2_values; \
+  cvt_to_fp32(vec, out1_values, out2_values); \
+  out1 = out1_values; \
+  out2 = out2_values; \
+}
+LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
+LOAD_FP32_VECTORIZED_INIT(Half, fp16);
+
+#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
+inline void load_fp32_from_##name(const type *data, Vectorized& out) { \
+  __at_align__ float values[Vectorized::size()]; \
+  for (const auto k : c10::irange(Vectorized::size())) { \
+    values[k] = data[k]; \
+  } \
+  out = Vectorized::loadu(values); \
+} \
+\
+inline void load_fp32_from_##name(const type *data, Vectorized& out1, Vectorized& out2) { \
+  load_fp32_from_##name(data, out1); \
+  data += Vectorized::size(); \
+  load_fp32_from_##name(data, out2); \
+}
+LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16);
+LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16);
+
+#endif
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h
new file mode 100644
index 0000000000000000000000000000000000000000..4455017576f4cc640ef92a6a1024cf2e31a0746c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h
@@ -0,0 +1,512 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#include 
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+template <> class Vectorized> {
+private:
+  __m512d values;
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+public:
+  using value_type = c10::complex;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 4;
+  }
+  Vectorized() {}
+  Vectorized(__m512d v) : values(v) {}
+  Vectorized(c10::complex val) {
+    double real_value = val.real();
+    double imag_value = val.imag();
+    values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value,
+                            real_value, imag_value, real_value, imag_value);
+  }
+  Vectorized(c10::complex val1, c10::complex val2,
+            c10::complex val3, c10::complex val4) {
+    values = _mm512_setr_pd(val1.real(), val1.imag(),
+                            val2.real(), val2.imag(),
+                            val3.real(), val3.imag(),
+                            val4.real(), val4.imag());
+  }
+  operator __m512d() const {
+    return values;
+  }
+  template 
+  static Vectorized> blend(const Vectorized>& a,
+                                               const Vectorized>& b) {
+     // convert c10::complex index mask to V index mask: xy -> xxyy
+    // NOLINTNEXTLINE(clang-diagnostic-warning)
+    switch (mask) {
+      case 0:
+        return a;
+      case 1:
+        return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011
+      case 2:
+        return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100
+      case 3:
+        return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111
+      case 4:
+        return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000
+      case 5:
+        return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011
+      case 6:
+        return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100
+      case 7:
+        return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111
+      case 8:
+        return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000
+      case 9:
+        return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011
+      case 10:
+        return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100
+      case 11:
+        return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111
+      case 12:
+        return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000
+      case 13:
+        return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011
+      case 14:
+        return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100
+      case 15:
+        return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111
+    }
+    return b;
+  }
+  static Vectorized> blendv(const Vectorized>& a,
+                                                const Vectorized>& b,
+                                                const Vectorized>& mask) {
+    // convert c10::complex index mask to V index mask: xy -> xxyy
+    auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values);
+    auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
+    auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_pd(mmask, a.values, b.values);
+  }
+  template
+  static Vectorized> arange(c10::complex base = 0.,
+                                                step_t step = static_cast(1)) {
+    return Vectorized>(base,
+                                           base + c10::complex(1)*step,
+                                           base + c10::complex(2)*step,
+                                           base + c10::complex(3)*step);
+  }
+  static Vectorized> set(const Vectorized>& a,
+                                             const Vectorized>& b,
+                                             int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+    }
+    return b;
+  }
+  static Vectorized> loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm512_loadu_pd(reinterpret_cast(ptr));
+
+    __at_align__ double tmp_values[2*size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(2*size())) {
+      tmp_values[i] = 0.0;
+    }
+    std::memcpy(
+        tmp_values,
+        reinterpret_cast(ptr),
+        count * sizeof(c10::complex));
+    return _mm512_load_pd(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm512_storeu_pd(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      double tmp_values[2*size()];
+      _mm512_storeu_pd(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(c10::complex));
+    }
+  }
+  const c10::complex& operator[](int idx) const  = delete;
+  c10::complex& operator[](int idx) = delete;
+  Vectorized> map(c10::complex (*const f)(const c10::complex &)) const {
+    __at_align__ c10::complex tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  // AVX512 doesn't have horizontal add & horizontal sub instructions.
+  // TODO: hadd_pd() & hsub_pd() may have scope for improvement.
+  static inline __m512d hadd_pd(__m512d a, __m512d b) {
+  __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
+  __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
+  return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
+                       _mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
+  }
+  static inline __m512d hsub_pd(__m512d a, __m512d b) {
+  __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0);
+  __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1);
+  return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
+                       _mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
+  }
+  __m512d abs_2_() const {
+    auto val_2 = _mm512_mul_pd(values, values);     // a*a     b*b
+    return hadd_pd(val_2, val_2);            // a*a+b*b a*a+b*b
+  }
+  __m512d abs_() const {
+    auto real = _mm512_movedup_pd(values);        // real real
+    // movehdup_pd does not exist...
+    auto imag = _mm512_permute_pd(values, 0xff);  // imag imag
+    return Sleef_hypotd8_u05(real, imag);         // abs  abs
+  }
+  Vectorized> abs() const {
+    const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
+    return _mm512_and_pd(abs_(), real_mask);        // abs     0
+  }
+  __m512d angle_() const {
+    //angle = atan2(b/a)
+    auto b_a = _mm512_permute_pd(values, 0x55);     // b        a
+    return Sleef_atan2d8_u10(values, b_a);          // 90-angle angle
+  }
+  Vectorized> angle() const {
+    const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
+    auto angle = _mm512_permute_pd(angle_(), 0x55); // angle    90-angle
+    return _mm512_and_pd(angle, real_mask);         // angle    0
+  }
+  Vectorized> sgn() const {
+    auto abs = abs_();
+    auto zero = _mm512_setzero_pd();
+    auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
+    auto div = values / abs;
+    return _mm512_mask_blend_pd(mask, div, zero);
+  }
+  __m512d real_() const {
+    const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
+                                                                    0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
+    return _mm512_and_pd(values, real_mask);
+  }
+  Vectorized> real() const {
+    return real_();
+  }
+  __m512d imag_() const {
+    const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
+                                                                    0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
+                                                                    0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
+                                                                    0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
+    return _mm512_and_pd(values, imag_mask);
+  }
+  Vectorized> imag() const {
+    return _mm512_permute_pd(imag_(), 0x55);           //b        a
+  }
+  __m512d conj_() const {
+    const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+    return _mm512_xor_pd(values, sign_mask);           // a       -b
+  }
+  Vectorized> conj() const {
+    return conj_();
+  }
+  Vectorized> log() const {
+    // Most trigonomic ops use the log() op to improve complex number performance.
+    return map(std::log);
+  }
+  Vectorized> log2() const {
+    const __m512d log2_ = _mm512_set1_pd(std::log(2));
+    return _mm512_div_pd(log(), log2_);
+  }
+  Vectorized> log10() const {
+    const __m512d log10_ = _mm512_set1_pd(std::log(10));
+    return _mm512_div_pd(log(), log10_);
+  }
+  Vectorized> log1p() const {
+    return map(std::log1p);
+  }
+  Vectorized> asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+    const __m512d one = _mm512_set1_pd(1);
+
+    auto conj = conj_();
+    auto b_a = _mm512_permute_pd(conj, 0x55);                         //-b        a
+    auto ab = _mm512_mul_pd(conj, b_a);                               //-ab       -ab
+    auto im = _mm512_add_pd(ab, ab);                                  //-2ab      -2ab
+
+    auto val_2 = _mm512_mul_pd(values, values);                       // a*a      b*b
+    auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55));  // a*a-b*b  b*b-a*a
+    re = _mm512_sub_pd(one, re);
+
+    auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt();         //sqrt(re + i*im)
+    auto ln = Vectorized(_mm512_add_pd(b_a, root)).log();                 //ln(iz + sqrt())
+    return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj();         //-i*ln()
+  }
+  Vectorized> acos() const {
+    // acos(x) = pi/2 - asin(x)
+    constexpr auto pi_2d = c10::pi / 2;
+    const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0);
+    return _mm512_sub_pd(pi_2, asin());
+  }
+  Vectorized> atan() const;
+  Vectorized> atanh() const {
+    return map(std::atanh);
+  }
+  Vectorized> exp() const {
+    //exp(a + bi)
+    // = exp(a)*(cos(b) + sin(b)i)
+    auto exp = Sleef_expd8_u10(values);                               //exp(a)           exp(b)
+    exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55));   //exp(a)           exp(a)
+
+    auto sin_cos = Sleef_sincosd8_u10(values);                        //[sin(a), cos(a)] [sin(b), cos(b)]
+    auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55),
+                                   sin_cos.x);                  //cos(b)           sin(b)
+    return _mm512_mul_pd(exp, cos_sin);
+  }
+  Vectorized> exp2() const {
+    // Use identity 2**x = exp(log(2) * x)
+    const __m512d ln_2 = _mm512_set1_pd(c10::ln_2);
+    Vectorized> scaled_values = _mm512_mul_pd(values, ln_2);
+    return scaled_values.exp();
+  }
+  Vectorized> expm1() const {
+    return map(std::expm1);
+  }
+  Vectorized> sin() const {
+    return map(std::sin);
+  }
+  Vectorized> sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized> cos() const {
+    return map(std::cos);
+  }
+  Vectorized> cosh() const {
+    return map(std::cosh);
+  }
+  Vectorized> ceil() const {
+    return _mm512_ceil_pd(values);
+  }
+  Vectorized> floor() const {
+    return _mm512_floor_pd(values);
+  }
+  Vectorized> neg() const {
+    auto zero = _mm512_setzero_pd();
+    return _mm512_sub_pd(zero, values);
+  }
+  Vectorized> round() const {
+    return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> tan() const {
+    return map(std::tan);
+  }
+  Vectorized> tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized> trunc() const {
+    return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> sqrt() const {
+    return map(std::sqrt);
+  }
+  Vectorized> reciprocal() const;
+  Vectorized> rsqrt() const {
+    return sqrt().reciprocal();
+  }
+  Vectorized> pow(const Vectorized> &exp) const {
+    __at_align__ c10::complex x_tmp[size()];
+    __at_align__ c10::complex y_tmp[size()];
+    store(x_tmp);
+    exp.store(y_tmp);
+    for (const auto i : c10::irange(size())) {
+      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
+    }
+    return loadu(x_tmp);
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized> operator==(const Vectorized>& other) const {
+    auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+  Vectorized> operator!=(const Vectorized>& other) const {
+    auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+  Vectorized> operator<(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator<=(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>=(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized> eq(const Vectorized>& other) const;
+  Vectorized> ne(const Vectorized>& other) const;
+};
+
+template <> Vectorized> inline operator+(const Vectorized> &a,
+                                                             const Vectorized> &b) {
+  return _mm512_add_pd(a, b);
+}
+
+template <> Vectorized> inline operator-(const Vectorized> &a,
+                                                             const Vectorized> &b) {
+  return _mm512_sub_pd(a, b);
+}
+
+template <> Vectorized> inline operator*(const Vectorized> &a,
+                                                             const Vectorized> &b) {
+  //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+  const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+  auto ac_bd = _mm512_mul_pd(a, b);         //ac       bd
+
+  auto d_c = _mm512_permute_pd(b, 0x55);    //d        c
+  d_c = _mm512_xor_pd(sign_mask, d_c);      //d       -c
+  auto ad_bc = _mm512_mul_pd(a, d_c);       //ad      -bc
+
+  auto ret = Vectorized>::hsub_pd(ac_bd, ad_bc);  //ac - bd  ad + bc
+  return ret;
+}
+
+template <> Vectorized> inline operator/(const Vectorized> &a,
+                                                             const Vectorized> &b) {
+  //re + im*i = (a + bi)  / (c + di)
+  auto mask = _mm512_set1_pd(-0.f);
+  auto fabs_cd = _mm512_andnot_pd(mask, b);     // |c|    |d|
+  auto fabs_dc = _mm512_permute_pd(fabs_cd, 0x55);   // |d|    |c|
+  auto scale = _mm512_rcp14_pd(_mm512_max_pd(fabs_cd, fabs_dc));  // 1/sc     1/sc
+  auto a2 = _mm512_mul_pd(a, scale);         // a/sc     b/sc
+  auto b2 = _mm512_mul_pd(b, scale);         // c/sc     d/sc
+  auto acbd2 = _mm512_mul_pd(a2, b2);
+
+  const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
+  auto dc2 = _mm512_permute_pd(b2, 0x55);    // d/sc         c/sc
+  dc2 = _mm512_xor_pd(sign_mask, dc2);       // -d/|c,d|        c/sc
+  auto adbc2 = _mm512_mul_pd(a2, dc2);       //-ad/sc^2      bc/sc^2
+  auto res2 = Vectorized>::hadd_pd(acbd2, adbc2);  //(ac+bd)/sc^2  (bc-ad)/sc^2
+
+  // get the denominator
+  auto denom2 = Vectorized>(b2).abs_2_();  // (c^2+d^2)/sc^2   (c^2+d^2)/sc^2
+  res2 = _mm512_div_pd(res2, denom2);
+  return res2;
+}
+
+// reciprocal. Implement this here so we can use multiplication.
+inline Vectorized> Vectorized>::reciprocal() const{
+  //re + im*i = (a + bi)  / (c + di)
+  //re = (ac + bd)/abs_2() = c/abs_2()
+  //im = (bc - ad)/abs_2() = d/abs_2()
+  const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+  auto c_d = _mm512_xor_pd(sign_mask, values);    //c       -d
+  return _mm512_div_pd(c_d, abs_2_());
+}
+
+inline Vectorized> Vectorized>::atan() const {
+  // atan(x) = i/2 * ln((i + z)/(i - z))
+  const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
+  const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
+
+  auto sum = Vectorized(_mm512_add_pd(i, values));                      // a        1+b
+  auto sub = Vectorized(_mm512_sub_pd(i, values));                      // -a       1-b
+  auto ln = (sum/sub).log();                                        // ln((i + z)/(i - z))
+  return i_half*ln;                                                 // i/2*ln()
+}
+
+template <>
+Vectorized> inline maximum(const Vectorized>& a,
+                                               const Vectorized>& b) {
+  auto zero_vec = _mm512_set1_epi64(0);
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ);
+  auto max = _mm512_mask_blend_pd(mask, a, b);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
+  auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
+                                      0xFFFFFFFFFFFFFFFF);
+  return _mm512_or_pd(max, _mm512_castsi512_pd(isnan));
+}
+
+template <>
+Vectorized> inline minimum(const Vectorized>& a,
+                                               const Vectorized>& b) {
+  auto zero_vec = _mm512_set1_epi64(0);
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ);
+  auto min = _mm512_mask_blend_pd(mask, a, b);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q);
+  auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask,
+                                      0xFFFFFFFFFFFFFFFF);
+  return _mm512_or_pd(min, _mm512_castsi512_pd(isnan));
+}
+
+template <>
+Vectorized> inline operator&(const Vectorized>& a,
+                                                 const Vectorized>& b) {
+  return _mm512_and_pd(a, b);
+}
+
+template <>
+Vectorized> inline operator|(const Vectorized>& a,
+                                                 const Vectorized>& b) {
+  return _mm512_or_pd(a, b);
+}
+
+template <>
+Vectorized> inline operator^(const Vectorized>& a,
+                                                 const Vectorized>& b) {
+  return _mm512_xor_pd(a, b);
+}
+
+inline Vectorized> Vectorized>::eq(const Vectorized>& other) const {
+  auto eq = (*this == other);  // compares real and imag individually
+  // If both real numbers and imag numbers are equal, then the complex numbers are equal
+  return (eq.real() & eq.imag()) & Vectorized>(_mm512_set1_pd(1.0));
+}
+
+inline Vectorized> Vectorized>::ne(const Vectorized>& other) const {
+  auto ne = (*this != other);  // compares real and imag individually
+  // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+  return (ne.real() | ne.imag()) & Vectorized>(_mm512_set1_pd(1.0));
+}
+
+#endif
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h
new file mode 100644
index 0000000000000000000000000000000000000000..14dfb24e3efeec03417796a6ab15d26d8f2c84b1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h
@@ -0,0 +1,1018 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#include 
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+template <> class Vectorized> {
+private:
+  __m512 values;
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+public:
+  using value_type = c10::complex;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+  Vectorized() {}
+  Vectorized(__m512 v) : values(v) {}
+  Vectorized(c10::complex val) {
+    float real_value = val.real();
+    float imag_value = val.imag();
+    values = _mm512_setr_ps(real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value,
+                            real_value, imag_value);
+  }
+  Vectorized(c10::complex val1, c10::complex val2,
+            c10::complex val3, c10::complex val4,
+            c10::complex val5, c10::complex val6,
+            c10::complex val7, c10::complex val8) {
+    values = _mm512_setr_ps(val1.real(), val1.imag(),
+                            val2.real(), val2.imag(),
+                            val3.real(), val3.imag(),
+                            val4.real(), val4.imag(),
+                            val5.real(), val5.imag(),
+                            val6.real(), val6.imag(),
+                            val7.real(), val7.imag(),
+                            val8.real(), val8.imag());
+  }
+  operator __m512() const {
+    return values;
+  }
+  template 
+  static Vectorized> blend(const Vectorized>& a,
+                                              const Vectorized>& b) {
+    // convert c10::complex index mask to V index mask: xy -> xxyy
+    static_assert(mask > -1 && mask < 256, "Unexpected mask value");
+    // The compiler would hopefully convert this switch condition
+    // into a jump table
+    switch (mask) {
+      case 0:
+        return a;
+      case 1:
+        return _mm512_mask_blend_ps(0x03, a.values, b.values);
+      case 2:
+        return _mm512_mask_blend_ps(0x0C, a.values, b.values);
+      case 3:
+        return _mm512_mask_blend_ps(0x0F, a.values, b.values);
+      case 4:
+        return _mm512_mask_blend_ps(0x30, a.values, b.values);
+      case 5:
+        return _mm512_mask_blend_ps(0x33, a.values, b.values);
+      case 6:
+        return _mm512_mask_blend_ps(0x3C, a.values, b.values);
+      case 7:
+        return _mm512_mask_blend_ps(0x3F, a.values, b.values);
+      case 8:
+        return _mm512_mask_blend_ps(0xC0, a.values, b.values);
+      case 9:
+        return _mm512_mask_blend_ps(0xC3, a.values, b.values);
+      case 10:
+        return _mm512_mask_blend_ps(0xCC, a.values, b.values);
+      case 11:
+        return _mm512_mask_blend_ps(0xCF, a.values, b.values);
+      case 12:
+        return _mm512_mask_blend_ps(0xF0, a.values, b.values);
+      case 13:
+        return _mm512_mask_blend_ps(0xF3, a.values, b.values);
+      case 14:
+        return _mm512_mask_blend_ps(0xFC, a.values, b.values);
+      case 15:
+        return _mm512_mask_blend_ps(0xFF, a.values, b.values);
+      case 16:
+        return _mm512_mask_blend_ps(0x300, a.values, b.values);
+      case 17:
+        return _mm512_mask_blend_ps(0x303, a.values, b.values);
+      case 18:
+        return _mm512_mask_blend_ps(0x30C, a.values, b.values);
+      case 19:
+        return _mm512_mask_blend_ps(0x30F, a.values, b.values);
+      case 20:
+        return _mm512_mask_blend_ps(0x330, a.values, b.values);
+      case 21:
+        return _mm512_mask_blend_ps(0x333, a.values, b.values);
+      case 22:
+        return _mm512_mask_blend_ps(0x33C, a.values, b.values);
+      case 23:
+        return _mm512_mask_blend_ps(0x33F, a.values, b.values);
+      case 24:
+        return _mm512_mask_blend_ps(0x3C0, a.values, b.values);
+      case 25:
+        return _mm512_mask_blend_ps(0x3C3, a.values, b.values);
+      case 26:
+        return _mm512_mask_blend_ps(0x3CC, a.values, b.values);
+      case 27:
+        return _mm512_mask_blend_ps(0x3CF, a.values, b.values);
+      case 28:
+        return _mm512_mask_blend_ps(0x3F0, a.values, b.values);
+      case 29:
+        return _mm512_mask_blend_ps(0x3F3, a.values, b.values);
+      case 30:
+        return _mm512_mask_blend_ps(0x3FC, a.values, b.values);
+      case 31:
+        return _mm512_mask_blend_ps(0x3FF, a.values, b.values);
+      case 32:
+        return _mm512_mask_blend_ps(0xC00, a.values, b.values);
+      case 33:
+        return _mm512_mask_blend_ps(0xC03, a.values, b.values);
+      case 34:
+        return _mm512_mask_blend_ps(0xC0C, a.values, b.values);
+      case 35:
+        return _mm512_mask_blend_ps(0xC0F, a.values, b.values);
+      case 36:
+        return _mm512_mask_blend_ps(0xC30, a.values, b.values);
+      case 37:
+        return _mm512_mask_blend_ps(0xC33, a.values, b.values);
+      case 38:
+        return _mm512_mask_blend_ps(0xC3C, a.values, b.values);
+      case 39:
+        return _mm512_mask_blend_ps(0xC3F, a.values, b.values);
+      case 40:
+        return _mm512_mask_blend_ps(0xCC0, a.values, b.values);
+      case 41:
+        return _mm512_mask_blend_ps(0xCC3, a.values, b.values);
+      case 42:
+        return _mm512_mask_blend_ps(0xCCC, a.values, b.values);
+      case 43:
+        return _mm512_mask_blend_ps(0xCCF, a.values, b.values);
+      case 44:
+        return _mm512_mask_blend_ps(0xCF0, a.values, b.values);
+      case 45:
+        return _mm512_mask_blend_ps(0xCF3, a.values, b.values);
+      case 46:
+        return _mm512_mask_blend_ps(0xCFC, a.values, b.values);
+      case 47:
+        return _mm512_mask_blend_ps(0xCFF, a.values, b.values);
+      case 48:
+        return _mm512_mask_blend_ps(0xF00, a.values, b.values);
+      case 49:
+        return _mm512_mask_blend_ps(0xF03, a.values, b.values);
+      case 50:
+        return _mm512_mask_blend_ps(0xF0C, a.values, b.values);
+      case 51:
+        return _mm512_mask_blend_ps(0xF0F, a.values, b.values);
+      case 52:
+        return _mm512_mask_blend_ps(0xF30, a.values, b.values);
+      case 53:
+        return _mm512_mask_blend_ps(0xF33, a.values, b.values);
+      case 54:
+        return _mm512_mask_blend_ps(0xF3C, a.values, b.values);
+      case 55:
+        return _mm512_mask_blend_ps(0xF3F, a.values, b.values);
+      case 56:
+        return _mm512_mask_blend_ps(0xFC0, a.values, b.values);
+      case 57:
+        return _mm512_mask_blend_ps(0xFC3, a.values, b.values);
+      case 58:
+        return _mm512_mask_blend_ps(0xFCC, a.values, b.values);
+      case 59:
+        return _mm512_mask_blend_ps(0xFCF, a.values, b.values);
+      case 60:
+        return _mm512_mask_blend_ps(0xFF0, a.values, b.values);
+      case 61:
+        return _mm512_mask_blend_ps(0xFF3, a.values, b.values);
+      case 62:
+        return _mm512_mask_blend_ps(0xFFC, a.values, b.values);
+      case 63:
+        return _mm512_mask_blend_ps(0xFFF, a.values, b.values);
+      case 64:
+        return _mm512_mask_blend_ps(0x3000, a.values, b.values);
+      case 65:
+        return _mm512_mask_blend_ps(0x3003, a.values, b.values);
+      case 66:
+        return _mm512_mask_blend_ps(0x300C, a.values, b.values);
+      case 67:
+        return _mm512_mask_blend_ps(0x300F, a.values, b.values);
+      case 68:
+        return _mm512_mask_blend_ps(0x3030, a.values, b.values);
+      case 69:
+        return _mm512_mask_blend_ps(0x3033, a.values, b.values);
+      case 70:
+        return _mm512_mask_blend_ps(0x303C, a.values, b.values);
+      case 71:
+        return _mm512_mask_blend_ps(0x303F, a.values, b.values);
+      case 72:
+        return _mm512_mask_blend_ps(0x30C0, a.values, b.values);
+      case 73:
+        return _mm512_mask_blend_ps(0X30C3, a.values, b.values);
+      case 74:
+        return _mm512_mask_blend_ps(0x30CC, a.values, b.values);
+      case 75:
+        return _mm512_mask_blend_ps(0x30CF, a.values, b.values);
+      case 76:
+        return _mm512_mask_blend_ps(0x30F0, a.values, b.values);
+      case 77:
+        return _mm512_mask_blend_ps(0x30F3, a.values, b.values);
+      case 78:
+        return _mm512_mask_blend_ps(0x30FC, a.values, b.values);
+      case 79:
+        return _mm512_mask_blend_ps(0x30FF, a.values, b.values);
+      case 80:
+        return _mm512_mask_blend_ps(0x3300, a.values, b.values);
+      case 81:
+        return _mm512_mask_blend_ps(0X3303, a.values, b.values);
+      case 82:
+        return _mm512_mask_blend_ps(0x330C, a.values, b.values);
+      case 83:
+        return _mm512_mask_blend_ps(0x330F, a.values, b.values);
+      case 84:
+        return _mm512_mask_blend_ps(0x3330, a.values, b.values);
+      case 85:
+        return _mm512_mask_blend_ps(0x3333, a.values, b.values);
+      case 86:
+        return _mm512_mask_blend_ps(0x333C, a.values, b.values);
+      case 87:
+        return _mm512_mask_blend_ps(0X333F, a.values, b.values);
+      case 88:
+        return _mm512_mask_blend_ps(0x33C0, a.values, b.values);
+      case 89:
+        return _mm512_mask_blend_ps(0x33C3, a.values, b.values);
+      case 90:
+        return _mm512_mask_blend_ps(0x33CC, a.values, b.values);
+      case 91:
+        return _mm512_mask_blend_ps(0x33CF, a.values, b.values);
+      case 92:
+        return _mm512_mask_blend_ps(0x33F0, a.values, b.values);
+      case 93:
+        return _mm512_mask_blend_ps(0x33F3, a.values, b.values);
+      case 94:
+        return _mm512_mask_blend_ps(0x33FC, a.values, b.values);
+      case 95:
+        return _mm512_mask_blend_ps(0x33FF, a.values, b.values);
+      case 96:
+        return _mm512_mask_blend_ps(0X3C00, a.values, b.values);
+      case 97:
+        return _mm512_mask_blend_ps(0x3C03, a.values, b.values);
+      case 98:
+        return _mm512_mask_blend_ps(0x3C0C, a.values, b.values);
+      case 99:
+        return _mm512_mask_blend_ps(0x3C0F, a.values, b.values);
+      case 100:
+        return _mm512_mask_blend_ps(0x3C30, a.values, b.values);
+      case 101:
+        return _mm512_mask_blend_ps(0x3C33, a.values, b.values);
+      case 102:
+        return _mm512_mask_blend_ps(0x3C3C, a.values, b.values);
+      case 103:
+        return _mm512_mask_blend_ps(0x3C3F, a.values, b.values);
+      case 104:
+        return _mm512_mask_blend_ps(0x3CC0, a.values, b.values);
+      case 105:
+        return _mm512_mask_blend_ps(0x3CC3, a.values, b.values);
+      case 106:
+        return _mm512_mask_blend_ps(0x3CCC, a.values, b.values);
+      case 107:
+        return _mm512_mask_blend_ps(0x3CCF, a.values, b.values);
+      case 108:
+        return _mm512_mask_blend_ps(0x3CF0, a.values, b.values);
+      case 109:
+        return _mm512_mask_blend_ps(0x3CF3, a.values, b.values);
+      case 110:
+        return _mm512_mask_blend_ps(0x3CFC, a.values, b.values);
+      case 111:
+        return _mm512_mask_blend_ps(0x3CFF, a.values, b.values);
+      case 112:
+        return _mm512_mask_blend_ps(0x3F00, a.values, b.values);
+      case 113:
+        return _mm512_mask_blend_ps(0x3F03, a.values, b.values);
+      case 114:
+        return _mm512_mask_blend_ps(0x3F0C, a.values, b.values);
+      case 115:
+        return _mm512_mask_blend_ps(0x3F0F, a.values, b.values);
+      case 116:
+        return _mm512_mask_blend_ps(0x3F30, a.values, b.values);
+      case 117:
+        return _mm512_mask_blend_ps(0x3F33, a.values, b.values);
+      case 118:
+        return _mm512_mask_blend_ps(0x3F3C, a.values, b.values);
+      case 119:
+        return _mm512_mask_blend_ps(0x3F3F, a.values, b.values);
+      case 120:
+        return _mm512_mask_blend_ps(0x3FC0, a.values, b.values);
+      case 121:
+        return _mm512_mask_blend_ps(0x3FC3, a.values, b.values);
+      case 122:
+        return _mm512_mask_blend_ps(0x3FCC, a.values, b.values);
+      case 123:
+        return _mm512_mask_blend_ps(0x3FCF, a.values, b.values);
+      case 124:
+        return _mm512_mask_blend_ps(0x3FF0, a.values, b.values);
+      case 125:
+        return _mm512_mask_blend_ps(0x3FF3, a.values, b.values);
+      case 126:
+        return _mm512_mask_blend_ps(0x3FFC, a.values, b.values);
+      case 127:
+        return _mm512_mask_blend_ps(0x3FFF, a.values, b.values);
+      case 128:
+        return _mm512_mask_blend_ps(0xC000, a.values, b.values);
+      case 129:
+        return _mm512_mask_blend_ps(0xC003, a.values, b.values);
+      case 130:
+        return _mm512_mask_blend_ps(0xC00C, a.values, b.values);
+      case 131:
+        return _mm512_mask_blend_ps(0xC00F, a.values, b.values);
+      case 132:
+        return _mm512_mask_blend_ps(0xC030, a.values, b.values);
+      case 133:
+        return _mm512_mask_blend_ps(0xC033, a.values, b.values);
+      case 134:
+        return _mm512_mask_blend_ps(0xC03C, a.values, b.values);
+      case 135:
+        return _mm512_mask_blend_ps(0xC03F, a.values, b.values);
+      case 136:
+        return _mm512_mask_blend_ps(0xC0C0, a.values, b.values);
+      case 137:
+        return _mm512_mask_blend_ps(0xC0C3, a.values, b.values);
+      case 138:
+        return _mm512_mask_blend_ps(0xC0CC, a.values, b.values);
+      case 139:
+        return _mm512_mask_blend_ps(0xC0CF, a.values, b.values);
+      case 140:
+        return _mm512_mask_blend_ps(0xC0F0, a.values, b.values);
+      case 141:
+        return _mm512_mask_blend_ps(0xC0F3, a.values, b.values);
+      case 142:
+        return _mm512_mask_blend_ps(0xC0FC, a.values, b.values);
+      case 143:
+        return _mm512_mask_blend_ps(0xC0FF, a.values, b.values);
+      case 144:
+        return _mm512_mask_blend_ps(0xC300, a.values, b.values);
+      case 145:
+        return _mm512_mask_blend_ps(0xC303, a.values, b.values);
+      case 146:
+        return _mm512_mask_blend_ps(0xC30C, a.values, b.values);
+      case 147:
+        return _mm512_mask_blend_ps(0xC30F, a.values, b.values);
+      case 148:
+        return _mm512_mask_blend_ps(0xC330, a.values, b.values);
+      case 149:
+        return _mm512_mask_blend_ps(0xC333, a.values, b.values);
+      case 150:
+        return _mm512_mask_blend_ps(0xC33C, a.values, b.values);
+      case 151:
+        return _mm512_mask_blend_ps(0xC33F, a.values, b.values);
+      case 152:
+        return _mm512_mask_blend_ps(0xC3C0, a.values, b.values);
+      case 153:
+        return _mm512_mask_blend_ps(0xC3C3, a.values, b.values);
+      case 154:
+        return _mm512_mask_blend_ps(0xC3CC, a.values, b.values);
+      case 155:
+        return _mm512_mask_blend_ps(0xC3CF, a.values, b.values);
+      case 156:
+        return _mm512_mask_blend_ps(0xC3F0, a.values, b.values);
+      case 157:
+        return _mm512_mask_blend_ps(0xC3F3, a.values, b.values);
+      case 158:
+        return _mm512_mask_blend_ps(0xC3FC, a.values, b.values);
+      case 159:
+        return _mm512_mask_blend_ps(0xC3FF, a.values, b.values);
+      case 160:
+        return _mm512_mask_blend_ps(0xCC00, a.values, b.values);
+      case 161:
+        return _mm512_mask_blend_ps(0xCC03, a.values, b.values);
+      case 162:
+        return _mm512_mask_blend_ps(0xCC0C, a.values, b.values);
+      case 163:
+        return _mm512_mask_blend_ps(0xCC0F, a.values, b.values);
+      case 164:
+        return _mm512_mask_blend_ps(0xCC30, a.values, b.values);
+      case 165:
+        return _mm512_mask_blend_ps(0xCC33, a.values, b.values);
+      case 166:
+        return _mm512_mask_blend_ps(0xCC3C, a.values, b.values);
+      case 167:
+        return _mm512_mask_blend_ps(0xCC3F, a.values, b.values);
+      case 168:
+        return _mm512_mask_blend_ps(0xCCC0, a.values, b.values);
+      case 169:
+        return _mm512_mask_blend_ps(0xCCC3, a.values, b.values);
+      case 170:
+        return _mm512_mask_blend_ps(0xCCCC, a.values, b.values);
+      case 171:
+        return _mm512_mask_blend_ps(0xCCCF, a.values, b.values);
+      case 172:
+        return _mm512_mask_blend_ps(0xCCF0, a.values, b.values);
+      case 173:
+        return _mm512_mask_blend_ps(0xCCF3, a.values, b.values);
+      case 174:
+        return _mm512_mask_blend_ps(0xCCFC, a.values, b.values);
+      case 175:
+        return _mm512_mask_blend_ps(0xCCFF, a.values, b.values);
+      case 176:
+        return _mm512_mask_blend_ps(0xCF00, a.values, b.values);
+      case 177:
+        return _mm512_mask_blend_ps(0xCF03, a.values, b.values);
+      case 178:
+        return _mm512_mask_blend_ps(0xCF0C, a.values, b.values);
+      case 179:
+        return _mm512_mask_blend_ps(0xCF0F, a.values, b.values);
+      case 180:
+        return _mm512_mask_blend_ps(0xCF30, a.values, b.values);
+      case 181:
+        return _mm512_mask_blend_ps(0xCF33, a.values, b.values);
+      case 182:
+        return _mm512_mask_blend_ps(0xCF3C, a.values, b.values);
+      case 183:
+        return _mm512_mask_blend_ps(0xCF3F, a.values, b.values);
+      case 184:
+        return _mm512_mask_blend_ps(0xCFC0, a.values, b.values);
+      case 185:
+        return _mm512_mask_blend_ps(0xCFC3, a.values, b.values);
+      case 186:
+        return _mm512_mask_blend_ps(0xCFCC, a.values, b.values);
+      case 187:
+        return _mm512_mask_blend_ps(0xCFCF, a.values, b.values);
+      case 188:
+        return _mm512_mask_blend_ps(0xCFF0, a.values, b.values);
+      case 189:
+        return _mm512_mask_blend_ps(0xCFF3, a.values, b.values);
+      case 190:
+        return _mm512_mask_blend_ps(0xCFFC, a.values, b.values);
+      case 191:
+        return _mm512_mask_blend_ps(0xCFFF, a.values, b.values);
+      case 192:
+        return _mm512_mask_blend_ps(0xF000, a.values, b.values);
+      case 193:
+        return _mm512_mask_blend_ps(0xF003, a.values, b.values);
+      case 194:
+        return _mm512_mask_blend_ps(0xF00C, a.values, b.values);
+      case 195:
+        return _mm512_mask_blend_ps(0xF00F, a.values, b.values);
+      case 196:
+        return _mm512_mask_blend_ps(0xF030, a.values, b.values);
+      case 197:
+        return _mm512_mask_blend_ps(0xF033, a.values, b.values);
+      case 198:
+        return _mm512_mask_blend_ps(0xF03C, a.values, b.values);
+      case 199:
+        return _mm512_mask_blend_ps(0xF03F, a.values, b.values);
+      case 200:
+        return _mm512_mask_blend_ps(0XF0C0, a.values, b.values);
+      case 201:
+        return _mm512_mask_blend_ps(0xF0C3, a.values, b.values);
+      case 202:
+        return _mm512_mask_blend_ps(0xF0CC, a.values, b.values);
+      case 203:
+        return _mm512_mask_blend_ps(0xF0CF, a.values, b.values);
+      case 204:
+        return _mm512_mask_blend_ps(0xF0F0, a.values, b.values);
+      case 205:
+        return _mm512_mask_blend_ps(0xF0F3, a.values, b.values);
+      case 206:
+        return _mm512_mask_blend_ps(0xF0FC, a.values, b.values);
+      case 207:
+        return _mm512_mask_blend_ps(0xF0FF, a.values, b.values);
+      case 208:
+        return _mm512_mask_blend_ps(0XF300, a.values, b.values);
+      case 209:
+        return _mm512_mask_blend_ps(0xF303, a.values, b.values);
+      case 210:
+        return _mm512_mask_blend_ps(0xF30C, a.values, b.values);
+      case 211:
+        return _mm512_mask_blend_ps(0xF30F, a.values, b.values);
+      case 212:
+        return _mm512_mask_blend_ps(0xF330, a.values, b.values);
+      case 213:
+        return _mm512_mask_blend_ps(0xF333, a.values, b.values);
+      case 214:
+        return _mm512_mask_blend_ps(0XF33C, a.values, b.values);
+      case 215:
+        return _mm512_mask_blend_ps(0xF33F, a.values, b.values);
+      case 216:
+        return _mm512_mask_blend_ps(0xF3C0, a.values, b.values);
+      case 217:
+        return _mm512_mask_blend_ps(0xF3C3, a.values, b.values);
+      case 218:
+        return _mm512_mask_blend_ps(0xF3CC, a.values, b.values);
+      case 219:
+        return _mm512_mask_blend_ps(0xF3CF, a.values, b.values);
+      case 220:
+        return _mm512_mask_blend_ps(0xF3F0, a.values, b.values);
+      case 221:
+        return _mm512_mask_blend_ps(0xF3F3, a.values, b.values);
+      case 222:
+        return _mm512_mask_blend_ps(0xF3FC, a.values, b.values);
+      case 223:
+        return _mm512_mask_blend_ps(0XF3FF, a.values, b.values);
+      case 224:
+        return _mm512_mask_blend_ps(0xFC00, a.values, b.values);
+      case 225:
+        return _mm512_mask_blend_ps(0xFC03, a.values, b.values);
+      case 226:
+        return _mm512_mask_blend_ps(0xFC0C, a.values, b.values);
+      case 227:
+        return _mm512_mask_blend_ps(0xFC0F, a.values, b.values);
+      case 228:
+        return _mm512_mask_blend_ps(0xFC30, a.values, b.values);
+      case 229:
+        return _mm512_mask_blend_ps(0xFC33, a.values, b.values);
+      case 230:
+        return _mm512_mask_blend_ps(0xFC3C, a.values, b.values);
+      case 231:
+        return _mm512_mask_blend_ps(0xFC3F, a.values, b.values);
+      case 232:
+        return _mm512_mask_blend_ps(0xFCC0, a.values, b.values);
+      case 233:
+        return _mm512_mask_blend_ps(0xFCC3, a.values, b.values);
+      case 234:
+        return _mm512_mask_blend_ps(0xFCCC, a.values, b.values);
+      case 235:
+        return _mm512_mask_blend_ps(0xFCCF, a.values, b.values);
+      case 236:
+        return _mm512_mask_blend_ps(0xFCF0, a.values, b.values);
+      case 237:
+        return _mm512_mask_blend_ps(0xFCF3, a.values, b.values);
+      case 238:
+        return _mm512_mask_blend_ps(0xFCFC, a.values, b.values);
+      case 239:
+        return _mm512_mask_blend_ps(0xFCFF, a.values, b.values);
+      case 240:
+        return _mm512_mask_blend_ps(0xFF00, a.values, b.values);
+      case 241:
+        return _mm512_mask_blend_ps(0xFF03, a.values, b.values);
+      case 242:
+        return _mm512_mask_blend_ps(0xFF0C, a.values, b.values);
+      case 243:
+        return _mm512_mask_blend_ps(0xFF0F, a.values, b.values);
+      case 244:
+        return _mm512_mask_blend_ps(0xFF30, a.values, b.values);
+      case 245:
+        return _mm512_mask_blend_ps(0xFF33, a.values, b.values);
+      case 246:
+        return _mm512_mask_blend_ps(0xFF3C, a.values, b.values);
+      case 247:
+        return _mm512_mask_blend_ps(0xFF3F, a.values, b.values);
+      case 248:
+        return _mm512_mask_blend_ps(0xFFC0, a.values, b.values);
+      case 249:
+        return _mm512_mask_blend_ps(0xFFC3, a.values, b.values);
+      case 250:
+        return _mm512_mask_blend_ps(0xFFCC, a.values, b.values);
+      case 251:
+        return _mm512_mask_blend_ps(0xFFCF, a.values, b.values);
+      case 252:
+        return _mm512_mask_blend_ps(0xFFF0, a.values, b.values);
+      case 253:
+        return _mm512_mask_blend_ps(0xFFF3, a.values, b.values);
+      case 254:
+        return _mm512_mask_blend_ps(0xFFFC, a.values, b.values);
+      default: break;
+    }
+    return b;
+  }
+  static Vectorized> blendv(const Vectorized>& a,
+                                               const Vectorized>& b,
+                                               const Vectorized>& mask) {
+    // convert c10::complex index mask to V index mask: xy -> xxyy
+    auto mask_ = _mm512_unpacklo_ps(mask.values, mask.values);
+    auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
+    auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask_), all_ones, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_ps(mmask, a.values, b.values);
+  }
+  template
+  static Vectorized> arange(c10::complex base = 0.,
+                                               step_t step = static_cast(1)) {
+    return Vectorized>(base,
+                                        base + step,
+                                        base + c10::complex(2)*step,
+                                        base + c10::complex(3)*step,
+                                        base + c10::complex(4)*step,
+                                        base + c10::complex(5)*step,
+                                        base + c10::complex(6)*step,
+                                        base + c10::complex(7)*step);
+  }
+  static Vectorized> set(const Vectorized>& a,
+                                            const Vectorized>& b,
+                            int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+    return b;
+  }
+  static Vectorized> loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm512_loadu_ps(reinterpret_cast(ptr));
+
+    __at_align__ float tmp_values[2*size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(2*size())) {
+      tmp_values[i] = 0.0;
+    }
+    std::memcpy(
+        tmp_values,
+        reinterpret_cast(ptr),
+        count * sizeof(c10::complex));
+    return _mm512_load_ps(tmp_values);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm512_storeu_ps(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      float tmp_values[2*size()];
+      _mm512_storeu_ps(reinterpret_cast(tmp_values), values);
+      std::memcpy(ptr, tmp_values, count * sizeof(c10::complex));
+    }
+  }
+  // AVX512 doesn't have horizontal add & horizontal sub instructions.
+  // TODO: hadd_pd() & hsub_pd() may have scope for improvement.
+  static inline __m512 hadd_ps(__m512 a, __m512 b) {
+  __m512i idx1 = _mm512_set_epi32(30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0);
+  __m512i idx2 = _mm512_set_epi32(31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1);
+  return _mm512_add_ps(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
+                       _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
+  }
+  static inline __m512 hsub_ps(__m512 a, __m512 b) {
+  __m512i idx1 = _mm512_set_epi32(30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0);
+  __m512i idx2 = _mm512_set_epi32(31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1);
+  return _mm512_sub_ps(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
+                       _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
+  }
+  const c10::complex& operator[](int idx) const  = delete;
+  c10::complex& operator[](int idx) = delete;
+  Vectorized> map(c10::complex (*const f)(const c10::complex &)) const {
+    __at_align__ c10::complex tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  __m512 abs_2_() const {
+    auto val_2 = _mm512_mul_ps(values, values);     // a*a     b*b
+    auto ret = hadd_ps(val_2, val_2);               // a*a+b*b a*a+b*b
+    return ret;
+  }
+  __m512 abs_() const {
+    auto real = _mm512_moveldup_ps(values);    // real real
+    auto imag = _mm512_movehdup_ps(values);    // imag imag
+    return Sleef_hypotf16_u05(real, imag);     // abs  abs
+  }
+  Vectorized> abs() const {
+    const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
+    return _mm512_and_ps(abs_(), real_mask);        // abs     0
+  }
+  __m512 angle_() const {
+    //angle = atan2(b/a)
+    auto b_a = _mm512_permute_ps(values, 0xB1);     // b        a
+    return Sleef_atan2f16_u10(values, b_a);          // 90-angle angle
+  }
+  Vectorized> angle() const {
+    const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
+    auto angle = _mm512_permute_ps(angle_(), 0xB1); // angle    90-angle
+    return _mm512_and_ps(angle, real_mask);         // angle    0
+  }
+  Vectorized> sgn() const {
+    auto abs = abs_();
+    auto zero = _mm512_setzero_ps();
+    auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ);
+    auto div = values / abs;
+    return _mm512_mask_blend_ps(mask, div, zero);
+  }
+  __m512 real_() const {
+    const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
+                                                                   0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
+    return _mm512_and_ps(values, real_mask);
+  }
+  Vectorized> real() const {
+    return real_();
+  }
+  __m512 imag_() const {
+    const __m512 imag_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
+                                                                   0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
+                                                                   0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
+                                                                   0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF));
+    return _mm512_and_ps(values, imag_mask);
+  }
+  Vectorized> imag() const {
+    return _mm512_permute_ps(imag_(), 0xB1);        //b        a
+  }
+  __m512 conj_() const {
+    const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
+                                            0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+    return _mm512_xor_ps(values, sign_mask);        // a       -b
+  }
+  Vectorized> conj() const {
+    return conj_();
+  }
+  Vectorized> log() const {
+    // Most trigonomic ops use the log() op to improve complex number performance.
+    return map(std::log);
+  }
+  Vectorized> log2() const {
+    const __m512 log2_ = _mm512_set1_ps(std::log(2));
+    return _mm512_div_ps(log(), log2_);
+  }
+  Vectorized> log10() const {
+    const __m512 log10_ = _mm512_set1_ps(std::log(10));
+    return _mm512_div_ps(log(), log10_);
+  }
+  Vectorized> log1p() const {
+    return map(std::log1p);
+  }
+  Vectorized> asin() const {
+    // asin(x)
+    // = -i*ln(iz + sqrt(1 -z^2))
+    // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
+    // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
+    const __m512 one = _mm512_set1_ps(1);
+
+    auto conj = conj_();
+    auto b_a = _mm512_permute_ps(conj, 0xB1);                         //-b        a
+    auto ab = _mm512_mul_ps(conj, b_a);                               //-ab       -ab
+    auto im = _mm512_add_ps(ab, ab);                                  //-2ab      -2ab
+
+    auto val_2 = _mm512_mul_ps(values, values);                       // a*a      b*b
+    auto re = hsub_ps(val_2, _mm512_permute_ps(val_2, 0xB1));  // a*a-b*b  b*b-a*a
+    re = _mm512_sub_ps(one, re);
+
+    auto root = Vectorized(_mm512_mask_blend_ps(0xAAAA, re, im)).sqrt();         //sqrt(re + i*im)
+    auto ln = Vectorized(_mm512_add_ps(b_a, root)).log();                 //ln(iz + sqrt())
+    return Vectorized(_mm512_permute_ps(ln.values, 0xB1)).conj();         //-i*ln()
+  }
+  Vectorized> acos() const {
+    return map(std::acos);
+  }
+  Vectorized> atan() const;
+  Vectorized> atanh() const {
+    return map(std::atanh);
+  }
+  Vectorized> exp() const {
+    //exp(a + bi)
+    // = exp(a)*(cos(b) + sin(b)i)
+    auto exp = Sleef_expf16_u10(values);                               //exp(a)           exp(b)
+    exp = _mm512_mask_blend_ps(0xAAAA, exp, _mm512_permute_ps(exp, 0xB1));   //exp(a)           exp(a)
+
+    auto sin_cos = Sleef_sincosf16_u10(values);                        //[sin(a), cos(a)] [sin(b), cos(b)]
+    auto cos_sin = _mm512_mask_blend_ps(0xAAAA, _mm512_permute_ps(sin_cos.y, 0xB1),
+                                   sin_cos.x);                  //cos(b)           sin(b)
+    return _mm512_mul_ps(exp, cos_sin);
+  }
+  Vectorized> exp2() const {
+    // Use identity 2**x = exp(log(2) * x)
+    const __m512 ln_2 = _mm512_set1_ps(c10::ln_2);
+    Vectorized> scaled_values = _mm512_mul_ps(values, ln_2);
+    return scaled_values.exp();
+  }
+  Vectorized> expm1() const {
+    return map(std::expm1);
+  }
+  Vectorized> sin() const {
+    return map(std::sin);
+  }
+  Vectorized> sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized> cos() const {
+    return map(std::cos);
+  }
+  Vectorized> cosh() const {
+    return map(std::cosh);
+  }
+  Vectorized> ceil() const {
+    return _mm512_ceil_ps(values);
+  }
+  Vectorized> floor() const {
+    return _mm512_floor_ps(values);
+  }
+  Vectorized> neg() const {
+    auto zero = _mm512_setzero_ps();
+    return _mm512_sub_ps(zero, values);
+  }
+  Vectorized> round() const {
+    return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> tan() const {
+    return map(std::tan);
+  }
+  Vectorized> tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized> trunc() const {
+    return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized> sqrt() const {
+    return map(std::sqrt);
+  }
+  Vectorized> reciprocal() const;
+  Vectorized> rsqrt() const {
+    return sqrt().reciprocal();
+  }
+  Vectorized> pow(const Vectorized> &exp) const {
+    __at_align__ c10::complex x_tmp[size()];
+    __at_align__ c10::complex y_tmp[size()];
+    store(x_tmp);
+    exp.store(y_tmp);
+    for (const auto i : c10::irange(size())) {
+      x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
+    }
+    return loadu(x_tmp);
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized> operator==(const Vectorized>& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF));
+  }
+  Vectorized> operator!=(const Vectorized>& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF));
+  }
+  Vectorized> operator<(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator<=(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+  Vectorized> operator>=(const Vectorized>& other) const {
+    TORCH_CHECK(false, "not supported for complex numbers");
+  }
+
+  Vectorized> eq(const Vectorized>& other) const;
+  Vectorized> ne(const Vectorized>& other) const;
+};
+
+template <> Vectorized> inline operator+(const Vectorized> &a,
+                                                            const Vectorized> &b) {
+  return _mm512_add_ps(a, b);
+}
+
+template <> Vectorized> inline operator-(const Vectorized> &a,
+                                                            const Vectorized> &b) {
+  return _mm512_sub_ps(a, b);
+}
+
+template <> Vectorized> inline operator*(const Vectorized> &a,
+                                                            const Vectorized> &b) {
+  //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
+  const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
+                                          0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+  auto ac_bd = _mm512_mul_ps(a, b);         //ac       bd
+
+  auto d_c = _mm512_permute_ps(b, 0xB1);    //d        c
+  d_c = _mm512_xor_ps(sign_mask, d_c);      //d       -c
+  auto ad_bc = _mm512_mul_ps(a, d_c);       //ad      -bc
+
+  auto ret = Vectorized>::hsub_ps(ac_bd, ad_bc);  //ac - bd  ad + bc
+  return ret;
+}
+
+template <> Vectorized> inline operator/(const Vectorized> &a,
+                                                            const Vectorized> &b) {
+  //re + im*i = (a + bi)  / (c + di)
+  auto mask = _mm512_set1_ps(-0.f);
+  auto fabs_cd = _mm512_andnot_ps(mask, b);     // |c|    |d|
+  auto fabs_dc = _mm512_permute_ps(fabs_cd, 0xB1);   // |d|    |c|
+  auto scale = _mm512_rcp14_ps(_mm512_max_ps(fabs_cd, fabs_dc));  // 1/sc     1/sc
+  auto a2 = _mm512_mul_ps(a, scale);         // a/sc     b/sc
+  auto b2 = _mm512_mul_ps(b, scale);         // c/sc     d/sc
+  auto acbd2 = _mm512_mul_ps(a2, b2);
+
+  const __m512 sign_mask = _mm512_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0,
+                                          -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
+  auto dc2 = _mm512_permute_ps(b2, 0xB1);    // d/sc         c/sc
+  dc2 = _mm512_xor_ps(sign_mask, dc2);       // -d/|c,d|        c/sc
+  auto adbc2 = _mm512_mul_ps(a2, dc2);       //-ad/sc^2      bc/sc^2
+  auto res2 = Vectorized>::hadd_ps(acbd2, adbc2);  //(ac+bd)/sc^2  (bc-ad)/sc^2
+
+  // get the denominator
+  auto denom2 = Vectorized>(b2).abs_2_();  // (c^2+d^2)/sc^2   (c^2+d^2)/sc^2
+  res2 = _mm512_div_ps(res2, denom2);
+  return res2;
+}
+
+// reciprocal. Implement this here so we can use multiplication.
+inline Vectorized> Vectorized>::reciprocal() const {
+  //re + im*i = (a + bi)  / (c + di)
+  //re = (ac + bd)/abs_2() = c/abs_2()
+  //im = (bc - ad)/abs_2() = d/abs_2()
+  const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
+                                          0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
+  auto c_d = _mm512_xor_ps(sign_mask, values);    //c       -d
+  return _mm512_div_ps(c_d, abs_2_());
+}
+
+inline Vectorized> Vectorized>::atan() const {
+  // atan(x) = i/2 * ln((i + z)/(i - z))
+  const __m512 i = _mm512_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0,
+                                  0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
+  const Vectorized i_half = _mm512_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5,
+                                          0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
+
+  auto sum = Vectorized(_mm512_add_ps(i, values));                      // a        1+b
+  auto sub = Vectorized(_mm512_sub_ps(i, values));                      // -a       1-b
+  auto ln = (sum/sub).log();                                        // ln((i + z)/(i - z))
+  return i_half*ln;                                                 // i/2*ln()
+}
+
+template <>
+Vectorized> inline maximum(const Vectorized>& a,
+                                              const Vectorized>& b) {
+  auto zero_vector = _mm512_set1_epi32(0);
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_LT_OQ);
+  auto max = _mm512_mask_blend_ps(mask, a, b);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q);
+  auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF);
+  return _mm512_or_ps(max, _mm512_castsi512_ps(isnan));
+}
+
+template <>
+Vectorized> inline minimum(const Vectorized>& a,
+                                              const Vectorized>& b) {
+  auto zero_vector = _mm512_set1_epi32(0);
+  auto abs_a = a.abs_2_();
+  auto abs_b = b.abs_2_();
+  auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_GT_OQ);
+  auto min = _mm512_mask_blend_ps(mask, a, b);
+  // Exploit the fact that all-ones is a NaN.
+  auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q);
+  auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF);
+  return _mm512_or_ps(min, _mm512_castsi512_ps(isnan));
+}
+
+template <>
+Vectorized> inline operator&(const Vectorized>& a,
+                                                const Vectorized>& b) {
+  return _mm512_and_ps(a, b);
+}
+
+template <>
+Vectorized> inline operator|(const Vectorized>& a,
+                                                const Vectorized>& b) {
+  return _mm512_or_ps(a, b);
+}
+
+template <>
+Vectorized> inline operator^(const Vectorized>& a,
+                                                const Vectorized>& b) {
+  return _mm512_xor_ps(a, b);
+}
+
+inline Vectorized> Vectorized>::eq(
+    const Vectorized>& other) const {
+  auto eq = (*this == other);  // compares real and imag individually
+  // If both real numbers and imag numbers are equal, then the complex numbers are equal
+  return (eq.real() & eq.imag()) & Vectorized>(_mm512_set1_ps(1.0f));
+}
+
+inline Vectorized> Vectorized>::ne(
+    const Vectorized>& other) const {
+  auto ne = (*this != other);  // compares real and imag individually
+  // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
+  return (ne.real() | ne.imag()) & Vectorized>(_mm512_set1_ps(1.0f));
+}
+
+#endif
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h
new file mode 100644
index 0000000000000000000000000000000000000000..fb12593fbc970dc2ebf69380bfbd35e3f90fb590
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_double.h
@@ -0,0 +1,467 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+template <> class Vectorized {
+private:
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+public:
+  // values needs to be public for compilation with clang
+  // as vec512.h uses it
+  __m512d values;
+  using value_type = double;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+  Vectorized() {}
+  Vectorized(__m512d v) : values(v) {}
+  Vectorized(double val) {
+    values = _mm512_set1_pd(val);
+  }
+  Vectorized(double val1, double val2, double val3, double val4,
+         double val5, double val6, double val7, double val8) {
+    values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8);
+  }
+  operator __m512d() const {
+    return values;
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    return _mm512_mask_blend_pd(mask, a.values, b.values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                               const Vectorized& mask) {
+    auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
+    auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_pd(mmask, a.values, b.values);
+  }
+  template
+  static Vectorized arange(double base = 0., step_t step = static_cast(1)) {
+    return Vectorized(base, base + step, base + 2 * step, base + 3 * step,
+                          base + 4 * step, base + 5 * step, base + 6 * step,
+                          base + 7 * step);
+  }
+  static Vectorized set(const Vectorized& a, const Vectorized& b,
+                            int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm512_loadu_pd(reinterpret_cast(ptr));
+
+    __mmask8 mask = (1ULL << count) - 1;
+    return _mm512_maskz_loadu_pd(mask, ptr);
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      _mm512_storeu_pd(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      __mmask8 mask = (1ULL << count) - 1;
+      _mm512_mask_storeu_pd(reinterpret_cast(ptr), mask, values);
+    }
+  }
+  const double& operator[](int idx) const  = delete;
+  double& operator[](int idx) = delete;
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ);
+    return static_cast(cmp);
+  }
+  Vectorized isnan() const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+  bool has_inf_nan() const {
+    __m512d self_sub  = _mm512_sub_pd(values, values);
+    return (_mm512_movepi8_mask(_mm512_castpd_si512(self_sub)) & 0x7777777777777777) != 0;
+  }
+  Vectorized map(double (*const f)(double)) const {
+    __at_align__ double tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized abs() const {
+    auto mask = _mm512_set1_pd(-0.f);
+    return _mm512_andnot_pd(mask, values);
+  }
+  Vectorized angle() const {
+    const auto zero_vec = _mm512_castsi512_pd(zero_vector);
+    const auto nan_vec = _mm512_set1_pd(NAN);
+    const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ);
+    const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask,
+                                                0xFFFFFFFFFFFFFFFF);
+    const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan),
+                                             zero_vec, _CMP_EQ_OQ);
+    const auto pi = _mm512_set1_pd(c10::pi);
+
+    const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ);
+    auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi);
+    angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec);
+    return angle;
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_pd(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return Vectorized(Sleef_acosd8_u10(values));
+  }
+  Vectorized acosh() const {
+    return Vectorized(Sleef_acoshd8_u10(values));
+  }
+  Vectorized asin() const {
+    return Vectorized(Sleef_asind8_u10(values));
+  }
+  Vectorized atan() const {
+    return Vectorized(Sleef_atand8_u10(values));
+  }
+  Vectorized atanh() const {
+    return Vectorized(Sleef_atanhd8_u10(values));
+  }
+  Vectorized atan2(const Vectorized &b) const {
+    return Vectorized(Sleef_atan2d8_u10(values, b));
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    return Vectorized(Sleef_copysignd8(values, sign));
+  }
+  Vectorized erf() const {
+    return Vectorized(Sleef_erfd8_u10(values));
+  }
+  Vectorized erfc() const {
+    return Vectorized(Sleef_erfcd8_u15(values));
+  }
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+  Vectorized exp() const {
+    return Vectorized(Sleef_expd8_u10(values));
+  }
+  Vectorized exp2() const {
+    return Vectorized(Sleef_exp2d8_u10(values));
+  }
+  Vectorized expm1() const {
+    return Vectorized(Sleef_expm1d8_u10(values));
+  }
+  Vectorized exp_u20() const {
+    return exp();
+  }
+  Vectorized fmod(const Vectorized& q) const {
+    return Vectorized(Sleef_fmodd8(values, q));
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    return Vectorized(Sleef_hypotd8_u05(values, b));
+  }
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __at_align__ double tmp[size()];
+    __at_align__ double tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized igammac(const Vectorized &x) const {
+    __at_align__ double tmp[size()];
+    __at_align__ double tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized log() const {
+    return Vectorized(Sleef_logd8_u10(values));
+  }
+  Vectorized log2() const {
+    return Vectorized(Sleef_log2d8_u10(values));
+  }
+  Vectorized log10() const {
+    return Vectorized(Sleef_log10d8_u10(values));
+  }
+  Vectorized log1p() const {
+    return Vectorized(Sleef_log1pd8_u10(values));
+  }
+  Vectorized sin() const {
+    return Vectorized(Sleef_sind8_u10(values));
+  }
+  Vectorized sinh() const {
+    return Vectorized(Sleef_sinhd8_u10(values));
+  }
+  Vectorized cos() const {
+    return Vectorized(Sleef_cosd8_u10(values));
+  }
+  Vectorized cosh() const {
+    return Vectorized(Sleef_coshd8_u10(values));
+  }
+  Vectorized ceil() const {
+    return _mm512_ceil_pd(values);
+  }
+  Vectorized floor() const {
+    return _mm512_floor_pd(values);
+  }
+  Vectorized frac() const;
+  Vectorized neg() const {
+    return _mm512_xor_pd(_mm512_set1_pd(-0.), values);
+  }
+  Vectorized nextafter(const Vectorized &b) const {
+    return Vectorized(Sleef_nextafterd8(values, b));
+  }
+  Vectorized round() const {
+    return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized tan() const {
+    return Vectorized(Sleef_tand8_u10(values));
+  }
+  Vectorized tanh() const {
+    return Vectorized(Sleef_tanhd8_u10(values));
+  }
+  Vectorized trunc() const {
+    return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized lgamma() const {
+    return Vectorized(Sleef_lgammad8_u10(values));
+  }
+  Vectorized sqrt() const {
+    return _mm512_sqrt_pd(values);
+  }
+  Vectorized reciprocal() const {
+    return _mm512_div_pd(_mm512_set1_pd(1), values);
+  }
+  Vectorized rsqrt() const {
+    return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values));
+  }
+  Vectorized pow(const Vectorized &b) const {
+    return Vectorized(Sleef_powd8_u10(values, b));
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized operator==(const Vectorized& other) const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+
+  Vectorized operator!=(const Vectorized& other) const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_UQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ);
+    return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask,
+                                                      0xFFFFFFFFFFFFFFFF));
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_pd(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_pd(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm512_mul_pd(a, b);
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return _mm512_div_pd(a, b);
+}
+
+// frac. Implement this here so we can use subtraction.
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  auto zero_vec = _mm512_set1_epi64(0);
+  Vectorized max = _mm512_max_pd(a, b);
+  auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
+  auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
+                                                          0xFFFFFFFFFFFFFFFF));
+  // Exploit the fact that all-ones is a NaN.
+  return _mm512_or_pd(max, isnan);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  auto zero_vec = _mm512_set1_epi64(0);
+  Vectorized min = _mm512_min_pd(a, b);
+  auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q);
+  auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask,
+                                                          0xFFFFFFFFFFFFFFFF));
+  // Exploit the fact that all-ones is a NaN.
+  return _mm512_or_pd(min, isnan);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) {
+  return _mm512_min_pd(max, _mm512_max_pd(min, a));
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  return _mm512_max_pd(min, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  return _mm512_min_pd(max, a);
+}
+
+template <>
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm512_and_pd(a, b);
+}
+
+template <>
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm512_or_pd(a, b);
+}
+
+template <>
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm512_xor_pd(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0);
+}
+
+template <>
+inline void convert(const double* src, double* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm512_fmadd_pd(a, b, c);
+}
+
+template <>
+Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm512_fmsub_pd(a, b, c);
+}
+
+#endif
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h
new file mode 100644
index 0000000000000000000000000000000000000000..69a429988065b0ba3b27f734561c48445cead041
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_float.h
@@ -0,0 +1,793 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+#include 
+#endif
+
+namespace at {
+namespace vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+template <> class Vectorized {
+private:
+  static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0};
+public:
+  __m512 values;
+  using value_type = float;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 16;
+  }
+  Vectorized() {}
+  Vectorized(__m512 v) : values(v) {}
+  Vectorized(float val) {
+    values = _mm512_set1_ps(val);
+  }
+  Vectorized(float val1, float val2, float val3, float val4,
+         float val5, float val6, float val7, float val8,
+         float val9, float val10, float val11, float val12,
+         float val13, float val14, float val15, float val16) {
+    values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8,
+                            val9, val10, val11, val12, val13, val14, val15, val16);
+  }
+  operator __m512() const {
+    return values;
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    return _mm512_mask_blend_ps(mask, a.values, b.values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                              const Vectorized& mask) {
+    auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
+    auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_ps(mmask, a.values, b.values);
+  }
+  template
+  static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,            base +     step, base + 2 * step, base + 3 * step,
+      base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
+      base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
+  }
+  static Vectorized set(const Vectorized& a, const Vectorized& b,
+                           int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+      case 8:
+        return blend<255>(a, b);
+      case 9:
+        return blend<511>(a, b);
+      case 10:
+        return blend<1023>(a, b);
+      case 11:
+        return blend<2047>(a, b);
+      case 12:
+        return blend<4095>(a, b);
+      case 13:
+        return blend<8191>(a, b);
+      case 14:
+        return blend<16383>(a, b);
+      case 15:
+        return blend<32767>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr, int64_t count = size()) {
+    if (count == size())
+      return _mm512_loadu_ps(reinterpret_cast(ptr));
+
+    __mmask16 mask = (1ULL << count) - 1;
+    return _mm512_maskz_loadu_ps(mask, ptr);
+  }
+  void store(void* ptr, int64_t count = size()) const {
+    if (count == size()) {
+      _mm512_storeu_ps(reinterpret_cast(ptr), values);
+    } else if (count > 0) {
+      __mmask16 mask = (1ULL << count) - 1;
+      _mm512_mask_storeu_ps(reinterpret_cast(ptr), mask, values);
+    }
+  }
+  const float& operator[](int idx) const  = delete;
+  float& operator[](int idx) = delete;
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ);
+    return static_cast(cmp);
+  }
+  Vectorized isnan() const {
+    auto mask =  _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+  bool has_inf_nan() const {
+    __m512 self_sub  = _mm512_sub_ps(values, values);
+    return (_mm512_movepi8_mask(_mm512_castps_si512(self_sub)) & 0x7777777777777777) != 0;
+  }
+  Vectorized map(float (*const f)(float)) const {
+    __at_align__ float tmp[size()];
+    store(tmp);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = f(tmp[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized abs() const {
+    auto mask = _mm512_set1_ps(-0.f);
+    return _mm512_andnot_ps(mask, values);
+  }
+  Vectorized angle() const {
+    __m512 zero_vec = _mm512_set1_ps(0.f);
+    const auto nan_vec = _mm512_set1_ps(NAN);
+    const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
+    const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
+                                                    not_nan_mask, 0xFFFFFFFF);
+    const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec),
+                                             zero_vec, _CMP_EQ_OQ);
+    const auto pi = _mm512_set1_ps(c10::pi);
+
+    const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
+    auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
+    angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
+    return angle;
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_ps(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized acos() const {
+    return Vectorized(Sleef_acosf16_u10(values));
+  }
+  Vectorized acosh() const {
+    return Vectorized(Sleef_acoshf16_u10(values));
+  }
+  Vectorized asin() const {
+    return Vectorized(Sleef_asinf16_u10(values));
+  }
+  Vectorized atan() const {
+    return Vectorized(Sleef_atanf16_u10(values));
+  }
+  Vectorized atanh() const {
+    return Vectorized(Sleef_atanhf16_u10(values));
+  }
+  Vectorized atan2(const Vectorized &b) const {
+    return Vectorized(Sleef_atan2f16_u10(values, b));
+  }
+  Vectorized copysign(const Vectorized &sign) const {
+    return Vectorized(Sleef_copysignf16(values, sign));
+  }
+  Vectorized erf() const {
+    // constants
+    const auto neg_zero_vec = _mm512_set1_ps(-0.f);
+    const auto one_vec = _mm512_set1_ps(1.0f);
+    const auto p = _mm512_set1_ps(0.3275911f);
+    const auto p1 = _mm512_set1_ps(0.254829592f);
+    const auto p2 = _mm512_set1_ps(-0.284496736f);
+    const auto p3 = _mm512_set1_ps(1.421413741f);
+    const auto p4 = _mm512_set1_ps(-1.453152027f);
+    const auto p5 = _mm512_set1_ps(1.061405429f);
+    // sign(x)
+    auto sign_mask = _mm512_and_ps(neg_zero_vec, values);
+    auto abs_vec = _mm512_abs_ps(values);
+    // t = 1 / (p * abs(x) + 1)
+    auto tmp0 = _mm512_fmadd_ps(p, abs_vec, one_vec);
+    auto t = _mm512_div_ps(one_vec, tmp0);
+    // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
+    auto tmp1 = _mm512_fmadd_ps(p5, t, p4);
+    auto tmp2 = _mm512_fmadd_ps(tmp1, t, p3);
+    auto tmp3 = _mm512_fmadd_ps(tmp2, t, p2);
+    auto r = _mm512_fmadd_ps(tmp3, t, p1);
+    // - exp(- x * x)
+    auto pow_2 = _mm512_mul_ps(values, values);
+    auto neg_pow_2 = _mm512_xor_ps(neg_zero_vec, pow_2);
+    // auto tmp4 = exp(neg_pow_2);
+    auto tmp4 = Vectorized(Sleef_expf16_u10(neg_pow_2));
+    auto tmp5 = _mm512_xor_ps(neg_zero_vec, tmp4);
+    // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
+    auto tmp6 = _mm512_mul_ps(tmp5, t);
+    auto tmp7 = _mm512_fmadd_ps(tmp6, r, one_vec);
+    return _mm512_xor_ps(sign_mask, tmp7);
+  }
+  Vectorized erfc() const {
+    return Vectorized(Sleef_erfcf16_u15(values));
+  }
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+  Vectorized exp() const {
+    return Vectorized(Sleef_expf16_u10(values));
+  }
+  Vectorized exp2() const {
+    return Vectorized(Sleef_exp2f16_u10(values));
+  }
+  Vectorized expm1() const {
+    return Vectorized(Sleef_expm1f16_u10(values));
+  }
+  Vectorized exp_u20() const {
+    // A faster version of exp with ULP=20
+    static __m512 vec_factorial_1 =
+        _mm512_set1_ps(0.999999701f); // 1/factorial(1)
+    static __m512 vec_factorial_2 =
+        _mm512_set1_ps(0.499991506f); // 1/factorial(2)
+    static __m512 vec_factorial_3 =
+        _mm512_set1_ps(0.166676521f); // 1/factorial(3)
+    static __m512 vec_factorial_4 =
+        _mm512_set1_ps(0.0418978221f); // 1/factorial(4)
+    static __m512 vec_factorial_5 =
+        _mm512_set1_ps(0.00828929059f); // 1/factorial(5)
+    static __m512 vec_exp_log2ef =
+        (__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
+    static __m512 vec_half = _mm512_set1_ps(0.5f);
+    static __m512 vec_one = _mm512_set1_ps(1.f);
+    static __m512 vec_zero = _mm512_set1_ps(0.f);
+    static __m512 vec_two = _mm512_set1_ps(2.f);
+    static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
+    static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
+    static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
+    static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
+    static int n_mantissa_bits = 23;
+
+    // exp(x) =
+    // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
+    // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression
+
+    auto less_ln_flt_min_mask =
+        _mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/);
+    auto vec_src = _mm512_min_ps(values, vec_ln_flt_max);
+    vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min);
+
+    // fx = floorf(x * log2ef + 0.5)
+    auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half);
+    auto vec_fx_i = _mm512_cvt_roundps_epi32(
+        vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
+    vec_fx = _mm512_cvtepi32_ps(vec_fx_i);
+
+    // x = x - fx * ln2
+    auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src);
+
+    // compute polynomial
+    auto vec_res =
+        _mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4);
+    vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3);
+    vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2);
+    vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1);
+    vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one);
+
+    // compute 2^(n-1)
+    auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one);
+    auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number);
+    auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127);
+    vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
+    auto vec_two_pow_n = (__m512)vec_two_pow_n_i;
+    vec_two_pow_n =
+        _mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero);
+
+    // y = y * 2^n
+    vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n);
+    vec_res = _mm512_mul_ps(vec_res, vec_two);
+    return vec_res;
+  }
+  Vectorized fmod(const Vectorized& q) const {
+    return Vectorized(Sleef_fmodf16(values, q));
+  }
+  Vectorized log() const {
+    return Vectorized(Sleef_logf16_u10(values));
+  }
+  Vectorized log2() const {
+    return Vectorized(Sleef_log2f16_u10(values));
+  }
+  Vectorized log10() const {
+    return Vectorized(Sleef_log10f16_u10(values));
+  }
+  Vectorized log1p() const {
+    return Vectorized(Sleef_log1pf16_u10(values));
+  }
+  Vectorized frac() const;
+  Vectorized sin() const {
+    return Vectorized(Sleef_sinf16_u35(values));
+  }
+  Vectorized sinh() const {
+    return Vectorized(Sleef_sinhf16_u10(values));
+  }
+  Vectorized cos() const {
+    return Vectorized(Sleef_cosf16_u35(values));
+  }
+  Vectorized cosh() const {
+    return Vectorized(Sleef_coshf16_u10(values));
+  }
+  Vectorized ceil() const {
+    return _mm512_ceil_ps(values);
+  }
+  Vectorized floor() const {
+    return _mm512_floor_ps(values);
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    return Vectorized(Sleef_hypotf16_u05(values, b));
+  }
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    __at_align__ float tmp[size()];
+    __at_align__ float tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized igammac(const Vectorized &x) const {
+    __at_align__ float tmp[size()];
+    __at_align__ float tmp_x[size()];
+    store(tmp);
+    x.store(tmp_x);
+    for (const auto i : c10::irange(size())) {
+      tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
+    }
+    return loadu(tmp);
+  }
+  Vectorized neg() const {
+    return _mm512_xor_ps(_mm512_set1_ps(-0.f), values);
+  }
+  Vectorized nextafter(const Vectorized &b) const {
+    return Vectorized(Sleef_nextafterf16(values, b));
+  }
+  Vectorized round() const {
+    return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  }
+  Vectorized tan() const {
+    return Vectorized(Sleef_tanf16_u10(values));
+  }
+  Vectorized tanh() const {
+    return Vectorized(Sleef_tanhf16_u10(values));
+  }
+  Vectorized trunc() const {
+    return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
+  }
+  Vectorized lgamma() const {
+    return Vectorized(Sleef_lgammaf16_u10(values));
+  }
+  Vectorized sqrt() const {
+    return _mm512_sqrt_ps(values);
+  }
+  Vectorized reciprocal() const {
+    return _mm512_div_ps(_mm512_set1_ps(1), values);
+  }
+  Vectorized rsqrt() const {
+    return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values));
+  }
+  Vectorized pow(const Vectorized &b) const {
+    return Vectorized(Sleef_powf16_u10(values, b));
+  }
+  // Comparison using the _CMP_**_OQ predicate.
+  //   `O`: get false if an operand is NaN
+  //   `Q`: do not raise if an operand is NaN
+  Vectorized operator==(const Vectorized& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+
+  Vectorized operator!=(const Vectorized& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_UQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+
+  Vectorized operator<(const Vectorized& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+
+  Vectorized operator<=(const Vectorized& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+
+  Vectorized operator>(const Vectorized& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+
+  Vectorized operator>=(const Vectorized& other) const {
+    auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ);
+    return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask,
+                                                      0xFFFFFFFF));
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_ps(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_ps(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm512_mul_ps(a, b);
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return _mm512_div_ps(a, b);
+}
+
+// frac. Implement this here so we can use subtraction
+inline Vectorized Vectorized::frac() const {
+  return *this - this->trunc();
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  auto zero_vec = _mm512_set1_epi32(0);
+  auto max = _mm512_max_ps(a, b);
+  auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
+  auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
+                                                          0xFFFFFFFF));
+  // Exploit the fact that all-ones is a NaN.
+  return _mm512_or_ps(max, isnan);
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  auto zero_vec = _mm512_set1_epi32(0);
+  auto min = _mm512_min_ps(a, b);
+  auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q);
+  auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask,
+                                                          0xFFFFFFFF));
+  // Exploit the fact that all-ones is a NaN.
+  return _mm512_or_ps(min, isnan);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) {
+  return _mm512_min_ps(max, _mm512_max_ps(min, a));
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) {
+  return _mm512_min_ps(max, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) {
+  return _mm512_max_ps(min, a);
+}
+
+template <>
+Vectorized inline operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm512_and_ps(a, b);
+}
+
+template <>
+Vectorized inline operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm512_or_ps(a, b);
+}
+
+template <>
+Vectorized inline operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm512_xor_ps(a, b);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1.0f);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1.0f);
+}
+
+template <>
+inline void convert(const float* src, float* dst, int64_t n) {
+  int64_t i;
+#pragma unroll
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i));
+  }
+#pragma unroll
+  for (; i < n; i++) {
+    dst[i] = src[i];
+  }
+}
+
+template <>
+Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm512_fmadd_ps(a, b, c);
+}
+
+template <>
+Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return _mm512_fmsub_ps(a, b, c);
+}
+
+// TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle)
+// Used by Inductor CPP codegen
+// Code referred to FBGEMM:
+// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6
+// 16 * 6 = 96 instructions
+template<>
+inline void transpose_mxn(
+    const float* src,
+    int64_t ld_src,
+    float* dst,
+    int64_t ld_dst) {
+  // load from src to registers
+  // a: a0  a1  a2  a3  a4  a5  a6  a7  a8  a9  a10 a11 a12 a13 a14 a15
+  // b: b0  b1  b2  b3  b4  b5  b6  b7  b8  b9  b10 b11 b12 b13 b14 b15
+  // c: c0  c1  c2  c3  c4  c5  c6  c7  c8  c9  c10 c11 c12 c13 c14 c15
+  // d: d0  d1  d2  d3  d4  d5  d6  d7  d8  d9  d10 d11 d12 d13 d14 d15
+  // e: e0  e1  e2  e3  e4  e5  e6  e7  e8  e9  e10 e11 e12 e13 e14 e15
+  // f: f0  f1  f2  f3  f4  f5  f6  f7  f8  f9  f10 f11 f12 f13 f14 f15
+  // g: g0  g1  g2  g3  g4  g5  g6  g7  g8  g9  g10 g11 g12 g13 g14 g15
+  // h: h0  h1  h2  h3  h4  h5  h6  h7  h8  h9  h10 h11 h12 h13 h14 h15
+  // i: i0  i1  i2  i3  i4  i5  i6  i7  i8  i9  i10 i11 i12 i13 i14 i15
+  // j: j0  j1  j2  j3  j4  j5  j6  j7  j8  j9  j10 j11 j12 j13 j14 j15
+  // k: k0  k1  k2  k3  k4  k5  k6  k7  k8  k9  k10 k11 k12 k13 k14 k15
+  // l: l0  l1  l2  l3  l4  l5  l6  l7  l8  l9  l10 l11 l12 l13 l14 l15
+  // m: m0  m1  m2  m3  m4  m5  m6  m7  m8  m9  m10 m11 m12 m13 m14 m15
+  // n: n0  n1  n2  n3  n4  n5  n6  n7  n8  n9  n10 n11 n12 n13 n14 n15
+  // o: o0  o1  o2  o3  o4  o5  o6  o7  o8  o9  o10 o11 o12 o13 o14 o15
+  // p: p0  p1  p2  p3  p4  p5  p6  p7  p8  p9  p10 p11 p12 p13 p14 p15
+  __m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
+  __m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
+  __m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
+  __m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
+  __m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
+  __m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
+  __m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
+  __m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
+  __m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
+  __m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
+  __m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
+  __m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
+  __m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
+  __m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
+  __m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
+  __m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
+
+  __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
+  // unpacking and interleaving 32-bit elements
+  // a0  b0  a1  b1  a4  b4  a5  b5  a8  b8  a9  b9  a12  b12 a13 b13
+  // a2  b2  a3  b3  a6  b6  a7  b7  a10 b10 a11 b11 a14  b14 a15 b15
+  // c0  d0  c1  d1 ...
+  // c2  d2  c3  d3 ...
+  // e0  f0  e1  f1 ...
+  // e2  f2  e3  f3 ...
+  // g0  h0  g1  h1 ...
+  // g2  h2  g3  h3 ...
+  // i0  ...
+  // i2  ...
+  // k0  ...
+  // k2  ...
+  // m0  ...
+  // m2  ...
+  // o0  ...
+  // o1  ...
+  ta = _mm512_unpacklo_ps(a, b);
+  tb = _mm512_unpackhi_ps(a, b);
+  tc = _mm512_unpacklo_ps(c, d);
+  td = _mm512_unpackhi_ps(c, d);
+  te = _mm512_unpacklo_ps(e, f);
+  tf = _mm512_unpackhi_ps(e, f);
+  tg = _mm512_unpacklo_ps(g, h);
+  th = _mm512_unpackhi_ps(g, h);
+  ti = _mm512_unpacklo_ps(i, j);
+  tj = _mm512_unpackhi_ps(i, j);
+  tk = _mm512_unpacklo_ps(k, l);
+  tl = _mm512_unpackhi_ps(k, l);
+  tm = _mm512_unpacklo_ps(m, n);
+  tn = _mm512_unpackhi_ps(m, n);
+  to = _mm512_unpacklo_ps(o, p);
+  tq = _mm512_unpackhi_ps(o, p);
+
+  // unpacking and interleaving 64-bit elements
+  //  a0  b0  c0  d0  a4  b4  c4  d4  a8  b8  c8  d8  a12 b12 c12 d12
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  e0  f0  g0  h0  e4  f4  g4  h4  e8  f8  g8  h8  e12 f12 g12 h12
+  //  e1  f1  g1  h1 ...
+  //  e2  f2  g2  h2 ...
+  //  e3  f3  g3  h3 ...
+  //  i0  j0  k0  l0 ...
+  //  i1  j1  k1  l1 ...
+  //  i2  j2  k2  l2 ...
+  //  i3  j3  k3  l3 ...
+  //  m0  n0  o0  p0 ...
+  //  m1  n1  o1  p1 ...
+  //  m2  n2  o2  p2 ...
+  //  m3  n3  o3  p3 ...
+  a = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+  b = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+  c = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+  d = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+  e = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+  f = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+  g = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+  h = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+  i = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+  j = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+  k = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+  l = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+  m = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+  n = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+  o = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+  p = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+
+  //  shuffle 128-bits (composed of 4 32-bit elements)
+  //  a0  b0  c0  d0  a8  b8  c8  d8  e0  f0  g0  h0  e8  f8  g8  h8
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  a4  b4  c4  d4 ...
+  //  a5  b5  c5  d5 ...
+  //  a6  b6  c6  d6 ...
+  //  a7  b7  c7  d7 ...
+  //  i0  j0  k0  l0  i8  j8  k8  l8  m0  n0  o0  p0  m8  n8  o8  p8
+  //  i1  j1  k1  l1 ...
+  //  i2  j2  k2  l2 ...
+  //  i3  j3  k3  l3 ...
+  //  i4  j4  k4  l4 ...
+  //  i5  j5  k5  l5 ...
+  //  i6  j6  k6  l6 ...
+  //  i7  j7  k7  l7 ...
+  ta = _mm512_shuffle_f32x4(a, e, 0x88);
+  tb = _mm512_shuffle_f32x4(b, f, 0x88);
+  tc = _mm512_shuffle_f32x4(c, g, 0x88);
+  td = _mm512_shuffle_f32x4(d, h, 0x88);
+  te = _mm512_shuffle_f32x4(a, e, 0xdd);
+  tf = _mm512_shuffle_f32x4(b, f, 0xdd);
+  tg = _mm512_shuffle_f32x4(c, g, 0xdd);
+  th = _mm512_shuffle_f32x4(d, h, 0xdd);
+  ti = _mm512_shuffle_f32x4(i, m, 0x88);
+  tj = _mm512_shuffle_f32x4(j, n, 0x88);
+  tk = _mm512_shuffle_f32x4(k, o, 0x88);
+  tl = _mm512_shuffle_f32x4(l, p, 0x88);
+  tm = _mm512_shuffle_f32x4(i, m, 0xdd);
+  tn = _mm512_shuffle_f32x4(j, n, 0xdd);
+  to = _mm512_shuffle_f32x4(k, o, 0xdd);
+  tq = _mm512_shuffle_f32x4(l, p, 0xdd);
+
+  //  shuffle 128-bits (composed of 4 32-bit elements)
+  //  a0  b0  c0  d0  ...  o0
+  //  a1  b1  c1  d1  ...  o1
+  //  a2  b2  c2  d2  ...  o2
+  //  a3  b3  c3  d3  ...  o3
+  //  a4  ...
+  //  a5  ...
+  //  a6  ...
+  //  a7  ...
+  //  a8  ...
+  //  a9  ...
+  //  a10 ...
+  //  a11 ...
+  //  a12 ...
+  //  a13 ...
+  //  a14 ...
+  //  a15 b15 c15 d15 ...  o15
+  a = _mm512_shuffle_f32x4(ta, ti, 0x88);
+  b = _mm512_shuffle_f32x4(tb, tj, 0x88);
+  c = _mm512_shuffle_f32x4(tc, tk, 0x88);
+  d = _mm512_shuffle_f32x4(td, tl, 0x88);
+  e = _mm512_shuffle_f32x4(te, tm, 0x88);
+  f = _mm512_shuffle_f32x4(tf, tn, 0x88);
+  g = _mm512_shuffle_f32x4(tg, to, 0x88);
+  h = _mm512_shuffle_f32x4(th, tq, 0x88);
+  i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
+  j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
+  k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
+  l = _mm512_shuffle_f32x4(td, tl, 0xdd);
+  m = _mm512_shuffle_f32x4(te, tm, 0xdd);
+  n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
+  o = _mm512_shuffle_f32x4(tg, to, 0xdd);
+  p = _mm512_shuffle_f32x4(th, tq, 0xdd);
+
+  // store from registers to dst
+  _mm512_storeu_ps(&dst[0 * ld_dst], a);
+  _mm512_storeu_ps(&dst[1 * ld_dst], b);
+  _mm512_storeu_ps(&dst[2 * ld_dst], c);
+  _mm512_storeu_ps(&dst[3 * ld_dst], d);
+  _mm512_storeu_ps(&dst[4 * ld_dst], e);
+  _mm512_storeu_ps(&dst[5 * ld_dst], f);
+  _mm512_storeu_ps(&dst[6 * ld_dst], g);
+  _mm512_storeu_ps(&dst[7 * ld_dst], h);
+  _mm512_storeu_ps(&dst[8 * ld_dst], i);
+  _mm512_storeu_ps(&dst[9 * ld_dst], j);
+  _mm512_storeu_ps(&dst[10 * ld_dst], k);
+  _mm512_storeu_ps(&dst[11 * ld_dst], l);
+  _mm512_storeu_ps(&dst[12 * ld_dst], m);
+  _mm512_storeu_ps(&dst[13 * ld_dst], n);
+  _mm512_storeu_ps(&dst[14 * ld_dst], o);
+  _mm512_storeu_ps(&dst[15 * ld_dst], p);
+}
+
+#endif
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h
new file mode 100644
index 0000000000000000000000000000000000000000..9f316ab65166420e40a31bc0a4913ce7c682ae89
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_int.h
@@ -0,0 +1,1459 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace vec {
+inline namespace CPU_CAPABILITY {
+
+#ifdef CPU_CAPABILITY_AVX512
+
+struct Vectorizedi {
+protected:
+  __m512i values;
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+  static inline __m512i invert(const __m512i& v) {
+    const auto ones = _mm512_set1_epi64(-1);
+    return _mm512_xor_si512(ones, v);
+  }
+public:
+  Vectorizedi() {}
+  Vectorizedi(__m512i v) : values(v) {}
+  operator __m512i() const {
+    return values;
+  }
+};
+
+#else
+
+struct Vectorizedi {};  // dummy definition to make Vectorizedi always defined
+
+#endif // CPU_CAPABILITY_AVX512
+
+#ifdef CPU_CAPABILITY_AVX512
+
+template <>
+class Vectorized : public Vectorizedi {
+private:
+  static const Vectorized ones;
+public:
+  using value_type = int64_t;
+  using size_type = int;
+  static constexpr size_type size() {
+    return 8;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized() {}
+  Vectorized(int64_t v) { values = _mm512_set1_epi64(v); }
+  Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4,
+         int64_t val5, int64_t val6, int64_t val7, int64_t val8) {
+    values = _mm512_setr_epi64(val1, val2, val3, val4,
+                                val5, val6, val7, val8);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    return _mm512_mask_blend_epi64(mask, a.values, b.values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                                const Vectorized& mask) {
+    auto msb_one = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF);
+    auto mask_ = _mm512_cmp_epi64_mask(mask, msb_one, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_epi64(mask_, a.values, b.values);
+  }
+  template 
+  static Vectorized arange(int64_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(base,            base + step,     base + 2 * step, base + 3 * step,
+                           base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, int64_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm512_loadu_si512(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    if (count == size()) {
+      return _mm512_loadu_si512(reinterpret_cast(ptr));
+    } else {
+      __mmask8 mask = (1ULL << count) - 1;
+      return _mm512_maskz_loadu_epi64(mask, ptr);
+    }
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html
+      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
+    } else if (count > 0) {
+      __mmask8 mask = (1ULL << count) - 1;
+      _mm512_mask_storeu_epi64(ptr, mask, values);
+    }
+  }
+  const int64_t& operator[](int idx) const  = delete;
+  int64_t& operator[](int idx)  = delete;
+  Vectorized abs() const {
+    auto is_larger_mask = _mm512_cmpgt_epi64_mask(zero_vector, values);
+    auto is_larger = _mm512_mask_set1_epi64(zero_vector, is_larger_mask, 0xFFFFFFFFFFFFFFFF);
+    auto inverse = _mm512_xor_si512(values, is_larger);
+    return _mm512_sub_epi64(inverse, is_larger);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_epi64(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized neg() const;
+  Vectorized operator==(const Vectorized& other) const {
+    auto mask = _mm512_cmpeq_epi64_mask(values, other.values);
+    return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    auto mask = _mm512_cmpneq_epi64_mask(values, other.values);
+    return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF);
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    auto mask = _mm512_cmplt_epi64_mask(values, other.values);
+    return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    auto mask = _mm512_cmple_epi64_mask(values, other.values);
+    return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF);
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    auto mask = _mm512_cmpgt_epi64_mask(values, other.values);
+    return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF);
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    auto mask = _mm512_cmpge_epi64_mask(values, other.values);
+    return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF);
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+class Vectorized : public Vectorizedi {
+private:
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+  static const Vectorized ones;
+public:
+  using value_type = int32_t;
+  static constexpr int size() {
+    return 16;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized() {}
+  Vectorized(int32_t v) { values = _mm512_set1_epi32(v); }
+  Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
+            int32_t val5, int32_t val6, int32_t val7, int32_t val8,
+            int32_t val9, int32_t val10, int32_t val11, int32_t val12,
+            int32_t val13, int32_t val14, int32_t val15, int32_t val16) {
+    values = _mm512_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8,
+                               val9, val10, val11, val12, val13, val14, val15, val16);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    return _mm512_mask_blend_epi32(mask, a.values, b.values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                                const Vectorized& mask) {
+    auto msb_one = _mm512_set1_epi32(0xFFFFFFFF);
+    auto mask_ = _mm512_cmp_epi32_mask(mask, msb_one, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_epi32(mask_, a.values, b.values);
+  }
+  template 
+  static Vectorized arange(int32_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, int32_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<1>(a, b);
+      case 2:
+        return blend<3>(a, b);
+      case 3:
+        return blend<7>(a, b);
+      case 4:
+        return blend<15>(a, b);
+      case 5:
+        return blend<31>(a, b);
+      case 6:
+        return blend<63>(a, b);
+      case 7:
+        return blend<127>(a, b);
+      case 8:
+        return blend<255>(a, b);
+      case 9:
+        return blend<511>(a, b);
+      case 10:
+        return blend<1023>(a, b);
+      case 11:
+        return blend<2047>(a, b);
+      case 12:
+        return blend<4095>(a, b);
+      case 13:
+        return blend<8191>(a, b);
+      case 14:
+        return blend<16383>(a, b);
+      case 15:
+        return blend<32767>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm512_loadu_si512(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu(const void* ptr, int32_t count) {
+    if (count == size()) {
+      return _mm512_loadu_si512(reinterpret_cast(ptr));
+    } else {
+      __mmask16 mask = (1ULL << count) - 1;
+      return _mm512_maskz_loadu_epi32(mask, ptr);
+    }
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html
+      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
+    } else if (count > 0) {
+      __mmask16 mask = (1ULL << count) - 1;
+      _mm512_mask_storeu_epi32(ptr, mask, values);
+    }
+  }
+  const int32_t& operator[](int idx) const  = delete;
+  int32_t& operator[](int idx)  = delete;
+  Vectorized abs() const {
+    return _mm512_abs_epi32(values);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_epi32(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized neg() const;
+  Vectorized operator==(const Vectorized& other) const {
+    auto mask = _mm512_cmpeq_epi32_mask(values, other.values);
+    return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    auto mask = _mm512_cmpneq_epi32_mask(values, other.values);
+    return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    auto mask = _mm512_cmplt_epi32_mask(values, other.values);
+    return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    auto mask = _mm512_cmple_epi32_mask(values, other.values);
+    return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    auto mask = _mm512_cmpgt_epi32_mask(values, other.values);
+    return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    auto mask = _mm512_cmpge_epi32_mask(values, other.values);
+    return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);
+  }
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+inline void convert(const int32_t *src, float *dst, int64_t n) {
+  int64_t i;
+  // int32_t and float have same size
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto input_vec = _mm512_loadu_si512(reinterpret_cast(src + i));
+    auto output_vec = _mm512_cvtepi32_ps(input_vec);
+    _mm512_storeu_ps(reinterpret_cast(dst + i), output_vec);
+  }
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+inline void convert(const int32_t *src, double *dst, int64_t n) {
+  int64_t i;
+  // int32_t has half the size of double
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) {
+    auto input_256_vec = _mm256_loadu_si256(reinterpret_cast(src + i));
+    auto output_vec = _mm512_cvtepi32_pd(input_256_vec);
+    _mm512_storeu_pd(reinterpret_cast(dst + i), output_vec);
+  }
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (; i < n; i++) {
+    dst[i] = static_cast(src[i]);
+  }
+}
+
+template <>
+class Vectorized : public Vectorizedi {
+private:
+  static const Vectorized ones;
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+public:
+  using value_type = int16_t;
+  static constexpr int size() {
+    return 32;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized() {}
+  Vectorized(int16_t v) { values = _mm512_set1_epi16(v); }
+  Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
+         int16_t val5, int16_t val6, int16_t val7, int16_t val8,
+         int16_t val9, int16_t val10, int16_t val11, int16_t val12,
+         int16_t val13, int16_t val14, int16_t val15, int16_t val16,
+         int16_t val17, int16_t val18, int16_t val19, int16_t val20,
+         int16_t val21, int16_t val22, int16_t val23, int16_t val24,
+         int16_t val25, int16_t val26, int16_t val27, int16_t val28,
+         int16_t val29, int16_t val30, int16_t val31, int16_t val32) {
+    values = _mm512_set_epi16(val32, val31, val30, val29, val28, val27, val26, val25,
+                              val24, val23, val22, val21, val20, val19, val18, val17,
+                              val16, val15, val14, val13, val12, val11, val10, val9,
+                              val8, val7, val6, val5, val4, val3, val2, val1);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    return _mm512_mask_blend_epi16(mask, a.values, b.values);
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                                const Vectorized& mask) {
+    auto msb_one = _mm512_set1_epi16(0xFFFF);
+    auto mask_ = _mm512_cmp_epi16_mask(mask, msb_one, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_epi16(mask_, a.values, b.values);
+  }
+  template 
+  static Vectorized arange(int16_t base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
+      base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
+      base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
+      base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
+      base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step
+    );
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, int16_t count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<0x1>(a, b);
+      case 2:
+        return blend<0x3>(a, b);
+      case 3:
+        return blend<0x7>(a, b);
+      case 4:
+        return blend<0xF>(a, b);
+      case 5:
+        return blend<0x1F>(a, b);
+      case 6:
+        return blend<0x3F>(a, b);
+      case 7:
+        return blend<0x7F>(a, b);
+      case 8:
+        return blend<0xFF>(a, b);
+      case 9:
+        return blend<0x1FF>(a, b);
+      case 10:
+        return blend<0x3FF>(a, b);
+      case 11:
+        return blend<0x7FF>(a, b);
+      case 12:
+        return blend<0xFFF>(a, b);
+      case 13:
+        return blend<0x1FFF>(a, b);
+      case 14:
+        return blend<0x3FFF>(a, b);
+      case 15:
+        return blend<0x7FFF>(a, b);
+      case 16:
+        return blend<0xFFFF>(a, b);
+      case 17:
+        return blend<0x1FFFF>(a, b);
+      case 18:
+        return blend<0x3FFFF>(a, b);
+      case 19:
+        return blend<0x7FFFF>(a, b);
+      case 20:
+        return blend<0xFFFFF>(a, b);
+      case 21:
+        return blend<0x1FFFFF>(a, b);
+      case 22:
+        return blend<0x3FFFFF>(a, b);
+      case 23:
+        return blend<0x7FFFFF>(a, b);
+      case 24:
+        return blend<0xFFFFFF>(a, b);
+      case 25:
+        return blend<0x1FFFFFF>(a, b);
+      case 26:
+        return blend<0x3FFFFFF>(a, b);
+      case 27:
+        return blend<0x7FFFFFF>(a, b);
+      case 28:
+        return blend<0xFFFFFFF>(a, b);
+      case 29:
+        return blend<0x1FFFFFFF>(a, b);
+      case 30:
+        return blend<0x3FFFFFFF>(a, b);
+      case 31:
+        return blend<0x7FFFFFFF>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm512_loadu_si512(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu(const void* ptr, int16_t count) {
+    if (count == size()) {
+      return _mm512_loadu_si512(reinterpret_cast(ptr));
+    } else {
+      __mmask32 mask = (1ULL << count) - 1;
+      return _mm512_maskz_loadu_epi16(mask, ptr);
+    }
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html
+      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
+    } else if (count > 0) {
+      __mmask32 mask = (1ULL << count) - 1;
+      _mm512_mask_storeu_epi16(ptr, mask, values);
+    }
+  }
+  const int16_t& operator[](int idx) const  = delete;
+  int16_t& operator[](int idx)  = delete;
+  Vectorized abs() const {
+    return _mm512_abs_epi16(values);
+  }
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_epi16(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+  Vectorized neg() const;
+  Vectorized operator==(const Vectorized& other) const {
+    auto mask = _mm512_cmpeq_epi16_mask(values, other.values);
+    return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    auto mask = _mm512_cmpneq_epi16_mask(values, other.values);
+    return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF);
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    auto mask = _mm512_cmplt_epi16_mask(values, other.values);
+    return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    auto mask = _mm512_cmple_epi16_mask(values, other.values);
+    return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF);
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    auto mask = _mm512_cmpgt_epi16_mask(values, other.values);
+    return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF);
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    auto mask = _mm512_cmpge_epi16_mask(values, other.values);
+    return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF);
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template 
+class Vectorized8 : public Vectorizedi {
+  static_assert(
+    std::is_same::value || std::is_same::value,
+    "Only int8_t/uint8_t are supported");
+protected:
+  static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
+  static const Vectorized ones;
+public:
+  using value_type = T;
+  static constexpr int size() {
+    return 64;
+  }
+  using Vectorizedi::Vectorizedi;
+  Vectorized8() {}
+  Vectorized8(T v) { values = _mm512_set1_epi8(v); }
+  Vectorized8(T val1, T val2, T val3, T val4,
+         T val5, T val6, T val7, T val8,
+         T val9, T val10, T val11, T val12,
+         T val13, T val14, T val15, T val16,
+         T val17, T val18, T val19, T val20,
+         T val21, T val22, T val23, T val24,
+         T val25, T val26, T val27, T val28,
+         T val29, T val30, T val31, T val32,
+         T val33, T val34, T val35, T val36,
+         T val37, T val38, T val39, T val40,
+         T val41, T val42, T val43, T val44,
+         T val45, T val46, T val47, T val48,
+         T val49, T val50, T val51, T val52,
+         T val53, T val54, T val55, T val56,
+         T val57, T val58, T val59, T val60,
+         T val61, T val62, T val63, T val64){
+    values = _mm512_set_epi8(val64, val63, val62, val61, val60, val59, val58, val57,
+                              val56, val55, val54, val53,val52, val51, val50, val49,
+                              val48, val47, val46, val45, val44, val43, val42, val41,
+                              val40, val39, val38, val37, val36, val35, val34, val33,
+                              val32, val31, val30, val29, val28, val27, val26, val25,
+                              val24, val23, val22, val21, val20, val19, val18, val17,
+                              val16, val15, val14, val13, val12, val11, val10, val9,
+                              val8, val7, val6, val5, val4, val3, val2, val1);
+  }
+  template 
+  static Vectorized blend(Vectorized a, Vectorized b) {
+    return _mm512_mask_blend_epi8(mask, a.values, b.values);
+  }
+  template 
+  static Vectorized arange(T base = 0, step_t step = static_cast(1)) {
+    return Vectorized(
+      base,             base +      step, base +  2 * step, base +  3 * step,
+      base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
+      base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
+      base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
+      base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
+      base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
+      base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
+      base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step,
+      base + 32 * step, base + 33 * step, base + 34 * step, base + 35 * step,
+      base + 36 * step, base + 37 * step, base + 38 * step, base + 39 * step,
+      base + 40 * step, base + 41 * step, base + 42 * step, base + 43 * step,
+      base + 44 * step, base + 45 * step, base + 46 * step, base + 47 * step,
+      base + 48 * step, base + 49 * step, base + 50 * step, base + 51 * step,
+      base + 52 * step, base + 53 * step, base + 54 * step, base + 55 * step,
+      base + 56 * step, base + 57 * step, base + 58 * step, base + 59 * step,
+      base + 60 * step, base + 61 * step, base + 62 * step, base + 63 * step);
+  }
+  static Vectorized
+  set(Vectorized a, Vectorized b, T count = size()) {
+    switch (count) {
+      case 0:
+        return a;
+      case 1:
+        return blend<0x1>(a, b);
+      case 2:
+        return blend<0x3>(a, b);
+      case 3:
+        return blend<0x7>(a, b);
+      case 4:
+        return blend<0xF>(a, b);
+      case 5:
+        return blend<0x1F>(a, b);
+      case 6:
+        return blend<0x3F>(a, b);
+      case 7:
+        return blend<0x7F>(a, b);
+      case 8:
+        return blend<0xFF>(a, b);
+      case 9:
+        return blend<0x1FF>(a, b);
+      case 10:
+        return blend<0x3FF>(a, b);
+      case 11:
+        return blend<0x7FF>(a, b);
+      case 12:
+        return blend<0xFFF>(a, b);
+      case 13:
+        return blend<0x1FFF>(a, b);
+      case 14:
+        return blend<0x3FFF>(a, b);
+      case 15:
+        return blend<0x7FFF>(a, b);
+      case 16:
+        return blend<0xFFFF>(a, b);
+      case 17:
+        return blend<0x1FFFF>(a, b);
+      case 18:
+        return blend<0x3FFFF>(a, b);
+      case 19:
+        return blend<0x7FFFF>(a, b);
+      case 20:
+        return blend<0xFFFFF>(a, b);
+      case 21:
+        return blend<0x1FFFFF>(a, b);
+      case 22:
+        return blend<0x3FFFFF>(a, b);
+      case 23:
+        return blend<0x7FFFFF>(a, b);
+      case 24:
+        return blend<0xFFFFFF>(a, b);
+      case 25:
+        return blend<0x1FFFFFF>(a, b);
+      case 26:
+        return blend<0x3FFFFFF>(a, b);
+      case 27:
+        return blend<0x7FFFFFF>(a, b);
+      case 28:
+        return blend<0xFFFFFFF>(a, b);
+      case 29:
+        return blend<0x1FFFFFFF>(a, b);
+      case 30:
+        return blend<0x3FFFFFFF>(a, b);
+      case 31:
+        return blend<0x7FFFFFFF>(a, b);
+      case 32:
+        return blend<0xFFFFFFFF>(a, b);
+      case 33:
+        return blend<0x1FFFFFFFF>(a, b);
+      case 34:
+        return blend<0x3FFFFFFFF>(a, b);
+      case 35:
+        return blend<0x7FFFFFFFF>(a, b);
+      case 36:
+        return blend<0xFFFFFFFFF>(a, b);
+      case 37:
+        return blend<0x1FFFFFFFFF>(a, b);
+      case 38:
+        return blend<0x3FFFFFFFFF>(a, b);
+      case 39:
+        return blend<0x7FFFFFFFFF>(a, b);
+      case 40:
+        return blend<0xFFFFFFFFFF>(a, b);
+      case 41:
+        return blend<0x1FFFFFFFFFF>(a, b);
+      case 42:
+        return blend<0x3FFFFFFFFFF>(a, b);
+      case 43:
+        return blend<0x7FFFFFFFFFF>(a, b);
+      case 44:
+        return blend<0xFFFFFFFFFFF>(a, b);
+      case 45:
+        return blend<0x1FFFFFFFFFFF>(a, b);
+      case 46:
+        return blend<0x3FFFFFFFFFFF>(a, b);
+      case 47:
+        return blend<0x7FFFFFFFFFFF>(a, b);
+      case 48:
+        return blend<0xFFFFFFFFFFFF>(a, b);
+      case 49:
+        return blend<0x1FFFFFFFFFFFF>(a, b);
+      case 50:
+        return blend<0x3FFFFFFFFFFFF>(a, b);
+      case 51:
+        return blend<0x7FFFFFFFFFFFF>(a, b);
+      case 52:
+        return blend<0xFFFFFFFFFFFFF>(a, b);
+      case 53:
+        return blend<0x1FFFFFFFFFFFFF>(a, b);
+      case 54:
+        return blend<0x3FFFFFFFFFFFFF>(a, b);
+      case 55:
+        return blend<0x7FFFFFFFFFFFFF>(a, b);
+      case 56:
+        return blend<0xFFFFFFFFFFFFFF>(a, b);
+      case 57:
+        return blend<0x1FFFFFFFFFFFFFF>(a, b);
+      case 58:
+        return blend<0x3FFFFFFFFFFFFFF>(a, b);
+      case 59:
+        return blend<0x7FFFFFFFFFFFFFF>(a, b);
+      case 60:
+        return blend<0xFFFFFFFFFFFFFFF>(a, b);
+      case 61:
+        return blend<0x1FFFFFFFFFFFFFFF>(a, b);
+      case 62:
+        return blend<0x3FFFFFFFFFFFFFFF>(a, b);
+      case 63:
+        return blend<0x7FFFFFFFFFFFFFFF>(a, b);
+    }
+    return b;
+  }
+  static Vectorized loadu(const void* ptr) {
+    return _mm512_loadu_si512(reinterpret_cast(ptr));
+  }
+  static Vectorized loadu_one_fourth(const void* ptr) {
+      // Fast path if only load element number of 16.
+      // Note: We didn't merge it as fast path of loadu(const void* ptr, T count),
+      // Because loadu(const void* ptr, T count) requires zero initialization for upper 384 bits.
+      // However, by using _mm512_castsi128_si512, the upper 384 bits of the result are undefined.
+      // TODO We can use _mm512_zextsi128_si512 in the furture,
+      // since gcc 9.3 doesn't support it now.
+      __m128i input_128 = _mm_loadu_si128(reinterpret_cast(ptr));
+      return _mm512_castsi128_si512(input_128);
+  }
+  static Vectorized loadu(const void* ptr, T count) {
+    if (count == size()) {
+      return _mm512_loadu_si512(reinterpret_cast(ptr));
+    } else if (count == 16) {
+      // Fast path if only load element number of 16
+      return loadu_one_fourth(ptr);
+    } else {
+      __mmask64 mask = (1ULL << count) - 1;
+      return _mm512_maskz_loadu_epi8(mask, ptr);
+    }
+  }
+  void store(void* ptr, int count = size()) const {
+    if (count == size()) {
+      // ptr need not to be aligned here. See
+      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html
+      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
+    } else if (count > 0) {
+      if (count == 16) {
+        // Fast path if only store element number of 16
+        _mm_storeu_si128(
+          reinterpret_cast<__m128i*>(ptr),
+          _mm512_castsi512_si128(values));
+      } else {
+        __mmask64 mask = (1ULL << count) - 1;
+        _mm512_mask_storeu_epi8(ptr, mask, values);
+      }
+    }
+  }
+  const T& operator[](int idx) const  = delete;
+  T& operator[](int idx)  = delete;
+  Vectorized real() const {
+    return *this;
+  }
+  Vectorized imag() const {
+    return _mm512_set1_epi8(0);
+  }
+  Vectorized conj() const {
+    return *this;
+  }
+};
+
+template<>
+class Vectorized: public Vectorized8 {
+public:
+  using Vectorized8::Vectorized8;
+
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                               const Vectorized& mask) {
+    auto msb_one = _mm512_set1_epi8(0xFF);
+    auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_epi8(mask_, a.values, b.values);
+  }
+
+  Vectorized neg() const;
+
+  Vectorized abs() const {
+    return _mm512_abs_epi8(values);
+  }
+
+  Vectorized operator==(const Vectorized& other) const {
+    auto mask = _mm512_cmpeq_epi8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    auto mask = _mm512_cmpneq_epi8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    auto mask = _mm512_cmplt_epi8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    auto mask = _mm512_cmple_epi8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return other < *this;
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return other <= *this;
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template<>
+class Vectorized: public Vectorized8 {
+public:
+  using Vectorized8::Vectorized8;
+
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                               const Vectorized& mask) {
+    auto msb_one = _mm512_set1_epi8(0xFF);
+    auto mask_ = _mm512_cmp_epu8_mask(mask, msb_one, _MM_CMPINT_EQ);
+    return _mm512_mask_blend_epi8(mask_, a.values, b.values);
+  }
+
+  Vectorized neg() const;
+
+  Vectorized abs() const {
+    return *this;
+  }
+
+  Vectorized operator==(const Vectorized& other) const {
+    auto mask = _mm512_cmpeq_epu8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator!=(const Vectorized& other) const {
+    auto mask = _mm512_cmpneq_epu8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator<(const Vectorized& other) const {
+    auto mask = _mm512_cmplt_epu8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator<=(const Vectorized& other) const {
+    auto mask = _mm512_cmple_epu8_mask(values, other.values);
+    return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF);
+  }
+  Vectorized operator>(const Vectorized& other) const {
+    return other < *this;
+  }
+  Vectorized operator>=(const Vectorized& other) const {
+    return other <= *this;
+  }
+
+  Vectorized eq(const Vectorized& other) const;
+  Vectorized ne(const Vectorized& other) const;
+  Vectorized gt(const Vectorized& other) const;
+  Vectorized ge(const Vectorized& other) const;
+  Vectorized lt(const Vectorized& other) const;
+  Vectorized le(const Vectorized& other) const;
+};
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_epi16(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_epi8(a, b);
+}
+
+template <>
+Vectorized inline operator+(const Vectorized& a, const Vectorized& b) {
+  return _mm512_add_epi8(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_epi16(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_epi8(a, b);
+}
+
+template <>
+Vectorized inline operator-(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sub_epi8(a, b);
+}
+
+// Negation. Defined here so we can utilize operator-
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+inline Vectorized Vectorized::neg() const {
+  return Vectorized(0) - *this;
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm512_mullo_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm512_mullo_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  return _mm512_mullo_epi16(a, b);
+}
+
+template 
+Vectorized inline int_elementwise_binary_512(const Vectorized& a, const Vectorized& b, Op op) {
+  T values_a[Vectorized::size()];
+  T values_b[Vectorized::size()];
+  a.store(values_a);
+  b.store(values_b);
+  for (int i = 0; i != Vectorized::size(); i++) {
+    values_a[i] = op(values_a[i], values_b[i]);
+  }
+  return Vectorized::loadu(values_a);
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  // We don't have an instruction for multiplying int8_t
+#ifndef CPU_CAPABILITY_AVX512
+  return int_elementwise_binary_512(a, b, std::multiplies());
+#else
+  __m512i mask00FF = _mm512_set1_epi16(0x00FF);
+  __m512i a_lo = _mm512_srai_epi16(_mm512_slli_epi16(a, 8), 8);
+  __m512i b_lo = _mm512_srai_epi16(_mm512_slli_epi16(b, 8), 8);
+  __m512i a_hi = _mm512_srai_epi16(a, 8);
+  __m512i b_hi = _mm512_srai_epi16(b, 8);
+  __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF);
+  __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8);
+  __m512i res = _mm512_or_si512(res_hi, res_lo);
+  return res;
+#endif
+}
+
+template <>
+Vectorized inline operator*(const Vectorized& a, const Vectorized& b) {
+  // We don't have an instruction for multiplying uint8_t
+#ifndef CPU_CAPABILITY_AVX512
+  return int_elementwise_binary_512(a, b, std::multiplies());
+#else
+  __m512i mask00FF = _mm512_set1_epi16(0x00FF);
+  __m512i a_lo = _mm512_and_si512 (a, mask00FF);
+  __m512i b_lo = _mm512_and_si512 (b, mask00FF);
+  __m512i a_hi = _mm512_srli_epi16(a, 8);
+  __m512i b_hi = _mm512_srli_epi16(b, 8);
+  __m512i res_lo = _mm512_and_si512(_mm512_mullo_epi16(a_lo, b_lo), mask00FF);
+  __m512i res_hi = _mm512_slli_epi16(_mm512_mullo_epi16(a_hi, b_hi), 8);
+  __m512i res = _mm512_or_si512(res_hi, res_lo);
+  return res;
+#endif
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_min_epi64(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_min_epi32(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_min_epi16(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_min_epi8(a, b);
+}
+
+template <>
+Vectorized inline minimum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_min_epu8(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_max_epi64(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_max_epi32(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_max_epi16(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_max_epi8(a, b);
+}
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return _mm512_max_epi8(a, b);
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm512_min_epi32(max_val, _mm512_max_epi32(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm512_min_epi16(max_val, _mm512_max_epi16(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm512_min_epi8(max_val, _mm512_max_epi8(a, min_val));
+}
+
+template <>
+Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) {
+  return _mm512_min_epu8(max_val, _mm512_max_epu8(a, min_val));
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm512_min_epi64(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm512_min_epi32(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm512_min_epi16(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm512_min_epi8(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) {
+  return _mm512_min_epu8(max_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm512_max_epi64(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm512_max_epi32(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm512_max_epi16(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm512_max_epi8(min_val, a);
+}
+
+template <>
+Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) {
+  return _mm512_max_epu8(min_val, a);
+}
+
+template
+Vectorized inline convert_to_int32(const T* ptr) {
+  return Vectorized::loadu(ptr);
+}
+
+template<>
+Vectorized inline convert_to_int32(const int8_t* ptr) {
+  return _mm512_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast(ptr)));
+}
+
+template<>
+Vectorized inline convert_to_int32(const uint8_t* ptr) {
+  return _mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast(ptr)));
+}
+
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_512(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_512(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_512(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_512(a, b, std::divides());
+}
+template <>
+Vectorized inline operator/(const Vectorized& a, const Vectorized& b) {
+  return int_elementwise_binary_512(a, b, std::divides());
+}
+
+template>::value, int> = 0>
+inline Vectorized operator&(const Vectorized& a, const Vectorized& b) {
+  return _mm512_and_si512(a, b);
+}
+template>::value, int> = 0>
+inline Vectorized operator|(const Vectorized& a, const Vectorized& b) {
+  return _mm512_or_si512(a, b);
+}
+template>::value, int> = 0>
+inline Vectorized operator^(const Vectorized& a, const Vectorized& b) {
+  return _mm512_xor_si512(a, b);
+}
+template>::value, int> = 0>
+inline Vectorized operator~(const Vectorized& a) {
+  return _mm512_xor_si512(a, _mm512_set1_epi32(-1));
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::eq(const Vectorized& other) const {
+  return (*this == other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ne(const Vectorized& other) const {
+  return (*this != other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::gt(const Vectorized& other) const {
+  return (*this > other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::ge(const Vectorized& other) const {
+  return (*this >= other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::lt(const Vectorized& other) const {
+  return (*this < other) & Vectorized(1);
+}
+
+inline Vectorized Vectorized::le(const Vectorized& other) const {
+  return (*this <= other) & Vectorized(1);
+}
+
+template ::value || std::is_same::value, int> = 0>
+Vectorized inline shift_512_8(const Vectorized& a, const Vectorized& b) {
+  // No vector instruction for shifting int8_t/uint8_t, so emulating
+  // it instead.
+
+  // Control masks for shuffle operation, treating 512 bits as an
+  // array of 8-bit elements, and considering pairs of neighboring
+  // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
+  // M!=N) is set so that shuffle will move element with index M from
+  // input pair into element with index N in output pair, and element
+  // with index M in output pair will be set to all 0s.
+  __m512i ctl_0_1 = _mm512_set_epi8(62, 0x80, 60, 0x80, 58, 0x80, 56, 0x80,
+                                    54, 0x80, 52, 0x80, 50, 0x80, 48, 0x80,
+                                    46, 0x80, 44, 0x80, 42, 0x80, 40, 0x80,
+                                    38, 0x80, 36, 0x80, 34, 0x80, 32, 0x80,
+                                    30, 0x80, 28, 0x80, 26, 0x80, 24, 0x80,
+                                    22, 0x80, 20, 0x80, 18, 0x80, 16, 0x80,
+                                    14, 0x80, 12, 0x80, 10, 0x80, 8, 0x80,
+                                    6, 0x80, 4, 0x80, 2, 0x80, 0, 0x80);
+  __m512i ctl_1_0 = _mm512_set_epi8(0x80, 63, 0x80, 61, 0x80, 59, 0x80, 57,
+                                    0x80, 55, 0x80, 53, 0x80, 51, 0x80, 49,
+                                    0x80, 47, 0x80, 45, 0x80, 43, 0x80, 41,
+                                    0x80, 39, 0x80, 37, 0x80, 35, 0x80, 33,
+                                    0x80, 31, 0x80, 29, 0x80, 27, 0x80, 25,
+                                    0x80, 23, 0x80, 21, 0x80, 19, 0x80, 17,
+                                    0x80, 15, 0x80, 13, 0x80, 11, 0x80, 9,
+                                    0x80, 7, 0x80, 5, 0x80, 3, 0x80, 1);
+
+  // Masks for bitwise and operation, treating 512 bits as an array of
+  // 8-bit elements, and considering them in pairs of neighboring
+  // elements.  A mask named "keep_M" (M in [0,1]) is set so that
+  // bitwise and will copy element with index M from input pair into
+  // element with the same index in output pair, while the other
+  // element in output pair will be set to all 0s.
+  __m512i keep_0 = _mm512_set1_epi16(0xFF);
+  __m512i keep_1 = _mm512_set1_epi16(0xFF00);
+
+  // Take each 8-bit element with idx%2==0 from input array to be
+  // shifted and extend it to 16 bits so that 0s are added to the
+  // right.  Then, perform shifting on this 16-bit number.  Upper 8
+  // bits will be proper result of shifting original 8-bit number, so
+  // write them to result array, into the same position from which
+  // corresponding input element is taken.  Also, make sure that
+  // result array elements with idx%2!=0 are set to all 0s.
+  //
+  // Note that number of bits to shift for is extended to 16 bits by
+  // adding 0s to the left.  That means this number is not properly
+  // sign-extended for negative values.  However, number of bits to
+  // shift is treated as an unsigned integer by respective shift
+  // intrinsics anyway so if negative then either with or without
+  // proper sign extension, it will be interpreted as a number greater
+  // than 32, and the shifting result will be the same.
+  __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1);
+  __m512i b0 = _mm512_and_si512(b, keep_0);
+  __m512i c0;
+  if (left_shift)
+    c0 = _mm512_sllv_epi16(a0, b0);
+  else
+    if constexpr (std::is_same_v)
+      c0 = _mm512_srav_epi16(a0, b0);
+    else
+      c0 = _mm512_srlv_epi16(a0, b0);
+  c0 = _mm512_shuffle_epi8(c0, ctl_1_0);
+
+  // Peform shifting the same way for input array elements with
+  // idx%2==1.
+  __m512i a1 = _mm512_and_si512(a, keep_1);
+  __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0);
+  __m512i c1;
+  if (left_shift)
+    c1 = _mm512_sllv_epi16(a1, b1);
+  else
+    if constexpr (std::is_same_v)
+      c1 = _mm512_srav_epi16(a1, b1);
+    else
+      c1 = _mm512_srlv_epi16(a1, b1);
+  c1 = _mm512_and_si512(c1, keep_1);
+
+  // Merge partial results into the final result.
+  __m512i c = _mm512_or_si512(c0, c1);
+
+  return c;
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sllv_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sllv_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return _mm512_sllv_epi16(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return shift_512_8(a, b);
+}
+
+template <>
+Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) {
+  return shift_512_8(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return _mm512_srav_epi64(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return _mm512_srav_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return _mm512_srav_epi16(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return shift_512_8(a, b);
+}
+
+template <>
+Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) {
+  return shift_512_8(a, b);
+}
+
+#endif
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h
new file mode 100644
index 0000000000000000000000000000000000000000..ffe9ada5c13f50492d4d69428b164db941da39c9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec512/vec512_qint.h
@@ -0,0 +1,1346 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+// This file defines Vectorized<> for the quantized types.
+//
+//
+// Currently, we simply use these classes as efficient converters between
+// the quantized types and Vectorized, usually in bandwidth-bound cases
+// where doing the arithmetic in full-precision is acceptable (e.g.
+// elementwise operators).
+//
+//
+// Conversions are as follows:
+//  Vectorized -> 4x Vectorized
+//  Vectorized -> 4x Vectorized
+//  Vectorized -> 1x Vectorized
+//
+// The size of the returned float vector is specified by the special
+// constexpr function float_num_vecs. The type of the value returned
+// from dequantize (and expected as an argument to quantize) is
+// specified by float_vec_return_type.
+//
+// When writing kernels with these vectors, it is expected that floating-
+// point operations will be carried out in a loop over Vectorized::float_num_vecs
+// iterations.
+
+namespace at {
+namespace vec {
+inline namespace CPU_CAPABILITY {
+
+#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
+
+struct Vectorizedqi {
+ protected:
+  __m512i vals __attribute__((aligned(64)));
+
+ public:
+  Vectorizedqi() {}
+  Vectorizedqi(__m512i v) : vals(v) {}
+  operator __m512i() const {
+    return vals;
+  }
+};
+
+
+template 
+__m512i pack_saturate_and_clamp(
+    __m512i first,
+    __m512i second,
+    T min_val,
+    T max_val);
+
+template <>
+inline __m512i pack_saturate_and_clamp(
+    __m512i first,
+    __m512i second,
+    int32_t min_val,
+    int32_t max_val) {
+  // This function is for linkage only, will not be used
+  AT_ERROR("pack_saturate_and_clamp is not supported");
+}
+
+template <>
+inline __m512i pack_saturate_and_clamp(
+    __m512i first,
+    __m512i second,
+    int8_t min_val,
+    int8_t max_val) {
+  __m512i packed_and_sat = _mm512_packs_epi16(first, second);
+  return _mm512_max_epi8(
+      _mm512_set1_epi8(min_val),
+      _mm512_min_epi8(packed_and_sat, _mm512_set1_epi8(max_val)));
+}
+
+template <>
+inline __m512i pack_saturate_and_clamp(
+    __m512i first,
+    __m512i second,
+    uint8_t min_val,
+    uint8_t max_val) {
+  __m512i packed_and_sat = _mm512_packus_epi16(first, second);
+  return _mm512_max_epu8(
+      _mm512_set1_epi8(min_val),
+      _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val)));
+}
+
+template 
+typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type
+inline convert_int8_to_float(at::vec::Vectorized src) {
+  // Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size()
+  // Only handle first 16*8 bits
+  __m128i input_128 = _mm512_castsi512_si128(src);
+  // Convert from 16*uint8/int8 to 16*int32
+  __m512i input_512_extended;
+  if constexpr (std::is_same_v)
+    input_512_extended = _mm512_cvtepu8_epi32(input_128);
+  else
+    input_512_extended = _mm512_cvtepi8_epi32(input_128);
+  // Convert from 16*int32 to 16*float32
+  return _mm512_cvtepi32_ps(input_512_extended);
+}
+
+template 
+typename std::enable_if::value || std::is_same::value, at::vec::Vectorized>::type
+inline convert_float_to_int8(at::vec::Vectorized src) {
+  // Convert from float32 to int32 with truncation
+  __m512i x_values_int32 = _mm512_cvttps_epi32(src);
+
+  // Convert from int32 to int16 using signed saturation
+  __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32);
+
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+
+  // Convert from int16 to uint8/int8 using unsigned saturation
+  __m512i xyzw_clamped_v = pack_saturate_and_clamp(
+      xy_packed_v, xy_packed_v, min_val, max_val);
+  __m512i permute_mask_v =
+      _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02,
+                      0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00);
+  return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
+}
+
+template 
+inline void __attribute__((always_inline)) QuantizeAvx512(
+    const float* src,
+    T* dst,
+    int len,
+    float inverse_scale,
+    int64_t zero_point) {
+  constexpr int VLEN = 16;
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+  const __m512i min_v = _mm512_set1_epi32(min_val);
+  const __m512i max_v = _mm512_set1_epi32(max_val);
+  // This is the largest int32 value < int32_max exactly representable in float
+  constexpr int32_t int32_float_max_val =
+      std::numeric_limits::max() - 127;
+  int i = 0;
+  __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale);
+  // clang-format off
+  static const __m512i shuffle_mask_v = _mm512_set_epi8(
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0x0c, 0x08, 0x04, 0x00,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0x0c, 0x08, 0x04, 0x00,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0x0c, 0x08, 0x04, 0x00,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0xff, 0xff, 0xff, 0xff,
+      0x0c, 0x08, 0x04, 0x00);
+  // clang-format on
+  __m512i permute_mask_v =
+      _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02,
+                       0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00);
+  __m512i permute_mask_l8_v =
+      _mm512_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+                       0x00, 0x00, 0x00, 0x00, 0x0c, 0x08, 0x04, 0x00);
+  int len_aligned = len / (VLEN * 4) * (VLEN * 4);
+  for (; i < len_aligned; i += 4 * VLEN) {
+    // x
+    __m512 x_vals = _mm512_load_ps(src + i);
+    __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v);
+    // If the floating point value is greater than int32_max,
+    // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
+    // Clip at int32_float_max_val to avoid this.
+    x_transformed_v =
+        _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val));
+    // y
+    __m512 y_vals = _mm512_load_ps(src + i + VLEN);
+    __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v);
+    y_transformed_v =
+        _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val));
+    // z
+    __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN);
+    __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v);
+    z_transformed_v =
+        _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val));
+    // w
+    __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN);
+    __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v);
+    w_transformed_v =
+        _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val));
+
+    __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v);
+    __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v);
+    __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v);
+    __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v);
+
+    // add zero point
+    x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point));
+    y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point));
+    z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point));
+    w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point));
+
+    __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v);
+    __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v);
+    __m512i xyzw_clamped_v =
+        pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val);
+
+    xyzw_clamped_v =
+        _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
+    _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v);
+  }
+
+  // Additional 8-lane AVX512 version to take advantage when len is smaller
+  // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
+  for (; i < len / VLEN * VLEN; i += VLEN) {
+    __m512 x_vals = _mm512_load_ps(src + i);
+    __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v);
+    x_transformed_v =
+        _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val));
+    __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v);
+    x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point));
+    __m512i x_clipped_v =
+        _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v));
+
+    x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v);
+    x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v);
+    _mm_storeu_si128(
+        reinterpret_cast<__m128i*>(dst + i),
+        _mm512_castsi512_si128(x_clipped_v));
+  }
+
+  for (; i < len; ++i) {
+    float transformed = src[i] * inverse_scale;
+
+    // Not exactly the same behavior as the vectorized code.
+    // The vectorized code above always rounds to even in halfway cases
+    // (https://software.intel.com/en-us/node/523819), but std::nearbyint
+    // does the same only when the current rounding mode is FE_TONEAREST.
+    // However, in practice, this should not be a problem because most cases
+    // use the default rounding mode FE_TONEAREST.
+    // Note that we cannot implement the same behavior as the vectorized code
+    // using std::round because it does rounding away from zero in halfway
+    // cases.
+    transformed = zero_point + std::nearbyint(transformed);
+    float clipped =
+        std::min(std::max(transformed, float(min_val)), float(max_val));
+    dst[i] = clipped;
+  }
+}
+
+template<>
+struct Vectorized : public Vectorizedqi {
+    using size_type = int;
+    static constexpr size_type size() {
+        return 16;
+    }
+
+    static constexpr int float_num_vecs() {
+        return 1;
+    }
+
+    static constexpr int int_num_vecs() {
+        return 1;
+    }
+
+    using float_vec_return_type = std::array, 1>;
+    using int_vec_return_type = std::array, 1>;
+    using value_type = c10::qint32::underlying;
+
+ public:
+    using Vectorizedqi::Vectorizedqi;
+    Vectorized() {}
+
+    Vectorized(__m512i vals_) { vals = vals_;}
+
+    // Broadcast constructor
+    Vectorized(const c10::qint32& val) {
+        value_type uw = val.val_;
+        vals = _mm512_set1_epi32(uw);
+    }
+
+    void store(void* ptr, int count = size()) const {
+      if (count != size()) {
+        memcpy(ptr, &vals, count * sizeof(value_type));
+      } else {
+        _mm512_storeu_si512((__m512i*)ptr, vals);
+      }
+    }
+
+    static Vectorized loadu(const void* ptr) {
+        return Vectorized(ptr);
+    }
+
+    static Vectorized loadu(const void* ptr, int64_t count) {
+        __at_align__ value_type tmp_values[size()];
+        // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+        // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+        // instructions while a loop would be compiled to one instruction.
+        for (const auto i : c10::irange(size())) {
+          tmp_values[i] = 0;
+        }
+        std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+        return loadu(tmp_values);
+    }
+
+    float_vec_return_type dequantize(
+        Vectorized scale,
+        Vectorized zero_point,
+        Vectorized scale_zp_premul) const {
+      __m512 float_vals = _mm512_cvtepi32_ps(vals);
+      return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)};
+    }
+
+    float_vec_return_type dequantize(
+        Vectorized scale,
+        Vectorized zero_point) const {
+      __m512 float_vals = _mm512_cvtepi32_ps(vals);
+      return {(Vectorized(float_vals) - zero_point) * scale};
+    }
+
+    static Vectorized quantize(
+        const float_vec_return_type& rhs,
+        float scale,
+        int32_t zero_point,
+        float inverse_scale) {
+      Vectorized retval;
+      auto rhs_data = (__m512)rhs[0];
+      at::native::quantize_vec(
+          scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 16);
+      return retval;
+    }
+
+    Vectorized maximum(Vectorized b) const {
+      return _mm512_max_epi32(vals, b.vals);
+    }
+
+    Vectorized minimum(Vectorized b) const {
+      return _mm512_min_epi32(vals, b.vals);
+    }
+
+    Vectorized relu(Vectorized zero_point) const {
+        return maximum(zero_point);
+    }
+
+    Vectorized relu6(
+        Vectorized zero_point,
+        Vectorized q_six) {
+      return _mm512_min_epi32(
+          _mm512_max_epi32(vals, zero_point.vals), q_six.vals);
+    }
+
+    int_vec_return_type widening_subtract(Vectorized b) const {
+      return {_mm512_sub_epi32(vals, b)};
+    }
+
+    static Vectorized requantize_from_int(
+        const int_vec_return_type& inp,
+        float multiplier,
+        int32_t zero_point) {
+      __m512 multiplier_v = _mm512_set1_ps(multiplier);
+      __m512i zero_point_v = _mm512_set1_epi32(zero_point);
+
+      __m512 scaled = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier_v);
+      __m512i rounded = _mm512_cvtps_epi32(scaled);
+      return _mm512_add_epi32(rounded, zero_point_v);
+    }
+
+ private:
+    // Load from memory constructor
+    Vectorized(const void* ptr) {
+      vals = _mm512_loadu_si512((const __m512i*)ptr);
+    }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline operator*(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return _mm512_mullo_epi32(a, b);
+}
+
+template <>
+Vectorized inline operator+(
+    const Vectorized& a,
+    const Vectorized& b) {
+  return _mm512_add_epi32(a, b);
+}
+
+/*
+ * Convert values from int32 back to int8/uint8
+ */
+template 
+__m512i RequantizeAvx512(
+    const std::array, 4>& inp,
+    __m512 multiplier,
+    __m512i zp) {
+  static_assert(
+      std::is_same::value || std::is_same::value,
+      "Only int8_t/uint8_t are supported");
+  constexpr auto min_val = std::numeric_limits::min();
+  constexpr auto max_val = std::numeric_limits::max();
+  __m512i permute_mask_v =
+      _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02,
+                       0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00);
+  __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier);
+  __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier);
+  __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier);
+  __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier);
+
+  __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v);
+  __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v);
+  __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v);
+  __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v);
+
+  /* Add zero point */
+  __m512i x_v = _mm512_add_epi32(x_rounded_v, zp);
+  __m512i y_v = _mm512_add_epi32(y_rounded_v, zp);
+  __m512i z_v = _mm512_add_epi32(z_rounded_v, zp);
+  __m512i w_v = _mm512_add_epi32(w_rounded_v, zp);
+
+  /* Pack to int16_t and saturate */
+  __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v);
+  __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v);
+
+  __m512i xyzw_clamped_v =
+      pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val);
+
+  /*
+   * xyzw_clamped_v has results in the following layout so we need to
+   * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11 x12-15 y12-15 z12-15 w12-15
+   */
+  xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
+  return xyzw_clamped_v;
+}
+
+template<>
+struct Vectorized : public Vectorizedqi {
+    static constexpr int size() {
+        return 64;
+    }
+
+    static constexpr int float_num_vecs() {
+        return 4;
+    }
+
+    static constexpr int int_num_vecs() {
+        return 4;
+    }
+
+    using float_vec_return_type = std::array, 4>;
+    using int_vec_return_type = std::array, 4>;
+    using value_type = typename c10::qint8::underlying;
+
+ public:
+    using Vectorizedqi::Vectorizedqi;
+
+    Vectorized() {}
+    Vectorized(__m512i vals_) { vals = vals_;}
+
+    // Broadcast constructor
+    Vectorized(const c10::qint8& val) {
+        value_type uw = val.val_;
+        vals = _mm512_set1_epi8(uw);
+    }
+
+    // This is needed because the compiler emits awful code for the default
+    // constructor for moving the enum
+    Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { }
+
+    // This is added to avoid error: definition of implicit copy assignment operator
+    // for 'Vectorized' is deprecated because it has a user-declared
+    // copy constructor [-Werror,-Wdeprecated-copy]
+    Vectorized& operator=(const Vectorized&) = default;
+
+    void store(void* ptr, int count = size()) const {
+        if (count != size()) {
+            memcpy(ptr, &vals, count * sizeof(value_type));
+        } else {
+            _mm512_storeu_si512((__m512i*)ptr, vals);
+        }
+    }
+
+    static Vectorized loadu(const void* ptr) {
+        return Vectorized(ptr);
+    }
+
+    static Vectorized loadu(const void* ptr, int64_t count) {
+        __at_align__ value_type tmp_values[size()];
+        // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+        // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+        // instructions while a loop would be compiled to one instruction.
+        for (const auto i : c10::irange(size())) {
+          tmp_values[i] = 0;
+        }
+        std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+        return loadu(tmp_values);
+    }
+
+ private:
+    __m512i cvtepi8_epi32(__m128i epi8_vals) const {
+        return _mm512_cvtepi8_epi32(epi8_vals);
+    }
+
+ public:
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_neg_zp_premul) const {
+    __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
+    __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
+    __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
+    __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
+
+    __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
+    __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
+    __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2));
+    __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3));
+
+    auto val0 =
+        vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul);
+    auto val1 =
+        vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul);
+    auto val2 =
+        vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul);
+    auto val3 =
+        vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul);
+    return {val0, val1, val2, val3};
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
+    __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
+    __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
+    __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
+
+    __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
+    __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
+    __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2));
+    __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3));
+
+    auto val0 = (Vectorized(float_val0) - zero_point) * scale;
+    auto val1 = (Vectorized(float_val1) - zero_point) * scale;
+    auto val2 = (Vectorized(float_val2) - zero_point) * scale;
+    auto val3 = (Vectorized(float_val3) - zero_point) * scale;
+    return {val0, val1, val2, val3};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    auto* rhs_data = (float*)rhs.data();
+    int8_t quantized_values[64];
+    QuantizeAvx512(
+        rhs_data, quantized_values, 64, inverse_scale, zero_point);
+    return Vectorized::loadu(quantized_values);
+  }
+
+  Vectorized maximum(Vectorized b) const {
+      return _mm512_max_epi8(vals, b.vals);
+    }
+
+  Vectorized minimum(Vectorized b) const {
+      return _mm512_min_epi8(vals, b.vals);
+    }
+
+    Vectorized relu(Vectorized zero_point) const {
+        return maximum(zero_point);
+    }
+
+    Vectorized relu6(
+        Vectorized zero_point,
+        Vectorized q_six) {
+      return _mm512_min_epi8(
+          _mm512_max_epi8(vals, zero_point.vals), q_six.vals);
+    }
+
+    int_vec_return_type widening_subtract(Vectorized b) const {
+      __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
+      __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
+      __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
+      __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
+
+      __m512i int32_val0 = cvtepi8_epi32(int_val0);
+      __m512i int32_val1 = cvtepi8_epi32(int_val1);
+      __m512i int32_val2 = cvtepi8_epi32(int_val2);
+      __m512i int32_val3 = cvtepi8_epi32(int_val3);
+
+      __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
+      __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
+      __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
+      __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
+
+      __m512i int32_b0 = cvtepi8_epi32(int_b0);
+      __m512i int32_b1 = cvtepi8_epi32(int_b1);
+      __m512i int32_b2 = cvtepi8_epi32(int_b2);
+      __m512i int32_b3 = cvtepi8_epi32(int_b3);
+
+      __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0);
+      __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1);
+      __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2);
+      __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3);
+
+      return {Vectorized(res_0),
+              Vectorized(res_1),
+              Vectorized(res_2),
+              Vectorized(res_3)};
+    }
+
+    static Vectorized requantize_from_int(
+        const int_vec_return_type& inp,
+        float multiplier,
+        int32_t zero_point) {
+      __m512 multiplier_v = _mm512_set1_ps(multiplier);
+      __m512i zero_point_v = _mm512_set1_epi32(zero_point);
+      return RequantizeAvx512(inp, multiplier_v, zero_point_v);
+    }
+
+ private:
+    // Load from memory constructor
+    Vectorized(const void* ptr) {
+        vals = _mm512_loadu_si512((const __m512i*)ptr);
+    }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template<>
+struct Vectorized : public Vectorizedqi {
+    static constexpr int size() {
+        return 64;
+    }
+
+    static constexpr int float_num_vecs() {
+        return 4;
+    }
+
+    static constexpr int int_num_vecs() {
+        return 4;
+    }
+
+    using float_vec_return_type = std::array, 4>;
+    using int_vec_return_type = std::array, 4>;
+    using value_type = typename c10::quint8::underlying;
+
+ public:
+    using Vectorizedqi::Vectorizedqi;
+    Vectorized() {}
+
+    Vectorized(__m512i vals_) { vals = vals_;}
+
+    // Broadcast constructor
+    Vectorized(const c10::quint8& val) {
+        value_type uw = val.val_;
+        vals = _mm512_set1_epi8(uw);
+    }
+
+    Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { }
+
+    // This is added to avoid error: definition of implicit copy assignment operator
+    // for 'Vectorized' is deprecated because it has a user-declared
+    // copy constructor [-Werror,-Wdeprecated-copy]
+    Vectorized& operator=(const Vectorized&) = default;
+
+    void store(void* ptr, int count = size()) const {
+        if (count != size()) {
+            memcpy(ptr, &vals, count * sizeof(value_type));
+        } else {
+            _mm512_storeu_si512((__m512i*)ptr, vals);
+        }
+    }
+
+    static Vectorized loadu(const void* ptr) {
+        return Vectorized(ptr);
+    }
+
+    static Vectorized loadu(const void* ptr, int64_t count) {
+        __at_align__ value_type tmp_values[size()];
+        // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+        // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+        // instructions while a loop would be compiled to one instruction.
+        for (const auto i : c10::irange(size())) {
+          tmp_values[i] = 0;
+        }
+        std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+        return loadu(tmp_values);
+    }
+
+ private:
+    __m512i cvtepu8_epi32(__m128i epu8_vals) const {
+        return _mm512_cvtepu8_epi32(epu8_vals);
+    }
+
+ public:
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
+    __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
+    __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
+    __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
+
+    __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
+    __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
+    __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2));
+    __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3));
+
+    auto val0 =
+        vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul);
+    auto val1 =
+        vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul);
+    auto val2 =
+        vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul);
+    auto val3 =
+        vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul);
+
+    return {val0, val1, val2, val3};
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
+    __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
+    __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
+    __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
+
+    __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
+    __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
+    __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2));
+    __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3));
+
+    auto val0 = (Vectorized(float_val0) - zero_point) * scale;
+    auto val1 = (Vectorized(float_val1) - zero_point) * scale;
+    auto val2 = (Vectorized(float_val2) - zero_point) * scale;
+    auto val3 = (Vectorized(float_val3) - zero_point) * scale;
+
+    return {val0, val1, val2, val3};
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    auto* rhs_data = (float*)rhs.data();
+    uint8_t quantized_values[64];
+    QuantizeAvx512(
+        rhs_data, quantized_values, 64, inverse_scale, zero_point);
+    return Vectorized::loadu(quantized_values);
+  }
+
+  Vectorized maximum(Vectorized b) const {
+      return _mm512_max_epu8(vals, b.vals);
+    }
+
+  Vectorized minimum(Vectorized b) const {
+      return _mm512_min_epu8(vals, b.vals);
+    }
+
+    Vectorized relu(Vectorized zero_point) const {
+        return maximum(zero_point);
+    }
+
+    Vectorized relu6(
+        Vectorized zero_point,
+        Vectorized q_six) {
+      return _mm512_min_epu8(
+          _mm512_max_epu8(vals, zero_point.vals), q_six.vals);
+    }
+
+    int_vec_return_type widening_subtract(Vectorized b) const {
+      __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
+      __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
+      __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
+      __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
+
+      __m512i int32_val0 = cvtepu8_epi32(int_val0);
+      __m512i int32_val1 = cvtepu8_epi32(int_val1);
+      __m512i int32_val2 = cvtepu8_epi32(int_val2);
+      __m512i int32_val3 = cvtepu8_epi32(int_val3);
+
+      __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
+      __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
+      __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
+      __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
+
+      __m512i int32_b0 = cvtepu8_epi32(int_b0);
+      __m512i int32_b1 = cvtepu8_epi32(int_b1);
+      __m512i int32_b2 = cvtepu8_epi32(int_b2);
+      __m512i int32_b3 = cvtepu8_epi32(int_b3);
+
+      __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0);
+      __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1);
+      __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2);
+      __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3);
+      return {Vectorized(res_0),
+              Vectorized(res_1),
+              Vectorized(res_2),
+              Vectorized(res_3)};
+    }
+
+    static Vectorized requantize_from_int(
+        const int_vec_return_type& inp,
+        float multiplier,
+        int32_t zero_point) {
+      __m512 multiplier_v = _mm512_set1_ps(multiplier);
+      __m512i zero_point_v = _mm512_set1_epi32(zero_point);
+      return RequantizeAvx512(inp, multiplier_v, zero_point_v);
+    }
+
+ private:
+
+    // Load from memory constructor
+    Vectorized(const void* ptr) {
+        vals = _mm512_loadu_si512((const __m512i*)ptr);
+    }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+#else
+
+// NOTE: These are low-performance implementations that we fall back on.
+
+template <
+    typename T,
+    typename float_vec_return_type_,
+    typename int_vec_return_type_,
+    int size_>
+struct VectorizedQuantizedConverter {
+  static constexpr int size() {
+    return size_;
+  }
+
+  static constexpr int float_num_vecs() {
+    return size() / 8;
+  }
+
+  static constexpr int int_num_vecs() {
+    return size() / 8;
+  }
+
+  using float_vec_return_type = float_vec_return_type_;
+  using int_vec_return_type = int_vec_return_type_;
+
+  using value_type = typename T::underlying;
+  std::array vals;
+
+  VectorizedQuantizedConverter(T val) {
+    for (const auto i : c10::irange(size())) {
+      vals[i] = val.val_;
+    }
+  }
+
+  VectorizedQuantizedConverter(const void* ptr) {
+    memcpy(vals.data(), ptr, sizeof(value_type) * size());
+  }
+
+  void store(void* ptr, int count = size()) const {
+    memcpy(ptr, vals.data(), count * sizeof(value_type));
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point,
+      Vectorized scale_zp_premul) const {
+    float_vec_return_type rv;
+    for (const auto i : c10::irange(float_num_vecs())) {
+      float tmp_vals[16];
+      for (const auto j : c10::irange(16)) {
+        tmp_vals[j] = at::native::dequantize_val(
+            scale[j], zero_point[j], T(vals[16 * i + j]));
+      }
+      rv[i] = Vectorized(tmp_vals[0],
+          tmp_vals[1],
+          tmp_vals[2],
+          tmp_vals[3],
+          tmp_vals[4],
+          tmp_vals[5],
+          tmp_vals[6],
+          tmp_vals[7],
+          tmp_vals[8],
+          tmp_vals[9],
+          tmp_vals[10],
+          tmp_vals[11],
+          tmp_vals[12],
+          tmp_vals[13],
+          tmp_vals[14],
+          tmp_vals[15]);
+    }
+    return rv;
+  }
+
+  float_vec_return_type dequantize(
+      Vectorized scale,
+      Vectorized zero_point) const {
+    Vectorized scale_zp_premul;
+    return dequantize(scale, zero_point, scale_zp_premul);
+  }
+
+ protected:
+  VectorizedQuantizedConverter() {}
+};
+
+template <>
+struct Vectorized : public VectorizedQuantizedConverter<
+                                 c10::qint32,
+                                 std::array, 1>,
+                                 std::array, 1>,
+                                 16> {
+  Vectorized()
+      : VectorizedQuantizedConverter<
+            c10::qint32,
+            std::array, 1>,
+            std::array, 1>,
+            16>() {}
+  Vectorized(c10::qint32 val)
+      : VectorizedQuantizedConverter<
+            c10::qint32,
+            std::array, 1>,
+            std::array, 1>,
+            16>(val) {}
+  Vectorized(const void* ptr)
+      : VectorizedQuantizedConverter<
+            c10::qint32,
+            std::array, 1>,
+            std::array, 1>,
+            16>(ptr) {}
+
+  static Vectorized loadu(const void* ptr) {
+    return Vectorized(ptr);
+  }
+
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ value_type tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+    return loadu(tmp_values);
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    std::array qvals;
+    std::array float_vals;
+
+    for (const auto i : c10::irange(float_num_vecs())) {
+      rhs[i].store(&float_vals[i * 16], 16);
+    }
+
+    at::native::quantize_vec(
+        scale,
+        zero_point,
+        float_vals.data(),
+        (c10::qint32*)qvals.data(),
+        16 * float_num_vecs());
+
+    return Vectorized::loadu(qvals.data());
+  }
+
+  Vectorized maximum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::max(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized minimum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized relu(Vectorized zero_point) const  {
+    return maximum(zero_point);
+  }
+
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(
+          std::max(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    int_vec_return_type retval;
+    for (const auto i : c10::irange(size())) {
+      retval[0].vals[i] = vals[i] - b.vals[i];
+    }
+    return retval;
+  }
+
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] =
+          std::nearbyint(static_cast(inp[0].vals[i]) * multiplier) +
+          zero_point;
+    }
+    return retval;
+  }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+Vectorized inline operator*(
+    const Vectorized& a,
+    const Vectorized& b) {
+  Vectorized retval;
+  for (const auto i : c10::irange(std::decay_t::size())) {
+    retval.vals[i] = a.vals[i] * b.vals[i];
+  }
+  return retval;
+}
+
+template <>
+Vectorized inline operator+(
+    const Vectorized& a,
+    const Vectorized& b) {
+  Vectorized retval;
+  for (const auto i : c10::irange(std::decay_t::size())) {
+    retval.vals[i] = a.vals[i] + b.vals[i];
+  }
+  return retval;
+}
+
+template <>
+struct Vectorized : public VectorizedQuantizedConverter<
+                                c10::qint8,
+                                std::array, 4>,
+                                std::array, 4>,
+                                64> {
+  Vectorized()
+      : VectorizedQuantizedConverter<
+            c10::qint8,
+            std::array, 4>,
+            std::array, 4>,
+            64>() {}
+  Vectorized(c10::qint8 val)
+      : VectorizedQuantizedConverter<
+            c10::qint8,
+            std::array, 4>,
+            std::array, 4>,
+            64>(val) {}
+  Vectorized(const void* ptr)
+      : VectorizedQuantizedConverter<
+            c10::qint8,
+            std::array, 4>,
+            std::array, 4>,
+            64>(ptr) {}
+
+  static Vectorized loadu(const void* ptr) {
+    return Vectorized(ptr);
+  }
+
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ value_type tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+    return loadu(tmp_values);
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    std::array qvals;
+    std::array float_vals;
+
+    for (const auto i : c10::irange(float_num_vecs())) {
+      rhs[i].store(&float_vals[i * 16], 16);
+    }
+
+    at::native::quantize_vec(
+        scale,
+        zero_point,
+        float_vals.data(),
+        (c10::qint8*)qvals.data(),
+        16 * float_num_vecs());
+
+    return Vectorized::loadu(qvals.data());
+  }
+
+  Vectorized maximum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::max(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized minimum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized relu(Vectorized zero_point) const {
+    return maximum(zero_point);
+  }
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(
+          std::max(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    int_vec_return_type retval;
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        retval[i].vals[j] =
+            static_cast(vals[i * elem_per_int_vec + j]) -
+            static_cast(b.vals[i * elem_per_int_vec + j]);
+      }
+    }
+    return retval;
+  }
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    constexpr auto min_val = std::numeric_limits::min();
+    constexpr auto max_val = std::numeric_limits::max();
+    Vectorized retval;
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        int32_t rounded =
+            std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) +
+            zero_point;
+        retval.vals[i * elem_per_int_vec + j] =
+            std::min(std::max(rounded, min_val), max_val);
+      }
+    }
+    return retval;
+  }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+template <>
+struct Vectorized : public VectorizedQuantizedConverter<
+                                 c10::quint8,
+                                 std::array, 4>,
+                                 std::array, 4>,
+                                 64> {
+  Vectorized()
+      : VectorizedQuantizedConverter<
+            c10::quint8,
+            std::array, 4>,
+            std::array, 4>,
+            64>() {}
+  Vectorized(c10::quint8 val)
+      : VectorizedQuantizedConverter<
+            c10::quint8,
+            std::array, 4>,
+            std::array, 4>,
+            64>(val) {}
+  Vectorized(const void* ptr)
+      : VectorizedQuantizedConverter<
+            c10::quint8,
+            std::array, 4>,
+            std::array, 4>,
+            64>(ptr) {}
+
+  static Vectorized loadu(const void* ptr) {
+    return Vectorized(ptr);
+  }
+
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    __at_align__ value_type tmp_values[size()];
+    // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
+    // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
+    // instructions while a loop would be compiled to one instruction.
+    for (const auto i : c10::irange(size())) {
+      tmp_values[i] = 0;
+    }
+    std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type));
+    return loadu(tmp_values);
+  }
+
+  static Vectorized quantize(
+      const float_vec_return_type& rhs,
+      float scale,
+      int32_t zero_point,
+      float inverse_scale) {
+    std::array qvals;
+    std::array float_vals;
+
+    for (const auto i : c10::irange(float_num_vecs())) {
+      rhs[i].store(&float_vals[i * 16], 16);
+    }
+
+    at::native::quantize_vec(
+        scale,
+        zero_point,
+        float_vals.data(),
+        (c10::quint8*)qvals.data(),
+        16 * float_num_vecs());
+
+    return Vectorized::loadu(qvals.data());
+  }
+
+  Vectorized maximum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::max(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized minimum(Vectorized b) const {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(vals[i], b.vals[i]);
+    }
+    return retval;
+  }
+
+  Vectorized relu(Vectorized zero_point) const {
+    return maximum(zero_point);
+  }
+
+
+  Vectorized relu6(
+      Vectorized zero_point,
+      Vectorized q_six) {
+    Vectorized retval;
+    for (const auto i : c10::irange(size())) {
+      retval.vals[i] = std::min(
+          std::max(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+  int_vec_return_type widening_subtract(Vectorized b) const {
+    int_vec_return_type retval;
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        retval[i].vals[j] =
+            static_cast(vals[i * elem_per_int_vec + j]) -
+            static_cast(b.vals[i * elem_per_int_vec + j]);
+      }
+    }
+    return retval;
+  }
+  static Vectorized requantize_from_int(
+      const int_vec_return_type& inp,
+      float multiplier,
+      int32_t zero_point) {
+    constexpr int elem_per_int_vec = size() / int_num_vecs();
+    constexpr auto min_val = std::numeric_limits::min();
+    constexpr auto max_val = std::numeric_limits::max();
+    Vectorized retval;
+    for (const auto i : c10::irange(int_num_vecs())) {
+      for (const auto j : c10::irange(elem_per_int_vec)) {
+        int32_t rounded =
+            std::nearbyint(static_cast(inp[i].vals[j]) * multiplier) +
+            zero_point;
+        retval.vals[i * elem_per_int_vec + j] =
+            std::min(std::max(rounded, min_val), max_val);
+      }
+    }
+    return retval;
+  }
+};
+
+template <>
+Vectorized inline maximum(const Vectorized& a, const Vectorized& b) {
+  return a.maximum(b);
+}
+
+#endif // defined(CPU_CAPABILITY_AVX512) && !defined(MSVC)
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_base.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_base.h
new file mode 100644
index 0000000000000000000000000000000000000000..85dd7207272310c7be95d5a44772c418e7b20558
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_base.h
@@ -0,0 +1,1108 @@
+#pragma once
+
+// DO NOT DEFINE STATIC DATA IN THIS HEADER!
+// See Note [Do not compile initializers with AVX]
+//
+// Note [Do not compile initializers with AVX]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// If you define a static initializer in this file, the initialization will use
+// AVX instructions because these object files are compiled with AVX enabled.
+// We need to avoid non-trivial global data in these architecture specific files
+// because there's no way to guard the global initializers with CPU capability
+// detection.
+//
+// See https://github.com/pytorch/pytorch/issues/37577 for an instance
+// of this bug in the past.
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// These macros helped us unify vec_base.h
+#ifdef CPU_CAPABILITY_AVX512
+#if defined(__GNUC__)
+#define __at_align__ __attribute__((aligned(64)))
+#elif defined(_WIN32)
+#define __at_align__ __declspec(align(64))
+#else
+#define __at_align__
+#endif
+#define VECTOR_WIDTH 64
+#define int_vector __m512i
+#else // CPU_CAPABILITY_AVX512
+#if defined(__GNUC__)
+#define __at_align__ __attribute__((aligned(32)))
+#elif defined(_WIN32)
+#define __at_align__ __declspec(align(32))
+#else
+#define __at_align__
+#endif
+#define VECTOR_WIDTH 32
+#define int_vector __m256i
+#endif // CPU_CAPABILITY_AVX512
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+// at::Half and at::BFloat16 should be treated as floating point
+template 
+struct is_floating_point:
+    std::integral_constant::value ||
+      std::is_same::value ||
+      std::is_same::value> {
+};
+
+template
+constexpr bool is_floating_point_v = is_floating_point::value;
+
+template 
+struct is_reduced_floating_point:
+    std::integral_constant::value ||
+      std::is_same::value> {
+};
+
+template 
+constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value;
+
+template struct int_of_size;
+
+#define DEFINE_INT_OF_SIZE(int_t) \
+template<> struct int_of_size { using type = int_t; }
+
+DEFINE_INT_OF_SIZE(int64_t);
+DEFINE_INT_OF_SIZE(int32_t);
+DEFINE_INT_OF_SIZE(int16_t);
+DEFINE_INT_OF_SIZE(int8_t);
+
+#undef DEFINE_INT_OF_SIZE
+
+template 
+using int_same_size_t = typename int_of_size::type;
+
+// NOTE: If you specialize on a type, you must define all operations!
+
+// emulates Vectorized types
+#if defined(__s390x__)
+template 
+#else
+template 
+#endif
+struct Vectorized {
+private:
+  __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
+public:
+  using value_type = T;
+  using size_type = int;
+  // Note [constexpr static function to avoid odr-usage compiler bug]
+  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+  // Why, you might ask, is size defined to be a static constexpr function,
+  // rather than a more ordinary 'static constexpr int size;' variable?
+  // The problem lies within ODR rules for static constexpr members versus
+  // static constexpr functions.  First, recall that this class (along with all
+  // of its derivations) live in an anonymous namespace: they are intended to be
+  // *completely* inlined at their use-sites, because we need to compile it
+  // multiple times for different instruction sets.
+  //
+  // Because of this constraint, we CANNOT provide a single definition for
+  // any static members in this class; since we want to compile the class
+  // multiple times, there wouldn't actually be any good place to put the
+  // definition.  Now here is the problem: if we ODR-use a static constexpr
+  // member, we are *obligated* to provide a definition.  Without the
+  // definition, you get a compile error like:
+  //
+  //    relocation R_X86_64_PC32 against undefined symbol
+  //    `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
+  //    a shared object; recompile with -fPIC
+  //
+  // If this were C++17, we could replace a static constexpr variable with
+  // an inline variable which doesn't require one definition. But we are not
+  // C++17.  So the next best thing is to replace the member with a static
+  // constexpr (and therefore inline) function, which does not require ODR
+  // either.
+  //
+  // Also, technically according to the C++ standard, we don't have to define
+  // a constexpr variable if we never odr-use it.  But it seems that some
+  // versions GCC/Clang have buggy determinations on whether or not an
+  // identifier is odr-used or not, and in any case it's hard to tell if
+  // a variable is odr-used or not.  So best to just cut the problem at the root.
+  static constexpr size_type size() {
+    return VECTOR_WIDTH / sizeof(T);
+  }
+  Vectorized() : values{static_cast(0)} {}
+  Vectorized(T val) {
+    for (int i = 0; i != size(); i++) {
+      values[i] = val;
+    }
+  }
+  template>
+  Vectorized(Args... vals) : values{vals...}{
+  }
+  // This also implies const T& operator[](int idx) const
+  inline operator const T*() const {
+    return values;
+  }
+  // This also implies T& operator[](int idx)
+  inline operator T*() {
+    return values;
+  }
+  // Return the values as char* for type punning
+  auto as_bytes() const -> const char* {
+    return reinterpret_cast(values);
+  }
+  template 
+  static Vectorized blend(const Vectorized& a, const Vectorized& b) {
+    int64_t mask = mask_;
+    Vectorized vector;
+    for (const auto i : c10::irange(size())) {
+      if (mask & 0x01) {
+        vector[i] = b[i];
+      } else {
+        vector[i] = a[i];
+      }
+      mask = mask >> 1;
+    }
+    return vector;
+  }
+  static Vectorized blendv(const Vectorized& a, const Vectorized& b,
+                          const Vectorized& mask) {
+    Vectorized vector;
+    int_same_size_t buffer[size()];
+    mask.store(buffer);
+    for (const auto i : c10::irange(size())) {
+      if (buffer[i] & 0x01)
+       {
+        vector[i] = b[i];
+      } else {
+        vector[i] = a[i];
+      }
+    }
+    return vector;
+  }
+  template  // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
+  static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) {
+    Vectorized vector;
+    for (const auto i : c10::irange(size())) {
+      vector.values[i] = base + i * step;
+    }
+    return vector;
+  }
+  static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) {
+    Vectorized vector;
+    for (const auto i : c10::irange(size())) {
+      if (i < count) {
+        vector[i] = b[i];
+      } else {
+        vector[i] = a[i];
+      }
+    }
+    return vector;
+  }
+  static Vectorized loadu(const void* ptr) {
+    Vectorized vector;
+    std::memcpy(vector.values, ptr, VECTOR_WIDTH);
+    return vector;
+  }
+  static Vectorized loadu(const void* ptr, int64_t count) {
+    Vectorized vector;
+    std::memcpy(vector.values, ptr, count * sizeof(T));
+    return vector;
+  }
+  void store(void* ptr, int count = size()) const {
+    std::memcpy(ptr, values, count * sizeof(T));
+  }
+  int zero_mask() const {
+    // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
+    int mask = 0;
+    for (int i = 0; i < size(); ++ i) {
+      if (values[i] == static_cast(0)) {
+        mask |= (1 << i);
+      }
+    }
+    return mask;
+  }
+  Vectorized isnan() const {
+    Vectorized vector;
+    for (int64_t i = 0; i != size(); i++) {
+      if (_isnan(values[i])) {
+        std::memset(static_cast(vector.values + i), 0xFF, sizeof(T));
+      } else {
+        std::memset(static_cast(vector.values + i), 0, sizeof(T));
+      }
+    }
+    return vector;
+  }
+  bool has_inf_nan() const {
+    for (int64_t i = 0; i != size(); i++) {
+      if(_isnan(values[i]) || _isinf(values[i])) {
+        return true;
+      }
+    }
+    return false;
+  }
+  Vectorized map(T (*const f)(T)) const {
+    Vectorized ret;
+    for (int64_t i = 0; i != size(); i++) {
+      ret[i] = f(values[i]);
+    }
+    return ret;
+  }
+  Vectorized map(T (*const f)(const T &)) const {
+    Vectorized ret;
+    for (int64_t i = 0; i != size(); i++) {
+      ret[i] = f(values[i]);
+    }
+    return ret;
+  }
+  template  && !c10::is_complex::value, int>::type = 0>
+  Vectorized abs() const {
+    // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "other_t_abs must be T");
+    return map([](T x) -> T { return x < static_cast(0) ? -x : x; });
+  }
+  template , int>::type = 0>
+  Vectorized abs() const {
+    // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "float_t_abs must be T");
+    // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
+    // 0.0) properly.
+    return map([](T x) -> T { return std::abs(x); });
+  }
+  template ::value, int>::type = 0>
+  Vectorized abs() const {
+    // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "complex_t_abs must be T");
+    // Specifically map() does not perform the type conversion needed by abs.
+    return map([](T x) { return static_cast(std::abs(x)); });
+  }
+
+  template ::value, int>::type = 0>
+  Vectorized sgn() const {
+    return map(at::native::sgn_impl);
+  }
+
+  template ::value, int>::type = 0>
+  Vectorized angle() const {
+    // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "other_t_angle must be T");
+    return map(at::native::angle_impl);  // compiler is unable to resolve the overload without 
+  }
+  template ::value, int>::type = 0>
+  Vectorized angle() const {
+    // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "complex_t_angle must be T");
+    return map([](T x) { return static_cast(std::arg(x)); });
+  }
+  template ::value, int>::type = 0>
+  Vectorized real() const {
+    // other_t_real is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "other_t_real must be T");
+    return *this;
+  }
+  template ::value, int>::type = 0>
+  Vectorized real() const {
+    // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "complex_t_real must be T");
+    return map([](T x) { return static_cast(x.real()); });
+  }
+  template ::value, int>::type = 0>
+  Vectorized imag() const {
+    // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "other_t_imag must be T");
+    return Vectorized(0);
+  }
+  template ::value, int>::type = 0>
+  Vectorized imag() const {
+    // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "complex_t_imag must be T");
+    return map([](T x) { return static_cast(x.imag()); });
+  }
+  template ::value, int>::type = 0>
+  Vectorized conj() const {
+    // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "other_t_conj must be T");
+    return *this;
+  }
+  template ::value, int>::type = 0>
+  Vectorized conj() const {
+    // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "complex_t_conj must be T");
+    return map([](T x) { return static_cast(std::conj(x)); });
+  }
+  Vectorized acos() const {
+    return map(std::acos);
+  }
+  Vectorized acosh() const {
+    return map(std::acosh);
+  }
+  Vectorized asin() const {
+    return map(std::asin);
+  }
+  Vectorized atan() const {
+    return map(std::atan);
+  }
+  Vectorized atanh() const {
+    return map(std::atanh);
+  }
+  Vectorized atan2(const Vectorized &exp) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = std::atan2(values[i], exp[i]);
+    }
+    return ret;
+  }
+  template <
+    typename U = T,
+    typename std::enable_if_t, int> = 0>
+  Vectorized copysign(const Vectorized &sign) const {
+    Vectorized ret;
+    for (size_type i = 0; i < size(); i++) {
+      ret[i] = c10::copysign(values[i], sign[i]);
+    }
+    return ret;
+  }
+  Vectorized erf() const {
+    return map(std::erf);
+  }
+  Vectorized erfc() const {
+    return map(std::erfc);
+  }
+  Vectorized erfinv() const {
+    return map(calc_erfinv);
+  }
+  Vectorized exp() const {
+    return map(std::exp);
+  }
+  Vectorized exp2() const {
+    return map(exp2_impl);
+  }
+  Vectorized expm1() const {
+    return map(std::expm1);
+  }
+  Vectorized exp_u20() const {
+    return map(std::exp);
+  }
+  Vectorized frac() const {
+    return *this - this->trunc();
+  }
+  template <
+    typename U = T,
+    typename std::enable_if_t, int> = 0>
+  Vectorized fmod(const Vectorized& q) const {
+    // U is for SFINAE purposes only. Make sure it is not changed.
+    static_assert(std::is_same::value, "U must be T");
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = std::fmod(values[i], q[i]);
+    }
+    return ret;
+  }
+  Vectorized log() const {
+    return map(std::log);
+  }
+  Vectorized log10() const {
+    return map(std::log10);
+  }
+  Vectorized log1p() const {
+    return map(std::log1p);
+  }
+  template ::value, int>::type = 0>
+  Vectorized log2() const {
+    // other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "other_t_log2 must be T");
+    return map(std::log2);
+  }
+  template ::value, int>::type = 0>
+  Vectorized log2() const {
+    // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
+    static_assert(std::is_same::value, "complex_t_log2 must be T");
+    const T log_2 = T(std::log(2.0));
+    return Vectorized(map(std::log))/Vectorized(log_2);
+  }
+  Vectorized ceil() const {
+    return map(at::native::ceil_impl);
+  }
+  Vectorized cos() const {
+    return map(std::cos);
+  }
+  Vectorized cosh() const {
+    return map(std::cosh);
+  }
+  Vectorized floor() const {
+    return map(at::native::floor_impl);
+  }
+  Vectorized hypot(const Vectorized &b) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = std::hypot(values[i], b[i]);
+    }
+    return ret;
+  }
+  Vectorized i0() const {
+    return map(calc_i0);
+  }
+  Vectorized i0e() const {
+    return map(calc_i0e);
+  }
+  Vectorized digamma() const {
+    return map(calc_digamma);
+  }
+  Vectorized igamma(const Vectorized &x) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = calc_igamma(values[i], x[i]);
+    }
+    return ret;
+  }
+  Vectorized igammac(const Vectorized &x) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = calc_igammac(values[i], x[i]);
+    }
+    return ret;
+  }
+  Vectorized neg() const {
+    // NB: the trailing return type is needed because we need to coerce the
+    // return value back to T in the case of unary operator- incuring a
+    // promotion
+    return map([](T x) -> T { return -x; });
+  }
+  Vectorized nextafter(const Vectorized &b) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = std::nextafter(values[i], b[i]);
+    }
+    return ret;
+  }
+  Vectorized round() const {
+    // We do not use std::round because we would like to round midway numbers to the nearest even integer.
+    return map(at::native::round_impl);
+  }
+  Vectorized sin() const {
+    return map(std::sin);
+  }
+  Vectorized sinh() const {
+    return map(std::sinh);
+  }
+  Vectorized tan() const {
+    return map(std::tan);
+  }
+  Vectorized tanh() const {
+    return map(std::tanh);
+  }
+  Vectorized trunc() const {
+    return map(at::native::trunc_impl);
+  }
+  Vectorized lgamma() const {
+    return map(std::lgamma);
+  }
+  Vectorized sqrt() const {
+    return map(std::sqrt);
+  }
+  Vectorized reciprocal() const {
+    return map([](T x) { return (T)(1) / x; });
+  }
+  Vectorized rsqrt() const {
+    return map([](T x) { return (T)1 / std::sqrt(x); });
+  }
+  Vectorized pow(const Vectorized &exp) const {
+    Vectorized ret;
+    for (const auto i : c10::irange(size())) {
+      ret[i] = std::pow(values[i], exp[i]);
+    }
+    return ret;
+  }
+private:
+  template 
+  inline Vectorized binary_pred(const Vectorized& other, Op op) const {
+    // All bits are set to 1 if the pred is true, otherwise 0.
+    Vectorized vector;
+    for (int64_t i = 0; i != size(); i++) {
+      if (op(values[i], other.values[i])) {
+        std::memset(static_cast(vector.values + i), 0xFF, sizeof(T));
+      } else {
+        std::memset(static_cast(vector.values + i), 0, sizeof(T));
+      }
+    }
+    return vector;
+  }
+
+public:
+  Vectorized operator==(const Vectorized& other) const { return binary_pred(other, std::equal_to()); }
+  Vectorized operator!=(const Vectorized& other) const { return binary_pred(other, std::not_equal_to()); }
+  Vectorized operator>=(const Vectorized& other) const { return binary_pred(other, std::greater_equal()); }
+  Vectorized operator<=(const Vectorized& other) const { return binary_pred(other, std::less_equal()); }
+  Vectorized operator>(const Vectorized& other) const { return binary_pred(other, std::greater()); }
+  Vectorized operator<(const Vectorized& other) const { return binary_pred(other, std::less()); }
+
+private:
+  template 
+  inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const {
+    // 1 if the pred is true, otherwise 0.
+    Vectorized vector;
+    for (int i = 0; i != size(); ++ i) {
+      vector[i] = static_cast(op(values[i], other.values[i]));
+    }
+    return vector;
+  }
+
+public:
+  Vectorized eq(const Vectorized& other) const { return binary_pred_bool(other, std::equal_to()); }
+  Vectorized ne(const Vectorized& other) const { return binary_pred_bool(other, std::not_equal_to()); }
+  Vectorized gt(const Vectorized& other) const { return binary_pred_bool(other, std::greater()); }
+  Vectorized ge(const Vectorized& other) const { return binary_pred_bool(other, std::greater_equal()); }
+  Vectorized lt(const Vectorized& other) const { return binary_pred_bool(other, std::less()); }
+  Vectorized le(const Vectorized& other) const { return binary_pred_bool(other, std::less_equal()); }
+};
+
+template  Vectorized inline operator+(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] + b[i];
+  }
+  return c;
+}
+
+template  Vectorized inline operator-(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] - b[i];
+  }
+  return c;
+}
+
+template  Vectorized inline operator*(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] * b[i];
+  }
+  return c;
+}
+
+template  Vectorized inline operator/(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] / b[i];
+  }
+  return c;
+}
+
+template , int>::type = 0>
+Vectorized inline operator%(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ {
+  return a - a / b * b;
+}
+
+template  Vectorized inline operator||(
+    const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] || b[i];
+  }
+  return c;
+}
+
+// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
+// either input is a NaN.
+template ::value, int>::type = 0>
+Vectorized inline maximum(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = (a[i] > b[i]) ? a[i] : b[i];
+    if (_isnan(a[i])) {
+      // If either input is NaN, propagate a NaN.
+      // NOTE: The case where b[i] was NaN is handled correctly by the naive
+      // ternary operator above.
+      c[i] = a[i];
+    }
+  }
+  return c;
+}
+
+template ::value, int>::type = 0>
+Vectorized inline maximum(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
+    if (_isnan(a[i])) {
+      // If either input is NaN, propagate a NaN.
+      // NOTE: The case where b[i] was NaN is handled correctly by the naive
+      // ternary operator above.
+      c[i] = a[i];
+    }
+  }
+  return c;
+}
+
+// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
+// either input is a NaN.
+template ::value, int>::type = 0>
+Vectorized inline minimum(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = (a[i] < b[i]) ? a[i] : b[i];
+    if (_isnan(a[i])) {
+      // If either input is NaN, propagate a NaN.
+      // NOTE: The case where b[i] was NaN is handled correctly by the naive
+      // ternary operator above.
+      c[i] = a[i];
+    }
+  }
+  return c;
+}
+
+template ::value, int>::type = 0>
+Vectorized inline minimum(const Vectorized &a, const Vectorized &b) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
+    if (_isnan(a[i])) {
+      // If either input is NaN, propagate a NaN.
+      // NOTE: The case where b[i] was NaN is handled correctly by the naive
+      // ternary operator above.
+      c[i] = a[i];
+    }
+  }
+  return c;
+}
+
+template ::value, int>::type = 0>
+Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, const Vectorized &max_vec) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
+  }
+  return c;
+}
+
+template ::value, int>::type = 0>
+Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_vec) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
+  }
+  return c;
+}
+
+template ::value, int>::type = 0>
+Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_vec) {
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
+  }
+  return c;
+}
+
+struct Vectorizedi;
+
+#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
+template 
+static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) {
+  int_vector buffer;
+#if defined(CPU_CAPABILITY_AVX2)
+  int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a));
+  int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b));
+#elif defined(CPU_CAPABILITY_AVX512)
+  int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a));
+  int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b));
+#endif
+  buffer = op(a_buffer, b_buffer);
+  __at_align__ T results[Vectorized::size()];
+
+#if defined(CPU_CAPABILITY_AVX2)
+  _mm256_store_si256(reinterpret_cast(results), buffer);
+#elif defined(CPU_CAPABILITY_AVX512)
+  _mm512_store_si512(reinterpret_cast(results), buffer);
+#endif
+  return Vectorized::loadu(results);
+}
+
+template>::value, int> = 0>
+inline Vectorized operator&(const Vectorized& a, const Vectorized& b) {
+  // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
+#if defined(CPU_CAPABILITY_AVX2)
+  return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
+#elif defined(CPU_CAPABILITY_AVX512)
+  return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
+#endif
+}
+template>::value, int> = 0>
+inline Vectorized operator|(const Vectorized& a, const Vectorized& b) {
+  // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
+#if defined(CPU_CAPABILITY_AVX2)
+  return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
+#elif defined(CPU_CAPABILITY_AVX512)
+  return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
+#endif
+}
+template>::value, int> = 0>
+inline Vectorized operator^(const Vectorized& a, const Vectorized& b) {
+  // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
+#if defined(CPU_CAPABILITY_AVX2)
+  return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
+#elif defined(CPU_CAPABILITY_AVX512)
+  return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
+#endif
+}
+
+#else
+
+template 
+auto load(char const* data) -> T {
+  T ret;
+  std::memcpy(&ret, data, sizeof(ret));
+  return ret;
+}
+
+template
+static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) {
+  static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
+  __at_align__ intmax_t buffer[element_no];
+  static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
+  static_assert(sizeof(buffer) == sizeof(Vectorized), "sizeof(buffer) must match sizeof(Vectorized)");
+  // We should be using memcpy in order to respect the strict aliasing rule
+  // see: https://github.com/pytorch/pytorch/issues/66119
+  // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
+  // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
+  const auto* a_data = a.as_bytes();
+  const auto* b_data = b.as_bytes();
+  // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
+  for (auto& out : buffer) {
+    out = op(load(a_data), load(b_data));
+    a_data += sizeof(intmax_t);
+    b_data += sizeof(intmax_t);
+  }
+  assert(a_data == a.as_bytes() + sizeof(a));
+  assert(b_data == b.as_bytes() + sizeof(b));
+  return Vectorized::loadu(buffer);
+}
+
+template>::value, int> = 0>
+inline Vectorized operator&(const Vectorized& a, const Vectorized& b) {
+  return bitwise_binary_op(a, b, std::bit_and());
+}
+template>::value, int> = 0>
+inline Vectorized operator|(const Vectorized& a, const Vectorized& b) {
+  return bitwise_binary_op(a, b, std::bit_or());
+}
+template>::value, int> = 0>
+inline Vectorized operator^(const Vectorized& a, const Vectorized& b) {
+  return bitwise_binary_op(a, b, std::bit_xor());
+}
+
+#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
+
+template>::value, int> = 0>
+inline Vectorized operator~(const Vectorized& a) {
+  Vectorized ones;  // All bits are 1
+  memset((T*) ones, 0xFF, VECTOR_WIDTH);
+  return a ^ ones;
+}
+
+template  Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) {
+  constexpr T max_shift = sizeof(T) * CHAR_BIT;
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    T shift = b[i];
+    if ((static_cast>(shift) < 0) || (shift >= max_shift)) {
+      c[i] = 0;
+    } else {
+      c[i] = static_cast>(a[i]) << shift;
+    }
+  }
+  return c;
+}
+
+template  Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) {
+  // right shift value to retain sign bit for signed and no bits for unsigned
+  constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v;
+  Vectorized c;
+  for (int i = 0; i != Vectorized::size(); i++) {
+    T shift = b[i];
+    if ((static_cast>(shift) < 0) || (shift >= max_shift)) {
+      c[i] = a[i] >> max_shift;
+    } else {
+      c[i] = a[i] >> shift;
+    }
+  }
+  return c;
+}
+
+template 
+inline Vectorized& operator += (Vectorized& a, const Vectorized& b) {
+  a = a + b;
+  return a;
+}
+template 
+inline Vectorized& operator -= (Vectorized& a, const Vectorized& b) {
+  a = a - b;
+  return a;
+}
+template 
+inline Vectorized& operator /= (Vectorized& a, const Vectorized& b) {
+  a = a / b;
+  return a;
+}
+template 
+inline Vectorized& operator %= (Vectorized& a, const Vectorized& b) {
+  a = a % b;
+  return a;
+}
+template 
+inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) {
+  a = a * b;
+  return a;
+}
+
+template 
+inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) {
+  a = a << b;
+  return a;
+}
+
+template 
+inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) {
+  a = a >> b;
+  return a;
+}
+
+template 
+inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return a * b + c;
+}
+
+template 
+inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) {
+  return a * b - c;
+}
+
+template 
+std::enable_if_t>
+inline gather(T const* base_addr, const Vectorized>& vindex) {
+  static constexpr int size = Vectorized::size();
+  int_same_size_t index_arr[size];
+  vindex.store(static_cast(index_arr));
+  T buffer[size];
+  for (const auto i : c10::irange(size)) {
+    buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
+  }
+  return Vectorized::loadu(static_cast(buffer));
+}
+
+template 
+std::enable_if_t>
+inline mask_gather(const Vectorized& src, T const* base_addr,
+                   const Vectorized>& vindex, Vectorized& mask) {
+  static constexpr int size = Vectorized::size();
+  T src_arr[size];
+  int_same_size_t mask_arr[size];  // use int type so we can logical and
+  int_same_size_t index_arr[size];
+  src.store(static_cast(src_arr));
+  mask.store(static_cast(mask_arr));
+  vindex.store(static_cast(index_arr));
+  T buffer[size];
+  for (const auto i : c10::irange(size)) {
+    if (mask_arr[i] & 0x01) {  // check highest bit
+      buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
+    } else {
+      buffer[i] = src_arr[i];
+    }
+  }
+  mask = Vectorized();  // "zero out" mask
+  return Vectorized::loadu(static_cast(buffer));
+}
+
+// Cast a given vector to another type without changing the bits representation.
+// So a Vectorized of 512 bits containing all ones can be cast to a
+// Vectorized of 512 bits containing all ones (i.e., eight negative 1s).
+// A Vec of 256 bits containing all ones can be cast to a
+// Vec of 256 bits containing all ones (i.e., four negative 1s).
+// There is a struct here because we don't have static_if and I can't
+// partially specialize a templated function.
+template
+struct CastImpl {
+  static inline Vectorized apply(const Vectorized& src) {
+    src_t src_arr[Vectorized::size()];
+    src.store(static_cast(src_arr));
+    return Vectorized::loadu(static_cast(src_arr));
+  }
+};
+
+template
+struct CastImpl {
+  static inline Vectorized apply(const Vectorized& src) {
+    return src;
+  }
+};
+
+template
+inline Vectorized cast(const Vectorized& src) {
+  return CastImpl::apply(src);
+}
+
+template >
+inline Vectorized convert_to_int_of_same_size(const Vectorized& src) {
+  static_assert(sizeof(T) == sizeof(IntType));
+  static constexpr int size = Vectorized::size();
+
+  std::array src_arr;
+  src.store(static_cast(src_arr.data()));
+  std::array buffer;
+  std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
+                 [](const T& x) { return static_cast(x); });
+  return Vectorized::loadu(static_cast(buffer.data()));
+}
+
+template >
+inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) {
+  static_assert(sizeof(T) == sizeof(IntType));
+  static constexpr int size = Vectorized::size();
+
+  std::array src_arr;
+  src.store(static_cast(src_arr.data()));
+  std::array buffer;
+  std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
+                 [](const IntType& x) { return static_cast(x); });
+  return Vectorized::loadu(static_cast(buffer.data()));
+}
+
+// Example inputs for AVX512:
+// a   Vectorized   = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
+// b   Vectorized   = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
+// returns:
+//           Vectorized   = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
+//           Vectorized   = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
+// Example inputs for AVX2: a           Vectorized   = {a0, b0, a1, b1, a2, b2, a3, b3}
+//               b                      Vectorized   = {a4, b4, a5, b5, a6, b6, a7, b7}
+//       returns:                       Vectorized   = {a0, a1, a2, a3, a4, a5, a6, a7}
+//                                      Vectorized   = {b0, b1, b2, b3, b4, b5, b6, b7}
+template 
+inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>>
+deinterleave2(const Vectorized& a, const Vectorized& b) {
+  static constexpr int size = Vectorized::size();
+  static constexpr int half_size = size / 2;
+  T a_arr[size];
+  T b_arr[size];
+  T buffer1[size];
+  T buffer2[size];
+  a.store(static_cast(a_arr));
+  b.store(static_cast(b_arr));
+  for (const auto i : c10::irange(half_size)) {
+    buffer1[i] = a_arr[i * 2];
+    buffer1[half_size + i] = b_arr[i * 2];
+    buffer2[i] = a_arr[i * 2 + 1];
+    buffer2[half_size + i] = b_arr[i * 2 + 1];
+  }
+  return std::make_pair(Vectorized::loadu(static_cast(buffer1)),
+                        Vectorized::loadu(static_cast(buffer2)));
+}
+
+// inverse operation of deinterleave2
+// Example inputs for AVX512:
+//  a       Vectorized   = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
+//  b       Vectorized   = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
+// returns, for AVX512:
+//          Vectorized   = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
+//          Vectorized   = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
+// Example inputs for AVX2 : a           Vectorized   = {a0, a1, a2, a3, a4, a5, a6, a7}
+//                   b                   Vectorized   = {b0, b1, b2, b3, b4, b5, b6, b7}
+//       returns:            Vectorized   = {a0, b0, a1, b1, a2, b2, a3, b3}
+//                           Vectorized   = {a4, b4, a5, b5, a6, b6, a7, b7}
+template 
+inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>>
+interleave2(const Vectorized& a, const Vectorized& b) {
+  static constexpr int size = Vectorized::size();
+  static constexpr int half_size = size / 2;
+  T a_arr[size];
+  T b_arr[size];
+  T buffer1[size];
+  T buffer2[size];
+  a.store(static_cast(a_arr));
+  b.store(static_cast(b_arr));
+  for (const auto i : c10::irange(half_size)) {
+    buffer1[i * 2] = a_arr[i];
+    buffer1[i * 2 + 1] = b_arr[i];
+    buffer2[i * 2] = a_arr[half_size + i];
+    buffer2[i * 2 + 1] = b_arr[half_size + i];
+  }
+  return std::make_pair(Vectorized::loadu(static_cast(buffer1)),
+                        Vectorized::loadu(static_cast(buffer2)));
+}
+
+template 
+inline void convert(const src_T *src, dst_T *dst, int64_t n) {
+#ifndef _MSC_VER
+# pragma unroll
+#endif
+  for (C10_UNUSED const auto i : c10::irange(n)) {
+    *dst = c10::convert(c10::load(src));
+    src++;
+    dst++;
+  }
+}
+
+template 
+inline Vectorized flip(const Vectorized & data) {
+  static constexpr int size = Vectorized::size();
+  T output[size];
+  T buffer[size];
+  data.store(static_cast(buffer));
+  for (const auto i : c10::irange(size)) {
+    output[i] = buffer[size - i - 1];
+  }
+  return Vectorized::loadu(static_cast(output));
+}
+
+// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
+// dimension of `src` and `ld_dst` is the leading dimension of `dst`.
+template 
+inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
+  for (int i = 0; i < M; i++) {
+    for (int j = 0; j < N; j++) {
+      dst[j*ld_dst + i] = src[i*ld_src + j];
+    }
+  }
+}
+
+}} // namespace at::vec::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_half.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_half.h
new file mode 100644
index 0000000000000000000000000000000000000000..7cdc259581da37601221b2929702afc98938619a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_half.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include 
+
+namespace at::vec {
+// See Note [CPU_CAPABILITY namespace]
+inline namespace CPU_CAPABILITY {
+
+#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
+    !defined(__APPLE__)
+static inline uint16_t float2half_scalar(float val) {
+#if defined(CPU_CAPABILITY_AVX2)
+#if defined(_MSC_VER)
+  __m256 v = _mm256_set1_ps(val);
+  __m128i o =
+      _mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  return static_cast(_mm_cvtsi128_si32(o));
+#else
+  return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
+#endif
+#elif defined(CPU_CAPABILITY_AVX512)
+  __m512 v = _mm512_set1_ps(val);
+  __m256i o =
+      _mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+  return static_cast(
+      _mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
+#endif
+}
+
+static inline float half2float_scalar(uint16_t val) {
+#if defined(CPU_CAPABILITY_AVX2)
+#if defined(_MSC_VER)
+  __m128i v = _mm_cvtsi32_si128(val);
+  __m256 o = _mm256_cvtph_ps(v);
+  return _mm256_cvtss_f32(o);
+#else
+  return _cvtsh_ss(val);
+#endif
+#elif defined(CPU_CAPABILITY_AVX512)
+  __m256i v =
+      _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
+  __m512 o = _mm512_cvtph_ps(v);
+  return _mm512_cvtss_f32(o);
+#endif
+}
+
+#endif
+
+} // namespace CPU_CAPABILITY
+} // namespace at::vec
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_n.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_n.h
new file mode 100644
index 0000000000000000000000000000000000000000..0ad919432a05eec30bdbff189a122988731083c2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vec/vec_n.h
@@ -0,0 +1,344 @@
+#include 
+#include 
+
+namespace at::vec {
+inline namespace CPU_CAPABILITY {
+
+/**
+ * @brief A class template representing a vectorized type with
+ * `N * Vectorized::size()` elements, aiming to support vectors of
+ * arbitrary size. A specific use case of it is to represent vectors
+ * converted from data types with different sizes but with the same
+ * number of vector elements, e.g., `VectorizedN` can be
+ * a vector converted from two `Vectorized`, `VectorizedN`
+ * can be a vector converted from two `Vectorized` etc.
+ *
+ * It supports most of the operations of `Vectorized`
+ * and the implementation delegates to `Vectorized` with loops over `N`.
+ *
+ * @tparam T The underlying type of the vectorized elements.
+ * @tparam N The number of underlying `Vectorized`.
+ */
+template 
+class VectorizedN {
+ public:
+  using value_type = T;
+  using size_type = int;
+
+  static constexpr size_type size_T = sizeof(T);
+  static constexpr size_type size() {
+    return Vectorized::size() * N;
+  }
+
+ private:
+  std::array, N> values;
+
+ public:
+  // methods not implemented yet:
+  // variadic constructor, operator T*, as_bytes, zero_mask
+
+#define VECTORIZEDN_DEFINE_UNARY_OP(op)                             \
+  VectorizedN op() const {                                    \
+    return unary_op([](const Vectorized& a) { return a.op(); }); \
+  }
+
+#define VECTORIZEDN_DEFINE_BINARY_OP(op)                            \
+  VectorizedN op(const VectorizedN& other) const {      \
+    return binary_op(                                               \
+        other, [](const Vectorized& a, const Vectorized& b) { \
+          return a.op(b);                                           \
+        });                                                         \
+  }
+
+  template 
+  inline VectorizedN unary_op(Op op) const {
+    VectorizedN result;
+#ifndef _MSC_VER
+#pragma unroll
+#endif
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = op(values[i]);
+    }
+    return result;
+  }
+
+  template 
+  inline VectorizedN binary_op(const VectorizedN& other, Op op)
+      const {
+    VectorizedN result;
+#ifndef _MSC_VER
+#pragma unroll
+#endif
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = op(values[i], other.values[i]);
+    }
+    return result;
+  }
+
+  VectorizedN() = default;
+
+  explicit VectorizedN(T val) {
+    for (int i = 0; i < N; ++i) {
+      values[i] = Vectorized(val);
+    }
+  }
+
+  const Vectorized& operator[](int i) const {
+    return values[i];
+  }
+
+  Vectorized& operator[](int i) {
+    return values[i];
+  }
+
+  template 
+  static VectorizedN blend(
+      const VectorizedN& a,
+      const VectorizedN& b) {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = Vectorized::blend(a.values[i], b.values[i]);
+    }
+    return result;
+  }
+
+  static VectorizedN blendv(
+      const VectorizedN& a,
+      const VectorizedN& b,
+      const VectorizedN& mask) {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] =
+          Vectorized::blendv(a.values[i], b.values[i], mask.values[i]);
+    }
+    return result;
+  }
+
+  template 
+  static VectorizedN arange(
+      T base = static_cast(0),
+      step_t step = static_cast(1)) {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = Vectorized::arange(base, step);
+      base += step * Vectorized::size();
+    }
+    return result;
+  }
+
+  static VectorizedN set(
+      const VectorizedN& a,
+      const VectorizedN& b,
+      int64_t count = size()) {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] =
+          Vectorized::set(a.values[i], b.values[i], std::min(count, Vectorized::size()));
+      count -= Vectorized::size();
+      if (count <= 0) {
+        break;
+      }
+    }
+    return result;
+  }
+
+  static VectorizedN loadu(const void* ptr) {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = Vectorized::loadu(ptr);
+      ptr = static_cast(ptr) + Vectorized::size();
+    }
+    return result;
+  }
+
+  static VectorizedN loadu(const void* ptr, int64_t count) {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] =
+          Vectorized::loadu(ptr, std::min(count, Vectorized::size()));
+      ptr = static_cast(ptr) + Vectorized::size();
+      count -= Vectorized::size();
+      if (count <= 0) {
+        break;
+      }
+    }
+    return result;
+  }
+
+  void store(void* ptr) const {
+    for (int i = 0; i < N; ++i) {
+      values[i].store(ptr);
+      ptr = static_cast(ptr) + Vectorized::size();
+    }
+  }
+
+  void store(void* ptr, int count) const {
+    for (int i = 0; i < N; ++i) {
+      values[i].store(ptr, std::min(count, Vectorized::size()));
+      ptr = static_cast(ptr) + Vectorized::size();
+      count -= Vectorized::size();
+      if (count <= 0) {
+        break;
+      }
+    }
+  }
+
+  bool has_inf_nan() const {
+    for (int i = 0; i < N; ++i) {
+      if (values[i].has_inf_nan()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  VectorizedN map(T (*const f)(T)) const {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = values[i].map(f);
+    }
+    return result;
+  }
+
+  VectorizedN map(T (*const f)(const T&)) const {
+    VectorizedN result;
+    for (int i = 0; i < N; ++i) {
+      result.values[i] = values[i].map(f);
+    }
+    return result;
+  }
+
+  VECTORIZEDN_DEFINE_UNARY_OP(abs)
+  VECTORIZEDN_DEFINE_UNARY_OP(sgn)
+  VECTORIZEDN_DEFINE_UNARY_OP(angle)
+  VECTORIZEDN_DEFINE_UNARY_OP(real)
+  VECTORIZEDN_DEFINE_UNARY_OP(imag)
+  VECTORIZEDN_DEFINE_UNARY_OP(conj)
+  VECTORIZEDN_DEFINE_UNARY_OP(acos)
+  VECTORIZEDN_DEFINE_UNARY_OP(acosh)
+  VECTORIZEDN_DEFINE_UNARY_OP(asin)
+  VECTORIZEDN_DEFINE_UNARY_OP(atan)
+  VECTORIZEDN_DEFINE_UNARY_OP(atanh)
+  VECTORIZEDN_DEFINE_BINARY_OP(atan2)
+  VECTORIZEDN_DEFINE_BINARY_OP(copysign)
+  VECTORIZEDN_DEFINE_UNARY_OP(erf)
+  VECTORIZEDN_DEFINE_UNARY_OP(erfc)
+  VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
+  VECTORIZEDN_DEFINE_UNARY_OP(exp)
+  VECTORIZEDN_DEFINE_UNARY_OP(exp2)
+  VECTORIZEDN_DEFINE_UNARY_OP(expm1)
+  VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
+  VECTORIZEDN_DEFINE_UNARY_OP(frac)
+  VECTORIZEDN_DEFINE_BINARY_OP(fmod)
+  VECTORIZEDN_DEFINE_UNARY_OP(log)
+  VECTORIZEDN_DEFINE_UNARY_OP(log10)
+  VECTORIZEDN_DEFINE_UNARY_OP(log1p)
+  VECTORIZEDN_DEFINE_UNARY_OP(log2)
+  VECTORIZEDN_DEFINE_UNARY_OP(ceil)
+  VECTORIZEDN_DEFINE_UNARY_OP(cos)
+  VECTORIZEDN_DEFINE_UNARY_OP(cosh)
+  VECTORIZEDN_DEFINE_UNARY_OP(floor)
+  VECTORIZEDN_DEFINE_BINARY_OP(hypot)
+  VECTORIZEDN_DEFINE_UNARY_OP(i0)
+  VECTORIZEDN_DEFINE_UNARY_OP(i0e)
+  VECTORIZEDN_DEFINE_UNARY_OP(digamma)
+  VECTORIZEDN_DEFINE_BINARY_OP(igamma)
+  VECTORIZEDN_DEFINE_BINARY_OP(igammac)
+  VECTORIZEDN_DEFINE_UNARY_OP(neg)
+  VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
+  VECTORIZEDN_DEFINE_UNARY_OP(round)
+  VECTORIZEDN_DEFINE_UNARY_OP(sin)
+  VECTORIZEDN_DEFINE_UNARY_OP(sinh)
+  VECTORIZEDN_DEFINE_UNARY_OP(tan)
+  VECTORIZEDN_DEFINE_UNARY_OP(tanh)
+  VECTORIZEDN_DEFINE_UNARY_OP(trunc)
+  VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
+  VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
+  VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
+  VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
+  VECTORIZEDN_DEFINE_BINARY_OP(pow)
+  VECTORIZEDN_DEFINE_BINARY_OP(operator==)
+  VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
+  VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
+  VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
+  VECTORIZEDN_DEFINE_BINARY_OP(operator>)
+  VECTORIZEDN_DEFINE_BINARY_OP(operator<)
+  VECTORIZEDN_DEFINE_BINARY_OP(eq)
+  VECTORIZEDN_DEFINE_BINARY_OP(ne)
+  VECTORIZEDN_DEFINE_BINARY_OP(gt)
+  VECTORIZEDN_DEFINE_BINARY_OP(ge)
+  VECTORIZEDN_DEFINE_BINARY_OP(lt)
+  VECTORIZEDN_DEFINE_BINARY_OP(le)
+
+#undef VECTORIZEDN_DEFINE_UNARY_OP
+#undef VECTORIZEDN_DEFINE_BINARY_OP
+};
+
+#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op)                       \
+  template                                        \
+  inline VectorizedN op(const VectorizedN& a) {          \
+    return a.unary_op([](const Vectorized& a) { return op(a); }); \
+  }
+
+#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op)                                \
+  template                                                  \
+  inline VectorizedN op(                                                 \
+      const VectorizedN& a, const VectorizedN& b) {                \
+    return a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \
+      return op(a, b);                                                         \
+    });                                                                        \
+  }
+
+#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op)                     \
+  template                                               \
+  inline VectorizedN& op(                                             \
+      VectorizedN& a, const VectorizedN& b) {                   \
+    a = a.binary_op(b, [](const Vectorized& a, const Vectorized& b) { \
+      return op(a, b);                                                      \
+    });                                                                     \
+    return a;                                                               \
+  }
+
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
+VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
+VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
+
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
+VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
+
+#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
+#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
+#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
+
+template 
+inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN acc_vec) {
+  Vectorized vec_result = acc_vec[0];
+  for (int i = 1; i < N; i++) {
+    vec_result = vec_fun(vec_result, acc_vec[i]);
+  }
+  return vec_reduce_all(vec_fun, vec_result);
+}
+
+} // namespace CPU_CAPABILITY
+} // namespace at::vec
\ No newline at end of file
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cpu/vml.h b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vml.h
new file mode 100644
index 0000000000000000000000000000000000000000..45a4b88ae93bac68da49ca2f3f25375b5d6c98e5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cpu/vml.h
@@ -0,0 +1,171 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// This header implements various unary operations using a MKL VML style
+// interface.
+
+// It implements various functions with a simple interface
+// For example it enables the user to call vsin(float* out, const float* in,
+// size) This functions takes a pointer to a continuous output array of floats and
+// a constant input array. It will then apply sin to each value in the input
+// array and write the result into the output array. out and in may point to the
+// same memory, i.e. this fully supports in-place operations. These functions
+// also implement their own parallelization, so take precautions when calling
+// these from threaded functions.
+
+// When MKL is available it will call into MKL's VML library similar to NumPy
+// If MKL is not available it will use SLEEF.
+
+// This file might be compiled under AVX or AVX2 when called from e.g.
+// UnaryOpsKernel.cpp
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if AT_MKL_ENABLED() && !defined(__APPLE__)
+#include 
+#endif
+
+namespace at {
+namespace vml {
+inline namespace CPU_CAPABILITY {
+
+using namespace vec;
+
+template 
+inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
+  parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
+    map(
+        [](const Vectorized& x) {
+          return Vectorized((scalar_t)(1)) / x.sqrt();
+        },
+        out + begin,
+        in + begin,
+        end - begin);
+  });
+}
+
+// NB: We ignore numerical errors by convention and leave them to the user
+
+#define IMPLEMENT_VML(op)                                               \
+  template                                           \
+  inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) {  \
+    using vec_t = Vectorized>;                   \
+    vec::map([](vec_t x) { return x.op(); }, out, in, size);            \
+  }                                                                     \
+
+IMPLEMENT_VML(abs)
+IMPLEMENT_VML(acos)
+IMPLEMENT_VML(asin)
+IMPLEMENT_VML(atan)
+IMPLEMENT_VML(atanh)
+IMPLEMENT_VML(ceil)
+IMPLEMENT_VML(cos)
+// IMPLEMENT_VML(cosh)
+IMPLEMENT_VML(erf)
+IMPLEMENT_VML(erfc)
+IMPLEMENT_VML(erfinv)
+IMPLEMENT_VML(exp)
+IMPLEMENT_VML(expm1)
+IMPLEMENT_VML(floor)
+IMPLEMENT_VML(i0)
+IMPLEMENT_VML(i0e)
+IMPLEMENT_VML(digamma)
+IMPLEMENT_VML(reciprocal)
+IMPLEMENT_VML(log)
+IMPLEMENT_VML(log10)
+IMPLEMENT_VML(log1p)
+IMPLEMENT_VML(log2)
+IMPLEMENT_VML(neg)
+IMPLEMENT_VML(sin)
+// IMPLEMENT_VML(sinh)
+IMPLEMENT_VML(sqrt)
+IMPLEMENT_VML(round)
+IMPLEMENT_VML(rsqrt)
+IMPLEMENT_VML(tan)
+IMPLEMENT_VML(tanh)
+IMPLEMENT_VML(trunc)
+IMPLEMENT_VML(lgamma)
+
+
+#if AT_MKL_ENABLED() && !defined(__APPLE__)
+
+// NB: LP64 MKL is the most commonly used and thus we assume it here. That means
+// we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most
+// cases.
+static_assert(
+    std::is_same_v || std::is_same_v,
+    "MKL_INT is assumed to be int32_t or int64_t");
+#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype)                \
+  template <>                                                           \
+  inline void v##op(type * out, const type * in, int64_t size) {        \
+    int64_t max_mkl_ind = std::numeric_limits::max();          \
+    if (size <= static_cast(max_mkl_ind)) {                    \
+      vm##mkltype##mklop(                                               \
+          size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
+    } else {                                                            \
+      MKL_INT ind = 0;                                                  \
+      int64_t chunks = size / max_mkl_ind;                              \
+      int64_t rest = size % max_mkl_ind;                                \
+      for (; ind < chunks; ind++) {                                     \
+        vm##mkltype##mklop(                                             \
+            max_mkl_ind,                                                \
+            in + ind * max_mkl_ind,                                     \
+            out + ind * max_mkl_ind,                                    \
+            VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE);              \
+      }                                                                 \
+      vm##mkltype##mklop(                                               \
+          rest,                                                         \
+          in + ind * max_mkl_ind,                                       \
+          out + ind * max_mkl_ind,                                      \
+          VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE);                \
+    }                                                                   \
+  }
+
+#define IMPLEMENT_VML_MKL(op, mklop)          \
+  IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
+  IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
+
+// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple
+// NB: expm1 is disabled because on some configs it produces expm1(nan)=-1
+IMPLEMENT_VML_MKL(acos, Acos)
+IMPLEMENT_VML_MKL(asin, Asin)
+IMPLEMENT_VML_MKL(atan, Atan)
+IMPLEMENT_VML_MKL(cos, Cos)
+// IMPLEMENT_VML_MKL(cosh, Cosh)
+IMPLEMENT_VML_MKL(erf, Erf)
+IMPLEMENT_VML_MKL(erfc, Erfc)
+IMPLEMENT_VML_MKL(erfinv, ErfInv)
+IMPLEMENT_VML_MKL(exp, Exp)
+// IMPLEMENT_VML_MKL(expm1, Expm1)
+IMPLEMENT_VML_MKL(log, Ln)
+IMPLEMENT_VML_MKL(log10, Log10)
+IMPLEMENT_VML_MKL(sin, Sin)
+// IMPLEMENT_VML_MKL(sinh, Sinh)
+IMPLEMENT_VML_MKL(sqrt, Sqrt)
+IMPLEMENT_VML_MKL(tan, Tan)
+IMPLEMENT_VML_MKL(tanh, Tanh)
+IMPLEMENT_VML_MKL(trunc, Trunc)
+
+// Not vectorized in MKL version tested
+// IMPLEMENT_VML_MKL(abs, Abs)
+// IMPLEMENT_VML_MKL(log1p, Log1p)
+
+#if INTEL_MKL_VERSION >= 20180406
+IMPLEMENT_VML_MKL(log2, Log2)
+#endif
+
+#endif
+
+} // namespace
+} // namespace vml
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h
new file mode 100644
index 0000000000000000000000000000000000000000..3938aa341bb3943a9e42a3178d3233868b755101
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..b67b0905a09fd2a1bb17f7cc69863fd849ded1ff
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh
@@ -0,0 +1,47 @@
+#include 
+
+#include 
+
+namespace at::cuda {
+
+/**
+   Computes ceil(a / b)
+*/
+template 
+__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
+  return (a + b - 1) / b;
+}
+
+namespace {
+
+// Threads per block for our apply kernel
+// FIXME: use occupancy calculator instead
+constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
+constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
+
+template 
+inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
+  if (curDevice == -1) return false;
+  uint64_t numel_per_thread = static_cast(max_threads_per_block) * static_cast(step);
+  uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
+  uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
+  if (numBlocks > maxGridX)
+    numBlocks = maxGridX;
+  grid = dim3(numBlocks);
+  return true;
+}
+
+constexpr int getApplyBlocksPerSM() {
+  return AT_APPLY_BLOCKS_PER_SM;
+}
+
+constexpr int getApplyBlockSize() {
+  return AT_APPLY_THREADS_PER_BLOCK;
+}
+
+inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
+  return dim3(max_threads_per_block);
+}
+
+} // anonymous namespace
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/AsmUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/AsmUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1daf0349042c77bf0627c61ecfa294a5b5c73a3c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/AsmUtils.cuh
@@ -0,0 +1,149 @@
+#pragma once
+#include 
+
+// Collection of direct PTX functions
+
+namespace at::cuda {
+
+template 
+struct Bitfield {};
+
+template <>
+struct Bitfield {
+  static __device__ __host__ __forceinline__
+  unsigned int getBitfield(unsigned int val, int pos, int len) {
+#if !defined(__CUDA_ARCH__)
+    pos &= 0xff;
+    len &= 0xff;
+
+    unsigned int m = (1u << len) - 1u;
+    return (val >> pos) & m;
+#else
+    unsigned int ret;
+    asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
+    return ret;
+#endif
+  }
+
+  static __device__ __host__ __forceinline__
+  unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
+#if !defined(__CUDA_ARCH__)
+    pos &= 0xff;
+    len &= 0xff;
+
+    unsigned int m = (1u << len) - 1u;
+    toInsert &= m;
+    toInsert <<= pos;
+    m <<= pos;
+
+    return (val & ~m) | toInsert;
+#else
+    unsigned int ret;
+    asm("bfi.b32 %0, %1, %2, %3, %4;" :
+        "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
+    return ret;
+#endif
+  }
+};
+
+template <>
+struct Bitfield {
+  static __device__ __host__ __forceinline__
+  uint64_t getBitfield(uint64_t val, int pos, int len) {
+#if !defined(__CUDA_ARCH__)
+    pos &= 0xff;
+    len &= 0xff;
+
+    uint64_t m = (1u << len) - 1u;
+    return (val >> pos) & m;
+#else
+    uint64_t ret;
+    asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
+    return ret;
+#endif
+  }
+
+  static __device__ __host__ __forceinline__
+  uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
+#if !defined(__CUDA_ARCH__)
+    pos &= 0xff;
+    len &= 0xff;
+
+    uint64_t m = (1u << len) - 1u;
+    toInsert &= m;
+    toInsert <<= pos;
+    m <<= pos;
+
+    return (val & ~m) | toInsert;
+#else
+    uint64_t ret;
+    asm("bfi.b64 %0, %1, %2, %3, %4;" :
+        "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
+    return ret;
+#endif
+  }
+};
+
+__device__ __forceinline__ int getLaneId() {
+#if defined(USE_ROCM)
+  return __lane_id();
+#else
+  int laneId;
+  asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
+  return laneId;
+#endif
+}
+
+#if defined(USE_ROCM)
+__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
+  const std::uint64_t m = (1ull << getLaneId()) - 1ull;
+  return m;
+}
+#else
+__device__ __forceinline__ unsigned getLaneMaskLt() {
+  unsigned mask;
+  asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
+  return mask;
+}
+#endif
+
+#if defined (USE_ROCM)
+__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
+  std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
+  return m;
+}
+#else
+__device__ __forceinline__ unsigned getLaneMaskLe() {
+  unsigned mask;
+  asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
+  return mask;
+}
+#endif
+
+#if defined(USE_ROCM)
+__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
+  const std::uint64_t m = getLaneMaskLe();
+  return m ? ~m : m;
+}
+#else
+__device__ __forceinline__ unsigned getLaneMaskGt() {
+  unsigned mask;
+  asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
+  return mask;
+}
+#endif
+
+#if defined(USE_ROCM)
+__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
+  const std::uint64_t m = getLaneMaskLt();
+  return ~m;
+}
+#else
+__device__ __forceinline__ unsigned getLaneMaskGe() {
+  unsigned mask;
+  asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
+  return mask;
+}
+#endif
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/Atomic.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/Atomic.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1e66b2fdce4eb54e425885a9e400490350574c7f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/Atomic.cuh
@@ -0,0 +1,508 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
+#include 
+#endif
+
+template 
+struct AtomicFPOp;
+
+template <>
+struct AtomicFPOp {
+  template 
+  inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
+    unsigned int * address_as_ui =
+      (unsigned int *) ((char *)address - ((size_t)address & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    at::Half hsum;
+    do {
+      assumed = old;
+      hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+      hsum = func(hsum, val);
+      old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+      old = atomicCAS(address_as_ui, assumed, old);
+    } while (assumed != old);
+    hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+    return hsum;
+  }
+};
+
+template <>
+struct AtomicFPOp {
+  template 
+  inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
+    unsigned int * address_as_ui =
+      (unsigned int *) ((char *)address - ((size_t)address & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    at::BFloat16 bsum;
+    do {
+      assumed = old;
+      bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+      bsum = func(bsum, val);
+      old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
+      old = atomicCAS(address_as_ui, assumed, old);
+    } while (assumed != old);
+    bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+    return bsum.x;
+  }
+};
+
+template <>
+struct AtomicFPOp {
+  template 
+  inline __device__ double operator() (double * address, double val, const func_t& func) {
+    unsigned long long int* address_as_ull = (unsigned long long int*)address;
+    unsigned long long int old = *address_as_ull;
+    unsigned long long int assumed;
+
+    do {
+      assumed = old;
+      old = atomicCAS(address_as_ull, assumed, func(val, assumed));
+      // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+    } while (assumed != old);
+
+    return __longlong_as_double(old);
+  }
+};
+
+#define ATOMIC_INTEGER_IMPL(NAME)                                                                                      \
+template                                                                                         \
+struct Atomic##NAME##IntegerImpl;                                                                                      \
+                                                                                                                       \
+template                                                                                                   \
+struct Atomic##NAME##IntegerImpl {                                                                               \
+  template                                                                                            \
+  inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
+    size_t offset = (size_t)address & 3;                                                                               \
+    uint32_t * address_as_ui = (uint32_t *)((char *)address - offset);                                                 \
+    uint32_t old = *address_as_ui;                                                                                     \
+    uint32_t shift = offset * 8;                                                                                       \
+    uint32_t old_byte;                                                                                                 \
+    uint32_t newval;                                                                                                   \
+    uint32_t assumed;                                                                                                  \
+                                                                                                                       \
+    do {                                                                                                               \
+      assumed = old;                                                                                                   \
+      old_byte = (old >> shift) & 0xff;                                                                                \
+      newval = static_cast(func(val, static_cast(old_byte)));                                              \
+      newval = (old & ~(0x000000ff << shift)) | (newval << shift);                                                     \
+      old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
+    } while (assumed != old);                                                                                          \
+  }                                                                                                                    \
+};                                                                                                                     \
+                                                                                                                       \
+template                                                                                                   \
+struct Atomic##NAME##IntegerImpl {                                                                               \
+  template                                                                                            \
+  inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
+    size_t offset = (size_t)address & 2;                                                                               \
+    uint32_t * address_as_ui = (uint32_t *)((char *)address - offset);                                                 \
+    bool is_32_align = offset;                                                                                         \
+    uint32_t old = *address_as_ui;                                                                                     \
+    uint32_t old_bytes;                                                                                                \
+    uint32_t newval;                                                                                                   \
+    uint32_t assumed;                                                                                                  \
+                                                                                                                       \
+    do {                                                                                                               \
+      assumed = old;                                                                                                   \
+      old_bytes = is_32_align ? old >> 16 : old & 0xffff;                                                              \
+      newval = static_cast(func(val, static_cast(old_bytes)));                                            \
+      newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval;                            \
+      old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
+    } while (assumed != old);                                                                                          \
+  }                                                                                                                    \
+};                                                                                                                     \
+                                                                                                                       \
+template                                                                                                   \
+struct Atomic##NAME##IntegerImpl {                                                                               \
+  template                                                                                            \
+  inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
+    uint32_t * address_as_ui = (uint32_t *) (address);                                                                 \
+    uint32_t old = *address_as_ui;                                                                                     \
+    uint32_t newval;                                                                                                   \
+    uint32_t assumed;                                                                                                  \
+                                                                                                                       \
+    do {                                                                                                               \
+      assumed = old;                                                                                                   \
+      newval = static_cast(func(val, static_cast(old)));                                                  \
+      old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
+    } while (assumed != old);                                                                                          \
+  }                                                                                                                    \
+};                                                                                                                     \
+                                                                                                                       \
+template                                                                                                   \
+struct Atomic##NAME##IntegerImpl {                                                                               \
+  template                                                                                            \
+  inline __device__ void operator()(T *address, T val, const func_t& func) {                                           \
+    unsigned long long * address_as_ui = (unsigned long long *) (address);                                             \
+    unsigned long long old = *address_as_ui;                                                                           \
+    unsigned long long newval;                                                                                         \
+    unsigned long long assumed;                                                                                        \
+                                                                                                                       \
+    do {                                                                                                               \
+      assumed = old;                                                                                                   \
+      newval = static_cast(func(val, static_cast(old)));                                                  \
+      old = atomicCAS(address_as_ui, assumed, newval);                                                                 \
+    } while (assumed != old);                                                                                          \
+  }                                                                                                                    \
+};
+
+
+# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE)                                                                           \
+static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) {                                             \
+Atomic##NAME##IntegerImpl()(address,                                                             \
+                                                      val,                                                             \
+                                                      [](DTYPE a, DTYPE b) {                                           \
+                                                          return OP;                                                   \
+                                                      });                                                              \
+}                                                                                                                      \
+
+ATOMIC_INTEGER_IMPL(Add)
+GPU_ATOMIC_INTEGER(Add, a || b, bool)
+
+// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
+static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
+  AtomicAddIntegerImpl()(address,
+                                                   val,
+                                                   [](uint8_t a, uint8_t b) {
+                                                      return a + b;
+                                                   });
+}
+
+static inline  __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
+  AtomicAddIntegerImpl()(address,
+                                                 val,
+                                                 [](int8_t a, int8_t b) {
+                                                   return a + b;
+                                                 });
+}
+
+static inline  __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
+  AtomicAddIntegerImpl()(address,
+                                                   val,
+                                                   [](int16_t a, int16_t b) {
+                                                     return a + b;
+                                                   });
+}
+
+static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
+  return atomicAdd(address, val);
+}
+
+static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
+#if defined(USE_ROCM)
+  __atomic_fetch_add(address, val, __ATOMIC_RELAXED);
+#else
+  static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
+  atomicAdd(reinterpret_cast(address), static_cast(val));
+#endif
+}
+
+static inline  __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
+#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
+  return AtomicFPOp()(address, val,
+                                [](at::Half hsum, at::Half val) {
+                                  return hsum + val;
+                                });
+#else
+  return atomicAdd(reinterpret_cast<__half*>(address), val);
+#endif
+}
+
+static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
+#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
+return AtomicFPOp()(address, val,
+                                  [](at::BFloat16 bsum, at::BFloat16 val) {
+                                    return bsum + val;
+                                  });
+#else
+  __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
+  return *reinterpret_cast(&r);
+#endif
+}
+
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
+// from CUDA C Programmic Guide
+static inline __device__ double atomicAdd(double* address, double val)
+#if defined(__clang__) && defined(__CUDA__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wgcc-compat"
+    __attribute__((enable_if(true, "")))
+#pragma GCC diagnostic pop
+#endif
+{
+
+  return AtomicFPOp()(address, val,
+                              [](double val, unsigned long long int assumed) {
+                                return __double_as_longlong(val + __longlong_as_double(assumed));
+                              });
+}
+#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
+
+/* Note [hip-clang differences to hcc]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
+ * It exports the __HIP__ macro, we can hence differentiate between hcc and
+ * hip-clang. In the below, hcc only received support for atomicAdd with double
+ * typing after work week 18312. hip-clang had support from the first version.
+ * In general, the code-visible differences between hip-clang and hcc will be
+ * minimal.
+ */
+
+#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
+  // This needs to be defined for the host side pass
+  static inline  __device__  double atomicAdd(double *address, double val) { }
+#endif
+#endif
+
+static inline __device__ double gpuAtomicAdd(double *address, double val) {
+  return atomicAdd(address, val);
+}
+
+static inline __device__ float gpuAtomicAdd(float *address, float val) {
+  return atomicAdd(address, val);
+}
+
+template
+static inline __device__ void gpuAtomicAdd(c10::complex *address, c10::complex val) {
+  gpuAtomicAdd(&address->real_, val.real_);
+  gpuAtomicAdd(&address->imag_, val.imag_);
+}
+
+/* Note [gpuAtomicAdd vs atomicAdd]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * Some extensions such as torchvision call atomicAdd()
+ * directly and require non-library provided data type support. Only for these, we
+ * continue to provide atomicAdd overloads.
+ */
+static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
+  return gpuAtomicAdd(address, val);
+}
+
+static inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
+  return gpuAtomicAdd(address, val);
+}
+
+static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
+  gpuAtomicAdd(address, val);
+}
+
+static inline  __device__ void atomicAdd(int8_t *address, int8_t val) {
+  gpuAtomicAdd(address, val);
+}
+
+static inline  __device__ void atomicAdd(int16_t *address, int16_t val) {
+  gpuAtomicAdd(address, val);
+}
+
+static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
+  gpuAtomicAdd(address, val);
+}
+
+static inline __device__ void atomicAdd(bool *address, bool val) {
+  gpuAtomicAdd(address, val);
+}
+
+/* Note [explicitly non-returning atomics]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
+ * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
+ * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
+ * therefore we need a new API 'gpuAtomicAddNoReturn'.
+ */
+template
+static inline __device__ void gpuAtomicAddNoReturn(c10::complex *address, c10::complex val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
+static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
+
+/* Special case fp32 atomic. */
+#if defined(USE_ROCM)
+static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
+#else
+static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
+#endif
+
+// Atomic multiplication implementation.
+
+ATOMIC_INTEGER_IMPL(Mul)
+GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
+GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
+GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
+GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
+GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
+
+inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
+  return AtomicFPOp()(address, val,
+                                [](at::Half bsum, at::Half val) {
+                                  return bsum * val;
+                                });
+}
+
+inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
+  return AtomicFPOp()(address, val,
+                                    [](at::BFloat16 bsum, at::BFloat16 val) {
+                                      return bsum * val;
+                                    });
+}
+
+inline __device__ double gpuAtomicMul(double * address, double val) {
+  return AtomicFPOp()(address, val,
+                              [](double val, unsigned long long int assumed) {
+                                return __double_as_longlong(val * __longlong_as_double(assumed));
+                              });
+}
+
+// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
+inline __device__ float gpuAtomicMul (float * address, float val) {
+  unsigned int* address_as_ull = (unsigned int*)address;
+  unsigned int old = *address_as_ull;
+  unsigned int assumed;
+
+  do {
+    assumed = old;
+    old = atomicCAS(address_as_ull, assumed,
+                    __float_as_int(val *
+                                   __int_as_float(assumed)));
+
+    // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+  } while (assumed != old);
+
+  return __int_as_float(old);
+}
+
+// Atomic maximum implementation.
+
+template 
+__host__ __device__ T safe_max(T a, T b) {
+  #if defined(__HIPCC__)
+  // TODO: remove this special case for HIP when issue is fixed:
+  //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
+    T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
+  #else
+    T max = at::_isnan(b) ? b : std::max(a, b);
+  #endif
+
+  return max;
+}
+
+ATOMIC_INTEGER_IMPL(Max)
+GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
+GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
+GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
+GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
+GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
+
+inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
+  return AtomicFPOp()(address, val,
+                                [](at::Half bsum, at::Half val) {
+                                  return safe_max(bsum, val);
+                                });
+}
+
+inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
+  return AtomicFPOp()(address, val,
+                                    [](at::BFloat16 bsum, at::BFloat16 val) {
+                                      return safe_max(bsum, val);
+                                    });
+}
+
+inline __device__ double gpuAtomicMax(double * address, double val) {
+  return AtomicFPOp()(address, val,
+                              [](double val, unsigned long long int assumed) {
+                                return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
+                              });
+}
+
+// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
+inline __device__ float gpuAtomicMax(float * address, float val) {
+  unsigned int* address_as_ull = (unsigned int*)address;
+  unsigned int old = *address_as_ull;
+  unsigned int assumed;
+
+  do {
+    assumed = old;
+    old = atomicCAS(address_as_ull, assumed,
+                    __float_as_int(safe_max(val, __int_as_float(assumed))));
+
+    // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+  } while (assumed != old);
+
+  return __int_as_float(old);
+}
+
+// Atomic minimum implementation.
+
+template 
+__host__ __device__ T safe_min(T a, T b) {
+  #if defined(__HIPCC__)
+  // TODO: remove this special case for HIP when issue is fixed:
+  //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
+    T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
+  #else
+    T min = at::_isnan(b) ? b : std::min(a, b);
+  #endif
+
+  return min;
+}
+
+ATOMIC_INTEGER_IMPL(Min)
+GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
+GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
+GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
+GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
+GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
+
+inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
+  return AtomicFPOp()(address, val,
+                                [](at::Half bsum, at::Half val) {
+                                  return safe_min(bsum, val);
+                                });
+}
+
+inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
+  return AtomicFPOp()(address, val,
+                                    [](at::BFloat16 bsum, at::BFloat16 val) {
+                                      return safe_min(bsum, val);
+                                    });
+}
+
+inline __device__ double gpuAtomicMin(double * address, double val) {
+  return AtomicFPOp()(address, val,
+                              [](double val, unsigned long long int assumed) {
+                                return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
+                              });
+}
+
+// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
+inline __device__ float gpuAtomicMin(float * address, float val) {
+  unsigned int* address_as_ull = (unsigned int*)address;
+  unsigned int old = *address_as_ull;
+  unsigned int assumed;
+
+  do {
+    assumed = old;
+    old = atomicCAS(address_as_ull, assumed,
+                    __float_as_int(safe_min(val, __int_as_float(assumed))));
+
+    // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+  } while (assumed != old);
+
+  return __int_as_float(old);
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..eb26308c52dfc4b1c62b67b22a76d6a6a37c241c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh
@@ -0,0 +1,537 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+//
+// This file contains pointwise operation functions and kernels that
+// work on both contiguous and non-contiguous tensor arguments of
+// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
+// copying or temporary storage.
+//
+
+/*
+  NOTE [ CUDA_tensor_applyN helpers ]
+
+  The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
+  functions apply a pointwise operator to N tensor(s).
+
+  The calling convention is
+
+  1. The template arguments should be, sequentially,
+    - First N typename args specify the scalar types of each of the N tensors.
+    - (Optional) `int step` arg specifies the number of elements processed
+      together at the same time.
+      Default is 1.
+    - A usually omitted (i.e., inferred) typename arg specifies the type of the
+      function/functor applied on `N * step` values  in each iteration of each
+      CUDA thread.
+  2. The arguments should be, sequentially,
+    - N tensors
+    - op: a function/functor that processes `N * step` values at the same time.
+      - If `step == 1`, it must have signature
+        `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
+        `scalar*_t`s are the first N typename template args, and the inputs
+        are the `N` values from the `N` tensors retrieved at a common index.
+      - Otherwise, it must must have signature
+          void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&,  // repeat `step` times
+                         scalar2_t&, scalar2_t&, ..., scalar2_t&,  // repeat `step` times
+                         ...,
+                         scalarN_t&, scalarN_t&, ..., scalarN_t&)  // repeat `step` times
+        Different from `step == 1` case, it processes `N * step` values taken
+        from `step` common indices. Moreover, the first input `n` represents the
+        number of valid indices (it will always have `0 < n <= step`). It will
+        almost always be `step`, but at the boundary we may not have full `step`
+        elements and `n` can be a lesser value.
+
+        E.g., if `step == 4` and `N == 2`, `op` could be
+
+          [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
+                    scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
+            // Only process u1, ..., un and v1, ..., vn.
+            // So if `n == 3`, `u4` and `v4` need not to be considered.
+          }
+
+      In both cases, the references can actually be const, but at least one of
+      them should be non-const in order to write the output.
+    - (Optional, but recommended) N TensorArgType args that specify for each
+      tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
+      or only reads (i.e., TensorArgType::ReadOnly).
+      Default is TensorArgType::ReadWrite for first Tensor, and
+                 TensorArgType::ReadOnly  for the rest.
+
+  E.g.,
+
+  to compute a = b^2 for a and b of same dtype, we can call
+
+  CUDA_tensor_apply2(
+    a, b,
+    [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
+  );
+
+  to work on 2 values at the same time, we can call
+
+  CUDA_tensor_apply2(
+    a, b,
+    [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
+                          const scalar2 &b_val1, const scalar2 &b_val2) {
+      // call special vectorized op here, or just do elementwise and enjoy unrolling...
+      // if n == 1, only process a_val1 and b_val1
+    }
+  );
+*/
+
+namespace at::cuda {
+
+// TODO: combine with TensorArg?  So far that's been for debugging, and this is functional...
+enum class TensorArgType { ReadWrite, ReadOnly };
+
+namespace {
+
+// Rearrange dimensions for pointwise operations so that strides are in
+// decreasing order as much as possible, so that kernels have better memory
+// access patterns.
+//
+// For example, consider a binary operation on two "transposed" 2-dim tensors:
+//    sizes:          256 512
+//    aInfo->strides:   1 256
+//    bInfo->strides:   1 256
+//
+// Given this, each concurrent memory access inside kernelPointwiseApply2() is
+// exactly 256 elements apart, resulting in poor performance.
+//
+// This function exchanges dimensions so that memory access is contiguous:
+//    sizes:          512 256
+//    aInfo->strides: 256   1
+//    bInfo->strides: 256   1
+//
+// (Actually, it becomes even better because now collapseDims() can turn each
+// input into one contiguous array.)
+//
+// In general, given M (<=4) TensorInfo's with N dimensions, we can view each
+// strides[i] (0 <= i < N) as an M-tuple.  Given each pair i < j, we exchange
+// strides[i] and [j] if
+//    (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
+//        (exchanging them will benefit input #k), and
+//    (2) strides[i][k] <= strieds[j][k] for all k
+//        (exchanging them will not make any input worse).
+template 
+inline void rearrangeDims(detail::TensorInfo* aInfo,
+                          detail::TensorInfo* bInfo = nullptr,
+                          detail::TensorInfo* cInfo = nullptr,
+                          detail::TensorInfo* dInfo = nullptr) {
+  int numInfos = 1;
+  int dims = aInfo->dims;
+  IndexType *sizes[4] = { aInfo->sizes, };
+  IndexType *strides[4] = { aInfo->strides, };
+
+  if (bInfo != nullptr) {
+    ++numInfos;
+    if (bInfo->dims != dims) return;
+    sizes[1] = bInfo->sizes;
+    strides[1] = bInfo->strides;
+  }
+
+  if (cInfo != nullptr) {
+    ++numInfos;
+    if (cInfo->dims != dims) return;
+    sizes[2] = cInfo->sizes;
+    strides[2] = cInfo->strides;
+  }
+
+  if (dInfo != nullptr) {
+    ++numInfos;
+    if (dInfo->dims != dims) return;
+    sizes[3] = dInfo->sizes;
+    strides[3] = dInfo->strides;
+  }
+
+  // Bail out if sizes do not match: we are using "deprecated pointwise
+  // behavior" among tensors of different shapes but same number of elements.
+  for (int i = 1; i < numInfos; ++i) {
+    for (int j = 0; j < dims; ++j) {
+      if (sizes[i][j] != sizes[0][j]) return;
+    }
+  }
+
+  for (int i = 0; i < dims - 1; ++i) {
+    // No need to consider dimensions of size 1.
+    if (sizes[0][i] == 1) continue;
+
+    for (int j = i + 1; j < dims; ++j) {
+      if (sizes[0][j] == 1) continue;
+
+      // Compare the relative sizes of strides between dim #i and dim #j.
+      bool hasIncreasingStrides = false;
+      bool hasDecreasingStrides = false;
+
+      for (int k = 0; k < numInfos; k++) {
+        IndexType stride_i = strides[k][i];
+        IndexType stride_j = strides[k][j];
+        if (stride_i < stride_j) {
+          hasIncreasingStrides = true;
+        } else if (stride_i > stride_j) {
+          hasDecreasingStrides = true;
+        }
+      }
+
+      if (hasIncreasingStrides && !hasDecreasingStrides) {
+        for (int k = 0; k < numInfos; k++) {
+          IndexType size = sizes[k][i];
+          sizes[k][i] = sizes[k][j];
+          sizes[k][j] = size;
+
+          IndexType stride = strides[k][i];
+          strides[k][i] = strides[k][j];
+          strides[k][j] = stride;
+        }
+      }
+    }
+  }
+}
+
+// The `remaining_steps` argument is used to support Op that operates on
+// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
+//  1. Initialize `remaining_steps = step`, where `step` is the template arg of
+//     CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
+//     number of elements in bound for this call. It will almost always equal to
+//     `step` except at boundaries.
+//  2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
+//     bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
+//  3. At `remaining_steps = 0`,
+//       if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
+//       if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
+//                                  tensor2_val1, tensor2_val2, ..., tesor2_valstep,
+//                                       ...
+//                                  tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
+//
+// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
+
+template 
+struct ApplyOp1 {
+__device__ __forceinline__
+static void apply(detail::TensorInfo &a, const Op &op, int n,
+                  IndexType linearIndex, Offsets... aOffsets) {
+  // Convert `linearIndex` into an offset of `a`
+  const IndexType aOffset = sizeof...(Offsets) < n ?
+    detail::IndexToOffset::get(linearIndex, a) : 0;
+
+  ApplyOp1::apply(
+    a, op, n, linearIndex + 1, aOffsets..., aOffset
+  );
+}
+};
+
+// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
+// We don't need to pass in how many elements need to processed in this case.
+template 
+struct ApplyOp1 {
+__device__ __forceinline__
+static void apply(detail::TensorInfo &a, const Op &op,
+                  int n, IndexType linearIndex, Offset offset) {
+  op(a.data[offset]);
+}
+};
+
+template 
+struct ApplyOp1 {
+__device__ __forceinline__
+static void apply(detail::TensorInfo &a, const Op &op, int n,
+                 IndexType linearIndex, Offsets... offsets) {
+  op(n, a.data[offsets]...);
+}
+};
+
+template 
+#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
+C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+#endif
+__global__ void kernelPointwiseApply1(detail::TensorInfo a,
+                                      IndexType totalElements, const Op op) {
+  for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
+       linearIndex < totalElements;
+       linearIndex += gridDim.x * blockDim.x * step) {
+    ApplyOp1::apply(
+      a, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex);
+  }
+}
+
+
+template 
+struct ApplyOp2 {
+__device__ __forceinline__
+static void apply(detail::TensorInfo &a,
+                  detail::TensorInfo &b,
+                  const Op &op, int64_t n, IndexType linearIndex,
+                  Offsets... aOffsets, Offsets... bOffsets) {
+  // Convert `linearIndex` into an offset of `a`
+  const IndexType aOffset = static_cast(sizeof...(Offsets)) < n ?
+    detail::IndexToOffset::get(linearIndex, a) : 0;
+
+  // Convert `linearIndex` into an offset of `b`
+  const IndexType bOffset = static_cast(sizeof...(Offsets)) < n ?
+    detail::IndexToOffset::get(linearIndex, b) : 0;
+
+  ApplyOp2::apply(
+    a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
+  );
+}
+};
+
+// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
+// We don't need to pass in how many elements need to processed in this case.
+template 
+struct ApplyOp2 {
+__device__ __forceinline__
+static void apply(detail::TensorInfo &a,
+                  detail::TensorInfo &b,
+                  const Op &op, int /*n*/, IndexType /*linearIndex*/,
+                  Offset aOffset, Offset bOffset) {
+  op(a.data[aOffset], b.data[bOffset]);
+}
+};
+
+template 
+struct ApplyOp2 {
+__device__ __forceinline__
+static void apply(detail::TensorInfo &a,
+                  detail::TensorInfo &b,
+                  const Op &op, int n, IndexType linearIndex,
+                  Offsets... aOffsets, Offsets... bOffsets) {
+  op(n, a.data[aOffsets]..., b.data[bOffsets]...);
+}
+};
+
+template 
+#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
+C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
+#endif
+__global__ void
+kernelPointwiseApply2(detail::TensorInfo a,
+                      detail::TensorInfo b,
+                      IndexType totalElements,
+                      const Op op) {
+  for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
+       linearIndex < totalElements;
+       linearIndex += gridDim.x * blockDim.x * step) {
+    ApplyOp2::apply(
+      a, b, op, ::min(step, static_cast(totalElements - linearIndex)),
+      linearIndex);
+  }
+}
+
+} // anonymous namespace
+
+template 
+inline bool CUDA_tensor_apply2(at::TensorBase a,
+                               at::TensorBase b,
+                               const Op op,
+                               TensorArgType aType = TensorArgType::ReadWrite,
+                               TensorArgType bType = TensorArgType::ReadOnly) {
+  TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
+              "CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
+              "tensors with type ", a.device().type(), " and ", b.device().type());
+  int64_t totalElements = a.numel();
+
+  if (totalElements != b.numel()) {
+    return false;
+  }
+
+  if (a.dim() > MAX_TENSORINFO_DIMS ||
+      b.dim() > MAX_TENSORINFO_DIMS) {
+    return false;
+  }
+
+  if (a.numel() == 0) {
+    // Empty tensor; do nothing
+    return true;
+  }
+  const dim3 block = getApplyBlock(max_threads_per_block);
+
+  dim3 grid;
+  auto curDevice = current_device();
+  if (curDevice == -1) return false;
+  if (!getApplyGrid(totalElements, grid, curDevice, max_threads_per_block)) {
+    return false;
+  }
+
+  /*
+  Expands readable/writable tensors whose indices may be "overlapped."
+  This ensures that each element of the tensor is operated on once and only
+  once.
+  */
+  TensorBase oldA;
+  TensorBase oldB;
+
+  if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
+    // Must perform in contiguous space
+    oldA = std::exchange(a, a.contiguous());
+  }
+  if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
+    // Must perform in contiguous space
+    oldB = std::exchange(b, b.contiguous());
+  }
+
+  // It is possible that the tensor dimensions are able to be collapsed,
+  // and thus we can reduce the actual code complexity of the copy by
+  // exploiting this knowledge statically, since the div/mod is the
+  // most expensive part of the operation, more so than memory accesses.
+  // For instance, when copying a non-contiguous to a contiguous tensor
+  // (or vice versa), the contiguous tensor can be collapsed to one
+  // dimension, and the loop to translate the linear index to the array
+  // index can be similarly collapsed. That is what this unrolling is for.
+
+#define HANDLE_CASE(TYPE, A, B)                                        \
+  kernelPointwiseApply2                             \
+   <<>>(    \
+       aInfo, bInfo, static_cast(totalElements), op);            \
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+#define HANDLE_B_CASE(TYPE, A, B) {         \
+  switch (B) {                              \
+    case 1:                                 \
+      HANDLE_CASE(TYPE, A, 1);              \
+      break;                                \
+    case 2:                                 \
+      HANDLE_CASE(TYPE, A, 2);              \
+      break;                                \
+    default:                                \
+      HANDLE_CASE(TYPE, A, -1);             \
+      break;                                \
+  }                                         \
+}
+
+#define HANDLE_A_CASE(TYPE, A, B) {         \
+  switch (A) {                              \
+    case 1:                                 \
+      HANDLE_B_CASE(TYPE, 1, B);            \
+      break;                                \
+    case 2:                                 \
+      HANDLE_B_CASE(TYPE, 2, B);            \
+      break;                                \
+    default:                                \
+      HANDLE_B_CASE(TYPE, -1, B);           \
+      break;                                \
+  }                                         \
+}
+
+  if (detail::canUse32BitIndexMath(a) &&
+      detail::canUse32BitIndexMath(b)) {
+    detail::TensorInfo aInfo =
+      detail::getTensorInfo(a);
+
+    detail::TensorInfo bInfo =
+      detail::getTensorInfo(b);
+    rearrangeDims(&aInfo, &bInfo);
+    aInfo.collapseDims();
+    bInfo.collapseDims();
+
+    HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
+  } else {
+    detail::TensorInfo aInfo =
+      detail::getTensorInfo(a);
+
+    detail::TensorInfo bInfo =
+      detail::getTensorInfo(b);
+    rearrangeDims(&aInfo, &bInfo);
+    aInfo.collapseDims();
+    bInfo.collapseDims();
+
+    /*
+    Only instantiates the all 1D special case and the fallback all nD case for
+    large (64-bit indexed) tensors to reduce compilation time.
+    */
+    if (aInfo.dims == 1 && bInfo.dims == 1) {
+      HANDLE_CASE(uint64_t, 1, 1);
+    } else {
+      HANDLE_CASE(uint64_t, -1, -1);
+    }
+  }
+#undef HANDLE_CASE
+#undef HANDLE_B_CASE
+#undef HANDLE_A_CASE
+
+  if (oldA.defined()) {
+    at::native::copy_ignoring_overlaps(oldA, a);
+  }
+
+  if (oldB.defined()) {
+    at::native::copy_ignoring_overlaps(oldB, b);
+  }
+
+  return true;
+}
+
+/* Provides default step = 1 to CUDA_tensor_apply2. */
+template 
+inline bool CUDA_tensor_apply2(const at::TensorBase &a,
+                               const at::TensorBase &b,
+                               const Op op,
+                               TensorArgType aType = TensorArgType::ReadWrite,
+                               TensorArgType bType = TensorArgType::ReadOnly) {
+  return CUDA_tensor_apply2(a, b, op, aType, bType);
+}
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDABlas.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDABlas.h
new file mode 100644
index 0000000000000000000000000000000000000000..395a86902a322977291b7ab6370062dd47ac02d9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDABlas.h
@@ -0,0 +1,375 @@
+#pragma once
+/*
+  Provides a subset of CUDA BLAS functions as templates:
+
+    gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
+  ldc)
+
+    gemv(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
+
+    dot(n, x, incx, y, incy, result)
+
+  where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
+  The functions are available in at::cuda::blas namespace.
+ */
+
+#include 
+#include 
+
+namespace at::cuda::blas {
+
+// RAII guard that sets the CuBLAS pointer mode and restores it to
+// its previous value when the guard is destroyed
+class PointerModeGuard {
+public:
+  PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
+      handle(handle) {
+    TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
+    TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
+  }
+
+  ~PointerModeGuard() {
+    cublasSetPointerMode(handle, previous_mode);
+  }
+
+private:
+  cublasHandle_t handle;
+  cublasPointerMode_t previous_mode;
+};
+
+/* LEVEL 3 BLAS FUNCTIONS */
+
+#define CUDABLAS_GEMM_ARGTYPES(Dtype)                                                       \
+  char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type alpha,  \
+      const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type beta,\
+      Dtype *c, int64_t ldc
+
+#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
+
+template 
+inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void gemm(CUDABLAS_GEMM_ARGTYPES(double));
+template <>
+void gemm(CUDABLAS_GEMM_ARGTYPES(float));
+template <>
+void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex));
+template <>
+void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex));
+template <>
+void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half));
+template <>
+void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
+
+template 
+inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::gemm_internal: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double));
+template <>
+void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float));
+template <>
+void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex));
+template <>
+void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex));
+template <>
+void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half));
+template <>
+void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
+
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+enum GEMMAndBiasActivationEpilogue {
+  None,
+  RELU,
+  GELU,
+};
+
+// NOTE: GELU activation is not supported prior to CUDA 11.4 and will
+// do nothing if passed in that case.
+template 
+void gemm_and_bias(
+    bool transpose_mat1,
+    bool transpose_mat2,
+    int64_t m,
+    int64_t n,
+    int64_t k,
+    at::opmath_type alpha_val,
+    const Dtype* mat1_ptr,
+    int64_t mat1_ld,
+    const Dtype* mat2_ptr,
+    int64_t mat2_ld,
+    const Dtype* bias,
+    Dtype* result_ptr,
+    int64_t result_ld,
+    GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
+
+void int8_gemm(
+    bool transpose_mat1,
+    bool transpose_mat2,
+    int64_t m,
+    int64_t n,
+    int64_t k,
+    const int8_t* mat1_ptr,
+    int64_t mat1_ld,
+    const int8_t* mat2_ptr,
+    int64_t mat2_ld,
+    int32_t* result_ptr,
+    int64_t result_ld);
+
+void scaled_gemm(
+    char transa,
+    char transb,
+    int64_t m,
+    int64_t n,
+    int64_t k,
+    const void* mat1_ptr,
+    const void* mat1_scale_ptr,
+    int64_t mat1_ld,
+    ScalarType mat1_dtype,
+    const void* mat2_ptr,
+    const void* mat2_scale_ptr,
+    int64_t mat2_ld,
+    ScalarType mat2_dtype,
+    const void* bias_ptr,
+    ScalarType bias_dtype,
+    void* result_ptr,
+    const void* result_scale_ptr,
+    int64_t result_ld,
+    ScalarType result_dtype,
+    void* amax_ptr,
+    bool use_fast_accum);
+#endif
+
+#define CUDABLAS_BGEMM_ARGTYPES(Dtype)                                                        \
+  char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type alpha,    \
+      const Dtype *a, int64_t lda, int64_t stridea,                                           \
+      const Dtype *b, int64_t ldb, int64_t strideb,                                           \
+      at::opmath_type beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
+
+#define CUDABLAS_BGEMM_ARGS(Dtype) \
+  transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
+
+template 
+inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void bgemm(CUDABLAS_BGEMM_ARGTYPES(double));
+template <>
+void bgemm(CUDABLAS_BGEMM_ARGTYPES(float));
+template <>
+void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex));
+template <>
+void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex));
+template <>
+void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half));
+template <>
+void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
+
+template 
+inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::bgemm_internal: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double));
+template <>
+void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float));
+template <>
+void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex));
+template <>
+void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex));
+template <>
+void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half));
+template <>
+void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
+
+#if defined(USE_ROCM) && ROCM_VERSION <= 50500
+// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
+#define CUDABLAS_TRSM_ARGTYPES(Dtype)                                  \
+  hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
+      hipblasOperation_t trans, hipblasDiagType_t diag, int m, int n,    \
+      const Dtype *alpha,       Dtype *A, int lda, Dtype *B, int ldb
+#else
+#define CUDABLAS_TRSM_ARGTYPES(Dtype)                                  \
+  cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
+      cublasOperation_t trans, cublasDiagType_t diag, int m, int n,    \
+      const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
+#endif
+
+template 
+inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
+  TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::trsm: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+TORCH_CUDA_CU_API void trsm(CUDABLAS_TRSM_ARGTYPES(float));
+template <>
+TORCH_CUDA_CU_API void trsm(CUDABLAS_TRSM_ARGTYPES(double));
+template <>
+TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::complex));
+template <>
+TORCH_CUDA_CU_API void trsm>(CUDABLAS_TRSM_ARGTYPES(c10::complex));
+
+#define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)                          \
+  cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
+      cublasOperation_t trans, cublasDiagType_t diag, int m, int n,    \
+      const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb,    \
+      int batchCount
+
+template 
+inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::blas::trsmBatched: not implemented for ",
+      typeid(Dtype).name());
+}
+
+template <>
+TORCH_CUDA_CU_API void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
+template <>
+TORCH_CUDA_CU_API void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
+template <>
+TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex));
+template <>
+TORCH_CUDA_CU_API void trsmBatched>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex));
+
+/* LEVEL 2 BLAS FUNCTIONS */
+
+#define CUDABLAS_GEMV_ARGTYPES(Dtype)                                         \
+  char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
+      const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
+
+template 
+inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::gemv: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void gemv(CUDABLAS_GEMV_ARGTYPES(double));
+template <>
+void gemv(CUDABLAS_GEMV_ARGTYPES(float));
+template <>
+void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex));
+template <>
+void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex));
+template <>
+void gemv(CUDABLAS_GEMV_ARGTYPES(at::Half));
+template <>
+void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
+
+/* LEVEL 1 BLAS FUNCTIONS */
+
+#define CUDABLAS_DOT_ARGTYPES(Dtype)                                      \
+  cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
+      int incy, Dtype *result
+
+template 
+inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::dot: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void dot(CUDABLAS_DOT_ARGTYPES(double));
+template <>
+void dot(CUDABLAS_DOT_ARGTYPES(float));
+template <>
+void dot(CUDABLAS_DOT_ARGTYPES(at::Half));
+template <>
+void dot(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
+template <>
+void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex));
+template <>
+void dot>(CUDABLAS_DOT_ARGTYPES(c10::complex));
+
+template 
+inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
+  AT_ERROR("at::cuda::blas::vdot: not implemented for ", typeid(Dtype).name());
+}
+
+template <>
+void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex));
+template <>
+void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex));
+
+#define CUDABLAS_GETRS_ARGTYPES(Dtype)  \
+  cublasHandle_t handle, cublasOperation_t trans, \
+  int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
+  Dtype** dB_array, int ldb, int* info_array, int batchsize
+
+template
+void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
+  TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ",
+    typeid(Dtype).name());
+}
+template<>
+TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float));
+template<>
+TORCH_CUDA_CU_API void getrsBatched(CUDABLAS_GETRS_ARGTYPES(double));
+template<>
+TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex));
+template<>
+TORCH_CUDA_CU_API void getrsBatched>(CUDABLAS_GETRS_ARGTYPES(c10::complex));
+
+#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)                   \
+  cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
+      Dtype **tau_array, int *info, int batchsize
+
+template 
+void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::blas::geqrfBatched: not implemented for ",
+      typeid(Dtype).name());
+}
+template <>
+TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
+template <>
+TORCH_CUDA_CU_API void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
+template <>
+TORCH_CUDA_CU_API void geqrfBatched>(
+    CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex));
+template <>
+TORCH_CUDA_CU_API void geqrfBatched>(
+    CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex));
+
+#define CUDABLAS_GETRF_ARGTYPES(Dtype)  \
+  int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
+
+template
+void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
+  TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name());
+}
+template<>
+TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(float));
+template<>
+TORCH_CUDA_CU_API void getrfBatched(CUDABLAS_GETRF_ARGTYPES(double));
+template<>
+TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex));
+template<>
+TORCH_CUDA_CU_API void getrfBatched>(CUDABLAS_GETRF_ARGTYPES(c10::complex));
+
+#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)  \
+  cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
+
+template 
+void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
+  TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name());
+}
+
+template<>
+TORCH_CUDA_CU_API void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
+template<>
+TORCH_CUDA_CU_API void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
+template<>
+TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex));
+template<>
+TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex));
+
+} // namespace at::cuda::blas
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAContext.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAContext.h
new file mode 100644
index 0000000000000000000000000000000000000000..b257e3f16b4adb5efde62dff92ed6f8fb9bc1a64
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAContext.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+
+// Preserved for BC, as many files depend on these includes
+#include 
+#include 
+#include 
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAContextLight.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAContextLight.h
new file mode 100644
index 0000000000000000000000000000000000000000..efaf986bc75d611cf6cf637ca7eeebc156de9a53
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAContextLight.h
@@ -0,0 +1,95 @@
+#pragma once
+// Light-weight version of CUDAContext.h with fewer transitive includes
+
+#include 
+
+#include 
+#include 
+#include 
+
+// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
+// added bf16 support
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#include 
+#endif
+
+#ifdef CUDART_VERSION
+#include 
+#endif
+
+#if defined(USE_ROCM) && ROCM_VERSION >= 50300
+#include 
+#endif
+
+#include 
+#include 
+
+namespace c10 {
+struct Allocator;
+}
+
+namespace at::cuda {
+
+/*
+A common CUDA interface for ATen.
+
+This interface is distinct from CUDAHooks, which defines an interface that links
+to both CPU-only and CUDA builds. That interface is intended for runtime
+dispatch and should be used from files that are included in both CPU-only and
+CUDA builds.
+
+CUDAContext, on the other hand, should be preferred by files only included in
+CUDA builds. It is intended to expose CUDA functionality in a consistent
+manner.
+
+This means there is some overlap between the CUDAContext and CUDAHooks, but
+the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
+use CUDAHooks otherwise.
+
+Note that CUDAContext simply defines an interface with no associated class.
+It is expected that the modules whose functions compose this interface will
+manage their own state. There is only a single CUDA context/state.
+*/
+
+/**
+ * DEPRECATED: use device_count() instead
+ */
+inline int64_t getNumGPUs() {
+    return c10::cuda::device_count();
+}
+
+/**
+ * CUDA is available if we compiled with CUDA, and there are one or more
+ * devices.  If we compiled with CUDA but there is a driver problem, etc.,
+ * this function will report CUDA is not available (rather than raise an error.)
+ */
+inline bool is_available() {
+    return c10::cuda::device_count() > 0;
+}
+
+TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
+
+TORCH_CUDA_CPP_API int warp_size();
+
+TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device);
+
+TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
+    c10::DeviceIndex device,
+    c10::DeviceIndex peer_device);
+
+TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
+
+/* Handles */
+TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
+TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
+#endif
+
+TORCH_CUDA_CPP_API void clearCublasWorkspaces();
+
+#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION >= 50300
+TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
+#endif
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDADataType.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDADataType.h
new file mode 100644
index 0000000000000000000000000000000000000000..d14a908a54831124b8127fc5df10870fd6e31b3f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDADataType.h
@@ -0,0 +1,115 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+namespace at::cuda {
+
+template 
+cudaDataType getCudaDataType() {
+  TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.")
+}
+
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_16F;
+}
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_32F;
+}
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_64F;
+}
+template<> inline cudaDataType getCudaDataType>() {
+  return CUDA_C_16F;
+}
+template<> inline cudaDataType getCudaDataType>() {
+  return CUDA_C_32F;
+}
+template<> inline cudaDataType getCudaDataType>() {
+  return CUDA_C_64F;
+}
+
+// HIP doesn't define integral types
+#ifndef USE_ROCM
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_8U;
+}
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_8I;
+}
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_32I;
+}
+#endif
+
+#if !defined(USE_ROCM)
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_16I;
+}
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_64I;
+}
+template<> inline cudaDataType getCudaDataType() {
+  return CUDA_R_16BF;
+}
+#endif
+
+inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
+  switch (scalar_type) {
+// HIP doesn't define integral types
+#ifndef USE_ROCM
+    case c10::ScalarType::Byte:
+      return CUDA_R_8U;
+    case c10::ScalarType::Char:
+      return CUDA_R_8I;
+    case c10::ScalarType::Int:
+      return CUDA_R_32I;
+#endif
+    case c10::ScalarType::Half:
+      return CUDA_R_16F;
+    case c10::ScalarType::Float:
+      return CUDA_R_32F;
+    case c10::ScalarType::Double:
+      return CUDA_R_64F;
+    case c10::ScalarType::ComplexHalf:
+      return CUDA_C_16F;
+    case c10::ScalarType::ComplexFloat:
+      return CUDA_C_32F;
+    case c10::ScalarType::ComplexDouble:
+      return CUDA_C_64F;
+#if !defined(USE_ROCM)
+    case c10::ScalarType::Short:
+      return CUDA_R_16I;
+    case c10::ScalarType::Long:
+      return CUDA_R_64I;
+    case c10::ScalarType::BFloat16:
+      return CUDA_R_16BF;
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
+    case c10::ScalarType::Float8_e4m3fn:
+      return CUDA_R_8F_E4M3;
+    case c10::ScalarType::Float8_e5m2:
+      return CUDA_R_8F_E5M2;
+#endif
+#else // USE_ROCM
+    case c10::ScalarType::BFloat16:
+      return CUDA_R_16BF;
+#if defined(HIP_NEW_TYPE_ENUMS)
+    case c10::ScalarType::Float8_e4m3fnuz:
+      return HIP_R_8F_E4M3_FNUZ;
+    case c10::ScalarType::Float8_e5m2fnuz:
+      return HIP_R_8F_E5M2_FNUZ;
+#else
+    case c10::ScalarType::Float8_e4m3fnuz:
+      return static_cast(1000);
+    case c10::ScalarType::Float8_e5m2fnuz:
+      return static_cast(1001);
+#endif
+#endif
+    default:
+      TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
+  }
+}
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDADevice.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDADevice.h
new file mode 100644
index 0000000000000000000000000000000000000000..5353a06ca6b11f607151a0b7c64762234b617c79
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDADevice.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+namespace at::cuda {
+
+inline Device getDeviceFromPtr(void* ptr) {
+  cudaPointerAttributes attr{};
+
+  AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
+
+#if !defined(USE_ROCM)
+  TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
+    "The specified pointer resides on host memory and is not registered with any CUDA device.");
+#endif
+
+  return {c10::DeviceType::CUDA, static_cast(attr.device)};
+}
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAEvent.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAEvent.h
new file mode 100644
index 0000000000000000000000000000000000000000..9a9a15c4a857b71137afc8735ad226b6d91e3e2a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAEvent.h
@@ -0,0 +1,208 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+
+namespace at::cuda {
+
+/*
+* CUDAEvents are movable not copyable wrappers around CUDA's events.
+*
+* CUDAEvents are constructed lazily when first recorded unless it is
+* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
+* device is acquired from the first recording stream. However, if reconstructed
+* from a handle, the device should be explicitly specified; or if ipc_handle() is
+* called before the event is ever recorded, it will use the current device.
+* Later streams that record the event must match this device.
+*/
+struct TORCH_CUDA_CPP_API CUDAEvent {
+  // Constructors
+  // Default value for `flags` is specified below - it's cudaEventDisableTiming
+  CUDAEvent() noexcept = default;
+  CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
+
+  CUDAEvent(
+      DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
+      device_index_ = device_index;
+      CUDAGuard guard(device_index_);
+
+      AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
+      is_created_ = true;
+  }
+
+  // Note: event destruction done on creating device to avoid creating a
+  // CUDA context on other devices.
+  ~CUDAEvent() {
+    try {
+      if (is_created_) {
+        CUDAGuard guard(device_index_);
+        const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
+        if (C10_UNLIKELY(interp)) {
+          (*interp)->trace_gpu_event_deletion(reinterpret_cast(event_));
+        }
+        AT_CUDA_CHECK(cudaEventDestroy(event_));
+      }
+    } catch (...) { /* No throw */ }
+  }
+
+  CUDAEvent(const CUDAEvent&) = delete;
+  CUDAEvent& operator=(const CUDAEvent&) = delete;
+
+  CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
+  CUDAEvent& operator=(CUDAEvent&& other) noexcept {
+    if (this != &other) {
+      moveHelper(std::move(other));
+    }
+    return *this;
+  }
+
+  operator cudaEvent_t() const { return event(); }
+
+  // Less than operator (to allow use in sets)
+  friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
+    return left.event_ < right.event_;
+  }
+
+  optional device() const {
+    if (is_created_) {
+      return at::Device(at::kCUDA, device_index_);
+    } else {
+      return {};
+    }
+  }
+
+  bool isCreated() const { return is_created_; }
+  DeviceIndex device_index() const {return device_index_;}
+  cudaEvent_t event() const { return event_; }
+
+  // Note: cudaEventQuery can be safely called from any device
+  bool query() const {
+    if (!is_created_) {
+      return true;
+    }
+
+    cudaError_t err = cudaEventQuery(event_);
+    if (err == cudaSuccess) {
+      return true;
+    } else if (err != cudaErrorNotReady) {
+      C10_CUDA_CHECK(err);
+    } else {
+      // ignore and clear the error if not ready
+      (void)cudaGetLastError();
+    }
+
+    return false;
+  }
+
+  void record() { record(getCurrentCUDAStream()); }
+
+  void recordOnce(const CUDAStream& stream) {
+    if (!was_recorded_) record(stream);
+  }
+
+  // Note: cudaEventRecord must be called on the same device as the event.
+  void record(const CUDAStream& stream) {
+    if (!is_created_) {
+      createEvent(stream.device_index());
+    }
+
+    TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
+      " does not match recording stream's device ", stream.device_index(), ".");
+    CUDAGuard guard(device_index_);
+    AT_CUDA_CHECK(cudaEventRecord(event_, stream));
+    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
+    if (C10_UNLIKELY(interp)) {
+      (*interp)->trace_gpu_event_record(
+          reinterpret_cast(event_),
+          reinterpret_cast(stream.stream())
+      );
+    }
+    was_recorded_ = true;
+  }
+
+  // Note: cudaStreamWaitEvent must be called on the same device as the stream.
+  // The event has no actual GPU resources associated with it.
+  void block(const CUDAStream& stream) {
+    if (is_created_) {
+      CUDAGuard guard(stream.device_index());
+      AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
+      const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
+      if (C10_UNLIKELY(interp)) {
+        (*interp)->trace_gpu_event_wait(
+            reinterpret_cast(event_),
+            reinterpret_cast(stream.stream())
+        );
+      }
+    }
+  }
+
+  // Note: cudaEventElapsedTime can be safely called from any device
+  float elapsed_time(const CUDAEvent& other) const {
+    TORCH_CHECK(is_created_ && other.isCreated(),
+      "Both events must be recorded before calculating elapsed time.");
+    float time_ms = 0;
+    // raise cudaErrorNotReady if either event is recorded but not yet completed
+    AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
+    return time_ms;
+  }
+
+  // Note: cudaEventSynchronize can be safely called from any device
+  void synchronize() const {
+    if (is_created_) {
+      const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
+      if (C10_UNLIKELY(interp)) {
+          (*interp)->trace_gpu_event_synchronization(reinterpret_cast(event_));
+      }
+      AT_CUDA_CHECK(cudaEventSynchronize(event_));
+    }
+  }
+
+  // Note: cudaIpcGetEventHandle must be called on the same device as the event
+  void ipc_handle(cudaIpcEventHandle_t * handle) {
+      if (!is_created_) {
+        // this CUDAEvent object was initially constructed from flags but event_
+        // is not created yet.
+        createEvent(getCurrentCUDAStream().device_index());
+      }
+      CUDAGuard guard(device_index_);
+      AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
+  }
+
+private:
+  unsigned int flags_ = cudaEventDisableTiming;
+  bool is_created_ = false;
+  bool was_recorded_ = false;
+  DeviceIndex device_index_ = -1;
+  cudaEvent_t event_{};
+
+  void createEvent(DeviceIndex device_index) {
+    device_index_ = device_index;
+    CUDAGuard guard(device_index_);
+    AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
+    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
+    if (C10_UNLIKELY(interp)) {
+      (*interp)->trace_gpu_event_creation(reinterpret_cast(event_));
+    }
+    is_created_ = true;
+  }
+
+  void moveHelper(CUDAEvent&& other) {
+    std::swap(flags_, other.flags_);
+    std::swap(is_created_, other.is_created_);
+    std::swap(was_recorded_, other.was_recorded_);
+    std::swap(device_index_, other.device_index_);
+    std::swap(event_, other.event_);
+  }
+};
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..f237d77f009d32080a1d9445bc42b256939a78ae
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h
@@ -0,0 +1,138 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+/**
+ * Note [CUDA Graph-safe RNG states]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ *
+ * Strategy:
+ * ~~~~~~~~~
+ * (It helps to look at
+ * cuda/detail/PhiloxCudaStateRaw.cuh and
+ * cuda/detail/UnpackRaw.cuh
+ * while you read this.)
+ *
+ * A CUDA graph containing multiple RNG ops behaves like a
+ * single giant kernel from the perspective of ops external
+ * to the graph.  During graph capture, logic in CUDAGeneratorImpl
+ * records the total of all offset increments that occur in the
+ * graphed region, and records the final total as the offset for
+ * the entire graph.
+ *
+ * When the graph reruns, the logic that reruns it
+ * increments this device's CUDA generator's offset
+ * by that total.
+ *
+ * Meanwhile, within the graph, at capture time, instead of
+ * populating PhiloxCudaStates with the uint64_t offset pulled
+ * directly from the global state, PhiloxCudaState uses a pointer
+ * to a one-element stream-local int64_t device tensor
+ * holding an initial offset value, and a uint64_t holding an
+ * intra-graph offset. (The intra-graph offset starts from zero
+ * when capture begins.)  In each consumer kernel,
+ * at::cuda::philox::unpack computes the offset to use for this kernel
+ * as intra-graph offset + *initial offset.
+ *
+ * When the graph reruns, the logic that reruns it first
+ * fill_s the initial offset tensor with this device's
+ * CUDA generator's current offset.
+ *
+ * The control flow above ensures graphed execution is bitwise
+ * identical to eager execution as long as RNG ops are enqueued
+ * from a single thread, even if RNG ops and graphs containing
+ * RNG ops are enqueued and run simultaneously on multiple streams.
+ *
+ * Usage:
+ * ~~~~~~
+ * PhiloxCudaState in this file, and unpack() in
+ * cuda/CUDAGraphsUtils.cuh allow non-divergent use of
+ * CUDAGeneratorImpl whether graph capture is underway or not.
+ *
+ * Each PhiloxCudaState instance should be used for one and only one
+ * consumer kernel.
+ *
+ * Example (see e.g. native/cuda/Dropout.cu):
+ *
+ * #include 
+ * #include 
+ *
+ * __global__ void kernel(..., PhiloxCudaState philox_args) {
+ *   auto seeds = at::cuda::philox::unpack(philox_args);
+ *   IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
+ *   curandStatePhilox4_32_10_t state;
+ *   curand_init(std::get<0>(seeds), // seed
+ *               idx,                // per-thread subsequence
+ *               std::get<1>(seeds), // offset in subsequence
+ *               &state);
+ *   ...
+ * }
+ *
+ * host_caller(...) {
+ *   PhiloxCudaState rng_engine_inputs;
+ *   {
+ *     // See Note [Acquire lock when using random generators]
+ *     std::lock_guard lock(gen->mutex_);
+ *
+ *     // gen could be HostState or DevState here! No divergent code needed!
+ *     rng_engine_inputs = gen->philox_cuda_state(offset_increment);
+ *   }
+ *   kernel<<<...>>>(..., rng_engine_inputs);
+ * }
+ *
+ */
+
+struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
+  // Constructors
+  CUDAGeneratorImpl(DeviceIndex device_index = -1);
+  ~CUDAGeneratorImpl() override = default;
+
+  // CUDAGeneratorImpl methods
+  std::shared_ptr clone() const;
+  void set_current_seed(uint64_t seed) override;
+  void set_offset(uint64_t offset) override;
+  uint64_t get_offset() const override;
+  uint64_t current_seed() const override;
+  uint64_t seed() override;
+  void set_state(const c10::TensorImpl& new_state) override;
+  c10::intrusive_ptr get_state() const override;
+  void set_philox_offset_per_thread(uint64_t offset);
+  uint64_t philox_offset_per_thread() const;
+  void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
+  uint64_t capture_epilogue();
+  PhiloxCudaState philox_cuda_state(uint64_t increment);
+
+  bool reset_rnn_state() {
+    return !no_reset_rnn_state_.test_and_set();
+  }
+
+  // Temporarily accommodates call sites that use philox_engine_inputs.
+  // Allows incremental refactor of call sites to use philox_cuda_state.
+  std::pair philox_engine_inputs(uint64_t increment);
+
+  static c10::DeviceType device_type();
+
+private:
+  CUDAGeneratorImpl* clone_impl() const override;
+  uint64_t seed_ = default_rng_seed_val;
+  uint64_t philox_offset_per_thread_ = 0;
+  int64_t* seed_extragraph_{};
+  int64_t* offset_extragraph_{};
+  uint32_t offset_intragraph_ = 0;
+  bool graph_expects_this_gen_ = false;
+  std::atomic_flag no_reset_rnn_state_;
+};
+
+namespace cuda::detail {
+
+TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
+    DeviceIndex device_index = -1);
+TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
+
+} // namespace cuda::detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGraph.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGraph.h
new file mode 100644
index 0000000000000000000000000000000000000000..8b3c1a3f27393f13971622f6b432818ece002cb1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGraph.h
@@ -0,0 +1,92 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+
+struct CUDAGeneratorImpl;
+
+namespace cuda {
+
+// Standalone way to get a unique mempool id usable as a pool=... argument
+// to CUDAGraph::capture_begin
+TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
+
+struct TORCH_CUDA_CPP_API CUDAGraph {
+  CUDAGraph();
+  ~CUDAGraph();
+
+  static void inc_pending_event_queries();
+  static void dec_pending_event_queries();
+  static int num_pending_event_queries();
+  void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
+  void capture_end();
+  void replay();
+  void reset();
+  MempoolId_t pool();
+  void enable_debug_mode();
+  void debug_dump(const std::string& debug_path);
+
+  protected:
+#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
+  cudaGraph_t graph_ = NULL;
+  cudaGraphExec_t graph_exec_ = NULL;
+#endif
+
+  static std::atomic pending_event_queries;
+
+  // internal states so reset() can do its best cleaning up
+  // Set to true in capture_end if cudaStreamEndCapture succeeded
+  // Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
+  // to create graph_exec_, then graph_ is deleted
+  bool has_graph_ = false;
+  // Set to true in capture_end if cudaGraphInstantiate succeeded
+  bool has_graph_exec_ = false;
+
+  // uuid of this instance's current capture, used to
+  // specify the pool.
+  CaptureId_t id_;
+
+  // the ID assigned by cuda during graph capture,
+  // used to identify when a stream is participating in capture
+  CaptureId_t capture_id_ = -1;
+
+  // uuid used to request a particular private mempool from CUDACachingAllocator.
+  // By default, this will be set to {id_, 0}.
+  //
+  // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
+  // will be set to the other graph's mempool_id_, and therefore share a mempool with the
+  // other graph.
+  //
+  // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
+  // it will share a mempool with any other captures that used "pool=handle".
+  //
+  // Sharing a mempool across graphs saves memory, and it's safe if you
+  // know you'll replay those graphs in the same order you captured them.
+  MempoolId_t mempool_id_;
+
+  // Stream on which capture began
+  at::cuda::CUDAStream capture_stream_;
+
+  // Default generator on device where capture began
+  at::CUDAGeneratorImpl* capture_gen_;
+
+  // Device where capture occurred. Right now, for simplicity, we require all ops
+  // in a capture to run on the same device, but this is a limitation of CUDAGraph,
+  // not CUDA itself.  We can straightforwardly modify CUDAGraph to support multi-device
+  // captures if needed.
+  int capture_dev_;
+
+  // RNG state trackers
+  at::Tensor seed_extragraph_;
+  at::Tensor offset_extragraph_;
+  uint64_t wholegraph_increment_;
+};
+
+} // namespace cuda
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..8de2adbb7ec9a7c8f47f23bebb855135a7452885
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh
@@ -0,0 +1,57 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
+// This file adds utils used by aten only.
+
+namespace at::cuda {
+
+using CaptureId_t = c10::cuda::CaptureId_t;
+using CaptureStatus = c10::cuda::CaptureStatus;
+
+// Use this version where you don't want to create a CUDA context if none exists.
+inline CaptureStatus currentStreamCaptureStatus() {
+#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
+  // don't create a context if we don't have to
+  if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
+    return c10::cuda::currentStreamCaptureStatusMayInitCtx();
+  } else {
+    return CaptureStatus::None;
+  }
+#else
+  return CaptureStatus::None;
+#endif
+}
+
+inline void assertNotCapturing(std::string attempt) {
+  auto status = currentStreamCaptureStatus();
+  TORCH_CHECK(status == CaptureStatus::None,
+              attempt,
+              " during CUDA graph capture. If you need this call to be captured, "
+              "please file an issue. "
+              "Current cudaStreamCaptureStatus: ",
+              status);
+}
+
+inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
+  auto status = currentStreamCaptureStatus();
+  TORCH_CHECK(status == CaptureStatus::None,
+              "Current cudaStreamCaptureStatus: ",
+              status,
+              "\nCapturing ",
+              version_specific,
+              "is prohibited. Possible causes of this error:\n"
+              "1. No warmup iterations occurred before capture.\n"
+              "2. The convolutions you're trying to capture use dynamic shapes, "
+              "in which case capturing them is generally prohibited.");
+}
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparse.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparse.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2b5c0beade3776bc8c98edad6a4d05460dff1c4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparse.h
@@ -0,0 +1,76 @@
+#pragma once
+
+#include 
+#if defined(USE_ROCM)
+#include 
+#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
+#endif
+
+// cuSparse Generic API added in CUDA 10.1
+// Windows support added in CUDA 11.0
+#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
+#define AT_USE_CUSPARSE_GENERIC_API() 1
+#else
+#define AT_USE_CUSPARSE_GENERIC_API() 0
+#endif
+
+// cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
+#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
+    (CUSPARSE_VERSION < 12000)
+#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
+#else
+#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
+#endif
+
+#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
+    (CUSPARSE_VERSION >= 12000)
+#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
+#else
+#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
+#endif
+
+#if defined(USE_ROCM)
+// hipSparse const API added in v2.4.0
+#if HIPSPARSE_VERSION >= 200400
+#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
+#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
+#define AT_USE_HIPSPARSE_GENERIC_API() 1
+#else
+#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
+#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1
+#define AT_USE_HIPSPARSE_GENERIC_API() 1
+#endif
+#else // USE_ROCM
+#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
+#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
+#define AT_USE_HIPSPARSE_GENERIC_API() 0
+#endif // USE_ROCM
+
+// cuSparse Generic API spsv function was added in CUDA 11.3.0
+#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
+#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
+#else
+#define AT_USE_CUSPARSE_GENERIC_SPSV() 0
+#endif
+
+// cuSparse Generic API spsm function was added in CUDA 11.3.1
+#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
+#define AT_USE_CUSPARSE_GENERIC_SPSM() 1
+#else
+#define AT_USE_CUSPARSE_GENERIC_SPSM() 0
+#endif
+
+// cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
+#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
+#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
+#else
+#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
+#endif
+
+// BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
+#if defined(CUDART_VERSION) ||                            \
+      (defined(USE_ROCM) && ROCM_VERSION >= 40500 )
+#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
+#else
+#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
+#endif
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h
new file mode 100644
index 0000000000000000000000000000000000000000..9eb0488d2b3dbe7a64dbdd0463f5fbd53b9cde18
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h
@@ -0,0 +1,318 @@
+#pragma once
+
+/*
+  Provides a subset of cuSPARSE functions as templates:
+
+    csrgeam2(...)
+
+  where scalar_t is double, float, c10::complex or c10::complex.
+  The functions are available in at::cuda::sparse namespace.
+*/
+
+#include 
+#include 
+
+namespace at::cuda::sparse {
+
+#define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)             \
+  cusparseHandle_t handle, int m, int n, const scalar_t *alpha,     \
+      const cusparseMatDescr_t descrA, int nnzA,                    \
+      const scalar_t *csrSortedValA, const int *csrSortedRowPtrA,   \
+      const int *csrSortedColIndA, const scalar_t *beta,            \
+      const cusparseMatDescr_t descrB, int nnzB,                    \
+      const scalar_t *csrSortedValB, const int *csrSortedRowPtrB,   \
+      const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
+      const scalar_t *csrSortedValC, const int *csrSortedRowPtrC,   \
+      const int *csrSortedColIndC, size_t *pBufferSizeInBytes
+
+template 
+inline void csrgeam2_bufferSizeExt(
+    CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void csrgeam2_bufferSizeExt(
+    CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
+template <>
+void csrgeam2_bufferSizeExt(
+    CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
+template <>
+void csrgeam2_bufferSizeExt>(
+    CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex));
+template <>
+void csrgeam2_bufferSizeExt>(
+    CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex));
+
+#define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()                                      \
+  cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA,     \
+      int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA,     \
+      const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
+      const int *csrSortedColIndB, const cusparseMatDescr_t descrC,           \
+      int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
+
+template 
+inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
+  TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
+      handle,
+      m,
+      n,
+      descrA,
+      nnzA,
+      csrSortedRowPtrA,
+      csrSortedColIndA,
+      descrB,
+      nnzB,
+      csrSortedRowPtrB,
+      csrSortedColIndB,
+      descrC,
+      csrSortedRowPtrC,
+      nnzTotalDevHostPtr,
+      workspace));
+}
+
+#define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)                                 \
+  cusparseHandle_t handle, int m, int n, const scalar_t *alpha,              \
+      const cusparseMatDescr_t descrA, int nnzA,                             \
+      const scalar_t *csrSortedValA, const int *csrSortedRowPtrA,            \
+      const int *csrSortedColIndA, const scalar_t *beta,                     \
+      const cusparseMatDescr_t descrB, int nnzB,                             \
+      const scalar_t *csrSortedValB, const int *csrSortedRowPtrB,            \
+      const int *csrSortedColIndB, const cusparseMatDescr_t descrC,          \
+      scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
+      void *pBuffer
+
+template 
+inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::csrgeam2: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(float));
+template <>
+void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(double));
+template <>
+void csrgeam2>(
+    CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex));
+template <>
+void csrgeam2>(
+    CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRMM_ARGTYPES(scalar_t)                                    \
+  cusparseHandle_t handle, cusparseDirection_t dirA,                         \
+      cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
+      int kb, int nnzb, const scalar_t *alpha,                               \
+      const cusparseMatDescr_t descrA, const scalar_t *bsrValA,              \
+      const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,            \
+      const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
+
+template 
+inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrmm: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrmm(CUSPARSE_BSRMM_ARGTYPES(float));
+template <>
+void bsrmm(CUSPARSE_BSRMM_ARGTYPES(double));
+template <>
+void bsrmm>(CUSPARSE_BSRMM_ARGTYPES(c10::complex));
+template <>
+void bsrmm>(CUSPARSE_BSRMM_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRMV_ARGTYPES(scalar_t)                                    \
+  cusparseHandle_t handle, cusparseDirection_t dirA,                         \
+      cusparseOperation_t transA, int mb, int nb, int nnzb,                  \
+      const scalar_t *alpha, const cusparseMatDescr_t descrA,                \
+      const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
+      int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
+
+template 
+inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrmv: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrmv(CUSPARSE_BSRMV_ARGTYPES(float));
+template <>
+void bsrmv(CUSPARSE_BSRMV_ARGTYPES(double));
+template <>
+void bsrmv>(CUSPARSE_BSRMV_ARGTYPES(c10::complex));
+template <>
+void bsrmv>(CUSPARSE_BSRMV_ARGTYPES(c10::complex));
+
+#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
+
+#define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)                 \
+  cusparseHandle_t handle, cusparseDirection_t dirA,              \
+      cusparseOperation_t transA, int mb, int nnzb,               \
+      const cusparseMatDescr_t descrA, scalar_t *bsrValA,         \
+      const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
+      bsrsv2Info_t info, int *pBufferSizeInBytes
+
+template 
+inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
+template <>
+void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
+template <>
+void bsrsv2_bufferSize>(
+    CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex));
+template <>
+void bsrsv2_bufferSize>(
+    CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)               \
+  cusparseHandle_t handle, cusparseDirection_t dirA,              \
+      cusparseOperation_t transA, int mb, int nnzb,               \
+      const cusparseMatDescr_t descrA, const scalar_t *bsrValA,   \
+      const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
+      bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
+
+template 
+inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrsv2_analysis: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
+template <>
+void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
+template <>
+void bsrsv2_analysis>(
+    CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex));
+template <>
+void bsrsv2_analysis>(
+    CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)                           \
+  cusparseHandle_t handle, cusparseDirection_t dirA,                       \
+      cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
+      const cusparseMatDescr_t descrA, const scalar_t *bsrValA,            \
+      const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,          \
+      bsrsv2Info_t info, const scalar_t *x, scalar_t *y,                   \
+      cusparseSolvePolicy_t policy, void *pBuffer
+
+template 
+inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrsv2_solve: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
+template <>
+void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
+template <>
+void bsrsv2_solve>(
+    CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex));
+template <>
+void bsrsv2_solve>(
+    CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)                            \
+  cusparseHandle_t handle, cusparseDirection_t dirA,                         \
+      cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
+      int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA,          \
+      const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,            \
+      bsrsm2Info_t info, int *pBufferSizeInBytes
+
+template 
+inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
+template <>
+void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
+template <>
+void bsrsm2_bufferSize>(
+    CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex));
+template <>
+void bsrsm2_bufferSize>(
+    CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)                          \
+  cusparseHandle_t handle, cusparseDirection_t dirA,                         \
+      cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
+      int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA,    \
+      const int *bsrRowPtrA, const int *bsrColIndA, int blockDim,            \
+      bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
+
+template 
+inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrsm2_analysis: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
+template <>
+void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
+template <>
+void bsrsm2_analysis>(
+    CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex));
+template <>
+void bsrsm2_analysis>(
+    CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex));
+
+#define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)                             \
+  cusparseHandle_t handle, cusparseDirection_t dirA,                         \
+      cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
+      int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA,      \
+      const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
+      int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb,           \
+      scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
+
+template 
+inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
+  TORCH_INTERNAL_ASSERT(
+      false,
+      "at::cuda::sparse::bsrsm2_solve: not implemented for ",
+      typeid(scalar_t).name());
+}
+
+template <>
+void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
+template <>
+void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
+template <>
+void bsrsm2_solve>(
+    CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex));
+template <>
+void bsrsm2_solve>(
+    CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex));
+
+#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
+
+} // namespace at::cuda::sparse
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h
new file mode 100644
index 0000000000000000000000000000000000000000..b5a5391ab3642ef5301b1d87a94a4306bf5f6929
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h
@@ -0,0 +1,290 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+#if defined(USE_ROCM)
+#include 
+#endif
+
+namespace at::cuda::sparse {
+
+template 
+struct CuSparseDescriptorDeleter {
+  void operator()(T* x) {
+    if (x != nullptr) {
+      TORCH_CUDASPARSE_CHECK(destructor(x));
+    }
+  }
+};
+
+template 
+class CuSparseDescriptor {
+ public:
+  T* descriptor() const {
+    return descriptor_.get();
+  }
+  T* descriptor() {
+    return descriptor_.get();
+  }
+
+ protected:
+  std::unique_ptr> descriptor_;
+};
+
+#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
+template 
+struct ConstCuSparseDescriptorDeleter {
+  void operator()(T* x) {
+    if (x != nullptr) {
+      TORCH_CUDASPARSE_CHECK(destructor(x));
+    }
+  }
+};
+
+template 
+class ConstCuSparseDescriptor {
+ public:
+  T* descriptor() const {
+    return descriptor_.get();
+  }
+  T* descriptor() {
+    return descriptor_.get();
+  }
+
+ protected:
+  std::unique_ptr> descriptor_;
+};
+#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
+
+#if defined(USE_ROCM)
+using cusparseMatDescr = std::remove_pointer::type;
+using cusparseDnMatDescr = std::remove_pointer::type;
+using cusparseDnVecDescr = std::remove_pointer::type;
+using cusparseSpMatDescr = std::remove_pointer::type;
+using cusparseSpMatDescr = std::remove_pointer::type;
+using cusparseSpGEMMDescr = std::remove_pointer::type;
+#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
+using bsrsv2Info = std::remove_pointer::type;
+using bsrsm2Info = std::remove_pointer::type;
+#endif
+#endif
+
+// NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced
+// API for const descriptors
+cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr);
+
+class TORCH_CUDA_CPP_API CuSparseMatDescriptor
+    : public CuSparseDescriptor {
+ public:
+  CuSparseMatDescriptor() {
+    cusparseMatDescr_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
+    descriptor_.reset(raw_descriptor);
+  }
+
+  CuSparseMatDescriptor(bool upper, bool unit) {
+    cusparseFillMode_t fill_mode =
+        upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
+    cusparseDiagType_t diag_type =
+        unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
+    cusparseMatDescr_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
+    TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode));
+    TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type));
+    descriptor_.reset(raw_descriptor);
+  }
+};
+
+#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
+
+class TORCH_CUDA_CPP_API CuSparseBsrsv2Info
+    : public CuSparseDescriptor {
+ public:
+  CuSparseBsrsv2Info() {
+    bsrsv2Info_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor));
+    descriptor_.reset(raw_descriptor);
+  }
+};
+
+class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
+    : public CuSparseDescriptor {
+ public:
+  CuSparseBsrsm2Info() {
+    bsrsm2Info_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor));
+    descriptor_.reset(raw_descriptor);
+  }
+};
+
+#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
+
+#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
+
+cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
+
+#if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS()
+class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
+    : public CuSparseDescriptor {
+ public:
+  explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
+};
+
+class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
+    : public CuSparseDescriptor {
+ public:
+  explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
+  cusparseDnMatDescr* unsafe_mutable_descriptor() const {
+    return const_cast(descriptor());
+  }
+  cusparseDnMatDescr* unsafe_mutable_descriptor() {
+    return const_cast(descriptor());
+  }
+};
+
+class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
+    : public CuSparseDescriptor {
+ public:
+  explicit CuSparseDnVecDescriptor(const Tensor& input);
+};
+
+class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
+    : public CuSparseDescriptor {};
+
+#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
+  class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
+      : public ConstCuSparseDescriptor<
+            cusparseDnMatDescr,
+            &cusparseDestroyDnMat> {
+   public:
+    explicit CuSparseDnMatDescriptor(
+        const Tensor& input,
+        int64_t batch_offset = -1);
+  };
+
+  class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
+      : public ConstCuSparseDescriptor<
+            const cusparseDnMatDescr,
+            &destroyConstDnMat> {
+   public:
+    explicit CuSparseConstDnMatDescriptor(
+        const Tensor& input,
+        int64_t batch_offset = -1);
+  cusparseDnMatDescr* unsafe_mutable_descriptor() const {
+    return const_cast(descriptor());
+  }
+  cusparseDnMatDescr* unsafe_mutable_descriptor() {
+    return const_cast(descriptor());
+  }
+  };
+
+  class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
+      : public ConstCuSparseDescriptor<
+            cusparseDnVecDescr,
+            &cusparseDestroyDnVec> {
+   public:
+    explicit CuSparseDnVecDescriptor(const Tensor& input);
+  };
+
+  class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
+      : public ConstCuSparseDescriptor<
+            cusparseSpMatDescr,
+            &cusparseDestroySpMat> {};
+#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
+
+class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
+    : public CuSparseSpMatDescriptor {
+ public:
+  explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
+
+  std::tuple get_size() {
+    int64_t rows, cols, nnz;
+    TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
+        this->descriptor(),
+        &rows,
+        &cols,
+        &nnz));
+    return std::make_tuple(rows, cols, nnz);
+  }
+
+  void set_tensor(const Tensor& input) {
+    auto crow_indices = input.crow_indices();
+    auto col_indices = input.col_indices();
+    auto values = input.values();
+
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
+    TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers(
+        this->descriptor(),
+        crow_indices.data_ptr(),
+        col_indices.data_ptr(),
+        values.data_ptr()));
+  }
+
+#if AT_USE_CUSPARSE_GENERIC_SPSV()
+  void set_mat_fill_mode(bool upper) {
+    cusparseFillMode_t fill_mode =
+        upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
+    TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
+        this->descriptor(),
+        CUSPARSE_SPMAT_FILL_MODE,
+        &fill_mode,
+        sizeof(fill_mode)));
+  }
+
+  void set_mat_diag_type(bool unit) {
+    cusparseDiagType_t diag_type =
+        unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
+    TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
+        this->descriptor(),
+        CUSPARSE_SPMAT_DIAG_TYPE,
+        &diag_type,
+        sizeof(diag_type)));
+  }
+#endif
+};
+
+#if AT_USE_CUSPARSE_GENERIC_SPSV()
+class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor
+    : public CuSparseDescriptor {
+ public:
+  CuSparseSpSVDescriptor() {
+    cusparseSpSVDescr_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
+    descriptor_.reset(raw_descriptor);
+  }
+};
+#endif
+
+#if AT_USE_CUSPARSE_GENERIC_SPSM()
+class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
+    : public CuSparseDescriptor {
+ public:
+  CuSparseSpSMDescriptor() {
+    cusparseSpSMDescr_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
+    descriptor_.reset(raw_descriptor);
+  }
+};
+#endif
+
+#if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || !defined(USE_ROCM)
+class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
+    : public CuSparseDescriptor {
+ public:
+  CuSparseSpGEMMDescriptor() {
+    cusparseSpGEMMDescr_t raw_descriptor;
+    TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor));
+    descriptor_.reset(raw_descriptor);
+  }
+};
+#endif
+
+#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
+
+} // namespace at::cuda::sparse
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..e3c8526a0004cde8198965f3aea34af25ac5c452
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh
@@ -0,0 +1,15 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at {
+template <>
+inline __half* Tensor::data() const {
+  return reinterpret_cast<__half*>(data());
+}
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..f41fae69ea89d078d61ebb3f698d0e24904761a0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include 
+
+namespace at::cuda {
+
+// Check if every tensor in a list of tensors matches the current
+// device.
+inline bool check_device(ArrayRef ts) {
+  if (ts.empty()) {
+    return true;
+  }
+  Device curDevice = Device(kCUDA, current_device());
+  for (const Tensor& t : ts) {
+    if (t.device() != curDevice) return false;
+  }
+  return true;
+}
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..550000c0e580f0a91932c44161979922d5e00227
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::cuda {
+
+//
+// A caching allocator for CUDA host allocations (pinned memory).
+//
+// This provides a drop-in replacement for THCudaHostAllocator, which re-uses
+// freed pinned (page-locked) memory allocations. This avoids device
+// synchronizations due to cudaFreeHost calls.
+//
+// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be
+// called anytime a pointer from this allocator is used in a cudaMemcpyAsync
+// call between host and device, and passed the corresponding context from the
+// allocation. This is currently invoked by at::native::copy_kernel_cuda.
+//
+// Note that this allocator does not split larger allocations into smaller
+// blocks, unlike the caching device allocator.
+//
+TORCH_CUDA_CPP_API c10::Allocator* getCachingHostAllocator();
+
+// Records an event in the specified stream. The allocation corresponding to the
+// input `ptr`/`ctx` will not be re-used until the event has occurred.
+TORCH_CUDA_CPP_API bool
+CachingHostAllocator_recordEvent(void* ptr, void* ctx, c10::cuda::CUDAStream stream);
+
+// Releases cached pinned memory allocations via cudaHostFree
+TORCH_CUDA_CPP_API void CachingHostAllocator_emptyCache();
+
+inline TORCH_CUDA_CPP_API at::DataPtr HostAlloc(size_t size) {
+  return getCachingHostAllocator()->allocate(size);
+}
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..3a258954db6306d16caf24906499faa7bc54aa77
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh
@@ -0,0 +1,121 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+__device__ __forceinline__ unsigned int ACTIVE_MASK()
+{
+#if !defined(USE_ROCM)
+    return __activemask();
+#else
+// will be ignored anyway
+    return 0xffffffff;
+#endif
+}
+
+__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
+#if !defined(USE_ROCM)
+  return __syncwarp(mask);
+#endif
+}
+
+#if defined(USE_ROCM)
+__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
+{
+return __ballot(predicate);
+}
+#else
+__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
+{
+#if !defined(USE_ROCM)
+    return __ballot_sync(mask, predicate);
+#else
+    return __ballot(predicate);
+#endif
+}
+#endif
+
+template 
+__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if !defined(USE_ROCM)
+    return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+    return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+template 
+__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if !defined(USE_ROCM)
+    return __shfl_sync(mask, value, srcLane, width);
+#else
+    return __shfl(value, srcLane, width);
+#endif
+}
+
+template 
+__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if !defined(USE_ROCM)
+    return __shfl_up_sync(mask, value, delta, width);
+#else
+    return __shfl_up(value, delta, width);
+#endif
+}
+
+template 
+__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if !defined(USE_ROCM)
+    return __shfl_down_sync(mask, value, delta, width);
+#else
+    return __shfl_down(value, delta, width);
+#endif
+}
+
+#if defined(USE_ROCM)
+template<>
+__device__ __forceinline__ int64_t WARP_SHFL_DOWN(int64_t value, unsigned int delta, int width , unsigned int mask)
+{
+  //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
+  int2 a = *reinterpret_cast(&value);
+  a.x = __shfl_down(a.x, delta);
+  a.y = __shfl_down(a.y, delta);
+  return *reinterpret_cast(&a);
+}
+#endif
+
+template<>
+__device__ __forceinline__ c10::Half WARP_SHFL_DOWN(c10::Half value, unsigned int delta, int width, unsigned int mask)
+{
+  return c10::Half(WARP_SHFL_DOWN(value.x, delta, width, mask), c10::Half::from_bits_t{});
+}
+
+template 
+__device__ __forceinline__ c10::complex WARP_SHFL_DOWN(c10::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if !defined(USE_ROCM)
+    return c10::complex(
+        __shfl_down_sync(mask, value.real_, delta, width),
+        __shfl_down_sync(mask, value.imag_, delta, width));
+#else
+    return c10::complex(
+        __shfl_down(value.real_, delta, width),
+        __shfl_down(value.imag_, delta, width));
+#endif
+}
+
+/**
+ * For CC 3.5+, perform a load using __ldg
+ */
+template 
+__device__ __forceinline__ T doLdg(const T* p) {
+#if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
+  return __ldg(p);
+#else
+  return *p;
+#endif
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/EmptyTensor.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/EmptyTensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..54942b88f761a277809e0901e931fcc6d18f950e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/EmptyTensor.h
@@ -0,0 +1,44 @@
+#pragma once
+#include 
+
+namespace at::detail {
+
+TORCH_CUDA_CPP_API TensorBase empty_cuda(
+    IntArrayRef size,
+    ScalarType dtype,
+    c10::optional device_opt,
+    c10::optional memory_format_opt);
+
+TORCH_CUDA_CPP_API TensorBase empty_cuda(
+    IntArrayRef size,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt,
+    c10::optional memory_format_opt);
+
+TORCH_CUDA_CPP_API TensorBase empty_cuda(
+    IntArrayRef size,
+    const TensorOptions &options);
+
+TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
+    IntArrayRef size,
+    IntArrayRef stride,
+    ScalarType dtype,
+    c10::optional device_opt);
+
+TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
+    IntArrayRef size,
+    IntArrayRef stride,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt);
+
+TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
+    IntArrayRef size,
+    IntArrayRef stride,
+    const TensorOptions &options);
+
+
+}  // namespace at::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/Exceptions.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/Exceptions.h
new file mode 100644
index 0000000000000000000000000000000000000000..6f83d217db306ae038cc01023f54017c014af83e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/Exceptions.h
@@ -0,0 +1,174 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifdef CUDART_VERSION
+#include 
+#endif
+
+#include 
+#include 
+#include 
+
+
+namespace c10 {
+
+class CuDNNError : public c10::Error {
+  using Error::Error;
+};
+
+}  // namespace c10
+
+#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...)                                                      \
+  do {                                                                                          \
+    auto error_object = EXPR;                                                                   \
+    if (!error_object.is_good()) {                                                              \
+      TORCH_CHECK_WITH(CuDNNError, false,                                                       \
+            "cuDNN Frontend error: ", error_object.get_message());                              \
+    }                                                                                           \
+  } while (0)                                                                                   \
+
+#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
+
+// See Note [CHECK macro]
+#define AT_CUDNN_CHECK(EXPR, ...)                                                               \
+  do {                                                                                          \
+    cudnnStatus_t status = EXPR;                                                                \
+    if (status != CUDNN_STATUS_SUCCESS) {                                                       \
+      if (status == CUDNN_STATUS_NOT_SUPPORTED) {                                               \
+        TORCH_CHECK_WITH(CuDNNError, false,                                                     \
+            "cuDNN error: ",                                                                    \
+            cudnnGetErrorString(status),                                                        \
+            ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
+      } else {                                                                                  \
+        TORCH_CHECK_WITH(CuDNNError, false,                                                     \
+            "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__);                       \
+      }                                                                                         \
+    }                                                                                           \
+  } while (0)
+
+namespace at::cuda::blas {
+C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
+} // namespace at::cuda::blas
+
+#define TORCH_CUDABLAS_CHECK(EXPR)                              \
+  do {                                                          \
+    cublasStatus_t __err = EXPR;                                \
+    TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS,                 \
+                "CUDA error: ",                                 \
+                at::cuda::blas::_cublasGetErrorEnum(__err),     \
+                " when calling `" #EXPR "`");                   \
+  } while (0)
+
+const char *cusparseGetErrorString(cusparseStatus_t status);
+
+#define TORCH_CUDASPARSE_CHECK(EXPR)                            \
+  do {                                                          \
+    cusparseStatus_t __err = EXPR;                              \
+    TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS,               \
+                "CUDA error: ",                                 \
+                cusparseGetErrorString(__err),                  \
+                " when calling `" #EXPR "`");                   \
+  } while (0)
+
+// cusolver related headers are only supported on cuda now
+#ifdef CUDART_VERSION
+
+namespace at::cuda::solver {
+C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
+
+constexpr const char* _cusolver_backend_suggestion =            \
+  "If you keep seeing this error, you may use "                 \
+  "`torch.backends.cuda.preferred_linalg_library()` to try "    \
+  "linear algebra operators with other supported backends. "    \
+  "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
+
+} // namespace at::cuda::solver
+
+// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
+// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
+#define TORCH_CUSOLVER_CHECK(EXPR)                                      \
+  do {                                                                  \
+    cusolverStatus_t __err = EXPR;                                      \
+    if ((CUDA_VERSION < 11500 &&                                        \
+         __err == CUSOLVER_STATUS_EXECUTION_FAILED) ||                  \
+        (CUDA_VERSION >= 11500 &&                                       \
+         __err == CUSOLVER_STATUS_INVALID_VALUE)) {                     \
+      TORCH_CHECK_LINALG(                                               \
+          false,                                                        \
+          "cusolver error: ",                                           \
+          at::cuda::solver::cusolverGetErrorMessage(__err),             \
+          ", when calling `" #EXPR "`",                                 \
+          ". This error may appear if the input matrix contains NaN. ", \
+          at::cuda::solver::_cusolver_backend_suggestion);              \
+    } else {                                                            \
+      TORCH_CHECK(                                                      \
+          __err == CUSOLVER_STATUS_SUCCESS,                             \
+          "cusolver error: ",                                           \
+          at::cuda::solver::cusolverGetErrorMessage(__err),             \
+          ", when calling `" #EXPR "`. ",                               \
+          at::cuda::solver::_cusolver_backend_suggestion);              \
+    }                                                                   \
+  } while (0)
+
+#else
+#define TORCH_CUSOLVER_CHECK(EXPR) EXPR
+#endif
+
+#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
+
+// For CUDA Driver API
+//
+// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
+// in ATen, and we need to use its nvrtcGetErrorString.
+// See NOTE [ USE OF NVRTC AND DRIVER API ].
+#if !defined(USE_ROCM)
+
+#define AT_CUDA_DRIVER_CHECK(EXPR)                                                                               \
+  do {                                                                                                           \
+    CUresult __err = EXPR;                                                                                       \
+    if (__err != CUDA_SUCCESS) {                                                                                 \
+      const char* err_str;                                                                                       \
+      CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str);  \
+      if (get_error_str_err != CUDA_SUCCESS) {                                                                   \
+        AT_ERROR("CUDA driver error: unknown error");                                                            \
+      } else {                                                                                                   \
+        AT_ERROR("CUDA driver error: ", err_str);                                                                \
+      }                                                                                                          \
+    }                                                                                                            \
+  } while (0)
+
+#else
+
+#define AT_CUDA_DRIVER_CHECK(EXPR)                                                \
+  do {                                                                            \
+    CUresult __err = EXPR;                                                        \
+    if (__err != CUDA_SUCCESS) {                                                  \
+      AT_ERROR("CUDA driver error: ", static_cast(__err));                   \
+    }                                                                             \
+  } while (0)
+
+#endif
+
+// For CUDA NVRTC
+//
+// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
+// incorrectly produces the error string "NVRTC unknown error."
+// The following maps it correctly.
+//
+// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
+// in ATen, and we need to use its nvrtcGetErrorString.
+// See NOTE [ USE OF NVRTC AND DRIVER API ].
+#define AT_CUDA_NVRTC_CHECK(EXPR)                                                                   \
+  do {                                                                                              \
+    nvrtcResult __err = EXPR;                                                                       \
+    if (__err != NVRTC_SUCCESS) {                                                                   \
+      if (static_cast(__err) != 7) {                                                           \
+        AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err));  \
+      } else {                                                                                      \
+        AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE");                        \
+      }                                                                                             \
+    }                                                                                               \
+  } while (0)
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/NumericLimits.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/NumericLimits.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d02b41a8157f30aeb4e91fc865ed654598318351
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/NumericLimits.cuh
@@ -0,0 +1,121 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
+// types. This header is very specific to ROCm HIP and may be removed in the future.
+// This header is derived from the legacy THCNumerics.cuh.
+
+// The lower_bound and upper_bound constants are same as lowest and max for
+// integral types, but are -inf and +inf for floating point types. They are
+// useful in implementing min, max, etc.
+
+namespace at {
+
+template 
+struct numeric_limits {
+};
+
+// WARNING: the following at::numeric_limits definitions are there only to support
+//          HIP compilation for the moment. Use std::numeric_limits if you are not
+//          compiling for ROCm.
+//          from @colesbury: "The functions on numeric_limits aren't marked with
+//          __device__ which is why they don't work with ROCm. CUDA allows them
+//          because they're constexpr."
+
+namespace {
+  // ROCm doesn't like INFINITY too.
+  constexpr double inf = INFINITY;
+}
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ bool lowest() { return false; }
+  static inline __host__ __device__ bool max() { return true; }
+  static inline __host__ __device__ bool lower_bound() { return false; }
+  static inline __host__ __device__ bool upper_bound() { return true; }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ uint8_t lowest() { return 0; }
+  static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
+  static inline __host__ __device__ uint8_t lower_bound() { return 0; }
+  static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
+  static inline __host__ __device__ int8_t max() { return INT8_MAX; }
+  static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
+  static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
+  static inline __host__ __device__ int16_t max() { return INT16_MAX; }
+  static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
+  static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
+  static inline __host__ __device__ int32_t max() { return INT32_MAX; }
+  static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
+  static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
+};
+
+template <>
+struct numeric_limits {
+#ifdef _MSC_VER
+  static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
+  static inline __host__ __device__ int64_t max() { return _I64_MAX; }
+  static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
+  static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
+#else
+  static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
+  static inline __host__ __device__ int64_t max() { return INT64_MAX; }
+  static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
+  static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
+#endif
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
+  static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
+  static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
+  static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
+  static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
+  static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
+  static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ float lowest() { return -FLT_MAX; }
+  static inline __host__ __device__ float max() { return FLT_MAX; }
+  static inline __host__ __device__ float lower_bound() { return -static_cast(inf); }
+  static inline __host__ __device__ float upper_bound() { return static_cast(inf); }
+};
+
+template <>
+struct numeric_limits {
+  static inline __host__ __device__ double lowest() { return -DBL_MAX; }
+  static inline __host__ __device__ double max() { return DBL_MAX; }
+  static inline __host__ __device__ double lower_bound() { return -inf; }
+  static inline __host__ __device__ double upper_bound() { return inf; }
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h
new file mode 100644
index 0000000000000000000000000000000000000000..bad21b18d83c2e8110607ff83153bc568717524c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h
@@ -0,0 +1,11 @@
+#include 
+#include 
+
+namespace at::cuda {
+namespace detail {
+void init_p2p_access_cache(int64_t num_devices);
+}
+
+TORCH_CUDA_CPP_API bool get_p2p_access(int source_dev, int dest_dev);
+
+}  // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h
new file mode 100644
index 0000000000000000000000000000000000000000..257ac6bbb896ab2883e7e85011ddee1426f53d15
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PhiloxCudaState.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#include 
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..091dd5e4402b9987edebed96d6d06c3baffa8272
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PhiloxUtils.cuh
@@ -0,0 +1,4 @@
+#pragma once
+
+#include 
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..ff65549139607e2ea3d5378e953010f4ef6040fd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/PinnedMemoryAllocator.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::cuda {
+
+inline TORCH_CUDA_CPP_API at::Allocator* getPinnedMemoryAllocator() {
+  return getCachingHostAllocator();
+}
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/ScanUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ScanUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..f81f560b4b523f8bc81423183cbbccca9c9d45e2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ScanUtils.cuh
@@ -0,0 +1,78 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+// Collection of in-kernel scan / prefix sum utilities
+
+namespace at::cuda {
+
+// Inclusive prefix sum for binary vars using intra-warp voting +
+// shared memory
+template 
+__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
+  // Within-warp, we use warp voting.
+#if defined (USE_ROCM)
+  unsigned long long int vote = WARP_BALLOT(in);
+  T index = __popcll(getLaneMaskLe() & vote);
+  T carry = __popcll(vote);
+#else
+  T vote = WARP_BALLOT(in);
+  T index = __popc(getLaneMaskLe() & vote);
+  T carry = __popc(vote);
+#endif
+
+  int warp = threadIdx.x / C10_WARP_SIZE;
+
+  // Per each warp, write out a value
+  if (getLaneId() == 0) {
+    smem[warp] = carry;
+  }
+
+  __syncthreads();
+
+  // Sum across warps in one thread. This appears to be faster than a
+  // warp shuffle scan for CC 3.0+
+  if (threadIdx.x == 0) {
+    int current = 0;
+    for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
+      T v = smem[i];
+      smem[i] = binop(smem[i], current);
+      current = binop(current, v);
+    }
+  }
+
+  __syncthreads();
+
+  // load the carry from the preceding warp
+  if (warp >= 1) {
+    index = binop(index, smem[warp - 1]);
+  }
+
+  *out = index;
+
+  if (KillWARDependency) {
+    __syncthreads();
+  }
+}
+
+// Exclusive prefix sum for binary vars using intra-warp voting +
+// shared memory
+template 
+__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
+  inclusiveBinaryPrefixScan(smem, in, out, binop);
+
+  // Inclusive to exclusive
+  *out -= (T) in;
+
+  // The outgoing carry for all threads is the last warp's sum
+  *carry = smem[at::ceil_div(blockDim.x, C10_WARP_SIZE) - 1];
+
+  if (KillWARDependency) {
+    __syncthreads();
+  }
+}
+
+}  // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/Sleep.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/Sleep.h
new file mode 100644
index 0000000000000000000000000000000000000000..f14fbb5a8f9720b5f0da97e3d65f63bf041c0a18
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/Sleep.h
@@ -0,0 +1,10 @@
+#pragma once
+#include 
+#include 
+
+namespace at::cuda {
+
+// enqueues a kernel that spins for the specified number of cycles
+TORCH_CUDA_CU_API void sleep(int64_t cycles);
+
+}  // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/ThrustAllocator.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ThrustAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..e7f56bd455e5a71bef001908cd55f0e40a45f6ad
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/ThrustAllocator.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::cuda {
+
+/// Allocator for Thrust to re-route its internal device allocations
+/// to the THC allocator
+class ThrustAllocator {
+public:
+  typedef char value_type;
+
+  char* allocate(std::ptrdiff_t size) {
+    return static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(size));
+  }
+
+  void deallocate(char* p, size_t size) {
+    c10::cuda::CUDACachingAllocator::raw_delete(p);
+  }
+};
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..954fac05ca331f5a6fd5b89eada0e8572ba1ec77
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub.cuh
@@ -0,0 +1,413 @@
+#pragma once
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
+
+#include 
+
+#else
+
+// include cub in a safe manner, see:
+// https://github.com/pytorch/pytorch/pull/55292
+#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
+#undef CUB_NS_PREFIX
+#undef CUB_NS_QUALIFIER
+#define CUB_NS_PREFIX namespace at_cuda_detail {
+#define CUB_NS_POSTFIX }
+#define CUB_NS_QUALIFIER ::at_cuda_detail::cub
+#include 
+#undef CUB_NS_POSTFIX
+#undef CUB_NS_PREFIX
+#undef CUB_NS_QUALIFIER
+
+#endif
+
+#include 
+#include 
+#include 
+
+// handle the temporary storage and 'twice' calls for cub API
+#define CUB_WRAPPER(func, ...) do {                                       \
+  size_t temp_storage_bytes = 0;                                          \
+  func(nullptr, temp_storage_bytes, __VA_ARGS__);                         \
+  auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get();    \
+  auto temp_storage = caching_allocator.allocate(temp_storage_bytes);     \
+  func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__);              \
+  AT_CUDA_CHECK(cudaGetLastError());                                      \
+} while (false)
+
+#ifdef USE_ROCM
+#define NO_ROCM(x)
+#define ROCM_HIPCUB(x) ::hipcub
+#else
+#define NO_ROCM(x) x
+#define ROCM_HIPCUB(x) x
+#endif
+
+#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || \
+     (defined(USE_ROCM) && ROCM_VERSION >= 40500)
+
+#if !defined(USE_ROCM)
+namespace at_cuda_detail {
+#endif
+
+// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
+
+template <>
+struct ROCM_HIPCUB(cub)::FpLimits
+{
+    static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
+        unsigned short max_word = 0x7F7F;
+        return reinterpret_cast(max_word);
+    }
+
+    static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
+        unsigned short lowest_word = 0xFF7F;
+        return reinterpret_cast(lowest_word);
+    }
+};
+
+template <>
+struct ROCM_HIPCUB(cub)::NumericTraits:
+       ROCM_HIPCUB(cub)::BaseTraits {};
+
+#if !defined(USE_ROCM)
+} // namespace at_cuda_detail
+#endif
+
+#endif
+
+#if !defined(USE_ROCM)
+namespace at::native {
+namespace cub = ::at_cuda_detail::cub;
+} // namespace at::native
+#endif
+
+namespace at::cuda::cub {
+
+namespace detail {
+
+template
+struct cuda_type {
+  using type = T;
+};
+template<>
+struct cuda_type {
+  using type = __half;
+};
+
+#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
+
+template<>
+struct cuda_type {
+  using type = __nv_bfloat16;
+};
+
+#elif (defined(USE_ROCM) && ROCM_VERSION >= 40500)
+
+template<>
+struct cuda_type {
+  using type = hip_bfloat16;
+};
+
+#endif
+
+}  // namespace detail
+
+template
+inline void segmented_sort_pairs(
+    const key_t *keys_in, key_t *keys_out,
+    const value_t *values_in, value_t *values_out,
+    int64_t num_elements, int64_t num_segments,
+    OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
+    bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
+) {
+  TORCH_CHECK(num_elements <= std::numeric_limits::max(),
+    "cub sort does not support sorting more than INT_MAX elements");
+  TORCH_CHECK(num_segments <= std::numeric_limits::max(),
+    "cub sort does not support sorting more than INT_MAX elements");
+  using key_t_ = typename detail::cuda_type::type;
+
+  auto allocator = c10::cuda::CUDACachingAllocator::get();
+  c10::DataPtr keys_out_owner;
+
+  if (keys_out == nullptr) {
+    keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
+    keys_out = reinterpret_cast(keys_out_owner.get());
+  }
+
+  const key_t_ *keys_in_ = reinterpret_cast(keys_in);
+  key_t_ *keys_out_ = reinterpret_cast(keys_out);
+
+  if (descending) {
+    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
+      keys_in_, keys_out_, values_in, values_out,
+      num_elements, num_segments, begin_offsets, end_offsets,
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+  } else {
+    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
+      keys_in_, keys_out_, values_in, values_out,
+      num_elements, num_segments, begin_offsets, end_offsets,
+      begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
+  }
+}
+
+#if CUB_SUPPORTS_UNIQUE_BY_KEY()
+template 
+inline void unique_by_key(
+  KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
+  KeysOutputIteratorT keys_out, ValuesOutputIteratorT values_out,
+  NumSelectedIteratorT num_selected, int64_t num_input_items)
+{
+  // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
+  constexpr bool null_keys_out = std::is_same::value;
+  using KeyT = typename std::iterator_traits::value_type;
+  using RealKeysOutputIteratorT = typename std::conditional::type;
+  RealKeysOutputIteratorT keys_out_;
+  auto allocator = c10::cuda::CUDACachingAllocator::get();
+  c10::DataPtr keys_out_owner;
+  if constexpr (null_keys_out) {
+    keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
+    keys_out_ = static_cast(keys_out_owner.get());
+  } else {
+    keys_out_ = keys_out;
+  }
+  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
+    keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
+}
+#endif
+
+namespace impl {
+
+template
+C10_LAUNCH_BOUNDS_1(1)
+__global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
+  // NOTE: out here not the final scan output, but an intermediate of the accumulation type.
+  using acc_t = typename std::iterator_traits::value_type;
+  *out = scan_op(static_cast(*a), static_cast(*b));
+}
+
+#if !CUB_SUPPORTS_FUTURE_VALUE()
+template
+struct chained_iterator {
+  using iterator_category = std::random_access_iterator_tag;
+  using difference_type   = std::ptrdiff_t;
+  using value_type        = ValueT;
+  using pointer           = ValueT*;
+  using reference         = ValueT&;
+
+  InputIteratorT iter;
+  ValueT *first;
+  difference_type offset = 0;
+
+  __device__ ValueT operator[](difference_type i) {
+    i +=  offset;
+    if (i == 0) {
+      return *first;
+    } else {
+      return ValueT(iter[i - 1]);
+    }
+  }
+  __device__ chained_iterator operator+(difference_type i) {
+    return chained_iterator{iter, first, i};
+  }
+  __device__ ValueT operator*() {
+    return (*this)[0];
+  }
+};
+#endif
+
+// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
+// so split at int_max/2
+constexpr int max_cub_size = std::numeric_limits::max() / 2 + 1; // 2**30
+}
+
+// non synchronizing cub call
+// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
+// so split at int_max/2
+template
+inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
+#if defined(USE_ROCM) && (ROCM_VERSION >= 50000)
+  //For ROCm, use hipCUB chained iterators
+  CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
+      input,
+      output,
+      scan_op,
+      num_items,
+      at::cuda::getCurrentCUDAStream());
+  C10_HIP_KERNEL_LAUNCH_CHECK();
+#else
+  // non synchronizing cub call
+  // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
+  // so split at int_max/2
+  int size_cub = std::min(num_items, max_cub_size);
+  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
+      input,
+      output,
+      scan_op,
+      size_cub,
+      at::cuda::getCurrentCUDAStream());
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+  using input_t = typename std::iterator_traits::value_type;
+  for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
+    auto allocator = c10::cuda::CUDACachingAllocator::get();
+    c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
+    auto first_elem_ptr = reinterpret_cast(first_elem.get());
+
+    size_cub = std::min(num_items - i, max_cub_size);
+    impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
+        output + i - 1,
+        input + i,
+        first_elem_ptr,
+        scan_op);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+#if !CUB_SUPPORTS_FUTURE_VALUE()
+    using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator;
+    using tuple = typename ArgIndexInputIterator::value_type;
+    auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  {
+      if (x.key == 0) {
+        return *first_elem_ptr;
+      } else {
+        return x.value;
+      }
+    };
+    auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator(
+      ArgIndexInputIterator(input + i), input_iter_transform);
+    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
+        input_,
+        output + i,
+        scan_op,
+        size_cub,
+        at::cuda::getCurrentCUDAStream());
+#else
+    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
+        input + i + 1,
+        output + i,
+        scan_op,
+        ::at_cuda_detail::cub::FutureValue(first_elem_ptr),
+        size_cub,
+        at::cuda::getCurrentCUDAStream());
+#endif
+  }
+#endif
+}
+
+template
+inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
+#if defined(USE_ROCM) && (ROCM_VERSION >= 50000)
+  //For ROCm, use hipCUB chained iterators
+  CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
+      input,
+      output,
+      scan_op,
+      init_value,
+      num_items,
+      at::cuda::getCurrentCUDAStream());
+  C10_HIP_KERNEL_LAUNCH_CHECK();
+#else
+  // non synchronizing cub call
+  // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
+  // so split at int_max/2
+  int size_cub = std::min(num_items, max_cub_size);
+  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
+      input,
+      output,
+      scan_op,
+      init_value,
+      size_cub,
+      at::cuda::getCurrentCUDAStream());
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+  for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
+    auto allocator = c10::cuda::CUDACachingAllocator::get();
+    c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
+    auto first_elem_ptr = reinterpret_cast(first_elem.get());
+
+    size_cub = std::min(num_items - i, max_cub_size);
+    impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
+        output + i - 1,
+        input + i - 1,
+        first_elem_ptr,
+        scan_op);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+#if !CUB_SUPPORTS_FUTURE_VALUE()
+    auto input_ = impl::chained_iterator{
+      input + i, first_elem_ptr};
+    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
+        input_,
+        output + i,
+        scan_op,
+        size_cub,
+        at::cuda::getCurrentCUDAStream());
+#else
+    CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
+        input + i,
+        output + i,
+        scan_op,
+        ::at_cuda_detail::cub::FutureValue(first_elem_ptr),
+        size_cub,
+        at::cuda::getCurrentCUDAStream());
+#endif
+  }
+#endif
+}
+
+#if CUB_SUPPORTS_SCAN_BY_KEY()
+
+template 
+inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
+  TORCH_CHECK(num_items <= std::numeric_limits::max(),
+    "cub InclusiveSumByKey does not support more than INT_MAX elements");
+  CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
+      keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
+}
+
+template 
+inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
+  TORCH_CHECK(num_items <= std::numeric_limits::max(),
+    "cub InclusiveSumByKey does not support more than INT_MAX elements");
+  CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
+      keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
+}
+
+#endif
+
+template 
+void unique(InputIteratorT input, OutputIteratorT output,
+            NumSelectedIteratorT num_selected_out, int64_t num_items) {
+  TORCH_CHECK(num_items <= std::numeric_limits::max(),
+              "cub unique does not support more than INT_MAX elements");
+  CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
+              input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
+}
+
+template 
+void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
+                       LengthOutputIteratorT length_out, int64_t num_items) {
+  TORCH_CHECK(num_items <= std::numeric_limits::max(),
+              "cub run_length_encode does not support more than INT_MAX elements");
+  CUB_WRAPPER(
+      NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
+      input, output, counts_out, length_out, num_items,
+      at::cuda::getCurrentCUDAStream());
+}
+
+template 
+void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
+  TORCH_CHECK(num_items <= std::numeric_limits::max(),
+              "cub reduce does not support more than INT_MAX elements");
+  CUB_WRAPPER(
+      NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
+      input, output, num_items, op, init,
+      at::cuda::getCurrentCUDAStream());
+
+}
+
+}  // namespace at::cuda::cub
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub.h
new file mode 100644
index 0000000000000000000000000000000000000000..37e9867f39be900c5d9a0a1e525cb94676dc134b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub.h
@@ -0,0 +1,87 @@
+#pragma once
+#include 
+#include 
+#include 
+
+// NOTE: These templates are intentionally not defined in this header,
+// which aviods re-compiling them for each translation unit. If you get
+// a link error, you need to add an explicit instantiation for your
+// types in cub.cu
+
+namespace at::cuda::cub {
+
+inline int get_num_bits(uint64_t max_key) {
+  int num_bits = 1;
+  while (max_key > 1) {
+    max_key >>= 1;
+    num_bits++;
+  }
+  return num_bits;
+}
+
+namespace detail {
+
+// radix_sort_pairs doesn't interact with value_t other than to copy
+// the data, so we can save template instantiations by reinterpreting
+// it as an opaque type.
+template  struct alignas(N) OpaqueType { char data[N]; };
+
+template
+void radix_sort_pairs_impl(
+    const key_t *keys_in, key_t *keys_out,
+    const OpaqueType *values_in, OpaqueType *values_out,
+    int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);
+
+}  // namespace detail
+
+template
+void radix_sort_pairs(
+    const key_t *keys_in, key_t *keys_out,
+    const value_t *values_in, value_t *values_out,
+    int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) {
+  static_assert(std::is_trivially_copyable::value ||
+                AT_ROCM_ENABLED(),  // ROCm incorrectly fails this check for vector types
+                "radix_sort_pairs value type must be trivially copyable");
+  // Make value type opaque, so all inputs of a certain size use the same template instantiation
+  using opaque_t = detail::OpaqueType;
+  static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
+                "This size of value_t is not instantiated. Please instantiate it in cub.cu"
+                " and modify this check.");
+  static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned");
+  detail::radix_sort_pairs_impl(
+      keys_in, keys_out,
+      reinterpret_cast(values_in),
+      reinterpret_cast(values_out),
+      n, descending, begin_bit, end_bit);
+}
+
+template
+void radix_sort_keys(
+    const key_t *keys_in, key_t *keys_out,
+    int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8);
+
+// NOTE: Intermediate sums will be truncated to input_t precision
+template 
+void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n);
+
+template 
+void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
+  return inclusive_sum_truncating(input, output, n);
+}
+
+// NOTE: Sums are done is common_type
+template 
+void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n);
+
+template 
+void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
+  return exclusive_sum_in_common_type(input, output, n);
+}
+
+void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n);
+inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) {
+  return mask_exclusive_sum(
+      reinterpret_cast(mask), output_idx, n);
+}
+
+}  // namespace at::cuda::cub
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub_definitions.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub_definitions.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..a88086ae6d6a9b9e9dcd7a69822ef30d58481925
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/cub_definitions.cuh
@@ -0,0 +1,53 @@
+#pragma once
+
+#if !defined(USE_ROCM)
+#include   // for CUDA_VERSION
+#endif
+
+#if !defined(USE_ROCM)
+#include 
+#else
+#define CUB_VERSION 0
+#endif
+
+// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
+// https://github.com/NVIDIA/cub/pull/306
+#if CUB_VERSION >= 101300
+#define CUB_SUPPORTS_NV_BFLOAT16() true
+#else
+#define CUB_SUPPORTS_NV_BFLOAT16() false
+#endif
+
+// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
+// https://github.com/NVIDIA/cub/pull/326
+// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
+// starting from CUDA 11.5
+#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE)
+#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true
+#else
+#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
+#endif
+
+// cub support for UniqueByKey is added to cub 1.16 in:
+// https://github.com/NVIDIA/cub/pull/405
+#if CUB_VERSION >= 101600
+#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
+#else
+#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
+#endif
+
+// cub support for scan by key is added to cub 1.15
+// in https://github.com/NVIDIA/cub/pull/376
+#if CUB_VERSION >= 101500
+#define CUB_SUPPORTS_SCAN_BY_KEY() 1
+#else
+#define CUB_SUPPORTS_SCAN_BY_KEY() 0
+#endif
+
+// cub support for cub::FutureValue is added to cub 1.15 in:
+// https://github.com/NVIDIA/cub/pull/305
+#if CUB_VERSION >= 101500
+#define CUB_SUPPORTS_FUTURE_VALUE() true
+#else
+#define CUB_SUPPORTS_FUTURE_VALUE() false
+#endif
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h
new file mode 100644
index 0000000000000000000000000000000000000000..904d333f72709afe077c015670f2a932c29c9882
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+// TODO: No need to have this whole header, we can just put it all in
+// the cpp file
+
+namespace at::cuda::detail {
+
+// Set the callback to initialize Magma, which is set by
+// torch_cuda_cu. This indirection is required so magma_init is called
+// in the same library where Magma will be used.
+TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
+
+
+// The real implementation of CUDAHooksInterface
+struct CUDAHooks : public at::CUDAHooksInterface {
+  CUDAHooks(at::CUDAHooksArgs) {}
+  void initCUDA() const override;
+  Device getDeviceFromPtr(void* data) const override;
+  bool isPinnedPtr(const void* data) const override;
+  const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
+  bool hasCUDA() const override;
+  bool hasMAGMA() const override;
+  bool hasCuDNN() const override;
+  bool hasCuSOLVER() const override;
+  bool hasROCM() const override;
+  const at::cuda::NVRTC& nvrtc() const override;
+  DeviceIndex current_device() const override;
+  bool hasPrimaryContext(DeviceIndex device_index) const override;
+  Allocator* getCUDADeviceAllocator() const override;
+  Allocator* getPinnedMemoryAllocator() const override;
+  bool compiledWithCuDNN() const override;
+  bool compiledWithMIOpen() const override;
+  bool supportsDilatedConvolutionWithCuDNN() const override;
+  bool supportsDepthwiseConvolutionWithCuDNN() const override;
+  bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
+  bool hasCUDART() const override;
+  long versionCUDART() const override;
+  long versionCuDNN() const override;
+  std::string showConfig() const override;
+  double batchnormMinEpsilonCuDNN() const override;
+  int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
+  void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
+  int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
+  void cuFFTClearPlanCache(DeviceIndex device_index) const override;
+  int getNumGPUs() const override;
+  void deviceSynchronize(DeviceIndex device_index) const override;
+};
+
+} // at::cuda::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h
new file mode 100644
index 0000000000000000000000000000000000000000..e17eed4b63a1c84f7f873e813ab9ed8bdf849472
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h
@@ -0,0 +1,151 @@
+// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
+// These handles are tied to device, and these libraries requires/recommends not to
+// share handles across host threads.
+//
+// These libraries recommend using one handle per host thread. We may not want to do
+// this because threads are relatively light-weight, but creating and destroying
+// handles is expensive (destroying the handle causes synchronizations). DataParallel,
+// for example, creates new threads for each forward pass.
+//
+// This file implements a handle pool mechanism. The handle pool returns handles on
+// demand as threads request them. If all existing handles in the pool are in use,
+// it creates a new one. As threads terminate, they release handles back into the pool.
+// In this way, the handle pool never creates more handles than the high-water mark of
+// active threads, so it's efficient with DataParallel.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at::cuda { namespace {
+
+template 
+struct DeviceThreadHandlePool : public std::enable_shared_from_this> {
+
+    struct Handle {
+    Handle_t handle;
+    Handle(bool create = false) : handle(nullptr)
+    {
+        if(create) Create(&handle);
+    }
+    // std::vector.emplace() and push_back() may route through temporaries and call
+    // copy/move constructors along the way.  If this is the case, we don't want
+    // the destructors of temporaries to call cudnnDestroy on the handle.
+    // We can achieve safety (for the narrow case of stashing within std::vectors)
+    // by making Handle moveable but not copyable, and transferring handle ownership
+    // to the latest constructed object.  This is not a substitute for full-blown
+    // reference counting, but reference counting may be overkill here.
+    // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
+    // unordered_map>> created_handles;
+    Handle(const Handle& rhs) = delete;
+    // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
+    Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
+    // operator= takes argument by value
+    Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
+    ~Handle() {
+        if(handle) Destroy(handle);
+    }
+    };
+
+    std::mutex mutex;
+
+    // Handles are lazily created as different threads request them,
+    // but are never destroyed until the end of the process.
+    // The maximum number of handles this process will create for each device is equal
+    // to the high-water mark of the number of concurrently active threads that request
+    // handles for that device.
+    // When threads terminate, they release their handles back into the pool for reuse.
+    // Otherwise, new handles would be created every time new threads were spawned,
+    // resulting in poor performance for Python modules that repeatedly or frequently
+    // spawned new sets of threads (like DataParallel, which creates a new set of threads
+    // for each forward pass).
+    //
+    // To prevent potential deadlocks, we explicitly choose not to cap the number
+    // of handles that are created per device.
+    // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
+    // only 4 can make forward progress at any time. The other 4 will not release their
+    // handles until they exit, so the fifth cannot make progress until then.  This is
+    // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
+    // intermediate point (ie, before any of them have exited).  We have no way to anticipate
+    // or enforce that user threads will not attempt such intermediate synchronization.
+    // The only way to ensure safety is to avoid imposing a cap on the number of handles.
+    std::unordered_map> created_handles;
+    std::unordered_map> available_handles;
+
+    // PoolWindow lazily creates and caches the handles that a particular thread is using,
+    // so in the common case handle access doesn't incur either handle creation or a mutex lock.
+    class PoolWindow
+    {
+    public:
+    PoolWindow(std::shared_ptr parent): weak_parent(std::move(parent)) {}
+    ~PoolWindow(){ release(); }
+
+    Handle_t reserve(int device)
+    {
+        // If this thread already has a handle for this device, return it
+        if(my_handles.find(device) != my_handles.end())
+        return my_handles[device];
+
+        // otherwise, either grab a handle from the pool if one is available,
+        // or if not, create a new one.
+        auto parent = weak_parent.lock();
+        TORCH_CHECK(parent, "Cannot create handle during program termination");
+        std::lock_guard guard(parent->mutex);
+
+        if(parent->available_handles[device].size() > 0)
+        {
+        my_handles[device] = parent->available_handles[device].back();
+        parent->available_handles[device].pop_back();
+        }
+        else
+        {
+        // In local testing, I do observe that emplace_back sometimes routes through temporaries
+        // that incur move-constructor and destructor calls.  See comments in Handle above.
+        parent->created_handles[device].emplace_back(true /*create*/);
+        my_handles[device] = parent->created_handles[device].back().handle;
+        }
+
+        return my_handles[device];
+    }
+
+    private:
+    // Stores the per-device handles currently owned by this thread
+    std::unordered_map my_handles;
+
+    std::weak_ptr weak_parent;
+
+    // Called by the destructor.  Releases this thread's handles back into the pool.
+    void release() {
+        if(my_handles.size() > 0) {
+            auto parent = weak_parent.lock();
+            if (!parent) {
+                // If this thread exits after atexit handlers have completed, the
+                // cuda context itself may be invalid, so we must leak the handles.
+                return;
+            }
+
+            std::lock_guard guard(parent->mutex);
+            for(auto d_h : my_handles)
+                parent->available_handles[d_h.first].push_back(d_h.second);
+        }
+    }
+    };
+
+    // Warning:
+    // If you want to change this function, be aware that this function will be called
+    // by multiple threads and there is no mutex guarding the call of this function, so
+    // make sure your implementation is thread-safe.
+    PoolWindow *newPoolWindow() {
+        // The returned pointer will be owned by a thread local variable
+        // so that different threads does not share the same PoolWindow.
+        return new PoolWindow(this->shared_from_this());
+    }
+};
+
+}}  // namespace at::cuda::detail::
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..a1994fee2ae3f0a0f984f4e4ec60597c1af302ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::cuda::detail {
+
+TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
+using at::native::canUse32BitIndexMath;
+
+template 
+TensorInfo
+getTensorInfo(const at::TensorBase &t) {
+  IndexType sz[MAX_TENSORINFO_DIMS];
+  IndexType st[MAX_TENSORINFO_DIMS];
+
+  int dims = t.dim();
+  for (int i = 0; i < dims; ++i) {
+    sz[i] = t.size(i);
+    st[i] = t.stride(i);
+  }
+
+  scalar* data_ptr = nullptr;
+
+  if constexpr (std::is_const::value) {
+    data_ptr = t.const_data_ptr();
+  } else {
+    data_ptr = t.mutable_data_ptr();
+  }
+
+  return TensorInfo(
+    data_ptr, dims, sz, st);
+}
+
+} // namespace at::cuda::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..7532aed5fee08a22c88135169634d206ab3c8982
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh
@@ -0,0 +1,124 @@
+#pragma once
+
+#include 
+#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
+#include 
+#endif
+
+namespace at::cuda::detail {
+
+// A utility class to implement integer division by multiplication, given a fixed
+// divisor.
+//
+// WARNING: The fast divider algorithm is only implemented for unsigned int;
+//          otherwise we default to plain integer division.  For unsigned int,
+//          we further assume that the dividend is at most INT32_MAX.  Thus,
+//          IntDivider must NOT be used for general integer division.
+//
+//          This reduced range is enough for our purpose, and it allows us to
+//          slightly simplify the computation.
+//
+// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1< 0), we can find a "magic number" m (2^N
+// <= m < 2^(N+1)) and shift s such that:
+//
+//    \floor(n / d) = \floor((m * n) / 2^(N+s)).
+//
+// Given such m and s, the integer division can be then implemented as:
+//
+//    let m' = m - 2^N  // 0 <= m' < 2^N
+//
+//    fast_integer_division(n):
+//      // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
+//      // integer.  Then take the higher N bits.
+//      t = (m' * n) >> N
+//
+//      // Here we use the fact that n is less than 2^(N-1): otherwise the value
+//      // of (t + n) may not fit in an N-bit integer.
+//      return (t + n) >> s
+//
+// Finding such a magic number is surprisingly easy:
+//
+//    s  = \ceil(\log_2 d)
+//    m' = \floor(2^N * (2^s - d) / d) + 1  // Need 2N-bit integer arithmetic.
+//
+// See also:
+//    - Division by Invariant Integers Using Multiplication,
+//      Torbjörn Granlund and Peter L. Montgomery, 1994.
+//
+//    - http://www.hackersdelight.org/magic.htm
+//
+//    - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
+
+// Result of div/mod operation stored together.
+template 
+struct DivMod {
+  Value div, mod;
+
+  C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
+};
+
+// Base case: we only have an implementation for uint32_t for now.  For
+// everything else, we use plain division.
+template 
+struct IntDivider {
+  IntDivider() = default;
+  IntDivider(Value d) : divisor(d) { }
+
+  C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
+  C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
+  C10_HOST_DEVICE inline DivMod divmod(Value n) const {
+    return DivMod(n / divisor, n % divisor);
+  }
+
+  Value divisor;
+};
+
+// Implement fast integer division.
+template <>
+struct IntDivider {
+  static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
+
+  IntDivider() = default;
+
+  IntDivider(unsigned int d) : divisor(d) {
+    assert(divisor >= 1 && divisor <= INT32_MAX);
+
+    // TODO: gcc/clang has __builtin_clz() but it's not portable.
+    for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
+
+    uint64_t one = 1;
+    uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
+    m1 = magic;
+    assert(m1 > 0 && m1 == magic);  // m1 must fit in 32 bits.
+  }
+
+  C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
+#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
+    // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
+    // 'm1'.
+    unsigned int t = __umulhi(n, m1);
+    return (t + n) >> shift;
+#else
+    // Using uint64_t so that the addition does not overflow.
+    uint64_t t = ((uint64_t) n * m1) >> 32;
+    return (t + n) >> shift;
+#endif
+  }
+
+  C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
+    return n - div(n) * divisor;
+  }
+
+  C10_HOST_DEVICE inline DivMod divmod(unsigned int n) const {
+    unsigned int q = div(n);
+    return DivMod(q, n - q * divisor);
+  }
+
+  unsigned int divisor;  // d above.
+  unsigned int m1;  // Magic number: m' above.
+  unsigned int shift;  // Shift amounts.
+};
+
+}  // namespace at::cuda::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..ab85887311ebb1215e9ad670209608364cd77aee
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::cuda::detail {
+
+// CUDA: grid stride looping
+//
+// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
+// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
+// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
+// greater than INT_MAX.  But in that case _i_n_d_e_x >= n, so there are no
+// further iterations and the overflowed value in i=_i_n_d_e_x is not used.
+#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type)                         \
+  int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x;           \
+  for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
+
+#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
+
+
+// Use 1024 threads per block, which requires cuda sm_2x or above
+constexpr int CUDA_NUM_THREADS = 1024;
+
+// CUDA: number of blocks for threads.
+inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
+  TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
+  constexpr int64_t max_int = std::numeric_limits::max();
+
+  // Round up division for positive number that cannot cause integer overflow
+  auto block_num = (N - 1) / max_threads_per_block + 1;
+  TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
+
+  return static_cast(block_num);
+}
+
+}  // namespace at::cuda::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h
new file mode 100644
index 0000000000000000000000000000000000000000..23821c88e964ea499df1479a0c369228ba854738
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h
@@ -0,0 +1,11 @@
+#pragma once
+#include 
+namespace at::cuda {
+// Forward-declares at::cuda::NVRTC
+struct NVRTC;
+
+namespace detail {
+extern NVRTC lazyNVRTC;
+} // namespace detail
+
+}  // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..4b11f1fa64be6651e4d208618d3b4a40a1a1c2fb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh
@@ -0,0 +1,119 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// If element_sizes is nullptr, then the strides will be in bytes, otherwise
+// the strides will be in # of elements.
+// Operands that share the same shape, but may have different strides.
+// OffsetCalculator iterates the tensor in a column-major order
+
+#if defined(USE_ROCM)
+constexpr int MAX_DIMS = 16;
+#else
+constexpr int MAX_DIMS = 25;
+#endif
+
+template 
+struct OffsetCalculator {
+  // We allow having negative strides to implement some operations like torch.flip
+  using stride_t = std::conditional_t,
+                                      index_t>;
+  // The offset for each argument. Wrapper around fixed-size array.
+  // On CUDA, zero sized array is not allowed, so when we are handling nullary
+  // operators, we need to create a size 1 offset to avoid compiler failure.
+  // This size 1 offset is just a placeholder, and we will not use it.
+  using offset_type = at::detail::Array(NARGS, 1)>;
+
+  // if element_sizes is nullptr, then the strides will be in bytes, otherwise
+  // the strides will be in # of elements.
+  OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
+    TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
+    for (int i=0; i < dims; i++){
+      sizes_[i] = at::cuda::detail::IntDivider(sizes[i]);
+      for (int arg = 0; arg < NARGS; arg++) {
+        int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
+        strides_[i][arg] = strides[arg][i] / element_size;
+      }
+    }
+  }
+
+  C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
+    offset_type offsets;
+    #pragma unroll
+    for (int arg = 0; arg < NARGS; arg++) {
+      offsets[arg] = 0;
+    }
+
+    #pragma unroll
+    for (int dim = 0; dim < MAX_DIMS; ++dim) {
+      if (dim == dims) {
+        break;
+      }
+      auto divmod = sizes_[dim].divmod(linear_idx);
+      linear_idx = divmod.div;
+
+      #pragma unroll
+      for (int arg = 0; arg < NARGS; arg++) {
+        offsets[arg] += divmod.mod * strides_[dim][arg];
+      }
+
+    }
+    return offsets;
+  }
+
+  int dims;
+  at::cuda::detail::IntDivider sizes_[MAX_DIMS];
+  stride_t strides_[MAX_DIMS][std::max(NARGS, 1)];
+};
+
+template 
+struct TrivialOffsetCalculator {
+  // The offset for each argument. Wrapper around fixed-size array.
+  // The offsets are in # of elements, not in bytes.
+  // On CUDA, zero sized array is not allowed, so when we are handling nullary
+  // operators, we need to create a size 1 offset to avoid compiler failure.
+  // This size 1 offset is just a placeholder, and we will not use it.
+  using offset_type = at::detail::Array(NARGS, 1)>;
+
+  C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
+    offset_type offsets;
+    #pragma unroll
+    for (int arg = 0; arg < NARGS; arg++) {
+      offsets[arg] = linear_idx;
+    }
+    return offsets;
+  }
+};
+
+// Make an OffsetCalculator with byte offsets
+template
+static OffsetCalculator make_offset_calculator(const at::TensorIteratorBase& iter) {
+  TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
+  std::array strides;
+  for (int i = 0; i < N; i++) {
+    strides[i] = iter.strides(i).data();
+  }
+  return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data());
+}
+
+// Make an OffsetCalculator with element offsets
+template
+static OffsetCalculator make_element_offset_calculator(
+    const at::TensorIteratorBase& iter) {
+  TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
+  std::array strides;
+  std::array element_sizes;
+  for (int i = 0; i < N; i++) {
+    strides[i] = iter.strides(i).data();
+    element_sizes[i] = iter.element_size(i);
+  }
+  return OffsetCalculator(
+      iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..94a69cacc552aaaccbd879497a47e1c8c7cf65c8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
@@ -0,0 +1,43 @@
+// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
+// Eager mode clients should not include this file directly, instead,
+// they should #include , which has a #pragma once.
+
+// Stores RNG state values. Passed as a kernel argument.
+// See Note [CUDA Graph-safe RNG states].
+//
+// The raw definition lives in its own file so jit codegen can easily copy it.
+namespace at {
+
+struct PhiloxCudaState {
+  PhiloxCudaState() = default;
+  // Called if graph capture is not underway
+  PhiloxCudaState(uint64_t seed,
+                  uint64_t offset) {
+    seed_.val = seed;
+    offset_.val = offset;
+  }
+  // Called if graph capture is underway
+  PhiloxCudaState(int64_t* seed,
+                  int64_t* offset_extragraph,
+                  uint32_t offset_intragraph) {
+    seed_.ptr = seed;
+    offset_.ptr = offset_extragraph;
+    offset_intragraph_ = offset_intragraph;
+    captured_ = true;
+  }
+
+  // Public members, directly accessible by at::cuda::philox::unpack.
+  // If we made them private with getters/setters, the getters/setters
+  // would have to be __device__, and we can't declare __device__ in ATen.
+  union Payload {
+    uint64_t val;
+    int64_t* ptr;
+  };
+
+  Payload seed_;
+  Payload offset_;
+  uint32_t offset_intragraph_ = 0;
+  bool captured_ = false;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..dec8f789c7358c3f487c1104007a0e7318829a0c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh
@@ -0,0 +1,116 @@
+#pragma once
+
+#include 
+
+namespace at::cuda::detail {
+
+#define MAX_TENSORINFO_DIMS 25
+
+// CUDA kernel argument that defines tensor layout
+template 
+struct TensorInfo {
+  TensorInfo();
+  TensorInfo(T* p,
+             int dim,
+             IndexType sz[MAX_TENSORINFO_DIMS],
+             IndexType st[MAX_TENSORINFO_DIMS]);
+
+  // Set the size of the given dimension to 1, as if it were a
+  // reduction dim (allows you to calculate offsets of the reduction
+  // slice)
+  void reduceDim(int dim);
+
+  // See note on [collapse dims].
+  int collapseDims(const int excludeDim = -1);
+
+  // Contiguous tensors of more than one dimension are collapsed down
+  // to one tensor
+  __host__ __device__ inline bool isContiguous() const {
+    return (dims == 1 && strides[0] == 1);
+  }
+
+  T* data;
+  IndexType sizes[MAX_TENSORINFO_DIMS];
+  IndexType strides[MAX_TENSORINFO_DIMS];
+  int dims;
+};
+
+template 
+TensorInfo::TensorInfo() {
+  data = nullptr;
+  dims = 0;
+}
+
+template 
+TensorInfo::TensorInfo(T* p,
+                                     int dim,
+                                     IndexType sz[MAX_TENSORINFO_DIMS],
+                                     IndexType st[MAX_TENSORINFO_DIMS]) {
+  data = p;
+  dims = dim;
+  TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
+
+  for (int i = 0; i < dim; ++i) {
+    sizes[i] = sz[i];
+    strides[i] = st[i];
+  }
+}
+
+template 
+void
+TensorInfo::reduceDim(int dim) {
+  TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
+  sizes[dim] = 1;
+}
+
+template 
+int
+TensorInfo::collapseDims(const int excludeDim) {
+  auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
+  dims = std::get<1>(result);
+  return std::get<0>(result);
+}
+
+// Translate a linear index for the apply to a T* offset;
+// specialized on `Dims` to reduce nvcc compilation time
+template 
+struct IndexToOffset {
+  static __host__ __device__ IndexType get(
+    IndexType linearId,
+    const TensorInfo& info) {
+
+    IndexType offset = 0;
+
+    // Uses static dims
+    for (int i = Dims - 1; i > 0; --i) {
+      IndexType curDimIndex = linearId % info.sizes[i];
+      IndexType curDimOffset = curDimIndex * info.strides[i];
+      offset += curDimOffset;
+      linearId /= info.sizes[i];
+    }
+
+    return offset + linearId * info.strides[0];
+  }
+};
+
+// Uses dynamic (runtime) instead of static (compiletime) dims
+template 
+struct IndexToOffset {
+  static inline __host__ __device__ IndexType get(
+    IndexType linearId,
+    const TensorInfo& info) {
+
+      IndexType offset = 0;
+
+      for (int i = info.dims - 1; i > 0; --i) {
+        IndexType curDimIndex = linearId % info.sizes[i];
+        IndexType curDimOffset = curDimIndex * info.strides[i];
+        offset += curDimOffset;
+        linearId /= info.sizes[i];
+      }
+
+      return offset + linearId * info.strides[0];
+  }
+};
+
+} // namespace at::cuda::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..274c9add050fe3a33794e39d79b661f7e2fcf8cf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh
@@ -0,0 +1,28 @@
+// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
+// Eager mode clients should not include this file directly, instead,
+// they should #include , which has a #pragma once.
+
+namespace at::cuda::philox {
+
+// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
+// that instance was created with graph capture underway or not.
+// See Note [CUDA Graph-safe RNG states].
+//
+// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
+// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
+// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
+//
+// The raw definition lives in its own file so jit codegen can easily copy it.
+__host__ __device__ __forceinline__ std::tuple
+unpack(at::PhiloxCudaState arg) {
+  if (arg.captured_) {
+    // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
+    // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
+    // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
+    return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_));
+  } else {
+    return std::make_tuple(arg.seed_.val, arg.offset_.val);
+  }
+}
+
+} // namespace at::cuda::philox
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/jiterator.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/jiterator.h
new file mode 100644
index 0000000000000000000000000000000000000000..5e67b0f83c5d8a52cb1534bdbc7879138b53bdf9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/jiterator.h
@@ -0,0 +1,40 @@
+#pragma once
+#include 
+
+#if AT_USE_JITERATOR()
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at::cuda {
+
+TORCH_CUDA_CPP_API c10::SmallVector CompileAndLaunchKernel(
+  const std::string& code_string,
+  const std::string& kernel_name,
+  const int num_outputs,
+  const c10::SmallVector& tensors,
+  const c10::SmallVector& extra_args,
+  bool return_by_ref);
+
+} // namespace at::cuda
+
+#else
+
+namespace at::cuda {
+
+TORCH_CUDA_CPP_API c10::SmallVector CompileAndLaunchKernel(
+  const std::string& code_string,
+  const std::string& kernel_name,
+  const int num_outputs,
+  const c10::SmallVector& tensors,
+  const c10::SmallVector& extra_args,
+  bool return_by_ref) {
+    TORCH_CHECK(false, "Jiterator is not supported");
+  }
+} // namespace at::cuda
+
+#endif // AT_USE_JITERATOR()
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/jiterator_impl.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/jiterator_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..db8334c9ba510c2488a1e3e6d26d1f4b357cc9e4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/jiterator_impl.h
@@ -0,0 +1,249 @@
+#pragma once
+#include 
+
+#if AT_USE_JITERATOR()
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+
+#define AT_FOR_8_CASES(_)  \
+  _(1)                      \
+  _(2)                      \
+  _(3)                      \
+  _(4)                      \
+  _(5)                      \
+  _(6)                      \
+  _(7)                      \
+  _(8)
+
+#define AT_FOR_8_CASES_WITH_COMMA(_)  \
+  _(1)     ,                           \
+  _(2)     ,                           \
+  _(3)     ,                           \
+  _(4)     ,                           \
+  _(5)     ,                           \
+  _(6)     ,                           \
+  _(7)     ,                           \
+  _(8)
+
+c10::SmallVector get_extra_args_typenames(const c10::SmallVector& extra_args) {
+  c10::SmallVector args_typenames(extra_args.size());
+  for (const auto i : c10::irange(extra_args.size())) {
+    args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
+  }
+  return args_typenames;
+}
+
+int can_vectorize_up_to(at::ScalarType type, char* pointer) {
+  switch(type) {
+#define DEFINE_CASE(ctype, scalartype)                                   \
+    case ScalarType::scalartype : return memory::can_vectorize_up_to(pointer);
+
+    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
+#undef DEFINE_CASE
+
+    default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
+  }
+}
+
+// jitted version of the above
+// See Note [Jiterator], this relies on the assumptions enumerated there
+int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
+  const at::ScalarType common_dtype = iter.common_dtype();
+  const at::ScalarType result_dtype = common_dtype;
+
+  // Deals with output
+  int result = can_vectorize_up_to(result_dtype, static_cast(iter.data_ptr(0)));
+
+  // Incorporates input(s)
+  for (auto i = 1; i < iter.ntensors(); ++i) {
+    result = std::min(result, can_vectorize_up_to(common_dtype, static_cast(iter.data_ptr(i))));
+  }
+
+  return result;
+}
+
+template
+static std::unique_ptr> make_unique_offset_calculator(
+          const TensorIteratorBase& iter) {
+  // array size can not be 0, this happens when N == 0
+  constexpr int array_size = std::max(N, 1);
+  TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs()));
+
+  std::array strides;
+  int64_t element_sizes[array_size];
+  for (int i = 0; i < N; i++) {
+    int index = IS_INPUT ? i + iter.noutputs() : i;
+    strides[i] = iter.strides(index).data();
+    element_sizes[i] = iter.element_size(index);
+  }
+  return std::make_unique>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
+}
+
+template 
+struct OffsetCalculatorVariant {
+#define DEFINE_CASE(index) std::unique_ptr>
+  using OffsetCalculatorTypes = std::variant<
+    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
+  >;
+#undef DEFINE_CASE
+
+  OffsetCalculatorVariant(const TensorIteratorBase& iter) {
+    int num = IS_INPUT ? iter.ninputs() : iter.noutputs();
+
+    switch(num) {
+#define DEFINE_CASE(index)        \
+      case index : v = make_unique_offset_calculator(iter); break;
+
+      AT_FOR_8_CASES(DEFINE_CASE)
+#undef DEFINE_CASE
+      default:
+        TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num);
+    }
+  }
+
+  void* data_ptr() {
+    return std::visit([](auto & v){ return static_cast(v.get()); }, v);
+  }
+
+ private:
+  OffsetCalculatorTypes v;
+};
+
+struct ArrayVariant {
+// works for up to 8 input + 8 outputs
+#define DEFINE_CASE(index) at::detail::Array, at::detail::Array
+  using ArrayTypes = std::variant<
+    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
+  >;
+#undef DEFINE_CASE
+
+  ArrayVariant(const TensorIteratorBase& iter) {
+    int ntensors = iter.ntensors();
+    switch(ntensors) {
+#define DEFINE_CASE(index)                                            \
+      case index: array = at::detail::Array{}; break;   \
+      case index+8: array = at::detail::Array{}; break;
+
+      AT_FOR_8_CASES(DEFINE_CASE)
+#undef DEFINE_CASE
+
+      default:
+        TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
+    }
+
+    std::visit([&](auto& a) {
+      for (auto i = 0; i < ntensors; ++i) {
+        a[i] = (char*)iter.data_ptr(i);
+      }
+    }, array);
+  }
+
+  void* data_ptr() {
+    return std::visit([](auto & a){ return static_cast(&a); }, array);
+  }
+
+private:
+  ArrayTypes array;
+};
+
+struct TrivialOffsetCalculatorVariant {
+#define DEFINE_CASE(index) TrivialOffsetCalculator
+  using TrivialOffsetCalculatorTypes = std::variant<
+    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
+  >;
+#undef DEFINE_CASE
+
+  TrivialOffsetCalculatorVariant(int num) {
+    switch(num) {
+#define DEFINE_CASE(index)      \
+      case index: v = TrivialOffsetCalculator(); break;
+
+      AT_FOR_8_CASES(DEFINE_CASE)
+#undef DEFINE_CASE
+
+      default:
+        TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num);
+    }
+  }
+
+  void* data_ptr() {
+    return std::visit([](auto & v){ return static_cast(&v); }, v);
+  }
+
+private:
+  TrivialOffsetCalculatorTypes v;
+};
+
+struct LoadWithCastVariant {
+#define DEFINE_CASE(index) std::unique_ptr>
+  using LoadWithCastPtr = std::variant<
+    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
+  >;
+#undef DEFINE_CASE
+
+  LoadWithCastVariant(const TensorIteratorBase& iter) {
+    int arity = iter.ninputs();
+    switch(arity) {
+#define DEFINE_CASE(index)      \
+      case index: v = std::make_unique>(iter); break;
+
+      AT_FOR_8_CASES(DEFINE_CASE)
+#undef DEFINE_CASE
+
+      default:
+        TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
+    }
+  }
+
+  void* data_ptr() {
+    return std::visit([](auto & v){ return static_cast(v.get()); }, v);
+  }
+
+private:
+  LoadWithCastPtr v;
+};
+
+struct StoreWithCastVariant {
+#define DEFINE_CASE(index) std::unique_ptr>
+  using StoreWithCastPtr = std::variant<
+    AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
+  >;
+#undef DEFINE_CASE
+
+  StoreWithCastVariant(const TensorIteratorBase& iter) {
+    int num = iter.noutputs();
+    switch(num) {
+#define DEFINE_CASE(index)      \
+      case index: v = std::make_unique>(iter); break;
+
+      AT_FOR_8_CASES(DEFINE_CASE)
+#undef DEFINE_CASE
+
+      default:
+        TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num);
+    }
+  }
+
+  void* data_ptr() {
+    return std::visit([](auto & v){ return static_cast(v.get()); }, v);
+  }
+
+private:
+  StoreWithCastPtr v;
+};
+
+} // namespace at::native
+
+
+#endif // AT_USE_JITERATOR()
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h
new file mode 100644
index 0000000000000000000000000000000000000000..ec2caa7b34b80eec75210988b7d6081e368f65bf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::cuda {
+
+TORCH_CUDA_CPP_API const std::string &get_traits_string();
+TORCH_CUDA_CPP_API const std::string &get_cmath_string();
+TORCH_CUDA_CPP_API const std::string &get_complex_body_string();
+TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string();
+TORCH_CUDA_CPP_API const std::string &get_complex_math_string();
+
+} // namespace at::cuda
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
new file mode 100644
index 0000000000000000000000000000000000000000..592f9fb17cd530cfa195fe073ea3156d67adbffb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
@@ -0,0 +1,174 @@
+// Original TunableOp is from onnxruntime.
+// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
+// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+//
+// Adapting TunableOp into PyTorch
+// Copyright (c) Advanced Micro Devices, Inc.
+//
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at::cuda::tunable {
+
+enum class BlasOp {
+  N = 0,
+  T = 1
+};
+
+inline std::string BlasOpToString(BlasOp op) {
+  switch (op) {
+    case BlasOp::N:
+      return "N";
+    case BlasOp::T:
+      return "T";
+  }
+  TORCH_CHECK(false, "unrecognized BlasOp");
+  return "N";
+}
+
+template 
+struct GemmParams : OpParams {
+  std::string Signature() const override {
+    return c10::str(transa, transb, "_", m, "_", n, "_", k);
+  }
+
+  GemmParams* DeepCopy() const {
+    GemmParams* copy = new GemmParams;
+    *copy = *this;
+    c10::DeviceIndex device = 0;
+    AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
+    size_t c_size = m * n * sizeof(T);
+    copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
+    AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
+        copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
+    return copy;
+  }
+
+  // only call on object returned by DeepCopy
+  void Delete() {
+    c10::cuda::CUDACachingAllocator::raw_delete(c);
+  }
+
+  TuningStatus NumericalCheck(GemmParams *other) {
+    auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType::value).device(at::kCUDA);
+    // comparison done as 1D tensor
+    at::Tensor ref = at::from_blob(c,        {m*n}, options);
+    at::Tensor oth = at::from_blob(other->c, {m*n}, options);
+    at::Tensor ref_float = ref.to(at::kFloat);
+    at::Tensor oth_float = oth.to(at::kFloat);
+    std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
+    std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
+    double last_succeed_atol = 1;
+    double last_succeed_rtol = 1;
+    for (auto& atol : atols) {
+      for (auto& rtol : rtols) {
+        if (at::allclose(ref_float, oth_float, rtol, atol)) {
+          last_succeed_atol = atol;
+          last_succeed_rtol = rtol;
+        }
+      }
+    }
+    if (last_succeed_atol == 1) {
+      return FAIL;
+    }
+    else {
+      TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
+    }
+
+    return OK;
+  }
+
+  char transa;
+  char transb;
+  int64_t m;
+  int64_t n;
+  int64_t k;
+  at::opmath_type alpha;
+  const T* a;
+  int64_t lda;
+  const T* b;
+  int64_t ldb;
+  at::opmath_type beta;
+  T* c;
+  int64_t ldc;
+};
+
+template 
+struct GemmStridedBatchedParams : OpParams {
+  std::string Signature() const override {
+    return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
+  }
+
+  GemmStridedBatchedParams* DeepCopy() const {
+    GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
+    *copy = *this;
+    c10::DeviceIndex device = 0;
+    AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
+    size_t c_size = batch * stride_c * sizeof(T);
+    copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
+    AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
+        copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
+    return copy;
+  }
+
+  // only call on object returned by DeepCopy
+  void Delete() {
+    c10::cuda::CUDACachingAllocator::raw_delete(c);
+  }
+
+  TuningStatus NumericalCheck(GemmStridedBatchedParams *other) {
+    auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType::value).device(at::kCUDA);
+    // comparison done as 1D tensor
+    at::Tensor ref = at::from_blob(c,        {batch*stride_c}, options);
+    at::Tensor oth = at::from_blob(other->c, {batch*stride_c}, options);
+    at::Tensor ref_float = ref.to(at::kFloat);
+    at::Tensor oth_float = oth.to(at::kFloat);
+    std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
+    std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
+    double last_succeed_atol = 1;
+    double last_succeed_rtol = 1;
+    for (auto& atol : atols) {
+      for (auto& rtol : rtols) {
+        if (at::allclose(ref_float, oth_float, rtol, atol)) {
+          last_succeed_atol = atol;
+          last_succeed_rtol = rtol;
+        }
+      }
+    }
+    if (last_succeed_atol == 1) {
+      return FAIL;
+    }
+    else {
+      TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
+    }
+
+    return OK;
+  }
+
+  char transa;
+  char transb;
+  int64_t m;
+  int64_t n;
+  int64_t k;
+  at::opmath_type alpha;
+  const T* a;
+  int64_t lda;
+  int64_t stride_a;
+  const T* b;
+  int64_t ldb;
+  int64_t stride_b;
+  at::opmath_type beta;
+  T* c;
+  int64_t ldc;
+  int64_t stride_c;
+  int64_t batch;
+};
+
+} // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
new file mode 100644
index 0000000000000000000000000000000000000000..91c54b229e5b61a613e711bbf6d2e2fc0d71b2fb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
@@ -0,0 +1,379 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#define TORCH_HIPBLASLT_CHECK(EXPR)               \
+  do {                                            \
+    hipblasStatus_t __err = EXPR;                 \
+    TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS,  \
+                "hipblaslt error: ",              \
+                hipblasStatusToString(__err),     \
+                " when calling `" #EXPR "`");     \
+  } while (0)
+
+namespace at::cuda::tunable {
+
+#ifdef HIPBLASLT_HAS_GETINDEXFROMALGO
+#define GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
+#else
+static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) {
+    int* algo_ptr = (int*)algo.data;
+    if(*algo_ptr < 0) {
+        return -1;
+    }
+    return *algo_ptr;
+}
+#define GETINDEXFROMALGO(algo) getIndexFromAlgo(algo)
+#endif
+
+#ifdef HIPBLASLT_CUSTOM_COMPUTE_TYPE
+#define COMPUTE_TYPE_32 HIPBLASLT_COMPUTE_F32
+#else
+#define COMPUTE_TYPE_32 HIPBLAS_COMPUTE_32F
+#endif
+
+#ifdef HIPBLASLT_CUSTOM_DATA_TYPE
+
+template 
+constexpr hipblasltDatatype_t HipBlasDataTypeFor();
+
+template <>
+constexpr hipblasltDatatype_t HipBlasDataTypeFor() {
+  return HIPBLASLT_R_32F;
+}
+
+template <>
+constexpr hipblasltDatatype_t HipBlasDataTypeFor() {
+  return HIPBLASLT_R_16F;
+}
+
+template <>
+constexpr hipblasltDatatype_t HipBlasDataTypeFor() {
+  return HIPBLASLT_R_16B;
+}
+
+template <>
+constexpr hipblasltDatatype_t HipBlasDataTypeFor() {
+  return HIPBLASLT_R_64F;
+}
+
+#define DATA_TYPE_R_32 HIPBLASLT_R_32F
+
+#else
+
+template 
+constexpr hipblasDatatype_t HipBlasDataTypeFor();
+
+template <>
+constexpr hipblasDatatype_t HipBlasDataTypeFor() {
+  return HIPBLAS_R_32F;
+}
+
+template <>
+constexpr hipblasDatatype_t HipBlasDataTypeFor() {
+  return HIPBLAS_R_16F;
+}
+
+template <>
+constexpr hipblasDatatype_t HipBlasDataTypeFor() {
+  return HIPBLAS_R_16B;
+}
+
+template <>
+constexpr hipblasDatatype_t HipBlasDataTypeFor() {
+  return HIPBLAS_R_64F;
+}
+
+#ifdef HIPBLAS_V2
+#define DATA_TYPE_R_32 HIP_R_32F
+#else
+#define DATA_TYPE_R_32 HIPBLAS_R_32F
+#endif
+
+#endif
+
+template 
+int GetBatchFromParams(const ParamsT* params) {
+  return 1;
+}
+
+template 
+int GetBatchFromParams(const GemmStridedBatchedParams* params) {
+  return params->batch;
+}
+
+template 
+int GetStrideAFromParams(const ParamsT* params) {
+  return 1;
+}
+
+template 
+int GetStrideAFromParams(const GemmStridedBatchedParams* params) {
+  return params->stride_a;
+}
+
+template 
+int GetStrideBFromParams(const ParamsT* params) {
+  return 1;
+}
+
+template 
+int GetStrideBFromParams(const GemmStridedBatchedParams* params) {
+  return params->stride_b;
+}
+
+template 
+int GetStrideCFromParams(const ParamsT* params) {
+  return 1;
+}
+
+template 
+int GetStrideCFromParams(const GemmStridedBatchedParams* params) {
+  return params->stride_c;
+}
+
+static hipblasOperation_t _hipblasOpFromChar(char op) {
+  switch (op) {
+    case 'n':
+    case 'N':
+      return HIPBLAS_OP_N;
+    case 't':
+    case 'T':
+      return HIPBLAS_OP_T;
+    case 'c':
+    case 'C':
+      return HIPBLAS_OP_C;
+  }
+  AT_ERROR(
+      "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
+}
+
+static char _charFromhipblasOp(hipblasOperation_t op) {
+  switch (op) {
+    case HIPBLAS_OP_N:
+      return 'N';
+    case HIPBLAS_OP_T:
+      return 'T';
+    case HIPBLAS_OP_C:
+      return 'C';
+  }
+  AT_ERROR(
+      "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
+}
+
+static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
+  if (layout == BlasOp::N) {
+    return HIPBLAS_OP_N;
+  }
+  return HIPBLAS_OP_T;
+}
+
+static size_t GetHipblasltWorkspaceSize() {
+  static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
+  // 256MB is max workspace size allowed for hipblaslt
+  // hipblaslt-bench uses 32MB
+  // recommendation from hipblaslt author was 76MB
+  size_t workspace_size = 2*128*1024*1024; // default 256MB
+  if (env) {
+    try {
+      workspace_size = std::stoi(env);
+    } catch(std::invalid_argument const& e) {
+      TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
+                 " using default workspace size of ", workspace_size, " bytes.");
+    } catch(std::out_of_range const& e) {
+      TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
+                 " using default workspace size of ", workspace_size, " bytes.");
+    }
+  }
+  return workspace_size;
+}
+
+template 
+class HipblasltGemmOp : public Callable {
+  public:
+    HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
+
+    TuningStatus Call(const ParamsT* params) override {
+      hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
+      hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
+      auto in_out_datatype = HipBlasDataTypeFor();
+      auto opa = _hipblasOpFromChar(params->transa);
+      auto opb = _hipblasOpFromChar(params->transb);
+
+      TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
+
+      float alpha = static_cast(params->alpha);
+      float beta = static_cast(params->beta);
+
+      hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
+      hipblasLtMatmulDesc_t matmul;
+      if (opa == HIPBLAS_OP_N) {
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->m, params->k, params->lda));
+      }
+      else {
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->k, params->m, params->lda));
+      }
+      if (opb == HIPBLAS_OP_N) {
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->k, params->n, params->ldb));
+      }
+      else {
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->n, params->k, params->ldb));
+      }
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul, COMPUTE_TYPE_32, DATA_TYPE_R_32));
+
+      int batch = GetBatchFromParams(params);
+      if (batch > 1) {
+        int64_t stride_a = GetStrideAFromParams(params);
+        int64_t stride_b = GetStrideBFromParams(params);
+        int64_t stride_c = GetStrideCFromParams(params);
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
+            mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
+            mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
+            mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
+            mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
+            mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
+        TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
+            mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
+      }
+
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
+            matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &opa, sizeof(int32_t)));
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
+            matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &opb, sizeof(int32_t)));
+
+      size_t workspace_size = GetHipblasltWorkspaceSize();
+
+      auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
+
+      size_t ret_workspace_size = 0;
+      auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
+          matmul,
+          &alpha,
+          mat_a,
+          mat_b,
+          &beta,
+          mat_c,
+          mat_c,
+          algo_,
+          ret_workspace_size);
+
+      if (status == HIPBLAS_STATUS_SUCCESS) {
+        if (ret_workspace_size >= workspace_size) {
+          //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large");
+          return FAIL;
+        }
+      }
+      else {
+        //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported");
+        return FAIL;
+      }
+
+      void* workspace_buffer = nullptr;
+      if (workspace_size > 0) {
+        workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
+      }
+
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
+            matmul,
+            &alpha,
+            params->a,
+            mat_a,
+            params->b,
+            mat_b,
+            &beta,
+            params->c,
+            mat_c,
+            params->c,
+            mat_c,
+            &algo_,
+            workspace_buffer,
+            workspace_size,
+            at::cuda::getCurrentCUDAStream()));
+
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
+      TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
+      if (workspace_size > 0) {
+        c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
+      }
+      return OK;
+    }
+
+  private:
+    hipblasLtMatmulAlgo_t algo_;
+};
+
+template 
+auto GetHipBlasLtTypeStringAndOps() {
+  hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
+  hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
+  auto in_out_datatype = HipBlasDataTypeFor();
+  std::vector heuristic_result;
+
+  hipblasLtHandle_t handle;
+  TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
+  TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
+        hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
+        transa_outer,
+        transb_outer,
+        in_out_datatype,
+        in_out_datatype,
+        in_out_datatype,
+        in_out_datatype,
+        COMPUTE_TYPE_32,
+        heuristic_result));
+  TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
+
+  // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
+  std::sort(heuristic_result.begin(),
+      heuristic_result.end(),
+      [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
+      return GETINDEXFROMALGO(a.algo) < GETINDEXFROMALGO(b.algo);
+      });
+
+  int returned_algo_count = heuristic_result.size();
+  std::vector>>> ret;
+  for (int i = 0; i < returned_algo_count; i++) {
+    auto algo = heuristic_result[i].algo;
+    int algo_index = GETINDEXFROMALGO(algo);
+    auto callable = std::make_unique>(algo);
+    std::string type_string = c10::str(
+        "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
+    ret.emplace_back(type_string, std::move(callable));
+  }
+
+  return ret;
+}
+
+template 
+auto GetHipBlasLtGemmTypeStringAndOps() {
+  return GetHipBlasLtTypeStringAndOps>();
+}
+
+template 
+auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
+  return GetHipBlasLtTypeStringAndOps>();
+}
+
+#undef TORCH_HIPBLASLT_CHECK
+#undef GETINDEXFROMALGO
+#undef COMPUTE_TYPE_32
+#undef DATA_TYPE_R_32
+
+}  // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h
new file mode 100644
index 0000000000000000000000000000000000000000..37fcc3bea8e880cc7a2f547e00f1bff89405f49b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h
@@ -0,0 +1,275 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#define ROCBLAS_BETA_FEATURES_API
+#include 
+
+#define TORCH_ROCBLAS_CHECK(EXPR)                 \
+  do {                                            \
+    rocblas_status __err = EXPR;                  \
+    TORCH_CHECK(__err == rocblas_status_success,  \
+                "rocblas error: ",                \
+                rocblas_status_to_string(__err),  \
+                " when calling `" #EXPR "`");     \
+  } while (0)
+
+namespace at::cuda::tunable {
+
+template 
+constexpr rocblas_datatype RocBlasDataTypeFor();
+
+template <>
+constexpr rocblas_datatype RocBlasDataTypeFor() {
+  return rocblas_datatype_f32_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasDataTypeFor() {
+  return rocblas_datatype_f64_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasDataTypeFor() {
+  return rocblas_datatype_f16_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasDataTypeFor() {
+  return rocblas_datatype_bf16_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasDataTypeFor>() {
+  return rocblas_datatype_f32_c;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasDataTypeFor>() {
+  return rocblas_datatype_f64_c;
+}
+
+template 
+constexpr rocblas_datatype RocBlasComputeTypeFor();
+
+template <>
+constexpr rocblas_datatype RocBlasComputeTypeFor() {
+  return rocblas_datatype_f32_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasComputeTypeFor() {
+  return rocblas_datatype_f64_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasComputeTypeFor() {
+  // Note that we're returning the _compute_ type for a given datatype.
+  // As of 12/2022, using compute type FP16 for 16-bit floats was much
+  // slower than using compute type FP32. So we use FP32 compute even for
+  // FP16 datatypes. This is how GEMM is implemented even in the function
+  // rocblasGemmHelper (see fpgeneric.h)
+  return rocblas_datatype_f32_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasComputeTypeFor() {
+  // Note that we're returning the _compute_ type for a given datatype.
+  // As of 12/2022, using compute type FP16 for 16-bit floats was much
+  // slower than using compute type FP32. So we use FP32 compute even for
+  // BF16 datatypes. This is how GEMM is implemented even in the function
+  // rocblasGemmHelper (see fpgeneric.h)
+  return rocblas_datatype_f32_r;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasComputeTypeFor>() {
+  return rocblas_datatype_f32_c;
+}
+
+template <>
+constexpr rocblas_datatype RocBlasComputeTypeFor>() {
+  return rocblas_datatype_f64_c;
+}
+
+template 
+auto DoCastForHalfOrBfloat16(const T fp) {
+  return fp;
+}
+
+template <>
+inline auto DoCastForHalfOrBfloat16(const Half fp) {
+  // alpha and beta should be the same as compute_type, in Half case it is float.
+  float h = fp;
+  return h;
+}
+
+template <>
+inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) {
+  // alpha and beta should be the same as compute_type, in bfloat16 case it is float.
+  float h = fp;
+  return h;
+}
+
+static rocblas_operation _rocblasOpFromChar(char op) {
+  switch (op) {
+    case 'n':
+    case 'N':
+      return rocblas_operation_none;
+    case 't':
+    case 'T':
+      return rocblas_operation_transpose;
+    case 'c':
+    case 'C':
+      return rocblas_operation_conjugate_transpose;
+  }
+  AT_ERROR(
+      "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
+}
+
+template 
+class RocblasGemmOp : public Callable> {
+  public:
+    RocblasGemmOp(int solution) : solution_{solution} {}
+
+    TuningStatus Call(const GemmParams* params) override {
+      auto input_output_type = RocBlasDataTypeFor();
+      auto compute_type = RocBlasComputeTypeFor();
+      auto h_a = DoCastForHalfOrBfloat16(params->alpha);
+      auto h_b = DoCastForHalfOrBfloat16(params->beta);
+      auto status = rocblas_gemm_ex(
+          (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
+          _rocblasOpFromChar(params->transa),
+          _rocblasOpFromChar(params->transb),
+          params->m, params->n, params->k,
+          &h_a,
+          params->a, input_output_type, params->lda,
+          params->b, input_output_type, params->ldb,
+          &h_b,
+          params->c, input_output_type, params->ldc,
+          params->c, input_output_type, params->ldc,
+          compute_type,
+          rocblas_gemm_algo_solution_index,
+          solution_,
+          rocblas_gemm_flags_none);
+      if (status != rocblas_status_success) {
+        return FAIL;
+      }
+      return OK;
+    }
+
+  private:
+    int solution_;
+};
+
+template 
+auto GetRocBlasGemmTypeStringAndOps() {
+  rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
+  int solution_size;
+  auto input_output_type = RocBlasDataTypeFor();
+  auto compute_type = RocBlasComputeTypeFor();
+  // Get the number of available solutions
+  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
+                                                            input_output_type,
+                                                            input_output_type,
+                                                            compute_type,
+                                                            rocblas_gemm_flags_none,
+                                                            nullptr,
+                                                            &solution_size));
+  std::vector solutions(solution_size);
+  // Get the list of available solutions
+  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
+                                                            input_output_type,
+                                                            input_output_type,
+                                                            compute_type,
+                                                            rocblas_gemm_flags_none,
+                                                            solutions.data(),
+                                                            &solution_size));
+  // Sort the solutions in ascending order to make the solution vector deterministic across runs
+  std::sort(solutions.begin(), solutions.end());
+
+  std::vector>>>> ret;
+  for (size_t i = 0; i < solutions.size(); ++i) {
+    auto callable = std::make_unique>(solutions[i]);
+    ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
+  }
+  return ret;
+}
+
+template 
+class RocblasGemmStridedBatchedOp : public Callable> {
+  public:
+    RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
+
+    TuningStatus Call(const GemmStridedBatchedParams* params) override {
+      auto input_output_type = RocBlasDataTypeFor();
+      auto compute_type = RocBlasComputeTypeFor();
+      auto h_a = DoCastForHalfOrBfloat16(params->alpha);
+      auto h_b = DoCastForHalfOrBfloat16(params->beta);
+      auto status = rocblas_gemm_strided_batched_ex(
+          (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
+          _rocblasOpFromChar(params->transa),
+          _rocblasOpFromChar(params->transb),
+          params->m, params->n, params->k,
+          &h_a,
+          params->a, input_output_type, params->lda, params->stride_a,
+          params->b, input_output_type, params->ldb, params->stride_b,
+          &h_b,
+          params->c, input_output_type, params->ldc, params->stride_c,
+          params->c, input_output_type, params->ldc, params->stride_c,
+          params->batch,
+          compute_type,
+          rocblas_gemm_algo_solution_index,
+          solution_,
+          rocblas_gemm_flags_none);
+      if (status != rocblas_status_success) {
+        return FAIL;
+      }
+      return OK;
+    }
+
+  private:
+    int solution_;
+};
+
+template 
+auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
+  rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
+  int solution_size;
+  auto input_output_type = RocBlasDataTypeFor();
+  auto compute_type = RocBlasComputeTypeFor();
+  // Get the number of available solutions
+  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
+                                                            input_output_type,
+                                                            input_output_type,
+                                                            compute_type,
+                                                            rocblas_gemm_flags_none,
+                                                            nullptr,
+                                                            &solution_size));
+  std::vector solutions(solution_size);
+  // Get the list of available solutions
+  TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
+                                                            input_output_type,
+                                                            input_output_type,
+                                                            compute_type,
+                                                            rocblas_gemm_flags_none,
+                                                            solutions.data(),
+                                                            &solution_size));
+  // Sort the solutions in ascending order to make the solution vector deterministic across runs
+  std::sort(solutions.begin(), solutions.end());
+
+  std::vector>>>> ret;
+  for (size_t i = 0; i < solutions.size(); ++i) {
+    auto callable = std::make_unique>(solutions[i]);
+    ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
+  }
+  return ret;
+}
+
+}  // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h
new file mode 100644
index 0000000000000000000000000000000000000000..be2b23ca418a821deb10a05224868de164816c12
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h
@@ -0,0 +1,34 @@
+// Original TunableOp is from onnxruntime.
+// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
+// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+//
+// Adapting TunableOp into PyTorch
+// Copyright (c) Advanced Micro Devices, Inc.
+//
+#pragma once
+
+#include 
+
+#include 
+
+namespace at::cuda::tunable {
+
+class StreamTimer : public ITimer {
+  public:
+    StreamTimer();
+    virtual ~StreamTimer();
+
+    void Start() override;
+
+    void End() override;
+
+    float Duration() override;
+
+  private:
+    cudaEvent_t start_;
+    cudaEvent_t end_;
+};
+
+} // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/Tunable.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/Tunable.h
new file mode 100644
index 0000000000000000000000000000000000000000..292c453aca1355ac6ded16be849933111ca259c0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/Tunable.h
@@ -0,0 +1,205 @@
+// Original TunableOp is from onnxruntime.
+// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
+// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+//
+// Adapting TunableOp into PyTorch
+// Copyright (c) Advanced Micro Devices, Inc.
+//
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::cuda::tunable {
+
+static void TunableLog(const std::string& msg) {
+  static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE");
+  if (env != nullptr && strcmp(env, "1") == 0) {
+    std::cerr << msg << std::endl;
+  }
+}
+#define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__))
+
+enum TuningStatus {
+  OK = 0,
+  FAIL = 1,
+  UNSUPPORTED = 2,
+};
+
+// Mapping from params signature to kernel id
+class ResultEntry {
+  public:
+    explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
+    bool operator==(const ResultEntry& other) { return key_ == other.key_; }
+    bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
+    operator std::string () { return key_; }
+    friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
+    static ResultEntry Null() { return ResultEntry("Null", 0.0); }
+    static ResultEntry Default() { return ResultEntry("Default", 0.0); }
+
+  private:
+    std::string key_;
+    double time_;
+};
+
+typedef std::unordered_map KernelMap;
+typedef std::unordered_map ResultsMap;
+
+struct TuningResults {
+  // Validates if these results are compatible with the libraries
+  std::unordered_map validators;
+
+  // Mapping from Callable signature to Callable's tuning result
+  ResultsMap results;
+};
+
+class TuningResultsManager {
+  public:
+    TuningResultsManager() = default;
+    ~TuningResultsManager() = default;
+
+    KernelMap Lookup(const std::string& op_signature);
+
+    ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
+
+    inline void AddImpl(const std::string& op_signature,
+        const std::string& params_signature,
+        ResultEntry best,
+        KernelMap& kernel_map);
+
+    void Add(const std::string& op_signature,
+        const std::string& params_signature,
+        ResultEntry best);
+
+    void Delete(const std::string& op_signature, const std::string& params_signature);
+
+    inline void DisjointMergeImpl(
+        const std::string& op_signature,
+        const KernelMap& kernel_map,
+        /*out*/ ResultsMap& results);
+
+    void Load(const ResultsMap& results_to_load);
+
+    ResultsMap Dump();
+
+    void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
+
+    size_t GetSize();
+
+  private:
+    std::mutex lock_;
+    ResultsMap results_;
+};
+
+class TuningResultsValidator {
+  public:
+    using GetFunc = std::function;
+    using ValidateFunc = std::function;
+    using GetValidateFuncs = std::unordered_map>;
+
+    TuningResultsValidator();
+    ~TuningResultsValidator() = default;
+
+    std::unordered_map GetAllValidators() const;
+    TuningStatus ValidateAll(const std::unordered_map& to_validate) const;
+    void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
+
+  protected:
+    std::string GetPyTorchVersion() const;
+    TuningStatus ValidatePyTorchVersion(const std::string& value) const;
+
+  public:
+    static constexpr const std::array mandatory_keys{"PT_VERSION"};
+
+  private:
+    GetValidateFuncs validators_;
+};
+
+class TuningContext {
+  public:
+    TuningContext();
+    ~TuningContext();
+    TuningContext(TuningContext &) = delete;
+    TuningContext(TuningContext &&) = delete;
+    TuningContext &operator=(TuningContext &) = delete;
+    TuningContext &operator=(TuningContext &&) = delete;
+
+    void EnableTunableOp();
+    void DisableTunableOp();
+    bool IsTunableOpEnabled() const;
+
+    void EnableTuning();
+    void DisableTuning();
+    bool IsTuningEnabled() const;
+
+    void SetMaxTuningDurationMs(int max_duration_ms);
+    int GetMaxTuningDurationMs() const;
+
+    void SetMaxTuningIterations(int max_iter);
+    int GetMaxTuningIterations() const;
+
+    void SetMaxWarmupDurationMs(int max_duration_ms);
+    int GetMaxWarmupDurationMs() const;
+
+    void SetMaxWarmupIterations(int max_iter);
+    int GetMaxWarmupIterations() const;
+
+    void EnableTunableOpAndTuning();
+    void DisableTunableOpAndTuning();
+
+    TuningResultsManager& GetTuningResultsManager();
+
+    TuningResultsValidator& GetTuningResultsValidator();
+
+    TuningResults GetTuningResults();
+
+    TuningStatus LoadTuningResults(const TuningResults& tr);
+
+    void SetFilename(const std::string& filename);
+    std::string GetFilename() const;
+
+  protected:
+    bool ReadFile(const std::string& filename);
+    bool WriteFile(const std::string& filename);
+
+  private:
+    bool enable_;
+    bool tuning_enable_;
+    bool manager_initialized_;
+    int max_tuning_duration_ms_;
+    int max_tuning_iterations_;
+    int max_warmup_duration_ms_;
+    int max_warmup_iterations_;
+    mutable TuningResultsManager manager_;
+    mutable c10::once_flag manager_init_once_;
+    TuningResultsValidator validator_;
+    std::string filename_;
+    size_t results_count_from_input_file_;
+};
+
+TuningContext* getTuningContext();
+
+class ITimer {
+  public:
+    ITimer() = default;
+    virtual ~ITimer() = default;
+
+    virtual void Start() = 0;
+    virtual void End() = 0;
+
+    /// Computes the elapsed time in milliseconds between Start() and End()
+    virtual float Duration() = 0;
+};
+
+} // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h
new file mode 100644
index 0000000000000000000000000000000000000000..5e9edc0810eb5991b530afec7f17e85318143858
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h
@@ -0,0 +1,278 @@
+// Original TunableOp is from onnxruntime.
+// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
+// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+//
+// Adapting TunableOp into PyTorch
+// Copyright (c) Advanced Micro Devices, Inc.
+//
+#pragma once
+
+#include 
+#ifdef USE_ROCM
+#if ROCM_VERSION >= 50700
+#include 
+#endif
+#include 
+#endif
+#include 
+#include 
+#include 
+#include 
+
+#ifdef USE_ROCM
+#include 
+#endif
+
+#define STRINGIFY(s) #s
+#define XSTRINGIFY(s) STRINGIFY(s)
+
+namespace at::cuda::tunable {
+
+template 
+class DefaultGemmOp : public Callable> {
+  public:
+    TuningStatus Call(const GemmParams* params) override {
+      at::cuda::blas::gemm_internal(
+          params->transa, params->transb,
+          params->m, params->n, params->k,
+          params->alpha,
+          params->a, params->lda,
+          params->b, params->ldb,
+          params->beta,
+          params->c, params->ldc);
+      return OK;
+    }
+};
+
+template 
+class DefaultGemmStridedBatchedOp : public Callable> {
+  public:
+    TuningStatus Call(const GemmStridedBatchedParams* params) override {
+      at::cuda::blas::bgemm_internal(
+          params->transa, params->transb,
+          params->m, params->n, params->k,
+          params->alpha,
+          params->a, params->lda, params->stride_a,
+          params->b, params->ldb, params->stride_b,
+          params->beta,
+          params->c, params->ldc, params->stride_c,
+          params->batch);
+      return OK;
+    }
+};
+
+template 
+bool IsZero(T v) {
+  return v == 0.0f;
+}
+
+template <>
+bool IsZero(BFloat16 v) {
+  return v.x == 0;
+}
+
+template <>
+bool IsZero(Half v) {
+  return float(v) == 0.0f;
+}
+
+template <>
+bool IsZero(c10::complex v) {
+  return v == 0.0;
+}
+
+template <>
+bool IsZero(c10::complex v) {
+  return v == 0.0f;
+}
+
+template 
+std::string TypeName(T v) {
+  return "unknown";
+}
+
+template <>
+std::string TypeName(float v) {
+  return "float";
+}
+
+template <>
+std::string TypeName(double v) {
+  return "double";
+}
+
+template <>
+std::string TypeName(BFloat16 v) {
+  return "BFloat16";
+}
+
+template <>
+std::string TypeName(Half v) {
+  return "Half";
+}
+
+template <>
+std::string TypeName(c10::complex v) {
+  return "c10::complex";
+}
+
+template <>
+std::string TypeName(c10::complex v) {
+  return "c10::complex";
+}
+
+
+template 
+class GemmTunableOp : public TunableOp, StreamTimer> {
+ public:
+  GemmTunableOp() {
+    this->RegisterOp(std::string("Default"), std::make_unique>());
+
+    auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
+
+#ifdef USE_ROCM
+    for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) {
+      this->RegisterOp(std::move(name), std::move(op));
+    }
+
+    if (validators.find("ROCM_VERSION") == validators.end()) {
+      std::string rocm_version = ROCM_BUILD_INFO;
+      getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+          "ROCM_VERSION",
+          [rocm_version]() { return rocm_version; },
+          [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
+    }
+
+    if (validators.find("GCN_ARCH_NAME") == validators.end()) {
+      std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
+      getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+          "GCN_ARCH_NAME",
+          [gcn_arch_name]() { return gcn_arch_name; },
+          [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
+    }
+
+    if (validators.find("ROCBLAS_VERSION") == validators.end()) {
+      std::string rocblas_version = c10::str(
+          XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
+          XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
+          XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
+          XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
+      getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+          "ROCBLAS_VERSION",
+          [rocblas_version]() { return rocblas_version; },
+          [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
+    }
+#endif
+
+#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+    static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
+    if (env == nullptr || strcmp(env, "1") == 0) {
+      // disallow tuning of hipblaslt with c10::complex
+      if constexpr (
+          !std::is_same_v> &&
+          !std::is_same_v>) {
+        for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) {
+          this->RegisterOp(std::move(name), std::move(op));
+        }
+      }
+
+      if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
+        std::string hipblaslt_version = c10::str(
+            XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
+            XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
+            XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
+            XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
+        getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+            "HIPBLASLT_VERSION",
+            [hipblaslt_version]() { return hipblaslt_version; },
+            [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
+      }
+    }
+#endif
+  }
+
+  std::string Signature() override {
+    return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
+  }
+};
+
+template 
+class GemmStridedBatchedTunableOp : public TunableOp, StreamTimer> {
+ public:
+  GemmStridedBatchedTunableOp() {
+    this->RegisterOp(std::string("Default"), std::make_unique>());
+
+    auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
+
+#ifdef USE_ROCM
+    for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) {
+      this->RegisterOp(std::move(name), std::move(op));
+    }
+
+    if (validators.find("ROCM_VERSION") == validators.end()) {
+      std::string rocm_version = ROCM_BUILD_INFO;
+      getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+          "ROCM_VERSION",
+          [rocm_version]() { return rocm_version; },
+          [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
+    }
+
+    if (validators.find("GCN_ARCH_NAME") == validators.end()) {
+      std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
+      getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+          "GCN_ARCH_NAME",
+          [gcn_arch_name]() { return gcn_arch_name; },
+          [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
+    }
+
+    if (validators.find("ROCBLAS_VERSION") == validators.end()) {
+      std::string rocblas_version = c10::str(
+          XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
+          XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
+          XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
+          XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
+      getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+          "ROCBLAS_VERSION",
+          [rocblas_version]() { return rocblas_version; },
+          [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
+    }
+#endif
+
+#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+    static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
+    if (env == nullptr || strcmp(env, "1") == 0) {
+      // disallow tuning of hipblaslt with c10::complex
+      if constexpr (
+          !std::is_same_v> &&
+          !std::is_same_v>) {
+        for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps()) {
+          this->RegisterOp(std::move(name), std::move(op));
+        }
+      }
+
+      if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
+        std::string hipblaslt_version = c10::str(
+            XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
+            XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
+            XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
+            XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
+        getTuningContext()->GetTuningResultsValidator().RegisterValidator(
+            "HIPBLASLT_VERSION",
+            [hipblaslt_version]() { return hipblaslt_version; },
+            [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
+      }
+    }
+#endif
+  }
+
+  std::string Signature() override {
+    return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
+  }
+};
+
+#undef XSTRINGIFY
+#undef STRINGIFY
+
+} // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h
new file mode 100644
index 0000000000000000000000000000000000000000..3047a90efc78d2355e3b8b7c4a74d53fd1a6c644
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h
@@ -0,0 +1,242 @@
+// Original TunableOp is from onnxruntime.
+// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
+// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+//
+// Adapting TunableOp into PyTorch
+// Copyright (c) Advanced Micro Devices, Inc.
+//
+#pragma once
+
+#include 
+#include 
+
+#ifndef _WIN32
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::cuda::tunable {
+
+template 
+class Callable {
+  public:
+    Callable() = default;
+    Callable(Callable&&) = default;
+    virtual ~Callable() = default;
+    virtual TuningStatus Call(const ParamsT*) {
+      return FAIL;
+    }
+    virtual TuningStatus IsSupported(const ParamsT* params) {
+      return Call(params);
+    }
+};
+
+template 
+class TunableOp {
+  public:
+    TunableOp() = default;
+    TunableOp(TunableOp&&) = default;
+    virtual ~TunableOp() = default;
+
+    TuningStatus operator()(const ParamsT* params) {
+      ResultEntry result = ResultEntry::Null();
+      TuningContext* ctx = getTuningContext();
+      if (ctx->IsTunableOpEnabled()) {
+        auto& mgr = ctx->GetTuningResultsManager();
+        auto op_sig = Signature();
+        auto params_sig = params->Signature();
+        result = mgr.Lookup(op_sig, params_sig);
+        // If there is not previous tuning result been found, we do the tuning iff tuning is enabled
+        if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
+          result = FindFastest(params);
+          mgr.Add(op_sig, params_sig, result);
+        }
+      }
+      else {
+        result = ResultEntry::Default();
+      }
+      if (result == ResultEntry::Null()) {
+        TUNABLE_LOG("no result, using default");
+        result = ResultEntry::Default();
+      }
+      auto iter = ops_.find(result);
+      TORCH_CHECK(iter != ops_.end());
+      return iter->second->Call(params);
+    }
+
+    virtual std::string Signature() {
+      // According to C++17 standard https://wg21.link/n4659 section 15.7.4
+      // > if the operand of typeid refers to the
+      // > object under construction or destruction, typeid yields the std::type_info object representing the constructor
+      // > or destructor’s class.
+      // So delay the op signature generation.
+      c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
+      return signature_;
+    }
+
+  protected:
+    void RegisterOp(const std::string& name, std::unique_ptr> op) {
+      this->op_names_.emplace_back(name);
+      this->ops_.emplace(name, std::move(op));
+    }
+
+  private:
+    static void WarmUp(Callable *op, ParamsT* param, size_t num_iter) {
+      for (size_t i = 0; i < num_iter; i++) {
+        TORCH_CHECK(op->Call(param) == OK);
+      }
+    }
+
+    static double Profile(Callable *op, ParamsT* param, size_t num_iter) {
+      TimerT timer{};
+      timer.Start();
+      for (size_t i = 0; i < num_iter; i++) {
+        TORCH_CHECK(op->Call(param) == OK);
+      }
+      timer.End();
+      return timer.Duration() / num_iter;
+    }
+
+  protected:
+    bool IsNumericsCheckEnabled() {
+      static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
+      if (env != nullptr && strcmp(env, "0") == 0) {
+        return false;
+      }
+      return true;
+    }
+
+    virtual ResultEntry FindFastest(const ParamsT* params) {
+      TuningContext* ctx = getTuningContext();
+      auto op_sig = Signature();
+      auto params_sig = params->Signature();
+      TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
+      auto min_duration_ms = std::numeric_limits::infinity();
+      std::string id_name = "Default";
+
+      // calcaulte a reference answer for numerical check
+      ParamsT* reference_params = params->DeepCopy();
+      TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
+
+      // need a copy of params to reuse
+      ParamsT* reusable_params = params->DeepCopy();
+
+      for (size_t i = 0; i < op_names_.size(); i++) {
+        auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
+        auto status = candidate->Call(reusable_params);
+        if (status != OK) {
+          TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
+          continue;
+        }
+
+        if (IsNumericsCheckEnabled()) {
+          ParamsT* numerical_params = params->DeepCopy();
+          WarmUp(candidate, numerical_params, 1);
+          status = reference_params->NumericalCheck(numerical_params);
+          numerical_params->Delete();
+          if (status != OK) {
+            TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
+            continue;
+          }
+        }
+
+        // collect a small profile
+        constexpr const int approx_num_iter = 3;
+        auto approx_duration = Profile(candidate, reusable_params, approx_num_iter);
+        // bail if too slow
+        if (approx_duration > 2 * min_duration_ms) {
+          TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
+          continue;
+        }
+
+        // for warmup does user set max duration, max iters, or both?
+        double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
+        int max_warmup_iter = ctx->GetMaxWarmupIterations();
+        int warmup_iter = 1; // default
+        if (max_warmup_duration > 0) {
+          int duration_iters = max_warmup_duration / approx_duration;
+          if (max_warmup_iter > 0) {
+            warmup_iter = std::min(max_warmup_iter, duration_iters);
+          }
+          else {
+            warmup_iter = duration_iters;
+          }
+        }
+        else if (max_warmup_iter > 0) {
+          warmup_iter = max_warmup_iter;
+        }
+
+        // for tuning does user set max duration, max iters, or both?
+        double max_tuning_duration = ctx->GetMaxTuningDurationMs();
+        int max_tuning_iter = ctx->GetMaxTuningIterations();
+        int tuning_iter = 100; // default
+        if (max_tuning_duration > 0) {
+          int duration_iters = max_tuning_duration / approx_duration;
+          if (max_tuning_iter > 0) {
+            tuning_iter = std::min(max_tuning_iter, duration_iters);
+          }
+          else {
+            tuning_iter = duration_iters;
+          }
+        }
+        else if (max_tuning_iter > 0) {
+          tuning_iter = max_tuning_iter;
+        }
+
+        // do the full warmup followed by tuning
+        double warmup_ms = warmup_iter * approx_duration;
+        double tuning_ms = tuning_iter * approx_duration;
+        TUNABLE_LOG("├──tuning using "
+            "warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
+            "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
+            "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
+        WarmUp(candidate, reusable_params, warmup_iter);
+        auto duration_ms = Profile(candidate, reusable_params, tuning_iter);
+        if (duration_ms < min_duration_ms) {
+          TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
+          min_duration_ms = duration_ms;
+          id_name = op_names_[i];
+        }
+      }
+
+      reusable_params->Delete();
+      reference_params->Delete();
+
+      TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
+      return ResultEntry(id_name, min_duration_ms);
+    }
+
+  private:
+    std::string CreateSignature() {
+#ifndef _WIN32
+      const auto* name = typeid(*this).name();
+      char buf[256];
+      size_t buf_len = 256;
+      abi::__cxa_demangle(name, buf, &buf_len, nullptr);
+      buf[255] = '\0';
+      return buf;
+#else
+      return typeid(*this).name();
+#endif
+    }
+
+    mutable c10::once_flag signature_init_once_;
+    std::string signature_;
+
+    std::unordered_map>> ops_;
+    std::vector op_names_;
+};
+
+struct OpParams {
+  OpParams() {}
+  virtual ~OpParams() = default;
+  virtual std::string Signature() const = 0;
+};
+
+} // namespace at::cuda::tunable
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Descriptors.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Descriptors.h
new file mode 100644
index 0000000000000000000000000000000000000000..96d457601eb0fe115fb7816e39c43594a47414da
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Descriptors.h
@@ -0,0 +1,391 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907
+#define USE_CUDNN_RNN_V8_API
+#endif
+
+namespace at { namespace native {
+
+std::string cudnnTypeToString(cudnnDataType_t dtype);
+
+// TODO: Add constructors for all of the descriptors
+
+inline int dataSize(cudnnDataType_t dataType)
+{
+  switch (dataType) {
+#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
+    case CUDNN_DATA_BFLOAT16:
+#endif
+    case CUDNN_DATA_HALF: return 2;
+    case CUDNN_DATA_FLOAT: return 4;
+    default: return 8;
+  }
+}
+
+// The stride for a size-1 dimensions is not uniquely determined; in
+// fact, it can be anything you want, because the fact that the
+// tensor is size 1 at this dimension means that you will never actually
+// try advancing your pointer by this stride.
+//
+// However, CuDNN has a much more stringent requirement on strides:
+// if you are passing a contiguous input, it better be the case
+// that the stride for dim i is the product of the sizes of dims
+// i+1 to the end.  This stride is indeed uniquely determined.  This
+// function modifies 'stride' in place so this invariant holds.
+template 
+static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
+  int64_t z = 1;
+  int index = 0;
+  std::vector permutation(dim);
+
+  if (nhwc) {
+    permutation[index++] = 1;
+  }
+  for (int d = dim-1; d > 1; d--) {
+    permutation[index++] = d;
+  }
+  if (!nhwc) {
+    permutation[index++] = 1;
+  }
+  permutation[index++] = 0;
+  for (int d : permutation) {
+    if (size[d] == 1) {
+      stride[d] = z;
+    } else {
+      z *= size[d];
+    }
+  }
+}
+
+template 
+struct DescriptorDeleter {
+  void operator()(T* x) {
+    if (x != nullptr) {
+      AT_CUDNN_CHECK(dtor(x));
+    }
+  }
+};
+
+// A generic class for wrapping cuDNN descriptor types.  All you need
+// is to give the underlying type the Descriptor_t points to (usually,
+// if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct),
+// the constructor and the destructor.  Subclasses are responsible
+// for defining a set() function to actually set the descriptor.
+//
+// Descriptors default construct to a nullptr, and have a descriptor
+// initialized the first time you call set() or any other initializing
+// function.
+template 
+class TORCH_CUDA_CPP_API Descriptor {
+ public:
+  // TODO: Figure out why const-correctness doesn't work here
+
+  // Use desc() to access the underlying descriptor pointer in
+  // a read-only fashion.  Most client code should use this.
+  // If the descriptor was never initialized, this will return
+  // nullptr.
+  T* desc() const { return desc_.get(); }
+  T* desc() { return desc_.get(); }
+
+  // Use mut_desc() to access the underlying descriptor pointer
+  // if you intend to modify what it points to (e.g., using
+  // cudnnSetFooDescriptor).  This will ensure that the descriptor
+  // is initialized.  Code in this file will use this function.
+  T* mut_desc() { init(); return desc_.get(); }
+protected:
+  void init() {
+    if (desc_ == nullptr) {
+      T* raw_desc;
+      AT_CUDNN_CHECK(ctor(&raw_desc));
+      desc_.reset(raw_desc);
+    }
+  }
+private:
+  std::unique_ptr> desc_;
+};
+
+class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor<
+                                       cudnnRNNDataStruct,
+                                       &cudnnCreateRNNDataDescriptor,
+                                       &cudnnDestroyRNNDataDescriptor> {
+public:
+  void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray);
+private:
+  void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) {
+    AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, NULL));
+  }
+};
+
+class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
+                                               cudnnTensorStruct,
+                                               &cudnnCreateTensorDescriptor,
+                                               &cudnnDestroyTensorDescriptor> {
+ public:
+  TensorDescriptor() = default;
+  explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
+    set(t, pad);
+  }
+
+  // Note [CuDNN broadcast padding]
+  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+  // pad specifies the minimum dimensionality of the tensor descriptor
+  // we produce (it doesn't have anything to do with, e.g., convolution
+  // padding).  If 't' is lower-dimensional than 'pad', the remaining
+  // dimensions (on the right) are padded with ones.  This doesn't
+  // affect the underlying data layout.  This is particularly useful for
+  // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is
+  // done in two steps: first, the client code is expected to pad out
+  // (the dimensions) input tensors to be the same dimension as the
+  // target broadcast, and then second, CuDNN takes of actually
+  // broadcasting size 1 dimensions.
+
+  void set(const at::Tensor &t, size_t pad = 0);
+  void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
+  void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
+
+  void print();
+
+private:
+  void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
+
+  void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
+    fixSizeOneDimStride(dim, size, stride, nhwc);
+    AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
+  }
+};
+
+std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
+
+class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
+                                               cudnnFilterStruct,
+                                               &cudnnCreateFilterDescriptor,
+                                               &cudnnDestroyFilterDescriptor> {
+ public:
+  void set(const at::Tensor &t, int64_t pad = 0) {
+    set(t, at::MemoryFormat::Contiguous, pad);
+  }
+
+  void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
+
+  void print();
+private:
+  void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
+    AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
+  }
+};
+
+std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
+
+struct TORCH_CUDA_CPP_API ConvolutionDescriptor
+    : public Descriptor<
+          cudnnConvolutionStruct,
+          &cudnnCreateConvolutionDescriptor,
+          &cudnnDestroyConvolutionDescriptor> {
+  void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
+    cudnnDataType_t mathType = dataType;
+    if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
+    AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
+                                          CUDNN_CROSS_CORRELATION, mathType));
+    AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
+    // See Note [behavior of cudnnFind and cudnnGet]
+    AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
+    if(dataType == CUDNN_DATA_HALF) {
+      AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
+    } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
+      AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
+    }
+  }
+};
+
+struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
+    : public Descriptor<
+          cudnnSpatialTransformerStruct,
+          &cudnnCreateSpatialTransformerDescriptor,
+          &cudnnDestroySpatialTransformerDescriptor> {
+  void set(cudnnDataType_t dataType, int dim, int* size) {
+    AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
+  }
+};
+
+struct TORCH_CUDA_CPP_API DropoutDescriptor
+    : public Descriptor<
+          cudnnDropoutStruct,
+          &cudnnCreateDropoutDescriptor,
+          &cudnnDestroyDropoutDescriptor> {
+  at::Tensor state;
+
+  // Initialize a dropout descriptor's RNG state.
+  // WARNING: This function is very expensive, avoid calling this function!
+  void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
+    TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
+    size_t state_size;
+    AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
+    AT_ASSERT(options.device().type() == kCUDA);
+    AT_ASSERT(options.dtype() == kByte);
+    state = at::empty({static_cast(state_size)}, options);
+    AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
+  }
+
+  // Restore a dropout descriptor given a dropout probability and existing RNG state.
+  void set(cudnnHandle_t handle, float dropout, at::Tensor state_) {
+    TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
+    state = state_;
+    void *state_ptr = state.data_ptr();
+    size_t state_size = state.size(0);
+    // NB: The seed doesn't actually matter, so we give a dummy value
+    AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
+  }
+
+  // Restore a dropout descriptor corresponding to no dropout
+  void set_no_dropout(cudnnHandle_t handle) {
+    // NB: seed doesn't matter when dropout = 0, because no random number
+    // initialization actually takes place when there is no dropout.
+    // NB: Empirically, cudnnSetDropoutDescriptor is cheap when
+    // dropout == 0
+    AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */));
+  }
+};
+
+struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
+                                             cudnnRNNStruct,
+                                             &cudnnCreateRNNDescriptor,
+                                             &cudnnDestroyRNNDescriptor> {
+  DropoutDescriptor dropout_desc_;
+  void set(cudnnHandle_t handle,
+#ifdef USE_CUDNN_RNN_V8_API
+       int input_size,
+       bool packed,
+#endif
+       int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
+           cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
+           cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
+    dropout_desc_ = std::move(dropout_desc);
+#ifndef USE_CUDNN_RNN_V8_API
+    AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
+          handle,
+          mut_desc(),
+          hidden_size,
+          num_layers,
+          dropout_desc_.desc(),
+          input_mode,
+          bidirectional,
+          mode,
+          algo,
+          datatype));
+    if (proj_size != 0) {
+      AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
+            handle,
+            /*rnnDesc=*/mut_desc(),
+            /*recProjSize=*/proj_size,
+            /*outProjSize=*/0));
+    }
+    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+    if (prop->major >= 7) {
+      if (input_type == CUDNN_DATA_HALF) {
+        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
+      }
+      else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
+        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
+      }
+      else {
+        // Technically, as the default it's not necessary to explicitly
+        // set this.
+        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
+      }
+    }
+#else
+    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+    auto math_type = CUDNN_DEFAULT_MATH;
+    if (prop->major >= 7) {
+      if (input_type == CUDNN_DATA_HALF) {
+        math_type = CUDNN_TENSOR_OP_MATH;
+      } else if (!allow_tf32) {
+        math_type = CUDNN_FMA_MATH;
+      }
+    }
+    AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8(
+          mut_desc(),
+          algo,
+          mode,
+          CUDNN_RNN_DOUBLE_BIAS,
+          bidirectional,
+          input_mode,
+          input_type,
+          datatype,
+          math_type,
+          input_size,
+          hidden_size,
+          proj_size ? proj_size : hidden_size,
+          num_layers,
+          dropout_desc_.desc(),
+          packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED));
+#endif
+  }
+};
+
+struct TORCH_CUDA_CPP_API CTCLossDescriptor
+    : public Descriptor<
+          cudnnCTCLossStruct,
+          &cudnnCreateCTCLossDescriptor,
+          &cudnnDestroyCTCLossDescriptor> {
+  void set(cudnnDataType_t datatype) {
+    AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
+  }
+  void setEx(
+      cudnnDataType_t datatype,
+      cudnnLossNormalizationMode_t normMode,
+      cudnnNanPropagation_t gradMode) {
+    AT_CUDNN_CHECK(
+        cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
+  }
+};
+
+struct TORCH_CUDA_CPP_API ActivationDescriptor
+    : public Descriptor<
+          cudnnActivationStruct,
+          &cudnnCreateActivationDescriptor,
+          &cudnnDestroyActivationDescriptor> {
+  void set(cudnnActivationMode_t mode) {
+    AT_ASSERT(
+        mode == CUDNN_ACTIVATION_RELU,
+        "TODO: support more cuDNN activation modes");
+    AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
+        mut_desc(),
+        mode,
+        cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
+        std::numeric_limits::max()));
+  }
+};
+
+union Constant
+{
+  float f;
+  double d;
+  Constant(cudnnDataType_t dataType, double value) {
+    if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
+      f = static_cast(value);
+    } else {
+      d = value;
+    }
+  }
+};
+
+}}  // namespace
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Exceptions.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Exceptions.h
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Handle.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Handle.h
new file mode 100644
index 0000000000000000000000000000000000000000..3415d86dd944d4b0451ea4d3586cbe807c8f32eb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Handle.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+
+TORCH_CUDA_CPP_API cudnnHandle_t getCudnnHandle();
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Handles.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Handles.h
new file mode 100644
index 0000000000000000000000000000000000000000..65b5d4454879ad165c8e002fc5df4c400da9303a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Handles.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Types.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Types.h
new file mode 100644
index 0000000000000000000000000000000000000000..31e39404036b946c7db6b8d9f14706905a18cf30
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Types.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+
+TORCH_CUDA_CPP_API cudnnDataType_t
+getCudnnDataTypeFromScalarType(const at::ScalarType dtype);
+cudnnDataType_t getCudnnDataType(const at::Tensor& tensor);
+
+int64_t cudnn_version();
+
+}}  // namespace at::cudnn
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Utils.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..c132840385d8e7bf2de5e1b2afc9cf05a8a2b9ed
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/Utils.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at { namespace native {
+
+// cuDNN has a buggy check for tensor being contiguous (that is, it does
+// not ignore stride for dimension that is equal to 0).  This function
+// makes tensors which have zero stride contiguous, by setting the
+// strides to 1 as cuDNN likes.
+inline Tensor contiguousIfZeroInStrides(const Tensor& t) {
+  for (auto s : t.strides()) {
+    if (s == 0) return t.contiguous();
+  }
+  return t;
+}
+
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h
new file mode 100644
index 0000000000000000000000000000000000000000..fbea50d26fd01f5ee211ce0128ac220adafdcb5d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/cudnn/cudnn-wrapper.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include 
+
+#define STRINGIFY(x) #x
+#define STRING(x) STRINGIFY(x)
+
+#if CUDNN_MAJOR < 6
+#pragma message ("CuDNN v" STRING(CUDNN_MAJOR) " found, but need at least CuDNN v6. You can get the latest version of CuDNN from https://developer.nvidia.com/cudnn or disable CuDNN with USE_CUDNN=0")
+#pragma message "We strongly encourage you to move to 6.0 and above."
+#pragma message "This message is intended to annoy you enough to update."
+#endif
+
+#undef STRINGIFY
+#undef STRING
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..d013b91e9d50f54e518d29bff44aa34e3d6cf903
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/AcceleratorHooksInterface.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include 
+
+namespace at {
+
+// AcceleratorHooksInterface is a shared interface provided by all
+// accelerators to allow generic code.
+// This inferface is hook-based as it corresponds to all the functions
+// that are going to be called in a generic way from the CPU code.
+
+struct TORCH_API AcceleratorHooksInterface {
+  // This should never actually be implemented, but it is used to
+  // squelch -Werror=non-virtual-dtor
+  virtual ~AcceleratorHooksInterface() = default;
+
+  // Whether the device at device_index is fully initialized or not.
+  virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..ce9d84e62d0389c24aac9852aa315e5cb661844d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/CUDAHooksInterface.h
@@ -0,0 +1,201 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+// Forward-declares at::Generator and at::cuda::NVRTC
+namespace at {
+struct Generator;
+namespace cuda {
+struct NVRTC;
+} // namespace cuda
+} // namespace at
+
+// NB: Class must live in `at` due to limitations of Registry.h.
+namespace at {
+
+#ifdef _MSC_VER
+constexpr const char* CUDA_HELP =
+  "PyTorch splits its backend into two shared libraries: a CPU library "
+  "and a CUDA library; this error has occurred because you are trying "
+  "to use some CUDA functionality, but the CUDA library has not been "
+  "loaded by the dynamic linker for some reason.  The CUDA library MUST "
+  "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
+  "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ "
+  "in your link arguments; many dynamic linkers will delete dynamic library "
+  "dependencies if you don't depend on any of their symbols.  You can check "
+  "if this has occurred by using link on your binary to see if there is a "
+  "dependency on *_cuda.dll library.";
+#else
+constexpr const char* CUDA_HELP =
+  "PyTorch splits its backend into two shared libraries: a CPU library "
+  "and a CUDA library; this error has occurred because you are trying "
+  "to use some CUDA functionality, but the CUDA library has not been "
+  "loaded by the dynamic linker for some reason.  The CUDA library MUST "
+  "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
+  "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many "
+  "dynamic linkers will delete dynamic library dependencies if you don't "
+  "depend on any of their symbols.  You can check if this has occurred by "
+  "using ldd on your binary to see if there is a dependency on *_cuda.so "
+  "library.";
+#endif
+
+// The CUDAHooksInterface is an omnibus interface for any CUDA functionality
+// which we may want to call into from CPU code (and thus must be dynamically
+// dispatched, to allow for separate compilation of CUDA code).  How do I
+// decide if a function should live in this class?  There are two tests:
+//
+//  1. Does the *implementation* of this function require linking against
+//     CUDA libraries?
+//
+//  2. Is this function *called* from non-CUDA ATen code?
+//
+// (2) should filter out many ostensible use-cases, since many times a CUDA
+// function provided by ATen is only really ever used by actual CUDA code.
+//
+// TODO: Consider putting the stub definitions in another class, so that one
+// never forgets to implement each virtual function in the real implementation
+// in CUDAHooks.  This probably doesn't buy us much though.
+struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
+  // This should never actually be implemented, but it is used to
+  // squelch -Werror=non-virtual-dtor
+  virtual ~CUDAHooksInterface() override = default;
+
+  // Initialize THCState and, transitively, the CUDA state
+  virtual void initCUDA() const {
+    TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual const Generator& getDefaultCUDAGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
+    TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual Device getDeviceFromPtr(void* /*data*/) const {
+    TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual bool isPinnedPtr(const void* /*data*/) const {
+    return false;
+  }
+
+  virtual bool hasCUDA() const {
+    return false;
+  }
+
+  virtual bool hasCUDART() const {
+    return false;
+  }
+
+  virtual bool hasMAGMA() const {
+    return false;
+  }
+
+  virtual bool hasCuDNN() const {
+    return false;
+  }
+
+  virtual bool hasCuSOLVER() const {
+    return false;
+  }
+
+  virtual bool hasROCM() const {
+    return false;
+  }
+
+  virtual const at::cuda::NVRTC& nvrtc() const {
+    TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
+  }
+
+  virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
+    TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual DeviceIndex current_device() const {
+    return -1;
+  }
+
+  virtual Allocator* getPinnedMemoryAllocator() const {
+    TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
+  }
+
+  virtual Allocator* getCUDADeviceAllocator() const {
+    TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. ", CUDA_HELP);
+  }
+
+  virtual bool compiledWithCuDNN() const {
+    return false;
+  }
+
+  virtual bool compiledWithMIOpen() const {
+    return false;
+  }
+
+  virtual bool supportsDilatedConvolutionWithCuDNN() const {
+    return false;
+  }
+
+  virtual bool supportsDepthwiseConvolutionWithCuDNN() const {
+    return false;
+  }
+
+  virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const {
+    return false;
+  }
+
+  virtual long versionCuDNN() const {
+    TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual long versionCUDART() const {
+    TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual std::string showConfig() const {
+    TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual double batchnormMinEpsilonCuDNN() const {
+    TORCH_CHECK(false,
+        "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const {
+    TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const {
+    TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const {
+    TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const {
+    TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
+  }
+
+  virtual int getNumGPUs() const {
+    return 0;
+  }
+
+  virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
+    TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
+  }
+};
+
+// NB: dummy argument to suppress "ISO C++11 requires at least one argument
+// for the "..." in a variadic macro"
+struct TORCH_API CUDAHooksArgs {};
+
+TORCH_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs);
+#define REGISTER_CUDA_HOOKS(clsname) \
+  C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const CUDAHooksInterface& getCUDAHooks();
+} // namespace detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/FunctionTraits.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/FunctionTraits.h
new file mode 100644
index 0000000000000000000000000000000000000000..51fe0b8320a38aad0cb7c253f701674051d3a496
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/FunctionTraits.h
@@ -0,0 +1,102 @@
+#pragma once
+
+#include 
+
+// Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda
+
+// Fallback, anything with an operator()
+template 
+struct function_traits : public function_traits {
+};
+
+// Pointers to class members that are themselves functors.
+// For example, in the following code:
+// template 
+// struct S {
+//     func_t f;
+// };
+// template 
+// S make_s(func_t f) {
+//     return S { .f = f };
+// }
+//
+// auto s = make_s([] (int, float) -> double { /* ... */ });
+//
+// function_traits traits;
+template 
+struct function_traits : public function_traits {
+};
+
+// Const class member functions
+template 
+struct function_traits : public function_traits {
+};
+
+// Reference types
+template 
+struct function_traits : public function_traits {};
+template 
+struct function_traits : public function_traits {};
+
+// Free functions
+template 
+struct function_traits {
+  // arity is the number of arguments.
+  enum { arity = sizeof...(Args) };
+
+  typedef std::tuple ArgsTuple;
+  typedef ReturnType result_type;
+
+  template 
+  struct arg
+  {
+      typedef typename std::tuple_element>::type type;
+      // the i-th argument is equivalent to the i-th tuple element of a tuple
+      // composed of those arguments.
+  };
+};
+
+template 
+struct nullary_function_traits {
+  using traits = function_traits;
+  using result_type = typename traits::result_type;
+};
+
+template 
+struct unary_function_traits {
+  using traits = function_traits;
+  using result_type = typename traits::result_type;
+  using arg1_t = typename traits::template arg<0>::type;
+};
+
+template 
+struct binary_function_traits {
+  using traits = function_traits;
+  using result_type = typename traits::result_type;
+  using arg1_t = typename traits::template arg<0>::type;
+  using arg2_t = typename traits::template arg<1>::type;
+};
+
+
+// Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType
+template 
+struct invoke_traits : public function_traits{
+};
+
+template 
+struct invoke_traits : public invoke_traits{
+};
+
+template 
+struct invoke_traits : public invoke_traits{
+};
+
+template 
+struct invoke_traits :
+  public function_traits {
+};
+
+template 
+struct invoke_traits :
+  public function_traits {
+};
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/HIPHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/HIPHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..f9866a872b67849917903de7399061d2879782c4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/HIPHooksInterface.h
@@ -0,0 +1,70 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+
+namespace at {
+class Context;
+}
+
+// NB: Class must live in `at` due to limitations of Registry.h.
+namespace at {
+
+// The HIPHooksInterface is an omnibus interface for any HIP functionality
+// which we may want to call into from CPU code (and thus must be dynamically
+// dispatched, to allow for separate compilation of HIP code).  See
+// CUDAHooksInterface for more detailed motivation.
+struct TORCH_API HIPHooksInterface {
+  // This should never actually be implemented, but it is used to
+  // squelch -Werror=non-virtual-dtor
+  virtual ~HIPHooksInterface() = default;
+
+  // Initialize the HIP library state
+  virtual void initHIP() const {
+    AT_ERROR("Cannot initialize HIP without ATen_hip library.");
+  }
+
+  virtual std::unique_ptr initHIPGenerator(Context*) const {
+    AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
+  }
+
+  virtual bool hasHIP() const {
+    return false;
+  }
+
+  virtual c10::DeviceIndex current_device() const {
+    return -1;
+  }
+
+  virtual Allocator* getPinnedMemoryAllocator() const {
+    AT_ERROR("Pinned memory requires HIP.");
+  }
+
+  virtual void registerHIPTypes(Context*) const {
+    AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
+  }
+
+  virtual int getNumGPUs() const {
+    return 0;
+  }
+};
+
+// NB: dummy argument to suppress "ISO C++11 requires at least one argument
+// for the "..." in a variadic macro"
+struct TORCH_API HIPHooksArgs {};
+
+TORCH_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs);
+#define REGISTER_HIP_HOOKS(clsname) \
+  C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const HIPHooksInterface& getHIPHooks();
+
+} // namespace detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/IPUHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/IPUHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..daa89c733779d99fd49f384de638df2ce569b728
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/IPUHooksInterface.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+struct TORCH_API IPUHooksInterface {
+  virtual ~IPUHooksInterface() = default;
+
+  virtual const Generator& getDefaultIPUGenerator(
+      DeviceIndex device_index = -1) const {
+    AT_ERROR(
+        "Cannot get the default IPU generator: the IPU backend is not "
+        "available.");
+  }
+
+  virtual Generator newIPUGenerator(DeviceIndex device_index = -1) const {
+    AT_ERROR(
+        "Cannot create a new IPU generator: the IPU backend is not available.");
+  }
+};
+
+struct TORCH_API IPUHooksArgs {};
+
+TORCH_DECLARE_REGISTRY(IPUHooksRegistry, IPUHooksInterface, IPUHooksArgs);
+#define REGISTER_IPU_HOOKS(clsname) \
+  C10_REGISTER_CLASS(IPUHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const IPUHooksInterface& getIPUHooks();
+} // namespace detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/MPSHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/MPSHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..e81b590476e3b5ed0d99ec99006c2ed8e9d62572
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/MPSHooksInterface.h
@@ -0,0 +1,106 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+
+struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
+  // this fails the implementation if MPSHooks functions are called, but
+  // MPS backend is not present.
+  #define FAIL_MPSHOOKS_FUNC(func) \
+    TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");
+
+  virtual ~MPSHooksInterface() override = default;
+
+  // Initialize the MPS library state
+  virtual void initMPS() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual bool hasMPS() const {
+    return false;
+  }
+  virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual const Generator& getDefaultMPSGenerator() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual Allocator* getMPSDeviceAllocator() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void deviceSynchronize() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void commitStream() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void* getCommandBuffer() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void* getDispatchQueue() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void emptyCache() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual size_t getCurrentAllocatedMemory() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual size_t getDriverAllocatedMemory() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void setMemoryFraction(double /*ratio*/) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void profilerStopTrace() const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual uint32_t acquireEvent(bool enable_timing) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void releaseEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void recordEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void waitForEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void synchronizeEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual bool queryEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  #undef FAIL_MPSHOOKS_FUNC
+};
+
+struct TORCH_API MPSHooksArgs {};
+
+TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs);
+#define REGISTER_MPS_HOOKS(clsname) \
+  C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const MPSHooksInterface& getMPSHooks();
+
+} // namespace detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..43c110777cd0cded229b28587f3feb0dcf1b584c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include 
+
+#include 
+
+#include 
+
+#include 
+
+namespace at {
+class Context;
+}
+
+namespace at {
+
+constexpr const char* MTIA_HELP =
+    "The MTIA backend requires MTIA extension for PyTorch;"
+    "this error has occurred because you are trying "
+    "to use some MTIA's functionality without MTIA extension included.";
+
+struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
+  virtual ~MTIAHooksInterface() override = default;
+
+  virtual void initMTIA() const {
+    TORCH_CHECK(
+        false,
+        "Cannot initialize MTIA without MTIA Extension for PyTorch.",
+        MTIA_HELP);
+  }
+
+  virtual bool hasMTIA() const {
+    return false;
+  }
+
+  virtual std::string showConfig() const {
+    TORCH_CHECK(
+        false,
+        "Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
+        MTIA_HELP);
+  }
+
+  virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
+    TORCH_CHECK(
+        false,
+        "Cannot check MTIA primary context without MTIA Extension for PyTorch.",
+        MTIA_HELP);
+  }
+
+};
+
+struct TORCH_API MTIAHooksArgs {};
+
+C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
+#define REGISTER_MTIA_HOOKS(clsname) \
+  C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const MTIAHooksInterface& getMTIAHooks();
+} // namespace detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/ORTHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/ORTHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..af22f687c13d6d81365dfbc5def7739165339f8b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/ORTHooksInterface.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+#include 
+
+constexpr const char* ORT_HELP =
+  " You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
+  "The 'torch_ort' module is provided by the ONNX Runtime itself "
+  "(https://onnxruntime.ai).";
+
+// NB: Class must live in `at` due to limitations of Registry.h.
+namespace at {
+
+struct TORCH_API ORTHooksInterface {
+  // This should never actually be implemented, but it is used to
+  // squelch -Werror=non-virtual-dtor
+  virtual ~ORTHooksInterface() = default;
+
+  virtual std::string showConfig() const {
+    TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
+  }
+};
+
+// NB: dummy argument to suppress "ISO C++11 requires at least one argument
+// for the "..." in a variadic macro"
+struct TORCH_API ORTHooksArgs {};
+
+TORCH_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
+#define REGISTER_ORT_HOOKS(clsname) \
+  C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const ORTHooksInterface& getORTHooks();
+} // namespace detail
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..330677fe95df38d471e2e106adae0aa22bb64034
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+namespace at {
+
+struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
+  virtual ~PrivateUse1HooksInterface() override = default;
+  virtual const at::Generator& getDefaultGenerator(
+      c10::DeviceIndex device_index) {
+    TORCH_CHECK_NOT_IMPLEMENTED(
+        false,
+        "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
+  }
+
+  virtual at::Device getDeviceFromPtr(void* data) const {
+    TORCH_CHECK_NOT_IMPLEMENTED(
+        false,
+        "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
+  }
+
+  virtual Allocator* getPinnedMemoryAllocator() const {
+    TORCH_CHECK(
+        false,
+        "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
+  }
+
+  virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
+    TORCH_CHECK_NOT_IMPLEMENTED(
+        false,
+        "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
+  }
+
+  virtual void initPrivateUse1() const {}
+  virtual void resizePrivateUse1Bytes(const c10::Storage &storage, size_t newsize) const {
+    TORCH_CHECK_NOT_IMPLEMENTED(
+        false,
+        "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
+  }
+};
+
+struct TORCH_API PrivateUse1HooksArgs {};
+
+TORCH_API void RegisterPrivateUse1HooksInterface(
+    at::PrivateUse1HooksInterface* hook_);
+
+TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
+
+TORCH_API bool isPrivateUse1HooksRegistered();
+
+namespace detail {
+
+TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
+
+} // namespace detail
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..44b31b2348dd96ad3a0f4903cc96d046ea900e2a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h
@@ -0,0 +1,80 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+constexpr const char* XPU_HELP =
+    "The XPU backend requires Intel Extension for Pytorch;"
+    "this error has occurred because you are trying "
+    "to use some XPU's functionality, but the Intel Extension for Pytorch has not been "
+    "loaded for some reason. The Intel Extension for Pytorch MUST "
+    "be loaded, EVEN IF you don't directly use any symbols from that!";
+
+struct TORCH_API XPUHooksInterface {
+  virtual ~XPUHooksInterface() {}
+
+  virtual void initXPU() const {
+    TORCH_CHECK(
+        false,
+        "Cannot initialize XPU without Intel Extension for Pytorch.",
+        XPU_HELP);
+  }
+
+  virtual bool hasXPU() const {
+    return false;
+  }
+
+  virtual std::string showConfig() const {
+    TORCH_CHECK(
+        false,
+        "Cannot query detailed XPU version without Intel Extension for Pytorch. ",
+        XPU_HELP);
+  }
+
+  virtual int32_t getGlobalIdxFromDevice(const Device& device) const {
+    TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
+  }
+
+  virtual Generator getXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
+    TORCH_CHECK(false, "Cannot get XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
+  }
+
+  virtual const Generator& getDefaultXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
+    TORCH_CHECK(false, "Cannot get default XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
+  }
+
+  virtual DeviceIndex getNumGPUs() const {
+    return 0;
+  }
+
+  virtual DeviceIndex current_device() const {
+    TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
+  }
+
+  virtual Device getDeviceFromPtr(void* /*data*/) const {
+    TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
+  }
+
+  virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
+    TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library.");
+  }
+};
+
+struct TORCH_API XPUHooksArgs {};
+
+C10_DECLARE_REGISTRY(XPUHooksRegistry, XPUHooksInterface, XPUHooksArgs);
+#define REGISTER_XPU_HOOKS(clsname) \
+  C10_REGISTER_CLASS(XPUHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const XPUHooksInterface& getXPUHooks();
+} // namespace detail
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/div_rtn.h b/MLPY/Lib/site-packages/torch/include/ATen/div_rtn.h
new file mode 100644
index 0000000000000000000000000000000000000000..4a6d088b798c2ac96e58107db224a35ba5c9e8c8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/div_rtn.h
@@ -0,0 +1,11 @@
+#pragma once
+
+// Integer division rounding to -Infinity
+template 
+static inline T div_rtn(T x, T y) {
+  int q = x / y;
+  int r = x % y;
+  if ((r != 0) && ((r < 0) != (y < 0)))
+    --q;
+  return q;
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/dlpack.h b/MLPY/Lib/site-packages/torch/include/ATen/dlpack.h
new file mode 100644
index 0000000000000000000000000000000000000000..c5a3a5a0143123038d7c0a3ed43fdaffb0eae359
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/dlpack.h
@@ -0,0 +1,232 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file dlpack.h
+ * \brief The common header of DLPack.
+ */
+#ifndef DLPACK_DLPACK_H_
+#define DLPACK_DLPACK_H_
+
+/**
+ * \brief Compatibility with C++
+ */
+#ifdef __cplusplus
+#define DLPACK_EXTERN_C extern "C"
+#else
+#define DLPACK_EXTERN_C
+#endif
+
+/*! \brief The current version of dlpack */
+#define DLPACK_VERSION 80
+
+/*! \brief The current ABI version of dlpack */
+#define DLPACK_ABI_VERSION 1
+
+/*! \brief DLPACK_DLL prefix for windows */
+#ifdef _WIN32
+#ifdef DLPACK_EXPORTS
+#define DLPACK_DLL __declspec(dllexport)
+#else
+#define DLPACK_DLL __declspec(dllimport)
+#endif
+#else
+#define DLPACK_DLL
+#endif
+
+#include 
+#include 
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*!
+ * \brief The device type in DLDevice.
+ */
+#ifdef __cplusplus
+typedef enum : int32_t {
+#else
+typedef enum {
+#endif
+  /*! \brief CPU device */
+  kDLCPU = 1,
+  /*! \brief CUDA GPU device */
+  kDLCUDA = 2,
+  /*!
+   * \brief Pinned CUDA CPU memory by cudaMallocHost
+   */
+  kDLCUDAHost = 3,
+  /*! \brief OpenCL devices. */
+  kDLOpenCL = 4,
+  /*! \brief Vulkan buffer for next generation graphics. */
+  kDLVulkan = 7,
+  /*! \brief Metal for Apple GPU. */
+  kDLMetal = 8,
+  /*! \brief Verilog simulator buffer */
+  kDLVPI = 9,
+  /*! \brief ROCm GPUs for AMD GPUs */
+  kDLROCM = 10,
+  /*!
+   * \brief Pinned ROCm CPU memory allocated by hipMallocHost
+   */
+  kDLROCMHost = 11,
+  /*!
+   * \brief Reserved extension device type,
+   * used for quickly test extension device
+   * The semantics can differ depending on the implementation.
+   */
+  kDLExtDev = 12,
+  /*!
+   * \brief CUDA managed/unified memory allocated by cudaMallocManaged
+   */
+  kDLCUDAManaged = 13,
+  /*!
+   * \brief Unified shared memory allocated on a oneAPI non-partititioned
+   * device. Call to oneAPI runtime is required to determine the device
+   * type, the USM allocation type and the sycl context it is bound to.
+   *
+   */
+  kDLOneAPI = 14,
+  /*! \brief GPU support for next generation WebGPU standard. */
+  kDLWebGPU = 15,
+  /*! \brief Qualcomm Hexagon DSP */
+  kDLHexagon = 16,
+} DLDeviceType;
+
+/*!
+ * \brief A Device for Tensor and operator.
+ */
+typedef struct {
+  /*! \brief The device type used in the device. */
+  DLDeviceType device_type;
+  /*!
+   * \brief The device index.
+   * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
+   */
+  int32_t device_id;
+} DLDevice;
+
+/*!
+ * \brief The type code options DLDataType.
+ */
+typedef enum {
+  /*! \brief signed integer */
+  kDLInt = 0U,
+  /*! \brief unsigned integer */
+  kDLUInt = 1U,
+  /*! \brief IEEE floating point */
+  kDLFloat = 2U,
+  /*!
+   * \brief Opaque handle type, reserved for testing purposes.
+   * Frameworks need to agree on the handle data type for the exchange to be well-defined.
+   */
+  kDLOpaqueHandle = 3U,
+  /*! \brief bfloat16 */
+  kDLBfloat = 4U,
+  /*!
+   * \brief complex number
+   * (C/C++/Python layout: compact struct per complex number)
+   */
+  kDLComplex = 5U,
+  /*! \brief boolean */
+  kDLBool = 6U,
+} DLDataTypeCode;
+
+/*!
+ * \brief The data type the tensor can hold. The data type is assumed to follow the
+ * native endian-ness. An explicit error message should be raised when attempting to
+ * export an array with non-native endianness
+ *
+ *  Examples
+ *   - float: type_code = 2, bits = 32, lanes = 1
+ *   - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4
+ *   - int8: type_code = 0, bits = 8, lanes = 1
+ *   - std::complex: type_code = 5, bits = 64, lanes = 1
+ *   - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
+ */
+typedef struct {
+  /*!
+   * \brief Type code of base types.
+   * We keep it uint8_t instead of DLDataTypeCode for minimal memory
+   * footprint, but the value should be one of DLDataTypeCode enum values.
+   * */
+  uint8_t code;
+  /*!
+   * \brief Number of bits, common choices are 8, 16, 32.
+   */
+  uint8_t bits;
+  /*! \brief Number of lanes in the type, used for vector types. */
+  uint16_t lanes;
+} DLDataType;
+
+/*!
+ * \brief Plain C Tensor object, does not manage memory.
+ */
+typedef struct {
+  /*!
+   * \brief The data pointer points to the allocated data. This will be CUDA
+   * device pointer or cl_mem handle in OpenCL. It may be opaque on some device
+   * types. This pointer is always aligned to 256 bytes as in CUDA. The
+   * `byte_offset` field should be used to point to the beginning of the data.
+   *
+   * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
+   * TVM, perhaps others) do not adhere to this 256 byte aligment requirement
+   * on CPU/CUDA/ROCm, and always use `byte_offset=0`.  This must be fixed
+   * (after which this note will be updated); at the moment it is recommended
+   * to not rely on the data pointer being correctly aligned.
+   *
+   * For given DLTensor, the size of memory required to store the contents of
+   * data is calculated as follows:
+   *
+   * \code{.c}
+   * static inline size_t GetDataSize(const DLTensor* t) {
+   *   size_t size = 1;
+   *   for (tvm_index_t i = 0; i < t->ndim; ++i) {
+   *     size *= t->shape[i];
+   *   }
+   *   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
+   *   return size;
+   * }
+   * \endcode
+   */
+  void* data;
+  /*! \brief The device of the tensor */
+  DLDevice device;
+  /*! \brief Number of dimensions */
+  int32_t ndim;
+  /*! \brief The data type of the pointer*/
+  DLDataType dtype;
+  /*! \brief The shape of the tensor */
+  const int64_t* shape;
+  /*!
+   * \brief strides of the tensor (in number of elements, not bytes)
+   *  can be NULL, indicating tensor is compact and row-majored.
+   */
+  const int64_t* strides;
+  /*! \brief The offset in bytes to the beginning pointer to data */
+  uint64_t byte_offset;
+} DLTensor;
+
+/*!
+ * \brief C Tensor object, manage memory of DLTensor. This data structure is
+ *  intended to facilitate the borrowing of DLTensor by another framework. It is
+ *  not meant to transfer the tensor. When the borrowing framework doesn't need
+ *  the tensor, it should call the deleter to notify the host that the resource
+ *  is no longer needed.
+ */
+typedef struct DLManagedTensor {
+  /*! \brief DLTensor which is being memory managed */
+  DLTensor dl_tensor;
+  /*! \brief the context of the original host framework of DLManagedTensor in
+   *   which DLManagedTensor is used in the framework. It can also be NULL.
+   */
+  void * manager_ctx;
+  /*! \brief Destructor signature void (*)(void*) - this should be called
+   *   to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
+   *   if there is no way for the caller to provide a reasonable destructor.
+   *   The destructors deletes the argument self as well.
+   */
+  void (*deleter)(struct DLManagedTensor * self);
+} DLManagedTensor;
+#ifdef __cplusplus
+}  // DLPACK_EXTERN_C
+#endif
+#endif  // DLPACK_DLPACK_H_
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h
new file mode 100644
index 0000000000000000000000000000000000000000..fc151cedd1b05936922a94df29d771f39884e749
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h
@@ -0,0 +1,38 @@
+#pragma once
+#include 
+
+namespace at::functorch {
+
+// These are the interpreters for our AD transforms
+// (grad, vjp and jvp).
+// See NOTE: [functorch interpreter stack] for more details.
+
+struct TORCH_API GradInterpreterPtr {
+  explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
+  bool prevGradMode() const {
+    return std::get(base_->meta()).prevGradMode_;
+  }
+  Tensor lift(const Tensor& tensor) const;
+ private:
+  const Interpreter* base_;
+};
+
+struct TORCH_API JvpInterpreterPtr {
+  explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
+  bool prevFwdGradMode() const {
+    return std::get(base_->meta()).prevFwdGradMode_;
+  }
+  Tensor lift(const Tensor& tensor) const;
+ private:
+  const Interpreter* base_;
+};
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h
new file mode 100644
index 0000000000000000000000000000000000000000..c90dbabbe4c422767637946385700c2794e6d91b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h
@@ -0,0 +1,475 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+// This file contains helper functions for batching rules.
+
+namespace at::functorch {
+
+TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
+TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
+
+TORCH_API Tensor reshape_dim_outof_symint(int64_t src, c10::SymInt size1, const Tensor& x);
+
+Tensor moveBatchDimToFront(const Tensor& tensor, optional maybe_batch_dim);
+int64_t rankWithoutBatchDim(const Tensor& tensor, optional maybe_batch_dim);
+int64_t numelWithoutBatchDim(const Tensor& tensor, optional maybe_batch_dim);
+optional valIfNonempty(optional maybe_empty, int64_t new_val);
+int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
+VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
+
+void vmapIncompatibleInplaceError(const char* schema_name);
+
+Tensor maybePadToLogicalRank(const Tensor& tensor, optional has_bdim, int64_t logical_rank);
+
+void check_randomness(RandomnessType randomness);
+void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
+
+inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
+  if (has_bdim) {
+    return tensor;
+  }
+  const auto sizes = tensor.sym_sizes();
+  SymDimVector expanded_shape;
+  expanded_shape.reserve(sizes.size());
+  expanded_shape.emplace_back(std::move(batch_size));
+  expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
+  return tensor.expand_symint(expanded_shape);
+}
+
+#define VMAP_SUPPORT(op, batch_rule) \
+  m.impl(#op, op ## _generated_plumbing);
+
+#define VMAP_SUPPORT2(op, overload, batch_rule) \
+  m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing);
+
+#define OP_DECOMPOSE(op)  m.impl(#op, static_cast(native::op));
+#define OP_DECOMPOSE2(op, overload)  m.impl(#op"."#overload, static_cast(native::op));
+
+// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
+template 
+struct BasicUnaryBatchRuleHelper;
+
+template 
+struct BasicUnaryBatchRuleHelper> {
+  static std::tuple> apply(
+      const Tensor& tensor,
+      optional batch_dim,
+      T... extra_args) {
+    return std::make_tuple(Func(tensor, std::forward(extra_args)...), batch_dim);
+  }
+};
+
+// USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
+// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
+// It is important that this macro is not passed a function pointer!!
+#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
+    BasicUnaryBatchRuleHelper<\
+      decltype(&fn),\
+      &fn,\
+      c10::guts::function_traits::parameter_types>::apply)
+
+#define UNARY_POINTWISE(op) \
+  VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
+
+template 
+struct VariadicBdimsBatchRuleHelper;
+
+template 
+struct VariadicBdimsBatchRuleHelper> {
+  static std::tuple> apply(
+      const Tensor& tensor,
+      optional batch_dim,
+      T... extra_args) {
+    auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
+    return std::make_tuple(Func(tensor_, std::forward(extra_args)...), 0);
+  }
+};
+
+// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
+// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
+// It is important that this macro is not passed a function pointer!!
+#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
+    VariadicBdimsBatchRuleHelper<\
+      decltype(&fn),\
+      &fn,\
+      c10::guts::function_traits::parameter_types>::apply)
+
+#define VARIADIC_BDIMS(op) \
+  VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
+
+#define VARIADIC_BDIMS2(op, overload) \
+  VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
+
+template
+void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = schema.arguments().size();
+
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
+
+  int64_t cur_level = maybe_layer->layerId();
+
+  auto orig_arguments = torch::jit::last(*stack, num_arguments);
+  if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    op.callBoxed(stack);
+    return;
+  }
+
+  auto arguments = torch::jit::pop(*stack, num_arguments);
+  std::vector>> tensor_inputs;
+  std::vector tensor_pos;
+  for (const auto idx : c10::irange(0, num_arguments)) {
+    const auto& ivalue = arguments[idx];
+    if (ivalue.isTensor()) {
+      auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
+      tensor_inputs.emplace_back(tensor_value, tensor_bdim);
+      tensor_pos.push_back(idx);
+    }
+  }
+  Func(tensor_inputs);
+
+  size_t tensor_idx = 0;
+  TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
+  for (const auto arg_idx : c10::irange(0, num_arguments)) {
+    if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
+      torch::jit::push(stack, arguments[arg_idx]);
+    } else {
+      TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
+      torch::jit::push(stack, tensor_inputs[tensor_idx].first);
+      tensor_idx++;
+    }
+  }
+
+  op.callBoxed(stack);
+  const auto returns = torch::jit::pop(*stack, num_returns);
+  for (const auto& ret : returns) {
+    if (ret.isTensor()) {
+      torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
+    } else {
+      TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
+    }
+  }
+}
+
+inline void handle_pointwise_ops(std::vector>> &tensor_inputs) {
+  int64_t out_logical_rank = 0;
+  for (auto& tensor_input : tensor_inputs) {
+    int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
+    out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
+  }
+  for (auto& tensor_input: tensor_inputs) {
+    tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
+    tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
+  }
+}
+
+#define POINTWISE_BOXED(op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction>());
+
+#define POINTWISE_BOXED2(op, overload) \
+  m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction>());
+
+inline void handle_variadic_bdims(std::vector>> &tensor_inputs) {
+  for (auto & tensor_input : tensor_inputs) {
+    tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
+  }
+}
+
+#define VARIADIC_BDIMS_BOXED(op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction>());
+
+using UnpackedBatchedTensor = std::tuple>;
+
+inline void find_and_unpack_tensors(
+    const torch::jit::Stack* stack,
+    int64_t num_args,
+    int64_t cur_level,
+    SmallVector* tensors,
+    SmallVector* tensors_pos,
+    int64_t* batch_size) {
+
+  int64_t computed_batch_size = -1;
+  int64_t args_begin = stack->size() - num_args;
+
+  for (const auto idx : c10::irange(0, num_args)) {
+    const auto& ivalue = (*stack)[args_begin + idx];
+    if (!ivalue.isTensor()) {
+      continue;
+    }
+    auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
+    const auto& tensor_value = std::get<0>(unpacked);
+    const auto tensor_bdim = std::get<1>(unpacked);
+    if (tensor_bdim.has_value()) {
+      auto candidate_batch_size = tensor_value.size(*tensor_bdim);
+      if (computed_batch_size == -1) {
+        computed_batch_size = candidate_batch_size;
+      }
+      TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
+    }
+
+    tensors->push_back(std::move(unpacked));
+    tensors_pos->push_back(idx);
+  }
+  TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
+  *batch_size = computed_batch_size;
+}
+
+inline void boxed_existing_bdim_all_batch_rule(
+    const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = schema.arguments().size();
+
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
+  int64_t cur_level = maybe_layer->layerId();
+
+  const auto arguments = torch::jit::last(stack, num_arguments);
+  if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    op.callBoxed(stack);
+    return;
+  }
+
+  int64_t args_begin = stack->size() - num_arguments;
+  SmallVector tensor_inputs;
+  SmallVector tensor_pos;
+  int64_t batch_size;
+
+  find_and_unpack_tensors(
+      stack, num_arguments, cur_level,
+      &tensor_inputs, &tensor_pos, &batch_size);
+
+  // for each tensor, ensure it has a bdim and reshape it.
+  for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
+    const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
+    auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
+    auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
+    if (!bdim.has_value()) {
+      bdim = 0;
+    }
+    (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
+  }
+
+  op.callBoxed(stack);
+
+  for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
+    const auto& ret = (*stack)[idx];
+    TORCH_INTERNAL_ASSERT(ret.isTensor(),
+        "This boxed batching rule does not currently support ops that return non-tensor values");
+    (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
+  }
+}
+
+// Use when all tensors arguments accept one (normal) batch dim.
+// This batching rule expands the batch dim on all Tensors, reshapes it into
+// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
+// This is not the most efficient thing; if there are alternatives, plese try
+// to use them. Use this only as a last resort.
+#define EXISTING_BDIM_ALL_BOXED(op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction());
+
+template 
+inline void boxed_all_tensors_have_optional_bdim(
+    const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  const auto& schema = op.schema();
+  const auto num_returns = schema.returns().size();
+  const auto num_arguments = schema.arguments().size();
+
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
+  int64_t cur_level = maybe_layer->layerId();
+
+  const auto arguments = torch::jit::last(stack, num_arguments);
+  if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
+    op.callBoxed(stack);
+    return;
+  }
+
+  int64_t args_begin = stack->size() - num_arguments;
+  SmallVector tensor_inputs;
+  SmallVector tensor_pos;
+  int64_t batch_size;
+
+  find_and_unpack_tensors(
+      stack, num_arguments, cur_level,
+      &tensor_inputs, &tensor_pos, &batch_size);
+
+  optional is_no_batch_dim_case;
+
+  for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
+    const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
+    auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
+    const auto logical_rank = rankWithoutBatchDim(value, bdim);
+
+    if (!is_no_batch_dim_case.has_value()) {
+      is_no_batch_dim_case = (logical_rank == feature_rank);
+    }
+    auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
+    if (!bdim.has_value()) {
+      bdim = 0;
+    }
+    if (*is_no_batch_dim_case) {
+      TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
+      value_ = moveBatchDimToFront(value_, bdim);
+      if (tensor_idx == contig_tensor_index) {
+        value_ = value_.contiguous();
+      }
+      (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
+      continue;
+    }
+    TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
+    value_ = reshape_dim_into(*bdim, 0, value_);
+    if (tensor_idx == contig_tensor_index) {
+      value_ = value_.contiguous();
+    }
+    (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
+  }
+
+  op.callBoxed(stack);
+
+  for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
+    const auto& ret = (*stack)[idx];
+    TORCH_INTERNAL_ASSERT(ret.isTensor(),
+        "This boxed batching rule does not currently support ops that return non-tensor values");
+    if (*is_no_batch_dim_case) {
+      (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
+    } else {
+      (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
+    }
+  }
+}
+
+// Useful for many NN operators.
+// The operator must satisfy the following:
+// - All arguments must accept an optional batch dim.
+// - All arguments must be the same rank
+#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
+  m.impl(#op, torch::CppFunction::makeFromBoxedFunction>());
+
+#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
+  m.impl(#op, \
+         torch::CppFunction::makeFromBoxedFunction<\
+             boxed_all_tensors_have_optional_bdim<\
+                 feature_rank, \
+                 contig_tensor_index>\
+             >());
+
+template 
+struct ExistingBdimBatchRuleHelper;
+
+template 
+struct ExistingBdimBatchRuleHelper> {
+  static std::tuple> apply(
+      const Tensor& self,
+      optional self_bdim,
+      T... extra_args) {
+    auto self_ = reshape_dim_into(*self_bdim, 0, self);
+    auto out = Func(self_, std::forward(extra_args)...);
+    return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
+  }
+};
+
+// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
+// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
+// It is important that this macro is not passed a function pointer!!
+#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
+    ExistingBdimBatchRuleHelper<\
+      decltype(&fn),\
+      &fn,\
+      c10::guts::function_traits::parameter_types>::apply)
+
+
+#define EXISTING_BDIM(op) \
+  VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
+
+#define EXISTING_BDIM2(op, overload) \
+  VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
+
+#define INVOKE(object,ptrToMember)  ((object).*(ptrToMember))
+
+
+template 
+Tensor& unary_inplace_batch_rule(Tensor& self, optional, ExtraArgs... extra_args) {
+  INVOKE(self, Method)(std::forward(extra_args)...);
+  return self;
+}
+
+inline int64_t get_bdim_size4(
+    const Tensor& a_value, optional a_bdim,
+    const Tensor& b_value, optional b_bdim,
+    const Tensor& c_value, optional c_bdim,
+    const Tensor& d_value, optional d_bdim) {
+  if (a_bdim)
+    return a_value.size(*a_bdim);
+  if (b_bdim)
+    return b_value.size(*b_bdim);
+  if (c_bdim)
+    return c_value.size(*c_bdim);
+  if (d_bdim)
+    return d_value.size(*d_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+inline int64_t get_bdim_size3(
+    const Tensor& a_value, optional a_bdim,
+    const Tensor& b_value, optional b_bdim,
+    const Tensor& c_value, optional c_bdim) {
+  if (a_bdim)
+    return a_value.size(*a_bdim);
+  if (b_bdim)
+    return b_value.size(*b_bdim);
+  if (c_bdim)
+    return c_value.size(*c_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+inline int64_t get_bdim_size2(
+    const Tensor& a_value, optional a_bdim,
+    const Tensor& b_value, optional b_bdim) {
+  if (a_bdim)
+    return a_value.size(*a_bdim);
+  if (b_bdim)
+    return b_value.size(*b_bdim);
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+// [start, start + 1, ..., stop - 1]
+inline VmapDimVector range(int64_t start, int64_t stop) {
+  TORCH_INTERNAL_ASSERT(stop >= start);
+  VmapDimVector dims;
+  dims.reserve(stop - start);
+  for (int64_t i = start; i < stop; i++) {
+    dims.emplace_back(i);
+  }
+  return dims;
+}
+std::tuple _binary_pointwise_helper(
+    const Tensor& tensor, optional tensor_batch_dim, const Tensor& other, optional other_batch_dim,
+    bool do_type_promotion=true);
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h
new file mode 100644
index 0000000000000000000000000000000000000000..ab4fbc662aa3e0f28bc4e15432e56377a471a196
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h
@@ -0,0 +1,81 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at::functorch {
+
+// This file contains code for the vmap fallback (also known as the
+// BatchedTensor fallback or the Batched fallback). This code runs
+// when an operation doesn't have a batching rule implemented.
+
+// If an operator doesn't have a batching rule implemented then we fallback
+// to this implementation. The fallback doesn't work on out= variants or
+// view operations; that is, it works for out-of-place operations and
+// in-place non-view operations.
+//
+// For out-of-place operations, the fallback effectively takes all of the
+// BatchedTensors in `stack`, slices them, and runs `op` on all of the
+// corresponding slices to produce slices of the outputs. The output slices
+// then get `torch.stack`ed to create the
+// final returns.
+//
+// The performance of the fallback is not very good because it introduces an
+// extra copy from stacking the sliced outputs. Because of this, we prefer to
+// write batching rules for operators whenever possible.
+void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+// The vmap fallback emits a warning by default, but it may be disabled if
+// the user finds it to be too annoying.
+TORCH_API bool isVmapFallbackWarningEnabled();
+TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
+
+// Used for testing. The vmap fallback is enabled by default. When it is disabled,
+// it raises an error.
+TORCH_API bool isVmapFallbackEnabled();
+TORCH_API void setVmapFallbackEnabled(bool enabled);
+
+template  A vector_to_result(const std::vector& buffer) {
+  return buffer[0].to();
+}
+template  std::tuple vector_to_result(const std::vector& buffer) {
+  return std::make_tuple(buffer[0].to(), buffer[1].to());
+}
+template  std::tuple vector_to_result(const std::vector& buffer) {
+  return std::make_tuple(buffer[0].to(), buffer[1].to(), buffer[2].to());
+}
+
+// slow_fallback is a way to call the vmap fallback inside some boxed kernel.
+// There is probably some better way to metaprogram this.
+template 
+Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef args) {
+  std::vector stack(args.begin(), args.end());
+  batchedTensorForLoopFallback(op, &stack);
+  return vector_to_result(stack);
+}
+
+template 
+std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) {
+  std::vector stack(args.begin(), args.end());
+  batchedTensorForLoopFallback(op, &stack);
+  return vector_to_result(stack);
+}
+
+template 
+std::tuple slow_fallback(const c10::OperatorHandle& op, ArrayRef args) {
+  std::vector stack(args.begin(), args.end());
+  batchedTensorForLoopFallback(op, &stack);
+  return vector_to_result(stack);
+}
+
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..1982b94833e03dff8e15bf1f4ddddffd95260981
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h
@@ -0,0 +1,170 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at::functorch {
+
+using Tensor = at::Tensor;
+
+// We assume this in a few other places in the codebase,
+// but there isn't a centralized definition.
+constexpr int64_t kVmapMaxTensorDims = 64;
+
+// The valid vmap levels range from [0, 64). This effectively means that we
+// support a maximum of 64 nested vmaps.
+constexpr int64_t kVmapNumLevels = 64;
+
+// Store this number of elements of BatchDims on the stack. Most people will
+// probably use <= 5 nested vmaps, but adjust this number as necessary.
+constexpr int64_t kBatchDimsStackSize = 5;
+
+// A BatchedTensorImpl holds an underlying Tensor and a single batch dim
+// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
+// BatchedTensorImpl.
+//
+// The batch dimensions are treated as being "private"; they are not user-visible.
+// For example, in the following Tensor,
+//    bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
+// dimension 0 is batch dimension.
+//
+// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
+// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
+struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
+  explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
+
+  // Returns batch dimension of this tensor
+  int64_t bdim() const { return bdim_; }
+
+  // Returns batch dimension of this tensor
+  int64_t level() const { return level_; }
+
+  // BatchedTensorImpl wraps a Tensor
+  const Tensor& value() const { return value_; }
+
+  // Given a public dimension index, return the dimension index in the underlying
+  // value() tensor.
+  // For example, if we have
+  //    bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
+  // bt.actualDim(0) -> 1
+  // bt.actualDim(1) -> 2
+  // bt.actualDim(2) -> 3
+  // bt.actualDim(3) -> Error
+  int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
+
+  IntArrayRef sizes_custom() const override;
+  SymIntArrayRef sym_sizes_custom() const override;
+  int64_t size_custom(int64_t d) const override;
+  c10::SymInt sym_size_custom(int64_t d) const override;
+  // We have to override this because we opted into CustomStrides
+  IntArrayRef strides_custom() const override;
+  SymIntArrayRef sym_strides_custom() const override;
+  // Override a bunch of methods inherited from TensorImpl to return error messages.
+  bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
+  void set_size(int64_t dim, int64_t new_size) override;
+  void set_stride(int64_t dim, int64_t new_stride) override;
+  c10::intrusive_ptr shallow_copy_and_detach(
+    const c10::VariableVersion& version_counter,
+    bool allow_tensor_metadata_change) const override;
+  c10::intrusive_ptr shallow_copy_and_detach(
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const override;
+  void shallow_copy_from(const c10::intrusive_ptr& impl) override;
+#ifdef DEBUG
+  bool has_storage() const override;
+#endif
+
+  void refreshTensorMetadata();
+
+  // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
+  // accomplishes this is a hack where it is able to modify the levels of
+  // BatchedTensor to match the level of the current vmap transform.
+  void _unsafe_set_level(int64_t level) {
+    level_ = level;
+  }
+
+  // Used in batching rule for in-place view operations that can change
+  // the index of the bdim (think squeeze_, unsqueeze_)
+  void unsafe_set_bdim(int64_t bdim) {
+    // NB: you MUST call refreshTensorMetadata after doing this.
+    bdim_ = bdim;
+  }
+ private:
+  // see NOTE: [BatchedTensorImpl levels invariant]
+  void checkInvariants() const;
+  const char* tensorimpl_type_name() const override;
+
+  Tensor value_;
+
+  int64_t level_;
+  int64_t bdim_;
+};
+
+// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
+// BatchedTensorImpl.
+inline bool isBatchedTensor(const Tensor& tensor) {
+  return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
+      tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
+}
+
+// It is unsafe to call this on a Tensor that is not backed by a
+// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
+inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
+  return static_cast(tensor.unsafeGetTensorImpl());
+}
+
+inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
+  if (!isBatchedTensor(tensor)) {
+    return nullptr;
+  }
+  return unsafeGetBatchedImpl(std::move(tensor));
+}
+
+// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
+inline std::bitset createBatchDimBitset(int64_t dim) {
+  std::bitset is_bdim;
+  is_bdim.set(dim);
+  return is_bdim;
+}
+
+// Creates a bitset for the given level
+inline std::bitset createVmapLevelsBitset(int64_t level) {
+  std::bitset result;
+  result.set(level);
+  return result;
+}
+
+// Use this to construct a BatchedTensor from a regular Tensor
+TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
+
+// Adds a batch dim to `tensor`, returning a BatchedTensor
+TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
+
+// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
+// any wrapper Tensor subclasses). This is because there are methods on Tensor
+// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
+// TODO: should probably contain more (or all?) backend keys
+constexpr DispatchKeySet kKeysToPropagateToWrapper({
+  DispatchKey::Negative,
+  DispatchKey::Conjugate,
+  DispatchKey::XLA,
+  DispatchKey::CUDA,
+  DispatchKey::CPU,
+});
+
+inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
+  auto key_set = tensor.unsafeGetTensorImpl()->key_set();
+  return key_set & kKeysToPropagateToWrapper;
+}
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h
new file mode 100644
index 0000000000000000000000000000000000000000..608402801abc07565e370bff52475cc7ec7f6871
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h
@@ -0,0 +1,126 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include 
+#include 
+
+// This file contains template metaprogramming things that are used for our
+// batching rules.
+//
+// See NOTE: [vmap plumbing] for more details on why this is necessary.
+// The plumbing has a bunch of metaprogramming hacks for determining the signature
+// of a batching rule from the signature of the operator, many of which use the
+// helper functions in this file.
+
+namespace at::functorch {
+
+// Metaprogramming things
+template  using typelist = c10::guts::typelist::typelist;
+template  using head_t = c10::guts::typelist::head_t;
+template  using concat_t = c10::guts::typelist::concat_t;
+template  class debug_t;
+
+// tail operation
+template
+struct tail final {
+    static_assert(c10::guts::false_t::value,
+                  "In typelist::tail, the T argument must be typelist<...>.");
+};
+template
+struct tail> final {
+  using type = typelist;
+};
+template using tail_t = typename tail::type;
+
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
+  using type = Next;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, optional, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, optional, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, optional, Next, Tail> {
+  using type = Tail;
+};
+template 
+struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, optional, Next, Tail> {
+  using type = Tail;
+};
+template  struct RemoveBatchDimAfterTensor {
+  using first = head_t;
+  using next = tail_t;
+  using second = head_t;
+  using tail = tail_t;
+
+  using type = concat_t<
+    typelist,
+    typename RemoveBatchDimAfterTensor<
+      typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext::type
+    >::type
+  >;
+};
+template  struct RemoveBatchDimAfterTensor> {
+  using type = typelist;
+};
+template <> struct RemoveBatchDimAfterTensor> {
+  using type = typelist<>;
+};
+template using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor::type;
+
+template  struct UnpackSingleItemTuple {
+  using type = T;
+};
+template  struct UnpackSingleItemTuple> {
+  using type = T;
+};
+template  using unpack_single_item_tuple_t = typename UnpackSingleItemTuple::type;
+
+template  struct BuildFunctionHelper;
+template  struct BuildFunctionHelper> {
+  using type = Return(Args...);
+};
+template 
+struct BuildFunction {
+  using type = typename BuildFunctionHelper>::type;
+};
+template  using build_function_t = typename BuildFunction::type;
+
+
+template  struct ToOperatorType {
+  using batch_rule_return_type = typename c10::guts::function_traits::return_type;
+  using batch_rule_parameter_types = typename c10::guts::function_traits::parameter_types;
+
+  using operator_parameter_types = remove_batch_dim_after_tensor_t;
+  using operator_return_type =
+    unpack_single_item_tuple_t<
+      c10::guts::typelist::to_tuple_t<
+        remove_batch_dim_after_tensor_t<
+          c10::guts::typelist::from_tuple_t>>>;
+
+  using type = build_function_t;
+};
+template  using to_operator_t = typename ToOperatorType::type;
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h
new file mode 100644
index 0000000000000000000000000000000000000000..cede226d7945bb7dc3c13311cb74e1d7c5869613
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h
@@ -0,0 +1,124 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// Forward declared
+namespace c10 { struct AutogradMetaInterface; }
+
+namespace at::functorch  {
+
+// This file contains the implementation of functorch's interpreter stack.
+// See NOTE: [functorch interpreter stack] first before reading on.
+//
+// NB: the functorch interpreter stack is also referred to as:
+// - the "dynamic layer stack" -- an older name for "interpreter" was
+//   "dynamic layer".
+// - the "functorch mode stack". You can think of each functorch transform as a
+//   "mode" (in the same sense as torch_dispatch mode or torch_function mode),
+//   and functorch being an implementation of a "mode stack" where the modes
+//   may be arbitrary composed.
+
+// DynamicLayer is basically the same thing as an Interpreter.
+// It represents a functorch transform and it holds an Interpreter,
+// which contains metadata related to the transform and instructions on
+// how to perform the transform.
+//
+// TODO: we can excise DynamicLayer in favor of Interpreter,
+// But I am going to leave it for now as a compatiblity shim to avoid
+// needing to refactor a lot of callsites...
+struct TORCH_API DynamicLayer {
+  explicit DynamicLayer(
+      TransformType transform_type,
+      int64_t layerId,
+      optional batchSize = nullopt,
+      optional randomness = nullopt,
+      optional prev_grad_mode = nullopt,
+      optional pre_fwd_grad_mode = nullopt,
+      optional functionalize_add_back_views = nullopt);
+
+  TransformType key() const;
+  int64_t layerId() const;
+
+  const Interpreter& interpreter() const { return interpreter_; }
+  Interpreter& interpreter() { return interpreter_; }
+
+  // Only valid for vmap
+  c10::SymInt batchSize() const;
+  RandomnessType randomness() const;
+
+ private:
+  Interpreter interpreter_;
+};
+
+TORCH_API int64_t initAndPushDynamicLayer(
+    TransformType transform_type,
+    optional batch_size = nullopt,
+    optional randomness = nullopt,
+    optional prev_grad_mode = nullopt,
+    optional prev_fwd_grad_mode = nullopt,
+    optional functionalize_add_back_views = nullopt);
+TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
+TORCH_API c10::optional maybeCurrentDynamicLayer();
+TORCH_API const std::vector& getDynamicLayerStack();
+TORCH_API void setDynamicLayerStack(const std::vector& stack);
+TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
+
+// NOTE: [Life handles and lexically scoped transforms]
+// functorch transforms are lexically scoped.
+// Given a level, we store a "life handle" that is a boolean that tells us if the
+// transform with that level is active or not.
+//
+// functorch's TensorWrapper (for grad transforms) stores a life handle.
+// If a TensorWrapper escapes from the scope of the transform, then somehow
+// it must know it escaped; it can tell by querying the life handle.
+TORCH_API const std::shared_ptr& getLifeHandleForLevel(int64_t level);
+
+// Returns if an operator is in-place. An operator is inplace if:
+// 1. The first argument is a Tensor and it is being written to
+// 2. The first argument is being returned
+// 3. No other arguments are aliased
+// Here is an example of an in-place operator:
+// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
+
+// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
+TORCH_API c10::optional findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
+
+TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
+TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
+
+// Pretty printers
+TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
+TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack);
+
+// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
+// is disabled by default. The following two APIs are APIs for enabling
+// it. These are not user-facing APIs. We can delete this in the future, but
+// it is useful for debugging when something goes wrong with the
+// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
+// because it leads to loud errors if something is incorrect.
+TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
+TORCH_API bool getSingleLevelAutogradFunctionAllowed();
+
+// While a functorch grad transform is active, Tensor.requires_grad_() gets
+// disabled. These two functions are the mechanism to controlling that.
+TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
+TORCH_API bool getInplaceRequiresGradAllowed();
+
+TORCH_API DynamicLayer popDynamicLayer();
+TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h
new file mode 100644
index 0000000000000000000000000000000000000000..5ae0bcdccdf5fc5c061b542d175f677eace9a4c2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h
@@ -0,0 +1,22 @@
+#pragma once
+#include 
+
+namespace at::functorch {
+
+// This is the interpreter that handles the functionalize() transform.
+// See NOTE: [functorch interpreter stack] for more details.
+
+struct FunctionalizeInterpreterPtr {
+  explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
+  bool functionalizeAddBackViews() const {
+    return std::get(base_->meta()).functionalizeAddBackViews_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h
new file mode 100644
index 0000000000000000000000000000000000000000..ba44a44676cab2864781d85166f5cc780f4c23a4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h
@@ -0,0 +1,208 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::functorch {
+
+// NOTE: [functorch interpreter stack]
+//
+// functorch's dispatching system uses a stack of interpreters.
+// Historically we've referred to this as the "DynamicLayerStack".
+//
+// An interpreter is something that reads in the code it is passed
+// and then executes it. We have a different interpreter per-transform:
+// the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
+// and executing the batched version of it (the batching rule for aten::mv).
+//
+// Concretely, each interpreter is responsible for two things:
+//
+// 1) process(ophandle, stack)
+// Given an operator handle and a stack of arguments, the interpreter is
+// responsible for figuring out how to execute the operation under the semantics
+// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
+// the batching rule.
+//
+// The batching rules are stored as kernels on the FuncTorchBatched key, so the way
+// VmapInterpreter calls the batching rule is roughly: (A) exclude all
+// dispatch keys aside from the Batched key, (B) redispatch so we get to the
+// Batched key.
+//
+// 2) sendToNextInterpreter(ophandle, stack)
+// The VmapInterpreter, when it sees aten::mv, will process it into a call to
+// aten::mm. It then needs to send the call to aten::mm to the next interpreter
+// in the interpreter stack.
+//
+// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
+// and most Interpreters will implement it this way.
+
+enum class RandomnessType {
+    Error,      // always errors when calling a random function
+    Same,       // randomness appears the same across batches
+    Different,  // randomness appears different across batches
+    END
+};
+
+enum class TransformType {
+  Torch,  // Unused
+  Vmap,
+  Grad,  // reverse-mode AD, aka vjp
+  Jvp,  // forward-mode AD
+  Functionalize,
+};
+
+std::ostream& operator<<(std::ostream& os, const TransformType& t);
+
+// NOTE: [Interpreter "subclassing" design]
+//
+// How are various Interpreters for different transforms (vmap, grad, ...)
+// implemented?
+//
+// Accessing interpreters is in the hot-path of functorch so we have a constraint
+// that this code must be as fast as possible.
+//
+// As a result, we stay away from virtual methods and this causes our code
+// to look a little funny.
+//
+// `Interpreter` is the struct for Interpreters. It holds ALL of the
+// relevant information (what type of interpreter it is and the metadata).
+// Metadata for each interpreter is represented as a Union (std::variant)
+// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
+//
+// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
+// if you want to access the metadata fields (like batchSize and randomness).
+//
+// Each type of interpreter (e.g. Vmap) has a convenience struct
+// (e.g. VmapInterpreterPtr) associated with it.
+//
+// Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
+// and then one can access methods on VmapInterpreterPtr like so:
+// >>> VmapInterpreterPtr(&interpreter).batchSize()
+//
+// Finally, Interpreter::process switches on the type of the interpreter
+// and calls one of {Transform}Intepreter::processImpl under the hood.
+// Same for Interpreter::sendToNextInterpreter :)
+
+struct VmapInterpreterMeta {
+  explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
+    batchSize_(std::move(batchSize)), randomness_(randomness) {}
+  c10::SymInt batchSize_;
+  RandomnessType randomness_;
+};
+
+struct GradInterpreterMeta {
+  explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
+  bool prevGradMode_;
+};
+
+struct JvpInterpreterMeta {
+  explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
+  bool prevFwdGradMode_;
+};
+
+struct FunctionalizeInterpreterMeta {
+  explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
+    functionalizeAddBackViews_(functionalizeAddBackViews) {}
+  bool functionalizeAddBackViews_;
+};
+
+typedef std::variant<
+  int64_t,
+  GradInterpreterMeta,
+  JvpInterpreterMeta,
+  VmapInterpreterMeta,
+  FunctionalizeInterpreterMeta
+> InterpreterMeta;
+
+
+struct Interpreter {
+  // factory functions
+  static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
+    return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
+  }
+  static Interpreter Grad(int64_t level, bool prevGradMode) {
+    return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
+  }
+  static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
+    return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
+  }
+  static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
+    return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
+  }
+
+  // methods
+  TransformType key() const { return type_; }
+  int64_t level() const { return level_; }
+  const InterpreterMeta& meta() const { return meta_; }
+
+  void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
+
+  void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
+    TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
+    savedLocalDispatchKeySet_ = std::move(keyset);
+  }
+  void clearSavedLocalDispatchKeySet() {
+    TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
+    savedLocalDispatchKeySet_ = c10::nullopt;
+  }
+  c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
+    TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
+    return *savedLocalDispatchKeySet_;
+  }
+
+  // An Interpreter is alive if we are currently inside the ongoing transform
+  // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
+  // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
+  bool is_alive() const {
+    return *is_alive_;
+  }
+  const std::shared_ptr& is_alive_ptr() const {
+    return is_alive_;
+  }
+  void set_is_alive(bool alive) {
+    *is_alive_ = alive;
+  }
+
+  // Please don't use this
+  explicit Interpreter() = default;
+
+ private:
+  explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
+    type_(type), level_(level), is_alive_(std::make_shared(false)), meta_(meta) {}
+
+  // fields
+  TransformType type_;
+  int64_t level_;
+  optional savedLocalDispatchKeySet_;
+  std::shared_ptr is_alive_;
+  InterpreterMeta meta_;
+};
+
+// Applies the following for-loop:
+// for i in range(begin, end):
+//   args[i] = func(args[i])
+void foreachTensorInplace(std::vector& args, int64_t begin, int64_t end,
+    std::function func);
+
+// Applies the following for-loop:
+// for i in range(begin, end):
+//   if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
+//     args[i] = func(args[i], i - begin, true)
+//   args[i] = func(args[i], i - begin)
+void foreachTensorInplaceWithFlag(std::vector& args, int64_t begin, int64_t end,
+    const std::bitset<64> use_flag_relative, std::function func);
+
+std::vector findUnwrappedInputs(std::vector& args, int64_t begin, int64_t end);
+
+DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
+
+void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
+
+void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h
new file mode 100644
index 0000000000000000000000000000000000000000..2ad7ee72b6425dfe3f8a9d26e8d46274bf3788c9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h
@@ -0,0 +1,187 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include 
+#include 
+
+namespace at::functorch {
+
+// This files contains the legacy (now-deprecated) batching rule API.
+// Please try to use the new-style batching rule API (see writing_batch_rules.md)
+
+// This file contains abstractions used for transforming *logical* vmap arguments
+// into *physical* arguments. (Keep reading for definitions of these terms).
+
+// NOTE: [Logical vs physical args]
+// Consider the following vmap.
+//   vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
+// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
+// with batch dims 0 and 2:
+//   BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
+//
+// We say the *logical* view of the tensor has size [3] -- tensors inside
+// `func` appear to have size [3].
+// However, the *physical* underlying tensor (the one passed to vmap) has size
+// [2, 3, 4].
+//
+// This notion of logical vs physical also extends to non-tensor arguments.
+// Consider the previous tensor; let's assume the user called
+// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
+// dimension they are reducing over is dim 0 but the physical dim is dim 1
+// (the first non-batch dimension)
+
+// Forward declared; see NOTE: [What is a VmapPhysicalView?]
+struct VmapPhysicalView;
+
+// Most PyTorch operators take 4 or fewer inputs.
+constexpr int64_t kVmapTransformStaticInputSize = 4;
+using VmapPhysicalViewVec = SmallVector;
+
+// Pytorch generally advertises good performance for <= 5 dims.
+// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
+// dimensions to get 8. Adjust this number as necessary
+constexpr int64_t kVmapStaticDimVecSize = 8;
+using VmapDimVector = SmallVector;
+using VmapSymDimVector = SmallVector;
+
+// NOTE: [What is an VmapTransform?]
+// An *VmapTransform* converts logical views of tensors to physical views.
+//
+// Batching rules use VmapTransforms to convert logical arguments to
+// physical arguments, then call one or more at:: operator that handles the
+// physical arguments, and then converts the physical result back to a logical
+// argument.
+
+// VmapTransform for operators that take tensors with multiple batch dims.
+// Given one or more logical views on Tensors, `logicalToPhysical`
+// permutes all of the batch dims to the front of the tensor, aligns
+// and expands the batch dims to match each other (according to their `level`),
+// and returns a VmapPhysicalView on the tensor(s).
+struct TORCH_API MultiBatchVmapTransform {
+  static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
+  static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
+};
+
+// VmapTransform for operators that broadcast all inputs.
+// Given some logical views on Tensors, `logicalToPhysical`:
+// - permutes all of the batch dims to the front of the tensors
+// - aligns all the batch dims to the collective levels of all of the tensors.
+//   If a tensor does not have a batch dim for a vmap level, then it receives
+//   a size-one dimension for said level.
+// - aligns the non-batch dims to have the same dimensionality, adding extra
+//   size-1 dimensions in between the batch dimensions and the non-batch dimensions
+//   so that the batch dimensions are lined up from the right.
+//
+// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
+// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
+// of size (B, 1, 2) and (B, 3, 2).
+//
+// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
+// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
+// actually *need* to return a tensor of size (1, 2) for the second tensor
+// because the broadcasting operation takes care of that for us, but we do
+// it anyways to keep things simple.
+struct TORCH_API BroadcastingVmapTransform {
+  static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
+};
+
+// Forward declared, if you're reading this file head to toe, don't worry about
+// it yet.
+struct VmapPhysicalToLogicalMap;
+
+// NOTE: [What is a VmapPhysicalView?]
+// VmapPhysicalView represents a physical view on a Tensor.
+//
+// One can use it to further convert logical dimension indices, logical shapes,
+// and more to their physical variants, or convert a new (physical) tensor into
+// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
+//
+// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
+// the front and some levels that correspond to said batch dimensions.
+//
+// The levels bitset specifies which vmap levels correspond to the batch
+// dimensions at the front of the tensor. In particular, the number of set bits
+// corresponds to the number of batch dimensions on `tensor` and the rightmost
+// bit of `levels` specifies the maximum number of nested vmaps we are in at
+// this point in time.
+// For example, given:
+//   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
+//
+// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
+// than or equal to 3.
+//   bitset: 010100
+//              ^
+//              |
+//   levels: 012345
+struct TORCH_API VmapPhysicalView {
+  VmapPhysicalView(Tensor&& tensor, std::bitset levels)
+      : levels_(levels), tensor_(tensor) {
+    // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
+  }
+
+  Tensor& tensor() { return tensor_; }
+  const Tensor& tensor() const { return tensor_; }
+
+  // Maps logical dim indices to physical dim indices. Also does dim wrapping.
+  //
+  // For example, given:
+  //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
+  //
+  // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
+  // This is because the size of levels tell us that the first two dimensions
+  // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
+  // a physical dim of `n + 2`.
+  VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
+  int64_t getPhysicalDim(int64_t logical_dim) const;
+
+  // Returns a VmapPhysicalToLogicalMap object. This can be used for
+  // mapping a physical tensor to a new logical tensor (BatchedTensor)
+  VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
+
+  // Maps a logical shape to a physical shape by pre-pending the batch
+  // sizes to the logical shape.
+  VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
+  SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
+
+  int64_t numBatchDims() const;
+
+ private:
+  int64_t numLogicalDims() const;
+
+  std::bitset levels_;
+  Tensor tensor_;
+};
+
+// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
+// to a logical one (BatchedTensor). It holds some levels that are used to do the
+// mapping and assumes that the batch dimensions in the physical tensor all
+// occur at the front of the tensor.
+struct TORCH_API VmapPhysicalToLogicalMap {
+  VmapPhysicalToLogicalMap(std::bitset levels): levels_(levels) {}
+
+  // Maps a physical tensor to a new logical tensor (BatchedTensor).
+  // Assumes that all of the "batch dimensions" are at the front
+  // of the physical tensor. For example, given:
+  // - x = rank-4 Tensor with size 2, 3, 5, 7
+  // - levels = (2, 4)
+  // Returns:
+  // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
+  Tensor apply(const Tensor& physical_tensor) const;
+
+  // Given a vector of physical tensors,
+  // 1. maps each tensor to a new logical tensor. Assumes that all of the
+  //    "batch dimensions" are at the front of the physical tensors.
+  // 2. stores the new logical tensors back into the passed-in vector. This is
+  //    to avoid additional dynamic allocations.
+  void applyInplace(std::vector& physical_tensors) const;
+
+  std::bitset levels_;
+};
+
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/Macros.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/Macros.h
new file mode 100644
index 0000000000000000000000000000000000000000..b99be8781c127d5d8c49fdc1b7b80027c9383e48
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/Macros.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#define SINGLE_ARG(...) __VA_ARGS__
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/PlumbingHelper.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/PlumbingHelper.h
new file mode 100644
index 0000000000000000000000000000000000000000..7a3b9e4df77819d155e3f4682f6b0d90d207a3b2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/PlumbingHelper.h
@@ -0,0 +1,63 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+#pragma once
+#include 
+#include 
+#include 
+
+// NOTE: [vmap plumbing]
+//
+// Here's how "batching rules" work.
+// - we register kernels to the Batched key
+// - these kernels have the same signatures as the original operators.
+//   For example, at::sin(Tensor self) accepts a Tensor, and the batched kernel
+//   must also accept a Tensor
+// - However, it is more natural for users to write a batching rule like the
+//   following: sin_batch_rule(Tensor self, optional self_bdim)
+// - There is some codegenerated layer (the "plumbing") that wraps the user
+//   defined batching rule (e.g. sin_batch_rule) in a kernel that can be
+//   registered to the Batched key.
+//
+// The plumbing is responsible for wrapping a batching rule into a form that may
+// be registered as the kernel for the batched key.
+
+namespace at::functorch {
+
+void vmap_check_escaped(const optional &layer, const char* what);
+
+// Create a BatchedTensor given a tensor, bdim, and level
+TORCH_API Tensor makeBatched(const Tensor& tensor, optional bdim, int64_t level);
+
+// Given a Tensor that may or may not be a BatchedTensor, unwrap it.
+// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
+// doesn't match, then this returns (tensor, nullopt).
+// Otherwise, it returns (unwrap(tensor), bdim).
+TORCH_API std::tuple> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
+
+// Creates a vector of BatchedTensor
+TORCH_API std::vector makeBatchedVector(const std::vector& tensors, optional bdim, int64_t level);
+
+// Returns True if ANY tensor in tensors is batched at level
+TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);
+TORCH_API bool isBatchedAtLevel(const c10::List>& maybe_tensors, int64_t level);
+TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level);
+TORCH_API bool isBatchedAtLevel(const c10::optional& maybe_tensor, int64_t level);
+
+// Convenience helper. Returns true if any tensor is batched at level
+TORCH_API bool areAnyBatchedAtLevel(ArrayRef> maybe_tensors, int64_t level);
+
+inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
+  if (ivalue.isTensor()) {
+    auto maybe_level = maybeCurrentDynamicLayer();
+    TORCH_INTERNAL_ASSERT(maybe_level.has_value());
+    auto current_level = maybe_level->layerId();
+    return isBatchedAtLevel(ivalue.toTensor(), current_level);
+  }
+  // TODO: should really check this
+  return false;
+}
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/TensorWrapper.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/TensorWrapper.h
new file mode 100644
index 0000000000000000000000000000000000000000..b99f3f937fa678950e7833c5d617ffc7f1c5dffc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/TensorWrapper.h
@@ -0,0 +1,103 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::functorch {
+
+// NOTE: [functorch's TensorWrapper]
+//
+// Taking better suggestions for a name. TensorWrapper is the wrapper Tensor
+// Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is
+// analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass.
+//
+// If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively
+// another Variable.
+//
+// Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)).
+// The reason why is so that each TensorWrapper can hold its own AutogradMeta and
+// participate in a **separate** autograd graph.
+//
+// There are alternative designs we could have chosen (e.g. each grad transform
+// stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper
+// design is that we can re-use existing VariableType kernels (i.e. Autograd kernels)
+// without much modification. Since a TensorWrapper looks like a regular Tensor,
+// the VariableType kernel can pull out the AutogradMeta struct from where it
+// expects and extend the autograd graph
+
+struct TORCH_API TensorWrapper : public c10::TensorImpl {
+  explicit TensorWrapper(
+      c10::DispatchKeySet key_set,
+      Tensor value,
+      int64_t level,
+      std::shared_ptr is_alive,
+      bool is_immutable = false,  // if true, this came from an operation that aliases an immutable tensor
+      bool use_value_sizes_strides = true);
+
+  void refreshMetadata();
+
+  const Tensor& value() const {
+    return value_;
+  }
+  optional level() const {
+    if (is_alive()) {
+      return level_;
+    }
+    return {};
+  }
+  bool is_immutable() const {
+    return is_immutable_;
+  }
+  bool is_alive() const;
+
+  // Overrides necessary for autograd
+  c10::intrusive_ptr shallow_copy_and_detach(
+    const c10::VariableVersion& version_counter,
+    bool allow_tensor_metadata_change) const override;
+  c10::intrusive_ptr shallow_copy_and_detach(
+      c10::VariableVersion&& version_counter,
+      bool allow_tensor_metadata_change) const override;
+  void shallow_copy_from(const c10::intrusive_ptr& impl) override;
+
+ private:
+  const char* tensorimpl_type_name() const override;
+  Tensor value_;
+  int64_t level_;
+  bool is_immutable_;
+
+  // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter
+  // that created it is still alive or not.
+  // If the Grad Interpreter is no longer alive then it attempts to behave like
+  // a regular Tensor.
+  //
+  // When we exit the level, this wrapper may be marked as "not alive".
+  // Wrappers that are not alive:
+  // 1) May still have autograd metadata on them
+  // 2) Forward dispatches to the underlying value()
+  std::shared_ptr is_alive_;
+};
+
+// There are two variants of makeTensorWrapper: one that accepts a level
+// and one that accepts an Interpreter.
+//
+// The one that accepts a level tries to automatically get the life handle from the
+// interpreter on the DynamicLayerStack.
+// It needs to be used with caution: if the interpreter is not on the
+// DynamicLayerStack, then we won't be able to find the life handle.
+//
+// In practice this isn't a problem: when we're constructing TensorWrapper in
+// Python, the corresponding interpreter is on the stack.
+TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
+TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
+TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
+TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
+TORCH_API void dumpTensorCout(const Tensor& tensor);
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/functorch/VmapInterpreter.h b/MLPY/Lib/site-packages/torch/include/ATen/functorch/VmapInterpreter.h
new file mode 100644
index 0000000000000000000000000000000000000000..8a2539e24faeae1308dc8376bfc3a2b15d438179
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/functorch/VmapInterpreter.h
@@ -0,0 +1,25 @@
+#pragma once
+#include 
+
+namespace at::functorch {
+
+// This is the interpreter that handles the functionalize() transform.
+// See NOTE: [functorch interpreter stack] for more details.
+
+struct VmapInterpreterPtr {
+  explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); }
+  TransformType key() const { return base_->key(); }
+  int64_t level() const { return base_->level(); }
+  void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
+  void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
+  c10::SymInt batchSize() const {
+    return std::get(base_->meta()).batchSize_;
+  }
+  RandomnessType randomness() const {
+    return std::get(base_->meta()).randomness_;
+  }
+ private:
+  const Interpreter* base_;
+};
+
+} // namespace at::functorch
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
new file mode 100644
index 0000000000000000000000000000000000000000..fcd0650a58d6a4d87c3cedf34fccde5c0d5b7e3d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
@@ -0,0 +1,31 @@
+#pragma once
+
+#include 
+#include 
+
+// Use of c10::hip namespace here makes hipification easier, because
+// I don't have to also fix namespaces.  Sorry!
+namespace c10 { namespace hip {
+
+// Takes a valid HIPAllocator (of any sort) and turns it into
+// an allocator pretending to be a CUDA allocator.  See
+// Note [Masquerading as CUDA]
+class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
+  Allocator* allocator_;
+public:
+  explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
+    : allocator_(allocator) {}
+  DataPtr allocate(size_t size) override {
+    DataPtr r = allocator_->allocate(size);
+    r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
+    return r;
+  }
+  DeleterFnPtr raw_deleter() const override {
+    return allocator_->raw_deleter();
+  }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    allocator_->copy_data(dest, src, count);
+  }
+};
+
+}} // namespace c10::hip
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h
new file mode 100644
index 0000000000000000000000000000000000000000..4811b0d5e45e984bea140496c9ae10684d16e040
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+// forward declaration
+class DataPtr;
+namespace hip {
+namespace HIPCachingAllocatorMasqueradingAsCUDA {
+
+C10_HIP_API Allocator* get();
+C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream);
+
+} // namespace HIPCachingAllocatorMasqueradingAsCUDA
+} // namespace hip
+} // namespace c10
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h
new file mode 100644
index 0000000000000000000000000000000000000000..0a3992263025cea38f7bd896cfb1efc534ff9ae3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h
@@ -0,0 +1,353 @@
+#pragma once
+
+#include 
+
+// The includes of HIPGuard.h
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+
+// Use of c10::hip namespace here makes hipification easier, because
+// I don't have to also fix namespaces.  Sorry!
+namespace c10 { namespace hip {
+
+// Note [Masquerading as CUDA]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// c10_hip is very easy to understand: it is HIPified from c10_cuda,
+// and anywhere you said CUDA, the source code now says HIP.  HIPified
+// PyTorch is much harder to understand: it is HIPified from regular
+// PyTorch, yes, but NO source-to-source translation from CUDA to
+// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
+// For example, when you use HIPified PyTorch, you say x.cuda() to
+// move a tensor onto ROCm device.  We call this situation "HIP
+// masquerading as CUDA".
+//
+// This leads to a very awkward situation when we want to call c10_hip
+// code from PyTorch, since c10_hip is expecting things to be called
+// HIP, but PyTorch is calling them CUDA (masquerading as HIP).  To
+// fix this impedance mismatch, we have MasqueradingAsCUDA variants
+// for all c10_hip classes.  These translate between the "HIP" and "CUDA
+// masquerading as HIP" worlds.  For example,
+// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
+// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
+// returns CUDA, getDevice() reports the current HIP device as a CUDA
+// device.)
+//
+// We should be able to delete all of these classes entirely once
+// we switch PyTorch to calling a HIP a HIP.
+//
+// When you add a new MasqueradingAsCUDA class/function, you need to
+// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
+//
+//
+//
+// By the way, note that the cpp file associated with this also
+// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
+// this HIP implementation.
+
+struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
+  static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
+  HIPGuardImplMasqueradingAsCUDA() {}
+  HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
+    TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
+  }
+  c10::DeviceType type() const override {
+    return c10::DeviceType::CUDA;
+  }
+  Device exchangeDevice(Device d) const override {
+    TORCH_INTERNAL_ASSERT(d.is_cuda());
+    Device old_device = getDevice();
+    if (old_device.index() != d.index()) {
+      C10_HIP_CHECK(hipSetDevice(d.index()));
+    }
+    return old_device;
+  }
+  Device getDevice() const override {
+    int device;
+    C10_HIP_CHECK(hipGetDevice(&device));
+    return Device(c10::DeviceType::CUDA, device);
+  }
+  void setDevice(Device d) const override {
+    TORCH_INTERNAL_ASSERT(d.is_cuda());
+    C10_HIP_CHECK(hipSetDevice(d.index()));
+  }
+  void uncheckedSetDevice(Device d) const noexcept override {
+    C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
+  }
+  Stream getStream(Device d) const noexcept override {
+    return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
+  }
+  Stream getDefaultStream(Device d) const override {
+    return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
+  }
+  Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
+    return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
+  }
+  Stream exchangeStream(Stream s) const noexcept override {
+    HIPStreamMasqueradingAsCUDA cs(s);
+    auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
+    setCurrentHIPStreamMasqueradingAsCUDA(cs);
+    return old_stream.unwrap();
+  }
+  DeviceIndex deviceCount() const noexcept override {
+    int deviceCnt;
+    hipError_t _err;
+    _err = hipGetDeviceCount(&deviceCnt);
+#if defined(USE_ROCM) && (ROCM_VERSION < 50201)
+    if(_err == hipErrorInvalidDevice)
+        return 0;
+#endif
+    if(_err != hipErrorNoDevice && _err != hipSuccess)
+        C10_HIP_CHECK(_err);
+    return deviceCnt;
+  }
+
+  // Event-related functions
+  // Note: hipEventCreateWithFlags should be called on the same device as
+  //  the recording stream's device.
+  void createEvent(
+    hipEvent_t* hip_event,
+    const EventFlag flag) const {
+    // Maps PyTorch's Event::Flag to HIP flag
+    auto hip_flag = hipEventDefault;
+    switch (flag) {
+      case EventFlag::PYTORCH_DEFAULT:
+      case EventFlag::HIP_EVENT_DISABLE_TIMING:
+        hip_flag = hipEventDisableTiming;
+        break;
+      case EventFlag::BACKEND_DEFAULT:
+      case EventFlag::HIP_EVENT_DEFAULT:
+        hip_flag = hipEventDefault;
+        break;
+      default:
+        TORCH_CHECK(false, "HIP event received unknown flag");
+    }
+
+    C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
+  }
+
+  void destroyEvent(
+    void* event,
+    const DeviceIndex device_index) const noexcept override {
+    if (!event) return;
+    auto hip_event = static_cast(event);
+    int orig_device;
+    C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
+    C10_HIP_CHECK_WARN(hipSetDevice(device_index));
+    C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
+    C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
+  }
+
+  void record(void** event,
+    const Stream& stream,
+    const DeviceIndex device_index,
+    const EventFlag flag) const override {
+    TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
+      "Event device index ",
+      device_index,
+      " does not match recording stream's device index ",
+      stream.device_index(),
+      ".");
+
+    hipEvent_t hip_event = static_cast(*event);
+    HIPStreamMasqueradingAsCUDA hip_stream{stream};
+
+    // Moves to stream's device to record
+    const auto orig_device = getDevice();
+    setDevice(stream.device());
+
+    // Creates the event (lazily)
+    if (!hip_event) createEvent(&hip_event, flag);
+    C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
+    // Makes the void* point to the (possibly just allocated) HIP event
+    *event = hip_event;
+
+    // Resets device
+    setDevice(orig_device);
+  }
+
+  void block(
+    void* event,
+    const Stream& stream) const override {
+    if (!event) return;
+    hipEvent_t hip_event = static_cast(event);
+    HIPStreamMasqueradingAsCUDA hip_stream{stream};
+    const auto orig_device = getDevice();
+    setDevice(stream.device());
+    C10_HIP_CHECK(hipStreamWaitEvent(
+      hip_stream,
+      hip_event,
+      /*flags (must be zero)=*/ 0));
+    setDevice(orig_device);
+  }
+
+  bool queryEvent(void* event) const override {
+    if (!event) return true;
+    hipEvent_t hip_event = static_cast(event);
+    const hipError_t err = hipEventQuery(hip_event);
+    if (err != hipErrorNotReady) C10_HIP_CHECK(err);
+    else {
+      // ignore and clear the error if not ready
+      (void)hipGetLastError();
+    }
+    return (err == hipSuccess);
+  }
+
+  // Stream-related functions
+  bool queryStream(const Stream& stream) const override {
+    HIPStreamMasqueradingAsCUDA hip_stream{stream};
+    return hip_stream.query();
+  }
+
+  void synchronizeStream(const Stream& stream) const override {
+    HIPStreamMasqueradingAsCUDA hip_stream{stream};
+    hip_stream.synchronize();
+  }
+
+  void recordDataPtrOnStream(
+    const c10::DataPtr& data_ptr,
+    const Stream& stream) const override {
+    HIPStreamMasqueradingAsCUDA hip_stream{stream};
+    HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
+  }
+};
+
+// All of the guards which have HIPGuardImpl burned in need to also have
+// variants using HIPGuardImplMasqueradingAsCUDA.
+
+/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
+/// the correct InlineDeviceGuard burned in.  Sorry about the
+/// copy-pasting.
+
+struct HIPGuardMasqueradingAsCUDA {
+  explicit HIPGuardMasqueradingAsCUDA() = delete;
+  explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
+  explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
+
+  HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
+  HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
+  HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
+  HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
+
+  void set_device(Device device) { guard_.set_device(device); }
+  void reset_device(Device device) { guard_.reset_device(device); }
+  void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
+  Device original_device() const { return guard_.original_device(); }
+  Device current_device() const { return guard_.current_device(); }
+
+ private:
+  c10::impl::InlineDeviceGuard guard_;
+};
+
+struct OptionalHIPGuardMasqueradingAsCUDA {
+  explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
+  explicit OptionalHIPGuardMasqueradingAsCUDA(optional device_opt) : guard_(device_opt) {}
+  explicit OptionalHIPGuardMasqueradingAsCUDA(optional device_index_opt) : guard_(device_index_opt) {}
+
+  OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
+  OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
+
+  void set_device(Device device) { guard_.set_device(device); }
+  void reset_device(Device device) { guard_.reset_device(device); }
+  void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
+  optional original_device() const { return guard_.original_device(); }
+  optional current_device() const { return guard_.current_device(); }
+  void reset() { guard_.reset(); }
+
+private:
+  c10::impl::InlineOptionalDeviceGuard guard_;
+};
+
+struct HIPStreamGuardMasqueradingAsCUDA {
+  explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
+  explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
+  HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
+  HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
+  HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+  HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+  void reset_stream(Stream stream) { guard_.reset_stream(stream); }
+
+  HIPStreamMasqueradingAsCUDA original_stream() const {
+    return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
+  }
+  HIPStreamMasqueradingAsCUDA current_stream() const {
+    return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
+  }
+
+  Device current_device() const { return guard_.current_device(); }
+  Device original_device() const { return guard_.original_device(); }
+
+private:
+  c10::impl::InlineStreamGuard guard_;
+};
+
+struct OptionalHIPStreamGuardMasqueradingAsCUDA {
+  explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
+  explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
+  explicit OptionalHIPStreamGuardMasqueradingAsCUDA(optional stream_opt) : guard_(stream_opt) {}
+
+  OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+  OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+  void reset_stream(Stream stream) { guard_.reset_stream(stream); }
+
+  optional original_stream() const {
+    auto r = guard_.original_stream();
+    if (r.has_value()) {
+      return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
+    } else {
+      return nullopt;
+    }
+  }
+
+  optional current_stream() const {
+    auto r = guard_.current_stream();
+    if (r.has_value()) {
+      return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
+    } else {
+      return nullopt;
+    }
+  }
+
+  void reset() { guard_.reset(); }
+
+private:
+  c10::impl::InlineOptionalStreamGuard guard_;
+};
+
+struct HIPMultiStreamGuardMasqueradingAsCUDA {
+  explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef streams)
+    : guard_(unwrapStreams(streams)) {}
+
+  HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
+  HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
+  HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
+  HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+private:
+  c10::impl::InlineMultiStreamGuard guard_;
+
+  static std::vector unwrapStreams(ArrayRef hipStreams) {
+    std::vector streams;
+    streams.reserve(hipStreams.size());
+    for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
+      streams.push_back(hipStream);
+    }
+    return streams;
+  }
+};
+
+}} // namespace c10::hip
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h
new file mode 100644
index 0000000000000000000000000000000000000000..7958146b81edcedff9facecd94d69cdb9011ecbd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h
@@ -0,0 +1,130 @@
+#pragma once
+
+#include 
+
+// Use of c10::hip namespace here makes hipification easier, because
+// I don't have to also fix namespaces.  Sorry!
+namespace c10 { namespace hip {
+
+// See Note [Masquerading as CUDA] for motivation
+
+class HIPStreamMasqueradingAsCUDA {
+public:
+
+  enum Unchecked { UNCHECKED };
+
+  explicit HIPStreamMasqueradingAsCUDA(Stream stream)
+    : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
+    // We did the coercion unchecked; check that it was right.
+    TORCH_CHECK(stream.device().is_cuda() /* !!! */);
+  }
+
+  explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
+    // Unsafely coerce the "CUDA" stream into a HIP stream
+    : stream_(
+        HIPStream(
+          Stream(
+            Stream::UNSAFE,
+            Device(c10::DeviceType::HIP, stream.device_index()),
+            stream.id())
+        )
+      ) {}
+
+  // New constructor, just for this.  Does NOT coerce.
+  explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
+
+  bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
+    return stream_ == other.stream_;
+  }
+
+  bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
+    return stream_ != other.stream_;
+  }
+
+  operator hipStream_t() const { return stream_.stream(); }
+
+  operator Stream() const {
+    // Unsafely coerce HIP stream into a "CUDA" stream
+    return Stream(Stream::UNSAFE, device(), id());
+  }
+
+  DeviceIndex device_index() const { return stream_.device_index(); }
+
+  // Unsafely coerce HIP device into CUDA device
+  c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
+
+  Device device() const {
+    // Unsafely coerce HIP device into CUDA device
+    return Device(c10::DeviceType::CUDA, stream_.device_index());
+  }
+
+  StreamId id() const        { return stream_.id(); }
+  bool query() const         { return stream_.query(); }
+  void synchronize() const   { stream_.synchronize(); }
+  int priority() const       { return stream_.priority(); }
+  hipStream_t stream() const { return stream_.stream(); }
+
+  Stream unwrap() const {
+    // Unsafely coerce HIP stream into "CUDA" stream
+    return Stream(Stream::UNSAFE, device(), id());
+  }
+
+  c10::StreamData3 pack3() const noexcept {
+    // Unsafely coerce HIP stream into "CUDA" stream before packing
+    return unwrap().pack3();
+  }
+
+  static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
+                                             DeviceIndex device_index,
+                                             c10::DeviceType device_type) {
+    // NB: constructor manages CUDA->HIP translation for us
+    return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
+        stream_id, device_index, device_type));
+  }
+
+  static std::tuple priority_range() { return HIPStream::priority_range(); }
+
+  // New method, gets the underlying HIPStream
+  HIPStream hip_stream() const { return stream_; }
+
+private:
+  HIPStream stream_;
+};
+
+HIPStreamMasqueradingAsCUDA
+inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
+  return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
+}
+
+HIPStreamMasqueradingAsCUDA
+inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
+  return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
+}
+
+inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
+  return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
+}
+
+inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
+  return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
+}
+
+inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
+  setCurrentHIPStream(stream.hip_stream());
+}
+
+inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
+  stream << s.hip_stream() << " (masquerading as CUDA)";
+  return stream;
+}
+
+}} // namespace c10::hip
+
+namespace std {
+  template <>
+  struct hash {
+    size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
+      return std::hash{}(s.unwrap());
+    }
+  };
+} // namespace std
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/jit_macros.h b/MLPY/Lib/site-packages/torch/include/ATen/jit_macros.h
new file mode 100644
index 0000000000000000000000000000000000000000..ac6d0432425f11f761dcf26de7b0402a8daae5ac
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/jit_macros.h
@@ -0,0 +1,7 @@
+#pragma once
+#include 
+#include 
+
+// AT_USE_JITERATOR(), controls whether we jit some elementwise kernels
+#define AT_USE_JITERATOR() true
+#define jiterator_stringify(...) std::string(#__VA_ARGS__);
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/jiterator_macros.h b/MLPY/Lib/site-packages/torch/include/ATen/jiterator_macros.h
new file mode 100644
index 0000000000000000000000000000000000000000..ccde91c67237707108eb61cc0eea38d0768aa2b5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/jiterator_macros.h
@@ -0,0 +1,38 @@
+#pragma once
+#include 
+#include 
+
+#define JITERATOR_HOST_DEVICE C10_HOST_DEVICE
+#if defined(_MSC_VER) && defined(__CUDACC__)
+// NVRTC on Windows errors if __host__ __device__ attribute is
+// present on kernel.
+// error: attribute "__host__" does not apply here
+// error: attribute "__device__" does not apply here
+#define JITERATOR_HOST_DEVICE
+#endif
+
+// jiterator_also_stringify_as macro is used to define code (for CPU/ROCm)
+// and generate code string for `jiterator` (only when compiling for CUDA).
+// Usage :
+//      jiterator_also_stringify_as(
+//          jiterator_code(template  T identity(T x) { return x; }),
+//          identity_string);
+// This will define the template `identity` as present in code and
+// also define `std::string identity_string` with the code as the string
+// if this is being compiled for CUDA.
+
+// `jiterator_code` macro is to deal with `,` in the kernel code.
+// These `,`s confuse the preprocessor into thinking we are passing
+// multiple arguments to the macro.
+#define jiterator_code(...) __VA_ARGS__
+#if defined(__CUDACC__) || defined(__HIPCC__)
+// CPU and CUDA and ROCm case
+#define stringify_code(...) #__VA_ARGS__
+#define jiterator_also_stringify_as(code, str_name) \
+  code /* define the function */                    \
+      const std::string str_name = std::string(stringify_code(code));
+#else
+// CPU only or CPU and ROCm case
+// Only needs the function
+#define jiterator_also_stringify_as(code, str_name) code
+#endif
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/miopen/Descriptors.h b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Descriptors.h
new file mode 100644
index 0000000000000000000000000000000000000000..b66cb9b8720260f7e8faeca08cdfbc4e8f704100
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Descriptors.h
@@ -0,0 +1,146 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at { namespace native {
+
+inline int dataSize(miopenDataType_t dataType)
+{
+  switch (dataType) {
+    case miopenHalf: return 2;
+    case miopenFloat: return 4;
+    case miopenBFloat16: return 2;
+    default: return 8;
+  }
+}
+
+template 
+struct DescriptorDeleter {
+  void operator()(T* x) {
+    if (x != nullptr) {
+      MIOPEN_CHECK(dtor(x));
+    }
+  }
+};
+
+// A generic class for wrapping MIOpen descriptor types.  All you need
+// is to give the underlying type the Descriptor_t points to (usually,
+// if it's miopenTensorDescriptor_t it points to miopenTensorStruct),
+// the constructor and the destructor.  Subclasses are responsible
+// for defining a set() function to actually set the descriptor.
+//
+// Descriptors default construct to a nullptr, and have a descriptor
+// initialized the first time you call set() or any other initializing
+// function.
+template 
+class Descriptor
+{
+public:
+  // Use desc() to access the underlying descriptor pointer in
+  // a read-only fashion.  Most client code should use this.
+  // If the descriptor was never initialized, this will return
+  // nullptr.
+  T* desc() const { return desc_.get(); }
+  T* desc() { return desc_.get(); }
+
+  // Use mut_desc() to access the underlying descriptor pointer
+  // if you intend to modify what it points to (e.g., using
+  // miopenSetFooDescriptor).  This will ensure that the descriptor
+  // is initialized.  Code in this file will use this function.
+  T* mut_desc() { init(); return desc_.get(); }
+protected:
+  void init() {
+    if (desc_ == nullptr) {
+      T* raw_desc;
+      MIOPEN_CHECK(ctor(&raw_desc));
+      desc_.reset(raw_desc);
+    }
+  }
+private:
+  std::unique_ptr> desc_;
+};
+
+class TensorDescriptor
+  : public Descriptor
+{
+public:
+  TensorDescriptor() {}
+  explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
+    set(t, pad);
+  }
+
+  void set(const at::Tensor &t, size_t pad = 0);
+  void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
+
+  void print();
+
+private:
+  void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
+    MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
+  }
+};
+
+std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
+
+class FilterDescriptor
+  : public Descriptor
+{
+ public:
+  void set(const at::Tensor &t, int64_t pad = 0) {
+    set(t, at::MemoryFormat::Contiguous, pad);
+  }
+
+  void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
+
+private:
+  void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
+    MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
+  }
+};
+
+struct ConvolutionDescriptor
+  : public Descriptor
+{
+  void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode,  int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool deterministic) {
+    MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
+    MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
+    MIOPEN_CHECK(miopenSetConvolutionAttribute(mut_desc(), MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, deterministic ? 1 : 0));
+  }
+};
+
+
+struct RNNDescriptor
+  : public Descriptor
+{
+    void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
+              miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
+      MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
+    }
+};
+
+union Constant
+{
+  float f;
+  double d;
+  Constant(miopenDataType_t dataType, double value) {
+    if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
+      f = static_cast(value);
+    } else {
+      d = value;
+    }
+  }
+};
+
+}}  // namespace
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/miopen/Exceptions.h b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Exceptions.h
new file mode 100644
index 0000000000000000000000000000000000000000..044ae3222aa83e512c796fc2b903b2a111285015
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Exceptions.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at { namespace native {
+
+class miopen_exception : public std::runtime_error {
+public:
+  miopenStatus_t status;
+  miopen_exception(miopenStatus_t status, const char* msg)
+      : std::runtime_error(msg)
+      , status(status) {}
+  miopen_exception(miopenStatus_t status, const std::string& msg)
+      : std::runtime_error(msg)
+      , status(status) {}
+};
+
+inline void MIOPEN_CHECK(miopenStatus_t status)
+{
+  if (status != miopenStatusSuccess) {
+    if (status == miopenStatusNotImplemented) {
+        throw miopen_exception(status, std::string(miopenGetErrorString(status)) +
+                ". This error may appear if you passed in a non-contiguous input.");
+    }
+    throw miopen_exception(status, miopenGetErrorString(status));
+  }
+}
+
+inline void HIP_CHECK(hipError_t error)
+{
+  if (error != hipSuccess) {
+    std::string msg("HIP error: ");
+    msg += hipGetErrorString(error);
+    throw std::runtime_error(msg);
+  }
+}
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/miopen/Handle.h b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Handle.h
new file mode 100644
index 0000000000000000000000000000000000000000..8307827d5bfd33c4173c8e14d6e91031e6f1adf9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Handle.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+
+namespace at { namespace native {
+
+miopenHandle_t getMiopenHandle();
+
+}} // namespace
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/miopen/Types.h b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Types.h
new file mode 100644
index 0000000000000000000000000000000000000000..74121cbb9e62f9f974db4fd43575554c89ac8df0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Types.h
@@ -0,0 +1,12 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+
+miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
+
+int64_t miopen_version();
+
+}}  // namespace at::miopen
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/miopen/Utils.h b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..30f8e228165664c6e358838df3c26d4074ccd173
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/miopen/Utils.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at { namespace native {
+
+// This function makes tensors which have zero stride contiguous, by
+// setting the strides to 1.
+inline Tensor contiguousIfZeroInStrides(const Tensor& t) {
+  for (auto s : t.strides()) {
+    if (s == 0) return t.contiguous();
+  }
+  return t;
+}
+
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/miopen/miopen-wrapper.h b/MLPY/Lib/site-packages/torch/include/ATen/miopen/miopen-wrapper.h
new file mode 100644
index 0000000000000000000000000000000000000000..202e189ef6db3456c3a46e88f9cf753459e2ae0d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/miopen/miopen-wrapper.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/EmptyTensor.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/EmptyTensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..0256d2f0e25a5ee8ceb3187f64c4bc1d58043708
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/EmptyTensor.h
@@ -0,0 +1,29 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+#include 
+
+namespace at::detail {
+
+C10_EXPORT TensorBase empty_mps(
+    IntArrayRef size,
+    c10::optional dtype_opt,
+    c10::optional layout_opt,
+    c10::optional device_opt,
+    c10::optional pin_memory_opt,
+    c10::optional memory_format_opt);
+C10_EXPORT TensorBase empty_mps(
+    IntArrayRef size, const TensorOptions &options);
+
+C10_EXPORT TensorBase empty_strided_mps(
+    IntArrayRef size,
+    IntArrayRef stride,
+    ScalarType dtype,
+    c10::optional device_opt);
+
+C10_EXPORT TensorBase empty_strided_mps(
+    IntArrayRef size,
+    IntArrayRef stride,
+    const TensorOptions &options);
+
+} // namespace at::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h
new file mode 100644
index 0000000000000000000000000000000000000000..d52c90e71f8b2ea7febcf8aa79a2a4ea9261c01b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h
@@ -0,0 +1,630 @@
+#pragma once
+
+namespace at::mps {
+
+static const char * indexing_metal_shaders = R"INDEX_METAL(
+#include 
+#include 
+
+using namespace metal;
+
+#if __METAL_VERSION__ < 300
+struct IndexAB {
+    // Allow up to 16 indices
+    metal::array  indexArray [[ id(0) ]];
+};
+#else
+struct IndexAB {
+    constant int64_t* indexArray;
+};
+
+#endif
+
+template
+kernel void index_select(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB           [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB           [[buffer(0)]],
+#endif
+    constant void     * indexSizes        [[buffer(1)]],
+    constant void     * indexStrides      [[buffer(2)]],
+    constant OffsetsT * offsets           [[buffer(3)]],
+    constant void     * inputData         [[buffer(4)]],
+    device   void     * outputData        [[buffer(5)]],
+    constant uint32_t & num_indices       [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]) {
+    constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
+    constant int64_t * index_strides = (constant int64_t *)indexStrides;
+    int64_t offset = 0;
+    for (uint32_t i = 0; i < num_indices; i++) {
+#if __METAL_VERSION__ >= 300
+        constant int64_t* indexArray = indexAB[i].indexArray;
+#else
+        constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
+#endif
+        int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
+        if (index < 0) {
+            index += index_sizes[i];
+        }
+        offset += index * index_strides[i];
+     }
+    device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
+    constant T * in  = (constant T*)((constant char*)inputData  + offsets[thread_index].y + offset);
+    *out = *in;
+}
+
+template
+void index_put_impl(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB,
+#else
+    constant IndexAB  & indexAB,
+#endif
+    constant int64_t  * index_sizes,
+    constant int64_t  * index_strides,
+    constant OffsetsT * offsets,
+    constant void     * inputData,
+    device   void     * outputData,
+    constant uint32_t & num_indices,
+    uint thread_index) {
+    int64_t offset = 0;
+    for (uint32_t i = 0; i < num_indices; i++) {
+#if __METAL_VERSION__ >= 300
+        constant int64_t* indexArray = indexAB[i].indexArray;
+#else
+        constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
+#endif
+        int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
+
+        if (index < 0) {
+            index += index_sizes[i];
+        }
+        offset += index * index_strides[i];
+    }
+    device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
+    constant T * in  = (constant T*)((constant char*)inputData  + offsets[thread_index].y);
+    *out = *in;
+}
+
+template
+kernel void index_put_serial(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB           [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB           [[buffer(0)]],
+#endif
+    constant void     * indexSizes        [[buffer(1)]],
+    constant void     * indexStrides      [[buffer(2)]],
+    constant OffsetsT * offsets           [[buffer(3)]],
+    constant void     * inputData         [[buffer(4)]],
+    device   void     * outputData        [[buffer(5)]],
+    constant uint32_t & num_indices       [[buffer(6)]],
+    constant uint     * numIters          [[buffer(7)]],
+    uint thread_index [[thread_position_in_grid]]) {
+
+    constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
+    constant int64_t * index_strides = (constant int64_t *)indexStrides;
+
+    for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
+        index_put_impl(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
+    }
+}
+
+template
+kernel void index_put(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB           [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB           [[buffer(0)]],
+#endif
+    constant void     * indexSizes        [[buffer(1)]],
+    constant void     * indexStrides      [[buffer(2)]],
+    constant OffsetsT * offsets           [[buffer(3)]],
+    constant void     * inputData         [[buffer(4)]],
+    device   void     * outputData        [[buffer(5)]],
+    constant uint32_t & num_indices       [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]) {
+
+    constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
+    constant int64_t * index_strides = (constant int64_t *)indexStrides;
+    index_put_impl(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
+}
+
+#if __METAL_VERSION__ < 300
+#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE)   \
+template                                                                           \
+[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]]               \
+kernel void index_ ## INDEX_OP_TYPE(                             \
+    constant IndexAB & indexAB           [[buffer(0)]],                            \
+    constant void    * indexSizes        [[buffer(1)]],                            \
+    constant void    * indexStrides      [[buffer(2)]],                            \
+    constant IDX_DTYPE   * offsets           [[buffer(3)]],                        \
+    constant void    * inputData         [[buffer(4)]],                            \
+    device   void    * outputData        [[buffer(5)]],                            \
+    constant uint32_t & num_indices      [[buffer(6)]],                            \
+    uint thread_index [[thread_position_in_grid]]);
+#else
+#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE)   \
+template                                                                           \
+[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]]               \
+kernel void index_ ## INDEX_OP_TYPE(                             \
+    constant IndexAB * indexAB           [[buffer(0)]],                            \
+    constant void    * indexSizes        [[buffer(1)]],                            \
+    constant void    * indexStrides      [[buffer(2)]],                            \
+    constant IDX_DTYPE   * offsets           [[buffer(3)]],                        \
+    constant void    * inputData         [[buffer(4)]],                            \
+    device   void    * outputData        [[buffer(5)]],                            \
+    constant uint32_t & num_indices      [[buffer(6)]],                            \
+    uint thread_index [[thread_position_in_grid]]);
+#endif
+
+#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE)     \
+    REGISTER_INDEX_OP(8bit,  idx32, char,  INDEX_OP_TYPE, uint3);     \
+    REGISTER_INDEX_OP(8bit,  idx64, char,  INDEX_OP_TYPE, ulong3);    \
+    REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3);     \
+    REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3);    \
+    REGISTER_INDEX_OP(32bit, idx32, int,   INDEX_OP_TYPE, uint3);     \
+    REGISTER_INDEX_OP(32bit, idx64, int,   INDEX_OP_TYPE, ulong3);    \
+    REGISTER_INDEX_OP(64bit, idx32, long,  INDEX_OP_TYPE, uint3);     \
+    REGISTER_INDEX_OP(64bit, idx64, long,  INDEX_OP_TYPE, ulong3);
+
+REGISTER_INDEX_OP_ALL_DTYPES(select);
+REGISTER_INDEX_OP_ALL_DTYPES(put);
+
+#if __METAL_VERSION__ < 300
+#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE)   \
+template                                                                                           \
+[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]]                               \
+kernel void index_ ## INDEX_OP_TYPE(                                             \
+    constant IndexAB   & indexAB           [[buffer(0)]],                                          \
+    constant void      * indexSizes        [[buffer(1)]],                                          \
+    constant void      * indexStrides      [[buffer(2)]],                                          \
+    constant IDX_DTYPE * offsets           [[buffer(3)]],                                          \
+    constant void      * inputData         [[buffer(4)]],                                          \
+    device   void      * outputData        [[buffer(5)]],                                          \
+    constant uint32_t  & num_indices       [[buffer(6)]],                                          \
+    constant uint      * numIters          [[buffer(7)]],                                          \
+    uint thread_index [[thread_position_in_grid]]);
+#else
+#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE)   \
+template                                                                                           \
+[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]]                               \
+kernel void index_ ## INDEX_OP_TYPE(                                             \
+    constant IndexAB   * indexAB           [[buffer(0)]],                                          \
+    constant void      * indexSizes        [[buffer(1)]],                                          \
+    constant void      * indexStrides      [[buffer(2)]],                                          \
+    constant IDX_DTYPE * offsets           [[buffer(3)]],                                          \
+    constant void      * inputData         [[buffer(4)]],                                          \
+    device   void      * outputData        [[buffer(5)]],                                          \
+    constant uint32_t  & num_indices       [[buffer(6)]],                                          \
+    constant uint      * numIters          [[buffer(7)]],                                          \
+    uint thread_index [[thread_position_in_grid]]);
+#endif
+
+#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE)                   \
+    REGISTER_SINGLE_THREADED_INDEX_OP(8bit,  idx32, char,  INDEX_OP_TYPE, uint3);     \
+    REGISTER_SINGLE_THREADED_INDEX_OP(8bit,  idx64, char,  INDEX_OP_TYPE, ulong3);    \
+    REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3);     \
+    REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3);    \
+    REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int,   INDEX_OP_TYPE, uint3);     \
+    REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int,   INDEX_OP_TYPE, ulong3);    \
+    REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long,  INDEX_OP_TYPE, uint3);     \
+    REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long,  INDEX_OP_TYPE, ulong3);
+
+REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
+
+template
+kernel void kernel_index_offsets(constant StridesT * strides         [[buffer(0)]],
+                                device DataT      * data_offsets    [[buffer(1)]],
+                                constant uint     * iter_shape      [[buffer(2)]],
+                                constant uint     & num_dimensions  [[buffer(3)]],
+                                uint thread_index [[thread_position_in_grid]]) {
+    data_offsets[thread_index] = 0;
+    uint32_t idx = thread_index;
+    for (uint32_t dim = 0; dim < num_dimensions; dim++) {
+        uint32_t remainder = idx % iter_shape[dim];
+        idx /= iter_shape[dim];
+
+        data_offsets[thread_index] += remainder * DataT(strides[dim]);
+    }
+}
+
+template
+[[host_name("kernel_index_offsets_32")]]
+kernel void kernel_index_offsets(
+                constant packed_uint3 * strides         [[buffer(0)]],
+                device uint3          * data_offsets    [[buffer(1)]],
+                constant uint         * iter_shape      [[buffer(2)]],
+                constant uint         & num_dimensions  [[buffer(3)]],
+                uint thread_index [[thread_position_in_grid]]);
+
+template
+[[host_name("kernel_index_offsets_64")]]
+kernel void kernel_index_offsets(
+                constant packed_uint3 * strides         [[buffer(0)]],
+                device ulong3          * data_offsets    [[buffer(1)]],
+                constant uint         * iter_shape      [[buffer(2)]],
+                constant uint         & num_dimensions  [[buffer(3)]],
+                uint thread_index [[thread_position_in_grid]]);
+
+template
+kernel void index_put_accumulate_native_dtypes(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB     [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB     [[buffer(0)]],
+#endif
+    constant void     * indexSizes   [[buffer(1)]],
+    constant void     * indexStrides [[buffer(2)]],
+    constant OffsetsT * offsets      [[buffer(3)]],
+    constant void     * inputData    [[buffer(4)]],
+    device void       * outputData   [[buffer(5)]],
+    constant uint32_t & num_indices  [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]) {
+    constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
+    constant int64_t * index_strides = (constant int64_t *)indexStrides;
+    int64_t offset = 0;
+    for (uint32_t i = 0; i < num_indices; i++) {
+#if __METAL_VERSION__ >= 300
+        constant int64_t* indexArray = indexAB[i].indexArray;
+#else
+        constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
+#endif
+        int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
+        if (index < 0) {
+            index += index_sizes[i];
+        }
+        offset += index * index_strides[i];
+    }
+    device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
+    constant E * in  = (constant E*)((constant char*)inputData  + offsets[thread_index].y);
+    atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
+}
+
+template
+__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
+    device atomic_uint* uintAddr = (device atomic_uint*)addr;
+    uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
+    T updated = as_type(expected) + value;
+    while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type(updated), memory_order_relaxed, memory_order_relaxed)) {
+        updated = as_type(expected) + value;
+    }
+}
+
+template
+kernel void atomic_index_put_accumulate(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB           [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB           [[buffer(0)]],
+#endif
+    constant void     * indexSizes        [[buffer(1)]],
+    constant void     * indexStrides      [[buffer(2)]],
+    constant OffsetsT * offsets           [[buffer(3)]],
+    constant void     * inputData         [[buffer(4)]],
+    device   void     * outputData        [[buffer(5)]],
+    constant uint32_t & num_indices       [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]) {
+    constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
+    constant int64_t * index_strides = (constant int64_t *)indexStrides;
+    int64_t offset = 0;
+    for (uint32_t i = 0; i < num_indices; i++) {
+#if __METAL_VERSION__ >= 300
+        constant int64_t* indexArray = indexAB[i].indexArray;
+#else
+        constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
+#endif
+        int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
+        if (index < 0) {
+            index += index_sizes[i];
+        }
+        offset += index * index_strides[i];
+    }
+    device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
+    constant T  * in  = (constant T*)((constant char*)inputData + offsets[thread_index].y);
+    atomic_fetch_add_relaxed(out, *in);
+}
+
+template
+[[host_name("index_put_accumulate_32bit_float_idx32")]]
+kernel void atomic_index_put_accumulate(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB     [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB     [[buffer(0)]],
+#endif
+    constant void     * indexSizes   [[buffer(1)]],
+    constant void     * indexStrides [[buffer(2)]],
+    constant uint3    * offsets      [[buffer(3)]],
+    constant void     * inputData    [[buffer(4)]],
+    device   void     * outputData   [[buffer(5)]],
+    constant uint32_t & num_indices  [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]);
+
+template
+[[host_name("index_put_accumulate_32bit_float_idx64")]]
+kernel void atomic_index_put_accumulate(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB     [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB     [[buffer(0)]],
+#endif
+    constant void     * indexSizes   [[buffer(1)]],
+    constant void     * indexStrides [[buffer(2)]],
+    constant ulong3   * offsets      [[buffer(3)]],
+    constant void     * inputData    [[buffer(4)]],
+    device   void     * outputData   [[buffer(5)]],
+    constant uint32_t & num_indices  [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]);
+
+template
+[[host_name("index_put_accumulate_32bit_int_idx32")]]
+kernel void index_put_accumulate_native_dtypes(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB     [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB     [[buffer(0)]],
+#endif
+    constant void     * indexSizes   [[buffer(1)]],
+    constant void     * indexStrides [[buffer(2)]],
+    constant uint3    * offsets      [[buffer(3)]],
+    constant void     * inputData    [[buffer(4)]],
+    device   void     * outputData   [[buffer(5)]],
+    constant uint32_t & num_indices [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]);
+
+template
+[[host_name("index_put_accumulate_32bit_int_idx64")]]
+kernel void index_put_accumulate_native_dtypes(
+#if __METAL_VERSION__ >= 300
+    constant IndexAB  * indexAB     [[buffer(0)]],
+#else
+    constant IndexAB  & indexAB     [[buffer(0)]],
+#endif
+    constant void     * indexSizes   [[buffer(1)]],
+    constant void     * indexStrides [[buffer(2)]],
+    constant ulong3   * offsets      [[buffer(3)]],
+    constant void     * inputData    [[buffer(4)]],
+    device   void     * outputData   [[buffer(5)]],
+    constant uint32_t & num_indices [[buffer(6)]],
+    uint thread_index [[thread_position_in_grid]]);
+)INDEX_METAL";
+
+static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
+struct __attribute__ ((packed)) packed_uint5{{
+  uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
+}};
+
+template
+Y cast(const X x);
+
+template<>
+{1} cast<{1}, {0}>(const {0} x) {{
+ return {2};
+}}
+
+kernel void scatter_kernel_5(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint5 & size   [[buffer(2)]],
+                             constant packed_uint5 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint5 local_index;
+    local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
+    local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
+    local_index.z = linear_index / (size.u * size.w) % size.z;
+    local_index.w = linear_index / size.u % size.w;
+    local_index.u = linear_index % size.u;
+
+    packed_uint5 strided_index;
+    strided_index.x = local_index.x * stride.x;
+    strided_index.y = local_index.y * stride.y;
+    strided_index.z = local_index.z * stride.z;
+    strided_index.w = local_index.w * stride.w;
+    strided_index.u = local_index.u * stride.u;
+
+    dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_4(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint4 & size   [[buffer(2)]],
+                             constant packed_uint4 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint4 local_index;
+    local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
+    local_index.y = linear_index / (size[3] * size[2]) % size[1];
+    local_index.z = linear_index / size[3] % size[2];
+    local_index.w = linear_index % size[3];
+
+    const packed_uint4 strided_index = local_index * stride;
+    dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_3(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint3 & size   [[buffer(2)]],
+                             constant packed_uint3 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint3 local_index;
+    local_index.x = linear_index / (size[2] * size[1]) % size[0];
+    local_index.y = linear_index / size[2] % size[1];
+    local_index.z = linear_index % size[2];
+
+    const packed_uint3 strided_index = local_index * stride;
+    dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_2(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant packed_uint2 & size   [[buffer(2)]],
+                             constant packed_uint2 & stride [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint2 local_index;
+    local_index.x = linear_index / size[1] % size[0];
+    local_index.y = linear_index % size[1];
+
+    const packed_uint2 strided_index = local_index * stride;
+    dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
+}}
+
+kernel void scatter_kernel_1(uint linear_index              [[thread_position_in_grid]],
+                             constant void * src_           [[buffer(0)]],
+                             device void * dst_             [[buffer(1)]],
+                             constant int & size            [[buffer(2)]],
+                             constant int & stride          [[buffer(3)]],
+                             constant uint32_t & numel      [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    const int local_index = linear_index % size;
+    const int strided_index = local_index * stride;
+    dst[strided_index] = cast<{1}>(src[linear_index]);
+}}
+)METAL_SCATTER";
+
+static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
+struct __attribute__ ((packed)) packed_uint5{{
+  uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
+}};
+
+template
+Y cast(const X x);
+
+template<>
+{1} cast<{1}, {0}>(const {0} x) {{
+ return {2};
+}}
+
+kernel void gather_kernel_5(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint5 & size    [[buffer(2)]],
+                            constant packed_uint5 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+
+    packed_uint5 local_index;
+    local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
+    local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
+    local_index.z = linear_index / (size.u * size.w) % size.z;
+    local_index.w = linear_index / size.u % size.w;
+    local_index.u = linear_index % size.u;
+
+    packed_uint5 strided_index;
+    strided_index.x = local_index.x * stride.x;
+    strided_index.y = local_index.y * stride.y;
+    strided_index.z = local_index.z * stride.z;
+    strided_index.w = local_index.w * stride.w;
+    strided_index.u = local_index.u * stride.u;
+
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
+}}
+
+kernel void gather_kernel_4(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint4 & size    [[buffer(2)]],
+                            constant packed_uint4 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint4 local_index;
+    local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
+    local_index.y = linear_index / (size[3] * size[2]) % size[1];
+    local_index.z = linear_index / size[3] % size[2];
+    local_index.w = linear_index % size[3];
+
+    const packed_uint4 strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
+}}
+
+kernel void gather_kernel_3(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint3 & size    [[buffer(2)]],
+                            constant packed_uint3 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint3 local_index;
+    local_index.x = linear_index / (size[2] * size[1]) % size[0];
+    local_index.y = linear_index / size[2] % size[1];
+    local_index.z = linear_index % size[2];
+
+    const packed_uint3 strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
+}}
+
+kernel void gather_kernel_2(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant packed_uint2 & size    [[buffer(2)]],
+                            constant packed_uint2 & stride  [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    packed_uint2 local_index;
+    local_index.x = linear_index / size[1] % size[0];
+    local_index.y = linear_index % size[1];
+
+    const packed_uint2 strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
+}}
+
+kernel void gather_kernel_1(uint linear_index               [[thread_position_in_grid]],
+                            constant void * src_            [[buffer(0)]],
+                            device void * dst_              [[buffer(1)]],
+                            constant int & size             [[buffer(2)]],
+                            constant int & stride           [[buffer(3)]],
+                            constant uint32_t & numel       [[buffer(4)]]) {{
+    if (linear_index >= numel) return;
+
+    constant {0} * src = (constant {0} *)src_;
+    device {1} * dst = (device {1} *)dst_;
+
+    const int local_index = linear_index % size;
+    const int strided_index = local_index * stride;
+    dst[linear_index] = cast<{1}>(src[strided_index]);
+}}
+)METAL_GATHER";
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSAllocator.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSAllocator.h
new file mode 100644
index 0000000000000000000000000000000000000000..62bf958a9b95c63669a99200f7b05b262aa09f03
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSAllocator.h
@@ -0,0 +1,401 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// this implementation is based on CUDACachingAllocator.
+// It utilizes Metal Heaps to improve the performance with buffer allocation.
+// Do not include this header. Use MPSAllocatorInterface.h instead.
+// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
+namespace at::mps::HeapAllocator {
+
+static const size_t kMaxSmallAlloc = MB(1);    // largest "small" allocation is 1 MiB
+static const size_t kMinLargeAlloc = MB(10);   // allocations between 1 and 10 MiB may use kLargeHeap
+static const size_t kRoundLarge    = MB(2);    // round up large allocations to 2 MiB
+static const size_t kSmallHeap     = MB(8);    // "small" allocations are packed in 8 MiB heaps
+static const size_t kLargeHeap     = MB(32);   // "large" allocations may be packed in 32 MiB heaps
+static const size_t kXLargeHeapD   = MB(128);  // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
+static const size_t kXLargeHeapU   = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
+static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
+
+// buffer pools could be customized with a combination of usage flags
+enum UsageFlags : uint32_t {
+  PRIVATE = 0,
+  SMALL   = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
+  SHARED  = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
+  MANAGED = (1 << 2), // managed storage mode
+  HAZARD  = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
+  SCALAR  = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
+};
+// debug verbosity flags
+enum DebugVerbosity : uint32_t {
+  SILENT      = 0,
+  PROFILING   = (1 << 0), // print generic profiling data for total system memory usage
+  ALLOCATIONS = (1 << 1), // print buffer allocations
+  RECYCLES    = (1 << 2), // print buffer recycling
+  RELEASES    = (1 << 3), // print buffer releases
+  LARGE_ONLY  = (1 << 4), // only log large buffer pool transactions
+};
+
+struct HeapBlock;
+
+struct BufferBlock {
+  id buffer;
+  void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
+  size_t size; // size after alignment
+  size_t requested_size; // requested size (before alignment)
+  // buffer shape is used for retrieving base of views in cached graphs
+  std::vector shape;
+  bool in_use = false;
+  HeapBlock* heap;
+  id_t buf_id;
+  // counter to candidate least recently used buffers for garbage collection
+  uint32_t gc_count = 0;
+  uint32_t use_count = 0;
+  // counter to assign unique ids to buffer blocks
+  static uint64_t buffer_counter;
+  // Metal events used to sync GPU/CPU operations on the shared-storage buffers
+  MPSEventPtr event;
+
+  BufferBlock(size_t Size, size_t RequestedSize = 0, const id Buffer = nullptr,
+              HeapBlock* Heap = nullptr) :
+              buffer(Buffer), size(Size), requested_size(RequestedSize),
+              heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
+
+  static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
+    return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
+  }
+  static size_t alignUp(size_t Size, size_t Alignment) {
+    assert(((Alignment - 1) & Alignment) == 0);
+    return ((Size + Alignment - 1) & ~(Alignment - 1));
+  }
+  uint32_t retainCount() const { return [buffer retainCount]; }
+};
+typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
+
+struct BufferPool;
+struct AllocParams {
+  AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
+              search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
+  size_t size() const { return search_key.size; }
+
+  BufferBlock search_key;
+  BufferPool* pool;
+  BufferBlock* buffer_block = nullptr;
+  size_t requested_size;
+  // true if we exceed the low watermark limit. In this case
+  // we apply strategies to relieve the pressure before allocation.
+  bool has_memory_pressure = false;
+  // true if we're allocating on a unified memory device
+  bool has_unified_memory = true;
+};
+
+struct HeapBlock {
+  id heap;
+  struct { size_t total, available; } size;
+  BufferPool* pool;
+  unsigned int n_buffers = 0;
+  id_t heap_id;
+  // indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
+  bool is_split;
+  // counter to assign unique ids to heap blocks
+  static uint64_t heap_counter;
+
+  HeapBlock(size_t Size, const id Heap = nullptr, BufferPool *Pool = nullptr) :
+            heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
+            heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
+
+  static MTLResourceOptions getOptions(uint32_t usage) {
+    // TODO: check the caching performance of write-combined mode
+    MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
+
+    if (usage & UsageFlags::MANAGED)
+      options |= MTLResourceStorageModeManaged;
+    else if (usage & UsageFlags::SHARED)
+      options |= MTLResourceStorageModeShared;
+    else
+      options |= MTLResourceStorageModePrivate;
+
+    options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
+
+    return options;
+  }
+
+  static HeapBlock* createHeapBlock(AllocParams& params, id device, uint32_t usage) {
+    HeapBlock *heapBlock = nullptr;
+    bool is_split = true;
+    const size_t size = params.size();
+    MTLHeapDescriptor *d = [MTLHeapDescriptor new];
+    if (d) {
+      const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
+      if (size <= kMaxSmallAlloc) {
+        d.size = kSmallHeap;
+      } else if (size < kMinLargeAlloc) {
+        d.size = kLargeHeap;
+      } else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
+        d.size = kXLargeHeap;
+      } else {
+        d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
+        is_split = false;
+      }
+      d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
+      d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
+      // this automatically handles Metal buffer access synchronizations at the
+      // cost of slightly lower performance.
+      d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
+      d.resourceOptions = getOptions(usage);
+      d.type = MTLHeapTypeAutomatic;
+      id heap = [device newHeapWithDescriptor: d];
+      if (heap) {
+        [heap setPurgeableState:MTLPurgeableStateNonVolatile];
+        const size_t heap_size = heapAvailableSize(heap);
+        heapBlock = new HeapBlock(heap_size, heap, params.pool);
+        if (heapBlock) {
+          heapBlock->is_split = is_split;
+        }
+      }
+      [d release];
+    }
+    return heapBlock;
+  }
+  static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
+    return (a->size.available != b->size.available) ? a->size.available < b->size.available :
+                                                      (uintptr_t)a->heap < (uintptr_t)b->heap;
+  }
+  static NSUInteger heapAvailableSize(id heap, size_t Alignment = vm_page_size) {
+    return [heap maxAvailableSizeWithAlignment:Alignment];
+  }
+  NSUInteger Size() {
+    return [heap size];
+  }
+  id newMTLBuffer(size_t length, uint32_t usage) {
+    id buf = [heap newBufferWithLength:length options:getOptions(usage)];
+    if (buf) {
+      updateAvailableSize();
+      n_buffers++;
+    }
+    return buf;
+  }
+  // returns the retainCount before releasing the buffer
+  uint32_t releaseMTLBuffer(id& buffer) {
+    const uint32_t retainCount = [buffer retainCount];
+    [buffer release];
+    buffer = nil;
+    updateAvailableSize();
+    n_buffers--;
+    return retainCount;
+  }
+  // returns the retainCount before releasing the heap
+  uint32_t releaseMTLHeap() {
+    const uint32_t retainCount = [heap retainCount];
+    TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
+    [heap setPurgeableState:MTLPurgeableStateEmpty];
+    [heap release];
+    heap = nil;
+    size.available = 0;
+    return retainCount;
+  }
+  uint32_t retainCount() const { return [heap retainCount]; }
+  void updateAvailableSize() { size.available = heapAvailableSize(heap); }
+};
+typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
+
+struct BufferPool {
+  enum class Kind {
+    PRIVATE_SMALL,
+    PRIVATE_LARGE,
+    SHARED_SMALL,
+    SHARED_LARGE,
+    SCALAR,
+  };
+
+  BufferPool(const id Device, uint32_t Usage) :
+             device(Device), usage(Usage),
+             heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
+
+  const id device;
+  // usage flags to customize the pool for various purposes (see UsageFlags enum)
+  const uint32_t usage;
+  // total number of buffers in the pool
+  uint32_t n_buffers = 0;
+  // total allocations size on this pool
+  size_t allocated_size = 0;
+  // total memory available in the pool
+  size_t available_size = 0;
+  // list of heaps ordered by their "available" (not total) memory size
+  std::set heaps;
+  // list of only "available" buffers in the pool (i.e., buffers not in-use)
+  std::set available_buffers;
+  // list of buffers that are in a state of "limbo" where they've already been freed
+  // from PyTorch-side, but were not returned to pool due to still being
+  // in-use by command buffers with retainCount > 1. In this state, the buffer is
+  // neither ready to be recycled, nor could be returned to pool as available.
+  // These buffers will be returned to pool once the command buffer's
+  // completionHandler callbacks are called.
+  std::unordered_set buffers_pending_free;
+  // list of heaps pending size update
+  std::unordered_set heaps_pending_update;
+};
+
+class MPSHeapAllocatorImpl {
+public:
+  explicit MPSHeapAllocatorImpl() :
+    m_device(at::mps::MPSDevice::getInstance()->device()),
+    m_max_buffer_size([m_device maxBufferLength]),
+    m_stream(getDefaultMPSStream()),
+    m_event_pool(getMPSEventPool()) {
+    init_allocator();
+  }
+  ~MPSHeapAllocatorImpl() {
+    emptyCache();
+  }
+  // interface exposed to at::Allocator
+  id malloc(size_t size, uint32_t usage);
+  // frees a buffer and returns it into buffer pool
+  void free(void* ptr);
+  // releases all the cached buffers and their associated heaps
+  void emptyCache();
+  // free inactive buffers that are pending to be freed
+  void freeInactiveBuffers();
+  // returns true if buffer was allocated from the shared pool
+  bool isSharedBuffer(const void* ptr);
+  // get the requested unaligned size of an MTLBuffer
+  ssize_t getUnalignedBufferSize(const void* ptr);
+  // set the shape of a base tensor from a view tensor
+  void setBufferShape(const void* ptr, const IntArrayRef& shape);
+  // retrieve the shape of a base tensor from a view tensor
+  IntArrayRef getBufferShape(const void* ptr);
+  // get the unique ID of the buffer
+  id_t getBufferId(const void* ptr);
+  // allocate a buffer from a specialized pool to import CPU scalars into GPU
+  id allocScalarBufferWithValue(void* value, size_t size);
+  // returns a CPU-mapping of the input buffer and its retainCount,
+  // if only it has Shared storage-mode and allocated on MPSAllocator
+  std::pair getSharedBufferPtr(const void* buffer);
+  // records events for a list of MTLBuffers (list is used to lock the mutex once)
+  // returns true if records any event (given if passed buffers exist and are shared-storage)
+  bool recordEvents(c10::ArrayRef buffers);
+  // waits for the event to signal the completion of GPU execution
+  // on the passed shared buffers (list is used to lock the mutex once)
+  // returns true if actually waited on any event
+  bool waitForEvents(c10::ArrayRef buffers);
+  // this indicates how far (in Megabytes) the current total allocations are from the
+  // low watermark limit which is used to detect if we're under memory pressure
+  // This returns zero if we've reached the low watermark limit
+  ssize_t getLowWatermarkValue();
+  // (see m_low_watermark_ratio for description)
+  void setLowWatermarkRatio(double ratio);
+  // (see m_high_watermark_ratio for description)
+  void setHighWatermarkRatio(double ratio);
+  // (see m_low_watermark_limit for description)
+  size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
+  // (see m_max_total_allowed_size for description)
+  size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
+  // (see m_total_allocated_memory for description)
+  size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
+  // (see m_current_allocated_memory for description)
+  size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
+  // total GPU memory allocated in the process by Metal driver; including
+  // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
+  size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
+  // (see enum DebugVerbosity for description)
+  uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
+  // returns the device that we allocate from
+  inline id Device() const { return m_device; }
+
+  // TODO: make a common function to do size unit conversions in PyTorch.
+  inline std::string format_size(uint64_t size) const;
+
+private:
+  // (see m_high_watermark_ratio for description)
+  constexpr static double default_high_watermark_ratio = 1.7;
+  // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
+  constexpr static double default_high_watermark_upper_bound = 2.0;
+  // (see m_low_watermark_ratio for description)
+  // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
+  constexpr static double default_low_watermark_ratio_unified  = 1.4;
+  constexpr static double default_low_watermark_ratio_discrete = 1.0;
+
+  const id m_device;
+  std::recursive_mutex m_mutex;
+  // allocated buffers by device pointer
+  ska::flat_hash_map m_allocated_buffers;
+  // using a container for pools to simplify iterating them
+  ska::flat_hash_map> m_pools;
+  // total memory allocated by HeapAllocator (including blocks in pools)
+  size_t m_total_allocated_memory = 0;
+  // currently active memory allocations in use (i.e., blocks not in pools)
+  size_t m_current_allocated_memory = 0;
+  // max buffer size allowed by Metal
+  size_t m_max_buffer_size = 0;
+  // maximum total size allowed to be allocated
+  size_t m_max_total_allowed_size = 0;
+  // high watermark ratio is a hard limit for the total allowed allocations
+  // 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
+  // 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
+  // >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
+  // e.g., value 0.95 means we allocate up to 95% of recommended maximum
+  // allocation size; beyond that, the allocations would fail with OOM error.
+  double m_high_watermark_ratio;
+  // low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
+  // level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
+  // Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
+  // e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
+  // allocation size.
+  double m_low_watermark_ratio;
+  // low watermark size limit (in Bytes) at the time we initialize the allocator
+  size_t m_low_watermark_limit;
+  // use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
+  uint32_t m_debug_verbosity;
+  // default MPS stream
+  MPSStream* m_stream;
+  // we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
+  std::shared_ptr m_event_pool;
+
+  void init_allocator();
+  void init_buffer_pools();
+  HeapBlock* get_free_heap(AllocParams& params);
+  bool get_free_buffer(AllocParams& params);
+  BufferBlock* get_allocated_buffer_block(const void* ptr);
+  BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
+  bool alloc_buffer(AllocParams& params);
+  void free_buffer(BufferBlock* buffer_block);
+  // returns true if the container heap is also released
+  bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
+  void release_buffers(BufferPool& pool);
+  bool release_available_cached_buffers(AllocParams& params);
+  bool release_cached_buffers();
+  // free unused cached blocks to reclaim GPU memory if memory pressure is high
+  void garbage_collect_cached_buffers(AllocParams& params);
+  // returns the suitable buffer pool type for the usage or
+  // requested/allocated sizes
+  BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
+  // returns the aligned allocation size that is optimized
+  // for the buffers to get reused frequently
+  size_t get_allocation_size(size_t size, uint32_t usage) const;
+  // maximum size of device memory available for allocation in current process
+  // Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
+  size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
+  // there are implicit allocations from MPS backend, so we need to query the 'device' for
+  // total allocated size instead of manually tracking in MPSAllocator
+  size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
+
+  bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
+    for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
+      MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
+    }
+    return true;
+  }
+};
+
+} // namespace at::mps::HeapAllocator
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h
new file mode 100644
index 0000000000000000000000000000000000000000..88a977fe48bf4db6e082e7cb9d60fe42c0616531
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h
@@ -0,0 +1,61 @@
+//  Copyright © 2023 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+#define MB(x) (x * 1048576UL)
+
+namespace at::mps {
+
+// this is a public interface to access MPSAllocator.
+// Do not declare methods that would depend on MPS or Metal frameworks.
+class IMPSAllocator : public c10::Allocator {
+public:
+  // see the comments in MPSAllocator.h for the description of these methods.
+  virtual void emptyCache() const = 0;
+  virtual void freeInactiveBuffers() const = 0;
+  virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
+  virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
+  virtual id_t getBufferId(const void* ptr) const = 0;
+  virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
+  virtual bool isSharedBuffer(const void* ptr) const = 0;
+  virtual bool isSharedStorageSupported() const = 0;
+  virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
+  virtual std::string formatSize(size_t size) const = 0;
+  virtual void setLowWatermarkRatio(double ratio) const = 0;
+  virtual void setHighWatermarkRatio(double ratio) const = 0;
+  virtual ssize_t getLowWatermarkValue() const = 0;
+  virtual size_t getLowWatermarkLimit() const = 0;
+  virtual size_t getHighWatermarkLimit() const = 0;
+  virtual size_t getTotalAllocatedMemory() const = 0;
+  virtual size_t getCurrentAllocatedMemory() const = 0;
+  virtual size_t getDriverAllocatedMemory() const = 0;
+  virtual std::pair getSharedBufferPtr(const void* ptr) const = 0;
+  virtual bool recordEvents(c10::ArrayRef buffers) const = 0;
+  virtual bool waitForEvents(c10::ArrayRef buffers) const = 0;
+};
+
+class IMpsAllocatorCallback {
+ public:
+  enum class EventType {
+    ALLOCATED, // buffer got allocated to be used immediately
+    RECYCLED,  // buffer pulled from free list to be reused
+    FREED,     // buffer put to free list for future recycling
+    RELEASED,  // buffer memory released
+    ALLOCATION_FAILED // buffer allocation failed
+  };
+  virtual ~IMpsAllocatorCallback() = default;
+  virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
+};
+
+// MPS allocator will execute every registered callback when a block of memory is freed.
+C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
+#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
+  C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
+
+IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSDevice.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSDevice.h
new file mode 100644
index 0000000000000000000000000000000000000000..6d92ac5b7c41bd9905efdcdee77659a41b8a767b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSDevice.h
@@ -0,0 +1,85 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+#include 
+#include 
+#include 
+
+
+#ifdef __OBJC__
+#include 
+#include 
+#include 
+typedef id MTLDevice_t;
+typedef id MTLLibrary_t;
+typedef id MTLComputePipelineState_t;
+typedef id MTLLibrary_t;
+#else
+typedef void* MTLDevice;
+typedef void* MTLDevice_t;
+typedef void* MTLLibrary_t;
+typedef void* MTLComputePipelineState_t;
+typedef void* MTLLibrary_t;
+#endif
+
+using namespace std;
+
+namespace at::mps {
+
+// Helper enum to check if a MPSGraph op is supported in a given macOS version
+enum class MacOSVersion : uint32_t {
+  MACOS_VER_13_0_PLUS = 0,
+  MACOS_VER_13_1_PLUS,
+  MACOS_VER_13_2_PLUS,
+  MACOS_VER_13_3_PLUS,
+  MACOS_VER_14_0_PLUS,
+};
+
+//-----------------------------------------------------------------
+//  MPSDevice
+//
+// MPSDevice is a singleton class that returns the default device
+//-----------------------------------------------------------------
+
+class TORCH_API MPSDevice {
+ public:
+  /**
+   * MPSDevice should not be cloneable.
+   */
+  MPSDevice(MPSDevice& other) = delete;
+  /**
+   * MPSDevice should not be assignable.
+   */
+  void operator=(const MPSDevice&) = delete;
+  /**
+   * Gets single instance of the Device.
+   */
+  static MPSDevice* getInstance();
+  /**
+   * Returns the single device.
+   */
+  MTLDevice_t device() {
+    return _mtl_device;
+  }
+  /**
+   * Returns whether running on Ventura or newer
+   */
+  bool isMacOS13Plus(MacOSVersion version) const;
+
+  MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
+  MTLLibrary_t getMetalIndexingLibrary();
+
+  ~MPSDevice();
+
+ private:
+  static MPSDevice* _device;
+  MTLDevice_t _mtl_device;
+  MTLLibrary_t _mtl_indexing_library;
+  MPSDevice();
+};
+
+TORCH_API bool is_available();
+TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
+TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSEvent.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSEvent.h
new file mode 100644
index 0000000000000000000000000000000000000000..ab4ad68412a7c9f431da8032924b51c6c5c33660
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSEvent.h
@@ -0,0 +1,100 @@
+//  Copyright © 2023 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::mps {
+
+// NOTE: don't create instances of this class directly.
+// Use MPSEventPool to acquire instances of MPSEvent.
+class MPSEvent {
+public:
+  explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
+  ~MPSEvent();
+
+  // records an event on the stream
+  void record(bool needsLock, bool syncEvent = false);
+  // makes all future work submitted to the stream wait for this event.
+  bool wait(bool needsLock, bool syncEvent = false);
+  // schedules a notifyListener callback for the event.
+  bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
+  // checks if events are already signaled.
+  bool query() const;
+  // blocks the CPU thread until all the GPU work that were scheduled
+  // prior to recording this event are completed.
+  bool synchronize();
+  // resets this event with new parameters in case it gets reused from the event pool
+  void reset(MPSStream* stream, bool enable_timing);
+  // returns the unique ID of the event instance
+  id_t getID() const { return m_id; }
+  // returns the completion timestamp of the event
+  uint64_t getCompletionTime() const { return m_completion_time; }
+  // if already recorded, waits for cpu_sync_cv to be signaled
+  void waitForCpuSync();
+
+private:
+  id_t m_id;
+  // enables measuring the completion time of the notifyListener of this event
+  bool m_enable_timing;
+  uint64_t m_signalCounter = 0;
+  MPSStream* m_stream = nullptr;
+  MTLSharedEvent_t m_event = nullptr;
+  MTLSharedEventListener* m_listener = nullptr;
+  // used to sync the events created on this Stream with CPU
+  std::mutex m_cpu_sync_mutex{};
+  std::condition_variable m_cpu_sync_cv{};
+  // CondVar predicate to sync the events created on this Stream with CPU
+  bool m_cpu_sync_completed = false;
+  // used to compute elapsed time
+  uint64_t m_completion_time = 0;
+
+  void recordLocked(bool syncEvent);
+  bool waitLocked(bool syncEvent);
+  bool notifyLocked(MTLSharedEventNotificationBlock block);
+  void notifyCpuSync();
+  static uint64_t getTime() {
+    return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
+  }
+};
+
+typedef std::unique_ptr> MPSEventPtr;
+
+class MPSEventPool {
+public:
+  explicit MPSEventPool(MPSStream* default_stream);
+  ~MPSEventPool();
+
+  MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
+  void emptyCache();
+
+  // these are mainly used for MPSHooks and torch.mps.Event() bindings
+  id_t acquireEvent(bool enable_timing);
+  void releaseEvent(id_t event_id);
+  void recordEvent(id_t event_id, bool syncEvent);
+  void waitForEvent(id_t event_id, bool syncEvent);
+  void synchronizeEvent(id_t event_id);
+  bool queryEvent(id_t event_id);
+  // returns elapsed time between two recorded events in milliseconds
+  double elapsedTime(id_t start_event_id, id_t end_event_id);
+
+private:
+  MPSStream* m_default_stream = nullptr;
+  std::recursive_mutex m_mutex;
+  std::stack> m_pool{};
+  // dictionary to associate event IDs with event objects
+  // used to retain in-use events out of the pool
+  // for torch.mps.Event() bindings.
+  std::unordered_map m_in_use_events{};
+  uint64_t m_event_counter = 0;
+  std::function m_default_deleter;
+
+  MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
+};
+
+// shared_ptr is used to get MPSEventPool destroyed after dependent instances
+std::shared_ptr getMPSEventPool();
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..9ed6acd31cfa079a4bf9b5c1edc8824920f603d8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h
@@ -0,0 +1,52 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace mps::detail {
+
+static const uint32_t PHILOX_STATE_N = 7;
+struct rng_data_pod {
+  std::array state{1};
+  uint64_t seed = default_rng_seed_val;
+};
+
+TORCH_API const Generator& getDefaultMPSGenerator();
+TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
+
+} // namespace mps::detail
+
+struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
+  // Constructors
+  MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
+  ~MPSGeneratorImpl() override = default;
+
+  // MPSGeneratorImpl methods
+  std::shared_ptr clone() const;
+  void set_current_seed(uint64_t seed) override;
+  void set_offset(uint64_t offset) override;
+  uint64_t get_offset() const override;
+  uint64_t current_seed() const override;
+  uint64_t seed() override;
+  void set_state(const c10::TensorImpl& new_state) override;
+  c10::intrusive_ptr get_state() const override;
+  void update_philox_counters();
+
+  void set_engine(at::Philox4_32 engine) { engine_ = engine; };
+  at::Philox4_32 engine() { return engine_; };
+  uint32_t* state_data() { return data_.state.data(); }
+  static DeviceType device_type() { return DeviceType::MPS; };
+
+private:
+  mps::detail::rng_data_pod data_;
+  at::Philox4_32 engine_;
+
+  MPSGeneratorImpl* clone_impl() const override;
+};
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSGuardImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSGuardImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..587ebdba6d240e68aa06e64eeccd48328e85c647
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSGuardImpl.h
@@ -0,0 +1,174 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __OBJC__
+#include 
+#include 
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+namespace at::mps {
+
+typedef MPSEvent* mpsEvent_t;
+
+// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
+// https://github.com/pytorch/pytorch/issues/77170
+struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
+  static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
+
+  // constructor
+  MPSGuardImpl() {}
+  explicit MPSGuardImpl(c10::DeviceType t) {
+    TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
+  }
+
+  // returns the type
+  c10::DeviceType type() const override {
+    return c10::DeviceType::MPS;
+  }
+
+  Device exchangeDevice(Device d) const override {
+    return Device(c10::DeviceType::MPS, 0);
+  }
+
+  Device getDevice() const override {
+    return Device(c10::DeviceType::MPS, 0);
+  }
+
+  c10::optional uncheckedGetDevice() const noexcept {
+    return Device(c10::DeviceType::MPS, 0);
+  }
+
+  void setDevice(Device d) const override {
+    TORCH_INTERNAL_ASSERT(d.is_mps());
+  }
+
+  void uncheckedSetDevice(Device d) const noexcept override {
+    // TODO: Currently setting only device 0
+  }
+
+  Stream getStream(Device d) const noexcept override {
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+
+  Stream getDefaultStream(Device d) const override {
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+
+  // NB: These do NOT set the current device
+  Stream exchangeStream(Stream s) const noexcept override {
+    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
+  }
+  DeviceIndex deviceCount() const noexcept override {
+    if (at::hasMPS()) {
+      //TODO: extend it for multi-device case
+      return 1;
+    } else {
+      return 0;
+    }
+  }
+
+  // Event-related functions
+  void createEvent(
+    mpsEvent_t* event,
+    const EventFlag flag) const;
+
+  void destroyEvent(
+    void* event,
+    const DeviceIndex device_index) const noexcept override;
+
+  void record(
+    void** event,
+    const Stream& stream,
+    const DeviceIndex device_index,
+    const EventFlag flag) const override;
+
+  void block(
+    void* event,
+    const Stream& stream) const override;
+
+  bool queryEvent(void* event) const override;
+
+};
+
+/// A variant of OptionalDeviceGuard that is specialized for MPS.
+struct OptionalMPSGuard {
+  explicit OptionalMPSGuard() : guard_() {}
+
+  explicit OptionalMPSGuard(c10::optional device_opt)
+      : guard_(device_opt) {}
+
+  /// Set the current MPS device to the passed device index, if it is not
+  /// nullopt
+  explicit OptionalMPSGuard(c10::optional device_index_opt)
+      : guard_(device_index_opt) {}
+
+  // Copy is not allowed
+  OptionalMPSGuard(const OptionalMPSGuard&) = delete;
+  OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
+  OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
+  OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
+
+  /// Sets the MPS device to the given device, initializing the guard if it
+  /// is not already initialized.  Errors if the given device is not a MPS
+  /// device.
+  void set_device(Device device) {
+    guard_.set_device(device);
+  }
+
+  /// Sets the MPS device to the given device, initializing the guard if it is
+  /// not already initialized.  Errors if the given device is not a MPS device.
+  void reset_device(Device device) {
+    guard_.reset_device(device);
+  }
+
+  /// Sets the MPS device to the given device index, initializing the guard if
+  /// it is not already initialized.
+  void set_index(DeviceIndex device_index) {
+    guard_.set_index(device_index);
+  }
+
+  /// Returns the device that was set immediately prior to initialization of the
+  /// guard, or nullopt if the guard is uninitialized.
+  c10::optional original_device() const {
+    return guard_.original_device();
+  }
+
+  /// Returns the most recent device that was set using this device guard,
+  /// either from construction, or via set_device, if the guard is initialized,
+  /// or nullopt if the guard is uninitialized.
+  c10::optional current_device() const {
+    return guard_.current_device();
+  }
+
+  /// Restore the original MPS device, resetting this guard to uninitialized
+  /// state.
+  void reset() {
+    guard_.reset();
+  }
+
+ private:
+  c10::impl::InlineOptionalDeviceGuard guard_;
+};
+
+
+C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSHooks.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSHooks.h
new file mode 100644
index 0000000000000000000000000000000000000000..b49d620527dc1bb63833644a74d7683ac635b98b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSHooks.h
@@ -0,0 +1,57 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::mps {
+
+// The real implementation of MPSHooksInterface
+struct MPSHooks : public at::MPSHooksInterface {
+  MPSHooks(at::MPSHooksArgs) {}
+  void initMPS() const override;
+
+  // MPSDevice interface
+  bool hasMPS() const override;
+  bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
+
+  // MPSGeneratorImpl interface
+  const Generator& getDefaultMPSGenerator() const override;
+
+  // MPSStream interface
+  void deviceSynchronize() const override;
+  void commitStream() const override;
+  void* getCommandBuffer() const override;
+  void* getDispatchQueue() const override;
+
+  // MPSAllocator interface
+  Allocator* getMPSDeviceAllocator() const override;
+  void emptyCache() const override;
+  size_t getCurrentAllocatedMemory() const override;
+  size_t getDriverAllocatedMemory() const override;
+  void setMemoryFraction(double ratio) const override;
+
+  // MPSProfiler interface
+  void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
+  void profilerStopTrace() const override;
+
+  // MPSEvent interface
+  uint32_t acquireEvent(bool enable_timing) const override;
+  void releaseEvent(uint32_t event_id) const override;
+  void recordEvent(uint32_t event_id) const override;
+  void waitForEvent(uint32_t event_id) const override;
+  void synchronizeEvent(uint32_t event_id) const override;
+  bool queryEvent(uint32_t event_id) const override;
+  double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
+
+  // Compatibility with Accelerator API
+  bool hasPrimaryContext(DeviceIndex device_index) const override {
+    // When MPS is available, it is always in use for the one device.
+    return true;
+  }
+};
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSProfiler.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSProfiler.h
new file mode 100644
index 0000000000000000000000000000000000000000..0c6ce179943c2a8f790329eddf14e1e89b2a04dd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSProfiler.h
@@ -0,0 +1,393 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::mps {
+
+namespace Profiler {
+
+struct BaseInfo {
+  // profiling info types
+  enum class Type {
+    GRAPH,
+    KERNEL,
+    COPY,
+    CPU_FALLBACK,
+  };
+
+  BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
+      type(infoType), profileId(Id), handle(Handle) { }
+  virtual ~BaseInfo() = default;
+
+  // type of profiling info
+  Type type;
+  // unique profile ID for execution instances of operations or copies
+  uint64_t profileId;
+  // ID generated by os_signpost
+  // since it's possible to use event and interval-based signposts at the
+  // same time, we need separate IDs for each.
+  os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
+  // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
+  std::atomic totalGpuTime{0.0};
+  // accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
+  std::atomic totalSchedulingTime{0.0};
+  // indicates if the operation or copy execution has completed
+  std::atomic_bool completed{false};
+  // handle used to identify the profile info's instance (usually the pointer)
+  const uintptr_t handle;
+
+  virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
+  // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
+  static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
+    if (tensor.defined()) {
+      std::stringstream tensorStr;
+      auto deviceType = tensor.device().type();
+      tensorStr << c10::DeviceTypeName(deviceType);
+      // see comments for INCLUDE_BUFFER_ID
+      if (includeBufferId && deviceType == at::kMPS) {
+        id buffer = __builtin_bit_cast(id, tensor.storage().data());
+        tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
+                  << ":" << buffer.retainCount << ")";
+      }
+      tensorStr << ":"
+                << tensor.scalar_type() << tensor.sizes();
+      return tensorStr.str();
+    } else {
+      return "undefined";
+    }
+  }
+  static uint64_t getTime() {
+    return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
+  }
+};
+
+struct OperationInfo : BaseInfo {
+  OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
+      BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
+
+  uint64_t runCount = 0;
+  std::string strKey;
+
+  const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
+
+  // builds a string for a kernel
+  static std::string buildKernelString(const std::string& kernelName,
+                                       const TensorList& tensors,
+                                       bool includeBufferId = false) {
+    std::stringstream kernelStr;
+    kernelStr << kernelName;
+    for (const Tensor& tensor: tensors) {
+      kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
+    }
+    return kernelStr.str();
+  }
+};
+
+struct CpuFbInfo : BaseInfo {
+  CpuFbInfo(uint64_t Id, const std::string& OpName) :
+      BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
+
+  uint64_t runCount = 0;
+  // the current and total overhead of copies in bytes required to convert the Op's
+  // input tensors from MPS to CPU and then output from CPU back to MPS
+  size_t currentCopyOverhead = 0;
+  size_t totalCopyOverhead = 0;
+  std::string opName;
+  std::string strKey;
+  uint64_t startTime = 0;
+
+  const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
+
+  void updateCopyOverhead(const TensorList& tensors) {
+    currentCopyOverhead = 0;
+    for (const Tensor& tensor: tensors) {
+      if (tensor.defined()) {
+        currentCopyOverhead += tensor.nbytes();
+      }
+    }
+    totalCopyOverhead += currentCopyOverhead;
+  }
+};
+
+struct CopyInfo : BaseInfo {
+  enum class Kind {
+    MPS_TO_MPS,
+    MPS_TO_CPU,
+    CPU_TO_MPS,
+  };
+
+  CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
+           BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
+           length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
+
+  Kind kind;
+  size_t length;
+  bool isNonBlocking;
+  bool usesBlitter;
+  std::string srcStrKey;
+  std::string dstStrKey;
+  // for copies that don't use blitters, we measure CPU time
+  uint64_t startTime = 0;
+
+  const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
+
+  static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
+
+  static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
+    if (tensor.has_value()) {
+      return tensor->device().type() == at::kMPS;
+    }
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
+    // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
+    return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
+  }
+
+  static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
+                          const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
+    const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
+    const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
+    if (isSrcOnMPS && !isDstOnMPS) {
+      return Kind::MPS_TO_CPU;
+    } else if (!isSrcOnMPS && isDstOnMPS) {
+      return Kind::CPU_TO_MPS;
+    }
+    return Kind::MPS_TO_MPS;
+  }
+};
+
+struct CopyStat : CopyInfo {
+  explicit CopyStat(std::string CopyKindStr) :
+          CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
+  // total number of copies
+  size_t totalCount = 0;
+  // number of Scalar copies (i.e., less than sizeof(int64))
+  size_t scalarsCount = 0;
+  // number of blocking copies (i.e., require syncing to GPU)
+  size_t blockingCount = 0;
+  // number of copies that used memcpy(), instead of Metal Blit Encoder
+  size_t memcpyCount = 0;
+  // accumulated GPU time in ms for the scalar copies
+  std::atomic scalarsGpuTime{0.0};
+  // copy kind in string type
+  std::string kindStr;
+};
+
+class MPSProfiler {
+public:
+  // lower 16 bits used for profiler options
+  enum ProfileOptions : uint32_t {
+    OPTIONS_NONE = 0,
+    // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
+    // (used for convenience to not compute bit flags by OR-ing manually)
+    // trace all signpost types using events
+    ALL_SIGNPOST_EVENTS    = (1 << 0),
+    // trace all signpost types using intervals
+    ALL_SIGNPOST_INTERVALS = (1 << 1),
+    // always wait for command buffer to finish executing after each commit
+    WAIT_UNTIL_COMPLETED   = (1 << 2),
+    // for interval-based signposts, include the scheduling portion of
+    // Graph/Kernel/Copy executions as well.
+    // if flag is disable, only "GPU run time" is included in interval,
+    // and not schedule time.
+    INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
+
+    // use these if you need to trace signposts types individually (rarely required)
+    // trace signpost using intervals
+    USE_INTERVALS = (1 << 4),
+    // trace signpost by emitting events
+    USE_EVENTS    = (1 << 5),
+    // used for sanity check (Change this when new option added)
+    OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
+  };
+
+  // when adding new types, #define the type string in MPSProfiler.mm as well.
+  // upper 16 bits used for event types
+  enum SignpostTypes : uint32_t {
+    SIGNPOST_NONE = 0,
+    // trace signposts for PyTorch operation executions
+    RUN_OPERATION = (1 << 16),
+    // trace signposts for blitter copies
+    BLIT_COPY     = (1 << 17),
+    // trace signposts for ops that fall back on CPU
+    CPU_FALLBACK  = (1 << 18),
+    // used for sanity check (Change this when new type added)
+    SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
+  };
+
+  enum LogOptions : uint32_t {
+    LOG_NONE = 0,
+
+    // Info logging options during execution
+    // -------------------------------------
+    // prints operation info (id/key/run_count) during execution
+    OPERATION_INFO      = (1 << 0),
+    // prints copy info (src/dst tensors/buffers, size, etc.) during execution
+    COPY_INFO           = (1 << 1),
+    // prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
+    CPU_FALLBACK_INFO   = (1 << 2),
+
+    // Profiling Statistics logging options when process terminates
+    // ------------------------------------------------------------
+    // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
+    // this is convenient to not combine following stats bit flags manually
+    ALL_STATS           = (1 << 3),
+    // prints operation stats (GPU times, run count, etc.) before process terminates
+    OPERATION_STATS     = (1 << 4),
+    // prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
+    COPY_STATS          = (1 << 5),
+    // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
+    // for tensors, etc.) before process terminates
+    CPU_FALLBACK_STATS  = (1 << 6),
+
+    // Metadata format options when logging the info
+    // ---------------------------------------------
+    // if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
+    // from Metal Command Buffers) (e.g., [GPU=0.324 ms])
+    INCLUDE_GPU_TIME    = (1 << 7),
+    // if enabled, includes GPU scheduling time in metadata separately
+    // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
+    // e.g., [GPU=0.324 ms, KRNL=0.036 ms]
+    INCLUDE_KERNEL_TIME = (1 << 8),
+    // if enabled, includes the unique buffer ID in metadata for the storage
+    // of a tensor that was allocated on MPSAllocator. This is useful (along with
+    // the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
+    // with various operations.
+    INCLUDE_BUFFER_ID   = (1 << 9),
+
+    // used for sanity check (Change this when new option added)
+    LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
+  };
+
+  explicit MPSProfiler();
+  ~MPSProfiler();
+
+  // the handle is either "MPSGraph*" or "id" for Metal Kernels
+  // the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
+  uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
+  uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
+  uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
+                            const OptionalTensorRef srcTensor,
+                            const OptionalTensorRef dstTensor,
+                            size_t length, bool isNonBlocking, bool usesBlitter = true);
+  uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
+  void beginProfileGPUInterval(const void* handle);
+
+  void endProfileCopy(uint64_t profileId, SyncType syncType);
+  void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
+  void endProfileCPUFallback(const std::string& opName);
+
+  // these are used to hook into Python bindings for torch.mps.profiler module.
+  // this enables generating OS Signpost traces from MPSProfiler on-demand
+  // during runtime (instead of environment variables).
+  // The "mode" could be either "interval", "event", or both "interval,event"
+  // for interval-based and/or event-based signpost tracing.
+  void StartTrace(const string& mode, bool waitUntilCompleted);
+  void StopTrace();
+
+  // convenience functions to indicate whether signpost tracing or
+  // logging are enabled for the SignpostTypes
+  bool isOperationProfilingEnabled() const {
+    return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
+           (m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
+  }
+  bool isCopyProfilingEnabled() const {
+    return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
+           (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
+  }
+  bool isCPUFallbackProfilingEnabled() const {
+    return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
+           (m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
+  }
+  bool isSignpostTracingEnabled() const {
+    return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
+  }
+
+ private:
+  // indicates what type of signpost types are enabled and traced by MPS profiler.
+  uint32_t m_signpost_types = 0;
+  uint32_t m_profile_options = 0;
+  uint32_t m_log_options = 0;
+  uint64_t m_kernel_counter = 0;
+  uint64_t m_graph_counter = 0;
+  uint64_t m_cpu_fb_counter = 0;
+  uint64_t m_copy_counter = 0;
+  // technically, it's possible to trace both events and intervals at the same time
+  // so we use separate os_log categories for them
+  os_log_t m_os_log_events;
+  os_log_t m_os_log_intervals;
+  // stats logging could run either from destructor or signal handler
+  // so this is used to check if logging has already started.
+  std::atomic_bool hasLoggedStats{false};
+  // indicates there are pending completionHandler callbacks that haven't been called yet.
+  std::atomic_bool hasPendingCompletionHandlers{false};
+  // used to capture sigint signal to log profiling stats
+  static struct sigaction currentSigint, previousSigint;
+
+  // We use the following lists for two reasons:
+  // 1- for interval-based signposts the "begin" point won't be in same function
+  // as the "end" point where we need to be able to retrieve signpost's info
+  // 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
+
+  // the pointer key for this map is either "MPSGraph*" or "id" for Metal Kernels
+  // this list is retained and could be logged along with aggregate profiling numbers when the process ends.
+  std::unordered_map> m_op_info_list{};
+  // the string key for this map is the op name that we fall back to execute on CPU
+  // this list is retained and could be logged along with aggregate profiling numbers when the process ends.
+  std::unordered_map> m_cpu_fb_info_list{};
+  // this list contains the info for copies, and its key is the unique profileId
+  // which is generated from m_copy_counter
+  // The copyInfo list is not retained.
+  std::unordered_map> m_copy_info_list{};
+  // a short list that contains copy stats
+  std::unordered_map> m_copy_stat_list{};
+
+  void initialize();
+  void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
+  void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
+                           os_signpost_id_t interval_signpost_id,
+                           double gpuTime, double schedulingTime);
+  void addProfilerScheduledHandler(BaseInfo& info);
+  void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
+  void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
+                         const std::string& msg) const;
+  void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
+                             const std::string& msg) const;
+  void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
+
+  void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
+  // returns true if logging the profiling info "during the execution" is enabled
+  bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
+  // logs all the profiling stats that are enabled
+  void logProfilingStats();
+  // logs kernel profiling stats when the process ends.
+  void logOperationsProfilingStats(std::FILE* f) const;
+  // logs CPU Fallback profiling stats when the process ends.
+  void logCPUFallbackProfilingStats(std::FILE* f) const;
+  // logs copy profiling stats when the process ends.
+  void logCopyProfilingStats(std::FILE* f) const;
+
+  os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
+  static SignpostTypes getSignpostType(BaseInfo::Type infoType);
+  static void handleIntSignal(int signal);
+};
+
+} // namespace Profiler
+
+Profiler::MPSProfiler& getMPSProfiler();
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSStream.h b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSStream.h
new file mode 100644
index 0000000000000000000000000000000000000000..cc838de8e69bccebf8a385c7f9fc7eeb945302e7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/mps/MPSStream.h
@@ -0,0 +1,133 @@
+//  Copyright © 2022 Apple Inc.
+
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __OBJC__
+#include 
+#include 
+#include 
+#include 
+typedef id MTLCommandQueue_t;
+typedef id MTLCommandBuffer_t;
+typedef id MTLComputeCommandEncoder_t;
+typedef id MTLSharedEvent_t;
+typedef id MTLDevice_t;
+#else
+typedef void* MTLCommandQueue_t;
+typedef void* MTLCommandQueue;
+typedef void* MTLCommandBuffer_t;
+typedef void* MTLCommandBuffer;
+typedef void* MTLComputeCommandEncoder_t;
+typedef void* MTLSharedEvent_t;
+typedef void* dispatch_queue_t;
+typedef void* MTLDevice_t;
+#define nil NULL;
+#endif
+
+
+namespace at::mps {
+
+//-----------------------------------------------------------------
+//  MPSStream
+//-----------------------------------------------------------------
+
+enum class SyncType {
+  NONE,               // no commit to command buffer
+  COMMIT,             // commit and flush the command buffer
+  COMMIT_AND_WAIT,    // flush and wait for command buffer execution to finish
+  COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
+  COMMIT_ADAPTIVE,    // commit adaptively based on available memory
+};
+
+class TORCH_API MPSStream
+{
+public:
+  enum Unchecked { UNCHECKED };
+
+  /// Construct a MPSStream from a Stream.  This construction is checked,
+  /// and will raise an error if the Stream is not, in fact, a MPS stream.
+  explicit MPSStream(Stream stream);
+
+  ~MPSStream();
+  MTLCommandQueue_t commandQueue() const { return _commandQueue; };
+  dispatch_queue_t queue() const { return _serialQueue; }
+
+  MPSCommandBuffer* commandBuffer();
+  MTLComputeCommandEncoder_t commandEncoder();
+  void endKernelCoalescing();
+  void synchronize(SyncType syncType);
+  void fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
+  void copy(id srcBuffer, id dstBuffer,
+            size_t length, size_t srcOffset, size_t dstOffset,
+            uint64_t profileId, SyncType syncType = SyncType::NONE);
+  void copy_and_sync(id srcBuffer, id dstBuffer,
+                     size_t length, size_t srcOffset, size_t dstOffset,
+                     bool non_blocking, uint64_t profileId);
+  void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
+  void addCompletedHandler(MTLCommandBufferHandler block);
+
+  /// Get the MPS device index that this stream is associated with.
+  c10::DeviceIndex device_index() const { return _stream.device_index(); }
+
+  MTLCommandQueue_t stream() const { return _commandQueue; };
+
+  MTLDevice_t device() const { return [_commandQueue device];}
+
+  /// Explicit conversion to Stream.
+  Stream unwrap() const { return _stream; }
+
+private:
+  Stream _stream;
+  MTLCommandQueue_t _commandQueue = nil;
+  MPSCommandBuffer* _commandBuffer = nil;
+  MPSCommandBuffer* _prevCommandBuffer = nil;
+  MTLComputeCommandEncoder_t _commandEncoder = nil;
+  MPSGraphExecutionDescriptor *_executionDescriptor = nil;
+  MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
+  dispatch_queue_t _serialQueue = nullptr;
+  // CommitAndContinue is enabled by default
+  bool _enableCommitAndContinue = true;
+
+  // use synchronize() to access any of these commit functions outside MPSStream
+  void commit();
+  void commitAndWait();
+  void commitAndContinue();
+  void flush();
+};
+
+/**
+ * Get the current MPS stream
+ */
+TORCH_API MPSStream* getCurrentMPSStream();
+
+/**
+ * Get the default MPS stream
+ */
+TORCH_API MPSStream* getDefaultMPSStream();
+
+//-----------------------------------------------------------------
+//  MPSStreamImpl
+//-----------------------------------------------------------------
+
+class TORCH_API MPSStreamImpl
+{
+ public:
+  /**
+   * Gets single instance of the MPSStream.
+   */
+  static MPSStream* getInstance();
+
+ private:
+  static MPSStream* _stream;
+  MPSStreamImpl();
+};
+
+} // namespace at::mps
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Activation.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Activation.h
new file mode 100644
index 0000000000000000000000000000000000000000..d9eb8081fe06aa8daf4d3b4a3ae5a7ebbbf37ed9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Activation.h
@@ -0,0 +1,98 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+struct TensorIterator;
+struct TensorIteratorBase;
+class TensorBase;
+}
+
+namespace at::native {
+
+// These constants control the approximation behavior of gelu function.
+enum class GeluType {
+  None,             // Baseline Gelu
+  Tanh,             // Tahn Gelu Approximation
+  END
+};
+
+static GeluType get_gelutype_enum(const c10::string_view approximate) {
+  if (approximate == "none") {
+    return GeluType::None;
+  } else if (approximate == "tanh") {
+    return GeluType::Tanh;
+  } else {
+    TORCH_CHECK(false, "approximate argument must be either none or tanh.");
+  }
+}
+
+static std::string gelutype_to_string(const GeluType type) {
+  switch(type) {
+    case GeluType::None: return "none";
+    case GeluType::Tanh: return "tanh";
+    default: TORCH_CHECK(false, "unknown GELU type: ", static_cast(type));
+  }
+}
+
+using structured_activation_fn = void (*)(TensorIteratorBase&);
+using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
+
+using activation_fn = void (*)(TensorIterator&);
+using activation_backward_fn = void (*)(TensorIterator&);
+using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
+using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
+using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
+using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
+using hardsigmoid_fn = void(*)(TensorIteratorBase&);
+using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
+using hardswish_fn = void(*)(TensorIterator&);
+using hardswish_backward_fn = void(*)(TensorIterator&);
+using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
+using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
+using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
+using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
+using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
+using glu_jvp_fn = void (*)(TensorIteratorBase&);
+
+DECLARE_DISPATCH(elu_fn, elu_stub);
+DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
+DECLARE_DISPATCH(softplus_fn, softplus_stub);
+DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
+DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
+DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
+DECLARE_DISPATCH(threshold_fn, threshold_stub);
+DECLARE_DISPATCH(gelu_fn, GeluKernel);
+DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
+DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
+DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
+DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
+DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
+DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
+DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
+DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
+DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
+DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
+DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
+DECLARE_DISPATCH(structured_activation_fn, glu_stub);
+DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
+DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub);
+DECLARE_DISPATCH(structured_activation_fn, silu_stub);
+DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
+DECLARE_DISPATCH(structured_activation_fn, mish_stub);
+DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
+DECLARE_DISPATCH(activation_fn, prelu_stub);
+DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/AdaptivePooling.h b/MLPY/Lib/site-packages/torch/include/ATen/native/AdaptivePooling.h
new file mode 100644
index 0000000000000000000000000000000000000000..539a08ffee79c87028f10d7d71076a91006ddff2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/AdaptivePooling.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
+using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
+DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel);
+DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel);
+
+using adaptive_max_pooling_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
+using adaptive_max_pooling_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
+DECLARE_DISPATCH(adaptive_max_pooling_fn, adaptive_max_pool2d_kernel);
+DECLARE_DISPATCH(adaptive_max_pooling_backward_fn, adaptive_max_pool2d_backward_kernel);
+
+static inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
+  return (a / b) * c + ((a % b) * c) / b;
+}
+
+static inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
+  return 1 + ((a + 1) * c - 1) / b;
+}
+
+static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
+  int64_t ndim = gradOutput_.ndimension();
+  for (const auto i : c10::irange(1, ndim)) {
+    TORCH_CHECK(gradOutput_.size(i) > 0,
+      arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, "
+      "but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i,
+      " being empty");
+  }
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/AmpKernels.h b/MLPY/Lib/site-packages/torch/include/ATen/native/AmpKernels.h
new file mode 100644
index 0000000000000000000000000000000000000000..6bfd41885067671998aabe0bd178816c05325b1e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/AmpKernels.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
+    TensorList,
+    Tensor&,
+    const Tensor&);
+
+using _amp_update_scale_cpu__fn = Tensor& (*)(
+    Tensor&,
+    Tensor&,
+    const Tensor&,
+    double,
+    double,
+    int64_t);
+
+DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
+DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h b/MLPY/Lib/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h
new file mode 100644
index 0000000000000000000000000000000000000000..67b1c18d24e8b3e1f371020ac0fadf239fe37044
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h
@@ -0,0 +1,321 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+// Forward declare TI
+namespace at {
+class Tensor;
+struct TensorIterator;
+
+namespace native {
+enum class TransposeType;
+}
+
+}
+
+namespace at::native {
+
+enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
+
+#if AT_BUILD_WITH_LAPACK()
+// Define per-batch functions to be used in the implementation of batched
+// linear algebra operations
+
+template 
+void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
+
+template 
+void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
+
+template 
+void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
+
+template 
+void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
+
+template 
+void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
+
+template 
+void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
+
+template 
+void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
+
+template 
+void lapackGels(char trans, int m, int n, int nrhs,
+    scalar_t *a, int lda, scalar_t *b, int ldb,
+    scalar_t *work, int lwork, int *info);
+
+template 
+void lapackGelsd(int m, int n, int nrhs,
+    scalar_t *a, int lda, scalar_t *b, int ldb,
+    value_t *s, value_t rcond, int *rank,
+    scalar_t* work, int lwork,
+    value_t *rwork, int* iwork, int *info);
+
+template 
+void lapackGelsy(int m, int n, int nrhs,
+    scalar_t *a, int lda, scalar_t *b, int ldb,
+    int *jpvt, value_t rcond, int *rank,
+    scalar_t *work, int lwork, value_t* rwork, int *info);
+
+template 
+void lapackGelss(int m, int n, int nrhs,
+    scalar_t *a, int lda, scalar_t *b, int ldb,
+    value_t *s, value_t rcond, int *rank,
+    scalar_t *work, int lwork,
+    value_t *rwork, int *info);
+
+template 
+struct lapackLstsq_impl;
+
+template 
+struct lapackLstsq_impl {
+  static void call(
+      char trans, int m, int n, int nrhs,
+      scalar_t *a, int lda, scalar_t *b, int ldb,
+      scalar_t *work, int lwork, int *info, // Gels flavor
+      int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
+      value_t *s, // Gelss flavor
+      int *iwork // Gelsd flavor
+      ) {
+    lapackGels(
+        trans, m, n, nrhs,
+        a, lda, b, ldb,
+        work, lwork, info);
+  }
+};
+
+template 
+struct lapackLstsq_impl {
+  static void call(
+      char trans, int m, int n, int nrhs,
+      scalar_t *a, int lda, scalar_t *b, int ldb,
+      scalar_t *work, int lwork, int *info, // Gels flavor
+      int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
+      value_t *s, // Gelss flavor
+      int *iwork // Gelsd flavor
+      ) {
+    lapackGelsy(
+        m, n, nrhs,
+        a, lda, b, ldb,
+        jpvt, rcond, rank,
+        work, lwork, rwork, info);
+  }
+};
+
+template 
+struct lapackLstsq_impl {
+  static void call(
+      char trans, int m, int n, int nrhs,
+      scalar_t *a, int lda, scalar_t *b, int ldb,
+      scalar_t *work, int lwork, int *info, // Gels flavor
+      int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
+      value_t *s, // Gelss flavor
+      int *iwork // Gelsd flavor
+      ) {
+    lapackGelsd(
+        m, n, nrhs,
+        a, lda, b, ldb,
+        s, rcond, rank,
+        work, lwork,
+        rwork, iwork, info);
+  }
+};
+
+template 
+struct lapackLstsq_impl {
+  static void call(
+      char trans, int m, int n, int nrhs,
+      scalar_t *a, int lda, scalar_t *b, int ldb,
+      scalar_t *work, int lwork, int *info, // Gels flavor
+      int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
+      value_t *s, // Gelss flavor
+      int *iwork // Gelsd flavor
+      ) {
+    lapackGelss(
+        m, n, nrhs,
+        a, lda, b, ldb,
+        s, rcond, rank,
+        work, lwork,
+        rwork, info);
+  }
+};
+
+template 
+void lapackLstsq(
+    char trans, int m, int n, int nrhs,
+    scalar_t *a, int lda, scalar_t *b, int ldb,
+    scalar_t *work, int lwork, int *info, // Gels flavor
+    int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
+    value_t *s, // Gelss flavor
+    int *iwork // Gelsd flavor
+    ) {
+  lapackLstsq_impl::call(
+      trans, m, n, nrhs,
+      a, lda, b, ldb,
+      work, lwork, info,
+      jpvt, rcond, rank, rwork,
+      s,
+      iwork);
+}
+
+template 
+void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
+
+template 
+void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
+
+template 
+void lapackLdlHermitian(
+    char uplo,
+    int n,
+    scalar_t* a,
+    int lda,
+    int* ipiv,
+    scalar_t* work,
+    int lwork,
+    int* info);
+
+template 
+void lapackLdlSymmetric(
+    char uplo,
+    int n,
+    scalar_t* a,
+    int lda,
+    int* ipiv,
+    scalar_t* work,
+    int lwork,
+    int* info);
+
+template 
+void lapackLdlSolveHermitian(
+    char uplo,
+    int n,
+    int nrhs,
+    scalar_t* a,
+    int lda,
+    int* ipiv,
+    scalar_t* b,
+    int ldb,
+    int* info);
+
+template 
+void lapackLdlSolveSymmetric(
+    char uplo,
+    int n,
+    int nrhs,
+    scalar_t* a,
+    int lda,
+    int* ipiv,
+    scalar_t* b,
+    int ldb,
+    int* info);
+
+template
+void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
+#endif
+
+#if AT_BUILD_WITH_BLAS()
+template 
+void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
+#endif
+
+using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
+DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
+
+using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
+
+DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
+
+using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
+
+DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
+
+using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
+DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
+
+using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
+DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
+
+using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
+DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
+
+using linalg_eigh_fn = void (*)(
+    const Tensor& /*eigenvalues*/,
+    const Tensor& /*eigenvectors*/,
+    const Tensor& /*infos*/,
+    bool /*upper*/,
+    bool /*compute_eigenvectors*/);
+DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
+
+using lstsq_fn = void (*)(
+    const Tensor& /*a*/,
+    Tensor& /*b*/,
+    Tensor& /*rank*/,
+    Tensor& /*singular_values*/,
+    Tensor& /*infos*/,
+    double /*rcond*/,
+    std::string /*driver_name*/);
+DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
+
+using triangular_solve_fn = void (*)(
+    const Tensor& /*A*/,
+    const Tensor& /*B*/,
+    bool /*left*/,
+    bool /*upper*/,
+    TransposeType /*transpose*/,
+    bool /*unitriangular*/);
+DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
+
+using lu_factor_fn = void (*)(
+    const Tensor& /*input*/,
+    const Tensor& /*pivots*/,
+    const Tensor& /*infos*/,
+    bool /*compute_pivots*/);
+DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
+
+using unpack_pivots_fn = void(*)(
+  TensorIterator& iter,
+  const int64_t dim_size,
+  const int64_t max_pivot);
+DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
+
+using lu_solve_fn = void (*)(
+    const Tensor& /*LU*/,
+    const Tensor& /*pivots*/,
+    const Tensor& /*B*/,
+    TransposeType /*trans*/);
+DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
+
+using ldl_factor_fn = void (*)(
+    const Tensor& /*LD*/,
+    const Tensor& /*pivots*/,
+    const Tensor& /*info*/,
+    bool /*upper*/,
+    bool /*hermitian*/);
+DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
+
+using svd_fn = void (*)(
+    const Tensor& /*A*/,
+    const bool /*full_matrices*/,
+    const bool /*compute_uv*/,
+    const c10::optional& /*driver*/,
+    const Tensor& /*U*/,
+    const Tensor& /*S*/,
+    const Tensor& /*Vh*/,
+    const Tensor& /*info*/);
+DECLARE_DISPATCH(svd_fn, svd_stub);
+
+using ldl_solve_fn = void (*)(
+    const Tensor& /*LD*/,
+    const Tensor& /*pivots*/,
+    const Tensor& /*result*/,
+    bool /*upper*/,
+    bool /*hermitian*/);
+DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/BinaryOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/BinaryOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..ae39cea88579168b97712d6ffc9b57f92d612543
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/BinaryOps.h
@@ -0,0 +1,119 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+
+namespace at {
+struct TensorIterator;
+struct TensorIteratorBase;
+}
+
+namespace at::native {
+
+inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
+  TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
+              "Boolean alpha only supported for Boolean results.");
+  TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
+              || alpha.isIntegral(true),
+              "For integral input tensors, argument alpha must not be a floating point number.");
+  TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
+              "For non-complex input tensors, argument alpha must not be a complex number.")
+}
+
+// Basic checking for all sub functions.
+inline void sub_check(const TensorBase& self, const TensorBase& other) {
+  TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
+              "Subtraction, the `-` operator, with two bool tensors is not supported. "
+              "Use the `^` or `logical_xor()` operator instead.")
+  TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
+              "Subtraction, the `-` operator, with a bool tensor is not supported. "
+              "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
+}
+
+inline void sub_check(const TensorBase& self, const Scalar& scalar) {
+  TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
+              "Subtraction, the `-` operator, with two bool tensors is not supported. "
+              "Use the `^` or `logical_xor()` operator instead.")
+  TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
+              "Subtraction, the `-` operator, with a bool tensor is not supported. "
+              "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
+}
+
+using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
+using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
+using structured_binary_fn = void(*)(TensorIteratorBase&);
+
+using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
+using binary_fn_double = void(*)(TensorIterator&, double);
+using binary_fn = void(*)(TensorIterator&);
+using binary_clamp_fn_alpha =
+    void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
+
+// NB: codegenned
+DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
+
+DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
+DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
+DECLARE_DISPATCH(structured_binary_fn, mul_stub);
+DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
+DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
+DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
+DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
+DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
+DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
+DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
+DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
+DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
+DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
+DECLARE_DISPATCH(binary_fn, logical_xor_stub);
+DECLARE_DISPATCH(binary_fn, logical_and_stub);
+DECLARE_DISPATCH(binary_fn, logical_or_stub);
+DECLARE_DISPATCH(structured_binary_fn, lt_stub);
+DECLARE_DISPATCH(structured_binary_fn, le_stub);
+DECLARE_DISPATCH(structured_binary_fn, gt_stub);
+DECLARE_DISPATCH(structured_binary_fn, ge_stub);
+DECLARE_DISPATCH(structured_binary_fn, eq_stub);
+DECLARE_DISPATCH(structured_binary_fn, ne_stub);
+DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
+DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
+DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
+DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
+DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
+DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
+DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
+DECLARE_DISPATCH(binary_fn_double, huber_stub);
+DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
+DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
+DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
+DECLARE_DISPATCH(structured_binary_fn, mse_stub);
+DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
+DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
+DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
+DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
+DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
+DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
+DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
+DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
+DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
+DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
+DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
+DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
+DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
+DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
+DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
+DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
+DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
+DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
+DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
+DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/BucketizationUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/BucketizationUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..7c6f0599b03234ce674fff05531ba4f97450b3a4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/BucketizationUtils.h
@@ -0,0 +1,173 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
+// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
+// match, will change them to be a common super type so comparisons are done between the same types.
+// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
+// corresponding raw_* version should be used since it was already contiguous of the right type.
+inline void searchsorted_maybe_trim_input_tensors(
+    Tensor& trimmed_input,
+    Tensor& trimmed_boundaries,
+    Tensor& trimmed_sorter,
+    const Tensor& raw_input,
+    const Tensor& raw_boundaries,
+    const Tensor& raw_sorter) {
+  bool in_is_contiguous = raw_input.is_contiguous();
+  bool bd_is_contiguous = raw_boundaries.is_contiguous();
+  bool sort_is_contiguous = raw_sorter.is_contiguous();
+
+  if (!in_is_contiguous) {
+    TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
+      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
+      "tensor if possible. This message will only appear once per program.");
+    trimmed_input = raw_input.contiguous();
+  }
+  if (!bd_is_contiguous) {
+    TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
+      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
+      "tensor if possible. This message will only appear once per program.");
+    trimmed_boundaries = raw_boundaries.contiguous();
+  }
+  if (!sort_is_contiguous) {
+    TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
+      "to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
+      "tensor if possible. This message will only appear once per program.");
+    trimmed_sorter = raw_sorter.contiguous();
+  }
+  if (raw_input.dtype() != raw_boundaries.dtype()) {
+    at::native::ResultTypeState state = {};
+    state = at::native::update_result_type_state(raw_boundaries, state);
+    state = at::native::update_result_type_state(raw_input, state);
+    ScalarType common_stype = at::native::result_type(state);
+
+    TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
+    if (common_stype != raw_input.scalar_type()) {
+      trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
+    }
+    if (common_stype != raw_boundaries.scalar_type()) {
+      trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
+    }
+  }
+}
+
+/* unused but needed for internal jagged tensor class */
+inline void searchsorted_maybe_trim_input_tensors(
+    Tensor& trimmed_input,
+    Tensor& trimmed_boundaries,
+    const Tensor& raw_input,
+    const Tensor& raw_boundaries) {
+  Tensor trimmed_sorter;
+  Tensor raw_sorter;
+  return searchsorted_maybe_trim_input_tensors(
+      trimmed_input,
+      trimmed_boundaries,
+      trimmed_sorter,
+      raw_input,
+      raw_boundaries,
+      raw_sorter);
+}
+
+inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
+  if (boundaries.dim() != input.dim()) {
+    return false;
+  }
+  const auto& dims_bd = boundaries.sizes();
+  const auto& dims_in = input.sizes();
+  for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
+    if (dims_bd[dim] != dims_in[dim]) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
+  auto tensor = c10::scalar_to_tensor(scalar, device);
+  // This is to adopt the scalar promotion rules defined in native/TypeProperties.h
+  // So we have the same type promotion rules as binary operations.
+  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
+  return tensor;
+}
+
+inline void searchsorted_pre_check(
+    const Tensor& boundaries,
+    const Tensor& input,
+    const Tensor& output,
+    const bool out_int32,
+    const bool right,
+    const c10::optional side_opt,
+    const Tensor& sorter) {
+  if (side_opt) {
+    const c10::string_view side = *side_opt;
+    TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
+      "got ", side);
+
+    // assume the user has not explicitly set (right=False, side="right")
+    TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
+    "of ", side, " while right was True");
+  }
+
+  TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
+    "should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
+    "tensor device type ", input.device());
+
+  if (sorter.defined()) {
+    TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
+      "have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
+      "device type ", boundaries.device());
+
+    TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
+      "size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
+
+    TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
+      "dtype but got dtype ", sorter.scalar_type());
+
+    if (sorter.numel() > 0) {
+      auto minmax = sorter.aminmax();
+      int64_t vmin = std::get<0>(minmax).item().toLong();
+      int64_t vmax = std::get<1>(minmax).item().toLong();
+      TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
+    }
+  }
+
+  TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
+    "torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
+    "boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
+    input.numel(), ")");
+
+  TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
+    "got 0 dimension");
+
+  TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
+    "torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
+    "and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
+    input.sizes());
+
+  ScalarType output_dtype = output.scalar_type();
+  TORCH_CHECK(
+      (output_dtype == ScalarType::Long && !out_int32) ||
+          (output_dtype == ScalarType::Int && out_int32),
+      "torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
+      "whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
+      " and out_int32 flag is ", (out_int32 ? "True" : "False"));
+
+  if (out_int32) {
+    TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
+      "torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
+      boundaries.sizes().back());
+  }
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/CPUBlas.h b/MLPY/Lib/site-packages/torch/include/ATen/native/CPUBlas.h
new file mode 100644
index 0000000000000000000000000000000000000000..90a3fd28fb85b17cc75ca2cc04f9bdc947dfb2a2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/CPUBlas.h
@@ -0,0 +1,189 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native::cpublas {
+
+namespace internal {
+void normalize_last_dims(
+  TransposeType transa, TransposeType transb,
+  int64_t m, int64_t n, int64_t k,
+  int64_t *lda, int64_t *ldb, int64_t *ldc);
+}  // namespace internal
+
+using gemm_fn = void(*)(
+    at::ScalarType type,
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    const Scalar& alpha,
+    const void *a, int64_t lda,
+    const void *b, int64_t ldb,
+    const Scalar& beta,
+    void *c, int64_t ldc);
+
+DECLARE_DISPATCH(gemm_fn, gemm_stub);
+
+template 
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    at::opmath_type alpha,
+    const scalar_t *a, int64_t lda,
+    const scalar_t *b, int64_t ldb,
+    at::opmath_type beta,
+    scalar_t *c, int64_t ldc) {
+  internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
+  gemm_stub(
+    kCPU, c10::CppTypeToScalarType::value,
+    transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+}
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    double alpha,
+    const double *a, int64_t lda,
+    const double *b, int64_t ldb,
+    double beta,
+    double *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    float alpha,
+    const float *a, int64_t lda,
+    const float *b, int64_t ldb,
+    float beta,
+    float *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    float alpha,
+    const at::BFloat16 *a, int64_t lda,
+    const at::BFloat16 *b, int64_t ldb,
+    float beta,
+    at::BFloat16 *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    const float alpha,
+    const at::BFloat16 *a, int64_t lda,
+    const at::BFloat16 *b, int64_t ldb,
+    const float beta,
+    float *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    float alpha,
+    const at::Half *a, int64_t lda,
+    const at::Half *b, int64_t ldb,
+    float beta,
+    at::Half *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    const float alpha,
+    const at::Half *a, int64_t lda,
+    const at::Half *b, int64_t ldb,
+    const float beta,
+    float *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    c10::complex alpha,
+    const c10::complex *a, int64_t lda,
+    const c10::complex *b, int64_t ldb,
+    c10::complex beta,
+    c10::complex *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    c10::complex alpha,
+    const c10::complex *a, int64_t lda,
+    const c10::complex *b, int64_t ldb,
+    c10::complex beta,
+    c10::complex *c, int64_t ldc);
+
+void gemm(
+    TransposeType transa, TransposeType transb,
+    int64_t m, int64_t n, int64_t k,
+    int64_t alpha,
+    const int64_t *a, int64_t lda,
+    const int64_t *b, int64_t ldb,
+    int64_t beta,
+    int64_t *c, int64_t ldc);
+
+template 
+void gemm_batched(
+    TransposeType transa, TransposeType transb,
+    int64_t batch_size, int64_t m, int64_t n, int64_t k,
+    scalar_t alpha,
+    const scalar_t * const *a, int64_t lda,
+    const scalar_t * const *b, int64_t ldb,
+    const scalar_t beta,
+    scalar_t * const *c, int64_t ldc);
+
+template 
+void gemm_batched_with_stride(
+    TransposeType transa, TransposeType transb,
+    int64_t batch_size, int64_t m, int64_t n, int64_t k,
+    scalar_t alpha,
+    const scalar_t *a, int64_t lda, int64_t batch_stride_a,
+    const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
+    scalar_t beta,
+    scalar_t *c, int64_t ldc, int64_t batch_stride_c);
+
+using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
+
+DECLARE_DISPATCH(axpy_fn, axpy_stub);
+
+template
+void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
+  if(n == 1)
+  {
+    incx = 1;
+    incy = 1;
+  }
+  axpy_stub(
+      kCPU, c10::CppTypeToScalarType::value,
+      n, a, x, incx, y, incy);
+}
+
+void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
+void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
+void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy);
+void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy);
+
+using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
+
+DECLARE_DISPATCH(copy_fn, copy_stub);
+
+template
+void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
+  if(n == 1)
+  {
+    incx = 1;
+    incy = 1;
+  }
+  copy_stub(
+      kCPU, c10::CppTypeToScalarType::value,
+      n, x, incx, y, incy);
+}
+
+void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
+void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
+void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy);
+void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy);
+
+}  // namespace at::native::cpublas
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/CPUFallback.h b/MLPY/Lib/site-packages/torch/include/ATen/native/CPUFallback.h
new file mode 100644
index 0000000000000000000000000000000000000000..7554956f5ce889a08a30cf140fc2c53924868b17
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/CPUFallback.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+// This function implements a boxed fallback to CPU.
+// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
+TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
+
+// This is a helper function that backends can use to directly call their boxed CPU fallback
+// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
+template
+struct _call_fallback_fn final {};
+
+template
+struct _call_fallback_fn final {
+    static ReturnType call(typename c10::maybe_keep_symint::type... args) {
+        auto op = c10::Dispatcher::singleton()
+            // TODO: figure out how to make compiler happy without dynamic casts
+            .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
+            //.findSchemaOrThrow("a", "b")
+            .typed::type...)>();
+        return c10::impl::BoxedKernelWrapper::type...)>::call(
+            c10::BoxedKernel::makeFromFunction(),
+            op,
+            c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
+            // TODO: get std::forward<> to work
+            args...
+            );
+    }
+};
+
+template
+using call_fallback_fn_symint = _call_fallback_fn;
+
+template
+using call_fallback_fn = _call_fallback_fn;
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h b/MLPY/Lib/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h
new file mode 100644
index 0000000000000000000000000000000000000000..983ff7fe26e332a979ece32d42889081e6c56fcf
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h
@@ -0,0 +1,13 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at::native {
+
+TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits::max());
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ComplexHelper.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ComplexHelper.h
new file mode 100644
index 0000000000000000000000000000000000000000..f5e468e2be88bfdb00818a191d489bf06dab0bb9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ComplexHelper.h
@@ -0,0 +1,97 @@
+#pragma once
+
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+
+#include 
+#endif
+
+// WARNING: this header contains non-inline functions and should be only
+// included from ONE cpp file
+
+namespace at::native {
+
+// View tensor with new dtype, storage offset, sizes and strides
+inline Tensor view_tensor(
+    const Tensor &tensor, ScalarType dtype,
+    c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
+  Storage storage = tensor.storage();
+  auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
+  auto new_tensor = detail::make_tensor(
+      c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
+  auto * impl = new_tensor.unsafeGetTensorImpl();
+  impl->set_sizes_and_strides(sizes, strides, offset);
+  return new_tensor;
+}
+
+inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
+  SymDimVector res(oldstride.size() + 1);
+  for (const auto i : c10::irange(oldstride.size())) {
+    res[i] = oldstride[i] * 2;
+  }
+  res.back() = 1;
+  return res;
+}
+
+inline Tensor _view_as_real_physical(const Tensor& self) {
+  TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
+  auto old_sizes = self.sym_sizes();
+  SymDimVector new_sizes(old_sizes.size() + 1);
+  std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
+  // last dimension will always have two elements containing the real and imag vals
+  new_sizes.back() = 2;
+  auto new_strides = computeStrideForViewAsReal(self.sym_strides());
+  auto new_storage_offset = self.sym_storage_offset() * 2;
+  const auto float_type = c10::toRealValueType(self.scalar_type());
+  auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
+  return real_tensor;
+}
+
+// expects as input a complex tensor and returns back a tensor
+// with corresponding real dtype containing the complex values
+// in the last two dimensions
+Tensor view_as_real(const Tensor& self) {
+  TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors.  To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
+  return _view_as_real_physical(self);
+}
+
+inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
+  const int64_t dim = oldstride.size();
+  TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1");
+
+  SymDimVector res(dim - 1);
+  for (const auto i : c10::irange(res.size())) {
+    TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
+    res[i] = oldstride[i] / 2;
+  }
+  return res;
+}
+
+// expects as input a float or double tensor with last dimension of size 2
+// and returns back a tensor with corresponding complex dtype
+Tensor view_as_complex(const Tensor& self) {
+  TORCH_CHECK(
+    self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
+    "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
+
+  auto old_sizes = self.sym_sizes();
+  TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
+  TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
+  SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
+
+  const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
+  const auto complex_type = c10::toComplexType(self.scalar_type());
+
+  TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
+  const auto new_storage_offset = self.sym_storage_offset() / 2;
+
+  return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h b/MLPY/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h
new file mode 100644
index 0000000000000000000000000000000000000000..27aefd57376f4468da7f628cec608c4e4837c4b5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include 
+
+namespace at::native {
+
+struct TupleInfoCPU {
+  template 
+  using tuple = std::tuple;
+
+  template 
+  static constexpr auto tie(Types&... args) noexcept {
+    return std::tie(args...);
+  }
+};
+
+template 
+using CompositeRandomAccessorCPU =
+  CompositeRandomAccessor;
+
+template 
+void swap(
+  references_holder rh1,
+  references_holder rh2
+) {
+  return std::swap(rh1.data(), rh2.data());
+}
+
+template 
+auto get(references_holder rh) -> decltype(std::get(rh.data())) {
+  return std::get(rh.data());
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h b/MLPY/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h
new file mode 100644
index 0000000000000000000000000000000000000000..5db76a15575c4542004ef6fbedbdc3fb73d1f3fb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h
@@ -0,0 +1,263 @@
+#include 
+
+#pragma once
+
+namespace at::native {
+
+namespace {
+
+// operator_brackets_proxy is used in
+// CompositeRandomAccessor in place of operator[].
+// For some iterators, references returned by operator[]
+// could become invalid, operator_brackets_proxy tries to
+// resolve that by making accessor[n] to be equivalent to
+// *(accessor + n).
+template 
+class operator_brackets_proxy {
+  using reference = typename std::iterator_traits::reference;
+  using value_type = typename std::iterator_traits::value_type;
+
+public:
+  C10_HOST_DEVICE
+  operator_brackets_proxy(Accessor const& accessor)
+    : accessor(accessor)
+  {}
+
+  C10_HOST_DEVICE
+  operator reference() {
+    return *accessor;
+  }
+
+  C10_HOST_DEVICE
+  reference operator*() {
+    return *accessor;
+  }
+
+  C10_HOST_DEVICE
+  operator_brackets_proxy& operator=(value_type const& val) {
+    *accessor = val;
+    return *this;
+  }
+
+private:
+  Accessor accessor;
+};
+
+}
+
+// references_holder is used as a surrogate for the
+// references type from std::iterator_traits in CompositeRandomAccessor.
+// It is assumed in CompositeRandomAccessor that
+// References = tuple,
+// Values = tuple by default,
+// but they could be anything as long as References could be
+// cast to Values.
+// If you plan to use it with STL, for example, you will need to
+// define 'swap` and `get`(aka std::get) methods.
+template 
+class references_holder {
+public:
+  using values = Values;
+  using references = References;
+
+  C10_HOST_DEVICE
+  references_holder(references refs)
+    : refs{std::move(refs)}
+  {}
+
+  C10_HOST_DEVICE
+  operator references() {
+    return refs;
+  }
+
+  C10_HOST_DEVICE
+  operator values() {
+    return refs;
+  }
+
+  C10_HOST_DEVICE
+  references_holder& operator=(values vals) {
+    refs = vals;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  references& data() {
+    return refs;
+  }
+
+protected:
+  references refs;
+};
+
+// CompositeRandomAccessor is essentially a simplified version of
+// a random access iterator over two random access iterators.
+// TupleInfo should contain a variadic type `tuple`, and a method `tie`,
+// which constructs a tuple of references from a variadic list of arguments.
+template 
+class CompositeRandomAccessor {
+  using self_type = CompositeRandomAccessor;
+
+  using key_accessor_value_type =
+    typename std::iterator_traits::value_type;
+  using value_accessor_value_type =
+    typename std::iterator_traits::value_type;
+  using key_accessor_reference_type =
+    typename std::iterator_traits::reference;
+  using value_accessor_reference_type =
+    typename std::iterator_traits::reference;
+
+  using composite_value_type = typename TupleInfo::template tuple<
+    key_accessor_value_type,
+    value_accessor_value_type>;
+  using composite_reference = typename TupleInfo::template tuple<
+    key_accessor_reference_type,
+    value_accessor_reference_type>;
+
+public:
+  using value_type = composite_value_type;
+  using reference = references_holder;
+  // Note that CompositeRandomAccessor does not hold key and values
+  // in a specific datastructure, which means that a pointer to a (key, value)
+  // is not defined. Hence we just use a pointer type of the KeyAccessor.
+  using pointer = typename std::iterator_traits::pointer;
+  using difference_type = typename std::iterator_traits::difference_type;
+  using iterator_category = std::random_access_iterator_tag;
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor() = default;
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
+    : keys(keys), values(values)
+  {}
+
+  // Pointer-like operations {
+  C10_HOST_DEVICE
+  reference operator*() const {
+    return TupleInfo::tie(*keys, *values);
+  }
+
+  // operator->() is supposed to return a pointer type.
+  // Since CompositeRandomAccessor does not hold pointers to pairs,
+  // we just return a pointer to a key.
+  C10_HOST_DEVICE
+  auto* operator->() const {
+    return keys.operator->();
+  }
+
+  C10_HOST_DEVICE
+  reference operator[](difference_type idx) {
+    return operator_brackets_proxy(
+      CompositeRandomAccessor(keys + idx, values + idx)
+    );
+  }
+  // }
+
+  // Prefix/postfix increment/decrement {
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator++() {
+    ++keys;
+    ++values;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator++(int) {
+    CompositeRandomAccessor copy(*this);
+    ++*this;
+    return copy;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator--() {
+    --keys;
+    --values;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator--(int) {
+    CompositeRandomAccessor copy(*this);
+    --*this;
+    return copy;
+  }
+  // }
+
+  // Arithmetic operations {
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator+=(difference_type offset) {
+    keys += offset;
+    values += offset;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator+(difference_type offset) const {
+    return CompositeRandomAccessor(keys + offset, values + offset);
+  }
+
+  C10_HOST_DEVICE
+  friend CompositeRandomAccessor operator+(
+    difference_type offset,
+    const CompositeRandomAccessor& accessor
+  ) {
+    return accessor + offset;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor& operator-=(difference_type offset) {
+    keys -= offset;
+    values -= offset;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  CompositeRandomAccessor operator-(difference_type offset) const {
+    return CompositeRandomAccessor(keys - offset, values - offset);
+  }
+
+  C10_HOST_DEVICE
+  difference_type operator-(const CompositeRandomAccessor& other) const {
+    return keys - other.keys;
+  }
+  // }
+
+  // Comparison operators {
+  C10_HOST_DEVICE
+  bool operator==(const CompositeRandomAccessor& other) const {
+    return keys == other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator!=(const CompositeRandomAccessor& other) const {
+    return keys != other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator<(const CompositeRandomAccessor& other) const {
+    return keys < other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator<=(const CompositeRandomAccessor& other) const {
+    return keys <= other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator>(const CompositeRandomAccessor& other) const {
+    return keys > other.keys;
+  }
+
+  C10_HOST_DEVICE
+  bool operator>=(const CompositeRandomAccessor& other) const {
+    return keys >= other.keys;
+  }
+  // }
+
+protected:
+  KeyAccessor keys;
+  ValueAccessor values;
+};
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ConvUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ConvUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..7784492c004f022390770d494a85f156f97ff69e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ConvUtils.h
@@ -0,0 +1,446 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+using conv_depthwise2d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
+using conv_depthwise3d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
+using cudnn_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, bool, bool, bool, std::array);
+DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
+using mps_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, std::array);
+DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
+using cudnn_convolution_transpose_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array);
+DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
+using miopen_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, bool, bool, std::array);
+DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
+using miopen_convolution_transpose_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array);
+DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
+using miopen_depthwise_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, bool, bool, std::array);
+DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
+using mkldnn_convolution_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, int64_t, std::array);
+DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
+using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const c10::optional&,
+    IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
+DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
+using mkldnn_convolution_transpose_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, int64_t, std::array);
+DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
+using slow_conv_dilated2d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
+using slow_conv_dilated3d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
+using slow_conv_transpose2d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
+using slow_conv_transpose3d_backward_fn = std::tuple(*)(
+    const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
+    at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array);
+DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
+
+namespace {
+  static bool cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
+}
+
+static inline bool cudnnv8_enabled_check_debug() {
+  static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
+  static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
+  static uint8_t cudnnv8_debugcount = 0;
+  if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
+    TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", cudnnv8_heuristic_mode_b);
+    cudnnv8_debugcount++;
+  }
+  return cudnnv8_flag == 1;
+}
+
+static inline bool cudnnv8_use_heur_mode_b() {
+  return cudnnv8_heuristic_mode_b;
+}
+
+// Keep in sync with py::enum_ in Module.cpp
+enum class ConvBackend {
+  CudaDepthwise2d,
+  CudaDepthwise3d,
+  Cudnn,
+  CudnnTranspose,
+  Empty,
+  Miopen,
+  MiopenDepthwise,
+  MiopenTranspose,
+  Mkldnn,
+  MkldnnTranspose,
+  MkldnnEmpty,
+  NnpackSpatial,
+  Overrideable,
+  Slow2d,
+  Slow3d,
+  SlowDilated2d,
+  SlowDilated3d,
+  SlowTranspose2d,
+  SlowTranspose3d,
+  Winograd3x3Depthwise,
+  Xnnpack2d,
+  Mps,
+  MpsTranspose,
+};
+
+// Overload for selecting the convolution backend from the full set of convolution inputs.
+// This overload is exposed to python for testing, etc.
+TORCH_API ConvBackend select_conv_backend(
+    const Tensor& input, const Tensor& weight, const c10::optional& bias_opt,
+    SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
+    bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
+
+TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
+    const Tensor& weight,
+    const ConvBackend backend);
+
+// ---------------------------------------------------------------------
+//
+// Math
+//
+// ---------------------------------------------------------------------
+
+constexpr int input_batch_size_dim = 0;  // also grad_input
+constexpr int input_channels_dim = 1;
+constexpr int output_batch_size_dim = 0;  // also grad_output
+constexpr int output_channels_dim = 1;
+constexpr int weight_output_channels_dim = 0;
+constexpr int weight_input_channels_dim = 1;
+
+// Often written as 2 + max_dim (extra dims for batch size and channels)
+constexpr int max_dim = 3;
+
+// ---------------------------------------------------------------------
+//
+// Checking
+//
+// ---------------------------------------------------------------------
+
+// Used on pad, stride and dilation
+static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
+{
+  TORCH_CHECK(args.size() <= expected_size,
+           "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
+           expected_size, " (while checking arguments for ", c, ")");
+  TORCH_CHECK(args.size() >= expected_size,
+           "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
+           expected_size, " (while checking arguments for ", c, ")");
+
+  auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
+  if (num_negative_values > 0){
+    std::stringstream ss;
+    ss << arg_name << " should be greater than zero but got (";
+    std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", "));
+    ss << args.back() <<  ")" << " (while checking arguments for " << c << ")";
+    AT_ERROR(ss.str());
+  }
+}
+
+
+// NOTE [ Convolution checks ]
+//
+// NB: For many call sites, it is not strictly necessary to check all of
+// these relationships (for example, for forward convolution, we compute
+// the size of output ourselves, so we don't actually need to check
+// output.  However, writing a single function that does everything
+// means we get to reuse it for both forwards and all backwards
+// variants, even when the set of "real" inputs varies.  The magic of
+// relational computing!
+//
+// (There is one downside, which is that it is slightly harder to write
+// error messages which are able to distinguish between real inputs
+// (which the user can change) and computed inputs (which the user can
+// only indirectly affect).  It would be an interesting exercise to
+// come up with a general framework to handle such situations.)
+static void convolution_shape_check(
+    CheckedFrom c,
+    const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
+    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
+{
+  check_args(c, padding, input->dim() - 2, "padding");
+  check_args(c, stride, padding.size(), "stride");
+  check_args(c, dilation, padding.size(), "dilation");
+
+  // Input
+  checkDimRange(c, input, 3, 6 /* exclusive */);
+  checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
+
+  // Weight
+  checkSameDim(c, input, weight);
+
+  // TODO: check that output->size() matches output_sizes
+  // TODO: check that weight matches output->sizes()
+  checkSameDim(c, input, output);
+}
+
+// NB: conv_output_size and conv_input_size are not bijections,
+// as conv_output_size loses information; this is why conv_input_size
+// takes an extra output_padding argument to resolve the ambiguity.
+
+template 
+static inline std::vector _conv_output_size(
+    ArrayRef input_size, ArrayRef weight_size,
+    ArrayRef padding, ArrayRef stride, ArrayRef dilation = ArrayRef()
+) {
+  // ASSERT(input_size.size() > 2)
+  // ASSERT(input_size.size() == weight_size.size())
+  bool has_dilation = !dilation.empty();
+  auto dim = input_size.size();
+  std::vector output_size(dim);
+  output_size[0] = input_size[input_batch_size_dim];
+  output_size[1] = weight_size[weight_output_channels_dim];
+  for (const auto d : c10::irange(2, dim)) {
+    auto dilation_ = has_dilation ? dilation[d - 2] : 1;
+    auto kernel = dilation_ * (weight_size[d] - 1) + 1;
+    output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
+  }
+  return output_size;
+}
+
+static inline std::vector conv_output_size(
+    IntArrayRef input_size, IntArrayRef weight_size,
+    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
+) {
+  return _conv_output_size(input_size, weight_size, padding, stride, dilation);
+}
+
+static inline std::vector conv_output_size(
+    SymIntArrayRef input_size, SymIntArrayRef weight_size,
+    SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
+) {
+  return _conv_output_size(input_size, weight_size, padding, stride, dilation);
+}
+
+template 
+std::vector _conv_input_size(
+    ArrayRef output_size, ArrayRef weight_size,
+    ArrayRef padding, ArrayRef output_padding, ArrayRef stride, ArrayRef dilation, T groups
+) {
+  // ASSERT(output_size.size() > 2)
+  // ASSERT(output_size.size() == weight_size.size())
+  auto dim = output_size.size();
+  std::vector input_size(dim);
+  input_size[0] = output_size[output_batch_size_dim];
+  input_size[1] = weight_size[weight_input_channels_dim] * groups;
+  for (const auto d : c10::irange(2, dim)) {
+    auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
+    input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
+                     kernel + output_padding[d - 2];
+  }
+  return input_size;
+}
+
+static inline std::vector conv_input_size(
+    SymIntArrayRef output_size, SymIntArrayRef weight_size,
+    SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
+) {
+  return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
+}
+
+static inline std::vector conv_input_size(
+    IntArrayRef output_size, IntArrayRef weight_size,
+    IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
+}
+
+template 
+std::vector _conv_weight_size(
+    ArrayRef input_size, ArrayRef output_size,
+    ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  auto dim = input_size.size();
+  std::vector weight_size(dim);
+  weight_size[0] = output_size[1];
+  weight_size[1] = input_size[1] / groups;
+  for (const auto d : c10::irange(2, dim)) {
+    auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
+               + padding[d - 2] * 2 - output_padding[d - 2];
+    weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
+  }
+  return weight_size;
+}
+
+static inline std::vector conv_weight_size(
+    SymIntArrayRef input_size, SymIntArrayRef output_size,
+    SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
+}
+
+static inline std::vector conv_weight_size(
+    IntArrayRef input_size, IntArrayRef output_size,
+    IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
+) {
+  return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
+}
+
+static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
+  std::vector shape(dim, 1);
+  shape[1] = -1;
+  return bias.reshape(shape);
+}
+
+static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
+  // disable NHWC for float64 input.
+  if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
+      input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return at::MemoryFormat::Contiguous;
+  }
+  long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+  auto weight_ndim = weight.ndimension();
+
+  bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
+    (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+    (weight_memory_format == at::MemoryFormat::ChannelsLast)
+  );
+  if (can_use_cudnn_channels_last_2d) {
+    return at::MemoryFormat::ChannelsLast;
+  }
+
+  bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
+    (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+    (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
+  );
+  if (can_use_cudnn_channels_last_3d) {
+    return at::MemoryFormat::ChannelsLast3d;
+  }
+
+  return at::MemoryFormat::Contiguous;
+}
+
+// controls whether emptyCache will be called following cudnn conv benchmarking
+TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
+TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
+
+
+static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // disable NHWC for float64 input.
+  if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
+      input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return false;
+  }
+
+  bool can_use_miopen_channels_last_2d = false;
+#if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
+  // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+  // See #64427
+  static c10::optional PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC &&  *PYTORCH_MIOPEN_SUGGEST_NHWC && (
+            ( (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+            (weight_memory_format == at::MemoryFormat::ChannelsLast) )
+        );
+#endif
+
+  bool can_use_miopen_channels_last_3d = false;
+
+  return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
+}
+
+static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // disable NHWC for float64 input.
+  if (input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return false;
+  }
+
+  // disable NHWC for MkldnnCPU tensor.
+  if (input.is_mkldnn() || weight.is_mkldnn()) {
+    return false;
+  }
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  bool can_use_mkldnn_channels_last_2d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast);
+
+  bool can_use_mkldnn_channels_last_3d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
+
+  return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
+}
+
+static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
+      (input_memory_format  == at::MemoryFormat::ChannelsLast) || (
+       weight_memory_format == at::MemoryFormat::ChannelsLast));
+
+  return can_use_thnn_channels_last_2d;
+}
+
+static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
+
+  // check layout only for xpu tensor.
+  if (!input.is_xpu() || !weight.is_xpu()) {
+    return false;
+  }
+
+  // disable NHWC for float64 input.
+  if (input.scalar_type() == at::kDouble ||
+      weight.scalar_type() == at::kDouble) {
+    return false;
+  }
+
+  auto input_memory_format = input.suggest_memory_format();
+  auto weight_memory_format = weight.suggest_memory_format();
+
+  bool can_use_xpu_channels_last_2d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast);
+
+  bool can_use_xpu_channels_last_3d =
+      (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
+      (weight_memory_format == at::MemoryFormat::ChannelsLast3d);
+
+  return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ConvolutionMM3d.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ConvolutionMM3d.h
new file mode 100644
index 0000000000000000000000000000000000000000..6db6f69d96a67c04ef0e689d88c1cf40392d9e18
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ConvolutionMM3d.h
@@ -0,0 +1,14 @@
+#include 
+
+namespace at::native {
+
+std::tuple slow_conv3d_backward_cpu(
+    const Tensor& grad_output,
+    const Tensor& self,
+    const Tensor& weight,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    std::array output_mask);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Copy.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Copy.h
new file mode 100644
index 0000000000000000000000000000000000000000..200ea0e1d96cc02e7de2e29ea7217f7f1cd0aea0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Copy.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include 
+
+namespace at {
+
+class Tensor;
+struct TensorIterator;
+class TensorBase;
+
+namespace native {
+
+using copy_fn = void (*)(TensorIterator&, bool non_blocking);
+
+DECLARE_DISPATCH(copy_fn, copy_stub);
+
+TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Cross.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Cross.h
new file mode 100644
index 0000000000000000000000000000000000000000..a2bf16e6cd3ad275ee224910480263418c3b91d0
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Cross.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
+
+DECLARE_DISPATCH(cross_fn, cross_stub);
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..687700b8d2a133d4355a65495a44e2bf534ae0bd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h
@@ -0,0 +1,229 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
+  TORCH_CHECK(                                       \
+      T.dim() == DIM && T.size(DIM_SIZE) == SIZE,    \
+      "Need " #T " of dimension ",                   \
+      DIM,                                           \
+      " and " #T ".size[",                           \
+      DIM_SIZE,                                      \
+      "] == ",                                       \
+      SIZE,                                          \
+      " but got input to be of shape ",              \
+      T.sizes())
+
+namespace at::native::internal {
+namespace {
+inline bool all_positive(IntArrayRef& arr) {
+  return std::all_of(
+      arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
+}
+
+inline bool all_nonnegative(std::vector& arr) {
+  return std::all_of(
+      arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
+}
+
+} // namespace
+
+// calculate the rear part of output tensor sizes
+template 
+std::vector get_output_size(
+    const Tensor& input,
+    IntArrayRef kernel_size,
+    IntArrayRef stride_size,
+    IntArrayRef pad_size,
+    IntArrayRef dilation_size) {
+  std::vector sizes;
+  for (const auto index : c10::irange(dim)) {
+    sizes.push_back(
+        div_rtn(
+            input.size(index + input.dim() - dim) + 2 * pad_size[index] -
+                (dilation_size[index] * (kernel_size[index] - 1) + 1),
+            stride_size[index]) +
+        1);
+  }
+  return sizes;
+}
+
+// calculate the sizes of output tensor
+template 
+std::vector get_output_size(
+    const Tensor& input,
+    const Tensor& weight,
+    IntArrayRef kernel_size,
+    IntArrayRef stride_size,
+    IntArrayRef pad_size,
+    IntArrayRef dilation_size) {
+  auto output_size = get_output_size(
+      input, kernel_size, stride_size, pad_size, dilation_size);
+  output_size.insert(output_size.begin(), weight.size(0));
+  if (input.dim() == dim + 2) {
+    output_size.insert(output_size.begin(), input.size(0));
+  }
+  return output_size;
+}
+/*
+  slow_conv_dilated_shape_check - check user-input to dilated convolution
+  forward and backward functions.
+*/
+template 
+void slow_conv_dilated_shape_check(
+    const Tensor& input,
+    const Tensor& weight,
+    const Tensor& bias,
+    const Tensor& grad_output,
+    IntArrayRef kernel_size,
+    IntArrayRef stride_size,
+    IntArrayRef pad_size,
+    IntArrayRef dilation_size) {
+  /*
+    When the following tensors are defined:
+
+    bias, grad_weight, grad_output
+
+    then these are assumed to be contiguous without checking
+    because of these tensors are made contiguous by calling
+    .contiguous() method or by resizing of zero-sized tensors in
+    forward/backward functions.
+
+    When grad_weight is defined then it is assumed without
+    checking to have the same shape as weight, see backward
+    functions.
+   */
+  // Check size arguments
+  TORCH_CHECK(
+      kernel_size.size() == dim,
+      "kernel sizes length should be ",
+      dim,
+      ", but got ",
+      kernel_size.size());
+  TORCH_CHECK(
+      stride_size.size() == dim,
+      "strides length should be ",
+      dim,
+      ", but got ",
+      stride_size.size());
+  TORCH_CHECK(
+      dilation_size.size() == dim,
+      "dilations length should be ",
+      dim,
+      ", but got ",
+      dilation_size.size());
+  TORCH_CHECK(
+      pad_size.size() == dim,
+      "pads length should be ",
+      dim,
+      ", but got ",
+      pad_size.size());
+
+  TORCH_CHECK(
+      all_positive(kernel_size),
+      "kernel size should be greater than zero, but got ",
+      kernel_size);
+  TORCH_CHECK(
+      all_positive(stride_size),
+      "stride should be greater than zero, but got ",
+      stride_size);
+  TORCH_CHECK(
+      all_positive(dilation_size),
+      "dilation should be greater than zero, but got ",
+      dilation_size);
+
+  // check input
+  TORCH_CHECK(input.defined(), "input must be defined");
+  bool is_batch = input.dim() == dim + 2;
+  int64_t n = (is_batch ? 2 : 1);
+  int64_t ndim = n + dim;
+  if (!is_batch) {
+    // input dim has to be dim + 1 if not batched
+    TORCH_CHECK(
+        input.dim() == dim + 1,
+        "input must be 4D or 5D tensor but got ",
+        input.dim(),
+        "D tensor");
+  }
+
+  // check output sizes
+  auto output_size = get_output_size(
+      input, kernel_size, stride_size, pad_size, dilation_size);
+
+  TORCH_CHECK(
+      all_nonnegative(output_size),
+      "calculated output size ",
+      output_size,
+      " is too small (all sizes must be non-negative)");
+
+  // check weight
+  TORCH_CHECK(weight.defined(), "weight must be defined");
+  TORCH_CHECK(
+      weight.dim() == dim + 2,
+      "weight must be ",
+      dim + 2,
+      "D tensor but got ",
+      weight.dim(),
+      "D tensor dim=",
+      dim);
+  TORCH_CHECK(
+      weight.sizes().slice(2) == kernel_size,
+      "weight[2:] shape ",
+      weight.sizes().slice(2),
+      " must be equal to kernel_size ",
+      kernel_size);
+
+  TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
+
+  // check bias when present
+  if (bias.defined()) {
+    TORCH_CHECK(
+        bias.dim() == 1,
+        "bias must be 1D tensor but got ",
+        bias.dim(),
+        "D tensor");
+    TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
+  }
+
+  // check grad_output when present
+  if (grad_output.defined()) {
+    TORCH_CHECK(
+        grad_output.dim() == ndim,
+        "grad_output must be ",
+        ndim,
+        "D tensor but got ",
+        grad_output.dim(),
+        "D tensor");
+    if (is_batch) {
+      TORCH_CHECK(
+          grad_output.size(0) == input.size(0),
+          "grad_output.size(0)=",
+          grad_output.size(0),
+          " must be input.size(0)=",
+          input.size(0));
+    }
+    TORCH_CHECK(
+        grad_output.size(n - 1) == weight.size(0),
+        "grad_output.size(",
+        n - 1,
+        ")=",
+        grad_output.size(n - 1),
+        " must be weight.size(0)=",
+        weight.size(0));
+    TORCH_CHECK(
+        grad_output.sizes().slice(n) == output_size,
+        "grad_output[",
+        n,
+        ":] shape",
+        grad_output.sizes().slice(n),
+        " must be equal to output size ",
+        output_size);
+  }
+}
+
+} // namespace at::native::internal
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/DispatchStub.h b/MLPY/Lib/site-packages/torch/include/ATen/native/DispatchStub.h
new file mode 100644
index 0000000000000000000000000000000000000000..d474f2ce342c9399f52d9b654033132c0330ecdc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/DispatchStub.h
@@ -0,0 +1,315 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+#include 
+
+// Implements instruction set specific function dispatch.
+//
+// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
+// compiled multiple times with different compiler flags (e.g. -mavx2). A
+// DispatchStub contains a table of function pointers for a kernel. At runtime,
+// the fastest available kernel is chosen based on the features reported by
+// cpuinfo.
+//
+// Example:
+//
+// In native/MyKernel.h:
+//   using fn_type = void(*)(const Tensor& x);
+//   DECLARE_DISPATCH(fn_type, stub);
+//
+// In native/MyKernel.cpp
+//   DEFINE_DISPATCH(stub);
+//
+// In native/cpu/MyKernel.cpp:
+//   namespace {
+//     // use anonymous namespace so that different cpu versions won't conflict
+//     void kernel(const Tensor& x) { ... }
+//   }
+//   REGISTER_DISPATCH(stub, &kernel);
+//
+// To call:
+//   stub(kCPU, tensor);
+//
+// TODO: CPU instruction set selection should be folded into whatever
+// the main dispatch mechanism is.
+
+// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
+C10_CLANG_DIAGNOSTIC_PUSH()
+C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
+
+namespace at::native {
+
+enum class CPUCapability {
+  DEFAULT = 0,
+#if defined(HAVE_VSX_CPU_DEFINITION)
+  VSX = 1,
+#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
+  ZVECTOR = 1,
+#else
+  AVX2 = 1,
+  AVX512 = 2,
+#endif
+  NUM_OPTIONS
+};
+
+CPUCapability get_cpu_capability();
+
+template 
+struct DispatchStub;
+
+/**
+ * The sole purpose of this class is to outline methods that don't need to be
+ * specialized or otherwise inlined and duplicated (by the compiler due to
+ * template expansion), since it causes size bloat if there are a significant
+ * number of specialization of the DispatchStub<> class.
+ */
+struct TORCH_API DispatchStubImpl {
+  void* get_call_ptr(
+    c10::DeviceType device_type
+    , void *DEFAULT
+#ifdef HAVE_AVX512_CPU_DEFINITION
+      , void *AVX512
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+      , void *AVX2
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+      , void *VSX
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+      , void *ZVECTOR
+#endif
+  );
+
+  /**
+   * The CPU Dispatch actual method is chosen in decreasing order of preference by
+   * DispatchStubImpl::choose_cpu_impl() in case none is found by
+   * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
+   */
+  void* choose_cpu_impl(
+    void *DEFAULT
+#ifdef HAVE_AVX512_CPU_DEFINITION
+    , void *AVX512
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+    , void *AVX2
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+    , void *VSX
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+    , void *ZVECTOR
+#endif
+  );
+
+  // Fixing dispatch error in Windows debug builds.
+  // See https://github.com/pytorch/pytorch/issues/22681 for more details.
+  #if defined(_MSC_VER) && defined(_DEBUG)
+    std::atomic cpu_dispatch_ptr;
+    void* cuda_dispatch_ptr;
+    void* hip_dispatch_ptr;
+    void* mps_dispatch_ptr;
+    void* privateuse1_dispatch_ptr;
+  #else
+    std::atomic cpu_dispatch_ptr{nullptr};
+    void* cuda_dispatch_ptr = nullptr;
+    void* hip_dispatch_ptr = nullptr;
+    void* mps_dispatch_ptr = nullptr;
+    void* privateuse1_dispatch_ptr = nullptr;
+  #endif
+};
+
+template 
+struct DispatchStub {
+  using FnPtr = rT (*) (Args...);
+
+  DispatchStub() = default;
+  DispatchStub(const DispatchStub&) = delete;
+  DispatchStub& operator=(const DispatchStub&) = delete;
+
+private:
+  FnPtr get_call_ptr(c10::DeviceType device_type) {
+    return reinterpret_cast(
+      impl.get_call_ptr(device_type
+      , reinterpret_cast(DEFAULT)
+#ifdef HAVE_AVX512_CPU_DEFINITION
+      , reinterpret_cast(AVX512)
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+      , reinterpret_cast(AVX2)
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+      , reinterpret_cast(VSX)
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+      , reinterpret_cast(ZVECTOR)
+#endif
+      )
+    );
+  }
+
+public:
+  template 
+  rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
+    FnPtr call_ptr = get_call_ptr(device_type);
+    return (*call_ptr)(std::forward(args)...);
+  }
+
+  void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
+    impl.cuda_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  void set_hip_dispatch_ptr(FnPtr fn_ptr) {
+    impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  void set_mps_dispatch_ptr(FnPtr fn_ptr) {
+    impl.mps_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
+    impl.privateuse1_dispatch_ptr = reinterpret_cast(fn_ptr);
+  }
+
+  static TORCH_API FnPtr DEFAULT;
+#ifdef HAVE_AVX512_CPU_DEFINITION
+  static TORCH_API FnPtr AVX512;
+#endif
+#ifdef HAVE_AVX2_CPU_DEFINITION
+  static TORCH_API FnPtr AVX2;
+#endif
+#ifdef HAVE_VSX_CPU_DEFINITION
+  static TORCH_API FnPtr VSX;
+#endif
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+  static TORCH_API FnPtr ZVECTOR;
+#endif
+private:
+  DispatchStubImpl impl;
+};
+
+namespace {
+template 
+struct RegisterCUDADispatch {
+  RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_cuda_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterMPSDispatch {
+  RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_mps_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterHIPDispatch {
+  RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    // TODO: make this point at hip_dispatch_ptr
+    stub.set_cuda_dispatch_ptr(value);
+  }
+};
+
+template 
+struct RegisterPRIVATEUSE1Dispatch {
+  RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
+    stub.set_privateuse1_dispatch_ptr(value);
+  }
+};
+
+} // anonymous namespace
+// Compiler will complain if you put things like std::tuple in
+// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
+// adding parentheses and using helper struct to get rid of the parentheses, do
+// not work with MSVC. So do a `using`-declaration if you need to pass in such
+// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
+#define DECLARE_DISPATCH(fn, name)         \
+  struct name : DispatchStub {   \
+    name() = default;                      \
+    name(const name&) = delete;            \
+    name& operator=(const name&) = delete; \
+  };                                       \
+  extern TORCH_API struct name name
+
+#define DEFINE_DISPATCH(name) struct name name
+
+#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
+  template <> name::FnPtr TORCH_API DispatchStub::arch = fn;
+
+#ifdef HAVE_AVX512_CPU_DEFINITION
+#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
+#else
+#define REGISTER_AVX512_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_AVX2_CPU_DEFINITION
+#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
+#else
+#define REGISTER_AVX2_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_VSX_CPU_DEFINITION
+#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
+#else
+#define REGISTER_VSX_DISPATCH(name, fn)
+#endif
+
+#ifdef HAVE_ZVECTOR_CPU_DEFINITION
+#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
+#else
+#define REGISTER_ZVECTOR_DISPATCH(name, fn)
+#endif
+
+// Macro to register the same kernel for all CPU arch types. This is useful
+// if a kernel does not benefit from being recompiled across different arch types.
+#define REGISTER_ALL_CPU_DISPATCH(name, fn)                                    \
+  REGISTER_ARCH_DISPATCH(name, DEFAULT, fn)                                    \
+  REGISTER_AVX512_DISPATCH(name, fn)                                           \
+  REGISTER_AVX2_DISPATCH(name, fn)                                             \
+  REGISTER_VSX_DISPATCH(name, fn)                                              \
+  REGISTER_ZVECTOR_DISPATCH(name, fn)
+
+#define REGISTER_NO_CPU_DISPATCH(name)                                         \
+  REGISTER_ALL_CPU_DISPATCH(name, nullptr)
+
+#define REGISTER_CUDA_DISPATCH(name, fn) \
+  static RegisterCUDADispatch name ## __register(name, fn);
+
+#define REGISTER_HIP_DISPATCH(name, fn) \
+  static RegisterHIPDispatch name ## __register(name, fn);
+
+#define REGISTER_MPS_DISPATCH(name, fn) \
+  static RegisterMPSDispatch name ## __register(name, fn);
+
+#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
+  static RegisterPRIVATEUSE1Dispatch name ## __register(name, fn);
+
+// NB: This macro must be used in an actual 'cu' file; if you try using
+// it from a 'cpp' file it will not work!
+#if defined(__CUDACC__)
+#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
+#elif defined(__HIPCC__)
+// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
+// is HIP in the PyTorch HIPify build.
+#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
+// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
+#elif defined(__OBJC__) && defined(USE_MPS)
+// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
+#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
+#elif defined(CPU_CAPABILITY)
+// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
+// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
+#ifdef CPU_CAPABILITY_AVX512
+#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, nullptr)
+#else
+#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
+#endif
+#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
+#endif
+} // namespace at::native
+
+C10_CLANG_DIAGNOSTIC_POP()
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Distance.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Distance.h
new file mode 100644
index 0000000000000000000000000000000000000000..f8f02379d29e911bddbb671a6de5a238aad0b9d9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Distance.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
+using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
+using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
+using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
+
+DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
+DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
+DECLARE_DISPATCH(cdist_fn, cdist_stub);
+DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub);
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/DistributionTemplates.h b/MLPY/Lib/site-packages/torch/include/ATen/native/DistributionTemplates.h
new file mode 100644
index 0000000000000000000000000000000000000000..2f194014c496354ee4d324d5a01561717837e9fd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/DistributionTemplates.h
@@ -0,0 +1,394 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#include 
+#include 
+#endif
+
+namespace at::native::templates {
+
+// ==================================================== Random ========================================================
+
+// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
+// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
+// This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
+//
+//    auto actual = torch::empty({3, 3}, torch::half);
+//    actual.random_(0, 65504);
+//
+// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
+// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
+// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
+// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
+// available number for torch::half dtype.
+template
+int64_t update_from(int64_t from) {
+  static_assert(
+    std::is_floating_point::value ||
+    std::is_same::value ||
+    std::is_same::value, "scalar_t must be floating-point type");
+  const auto from_plus_1 = static_cast(static_cast(from + 1));
+  if (from_plus_1 < from) {
+    int64_t from_ = std::abs(from + 1);
+    int n = 0;
+    while (from_ >>= 1) ++n;
+    // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
+    from = from_plus_1 + (1LL << (n - std::numeric_limits::digits + 1));
+  }
+  return from;
+}
+
+template
+int64_t update_to(int64_t to) {
+  static_assert(
+    std::is_floating_point::value ||
+    std::is_same::value ||
+    std::is_same::value, "scalar_t must be floating-point type");
+  const auto to_minus_1 = static_cast(static_cast(to - 1));
+  if (to_minus_1 >= to) {
+    int64_t to_ = std::abs(to - 1);
+    int n = 0;
+    while (to_ >>= 1) ++n;
+    // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
+    to = to_minus_1 - (1LL << (n - std::numeric_limits::digits + 1));
+  }
+  return to;
+}
+
+// Return earlier for not invoking kernel.
+// See https://github.com/pytorch/pytorch/issues/103418 for more details
+#define CHECK_EMPTY_AND_RETURN(tensor) \
+  if (tensor.numel() == 0) {  \
+    return tensor;  \
+  }
+
+template class random_kernel, typename RNG>
+at::Tensor& random_impl(at::Tensor& self, c10::optional generator) {
+  CHECK_EMPTY_AND_RETURN(self);
+  auto iter = at::TensorIterator::borrowing_nullary_op(self);
+  random_kernel()(iter, generator);
+  return self;
+}
+
+#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
+  TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
+
+#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
+  if (var < -(1LL << digits) || var > (1LL << digits)) { \
+    TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
+      "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
+      "This warning will become an error in version 1.7 release, please fix the code in advance"); \
+  }
+
+static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
+  const auto scalar_type = typeMetaToScalarType(dtype);
+  if (isFloatingType(scalar_type)) {
+    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
+      const auto min = static_cast(std::numeric_limits::lowest());
+      const auto max = static_cast(std::numeric_limits::max());
+      CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
+      CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
+
+      constexpr auto digits = std::numeric_limits::digits;
+      WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
+      WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
+    });
+  } else if (scalar_type == kUInt64) {
+    // When you do a comparison between int64_t and uint64_t, the usual
+    // arithmetic conversions say that the int64_t value is promoted to
+    // unsigned. But this conversion wraps around: if I had -1 as my int64_t,
+    // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
+    // the right thing to do.
+    CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
+    CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
+  } else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
+    AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
+      const auto min = static_cast(std::numeric_limits::lowest());
+      const auto max = static_cast(std::numeric_limits::max());
+      CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
+      CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
+    }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
+  } else {
+    TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
+  }
+}
+
+template class random_from_to_kernel, typename RNG>
+at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional to_opt, c10::optional generator) {
+  uint64_t range = 0;
+  auto iter = at::TensorIterator::borrowing_nullary_op(self);
+  if (to_opt.has_value()) {
+    // [from, to)
+    int64_t to = *to_opt;
+    TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
+    if (isFloatingType(iter.dtype())) {
+      AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
+        from = update_from(from);
+        to = update_to(to);
+        TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
+      });
+    }
+    check_from_to_in_range(from, to - 1, self.dtype());
+    CHECK_EMPTY_AND_RETURN(self);
+    range = static_cast(to) - static_cast(from);
+    random_from_to_kernel()(iter, range, from, generator);
+  } else if (from != std::numeric_limits::lowest()) {
+    // [from, std::numeric_limits::max()]
+    int64_t to_inc = 0;
+    if (isFloatingType(iter.dtype())) {
+      AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
+        constexpr int64_t scalar_t_max = static_cast(1) << std::numeric_limits::digits;
+        to_inc = scalar_t_max > std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(scalar_t_max);
+        from = update_from(from);
+        TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
+      });
+    } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
+      AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
+        if constexpr (std::is_same_v) {
+          to_inc = static_cast(true);
+        } else {
+          to_inc = static_cast(std::numeric_limits::max());
+        }
+      }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
+    } else {
+      TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
+    }
+    check_from_to_in_range(from, to_inc, self.dtype());
+    CHECK_EMPTY_AND_RETURN(self);
+    range = static_cast(to_inc) - static_cast(from) + 1;
+    random_from_to_kernel()(iter, range, from, generator);
+  } else {
+    // [std::numeric_limits::lowest(), std::numeric_limits::max()]
+    // range = 2^64
+    CHECK_EMPTY_AND_RETURN(self);
+    random_from_to_kernel()(iter, generator);
+  }
+  return self;
+}
+
+// ==================================================== Normal ========================================================
+
+#define CHECK_NORMAL_TENSOR_STD(std) \
+  do { \
+    TORCH_CHECK( \
+      !std.is_complex(), \
+      "normal expects standard deviation to be non-complex"); \
+    TORCH_CHECK( \
+      std.numel() == 0 || std.is_meta() || std.min().ge(0).item(), \
+      "normal expects all elements of std >= 0.0"); \
+  } while (0)
+
+#define CHECK_NORMAL_STD(std) \
+  TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
+
+template class normal_kernel, typename RNG>
+Tensor& normal_impl_(Tensor& self, double mean, double std, c10::optional gen) {
+  CHECK_NORMAL_STD(std);
+  CHECK_EMPTY_AND_RETURN(self);
+
+  if (self.is_complex()) {
+    auto float_tensor = at::view_as_real(self);
+    // variance for normal distribution of the real and imaginary values
+    // is half of the input variance
+    normal_kernel()(float_tensor, mean, std/(std::sqrt(2)), gen);
+  } else {
+    normal_kernel()(self, mean, std, gen);
+  }
+  return self;
+}
+
+template class normal_kernel, typename RNG>
+Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, c10::optional gen) {
+  CHECK_NORMAL_STD(std);
+  auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
+  auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
+  at::native::resize_output(output, shape);
+  normal_impl_(output, 0, std, gen);
+  output.add_(mean);
+  return output;
+}
+
+template class normal_kernel, typename RNG>
+Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::optional gen) {
+  CHECK_NORMAL_TENSOR_STD(std);
+  auto mean_tensor = at::full({}, mean, output.options());
+  auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
+  at::native::resize_output(output, shape);
+  normal_impl_(output, 0, 1, gen);
+  // CUDA NB: addcmul_out copies the tensor to be added into the output.
+  // The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
+  // The third argument is not a constant reference and hence the samples in output are overwritten.
+  // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
+  output.mul_(std).add_(mean_tensor);
+  return output;
+}
+
+template class normal_kernel, typename RNG>
+Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional gen) {
+  CHECK_NORMAL_TENSOR_STD(std);
+  auto shape = at::infer_size(mean.sizes(), std.sizes());
+  at::native::resize_output(output, shape);
+  normal_impl_(output, 0, 1, gen);
+  // CUDA NB: addcmul_out copies the tensor to be added into the output.
+  // The previous function here was addcmul_out(output, mean, output, std, 1);
+  // The third argument is not a constant reference and hence the samples in output are overwritten.
+  // Consequently, the computation performed is mean + mean * std instead of mean + output * std
+  output.mul_(std).add_(mean);
+  return output;
+}
+
+template class normal_kernel, typename RNG>
+Tensor normal_impl(const Tensor& mean, double std, c10::optional gen) {
+  CHECK_NORMAL_STD(std);
+  Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
+  normal_out_impl(ret, mean, std, gen);
+  return ret;
+}
+
+template class normal_kernel, typename RNG>
+Tensor normal_impl(double mean, const Tensor& std, c10::optional gen) {
+  CHECK_NORMAL_TENSOR_STD(std);
+  Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
+  normal_out_impl(ret, mean, std, gen);
+  return ret;
+}
+
+template class normal_kernel, typename RNG>
+Tensor normal_impl(const Tensor& mean, const Tensor& std, c10::optional gen) {
+  CHECK_NORMAL_TENSOR_STD(std);
+  auto shape = at::infer_size(mean.sizes(), std.sizes());
+  Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
+  normal_out_impl(ret, mean, std, gen);
+  return ret;
+}
+
+// ==================================================== Uniform =======================================================
+
+template class uniform_kernel, typename RNG>
+at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, c10::optional generator) {
+  if (self.is_complex()) {
+    CHECK_EMPTY_AND_RETURN(self);
+    auto float_tensor = at::view_as_real(self);
+    uniform_impl_(float_tensor, from, to, generator);
+  } else {
+    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
+      const auto dtype = self.dtype();
+      const auto min = static_cast(std::numeric_limits::lowest());
+      const auto max = static_cast(std::numeric_limits::max());
+      CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
+      CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
+      TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
+      TORCH_CHECK((to - from) <= std::numeric_limits::max(),
+            "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
+            ">::max(), but found to=", to, " and from=", from,
+            " which result in to-from to exceed the limit");
+      from = std::min(std::max(from, min), max);
+      to = std::max(std::min(to, max), min);
+    });
+    CHECK_EMPTY_AND_RETURN(self);
+    auto iter = at::TensorIterator::borrowing_nullary_op(self);
+    uniform_kernel()(iter, from, to, generator);
+  }
+  return self;
+}
+
+// ================================================== LogNormal =======================================================
+
+template class log_normal_kernel, typename RNG>
+at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::optional gen) {
+  TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
+  CHECK_EMPTY_AND_RETURN(self);
+  auto iter = TensorIterator::borrowing_nullary_op(self);
+  log_normal_kernel()(iter, mean, std, gen);
+  return self;
+}
+
+// =================================================== Geometric ======================================================
+
+template class geometric_kernel, typename RNG>
+Tensor& geometric_impl_(Tensor& self, double p, c10::optional gen) {
+  TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
+  CHECK_EMPTY_AND_RETURN(self);
+  auto iter = TensorIterator::borrowing_nullary_op(self);
+  geometric_kernel()(iter, p, gen);
+  return self;
+}
+
+// ================================================== Exponential =====================================================
+
+template class exponential_kernel, typename RNG>
+Tensor& exponential_impl_(Tensor& self, double lambda, c10::optional gen) {
+  TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
+  CHECK_EMPTY_AND_RETURN(self);
+  auto iter = TensorIterator::borrowing_nullary_op(self);
+  exponential_kernel()(iter, lambda, gen);
+  return self;
+}
+
+// ==================================================== Cauchy ========================================================
+
+template class cauchy_kernel, typename RNG>
+Tensor& cauchy_impl_(Tensor& self, double median, double sigma, c10::optional gen) {
+  // TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
+  // the variance, squared sigma, is undefined for cauchy distribution
+  TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
+  TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
+  CHECK_EMPTY_AND_RETURN(self);
+  auto iter = TensorIterator::borrowing_nullary_op(self);
+  cauchy_kernel()(iter, median, sigma, gen);
+  return self;
+}
+
+// ==================================================== Bernoulli =====================================================
+
+template class bernoulli_tensor_kernel, typename RNG>
+Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, c10::optional gen) {
+  CHECK_EMPTY_AND_RETURN(self);
+  NoNamesGuard guard;
+  at::assert_no_internal_overlap(self);
+  bernoulli_tensor_kernel()(self, p_, gen);
+  return self;
+}
+
+template class bernoulli_scalar_kernel, typename RNG>
+Tensor& bernoulli_impl_(Tensor& self, double p, c10::optional gen) {
+  TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
+  CHECK_EMPTY_AND_RETURN(self);
+  at::assert_no_internal_overlap(self);
+  bernoulli_scalar_kernel()(self, p, gen);
+  return self;
+}
+
+template class bernoulli_tensor_kernel, typename RNG>
+Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, c10::optional gen) {
+  // result.resize_as_(self) requires self to have same dtype as result, so we
+  // use resize_ instead.
+  // TODO: Fix resize_as_. See pytorch/pytorch#11665.
+  result.resize_(self.sizes());
+  bernoulli_impl_(result, self, gen);
+  namedinference::propagate_names(result, self);
+  return result;
+}
+
+#undef CHECK_OUT_OF_BOUNDS
+#undef WARN_OUT_OF_BOUNDS
+
+} // namespace at::native::templates
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Distributions.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Distributions.h
new file mode 100644
index 0000000000000000000000000000000000000000..637dd73b6ba62835de12ec20cad521fc60dfcabc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Distributions.h
@@ -0,0 +1,518 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+// ROCM hcc doesn't work well with using std:: in kernel functions
+#if defined(__CUDA_ARCH__)
+#include 
+#define compat_exp c10::cuda::compat::exp
+#define compat_ceil c10::cuda::compat::ceil
+#define compat_floor c10::cuda::compat::floor
+#define compat_log c10::cuda::compat::log
+#define compat_pow c10::cuda::compat::pow
+#define compat_sqrt c10::cuda::compat::sqrt
+#define compat_tan c10::cuda::compat::tan
+#define compat_abs c10::cuda::compat::abs
+#define compat_log1p c10::cuda::compat::log1p
+#elif defined(__HIPCC__)
+#include 
+#define compat_exp c10::hip::compat::exp
+#define compat_ceil c10::hip::compat::ceil
+#define compat_floor c10::hip::compat::floor
+#define compat_log c10::hip::compat::log
+#define compat_pow c10::hip::compat::pow
+#define compat_sqrt c10::hip::compat::sqrt
+#define compat_tan c10::hip::compat::tan
+#define compat_abs c10::hip::compat::abs
+#define compat_log1p c10::hip::compat::log1p
+#else
+#define compat_exp std::exp
+#define compat_ceil std::ceil
+#define compat_floor std::floor
+#define compat_log std::log
+#define compat_pow std::pow
+#define compat_sqrt std::sqrt
+#define compat_tan std::tan
+#define compat_abs std::abs
+#define compat_log1p std::log1p
+#endif
+
+namespace {
+
+#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
+// we cannot use std::isnan directly due to some incompatibility of
+// gcc constexpr'ing and nvcc
+using std::isnan;
+#endif
+
+// Here sampler_t should be function type scalar_t(void). For gpu
+// "sampler" is a device function, but since ROCM doesn't have
+// equivalent to nvstd::function, we use a template type parameter to
+// capture it.
+template
+struct BaseSampler {
+  sampler_t sampler;
+  C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
+  C10_DEVICE scalar_t sample() {
+    return sampler();
+  }
+};
+
+// The function `sample_gamma` is
+// is adapted from Numpy's distributions.c implementation.
+// It is MIT licensed, so here is the copyright:
+
+/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+*/
+
+template
+C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler& standard_uniform, BaseSampler& standard_normal) {
+  accscalar_t scale = 1.0f;
+
+  // Boost alpha for higher acceptance probability.
+  if (alpha < 1.0f) {
+    if (alpha == 0.f) return 0.f;
+    scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
+    alpha += 1.0f;
+  }
+
+  // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
+  // doi:10.1145/358407.358414
+  const accscalar_t d = alpha - 1.0f / 3.0f;
+  const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
+  for (;;) {
+    accscalar_t x, y;
+    do {
+      x = standard_normal.sample();
+      y = 1.0f + c * x;
+    } while (y <= 0);
+    const accscalar_t v = y * y * y;
+    const accscalar_t u = 1 - standard_uniform.sample();
+    const accscalar_t xx = x * x;
+    if (u < 1.0f - 0.0331f * xx * xx)
+      return static_cast(scale * d * v);
+    if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
+      return static_cast(scale * d * v);
+  }
+}
+
+/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
+ * from TensorFlow's random_binomial_op.cc implementation. That code is under
+ * copyright: 2019 The TensorFlow Authors.
+ *
+ * It was released under the Apache License, Version 2.0 (the "License"), available at:
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+template
+C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
+  const static scalar_t kTailValues[] = {
+    0.0810614667953272,
+    0.0413406959554092,
+    0.0276779256849983,
+    0.02079067210376509,
+    0.0166446911898211,
+    0.0138761288230707,
+    0.0118967099458917,
+    0.0104112652619720,
+    0.00925546218271273,
+    0.00833056343336287
+  };
+  if (k <= 9) {
+    return kTailValues[static_cast(k)];
+  }
+  scalar_t kp1sq = (k + 1) * (k + 1);
+  return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
+}
+
+
+template
+C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) {
+  accscalar_t U;
+  accscalar_t geom_sum = 0;
+  scalar_t num_geom = 0;
+
+  accscalar_t logprob = compat_log1p(-prob);
+
+  while (1) {
+    U = standard_uniform.sample();
+    accscalar_t geom = compat_ceil(compat_log(U) / logprob);
+    geom_sum += geom;
+    if (geom_sum > count) {
+      break;
+    }
+    num_geom = num_geom + 1;
+  }
+  return num_geom;
+}
+
+template
+C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) {
+  scalar_t k;
+  accscalar_t U, V, us;
+
+  // This is spq in the paper.
+  const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
+
+  // Other coefficients for Transformed Rejection sampling.
+  const accscalar_t b = 1.15 + 2.53 * stddev;
+  const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
+  const accscalar_t c = count * prob + 0.5;
+  const accscalar_t v_r = 0.92 - 4.2 / b;
+  const accscalar_t r = prob / (1 - prob);
+
+  const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
+  const accscalar_t m = compat_floor((count + 1) * prob);
+
+  while (1) {
+    U = standard_uniform.sample() - 0.5;
+    V = standard_uniform.sample();
+
+    us = 0.5 - compat_abs(U);
+    k = static_cast(compat_floor((2 * a / us + b) * U + c));
+
+    // Reject non-sensical answers.
+    if (k < 0 || k > count) {
+      continue;
+    }
+    // Region for which the box is tight, and we can return our calculated value.
+    // This should happen 0.86 * v_r times. In the limit as n * p is large,
+    // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
+    if (us >= 0.07 && V <= v_r) {
+      return k;
+    }
+
+    // This deviates from Hormann's BTRS algorithm, as there is a log missing.
+    // For all (u, v) pairs outside of the bounding box, this calculates the
+    // transformed-reject ratio.
+    V = compat_log(V * alpha / (a / (us * us) + b));
+    accscalar_t upperbound =
+        ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
+         (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
+         (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
+         stirling_approx_tail(m) + stirling_approx_tail(count - m) -
+         stirling_approx_tail(k) - stirling_approx_tail(count - k));
+
+    if (V <= upperbound) {
+      return k;
+    }
+  }
+}
+
+template
+C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler& standard_uniform) {
+  if (count <= 0.0 || prob <= 0.0) {
+    return 0;
+  } else if (prob >= 1.0) {
+    return count;
+  } else if (prob <= 0.5) {
+    if (count * prob >= 10.0) {
+      // btrs
+      return btrs(count, prob, standard_uniform);
+    } else {
+      // binomial inversion
+      return binomial_inversion(count, prob, standard_uniform);
+    }
+  } else if (prob > 0.5) {
+    scalar_t qprob = 1.0 - prob;
+    if (count * qprob >= 10.0) {
+      // btrs
+      return count - btrs(count, qprob, standard_uniform);
+    } else {
+      // count - binomial inversion
+      return count - binomial_inversion(count, qprob, standard_uniform);
+    }
+  } else {
+    // prob is nan?
+    return static_cast(NAN);
+  }
+}
+
+/*
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
+ */
+template
+C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
+  constexpr accscalar_t PSI_10 = 2.25175258906672110764;
+  if (x == 0) {
+    return INFINITY;
+  }
+  accscalar_t additional_summand = 0;
+  int x_is_integer = x == compat_floor(x);
+  if (x < 0) {
+    if (x_is_integer) {
+      return INFINITY;
+    }
+    // it is more standard to write this as recursion, but
+    // nvcc does not like that
+    additional_summand = -c10::pi /
+        compat_tan(c10::pi * x);
+    x = 1 - x;
+  }
+
+  // Push x to be >= 10
+  accscalar_t result = 0;
+  while (x < 10) {
+    result -= 1 / x;
+    x += 1;
+  }
+  if (x == 10) {
+    return result + PSI_10 + additional_summand;
+  }
+
+  // Compute asymptotic digamma
+  static const accscalar_t A[] = {
+     8.33333333333333333333E-2,
+    -2.10927960927960927961E-2,
+     7.57575757575757575758E-3,
+    -4.16666666666666666667E-3,
+     3.96825396825396825397E-3,
+    -8.33333333333333333333E-3,
+     8.33333333333333333333E-2,
+  };
+
+  accscalar_t y = 0;
+  if (x < 1.0e17f) {
+    accscalar_t z = 1.0 / (x * x);
+    y = z * polevl(z, A, 6);
+  }
+  return static_cast(
+      result + compat_log(x) - (0.5f / x) - y + additional_summand);
+}
+
+// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
+// for random number x drawn from a standard Gamma distribution Gamma(alpha).
+template 
+C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
+  // Use a Taylor series expansion for small x.
+  accscalar_t x = static_cast(x_);
+  accscalar_t alpha = static_cast(alpha_);
+  if (x < 0.8f) {
+    accscalar_t numer = 1;
+    accscalar_t denom = alpha;
+    auto series1 = numer / denom;
+    auto series2 = numer / (denom * denom);
+    for (int i = 1; i <= 5; ++i) {
+      numer *= -x / static_cast(i);
+      denom += 1;
+      series1 += numer / denom;
+      series2 += numer / (denom * denom);
+    }
+    const auto pow_x_alpha = compat_pow(x, alpha);
+    const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
+    const auto gamma_cdf = pow_x_alpha * series1;
+    const auto gamma_cdf_alpha =
+        (compat_log(x) - digamma_one(alpha)) *
+            gamma_cdf -
+        pow_x_alpha * series2;
+    const auto result = -gamma_cdf_alpha / gamma_pdf;
+    return isnan(result) ? static_cast( 0.f ) : static_cast(result);
+  }
+
+  // Use a Rice saddle point expansion for large alpha.
+  if (alpha > 8.0f) {
+    if (0.9f * alpha <= x && x <= 1.1f * alpha) {
+      const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
+      const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
+          - 65 * x * x / alpha + alpha * (107 + 3600 * x);
+      const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
+      return static_cast(numer_1 * numer_2 / denom);
+    }
+    const auto denom = compat_sqrt(8 * alpha);
+    const auto term2 = denom / (alpha - x);
+    const auto term3 = compat_pow(
+        x - alpha - alpha * compat_log(x / alpha),
+        static_cast(-1.5));
+    const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
+    const auto term1 = compat_log(x / alpha) * term23 -
+        compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
+    const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
+    const auto numer = x * term1;
+    return static_cast(-stirling * numer / denom);
+  }
+
+  // Use a bivariate rational approximation to the reparameterized gradient.
+  const auto u = compat_log(x / alpha);
+  const auto v = compat_log(alpha);
+  static const accscalar_t coef_uv[3][8] = {
+    {0.16009398, -0.094634809, 0.025146376, -0.0030648343,
+     1, 0.32668115, 0.10406089, 0.0014179084},
+    {0.53487893, 0.1298071, 0.065735949, -0.0015649758,
+     0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
+    {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
+     0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
+  };
+  accscalar_t coef_v[8];
+  for (int i = 0; i < 8; ++ i) {
+    coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
+  }
+  const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
+  const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
+  return static_cast(compat_exp(p / q));
+}
+
+// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
+// Assumes x is close to zero and uses a Taylor expansion.
+template 
+C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
+  const scalar_t factor = digamma_one(alpha)
+                        - digamma_one(alpha + beta) - compat_log(x);
+  scalar_t numer = 1;
+  scalar_t series = numer / alpha * (factor + 1 / alpha);
+  for (int i = 1; i <= 10; ++i) {
+    scalar_t casted_i = static_cast(i);
+    numer *= (casted_i - beta) * x / casted_i;
+    const scalar_t denom = alpha + casted_i;
+    series += numer / denom * (factor + 1 / denom);
+  }
+  const scalar_t result = x * compat_pow(1 - x, -beta) * series;
+  return isnan(result) ? static_cast( 0.f ) : result;
+}
+
+// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
+// Assumes x is close to zero and uses a Taylor expansion.
+template 
+C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
+  const scalar_t factor = digamma_one(alpha + beta) - digamma_one(beta);
+  scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
+  for (int i = 1; i <= 8; ++i) {
+    scalar_t casted_i = static_cast(i);
+    numer *= -x / casted_i;
+    dbetas = dbetas * (beta - casted_i) + betas;
+    betas = betas * (beta - casted_i);
+    series += numer / (alpha + casted_i) * (dbetas + factor * betas);
+  }
+  const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
+  return isnan(result) ? static_cast( 0.f ) : result;
+}
+
+// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
+// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
+// To ensure numerical stability, this computation is performed at higher precision.
+template
+C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
+  const accscalar_t total = alpha + beta;
+  const accscalar_t mean = alpha / total;
+  const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
+  if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
+    // Avoid the singularity at x = mean.
+    const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
+                           (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
+                           3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
+                           (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
+                           8 * (1 - x) * (135 * beta - 11)))));
+    const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
+    const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
+    return prefactor_num / (1 - x) * poly / prefactor_den;
+  }
+  const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
+  const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
+                             * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
+                             / (1 + 1 / (12 * total) + 1 / (288 * total * total));
+  const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
+  const accscalar_t axbx = alpha * (x - 1) + beta * x;
+  const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast(1.5f)) * axbx * axbx;
+  const accscalar_t term1 = term1_num / term1_den;
+  const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
+  const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
+  const accscalar_t term3_den = beta * x + alpha * (x - 1);
+  const accscalar_t term3 = term3_num / term3_den;
+  const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
+                               alpha * compat_log(alpha / (total * x));
+  const accscalar_t term4 = compat_pow(term4_base, static_cast(-1.5f));
+  const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
+  return static_cast(stirling * prefactor * term1234);
+}
+
+// Computes a scaled reparameterized gradient
+//   -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
+// for random number x drawn from a Beta distribution Beta(alpha,beta).
+// This function inputs total=alpha+beta to make it easy to implement
+// Dirichlet reparameterized gradients in terms of Betas.
+template
+C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
+  accscalar_t x_ = static_cast(x);
+  accscalar_t alpha_ = static_cast(alpha);
+  accscalar_t total_ = static_cast(total);
+
+  const scalar_t beta = total - alpha;
+  const accscalar_t beta_ = total_ - alpha_;
+  const scalar_t boundary = total * x * (1 - x);
+
+  // Use an asymptotic approximation for x close to 0.
+  if (x <= 0.5f && boundary < 2.5f) {
+    return _beta_grad_alpha_small(x, alpha, beta);
+  }
+
+  // Use an asymptotic approximation for x close to 1.
+  if (x >= 0.5f && boundary < 0.75f) {
+    return -_beta_grad_beta_small(1 - x, beta, alpha);
+  }
+
+  // Use an asymptotic approximation when alpha and (total - alpha) are both large.
+  if (alpha > 6 && beta > 6) {
+    return _beta_grad_alpha_mid(x_, alpha_, beta_);
+  }
+
+  // Use a rational correction to an analytic approximation.
+  static const accscalar_t c[2][3][3][4] = {
+    {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
+      {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
+      {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
+     {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
+      {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
+      {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
+     {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
+      {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
+      {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
+    {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
+      {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
+      {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
+     {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
+      {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
+      {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
+     {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
+      {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
+      {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
+  };
+  const accscalar_t u = compat_log(x_);
+  const accscalar_t a = compat_log(alpha_) - u;
+  const accscalar_t b = compat_log(total_) - a;
+  const accscalar_t pow_u[3] = {1, u, u * u};
+  const accscalar_t pow_a[3] = {1, a, a * a};
+  accscalar_t p = 0.0;
+  accscalar_t q = 0.0;
+  for (int i = 0; i < 3; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      const accscalar_t ua = pow_u[i] * pow_a[j];
+      p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
+      q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
+    }
+  }
+  const accscalar_t approx = x_ * (digamma_one(total_) - digamma_one(alpha_)) / beta_;
+  return static_cast(p / q * approx);
+}
+
+} // namespace
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/EmbeddingBag.h b/MLPY/Lib/site-packages/torch/include/ATen/native/EmbeddingBag.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa927d7831af53976c46babe7b7b1c45f392e90a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/EmbeddingBag.h
@@ -0,0 +1,139 @@
+#include 
+#include 
+#include 
+
+#ifdef USE_FBGEMM
+#include 
+#endif
+
+namespace at::native {
+
+void check_arguments(
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offsets,
+    const int64_t mode,
+    const c10::optional& per_sample_weights,
+    bool include_last_offset);
+
+void make_bag_size_out(
+    Tensor& bag_size_out,
+    const Tensor& offsets,
+    const Tensor& indices,
+    const int64_t mode,
+    const bool include_last_offset,
+    const bool requires_grad);
+
+void make_max_indices_out(
+    Tensor& max_indices_out,
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offsets,
+    const Tensor& bag_size,
+    const int64_t mode,
+    bool include_last_offset);
+
+void make_offset2bag_out(
+    Tensor& offset2bag,
+    Tensor& output,
+    const Tensor& weight,
+    const Tensor& indices,
+    const Tensor& offsets,
+    const int64_t mode,
+    const c10::optional& per_sample_weights,
+    const int64_t padding_idx = -1);
+
+#ifdef USE_FBGEMM
+
+template
+struct _CallbackAndBlockSize {
+    using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature::Type;
+
+    int64_t blockSize = -1;
+    TCallback callback = nullptr;
+
+    static TCallback generateCallback(int64_t block_size) {
+        return fbgemm::GenerateEmbeddingSpMDM(
+                block_size,
+                has_weight,
+                /* normalize_by_lengths */false,
+                /* prefetch */16,
+                /* is_weight_positional */false,
+                /* use_offsets */true);
+    }
+
+    _CallbackAndBlockSize() = default;
+
+    explicit _CallbackAndBlockSize(c10::optional maybe_block_size)
+      : blockSize(maybe_block_size.value_or(-1))
+      , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
+    {}
+};
+
+template
+struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
+
+    _EmbeddingBagKernelCacheImpl() = default;
+    // use each of the mixins to store corresponding kernel and block size
+    explicit _EmbeddingBagKernelCacheImpl(c10::optional maybe_block_size)
+      : StorageMixins(maybe_block_size)...
+    {}
+
+    // this method is thread safe (call sites may call from different threads)
+    template
+    typename _CallbackAndBlockSize::TCallback
+    getCallback(int64_t block_size) const {
+        // if the cache doesn't store the kernel for the incoming block size
+        // (so it is different from the one stored in corresponding mixin)
+        // regenerate the kernel (not writing it into the cache so we avoid locks)
+        if (block_size != _CallbackAndBlockSize::blockSize) {
+            return _CallbackAndBlockSize::generateCallback(block_size);
+        }
+        // else retrieve the cached kernel from the corresponding mixin
+        return _CallbackAndBlockSize::callback;
+    }
+};
+
+// instantiate the cache with the list of storage mixins
+// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
+using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize,
+    _CallbackAndBlockSize>;
+#else
+struct _EmbeddingBagKernelCache {
+    explicit _EmbeddingBagKernelCache(c10::optional /* maybe_block_size */) {}
+};
+#endif
+
+void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
+    Tensor& bag_size, Tensor* max_indices,
+    const Tensor &weight, const Tensor &indices,
+    const Tensor &offsets, const int64_t mode = 0,
+    const c10::optional& per_sample_weights = c10::nullopt,
+    bool include_last_offset = false,
+    int64_t padding_idx = -1,
+    _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
+
+void _embedding_bag_cpu_out(
+    at::Tensor& output,
+    at::Tensor& offset2bag,
+    at::Tensor& bag_size,
+    at::Tensor* p_max_indices,
+    const at::Tensor& weight,
+    const at::Tensor& indices,
+    const at::Tensor& offsets,
+    const bool scale_grad_by_freq,
+    const int64_t mode,
+    const bool sparse,
+    const c10::optional& per_sample_weights,
+    const bool include_last_offset,
+    const c10::optional& padding_idx,
+    _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Fill.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Fill.h
new file mode 100644
index 0000000000000000000000000000000000000000..66ae4b0a14f0f73a7ac4e1ecca5e28f45867f9ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Fill.h
@@ -0,0 +1,21 @@
+// Functions that fill Tensors with constants. Implementations are in Fill.cpp.
+
+#pragma once
+
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+class Tensor;
+struct TensorIterator;
+
+namespace native {
+
+DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub);
+
+Tensor& fill_out(Tensor& self, const Scalar& value);
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ForeachUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ForeachUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..18df3a2a0e78be9909231832956f066e94aaf2c4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ForeachUtils.h
@@ -0,0 +1,371 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+#include 
+#include 
+
+namespace at::native {
+namespace {
+// Check if tensor list has either a boolean tensor or a integer tensor
+inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
+  return std::any_of(
+      tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
+        return at::isIntegralType(t.scalar_type(), includeBool);
+      });
+}
+// check if tensor list has bool tensors
+inline bool has_bool_tensor(TensorList tensors) {
+  return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
+    return t.scalar_type() == ScalarType::Bool;
+  });
+}
+
+// Check foreach API restrictions
+// - Tensor lists must be non-empty.
+// - All TensorLists and ScalarLists must have the same number of elements.
+// - Corresponding tensors must have the same size.
+inline void check_foreach_api_restrictions(TensorList tensors) {
+  TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
+}
+
+inline void check_foreach_api_restrictions(
+    TensorList tensors,
+    ArrayRef scalars) {
+  check_foreach_api_restrictions(tensors);
+  TORCH_CHECK(
+      tensors.size() == scalars.size(),
+      "Tensor list must have same number of elements as scalar list.");
+}
+
+inline void check_foreach_api_restrictions(
+    TensorList tensors1,
+    TensorList tensors2) {
+  TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
+  TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
+  TORCH_CHECK(
+      tensors1.size() == tensors2.size(),
+      "Tensor lists must have the same number of tensors, got ",
+      tensors1.size(),
+      " and ",
+      tensors2.size());
+}
+
+inline void check_foreach_api_restrictions(
+    TensorList tensors1,
+    TensorList tensors2,
+    TensorList tensors3) {
+  TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
+  TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
+  TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
+  TORCH_CHECK(
+      tensors1.size() == tensors2.size(),
+      "Tensor lists must have the same number of tensors, got ",
+      tensors1.size(),
+      " and ",
+      tensors2.size());
+  TORCH_CHECK(
+      tensors1.size() == tensors3.size(),
+      "Tensor lists must have the same number of tensors, got ",
+      tensors1.size(),
+      " and ",
+      tensors3.size());
+}
+
+inline void check_foreach_api_restrictions(
+    TensorList tensors1,
+    TensorList tensors2,
+    TensorList tensors3,
+    ArrayRef scalars) {
+  check_foreach_api_restrictions(tensors1, tensors2, tensors3);
+  TORCH_CHECK(
+      tensors1.size() == scalars.size(),
+      "Tensor list must have same number of elements as scalar list, got ",
+      tensors1.size(),
+      " and ",
+      scalars.size());
+}
+
+// Helper function called in check_fast_path_restrictions to check whether all
+// corresponding tensors (aligning in index across the tensorLists) share the
+// same device and dtype.
+inline bool _check_tensors_share_device_and_dtype(
+    ArrayRef tensorLists) {
+  const auto expected_dtype = tensorLists[0][0].dtype();
+  const auto expected_device = tensorLists[0][0].device();
+
+  auto is_tensor_okay = [&](const Tensor& tensor) {
+    return tensor.dtype() == expected_dtype &&
+        tensor.device() == expected_device && tensor.layout() == at::kStrided &&
+        tensor.is_non_overlapping_and_dense();
+  };
+
+  for (const auto& tensorList : tensorLists) {
+    for (const auto& tensor : tensorList) {
+      if (!is_tensor_okay(tensor)) {
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+// Helper function called in check_fast_path_restrictions to check if
+// corresponding tensors in tensor lists have the same sizes and strides.
+inline bool _check_tensors_share_sizes_and_strides(
+    ArrayRef tensorLists) {
+  for (const auto i : c10::irange(1, tensorLists.size())) {
+    for (const auto j : c10::irange(tensorLists[0].size())) {
+      if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
+          tensorLists[0][j].strides() != tensorLists[i][j].strides()) {
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+// Helper function called in check_fast_path_restrictions to check whether
+// all tensors type promote properly with the scalars in scalarList. This
+// function assumes that _check_tensors_share_device_and_dtype has already been
+// called so that all corresponding tensors in tensorLists have the same dtype.
+// Then, it is sufficient to check the type promotion with just one tensorList.
+inline bool _check_tensors_do_type_promotion_with_scalars(
+    TensorList tensorList,
+    ArrayRef scalarList = {},
+    bool does_op_promote_integer_inputs_to_float = false) {
+  for (const auto i : c10::irange(tensorList.size())) {
+    // For division, integer inputs will result in float.
+    if (does_op_promote_integer_inputs_to_float) {
+      if (at::isIntegralType(
+              tensorList[i].scalar_type(), /*includeBool*/ true)) {
+        return false;
+      }
+    }
+    if (!scalarList.empty()) {
+      const auto& scalar =
+          scalarList.size() == 1 ? scalarList[0] : scalarList[i];
+      const auto& tensor = tensorList[i];
+      // note(mkozuki): This check might be responsible for
+      // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
+      if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+// To go via 'fast' path, several conditions must be satisfied
+// - All tensors in all lists must have the same dtype.
+// - All tensors must be on the same device
+// - All tensors must have strided layout
+// - All tensors must be non-overlapping and dense
+// - Resulting tensor must have the same dtype as the input one
+
+// Please, make sure to call check_foreach_api_restrictions before calling this
+// method. There is a set of preconditions that have to be satisfied.
+inline bool check_fast_path_restrictions(
+    ArrayRef tensorLists,
+    ArrayRef scalarList = {},
+    bool does_op_promote_integer_inputs_to_float = false) {
+  return _check_tensors_share_device_and_dtype(tensorLists) &&
+      _check_tensors_share_sizes_and_strides(tensorLists) &&
+      _check_tensors_do_type_promotion_with_scalars(
+             tensorLists[0],
+             scalarList,
+             does_op_promote_integer_inputs_to_float);
+}
+
+inline std::vector convert_tensor_to_scalar_list(
+    const Tensor& scalarList_,
+    int64_t expect_length) {
+  std::vector scalarList;
+  TORCH_CHECK(
+      scalarList_.device() == c10::kCPU,
+      "Expected scalars to be on CPU, got ",
+      scalarList_.device(),
+      " instead.");
+  TORCH_CHECK(
+      scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
+  TORCH_CHECK(
+      scalarList_.dim() == 1,
+      "Expected packed scalar Tensor to be of dimension 1. Got ",
+      scalarList_.dim(),
+      " instead.");
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
+      kComplexHalf,
+      kHalf,
+      kBool,
+      kBFloat16,
+      scalarList_.scalar_type(),
+      "convert_tensor_to_scalar_list",
+      [&]() {
+        const scalar_t* scalar_data = scalarList_.data_ptr();
+        TORCH_CHECK(
+            (expect_length == scalarList_.size(0)),
+            "Expected length of scalars to match input of length ",
+            expect_length,
+            " but got ",
+            scalarList_.size(0),
+            " instead.");
+        for (int64_t i = 0; i < scalarList_.size(0); i++) {
+          scalarList.emplace_back(scalar_data[i]);
+        }
+      });
+  return scalarList;
+}
+
+inline bool can_use_fast_route(
+    ArrayRef tensorLists,
+    ArrayRef scalarList = {},
+    bool does_op_promote_integer_inputs_to_float = false) {
+  return check_fast_path_restrictions(
+      tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
+}
+
+inline bool can_use_fast_route(
+    TensorList tensors1,
+    TensorList tensors2,
+    bool does_op_promote_integer_inputs_to_float = false) {
+  return can_use_fast_route(
+      {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
+}
+
+using DeviceDtypeKey = std::pair;
+using IndicesT = std::vector;
+using nested_optional_tensorvec_t =
+    std::vector>>;
+using TensorsAndIndicesT = std::pair;
+using FlatMap = std::unordered_map<
+    DeviceDtypeKey,
+    TensorsAndIndicesT,
+    ParamsHash>;
+
+inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
+    const nested_optional_tensorvec_t& nested_tensorlist,
+    const bool with_indices) {
+  FlatMap grouped_tensors_with_indices;
+
+  TORCH_CHECK(!nested_tensorlist.empty());
+  TORCH_CHECK(!nested_tensorlist[0].empty());
+  const auto num_lists = nested_tensorlist.size();
+  const auto num_tensors = nested_tensorlist[0].size();
+
+  TORCH_CHECK(std::all_of(
+      nested_tensorlist.cbegin(),
+      nested_tensorlist.cend(),
+      [&](const auto& tensorlist) -> bool {
+        // note(crcrpar): Allow empty tensorlists following
+        // ref:
+        // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
+        return tensorlist.size() == num_tensors || tensorlist.size() == 0;
+      }));
+
+  for (const auto& tensor_index : c10::irange(num_tensors)) {
+    const auto key = [&]() -> DeviceDtypeKey {
+      const auto t = nested_tensorlist[0][tensor_index];
+      TORCH_CHECK(
+          t.has_value(),
+          "Tensors of the first list of nested Tensor lists are supposed to be defined but ",
+          "the ",
+          tensor_index,
+          "-th Tensor is not.");
+      return {t->device(), t->scalar_type()};
+    }();
+    TORCH_CHECK(
+        std::all_of(
+            nested_tensorlist.cbegin(),
+            nested_tensorlist.cend(),
+            [&](const auto& tensorlist) -> bool {
+              if (tensorlist.size() == 0) {
+                return true;
+              }
+              const auto& tensor = tensorlist[tensor_index];
+              // note(crcrpar): Currently the scope of this function is
+              // optimizers so there could be `state_steps` and other scalars
+              // whose elements are float tensors no matter what the parameter's
+              // dtype is.
+              if (!tensor.has_value()) {
+                return true;
+              } else {
+                const auto s = tensor->scalar_type();
+                const auto d = tensor->device();
+                // Note: `step` or `state_step` is float32 by default.
+                if (key.first == d) {
+                  return key.second == s || s == at::ScalarType::Float ||
+                      s == at::ScalarType::Double;
+                } else if (d.is_cpu()) {
+                  // note(crcrpar): There are some test cases (e.g.
+                  // TestOptim::test_adam) where state_steps are on CPU and the
+                  // others are on CUDA. Currently a state_step Tensor has the
+                  // dtype of float.
+                  return s == at::ScalarType::Float ||
+                      s == at::ScalarType::Double;
+                } else {
+                  return false;
+                }
+              }
+            }),
+        "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
+    if (!grouped_tensors_with_indices.count(key)) {
+      grouped_tensors_with_indices.insert(
+          {key,
+           TensorsAndIndicesT{
+               [&]() -> nested_optional_tensorvec_t {
+                 nested_optional_tensorvec_t nested_tensorvec;
+                 nested_tensorvec.reserve(num_lists);
+                 for (const auto& i : c10::irange(num_lists)) {
+                   std::vector> tensors;
+                   if (!nested_tensorlist[i].empty()) {
+                     // NB: num_tensors is the max possible length for any of
+                     // the inner lists of tensor references. Reserving the max
+                     // trades memory for perf. This should not have significant
+                     // impact.
+                     tensors.reserve(num_tensors);
+                   }
+                   nested_tensorvec.emplace_back(tensors);
+                 }
+                 return nested_tensorvec;
+               }(),
+               [&]() -> IndicesT {
+                 if (!with_indices) {
+                   return {};
+                 } else {
+                   IndicesT indices;
+                   indices.reserve(num_tensors);
+                   return indices;
+                 }
+               }()}});
+    }
+    for (const auto& list_index : c10::irange(num_lists)) {
+      if (!nested_tensorlist[list_index].empty()) {
+        grouped_tensors_with_indices[key].first[list_index].emplace_back(
+            nested_tensorlist[list_index][tensor_index]);
+      }
+    }
+    if (with_indices) {
+      grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
+    }
+  }
+
+  return grouped_tensors_with_indices;
+}
+
+} // namespace
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/FractionalMaxPooling.h b/MLPY/Lib/site-packages/torch/include/ATen/native/FractionalMaxPooling.h
new file mode 100644
index 0000000000000000000000000000000000000000..af87ba8b8962904ef5dbc3d7290d90c27f52c20d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/FractionalMaxPooling.h
@@ -0,0 +1,80 @@
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+template
+static inline std::vector generate_intervals(
+    scalar_t sample,
+    int64_t inputSize,
+    int64_t outputSize,
+    int64_t poolSize) {
+  std::vector sequence(outputSize);
+  if (outputSize > 1) {
+    scalar_t alpha = static_cast(inputSize - poolSize) /
+      static_cast(outputSize - 1);
+
+    for (const auto i : c10::irange(outputSize - 1)) {
+      sequence[i] =
+        static_cast((i + sample) * alpha) - static_cast(sample * alpha);
+    }
+  }
+  if (outputSize > 0) {
+    sequence[outputSize - 1] = inputSize - poolSize;
+  }
+  return sequence;
+}
+
+template 
+static inline void fractional_max_pool_check_shape(
+    const Tensor& input,
+    const Tensor& randomSamples) {
+
+  TORCH_CHECK(
+      input.scalar_type() == randomSamples.scalar_type(),
+      "Expect _random_samples to have the same dtype as input");
+
+  int64_t ndimension = randomSamples.ndimension();
+  TORCH_CHECK(
+      ndimension == 3,
+      "Expect _random_samples to have 3 dimensions, got ", ndimension);
+
+  int64_t N = randomSamples.size(0);
+  int64_t C = randomSamples.size(1);
+  int64_t D = randomSamples.size(2);
+
+  int64_t input_batch, input_channel;
+  if (ndim == 2) {
+    // fractional_max_pool2d
+    if (input.ndimension() == 3) {
+      input_batch = 1;
+      input_channel = input.size(0);
+    } else {
+      input_batch = input.size(0);
+      input_channel = input.size(1);
+    }
+  } else {
+    // factional_max_pool3d
+    if (input.ndimension() == 4) {
+      input_batch = 1;
+      input_channel = input.size(0);
+    } else {
+      input_batch = input.size(0);
+      input_channel = input.size(1);
+    }
+  }
+
+  TORCH_CHECK(
+      N >= input_batch,
+      "Expect _random_samples.size(0) no less then input batch size.");
+  TORCH_CHECK(
+      C == input_channel,
+      "Expect _random_samples.size(1) equals to input channel size.");
+  TORCH_CHECK(
+      D == ndim,
+      "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..e760b3bfaa7eea53046b8ec7ea00605b6f209503
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+struct TensorIterator;
+
+namespace native {
+
+using _compute_linear_combination_fn = void(*)(
+  TensorIterator& iter,
+  int64_t in_stride,
+  int64_t coeff_stride,
+  int64_t num_summations
+);
+
+DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub);
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/GridSampler.h b/MLPY/Lib/site-packages/torch/include/ATen/native/GridSampler.h
new file mode 100644
index 0000000000000000000000000000000000000000..cad7bd8205bf08c83f1e322c6eca70df03c27935
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/GridSampler.h
@@ -0,0 +1,298 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at::native {
+
+using detail::GridSamplerInterpolation;
+using detail::GridSamplerPadding;
+
+// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
+// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
+// if align_corners: -1 and +1 get sent to the centers of the corner pixels
+//     -1 --> 0
+//     +1 --> (size - 1)
+//     scale_factor = (size - 1) / 2
+// if not align_corners: -1 and +1 get sent to the image edges
+//     -1 --> -0.5
+//     +1 --> (size - 1) + 0.5 == size - 0.5
+//     scale_factor = size / 2
+template 
+static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
+                                                bool align_corners) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    return ((coord + 1) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    return ((coord + 1) * size - 1) / 2;
+  }
+}
+
+// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
+// except that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
+                                                         bool align_corners, scalar_t *grad_in) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    *grad_in = static_cast(size - 1) / 2;
+    return ((coord + 1) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    *grad_in = static_cast(size) / 2;
+    return ((coord + 1) * size - 1) / 2;
+  }
+}
+
+// Clips coordinates to between 0 and clip_limit - 1
+template
+static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
+  return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0)));
+}
+
+// clip_coordinates_set_grad works similarly to clip_coordinates except that
+// it also returns the `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template
+static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
+                                                 scalar_t *grad_in) {
+  // Note that it is important for the gradient calculation that borders
+  // are considered out of bounds.
+  if (in <= static_cast(0)) {
+    *grad_in = static_cast(0);
+    return static_cast(0);
+  } else {
+    scalar_t max = static_cast(clip_limit - 1);
+    if (in >= max) {
+      *grad_in = static_cast(0);
+      return max;
+    } else {
+      *grad_in = static_cast(1);
+      return in;
+    }
+  }
+}
+
+// Reflects coordinates until they fall between low and high (inclusive).
+// The bounds are passed as twice their value so that half-integer values
+// can be represented as ints.
+template
+static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
+                                           int64_t twice_high) {
+  if (twice_low == twice_high) {
+    return static_cast(0);
+  }
+  scalar_t min = static_cast(twice_low) / 2;
+  scalar_t span = static_cast(twice_high - twice_low) / 2;
+  in = std::fabs(in - min);
+  // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
+  scalar_t extra = std::fmod(in, span);
+  int flips = static_cast(std::floor(in / span));
+  if (flips % 2 == 0) {
+    return extra + min;
+  } else {
+    return span - extra + min;
+  }
+}
+
+// reflect_coordinates_set_grad works similarly to reflect_coordinates except
+// that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template
+static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
+                                                    int64_t twice_high, scalar_t *grad_in) {
+  if (twice_low == twice_high) {
+    *grad_in = static_cast(0);
+    return static_cast(0);
+  }
+  int grad_in_mult_;
+  scalar_t min = static_cast(twice_low) / 2;
+  scalar_t span = static_cast(twice_high - twice_low) / 2;
+  in = in - min;
+  if (in < static_cast(0)) {
+    grad_in_mult_ = -1;
+    in = -in;
+  } else {
+    grad_in_mult_ = 1;
+  }
+  // `fmod` returns same sign as `in`, which is positive after the `if` above.
+  scalar_t extra = std::fmod(in, span);
+  int flips = static_cast(std::floor(in / span));
+  if (flips % 2 == 0) {
+    *grad_in = static_cast(grad_in_mult_);
+    return extra + min;
+  } else {
+    *grad_in = static_cast(-grad_in_mult_);
+    return span - extra + min;
+  }
+}
+
+// Mapping the out-of-boundary points back into boundary
+// This would only affect padding_mode=border or reflection
+template
+static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
+                                           GridSamplerPadding padding_mode,
+                                           bool align_corners) {
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates(coord, 0, 2*(size - 1));
+    } else {
+      coord = reflect_coordinates(coord, -1, 2*size - 1);
+    }
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  }
+  return coord;
+}
+
+// Computes the pixel source index value for a grid coordinate
+template 
+static inline scalar_t grid_sampler_compute_source_index(
+    scalar_t coord,
+    int64_t size,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+  coord = grid_sampler_unnormalize(coord, size, align_corners);
+  coord = compute_coordinates(coord, size, padding_mode, align_corners);
+  return coord;
+}
+
+// grid_sampler_compute_source_index_set_grad works similarly to
+// grid_sampler_compute_source_index except that it also returns the
+// `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static inline scalar_t grid_sampler_compute_source_index_set_grad(
+    scalar_t coord,
+    int64_t size,
+    GridSamplerPadding padding_mode,
+    bool align_corners,
+    scalar_t *grad_in) {
+  scalar_t grad_clip, grad_refl;
+  coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_clip;
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
+    } else {
+      coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
+    }
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_refl * grad_clip;
+  }
+  return coord;
+}
+
+static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
+  return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
+  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template
+static inline scalar_t get_value_bounded(
+    scalar_t* data,
+    scalar_t x,
+    scalar_t y,
+    int64_t W,
+    int64_t H,
+    int64_t sW,
+    int64_t sH,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+
+  x = compute_coordinates(x, W, padding_mode, align_corners);
+  y = compute_coordinates(y, H, padding_mode, align_corners);
+
+  int64_t ix = static_cast(x);
+  int64_t iy = static_cast(y);
+
+  if (within_bounds_2d(iy, ix, H, W)) {
+    return data[iy * sH + ix * sW];
+  }
+  return static_cast(0);
+}
+
+template
+static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
+                               int64_t sH, int64_t sW, int64_t H, int64_t W,
+                               scalar_t delta) {
+  if (within_bounds_2d(h, w, H, W)) {
+    data[h * sH + w * sW] += delta;
+  }
+}
+
+template
+static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
+                               int64_t sD, int64_t sH, int64_t sW,
+                               int64_t D, int64_t H, int64_t W,
+                               scalar_t delta) {
+  if (within_bounds_3d(d, h, w, D, H, W)) {
+    data[d * sD + h * sH + w * sW] += delta;
+  }
+}
+
+template
+static inline void add_value_bounded(
+    scalar_t* data,
+    scalar_t x,
+    scalar_t y,
+    int64_t W,
+    int64_t H,
+    int64_t sW,
+    int64_t sH,
+    scalar_t delta,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+
+  x = compute_coordinates(x, W, padding_mode, align_corners);
+  y = compute_coordinates(y, H, padding_mode, align_corners);
+
+  int64_t ix = static_cast(x);
+  int64_t iy = static_cast(y);
+
+  safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
+}
+
+// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
+template
+static inline void get_cubic_coefficients_grad(
+    scalar_t coeffs[4],
+    scalar_t t) {
+
+  // Must be the same as forward calculation in
+  // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
+  scalar_t A = -0.75;
+
+  scalar_t x;
+  x = -1 - t; // 1 < x = |-1 - tx| < 2
+  coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
+  x = -t;     // x = |0 - tx| <= 1
+  coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
+  x = 1 - t;  // x = |1 - tx| <= 1
+  coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
+  x = 2 - t;  // 1 < x = |2 - tx| < 2
+  coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
+}
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/GridSamplerUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/GridSamplerUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..2d1084366a4daf2fc4c783a3a5c0e0cd2533bb45
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/GridSamplerUtils.h
@@ -0,0 +1,109 @@
+#pragma once
+
+// See NOTE: [Tensor vs. TensorBase]
+// https://github.com/pytorch/pytorch/pull/66979
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+namespace detail {
+
+enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
+enum class GridSamplerPadding {Zeros, Border, Reflection};
+
+} // namespace detail
+
+using detail::GridSamplerInterpolation;
+using detail::GridSamplerPadding;
+
+namespace {
+
+// See NOTE [ grid_sampler Native Functions ].
+void check_grid_sampler_common(
+  const TensorBase& input,
+  const TensorBase& grid
+) {
+  auto input_opt = input.options();
+  auto grid_opt = grid.options();
+
+  TORCH_CHECK(
+    input.defined(),
+    "grid_sampler(): expected input to not be undefined");
+  TORCH_CHECK(
+    grid.defined(),
+    "grid_sampler(): expected grid to not be undefined");
+  TORCH_CHECK(
+    input_opt.device() == grid_opt.device(),
+    "grid_sampler(): expected input and grid to be on same device, but input "
+    "is on ", input_opt.device(), " and grid is on ", grid_opt.device());
+  TORCH_CHECK(
+    input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
+    "grid_sampler(): expected input and grid to have torch.strided layout, but "
+    "input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
+  TORCH_CHECK(
+    input.size(0) == grid.size(0),
+    "grid_sampler(): expected grid and input to have same batch size, but got "
+    "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
+  TORCH_CHECK(
+    grid.size(-1) == input.dim() - 2,
+    "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
+    "dimension, but got grid with sizes ", grid.sizes());
+
+  for (const auto i : c10::irange(2, input.dim())) {
+    TORCH_CHECK(input.size(i) > 0,
+      "grid_sampler(): expected input to have non-empty spatial dimensions, "
+      "but input has sizes ", input.sizes(), " with dimension ", i, " being "
+      "empty");
+  }
+}
+
+// See NOTE [ grid_sampler Native Functions ].
+void check_grid_sampler_2d(
+  const TensorBase& input,
+  const TensorBase& grid
+) {
+  TORCH_CHECK(
+    input.dim() == 4 && input.dim() == grid.dim(),
+    "grid_sampler(): expected 4D input and grid with same number of "
+    "dimensions, but got input with sizes ", input.sizes(),
+    " and grid with sizes ", grid.sizes());
+}
+
+// See NOTE [ grid_sampler Native Functions ].
+void check_grid_sampler_3d(
+  const TensorBase& input,
+  const TensorBase& grid,
+  int64_t interpolation_mode
+) {
+  TORCH_CHECK(
+    input.dim() == 5 && input.dim() == grid.dim(),
+    "grid_sampler(): expected 5D input and grid with same number of "
+    "dimensions, but got input with sizes ", input.sizes(),
+    " and grid with sizes ", grid.sizes());
+  TORCH_CHECK(
+    !(input.dim() == 5 &&
+      static_cast(interpolation_mode) ==
+        GridSamplerInterpolation::Bicubic),
+    "grid_sampler(): bicubic interpolation only supports 4D input");
+}
+
+// See NOTE [ grid_sampler Native Functions ].
+// cudnn does not support inputs larger than 1024.
+bool cond_cudnn_grid_sampler(
+  const TensorBase& input,
+  const TensorBase& grid
+) {
+  return (
+    at::native::cudnn_is_acceptable(input) &&
+    at::native::cudnn_is_acceptable(grid) &&
+    at::native::canUse32BitIndexMath(input) &&
+    at::native::canUse32BitIndexMath(grid) &&
+    input.dim() == 4 &&
+    input.sym_size(1) <= 1024);
+}
+
+} // anonymous namespace
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Histogram.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Histogram.h
new file mode 100644
index 0000000000000000000000000000000000000000..27265bdc7d89e16db6af4a315b7b1f2368c6c50c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Histogram.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using histogramdd_fn = void(*)(const Tensor&, const c10::optional&, bool, Tensor&, const TensorList&);
+using histogramdd_linear_fn = void(*)(const Tensor&, const c10::optional&, bool, Tensor&, const TensorList&, bool);
+using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector &leftmost_edges, std::vector &rightmost_edges);
+
+DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
+DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
+DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/IndexKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/IndexKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..e89d349cebbbbe5501ea587aac6ea18e7a86a83e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/IndexKernel.h
@@ -0,0 +1,41 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+class Tensor;
+class TensorBase;
+struct TensorIterator;
+struct TensorIteratorBase;
+}
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at::native {
+
+using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
+using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
+using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
+using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
+using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
+using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
+using flip_fn = void(*)(TensorIterator &, const bool);
+using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
+using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
+using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
+
+DECLARE_DISPATCH(index_fn, index_stub);
+DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
+DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
+DECLARE_DISPATCH(index_put_fn, index_put_stub);
+DECLARE_DISPATCH(put_fn, put_stub);
+DECLARE_DISPATCH(take_fn, take_stub);
+DECLARE_DISPATCH(flip_fn, flip_stub);
+DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
+DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
+DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
+DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/IndexingUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/IndexingUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..bf1edc6de186211932cc1e76ce410833e43c1b06
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/IndexingUtils.h
@@ -0,0 +1,160 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+[[noreturn]]
+static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
+  TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
+  " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
+}
+
+
+static C10_UNUSED std::vector expandTensors(const Tensor & self, IOptTensorListRef indices) {
+  // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
+  std::vector result;
+  for (const auto& index_opt : indices) {
+    if (!index_opt.has_value()) {
+      result.emplace_back();
+    } else {
+      const auto& index = *index_opt;
+      if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
+        if (index.scalar_type() == kByte) {
+          TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
+          " please use a dtype torch.bool instead.");
+        }
+        // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
+        // corresponding dimensions in self
+        for (const auto j : c10::irange(index.dim())) {
+          int64_t srcIdx = static_cast(result.size() + j);
+          if (index.size(j) != self.size(srcIdx)) {
+            invalid_mask(self, srcIdx, index, j);
+          }
+        }
+        // Replace with nonzeros
+        auto nonzero = index.nonzero();
+        for (const auto j : c10::irange(index.dim())) {
+          result.emplace_back(nonzero.select(1, j));
+        }
+      } else {
+        result.emplace_back(index);
+      }
+    }
+  }
+  return result;
+}
+
+static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
+  for (const auto& tensor : indices) {
+    if (tensor.has_value() && tensor->defined()) {
+      auto scalarType = tensor->scalar_type();
+      if (allow_int) {
+        if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
+            TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
+        }
+      } else {
+        if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
+            TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
+        }
+      }
+    }
+  }
+}
+
+inline torch::List> toListOfOptionalTensors(ArrayRef list) {
+  torch::List> result;
+  result.reserve(list.size());
+  for (const Tensor& a : list) {
+    result.push_back(a);
+  }
+  return result;
+}
+
+inline torch::List> toListOfOptionalTensors(ArrayRef list) {
+  torch::List> result;
+  result.reserve(list.size());
+  for (const IValue& a : list) {
+    result.push_back(a.isTensor() ? c10::optional(a.toTensor()) : c10::optional());
+  }
+  return result;
+}
+
+static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
+  // true if all the non-null tensors are adjacent
+  auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
+  auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
+  auto start = std::find_if(tl.begin(), tl.end(), isDefined);
+  auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
+  auto it = std::find_if(start, stop.base(), isNull);
+  return it == stop.base();
+}
+
+
+// Transposes the tensor and indices together so that all the non-null indices
+// index the first k dimensions of the tensor. Returns the transposed tensor
+// and the reordered indices. For example:
+// transposeToFront(tensor, {nullptr, a, nullptr, b})
+// returns
+// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
+static C10_UNUSED std::tuple>
+transposeToFront(const Tensor& self, TensorList indices) {
+  std::vector dims;
+  std::vector transposedIndices;
+  dims.reserve(self.dim());
+  for (const auto i : c10::irange(self.dim())) {
+    if (indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back(indices[i]);
+    }
+  }
+  for (const auto i : c10::irange(self.dim())) {
+    if (!indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back();
+    }
+  }
+  return std::make_tuple(self.permute(dims), std::move(transposedIndices));
+}
+
+inline std::tuple, std::vector>
+transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
+  std::vector dims;
+  std::vector invPerm;
+  std::vector transposedIndices;
+  dims.reserve(self.dim());
+  invPerm.resize(self.dim());
+  for (const auto i : c10::irange(self.dim())) {
+    if (indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back(indices[i]);
+    }
+  }
+  for (const auto i : c10::irange(self.dim())) {
+    if (!indices[i].defined()) {
+      dims.push_back(i);
+      transposedIndices.emplace_back();
+    }
+  }
+  for (const auto i : c10::irange(self.dim())) {
+    invPerm[dims[i]] = i;
+  }
+  return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
+}
+
+struct AdvancedIndex {
+  AdvancedIndex(const Tensor& src, TensorList indices);
+
+  Tensor src;
+  std::vector indices;
+  DimVector indexed_sizes;
+  DimVector indexed_strides;
+  int64_t dims_before;
+  int64_t dims_after;
+};
+
+
+} //namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Lerp.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Lerp.h
new file mode 100644
index 0000000000000000000000000000000000000000..5fd66810125129ec9cf3fa7c3e9d218b084b687f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Lerp.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+template 
+C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
+  return std::abs(weight) < scalar_t(0.5);
+}
+template 
+C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex weight) {
+  // Avoid the sqrt in abs(weight)
+  return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
+}
+
+template 
+C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
+  using opmath_t = at::opmath_type;
+  using opmath_weight_t = at::opmath_type;
+
+  opmath_t self = self_;
+  opmath_t end = end_;
+  opmath_weight_t weight = weight_;
+
+  // Conditional for better numeric. This has been discussed in
+  // https://github.com/pytorch/pytorch/pull/18871
+  return is_lerp_weight_small(weight)
+      ? self + weight * (end - self)
+      : end - (end - self) * (opmath_t(1) - weight);
+}
+
+using lerp_fn_scalar = void (*)(
+    at::TensorIteratorBase& iter,
+    const Scalar& weight);
+
+using lerp_fn_tensor = void (*)(
+    at::TensorIteratorBase& iter);
+
+DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
+DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/LinearAlgebra.h b/MLPY/Lib/site-packages/torch/include/ATen/native/LinearAlgebra.h
new file mode 100644
index 0000000000000000000000000000000000000000..507a5f7b9c43ec2aced978c0f83100e555477b0c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/LinearAlgebra.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include 
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+struct TensorIterator;
+}
+
+namespace at::native {
+
+using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
+DECLARE_DISPATCH(addr_fn, addr_stub);
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..3fd7c014226cc9147bfc204d4a2775c401661731
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h
@@ -0,0 +1,623 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#include 
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+static inline c10::MaybeOwned expect_resolved_conj(const Tensor& tensor) {
+  if (tensor.is_conj()) {
+    return c10::MaybeOwned::owned(tensor.resolve_conj());
+  } else {
+    return c10::MaybeOwned::borrowed(tensor);
+  }
+}
+
+static inline DimVector batched_matrix_contiguous_strides(
+    const IntArrayRef sizes,
+    const bool f_contig = false) {
+  // f_contig chooses between the strides of a batch of Fortran (F-contiguous)
+  // and C-contiguous matrices
+  auto strides = c10::contiguous_strides(sizes);
+  auto dim = strides.size();
+
+  if (f_contig && dim >= 2) {
+    // Fix the strides of the last two dimensions, so that we return
+    // C-contiguous batches of F-contiguous matrices.
+    strides[dim - 1] = std::max(sizes[dim - 2], static_cast(1));
+    strides[dim - 2] = 1;
+  }
+  return strides;
+}
+
+/*
+ * Clones a Tensor so that the following conditions hold:
+ * If we think of a Tensor of having size (B, M, N), where B is any number
+ * of batch dimensions, then:
+ * - Each (M, N) matrix is in column major form
+ * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
+ *   Then when laid out in memory, the M by N matrix starting at
+ *   P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
+ *   matrix starting at Q.data_ptr()[B * M' * N'].
+ */
+static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
+  // If src is already in batched column major format, then
+  // this will be efficient (no reordering of the data will occur)
+  // because the first transpose will make the tensor contiguous,
+  // and cloning a contiguous tensor is fast.
+  auto result = src.mT().clone(at::MemoryFormat::Contiguous);
+  result.transpose_(-2, -1);
+  return result;
+}
+
+/*
+ * contig chooses between C-contig (true) and F-contig (false)
+ */
+static inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
+  return cond ? c10::MaybeOwned::borrowed(borrow)
+              : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous)
+                                                      : cloneBatchedColumnMajor(clone));
+}
+
+/*
+ * This method is designed to be a faster alternative to
+ * `cloneBatchedColumnMajor` with some additional features,
+ * namely:
+ * 1. It uses `copy` instead of `clone` which could be much faster.
+ * 2. `nrows` parameter used to create inputs with the number of rows larger
+ *  than the original input, which is required for some LAPACK/MAGMA methods.
+ * 3. `desired_batch_size` is used to create copies with the batch size
+ *  which is either the original batch size of the input, or its larger
+ *  broadcasted shape.
+ */
+static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
+    at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) {
+  nrows = (nrows == -1) ? src.size(-2) : nrows;
+  auto copy_sizes = desired_batch_sizes.has_value()
+    ? desired_batch_sizes.value().vec()
+    : IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
+  copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
+  const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
+  auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
+  copy.narrow(-2, 0, src.size(-2)).copy_(src);
+  return copy;
+}
+
+/*
+ * Given batches of matrices with arbitrary batch dim,
+ * computes the number of batches.
+ */
+static inline int64_t batchCount(const Tensor& batched_matrices) {
+  int64_t result = 1;
+  for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
+    result *= batched_matrices.size(i);
+  }
+  return result;
+}
+
+// Computes the number of elements of a matrix in a batched matrix tensor
+static inline int64_t matrixStride(const Tensor& batched_matrices) {
+  return batched_matrices.size(-1) * batched_matrices.size(-2);
+}
+
+// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
+static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
+  TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
+}
+static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
+  checkIsMatrix(self, f_name, arg_name);
+  TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
+              f_name,
+              ": ", arg_name, " must be batches of square matrices, "
+              "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
+}
+
+static inline void checkInputsSolver(const Tensor& A,
+                                     const Tensor& B,
+                                     const bool left,
+                                     const char* const f_name) {
+  squareCheckInputs(A, f_name, "A");
+  checkIsMatrix(B, f_name, "B");
+  TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
+              f_name, ": Incompatible shapes of A and B for the equation ",
+              left ? "AX = B" : "XA = B",
+              " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
+}
+
+static inline bool is_row_or_column_contiguous(const Tensor& t) {
+  // This could be made more general, similar to how it's checked in matmul, which would allow to
+  // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
+  // We choose to be conservative for simplicity
+  return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
+}
+
+static inline TransposeType to_transpose_type(const bool contig, const bool conj) {
+  if (conj) {
+    if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
+    else {        return TransposeType::ConjTranspose; }
+  } else {
+    if (contig) { return TransposeType::NoTranspose; }
+    else {        return TransposeType::Transpose; }
+  }
+}
+
+
+// This function is designed to be used with linear algebra methods that minimize
+// L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
+// or the L2 norm (`lstsq`).
+// It is expected that `a` and `b` are contiguous tensors of column-major matrices
+// (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
+// with the following additional properties:
+//
+// 1. a.dim() == b.dim()
+// 2. a.shape[:-2] broadcasts over b.shape[:-2]
+// 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
+//
+// MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
+// is to be memory efficient, which means that if there exists an index i such that
+// a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
+// then instead of materializing copies of `a` in the broadcasted shape, we keep
+// a buffer copy of `a` along with flags that check whether specific batch dimension
+// indices for `a` were already accessed. If they were, we copy the data from the buffer
+// into `a`. The number of copies does not exceed
+// prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
+// and this value is attained by tensors with non-empty batch dimensions.
+//
+// func_t `f` is a callable that is being supplied with
+// scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
+// a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
+// and a_linear_batch_idx is an index in the 3d representation which corresponds to
+// the memory a_working_ptr points to, in other words:
+// a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr();
+// a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
+// its rank or singular values (see linalg_lstsq).
+template
+void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
+  IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
+  IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
+
+  auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
+  auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
+
+  TensorIterator iter = TensorIteratorConfig()
+    .set_check_mem_overlap(false)
+    .check_all_same_dtype(false)
+    .resize_outputs(false)
+    .add_output(b_linear_batch_idx)
+    .add_input(a_linear_batch_idx)
+    .build();
+
+  auto m = a.size(-2);
+  auto n = a.size(-1);
+  auto a_3d = a.view({batchCount(a), m, n});
+  auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
+
+  auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
+  Tensor a_buffer, a_was_accessed, a_buffer_3d;
+  std::function check_if_copy_needed_for_a
+    = [](int64_t /*a_curr_linear_batch_idx*/){};
+  if (a_broadcasts_over_b) {
+    a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
+      .copy_(a);
+    a_was_accessed = at::zeros(batchCount(a), at::kBool);
+    a_buffer_3d = a_buffer.view({batchCount(a), m, n});
+    check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
+      auto* a_was_accessed_flag = a_was_accessed
+        .select(0, a_curr_linear_batch_idx)
+        .data_ptr();
+      if (!(*a_was_accessed_flag)) {
+        *a_was_accessed_flag = true;
+      }
+      else {
+        a_3d.select(0, a_curr_linear_batch_idx)
+          .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
+      }
+    };
+  }
+
+  auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
+    auto* b_batch_idx_ptr = data[0];
+    auto* a_batch_idx_ptr = data[1];
+
+    for (const auto elem C10_UNUSED : c10::irange(nelems)) {
+      auto b_curr_linear_batch_idx = *reinterpret_cast(b_batch_idx_ptr);
+      auto a_curr_linear_batch_idx = *reinterpret_cast(a_batch_idx_ptr);
+
+      check_if_copy_needed_for_a(a_curr_linear_batch_idx);
+
+      auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
+        .data_ptr();
+      auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
+        .data_ptr();
+      f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
+
+      b_batch_idx_ptr += strides[0];
+      a_batch_idx_ptr += strides[1];
+    }
+  };
+  iter.serial_for_each(loop, {0, batchCount(b)});
+}
+
+// Returns the epsilon value for floating types except half
+static inline double _get_epsilon(const ScalarType& sc_type) {
+  switch (sc_type) {
+    case at::ScalarType::Float:
+      return static_cast(std::numeric_limits::epsilon());
+    case at::ScalarType::Double:
+      return std::numeric_limits::epsilon();
+    default:
+      AT_ERROR("This function doesn't handle types other than float and double");
+  }
+}
+
+// Validates input shapes and devices
+// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
+static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
+  TORCH_CHECK(self.device() == A.device(),
+              "Expected b and A to be on the same device, but found b on ",
+              self.device(), " and A on ", A.device(), " instead.");
+
+  TORCH_CHECK(self.scalar_type() == A.scalar_type(),
+              "Expected b and A to have the same dtype, but found b of type ",
+              self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
+
+  TORCH_CHECK(A.size(-1) == A.size(-2),
+              "A must be batches of square matrices, "
+              "but they are ", A.size(-2), " by ", A.size(-1), " matrices");
+
+  TORCH_CHECK(A.size(-1) == self.size(-2),
+              "Incompatible matrix sizes for ", name, ": each A "
+              "matrix is ", A.size(-1), " by ", A.size(-1),
+              " but each b matrix is ", self.size(-2), " by ", self.size(-1));
+}
+
+static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
+  auto dtype = t.scalar_type();
+  TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
+              f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
+  if (!allow_low_precision_dtypes) {
+    TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
+                f_name, ": Low precision dtypes not supported. Got ", dtype);
+  }
+}
+
+
+// Checks if all the Tensors in a TensorList are of the same dimensions
+static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
+  for (auto &t : tensors) {
+    TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
+  }
+}
+
+static inline std::tuple, std::vector> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
+  // broadcast the batch dimensions of arg1 and arg2.
+  IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
+  IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
+  std::vector expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
+
+  std::vector arg1_expand_size({expand_batch_portion});
+  arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
+
+  std::vector arg2_expand_size({expand_batch_portion});
+  arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
+  return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
+}
+
+static inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
+  // If there's no name we assume we don't want to check the errors
+  if (name != nullptr) {
+    linearSolveCheckInputs(arg1, arg2, name);
+  }
+
+  auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
+
+  auto arg1_broadcasted  = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
+  auto arg2_broadcasted  = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
+  return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
+}
+
+static inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
+  IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
+  IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
+  auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
+  return broadcasted_batch_sizes;
+}
+
+// Return a permutation with the given axes moved to the end.
+static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
+  const std::vector a = axes.vec();
+  const int64_t ndim = self.ndimension();
+  std::vector perm;
+
+  for (const auto i : c10::irange(ndim)) {
+    auto it = std::find(a.begin(), a.end(), i);
+    if (it == a.end()) {
+       perm.push_back(i);
+    }
+  }
+  for (auto i : a) {
+    perm.push_back(i);
+  }
+
+  TORCH_CHECK((int64_t)perm.size() == ndim,
+    "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
+
+  return self.permute(perm);
+}
+
+// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
+static inline std::tuple _parse_qr_mode(c10::string_view mode) {
+  bool compute_q;
+  bool reduced;
+  if (mode == "reduced") {
+    compute_q = true;
+    reduced = true;
+  } else if (mode == "complete") {
+    compute_q = true;
+    reduced = false;
+  } else if (mode == "r") {
+    compute_q = false;
+    reduced = true; // this is actually irrelevant in this mode
+  } else {
+      TORCH_CHECK(false, "qr received unrecognized mode '", mode,
+                  "' but expected one of 'reduced' (default), 'r', or 'complete'");
+  }
+  return std::make_tuple(compute_q, reduced);
+}
+
+// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
+static inline std::tuple _compute_geometry_for_Q(
+    const Tensor& input,
+    bool reduced) {
+  int64_t m = input.size(-2), n = input.size(-1);
+  int64_t n_columns_q;
+
+  // We need to compute the required size of Q based on the `reduced` option
+  DimVector q_sizes(input.sizes());
+  if (!reduced && m > n) {
+    q_sizes[input.dim() - 1] = m;
+    n_columns_q = m;
+  } else {
+    q_sizes[input.dim() - 1] = n;
+    n_columns_q = std::min(m, n);
+  }
+  auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
+  return std::make_tuple(q_sizes, q_strides, n_columns_q);
+}
+
+static inline bool svd_uses_cusolver(const Tensor& A) {
+  // if cusolver is available, it is used unconditionally
+  return A.is_cuda()
+         && at::globalContext().hasCuSOLVER()
+         && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
+}
+
+
+// Function used instead of .to so that the original strides are retained
+// .to doesn't retain strides and make the output tensor contiguous
+static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
+  auto strided_to = at::empty_strided(original_tensor.sizes(),
+                                      original_tensor.strides(),
+                                      options);
+  strided_to.copy_(original_tensor);
+  return strided_to;
+}
+
+// Creates a dimension permutation array that can be given to `at::permute()`, which will shift
+// the two specified dimensions to the end of a tensor, without changing the order of
+// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
+// placed just to the left of it.
+//
+// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
+// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
+// be `vec(0, 2, 1, 3)`.
+static inline std::vector create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
+  TORCH_CHECK(
+    (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
+    "duplicate or invalid dimensions");
+  std::vector permutation(ndim);
+  int64_t cur_permuted_dim = 0;
+  for (const auto dim_ind : c10::irange(ndim)) {
+    if ((dim_ind != dim0) && (dim_ind != dim1)) {
+      permutation[cur_permuted_dim++] = dim_ind;
+    }
+  }
+  permutation[cur_permuted_dim++] = dim0;
+  permutation[cur_permuted_dim] = dim1;
+  return permutation;
+}
+
+// Creates a dimension permutation array that can be given to `at::permute()`, which
+// will reverse a given permutation.
+// The reverse permutation array is created by swapping the indices and their
+// associated values from the given permutation array.
+static inline std::vector create_reverse_permutation(std::vector permutation) {
+  int64_t ndim = permutation.size();
+  std::vector reverse_permutation(ndim);
+  for (const auto dim_ind : c10::irange(ndim)) {
+    reverse_permutation[permutation[dim_ind]] = dim_ind;
+  }
+  return reverse_permutation;
+}
+
+// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
+// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
+static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
+  auto mn = std::min(m, n);
+  auto mx = std::max(m, n);
+  if (jobz == 'N') {
+#ifdef __APPLE__
+    // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
+    return 7 * mn;
+#else
+    // These setting is valid for on LAPACK 3.6+
+    return 5 * mn;
+#endif
+  }
+  if (mx > 10 * mn) {
+    return 5 * mn * mn + 5 * mn;
+  }
+  return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
+}
+
+// This function checks whether the uplo argument input is valid
+// Allowed strings are "u", "U", "l", "L"
+static inline void checkUplo(const c10::string_view uplo) {
+  // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
+  char uplo_uppercase = static_cast(std::toupper(static_cast(uplo[0])));
+  TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
+    "Expected UPLO argument to be 'L' or 'U', but got ", uplo);
+}
+
+static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
+  TORCH_CHECK(
+      result.device() == input.device(),
+      fn_name,
+      ": Expected ", result_name, " and input tensors to be on the same device, but got ",
+      result_name, " on ", result.device(), " and input on ", input.device());
+}
+
+// Check the dtype of result and input tensors (for _out variants).
+// Most linear algebra functions have the same dtype for input and output
+// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
+// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
+// c10::canCast is used for checking the "safe copy" dtype requirements.
+static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
+  bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
+  TORCH_CHECK(
+      can_cast,
+      fn_name,
+      ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
+      result_name, " with dtype ", result.scalar_type());
+}
+
+// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
+static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
+  bool can_cast = c10::canCast(result_type, out_type);
+  TORCH_CHECK(
+      can_cast,
+      fn_name,
+      ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
+      out_name, " with dtype ", out_type);
+}
+
+static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
+  TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
+              f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
+}
+
+/*
+  Two types of 'other' tensors are supported when solving
+  a system of linear equations matmul(input, x) = other:
+  * 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
+  * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
+  The original torch.solve supported only the matrix case, while NumPy works for both cases.
+  For the batched input we need to be able to distinguish them.
+  Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
+  This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
+*/
+static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
+  auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
+  bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
+  return vector_case;
+}
+
+/*
+  Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
+*/
+static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
+  TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
+  return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
+}
+
+class BroadcastLinearIndices {
+ private:
+  Tensor linear_indices_;
+  bool is_broadcasting_;
+
+ public:
+  BroadcastLinearIndices(
+      int64_t numel,
+      IntArrayRef original_shape,
+      IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
+    // The assumption is that the broadcast_shape is a materialized broadcast
+    // shape of the original_shape. We need to compute the linear indices
+    // compatible with the original_shape to access the elements in the original
+    // tensor corresponding to the broadcast tensor.
+    if (is_broadcasting_) {
+      linear_indices_ =
+          get_linear_indices(numel, original_shape, broadcast_shape);
+    }
+  }
+  int64_t operator()(int64_t broadcast_linear_index) {
+    return is_broadcasting_
+        ? linear_indices_.data_ptr()[broadcast_linear_index]
+        : broadcast_linear_index;
+  }
+};
+
+static inline bool is_blas_compatible_column_major_order(const Tensor& input) {
+  IntArrayRef input_strides = input.strides();
+  IntArrayRef input_sizes = input.sizes();
+  auto ndim = input.dim();
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
+  if (ndim > 3) {
+    return input.transpose(-2, -1).is_contiguous();
+  }
+  auto leading_dimension = input_strides[ndim - 1];
+  auto rows = input_sizes[ndim - 2];
+  bool batch_stride_compatible = true;
+  if (ndim == 3) {
+    auto cols = input_sizes[ndim - 1];
+    batch_stride_compatible =
+        input_strides[ndim - 3] >= leading_dimension * cols;
+  }
+  return (input_strides[ndim - 2] == 1) &&
+      (leading_dimension >= std::max(1, rows)) &&
+      batch_stride_compatible;
+}
+
+static inline bool is_blas_compatible_row_major_order(const Tensor& input) {
+  IntArrayRef input_strides = input.strides();
+  IntArrayRef input_sizes = input.sizes();
+  auto ndim = input.dim();
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
+  if (ndim > 3) {
+    return input.is_contiguous();
+  }
+  auto leading_dimension = input_strides[ndim - 2];
+  auto cols = input_sizes[ndim - 1];
+  bool batch_stride_compatible = true;
+  if (ndim == 3) {
+    auto rows = input_sizes[ndim - 2];
+    batch_stride_compatible =
+        input_strides[ndim - 3] >= leading_dimension * rows;
+  }
+  return (input_strides[ndim - 1] == 1) &&
+      (leading_dimension >= std::max(1, cols)) &&
+      batch_stride_compatible;
+}
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/LossMulti.h b/MLPY/Lib/site-packages/torch/include/ATen/native/LossMulti.h
new file mode 100644
index 0000000000000000000000000000000000000000..d0a338234427cbf278ee70a921b769c42ad590b2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/LossMulti.h
@@ -0,0 +1,72 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+namespace {
+  static C10_UNUSED void multilabel_margin_loss_shape_check(
+    int64_t& nframe,
+    int64_t& dim,
+    const int64_t& ndims,
+    const Tensor& input,
+    const Tensor& target) {
+    TORCH_CHECK(
+        (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
+        "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
+        input.sizes());
+
+    if (ndims <= 1) {
+      nframe = 1;
+      dim = ndims == 0 ? 1 : input.size(0);
+      TORCH_CHECK(
+          target.dim() <= 1 && target.numel() == dim,
+          "inconsistent target size: ", target.sizes(), " for input of size: ",
+          input.sizes());
+    } else {
+      nframe = input.size(0);
+      dim = input.size(1);
+      TORCH_CHECK(
+          target.dim() == 2 && target.size(0) == nframe &&
+          target.size(1) == dim,
+          "inconsistent target size: ", target.sizes(), " for input of size: ",
+          input.sizes());
+    }
+  }
+
+  static C10_UNUSED void multi_margin_loss_shape_check(
+    int64_t& nframe,
+    int64_t& dim,
+    const int64_t& ndims,
+    const Tensor& input,
+    const Tensor& target,
+    const c10::optional& weight) {
+    TORCH_CHECK(
+        (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
+        "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
+        input.sizes());
+
+    if (ndims <= 1) {
+      nframe = 1;
+      dim = ndims == 0 ? 1 : input.size(0);
+    } else {
+      nframe = input.size(0);
+      dim = input.size(1);
+    }
+
+    TORCH_CHECK(
+        target.dim() <= 1 && target.numel() == nframe,
+        "inconsistent target size, expected ", nframe, " but got ",
+        target.sizes());
+    if (weight && weight->defined()) {
+      TORCH_CHECK(
+          weight->dim() <= 1 && weight->numel() == dim,
+          "inconsistent weight size, expected ", dim, " but got ",
+          weight->sizes());
+    }
+}
+
+
+}  // anonymous namespace
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Math.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Math.h
new file mode 100644
index 0000000000000000000000000000000000000000..3dc1f678c4546bf55cd6a2f48160e800c98e3fdc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Math.h
@@ -0,0 +1,3901 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+C10_CLANG_DIAGNOSTIC_PUSH()
+#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
+C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
+#endif
+
+/* The next function is taken from  https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
+Below is the copyright.
+Output was modified to be inf or -inf when input is 1 or -1. */
+
+
+/*
+    Copyright (c) 2014 Indiana University
+    All rights reserved.
+
+    Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci.,
+            Indiana University, Bloomington, IN
+
+    This software is licensed under the New BSD license:
+
+    Redistribution and use in source and binary forms,
+    with or without modification, are permitted provided
+    that the following conditions are met:
+
+    Redistributions of source code must retain the above
+    copyright notice, this list of conditions and the
+    following disclaimer.
+
+    Redistributions in binary form must reproduce the
+    above copyright notice, this list of conditions and
+    the following disclaimer in the documentation and/or
+    other materials provided with the distribution.
+
+    Neither the name of Indiana University nor
+    the names of its contributors may be used to endorse
+    or promote products derived from this software without
+    specific prior written permission.
+
+    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+    CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
+    WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+    WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
+    THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
+    DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+    CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
+    USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+    HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
+    IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+    USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+    POSSIBILITY OF SUCH DAMAGE.
+*/
+
+namespace {
+/*
+ * This function is derived from the implementation of the i0e function in the
+ * Cephes Math Library. See note [3-Clause BSD License for the Cephes Math
+ * Library].
+ *
+ * Computes an approximation of the exponentially scaled zeroth order modified
+ * Bessel function of the first kind. The approximation is actually two
+ * (sub)approximations, both using a Chebyshev polynomial expansion. One
+ * approximates the function over [0, 8], and the other over (8, infinity). This
+ * function takes the absolute value of all inputs to convert them into the
+ * domain of the approximation.
+ */
+jiterator_also_stringify_as(jiterator_code(
+  template 
+  JITERATOR_HOST_DEVICE T chbevl(T x, const T array[], const int len) {
+    T b0, b1, b2;
+
+    b0 = array[0];
+    b1 = 0;
+
+    for (int i = 1; i < len; ++i) {
+      b2 = b1;
+      b1 = b0;
+      b0 = x * b1 - b2 + array[i];
+    }
+
+    return T{0.5} * (b0 - b2);
+  }
+
+  template 
+  JITERATOR_HOST_DEVICE T calc_i0e(T _x) {
+    T x = std::fabs(_x);
+
+    if (x <= T{8.0}) {
+      static const T coefficients[] = {
+          -4.41534164647933937950E-18, 3.33079451882223809783E-17,
+          -2.43127984654795469359E-16, 1.71539128555513303061E-15,
+          -1.16853328779934516808E-14, 7.67618549860493561688E-14,
+          -4.85644678311192946090E-13, 2.95505266312963983461E-12,
+          -1.72682629144155570723E-11, 9.67580903537323691224E-11,
+          -5.18979560163526290666E-10, 2.65982372468238665035E-9,
+          -1.30002500998624804212E-8,  6.04699502254191894932E-8,
+          -2.67079385394061173391E-7,  1.11738753912010371815E-6,
+          -4.41673835845875056359E-6,  1.64484480707288970893E-5,
+          -5.75419501008210370398E-5,  1.88502885095841655729E-4,
+          -5.76375574538582365885E-4,  1.63947561694133579842E-3,
+          -4.32430999505057594430E-3,  1.05464603945949983183E-2,
+          -2.37374148058994688156E-2,  4.93052842396707084878E-2,
+          -9.49010970480476444210E-2,  1.71620901522208775349E-1,
+          -3.04682672343198398683E-1,  6.76795274409476084995E-1};
+
+      T y = (x / T{2.0}) - T{2.0};
+      return chbevl(y, coefficients, int{30});
+    }
+
+    // x > 8
+    static const T coefficients[] = {
+        -7.23318048787475395456E-18, -4.83050448594418207126E-18,
+        4.46562142029675999901E-17,  3.46122286769746109310E-17,
+        -2.82762398051658348494E-16, -3.42548561967721913462E-16,
+        1.77256013305652638360E-15,  3.81168066935262242075E-15,
+        -9.55484669882830764870E-15, -4.15056934728722208663E-14,
+        1.54008621752140982691E-14,  3.85277838274214270114E-13,
+        7.18012445138366623367E-13,  -1.79417853150680611778E-12,
+        -1.32158118404477131188E-11, -3.14991652796324136454E-11,
+        1.18891471078464383424E-11,  4.94060238822496958910E-10,
+        3.39623202570838634515E-9,   2.26666899049817806459E-8,
+        2.04891858946906374183E-7,   2.89137052083475648297E-6,
+        6.88975834691682398426E-5,   3.36911647825569408990E-3,
+        8.04490411014108831608E-1};
+
+    return chbevl(T{32.0} / x - T{2.0}, coefficients, int{25}) / std::sqrt(x);
+  }),
+  i0e_string); // i0e_string
+}
+
+#define CENTRAL_RANGE 0.7
+
+template 
+static inline typename std::enable_if::value, T>::type
+calc_erfinv(T y) {
+/* Function to calculate inverse error function.  Rational approximation
+is used to generate an initial approximation, which is then improved to
+full accuracy by two steps of Newton's method.  Code is a direct
+translation of the erfinv m file in matlab version 2.0.
+Author:  Gary L. Pavlis, Indiana University
+Date:  February 1996
+*/
+  T x, z, num, dem; /*working variables */
+  /* coefficients in rational expansion */
+  T a[4] = {  T(0.886226899), T(-1.645349621),  T(0.914624893), T(-0.140543331) };
+  T b[4] = { T(-2.118377725),  T(1.442710462), T(-0.329097515),  T(0.012229801) };
+  T c[4] = { T(-1.970840454), T(-1.624906493),  T(3.429567803),  T(1.641345311) };
+  T d[2] = {  T(3.543889200),  T(1.637067800) };
+  T y_abs = std::abs(y);
+  if(y_abs > 1.0) return std::numeric_limits::quiet_NaN();
+#ifdef _WIN32
+  // error C2039: '_copysign': is not a member of 'std'
+  if(y_abs == 1.0) return copysign(std::numeric_limits::infinity(), y);
+#else
+  if(y_abs == 1.0) return std::copysign(std::numeric_limits::infinity(), y);
+#endif
+  if(y_abs <= static_cast(CENTRAL_RANGE)) {
+    z = y * y;
+    num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
+    dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast(1.0));
+    x = y * num / dem;
+  }
+  else{
+    z = std::sqrt(-std::log((static_cast(1.0)-y_abs)/static_cast(2.0)));
+    num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
+    dem = (d[1]*z + d[0])*z + static_cast(1.0);
+#ifdef _WIN32
+    // error C2039: '_copysign': is not a member of 'std'
+    x = copysign(num, y) / dem;
+#else
+    x = std::copysign(num, y) / dem;
+#endif
+  }
+  /* Two steps of Newton-Raphson correction */
+  x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x));
+  x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x));
+
+  return(x);
+}
+
+#undef CENTRAL_RANGE
+
+/*
+ * Note [3-Clause BSD License for the Cephes Math Library]
+ * Code derived from implementations in the Cephes Math Library should mention its derivation and reference
+ * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note
+ * [3-Clause BSD License for the Cephes Math Library]. The license is:
+ * Copyright (c) 2018, Steven Moshier
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of the nor the
+ * names of its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/*
+ * This function is derived from the implementation of the zeta function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ */
+template 
+C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ {
+  using acc_t = at::acc_type;
+  const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
+  constexpr acc_t zero = acc_t{0.0};
+  constexpr acc_t half = acc_t{0.5};
+  constexpr acc_t one = acc_t{1.0};
+  static const acc_t A[] = {
+      12.0,
+      -720.0,
+      30240.0,
+      -1209600.0,
+      47900160.0,
+      -1.8924375803183791606e9, /*1.307674368e12/691*/
+      7.47242496e10,
+      -2.950130727918164224e12, /*1.067062284288e16/3617*/
+      1.1646782814350067249e14, /*5.109094217170944e18/43867*/
+      -4.5979787224074726105e15, /*8.028576626982912e20/174611*/
+      1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
+      -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
+  };
+
+  int i = 0;
+  acc_t a, b, k, s, t, w;
+  if (x == one) {
+    return std::numeric_limits::infinity();
+  }
+
+  if (x < one) {
+    return std::numeric_limits::quiet_NaN();
+  }
+
+  if (q <= zero) {
+    if (q == std::floor(q)) {
+      return std::numeric_limits::infinity();
+    }
+    if (x != std::floor(x)) {
+      return std::numeric_limits::quiet_NaN();
+    }
+  }
+
+  s = std::pow(q, -x);
+  a = q;
+  i = 0;
+  b = zero;
+  while ((i < 9) || (a <= acc_t{9.0})) {
+    i += 1;
+    a += one;
+    b = ::pow(a, -x);
+    s += b;
+    if ((-MACHEP * s < b) && (b < MACHEP * s)) {
+      return static_cast(s);
+    }
+  };
+
+  w = a;
+  s += b * w / (x - one);
+  s -= half * b;
+  a = one;
+  k = zero;
+  for (int i = 0; i < 12; i++) {
+    a *= x + k;
+    b /= w;
+    t = a * b / A[i];
+    s = s + t;
+    t = ::fabs(t / s);
+    if (t < MACHEP) {
+      return static_cast(s);
+    }
+    k += one;
+    a *= x + k;
+    b /= w;
+    k += one;
+  }
+  return static_cast(s);
+}
+
+/*
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ *
+ * Evaluates polynomial of degree N:
+ *
+ *                     2          N
+ * y  =  C  + C x + C x  +...+ C x
+ *        0    1     2          N
+ *
+ * Coefficients are stored in reverse order:
+ *
+ * coef[0] = C  , ..., coef[N] = C  .
+ *            N                   0
+ */
+template 
+C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) {
+  T result = 0;
+  for (size_t i = 0; i <= len; i++) {
+    result = result * x + A[i];
+  }
+  return result;
+}
+
+static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
+  double sign = +1;
+  double result = 0;
+  if (x < 0.5) {
+    sign = -1;
+    const double sin_pi_x = sin(c10::pi * x);
+    result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x);
+    x = 1 - x;
+  }
+  for (int i = 0; i < 6; ++i) {
+    result += 1 / (x * x);
+    x += 1;
+  }
+  const double ixx = 1 / (x*x);
+  result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x;
+  return sign * result;
+}
+
+static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
+  float sign = +1;
+  float result = 0;
+  if (x < 0.5f) {
+    sign = -1;
+    const float sin_pi_x = sinf(c10::pi * x);
+    result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x);
+    x = 1 - x;
+  }
+  for (int i = 0; i < 6; ++i) {
+    result += 1 / (x * x);
+    x += 1;
+  }
+  const float ixx = 1 / (x*x);
+  result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x;
+  return sign * result;
+}
+
+/*
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ */
+static inline double calc_digamma(double x) {
+  // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
+  static double PSI_10 = 2.25175258906672110764;
+  if (x == 0) {
+    // As per C++ standard for gamma related functions and SciPy,
+    // If the argument is ±0, ±∞ is returned
+    return std::copysign(INFINITY, -x);
+  }
+
+  bool x_is_integer = x == trunc(x);
+  if (x < 0) {
+    if (x_is_integer) {
+      // As per C++ standard for gamma related functions and SciPy,
+      // If the argument is a negative integer, NaN is returned
+      return std::numeric_limits::quiet_NaN();
+    }
+    // Extracts the fractional part of x as r, since tan(pi * r) is more numerically
+    // accurate than tan(pi * x). While these operations are mathematically equivalent
+    // since both x and r are in radians and tan() has a periodicity of pi, in practice
+    // the computation of pi * x is a source of error (when |x| > 1).
+    double q, r;
+    r = std::modf(x, &q);
+    return calc_digamma(1 - x) - c10::pi / tan(c10::pi * r);
+  }
+
+  // Push x to be >= 10
+  double result = 0;
+  while (x < 10) {
+    result -= 1 / x;
+    x += 1;
+  }
+  if (x == 10) {
+    return result + PSI_10;
+  }
+
+  // Compute asymptotic digamma
+  static const double A[] = {
+      8.33333333333333333333E-2,
+      -2.10927960927960927961E-2,
+      7.57575757575757575758E-3,
+      -4.16666666666666666667E-3,
+      3.96825396825396825397E-3,
+      -8.33333333333333333333E-3,
+      8.33333333333333333333E-2,
+  };
+
+  double y = 0;
+  if (x < 1.0e17) {
+    double z = 1.0 / (x * x);
+    y = z * polevl(z, A, 6);
+  }
+  return result + log(x) - (0.5 / x) - y;
+}
+
+/*
+ * This function is derived from the implementation of the digamma function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ */
+static inline float calc_digamma(float x) {
+  // See [C++ Standard Reference: Gamma Function]
+  static float PSI_10 = 2.25175258906672110764f;
+  if (x == 0) {
+    // As per C++ standard for gamma related functions and SciPy,
+    // If the argument is ±0, ±∞ is returned
+    return std::copysign(INFINITY, -x);
+  }
+
+  bool x_is_integer = x == truncf(x);
+  if (x < 0) {
+    if (x_is_integer) {
+    // As per C++ standard for gamma related functions and SciPy,
+    // If the argument is a negative integer, NaN is returned
+      return std::numeric_limits::quiet_NaN();
+    }
+    // Extracts the fractional part of x as r, since tan(pi * r) is more numerically
+    // accurate than tan(pi * x). While these operations are mathematically equivalent
+    // since both x and r are in radians and tan() has a periodicity of pi, in practice
+    // the computation of pi * x is a source of error (when |x| > 1).
+    double q, r;
+    r = std::modf(x, &q);
+    float pi_over_tan_pi_x = (float)(c10::pi / tan(c10::pi * r));
+    return calc_digamma(1 - x) - pi_over_tan_pi_x;
+  }
+
+  // Push x to be >= 10
+  float result = 0;
+  while (x < 10) {
+    result -= 1 / x;
+    x += 1;
+  }
+  if (x == 10) {
+    return result + PSI_10;
+  }
+
+  // Compute asymptotic digamma
+  static const float A[] = {
+      8.33333333333333333333E-2f,
+      -2.10927960927960927961E-2f,
+      7.57575757575757575758E-3f,
+      -4.16666666666666666667E-3f,
+      3.96825396825396825397E-3f,
+      -8.33333333333333333333E-3f,
+      8.33333333333333333333E-2f,
+  };
+
+  float y = 0;
+  if (x < 1.0e17f) {
+    float z = 1 / (x * x);
+    y = z * polevl(z, A, 6);
+  }
+  return result + logf(x) - (0.5f / x) - y;
+}
+
+static inline c10::BFloat16 calc_digamma(c10::BFloat16 a) {
+  return calc_digamma(static_cast(a));
+}
+
+static inline c10::Half calc_digamma(c10::Half a) {
+  return calc_digamma(static_cast(a));
+}
+
+template 
+static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
+  // already blocked if n <= 1
+  const auto one = scalar_t{1};
+  return ((n % 2) ? one : -one) *
+      std::exp(std::lgamma(static_cast(n) + one)) *
+      zeta(static_cast(n + 1), x);
+}
+
+// regularized lower incomplete gamma
+// the regularized lower, upper incomplete gamma, as well as their
+// helper functions follow SciPy's implementation
+
+/* References
+ * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov
+ * [igam2] Maddock et. al., "Incomplete Gamma Functions",
+ *     https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html
+ */
+
+/*
+ * This implementation of the regularized incomplete gamma functions and
+ * their helper functions are derived from the implementation of SciPy's
+ * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations.
+ * See NOTICE for the licenses.
+ */
+template 
+static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
+    const scalar_t denom[], int64_t N) {
+  // evaluating rational function, i.e., the ratio of two polynomials
+  // the coefficients for numerator are given by `num` while coeffs for
+  // denumerator are given by `denom`
+
+  int64_t i, dir;
+  scalar_t y, num_ans, denom_ans;
+  scalar_t absx = std::fabs(x);
+  const scalar_t *p;
+
+  if (absx > 1) {
+    /* Evaluate as a polynomial in 1/x. */
+    dir = -1;
+    p = num + M;
+    y = 1 / x;
+  }
+  else {
+    dir = 1;
+    p = num;
+    y = x;
+  }
+
+  /* Evaluate the numerator */
+  num_ans = *p;
+  p += dir;
+  for (i = 1; i <= M; i++) {
+    num_ans = num_ans * y + *p;
+    p += dir;
+  }
+  /* Evaluate the denominator */
+  if (absx > 1) {
+    p = denom + N;
+  }
+  else {
+    p = denom;
+  }
+
+  denom_ans = *p;
+  p += dir;
+  for (i = 1; i <= N; i++) {
+    denom_ans = denom_ans * y + *p;
+    p += dir;
+  }
+  if (absx > 1) {
+    i = N - M;
+    return std::pow(x, i) * num_ans / denom_ans;
+  }
+  else {
+    return num_ans / denom_ans;
+  }
+}
+
+// SciPy's lanczos implementation is taken from Boost
+/* (C) Copyright John Maddock 2006.
+ * Use, modification and distribution are subject to the
+ * Boost Software License, Version 1.0. See
+ * https://www.boost.org/LICENSE_1_0.txt or see NOTICE.
+ */
+template 
+static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
+  // lanczos approximation
+  static const scalar_t lanczos_sum_expg_scaled_num[13] = {
+    0.006061842346248906525783753964555936883222,
+    0.5098416655656676188125178644804694509993,
+    19.51992788247617482847860966235652136208,
+    449.9445569063168119446858607650988409623,
+    6955.999602515376140356310115515198987526,
+    75999.29304014542649875303443598909137092,
+    601859.6171681098786670226533699352302507,
+    3481712.15498064590882071018964774556468,
+    14605578.08768506808414169982791359218571,
+    43338889.32467613834773723740590533316085,
+    86363131.28813859145546927288977868422342,
+    103794043.1163445451906271053616070238554,
+    56906521.91347156388090791033559122686859
+  };
+  static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
+    1.,
+    66.,
+    1925.,
+    32670.,
+    357423.,
+    2637558.,
+    13339535.,
+    45995730.,
+    105258076.,
+    150917976.,
+    120543840.,
+    39916800.,
+    0.
+  };
+  return ratevl(x, lanczos_sum_expg_scaled_num,
+      sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1,
+      lanczos_sum_expg_scaled_denom,
+      sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1);
+}
+
+template 
+static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
+  // compute x^a * exp(-a) / gamma(a)
+  // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with
+  // exp(a - x).
+
+  scalar_t ax, fac, res, num, numfac;
+  static scalar_t MAXLOG = std::is_same::value ?
+    7.09782712893383996843E2 : 88.72283905206835;
+  static scalar_t EXP1 = 2.718281828459045;
+  static scalar_t lanczos_g = 6.024680040776729583740234375;
+
+  if (std::fabs(a - x) > 0.4 * std::fabs(a)) {
+    ax = a * std::log(x) - x - std::lgamma(a);
+    if (ax < -MAXLOG) {
+      return 0.0;
+    }
+    return std::exp(ax);
+  }
+
+  fac = a + lanczos_g - 0.5;
+  res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a);
+
+  if ((a < 200) && (x < 200)) {
+    res *= std::exp(a - x) * std::pow(x / fac, a);
+  }
+  else {
+    num = x - a - lanczos_g + 0.5;
+    numfac = num / fac;
+    res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac);
+  }
+  return res;
+}
+
+template 
+static scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
+  // Compute igam using DLMF 8.11.4. [igam1]
+  static scalar_t MACHEP = std::is_same::value ?
+    1.11022302462515654042E-16 : 5.9604644775390625E-8;
+  static int MAXITER = 2000;
+
+  int i;
+  scalar_t ans, ax, c, r;
+
+  ax = _igam_helper_fac(a, x);
+  if (ax == 0.0) {
+    return 0.0;
+  }
+
+  /* power series */
+  r = a;
+  c = 1.0;
+  ans = 1.0;
+
+  for (i = 0; i < MAXITER; i++) {
+    r += 1.0;
+    c *= x / r;
+    ans += c;
+    if (c <= MACHEP * ans) {
+      break;
+    }
+  }
+  return (ans * ax / a);
+}
+
+template 
+static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
+  // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in
+  // _igam_helper_series but extra care is taken to avoid cancellation.
+
+  int n;
+  scalar_t fac = 1;
+  scalar_t sum = 0;
+  scalar_t term, logx;
+  static scalar_t MAXITER = 2000;
+  static scalar_t MACHEP = std::is_same::value ?
+    1.11022302462515654042E-16 : 5.9604644775390625E-8;
+
+  for (n = 1; n < MAXITER; n++) {
+    fac *= -x / n;
+    term = fac / (a + n);
+    sum += term;
+    if (std::fabs(term) <= MACHEP * std::fabs(sum)) {
+        break;
+    }
+  }
+
+  logx = std::log(x);
+  term = -std::expm1(a * logx - std::lgamma(1+a));
+  return term - std::exp(a * logx - std::lgamma(a)) * sum;
+}
+
+template 
+static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
+  // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
+  static const scalar_t d[25][25] =
+    {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
+      1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
+      3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
+      8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9,
+      1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10,
+      -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11,
+      -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13,
+      -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16,
+      -1.9752288294349443e-15},
+    {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3,
+      -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7,
+      -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6,
+      4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8,
+      1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9,
+      4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14,
+      7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13,
+      -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14,
+      -4.13125571381061e-15},
+    {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4,
+      2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5,
+      -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6,
+      -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10,
+      -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9,
+      9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11,
+      1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12,
+      4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17,
+      8.8592218725911273e-15},
+    {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4,
+      2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7,
+      1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6,
+      -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8,
+      -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9,
+      -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14,
+      -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12,
+      6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14,
+      2.0453671226782849e-14},
+    {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4,
+      -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5,
+      1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6,
+      8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11,
+      2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9,
+      -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10,
+      -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12,
+      -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18,
+      -5.0453320690800944e-14},
+    {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4,
+      -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7,
+      -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6,
+      -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7,
+      4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9,
+      3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15,
+      9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11,
+      -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13,
+      -1.3249659916340829e-13},
+    {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4,
+      7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5,
+      -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6,
+      -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13,
+      -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8,
+      8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10,
+      2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11,
+      1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18,
+      3.6902800842763467e-13},
+    {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4,
+      2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7,
+      2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6,
+      4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7,
+      -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8,
+      -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17,
+      -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11,
+      2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12,
+      1.0865561947091654e-12},
+    {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4,
+      -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4,
+      4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5,
+      6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11,
+      3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8,
+      6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9,
+      -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10,
+      -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18,
+      -3.3721464474854592e-12},
+    {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4,
+      -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7,
+      -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5,
+      -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6,
+      8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7,
+      8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14,
+      3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10,
+      -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11,
+      -1.1002224534207726e-11},
+    {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3,
+      9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4,
+      -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5,
+      -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11,
+      -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7,
+      -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8,
+      1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9,
+      9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18,
+      3.7647749553543836e-11},
+    {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3,
+      2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7,
+      3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4,
+      2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5,
+      -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6,
+      -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14,
+      -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9,
+      -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10,
+      1.3481607129399749e-10},
+    {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3,
+      -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3,
+      8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4,
+      1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10,
+      1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6,
+      7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7,
+      -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8,
+      -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20,
+      -5.0423112718105824e-10},
+    {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3,
+      -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6,
+      -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4,
+      -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4,
+      4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5,
+      6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13,
+      3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8,
+      8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9,
+      -1.9661464453856102e-9},
+    {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2,
+      7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2,
+      -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3,
+      -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10,
+      -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5,
+      -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6,
+      1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7,
+      1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17,
+      7.9795091026746235e-9},
+    {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2,
+      5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6,
+      1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3,
+      3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3,
+      -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4,
+      -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12,
+      -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6,
+      -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7,
+      3.3654425209171788e-8},
+    {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1,
+      -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2,
+      4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2,
+      1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9,
+      1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4,
+      1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5,
+      -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6,
+      -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16,
+      -1.4729737374018841e-7},
+    {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1,
+      -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5,
+      -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2,
+      -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2,
+      5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3,
+      1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12,
+      8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5,
+      3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6,
+      -6.6812849447625594e-7},
+    {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968,
+      1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1,
+      -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1,
+      -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8,
+      -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3,
+      -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3,
+      3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5,
+      5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14,
+      3.1369106244517615e-6},
+    {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906,
+      4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4,
+      1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1,
+      1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1,
+      -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2,
+      -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11,
+      -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4,
+      9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5,
+      1.5227271505597605e-5},
+    {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1,
+      -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1,
+      5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816,
+      2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7,
+      3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1,
+      8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2,
+      -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3,
+      -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11,
+      -7.6340103696869031e-5},
+    {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1,
+      -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3,
+      -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1,
+      -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195,
+      1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1,
+      3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10,
+      3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3,
+      -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3,
+      -3.9479941246822517e-4},
+    {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2,
+      1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2,
+      -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1,
+      -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7,
+      -6.2716159907747034, 5.1168999071852637, -2.0319658112299095,
+      -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1,
+      1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2,
+      2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6,
+      2.1250180774699461e-3},
+    {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2,
+      7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2,
+      3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2,
+      1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1,
+      -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373,
+      -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7,
+      -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1,
+      1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2,
+      1.5109265210467774e-2},
+    {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3,
+      -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3,
+      1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2,
+      7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6,
+      1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1,
+      -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1,
+      -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468,
+      -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1,
+      4.8683443692930507e-1}};
+
+  int k, n, sgn;
+  int maxpow = 0;
+  static scalar_t MACHEP = std::is_same::value ?
+    1.11022302462515654042E-16 : 5.9604644775390625E-8;
+  scalar_t lambda = x / a;
+  scalar_t sigma = (x - a) / a;
+  scalar_t eta, res, ck, ckterm, term, absterm;
+  scalar_t absoldterm = INFINITY;
+  scalar_t etapow[25] = {1};
+  scalar_t sum = 0;
+  scalar_t afac = 1;
+
+  if (igam) {
+    sgn = -1;
+  }
+  else {
+    sgn = 1;
+  }
+
+  if (lambda > 1) {
+    eta = std::sqrt(-2 * (std::log1p(sigma) - sigma));
+  }
+  else if (lambda < 1) {
+    eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma));
+  }
+  else {
+    eta = 0;
+  }
+  res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2));
+
+  for (k = 0; k < 25; k++) {
+    ck = d[k][0];
+    for (n = 1; n < 25; n++) {
+      if (n > maxpow) {
+        etapow[n] = eta * etapow[n-1];
+        maxpow += 1;
+      }
+      ckterm = d[k][n]*etapow[n];
+      ck += ckterm;
+      if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) {
+        break;
+      }
+    }
+    term = ck * afac;
+    absterm = std::fabs(term);
+    if (absterm > absoldterm) {
+      break;
+    }
+    sum += term;
+    if (absterm < MACHEP * std::fabs(sum)) {
+      break;
+    }
+    absoldterm = absterm;
+    afac /= a;
+  }
+  res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * c10::pi * a);
+
+  return res;
+}
+
+template 
+static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
+  // Compute igamc using DLMF 8.9.2. [igam1]
+  int i;
+  scalar_t ans, ax, c, yc, r, t, y, z;
+  scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
+  int MAXITER = 2000;
+  static scalar_t MACHEP = std::is_same::value ?
+    1.11022302462515654042E-16 : 5.9604644775390625E-8;
+  static scalar_t BIG = std::is_same::value ?
+    4.503599627370496e15 : 16777216.;
+  static scalar_t BIGINV = std::is_same::value ?
+    2.22044604925031308085e-16 : 5.9604644775390625E-8;
+
+  ax = _igam_helper_fac(a, x);
+  if (ax == 0.0) {
+    return 0.0;
+  }
+
+  /* continued fraction */
+  y = 1.0 - a;
+  z = x + y + 1.0;
+  c = 0.0;
+  pkm2 = 1.0;
+  qkm2 = x;
+  pkm1 = x + 1.0;
+  qkm1 = z * x;
+  ans = pkm1 / qkm1;
+
+  for (i = 0; i < MAXITER; i++) {
+    c += 1.0;
+    y += 1.0;
+    z += 2.0;
+    yc = y * c;
+    pk = pkm1 * z - pkm2 * yc;
+    qk = qkm1 * z - qkm2 * yc;
+    if (qk != 0) {
+      r = pk / qk;
+      t = std::fabs((ans - r) / r);
+      ans = r;
+    }
+    else {
+      t = 1.0;
+    }
+    pkm2 = pkm1;
+    pkm1 = pk;
+    qkm2 = qkm1;
+    qkm1 = qk;
+    if (std::fabs(pk) > BIG) {
+      pkm2 *= BIGINV;
+      pkm1 *= BIGINV;
+      qkm2 *= BIGINV;
+      qkm1 *= BIGINV;
+    }
+    if (t <= MACHEP) {
+      break;
+    }
+  }
+  return ans * ax;
+}
+
+template 
+static inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
+  /* the calculation of the regularized upper incomplete gamma function
+   * is done differently based on the values of a and x:
+   * - if x and/or a is at the boundary of defined region, then assign the
+   *   result at the boundary
+   * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
+   *   Large Parameter (see DLMF 8.12.4 [igam1])
+   * - if x > 1.1 and x < a, using the substraction from the regularized lower
+   *   incomplete gamma
+   * - otherwise, calculate the series from [igam2] eq (5)
+   */
+  scalar_t absxma_a;
+
+  static scalar_t SMALL = 20.0;
+  static scalar_t LARGE = 200.0;
+  static scalar_t SMALLRATIO = 0.3;
+  static scalar_t LARGERATIO = 4.5;
+
+  // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e.,
+  // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0.
+  if ((x < 0) || (a < 0)) {
+    // out of defined-region of the function
+    return std::numeric_limits::quiet_NaN();
+  }
+  else if (a == 0) {
+    if (x > 0) {
+      return 0.0;
+    }
+    else {
+      return std::numeric_limits::quiet_NaN();
+    }
+  }
+  else if (x == 0) {
+    return 1.0;
+  }
+  else if (std::isinf(a)) {
+    if (std::isinf(x)) {
+      return std::numeric_limits::quiet_NaN();
+    }
+    return 1.0;
+  }
+  else if (std::isinf(x)) {
+    return 0.0;
+  }
+
+  absxma_a = std::fabs(x - a) / a;
+  if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
+     return _igam_helper_asymptotic_series(a, x, 0);
+  }
+  else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) {
+     return _igam_helper_asymptotic_series(a, x, 0);
+  }
+
+  if (x > 1.1) {
+    if (x < a) {
+      return 1.0 - _igam_helper_series(a, x);
+    }
+    else {
+      return _igamc_helper_continued_fraction(a, x);
+    }
+  }
+  else if (x <= 0.5) {
+    if (-0.4 / std::log(x) < a) {
+      return 1.0 - _igam_helper_series(a, x);
+    }
+    else {
+      return _igamc_helper_series(a, x);
+    }
+  }
+  else {
+    if (x * 1.1 < a) {
+      return 1.0 - _igam_helper_series(a, x);
+    }
+    else {
+      return _igamc_helper_series(a, x);
+    }
+  }
+}
+
+template 
+static inline scalar_t calc_igamma(scalar_t a, scalar_t x) {
+  /* the calculation of the regularized lower incomplete gamma function
+   * is done differently based on the values of a and x:
+   * - if x and/or a is at the boundary of defined region, then assign the
+   *   result at the boundary
+   * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
+   *   Large Parameter (see DLMF 8.12.3 [igam1])
+   * - if x > 1 and x > a, using the substraction from the regularized upper
+   *   incomplete gamma
+   * - otherwise, calculate the series from [igam2] eq (4)
+   */
+  scalar_t absxma_a;
+  static scalar_t SMALL = 20.0;
+  static scalar_t LARGE = 200.0;
+  static scalar_t SMALLRATIO = 0.3;
+  static scalar_t LARGERATIO = 4.5;
+
+  // boundary values following SciPy
+  // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e.,
+  // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0.
+  if ((x < 0) || (a < 0)) {
+    // out of defined-region of the function
+    return std::numeric_limits::quiet_NaN();
+  }
+  else if (a == 0) {
+    if (x > 0) {
+      return 1.0;
+    }
+    else {
+      return std::numeric_limits::quiet_NaN();
+    }
+  }
+  else if (x == 0) {
+    return 0.0; // zero integration limit
+  }
+  else if (std::isinf(a)) {
+    if (std::isinf(x)) {
+      return std::numeric_limits::quiet_NaN();
+    }
+    return 0.0;
+  }
+  else if (std::isinf(x)) {
+    return 1.0;
+  }
+
+  /* Asymptotic regime where a ~ x. See [igam2] */
+  absxma_a = std::fabs(x - a) / a;
+  if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
+    return _igam_helper_asymptotic_series(a, x, 1);
+  }
+  else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) {
+    return _igam_helper_asymptotic_series(a, x, 1);
+  }
+
+  if ((x > 1.0) && (x > a)) {
+    return 1.0 - calc_igammac(a, x);
+  }
+
+  return _igam_helper_series(a, x);
+}
+
+template <>
+C10_UNUSED c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) {
+  return calc_igamma(float(a), float(x));
+}
+
+template <>
+C10_UNUSED c10::Half calc_igamma(c10::Half a, c10::Half x) {
+  return calc_igamma(float(a), float(x));
+}
+
+template <>
+C10_UNUSED c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) {
+  return calc_igammac(float(a), float(x));
+}
+
+template <>
+C10_UNUSED c10::Half calc_igammac(c10::Half a, c10::Half x) {
+  return calc_igammac(float(a), float(x));
+}
+
+inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); }
+
+template 
+static T abs_impl(T v) {
+  return std::abs(v);
+}
+
+template <>
+C10_UNUSED uint8_t abs_impl(uint8_t v) {
+  return v;
+}
+
+template 
+static inline typename std::enable_if::value, T>::type
+calc_gcd(T a, T b) {
+  a = abs_impl(a);
+  b = abs_impl(b);
+  while (a != 0) {
+    T c = a;
+    a = b % a;
+    b = c;
+  }
+  return b;
+}
+
+template 
+C10_HOST_DEVICE T exp2_impl(T x) {
+  return std::exp2(x);
+}
+
+template 
+C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) {
+  // There is no std::exp2 overload for complex, so instead
+  // use the identity 2^x = e^(ln(2) * x)
+  constexpr auto ln2 = c10::ln_2;
+  return std::exp(ln2 * x);
+}
+
+/*
+ * This function is derived from the implementation of the chbevl function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ *
+ * Evaluates the series
+ *
+ *       len-1
+ *         - '
+ *  y  =   >   array[i] T (x/2)
+ *         -             i
+ *        i=0
+ *
+ * of Chebyshev polynomials Ti at argument x/2.
+ *
+ * Coefficients are stored in reverse order, i.e. the zero order term is last in the array.  Note len is the number of
+ * coefficients, not the order.
+ *
+ * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before
+ * entering the routine.  This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined.
+ *
+ * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation
+ * required is x -> 2(2ab/x - b - a)/(b-a).  If b is infinity, this becomes x -> 4a/x - 1.
+ */
+template 
+static inline typename std::enable_if::value, T>::type
+chbevl(const T x, const T array[], size_t len) {
+  T b0, b1, b2;
+
+  b0 = array[0];
+  b1 = static_cast(0.0);
+
+  for (size_t i = 1; i < len; ++i) {
+    b2 = b1;
+    b1 = b0;
+    b0 = x * b1 - b2 + array[i];
+  }
+
+  return (static_cast(0.5) * (b0 - b2));
+}
+
+/*
+ * This function is derived from the implementation of the i0 function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ *
+ * Computes an approximation of the zeroth order modified Bessel function of the first kind.
+ * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
+ * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
+ * of all inputs to convert them into the domain of the approximation.
+ */
+template 
+static inline std::tuple chebyshev_coefficients_i0e_A() {
+  /* Chebyshev coefficients for exp(-x) I0(x)
+   * in the interval [0,8].
+   *
+   * lim(x->0){ exp(-x) I0(x) } = 1.
+   */
+  static const T coeff[] = {
+      -4.41534164647933937950E-18, 3.33079451882223809783E-17,
+      -2.43127984654795469359E-16, 1.71539128555513303061E-15,
+      -1.16853328779934516808E-14, 7.67618549860493561688E-14,
+      -4.85644678311192946090E-13, 2.95505266312963983461E-12,
+      -1.72682629144155570723E-11, 9.67580903537323691224E-11,
+      -5.18979560163526290666E-10, 2.65982372468238665035E-9,
+      -1.30002500998624804212E-8,  6.04699502254191894932E-8,
+      -2.67079385394061173391E-7,  1.11738753912010371815E-6,
+      -4.41673835845875056359E-6,  1.64484480707288970893E-5,
+      -5.75419501008210370398E-5,  1.88502885095841655729E-4,
+      -5.76375574538582365885E-4,  1.63947561694133579842E-3,
+      -4.32430999505057594430E-3,  1.05464603945949983183E-2,
+      -2.37374148058994688156E-2,  4.93052842396707084878E-2,
+      -9.49010970480476444210E-2,  1.71620901522208775349E-1,
+      -3.04682672343198398683E-1,  6.76795274409476084995E-1};
+  return std::make_tuple(coeff, 30);
+};
+
+template 
+static inline std::tuple chebyshev_coefficients_i0e_B() {
+  /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
+   * in the inverted interval [8,infinity].
+   *
+   * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
+   */
+  static const T coeff[] = {
+      -7.23318048787475395456E-18, -4.83050448594418207126E-18,
+      4.46562142029675999901E-17,  3.46122286769746109310E-17,
+      -2.82762398051658348494E-16, -3.42548561967721913462E-16,
+      1.77256013305652638360E-15,  3.81168066935262242075E-15,
+      -9.55484669882830764870E-15, -4.15056934728722208663E-14,
+      1.54008621752140982691E-14,  3.85277838274214270114E-13,
+      7.18012445138366623367E-13,  -1.79417853150680611778E-12,
+      -1.32158118404477131188E-11, -3.14991652796324136454E-11,
+      1.18891471078464383424E-11,  4.94060238822496958910E-10,
+      3.39623202570838634515E-9,   2.26666899049817806459E-8,
+      2.04891858946906374183E-7,   2.89137052083475648297E-6,
+      6.88975834691682398426E-5,   3.36911647825569408990E-3,
+      8.04490411014108831608E-1};
+
+  return std::make_tuple(coeff, 25);
+};
+
+template 
+static inline typename std::enable_if::value, std::tuple>::type
+chebyshev_coefficients_i1e_A() {
+  /* Chebyshev coefficients for exp(-x) I1(x)
+   * in the interval [0,8].
+   *
+   * lim(x->0){ exp(-x) I1(x) / x } = 1/2.
+   */
+  static const T coeff[] = {
+      2.77791411276104639959E-18, -2.11142121435816608115E-17,
+      1.55363195773620046921E-16, -1.10559694773538630805E-15,
+      7.60068429473540693410E-15, -5.04218550472791168711E-14,
+      3.22379336594557470981E-13, -1.98397439776494371520E-12,
+      1.17361862988909016308E-11, -6.66348972350202774223E-11,
+      3.62559028155211703701E-10, -1.88724975172282928790E-9,
+      9.38153738649577178388E-9,  -4.44505912879632808065E-8,
+      2.00329475355213526229E-7,  -8.56872026469545474066E-7,
+      3.47025130813767847674E-6,  -1.32731636560394358279E-5,
+      4.78156510755005422638E-5,  -1.61760815825896745588E-4,
+      5.12285956168575772895E-4,  -1.51357245063125314899E-3,
+      4.15642294431288815669E-3,  -1.05640848946261981558E-2,
+      2.47264490306265168283E-2,  -5.29459812080949914269E-2,
+      1.02643658689847095384E-1,  -1.76416518357834055153E-1,
+      2.52587186443633654823E-1};
+  return std::make_tuple(coeff, 29);
+};
+
+template 
+static inline typename std::enable_if::value, std::tuple>::type
+chebyshev_coefficients_i1e_A() {
+  /* Chebyshev coefficients for exp(-x) I1(x)
+   * in the interval [0,8].
+   *
+   * lim(x->0){ exp(-x) I1(x) / x } = 1/2.
+   */
+  static const T coeff[] = {
+      9.38153738649577178388E-9f,
+      -4.44505912879632808065E-8f,
+      2.00329475355213526229E-7f,
+      -8.56872026469545474066E-7f,
+      3.47025130813767847674E-6f,
+      -1.32731636560394358279E-5f,
+      4.78156510755005422638E-5f,
+      -1.61760815825896745588E-4f,
+      5.12285956168575772895E-4f,
+      -1.51357245063125314899E-3f,
+      4.15642294431288815669E-3f,
+      -1.05640848946261981558E-2f,
+      2.47264490306265168283E-2f,
+      -5.29459812080949914269E-2f,
+      1.02643658689847095384E-1f,
+      -1.76416518357834055153E-1f,
+      2.52587186443633654823E-1f};
+  return std::make_tuple(coeff, 17);
+};
+
+template 
+static inline typename std::enable_if::value, std::tuple>::type
+chebyshev_coefficients_i1e_B() {
+  /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
+   * in the inverted interval [8,infinity].
+   *
+   * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
+   */
+  static const T coeff[] = {
+      7.51729631084210481353E-18,  4.41434832307170791151E-18,
+      -4.65030536848935832153E-17, -3.20952592199342395980E-17,
+      2.96262899764595013876E-16,  3.30820231092092828324E-16,
+      -1.88035477551078244854E-15, -3.81440307243700780478E-15,
+      1.04202769841288027642E-14,  4.27244001671195135429E-14,
+      -2.10154184277266431302E-14, -4.08355111109219731823E-13,
+      -7.19855177624590851209E-13, 2.03562854414708950722E-12,
+      1.41258074366137813316E-11,  3.25260358301548823856E-11,
+      -1.89749581235054123450E-11, -5.58974346219658380687E-10,
+      -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
+      -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
+      -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
+      7.78576235018280120474E-1};
+
+  return std::make_tuple(coeff, 25);
+};
+
+template 
+static inline typename std::enable_if::value, std::tuple>::type
+chebyshev_coefficients_i1e_B() {
+  /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
+   * in the inverted interval [8,infinity].
+   *
+   * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
+   */
+  static const T coeff[] = {
+      -3.83538038596423702205E-9f,
+      -2.63146884688951950684E-8f,
+      -2.51223623787020892529E-7f,
+      -3.88256480887769039346E-6f,
+      -1.10588938762623716291E-4f,
+      -9.76109749136146840777E-3f,
+      7.78576235018280120474E-1f};
+
+  return std::make_tuple(coeff, 7);
+};
+
+template 
+static inline typename std::enable_if::value, T>::type
+calc_i0(T _x) {
+  T x = std::abs(_x);
+
+  if (x <= T{8.0}) {
+    auto coeff_pair = chebyshev_coefficients_i0e_A();
+    auto A = std::get<0>(coeff_pair);
+    auto len = std::get<1>(coeff_pair);
+    T y = (x / T{2.0}) - T{2.0};
+    return static_cast(std::exp(x) * chbevl(y, A, len));
+  }
+  auto coeff_pair = chebyshev_coefficients_i0e_B();
+  auto B = std::get<0>(coeff_pair);
+  auto len = std::get<1>(coeff_pair);
+  return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
+}
+
+// Upcast bfloat16 input to float for numerical accuracy purposes
+static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); }
+
+/*
+ * This function is derived from the implementation of the i1 function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ *
+ * Computes an approximation of the first order modified Bessel function of the first kind.
+ * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
+ * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
+ * of all inputs to convert them into the domain of the approximation.
+ */
+template 
+static inline typename std::enable_if::value, T>::type
+calc_i1(T _x) {
+  T x = std::abs(_x);
+
+  if (x <= T{8.0}) {
+    auto coeff_pair = chebyshev_coefficients_i1e_A();
+    auto A = std::get<0>(coeff_pair);
+    auto len = std::get<1>(coeff_pair);
+    T y = (x / T{2.0}) - T{2.0};
+    const T out = std::exp(x) * x * chbevl(y, A, len);
+    return (_x < T{0.0}) ? -out : out;
+  }
+  auto coeff_pair = chebyshev_coefficients_i1e_B();
+  auto B = std::get<0>(coeff_pair);
+  auto len = std::get<1>(coeff_pair);
+  const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x);
+  return (_x < T{0.0}) ? -out : out;
+}
+
+/*
+ * This function is derived from the implementation of the i1e function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ *
+ * Computes an approximation of the exponentially scaled first order modified Bessel function of the first kind.
+ * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
+ * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
+ * of all inputs to convert them into the domain of the approximation.
+ */
+template 
+static inline typename std::enable_if::value, T>::type
+calc_i1e(T _x) {
+  T x = std::abs(_x);
+
+  if (x <= T{8.0}) {
+    auto coeff_pair = chebyshev_coefficients_i1e_A();
+    auto A = std::get<0>(coeff_pair);
+    auto len = std::get<1>(coeff_pair);
+    T y = (x / T{2.0}) - T{2.0};
+    const T out = chbevl(y, A, len) * x;
+    return (_x < T{0.0}) ? -out : out;
+  }
+  auto coeff_pair = chebyshev_coefficients_i1e_B();
+  auto B = std::get<0>(coeff_pair);
+  auto len = std::get<1>(coeff_pair);
+  const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
+  return (_x < T{0.0}) ? -out : out;
+}
+
+/*
+ * This function is derived from the implementation of the i1e function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ *
+ * Computes the argument, x, for which the area under the Gaussian probability density function
+ * (integrated from minus infinity to x) is equal to y.
+ */
+template 
+static inline C10_HOST_DEVICE T calc_ndtri(T y0) {
+
+  /* sqrt(2pi) */
+  constexpr T s2pi = 2.50662827463100050242E0;
+  constexpr T one = 1;
+  constexpr T zero = 0;
+
+  /* approximation for 0 <= |y - 0.5| <= 3/8 */
+  static const T P0[5] = {
+      -5.99633501014107895267E1,
+      9.80010754185999661536E1,
+      -5.66762857469070293439E1,
+      1.39312609387279679503E1,
+      -1.23916583867381258016E0,
+  };
+
+  static const T Q0[9] = {
+      1.00000000000000000000E0,
+      1.95448858338141759834E0,
+      4.67627912898881538453E0,
+      8.63602421390890590575E1,
+      -2.25462687854119370527E2,
+      2.00260212380060660359E2,
+      -8.20372256168333339912E1,
+      1.59056225126211695515E1,
+      -1.18331621121330003142E0,
+  };
+
+  /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
+  * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
+  */
+  static const T P1[9] = {
+      4.05544892305962419923E0,
+      3.15251094599893866154E1,
+      5.71628192246421288162E1,
+      4.40805073893200834700E1,
+      1.46849561928858024014E1,
+      2.18663306850790267539E0,
+      -1.40256079171354495875E-1,
+      -3.50424626827848203418E-2,
+      -8.57456785154685413611E-4,
+  };
+
+  static const T Q1[9] = {
+      1.00000000000000000000E0,
+      1.57799883256466749731E1,
+      4.53907635128879210584E1,
+      4.13172038254672030440E1,
+      1.50425385692907503408E1,
+      2.50464946208309415979E0,
+      -1.42182922854787788574E-1,
+      -3.80806407691578277194E-2,
+      -9.33259480895457427372E-4,
+  };
+
+  /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
+  * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
+  */
+
+  static const T P2[9] = {
+      3.23774891776946035970E0,
+      6.91522889068984211695E0,
+      3.93881025292474443415E0,
+      1.33303460815807542389E0,
+      2.01485389549179081538E-1,
+      1.23716634817820021358E-2,
+      3.01581553508235416007E-4,
+      2.65806974686737550832E-6,
+      6.23974539184983293730E-9,
+  };
+
+  static const T Q2[9] = {
+      1.00000000000000000000E0,
+      6.02427039364742014255E0,
+      3.67983563856160859403E0,
+      1.37702099489081330271E0,
+      2.16236993594496635890E-1,
+      1.34204006088543189037E-2,
+      3.28014464682127739104E-4,
+      2.89247864745380683936E-6,
+      6.79019408009981274425E-9,
+  };
+
+  if (y0 == zero) {
+    return -std::numeric_limits::infinity();
+  }
+  if (y0 == one) {
+    return std::numeric_limits::infinity();
+  }
+  if (y0 < zero || y0 > one) {
+    return std::numeric_limits::quiet_NaN();
+  }
+  bool code = true;
+  T y = y0;
+  if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */
+    y = one - y;
+    code = false;
+  }
+
+  if (y > T{0.13533528323661269189}) {
+    y = y - T{0.5};
+    const T y2 = y * y;
+    T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8));
+    return (x * s2pi);
+  }
+
+  T x = ::sqrt(T{-2.0} * ::log(y));
+  const T x0 = x - ::log(x) / x;
+
+  const T z = one / x;
+  T x1;
+  if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */
+  {
+    x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8);
+  } else {
+    x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8);
+  }
+  x = x0 - x1;
+  if (code) {
+    x = -x;
+  }
+  return x;
+}
+
+/* The next function is taken from http://ab-initio.mit.edu/Faddeev */
+
+/* Copyright (c) 2012 Massachusetts Institute of Technology
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining
+ * a copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to
+ * the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be
+ * included in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+ * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+ * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+ * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+/* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by
+   Steven G. Johnson, October 2012.
+
+   This function combines a few different ideas.
+
+   First, for x > 50, it uses a continued-fraction expansion (same as
+   for the Faddeeva function, but with algebraic simplifications for z=i*x).
+
+   Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations,
+   but with two twists:
+
+      a) It maps x to y = 4 / (4+x) in [0,1].  This simple transformation,
+         inspired by a similar transformation in the octave-forge/specfun
+         erfcx by Soren Hauberg, results in much faster Chebyshev convergence
+         than other simple transformations I have examined.
+
+      b) Instead of using a single Chebyshev polynomial for the entire
+         [0,1] y interval, we break the interval up into 100 equal
+         subintervals, with a switch/lookup table, and use much lower
+         degree Chebyshev polynomials in each subinterval. This greatly
+         improves performance in my tests.
+
+   For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x),
+   with the usual checks for overflow etcetera.
+
+   Performance-wise, it seems to be substantially faster than either
+   the SLATEC DERFC function [or an erfcx function derived therefrom]
+   or Cody's CALERF function (from netlib.org/specfun), while
+   retaining near machine precision in accuracy.  */
+
+/* Given y100=100*y, where y = 4/(4+x) for x >= 0, compute erfc(x).
+
+   Uses a look-up table of 100 different Chebyshev polynomials
+   for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated
+   with the help of Maple and a little shell script.   This allows
+   the Chebyshev polynomials to be of significantly lower degree (about 1/4)
+   compared to fitting the whole [0,1] interval with a single polynomial. */
+
+
+template 
+C10_HOST_DEVICE  static inline typename std::enable_if::value, T>::type
+erfcx_y100(T y100)
+{
+  switch (static_cast(y100)) {
+case 0: {
+T t = 2*y100 - 1;
+return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t;
+}
+case 1: {
+T t = 2*y100 - 3;
+return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t;
+}
+case 2: {
+T t = 2*y100 - 5;
+return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t;
+}
+case 3: {
+T t = 2*y100 - 7;
+return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t;
+}
+case 4: {
+T t = 2*y100 - 9;
+return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t;
+}
+case 5: {
+T t = 2*y100 - 11;
+return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t;
+}
+case 6: {
+T t = 2*y100 - 13;
+return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t;
+}
+case 7: {
+T t = 2*y100 - 15;
+return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t;
+}
+case 8: {
+T t = 2*y100 - 17;
+return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t;
+}
+case 9: {
+T t = 2*y100 - 19;
+return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t;
+}
+case 10: {
+T t = 2*y100 - 21;
+return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t;
+}
+case 11: {
+T t = 2*y100 - 23;
+return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t;
+}
+case 12: {
+T t = 2*y100 - 25;
+return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t;
+}
+case 13: {
+T t = 2*y100 - 27;
+return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t;
+}
+case 14: {
+T t = 2*y100 - 29;
+return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t;
+}
+case 15: {
+T t = 2*y100 - 31;
+return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t;
+}
+case 16: {
+T t = 2*y100 - 33;
+return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t;
+}
+case 17: {
+T t = 2*y100 - 35;
+return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t;
+}
+case 18: {
+T t = 2*y100 - 37;
+return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t;
+}
+case 19: {
+T t = 2*y100 - 39;
+return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t;
+}
+case 20: {
+T t = 2*y100 - 41;
+return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t;
+}
+case 21: {
+T t = 2*y100 - 43;
+return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t;
+}
+case 22: {
+T t = 2*y100 - 45;
+return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t;
+}
+case 23: {
+T t = 2*y100 - 47;
+return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t;
+}
+case 24: {
+T t = 2*y100 - 49;
+return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t;
+}
+case 25: {
+T t = 2*y100 - 51;
+return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t;
+}
+case 26: {
+T t = 2*y100 - 53;
+return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t;
+}
+case 27: {
+T t = 2*y100 - 55;
+return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t;
+}
+case 28: {
+T t = 2*y100 - 57;
+return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t;
+}
+case 29: {
+T t = 2*y100 - 59;
+return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t;
+}
+case 30: {
+T t = 2*y100 - 61;
+return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t;
+}
+case 31: {
+T t = 2*y100 - 63;
+return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t;
+}
+case 32: {
+T t = 2*y100 - 65;
+return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t;
+}
+case 33: {
+T t = 2*y100 - 67;
+return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t;
+}
+case 34: {
+T t = 2*y100 - 69;
+return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t;
+}
+case 35: {
+T t = 2*y100 - 71;
+return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t;
+}
+case 36: {
+T t = 2*y100 - 73;
+return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t;
+}
+case 37: {
+T t = 2*y100 - 75;
+return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t;
+}
+case 38: {
+T t = 2*y100 - 77;
+return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t;
+}
+case 39: {
+T t = 2*y100 - 79;
+return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t;
+}
+case 40: {
+T t = 2*y100 - 81;
+return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t;
+}
+case 41: {
+T t = 2*y100 - 83;
+return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t;
+}
+case 42: {
+T t = 2*y100 - 85;
+return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t;
+}
+case 43: {
+T t = 2*y100 - 87;
+return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t;
+}
+case 44: {
+T t = 2*y100 - 89;
+return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t;
+}
+case 45: {
+T t = 2*y100 - 91;
+return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t;
+}
+case 46: {
+T t = 2*y100 - 93;
+return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t;
+}
+case 47: {
+T t = 2*y100 - 95;
+return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t;
+}
+case 48: {
+T t = 2*y100 - 97;
+return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t;
+}
+case 49: {
+T t = 2*y100 - 99;
+return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t;
+}
+case 50: {
+T t = 2*y100 - 101;
+return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t;
+}
+case 51: {
+T t = 2*y100 - 103;
+return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t;
+}
+case 52: {
+T t = 2*y100 - 105;
+return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t;
+}
+case 53: {
+T t = 2*y100 - 107;
+return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t;
+}
+case 54: {
+T t = 2*y100 - 109;
+return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t;
+}
+case 55: {
+T t = 2*y100 - 111;
+return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t;
+}
+case 56: {
+T t = 2*y100 - 113;
+return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t;
+}
+case 57: {
+T t = 2*y100 - 115;
+return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t;
+}
+case 58: {
+T t = 2*y100 - 117;
+return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t;
+}
+case 59: {
+T t = 2*y100 - 119;
+return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t;
+}
+case 60: {
+T t = 2*y100 - 121;
+return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t;
+}
+case 61: {
+T t = 2*y100 - 123;
+return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t;
+}
+case 62: {
+T t = 2*y100 - 125;
+return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t;
+}
+case 63: {
+T t = 2*y100 - 127;
+return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t;
+}
+case 64: {
+T t = 2*y100 - 129;
+return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t;
+}
+case 65: {
+T t = 2*y100 - 131;
+return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t;
+}
+case 66: {
+T t = 2*y100 - 133;
+return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t;
+}
+case 67: {
+T t = 2*y100 - 135;
+return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t;
+}
+case 68: {
+T t = 2*y100 - 137;
+return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t;
+}
+case 69: {
+T t = 2*y100 - 139;
+return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t;
+}
+case 70: {
+T t = 2*y100 - 141;
+return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t;
+}
+case 71: {
+T t = 2*y100 - 143;
+return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t;
+}
+case 72: {
+T t = 2*y100 - 145;
+return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t;
+}
+case 73: {
+T t = 2*y100 - 147;
+return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t;
+}
+case 74: {
+T t = 2*y100 - 149;
+return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t;
+}
+case 75: {
+T t = 2*y100 - 151;
+return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t;
+}
+case 76: {
+T t = 2*y100 - 153;
+return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t;
+}
+case 77: {
+T t = 2*y100 - 155;
+return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t;
+}
+case 78: {
+T t = 2*y100 - 157;
+return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t;
+}
+case 79: {
+T t = 2*y100 - 159;
+return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t;
+}
+case 80: {
+T t = 2*y100 - 161;
+return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t;
+}
+case 81: {
+T t = 2*y100 - 163;
+return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t;
+}
+case 82: {
+T t = 2*y100 - 165;
+return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t;
+}
+case 83: {
+T t = 2*y100 - 167;
+return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t;
+}
+case 84: {
+T t = 2*y100 - 169;
+return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t;
+}
+case 85: {
+T t = 2*y100 - 171;
+return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t;
+}
+case 86: {
+T t = 2*y100 - 173;
+return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t;
+}
+case 87: {
+T t = 2*y100 - 175;
+return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t;
+}
+case 88: {
+T t = 2*y100 - 177;
+return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t;
+}
+case 89: {
+T t = 2*y100 - 179;
+return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t;
+}
+case 90: {
+T t = 2*y100 - 181;
+return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t;
+}
+case 91: {
+T t = 2*y100 - 183;
+return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t;
+}
+case 92: {
+T t = 2*y100 - 185;
+return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t;
+}
+case 93: {
+T t = 2*y100 - 187;
+return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t;
+}
+case 94: {
+T t = 2*y100 - 189;
+return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t;
+}
+case 95: {
+T t = 2*y100 - 191;
+return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t;
+}
+case 96: {
+T t = 2*y100 - 193;
+return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t;
+}
+case 97: {
+T t = 2*y100 - 195;
+return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t;
+}
+case 98: {
+T t = 2*y100 - 197;
+return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t;
+}
+case 99: {
+T t = 2*y100 - 199;
+return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t;
+}
+  }
+  // we only get here if y = 1, i.e. |x| < 4*eps, in which case
+  // erfcx is within 1e-15 of 1..
+  return 1.0;
+}
+
+template 
+C10_HOST_DEVICE static inline typename std::enable_if::value, T>::type
+calc_erfcx(T x)
+{
+  if (at::_isnan(x)) {
+    return x;
+  }
+
+  if (x >= 0) {
+    if (x > 50) { // continued-fraction expansion is faster
+      const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi)
+      if (x > 5e7) { // 1-term expansion, important to avoid overflow
+        return ispi / x;
+      }
+      /* 5-term expansion (rely on compiler for CSE), simplified from:
+                ispi / (x+0.5/(x+1/(x+1.5/(x+2/x))))  */
+      return ispi*((x*x) * (x*x+4.5) + 2) / (x * ((x*x) * (x*x+5) + 3.75));
+    }
+    return erfcx_y100(400/(4+x));
+  }
+  else {
+    if (x < -26.7) {
+      return std::numeric_limits::infinity();
+    }
+    else if (x < -6.1) {
+      return 2*exp(x*x);
+    }
+    else {
+      return 2*exp(x*x) - erfcx_y100(400/(4-x));
+    }
+  }
+}
+
+/*
+ * Logarithm of Gaussian cumulative distribution function.
+
+ * This implementation of log_ndtr and its helper functions
+ * follow SciPy's implementation
+ * See NOTICE for the licenses.
+ */
+template 
+static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
+  T t = x * c10::frac_sqrt_2;
+  if (x < T{-1.0}) {
+    return std::log(calc_erfcx(-t) / 2) - t * t;
+  } else {
+    return std::log1p(-std::erfc(t) / 2);
+  }
+}
+
+template
+static inline C10_HOST_DEVICE T airy_ai_forward(T x) {
+    static const T AN[] = {
+            +3.46538101525629032477e-01,
+            +1.20075952739645805542e+01,
+            +7.62796053615234516538e+01,
+            +1.68089224934630576269e+02,
+            +1.59756391350164413639e+02,
+            +7.05360906840444183113e+01,
+            +1.40264691163389668864e+01,
+            +9.99999999999999995305e-01,
+    };
+
+    static const T AD[] = {
+            +5.67594532638770212846e-01,
+            +1.47562562584847203173e+01,
+            +8.45138970141474626562e+01,
+            +1.77318088145400459522e+02,
+            +1.64234692871529701831e+02,
+            +7.14778400825575695274e+01,
+            +1.40959135607834029598e+01,
+            +1.00000000000000000470e+00,
+    };
+
+    static const T AFN[] = {
+            -1.31696323418331795333e-01,
+            -6.26456544431912369773e-01,
+            -6.93158036036933542233e-01,
+            -2.79779981545119124951e-01,
+            -4.91900132609500318020e-02,
+            -4.06265923594885404393e-03,
+            -1.59276496239262096340e-04,
+            -2.77649108155232920844e-06,
+            -1.67787698489114633780e-08,
+    };
+
+    static const T AFD[] = {
+            +1.33560420706553243746e+01,
+            +3.26825032795224613948e+01,
+            +2.67367040941499554804e+01,
+            +9.18707402907259625840e+00,
+            +1.47529146771666414581e+00,
+            +1.15687173795188044134e-01,
+            +4.40291641615211203805e-03,
+            +7.54720348287414296618e-05,
+            +4.51850092970580378464e-07,
+    };
+
+    static const T AGN[] = {
+            +1.97339932091685679179e-02,
+            +3.91103029615688277255e-01,
+            +1.06579897599595591108e+00,
+            +9.39169229816650230044e-01,
+            +3.51465656105547619242e-01,
+            +6.33888919628925490927e-02,
+            +5.85804113048388458567e-03,
+            +2.82851600836737019778e-04,
+            +6.98793669997260967291e-06,
+            +8.11789239554389293311e-08,
+            +3.41551784765923618484e-10,
+    };
+
+    static const T AGD[] = {
+            +9.30892908077441974853e+00,
+            +1.98352928718312140417e+01,
+            +1.55646628932864612953e+01,
+            +5.47686069422975497931e+00,
+            +9.54293611618961883998e-01,
+            +8.64580826352392193095e-02,
+            +4.12656523824222607191e-03,
+            +1.01259085116509135510e-04,
+            +1.17166733214413521882e-06,
+            +4.91834570062930015649e-09,
+    };
+
+    int domain_flag = 0;
+
+    T ai;
+
+    if (std::isinf(x)) {
+        return std::numeric_limits::quiet_NaN();
+    }
+
+    if (x > T(103.892)) {
+        return T(0.0);
+    }
+
+    T f;
+    T g;
+    T k;
+
+    if (x < T(-2.09)) {
+        T z = T(1.0) / (T(-2.0) * x * std::sqrt(-x) / T(3.0));
+
+        T afn = 0.0;
+
+        for (uint8_t index = 0; index <= 8; index++) {
+            afn = afn * (z * z) + AFN[index];
+        }
+
+        T afd = 0.0;
+
+        for (uint8_t index = 0; index <= 8; index++) {
+            afd = afd * (z * z) + AFD[index];
+        }
+
+        T agn = 0.0;
+
+        for (uint8_t index = 0; index <= 10 + 0; index++) {
+            agn = agn * (z * z) + AGN[index];
+        }
+
+        T agd = 0.0;
+
+        for (uint8_t index = 0; index <= 10 - 1; index++) {
+            agd = agd * (z * z) + AGD[index];
+        }
+
+        T t = T(-2.0) * x * std::sqrt(-x) / T(3.0) + T(0.25) * c10::pi;
+
+        return T(5.64189583547756286948e-01) / std::sqrt(std::sqrt(-x)) * (std::sin(t) * (T(1.0) + z * z * afn / afd) - std::cos(t) * (z * agn / agd));
+    }
+
+    if (x >= T(2.09)) {
+        domain_flag = 5;
+
+        T zeta = T(2.0) * x * std::sqrt(x) / T(3.0);
+
+        T an = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            an = an * (T(1.0) / zeta) + AN[index];
+        }
+
+        T ad = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            ad = ad * (T(1.0) / zeta) + AD[index];
+        }
+
+        ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * std::sqrt(std::sqrt(x)) * std::exp(zeta));
+
+        if (x > T(8.3203353)) {
+            return ai;
+        }
+    }
+
+    f = 1.0;
+    g = x;
+    k = 1.0;
+
+    T m = 1.0;
+    T n = x;
+    T t = 1.0;
+    T z = x * x * x;
+
+    while (t > std::numeric_limits::epsilon()) {
+        m *= z;
+        k += T(1.0);
+        m /= k;
+        n *= z;
+        k += T(1.0);
+        n /= k;
+        m /= k;
+        f += m;
+        k += T(1.0);
+        n /= k;
+        g += n;
+
+        t = std::abs(m / f);
+    }
+
+    if ((domain_flag & 1) == 0) {
+        return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g;
+    }
+
+    return ai;
+} // T airy_ai(T x)
+
+template
+static inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
+    static const T PP[] = {
+            +7.96936729297347051624e-04,
+            +8.28352392107440799803e-02,
+            +1.23953371646414299388e+00,
+            +5.44725003058768775090e+00,
+            +8.74716500199817011941e+00,
+            +5.30324038235394892183e+00,
+            +9.99999999999999997821e-01,
+    };
+
+    static const T PQ[] = {
+            +9.24408810558863637013e-04,
+            +8.56288474354474431428e-02,
+            +1.25352743901058953537e+00,
+            +5.47097740330417105182e+00,
+            +8.76190883237069594232e+00,
+            +5.30605288235394617618e+00,
+            +1.00000000000000000218e+00,
+    };
+
+    static const T QP[] = {
+            -1.13663838898469149931e-02,
+            -1.28252718670509318512e+00,
+            -1.95539544257735972385e+01,
+            -9.32060152123768231369e+01,
+            -1.77681167980488050595e+02,
+            -1.47077505154951170175e+02,
+            -5.14105326766599330220e+01,
+            -6.05014350600728481186e+00,
+    };
+
+    static const T QQ[] = {
+            +6.43178256118178023184e+01,
+            +8.56430025976980587198e+02,
+            +3.88240183605401609683e+03,
+            +7.24046774195652478189e+03,
+            +5.93072701187316984827e+03,
+            +2.06209331660327847417e+03,
+            +2.42005740240291393179e+02,
+    };
+
+    static const T RP[] = {
+            -4.79443220978201773821e+09,
+            +1.95617491946556577543e+12,
+            -2.49248344360967716204e+14,
+            +9.70862251047306323952e+15,
+    };
+
+    static const T RQ[] = {
+            +4.99563147152651017219e+02,
+            +1.73785401676374683123e+05,
+            +4.84409658339962045305e+07,
+            +1.11855537045356834862e+10,
+            +2.11277520115489217587e+12,
+            +3.10518229857422583814e+14,
+            +3.18121955943204943306e+16,
+            +1.71086294081043136091e+18,
+    };
+
+    if (x < T(0)) {
+        x = -x;
+    }
+
+    if (x <= T(5.0)) {
+        if (x < T(0.00001)) {
+            return T(1.0) - x * x / T(4.0);
+        }
+
+        T rp = 0.0;
+
+        for (uint8_t index = 0; index <= 3; index++) {
+            rp = rp * (x * x) + RP[index];
+        }
+
+        T rq = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            rq = rq * (x * x) + RQ[index];
+        }
+
+        return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
+    }
+
+    T pp = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pp = pp * (T(25.0) / (x * x)) + PP[index];
+    }
+
+    T pq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pq = pq * (T(25.0) / (x * x)) + PQ[index];
+    }
+
+    T qp = 0.0;
+
+    for (uint8_t index = 0; index <= 7; index++) {
+        qp = qp * (T(25.0) / (x * x)) + QP[index];
+    }
+
+    T qq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        qq = qq * (T(25.0) / (x * x)) + QQ[index];
+    }
+
+    return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
+} // bessel_j0_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
+    static const T PP[] = {
+            +7.62125616208173112003e-04,
+            +7.31397056940917570436e-02,
+            +1.12719608129684925192e+00,
+            +5.11207951146807644818e+00,
+            +8.42404590141772420927e+00,
+            +5.21451598682361504063e+00,
+            +1.00000000000000000254e+00,
+    };
+
+    static const T PQ[] = {
+            +5.71323128072548699714e-04,
+            +6.88455908754495404082e-02,
+            +1.10514232634061696926e+00,
+            +5.07386386128601488557e+00,
+            +8.39985554327604159757e+00,
+            +5.20982848682361821619e+00,
+            +9.99999999999999997461e-01,
+    };
+
+    static const T QP[] = {
+            +5.10862594750176621635e-02,
+            +4.98213872951233449420e+00,
+            +7.58238284132545283818e+01,
+            +3.66779609360150777800e+02,
+            +7.10856304998926107277e+02,
+            +5.97489612400613639965e+02,
+            +2.11688757100572135698e+02,
+            +2.52070205858023719784e+01,
+    };
+
+    static const T QQ[] = {
+            +7.42373277035675149943e+01,
+            +1.05644886038262816351e+03,
+            +4.98641058337653607651e+03,
+            +9.56231892404756170795e+03,
+            +7.99704160447350683650e+03,
+            +2.82619278517639096600e+03,
+            +3.36093607810698293419e+02,
+    };
+
+    static const T RP[] = {
+            -8.99971225705559398224e+08,
+            +4.52228297998194034323e+11,
+            -7.27494245221818276015e+13,
+            +3.68295732863852883286e+15,
+    };
+
+    static const T RQ[] = {
+            +6.20836478118054335476e+02,
+            +2.56987256757748830383e+05,
+            +8.35146791431949253037e+07,
+            +2.21511595479792499675e+10,
+            +4.74914122079991414898e+12,
+            +7.84369607876235854894e+14,
+            +8.95222336184627338078e+16,
+            +5.32278620332680085395e+18,
+    };
+
+    if (x < T(0.0)) {
+        return -bessel_j1_forward(-x);
+    }
+
+    if (x <= T(5.0)) {
+        T rp = 0.0;
+
+        for (uint8_t index = 0; index <= 3; index++) {
+            rp = rp * (x * x) + RP[index];
+        }
+
+        T rq = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            rq = rq * (x * x) + RQ[index];
+        }
+
+        return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
+    }
+
+    T pp = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
+    }
+
+    T pq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
+    }
+
+    T qp = 0.0;
+
+    for (uint8_t index = 0; index <= 7; index++) {
+        qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
+    }
+
+    T qq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
+    }
+
+    return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
+} // bessel_j1_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
+    static const T PP[] = {
+            +7.96936729297347051624e-04,
+            +8.28352392107440799803e-02,
+            +1.23953371646414299388e+00,
+            +5.44725003058768775090e+00,
+            +8.74716500199817011941e+00,
+            +5.30324038235394892183e+00,
+            +9.99999999999999997821e-01,
+    };
+
+    static const T PQ[] = {
+            +9.24408810558863637013e-04,
+            +8.56288474354474431428e-02,
+            +1.25352743901058953537e+00,
+            +5.47097740330417105182e+00,
+            +8.76190883237069594232e+00,
+            +5.30605288235394617618e+00,
+            +1.00000000000000000218e+00,
+    };
+
+    static const T QP[] = {
+            -1.13663838898469149931e-02,
+            -1.28252718670509318512e+00,
+            -1.95539544257735972385e+01,
+            -9.32060152123768231369e+01,
+            -1.77681167980488050595e+02,
+            -1.47077505154951170175e+02,
+            -5.14105326766599330220e+01,
+            -6.05014350600728481186e+00,
+    };
+
+    static const T QQ[] = {
+            +6.43178256118178023184e+01,
+            +8.56430025976980587198e+02,
+            +3.88240183605401609683e+03,
+            +7.24046774195652478189e+03,
+            +5.93072701187316984827e+03,
+            +2.06209331660327847417e+03,
+            +2.42005740240291393179e+02,
+    };
+
+    static const T YP[] = {
+            +1.55924367855235737965e+04,
+            -1.46639295903971606143e+07,
+            +5.43526477051876500413e+09,
+            -9.82136065717911466409e+11,
+            +8.75906394395366999549e+13,
+            -3.46628303384729719441e+15,
+            +4.42733268572569800351e+16,
+            -1.84950800436986690637e+16,
+    };
+
+    static const T YQ[] = {
+            +1.04128353664259848412e+03,
+            +6.26107330137134956842e+05,
+            +2.68919633393814121987e+08,
+            +8.64002487103935000337e+10,
+            +2.02979612750105546709e+13,
+            +3.17157752842975028269e+15,
+            +2.50596256172653059228e+17,
+    };
+
+    if (x <= T(5.0)) {
+        if (x == T(0.0)) {
+            return -std::numeric_limits::infinity();
+        }
+
+        if (x < T(0.0)) {
+            return std::numeric_limits::quiet_NaN();
+        }
+
+        T yp = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            yp = yp * (x * x) + YP[index];
+        }
+
+        T yq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            yq = yq * (x * x) + YQ[index];
+        }
+
+        return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x));
+    }
+
+    T pp = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pp = pp * (T(25.0) / (x * x)) + PP[index];
+    }
+
+    T pq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pq = pq * (T(25.0) / (x * x)) + PQ[index];
+    }
+
+    T qp = 0.0;
+
+    for (uint8_t index = 0; index <= 7; index++) {
+        qp = qp * (T(25.0) / (x * x)) + QP[index];
+    }
+
+    T qq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        qq = qq * (T(25.0) / (x * x)) + QQ[index];
+    }
+
+    return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
+} // bessel_y0_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
+    static const T PP[] = {
+            +7.62125616208173112003e-04,
+            +7.31397056940917570436e-02,
+            +1.12719608129684925192e+00,
+            +5.11207951146807644818e+00,
+            +8.42404590141772420927e+00,
+            +5.21451598682361504063e+00,
+            +1.00000000000000000254e+00,
+    };
+
+    static const T PQ[] = {
+            +5.71323128072548699714e-04,
+            +6.88455908754495404082e-02,
+            +1.10514232634061696926e+00,
+            +5.07386386128601488557e+00,
+            +8.39985554327604159757e+00,
+            +5.20982848682361821619e+00,
+            +9.99999999999999997461e-01,
+    };
+
+    static const T QP[] = {
+            +5.10862594750176621635e-02,
+            +4.98213872951233449420e+00,
+            +7.58238284132545283818e+01,
+            +3.66779609360150777800e+02,
+            +7.10856304998926107277e+02,
+            +5.97489612400613639965e+02,
+            +2.11688757100572135698e+02,
+            +2.52070205858023719784e+01,
+    };
+
+    static const T QQ[] = {
+            +7.42373277035675149943e+01,
+            +1.05644886038262816351e+03,
+            +4.98641058337653607651e+03,
+            +9.56231892404756170795e+03,
+            +7.99704160447350683650e+03,
+            +2.82619278517639096600e+03,
+            +3.36093607810698293419e+02,
+    };
+
+    static const T YP[] = {
+            +1.26320474790178026440e+09,
+            -6.47355876379160291031e+11,
+            +1.14509511541823727583e+14,
+            -8.12770255501325109621e+15,
+            +2.02439475713594898196e+17,
+            -7.78877196265950026825e+17,
+    };
+
+    static const T YQ[] = {
+            +5.94301592346128195359e+02,
+            +2.35564092943068577943e+05,
+            +7.34811944459721705660e+07,
+            +1.87601316108706159478e+10,
+            +3.88231277496238566008e+12,
+            +6.20557727146953693363e+14,
+            +6.87141087355300489866e+16,
+            +3.97270608116560655612e+18,
+    };
+
+    if (x <= T(5.0)) {
+        if (x == T(0.0)) {
+            return -std::numeric_limits::infinity();
+        }
+
+        if (x <= T(0.0)) {
+            return std::numeric_limits::quiet_NaN();
+        }
+
+        T yp = 0.0;
+
+        for (uint8_t index = 0; index <= 5; index++) {
+            yp = yp * (x * x) + YP[index];
+        }
+
+        T yq = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            yq = yq * (x * x) + YQ[index];
+        }
+
+        return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x));
+    }
+
+    T pp = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
+    }
+
+    T pq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
+    }
+
+    T qp = 0.0;
+
+    for (uint8_t index = 0; index <= 7; index++) {
+        qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
+    }
+
+    T qq = 0.0;
+
+    for (uint8_t index = 0; index <= 6; index++) {
+        qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
+    }
+
+    return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
+} // bessel_y1_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) == T(1.0)) {
+        if (x > T(0.0) || n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if ((n > 6) && (std::abs(x) < T(1.0))) {
+        return std::cos(n * std::acos(x));
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x;
+    }
+
+    T p = T(1.0);
+    T q = x;
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // chebyshev_polynomial_t_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) {
+    return chebyshev_polynomial_t_forward(x, static_cast(n));
+} // chebyshev_polynomial_t_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) == T(1.0)) {
+        if (x > T(0.0) || n % 2 == 0) {
+            return n + 1;
+        }
+
+        return -(n + 1);
+    }
+
+    if ((n > 8) && (std::abs(x) < T(1.0))) {
+        if (std::sin(std::acos(x)) != T(0.0)) {
+            return std::sin((n + 1) * std::acos(x)) / std::sin(std::acos(x));
+        }
+
+        return (n + 1) * std::cos((n + 1) * std::acos(x)) / x;
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x;
+    }
+
+    T p = T(1.0);
+    T q = x + x;
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // chebyshev_polynomial_u_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) {
+    return chebyshev_polynomial_u_forward(x, static_cast(n));
+} // chebyshev_polynomial_u_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) == T(1.0)) {
+        if (x > T(0.0)) {
+            return T(1.0);
+        }
+
+        if (n % 2 == 0) {
+            return n + n + 1;
+        }
+
+        return -(n + n + 1);
+    }
+
+    if ((n > 8) && (std::abs(x) < T(1.0))) {
+        if (std::sin(std::acos(x) / T(2.0)) != T(1.0)) {
+            return std::cos((n + T(0.5)) * std::acos(x)) / std::cos(std::acos(x) / T(2.0));
+        }
+
+        if (n % 2 == 0) {
+            return n + n + 1;
+        }
+
+        return -(n + n + 1);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x - T(1.0);
+    }
+
+    T p = T(1.0);
+    T q = x + x - T(1.0);
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // chebyshev_polynomial_v_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) {
+    return chebyshev_polynomial_v_forward(x, static_cast(n));
+} // chebyshev_polynomial_v_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) == T(1.0)) {
+        if (x > T(0.0)) {
+            return n + n + 1;
+        }
+
+        if (n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if ((n > 8) && (std::abs(x) < T(1.0))) {
+        if (std::cos(std::acos(x) / T(2.0)) != T(1.0)) {
+            return std::sin((n + T(0.5)) * std::acos(x)) / std::sin(std::acos(x) / T(2.0));
+        }
+
+        if (x > T(0.0)) {
+            return n + n + 1;
+        }
+
+        if (n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x + T(1.0);
+    }
+
+    T p = T(1.0);
+    T q = x + x + T(1.0);
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // chebyshev_polynomial_w_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
+    return chebyshev_polynomial_w_forward(x, static_cast(n));
+} // chebyshev_polynomial_w_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x;
+    }
+
+    T p = T(1.0);
+    T q = x + x;
+    T r = T(0.0);
+
+    for (int64_t k = 2; k < n + n; k += 2) {
+        r = (x + x) * q - k * p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // hermite_polynomial_h_forward(T x, int64_t n)
+
+template::value, int> = 0>
+static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
+    return hermite_polynomial_h_forward(x, static_cast(n));
+} // hermite_polynomial_h_forward(T x, T n)
+
+template::value, int> = 0>
+static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
+    return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast(n) : static_cast(-1));
+} // hermite_polynomial_h_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x;
+    }
+
+    T p = T(1.0);
+    T q = x;
+    T r;
+
+    for (int64_t k = 1; k < n; k++) {
+        r = x * q - k * p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // hermite_polynomial_he_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
+    return hermite_polynomial_he_forward(x, static_cast(n));
+} // hermite_polynomial_he_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) == T(0.0)) {
+        return T(1.0);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return T(1.0) - x;
+    }
+
+    T p = T(1.0);
+    T q = T(1.0) - x;
+    T r;
+
+    for (int64_t k = 1; k < n; k++) {
+        r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1);
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // laguerre_polynomial_l_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
+    return laguerre_polynomial_l_forward(x, static_cast(n));
+} // laguerre_polynomial_l_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) == T(1.0)) {
+        if (x > T(0.0) || n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x;
+    }
+
+    T p = T(1.0);
+    T q = x;
+    T r;
+
+    for (int64_t k = 1; k < n; k++) {
+        r = ((k + k + 1) * x * q - k * p) / (k + 1);
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // legendre_polynomial_p_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) {
+    return legendre_polynomial_p_forward(x, static_cast(n));
+} // legendre_polynomial_p_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
+    static const T A[] = {
+            -4.41534164647933937950e-18,
+            +3.33079451882223809783e-17,
+            -2.43127984654795469359e-16,
+            +1.71539128555513303061e-15,
+            -1.16853328779934516808e-14,
+            +7.67618549860493561688e-14,
+            -4.85644678311192946090e-13,
+            +2.95505266312963983461e-12,
+            -1.72682629144155570723e-11,
+            +9.67580903537323691224e-11,
+            -5.18979560163526290666e-10,
+            +2.65982372468238665035e-09,
+            -1.30002500998624804212e-08,
+            +6.04699502254191894932e-08,
+            -2.67079385394061173391e-07,
+            +1.11738753912010371815e-06,
+            -4.41673835845875056359e-06,
+            +1.64484480707288970893e-05,
+            -5.75419501008210370398e-05,
+            +1.88502885095841655729e-04,
+            -5.76375574538582365885e-04,
+            +1.63947561694133579842e-03,
+            -4.32430999505057594430e-03,
+            +1.05464603945949983183e-02,
+            -2.37374148058994688156e-02,
+            +4.93052842396707084878e-02,
+            -9.49010970480476444210e-02,
+            +1.71620901522208775349e-01,
+            -3.04682672343198398683e-01,
+            +6.76795274409476084995e-01,
+    };
+
+    static const T B[] = {
+            -7.23318048787475395456e-18,
+            -4.83050448594418207126e-18,
+            +4.46562142029675999901e-17,
+            +3.46122286769746109310e-17,
+            -2.82762398051658348494e-16,
+            -3.42548561967721913462e-16,
+            +1.77256013305652638360e-15,
+            +3.81168066935262242075e-15,
+            -9.55484669882830764870e-15,
+            -4.15056934728722208663e-14,
+            +1.54008621752140982691e-14,
+            +3.85277838274214270114e-13,
+            +7.18012445138366623367e-13,
+            -1.79417853150680611778e-12,
+            -1.32158118404477131188e-11,
+            -3.14991652796324136454e-11,
+            +1.18891471078464383424e-11,
+            +4.94060238822496958910e-10,
+            +3.39623202570838634515e-09,
+            +2.26666899049817806459e-08,
+            +2.04891858946906374183e-07,
+            +2.89137052083475648297e-06,
+            +6.88975834691682398426e-05,
+            +3.36911647825569408990e-03,
+            +8.04490411014108831608e-01,
+    };
+
+    T p;
+    T q = 0.0;
+
+    if (std::abs(x) <= T(8.0)) {
+        T a = A[0];
+
+        for (uint8_t index = 1; index < 30; index++) {
+            p = q;
+            q = a;
+            a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
+        }
+
+        return std::exp(std::abs(x)) * (T(0.5) * (a - p));
+    }
+
+    T b = B[0];
+
+    for (uint8_t index = 1; index < 25; index++) {
+        p = q;
+        q = b;
+        b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
+    }
+
+    return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
+} // modified_bessel_i0_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
+    static const T A[] = {
+            +2.77791411276104639959e-18,
+            -2.11142121435816608115e-17,
+            +1.55363195773620046921e-16,
+            -1.10559694773538630805e-15,
+            +7.60068429473540693410e-15,
+            -5.04218550472791168711e-14,
+            +3.22379336594557470981e-13,
+            -1.98397439776494371520e-12,
+            +1.17361862988909016308e-11,
+            -6.66348972350202774223e-11,
+            +3.62559028155211703701e-10,
+            -1.88724975172282928790e-09,
+            +9.38153738649577178388e-09,
+            -4.44505912879632808065e-08,
+            +2.00329475355213526229e-07,
+            -8.56872026469545474066e-07,
+            +3.47025130813767847674e-06,
+            -1.32731636560394358279e-05,
+            +4.78156510755005422638e-05,
+            -1.61760815825896745588e-04,
+            +5.12285956168575772895e-04,
+            -1.51357245063125314899e-03,
+            +4.15642294431288815669e-03,
+            -1.05640848946261981558e-02,
+            +2.47264490306265168283e-02,
+            -5.29459812080949914269e-02,
+            +1.02643658689847095384e-01,
+            -1.76416518357834055153e-01,
+            +2.52587186443633654823e-01,
+    };
+
+    static const T B[] = {
+            +7.51729631084210481353e-18,
+            +4.41434832307170791151e-18,
+            -4.65030536848935832153e-17,
+            -3.20952592199342395980e-17,
+            +2.96262899764595013876e-16,
+            +3.30820231092092828324e-16,
+            -1.88035477551078244854e-15,
+            -3.81440307243700780478e-15,
+            +1.04202769841288027642e-14,
+            +4.27244001671195135429e-14,
+            -2.10154184277266431302e-14,
+            -4.08355111109219731823e-13,
+            -7.19855177624590851209e-13,
+            +2.03562854414708950722e-12,
+            +1.41258074366137813316e-11,
+            +3.25260358301548823856e-11,
+            -1.89749581235054123450e-11,
+            -5.58974346219658380687e-10,
+            -3.83538038596423702205e-09,
+            -2.63146884688951950684e-08,
+            -2.51223623787020892529e-07,
+            -3.88256480887769039346e-06,
+            -1.10588938762623716291e-04,
+            -9.76109749136146840777e-03,
+            +7.78576235018280120474e-01,
+    };
+
+    T p;
+    T q = 0.0;
+
+    if (std::abs(x) <= T(8.0)) {
+        T a = A[0];
+
+        for (uint8_t index = 1; index < 29; index++) {
+            p = q;
+            q = a;
+            a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
+        }
+
+        if (x < T(0.0)) {
+            return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)));
+        }
+
+        return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x));
+    }
+
+    T b = B[0];
+
+    for (uint8_t index = 1; index < 25; index++) {
+        p = q;
+        q = b;
+        b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
+    }
+
+    if (x < T(0.0)) {
+        return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)));
+    }
+
+    return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
+} // modified_bessel_i1_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
+    static const T A[] = {
+            +1.37446543561352307156e-16,
+            +4.25981614279661018399e-14,
+            +1.03496952576338420167e-11,
+            +1.90451637722020886025e-09,
+            +2.53479107902614945675e-07,
+            +2.28621210311945178607e-05,
+            +1.26461541144692592338e-03,
+            +3.59799365153615016266e-02,
+            +3.44289899924628486886e-01,
+            -5.35327393233902768720e-01,
+    };
+
+    static const T B[] = {
+            +5.30043377268626276149e-18,
+            -1.64758043015242134646e-17,
+            +5.21039150503902756861e-17,
+            -1.67823109680541210385e-16,
+            +5.51205597852431940784e-16,
+            -1.84859337734377901440e-15,
+            +6.34007647740507060557e-15,
+            -2.22751332699166985548e-14,
+            +8.03289077536357521100e-14,
+            -2.98009692317273043925e-13,
+            +1.14034058820847496303e-12,
+            -4.51459788337394416547e-12,
+            +1.85594911495471785253e-11,
+            -7.95748924447710747776e-11,
+            +3.57739728140030116597e-10,
+            -1.69753450938905987466e-09,
+            +8.57403401741422608519e-09,
+            -4.66048989768794782956e-08,
+            +2.76681363944501510342e-07,
+            -1.83175552271911948767e-06,
+            +1.39498137188764993662e-05,
+            -1.28495495816278026384e-04,
+            +1.56988388573005337491e-03,
+            -3.14481013119645005427e-02,
+            +2.44030308206595545468e+00,
+    };
+
+    if (x == T(0.0)) {
+        return std::numeric_limits::infinity();
+    }
+
+    if (x < T(0.0)) {
+        return std::numeric_limits::quiet_NaN();
+    }
+
+    T p;
+    T q = 0.0;
+
+    if (x <= T(2.0)) {
+        T a = A[0];
+
+        for (uint8_t index = 1; index < 10; index++) {
+            p = q;
+            q = a;
+            a = (x * x - T(2.0)) * q - p + A[index];
+        }
+
+        return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x);
+    }
+
+    T b = B[0];
+
+    for (uint8_t index = 1; index < 25; index++) {
+        p = q;
+        q = b;
+        b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+    }
+
+    return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
+} // modified_bessel_k0_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
+    static const T A[] = {
+            -7.02386347938628759343e-18,
+            -2.42744985051936593393e-15,
+            -6.66690169419932900609e-13,
+            -1.41148839263352776110e-10,
+            -2.21338763073472585583e-08,
+            -2.43340614156596823496e-06,
+            -1.73028895751305206302e-04,
+            -6.97572385963986435018e-03,
+            -1.22611180822657148235e-01,
+            -3.53155960776544875667e-01,
+            +1.52530022733894777053e+00,
+    };
+
+    static const T B[] = {
+            -5.75674448366501715755e-18,
+            +1.79405087314755922667e-17,
+            -5.68946255844285935196e-17,
+            +1.83809354436663880070e-16,
+            -6.05704724837331885336e-16,
+            +2.03870316562433424052e-15,
+            -7.01983709041831346144e-15,
+            +2.47715442448130437068e-14,
+            -8.97670518232499435011e-14,
+            +3.34841966607842919884e-13,
+            -1.28917396095102890680e-12,
+            +5.13963967348173025100e-12,
+            -2.12996783842756842877e-11,
+            +9.21831518760500529508e-11,
+            -4.19035475934189648750e-10,
+            +2.01504975519703286596e-09,
+            -1.03457624656780970260e-08,
+            +5.74108412545004946722e-08,
+            -3.50196060308781257119e-07,
+            +2.40648494783721712015e-06,
+            -1.93619797416608296024e-05,
+            +1.95215518471351631108e-04,
+            -2.85781685962277938680e-03,
+            +1.03923736576817238437e-01,
+            +2.72062619048444266945e+00,
+    };
+
+    if (x == T(0.0)) {
+        return std::numeric_limits::infinity();
+    }
+
+    if (x < T(0.0)) {
+        return std::numeric_limits::quiet_NaN();
+    }
+
+    T p;
+    T q = 0.0;
+
+    if (x <= T(2.0)) {
+        T a = A[0];
+
+        for (uint8_t index = 1; index < 11; index++) {
+            p = q;
+            q = a;
+            a = (x * x - T(2.0)) * q - p + A[index];
+        }
+
+        return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
+    }
+
+    T b = B[0];
+
+    for (uint8_t index = 1; index < 25; index++) {
+        p = q;
+        q = b;
+        b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+    }
+
+    return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
+} // modified_bessel_k1_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
+    static const T A[] = {
+            +1.37446543561352307156e-16,
+            +4.25981614279661018399e-14,
+            +1.03496952576338420167e-11,
+            +1.90451637722020886025e-09,
+            +2.53479107902614945675e-07,
+            +2.28621210311945178607e-05,
+            +1.26461541144692592338e-03,
+            +3.59799365153615016266e-02,
+            +3.44289899924628486886e-01,
+            -5.35327393233902768720e-01,
+    };
+
+    static const T B[] = {
+            +5.30043377268626276149e-18,
+            -1.64758043015242134646e-17,
+            +5.21039150503902756861e-17,
+            -1.67823109680541210385e-16,
+            +5.51205597852431940784e-16,
+            -1.84859337734377901440e-15,
+            +6.34007647740507060557e-15,
+            -2.22751332699166985548e-14,
+            +8.03289077536357521100e-14,
+            -2.98009692317273043925e-13,
+            +1.14034058820847496303e-12,
+            -4.51459788337394416547e-12,
+            +1.85594911495471785253e-11,
+            -7.95748924447710747776e-11,
+            +3.57739728140030116597e-10,
+            -1.69753450938905987466e-09,
+            +8.57403401741422608519e-09,
+            -4.66048989768794782956e-08,
+            +2.76681363944501510342e-07,
+            -1.83175552271911948767e-06,
+            +1.39498137188764993662e-05,
+            -1.28495495816278026384e-04,
+            +1.56988388573005337491e-03,
+            -3.14481013119645005427e-02,
+            +2.44030308206595545468e+00,
+    };
+
+    if (x == T(0.0)) {
+        return std::numeric_limits::infinity();
+    }
+
+    if (x < T(0.0)) {
+        return std::numeric_limits::quiet_NaN();
+    }
+
+    T p;
+    T q = 0.0;
+
+    if (x <= T(2.0)) {
+        T a = A[0];
+
+        for (uint64_t index = 1; index < 10; index++) {
+            p = q;
+            q = a;
+            a = (x * x - T(2.0)) * q - p + A[index];
+        }
+
+        return (T(0.5) * (a - p) - std::log(T(0.5) * x) * modified_bessel_i0_forward(x)) * std::exp(x);
+    }
+
+    T b = B[0];
+
+    for (uint64_t index = 1; index < 25; index++) {
+        p = q;
+        q = b;
+        b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+    }
+
+    return T(0.5) * (b - p) / std::sqrt(x);
+} // T scaled_modified_bessel_k0_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
+    static const T A[] = {
+            -7.02386347938628759343e-18,
+            -2.42744985051936593393e-15,
+            -6.66690169419932900609e-13,
+            -1.41148839263352776110e-10,
+            -2.21338763073472585583e-08,
+            -2.43340614156596823496e-06,
+            -1.73028895751305206302e-04,
+            -6.97572385963986435018e-03,
+            -1.22611180822657148235e-01,
+            -3.53155960776544875667e-01,
+            +1.52530022733894777053e+00,
+    };
+
+    static const T B[] = {
+            -5.75674448366501715755e-18,
+            +1.79405087314755922667e-17,
+            -5.68946255844285935196e-17,
+            +1.83809354436663880070e-16,
+            -6.05704724837331885336e-16,
+            +2.03870316562433424052e-15,
+            -7.01983709041831346144e-15,
+            +2.47715442448130437068e-14,
+            -8.97670518232499435011e-14,
+            +3.34841966607842919884e-13,
+            -1.28917396095102890680e-12,
+            +5.13963967348173025100e-12,
+            -2.12996783842756842877e-11,
+            +9.21831518760500529508e-11,
+            -4.19035475934189648750e-10,
+            +2.01504975519703286596e-09,
+            -1.03457624656780970260e-08,
+            +5.74108412545004946722e-08,
+            -3.50196060308781257119e-07,
+            +2.40648494783721712015e-06,
+            -1.93619797416608296024e-05,
+            +1.95215518471351631108e-04,
+            -2.85781685962277938680e-03,
+            +1.03923736576817238437e-01,
+            +2.72062619048444266945e+00,
+    };
+
+    if (x == T(0.0)) {
+        return std::numeric_limits::infinity();
+    }
+
+    if (x < T(0.0)) {
+        return std::numeric_limits::quiet_NaN();
+    }
+
+    T p;
+    T q = 0.0;
+
+    if (x <= T(2.0)) {
+        T a = A[0];
+
+        for (uint64_t index = 1; index < 11; index++) {
+            p = q;
+            q = a;
+            a = (x * x - T(2.0)) * q - p + A[index];
+        }
+
+        return (std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * std::exp(x);
+    }
+
+    T b = B[0];
+
+    for (uint64_t index = 1; index < 25; index++) {
+        p = q;
+        q = b;
+        b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+    }
+
+    return (T(0.5) * (b - p) / std::sqrt(x));
+} // T scaled_modified_bessel_k1_forward(T x)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (x == T(1.0)) {
+        return T(1.0);
+    }
+
+    if (x == T(0.0)) {
+        if (n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
+        return std::cos(n * std::acos(x + x - T(1.0)));
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x - T(1.0);
+    }
+
+    T p = T(1.0);
+    T q = x + x - T(1.0);
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) {
+    return shifted_chebyshev_polynomial_t_forward(x, static_cast(n));
+} // shifted_chebyshev_polynomial_t_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (x == T(1.0)) {
+        return n + 1;
+    }
+
+    if (x == T(0.0)) {
+        if (n % 2 == 0) {
+            return n + 1;
+        }
+
+        return -(n + 1);
+    }
+
+    if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
+        if (std::sin(std::acos(x + x - T(1.0))) != T(0.0)) {
+            return std::sin((n + 1) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)));
+        }
+
+        return (n + 1) * std::cos((n + 1) * std::acos(x + x - T(1.0))) / (x + x - T(1.0));
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x - T(1.0) + (x + x - T(1.0));
+    }
+
+    T p = T(1.0);
+    T q = x + x - T(1.0) + (x + x - T(1.0));
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) {
+    return shifted_chebyshev_polynomial_u_forward(x, static_cast(n));
+} // shifted_chebyshev_polynomial_u_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (x == T(1.0)) {
+        return T(1.0);
+    }
+
+    if (x == T(0.0)) {
+        if (n % 2 == 0) {
+            return (n + n + 1);
+        }
+
+        return -(n + n + 1);
+    }
+
+    if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
+        if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
+            return std::cos(((n) + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0));
+        }
+
+        if (n % 2 == 0) {
+            return n + n + 1;
+        }
+
+        return -(n + n + 1);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
+    }
+
+    T p = T(1.0);
+    T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) {
+    return shifted_chebyshev_polynomial_v_forward(x, static_cast(n));
+} // shifted_chebyshev_polynomial_v_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
+    if (n < 0) {
+        return T(0.0);
+    }
+
+    if (x == T(1.0)) {
+        return n + n + 1;
+    }
+
+    if (x == T(0.0)) {
+        if (n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if ((n > 4) && (std::abs(x + x - T(1.0)) < T(1.0))) {
+        if (std::cos(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
+            return std::sin((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)) / T(2.0));
+        }
+
+        if (n % 2 == 0) {
+            return T(1.0);
+        }
+
+        return T(-1.0);
+    }
+
+    if (n == 0) {
+        return T(1.0);
+    }
+
+    if (n == 1) {
+        return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
+    }
+
+    T p = T(1.0);
+    T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
+    T r;
+
+    for (int64_t k = 2; k <= n; k++) {
+        r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+        p = q;
+        q = r;
+    }
+
+    return r;
+} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
+
+template
+static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) {
+    return shifted_chebyshev_polynomial_w_forward(x, static_cast(n));
+} // shifted_chebyshev_polynomial_w_forward(T x, T n)
+
+template
+static inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) {
+    if (std::isinf(x)) {
+        return T(0.0);
+    }
+
+    if (std::abs(x) < T(0.5)) {
+        return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0)))))));
+    }
+
+    return std::sin(x) / x;
+} // T spherical_bessel_j0_forward(T x)
+
+C10_CLANG_DIAGNOSTIC_POP()
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h b/MLPY/Lib/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h
new file mode 100644
index 0000000000000000000000000000000000000000..a1e84f029202bdb27e825a062a63adbcb5151d76
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h
@@ -0,0 +1,71 @@
+#pragma once
+
+namespace at {
+// views and their in-place version ops
+#define TORCH_VIEW_FNS(m) \
+  m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
+  m.impl("detach", torch::CppFunction::makeFallthrough()); \
+  m.impl("detach_", torch::CppFunction::makeFallthrough()); \
+  m.impl("diagonal", torch::CppFunction::makeFallthrough()); \
+  m.impl("expand", torch::CppFunction::makeFallthrough()); \
+  m.impl("expand_as", torch::CppFunction::makeFallthrough()); \
+  m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \
+  m.impl("narrow", torch::CppFunction::makeFallthrough()); \
+  m.impl("permute", torch::CppFunction::makeFallthrough()); \
+  m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("select.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("squeeze", torch::CppFunction::makeFallthrough()); \
+  m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \
+  m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("transpose_", torch::CppFunction::makeFallthrough()); \
+  m.impl("t", torch::CppFunction::makeFallthrough()); \
+  m.impl("t_", torch::CppFunction::makeFallthrough()); \
+  m.impl("real", torch::CppFunction::makeFallthrough()); \
+  m.impl("imag", torch::CppFunction::makeFallthrough()); \
+  m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \
+  m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("unfold", torch::CppFunction::makeFallthrough()); \
+  m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
+  m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
+  m.impl("view_as", torch::CppFunction::makeFallthrough()); \
+  m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \
+  m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \
+  m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \
+  m.impl("swapdims", torch::CppFunction::makeFallthrough()); \
+  m.impl("chunk", torch::CppFunction::makeFallthrough()); \
+  m.impl("reshape", torch::CppFunction::makeFallthrough()); \
+  m.impl("alias", torch::CppFunction::makeFallthrough()); \
+  m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \
+  m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \
+  m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
+  m.impl("conj", torch::CppFunction::makeFallthrough()); \
+  m.impl("_conj", torch::CppFunction::makeFallthrough()); \
+  m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \
+  m.impl("resize_", torch::CppFunction::makeFallthrough());
+
+#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
+  m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
+  m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
+  m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
+  m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
+  m.impl("full_like", torch::CppFunction::makeFallthrough()); \
+  m.impl("stride.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("size.int", torch::CppFunction::makeFallthrough()); \
+  m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \
+  m.impl("is_complex", torch::CppFunction::makeFallthrough()); \
+  m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
+  m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
+}
+
+#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
+  m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
+  m.impl("view", torch::CppFunction::makeFallthrough());
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/MathBitsFallback.h b/MLPY/Lib/site-packages/torch/include/ATen/native/MathBitsFallback.h
new file mode 100644
index 0000000000000000000000000000000000000000..add960c5b687695fd089c63bfe8ec21a0510cf2d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/MathBitsFallback.h
@@ -0,0 +1,157 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+
+#include 
+#endif
+
+namespace at::native {
+// This fallback should only be used for operations that are self inverse and have a corresponding tensor
+// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
+// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
+// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
+
+// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
+struct MathOpFallback {
+  MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
+  virtual bool is_bit_set(const Tensor&) = 0;
+  void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
+    /*
+      Situations to handle:
+        1. Out-of-place operation.  Easy: materialize all inputs and
+          call it a day.
+        2. Inplace operation.  Desugar x.add_(2) into x.conj_().add_(2).conj_().
+          Materialize other inputs as in (1).
+        3. out= operation.  Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
+        Materialize other inputs as in (1).
+
+        It is important to be able to tell if we READ from an argument and if we
+        WRITE to an argument.  Conservative approach is to assume that we always
+        READ from an argument, but in out= operations you can skip
+        conjugating inputs on entry that never get used. In the current schema we
+        can't easily tell if the operation is in in-place or out= operation.
+
+        Note:
+        1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
+        2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
+           correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
+
+           If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
+           with these mutable inputs would read into wrong values in the following cases:
+           1. Non mutable inputs have their math bit set to false.
+           2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
+              with one or more mutable arg(s)) are cloned.
+           At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
+    */
+    const auto& arguments = op.schema().arguments();
+    const auto num_arguments = arguments.size();
+    const auto stack_start = stack->size() - num_arguments;
+
+    c10::optional is_write;
+    for (const auto i : c10::irange(num_arguments)) {
+      // Three possible states:
+      // 1. alias_info has no value --> out-of-place operation
+      // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
+      // 3. alias_info does have a value, alias_info->is_write=False --> view operation
+      const AliasInfo* alias_info = arguments[i].alias_info();
+      if (alias_info != nullptr) {
+        if (is_write.has_value()) {
+          TORCH_CHECK(*is_write == alias_info->isWrite(),
+            "Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
+            op_name, " fallback doesn't work for operators with a mix "
+            "mutable and non-mutable inputs that alias with outputs, "
+            "this must be implemented manually.  "
+            "If you got this error on a core op, please report a bug to PyTorch.");
+        } else {
+          is_write = alias_info->isWrite();
+        }
+      }
+    }
+
+    if (is_write.has_value() && !*is_write) {
+      // We assume that view operators automatically handle the math bit
+      // correctly by propagating the dispatch key in key_set.
+      // This is not necessarily always right, so you should test these cases.
+      op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
+      return;
+    }
+
+    // Mutable inputs with math bit set to True and their clones
+    std::vector> mutable_inputs_with_their_clones;
+    for (const auto i : c10::irange(num_arguments)) {
+      auto& ivalue = (*stack)[stack_start + i];
+      if (!(ivalue.isTensor() || ivalue.isTensorList())) {
+        continue;
+      }
+      const auto& argument = arguments[i];
+      bool mut_arg = false;
+      if (argument.alias_info()) {
+        // Was already tested by is_write loop above
+        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
+        mut_arg = true;
+      }
+      if (ivalue.isTensor()) {
+        if (!is_bit_set(ivalue.toTensor())) {
+          continue;
+        }
+        auto tensor = std::move(ivalue).toTensor();
+        auto resolved_tensor = at::clone(tensor);
+        if (mut_arg) {
+          TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
+            op_name, "bit set to true.");
+          mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
+        }
+        (*stack)[stack_start + i] = std::move(resolved_tensor);
+      } else if (ivalue.isTensorList()) {
+        auto tensors = std::move(ivalue).toTensorList();
+        for(const auto j : c10::irange(tensors.size())) {
+          const auto& tensor = tensors[j];
+          if (!is_bit_set(tensor)) {
+            continue;
+          }
+          TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
+              op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
+              op.schema().name());
+          tensors[j] = at::clone(tensor);
+        }
+        (*stack)[stack_start + i] = std::move(tensors);
+      }
+    }
+
+    op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
+
+    TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
+
+    for (std::pair mut_tensors: mutable_inputs_with_their_clones) {
+      auto& mutable_input =  mut_tensors.first;
+      auto& cloned_mutable_input =  mut_tensors.second;
+      auto& ivalue = (*stack)[stack_start];
+      auto returned_output = std::move(ivalue).toTensor();
+
+      // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
+      TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
+
+      // necessary for out= arg
+      at::native::resize_output(mutable_input, returned_output.sizes());
+
+      mutable_input.copy_(returned_output);
+      (*stack)[stack_start] = std::move(mutable_input);
+    }
+  }
+
+  virtual ~MathOpFallback() = default;
+
+  DispatchKey key;
+  string op_name;
+};
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/MaxPooling.h b/MLPY/Lib/site-packages/torch/include/ATen/native/MaxPooling.h
new file mode 100644
index 0000000000000000000000000000000000000000..89a1ff7080deb3d91e2c1784af0942cd423beae6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/MaxPooling.h
@@ -0,0 +1,97 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+static void check_max_pool1d(
+    const Tensor& self,
+    IntArrayRef kernel_size,
+    IntArrayRef stride,
+    IntArrayRef padding,
+    IntArrayRef dilation,
+    bool ceil_mode) {
+
+  TORCH_CHECK(
+      self.dim() == 2 || self.dim() == 3,
+      "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
+  TORCH_CHECK(
+      kernel_size.size() == 1,
+      "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
+      kernel_size.size());
+  TORCH_CHECK(
+      stride.empty() || stride.size() == 1,
+      "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
+      stride.size());
+  TORCH_CHECK(
+      padding.size() == 1,
+      "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
+      padding.size());
+  TORCH_CHECK(
+      dilation.size() == 1,
+      "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
+      dilation.size());
+
+  // If stride=None then set it to kernel_size
+  if (stride.empty()) {
+    stride = kernel_size;
+  }
+
+  TORCH_CHECK(
+      kernel_size[0] > 0,
+      "max_pool1d() kernel_size must be greater than zero, but got ",
+      kernel_size[0]);
+  TORCH_CHECK(
+      stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
+  TORCH_CHECK(
+      padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
+  TORCH_CHECK(
+      padding[0] <= kernel_size[0] / 2,
+      "max_pool1d() padding should be at most half of kernel size, but got padding=",
+      padding[0],
+      " and kernel_size=",
+      kernel_size[0]);
+  TORCH_CHECK(
+      dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
+
+  const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
+  TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
+}
+
+// TODO(Heitor) Template by dimension
+struct PoolingParams1D {
+  int64_t NB; // Number of batches
+  int64_t NC; // Number of channels
+  int64_t IW; // Input width
+  int64_t OW; // Output width
+  int64_t KW; // Kernel width
+  int64_t SJ; // Column stride
+  int64_t PJ; // Column padding
+  int64_t DJ; // Column dilation
+
+  // Return index of input element for the given kernel and output index
+  inline int64_t index(int64_t kj, int64_t oj) const {
+    return oj * SJ + kj * DJ - PJ;
+  }
+
+  // Return index of first output within bounds for this kernel index
+  inline int64_t valid_output_start(int64_t kj) const {
+    int64_t ij = index(kj, 0);;
+    return ij < 0 ? at::divup(-ij, SJ) : 0;
+  }
+
+  // Return index one past last output within bounds for this kernel index
+  inline int64_t valid_output_end(int64_t kj) const {
+    int64_t ij = index(kj, OW - 1);
+    return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
+  }
+};
+
+using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
+
+DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/NonEmptyUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/NonEmptyUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..26cb65d844b4f0e1d88a45712159d18c0312ab73
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/NonEmptyUtils.h
@@ -0,0 +1,27 @@
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+inline int64_t ensure_nonempty_dim(int64_t dim) {
+  return std::max(dim, 1);
+}
+
+inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
+  return t.dim() == 0 ? 1 : t.size(dim);
+}
+
+inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
+  return t.dim() == 0 ? 1 : t.stride(dim);
+}
+
+using IdxVec = std::vector;
+inline IdxVec ensure_nonempty_vec(IdxVec vec) {
+  if (vec.empty()) {
+    vec.push_back(1);
+  }
+  return vec;
+}
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/NonSymbolicBC.h b/MLPY/Lib/site-packages/torch/include/ATen/native/NonSymbolicBC.h
new file mode 100644
index 0000000000000000000000000000000000000000..bbb4b0f7babdcf6ae263b40bc9f387bf9f7a6361
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/NonSymbolicBC.h
@@ -0,0 +1,26 @@
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at::native {
+// This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
+// However, in certain cases (such as static runtime), we call the native versions of the ops directly.
+// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
+TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
+TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
+TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype=c10::nullopt, c10::optional layout=c10::nullopt, c10::optional device=c10::nullopt, c10::optional pin_memory=c10::nullopt, c10::optional is_coalesced=c10::nullopt);
+TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index);
+TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const c10::optional& weight_opt, int64_t reduction, int64_t ignore_index);
+// The below ops don't get a duplicated C++ implementation.
+// They are backward ops, which make them very unlikely to be called directly
+// by external code (at::native::trace_backward).
+// They get their own declaration for BC purposes however.
+TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, int64_t padding_idx=-1);
+TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1);
+TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
+TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
+TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
+TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index);
+TORCH_API std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Normalization.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Normalization.h
new file mode 100644
index 0000000000000000000000000000000000000000..2a983e9ab6961f764c2eb1661e5f1c2cc7c6ed61
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Normalization.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
+DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Padding.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Padding.h
new file mode 100644
index 0000000000000000000000000000000000000000..ee31048f559ee1797baa62f57cade7918220b115
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Padding.h
@@ -0,0 +1,62 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
+
+// reflection padding
+DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
+DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
+DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
+DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
+DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
+DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
+
+// replication padding
+DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
+DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
+DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
+DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
+DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
+DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
+
+namespace padding {
+
+template 
+static inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
+
+  TORCH_CHECK(padding.size() == 2 * dim,
+      "padding size is expected to be ", 2 * dim,
+      ", but got: ", padding.size());
+
+  int input_dim = input.dim();
+
+  bool is_batch_mode = input_dim == (dim + 2);
+
+  bool valid_batch_mode = is_batch_mode;
+  bool valid_non_batch_mode = !is_batch_mode;
+
+  if (is_batch_mode) {
+    // allow batch size of 0-dim.
+    for (const auto d : c10::irange(1, input_dim)) {
+      valid_batch_mode = valid_batch_mode && input.size(d) != 0;
+    }
+  } else {
+    for (const auto d : c10::irange(0, input_dim)) {
+      valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
+    }
+  }
+
+  // allow empty batch size but not other dimensions.
+  TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
+      "Expected ", dim + 1, "D or ", dim + 2,
+      "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
+      input.sizes());
+}
+
+} // namespace padding
+
+} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/PixelShuffle.h b/MLPY/Lib/site-packages/torch/include/ATen/native/PixelShuffle.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2b83f4baa45859e219576775fcc59aa9aac7f53
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/PixelShuffle.h
@@ -0,0 +1,47 @@
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
+  TORCH_CHECK(self.dim() >= 3,
+              "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
+              self.dim(), " dimension(s)");
+  TORCH_CHECK(upscale_factor > 0,
+              "pixel_shuffle expects a positive upscale_factor, but got ",
+              upscale_factor);
+  int64_t c = self.size(-3);
+  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
+  TORCH_CHECK(c % upscale_factor_squared == 0,
+              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
+              "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
+}
+
+inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
+  TORCH_CHECK(
+      self.dim() >= 3,
+      "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
+      self.dim(),
+      " dimension(s)");
+  TORCH_CHECK(
+      downscale_factor > 0,
+      "pixel_unshuffle expects a positive downscale_factor, but got ",
+      downscale_factor);
+  int64_t h = self.size(-2);
+  int64_t w = self.size(-1);
+  TORCH_CHECK(
+      h % downscale_factor == 0,
+      "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
+      h,
+      " is not divisible by ",
+      downscale_factor);
+  TORCH_CHECK(
+      w % downscale_factor == 0,
+      "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
+      w,
+      " is not divisible by ",
+      downscale_factor);
+}
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/PointwiseOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/PointwiseOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..785d62f8d2735f1baa3c31860b6473aa386aca43
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/PointwiseOps.h
@@ -0,0 +1,28 @@
+// Ternary and higher-order pointwise operations
+#pragma once
+
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+
+struct TensorIterator;
+struct TensorIteratorBase;
+
+namespace native {
+
+using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar);
+using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar);
+using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double);
+
+DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub);
+DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub);
+DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub);
+DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub);
+DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Pool.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Pool.h
new file mode 100644
index 0000000000000000000000000000000000000000..8daa5f56b8388e9090e8e45a1b6abdcefb7e0254
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Pool.h
@@ -0,0 +1,340 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#pragma once
+
+namespace at::native {
+
+using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
+    int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
+using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
+
+DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
+DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
+
+// averge pooling has same signature for forward and backward
+using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
+    int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional divisor_override);
+using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
+    int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional divisor_override);
+
+DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
+DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
+
+using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
+    int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
+using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
+
+DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel);
+DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
+namespace {
+
+template 
+static inline dest_t
+safe_downcast(src_t v)
+{
+  TORCH_CHECK(std::numeric_limits::min() <= v && v <= std::numeric_limits::max(),
+              "integer out of range");
+
+  return static_cast(v);
+}
+
+template
+static inline T pooling_output_shape_pad_lr(
+        T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
+        bool ceil_mode) {
+    T outputSize = div_rtn(
+        inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
+        (ceil_mode ? stride - 1 : 0), stride) + 1;
+    if (ceil_mode) {
+        // ensure that the last pooling starts inside the image
+        // needed to avoid problems in ceil mode
+        if ((outputSize - 1) * stride >= inputSize + pad_l) {
+          --outputSize;
+        }
+    }
+    return outputSize;
+}
+
+template
+static inline T pooling_output_shape(
+      T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
+    TORCH_CHECK(stride != 0, "stride should not be zero");
+    TORCH_CHECK(pad >= 0,
+                "pad must be non-negative, but got pad: ", pad);
+    TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2,
+                "pad should be at most half of effective kernel size, but got pad=",
+                pad, ", kernel_size=", kernelSize, " and dilation=", dilation)
+    return pooling_output_shape_pad_lr(
+        inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
+}
+
+template 
+std::pair _pooling_same_mode_padding_lr(
+    T inputSize, T kernelSize, T stride, T dilation) {
+  // NOTE: with strides, the output shape is ceil(inputSize/stride)
+  auto total_padding = T(dilation) * (kernelSize - 1);
+
+  // Prefer symmetric padding if possible
+  if (stride > 2 && (total_padding % 2 == 1)) {
+    // The floor in the output size calculation gives us a little wiggle room
+    auto wiggle_room = inputSize % stride - 1;
+    if (wiggle_room > 0) {
+      total_padding = total_padding - 1;
+    }
+  }
+
+  auto left = total_padding / 2;
+  return {left, total_padding - left};
+}
+
+inline std::pair pooling_same_mode_padding_lr(
+    int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
+  return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
+}
+
+inline std::pair pooling_same_mode_padding_lr(
+    c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) {
+  return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation));
+}
+
+// AveragePool2d/DilatedMaxPool2d (forward)
+static inline void
+pool2d_shape_check(
+  const Tensor& input,
+  int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
+  int64_t nInputPlane,
+  int64_t inputHeight, int64_t inputWidth,
+  int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
+{
+  const int64_t ndim = input.ndimension();
+  const int64_t nOutputPlane = nInputPlane;
+
+  TORCH_CHECK(kW > 0 && kH > 0,
+              "kernel size should be greater than zero, but got ",
+              "kH: ", kH, " kW: ", kW);
+  TORCH_CHECK(dW > 0 && dH > 0,
+              "stride should be greater than zero, but got "
+              "dH: ", dH, " dW: ", dW);
+  TORCH_CHECK(dilationH > 0 && dilationW > 0,
+              "dilation should be greater than zero, but got ",
+              "dilationH: ", dilationH, " dilationW: ", dilationW);
+
+  bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
+  if (memory_format == at::MemoryFormat::ChannelsLast){
+    // Expect tensor in NHWC format and allow 0-dim only for N.
+    TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
+      "Expected 4D (batch mode) tensor expected for input with channels_last layout"
+      " with optional 0 dim batch size for input, but got: ", input.sizes());
+  } else {
+    TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
+      (ndim == 4 && valid_dims && input.size(3) != 0),
+      "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
+      input.sizes());
+  }
+
+  TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
+              "pad should be smaller than or equal to half of kernel size, but got ",
+              "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
+
+  TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
+              "Given input size: (",
+              nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
+              "Calculated output size: (",
+              nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
+              "Output size is too small");
+}
+
+// DilatedMaxPool2d (backward)
+static inline void
+max_pool2d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  const Tensor& indices,
+  int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
+  int64_t nInputPlane,
+  int64_t inputHeight, int64_t inputWidth,
+  int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
+{
+  pool2d_shape_check(
+    input,
+    kH, kW, dH, dW, padH, padW, dilationH, dilationW,
+    nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
+
+  const int64_t ndim = input.ndimension();
+  const int64_t nOutputPlane = nInputPlane;
+
+  check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
+  check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
+  check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
+
+  check_dim_size(indices, ndim, ndim-3, nOutputPlane);
+  check_dim_size(indices, ndim, ndim-2, outputHeight);
+  check_dim_size(indices, ndim, ndim-1, outputWidth);
+}
+
+// AveragePool2d (backward)
+static inline void
+avg_pool2d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  int64_t /*nbatch*/,
+  int kH, int kW, int dH, int dW, int padH, int padW,
+  int64_t nInputPlane,
+  int64_t inputHeight, int64_t inputWidth,
+  int64_t outputHeight, int64_t outputWidth,
+  MemoryFormat memory_format)
+{
+  pool2d_shape_check(
+    input,
+    kH, kW, dH, dW, padH, padW, 1, 1,
+    nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
+    memory_format);
+
+  const int64_t ndim = input.ndimension();
+  const int64_t nOutputPlane = nInputPlane;
+
+  check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
+  check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
+  check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
+}
+
+// AveragePool3d/DilatedMaxPool3d (forward)
+static inline void
+pool3d_shape_check(
+  const Tensor& input,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int dilationT, int dilationH, int dilationW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char *fn_name,
+  bool check_input_size=false)
+{
+  const int64_t ndim = input.ndimension();
+
+  TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
+              "kernel size should be greater than zero, but got ",
+              "kT: ", kT, " kH: ", kH, " kW: ", kW);
+  TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
+              "stride should be greater than zero, but got ",
+              "dT: ", dT, " dH: ", dH, " dW: ", dW);
+  TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
+              "dilation should be greater than zero, but got ",
+              "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
+
+  TORCH_CHECK(ndim == 4 || ndim == 5,
+              fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
+
+  for (const auto i : c10::irange(ndim)) {
+    if (ndim == 5 && i == 0) {
+      // size of batch-dim can be 0.
+      continue;
+    }
+    TORCH_CHECK(
+        input.size(i) > 0,
+        fn_name,
+        ": Expected input's non-batch dimensions to have positive length,"
+        " but input has a shape of ",
+        input.sizes(),
+        " and non-batch dimension ",
+        input.size(i),
+        " has length zero!")
+  }
+
+  if (check_input_size) { // AveragePool3d
+    TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
+                "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
+                "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
+  }
+
+  TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
+              "pad should be smaller than or equal to half of kernel size, but got "
+              "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
+
+  TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
+              "Given input size: (",
+              nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
+              "Calculated output size: (",
+              nslices, "x", otime, "x", oheight, "x", owidth, "). ",
+              "Output size is too small");
+}
+
+static inline void
+max_pool3d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  const Tensor& indices,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int dilationT, int dilationH, int dilationW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char* fn_name)
+{
+  const int64_t ndim = input.ndimension();
+
+  pool3d_shape_check(
+    input,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    pT, pH, pW,
+    dilationT, dilationH, dilationW,
+    itime, iheight, iwidth,
+    otime, oheight, owidth, fn_name);
+
+  check_dim_size(gradOutput, ndim, ndim-4, nslices);
+  check_dim_size(gradOutput, ndim, ndim-3, otime);
+  check_dim_size(gradOutput, ndim, ndim-2, oheight);
+  check_dim_size(gradOutput, ndim, ndim-1, owidth);
+
+  check_dim_size(indices, ndim, ndim-4, nslices);
+  check_dim_size(indices, ndim, ndim-3, otime);
+  check_dim_size(indices, ndim, ndim-2, oheight);
+  check_dim_size(indices, ndim, ndim-1, owidth);
+}
+
+static inline void
+avg_pool3d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth,
+  const char *fn_name)
+{
+  const int64_t ndim = input.ndimension();
+
+  pool3d_shape_check(
+    input,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    pT, pH, pW,
+    1, 1, 1,
+    itime, iheight, iwidth,
+    otime, oheight, owidth,
+    fn_name, true);
+
+  check_dim_size(gradOutput, ndim, ndim-4, nslices);
+  check_dim_size(gradOutput, ndim, ndim-3, otime);
+  check_dim_size(gradOutput, ndim, ndim-2, oheight);
+  check_dim_size(gradOutput, ndim, ndim-1, owidth);
+}
+
+} // anonymous namespace
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Pow.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Pow.h
new file mode 100644
index 0000000000000000000000000000000000000000..5845442f0de9102ad5c4793f23561970bc16f1fa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Pow.h
@@ -0,0 +1,69 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+
+struct TensorIterator;
+struct TensorIteratorBase;
+
+namespace native {
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+#define HOST_DEVICE __host__ __device__
+#else
+#define HOST_DEVICE
+#endif
+
+// integral power in pytorch allows for negative exponents, giving truncated integral results.
+// e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
+// only non-zero result.
+template ::value, T>::type* = nullptr>
+static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
+  T result = 1;
+  while (b) {
+    if (b & 1) {
+       result *= a;
+    }
+    b /= 2;
+    a *= a;
+  }
+  return result;
+}
+
+template ::value && !std::is_signed::value, T>::type* = nullptr>
+static inline HOST_DEVICE T powi(T a, T b) {
+  return powi_impl(a, b);
+}
+
+template ::value && std::is_signed::value, T>::type* = nullptr>
+static inline HOST_DEVICE T powi(T a, T b) {
+  if ( b < 0 ) {
+      if ( a == 1 ) {
+          return 1;
+      } else if ( a == -1 ) {
+          auto negative = (-b) % static_cast(2);
+          return negative ? -1 : 1;
+      } else {
+          return 0;
+      }
+  }
+  return powi_impl(a, b);
+}
+
+using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
+using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
+
+DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
+DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
+
+} // namespace native
+
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/RNN.h b/MLPY/Lib/site-packages/torch/include/ATen/native/RNN.h
new file mode 100644
index 0000000000000000000000000000000000000000..176897b560d3a71eb05a6db91b399a27d2c6f634
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/RNN.h
@@ -0,0 +1,53 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool);
+using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool);
+using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool);
+using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool);
+
+DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub);
+DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub);
+DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub);
+DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub);
+DECLARE_DISPATCH(rnn_fn, gru_miopen_stub);
+DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub);
+DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub);
+DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub);
+DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub);
+DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub);
+DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub);
+DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub);
+DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub);
+DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub);
+DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub);
+DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub);
+DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub);
+
+inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) {
+  auto input_device = input.device();
+  auto input_dtype = input.scalar_type();
+
+  auto check_tensors = [&](const std::string& name, const Tensor& t) {
+    if (!t.defined()) return;
+    auto t_device = t.device();
+    TORCH_CHECK(input_device == t_device,
+             "Input and ", name, " tensors are not at the same device, found input tensor at ",
+             input_device, " and ", name, " tensor at ", t_device);
+    if (check_dtype) {
+      auto t_dtype = t.scalar_type();
+      TORCH_CHECK(input_dtype == t_dtype,
+               "Input and ", name, " tensors are not the same dtype, found input tensor with ",
+               input_dtype, " and ", name, " tensor with ", t_dtype);
+    }
+  };
+
+  for (const auto& h : hiddens) check_tensors("hidden", h);
+  for (const auto& p : params) check_tensors("parameter", p);
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/RangeFactories.h b/MLPY/Lib/site-packages/torch/include/ATen/native/RangeFactories.h
new file mode 100644
index 0000000000000000000000000000000000000000..f39e9678f76308b3f510897b66edcf525e367b19
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/RangeFactories.h
@@ -0,0 +1,12 @@
+#include 
+#include 
+
+namespace at {
+struct TensorIterator;
+
+namespace native {
+
+DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
+DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceAllOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceAllOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..1f0bae6179a8968fdab7aef95a1a2b33bfb0dbed
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceAllOps.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include 
+
+namespace at {
+class Tensor;
+}
+
+namespace at::native {
+
+using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
+using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
+DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
+DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..2bef746b4ca89d15e162e517658c8b9544c1b81d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceOps.h
@@ -0,0 +1,56 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+struct TensorIterator;
+class Tensor;
+}
+
+namespace at::native {
+
+using reduce_fn = void(*)(TensorIterator &);
+
+DECLARE_DISPATCH(reduce_fn, sum_stub);
+DECLARE_DISPATCH(reduce_fn, nansum_stub);
+DECLARE_DISPATCH(reduce_fn, prod_stub);
+DECLARE_DISPATCH(reduce_fn, mean_stub);
+DECLARE_DISPATCH(reduce_fn, and_stub);
+DECLARE_DISPATCH(reduce_fn, or_stub);
+DECLARE_DISPATCH(reduce_fn, min_values_stub);
+DECLARE_DISPATCH(reduce_fn, max_values_stub);
+DECLARE_DISPATCH(reduce_fn, argmax_stub);
+DECLARE_DISPATCH(reduce_fn, argmin_stub);
+
+using reduce_std_var_function =
+    void (*)(TensorIterator&, double correction, bool take_sqrt);
+DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);
+
+using reduce_norm_fn =
+    void (*)(Tensor&, const Tensor&, const c10::Scalar&, c10::optional);
+DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
+
+using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&);
+DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
+
+using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t);
+using cum_fn = void (*)(Tensor&, const Tensor&, int64_t);
+DECLARE_DISPATCH(structured_cum_fn, cumsum_stub);
+DECLARE_DISPATCH(structured_cum_fn, cumprod_stub);
+DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
+
+DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub);
+DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub);
+
+// Used in cuda/Normalization.cu
+TORCH_API std::tuple var_mean_out(
+    Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim,
+    int64_t correction, bool keepdim);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceOpsUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..01dc822cf23a3e55f836bbe8b85ab346e6c5dfe8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ReduceOpsUtils.h
@@ -0,0 +1,449 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+// Maximum and minimum possible scalar values, including infinities
+template 
+constexpr scalar_t upper_bound() {
+  using lim = std::numeric_limits;
+  return lim::has_infinity ? lim::infinity() : lim::max();
+}
+
+template 
+constexpr scalar_t lower_bound() {
+  using lim = std::numeric_limits;
+  return lim::has_infinity ? -lim::infinity() : lim::lowest();
+}
+
+static inline Tensor restride_dim(
+  const Tensor& src, int64_t dim,
+  IntArrayRef replacement_shape
+) {
+  auto strides = ensure_nonempty_vec(src.strides().vec());
+  strides[dim] = 0;
+  return src.as_strided(replacement_shape, strides);
+}
+
+inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
+                                int64_t dim) {
+  IntArrayRef self_sizes = self.sizes();
+  std::vector result_sizes;
+  result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
+  result_sizes[dim] = 1;
+  result.resize_(result_sizes);
+}
+
+inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
+                                      const Scalar& ident, int64_t dim, bool keepdim) {
+  if (self.numel() == 1 && self.ndimension() == 0) {
+    result.resize_({});
+    result.fill_(self);
+    return true;
+  }
+  // Return identity
+  if (self.numel() == 0) {
+    _dimreduce_setup(result, self, dim);
+    result.fill_(ident);
+    if (!keepdim) result.squeeze_(dim);
+    return true;
+  }
+  return false;
+}
+
+inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
+                                               int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
+  if (self.numel() == 1 && self.ndimension() == 0) {
+    result.resize_({});
+    result.fill_(self);
+    return true;
+  }
+
+  return false;
+}
+
+inline c10::optional _allreduce_return_trivial(
+    const Tensor& self,
+    const Scalar& ident) {
+  // Return identity
+  if (self.numel() == 0) {
+    return at::scalar_tensor(ident, self.options());
+  }
+  return c10::nullopt;
+}
+
+#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
+{ \
+  TORCH_CHECK(\
+    out.option() == self.option(),\
+    "expected ", #option, " ",\
+    self.option(),\
+    " but found ", out.option())\
+}
+
+static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
+  OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
+  OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
+  OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
+}
+
+static inline Tensor integer_upcast(const Tensor& self, c10::optional dtype) {
+  ScalarType scalarType = self.scalar_type();
+  TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
+  ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
+  return self.toType(upcast_scalarType);
+}
+
+using DimMask = TensorIterator::DimMask;
+
+static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
+  if (opt_dims.has_value()) {
+    return DimVector(opt_dims.value());
+  } else {
+    std::vector all_dims(ndim);
+    std::iota(all_dims.begin(), all_dims.end(), 0);
+    return DimVector(all_dims);
+  }
+}
+
+static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
+  DimMask mask;
+  if (opt_dims.has_value()) {
+    auto dims = opt_dims.value();
+    if (dims.empty() && !allow_empty_dims) {
+      mask = DimMask().flip();
+    } else {
+      mask = at::dim_list_to_bitset(dims, ndim);
+    }
+  } else {
+    mask = DimMask().flip();
+  }
+  return mask;
+}
+
+inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
+  auto shape = DimVector(self.sizes());
+  for (int dim = shape.size() - 1; dim >= 0; dim--) {
+    if (mask[dim]) {
+      if (keepdim) {
+        shape[dim] = 1;
+      } else {
+        shape.erase(shape.begin() + dim);
+      }
+    }
+  }
+  return shape;
+}
+
+static void resize_reduction_result(
+    Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
+    ScalarType /*dtype*/)
+{
+  auto shape = shape_from_dim_mask(self, mask, keepdim);
+  TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
+  at::native::resize_output(result, shape);
+}
+
+inline Tensor create_reduction_result(
+  const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
+) {
+  DimMask mask = make_dim_mask(dim, self.dim());
+  auto shape = shape_from_dim_mask(self, mask, keepdim);
+  return at::empty(shape, self.options().dtype(dtype));
+}
+
+static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
+  if (keepdim) {
+    return result;
+  }
+  auto shape = DimVector(result.sizes());
+  auto stride = DimVector(result.strides());
+  for (const auto dim : c10::irange(ndim)) {
+    if (mask[dim]) {
+      shape.insert(shape.begin() + dim, 1);
+      stride.insert(stride.begin() + dim, 0);
+    }
+  }
+  return result.as_strided(shape, stride);
+}
+
+static TensorIterator make_reduction(
+    const char* name, Tensor& result, const Tensor& self,
+    at::OptionalIntArrayRef dim_opt,
+    bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
+  // check that result type and dtype match if provided
+  TORCH_CHECK(
+      !result.defined() || result.scalar_type() == out_dtype,
+      name, ": provided dtype must match dtype of result. Got ",
+      toString(result.scalar_type()),
+      " and ",
+      toString(out_dtype),
+      ".");
+  // dim={} performs an all-reduce, same as dim=None
+  IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
+  int64_t ndim = self.dim();
+  auto mask = make_dim_mask(dim, ndim);
+  resize_reduction_result(result, self, mask, keepdim, out_dtype);
+  auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
+  namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
+  if (self.scalar_type() == in_dtype) {
+    return TensorIterator::reduce_op(viewed_result, self);
+  }
+  return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
+}
+
+static C10_UNUSED TensorIterator make_reduction(
+    const char* name, Tensor& result, const Tensor& self,
+    at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // not generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  const bool gpu_lowp_to_f32 = (
+    self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
+  auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
+                   : self.is_complex() ? c10::toComplexType(out_dtype)
+                                       : out_dtype;
+  return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
+}
+
+static TensorIterator make_reduction(
+    const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
+    at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
+    ScalarType dtype2) {
+  // check that result type and dtype match if provided
+  TORCH_CHECK(
+    (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
+    name, ": provided dtype must match dtype of result. Got ",
+    toString(result1.scalar_type()), toString(result2.scalar_type()),
+    " and ",
+    toString(dtype1), toString(dtype2),
+    ".");
+
+  // dim={} performs an all-reduce, same as dim=None
+  auto dim = dim_opt.value_or(IntArrayRef{});
+  int64_t ndim = self.dim();
+  DimMask mask = make_dim_mask(dim, ndim);
+  resize_reduction_result(result1, self, mask, keepdim, dtype1);
+  auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
+
+  resize_reduction_result(result2, self, mask, keepdim, dtype2);
+  auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
+
+  namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
+  namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
+
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // We don't generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  if (self.scalar_type() == dtype1 ||
+      (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
+    return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
+  }
+  return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
+}
+
+static C10_UNUSED TensorIterator make_reduction(
+    const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
+    at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
+  return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
+}
+
+static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
+  if (self.ndimension() == 0) {
+    TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
+      ": Expected reduction dim -1 or 0 for scalar but got ", dim);
+  }
+  else {
+    TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
+      ": Expected reduction dim ", dim, " to have non-zero size.");
+  }
+}
+
+static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
+  TORCH_CHECK(
+    !dim.empty(),
+      fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
+        "Specify the reduction dim with the 'dim' argument.");
+  for (const int64_t d : dim) {
+    zero_numel_check_dims(self, d, fn_name);
+  }
+}
+
+static std::vector get_zero_numel_tensor_size(
+    const Tensor& self,
+    const int64_t dim,
+    const bool keepdim,
+    const char* fn_name) {
+  TORCH_INTERNAL_ASSERT(self.numel() == 0,  fn_name, ": Expected self.numel() == 0.");
+  zero_numel_check_dims(self, dim, fn_name);
+  std::vector sizes;
+  if (keepdim) {
+    sizes = self.sizes().vec();
+    sizes[dim] = 1;
+  }
+  else {
+    for (const auto d : c10::irange(self.dim())) {
+      if (d != dim) {
+        sizes.push_back(self.sizes()[d]);
+      }
+    }
+  }
+  return sizes;
+}
+
+// Resize the result tensor and indices when result.numel() == 0 depending on values of
+// dim and keepdim for returning tensors containing reduction results.
+// This function should be called when you are reducing a zero-numel tensor and want to
+// resize the output and return it. This function exists for resizing zero-numel
+// tensors when the size of the reduction dimension is non-zero.
+static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
+                                     const Tensor& self, const int64_t dim,
+                                     const bool keepdim, const char *fn_name) {
+  auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
+  at::native::resize_output(result, sizes);
+  at::native::resize_output(result_indices, sizes);
+}
+
+inline ScalarType get_dtype_from_self(
+    const Tensor& self,
+    const c10::optional& dtype,
+    bool promote_integers) {
+  if (dtype.has_value()) {
+    return dtype.value();
+  }
+  ScalarType src_type = self.scalar_type();
+  if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
+    return kLong;
+  }
+  return src_type;
+}
+
+inline ScalarType get_dtype_from_result(Tensor& result, c10::optional dtype) {
+  TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
+  if (dtype.has_value()) {
+    return dtype.value();
+  } else {
+    return result.scalar_type();
+  }
+}
+
+
+} // namespace at::native
+
+namespace at::meta {
+
+static C10_UNUSED DimVector get_reduction_shape(
+    const Tensor& self,
+    IntArrayRef dims,
+    bool keepdim,
+    bool allow_empty_dims=false) {
+  auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
+  return native::shape_from_dim_mask(self, mask, keepdim);
+}
+
+static void resize_reduction(
+    impl::MetaBase& meta,
+    const Tensor& self,
+    OptionalIntArrayRef opt_dims,
+    bool keepdim,
+    ScalarType out_dtype,
+    bool allow_empty_dims=false) {
+  DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
+  maybe_wrap_dims(dims_, self.dim());
+  auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
+  meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
+  namedinference::propagate_names_for_reduction(
+      meta.maybe_get_output(), self, dims_, keepdim);
+}
+
+static void resize_reduction_with_indices(
+    impl::MetaBase& meta,
+    const Tensor& self,
+    IntArrayRef dims,
+    bool keepdim,
+    ScalarType out_dtype) {
+  DimVector dims_(dims);
+  maybe_wrap_dims(dims_, self.dim());
+  auto shape = get_reduction_shape(self, dims_, keepdim);
+  meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
+  meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
+  namedinference::propagate_names_for_reduction(
+      meta.maybe_get_output(0), self, dims_, keepdim);
+  namedinference::propagate_names_for_reduction(
+      meta.maybe_get_output(1), self, dims_, keepdim);
+}
+
+static TensorIterator make_reduction(
+    const Tensor& self,
+    const Tensor& result,
+    OptionalIntArrayRef opt_dims,
+    bool keepdim,
+    ScalarType in_dtype) {
+  int64_t ndim = self.dim();
+  auto mask = at::native::make_dim_mask(opt_dims, ndim);
+  auto viewed_result =
+      at::native::review_reduce_result(result, ndim, mask, keepdim);
+  if (self.scalar_type() == in_dtype) {
+    return TensorIterator::reduce_op(viewed_result, self);
+  }
+  return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
+}
+
+static TensorIterator make_reduction(
+    const Tensor& self,
+    const Tensor& result1,
+    const Tensor& result2,
+    IntArrayRef dims,
+    bool keepdim,
+    ScalarType dtype1,
+    ScalarType /*dtype2*/) {
+  int64_t ndim = self.dim();
+  auto mask = at::native::make_dim_mask(dims, ndim);
+  auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
+  auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
+  // special case for type promotion in mixed precision, improves computational efficiency.
+  // We don't generalize this to common mismatched input/output types to avoid cross product
+  // of templated kernel launches.
+  if (self.scalar_type() == dtype1 ||
+      (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
+    return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
+  }
+  return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
+}
+
+static C10_UNUSED TensorIterator make_reduction_from_out_ty(
+    const Tensor& self,
+    const Tensor& result,
+    OptionalIntArrayRef opt_dims,
+    bool keepdim,
+    ScalarType out_dtype) {
+  // special case for type promotion in mixed precision, improves computational
+  // efficiency.
+  // not generalize this to common mismatched input/output types to avoid cross
+  // product of templated kernel launches.
+  const bool gpu_lowp_to_f32 =
+      (self.is_cuda() &&
+       (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
+       out_dtype == kFloat);
+  auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
+  return make_reduction(self, result, opt_dims, keepdim, in_dtype);
+}
+
+} // namespace at::meta
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ReductionType.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ReductionType.h
new file mode 100644
index 0000000000000000000000000000000000000000..97328c227ad6cb7556acc3f9efeca4bf3a66eaf6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ReductionType.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include 
+
+namespace at::native {
+
+enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
+
+static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
+  if (reduce == "max" || reduce == "amax") {
+    return ReductionType::MAX;
+  } else if (reduce == "mean") {
+    return ReductionType::MEAN;
+  } else if (reduce == "min" || reduce == "amin") {
+    return ReductionType::MIN;
+  } else if (reduce == "sum") {
+    return ReductionType::SUM;
+  } else if (reduce == "prod") {
+    return ReductionType::PROD;
+  } else {
+    TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
+  }
+}
+
+// used for `scatter_reduce`, old options for BC.
+static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
+  if (use_new_options) {
+    return get_reduction_enum(reduce);
+  } else {
+    if (reduce == "add") {
+      return ReductionType::SUM;
+    } else if (reduce == "multiply") {
+      return ReductionType::PROD;
+    } else {
+      TORCH_CHECK(false, "reduce argument must be either add or multiply.")
+    }
+  }
+}
+
+} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Repeat.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Repeat.h
new file mode 100644
index 0000000000000000000000000000000000000000..c3a81f0fba67747235a68fa13e8a2cc6d539b09f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Repeat.h
@@ -0,0 +1,48 @@
+#pragma once
+
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+namespace at::native {
+
+template <
+    typename index_t,
+    void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)>
+static inline Tensor repeat_interleave_common(
+    const Tensor& repeats,
+    c10::optional output_size) {
+  TORCH_CHECK(
+      repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
+  TORCH_CHECK(
+      repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
+      "repeats has to be Long or Int tensor");
+  if (repeats.size(0) == 0) {
+    return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  }
+  Tensor repeats_ = repeats.contiguous();
+  Tensor cumsum = repeats.cumsum(0);
+  int64_t total;
+  if (output_size.has_value()) {
+    total = output_size.value();
+  } else {
+    total = cumsum[-1].item();
+    TORCH_CHECK(
+        (repeats >= 0).all().item(), "repeats can not be negative");
+  }
+
+  Tensor result = at::empty({total}, repeats.options());
+  index_t* repeat_ptr = repeats_.data_ptr();
+  int64_t* cumsum_ptr = cumsum.data_ptr();
+  index_t* result_ptr = result.data_ptr();
+  compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
+  return result;
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Resize.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Resize.h
new file mode 100644
index 0000000000000000000000000000000000000000..11aba9b4087f20b8596bae9e718829502d900867
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Resize.h
@@ -0,0 +1,173 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+
+
+namespace at::native {
+
+// TODO: make all operations that resize given outputs use this function
+//   for consistency and maintainability.
+//   Some operations like `cat` might not be able to make the use of
+//   resize_output directly. For more details to understand how it works in `cat`,
+//   see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
+// Resizes outputs
+// Functions accepting output tensors, like with the "out" kwarg, should
+//   call this function to handle resizing their output tensor.
+// Issues a warning if the output tensor has one or more elements and
+//   needs resizing
+// NOTE: In the future the warning will become an error
+// Returns a bool saying whether or not the resize actually happened or not
+TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
+// WARNING: Do NOT call this directly. If you are resizing an output and want
+// to support dynamic shapes call at::resize__symint and resize_output_check_symint.
+// For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
+TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
+
+// Utility for resize_output
+//  Returns a bool saying resize should happen or not and
+//  raises a warning if resizing for one or more elements
+TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
+TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
+
+TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
+TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
+TORCH_API void resize_bytes_nocuda(const Storage& storage, c10::SymInt size_bytes);
+
+static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
+  // It does not make sense to try to resize a storage
+  // to hold 0 elements, and this can break
+  // if storage_offset is positive but
+  // new_size is 0, so just bail in that case
+  // (same comment is in cuda/Resize.h)
+  if (self->numel() == 0) {
+    return;
+  }
+
+  const Storage& storage = self->unsafe_storage();
+  if (!storage) {
+    auto new_storage = c10::make_intrusive(
+        StorageImpl::use_byte_size_t(),
+        new_size_bytes,
+        c10::GetCPUAllocator(),
+        true);
+    self->set_storage_keep_dtype(std::move(new_storage));
+  } else if (new_size_bytes > storage.nbytes()) {
+    resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
+  }
+}
+
+TORCH_API TensorImpl* resize_impl_cpu_(
+    TensorImpl* self,
+    IntArrayRef size,
+    at::OptionalIntArrayRef stride,
+    bool resize_storage = true);
+
+template 
+T maybe_convert_symint(c10::SymInt) = delete;
+
+template <>
+inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
+
+template <>
+inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
+
+template 
+static inline void checkInBoundsForStorage(
+    ArrayRef size,
+    ArrayRef stride,
+    T storage_offset,
+    const caffe2::TypeMeta& data_type,
+    const Storage& new_storage) {
+  T storage_size_bytes =
+      at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
+  T storage_offset_bytes = storage_offset * data_type.itemsize();
+  if (storage_size_bytes == 0) {
+    // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
+    return;
+  }
+  T new_storage_size_bytes = maybe_convert_symint(new_storage.sym_nbytes());
+  TORCH_CHECK(
+      storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
+      "setStorage: sizes ",
+      size,
+      ", strides ",
+      stride,
+      ","
+      " storage offset ",
+      storage_offset,
+      ", and itemsize ",
+      data_type.itemsize(),
+      " requiring a storage size of ",
+      storage_size_bytes + storage_offset_bytes,
+      " are out of bounds for storage of size ",
+      new_storage_size_bytes);
+}
+
+template 
+static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
+                                   ArrayRef size, ArrayRef stride) {
+  // FIXME: stride should be optional
+  if (stride.data()) {
+    TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
+                                              ") and stride length (", stride.size(), ")");
+  }
+
+#ifdef DEBUG
+  TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
+#endif
+
+  // storage: note this can't be replaced with result.set_(storage) as the semantics of that
+  // function is to set the tensor size to be equal to the size of the storage.
+  if (!result.storage().is_alias_of(storage)) {
+    // Caffe2 might have tensors whose storages are null, but we
+    // don't allow it in PyTorch.
+    TORCH_INTERNAL_ASSERT(storage);
+    TORCH_INTERNAL_ASSERT(result.storage());
+
+    // We used to allow this, but this breaks device caching.
+    // Let's put an actual error message for this one.
+    TORCH_CHECK(result.storage().device() == storage.device(),
+                "Attempted to set the storage of a tensor on device \"", result.storage().device(),
+                "\" to a storage on different device \"", storage.device(),
+                "\".  This is no longer allowed; the devices must match.");
+    result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
+  }
+
+  // storageOffset
+  TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
+}
+
+/**
+ * Set self's sizes, strides, and storage_offset.
+ * (size, stride, storage_offset) must be in bounds for self's storage.
+ */
+template 
+inline void setStrided(
+    const Tensor& self,
+    ArrayRef size,
+    ArrayRef stride,
+    T storage_offset) {
+  TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
+  for (const auto& val : stride) {
+    TORCH_CHECK(val >= 0,
+                "as_strided: Negative strides are not supported at the moment, "
+                "got strides: ", stride);
+  }
+
+  auto* self_ = self.unsafeGetTensorImpl();
+  checkInBoundsForStorage(
+      size, stride, storage_offset, self_->dtype(), self_->storage());
+
+  /* storage offset */
+  TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
+  self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset));
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ResizeCommon.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ResizeCommon.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa836fac7b06db176d1712f6a5040d1598651a7b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ResizeCommon.h
@@ -0,0 +1,75 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+template 
+inline T storage_size_for(ArrayRef size, ArrayRef stride) {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
+      "storage_size_for(size, stride) requires that size and stride ",
+      "have the same size as a precondition.");
+  T storage_size = 1;
+  for (const auto dim : c10::irange(size.size())) {
+    if (size[dim] == 0) {
+      storage_size = 0;
+      break;
+    }
+    storage_size += (size[dim] - 1) * stride[dim];
+  }
+  return storage_size;
+}
+
+inline const Tensor& resize_named_tensor_(
+    const Tensor& self,
+    IntArrayRef size,
+    c10::optional optional_memory_format) {
+  TORCH_INTERNAL_ASSERT(self.has_names());
+  TORCH_CHECK(
+      self.sizes() == size,
+      "Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
+      "Tensor",
+      self.names(),
+      " with size ",
+      self.sizes(),
+      " to ",
+      size,
+      "). This may be caused by passing a named tensor ",
+      "as an `out=` argument; please ensure that the sizes are the same. ");
+  TORCH_CHECK(
+      !optional_memory_format.has_value(),
+      "Unsupported memory format for named tensor resize ",
+      optional_memory_format.value());
+  return self;
+}
+
+// For deterministic output, fill new elements that were added after a storage
+// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
+// before the resize happened.
+inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
+  const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
+  int64_t new_storage_nbytes = storage.nbytes();
+  int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
+  int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
+  if (new_storage_numel > old_storage_numel) {
+    at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
+    tensor_view.set_(
+      storage,
+      /*storage_offset=*/old_storage_numel,
+      /*size=*/{new_storage_numel - old_storage_numel},
+      /*stride=*/{1});
+    at::native::fill_empty_deterministic_(tensor_view);
+  }
+  return tensor;
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/ScatterGatherChecks.h b/MLPY/Lib/site-packages/torch/include/ATen/native/ScatterGatherChecks.h
new file mode 100644
index 0000000000000000000000000000000000000000..0b8a3a81abe3eca3cded51cb855366d388469441
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/ScatterGatherChecks.h
@@ -0,0 +1,128 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+namespace {
+
+// checks whether index.dtype == int64
+// and self.dtype == src.dtype if src is a Tensor
+static void scatter_gather_dtype_check(
+  const std::string& method_name,
+  const Tensor& self,
+  const Tensor& index,
+  const c10::optional& src_opt = c10::nullopt
+) {
+  if (index.numel() != 0) {
+    TORCH_CHECK(
+      index.scalar_type() == at::ScalarType::Long,
+      method_name, "(): Expected dtype int64 for index"
+    );
+  }
+
+  if (src_opt.has_value()) {
+    const auto& src = src_opt.value();
+    TORCH_CHECK(
+      self.scalar_type() == src.scalar_type(),
+      method_name, "(): Expected self.dtype to be equal to src.dtype"
+    );
+  }
+}
+
+// Used for `gather`-like methods
+// Note: self means the input tensor here
+// Test:
+// 1. index.size(d) <= self.size(d) for all d != dim
+// 2. index.dim() == self.dim()
+static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim,
+  const Tensor& index
+) {
+  auto self_dims = ensure_nonempty_dim(self.dim());
+  TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
+    "Index tensor must have the same number of dimensions as input tensor"
+  );
+
+  for (const auto i : c10::irange(self_dims)) {
+    if (i != dim) {
+      TORCH_CHECK(
+        ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
+        "Size does not match at dimension ", i,
+        " expected index ", index.sizes(),
+        " to be smaller than self ", self.sizes(),
+        " apart from dimension ", dim
+      );
+    }
+  }
+}
+
+// Used for `scatter` and `scatter_add`
+// Tests:
+//  1. index.size(d) <= self.size(d) for all d != dim
+//  2. index.size(d) <= src.size(d) for all d if src is a Tensor
+//  3. index.dim() == self.dim() == src.dim()
+static C10_UNUSED void scatter_shape_check(
+  const Tensor& self, int64_t dim, const Tensor& index,
+  const c10::optional& src_opt = c10::nullopt
+) {
+  if (index.numel() == 0) return;
+  TORCH_CHECK(
+    ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
+    "Index tensor must have the same number of dimensions as self tensor"
+  );
+
+  bool is_wrong_shape = false;
+  int64_t self_dims = ensure_nonempty_dim(self.dim());
+
+  //  Check: index.size(d) <= self.size(d) for all d != dim
+  for (const auto d : c10::irange(self_dims)) {
+    int64_t index_d_size = ensure_nonempty_size(index, d);
+    if (d == dim) continue;
+    if (index_d_size > ensure_nonempty_size(self, d)) {
+      is_wrong_shape = true;
+      break;
+    }
+  }
+
+  //  Check: index.size(d) <= src.size(d) for all d if src is Tensor
+  if (!is_wrong_shape && src_opt.has_value()) {
+    const auto& src = src_opt.value();
+    for (const auto d : c10::irange(self_dims)) {
+      int64_t index_d_size = ensure_nonempty_size(index, d);
+      if (index_d_size > ensure_nonempty_size(src, d)) {
+        is_wrong_shape = true;
+        break;
+      }
+    }
+  }
+
+  if (src_opt.has_value()) {
+    const auto& src = src_opt.value();
+
+    TORCH_CHECK(
+      ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
+      "Index tensor must have the same number of dimensions as src tensor"
+    );
+
+    TORCH_CHECK(!is_wrong_shape,
+      "Expected index ", index.sizes(),
+      " to be smaller than self ", self.sizes(),
+      " apart from dimension ", dim,
+      " and to be smaller size than src ", src.sizes()
+    );
+  }
+  else {
+    TORCH_CHECK(!is_wrong_shape,
+      "Expected index ", index.sizes(),
+      " to be smaller than self ", self.sizes(),
+      " apart from dimension ", dim
+    );
+  }
+}
+
+} // anonymous namespace
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/SegmentReduce.h b/MLPY/Lib/site-packages/torch/include/ATen/native/SegmentReduce.h
new file mode 100644
index 0000000000000000000000000000000000000000..20c251cfd7581f5861e969880a54a6459dfca06b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/SegmentReduce.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using segment_reduce_lengths_fn = Tensor (*)(
+    ReductionType,
+    const Tensor&,
+    const Tensor&,
+    int64_t,
+    const c10::optional&);
+DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
+
+using segment_reduce_offsets_fn = Tensor (*)(
+    ReductionType,
+    const Tensor&,
+    const Tensor&,
+    int64_t,
+    const c10::optional&);
+DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
+
+using segment_reduce_lengths_backward_fn = Tensor (*)(
+    const Tensor&,
+    const Tensor&,
+    const Tensor&,
+    ReductionType,
+    const Tensor&,
+    int64_t,
+    const c10::optional&);
+DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
+
+using segment_reduce_offsets_backward_fn = Tensor (*)(
+    const Tensor&,
+    const Tensor&,
+    const Tensor&,
+    ReductionType,
+    const Tensor&,
+    int64_t,
+    const c10::optional&);
+DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/SharedReduceOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/SharedReduceOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..d619b05237acac6634544297b43fcf24d807ab6c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/SharedReduceOps.h
@@ -0,0 +1,544 @@
+#pragma once
+// Please note that this file is
+// used across both CPU and GPU.
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#if defined(__CUDACC__)
+#include 
+#include 
+#elif defined(__HIPCC__)
+#include 
+#include 
+#endif
+#if defined(__CUDACC__) || defined(__HIPCC__)
+#include 
+#else
+#include 
+#define device_sqrt std::sqrt
+#endif
+#if defined(__CUDACC__) || defined(__HIPCC__)
+template 
+inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
+#if defined(__HIPCC__)
+  // TODO: remove this special case for HIP when issue is fixed:
+  //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
+  scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
+#else
+  scalar_t max = at::_isnan(b) ? b : std::max(a, b);
+#endif
+  return max;
+}
+template 
+inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
+#if defined(__HIPCC__)
+  // TODO: remove this special case for HIP when issue is fixed:
+  //       https://github.com/ROCm-Developer-Tools/HIP/issues/2209
+  scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
+#else
+  scalar_t min = at::_isnan(b) ? b : std::min(a, b);
+#endif
+  return min;
+}
+#define MAX(X, Y) max_propagate_nan(X,Y)
+#define MIN(X, Y) min_propagate_nan(X,Y)
+#else
+#include 
+#define MAX(X, Y) max_impl(X,Y)
+#define MIN(X, Y) min_impl(X,Y)
+#endif
+
+// ROCM hcc doesn't work well with using std:: in kernel functions
+#if defined(__CUDA_ARCH__)
+#include 
+#define compat_pow c10::cuda::compat::pow
+#elif defined(__HIPCC__)
+#include 
+#define compat_pow c10::hip::compat::pow
+#else
+#define compat_pow std::pow
+#endif
+
+namespace at { namespace native {
+
+namespace detail {
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+template  using pair = thrust::pair;
+#else
+template  using pair = std::pair;
+#endif
+
+} // namespace detail
+
+template 
+struct WelfordData {
+  scalar_t mean;
+  scalar_t m2;
+  index_t n;
+  scalar_t nf;
+
+  C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
+
+  C10_HOST_DEVICE WelfordData(
+      scalar_t mean,
+      scalar_t m2,
+      index_t n,
+      scalar_t nf)
+      : mean(mean), m2(m2), n(n), nf(nf) {}
+};
+
+
+template 
+struct WelfordOps {
+  acc_scalar_t correction;
+  bool take_sqrt;
+ public:
+  using acc_t = WelfordData;
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
+    // We accumulate n in index_t to avoid cumulative rounding error, but still
+    // need nf for use in combine where int32 may overflow.
+    index_t new_n = acc.n + 1;
+    acc_scalar_t new_nf = static_cast(new_n);
+    acc_scalar_t delta = data - acc.mean;
+    acc_scalar_t new_mean = acc.mean + delta / new_nf;
+    acc_scalar_t new_delta = data - new_mean;
+    return {
+      new_mean,
+      acc.m2 + delta * new_delta,
+      new_n,
+      new_nf,
+    };
+  }
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    if (a.nf == 0) {
+      return b;
+    }
+    if (b.nf == 0) {
+      return a;
+    }
+    acc_scalar_t delta = b.mean - a.mean;
+    acc_scalar_t new_count = a.nf + b.nf;
+    acc_scalar_t nb_over_n = b.nf / new_count;
+    return {
+      a.mean + delta * nb_over_n,
+      a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
+      // setting acc.n as -1 since acc.n might not be able to represent the count
+      // correctly within its range, setting it to -1 to avoid confusion
+      -1,
+      new_count
+    };
+  }
+  inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
+    const auto mean = static_cast(acc.mean);
+    const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
+    const auto var = acc.m2 / divisor;
+    res_t results(take_sqrt ? device_sqrt(var) : var, mean);
+    return results;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return {
+      WARP_SHFL_DOWN(acc.mean, offset)
+      , WARP_SHFL_DOWN(acc.m2, offset)
+      , WARP_SHFL_DOWN(acc.n, offset)
+      , WARP_SHFL_DOWN(acc.nf, offset)
+    };
+  }
+#endif
+  C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
+      : correction(correction), take_sqrt(take_sqrt) {}
+};
+
+template 
+struct MeanOps {
+  factor_t factor;
+
+  inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
+    return combine(a, static_cast(b));
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return a + b;
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return a * factor;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+    return WARP_SHFL_DOWN(data, offset);
+  }
+#endif
+
+  MeanOps(factor_t factor): factor(factor) {
+  }
+};
+
+// This accumulator template is used to calculate the minimum absolute value of
+// a set of numbers.
+// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
+// value. These types differ for complex number input support.
+template 
+struct AbsMinOps {
+
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
+    return MIN(acc, static_cast(std::abs(data)));
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return MIN(a, b);
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return a;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return WARP_SHFL_DOWN(acc, offset);
+  }
+#endif
+};
+
+// This accumulator template is used to calculate the maximum absolute value of
+// a set of numbers.
+// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
+// value. These types differ for complex number input support.
+template 
+struct AbsMaxOps {
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
+    return MAX(acc, static_cast(std::abs(data)));
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return MAX(a, b);
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return a;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return WARP_SHFL_DOWN(acc, offset);
+  }
+#endif
+};
+
+// This accumulator template is used to calculate the norm of the absolute value
+// of a set of numbers.
+// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
+// value. These types differ for complex number input support.
+template 
+struct NormOps {
+  acc_t norm_;
+
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
+    return acc + compat_pow(static_cast(std::abs(data)), norm_);
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return a + b;
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return compat_pow(a, static_cast(1.0) / norm_);
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return WARP_SHFL_DOWN(acc, offset);
+  }
+#endif
+
+  NormOps(acc_t norm_): norm_(norm_) {
+  }
+};
+
+// This accumulator template is used to calculate the order zero norm of the
+// absolute value of a set of numbers.
+// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
+// value. These types differ for complex number input support.
+template 
+struct NormZeroOps {
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
+    return acc + (data == static_cast(0) ? static_cast(0) : static_cast(1));
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return a + b;
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return a;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return WARP_SHFL_DOWN(acc, offset);
+  }
+#endif
+};
+
+// This accumulator template is used to calculate the order one norm of the
+// absolute value of a set of numbers.
+// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
+// value. These types differ for complex number input support.
+template 
+struct NormOneOps {
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
+    return acc + static_cast(std::abs(data));
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return a + b;
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return a;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return WARP_SHFL_DOWN(acc, offset);
+  }
+#endif
+};
+
+
+template
+struct AbsSwitch {};
+
+template
+inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch) {
+  return static_cast(data);
+}
+
+template
+inline C10_DEVICE acc_t abs_if_complex(std::complex data, AbsSwitch) {
+  return static_cast(std::abs(data));
+}
+
+template
+inline C10_DEVICE acc_t abs_if_complex(c10::complex data, AbsSwitch) {
+  return static_cast(std::abs(data));
+}
+
+// This accumulator template is used to calculate the order two norm of the
+// absolute value of a set of numbers.
+// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
+// value. These types differ for complex number input support.
+template 
+struct NormTwoOps {
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
+    acc_t data_ = abs_if_complex(data, AbsSwitch());
+    return acc + data_ * data_;
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return a + b;
+  }
+
+  inline C10_DEVICE out_t project(acc_t a) const {
+    return device_sqrt(a);
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return WARP_SHFL_DOWN(acc, offset);
+  }
+#endif
+};
+
+template 
+struct NanSumOps {
+  inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
+    return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    return  a + b;
+  }
+
+  inline C10_DEVICE data_t project(acc_t a) const {
+    return data_t{a};
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
+    return WARP_SHFL_DOWN(data, offset);
+  }
+#endif
+};
+
+namespace detail {
+
+template 
+struct LessOrNan {
+  C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
+    // If (a == b), then choose the one with lower idx, else min(a, b)
+    if (at::_isnan(a)) {
+      if (at::_isnan(b)) {
+        return idx_a < idx_b;
+      }
+      return true;
+    }
+    return (a == b) ? idx_a < idx_b : (a < b);
+  }
+};
+
+template 
+struct GreaterOrNan {
+  C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
+    // If (a == b), then choose the one with lower idx, else max(a, b)
+    if (at::_isnan(a)) {
+      if (at::_isnan(b)) {
+        return idx_a < idx_b;
+      }
+      return true;
+    }
+    return (a == b) ? idx_a < idx_b : (a > b);
+  }
+};
+
+template 
+struct MinMaxReductionOps {
+  using scalar_t = typename binary_function_traits::arg1_t;
+  using index_t = int64_t;
+  using arg_t = detail::pair;
+
+  static C10_DEVICE arg_t project(arg_t arg) {
+    return arg;
+  }
+
+  static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
+    return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
+  }
+
+  static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
+    return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
+  }
+
+  static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
+    return {a.first, a.second + base_idx};
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
+    return arg_t(WARP_SHFL_DOWN(arg.first, offset),
+                 WARP_SHFL_DOWN(arg.second, offset));
+  }
+#endif
+};
+
+template 
+struct ArgReductionOps : public MinMaxReductionOps {
+  using typename MinMaxReductionOps::scalar_t;
+  using typename MinMaxReductionOps::index_t;
+  using typename MinMaxReductionOps::arg_t;
+
+  static C10_DEVICE index_t project(arg_t arg) {
+    return arg.second;
+  }
+};
+
+} // namespace detail
+
+template 
+struct ArgMaxOps :
+  public detail::ArgReductionOps> {
+};
+
+template 
+struct ArgMinOps :
+  public detail::ArgReductionOps> {
+};
+
+template 
+struct MinOps :
+  public detail::MinMaxReductionOps> {
+};
+
+template 
+struct MaxOps :
+  public detail::MinMaxReductionOps> {
+};
+
+template 
+struct MinMaxOps {
+  using acc_t = detail::pair;
+  inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
+    return combine(acc, {data, data});
+  }
+
+  inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
+    auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
+    auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
+
+    return {min_val, max_val};
+  }
+
+  inline C10_DEVICE acc_t project(acc_t acc) const {
+    return acc;
+  }
+
+  static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
+    return acc;
+  }
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+  inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
+    return {
+      WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
+    };
+  }
+#endif
+};
+
+}} // namespace at::native
+
+#undef MAX
+#undef MIN
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..3fbc4e311dea531b7d0a2501dad0685671f8a1b2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h
@@ -0,0 +1,55 @@
+/// This file contains some tensor-agnostic operations to be used in the
+/// core functions of the `SobolEngine`
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#include 
+#endif
+
+namespace at::native::sobol_utils {
+
+/// Function to return the minimum of number of bits to represent the integer `n`
+inline int64_t bit_length(const int64_t n) {
+  int64_t nbits, nloc;
+  for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++);
+  return nbits;
+}
+
+/// Function to get the position of the rightmost zero in the bit representation of an integer
+/// This value is the zero-indexed position
+inline int64_t rightmost_zero(const int64_t n) {
+  int64_t z, i;
+  for (z = n, i = 0; z % 2 == 1; z /= 2, i++);
+  return i;
+}
+
+/// Function to get a subsequence of bits in the representation of an integer starting from
+/// `pos` and of length `length`
+inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) {
+  return (n >> pos) & ((1 << length) - 1);
+}
+
+/// Function to perform the inner product between a batched square matrix and a power of 2 vector
+inline at::Tensor cdot_pow2(const at::Tensor& bmat) {
+  at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options());
+  inter = at::pow(2, inter).expand_as(bmat);
+  return at::mul(inter, bmat).sum(-1);
+}
+
+/// All definitions below this point are data. These are constant, and should not be modified
+/// without notice
+
+constexpr int64_t MAXDIM = 21201;
+constexpr int64_t MAXDEG = 18;
+constexpr int64_t MAXBIT = 30;
+constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT;
+constexpr float RECIPD = 1.0 / LARGEST_NUMBER;
+
+extern const int64_t poly[MAXDIM];
+extern const int64_t initsobolstate[MAXDIM][MAXDEG];
+
+} // namespace at::native::sobol_utils
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Sorting.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Sorting.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2126bd083d7ae7496ef06190557f17421708cdb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Sorting.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at::native {
+
+enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
+  LINEAR,
+  LOWER,
+  HIGHER,
+  MIDPOINT,
+  NEAREST
+};
+
+using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
+using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
+
+DECLARE_DISPATCH(sort_fn, sort_stub);
+DECLARE_DISPATCH(topk_fn, topk_stub);
+
+void _fill_indices(const TensorBase &indices, int64_t dim);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/SortingUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/SortingUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..cb9e3d37c6768e08cc091d2ce8c7efed04a8a2cc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/SortingUtils.h
@@ -0,0 +1,88 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+// ensure we get good values and indices for kthvalue, mode
+// this will always be with the reducing dim as 1-d
+inline void _reduction_with_indices_allocate_or_resize_output(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t dim_,
+    bool keepdim) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
+  auto result_sizes = self.sizes().vec();
+  if (!result_sizes.empty()) {
+    result_sizes[dim] = 1;
+  }
+  if (values.defined()) {
+    TORCH_CHECK(
+        self.options().type_equal(values.options()),
+        "output values must be of same type as input");
+    if (!keepdim && values.dim() == self.dim() - 1) {
+      // unsqueeze to preserve passed in noncontiguous tensor in resize
+      values.unsqueeze_(dim);
+    }
+    resize_output(values, result_sizes);
+  } else {
+    values = at::empty(result_sizes, self.options());
+  }
+  if (indices.defined()) {
+    TORCH_CHECK(
+        indices.dtype() == kLong, "output indices must be of scalar type Long");
+    TORCH_CHECK(
+        indices.device() == self.device(),
+        "output indices must be on same device as input");
+    if (!keepdim && indices.dim() == self.dim() - 1) {
+      // unsqueeze to preserve passed in noncontiguous tensor in resize
+      indices.unsqueeze_(dim);
+    }
+    resize_output(indices, result_sizes);
+  } else {
+    indices = at::empty(result_sizes, self.options().dtype(kLong));
+  }
+}
+
+// ensure we get good values and indices for topk
+inline void _allocate_or_resize_output_with_indices(
+    Tensor& values,
+    Tensor& indices,
+    const Tensor& self,
+    int64_t dim_,
+    int64_t k) {
+  int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
+  auto result_sizes = self.sizes().vec();
+  if (!result_sizes.empty()) {
+    result_sizes[dim] = k;
+  }
+  if (values.defined()) {
+    TORCH_CHECK(
+        self.options().type_equal(values.options()),
+        "output values must be of same type as input");
+    values.resize_(result_sizes);
+  } else {
+    values = at::empty(result_sizes, self.options());
+  }
+  if (indices.defined()) {
+    TORCH_CHECK(
+        indices.dtype() == kLong, "output indices must be of scalar type Long");
+    TORCH_CHECK(
+        indices.device() == self.device(),
+        "output indices must be on same device as input");
+    indices.resize_(result_sizes);
+  } else {
+    indices = at::empty(result_sizes, self.options().dtype(kLong));
+  }
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/SparseTensorUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/SparseTensorUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..f44d51b352eebed86c0743eca4446842a1b65ca3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/SparseTensorUtils.h
@@ -0,0 +1,190 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#endif
+
+namespace at::sparse {
+
+// Just for documentary purposes
+using SparseTensor = Tensor;
+using SparseType = Type;
+
+// This is an internal utility function for getting at the SparseTensorImpl,
+// so that we can write sparse tensor specific accessors for special fields
+// in SparseTensor.  You should only use this for writing low level
+// setters/getters for SparseTensorImpl fields; otherwise, you should use
+// the low level setters/getters that were implemented using this.
+//
+// This may be called repeatedly, so make sure it's pretty cheap.
+inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
+  TORCH_INTERNAL_ASSERT(
+      self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
+  return static_cast(self.unsafeGetTensorImpl());
+}
+
+// Takes indices and values and directly puts them into the sparse tensor, no
+// copy.  This used to be called THSTensor_(_move)
+inline void alias_into_sparse(
+    const SparseTensor& self,
+    const Tensor& indices,
+    const Tensor& values) {
+  get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
+}
+
+// Take indices and values and makes a (data) copy of them to put into the
+// sparse indices/values.  This used to be called THSTensor_(_set)
+inline void copy_into_sparse(
+    const SparseTensor& self,
+    const Tensor& indices,
+    const Tensor& values,
+    bool non_blocking) {
+  alias_into_sparse(
+      self,
+      indices.to(self._indices().options(), non_blocking, /*copy=*/true),
+      values.to(self._values().options(), non_blocking, /*copy=*/true));
+}
+
+// TODO: put this into the public API
+inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
+  return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
+}
+
+inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
+  return self.sparse_dim() == src.sparse_dim() &&
+      self.dense_dim() == src.dense_dim();
+}
+
+// Give us a new values tensor, with the same dimensionality
+// as 'values' but with a new number of non-zero elements.
+// TODO: Expose this for real in ATen, some day?
+// NB: Doesn't preserve data.
+inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
+  std::vector size = values.sizes().vec();
+  size[0] = nnz;
+  return at::empty(size, values.options());
+}
+
+// NOTE [ Flatten Sparse Indices ]
+// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
+// indices tensor. E.g.,
+//   input = [[2, 4, 0],
+//            [3, 1, 10]]
+//   full_size = [2, 12]
+//   output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
+//
+// In other words, assuming that each `indices[i, :]` is a valid index to a
+// tensor `t` of shape `full_size`. This returns the corresponding indices to
+// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
+// if forceClone is true, the result will forced to be a clone of self.
+// if force_clone is true, the result will forced to be a clone of self.
+TORCH_API Tensor flatten_indices(
+    const Tensor& indices,
+    IntArrayRef full_size,
+    bool force_clone = false);
+
+// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
+// Sparse Indices ], except this one allows partial flatten: only flatten on
+// specified dims. Note that the flatten indices might be uncoalesced if
+// dims_to_flatten.size() < sparse_dim. Also if input indices is already
+// coalesced, the flattened indices will also be sorted.
+//
+// args:
+//    indices: sparse tensor indices
+//    sizes: sparse tensor sizes
+//    dims_to_flatten: a list of dim index to flatten
+//
+// Ex1:
+//   indices = [[2, 4, 0],
+//             [3, 1, 3]]
+//   sizes = [2, 12]
+//   dims_to_flatten = [0, 1]
+//   new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
+//
+// Ex2:
+//   dims_to_flatten = [1]
+//   new_indices = [ 3, 1, 3 ]  # uncoalesced
+TORCH_API Tensor flatten_indices_by_dims(
+    const Tensor& indices,
+    const IntArrayRef& sizes,
+    const IntArrayRef& dims_to_flatten);
+
+// Find the CSR representation for a row `indices` from the COO format
+TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
+
+TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
+
+template 
+class TensorGeometryHolder {
+  using geometry_holder_t = std::array;
+
+ public:
+  explicit TensorGeometryHolder(
+      IntArrayRef sizes,
+      IntArrayRef strides,
+      TensorOptions options = {}) {
+    std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
+    std::copy(strides.begin(), strides.end(), t_strides.begin());
+  }
+
+  explicit TensorGeometryHolder(const Tensor& t)
+      : TensorGeometryHolder(t.sizes(), t.strides()) {}
+
+  auto operator*() const {
+    return std::make_tuple(t_sizes, t_strides);
+  }
+
+ private:
+  geometry_holder_t t_sizes;
+  geometry_holder_t t_strides;
+};
+
+template <>
+class TensorGeometryHolder<0> {
+  using geometry_holder_t = Tensor;
+
+ public:
+  explicit TensorGeometryHolder(
+      IntArrayRef sizes,
+      IntArrayRef strides,
+      TensorOptions options) {
+    const int64_t t_ndims = sizes.size();
+    const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
+    Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
+    t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
+    t_sizes_and_strides_cpu.select(0, 1).copy_(
+        at::tensor(strides, cpu_options));
+    const Tensor t_sizes_and_strides =
+        t_sizes_and_strides_cpu.to(options.device());
+    t_sizes = t_sizes_and_strides.select(0, 0);
+    t_strides = t_sizes_and_strides.select(0, 1);
+  }
+
+  explicit TensorGeometryHolder(const Tensor& t)
+      : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
+
+  auto operator*() const {
+    return std::make_tuple(
+        t_sizes.template data_ptr(),
+        t_strides.template data_ptr());
+  }
+
+ private:
+  geometry_holder_t t_sizes;
+  geometry_holder_t t_strides;
+};
+
+// Return all indices of a tensor with the given shape.
+//
+// full_coo_indices(shape) is equivalent to
+// torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
+TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
+
+} // namespace at::sparse
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/SpectralOpsUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/SpectralOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..5a35d87e522307c99b9d0b95dfcc46bd0a93fe00
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/SpectralOpsUtils.h
@@ -0,0 +1,84 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+// Normalization types used in _fft_with_size
+enum class fft_norm_mode {
+  none,       // No normalization
+  by_root_n,  // Divide by sqrt(signal_size)
+  by_n,       // Divide by signal_size
+};
+
+// NOTE [ Fourier Transform Conjugate Symmetry ]
+//
+// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
+// assuming X is the transformed K-dimensionsal signal, we have
+//
+//     X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
+//
+//       where j_k  = (N_k - i_k)  mod N_k, N_k being the signal size at dim k,
+//             * is the conjugate operator.
+//
+// Therefore, in such cases, FFT libraries return only roughly half of the
+// values to avoid redundancy:
+//
+//     X[:, :, ..., :floor(N / 2) + 1]
+//
+// This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such
+// halved signal will also be returned by default (flag onesided=True).
+// The following infer_ft_real_to_complex_onesided_size function calculates the
+// onesided size from the twosided size.
+//
+// Note that this loses some information about the size of signal at last
+// dimension. E.g., both 11 and 10 maps to 6. Hence, the following
+// infer_ft_complex_to_real_onesided_size function takes in optional parameter
+// to infer the twosided size from given onesided size.
+//
+// cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional
+// MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE
+
+inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) {
+  return (real_size / 2) + 1;
+}
+
+inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size,
+                                                      int64_t expected_size=-1) {
+  int64_t base = (complex_size - 1) * 2;
+  if (expected_size < 0) {
+    return base + 1;
+  } else if (base == expected_size) {
+    return base;
+  } else if (base + 1 == expected_size) {
+    return base + 1;
+  } else {
+    std::ostringstream ss;
+    ss << "expected real signal size " << expected_size << " is incompatible "
+       << "with onesided complex frequency size " << complex_size;
+    AT_ERROR(ss.str());
+  }
+}
+
+using fft_fill_with_conjugate_symmetry_fn =
+    void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes,
+             IntArrayRef in_strides, const void* in_data,
+             IntArrayRef out_strides, void* out_data);
+DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub);
+
+// In real-to-complex transform, cuFFT and MKL only fill half of the values
+// due to conjugate symmetry. This function fills in the other half of the full
+// fft by using the Hermitian symmetry in the signal.
+// self should be the shape of the full signal and dims.back() should be the
+// one-sided dimension.
+// See NOTE [ Fourier Transform Conjugate Symmetry ]
+TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/StridedRandomAccessor.h b/MLPY/Lib/site-packages/torch/include/ATen/native/StridedRandomAccessor.h
new file mode 100644
index 0000000000000000000000000000000000000000..5ee7da926bae6bdf5c8c9e1149152f6c75f263aa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/StridedRandomAccessor.h
@@ -0,0 +1,301 @@
+#pragma once
+
+namespace at::native {
+
+// (Const)StridedRandomAccessor is a
+// (const) random access iterator defined over
+// a strided array.
+
+// The traits below are to introduce __restrict__
+// modifier on different platforms.
+
+template 
+struct DefaultPtrTraits {
+  using PtrType = T*;
+};
+
+#if (defined(_WIN32) || defined(_WIN64))
+#define RESTRICT __restrict
+#else
+#define RESTRICT __restrict__
+#endif
+
+template 
+struct RestrictPtrTraits {
+  using PtrType = T* RESTRICT;
+};
+
+template <
+  typename T,
+  typename index_t = int64_t,
+  template  class PtrTraits = DefaultPtrTraits
+>
+class ConstStridedRandomAccessor {
+public:
+  using difference_type = index_t;
+  using value_type = const T;
+  using pointer = const typename PtrTraits::PtrType;
+  using reference = const value_type&;
+  using iterator_category = std::random_access_iterator_tag;
+
+  using PtrType = typename PtrTraits::PtrType;
+  using index_type = index_t;
+
+  // Constructors {
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor(PtrType ptr, index_t stride)
+    : ptr{ptr}, stride{stride}
+  {}
+
+  C10_HOST_DEVICE
+  explicit ConstStridedRandomAccessor(PtrType ptr)
+    : ptr{ptr}, stride{static_cast(1)}
+  {}
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor()
+    : ptr{nullptr}, stride{static_cast(1)}
+  {}
+  // }
+
+  // Pointer-like operations {
+  C10_HOST_DEVICE
+  reference operator*() const {
+    return *ptr;
+  }
+
+  C10_HOST_DEVICE
+  const value_type* operator->() const {
+    return reinterpret_cast(ptr);
+  }
+
+  C10_HOST_DEVICE
+  reference operator[](index_t idx) const {
+    return ptr[idx * stride];
+  }
+  // }
+
+  // Prefix/postfix increment/decrement {
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator++() {
+    ptr += stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator++(int) {
+    ConstStridedRandomAccessor copy(*this);
+    ++*this;
+    return copy;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator--() {
+    ptr -= stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator--(int) {
+    ConstStridedRandomAccessor copy(*this);
+    --*this;
+    return copy;
+  }
+  // }
+
+  // Arithmetic operations {
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator+=(index_t offset) {
+    ptr += offset * stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator+(index_t offset) const {
+    return ConstStridedRandomAccessor(ptr + offset * stride, stride);
+  }
+
+  C10_HOST_DEVICE
+  friend ConstStridedRandomAccessor operator+(
+    index_t offset,
+    const ConstStridedRandomAccessor& accessor
+  ) {
+    return accessor + offset;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor& operator-=(index_t offset) {
+    ptr -= offset * stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  ConstStridedRandomAccessor operator-(index_t offset) const {
+    return ConstStridedRandomAccessor(ptr - offset * stride, stride);
+  }
+
+  // Note that this operator is well-defined when `this` and `other`
+  // represent the same sequences, i.e. when
+  // 1. this.stride == other.stride,
+  // 2. |other - this| / this.stride is an Integer.
+  C10_HOST_DEVICE
+  difference_type operator-(const ConstStridedRandomAccessor& other) const {
+    return (ptr - other.ptr) / stride;
+  }
+  // }
+
+  // Comparison operators {
+  C10_HOST_DEVICE
+  bool operator==(const ConstStridedRandomAccessor& other) const {
+    return (ptr == other.ptr) && (stride == other.stride);
+  }
+
+  C10_HOST_DEVICE
+  bool operator!=(const ConstStridedRandomAccessor& other) const {
+    return !(*this == other);
+  }
+
+  C10_HOST_DEVICE
+  bool operator<(const ConstStridedRandomAccessor& other) const {
+    return ptr < other.ptr;
+  }
+
+  C10_HOST_DEVICE
+  bool operator<=(const ConstStridedRandomAccessor& other) const {
+    return (*this < other) || (*this == other);
+  }
+
+  C10_HOST_DEVICE
+  bool operator>(const ConstStridedRandomAccessor& other) const {
+    return !(*this <= other);
+  }
+
+  C10_HOST_DEVICE
+  bool operator>=(const ConstStridedRandomAccessor& other) const {
+    return !(*this < other);
+  }
+  // }
+
+protected:
+  PtrType ptr;
+  index_t stride;
+};
+
+template <
+  typename T,
+  typename index_t = int64_t,
+  template  class PtrTraits = DefaultPtrTraits
+>
+class StridedRandomAccessor
+  : public ConstStridedRandomAccessor {
+public:
+  using difference_type = index_t;
+  using value_type = T;
+  using pointer = typename PtrTraits::PtrType;
+  using reference = value_type&;
+
+  using BaseType = ConstStridedRandomAccessor;
+  using PtrType = typename PtrTraits::PtrType;
+
+  // Constructors {
+  C10_HOST_DEVICE
+  StridedRandomAccessor(PtrType ptr, index_t stride)
+    : BaseType(ptr, stride)
+  {}
+
+  C10_HOST_DEVICE
+  explicit StridedRandomAccessor(PtrType ptr)
+    : BaseType(ptr)
+  {}
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor()
+    : BaseType()
+  {}
+  // }
+
+  // Pointer-like operations {
+  C10_HOST_DEVICE
+  reference operator*() const {
+    return *this->ptr;
+  }
+
+  C10_HOST_DEVICE
+  value_type* operator->() const {
+    return reinterpret_cast(this->ptr);
+  }
+
+  C10_HOST_DEVICE
+  reference operator[](index_t idx) const {
+    return this->ptr[idx * this->stride];
+  }
+  // }
+
+  // Prefix/postfix increment/decrement {
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator++() {
+    this->ptr += this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator++(int) {
+    StridedRandomAccessor copy(*this);
+    ++*this;
+    return copy;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator--() {
+    this->ptr -= this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator--(int) {
+    StridedRandomAccessor copy(*this);
+    --*this;
+    return copy;
+  }
+  // }
+
+  // Arithmetic operations {
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator+=(index_t offset) {
+    this->ptr += offset * this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator+(index_t offset) const {
+    return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
+  }
+
+  C10_HOST_DEVICE
+  friend StridedRandomAccessor operator+(
+    index_t offset,
+    const StridedRandomAccessor& accessor
+  ) {
+    return accessor + offset;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor& operator-=(index_t offset) {
+    this->ptr -= offset * this->stride;
+    return *this;
+  }
+
+  C10_HOST_DEVICE
+  StridedRandomAccessor operator-(index_t offset) const {
+    return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
+  }
+
+  // Note that here we call BaseType::operator- version
+  C10_HOST_DEVICE
+  difference_type operator-(const BaseType& other) const {
+    return (static_cast(*this) - other);
+  }
+  // }
+};
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h
new file mode 100644
index 0000000000000000000000000000000000000000..7dea4a1a279d78e430a6aaf4258bdabdab2a5e71
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h
@@ -0,0 +1,49 @@
+#pragma once
+
+// Indexing tensors by tensors
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+struct TensorIterator;
+}
+
+namespace at::native {
+
+using index_put_with_sort_fn = void(*)(Tensor &, const c10::List> &, const Tensor &, bool accumulate, bool unsafe);
+using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
+using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
+using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
+using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
+using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
+using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
+                                  const Tensor& src, const ReductionType& reduce);
+using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
+                                         const Scalar& value, const ReductionType& reduce);
+using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
+                                      const Tensor& src, const ReductionType& reduce);
+
+DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
+DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
+DECLARE_DISPATCH(gather_fn, gather_stub);
+DECLARE_DISPATCH(scatter_fn, scatter_stub);
+DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
+DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
+DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
+DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
+DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
+
+TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List>& indices);
+
+using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
+using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
+using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
+DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
+DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..a8934989512f22ab96df8d348050673a6782e2ee
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h
@@ -0,0 +1,92 @@
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at::native {
+namespace {
+static std::string shapes_as_str(TensorList tensors) {
+  std::ostringstream os;
+  bool first = true;
+  for (auto& tensor : tensors) {
+    if (tensor.defined()) {
+      if (!first) {
+        os << ", ";
+      }
+      os << tensor.sizes();
+      first = false;
+    }
+  }
+  return os.str();
+}
+} // anonymous namespace
+
+static std::tuple canDispatchToMaskedFill(const Tensor& self, const torch::List>& indices,
+const Tensor& value){
+  if (!(value.numel() ==1 && value.device().is_cpu())){
+    return std::make_tuple(false,Tensor());
+  }
+  int64_t num_ind = 0;
+  Tensor mask;
+  auto self_device = self.device();
+  for (const c10::optional& i: indices) {
+    if (!i.has_value() || !(*i).defined()){
+      num_ind++;
+    } else {
+      const Tensor &index = *i;
+      if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
+          index.device() != self_device || mask.defined()){
+        return std::make_tuple(false, Tensor());
+      } else {
+        mask = index;
+        for (const auto j : c10::irange(index.dim())) {
+          int64_t srcIdx = num_ind + j;
+          TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
+  " does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
+        }
+        num_ind += mask.ndimension();
+      }
+    }
+  }
+  for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
+    mask = mask.unsqueeze(-1);
+  }
+  return std::make_tuple(true, mask);
+}
+
+static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
+  checkIndexTensorTypes(orig, /*allow_int*/ true);
+  // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
+  auto indices = expandTensors(self, orig);
+  // next broadcast all index tensors together
+  try {
+    indices = expand_outplace(indices);
+  } catch (std::exception& e) {
+    TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
+                   " with shapes ", shapes_as_str(indices));
+  }
+  // add missing null Tensors so that it matches self.dim()
+  while (indices.size() < (size_t)self.dim()) {
+    indices.emplace_back();
+  }
+  // if the non-null indices are not all adjacent, transpose self and indices
+  // together so that they're adjacent at the front
+  if (!hasContiguousSubspace(indices)) {
+    std::tie(self, indices) = transposeToFront(self, indices);
+  }
+  // Ensure indices are on the same device as self
+  for (auto & indice : indices) {
+    if (indice.defined() && indice.device() != self.device()) {
+      indice = indice.to(self.device());
+    }
+  }
+  for (auto & indice : indices) {
+    if (indice.defined() && indice.dtype() == at::kInt) {
+      indice = indice.to(at::kLong);
+    }
+  }
+
+  return AdvancedIndex(self, indices);
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorCompare.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorCompare.h
new file mode 100644
index 0000000000000000000000000000000000000000..f61d336c5671fc1c0c35356479213192b04501ce
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorCompare.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include 
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at {
+class Tensor;
+struct TensorIterator;
+struct TensorIteratorBase;
+}
+
+namespace at::native {
+
+using reduce_minmax_fn =
+    void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
+using structured_reduce_minmax_fn =
+    void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
+
+DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
+DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
+
+using where_fn = void (*)(TensorIterator &);
+DECLARE_DISPATCH(where_fn, where_kernel);
+
+using is_infinity_op_fn = void (*)(TensorIteratorBase &);
+DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
+DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
+
+using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
+DECLARE_DISPATCH(mode_fn, mode_stub);
+
+using clamp_tensor_fn = void (*)(TensorIteratorBase &);
+DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub);
+
+namespace detail {
+    enum class ClampLimits {Min, Max, MinMax};
+}
+
+DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub);
+DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub);
+DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub);
+
+using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
+DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorConversions.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorConversions.h
new file mode 100644
index 0000000000000000000000000000000000000000..cf0ae4482d5eebc7a530283c9558ac3d0dde4408
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorConversions.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+  class Tensor;
+namespace native {
+bool to_will_alias(
+    const Tensor& self,
+    c10::optional dtype,
+    c10::optional layout,
+    c10::optional device,
+    bool copy,
+    c10::optional optional_memory_format);
+
+Tensor to_meta(const Tensor& tensor);
+c10::optional to_meta(const c10::optional& tensor);
+std::vector to_meta(at::ITensorListRef t_list);
+Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, c10::optional layout, OptionalIntArrayRef blocksize, c10::optional dense_dim_opt);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorDimApply.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorDimApply.h
new file mode 100644
index 0000000000000000000000000000000000000000..6cbd8c432f9885185022ceb4dc8257e2d934d78d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorDimApply.h
@@ -0,0 +1,55 @@
+#pragma once
+#include 
+#include 
+
+namespace at::native {
+//input tensors are non-zero dim and non-empty
+template
+
+void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
+  int ndims = self.dim();
+  int tensor_dim_apply_has_finished = 0;
+  std::vector counter(ndims, 0);
+  const T1* self_data = self.const_data_ptr();
+  T1* values_data = values.data_ptr();
+  T2* indices_data = indices.data_ptr();
+  int64_t self_stride = self.stride(dim);
+  int64_t values_stride = values.stride(dim);
+  int64_t indices_stride = indices.stride(dim);
+  int self_dim_size = self.size(dim);
+
+  while (!tensor_dim_apply_has_finished) {
+    func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
+    if (ndims == 1) {
+       break;
+    }
+    for (const auto dim_i : c10::irange(ndims)) {
+      if (dim_i == dim) {
+        if (dim_i == (ndims - 1)) {
+          tensor_dim_apply_has_finished = 1;
+          break;
+        }
+        continue;
+      }
+      counter[dim_i]++;
+      self_data += self.stride(dim_i);
+      values_data += values.stride(dim_i);
+      indices_data += indices.stride(dim_i);
+
+      if (counter[dim_i] == self.size(dim_i)) {
+        if (dim_i == ndims-1) {
+          tensor_dim_apply_has_finished = 1;
+          break;
+        } else {
+          self_data -= counter[dim_i]*self.stride(dim_i);
+          values_data -= counter[dim_i]*values.stride(dim_i);
+          indices_data -= counter[dim_i]*indices.stride(dim_i);
+          counter[dim_i] = 0;
+        }
+      } else {
+        break;
+     }
+    }
+  }
+}
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorFactories.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorFactories.h
new file mode 100644
index 0000000000000000000000000000000000000000..7eb70f18fb15e4926ed5b24edc9118569a1bebb7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorFactories.h
@@ -0,0 +1,142 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+// Different combinations of row, col, and offset can lead to two cases:
+//
+// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
+//    Example A: offset > 0
+//      1 1 0 0 0
+//      1 1 1 0 0
+//      1 1 1 1 0
+//    Example B: offset <= 0
+//      0 0 0
+//      1 0 0
+//      1 1 0
+//    In this case, we calculate the number of elements in the first row and
+//    last row of the tril respectively, and then compute the tril size.
+//
+// Case 2 - Trapezoid + Rectangle: row + offset > col
+//    Example:
+//      1 1 0
+//      1 1 1
+//      1 1 1
+//    In this case, we first calculate the size of top trapezoid, and then
+//    calculate the size of the bottom rectangle.
+inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
+  // If either dimension is 0 then the there is no tril
+  if (row == 0 || col == 0) {
+    return 0;
+  }
+  // number of elements in the first row of the tril
+  auto m_first_row = offset > 0 ?
+    std::min(col, 1 + offset) : // upper bounded by col
+    row + offset > 0; // either 0 or 1
+  // number of elements in the last row of the tril, bounded by [0, col]
+  auto m_last_row = std::max(0, std::min(col, row + offset));
+  // number of rows, bounded by [0, row]
+  auto n_row_all = std::max(0, std::min(row, row + offset));
+  auto n_row_trapezoid = (m_last_row - m_first_row + 1);
+
+  // calculate # of elements in the top trapezoid
+  auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
+
+  // calculate # of elements in the bottom rectangle if there is any
+  auto diff_row = n_row_all - n_row_trapezoid;
+  if (diff_row > 0) {
+    tril_size += diff_row * col;
+  }
+
+  return tril_size;
+}
+
+inline void check_args(
+    int64_t row, int64_t col, c10::optional layout_opt) {
+  TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
+  TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
+  if (layout_opt.has_value()) {
+    TORCH_CHECK(
+      *layout_opt == at::kStrided,
+      "only support layout=torch.strided, got",
+      *layout_opt)
+  }
+}
+
+using at::check_size_nonnegative;
+
+// assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
+inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
+  // match defined() to behavior of checks below
+  TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
+              "n is too large for result tensor type: '", tensor.toString(), "'");
+
+  // Ensure sufficient precision for floating point representation.
+  switch (tensor.scalar_type()) {
+    case at::ScalarType::Half:
+      TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
+      break;
+    case at::ScalarType::Float:
+      TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
+      break;
+    case at::ScalarType::Double:  // Unlikely to happen, but doesn't hurt to check
+      TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
+      break;
+    default:
+      break;
+  }
+}
+
+// Called by `empty*` functions when deterministic algorithms are enabled to
+// fill the tensor with NaN if it is floating point or complex type, or fill
+// with max value if it is integer type
+inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
+  if (tensor.is_floating_point() || tensor.is_complex()) {
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
+      kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
+        tensor.fill_(std::numeric_limits::quiet_NaN());
+    });
+  } else {
+    AT_DISPATCH_V2(
+      tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
+        tensor.fill_(std::numeric_limits::max());
+    }), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
+  }
+  return tensor;
+}
+
+// The ZeroTensor allocator ignores whatever allocation is requested and always
+// gives you nullptr
+struct ZeroTensorAllocator final : public at::Allocator {
+  ZeroTensorAllocator(at::Device device) : device_(device) {};
+  ~ZeroTensorAllocator() override = default;
+  static void deleter(void* const pointer) {
+    TORCH_INTERNAL_ASSERT(!pointer);
+  }
+  DataPtr allocate(const size_t /*nbytes*/) override {
+    return {nullptr, nullptr, &deleter, device_};
+  }
+  DeleterFnPtr raw_deleter() const override {
+    return deleter;
+  }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {}
+  at::Device device_;
+};
+
+using binary_fn = void (*)(TensorIterator&);
+
+DECLARE_DISPATCH(binary_fn, complex_stub);
+DECLARE_DISPATCH(binary_fn, polar_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorIterator.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorIterator.h
new file mode 100644
index 0000000000000000000000000000000000000000..4fb52e967ad7da6e58fca440b588f20767c0bf15
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorIterator.h
@@ -0,0 +1,2 @@
+#pragma once
+#include 
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h
new file mode 100644
index 0000000000000000000000000000000000000000..790e71dd5a5565da4c27869383a2f2436b774b7d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h
@@ -0,0 +1,52 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+// This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h.
+
+// dynamic_casting handles when the types expected by the iterator do not match the types of the arguments
+// to the function that is being called.
+// On CUDA, the cast is currently pushed down into the kernel (for performance reasons).
+// On CPU, there is currently an internal assert that a dynamic_cast is not needed.
+
+namespace at::native {
+
+// `needs_dynamic_casting` compares the types expected by iterator
+// (i.e. dtypes of the operands) with the actual type of the arguments
+// (and returns) of func_t
+template::arity>
+struct needs_dynamic_casting {
+  static bool check(TensorIteratorBase& iter) {
+    using traits = function_traits;
+    using cpp_type = typename traits::template arg::type;
+    using cpp_map = c10::CppTypeToScalarType;
+
+    if (iter.input_dtype(nargs-1) != cpp_map::value) {
+      return true;
+    }
+    return needs_dynamic_casting::check(iter);
+  }
+};
+
+template
+struct needs_dynamic_casting {
+  static bool check(TensorIteratorBase& iter) {
+    using traits = function_traits;
+    using cpp_type = typename traits::result_type;
+
+    // we could assert output numbers are correct here, but checks
+    // (including arity) are currently pushed outside of this struct.
+    if constexpr (std::is_void_v) {
+      return false;
+    } else {
+      return iter.dtype(0) != c10::CppTypeToScalarType::value;
+    }
+  }
+};
+
+} //namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorProperties.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorProperties.h
new file mode 100644
index 0000000000000000000000000000000000000000..8654b3dae577b192c75c9cb8f74ea417bcd3b961
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorProperties.h
@@ -0,0 +1,12 @@
+#pragma once
+
+// See NOTE: [Tensor vs. TensorBase]
+namespace at {
+class TensorBase;
+}
+
+namespace at::native {
+
+TORCH_API bool cudnn_is_acceptable(const TensorBase& self);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorShape.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorShape.h
new file mode 100644
index 0000000000000000000000000000000000000000..623c81c0b16066fe0766405e158621097427ad59
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorShape.h
@@ -0,0 +1,105 @@
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
+
+inline bool cat_should_skip_tensor(const Tensor& t) {
+  return t.numel() == 0 && t.dim() == 1;
+}
+
+ // Check to see if the shape of tensors is compatible
+ // for being concatenated along a given dimension.
+inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
+   int64_t first_dims = first.dim();
+   int64_t second_dims = second.dim();
+   TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
+               first_dims, " and ", second_dims);
+   for (const auto dim : c10::irange(first_dims)) {
+     if (dim == dimension) {
+       continue;
+     }
+     int64_t first_dim_size = first.sizes()[dim];
+     int64_t second_dim_size = second.sizes()[dim];
+     TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
+                 dimension, ". Expected size ", static_cast(first_dim_size), " but got size ", static_cast(second_dim_size), " for tensor number ", index, " in the list.");
+   }
+ }
+
+inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
+  int64_t i = 0;
+  for(const Tensor& t : tensors) {
+    TORCH_CHECK(t.dim() > 0,
+             "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
+    i++;
+  }
+}
+
+inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
+  TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
+  TORCH_CHECK(split_size >= 0,  "split expects split_size be non-negative, but got split_size=", split_size);
+  int64_t dim_size = self.size(dim);
+  TORCH_CHECK(split_size > 0 || dim_size == 0,
+           "split_size can only be 0 if dimension size is 0, "
+           "but got dimension size of ", dim_size);
+  // if split_size is 0 and dimension size is 0, there is 1 split.
+  int64_t num_splits = 1;
+  if (split_size != 0) {
+    // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
+    // (returns a single split).  We might want to error here, but keep it for BC.
+    num_splits = std::max((dim_size + split_size - 1) / split_size, 1);
+  }
+  return num_splits;
+}
+
+inline bool have_same_ndims(TensorList tensors) {
+  auto ndim = tensors[0].dim();
+  for (const auto tensor_idx : c10::irange(tensors.size())) {
+    if(tensors[tensor_idx].dim() != ndim) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
+  auto tensor_zero_size = tensors[0].sizes();
+  std::vector leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
+  for (const auto i : c10::irange(tensors.size())) {
+    at::Tensor tensor = tensors[i];
+    for(const auto j : c10::irange(dim)) {
+      TORCH_CHECK(
+        tensor.size(j) == leading_dim_sizes[j],
+        "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
+      );
+    }
+  }
+}
+
+inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
+  TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
+  TORCH_CHECK(!tensors.empty(),
+           "_chunk_cat expects a non-empty input tensor list");
+  auto expected_dtype = tensors[0].dtype();
+  auto expected_device = tensors[0].device();
+  for(const auto i : c10::irange(tensors.size())) {
+    TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
+    TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
+    TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
+  }
+  if (have_same_ndims(tensors)) {
+    dim = maybe_wrap_dim(dim, tensors[0].dim());
+  } else {
+    TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
+    for(const auto i : c10::irange(tensors.size())) {
+      TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
+    }
+  }
+  leading_dimension_matches(tensors, dim);
+  return dim;
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TensorTransformations.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorTransformations.h
new file mode 100644
index 0000000000000000000000000000000000000000..74e1e4232ce42bf827de488f3ee1ff9d50db235e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TensorTransformations.h
@@ -0,0 +1,30 @@
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+#include 
+
+namespace at::native {
+
+static inline Tensor roll_common(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
+  TORCH_CHECK(!shifts.empty(), "`shifts` required");
+  if (dims.empty() && shifts.size() == 1) {
+    auto flattened = self.contiguous().view(self.numel());
+    return roll(flattened, shifts[0], 0).view(self.sizes());
+  }
+  TORCH_CHECK(
+    shifts.size() == dims.size(),
+    "shifts and dimensions must align. shifts: ", shifts.size(), ", dims:", dims.size()
+  );
+  AT_ASSERT(dims.size() > 1);
+  auto tail_shifts = shifts.slice(1);
+  auto tail_dims = dims.slice(1);
+  auto first_dim_rolled = roll(self, shifts[0], dims[0]);
+  return at::roll(first_dim_rolled, tail_shifts, tail_dims);
+}
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TopKImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TopKImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..a8ffaf61295398c9e7a28bdcbc77d4c81e9b3846
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TopKImpl.h
@@ -0,0 +1,98 @@
+#pragma once
+#include 
+#include 
+
+namespace at::native {
+
+#ifdef CPU_CAPABILITY
+inline namespace CPU_CAPABILITY {
+#else
+inline namespace DEFAULT {
+#endif
+
+// Core topk loop, shared between CPU and QuantizedCPU
+template 
+void topk_impl_loop(
+    const int64_t mode_values_stride,
+    const int64_t mode_indices_stride,
+    const int64_t tmp_values_stride,
+    const int64_t k,
+    const int64_t dim_size,
+    const bool largest,
+    const bool sorted,
+    char** data, const int64_t* strides, const int64_t n) {
+
+  // If k is zero, then output values and indices are empty tensors
+  // So iterating over other dims is pointless
+  if (k == 0) {
+    return;
+  }
+  using elem_t = std::pair;
+  std::vector queue(dim_size);
+  for (const auto i : c10::irange(n)) {
+    TensorAccessor mode_values(
+        reinterpret_cast(data[0] + i * strides[0]),
+        &k, &mode_values_stride);
+    TensorAccessor mode_indices(
+        reinterpret_cast(data[1] + i * strides[1]),
+        &k, &mode_indices_stride);
+    TensorAccessor tmp_values(
+        reinterpret_cast(data[2] + i * strides[2]),
+        &dim_size, &tmp_values_stride);
+
+    auto n_2 = dim_size;
+    auto use_partial_sort = k * 64 <= n_2;
+
+    for (const auto j : c10::irange(n_2)) {
+      queue[j].first = tmp_values[j];
+      queue[j].second = j;
+    }
+
+    // we want nan to be sorted as top for numpy compatibility
+    if (use_partial_sort) {
+      if (largest) {
+        std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
+          [](const elem_t& x, const elem_t& y) -> bool {
+            return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first));
+          });
+      } else {
+        std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
+          [](const elem_t& x, const elem_t& y) -> bool {
+            return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first));
+          });
+      }
+    } else {
+      if (largest) {
+        std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
+          [](const elem_t& x, const elem_t& y) -> bool {
+            return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first));
+          });
+        if (sorted) {
+          std::sort(queue.begin(), queue.begin() + k - 1,
+            [](const elem_t& x, const elem_t& y) -> bool {
+              return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first));
+            });
+        }
+      } else {
+        std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
+          [](const elem_t& x, const elem_t& y) -> bool {
+            return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first));
+          });
+        if (sorted) {
+          std::sort(queue.begin(), queue.begin() + k -1,
+            [](const elem_t& x, const elem_t& y) -> bool {
+              return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first));
+            });
+        }
+      }
+    }
+
+    for (const auto j : c10::irange(k)) {
+      mode_values[j] = queue[j].first;
+      mode_indices[j] = queue[j].second;
+    }
+  }
+}
+
+} // namespace CPU_CAPABILITY
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TransposeType.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TransposeType.h
new file mode 100644
index 0000000000000000000000000000000000000000..2ebdce31873a4ff7e6269551d374952a35f49fdc
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TransposeType.h
@@ -0,0 +1,23 @@
+#pragma once
+#include 
+
+namespace at::native {
+
+// Used as an interface between the different BLAS-like libraries
+enum class TransposeType {
+  NoTranspose,
+  Transpose,
+  ConjTranspose,
+};
+
+// Transforms TransposeType into the BLAS / LAPACK format
+static inline char to_blas(TransposeType trans) {
+  switch (trans) {
+    case TransposeType::Transpose: return 'T';
+    case TransposeType::NoTranspose: return 'N';
+    case TransposeType::ConjTranspose: return 'C';
+  }
+  TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TriangularOpsUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TriangularOpsUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..27fe2e18cb685b5fe32214b0fe10466d2b5d0189
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TriangularOpsUtils.h
@@ -0,0 +1,57 @@
+#include 
+#include 
+
+namespace at::native {
+
+/*
+ * Given batches of matrices with arbitrary batch dim,
+ * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
+ */
+static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
+  int64_t result = 1;
+  for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
+    if (batched_matrices.stride(i) != 0) {
+      result *= batched_matrices.size(i);
+    }
+  }
+  return result;
+}
+
+/* Checks a necessary property for the triu and tril implementations, hence the name.
+ * Here batch contiguity is checked for tensors with greater than 4 dimensions.
+ * Contiguous tensors and tensors with less than 3 dimensions pass this check
+ */
+static inline std::tuple checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
+  // Complete contiguity is the most desired property, which is why
+  // we return true if the tensor is contiguous
+  if (tensor.is_contiguous()) {
+    auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
+    if (tensor.strides() == default_strides_for_size) {
+      return std::make_tuple(true, tensor);
+    } else {
+      return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
+    }
+  }
+
+  int64_t dims = tensor.dim();
+
+  // Tensors with dimension less than 4 are handled by default
+  if (allow_zero_stride && dims <= 3) {
+    return std::make_tuple(true, tensor);
+  }
+
+  int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
+  for (int64_t i = dims - 3; i >= 0; i--) {
+    // Skip trivial dimension;
+    if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
+      continue;
+    }
+    if (expected_stride != tensor.stride(i)) {
+      return std::make_tuple(false, tensor.contiguous());
+    }
+    expected_stride *= tensor.size(i);
+  }
+  return std::make_tuple(true, tensor);
+}
+
+}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/TypeProperties.h b/MLPY/Lib/site-packages/torch/include/ATen/native/TypeProperties.h
new file mode 100644
index 0000000000000000000000000000000000000000..07f0028655e58f6c1305251782ad6a5e51ad7a74
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/TypeProperties.h
@@ -0,0 +1,20 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+struct ResultTypeState {
+  c10::ScalarType dimResult = ScalarType::Undefined;
+  c10::ScalarType wrappedResult = ScalarType::Undefined;
+  c10::ScalarType zeroResult = ScalarType::Undefined;
+};
+
+TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state);
+TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state);
+TORCH_API ScalarType result_type(const ResultTypeState& state);
+
+TORCH_API ScalarType result_type(ITensorListRef tensors);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/UnaryOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/UnaryOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..7953186237fd071e0b03fc1acee9077507b2869d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/UnaryOps.h
@@ -0,0 +1,130 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+class Tensor;
+class TensorBase;
+struct TensorIteratorBase;
+}
+
+namespace at::native {
+
+using unary_fn = void(*)(TensorIteratorBase&);
+using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
+
+inline namespace CPU_CAPABILITY {
+void conj_kernel(TensorIteratorBase &iter);
+void neg_kernel(TensorIteratorBase &iter);
+void reciprocal_kernel(TensorIteratorBase &iter);
+void rsqrt_kernel(TensorIteratorBase& iter);
+void sqrt_kernel(TensorIteratorBase& iter);
+} // namespace CPU_CAPABILITY
+
+DECLARE_DISPATCH(unary_fn, abs_stub);
+DECLARE_DISPATCH(unary_fn, angle_stub);
+DECLARE_DISPATCH(unary_fn, conj_physical_stub);
+DECLARE_DISPATCH(unary_fn, acos_stub);
+DECLARE_DISPATCH(unary_fn, acosh_stub);
+DECLARE_DISPATCH(unary_fn, asinh_stub);
+DECLARE_DISPATCH(unary_fn, atanh_stub);
+DECLARE_DISPATCH(unary_fn, asin_stub);
+DECLARE_DISPATCH(unary_fn, atan_stub);
+DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
+DECLARE_DISPATCH(unary_fn, logical_not_stub);
+DECLARE_DISPATCH(unary_fn, ceil_stub);
+DECLARE_DISPATCH(unary_fn, cos_stub);
+DECLARE_DISPATCH(unary_fn, cosh_stub);
+DECLARE_DISPATCH(unary_fn, digamma_stub);
+DECLARE_DISPATCH(unary_fn, special_entr_stub);
+DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
+DECLARE_DISPATCH(unary_fn, erf_stub);
+DECLARE_DISPATCH(unary_fn, erfc_stub);
+DECLARE_DISPATCH(unary_fn, erfinv_stub);
+DECLARE_DISPATCH(unary_fn, exp_stub);
+DECLARE_DISPATCH(unary_fn, exp2_stub);
+DECLARE_DISPATCH(unary_fn, expm1_stub);
+DECLARE_DISPATCH(unary_fn, floor_stub);
+DECLARE_DISPATCH(unary_fn, frac_stub);
+DECLARE_DISPATCH(unary_fn, frexp_stub);
+DECLARE_DISPATCH(unary_fn, i0_stub);
+DECLARE_DISPATCH(unary_fn, special_i0e_stub);
+DECLARE_DISPATCH(unary_fn, special_i1_stub);
+DECLARE_DISPATCH(unary_fn, special_i1e_stub);
+DECLARE_DISPATCH(unary_fn, log_stub);
+DECLARE_DISPATCH(unary_fn, log10_stub);
+DECLARE_DISPATCH(unary_fn, log1p_stub);
+DECLARE_DISPATCH(unary_fn, log2_stub);
+DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
+DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
+DECLARE_DISPATCH(unary_fn, neg_stub);
+
+DECLARE_DISPATCH(unary_fn, reciprocal_stub);
+DECLARE_DISPATCH(unary_fn, round_stub);
+DECLARE_DISPATCH(unary_fn, rsqrt_stub);
+DECLARE_DISPATCH(unary_fn, sigmoid_stub);
+DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
+DECLARE_DISPATCH(unary_fn, sign_stub);
+DECLARE_DISPATCH(unary_fn, signbit_stub);
+DECLARE_DISPATCH(unary_fn, sgn_stub);
+DECLARE_DISPATCH(unary_fn, sin_stub);
+DECLARE_DISPATCH(unary_fn, sinc_stub);
+DECLARE_DISPATCH(unary_fn, sinh_stub);
+DECLARE_DISPATCH(unary_fn, sqrt_stub);
+DECLARE_DISPATCH(unary_fn, tan_stub);
+DECLARE_DISPATCH(unary_fn, tanh_stub);
+DECLARE_DISPATCH(unary_fn, trigamma_stub);
+DECLARE_DISPATCH(unary_fn, trunc_stub);
+DECLARE_DISPATCH(unary_fn, lgamma_stub);
+DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
+DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
+DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
+DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
+DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
+DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
+DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
+DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
+DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
+DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
+DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
+DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
+
+// NB: these are actually defined in Distribution
+DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, c10::optional), bernoulli_tensor_stub);
+DECLARE_DISPATCH(void(*)(const TensorBase&, const double, c10::optional), bernoulli_scalar_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional), cauchy_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional), exponential_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, c10::optional), geometric_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional), log_normal_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, c10::optional), uniform_stub);
+DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, c10::optional), normal_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, c10::optional), random_from_to_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional), random_full_64_bits_range_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, c10::optional), random_stub);
+
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
+DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
+DECLARE_DISPATCH(
+    void (*)(Tensor&, const Tensor&, int64_t, c10::optional),
+    multinomial_with_replacement_stub);
+DECLARE_DISPATCH(
+    void (*)(
+        TensorIteratorBase&,
+        c10::optional,
+        c10::optional,
+        c10::optional),
+    nan_to_num_stub);
+DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
+
+// Missing unary functions
+// digamma
+// lgamma
+// erfinv
+// clone
+// contiguous
+// zero
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Unfold2d.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Unfold2d.h
new file mode 100644
index 0000000000000000000000000000000000000000..f966d9d7c9776ab76c26da5a3be2ad98e13cf5f8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Unfold2d.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+using unfold2d_fn = void (*)(
+    ScalarType dtype,
+    void *finput,
+    void *input,
+    int64_t kH,
+    int64_t kW,
+    int64_t dH,
+    int64_t dW,
+    int64_t padH,
+    int64_t padW,
+    int64_t n_input_plane,
+    int64_t input_height,
+    int64_t input_width,
+    int64_t output_height,
+    int64_t output_width,
+    bool is_channels_last
+);
+
+DECLARE_DISPATCH(unfold2d_fn, unfolded2d_copy_stub);
+DECLARE_DISPATCH(unfold2d_fn, unfolded2d_acc_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/Unfold3d.h b/MLPY/Lib/site-packages/torch/include/ATen/native/Unfold3d.h
new file mode 100644
index 0000000000000000000000000000000000000000..eae526b7ec33a2ec1b34aeee808f78fc47931c82
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/Unfold3d.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include 
+
+namespace at::native {
+
+void Unfold3dCopyCPU(
+    ScalarType dtype,
+    const void *src,
+    int64_t C,
+    int64_t X_D,
+    int64_t X_H,
+    int64_t X_W,
+    int64_t Y_D,
+    int64_t Y_H,
+    int64_t Y_W,
+    int64_t kernel_d,
+    int64_t kernel_h,
+    int64_t kernel_w,
+    int64_t stride_d,
+    int64_t stride_h,
+    int64_t stride_w,
+    int64_t pad_d,
+    int64_t pad_h,
+    int64_t pad_w,
+    void* dst);
+
+void Unfold3dAccCPU(
+    ScalarType dtype,
+    const void *src,
+    int64_t C,
+    int64_t X_D,
+    int64_t X_H,
+    int64_t X_W,
+    int64_t Y_D,
+    int64_t Y_H,
+    int64_t Y_W,
+    int64_t kernel_d,
+    int64_t kernel_h,
+    int64_t kernel_w,
+    int64_t stride_d,
+    int64_t stride_h,
+    int64_t stride_w,
+    int64_t pad_d,
+    int64_t pad_h,
+    int64_t pad_w,
+    void *dst);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/UnfoldBackward.h b/MLPY/Lib/site-packages/torch/include/ATen/native/UnfoldBackward.h
new file mode 100644
index 0000000000000000000000000000000000000000..f715d4dbf9db2bb23715b7c87e62ccdaf989015a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/UnfoldBackward.h
@@ -0,0 +1,112 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+namespace at::native {
+
+using unfold_backward_fn = void (*)(
+  Tensor& grad_in,
+  const Tensor& grad,
+  int64_t dim,
+  int64_t size,
+  int64_t step
+);
+
+DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub);
+
+namespace {
+
+// Note on naming: it is unconventional.
+// grad_in does not mean that it is a gradient wrt to input,
+// grad_in/grad_out is just an input/output of unfold_backward kernel.
+
+static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
+  Tensor& grad_out,
+  const Tensor& grad_in,
+  int64_t dim,
+  int64_t size,
+  int64_t step
+) {
+  dim = maybe_wrap_dim(dim, grad_out.dim());
+  // last dim stores the folds
+
+  auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
+  auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
+  // dictates the number of elements to iterate over
+  // in dimension `dim`
+  auto iter_dim_size = std::min(
+    grad_out_dim_size,
+    (grad_in_dim_size - 1) * step + size
+  );
+
+  /* prepare grad_out for TensorIterator { */
+  auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
+  auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
+  grad_out_sizes[dim] = iter_dim_size;
+  auto grad_out_restrided = grad_out.as_strided(
+    grad_out_sizes, grad_out_strides
+  );
+  /* } */
+
+  /* prepare grad_in for TensorIterator { */
+  auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
+  auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
+
+  // set strides for dim to 0
+  // and size to 1 because
+  // this dimension is indexed inside the kernel
+  grad_in_strides[dim] = 0;
+  grad_in_sizes[dim] = 1;
+
+  grad_in_strides.pop_back();
+  grad_in_sizes.pop_back();
+
+  auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
+    grad_in_sizes, grad_in_strides
+  );
+  /* } */
+
+  // During the TensorIterator iteration we have to know
+  // i_dim in grad_out[i_1,...,i_dim,...i_n],
+  // idx_dim stores this information
+  /* prepare idx_dim for TensorIterator { */
+  auto idx_dim = at::arange(
+    0, iter_dim_size, grad_in.options().dtype(at::kLong)
+  );
+
+  auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
+
+  auto idx_dim_strides = std::vector(grad_out_dim, 0);
+  auto idx_dim_sizes = std::vector(grad_out_dim, 1);
+
+  idx_dim_strides[dim] = 1;
+  idx_dim_sizes[dim] = iter_dim_size;
+
+  // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
+  auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
+  /* } */
+
+  auto iter = TensorIteratorConfig()
+    .set_check_mem_overlap(false)
+    .check_all_same_dtype(false)
+    .resize_outputs(false)
+    .add_owned_output(grad_out_restrided)
+    .add_owned_input(grad_in_restrided)
+    .add_owned_input(idx_dim_restrided)
+    .build();
+
+  return iter;
+}
+
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/UpSample.h b/MLPY/Lib/site-packages/torch/include/ATen/native/UpSample.h
new file mode 100644
index 0000000000000000000000000000000000000000..72c4f1d72cdb847db5c7da5f11d3219317b4b187
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/UpSample.h
@@ -0,0 +1,506 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+/**
+ * Note [compute_scales_value]
+ * Note [area_pixel_compute_scale]
+ * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ * Interpolate with scale_factor can have different behaviors
+ * depending on the value of recompute_scale_factor:
+ *
+ * - With recompute_scale_factor = True (current default behavior):
+ * the scale_factor, when provided by the user, are used to calculate
+ * the output size. The input size and the computed output_size
+ * are then used to infer new values for the scales which are
+ * used in the interpolation.  Because floating-point math is not exact,
+ * this may be a different value from the user-supplied scales.
+ *
+ * - With recompute_scale_factor = False (which will be the default
+ * behavior starting 1.5.0):
+ * the behavior follows opencv logic, and the scales provided by
+ * the user are the ones used in the interpolation calculations.
+ *
+ * If the scales are not provided or if they are provided but
+ * recompute_scale_factor is set to True (default behavior), the scales
+ * are computed from the input and the output size;
+ *
+ *
+ * When the scales are inferred from the input and output sizes,
+ * we view each pixel as an area, idx + 0.5 as its center index.
+ * Here is an example formula in 1D case.
+ * if align_corners: center of two corner pixel areas are preserved,
+ *     (0.5, 0.5) -> (0.5, 0.5),
+ *     (input_size - 0.5, 0.5) -> (output_size - 0.5)
+ *     scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
+ *     src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
+ * if not align_corners: the whole range is scaled accordingly
+ *     scale = input_size / output_size
+ *     src_idx + 0.5 = scale * (dst_index + 0.5)
+ */
+
+namespace at::native {
+
+namespace upsample {
+
+TORCH_API c10::SmallVector compute_output_size(
+    c10::IntArrayRef input_size,  // Full input tensor size.
+    at::OptionalIntArrayRef output_size,
+    c10::optional> scale_factors);
+
+inline c10::optional get_scale_value(c10::optional> scales, int idx) {
+  if (!scales) {
+    return c10::nullopt;
+  }
+  return scales->at(idx);
+}
+
+} // namespace upsample
+
+using scale_t = c10::optional;
+using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
+using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
+using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
+using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
+using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
+using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
+using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
+using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
+using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
+using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
+using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
+using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
+DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
+DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
+DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
+DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
+DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
+DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
+DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
+DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
+DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
+DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
+DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
+DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
+DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
+DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
+DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
+DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
+DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
+DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
+DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
+DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
+DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
+DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
+DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
+
+static C10_UNUSED std::array upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
+  TORCH_CHECK(
+      output_size.size() == 1,
+      "It is expected output_size equals to 1, but got size ",
+      output_size.size());
+
+  TORCH_CHECK(
+      input_size.size() == 3,
+      "It is expected input_size equals to 3, but got size ",
+      input_size.size());
+
+  int64_t output_width = output_size[0];
+
+  int64_t nbatch = input_size[0];
+  int64_t channels = input_size[1];
+  int64_t input_width = input_size[2];
+
+  TORCH_CHECK(
+      input_width > 0 && output_width > 0,
+      "Input and output sizes should be greater than 0, but got input (W: ",
+      input_width,
+      ") and output (W: ",
+      output_width,
+      ")");
+
+  return {nbatch, channels, output_width};
+}
+
+static C10_UNUSED std::array upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
+  TORCH_CHECK(
+      output_size.size() == 2,
+      "It is expected output_size equals to 2, but got size ",
+      output_size.size());
+
+  TORCH_CHECK(
+      input_size.size() == 4,
+      "It is expected input_size equals to 4, but got size ",
+      input_size.size());
+
+  int64_t output_height = output_size[0];
+  int64_t output_width = output_size[1];
+
+  int64_t nbatch = input_size[0];
+  int64_t channels = input_size[1];
+  int64_t input_height = input_size[2];
+  int64_t input_width = input_size[3];
+
+  TORCH_CHECK(
+      input_height > 0 && input_width > 0 && output_height > 0 &&
+          output_width > 0,
+      "Input and output sizes should be greater than 0,"
+      " but got input (H: ",
+      input_height,
+      ", W: ",
+      input_width,
+      ") output (H: ",
+      output_height,
+      ", W: ",
+      output_width,
+      ")");
+
+  return {nbatch, channels, output_height, output_width};
+}
+
+static C10_UNUSED
+std::array upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
+  TORCH_CHECK(
+      output_size.size() == 3,
+      "It is expected output_size equals to 3, but got size ",
+      output_size.size());
+
+  TORCH_CHECK(
+      input_size.size() == 5,
+      "It is expected input_size equals to 5, but got size ",
+      input_size.size());
+
+  int64_t output_depth = output_size[0];
+  int64_t output_height = output_size[1];
+  int64_t output_width = output_size[2];
+
+  int64_t nbatch = input_size[0];
+  int64_t channels = input_size[1];
+  int64_t input_depth = input_size[2];
+  int64_t input_height = input_size[3];
+  int64_t input_width = input_size[4];
+
+  TORCH_CHECK(
+      input_depth > 0 && input_height > 0 && input_width > 0 &&
+          output_depth > 0 && output_height > 0 && output_width > 0,
+      "Input and output sizes should be greater than 0, but got input (D: ",
+      input_depth,
+      ", H: ",
+      input_height,
+      ", W: ",
+      input_width,
+      ") output (D: ",
+      output_depth,
+      ", H: ",
+      output_height,
+      ", W: ",
+      output_width,
+      ")");
+
+
+  return {nbatch, channels, output_depth, output_height, output_width};
+}
+
+static inline void upsample_2d_shape_check(
+    const Tensor& input,
+    const Tensor& grad_output,
+    int64_t nbatch,
+    int64_t nchannels,
+    int64_t input_height,
+    int64_t input_width,
+    int64_t output_height,
+    int64_t output_width) {
+  TORCH_CHECK(
+      input_height > 0 && input_width > 0 && output_height > 0 &&
+          output_width > 0,
+      "Input and output sizes should be greater than 0,"
+      " but got input (H: ",
+      input_height,
+      ", W: ",
+      input_width,
+      ") output (H: ",
+      output_height,
+      ", W: ",
+      output_width,
+      ")");
+
+  if (input.defined()) {
+    // Allow for empty batch size but not other dimensions
+    TORCH_CHECK(
+                (input.numel() != 0 ||
+                 (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
+                 ) &&
+                input.dim() == 4,
+                "Non-empty 4D data tensor expected but got a tensor with sizes ",
+                input.sizes());
+  } else if (grad_output.defined()) {
+    check_dim_size(grad_output, 4, 0, nbatch);
+    check_dim_size(grad_output, 4, 1, nchannels);
+    check_dim_size(grad_output, 4, 2, output_height);
+    check_dim_size(grad_output, 4, 3, output_width);
+  }
+}
+
+template 
+static inline scalar_t compute_scales_value(
+    const c10::optional scale,
+    int64_t input_size,
+    int64_t output_size) {
+      // see Note [compute_scales_value]
+      // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
+      return (scale.has_value() && scale.value() > 0.)
+          ? static_cast(1.0 / scale.value())
+          : (static_cast(input_size) / output_size);
+}
+
+template 
+static inline scalar_t area_pixel_compute_scale(
+    int64_t input_size,
+    int64_t output_size,
+    bool align_corners,
+    const c10::optional scale) {
+  // see Note [area_pixel_compute_scale]
+  if(align_corners) {
+    if(output_size > 1) {
+      return static_cast(input_size - 1) / (output_size - 1);
+    } else {
+      return static_cast(0);
+    }
+  } else {
+    return compute_scales_value(scale, input_size, output_size);
+  }
+}
+
+template 
+static inline scalar_t area_pixel_compute_source_index(
+    scalar_t scale,
+    int64_t dst_index,
+    bool align_corners,
+    bool cubic) {
+  if (align_corners) {
+    return scale * dst_index;
+  } else {
+    scalar_t src_idx = scale * (dst_index + static_cast(0.5)) -
+        static_cast(0.5);
+    // [Note] Follow Opencv resize logic:
+    // We allow negative src_idx here and later will use
+    //   dx = src_idx - floorf(src_idx)
+    // to compute the "distance"(which affects weights).
+    // For linear modes, weight distribution doesn't matter
+    // for negative indices as they use 2 pixels to interpolate.
+    // For example, [-1, 0], they both use pixel 0 value so it
+    // doesn't affect if we bound the src_idx to 0 or not.
+    // TODO: Our current linear mode impls use unbound indices
+    // where we should and then remove this cubic flag.
+    // This matters in cubic mode, as we might need [-1, 0, 1, 2]
+    // to interpolate and the weights can be affected.
+    return (!cubic && src_idx < static_cast(0)) ? scalar_t(0)
+                                                          : src_idx;
+  }
+}
+
+static inline int64_t nearest_neighbor_compute_source_index(
+    const float scale,
+    int64_t dst_index,
+    int64_t input_size) {
+  // Index computation matching OpenCV INTER_NEAREST
+  // which is buggy and kept for BC
+  const int64_t src_index =
+      std::min(static_cast(floorf(dst_index * scale)), input_size - 1);
+  return src_index;
+}
+
+static inline int64_t nearest_neighbor_exact_compute_source_index(
+    const float scale,
+    int64_t dst_index,
+    int64_t input_size) {
+  // index_f32 = (output_index + 0.5) * scale - 0.5
+  // input_index = round(index_f32)
+  // Same as Pillow and Scikit-Image/Scipy ndi.zoom
+  const int64_t src_index =
+      std::min(static_cast(floorf((dst_index + 0.5) * scale)), input_size - 1);
+  return src_index;
+}
+
+static inline int64_t nearest_idx(
+    int64_t output_index,
+    int64_t input_size,
+    int64_t output_size,
+    c10::optional scales) {
+  // This method specificly treats cases: output_size == input_size or
+  // output_size == 2 * input_size, that we would like to get rid of
+  // We keep this method for BC and consider as deprecated.
+  // See nearest_exact_idx as replacement
+  if (output_size == input_size) {
+    // scale_factor = 1, simply copy
+    return output_index;
+  } else if (output_size == 2 * input_size) {
+    // scale_factor = 2, shift input index
+    return output_index >> 1;
+  } else {
+    float scale = compute_scales_value(scales, input_size, output_size);
+    return nearest_neighbor_compute_source_index(scale, output_index, input_size);
+  }
+}
+
+static inline int64_t nearest_exact_idx(
+    int64_t output_index,
+    int64_t input_size,
+    int64_t output_size,
+    c10::optional scales) {
+  float scale = compute_scales_value(scales, input_size, output_size);
+    return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
+}
+
+// Define a typedef to dispatch to nearest_idx or nearest_exact_idx
+typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, c10::optional);
+
+template 
+static scalar_t upsample_get_value_bounded(
+    scalar_t* data,
+    int64_t width,
+    int64_t height,
+    int64_t x,
+    int64_t y) {
+  int64_t access_x = std::max(std::min(x, width - 1), static_cast(0));
+  int64_t access_y = std::max(std::min(y, height - 1), static_cast(0));
+  return data[access_y * width + access_x];
+}
+
+template 
+static void upsample_increment_value_bounded(
+    scalar_t* data,
+    int64_t width,
+    int64_t height,
+    int64_t x,
+    int64_t y,
+    scalar_t value) {
+  int64_t access_x = std::max(std::min(x, width - 1), static_cast(0));
+  int64_t access_y = std::max(std::min(y, height - 1), static_cast(0));
+  data[access_y * width + access_x] += value;
+}
+
+// Based on
+// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+template 
+static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
+  return ((A + 2) * x - (A + 3)) * x * x + 1;
+}
+
+template 
+static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
+  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
+}
+
+template 
+static inline void get_cubic_upsample_coefficients(
+    scalar_t coeffs[4],
+    scalar_t t) {
+  scalar_t A = -0.75;
+
+  scalar_t x1 = t;
+  coeffs[0] = cubic_convolution2(x1 + 1.0, A);
+  coeffs[1] = cubic_convolution1(x1, A);
+
+  // opposite coefficients
+  scalar_t x2 = 1.0 - t;
+  coeffs[2] = cubic_convolution1(x2, A);
+  coeffs[3] = cubic_convolution2(x2 + 1.0, A);
+}
+
+template 
+static inline scalar_t cubic_interp1d(
+    scalar_t x0,
+    scalar_t x1,
+    scalar_t x2,
+    scalar_t x3,
+    scalar_t t) {
+  scalar_t coeffs[4];
+  get_cubic_upsample_coefficients(coeffs, t);
+
+  return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
+}
+
+// when `real_input_index` becomes larger than the range the floating point
+// type can accurately represent, the type casting to `int64_t` might exceed
+// `input_size`, causing overflow. So we guard it with `std::min` below.
+template
+static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
+  input_index = std::min(static_cast(floorf(real_input_index)), input_size - 1);
+  lambda = std::min(
+      std::max(real_input_index - input_index, static_cast(0)),
+      static_cast(1)
+    );
+}
+
+template
+static inline void compute_source_index_and_lambda(
+    int64_t& input_index0,
+    int64_t& input_index1,
+    scalar_t& lambda0,
+    scalar_t& lambda1,
+    opmath_t ratio,
+    int64_t output_index,
+    int64_t input_size,
+    int64_t output_size,
+    bool align_corners) {
+  if (output_size == input_size) {
+    // scale_factor = 1, simply copy
+    input_index0 = output_index;
+    input_index1 = output_index;
+    lambda0 = static_cast(1);
+    lambda1 = static_cast(0);
+  } else {
+    const auto real_input_index =
+        area_pixel_compute_source_index(
+            ratio, output_index, align_corners, /*cubic=*/false);
+    guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
+    int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
+    input_index1 = input_index0 + offset;
+    lambda0 = static_cast(1.) - lambda1;
+  }
+}
+
+// It will not be used by data types other than BFloat16 and Half.
+template  || !std::is_same::value, int> = 0>
+void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
+  TORCH_CHECK((is_reduced_floating_point_v),
+              "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
+  TORCH_CHECK((std::is_same::value),
+              "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
+  return;
+}
+
+template  && std::is_same::value, int> = 0>
+void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
+  using bVec = Vectorized;
+  using fVec = Vectorized;
+  int64_t d = 0;
+  for (; d < size - (size % bVec::size()); d += bVec::size()) {
+    bVec gin_bvec = bVec::loadu(gin + d);
+    fVec gin_fvec0, gin_fvec1;
+    std::tie(gin_fvec0, gin_fvec1) = convert_to_float(gin_bvec);
+    gin_fvec0 += fVec::loadu(buffer_ptr + d);
+    gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
+    fVec(0).store(buffer_ptr + d);
+    fVec(0).store(buffer_ptr + d + fVec::size());
+    convert_from_float(gin_fvec0, gin_fvec1).store(gin + d);
+  }
+  for (; d < size; d++) {
+    gin[d] += buffer_ptr[d];
+    buffer_ptr[d] = 0;
+  }
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/batch_norm.h b/MLPY/Lib/site-packages/torch/include/ATen/native/batch_norm.h
new file mode 100644
index 0000000000000000000000000000000000000000..d38158cfe4b6e2c027ba4bd2daa44f2501881522
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/batch_norm.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at::native {
+
+using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
+    const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
+using batch_norm_collect_stats_fn = void (*)(Tensor&, Tensor&, const Tensor&);
+using batch_norm_backward_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&,
+        const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, bool, double);
+
+DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_stub);
+DECLARE_DISPATCH(batch_norm_collect_stats_fn, batch_norm_cpu_collect_stats_stub);
+DECLARE_DISPATCH(batch_norm_backward_fn, batch_norm_cpu_backward_stub);
+
+// TensorAccessor when it is defined to work around undefined...
+template 
+static TensorAccessor conditional_accessor_1d(const Tensor& t) {
+  if (! t.defined()) {
+    return TensorAccessor(nullptr, nullptr, nullptr);
+  }
+  return t.accessor();
+}
+
+template 
+static scalar_t* conditional_data_ptr(const Tensor& t) {
+  return t.defined() ? t.contiguous().data_ptr()
+                     : nullptr;
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h
new file mode 100644
index 0000000000000000000000000000000000000000..5f2fe7f1a32f50cd35585d4f1060cd76386beec5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h
@@ -0,0 +1,37 @@
+#ifndef ATOMIC_ADD_FLOAT
+#define ATOMIC_ADD_FLOAT
+
+#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
+#include 
+#else
+#define _mm_pause()
+#endif
+
+#include 
+
+static inline void cpu_atomic_add_float(float* dst, float fvalue)
+{
+  typedef union {
+    unsigned intV;
+    float floatV;
+  } uf32_t;
+
+  uf32_t new_value, old_value;
+  std::atomic* dst_intV = (std::atomic*)(dst);
+
+  old_value.floatV = *dst;
+  new_value.floatV = old_value.floatV + fvalue;
+
+  unsigned* old_intV = (unsigned*)(&old_value.intV);
+  while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
+#ifdef __aarch64__
+    __asm__ __volatile__("yield;" : : : "memory");
+#else
+    _mm_pause();
+#endif
+    old_value.floatV = *dst;
+    new_value.floatV = old_value.floatV + fvalue;
+  }
+}
+
+#endif
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/CatKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/CatKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..79bf7c06af6991d656114fd0bd8678c544e96a00
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/CatKernel.h
@@ -0,0 +1,12 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at { namespace native {
+
+using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t);
+DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..a0b62ef3226e0a129990f9b107a15e7e240489ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h
@@ -0,0 +1,14 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at { namespace native {
+
+using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
+DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel);
+
+}} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/CopyKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/CopyKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..da04349fe44724e4e24d1e690160f2c5a3cf2fa5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/CopyKernel.h
@@ -0,0 +1,12 @@
+#pragma once
+
+namespace at {
+struct TensorIteratorBase;
+
+namespace native {
+inline namespace CPU_CAPABILITY {
+
+void direct_copy_kernel(TensorIteratorBase &iter);
+void copy_kernel(TensorIterator& iter, bool /*non_blocking*/);
+
+}}}  // namespace at::native::CPU_CAPABILITY
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..fbaff919a620b2c3c78603f03744016ccfdd4f10
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include 
+#include 
+
+/*
+  Depthwise 3x3 Winograd convolution operator
+*/
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using convolution_depthwise3x3_winograd_fn =
+    Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
+
+DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
+
+}  // namespace native
+}  // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h
new file mode 100644
index 0000000000000000000000000000000000000000..5bc026ae278a95436698bf600bb3e8bb61327dd2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h
@@ -0,0 +1,369 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef CPU_CAPABILITY_AVX2
+#include 
+#include 
+#endif
+
+
+namespace at {
+namespace native {
+namespace templates {
+namespace cpu {
+namespace {
+
+// ==================================================== Random ========================================================
+
+template
+void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
+  AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
+    std::lock_guard lock(generator->mutex_);
+    cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
+      uniform_int_from_to_distribution random(range, base);
+      return random(generator);
+    });
+  }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
+}
+
+// This is the special kernel to handle single specific case:
+// from(inclusive) = std::numeric_limits::lowest()
+// to(exclusive) = None (= std::numeric_limits::max() + 1)
+template
+void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
+    if constexpr (std::is_same::value ||
+        std::is_same::value ||
+        std::is_same::value ||
+        std::is_same::value) {
+      std::lock_guard lock(generator->mutex_);
+      cpu_serial_kernel(iter, [generator]() -> scalar_t {
+        uniform_int_full_range_distribution random;
+        return random(generator);
+      });
+    } else {
+      TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
+    }
+  });
+}
+
+template
+struct RandomFromToKernel {
+  void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional gen) {
+    random_from_to_kernel(iter, range, base, check_generator(gen));
+  }
+  void operator()(TensorIteratorBase& iter, c10::optional gen) {
+    random_full_64_bits_range_kernel(iter, check_generator(gen));
+  }
+};
+
+template
+void random_kernel(TensorIteratorBase& iter, RNG generator) {
+  std::lock_guard lock(generator->mutex_);
+  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
+    cpu_serial_kernel(iter, [generator]() -> scalar_t {
+      uniform_int_distribution random;
+      return random(generator);
+    });
+  });
+}
+
+template
+struct RandomKernel {
+  void operator()(TensorIteratorBase& iter, c10::optional gen) {
+    random_kernel(iter, check_generator(gen));
+  }
+};
+
+// ==================================================== Normal ========================================================
+
+#ifdef CPU_CAPABILITY_AVX2
+static void normal_fill_16_AVX2(float *data,
+                         const __m256* two_pi,
+                         const __m256* one,
+                         const __m256* minus_two,
+                         const __m256* mean,
+                         const __m256* std_v) {
+  const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
+  const __m256 u2 = _mm256_loadu_ps(data + 8);
+  // sincos256_ps and log256_ps are from avx_mathfun.h
+  const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
+  const __m256 theta = _mm256_mul_ps(*two_pi, u2);
+  __m256 sintheta, costheta;
+  sincos256_ps(theta, &sintheta, &costheta);
+  const __m256 n1 = _mm256_mul_ps(radius, costheta);
+  const __m256 n2 = _mm256_mul_ps(radius, sintheta);
+  _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
+  _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
+}
+
+template
+void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
+  float *data = self.data_ptr();
+  auto size = self.numel();
+  std::lock_guard lock(generator->mutex_);
+  for (const auto i : c10::irange(size)) {
+    at::uniform_real_distribution uniform(0, 1);
+    data[i] = uniform(generator);
+  }
+  const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi);
+  const __m256 one = _mm256_set1_ps(1.0f);
+  const __m256 minus_two = _mm256_set1_ps(-2.0f);
+  const __m256 mean_v = _mm256_set1_ps(mean);
+  const __m256 std_v = _mm256_set1_ps(std);
+
+  for (int64_t i = 0; i < size - 15; i += 16) {
+    normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
+  }
+
+  if (size % 16 != 0) {
+    // Recompute the last 16 values.
+    data = data + size - 16;
+    for (const auto i : c10::irange(16)) {
+      at::uniform_real_distribution uniform(0, 1);
+      data[i] = uniform(generator);
+    }
+    normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
+  }
+}
+#endif
+
+template 
+static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
+  for (const auto j : c10::irange(8)) {
+    const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
+    const scalar_t u2 = data[j + 8];
+    const scalar_t radius = std::sqrt(-2 * std::log(u1));
+    const scalar_t theta = 2.0f * c10::pi * u2;
+    data[j] = radius * std::cos(theta) * std + mean;
+    data[j + 8] = radius * std::sin(theta) * std + mean;
+  }
+}
+
+template 
+void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
+  scalar_t *data = self.data_ptr();
+  auto size = self.numel();
+  std::lock_guard lock(generator->mutex_);
+  for (const auto i : c10::irange(size)) {
+    at::uniform_real_distribution uniform(0, 1);
+    data[i] = uniform(generator);
+  }
+
+  for (int64_t i = 0; i < size - 15; i += 16) {
+    normal_fill_16(data + i, mean, std);
+  }
+  if (size % 16 != 0) {
+    // Recompute the last 16 values.
+    data = data + size - 16;
+    for (const auto i : c10::irange(16)) {
+      at::uniform_real_distribution uniform(0, 1);
+      data[i] = uniform(generator);
+    }
+    normal_fill_16(data, mean, std);
+  }
+}
+
+template
+void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
+  auto size = self.numel();
+  if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
+#ifdef CPU_CAPABILITY_AVX2
+    normal_fill_AVX2(self, static_cast(mean), static_cast(std), generator);
+#else
+    normal_fill(self, static_cast(mean), static_cast(std), generator);
+#endif
+  } else {
+    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
+      if (size >= 16 && self.is_contiguous()) {
+        normal_fill(self, static_cast(mean), static_cast(std), generator);
+      } else {
+        auto iter = TensorIterator::borrowing_nullary_op(self);
+        std::lock_guard lock(generator->mutex_);
+        cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
+          at::normal_distribution normal(mean, std);
+          return static_cast(normal(generator));
+        });
+      }
+    });
+  }
+}
+
+template
+struct NormalKernel {
+  void operator()(Tensor& self, double mean, double std, c10::optional gen) {
+    normal_kernel(self, mean, std, check_generator(gen));
+  }
+};
+
+// ==================================================== Uniform =======================================================
+
+template
+void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
+    std::lock_guard lock(generator->mutex_);
+    auto from = static_cast(from_);
+    auto to = static_cast(to_);
+    at::uniform_real_distribution uniform(from, to);
+    cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
+      return static_cast(uniform(generator));
+    });
+  });
+}
+
+template
+struct UniformKernel {
+  void operator()(TensorIteratorBase& iter, double from, double to, c10::optional gen) {
+    uniform_kernel(iter, from, to, check_generator(gen));
+  }
+};
+
+// ==================================================== Cauchy ========================================================
+
+template
+void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
+    std::lock_guard lock(generator->mutex_);
+    at::cauchy_distribution cauchy(median, sigma);
+    cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
+      return static_cast(cauchy(generator));
+    });
+  });
+}
+
+template
+struct CauchyKernel {
+  void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional gen) {
+    cauchy_kernel(iter, median, sigma, check_generator(gen));
+  }
+};
+
+// ================================================== LogNormal =======================================================
+
+template
+void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
+    std::lock_guard lock(generator->mutex_);
+    at::lognormal_distribution logNormal(mean, std);
+    cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
+      return static_cast(logNormal(generator));
+    });
+  });
+}
+
+template
+struct LogNormalKernel {
+  void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional gen) {
+    log_normal_kernel(iter, mean, std, check_generator(gen));
+  }
+};
+
+// =================================================== Geometric ======================================================
+
+template
+void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
+    std::lock_guard lock(generator->mutex_);
+    at::geometric_distribution geometric(p);
+    cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
+      return static_cast(geometric(generator));
+    });
+  });
+}
+
+template
+struct GeometricKernel {
+  void operator()(TensorIteratorBase& iter, double p, c10::optional gen) {
+    geometric_kernel(iter, p, check_generator(gen));
+  }
+};
+
+// ================================================== Exponential =====================================================
+
+template
+void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
+  TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
+    std::lock_guard lock(generator->mutex_);
+    at::exponential_distribution exponential(lambda);
+    cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
+      return static_cast(exponential(generator));
+    });
+  });
+}
+
+template
+struct ExponentialKernel {
+  void operator()(TensorIteratorBase& iter, double lambda, c10::optional gen) {
+    exponential_kernel(iter, lambda, check_generator(gen));
+  }
+};
+
+// ================================================== Bernoulli =======================================================
+
+template
+void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
+  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
+  self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(generator->mutex_);
+    using self_t = scalar_t;
+    auto p_cpu = p_.to(kCPU);
+    auto p = expand_inplace(self, p_cpu);
+    auto iter = TensorIteratorConfig()
+        .add_output(self)
+        .add_input(*p)
+        .check_all_same_dtype(false)
+        .build();
+    if (p->scalar_type() == kDouble) {
+      cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
+        at::bernoulli_distribution bernoulli(p_val);
+        return static_cast(bernoulli(generator));
+      });
+    } else {
+      AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
+      p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
+        using p_t = scalar_t;
+        cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
+          at::bernoulli_distribution bernoulli(p_val);
+          return static_cast(bernoulli(generator));
+        });
+      });
+    }
+  });
+}
+
+template
+void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
+  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
+  self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(generator->mutex_);
+    auto iter = TensorIterator::borrowing_nullary_op(self);
+    cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
+      at::bernoulli_distribution bernoulli(p);
+      return static_cast(bernoulli(generator));
+    });
+  });
+}
+
+template
+struct BernoulliKernel {
+  void operator()(const TensorBase &self, double p, c10::optional gen) {
+    bernoulli_kernel(self, p, check_generator(gen));
+  }
+  void operator()(const TensorBase &self, const TensorBase &p_, c10::optional gen) {
+    bernoulli_kernel(self, p_, check_generator(gen));
+  }
+};
+
+}}}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..137a578b3f74edc67511f0e9900fc7c320318916
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at { namespace native {
+
+using forward_2d_fn = void (*) (
+    const TensorBase &output,
+    const TensorBase &input,
+    const TensorBase &grid,
+    int64_t interpolation_mode,
+    int64_t padding_mode,
+    bool align_corners);
+using backward_2d_fn = void (*) (
+    const TensorBase &grad_input,
+    const TensorBase &grad_grid,
+    const TensorBase &grad_output,
+    const TensorBase &input,
+    const TensorBase &grid,
+    int64_t interpolation_mode,
+    int64_t padding_mode,
+    bool align_corners,
+    std::array output_mask);
+DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
+DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..de667f3fe287b8434743258dec2e40d29b99a9d9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h
@@ -0,0 +1,88 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+namespace {
+static bool is_constant_index(int ntensor, const int64_t* strides) {
+  AT_ASSERT(ntensor >= 3);
+  for (const auto arg : c10::irange(2, ntensor)) {
+    if (strides[arg] != 0) {
+      return false;
+    }
+  }
+  return true;
+}
+
+
+struct Indexer {
+  Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
+          IntArrayRef original_sizes, IntArrayRef original_strides)
+    : num_indexers(num_indexers)
+    , indexers(indexers)
+    , indexer_strides(indexer_strides)
+    , original_strides(original_strides.data())
+    , original_sizes(original_sizes.data()) {
+    AT_ASSERT(static_cast(original_strides.size()) == num_indexers);
+    AT_ASSERT(static_cast(original_sizes.size()) == num_indexers);
+  }
+
+  int64_t num_indexers;
+  char** indexers;
+  const int64_t* indexer_strides;
+  const int64_t* original_strides;
+  const int64_t* original_sizes;
+
+  int64_t get(int64_t idx) {
+    int64_t offset = 0;
+    for (const auto j : c10::irange(num_indexers)) {
+      int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
+      int64_t size = original_sizes[j];
+      TORCH_CHECK_INDEX(value >= -size && value < size,
+                        "index ", value, " is out of bounds for dimension ", j, " with size ", size);
+      if (value < 0) {
+        value += size;
+      }
+      offset += value * original_strides[j];
+    }
+    return offset;
+  }
+};
+} // anonymous namespace
+
+template 
+void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
+                      const func_t& f, bool serial_execution=false)
+{
+  int ntensor = iter.ntensors();
+  // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
+  // to make the whole available thread numbers get more balanced work load and a better cache location.
+  // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
+  const int index_parallel_grain_size = 3000;
+  auto loop = [&](char** data, const int64_t* strides, int64_t n) {
+    auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
+    char* dst = data[0];
+    char* src = data[1];
+    if (is_constant_index(ntensor, strides)) {
+      // specialization for when every element uses the same index
+      int64_t offset = indexer.get(0);
+      for (const auto i : c10::irange(n)) {
+        f(dst + strides[0] * i, src + strides[1] * i, offset);
+      }
+    } else {
+      for (const auto i : c10::irange(n)) {
+        int64_t offset = indexer.get(i);
+        f(dst + strides[0] * i, src + strides[1] * i, offset);
+      }
+    }
+  };
+  if (serial_execution) {
+    iter.serial_for_each(loop, {0, iter.numel()});
+  } else {
+    iter.for_each(loop, index_parallel_grain_size);
+  }
+}
+} // at
+} // native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Intrinsics.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Intrinsics.h
new file mode 100644
index 0000000000000000000000000000000000000000..c85239e5a7067907af8c7e903208f2d4338c8213
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Intrinsics.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
+/* Clang-compatible compiler, targeting x86/x86-64 */
+#include 
+#elif defined(_MSC_VER)
+/* Microsoft C/C++-compatible compiler */
+#include 
+#if _MSC_VER <= 1900
+#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
+#endif
+#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
+/* GCC-compatible compiler, targeting x86/x86-64 */
+#include 
+#elif defined(__GNUC__) && defined(__ARM_NEON__)
+/* GCC-compatible compiler, targeting ARM with NEON */
+#include 
+#elif defined(__GNUC__) && defined(__IWMMXT__)
+/* GCC-compatible compiler, targeting ARM with WMMX */
+#include 
+#elif (defined(__GNUC__) || defined(__xlC__)) && \
+    (defined(__VEC__) || defined(__ALTIVEC__))
+/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
+#include 
+/* We need to undef those tokens defined by  to avoid conflicts
+   with the C++ types. => Can still use __bool/__vector */
+#undef bool
+#undef vector
+#undef pixel
+#elif defined(__GNUC__) && defined(__SPE__)
+/* GCC-compatible compiler, targeting PowerPC with SPE */
+#include 
+#endif
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/IsContiguous.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/IsContiguous.h
new file mode 100644
index 0000000000000000000000000000000000000000..d521bd122114b9abb6f052cff689a3ca120e0942
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/IsContiguous.h
@@ -0,0 +1,62 @@
+#pragma once
+
+namespace at { namespace native { inline namespace CPU_CAPABILITY {
+
+// n: number of function arguments (arity)
+// traits: function_traits (see FunctionTraits.h)
+// s: index of scalar argument or -1
+template 
+struct IsContiguous {
+  static bool eval(const int64_t* strides) {
+    using type = typename traits::template arg::type;
+    return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
+           IsContiguous::eval(strides);
+  }
+};
+
+// will be called when there is an output exists
+template 
+struct IsContiguous<0, 0, traits, s> {
+  static bool eval(const int64_t* strides) {
+    return strides[0] == sizeof(typename traits::result_type);
+  }
+};
+
+// will be called when there is no output
+template 
+struct IsContiguous<0, -1, traits, s> {
+  static bool eval(const int64_t* /*strides*/) {
+    return true;
+  }
+};
+
+// output and all inputs are contiguous
+template ::value>::type* = nullptr>
+static inline bool is_contiguous(const int64_t* strides) {
+  return IsContiguous::eval(strides);
+}
+
+template ::value>::type* = nullptr>
+static inline bool is_contiguous(const int64_t* strides) {
+  return IsContiguous::eval(strides);
+}
+
+// input at `s` is scalar (stride 0); output and other inputs are contiguous
+// NB: output is typically at strides[0] so first input corresponds to s=1
+template ::value>::type* = nullptr>
+static inline bool is_contiguous_scalar(const int64_t* strides) {
+  static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
+  return IsContiguous::eval(strides);
+}
+
+template ::value>::type* = nullptr>
+static inline bool is_contiguous_scalar(const int64_t* strides) {
+  static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
+  return IsContiguous::eval(strides);
+}
+
+}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/LogAddExp.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/LogAddExp.h
new file mode 100644
index 0000000000000000000000000000000000000000..1bf461849ac82c56431eb23a5e651e71c09df7aa
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/LogAddExp.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+inline namespace CPU_CAPABILITY {
+
+// custom min and max to be used in logcumsumexp for complex arguments
+template 
+std::pair, c10::complex> _logcumsumexp_minmax(c10::complex x, c10::complex y) {
+  if (at::_isnan(y)) {  // either real is nan or imag is nan
+    return std::make_pair(y, y);
+  } else if (at::_isnan(x)) {  // either real is nan or imag is nan
+    return std::make_pair(x, x);
+  } else {
+    return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x);
+  }
+}
+
+template 
+scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
+  // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
+  scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
+  scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
+  if (min != max || std::isfinite(min)) {
+    // nan will be propagated here
+    return std::log1p(std::exp(min - max)) + max;
+  } else {
+    // special case to correctly handle infinite cases
+    return x;
+  }
+}
+
+template 
+c10::complex _log_add_exp_helper(const c10::complex& x, const c10::complex& y) {
+  auto [min, max] = _logcumsumexp_minmax(x, y);
+  auto min_real = std::real(min);
+  auto max_real = std::real(max);
+
+  if (at::_isnan(min)) {  // either real is nan or imag is nan
+    // handling the "infectious" NaNs
+    return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()};
+  } else if (!std::isfinite(min_real) && (min_real == max_real)) {
+    if (min_real < 0) {
+      // handle the -inf case, the imaginary part here does not really matter as the exp(value)
+      // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
+      // It does not matter if we're taking the exp of this value
+      return min;
+    } else {
+      // handle the +inf case, we don't need the special precision for log1p for small values
+      // and to avoid producing nan in case of real(max) == real(min) == +inf
+      return std::log(std::exp(min) + std::exp(max));
+    }
+  } else {
+    return std::log1p(std::exp(min - max)) + max;
+  }
+}
+
+} // end namespace
+}} //end at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Loops.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Loops.h
new file mode 100644
index 0000000000000000000000000000000000000000..016d4ded00f574af8068882ec339220d5f324cd2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Loops.h
@@ -0,0 +1,394 @@
+#pragma once
+
+// This file provides two functions to help write elementwise kernels:
+//
+//   cpu_kernel(TensorIterator iter, )
+//   cpu_kernel_vec(TensorIterator iter, , )
+//
+// Both functions may generate vectorized code. The cpu_kernel implementation
+// relies on the compiler's auto-vectorization. The cpu_kernel_vec
+// implementation uses x86 SIMD intrinsics when available. These functions
+// are only intended to be used in the ATen/native/cpu subdirectory, since files
+// in other directories are not compiled with AVX/AVX2 enabled. See README.md
+// for more details.
+//
+// For example, to write a multiplication kernel for float:
+//
+//   cpu_kernel(iter, [](float a, float b) { return a * b; });
+//
+// Or you may write:
+//
+//   cpu_kernel_vec(iter,
+//     [](float a, float b) { return a * b; },
+//     [](Vectorized a, Vectorized b) { return a * b; });
+//
+// See BinaryOpsKernel.cpp for the complete implementation
+//
+//
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at { namespace native { inline namespace CPU_CAPABILITY {
+
+using namespace vec;
+
+template 
+typename traits::ArgsTuple
+dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
+                 std::index_sequence) {
+  return std::make_tuple(
+      c10::load::type>(
+          data[INDEX] + i * strides[INDEX])...);
+}
+
+template 
+typename traits::ArgsTuple
+dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
+  using Indices = std::make_index_sequence;
+  return dereference_impl(data, strides, i, Indices{});
+}
+
+template 
+typename traits::ArgsTuple
+dereference_vec_impl(char* C10_RESTRICT data[],
+                     const typename traits::result_type& opt_scalar,
+                     size_t S,
+                     int64_t i,
+                     std::index_sequence) {
+  using Vec = typename traits::result_type;
+  using scalar_t = typename Vec::value_type;
+  return std::make_tuple(
+      S == INDEX + 1 ?
+      opt_scalar :
+      Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
+}
+
+template 
+typename traits::ArgsTuple
+dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
+  using Indices = std::make_index_sequence;
+  return dereference_vec_impl(data, opt_scalar, S, i, Indices{});
+}
+
+template ::result_type>::value>::type* = nullptr>
+static inline void
+execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
+  using traits = function_traits;
+  using result_type = typename traits::result_type;
+  for (; i < n; i++) {
+    result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
+    *out_ptr = c10::guts::apply(std::forward(op), dereference(
+        &data[1],
+        &strides[1],
+        i));
+  }
+}
+
+template ::result_type>::value>::type* = nullptr>
+static inline void
+execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
+  using traits = function_traits;
+  for (; i < n; i++) {
+    c10::guts::apply(std::forward(op), dereference(
+        &data[0],
+        &strides[0],
+        i));
+  }
+}
+
+// Basic loop operation (one output, N inputs). May be auto-vectorized
+// by the compiler. Supports inputs and outputs of different types.
+template 
+static inline void
+basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
+  using traits = function_traits;
+  constexpr int ntensors = traits::arity + 1;
+
+  // Copying strides to temporary array helps auto vectorization in older GCC
+  // versions.
+  int64_t strides[ntensors];
+  for (const auto arg : c10::irange(ntensors)) {
+    strides[arg] = strides_[arg];
+  }
+
+  execute_op(data, strides, i, n, std::forward(op));
+}
+
+// the recursive variadic template for iterating over the returned tuple
+template
+struct TupleOutput {
+  static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
+                     const T &tuple) {
+    TupleOutput::handle(data, strides, i, tuple);
+
+    auto output = std::get(tuple);
+    using output_type = decltype(output);
+    output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
+    *out_ptr = output;
+  }
+};
+
+// Base case for the above recursive template
+template
+struct TupleOutput {
+  static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
+                     const T &tuple) {
+    auto output = std::get<0>(tuple);
+    using output_type = decltype(output);
+    output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
+    *out_ptr = output;
+  }
+};
+
+template
+void handle_tuple_outputs(char* C10_RESTRICT data[],
+                          const int64_t* strides,
+                          int64_t i,
+                          const std::tuple &tuple) {
+  TupleOutput::handle(data, strides, i, tuple);
+}
+
+// Loop operation for `cpu_kernel_multiple_outputs`.
+// 1. Use `c10::guts::apply` to make dynamic method invocation
+//    for the lambda passed in `cpu_kernel_multiple_outputs`.
+// 2. Iterate over the members of the returned tuple, set the corresponding
+//    output tensor by the tuple member in `handle_tuple_outputs` function.
+template 
+static inline void
+multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
+  using traits = function_traits;
+
+  using result_type = typename traits::result_type;
+  constexpr int num_outputs = std::tuple_size::value;
+  constexpr int ntensors = traits::arity + num_outputs;
+
+  // Copying strides to temporary array helps auto vectorization in older GCC
+  // versions.
+  int64_t strides[ntensors];
+  for (const auto arg : c10::irange(ntensors)) {
+    strides[arg] = strides_[arg];
+  }
+
+  for (; i < n; i++) {
+    auto output = c10::guts::apply(op, dereference(
+      &data[num_outputs],
+      &strides[num_outputs],
+      i));
+    handle_tuple_outputs(data, strides, i, output);
+  }
+}
+
+// Explicitly vectorized loop implementation. All inputs and outputs must be
+// the same type and contiguous with one exception: a single input may be
+// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
+// is 0, then there are no scalar inputs.
+template 
+static inline void
+vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
+  using traits = function_traits;
+  using scalar_t = typename function_traits::result_type;
+  using Vec = Vectorized;
+  constexpr int ntensors = traits::arity + 1;
+
+  char* C10_RESTRICT data[ntensors];
+  for (const auto arg : c10::irange(ntensors)) {
+    data[arg] = data_[arg];
+  }
+
+  Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
+  int64_t i = 0;
+  for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
+    auto args1 = dereference_vec(&data[1], opt_scalar, S, i);
+    auto args2 = dereference_vec(&data[1], opt_scalar, S, i + Vec::size());
+    auto out1 = c10::guts::apply(std::forward(vop), std::move(args1));
+    auto out2 = c10::guts::apply(std::forward(vop), std::move(args2));
+    out1.store(data[0] + i * sizeof(scalar_t));
+    out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
+  }
+  if (i < n) {
+    int64_t strides[ntensors];
+    for (const auto arg : c10::irange(ntensors)) {
+      strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
+    }
+    basic_loop(data, strides, i, n, std::forward(op));
+  }
+}
+
+
+template 
+static inline void unroll_contiguous_scalar_checks(
+    const int64_t* /*strides*/,
+    std::index_sequence<>,
+    cb_t&& cb) {
+  cb(0);
+}
+
+template 
+static inline void unroll_contiguous_scalar_checks(
+    const int64_t* strides,
+    std::index_sequence,
+    cb_t&& cb) {
+  if (is_contiguous_scalar(strides)) {
+    cb(INDEX0 + 1);
+  } else {
+    unroll_contiguous_scalar_checks(strides, std::index_sequence{}, std::forward(cb));
+  }
+}
+
+template 
+struct VectorizedLoop2d {
+  op_t op;
+  vop_t vop;
+
+  using traits = function_traits;
+  static constexpr int ntensors = traits::arity + 1;
+  using data_t = std::array;
+
+  VectorizedLoop2d(const op_t &op, vop_t vop):
+    op(op), vop(std::move(vop)) {}
+
+  static void advance(data_t &data, const int64_t *outer_strides) {
+    for (const auto arg : c10::irange(data.size())) {
+      data[arg] += outer_strides[arg];
+    }
+  }
+
+  void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
+    data_t data;
+    std::copy_n(base, ntensors, data.data());
+    const int64_t *outer_strides = &strides[ntensors];
+
+    if (is_contiguous(strides)) {
+      for (const auto i C10_UNUSED : c10::irange(size1)) {
+        vectorized_loop(data.data(), size0, 0, op, vop);
+        advance(data, outer_strides);
+      }
+    } else {
+      using Indices = std::make_index_sequence;
+      unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t idx) {
+        if (idx) {
+          for (const auto i C10_UNUSED : c10::irange(size1)) {
+            vectorized_loop(data.data(), size0, idx, op, vop);
+            advance(data, outer_strides);
+          }
+        } else {
+          for (const auto i C10_UNUSED : c10::irange(size1)) {
+            basic_loop(data.data(), strides, 0, size0, op);
+            advance(data, outer_strides);
+          }
+        }
+      });
+    }
+  }
+};
+
+template 
+VectorizedLoop2d make_vectorized_loop2d(
+    const op_t &op, const vop_t &vop) {
+  return VectorizedLoop2d(op, vop);
+}
+
+template 
+void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
+  using traits = function_traits;
+  // this could be extended to work with void return types
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
+  TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
+  // dynamic casting not currently supported on CPU
+  TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter));
+
+  iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
+    // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
+    // iter.for_each is ever sending to the loop lambda
+      basic_loop(data, strides, 0, n, std::forward(op));
+  }, grain_size);
+  iter.cast_outputs();
+}
+
+// This function helps write elementwise kernels that requires multiple outputs.
+// It follows the similar structure of cpu_kernel.
+// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
+// manipulated to handle multiple return values.
+// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
+// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
+// The `gpu_kernel_multiple_outputs` is also implemented without this check,
+// We could extend `needs_dynamic_casting` to support both `std::tuple` and
+// `thrust::tuple` in the future.
+template 
+void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
+  using traits = function_traits;
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
+
+  iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
+    multiple_outputs_loop(data, strides, 0, n, std::forward(op));
+  }, grain_size);
+  iter.cast_outputs();
+}
+
+template 
+void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
+  using traits = function_traits;
+  // this could be extended to work with void return types
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
+  TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
+  // dynamic casting not currently supported on CPU, but some kernels (like Fill)
+  // explicitly dynamic_cast, so we give the opt-out of checking.
+  if constexpr (check_dynamic_cast) {
+    TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter));
+  }
+
+  iter.for_each(make_vectorized_loop2d(op, vop), grain_size);
+  iter.cast_outputs();
+}
+
+template 
+void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
+  using traits = function_traits;
+  constexpr bool result_void = std::is_void::value;
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
+                        ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
+  // dynamic casting not currently supported on CPU
+  TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter));
+
+  iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
+    basic_loop(data, strides, 0, n, std::forward(op));
+  }, range);
+  iter.cast_outputs();
+}
+
+template 
+void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
+  cpu_serial_kernel(iter, op, {0, iter.numel()});
+}
+
+template 
+void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
+  using traits = function_traits;
+  // this could be extended to work with void return types
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
+  TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
+  // dynamic casting not currently supported on CPU
+  TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter));
+
+  iter.serial_for_each(make_vectorized_loop2d(op, vop), range);
+  iter.cast_outputs();
+}
+
+template 
+void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
+  cpu_serial_kernel_vec(iter, op, vop, {0, iter.numel()});
+}
+
+}}}  // namespace at::native::
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..0ea0bed4b1eae1e12dec1163835607310532f495
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h
@@ -0,0 +1,14 @@
+#pragma once
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
+DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
+
+}} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..1cb98b008bf531ad5e987176f5b8dbc2acf73872
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h
@@ -0,0 +1,14 @@
+#pragma once
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at { namespace native {
+
+using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
+DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel);
+DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel);
+
+}} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Reduce.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Reduce.h
new file mode 100644
index 0000000000000000000000000000000000000000..d02a1bcc5171f1738909aa28b7cf0390522697f1
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/Reduce.h
@@ -0,0 +1,314 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at { namespace native { inline namespace CPU_CAPABILITY {
+
+using namespace vec;
+
+#define VEC_LOOP_HEADER(func_t, data) \
+  using scalar_t = typename function_traits::result_type; \
+  using Vec = Vectorized; \
+  char* out_ptr = data[0]; \
+  (void) out_ptr;
+
+// reduction that is contiguous over the input in dim 0
+template 
+static inline bool is_contiguous_reduction(const int64_t* strides) {
+  return strides[0] == 0 &&
+         strides[1] == sizeof(typename traits::arg2_t);
+}
+
+// reduction that is contiguous over the input in dim 1
+template 
+static inline bool is_outer_reduction(const int64_t* strides) {
+  return strides[0] == 0 &&
+         strides[2] == sizeof(typename traits::result_type) &&
+         strides[3] == sizeof(typename traits::arg2_t);
+}
+
+template 
+static inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
+                                        func_t op, vec_func_t vop, bool reduce) {
+  VEC_LOOP_HEADER(func_t, data)
+  const char* in1_ptr = data[1];
+  Vec acc[4];
+  for (const auto j : c10::irange(4)) {
+    acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t));
+  }
+  for (const auto i : c10::irange(1, n)) {
+    const char* ptr = in1_ptr + stride * i;
+    acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t))));
+    acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t))));
+    acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
+    acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
+  }
+  if (reduce) {
+    scalar_t buffer[Vec::size()];
+    acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
+    acc[0].store(buffer);
+    for (const auto j : c10::irange(1, Vec::size())) {
+      buffer[0] = op(buffer[0], buffer[j]);
+    }
+    auto dst = (scalar_t*)out_ptr;
+    *dst = op(*dst, buffer[0]);
+  } else {
+    for (const auto j : c10::irange(4)) {
+      auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t);
+      acc[j] = vop(acc[j], Vec::loadu(dst));
+      acc[j].store(dst);
+    }
+  }
+}
+
+template 
+static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
+  for (const auto j C10_UNUSED : c10::irange(n)) {
+    f();
+    data[0] += strides[0];
+    data[1] += strides[1];
+  }
+}
+
+// computes the reduction out = op(out, in)
+template 
+static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
+  VEC_LOOP_HEADER(func_t, data)
+  int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
+  int64_t count = n / (4 * Vec::size());
+  if (count > 0) {
+    vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
+  }
+  char* ptrs[3] = { data[0], data[0], data[1] };
+  int64_t strides[] = { 0, 0, sizeof(scalar_t) };
+  basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op);
+}
+
+// computes the reduction out = op(out, in)
+template 
+static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
+  VEC_LOOP_HEADER(func_t, data)
+
+  // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
+#if defined(CPU_CAPABILITY_AVX512)
+  int64_t outer_stride[2] = { 256, 256 };
+#else
+  int64_t outer_stride[2] = { 128, 128 };
+#endif
+  UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
+    vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
+  });
+
+  // reduce down the remaining columns
+  int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) };
+  int64_t remaining = size1 % (4 * Vec::size());
+  UNARY_OUTER_LOOP(data, step, remaining, [&] {
+    char* ptrs[3] = { data[0], data[0], data[1] };
+    int64_t strides[] = { 0, 0, inner_stride };
+    basic_loop(ptrs, strides, 0, size0, op);
+  });
+}
+
+template
+static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
+  // static_assert(std::is_same::value, "data types must match");
+  if (index < num_outputs) {
+    char *out = (char *) iter.data_ptr(index);
+    *(res_t *) out = result;
+  }
+}
+
+template
+static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
+  AT_ASSERT(num_outputs == 1);
+  set_result(0, result, iter, num_outputs);
+}
+
+template
+static inline typename std::enable_if::type
+for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) {
+  return i;
+}
+
+template
+static inline typename std::enable_if::type
+for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) {
+  if (i < (size_t)num_outputs) {
+    set_result(i, std::get(t), iter, num_outputs);
+    return for_each_in_tuple(t, iter, num_outputs);
+  }
+  return i;
+}
+
+template
+static void set_results(const std::tuple& result, const TensorIteratorBase &iter, const int num_outputs) {
+  AT_ASSERT(num_outputs >= 1);
+  std::size_t result_size = for_each_in_tuple(result, iter, num_outputs);
+  AT_ASSERT((size_t)num_outputs == result_size);
+}
+
+template 
+struct all_same : std::conjunction<
+  std::is_same...
+> {};
+
+// data_t is the input/output data type.
+// acc_t is a type that contains all the necessary data
+// to continue reducing.
+// index_t is a one-dimensional index
+//
+// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy
+// the following.
+// reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value.
+// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
+// project: acc_t -> out_t finishes the reduction, getting the required output.
+//
+// Additionally, acc_t must be default-constructible:
+// acc_t {} is an identity for combine,
+// and project(acc_t {}) is the value of the operation on zero elements.
+//
+// The point of `combine` is to support parallelization -
+// the idea is to one sequence of `reduce` calls per thread of execution,
+// and then to combine them at the end with `combine`.
+//
+// If there is more than one output element,
+// our parallelization strategy is to use one thread for each of them,
+// which means that `combine` will never be called.
+//
+// If, on the other hand, there is only one, then we split the input into
+// into several pieces, reduce each separately, and then combine them.
+
+template 
+void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
+  using rf_t = decltype(&ops_t::reduce);
+  using cf_t = decltype(&ops_t::combine);
+  using pf_t = decltype(&ops_t::project);
+  using r_traits = binary_function_traits;
+  using c_traits = binary_function_traits;
+  using p_traits = unary_function_traits;
+  using acc_t = typename p_traits::arg1_t;
+  using data_t = typename r_traits::arg2_t;
+  static_assert(
+    all_same<
+      acc_t,
+      init_t,
+      typename r_traits::arg1_t,
+      typename r_traits::result_type,
+      typename c_traits::arg1_t,
+      typename c_traits::arg2_t,
+      typename c_traits::result_type>::value,
+    "all accumulate types must match");
+  static_assert(
+    std::is_default_constructible::value,
+    "the accumulate type must be default-constructible"
+  );
+  const int num_outputs = iter.noutputs();
+  iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) {
+    auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t {
+      int ntensors = sub_iter.ntensors();
+      sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) {
+        AT_ASSERT(ntensors - num_outputs == 1);
+        char *in = data[ntensors - 1];
+        int64_t stride = strides[ntensors - 1];
+        for (const auto i : c10::irange(size)) {
+          acc = ops.reduce(acc, c10::load(in), begin + i);
+          in += stride;
+        }
+      }, {begin, end});
+      return ops.translate_idx(acc, sub_iter.view_offsets()[0]);
+    };
+    acc_t total_acc = init;
+    auto numel = sub_iter.numel();
+    if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+        at::in_parallel_region()) {
+      total_acc = reduction_body(total_acc, 0, numel);
+    } else {
+      int max_threads = at::get_num_threads();
+      AT_ASSERT(max_threads > 0);
+      static_assert(
+        !std::is_same::value,
+        "Concurrently modifying different references into std::vector is UB."
+      );
+      std::vector buffer((unsigned)max_threads, init);
+      at::parallel_for(0, numel, internal::GRAIN_SIZE,
+        [&](int64_t begin, int64_t end) {
+          auto& acc = buffer[at::get_thread_num()];
+          acc = reduction_body(acc, begin, end);
+        }
+      );
+      for (const auto i : c10::irange(max_threads)) {
+        total_acc = ops.combine(total_acc, buffer[i]);
+      }
+    }
+    set_results(ops.project(total_acc), sub_iter, num_outputs);
+  });
+}
+
+template 
+void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
+  using traits = binary_function_traits;
+  static_assert(
+    all_same<
+      typename traits::result_type,
+      typename traits::arg1_t,
+      typename traits::arg2_t>::value,
+    "all types must match");
+
+  iter.output_base().fill_(ident);
+  iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
+    int64_t outer_strides[] = { strides[2], strides[3] };
+    if (is_contiguous_reduction(strides)) {
+      // input is contiguous in dim 0, output is reduced in dim 0
+      UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
+        vectorized_inner_reduction(data, size0, op, vop);
+      });
+    } else if (is_outer_reduction(strides)) {
+      // input and output are contiguous in dim 1
+      int64_t inner_stride = strides[1]; // stride of input in dim 0
+      vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop);
+    } else {
+      UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
+        char* ptrs[3] = { data[0], data[0], data[1] };
+        int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
+        basic_loop(ptrs, inner_strides, 0, size0, op);
+      });
+    }
+  });
+}
+
+// when reduction is on most inner dimension (dim 0 in TensorIterator)
+// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
+// can be used.
+static inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
+  return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
+      && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
+}
+
+template 
+void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) {
+  auto shape = iter.shape();
+  int64_t dim_size = shape[0];
+  int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size);
+  TensorIterator sub_iter(iter);
+  // create sub iterator to parallel on all non-reduce-dims
+  sub_iter.narrow(0, 0, 1);
+  auto loop = [&](char** data, const int64_t* strides, int64_t size) {
+    char* out = data[0];
+    char* in = data[1];
+    for (int64_t i = 0; i < size; ++i) {
+      reduce_op(out, in, dim_size);
+      out += strides[0];
+      in += strides[1];
+    }
+  };
+  sub_iter.for_each(loop, grain_size);
+}
+
+}}}  // namespace at::native::
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..8f8c30e6fcebee4252f5272787599c1fabf67861
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h
@@ -0,0 +1,238 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+inline namespace CPU_CAPABILITY {
+
+using namespace vec;
+
+#define AT_DISPATCH_REDUCTION_TYPES(op, ...)                                   \
+  [&] {                                                                        \
+    switch (op) {                                                              \
+      case ReductionType::SUM: {                                               \
+        static constexpr auto reduce = ReductionType::SUM;                     \
+        return __VA_ARGS__();                                                  \
+      }                                                                        \
+      case ReductionType::MEAN: {                                              \
+        static constexpr auto reduce = ReductionType::MEAN;                    \
+        return __VA_ARGS__();                                                  \
+      }                                                                        \
+      case ReductionType::MIN: {                                               \
+        static constexpr auto reduce = ReductionType::MIN;                     \
+        return __VA_ARGS__();                                                  \
+      }                                                                        \
+      case ReductionType::MAX: {                                               \
+        static constexpr auto reduce = ReductionType::MAX;                     \
+        return __VA_ARGS__();                                                  \
+      }                                                                        \
+      case ReductionType::PROD: {                                              \
+        static constexpr auto reduce = ReductionType::PROD;                    \
+        return __VA_ARGS__();                                                  \
+      }                                                                        \
+    }                                                                          \
+  }()
+
+template 
+inline vec_scalar_t init_value() {
+  using acc_t = vec_scalar_t;
+  acc_t val;
+  if (reduce == ReductionType::SUM ||
+      reduce == ReductionType::MEAN) {
+    val = static_cast(0);
+  } else if (reduce == ReductionType::PROD) {
+    val = static_cast(1);
+  } else if (reduce == ReductionType::MAX) {
+    val = -std::numeric_limits::infinity();
+  } else {
+    TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
+    val = std::numeric_limits::infinity();
+  }
+  return val;
+}
+
+template 
+inline vec_scalar_t init_value(const c10::optional& initial) {
+  using acc_t = vec_scalar_t;
+  if (initial.has_value()) {
+    return initial.value().to();
+  } else {
+    return init_value();
+  }
+}
+
+template 
+inline void init(scalar_t* out, int64_t size, const vec_scalar_t& val) {
+  using Vec = Vectorized>;
+  map(
+      [val](Vec x) { return Vec(val); },
+      out,
+      out,
+      size);
+}
+
+template 
+inline void init(scalar_t* out, int64_t size, const c10::optional& initial) {
+  using acc_t = vec_scalar_t;
+  acc_t val = init_value(initial);
+  init(out, size, val);
+}
+
+// overload with `include_self`, used by scatter_reduce
+template 
+inline void init(scalar_t* out, int64_t size, bool include_self = false) {
+  using acc_t = vec_scalar_t;
+  if (!include_self) {
+    acc_t val = init_value();
+    init(out, size, val);
+  }
+}
+
+template 
+inline void _init(scalar_t* self_ptr, at::opmath_type* buffer_ptr, int64_t size, bool include_self) {
+  if (!include_self) {
+    init, reduce>(buffer_ptr, size, include_self);
+  } else {
+    vec::convert(self_ptr, buffer_ptr, size);
+  }
+}
+
+template 
+inline typename std::enable_if::value, scalar_t>::type
+_max(const scalar_t& x, const scalar_t& y) {
+  return at::_isnan(y) ? y : std::max(x, y);
+}
+
+template 
+inline Vectorized _max(const Vectorized& x, const Vectorized& y) {
+  // vec::maximum propagates NaN
+  return vec::maximum(x, y);
+}
+
+template 
+inline typename std::enable_if::value, Vec2>::type
+_max(const vec_t& x, const vec_t& y) {
+  // vec::maximum propagates NaN
+  return maximum(x, y);
+}
+
+template 
+inline typename std::enable_if::value, scalar_t>::type
+_min(const scalar_t& x, const scalar_t& y) {
+  return at::_isnan(y) ? y : std::min(x, y);
+}
+
+template 
+inline Vectorized _min(const Vectorized& x, const Vectorized& y) {
+  // vec::minimum propagates NaN
+  return vec::minimum(x, y);
+}
+
+template 
+inline typename std::enable_if::value, Vec2>::type
+_min(const vec_t& x, const vec_t& y) {
+  // vec::minimum propagates NaN
+  return minimum(x, y);
+}
+
+template , int> = 0>
+inline void map_acc(
+    const Op& vec_fun,
+    accumut* output_data,
+    const accumut* input_data,
+    const scalar_t* input_data2,
+    int64_t size) {
+  using Vec = vec::Vectorized;
+  using aVec = vec::Vectorized;
+  int64_t d = 0;
+  constexpr int64_t kVecSize = Vec::size();
+  constexpr int64_t kaVecSize = aVec::size();
+  for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
+    Vec data2_vec = Vec::loadu(input_data2 + d);
+    auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec);
+    aVec input_vec0 = aVec::loadu(input_data + d);
+    aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
+    vec_fun(input_vec0, data2_avec0).store(output_data + d);
+    vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
+  }
+  if (size - d > 0) {
+    int64_t tail_size = size - d;
+    Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
+    auto [data2_avec0, data2_avec1] = convert_to_float(data2_vec);
+    if (tail_size > kaVecSize) {
+      aVec input_vec0 = aVec::loadu(input_data + d);
+      aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
+      vec_fun(input_vec0, data2_avec0).store(output_data + d);
+      vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
+    } else {
+      aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
+      vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
+    }
+  }
+}
+
+// for Max and Min, propagate NaN:
+template 
+inline T update(const T& x, const T& y) {
+  if (reduce == ReductionType::SUM ||
+      reduce == ReductionType::MEAN) {
+    return x + y;
+  } else if (reduce == ReductionType::PROD) {
+    return x * y;
+  } else if (reduce == ReductionType::MAX) {
+    return _max(x, y);
+  } else {
+    TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
+    return _min(x, y);
+  }
+}
+
+template 
+inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
+  using Vec = vec::Vectorized>;
+  map2(
+      [](Vec x, Vec y) { return update(x, y); },
+      out,
+      out,
+      data,
+      K);
+}
+
+template , int> = 0>
+inline void update(at::opmath_type* out, const scalar_t* data, int64_t K) {
+  using opmath_t = at::opmath_type;
+  using Vec = vec::Vectorized;
+  map_acc(
+      [](Vec x, Vec y) { return update(x, y); },
+      out,
+      out,
+      data,
+      K);
+}
+
+template 
+inline void write(scalar_t* out, int64_t count, int64_t K) {
+  using Vec = vec::Vectorized>;
+  if (reduce == ReductionType::MEAN) {
+    if (count > 0) {
+      vec::map(
+          [count](Vec x) { return x / Vec(count); },
+          out,
+          out,
+          K);
+    }
+  }
+}
+
+} // namespace CPU_CAPABILITY
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..ebaf769f148a7887c5473aaa54b0f05fc55715ef
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h
@@ -0,0 +1,12 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+
+using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&);
+
+DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub);
+
+}} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h
new file mode 100644
index 0000000000000000000000000000000000000000..aa6785e41706f2b3eea51c7821c8a388ab866e4f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h
@@ -0,0 +1,144 @@
+// Copyright 2004-present Facebook. All Rights Reserved.
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at { namespace native { namespace detail {
+
+struct InputMeta {
+  void* data_ptr;
+  int64_t inner_size;
+
+  InputMeta(const Tensor& t, int64_t dim, int64_t inner)
+      : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
+};
+
+// This kernel is used by two TensorList types:
+// 1. stack_serial_kernel uses at::ArrayRef
+// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
+//    ProcessedNodeInputWrapper.
+// When making changes, make sure that they are compatible with both types!
+template 
+void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+      dim >= 0 && dim <= result.dim(),
+      "dim out of range in stack_serial_kernel_impl");
+  int64_t outer =
+      result.numel() / (result.sizes()[dim] * result.strides()[dim]);
+  scalar_t* result_data = result.data_ptr();
+  int64_t ninputs = tensors.size();
+  std::vector inputs;
+  inputs.reserve(ninputs);
+  for (const auto& tensor : tensors) {
+    inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
+  }
+
+  using Vec = vec::Vectorized;
+  scalar_t* result_ptr = result_data;
+  for (const auto i : c10::irange(outer)) {
+    for (const auto j : c10::irange(ninputs)) {
+      int64_t local_inner = inputs[j].inner_size;
+      scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
+
+      if (local_inner < Vec::size()) {
+        for (const auto k : c10::irange(local_inner)) {
+          result_ptr[k] = input_ptr[k];
+        }
+      } else {
+        vec::map(
+            [](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
+      }
+      result_ptr += local_inner;
+    }
+  }
+}
+
+// Checks to see whether native stack can be invoked under these conditions:
+// - result and input tensors are contiguous
+// - only one thread is used
+// - no type promotion has to occur
+// - tensors dtype is Double or Float
+template 
+bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
+  TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
+  const Tensor& first_tensor = tensors[0];
+  // stack dimension should be in range [0,firstTensor.dim())
+  // dim == firstTensor.dim() is a valid input, but it is handled by default code path
+  // that uses unsqueeze
+  if (dim >= first_tensor.dim()) return false;
+  // Native stack doesn't apply any tensor is skipped.
+  if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
+  // there should be no type promotion
+  if (result.dtype() != first_tensor.dtype()) return false;
+
+  auto first_tensor_mem_format = first_tensor.suggest_memory_format();
+  ScalarType dtype = first_tensor.scalar_type();
+
+  if (!result.is_contiguous(first_tensor_mem_format)) {
+    return false;
+  }
+
+  // fast path only works for Double and Float
+  if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
+    return false;
+  }
+
+  // check remainder of inputs
+  auto const &first_tensor_shape = first_tensor.sizes();
+  for (const auto i : c10::irange(1, tensors.size())) {
+    auto const &tensor = tensors[i];
+    TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
+      "stack expects each tensor to be equal size, but got ", first_tensor_shape,
+      " at entry 0 and ", tensor.sizes(), " at entry ", i);
+
+    // every tensor must be contiguous
+    // tensor sizes and strides must be the same
+    // there should be no type promotion
+    if (!tensor.is_contiguous(first_tensor_mem_format) ||
+      tensor.strides() != first_tensor.strides() ||
+      tensor.dtype() != dtype) {
+      return false;
+    }
+  }
+
+  // fast native stack should only be used when it is not worth using multiple threads
+  // or there is only one thread. Note that we aren't checking result.numel() here because
+  // it may not have been resized and we want to defer that cost till later.
+  int64_t numel_in_stack = first_tensor.numel() * tensors.size();
+  return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
+}
+
+template 
+struct CanUseNativeSerialStack;
+
+template 
+struct CanUseNativeSerialStack {
+  static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
+    // Inputs cannot alias the output tensor
+    for (const auto i : c10::irange(tensors.size())) {
+      auto lap = at::get_overlap_status(result, tensors[i]);
+      TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
+          lap != at::MemOverlapStatus::Full, 0,
+          "unsupported operation: the input tensors cannot refer to any of the "
+          "output memory locations. Found overlap in input tensor ", i);
+    }
+
+    return can_use_native_serial_stack_impl(result, tensors, dim);
+  }
+};
+
+template 
+struct CanUseNativeSerialStack {
+  static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
+    return can_use_native_serial_stack_impl(result, tensors, dim);
+  }
+};
+
+}}}  // namespace at::native::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..a5e1124e54f2aa299ce2cc45e370c599610d3bb4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at {
+class Tensor;
+
+namespace native {
+
+using forward_fn = void (*)(const Tensor&, const Tensor&);
+using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
+
+DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
+DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
+DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
+DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
+
+using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
+using backward_fn_with_dim =
+    void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
+
+DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
+DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
+DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
+DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
+}
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..7ae6de525c371d14093ebfe8c19f4198f9eef921
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
+using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
+using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
+using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
+using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
+
+DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
+DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
+DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
+DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
+DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
+DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
+
+} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/StackKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/StackKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..f8259c1f546bf6804407b1357618ef564a076609
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/StackKernel.h
@@ -0,0 +1,12 @@
+// Copyright 2004-present Facebook. All Rights Reserved.
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+
+using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t);
+DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h
new file mode 100644
index 0000000000000000000000000000000000000000..c5ee2344bc178f9736da0f353d02ecb543bfd598
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h
@@ -0,0 +1,1376 @@
+/*
+The Python Imaging Library (PIL) is
+
+    Copyright © 1997-2011 by Secret Labs AB
+    Copyright © 1995-2011 by Fredrik Lundh
+
+Pillow is the friendly PIL fork. It is
+
+    Copyright © 2010-2022 by Alex Clark and contributors
+
+Like PIL, Pillow is licensed under the open source HPND License
+*/
+
+// This code is heavily inspired from PILLOW-SIMD's implementation:
+// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c
+
+#pragma once
+#ifdef CPU_CAPABILITY_AVX2
+// TODO: This file only supports AVX2. We could split the AVX kernels into
+// smaller logical blocks in order to port them into the Vec.h logic. This would
+// allow to support other vectorization architectures and perhaps also support
+// the non-vectorized fallback (we'd need to make sure it's not slower than the
+// current fallback).
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#endif
+
+
+namespace {
+
+static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
+  int32_t v;
+  if (i32_aligned) {
+    v = *(const int32_t*)ptr;
+  } else {
+    std::memcpy(&v, ptr, 4);
+  }
+  return _mm_cvtsi32_si128(v);
+}
+
+static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
+  return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
+}
+
+static inline void _write_endline_rgb_as_uint32(
+    uint8_t* C10_RESTRICT output,
+    uint32_t data
+) {
+  // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
+  // Here we explicitly set X as R1
+  uint8_t* data_ptr = reinterpret_cast(&data);
+  data_ptr[3] = output[3];
+  std::memcpy(output, data_ptr, 4);
+}
+
+at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
+  // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
+  // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
+  // into as 32 bits. This generalizes to num_channels <= 4 and also works for
+  // non-channels_last tensors.
+
+  const uint8_t* packed = (const uint8_t*)packed_tensor.data_ptr();
+  auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
+  auto num_channels = packed_tensor.size(0);
+
+  constexpr int rgba_size = 4;
+  auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
+  uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr();
+
+  auto stride_i = packed_tensor.stride(2);
+  auto stride_j = packed_tensor.stride(0);
+
+  for (const auto i : c10::irange(num_pixels)) {
+    for (const auto j : c10::irange(rgba_size)) {
+      unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
+    }
+  }
+  return unpacked_tensor;
+}
+
+void pack_rgb(
+    const at::Tensor& unpacked_tensor, // IN
+    const at::Tensor& packed_tensor // OUT
+) {
+  // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.
+
+  uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr();
+  uint8_t* packed = (uint8_t*)packed_tensor.data_ptr();
+  auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
+  auto num_channels = packed_tensor.size(0);
+
+  auto unpacked_increment = unpacked_tensor.size(0);
+  auto packed_increment = packed_tensor.stride(2);
+  auto packed_stride = packed_tensor.stride(0);
+
+  TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);
+
+  for (const auto i C10_UNUSED : c10::irange(num_pixels)) {
+    for (const auto j : c10::irange(num_channels)) {
+      packed[j * packed_stride] = unpacked[j];
+    }
+    unpacked += unpacked_increment;
+    packed += packed_increment;
+  }
+}
+
+void ImagingResampleHorizontalConvolution8u4x(
+    uint8_t* C10_RESTRICT lineOut0,
+    uint8_t* C10_RESTRICT lineOut1,
+    uint8_t* C10_RESTRICT lineOut2,
+    uint8_t* C10_RESTRICT lineOut3,
+    int64_t out_xsize,
+    const uint8_t* C10_RESTRICT lineIn0,
+    const uint8_t* C10_RESTRICT lineIn1,
+    const uint8_t* C10_RESTRICT lineIn2,
+    const uint8_t* C10_RESTRICT lineIn3,
+    int64_t in_xsize,
+    const int64_t* idx_ptr_xmin,
+    const int64_t* idx_ptr_size,
+    const int16_t* kk,
+    int kmax,
+    unsigned int coefs_precision,
+    int64_t num_channels,
+    bool is_last_line);
+
+void ImagingResampleHorizontalConvolution8u(
+    uint8_t* C10_RESTRICT lineOut,
+    int64_t out_xsize,
+    const uint8_t* C10_RESTRICT lineIn,
+    int64_t in_xsize,
+    const int64_t* idx_ptr_xmin,
+    const int64_t* idx_ptr_size,
+    const int16_t* kk,
+    int kmax,
+    unsigned int coefs_precision,
+    int64_t num_channels,
+    bool is_last_line);
+
+void ImagingResampleVerticalConvolution8u(
+    uint8_t* C10_RESTRICT lineOut,
+    const uint8_t* C10_RESTRICT lineIn,
+    int64_t xsize,
+    int64_t ids_min,
+    int64_t ids_size,
+    const int16_t* k,
+    unsigned int coefs_precision,
+    int64_t num_channels);
+
+template
+void ImagingResampleHorizontal(
+    const at::Tensor & unpacked_output,
+    const at::Tensor & unpacked_input,
+    int ksize,
+    const std::vector& horiz_indices_weights,
+    unsigned int horiz_weights_precision) {
+
+  // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.
+
+  // Input data is stored as
+  //   input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
+  // Weights are float values computed for each output pixel and rescaled to uint16:
+  //   weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
+  // We want to compute the output as following:
+  //   output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
+  // where
+  //   oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
+  //   oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
+  //   oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
+  //
+
+  // TODO: we may want to merge that into the fallback code (currently called
+  // basic_loop_aa_horizontal)
+  // Although this may not be needed if / when we port all this code to use
+  // Vec.h since this would potentially give us another fall-back implem
+
+  const int16_t* kk = (int16_t*)(horiz_indices_weights[3].data_ptr());
+
+  auto xout = unpacked_output.size(2);
+  auto yout = unpacked_output.size(1);
+  auto xin = unpacked_input.size(2);
+  TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));
+
+  const int64_t* idx_ptr_xmin = horiz_indices_weights[0].data_ptr();
+  const int64_t* idx_ptr_size = horiz_indices_weights[1].data_ptr();
+
+  uint8_t* unpacked_output_p = unpacked_output.data_ptr();
+  const uint8_t* unpacked_input_p = unpacked_input.data_ptr();
+
+  int64_t yy = 0;
+  auto xout_stride = xout * num_channels;
+  auto xin_stride = xin * num_channels;
+  for (; yy < yout - 3; yy += 4) {
+    ImagingResampleHorizontalConvolution8u4x(
+        unpacked_output_p + yy * xout_stride,
+        unpacked_output_p + (yy + 1) * xout_stride,
+        unpacked_output_p + (yy + 2) * xout_stride,
+        unpacked_output_p + (yy + 3) * xout_stride,
+        xout,
+        unpacked_input_p + yy * xin_stride,
+        unpacked_input_p + (yy + 1) * xin_stride,
+        unpacked_input_p + (yy + 2) * xin_stride,
+        unpacked_input_p + (yy + 3) * xin_stride,
+        xin,
+        idx_ptr_xmin,
+        idx_ptr_size,
+        kk,
+        ksize,
+        horiz_weights_precision,
+        num_channels,
+        yy + 3 == yout - 1);
+  }
+  for (; yy < yout; yy++) {
+    ImagingResampleHorizontalConvolution8u(
+        unpacked_output_p + yy * xout_stride,
+        xout,
+        unpacked_input_p + yy * xin_stride,
+        xin,
+        idx_ptr_xmin,
+        idx_ptr_size,
+        kk,
+        ksize,
+        horiz_weights_precision,
+        num_channels,
+        yy == yout - 1);
+  }
+}
+
+void ImagingResampleVertical(
+    const at::Tensor & unpacked_output,
+    const at::Tensor & unpacked_input,
+    int ksize,
+    const std::vector& vert_indices_weights,
+    unsigned int vert_weights_precision) {
+
+  // Interpolation vertical pass: we compute y-axis interpolation outputs.
+  // Input data is stored as
+  //   input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
+  // Weights are float values computed for each output pixel and rescaled to uint16:
+  //   weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
+  // We want to compute the output as following:
+  //   output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
+  // where
+  //   oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
+  //   oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
+  //   oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
+
+  // TODO: we may want to merge that into the fallback code (currently called
+  // basic_loop_aa_vertical)
+  // Although this may not be needed if / when we port all this code to use
+  // Vec.h since this would potentially give us another fall-back implem
+  const int16_t* kk = (int16_t*)(vert_indices_weights[3].data_ptr());
+
+  const int64_t* idx_ptr_xmin = vert_indices_weights[0].data_ptr();
+  const int64_t* idx_ptr_size = vert_indices_weights[1].data_ptr();
+
+  uint8_t* unpacked_output_p = unpacked_output.data_ptr();
+  const uint8_t* unpacked_input_p = unpacked_input.data_ptr();
+
+  auto xout = unpacked_output.size(2);
+  auto yout = unpacked_output.size(1);
+  const auto num_channels = unpacked_input.size(0);
+  TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
+
+  auto xout_stride = xout * num_channels;
+  for (const auto yy : c10::irange(yout)) {
+    const auto* k = &kk[yy * ksize];
+    auto ids_min = idx_ptr_xmin[yy];
+    auto ids_size = idx_ptr_size[yy];
+    ImagingResampleVerticalConvolution8u(
+        unpacked_output_p + yy * xout_stride,
+        unpacked_input_p,
+        xout,
+        ids_min,
+        ids_size,
+        k,
+        vert_weights_precision,
+        num_channels);
+  }
+}
+
+// This is the only public entry point in this file.  It supports bilinear or bicubic
+// mode for uint8 dtype when C <= 4, with or without antialias. The
+// implem is based on PIL-SIMD.
+// Its equivalent implementation (fallback) for when AVX isn't supported or when
+// C > 4 is separable_upsample_generic_Nd_kernel_impl()  There are a bunch of
+// future improvement that can be done: look for the TODOs in this file.
+// For details on how the weights are computed and how the multiplications are
+// run on int (instead of float weights), see
+// [ Weights computation for uint8_t and multiplication trick ]
+// For details on how the AVX kernels are implemented, see
+// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
+// See also [ Support for antialias=False as a subcase of antialias=True ] to
+// learn more about how the antialias=False case is computed. The same holds
+// here: all these kernels are general enough to handle an arbitrary number of
+// weights, but when aa=False they could be optimized further.
+template 
+void upsample_avx_bilinear_bicubic_uint8(
+    const at::Tensor& input_,
+    const at::Tensor& output,
+    bool align_corners,
+    const scale_type& scales,
+    bool antialias) {
+  auto batch_size = input_.size(0);
+  auto num_channels = input_.size(1);
+  auto xin = input_.size(3);
+  auto yin = input_.size(2);
+  auto xout = output.size(3);
+  auto yout = output.size(2);
+
+  if (xin == xout && yin == yout) {
+    output.copy_(input_);
+    return;
+  }
+
+  at::Tensor input = input_;
+  if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
+    // If input is not contiguous with memory format channels first or channels last,
+    // we explicitly convert the input to contiguous channels last memory format.
+    // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
+    // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
+    // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
+    // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
+    input = input.contiguous(at::MemoryFormat::ChannelsLast);
+  }
+
+  auto need_horizontal = xout != xin;
+  auto need_vertical = yout != yin;
+
+  int ksize_horiz, ksize_vert;
+  std::vector horiz_indices_weights, vert_indices_weights;
+  unsigned int horiz_weights_precision, vert_weights_precision;
+
+  bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
+  bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);
+
+  if (need_horizontal) {
+    int interp_dim = 3;
+    auto stride = (skip_unpacking) ? num_channels : 4;
+    std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
+        F::compute_index_ranges_int16_weights(
+            /*input_size=*/xin,
+            /*output_size=*/xout,
+            /*stride=*/stride,
+            /*ndims=*/4,
+            /*reshape_dim=*/interp_dim,
+            /*align_corners=*/align_corners,
+            /*opt_scale=*/scales[interp_dim - 2],
+            /*antialias=*/antialias,
+            /*align_i32=*/true);
+  }
+
+  if (need_vertical) {
+    int interp_dim = 2;
+    auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout;
+    std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
+        F::compute_index_ranges_int16_weights(
+            /*input_size=*/yin,
+            /*output_size=*/yout,
+            /*stride=*/stride,
+            /*ndims=*/4,
+            /*reshape_dim=*/interp_dim,
+            /*align_corners=*/align_corners,
+            /*opt_scale=*/scales[interp_dim - 2],
+            /*antialias=*/antialias,
+            /*align_i32=*/true);
+  }
+
+  at::Tensor buffer_horiz, buffer_vert;
+  // Minor optimization: we can avoid allocating an extra buffer if we're performing
+  // horizontal-only or vertical-only interpolation, and if the tensor doesn't
+  // need repacking
+  if (need_horizontal && (need_vertical || !skip_packing)) {
+    auto c = (skip_unpacking) ? num_channels : 4;
+    buffer_horiz = at::empty({c, yin, xout}, input.options());
+  }
+  if (need_vertical && !skip_packing) {
+    auto c = (skip_unpacking) ? num_channels : 4;
+    buffer_vert = at::empty({c, yout, xout}, input.options());
+  }
+
+  for (const auto i : c10::irange(batch_size)) {
+
+    at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]);
+    at::Tensor unpacked_output;
+
+    if (need_horizontal) {
+      at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];
+
+      if (skip_unpacking && num_channels == 3) {
+        ImagingResampleHorizontal<3>(
+          unpacked_output_temp,
+          unpacked_input,
+          ksize_horiz,
+          horiz_indices_weights,
+          horiz_weights_precision);
+      } else {
+        ImagingResampleHorizontal<4>(
+            unpacked_output_temp,
+            unpacked_input,
+            ksize_horiz,
+            horiz_indices_weights,
+            horiz_weights_precision);
+      }
+      unpacked_output = unpacked_input = unpacked_output_temp;
+    }
+    if (need_vertical) {
+      unpacked_output = (skip_packing) ? output[i] : buffer_vert;
+
+      ImagingResampleVertical(
+          unpacked_output,
+          unpacked_input,
+          ksize_vert,
+          vert_indices_weights,
+          vert_weights_precision
+      );
+    }
+
+    TORCH_INTERNAL_ASSERT(unpacked_output.defined());
+
+    if (!skip_packing) {
+      pack_rgb(unpacked_output, output[i]);
+    }
+  }
+}
+
+void ImagingResampleHorizontalConvolution8u4x(
+    uint8_t* C10_RESTRICT lineOut0,
+    uint8_t* C10_RESTRICT lineOut1,
+    uint8_t* C10_RESTRICT lineOut2,
+    uint8_t* C10_RESTRICT lineOut3,
+    int64_t out_xsize,
+    const uint8_t* C10_RESTRICT lineIn0,
+    const uint8_t* C10_RESTRICT lineIn1,
+    const uint8_t* C10_RESTRICT lineIn2,
+    const uint8_t* C10_RESTRICT lineIn3,
+    int64_t in_xsize,
+    const int64_t* idx_ptr_xmin,
+    const int64_t* idx_ptr_size,
+    const int16_t* kk,
+    int kmax,
+    unsigned int coefs_precision,
+    int64_t num_channels,
+    bool is_last_line) {
+
+  // Interpolation horizontal pass processing together 4 vertical lines.
+  // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
+  //   we can encode 4 values as a single uint32 value.
+  // - We split the size of weight vector for a given output index as a sum:
+  //   ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
+  // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
+  // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
+
+  // Define shuffling masks (low/high) for num_channels 4 and 3
+  // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
+  //   [r1 g1 b1 a1  r2 g2 b2 a2  ... | R1 G1 B1 A1  R2 G2 B2 A2 ... ] ->
+  //   [r1 0 r2 0  g1 0 g2 0  b1 0 b2 0  a1 0 a2 0 | R1 0 R2 0  G1 0 G2 0  B1 0 B2 0  A1 0 A2 0]
+  // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
+  //   [ ... r3 g3 b3 a3  r4 g4 b4 a4 | ... R3 G3 B3 A3  R4 G4 B4 A4 ] ->
+  //   [r3 0 r4 0  g3 0 g4 0  b3 0 b4 0  a3 0 a4 0 | R3 0 R4 0  G3 0 G4 0  B3 0 B4 0  A3 0 A4 0]
+
+  const auto mask_low_c4 = _mm256_set_epi8(
+      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
+      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
+  const auto mask_high_c4 = _mm256_set_epi8(
+      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
+      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
+  const auto mask_low_c3 = _mm256_set_epi8(
+      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
+      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
+  const auto mask_high_c3 = _mm256_set_epi8(
+      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
+      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
+
+  const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
+  const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
+
+  const auto stride = num_channels * sizeof(uint8_t);
+
+  TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
+
+  // out_xsize = output width, out_x = output x index
+  // ids_min is the input offset index corresponding to out_x
+  // ids_size is the interpolation size for out_x
+
+  // Let's precompute ids_size limits for block 4 and block 2.
+  //
+  // In block 4 (4 means we process 4 weight values together), we read input data
+  // with _mm_loadu_si128, i.e. 16 bytes, per one line:
+  // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
+  // --> i <= ids_size - 16.0 / stride
+  // Strict boundary:
+  // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
+  // Soft boundary for reading inside the buffer except its boundaries:
+  // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
+  // RGBA: b4_delta = b4_delta_soft = 3
+  // RGB : b4_delta = 5
+  // RGB : b4_delta_soft = 4
+  const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
+
+  // In block 2 (2 means we process 2 weights values together), we read input data
+  // with _mm_loadl_epi64, i.e. 8 bytes, per one line:
+  // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
+  // --> i <= ids_size - 8.0 / stride
+  // Strict boundary:
+  // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
+  // Soft boundary for reading inside the buffer except its boundaries:
+  // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
+  // RGBA: b2_delta = b2_delta_soft = 1
+  // RGB : b2_delta = 2
+  // RGB : b2_delta_soft = 1
+  const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
+
+  const auto max_out_x_strided = out_xsize * stride;
+  const auto max_in_x_strided = in_xsize * stride;
+
+  const auto zero = _mm256_setzero_si256();
+  const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
+
+  for (const auto out_x : c10::irange(out_xsize)) {
+    const auto ids_min = idx_ptr_xmin[out_x];
+    const auto ids_size = idx_ptr_size[out_x];
+    const auto * k = &kk[out_x * kmax];
+    int64_t i = 0;
+
+    auto sss0 = initial;
+    auto sss1 = initial;
+
+    const auto * lineIn0_min = lineIn0 + ids_min;
+    const auto * lineIn1_min = lineIn1 + ids_min;
+    const auto * lineIn2_min = lineIn2 + ids_min;
+    const auto * lineIn3_min = lineIn3 + ids_min;
+
+    // block 4
+    for (; i < ids_size - b4_delta; i += 4) {
+      // Load 4 values from weight vector
+      // mmk0 = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...]
+      // mmk1 = [wl_2 wh_2 wl_3 wh_3  wl_2 wh_2 wl_3 wh_3  ...]
+      const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
+      const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);
+
+      // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
+      // source = [
+      //   r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
+      //   R0 G0 B0 A0  R1 G1 B1 A1  R2 G2 B2 A2  R3 G3 B3 A3
+      // ]
+      // RGB: Load 10 pixels (5 per line)
+      // source = [
+      //   r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
+      //   R0 G0 B0 R1  G1 B1 R2 G2  B2 R3 G3 B3  R4 G4 B4 R5
+      // ]
+      auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
+          _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);
+
+      // Apply mask_low:
+      // RGBA:
+      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  A0 0 A1 0]
+      // RGB:
+      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  0 0 0 0]
+      auto pix1 = _mm256_shuffle_epi8(source, mask_low);
+      // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));
+
+      // Apply mask_high:
+      // RGBA:
+      //   [r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  a2 0 a3 0 | R2 0 R3 0  G2 0 G3 0  B2 0 B3 0  A2 0 A3 0]
+      // RGB:
+      //   [r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  0 0 0 0 | R2 0 R3 0  G2 0 G3 0  B2 0 B3 0  0 0 0 0]
+      auto pix2 = _mm256_shuffle_epi8(source, mask_high);
+      // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));
+
+      // Same as above to next lines 2 and 3:
+      auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
+          _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
+      auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
+      auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
+    }
+
+    // block 2
+    for (; i < ids_size - b2_delta; i += 2) {
+      // Load 2 values from weight vector
+      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...]
+      const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
+
+      // Load 4 pixels (2 per line) from input lines 0 and 1:
+      // RGBA: source1 = [
+      //   r0 g0 b0 a0  r1 g1 b1 a1  0 0 0 0  0 0 0 0
+      //   R0 G0 B0 A0  R1 G1 B1 A1  0 0 0 0  0 0 0 0
+      // ]
+      // RGB: source1 = [
+      //   r0 g0 b0 r1  g1 b1 r2  0 0 0 0  0 0 0 0
+      //   R0 G0 B0 R1  G1 B1 R2  0 0 0 0  0 0 0 0
+      // ]
+      auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
+          _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
+      // Apply mask_low:
+      // RGBA:
+      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  A0 0 A1 0]
+      // RGB:
+      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  0 0 0 0]
+      auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
+      // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
+
+      // Same as above for lines 2 and 3:
+      auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
+          _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
+      auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
+    }
+
+    // block 1
+    const auto i32_aligned = num_channels == 4;
+    for (; i < ids_size - 1; i++) {
+      // Load 1 value from weight vector
+      // mmk = [wl_0 wh_0 0 0  wl_0 wh_0 0 0  ...]
+      const auto mmk = _mm256_set1_epi32(k[i]);
+
+      // Load 2 pixels (one per line) from input lines 0 and 1:
+      // RGBA: pix1 = [
+      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  a0 0 0 0
+      //   R0 0 0 0  G0 0 0 0  B0 0 0 0  A0 0 0 0
+      // ]
+      // RGB: pix1 = [
+      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  r1 0 0 0
+      //   R0 0 0 0  G0 0 0 0  B0 0 0 0  R1 0 0 0
+      // ]
+      auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
+          mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
+      // Compute output value as C += w0 * C0 for each channel in 32-bit precision
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
+
+      // Same as above for lines 2 and 3
+      auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
+          mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
+    }
+
+    if (i == ids_size - 1) {
+      // last element
+      auto mmk = _mm256_set1_epi32(k[i]);
+      // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
+      // lines 0, 1 and 2 wont go out of allocated memory bounds
+      auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
+          mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
+          mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));
+
+      auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
+      __m128i p1;
+      if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
+        uint8_t input[4];
+        std::memcpy(input, lineIn3_min + stride * i, 3);
+        p1 = mm_cvtepu8_epi32(input, true);
+      } else {
+        p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
+      }
+      auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
+    }
+
+    // Convert fixed point values back to integers (truncating)
+    sss0 = _mm256_srai_epi32(sss0, coefs_precision);
+    sss1 = _mm256_srai_epi32(sss1, coefs_precision);
+    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
+    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
+    sss0 = _mm256_packs_epi32(sss0, zero);
+    sss1 = _mm256_packs_epi32(sss1, zero);
+    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
+    // (a a b b c c d d) -> (a b c d 0 0 0 0)
+    sss0 = _mm256_packus_epi16(sss0, zero);
+    sss1 = _mm256_packus_epi16(sss1, zero);
+
+    // Write the output into single uint32
+    // (a b c d) -> x_uint32
+    auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
+    auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
+    auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
+    auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));
+
+    const auto out_x_strided = stride * out_x;
+
+    if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
+      // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
+      // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
+      // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
+      // value which was previously computed by another line. In other words, it means that we can not overwrite
+      // it by simply writing 4 bytes from the register to the output. We'll do the following:
+      //               v----------|
+      // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
+      // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
+      // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
+      // Output = [... R G B | R1 G1 B1 R2 ...]
+
+      _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
+      _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
+      _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);
+
+      if (C10_UNLIKELY(is_last_line)) {
+        // When we handle the last line, we can not access the next 4 bytes
+        // as they are out of memory bounds.
+        std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
+      } else {
+        _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
+      }
+    } else if (num_channels == 3) {
+      // Memcpy 4-bytes is faster than 3-bytes and here
+      // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
+      // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
+      std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
+      std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
+      std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
+      std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
+    } else {
+      // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
+      *(uint32_t *)(lineOut0 + out_x_strided) = o0;
+      *(uint32_t *)(lineOut1 + out_x_strided) = o1;
+      *(uint32_t *)(lineOut2 + out_x_strided) = o2;
+      *(uint32_t *)(lineOut3 + out_x_strided) = o3;
+    }
+  }
+}
+
+void ImagingResampleHorizontalConvolution8u(
+    uint8_t* C10_RESTRICT lineOut,
+    int64_t out_xsize,
+    const uint8_t* C10_RESTRICT lineIn,
+    int64_t in_xsize,
+    const int64_t* idx_ptr_xmin,
+    const int64_t* idx_ptr_size,
+    const int16_t* kk,
+    int kmax,
+    unsigned int coefs_precision,
+    int64_t num_channels,
+    bool is_last_line) {
+
+  // Interpolation horizontal pass processing only one vertical line.
+  // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
+  //   we can encode 4 values as a single uint32 value.
+  // - We split the size of weight vector for a given output index as a sum:
+  //   ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
+  // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
+  // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
+
+  // Define various shuffling masks
+  const auto kmask_low = _mm256_set_epi8(
+      11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
+      3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
+  const auto kmask_high = _mm256_set_epi8(
+      15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
+      7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
+  const auto kmask_hl = _mm256_set_epi8(
+      7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
+      3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
+
+  const auto mask_low_c4 = _mm256_set_epi8(
+      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
+      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
+  const auto mask_high_c4 = _mm256_set_epi8(
+      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
+      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
+  const auto mask_low_c3 = _mm256_set_epi8(
+      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
+      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
+  const auto mask_high_c3 = _mm256_set_epi8(
+      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
+      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
+  const auto mask_hl_c3 = _mm256_set_epi8(
+      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
+      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
+  const auto mask_hl_c4 = _mm256_set_epi8(
+      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
+      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
+
+  const auto mask_low128_c3 = _mm_set_epi8(
+      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
+  const auto mask_low128_c4 = _mm_set_epi8(
+      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
+
+  const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
+  const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
+  const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
+  const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;
+
+  // out_xsize = output width, out_x = output x index
+  // ids_min is the input offset index corresponding to out_x
+  // ids_size is the interpolation size for out_x
+
+  const auto stride = num_channels * sizeof(uint8_t);
+  const auto zero = _mm_setzero_si128();
+
+  TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
+
+  // Let's precompute ids_size limits for block 8, block 4 and block 2
+  //
+  // In block 8 (8 means we process 8 weight values together), we read at
+  // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
+  // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
+  // --> i <= ids_size - 32.0 / stride
+  // Strict boundary:
+  // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
+  // Soft boundary for reading inside the buffer except its boundaries:
+  // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
+  // RGBA: b8_delta = b8_delta_soft = 7
+  // RGB : b8_delta = 10
+  // RGB : b8_delta_soft = 9
+  const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9);
+
+  // In block 4 (4 means we process 4 weight values together), we read
+  // 16 bytes of input data.
+  // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
+  // --> i <= ids_size - 16.0 / stride
+  // Strict boundary:
+  // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
+  // Soft boundary for reading inside the buffer except its boundaries:
+  // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
+  // RGBA: b4_delta = b4_delta_soft = 3
+  // RGB : b4_delta = 5
+  // RGB : b4_delta_soft = 4
+  const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
+
+  // In block 2 (2 means we process 2 weight values together), we read
+  // 8 bytes of input data.
+  // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
+  // --> i <= ids_size - 8.0 / stride
+  // Strict boundary:
+  // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
+  // Soft boundary for reading inside the buffer except its boundaries:
+  // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
+  // RGBA: b2_delta = b2_delta_soft = 1
+  // RGB : b2_delta = 2
+  // RGB : b2_delta_soft = 1
+  const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
+
+  const auto max_out_x_strided = out_xsize * stride;
+  const auto max_in_x_strided = in_xsize * stride;
+
+  for (const auto out_x : c10::irange(out_xsize)) {
+    __m128i sss;
+    const auto ids_min = idx_ptr_xmin[out_x];
+    const auto ids_size = idx_ptr_size[out_x];
+    const auto * k = &kk[out_x * kmax];
+    int64_t i = 0;
+
+    const auto * lineIn_min = lineIn + ids_min;
+
+    if (ids_size < 8) {
+      sss = _mm_set1_epi32(1 << (coefs_precision - 1));
+    } else {
+      // Lower part will be added to higher, use only half of the error
+      auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));
+
+      // block 8
+      for (; i < ids_size - b8_delta; i += 8) {
+        // Load 8 values from weight vector
+        auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
+        // ksource = [
+        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  wl_4 wh_4 wl_5 wh_5  wl_6 wh_6 wl_7 wh_7
+        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  wl_4 wh_4 wl_5 wh_5  wl_6 wh_6 wl_7 wh_7
+        // ]
+        auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
+
+        // RGBA: Load 8 pixels from input:
+        // source = [
+        //    r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
+        //    r4 g4 b4 a4  r5 g5 b5 a5  r6 g6 b6 a6  r7 g7 b7 a7
+        // ]
+        // RGB: Load 10 pixels from input (however we can process only 8 pixels):
+        // source = [
+        //    r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
+        //    r4 g4 b4 r5  g5 b5 r6 g6  b6 r7 g7 b7  r8 g8 b8 r9
+        // ]
+        auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
+            _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
+            _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
+
+        // Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA
+        // RGBA: pix1 = [
+        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0
+        //   r4 0 r5 0  g4 0 g5 0  b4 0 b5 0  a4 0 a5 0
+        // ]
+        // RGB: pix1 = [
+        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0
+        //   r4 0 r5 0  g4 0 g5 0  b4 0 b5 0  0 0 0 0
+        // ]
+        auto pix1 = _mm256_shuffle_epi8(source, mask_low);
+        // mmk1 = [
+        //   wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...  ...
+        //   wl_4 wh_4 wl_5 wh_5  wl_4 wh_4 wl_5 wh_5  ...  ...
+        // ]
+        auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
+        // Compute output value as
+        //   C += w0 * C0 + w1 * C1
+        //   C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
+        sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));
+
+        // Same as above for higher part of each lane
+        auto pix2 = _mm256_shuffle_epi8(source, mask_high);
+        auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
+        // Compute output value as
+        //    C += w2 * C2 + w3 * C3
+        //    C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
+        sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
+      }
+
+      // block 4
+      for (; i < ids_size - b4_delta; i += 4) {
+        // Load 4 values from weight vector
+        auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
+        // ksource = [
+        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  0 0 0 0  0 0 0 0
+        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  0 0 0 0  0 0 0 0
+        // ]
+        auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
+
+        // Load pixels from input line
+        tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
+        // RGBA: source = [
+        //   r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
+        //   r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
+        // ]
+        // RGB: source = [
+        //   r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
+        //   r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
+        // ]
+        auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
+
+        // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
+        // RGBA: pix = [
+        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0
+        //   r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  a2 0 a3 0
+        // ]
+        // RGB: pix = [
+        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0
+        //   r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  0 0 0 0
+        // ]
+        auto pix = _mm256_shuffle_epi8(source, mask_hl);
+        // mmk = [
+        //   wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ... ...
+        //   wl_2 wh_2 wl_3 wh_3  wl_2 wh_2 wl_3 wh_3  ... ...
+        // ]
+        auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
+        // Compute output value as
+        //   C += w0 * C0 + w1 * C1
+        //   C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
+        sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
+      }
+
+      // Sum results between the lanes
+      sss = _mm_add_epi32(
+          _mm256_extracti128_si256(sss256, 0),
+          _mm256_extracti128_si256(sss256, 1));
+    }
+
+    // block 2
+    for (; i < ids_size - b2_delta; i += 2) {
+      // Load 2 values from weight vector
+      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...]
+      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
+      // Load pixels from input line
+      // RGBA: source = [
+      //   r0 g0 b0 a0  r1 g1 b1 a1  0 0 0 0  0 0 0 0
+      // ]
+      // RGB: source = [
+      //   r0 g0 b0 r1  g1 b1 r2 g2  0 0 0 0  0 0 0 0
+      // ]
+      auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
+      // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
+      auto pix = _mm_shuffle_epi8(source, mask_low128);
+      // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+
+    // block 1
+    const auto i32_aligned = num_channels == 4;
+    for (; i < ids_size - 1; i++) {
+      // Load 1 value from weight vector
+      // mmk = [wl_0 wh_0 0 0  wl_0 wh_0 0 0  ...]
+      auto mmk = _mm_set1_epi32(k[i]);
+      // Load one pixel from input line
+      // RGBA: pix = [
+      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  a0 0 0 0
+      // ]
+      // RGB: pix = [
+      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  r1 0 0 0
+      // ]
+      auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
+      // Compute output value as C += w0 * C0 for each channel in 32-bit precision
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+
+    if (i == ids_size - 1) {
+      // last element
+      auto mmk = _mm_set1_epi32(k[i]);
+      __m128i pix;
+      auto p = lineIn_min + stride * i;
+      if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
+        uint8_t input[4];
+        std::memcpy(input, p, 3);
+        pix = mm_cvtepu8_epi32(input, true);
+      } else {
+        pix = mm_cvtepu8_epi32(p, i32_aligned);
+      }
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+
+    // Convert fixed point values back to integers (truncating)
+    sss = _mm_srai_epi32(sss, coefs_precision);
+    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
+    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
+    sss = _mm_packs_epi32(sss, zero);
+    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
+    // (a a b b c c d d) -> (a b c d 0 0 0 0)
+    sss = _mm_packus_epi16(sss, zero);
+    // Write the output into single uint32
+    // (a b c d) -> x_uint32
+    auto o = _mm_cvtsi128_si32(sss);
+    const auto out_x_strided = stride * out_x;
+    if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
+      if (C10_UNLIKELY(is_last_line)) {
+        // When we handle the last line, we can not access the next 4 bytes
+        // as they are out of memory bounds.
+        std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
+      } else {
+        // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
+        // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
+        // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
+        // value which was previously computed by another line. In other words, it means that we can not overwrite
+        // it by simply writing 4 bytes from the register to the output. We'll do the following:
+        //               v----------|
+        // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
+        // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
+        // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
+        // Output = [... R G B | R1 G1 B1 R2 ...]
+        _write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
+      }
+    } else if (num_channels == 3) {
+      // Memcpy 4-bytes is faster than 3-bytes and here
+      // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
+      // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
+      std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
+    } else {
+      // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
+      *(uint32_t *)(lineOut + out_x_strided) = o;
+    }
+  }
+}
+
+void ImagingResampleVerticalConvolution8u(
+    uint8_t* C10_RESTRICT lineOut,
+    const uint8_t* C10_RESTRICT lineIn,
+    int64_t xsize,
+    int64_t ids_min,
+    int64_t ids_size,
+    const int16_t* k,
+    unsigned int coefs_precision,
+    int64_t num_channels) {
+
+  // Interpolation vertical pass processing one line.
+  // - We process x-axis data with blocks of 8, 2 and 1
+  // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.
+
+  // xsize = output width, also equals to input width
+  // ids_size = interpolation size
+  // ids_min = input y start index
+  const auto stride = num_channels * sizeof(uint8_t);
+
+  TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
+
+  const int64_t data_size = xsize * stride;
+  const int64_t data_stride = stride;
+  constexpr auto vec_size = 256 / 8;
+
+  const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
+  const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
+  const auto zero = _mm_setzero_si128();
+  const auto zero_256 = _mm256_setzero_si256();
+
+  int64_t j = 0;
+  // block 8
+  const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
+  for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
+    auto sss0 = initial_256;
+    auto sss1 = initial_256;
+    auto sss2 = initial_256;
+    auto sss3 = initial_256;
+    int64_t i = 0;
+    const auto * lineIn_min = lineIn + j + ids_min;
+
+    for (; i < ids_size - 1; i += 2) {
+      // Load 2 values from weight vector
+      auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
+
+      // RGBA: Load 8 pixels per line
+      // source1 = [
+      //    r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
+      //    r4 g4 b4 a4  r5 g5 b5 a5  r6 g6 b6 a6  r7 g7 b7 a7
+      // ]
+      // RGB: Load 10 pixels per line (however we can process only 8 pixels):
+      // source1 = [
+      //    r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
+      //    r4 g4 b4 r5  g5 b5 r6 g6  b6 r7 g7 b7  r8 g8 b8 r9
+      // ]
+      auto source1 =
+          _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
+      auto source2 =
+          _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));
+
+      // Interleave source1 and source2 from the low half of each 128-bit lane
+      // and cast the result to epi16
+      // RGBA: pix1 = [
+      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  a0 0 A0 0
+      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  a1 0 A1 0
+      // ]
+      // RGB: pix1 = [
+      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  0 0 0 0
+      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  0 0 0 0
+      // ]
+      auto source_lo = _mm256_unpacklo_epi8(source1, source2);
+      auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
+      // Compute output value as
+      //   C += w0 * c0 + w1 * C0
+      //   C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
+
+      // RGBA: pix2 = [
+      //    r2 0 R2 0  g2 0 G2 0  b2 0 B2 0  a2 0 A2 0
+      //    r3 0 R3 0  g3 0 G3 0  b3 0 B3 0  a3 0 A3 0
+      // ]
+      // RGB: pix2 = [
+      //    r2 0 R2 0  g2 0 G2 0  b2 0 B2 0  0 0 0 0
+      //    r3 0 R3 0  g3 0 G3 0  b3 0 B3 0  0 0 0 0
+      // ]
+      auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
+      // Compute output value as
+      //   C += w0 * c2 + w1 * C2
+      //   C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
+
+      // Same as above for the high half of each 128-bit lane
+      auto source_hi = _mm256_unpackhi_epi8(source1, source2);
+      auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
+      sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
+      auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
+      sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
+    }
+    // Same processing as above but with a single weight value
+    for (; i < ids_size; i += 1) {
+      auto mmk = _mm256_set1_epi32(k[i]);
+
+      auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));
+
+      auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
+      auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
+      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
+      auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
+      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
+
+      auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
+      auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
+      sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
+      auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
+      sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
+    }
+    // Convert fixed point values back to integers (truncating)
+    sss0 = _mm256_srai_epi32(sss0, coefs_precision);
+    sss1 = _mm256_srai_epi32(sss1, coefs_precision);
+    sss2 = _mm256_srai_epi32(sss2, coefs_precision);
+    sss3 = _mm256_srai_epi32(sss3, coefs_precision);
+    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
+    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
+    sss0 = _mm256_packs_epi32(sss0, sss1);
+    sss2 = _mm256_packs_epi32(sss2, sss3);
+    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
+    // (a a b b c c d d) -> (a b c d)
+    sss0 = _mm256_packus_epi16(sss0, sss2);
+
+    // Stores 32 bytes
+    _mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
+  }
+
+  // TODO: Do we also need block 4 ???
+  // block 2
+  const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
+  for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
+    auto sss0 = initial;
+    auto sss1 = initial;
+    int64_t i = 0;
+    const auto * lineIn_min = lineIn + j + ids_min;
+
+    for (; i < ids_size - 1; i += 2) {
+      // Load 2 values from weight vector
+      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ... ]
+      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
+
+      // Load 2 pixels per line
+      // RGBA: source1 = [
+      //    r0 g0 b0 a0  r1 g1 b1 a1  0 0 0 0  0 0 0 0
+      // ]
+      // RGB: source1 = [
+      //    r0 g0 b0 r1  g1 b1 r2 g2  0 0 0 0  0 0 0 0
+      // ]
+      auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
+      auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
+      // Interleave source1 and source2 and cast the result to epi16
+      // RGBA: pix = [
+      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  a0 0 A0 0
+      // ]
+      // RGB: pix = [
+      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  0 0 0 0
+      // ]
+      auto source = _mm_unpacklo_epi8(source1, source2);
+      auto pix = _mm_unpacklo_epi8(source, zero);
+      // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
+      sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
+      // RGBA: pix = [
+      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  a1 0 A1 0
+      // ]
+      // RGB: pix = [
+      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  0 0 0 0
+      // ]
+      pix = _mm_unpackhi_epi8(source, zero);
+      // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
+      sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
+    }
+    // Same processing as above but with a single weight value
+    for (; i < ids_size; i += 1) {
+      auto mmk = _mm_set1_epi32(k[i]);
+
+      auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));
+
+      auto source = _mm_unpacklo_epi8(source1, zero);
+      auto pix1 = _mm_unpacklo_epi8(source, zero);
+      sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
+      auto pix2 = _mm_unpackhi_epi8(source, zero);
+      sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
+    }
+    // Convert fixed point values back to integers (truncating)
+    sss0 = _mm_srai_epi32(sss0, coefs_precision);
+    sss1 = _mm_srai_epi32(sss1, coefs_precision);
+    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
+    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
+    sss0 = _mm_packs_epi32(sss0, sss1);
+    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
+    // (a a b b c c d d) -> (a b c d)
+    sss0 = _mm_packus_epi16(sss0, sss0);
+    // Store 2 pixels to the output
+    _mm_storel_epi64((__m128i*)(lineOut + j), sss0);
+  }
+
+  // block 1
+  const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
+  const auto i32_aligned = num_channels == 4;
+  for (; j < data_size - 4; j += b1_usable_vec_stride) {
+    auto sss = initial;
+    int64_t i = 0;
+    const auto * lineIn_min = lineIn + j + ids_min;
+
+    for (; i < ids_size - 1; i += 2) {
+      // Load 2 values from weight vector
+      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ... ]
+      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
+
+      // Load one pixel per line
+      // RGBA: source1 = [
+      //    r0 g0 b0 a0  0 0 0 0  0 0 0 0  0 0 0 0
+      // ]
+      // RGB: source1 = [
+      //    r0 g0 b0 r1  0 0 0 0  0 0 0 0  0 0 0 0
+      // ]
+      auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
+      auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
+
+      // Interleave source1 and source2 and cast the result to epi16
+      // RGBA: pix = [
+      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  a0 0 A0 0
+      // ]
+      // RGB: pix = [
+      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  0 0 0 0
+      // ]
+      auto source = _mm_unpacklo_epi8(source1, source2);
+      auto pix = _mm_unpacklo_epi8(source, zero);
+      // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+
+    for (; i < ids_size; i++) {
+      auto mmk = _mm_set1_epi32(k[i]);
+      auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+    sss = _mm_srai_epi32(sss, coefs_precision);
+    sss = _mm_packs_epi32(sss, zero);
+    sss = _mm_packus_epi16(sss, zero);
+
+    auto o = _mm_cvtsi128_si32(sss);
+
+    // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
+    // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
+    // We also wont go out of bounds of lineOut memory allocation
+    std::memcpy(lineOut + j, (uint8_t *) &o, 4);
+  }
+
+  for (; j < data_size; j += data_stride) {
+    auto sss = initial;
+    int64_t i = 0;
+    const auto * lineIn_min = lineIn + j + ids_min;
+    // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
+    // for the last remaining line
+    for (; i < ids_size - 2; i += 2) {
+      // Load two coefficients at once
+      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
+
+      // Load 2 lines
+      auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
+      auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
+
+      auto source = _mm_unpacklo_epi8(source1, source2);
+      auto pix = _mm_unpacklo_epi8(source, zero);
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+
+    // Same processing as above but with a single weight value
+    for (; i < ids_size; i++) {
+      auto mmk = _mm_set1_epi32(k[i]);
+
+      const uint8_t * p = lineIn_min + i * data_size;
+      __m128i pix;
+      // There is no much perf gain using more detailed condition like
+      // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
+      // const int64_t in_max_size = data_size * in_ysize;
+      if (num_channels == 3) {
+        uint8_t input[4];
+        std::memcpy(input, p, 3);
+        pix = mm_cvtepu8_epi32(input, true);
+      } else {
+        pix = mm_cvtepu8_epi32(p, true);
+      }
+      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
+    }
+
+    // Convert fixed point values back to integers (truncating)
+    sss = _mm_srai_epi32(sss, coefs_precision);
+    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
+    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
+    sss = _mm_packs_epi32(sss, zero);
+    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
+    // (a a b b c c d d) -> (a b c d)
+    sss = _mm_packus_epi16(sss, zero);
+    // Store one pixel to the output
+    auto o = _mm_cvtsi128_si32(sss);
+    if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
+      std::memcpy(lineOut + j, (uint8_t *) &o, 3);
+    } else {
+      std::memcpy(lineOut + j, (uint8_t *) &o, 4);
+    }
+  }
+}
+
+} // anonymous namespace
+#endif // CPU_CAPABILITY_AVX2
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..106b068d198989f5a4d71a9fe20c79f1b5a5d915
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h
@@ -0,0 +1,20 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at { namespace native {
+
+using weight_norm_fn = void(*)(
+    TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t);
+using weight_norm_backward_fn = void(*)(
+    TensorBase&, TensorBase&, const TensorBase&, const TensorBase&,
+    const TensorBase&, const TensorBase&, int64_t);
+
+DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub);
+DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h
new file mode 100644
index 0000000000000000000000000000000000000000..ce37f0aecb8cb88758703b3783c720d840a3d926
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h
@@ -0,0 +1,522 @@
+#pragma once
+/*
+   AVX implementation of sin, cos, sincos, exp and log
+
+   Based on "sse_mathfun.h", by Julien Pommier
+   http://gruntthepeon.free.fr/ssemath/
+
+   Copyright (C) 2012 Giovanni Garberoglio
+   Interdisciplinary Laboratory for Computational Science (LISC)
+   Fondazione Bruno Kessler and University of Trento
+   via Sommarive, 18
+   I-38123 Trento (Italy)
+
+  This software is provided 'as-is', without any express or implied
+  warranty.  In no event will the authors be held liable for any damages
+  arising from the use of this software.
+
+  Permission is granted to anyone to use this software for any purpose,
+  including commercial applications, and to alter it and redistribute it
+  freely, subject to the following restrictions:
+
+  1. The origin of this software must not be misrepresented; you must not
+     claim that you wrote the original software. If you use this software
+     in a product, an acknowledgment in the product documentation would be
+     appreciated but is not required.
+  2. Altered source versions must be plainly marked as such, and must not be
+     misrepresented as being the original software.
+  3. This notice may not be removed or altered from any source distribution.
+
+  (this is the zlib license)
+*/
+
+#include 
+
+/* The original source of this file has been modified. */
+#if defined(CPU_CAPABILITY_AVX2)
+
+#if defined(__GNUC__)
+# define ALIGN32_BEG __attribute__((aligned(32)))
+#elif defined(_WIN32)
+# define ALIGN32_BEG __declspec(align(32))
+#endif
+
+typedef __m256  v8sf; // vector of 8 float (avx2)
+typedef __m256i v8si; // vector of 8 int   (avx2)
+
+/* declare some AVX constants -- why can't I figure a better way to do that? */
+#define _PS256_CONST(Name, Val)                                            \
+  static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
+#define _PI32_CONST256(Name, Val)                                            \
+  static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
+#define _PS256_CONST_TYPE(Name, Type, Val)                                 \
+  static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
+
+_PS256_CONST(1  , 1.0f);
+_PS256_CONST(0p5, 0.5f);
+/* the smallest non denormalized float number */
+_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
+_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
+_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
+
+_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
+_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
+
+_PI32_CONST256(0, 0);
+_PI32_CONST256(1, 1);
+_PI32_CONST256(inv1, ~1);
+_PI32_CONST256(2, 2);
+_PI32_CONST256(4, 4);
+_PI32_CONST256(0x7f, 0x7f);
+
+_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
+_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
+_PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
+_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
+_PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
+_PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
+_PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
+_PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
+_PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
+_PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
+_PS256_CONST(cephes_log_q1, -2.12194440e-4);
+_PS256_CONST(cephes_log_q2, 0.693359375);
+
+
+/* natural logarithm computed for 8 simultaneous float
+   return NaN for x <= 0
+*/
+inline v8sf log256_ps(v8sf x) {
+  v8si imm0;
+  v8sf one = *(v8sf*)_ps256_1;
+
+  //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
+  v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
+
+  x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos);  /* cut off denormalized stuff */
+
+  // can be done with AVX2
+  imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23);
+
+  /* keep only the fractional part */
+  x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
+  x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
+
+  // this is again another AVX2 instruction
+  imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
+  v8sf e = _mm256_cvtepi32_ps(imm0);
+
+  e = _mm256_add_ps(e, one);
+
+  /* part2:
+     if( x < SQRTHF ) {
+       e -= 1;
+       x = x + x - 1.0;
+     } else { x = x - 1.0; }
+  */
+  //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
+  v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
+  v8sf tmp = _mm256_and_ps(x, mask);
+  x = _mm256_sub_ps(x, one);
+  e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
+  x = _mm256_add_ps(x, tmp);
+
+  v8sf z = _mm256_mul_ps(x,x);
+
+  v8sf y = *(v8sf*)_ps256_cephes_log_p0;
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
+  y = _mm256_mul_ps(y, x);
+
+  y = _mm256_mul_ps(y, z);
+
+  tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
+  y = _mm256_add_ps(y, tmp);
+
+
+  tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
+  y = _mm256_sub_ps(y, tmp);
+
+  tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
+  x = _mm256_add_ps(x, y);
+  x = _mm256_add_ps(x, tmp);
+  x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
+  return x;
+}
+
+_PS256_CONST(exp_hi,        88.3762626647949f);
+_PS256_CONST(exp_lo,        -88.3762626647949f);
+
+_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
+_PS256_CONST(cephes_exp_C1, 0.693359375);
+_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
+
+_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
+_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
+_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
+_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
+_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
+_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
+
+inline v8sf exp256_ps(v8sf x) {
+  v8sf tmp = _mm256_setzero_ps(), fx;
+  v8si imm0;
+  v8sf one = *(v8sf*)_ps256_1;
+
+  x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
+  x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
+
+  /* express exp(x) as exp(g + n*log(2)) */
+  fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
+  fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
+
+  /* how to perform a floorf with SSE: just below */
+  //imm0 = _mm256_cvttps_epi32(fx);
+  //tmp  = _mm256_cvtepi32_ps(imm0);
+
+  tmp = _mm256_floor_ps(fx);
+
+  /* if greater, subtract 1 */
+  //v8sf mask = _mm256_cmpgt_ps(tmp, fx);
+  v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
+  mask = _mm256_and_ps(mask, one);
+  fx = _mm256_sub_ps(tmp, mask);
+
+  tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
+  v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
+  x = _mm256_sub_ps(x, tmp);
+  x = _mm256_sub_ps(x, z);
+
+  z = _mm256_mul_ps(x,x);
+
+  v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
+  y = _mm256_mul_ps(y, x);
+  y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
+  y = _mm256_mul_ps(y, z);
+  y = _mm256_add_ps(y, x);
+  y = _mm256_add_ps(y, one);
+
+  /* build 2^n */
+  imm0 = _mm256_cvttps_epi32(fx);
+  // another two AVX2 instructions
+  imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
+  imm0 = _mm256_slli_epi32(imm0, 23);
+  v8sf pow2n = _mm256_castsi256_ps(imm0);
+  y = _mm256_mul_ps(y, pow2n);
+  return y;
+}
+
+_PS256_CONST(minus_cephes_DP1, -0.78515625);
+_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
+_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
+_PS256_CONST(sincof_p0, -1.9515295891E-4);
+_PS256_CONST(sincof_p1,  8.3321608736E-3);
+_PS256_CONST(sincof_p2, -1.6666654611E-1);
+_PS256_CONST(coscof_p0,  2.443315711809948E-005);
+_PS256_CONST(coscof_p1, -1.388731625493765E-003);
+_PS256_CONST(coscof_p2,  4.166664568298827E-002);
+_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
+
+
+/* evaluation of 8 sines at onces using AVX intrinsics
+
+   The code is the exact rewriting of the cephes sinf function.
+   Precision is excellent as long as x < 8192 (I did not bother to
+   take into account the special handling they have for greater values
+   -- it does not return garbage for arguments over 8192, though, but
+   the extra precision is missing).
+
+   Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
+   surprising but correct result.
+
+*/
+inline v8sf sin256_ps(v8sf x) { // any x
+  v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
+  v8si imm0, imm2;
+
+  sign_bit = x;
+  /* take the absolute value */
+  x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
+  /* extract the sign bit (upper one) */
+  sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
+
+  /* scale by 4/Pi */
+  y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
+
+  /*
+    Here we start a series of integer operations, which are in the
+    realm of AVX2.
+    If we don't have AVX, let's perform them using SSE2 directives
+  */
+
+  /* store the integer part of y in mm0 */
+  imm2 = _mm256_cvttps_epi32(y);
+  /* j=(j+1) & (~1) (see the cephes sources) */
+  // another two AVX2 instruction
+  imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
+  imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
+  y = _mm256_cvtepi32_ps(imm2);
+
+  /* get the swap sign flag */
+  imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
+  imm0 = _mm256_slli_epi32(imm0, 29);
+  /* get the polynom selection mask
+     there is one polynom for 0 <= x <= Pi/4
+     and another one for Pi/4
+#include 
+
+namespace at::native {
+
+using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
+using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
+using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
+
+DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
+DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
+DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h
new file mode 100644
index 0000000000000000000000000000000000000000..4b84a452747be996c5064b962ff418f524d2f999
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include 
+
+namespace at { namespace native {
+
+inline ScalarType first_type() {
+  return ScalarType::Undefined;
+}
+
+template 
+inline ScalarType first_type(const Tensor& arg, const Args&... parameters) {
+  return arg.defined() ? arg.scalar_type() : first_type(parameters...);
+}
+
+template 
+inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
+  const auto parameter_type = first_type(parameters...);
+  return ((parameter_type != ScalarType::Undefined) &&
+          (parameter_type != input.scalar_type()));
+}
+
+// currently on CPU, mixed data type is only supported
+// when input is 'BFloat16' or 'Half' and parameters are 'Float'
+inline void check_mixed_data_type(const Tensor& input) {
+  TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
+      "mixed dtype (CPU): all inputs must share same datatype.");
+}
+
+template 
+inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) {
+  TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float,
+      "mixed dtype (CPU): expect parameter to have scalar type of Float");
+  check_mixed_data_type(input, parameters...);
+}
+
+inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) {
+  return is_mixed_type ? ScalarType::Float : t.scalar_type();
+}
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/moments_utils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/moments_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..194e53a8e1fea892351619d2d26768095d815984
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/moments_utils.h
@@ -0,0 +1,206 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+inline namespace CPU_CAPABILITY {
+
+template using opmath_t = at::opmath_type;
+
+constexpr int64_t kChunkSize = 16;
+
+template 
+void AddMoments(
+    int64_t m0_add,
+    const T& m1_add,
+    const T& m2_add,
+    int64_t& m0,
+    T& m1,
+    T& m2) {
+  const int64_t n = m0 + m0_add;
+  const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n);
+  const T delta = m1_add - m1;
+  m1 += c * delta;
+  m2 += m2_add + delta * delta * c * static_cast(m0);
+  m0 = n;
+}
+
+template 
+C10_ALWAYS_INLINE void AddMomentsVec(
+    int64_t m0_add,
+    const vec::Vectorized& m1_add,
+    const vec::Vectorized& m2_add,
+    int64_t& m0,
+    vec::Vectorized& m1,
+    vec::Vectorized& m2) {
+  using Vec = vec::Vectorized;
+  const int64_t n = m0 + m0_add;
+  const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n);
+  const Vec c_vec(c);
+  const Vec delta = m1_add - m1;
+  m1 += c_vec * delta;
+  m2 += m2_add + delta * delta * c_vec * Vec(static_cast(m0));
+  m0 = n;
+}
+
+template 
+inline typename std::enable_if>::value, void>::type
+UpdateMomentsVec(
+    int64_t m0,
+    const T* X_ptr,
+    const std::array>, kChunkSize>& c_vecs,
+    int64_t& m0_stk0,
+    vec::Vectorized>& m1_stk0,
+    vec::Vectorized>& m2_stk0) {
+  using Vec = vec::Vectorized>;
+  Vec m1_vec(0);
+  Vec m2_vec(0);
+  for (const auto j : c10::irange(m0)) {
+    const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
+    const Vec delta_vec = x_vec - m1_vec;
+    m1_vec += delta_vec * c_vecs[j];
+    m2_vec += delta_vec * (x_vec - m1_vec);
+  }
+  AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
+}
+
+// each bfloat16/half vector will be converted to two float vectors,
+// and accumulated successively on m1_stk0/m2_stk0.
+template 
+inline typename std::enable_if>::value, void>::type
+UpdateMomentsVec(
+    int64_t m0,
+    const T* X_ptr,
+    const std::array>, kChunkSize>& c_vecs,
+    int64_t& m0_stk0,
+    vec::Vectorized>& m1_stk0,
+    vec::Vectorized>& m2_stk0) {
+  using Vec = vec::Vectorized;
+  using fVec = vec::Vectorized>;
+  fVec m1_fvec0(0), m1_fvec1(0);
+  fVec m2_fvec0(0), m2_fvec1(0);
+  for (const auto j : c10::irange(m0)) {
+    const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
+    auto [x_fvec0, x_fvec1] = convert_to_float(x_bvec);
+    const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
+    const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
+    m1_fvec0 += delta_fvec0 * c_vecs[j];
+    m1_fvec1 += delta_fvec1 * c_vecs[j];
+    m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
+    m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
+  }
+  AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
+  AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
+}
+
+// Compute rowwise moments by Welford algorithm and cascade sum to improve
+// numerical stability.
+// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+// https://en.wikipedia.org/wiki/Pairwise_summation
+template 
+std::pair, opmath_t> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
+  using math_t = opmath_t;
+
+  constexpr int64_t kVecSize = vec::Vectorized::size();
+  constexpr int64_t kAccVecSize = vec::Vectorized::size();
+  const int64_t n = N / kVecSize;
+  const int64_t m = divup(n, kChunkSize);
+  const int64_t depth = utils::CeilLog2(m);
+
+  using Vec = vec::Vectorized;
+  const Vec kZeroVec(math_t(0));
+  c10::SmallVector m0_stk(depth, 0);
+  c10::SmallVector m1_stk(depth, kZeroVec);
+  c10::SmallVector m2_stk(depth, kZeroVec);
+
+  for (const auto i : c10::irange(m)) {
+    const T* X_ptr = X + i * kChunkSize * kVecSize;
+    const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
+    static std::array c_vecs = ([]() {
+      std::array result;
+      for (const auto i : c10::irange(kChunkSize)) {
+        result[i] = Vec(math_t(1) / static_cast(i + 1));
+      }
+      return result;
+    })();
+    UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
+
+    int64_t mask = i + 1;
+    for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
+      AddMomentsVec(
+          m0_stk[j - 1],
+          m1_stk[j - 1],
+          m2_stk[j - 1],
+          m0_stk[j],
+          m1_stk[j],
+          m2_stk[j]);
+      m0_stk[j - 1] = 0;
+      m1_stk[j - 1] = kZeroVec;
+      m2_stk[j - 1] = kZeroVec;
+      mask >>= 1;
+    }
+  }
+  for (const auto i : c10::irange(1, depth)) {
+    AddMomentsVec(
+        m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
+  }
+
+  std::array m1_arr{};
+  std::array m2_arr{};
+  m1_stk[0].store(m1_arr.data());
+  m2_stk[0].store(m2_arr.data());
+
+  int64_t m0 = 0;
+  math_t m1 = 0;
+  math_t m2 = 0;
+  for (int64_t i = n * kVecSize; i < N; ++i) {
+    math_t x = static_cast(X[i]);
+    const math_t delta = x - m1;
+    ++m0;
+    m1 += delta / static_cast(m0);
+    m2 += delta * (x - m1);
+  }
+  // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
+  int64_t m0_add = n * kVecSize / kAccVecSize;
+  for (const auto i : c10::irange(kAccVecSize)) {
+    AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
+  }
+
+  return std::make_pair(m1, m2 / static_cast(N - ddof));
+}
+
+template 
+std::pair, opmath_t> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
+  using Vec = vec::Vectorized;
+  constexpr int64_t kVecSize = Vec::size();
+  const int64_t n = N / kVecSize;
+  const int64_t m = divup(n, kChunkSize);
+  const int64_t depth = utils::CeilLog2(m);
+  if (depth <= 4) {
+    return RowwiseMomentsImpl(X, N, ddof);
+  } else if (depth <= 8) {
+    return RowwiseMomentsImpl(X, N, ddof);
+  } else if (depth <= 16) {
+    return RowwiseMomentsImpl(X, N, ddof);
+  } else if (depth <= 32) {
+    return RowwiseMomentsImpl(X, N, ddof);
+  } else {
+    return RowwiseMomentsImpl(X, N, ddof);
+  }
+}
+
+} // namespace CPU_CAPABILITY
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/utils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..62bb4f20891116ac2e4aeff5e7698013f123a84f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/utils.h
@@ -0,0 +1,198 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#ifdef USE_FBGEMM
+#include 
+#endif
+
+namespace at {
+namespace native {
+
+template 
+inline void _store(T* dst, at::vec::Vectorized src) {
+  src.store(dst);
+}
+
+inline void _store(at::BFloat16* dst, at::vec::Vectorized src) {
+  auto res = at::vec::convert_float_bfloat16(src, src);
+  res.store(dst, at::vec::Vectorized::size());
+}
+
+inline void _store(at::Half* dst, at::vec::Vectorized src) {
+  auto res = at::vec::convert_float_half(src, src);
+  res.store(dst, at::vec::Vectorized::size());
+}
+
+inline namespace CPU_CAPABILITY {
+
+template 
+inline T data_index_init(T offset) {
+  return offset;
+}
+
+template 
+inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
+  offset = data_index_init(offset, std::forward(args)...);
+  x = offset % X;
+  return offset / X;
+}
+
+inline bool data_index_step() {
+  return true;
+}
+
+template 
+inline bool data_index_step(T& x, const T& X, Args&&... args) {
+  if (data_index_step(std::forward(args)...)) {
+    x = ((x + 1) == X) ? 0 : (x + 1);
+    return x == 0;
+  }
+  return false;
+}
+
+// Helper struct for bfloat16 vectorization
+// Useful when you need float as immediate dtype or accumulate dtype
+using namespace vec;
+struct Vec2 {
+  Vectorized val0, val1;
+  Vec2(Vectorized v0, Vectorized v1) : val0(v0), val1(v1) {}
+  Vec2(float v) : val0(v), val1(v) {}
+  static Vec2 loadu(const BFloat16* ptr) {
+    auto [v0, v1] = convert_bfloat16_float(Vectorized::loadu(ptr));
+    return {v0, v1};
+  }
+  static Vec2 loadu(const float* ptr) {
+    return {Vectorized::loadu(ptr), Vectorized::loadu(ptr + Vectorized::size())};
+  }
+  void store(BFloat16* ptr) const {
+    Vectorized val = convert_float_bfloat16(val0, val1);
+    val.store(ptr);
+  }
+  void store(float* ptr) const {
+    val0.store(ptr);
+    val1.store(ptr + Vectorized::size());
+  }
+};
+inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
+inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
+inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
+inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
+inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
+inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
+
+template  struct VectorizedType { using type = Vectorized; };
+template <> struct VectorizedType { using type = Vec2; };
+template  using VecType = typename VectorizedType::type;
+
+// Helper for mixed data type parameter Vec::load
+inline std::tuple, Vectorized> load2f(const BFloat16* ptr) {
+  return convert_bfloat16_float(Vectorized::loadu(ptr));
+}
+
+inline std::tuple, Vectorized> load2f(const Half* ptr) {
+  return convert_half_float(Vectorized::loadu(ptr));
+}
+
+inline std::tuple, Vectorized> load2f(const float* ptr) {
+  using Vec = Vectorized;
+  return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
+}
+
+inline std::tuple, Vectorized> load2f(const BFloat16* ptr, int64_t count) {
+  return convert_bfloat16_float(Vectorized::loadu(ptr, count));
+}
+
+inline std::tuple, Vectorized> load2f(const Half* ptr, int64_t count) {
+  return convert_half_float(Vectorized::loadu(ptr, count));
+}
+
+inline std::tuple, Vectorized> load2f(const float* ptr, int64_t count) {
+  using Vec = Vectorized;
+  if (count > Vec::size()) {
+  return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
+  } else {
+    return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
+  }
+}
+
+} // namespace
+
+namespace utils {
+
+template 
+T CeilLog2(const T& x) {
+  if (x <= 2) {
+    return 1;
+  }
+  // Last set bit is floor(log2(x)), floor + 1 is ceil
+  // except when x is an exact powers of 2, so subtract 1 first
+  return static_cast(llvm::findLastSet(static_cast(x) - 1)) + 1;
+}
+
+// matrix transpose:
+//   src has shape of M by N, with leading dimension of ld_src
+//   dst has shape of N by M, with leading dimension of ld_dst
+template 
+inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
+  for (int64_t j = 0; j < N; j++) {
+    for (int64_t i = 0; i < M; i++) {
+      dst[j * ld_dst + i] = src[i * ld_src + j];
+    }
+  }
+}
+
+#ifdef USE_FBGEMM
+template <>
+inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
+  TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
+  fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst);
+}
+#endif
+
+template 
+inline void parallel_sparse_csr(
+    const TensorAccessor& crow_acc,
+    const int64_t M,
+    const int64_t nnz,
+    const F& f) {
+  TORCH_CHECK(crow_acc.size(0) == M + 1);
+
+  // directly parallel on `M` may lead to load imbalance,
+  // statically determine thread partition here to average payload
+  // for each thread.
+  int num_threads = at::get_num_threads();
+  std::vector thread_splits(num_threads + 1, M);
+
+  int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
+
+  thread_splits[0] = 0;
+  int64_t sum = 0;
+  int64_t t = 1;
+  for (const auto m : c10::irange(M)) {
+    int64_t row_start = crow_acc[m];
+    int64_t row_end = crow_acc[m + 1];
+    sum += row_end - row_start;
+    if (sum > t * thread_averge_payload) {
+      thread_splits[t] = m;
+      t++;
+    }
+  }
+  // need to restore the last index,
+  // due to rounding error when calculating `thread_averge_payload`.
+  thread_splits[num_threads] = M;
+
+  at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
+    int tid = at::get_thread_num();
+    int64_t begin = thread_splits[tid];
+    int64_t end = thread_splits[tid + 1];
+    f(begin, end);
+  });
+}
+
+} // namespace utils
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/zmath.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/zmath.h
new file mode 100644
index 0000000000000000000000000000000000000000..d978e89b1e562d294a49e890474ff9ccceb1cece
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cpu/zmath.h
@@ -0,0 +1,250 @@
+#pragma once
+
+// Complex number math operations that act as no-ops for other dtypes.
+#include 
+#include 
+#include
+
+namespace at { namespace native {
+inline namespace CPU_CAPABILITY {
+
+template 
+inline VALUE_TYPE zabs (SCALAR_TYPE z) {
+  return z;
+}
+
+template<>
+inline c10::complex zabs > (c10::complex z) {
+  return c10::complex(std::abs(z));
+}
+
+template<>
+inline float zabs , float> (c10::complex z) {
+  return std::abs(z);
+}
+
+template<>
+inline c10::complex zabs > (c10::complex z) {
+  return c10::complex(std::abs(z));
+}
+
+template<>
+inline double zabs , double> (c10::complex z) {
+  return std::abs(z);
+}
+
+// This overload corresponds to non-complex dtypes.
+// The function is consistent with its NumPy equivalent
+// for non-complex dtypes where `pi` is returned for
+// negative real numbers and `0` is returned for 0 or positive
+// real numbers.
+// Note: `nan` is propagated.
+template 
+inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
+  if (at::_isnan(z)) {
+    return z;
+  }
+  return z < 0 ? c10::pi : 0;
+}
+
+template<>
+inline c10::complex angle_impl > (c10::complex z) {
+  return c10::complex(std::arg(z), 0.0);
+}
+
+template<>
+inline float angle_impl , float> (c10::complex z) {
+  return std::arg(z);
+}
+
+template<>
+inline c10::complex angle_impl > (c10::complex z) {
+  return c10::complex(std::arg(z), 0.0);
+}
+
+template<>
+inline double angle_impl , double> (c10::complex z) {
+  return std::arg(z);
+}
+
+template 
+constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
+  return z; //No-Op
+}
+
+template<>
+constexpr c10::complex real_impl > (c10::complex z) {
+  return c10::complex(z.real(), 0.0);
+}
+
+template<>
+constexpr float real_impl , float> (c10::complex z) {
+  return z.real();
+}
+
+template<>
+constexpr c10::complex real_impl > (c10::complex z) {
+  return c10::complex(z.real(), 0.0);
+}
+
+template<>
+constexpr double real_impl , double> (c10::complex z) {
+  return z.real();
+}
+
+template 
+constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
+  return 0;
+}
+
+template<>
+constexpr c10::complex imag_impl > (c10::complex z) {
+  return c10::complex(z.imag(), 0.0);
+}
+
+template<>
+constexpr float imag_impl , float> (c10::complex z) {
+  return z.imag();
+}
+
+template<>
+constexpr c10::complex imag_impl > (c10::complex z) {
+  return c10::complex(z.imag(), 0.0);
+}
+
+template<>
+constexpr double imag_impl , double> (c10::complex z) {
+  return z.imag();
+}
+
+template 
+inline TYPE conj_impl (TYPE z) {
+  return z; //No-Op
+}
+
+template<>
+inline c10::complex conj_impl > (c10::complex z) {
+  return c10::complex{z.real(), -z.imag()};
+}
+
+template<>
+inline c10::complex conj_impl > (c10::complex z) {
+  return c10::complex(z.real(), -z.imag());
+}
+
+template<>
+inline c10::complex conj_impl > (c10::complex z) {
+  return c10::complex(z.real(), -z.imag());
+}
+
+template 
+inline TYPE ceil_impl (TYPE z) {
+  return std::ceil(z);
+}
+
+template <>
+inline c10::complex ceil_impl (c10::complex z) {
+  return c10::complex(std::ceil(z.real()), std::ceil(z.imag()));
+}
+
+template <>
+inline c10::complex ceil_impl (c10::complex z) {
+  return c10::complex(std::ceil(z.real()), std::ceil(z.imag()));
+}
+
+template
+inline c10::complex sgn_impl (c10::complex z) {
+  if (z == c10::complex(0, 0)) {
+    return c10::complex(0, 0);
+  } else {
+    return z / zabs(z);
+  }
+}
+
+template 
+inline TYPE floor_impl (TYPE z) {
+  return std::floor(z);
+}
+
+template <>
+inline c10::complex floor_impl (c10::complex z) {
+  return c10::complex(std::floor(z.real()), std::floor(z.imag()));
+}
+
+template <>
+inline c10::complex floor_impl (c10::complex z) {
+  return c10::complex(std::floor(z.real()), std::floor(z.imag()));
+}
+
+template 
+inline TYPE round_impl (TYPE z) {
+  return std::nearbyint(z);
+}
+
+template <>
+inline c10::complex round_impl (c10::complex z) {
+  return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag()));
+}
+
+template <>
+inline c10::complex round_impl (c10::complex z) {
+  return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag()));
+}
+
+template 
+inline TYPE trunc_impl (TYPE z) {
+  return std::trunc(z);
+}
+
+template <>
+inline c10::complex trunc_impl (c10::complex z) {
+  return c10::complex(std::trunc(z.real()), std::trunc(z.imag()));
+}
+
+template <>
+inline c10::complex trunc_impl (c10::complex z) {
+  return c10::complex(std::trunc(z.real()), std::trunc(z.imag()));
+}
+
+template ::value, int> = 0>
+inline TYPE max_impl (TYPE a, TYPE b) {
+  if (_isnan(a) || _isnan(b)) {
+    return std::numeric_limits::quiet_NaN();
+  } else {
+    return std::max(a, b);
+  }
+}
+
+template ::value, int> = 0>
+inline TYPE max_impl (TYPE a, TYPE b) {
+  if (_isnan(a)) {
+    return a;
+  } else if (_isnan(b)) {
+    return b;
+  } else {
+    return std::abs(a) > std::abs(b) ? a : b;
+  }
+}
+
+template ::value, int> = 0>
+inline TYPE min_impl (TYPE a, TYPE b) {
+  if (_isnan(a) || _isnan(b)) {
+    return std::numeric_limits::quiet_NaN();
+  } else {
+    return std::min(a, b);
+  }
+}
+
+template ::value, int> = 0>
+inline TYPE min_impl (TYPE a, TYPE b) {
+  if (_isnan(a)) {
+    return a;
+  } else if (_isnan(b)) {
+    return b;
+  } else {
+    return std::abs(a) < std::abs(b) ? a : b;
+  }
+}
+
+} // end namespace
+}} //end at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Activation.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Activation.h
new file mode 100644
index 0000000000000000000000000000000000000000..bf13717c177eb4ca9973ff043248c1eafff87174
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Activation.h
@@ -0,0 +1,20 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+struct TensorIteratorBase;
+class TensorBase;
+}
+
+namespace at { namespace native {
+
+void launch_glu_backward_kernel(const TensorIteratorBase& iter,
+                                int64_t gI_stride, int64_t I_stride);
+
+void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter);
+
+void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
+void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h
new file mode 100644
index 0000000000000000000000000000000000000000..1bfa8060f2e345945751db4330318ea43878e487
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h
@@ -0,0 +1,48 @@
+// DON'T include this except from Binary*.cu files. It should not leak into
+// headers.
+#pragma once
+#define TORCH_ASSERT_NO_OPERATORS
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+namespace native {
+namespace binary_internal {
+
+template 
+struct DivFunctor {
+  __device__ scalar_t operator()(scalar_t a, scalar_t b) const {
+    return a / b;
+  }
+};
+
+template 
+struct MulFunctor {
+  __device__ T operator()(T a, T b) const {
+    return a * b;
+  }
+};
+
+// Workaround for the error: '*' in boolean context, suggest '&&' instead
+// [-Werror=int-in-bool-context]
+template <>
+struct MulFunctor {
+  __device__ bool operator()(bool a, bool b) const {
+    return a && b;
+  }
+};
+void div_true_kernel_cuda(TensorIteratorBase& iter);
+void div_trunc_kernel_cuda(TensorIteratorBase& iter);
+} // namespace binary_internal
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1c0f669f69913d521d3e12317fb152e25d6e48a6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh
@@ -0,0 +1,296 @@
+#pragma once
+#include 
+
+// Jiterator functions are guarded behind this macro
+#if AT_USE_JITERATOR()
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+template 
+constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence seq) {
+    constexpr auto size = seq.size();
+    (void)t; // warning : unused parameter when tuple is empty.
+    return std::array{static_cast(&std::get(t))...};
+}
+
+// Helper function convert tuple to std::array
+// for passing the arguments to CUDA Kernel
+// NOTE: We capture tuple by reference,
+// so the pointers in returned array are only valid
+// till tuple is alive.
+template 
+constexpr auto tuple_to_array(std::tuple& extra_args) {
+    constexpr auto tuple_size = sizeof...(Args);
+    return tuple_to_array_helper(extra_args, std::make_index_sequence{});
+}
+
+struct JittedVecKernelCache {
+  // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
+  at::cuda::jit::NvrtcFunction vec1;
+  at::cuda::jit::NvrtcFunction vec2;
+  at::cuda::jit::NvrtcFunction vec4;
+};
+
+struct JittedKernelVariantCache {
+  JittedVecKernelCache vec;
+  at::cuda::jit::NvrtcFunction noncontiguous;
+  at::cuda::jit::NvrtcFunction dynamic_contiguous;
+  at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
+};
+
+inline c10::SmallBuffer pack_kernel_args(
+    std::initializer_list args,
+    c10::ArrayRef extra_args) {
+  c10::SmallBuffer ret(args.size() + extra_args.size());
+  std::copy(args.begin(), args.end(), ret.data());
+  std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
+  return ret;
+}
+
+template
+void launch_jitted_unrolled_kernel(
+    std::mutex &jiterator_mutex,
+    at::cuda::jit::NvrtcFunction &fn_cache,
+    const at::cuda::jit::KernelDescriptor &desc,
+    int64_t N,
+    array_t data,
+    inp_calc_t ic,
+    out_calc_t oc,
+    loader_t l,
+    storer_t s,
+    bool contiguous,
+    at::cuda::jit::BinaryFuncVariant scalar_pos,
+    void* scalar_val,
+    c10::ArrayRef extra_args) {
+
+  TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max());
+  //casting result to int is always safe, intermediate is int64 and won't overflow
+  const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
+
+  if (!fn_cache.function) {
+    const std::lock_guard lock{jiterator_mutex};
+    if (!fn_cache.function) {
+      constexpr bool dynamic_casting = !std::is_same() ||
+                                       !std::is_same();
+      auto code = at::cuda::jit::generate_code(
+          desc, contiguous, dynamic_casting, scalar_pos);
+      fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
+    }
+  }
+
+  auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
+  at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
+  {num_threads(), 1u, 1u});
+}
+
+template
+void launch_jitted_vectorized_kernel(
+    std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
+    const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
+    at::cuda::jit::BinaryFuncVariant scalar_pos,
+    void *scalar_val, c10::ArrayRef extra_args) {
+  TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max());
+  // N is still int64_t for the computation, but it's always safe to cast result to int
+  const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
+  const int vec_size = at::cuda::jit::can_vectorize_up_to(
+      desc, c10::ArrayRef(data.data, data.size()));
+
+  // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
+  //   fn_ptr is set to the appropriate function based on the vec size and GPU used
+  at::cuda::jit::NvrtcFunction* fn_ptr;
+  if (vec_size == 4) {
+    fn_ptr = &fn_cache.vec4;
+  } else if (vec_size == 2) {
+    fn_ptr = &fn_cache.vec2;
+  } else if (vec_size ==1) {
+    fn_ptr = &fn_cache.vec1;
+  } else {
+    TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
+  }
+
+  bool vectorized = vec_size > 1;
+
+  if (!fn_ptr->function) {
+    const std::lock_guard lock{jiterator_mutex};
+    if (!fn_ptr->function) { // cache miss!
+
+      // Generates program
+      auto code = at::cuda::jit::generate_code(
+          desc, /*contiguous=*/true, /*dynamic_casting=*/false,
+          scalar_pos, vectorized, vec_size);
+      std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
+
+      // Acquires the program
+      *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
+    }
+  }
+
+  if (vectorized) {
+    auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
+    at::cuda::jit::launch_jitted_pwise_function(
+        *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
+  } else {
+// NVCC complains about unused variables l and s.
+// It should be false positive in most cases, so we suppress the warnings.
+#pragma nv_diagnostic push
+#pragma nv_diag_suppress 177
+    auto ic = TrivialOffsetCalculator();
+    auto oc = TrivialOffsetCalculator<1>();
+    auto l = memory::LoadWithoutCast();
+    auto s = memory::StoreWithoutCast();
+
+    auto args = pack_kernel_args(
+        {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
+    at::cuda::jit::launch_jitted_pwise_function(
+        *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
+#pragma nv_diagnostic pop
+  }
+}
+
+template 
+void jitted_gpu_kernel_generic(
+    std::mutex &jiterator_mutex,
+    JittedKernelVariantCache &cache,
+    const at::cuda::jit::KernelDescriptor &desc,
+    at::cuda::jit::BinaryFuncVariant scalar_pos,
+    c10::ArrayRef extra_args,
+    TensorIteratorBase& iter,
+    const bool dynamic_casting,
+    void *scalar_val) {
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
+  TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
+
+  constexpr int ntensors = arity + 1;
+  at::detail::Array data;
+  for (auto i : c10::irange(ntensors)) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  int64_t numel = iter.numel();
+  bool contiguous = iter.is_contiguous();
+
+  // Decides which of 4 kernel types to launch
+  // Variations are:
+  //   - Case 1: no dynamic casting and contiguous
+  //   - Case 2: no dynamic casting and noncontiguous
+  //   - Case 3: dynamic casting and contiguous
+  //   - Case 4: dynamic casting and noncontiguous
+  // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
+
+  if (!dynamic_casting) {
+    if (contiguous) {
+      // Case 1: no dynamic casting and contiguous
+      launch_jitted_vectorized_kernel(
+          jiterator_mutex, cache.vec, desc,
+          numel, data, scalar_pos, scalar_val, extra_args);
+      return;
+    }
+
+    // Case 2: no dynamic casting and noncontiguous
+    auto input_offset_calculator = make_input_offset_calculator(iter);
+    auto output_offset_calculator = make_output_offset_calculator(iter);
+    auto loader = memory::LoadWithoutCast();
+    auto storer = memory::StoreWithoutCast();
+    launch_jitted_unrolled_kernel(
+        jiterator_mutex, cache.noncontiguous, desc, numel, data,
+        input_offset_calculator, output_offset_calculator, loader,
+        storer, contiguous, scalar_pos, scalar_val, extra_args);
+    return;
+  }
+
+  // Cases 3 and 4 are handled below
+  // Both require construction of a storer (this asserts 1 output) and one or more loaders
+
+  // Creates store cast to output (the zeroth tensor in TensorIterator)
+  auto storer = memory::StoreWithCast<1>(iter);
+
+  // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
+  auto loader = memory::LoadWithCast(iter);
+
+  if (contiguous) {
+    // Case 3: dynamic casting and contiguous
+    auto input_offset_calculator = TrivialOffsetCalculator();
+    auto output_offset_calculator = TrivialOffsetCalculator<1>();
+    launch_jitted_unrolled_kernel(
+        jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
+        output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
+    return;
+  }
+
+  // Case 4: dynamic casting and noncontiguous
+  auto input_offset_calculator = make_input_offset_calculator(iter);
+  auto output_offset_calculator = make_output_offset_calculator(iter);
+  launch_jitted_unrolled_kernel(
+      jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
+      output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
+}
+
+// NOTE: static to reduce chances of name collision.
+template <
+    char const* name,
+    typename result_type,
+    typename f_inputs_type,
+    int arity,
+    at::cuda::jit::BinaryFuncVariant scalar_pos =
+        at::cuda::jit::BinaryFuncVariant::NoScalar,
+    typename... ExtraArgs>
+static void jitted_gpu_kernel_impl(
+    TensorIteratorBase& iter,
+    const std::string &f,
+    const bool dynamic_casting,
+    at::opmath_type scalar_val,
+    std::tuple extra_args) {
+
+  // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
+  //   the same compute capability
+  static std::mutex jiterator_mutex;
+  static std::vector device_caches(c10::cuda::device_count());
+
+  constexpr int nInputs = arity;
+  constexpr int nOutputs = 1;  // TODO: Support more than 1 output
+  static const auto desc = at::cuda::jit::make_kernel_descriptor<
+    result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
+
+  auto &cache = device_caches[iter.device().index()];
+  auto extra_args_array = tuple_to_array(extra_args);
+  return jitted_gpu_kernel_generic(
+      jiterator_mutex,
+      cache,
+      desc,
+      scalar_pos,
+      extra_args_array,
+      iter,
+      dynamic_casting,
+      &scalar_val
+    );
+}
+
+}}  // at::native
+
+#endif // AT_USE_JITERATOR()
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..ed31a997c18fb82d033e5810b9c657d5b125831e
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh
@@ -0,0 +1,348 @@
+#pragma once
+
+// This file provides two functions to help write GPU elementwise kernels:
+//
+//   gpu_kernel(TensorIterator iter, )
+//   gpu_kernel_with_scalars(TensorIterator iter, )
+//
+// The gpu_kernel_with_scalars generates specializations that support a
+// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
+// is lifted to a kernel parameter instead of copying to device memory.
+// This should be  used in conjunction with TensorIterator::allow_cpu_scalars_,
+// which is the default for TensorIterator::binary_op. Otherwise, all inputs
+// and the output must be on the GPU.
+//
+// For example, to write a reciprocal kernel for GPU float Tensors:
+//
+//   gpu_kernel(iter, []GPU_LAMBDA(float a) {
+//    return 1.0f / a;
+//   });
+//
+// To write a multiplication kernel for GPU float Tensors where one argument
+// may be a CPU scalar:
+//
+//   gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
+//     return a * b;
+//   });
+//
+// See BinaryOpsKernel.cu for the complete implementation
+//
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __NVCC__
+#define ASSERT_HOST_DEVICE_LAMBDA(type)                       \
+  static_assert(                                              \
+      __nv_is_extended_host_device_lambda_closure_type(type), \
+      #type " must be a __host__ __device__ lambda")
+#else
+#define ASSERT_HOST_DEVICE_LAMBDA(type)
+#endif
+
+namespace at {
+namespace native {
+
+template 
+C10_LAUNCH_BOUNDS_1(num_threads())
+__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
+  using traits = function_traits;
+  int remaining = N - block_work_size() * blockIdx.x;
+
+  if (remaining < block_work_size()) { // if this block handles the reminder,
+                                       // just do a naive unrolled loop
+    auto input_calc = TrivialOffsetCalculator();
+    auto output_calc = TrivialOffsetCalculator<1>();
+    auto loader = memory::LoadWithoutCast();
+    auto storer = memory::StoreWithoutCast();
+    auto policy = memory::policies::unroll<
+        array_t,
+        decltype(input_calc),
+        decltype(output_calc),
+        memory::LoadWithoutCast,
+        memory::StoreWithoutCast>(
+        data, remaining, input_calc, output_calc, loader, storer);
+    elementwise_kernel_helper(f, policy);
+  } else { // if this block has a full `block_work_size` data to handle, use
+           // vectorized memory access
+    elementwise_kernel_helper(
+        f, memory::policies::vectorized(data));
+  }
+}
+
+template <
+    typename func_t,
+    typename array_t,
+    typename inp_calc_t,
+    typename out_calc_t,
+    typename loader_t,
+    typename storer_t>
+C10_LAUNCH_BOUNDS_1(num_threads())
+__global__ void unrolled_elementwise_kernel(
+    int N,
+    func_t f,
+    array_t data,
+    inp_calc_t ic,
+    out_calc_t oc,
+    loader_t l,
+    storer_t s) {
+  int remaining = N - block_work_size() * blockIdx.x;
+  auto policy = memory::policies::
+      unroll(
+          data, remaining, ic, oc, l, s);
+  elementwise_kernel_helper(f, policy);
+}
+
+// this function assume trivial 1d and no dynamic casting
+template 
+static inline void launch_vectorized_kernel(
+    int64_t N,
+    const func_t& f,
+    array_t data) {
+  TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max());
+  using traits = function_traits;
+  int64_t grid = (N + block_work_size() - 1) / block_work_size();
+  auto stream = at::cuda::getCurrentCUDAStream();
+  int vec_size = memory::can_vectorize_up_to(data);
+
+  switch (vec_size) {
+    case 4:
+      vectorized_elementwise_kernel<4, func_t, array_t>
+          <<>>(N, f, data);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+      break;
+    case 2:
+      vectorized_elementwise_kernel<2, func_t, array_t>
+          <<>>(N, f, data);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+      break;
+    case 1: {
+      auto input_calc = TrivialOffsetCalculator();
+      auto output_calc = TrivialOffsetCalculator<1>();
+      auto loader = memory::LoadWithoutCast();
+      auto storer = memory::StoreWithoutCast();
+      unrolled_elementwise_kernel
+          <<>>(
+              N, f, data, input_calc, output_calc, loader, storer);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+      break;
+    }
+    default:
+      TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
+  }
+}
+
+template <
+    typename func_t,
+    typename array_t,
+    typename inp_calc_t,
+    typename out_calc_t,
+    typename loader_t,
+    typename storer_t>
+static inline void launch_unrolled_kernel(
+    int64_t N,
+    const func_t& f,
+    array_t data,
+    inp_calc_t ic,
+    out_calc_t oc,
+    loader_t l,
+    storer_t s) {
+  TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max());
+  int64_t grid = (N + block_work_size() - 1) / block_work_size();
+  auto stream = at::cuda::getCurrentCUDAStream();
+  unrolled_elementwise_kernel
+      <<>>(N, f, data, ic, oc, l, s);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template 
+C10_LAUNCH_BOUNDS_2(nt, 4)
+__global__ void elementwise_kernel(int N, func_t f) {
+  int tid = threadIdx.x;
+  int nv = nt * vt;
+  int idx = nv * blockIdx.x + tid;
+#pragma unroll
+  for (int i = 0; i < vt; i++) {
+    if (idx < N) {
+      f(idx);
+      idx += nt;
+    }
+  }
+}
+
+template 
+static void launch_legacy_kernel(int64_t N, const func_t& f) {
+  TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max());
+  if (N == 0) {
+    return;
+  }
+  dim3 block(nt);
+  dim3 grid((N + block.x * vt - 1) / (block.x * vt));
+  auto stream = at::cuda::getCurrentCUDAStream();
+  elementwise_kernel<<>>(N, f);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template 
+C10_HOST_DEVICE typename traits::result_type invoke_impl(
+    const func_t& f,
+    char* const C10_RESTRICT data[],
+    const index_t strides[],
+    int i,
+    std::index_sequence) {
+  (void)strides;
+  (void)i;
+  return f(c10::load::type>(
+      data[INDEX] + i * strides[INDEX])...);
+}
+
+template <
+    typename func_t,
+    typename index_t,
+    typename traits = function_traits>
+C10_HOST_DEVICE typename traits::result_type invoke(
+    const func_t& f,
+    char* const C10_RESTRICT data[],
+    const index_t strides[],
+    int i) {
+  using Indices = std::make_index_sequence;
+  return invoke_impl(f, data, strides, i, Indices{});
+}
+
+template 
+C10_HOST_DEVICE typename traits::result_type invoke_impl(
+    const func_t& f,
+    char* const C10_RESTRICT data[],
+    const index_t strides[],
+    const ScalarType dtypes[],
+    int i,
+    std::index_sequence) {
+  (void)strides;
+  (void)i;
+  return f(c10::fetch_and_cast::type>(
+      dtypes[I], data[I] + i * strides[I])...);
+}
+
+template <
+    typename func_t,
+    typename index_t,
+    typename traits = function_traits>
+C10_HOST_DEVICE typename traits::result_type invoke(
+    const func_t& f,
+    char* const C10_RESTRICT data[],
+    const index_t strides[],
+    const ScalarType dtypes[],
+    int i) {
+  using Indices = std::make_index_sequence;
+  return invoke_impl(f, data, strides, dtypes, i, Indices{});
+}
+
+template 
+void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
+  using traits = function_traits;
+  using arg0_t = typename traits::result_type;
+  constexpr int ntensors = traits::arity + 1;
+
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
+  TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
+  TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter));
+
+  at::detail::Array data;
+  for (int i = 0; i < ntensors; i++) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  int64_t numel = iter.numel();
+
+  bool contiguous = iter.is_contiguous();
+
+  if (contiguous) {
+    return launch_vectorized_kernel(numel, f, data);
+  }
+  auto offset_calc = ::make_offset_calculator(iter);
+  constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
+  launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
+    auto offsets = offset_calc.get(idx);
+    arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
+    *out = invoke(f, &data.data[1], &offsets.data[1], 1);
+  });
+}
+
+template 
+void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
+  if (!needs_dynamic_casting::check(iter)) {
+    return gpu_kernel_impl_nocast(iter, f);
+  }
+  using traits = function_traits;
+  using arg0_t = typename traits::result_type;
+  constexpr int ntensors = traits::arity + 1;
+
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
+  TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
+
+  at::detail::Array data;
+  for (int i = 0; i < ntensors; i++) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  int64_t numel = iter.numel();
+
+  bool contiguous = iter.is_contiguous();
+
+  if (contiguous) {
+#ifdef USE_ROCM
+    at::detail::Array dtypes;
+    auto inner_strides = iter.get_inner_strides();
+    at::detail::Array strides;
+    for (int i = 0; i < ntensors; i++) {
+      dtypes[i] = iter.dtype(i);
+      strides[i] = inner_strides[i];
+    }
+    launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
+      void* out = data[0] + strides[0] * idx;
+      arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
+      c10::cast_and_store(dtypes[0], out, result);
+    });
+#else
+    auto loader = memory::LoadWithCast(iter);
+    auto storer = memory::StoreWithCast<1>(iter);
+    auto input_offset_calculator = TrivialOffsetCalculator();
+    auto output_offset_calculator = TrivialOffsetCalculator<1>();
+    launch_unrolled_kernel(
+        numel,
+        f,
+        data,
+        input_offset_calculator,
+        output_offset_calculator,
+        loader,
+        storer);
+#endif
+  } else {
+    at::detail::Array dtypes;
+    for (int i = 0; i < ntensors; i++) {
+      dtypes[i] = iter.dtype(i);
+    }
+    auto offset_calc = ::make_offset_calculator(iter);
+    launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
+      auto offsets = offset_calc.get(idx);
+      void* out = data[0] + offsets[0];
+      arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
+      c10::cast_and_store(dtypes[0], out, result);
+    });
+  }
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h
new file mode 100644
index 0000000000000000000000000000000000000000..f0dc24872e6157de677146db592fe0fed86d51b9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include 
+#include 
+
+namespace at { namespace native {
+
+struct TupleInfoCPU {
+  template 
+  using tuple = thrust::tuple;
+
+  template 
+  static constexpr auto tie(Types&... args) noexcept {
+    return thrust::tie(args...);
+  }
+};
+
+template 
+using CompositeRandomAccessorCPU =
+  CompositeRandomAccessor;
+
+template 
+void swap(
+  references_holder rh1,
+  references_holder rh2
+) {
+  return thrust::swap(rh1.data(), rh2.data());
+}
+
+template 
+auto get(references_holder rh) -> decltype(thrust::get(rh.data())) {
+  return thrust::get(rh.data());
+}
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Copy.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Copy.h
new file mode 100644
index 0000000000000000000000000000000000000000..a9e23ad7fe8d56f7aa18833c371fd3969304e6ed
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Copy.h
@@ -0,0 +1,10 @@
+#pragma once
+
+namespace at {
+struct TensorIteratorBase;
+
+namespace native {
+
+void direct_copy_kernel_cuda(TensorIteratorBase &iter);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h
new file mode 100644
index 0000000000000000000000000000000000000000..116ed029e9e32e7ba27b4c9b5a013cd794c3362d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTPlanCache.h
@@ -0,0 +1,494 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at { namespace native { namespace detail {
+
+// Enum representing the FFT type
+enum class CuFFTTransformType : int8_t {
+  C2C,  // Complex-to-complex
+  R2C,  // Real-to-complex
+  C2R,  // Complex-to-real
+};
+
+// This struct is used to let us easily compute hashes of the
+// parameters.
+// It will be the **key** to the plan cache.
+struct CuFFTParams
+{
+  int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3
+  // These include additional batch dimension as well.
+  int64_t sizes_[max_rank + 1];
+  int64_t input_strides_[max_rank + 1];
+  int64_t output_strides_[max_rank + 1];
+  CuFFTTransformType fft_type_;
+  ScalarType value_type_;
+
+  CuFFTParams() = default;
+
+  CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides,
+      IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) {
+    // Padding bits must be zeroed for hashing
+    memset(this, 0, sizeof(*this));
+    signal_ndim_ = signal_sizes.size() - 1;
+    fft_type_ = fft_type;
+    value_type_ = value_type;
+
+    TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size());
+    TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size());
+    TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank);
+
+    std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_);
+    std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_);
+    std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_);
+  }
+};
+
+static_assert(std::is_trivial::value, "");
+
+// Returns true if the transform type has complex input
+inline bool cufft_complex_input(CuFFTTransformType type) {
+  switch (type) {
+    case CuFFTTransformType::C2C:
+    case CuFFTTransformType::C2R:
+      return true;
+
+    case CuFFTTransformType::R2C:
+      return false;
+  }
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+// Returns true if the transform type has complex output
+inline bool cufft_complex_output(CuFFTTransformType type) {
+  switch (type) {
+    case CuFFTTransformType::C2C:
+    case CuFFTTransformType::R2C:
+      return true;
+
+    case CuFFTTransformType::C2R:
+      return false;
+  }
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+// Create transform type enum from bools representing if input and output are complex
+inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) {
+  if (complex_input && complex_output) {
+    return CuFFTTransformType::C2C;
+  } else if (complex_input && !complex_output) {
+    return CuFFTTransformType::C2R;
+  } else if (!complex_input && complex_output) {
+    return CuFFTTransformType::R2C;
+  }
+  TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported");
+}
+
+
+class CuFFTHandle {
+  ::cufftHandle handle_;
+public:
+
+  CuFFTHandle() {
+    CUFFT_CHECK(cufftCreate(&handle_));
+  }
+
+  ::cufftHandle & get() { return handle_; }
+  const ::cufftHandle & get() const { return handle_; }
+
+  ~CuFFTHandle() {
+// Not using fftDestroy() for rocFFT to work around double freeing of handles
+#if !defined(USE_ROCM)
+    cufftDestroy(handle_);
+#endif
+  }
+};
+
+__forceinline__
+static bool is_pow_of_two(int64_t x) {
+  return (x & (x - 1)) == 0;
+}
+
+using cufft_size_type = long long int;
+
+using CuFFTDimVector = c10::SmallVector;
+
+// Struct representing a tensor in CuFFT's data layout for planning transforms
+// See NOTE [ cuFFT Embedded Strides ].
+struct CuFFTDataLayout {
+  CuFFTDimVector embed;
+  cufft_size_type stride, dist;
+  bool must_clone, simple;
+};
+
+// Returns a cufft embedding for a contiguous signal of the given size.
+// e.g. if the input is cloned, this will be the resulting data layout
+// See NOTE [ cuFFT Embedded Strides ].
+inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) {
+  CuFFTDataLayout layout;
+  layout.simple = true;
+  layout.must_clone = false;
+  layout.embed.assign(sizes.cbegin() + 1, sizes.cend());
+  if (onesided) {
+    layout.embed.back() = sizes.back() / 2 + 1;
+  }
+  layout.stride = 1;
+  layout.dist = 1;
+  for (const auto& len : layout.embed) {
+    layout.dist *= len;
+  }
+  return layout;
+}
+
+// Convert strides to a CuFFT embedded representation.
+// If strides cannot be embedded, returns a simple layout and sets must_clone flag
+// See NOTE [ cuFFT Embedded Strides ].
+inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) {
+  const auto signal_ndim = strides.size() - 1;
+  CuFFTDataLayout layout;
+  auto last_stride = strides[signal_ndim];
+  layout.must_clone = (last_stride <= 0);
+
+  const auto last_dim_size = onesided ?
+      sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim];
+  const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size;
+
+  // Zero stides are not allowed, even if the batch size is one.
+  // If that happens just set a dummy case
+  if (sizes[0] == 1) {
+    layout.dist = signal_numel;
+  } else if (strides[0] == 0) {
+    layout.must_clone = true;
+  } else {
+    layout.dist = strides[0];
+  }
+
+  // Calculate the embedding shape, or set must_clone if the strides cannot be embedded
+  layout.embed.resize(signal_ndim);
+  for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) {
+    auto stride = strides[i];
+    if (sizes[i] == 1) {
+      layout.embed[i] = 1;
+    } else if (stride > 0 && stride % last_stride == 0) {
+      layout.embed[i] = stride / last_stride;
+      last_stride = stride;
+    } else {
+      layout.must_clone = true;
+    }
+  }
+
+  if (layout.must_clone) {
+    // If the input needs to be cloned, assume it will be contiguous
+    layout = cufft_simple_embed(sizes, onesided);
+    layout.must_clone = true;
+  } else {
+    layout.embed[0] = sizes[1];
+    layout.stride = strides[signal_ndim];
+    // Determine if layout represents a simple embedding (contiguous data)
+    layout.simple = [&] {
+      for (const auto i : c10::irange(1, signal_ndim - 1)) {
+        if (layout.embed[i] != sizes[i + 1]) {
+          return false;
+        }
+      }
+
+      return (layout.stride == 1 && layout.dist == signal_numel &&
+          layout.embed.back() == last_dim_size);
+    }();
+  }
+  return layout;
+}
+
+// This class contains all the information needed to execute a cuFFT plan:
+//   1. the plan
+//   2. whether to clone input before executing the plan
+//   3. the workspace size needed
+//
+// This class will be the **value** in the plan cache.
+// It **owns** the raw plan via a unique_ptr.
+class CuFFTConfig {
+public:
+
+  // Only move semantics is enought for this class. Although we already use
+  // unique_ptr for the plan, still remove copy constructor and assignment op so
+  // we don't accidentally copy and take perf hit.
+  CuFFTConfig(const CuFFTConfig&) = delete;
+  CuFFTConfig& operator=(CuFFTConfig const&) = delete;
+
+  explicit CuFFTConfig(const CuFFTParams& params):
+      CuFFTConfig(
+          IntArrayRef(params.input_strides_, params.signal_ndim_ + 1),
+          IntArrayRef(params.output_strides_, params.signal_ndim_ + 1),
+          IntArrayRef(params.sizes_, params.signal_ndim_ + 1),
+          params.fft_type_,
+          params.value_type_) {}
+
+  // For complex types, strides are in units of 2 * element_size(dtype)
+  // sizes are for the full signal, including batch size and always two-sided
+  CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides,
+      IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype):
+        fft_type_(fft_type), value_type_(dtype) {
+
+    // signal sizes (excluding batch dim)
+    CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end());
+
+    // input batch size
+    const int64_t batch = sizes[0];
+    const int64_t signal_ndim = sizes.size() - 1;
+
+    // Since cuFFT has limited non-unit stride support and various constraints, we
+    // use a flag to keep track throughout this function to see if we need to
+    // input = input.clone();
+
+#if defined(USE_ROCM)
+    // clone input to avoid issues with hipfft clobering the input and failing tests
+    clone_input = true;
+#else
+    clone_input = false;
+#endif
+
+    // For half, base strides on the real part of real-to-complex and
+    // complex-to-real transforms are not supported. Since our output is always
+    // contiguous, only need to check real-to-complex case.
+    if (dtype == ScalarType::Half) {
+      // cuFFT on half requires compute capability of at least SM_53
+      auto dev_prop = at::cuda::getCurrentDeviceProperties();
+      TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3),
+               "cuFFT doesn't support signals of half type with compute "
+               "capability less than SM_53, but the device containing input half "
+               "tensor only has SM_", dev_prop->major, dev_prop->minor);
+      for (const auto i : c10::irange(signal_ndim)) {
+        TORCH_CHECK(is_pow_of_two(sizes[i + 1]),
+            "cuFFT only supports dimensions whose sizes are powers of two when"
+            " computing in half precision, but got a signal size of",
+            sizes.slice(1));
+      }
+      clone_input |= in_strides.back() != 1;
+    }
+
+    CuFFTDataLayout in_layout;
+    if (clone_input) {
+      in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R);
+    } else {
+      in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R);
+    }
+    auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C);
+    TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding");
+    clone_input |= in_layout.must_clone;
+
+    // Check if we can take advantage of simple data layout.
+    //
+    // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
+
+    const bool simple_layout = in_layout.simple && out_layout.simple;
+    cudaDataType itype, otype, exec_type;
+    const auto complex_input = cufft_complex_input(fft_type);
+    const auto complex_output = cufft_complex_output(fft_type);
+    if (dtype == ScalarType::Float) {
+      itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
+      otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
+      exec_type = CUDA_C_32F;
+    } else if (dtype == ScalarType::Double) {
+      itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
+      otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
+      exec_type = CUDA_C_64F;
+    } else if (dtype == ScalarType::Half) {
+      itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
+      otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
+      exec_type = CUDA_C_16F;
+    } else {
+      TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype);
+    }
+
+    // disable auto allocation of workspace to use THC allocator
+    CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));
+
+    size_t ws_size_t;
+
+    // make plan
+    if (simple_layout) {
+      // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL.
+      // In such case, cuFFT ignores istride, ostride, idist, and odist
+      // by assuming istride = ostride = 1.
+      //
+      // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
+      CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
+        /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
+        /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
+        batch, &ws_size_t, exec_type));
+    } else {
+      CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
+            in_layout.embed.data(), in_layout.stride, in_layout.dist, itype,
+            out_layout.embed.data(), out_layout.stride, out_layout.dist, otype,
+            batch, &ws_size_t, exec_type));
+    }
+    ws_size = static_cast(ws_size_t);
+  }
+
+  const cufftHandle &plan() const { return plan_ptr.get(); }
+
+  CuFFTTransformType transform_type() const { return fft_type_; }
+  ScalarType data_type() const { return value_type_; }
+  bool should_clone_input() const { return clone_input; }
+  int64_t workspace_size() const { return ws_size; }
+
+private:
+  CuFFTHandle plan_ptr;
+  bool clone_input;
+  int64_t ws_size;
+  CuFFTTransformType fft_type_;
+  ScalarType value_type_;
+};
+
+#if defined(USE_ROCM)
+  // Note that the max plan number for CUDA version < 10 has to be 1023
+  // due to a bug that fails on the 1024th plan
+  constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023;
+  constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
+#else
+  constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits::max();
+  // The default max cache size chosen for CUDA version > 10 is arbitrary.
+  // This number puts a limit on how big of a plan cache should we maintain by
+  // default. Users can always configure it via cufft_set_plan_cache_max_size.
+  constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
+#endif
+static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits::max(),
+              "CUFFT_MAX_PLAN_NUM not in size_t range");
+static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
+              "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
+
+// This cache assumes that the mapping from key to value never changes.
+// This is **NOT** thread-safe. Please use a mutex when using it **AND** the
+// value returned from try_emplace_value.
+// The contract of using this cache is that try_emplace_value should only be
+// used when the max_size is positive.
+class CuFFTParamsLRUCache {
+public:
+  using kv_t = typename std::pair;
+  using map_t = typename std::unordered_map,
+                                            typename std::list::iterator,
+                                            ParamsHash,
+                                            ParamsEqual>;
+  using map_kkv_iter_t = typename map_t::iterator;
+
+
+  CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {}
+
+  CuFFTParamsLRUCache(int64_t max_size) {
+    _set_max_size(max_size);
+  }
+
+  CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept :
+    _usage_list(std::move(other._usage_list)),
+    _cache_map(std::move(other._cache_map)),
+    _max_size(other._max_size) {}
+
+  CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept {
+    _usage_list = std::move(other._usage_list);
+    _cache_map = std::move(other._cache_map);
+    _max_size = other._max_size;
+    return *this;
+  }
+
+  // If key is in this cache, return the cached config. Otherwise, emplace the
+  // config in this cache and return it.
+  // Return const reference because CuFFTConfig shouldn't be tampered with once
+  // created.
+  const CuFFTConfig &lookup(CuFFTParams params) {
+    AT_ASSERT(_max_size > 0);
+
+    map_kkv_iter_t map_it = _cache_map.find(params);
+    // Hit, put to list front
+    if (map_it != _cache_map.end()) {
+      _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
+      return map_it->second->second;
+    }
+
+    // Miss
+    // remove if needed
+    if (_usage_list.size() >= _max_size) {
+      auto last = _usage_list.end();
+      last--;
+      _cache_map.erase(last->first);
+      _usage_list.pop_back();
+    }
+
+    // construct new plan at list front, then insert into _cache_map
+    _usage_list.emplace_front(std::piecewise_construct,
+                       std::forward_as_tuple(params),
+                       std::forward_as_tuple(params));
+    auto kv_it = _usage_list.begin();
+    _cache_map.emplace(std::piecewise_construct,
+                std::forward_as_tuple(kv_it->first),
+                std::forward_as_tuple(kv_it));
+    return kv_it->second;
+  }
+
+  void clear() {
+    _cache_map.clear();
+    _usage_list.clear();
+  }
+
+  void resize(int64_t new_size) {
+    _set_max_size(new_size);
+    auto cur_size = _usage_list.size();
+    if (cur_size > _max_size) {
+      auto delete_it = _usage_list.end();
+      for (size_t i = 0; i < cur_size - _max_size; i++) {
+        delete_it--;
+        _cache_map.erase(delete_it->first);
+      }
+      _usage_list.erase(delete_it, _usage_list.end());
+    }
+  }
+
+  size_t size() const { return _cache_map.size(); }
+
+  size_t max_size() const noexcept { return _max_size; }
+
+  std::mutex mutex;
+
+private:
+  // Only sets size and does value check. Does not resize the data structures.
+  void _set_max_size(int64_t new_size) {
+    // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
+    // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
+    // first.
+    TORCH_CHECK(new_size >= 0,
+             "cuFFT plan cache size must be non-negative, but got ", new_size);
+    TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM,
+             "cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size);
+    _max_size = static_cast(new_size);
+  }
+
+  std::list _usage_list;
+  map_t _cache_map;
+  size_t _max_size;
+};
+
+// Since ATen is separated into CPU build and CUDA build, we need a way to call
+// these functions only when CUDA is loaded. We use CUDA hooks for this purpose
+// (at cuda/detail/CUDAHooks.cpp), and call the hooked functions from the actual
+// native function counterparts (at native/SpectralOps.cpp), i.e.,
+// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
+// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
+int64_t cufft_get_plan_cache_max_size_impl(DeviceIndex device_index);
+void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_size);
+int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
+void cufft_clear_plan_cache_impl(DeviceIndex device_index);
+
+}}} // namespace at::native::detail
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..da5f79d8249621cde70647293e6d841eec23610f
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h
@@ -0,0 +1,73 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at { namespace native {
+
+// This means that max dim is 3 + 2 = 5 with batch dimension and possible
+// complex dimension
+constexpr int max_rank = 3;
+
+static inline std::string _cudaGetErrorEnum(cufftResult error)
+{
+  switch (error)
+  {
+    case CUFFT_SUCCESS:
+      return "CUFFT_SUCCESS";
+    case CUFFT_INVALID_PLAN:
+      return "CUFFT_INVALID_PLAN";
+    case CUFFT_ALLOC_FAILED:
+      return "CUFFT_ALLOC_FAILED";
+    case CUFFT_INVALID_TYPE:
+      return "CUFFT_INVALID_TYPE";
+    case CUFFT_INVALID_VALUE:
+      return "CUFFT_INVALID_VALUE";
+    case CUFFT_INTERNAL_ERROR:
+      return "CUFFT_INTERNAL_ERROR";
+    case CUFFT_EXEC_FAILED:
+      return "CUFFT_EXEC_FAILED";
+    case CUFFT_SETUP_FAILED:
+      return "CUFFT_SETUP_FAILED";
+    case CUFFT_INVALID_SIZE:
+      return "CUFFT_INVALID_SIZE";
+    case CUFFT_UNALIGNED_DATA:
+      return "CUFFT_UNALIGNED_DATA";
+    case CUFFT_INCOMPLETE_PARAMETER_LIST:
+      return "CUFFT_INCOMPLETE_PARAMETER_LIST";
+    case CUFFT_INVALID_DEVICE:
+      return "CUFFT_INVALID_DEVICE";
+    case CUFFT_PARSE_ERROR:
+      return "CUFFT_PARSE_ERROR";
+    case CUFFT_NO_WORKSPACE:
+      return "CUFFT_NO_WORKSPACE";
+    case CUFFT_NOT_IMPLEMENTED:
+      return "CUFFT_NOT_IMPLEMENTED";
+#if !defined(USE_ROCM)
+    case CUFFT_LICENSE_ERROR:
+      return "CUFFT_LICENSE_ERROR";
+#endif
+    case CUFFT_NOT_SUPPORTED:
+      return "CUFFT_NOT_SUPPORTED";
+    default:
+      std::ostringstream ss;
+      ss << "unknown error " << error;
+      return ss.str();
+  }
+}
+
+static inline void CUFFT_CHECK(cufftResult error)
+{
+  if (error != CUFFT_SUCCESS) {
+    std::ostringstream ss;
+    ss << "cuFFT error: " << _cudaGetErrorEnum(error);
+    AT_ERROR(ss.str());
+  }
+}
+
+}} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..b7f10c697d44436a55e88a836896729d6c5cea29
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh
@@ -0,0 +1,25 @@
+#pragma once
+
+namespace at { namespace native {
+#if defined(USE_ROCM)
+// take these out when ROCm implements std:: math functions
+#include 
+template 
+static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
+
+template <>
+__forceinline__ __device__ float device_sqrt(float val) {
+  return ::sqrtf(val);
+}
+
+template <>
+__forceinline__ __device__ double device_sqrt(double val) {
+  return ::sqrt(val);
+}
+#else
+template
+__forceinline__ __device__ double device_sqrt(scalar_t val) {
+  return std::sqrt(val);
+}
+#endif
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h
new file mode 100644
index 0000000000000000000000000000000000000000..3f04779f737ba159c0ec3cbcfbae6874b05452ea
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h
@@ -0,0 +1,672 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+namespace {
+
+// launch bounds used for kernels utilizing TensorIterator
+const uint32_t block_size_bound = 256;
+const uint32_t grid_size_bound = 4;
+// number of randoms given by distributions like curand_uniform4, curand_uniform2_double
+// used in calculating philox offset.
+const uint32_t curand4_engine_calls = 4;
+
+// utility function that calculates proper philox_offset
+// for distributions utilizing TensorIterator. For distributions using
+// TensorIterator, we are using a grid-stride loop with each
+// thread yielding one element per thread. For the edge of the grid-stride
+// loop, if the tensor size is large, the unroll loop will kick in and the float4
+// from curand4 will start getting utilized (for common tensor sizes, we end up
+// using rand.x from each thread). Hence, the philox_offset is
+// (number of elements per thread * number of engine calls), which makes
+// sure that philox offset increment is not less than the number of randoms used
+// in each thread.
+std::tuple calc_execution_policy(int64_t total_elements) {
+  const uint64_t numel = static_cast(total_elements);
+  const uint32_t block_size = block_size_bound;
+  const uint32_t unroll = curand4_engine_calls;
+  dim3 dim_block(block_size);
+  dim3 grid((numel + block_size - 1) / block_size);
+  uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
+  grid.x = std::min(
+      static_cast(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
+      grid.x);
+  //number of times random will be generated per thread, to offset philox counter in thc random state
+  uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
+                                * curand4_engine_calls;
+  return std::make_tuple(counter_offset, grid, dim_block);
+}
+
+// grid stride loop kernel for distributions
+template
+C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
+__global__ void distribution_elementwise_grid_stride_kernel(int numel,
+                                                            PhiloxCudaState philox_args,
+                                                            const dist_t dist_func,
+                                                            const transform_t transform_func) {
+  auto seeds = at::cuda::philox::unpack(philox_args);
+  int idx = blockIdx.x * blockDim.x + threadIdx.x;
+  curandStatePhilox4_32_10_t state;
+  curand_init(std::get<0>(seeds),
+              idx,
+              std::get<1>(seeds),
+              &state);
+
+  int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
+      blockDim.x * gridDim.x * unroll_factor;
+  for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
+    auto rand = dist_func(&state);
+    #pragma unroll
+    for (int ii = 0; ii < unroll_factor; ii++) {
+      int li = linear_index + blockDim.x * gridDim.x * ii;
+      if (li < numel) {
+        transform_func(li, static_cast((&rand.x)[ii]));
+      }
+    }
+    __syncthreads();
+  }
+}
+
+/**
+ * distribution_nullary_kernel is analogous to gpu_kernel in
+ * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
+ * TensorIterator to launch a kernel. However, the differences are
+ *   - it launches a grid-stride loop based kernel. The kernel is not
+ *     generic like elementwise_kernel in Loops.cuh and is specialized
+ *     for the distribution kernels here.
+ *   - For big size tensors, we can launch multiple kernels recursively
+ *     (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
+ *     offset calculation is done in this function.
+ *
+ * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
+ * to have grid-stride loop kernel and then use that to launch our distribution
+ * kernels? Note that we need a grid-stride loop kernel because, we found by testing
+ * that it achieves peak effective bandwidth.
+ */
+template
+void distribution_nullary_kernel(at::TensorIteratorBase& iter,
+                                 RNG gen,
+                                 const dist_t& dist_func,
+                                 const transform_t transform_func) {
+  static_assert(unroll_factor >= 1, "unroll_factor must be >= 1.");
+  int64_t numel = iter.numel();
+  if (numel == 0) {
+    return;
+  }
+
+  auto execution_policy = calc_execution_policy(numel);
+  auto counter_offset = std::get<0>(execution_policy);
+  auto grid = std::get<1>(execution_policy);
+  auto block = std::get<2>(execution_policy);
+  PhiloxCudaState rng_engine_inputs;
+  {
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(gen->mutex_);
+    rng_engine_inputs = gen->philox_cuda_state(counter_offset);
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      distribution_nullary_kernel(sub_iter,
+        gen, dist_func, transform_func);
+    }
+    return;
+  }
+
+  char* out_data = (char*)iter.data_ptr(0);
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+  if (iter.is_trivial_1d()) {
+    auto strides = iter.get_inner_strides();
+    int stride0 = strides[0];
+    distribution_elementwise_grid_stride_kernel<<>>(
+      numel,
+      rng_engine_inputs,
+      dist_func,
+      [=]__device__(int idx, accscalar_t rand) {
+        scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
+        *out = transform_func(rand);
+      }
+    );
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  } else {
+    auto offset_calc = make_offset_calculator<1>(iter);
+    distribution_elementwise_grid_stride_kernel<<>>(
+      numel,
+      rng_engine_inputs,
+      dist_func,
+      [=]__device__(int idx, accscalar_t rand) {
+        auto offsets = offset_calc.get(idx);
+        scalar_t* out = (scalar_t*)&out_data[offsets[0]];
+        *out = transform_func(rand);
+      }
+    );
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  }
+}
+
+// Binary kernel
+template 
+__global__ void distribution_binary_elementwise_kernel(
+    int numel,
+    func_t f,
+    PhiloxCudaState philox_args,
+    typename function_traits::result_type *output_data,
+    const typename function_traits::template arg<1>::type *input_data_1,
+    const typename function_traits::template arg<2>::type *input_data_2,
+    inp_offset_calc_t inp_calc,
+    out_offset_calc_t out_calc) {
+  auto seeds = at::cuda::philox::unpack(philox_args);
+
+  using input_t_1 = typename function_traits::template arg<1>::type;
+  using input_t_2 = typename function_traits::template arg<2>::type;
+
+  input_t_1 inputs_1[thread_work_size()];
+  input_t_2 inputs_2[thread_work_size()];
+
+  int base_index = block_work_size() * blockIdx.x;
+  int remaining = std::min(numel - base_index, block_work_size());
+
+  curandStatePhilox4_32_10_t state;
+  curand_init(std::get<0>(seeds),
+              blockIdx.x * blockDim.x + threadIdx.x,
+              std::get<1>(seeds),
+              &state);
+
+  // load data into registers
+  int thread_idx = threadIdx.x;
+  #pragma unroll
+  for (int i = 0; i < thread_work_size(); i++) {
+    if (thread_idx >= remaining) {
+      break;
+    }
+    int input_idx = thread_idx + base_index;
+    auto offsets = inp_calc.get(input_idx);
+    inputs_1[i] = input_data_1[offsets[0]];
+    inputs_2[i] = input_data_2[offsets[1]];
+
+    thread_idx += num_threads();
+  }
+
+  // compute and store
+  thread_idx = threadIdx.x;
+  #pragma unroll
+  for (int i = 0; i < thread_work_size(); i++) {
+    if (thread_idx >= remaining) {
+      break;
+    }
+    int input_idx = thread_idx + base_index;
+    auto offsets = out_calc.get(input_idx);
+    output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
+    thread_idx += num_threads();
+  }
+}
+
+template 
+void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
+  static_assert(std::is_same::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
+  using input_t_1 = typename function_traits::template arg<1>::type;
+  using input_t_2 = typename function_traits::template arg<2>::type;
+  using output_t = typename function_traits::result_type;
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      distribution_binary_kernel(sub_iter, philox_args, f);
+    }
+    return;
+  }
+
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
+
+  int64_t numel = iter.numel();
+  if (numel == 0) {
+    return;
+  }
+
+  output_t *output_data = static_cast(iter.data_ptr(0));
+  const input_t_1 *input_data_1 = static_cast(iter.data_ptr(1));
+  const input_t_2 *input_data_2 = static_cast(iter.data_ptr(2));
+
+  int64_t grid = (numel + block_work_size() - 1) / block_work_size();
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  if (iter.is_contiguous()) {
+    distribution_binary_elementwise_kernel<<>>(
+        numel, f, philox_args, output_data, input_data_1, input_data_2,
+        TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  } else {
+    distribution_binary_elementwise_kernel<<>>(
+        numel, f, philox_args, output_data, input_data_1, input_data_2,
+        make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  }
+}
+
+} // namespace
+}} // namespace at::native
+
+
+namespace at {
+namespace native {
+namespace templates {
+namespace cuda {
+
+// ==================================================== Random ========================================================
+
+template
+void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
+  AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
+    if ((
+      std::is_same::value ||
+      std::is_same::value ||
+      std::is_same::value ||
+      std::is_same::value) && range >= 1ULL << 32)
+    {
+      // define lambda to mod with range and add base
+      auto random_func = [range, base] __device__ (uint64_t rand) {
+        return transformation::uniform_int_from_to(rand, range, base);
+      };
+      distribution_nullary_kernel(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
+          ulonglong2 ret;
+          uint4 rand_val = curand4(state);
+          ret.x = (static_cast(rand_val.x) << 32) | rand_val.y;
+          ret.y = (static_cast(rand_val.z) << 32) | rand_val.w;
+          return ret;
+        },
+        random_func);
+    } else {
+      auto random_func = [range, base] __device__ (uint32_t rand) {
+        return transformation::uniform_int_from_to(rand, range, base);
+      };
+      distribution_nullary_kernel(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) {
+          return curand4(state);
+        },
+        random_func);
+    }
+   }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
+}
+
+// This is the special kernel to handle single specific case:
+// from(inclusive) = std::numeric_limits::lowest()
+// to(exclusive) = None (= std::numeric_limits::max() + 1)
+template
+void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
+    if (std::is_same::value ||
+        std::is_same::value ||
+        std::is_same::value ||
+        std::is_same::value) {
+      auto random_func = [] __device__ (uint64_t rand) {
+        return transformation::uniform_int_full_range(rand);
+      };
+      distribution_nullary_kernel(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
+          ulonglong2 ret;
+          uint4 rand_val = curand4(state);
+          ret.x = (static_cast(rand_val.x) << 32) | rand_val.y;
+          ret.y = (static_cast(rand_val.z) << 32) | rand_val.w;
+          return ret;
+        },
+        random_func);
+    } else {
+      TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
+    }
+  });
+}
+
+template
+struct RandomFromToKernel {
+  void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional gen) {
+    random_from_to_kernel(iter, range, base, check_generator(gen));
+  }
+  void operator()(TensorIteratorBase& iter, c10::optional gen) {
+    random_full_64_bits_range_kernel(iter, check_generator(gen));
+  }
+};
+
+template
+void random_kernel(TensorIteratorBase& iter, RNG gen) {
+  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
+    if (std::is_same::value || std::is_same::value) {
+      auto random_func = [] __device__ (uint64_t rand) {
+        return transformation::uniform_int(rand);
+      };
+      distribution_nullary_kernel(iter, gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
+          ulonglong2 ret;
+          uint4 rand_val = curand4(state);
+          ret.x = (static_cast(rand_val.x) << 32) | rand_val.y;
+          ret.y = (static_cast(rand_val.z) << 32) | rand_val.w;
+          return ret;
+        },
+        random_func);
+    } else {
+      auto random_func = [] __device__ (uint32_t rand) {
+        return transformation::uniform_int(rand);
+      };
+      distribution_nullary_kernel(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) {
+          return curand4(state);
+        },
+        random_func);
+    }
+  });
+}
+
+template
+struct RandomKernel {
+  void operator()(TensorIteratorBase& iter, RNG gen) {
+    random_kernel(iter, gen);
+  }
+};
+
+// ====================================================================================================================
+
+template
+void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
+  if (std::is_same::value) {
+    distribution_nullary_kernel(iter,
+      gen,
+      [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
+      transform);
+  } else {
+    distribution_nullary_kernel(iter,
+      gen,
+      [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
+      transform);
+  }
+}
+
+template
+void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
+  if (std::is_same::value) {
+    distribution_nullary_kernel(iter,
+      gen,
+      [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
+      transform);
+  } else {
+    distribution_nullary_kernel(iter,
+      gen,
+      [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
+      transform);
+  }
+}
+
+// ==================================================== Normal ========================================================
+
+template
+void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
+  auto iter = TensorIterator::borrowing_nullary_op(self);
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
+    using accscalar_t = at::acc_type;
+    auto mean = static_cast(mean_);
+    auto std = static_cast(std_);
+    // define lambda to multiply std and add mean
+    auto normal_func = [mean, std] __device__ (accscalar_t rand) {
+      return static_cast(transformation::normal(rand, mean, std));
+    };
+    normal_and_transform(iter, gen, normal_func);
+   });
+}
+
+template
+struct NormalKernel {
+  void operator()(const TensorBase &self, double mean, double std, c10::optional gen) {
+    normal_kernel(self, mean, std, check_generator(gen));
+  }
+};
+
+// ==================================================== Uniform ========================================================
+
+template
+void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
+    auto from = static_cast(from_);
+    auto to = static_cast(to_);
+    using opmath_t = at::opmath_type;
+    auto range = static_cast(to-from);
+    // define lambda to reverse bounds, multiply 'range' and add 'from_'
+    auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
+      // Compute output value before reversing the bounds
+      // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
+      auto value = static_cast(rand * range + from);
+      // reverse the bounds of curand4 from (0, 1] to [0, 1)
+      // Note that this method is from legacy THCTensorRandom and is likely to give
+      // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
+      // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
+      // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
+      auto reverse_bound_value = value == to ? from : value;
+      return reverse_bound_value;
+    };
+    uniform_and_transform(iter, gen, uniform_func);
+   });
+}
+
+template
+struct UniformKernel {
+  void operator()(TensorIteratorBase& iter, double from, double to, c10::optional gen) {
+    uniform_kernel(iter, from, to, check_generator(gen));
+  }
+};
+
+// ================================================== LogNormal =======================================================
+
+template
+void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
+    using accscalar_t = at::acc_type;
+    auto mean = static_cast(mean_);
+    auto std = static_cast(std_);
+    // define lambda for log_normal transformation
+    auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
+      return static_cast(transformation::log_normal(transformation::normal(rand, mean, std)));
+    };
+    normal_and_transform(iter, gen, log_normal_func);
+   });
+}
+
+template
+struct LogNormalKernel {
+  void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional gen) {
+    log_normal_kernel(iter, mean, std, check_generator(gen));
+  }
+};
+
+// =================================================== Geometric ======================================================
+
+template
+void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
+    using accscalar_t = at::DiscreteDistributionType::type;
+    // define lambda for geometric transformation
+    auto geometric_func = [p] __device__ (accscalar_t rand) {
+      return static_cast(transformation::geometric(rand, p));
+    };
+    uniform_and_transform(iter, gen, geometric_func);
+  });
+}
+
+template
+struct GeometricKernel {
+  void operator()(TensorIteratorBase& iter, double p, c10::optional gen) {
+    geometric_kernel(iter, p, check_generator(gen));
+  }
+};
+
+// ================================================== Exponential =====================================================
+
+template
+void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
+  TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
+    using accscalar_t = at::acc_type;
+    auto lambda = static_cast(lambda_);
+    // define lambda for exponential transformation
+    auto exponential_func = [lambda] __device__ (accscalar_t rand) {
+      return static_cast(transformation::exponential(rand, lambda));
+    };
+    uniform_and_transform(iter, gen, exponential_func);
+   });
+}
+
+template
+struct ExponentialKernel {
+  void operator()(TensorIteratorBase& iter, double lambda, c10::optional gen) {
+    exponential_kernel(iter, lambda, check_generator(gen));
+  }
+};
+
+// ==================================================== Cauchy ========================================================
+
+template
+void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
+    using accscalar_t = at::acc_type;
+    auto median = static_cast(median_);
+    auto sigma = static_cast(sigma_);
+    // define lambda for cauchy transformation
+    auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
+      return static_cast(transformation::cauchy(rand, median, sigma));
+    };
+    uniform_and_transform(iter, gen, cauchy_func);
+   });
+}
+
+template
+struct CauchyKernel {
+  void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional gen) {
+    cauchy_kernel(iter, median, sigma, check_generator(gen));
+  }
+};
+
+// ==================================================== Bernoulli =====================================================
+
+template
+void bernoulli_tensor_cuda_kernel(
+    const TensorBase &ret, const at::TensorBase &p,
+    PhiloxCudaState philox_args) {
+  auto functor = [philox_args] __device__(
+          int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
+          const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
+        auto seeds = at::cuda::philox::unpack(philox_args);
+        curandStatePhilox4_32_10_t state;
+        curand_init(std::get<0>(seeds),
+                    blockIdx.x * blockDim.x + threadIdx.x,
+                    std::get<1>(seeds),
+                    &state);
+
+        // See Note [Register spilling in curand call for CUDA < 10]
+        float4 rand = curand_uniform4(&state);
+        switch (n) {
+          case 4: {
+            CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
+            v4 = static_cast(rand.w <= p4);
+            // fallthrough
+          }
+          case 3: {
+            CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
+            v3 = static_cast(rand.z <= p3);
+            // fallthrough
+          }
+          case 2: {
+            CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
+            v2 = static_cast(rand.y <= p2);
+            // fallthrough
+          }
+          case 1: {
+            CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
+            v1 = static_cast(rand.x <= p1);
+          }
+        }
+      };
+  // The template argument `4` below indicates that we want to operate on four
+  // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
+  at::cuda::CUDA_tensor_apply2(ret, p, functor);
+}
+
+template
+void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
+  PhiloxCudaState rng_engine_inputs;
+  {
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(gen->mutex_);
+    rng_engine_inputs = gen->philox_cuda_state(10);
+  }
+  TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
+  // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
+  const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
+  auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
+  auto p = expand_inplace(self, p_cuda);
+  AT_DISPATCH_ALL_TYPES_AND3(
+    at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
+      if (std::is_same::value) {
+        return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs);
+      } else {
+        return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs);
+      }
+   });
+}
+
+template
+void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
+  AT_DISPATCH_ALL_TYPES_AND3(
+    at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
+      using accscalar_t = at::DiscreteDistributionType::type;
+      // define lambda for bernoulli transformation
+      auto bernoulli_func = [p] __device__ (accscalar_t rand) {
+        return static_cast(transformation::bernoulli(rand, p));
+      };
+      uniform_and_transform(iter, gen, bernoulli_func);
+   });
+}
+
+template
+struct BernoulliKernel {
+  void operator()(TensorIteratorBase& iter, double p, c10::optional gen) {
+    bernoulli_kernel(iter, p, check_generator(gen));
+  }
+  void operator()(const TensorBase &self, const TensorBase &p_, c10::optional gen) {
+    bernoulli_kernel(self, p_, check_generator(gen));
+  }
+};
+
+}}}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Distributions.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Distributions.h
new file mode 100644
index 0000000000000000000000000000000000000000..053eff0c7d7a5a84db1601bf17fd19dc2cc35382
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Distributions.h
@@ -0,0 +1,25 @@
+#pragma once
+
+namespace at {
+struct CUDAGeneratorImpl;
+struct TensorIteratorBase;
+class TensorBase;
+
+namespace native {
+
+void launch_poisson_cuda_kernel(
+    const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen);
+
+void launch_gamma_kernel(
+    const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen);
+
+void launch_binomial_cuda_kernel(
+    TensorIteratorBase &iter, CUDAGeneratorImpl *gen);
+
+void launch_dirichlet_kernel(TensorIteratorBase &iter);
+
+void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter);
+
+void launch_dirichlet_grad_kernel(TensorIteratorBase &iter);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..ceed1577ab60b1b84a8522498e2ca438d7fb3ef4
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh
@@ -0,0 +1,22 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+Tensor embedding_backward_cuda_kernel(
+    const Tensor &grad,
+    const Tensor &orig_indices,
+    const Tensor &sorted_indices,
+    const Tensor &count,
+    int64_t num_weights,
+    int padding_idx = -1,
+    bool mode_mean = false,
+    const Tensor &offset2bag = Tensor(),
+    const Tensor &bag_size = Tensor(),
+    const Tensor &per_sample_weights = Tensor());
+
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..df0b078ba24a1bbb14f998cd569403244d4a18f2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh
@@ -0,0 +1,681 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+namespace {
+
+// TODO(crcrpar): Handle version bump in codegen.
+// rel:
+// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
+inline void increment_version(TensorList tensors) {
+  for (const auto& t : tensors) {
+    t.unsafeGetTensorImpl()->bump_version();
+  }
+}
+
+// Initializes args and checks if all args are aligned
+template 
+__device__ bool init_args(
+    T** args,
+    TensorListMetadata& tl,
+    const int64_t chunk_idx,
+    const int64_t chunk_size,
+    const int64_t tensor_loc) {
+  bool all_aligned = true;
+  for (int i = 0; i < depth; i++) {
+    args[i] = (T*)tl.addresses[i][tensor_loc];
+    args[i] += chunk_idx * chunk_size;
+
+    if (!is_aligned(args[i])) {
+      all_aligned = false;
+    }
+  }
+  return all_aligned;
+}
+
+// Initializes args and checks if all args are aligned
+template 
+__device__ bool init_args(
+    T** args,
+    TensorListScalarListMetadata& tl,
+    const int64_t chunk_idx,
+    const int64_t chunk_size,
+    const int64_t tensor_loc) {
+  bool all_aligned = true;
+  for (int i = 0; i < depth; i++) {
+    args[i] = (T*)tl.addresses[i][tensor_loc];
+    args[i] += chunk_idx * chunk_size;
+
+    if (!is_aligned(args[i])) {
+      all_aligned = false;
+    }
+  }
+  return all_aligned;
+}
+
+template 
+__device__ bool init_args(
+    T** args,
+    FusedOptimizerTensorListMetadata& tl,
+    const int64_t chunk_idx,
+    const int64_t chunk_size,
+    const int64_t tensor_loc) {
+  bool all_aligned = true;
+  for (int i = 0; i < depth; i++) {
+    args[i] = (T*)tl.addresses[i][tensor_loc];
+    args[i] += chunk_idx * chunk_size;
+
+    if (!is_aligned(args[i])) {
+      all_aligned = false;
+    }
+  }
+  return all_aligned;
+}
+
+template 
+__device__ void load_args(
+    T r_args[][kILP],
+    T** args,
+    const int64_t i_start,
+    const int64_t chunk_size,
+    const int64_t n) {
+#pragma unroll
+  for (int ii = 0; ii < kILP; ii++) {
+    const auto i = i_start + threadIdx.x + ii * blockDim.x;
+    for (int r_index = 0; r_index < depth; r_index++) {
+      r_args[r_index][ii] = 0;
+      if (i < n && i < chunk_size) {
+        r_args[r_index][ii] = args[r_index][i];
+      }
+    }
+  }
+}
+
+template 
+__device__ void store_args(
+    T* dst,
+    T* src,
+    const int64_t i_start,
+    const int64_t chunk_size,
+    const int64_t n) {
+#pragma unroll
+  for (int ii = 0; ii < kILP; ii++) {
+    const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
+    if (i < n && i < chunk_size)
+      dst[i] = src[ii];
+  }
+}
+
+template 
+__device__ __forceinline__ void binary_op_scalar(
+    T r_args[][kILP],
+    T** args,
+    opmath_t scalar,
+    const int64_t n,
+    const int64_t chunk_size,
+    const bool all_aligned,
+    Op op) {
+  // to make things simple, we put aligned case in a different code path
+  if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+    for (int64_t i_start = threadIdx.x;
+         i_start * kILP < n && i_start * kILP < chunk_size;
+         i_start += blockDim.x) {
+      // load
+      load_store(r_args[0], args[0], 0, i_start);
+#pragma unroll
+      for (int ii = 0; ii < kILP; ii++) {
+        r_args[0][ii] = static_cast(
+            op(static_cast(r_args[0][ii]),
+               static_cast(scalar)));
+      }
+      // store
+      load_store(args[res_arg_index], r_args[0], i_start, 0);
+    }
+  } else {
+    for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+         i_start += blockDim.x * kILP) {
+      // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
+      // has depth 1
+      load_args<1>(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+      for (int ii = 0; ii < kILP; ii++) {
+        r_args[0][ii] = static_cast(
+            op(static_cast(r_args[0][ii]),
+               static_cast(scalar)));
+      }
+      store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+    }
+  }
+}
+
+template 
+__device__ __forceinline__ void pointwise_op_scalar(
+    T r_args[][kILP],
+    T** args,
+    opmath_t scalar,
+    const int64_t n,
+    const int64_t chunk_size,
+    const bool all_aligned,
+    Op op) {
+  // to make things simple, we put aligned case in a different code path
+  if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+    for (int64_t i_start = threadIdx.x;
+         i_start * kILP < n && i_start * kILP < chunk_size;
+         i_start += blockDim.x) {
+      // load
+      load_store(r_args[0], args[0], 0, i_start);
+      load_store(r_args[1], args[1], 0, i_start);
+      load_store(r_args[2], args[2], 0, i_start);
+#pragma unroll
+      for (int ii = 0; ii < kILP; ii++) {
+        r_args[0][ii] = static_cast(
+            static_cast(r_args[0][ii]) +
+            scalar *
+                op(static_cast(r_args[1][ii]),
+                   static_cast(r_args[2][ii])));
+      }
+      // store
+      load_store(args[res_arg_index], r_args[0], i_start, 0);
+    }
+  } else {
+    for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+         i_start += blockDim.x * kILP) {
+      // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
+      // has depth 3
+      load_args<3>(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+      for (int ii = 0; ii < kILP; ii++) {
+        r_args[0][ii] = static_cast(
+            static_cast(r_args[0][ii]) +
+            scalar *
+                op(static_cast(r_args[1][ii]),
+                   static_cast(r_args[2][ii])));
+      }
+      store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+    }
+  }
+}
+
+//
+// Binary Functors
+//
+template 
+struct BinaryOpScalarFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op,
+      opmath_t scalar) {
+    const int tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const int chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    binary_op_scalar(
+        r_args, args, scalar, n, chunk_size, all_aligned, op);
+  }
+};
+
+template 
+struct BinaryOpScalarListFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListScalarListMetadata& tl,
+      Op op) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    opmath_t scalar = tl.scalar_vals[tensor_loc];
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    binary_op_scalar(
+        r_args, args, scalar, n, chunk_size, all_aligned, op);
+  }
+};
+
+template 
+struct BinaryOpListAlphaFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op,
+      opmath_t alpha) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    // to make things simple, we put aligned case in a different code path
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+        // load
+        load_store(r_args[0], args[0], 0, i_start);
+        load_store(r_args[1], args[1], 0, i_start);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = static_cast(
+              op(static_cast(r_args[0][ii]),
+                 alpha * static_cast(r_args[1][ii])));
+        }
+        // store
+        load_store(args[res_arg_index], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        load_args(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = static_cast(
+              op(static_cast(r_args[0][ii]),
+                 alpha * static_cast(r_args[1][ii])));
+        }
+        store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+template 
+struct BinaryOpScalarTensorFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op,
+      T* scalar,
+      opmath_t alpha) {
+    const int tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const int chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    // to make things simple, we put aligned case in a different code path
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+        // load
+        load_store(r_args[0], args[0], 0, i_start);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = static_cast(op(
+              static_cast(r_args[0][ii]),
+              static_cast(alpha) * static_cast(*scalar)));
+        }
+        // store
+        load_store(args[res_arg_index], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        // Regardless if depth is 1 (for inplace) or 2 (for out of place),
+        // r_args has depth 1
+        load_args<1>(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = static_cast(op(
+              static_cast(r_args[0][ii]),
+              static_cast(alpha) * static_cast(*scalar)));
+        }
+        store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+//
+// Unary Functors
+//
+
+template 
+struct ZeroFunctor {
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata<1>& tl) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const auto all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    // to make things simple, we put aligned case in a different code path
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = 0;
+        }
+        // store
+        load_store(args[0], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = 0;
+        }
+        store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+template 
+struct UnaryOpFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    // to make things simple, we put aligned case in a different code path
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+        // load
+        load_store(r_args[0], args[0], 0, i_start);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] =
+              static_cast(op(static_cast(r_args[0][ii])));
+        }
+        // store
+        load_store(args[res_arg_index], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        load_args(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] =
+              static_cast(op(static_cast(r_args[0][ii])));
+        }
+        store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+//
+// Pointwise Functors
+//
+
+template 
+struct PointwiseOpScalarFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op,
+      opmath_t scalar) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    pointwise_op_scalar(
+        r_args, args, scalar, n, chunk_size, all_aligned, op);
+  }
+};
+
+template 
+struct PointwiseOpScalarListFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListScalarListMetadata& tl,
+      Op op) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    opmath_t scalar = tl.scalar_vals[tensor_loc];
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    pointwise_op_scalar(
+        r_args, args, scalar, n, chunk_size, all_aligned, op);
+  }
+};
+
+template 
+struct PointwiseOpListFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[depth - 1][kILP];
+
+    // to make things simple, we put aligned case in a different code path
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+        // load
+        load_store(r_args[0], args[0], 0, i_start);
+        load_store(r_args[1], args[1], 0, i_start);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = static_cast(
+              op(static_cast(r_args[0][ii]),
+                 static_cast(r_args[1][ii])));
+        }
+        // store
+        load_store(args[2], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        load_args(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] = static_cast(
+              op(static_cast(r_args[0][ii]),
+                 static_cast(r_args[1][ii])));
+        }
+        store_args(args[2], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+template 
+struct TernaryOpListFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op) {
+    static_assert(depth == 3 || depth == 4, "");
+    static_assert(depth >= r_args_depth, "");
+    static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+        load_store(r_args[0], args[0], 0, i_start);
+        load_store(r_args[1], args[1], 0, i_start);
+        load_store(r_args[2], args[2], 0, i_start);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] =
+              op(static_cast(r_args[0][ii]),
+                 static_cast(r_args[1][ii]),
+                 static_cast(r_args[2][ii]));
+        }
+        load_store(args[res_arg_index], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        load_args(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] =
+              op(static_cast(r_args[0][ii]),
+                 static_cast(r_args[1][ii]),
+                 static_cast(r_args[2][ii]));
+        }
+        store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+template 
+struct TernaryOpScalarFunctor {
+  using opmath_t = at::opmath_type;
+  template 
+  __device__ __forceinline__ void operator()(
+      int chunk_size,
+      TensorListMetadata& tl,
+      Op op,
+      opmath_t alpha) {
+    static_assert(depth == 2 || depth == 3, "");
+    static_assert(depth >= r_args_depth, "");
+    static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    auto n = tl.numel_for_tensor[tensor_loc];
+
+    T* args[depth];
+    const bool all_aligned =
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc);
+    n -= chunk_idx * chunk_size;
+    T r_args[r_args_depth][kILP];
+
+    // to make things simple, we put aligned case in a different code path
+    if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+        // load
+        load_store(r_args[0], args[0], 0, i_start);
+        load_store(r_args[1], args[1], 0, i_start);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] =
+              op(static_cast(r_args[0][ii]),
+                 static_cast(r_args[1][ii]),
+                 alpha);
+        }
+        // store
+        load_store(args[res_arg_index], r_args[0], i_start, 0);
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        load_args(r_args, args, i_start, chunk_size, n);
+#pragma unroll
+        for (int ii = 0; ii < kILP; ii++) {
+          r_args[0][ii] =
+              op(static_cast(r_args[0][ii]),
+                 static_cast(r_args[1][ii]),
+                 alpha);
+        }
+        store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
+      }
+    }
+  }
+};
+
+template 
+struct power_functor {
+  C10_DEVICE T operator()(const T& a, const T& b) const {
+    return at::native::pow_(a, b);
+  }
+};
+
+template 
+struct reverse_power_functor {
+  C10_DEVICE T operator()(const T& a, const T& b) const {
+    return at::native::pow_(b, a);
+  }
+};
+
+} // namespace
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..32421ef305a9905cd6d54805429fa58bc78b0825
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh
@@ -0,0 +1,22 @@
+#pragma once
+
+#include 
+
+namespace at::native {
+
+// std:: does not have clamp functors
+template 
+struct minimum {
+  __device__ T operator()(const T& a, const T& b) const {
+    return (_isnan(a) || a < b) ? a : b;
+  }
+};
+
+template 
+struct maximum {
+  __device__ T operator()(const T& a, const T& b) const {
+    return (_isnan(a) || a > b) ? a : b;
+  }
+};
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d23bc89fa64e55017e69f1352e86f9e36dcc36a5
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh
@@ -0,0 +1,321 @@
+#pragma once
+#include 
+#include 
+
+namespace at { namespace native {
+
+using detail::GridSamplerInterpolation;
+using detail::GridSamplerPadding;
+
+// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
+// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
+// if align_corners: -1 and +1 get sent to the centers of the corner pixels
+//     -1 --> 0
+//     +1 --> (size - 1)
+//     scale_factor = (size - 1) / 2
+// if not align_corners: -1 and +1 get sent to the image edges
+//     -1 --> -0.5
+//     +1 --> (size - 1) + 0.5 == size - 0.5
+//     scale_factor = size / 2
+template 
+static __forceinline__ __device__
+scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    return ((coord + 1.f) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    return ((coord + 1.f) * size - 1) / 2;
+  }
+}
+
+// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
+// except that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static __forceinline__ __device__
+scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
+                                           bool align_corners, scalar_t *grad_in) {
+  if (align_corners) {
+    // unnormalize coord from [-1, 1] to [0, size - 1]
+    *grad_in = static_cast(size - 1) / 2;
+    return ((coord + 1.f) / 2) * (size - 1);
+  } else {
+    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
+    *grad_in = static_cast(size) / 2;
+    return ((coord + 1.f) * size - 1) / 2;
+  }
+}
+
+// Clips coordinates to between 0 and clip_limit - 1
+template 
+static __forceinline__ __device__
+scalar_t clip_coordinates(scalar_t in, int clip_limit) {
+  return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0)));
+}
+
+// clip_coordinates_set_grad works similarly to clip_coordinates except that
+// it also returns the `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static __forceinline__ __device__
+scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
+  // Note that it is important for the gradient calculation that borders
+  // are considered out of bounds.
+  if (in <= static_cast(0)) {
+    *grad_in = static_cast(0);
+    return static_cast(0);
+  } else {
+    scalar_t max = static_cast(clip_limit - 1);
+    if (in >= max) {
+      *grad_in = static_cast(0);
+      return max;
+    } else {
+      *grad_in = static_cast(1);
+      return in;
+    }
+  }
+}
+
+// Reflects coordinates until they fall between low and high (inclusive).
+// The bounds are passed as twice their value so that half-integer values
+// can be represented as ints.
+template 
+static __forceinline__ __device__
+scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
+  if (twice_low == twice_high) {
+    return static_cast(0);
+  }
+  scalar_t min = static_cast(twice_low) / 2;
+  scalar_t span = static_cast(twice_high - twice_low) / 2;
+  in = ::fabs(in - min);
+  // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
+  scalar_t extra = ::fmod(in, span);
+  int flips = static_cast(::floor(in / span));
+  if (flips % 2 == 0) {
+    return extra + min;
+  } else {
+    return span - extra + min;
+  }
+}
+
+// reflect_coordinates_set_grad works similarly to reflect_coordinates except
+// that it also returns the `d output / d input` via pointer argument
+// `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static __forceinline__ __device__
+scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
+                                      scalar_t *grad_in) {
+  if (twice_low == twice_high) {
+    *grad_in = static_cast(0);
+    return static_cast(0);
+  }
+  int grad_in_mult_;
+  scalar_t min = static_cast(twice_low) / 2;
+  scalar_t span = static_cast(twice_high - twice_low) / 2;
+  in = in - min;
+  if (in < static_cast(0)) {
+    grad_in_mult_ = -1;
+    in = -in;
+  } else {
+    grad_in_mult_ = 1;
+  }
+  // `fmod` returns same sign as `in`, which is positive after the `if` above.
+  scalar_t extra = ::fmod(in, span);
+  int flips = static_cast(::floor(in / span));
+  if (flips % 2 == 0) {
+    *grad_in = static_cast(grad_in_mult_);
+    return extra + min;
+  } else {
+    *grad_in = static_cast(-grad_in_mult_);
+    return span - extra + min;
+  }
+}
+
+template
+static __forceinline__ __device__
+scalar_t safe_downgrade_to_int_range(scalar_t x){
+  // -100.0 does not have special meaning. This is just to make sure
+  // it's not within_bounds_2d or within_bounds_3d, and does not cause
+  // undefined behavior. See #35506.
+  if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast(x)))
+    return static_cast(-100.0);
+  return x;
+}
+
+template
+static __forceinline__ __device__
+scalar_t compute_coordinates(scalar_t coord, int size,
+                             GridSamplerPadding padding_mode,
+                             bool align_corners) {
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates(coord, 0, 2*(size - 1));
+    } else {
+      coord = reflect_coordinates(coord, -1, 2*size - 1);
+    }
+    // clip coordinates to image borders
+    coord = clip_coordinates(coord, size);
+  }
+
+  coord = safe_downgrade_to_int_range(coord);
+  return coord;
+}
+
+// Computes the pixel source index value for a grid coordinate
+template 
+static __forceinline__ __device__
+scalar_t grid_sampler_compute_source_index(
+    scalar_t coord,
+    int size,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+  coord = grid_sampler_unnormalize(coord, size, align_corners);
+  coord = compute_coordinates(coord, size, padding_mode, align_corners);
+  return coord;
+}
+
+// grid_sampler_compute_source_index_set_grad works similarly to
+// grid_sampler_compute_source_index except that it also returns the
+// `d output / d input` via pointer argument `grad_in`.
+// This is useful in the backward pass of grid_sampler.
+template 
+static __forceinline__ __device__
+scalar_t grid_sampler_compute_source_index_set_grad(
+    scalar_t coord,
+    int size,
+    GridSamplerPadding padding_mode,
+    bool align_corners,
+    scalar_t *grad_in) {
+  scalar_t grad_clip, grad_refl;
+  coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
+  if (padding_mode == GridSamplerPadding::Border) {
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_clip;
+  } else if (padding_mode == GridSamplerPadding::Reflection) {
+    // reflect coordinates by image borders
+    if (align_corners) {
+      coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
+    } else {
+      coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
+    }
+    // clip coordinates to image borders
+    coord = clip_coordinates_set_grad(coord, size, &grad_clip);
+    *grad_in = (*grad_in) * grad_refl * grad_clip;
+  }
+
+  coord = safe_downgrade_to_int_range(coord);
+  return coord;
+}
+
+static __forceinline__ __device__
+bool within_bounds_2d(int h, int w, int H, int W) {
+  return h >= 0 && h < H && w >= 0 && w < W;
+}
+
+static __forceinline__ __device__
+bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
+  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
+}
+
+template
+static __forceinline__ __device__
+scalar_t get_value_bounded(
+    scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
+    GridSamplerPadding padding_mode,
+    bool align_corners) {
+
+  x = compute_coordinates(x, W, padding_mode, align_corners);
+  y = compute_coordinates(y, H, padding_mode, align_corners);
+
+  int ix = static_cast(x);
+  int iy = static_cast(y);
+
+  if (within_bounds_2d(iy, ix, H, W)) {
+    return data[iy * sH + ix * sW];
+  }
+  return static_cast(0);
+}
+
+template
+static __forceinline__ __device__
+void safe_add_2d(scalar_t *data, int h, int w,
+                 int sH, int sW, int H, int W,
+                 scalar_t delta,
+                 const index_t NC_offset,
+                 const index_t memory_span) {
+  if (within_bounds_2d(h, w, H, W)) {
+    fastAtomicAdd(data,
+                  NC_offset + h * sH + w * sW,
+                  memory_span,
+                  delta,
+                  true);
+  }
+}
+
+template
+static __forceinline__ __device__
+void safe_add_3d(scalar_t *data, int d, int h, int w,
+                 int sD, int sH, int sW, int D, int H, int W,
+                 scalar_t delta,
+                 const index_t NC_offset,
+                 const index_t memory_span) {
+  if (within_bounds_3d(d, h, w, D, H, W)) {
+    fastAtomicAdd(data,
+                  NC_offset + d * sD + h * sH + w * sW,
+                  memory_span,
+                  delta,
+                  true);
+  }
+}
+
+template
+static __forceinline__ __device__
+void add_value_bounded(
+    scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
+    scalar_t delta,
+    GridSamplerPadding padding_mode,
+    bool align_corners,
+    const index_t NC_offset,
+    const index_t memory_span) {
+
+  x = compute_coordinates(x, W, padding_mode, align_corners);
+  y = compute_coordinates(y, H, padding_mode, align_corners);
+
+  int ix = static_cast(x);
+  int iy = static_cast(y);
+
+  safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
+}
+
+// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
+template
+static __forceinline__ __device__
+void get_cubic_coefficients_grad(
+    scalar_t coeffs[4],
+    scalar_t t) {
+
+  // Must be the same as forward calculation in
+  // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
+  scalar_t A = -0.75;
+
+  scalar_t x;
+  x = -1 - t;  // 1 < x = |-1 - tx| < 2
+  coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
+  x = -t;     // x = |0 - tx| <= 1
+  coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
+  x = 1 - t;  // x = |1 - tx| <= 1
+  coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
+  x = 2 - t;  // 1 < x = |2 - tx| < 2
+  coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
+}
+
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.h
new file mode 100644
index 0000000000000000000000000000000000000000..507662c13c8af7199e2620fd96f4f5309fa67884
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/GridSampler.h
@@ -0,0 +1,32 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at {
+namespace native {
+
+void launch_grid_sampler_2d_forward_kernel(
+    const TensorBase &output, const TensorBase &input, const TensorBase &grid,
+    int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_3d_forward_kernel(
+    const TensorBase &output, const TensorBase &input, const TensorBase &grid,
+    int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_2d_backward_kernel(
+    const TensorBase &grad_input, const TensorBase &grad_grid,
+    const TensorBase &grad_output, const TensorBase &input,
+    const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
+    bool align_corners, std::array output_mask);
+
+void launch_grid_sampler_3d_backward_kernel(
+    const TensorBase &grad_input, const TensorBase &grad_grid,
+    const TensorBase &grad_output, const TensorBase &input,
+    const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
+    bool align_corners, std::array output_mask);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..de2ab4de3416634213bf8299ab8e06d26ed41e2b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/IndexKernel.h
@@ -0,0 +1,16 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+struct TensorIteratorBase;
+class TensorBase;
+}
+
+namespace at {
+namespace native {
+/// @param maskPrefixSum[in,out]
+void launch_masked_scatter_kernel(
+    const TensorBase &self, const TensorBase &mask,
+    const TensorBase &maskPrefixSum, const TensorBase &source);
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..38e1f571968eb19ebc3b71a595cd09bb85a8202d
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/JitLoops.cuh
@@ -0,0 +1,187 @@
+#pragma once
+
+#include 
+
+#if AT_USE_JITERATOR()
+
+#include 
+
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+
+namespace at {
+namespace native {
+
+/* Note [Jiterator]
+The "jiterator" simply just-in-time compiles the same kernels that
+Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
+build size, and initial CUDA context size.
+
+By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
+This behavior is controlled with two environment variables:
+  - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
+  - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
+
+The jiterator currently has some limitations, however. It cannot:
+  - handle math on complex datatypes
+  - handle kernels with scalar parameters
+
+These improvements will likely come soon.
+
+For examples of how to use the jiterator see the i1 and gcd kernel
+implementations, which pass jittable strings implementing their
+operations instead of the typical CUDA functors.
+
+To pass a runtime argument (similar to lambda captures in non-JIT kernels),
+we need to pass to additional arguments to `jitted_gpu_kernel` by value.
+Currently only primitive C++ types used for computation are valid.
+The order of these extra arguments should be same as the order they appear
+in kernel's function signature. (look at polygamma for example)
+
+NOTE: One big restriction being that these arguments should be after the
+arguments provided by TensorIterator. Eg. While capturing `n`, where
+`scalar_t x` and `scalar_t y` are provided by TensorIterator,
+* foo(scalar_t x, scalar_t y, int n) works!
+* foo(int n, scalar_t x, scalar_y) doesn't work
+* foo(scalar_t x, int n, scalar_y) doesn't work
+
+*/
+
+// Entrypoint for jitted GPU kernels.
+// Only handles elementwise unary and binary kernels with a
+//   common dtype and a single output.
+// NOTE: this assumes the op's iterator has a common_dtype.
+// NOTE: We use std::tuple instead of parameter pack
+//  for `extra_args` due to following
+// bug on older versions of clang
+// https://bugs.llvm.org/show_bug.cgi?id=23029
+template <
+    char const* name,
+    typename return_type,
+    typename f_inputs_type,
+    int arity,
+    typename... Args>
+void jitted_gpu_kernel(
+    TensorIteratorBase& iter,
+    const std::string& f,
+    at::cuda::jit::BinaryFuncVariant scalar_pos =
+        at::cuda::jit::BinaryFuncVariant::NoScalar,
+    at::opmath_type scalar_val = 0,
+    std::tuple extra_args = std::make_tuple()) {
+  // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
+  //   Maybe it could be refactored?
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(
+      iter.device(arg).is_cuda(),
+      "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      jitted_gpu_kernel(
+          sub_iter, f, scalar_pos, scalar_val, extra_args);
+    }
+
+    return;
+  }
+
+  // Computes if dynamic casting is needed
+  // Dynamic casting is needed if an input's dtype differs from the common dtype
+  //   or if the result dtype differs from the output's dtype
+  // Note: this is intentionally divergent from calling needs_dynamic_casting,
+  //   which is more general and inspects a lambda to determine if dynamic
+  //   casting is needed.
+  bool needs_dynamic_casting = false;
+
+  // Checks output
+  const ScalarType return_scalar_type = c10::CppTypeToScalarType::value;
+  const auto dtype0 = iter.dtype(0);
+  if (dtype0 != return_scalar_type) {
+    needs_dynamic_casting = true;
+  }
+
+  // Checks input(s)
+  const ScalarType inputs_scalar_type = c10::CppTypeToScalarType::value;
+  for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
+    const auto dtypei = iter.dtype(i);
+    if (dtypei != inputs_scalar_type) {
+      needs_dynamic_casting = true;
+      break;
+    }
+  }
+  if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
+    // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
+    // for computation in the generated code and hence we pass a dummy
+    // value of `0`.
+    jitted_gpu_kernel_impl<
+        /*name*/ name,
+        /*return_type=*/return_type,
+        /*f_inputs_type=*/f_inputs_type,
+        arity,
+        at::cuda::jit::BinaryFuncVariant::NoScalar>(
+        iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
+  } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
+    jitted_gpu_kernel_impl<
+        /*name*/ name,
+        /*return_type=*/return_type,
+        /*f_inputs_type=*/f_inputs_type,
+        arity,
+        at::cuda::jit::BinaryFuncVariant::RhsScalar>(
+        iter,
+        f,
+        needs_dynamic_casting,
+        scalar_val,
+        extra_args);
+
+  } else {
+    jitted_gpu_kernel_impl<
+        /*name*/ name,
+        /*return_type=*/return_type,
+        /*f_inputs_type=*/f_inputs_type,
+        arity,
+        at::cuda::jit::BinaryFuncVariant::LhsScalar>(
+        iter,
+        f,
+        needs_dynamic_casting,
+        scalar_val,
+        extra_args);
+  }
+}
+
+// TODO: support runtime state capture similar to `jitted_gpu_kernel`.
+template 
+void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+  //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
+  using opmath_t = at::opmath_type;
+  if (iter.is_cpu_scalar(1)) {
+    auto scalar_val = iter.scalar_value(1);
+    iter.remove_operand(1);
+    // TODO: When all kernels that use gpu_kernel_with_scalars are
+    // ported to structured, this device guard can be deleted.  This
+    // works around incorrect device guard generation for pre-structured
+    // kernels device guards, but structured kernels do it right and
+    // we can assume the device is already set correctly
+    const OptionalDeviceGuard device_guard(iter.device(1));
+    jitted_gpu_kernel(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
+  } else if (iter.is_cpu_scalar(2)) {
+    auto scalar_val = iter.scalar_value(2);
+    iter.remove_operand(2);
+    jitted_gpu_kernel(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
+  } else {
+    jitted_gpu_kernel(iter, f);
+  }
+}
+
+}}  // at::native
+
+#endif // AT_USE_JITERATOR()
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1c12691ac9307243b8b00a7ac30930980e6456e2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh
@@ -0,0 +1,149 @@
+#pragma once
+#include 
+
+#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
+#include 
+#endif
+
+namespace at {
+namespace native {
+
+__device__ __forceinline__ size_t
+idx(const size_t nc,
+    const size_t height,
+    const size_t width,
+    const size_t h,
+    const size_t w) {
+  return (nc * height + h) * width + w;
+}
+
+// for channels-last
+__device__ __forceinline__ size_t
+idx_cl(
+  const size_t n, const size_t h, const size_t w, const size_t c,
+  const size_t height, const size_t width, const size_t channel
+) {
+  return ((n * height + h) * width + w) * channel + c;
+}
+
+// fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization
+// that speed up half-precision atomics.  The situation with half
+// precision atomics is that we have a slow __half atomic, and
+// a fast vectored __half2 atomic (this can be worth up to a 6x
+// speedup, see https://github.com/pytorch/pytorch/pull/21879).
+// We can convert a __half atomic into a __half2 atomic by simply
+// pairing the __half with a zero entry on the left/right depending
+// on alignment... but only if this wouldn't cause an out of bounds
+// access!  Thus, you must specify tensor and numel so we can check
+// if you would be out-of-bounds and use a plain __half atomic if
+// you would be.
+template <
+    typename scalar_t,
+    typename index_t,
+    typename std::enable_if::value>::type* =
+        nullptr>
+__device__ __forceinline__ void fastSpecializedAtomicAdd(
+    scalar_t* tensor,
+    index_t index,
+    const index_t numel,
+    scalar_t value) {
+#if (                      \
+    (defined(USE_ROCM)) || \
+    (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
+  gpuAtomicAddNoReturn(
+      reinterpret_cast(tensor) + index,
+      static_cast(value));
+#else
+  // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
+  __half* target_addr = reinterpret_cast<__half*>(tensor + index);
+  bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0);
+
+  if (low_byte && index < (numel - 1)) {
+    __half2 value2;
+    value2.x = static_cast<__half>(value);
+    value2.y = __int2half_rz(0);
+    atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
+
+  } else if (!low_byte && index > 0) {
+    __half2 value2;
+    value2.x = __int2half_rz(0);
+    value2.y = static_cast<__half>(value);
+    atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
+
+  } else {
+    atomicAdd(
+        reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
+  }
+#endif
+}
+
+template <
+    typename scalar_t,
+    typename index_t,
+    typename std::enable_if::value>::type* =
+        nullptr>
+__device__ __forceinline__ void fastSpecializedAtomicAdd(
+    scalar_t* tensor,
+    index_t index,
+    const index_t numel,
+    scalar_t value) {
+#if (                      \
+    (defined(USE_ROCM)) || \
+    (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
+  gpuAtomicAddNoReturn(
+      reinterpret_cast(tensor) + index,
+      static_cast(value));
+#else
+  // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
+  __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
+  bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0);
+
+  if (low_byte && index < (numel - 1)) {
+    __nv_bfloat162 value2;
+    value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
+    value2.y = __int2bfloat16_rz(0);
+    atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
+
+  } else if (!low_byte && index > 0) {
+    __nv_bfloat162 value2;
+    value2.x = __int2bfloat16_rz(0);
+    value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
+    atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
+
+  } else {
+    atomicAdd(
+        reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
+  }
+#endif
+}
+
+
+template <
+    typename scalar_t,
+    typename index_t,
+    typename std::enable_if::value && !std::is_same::value >::type* =
+        nullptr>
+__device__ __forceinline__ void fastSpecializedAtomicAdd(
+    scalar_t* tensor,
+    index_t index,
+    const index_t numel,
+    scalar_t value) {
+  gpuAtomicAddNoReturn(tensor + index, value);
+}
+
+template 
+__device__ __forceinline__ void fastAtomicAdd(
+    scalar_t* tensor,
+    index_t index,
+    const index_t numel,
+    scalar_t value,
+    bool fast_atomics) {
+  if (fast_atomics) {
+    fastSpecializedAtomicAdd(tensor, index, numel, value);
+  } else {
+    gpuAtomicAddNoReturn(tensor + index, value);
+  }
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..746fa2c34ecf7232fead66384142f9454efb67b7
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h
@@ -0,0 +1,18 @@
+#pragma once
+#include
+
+namespace at {
+namespace native {
+
+// returns 2**floor(log2(n))
+static int lastPow2(unsigned int n) {
+  n |= (n >> 1);
+  n |= (n >> 2);
+  n |= (n >> 4);
+  n |= (n >> 8);
+  n |= (n >> 16);
+  return std::max(1, n - (n >> 1));
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Loops.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Loops.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..741d31711e90669c12a7e163c76ea8eb1ba78027
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Loops.cuh
@@ -0,0 +1,326 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include 
+
+
+namespace at { namespace native {
+
+template
+static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) {
+  // array size can not be 0, this happens when N == 0
+  constexpr int array_size = std::max(N, 1);
+  TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
+  std::array strides;
+  int64_t element_sizes[array_size];
+  for (int i = 0; i < N; i++) {
+    strides[i] = iter.strides(i + iter.noutputs()).data();
+    element_sizes[i] = iter.element_size(i + iter.noutputs());
+  }
+  return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
+}
+
+template 
+static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) {
+  TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
+  std::array strides;
+  int64_t element_sizes[num_outputs];
+  for (int i = 0; i < num_outputs; i++) {
+    strides[i] = iter.strides(i).data();
+    element_sizes[i] = iter.element_size(i);
+  }
+  return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
+}
+
+template
+__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
+  using traits = function_traits;
+  using return_t = typename traits::result_type;
+  using args_t = typename traits::ArgsTuple;
+
+  int idx = blockIdx.x;
+
+  return_t results[thread_work_size()];
+  args_t args[thread_work_size()];
+
+  // load
+  policy.load(args, idx);
+
+  // compute
+  #pragma unroll
+  for (int i = 0; i < thread_work_size(); i++) {
+    if (policy.check_inbounds(i)) {
+      results[i] = c10::guts::apply(f, args[i]);
+    }
+  }
+
+  // store
+  policy.store(results, idx);
+}
+
+}}  // namespace at::native
+
+#include 
+
+namespace at:: native {
+
+template 
+void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) {
+
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(
+      iter.device(arg).is_cuda(),
+      "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      gpu_kernel_nocast(sub_iter, f);
+    }
+    return;
+  }
+
+  gpu_kernel_impl_nocast(iter, f);
+}
+
+template 
+void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
+
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(
+      iter.device(arg).is_cuda(),
+      "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      gpu_kernel(sub_iter, f);
+    }
+    return;
+  }
+
+  gpu_kernel_impl(iter, f);
+}
+
+template
+struct AUnaryFunctor {
+  using traits = function_traits;
+  using opmath_arg1_t = typename traits::template arg<0>::type;
+  __device__ return_t operator()(arg2_t b) const {
+    return f(a, b);
+  }
+  // NB: scalar is stored in higher precision!
+  AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {}
+  private:
+    func_t f;
+    opmath_arg1_t a;
+};
+
+template
+struct BUnaryFunctor {
+  using traits = function_traits;
+  using opmath_arg2_t = typename traits::template arg<1>::type;
+  __device__ return_t operator()(arg1_t a) const {
+    return f(a, b);
+  }
+  // NB: scalar is stored in higher precision!
+  BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {}
+  private:
+    func_t f;
+    opmath_arg2_t b;
+};
+
+// Though seemingly noop, this inserts casts from arg1_t to func_t's type
+// (which may be higher precision), as well as casts to return_t
+template 
+struct BinaryFunctor {
+  __device__ return_t operator()(arg1_t a, arg2_t b) const {
+    return f(a, b);
+  }
+  BinaryFunctor(func_t f_): f(f_) {}
+  private:
+    func_t f;
+};
+
+// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which
+// accepts inputs at higher precision (typically opmath_t), but then
+// ensure that we load from memory at the correct precision (scalar_t)
+// to avoid expensive loads.  For the whole sordid story see
+// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
+template 
+void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+  using traits = function_traits;
+  using opmath_arg1_t = typename traits::template arg<0>::type;
+  using opmath_arg2_t = typename traits::template arg<1>::type;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+
+  if (iter.is_cpu_scalar(1)) {
+    AUnaryFunctor af(f, iter.scalar_value(1));
+    iter.remove_operand(1);
+    // TODO: When all kernels that use gpu_kernel_with_scalars are
+    // ported to structured, this device guard can be deleted.  This
+    // works around incorrect device guard generation for pre-structured
+    // kernels device guards, but structured kernels do it right and
+    // we can assume the device is already set correctly
+    const OptionalDeviceGuard device_guard(iter.device(1));
+    gpu_kernel(iter, af);
+  } else if (iter.is_cpu_scalar(2)) {
+    BUnaryFunctor bf(f, iter.scalar_value(2));
+    iter.remove_operand(2);
+    gpu_kernel(iter, bf);
+  } else {
+    gpu_kernel(iter, BinaryFunctor(f));
+  }
+}
+
+template 
+void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+  // Use symmetric property of the functor to reduce number of kernels,
+  // requires f(a, b) == f(b, a)
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+  using traits = function_traits;
+  using opmath_arg_t = typename traits::template arg<0>::type;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+  static_assert(std::is_same::type>::value,
+                "f is not symmetric");
+
+  OptionalDeviceGuard device_guard;
+  opmath_arg_t scalar_val{};
+
+  if (iter.is_cpu_scalar(1)) {
+    scalar_val = iter.scalar_value(1);
+    iter.remove_operand(1);
+
+    // TODO: When all kernels that use gpu_kernel_with_scalars are
+    // ported to structured, this device guard can be deleted.  This
+    // works around incorrect device guard generation for pre-structured
+    // kernels device guards, but structured kernels do it right and
+    // we can assume the device is already set correctly
+    device_guard.reset_device(iter.device(1));
+  } else if (iter.is_cpu_scalar(2)) {
+    scalar_val = iter.scalar_value(2);
+    iter.remove_operand(2);
+  }
+
+  if (iter.ninputs() == 2) {
+    gpu_kernel(iter, BinaryFunctor(f));
+  } else {
+    AUnaryFunctor unary_f(f, scalar_val);
+    gpu_kernel(iter, unary_f);
+  }
+}
+
+// Legacy variant that assumes that func_t has the correct types
+// that we expect to load from memory
+template 
+void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+  using traits = function_traits;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+  using arg1_t = typename traits::template arg<0>::type;
+  using arg2_t = typename traits::template arg<1>::type;
+  using return_t = typename traits::result_type;
+  opmath_gpu_kernel_with_scalars(iter, f);
+}
+
+namespace { // functions for `gpu_kernel_multiple_outputs`.
+
+// check the return type is `thrust::tuple`, not `std::tuple`.
+template  struct is_tuple: std::false_type {};
+
+template  struct is_tuple>: std::true_type {};
+
+template 
+C10_LAUNCH_BOUNDS_1(num_threads())
+__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) {
+  int remaining = N - block_work_size() * blockIdx.x;
+  elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc));
+}
+
+template 
+static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) {
+  TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max());
+  int64_t grid = (N + block_work_size() - 1) / block_work_size();
+  auto stream = at::cuda::getCurrentCUDAStream();
+  unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template 
+void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) {
+  using traits = function_traits;
+  using output_t = typename traits::result_type;
+  static_assert(is_tuple::value, "f's return type must be `thrust::tuple`");
+  constexpr int num_outputs = thrust::tuple_size::value;
+  constexpr int num_inputs = traits::arity;
+  constexpr int ntensors = num_outputs + num_inputs;
+
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
+
+  at::detail::Array data;
+  for (int i = 0; i < ntensors; i++) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  int64_t numel = iter.numel();
+
+  if (iter.is_contiguous()) {
+    auto input_calc = TrivialOffsetCalculator();
+    auto output_calc = TrivialOffsetCalculator();
+    launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc);
+  } else {
+    auto input_calc = make_input_offset_calculator(iter);
+    auto output_calc = make_output_offset_calculator(iter);
+    launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc);
+  }
+}
+} // namespace
+
+template 
+void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      gpu_kernel_multiple_outputs(sub_iter, f);
+    }
+    return;
+  }
+
+  gpu_kernel_multiple_outputs_impl(iter, f);
+}
+
+} //namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Math.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Math.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..4e6effa0247e25d55d37238d7718ba47c2362713
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Math.cuh
@@ -0,0 +1,3375 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+// See note [Jiterator]
+// TODO: elaborate in this comment on the structure of math.cuh
+#if AT_USE_JITERATOR()
+
+const auto ndtri_string = jiterator_stringify(
+  /*
+  * This function is derived from the implementation of the digamma function in the Cephes Math Library.
+  * See note [3-Clause BSD License for the Cephes Math Library].
+  *
+  * Evaluates polynomial of degree N:
+  *
+  *                     2          N
+  * y  =  C  + C x + C x  +...+ C x
+  *        0    1     2          N
+  *
+  * Coefficients are stored in reverse order:
+  *
+  * coef[0] = C  , ..., coef[N] = C  .
+  *            N                   0
+  */
+  template 
+  T polevl(const T x, const T A[], const int len) {
+    // NOTE: This `polevl` is different from other `polevl`
+    // implementation (in PyTorch) which expect the `len` to be
+    // `len(A) - 1` instead of `len(A)`.
+    T result = 0;
+    for (int i = 0; i < len; ++i) {
+      result = result * x + A[i];
+    }
+    return result;
+  }
+
+  /*
+  * This function is derived from the implementation of the i1e function in the Cephes Math Library.
+  * See note [3-Clause BSD License for the Cephes Math Library].
+  *
+  * Computes the argument, x, for which the area under the Gaussian probability density function
+  * (integrated from minus infinity to x) is equal to y.
+  */
+  template 
+  T ndtri(T y0) {
+
+    constexpr T zero = 0;
+    constexpr T one = 1;
+
+    // Handles special cases
+    if (y0 == zero) {
+      return NEG_INFINITY;
+    }
+    if (y0 == one) {
+      return POS_INFINITY;
+    }
+    if (y0 < zero || y0 > one) {
+      return NAN;
+    }
+
+    bool code = true;
+    T y = y0;
+    // Note: the constant 0.135... is equal to exp(-2)
+    if (y > one - T{0.13533528323661269189}) {
+      y = one - y;
+      code = false;
+    }
+
+    if (y > T{0.13533528323661269189}) {
+      /* approximation for 0 <= |y - 0.5| <= 3/8 */
+      static const T P0[5] = {
+          -5.99633501014107895267E1,
+          9.80010754185999661536E1,
+          -5.66762857469070293439E1,
+          1.39312609387279679503E1,
+          -1.23916583867381258016E0,
+      };
+
+      static const T Q0[9] = {
+        1.00000000000000000000E0,
+        1.95448858338141759834E0,
+        4.67627912898881538453E0,
+        8.63602421390890590575E1,
+        -2.25462687854119370527E2,
+        2.00260212380060660359E2,
+        -8.20372256168333339912E1,
+        1.59056225126211695515E1,
+        -1.18331621121330003142E0,
+      };
+
+      /* sqrt(2pi) */
+      constexpr T s2pi = 2.50662827463100050242E0;
+
+      y = y - T{0.5};
+      const T y2 = y * y;
+      T x = y + y * (y2 * polevl(y2, P0, int{5}) / polevl(y2, Q0, int{9}));
+      return x * s2pi;
+    }
+
+    T x = sqrt(T{-2.} * log(y));
+    const T x0 = x - (log(x) / x);
+
+    const T z = one / x;
+    T x1;
+
+    /* y > exp(-32) = 1.2664165549e-14 */
+    if (x < T{8.0}) {
+      /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
+      * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
+      */
+      static const T P1[9] = {
+        4.05544892305962419923E0,
+        3.15251094599893866154E1,
+        5.71628192246421288162E1,
+        4.40805073893200834700E1,
+        1.46849561928858024014E1,
+        2.18663306850790267539E0,
+        -1.40256079171354495875E-1,
+        -3.50424626827848203418E-2,
+        -8.57456785154685413611E-4,
+      };
+
+      static const T Q1[9] = {
+        1.00000000000000000000E0,
+        1.57799883256466749731E1,
+        4.53907635128879210584E1,
+        4.13172038254672030440E1,
+        1.50425385692907503408E1,
+        2.50464946208309415979E0,
+        -1.42182922854787788574E-1,
+        -3.80806407691578277194E-2,
+        -9.33259480895457427372E-4,
+      };
+
+      x1 = z * polevl(z, P1, int{9}) / polevl(z, Q1, int{9});
+    } else {
+      /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
+      * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
+      */
+      static const T P2[9] = {
+        3.23774891776946035970E0,
+        6.91522889068984211695E0,
+        3.93881025292474443415E0,
+        1.33303460815807542389E0,
+        2.01485389549179081538E-1,
+        1.23716634817820021358E-2,
+        3.01581553508235416007E-4,
+        2.65806974686737550832E-6,
+        6.23974539184983293730E-9,
+      };
+
+      static const T Q2[9] = {
+        1.00000000000000000000E0,
+        6.02427039364742014255E0,
+        3.67983563856160859403E0,
+        1.37702099489081330271E0,
+        2.16236993594496635890E-1,
+        1.34204006088543189037E-2,
+        3.28014464682127739104E-4,
+        2.89247864745380683936E-6,
+        6.79019408009981274425E-9,
+      };
+
+      x1 = z * polevl(z, P2, int{9}) / polevl(z, Q2, int{9});
+    }
+
+    x = x0 - x1;
+    return (!code) ? x : -x;
+  }
+); // ndtri_string
+
+const auto log_ndtr_string = jiterator_stringify(
+  template 
+  T log_ndtr(T x) {
+    constexpr T SQRT1_2{0.707106781186547524400844362104849039};   // 1/sqrt(2)
+    T t = x * SQRT1_2;
+    if (x < T{-1.0}) {
+      return log(erfcx(-t) / 2) - t * t;
+    } else {
+      return log1p(-erfc(t) / 2);
+    }
+  }
+); // log_ndtr_string
+
+const auto gcd_string = jiterator_stringify(
+  template 
+  T gcd(const T a_in, const T b_in) {
+    T a = abs(a_in);
+    T b = abs(b_in);
+
+    while (a != T{0}) {
+      T c = a;
+      a = b % a;
+      b = c;
+    }
+
+    return b;
+  }
+); // gcd_string
+
+const auto lcm_string = jiterator_stringify(
+  template 
+  T gcd(const T a_in, const T b_in) {
+    T a = abs(a_in);
+    T b = abs(b_in);
+
+    while (a != T{0}) {
+      T c = a;
+      a = b % a;
+      b = c;
+    }
+
+    return b;
+  }
+
+  template 
+  T lcm(const T a, const T b) {
+    T g = gcd(a, b);
+    return (g == T{0}) ? T{0} : abs(a / g * b);
+  }
+); // lcm_string
+
+/*
+ * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h".
+ */
+// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
+const auto digamma_string = jiterator_stringify(
+  template 
+  T digamma(T x) {
+    static const double PI_f64 = 3.14159265358979323846;
+
+    // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
+    if (x == 0) {
+      return copysign(POS_INFINITY, -x);
+    }
+
+    T result = 0;
+    if (x < 0) {
+      // Short-circuits if x is a negative integer and returns NaN
+      //   per the C++ standard
+      const bool x_is_integer = (x == trunc(x));
+      if (x_is_integer) {
+        return NAN;
+      }
+
+      // Extracts the fractional part of x as r, since tan(pi * r) is more numerically
+      // accurate than tan(pi * x). While these operations are mathematically equivalent
+      // since both x and r are in radians and tan() has a periodicity of pi, in practice
+      // the computation of pi * x is a source of error (when |x| > 1).
+      double q, r;
+      r = modf(static_cast(x), &q);
+      result = - PI_f64 / tan(PI_f64 * r);
+      x = 1 - x;
+    }
+
+    while (x < T{10}) {
+      result -= T{1} / x;
+      x += T{1};
+    }
+
+    if (x == T{10}) {
+      return result + T{2.25175258906672110764};
+    }
+
+    T y = 0;
+    if (x < T{1.0e17}) {
+      const T A[] = {
+        8.33333333333333333333E-2,
+        -2.10927960927960927961E-2,
+        7.57575757575757575758E-3,
+        -4.16666666666666666667E-3,
+        3.96825396825396825397E-3,
+        -8.33333333333333333333E-3,
+        8.33333333333333333333E-2,
+      };
+
+
+      T z = T{1} / (x * x);
+
+      T polevl_result = 0;
+      for (int i = 0; i <= 6; i++) {
+        polevl_result = polevl_result * z + A[i];
+      }
+      y = z * polevl_result;
+    }
+
+    return log(x) - (T{0.5} / x) - y + result;
+  }
+); // digamma_string
+
+/*
+ * This function is derived from the implementation of the zeta function in the Cephes Math Library.
+ * See note [3-Clause BSD License for the Cephes Math Library].
+ */
+const auto zeta_string = jiterator_stringify(
+  template 
+  T zeta(T x, T q) {
+    const T MACHEP{1.11022302462515654042E-16};
+    constexpr T zero{0};
+    constexpr T half{0.5};
+    constexpr T one{1};
+    static const T A[] = {
+        12.0,
+        -720.0,
+        30240.0,
+        -1209600.0,
+        47900160.0,
+        -1.8924375803183791606e9, /*1.307674368e12/691*/
+        7.47242496e10,
+        -2.950130727918164224e12, /*1.067062284288e16/3617*/
+        1.1646782814350067249e14, /*5.109094217170944e18/43867*/
+        -4.5979787224074726105e15, /*8.028576626982912e20/174611*/
+        1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
+        -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
+    };
+
+    int i = 0;
+    T a, b, k, s, t, w;
+
+    // Short-circuits x -> +infty
+    if (x == one) {
+      return POS_INFINITY;
+    }
+
+    // Short-circuits x < 1 -> NaN
+    if (x < one) {
+      return NAN;
+    }
+
+    // Short-circuits negative q integers map to +infty,
+    //   negative q non-integers map to NaN
+    if (q <= zero) {
+      if (q == floor(q)) {
+        return POS_INFINITY;
+      }
+      if (x != floor(x)) {
+        return NAN;
+      }
+    }
+
+    s = pow(q, -x);
+    a = q;
+    i = 0;
+    b = zero;
+    while ((i < 9) || (a <= T{9.0})) {
+      i += 1;
+      a += one;
+      b = pow(a, -x);
+      s += b;
+      if ((-MACHEP * s < b) && (b < MACHEP * s)) {
+        return s;
+      }
+    };
+
+    w = a;
+    s += b * w / (x - one);
+    s -= half * b;
+    a = one;
+    k = zero;
+    for (int i = 0; i < 12; i++) {
+      a *= x + k;
+      b /= w;
+      t = a * b / A[i];
+      s = s + t;
+      t = fabs(t / s);
+
+      if (t < MACHEP) {
+        return s;
+      }
+
+      k += one;
+      a *= x + k;
+      b /= w;
+      k += one;
+    }
+
+    return s;
+  }
+); // zeta_string
+
+const auto trigamma_string = jiterator_stringify(
+  template 
+  T trigamma(T x) {
+    const T PI{3.14159265358979323846};
+    T sign = 1;
+    T result = 0;
+
+    if (x < T{0.5}) {
+      sign = -1;
+      T sin_pi_x = sin(PI * x);
+      result -= (PI * PI) / (sin_pi_x * sin_pi_x);
+      x = 1 - x;
+    }
+
+    for (int i = 0; i < 6; ++i) {
+      result += T{1} / (x * x);
+      x += 1;
+    }
+
+    const T one{1};
+    const T ixx = one / (x*x);
+    result += (one + one / (T{2}*x) + ixx * (one/T{6} - ixx * (one/T{30} - ixx * (one/T{42})))) / x;
+    return sign * result;
+}
+); // trigamma_string
+
+const auto lgamma_string = jiterator_stringify(
+  template 
+  T lgamma_kernel(T a) {
+    return lgamma(a);
+  }
+); // lgamma_string
+
+const auto polygamma_string = zeta_string + jiterator_stringify(
+  template 
+  T polygamma(T x, int n) {
+    // already blocked if n <= 1
+    const auto one = T{1};
+    return ((n % 2) ? one : -one) * exp(lgamma(static_cast(n) + one)) *
+        zeta(static_cast(n + 1), x);
+  }
+); // polygamma_string
+
+const auto exp2_string = jiterator_stringify(
+  template 
+  T exp2_impl(T a) {
+    return exp2(a);
+  }
+
+  namespace std { template  class complex; }
+  template 
+  std::complex exp2_impl(std::complex x) {
+    // There is no std::exp2 overload for complex, so instead
+    // use the identity 2^x = e^(ln(2) * x)
+    const auto ln_2 = static_cast(0.693147180559945309417232121458176);
+    return exp(ln_2 * x);
+  }
+
+  template 
+  T exp2_kernel(T a) {
+    return exp2_impl(a);
+  }
+); // exp2_string
+
+const auto erfc_string = jiterator_stringify(
+  template 
+  T erfc_kernel(T a) {
+    return erfc(a);
+  }
+); // erfc_string
+
+const auto erfinv_string = jiterator_stringify(
+  template 
+  T erfinv_kernel(T a) {
+    return erfinv(a);
+  }
+); // erfinv_string
+
+const auto entr_string = jiterator_stringify(
+  template 
+  T entr(T a) {
+    if (a != a) {
+      return a;
+    }
+
+    if (a > 0) {
+      return -a * log(a);
+    }
+
+    if (a == 0) {
+      return 0;
+    }
+
+    return NEG_INFINITY;
+  }
+); // entr_string
+
+// NOTE: `kaiser_window_string` depends on `i0_string`
+//       for its implementation.
+const auto i0_string = jiterator_stringify(
+  template
+  T chbevl(T x, const T array[], const int len) {
+
+      T b0, b1, b2;
+
+      b0 = array[0];
+      b1 = 0;
+
+      for (int i = 1; i < len; ++i)  {
+          b2 = b1;
+          b1 = b0;
+          b0 = x * b1 - b2 + array[i];
+      }
+
+      return T{0.5} * (b0 - b2);
+  }
+
+  template
+  T i0(T _x) {
+      T x = fabs(_x);
+
+      if (x <= T{8.0}) {
+          /* Chebyshev coefficients for exp(-x) I0(x)
+          *   in the interval [0,8].
+          *
+          * lim(x->0){ exp(-x) I0(x) } = 1.
+          */
+          static const T A[] = {
+              -4.41534164647933937950E-18, 3.33079451882223809783E-17,
+              -2.43127984654795469359E-16, 1.71539128555513303061E-15,
+              -1.16853328779934516808E-14, 7.67618549860493561688E-14,
+              -4.85644678311192946090E-13, 2.95505266312963983461E-12,
+              -1.72682629144155570723E-11, 9.67580903537323691224E-11,
+              -5.18979560163526290666E-10, 2.65982372468238665035E-9,
+              -1.30002500998624804212E-8,  6.04699502254191894932E-8,
+              -2.67079385394061173391E-7,  1.11738753912010371815E-6,
+              -4.41673835845875056359E-6,  1.64484480707288970893E-5,
+              -5.75419501008210370398E-5,  1.88502885095841655729E-4,
+              -5.76375574538582365885E-4,  1.63947561694133579842E-3,
+              -4.32430999505057594430E-3,  1.05464603945949983183E-2,
+              -2.37374148058994688156E-2,  4.93052842396707084878E-2,
+              -9.49010970480476444210E-2,  1.71620901522208775349E-1,
+              -3.04682672343198398683E-1,  6.76795274409476084995E-1};
+
+          T y = (x / T{2.0}) - T{2.0};
+          return exp(x) * chbevl(y, A, int{30});
+      }
+
+      // Handles x > 8 case
+      /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
+      * in the inverted interval [8,infinity].
+      *
+      * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
+      */
+      const T B[] = {
+          -7.23318048787475395456E-18, -4.83050448594418207126E-18,
+          4.46562142029675999901E-17,  3.46122286769746109310E-17,
+          -2.82762398051658348494E-16, -3.42548561967721913462E-16,
+          1.77256013305652638360E-15,  3.81168066935262242075E-15,
+          -9.55484669882830764870E-15, -4.15056934728722208663E-14,
+          1.54008621752140982691E-14,  3.85277838274214270114E-13,
+          7.18012445138366623367E-13,  -1.79417853150680611778E-12,
+          -1.32158118404477131188E-11, -3.14991652796324136454E-11,
+          1.18891471078464383424E-11,  4.94060238822496958910E-10,
+          3.39623202570838634515E-9,   2.26666899049817806459E-8,
+          2.04891858946906374183E-7,   2.89137052083475648297E-6,
+          6.88975834691682398426E-5,   3.36911647825569408990E-3,
+          8.04490411014108831608E-1};
+
+      return (exp(x) * chbevl(T{32.0} / x - T{2.0}, B, int{25})) / sqrt(x);
+  }
+); // i0_string
+
+const auto i1_string = jiterator_stringify(
+  template
+  T chbevl(const T x, const T array[], const int len) {
+      T b0, b1, b2;
+
+      b0 = array[0];
+      b1 = 0;
+
+      for (int i = 1; i < len; ++i)  {
+          b2 = b1;
+          b1 = b0;
+          b0 = x * b1 - b2 + array[i];
+      }
+
+      return T{0.5} * (b0 - b2);
+  }
+
+  template 
+  T i1(T _x) {
+    const T x = fabs(_x);
+
+    if (x <= T{8.0}) {
+      // Chebyshev coefficients for exp(-x) i1(x) in the internal [0, 8]
+      //   lim(x->0){ exp(-x) i1(x) / x } = 1/2
+      static const T coefficients[] = {
+          2.77791411276104639959E-18, -2.11142121435816608115E-17,
+          1.55363195773620046921E-16, -1.10559694773538630805E-15,
+          7.60068429473540693410E-15, -5.04218550472791168711E-14,
+          3.22379336594557470981E-13, -1.98397439776494371520E-12,
+          1.17361862988909016308E-11, -6.66348972350202774223E-11,
+          3.62559028155211703701E-10, -1.88724975172282928790E-9,
+          9.38153738649577178388E-9,  -4.44505912879632808065E-8,
+          2.00329475355213526229E-7,  -8.56872026469545474066E-7,
+          3.47025130813767847674E-6,  -1.32731636560394358279E-5,
+          4.78156510755005422638E-5,  -1.61760815825896745588E-4,
+          5.12285956168575772895E-4,  -1.51357245063125314899E-3,
+          4.15642294431288815669E-3,  -1.05640848946261981558E-2,
+          2.47264490306265168283E-2,  -5.29459812080949914269E-2,
+          1.02643658689847095384E-1,  -1.76416518357834055153E-1,
+          2.52587186443633654823E-1};
+      const T y = x / T{2.0} - T{2.0};
+      const T out = exp(x) * x * chbevl(y, coefficients, int{29});
+      return (_x < T{0.0}) ? -out : out;
+    }
+
+    // Chebyshev coefficients for exp(-x) sqrt(x) i1(x)
+    //   in the inverted interval [8, infinity]
+    //   lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi)
+    static const T coefficients[] = {
+      7.51729631084210481353E-18,  4.41434832307170791151E-18,
+      -4.65030536848935832153E-17, -3.20952592199342395980E-17,
+      2.96262899764595013876E-16,  3.30820231092092828324E-16,
+      -1.88035477551078244854E-15, -3.81440307243700780478E-15,
+      1.04202769841288027642E-14,  4.27244001671195135429E-14,
+      -2.10154184277266431302E-14, -4.08355111109219731823E-13,
+      -7.19855177624590851209E-13, 2.03562854414708950722E-12,
+      1.41258074366137813316E-11,  3.25260358301548823856E-11,
+      -1.89749581235054123450E-11, -5.58974346219658380687E-10,
+      -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
+      -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
+      -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
+      7.78576235018280120474E-1};
+    const T out = (exp(x) * chbevl(T{32.} / x - T{2.}, coefficients, int{25})) / sqrt(x);
+    return (_x < T{0.}) ? -out : out;
+  }
+); // i1_string
+
+const auto i1e_string = jiterator_stringify(
+  template
+  T chbevl(const T x, const T array[], const int len) {
+      T b0, b1, b2;
+
+      b0 = array[0];
+      b1 = 0;
+
+      for (int i = 1; i < len; ++i)  {
+          b2 = b1;
+          b1 = b0;
+          b0 = x * b1 - b2 + array[i];
+      }
+
+      return T{0.5} * (b0 - b2);
+  }
+
+  // See double and float instantiations below
+  template 
+  T i1e(T _x) { }
+
+  // Double specialization (uses different coefficients than the float version)
+  template<>
+  double i1e(double _x) {
+    const double x = fabs(_x);
+    if (x <= double{8.}) {
+      // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8].
+      // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2.
+      static const double coefficients[] = {
+        2.77791411276104639959E-18, -2.11142121435816608115E-17,
+        1.55363195773620046921E-16, -1.10559694773538630805E-15,
+        7.60068429473540693410E-15, -5.04218550472791168711E-14,
+        3.22379336594557470981E-13, -1.98397439776494371520E-12,
+        1.17361862988909016308E-11, -6.66348972350202774223E-11,
+        3.62559028155211703701E-10, -1.88724975172282928790E-9,
+        9.38153738649577178388E-9,  -4.44505912879632808065E-8,
+        2.00329475355213526229E-7,  -8.56872026469545474066E-7,
+        3.47025130813767847674E-6,  -1.32731636560394358279E-5,
+        4.78156510755005422638E-5,  -1.61760815825896745588E-4,
+        5.12285956168575772895E-4,  -1.51357245063125314899E-3,
+        4.15642294431288815669E-3,  -1.05640848946261981558E-2,
+        2.47264490306265168283E-2,  -5.29459812080949914269E-2,
+        1.02643658689847095384E-1,  -1.76416518357834055153E-1,
+        2.52587186443633654823E-1};
+      const double y = x / double{2.} - double{2.};
+      const double out = chbevl(y, coefficients, int{29}) * x;
+      return (_x < 0.) ? -out : out;
+    }
+
+    // Chebyshev coefficients for exp(-x) sqrt(x) i1(x)
+    //   in the inverted interval (8, infinity].
+    // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi).
+    // TODO: what's an "inverted interval"? Open on the left
+    //   and closed on the right?
+  static const double coefficients[] = {
+      7.51729631084210481353E-18,  4.41434832307170791151E-18,
+      -4.65030536848935832153E-17, -3.20952592199342395980E-17,
+      2.96262899764595013876E-16,  3.30820231092092828324E-16,
+      -1.88035477551078244854E-15, -3.81440307243700780478E-15,
+      1.04202769841288027642E-14,  4.27244001671195135429E-14,
+      -2.10154184277266431302E-14, -4.08355111109219731823E-13,
+      -7.19855177624590851209E-13, 2.03562854414708950722E-12,
+      1.41258074366137813316E-11,  3.25260358301548823856E-11,
+      -1.89749581235054123450E-11, -5.58974346219658380687E-10,
+      -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
+      -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
+      -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
+      7.78576235018280120474E-1};
+
+    const double out = chbevl(double{32.} / x - double{2.}, coefficients, int{25}) / sqrt(x);
+    return (_x < double{0.}) ? -out : out;
+  }
+
+  // Float specialization (uses different coefficients than the double version)
+  template<>
+  float i1e(float _x) {
+    const float x = fabsf(_x);
+    if (x <= float{8.}) {
+      // Chebyshev double coefficients for exp(-x) i1(x) in the interval [0,8].
+      // Note: lim(x->0){ exp(-x) i1(x) / x } = 1/2.
+      static const float coefficients[] = {
+        9.38153738649577178388E-9f,
+        -4.44505912879632808065E-8f,
+        2.00329475355213526229E-7f,
+        -8.56872026469545474066E-7f,
+        3.47025130813767847674E-6f,
+        -1.32731636560394358279E-5f,
+        4.78156510755005422638E-5f,
+        -1.61760815825896745588E-4f,
+        5.12285956168575772895E-4f,
+        -1.51357245063125314899E-3f,
+        4.15642294431288815669E-3f,
+        -1.05640848946261981558E-2f,
+        2.47264490306265168283E-2f,
+        -5.29459812080949914269E-2f,
+        1.02643658689847095384E-1f,
+        -1.76416518357834055153E-1f,
+        2.52587186443633654823E-1f};
+      const float y = x / float{2.} - float{2.};
+      const float out = chbevl(y, coefficients, int{17}) * x;
+      return (_x < 0.) ? -out : out;
+    }
+
+    // Chebyshev coefficients for exp(-x) sqrt(x) i1(x)
+    //   in the inverted interval (8, infinity].
+    // Note: lim(x->inf){ exp(-x) sqrt(x) i1(x) } = 1/sqrt(2pi).
+    // TODO: what's an "inverted interval"? Open on the left
+    //   and closed on the right?
+  static const float coefficients[] = {
+      -3.83538038596423702205E-9f,
+      -2.63146884688951950684E-8f,
+      -2.51223623787020892529E-7f,
+      -3.88256480887769039346E-6f,
+      -1.10588938762623716291E-4f,
+      -9.76109749136146840777E-3f,
+      7.78576235018280120474E-1f};
+
+    const float out = chbevl(float{32.} / x - float{2.}, coefficients, int{7}) / sqrt(x);
+    return (_x < float{0.}) ? -out : out;
+  }
+); // i1e_string
+
+const auto kaiser_window_string = i0_string + jiterator_stringify(
+  template 
+  T kaiser_window(T a, T inv_alpha, T beta, T inv_i0_beta) {
+    T x = a * inv_alpha - T{1};
+    T y = max(T{0}, T{1} - x * x);
+    return i0(beta * sqrt(y)) * inv_i0_beta;
+  }
+); // kaiser_window_string
+
+const auto sinc_string = jiterator_stringify(
+  template 
+  T sinc(T a) {
+    if (a == T(0)) {
+      return T(1);
+    } else {
+      constexpr T pi = T(3.14159265358979323846L);
+      T product = pi * a;
+      return std::sin(product) / product;
+    }
+  }
+); // sinc_string
+
+const auto erfcx_string = jiterator_stringify(
+  /* The next function is taken from http://ab-initio.mit.edu/Faddeev */
+
+  /* Copyright (c) 2012 Massachusetts Institute of Technology
+  *
+  * Permission is hereby granted, free of charge, to any person obtaining
+  * a copy of this software and associated documentation files (the
+  * "Software"), to deal in the Software without restriction, including
+  * without limitation the rights to use, copy, modify, merge, publish,
+  * distribute, sublicense, and/or sell copies of the Software, and to
+  * permit persons to whom the Software is furnished to do so, subject to
+  * the following conditions:
+  *
+  * The above copyright notice and this permission notice shall be
+  * included in all copies or substantial portions of the Software.
+  *
+  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+  */
+
+  /* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by
+    Steven G. Johnson, October 2012.
+
+    This function combines a few different ideas.
+
+    First, for x > 50, it uses a continued-fraction expansion (same as
+    for the Faddeeva function, but with algebraic simplifications for z=i*x).
+
+    Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations,
+    but with two twists:
+
+        a) It maps x to y = 4 / (4+x) in [0,1].  This simple transformation,
+          inspired by a similar transformation in the octave-forge/specfun
+          erfcx by Soren Hauberg, results in much faster Chebyshev convergence
+          than other simple transformations I have examined.
+
+        b) Instead of using a single Chebyshev polynomial for the entire
+          [0,1] y interval, we break the interval up into 100 equal
+          subintervals, with a switch/lookup table, and use much lower
+          degree Chebyshev polynomials in each subinterval. This greatly
+          improves performance in my tests.
+
+    For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x),
+    with the usual checks for overflow etcetera.
+
+    Performance-wise, it seems to be substantially faster than either
+    the SLATEC DERFC function [or an erfcx function derived therefrom]
+    or Cody's CALERF function (from netlib.org/specfun), while
+    retaining near machine precision in accuracy.
+  */
+
+  /* Given y100 = 100 * y, where y = 4 / (4 + x) for x >= 0, compute erfc(x).
+
+    Uses a look-up table of 100 different Chebyshev polynomials
+    for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated
+    with the help of Maple and a little shell script.   This allows
+    the Chebyshev polynomials to be of significantly lower degree (about 1/4)
+    compared to fitting the whole [0,1] interval with a single polynomial.
+  */
+
+  // TODO: review if this is computing in double when given a float input
+  template 
+  T erfcx_y100(T y100) {
+    switch (static_cast(y100)) {
+      case 0: {
+      T t = 2*y100 - 1;
+      return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 1: {
+      T t = 2*y100 - 3;
+      return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 2: {
+      T t = 2*y100 - 5;
+      return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 3: {
+      T t = 2*y100 - 7;
+      return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 4: {
+      T t = 2*y100 - 9;
+      return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 5: {
+      T t = 2*y100 - 11;
+      return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 6: {
+      T t = 2*y100 - 13;
+      return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 7: {
+      T t = 2*y100 - 15;
+      return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 8: {
+      T t = 2*y100 - 17;
+      return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 9: {
+      T t = 2*y100 - 19;
+      return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 10: {
+      T t = 2*y100 - 21;
+      return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 11: {
+      T t = 2*y100 - 23;
+      return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 12: {
+      T t = 2*y100 - 25;
+      return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 13: {
+      T t = 2*y100 - 27;
+      return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 14: {
+      T t = 2*y100 - 29;
+      return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 15: {
+      T t = 2*y100 - 31;
+      return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 16: {
+      T t = 2*y100 - 33;
+      return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 17: {
+      T t = 2*y100 - 35;
+      return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 18: {
+      T t = 2*y100 - 37;
+      return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 19: {
+      T t = 2*y100 - 39;
+      return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 20: {
+      T t = 2*y100 - 41;
+      return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 21: {
+      T t = 2*y100 - 43;
+      return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 22: {
+      T t = 2*y100 - 45;
+      return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 23: {
+      T t = 2*y100 - 47;
+      return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 24: {
+      T t = 2*y100 - 49;
+      return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 25: {
+      T t = 2*y100 - 51;
+      return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 26: {
+      T t = 2*y100 - 53;
+      return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 27: {
+      T t = 2*y100 - 55;
+      return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 28: {
+      T t = 2*y100 - 57;
+      return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 29: {
+      T t = 2*y100 - 59;
+      return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 30: {
+      T t = 2*y100 - 61;
+      return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t;
+      }
+      case 31: {
+      T t = 2*y100 - 63;
+      return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 32: {
+      T t = 2*y100 - 65;
+      return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 33: {
+      T t = 2*y100 - 67;
+      return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 34: {
+      T t = 2*y100 - 69;
+      return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 35: {
+      T t = 2*y100 - 71;
+      return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 36: {
+      T t = 2*y100 - 73;
+      return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 37: {
+      T t = 2*y100 - 75;
+      return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 38: {
+      T t = 2*y100 - 77;
+      return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 39: {
+      T t = 2*y100 - 79;
+      return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 40: {
+      T t = 2*y100 - 81;
+      return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 41: {
+      T t = 2*y100 - 83;
+      return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 42: {
+      T t = 2*y100 - 85;
+      return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 43: {
+      T t = 2*y100 - 87;
+      return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 44: {
+      T t = 2*y100 - 89;
+      return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 45: {
+      T t = 2*y100 - 91;
+      return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 46: {
+      T t = 2*y100 - 93;
+      return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 47: {
+      T t = 2*y100 - 95;
+      return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 48: {
+      T t = 2*y100 - 97;
+      return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 49: {
+      T t = 2*y100 - 99;
+      return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 50: {
+      T t = 2*y100 - 101;
+      return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 51: {
+      T t = 2*y100 - 103;
+      return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 52: {
+      T t = 2*y100 - 105;
+      return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 53: {
+      T t = 2*y100 - 107;
+      return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 54: {
+      T t = 2*y100 - 109;
+      return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 55: {
+      T t = 2*y100 - 111;
+      return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 56: {
+      T t = 2*y100 - 113;
+      return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 57: {
+      T t = 2*y100 - 115;
+      return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 58: {
+      T t = 2*y100 - 117;
+      return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 59: {
+      T t = 2*y100 - 119;
+      return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 60: {
+      T t = 2*y100 - 121;
+      return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 61: {
+      T t = 2*y100 - 123;
+      return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 62: {
+      T t = 2*y100 - 125;
+      return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 63: {
+      T t = 2*y100 - 127;
+      return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 64: {
+      T t = 2*y100 - 129;
+      return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 65: {
+      T t = 2*y100 - 131;
+      return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 66: {
+      T t = 2*y100 - 133;
+      return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 67: {
+      T t = 2*y100 - 135;
+      return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 68: {
+      T t = 2*y100 - 137;
+      return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 69: {
+      T t = 2*y100 - 139;
+      return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 70: {
+      T t = 2*y100 - 141;
+      return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 71: {
+      T t = 2*y100 - 143;
+      return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 72: {
+      T t = 2*y100 - 145;
+      return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t;
+      }
+      case 73: {
+      T t = 2*y100 - 147;
+      return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 74: {
+      T t = 2*y100 - 149;
+      return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 75: {
+      T t = 2*y100 - 151;
+      return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 76: {
+      T t = 2*y100 - 153;
+      return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 77: {
+      T t = 2*y100 - 155;
+      return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 78: {
+      T t = 2*y100 - 157;
+      return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 79: {
+      T t = 2*y100 - 159;
+      return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 80: {
+      T t = 2*y100 - 161;
+      return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 81: {
+      T t = 2*y100 - 163;
+      return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 82: {
+      T t = 2*y100 - 165;
+      return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 83: {
+      T t = 2*y100 - 167;
+      return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 84: {
+      T t = 2*y100 - 169;
+      return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 85: {
+      T t = 2*y100 - 171;
+      return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 86: {
+      T t = 2*y100 - 173;
+      return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 87: {
+      T t = 2*y100 - 175;
+      return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 88: {
+      T t = 2*y100 - 177;
+      return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 89: {
+      T t = 2*y100 - 179;
+      return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 90: {
+      T t = 2*y100 - 181;
+      return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 91: {
+      T t = 2*y100 - 183;
+      return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 92: {
+      T t = 2*y100 - 185;
+      return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 93: {
+      T t = 2*y100 - 187;
+      return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 94: {
+      T t = 2*y100 - 189;
+      return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 95: {
+      T t = 2*y100 - 191;
+      return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 96: {
+      T t = 2*y100 - 193;
+      return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 97: {
+      T t = 2*y100 - 195;
+      return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 98: {
+      T t = 2*y100 - 197;
+      return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t;
+      }
+      case 99: {
+      T t = 2*y100 - 199;
+      return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t;
+      }
+    }
+
+    // we only get here if y = 1, i.e. |x| < 4*eps, in which case
+    // erfcx is within 1e-15 of 1..
+    return 1.;
+  }
+
+  template 
+  T erfcx(T x) {
+    // Short-circuits on NaN (returning NaN)
+    if (x != x) {
+      return x;
+    }
+
+    if (x >= 0) {
+      if (x > T{50}) { // continued-fraction expansion is faster
+        const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi)
+
+        if (x > T{5e7}) { // 1-term expansion, important to avoid overflow
+          return ispi / x;
+        }
+
+        /* 5-term expansion (rely on compiler for CSE), simplified from:
+                  ispi / (x+0.5/(x+1/(x+1.5/(x+2/x))))  */
+        return ispi * ((x*x) * (x*x+T{4.5}) + T{2}) / (x * ((x*x) * (x*x+T{5}) + T{3.75}));
+      }
+
+      // x >= 0 x <= 50
+      return erfcx_y100(T{400} / (T{4} + x));
+    }
+
+    // x < 0
+    if (x < T{-26.7}) {
+      return POS_INFINITY;
+    } else if (x < T{-6.1}) {
+      return T{2} * exp(x * x);
+    }
+
+    // x < 0 and x >= -6.1
+    return T{2} * exp(x * x) - erfcx_y100(T{400} / (T{4} - x));
+  }
+); // erfcx_string
+
+const auto airy_ai_string = jiterator_stringify(
+    template
+    T airy_ai_forward(T x) {
+        static const T AN[] = {
+                +3.46538101525629032477e-01,
+                +1.20075952739645805542e+01,
+                +7.62796053615234516538e+01,
+                +1.68089224934630576269e+02,
+                +1.59756391350164413639e+02,
+                +7.05360906840444183113e+01,
+                +1.40264691163389668864e+01,
+                +9.99999999999999995305e-01,
+        };
+
+        static const T AD[] = {
+                +5.67594532638770212846e-01,
+                +1.47562562584847203173e+01,
+                +8.45138970141474626562e+01,
+                +1.77318088145400459522e+02,
+                +1.64234692871529701831e+02,
+                +7.14778400825575695274e+01,
+                +1.40959135607834029598e+01,
+                +1.00000000000000000470e+00,
+        };
+
+        static const T AFN[] = {
+                -1.31696323418331795333e-01,
+                -6.26456544431912369773e-01,
+                -6.93158036036933542233e-01,
+                -2.79779981545119124951e-01,
+                -4.91900132609500318020e-02,
+                -4.06265923594885404393e-03,
+                -1.59276496239262096340e-04,
+                -2.77649108155232920844e-06,
+                -1.67787698489114633780e-08,
+        };
+
+        static const T AFD[] = {
+                +1.33560420706553243746e+01,
+                +3.26825032795224613948e+01,
+                +2.67367040941499554804e+01,
+                +9.18707402907259625840e+00,
+                +1.47529146771666414581e+00,
+                +1.15687173795188044134e-01,
+                +4.40291641615211203805e-03,
+                +7.54720348287414296618e-05,
+                +4.51850092970580378464e-07,
+        };
+
+        static const T AGN[] = {
+                +1.97339932091685679179e-02,
+                +3.91103029615688277255e-01,
+                +1.06579897599595591108e+00,
+                +9.39169229816650230044e-01,
+                +3.51465656105547619242e-01,
+                +6.33888919628925490927e-02,
+                +5.85804113048388458567e-03,
+                +2.82851600836737019778e-04,
+                +6.98793669997260967291e-06,
+                +8.11789239554389293311e-08,
+                +3.41551784765923618484e-10,
+        };
+
+        static const T AGD[] = {
+                +9.30892908077441974853e+00,
+                +1.98352928718312140417e+01,
+                +1.55646628932864612953e+01,
+                +5.47686069422975497931e+00,
+                +9.54293611618961883998e-01,
+                +8.64580826352392193095e-02,
+                +4.12656523824222607191e-03,
+                +1.01259085116509135510e-04,
+                +1.17166733214413521882e-06,
+                +4.91834570062930015649e-09,
+        };
+
+        int domain_flag = 0;
+
+        T ai;
+
+        if (isinf(x)) {
+            return NAN;
+        }
+
+        if (x > T(103.892)) {
+            return T(0.0);
+        }
+
+        T f;
+        T g;
+        T k;
+
+        if (x < T(-2.09)) {
+            T z = T(1.0) / (T(-2.0) * x * sqrt(-x) / T(3.0));
+
+            T afn = 0.0;
+
+            for (uint8_t index = 0; index <= 8; index++) {
+                afn = afn * (z * z) + AFN[index];
+            }
+
+            T afd = 0.0;
+
+            for (uint8_t index = 0; index <= 8; index++) {
+                afd = afd * (z * z) + AFD[index];
+            }
+
+            T agn = 0.0;
+
+            for (uint8_t index = 0; index <= 10 + 0; index++) {
+                agn = agn * (z * z) + AGN[index];
+            }
+
+            T agd = 0.0;
+
+            for (uint8_t index = 0; index <= 10 - 1; index++) {
+                agd = agd * (z * z) + AGD[index];
+            }
+
+            T t = T(-2.0) * x * sqrt(-x) / T(3.0) + T(0.25) * T(3.14159265358979323846);
+
+            return T(5.64189583547756286948e-01) / sqrt(sqrt(-x)) * (sin(t) * (T(1.0) + z * z * afn / afd) - cos(t) * (z * agn / agd));
+        }
+
+        if (x >= T(2.09)) {
+            domain_flag = 5;
+
+            T zeta = T(2.0) * x * sqrt(x) / T(3.0);
+
+            T an = 0.0;
+
+            for (uint8_t index = 0; index <= 7; index++) {
+                an = an * (T(1.0) / zeta) + AN[index];
+            }
+
+            T ad = 0.0;
+
+            for (uint8_t index = 0; index <= 7; index++) {
+                ad = ad * (T(1.0) / zeta) + AD[index];
+            }
+
+            ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * sqrt(sqrt(x)) * exp(zeta));
+
+            if (x > T(8.3203353)) {
+                return ai;
+            }
+        }
+
+        f = 1.0;
+        g = x;
+        k = 1.0;
+
+        T m = 1.0;
+        T n = x;
+        T t = 1.0;
+        T z = x * x * x;
+
+        while (t > T(1.11022302462515654042e-16)) {
+            m *= z;
+            k += T(1.0);
+            m /= k;
+            n *= z;
+            k += T(1.0);
+            n /= k;
+            m /= k;
+            f += m;
+            k += T(1.0);
+            n /= k;
+            g += n;
+
+            t = abs(m / f);
+        }
+
+        if ((domain_flag & 1) == 0) {
+            return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g;
+        }
+
+        return ai;
+    } // T airy_ai(T x)
+); // airy_ai_string
+
+const auto bessel_j0_string = jiterator_stringify(
+    template
+    T bessel_j0_forward(T x) {
+        static const T PP[] = {
+                +7.96936729297347051624e-04,
+                +8.28352392107440799803e-02,
+                +1.23953371646414299388e+00,
+                +5.44725003058768775090e+00,
+                +8.74716500199817011941e+00,
+                +5.30324038235394892183e+00,
+                +9.99999999999999997821e-01,
+        };
+
+        static const T PQ[] = {
+                +9.24408810558863637013e-04,
+                +8.56288474354474431428e-02,
+                +1.25352743901058953537e+00,
+                +5.47097740330417105182e+00,
+                +8.76190883237069594232e+00,
+                +5.30605288235394617618e+00,
+                +1.00000000000000000218e+00,
+        };
+
+        static const T QP[] = {
+                -1.13663838898469149931e-02,
+                -1.28252718670509318512e+00,
+                -1.95539544257735972385e+01,
+                -9.32060152123768231369e+01,
+                -1.77681167980488050595e+02,
+                -1.47077505154951170175e+02,
+                -5.14105326766599330220e+01,
+                -6.05014350600728481186e+00,
+        };
+
+        static const T QQ[] = {
+                +6.43178256118178023184e+01,
+                +8.56430025976980587198e+02,
+                +3.88240183605401609683e+03,
+                +7.24046774195652478189e+03,
+                +5.93072701187316984827e+03,
+                +2.06209331660327847417e+03,
+                +2.42005740240291393179e+02,
+        };
+
+        static const T RP[] = {
+                -4.79443220978201773821e+09,
+                +1.95617491946556577543e+12,
+                -2.49248344360967716204e+14,
+                +9.70862251047306323952e+15,
+        };
+
+        static const T RQ[] = {
+                +4.99563147152651017219e+02,
+                +1.73785401676374683123e+05,
+                +4.84409658339962045305e+07,
+                +1.11855537045356834862e+10,
+                +2.11277520115489217587e+12,
+                +3.10518229857422583814e+14,
+                +3.18121955943204943306e+16,
+                +1.71086294081043136091e+18,
+        };
+
+        if (x < T(0)) {
+            x = -x;
+        }
+
+        if (x <= T(5.0)) {
+            if (x < T(0.00001)) {
+                return T(1.0) - x * x / T(4.0);
+            }
+
+            T rp = 0.0;
+
+            for (uint8_t index = 0; index <= 3; index++) {
+                rp = rp * (x * x) + RP[index];
+            }
+
+            T rq = 0.0;
+
+            for (uint8_t index = 0; index <= 7; index++) {
+                rq = rq * (x * x) + RQ[index];
+            }
+
+            return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
+        }
+
+        T pp = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pp = pp * (T(25.0) / (x * x)) + PP[index];
+        }
+
+        T pq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pq = pq * (T(25.0) / (x * x)) + PQ[index];
+        }
+
+        T qp = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            qp = qp * (T(25.0) / (x * x)) + QP[index];
+        }
+
+        T qq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            qq = qq * (T(25.0) / (x * x)) + QQ[index];
+        }
+
+        return (pp / pq * cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x);
+    } // bessel_j0_forward(T x)
+); // bessel_j0_string
+
+const auto bessel_y0_string = bessel_j0_string + jiterator_stringify(
+    template
+    T bessel_y0_forward(T x) {
+        static const T PP[] = {
+                +7.96936729297347051624e-04,
+                +8.28352392107440799803e-02,
+                +1.23953371646414299388e+00,
+                +5.44725003058768775090e+00,
+                +8.74716500199817011941e+00,
+                +5.30324038235394892183e+00,
+                +9.99999999999999997821e-01,
+        };
+
+        static const T PQ[] = {
+                +9.24408810558863637013e-04,
+                +8.56288474354474431428e-02,
+                +1.25352743901058953537e+00,
+                +5.47097740330417105182e+00,
+                +8.76190883237069594232e+00,
+                +5.30605288235394617618e+00,
+                +1.00000000000000000218e+00,
+        };
+
+        static const T QP[] = {
+                -1.13663838898469149931e-02,
+                -1.28252718670509318512e+00,
+                -1.95539544257735972385e+01,
+                -9.32060152123768231369e+01,
+                -1.77681167980488050595e+02,
+                -1.47077505154951170175e+02,
+                -5.14105326766599330220e+01,
+                -6.05014350600728481186e+00,
+        };
+
+        static const T QQ[] = {
+                +6.43178256118178023184e+01,
+                +8.56430025976980587198e+02,
+                +3.88240183605401609683e+03,
+                +7.24046774195652478189e+03,
+                +5.93072701187316984827e+03,
+                +2.06209331660327847417e+03,
+                +2.42005740240291393179e+02,
+        };
+
+        static const T YP[] = {
+                +1.55924367855235737965e+04,
+                -1.46639295903971606143e+07,
+                +5.43526477051876500413e+09,
+                -9.82136065717911466409e+11,
+                +8.75906394395366999549e+13,
+                -3.46628303384729719441e+15,
+                +4.42733268572569800351e+16,
+                -1.84950800436986690637e+16,
+        };
+
+        static const T YQ[] = {
+                +1.04128353664259848412e+03,
+                +6.26107330137134956842e+05,
+                +2.68919633393814121987e+08,
+                +8.64002487103935000337e+10,
+                +2.02979612750105546709e+13,
+                +3.17157752842975028269e+15,
+                +2.50596256172653059228e+17,
+        };
+
+        if (x <= T(5.0)) {
+            if (x == T(0.0)) {
+                return NEG_INFINITY;
+            }
+
+            if (x < T(0.0)) {
+                NAN;
+            }
+
+            T yp = 0.0;
+
+            for (uint8_t index = 0; index <= 7; index++) {
+                yp = yp * (x * x) + YP[index];
+            }
+
+            T yq = 0.0;
+
+            for (uint8_t index = 0; index <= 6; index++) {
+                yq = yq * (x * x) + YQ[index];
+            }
+
+            return yp / yq + (T(0.636619772367581343075535053490057448) * log(x) * bessel_j0_forward(x));
+        }
+
+        T pp = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pp = pp * (T(25.0) / (x * x)) + PP[index];
+        }
+
+        T pq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pq = pq * (T(25.0) / (x * x)) + PQ[index];
+        }
+
+        T qp = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            qp = qp * (T(25.0) / (x * x)) + QP[index];
+        }
+
+        T qq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            qq = qq * (T(25.0) / (x * x)) + QQ[index];
+        }
+
+        return (pp / pq * sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / sqrt(x);
+    } // bessel_y0_forward(T x)
+); // bessel_y0_string
+
+const auto bessel_j1_string = jiterator_stringify(
+    template
+    T bessel_j1_forward(T x) {
+        static const T PP[] = {
+                +7.62125616208173112003e-04,
+                +7.31397056940917570436e-02,
+                +1.12719608129684925192e+00,
+                +5.11207951146807644818e+00,
+                +8.42404590141772420927e+00,
+                +5.21451598682361504063e+00,
+                +1.00000000000000000254e+00,
+        };
+
+        static const T PQ[] = {
+                +5.71323128072548699714e-04,
+                +6.88455908754495404082e-02,
+                +1.10514232634061696926e+00,
+                +5.07386386128601488557e+00,
+                +8.39985554327604159757e+00,
+                +5.20982848682361821619e+00,
+                +9.99999999999999997461e-01,
+        };
+
+        static const T QP[] = {
+                +5.10862594750176621635e-02,
+                +4.98213872951233449420e+00,
+                +7.58238284132545283818e+01,
+                +3.66779609360150777800e+02,
+                +7.10856304998926107277e+02,
+                +5.97489612400613639965e+02,
+                +2.11688757100572135698e+02,
+                +2.52070205858023719784e+01,
+        };
+
+        static const T QQ[] = {
+                +7.42373277035675149943e+01,
+                +1.05644886038262816351e+03,
+                +4.98641058337653607651e+03,
+                +9.56231892404756170795e+03,
+                +7.99704160447350683650e+03,
+                +2.82619278517639096600e+03,
+                +3.36093607810698293419e+02,
+        };
+
+        static const T RP[] = {
+                -8.99971225705559398224e+08,
+                +4.52228297998194034323e+11,
+                -7.27494245221818276015e+13,
+                +3.68295732863852883286e+15,
+        };
+
+        static const T RQ[] = {
+                +6.20836478118054335476e+02,
+                +2.56987256757748830383e+05,
+                +8.35146791431949253037e+07,
+                +2.21511595479792499675e+10,
+                +4.74914122079991414898e+12,
+                +7.84369607876235854894e+14,
+                +8.95222336184627338078e+16,
+                +5.32278620332680085395e+18,
+        };
+
+        if (x < T(0.0)) {
+            return -bessel_j1_forward(-x);
+        }
+
+        if (x <= T(5.0)) {
+            T rp = 0.0;
+
+            for (uint8_t index = 0; index <= 3; index++) {
+                rp = rp * (x * x) + RP[index];
+            }
+
+            T rq = 0.0;
+
+            for (uint8_t index = 0; index <= 7; index++) {
+                rq = rq * (x * x) + RQ[index];
+            }
+
+            return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
+        }
+
+        T pp = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
+        }
+
+        T pq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
+        }
+
+        T qp = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
+        }
+
+        T qq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
+        }
+
+        return (pp / pq * cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x);
+    } // bessel_j1_forward(T x)
+); // bessel_j1_string
+
+const auto bessel_y1_string = bessel_j1_string + jiterator_stringify(
+    template
+    T bessel_y1_forward(T x) {
+        static const T PP[] = {
+                +7.62125616208173112003e-04,
+                +7.31397056940917570436e-02,
+                +1.12719608129684925192e+00,
+                +5.11207951146807644818e+00,
+                +8.42404590141772420927e+00,
+                +5.21451598682361504063e+00,
+                +1.00000000000000000254e+00,
+        };
+
+        static const T PQ[] = {
+                +5.71323128072548699714e-04,
+                +6.88455908754495404082e-02,
+                +1.10514232634061696926e+00,
+                +5.07386386128601488557e+00,
+                +8.39985554327604159757e+00,
+                +5.20982848682361821619e+00,
+                +9.99999999999999997461e-01,
+        };
+
+        static const T QP[] = {
+                +5.10862594750176621635e-02,
+                +4.98213872951233449420e+00,
+                +7.58238284132545283818e+01,
+                +3.66779609360150777800e+02,
+                +7.10856304998926107277e+02,
+                +5.97489612400613639965e+02,
+                +2.11688757100572135698e+02,
+                +2.52070205858023719784e+01,
+        };
+
+        static const T QQ[] = {
+                +7.42373277035675149943e+01,
+                +1.05644886038262816351e+03,
+                +4.98641058337653607651e+03,
+                +9.56231892404756170795e+03,
+                +7.99704160447350683650e+03,
+                +2.82619278517639096600e+03,
+                +3.36093607810698293419e+02,
+        };
+
+        static const T YP[] = {
+                +1.26320474790178026440e+09,
+                -6.47355876379160291031e+11,
+                +1.14509511541823727583e+14,
+                -8.12770255501325109621e+15,
+                +2.02439475713594898196e+17,
+                -7.78877196265950026825e+17,
+        };
+
+        static const T YQ[] = {
+                +5.94301592346128195359e+02,
+                +2.35564092943068577943e+05,
+                +7.34811944459721705660e+07,
+                +1.87601316108706159478e+10,
+                +3.88231277496238566008e+12,
+                +6.20557727146953693363e+14,
+                +6.87141087355300489866e+16,
+                +3.97270608116560655612e+18,
+        };
+
+        if (x <= T(5.0)) {
+            if (x == T(0.0)) {
+                return NEG_INFINITY;
+            }
+
+            if (x <= T(0.0)) {
+                return NAN;
+            }
+
+            T yp = 0.0;
+
+            for (uint8_t index = 0; index <= 5; index++) {
+                yp = yp * (x * x) + YP[index];
+            }
+
+            T yq = 0.0;
+
+            for (uint8_t index = 0; index <= 7; index++) {
+                yq = yq * (x * x) + YQ[index];
+            }
+
+            return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * log(x) - T(1.0) / x));
+        }
+
+        T pp = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
+        }
+
+        T pq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
+        }
+
+        T qp = 0.0;
+
+        for (uint8_t index = 0; index <= 7; index++) {
+            qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
+        }
+
+        T qq = 0.0;
+
+        for (uint8_t index = 0; index <= 6; index++) {
+            qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
+        }
+
+        return (pp / pq * sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / sqrt(x);
+    } // bessel_y1_forward(T x)
+); // bessel_y1_string
+
+const auto chebyshev_polynomial_t_string = jiterator_stringify(
+    template
+    T chebyshev_polynomial_t_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (abs(x) == T(1.0)) {
+            if (x > T(0.0) || n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if ((n > 6) && (abs(x) < T(1.0))) {
+            return cos(n * acos(x));
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x;
+        }
+
+        T p = T(1.0);
+        T q = x;
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // chebyshev_polynomial_t_forward(T x, int64_t n)
+
+    template
+    T chebyshev_polynomial_t_forward(T x, T n) {
+        return chebyshev_polynomial_t_forward(x, static_cast(n));
+    } // chebyshev_polynomial_t_forward(T x, T n)
+); // chebyshev_polynomial_t_string
+
+const auto chebyshev_polynomial_u_string = jiterator_stringify(
+    template
+    T chebyshev_polynomial_u_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (abs(x) == T(1.0)) {
+            if (x > T(0.0) || n % 2 == 0) {
+                return n + 1;
+            }
+
+            return -(n + 1);
+        }
+
+        if ((n > 8) && (abs(x) < T(1.0))) {
+            if (sin(acos(x)) != T(0.0)) {
+                return sin((n + 1) * acos(x)) / sin(acos(x));
+            }
+
+            return (n + 1) * cos((n + 1) * acos(x)) / x;
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x;
+        }
+
+        T p = T(1.0);
+        T q = x + x;
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // chebyshev_polynomial_u_forward(T x, int64_t n)
+
+    template
+    T chebyshev_polynomial_u_forward(T x, T n) {
+        return chebyshev_polynomial_u_forward(x, static_cast(n));
+    } // chebyshev_polynomial_u_forward(T x, T n)
+); // chebyshev_polynomial_u_string
+
+const auto chebyshev_polynomial_v_string = jiterator_stringify(
+    template
+    T chebyshev_polynomial_v_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (abs(x) == T(1.0)) {
+            if (x > T(0.0)) {
+                return T(1.0);
+            }
+
+            if (n % 2 == 0) {
+                return n + n + 1;
+            }
+
+            return -(n + n + 1);
+        }
+
+        if ((n > 8) && (abs(x) < T(1.0))) {
+            if (sin(acos(x) / T(2.0)) != T(1.0)) {
+                return cos((n + T(0.5)) * acos(x)) / cos(acos(x) / T(2.0));
+            }
+
+            if (n % 2 == 0) {
+                return n + n + 1;
+            }
+
+            return -(n + n + 1);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x - T(1.0);
+        }
+
+        T p = T(1.0);
+        T q = x + x - T(1.0);
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // chebyshev_polynomial_v_forward(T x, int64_t n)
+
+    template
+    T chebyshev_polynomial_v_forward(T x, T n) {
+        return chebyshev_polynomial_v_forward(x, static_cast(n));
+    } // chebyshev_polynomial_v_forward(T x, T n)
+); // chebyshev_polynomial_v_string
+
+const auto chebyshev_polynomial_w_string = jiterator_stringify(
+    template
+    T chebyshev_polynomial_w_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (abs(x) == T(1.0)) {
+            if (x > T(0.0)) {
+                return n + n + 1;
+            }
+
+            if (n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if ((n > 8) && (abs(x) < T(1.0))) {
+            if (cos(acos(x) / T(2.0)) != T(1.0)) {
+                return sin((n + T(0.5)) * acos(x)) / sin(acos(x) / T(2.0));
+            }
+
+            if (x > T(0.0)) {
+                return n + n + 1;
+            }
+
+            if (n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x + T(1.0);
+        }
+
+        T p = T(1.0);
+        T q = x + x + T(1.0);
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // chebyshev_polynomial_w_forward(T x, int64_t n)
+
+    template
+    T chebyshev_polynomial_w_forward(T x, T n) {
+        return chebyshev_polynomial_w_forward(x, static_cast(n));
+    } // chebyshev_polynomial_w_forward(T x, T n)
+); // chebyshev_polynomial_w_string
+
+const auto hermite_polynomial_h_string = jiterator_stringify(
+    template
+    T hermite_polynomial_h_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x;
+        }
+
+        T p = T(1.0);
+        T q = x + x;
+        T r = T(0.0);
+
+        for (int64_t k = 2; k < n + n; k += 2) {
+            r = (x + x) * q - k * p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // hermite_polynomial_h_forward(T x, int64_t n)
+
+    template
+    T hermite_polynomial_h_forward(T x, T n) {
+        return hermite_polynomial_h_forward(x, static_cast(n));
+    } // hermite_polynomial_h_forward(T x, T n)
+); // hermite_polynomial_h_string
+
+const auto hermite_polynomial_he_string = jiterator_stringify(
+    template
+    T hermite_polynomial_he_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x;
+        }
+
+        T p = T(1.0);
+        T q = x;
+        T r;
+
+        for (int64_t k = 1; k < n; k++) {
+            r = x * q - k * p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // hermite_polynomial_he_forward(T x, int64_t n)
+
+    template
+    T hermite_polynomial_he_forward(T x, T n) {
+        return hermite_polynomial_he_forward(x, static_cast(n));
+    } // hermite_polynomial_he_forward(T x, T n)
+); // hermite_polynomial_he_string
+
+const auto laguerre_polynomial_l_string = jiterator_stringify(
+    template
+    T laguerre_polynomial_l_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (abs(x) == T(0.0)) {
+            return T(1.0);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return T(1.0) - x;
+        }
+
+        T p = T(1.0);
+        T q = T(1.0) - x;
+        T r;
+
+        for (int64_t k = 1; k < n; k++) {
+            r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1);
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // laguerre_polynomial_l_forward(T x, int64_t n)
+
+    template
+    T laguerre_polynomial_l_forward(T x, T n) {
+        return laguerre_polynomial_l_forward(x, static_cast(n));
+    } // laguerre_polynomial_l_forward(T x, T n)
+); // laguerre_polynomial_l_string
+
+const auto legendre_polynomial_p_string = jiterator_stringify(
+    template
+    T legendre_polynomial_p_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (abs(x) == T(1.0)) {
+            if (x > T(0.0) || n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x;
+        }
+
+        T p = T(1.0);
+        T q = x;
+        T r;
+
+        for (int64_t k = 1; k < n; k++) {
+            r = ((k + k + 1) * x * q - k * p) / (k + 1);
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // legendre_polynomial_p_forward(T x, int64_t n)
+
+    template
+    T legendre_polynomial_p_forward(T x, T n) {
+        return legendre_polynomial_p_forward(x, static_cast(n));
+    } // legendre_polynomial_p_forward(T x, T n)
+); // legendre_polynomial_p_string
+
+const auto modified_bessel_i0_string = jiterator_stringify(
+    template
+    T modified_bessel_i0_forward(T x) {
+        static const T A[] = {
+                -4.41534164647933937950e-18,
+                +3.33079451882223809783e-17,
+                -2.43127984654795469359e-16,
+                +1.71539128555513303061e-15,
+                -1.16853328779934516808e-14,
+                +7.67618549860493561688e-14,
+                -4.85644678311192946090e-13,
+                +2.95505266312963983461e-12,
+                -1.72682629144155570723e-11,
+                +9.67580903537323691224e-11,
+                -5.18979560163526290666e-10,
+                +2.65982372468238665035e-09,
+                -1.30002500998624804212e-08,
+                +6.04699502254191894932e-08,
+                -2.67079385394061173391e-07,
+                +1.11738753912010371815e-06,
+                -4.41673835845875056359e-06,
+                +1.64484480707288970893e-05,
+                -5.75419501008210370398e-05,
+                +1.88502885095841655729e-04,
+                -5.76375574538582365885e-04,
+                +1.63947561694133579842e-03,
+                -4.32430999505057594430e-03,
+                +1.05464603945949983183e-02,
+                -2.37374148058994688156e-02,
+                +4.93052842396707084878e-02,
+                -9.49010970480476444210e-02,
+                +1.71620901522208775349e-01,
+                -3.04682672343198398683e-01,
+                +6.76795274409476084995e-01,
+        };
+
+        static const T B[] = {
+                -7.23318048787475395456e-18,
+                -4.83050448594418207126e-18,
+                +4.46562142029675999901e-17,
+                +3.46122286769746109310e-17,
+                -2.82762398051658348494e-16,
+                -3.42548561967721913462e-16,
+                +1.77256013305652638360e-15,
+                +3.81168066935262242075e-15,
+                -9.55484669882830764870e-15,
+                -4.15056934728722208663e-14,
+                +1.54008621752140982691e-14,
+                +3.85277838274214270114e-13,
+                +7.18012445138366623367e-13,
+                -1.79417853150680611778e-12,
+                -1.32158118404477131188e-11,
+                -3.14991652796324136454e-11,
+                +1.18891471078464383424e-11,
+                +4.94060238822496958910e-10,
+                +3.39623202570838634515e-09,
+                +2.26666899049817806459e-08,
+                +2.04891858946906374183e-07,
+                +2.89137052083475648297e-06,
+                +6.88975834691682398426e-05,
+                +3.36911647825569408990e-03,
+                +8.04490411014108831608e-01,
+        };
+
+        T p;
+        T q = 0.0;
+
+        if (abs(x) <= T(8.0)) {
+            T a = A[0];
+
+            for (uint8_t index = 1; index < 30; index++) {
+                p = q;
+                q = a;
+                a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
+            }
+
+            return exp(abs(x)) * (T(0.5) * (a - p));
+        }
+
+        T b = B[0];
+
+        for (uint8_t index = 1; index < 25; index++) {
+            p = q;
+            q = b;
+            b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index];
+        }
+
+        return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x));
+    } // modified_bessel_i0_forward(T x)
+); // modified_bessel_i0_string
+
+const auto modified_bessel_i1_string = jiterator_stringify(
+    template
+    T modified_bessel_i1_forward(T x) {
+        static const T A[] = {
+                +2.77791411276104639959e-18,
+                -2.11142121435816608115e-17,
+                +1.55363195773620046921e-16,
+                -1.10559694773538630805e-15,
+                +7.60068429473540693410e-15,
+                -5.04218550472791168711e-14,
+                +3.22379336594557470981e-13,
+                -1.98397439776494371520e-12,
+                +1.17361862988909016308e-11,
+                -6.66348972350202774223e-11,
+                +3.62559028155211703701e-10,
+                -1.88724975172282928790e-09,
+                +9.38153738649577178388e-09,
+                -4.44505912879632808065e-08,
+                +2.00329475355213526229e-07,
+                -8.56872026469545474066e-07,
+                +3.47025130813767847674e-06,
+                -1.32731636560394358279e-05,
+                +4.78156510755005422638e-05,
+                -1.61760815825896745588e-04,
+                +5.12285956168575772895e-04,
+                -1.51357245063125314899e-03,
+                +4.15642294431288815669e-03,
+                -1.05640848946261981558e-02,
+                +2.47264490306265168283e-02,
+                -5.29459812080949914269e-02,
+                +1.02643658689847095384e-01,
+                -1.76416518357834055153e-01,
+                +2.52587186443633654823e-01,
+        };
+
+        static const T B[] = {
+                +7.51729631084210481353e-18,
+                +4.41434832307170791151e-18,
+                -4.65030536848935832153e-17,
+                -3.20952592199342395980e-17,
+                +2.96262899764595013876e-16,
+                +3.30820231092092828324e-16,
+                -1.88035477551078244854e-15,
+                -3.81440307243700780478e-15,
+                +1.04202769841288027642e-14,
+                +4.27244001671195135429e-14,
+                -2.10154184277266431302e-14,
+                -4.08355111109219731823e-13,
+                -7.19855177624590851209e-13,
+                +2.03562854414708950722e-12,
+                +1.41258074366137813316e-11,
+                +3.25260358301548823856e-11,
+                -1.89749581235054123450e-11,
+                -5.58974346219658380687e-10,
+                -3.83538038596423702205e-09,
+                -2.63146884688951950684e-08,
+                -2.51223623787020892529e-07,
+                -3.88256480887769039346e-06,
+                -1.10588938762623716291e-04,
+                -9.76109749136146840777e-03,
+                +7.78576235018280120474e-01,
+        };
+
+        T p;
+        T q = 0.0;
+
+        if (abs(x) <= T(8.0)) {
+            T a = A[0];
+
+            for (uint8_t index = 1; index < 29; index++) {
+                p = q;
+                q = a;
+                a = ((abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
+            }
+
+            if (x < T(0.0)) {
+                return -(T(0.5) * (a - p) * abs(x) * exp(abs(x)));
+            }
+
+            return T(0.5) * (a - p) * abs(x) * exp(abs(x));
+        }
+
+        T b = B[0];
+
+        for (uint8_t index = 1; index < 25; index++) {
+            p = q;
+            q = b;
+            b = (T(32.0) / abs(x) - T(2.0)) * q - p + B[index];
+        }
+
+        if (x < T(0.0)) {
+            return -(exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x)));
+        }
+
+        return exp(abs(x)) * (T(0.5) * (b - p)) / sqrt(abs(x));
+    } // modified_bessel_i1_forward(T x)
+); // modified_bessel_i1_string
+
+const auto modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify(
+    template
+    T modified_bessel_k0_forward(T x) {
+        static const T A[] = {
+                +1.37446543561352307156e-16,
+                +4.25981614279661018399e-14,
+                +1.03496952576338420167e-11,
+                +1.90451637722020886025e-09,
+                +2.53479107902614945675e-07,
+                +2.28621210311945178607e-05,
+                +1.26461541144692592338e-03,
+                +3.59799365153615016266e-02,
+                +3.44289899924628486886e-01,
+                -5.35327393233902768720e-01,
+        };
+
+        static const T B[] = {
+                +5.30043377268626276149e-18,
+                -1.64758043015242134646e-17,
+                +5.21039150503902756861e-17,
+                -1.67823109680541210385e-16,
+                +5.51205597852431940784e-16,
+                -1.84859337734377901440e-15,
+                +6.34007647740507060557e-15,
+                -2.22751332699166985548e-14,
+                +8.03289077536357521100e-14,
+                -2.98009692317273043925e-13,
+                +1.14034058820847496303e-12,
+                -4.51459788337394416547e-12,
+                +1.85594911495471785253e-11,
+                -7.95748924447710747776e-11,
+                +3.57739728140030116597e-10,
+                -1.69753450938905987466e-09,
+                +8.57403401741422608519e-09,
+                -4.66048989768794782956e-08,
+                +2.76681363944501510342e-07,
+                -1.83175552271911948767e-06,
+                +1.39498137188764993662e-05,
+                -1.28495495816278026384e-04,
+                +1.56988388573005337491e-03,
+                -3.14481013119645005427e-02,
+                +2.44030308206595545468e+00,
+        };
+
+        if (x == T(0.0)) {
+            return INFINITY;
+        }
+
+        if (x < T(0.0)) {
+            return NAN;
+        }
+
+        T p;
+        T q = 0.0;
+
+        if (x <= T(2.0)) {
+            T a = A[0];
+
+            for (uint8_t index = 1; index < 10; index++) {
+                p = q;
+                q = a;
+                a = (x * x - T(2.0)) * q - p + A[index];
+            }
+
+            return T(0.5) * (a - p) - log(0.5 * x) * modified_bessel_i0_forward(x);
+        }
+
+        T b = B[0];
+
+        for (uint8_t index = 1; index < 25; index++) {
+            p = q;
+            q = b;
+            b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+        }
+
+        return exp(-x) * (T(0.5) * (b - p)) / sqrt(x);
+    } // modified_bessel_k0_forward(T x)
+); // modified_bessel_k0_string
+
+const auto scaled_modified_bessel_k0_string = modified_bessel_i0_string + jiterator_stringify(
+    template
+    T scaled_modified_bessel_k0_forward(T x) {
+        static const T A[] = {
+                +1.37446543561352307156e-16,
+                +4.25981614279661018399e-14,
+                +1.03496952576338420167e-11,
+                +1.90451637722020886025e-09,
+                +2.53479107902614945675e-07,
+                +2.28621210311945178607e-05,
+                +1.26461541144692592338e-03,
+                +3.59799365153615016266e-02,
+                +3.44289899924628486886e-01,
+                -5.35327393233902768720e-01,
+        };
+
+        static const T B[] = {
+                +5.30043377268626276149e-18,
+                -1.64758043015242134646e-17,
+                +5.21039150503902756861e-17,
+                -1.67823109680541210385e-16,
+                +5.51205597852431940784e-16,
+                -1.84859337734377901440e-15,
+                +6.34007647740507060557e-15,
+                -2.22751332699166985548e-14,
+                +8.03289077536357521100e-14,
+                -2.98009692317273043925e-13,
+                +1.14034058820847496303e-12,
+                -4.51459788337394416547e-12,
+                +1.85594911495471785253e-11,
+                -7.95748924447710747776e-11,
+                +3.57739728140030116597e-10,
+                -1.69753450938905987466e-09,
+                +8.57403401741422608519e-09,
+                -4.66048989768794782956e-08,
+                +2.76681363944501510342e-07,
+                -1.83175552271911948767e-06,
+                +1.39498137188764993662e-05,
+                -1.28495495816278026384e-04,
+                +1.56988388573005337491e-03,
+                -3.14481013119645005427e-02,
+                +2.44030308206595545468e+00,
+        };
+
+        if (x == T(0.0)) {
+            return INFINITY;
+        }
+
+        if (x < T(0.0)) {
+            return NAN;
+        }
+
+        T p;
+        T q = 0.0;
+
+        if (x <= T(2.0)) {
+            T a = A[0];
+
+            for (uint8_t index = 1; index < 10; index++) {
+                p = q;
+                q = a;
+                a = (x * x - T(2.0)) * q - p + A[index];
+            }
+
+            return (T(0.5) * (a - p) - log(T(0.5) * x) * modified_bessel_i0_forward(x)) * exp(x);
+        }
+
+        T b = B[0];
+
+        for (uint8_t index = 1; index < 25; index++) {
+            p = q;
+            q = b;
+            b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+        }
+
+        return T(0.5) * (b - p) / sqrt(x);
+    } // T scaled_modified_bessel_k0_forward(T x)
+); // scaled_modified_bessel_k0_string
+
+const auto modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify(
+    template
+    T modified_bessel_k1_forward(T x) {
+        static const T A[] = {
+                -7.02386347938628759343e-18,
+                -2.42744985051936593393e-15,
+                -6.66690169419932900609e-13,
+                -1.41148839263352776110e-10,
+                -2.21338763073472585583e-08,
+                -2.43340614156596823496e-06,
+                -1.73028895751305206302e-04,
+                -6.97572385963986435018e-03,
+                -1.22611180822657148235e-01,
+                -3.53155960776544875667e-01,
+                +1.52530022733894777053e+00,
+        };
+
+        static const T B[] = {
+                -5.75674448366501715755e-18,
+                +1.79405087314755922667e-17,
+                -5.68946255844285935196e-17,
+                +1.83809354436663880070e-16,
+                -6.05704724837331885336e-16,
+                +2.03870316562433424052e-15,
+                -7.01983709041831346144e-15,
+                +2.47715442448130437068e-14,
+                -8.97670518232499435011e-14,
+                +3.34841966607842919884e-13,
+                -1.28917396095102890680e-12,
+                +5.13963967348173025100e-12,
+                -2.12996783842756842877e-11,
+                +9.21831518760500529508e-11,
+                -4.19035475934189648750e-10,
+                +2.01504975519703286596e-09,
+                -1.03457624656780970260e-08,
+                +5.74108412545004946722e-08,
+                -3.50196060308781257119e-07,
+                +2.40648494783721712015e-06,
+                -1.93619797416608296024e-05,
+                +1.95215518471351631108e-04,
+                -2.85781685962277938680e-03,
+                +1.03923736576817238437e-01,
+                +2.72062619048444266945e+00,
+        };
+
+        if (x == T(0.0)) {
+            return INFINITY;
+        }
+
+        if (x < T(0.0)) {
+            return NAN;
+        }
+
+        T p;
+        T q = 0.0;
+
+        if (x <= T(2.0)) {
+            T a = A[0];
+
+            for (uint8_t index = 1; index < 11; index++) {
+                p = q;
+                q = a;
+                a = (x * x - T(2.0)) * q - p + A[index];
+            }
+
+            return log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
+        }
+
+        T b = B[0];
+
+        for (uint8_t index = 1; index < 25; index++) {
+            p = q;
+            q = b;
+            b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+        }
+
+        return exp(-x) * (T(0.5) * (b - p)) / sqrt(x);
+    } // modified_bessel_k1_forward(T x)
+); // modified_bessel_k1_string
+
+const auto scaled_modified_bessel_k1_string = modified_bessel_i1_string + jiterator_stringify(
+    template
+    T scaled_modified_bessel_k1_forward(T x) {
+        static const T A[] = {
+                -7.02386347938628759343e-18,
+                -2.42744985051936593393e-15,
+                -6.66690169419932900609e-13,
+                -1.41148839263352776110e-10,
+                -2.21338763073472585583e-08,
+                -2.43340614156596823496e-06,
+                -1.73028895751305206302e-04,
+                -6.97572385963986435018e-03,
+                -1.22611180822657148235e-01,
+                -3.53155960776544875667e-01,
+                +1.52530022733894777053e+00,
+        };
+
+        static const T B[] = {
+                -5.75674448366501715755e-18,
+                +1.79405087314755922667e-17,
+                -5.68946255844285935196e-17,
+                +1.83809354436663880070e-16,
+                -6.05704724837331885336e-16,
+                +2.03870316562433424052e-15,
+                -7.01983709041831346144e-15,
+                +2.47715442448130437068e-14,
+                -8.97670518232499435011e-14,
+                +3.34841966607842919884e-13,
+                -1.28917396095102890680e-12,
+                +5.13963967348173025100e-12,
+                -2.12996783842756842877e-11,
+                +9.21831518760500529508e-11,
+                -4.19035475934189648750e-10,
+                +2.01504975519703286596e-09,
+                -1.03457624656780970260e-08,
+                +5.74108412545004946722e-08,
+                -3.50196060308781257119e-07,
+                +2.40648494783721712015e-06,
+                -1.93619797416608296024e-05,
+                +1.95215518471351631108e-04,
+                -2.85781685962277938680e-03,
+                +1.03923736576817238437e-01,
+                +2.72062619048444266945e+00,
+        };
+
+        if (x == T(0.0)) {
+            return INFINITY;
+        }
+
+        if (x < T(0.0)) {
+            return NAN;
+        }
+
+        T p;
+        T q = 0.0;
+
+        if (x <= T(2.0)) {
+            T a = A[0];
+
+            for (uint8_t index = 1; index < 11; index++) {
+                p = q;
+                q = a;
+                a = (x * x - T(2.0)) * q - p + A[index];
+            }
+
+            return (log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * exp(x);
+        }
+
+        T b = B[0];
+
+        for (uint8_t index = 1; index < 25; index++) {
+            p = q;
+            q = b;
+            b = (T(8.0) / x - T(2.0)) * q - p + B[index];
+        }
+
+        return (T(0.5) * (b - p) / sqrt(x));
+    } // T scaled_modified_bessel_k1_forward(T x)
+); // scaled_modified_bessel_k1_string
+
+const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify(
+    template
+    T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (x == T(1.0)) {
+            return T(1.0);
+        }
+
+        if (x == T(0.0)) {
+            if (n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) {
+            return cos(n * acos(x + x - T(1.0)));
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x - T(1.0);
+        }
+
+        T p = T(1.0);
+        T q = x + x - T(1.0);
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
+
+    template
+    T shifted_chebyshev_polynomial_t_forward(T x, T n) {
+        return shifted_chebyshev_polynomial_t_forward(x, static_cast(n));
+    } // shifted_chebyshev_polynomial_t_forward(T x, T n)
+); // shifted_chebyshev_polynomial_t_string
+
+const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify(
+    template
+    T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (x == T(1.0)) {
+            return n + 1;
+        }
+
+        if (x == T(0.0)) {
+            if (n % 2 == 0) {
+                return n + 1;
+            }
+
+            return -(n + 1);
+        }
+
+        if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) {
+            if (sin(acos(x + x - T(1.0))) != T(0.0)) {
+                return sin((n + 1) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0)));
+            }
+
+            return (n + 1) * cos((n + 1) * acos(x + x - T(1.0))) / (x + x - T(1.0));
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x - T(1.0) + (x + x - T(1.0));
+        }
+
+        T p = T(1.0);
+        T q = x + x - T(1.0) + (x + x - T(1.0));
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
+
+    template
+    T shifted_chebyshev_polynomial_u_forward(T x, T n) {
+        return shifted_chebyshev_polynomial_u_forward(x, static_cast(n));
+    } // shifted_chebyshev_polynomial_u_forward(T x, T n)
+); // shifted_chebyshev_polynomial_u_string
+
+const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify(
+    template
+    T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (x == T(1.0)) {
+            return T(1.0);
+        }
+
+        if (x == T(0.0)) {
+            if (n % 2 == 0) {
+                return (n + n + 1);
+            }
+
+            return -(n + n + 1);
+        }
+
+        if ((n > 6) && (abs(x + x - T(1.0)) < T(1.0))) {
+            if (sin(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
+                return cos(((n) + T(0.5)) * acos(x + x - T(1.0))) / cos(acos(x + x - T(1.0)) / T(2.0));
+            }
+
+            if (n % 2 == 0) {
+                return n + n + 1;
+            }
+
+            return -(n + n + 1);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
+        }
+
+        T p = T(1.0);
+        T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
+
+    template
+    T shifted_chebyshev_polynomial_v_forward(T x, T n) {
+        return shifted_chebyshev_polynomial_v_forward(x, static_cast(n));
+    } // shifted_chebyshev_polynomial_v_forward(T x, T n)
+); // shifted_chebyshev_polynomial_v_string
+
+const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify(
+    template
+    T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
+        if (n < 0) {
+            return T(0.0);
+        }
+
+        if (x == T(1.0)) {
+            return n + n + 1;
+        }
+
+        if (x == T(0.0)) {
+            if (n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if ((n > 4) && (abs(x + x - T(1.0)) < T(1.0))) {
+            if (cos(acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
+                return sin((n + T(0.5)) * acos(x + x - T(1.0))) / sin(acos(x + x - T(1.0)) / T(2.0));
+            }
+
+            if (n % 2 == 0) {
+                return T(1.0);
+            }
+
+            return T(-1.0);
+        }
+
+        if (n == 0) {
+            return T(1.0);
+        }
+
+        if (n == 1) {
+            return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
+        }
+
+        T p = T(1.0);
+        T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
+        T r;
+
+        for (int64_t k = 2; k <= n; k++) {
+            r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
+            p = q;
+            q = r;
+        }
+
+        return r;
+    } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
+
+    template
+    T shifted_chebyshev_polynomial_w_forward(T x, T n) {
+        return shifted_chebyshev_polynomial_w_forward(x, static_cast(n));
+    } // shifted_chebyshev_polynomial_w_forward(T x, T n)
+); // shifted_chebyshev_polynomial_w_string
+
+const auto spherical_bessel_j0_string = jiterator_stringify(
+    template
+    T spherical_bessel_j0_forward(T x) {
+        if (isinf(x)) {
+            return T(0.0);
+        }
+
+        if (abs(x) < T(0.5)) {
+            return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0)))))));
+        }
+
+        return sin(x) / x;
+    } // T spherical_bessel_j0_forward(T x)
+); // spherical_bessel_j0_string
+
+#else // !AT_USE_JITERATOR() -- kernels must be precompiled
+
+template 
+static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
+  scalar_t a = ::abs(a_in);
+  scalar_t b = ::abs(b_in);
+  while (a != 0) {
+    scalar_t c = a;
+    a = b % a;
+    b = c;
+  }
+  return b;
+}
+
+/*
+ * For licensing information, please refer to the cpu implementation located in "ATen/native/Math.h".
+ */
+template 
+static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
+  // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
+  using accscalar_t = at::acc_type;
+  static const double PI_f64 = 3.14159265358979323846;
+  const accscalar_t PSI_10 = 2.25175258906672110764;
+  const accscalar_t A[] = {
+      8.33333333333333333333E-2,
+      -2.10927960927960927961E-2,
+      7.57575757575757575758E-3,
+      -4.16666666666666666667E-3,
+      3.96825396825396825397E-3,
+      -8.33333333333333333333E-3,
+      8.33333333333333333333E-2,
+  };
+
+  accscalar_t x = static_cast(in);
+  if (x == 0) {
+    // As per C++ standard for gamma related functions and SciPy,
+    // If the argument is ±0, ±∞ is returned
+    return std::copysign(static_cast(INFINITY), -x);
+  }
+
+  bool x_is_integer = x == ::trunc(x);
+  accscalar_t result = 0;
+  if (x < 0) {
+    if (x_is_integer) {
+      // As per C++ standard for gamma related functions and SciPy,
+      // If the argument is a negative integer, NaN is returned
+      return static_cast(NAN);
+    }
+    // Extracts the fractional part of x as r, since tan(pi * r) is more numerically
+    // accurate than tan(pi * x). While these operations are mathematically equivalent
+    // since both x and r are in radians and tan() has a periodicity of pi, in practice
+    // the computation of pi * x is a source of error (when |x| > 1).
+    double q, r;
+    r = ::modf(static_cast(x), &q);
+    result = static_cast(- PI_f64 / ::tan(PI_f64 * r));
+    x = 1 - x;
+  }
+
+  while (x < 10) {
+    result -= 1 / x;
+    x += 1;
+  }
+  if (x == 10) {
+    return static_cast(result + PSI_10);
+  }
+
+  accscalar_t y = 0;
+  if (x < 1.0e17) {
+    accscalar_t z = 1 / (x * x);
+
+    accscalar_t polevl_result = 0;
+    for (int i = 0; i <= 6; i++) {
+      polevl_result = polevl_result * z + A[i];
+    }
+    y = z * polevl_result;
+  }
+
+  return static_cast(::log(x) - (static_cast(0.5) / x) - y + result);
+}
+
+template 
+static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
+  using accscalar_t = at::acc_type;
+  const accscalar_t PI = 3.14159265358979323846;
+  accscalar_t x = static_cast(in);
+  accscalar_t sign = +1;
+  accscalar_t result = 0;
+  if (x < 0.5f) {
+    sign = -1;
+    accscalar_t sin_pi_x = ::sin(PI * x);
+    result -= (PI * PI) / (sin_pi_x * sin_pi_x);
+    x = 1 - x;
+  }
+  for (int i = 0; i < 6; ++i) {
+    result += 1 / (x * x);
+    x += 1;
+  }
+  const accscalar_t one = static_cast(1);
+  const accscalar_t ixx = 1 / (x*x);
+  result += (1 + 1 / (2*x) + ixx * (one/6 - ixx * (one/30 - ixx * (one/42)))) / x;
+  return static_cast(sign * result);
+}
+
+/*
+ * For licensing information and documentation, please refer to the cpu implementation located in "ATen/native/Math.h".
+ */
+template 
+static inline C10_HOST_DEVICE scalar_t
+chbevl(scalar_t _x, const scalar_t array[], size_t len) {
+  static_assert(!std::is_same() && !std::is_same(), "don't instantiate with low precision type");
+
+  scalar_t b0, b1, b2;
+
+  b0 = array[0];
+  b1 = 0;
+
+  for (size_t i = 1; i < len; ++i)  {
+    b2 = b1;
+    b1 = b0;
+    b0 = _x * b1 - b2 + array[i];
+  }
+
+  return (0.5 * (b0 - b2));
+}
+
+/*
+ * For licensing information and documentation, please refer to the cpu implementation located in "ATen/native/Math.h".
+ */
+template 
+C10_HOST_DEVICE inline std::tuple chebyshev_coefficients_i0e_A() {
+  /* Chebyshev coefficients for exp(-x) I0(x)
+   * in the interval [0,8].
+   *
+   * lim(x->0){ exp(-x) I0(x) } = 1.
+   */
+  static const T coefficients[] = {
+      -4.41534164647933937950E-18, 3.33079451882223809783E-17,
+      -2.43127984654795469359E-16, 1.71539128555513303061E-15,
+      -1.16853328779934516808E-14, 7.67618549860493561688E-14,
+      -4.85644678311192946090E-13, 2.95505266312963983461E-12,
+      -1.72682629144155570723E-11, 9.67580903537323691224E-11,
+      -5.18979560163526290666E-10, 2.65982372468238665035E-9,
+      -1.30002500998624804212E-8,  6.04699502254191894932E-8,
+      -2.67079385394061173391E-7,  1.11738753912010371815E-6,
+      -4.41673835845875056359E-6,  1.64484480707288970893E-5,
+      -5.75419501008210370398E-5,  1.88502885095841655729E-4,
+      -5.76375574538582365885E-4,  1.63947561694133579842E-3,
+      -4.32430999505057594430E-3,  1.05464603945949983183E-2,
+      -2.37374148058994688156E-2,  4.93052842396707084878E-2,
+      -9.49010970480476444210E-2,  1.71620901522208775349E-1,
+      -3.04682672343198398683E-1,  6.76795274409476084995E-1};
+
+  return std::make_tuple(coefficients, 30);
+}
+
+template 
+C10_HOST_DEVICE inline std::tuple chebyshev_coefficients_i0e_B() {
+  /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
+   * in the inverted interval [8,infinity].
+   *
+   * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
+   */
+  static const T coefficients[] = {
+      -7.23318048787475395456E-18, -4.83050448594418207126E-18,
+      4.46562142029675999901E-17,  3.46122286769746109310E-17,
+      -2.82762398051658348494E-16, -3.42548561967721913462E-16,
+      1.77256013305652638360E-15,  3.81168066935262242075E-15,
+      -9.55484669882830764870E-15, -4.15056934728722208663E-14,
+      1.54008621752140982691E-14,  3.85277838274214270114E-13,
+      7.18012445138366623367E-13,  -1.79417853150680611778E-12,
+      -1.32158118404477131188E-11, -3.14991652796324136454E-11,
+      1.18891471078464383424E-11,  4.94060238822496958910E-10,
+      3.39623202570838634515E-9,   2.26666899049817806459E-8,
+      2.04891858946906374183E-7,   2.89137052083475648297E-6,
+      6.88975834691682398426E-5,   3.36911647825569408990E-3,
+      8.04490411014108831608E-1};
+
+  return std::make_tuple(coefficients, 25);
+}
+
+template 
+static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
+  static_assert(!std::is_same() && !std::is_same(), "don't instantiate with low precision type");
+  // Upcast input for numerical accuracy purposes
+  // Needed for accurate results if input is bfloat16 or float16
+  scalar_t x = ::abs(_x);
+
+  if (x <= scalar_t{8.0}) {
+    auto coeff_pair = chebyshev_coefficients_i0e_A();
+    auto A = std::get<0>(coeff_pair);
+    auto len = std::get<1>(coeff_pair);
+    scalar_t y = (x / scalar_t{2.0}) - scalar_t{2.0};
+    return (::exp(x) * chbevl(y, A, len));
+  }
+
+  auto coeff_pair = chebyshev_coefficients_i0e_B();
+  auto B = std::get<0>(coeff_pair);
+  auto len = std::get<1>(coeff_pair);
+  return (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x));
+}
+
+template 
+C10_HOST_DEVICE inline
+    typename std::enable_if::value, std::tuple>::type
+    chebyshev_coefficients_i1e_A() {
+  /* Chebyshev coefficients for exp(-x) I1(x)
+   * in the interval [0,8].
+   *
+   * lim(x->0){ exp(-x) I1(x) / x } = 1/2.
+   */
+  static const T coefficients[] = {
+      2.77791411276104639959E-18, -2.11142121435816608115E-17,
+      1.55363195773620046921E-16, -1.10559694773538630805E-15,
+      7.60068429473540693410E-15, -5.04218550472791168711E-14,
+      3.22379336594557470981E-13, -1.98397439776494371520E-12,
+      1.17361862988909016308E-11, -6.66348972350202774223E-11,
+      3.62559028155211703701E-10, -1.88724975172282928790E-9,
+      9.38153738649577178388E-9,  -4.44505912879632808065E-8,
+      2.00329475355213526229E-7,  -8.56872026469545474066E-7,
+      3.47025130813767847674E-6,  -1.32731636560394358279E-5,
+      4.78156510755005422638E-5,  -1.61760815825896745588E-4,
+      5.12285956168575772895E-4,  -1.51357245063125314899E-3,
+      4.15642294431288815669E-3,  -1.05640848946261981558E-2,
+      2.47264490306265168283E-2,  -5.29459812080949914269E-2,
+      1.02643658689847095384E-1,  -1.76416518357834055153E-1,
+      2.52587186443633654823E-1};
+
+  return std::make_tuple(coefficients, 29);
+}
+
+template 
+C10_HOST_DEVICE inline
+    typename std::enable_if::value, std::tuple>::type
+    chebyshev_coefficients_i1e_A() {
+  /* Chebyshev coefficients for exp(-x) I1(x)
+   * in the interval [0,8].
+   *
+   * lim(x->0){ exp(-x) I1(x) / x } = 1/2.
+   */
+  static const T coeff[] = {
+      9.38153738649577178388E-9f,
+      -4.44505912879632808065E-8f,
+      2.00329475355213526229E-7f,
+      -8.56872026469545474066E-7f,
+      3.47025130813767847674E-6f,
+      -1.32731636560394358279E-5f,
+      4.78156510755005422638E-5f,
+      -1.61760815825896745588E-4f,
+      5.12285956168575772895E-4f,
+      -1.51357245063125314899E-3f,
+      4.15642294431288815669E-3f,
+      -1.05640848946261981558E-2f,
+      2.47264490306265168283E-2f,
+      -5.29459812080949914269E-2f,
+      1.02643658689847095384E-1f,
+      -1.76416518357834055153E-1f,
+      2.52587186443633654823E-1f};
+  return std::make_tuple(coeff, 17);
+};
+
+template 
+C10_HOST_DEVICE inline
+    typename std::enable_if::value, std::tuple>::type
+    chebyshev_coefficients_i1e_B() {
+  /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
+   * in the inverted interval [8,infinity].
+   *
+   * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
+   */
+  static const T coefficients[] = {
+      7.51729631084210481353E-18,  4.41434832307170791151E-18,
+      -4.65030536848935832153E-17, -3.20952592199342395980E-17,
+      2.96262899764595013876E-16,  3.30820231092092828324E-16,
+      -1.88035477551078244854E-15, -3.81440307243700780478E-15,
+      1.04202769841288027642E-14,  4.27244001671195135429E-14,
+      -2.10154184277266431302E-14, -4.08355111109219731823E-13,
+      -7.19855177624590851209E-13, 2.03562854414708950722E-12,
+      1.41258074366137813316E-11,  3.25260358301548823856E-11,
+      -1.89749581235054123450E-11, -5.58974346219658380687E-10,
+      -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
+      -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
+      -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
+      7.78576235018280120474E-1};
+
+  return std::make_tuple(coefficients, 25);
+}
+
+template 
+C10_HOST_DEVICE inline
+    typename std::enable_if::value, std::tuple>::type
+    chebyshev_coefficients_i1e_B() {
+  /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
+   * in the inverted interval [8,infinity].
+   *
+   * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
+   */
+  static const T coeff[] = {
+      -3.83538038596423702205E-9f,
+      -2.63146884688951950684E-8f,
+      -2.51223623787020892529E-7f,
+      -3.88256480887769039346E-6f,
+      -1.10588938762623716291E-4f,
+      -9.76109749136146840777E-3f,
+      7.78576235018280120474E-1f};
+
+  return std::make_tuple(coeff, 7);
+};
+
+template 
+static inline C10_HOST_DEVICE scalar_t calc_i1(scalar_t _x) {
+  const auto x = ::abs(_x);
+  if (x <= scalar_t{8.0}) {
+    auto coeff_pair = chebyshev_coefficients_i1e_A();
+    auto A = std::get<0>(coeff_pair);
+    auto len = std::get<1>(coeff_pair);
+    scalar_t y = x / scalar_t{2.0} - scalar_t{2.0};
+    const scalar_t out = ::exp(x) * x * chbevl(y, A, len);
+    return (_x < scalar_t{0.0}) ? -out : out;
+  }
+
+  auto coeff_pair = chebyshev_coefficients_i1e_B();
+  auto B = std::get<0>(coeff_pair);
+  auto len = std::get<1>(coeff_pair);
+  const scalar_t out = (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len)) / ::sqrt(x);
+  return (_x < scalar_t{0.0}) ? -out : out;
+}
+
+template 
+static inline C10_HOST_DEVICE scalar_t calc_i1e(scalar_t _x) {
+  const auto x = ::abs(_x);
+  if (x <= scalar_t{8.0}) {
+    auto coeff_pair = chebyshev_coefficients_i1e_A();
+    auto A = std::get<0>(coeff_pair);
+    auto len = std::get<1>(coeff_pair);
+    const scalar_t y = x / scalar_t{2.0} - scalar_t{2.0};
+    const scalar_t out = chbevl(y, A, len) * x;
+    return (_x < scalar_t{0.0}) ? -out : out;
+  }
+
+  auto coeff_pair = chebyshev_coefficients_i1e_B();
+  auto B = std::get<0>(coeff_pair);
+  auto len = std::get<1>(coeff_pair);
+  const scalar_t out = chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x);
+  return (_x < scalar_t{0.0}) ? -out : out;
+}
+
+#endif // AT_USE_JITERATOR() (this closes the "else" branch of a if/else preprocessor directive)
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..fabc47acb137252f5b138c59fafab23d874c2c8a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh
@@ -0,0 +1,384 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+// References:
+// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
+
+namespace at { namespace native { namespace memory {
+
+namespace detail {
+
+// What does the `static_unroll` do?
+//
+// We want to do something like:
+//
+//    using args_t = typename traits::ArgsTuple;
+//    args_t args;
+//    #pragma unroll
+//    for (int i = 0; i < traits::arity; i++) {
+//      std::get(args) = ....
+//    }
+//
+// but unfortunately the above code does not work because
+// the template argument has to be a compile time constant
+// so `static_unroll` is created to simulate `#pragma unroll`
+// using template metaprogramming.
+
+template typename func, int end, int current=0>
+struct static_unroll {
+  template
+  static inline C10_HOST_DEVICE void with_args(Args&&... args) {
+    func::apply(std::forward(args)...);
+    static_unroll::with_args(args...);
+  }
+};
+
+template typename func, int end>
+struct static_unroll {
+  template
+  static inline C10_HOST_DEVICE void with_args(Args... args) {}
+};
+
+// helper structs to be used with static_unroll to load arguments
+// one by one
+
+template
+struct vectorized_load_helper {
+  template 
+  static __device__ void apply(policy_t &self, args_t *args, int idx) {
+    using arg_t = std::tuple_element_t;
+    // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
+    // need a +1 offset to get the input
+    auto ptr = reinterpret_cast(self.data[arg_index + 1]) + block_work_size() * idx;
+    auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); };
+    self.load_single_arg(args_accessor, ptr);
+  }
+};
+
+template
+struct unroll_load_helper {
+  template 
+  static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
+    using arg_t = std::tuple_element_t;
+    // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
+    // need a +1 offset to get the input
+    std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
+  }
+};
+
+template 
+struct multi_outputs_store_helper {
+  template
+  C10_HOST_DEVICE static void apply(
+      at::detail::Array data,
+      at::detail::Array offsets,
+      thrust::tuple ret) {
+    using T = typename thrust::tuple_element>::type;
+    T *to = reinterpret_cast(data[current]) + offsets[current];
+    *to = thrust::get(ret);
+  }
+};
+
+}  // namespace detail
+
+struct LoadWithoutCast {
+  template
+  __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
+    return c10::load(reinterpret_cast(base_ptr) + offset);
+  }
+};
+
+template 
+struct LoadWithCast {
+  using array_t = at::detail::Array(N, 1)>;
+  using size_array_t = at::detail::Array(N, 1)>;
+
+  array_t dtypes;
+  size_array_t element_sizes;
+
+  LoadWithCast(const TensorIteratorBase& iter) {
+    CUDA_KERNEL_ASSERT(iter.ninputs() == N);
+    #pragma unroll
+    for (auto i = 0; i < N; ++i) {
+      this->dtypes[i] = iter.dtype(i + iter.noutputs());
+      element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
+    }
+  }
+
+  template
+  __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
+    void *ptr = base_ptr + element_sizes[arg] * offset;
+    return c10::fetch_and_cast(dtypes[arg], ptr);
+  }
+};
+
+struct StoreWithoutCast {
+  template
+  __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
+    *(reinterpret_cast(base_ptr) + offset) = value;
+  }
+};
+
+template 
+struct StoreWithCast {
+  using array_t = at::detail::Array(N, 1)>;
+  using size_array_t = at::detail::Array(N, 1)>;
+
+  array_t dtypes;
+  size_array_t element_sizes;
+
+  StoreWithCast(const TensorIteratorBase& iter) {
+    CUDA_KERNEL_ASSERT(iter.noutputs() == N);
+    #pragma unroll
+    for (auto i = 0; i < N; ++i) {
+      this->dtypes[i] = iter.dtype(i);
+      element_sizes[i] = c10::elementSize(iter.dtype(i));
+    }
+  }
+
+  template
+  __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
+    void *ptr = base_ptr + element_sizes[arg] * offset;
+    c10::cast_and_store(dtypes[arg], ptr, value);
+  }
+};
+
+// aligned vector generates vectorized load/store on CUDA
+template
+struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
+  scalar_t val[vec_size];
+};
+
+template 
+__device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) {
+  using vec_t = aligned_vector;
+  auto *from = reinterpret_cast(base_ptr);
+  return from[offset];
+}
+
+template 
+__device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) {
+  // See NOTE [Loading boolean values]
+  auto tmp = load_vector(reinterpret_cast(base_ptr), offset);
+  aligned_vector ret;
+  for (int i = 0; i < vec_size; ++i) {
+    ret.val[i] = bool(tmp.val[i]);
+  }
+  return ret;
+}
+
+namespace policies {
+
+// Assumption:
+// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
+template
+struct unroll {
+
+  data_t data;
+  int remaining;
+  inp_calc_t input_offset_calculator;
+  out_calc_t output_offset_calculator;
+  loader_t loader;
+  storer_t storer;
+
+  __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
+    data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
+
+  __device__ inline bool check_inbounds(int thread_work_elem) {
+    return ((int)(threadIdx.x  + thread_work_elem*num_threads()) < remaining);
+  }
+
+  template
+  __device__ inline void load(args_t *args, int idx) {
+    constexpr int arity = std::tuple_size::value;
+    int thread_idx = threadIdx.x;
+    #pragma unroll
+    for (int i = 0; i < thread_work_size(); i++) {
+      if (thread_idx >= remaining) {
+        return;
+      }
+      int linear_idx = thread_idx + block_work_size() * idx;
+      auto offset = input_offset_calculator.get(linear_idx);
+      detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs);
+      thread_idx += num_threads();
+    }
+  }
+
+  template
+  __device__ inline void store(scalar_t *from, int idx) {
+    int thread_idx = threadIdx.x;
+    #pragma unroll
+    for (int i = 0; i < thread_work_size(); i++) {
+      if (thread_idx >= remaining) {
+        return;
+      }
+      int linear_idx = thread_idx + block_work_size() * idx;
+      int offset = output_offset_calculator.get(linear_idx)[0];
+      storer.store(from[i], data[0], offset);
+      thread_idx += num_threads();
+    }
+  }
+};
+
+// Assumption:
+// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
+// Note:
+// Functions in vectorized policy does not do boundary check. It assumes the whole block
+// has its job to do. So the reminders should be handled by the caller manually.
+template   // vec_size: number of scalars, can be 1, 2, or 4.
+struct vectorized {
+
+  static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
+  static constexpr int loop_size = thread_work_size() / vec_size;
+
+  data_t data;
+
+  __device__ vectorized(data_t data) : data(data) {}
+
+  __device__ inline constexpr bool check_inbounds(int thread_work_elem) {
+    return true;
+  }
+
+  template
+  __device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
+    int thread_idx = threadIdx.x;
+    #pragma unroll
+    for (int i = 0; i < loop_size; i++) {
+      int index = thread_idx + i * num_threads();
+      auto v = load_vector(from, index);
+      #pragma unroll
+      for (int j = 0; j < vec_size; j++) {
+        to(vec_size * i + j) = v.val[j];
+      }
+    }
+  }
+
+  template
+  __device__ inline void load(args_t *args, int idx) {
+    constexpr int arity = std::tuple_size::value;
+    detail::static_unroll::with_args(*this, args, idx);
+  }
+
+  template
+  __device__ inline void store(scalar_t *from, int idx) {
+    using vec_t = aligned_vector;
+    scalar_t *to = reinterpret_cast(data[0]) + block_work_size() * idx;
+    vec_t *to_ = reinterpret_cast(to);
+    int thread_idx = threadIdx.x;
+    #pragma unroll
+    for (int i = 0; i < loop_size; i++) {
+      int index = thread_idx + i * num_threads();
+      vec_t v;
+      for (int j = 0; j < vec_size; j++) {
+        v.val[j] = from[vec_size * i + j];
+      }
+      to_[index] = v;
+    }
+  }
+};
+
+template 
+struct multi_outputs_unroll {
+  //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
+  //we don't use inheritance because of compiler bug in cuda 10.2+
+  data_t data;
+  int remaining;
+  inp_calc_t input_offset_calculator;
+  out_calc_t output_offset_calculator;
+  LoadWithoutCast loader;
+  StoreWithoutCast storer;
+
+  __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
+  data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
+
+  __device__ inline bool check_inbounds(int thread_work_elem) {
+    return ((int)(threadIdx.x  + thread_work_elem*num_threads()) < remaining);
+  }
+
+  template
+  __device__ inline void load(args_t *args, int idx) {
+    constexpr int arity = std::tuple_size::value;
+    int thread_idx = threadIdx.x;
+    #pragma unroll
+    for (int i = 0; i < thread_work_size(); i++) {
+      if (thread_idx >= remaining) {
+        return;
+      }
+      int linear_idx = thread_idx + block_work_size() * idx;
+      auto offset = input_offset_calculator.get(linear_idx);
+      detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs);
+      thread_idx += num_threads();
+    }
+  }
+
+
+  template 
+  __device__ inline void store(return_t *from, int idx) {
+    int thread_idx = threadIdx.x;
+    #pragma unroll
+    for (int i = 0; i < thread_work_size(); i++) {
+      if (thread_idx >= this->remaining) {
+        return;
+      }
+      int linear_idx = thread_idx + block_work_size() * idx;
+      auto offsets = this->output_offset_calculator.get(linear_idx);
+      memory::detail::static_unroll::with_args(this->data, offsets, from[i]);
+      thread_idx += num_threads();
+    }
+  }
+};
+
+}  // namespace policies
+
+// This is only used in host, but we will wrap this into some templates
+// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
+// in order to compile
+template
+inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
+  uint64_t address = reinterpret_cast(pointer);
+  constexpr int vec2_alignment = std::alignment_of>::value;
+  constexpr int vec4_alignment = std::alignment_of>::value;
+  if (address % vec4_alignment == 0) {
+    return 4;
+  } else if (address % vec2_alignment == 0) {
+    return 2;
+  }
+  return 1;
+}
+
+template
+struct can_vectorize_up_to_helper {
+  template 
+  static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
+    using arg_t = typename traits::template arg::type;
+    // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
+    // need a +1 offset to get the input
+    result = std::min(result, can_vectorize_up_to(pointers[i + 1]));
+  }
+};
+
+template
+inline int can_vectorize_up_to(array_t pointers) {
+  using traits = function_traits;
+  using return_t = typename traits::result_type;
+  constexpr int arity = traits::arity;
+  int result = can_vectorize_up_to(pointers[0]);
+  // We need to get the type for each argument of `func_t`, this can only
+  // be done at compile time.
+  detail::static_unroll::with_args(result, pointers, traits());
+  return result;
+}
+
+}}} // namespace at::native::memory
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MiscUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MiscUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..936b4d80a179a77afec1c6df4741161b14541934
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MiscUtils.h
@@ -0,0 +1,32 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+static inline int cuda_int_cast(int64_t value, const char* varname) {
+  auto result = static_cast(value);
+  TORCH_CHECK(static_cast(result) == value,
+              "cuda_int_cast: The value of ", varname, "(", (long long)value,
+              ") is too large to fit into a int (", sizeof(int), " bytes)");
+  return result;
+}
+
+// Creates an array of size elements of type T, backed by pinned memory
+// wrapped in a Storage
+template
+static inline Storage pin_memory(int64_t size) {
+  auto* allocator = cuda::getPinnedMemoryAllocator();
+  int64_t adjusted_size = size * sizeof(T);
+  return Storage(
+      Storage::use_byte_size_t(),
+      adjusted_size,
+      allocator,
+      /*resizable=*/false);
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1402f3bd038847e5491686c33669e2ac6ad59cb6
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh
@@ -0,0 +1,379 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at::native {
+
+namespace {
+
+static constexpr int64_t kILP = 4;
+static constexpr int64_t kChunkSize = 65536;
+static constexpr int64_t kBlockSize = 512;
+
+// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
+// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
+static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
+static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
+static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
+static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
+    72,
+    60};
+
+template 
+__device__ __forceinline__ bool is_aligned(T* p) {
+  return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
+}
+
+template 
+__device__ __forceinline__ void load_store(
+    T* dst,
+    T* src,
+    int64_t dst_offset,
+    int64_t src_offset) {
+  using LT = at::native::memory::aligned_vector;
+  ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
+}
+
+template 
+struct TensorListMetadata {
+  const void* addresses[n][depth_to_max_tensors[n - 1]];
+  int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
+  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
+  int block_to_chunk[depth_to_max_blocks[n - 1]];
+  int start_tensor_this_launch;
+};
+
+template 
+struct TensorListScalarListMetadata {
+  const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
+  int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
+  scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
+  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
+  int block_to_chunk[depth_to_max_blocks[n - 1]];
+};
+
+// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
+// 4kb with `c10::complex`
+template <>
+struct TensorListScalarListMetadata, 1> {
+  const void* addresses[1]
+                       [depth_to_max_tensors_scalarlist_of_complex_double[0]];
+  int64_t
+      numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
+  c10::complex
+      scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
+  unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
+  int block_to_chunk[depth_to_max_blocks[1 - 1]];
+};
+
+template <>
+struct TensorListScalarListMetadata, 2> {
+  const void* addresses[2]
+                       [depth_to_max_tensors_scalarlist_of_complex_double[1]];
+  int64_t
+      numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
+  c10::complex
+      scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
+  unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
+  int block_to_chunk[depth_to_max_blocks[2 - 1]];
+};
+
+// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
+// whose each element is `at::Tensor` of 1 element representing the number of
+// `step`s called so far.
+template 
+struct FusedOptimizerTensorListMetadata {
+  const void* addresses[n][depth_to_max_tensors[n - 1]];
+  int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
+  const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
+  unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
+  int block_to_chunk[depth_to_max_blocks[n - 1]];
+  int start_tensor_this_launch;
+};
+
+template 
+C10_LAUNCH_BOUNDS_1(kBlockSize)
+__global__ void multi_tensor_apply_kernel(
+    T tensorListMeta,
+    U callable,
+    ArgTypes... args) {
+  // Hand the chunk information to the user-supplied functor to process however
+  // it likes.
+  callable(kChunkSize, tensorListMeta, args...);
+}
+
+} // namespace
+
+// multi_tensor_apply enables horizontal fusion across lists of tensors.
+// For example, whereas you once had a for-loop of a + b = c, where a, b,
+// and c are individual tensors in lists as, bs, and cs, you can now with
+// fewer kernel launches compute as + bs = cs.
+//
+// You can also imagine bs to be a scalar list vs a tensor list.
+//
+// The function below takes in tensor lists, scalars, and a callable and
+// chunks up the computation to launch as few kernels as possible by iterating
+// through every "chunk" in every tensor (thus the nested for loops). In the
+// simplest case, everything gets bundled into just one kernel launch, but
+// due to blocksize constraints, we may need to launch multiple kernels.
+// Each kernel launch is defined by one tensorListMeta construct, which we
+// use to track and reset the necessary metadata for each launch.
+template 
+void multi_tensor_apply(
+    std::vector>& tensor_lists,
+    at::ArrayRef scalars,
+    T callable,
+    ArgTypes... args) {
+  TORCH_CHECK(
+      tensor_lists.size() == depth,
+      "Number of tensor lists has to match the depth.");
+  const size_t n_tensors = tensor_lists[0].size();
+  using scalar_vals_t = typename T::opmath_t;
+  TensorListScalarListMetadata tensorListMeta;
+
+  int loc_block_info = 0;
+  int loc_tensor_info = 0;
+  for (size_t t = 0; t < n_tensors; t++) {
+    // short-circuit to avoid adding empty tensors to tensorListMeta
+    if (tensor_lists[0][t].numel() == 0) {
+      continue;
+    }
+    tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to();
+    tensorListMeta.numel_for_tensor[loc_tensor_info] =
+        tensor_lists[0][t].numel();
+    for (int d = 0; d < depth; d++) {
+      tensorListMeta.addresses[d][loc_tensor_info] =
+          tensor_lists[d][t].const_data_ptr();
+    }
+    loc_tensor_info++;
+
+    // now we enter [chunking territory].
+    // we will launch a kernel when EITHER the blocks get filled up OR
+    // the tensors get filled up. There will always be at least one block
+    // per tensor since the zero-sized ones will not enter the loop, so
+    // the nested forloop within represents iterating through the chunks
+    // of a single tensor.
+    const auto numel = tensor_lists[0][t].numel();
+    const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
+    for (auto chunk = 0; chunk < chunks; chunk++) {
+      tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
+      tensorListMeta.block_to_chunk[loc_block_info] = chunk;
+      loc_block_info++;
+
+      // a tensor is not considered full unless all its chunks have been
+      // processed
+      const bool tensors_full =
+          (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
+           chunk == chunks - 1);
+      const bool blocks_full =
+          (loc_block_info == depth_to_max_blocks[depth - 1]);
+
+      if (tensors_full || blocks_full) {
+        multi_tensor_apply_kernel<<<
+            loc_block_info,
+            kBlockSize,
+            0,
+            at::cuda::getCurrentCUDAStream()>>>(
+            tensorListMeta, callable, args...);
+        C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+        // Reset.
+        loc_block_info = 0;
+        // all chunks have already been handled in the kernel
+        if (chunk == chunks - 1) {
+          loc_tensor_info = 0;
+        } else { // blocks were full and tensor chunks remain
+          tensorListMeta.numel_for_tensor[0] =
+              tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
+          tensorListMeta.scalar_vals[0] =
+              tensorListMeta.scalar_vals[loc_tensor_info - 1];
+          for (int d = 0; d < depth; d++) {
+            tensorListMeta.addresses[d][0] =
+                tensorListMeta.addresses[d][loc_tensor_info - 1];
+          }
+          loc_tensor_info = 1;
+        }
+      }
+    }
+  }
+
+  // note: [finishing what we started]
+  // if there's remaining work to be done but the tensors/blocks aren't full
+  // yet we are at the end, submit the kernel to do the work!
+  if (loc_block_info != 0) {
+    multi_tensor_apply_kernel<<<
+        loc_block_info,
+        kBlockSize,
+        0,
+        at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  }
+}
+
+template 
+void multi_tensor_apply(
+    std::vector>& tensor_lists,
+    T callable,
+    ArgTypes... args) {
+  TORCH_CHECK(
+      tensor_lists.size() == depth,
+      "Number of tensor lists has to match the depth.");
+  const size_t n_tensors = tensor_lists[0].size();
+  TensorListMetadata tensorListMeta;
+  tensorListMeta.start_tensor_this_launch = 0;
+
+  int loc_block_info = 0;
+  int loc_tensor_info = 0;
+  for (size_t t = 0; t < n_tensors; t++) {
+    // short-circuit to avoid adding empty tensors to tensorListMeta
+    if (tensor_lists[0][t].numel() == 0) {
+      continue;
+    }
+    tensorListMeta.numel_for_tensor[loc_tensor_info] =
+        tensor_lists[0][t].numel();
+    for (int d = 0; d < depth; d++) {
+      tensorListMeta.addresses[d][loc_tensor_info] =
+          tensor_lists[d][t].const_data_ptr();
+    }
+    loc_tensor_info++;
+
+    // see note: [chunking territory].
+    const auto numel = tensor_lists[0][t].numel();
+    const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
+    for (auto chunk = 0; chunk < chunks; chunk++) {
+      tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
+      tensorListMeta.block_to_chunk[loc_block_info] = chunk;
+      loc_block_info++;
+
+      const bool tensors_full =
+          (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+           chunk == chunks - 1);
+      const bool blocks_full =
+          (loc_block_info == depth_to_max_blocks[depth - 1]);
+
+      if (tensors_full || blocks_full) {
+        multi_tensor_apply_kernel<<<
+            loc_block_info,
+            kBlockSize,
+            0,
+            at::cuda::getCurrentCUDAStream()>>>(
+            tensorListMeta, callable, args...);
+        C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+        // Reset.
+        loc_block_info = 0;
+        if (chunk == chunks - 1) {
+          loc_tensor_info = 0;
+          tensorListMeta.start_tensor_this_launch = t + 1;
+        } else {
+          tensorListMeta.numel_for_tensor[0] =
+              tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
+          for (int d = 0; d < depth; d++) {
+            tensorListMeta.addresses[d][0] =
+                tensorListMeta.addresses[d][loc_tensor_info - 1];
+          }
+          loc_tensor_info = 1;
+          tensorListMeta.start_tensor_this_launch = t;
+        }
+      }
+    }
+  }
+
+  // see note: [finishing what we started]
+  if (loc_block_info != 0) {
+    multi_tensor_apply_kernel<<<
+        loc_block_info,
+        kBlockSize,
+        0,
+        at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  }
+}
+
+template 
+void multi_tensor_apply_for_fused_optimizer(
+    std::vector>& tensor_lists,
+    at::TensorList state_steps,
+    T callable,
+    ArgTypes... args) {
+  TORCH_CHECK(
+      tensor_lists.size() == depth,
+      "Number of tensor lists has to match the depth");
+  const auto num_tensors = tensor_lists[0].size();
+  FusedOptimizerTensorListMetadata tensorListMeta;
+
+  int loc_block_info = 0;
+  int loc_tensor_info = 0;
+  for (const auto& tensor_index : c10::irange(num_tensors)) {
+    // short-circuit to avoid adding empty tensors to tensorListMeta
+    if (tensor_lists[0][tensor_index].numel() == 0) {
+      continue;
+    }
+    tensorListMeta.state_steps_addresses[loc_tensor_info] =
+        state_steps[tensor_index].const_data_ptr();
+    tensorListMeta.numel_for_tensor[loc_tensor_info] =
+        tensor_lists[0][tensor_index].numel();
+    for (const auto& d : c10::irange(depth)) {
+      tensorListMeta.addresses[d][loc_tensor_info] =
+          tensor_lists[d][tensor_index].const_data_ptr();
+    }
+    loc_tensor_info++;
+
+    // see above note: [chunking territory]
+    const auto numel = tensor_lists[0][tensor_index].numel();
+    const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
+    TORCH_CHECK(chunks > -1);
+    for (const auto& chunk : c10::irange(chunks)) {
+      tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
+      tensorListMeta.block_to_chunk[loc_block_info] = chunk;
+      loc_block_info++;
+
+      const auto tensor_full =
+          (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+           chunk == chunks - 1);
+      const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
+
+      if (tensor_full || blocks_full) {
+        multi_tensor_apply_kernel<<<
+            loc_block_info,
+            kBlockSize,
+            0,
+            at::cuda::getCurrentCUDAStream()>>>(
+            tensorListMeta, callable, args...);
+        C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+        // Reset.
+        loc_block_info = 0;
+        if (chunk == chunks - 1) {
+          loc_tensor_info = 0;
+        } else {
+          tensorListMeta.numel_for_tensor[0] =
+              tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
+          tensorListMeta.state_steps_addresses[0] =
+              tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
+          for (const auto& d : c10::irange(depth)) {
+            tensorListMeta.addresses[d][0] =
+                tensorListMeta.addresses[d][loc_tensor_info - 1];
+          }
+          loc_tensor_info = 1;
+        }
+      }
+    }
+  }
+
+  // see above note: [finishing what we've started]
+  if (loc_block_info != 0) {
+    multi_tensor_apply_kernel<<<
+        loc_block_info,
+        kBlockSize,
+        0,
+        at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  }
+}
+
+} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Normalization.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Normalization.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..9a609e956aa37a13249dc81e0982a34404837816
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Normalization.cuh
@@ -0,0 +1,1742 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+#include 
+#include 
+#endif
+
+namespace at { namespace native {
+
+// The maximum number of threads in a block
+#if defined(USE_ROCM)
+constexpr int MAX_BLOCK_SIZE = 256;
+#else
+constexpr int MAX_BLOCK_SIZE = 512;
+#endif
+
+constexpr unsigned MAX_GRID_SIZE = 65535u;
+
+// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
+static int getNumThreads(int nElem) {
+#if defined(USE_ROCM)
+  int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
+#else
+  int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
+#endif
+  for (int i = 0; i != 5; ++i) {
+    if (nElem <= threadSizes[i]) {
+      return threadSizes[i];
+    }
+  }
+  return MAX_BLOCK_SIZE;
+}
+
+// Returns the index of the most significant 1 bit in `val`.
+__device__ __forceinline__ int getMSB(int val) {
+  return 31 - __clz(val);
+}
+
+template 
+struct Float2 {
+  accscalar_t v1, v2;
+  __device__ Float2() {}
+  __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {}
+  __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {}
+  __device__ Float2& operator+=(const Float2& a) {
+    v1 += a.v1;
+    v2 += a.v2;
+    return *this;
+  }
+  __device__ friend Float2 operator+(Float2 a, const Float2& b) {
+    a += b;
+    return a;
+  }
+};
+
+template 
+struct GradOp {
+  __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
+    : mean(m), input(i), grad_output(g) {}
+  __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) {
+    accscalar_t g = grad_output[batch][plane][n];
+    accscalar_t c = static_cast(input[batch][plane][n]) - mean;
+    return Float2(g, g * c);
+  }
+  const accscalar_t mean;
+  const PTA& input;
+  const PTA& grad_output;
+};
+
+template 
+struct SumReduceOp {
+    __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
+
+    __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
+        return WARP_SHFL_DOWN(data, offset);
+    }
+};
+
+template 
+struct SumReduceOp> {
+    using acc_t = Float2;
+
+    __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
+
+    __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
+        return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
+    }
+};
+
+// Sum across (batch, x/y/z) applying Op() pointwise
+// this works by first having each thread sum it's part
+// of the data. Then there is a double-shuffling reduction.
+// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
+// data to the "warp leader", who writes its value into shared memory.
+// Then a single warp reads the remaining (at most C10_WARP_SIZE) items
+// and reduces them using another warpSum.
+// The implicit assumption is that there are no more
+// than C10_WARP_SIZE**2 threads.
+template
+__device__ scalar_t reduce(Op op, PTA tensor, int plane) {
+  // first the reductions each thread does separately
+  scalar_t sum = static_cast(0);
+  for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
+    for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
+      sum += op(batch, plane, x);
+    }
+  }
+  __shared__ scalar_t shared[C10_WARP_SIZE];
+  SumReduceOp reduce_op;
+  sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
+      shared[0] = sum;
+  }
+  __syncthreads();
+  // Everyone picks it up, should be broadcast into the whole grad_input
+  return shared[0];
+}
+
+constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
+constexpr int ELEMENTS_PER_THREAD = 16;
+constexpr int OPTIMAL_TILE_W = 32;
+constexpr int MAX_H_BLOCK = 128;
+
+__host__ void flexible_launch_configs(
+      const int reduction,
+      const int stride,
+      dim3 &block,
+      dim3 &grid,
+      const bool coop_flag = false) {
+  int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
+  int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
+                         MAX_BLOCK_SIZE / block_x);
+  if (block_x * block_y != MAX_BLOCK_SIZE) {
+    block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
+  }
+
+  int grid_x = at::ceil_div(stride, block_x);
+  int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
+  if (coop_flag) {
+    // it's not worth having a grid reduction if the reduction dimension is not big enough
+    grid_y = grid_y < 8 ? 1 : grid_y;
+  }
+
+  block.x = block_x;
+  block.y = block_y;
+  block.z = 1;
+  grid.x = grid_x;
+  grid.y = grid_y;
+  grid.z = 1;
+}
+
+template
+__device__ __forceinline__ void welford_merge_element(C& count,
+                                                      T& mean,
+                                                      T& m2n,
+                                                      const C& count_new,
+                                                      const T& mean_new,
+                                                      const T& m2n_new) {
+      T factor = T(1.0) / ::max(1, (count + count_new));
+      T delta0 = mean - mean_new;
+      mean = (mean_new * count_new + mean * count) * factor;
+      m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
+      count += count_new;
+}
+
+// merge mean/m2n among threadIdx.y within block
+template
+__device__ __forceinline__ void welford_merge_block_vertical(C& count,
+                                                             T& mean,
+                                                             T& m2n,
+                                                             C* shmem_count,
+                                                             T* shmem_mean,
+                                                             T* shmem_m2n) {
+  // write to shared memory
+  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
+
+#pragma unroll
+  for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
+    if (threadIdx.y < offset*2) {
+      shmem_mean[address_base] = mean;
+      shmem_m2n[address_base] = m2n;
+      shmem_count[address_base] = count;
+    }
+    __syncthreads();
+    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
+      auto address = address_base + offset * blockDim.x;
+      // read shared memory back to register for reduction
+      auto count_new = shmem_count[address];
+      auto mean_new = shmem_mean[address];
+      auto m2n_new = shmem_m2n[address];
+
+      welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
+    }
+  }
+}
+
+template 
+__global__ void batch_norm_transform_input_kernel(
+    const GenericPackedTensorAccessor input,
+    GenericPackedTensorAccessor output,
+    const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_,
+    const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
+    const GenericPackedTensorAccessor weight,
+    const GenericPackedTensorAccessor bias,
+    stat_accscalar_t epsilon) {
+
+  index_t plane = blockIdx.x;
+
+  if (plane >= input.size(1)) {
+    return;
+  }
+
+  stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1);
+  stat_accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0);
+  stat_accscalar_t mean = static_cast(mean_[plane]);
+  stat_accscalar_t invstd;
+  if (train) {
+    invstd = var_or_invstd[plane];
+  } else {
+    invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon);
+  }
+
+  index_t bs = input.size(0);
+  index_t fs = input.size(2);
+
+  index_t bstep  = blockDim.y * gridDim.y;
+  for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
+    auto o = output[batch][plane];
+    auto i = input[batch][plane];
+    for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
+      o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta);
+    }
+  }
+}
+
+struct InvStd {
+  template 
+  __device__ __forceinline__ T operator()(T var, double epsilon) const {
+    T invstd = 0;
+    if (var != static_cast(0) || epsilon != static_cast(0)) {
+      invstd = static_cast(1) / device_sqrt(var + epsilon);
+    }
+    return invstd;
+  }
+};
+
+struct Var {
+  template 
+  __device__ __forceinline__ T operator()(T var, double epsilon) const {
+    return var;
+  }
+};
+
+template 
+__global__ void batch_norm_collect_statistics_kernel(
+    const GenericPackedTensorAccessor input,
+    const stat_accscalar_t epsilon,
+    const stat_accscalar_t momentum,
+    GenericPackedTensorAccessor save_mean,
+    GenericPackedTensorAccessor save_transformed_var) {
+
+  __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
+
+  int plane = blockIdx.x;
+  int N = input.size(0) * input.size(2);
+  int tid = threadIdx.x + threadIdx.y * blockDim.x;
+
+  // Compute the mean and variance across (batch, x/y/z)
+  // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
+  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
+  // and the parallel algorithm on the same page.
+  // We use two shuffles to reduce across the entire block.
+  // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
+  stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
+
+  // first the reductions each thread does separately
+  stat_accscalar_t avg = 0;
+  stat_accscalar_t var_n = 0;
+  int n = 0;
+  for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
+    for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
+      stat_accscalar_t v = input[batch][plane][x];
+      stat_accscalar_t d1 = v - avg;
+      n++;
+      avg += d1 / n;
+      var_n += d1 * (v - avg);
+    }
+  }
+
+  // first warpSum to get one value per thread to
+  // one value per warp
+  for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
+    stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
+    int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
+    stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
+    var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
+    avg = (n * avg + o_n * o_avg) * factor;
+    n += o_n;
+  }
+
+  // this writes each warps  item into shared memory
+  // there are at most C10_WARP_SIZE items left because
+  // there are at most C10_WARP_SIZE**2 threads at the beginning
+  __syncthreads();
+  if (tid % C10_WARP_SIZE == 0) {
+    shared_n[tid / C10_WARP_SIZE] = n;
+    shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
+    shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
+  }
+  __syncthreads();
+  // now have a second warpSum to reduce the intermediate values
+  // from shared memory to a single number. The very first
+  // thread writes it to shared memory.
+
+  if (tid < C10_WARP_SIZE) {
+    n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
+    avg = (tid < blockDim.x * blockDim.y  / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
+    var_n = (tid < blockDim.x * blockDim.y  / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
+  }
+  for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
+    stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
+    int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
+    stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
+    var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
+    avg = (n * avg + o_n * o_avg) * factor;
+    n += o_n;
+  }
+
+  // Save the mean, variance, and moving averages
+  if (tid == 0) {
+    if (save_mean.data() != NULL) {
+      save_mean[plane] = avg;
+    }
+    if (save_transformed_var.data() != NULL) {
+      save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
+    }
+  }
+
+}
+
+template 
+__global__ void batch_norm_backward_kernel(
+    const GenericPackedTensorAccessor input,
+    const GenericPackedTensorAccessor grad_output,
+    GenericPackedTensorAccessor grad_input,
+    GenericPackedTensorAccessor grad_weight,
+    GenericPackedTensorAccessor grad_bias,
+    const GenericPackedTensorAccessor weight,
+    const GenericPackedTensorAccessor running_mean,
+    const GenericPackedTensorAccessor running_var,
+    const GenericPackedTensorAccessor save_mean,
+    const GenericPackedTensorAccessor save_invstd,
+    bool train,
+    stat_accscalar_t epsilon) {
+
+  index_t plane = blockIdx.x;
+  index_t N = grad_output.size(0) * grad_output.size(2);
+
+  stat_accscalar_t mean, invstd;
+  if (train) {
+    mean = save_mean[plane];
+    invstd = save_invstd[plane];
+  } else {
+    mean = static_cast(running_mean[plane]);
+    invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon);
+  }
+
+  stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1);
+  stat_accscalar_t norm = stat_accscalar_t(1) / N;
+
+  // Compute two values across (batch, x/y/z) in one pass:
+  // 1. Sum(grad_output)
+  // 2. DotProduct(input - mean, grad_output)
+  GradOp> g(mean, input, grad_output);
+  auto res = reduce>(g, grad_output, plane);
+
+  stat_accscalar_t grad_output_sum = res.v1;
+  stat_accscalar_t dot_p = res.v2;
+
+  stat_accscalar_t grad_mean = grad_output_sum * norm;
+  stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
+  stat_accscalar_t grad_scale = invstd * weight_val;
+
+  if (grad_input.data() != NULL) {
+    for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
+      for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
+        input_scalar_t go = grad_output[batch][plane][x];
+        if (train) {
+          stat_accscalar_t inp = input[batch][plane][x];
+          stat_accscalar_t proj = (inp - mean) * proj_scale;
+          grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale);
+        } else {
+          grad_input[batch][plane][x] = static_cast(go * grad_scale);
+        }
+      }
+    }
+  }
+
+  if (grad_weight.size(0) > 0) {
+    if (threadIdx.x == 0) {
+      grad_weight[plane] = static_cast(dot_p * invstd);
+    }
+  }
+
+  if (grad_bias.size(0) > 0) {
+    if (threadIdx.x == 0) {
+      grad_bias[plane] = static_cast(grad_output_sum);
+    }
+  }
+}
+
+template 
+__global__ void batch_norm_reduce_statistics_kernel(
+    const GenericPackedTensorAccessor vec_mean,
+    const GenericPackedTensorAccessor vec_invstd,
+    GenericPackedTensorAccessor mean,
+    GenericPackedTensorAccessor invstd,
+    GenericPackedTensorAccessor running_mean,
+    GenericPackedTensorAccessor running_var,
+    const accscalar_t epsilon,
+    const accscalar_t momentum,
+    const GenericPackedTensorAccessor counts) {
+
+  int feature_size = vec_mean.size(1);
+  int world_size = vec_mean.size(0);
+
+  int bid = blockIdx.x;
+  int tid = threadIdx.x;
+
+  // first the reductions each thread does separately
+  for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
+    accscalar_t avg = 0;
+    accscalar_t var_n = 0;
+    index_t n = 0;
+    for (int j = 0; j < world_size; j++) {
+      scalar_t count = counts[j];
+      accscalar_t m = vec_mean[j][i];
+      accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
+      v = (v * v - epsilon) * count;
+      accscalar_t factor = 1.0 / (n + count);
+      var_n += v + (avg - m) * (avg - m) * n * count * factor;
+      avg = n * factor * avg + count * factor * m;
+      n += count;
+    }
+    mean[i] = avg;
+    invstd[i] = static_cast(1) / device_sqrt(var_n / n + epsilon);
+    if (running_mean.data() != NULL) {
+      running_mean[i] = static_cast((1 - momentum) * running_mean[i] + momentum * avg);
+    }
+    accscalar_t unbiasedVar = var_n / (n - 1);
+    if (running_var.data() != NULL) {
+      running_var[i] = static_cast((1 - momentum) * running_var[i] + momentum * unbiasedVar);
+    }
+  }
+
+}
+
+template 
+__global__ void batch_norm_backward_reduce_kernel(
+    const GenericPackedTensorAccessor input,
+    const GenericPackedTensorAccessor grad_output,
+    GenericPackedTensorAccessor mean,
+    GenericPackedTensorAccessor invstd,
+    GenericPackedTensorAccessor sum_dy,
+    GenericPackedTensorAccessor sum_dy_xmu,
+    GenericPackedTensorAccessor grad_weight,
+    GenericPackedTensorAccessor grad_bias) {
+
+  index_t plane = blockIdx.x;
+
+  stat_accscalar_t r_mean = mean[plane];
+  stat_accscalar_t factor = invstd[plane];
+
+  GradOp> g(r_mean, input, grad_output);
+  auto res = reduce>(g, grad_output, plane);
+
+  if (threadIdx.x == 0) {
+    if (grad_weight.size(0) > 0) {
+      grad_weight[plane] = static_cast(res.v2 * factor);
+    }
+    if (grad_bias.size(0) > 0) {
+      grad_bias[plane] = static_cast(res.v1);
+    }
+    if (sum_dy.size(0) > 0) {
+      sum_dy[plane] = static_cast(res.v1);
+    }
+    if (sum_dy_xmu.size(0) > 0) {
+      sum_dy_xmu[plane] = static_cast(res.v2);
+    }
+  }
+}
+
+template 
+__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
+    const GenericPackedTensorAccessor input,
+    const GenericPackedTensorAccessor grad_output,
+    const GenericPackedTensorAccessor mean,
+    const GenericPackedTensorAccessor invstd,
+    const GenericPackedTensorAccessor weight,
+    const GenericPackedTensorAccessor sum_dy,
+    const GenericPackedTensorAccessor sum_dy_xmu,
+    GenericPackedTensorAccessor grad_input,
+    const stat_accscalar_t norm_fct) {
+  index_t plane = blockIdx.x;
+
+  if (plane >= input.size(1)) {
+    return;
+  }
+
+  stat_accscalar_t m_c = mean[plane];
+  stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
+  stat_accscalar_t factor_1_c = invstd[plane];
+  stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1);
+  factor_2_c *= factor_1_c;
+  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
+
+  index_t bs = input.size(0);
+  index_t fs = input.size(2);
+
+  index_t bstep  = blockDim.y * gridDim.y;
+  for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
+    auto g_i = grad_input[batch][plane];
+    auto g_o = grad_output[batch][plane];
+    auto i = input[batch][plane];
+    for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
+      g_i[feature] = static_cast((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
+    }
+  }
+}
+
+template 
+__global__ void batch_norm_backward_elemt_kernel(
+    const GenericPackedTensorAccessor input,
+    const GenericPackedTensorAccessor grad_output,
+    const GenericPackedTensorAccessor mean,
+    const GenericPackedTensorAccessor invstd,
+    const GenericPackedTensorAccessor weight,
+    const GenericPackedTensorAccessor sum_dy,
+    const GenericPackedTensorAccessor sum_dy_xmu,
+    GenericPackedTensorAccessor grad_input,
+    const int* __restrict__ numel, const int world_size) {
+  int64_t total_numel = 0;
+  for (int i = 0; i < world_size; i ++) {
+    total_numel += numel[i];
+  }
+
+  const stat_accscalar_t norm_fct =
+      static_cast(1) / static_cast(total_numel);
+  batch_norm_backward_elemt_kernel_impl(
+      input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
+}
+
+template 
+__global__ void batch_norm_backward_elemt_kernel(
+    const GenericPackedTensorAccessor input,
+    const GenericPackedTensorAccessor grad_output,
+    const GenericPackedTensorAccessor mean,
+    const GenericPackedTensorAccessor invstd,
+    const GenericPackedTensorAccessor weight,
+    const GenericPackedTensorAccessor sum_dy,
+    const GenericPackedTensorAccessor sum_dy_xmu,
+    GenericPackedTensorAccessor grad_input,
+    const stat_accscalar_t norm_fct) {
+  batch_norm_backward_elemt_kernel_impl(
+      input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
+}
+
+template  class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+static GenericPackedTensorAccessor get_packed_accessor(
+    const Tensor& t, c10::string_view var_name) {
+  constexpr auto expect_type = c10::CppTypeToScalarType::value;
+  const auto actual_type = t.scalar_type();
+  TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
+              " to have type ", expect_type, " but got ", actual_type);
+  return t.generic_packed_accessor();
+}
+
+template  class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+static GenericPackedTensorAccessor packed_accessor_or_dummy(
+    const Tensor& t, c10::string_view var_name) {
+  if (!t.defined()) {
+    const std::array zeros{{0}};
+    return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data());
+  }
+  return get_packed_accessor(t, var_name);
+}
+
+template
+std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
+                                                                     const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
+                                                                     bool train, double epsilon, std::array grad_input_mask) {
+
+  using accscalar_t = at::acc_type;
+  Tensor grad_input_;
+  Tensor grad_input_reshaped;
+  Tensor grad_weight_;
+  Tensor grad_bias_;
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
+  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+
+  if (grad_input_mask[0]) {
+    grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+    grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
+  }
+  if (grad_input_mask[1]) {
+    grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  }
+  if (grad_input_mask[2]) {
+    grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  }
+
+  auto input = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
+  auto grad_output = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
+  auto grad_input = packed_accessor_or_dummy<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
+  auto weight = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
+  auto grad_weight = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
+  auto grad_bias = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
+  auto running_mean = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
+  auto running_var = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
+  auto save_mean = packed_accessor_or_dummy<
+      accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
+  auto save_invstd = packed_accessor_or_dummy<
+      accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+  dim3 blocks(input.size(1));
+  int tf = getNumThreads(input.size(2));
+  dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf));
+
+  batch_norm_backward_kernel <<>>
+    (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
+     save_mean, save_invstd, train, epsilon);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
+}
+
+template
+void batch_norm_stats_cuda_template(
+    const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
+
+  using accscalar_t = at::acc_type;
+  int64_t n_input = input_.size(1);
+  Tensor dummy_mean_;
+  Tensor dummy_var_;
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+
+  resize_output(out_mean, {n_input});
+  resize_output(out_invstd, {n_input});
+  auto input = get_packed_accessor<
+      scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
+  TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
+                        out_invstd.sizes()[0]);
+  TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
+                        out_mean.sizes()[0]);
+
+  auto mean = packed_accessor_or_dummy<
+      accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
+  auto invstd = packed_accessor_or_dummy<
+      accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 blocks(input.size(1));
+  int tf = getNumThreads(input.size(2));
+  dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf));
+  batch_norm_collect_statistics_kernel <<>>
+    (input, epsilon, 0.0, mean, invstd);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template
+void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
+                                    const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
+
+  using stat_accscalar_t = at::acc_type;
+  int64_t n_input = input_.size(1);
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
+
+  auto input = get_packed_accessor<
+      input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
+  auto output = get_packed_accessor<
+      input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
+  auto weight = packed_accessor_or_dummy<
+    stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
+  auto bias = packed_accessor_or_dummy<
+      stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
+  auto mean = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
+  auto invstd = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
+  const double dummy_epsilon = 1e-5;
+
+  // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+  // The various planes are independent, so we use blocks for them.
+  int tf = std::max(getNumThreads(input.size(2)/4),
+                         std::min(getNumThreads(input.size(2)), 64));
+  int tb = std::max(64/tf, 1);
+  dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1),
+                                                                  (input.size(0)+tb-1)/tb)));
+  blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
+  dim3 threads_trans(tf, tb);
+  batch_norm_transform_input_kernel <<>>
+    (input, output, mean, invstd, weight, bias, dummy_epsilon);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template
+std::tuple batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
+                                                                 const Tensor& running_mean_, const Tensor& running_var_,
+                                                                 double momentum, double epsilon, const Tensor& counts_) {
+
+  Tensor save_mean_;
+  Tensor save_invstd_;
+
+  auto features = mean_.size(1);
+  auto input_options = mean_.options();
+  if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
+    input_options = input_options.dtype(ScalarType::Float);
+  }
+  save_mean_ = at::empty({features}, input_options);
+  save_invstd_ = at::empty({features}, input_options);
+
+  auto mean = packed_accessor_or_dummy<
+      accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
+  auto invstd = packed_accessor_or_dummy<
+      accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
+  auto running_mean = packed_accessor_or_dummy<
+      scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
+  auto running_var = packed_accessor_or_dummy<
+      scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
+  auto counts = packed_accessor_or_dummy<
+      scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
+
+  auto save_mean = get_packed_accessor<
+      accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
+  auto save_invstd = get_packed_accessor<
+      accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  int block = getNumThreads(features);
+  int grid = std::max(1, features/block);
+  batch_norm_reduce_statistics_kernel <<>>
+      (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return std::make_tuple(save_mean_, save_invstd_);
+}
+
+template
+std::tuple batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+                                                                                    const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
+                                                                                    const bool input_g, const bool weight_g, const bool bias_g) {
+
+  using stat_accscalar_t = at::acc_type;
+  int64_t n_input = input_.size(1);
+  Tensor sum_dy_;
+  Tensor sum_dy_xmu_;
+  Tensor grad_weight_;
+  Tensor grad_bias_;
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+
+  if (input_g) {
+    sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+    sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  }
+  if (weight_g) {
+    grad_weight_ = at::empty({n_input}, weight_.options());
+  }
+  if (bias_g) {
+    grad_bias_ = at::empty({n_input}, weight_.options());
+  }
+
+  auto input = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
+  auto grad_output = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
+  auto grad_weight = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
+  auto grad_bias = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
+  auto mean = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
+  auto invstd = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
+  auto sum_dy = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
+  auto sum_dy_xmu = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
+
+  auto batch_size = input_reshaped.size(0);
+  auto feature_size = input_reshaped.size(2);
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  int warp_size = at::cuda::warp_size();
+  int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
+  // We want block_x to be at least a warp width
+  int block_x = std::min(std::max(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
+  const dim3 block(block_x, block_y);
+  const dim3 grid(n_input);
+
+  batch_norm_backward_reduce_kernel <<>>
+    (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
+}
+
+template
+Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+                                               const Tensor& mean_, const Tensor& invstd_,
+                                               const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
+
+  using stat_accscalar_t = at::acc_type;
+  int64_t n_input = input_.size(1);
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+  auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+  auto input = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
+  auto grad_input = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
+  auto grad_output = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
+  auto mean = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
+  auto invstd = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
+  auto weight = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
+  auto sum_dy = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
+  auto sum_dy_xmu = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+  // The various planes are independent, so we use blocks for them.
+  int tf = std::max(getNumThreads(input.size(2)/4),
+                         std::min(getNumThreads(input.size(2)), 64));
+  int tb = std::max(64/tf, 1);
+  dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1),
+                                                                  (input.size(0)+tb-1)/tb)));
+  blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
+  dim3 threads_trans(tf, tb);
+  auto reduction_size = input_.numel() / n_input;
+  auto norm_fct = static_cast(1.0 / reduction_size);
+  batch_norm_backward_elemt_kernel
+      <<>>
+      (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return grad_input_reshaped.view(input_.sizes());
+}
+
+template
+Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
+                                               const Tensor& mean_, const Tensor& invstd_,
+                                               const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
+
+  using stat_accscalar_t = at::acc_type;
+  int64_t n_input = input_.size(1);
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
+  auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+
+  auto input = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
+  auto grad_input = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
+  auto grad_output = get_packed_accessor<
+      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
+  auto mean = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
+  auto invstd = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
+  auto weight = packed_accessor_or_dummy<
+      stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
+  auto sum_dy = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
+  auto sum_dy_xmu = packed_accessor_or_dummy<
+      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
+  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
+  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
+  // The various planes are independent, so we use blocks for them.
+  int tf = std::max(getNumThreads(input.size(2)/4),
+                         std::min(getNumThreads(input.size(2)), 64));
+  int tb = std::max(64/tf, 1);
+  dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1),
+                                                                  (input.size(0)+tb-1)/tb)));
+  blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
+  dim3 threads_trans(tf, tb);
+  batch_norm_backward_elemt_kernel <<>>
+    (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr(), count.numel());
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+  return grad_input_reshaped.view(input_.sizes());
+}
+
+// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
+// original apex name: welford_kernel_c_last
+template
+   
+__global__ void
+batch_norm_collect_statistics_channels_last_kernel(
+      const scalar_t* __restrict__ input,
+      accscalar_t* __restrict__ out_mean,
+      accscalar_t* __restrict__ out_invstd,
+      volatile accscalar_t* staging_data,
+      int* semaphores,
+      const int reduction_size,
+      const int stride,
+      accscalar_t epsilon) {
+  // hide latency with concurrency
+  accscalar_t x_mean[PARALLEL_LOADS];
+  accscalar_t m_2_n[PARALLEL_LOADS];
+  int count[PARALLEL_LOADS];
+
+#pragma unroll
+  for (int i = 0; i < PARALLEL_LOADS; i++) {
+    x_mean[i] = accscalar_t(0);
+    m_2_n[i] = accscalar_t(0);
+    count[i] = accscalar_t(0);
+  }
+  // tensor dimension (m,c)
+
+  // loop along m dimension
+  int inner_loop_stride = blockDim.y * gridDim.y;
+
+  // offset along m dimension
+  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
+  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
+
+  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
+  int address_base = m_offset * stride + c_offset;
+  int address_increment = inner_loop_stride * stride;
+
+  for (int i = 0; i < loop_count; i++) {
+    accscalar_t x_math[PARALLEL_LOADS];
+    accscalar_t x_count_inv[PARALLEL_LOADS];
+    accscalar_t is_valid[PARALLEL_LOADS];
+
+    // load multiple data in
+#pragma unroll
+    for (int j = 0; j < PARALLEL_LOADS; j++) {
+      if (c_offset < stride && m_offset < reduction_size) {
+        x_math[j] = input[address_base];
+        count[j]++;
+        x_count_inv[j] = accscalar_t(1) / count[j];
+        is_valid[j] = accscalar_t(1);
+      } else {
+        x_math[j] = accscalar_t(0);
+        x_count_inv[j] = accscalar_t(0);
+        is_valid[j] = accscalar_t(0);
+      }
+      m_offset += inner_loop_stride;
+      address_base += address_increment;
+    }
+
+    // calculate mean/m2n with welford
+#pragma unroll
+    for (int j = 0; j < PARALLEL_LOADS; j++) {
+      accscalar_t delta0 = x_math[j] - x_mean[j];
+      x_mean[j] += delta0 * x_count_inv[j];
+      accscalar_t delta1 = x_math[j] - x_mean[j];
+      m_2_n[j] += delta0 * delta1 * is_valid[j];
+    }
+  }
+
+  // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
+#pragma unroll
+  for (int j = 1; j < PARALLEL_LOADS; j++) {
+    welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
+  }
+
+  // release x_mean / m_2_n
+  auto mean_th = x_mean[0];
+  auto m2_th = m_2_n[0];
+  auto count_th = count[0];
+
+  // block-wise reduction with shared memory (since reduction cannot be done within a warp)
+  static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
+  static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
+  static __shared__ int shmem_count[MAX_BLOCK_SIZE];
+
+  welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
+
+  if (gridDim.y > 1) {
+    volatile accscalar_t* staging_mean = staging_data;
+    volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
+    volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]);
+
+    address_base = c_offset + blockIdx.y * stride;
+    // write data to staging_data;
+    if (threadIdx.y == 0 && c_offset < stride) {
+      staging_mean[address_base] = mean_th;
+      staging_m2n[address_base] = m2_th;
+      staging_count[address_base] = count_th;
+    }
+
+    __threadfence();
+    __syncthreads(); // ensuring writes to staging_ is visible to all blocks
+
+    __shared__ bool is_last_block_done;
+    // mark block done
+    if (threadIdx.x == 0 && threadIdx.y == 0) {
+      int old = atomicAdd(&semaphores[blockIdx.x], 1);
+      is_last_block_done = (old == (gridDim.y-1));
+    }
+
+    __syncthreads();
+
+    // check that all data is now available in global memory
+    if (is_last_block_done) {
+      count_th = 0;
+      mean_th = accscalar_t(0.0);
+      m2_th = accscalar_t(0.0);
+
+      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
+        address_base = c_offset + y * stride;
+        int count_new = c_offset < stride ? staging_count[address_base] : 0;
+        accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
+        accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
+
+        welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
+      }
+
+      welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
+      if (threadIdx.y == 0 && c_offset < stride) {
+        out_mean[c_offset] = static_cast(mean_th);
+        out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
+      }
+    }
+  } else {
+    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
+      out_mean[c_offset] = static_cast(mean_th);
+      out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
+    }
+  }
+}
+
+// elementwise BN kernel
+// original apex name: batchnorm_forward_c_last_kernel
+template <
+    typename scalar_t,
+    typename accscalar_t,
+    typename layerscalar_t,
+    int PARALLEL_LOADS>
+__global__ void batch_norm_transform_input_channels_last_kernel(
+      const scalar_t* __restrict__ input,
+      const scalar_t* __restrict__ z,
+      const accscalar_t* __restrict__ mean,
+      const accscalar_t* __restrict__ inv_std,
+      const layerscalar_t* __restrict__ weight,
+      const layerscalar_t* __restrict__ shift,
+      scalar_t* __restrict__ out,
+      const int reduction_size,
+      const int stride,
+      const bool fuse_relu) {
+  // tensor dimension (m,c)
+  // loop along m dimension
+  int inner_loop_stride = blockDim.y * gridDim.y;
+
+  // offset along m dimension
+  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
+  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (c_offset >= stride || m_offset >= reduction_size) {
+    return;
+  }
+
+  auto m_c = mean[c_offset];
+  auto inv_std_c = static_cast(inv_std[c_offset]);
+  auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast(weight[c_offset]);
+  auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast(shift[c_offset]);
+
+  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
+  int address_base = m_offset * stride + c_offset;
+  int address_increment = inner_loop_stride * stride;
+
+  for (int i = 0; i < loop_count; i++) {
+#pragma unroll
+    for (int j = 0; j < PARALLEL_LOADS; j++) {
+      if (c_offset < stride && m_offset < reduction_size) {
+        auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c;
+        if (z != nullptr) {
+          tmp += z[address_base];
+        }
+        out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp));
+      }
+      m_offset += inner_loop_stride;
+      address_base += address_increment;
+    }
+  }
+}
+
+template
+__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
+    T& sum_dy_xmu,
+    T* shmem_sum_dy,
+    T* shmem_sum_dy_xmu) {
+  // write to shared memory
+  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
+
+#pragma unroll
+  for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
+    if (threadIdx.y < offset*2) {
+      shmem_sum_dy[address_base] = sum_dy;
+      shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
+    }
+    __syncthreads();
+    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
+      auto address = address_base + offset * blockDim.x;
+
+      sum_dy += shmem_sum_dy[address];
+      sum_dy_xmu += shmem_sum_dy_xmu[address];
+    }
+  }
+}
+
+// batchnorm backward kernel for c last tensor
+// original apex name: reduce_bn_c_last_kernel
+template <
+    int PARALLEL_LOADS,
+    typename scalar_t,
+    typename accscalar_t,
+    typename layerscalar_t>
+__global__ void batch_norm_backward_reduce_channels_last_kernel(
+      const scalar_t* __restrict__ input,
+      const scalar_t* __restrict__ grad_output,
+      const accscalar_t* __restrict__ mean,
+      const accscalar_t* __restrict__ inv_std,
+      accscalar_t* __restrict__ sum_dy_o,
+      accscalar_t* __restrict__ sum_dy_xmu_o,
+      layerscalar_t* __restrict__ grad_weight,
+      layerscalar_t* __restrict__ grad_bias,
+      volatile accscalar_t* staging_data,
+      int* semaphores,
+      const int reduction_size,
+      const int stride) {
+
+  // hide latency with concurrency
+  accscalar_t sum_dy[PARALLEL_LOADS];
+  accscalar_t sum_dy_xmu[PARALLEL_LOADS];
+
+#pragma unroll
+  for (int i = 0; i < PARALLEL_LOADS; i++) {
+    sum_dy[i] = accscalar_t(0);
+    sum_dy_xmu[i] = accscalar_t(0);
+  }
+  // tensor dimension (m,c)
+
+  // loop along m dimension
+  int inner_loop_stride = blockDim.y * gridDim.y;
+
+  // offset along m dimension
+  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
+  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (c_offset >= stride || m_offset >= reduction_size) {
+    return;
+  }
+
+  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
+  int address_base = m_offset * stride + c_offset;
+  int address_increment = inner_loop_stride * stride;
+
+  auto r_mean = mean[c_offset];
+  auto factor = inv_std[c_offset];
+
+  for (int i = 0; i < loop_count; i++) {
+    accscalar_t x_input[PARALLEL_LOADS];
+    accscalar_t x_grad_output[PARALLEL_LOADS];
+
+    // load multiple data in
+#pragma unroll
+    for (int j = 0; j < PARALLEL_LOADS; j++) {
+      if (c_offset < stride && m_offset < reduction_size) {
+        x_input[j] = input[address_base];
+        x_grad_output[j] = grad_output[address_base];
+      } else {
+        x_input[j] = accscalar_t(0);
+        x_grad_output[j] = accscalar_t(0);
+      }
+      m_offset += inner_loop_stride;
+      address_base += address_increment;
+    }
+
+    // calculate sum_dy / sum_dy_xmu
+#pragma unroll
+    for (int j = 0; j < PARALLEL_LOADS; j++) {
+      sum_dy[j] += x_grad_output[j];
+      sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
+    }
+  }
+
+  // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
+#pragma unroll
+  for (int j = 1; j < PARALLEL_LOADS; j++) {
+    sum_dy[0] += sum_dy[j];
+    sum_dy_xmu[0] += sum_dy_xmu[j];
+  }
+
+  // release array of registers
+  auto sum_dy_th = sum_dy[0];
+  auto sum_dy_xmu_th = sum_dy_xmu[0];
+
+  // block-wise reduction with shared memory (since reduction cannot be done within a warp)
+  static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
+  static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
+
+  merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
+
+  if (gridDim.y > 1) {
+    volatile accscalar_t* staging_sum_dy = staging_data;
+    volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
+
+    address_base = c_offset + blockIdx.y * stride;
+    // write data to staging_data;
+    if (threadIdx.y == 0 && c_offset < stride) {
+      staging_sum_dy[address_base] = sum_dy_th;
+      staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
+    }
+
+    __threadfence();
+    __syncthreads(); // ensuring writes to staging_ is visible to all blocks
+
+    __shared__ bool is_last_block_done;
+    // mark block done
+    if (threadIdx.x == 0 && threadIdx.y == 0) {
+      int old = atomicAdd(&semaphores[blockIdx.x], 1);
+      is_last_block_done = (old == (gridDim.y-1));
+    }
+
+    __syncthreads();
+
+    // check that all data is now available in global memory
+    if (is_last_block_done) {
+      sum_dy_th = accscalar_t(0.0);
+      sum_dy_xmu_th = accscalar_t(0.0);
+
+      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
+        address_base = c_offset + y * stride;
+        sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
+        sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
+      }
+
+      merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
+      if (threadIdx.y == 0 && c_offset < stride) {
+        if (grad_bias != nullptr) {
+          grad_bias[c_offset] = static_cast(sum_dy_th);
+        }
+        if (grad_weight != nullptr) {
+          grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor);
+        }
+        //mean_dy[c_offset] = sum_dy_th / reduction_size;
+        //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
+        sum_dy_o[c_offset] = sum_dy_th;
+        sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
+      }
+    }
+  } else {
+    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
+      if (grad_bias != nullptr) {
+        grad_bias[c_offset] = static_cast(sum_dy_th);
+      }
+      if (grad_weight != nullptr) {
+        grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor);
+      }
+      //mean_dy[c_offset] = sum_dy_th / reduction_size;
+      //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
+      sum_dy_o[c_offset] = sum_dy_th;
+      sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
+    }
+  }
+}
+
+// elementwise BN kernel
+// original apex name: batchnorm_backward_c_last_kernel
+template <
+    int PARALLEL_LOADS,
+    typename scalar_t,
+    typename accscalar_t,
+    typename layerscalar_t>
+__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
+      const scalar_t* __restrict__ grad_output,
+      const scalar_t* __restrict__ input,
+      const accscalar_t* __restrict__ mean,
+      const accscalar_t* __restrict__ inv_std,
+      const layerscalar_t* __restrict__ weight,
+      const accscalar_t* __restrict__ sum_dy,
+      const accscalar_t* __restrict__ sum_dy_xmu,
+      scalar_t* __restrict__ grad_input,
+      const accscalar_t norm_fct,
+      const int reduction_size,
+      const int stride) {
+  // tensor dimension (m,c)
+  // loop along m dimension
+  int inner_loop_stride = blockDim.y * gridDim.y;
+
+  // offset along m dimension
+  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
+  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (c_offset >= stride || m_offset >= reduction_size) {
+    return;
+  }
+
+  auto m_c = mean[c_offset];
+  auto m_dy_c = sum_dy[c_offset] * norm_fct;
+  auto factor_1_c = inv_std[c_offset];
+  auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c;
+  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
+
+  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
+  int address_base = m_offset * stride + c_offset;
+  int address_increment = inner_loop_stride * stride;
+
+  for (int i = 0; i < loop_count; i++) {
+#pragma unroll
+    for (int j = 0; j < PARALLEL_LOADS; j++) {
+      if (c_offset < stride && m_offset < reduction_size) {
+        grad_input[address_base] = static_cast(
+            (static_cast(grad_output[address_base]) - m_dy_c -
+            (static_cast(input[address_base]) - m_c) * factor_1_c)
+            * factor_2_c);
+      }
+      m_offset += inner_loop_stride;
+      address_base += address_increment;
+    }
+  }
+}
+
+template <
+    int PARALLEL_LOADS,
+    typename scalar_t,
+    typename accscalar_t,
+    typename layerscalar_t>
+__global__ void batch_norm_backward_elemt_channels_last_kernel(
+      const scalar_t* __restrict__ grad_output,
+      const scalar_t* __restrict__ input,
+      const accscalar_t* __restrict__ mean,
+      const accscalar_t* __restrict__ inv_std,
+      const layerscalar_t* __restrict__ weight,
+      const accscalar_t* __restrict__ sum_dy,
+      const accscalar_t* __restrict__ sum_dy_xmu,
+      const int* __restrict__ numel,
+      scalar_t* __restrict__ grad_input,
+      const int64_t world_size,
+      const int reduction_size,
+      const int stride) {
+
+  int64_t total_numel = 0;
+  for (int i = 0; i < world_size; i++) {
+    total_numel += numel[i];
+  }
+
+  auto norm_fct = static_cast(1) / static_cast(total_numel);
+  batch_norm_backward_elemt_channels_last_kernel_impl(
+      grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
+      grad_input, norm_fct, reduction_size, stride);
+}
+
+template <
+    int PARALLEL_LOADS,
+    typename scalar_t,
+    typename accscalar_t,
+    typename layerscalar_t>
+__global__ void batch_norm_backward_elemt_channels_last_kernel(
+      const scalar_t* __restrict__ grad_output,
+      const scalar_t* __restrict__ input,
+      const accscalar_t* __restrict__ mean,
+      const accscalar_t* __restrict__ inv_std,
+      const layerscalar_t* __restrict__ weight,
+      const accscalar_t* __restrict__ sum_dy,
+      const accscalar_t* __restrict__ sum_dy_xmu,
+      scalar_t* __restrict__ grad_input,
+      const accscalar_t norm_fct,
+      const int reduction_size,
+      const int stride) {
+  batch_norm_backward_elemt_channels_last_kernel_impl(
+      grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
+      grad_input, norm_fct, reduction_size, stride);
+}
+
+template
+void batch_norm_stats_channels_last_cuda_template(
+    const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
+  using accscalar_t = at::acc_type;
+
+  const auto stride = input.sizes()[1];
+  const auto reduction_size = input.numel() / stride;
+
+  resize_output(out_mean, {stride});
+  resize_output(out_invstd, {stride});
+  TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
+                        out_invstd.sizes()[0]);
+  TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
+                        out_mean.sizes()[0]);
+
+  dim3 block;
+  dim3 grid;
+  flexible_launch_configs(reduction_size, stride, block, grid, true);
+
+  at::Tensor staging_data;
+  at::Tensor semaphores;
+  if (grid.y > 1) {
+    staging_data = at::empty({4*stride*grid.y}, out_mean.options());
+    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
+  }
+
+  accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr;
+  int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr;
+  batch_norm_collect_statistics_channels_last_kernel
+      <<>>(
+      input.const_data_ptr(),
+      out_mean.mutable_data_ptr(),
+      out_invstd.mutable_data_ptr(),
+      staging_data_ptr,
+      semaphores_ptr,
+      reduction_size,
+      stride,
+      epsilon);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+void batch_norm_elemt_channels_last_cuda_template(
+    const at::Tensor& output,
+    const at::Tensor& input,
+    const at::Tensor& weight,
+    const at::Tensor& shift,  // bias of BN
+    const at::Tensor& mean,
+    const at::Tensor& inv_std,
+    const at::optional& z = c10::nullopt,  // bias after BN
+    const bool fuse_relu = false) {
+  const auto stride = input.sizes()[1];
+  const auto reduction_size = input.numel() / stride;
+
+  dim3 block;
+  dim3 grid;
+  flexible_launch_configs(reduction_size, stride, block, grid);
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+  const auto second_dtype = weight.defined() ? weight.scalar_type() :
+      (shift.defined() ? shift.scalar_type() : input.scalar_type());
+
+  if (input.scalar_type() != second_dtype) {
+    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
+      using accscalar_t = at::acc_type;
+      batch_norm_transform_input_channels_last_kernel
+          <<>>(
+          input.const_data_ptr(),
+          z.has_value() ? z.value().const_data_ptr() : nullptr,
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          weight.defined() ? weight.const_data_ptr() : nullptr,
+          shift.defined() ? shift.const_data_ptr() : nullptr,
+          output.mutable_data_ptr(),
+          reduction_size,
+          stride,
+          fuse_relu);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    });
+  } else {
+    if (weight.defined()){
+      TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
+        " is not supported with weight.scalar_type() ", weight.scalar_type());
+    }
+    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
+      using accscalar_t = at::acc_type;
+      batch_norm_transform_input_channels_last_kernel
+          <<>>(
+          input.const_data_ptr(),
+          z.has_value() ? z.value().const_data_ptr() : nullptr,
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          weight.defined() ? weight.const_data_ptr() : nullptr,
+          shift.defined() ? shift.const_data_ptr(): nullptr,
+          output.mutable_data_ptr(),
+          reduction_size,
+          stride,
+          fuse_relu);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    });
+  }
+}
+
+std::tuple
+batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
+    const at::Tensor& input,
+    const at::Tensor& mean,
+    const at::Tensor& inv_std,
+    const at::Tensor& weight,
+    const bool input_g, const bool weight_g, const bool bias_g) {
+  const auto stride = input.sizes()[1];
+  const auto reduction_size = input.numel() / stride;
+
+  at::Tensor sumn_dy = at::empty({stride}, mean.options());
+  at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
+
+  at::Tensor grad_weight;
+  at::Tensor grad_bias;
+  if (weight.defined()) {
+    grad_weight = at::empty({stride}, weight.options());
+    grad_bias = at::empty({stride}, weight.options());
+  } else {
+    // because I cannot return an uninitialized at::Tensor
+    grad_weight = at::empty({0}, mean.options());
+    grad_bias = at::empty({0}, mean.options());
+  }
+
+  dim3 block;
+  dim3 grid;
+  flexible_launch_configs(reduction_size, stride, block, grid, true);
+
+  at::Tensor staging_data;
+  at::Tensor semaphores;
+  if (grid.y > 1) {
+    staging_data = at::empty({2*stride*grid.y}, mean.options());
+    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
+  }
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
+    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
+      using accscalar_t = at::acc_type;
+      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr;
+      int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr;
+      batch_norm_backward_reduce_channels_last_kernel
+          <<>>(
+          input.const_data_ptr(),
+          grad_output.const_data_ptr(),
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          sumn_dy.mutable_data_ptr(),
+          sum_dy_xmu.mutable_data_ptr(),
+          grad_weight.mutable_data_ptr(),
+          grad_bias.mutable_data_ptr(),
+          staging_data_ptr,
+          semaphores_ptr,
+          reduction_size,
+          stride);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    });
+  } else {
+    if (weight.defined()) {
+      TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
+        " is not supported with weight.scalar_type() ", weight.scalar_type());
+    }
+    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
+      using accscalar_t = at::acc_type;
+      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr;
+      int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr;
+      batch_norm_backward_reduce_channels_last_kernel
+          <<>>(
+          input.const_data_ptr(),
+          grad_output.const_data_ptr(),
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          sumn_dy.mutable_data_ptr(),
+          sum_dy_xmu.mutable_data_ptr(),
+          weight.defined() ? grad_weight.mutable_data_ptr() : nullptr,
+          weight.defined() ? grad_bias.mutable_data_ptr() : nullptr,
+          staging_data_ptr,
+          semaphores_ptr,
+          reduction_size,
+          stride);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    });
+  }
+
+  return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
+}
+
+at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
+    const at::Tensor& grad_output,
+    const at::Tensor& input,
+    const at::Tensor& mean,
+    const at::Tensor& inv_std,
+    const at::Tensor& weight,
+    const at::Tensor& sum_dy,
+    const at::Tensor& sum_dy_xmu,
+    const at::Tensor& count) {
+  const auto stride = input.sizes()[1];
+  const auto reduction_size = input.numel() / stride;
+
+  // Input is guarunteed to be channels-last compatible
+  at::Tensor grad_input = at::empty_like(input);
+
+  dim3 block;
+  dim3 grid;
+  flexible_launch_configs(reduction_size, stride, block, grid);
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
+    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
+      using accscalar_t = at::acc_type;
+      batch_norm_backward_elemt_channels_last_kernel
+          <<>>(
+          grad_output.const_data_ptr(),
+          input.const_data_ptr(),
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          weight.const_data_ptr(),
+          sum_dy.const_data_ptr(),
+          sum_dy_xmu.const_data_ptr(),
+          count.const_data_ptr(),
+          grad_input.mutable_data_ptr(),
+          count.numel(),
+          reduction_size,
+          stride);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    });
+  } else {
+    if (weight.defined()) {
+      TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
+        " is not supported with weight.scalar_type() ", weight.scalar_type());
+    }
+    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
+      using accscalar_t = at::acc_type;
+      batch_norm_backward_elemt_channels_last_kernel
+          <<>>(
+          grad_output.const_data_ptr(),
+          input.const_data_ptr(),
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          weight.defined() ? weight.const_data_ptr() : nullptr,
+          sum_dy.const_data_ptr(),
+          sum_dy_xmu.const_data_ptr(),
+          count.const_data_ptr(),
+          grad_input.mutable_data_ptr(),
+          count.numel(),
+          reduction_size,
+          stride);
+      C10_CUDA_KERNEL_LAUNCH_CHECK();
+    });
+  }
+
+  return grad_input;
+}
+
+at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
+    const at::Tensor& grad_output,
+    const at::Tensor& input,
+    const at::Tensor& mean,
+    const at::Tensor& inv_std,
+    const at::Tensor& weight,
+    const at::Tensor& sum_dy,
+    const at::Tensor& sum_dy_xmu) {
+  const auto stride = input.sizes()[1];
+  const auto reduction_size = input.numel() / stride;
+  auto norm_fct = 1.0 / reduction_size;
+
+  // Input is guarunteed to be channels-last compatible
+  at::Tensor grad_input = at::empty_like(input);
+
+  dim3 block;
+  dim3 grid;
+  flexible_launch_configs(reduction_size, stride, block, grid);
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
+    using accscalar_t = at::acc_type;
+
+    if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
+      batch_norm_backward_elemt_channels_last_kernel
+          <<>>(
+          grad_output.const_data_ptr(),
+          input.const_data_ptr(),
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          weight.const_data_ptr(),
+          sum_dy.const_data_ptr(),
+          sum_dy_xmu.const_data_ptr(),
+          grad_input.mutable_data_ptr(),
+          static_cast(norm_fct),
+          reduction_size,
+          stride);
+          C10_CUDA_KERNEL_LAUNCH_CHECK();
+    } else {
+      batch_norm_backward_elemt_channels_last_kernel
+          <<>>(
+          grad_output.const_data_ptr(),
+          input.const_data_ptr(),
+          mean.const_data_ptr(),
+          inv_std.const_data_ptr(),
+          weight.defined() ? weight.const_data_ptr() : nullptr,
+          sum_dy.const_data_ptr(),
+          sum_dy_xmu.const_data_ptr(),
+          grad_input.mutable_data_ptr(),
+          static_cast(norm_fct),
+          reduction_size,
+          stride);
+          C10_CUDA_KERNEL_LAUNCH_CHECK();
+    }
+  });
+
+  return grad_input;
+}
+
+} } // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..79ecec1981f376e528af2674a5822fa01cc6fd00
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/PersistentSoftmax.cuh
@@ -0,0 +1,401 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace {
+
+int log2_ceil(int value) {
+    int log2_value = 0;
+    while ((1 << log2_value) < value) ++log2_value;
+    return log2_value;
+}
+
+template
+struct Add {
+  __device__ __forceinline__ T operator()(T a, T b) const {
+    return a + b;
+  }
+};
+
+template
+struct Max {
+  __device__ __forceinline__ T operator()(T a, T b) const {
+    return a < b ? b : a;
+  }
+};
+
+template  class ReduceOp>
+__device__ __forceinline__ void warp_reduce(acc_t* sum) {
+    ReduceOp r;
+    #pragma unroll
+    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
+        #pragma unroll
+        for (int i = 0;  i < WARP_BATCH;  ++i) {
+            acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE);
+            sum[i] = r(sum[i], b);
+        }
+    }
+}
+
+// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
+// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
+// The template arguments have the following meaning:
+// One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples.
+// WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small.
+// A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp.
+// This is important because it means only __shfl_ instructions are required for reductions.
+// Note that this means WARP_SIZE must be a power of two and <= architecture warp size.
+// CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch.
+// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs.
+// is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
+// is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed.
+// The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
+// This allows SoftMax to be fused with a cast immediately following the SoftMax.
+// The mask should have the same shape as input, with a boolean indicate if the value is masked.
+// The head_chunk_size is only used for transformer mask softmax, equals to H * D * D.
+// For instance:
+// input_t=half,  acc_t=float, output_t=half  => read half tensor, float accumulators, write half tensor.
+// input_t=half,  acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
+// input_t_float, acc_t=float, output_t=half  => read float tensor, float accumulators, write half tensor.
+
+template 
+__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false)
+{
+    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
+    constexpr int next_power_of_two = 1 << log2_elements;
+    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+
+    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
+
+    // batch_size might not be a multiple of WARP_BATCH. Check how
+    // many batches have to computed within this WARP.
+    int local_batches = batch_size - first_batch;
+    if (local_batches > WARP_BATCH)
+        local_batches = WARP_BATCH;
+
+    // there might be multiple batches per warp. compute the index within the batch
+    int local_idx = threadIdx.x;
+    int idx_offset = first_batch * stride + local_idx;
+
+    src += idx_offset;
+    dst += idx_offset;
+
+    if (is_transformer_mask) {
+        mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx;
+    } else {
+        mask += idx_offset;
+    }
+    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
+    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
+    // the nested loops.
+    // This should have no impact on performance because the loops are unrolled anyway.
+
+    // load data from global memory
+    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            int element_index = local_idx + it * WARP_SIZE;
+            if (element_index < batch_element_count) {
+                elements[i][it] = src[i*element_count+it*WARP_SIZE];
+            } else {
+                elements[i][it] = -std::numeric_limits::infinity();
+            }
+        }
+    }
+
+    // compute max_value
+    acc_t max_value[WARP_BATCH];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+        bool is_meaningful_max = false;
+        max_value[i] = elements[i][0];
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            if (is_masked) {
+                int idx = it*WARP_SIZE;
+                if ((idx + local_idx) < batch_element_count) {
+                    if (!is_transformer_mask) {
+                        idx += i*element_count;
+                    }
+                    if (!mask[idx]) {
+                        max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+                        is_meaningful_max = true;
+                    }
+                }
+            } else {
+                max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
+            }
+        }
+        if (is_masked) {
+            if (!is_meaningful_max) {
+                max_value[i] = -std::numeric_limits::infinity();
+            }
+        }
+    }
+    warp_reduce(max_value);
+
+    acc_t sum[WARP_BATCH] { 0.0f };
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            if (!is_masked) {
+                if (is_log_softmax) {
+                    sum[i] += std::exp(elements[i][it] - max_value[i]);
+                } else {
+                    elements[i][it] = std::exp(elements[i][it] - max_value[i]);
+                    sum[i] += elements[i][it];
+                }
+            } else {
+                int idx = it*WARP_SIZE;
+                bool valid = (idx + local_idx) < batch_element_count;
+                if (!is_transformer_mask) {
+                    idx += i*element_count;
+                }
+                if (valid) {
+                    if (!mask[idx]) {
+                        if (is_log_softmax) {
+                            sum[i] += std::exp(elements[i][it] - max_value[i]);
+                        } else {
+                            elements[i][it] = std::exp(elements[i][it] - max_value[i]);
+                            sum[i] += elements[i][it];
+                        }
+                    } else {
+                        if (!is_log_softmax) {
+                            // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
+                            elements[i][it] = 0;
+                        }
+                    }
+                } else {
+                    if (!is_log_softmax) {
+                        elements[i][it] = 0.;
+                    }
+                }
+            }
+        }
+    }
+    warp_reduce(sum);
+
+    // store result
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        if (i >= local_batches)
+            break;
+        if (is_log_softmax) sum[i] = std::log(sum[i]);
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            int element_index = local_idx + it * WARP_SIZE;
+            if (element_index < element_count) {
+                if (is_log_softmax) {
+                    dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
+                } else if (sum[i] == 0) {
+                    dst[i*element_count+it*WARP_SIZE] = std::numeric_limits::quiet_NaN();
+                } else {
+                    dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i];
+                }
+            } else {
+                break;
+            }
+        }
+    }
+}
+
+template 
+__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr)
+{
+    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
+    constexpr int next_power_of_two = 1 << log2_elements;
+    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+
+    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
+
+    // batch_size might not be a multiple of WARP_BATCH. Check how
+    // many batches have to computed within this WARP.
+    int local_batches = batch_size - first_batch;
+    if (local_batches > WARP_BATCH)
+        local_batches = WARP_BATCH;
+
+    // there might be multiple batches per warp. compute the index within the batch
+    int local_idx = threadIdx.x % WARP_SIZE;
+
+    // the first element to process by the current thread
+    int thread_offset = first_batch * stride + local_idx;
+    grad += thread_offset;
+    output += thread_offset;
+    gradInput += thread_offset;
+    if (is_masked) {
+        mask += thread_offset;
+    }
+
+    // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
+    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
+    // the nested loops.
+    // This should have no impact on performance because the loops are unrolled anyway.
+
+    // load data from global memory
+    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
+    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            int element_index = local_idx + it * WARP_SIZE;
+            if (element_index < batch_element_count) {
+                grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE];
+                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
+            } else {
+                grad_reg[i][it] = acc_t(0);
+                output_reg[i][it] = acc_t(0);
+            }
+        }
+    }
+
+    acc_t sum[WARP_BATCH] { 0.0f };
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) {
+                sum[i] += grad_reg[i][it];
+            }
+        }
+    }
+    warp_reduce(sum);
+
+    // store result
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        if (i >= local_batches)
+            break;
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            int element_index = local_idx + it * WARP_SIZE;
+            if (element_index < element_count) {
+                if (is_masked && mask[i*element_count+it*WARP_SIZE]) {
+                    gradInput[i*element_count+it*WARP_SIZE] = 0;
+                }
+                // compute gradients
+                else if (is_log_softmax) {
+                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
+                } else {
+                    gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
+                }
+            }
+        }
+    }
+}
+
+} // end of anonymous namespace
+
+template
+void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false)
+{
+    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
+    if (softmax_elements == 0) {
+        return;
+    } else {
+        int log2_elements = log2_ceil(softmax_elements);
+        const int next_power_of_two = 1 << log2_elements;
+
+        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
+        int warp_size = at::cuda::warp_size();
+        warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
+
+        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
+        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+        // use 128 threads per block to maximize gpu utilization
+        constexpr int threads_per_block = 128;
+
+        int warps_per_block = (threads_per_block / warp_size);
+        int batches_per_block = warps_per_block * batches_per_warp;
+        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
+        dim3 threads(warp_size, warps_per_block, 1);
+        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+        switch (log2_elements) {
+            #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E:                    \
+            softmax_warp_forward   \
+                <<>>(dst,   \
+                    src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \
+            C10_CUDA_KERNEL_LAUNCH_CHECK();                                       \
+            break;
+
+            LAUNCH_SOFTMAX_WARP_FORWARD(0);  // 1
+            LAUNCH_SOFTMAX_WARP_FORWARD(1);  // 2
+            LAUNCH_SOFTMAX_WARP_FORWARD(2);  // 4
+            LAUNCH_SOFTMAX_WARP_FORWARD(3);  // 8
+            LAUNCH_SOFTMAX_WARP_FORWARD(4);  // 16
+            LAUNCH_SOFTMAX_WARP_FORWARD(5);  // 32
+            LAUNCH_SOFTMAX_WARP_FORWARD(6);  // 64
+            LAUNCH_SOFTMAX_WARP_FORWARD(7);  // 128
+            LAUNCH_SOFTMAX_WARP_FORWARD(8);  // 256
+            LAUNCH_SOFTMAX_WARP_FORWARD(9);  // 512
+            LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024
+            default:
+                break;
+        }
+    }
+}
+
+template
+void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr)
+{
+    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
+    if (softmax_elements == 0) {
+       return;
+    } else {
+        int log2_elements = log2_ceil(softmax_elements);
+        const int next_power_of_two = 1 << log2_elements;
+
+        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
+        int warp_size = at::cuda::warp_size();
+        warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
+
+        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
+        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+        // use 128 threads per block to maximize gpu utilization
+        constexpr int threads_per_block = 128;
+
+        int warps_per_block = (threads_per_block / warp_size);
+        int batches_per_block = warps_per_block * batches_per_warp;
+        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
+        dim3 threads(warp_size, warps_per_block, 1);
+        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+        switch (log2_elements) {
+            #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E:                      \
+            softmax_warp_backward \
+                <<>>       \
+                (grad_input, grad, output, batch_count, softmax_elements_stride, \
+                softmax_elements, mask);                                              \
+            C10_CUDA_KERNEL_LAUNCH_CHECK();                                      \
+            break;
+
+            LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1
+            LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2
+            LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4
+            LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8
+            LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16
+            LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32
+            LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64
+            LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128
+            LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256
+            LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512
+            LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024
+            default:
+                break;
+        }
+    }
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Pow.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Pow.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..f75054e47a6cf4b401f85fb53213178b08e33a17
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Pow.cuh
@@ -0,0 +1,58 @@
+#pragma once
+#include 
+#include 
+
+namespace at { namespace native {
+
+namespace {
+
+
+// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt.
+// So we need to define the functions with the explicit function signatures.
+// As for pow, the following signatures are defined as the device function:
+//   pow(float, int)
+//   pow(double, int)
+//   pow(float, float)
+//   pow(double, double)
+#ifdef _MSC_VER
+// Functions for pow
+// pow for at::Half
+static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) {
+  return static_cast(std::pow(static_cast(base), static_cast(exp)));
+}
+// pow for at::BFloat16
+static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) {
+  return static_cast(std::pow(static_cast(base), static_cast(exp)));
+}
+// pow (floating, floating/int)
+template 
+static inline __host__ __device__ typename std::enable_if::value && (std::is_same::value || std::is_same::value), Base_type>::type
+  pow_(Base_type base, Exp_type exp) {
+  return std::pow(base, exp);
+}
+// pow (Otherwise)
+template 
+static inline __host__ __device__ typename std::enable_if::value && !std::is_same::value, Base_type>::type
+  pow_(Base_type base, Exp_type exp) {
+  return static_cast(std::pow(static_cast(base), static_cast(exp)));
+}
+#else
+template 
+static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) {
+  return ::pow(base, exp);
+}
+#endif
+
+template 
+static inline __host__ __device__ std::enable_if_t::value, T> pow_(
+    T base, T exp) {
+  return at::native::powi(base, exp);
+}
+
+template 
+static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) {
+  return c10_complex_math::pow(base, exp);
+}
+
+} // namespace
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Randperm.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Randperm.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..3f908031f1e938b65d87c77d6a0d4182bc5747de
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Randperm.cuh
@@ -0,0 +1,58 @@
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace {
+
+// See note [Algorithm of randperm]
+template
+__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
+  int tid = threadIdx.x + blockDim.x * blockIdx.x;
+
+  // find the beginning of islands
+  if (tid >= n - 1) return;  // out of range
+  if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return;  // not in an island
+  if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return;  // not the beginning of an island
+
+  // find the size of islands
+  int island_size = 0;
+  do { island_size++; }
+  while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
+
+  // do random permutation inside each island.
+  data += tid;
+  auto seeds = at::cuda::philox::unpack(philox_args);
+  curandStatePhilox4_32_10_t state;
+  curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
+  for (int i = island_size - 1; i > 0; i--) {
+    unsigned int r = curand(&state) % (i + 1);
+    if (i != r) {
+      scalar_t tmp = data[i];
+      data[i] = data[r];
+      data[r] = tmp;
+    }
+  }
+}
+
+// See note [Algorithm of randperm]
+template
+void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, c10::optional &gen_) {
+  auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator());
+  int64_t counter_offset = n;
+  at::PhiloxCudaState rng_engine_inputs;
+  {
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard lock(gen->mutex_);
+    rng_engine_inputs = gen->philox_cuda_state(counter_offset);
+  }
+  T mask = static_cast((1UL << bits) - 1);
+  randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
+    keys, data, mask, n, rng_engine_inputs);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Reduce.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Reduce.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..765a2d4724e036820b33590d34304060cab7d690
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Reduce.cuh
@@ -0,0 +1,1354 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at { namespace native {
+
+using at::detail::Array;
+
+static inline int64_t div_up(int64_t a, int64_t b) {
+  return (a + b - 1) / b;
+}
+
+// returns floor(log2(n))
+static inline int last_pow2(int n) {
+  n |= (n >>  1);
+  n |= (n >>  2);
+  n |= (n >>  4);
+  n |= (n >>  8);
+  n |= (n >> 16);
+  return std::max(1, n - (n >> 1));
+}
+
+// returns reduced fraction numerator & denominator
+C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
+  // get GCD of num and denom using Euclid's algorithm.
+  // Can replace this with std::gcd if we ever support c++17.
+  size_t a = denominator;
+  size_t b = numerator;
+  while (b != 0) {
+      a %= b;
+      // swap(a,b)
+      size_t tmp = a;
+      a = b;
+      b = tmp;
+  }
+
+  // a is now the GCD
+  numerator /= a;
+  denominator /= a;
+}
+
+//template for changing MAX_NUM_THREADS based on op dtype
+template 
+struct mnt_wrapper {
+  static constexpr int MAX_NUM_THREADS = 512;
+};
+
+template <>
+struct mnt_wrapper >{
+  static constexpr int MAX_NUM_THREADS = 256;
+};
+
+constexpr int max_reduce_threads(c10::ScalarType type) {
+  return type == kComplexDouble ? 256 : 512;
+}
+
+struct ReduceConfig {
+  static constexpr int BLOCK_X = 0;
+  static constexpr int BLOCK_Y = 1;
+  static constexpr int CTA = 2;
+
+  static constexpr int input_vec_size = 4;
+
+  ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
+    : element_size_bytes(element_size_bytes)
+    , num_inputs(num_inputs)
+    , num_outputs(num_outputs) {}
+  int element_size_bytes;
+  int num_inputs;
+  int num_outputs;
+  int step_input = 1;
+  int step_output = 1;
+  int ctas_per_output = 1;
+  int input_mult[3] = {0, 0, 0};
+  int output_mult[2] = {0, 0};
+
+  int block_width;
+  int block_height;
+  int num_threads;
+
+  bool vectorize_input = false;
+  int output_vec_size = 1;
+
+  template 
+  void set_block_dimension(int64_t dim0, int64_t dim1) {
+    const int max_num_threads = mnt_wrapper::MAX_NUM_THREADS / output_vec_size;
+    int dim0_pow2 = dim0 < max_num_threads ? static_cast(last_pow2(dim0)) : max_num_threads;
+    int dim1_pow2 = dim1 < max_num_threads ? static_cast(last_pow2(dim1)) : max_num_threads;
+    block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
+    block_height = std::min(dim1_pow2, int(max_num_threads / block_width));
+    block_width = std::min(dim0_pow2, int(max_num_threads / block_height));
+    num_threads = block_width * block_height;
+  }
+
+  int split_input(int parallelism) {
+    int step = step_input;
+    step_input *= parallelism;
+    return step;
+  }
+
+  int split_output(int parallelism) {
+    int step = step_output;
+    step_output *= parallelism;
+    return step;
+  }
+
+  dim3 block() const {
+    return dim3(block_width, block_height);
+  }
+
+  dim3 grid() const {
+    return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output);
+  }
+
+  C10_HOST_DEVICE bool should_block_x_reduce() const {
+    return input_mult[BLOCK_X] != 0;
+  }
+
+  C10_HOST_DEVICE bool should_block_y_reduce() const {
+    return input_mult[BLOCK_Y] != 0;
+  }
+
+  C10_HOST_DEVICE bool should_global_reduce() const {
+    return input_mult[CTA] != 0;
+  }
+
+  C10_DEVICE bool should_store(int output_idx) const {
+    return output_idx < num_outputs &&
+      (!should_block_x_reduce() || threadIdx.x == 0) &&
+      (!should_block_y_reduce() || threadIdx.y == 0);
+  }
+
+  C10_DEVICE bool should_reduce_tail() const {
+    return (!should_block_y_reduce() || threadIdx.y == 0) &&
+      (!should_global_reduce() || blockIdx.y == 0);
+  }
+
+  C10_HOST_DEVICE int input_idx() const {
+    int lane = threadIdx.x;
+    int warp = threadIdx.y;
+    int cta2 = blockIdx.y;
+    return (lane * input_mult[BLOCK_X] +
+            warp * input_mult[BLOCK_Y] +
+            cta2 * input_mult[CTA]);
+  }
+
+  template 
+  C10_HOST_DEVICE int output_idx() const {
+    int lane = threadIdx.x;
+    int warp = threadIdx.y;
+    int cta1 = blockIdx.x;
+    return (lane * output_mult[BLOCK_X] +
+            warp * output_mult[BLOCK_Y] +
+            cta1 * step_output) * output_vec_size;
+  }
+
+  C10_DEVICE int shared_memory_offset(int offset) const {
+    return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
+  }
+
+  C10_DEVICE int staging_memory_offset(int cta2) const {
+    int offset = cta2 + blockIdx.x * gridDim.y;
+    if (!should_block_x_reduce()) {
+      offset = threadIdx.x + offset * blockDim.x;
+    }
+    return offset;
+  }
+
+  int shared_memory_size() const {
+    if (!should_block_y_reduce() &&
+        (!should_block_x_reduce() ||
+         block_width <= at::cuda::warp_size())) {
+      return 0;
+    }
+    return element_size_bytes * num_threads * output_vec_size;
+  }
+
+  int64_t global_memory_size() const {
+    if (!should_global_reduce()) {
+      return 0;
+    }
+    auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
+    if (!should_block_x_reduce()) {
+      size *= block().x * output_vec_size;
+    }
+    return size;
+  }
+
+  int semaphore_size() const {
+    if (!should_global_reduce()) {
+      return 0;
+    }
+    return sizeof(int) * grid().x;
+  }
+
+  int values_per_thread() const {
+    return div_up(num_inputs, step_input);
+  }
+};
+
+std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
+
+template
+C10_LAUNCH_BOUNDS_2(nt, 4)
+__global__ void reduce_kernel(R reduction) {
+  reduction.template run();
+}
+
+template 
+static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
+  int num_reduce_dims = iter.num_reduce_dims();
+  int num_output_dims = iter.ndim() - num_reduce_dims;
+  int input_index = iter.ntensors() - 1;
+  int output_index = 0;
+  std::array strides = {
+    iter.strides(output_index).data() + num_reduce_dims,
+    iter.strides(input_index).data() + num_reduce_dims,
+  };
+  auto shape = iter.shape().data() + num_reduce_dims;
+  return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data());
+}
+
+template 
+static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) {
+  int num_reduce_dims = iter.num_reduce_dims();
+  int input_index = iter.ntensors() - 1;
+  std::array strides = {
+    iter.strides(input_index).data(),
+  };
+  return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data());
+}
+
+template 
+struct func_wrapper_t {
+  using arg_t = typename binary_function_traits::arg1_t;
+  using scalar_t = typename binary_function_traits::arg2_t;
+
+  func_t combine;
+  static inline __device__ out_scalar_t project(arg_t arg) {
+    return (out_scalar_t) arg;
+  }
+  static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
+    return WARP_SHFL_DOWN(arg, offset);
+  }
+
+  static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
+    return acc;
+  }
+
+  func_wrapper_t(const func_t& op) : combine(op) {
+  }
+
+  // wrap a normal reduction that ignores the index
+  __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const {
+    return combine(acc, val);
+  }
+};
+
+template 
+func_wrapper_t func_wrapper(const func_t& op) {
+  return func_wrapper_t { op };
+}
+
+template 
+struct ReduceJitOp {
+//ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations
+//Maybe we can find a way to unify ReduceOp and ReduceJitOp
+  using InputCalculator = OffsetCalculator<1, uint32_t>;
+  using OutputCalculator = OffsetCalculator<2, uint32_t>;
+  //TODO for now arg_t is always opmath_t of the input, later we'll need to change it
+  using arg_t = at::opmath_type;
+
+  static constexpr int input_vec_size = ReduceConfig::input_vec_size;
+  //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor,
+  //not just wrapper
+  arg_t ident;
+  ReduceConfig config;
+  InputCalculator input_calc;
+  OutputCalculator output_calc;
+  const void* src;
+  const char* dst[2]; //it accepts at most two destinations
+  // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
+  // output is not permissible
+  void* acc_buf;
+  // cta_buf used for accumulation between blocks during global reduction
+  void* cta_buf;
+  int* semaphores;
+  int64_t base_idx;
+  bool accumulate;
+  bool final_output;
+  int noutputs;
+
+  ReduceJitOp(
+      ReduceConfig config,
+      InputCalculator input_calc,
+      OutputCalculator output_calc,
+      const void* src,
+      char* dst0,
+      optional dst1,
+      void* acc_buf,
+      void* cta_buf,
+      int* semaphores,
+      arg_t ident,
+      int noutputs,
+      int64_t base_idx)
+      : ident(ident),
+        config(config),
+        input_calc(input_calc),
+        output_calc(output_calc),
+        src(src),
+        acc_buf(acc_buf),
+        cta_buf(cta_buf),
+        semaphores(semaphores),
+        base_idx(base_idx),
+        noutputs(noutputs) {
+    dst[0] = dst0;
+    if (dst1.has_value()) {
+      dst[1] = dst1.value();
+    }
+  }
+};
+
+template 
+struct ReduceOp {
+  using traits = function_traits;
+  using arg_t = typename std::decay::type>::type;
+
+  using InputCalculator = OffsetCalculator<1, index_t>;
+  using OutputCalculator = OffsetCalculator<2, index_t>;
+
+  static constexpr bool can_accumulate_in_output =
+    std::is_convertible::value
+    && std::is_convertible::value;
+
+  static constexpr int input_vec_size = ReduceConfig::input_vec_size;
+
+  ops_t ops;
+  arg_t ident;
+  ReduceConfig config;
+  InputCalculator input_calc;
+  OutputCalculator output_calc;
+  const void* src;
+  const char* dst[2]; //it accepts at most two destinations
+  // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
+  // output is not permissible
+  void* acc_buf;
+  // cta_buf used for accumulation between blocks during global reduction
+  void* cta_buf;
+  int* semaphores;
+  int64_t base_idx;
+  bool accumulate;
+  bool final_output;
+  int noutputs;
+
+  ReduceOp(
+      ops_t ops,
+      ReduceConfig config,
+      InputCalculator input_calc,
+      OutputCalculator output_calc,
+      const void* src,
+      char* dst0,
+      optional dst1,
+      void* acc_buf,
+      void* cta_buf,
+      int* semaphores,
+      arg_t ident,
+      int noutputs,
+      int64_t base_idx)
+      : ops(ops),
+        ident(ident),
+        config(config),
+        input_calc(input_calc),
+        output_calc(output_calc),
+        src(src),
+        acc_buf(acc_buf),
+        cta_buf(cta_buf),
+        semaphores(semaphores),
+        base_idx(base_idx),
+        noutputs(noutputs) {
+    dst[0] = dst0;
+    if (dst1.has_value()) {
+      dst[1] = dst1.value();
+    }
+  }
+
+  template 
+  C10_DEVICE void run() const {
+    extern __shared__ char shared_memory[];
+    index_t output_idx = config.output_idx();
+    index_t input_idx = config.input_idx();
+    auto base_offsets1 = output_calc.get(output_idx)[1];
+
+    using arg_vec_t = at::detail::Array;
+    arg_vec_t value;
+
+    if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
+      const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
+      value = thread_reduce(input_slice);
+    }
+
+    if (config.should_block_y_reduce()) {
+      value = block_y_reduce(value, shared_memory);
+    }
+    if (config.should_block_x_reduce()) {
+      value = block_x_reduce(value, shared_memory);
+    }
+
+    using out_ptr_vec_t = at::detail::Array;
+    using offset_vec_t = at::detail::Array;
+    offset_vec_t base_offsets;
+    out_ptr_vec_t out;
+
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      base_offsets[i] = output_calc.get(output_idx + i)[0];
+      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
+    }
+
+    arg_vec_t* acc = nullptr;
+    if (acc_buf != nullptr) {
+      size_t numerator = sizeof(arg_t);
+      size_t denominator = sizeof(out_scalar_t);
+      reduce_fraction(numerator, denominator);
+      acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
+    }
+
+    if (config.should_global_reduce()) {
+      value = global_reduce(value, acc, shared_memory);
+    } else if (config.should_store(output_idx)) {
+      if (accumulate) {
+        #pragma unroll
+        for (int i = 0; i < output_vec_size; i++) {
+          value[i] = ops.translate_idx(value[i], base_idx);
+        }
+      }
+
+      if (acc == nullptr) {
+        if (accumulate) {
+          value = accumulate_in_output(out, value);
+        }
+        if (final_output) {
+          set_results_to_output(value, base_offsets);
+        } else {
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            *(out[i]) = get_accumulated_output(out[i], value[i]);
+          }
+        }
+      } else {
+        if (accumulate) {
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = ops.combine((*acc)[i], value[i]);
+          }
+        }
+        if (final_output) {
+          set_results_to_output(value, base_offsets);
+        } else {
+          *acc = value;
+        }
+      }
+    }
+  }
+
+  template 
+  C10_DEVICE at::detail::Array thread_reduce(const scalar_t* data) const {
+    if (config.vectorize_input) {
+      CUDA_KERNEL_ASSERT(output_vec_size == 1);
+      // reduce at the header of input_slice where memory is not aligned,
+      // so that thread_reduce will have an aligned memory to work on.
+      return {input_vectorized_thread_reduce_impl(data)};
+    } else {
+      index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
+      bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
+      if (is_contiguous) {
+        return thread_reduce_impl(data, [](index_t idx) { return idx; });
+      } else if (input_calc.dims == 1) {
+        return thread_reduce_impl(data, [&](index_t idx) { return idx * element_stride; });
+      } else {
+        return thread_reduce_impl(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
+      }
+    }
+  }
+
+  C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
+    index_t end = config.num_inputs;
+
+    // Handle the head of input slice where data is not aligned
+    arg_t value = ident;
+    constexpr int align_bytes = alignof(at::native::memory::aligned_vector);
+    constexpr int align_elements = align_bytes / sizeof(scalar_t);
+    int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t);
+    if (shift > 0) {
+      data -= shift;
+      end += shift;
+      if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
+        value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift);
+      }
+      end -= align_elements;
+      data += align_elements;
+      shift = align_elements - shift;
+    }
+
+    // Do the vectorized reduction
+    using load_t = at::native::memory::aligned_vector;
+
+    index_t idx = config.input_idx();
+    const index_t stride = config.step_input;
+
+    // Multiple accumulators to remove dependency between unrolled loops.
+    arg_t value_list[input_vec_size];
+    value_list[0] = value;
+
+    #pragma unroll
+    for (int i = 1; i < input_vec_size; i++) {
+      value_list[i] = ident;
+    }
+
+    while (idx * input_vec_size + input_vec_size - 1 < end) {
+      const auto values_vec = memory::load_vector(data, idx);
+      #pragma unroll
+      for (index_t i = 0; i < input_vec_size; i++) {
+        value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i);
+      }
+      idx += stride;
+    }
+
+    // tail
+    index_t tail_start = end - end % input_vec_size;
+    if (config.should_reduce_tail()) {
+      int idx = tail_start + threadIdx.x;
+      if (idx < end) {
+        const auto value = c10::load(data + idx);
+        value_list[0] = ops.reduce(value_list[0], value, idx + shift);
+      }
+    }
+
+    // combine accumulators
+    #pragma unroll
+    for (int i = 1; i < input_vec_size; i++) {
+      value_list[0] = ops.combine(value_list[0], value_list[i]);
+    }
+    return value_list[0];
+  }
+
+  template 
+  C10_DEVICE at::detail::Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
+    index_t idx = config.input_idx();
+    const index_t end = config.num_inputs;
+    const index_t stride = config.step_input;
+
+    using arg_vec_t = at::detail::Array;
+    using load_t = at::native::memory::aligned_vector;
+
+    // Multiple accumulators to remove dependency between unrolled loops.
+    arg_vec_t value_list[vt0];
+
+    #pragma unroll
+    for (int i = 0; i < vt0; i++) {
+      #pragma unroll
+      for (int j = 0; j < output_vec_size; j++) {
+        value_list[i][j] = ident;
+      }
+    }
+
+    load_t values[vt0];
+
+    while (idx + (vt0 - 1) * stride < end) {
+      #pragma unroll
+      for (index_t i = 0; i < vt0; i++) {
+        const auto offset = calc(idx + i * stride) / output_vec_size;
+        values[i] = memory::load_vector(data_, offset);
+      }
+      #pragma unroll
+      for (index_t i = 0; i < vt0; i++) {
+        #pragma unroll
+        for (index_t j = 0; j < output_vec_size; j++) {
+          value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
+        }
+      }
+      idx += stride * vt0;
+    }
+
+    // tail
+    int idx_ = idx;
+    #pragma unroll
+    for (index_t i = 0; i < vt0; i++) {
+      if (idx >= end) {
+        break;
+      }
+      const auto offset = calc(idx) / output_vec_size;
+      values[i] = memory::load_vector(data_, offset);
+      idx += stride;
+    }
+    idx = idx_;
+    #pragma unroll
+    for (index_t i = 0; i < vt0; i++) {
+      if (idx >= end) {
+        break;
+      }
+      #pragma unroll
+      for (index_t j = 0; j < output_vec_size; j++) {
+        value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
+      }
+      idx += stride;
+    }
+
+    // combine accumulators
+    #pragma unroll
+    for (int i = 1; i < vt0; i++) {
+      #pragma unroll
+      for (index_t j = 0; j < output_vec_size; j++) {
+        value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
+      }
+    }
+    return value_list[0];
+  }
+
+  template 
+  C10_DEVICE at::detail::Array block_x_reduce(at::detail::Array value, char* shared_memory) const {
+    using args_vec_t = at::detail::Array;
+    int dim_x = blockDim.x;
+    args_vec_t* shared = (args_vec_t*)shared_memory;
+    if (dim_x > warpSize) {
+      int address_base = threadIdx.x + threadIdx.y*blockDim.x;
+      shared[address_base] = value;
+      for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
+        __syncthreads();
+        if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
+          args_vec_t other = shared[address_base + offset];
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = ops.combine(value[i], other[i]);
+          }
+          shared[address_base] = value;
+        }
+      }
+      dim_x = warpSize;
+    }
+
+    __syncthreads();
+
+    for (int offset = 1; offset < dim_x; offset <<= 1) {
+      #pragma unroll
+      for (int i = 0; i < output_vec_size; i++) {
+        arg_t other = ops.warp_shfl_down(value[i], offset);
+        value[i] = ops.combine(value[i], other);
+      }
+    }
+    return value;
+  }
+
+  template 
+  C10_DEVICE at::detail::Array block_y_reduce(at::detail::Array value, char* shared_memory) const {
+    using args_vec_t = at::detail::Array;
+    args_vec_t* shared = (args_vec_t*)shared_memory;
+    shared[config.shared_memory_offset(0)] = value;
+    for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
+      __syncthreads();
+      if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
+        args_vec_t other = shared[config.shared_memory_offset(offset)];
+        #pragma unroll
+        for (int i = 0; i < output_vec_size; i++) {
+          value[i] = ops.combine(value[i], other[i]);
+        }
+        shared[config.shared_memory_offset(0)] = value;
+      }
+    }
+    return value;
+  }
+
+  C10_DEVICE bool mark_block_finished() const {
+    __shared__ bool is_last_block_done_shared;
+
+    __syncthreads();
+    if (threadIdx.x == 0 && threadIdx.y == 0) {
+      int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
+      is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
+    }
+
+    __syncthreads();
+
+    return is_last_block_done_shared;
+  }
+
+  template 
+  C10_DEVICE at::detail::Array accumulate_in_output(
+    at::detail::Array out,
+    at::detail::Array value,
+    typename std::enable_if::type* = nullptr
+  ) const {
+    at::detail::Array ret;
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      ret[i] = ops.combine(*(out[i]), value[i]);
+    }
+    return ret;
+  }
+
+  template 
+  C10_DEVICE out_scalar_t get_accumulated_output(
+    out_scalar_t* out, arg_t value,
+    typename std::enable_if::type* = nullptr
+  ) const {
+    CUDA_KERNEL_ASSERT(!final_output);
+    return (out_scalar_t)value;
+  }
+
+  // This function should never be called --
+  // it's the version of `accumulate_in_output`
+  // when accumulation in the output is not possible.
+  template 
+  C10_DEVICE at::detail::Array accumulate_in_output(
+    at::detail::Array,
+    at::detail::Array,
+    typename std::enable_if::type* = nullptr
+  ) const {
+    CUDA_KERNEL_ASSERT(false);
+    return arg_t {};
+  }
+
+  // This function should never be called --
+  // it's the version of `get_accumulated_output`
+  // when accumulation in the output is not possible.
+  template 
+  C10_DEVICE out_scalar_t get_accumulated_output(
+    out_scalar_t* out, arg_t value,
+    typename std::enable_if::type* = nullptr
+  ) const {
+    CUDA_KERNEL_ASSERT(false);
+    return *out;
+  }
+
+  template
+  C10_DEVICE void set_results(const T x, const index_t base_offset) const {
+    CUDA_KERNEL_ASSERT(noutputs == 1);
+    auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
+    *res = x;
+  }
+
+  //Currently implemented for max of two outputs
+  template
+  C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const {
+    if (noutputs >= 1) {
+      auto res0 = (T1*)((char*)dst[0] + base_offset);
+      *res0 = x.first;
+    }
+    if (noutputs >= 2) {
+      // base offset is computed assuming element size being sizeof(T1), so we need to make a
+      // correction to obtain the correct base offset
+      auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
+      *res1 = x.second;
+    }
+  }
+
+  template 
+  C10_DEVICE void set_results_to_output(at::detail::Array value, at::detail::Array base_offset) const {
+    CUDA_KERNEL_ASSERT(final_output);
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      set_results(ops.project(value[i]), base_offset[i]);
+    }
+  }
+
+  template 
+  C10_DEVICE at::detail::Array global_reduce(at::detail::Array value, at::detail::Array *acc, char* shared_memory) const {
+    using arg_vec_t = at::detail::Array;
+    using out_ptr_vec_t = at::detail::Array;
+    using offset_vec_t = at::detail::Array;
+
+    arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
+    index_t output_idx = config.output_idx();
+    offset_vec_t base_offsets;
+    out_ptr_vec_t out;
+
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      base_offsets[i] = output_calc.get(output_idx + i)[0];
+      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
+    }
+
+    bool should_store = config.should_store(output_idx);
+    if (should_store) {
+      index_t offset = config.staging_memory_offset(blockIdx.y);
+      reduce_buffer[offset] = value;
+    }
+
+    __threadfence(); // make sure writes are globally visible
+    __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
+    bool is_last_block_done = mark_block_finished();
+
+    if (is_last_block_done) {
+      value = ident;
+      if (config.should_block_x_reduce()) {
+        index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
+        index_t step = blockDim.x * blockDim.y;
+        for (; input_offset < config.ctas_per_output; input_offset += step) {
+          index_t idx = config.staging_memory_offset(input_offset);
+          arg_vec_t next = reduce_buffer[idx];
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = ops.combine(value[i], next[i]);
+          }
+        }
+      } else {
+        index_t input_offset = threadIdx.y;
+        index_t step = blockDim.y;
+        for (; input_offset < config.ctas_per_output; input_offset += step) {
+          index_t idx = config.staging_memory_offset(input_offset);
+          arg_vec_t next = reduce_buffer[idx];
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = ops.combine(value[i], next[i]);
+          }
+        }
+      }
+      value = block_y_reduce(value, shared_memory);
+      if (config.should_block_x_reduce()) {
+        value = block_x_reduce(value, shared_memory);
+      }
+      if (should_store) {
+        if (accumulate) {
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = ops.translate_idx(value[i], base_idx);
+          }
+        }
+
+        if (acc == nullptr) {
+          if (accumulate) {
+            value = accumulate_in_output(out, value);
+          }
+          if (final_output) {
+            set_results_to_output(value, base_offsets);
+          } else {
+            #pragma unroll
+            for (int i = 0; i < output_vec_size; i++) {
+              *(out[i]) = get_accumulated_output(out[i], value[i]);
+            }
+          }
+        } else {
+          if (accumulate) {
+            #pragma unroll
+            for (int i = 0; i < output_vec_size; i++) {
+              value[i] = ops.combine((*acc)[i], value[i]);
+            }
+          }
+          if (final_output) {
+            set_results_to_output(value, base_offsets);
+          } else {
+            *acc = value;
+          }
+        }
+      }
+    }
+
+    return value;
+  }
+};
+
+template
+static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
+  dim3 block = config.block();
+  dim3 grid = config.grid();
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+  int shared_memory = config.shared_memory_size();
+
+  switch(config.output_vec_size) {
+  case 4:
+    reduce_kernel<<>>(reduction);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    break;
+  case 2:
+    reduce_kernel<<>>(reduction);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    break;
+  default:
+    reduce_kernel<<>>(reduction);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+  }
+}
+
+inline void launch_jitted_reduce_kernel(
+    std::mutex &jiterator_mutex,
+    std::array &fn_cache,
+    const at::cuda::jit::KernelDescriptor &desc,
+    int vt0, const ReduceConfig& config, void *reduction) {
+  dim3 block = config.block();
+  dim3 grid = config.grid();
+
+  int shared_memory = config.shared_memory_size();
+  at::cuda::jit::NvrtcFunction* fn_ptr;
+  switch(config.output_vec_size) {
+  case 4:
+    fn_ptr = &fn_cache[0];
+    break;
+  case 2:
+    fn_ptr = &fn_cache[1];
+    break;
+  default:
+    fn_ptr = &fn_cache[2];
+  }
+  if (!fn_ptr->function) {
+    int max_threads_codegen =
+        max_reduce_threads(desc.f_inputs_type) / config.output_vec_size;
+    auto code = at::cuda::jit::generate_reduction_code(
+        desc, vt0, true, false, config.output_vec_size, max_threads_codegen);
+
+    *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name);
+  }
+  constexpr int kernel_args = 1;
+  void* args[kernel_args];
+  args[0] = reduction;
+  at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory);
+}
+
+
+class AccumulationBuffer {
+ public:
+  AccumulationBuffer() {}
+
+  AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) {
+    out_ptr_ = (char*)out_ptr;
+    if (out_t_size >= acc_t_size) {
+      // reusing output buffer for accumulation.
+      acc_ptr_ = (char*)out_ptr;
+      numerator_ = 1;
+      denominator_ = 1;
+    } else {
+      auto& allocator = *c10::cuda::CUDACachingAllocator::get();
+      buffer_ = allocator.allocate(size);
+      acc_ptr_ = (char*)buffer_.get();
+      numerator_ = acc_t_size;
+      denominator_ = out_t_size;
+      reduce_fraction(numerator_, denominator_);
+    }
+  }
+
+  char* get_acc_slice(char* out_ptr) {
+    if (acc_ptr_ == nullptr) {
+      return nullptr;
+    }
+    return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_);
+  }
+
+ private:
+  char* acc_ptr_ = nullptr;
+  char* out_ptr_ = nullptr;
+  size_t numerator_;
+  size_t denominator_;
+  at::DataPtr buffer_;
+};
+
+template 
+int get_output_vec_size(const TensorIterator &iter) {
+  int vec_size = 4;
+  auto update_vec_size = [&vec_size](uint64_t n) {
+    while(n % vec_size != 0) {
+      vec_size /= 2;
+    }
+  };
+
+  uint64_t base_address = reinterpret_cast(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t);
+  update_vec_size(base_address);
+
+  const int output_index = iter.num_reduce_dims();
+  update_vec_size(iter.shape()[output_index]);
+
+  int j = 0;
+  for(auto i : iter.strides(iter.noutputs())) {
+    if (j != output_index) {
+      update_vec_size(i / sizeof(scalar_t));
+    }
+    j++;
+  }
+  return vec_size;
+}
+
+template
+ReduceConfig setReduceConfig(const TensorIterator& iter){
+  // Start by assuming that each thread handles a single output and all
+  // the inputs for that output.
+  int64_t num_outputs = iter.num_output_elements();
+  int64_t inputs_per_output = iter.numel() / num_outputs;
+  int input_index = iter.ntensors() - 1;
+
+  auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output);
+
+  int64_t dim0;
+  int64_t dim1;
+  int64_t fastest_moving_stride;
+  bool reduction_on_fastest_striding_dimension;
+
+  if (iter.ndim() > 0) {
+    // Adjust block size to map block width to fastest changing dimension of input
+    // tensor. This grants the best possible memory accessing pattern, given that
+    // for non-contiguous tensor with space in between, we cannot have perfect
+    // memory coalescing.
+    reduction_on_fastest_striding_dimension =
+        (iter.num_reduce_dims() == iter.ndim()) ||
+        (iter.strides(/*arg=*/input_index)[0] <
+        iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
+    // Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
+    // dim0 & dim1 are more like the upper bound of the block dimension. The
+    // actual launch config and reduction scheme is determined by setting values
+    // to `config.input_mult` and `config.output_mult`.
+    // We try to max out dim1 so that we have enough threads per CTA to deliver
+    // performance for larger problem size.
+    if (reduction_on_fastest_striding_dimension) {
+      // Map block.x to the fastest reducing dimension. It implies:
+      //   1. block_x_reduce is required.
+      //   2. block.y now max out to num_outputs.
+      dim0 = inputs_per_output;
+      dim1 = num_outputs;
+      fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
+    } else {
+      // Map block.x to the fastest non reducing dimension. It implies:
+      //   1. block_x_reduce is turned off.
+      //   2. block.y now max out to inputs_per_output.
+      dim0 = num_outputs;
+      dim1 = inputs_per_output;
+      fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
+    }
+  } else {
+    reduction_on_fastest_striding_dimension = true;
+    fastest_moving_stride = sizeof(scalar_t);
+    dim0 = 1;
+    dim1 = 1;
+  }
+
+  // We do vectorization to gain better memory access, there are two cases which we call
+  // "vectorize along input" and "vectorize along output". Note that the "input/output"
+  // here does not mean we are vectorizing load/store instructions. We always only vectorize
+  // load instructions.
+  //
+  // Case 1: "vectorize along input"
+  // This case happens when we are reducing along fastest moving dimesion. In such case, threads
+  // with the same threadIdx.y works on the same reduction cooperatively and will produce results
+  // for the same output. In such case, values in each loaded vector always correspond to the same output.
+  //
+  // Case 2: "vectorize along output"
+  // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
+  // threads with different threadIdx.x are independent and will produce results for different outputs.
+  // In such case, values in each loaded vector always correspond to different outputs.
+  if (fastest_moving_stride == sizeof(scalar_t)) {
+    if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) {
+      // Case 1: "vectorize along input"
+      // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
+      // we should avoid vectorization.
+      config.vectorize_input = true;
+      dim0 /= config.input_vec_size;
+    } else if (!reduction_on_fastest_striding_dimension) {
+      // Case 2: "vectorize along output"
+      config.output_vec_size = get_output_vec_size(iter);
+      dim0 /= config.output_vec_size;
+    }
+  }
+
+  // Adjust block_width and block_height
+  config.set_block_dimension(dim0, dim1);
+
+  int block_width = config.block_width;
+  int block_height = config.block_height;
+
+  if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) {
+    // Split the input across lanes if the input is contiguous in the reduced
+    // dimension. This will require reduction between threads using warp
+    // shuffle instructions and shared memory (if block_width > warpSize).
+    config.input_mult[0] = config.split_input(block_width);
+  } else {
+    // Otherwise split the output across lanes in a warp.
+    config.output_mult[0] = config.split_output(block_width);
+  }
+
+  constexpr int min_values_per_thread = 16;
+  constexpr int max_values_per_thread = 256;
+
+  if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
+    // Divide the input across warps in a thread-block, if that leaves at least
+    // 16 elements to be summed by each thread. This will require inter-warp
+    // reduction using shared memory.
+    config.input_mult[1] = config.split_input(block_height);
+  } else {
+    // Otherwise, each warp handles a separate output.
+    config.output_mult[1] = config.split_output(block_height);
+  }
+
+  const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / config.num_threads;
+  const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
+  const int target_grid_size = num_mp * blocks_per_sm;
+  int grid = config.grid().x;
+  if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) {
+    // Divide the input across thread-blocks if the amount of work per-thread
+    // is large enough and the size of the output is small enough. This will
+    // require a reduction using global memory.
+    // If we decide to split input across blocks, as long as we can get enough
+    // number of blocks (`target_grid_size`) to balance SM, we should still
+    // make the number of values per thread large for best performance.
+    int ctas_per_output1 = div_up(target_grid_size, grid);
+    int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread);
+    int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread);
+    // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have
+    // a large number of values to deal with. But we don't want values_per_thread to be larger than
+    // max_values_per_thread
+    config.ctas_per_output = std::max(std::min(ctas_per_output1, ctas_per_output2), ctas_per_output3);
+    if (config.ctas_per_output > 1) {
+      config.input_mult[2] = config.split_input(config.ctas_per_output);
+    }
+  }
+  return config;
+};
+
+template 
+inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
+                              AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
+  AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
+
+  using traits = function_traits;
+  using arg_t = typename traits::template arg<0>::type;
+  // at::Half/at::ComplexHalf overflows easily as it's range is very small.
+  // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
+  // set can_accumulate_in_output to False.
+  static constexpr bool is_inp_out_type_half_or_chalf =
+      (std::is_same::value &&
+       std::is_same::value) ||
+      (std::is_same, scalar_t>::value &&
+       std::is_same, out_scalar_t>::value);
+  // at::BFloat16 has lower precision and can lead to rounding errors.
+  // So when scalar_t and out_scalar_t are at::BFloat16, we
+  // set can_accumulate_in_output to False.
+  static constexpr bool is_inp_out_type_bfloat16 =
+      (std::is_same::value &&
+       std::is_same::value);
+  static constexpr bool can_accumulate_in_output =
+      std::is_convertible::value &&
+      !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
+
+  bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+  std::unique_ptr owned_buf_ptr;
+  // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
+  // reused by all recursive function calls.
+  if (acc_buf_ptr == NULL) {
+    // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
+    // when accumulation in output is not possible.
+    if (!can_accumulate_in_output && !can_use_32bit_indexing) {
+      int64_t output_memory_size = iter.element_size(0);
+      for (int dim = 0; dim < iter.ndim(); dim++) {
+        output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
+      }
+      output_memory_size /= iter.element_size(0); //iter.strides is in bytes
+      owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
+                                                 sizeof(out_scalar_t),
+                                                 (char*) iter.data_ptr(0),
+                                                 output_memory_size * sizeof(arg_t)));
+    } else {
+      owned_buf_ptr.reset(new AccumulationBuffer());
+    }
+    acc_buf_ptr = owned_buf_ptr.get();
+  }
+
+  if (!can_use_32bit_indexing) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
+
+      gpu_reduce_kernel(sub_iter, ops, ident,
+          acc_buf_ptr, sub_iter_base_idx);
+    }
+    return;
+  }
+
+  const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
+  char* out_data = (char*)iter.data_ptr(0);
+  const auto noutputs = iter.noutputs();
+  optional out_data_extra;
+  if (noutputs > 1) {
+    out_data_extra = (char*)iter.data_ptr(1);
+  } else {
+    out_data_extra = nullopt;
+  }
+  char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
+
+  ReduceConfig config = setReduceConfig(iter);
+  at::DataPtr buffer;
+  at::DataPtr semaphores;
+  if (config.should_global_reduce()) {
+    auto& allocator = *c10::cuda::CUDACachingAllocator::get();
+    buffer = allocator.allocate(config.global_memory_size());
+    semaphores = allocator.allocate(config.semaphore_size());
+
+    auto stream = at::cuda::getCurrentCUDAStream();
+    AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
+  }
+
+  AT_ASSERT(can_use_32bit_indexing);
+  auto output_calc = make_output_calculator(iter);
+  auto input_calc = make_input_calculator(iter);
+  auto reduce = ReduceOp(
+      ops,
+      config,
+      input_calc,
+      output_calc,
+      in_data,
+      out_data,
+      out_data_extra,
+      acc_data,
+      buffer.get(),
+      (int*)semaphores.get(),
+      ident,
+      noutputs,
+      base_idx);
+  reduce.accumulate = iter.should_accumulate();
+  reduce.final_output = iter.is_final_output();
+
+  launch_reduce_kernel::MAX_NUM_THREADS>(config, reduce);
+}
+
+//TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function
+//try unifying with gpu_reduce_kernel
+template 
+inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0,
+                              AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
+  AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
+
+  //TODO - this will be different for more complicated reductions, but for now reductions using
+  //func_wrapper all have arg_t = opmath
+  using arg_t = at::opmath_type;
+  // at::Half/at::ComplexHalf overflows easily as it's range is very small.
+  // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
+  // set can_accumulate_in_output to False.
+  static constexpr bool is_inp_out_type_half_or_chalf =
+      (std::is_same::value &&
+       std::is_same::value) ||
+      (std::is_same, scalar_t>::value &&
+       std::is_same, out_scalar_t>::value);
+  // at::BFloat16 has lower precision and can lead to rounding errors.
+  // So when scalar_t and out_scalar_t are at::BFloat16, we
+  // set can_accumulate_in_output to False.
+  static constexpr bool is_inp_out_type_bfloat16 =
+      (std::is_same::value &&
+       std::is_same::value);
+  static constexpr bool can_accumulate_in_output =
+      std::is_convertible::value &&
+      !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
+
+  bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+  std::unique_ptr owned_buf_ptr;
+
+  // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
+  // reused by all recursive function calls.
+  if (acc_buf_ptr == NULL) {
+    // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
+    // when accumulation in output is not possible.
+    if (!can_accumulate_in_output && !can_use_32bit_indexing) {
+      int64_t output_memory_size = iter.element_size(0);
+      for (int dim = 0; dim < iter.ndim(); dim++) {
+        output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
+      }
+      output_memory_size /= iter.element_size(0); //iter.strides is in bytes
+      owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO
+                                                 sizeof(out_scalar_t),
+                                                 (char*) iter.data_ptr(0),
+                                                 output_memory_size * sizeof(out_scalar_t))); //TODO
+    } else {
+      owned_buf_ptr.reset(new AccumulationBuffer());
+    }
+    acc_buf_ptr = owned_buf_ptr.get();
+  }
+
+  if (!can_use_32bit_indexing) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
+
+      jitted_gpu_reduce_kernel(sub_iter, func, ident,
+          acc_buf_ptr, sub_iter_base_idx);
+    }
+    return;
+  }
+
+  //TODO - for now we support a single input, we may be able to relax this constraint
+  const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
+  char* out_data = (char*)iter.data_ptr(0);
+  const auto noutputs = iter.noutputs();
+  optional out_data_extra;
+  if (noutputs > 1) {
+    out_data_extra = (char*)iter.data_ptr(1);
+  } else {
+    out_data_extra = nullopt;
+  }
+  char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
+
+  ReduceConfig config = setReduceConfig(iter);
+
+  at::DataPtr buffer;
+  at::DataPtr semaphores;
+  if (config.should_global_reduce()) {
+    auto& allocator = *c10::cuda::CUDACachingAllocator::get();
+    buffer = allocator.allocate(config.global_memory_size());
+    semaphores = allocator.allocate(config.semaphore_size());
+
+    auto stream = at::cuda::getCurrentCUDAStream();
+    AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
+  }
+
+  AT_ASSERT(can_use_32bit_indexing);
+  auto output_calc = make_output_calculator(iter);
+  auto input_calc = make_input_calculator(iter);
+  auto reduce = ReduceJitOp(
+      config,
+      input_calc,
+      output_calc,
+      in_data,
+      out_data,
+      out_data_extra,
+      acc_data,
+      buffer.get(),
+      (int*)semaphores.get(),
+      ident,
+      noutputs,
+      base_idx);
+  reduce.accumulate = iter.should_accumulate();
+  reduce.final_output = iter.is_final_output();
+
+  constexpr int nInputs = 1;
+  constexpr int nOutputs = 1;
+  static auto desc = at::cuda::jit::make_kernel_descriptor<
+    out_scalar_t, scalar_t>(name, func, nInputs, nOutputs);
+
+  static std::mutex jiterator_mutex;
+  static std::vector> fn_cache(c10::cuda::device_count());
+  auto &cache = fn_cache[iter.device().index()];
+
+  launch_jitted_reduce_kernel(
+      jiterator_mutex, cache, desc, vt0, config, &reduce);
+}
+
+}} // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ReduceOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ReduceOps.h
new file mode 100644
index 0000000000000000000000000000000000000000..e5ef4c050130397dea644386896e7795a2035033
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ReduceOps.h
@@ -0,0 +1,20 @@
+
+namespace at {
+struct TensorIterator;
+}
+
+namespace c10 {
+class Scalar;
+}
+
+namespace at { namespace native {
+
+void norm_launch_kernel(TensorIterator &iter, double val);
+void min_launch_kernel(TensorIterator &iter);
+void max_launch_kernel(TensorIterator &iter);
+void aminmax_launch_kernel(TensorIterator &iter);
+void min_all_launch_kernel(TensorIterator &iter);
+void max_all_launch_kernel(TensorIterator &iter);
+void aminmax_allreduce_launch_kernel(TensorIterator &iter);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Resize.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Resize.h
new file mode 100644
index 0000000000000000000000000000000000000000..9740ed43ff5288b9ffe3f6666d1ce274500a83b8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Resize.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include 
+#include 
+
+#include 
+
+namespace at { namespace native {
+
+TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
+
+static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
+  // It does not make sense to try to resize a storage
+  // to hold 0 elements, and this can break
+  // if storage_offset is positive but
+  // new_size is 0, so just bail in that case
+  // (same comment is in Resize.h)
+  if (self->numel() == 0) {
+    return;
+  }
+
+  const Storage &storage = self->unsafe_storage();
+  TORCH_CHECK(storage, "Tensor: invalid null storage");
+  if (new_size_bytes > storage.nbytes()) {
+    resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
+  }
+}
+
+inline TensorImpl* resize_impl_cuda_(
+    TensorImpl* self,
+    IntArrayRef size,
+    at::OptionalIntArrayRef stride,
+    bool device_guard = true) {
+  if (self->sizes() == size && (!stride || self->strides() == stride)) {
+    return self;
+  }
+
+  // NB: We don't need to hold the device guard when calling from TH
+  cuda::OptionalCUDAGuard guard;
+  if (device_guard) {
+    guard.set_index(self->storage().device().index());
+  }
+
+  const auto itemsize = self->dtype().itemsize();
+  const auto storage_offset = self->storage_offset();
+  size_t storage_size = 1;
+  if (stride) {
+    self->set_sizes_and_strides(size, *stride);
+    storage_size = at::detail::computeStorageNbytes(
+        size, *stride, itemsize, storage_offset);
+  } else {
+    self->set_sizes_contiguous(size);
+    storage_size = at::detail::computeStorageNbytesContiguous(
+        size, itemsize, storage_offset);
+  }
+  maybe_resize_storage_cuda(self, storage_size);
+
+  return self;
+}
+
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ScanKernels.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ScanKernels.h
new file mode 100644
index 0000000000000000000000000000000000000000..fbc3d974cf9684205036fa4d71f4fbebbfd1ed52
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ScanKernels.h
@@ -0,0 +1,18 @@
+#pragma once
+#include 
+
+namespace at {
+class TensorBase;
+
+namespace native {
+
+// NOTE: these functions require output tensors to be contiguous
+void launch_cummax_cuda_kernel(const TensorBase& self, const TensorBase& values,
+                               const TensorBase& indices, int64_t dim);
+void launch_cummin_cuda_kernel(const TensorBase& self, const TensorBase& values,
+                               const TensorBase& indices, int64_t dim);
+void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim);
+void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim);
+void launch_cumprod_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..19683e3d030d3246d8895754a84046d4f6906fae
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh
@@ -0,0 +1,459 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+template 
+constexpr inline integer ceil_div(integer n, integer m) {
+  return (n + m - 1) / m;
+}
+
+template 
+constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
+  integer log_num_threads_x = 0;
+  integer log_num_threads_y = 0;
+  while (((integer)1 << log_num_threads_x) < row_size) {
+    ++log_num_threads_x;
+  }
+  while (((integer)1 << log_num_threads_y) < num_rows) {
+    ++log_num_threads_y;
+  }
+  // we want to keep the ratio between the x-threads and y-threads about the same as
+  // the ratio between the row_size and num_rows, but the total number of threads in
+  // a block should be about 512
+  integer diff = log_num_threads_x - log_num_threads_y;
+  // 9 is from log2(512)
+  log_num_threads_x = ((integer)9 + diff) / (integer)2;
+  // I found that in having larger log_num_threads_x can give significant speed up in some cases,
+  // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
+  // similar to the previous implementation
+  // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
+  log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
+  return log_num_threads_x;
+}
+
+template
+__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
+  if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) {
+    rhs = lhs;
+    rhs_idx = lhs_idx;
+  }
+}
+/* Perform an inclusive scan along the innermost dimension of a tensor.
+ *
+ * - num_rows is the size of the flattened outer dimensions;
+ * - row_size is the size of the innermost dimension;
+ *
+ * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
+ * considered as having 'num_rows' rows of size 'row_size'.
+ * Each thread block processes one or more sets of contiguous rows (processing multiple rows
+ * per thread block is quicker than processing a single row, especially for short rows).
+ */
+template
+__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
+                                                int num_rows, int row_size,
+                                                const uint32_t num_threads, const uint32_t log_num_threads_x,
+                                                scalar_t init, BinaryFunction binary_op) {
+  // dynamic memory allocation for vbuf and ibuf
+  alignas(sizeof(double)) extern __shared__ char buf[];
+  scalar_t* vbuf = reinterpret_cast(buf); // the size is num_threads * 2
+  int64_t* ibuf = reinterpret_cast(vbuf + num_threads * 2);
+  const uint32_t num_threads_x = 1 << log_num_threads_x;
+  scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y;
+  int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y;
+
+  for (int block_row = blockIdx.x * blockDim.y;
+       block_row < num_rows;
+       block_row += blockDim.y * gridDim.x) {
+    int row = block_row + threadIdx.y;
+    const scalar_t *row_self = self_ + row * row_size;
+    scalar_t *row_values = values_ + row * row_size;
+    int64_t *row_indices = indices_ + row * row_size;
+    scalar_t block_total = init;
+    int64_t block_idx_final = 0;
+    const bool row_exists = row < num_rows;
+    // Perform scan on one block at a time, keeping track of the total value of
+    // all blocks processed so far.
+    for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
+      // Load data into shared memory (two values per thread).
+      int col1 = block_col + threadIdx.x;
+      int col2 = block_col + num_threads_x + threadIdx.x;
+      if (row_exists) {
+        if (col1 < row_size) {
+          row_buf[threadIdx.x] = c10::load(&row_self[col1]);
+          row_idx_buf[threadIdx.x] = col1;
+        } else {
+          row_buf[threadIdx.x] = init;
+          // No need to set the index here as the value in init will never be selected
+        }
+
+        if (col2 < row_size) {
+          row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
+          row_idx_buf[num_threads_x + threadIdx.x] = col2;
+        } else {
+          row_buf[num_threads_x + threadIdx.x] = init;
+          // No need to set the index here as the value in init will never be selected
+        }
+
+        // Add the total value of all previous blocks to the first value of this block.
+        if (threadIdx.x == 0) {
+          binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
+        }
+      }
+      __syncthreads();
+
+      // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
+      // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
+      for (uint32_t s = 1; s <= num_threads_x; s <<= 1) {
+        if (row_exists) {
+          uint32_t a = (threadIdx.x / s) * (2 * s) + s;
+          uint32_t ti = a + (threadIdx.x % s);
+          uint32_t si = a - 1;
+          binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op);
+        }
+        __syncthreads();
+      }
+
+      // Write back to output.
+      if (row_exists) {
+        if (col1 < row_size){
+          row_values[col1] = row_buf[threadIdx.x];
+          row_indices[col1] = row_idx_buf[threadIdx.x];
+        }
+        if (col2 < row_size) {
+          row_values[col2] = row_buf[num_threads_x + threadIdx.x];
+          row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
+        }
+      }
+      block_total = row_buf[2 * num_threads_x - 1];
+      block_idx_final = row_idx_buf[2 * num_threads_x - 1];
+      __syncthreads();
+    }
+  }
+}
+
+/* Perform an inclusive scan along an outer dimension of a tensor.
+ *
+ * - num_orows is the size of the flattened outer dimensions;
+ * - num_irows is the size of the flattened inner dimensions;
+ * - row_size is the size of the dimension along which to compute the variance;
+ *
+ * The dimensions to the outside and inside of the specified dimension are considered as flattened.
+ * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
+ * outer dimensions, which contains several "inner rows").
+ * Each thread processes a single inner row at a time.
+ */
+template
+__global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
+                  const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
+  for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
+    for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
+      const scalar_t *self = self_ + orow * row_size * num_irows + irow;
+      scalar_t *values = values_ + orow * row_size * num_irows + irow;
+      int64_t *indices = indices_ + orow * row_size * num_irows + irow;
+      scalar_t out = init;
+      int64_t out_idx = 0;
+
+      for (auto col = decltype(row_size){0}; col < row_size; ++col) {
+        const auto val = c10::load(self);
+        if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
+          out = val;
+          out_idx = col;
+        }
+        *values = out;
+        *indices = out_idx;
+        self += num_irows;
+        values += num_irows;
+        indices += num_irows;
+      }
+    }
+  }
+}
+
+inline void check_fits_in_unsigned(int64_t val, const char* name) {
+  constexpr auto umax = std::numeric_limits::max();
+  TORCH_CHECK(
+      val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
+}
+
+
+template
+__host__ void scan_outer_dim_with_indices(
+    const TensorBase& self, const TensorBase& values, const TensorBase& indices,
+    int dim, scalar_t init, BinaryFunction binary_op) {
+  int64_t row_size = self.size(dim);
+  auto sizes = self.sizes();
+
+  // Treat all outer dimensions (i.e. dim_ < dim) as one.
+  const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
+
+  // Treat all inner dimensions (i.e. dim > dimension) as one.
+  const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
+  //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
+  //make sure that input is not bigger than supported by uint32_t
+  check_fits_in_unsigned(num_irows, "num_irows");
+  check_fits_in_unsigned(num_orows, "num_orows");
+  check_fits_in_unsigned(row_size, "row_size");
+
+
+  dim3 threads(std::min(512, int(num_irows)));
+  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
+  dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
+  tensor_kernel_scan_outer_dim_with_indices<<>>(
+    self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(),
+    num_orows, num_irows, row_size, init, binary_op);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template 
+__host__ void scan_innermost_dim_with_indices(
+    const TensorBase& self, const TensorBase& values, const TensorBase& indices,
+    scalar_t init, BinaryFunction binary_op) {
+  int ndim = self.dim();
+  // Treat all outer dimensions as a single dimension.
+  int row_size = self.size(ndim - 1);
+  int num_rows = self.numel() / row_size;
+
+  // assuming max_num_threads per block is 512
+  const uint32_t num_threads = 512;
+  const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size);
+  const uint32_t num_threads_x = (1 << log_num_threads_x);
+  const uint32_t num_threads_y = num_threads / num_threads_x;
+  dim3 threads(num_threads_x, num_threads_y);
+  dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
+
+  const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t));
+  tensor_kernel_scan_innermost_dim_with_indices<<>>(
+    self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(),
+    num_rows, row_size, num_threads, log_num_threads_x, init, binary_op);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template
+void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) {
+     int64_t dim, scalar_t init, BinaryFunction binary_op) {
+  int ndim = self.dim();
+  auto self_ = self.expect_contiguous();
+  TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous());
+  if (dim == ndim - 1) {
+    scan_innermost_dim_with_indices(*self_, values, indices, init, binary_op);
+  } else {
+    scan_outer_dim_with_indices(*self_, values, indices, dim, init, binary_op);
+  }
+}
+
+// TODO: The implementation of `tensor_kernel_scan_outer_dim` and
+// `tensor_kernel_scan_innermost_dim` is similar to
+// `tensor_kernel_scan_outer_dim_with_indices`
+// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
+// remove the duplication.
+
+/* Perform an inclusive scan along an outer dimension of a tensor.
+ *
+ * - num_orows is the size of the flattened outer dimensions;
+ * - num_irows is the size of the flattened inner dimensions;
+ * - row_size is the size of the dimension along which to scan;
+ *
+ * The dimensions to the outside and inside of the specified dimension are considered as flattened.
+ * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
+ * outer dimensions, which contains several "inner rows").
+ * Each thread processes a single inner row at a time.
+ */
+template
+__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
+                                              const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
+                                              const scalar_t init, BinaryOp binary_op)
+{
+  for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
+    for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
+      const scalar_t *src = src_ + orow * row_size * num_irows + irow;
+      scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
+      scalar_t acc = init;
+
+      for (uint32_t col = 0; col < row_size; ++col) {
+        acc = binary_op(acc, c10::load(src));
+        *tgt = acc;
+
+        src += num_irows;
+        tgt += num_irows;
+      }
+    }
+  }
+}
+
+/* Perform an inclusive scan along the innermost dimension of a tensor.
+ *
+ * - num_rows is the size of the flattened outer dimensions;
+ * - row_size is the size of the innermost dimension;
+ *
+ * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
+ * considered as having 'num_rows' rows of size 'row_size'.
+ * Each thread block processes one or more sets of contiguous rows (processing multiple rows
+ * per thread block is quicker than processing a single row, especially for short rows).
+ */
+template
+__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_,
+                                      const uint32_t num_rows, const uint32_t row_size,
+                                      const uint32_t log_num_threads_x,
+                                      T init, BinaryFunction binary_op){
+  const uint32_t num_threads_x = 1 << log_num_threads_x;
+  for (uint32_t block_row = blockIdx.x * blockDim.y;
+       block_row < num_rows;
+       block_row += blockDim.y * gridDim.x) {
+    uint32_t row = block_row + threadIdx.y;
+    T block_total = init;
+
+    const T *row_src = src_ + row * row_size;
+    T *row_tgt = tgt_ + row * row_size;
+    const bool row_exists = row < num_rows;
+
+    // Perform scan on one block at a time, keeping track of the total value of
+    // all blocks processed so far.
+    for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
+      // Load data into shared memory (two values per thread).
+      uint32_t col1 = block_col + threadIdx.x;
+      uint32_t col2 = block_col + num_threads_x + threadIdx.x;
+      if (row_exists) {
+        if (col1 < row_size) {
+          row_buf[threadIdx.x] = row_src[col1];
+        } else {
+          row_buf[threadIdx.x] = init;
+        }
+
+        if (col2 < row_size) {
+          row_buf[num_threads_x + threadIdx.x] = row_src[col2];
+        } else {
+          row_buf[num_threads_x + threadIdx.x] = init;
+        }
+
+        // Add the total value of all previous blocks to the first value of this block.
+        if (threadIdx.x == 0) {
+          row_buf[0] = binary_op(row_buf[0], block_total);
+        }
+      }
+      __syncthreads();
+
+      // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
+      // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
+      for (uint32_t m = 0; m <= log_num_threads_x; ++m) {
+        if (row_exists) {
+          uint32_t s = 1 << m; // s = 2 ^ m
+          uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s
+          uint32_t ti = a + (threadIdx.x % s);
+          uint32_t si = a - 1;
+          row_buf[ti] = binary_op(row_buf[ti], row_buf[si]);
+        }
+        __syncthreads();
+      }
+
+      // Write back to output.
+      if (row_exists) {
+        if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
+        if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
+      }
+      block_total = row_buf[2 * num_threads_x - 1];
+      __syncthreads();
+    }
+  }
+}
+
+template <
+    typename T,
+    class BinaryFunction>
+__global__ void tensor_kernel_scan_innermost_dim(
+    T* tgt_,
+    const T* src_,
+    const uint32_t num_rows,
+    const uint32_t row_size,
+    const uint32_t log_num_threads_x,
+    T init,
+    BinaryFunction binary_op) {
+  alignas(sizeof(double)) extern __shared__ char sbuf[];
+  T* sbuf2 = reinterpret_cast(sbuf);
+  const uint32_t num_threads_x = 1 << log_num_threads_x;
+  T* row_buf = reinterpret_cast(sbuf2 + num_threads_x * 2 * threadIdx.y);
+
+  tensor_kernel_scan_innermost_dim_impl(
+      row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op);
+}
+
+
+template
+__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
+                             int dim, scalar_t init, BinaryFunction binary_op) {
+  const int64_t row_size = self.size(dim);
+  auto sizes = self.sizes();
+
+  // Treat all outer dimensions (i.e. dim_ < dim) as one.
+  const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
+
+  // Treat all inner dimensions (i.e. dim > dimension) as one.
+  const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
+
+  dim3 threads(std::min(512, int(num_irows)));
+  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
+  dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
+
+  check_fits_in_unsigned(num_irows, "num_irows");
+  check_fits_in_unsigned(num_orows, "num_orows");
+  check_fits_in_unsigned(row_size, "row_size");
+
+  tensor_kernel_scan_outer_dim<<>>(
+    result.mutable_data_ptr(), self.const_data_ptr(),
+    num_orows, num_irows, row_size, init, binary_op);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template 
+void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
+                        scalar_t init, BinaryFunction binary_op) {
+  int64_t ndim = self.dim();
+  // Treat all outer dimensions as a single dimension.
+  int64_t row_size = self.size(ndim - 1);
+  int64_t num_rows = self.numel() / row_size;
+
+  // assuming max_num_threads per block is 512
+  const uint32_t num_threads = 512;
+  const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size);
+  const uint32_t num_threads_x = (1 << log_num_threads_x);
+  const uint32_t num_threads_y = num_threads / num_threads_x;
+  dim3 threads(num_threads_x, num_threads_y);
+  int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
+  dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
+
+  check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
+  check_fits_in_unsigned(row_size, "row_size");
+
+  tensor_kernel_scan_innermost_dim<<>>(
+    result.mutable_data_ptr(), self.const_data_ptr(),
+    num_rows, row_size, log_num_threads_x, init, binary_op);
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template
+void scan_dim(const TensorBase& self, const TensorBase& result,
+     int64_t dim, scalar_t init, BinaryFunction binary_op) {
+  int ndim = self.dim();
+  auto self_ = self.expect_contiguous();
+  TORCH_INTERNAL_ASSERT(result.is_contiguous());
+
+  if (self.numel() == self.size(dim)) {
+    cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel());
+  } else if (dim == ndim - 1) {
+    scan_innermost_dim(*self_, result, init, binary_op);
+  } else {
+    scan_outer_dim(*self_, result, dim, init, binary_op);
+  }
+}
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Sort.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Sort.h
new file mode 100644
index 0000000000000000000000000000000000000000..388401118a9303955a018e1f881688e602742edd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Sort.h
@@ -0,0 +1,17 @@
+#pragma once
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+inline bool should_use_small_sort(const TensorBase &self, int64_t dim) {
+  return self.size(dim) <= 4096;
+}
+
+void sortKeyValueInplace(
+    const TensorBase &key, const TensorBase &value, int dim,
+    bool descending, bool stable=false);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortStable.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortStable.h
new file mode 100644
index 0000000000000000000000000000000000000000..e511e4422163da3ab41e9af206ca33e550d076cd
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortStable.h
@@ -0,0 +1,19 @@
+#pragma once
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+// Stable-sort self into values, and set indices to the
+// inverse-permutation from values back to self.
+// Output tensors must be pre-allocated and contiguous.
+void launch_stable_sort_kernel(
+    const TensorBase& self,
+    int64_t dim,
+    bool descending,
+    const TensorBase& values,
+    const TensorBase& indices);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..6848fa9acd3df7279a4eb938c148621d9b9c1704
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh
@@ -0,0 +1,344 @@
+#pragma once
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600)
+
+
+namespace at { namespace native {
+
+template 
+__device__ inline void swapVars(T& t1, T& t2) {
+  T tmp = t1;
+  t1 = t2;
+  t2 = tmp;
+}
+
+template 
+__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
+                                   K& kB, V& vB, bool& validB,
+                                   bool dir,
+                                   const Comparator& comp) {
+  // Invalid entries always sort to the end
+  bool swap = (comp(kA, kB) && validA) || !validB;
+  if (swap == dir) {
+    swapVars(kA, kB);
+    swapVars(vA, vB);
+    swapVars(validA, validB);
+  }
+};
+
+template 
+__device__ inline void bitonicSort(K *keys,
+                                   V *values,
+                                   bool *valid,
+                                   const Comparator& comp) {
+#if !defined(USE_ROCM)
+#pragma unroll
+#endif
+  for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
+    bool flag = ((threadIdx.x & (size / 2)) != 0);
+
+#if !defined(USE_ROCM)
+#pragma unroll
+#endif
+    for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
+
+      __syncthreads();
+
+      unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
+      bitonicSwap(
+        keys[pos], values[pos], valid[pos],
+        keys[pos + stride], values[pos + stride], valid[pos + stride],
+        flag, comp);
+    }
+  }
+
+#if !defined(USE_ROCM)
+#pragma unroll
+#endif
+  for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
+
+    __syncthreads();
+
+    unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
+    bitonicSwap(
+      keys[pos], values[pos], valid[pos],
+      keys[pos + stride], values[pos + stride], valid[pos + stride],
+      false, comp);
+  }
+
+  __syncthreads();
+
+}
+
+// at::cuda::detail::TensorInfo version
+// Sorts (key, value) pairs (in different tensors) in-place; i.e.,
+// modifies the input `keys` and `values`
+template 
+C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y)
+__global__ void
+bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys,
+                     IndexType keySlices,
+                     IndexType keySliceSize,
+                     IndexType keySliceStride,
+                     at::cuda::detail::TensorInfo values,
+                     IndexType valueSliceStride,
+                     Comparator comp) {
+  // Find the slice of the tensor that we are sorting
+  // NOTE: blockDim.y may be less max_block_dim_y
+  const IndexType blockIndex = getLinearBlockId();
+  const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
+
+  // If the entire block is out of bounds exit early
+  if (blockIndex * blockDim.y >= keySlices) {
+    return;
+  }
+  // It's also possible for some rows of a block to be out of bounds
+  // but all thread need to run for __syncthreads to work.
+  const bool row_valid = linearIndex < keySlices;
+
+  constexpr int items_per_thread = 2;
+  constexpr int Power2SortSize = block_dim_x * items_per_thread;
+
+  // Storage for max_block_dim_y sorts performed in parallel
+  __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize];
+  __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize];
+  __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize];
+
+  auto sharedKeys = blockSharedKeys[threadIdx.y];
+  auto sharedValues = blockSharedValues[threadIdx.y];
+  auto sharedValid = blockSharedValid[threadIdx.y];
+
+  const IndexType keyStartOffset =
+    at::cuda::detail::IndexToOffset::get(linearIndex, keys);
+  const IndexType valueStartOffset =
+    at::cuda::detail::IndexToOffset::get(linearIndex, values);
+
+  // Load 2 values per thread into the shared workspace
+  #pragma unroll
+  for (int k = 0; k < items_per_thread; ++k) {
+    auto idx = threadIdx.x + k * blockDim.x;
+    bool valid = row_valid && idx < keySliceSize;
+
+    sharedKeys[idx] = valid ?
+        keys.data[idx * keySliceStride + keyStartOffset] : K{};
+    sharedValues[idx] = valid ?
+        values.data[idx * valueSliceStride + valueStartOffset] : V{};
+    sharedValid[idx] = valid;
+  }
+
+  // Sort!
+  bitonicSort(
+      sharedKeys, sharedValues, sharedValid, comp);
+
+  if (!row_valid) {
+    return;
+  }
+
+  // Store outputs
+  #pragma unroll
+  for (int k = 0; k < items_per_thread; ++k) {
+    auto idx = threadIdx.x + k * blockDim.x;
+    if (idx < keySliceSize) {
+      keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx];
+      values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx];
+    }
+  }
+}
+
+#if HAS_WARP_MERGE_SORT()
+
+template 
+C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)
+__global__ void
+warpMergeSortKVInPlace(
+    at::cuda::detail::TensorInfo keys,
+    IndexType keySlices,
+    IndexType keySliceSize,
+    IndexType keySliceStride,
+    at::cuda::detail::TensorInfo values,
+    IndexType valueSliceStride,
+    Comparator comp,
+    K invalid_key) {
+  // Find the slice of the tensor that we are sorting
+  // NOTE: blockDim.y may be less max_block_dim_y
+  const IndexType blockIndex = getLinearBlockId();
+  const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
+
+  // If this row is out of bounds exit early
+  if (linearIndex >= keySlices) {
+    return;
+  }
+
+  const IndexType keyStartOffset =
+    at::cuda::detail::IndexToOffset::get(linearIndex, keys);
+  const IndexType valueStartOffset =
+    at::cuda::detail::IndexToOffset::get(linearIndex, values);
+
+  K *keys_slice = &keys.data[keyStartOffset];
+  V *values_slice = &values.data[valueStartOffset];
+
+  StridedRandomAccessor keys_iter(keys_slice, keySliceStride);
+  StridedRandomAccessor values_iter(values_slice, valueSliceStride);
+
+  namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
+
+  CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE);
+  CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y);
+  constexpr int items_per_thread = sort_size / C10_WARP_SIZE;
+  static_assert(
+      items_per_thread * C10_WARP_SIZE == sort_size,
+      "sort_size must be a multiple of C10_WARP_SIZE");
+
+
+  using LoadKeys = cub::WarpLoad;
+  using LoadValues = cub::WarpLoad;
+  using Sort = cub::WarpMergeSort;
+  using StoreKeys = cub::WarpStore;
+  using StoreValues = cub::WarpStore;
+
+  __shared__ union {
+    typename LoadKeys::TempStorage load_keys;
+    typename LoadValues::TempStorage load_values;
+    typename Sort::TempStorage sort;
+    typename StoreKeys::TempStorage store_keys;
+    typename StoreValues::TempStorage store_values;
+  } tmp_storage[max_block_dim_y];
+
+  auto& warp_storage = tmp_storage[threadIdx.y];
+
+  // Load inputs
+  K local_keys[items_per_thread];
+  V local_values[items_per_thread];
+
+  const auto invalid_value = V{};
+  LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
+  WARP_SYNC();
+  LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
+  WARP_SYNC();
+
+  // Sort! We use stable sort to ensure that invalid values are never
+  // sorted before valid values. In testing it performed the same as
+  // .Sort, so there is no down-side.
+  Sort(warp_storage.sort).StableSort(
+      local_keys, local_values, comp, keySliceSize, invalid_key);
+  WARP_SYNC();
+
+  // Store outputs
+  StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
+  WARP_SYNC();
+  StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize);
+}
+
+#endif // HAS_WARP_MERGE_SORT()
+
+template 
+C10_LAUNCH_BOUNDS_1(block_size)
+__global__ void
+radixSortKVInPlace(at::cuda::detail::TensorInfo keys,
+                   IndexType keySlices,
+                   IndexType keySliceSize,
+                   IndexType keySliceStride,
+                   at::cuda::detail::TensorInfo values,
+                   IndexType valueSliceStride,
+                   bool descending) {
+  static_assert(block_size > 0, "");
+
+  // Find the slice of the tensor that we are sorting
+  const IndexType linearIndex = getLinearBlockId();
+  // Tiling the slices could have us be out of bounds, if there are a
+  // lot of slices to sort
+  if (linearIndex >= keySlices) {
+    return;
+  }
+
+  const IndexType keyStartOffset =
+    at::cuda::detail::IndexToOffset::get(linearIndex, keys);
+  const IndexType valueStartOffset =
+    at::cuda::detail::IndexToOffset::get(linearIndex, values);
+
+  K *keys_slice = &keys.data[keyStartOffset];
+  V *values_slice = &values.data[valueStartOffset];
+
+  StridedRandomAccessor keys_iter(keys_slice, keySliceStride);
+  StridedRandomAccessor values_iter(values_slice, valueSliceStride);
+
+  namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
+
+  using key_t = typename at::cuda::cub::detail::cuda_type::type;
+  using LoadKeys = cub::BlockLoad;
+  using LoadValues = cub::BlockLoad;
+  using Sort = cub::BlockRadixSort;
+  using StoreKeys = cub::BlockStore;
+  using StoreValues = cub::BlockStore;
+
+  __shared__ union {
+    typename LoadKeys::TempStorage load_keys;
+    typename LoadValues::TempStorage load_values;
+    typename Sort::TempStorage sort;
+    typename StoreKeys::TempStorage store_keys;
+    typename StoreValues::TempStorage store_values;
+  } tmp_storage;
+
+  // cub's Block operations operate on a fixed number of items, but the
+  // actual slice we are sorting might be smaller. So, we need to make
+  // up the difference with keys that will always sort higher.
+  const K invalid_key = [descending] {
+    using radix_t = typename cub::Traits::UnsignedBits;
+    union {
+      K key;
+      radix_t radix;
+    } tmp;
+    tmp.radix = descending ?
+        cub::Traits::LOWEST_KEY :
+        cub::Traits::MAX_KEY;
+    return tmp.key;
+  }();
+  const V invalid_value = static_cast(0);
+
+  // Load inputs
+  K local_keys[items_per_thread];
+  V local_values[items_per_thread];
+
+  LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
+  __syncthreads();
+  LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
+  __syncthreads();
+
+  // Sort!
+  if (descending) {
+    Sort(tmp_storage.sort).SortDescending(
+        reinterpret_cast(local_keys),
+        local_values);
+  } else {
+    Sort(tmp_storage.sort).Sort(
+        reinterpret_cast(local_keys),
+        local_values);
+  }
+  __syncthreads();
+
+  // Store outputs
+  StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
+  __syncthreads();
+  StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize);
+}
+
+}} // at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Sorting.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Sorting.h
new file mode 100644
index 0000000000000000000000000000000000000000..8eddefcf1be3316cfead0630dd59a02219c9de94
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/Sorting.h
@@ -0,0 +1,18 @@
+#pragma once
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at {
+namespace native {
+
+void launch_kthvalue_kernel(
+    const TensorBase &values, const TensorBase &indices,
+    const TensorBase &self, int64_t dim, int64_t k);
+void launch_median_kernel(
+    const TensorBase &vals, const TensorBase &inds,
+    const TensorBase &in, int64_t dim, bool ignore_nan);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..58636e5e1d0b059d1379d8f8fed293db95c6445c
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh
@@ -0,0 +1,193 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+// Is this questionable namespace pollution?
+#if defined(USE_ROCM)
+constexpr int MAX_BLOCK_SIZE = 256;
+
+#else
+constexpr int MAX_BLOCK_SIZE = 1024;
+#endif
+
+// Maximum size per grid dimension that we assume (compute capability >= 2.0)
+constexpr int64_t MAX_GRID_SIZE = 65535LL;
+
+static bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
+  if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
+    return false;
+  }
+
+  int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+  int64_t gridY = 1;
+  int64_t gridZ = 1;
+
+  if (gridTiles > MAX_GRID_SIZE) {
+    gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE);
+    gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+
+    if (gridTiles > MAX_GRID_SIZE) {
+      gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE);
+      gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
+    }
+  }
+
+  grid = dim3(gridX, gridY, gridZ);
+  return true;
+}
+
+template 
+struct GTOp {
+  __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
+    return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs);
+  }
+};
+
+template 
+struct LTOp {
+  __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
+    return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs);
+  }
+};
+
+template 
+__device__ __forceinline__ index_t getLinearBlockId() {
+  return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x +
+      blockIdx.x;
+}
+
+// For slice sorting in Thrust; extracts a slice index from a linear
+// index and uses that for comparison
+struct SliceComp {
+  SliceComp(int64_t size) : sliceSize(size) {}
+
+  __device__ bool operator()(const int64_t& a, const int64_t& b) const {
+    // Since the slices are guaranteed to be innermost,
+    // the segment is just via int64_t division
+    int64_t segA = a / sliceSize;
+    int64_t segB = b / sliceSize;
+    return segA < segB;
+  }
+
+  const int64_t sliceSize;
+};
+
+// For sorting in Thurst; extracts a within-slice index from a linear index
+struct GlobalIndexToPerSliceIndex {
+  GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
+
+  __device__ inline void operator()(int64_t& v) const {
+    v = v % sliceSize;
+  }
+
+  const int64_t sliceSize;
+};
+
+// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
+static uint64_t nextHighestPowerOf2(uint64_t n) {
+  n--;
+  n |= n >> 1;
+  n |= n >> 2;
+  n |= n >> 4;
+  n |= n >> 8;
+  n |= n >> 16;
+#ifndef _MSC_VER
+  n |= n >> 32;
+#endif
+  n++;
+
+  return n;
+}
+
+
+// WARNING: This function assumes input tensors are contiguous
+template 
+void run_launcher(
+    const TensorBase &values,
+    const TensorBase &indices,
+    const TensorBase &self,
+    int64_t dim,
+    Launcher l) {
+  auto self_info = cuda::detail::getTensorInfo(self);
+  auto values_info = cuda::detail::getTensorInfo(values);
+  auto indices_info = cuda::detail::getTensorInfo(indices);
+
+  int64_t slice_size = self.size(dim);
+  /* We use these structures solely to find the offset to */
+  /* each slice we are operating on */
+  self_info.reduceDim(dim);
+  values_info.reduceDim(dim);
+  indices_info.reduceDim(dim);
+
+  /* Collapse all other dims */
+  int collapse_self_dim = self_info.collapseDims(dim);
+  int collapse_values_dim = values_info.collapseDims(dim);
+  int collapse_indices_dim = indices_info.collapseDims(dim);
+
+  int64_t num_slices = 1;
+  for (int i = 0; i < self_info.dims; ++i) {
+    num_slices *= self_info.sizes[i];
+  }
+
+  /* This is used as a template parameter to calculate indices. */
+  /* We only specialize it if all collapsed dim sizes are the */
+  /* same; otherwise, we use -1 which is the specialization */
+  /* parameter for arbitrary dimensions */
+  int all_dims = self_info.dims;
+  if (values_info.dims != all_dims || indices_info.dims != all_dims) {
+    all_dims = -1;
+  }
+
+  if (all_dims == 1) {
+    l.template launch(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  } else if (all_dims == 2) {
+    l.template launch(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  } else if (all_dims == 3) {
+    l.template launch(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  } else {
+    l.template launch(
+        values_info,
+        collapse_values_dim,
+        indices_info,
+        collapse_indices_dim,
+        self_info,
+        collapse_self_dim,
+        num_slices,
+        slice_size);
+  }
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..446ca5d796a903155bb7987d4b2348750b685465
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh
@@ -0,0 +1,429 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+template 
+struct TopKTypeConfig {};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  // Converts a float to an integer representation with the same
+  // sorting; i.e., for floats f1, f2:
+  // if f1 < f2 then convert(f1) < convert(f2)
+  // We use this to enable radix selection of floating-point values.
+  // This also gives a relative order for NaNs, but that's ok, as they
+  // will all be adjacent
+  // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff..
+  // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00..
+  // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0
+  // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(uint8_t v) {
+    return v;
+  }
+
+  static inline __device__ uint8_t deconvert(RadixType v) {
+    return v;
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(int8_t v) {
+    return 128u + v;
+  }
+
+  static inline __device__ int8_t deconvert(RadixType v) {
+    return v - 128;
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(int16_t v) {
+    static_assert(sizeof(short) == 2, "");
+    return 32768u + v;
+  }
+
+  static inline __device__ int16_t deconvert(RadixType v) {
+    return v - 32768;
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(int32_t v) {
+    static_assert(sizeof(int) == 4, "");
+    return 2147483648u + v;
+  }
+
+  static inline __device__ int32_t deconvert(RadixType v) {
+    return v - 2147483648u;
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint64_t RadixType;
+
+  static inline __device__ RadixType convert(int64_t v) {
+    static_assert(sizeof(int64_t) == 8, "");
+    return 9223372036854775808ull + v;
+  }
+
+  static inline __device__ int64_t deconvert(RadixType v) {
+    return v - 9223372036854775808ull;
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint64_t RadixType;
+
+  static inline __device__ RadixType convert(double v) {
+    RadixType x = __double_as_longlong(v);
+    RadixType mask = -((x >> 63)) | 0x8000000000000000;
+    return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
+  }
+
+  static inline __device__ double deconvert(RadixType v) {
+    RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
+    return __longlong_as_double(v ^ mask);
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(at::Half v) {
+#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
+    RadixType x = __half_as_ushort(v);
+    RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
+    return (v == v) ? (x ^ mask) : 0xffff;
+#else
+    CUDA_KERNEL_ASSERT(false);
+    return 0u;
+#endif
+  }
+
+  static inline __device__ at::Half deconvert(RadixType v) {
+#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
+    RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
+    return __ushort_as_half(v ^ mask);
+#else
+    CUDA_KERNEL_ASSERT(false);
+    return static_cast(0);
+#endif
+  }
+};
+
+template <>
+struct TopKTypeConfig {
+  typedef uint32_t RadixType;
+
+  static inline __device__ RadixType convert(at::BFloat16 v) {
+    RadixType x = v.x;
+    RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
+    return (v == v) ? (x ^ mask) : 0xffff;
+  }
+
+  static inline __device__ at::BFloat16 deconvert(RadixType v) {
+    RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
+    at::BFloat16 r;
+    r.x = (v ^ mask);
+    return r;
+  }
+};
+
+// This function counts the distribution of all input values in a
+// slice we are selecting by radix digit at `radixDigitPos`, but only
+// those that pass the filter `((v & desiredMask) == desired)`.
+// This produces and broadcasts the seen counts for a single block only.
+// `smem` must have at least `RadixSize` elements.
+template <
+    typename scalar_t,
+    typename bitwise_t,
+    typename index_t,
+    typename CountType,
+    int RadixSize,
+    int RadixBits>
+__device__ void countRadixUsingMask(
+    CountType counts[RadixSize],
+    CountType* smem,
+    bitwise_t desired,
+    bitwise_t desiredMask,
+    int radixDigitPos,
+    index_t sliceSize,
+    index_t withinSliceStride,
+    const scalar_t* data) {
+  // Clear out per-thread counts from a previous round
+#pragma unroll
+  for (int i = 0; i < RadixSize; ++i) {
+    counts[i] = 0;
+  }
+
+  if (threadIdx.x < RadixSize) {
+    smem[threadIdx.x] = 0;
+  }
+  __syncthreads();
+
+  // Scan over all the data. Upon a read, the warp will accumulate
+  // counts per each digit in the radix using warp voting.
+#if !defined(USE_ROCM)
+  // Must be called outside of loop to ensure all threads participate
+  unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize);
+#endif
+  for (index_t i = threadIdx.x; i < sliceSize;) {
+    bitwise_t val =
+        TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride]));
+
+    bool hasVal = ((val & desiredMask) == desired);
+    bitwise_t digitInRadix = at::cuda::Bitfield::getBitfield(
+        val, radixDigitPos, RadixBits);
+
+#pragma unroll
+    for (uint32_t j = 0; j < RadixSize; ++j) {
+      bool vote = hasVal && (digitInRadix == j);
+#if defined(USE_ROCM)
+      counts[j] += __popcll(WARP_BALLOT(vote));
+#else
+      counts[j] += __popc(WARP_BALLOT(vote, mask));
+#endif
+    }
+    i += blockDim.x;
+#if !defined(USE_ROCM)
+    mask = WARP_BALLOT(i < sliceSize, mask);
+#endif
+  }
+
+  // Now, for each warp, sum values
+  if (at::cuda::getLaneId() == 0) {
+#pragma unroll
+    for (uint32_t i = 0; i < RadixSize; ++i) {
+      gpuAtomicAddNoReturn(&smem[i], counts[i]);
+    }
+  }
+
+  __syncthreads();
+
+  // For each thread, read in the total counts
+#pragma unroll
+  for (uint32_t i = 0; i < RadixSize; ++i) {
+    counts[i] = smem[i];
+  }
+
+  __syncthreads();
+}
+
+// Over what radix we are selecting values
+constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
+constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
+constexpr int RADIX_MASK = (RADIX_SIZE - 1);
+
+// This finds the unique value `v` that matches the pattern
+// ((v & desired) == desiredMask) in our sorted int format
+template 
+__device__ scalar_t findPattern(
+    scalar_t* smem,
+    const scalar_t* data,
+    index_t sliceSize,
+    index_t withinSliceStride,
+    bitwise_t desired,
+    bitwise_t desiredMask) {
+  if (threadIdx.x < 2) {
+    smem[threadIdx.x] = static_cast(0);
+  }
+  __syncthreads();
+
+  // All threads participate in the loop, in order to sync on the flag
+  index_t numIterations =
+      round_up(sliceSize, static_cast(blockDim.x));
+  for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
+    bool inRange = (i < sliceSize);
+    scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
+                         : static_cast(0);
+
+    if (inRange &&
+        ((TopKTypeConfig::convert(v) & desiredMask) == desired)) {
+      // There should not be conflicts if we are using findPattern,
+      // since the result is unique
+      smem[0] = static_cast(1);
+      smem[1] = v; // can't use val as the flag, since it could be 0
+    }
+
+    __syncthreads();
+
+    scalar_t found = smem[0];
+    scalar_t val = smem[1];
+
+    __syncthreads();
+
+    // Check to see if a thread found the value
+    if (found != static_cast(0)) {
+      // all threads return this value
+      return val;
+    }
+  }
+
+  // should not get here
+  CUDA_KERNEL_ASSERT(false);
+  return static_cast(0);
+}
+
+// Returns the top-Kth element found in the data using radix selection
+template 
+__device__ void radixSelect(
+    const scalar_t* data,
+    index_t k,
+    bool largest,
+    index_t sliceSize,
+    index_t withinSliceStride,
+    int* smem,
+    scalar_t* topK) {
+  // Per-thread buckets into which we accumulate digit counts in our
+  // radix
+  int counts[RADIX_SIZE];
+
+  // We only consider elements x such that (x & desiredMask) == desired
+  // Initially, we consider all elements of the array, so the above
+  // statement is true regardless of input.
+  bitwise_t desired = 0;
+  bitwise_t desiredMask = 0;
+
+  // We are looking for the top kToFind-th element when iterating over
+  // digits; this count gets reduced by elimination when counting
+  // successive digits
+  int kToFind = k;
+
+  // We start at the most significant digit in our radix, scanning
+  // through to the least significant digit
+  for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
+       digitPos -= RADIX_BITS) {
+    // Count radix distribution for the current position and reduce
+    // across all threads
+    countRadixUsingMask<
+        scalar_t,
+        bitwise_t,
+        index_t,
+        int,
+        RADIX_SIZE,
+        RADIX_BITS>(
+        counts,
+        smem,
+        desired,
+        desiredMask,
+        digitPos,
+        sliceSize,
+        withinSliceStride,
+        data);
+
+    auto found_unique = [&](int i, int count) -> bool {
+      /* All threads have the same value in counts here, so all */
+      /* threads will return from the function. */
+      if (count == 1 && kToFind == 1) {
+        /* There is a unique answer. */
+        desired = at::cuda::Bitfield::setBitfield(
+            desired, i, digitPos, RADIX_BITS);
+        desiredMask = at::cuda::Bitfield::setBitfield(
+            desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
+
+        /* The answer is now the unique element v such that: */
+        /* (v & desiredMask) == desired */
+        /* However, we do not yet know what the actual element is. We */
+        /* need to perform a search through the data to find the */
+        /* element that matches this pattern. */
+        *topK = findPattern(
+            (scalar_t*)smem,
+            data,
+            sliceSize,
+            withinSliceStride,
+            desired,
+            desiredMask);
+        return true;
+      }
+      return false;
+    };
+    auto found_non_unique = [&](int i, int count) -> bool {
+      if (count >= kToFind) {
+        desired =
+            at::cuda::Bitfield::setBitfield(
+                desired, i, digitPos, RADIX_BITS);
+        desiredMask = at::cuda::Bitfield::setBitfield(
+            desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
+
+        /* The top-Kth element v must now be one such that: */
+        /* (v & desiredMask == desired) */
+        /* but we haven't narrowed it down; we must check the next */
+        /* least-significant digit */
+        return true;
+      }
+      kToFind -= count;
+      return false; // continue the loop
+    };
+
+    // All threads participate in the comparisons below to know the
+    // final result
+    if (largest) {
+      // Process in descending order
+#pragma unroll
+      for (int i = RADIX_SIZE - 1; i >= 0; --i) {
+        int count = counts[i];
+        if (found_unique(i, count)) {
+          return;
+        }
+        if (found_non_unique(i, count)) {
+          break;
+        }
+      }
+    } else {
+      // Process in ascending order
+#pragma unroll
+      for (int i = 0; i < RADIX_SIZE; ++i) {
+        int count = counts[i];
+        if (found_unique(i, count)) {
+          return;
+        }
+        if (found_non_unique(i, count)) {
+          break;
+        }
+      }
+    }
+  } // end digitPos for
+
+  // There is no unique result, but there is a non-unique result
+  // matching `desired` exactly
+  *topK = TopKTypeConfig::deconvert(desired);
+}
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..aadc34b67d7dc065aac2454ec0387c7f4e4a78b3
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh
@@ -0,0 +1,435 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+// Used for a segmented reduction
+struct ModeUnsignedBoolPair {
+  unsigned int val;
+  bool flag;
+};
+
+// In the kernel below, we have a common pattern of reducing (unsigned int,
+// unsigned int) pairs of data
+struct ModeUnsignedPair {
+  unsigned int val;
+  unsigned int index;
+};
+
+// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
+//
+// 1. Power2ScanSize is a power of 2. This code still works for collections that
+// do not exactly contain a power of 2 number of elements, simply round up to
+// the nearest power of 2 and then call.
+//
+// 2. That there are two-elements per thread, i.e. the size of the smem storage
+// is 2 * blockDim.x * sizeof(T).
+//
+// Consider a (+)-Scan on the following elements:
+//
+// Upsweep:
+//
+//    0  1  2  3  4  5  6  7
+//       1     5     9    13
+//             6          22
+//                        28
+//
+// Downsweep:
+//                  15
+//         3     10    21
+template 
+__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) {
+  // Reduce step ("upsweep")
+#pragma unroll
+  for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
+    int index = (threadIdx.x + 1) * stride * 2 - 1;
+    if (index < Power2ScanSize) {
+      smem[index] = binop(smem[index], smem[index - stride]);
+    }
+    __syncthreads();
+  }
+
+  // Post-reduce step ("downsweep")
+#pragma unroll
+  for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
+    int index = (threadIdx.x + 1) * stride * 2 - 1;
+    if ((index + stride) < Power2ScanSize) {
+      smem[index + stride] = binop(smem[index + stride], smem[index]);
+    }
+    __syncthreads();
+  }
+}
+
+// Block-wide reduction where each thread locally reduces N
+// values before letting a single warp take over - assumes
+// threadVals is in registers, not shared memory
+//
+// If smem is not used again, there is no need to __syncthreads before this
+// call. However, if smem will be used, e.g., this function is called in a loop,
+// then __syncthreads is needed either before or afterwards to prevent non-0
+// threads overriding smem in the next loop before num-0 thread reads from it.
+template 
+__device__ T reduceBlockWithNThreadLocalReductions(
+    T* smem,
+    T threadVals[N],
+    const unsigned int numVals,
+    ReduceOp reduceOp,
+    T init) {
+  int offset = threadIdx.x * N;
+  T local = offset < numVals ? threadVals[0] : init;
+
+#pragma unroll
+  for (int i = 1; i < N; ++i) {
+    ++offset;
+    T next = offset < numVals ? threadVals[i] : init;
+    local = reduceOp.combine(local, next);
+  }
+
+  return cuda_utils::BlockReduce(local, reduceOp, init, smem);
+}
+
+template 
+__device__ inline void swapVars(T& t1, T& t2) {
+  T tmp = t1;
+  t1 = t2;
+  t2 = tmp;
+}
+
+template 
+__device__ inline void bitonicSwap(
+    K& kA,
+    V& vA,
+    bool& validA,
+    K& kB,
+    V& vB,
+    bool& validB,
+    bool dir,
+    const Comparator& comp) {
+  // Invalid entries always sort to the end
+  bool swap = (comp(kA, kB) && validA) || !validB;
+  if (swap == dir) {
+    swapVars(kA, kB);
+    swapVars(vA, vB);
+    swapVars(validA, validB);
+  }
+};
+
+template 
+__device__ inline void bitonicSwapKeys(
+    K& kA,
+    bool& validA,
+    K& kB,
+    bool& validB,
+    bool dir,
+    const Comparator& comp) {
+  bool swap = (comp(kA, kB) && validA) || !validB;
+  if (swap == dir) {
+    swapVars(kA, kB);
+    swapVars(validA, validB);
+  }
+}
+
+template <
+    typename K,
+    typename IndexType,
+    int Power2SortSize,
+    typename Comparator>
+__device__ inline void bitonicSortKeys(
+    K keys[Power2SortSize],
+    bool valid[Power2SortSize],
+    const Comparator& comp) {
+#if !defined(USE_ROCM)
+#pragma unroll
+#endif
+  for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
+    bool flag = ((threadIdx.x & (size / 2)) != 0);
+
+#if !defined(USE_ROCM)
+#pragma unroll
+#endif
+    for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
+      __syncthreads();
+
+      unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
+      bitonicSwapKeys(
+          keys[pos],
+          valid[pos],
+          keys[pos + stride],
+          valid[pos + stride],
+          flag,
+          comp);
+    }
+  }
+
+#if !defined(USE_ROCM)
+#pragma unroll
+#endif
+  for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
+    __syncthreads();
+
+    unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
+    bitonicSwapKeys(
+        keys[pos],
+        valid[pos],
+        keys[pos + stride],
+        valid[pos + stride],
+        false,
+        comp);
+  }
+
+  __syncthreads();
+}
+
+// The mode kernel has the following characteristics: It uses internal shared
+// memory buffers of Power2Size, which must be greater than the number of
+// elements. Additionally, there is one block for every slice to calculate the
+// mode for, and in each block there is one thread for every two elements.
+//
+// Both sorted and positions are assumed to be contiguous Tensors with the mode
+// dimension as the innermost dim, such that we can get the particular slice for
+// a Tensor via its linear block dimension * the slice size.
+template 
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 11070
+__launch_bounds__(1024, 1)
+#endif
+__global__ void compute_mode(
+    const T* input,
+    at::cuda::detail::TensorInfo values,
+    at::cuda::detail::TensorInfo indices,
+    int64_t sliceSize,
+    int64_t slices) {
+  int tidx = threadIdx.x;
+  int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for
+
+  // First, we need to calculate the offset into the sorted Tensor that
+  // represents the start of the slice for this block to calculate the mode for.
+  // This offset is a combination of the gridIndices, and the number of elements
+  // in the slice.
+  unsigned int blockId = getLinearBlockId();
+  unsigned int linearOffset = blockId * sliceSize;
+
+  if (blockId >= slices) {
+      return;
+  }
+
+  // shmem is a dynamically sized buffer we will use throughout the kernel to
+  // handle computation efficiently. The size of this shmem must be
+  // sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size)
+  //
+  // Initially, the buffer will be organized as follows:
+  //
+  // [smem (slice elements) | bmem (valid indices) | ]
+  extern __shared__ char shmem[];
+
+  // smem represents a proportion of the shared memory buffer that is used to
+  // store the elements from the slice:
+  T* smem = reinterpret_cast(shmem);
+
+  // Each thread loads up to two elements from the Tensor into shared memory
+  if (tidx < sliceSize) {
+    smem[tidx] = c10::load(&input[linearOffset + tidx]);
+  }
+  if (stidx < sliceSize) {
+    smem[stidx] = c10::load(&input[linearOffset + stidx]);
+  }
+
+  // Next, we initialize a boolean region of the buffer, offset by the loaded
+  // element smem region
+  bool* bmem = reinterpret_cast(&smem[Power2Size]);
+
+  // The first use of this region stores bmem[i] = i < sliceSize to mark the
+  // valid components in the smem buffer
+  bmem[tidx] = tidx < sliceSize;
+  bmem[stidx] = stidx < sliceSize;
+  __syncthreads(); // barrier for smem, bmem initialization
+
+  // First, sort the input slice in ascending order. smem contains the input
+  // elements, and bmem marks the valid indices
+  bitonicSortKeys(
+      smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) {
+        return a < b;
+      });
+  __syncthreads(); // make no assumptions that the sort syncs at end
+
+  // The next step of our algorithm is performing a block-wide comparison of
+  // neighboring elements. In particular, given an sorted input slice A, we
+  // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise
+  // 0.
+  //
+  // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
+  //                 B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
+  //
+  // In particular, we can think of B[i] true indicating the start of a sequence
+  // of equal values in the sorted list. Similarly, we will also store the
+  // negation of B, which we'll call C. In particular, we can think of C[i] =
+  // true iff A[i-1] == A[i] in our original sorted slice.
+  //
+  //                 C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
+
+  // We overwrite bmem, and treat the rest of shared memory as a buffer of
+  // (index, flag) pairs where the index represents values from C, and the flag
+  // represents values from B.
+  //
+  // [smem (sorted slice) | ubpmem (index, flag pairs)]
+
+  struct ModeUnsignedBoolPair* ubpmem =
+      reinterpret_cast(&smem[Power2Size]);
+
+  if (tidx == 0) {
+    ubpmem[0].flag = true;
+    ubpmem[0].val = 0;
+  }
+
+  // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
+  ubpmem[tidx * 2 + 1].flag =
+      smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc.
+  ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag;
+
+  // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
+  if (((tidx + 1) * 2) < Power2Size) {
+    ubpmem[(tidx + 1) * 2].flag =
+        smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2];
+    ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag;
+  }
+  __syncthreads(); // barrier for ubpmem initialization
+
+  // Next, we perform a segmented prefix sum on the neighboring elements, where
+  // the presence of a one indicates the start of a segment. In this case B acts
+  // as the segment start flags, and C is the buffer to be summed:
+  //
+  // Input  (C)  = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
+  // Flag   (B)  = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
+  // Output (C)  = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
+  //
+  // Afterwards, the (index) components of the ubpmem buffer contain the lengths
+  // of the segments (minus 1), i.e. the counts of each element in the original
+  // input.
+  inclusivePrefixScan(
+      ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) {
+        ModeUnsignedBoolPair c;
+        c.val = a.flag ? a.val : a.val + b.val;
+        c.flag = a.flag | b.flag;
+        return c;
+      });
+  // assumes scan syncs at the end
+
+  // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e.
+  // we treat the boolean flag regions as integers). We initialize these to
+  // represent indices, and we'll call this buffer I
+  struct ModeUnsignedPair* uupmem =
+      reinterpret_cast(ubpmem);
+
+  // At this point, we need to find the maximum element in lengths buffer C.
+  // This element will represent the count (-1) of the mode. Because of the
+  // way we have set up the problem, the index where this mode occurs will
+  // also be the location of the mode value in the sorted array, e.g.
+  //
+  // smem = [0, 0, 1, 1, 1, 2]
+  // C    = [0, 1, 0, 1, 2, 0]
+  // I    = [0, 1, 2, 3, 4, 5]
+  //                     ^
+  //                     maximum value, also aligned with mode = 1
+  //
+  // We perform a block wide max-reduction of the C buffer, but we also need the
+  // indices to come along with it, so we utilize the uupmem construction.
+  //
+  // At the end we need to return the ModeUnsignedPair containing index = 4, val
+  // = 2, which represents the max
+
+  // In practice, we will make each thread locally reduce 2 values in its
+  // registers prior to the global block-wide reduction. Note that instead of
+  // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with
+  // adjacent elements. This is because the reduce code below relies on thread
+  // elements to be adjacent.
+  struct ModeUnsignedPair uup[2];
+  uup[0].index = tidx * 2;
+  uup[0].val = ubpmem[tidx * 2].val;
+  uup[1].index = tidx * 2 + 1;
+  uup[1].val = ubpmem[tidx * 2 + 1].val;
+  __syncthreads();
+
+  struct ModeUnsignedPair max = {0, 0};
+
+  struct MaxOp {
+    inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
+      return b.val > a.val ? b : a;
+    }
+
+    inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
+      ModeUnsignedPair ret;
+      ret.index = WARP_SHFL_DOWN(acc.index, offset);
+      ret.val = WARP_SHFL_DOWN(acc.val, offset);
+      return ret;
+    }
+  } max_op;
+
+  max = reduceBlockWithNThreadLocalReductions<2>(
+      uupmem,
+      uup,
+      sliceSize,
+      max_op,
+      max);
+
+  // Store the mode in shared memory for use in finding the mode in the input
+  // slice
+  __shared__ T mode;
+
+  // Given the above constraints, the mode is the value at the reduced index in
+  // the original sorted element buffer
+  if (tidx == 0) {
+    mode = smem[max.index];
+  }
+  __syncthreads(); // broadcast mode
+
+  // Finally, we need to find "an" index of the mode in the input
+  // Tensor. The API does not constrain which index we pick, but here
+  // we always pick the largest index. We store the index if the value
+  // is the mode, or 0 otherwise. Then find the maximum value.
+  //
+  // Again we reduce 2 elements in the thread's registers prior to the
+  // block-wide reduction
+  unsigned mode_index[2] = {0u, 0u};
+  if (tidx * 2 < sliceSize) {
+    const unsigned idx = tidx * 2;
+    mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
+  }
+  if (tidx * 2 + 1 < sliceSize) {
+    const unsigned idx = tidx * 2 + 1;
+    mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
+  }
+
+  struct MaxIndexOp {
+    inline __device__ unsigned combine(unsigned a, unsigned b) const {
+      return b > a ? b : a;
+    }
+
+    inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
+      return WARP_SHFL_DOWN(acc, offset);
+    }
+  } max_index_op;
+
+  int64_t index = reduceBlockWithNThreadLocalReductions<2>(
+      reinterpret_cast(&shmem[0]),
+      mode_index,
+      sliceSize,
+      max_index_op,
+      0u);
+
+  // Finally, we have the mode, and an index where it occurs. We use a single
+  // thread to place this in the appropriate output position
+  if (tidx == 0) {
+    unsigned int outputOffset =
+        at::cuda::detail::IndexToOffset::get(
+            blockId, values);
+    values.data[outputOffset] = mode;
+    indices.data[outputOffset] = index;
+  }
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..4ee6fb5c0fd2a17c4f29e44985fd7c2bab0f8b0a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.h
@@ -0,0 +1,19 @@
+#pragma once
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at {
+namespace native {
+
+void launch_fused_mode_kernel(
+    const TensorBase &values, const TensorBase &indices,
+    const TensorBase &self, int64_t slice_size, int64_t slices);
+
+void launch_apply_mode_kernel(
+    const TensorBase &values, const TensorBase &indices,
+    const TensorBase &self, int64_t dim, int64_t ndim);
+
+}}  // namespace at::native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorTopK.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorTopK.h
new file mode 100644
index 0000000000000000000000000000000000000000..2f34706de244852b83865f9d0f45601fe01e0d5a
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/TensorTopK.h
@@ -0,0 +1,14 @@
+#pragma once
+#include 
+
+namespace at {
+class TensorBase;
+}
+
+namespace at {
+namespace native {
+void launch_gather_topk_kernel(
+    const TensorBase& self,
+    int64_t k, int64_t dim, bool largest,
+    const TensorBase& values, const TensorBase& indices);
+}}
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..3b39fbebc8daac24f402a7230b9547706a300b24
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/UniqueCub.cuh
@@ -0,0 +1,16 @@
+#include 
+
+namespace at {
+namespace native {
+namespace internal {
+
+template 
+std::tuple unique_cuda_template(
+    const Tensor& self,
+    const bool consecutive,
+    const bool return_inverse,
+    const bool return_counts);
+
+} // namespace internal
+} // namespace at
+} // namespace native
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/UpSample.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/UpSample.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..11f644f8d2aaa8fd5cba0bf128cf8e0d600f660b
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/UpSample.cuh
@@ -0,0 +1,370 @@
+#pragma once
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+namespace native {
+
+namespace upsample {
+// TODO: Remove duplicate declaration.
+TORCH_API c10::SmallVector compute_output_size(
+    c10::IntArrayRef input_size,  // Full input tensor size.
+    at::OptionalIntArrayRef output_size,
+    c10::optional> scale_factors);
+} // namespace upsample
+
+namespace upsample_cuda {
+
+// TODO: Remove duplication with Upsample.h (CPU).
+inline c10::optional get_scale_value(c10::optional> scales, int idx) {
+  if (!scales) {
+    return nullopt;
+  }
+  return scales->at(idx);
+}
+
+} // namespace upsample_cuda
+
+
+/* TODO: move this to a common place */
+template 
+__device__ inline scalar_t min(scalar_t a, scalar_t b) {
+  return a < b ? a : b;
+}
+
+template 
+__device__ inline scalar_t max(scalar_t a, scalar_t b) {
+  return a > b ? a : b;
+}
+
+// NOTE [ Nearest neighbor upsampling kernel implementation ]
+//
+// The nearest neighbor upsampling kernel implementation is symmetrical as
+// expected. We launch kernels with threads mapping to destination tensors where
+// kernels write data to, each thread reads data from the source tensor, this
+// means:
+// 1. In the forward kernel,
+//      src_xxx refers to properties of input tensors;
+//      dst_xxx refers to properties of output tensors;
+//      scale_factor is the ratio of src_size to dst_size;
+// 2. In the backward kernel,
+//      src_xxx refers to properties of grad_output tensors;
+//      dst_xxx refers to properties of grad_input tensors;
+//      scale_factor is the ratio of src_size to dst_size;
+//
+// Because of this, we need to take the reciprocal of the scale defined by
+// upsample layer during forward path. The motivation is to avoid slow
+// division in the kernel code, so we can use faster multiplication instead.
+// This is not necessary during backward path, since the scale_factor is already
+// the reciprocal of corresponding scale_factor used in the forward path due to
+// the swap of source and destination tensor.
+//
+// Similarly, since the mapping from grad_input to grad_output during backward
+// is the reverse of the mapping of output to input, we need to have opposite
+// mapping functions to compute the source index.
+
+// see NOTE [ Nearest neighbor upsampling kernel implementation ]
+template 
+__host__ __forceinline__ static accscalar_t compute_scales_value(
+    const c10::optional scale,
+    int64_t src_size,
+    int64_t dst_size) {
+  // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
+  return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)(1.0 / scale.value())
+                                                   : (accscalar_t)src_size / dst_size;
+}
+
+// see NOTE [ Nearest neighbor upsampling kernel implementation ]
+template 
+__host__ __forceinline__ static accscalar_t compute_scales_value_backwards(
+    const c10::optional scale,
+    int64_t src_size,
+    int64_t dst_size) {
+  // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
+  return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)scale.value()
+                                                   : (accscalar_t)src_size / dst_size;
+}
+
+template 
+__host__ __forceinline__ static accscalar_t area_pixel_compute_scale(
+    int input_size,
+    int output_size,
+    bool align_corners,
+    const c10::optional scale) {
+  if(align_corners) {
+    if(output_size > 1) {
+      return (accscalar_t)(input_size - 1) / (output_size - 1);
+    }
+    else {
+      return static_cast(0);
+    }
+  }
+  else{
+    return compute_scales_value(scale, input_size, output_size);
+  }
+}
+
+template 
+__device__ __forceinline__ static accscalar_t area_pixel_compute_source_index(
+    accscalar_t scale,
+    int dst_index,
+    bool align_corners,
+    bool cubic) {
+  if (align_corners) {
+    return scale * dst_index;
+  } else {
+    accscalar_t src_idx = scale * (dst_index + static_cast(0.5)) -
+        static_cast(0.5);
+    // See Note[Follow Opencv resize logic]
+    return (!cubic && src_idx < static_cast(0))
+        ? static_cast(0)
+        : src_idx;
+  }
+}
+
+// see NOTE [ Nearest neighbor upsampling kernel implementation ]
+__device__ __forceinline__ static int nearest_neighbor_compute_source_index(
+    const float scale,
+    int dst_index,
+    int input_size) {
+  // index_f32 = (output_index) * scale
+  // input_index = round(index_f32)
+  // Same as a buggy OpenCV INTER_NEAREST
+  // We keep this method for BC and consider as deprecated.
+  // See nearest_neighbor_exact_compute_source_index as replacement
+  const int src_index =
+      min(static_cast(floorf((dst_index) * scale)), input_size - 1);
+  return src_index;
+}
+
+__device__ __forceinline__ static int nearest_neighbor_exact_compute_source_index(
+    const float scale,
+    int dst_index,
+    int input_size) {
+  // index_f32 = (output_index + 0.5) * scale - 0.5
+  // input_index = round(index_f32)
+  // Same as Pillow and Scikit-Image/Scipy ndi.zoom
+  const int src_index =
+      min(static_cast(floorf((dst_index + static_cast(0.5)) * scale)), input_size - 1);
+  return src_index;
+}
+
+// see NOTE [ Nearest neighbor upsampling kernel implementation ]
+__device__ __forceinline__ static int nearest_neighbor_bw_compute_source_index(
+    const float scale,
+    int dst_index,
+    int output_size) {
+  // Equivalent to buggy OpenCV INTER_NEAREST
+  // We keep this method for BC and consider as deprecated.
+  // See nearest_neighbor_exact_bw_compute_source_index as replacement
+  const int src_index =
+      min(static_cast(ceilf(dst_index * scale)), output_size);
+  return src_index;
+}
+
+// see NOTE [ Nearest neighbor upsampling kernel implementation ]
+__device__ __forceinline__ static int nearest_neighbor_exact_bw_compute_source_index(
+    const float scale,
+    int dst_index,
+    int output_size) {
+  // Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom
+  const int src_index =
+      min(static_cast(ceilf(dst_index * scale - static_cast(0.5))), output_size);
+  return src_index;
+}
+
+/* Used by UpSampleBicubic2d.cu */
+template 
+__device__ __forceinline__ static scalar_t upsample_get_value_bounded(
+    const PackedTensorAccessor64& data,
+    int batch,
+    int channel,
+    int height,
+    int width,
+    int y,
+    int x) {
+  int access_y = max(min(y, height - 1), 0);
+  int access_x = max(min(x, width - 1), 0);
+  return data[batch][channel][access_y][access_x];
+}
+
+/* Used by UpSampleBicubic2d.cu */
+template 
+__device__ __forceinline__ static void upsample_increment_value_bounded(
+    PackedTensorAccessor64& data,
+    int batch,
+    int channel,
+    int height,
+    int width,
+    int y,
+    int x,
+    accscalar_t value) {
+  int access_y = max(min(y, height - 1), 0);
+  int access_x = max(min(x, width - 1), 0);
+  /* TODO: result here is truncated to scalar_t,
+     check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912
+   */
+  gpuAtomicAddNoReturn(
+      &data[batch][channel][access_y][access_x], static_cast(value));
+}
+
+// Based on
+// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+template 
+__device__ __forceinline__ static accscalar_t cubic_convolution1(
+    accscalar_t x,
+    accscalar_t A) {
+  return ((A + 2) * x - (A + 3)) * x * x + 1;
+}
+
+template 
+__device__ __forceinline__ static accscalar_t cubic_convolution2(
+    accscalar_t x,
+    accscalar_t A) {
+  return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
+}
+
+template 
+__device__ __forceinline__ static void get_cubic_upsampling_coefficients(
+    accscalar_t coeffs[4],
+    accscalar_t t) {
+  accscalar_t A = -0.75;
+
+  accscalar_t x1 = t;
+  coeffs[0] = cubic_convolution2(x1 + 1.0, A);
+  coeffs[1] = cubic_convolution1(x1, A);
+
+  // opposite coefficients
+  accscalar_t x2 = 1.0 - t;
+  coeffs[2] = cubic_convolution1(x2, A);
+  coeffs[3] = cubic_convolution2(x2 + 1.0, A);
+}
+
+template 
+__device__ __forceinline__ static accscalar_t cubic_interp1d(
+    scalar_t x0,
+    scalar_t x1,
+    scalar_t x2,
+    scalar_t x3,
+    accscalar_t t) {
+  accscalar_t coeffs[4];
+  get_cubic_upsampling_coefficients(coeffs, t);
+
+  return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
+}
+
+namespace upsample_antialias {
+
+// taken from
+// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
+// src/libImaging/Resample.c#L20-L29
+struct BilinearFilterFunctor {
+
+  template 
+  __device__ accscalar_t operator()(accscalar_t x) const {
+    if (x < 0) {
+      x = -x;
+    }
+    if (x < 1) {
+      return 1 - x;
+    }
+    return 0;
+  }
+
+  static const int size = 2;
+};
+
+// taken from
+// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
+// src/libImaging/Resample.c#L46-L62
+struct BicubicFilterFunctor {
+
+  template 
+  __device__ accscalar_t operator()(accscalar_t x) const {
+    // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+    const accscalar_t a = -0.5;
+    if (x < 0) {
+      x = -x;
+    }
+    if (x < 1) {
+      return ((a + 2) * x - (a + 3)) * x * x + 1;
+    }
+    if (x < 2) {
+      return (((x - 5) * x + 8) * x - 4) * a;
+    }
+    return 0;
+  }
+
+  static const int size = 4;
+};
+
+template 
+__device__ __forceinline__ static void _compute_weights_span(
+    const int i,
+    const int input_size,
+    const accscalar_t scale,
+    const accscalar_t support,
+    int& xmin,
+    int& xsize,
+    accscalar_t& center) {
+  center = scale * (i + static_cast(0.5));
+  xmin = max(static_cast(center - support + static_cast(0.5)), static_cast(0));
+  xsize = min(static_cast(center + support + static_cast(0.5)), input_size) - xmin;
+}
+
+template 
+__device__ __forceinline__ static void _compute_weights(
+    scalar_t* wt_ptr,
+    const accscalar_t scale,
+    int interp_size,
+    const interp_filter_t& interp_filter,
+    accscalar_t xmin_m_center,
+    int xsize) {
+
+  accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
+  accscalar_t total_w = 0.0;
+  int j = 0;
+  for (j = 0; j < xsize; j++) {
+    accscalar_t w = interp_filter((j + xmin_m_center + static_cast(0.5)) * invscale);
+    wt_ptr[j] = static_cast(w);
+    total_w += w;
+  }
+  for (j = 0; j < xsize; j++) {
+    if (total_w != 0.0) {
+      wt_ptr[j] /= total_w;
+    }
+  }
+  for (; j < interp_size; j++) {
+    wt_ptr[j] = static_cast(0.0);
+  }
+}
+
+template 
+__device__ __forceinline__ static accscalar_t interpolate_aa_single_dim(
+    const scalar_t* src,
+    const scalar_t* weights,
+    int size) {
+  scalar_t t = static_cast(*src);
+  scalar_t wts = static_cast(weights[0]);
+  accscalar_t output = t * wts;
+
+  int j = 1;
+  for (; j < size; j++) {
+    wts = static_cast(weights[j]);
+    t = static_cast(*(src + j));
+    output += t * wts;
+  }
+  return output;
+}
+
+}
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1366dbd3a2adc2fee0f672eaca4df59efdc5c883
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh
@@ -0,0 +1,143 @@
+#pragma once
+
+#include 
+
+#include 
+#include 
+
+namespace at {
+namespace native {
+namespace cuda_utils {
+
+constexpr int kCUDABlockReduceNumThreads = 512;
+// Algorithmic limitation: BlockReduce does two WarpReduce calls, each
+// of which reduces C10_WARP_SIZE elements. So, at most
+// C10_WARP_SIZE**2 elements can be reduced at a time.
+// NOTE: This is >= the max block size on current hardware anyway (1024).
+constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
+
+// Sums `val` across all threads in a warp.
+//
+// Assumptions:
+//   - The size of each block should be a multiple of `C10_WARP_SIZE`
+template 
+__inline__ __device__ T WarpReduceSum(T val) {
+#pragma unroll
+  for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
+    val += WARP_SHFL_DOWN(val, offset);
+  }
+  return val;
+}
+
+// Picks the maximum `val` across all threads in a warp.
+//
+// Assumptions:
+//   - The size of each block should be a multiple of `C10_WARP_SIZE`
+template 
+__inline__ __device__ T WarpReduceMax(T val) {
+#pragma unroll
+  for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
+    val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset));
+  }
+  return val;
+}
+
+struct Block1D {
+    static __forceinline__ __device__ int Tid() { return threadIdx.x; }
+
+    static __forceinline__ __device__ int Warps() {
+        return blockDim.x / C10_WARP_SIZE;
+    }
+};
+
+struct Block2D {
+    static __forceinline__ __device__ int Tid() {
+        return threadIdx.x + threadIdx.y * blockDim.x;
+    }
+
+    static __forceinline__ __device__ int Warps() {
+        return blockDim.x * blockDim.y / C10_WARP_SIZE;
+    }
+};
+
+// Sums `val` across all threads in a block.
+//
+// Warning: the return value is only valid for thread 0.
+// Assumptions:
+//   - The size of each block should be a multiple of `C10_WARP_SIZE`
+//   - `shared` should be a pointer to shared memory with size of, at least,
+//     `sizeof(T) * number_of_warps`
+template 
+__inline__ __device__ T BlockReduceSum(T val, T* shared) {
+  const int tid = B::Tid();
+  const int lid = tid % C10_WARP_SIZE;
+  const int wid = tid / C10_WARP_SIZE;
+  val = WarpReduceSum(val);
+  __syncthreads(); // prevent races when BlockReduces are called in a row.
+  if (lid == 0) {
+    shared[wid] = val;
+  }
+  __syncthreads();
+  val = (tid < B::Warps()) ? shared[lid] : T(0);
+  if (wid == 0) {
+    val = WarpReduceSum(val);
+  }
+  return val;
+}
+
+// Picks out the maximum `val` across all threads in a block.
+//
+// Warning: the return value is only valid for thread 0.
+// Assumptions:
+//   - The size of each block should be a multiple of `C10_WARP_SIZE`
+//   - `shared` should be a pointer to shared memory with size of, at least,
+//     `sizeof(T) * number_of_warps`
+template 
+__inline__ __device__ T BlockReduceMax(T val, T* shared) {
+  const int tid = B::Tid();
+  const int lid = tid % C10_WARP_SIZE;
+  const int wid = tid / C10_WARP_SIZE;
+  val = WarpReduceMax(val);
+  __syncthreads(); // prevent races when BlockReduces are called in a row.
+  if (lid == 0) {
+    shared[wid] = val;
+  }
+  __syncthreads();
+  val = (tid < B::Warps()) ? shared[lid] : T(0);
+  if (wid == 0) {
+    val = WarpReduceMax(val);
+  }
+  return val;
+}
+
+template 
+__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
+#pragma unroll
+  for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
+    val = op.combine(val, op.warp_shfl_down(val, offset));
+  }
+  return val;
+}
+
+template 
+__inline__ __device__ T
+BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
+  const int tid = B::Tid();
+  const int lid = tid % C10_WARP_SIZE;
+  const int wid = tid / C10_WARP_SIZE;
+  val = WarpReduce(val, op);
+  __syncthreads(); // prevent races when BlockReduces are called in a row.
+  if (lid == 0) {
+    shared[wid] = val;
+  }
+  __syncthreads();
+  val = (tid < B::Warps()) ? shared[lid] : identity_element;
+  if (wid == 0) {
+    val = WarpReduce(val, op);
+  }
+  return val;
+}
+
+} // namespace cuda_utils
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..211cb62dcae2bf5feea8bca8dbf5e48f3d81dee9
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh
@@ -0,0 +1,40 @@
+#pragma once
+#include 
+
+namespace at {
+namespace native {
+
+void _fused_adam_amsgrad_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList max_exp_avg_sqs,
+    at::TensorList state_steps,
+    const double lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+void _fused_adam_amsgrad_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList max_exp_avg_sqs,
+    at::TensorList state_steps,
+    const at::Tensor& lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d94d65c9c2ba07f3f3b1b6342abd8ed8da1206b8
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh
@@ -0,0 +1,38 @@
+#pragma once
+#include 
+
+namespace at {
+namespace native {
+
+void _fused_adam_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList state_steps,
+    const double lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+void _fused_adam_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList state_steps,
+    const at::Tensor& lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..be681ad162b5951d733ef3efc0c764d5c0d45d20
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh
@@ -0,0 +1,202 @@
+#pragma once
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+
+enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
+
+namespace {
+
+constexpr uint8_t kParamIdx = 0;
+constexpr uint8_t kGradIdx = 1;
+constexpr uint8_t kExpAvgIdx = 2;
+constexpr uint8_t kExpAvgSqIdx = 3;
+constexpr uint8_t kMaxExpAvgSqIdx = 4;
+
+template <
+    typename scalar_type,
+    typename opmath_t,
+    int depth,
+    ADAM_MODE adam_mode,
+    bool amsgrad>
+C10_DEVICE inline void adam_math(
+    scalar_type r_args[depth][kILP],
+    const double& lr,
+    const double& beta1,
+    const double& beta2,
+    const double& weight_decay,
+    const double& eps,
+    const bool& maximize,
+    const float* grad_scale_ptr,
+    const float* found_inf_ptr,
+    const opmath_t& bias_correction1,
+    const opmath_t& bias_correction2_sqrt) {
+  static_assert(depth == 4 || depth == 5);
+#pragma unroll
+  for (int ii = 0; ii < kILP; ii++) {
+    // Load values.
+    opmath_t param = static_cast(r_args[kParamIdx][ii]);
+    opmath_t grad = static_cast(r_args[kGradIdx][ii]);
+    if (grad_scale_ptr) {
+      grad /= (static_cast(*grad_scale_ptr));
+    }
+    const opmath_t grad_to_store = grad;
+    if (maximize) {
+      grad = -grad;
+    }
+    opmath_t exp_avg = static_cast(r_args[kExpAvgIdx][ii]);
+    opmath_t exp_avg_sq = static_cast(r_args[kExpAvgSqIdx][ii]);
+    opmath_t max_exp_avg_sq;
+    if (amsgrad) {
+      max_exp_avg_sq = static_cast(r_args[kMaxExpAvgSqIdx][ii]);
+    }
+    // Update param, grad, 1st and 2nd order momentum.
+    if (weight_decay != 0) {
+      if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
+        grad += param * weight_decay;
+      } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
+        param -= lr * weight_decay * param;
+      }
+    }
+    // todo(crcrpar): use lerp
+    // ref: https://developer.nvidia.com/blog/lerp-faster-cuda/
+    exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
+    exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
+    const opmath_t step_size = lr / bias_correction1;
+    opmath_t denom;
+    if (amsgrad) {
+      max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq);
+      denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
+    } else {
+      denom = (std::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
+    }
+    param -= step_size * exp_avg / denom;
+
+    // Store results.
+    r_args[kParamIdx][ii] = param;
+    if (grad_scale_ptr) {
+      r_args[kGradIdx][ii] = grad_to_store;
+    }
+    r_args[kExpAvgIdx][ii] = exp_avg;
+    r_args[kExpAvgSqIdx][ii] = exp_avg_sq;
+    if (amsgrad) {
+      r_args[kMaxExpAvgSqIdx][ii] = max_exp_avg_sq;
+    }
+  }
+}
+
+// [note: Conditional Gradient Store when `optimizer.step` is called by
+// GradScaler] When a user is training their model(s) with an FP16 AMP recipe,
+// parameter updates are done via `grad_scaler.step(optimizer)` instead of
+// `optimizer.step()`. For most optimizers, GradScaler unscales gradients on
+// behalf of those optimizers. Also, before `.step`, it makes sure that all the
+// gradients involved are finite, which incurs a device sync. On the other hand,
+// fused optimizers set their member variable of `_step_supports_amp_scaling` to
+// `True` in order to remove the device sync above. This means that fused
+// optimizers have to have their CUDA kernels (a) unscale gradients and (b) skip
+// parameter updates accordingly. To be functionally on par with `torch.optim`
+// optimizers and `_multi_tensor` ones, the kernel below writes out gradients
+// only when `grad_scale_ptr != nullptr.
+template 
+struct FusedAdamMathFunctor {
+  static_assert(
+      depth == 4 || depth == 5,
+      "depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
+  using opmath_t = at::opmath_type;
+  C10_DEVICE __forceinline__ void operator()(
+      int chunk_size,
+      FusedOptimizerTensorListMetadata& tl,
+      const float* lr_ptr,
+      const double& lr,
+      const double& beta1,
+      const double& beta2,
+      const double& weight_decay,
+      const double& eps,
+      const bool& maximize,
+      const float* grad_scale_ptr,
+      const float* found_inf_ptr) {
+    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
+    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
+    const double lr_double = lr_ptr ? *lr_ptr : lr;
+
+    if (found_inf_ptr && *found_inf_ptr == 1) {
+      return;
+    }
+    const auto [bias_correction1, bias_correction2_sqrt] =
+        [&]() -> std::pair {
+      auto* step_count =
+          reinterpret_cast(tl.state_steps_addresses[tensor_loc]);
+      const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count);
+      const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count);
+      const auto bias_correction2_sqrt = std::sqrt(bias_correction2);
+      return {bias_correction1, bias_correction2_sqrt};
+    }();
+
+    scalar_type* args[depth];
+    scalar_type r_args[depth][kILP];
+    const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;
+
+    const bool all_aligned{
+        init_args(args, tl, chunk_idx, chunk_size, tensor_loc)};
+    if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) {
+      for (int64_t i_start = threadIdx.x;
+           i_start * kILP < n && i_start * kILP < chunk_size;
+           i_start += blockDim.x) {
+#pragma unroll
+        for (int i = 0; i < depth; i++) {
+          load_store(r_args[i], args[i], 0, i_start);
+        }
+        adam_math(
+            r_args,
+            lr_double,
+            beta1,
+            beta2,
+            weight_decay,
+            eps,
+            maximize,
+            grad_scale_ptr,
+            found_inf_ptr,
+            bias_correction1,
+            bias_correction2_sqrt);
+#pragma unroll
+        for (int i = 0; i < depth; i++) {
+          if (i != kGradIdx || grad_scale_ptr) {
+            load_store(args[i], r_args[i], i_start, 0);
+          }
+        }
+      }
+    } else {
+      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
+           i_start += blockDim.x * kILP) {
+        load_args(r_args, args, i_start, chunk_size, n);
+        adam_math(
+            r_args,
+            lr_double,
+            beta1,
+            beta2,
+            weight_decay,
+            eps,
+            maximize,
+            grad_scale_ptr,
+            found_inf_ptr,
+            bias_correction1,
+            bias_correction2_sqrt);
+#pragma unroll
+        for (int i = 0; i < depth; i++) {
+          if (i != kGradIdx || grad_scale_ptr) {
+            store_args(args[i], r_args[i], i_start, chunk_size, n);
+          }
+        }
+      }
+    }
+  }
+};
+} // namespace
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..06782055593021b1301f8f670b444db5af9001da
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh
@@ -0,0 +1,40 @@
+#pragma once
+#include 
+
+namespace at {
+namespace native {
+
+void _fused_adamw_amsgrad_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList max_exp_avg_sqs,
+    at::TensorList state_steps,
+    const double lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+void _fused_adamw_amsgrad_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList max_exp_avg_sqs,
+    at::TensorList state_steps,
+    const at::Tensor& lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..6d454ed93960ec71fc598ee1f8a2cd96f01301d2
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh
@@ -0,0 +1,38 @@
+#pragma once
+#include 
+
+namespace at {
+namespace native {
+
+void _fused_adamw_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList state_steps,
+    const double lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+void _fused_adamw_cuda_impl_(
+    at::TensorList params,
+    at::TensorList grads,
+    at::TensorList exp_avgs,
+    at::TensorList exp_avg_sqs,
+    at::TensorList state_steps,
+    const at::Tensor& lr,
+    const double beta1,
+    const double beta2,
+    const double weight_decay,
+    const double eps,
+    const bool maximize,
+    const c10::optional& grad_scale,
+    const c10::optional& found_inf);
+
+} // namespace native
+} // namespace at
diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/im2col.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/im2col.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..8829ecb6155cb12b91b35575baa29973fb963ebb
--- /dev/null
+++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/im2col.cuh
@@ -0,0 +1,345 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include 
+
+namespace at {
+namespace native {
+
+using namespace at::cuda::detail;
+
+// Kernel for fast unfold+copy
+// (borrowed from Caffe:
+// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
+// CUDA_NUM_THREADS = 1024
+
+template 
+C10_LAUNCH_BOUNDS_1(1024)
+__global__ void im2col_kernel(
+    const int64_t n,
+    const dt* data_im,
+    const int64_t height,
+    const int64_t width,
+    const int64_t kernel_height,
+    const int64_t kernel_width,
+    const int64_t pad_height,
+    const int64_t pad_width,
+    const int64_t stride_height,
+    const int64_t stride_width,
+    const int64_t dilation_height,
+    const int64_t dilation_width,
+    const int64_t height_col,
+    const int64_t width_col,
+    dt* data_col) {
+  CUDA_KERNEL_LOOP(index, n) {
+    int64_t w_out = index % width_col;
+
+    int64_t idx = index / width_col;
+
+    int64_t h_out = idx % height_col;
+    int64_t channel_in = idx / height_col;
+    int64_t channel_out = channel_in * kernel_height * kernel_width;
+    int64_t h_in = h_out * stride_height - pad_height;
+    int64_t w_in = w_out * stride_width - pad_width;
+
+    dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out;
+    const dt* im = data_im + (channel_in * height + h_in) * width + w_in;
+
+    for (int64_t i = 0; i < kernel_height; ++i) {
+      for (int64_t j = 0; j < kernel_width; ++j) {
+        int64_t h = h_in + i * dilation_height;
+        int64_t w = w_in + j * dilation_width;
+        *col = (h >= 0 && w >= 0 && h < height && w < width)
+            ? im[i * dilation_height * width + j * dilation_width]
+            : static_cast
(0); + col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__forceinline__ __device__ void col2im_device( + const int64_t index, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + col2im_device( + index, + data_col, + height, + width, + channels, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_batched_kernel( + const int64_t n, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im, + const int64_t im_batch_stride) { + using accT = at::acc_type; + const auto im_numel = n * nbatch; + + CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) { + const auto ibatch = index / n; + const auto slice_index = index % n; + + col2im_device( + slice_index, + data_col + ibatch * col_batch_stride, + height, + width, + channels, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im + ibatch * im_batch_stride); + } +} + +template +void col2im_batched( + cudaStream_t stream, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im, + const int64_t im_batch_stride) { + const int64_t num_kernels = channels * height * width; + const int64_t output_numel = nbatch * num_kernels; + if (output_numel == 0) { + return; // No work to do + } + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_batched_kernel<<>>( + num_kernels, + data_col, + col_batch_stride, + nbatch, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im, + im_batch_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/jit_utils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/jit_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..38af0cc125d5f191ea7d6321853198a3ac79d11f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/jit_utils.h @@ -0,0 +1,215 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace at { namespace cuda { namespace jit { + +enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar}; + +struct NvrtcFunction { + CUmodule module = CUmodule(); + CUfunction function = nullptr; +}; + +struct KernelDescriptor { + std::string name; + std::string f; + c10::ScalarType f_inputs_type; + c10::ScalarType result_type; + c10::SmallVector extra_args_types; + int nInputs, nOutputs; +}; + +// Helper function to return a vector +// corresponding to the type of the arguments in parameter pack. +template +c10::SmallVector get_extra_args_types() { + return {c10::CppTypeToScalarType::value ...}; +} + +template < + typename result_type, + typename f_inputs_type, + typename... ExtraArgs> +KernelDescriptor make_kernel_descriptor( + std::string name, + std::string f, + int nInputs, + int nOutputs) { + KernelDescriptor ret; + ret.name = std::move(name); + ret.f = std::move(f); + ret.f_inputs_type = c10::CppTypeToScalarType::value; + ret.result_type = c10::CppTypeToScalarType::value; + ret.extra_args_types = get_extra_args_types(); + ret.nInputs = nInputs; + ret.nOutputs = nOutputs; + return ret; +} + +inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { + auto ip = reinterpret_cast(pointer); + if (ip % (4 * default_alignment) == 0) { + return 4; + } + if (ip % (2 * default_alignment) == 0) { + return 2; + } + return 1; +} + +inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { + TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); + TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); + + // Deals with output + auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); + int result = can_vectorize_up_to(result_size, pointers[0]); + + // Incorporates input(s) + auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + for (auto i : c10::irange(1, pointers.size())) { + result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); + } + + return result; +} + +std::string generate_code( + int nInputs, + int nOutputs, + const std::string& func, + const std::string& name, + const std::string& f_input_type, + const std::string& compute_type, + const std::string& result_type, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + c10::SmallVector& extra_args_typenames, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_reduction_code( + int nOutputs, + const std::string& func, + const std::string& name, + const int vt0, + const std::string& f_inputs_type, + const std::string& reduction_accum_type, + const std::string& result_type, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +std::string generate_reduction_code( + const KernelDescriptor &desc, + const int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +NvrtcFunction jit_pwise_function( + const std::string& code, + const std::string& kernel_name); + +void launch_jitted_pwise_function( + NvrtcFunction function, + void* args[], + const dim3 nBlocks, + const dim3 kBlockSize, + const int smem=0); + +template +struct delayed_false : std::false_type { +}; + +// Defines type names +// NOTE: General case is instantiated only for invalid types. +// All the valid types have specialization using the TYPE_NAME_FN +// macro below. +template +inline std::string typeName() { + // we can't use static_assert(false) directly as the + // program will be not compiled even if the template is not + // instantiated, so we use `delayed_false` + // to make sure compiler doesn't eagerly raise + // fail this assertion. + static_assert(delayed_false::value, "invalid type for jiterator"); + return "void"; +} + +#define TYPE_NAME_FN(ctype, name) \ +template <> inline std::string typeName(){ \ + return std::string(#ctype); \ +} + +AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN) +#undef TYPE_NAME_FN +// JIT uses std::complex directly, because nvRTC compile programs +// with -default-device, so there is no such issue like: +// "std::sin(complex) is __host__ only" +template <> inline std::string typeName(){ + return "bool"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName(){ + return "at::Half"; +} +template <> inline std::string typeName(){ + return "at::BFloat16"; +} +template <> inline std::string typeName(){ + return "at::Float8_e5m2"; +} +template <> inline std::string typeName(){ + return "at::Float8_e4m3fn"; +} +template <> inline std::string typeName() { + return "at::Float8_e5m2fnuz"; +} +template <> inline std::string typeName() { + return "at::Float8_e4m3fnuz"; +} + +#define TYPE_NAME_CASE(ctype, scalartype) \ + case ScalarType::scalartype: return typeName(); +inline std::string typeName(ScalarType t) { + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) + default: + TORCH_CHECK(false, "invalid type for jiterator"); + } +} +#undef TYPE_NAME_CASE + +TORCH_CUDA_CPP_API void initializeCudaContext(); + +}}} // namespace at::cuda::jit diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6350c44eab91827ac7a7fc1df75ca7f88ad44c7e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh @@ -0,0 +1,680 @@ +namespace at { +namespace cuda { +//windows doesn't like large string literals, so split in two +const std::string reduction_template_0 = R"ESCAPE( + #define C10_HOST_DEVICE __host__ __device__ + #define C10_DEVICE __device__ + #if defined(__clang__) && defined(__HIP__) + #ifndef __forceinline__ + #define __forceinline__ inline __attribute__((always_inline)) + #endif + // until ROCm support for kernel asserts is restored + #define assert(expr) (static_cast(0)) + #endif + + template + __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + #if defined(__clang__) && defined(__HIP__) + return __shfl_down(value, delta, width); + #else + return __shfl_down_sync(mask, value, delta, width); + #endif + } + + + #if ${complex} + template + __device__ __forceinline__ std::complex WARP_SHFL_DOWN(std::complex value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) + { + return std::complex( + #if defined(__clang__) && defined(__HIP__) + __shfl_down(value.real(), delta, width), + __shfl_down(value.imag(), delta, width)); + #else + __shfl_down_sync(mask, value.real(), delta, width), + __shfl_down_sync(mask, value.imag(), delta, width)); + #endif + } + #endif + + // aligned vector generates vectorized load/store on CUDA + template + struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; + }; + + + C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) { + // get GCD of num and denom using Euclid's algorithm. + // Can replace this with std::gcd if we ever support c++17. + size_t a = denominator; + size_t b = numerator; + while (b != 0) { + a %= b; + // swap(a,b) + size_t tmp = a; + a = b; + b = tmp; + } + + // a is now the GCD + numerator /= a; + denominator /= a; + } + + + + + struct ReduceConfig { + //has to match host-side ReduceConfig in the eager code + static constexpr int BLOCK_X = 0; + static constexpr int BLOCK_Y = 1; + static constexpr int CTA = 2; + + static constexpr int input_vec_size = 4; + int element_size_bytes; + int num_inputs; + int num_outputs; + int step_input = 1; + int step_output = 1; + int ctas_per_output = 1; + int input_mult[3] = {0, 0, 0}; + int output_mult[2] = {0, 0}; + + int block_width; + int block_height; + int num_threads; + + bool vectorize_input = false; + int output_vec_size = 1; + + C10_HOST_DEVICE bool should_block_x_reduce() const { + return input_mult[BLOCK_X] != 0; + } + + C10_HOST_DEVICE bool should_block_y_reduce() const { + return input_mult[BLOCK_Y] != 0; + } + + C10_HOST_DEVICE bool should_global_reduce() const { + return input_mult[CTA] != 0; + } + + C10_DEVICE bool should_store(int output_idx) const { + return output_idx < num_outputs && + (!should_block_x_reduce() || threadIdx.x == 0) && + (!should_block_y_reduce() || threadIdx.y == 0); + } + + C10_DEVICE bool should_reduce_tail() const { + return (!should_block_y_reduce() || threadIdx.y == 0) && + (!should_global_reduce() || blockIdx.y == 0); + } + + C10_HOST_DEVICE int input_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta2 = blockIdx.y; + return (lane * input_mult[BLOCK_X] + + warp * input_mult[BLOCK_Y] + + cta2 * input_mult[CTA]); + } + + template + C10_HOST_DEVICE int output_idx() const { + int lane = threadIdx.x; + int warp = threadIdx.y; + int cta1 = blockIdx.x; + return (lane * output_mult[BLOCK_X] + + warp * output_mult[BLOCK_Y] + + cta1 * step_output) * output_vec_size; + } + + C10_DEVICE int shared_memory_offset(int offset) const { + return threadIdx.x + (threadIdx.y + offset) * blockDim.x; + } + + C10_DEVICE int staging_memory_offset(int cta2) const { + int offset = cta2 + blockIdx.x * gridDim.y; + if (!should_block_x_reduce()) { + offset = threadIdx.x + offset * blockDim.x; + } + return offset; + } + + + }; + + +//TODO this will need to be different for more generic reduction functions +namespace reducer { + + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + + inline __device__ ${functor} + + inline __device__ out_scalar_t project(arg_t arg) { + return (out_scalar_t) arg; + } + + inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) { + return WARP_SHFL_DOWN(arg, offset); + } + + inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) { + return acc; + } + + // wrap a normal reduction that ignores the index + inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) { + return combine(acc, val); + } +} + + +struct ReduceJitOp { + using scalar_t = ${scalar_type}; + using arg_t = ${reduction_accum_type}; + using out_scalar_t = ${result_type}; + + using InputCalculator = OffsetCalculator<1>; + using OutputCalculator = OffsetCalculator<2>; + +// static constexpr bool can_accumulate_in_output = +// std::is_convertible::value +// && std::is_convertible::value; + + static constexpr int input_vec_size = ReduceConfig::input_vec_size; + + arg_t ident; + ReduceConfig config; + InputCalculator input_calc; + OutputCalculator output_calc; + const void* src; + const char* dst[2]; //it accepts at most two destinations + // acc_buf used for accumulation among sub Tensor Iterator when accumulation on + // output is not permissible + void* acc_buf; + // cta_buf used for accumulation between blocks during global reduction + void* cta_buf; + int* semaphores; + int64_t base_idx; + bool accumulate; + bool final_output; + int noutputs; + + + C10_DEVICE void run() const { + extern __shared__ char shared_memory[]; + uint32_t output_idx = config.output_idx<${output_vec_size}>(); + uint32_t input_idx = config.input_idx(); + auto base_offsets1 = output_calc.get(output_idx)[1]; + + using arg_vec_t = Array; + arg_vec_t value; + + if (output_idx < config.num_outputs && input_idx < config.num_inputs) { + const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1); + + value = thread_reduce<${output_vec_size}>(input_slice); + } + + if (config.should_block_y_reduce()) { + value = block_y_reduce<${output_vec_size}>(value, shared_memory); + } + if (config.should_block_x_reduce()) { + value = block_x_reduce<${output_vec_size}>(value, shared_memory); + } + + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + arg_vec_t* acc = nullptr; + if (acc_buf != nullptr) { + size_t numerator = sizeof(arg_t); + size_t denominator = sizeof(out_scalar_t); + reduce_fraction(numerator, denominator); + acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator)); + } + + if (config.should_global_reduce()) { + value = global_reduce<${output_vec_size}>(value, acc, shared_memory); + } else if (config.should_store(output_idx)) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output<${output_vec_size}>(out, value); + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < ${output_vec_size}; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output<${output_vec_size}>(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + template + C10_DEVICE Array thread_reduce(const scalar_t* data) const { + if (config.vectorize_input) { + assert(output_vec_size == 1); + // reduce at the header of input_slice where memory is not aligned, + // so that thread_reduce will have an aligned memory to work on. + return {input_vectorized_thread_reduce_impl(data)}; + } else { + uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t); + bool is_contiguous = (input_calc.dims == 1 && element_stride == 1); + if (is_contiguous) { + return thread_reduce_impl(data, [](uint32_t idx) { return idx; }); + } else if (input_calc.dims == 1) { + return thread_reduce_impl(data, [&](uint32_t idx) { return idx * element_stride; }); + } else { + return thread_reduce_impl(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); }); + } + } + } + + C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const { + uint32_t end = config.num_inputs; + + // Handle the head of input slice where data is not aligned + arg_t value = ident; + constexpr int align_bytes = alignof(aligned_vector); + constexpr int align_elements = align_bytes / sizeof(scalar_t); + int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t); + if (shift > 0) { + data -= shift; + end += shift; + if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){ + value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift); + } + end -= align_elements; + data += align_elements; + shift = align_elements - shift; + } + + // Do the vectorized reduction + using load_t = aligned_vector; + + uint32_t idx = config.input_idx(); + const uint32_t stride = config.step_input; + + // Multiple accumulators to remove dependency between unrolled loops. + arg_t value_list[input_vec_size]; + value_list[0] = value; + + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[i] = ident; + } + + scalar_t values[input_vec_size]; + + load_t *values_vector = reinterpret_cast(&values[0]); + + while (idx * input_vec_size + input_vec_size - 1 < end) { + *values_vector = reinterpret_cast(data)[idx]; + #pragma unroll + for (uint32_t i = 0; i < input_vec_size; i++) { + value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i); + } + idx += stride; + } + + // tail + uint32_t tail_start = end - end % input_vec_size; + if (config.should_reduce_tail()) { + int idx = tail_start + threadIdx.x; + if (idx < end) { + value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift); + } + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < input_vec_size; i++) { + value_list[0] = reducer::combine(value_list[0], value_list[i]); + } + return value_list[0]; + } + + template + C10_DEVICE Array thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const { + uint32_t idx = config.input_idx(); + const uint32_t end = config.num_inputs; + const uint32_t stride = config.step_input; + const int vt0=${vt0}; + + using arg_vec_t = Array; + using load_t = aligned_vector; + const load_t* data = reinterpret_cast(data_); + + // Multiple accumulators to remove dependency between unrolled loops. + arg_vec_t value_list[vt0]; + + #pragma unroll + for (int i = 0; i < vt0; i++) { + #pragma unroll + for (int j = 0; j < output_vec_size; j++) { + value_list[i][j] = ident; + } + } + + load_t values[vt0]; + + while (idx + (vt0 - 1) * stride < end) { + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + values[i] = data[calc(idx + i * stride) / output_vec_size]; + } + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride); + } + } + idx += stride * vt0; + } + + // tail + int idx_ = idx; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + values[i] = data[calc(idx) / output_vec_size]; + idx += stride; + } + idx = idx_; + #pragma unroll + for (uint32_t i = 0; i < vt0; i++) { + if (idx >= end) { + break; + } + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx); + } + idx += stride; + } + + // combine accumulators + #pragma unroll + for (int i = 1; i < vt0; i++) { + #pragma unroll + for (uint32_t j = 0; j < output_vec_size; j++) { + value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]); + } + } + return value_list[0]; + } + template + C10_DEVICE Array block_x_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + int dim_x = blockDim.x; + args_vec_t* shared = (args_vec_t*)shared_memory; + if (dim_x > warpSize) { + int address_base = threadIdx.x + threadIdx.y*blockDim.x; + shared[address_base] = value; + for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + __syncthreads(); + if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { + args_vec_t other = shared[address_base + offset]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[address_base] = value; + } + } + dim_x = warpSize; + } + + __syncthreads(); + + for (int offset = 1; offset < dim_x; offset <<= 1) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + arg_t other = reducer::warp_shfl_down(value[i], offset); + value[i] = reducer::combine(value[i], other); + } + } + return value; + } + + template + C10_DEVICE Array block_y_reduce(Array value, char* shared_memory) const { + using args_vec_t = Array; + args_vec_t* shared = (args_vec_t*)shared_memory; + shared[config.shared_memory_offset(0)] = value; + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + args_vec_t other = shared[config.shared_memory_offset(offset)]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], other[i]); + } + shared[config.shared_memory_offset(0)] = value; + } + } + return value; + } + )ESCAPE"; + + const std::string reduction_template_1 = R"ESCAPE( + + C10_DEVICE bool mark_block_finished() const { + __shared__ bool is_last_block_done_shared; + + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0) { + int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1); + } + + __syncthreads(); + + return is_last_block_done_shared; + } + + template + C10_DEVICE Array accumulate_in_output( + Array out, + Array value + ) const { + Array ret; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + ret[i] = reducer::combine(*(out[i]), value[i]); + } + return ret; + } + + + C10_DEVICE out_scalar_t get_accumulated_output( + out_scalar_t* out, arg_t value + ) const { + assert(!final_output); + return (out_scalar_t)value; + } + + template + C10_DEVICE void set_results(const T x, const uint32_t base_offset) const { + assert(noutputs == 1); + auto res = (out_scalar_t*)((char*)dst[0] + base_offset); + *res = x; + } + +//TODO - multi-output reduction - we won't be able to use thrust::pair +//just explicitly specify typed output reads/writes +//Currently implemented for max of two outputs +// template +// C10_DEVICE void set_results(const thrust::pair x, const index_t base_offset) const { +// if (noutputs >= 1) { +// auto res0 = (T1*)((char*)dst[0] + base_offset); +// *res0 = x.first; +// } +// if (noutputs >= 2) { +// // base offset is computed assuming element size being sizeof(T1), so we need to make a +// // correction to obtain the correct base offset +// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2)); +// *res1 = x.second; +// } +// } + + template + C10_DEVICE void set_results_to_output(Array value, Array base_offset) const { + assert(final_output); + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + set_results(reducer::project(value[i]), base_offset[i]); + } + } + + template + C10_DEVICE Array global_reduce(Array value, Array *acc, char* shared_memory) const { + using arg_vec_t = Array; + using out_ptr_vec_t = Array; + using offset_vec_t = Array; + + arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf; + uint32_t output_idx = config.output_idx(); + offset_vec_t base_offsets; + out_ptr_vec_t out; + + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + base_offsets[i] = output_calc.get(output_idx + i)[0]; + out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]); + } + + bool should_store = config.should_store(output_idx); + if (should_store) { + uint32_t offset = config.staging_memory_offset(blockIdx.y); + reduce_buffer[offset] = value; + } + + __threadfence(); // make sure writes are globally visible + __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done + bool is_last_block_done = mark_block_finished(); + + if (is_last_block_done) { + value = ident; + if (config.should_block_x_reduce()) { + uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x; + uint32_t step = blockDim.x * blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } else { + uint32_t input_offset = threadIdx.y; + uint32_t step = blockDim.y; + for (; input_offset < config.ctas_per_output; input_offset += step) { + uint32_t idx = config.staging_memory_offset(input_offset); + arg_vec_t next = reduce_buffer[idx]; + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine(value[i], next[i]); + } + } + } + value = block_y_reduce(value, shared_memory); + if (config.should_block_x_reduce()) { + value = block_x_reduce(value, shared_memory); + } + if (should_store) { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::translate_idx(value[i], base_idx); + } + } + + if (acc == nullptr) { + if (accumulate) { + value = accumulate_in_output(out, value); + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + *(out[i]) = get_accumulated_output(out[i], value[i]); + } + } + } else { + if (accumulate) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = reducer::combine((*acc)[i], value[i]); + } + } + if (final_output) { + set_results_to_output(value, base_offsets); + } else { + *acc = value; + } + } + } + } + + return value; + } +}; + +extern "C" +__launch_bounds__(${max_threads_lb}, 4) +__global__ void reduction_${name}_kernel(ReduceJitOp r){ + r.run(); +} +)ESCAPE"; + +const std::string reduction_template = reduction_template_0 + reduction_template_1; + + +const std::string &get_reduction_template() { + return reduction_template; +} + +}} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/thread_constants.h b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/thread_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..6c8e524a0467ad1034c069d2b69a58dab92d7d68 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/thread_constants.h @@ -0,0 +1,22 @@ +#pragma once +#include + +// Marks a lambda as executable on both the host and device. The __host__ +// attribute is important so that we can access static type information from +// the host, even if the function is typically only executed on the device. +#ifndef GPU_LAMBDA +#define GPU_LAMBDA __host__ __device__ +#endif + +#if defined(USE_ROCM) +constexpr int num_threads() { + return 256; +} +#else +constexpr uint32_t num_threads() { + return C10_WARP_SIZE * 4; +} +#endif + +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/vol2col.cuh b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/vol2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..285fd470563d81562743808efdd3a39300e4264c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/cuda/vol2col.cuh @@ -0,0 +1,263 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy on volumes +template +__global__ void vol2col_kernel( + const int64_t n, + const T* data_vol, + const int depth, + const int height, + const int width, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + const int depth_col, + const int height_col, + const int width_col, + T* data_col) { + CUDA_KERNEL_LOOP(index, n) { + auto w_out = index % width_col; + index /= width_col; + auto h_out = index % height_col; + index /= height_col; + auto t_out = index % depth_col; + auto channel_in = index / depth_col; + auto channel_out = channel_in * ksize_t * ksize_h * ksize_w; + auto t_in = t_out * stride_t - pad_t; + auto h_in = h_out * stride_h - pad_h; + auto w_in = w_out * stride_w - pad_w; + data_col += + ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + + w_out; + data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in; + for (int i = 0; i < ksize_t; ++i) { + for (int j = 0; j < ksize_h; ++j) { + for (int k = 0; k < ksize_w; ++k) { + auto t = t_in + i * dilation_t; + auto h = h_in + j * dilation_h; + auto w = w_in + k * dilation_w; + *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && + w < width) + ? data_vol + [i * dilation_t * height * width + j * dilation_h * width + + k * dilation_w] + : static_cast(0); + data_col += depth_col * height_col * width_col; + } + } + } + } +} + +template +void vol2col( + cudaStream_t stream, + const T* data_vol, + const int channels, + const int depth, + const int height, + const int width, + const int depth_col, + const int height_col, + const int width_col, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + T* data_col) { + // We are going to launch channels * depth_col * height_col * width_col + // kernels, each kernel responsible for copying a single-channel grid. + // We cast an operand to int64 so that the product will not overflow + const auto num_kernels = static_cast(channels) * depth_col * height_col * width_col; + // Launch + vol2col_kernel<<>>( + num_kernels, + data_vol, + depth, + height, + width, + ksize_t, + ksize_h, + ksize_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + depth_col, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__global__ void vol2im_kernel( + const int64_t n, + const T* data_col, + const unsigned depth, + const unsigned height, + const unsigned width, + const unsigned channels, + const unsigned kernel_t, + const unsigned kernel_h, + const unsigned kernel_w, + const unsigned pad_t, + const unsigned pad_h, + const unsigned pad_w, + const unsigned stride_t, + const unsigned stride_h, + const unsigned stride_w, + const unsigned dilation_t, + const unsigned dilation_h, + const unsigned dilation_w, + const unsigned depth_col, + const unsigned height_col, + const unsigned width_col, + T* data_vol) { + CUDA_KERNEL_LOOP(index, n) { + accT val = static_cast(0); + const auto w_im = index % width + pad_w; + const auto h_im = (index / width) % height + pad_h; + const auto t_im = (index / width / height) % depth + pad_t; + const auto c_im = index / (width * height * depth); + auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1; + // compute the start and end of the output + const auto w_col_start = + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const auto w_col_end = std::min(w_im / stride_w + 1, width_col); + const auto h_col_start = + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const auto h_col_end = std::min(h_im / stride_h + 1, height_col); + const auto t_col_start = + (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; + const auto t_col_end = std::min(t_im / stride_t + 1, depth_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) { + for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) { + uint64_t t_k = (t_im - t_col * stride_t); + uint64_t h_k = (h_im - h_col * stride_h); + uint64_t w_k = (w_im - w_col * stride_w); + if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && + w_k % dilation_w == 0) { + t_k /= dilation_t; + h_k /= dilation_h; + w_k /= dilation_w; + const int64_t idx_k = + ((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k; + const int64_t data_col_index = + ((idx_k * depth_col + t_col) * + height_col + h_col) * + width_col + w_col; + val += data_col[data_col_index]; + } + } + } + } + data_vol[index] = static_cast(val); + } +} + +template +void col2vol( + cudaStream_t stream, + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t output_depth, + const int64_t output_height, + const int64_t output_width, + const int64_t patch_t, + const int64_t patch_h, + const int64_t patch_w, + const int64_t pad_t, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_t, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_t, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_vol) { + const auto num_kernels = channels * depth * height * width; + + auto check_fits_in_unsigned = + [](int64_t val, const char * name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK(val >= 0 && val <= umax, + name, " must fit in a 32-bit unsigned value"); + }; + check_fits_in_unsigned(num_kernels, "input size"); + check_fits_in_unsigned( + channels * patch_t * patch_h * patch_w, "channels x kernel size"); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + vol2im_kernel + <<>>( + num_kernels, + data_col, + depth, + height, + width, + channels, + patch_t, + patch_h, + patch_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + output_depth, + output_height, + output_width, + data_vol); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/group_norm.h b/MLPY/Lib/site-packages/torch/include/ATen/native/group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..2747015d02fb20e2407719885867250cb9b4cfb1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/group_norm.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using forward_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + int64_t /* N */, + int64_t /* C */, + int64_t /* HxW */, + int64_t /* group */, + double /* eps */, + Tensor& /* Y */, + Tensor& /* mean */, + Tensor& /* rstd */); + +using backward_fn = void (*)( + const Tensor& /* dY */, + const Tensor& /* X */, + const Tensor& /* mean */, + const Tensor& /* rstd */, + const Tensor& /* gamma */, + int64_t /* N */, + int64_t /* C */, + int64_t /* HxW */, + int64_t /* group */, + Tensor& /* dX */, + Tensor& /* dgamma */, + Tensor& /* dbeta */); + +DECLARE_DISPATCH(forward_fn, GroupNormKernel); +DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/im2col.h b/MLPY/Lib/site-packages/torch/include/ATen/native/im2col.h new file mode 100644 index 0000000000000000000000000000000000000000..c9093c4ae116d1135af196d943a065942527e684 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/im2col.h @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +template +static void im2col( + const T* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_col, + bool is_channels_last = false) { + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + if (is_channels_last) { + at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) { + int64_t h_col{0}, w_col{0}; + data_index_init(begin, h_col, height_col, w_col, width_col); + + for (const auto i_col : c10::irange(begin, end)) { + for (const auto h_offset : c10::irange(kernel_h)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_offset : c10::irange(kernel_w)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + const T* slice_im = data_im + (h_im * width + w_im) * channels; + T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels; + + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::copy_n(slice_im, channels, slice_col); + } else { + std::fill_n(slice_col, channels, T(0)); + } + } + } + + // move the next index + data_index_step(h_col, height_col, w_col, width_col); + } + }); + } else { + at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) { + int64_t c_im{0}, h_offset{0}, w_offset{0}; + data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w); + + for (const auto c_col : c10::irange(begin, end)) { + for (const auto h_col : c10::irange(height_col)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_col : c10::irange(width_col)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + data_col[(c_col * height_col + h_col) * width_col + w_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(c_im * height + h_im) * width + w_im] + : static_cast(0); + } + } + + // move to the next index + data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w); + } + }); + } +} + +template +static void col2im( + const T* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_im, + bool is_channels_last = false) { + std::fill_n(data_im, height * width * channels, T(0)); + + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + if (is_channels_last) { + for (const auto h_col : c10::irange(height_col)) { + for (const auto w_col : c10::irange(width_col)) { + for (const auto h_offset : c10::irange(kernel_h)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_offset : c10::irange(kernel_w)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + T* slice_im = data_im + (h_im * width + w_im) * channels; + const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w + + h_offset * kernel_w + w_offset) * channels; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) { + std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus()); + } + } + } + } + } + } else { + for (const auto c_col : c10::irange(channels_col)) { + int64_t w_offset = c_col % kernel_w; + int64_t h_offset = (c_col / kernel_w) % kernel_h; + int64_t c_im = c_col / kernel_h / kernel_w; + + for (const auto h_col : c10::irange(height_col)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_col : c10::irange(width_col)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) + data_im[(c_im * height + h_im) * width + w_im] += + data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + } + } +} + +} // namespace at::native diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/im2col_shape_check.h b/MLPY/Lib/site-packages/torch/include/ATen/native/im2col_shape_check.h new file mode 100644 index 0000000000000000000000000000000000000000..9fa2afa27cfd00c809de4852811458fc175d3e33 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/im2col_shape_check.h @@ -0,0 +1,232 @@ +#pragma once +#include +#include +#include + +namespace at::native { + +static inline void col2im_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t output_height, + int64_t output_width, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + + int64_t ndim = input.ndimension(); + // allow dim=0 only the batch dimension. + TORCH_CHECK( + (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) || + (ndim == 3 && input.size(1) != 0 && input.size(2) != 0), + "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ", + input.sizes()); + + int64_t batch_dim = (ndim == 3) ? 0 : -1; + int64_t n_input_plane = input.size(batch_dim + 1); + + if (n_input_plane % (kernel_width * kernel_height) != 0) { + AT_ERROR( + "Expected size of input's dimension 1 to be divisible by the " + "product of kernel_size, but got input.size(1)=", + n_input_plane, + " and kernel_size=(", + kernel_height, + ", ", + kernel_width, + ")."); + } + + int64_t input_length = input.size(batch_dim + 2); + int64_t n_blocks_height = + div_rtn( + output_height + 2 * pad_height - + dilation_height * (kernel_height - 1) - 1, + stride_height) + + 1; + int64_t n_blocks_width = div_rtn( + output_width + 2 * pad_width - + dilation_width * (kernel_width - 1) - 1, + stride_width) + + 1; + + if (input_length != (n_blocks_height * n_blocks_width)) { + AT_ERROR( + "Given output_size=(", + output_height, + ", ", + output_width, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), stride=(", + stride_height, + ", ", + stride_width, + "), expected size of input's dimension 2 to match the calculated number of ", + "sliding blocks ", + n_blocks_height, + " * ", + n_blocks_width, + " = ", + (n_blocks_height * n_blocks_width), + ", but got input.size(2)=", + input_length, + "."); + } + + TORCH_CHECK( + n_blocks_height >= 1 && n_blocks_width >= 1, + "Given output_size=(", output_height, ", ", output_width, "), ", + "kernel_size=(", kernel_height, ", ", kernel_width, "), ", + "dilation=(", dilation_height, ", ", dilation_width, "), ", + "padding=(", pad_height, ", ", pad_width, "), ", + "stride=(", stride_height, ", ", stride_width, "), ", + "calculated shape of the array of sliding blocks as ", + "(", n_blocks_height, ", ", n_blocks_width, "), ", + "which is too small (non-positive)"); + + if (output_width < 1 || output_height < 1) { + AT_ERROR( + "Expected output spatial size to be positive, but got: output_size=(", + output_height, + ", ", + output_width, + ")."); + } +} + +static inline void im2col_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + + int64_t ndim = input.ndimension(); + + // allow dim=0 only the batch dimension. + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + TORCH_CHECK( + (ndim == 3 && input.size(0) && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); + + int64_t dim_batch = 0; + + if (ndim == 3) { + dim_batch = -1; + } + + int64_t input_height = input.size(dim_batch + 2); + int64_t input_width = input.size(dim_batch + 3); + int64_t output_height = div_rtn( + input_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1), + stride_height) + + 1; + int64_t output_width = div_rtn( + input_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1), + stride_width) + + 1; + + if (output_height < 1 || output_width < 1) { + AT_ERROR( + "Given input with spatial size (", + input_height, + ", ", + input_height, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), calculated shape of the array of sliding blocks as (", + output_height, + ", ", + output_width, + "), but its components must be at least one."); + } +} + +} // namespace at::native diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/layer_norm.h b/MLPY/Lib/site-packages/torch/include/ATen/native/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..b452a0575397cb4e764abc7186c0a77717a9ac28 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/layer_norm.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +namespace { + +C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs( + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } + + const int axis = input_ndim - normalized_ndim; + const int64_t M = + c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend()); + + return std::make_pair(M, N); +} + +} // namespace + +void layer_norm_cpu_out( + at::Tensor& out, + const at::Tensor& input, + const Tensor& gamma, + const Tensor& beta, + double eps, + int64_t M, + int64_t N); + +using forward_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */, + Tensor* /* mean */, + Tensor* /* rstd */); + +using backward_fn = void (*)( + const Tensor& /* dY */, + const Tensor& /* X */, + const Tensor& /* mean */, + const Tensor& /* rstd */, + const Tensor& /* gamma */, + int64_t /* M */, + int64_t /* N */, + Tensor* /* dX */, + Tensor* /* dgamma */, + Tensor* /* dbeta */); + +DECLARE_DISPATCH(forward_fn, LayerNormKernel); +DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel); + +} // namespace at::native diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/mps/Copy.h b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..03366154489d12267705dc699bde5f9bdf4e3025 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/Copy.h @@ -0,0 +1,15 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include + +namespace at { +namespace native { +namespace mps { + +at::Tensor& mps_copy_(at::Tensor& dst, const at::Tensor& src, bool non_blocking); +void copy_blit_mps(void* dst, const void* src, size_t size); + +} // namespace mps +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..91fcba78006345dd7ce1a88bb29b7c6edc9adcf8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#if !defined(__MAC_14_0) && \ + (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) +{ + MPSGraphFFTScalingModeNone = 0L, + MPSGraphFFTScalingModeSize = 1L, + MPSGraphFFTScalingModeUnitary = 2L, +}; + +@interface FakeMPSGraphFFTDescriptor : NSObject +@property (readwrite, nonatomic) BOOL inverse; +@property (readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode; +@property (readwrite, nonatomic) BOOL roundToOddHermitean; ++(nullable instancetype) descriptor; +@end + +@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor; + +@interface MPSGraph (SonomaOps) +-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor + name:(NSString * _Nullable) name; + +-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor + name:(NSString * _Nullable) name; + + +-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor + axes:(NSArray * _Nonnull) axes + descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; + +-(MPSGraphTensor * _Nonnull) realToHermiteanFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor + axes:(NSArray * _Nonnull) axes + descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; + +-(MPSGraphTensor * _Nonnull) HermiteanToRealFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor + axes:(NSArray * _Nonnull) axes + descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; +@end + +// define BFloat16 enums for MacOS13 +#define MPSDataTypeBFloat16 ((MPSDataType) (MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16)) + +// define Metal version +#define MTLLanguageVersion3_1 ((MTLLanguageVersion) ((3 << 16) + 1)) +#endif diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..aa5637bf562bb92349f943b8a5f4272b6d84941a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h @@ -0,0 +1,197 @@ +#pragma once +#include + +// TODO: Remove me when moved to MacOS 13 +#if !defined(__MAC_13_2) && \ + (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2)) + +@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject + +@property (readwrite, nonatomic) NSUInteger strideInX; +@property (readwrite, nonatomic) NSUInteger strideInY; +@property (readwrite, nonatomic) NSUInteger strideInZ; +@property (readwrite, nonatomic) NSUInteger dilationRateInX; +@property (readwrite, nonatomic) NSUInteger dilationRateInY; +@property (readwrite, nonatomic) NSUInteger dilationRateInZ; + +@property (readwrite, nonatomic) NSUInteger paddingLeft; +@property (readwrite, nonatomic) NSUInteger paddingRight; +@property (readwrite, nonatomic) NSUInteger paddingTop; +@property (readwrite, nonatomic) NSUInteger paddingBottom; +@property (readwrite, nonatomic) NSUInteger paddingFront; +@property (readwrite, nonatomic) NSUInteger paddingBack; + +@property (readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle; +@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout; +@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout; + +@property (readwrite, nonatomic) NSUInteger groups; + +@end + +@compatibility_alias MPSGraphConvolution3DOpDescriptor FakeMPSGraphConvolution3DOpDescriptor; + +#endif + +@interface MPSGraph (VenturaOps) + +#if !defined(__MAC_13_0) && \ + (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) +{ + MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L, + MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L, + MPSGraphResizeNearestRoundingModeCeil = 2L, + MPSGraphResizeNearestRoundingModeFloor = 3L, + MPSGraphResizeNearestRoundingModeRoundToEven = 4L, + MPSGraphResizeNearestRoundingModeRoundToOdd = 5L, +}; + +// Define complex enums for MacOS 12 +#define MPSDataTypeComplexBit 0x01000000 +#define MPSDataTypeComplexFloat32 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64)) +#define MPSDataTypeComplexFloat16 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32)) +#endif + +- (MPSGraphTensor * _Nonnull) convolution3DWithSourceTensor:(MPSGraphTensor * _Nonnull) source + weightsTensor:(MPSGraphTensor * _Nonnull) weights + descriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient + weightsTensor:(MPSGraphTensor * _Nonnull) weights + outputShape:(MPSShape * _Nonnull) outputShape + forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient + sourceTensor:(MPSGraphTensor * _Nonnull) source + outputShape:(MPSShape * _Nonnull) outputShape + forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axis:(NSInteger) axis + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axis:(NSInteger) axis + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source + coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates + layout:(MPSGraphTensorNamedDataLayout) layout + normalizeCoordinates:(BOOL) normalizeCoordinates + relativeCoordinates:(BOOL) relativeCoordinates + alignCorners:(BOOL) alignCorners + paddingMode:(MPSGraphPaddingMode) paddingMode + samplingMode:(MPSGraphResizeMode) samplingMode + constantValue:(double) constantValue + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source + coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates + layout:(MPSGraphTensorNamedDataLayout) layout + normalizeCoordinates:(BOOL) normalizeCoordinates + relativeCoordinates:(BOOL) relativeCoordinates + alignCorners:(BOOL) alignCorners + paddingMode:(MPSGraphPaddingMode) paddingMode + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + constantValue:(double) constantValue + name:(NSString * _Nullable) name; +- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor + name:(NSString * _Nullable) name; + +@end diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/mps/OperationUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/OperationUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..46078ea068d406f8a2ff5eca9609c1bc51dd197b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/OperationUtils.h @@ -0,0 +1,394 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +// Fwd declarations +namespace at { + struct TensorIteratorBase; +} +using namespace at::mps; + +namespace at::native::mps { + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +struct MPSScalar { + id getMTLBuffer() const { return __builtin_bit_cast(id, buffer.get()); } + + size_t size = 0; + ScalarType type = ScalarType::Undefined; + c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope) + union { + float f; // MPS doesn't support 'double' + at::Half h; + int64_t i; + bool b; + c10::complex cf; + c10::complex ch; + at::BFloat16 bf16; + } value {}; +}; + +void runMPSGraph(MPSStream* mpsStream, + MPSGraph* mpsGraph, + NSDictionary* feeds, + NSDictionary* results); + +MPSDataType getMPSDataType(ScalarType scalar_type); +static inline MPSDataType getMPSDataType(const Tensor& t) { + return getMPSDataType(t.scalar_type()); +} +MPSDataType getMPSScalarType(ScalarType scalar_type); +static inline MPSDataType getMPSScalarType(const Tensor& t) { + return getMPSScalarType(t.scalar_type()); +} +MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); +std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); +static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) { + return getMPSTypeString(t.scalar_type(), short_name); +} +std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); +NSArray* getTensorAxes(const Tensor& t); +NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); +std::string getMPSShapeString(MPSShape* shape); +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true); +std::string getArrayRefString(const IntArrayRef s); +// use has_storage() on the returned tensor to determine if src actually is a view +Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output); +bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape); +MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false); + +// The MPSShape could vary based on memory format +MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); +MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); + +static inline id getMTLBufferStorage(const at::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +class Placeholder { + public: + Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr, + bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid); + MPSGraphTensor* getMPSGraphTensor() { + return _placeholder; + } + MPSGraphTensorData* getMPSGraphTensorData() { + return _value; + } + bool isIntermediate() { + return _value == nullptr; + } + + private: + MPSGraphTensor* _placeholder; + MPSGraphTensorData* _value; + Tensor _tensor; +}; + +void resize_tensor(Tensor* output); +Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device); +MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor); +MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType); +MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); +MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor); +MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); + +MPSGraph* make_mps_graph(); +void printTensorNDArray(const Tensor& t); +MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType); + +MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar); + +string get_mem_format_string(c10::MemoryFormat memory_format); + +using MPSCacheKey = uint64_t; + +// derive this class to cache a graph and its inputs/outputs +// can be used to store any NSObject +struct MPSCachedGraph +{ + MPSCachedGraph(NSObject *object) : _object([object retain]) {} + virtual ~MPSCachedGraph() { + [_object release]; + _object = nullptr; + } + + template + inline T* as() { + return static_cast(this); + } + + MPSGraph *graph() const { return (MPSGraph *)_object; } + NSObject *object() const { return _object; } +private: + NSObject *_object = nullptr; +}; + +struct MPSUnaryCachedGraph : public MPSCachedGraph +{ + MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; + +struct MPSUnaryGradCachedGraph : public MPSCachedGraph +{ + MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output + MPSGraphTensor *gradInputTensor_ = nil; +}; + +struct MPSBinaryCachedGraph : public MPSCachedGraph +{ + MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *otherTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; + +struct MPSBinaryGradCachedGraph : public MPSCachedGraph +{ + MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *otherTensor_ = nil; + MPSGraphTensor *gradInputTensor_ = nil; +}; + +// TODO: Improve the overall design of MPSGraphCache. +// https://github.com/pytorch/pytorch/issues/77176 +// Cache holding various keys mapped to graphs +struct MPSGraphCache +{ + typedef MPSCachedGraph * (^CreateCachedGraphBlock)(); + + struct CacheEntry { + CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {} + MPSCachedGraph* cachedGraph_ = nullptr; + std::string key_; + }; + + public: + + static MPSGraphCache* getInstance() { + if(_instance_cache == nullptr) { + _instance_cache = new MPSGraphCache(); + } + return _instance_cache; + } + + ~MPSGraphCache() { + dispatch_release(serialQueue_); + + for (const auto& i : cache_) { + delete i.second.cachedGraph_; + } + } + + // Disallow the copy constructor and operator= functions + MPSGraphCache(const MPSGraphCache&) = delete; + void operator=(const MPSGraphCache&) = delete; + + MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + + __block MPSCachedGraph* cachedGraph = nil; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync_with_rethrow(serialQueue_, ^() { + // verify the cached entry doesn't already exist + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + } else { + cachedGraph = createCacheBlock(); + CacheEntry entry(key, cachedGraph); + cache_.emplace(hash, entry); + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + return static_cast(CreateCachedGraph(key, createCacheBlock)); + } + + MPSCachedGraph* LookUp(const std::string& key) const { + + __block MPSCachedGraph* cachedGraph = nullptr; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync(serialQueue_, ^() { + + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + + private: + MPSGraphCache() { + serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); + } + // this is defined in OperationUtils.mm to not include + // MPSProfiler.h in header OperationUtils.h + void profileCachedGraph(const CacheEntry& cacheEntry) const; + + static MPSGraphCache* _instance_cache; + std::unordered_map cache_; + dispatch_queue_t serialQueue_ = nullptr; + +}; + +// Common template for creating graph with a specified cache if missing +template +inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function instantiate) { + auto cache_ = MPSGraphCache::getInstance(); + if (auto rc = cache_->LookUpAs(key)) { + return rc; + } + return cache_->CreateCachedGraphAs(key, ^mps::MPSCachedGraph*() { + T* newCachedGraph = nil; + @autoreleasepool { + // Initialize graph + auto mpsGraph = mps::make_mps_graph(); + newCachedGraph = new T(mpsGraph); + instantiate(mpsGraph, newCachedGraph); + } + return newCachedGraph; + }); +} + +// Common math operations +MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); + +#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \ + if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \ + TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \ + ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \ + } + +/** + * Returns distance from lowest to highest element offset in given tensor. + */ +size_t compute_storage_numel_distance(const at::Tensor& t); + +/** + * Checks whether tensor is mapped to a contiguous area in the storage. + */ +inline bool is_dense_in_storage(const at::Tensor& t) { + return compute_storage_numel_distance(t) == static_cast(t.numel()); +} + +static inline void mtl_setBuffer(id encoder, const Tensor& t, unsigned idx) { + [encoder setBuffer:getMTLBufferStorage(t) + offset:t.storage_offset() * t.element_size() + atIndex:idx]; +} + +static inline void mtl_dispatch1DJob(id encoder, + id cplState, + uint32_t length) { + const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); + [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; +} + +id generateKernelDataOffsets(id commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false); + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) { + return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) { + return @{ + p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) { + return @{ + p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) { + return @{ + p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(), + p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(), + }; +} + +inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) { + runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); +} + +inline bool supportsComplex() { + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); +} + +// MPS yet to support double types, but starting from MacOS 14, supports bfloat16 +inline bool supportedFloatingType(ScalarType dtype) { + return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; +} + +inline bool supportedFloatingType(const Tensor& t) { + return supportedFloatingType(t.scalar_type()); +} + +} // namespace at::native::mps diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/mps/TensorFactory.h b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/TensorFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..5512899934355d19d78f3fc1f65700aa043fc413 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/TensorFactory.h @@ -0,0 +1,12 @@ +// Copyright © 2022 Apple Inc. + +#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)) diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/mps/UnaryConstants.h b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/UnaryConstants.h new file mode 100644 index 0000000000000000000000000000000000000000..90ac12c0b8f845940108c6585426fcec0632ef0d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/mps/UnaryConstants.h @@ -0,0 +1,43 @@ +#pragma once + +const char* UNARY_KERNEL_TEMPLATE = R"METAL( +#include +using namespace metal; + +constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}}; +constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}}; +constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}}; +constant float d[2] = {{3.543889200, 1.637067800}}; + +kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]], + device {1} *input [[buffer(1)]], + uint index [[thread_position_in_grid]]) {{ + + float y = input[index]; + float x, z, num, dem; /*working variables */ + /* coefficients in rational expansion */ + + float y_abs = abs(y); + if(y_abs > 1.0f){{ + output[index] = NAN; + return; + }} + if(y_abs == 1.0f){{ + output[index] = copysign(INFINITY, y); + return; + }} + if(y_abs <= 0.7f) {{ + z = y * y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f); + x = y * num / dem; + }} + else{{ + z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); + num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; + dem = (d[1]*z + d[0])*z + 1.0f; + x = copysign(num, y) / dem; + }} + + output[index] = x; +}})METAL"; \ No newline at end of file diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..c0155a074db1686f62fe4d0a12b04642f2b1a9f5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { + +enum class NESTED_DENSE_OP: uint8_t {ADD, MUL}; + +using nested_dense_elementwise_fn = void (*)(Tensor& result, const Tensor & self, const Tensor & other, const NESTED_DENSE_OP& op); + +DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorFactories.h b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorFactories.h new file mode 100644 index 0000000000000000000000000000000000000000..2efb0cbfc4fd856ea38472d375c258dd01350f0f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorFactories.h @@ -0,0 +1,7 @@ +#pragma once + +namespace at { +namespace native { + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h new file mode 100644 index 0000000000000000000000000000000000000000..9520b517c81d62dae0d24a1ec51a8089fa563eec --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace native { + +TORCH_API Tensor NestedTensor_to_padded_tensor_generic( + const Tensor& t, + double padding, + OptionalIntArrayRef output_size); + +template +Tensor map_nt(const Tensor& nt, Func f) { + auto* nt_impl = get_nested_tensor_impl(nt); + const auto& sizes = nt_impl->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl->get_buffer()), sizes); +} +template +Tensor map_nt_binary(const Tensor& nt_1, const Tensor& nt_2, Func f){ + auto* nt_impl_1 = get_nested_tensor_impl(nt_1); + auto* nt_impl_2 = get_nested_tensor_impl(nt_2); + const auto& sizes = nt_impl_1->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl_1->get_buffer(), nt_impl_2->get_buffer()), sizes); +} + +C10_ALWAYS_INLINE std::pair _check_nested_layer_norm_inputs( + const NestedTensorImpl& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const size_t normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + // Check that the normalized_shape has the exact same sizes as the last dimensions from the NestedTensor input + // Also, compute M and N considering the idiosyncracies of NestedTensors + int64_t N = 1; + for (const auto i: c10::irange(normalized_ndim)) { + TORCH_CHECK( + input.opt_size(-normalized_ndim + i) != c10::nullopt, + "normalized_shape extends into irregular dimensions for the nested tensor" + ); + TORCH_CHECK( + normalized_shape[i] == *input.opt_size(-normalized_ndim + i), + "The shape at dimension ", + i, + "of normalized_shape doesn't match the input" + ); + N *= normalized_shape[i]; + } + + const int64_t M = input.numel() / N; + + return std::make_pair(M, N); +} + +Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..0aa6fe52ab95d10afde810adbfd9b191416057ce --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -0,0 +1,103 @@ +/** + * Transformer-specific NestedTensor utility functions. + * + * Not co-located with NestedTensor core code yet because they only + * support specific cases needed in transformers. + */ +#pragma once + +#include + +#include +#include + +namespace c10 { +class Scalar; +} // namespace c10 + +namespace at { +class Tensor; +namespace native { +struct NestedTensorImpl; + +// Requires that self is a contiguous NestedTensor, other is not a +// NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self +// must have a consistent last dimension across its included Tensors +// and that dimension must match other.size(0). +Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other); + +// Requires that mat1 is a contiguous NestedTensor, self & mat2 are +// not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1 +// has a consistent last dimension across its included Tensors that +// matches mat2.size(0). +Tensor NestedTensor_times_Tensor_plus_Tensor_addmm( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const c10::Scalar& beta, + const c10::Scalar& alpha, + c10::optional use_gelu = c10::nullopt); + +Tensor NestedTensor_add_NestedTensor_in_place( + const Tensor& self, + const Tensor& other); + +TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor( + const Tensor& sizes, + int64_t extra_elements); + +Tensor NestedTensor_from_padded_tensor_cpu( + const Tensor& padded, + const NestedTensorImpl& nt); + +Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional mask_dim, c10::optional mask_dim_length); + +template +void remove_padding_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +template +void remove_padding_transform0213_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +template +void add_padding_kernelLauncher( + T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + const std::vector& output_sizes, + const int batch_size, + const int output_batch_size); + +TORCH_API Tensor flash_attention_helper( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); + +TORCH_API std::tuple mem_efficient_helper_nested_unpacked( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..ca2a4ea6c2c9b89a61eb956c586e2e1d5b365de5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h @@ -0,0 +1,44 @@ +#include + + +namespace at { +namespace native { +namespace preprocessing { + +/** + * This function will take nested query, key, and value + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels. + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing( + const Tensor& query, + const Tensor& key, + const Tensor& value); + +/** + * This function will take nested query, key, and value, grad_out, and out + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels backwards. + * We use both functions to avoid having to do the same preprocessing + * for cumulative_sequence_length_q and cumulative_sequence_length_kv + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_kv, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_kv); + +} // namespace preprocessing +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..497c7ecc8cb6e88849a4ac1348ca8f4e682b4c03 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h @@ -0,0 +1,415 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS + +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +#include +#include + +namespace at { +namespace native { +struct NestedTensorImpl; + +// The following functions are used to construct nested tensors from buffers and +// metadata. + +inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_sizes) { + TORCH_CHECK( + buffer.dim() == 1, + "Expected given buffer to be 1dim, but got ", + buffer.dim(), + " instead."); + TORCH_CHECK( + buffer.is_contiguous(), "Expected given buffer to be contiguous."); + return at::detail::make_tensor( + std::move(buffer), std::move(nested_sizes)); +} + +// TODO: Figure out if we need a non-moving wrap_buffer() +inline at::Tensor wrap_buffer( + at::Tensor buffer, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + buffer.is_contiguous(), "Given buffer must be contiguous."); + return at::detail::make_tensor( + std::move(buffer), + std::move(nested_sizes), + std::move(nested_strides), + std::move(storage_offsets)); +} + +inline at::Tensor get_buffer(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_buffer(); +} + +/** + * Create a new nested tensor that is a view of a base nested tensor + * + * create_view_tensor calls a specialized constructor that copys the + * the keys from base onto the new view tensor being created. + * The storage is shared between the base and the returned view tensor + * + * All callers of this helper must: + * - Only return a view of the input + * - Must be explicit and define a derivative + * + * @param base Base tensor to construct view from. + * @param nested_sizes View tensors' sizes. + * @param nested_strides View tensors' strides. + * @param storage_offsets View tensors' offsets. + * @return A newly constructed view tensor + */ +inline at::Tensor create_nested_view_tensor( + const at::Tensor& base, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT( + base.is_nested(), + "This function can only be used to create nested tensor views"); + TORCH_INTERNAL_ASSERT( + c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::AutogradFunctionality), + "Creating a non differentiable nested tensor view in a CompositeImplicit function is not allowed."); + return at::detail::make_tensor( + c10::TensorImpl::VIEW, + base, + nested_sizes, + nested_strides, + storage_offsets); +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Helper functions for getting information about a nested tensor's shape. + +int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt); + +// The sizes of the underlying tensors +inline std::vector NestedTensor_get_sizes( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector sizes(ntensors); + if (ntensors == 0) { + return sizes; + } + const Tensor& sizemat = self_ptr->get_nested_sizes(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars has empty sizes + if (orig_dim == 0) { + return sizes; + } + const int64_t* sizemat_ptr = sizemat.data_ptr(); + + for (const auto i : c10::irange(ntensors)) { + sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); + sizemat_ptr += orig_dim; + } + return sizes; +} + +TORCH_API std::vector NestedTensor_get_max_size( + const NestedTensorImpl& nt); + +std::vector NestedTensor_get_max_size_from_size_tensor( + const Tensor& sizes); + +inline std::vector NestedTensor_get_sizes(const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_sizes(self_ptr); +} +// The strides of the underlying tensors +inline std::vector NestedTensor_get_strides( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector strides(ntensors); + if (ntensors == 0) { + return strides; + } + const Tensor& stridemat = self_ptr->get_nested_strides(); + int64_t orig_dim = stridemat.size(1); + // nesting scalars has empty strides + if (orig_dim == 0) { + return strides; + } + const int64_t* stridemat_ptr = stridemat.data_ptr(); + for (const auto i : c10::irange(ntensors)) { + strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim); + stridemat_ptr += orig_dim; + } + return strides; +} + +inline std::vector NestedTensor_get_strides( + const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_strides(self_ptr); +} + +inline void check_numel_equals_buffer_size(const at::Tensor& self) { + auto self_impl = get_nested_tensor_impl(self); + TORCH_CHECK( + self.numel() == static_cast(self_impl->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { + TORCH_CHECK( + self_ptr->numel() == static_cast(self_ptr->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Data structures and functions for generically applying a function on a nested +// tensor. +namespace impl { + +template +struct NestedNode { + NestedNode() = delete; + explicit NestedNode(std::vector&& children) + : _is_leaf(false), _children(children) {} + explicit NestedNode(TensorList children) + : _is_leaf(false), _children(children.vec()) {} + // NestedNode(NestedNode&) = delete; + // NestedNode(const NestedNode&) = delete; + // NestedNode& operator=(NestedNode) = delete; + explicit NestedNode(T payload) : _is_leaf(true), _payload(std::move(payload)) {} + inline bool is_leaf() const { + return _is_leaf; + } + inline size_t degree() const { + return _children.size(); + } + inline const std::vector unbind() const { + return _children; + } + inline T children(size_t i) const { + return _children[i]; + } + inline const T& payload() const { + return _payload; + } + inline T& payload() { + return _payload; + } + + private: + bool _is_leaf; + std::vector _children; + T _payload; +}; + +using TensorNode = NestedNode; + +template +class _map; + +template +class _map> { + public: + static A function_one(F&& fn, const Args&... nested_node) { + return std::forward(fn)(nested_node...); + } + // NOTE: We must move F to avoid copying objects if it is a lambda with + // captures. + static NestedNode function( + F&& fn, + const NestedNode&... nested_node) { + size_t degree = 0; + bool all_leaf = true; + c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { + all_leaf = all_leaf && (n.is_leaf()); + if (degree > 1 && n.degree() > 1) { + TORCH_CHECK( + degree == n.degree(), "NestedNodes must match in degree."); + } + if (n.degree() > degree) { + degree = n.degree(); + } + return nullptr; + }); + // All NestedNodes just wrap regular objects. + if (all_leaf) { + return NestedNode(std::forward(fn)(nested_node.payload()...)); + } + // Some NestedNodes wrap regular Tensors, some NestedTensors and some other + // types. + std::vector result; + for (size_t i = 0; i < degree; i++) { + std::tuple children = c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of::value, + "Internal error."); + // Broadcast regular arguments across NestedTensor constituents. + // This could be a Tensor, integer or anything else really. + if (a.is_leaf()) { + return a.payload(); + } + // Broadcast NestedTensors with one constituent. + if (a.degree() == 1 && !a.is_leaf()) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }); + c10::guts::apply( + [&result, &fn](Args... filtered) { + result.emplace_back(function_one(std::forward(fn), filtered...)); + }, + std::move(children)); + } + return NestedNode(std::move(result)); + } +}; + +// TODO: Add static assert to verify lambda arguments match nested_node types +template +static inline NestedNode< + typename c10::guts::infer_function_traits::type::return_type> +map(F&& fn, const NestedNode&... nested_node) { + return _map< + F, + typename c10::guts::infer_function_traits::type::return_type, + typename c10::guts::infer_function_traits::type::parameter_types>:: + function(std::forward(fn), nested_node...); +} + +inline TensorNode get_nested_tensor_structure(at::Tensor tensor) { + if (get_nested_tensor_impl_or_null(tensor) == nullptr) { + return TensorNode(std::move(tensor)); + } + return TensorNode(tensor.unbind()); +} + +inline Tensor wrap_tensor_node( + TensorNode tensor_node, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + TORCH_CHECK( + !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors."); + TensorOptions options_ = + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( + pin_memory); + if (tensor_node.degree() == 0) { + return wrap_buffer(ones({0}, dtype, layout, device), ones({})); + } + + // Fast path: if all tensors are on CPU, have contiguous memory, and the same + // dtype, copying can be done much faster. + bool all_tensors_cpu = true; + bool all_tensors_contiguous = true; + bool all_tensors_same_dtype = true; + auto first_dtype = tensor_node.children(0).dtype(); + std::vector start_offsets(tensor_node.degree()); + start_offsets[0] = 0; + long total_size = 0; + for (const auto i : c10::irange(tensor_node.degree())) { + all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu(); + all_tensors_contiguous = + all_tensors_contiguous && tensor_node.children(i).is_contiguous(); + all_tensors_same_dtype = all_tensors_same_dtype && + (first_dtype == tensor_node.children(i).dtype()); + if (!(all_tensors_cpu && all_tensors_contiguous && + all_tensors_same_dtype)) { + break; + } + if (i > 0) { + start_offsets[i] = + start_offsets[i - 1] + tensor_node.children(i - 1).numel(); + } + total_size += tensor_node.children(i).numel(); + } + + TensorOptions options; + Tensor nt_buffer, nt_sizes; + if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) { + nt_buffer = at::empty({total_size}, tensor_node.children(0).options()); + nt_sizes = at::empty( + {static_cast(tensor_node.degree()), + static_cast(tensor_node.children(0).sizes().size())}, + TensorOptions().dtype(kLong)); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + c10::typeMetaToScalarType(first_dtype), + "create_nt_buffer", + [&]() { + at::parallel_for( + 0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // Only try copying memory if there is more than 0 elements + // for a certain tensor + if (tensor_node.children(i).numel() > 0) { + memcpy( + nt_buffer.mutable_data_ptr() + start_offsets[i], + tensor_node.children(i).const_data_ptr(), + tensor_node.children(i).numel() * sizeof(scalar_t)); + } + } + }); + }); + long sizes_offset = 0; + for (size_t i = 0; i < tensor_node.degree(); ++i) { + auto tensor_sizes = tensor_node.children(i).sizes(); + for (int64_t tensor_size : tensor_sizes) { + nt_sizes.mutable_data_ptr()[sizes_offset++] = tensor_size; + } + } + options = nt_buffer.options().merge_in(options_); + } else { // Slow path + std::vector flat_tensors; + std::vector sizes; + for (const auto i : c10::irange(tensor_node.degree())) { + flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); + sizes.push_back( + tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + } + options = flat_tensors[0].options().merge_in(options_); + nt_buffer = at::cat(flat_tensors); + nt_sizes = at::native::stack(sizes); + } + + return wrap_buffer(nt_buffer.to(options), nt_sizes); +} + +} // namespace impl + +// This function is meant to ease rapid operator coverage for +// NestedTensor kernels. It is not meant to be efficient. Use it judiciously. +template +inline at::Tensor map_nested_tensor(F&& fn, A... a) { + return wrap_tensor_node( + impl::map(std::forward(fn), impl::get_nested_tensor_structure(a)...), + c10::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt); +} + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h new file mode 100644 index 0000000000000000000000000000000000000000..141bf5c6e10eb992358d5cf474306eba1cb5327c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace native { + +Tensor& quantize_tensor_per_tensor_affine( + const Tensor& rtensor, + Tensor& qtensor, + double scale, + int64_t zero_point); +Tensor& quantize_tensor_per_channel_affine( + const Tensor& rtensor, + Tensor& qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis); + +Tensor& quantize_tensor_per_channel_float_qparams( + const Tensor& rtensor, + Tensor& qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis); + +Tensor& dequantize_tensor_per_tensor_affine( + const Tensor& qtensor, + Tensor& rtensor, + double scale, + int64_t zero_point); +Tensor& dequantize_tensor_per_channel_affine( + const Tensor& qtensor, + Tensor& rtensor, + Tensor scales, + Tensor zero_points, + int64_t axis); +Tensor& dequantize_tensor_per_channel_float_qparams( + const Tensor& qtensor, + Tensor& rtensor, + Tensor scales, + Tensor zero_points, + int64_t axis); + +using quantize_tensor_per_tensor_affine_fn = + void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point); + +using quantize_tensor_per_channel_affine_fn = void (*)( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_channel_float_qparams_fn = void (*)( + const Tensor& rtensor, + Tensor& qtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using dequantize_tensor_per_tensor_affine_fn = + void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point); + +using dequantize_tensor_per_channel_affine_fn = void (*)( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using dequantize_tensor_per_channel_float_qparams_fn = void (*)( + const Tensor& qtensor, + Tensor& rtensor, + const Tensor& scales, + const Tensor& zero_points, + int64_t axis); + +using quantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point); + +using dequantize_tensor_per_tensor_affine_sub_byte_fn = + void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point); + +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_fn, + quantize_tensor_per_tensor_affine_stub); +DECLARE_DISPATCH( + quantize_tensor_per_channel_affine_fn, + quantize_tensor_per_channel_affine_stub); +DECLARE_DISPATCH( + quantize_tensor_per_channel_float_qparams_fn, + quantize_tensor_per_channel_float_qparams_stub); + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_fn, + dequantize_tensor_per_tensor_affine_stub); +DECLARE_DISPATCH( + dequantize_tensor_per_channel_affine_fn, + dequantize_tensor_per_channel_affine_stub); +DECLARE_DISPATCH( + dequantize_tensor_per_channel_float_qparams_fn, + dequantize_tensor_per_channel_float_qparams_stub); + +DECLARE_DISPATCH( + quantize_tensor_per_tensor_affine_sub_byte_fn, + quantize_tensor_per_tensor_affine_sub_byte_stub); + +DECLARE_DISPATCH( + dequantize_tensor_per_tensor_affine_sub_byte_fn, + dequantize_tensor_per_tensor_affine_sub_byte_stub); + +template +TORCH_API Tensor quantize_tensor( + Tensor rtensor, + Tensor qtensor, + double scale, + int64_t zero_point); +template +TORCH_API Tensor dequantize_tensor( + Tensor qtensor, + Tensor rtensor, + double scale, + int64_t zero_point); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..67cb7a7c451a587c566f0e5be6949c745e5d8b68 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h @@ -0,0 +1,47 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Quantize a float value into a uint value given scale and zero_point +template +TORCH_API T quantize_val(double scale, int64_t zero_point, float value); +// TODO combine this with quantize_val once the numerics for ARM are aligned +// with it +template +T quantize_val_arm( + const float scale, + const int32_t zero_point, + const float value); +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count = 8); +template +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); +template +TORCH_API float dequantize_vec( + double scale, + int64_t zero_point, + const T* src, + float* dst, + size_t count = 8); +template +TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); + +// Given a multiplier and a zero_point, requantize int32_t computed values back +// to quantized values. See comment above +// make_per_tensor_affine_quantizer function for the usage of int64_t +template +TORCH_API DST_T +requantize_from_int(double multiplier, int64_t zero_point, int64_t src); + +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/ConvUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/ConvUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..679777f8d65862e79fb8f3deb6aaa3b2d5c9428e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/ConvUtils.h @@ -0,0 +1,62 @@ +#pragma once +#include +#include + +namespace at::native::quantized { +namespace { +// MakeConvOutputShape used from both CPU and CUDA libraries +// and exporting symbol from torch_cpu would probably take more storage +// than duplicating implementation which likely be inlined away +template +at::SmallVector MakeConvOutputShape( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const torch::List& stride, + const torch::List& padding, + const torch::List& dilation); + +#if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK) +template <> +at::SmallVector MakeConvOutputShape<2>( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const at::List& stride, + const at::List& padding, + const at::List& dilation) { + const int H = input_image_shape[0]; + const int W = input_image_shape[1]; + const int64_t Y_H = + (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; + const int64_t Y_W = + (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; + return {N, M, Y_H, Y_W}; +} + +template <> +at::SmallVector MakeConvOutputShape<3>( + int N, // mini-batch + int M, // output channels + const std::array& input_image_shape, + const std::vector& kernel, + const at::List& stride, + const at::List& padding, + const torch::List& dilation) { + const int D = input_image_shape[0]; + const int H = input_image_shape[1]; + const int W = input_image_shape[2]; + const int64_t Y_D = + (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1; + const int64_t Y_H = + (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1; + const int64_t Y_W = + (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1; + return {N, M, Y_D, Y_H, Y_W}; +} + +#endif +} // anonymous namespace +} // namespace at::native::quantized diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/Copy.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..9c611251db31ec8ad3c5f9cb6d843a81ba886cbf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/Copy.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace at { +namespace native { + +Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src); +} +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h new file mode 100644 index 0000000000000000000000000000000000000000..2b5c36415346489ff17d9731b4b3eca09212a50e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +namespace at { + +struct TensorIterator; + +namespace native { + +using fake_quant_tensor_cachemask_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + float sc, + int64_t z_point, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + const Tensor& sc, + const Tensor& z_point, + const Tensor& fake_quant_enabled, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_learnable_grad_tensor_fn = void (*)( + TensorIterator& iter, + float scale, + float inv_scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub); +DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub); +DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub); + +using fake_quant_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max); + +using fake_quant_per_channel_cachemask_fn = void (*)( + TensorIterator &iter, + TensorIterator &iter_mask, + int64_t quant_min, + int64_t quant_max); + +DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub); + +using fake_quant_learnable_per_channel_fn = void (*)( + TensorIterator &iter, + int64_t quant_min, + int64_t quant_max, + float grad_factor); + +DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/IndexKernel.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8338878cc208ddb61fd9b66ef39ec1ba3236762c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/IndexKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +namespace native { +using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point); +using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point); + +DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub); +DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub); + + +} // native +} // at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/PackedParams.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/PackedParams.h new file mode 100644 index 0000000000000000000000000000000000000000..5800a3d8fff18953192da734920cdd837c113b73 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/PackedParams.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include + +struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) = 0; + + // out variant of LinearPackedParamsBase::apply + virtual at::Tensor& apply_out( + const at::Tensor& /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/, + at::Tensor& output) { + throw std::runtime_error( + "apply_out is not implemented for this packed " + "parameter type"); + return output; + } + + virtual at::Tensor& apply_relu_out( + const at::Tensor& /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/, + at::Tensor& output) { + throw std::runtime_error( + "apply_relu_out is not implemented for this packed " + "parameter type"); + return output; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32): + // input -> q* -> dq* -> linear* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // X: float32 Tensor, will be quantized to quint8 in the op + // W_prepack: packed qint8 quantized weight and bias + // Returns: + // Y: float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32): + // input -> q* -> dq* -> linear* -> relu* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // input: float32 Tensor, will be quantized to quint8 in the op + // Returns: + // float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + virtual at::Tensor apply_dynamic( + at::Tensor input, + bool reduce_range = false) = 0; + virtual at::Tensor apply_dynamic_relu( + at::Tensor input, + bool reduce_range = false) = 0; + + virtual at::Tensor& apply_dynamic_out( + const at::Tensor& /* input */, + at::Tensor& output, + bool /* reduce_range */) { + throw std::runtime_error( + "apply_dynamic_out is not implemented for this packed " + "parameter type"); + return output; + } + virtual at::Tensor& apply_dynamic_relu_out( + const at::Tensor& /* input */, + at::Tensor& output, + bool /* reduce_range */) { + throw std::runtime_error( + "apply_dynamic_relu_out is not implemented for this packed " + "parameter type"); + return output; + } + + virtual std::tuple> unpack() = 0; + + virtual c10::optional bias() = 0; + + virtual void set_bias(c10::optional /*bias*/) { + throw std::runtime_error( + "set_bias is not implemented for this packed " + "parameter type"); + } +}; + +template +struct ConvPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) = 0; + virtual at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) = 0; + + virtual std::tuple> unpack() = 0; + + virtual torch::List stride() const = 0; + virtual torch::List padding() const = 0; + virtual torch::List output_padding() const = 0; + virtual torch::List dilation() const = 0; + virtual int64_t groups() const = 0; + virtual bool transpose() const = 0; +}; diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..8e34516a9ab73887d606c7baf6886d17819fb9d6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h @@ -0,0 +1,8 @@ +#include + +namespace at { +namespace native { +TORCH_API Tensor +quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point); +} +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h new file mode 100644 index 0000000000000000000000000000000000000000..c24760a459ab7289f950758eea0588915a2707cb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { + virtual at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) = 0; + + virtual at::Tensor unpack() = 0; + + virtual int64_t bit_rate() const = 0; + virtual int64_t version() const = 0; +}; diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..f726926f7390bd37512d0bd03620235aa7cf883b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h @@ -0,0 +1,445 @@ +#pragma once + +#include +#if AT_MKLDNN_ENABLED() +#include +#include +#include +#include + +#include + +using PrimitiveCacheKey = std::tuple< + double, // input_scale + int64_t, // input_zero_point + std::vector, // input_shape + double, // output_scale + int64_t, // output_zero_point + int64_t, // OMP_number_of_threads + double, // accum_scale + int64_t>; // accum_zero_point + +enum CacheKeyIndex { + InputScale, + InputZeroPoint, + InputShape, + OutputScale, + OutputZeroPoint, + NumOfThreads, +}; + +// Base class of primitive cache +struct PrimitiveCache { + PrimitiveCacheKey key; + + bool hit(const PrimitiveCacheKey& key) { + return this->key == key; + } +}; + +using LinearParams = ideep::matmul_forward_params; +using Conv = dnnl::convolution_forward; +using ConvDesc = dnnl::convolution_forward::primitive_desc; +using ConvParams = ideep::convolution_forward_params; +using Deconv = dnnl::deconvolution_forward; +using DeconvDesc = dnnl::deconvolution_forward::primitive_desc; +using DeconvParams = ideep::deconv_forward_params; + +struct LinearPrimitiveCache : PrimitiveCache { + LinearPrimitiveCache() {} + + LinearPrimitiveCache( + const PrimitiveCacheKey& key, + const LinearParams& param) { + this->key = key; + this->param = param; + } + + LinearParams param; + + // For dynamic qlinear, scale and zero point + // are set at execution time. So we only need to compare + // the rest part of key. + bool hit_dynamic(const PrimitiveCacheKey& new_key) { + auto cached_input_shape = std::get(this->key); + auto new_input_shape = std::get(new_key); + return ( + cached_input_shape == new_input_shape && + std::get(this->key) == std::get(new_key)); + } + + LinearParams& get_param() { + return param; + } +}; + +struct ConvPrimitiveCache : PrimitiveCache { + ConvPrimitiveCache() {} + + ConvPrimitiveCache( + const PrimitiveCacheKey& key, + const ConvParams& params) { + this->key = key; + this->params = params; + } + + ConvParams params; + + ConvParams& get_params() { + return params; + } +}; + +struct DeconvPrimitiveCache : PrimitiveCache { + DeconvPrimitiveCache() {} + + DeconvPrimitiveCache( + const PrimitiveCacheKey& key, + const DeconvParams& params) { + this->key = key; + this->params = params; + } + + DeconvParams params; + + DeconvParams& get_params() { + return params; + } +}; + +enum PostOps { + NoPostOp, + Relu, + LeakyRelu, + Tanh, + Gelu +}; + +static std::unordered_map POST_OP_TABLE = { + {"none", NoPostOp}, + {"relu", Relu}, + {"leaky_relu", LeakyRelu}, + {"tanh", Tanh}, + {"gelu", Gelu} +}; + +struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { + PackedLinearWeightsOnednn( + std::unique_ptr weight, + c10::optional bias, + at::Tensor orig_weight, + c10::optional orig_bias) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)) { + cache_initialized_flag = std::make_unique(); + } + std::unique_ptr weight_; + c10::optional bias_; + at::Tensor orig_weight_; + c10::optional orig_bias_; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + at::Tensor apply_leaky_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + double negative_slope); + + at::Tensor apply_tanh( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + c10::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + private: + LinearPrimitiveCache prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + torch::List post_op_args = torch::List()); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); + + LinearPrimitiveCache& get_cache() { + return prim_cache; + } +}; + +template +struct PackedConvWeightsOnednn : public ConvPackedParamsBase { + PackedConvWeightsOnednn( + std::unique_ptr weight, + c10::optional bias, + at::Tensor orig_weight, + c10::optional orig_bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose) { + cache_initialized_flag = std::make_unique(); + } + + std::unique_ptr weight_; + c10::optional bias_; + at::Tensor orig_weight_; + c10::optional orig_bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + at::Tensor apply_add( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + at::Tensor apply_add_relu( + const at::Tensor& input, + const at::Tensor& accum, + double output_scale, + int64_t output_zero_point); + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + ConvPrimitiveCache conv_prim_cache; + DeconvPrimitiveCache deconv_prim_cache; + std::unique_ptr cache_initialized_flag; + + template + at::Tensor apply_impl( + const at::Tensor& input, + const c10::optional& accum, + double output_scale, + int64_t output_zero_point); + + ConvPrimitiveCache& get_conv_cache() { + assert(!transpose()); + return conv_prim_cache; + } + + DeconvPrimitiveCache& get_deconv_cache() { + assert(transpose()); + return deconv_prim_cache; + } +}; + +namespace onednn_utils { + +static ideep::attr_t create_attr_by_post_op( + const std::string& post_op_name, + const torch::List>& post_op_args, + const dnnl::algorithm post_algorithm) { + using ideep::tensor; + PostOps post_op = POST_OP_TABLE[post_op_name]; + if (post_op == Relu) { + return ideep::attr_t::fuse_relu(); + } else if (post_op == LeakyRelu) { + return ideep::attr_t::fuse_relu_v2(/*alpha=*/post_op_args[0].value().to()); + } else if (post_op == Tanh) { + return ideep::attr_t::fuse_tanh(); + } else if (post_op == Gelu) { + return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm); + } + return ideep::attr_t(); +} + +// Try to reorder tensor to expected desc at runtime +// Do it in a `try...catch...` manner to avoid oneDNN's errors +// TODO: Move it to third_party/ideep +static void try_reorder( + ideep::tensor& t, + const ideep::tensor::desc&& desc, + ideep::scale_t scales) { + if (t.get_desc() != desc) { + try { + t = t.reorder_if_differ_in(desc); + } catch (...) { + ideep::tensor&& plain = t.to_public(nullptr, t.get_data_type()); + t = plain.reorder_if_differ_in(desc); + } + t.set_scale(scales); + } +} + +// ONEDNN requires symmetric quantization of weight +// Use this util function to check. +static bool is_weight_symmetric_quant( + const at::Tensor& weight, + bool is_transposed_conv) { + bool is_symmetric = true; + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + is_symmetric &= (weight.q_zero_point() == 0); + } else if (qtype == c10::kPerChannelAffine) { + if (is_transposed_conv) { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } else { + auto output_channels = weight.size(0); + for (int i = 0; i < output_channels; ++i) { + auto zp = weight.q_per_channel_zero_points()[i].item(); + is_symmetric &= (zp == 0); + } + } + } else { + // This case is currently not supported in PyTorch + // but we do not want to raise an error in this util function. + is_symmetric = false; + } + return is_symmetric; +} + +// When qengine is x86, use this util func to check if onednn kernel +// is preferred than fbgemm's to get better performance. +static bool should_use_onednn_quant( + const at::Tensor& weight, + bool is_transposed_conv, + int groups, + torch::List output_padding) { + // Performance of onednn is only validated on Linux right now. + // Also, the heuristics for dispatching are based on perf data on Linux. + // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux. + // TODO Support more OSs. +#if !defined(__linux__) + return false; +#else + bool vnni_available = cpuinfo_has_x86_avx512vnni(); + bool w_sym_quant = + is_weight_symmetric_quant(weight, is_transposed_conv); + bool opad_all_zero = + std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }); + return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero; +#endif +} + +} // onednn_utils + +at::Tensor _qconv_prepack_onednn( + at::Tensor weight, // from CPU backend instead of QuantizedCPU + at::Tensor weight_scales, // Weight zero points must be 0 for onednn + double input_scale, + int64_t input_zero_point, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + c10::optional> input_shape=c10::nullopt); + +static at::Tensor _quantized_convolution_onednn( + at::Tensor act, // contains quantized values but not QTensor + double act_scale, + int64_t act_zero_point, + at::Tensor weight, // MKLDNN tensor with quantized values + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, // Bias is packed if not None + torch::List stride, + torch::List padding, + torch::List dilation, + bool transposed, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + c10::optional accum=c10::nullopt, // accum to fused with conv add + double accum_scale=1.0, + int64_t accum_zero_point=0, + bool fp32_output=false, + c10::optional binary_attr=c10::nullopt, + c10::optional binary_alpha=c10::nullopt, + c10::optional unary_attr=c10::nullopt, + torch::List> unary_scalars=torch::List>(), + c10::optional unary_algorithm=c10::nullopt); + +#endif // #if AT_MKLDNN_ENABLED() diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c1429338b6c07c1f4c3877bbc638398415fe0be2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h @@ -0,0 +1,527 @@ +#pragma once + +#ifdef USE_PYTORCH_QNNPACK +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +inline int kPaddingChannels = 8; +struct QnnpackOperatorDeleter { + void operator()(pytorch_qnnp_operator_t op) { + pytorch_qnnp_delete_operator(op); + } +}; + +// PackedWeight struct for QNNPACK stores the original Weight and Bias as +// QNNPACK currently does not support an unpack function. +// For PyTorch Mobile, once the model is scripted and serialized we don't need +// to call unpack, so we can save some memory by checking for this case and free +// the original weights after packing. +// Input scale is set to null in pre-pack step. QNNPACK needs bias quantized +// with input scale which is available at runtime in pytorch. During runtime if +// input scale value changes then we requantize bias with the updated scale. For +// inference we expect the graph to be static so the input scale should not +// change across consecutive inference calls. +struct PackedLinearWeightsQnnp : public LinearPackedParamsBase { + PackedLinearWeightsQnnp( + std::unique_ptr w, + at::Tensor orig_weight, + at::Tensor bias, + c10::optional input_scale, + at::Tensor w_scales, + std::vector&& w_zps) + : w(std::move(w)), + orig_weight(std::move(orig_weight)), + bias_(at::native::mobile::allocate_padded_contiguous_if_needed( + bias, bias.suggest_memory_format())), + per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine), + input_scale(std::move(input_scale)), + w_scales(std::move(w_scales)), + w_zero_points(std::move(w_zps)), + q_scheme(this->orig_weight.qscheme()) { + weight_sizes = this->orig_weight.sizes().vec(); + } + + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias_; + bool per_channel_; + c10::optional input_scale; + at::Tensor w_scales; + std::vector w_zero_points; + std::vector requantization_scales; + std::vector weight_sizes; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + std::tuple> unpack() override; + + c10::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + bool per_channel() const { + return per_channel_; + } + + private: + std::mutex qnnp_mutex_; + +#ifdef USE_XNNPACK + xnnpack_operator xnnp_linear_op; + + template + at::Tensor apply_impl_xnnp( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +#endif // USE_XNNPACK + + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range); +}; + +template +struct PackedConvWeightsQnnp : public ConvPackedParamsBase { + PackedConvWeightsQnnp( + std::unique_ptr w, + at::Tensor orig_weight, + at::Tensor bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose, + c10::optional input_scale, + std::vector kernel, + at::Tensor w_scale, + std::vector&& w_zps, + bool is_per_channel) + : w(std::move(w)), + orig_weight(std::move(orig_weight)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + is_per_channel_(is_per_channel), + input_scale(input_scale), + kernel_(std::move(kernel)), + w_scales(std::move(w_scale)), + w_zero_points(std::move(w_zps)) { + const bool any_padding = std::any_of( + padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; }); + const size_t kernel_size = + std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>()); + + const size_t group_input_channels = transpose + ? this->orig_weight.size(0) / groups + : this->orig_weight.size(1); + const size_t group_output_channels = transpose + ? this->orig_weight.size(1) + : this->orig_weight.size(0) / groups; + + const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1; + const size_t kernel_height = kernel_[kSpatialDim - 2]; + const size_t kernel_width = kernel_[kSpatialDim - 1]; + + pytorch_qnnp_ukernel_type ukernel_type; + if (transpose_) { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_none; + + const bool has_depthwise_dimensions = + (kSpatialDim == 2 && + ((kernel_height == 3 && kernel_width == 3) || + (kernel_height == 5 && kernel_width == 5))) || + (kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 && + kernel_depth == 3); + const bool has_depthwise_grouping = + group_input_channels == 1 && group_output_channels == 1 && groups > 1; + + if (has_depthwise_dimensions && has_depthwise_grouping) { + ukernel_type = pytorch_qnnp_ukernel_type_dwconv; + } else if ( + kernel_size == 1 && + std::all_of( + stride_.begin(), + stride_.end(), + [](const auto& e) { return e == 1; }) && + !any_padding) { + ukernel_type = group_input_channels >= SIZE_MAX + ? pytorch_qnnp_ukernel_type_xzp_gemm + : pytorch_qnnp_ukernel_type_gemm; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } + } + + if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) { + TORCH_INTERNAL_ASSERT( + false, "Per channel quantized weights are not supported for XZP kernels"); + } + + pytorch_qnnp_operator_t convolution{nullptr}; + // Initially all the params are set to zero. + convolution = static_cast( + calloc(1, sizeof(struct pytorch_qnnp_operator))); + if (convolution == nullptr) { + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + } + + convolution_op = + std::unique_ptr( + convolution); + + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + convolution->ukernel_type = ukernel_type; + convolution->groups = groups; + convolution->group_input_channels = group_input_channels; + convolution->group_output_channels = group_output_channels; + convolution->kernel_depth = kernel_depth; + convolution->kernel_height = kernel_height; + convolution->kernel_width = kernel_width; + convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1; + convolution->stride_height = stride_[kSpatialDim - 2]; + convolution->stride_width = stride_[kSpatialDim - 1]; + convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1; + convolution->dilation_height = dilation_[kSpatialDim - 2]; + convolution->dilation_width = dilation_[kSpatialDim - 1]; + convolution->input_padding_height = padding_[kSpatialDim - 2]; + convolution->input_padding_width = padding_[kSpatialDim - 1]; + convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0; + convolution->per_channel = is_per_channel_; + convolution->transpose = transpose_; + + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + + size_t zero_size = sizeof(uint8_t) * k_stride; + size_t zero_offset = 0; + + if (transpose_) { + convolution->adjustment_width = output_padding_[1]; + convolution->adjustment_height = output_padding_[0]; + if (group_input_channels < 8) { + zero_size += 8; + zero_offset = 8; + } + } else { + zero_buffer_size = 0; + if (any_padding) { + zero_size = 0; + zero_offset = 0; + if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) { + const uint32_t cr = pytorch_qnnp_params.q8dw9.cr; + const size_t group_stride = (groups + (cr - 1)) & -cr; + if (groups >= 8) { + zero_size = sizeof(uint8_t) * group_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * group_stride + 8; + zero_offset = sizeof(uint8_t) * 8; + } + } else if ( + ukernel_type == pytorch_qnnp_ukernel_type_conv || + ukernel_type == pytorch_qnnp_ukernel_type_gemm) { + if (group_input_channels >= 8) { + zero_size = sizeof(uint8_t) * k_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * k_stride + 8; + zero_offset = 8; + } + } + } + } + + // NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI) + void* zero_buffer = malloc(zero_size); + if (zero_buffer == nullptr) { + pytorch_qnnp_delete_operator(convolution); + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for zero padding", + zero_size); + } + // Need to set to input zero point + // memset(zero_buffer, input_zero_point, zero_size); + zero_buffer_size = zero_size; + convolution->zero_buffer = zero_buffer; + convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset); + } + + std::unique_ptr convolution_op; + #ifdef USE_XNNPACK + xnnpack_operator xnnp_convolution_op; + #endif // USE_XNNPACK + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + bool transpose_; + bool is_per_channel_; + c10::optional input_scale; + std::vector kernel_; + at::Tensor w_scales; + std::vector w_zero_points; + std::vector requantization_scales; + size_t zero_buffer_size; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range=false) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return transpose_; + } + + bool per_channel() const { + return is_per_channel_; + } + + private: + std::mutex qnnp_mutex_; + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); + +#ifdef USE_XNNPACK + template + at::Tensor apply_impl_xnnp( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +#endif // USE_XNNPACK +}; + +enum class Activation : uint8_t { NONE = 0, RELU = 1 }; + +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float Round(const float x) { + return ::nearbyintf(x); +} +inline double Round(const double x) { + return ::nearbyint(x); +} +#else +template +inline T Round(const T x) { + return std::nearbyint(x); +} +#endif + +template +inline T QuantizeValue(float scale, int32_t zero_point, float value) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + auto r = zero_point + static_cast(Round(value / scale)); + r = std::max(r, qmin); + r = std::min(r, qmax); + return static_cast(r); +} + +template +inline std::pair activationLimits( + float scale, + int32_t zero_point, + Activation Ac) { + switch (Ac) { + case Activation::NONE: + return {std::numeric_limits::min(), + std::numeric_limits::max()}; + case Activation::RELU: + return {QuantizeValue(scale, zero_point, 0.0), + std::numeric_limits::max()}; + default: +#ifdef _MSC_VER + __assume(0); +#else + __builtin_unreachable(); +#endif + } +} + +namespace at { +namespace native { +namespace qnnp_avgpool_helper { +Tensor qnnpack_avg_pool2d( + Tensor input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override); +} // qnnp_avgpool_helper +} // namespace native +} // namespace at + +namespace { +C10_UNUSED std::vector generate_requantization_scales( + const at::Tensor& weight_scales, + const float input_scale, + const float output_scale, + std::vector& requant_scales) { + // Since weight scale is allocated with padding + // weight_scales.numel() gives us padded num elements. + const auto num_output_channels_padded = weight_scales.numel(); + float *const weight_scales_data = weight_scales.data_ptr(); + if (static_cast(requant_scales.size()) < num_output_channels_padded) { + requant_scales.resize(num_output_channels_padded); + } + for (const auto i : c10::irange(num_output_channels_padded)) { + const auto inverse_output_scale = 1.f /output_scale; + requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale; + TORCH_CHECK( + (requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])), + "failed to create op with requantization scale: ", + requant_scales[i], + ": requantization scale must be finite and positive"); + } + return requant_scales; +} + +C10_UNUSED std::pair, at::Tensor> make_zero_points_and_scales_tensor( + const at::Tensor& weight_contig, + bool transpose = false, + uint32_t groups = 1 + ) { + const int out_ch_idx = transpose ? 1 : 0; + const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); + // Add 8 to account for bufferring needed by QNNPACK. + const auto num_output_channels_padded = num_output_channels + kPaddingChannels; + const auto qtype = weight_contig.qscheme(); + std::vector weight_zp(num_output_channels_padded, 0); + // Adjust weight zero point, similar to weight data. + if (qtype == at::kPerTensorAffine) { + for (const auto i : c10::irange(num_output_channels)) { + weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128); + } + } else if (qtype == at::kPerChannelAffine) { + TORCH_CHECK( + weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong, + "Per channel zero points dtype must be long int."); + const int64_t* per_channel_zero_points = + weight_contig.q_per_channel_zero_points().data_ptr(); + for (const auto i : c10::irange(num_output_channels)) { + weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme."); + } + at:: Tensor weight_scales = + at::empty( + {num_output_channels_padded}, + at::device(at::kCPU).dtype(at::kFloat)); + float *const weight_scales_data = weight_scales.data_ptr(); + if (qtype == at::kPerTensorAffine) { + for (const auto i : c10::irange(num_output_channels)) { + weight_scales_data[i] = weight_contig.q_scale(); + } + } else if (qtype == at::kPerChannelAffine) { + TORCH_CHECK( + weight_contig.q_per_channel_scales().scalar_type() == at::kDouble, + "Per channel scales dtype must be double."); + const double *const per_channel_scales = + weight_contig.q_per_channel_scales().data_ptr(); + for (const auto i : c10::irange(num_output_channels)) { + weight_scales_data[i] = static_cast(per_channel_scales[i]); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme."); + } + for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) { + weight_scales_data[i] = 1.f; + } + return {weight_zp, weight_scales}; +} +} // namespace + +#endif diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..4499bb73694cde5e422d1a3c7dd15cbbd17697d4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace quant_utils { +namespace { + float RawUint16ToFp16(unsigned short value) { + // Convert raw 16 bits half precision floating point number + // to single precision floating point number. + const unsigned short sign_bits = value >> 15; + const unsigned short exponent_bits = value >> 10 & 0x1f; + const unsigned short significand_bits = value & 0x3ff; + + const float sign = sign_bits ? -1 : 1; + const float significand = + 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10; + const float exponent = exponent_bits - 0xf; + + return sign * std::ldexp(significand, exponent); +} + +template +bool CheckAndSaturate(T max_val, T* element) { + if (*element > max_val) { + *element = max_val; + return true; + } + if (*element < -max_val) { + *element = -max_val; + return true; + } + return false; +} +} +using namespace std; +// A structure to hold quantization parameters 'scale' and 'zero_point'. +// The meaning of these values is as the constants in the quantization equation +// +// real_value = scale * (quantized_value - zero_point) +// +// In other words, 'zero_point' is the quantized value that corresponds +// to the real value 0, and 'scale' is the difference of real values +// corresponding to consecutive quantized values. +struct TensorQuantizationParams { + double scale; + std::int32_t zero_point; + int precision; +}; + +// Use fp16_min as the small scale cutoff because we don't want to use scales in +// fp16 subnormal range. This is to be consistent with Glow and FakeLowP +// implementation for NNPI. +constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + +// Following implementation should be identical to fbgemm::ChooseQuantizationParams +inline TensorQuantizationParams ChooseQuantizationParams( + float min, + float max, + int32_t qmin, + int32_t qmax, + bool preserve_sparsity = false, + bool force_scale_power_of_two = false, + bool reduce_range = false) { + TORCH_CHECK( + min <= max, + "In ChooseQuantizationParams, min should be less than or equal to max"); + + if (reduce_range) { + qmin = qmin/2; + qmax = qmax/2; + } + if (min < 0 && max > 0 && preserve_sparsity) { + int symmetric_qmin = -((qmax - qmin) / 2 + 1); + int symmetric_qmax = (qmax - qmin) / 2; + double max_scale = + std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax)); + min = max_scale * symmetric_qmin; + max = max_scale * symmetric_qmax; + } + + // We extend the [min, max] interval to ensure that it contains 0. + // Otherwise, we would not meet the requirement that 0 be an exactly + // representable value. + min = std::min(min, 0.f); + max = std::max(max, 0.f); + + TORCH_CHECK( + qmin < qmax, + "In ChooseQuantizationParams, qmin should be less than qmax"); + + // Use double precision for intermediate computation but use single precision + // in final number to reflect the actual number used during quantization. + double scale = (static_cast(max) - min) / (qmax - qmin); + // If scale is 0 or too small so its reciprocal is infinity, we arbitrary + // adjust the scale to 0.1 . We want to avoid scale's reciprocal being + // infinity because some of fbgemm code pre-computes scale's reciprocal to do + // multiplication instead of division in the time critical part of code. + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + TORCH_CHECK(scale > 0, "quantization scale should be > 0"); + + if (force_scale_power_of_two) { + if (scale < 1) { + scale = 1.0 / (1 << static_cast(floor(log(1.0 / scale) / log(2)))); + } else { + scale = 1 << static_cast(ceil(log(scale) / log(2))); + } + } + + // Cut off small scale + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust the min and max based on the new scale + if (min == 0.0f) { + max = SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else if (max == 0.0f) { + min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min *= amplifier; + max *= amplifier; + } + } + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + double zero_point_from_min = qmin - min / static_cast(scale); + double zero_point_from_max = qmax - max / static_cast(scale); + double zero_point_from_min_error = + std::abs(qmin) - std::abs(min / static_cast(scale)); + double zero_point_from_max_error = + std::abs(qmax) - std::abs(max / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // for symmetric quantization (preserve_sparsity == true), we force zero_point + // to be a middle value between qmin and qmax. + // If either min or max is 0, then we just use 0 as zero_point. + if (min < 0 && max > 0 && preserve_sparsity) { + initial_zero_point = static_cast(qmin + qmax) / 2; + } + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with zero + // padding). + int32_t nudged_zero_point = 0; + if (initial_zero_point < qmin) { + nudged_zero_point = qmin; + } else if (initial_zero_point > qmax) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = nearbyint(initial_zero_point); + } + + TensorQuantizationParams result; + result.scale = scale; + result.zero_point = nudged_zero_point; + return result; +} + +// This function helps to convert the Conv1D dimensions usable by the Conv2d op. +constexpr int64_t kConv1dSqueezeDim = 0; +static C10_UNUSED torch::List MakeArgForConv1d(const torch::List& arg, + int64_t base_value) { + TORCH_CHECK(!arg.empty(), "Argument must have elements."); + torch::List result({arg.get(0), base_value}); + if (arg.size() == 1) { + result[1] = arg.get(0); + } else { + result[1] = arg.get(1); + } + result[kConv1dSqueezeDim] = base_value; + return result; +} + +// The range for using FP16 quantization of weights requires that the elements +// should be in the range of [5.96e-8, 65504]. If it is out of range, then the +// number will be saturated to max or min representable values by FP16. +inline void HandleWeightsSaturation(int64_t N, float* weight) { + const float kFp16Max = RawUint16ToFp16(0x7BFF); + bool found_out_of_range = false; + for (const auto i : c10::irange(N)) { + bool saturate = CheckAndSaturate(kFp16Max, weight + i); + if (saturate) { + found_out_of_range = true; + } + } + if (found_out_of_range) { + TORCH_WARN("FOUND weight out of range "); + } +} + +// Util function for quantizing bias. +inline at::Tensor QuantizeBias( + bool is_per_channel, + const at::Tensor& bias, + const at::Tensor& weight_contig, + double input_scale) { + at::Tensor qbias; + if (is_per_channel) { + auto bias_quant_scales = + weight_contig.q_per_channel_scales() * input_scale; + auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt); + qbias = at::native::quantize_per_channel( + bias, bias_quant_scales, bias_zp, 0, c10::kQInt32); + } else { + qbias = at::native::quantize_per_tensor( + bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32); + } + return qbias; +} + +} // namespace quant_utils diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h new file mode 100644 index 0000000000000000000000000000000000000000..15e442a15b01ed74981ebae504756f9b7a38aad7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h @@ -0,0 +1,258 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Scalar& /*negval_*/); +using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */); +using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point); +using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qclamp_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& min, + const Scalar& max, + at::Tensor& /*qy*/); +using qclamp_minmax_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& /*min or max*/, + at::Tensor& /*qy*/); +using qthreshold_fn = void (*)( + const at::Tensor& /*qx*/, + const Scalar& threshold, + const Scalar& value, + at::Tensor& /*qy*/); +using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qelu_fn = void(*)( + const at::Tensor& /*qx*/, + const Scalar& /*alpha*/, + const Scalar& /*scale*/, + const Scalar& /*input_scale*/, + at::Tensor& /*qy*/); +using qbinary_fn = + void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/); +using qadd_scalar_fn = + void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/); +using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); +using qdropout_fn = void(*)( + const at::Tensor& /*qx*/, + const Scalar& /*p*/, + bool training /*training*/, + at::Tensor& /*qy*/); +using qmaxpool_2d_fn = void (*)( + const Tensor& qx, + int64_t iC, // input/output channels + int64_t iH, + int64_t iW, // input sizes + int64_t oH, + int64_t oW, // output sizes + int64_t kH, + int64_t kW, // kernel size + int64_t sH, + int64_t sW, // strides + int64_t pH, + int64_t pW, // padding + int64_t dH, + int64_t dW, // dilation + Tensor& qy); +using qmaxpool_3d_fn = void (*)( + const Tensor& qx, + int64_t iC, // input/output channels + int64_t iT, + int64_t iH, + int64_t iW, // input sizes + int64_t oT, + int64_t oH, + int64_t oW, // output sizes + int64_t kT, + int64_t kH, + int64_t kW, // kernel size + int64_t sT, + int64_t sH, + int64_t sW, // strides + int64_t pT, + int64_t pH, + int64_t pW, // padding + int64_t dT, + int64_t dH, + int64_t dW, // dilation + Tensor& qy); +using qadaptive_avg_pool2d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t sizeB, + int64_t sizeC, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideC, + int64_t istrideH, + int64_t istrideW); +using qadaptive_avg_pool3d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t sizeB, + int64_t sizeC, + int64_t isizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeD, + int64_t osizeH, + int64_t osizeW, + int64_t istrideB, + int64_t istrideC, + int64_t istrideD, + int64_t istrideH, + int64_t istrideW); +using qavg_pool2d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t nBatch, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t outputWidth, + int64_t outputHeight, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + bool count_include_pad, + c10::optional divisor_override); + +using qavg_pool3d_fn = void (*)( + const Tensor& qx, + Tensor& qy, + int64_t nBatch, + int64_t nInputPlane, + int64_t inputWidth, + int64_t inputHeight, + int64_t inputDepth, + int64_t outputWidth, + int64_t outputHeight, + int64_t outputDepth, + int kW, + int kH, + int kD, + int dW, + int dH, + int dD, + int padW, + int padH, + int padD, + bool count_include_pad, + c10::optional divisor_override); + +using qupsample_bilinear2d_fn = void (*)( + Tensor& output, + const Tensor& input, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +using qcat_nhwc_fn = Tensor (*)( + const MaterializedITensorListRef& qxs, + int64_t dim, + double scale, + int64_t zero_point); +using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool); + +using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&); + +using qnormalize_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qmean_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* opt_dim */, + bool /* keepdim */, + c10::optional /* opt_dtype */, + Tensor& /* Y */); + +using qstd_inner_dim_fn = void (*)( + const Tensor& /* X */, + OptionalIntArrayRef /* dim */, + const c10::optional& /* correction */, + bool /* keepdim */, + Tensor& /* Y */); + +using qnormalize_nhwc_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + bool /* affine_per_channel */, + int /* num_channels */, + int /* num_groups */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */); + +using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/, + const Tensor& /*qw*/); + +DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub); +DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub); +DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub); +DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub); +DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub); +DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub); +DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub); +DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub); +DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub); +DECLARE_DISPATCH(qbinary_fn, qadd_stub); +DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub); +DECLARE_DISPATCH(qbinary_fn, qmul_stub); +DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub); +DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub); +DECLARE_DISPATCH(qclamp_fn, qclamp_stub); +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub); +DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub); +DECLARE_DISPATCH(qelu_fn, qelu_stub); +DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub); +DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub); +DECLARE_DISPATCH(qdropout_fn, qdropout_stub); +DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub); +DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub); +DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub); +DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub); +DECLARE_DISPATCH(qrelu_fn, qrelu_stub); +DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub); +DECLARE_DISPATCH(qgelu_fn, qgelu_stub); +DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub); +DECLARE_DISPATCH(qtanh_fn, qtanh_stub); +DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub); +DECLARE_DISPATCH(qtopk_fn, qtopk_stub); +DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub); +DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub); +DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub); +DECLARE_DISPATCH(qprelu_fn, qprelu_stub); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..ecfddcf449d5968f72b2d518914f980615183fe6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef USE_RUY_QMATMUL + +#include + +namespace at { +namespace native { +namespace ruy_utils { + +ruy::Context* get_ruy_context(); + +void quantize_multiplier(double scale, + int* multiplier_fixedpoint, + int* multiplier_exponent); + +} // namespace ruy_utils +} // namespace native +} // namespace + +#endif // USE_RUY_QMATMUL diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6b0d6ab04ddac6cdf9dbda696e72bed4afd9c7ac --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h @@ -0,0 +1,335 @@ +#pragma once + +#ifdef USE_XNNPACK +#include + +#include +#include + +using xnnpack_operator = at::native::xnnpack::Operator; + +namespace at { +namespace native { +namespace xnnp_utils { + +/* + * Return shape in the same order as the memory format + * e.g. channels_last will return NHWC instead of NCHW + */ +std::vector get_mem_format_aware_shape(const at::Tensor& in); + +/* + * Input is always int8_t, output can be [int8_t, uint8_t]. + * input + offset = output + * int8_t + 128 = uint8_t + * int8_t + 0 = int8_t + */ +template +void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); + +template +Tensor convert_conv_weights_to_channel_last_tensor( + const at::Tensor& src, + int groups, + bool transpose); + +/* + * Series of create wrapper functions to call xnn_create_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_convolution2d_nhwc( + uint32_t pad_top, + uint32_t pad_right, + uint32_t pad_bottom, + uint32_t pad_left, + uint32_t kernel_h, + uint32_t kernel_w, + uint32_t stride_h, + uint32_t stride_w, + uint32_t dilation_h, + uint32_t dilation_w, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t ip_chan_stride, + size_t op_chan_stride, + int8_t izp, + float ip_scale, + int8_t kzp, + const float* k_scales, + const int8_t* kernel, + const int32_t* bias, + int8_t ozp, + float op_scale, + int8_t op_min, + int8_t op_max, + uint32_t flags, + xnn_operator_t* op, + bool per_channel, + bool transpose) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." + "But got: ", kzp); + + if (transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_create_deconvolution2d_nhwc_qs8( + pad_top, /* uint32_t output_padding_top */ + pad_right, /* uint32_t output_padding_right */ + pad_bottom, /* uint32_t output_padding_bottom */ + pad_left, /* uint32_t output_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t stride_height */ + stride_w, /* uint32_t stride_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels */ + ip_chan_stride, /* size_t input_pixel_stride */ + op_chan_stride, /* size_t output_pixel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* deconvolution_op_out */ + + } + + if (!per_channel) { + return xnn_create_convolution2d_nhwc_qs8( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales[0], /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } else { /* per_channel */ + return xnn_create_convolution2d_nhwc_qs8_qc8w( + pad_top, /* uint32_t input_padding_top */ + pad_right, /* uint32_t input_padding_right */ + pad_bottom, /* uint32_t input_padding_bottom */ + pad_left, /* uint32_t input_padding_left */ + kernel_h, /* uint32_t kernel_height */ + kernel_w, /* uint32_t kernel_width */ + stride_h, /* uint32_t subsampling_height */ + stride_w, /* uint32_t subsampling_width */ + dilation_h, /* uint32_t dilation_height */ + dilation_w, /* uint32_t dilation_width */ + groups, /* uint32_t groups */ + group_input_channels, /* size_t group_input_channels */ + group_output_channels, /* size_t group_output_channels*/ + ip_chan_stride, /* size_t input_channel_stride */ + op_chan_stride, /* size_t output_channel_stride */ + izp, /* int8_t input_zero_point */ + ip_scale, /* float input_scale */ + k_scales, /* const float* kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + ozp, /* int8_t output_zero_point */ + op_scale, /* float output_scale */ + op_min, /* int8_t output_min */ + op_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t weights_cache */ + op); /* xnn_operator_t* convolution_op_out */ + } +} + +/* + * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_convolution2d_nhwc( + xnn_operator_t op, + size_t batch, + size_t in_h, + size_t in_w, + pthreadpool_t pt_pool, + bool per_channel = false, + bool transpose = false, + uint32_t adj_h = 0, + uint32_t adj_w = 0) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + return xnn_reshape_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + adj_h, /* uint32_t adjustment_height */ + adj_w, /* uint32_t adjustment_width */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } + + size_t workspace_size = SIZE_MAX; + size_t workspace_alignment = SIZE_MAX; + + if (!per_channel) { + return xnn_reshape_convolution2d_nhwc_qs8( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } else { /* per_channel */ + return xnn_reshape_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t convolution_op */ + batch, /* size_t batch_size */ + in_h, /* size_t input_height */ + in_w, /* size_t input_width */ + &workspace_size, /* size_t* workspace_size */ + &workspace_alignment, /* size_t* workspace_alignment */ + nullptr, /* size_t* output_height_out */ + nullptr, /* size_t* output_width_out */ + pt_pool); /* pthreadpool_t threadpool */ + } +} + + +/* + * Series of setup wrapper functions to call xnn_setup_[de]conv* functions. + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_convolution2d_nhwc( + xnn_operator_t op, + const int8_t* inp, + int8_t* outp, + bool per_channel = false, + bool transpose = false) { + if(transpose) { + TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); + + return xnn_setup_deconvolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } + + if (!per_channel) { + return xnn_setup_convolution2d_nhwc_qs8( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } else { /* per_channel */ + return xnn_setup_convolution2d_nhwc_qs8_qc8w( + op, /* xnn_operator_t deconvolution_op */ + nullptr, /* void workspace */ + inp, /* const int8_t* input */ + outp); /* int8_t* output */ + } +} + + +/* + * Series of wrapper functions to call xnn_create* and xnn_setup* + * functions for linear + */ +C10_ALWAYS_INLINE +enum xnn_status xnnp_create_fully_connected_nc( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + int8_t input_zero_point, + float input_scale, + int8_t kernel_zero_point, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_operator_t* fully_connected_op_out) { + /* Symmetric quantization forces kzp = 0 */ + TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." + "But got: ", kernel_zero_point); + return xnn_create_fully_connected_nc_qs8( + input_channels, /* size_t input_channels */ + output_channels, /* size_t output_channels */ + input_stride, /* size_t input_stride */ + output_stride, /* size_t output_stride */ + input_zero_point, /* int8_t input_zero_point */ + input_scale, /* float input_scale */ + kernel_scale, /* float kernel_scale */ + kernel, /* const int8_t* kernel */ + bias, /* const int32_t* bias */ + output_zero_point, /* int8_t output_zero_point */ + output_scale, /* float output_scale */ + output_min, /* int8_t output_min */ + output_max, /* int8_t output_max */ + flags, /* uint32_t flags */ + nullptr, /* xnn_caches_t caches */ + nullptr, /* xnn_weights_cache_t */ + fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_reshape_fully_connected_nc( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) { + return xnn_reshape_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + batch_size, /* size_t batch_size */ + threadpool); /* pthreadpool_t threadpool */ +} + +C10_ALWAYS_INLINE +enum xnn_status xnnp_setup_fully_connected_nc( + xnn_operator_t fully_connected_op, + const int8_t* input, + int8_t* output) { + return xnn_setup_fully_connected_nc_qs8( + fully_connected_op, /* xnn_operator_t fully_connected_op */ + input, /* const int8_t* input */ + output /* int8_t* output */ + ); +} + +} // namespace xnnp_utils +} // namespace native +} // namespace at + +#endif // USE_XNNPACK diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h new file mode 100644 index 0000000000000000000000000000000000000000..9af6c65af7716108586889bc5b2be0edbd32e138 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h @@ -0,0 +1,414 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#if !defined(__s390x__) && !defined(__powerpc__) +#include +#endif + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + + +#include + +/* Convolution prepacked parameters serialization. + * + * Version 1 + * + * - Fields: + * 1. weight + * 2. bias + * 3. stride x kSpatialDim + * 4. padding x kSpatialDim + * 5. dilation x kSpatialDim + * 6. groups + * + * Version 2 + * + * - Fields: + * 0. version (string) + * 1. list of non-optional tensors + * 0: packed parameters (int16_t) + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim + * - groups + * - transpose (0 or 1) + * 1: weight + * 2. list of optional tensors + * 0: bias + * + * Version 3 + * + * - Fields: + * 0. version (int64_t) + * 1. list of int64_t configuration values + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim + * - groups + * - flags (bitmask) + * - (1 << 0) transpose (1 = yes) + * 2. list of optional tensors + * 0: None (helps with type inference) + * 1: weight (this must be present) + * 2: bias + */ + +using ConvParamsSerializationTypeV2 = std::tuple< + // version, for versions 2 and up + std::string, + // non-optional tensors + std::vector, + // optional tensors + std::vector>>; + +using ConvParamsSerializationTypeV3 = std::tuple< + // version, int for versions 3 and up + int64_t, + // configuration values + std::vector, + // optional tensors + std::vector>>; + +// Parses any historical conv packed params format into +// the current format. +template +ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) { + + // determine the version based on IValue contents + int version = -1; + if (v.isTuple()) { + const auto& elements = v.toTupleRef().elements(); + if (!elements.empty()) { + auto firstElement = elements[0]; + if (firstElement.isTensor()) { + version = 1; + } else if (firstElement.isString()) { + const std::string& version_str = firstElement.toStringRef(); + // note: not parsing the string to automatically handle bad + // inputs + if (version_str == "2") { + version = 2; + } + } else if (firstElement.isInt()) { + auto raw_version = firstElement.toInt(); + if (raw_version == 3) { + version = 3; + } + } + } + } + TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version"); + + if (version == 1) { + // version 1 - convert to version 3 manually + + const auto& elements = v.toTupleRef().elements(); + + at::Tensor weight = elements[0].toTensor(); + c10::optional bias = elements[1].toOptional(); + torch::List stride_x_kSpatialDim = elements[2].toTensorList(); + torch::List padding_x_kSpatialDim = elements[3].toTensorList(); + torch::List dilation_x_kSpatialDim = elements[4].toTensorList(); + at::Tensor groups = elements[5].toTensor(); + + std::vector config_vals; + config_vals.reserve( + stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() + + dilation_x_kSpatialDim.size() + kSpatialDim + 3); + config_vals.push_back(kSpatialDim); + for (const auto i : c10::irange(stride_x_kSpatialDim.size())) { + auto stride = stride_x_kSpatialDim.get(i); + config_vals.push_back(stride[0].item()); + } + for (const auto i : c10::irange(padding_x_kSpatialDim.size())) { + auto padding = padding_x_kSpatialDim.get(i); + config_vals.push_back(padding[0].item()); + } + for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) { + auto dilation = dilation_x_kSpatialDim.get(i); + config_vals.push_back(dilation[0].item()); + } + // output_padding does not exist in v1, so we fill in a default value + for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + config_vals.push_back(0); + } + config_vals.push_back(groups[0].item()); + // transpose does not exist in v1, so we fill in a default value + config_vals.push_back(0); + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); + } else if (version == 2) { + // version 2 + const auto& elements = v.toTupleRef().elements(); + std::vector non_optional = elements[1].toTensorList().vec(); + std::vector> optional; + + if (elements[2].isTensorList()) { + for (const auto& elem : elements[2].toTensorList()) { + optional.emplace_back(static_cast(elem)); + } + } else { + for (const auto& elem : elements[2].toList()) { + optional.emplace_back(static_cast(elem).toOptional()); + } + } + // create default optional value for bias + if (optional.empty()) { + optional.emplace_back(); + } + + auto config_a = non_optional[0].accessor(); + std::vector config_vals; + config_vals.reserve(config_a.size(0)); + for (const auto i : c10::irange(config_a.size(0))) { + config_vals.emplace_back(config_a[i]); + } + + auto weight = non_optional[1]; + auto bias = optional[0]; + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); + } else if (version == 3) { + return v.to(); + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ", + version); + } +} + +#define QCONV_SERIALIZATION_VERSION 2 + +#if QCONV_SERIALIZATION_VERSION == 2 +using ConvParamsSerializationType = ConvParamsSerializationTypeV2; + +template +ConvParamsSerializationTypeV2 serialize_conv( + const c10::intrusive_ptr>& params) { + + std::string version = "2"; + std::vector non_optional; + std::vector> optional; + + // create a packed int8_t tensor for conv params + std::vector params_vec; + params_vec.push_back(kSpatialDim); + auto stride = params->stride().vec(); + params_vec.insert(params_vec.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + params_vec.insert(params_vec.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + params_vec.insert(params_vec.end(), dilation.begin(), dilation.end()); + auto output_padding = params->output_padding().vec(); + params_vec.insert(params_vec.end(), output_padding.begin(), + output_padding.end()); + params_vec.push_back(params->groups()); + params_vec.push_back(params->transpose()); + int64_t vec_size = params_vec.size(); + at::Tensor params_tensor = at::from_blob( + params_vec.data(), {vec_size}, + at::TensorOptions().dtype(at::kShort)) + // clone to retain ownership of the data + .clone(); + + auto [weight, bias] = params->unpack(); + + non_optional.emplace_back(std::move(params_tensor)); + non_optional.emplace_back(std::move(weight)); + optional.emplace_back(std::move(bias)); + + return std::tie(version, non_optional, optional); +} + +#elif QCONV_SERIALIZATION_VERSION == 3 +using ConvParamsSerializationType = ConvParamsSerializationTypeV3; + +template +ConvParamsSerializationTypeV3 serialize_conv( + const c10::intrusive_ptr>& params) { + std::vector config_vals; + config_vals.push_back(kSpatialDim); + auto stride = params->stride().vec(); + config_vals.insert(config_vals.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + config_vals.insert(config_vals.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + config_vals.insert(config_vals.end(), dilation.begin(), dilation.end()); + auto output_padding = params->output_padding().vec(); + config_vals.insert(config_vals.end(), output_padding.begin(), + output_padding.end()); + config_vals.push_back(params->groups()); + config_vals.push_back(params->transpose()); + + auto [weight, bias] = params->unpack(); + + std::vector> tensors; + tensors.emplace_back(); + tensors.emplace_back(weight); + tensors.emplace_back(bias); + + int64_t version = 3; + return std::tie(version, config_vals, tensors); +} + +#else +#error "Invalid qconv serialization version." +#endif + +template +c10::intrusive_ptr> deserialize_conv( + ConvParamsSerializationTypeV3 state) { + auto [version, config_vals, tensors] = state; + TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version); + + TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size()); + c10::optional weight = tensors[1]; + c10::optional bias = tensors[2]; + TORCH_INTERNAL_ASSERT(weight, "Weight should always be present in serialized qconv."); + + torch::List stride, padding, output_padding, dilation; + // skip kSpatialDim + int idx = 1; + for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + stride.emplace_back(config_vals.at(idx)); + idx++; + } + for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + padding.emplace_back(config_vals.at(idx)); + idx++; + } + for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + dilation.emplace_back(config_vals.at(idx)); + idx++; + } + for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + TORCH_INTERNAL_ASSERT(idx < static_cast(config_vals.size()), + "Unexpected index = ", idx, " for config_vals of size ", + config_vals.size()); + output_padding.emplace_back(config_vals.at(idx)); + idx++; + } + int64_t groups = config_vals.at(idx); + idx++; + int64_t flags = config_vals.at(idx); + idx++; + TORCH_INTERNAL_ASSERT(idx == static_cast(config_vals.size()), + "Unexpected length of config_vals, expected ", + idx, + " got ", + config_vals.size()); + + bool transpose = flags & (1 << 0); + + int64_t other_flags = flags & ~(1 << 0); + TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, "."); + + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::X86) { +#if AT_MKLDNN_ENABLED() + bool use_onednn = onednn_utils::should_use_onednn_quant( + weight.value(), transpose, groups, output_padding); + if (use_onednn) { + return PackedConvWeightsOnednn::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif + return PackedConvWeight::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } // x86 +#endif + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return PackedConvWeight::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + TORCH_CHECK( + kSpatialDim == 2, + "prepack/__setstate__: QNNPACK only supports Conv2d " + "now."); + return PackedConvWeightsQnnp::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedConvWeightsOnednn::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // AT_MKLDNN_ENABLED() +TORCH_CHECK( + false, + "Didn't find engine for when deserializing ConvPackedParams: ", + toString(ctx.qEngine())); +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d35336fde12173423f61a74735bfd5fb65d76377 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h @@ -0,0 +1,411 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#include +#include + +// The struct for the packed weight matrix (PackBMatrix) and the corresponding +// column offsets used for the fully connect layer, which are both prepared in +// the prepacking step to save the computations in the inference. Note the +// column offsets include the sum of the B columns as well as the scalar term +// B_zero_point * K, whereas the row offsets created by +// PackAWithQuantRowOffset/PackAWithIm2Col/PackAWithRowOffset are only the sum +// of the A rows. The column offsets are needed for the asymmetric quantization +// (affine quantization) of input matrix. +// Note that in JIT mode we can think of a way to fuse col_offsets with bias. +struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { + PackedLinearWeight( + std::unique_ptr> w, + c10::optional bias, + std::vector col_offsets, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias_(std::move(bias)), + col_offsets(std::move(col_offsets)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(std::move(q_scheme)) {} + std::unique_ptr> w; + c10::optional bias_; + std::vector col_offsets; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor& apply_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor& apply_relu_out( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + std::tuple> unpack() override; + + c10::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + private: + template + at::Tensor& apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point, + at::Tensor& output); + + template + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl( + const at::Tensor& input, + double input_scale, + int64_t input_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); +}; + +struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase { + PackedLinearWeightFp16( + std::unique_ptr w, + c10::optional bias) + : w(std::move(w)), bias_(std::move(bias)) {} + + std::unique_ptr w; + c10::optional bias_; + + at::Tensor apply( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + at::Tensor apply_relu( + at::Tensor /*input*/, + double /*output_scale*/, + int64_t /*output_zero_point*/) override { + TORCH_INTERNAL_ASSERT(false); + } + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) + override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) + override; + + at::Tensor& apply_dynamic_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + at::Tensor& apply_dynamic_relu_out( + const at::Tensor& input, + at::Tensor& output, + bool reduce_range = false) override; + + std::tuple> unpack() override; + + c10::optional bias() override { + return bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + void set_bias(c10::optional bias) override; + + private: + template + at::Tensor& apply_dynamic_impl(const at::Tensor& input, at::Tensor& output); +}; + +template +struct TORCH_API PackedConvWeight : public ConvPackedParamsBase { + PackedConvWeight( + std::unique_ptr> w, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose, + std::vector col_offsets, + std::vector kernel, + std::vector w_scale, + std::vector w_zp, + c10::QScheme q_scheme) + : w(std::move(w)), + bias(std::move(bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose), + col_offsets(std::move(col_offsets)), + kernel(std::move(kernel)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + q_scheme(q_scheme) {} + + std::unique_ptr> w; + c10::optional bias; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + std::vector col_offsets; + std::vector kernel; + std::vector w_scale; + std::vector w_zp; + c10::QScheme q_scheme; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + const float* GetBiasData(at::Tensor* bias); + + void GetQuantizationParams( + float act_scale, + float out_scale, + std::vector* output_multiplier_float, + std::vector* act_times_w_scale); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +// PackWeight: Convert the weight from uint8 to int8. +inline void convert_uint8_int8( + int len, + const uint8_t* src_uint8, + int8_t* dst_int8) { + for (const auto i : c10::irange(len)) { + dst_int8[i] = static_cast(static_cast(src_uint8[i]) - 128); + } +} + +// UnpackWeight: Convert the weight from int8 to uint8. +inline void convert_int8_uint8( + int len, + const int8_t* src_int8, + uint8_t* dst_uint8) { + for (const auto i : c10::irange(len)) { + dst_uint8[i] = + static_cast(static_cast(src_int8[i]) + 128); + } +} + +namespace at { +namespace native { +namespace fbgemm_utils { + +template +fbgemm::conv_param_t MakeFbgemmConvParam( + int N, + int C, + int M, + const std::vector& image_shape, + int groups, + const std::vector& kernels, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + const std::vector& output_padding = std::vector(kSpatialDim, 0), + bool transposed = false); + +// TODO: Remove functions below when ChannelsLast3d is ready. +Tensor MakeStridedQTensorCPU( + const IntArrayRef& sizes, + const IntArrayRef& strides, + const TensorOptions& options, + QuantizerPtr quantizer); + +Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + double scale, + int64_t zero_point); + +Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor( + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W, + const TensorOptions& options, + const Tensor& scales, + const Tensor& zero_points); + +Tensor ConvertToChannelsLast3dTensor(const Tensor& src); + +template +Tensor TransposeConvTensorUnpackConversion(const Tensor& src, int groups); + +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); +} // namespace fbgemm_utils +} // namespace native +} // namespace at + +#endif // USE_FBGEMM + +struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { + PackedEmbeddingBagWeight( + at::Tensor packed_w, + std::vector w_scale, + std::vector w_zp, + int64_t bit_rate, + c10::QScheme q_scheme, + int64_t version) + : packed_w(std::move(packed_w)), + w_scale(std::move(w_scale)), + w_zp(std::move(w_zp)), + bit_rate_(bit_rate), + q_scheme(q_scheme), + version_(version) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) + if (!packed_w.is_contiguous()) { + packed_w = packed_w.contiguous(); + } + } + + at::Tensor packed_w; + std::vector w_scale; + std::vector w_zp; + int64_t bit_rate_; + c10::QScheme q_scheme; + int64_t version_; + + at::Tensor unpack() override; + static c10::intrusive_ptr prepack( + at::Tensor weight); + + int64_t bit_rate() const override { + return bit_rate_; + } + + int64_t version() const override { + return version_; + } + + at::Tensor embeddingbag_byte( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; + + at::Tensor embeddingbag_4bit( + const at::Tensor& indices, + const c10::optional& offsets, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset, + bool is_embedding_op) override; +}; diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h new file mode 100644 index 0000000000000000000000000000000000000000..e7a1033e9758b19f9f05b05e2854289be44324c7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h @@ -0,0 +1,13 @@ +#pragma once + +#ifdef USE_PYTORCH_QNNPACK + +namespace at { +namespace native { + +void initQNNPACK(); + +} // namespace native +} // namespace at + +#endif diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h new file mode 100644 index 0000000000000000000000000000000000000000..cd2c04e589c439d55a19fe55fee7f76fc433a5ef --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h @@ -0,0 +1,34 @@ +#pragma once +#include +#include + +namespace at { +namespace native { +Tensor& embedding_bag_byte_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset); + +Tensor& embedding_bag_4bit_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset); + +Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight); + +} // native +} // at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h new file mode 100644 index 0000000000000000000000000000000000000000..652e7501c25dd90a0d38ba2d865d09666b5434f2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at { namespace native { + +Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); + +Tensor qembeddingbag_byte_prepack(const Tensor& weight); + +Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight); + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/utils/Factory.h b/MLPY/Lib/site-packages/torch/include/ATen/native/utils/Factory.h new file mode 100644 index 0000000000000000000000000000000000000000..28444494242ae5ce6e6d728686573491302c3721 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/utils/Factory.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace mobile { + +Tensor allocate_padded_contiguous_if_needed( + const Tensor& input, + c10::MemoryFormat memory_format); + +// TODO: Remove this function when at::native::empty() is modified to accept a +// custom memory allocator. + +at::Tensor empty_with_tail_padding( + IntArrayRef size, + const caffe2::TypeMeta dtype, + c10::MemoryFormat memory_format, + c10::optional maybe_names); + +} // namespace mobile +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/utils/ParamUtils.h b/MLPY/Lib/site-packages/torch/include/ATen/native/utils/ParamUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..3f4cdf5b906b9ecc9b7a5ff130f1a607e7ea4d25 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/utils/ParamUtils.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { + +template +inline std::vector _expand_param_if_needed( + ArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + if (list_param.size() == 1) { + return std::vector(expected_dim, list_param[0]); + } else if ((int64_t)list_param.size() != expected_dim) { + std::ostringstream ss; + ss << "expected " << param_name << " to be a single integer value or a " + << "list of " << expected_dim << " values to match the convolution " + << "dimensions, but got " << param_name << "=" << list_param; + AT_ERROR(ss.str()); + } else { + return list_param.vec(); + } +} + +inline std::vector expand_param_if_needed( + IntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +inline std::vector expand_param_if_needed( + SymIntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/utils/ParamsHash.h b/MLPY/Lib/site-packages/torch/include/ATen/native/utils/ParamsHash.h new file mode 100644 index 0000000000000000000000000000000000000000..24c836f3308d145f8a3e56f0d240404f53df5f0d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/utils/ParamsHash.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +// Hashing machinery for Params +// Fowler–Noll–Vo hash function +// see +// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct ParamsHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + size_t operator()(const Params& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(Params))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +template +struct ParamsEqual { + // Params must be a POD because we read out its memory + // contents as char* when comparing + static_assert(std::is_standard_layout_v, "Params is not POD"); + + bool operator()(const Params& a, const Params& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(Params)) == 0; + } +}; + +// Provide explicit byte-for-byte constructors to avoid uwittingly leaving +// padding bytes unitialized (e.g., when passing Params by value) +template +struct ParamsWrapper { + T pod; + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + ParamsWrapper() { + memset(&(this->pod), 0, sizeof(this->pod)); + } + + ParamsWrapper(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + } + + ParamsWrapper& operator=(const ParamsWrapper& other) { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + ParamsWrapper& operator=(ParamsWrapper&& other) noexcept { + memcpy(&(this->pod), &(other.pod), sizeof(this->pod)); + return *this; + } + + inline friend bool operator==( + const ParamsWrapper& lhs, + const ParamsWrapper& rhs) noexcept { + auto ptr1 = reinterpret_cast(&(lhs.pod)); + auto ptr2 = reinterpret_cast(&(rhs.pod)); + return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0; + } +}; + +// Wrapped version: this allows the outer struct to have custom copy and move +// constructors for additional safety +template +struct ParamsWrapperHash { + // Params must be a POD because we read out its memory + // contents as char* when hashing + static_assert( + std::is_standard_layout_v, + "ParamsWrapper cannot wrap non-POD data"); + + size_t operator()(const ParamsWrapper& params_wrapper) const { + auto ptr = reinterpret_cast(&(params_wrapper.pod)); + uint32_t value = 0x811C9DC5; + for (const auto i : c10::irange(sizeof(params_wrapper.pod))) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +} // namespace at::native diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/verbose_wrapper.h b/MLPY/Lib/site-packages/torch/include/ATen/native/verbose_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..fc16ad2c373177cb92d297b4b78da0efa9800225 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/verbose_wrapper.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace torch::verbose { +TORCH_API int _mkl_set_verbose(int enable); +TORCH_API int _mkldnn_set_verbose(int level); +} // namespace torch::verbose diff --git a/MLPY/Lib/site-packages/torch/include/ATen/native/vol2col.h b/MLPY/Lib/site-packages/torch/include/ATen/native/vol2col.h new file mode 100644 index 0000000000000000000000000000000000000000..7067c741cbc6a23a13d33341963294874b5c3716 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/native/vol2col.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +namespace at::native { + +template +static void vol2col( + const T* data_vol, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t depth_col, + const int64_t height_col, + const int64_t width_col, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_col) { + int64_t c, t, h, w; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + + w_pad]; + else + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + 0; + } + } + } + } +} + +template +static void col2vol( + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t out_depth, + const int64_t out_height, + const int64_t out_width, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_vol) { + memset(data_vol, 0, sizeof(T) * depth * height * width * channels); + int64_t depth_col = out_depth; + int64_t height_col = out_height; + int64_t width_col = out_width; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (int64_t c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (int64_t t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (int64_t h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (int64_t w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] += + data_col + [((c * depth_col + t) * height_col + h) * width_col + w]; + } + } + } + } +} + +} // namespace at::native diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/abs.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs.h new file mode 100644 index 0000000000000000000000000000000000000000..ec5807aed88f2adbe3743936bc29ac430f22db16 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs.h @@ -0,0 +1,44 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::abs(Tensor self) -> Tensor +inline at::Tensor abs(const at::Tensor & self) { + return at::_ops::abs::call(self); +} + +// aten::abs_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & abs_(at::Tensor & self) { + return at::_ops::abs_::call(self); +} + +// aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::abs_out::call(self, out); +} +// aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & abs_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::abs_out::call(self, out); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6464bff68321e5e032d3caeab3c5607692887418 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautograd { + +TORCH_API at::Tensor abs(const at::Tensor & self); +TORCH_API at::Tensor & abs_(at::Tensor & self); + +} // namespace compositeexplicitautograd +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3318cfc1b25c809c02e00de2a958fa2889d0038d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_cpu_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & abs_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..0d19eeae3b4cbb14e13e3ae89b7d264a564fdf15 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_cuda_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & abs_outf(const at::Tensor & self, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_native.h new file mode 100644 index 0000000000000000000000000000000000000000..4c32dc5290c27dfb642d93211a240fe9f2e56ab0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_native.h @@ -0,0 +1,31 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor abs(const at::Tensor & self); +TORCH_API at::Tensor & abs_(at::Tensor & self); +TORCH_API at::Tensor & abs_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor NestedTensor_abs(const at::Tensor & self); +TORCH_API at::Tensor & NestedTensor_abs_(at::Tensor & self); +TORCH_API at::Tensor abs_sparse(const at::Tensor & self); +TORCH_API at::Tensor & abs_sparse_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & abs_sparse_(at::Tensor & self); +TORCH_API at::Tensor abs_sparse_csr(const at::Tensor & self); +TORCH_API at::Tensor & abs_sparse_csr_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & abs_sparse_csr_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..22f93d2322e7d71bae4b02158a33d463c2af871c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/abs_ops.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API abs { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "abs(Tensor self) -> Tensor") + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API abs_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs_") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "abs_(Tensor(a!) self) -> Tensor(a!)") + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API abs_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute.h new file mode 100644 index 0000000000000000000000000000000000000000..553cea7dfad348bf5cf66e9884d20ab518b89a7a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::absolute(Tensor self) -> Tensor +inline at::Tensor absolute(const at::Tensor & self) { + return at::_ops::absolute::call(self); +} + +// aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & absolute_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::absolute_out::call(self, out); +} +// aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & absolute_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::absolute_out::call(self, out); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1638f19b5f208cf4c3cf1479f2095b8e45497c6c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor absolute(const at::Tensor & self); +TORCH_API at::Tensor & absolute_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & absolute_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & absolute_(at::Tensor & self); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_native.h new file mode 100644 index 0000000000000000000000000000000000000000..025d6908fafc2af3b41ccf303fee2717660e981f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor absolute(const at::Tensor & self); +TORCH_API at::Tensor & absolute_out(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & absolute_(at::Tensor & self); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..62bb1e6ddd4451e18927e583dfc0732314fe4aae --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/absolute_ops.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API absolute { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::absolute") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "absolute(Tensor self) -> Tensor") + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API absolute_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::absolute_") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "absolute_(Tensor(a!) self) -> Tensor(a!)") + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API absolute_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::absolute") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos.h new file mode 100644 index 0000000000000000000000000000000000000000..446813b7360b2ed29cf88ed2228106323b178286 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos.h @@ -0,0 +1,44 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::acos(Tensor self) -> Tensor +inline at::Tensor acos(const at::Tensor & self) { + return at::_ops::acos::call(self); +} + +// aten::acos_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & acos_(at::Tensor & self) { + return at::_ops::acos_::call(self); +} + +// aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::acos_out::call(self, out); +} +// aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::acos_out::call(self, out); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..17e29662d852815777e082ad49320532f5d400b7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..2396f7b6f9ddeeee84eb60f25d40d583e154ae14 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..31049f254f7e6138f568bcd22b976218f256c57b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..1c1b110ab944a5e528aeabde83e6cb0b53ecd46a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_acos : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..67c00a184c2419173a7a3078a7409ef37588e2b6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor acos(const at::Tensor & self); +TORCH_API at::Tensor & acos_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acos_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acos_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_native.h new file mode 100644 index 0000000000000000000000000000000000000000..1a0be9eca577fd3d8a631de12de4cf3367bc8700 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_acos_out : public at::meta::structured_acos { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..9aeeb020d5e41fff6adfe89e6d03fc4f68da332d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acos_ops.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API acos { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::acos") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "acos(Tensor self) -> Tensor") + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API acos_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::acos_") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "acos_(Tensor(a!) self) -> Tensor(a!)") + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API acos_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::acos") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh.h new file mode 100644 index 0000000000000000000000000000000000000000..ea56387376a1a2df0a2a592d3085dfcb1b3e3229 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh.h @@ -0,0 +1,44 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::acosh(Tensor self) -> Tensor +inline at::Tensor acosh(const at::Tensor & self) { + return at::_ops::acosh::call(self); +} + +// aten::acosh_(Tensor(a!) self) -> Tensor(a!) +inline at::Tensor & acosh_(at::Tensor & self) { + return at::_ops::acosh_::call(self); +} + +// aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self) { + return at::_ops::acosh_out::call(self, out); +} +// aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out) { + return at::_ops::acosh_out::call(self, out); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..279bd1957fe1e17321b6d062b343d12459bfb487 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..026f535fdf09102ea68284aaf189f3de877f1ed9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..a31a1c967c1185ee9b87bd763c281a1ee491d141 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..1974e9150ba6ac3c648bf38c90f4bb0b7b0f0ff7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_acosh : public TensorIteratorBase { + + + void meta(const at::Tensor & self); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d6f867a40b3cc2dc93a7b780cf0142f929228d36 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor acosh(const at::Tensor & self); +TORCH_API at::Tensor & acosh_out(at::Tensor & out, const at::Tensor & self); +TORCH_API at::Tensor & acosh_outf(const at::Tensor & self, at::Tensor & out); +TORCH_API at::Tensor & acosh_(at::Tensor & self); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_native.h new file mode 100644 index 0000000000000000000000000000000000000000..35b66762bd865fb68c256b900bb89d24b7467f11 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_acosh_out : public at::meta::structured_acosh { +void impl(const at::Tensor & self, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a6dcfda790b3c9d749632e5988d3013b59ca7ad3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/acosh_ops.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API acosh { + using schema = at::Tensor (const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::acosh") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "acosh(Tensor self) -> Tensor") + static at::Tensor call(const at::Tensor & self); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self); +}; + +struct TORCH_API acosh_ { + using schema = at::Tensor & (at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::acosh_") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "acosh_(Tensor(a!) self) -> Tensor(a!)") + static at::Tensor & call(at::Tensor & self); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self); +}; + +struct TORCH_API acosh_out { + using schema = at::Tensor & (const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::acosh") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d.h new file mode 100644 index 0000000000000000000000000000000000000000..6d28c33d8834feff6600f61635b8c2e93832d273 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d.h @@ -0,0 +1,30 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d::call(self, output_size); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..45bb93f27cacd0ded73f9227e7be915c3f3de42d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..ff7f39f6536670c5dbf7dbf9ae37179d1fe2ef12 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..61a858c40cce2203bb164cdd40ca985e449a1a6f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool1d_ops.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool1d { + using schema = at::Tensor (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_avg_pool1d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor") + static at::Tensor call(const at::Tensor & self, at::IntArrayRef output_size); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d.h new file mode 100644 index 0000000000000000000000000000000000000000..433acd7aed345d789c1b57ee6086d4b3169f262c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d.h @@ -0,0 +1,91 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, c10::fromIntArrayRefSlow(output_size)); +} +namespace symint { + template ::value>> + at::Tensor adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, c10::fromIntArrayRefSlow(output_size)); + } +} + +// aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, output_size); +} +namespace symint { + template ::value>> + at::Tensor adaptive_avg_pool2d(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::call(self, output_size); + } +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..b14e16398365dbed76601b8a1fecc31053869166 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor adaptive_avg_pool2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..168960f65674dc06fb52a3331be4ff949a7ac84e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6363b3a2ce1d7c7e914ad1fd03301b649157d405 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & adaptive_avg_pool2d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..6ed978fd6a96f295ee2683e6e6c782d8488c1980 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adaptive_avg_pool2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool2d_out_cpu(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool2d_out_cuda(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & mkldnn_adaptive_avg_pool2d_out_stub(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..2cfeee2c7dd4a077debcfc4bab04b5cbce2d282d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool2d_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool2d_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_avg_pool2d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); +}; + +struct TORCH_API adaptive_avg_pool2d { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_avg_pool2d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor") + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef output_size); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d.h new file mode 100644 index 0000000000000000000000000000000000000000..7c1cac534a1327ae16f8db2231860530f98617d5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d.h @@ -0,0 +1,91 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, c10::fromIntArrayRefSlow(output_size), out); + } +} + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); +} +namespace symint { + template ::value>> + at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::call(self, output_size, out); + } +} + +// aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, c10::fromIntArrayRefSlow(output_size)); +} +namespace symint { + template ::value>> + at::Tensor adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, c10::fromIntArrayRefSlow(output_size)); + } +} + +// aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor +inline at::Tensor adaptive_avg_pool3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, output_size); +} +namespace symint { + template ::value>> + at::Tensor adaptive_avg_pool3d(const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::call(self, output_size); + } +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..e1e984135bced3550b4f60ca58ce22e876aa7e50 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward.h @@ -0,0 +1,34 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::call(grad_output, self, grad_input); +} +// aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::call(grad_output, self, grad_input); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..7d6271488475837cc58198cabb052077194eef9f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self); +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e8433c1e6df18a07e0b1e6a229ae5357530ab4dd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self); +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..a72bd814f72257a81fd997e0435d36bd1054afa1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h @@ -0,0 +1,22 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out_cpu(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); +TORCH_API at::Tensor & adaptive_avg_pool3d_backward_out_cuda(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..70eccdbad7ee1b4078119728a7f4dc13d92b0cc2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool3d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_avg_pool3d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "grad_input") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..46872d80ad9b344940c662aac8ea906a77a329ef --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API at::Tensor adaptive_avg_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor adaptive_avg_pool3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c8d2528bcee029c651ee315d8ae8e3f997ce25b6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..52f8f7a70f29d0d5da89ca7a2a5f03755c70e579 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor & adaptive_avg_pool3d_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..0949add8cb6083c204f673e05af0a20eedc0a04f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_native.h @@ -0,0 +1,24 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor adaptive_avg_pool3d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size); +TORCH_API at::Tensor & adaptive_avg_pool3d_out_cpu(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_out_cuda(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +TORCH_API at::Tensor & adaptive_avg_pool3d_out_quantized_cpu(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..f420ebc5b3597b4e185670c500430dab3a259886 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_avg_pool3d_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_avg_pool3d_out { + using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_avg_pool3d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); +}; + +struct TORCH_API adaptive_avg_pool3d { + using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_avg_pool3d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor") + static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef output_size); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d.h new file mode 100644 index 0000000000000000000000000000000000000000..22baccb753eebbd559ce06d4d449482d22d2769c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d.h @@ -0,0 +1,30 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) +inline ::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool1d::call(self, output_size); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4546f00b62879ef7a6b3e241c3e7295bf36f61c9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeimplicitautograd { + +TORCH_API ::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeimplicitautograd +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..b2af9d523f42c94746fca4af685436f4cd6eeffe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_native.h @@ -0,0 +1,21 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API ::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..49a1393a3b436cce4ca8920c0d71ad17271ab39f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool1d_ops.h @@ -0,0 +1,28 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool1d { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool1d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)") + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d.h new file mode 100644 index 0000000000000000000000000000000000000000..d1520b27b55ae9ca3cb71e6854982a6277d4289a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d_out::call(self, output_size, out, indices); +} +// aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_out::call(self, output_size, out, indices); +} + +// aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) +inline ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d::call(self, output_size); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..af85a8dae17ff557ed665c416058a178a42c93a6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::call(grad_output, self, indices, grad_input); +} +// aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::call(grad_output, self, indices, grad_input); +} + +// aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor +inline at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward::call(grad_output, self, indices); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..fde38f5cc95032d0abcefdc3587cf86990b78226 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..798c421dbac601d569b02fb7bda45d1c5145c3c5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e176d3228926b817dc8ea7b323f94a19a273e739 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..9794ab4c66de52e58f47cc0e08763ea039f58573 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool2d_backward : public at::impl::MetaBase { + + + void meta(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..89982f55f7f3227ae11c6f416446f4ef25221fa7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool2d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..3e380439c0ab5a40a66abf4c8ab62af387d8b3c3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool2d_backward_out_cpu : public at::meta::structured_adaptive_max_pool2d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +struct TORCH_API structured_adaptive_max_pool2d_backward_out_cuda : public at::meta::structured_adaptive_max_pool2d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..70ec6f1f710852683b88f80053564701593806b3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool2d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool2d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "grad_input") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); +}; + +struct TORCH_API adaptive_max_pool2d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool2d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor") + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f6a6e5560eff20cd763f28e34c6932ddd48ded3a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4d3ee55bcba794e2b8c47da8d94deabed98e1a00 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..e30f25019eefa8d5e58c0ba2a75594470ae62c24 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..3566130ca13f037f479c13f3e2686a4d43122db6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool2d : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef output_size); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..d9668c4ece2710352a58d9951f04dd1e5409951b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API ::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool2d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..7e968a36b208c1d4359a43e992f21e48ddd17d34 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool2d_out_cpu : public at::meta::structured_adaptive_max_pool2d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +struct TORCH_API structured_adaptive_max_pool2d_out_cuda : public at::meta::structured_adaptive_max_pool2d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..3179829d6b41fe74bebfc26e73dffa37e5b4456b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool2d_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool2d_out { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool2d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))") + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); +}; + +struct TORCH_API adaptive_max_pool2d { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool2d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)") + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d.h new file mode 100644 index 0000000000000000000000000000000000000000..46140dc618b2f84eb8735189167a3142364b5e0a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d_out::call(self, output_size, out, indices); +} +// aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) +inline ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_out::call(self, output_size, out, indices); +} + +// aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) +inline ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d::call(self, output_size); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..a937d646f57af67af5c70fccfe8cad3b8a6e44b5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::call(grad_output, self, indices, grad_input); +} +// aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) +inline at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::call(grad_output, self, indices, grad_input); +} + +// aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor +inline at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward::call(grad_output, self, indices); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..39370d2febf72e354824f2de2da25812378534d6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..f82cc66807a723d28089c883a1ef9091344b5d8c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..c8e4a5c65488f5d479b869a4e49280c4d9428079 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..9a36e644418b0ced5d4c1a099778e85a233c693f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool3d_backward : public at::impl::MetaBase { + + + void meta(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..ef8b885cbb1c5c98254c2eca5f8539f5e96f0310 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_out(at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +TORCH_API at::Tensor & adaptive_max_pool3d_backward_outf(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h new file mode 100644 index 0000000000000000000000000000000000000000..caf64b5c5552cdbef6eb731aaa732f19d0731144 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool3d_backward_out_cpu : public at::meta::structured_adaptive_max_pool3d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +struct TORCH_API structured_adaptive_max_pool3d_backward_out_cuda : public at::meta::structured_adaptive_max_pool3d_backward { +void impl(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, const at::Tensor & grad_input); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..a01977617a72331ea121d30ee2dca63bb847780e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool3d_backward_grad_input { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool3d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "grad_input") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); +}; + +struct TORCH_API adaptive_max_pool3d_backward { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool3d_backward") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor") + static at::Tensor call(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..684cfdeec4c31441721d1e75c50e4b7797773db6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..41dd65be3684d5febb0eb37d655c37c9f71d17ed --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..55acbfee68654cf1c7578be8c3b1c9c992329fc2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..7f923ba0637757dab45adf74824f2dba850bccfd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_adaptive_max_pool3d : public at::impl::MetaBase { + + + void meta(const at::Tensor & self, at::IntArrayRef output_size); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..aaa79aaa2193efbd7f8d7dcde2865de883604108 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API ::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_out(at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size); +TORCH_API ::std::tuple adaptive_max_pool3d_outf(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_native.h new file mode 100644 index 0000000000000000000000000000000000000000..80421c103dbd626a591c9d9b2b9177cefe52869d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_native.h @@ -0,0 +1,26 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_adaptive_max_pool3d_out_cpu : public at::meta::structured_adaptive_max_pool3d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +struct TORCH_API structured_adaptive_max_pool3d_out_cuda : public at::meta::structured_adaptive_max_pool3d { +void impl(const at::Tensor & self, at::IntArrayRef output_size, const at::Tensor & out, const at::Tensor & indices); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..214ae900afcaf832a50361120a28cdd160dd254f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/adaptive_max_pool3d_ops.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API adaptive_max_pool3d_out { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef, at::Tensor &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool3d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))") + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); +}; + +struct TORCH_API adaptive_max_pool3d { + using schema = ::std::tuple (const at::Tensor &, at::IntArrayRef); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::adaptive_max_pool3d") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)") + static ::std::tuple call(const at::Tensor & self, at::IntArrayRef output_size); + static ::std::tuple redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/add.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/add.h new file mode 100644 index 0000000000000000000000000000000000000000..da1de6d07e5548f142fa2131f28d5c11af45d5a1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/add.h @@ -0,0 +1,53 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +inline at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_Tensor::call(self, other, alpha); +} + +// aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_out::call(self, other, alpha, out); +} +// aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_outf(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_out::call(self, other, alpha, out); +} + +// aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor +inline at::Tensor add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar::call(self, other, alpha); +} + +// aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_out(at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar_out::call(self, other, alpha, out); +} +// aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & add_outf(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_Scalar_out::call(self, other, alpha, out); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm.h new file mode 100644 index 0000000000000000000000000000000000000000..6f74a0f869fb9afb99cd25ebbfb6d248449b32fe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm_out::call(self, batch1, batch2, beta, alpha, out); +} +// aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addbmm_out::call(self, batch1, batch2, beta, alpha, out); +} + +// aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +inline at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm::call(self, batch1, batch2, beta, alpha); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..4b79e23aeeac7d2fb9d957d8e9129715f6fdd1f0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6bc7a7c0680d6278569997b0cba0cb28a265e6de --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_outf(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..6452de7679774bcddb9012c48260bb5e3beea5f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_meta_dispatch.h @@ -0,0 +1,23 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_native.h new file mode 100644 index 0000000000000000000000000000000000000000..6e50590419df3b2f79326e3abfd795b533946bdb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace at { +namespace native { +TORCH_API at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +TORCH_API at::Tensor & addbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +TORCH_API at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1); +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..dc14be4b1a3af8207594ae24ca6091d25f824f88 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addbmm_ops.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addbmm_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::addbmm_") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)") + static at::Tensor & call(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +struct TORCH_API addbmm_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::addbmm") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); +}; + +struct TORCH_API addbmm { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::addbmm") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor") + static at::Tensor call(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv.h new file mode 100644 index 0000000000000000000000000000000000000000..f23b44aa2f23959d3f1cb3034a20cbed7ba91ad6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv_out::call(self, tensor1, tensor2, value, out); +} +// aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcdiv_out::call(self, tensor1, tensor2, value, out); +} + +// aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv::call(self, tensor1, tensor2, value); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..5c772d2c64dcd37674988ebfc110fbf642eb9f63 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_cpu_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_cpu_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..67600a3ac708466a3b8551bbc0254ae3ca7a4e6f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_cpu_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cpu { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace cpu +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_cuda_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_cuda_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..85268d589375de645afc8e2cdb4f83b2180be8a6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_cuda_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace cuda { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace cuda +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..76775529f6e3750339f0c032f77312346ed2dc22 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta.h @@ -0,0 +1,27 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +struct TORCH_API structured_addcdiv : public TensorIteratorBase { + + + void meta(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..1b1dca2d129367c1303657d280f40b69719ac5f7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_meta_dispatch.h @@ -0,0 +1,26 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace meta { + +TORCH_API at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcdiv_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +TORCH_API at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace meta +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_native.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_native.h new file mode 100644 index 0000000000000000000000000000000000000000..f6ef52a6ca95add2c4609395220a645c05ebb468 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_native.h @@ -0,0 +1,23 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunction.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +struct TORCH_API structured_addcdiv_out : public at::meta::structured_addcdiv { +void impl(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, const at::Tensor & out); +}; +} // namespace native +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_ops.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..fe9b02622035b42513306066fb9e4b901bf4aa37 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcdiv_ops.h @@ -0,0 +1,50 @@ +#pragma once + +// @generated by torchgen/gen.py from Operator.h + +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + + +struct TORCH_API addcdiv_out { + using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &, at::Tensor &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::addcdiv") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)") + static at::Tensor & call(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); +}; + +struct TORCH_API addcdiv { + using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::addcdiv") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor") + static at::Tensor call(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); + static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +struct TORCH_API addcdiv_ { + using schema = at::Tensor & (at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Scalar &); + using ptr_schema = schema*; + // See Note [static constexpr char* members for windows NVCC] + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::addcdiv_") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)") + static at::Tensor & call(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); + static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); +}; + +}} // namespace at::_ops diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcmul.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcmul.h new file mode 100644 index 0000000000000000000000000000000000000000..9d35481d8b78ee36ff9a60a5780394651f8fc65b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcmul.h @@ -0,0 +1,39 @@ +#pragma once + +// @generated by torchgen/gen.py from Function.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#include + +namespace at { + + +// aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcmul_out(at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul_out::call(self, tensor1, tensor2, value, out); +} +// aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) +inline at::Tensor & addcmul_outf(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcmul_out::call(self, tensor1, tensor2, value, out); +} + +// aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor +inline at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul::call(self, tensor1, tensor2, value); +} + +} diff --git a/MLPY/Lib/site-packages/torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..3c97b1dfdafa251d1f1ea3440bf0d62036b68bbf --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h @@ -0,0 +1,24 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunction.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { + +namespace compositeexplicitautogradnonfunctional { + +TORCH_API at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); +TORCH_API at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1); + +} // namespace compositeexplicitautogradnonfunctional +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/quantized/QTensorImpl.h b/MLPY/Lib/site-packages/torch/include/ATen/quantized/QTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..cfe4bea80c14ad8d19a987c887e935f4709e6488 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/quantized/QTensorImpl.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include + +namespace at { + +/** + * QTensorImpl is a TensorImpl for Quantized Tensors, it stores Quantizer which + * specifies the quantization scheme and parameters, for more information please + * see ATen/quantized/Quantizer.h + * + * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl. + */ +struct TORCH_API QTensorImpl : public c10::TensorImpl { + public: + QTensorImpl( + Storage&& storage, + DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + QuantizerPtr quantizer); + + // See Note [Enum ImplType] + QTensorImpl( + ImplType type, + Storage&& storage, + DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + QuantizerPtr quantizer); + + + // TODO: Expose in PyTorch Frontend + QuantizerPtr quantizer() { + return quantizer_; + } + + void set_quantizer_(QuantizerPtr quantizer) { + quantizer_ = quantizer; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive( + Storage(storage()), key_set(), data_type_, quantizer_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive( + Storage(storage()), key_set(), data_type_, quantizer_); + copy_tensor_metadata( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + impl->refresh_contiguous(); + return impl; + } + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); + auto q_impl = static_cast(impl.get()); + copy_tensor_metadata( + /*src_impl=*/q_impl, + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + refresh_numel(); + refresh_contiguous(); + } + + private: + QuantizerPtr quantizer_; + + const char* tensorimpl_type_name() const override; + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) + * from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const QTensorImpl* src_q_impl, + QTensorImpl* dest_q_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change); + + // OpaqueTensorImpl-specific fields. + dest_q_impl->quantizer_ = src_q_impl->quantizer_; + } +}; + +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/quantized/Quantizer.h b/MLPY/Lib/site-packages/torch/include/ATen/quantized/Quantizer.h new file mode 100644 index 0000000000000000000000000000000000000000..92e7bb6844f5a1e010174ad37c7a9f8928392e6a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/quantized/Quantizer.h @@ -0,0 +1,279 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include + +namespace at { + +/** + * UnknownQuantizer is a placeholder quantizer for functions that implement + * quantization in a two step process. First a tensor is allocated but with + * unknown quantizer, and then the quantization kernel decides what the final + * quantizer will be. + */ +struct TORCH_API UnknownQuantizer : public Quantizer { + explicit UnknownQuantizer(ScalarType scalar_type) + : Quantizer(scalar_type) {} + + Tensor quantize(const Tensor& tensor) override; + Tensor dequantize(const Tensor& qtensor) override; + Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override; + QScheme qscheme() const override; + bool equalTo(QuantizerPtr other) const override; +}; + +/** + * UniformQuantizer is the parent class for all uniform quantizers. + * These quantization scheme will map float value uniformly to + * the quantized value. For example, affine quantizer is + * the most commonly used scheme in this category. + */ +struct TORCH_API UniformQuantizer : public Quantizer { + explicit UniformQuantizer(ScalarType scalar_type) : Quantizer(scalar_type) {} +}; + +/** + * NonUniformQuantizer is the parent class for all non-uniform quantizers. + * These quantization scheme may map float value non-uniformly to the quantized + * value. K-means quantization is a representative example in this category. + */ +struct TORCH_API NonUniformQuantizer : public Quantizer { + explicit NonUniformQuantizer(ScalarType scalar_type) : Quantizer(scalar_type) {} +}; + +// There is also StochasticQuantizer which is uniform but not affine + +/** + * AffineQuantizer uses affine transformation to do quantization. + * + * For quantize: + * Y = clamp(round(X / scale + zero_point), min, max) + * For dequantize: + * X = (Y - zero_point) * scale + */ +struct TORCH_API AffineQuantizer : public UniformQuantizer { + explicit AffineQuantizer(ScalarType scalar_type) : UniformQuantizer(scalar_type) {} +}; + +// Note that we will not have Symmetric Quantizer in backend to reduce +// complications in quantized kernel implementation. + +/** + * PerTensorAffineQuantizer stores a scale and a zero_point, which is used for + * all the values in the Tensor. + */ +struct TORCH_API PerTensorAffineQuantizer : public AffineQuantizer { + explicit PerTensorAffineQuantizer(ScalarType scalar_type, double scale, int64_t zero_point) + : AffineQuantizer(scalar_type), + scale_(scale), + zero_point_(zero_point) {} + + Tensor quantize(const Tensor& tensor) override; + Tensor dequantize(const Tensor& qtensor) override; + Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override; + + QScheme qscheme() const override { + return kPerTensorAffine; + } + + double scale() const { + return scale_; + } + + int64_t zero_point() const { + return zero_point_; + } + + bool equalTo(QuantizerPtr other) const override { + if (!other.get() || other->qscheme() != kPerTensorAffine) { + return false; + } + auto* other_per_tensor_affine = + static_cast(other.get()); + return scalar_type() == other_per_tensor_affine->scalar_type() && + scale() == other_per_tensor_affine->scale() && + zero_point() == other_per_tensor_affine->zero_point(); + } + + private: + const double scale_; + // We use int64_t for consistency with Python + const int64_t zero_point_; +}; + +/** + * PerChannelAffineQuantizer is the same as PerTensorAffineQuantizer + * except that we have an independent scale and zero_point parameter + * for each channel. + * + * Also note that per channel quantization is mostly applied to output channels + * of weights since per-input channel of weight quantization or per-channel + * quantization for activations can't be efficiently supported in most of + * processors since it requires each multiplication result within a single + * dot-product to have a different scale. + */ +struct TORCH_API PerChannelAffineQuantizer : public AffineQuantizer { + explicit PerChannelAffineQuantizer( + ScalarType scalar_type, + Tensor scales, + Tensor zero_points, + int64_t axis) + : AffineQuantizer(scalar_type), + scales_(std::move(scales)), + zero_points_(std::move(zero_points)), + axis_(axis) {} + + QScheme qscheme() const override { + return kPerChannelAffine; + } + + Tensor scales() const { + return scales_; + } + + Tensor zero_points() const { + return zero_points_; + } + + int64_t axis() const { + return axis_; + } + + Tensor quantize(const Tensor& tensor) override; + Tensor dequantize(const Tensor& qtensor) override; + Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override; + + bool equalTo(QuantizerPtr other) const override { + if (!other.get() || other->qscheme() != kPerChannelAffine) { + return false; + } + auto* other_per_channel_affine = + static_cast(other.get()); + return scalar_type() == other_per_channel_affine->scalar_type() && + scales().equal(other_per_channel_affine->scales()) && + zero_points().equal(other_per_channel_affine->zero_points()) && + axis() == other_per_channel_affine->axis(); + } + + protected: + Tensor scales_; + Tensor zero_points_; + const int64_t axis_; +}; + +/** + * PerChannelAffineFloatQParamsQuantizer is the same as PerChannelAffineQuantizer + * except that it expects both scale and zero point to be floating point values. + * + * This quantizer uses the kPerChannelAffineFloatQParams qscheme which is a variant of + * kPerChannelAffine. + * + * The quantize equation in this case looks like - + * Xq = (Xf - zero_point) * inv_scale, where inv_scale = 1.0/scale + * + * Note: Usage of floating point zero point is useful in cases where 0 doesn't need to + * be exactly represented in the quantized space. We can get additional precision by + * using floating point values for zero point. + */ +struct TORCH_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffineQuantizer { + explicit PerChannelAffineFloatQParamsQuantizer( + ScalarType scalar_type, + Tensor scales, + Tensor zero_points, + int64_t axis) + : PerChannelAffineQuantizer(scalar_type, + scales, + zero_points, + axis) {} + + QScheme qscheme() const override { + return kPerChannelAffineFloatQParams; + } + + Tensor quantize(const Tensor& tensor) override; + Tensor dequantize(const Tensor& qtensor) override; + Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override; + + bool equalTo(QuantizerPtr other) const override { + if (!other.get() || other->qscheme() != kPerChannelAffineFloatQParams) { + return false; + } + auto* other_per_channel_float_qparams = + static_cast(other.get()); + return scalar_type() == other_per_channel_float_qparams->scalar_type() && + scales().equal(other_per_channel_float_qparams->scales()) && + zero_points().equal(other_per_channel_float_qparams->zero_points()) && + axis() == other_per_channel_float_qparams->axis(); + } +}; + +// This is an internal utility function for getting at the QTensorImpl, +// You should only use this for writing low level +// setters/getters for QTensorImpl fields; otherwise, you should use +// the low level setters/getters that were implemented using this. +// This may be called repeatedly, so make sure it's pretty cheap. +TORCH_API QTensorImpl* get_qtensorimpl(const TensorBase& self); + +// double and int64_t are because of the native function API, we only have these +// argument types right now in native functions +TORCH_API QuantizerPtr +make_per_tensor_affine_quantizer( + double scale, int64_t zero_point, ScalarType scalar_type); + +TORCH_API QuantizerPtr make_per_channel_affine_quantizer( + const Tensor& scales, + const Tensor& zero_points, + int64_t axis, + ScalarType scalar_type); + +TORCH_API QuantizerPtr make_unknown_quantizer(ScalarType scalar_type); + +// Create a Quantized Tensor given arguments for normal Tensor and a quantizer +TORCH_API Tensor new_qtensor( + IntArrayRef sizes, + const TensorOptions& options, + QuantizerPtr quantizer); + +TORCH_API void set_quantizer_(const Tensor& self, ConstQuantizerPtr quantizer); + +TORCH_API Tensor from_blob_quantized_per_tensor_affine( + void* data, + IntArrayRef sizes, + IntArrayRef strides, + std::function deleter, + const float scale, + const int64_t zeroPoint, + const TensorOptions& options); + +TORCH_API Tensor from_blob_quantized_per_tensor_affine( + void* data, + IntArrayRef sizes, + std::function deleter, + const float scale, + const int64_t zeroPoint, + const TensorOptions& options); + +TORCH_API Tensor from_blob_quantized_per_channel_affine( + void* data, + IntArrayRef sizes, + std::function deleter, + const Tensor& scales, + const Tensor& zero_points, + const int64_t axis, + const TensorOptions& options); + +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/ATen/record_function.h b/MLPY/Lib/site-packages/torch/include/ATen/record_function.h new file mode 100644 index 0000000000000000000000000000000000000000..6aa5b67766b3c6c3be921bcf69a8c74098d682f2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/ATen/record_function.h @@ -0,0 +1,740 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { +class TORCH_API OperatorHandle; +} + +namespace at { + +// Function name to record NCCL metadata +extern TORCH_API const std::string kParamCommsCallName; + +// Kind of record function scope; +enum class C10_API_ENUM RecordScope : uint8_t { + // c10/ATen ops, autograd nodes + FUNCTION = 0, + // Functions/nodes called from the autograd + BACKWARD_FUNCTION, + // TorchScript functions, methods + TORCHSCRIPT_FUNCTION, + // Kernel Function dtype Tag + KERNEL_FUNCTION_DTYPE, + // Torchbind custom class, + CUSTOM_CLASS, + // Generic Build Feature + BUILD_FEATURE, + // Kernel Function dtype Tag + LITE_INTERPRETER, + // User defined scope (e.g. with record_function()) + USER_SCOPE, + // Scopes for static runtime, a specialized TorchScript interpreter + STATIC_RUNTIME_OP, + STATIC_RUNTIME_MODEL, + NUM_SCOPES, // must be the last in the list +}; + +} // namespace at + +namespace std { +template <> +struct hash { + size_t operator()(const at::RecordScope& sc) const { + return static_cast(sc); + } +}; +} // namespace std + +namespace at { + +struct TORCH_API StringView { + StringView() : StringView(nullptr) {} + explicit StringView(const char* str_ptr) + : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {} + explicit StringView(std::string str) + : owned_str_ptr_(std::make_shared(std::move(str))), + str_ptr_(owned_str_ptr_->c_str()) {} + + const char* str() const { + return str_ptr_; + } + + friend std::ostream& operator<<(std::ostream& os, const StringView& dt) { + os << dt.str(); + return os; + } + + friend bool operator==(const StringView& lhs, const StringView& rhs) { + return strcmp(lhs.str(), rhs.str()) == 0; + } + + friend bool operator!=(const StringView& lhs, const StringView& rhs) { + return !(lhs == rhs); + } + + private: + std::shared_ptr owned_str_ptr_; + const char* str_ptr_; +}; + +// Soft limit on the number of callbacks to use; +constexpr std::size_t kSoftLimitCallbacks = 4; + +// An abstract base class for various observer contexts that can be attached to +// the RecordFunction. +struct ObserverContext { + virtual ~ObserverContext() = default; + + protected: + ObserverContext() = default; +}; + +typedef c10::SmallVector CallbackHandles; +typedef c10::SmallVector, kSoftLimitCallbacks> + ObserverContextList; +typedef uint64_t RecordFunctionHandle; +struct RecordFunction; + +// +// PyTorch callbacks/observers API: +// + +/** + * RecordFunctionCallback represents a pair of callbacks to be used with + * RecordFunction, members: + * start, end - the callbacks to run when entering and exiting the scope; + * optionally, the start callback may return an ObserverContext which will + * be passed to the end callback, use appropriate constructor accordingly. + * needs_inputs - whether the callbacks need the inputs passed from the + * observed function/range; NOTE: passing the inputs incurs an additional + * overhead; sampling_probability - if not 1.0, then the callback is + * probabilistically sampled to run; NOTE: start and end callbacks always run as + * a pair and are sampled together; scopes - types of scopes to execute the + * callbacks on (see RecordScope); passing empty set means the callbacks will be + * executed for all possible scope types should_run - optional function that + * returns whether this callback should run; overwrites the effect of setting + * sampling_probability + */ +class TORCH_API RecordFunctionCallback { + public: + using StartCallback = + std::unique_ptr (*)(const RecordFunction&); + using EndCallback = void (*)(const RecordFunction&, ObserverContext*); + + // This interface supports observers that require passing an ObserverContext + // between start and end callbacks. + explicit RecordFunctionCallback( + StartCallback start, + EndCallback end = nullptr) + : start_(start), end_(end) { + scopes_.fill(true); + } + + RecordFunctionCallback& needsInputs(bool needs_inputs) { + needs_inputs_ = needs_inputs; + return *this; + } + + RecordFunctionCallback& needsOutputs(bool needs_outputs) { + needs_outputs_ = needs_outputs; + return *this; + } + + RecordFunctionCallback& needsIds(bool needs_ids) { + needs_ids_ = needs_ids; + return *this; + } + + RecordFunctionCallback& samplingProb(double sampling_prob) { + TORCH_CHECK( + sampling_prob >= 0.0 && sampling_prob <= 1.0, + "Invalid sampling probability"); + sampling_prob_ = sampling_prob; + return *this; + } + + RecordFunctionCallback& scopes( + const std::unordered_set>& scopes) { + if (!scopes.empty()) { + scopes_.fill(false); + for (auto sc : scopes) { + scopes_[static_cast(sc)] = true; + } + } else { + scopes_.fill(true); + } + return *this; + } + + bool needsInputs() const { + return needs_inputs_; + } + + bool needsOutputs() const { + return needs_outputs_; + } + + bool needsIds() const { + return needs_ids_; + } + + double samplingProb() const { + return sampling_prob_; + } + + bool checkScope(RecordScope sc) const { + return scopes_[(size_t)sc]; + } + + StartCallback start() const { + return start_; + } + + EndCallback end() const { + return end_; + } + + private: + StartCallback start_; + EndCallback end_; + double sampling_prob_ = 1.0; + std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; + bool needs_inputs_ = false; + bool needs_outputs_ = false; + bool needs_ids_ = false; +}; + +// Notes: +// - two types of callbacks are provided: thread local and global +// - thread local callbacks are added/removed only for the given thread +// and are stored locally for each thread and separately from the list +// of the global callbacks +// - global callbacks are stored in a single per process list and are +// invoked by every RecordFunction, in addition to the thread local +// callbacks specific to the given thread +// - we allow the added callbacks to be sampled, by specifying a sampling +// probability for each callback pair, if the start callback is +// not picked to run, the corresponding end callback won't be called +// - a typical use case for the global callbacks is passive monitoring +// in the background (e.g. fleet-wide monitoring), without focusing on +// the specific piece of code +// - in contrast, thread local callbacks are enabled locally, on demand, +// for the specific piece of code (range) and are not sampled +// - a typical use case for thread local callbacks is profiler and code +// execution tracer +// - note, thread local callbacks are automatically propagated with +// ThreadLocalState across JIT continuations and async tasks (at::launch) + +typedef uint64_t CallbackHandle; + +constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0}; + +// It is unnecessary to use atomic operations for enabling +// thread-local function callbacks. Moreover, it prevents saving to +// ThreadLocalState because std::atomic is non-copyable. +struct RecordFunctionCallbacksEntry { + RecordFunctionCallbacksEntry(RecordFunctionCallback cb, CallbackHandle h) + : callback_(cb), handle_(h) {} + + RecordFunctionCallback callback_; + bool enabled_{true}; + CallbackHandle handle_; +}; + +// Holds pairs (callbacks, unique_id) +using RecordFunctionCallbacks = std::vector; + +// Generated by the callback managers to determine which functions to run. +struct StepCallbacks { + StepCallbacks() = default; + StepCallbacks(uint64_t thread_id, RecordScope scope) + : thread_id_{thread_id}, scope_{scope} {} + + bool empty() const { + return callbacks_.empty(); + } + + struct StartEndPair { + RecordFunctionCallback::StartCallback start_; + RecordFunctionCallback::EndCallback end_; + }; + + using StartEndPairs = c10::SmallVector; + + StartEndPairs callbacks_; + uint64_t thread_id_{0}; + RecordScope scope_{RecordScope::FUNCTION}; + bool needs_inputs_{false}; + bool needs_outputs_{false}; + bool needs_ids_{false}; +}; + +struct TORCH_API RecordFunction { + // Default constructor is used with before function called afterwards: + // scope - record scope that this function tracks + // pre_sampled - whether this RecordFunction was already pre-sampled with + // kLowProb probability + explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); + explicit RecordFunction(StepCallbacks&& step_callbacks); + + template + void before( + F fn, + c10::ArrayRef args, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + inputs_ = args; + before(fn, current_sequence_nr); + } + + template + void before( + F fn, + const std::vector* args, + int64_t current_sequence_nr = -1) { + before( + std::move(fn), + c10::ArrayRef(args->data(), args->size()), + current_sequence_nr); + } + + // Destructor calls end callbacks + virtual ~RecordFunction(); + + RecordFunction(const RecordFunction&) = delete; + RecordFunction& operator=(const RecordFunction&) = delete; + + const char* name() const; + + int64_t seqNr() const { + return sequence_nr_; + } + + c10::ArrayRef inputs() const { +#ifndef NDEBUG + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + inputs_valid_, "Called inputs() outside RecordFunction start callback"); +#endif + return inputs_; + } + + const std::vector& outputs() const { + return outputs_; + } + + void setOutputs(std::vector&& outputs) { + outputs_ = std::move(outputs); + } + + void setOutputs(c10::ArrayRef outputs) { + outputs_ = outputs.vec(); + } + + size_t num_inputs() const; + size_t num_outputs() const; + + // Retrieves the thread_id that this RecordFunction ran start callbacks with. + // Useful for writing thread safe end callbacks that may be potentially + // executed in a different thread (async ops) + uint64_t threadId() const { + return step_callbacks_.thread_id_; + } + + // For backward functions - thread id of the corresponding forward function, + // or zero otherwise; + // used alongside with sequence number to correlate backward functions with + // the forward ones + uint64_t forwardThreadId() const { + return fwd_thread_id_; + } + + void setForwardThreadId(uint64_t thread_id) { + fwd_thread_id_ = thread_id; + } + + RecordScope scope() const { + return step_callbacks_.scope_; + } + + // Returns logical thread_id for the current thread + static uint64_t currentThreadId(); + + // Internal functions, do not use directly; + // used in python's context manager + + // before functions initialize RecordFunction members and call + // start callbacks + using schema_ref_t = std::reference_wrapper; + void before(const char* name, int64_t sequence_nr = -1); + void before(std::string name, int64_t sequence_nr = -1); + void before(schema_ref_t schema, int64_t sequence_nr = -1); + + // Sets node ID for distributed profiling + static void setDefaultNodeId(int64_t defaultNodeId); + // Gets node ID for distributed profiling + static int64_t getDefaultNodeId(); + + // Calls end callbacks. After end(), accessors will no longer provide useful + // results. + void end(); + + // Internal-only, used only force async event for distributed events + // profiling. + void _setAsync(); + + // Returns whether this RecordFunction corresponds to an async event or not. + bool isAsync() const; + + // Returns whether this RecordFunction corresponds to NCCL metadata collection + // or not. + bool isNcclMeta() const { + return is_nccl_meta_; + } + + // Internal-only, used to denote out variant used for Static Runtime execution + void _setStaticRuntimeOutVariant(); + bool isStaticRuntimeOutVariant() const; + + RecordFunctionHandle handle() const { + return handle_; + } + + c10::optional operator_name() const; + + // This method returns a copy of the FunctionSchema and can be expensive. + c10::optional operator_schema() const; + + void setHandle(RecordFunctionHandle handle) { + handle_ = handle; + } + + // Whether this RecordFunction runs any callbacks. + bool isActive() const { + return !step_callbacks_.empty(); + } + + bool needsInputs() const { + return step_callbacks_.needs_inputs_; + } + + bool needsOutputs() const { + return step_callbacks_.needs_outputs_; + } + + int64_t debugHandle() const { + return debug_handle_; + } + + void setDebugHandle(int64_t debug_handle) { + debug_handle_ = debug_handle; + } + + void invalidateInputs() { +#ifndef NDEBUG + inputs_valid_ = false; +#endif + } + + private: + void runStartCallbacks(); + + StepCallbacks step_callbacks_; + + // In cases when RecordFunction might be active but we chose not to + // use the observers (e.g. operator is not observed), this boolean + // flag is used to check whether the start callbacks were called + bool called_start_callbacks_ = false; + +#ifndef NDEBUG + bool inputs_valid_ = false; +#endif + + // Stores various ObserverContext objects with event metadata for callbacks. + ObserverContextList ctx_; + + std::variant fn_; + + int64_t sequence_nr_ = -1; + c10::ArrayRef inputs_; + std::vector outputs_; + + // For backward functions - thread id of the forward function + uint64_t fwd_thread_id_ = 0; + + // Unique id for this RecordFunction, used in callbacks to track start + // and end of ranges + RecordFunctionHandle handle_{0}; + + // Whether this record_function corresponds to an async event or not. Async + // events can complete in different threads or follow a future-like pattern + // of use. + bool is_async_{false}; + + // Debug handles are used for lazy annotation of module hierarchy + // and callstack. + // This is specifically is useful for mobile runtime, where generated + // debug handles can be lazily symbolicated using debug information + int64_t debug_handle_{-1}; + + // Whether this RecordFunction is used for an out variant run with + // Static Runtime + bool is_static_runtime_out_variant_{false}; + + // Whether this RecordFunction is used for NCCL metadata collection + bool is_nccl_meta_{false}; +}; + +TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); + +TORCH_API c10::optional getStepCallbacksUnlessEmpty( + RecordScope scope); + +namespace detail { +template +void record_function_with_scope( + RecordFunction& guard, + F fn, + const Inputs& inputs, + Args&&... args) { + if (guard.needsInputs()) { + guard.before( + fn, + c10::ArrayRef(inputs.data(), inputs.size()), + std::forward(args)...); + } else { + guard.before(fn, std::forward(args)...); + } +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + F fn, + int64_t debug_handle, + const Inputs& inputs, + Args&&... args) { + guard.setDebugHandle(debug_handle); + if (guard.needsInputs()) { + guard.before( + fn, + c10::ArrayRef(inputs.data(), inputs.size()), + std::forward(args)...); + } else { + guard.before(fn, std::forward(args)...); + } +} + +template +void record_function_with_scope( + RecordFunction& guard, + F fn, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope< + c10::ArrayRef, + F, + Args...>(guard, std::move(fn), inputs, std::forward(args)...); +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + F fn, + int64_t debug_handle, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope_and_debug_handle< + c10::ArrayRef, + F, + Args...>( + guard, std::move(fn), debug_handle, inputs, std::forward(args)...); +} + +} // namespace detail + +// optional argument - function's seq_no +#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + ::at::detail::record_function_with_scope( \ + guard, fn, inputs, ##__VA_ARGS__); \ + } + +#define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ + scope, fn, inputs, outputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + if (guard.needsInputs()) { \ + guard.before(fn, inputs, ##__VA_ARGS__); \ + } else { \ + guard.before(fn, ##__VA_ARGS__); \ + } \ + if (guard.needsOutputs()) { \ + guard.setOutputs(outputs); \ + } \ + } + +#define RECORD_FUNCTION(fn, inputs, ...) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__) + +#define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs) + +#define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \ + RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ + at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__) + +// Custom user scopes in C++; similar to Python's 'with record_function("..."):' +#define RECORD_USER_SCOPE(fn) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, fn, c10::ArrayRef{}) + +// RECORD_USER_SCOPE with inputs +#define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) + +// Helper macro to pass in debug handle that is used to +// post process events +#define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ + scope, fn, debug_handle, inputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + ::at::detail::record_function_with_scope_and_debug_handle( \ + guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ + } + +// Helper macros to record LITE INTERPETER scope events with debug handles +#define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ + fn, debug_handle, inputs) \ + RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ + at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) + +// Bookend to the RECORD_FUNCTION macros. Use this after the kernel +// launch to let the profiler bind the outputs to the op that produced +// them. Note that guard is declared by RECORD_FUNCTION so this macro +// needs to be called from the same scope as RECORD_FUNCTION +#define RECORD_OUTPUTS(outputs) \ + if (guard.needsOutputs()) { \ + guard.setOutputs( \ + std::vector(outputs.begin(), outputs.end())); \ + } + +/** + * addThreadLocalCallback adds a thread local callback to run with + * RecordFunction, returns handle to use with removeThreadLocalCallback + */ +TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb); + +/** + * hasThreadLocalCallbacks returns whether there're callbacks registered + * with addThreadLocalCallback + */ +TORCH_API bool hasThreadLocalCallbacks(); + +/** + * clearThreadLocalCallbacks removes all thread local callbacks + */ +TORCH_API void clearThreadLocalCallbacks(); + +/** + * addGlobalCallback adds a global callback to run with RecordFunction: + * + * only during the program initialization + */ +TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb); + +/** + * removeCallback removes a callback given the handle returned by + * addThreadLocalCallback or addGlobalCallback; + * + * no other code can run simultaneously + */ +TORCH_API void removeCallback(CallbackHandle handle); + +/** + * Prevent the given callback from executing. If handle is invalid, + * does nothing. + */ +TORCH_API void disableCallback(CallbackHandle handle); + +/** + * Allow the given callback, previously disabled with disableCallback, to + * execute again. If handle is invalid, does nothing. + */ +TORCH_API void reenableCallback(CallbackHandle handle); + +/** + * hasGlobalCallbacks returns whether there're global callbacks + * registered with pushGlobalCallback + */ +TORCH_API bool hasGlobalCallbacks(); + +/** + * clearGlobalCallbacks removes all global callbacks + */ +TORCH_API void clearGlobalCallbacks(); + +// for both thread local and global callbacks +TORCH_API bool hasCallbacks(); +TORCH_API void clearCallbacks(); + +/** + * enableRecordFunction enables RecordFunction thread locally + */ +TORCH_API void enableRecordFunction(bool enable = true); + +/** + * isRecordFunctionEnabled returns whether RecordFunction + * is enabled thread locally + */ +TORCH_API bool isRecordFunctionEnabled(); + +class TORCH_API RecordFunctionGuard { + public: + explicit RecordFunctionGuard(bool is_enabled = true) + : prev_value_(isRecordFunctionEnabled()) { + enableRecordFunction(is_enabled); + } + + virtual ~RecordFunctionGuard() { + enableRecordFunction(prev_value_); + } + + private: + bool prev_value_ = false; +}; + +class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { + public: + DisableRecordFunctionGuard() : RecordFunctionGuard(false) {} + ~DisableRecordFunctionGuard() override = default; +}; + +struct TORCH_API RecordFunctionTLS { + // Thread local vector of callbacks, holds pairs (callbacks, unique_id); + // must be sorted in increasing handles order + RecordFunctionCallbacks sorted_tls_callbacks_; + + bool tls_record_function_enabled_ = true; +}; + +TORCH_API const RecordFunctionTLS& get_record_function_tls_(); + +TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); + +TORCH_API void set_record_function_seed_for_testing(uint32_t seed); + +} // namespace at diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Allocator.h b/MLPY/Lib/site-packages/torch/include/c10/core/Allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..842cae2c1234f78a9d2d2bec2b74f44109ae569b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Allocator.h @@ -0,0 +1,319 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// A DataPtr is a unique pointer (with an attached deleter and some +// context for the deleter) to some memory, which also records what +// device is for its data. +// +// nullptr DataPtrs can still have a nontrivial device; this allows +// us to treat zero-size allocations uniformly with non-zero allocations. +// +class C10_API DataPtr { + private: + c10::detail::UniqueVoidPtr ptr_; + Device device_; + + public: + // Choice of CPU here is arbitrary; if there's an "undefined" device + // we could use that too + DataPtr() : ptr_(), device_(DeviceType::CPU) {} + DataPtr(void* data, Device device) : ptr_(data), device_(device) {} + DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device) + : ptr_(data, ctx, ctx_deleter), device_(device) {} + void* operator->() const { + return ptr_.get(); + } + void clear() { + ptr_.clear(); + } + void* get() const { + return ptr_.get(); + } + void* mutable_get() { + return ptr_.get(); + } + void* get_context() const { + return ptr_.get_context(); + } + void* release_context() { + return ptr_.release_context(); + } + std::unique_ptr&& move_context() { + return ptr_.move_context(); + } + operator bool() const { + return static_cast(ptr_); + } + template + T* cast_context(DeleterFnPtr expected_deleter) const { + return ptr_.cast_context(expected_deleter); + } + DeleterFnPtr get_deleter() const { + return ptr_.get_deleter(); + } + /** + * Compare the deleter in a DataPtr to expected_deleter. + * If it matches, replace the deleter with new_deleter + * and return true; otherwise, does nothing and returns + * false. + * + * In general, it is not safe to unconditionally set the + * deleter on a DataPtr, because you don't know what + * the deleter is, and thus will have a hard time properly + * disposing of the deleter without storing the original + * deleter (this is difficult to do, because DeleterFnPtr + * is not a closure, and because the context on DataPtr is + * only a single word, you generally don't have enough + * space to store both the original deleter and its context). + * However, in some cases, you know /exactly/ what the deleter + * is, and you have a new deleter that manually wraps + * the old one. In this case, you can safely swap the deleter + * after asserting that the deleters line up. + * + * What are the requirements on new_deleter? It must still + * properly dispose of the void* pointer passed in as its argument, + * where void* is whatever the context of the original deleter + * is. So in general, you expect the new deleter to look something + * like this: + * + * [](void* ptr) { + * some_new_stuff(ptr); + * get_orig_allocator()->raw_deleter(ptr); + * } + * + * Note that it won't work to close over the original + * allocator; you don't have enough space to do that! Also, + * it's unsafe to assume that the passed in pointer in + * question is the memory pointer in question; it might not + * be; be sure to read the source code of the Allocator + * in question to confirm this. + */ + C10_NODISCARD bool compare_exchange_deleter( + DeleterFnPtr expected_deleter, + DeleterFnPtr new_deleter) { + return ptr_.compare_exchange_deleter(expected_deleter, new_deleter); + } + Device device() const { + return device_; + } + // Unsafely mutates the device on a DataPtr. Under normal use, + // you should never actually need to call this function. + // We need this for the implementation of the hack detailed + // in Note [Masquerading as CUDA] + void unsafe_set_device(Device device) { + device_ = device; + } +}; + +// NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a +// CPU nullptr + +inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept { + return !dp; +} +inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept { + return !dp; +} +inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept { + return dp; +} +inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept { + return dp; +} + +// Note [raw_allocate/raw_deallocate and Thrust] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Thrust's support for custom allocators requires us to write something +// like this: +// +// class ThrustAllocator { +// char* allocate(size_t); +// void deallocate(char*, size_t); +// }; +// +// This is not good for our unique_ptr based allocator interface, as +// there is no way to get to the context when we free. +// +// However, in some cases the context is exactly the same as +// the data pointer. In this case, we can support the "raw" +// allocate and deallocate interface. This is what +// raw_deleter signifies. By default, it returns a nullptr, which means that +// the raw interface is not implemented. Be sure to implement it whenever +// possible, or the raw interface will incorrectly reported as unsupported, +// when it is actually possible. + +struct C10_API Allocator { + virtual ~Allocator() = default; + + virtual DataPtr allocate(size_t n) = 0; + + // Clones an allocation that came from this allocator. + // + // To perform the copy, this function calls `copy_data`, which + // must be implemented by derived classes. + // + // Note that this explicitly ignores any context that may have been + // attached to the input data. + // + // Requires: input data was allocated by the same allocator. + DataPtr clone(const void* data, std::size_t n); + + // Checks if DataPtr has a simple context, not wrapped with any out of the + // ordinary contexts. + virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const; + + // If this returns a non nullptr, it means that allocate() + // is guaranteed to return a unique_ptr with this deleter attached; + // it means the rawAllocate and rawDeallocate APIs are safe to use. + // This function MUST always return the same BoundDeleter. + virtual DeleterFnPtr raw_deleter() const { + return nullptr; + } + void* raw_allocate(size_t n) { + auto dptr = allocate(n); + AT_ASSERT(dptr.get() == dptr.get_context()); + return dptr.release_context(); + } + void raw_deallocate(void* ptr) { + auto d = raw_deleter(); + AT_ASSERT(d); + d(ptr); + } + + // Copies data from one allocation to another. + // Pure virtual, so derived classes must define behavior. + // Derived class implementation can simply call `default_copy_data` + // to use `std::memcpy`. + // + // Requires: src and dest were allocated by this allocator + // Requires: src and dest both have length >= count + virtual void copy_data(void* dest, const void* src, std::size_t count) + const = 0; + + protected: + // Uses `std::memcpy` to copy data. + // Child classes can use this as `copy_data` when an alternative copy + // API is not needed. + void default_copy_data(void* dest, const void* src, std::size_t count) const; +}; + +// This context is used to generate DataPtr which have arbitrary +// std::function deleters associated with them. In some user facing +// functions, we give a (user-friendly) interface for constructing +// tensors from external data which take an arbitrary std::function +// deleter. Grep for InefficientStdFunctionContext to find these +// occurrences. +// +// This context is inefficient because we have to do a dynamic +// allocation InefficientStdFunctionContext, on top of the dynamic +// allocation which is implied by std::function itself. +struct C10_API InefficientStdFunctionContext { + void* ptr_; + std::function deleter_; + InefficientStdFunctionContext(void* ptr, std::function deleter) + : ptr_(ptr), deleter_(std::move(deleter)) {} + ~InefficientStdFunctionContext() { + if (deleter_) { + deleter_(ptr_); + } + } + static DataPtr makeDataPtr( + void* ptr, + std::function deleter, + Device device); +}; + +/** Set the allocator for DeviceType `t`. The passed in allocator pointer is + * expected to have static lifetime; this function does NOT take ownership + * of the raw pointer. (The reason for this is to prevent existing pointers + * to an allocator of a particular device from being invalidated when + * SetAllocator is called.) + * + * Also note that this is not thread-safe, and we assume this function will + * only be called during initialization. + * + * The 'priority' flag is introduced when we want to overwrite the default + * allocator, since the allocators are set statically. The default priority + * is 0, which means the lowest. Only higher or equal priority can overwrite + * existing ones. + */ +C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0); +C10_API Allocator* GetAllocator(const DeviceType& t); + +template +struct AllocatorRegisterer { + explicit AllocatorRegisterer(Allocator* alloc) { + SetAllocator(t, alloc); + } +}; + +#define REGISTER_ALLOCATOR(t, f) \ + namespace { \ + static c10::AllocatorRegisterer g_allocator_d(f); \ + } + +// An interface for reporting thread local memory usage +// per device +struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { + MemoryReportingInfoBase(); + ~MemoryReportingInfoBase() override = default; + + /** + * alloc_size corresponds to the size of the ptr. + * + * total_allocated corresponds to total allocated memory. + * + * total_reserved corresponds to total size of memory pool, both used and + * unused, if applicable. + */ + virtual void reportMemoryUsage( + void* ptr, + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device) = 0; + + virtual void reportOutOfMemory( + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device); + + virtual bool memoryProfilingEnabled() const = 0; +}; + +C10_API bool memoryProfilingEnabled(); +C10_API void reportMemoryUsageToProfiler( + void* ptr, + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device); + +C10_API void reportOutOfMemoryToProfiler( + int64_t alloc_size, + size_t total_allocated, + size_t total_reserved, + Device device); + +// used to hold traceback information in allocators +struct GatheredContext { + virtual ~GatheredContext() = default; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/AutogradState.h b/MLPY/Lib/site-packages/torch/include/c10/core/AutogradState.h new file mode 100644 index 0000000000000000000000000000000000000000..328ca686a11c441b2d0777f6045dcbcf59708715 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/AutogradState.h @@ -0,0 +1,72 @@ +#pragma once + +#include + +namespace c10 { + +// Structure used to pack all the thread local boolean +// flags used by autograd +struct C10_API AutogradState { + static AutogradState& get_tls_state(); + static void set_tls_state(AutogradState state); + + AutogradState( + bool grad_mode, + bool inference_mode, + bool fw_grad_mode, + bool multithreading_enabled) + : grad_mode_(grad_mode), + inference_mode_(inference_mode), + fw_grad_mode_(fw_grad_mode), + multithreading_enabled_(multithreading_enabled), + view_replay_enabled_(false) {} + + void set_grad_mode(bool enabled) { + grad_mode_ = enabled; + } + + void set_fw_grad_mode(bool enabled) { + fw_grad_mode_ = enabled; + } + + void set_inference_mode(bool enabled) { + inference_mode_ = enabled; + } + + void set_multithreading_enabled(bool multithreading_enabled) { + multithreading_enabled_ = multithreading_enabled; + } + + void set_view_replay_enabled(bool view_replay_enabled) { + view_replay_enabled_ = view_replay_enabled; + } + + bool get_grad_mode() const { + return grad_mode_; + } + + bool get_fw_grad_mode() const { + return fw_grad_mode_; + } + + bool get_inference_mode() const { + return inference_mode_; + } + + bool get_multithreading_enabled() const { + return multithreading_enabled_; + } + + bool get_view_replay_enabled() const { + return view_replay_enabled_; + } + + private: + bool grad_mode_ : 1; + bool inference_mode_ : 1; + bool fw_grad_mode_ : 1; + bool multithreading_enabled_ : 1; + bool view_replay_enabled_ : 1; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Backend.h b/MLPY/Lib/site-packages/torch/include/c10/core/Backend.h new file mode 100644 index 0000000000000000000000000000000000000000..352427d9ed99e260a450826edf2a4ceaa77cbb7c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Backend.h @@ -0,0 +1,388 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace c10 { + +/** + * This legacy enum class defines the set of backends supported by old school, + * code generated Type-based ATen. A "backend" in this sense roughly + * corresponds to the cartesian product of (device type, layout), but restricted + * only to combinations which we actually have kernels for. Backend does NOT + * include dtype. + * + * The reason we are sunsetting this enum class is because it doesn't allow for + * open registration; e.g., if you want to add SparseXLA, you'd have to + * edit this enum; you wouldn't be able to do it out of tree. DispatchKey is + * the replacement for Backend which supports open registration. + * + * NB: The concept of 'Backend' here disagrees with the notion of backend + * exposed to users in torch.backends. Backend here is something like "CPU" + * or "SparseCUDA"; backend in torch.backends is something like "MKL" or + * "CUDNN". + */ +enum class Backend { + CPU, + CUDA, + HIP, + VE, + FPGA, + IPU, + XPU, + SparseCPU, + SparseCUDA, + SparseCsrCPU, + SparseCsrCUDA, + SparseHIP, + SparseVE, + SparseXPU, + SparsePrivateUse1, + SparseCsrHIP, + SparseCsrVE, + SparseCsrXPU, + SparseCsrPrivateUse1, + ORT, + XLA, + Vulkan, + Metal, + Meta, + QuantizedCPU, + QuantizedCUDA, + QuantizedXPU, + QuantizedPrivateUse1, + Undefined, + MkldnnCPU, + MPS, + HPU, + Lazy, + MTIA, + PrivateUse1, + NumOptions +}; + +static inline Backend dispatchKeyToBackend(DispatchKey t) { + if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) { + return Backend::CPU; + } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) { + return Backend::CUDA; + } else if (t == DispatchKey::HIP) { + return Backend::HIP; + } else if (t == DispatchKey::VE) { + return Backend::VE; + } else if (t == DispatchKey::FPGA) { + return Backend::FPGA; + } else if (t == DispatchKey::ORT) { + return Backend::ORT; + } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { + return Backend::XLA; + } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { + return Backend::Lazy; + } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) { + return Backend::MPS; + } else if (t == DispatchKey::Vulkan) { + return Backend::Vulkan; + } else if (t == DispatchKey::Metal) { + return Backend::Metal; + } else if (t == DispatchKey::Meta) { + return Backend::Meta; + } else if (t == DispatchKey::SparseCPU) { + return Backend::SparseCPU; + } else if (t == DispatchKey::SparseCUDA) { + return Backend::SparseCUDA; + } else if (t == DispatchKey::SparseHIP) { + return Backend::SparseHIP; + } else if (t == DispatchKey::SparseVE) { + return Backend::SparseVE; + } else if (t == DispatchKey::SparsePrivateUse1) { + return Backend::SparsePrivateUse1; + } else if (t == DispatchKey::SparseCsrCPU) { + return Backend::SparseCsrCPU; + } else if (t == DispatchKey::SparseCsrCUDA) { + return Backend::SparseCsrCUDA; + } else if (t == DispatchKey::SparseCsrHIP) { + return Backend::SparseCsrHIP; + } else if (t == DispatchKey::SparseCsrVE) { + return Backend::SparseCsrVE; + } else if (t == DispatchKey::SparseCsrPrivateUse1) { + return Backend::SparseCsrPrivateUse1; + } else if (t == DispatchKey::MkldnnCPU) { + return Backend::MkldnnCPU; + } else if (t == DispatchKey::QuantizedCPU) { + return Backend::QuantizedCPU; + } else if (t == DispatchKey::QuantizedCUDA) { + return Backend::QuantizedCUDA; + } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) { + return Backend::IPU; + } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) { + return Backend::XPU; + } else if (t == DispatchKey::SparseXPU) { + return Backend::SparseXPU; + } else if (t == DispatchKey::SparseCsrXPU) { + return Backend::SparseCsrXPU; + } else if (t == DispatchKey::QuantizedXPU) { + return Backend::QuantizedXPU; + } else if (t == DispatchKey::QuantizedPrivateUse1) { + return Backend::QuantizedPrivateUse1; + } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) { + return Backend::HPU; + } else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) { + return Backend::MTIA; + } else if ( + t == DispatchKey::PrivateUse1 || t == DispatchKey::AutogradPrivateUse1) { + return Backend::PrivateUse1; + } else if (t == DispatchKey::Undefined) { + return Backend::Undefined; + } else { + TORCH_CHECK(false, "Unrecognized tensor type ID: ", t); + } +} + +static inline DispatchKey backendToDispatchKey(Backend b) { + switch (b) { + case Backend::CPU: + return DispatchKey::CPU; + case Backend::CUDA: + return DispatchKey::CUDA; + case Backend::HIP: + return DispatchKey::HIP; + case Backend::VE: + return DispatchKey::VE; + case Backend::FPGA: + return DispatchKey::FPGA; + case Backend::ORT: + return DispatchKey::ORT; + case Backend::XLA: + return DispatchKey::XLA; + case Backend::Lazy: + return DispatchKey::Lazy; + case Backend::IPU: + return DispatchKey::IPU; + case Backend::XPU: + return DispatchKey::XPU; + case Backend::SparseXPU: + return DispatchKey::SparseXPU; + case Backend::SparseCsrXPU: + return DispatchKey::SparseCsrXPU; + case Backend::SparseCPU: + return DispatchKey::SparseCPU; + case Backend::SparseCUDA: + return DispatchKey::SparseCUDA; + case Backend::SparseHIP: + return DispatchKey::SparseHIP; + case Backend::SparseVE: + return DispatchKey::SparseVE; + case Backend::SparsePrivateUse1: + return DispatchKey::SparsePrivateUse1; + case Backend::SparseCsrCPU: + return DispatchKey::SparseCsrCPU; + case Backend::SparseCsrCUDA: + return DispatchKey::SparseCsrCUDA; + case Backend::SparseCsrHIP: + return DispatchKey::SparseCsrHIP; + case Backend::SparseCsrVE: + return DispatchKey::SparseCsrVE; + case Backend::SparseCsrPrivateUse1: + return DispatchKey::SparseCsrPrivateUse1; + case Backend::MkldnnCPU: + return DispatchKey::MkldnnCPU; + case Backend::Vulkan: + return DispatchKey::Vulkan; + case Backend::Metal: + return DispatchKey::Metal; + case Backend::Meta: + return DispatchKey::Meta; + case Backend::QuantizedCPU: + return DispatchKey::QuantizedCPU; + case Backend::QuantizedCUDA: + return DispatchKey::QuantizedCUDA; + case Backend::QuantizedPrivateUse1: + return DispatchKey::QuantizedPrivateUse1; + case Backend::Undefined: + return DispatchKey::Undefined; + case Backend::MPS: + return DispatchKey::MPS; + case Backend::HPU: + return DispatchKey::HPU; + case Backend::MTIA: + return DispatchKey::MTIA; + case Backend::PrivateUse1: + return DispatchKey::PrivateUse1; + default: + throw std::runtime_error("Unknown backend"); + } +} + +static inline DeviceType backendToDeviceType(Backend b) { + switch (b) { + case Backend::CPU: + case Backend::MkldnnCPU: + case Backend::SparseCPU: + case Backend::SparseCsrCPU: + case Backend::QuantizedCPU: + return DeviceType::CPU; + case Backend::CUDA: + case Backend::SparseCUDA: + case Backend::QuantizedCUDA: + case Backend::SparseCsrCUDA: + return DeviceType::CUDA; + case Backend::HIP: + return DeviceType::HIP; + case Backend::VE: + return DeviceType::VE; + case Backend::FPGA: + return DeviceType::FPGA; + case Backend::ORT: + return DeviceType::ORT; + case Backend::XLA: + return DeviceType::XLA; + case Backend::Lazy: + return DeviceType::Lazy; + case Backend::SparseHIP: + return DeviceType::HIP; + case Backend::SparseVE: + return DeviceType::VE; + case Backend::SparseCsrHIP: + return DeviceType::HIP; + case Backend::SparseCsrVE: + return DeviceType::VE; + case Backend::IPU: + return DeviceType::IPU; + case Backend::XPU: + case Backend::SparseXPU: + case Backend::SparseCsrXPU: + case Backend::QuantizedXPU: + return DeviceType::XPU; + case Backend::Vulkan: + return DeviceType::Vulkan; + case Backend::Metal: + return DeviceType::Metal; + case Backend::Meta: + return DeviceType::Meta; + case Backend::MPS: + return DeviceType::MPS; + case Backend::HPU: + return DeviceType::HPU; + case Backend::MTIA: + return DeviceType::MTIA; + case Backend::PrivateUse1: + case Backend::SparsePrivateUse1: + case Backend::SparseCsrPrivateUse1: + case Backend::QuantizedPrivateUse1: + return DeviceType::PrivateUse1; + case Backend::Undefined: + TORCH_CHECK(false, "Undefined backend is not a valid device type"); + default: + TORCH_CHECK(false, "Unknown backend"); + } +} + +// TODO: This probably shouldn't actually be static inline +static inline const char* toString(Backend b) { + switch (b) { + case Backend::CPU: + return "CPU"; + case Backend::CUDA: + return "CUDA"; + case Backend::HIP: + return "HIP"; + case Backend::VE: + return "VE"; + case Backend::FPGA: + return "FPGA"; + case Backend::XPU: + return "XPU"; + case Backend::IPU: + return "IPU"; + case Backend::ORT: + return "ORT"; + case Backend::XLA: + return "XLA"; + case Backend::Lazy: + return "Lazy"; + case Backend::MPS: + return "MPS"; + case Backend::SparseCPU: + return "SparseCPU"; + case Backend::SparseCUDA: + return "SparseCUDA"; + case Backend::SparseHIP: + return "SparseHIP"; + case Backend::SparseVE: + return "SparseVE"; + case Backend::SparseXPU: + return "SparseXPU"; + case Backend::SparsePrivateUse1: + return "SparsePrivateUse1"; + case Backend::SparseCsrCPU: + return "SparseCsrCPU"; + case Backend::SparseCsrCUDA: + return "SparseCsrCUDA"; + case Backend::SparseCsrHIP: + return "SparseCsrHIP"; + case Backend::SparseCsrVE: + return "SparseCsrVE"; + case Backend::SparseCsrXPU: + return "SparseCsrXPU"; + case Backend::SparseCsrPrivateUse1: + return "SparseCsrPrivateUse1"; + case Backend::MkldnnCPU: + return "MkldnnCPU"; + case Backend::Vulkan: + return "Vulkan"; + case Backend::Metal: + return "Metal"; + case Backend::Meta: + return "Meta"; + case Backend::QuantizedCPU: + return "QuantizedCPU"; + case Backend::QuantizedCUDA: + return "QuantizedCUDA"; + case Backend::QuantizedXPU: + return "QuantizedXPU"; + case Backend::QuantizedPrivateUse1: + return "QuantizedPrivateUse1"; + case Backend::HPU: + return "HPU"; + case Backend::MTIA: + return "MTIA"; + case Backend::PrivateUse1: + return "PrivateUseOne"; + default: + return "UNKNOWN_BACKEND"; + } +} + +static inline bool isSparse(Backend b) { + switch (b) { + case Backend::SparseXPU: + case Backend::SparseCPU: + case Backend::SparseCUDA: + case Backend::SparseHIP: + case Backend::SparseVE: + case Backend::SparsePrivateUse1: + return true; + default: + return false; + } +} + +static inline bool isSparseCsr(Backend b) { + switch (b) { + case Backend::SparseCsrXPU: + case Backend::SparseCsrCPU: + case Backend::SparseCsrCUDA: + case Backend::SparseCsrHIP: + case Backend::SparseCsrVE: + case Backend::SparseCsrPrivateUse1: + return true; + default: + return false; + } +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/CPUAllocator.h b/MLPY/Lib/site-packages/torch/include/c10/core/CPUAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..d84ac28ec9e99d989692e37a1c465689dc7edf9a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/CPUAllocator.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +// TODO: rename to c10 +C10_DECLARE_bool(caffe2_report_cpu_memory_usage); + +namespace c10 { + +using MemoryDeleter = void (*)(void*); + +// A helper function that is basically doing nothing. +C10_API void NoDelete(void*); + +// A simple struct that is used to report C10's memory allocation, +// deallocation status and out-of-memory events to the profiler +class C10_API ProfiledCPUMemoryReporter { + public: + ProfiledCPUMemoryReporter() = default; + void New(void* ptr, size_t nbytes); + void OutOfMemory(size_t nbytes); + void Delete(void* ptr); + + private: + std::mutex mutex_; + std::unordered_map size_table_; + size_t allocated_ = 0; + size_t log_cnt_ = 0; +}; + +C10_API ProfiledCPUMemoryReporter& profiledCPUMemoryReporter(); + +// Get the CPU Allocator. +C10_API at::Allocator* GetCPUAllocator(); +// Sets the CPU allocator to the given allocator: the caller gives away the +// ownership of the pointer. +C10_API void SetCPUAllocator(at::Allocator* alloc, uint8_t priority = 0); + +// Get the Default CPU Allocator +C10_API at::Allocator* GetDefaultCPUAllocator(); + +// Get the Default Mobile CPU Allocator +C10_API at::Allocator* GetDefaultMobileCPUAllocator(); + +// The CPUCachingAllocator is experimental and might disappear in the future. +// The only place that uses it is in StaticRuntime. +// Set the CPU Caching Allocator +C10_API void SetCPUCachingAllocator(Allocator* alloc, uint8_t priority = 0); +// Get the CPU Caching Allocator +C10_API Allocator* GetCPUCachingAllocator(); + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/CompileTimeFunctionPointer.h b/MLPY/Lib/site-packages/torch/include/c10/core/CompileTimeFunctionPointer.h new file mode 100644 index 0000000000000000000000000000000000000000..698f191056693249c0f0a95f1e671f7f6cf67d12 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/CompileTimeFunctionPointer.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +namespace c10 { + +/** + * Represent a function pointer as a C++ type. + * This allows using the function pointer as a type + * in a template and calling it from inside the template + * allows the compiler to inline the call because it + * knows the function pointer at compile time. + * + * Example 1: + * int add(int a, int b) {return a + b;} + * using Add = TORCH_FN_TYPE(add); + * template struct Executor { + * int execute(int a, int b) { + * return Func::func_ptr()(a, b); + * } + * }; + * Executor executor; + * EXPECT_EQ(3, executor.execute(1, 2)); + * + * Example 2: + * int add(int a, int b) {return a + b;} + * template int execute(Func, int a, int b) { + * return Func::func_ptr()(a, b); + * } + * EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); + */ +template +struct CompileTimeFunctionPointer final { + static_assert( + guts::is_function_type::value, + "TORCH_FN can only wrap function types."); + using FuncType = FuncType_; + + static constexpr FuncType* func_ptr() { + return func_ptr_; + } +}; + +template +struct is_compile_time_function_pointer : std::false_type {}; +template +struct is_compile_time_function_pointer< + CompileTimeFunctionPointer> : std::true_type {}; + +} // namespace c10 + +#define TORCH_FN_TYPE(func) \ + ::c10::CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> +#define TORCH_FN(func) TORCH_FN_TYPE(func)() diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/ConstantSymNodeImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/ConstantSymNodeImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..863a701ecf7c86c9027f5e46b1a09746d4530bf8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/ConstantSymNodeImpl.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Unlike other SymNodeImpl, this cannot be "dispatched" conventionally, +// as it typically needs to defer to another SymNodeImpl +// +// Can either represent a bool, int (don't support float yet) this is useful +// for representing otherwise unrepresentable large negative integer constant. +template +class C10_API ConstantSymNodeImpl : public SymNodeImpl { + static_assert( + ::std::is_same_v || ::std::is_same_v, + "ConstantSymNodeImpl can only accept int64_t or bool types"); + + public: + ConstantSymNodeImpl(T val) : value_(val) {} + + bool is_int() override { + return is_int_(); + } + bool is_bool() override { + return is_bool_(); + } + bool is_float() override { + return false; + } + int64_t guard_int(const char* file, int64_t line) override { + TORCH_CHECK(is_int(), "not an int"); + return int_(); + } + bool guard_bool(const char* file, int64_t line) override { + TORCH_CHECK(is_bool(), "not a bool"); + return bool_(); + } + double guard_float(const char* file, int64_t line) override { + TORCH_CHECK(false, "not a float"); + } + int64_t int_() override { + TORCH_CHECK(is_int(), "not an int"); + return ::std::get(value_); + } + bool bool_() override { + TORCH_CHECK(is_bool(), "not a bool"); + return ::std::get(value_); + } + bool has_hint() override { + return true; + } + c10::SymNode eq(const c10::SymNode& other) override; + c10::SymNode ne(const c10::SymNode& other) override; + c10::SymNode ge(const c10::SymNode& other) override; + c10::SymNode le(const c10::SymNode& other) override; + c10::SymNode lt(const c10::SymNode& other) override; + c10::SymNode gt(const c10::SymNode& other) override; + c10::SymNode mul(const c10::SymNode& other) override; + ::std::string str() override { + if constexpr (is_int_()) { + return ::std::to_string(::std::get(value_)); + } else { + return ::std::get(value_) ? "true" : "false"; + } + } + c10::optional constant_int() override { + if constexpr (is_int_()) { + return ::std::get(value_); + } else { + return c10::nullopt; + } + } + c10::optional constant_bool() override { + if constexpr (is_bool_()) { + return ::std::get(value_); + } else { + return c10::nullopt; + } + } + bool is_constant() override { + return true; + } + bool is_symbolic() override { + return false; + } + + private: + ::std::variant value_; + + static constexpr bool is_int_() { + return ::std::is_same_v; + } + static constexpr bool is_bool_() { + return ::std::is_same_v; + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Contiguity.h b/MLPY/Lib/site-packages/torch/include/c10/core/Contiguity.h new file mode 100644 index 0000000000000000000000000000000000000000..9c28aa0d83f09ed477419e92db0ad758f46cbf42 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Contiguity.h @@ -0,0 +1,129 @@ +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +template +bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + bool is_contiguous = true; + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { + return is_contiguous; + } + T z = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) { + z *= size_d; + } else { + is_contiguous = false; + break; + } + } + } + return is_contiguous; +} + +template +bool _compute_channels_last_contiguous_2d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 4: { + T expected = 1; + for (auto& d : {1, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +bool _compute_channels_last_contiguous_3d( + ArrayRef sizes, + ArrayRef strides) { + // Please don't combine these code, constant array is used here to let + // compiler fully unroll the loop to get better performance + switch (sizes.size()) { + case 5: { + T expected = 1; + for (auto& d : {1, 4, 3, 2, 0}) { + const auto& size_d = sizes[d]; + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) { + return false; + } + expected *= size_d; + } + } + return true; + } + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +bool _compute_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + auto dim = sizes.size(); + if (dim == 1) { + return sizes[0] < 2 || strides[0] == 1; + } + SmallVector perm; + perm.resize(dim); + for (const auto i : c10::irange(dim)) { + perm[i] = i; + } + // Sort by strides, leaving 0 and 1 sized dims at the end of the array + std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { + if (sizes[a] < 2) { + return false; + } else if (sizes[b] < 2) { + return true; + } + return strides[a] < strides[b]; + }); + T require_stride = 1; + for (const auto i : c10::irange(dim)) { + const auto& size_perm_i = sizes[perm[i]]; + if (size_perm_i < 2) { + return true; + } + if (strides[perm[i]] != require_stride) { + return false; + } + require_stride *= size_perm_i; + } + return true; +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/CopyBytes.h b/MLPY/Lib/site-packages/torch/include/c10/core/CopyBytes.h new file mode 100644 index 0000000000000000000000000000000000000000..f5b08d74aa6ba3525d2ab50f73dd217b893b4f9d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/CopyBytes.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +using CopyBytesFunction = void (*)( + size_t nbytes, + const void* src, + Device src_device, + void* dst, + Device dst_device); + +struct C10_API _CopyBytesFunctionRegisterer { + _CopyBytesFunctionRegisterer( + DeviceType from, + DeviceType to, + CopyBytesFunction func_sync, + CopyBytesFunction func_async = nullptr); +}; + +#define REGISTER_COPY_BYTES_FUNCTION(from, to, ...) \ + namespace { \ + static _CopyBytesFunctionRegisterer C10_ANONYMOUS_VARIABLE( \ + g_copy_function)(from, to, __VA_ARGS__); \ + } + +/* + * WARNING: Implementations for this function are currently registered from + * ATen and caffe2, not yet from c10. Don't use this if not either ATen + * or caffe2 is present as well. + * We can't move them yet, because the CUDA implementations aren't unified yet + * between ATen and caffe2. + * We're planning to move the implementations into c10/backend/xxx + * to make c10 self contained again. + */ +C10_API void CopyBytes( + size_t nbytes, + const void* src, + Device src_device, + void* dst, + Device dst_device, + bool async); +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DefaultDtype.h b/MLPY/Lib/site-packages/torch/include/c10/core/DefaultDtype.h new file mode 100644 index 0000000000000000000000000000000000000000..5c4cd53fa78dc59fc7da6b0fbe4fd2ad9f91cf77 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DefaultDtype.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +namespace caffe2 { +class TypeMeta; +} // namespace caffe2 + +namespace c10 { +C10_API void set_default_dtype(caffe2::TypeMeta dtype); +C10_API const caffe2::TypeMeta get_default_dtype(); +C10_API ScalarType get_default_dtype_as_scalartype(); +C10_API const caffe2::TypeMeta get_default_complex_dtype(); +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DefaultTensorOptions.h b/MLPY/Lib/site-packages/torch/include/c10/core/DefaultTensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..c00197ead055805164a6f99ff8600b46784f3963 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DefaultTensorOptions.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +struct TensorOptions; + +/// Like TensorOptions, but all fields are guaranteed to be filled. +struct DefaultTensorOptions { + DefaultTensorOptions() = default; + + caffe2::TypeMeta dtype() const noexcept { + return dtype_; + } + Device device() const noexcept { + return device_; + } + Layout layout() const noexcept { + return layout_; + } + bool requires_grad() const noexcept { + return requires_grad_; + } + + // Defined in TensorOptions.h + inline DefaultTensorOptions& merge(const TensorOptions& options); + + private: + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit + Device device_ = at::kCPU; // 32-bit + Layout layout_ = at::kStrided; // 8-bit + bool requires_grad_ = false; // 8-bit +}; + +inline const DefaultTensorOptions& getDefaultTensorOptions() { + static const auto options = DefaultTensorOptions(); + return options; +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Device.h b/MLPY/Lib/site-packages/torch/include/c10/core/Device.h new file mode 100644 index 0000000000000000000000000000000000000000..064b1f8c2d67d1a9d5a306fdf7db7d11d6efe83f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Device.h @@ -0,0 +1,216 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10 { + +/// An index representing a specific device; e.g., the 1 in GPU 1. +/// A DeviceIndex is not independently meaningful without knowing +/// the DeviceType it is associated; try to use Device rather than +/// DeviceIndex directly. +using DeviceIndex = int8_t; + +/// Represents a compute device on which a tensor is located. A device is +/// uniquely identified by a type, which specifies the type of machine it is +/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the +/// specific compute device when there is more than one of a certain type. The +/// device index is optional, and in its defaulted state represents (abstractly) +/// "the current device". Further, there are two constraints on the value of the +/// device index, if one is explicitly stored: +/// 1. A negative index represents the current device, a non-negative index +/// represents a specific, concrete device, +/// 2. When the device type is CPU, the device index must be zero. +struct C10_API Device final { + using Type = DeviceType; + + /// Constructs a new `Device` from a `DeviceType` and an optional device + /// index. + /* implicit */ Device(DeviceType type, DeviceIndex index = -1) + : type_(type), index_(index) { + validate(); + } + + /// Constructs a `Device` from a string description, for convenience. + /// The string supplied must follow the following schema: + /// `(cpu|cuda)[:]` + /// where `cpu` or `cuda` specifies the device type, and + /// `:` optionally specifies a device index. + /* implicit */ Device(const std::string& device_string); + + /// Returns true if the type and index of this `Device` matches that of + /// `other`. + bool operator==(const Device& other) const noexcept { + return this->type_ == other.type_ && this->index_ == other.index_; + } + + /// Returns true if the type or index of this `Device` differs from that of + /// `other`. + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + /// Sets the device index. + void set_index(DeviceIndex index) { + index_ = index; + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the optional index. + DeviceIndex index() const noexcept { + return index_; + } + + /// Returns true if the device has a non-default index. + bool has_index() const noexcept { + return index_ != -1; + } + + /// Return true if the device is of CUDA type. + bool is_cuda() const noexcept { + return type_ == DeviceType::CUDA; + } + + /// Return true if the device is of PrivateUse1 type. + bool is_privateuseone() const noexcept { + return type_ == DeviceType::PrivateUse1; + } + + /// Return true if the device is of MPS type. + bool is_mps() const noexcept { + return type_ == DeviceType::MPS; + } + + /// Return true if the device is of HIP type. + bool is_hip() const noexcept { + return type_ == DeviceType::HIP; + } + + /// Return true if the device is of VE type. + bool is_ve() const noexcept { + return type_ == DeviceType::VE; + } + + /// Return true if the device is of XPU type. + bool is_xpu() const noexcept { + return type_ == DeviceType::XPU; + } + + /// Return true if the device is of IPU type. + bool is_ipu() const noexcept { + return type_ == DeviceType::IPU; + } + + /// Return true if the device is of XLA type. + bool is_xla() const noexcept { + return type_ == DeviceType::XLA; + } + + /// Return true if the device is of MTIA type. + bool is_mtia() const noexcept { + return type_ == DeviceType::MTIA; + } + + /// Return true if the device is of HPU type. + bool is_hpu() const noexcept { + return type_ == DeviceType::HPU; + } + + /// Return true if the device is of Lazy type. + bool is_lazy() const noexcept { + return type_ == DeviceType::Lazy; + } + + /// Return true if the device is of Vulkan type. + bool is_vulkan() const noexcept { + return type_ == DeviceType::Vulkan; + } + + /// Return true if the device is of Metal type. + bool is_metal() const noexcept { + return type_ == DeviceType::Metal; + } + + /// Return true if the device is of ORT type. + bool is_ort() const noexcept { + return type_ == DeviceType::ORT; + } + + /// Return true if the device is of META type. + bool is_meta() const noexcept { + return type_ == DeviceType::Meta; + } + + /// Return true if the device is of CPU type. + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } + + /// Return true if the device supports arbitrary strides. + bool supports_as_strided() const noexcept { + return type_ != DeviceType::IPU && type_ != DeviceType::XLA && + type_ != DeviceType::Lazy && type_ != DeviceType::MTIA; + } + + /// Same string as returned from operator<<. + std::string str() const; + + private: + DeviceType type_; + DeviceIndex index_ = -1; + void validate() { + // Removing these checks in release builds noticeably improves + // performance in micro-benchmarks. + // This is safe to do, because backends that use the DeviceIndex + // have a later check when we actually try to switch to that device. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index_ >= -1, + "Device index must be -1 or non-negative, got ", + static_cast(index_)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", + static_cast(index_)); + } +}; + +C10_API std::ostream& operator<<(std::ostream& stream, const Device& device); + +} // namespace c10 + +namespace std { +template <> +struct hash { + size_t operator()(c10::Device d) const noexcept { + // Are you here because this static assert failed? Make sure you ensure + // that the bitmasking code below is updated accordingly! + static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); + static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); + // Note [Hazard when concatenating signed integers] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // We must first convert to a same-sized unsigned type, before promoting to + // the result type, to prevent sign extension when any of the values is -1. + // If sign extension occurs, you'll clobber all of the values in the MSB + // half of the resulting integer. + // + // Technically, by C/C++ integer promotion rules, we only need one of the + // uint32_t casts to the result type, but we put in both for explicitness's + // sake. + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); + return std::hash{}(bits); + } +}; +} // namespace std diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DeviceArray.h b/MLPY/Lib/site-packages/torch/include/c10/core/DeviceArray.h new file mode 100644 index 0000000000000000000000000000000000000000..1a98e7c47bce1153dbc9cde165973dfa580a9977 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DeviceArray.h @@ -0,0 +1,28 @@ +#include +#include +#include +#include +#include + +namespace c10 { + +template +class DeviceArray { + public: + DeviceArray(c10::Allocator& allocator, size_t size) + : data_ptr_(allocator.allocate(size * sizeof(T))) { + static_assert(std::is_trivial::value, "T must be a trivial type"); + TORCH_INTERNAL_ASSERT( + 0 == (reinterpret_cast(data_ptr_.get()) % alignof(T)), + "c10::DeviceArray: Allocated memory is not aligned for this data type"); + } + + T* get() { + return static_cast(data_ptr_.get()); + } + + private: + c10::DataPtr data_ptr_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DeviceGuard.h b/MLPY/Lib/site-packages/torch/include/c10/core/DeviceGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..2101f1c7b001fcd62000fb729be3693cdf87b95d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DeviceGuard.h @@ -0,0 +1,199 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +/// RAII guard that sets a certain default device in its constructor, and +/// changes it back to the device that was originally active upon destruction. +/// +/// The device is always reset to the one that was active at the time of +/// construction of the guard. Even if you `set_device` after construction, the +/// destructor will still reset the device to the one that was active at +/// construction time. +/// +/// This device guard does NOT have an uninitialized state; it is guaranteed +/// to reset a device on exit. If you are in a situation where you *might* +/// want to setup a guard (i.e., are looking for the moral equivalent +/// of optional), see OptionalDeviceGuard. +class DeviceGuard { + public: + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit DeviceGuard() = delete; + + /// Set the current device to the passed Device. + explicit DeviceGuard(Device device) : guard_(device) {} + + /// This constructor is for testing only. + explicit DeviceGuard( + Device device, + const impl::DeviceGuardImplInterface* impl) + : guard_(device, impl) {} + + /// Copy is disallowed + DeviceGuard(const DeviceGuard&) = delete; + DeviceGuard& operator=(const DeviceGuard&) = delete; + + /// Move is disallowed, as DeviceGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + DeviceGuard(DeviceGuard&& other) = delete; + DeviceGuard& operator=(DeviceGuard&& other) = delete; + + /// Sets the device to the given one. The specified device must be consistent + /// with the device type originally specified during guard construction. + /// + /// TODO: The consistency check here is inconsistent with StreamGuard's + /// behavior with set_stream, where a stream on a different device than + /// the original one isn't an error; we just reset the stream and then + /// switch devices. + void reset_device(at::Device device) { + guard_.reset_device(device); + } + + /// This method is for testing only. + void reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl) { + guard_.reset_device(device, impl); + } + + /// Sets the device index to the given one. The device type is inferred + /// from the original device type the guard was constructed with. + void set_index(DeviceIndex index) { + guard_.set_index(index); + } + + /// Returns the device that was set at the time the guard was constructed. + Device original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device. + Device current_device() const { + return guard_.current_device(); + } + + private: + impl::InlineDeviceGuard guard_; +}; + +/** + * A OptionalDeviceGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * Morally, a OptionalDeviceGuard is equivalent to optional, but + * with extra constructors and methods as appropriate. + * + * Besides its obvious use (optionally applying a DeviceGuard), + * OptionalDeviceGuard is often also used for the following idiom: + * + * OptionalDeviceGuard g; + * for (const auto& t : tensors) { + * g.set_device(t.device()); + * do_something_with(t); + * } + * + * This usage is marginally more efficient than constructing a DeviceGuard every + * iteration of the for loop, as it avoids an unnecessary device reset. + * + * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs + * when you use the nullary constructor, or pass a nullopt to the constructor. + * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the + * original device was and they do not reset on destruction. This is why + * original_device() and current_device() return optional rather than + * Device (as they do in DeviceGuard), and also is why we didn't just + * provide OptionalDeviceGuard by default and hide DeviceGuard from users. + * + * The semantics of an OptionalDeviceGuard are exactly explained by thinking + * of it as an optional. In particular, an initialized + * OptionalDeviceGuard doesn't restore device to its value at construction; it + * restores device to its value *at initialization*. So if you have the + * program: + * + * setDevice(1); + * OptionalDeviceGuard g; + * setDevice(2); + * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! + * + * On destruction, g will reset device to 2, rather than 1. + * + * An uninitialized OptionalDeviceGuard is distinct from a (initialized) + * DeviceGuard whose original_device_ and current_device_ match, since the + * DeviceGuard will still reset the device to original_device_. + */ +class OptionalDeviceGuard { + public: + /// Create an uninitialized guard. Set the guard later using reset_device. + explicit OptionalDeviceGuard() = default; + + /// Initialize the guard, setting the current device to the passed Device. + explicit OptionalDeviceGuard(Device device) : guard_(device) {} + + /// Initialize the guard if a Device is passed; otherwise leave the + /// guard uninitialized. + explicit OptionalDeviceGuard(optional device) : guard_(device) {} + + /// Constructor for testing only. + explicit OptionalDeviceGuard( + Device device, + const impl::DeviceGuardImplInterface* impl) + : guard_(device, impl) {} + + /// Copy is disallowed + OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; + OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; + + /// Move is disallowed + /// See Note [Explicit initialization of optional fields] + /// and // Note [Move construction for RAII guards is tricky] + /// for rationale. + OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; + OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; + + /// Sets the device to the given one. The specified device must be consistent + /// with the device type originally specified during guard construction. + void reset_device(at::Device device) { + guard_.reset_device(device); + } + + /// For testing only + void reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl) { + guard_.reset_device(device, impl); + } + + /// Returns the device that was set at the time the guard was constructed. + optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via reset_device. + optional current_device() const { + return guard_.current_device(); + } + + private: + impl::InlineOptionalDeviceGuard guard_{}; +}; + +// Note [Whither the DeviceGuard boilerplate] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Design note: in principle, we could avoid these wrappers using: +// +// using DeviceGuard = impl::InlineDeviceGuard; +// using OptionalDeviceGuard = +// impl::InlineOptionalDeviceGuard; +// +// But the error messages are worse, and our users can't just look at the +// header file to find out what's going on. Furthermore, for specializations +// like CUDAStreamGuard, it can be profitable to replace some interfaces with +// refined types (e.g., return CUDAStream instead of Stream). So, we eat +// the boilerplate and write out the API explicitly. + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DeviceType.h b/MLPY/Lib/site-packages/torch/include/c10/core/DeviceType.h new file mode 100644 index 0000000000000000000000000000000000000000..64b50f3f5701b4650445c556127055ee3bd1d056 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DeviceType.h @@ -0,0 +1,123 @@ +#pragma once + +// This is directly synchronized with caffe2/proto/caffe2.proto, but +// doesn't require me to figure out how to get Protobuf headers into +// ATen/core (which would require a lot more build system hacking.) +// If you modify me, keep me synchronized with that file. + +#include + +#include +#include +#include +#include +#include + +namespace c10 { + +// These contains all device types that also have a BackendComponent +// and therefore participate in per-backend functionality dispatch keys. +// This is most backends except PrivateUse2 and PrivateUse3 +#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \ + _(CPU, extra) \ + _(CUDA, extra) \ + _(HIP, extra) \ + _(XLA, extra) \ + _(MPS, extra) \ + _(IPU, extra) \ + _(XPU, extra) \ + _(HPU, extra) \ + _(VE, extra) \ + _(Lazy, extra) \ + _(Meta, extra) \ + _(MTIA, extra) \ + _(PrivateUse1, extra) + +enum class DeviceType : int8_t { + CPU = 0, + CUDA = 1, // CUDA. + MKLDNN = 2, // Reserved for explicit MKLDNN + OPENGL = 3, // OpenGL + OPENCL = 4, // OpenCL + IDEEP = 5, // IDEEP. + HIP = 6, // AMD HIP + FPGA = 7, // FPGA + ORT = 8, // ONNX Runtime / Microsoft + XLA = 9, // XLA / TPU + Vulkan = 10, // Vulkan + Metal = 11, // Metal + XPU = 12, // XPU + MPS = 13, // MPS + Meta = 14, // Meta (tensors with no data) + HPU = 15, // HPU / HABANA + VE = 16, // SX-Aurora / NEC + Lazy = 17, // Lazy Tensors + IPU = 18, // Graphcore IPU + MTIA = 19, // Meta training and inference devices + PrivateUse1 = 20, // PrivateUse1 device + // NB: If you add more devices: + // - Change the implementations of DeviceTypeName and isValidDeviceType + // in DeviceType.cpp + // - Change the number below + COMPILE_TIME_MAX_DEVICE_TYPES = 21, +}; + +constexpr DeviceType kCPU = DeviceType::CPU; +constexpr DeviceType kCUDA = DeviceType::CUDA; +constexpr DeviceType kHIP = DeviceType::HIP; +constexpr DeviceType kFPGA = DeviceType::FPGA; +constexpr DeviceType kORT = DeviceType::ORT; +constexpr DeviceType kXLA = DeviceType::XLA; +constexpr DeviceType kMPS = DeviceType::MPS; +constexpr DeviceType kMeta = DeviceType::Meta; +constexpr DeviceType kVulkan = DeviceType::Vulkan; +constexpr DeviceType kMetal = DeviceType::Metal; +constexpr DeviceType kXPU = DeviceType::XPU; +constexpr DeviceType kHPU = DeviceType::HPU; +constexpr DeviceType kVE = DeviceType::VE; +constexpr DeviceType kLazy = DeviceType::Lazy; +constexpr DeviceType kIPU = DeviceType::IPU; +constexpr DeviceType kMTIA = DeviceType::MTIA; +constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1; + +// define explicit int constant +constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); + +static_assert( + COMPILE_TIME_MAX_DEVICE_TYPES <= 21, + "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " + "for this constant to reflect the actual number of DeviceTypes we support " + "in PyTorch; it's important that this number is not too large as we " + "use this to allocate stack arrays in some places in our code. If you " + "are indeed just adding the 20th device type, feel free to change " + "the check to 32; but if you are adding some sort of extensible device " + "types registration, please be aware that you are affecting code that " + "this number is small. Try auditing uses of this constant."); + +C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false); + +C10_API bool isValidDeviceType(DeviceType d); + +C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); + +C10_API void register_privateuse1_backend(const std::string& backend_name); +C10_API std::string get_privateuse1_backend(bool lower_case = true); + +C10_API bool is_privateuse1_backend_registered(); + +} // namespace c10 + +namespace std { +template <> +struct hash { + std::size_t operator()(c10::DeviceType k) const { + return std::hash()(static_cast(k)); + } +}; +} // namespace std + +namespace torch { +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::DeviceType; +} // namespace torch diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DispatchKey.h b/MLPY/Lib/site-packages/torch/include/c10/core/DispatchKey.h new file mode 100644 index 0000000000000000000000000000000000000000..1eafef5a7ca4c077a98b8f62922ae82f1228d2db --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DispatchKey.h @@ -0,0 +1,748 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Semantically, each value of BackendComponent identifies a "backend" for our +// dispatch. Some functionalities that we may dispatch to are allowed to +// register different handlers for each backend. The BackendComponent is then +// used to figure out which backend implementation to dispatch to. + +// In implementation terms, the backend component identifies a specific "bit" in +// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom +// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to +// functionalities. When we encounter a functionality bit that is known to be +// customizable per-backend, then we also look at the lower BackendComponent +// bits and take the highest bit to determine which backend's implementation to +// use. + +// WARNING! If you add a new backend component to the end of this list, +// make sure you register it before Meta. +// Meta must be at the end so that meta key in tls triggers meta kernels. +// (But you shouldn't: private use keys should have higher precedence than all +// built-in keys) + +// If you add a new (non-privateuse) backend here, +// make sure to add an Autograd fallthrough kernel +// in aten/src/ATen/core/VariableFallbackKernel.cpp + +#define C10_FORALL_BACKEND_COMPONENTS(_, extra) \ + _(CPU, extra) \ + _(CUDA, extra) \ + _(HIP, extra) \ + _(XLA, extra) \ + _(MPS, extra) \ + _(IPU, extra) \ + _(XPU, extra) \ + _(HPU, extra) \ + _(VE, extra) \ + _(Lazy, extra) \ + _(MTIA, extra) \ + _(PrivateUse1, extra) \ + _(PrivateUse2, extra) \ + _(PrivateUse3, extra) \ + _(Meta, extra) + +// WARNING! If we add a new per-backend functionality key that has higher +// priority than Autograd, then make sure you update EndOfRuntimeBackendKeys + +#define C10_FORALL_FUNCTIONALITY_KEYS(_) \ + _(Dense, ) \ + _(Quantized, Quantized) \ + _(Sparse, Sparse) \ + _(SparseCsr, SparseCsr) \ + _(NestedTensor, NestedTensor) \ + _(AutogradFunctionality, Autograd) + +enum class BackendComponent : uint8_t { + + // A "backend" is colloquially used to refer to handlers for dispatch + // which actually implement the numerics of an operation in question. + // + // Due to the nature of the enum, these backends are specified in + // an ordered way, but for most backends this order is not semantically + // meaningful (e.g., it's valid to reorder these backends without changing + // semantics). The only situation when backend ordering is meaningful + // is when the backend participates in multiple dispatch with another + // backend; e.g., CPU and CUDA (cuda must have higher priority). + + // These keys don't correspond to individual kernels. + // Instead, they represent the backends that are allowed to override specific + // pieces of functionality: + // - dense kernels (e.g. DispatchKey::CPU) + // - sparse kernels (e.g. DispatchKey::SparseCPU) + // - quantized kernels (e.g. DispatchKey::QuantizedCPU) + // - autograd kernels (e.g. DispatchKey::AutogradCPU) + // We reserve space in the runtime operator table for this full cross product + // of + // [backends in this enum] x [keys below that are explicitly marked as having + // per-backend functionality] + // + // A meta tensor is a tensor without any data associated with it. (They + // have also colloquially been referred to as tensors on the "null" device). + // A meta tensor can be used to dry run operators without actually doing any + // computation, e.g., add on two meta tensors would give you another meta + // tensor with the output shape and dtype, but wouldn't actually add anything. + + InvalidBit = 0, +#define DEFINE_BACKEND_COMPONENT(n, _) n##Bit, + C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused) +#undef DEFINE_BACKEND_COMPONENT + + // Define an alias to represent end of backend dispatch keys. + // If you add new backend keys after PrivateUse3, please also update it here. + EndOfBackendKeys = MetaBit, +}; + +// Semantically, a dispatch key identifies a possible "level" in our +// dispatch, for which a handler may be registered. Each handler corresponds +// to a type of functionality. +// +// In implementation terms, the dispatch key identifies a specific "bit" in a +// DispatchKeySet. Higher bit indexes get handled by dispatching first (because +// we "count leading zeros" when we extract the highest priority dispatch +// key.) +// +// Note [DispatchKey Classification] +// This enum actually contains several types of keys, which are explained +// in more detail further down: +// (1) non-customizable backends (e.g. FPGA) +// (2) non-customizable functionalities (e.g. Functionalize) +// (3) functionalized that are customizable per backend (e.g. Dense, Sparse, +// AutogradFunctionality) (4) per-backend instances of customizable +// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g. +// CompositeImplicitAutograd) +// +// Of the categories above, it's important to note: +// (a) which keys are assigned individual bits in a DispatchKeySet +// (b) which keys are assigned individual slots in the runtime operator table +// ("Runtime keys") +// +// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet. +// (1), (2) and (4) all get their own dedicated slots in the runtime operator +// table. + +// See Note [DispatchKeySet Internal Representation] for more details. +// +// NOTE: Keep the list in sync with `DispatchKey` in torchgen/model.py +enum class DispatchKey : uint16_t { + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // This is not a "real" functionality, but it exists to give us a "nullopt" + // element we can return for cases when a DispatchKeySet contains no elements. + // You can think a more semantically accurate definition of DispatchKey is: + // + // using DispatchKey = optional + // + // and Undefined == nullopt. We didn't actually represent + // it this way because optional would take two + // words, when DispatchKey fits in eight bits. + + Undefined = 0, + + // Define an alias for Undefined to represent CatchAll (long term + // this will get eliminated, but for now it's convenient) + CatchAll = Undefined, + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ // + // Every value in the enum (up to EndOfFunctionalityKeys) + // corresponds to an individual "functionality" that can be dispatched to. + // This is represented in the DispatchKeySet by assigning each of these enum + // values + // to each of the remaining (64 - len(BackendComponent)) bits. + // + // Most of these functionalities have a single handler assigned to them, + // making them "runtime keys". + // That map to a single slot in the runtime operator table. + // + // A few functionalities are allowed to be customizable per backend. + // See [Note: Per-Backend Functionality Dispatch Keys] for details. + + // See [Note: Per-Backend Functionality Dispatch Keys] + Dense, + + // Below are non-extensible backends. + // These are backends that currently don't have their own overrides for + // Autograd/Sparse/Quantized kernels, + // and we therefore don't waste space in the runtime operator table allocating + // space for them. + // If any of these backends ever need to customize, e.g., Autograd, then we'll + // need to add a DispatchKey::*Bit for them. + + // TODO: put this in BackendComponents + FPGA, // Xilinx support lives out of tree at + // https://gitlab.com/pytorch-complex/vitis_kernels + + // TODO: put this in BackendComponents + // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and + // https://github.com/microsoft/onnxruntime, and is also used to test general + // backend/extension machinery in the core. cf: + // - test/cpp_extensions/ort_extension.cpp + // - test/test_torch.py + // - aten/src/ATen/test/extension_backend_test.cpp + ORT, + + Vulkan, // TODO: put this in BackendComponents + Metal, // TODO: put this in BackendComponents + + // See [Note: Per-Backend Functionality Dispatch Keys] + Quantized, + + // This backend is to support custom RNGs; it lets you go + // to a different kernel if you pass in a generator that is not a + // traditional CPUGeneratorImpl/CUDAGeneratorImpl. To make use of this + // key: + // 1) set it as a second parameter of at::Generator constructor call in + // the user-defined PRNG class. + // 2) use it as a dispatch key while registering custom kernels + // (templatized kernels specialized for user-defined PRNG class) + // intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp + CustomRNGKeyId, + + // TODO: Make Mkldnn a functionality key, so we can give it Meta + // support + // Here are backends which specify more specialized operators + // based on the layout of the tensor. Note that the sparse backends + // are one case where ordering matters: sparse multi-dispatches with + // the corresponding dense tensors, and must be handled before them. + MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp + // NB: not to be confused with MKLDNN, which is Caffe2 only + + // See [Note: Per-Backend Functionality Dispatch Keys] + Sparse, + + SparseCsr, + + NestedTensor, + + // In some situations, it is not immediately obvious what the correct + // backend for function is, because the function in question doesn't + // have any "tensor" arguments. In this case, a BackendSelect function + // can be registered to implement the custom determination of the + // correct backend. + BackendSelect, + + Python, + + // Out-of-core key for Fake Tensor in torchdistx. + // See https://pytorch.org/torchdistx/latest/fake_tensor.html + // TODO: delete this in favor of Python-implemented fake tensor + Fake, + // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key + // is to insert code after the "autograd subsystem" runs, so this key should + // be directly after ADInplaceOrView and all of the autograd keys. + FuncTorchDynamicLayerBackMode, + + // Alias and mutation removal. + // If some backends want to opt into only alias removal or only mutation + // removal, + // we can consider adding separate keys dedicated to those individual passes. + // See Note [Functionalization Pass In Core] for details. + Functionalize, + + // The named dispatch key is set for any tensors with named dimensions. + // Although we have a dispatch key for named tensors, for historical reasons, + // this dispatch key doesn't do any of the substantive functionality for named + // tensor (though, hypothetically, it could!) At the moment, it's just + // responsible for letting us give good error messages when operations + // don't support named tensors. + // + // NB: If you ever consider moving named tensor functionality into + // this dispatch key, note that it might be necessary add another dispatch + // key that triggers before composite operators, in case a composite operator + // has named dimension propagation that doesn't match that of its + // constituent parts. + // TODO: delete this once torchdim lands in functorch + Named, + + // The Conjugate dispatch key is set for any tensors that need to perform + // conjugation + // This is implemented at a dispatch level right before any backends run + Conjugate, + + // The Negative dispatch key is set for any tensors that need to perform + // negation + // This is implemented at a dispatch level right before any backends run + Negative, + + ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp + + // Note [ADInplaceOrView key] + // ADInplaceOrView key is used by inplace or view ops to register a kernel + // that does additional setup for future autograd computation. + // + // 1. For inplace ops this kernel does version bump + // 2. For view ops this kernel does `as_view` setup where we properly setup + // DifferentiableViewMeta on the view tensors. + // + // For other ops it's fallthrough kernel since there's no extra + // work to do. + // + // Note [Dream: skip VariableType kernel when requires_grad=false] + // + // In an ideal world where we can skip VariableType kernel for inputs + // with requires_grad=false, instead of a fallthrough kernel, we'll + // register a kernel shown below to all functional ops as well: + // torch::Tensor my_functional_op(...) { + // { + // // Note for every op in VariableType, you need to go through + // // `AutoDispatchBelowADInplaceOrView` guard exactly once to add the + // // key to TLS excluded set. If you don't go through it at all, + // // inplace/view ops called through `at::` inside your backend + // // kernel will dispatch to ADInplaceOrView kernels and do a lot + // // of extra work. + // at::AutoDispatchBelowADInplaceOrView guard; + // at::redispatch::my_functional_op(...); + // } + // } + // But this work is currently blocked since it adds an extra dispatch + // for all ops and it's non-trivial overhead at model level(a few percents). + // Thus our current approach takes advantage of the fact every kernel go + // through VariableType kernel first and pulls the + // `at::AutoDispatchBelowADInplaceOrView` guard of functional ops + // up to the `VariableType` kernel. Thus we only add the extra dispatch + // to view/inplace ops to minimize its perf impact to real models. + ADInplaceOrView, + // Note [Alias Dispatch Key : Autograd] + // All backends are oblivious to autograd; autograd is handled as a + // layer which happens on top of all backends. It inspects the autograd + // metadata of all inputs, determines what autograd metadata should be + // constructed by the output, and otherwise defers to the backend to + // actually do the numeric computation. Autograd contains + // the bulk of this logic. + + // Autograd is now an alias dispatch key which by default maps to all + // backend-specific autograd keys. + // Backend-specific allow backends to override the default kernel registered + // to Autograd key as needed. + // For example, XLA wants to define autograd for einsum directly. + // Registering a custom autograd implementation at the XLA key won't work + // because we process Autograd before XLA. This key has higher priority and + // gets processed first. You generally should NOT redispatch after handling + // autograd here (since that would result in execution of the Autograd + // operator, which you're trying to skip). In AutogradXLA implementations, + // you are responsible for handling autograd yourself, or deferring to other + // operators which support autograd. + + // Currently we only have backend-specific autograd keys for CPU/CUDA/XLA and + // reserved user-defined backends. All other in-tree backends share the + // AutogradOther key. We can add specific autograd key for those backends + // upon request. + AutogradOther, + + // See [Note: Per-Backend Functionality Dispatch Keys] + AutogradFunctionality, + + // NestedTensor is an example of something that isn't a "real backend" + // (because it mostly consists of redispatching kernels) + // but it would like to override autograd functionality in C++. + // We can handle cases like this by adding an extra functionality key + // exclusively for handling autograd for NestedTensor. + // lives out of tree at + // https://github.com/pytorch/nestedtensor + AutogradNestedTensor, + + Tracer, + + // TODO: make Autocast a functionality key + // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed + // and inputs are saved for backward in the post-autocast type. + AutocastCPU, + AutocastXPU, + AutocastIPU, + AutocastHPU, + AutocastXLA, + // AutocastXLA is only being used for TPUs. XLA GPUs continue to use + // AutocastCUDA. + AutocastCUDA, + AutocastPrivateUse1, + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // There are a number of alternative modes which may want to handle before + // autograd; for example, error checking, tracing, profiling or vmap. They + // go here. + + FuncTorchBatched, // See Note [Out-of-tree vmap+grad prototype] + + // Dispatch key for BatchedTensorImpl wrapping a nested tensor. + BatchedNestedTensor, + + FuncTorchVmapMode, // See Note [Out-of-tree vmap+grad prototype] + + // This is the dispatch key for BatchedTensorImpl, which is used to implement + // batching rules for vmap. + Batched, + + // When we are inside a vmap, all tensors dispatch on this key. + // See Note: [DispatchKey::VmapMode usage] for more details. + VmapMode, + + FuncTorchGradWrapper, // See Note [Out-of-tree vmap+grad prototype] + + // Out-of-core key for Deferred Module Initialization in torchdistx. + // See https://pytorch.org/torchdistx/latest/deferred_init.html + DeferredInit, + + // Used by Python key logic to know the set of tls on entry to the dispatcher + // This kernel assumes it is the top-most non-functorch-related DispatchKey. + // If you add a key above, make sure to update the fallback implementation for + // this. + PythonTLSSnapshot, + + // This key should be at the very top of the dispatcher + FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype] + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptable use is within a single + // process test. Use it by creating a TensorImpl with this DispatchKey, and + // then registering operators to operate on this type id. See + // aten/src/ATen/core/dispatch/backend_fallback_test.cpp for a usage example. + TESTING_ONLY_GenericWrapper, + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptable use is within a ingle + // process test. Use it by toggling the mode on and off via + // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators + // to operate on this type id. See + // aten/src/ATen/core/dispatch/backend_fallback_test.cpp + // for a usage example + TESTING_ONLY_GenericMode, + + // This key is used for pre-dispatch tracing in make_fx. + // It has lower priority than the PythonDispatcher key + // because we use the PythonDispatcher to intercept the key from python, + // and avoid having to implement it in C++. + PreDispatch, + + // This is a bypass that allows you to skip running the C++ dispatcher + // entirely + PythonDispatcher, + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + EndOfFunctionalityKeys, // End of functionality keys. + +// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ // +// Here are backends which you think of as traditionally specifying +// how to implement operations on some device. + +#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n, + +#define DEFINE_PER_BACKEND_KEYS(fullname, prefix) \ + StartOf##fullname##Backends, \ + C10_FORALL_BACKEND_COMPONENTS( \ + DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \ + EndOf##fullname##Backends = prefix##Meta, + + C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS) + +#undef DEFINE_PER_BACKEND_KEYS +#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND + + EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends, + + // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // Note [Alias Dispatch Keys] + // Alias dispatch keys are synthetic dispatch keys which map to multiple + // runtime dispatch keys. Alisa keys have precedence, but they are always + // lower precedence than runtime keys. You can register a kernel to an + // alias key, the kernel might be populated to the mapped runtime keys + // during dispatch table computation. + // If a runtime dispatch key has multiple kernels from alias keys, which + // kernel wins is done based on the precedence of alias keys (but runtime + // keys always have precedence over alias keys). + // Alias keys won't be directly called during runtime. + + // See Note [Alias Dispatch Key : Autograd] + Autograd, + CompositeImplicitAutograd, // registered at + // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp + + // Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from + // all + // other alias keysets + // and so precedence order doesn't matter + FuncTorchBatchedDecomposition, // registered at + // build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp + // Note: The alias keyset for CompositeImplicitAutogradNestedTensor is + // disjoint from all other alias keysets + CompositeImplicitAutogradNestedTensor, // registered at + // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp + CompositeExplicitAutograd, // registered at + // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp + // See Note [CompositeExplicitAutogradNonFunctional Key] + CompositeExplicitAutogradNonFunctional, // registered at + // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp + + // Define an alias key to represent end of alias dispatch keys. + // If you add new alias keys after Autograd, please also update it here. + StartOfAliasKeys = Autograd, + EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, // + + // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // + // The aliases exist for backwards compatibility reasons, they shouldn't + // be used + CPUTensorId = CPU, + CUDATensorId = CUDA, + DefaultBackend = CompositeExplicitAutograd, + PrivateUse1_PreAutograd = AutogradPrivateUse1, + PrivateUse2_PreAutograd = AutogradPrivateUse2, + PrivateUse3_PreAutograd = AutogradPrivateUse3, + Autocast = AutocastCUDA, +}; + +// Note [Private use DispatchKey] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Private use tensor IDs are preallocated tensor type IDs for use in user +// applications. Similar to private use fields in HTTP, they can be used +// by end users for experimental or private applications, without needing +// to "standardize" the tensor ID (which would be done by submitting a PR +// to PyTorch to add your type ID). +// +// Private use tensor IDs are appropriate to use if you want to experiment +// with adding a new tensor type (without having to patch PyTorch first) or +// have a private, non-distributed application that needs to make use of a +// new tensor type. Private use tensor IDs are NOT appropriate to use for +// libraries intended to be distributed to further users: please contact +// the PyTorch developers to get a type ID registered in this case. +// +// We provide two classes of private user tensor id: regular DispatchKeys +// and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend" +// DispatchKeys; if you were adding support for a new type of accelerator, you +// would use a backend DispatchKey, and ideally automatically reuse +// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse +// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for +// tensors that compose multiple internal tensors, and for cases when the +// built-in autograd formulas for operators are not appropriate. + +static_assert( + (static_cast(BackendComponent::EndOfBackendKeys) + + static_cast(DispatchKey::EndOfFunctionalityKeys)) <= 64, + "The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)" + " both map to backend and functionality bits" + " into a 64-bit bitmask; you must have less than 64 total entries between them"); + +// Check if a DispatchKey is an alias mapping to other runtime keys. +constexpr bool isAliasDispatchKey(DispatchKey k) { + return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys; +} + +// [Note: Per-Backend Functionality Dispatch Keys] +// Check if a DispatchKey is a per-backend functionality key +// Any functionalities that can be customized per-backend should be added here. +// These keys correspond to functionalities that can be customized individually +// per backend. While they only take up one bit in the `DispatchKeySet` bitset, +// they map to (# backends) slots in the operator table. +// Each of these keys also has a separate set of "runtime keys" in the dispatch +// key enum, per backend, which *do* map to the individual operator table slots. +// For example, the "Sparse" key maps to an individual bit in the +// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual +// slots in the runtime operator table. + +constexpr bool isPerBackendFunctionalityKey(DispatchKey k) { + if (k == DispatchKey::Dense || k == DispatchKey::Quantized || + k == DispatchKey::Sparse || k == DispatchKey::SparseCsr || + k == DispatchKey::AutogradFunctionality || + k == DispatchKey::NestedTensor) { + return true; + } else { + return false; + } +} + +// Note that this includes Undefined in the total count. +// BUT EndOfFunctionalityKeys is its own (placeholder) key. +// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3. +// In the above example, there are 3 total functionality keys. +constexpr uint8_t num_functionality_keys = + static_cast(DispatchKey::EndOfFunctionalityKeys); + +constexpr uint8_t num_backends = + static_cast(BackendComponent::EndOfBackendKeys); + +// Note [No More Than 16 Backends] +// Search for this note to find places in the code where the "no more than 16 +// backends" invariant is baked in. +static_assert( + static_cast(BackendComponent::EndOfBackendKeys) <= 16, + "BackendComponent currently only supports <= 16 backends. If we really need to extend this, \ +there are a few places where this invariant is baked in"); + +constexpr uint8_t numPerBackendFunctionalityKeys() { + uint8_t count = 0; + for (uint8_t k = 0; k <= num_functionality_keys; ++k) { + if (isPerBackendFunctionalityKey(static_cast(k))) + ++count; + } + return count; +} + +#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) +// See [Note: Trimmed Mobile Dispatch Keys] +constexpr uint16_t num_runtime_entries = 8; +#else +constexpr uint16_t num_runtime_entries = num_functionality_keys + + (numPerBackendFunctionalityKeys() * (num_backends - 1)); +#endif + +// See Note [No More Than 16 Backends] +constexpr uint16_t full_backend_mask = + (static_cast(1) << num_backends) - 1; + +C10_API const char* toString(DispatchKey); +C10_API const char* toString(BackendComponent); +C10_API std::ostream& operator<<(std::ostream&, DispatchKey); +C10_API std::ostream& operator<<(std::ostream&, BackendComponent); + +C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k); + +// Parses a string into a dispatch key. +// If the string cannot be correctly parsed, throws an exception. +C10_API c10::DispatchKey parseDispatchKey(const std::string& k); + +// These are some convenience identifiers for dispatch keys which are +// shorter to type than their long counterparts. Note that some of these +// dispatch keys directly correspond to DeviceType; and most APIs that +// accept DispatchKey also accept DeviceType; e.g., +// torch::dispatch(torch::kCPU, ...) is also valid. +constexpr DispatchKey kAutograd = DispatchKey::Autograd; + +// See Note [The Ordering of Per-Backend Dispatch Keys Matters!] +// This function relies on the invariant that the dispatch keys between +// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend +// in the same order as `BackendComponent`. +constexpr BackendComponent toBackendComponent(DispatchKey k) { + if (k >= DispatchKey::StartOfDenseBackends && + k <= DispatchKey::EndOfDenseBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfDenseBackends)); + } else if ( + k >= DispatchKey::StartOfQuantizedBackends && + k <= DispatchKey::EndOfQuantizedBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfQuantizedBackends)); + } else if ( + k >= DispatchKey::StartOfSparseBackends && + k <= DispatchKey::EndOfSparseBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfSparseBackends)); + } else if ( + k >= DispatchKey::StartOfSparseCsrBackends && + k <= DispatchKey::EndOfSparseCsrBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfSparseCsrBackends)); + } else if ( + k >= DispatchKey::StartOfNestedTensorBackends && + k <= DispatchKey::EndOfNestedTensorBackends) { + return static_cast( + static_cast(k) - + static_cast(DispatchKey::StartOfNestedTensorBackends)); + } else if ( + k >= DispatchKey::StartOfAutogradFunctionalityBackends && + k <= DispatchKey::EndOfAutogradFunctionalityBackends) { + return static_cast( + static_cast(k) - + static_cast( + DispatchKey::StartOfAutogradFunctionalityBackends)); + } else { + return BackendComponent::InvalidBit; + } +} + +constexpr DispatchKey toFunctionalityKey(DispatchKey k) { + if (k <= DispatchKey::EndOfFunctionalityKeys) { + return k; + } else if (k <= DispatchKey::EndOfDenseBackends) { + return DispatchKey::Dense; + } else if (k <= DispatchKey::EndOfQuantizedBackends) { + return DispatchKey::Quantized; + } else if (k <= DispatchKey::EndOfSparseBackends) { + return DispatchKey::Sparse; + } else if (k <= DispatchKey::EndOfSparseCsrBackends) { + return DispatchKey::SparseCsr; + } else if (k <= DispatchKey::EndOfNestedTensorBackends) { + return DispatchKey::NestedTensor; + } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) { + return DispatchKey::AutogradFunctionality; + } else { + return DispatchKey::Undefined; + } +} + +BackendComponent toBackendComponent(DeviceType device_type); + +// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns +// DispatchKey::CUDA. +// See Note [The Ordering of Per-Backend Dispatch Keys Matters!] +// This function relies on the invariant that the dispatch keys between +// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend +// in the same order as `BackendComponent`. +constexpr DispatchKey toRuntimePerBackendFunctionalityKey( + DispatchKey functionality_k, + BackendComponent backend_k) { + if (functionality_k == DispatchKey::Dense) { + return static_cast( + static_cast(DispatchKey::StartOfDenseBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::Sparse) { + return static_cast( + static_cast(DispatchKey::StartOfSparseBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::SparseCsr) { + return static_cast( + static_cast(DispatchKey::StartOfSparseCsrBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::Quantized) { + return static_cast( + static_cast(DispatchKey::StartOfQuantizedBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::NestedTensor) { + return static_cast( + static_cast(DispatchKey::StartOfNestedTensorBackends) + + static_cast(backend_k)); + } + if (functionality_k == DispatchKey::AutogradFunctionality) { + return static_cast( + static_cast( + DispatchKey::StartOfAutogradFunctionalityBackends) + + static_cast(backend_k)); + } + return DispatchKey::Undefined; +} + +} // namespace c10 + +namespace torch { +// Expose the constant, but not the TYPE (DispatchKey is an implementation +// detail!) +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::kAutograd; +} // namespace torch + +// NB: You really shouldn't use this instance; this enum is guaranteed +// to be pretty small so a regular array should be acceptable. +namespace std { +template <> +struct hash { + typedef size_t result_type; + typedef c10::DispatchKey argument_type; + + size_t operator()(c10::DispatchKey x) const { + return static_cast(x); + } +}; +} // namespace std diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DispatchKeySet.h b/MLPY/Lib/site-packages/torch/include/c10/core/DispatchKeySet.h new file mode 100644 index 0000000000000000000000000000000000000000..45fff3879055eae7e504210a1b47e0c4cb1d743b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DispatchKeySet.h @@ -0,0 +1,941 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct FunctionalityOffsetAndMask { + // empty constructor shouldn't be used; only needed to initialize + // the array before populating it. + FunctionalityOffsetAndMask() = default; + FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask) + : offset(offset), mask(mask) {} + // This needs to big enough to cover the size of the operator table. + uint16_t offset{}; + // See Note [No More Than 16 Backends] + // This mask needs to be big enough to mask all of the backend bits. + // We probably don't ever want to have more than 16 backend bits, so uint16_t + // should be enough. + uint16_t mask{}; +}; +static_assert( + c10::num_runtime_entries < 65536, + "The dispatcher currently only supports up to 2^16 runtime entries"); + +C10_API std::array +initializeFunctionalityOffsetsAndMasks(); + +C10_ALWAYS_INLINE static const std:: + array& + offsetsAndMasks() { + static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks(); + return offsets_and_masks_; +} + +// A representation of a set of DispatchKeys. A DispatchKeySet contains both +// "functionality" bits and "backend bits", and every tensor holds its own +// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the +// keyset on every input tensor, or’ing them together, and dispatching to a +// specific piece of functionality. The functionality bits are *ordered*. When +// multiple functionality bits are set, we use the highest priority +// functionality. Similarly, multiple backend bits can theoretically be set if +// you call an operator with multiple tensors from difference devices (e.g. CPU +// and CUDA), although support for mixed device dispatch is limited (the only +// kernels that gracefully handle mixed device inputs for now are cuda kernels +// that take in a scalar cpu tensor). + +// A representation of a set of DispatchKeys. A tensor may have multiple +// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the +// DispatchKeySet specifies what type ids apply. The internal representation is +// as a 64-bit bit set (this means only 64 tensor type ids are supported). +// +// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like +// "what is the highest priority DispatchKey in the set"? (The set itself is +// not ordered; two sets with the same ids will always have the ids ordered in +// the same way.) +// +// Note [DispatchKeySet Internal Representation] +// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects +// that get passed around at runtime. +// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset +// and individual dispatch keys. +// +// First: why do we have this distinction, and why not map every dispatch key +// directly to a bit? This is mostly because we have several types of +// functionalities that different backends would like to customize. For example, +// we have: +// - "Dense": CPU, CUDA, XLA, ... (~12 keys) +// - "Sparse": SparseCPU, SparseCUDA, ... +// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ... +// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ... +// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ... +// The problem is that total number of keys grows quadratically with [# +// backends] x [# functionalities], making it very difficult to map each key +// directly to a bit in a bitset without dramatically increasing the size of the +// bitset over time. +// +// The two enums (BackendComponent and DispatchKey) can be divided roughly into +// 5 categories. +// +// (1) "Building block" keys +// (a) backends: Everything in the BackendComponent enum (e.g. CPUBit, +// CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys +// (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense) +// (2) "Runtime" keys +// (a) "non-customizable backends" (e.g. FPGA) +// (b) "non-customizable functionalities" (e.g. Functionalize) +// (c) "per-backend instances of customizable functionalities" (e.g. CPU, +// SparseCPU, AutogradCPU) +// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys]) +// +// (1) Building block keys always correspond to individual bits in a +// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual +// runtime keys. e.g. +// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit, +// DispatchKey::Dense}); +// // The keyset has the runtime dense-cpu key. +// dense_cpu_ks.has(DispatchKey::CPU); +// // And it contains the building block keys too. +// dense_cpu_ks.has(DispatchKey::CPUBit); +// dense_cpu_ks.has(DispatchKey::Dense); +// +// Not every backend and not every functionality counts as a "building block +// key". This is mostly to give us more levers to pull in the design space. +// Backend keys and functionality keys that count as "building blocks" will +// contribute to a full cross product of functionality that can be overriden. +// +// For example, right now we have at least 12 "backend" building +// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality" +// building blocks (Dense, Sparse, SparseCsr, Quantized, +// AutogradFunctionality, ...). These keys together allow every +// dispatcher operator to be customized in up to 12*4 different +// ways. Each of those requires a slot in the operator table of every +// dispatcher operator. Not every piece of functionality necessarily +// needs to be customizable per-backend, and not every backend +// necessarily needs to be able to customize every type of +// functionality. +// +// +// (2) Every runtime key corresponds directly to a slot in an operator's runtime +// dispatch table, and you can directly register kernels to a runtime dispatch +// key. +// +// For per-backend functionalities like "Dense" or "AutogradFunctionality", +// you can think of the corresponding runtime dispatch keys as "instances" of +// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all +// runtime instances of the "Dense" building block key. + +// (2a) and (2b) are represented identically in the DispatchKeySet logic: +// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT +// customizable per backend. +// In order to do so, we'd need to promote it to a per-backend functionality +// "building block" key. +// - non-customizable backends (e.g. FPGA) can NOT customize existing +// functionality like Sparse, Autograd, etc. +// In order to do so, we'd need to promote it to a backend "building block" +// key. +// +// In both cases, these keys directly correspond to runtime slots in the +// operator table. +// +// +// (3) "Alias" keys +// See Note [Alias Dispatch Keys] +// +// Final note: for anyone making future changes to the Dispatcher + +// DispatchKeySet internals, there's a closed PR with a basic +// python-implementation of the Dispatcher that might be useful in quickly +// testing out and validating changes. See it at +// https://github.com/pytorch/pytorch/pull/68743 + +// An undefined tensor is one with an empty tensor type set. +class DispatchKeySet final { + public: + enum Full { FULL }; + enum FullAfter { FULL_AFTER }; + enum Raw { RAW }; + + // NB: default constructor representation as zero is MANDATORY as + // use of DispatchKeySet in TLS requires this. + constexpr DispatchKeySet() = default; + + constexpr DispatchKeySet(Full) + : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {} + + constexpr DispatchKeySet(FullAfter, DispatchKey t) + // LSB after t are OK, but not t itself. + // "functionalities" have a notion of ordering (e.g. Autograd > Sparse > + // Quantized > Dense). But backends don't really have an ordering. + // Therefore, we're enforcing that FullAfter can only be used on + // "functionality" keys. + : repr_( + (1ULL + << (num_backends + static_cast(toFunctionalityKey(t)) - + 1)) - + 1) { + *this = add(DispatchKey::PythonDispatcher); + } + + // Public version of DispatchKeySet(uint64_t) API; external users + // must be explicit when they do this! + constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {} + + constexpr explicit DispatchKeySet(BackendComponent k) { + if (k == BackendComponent::InvalidBit) { + repr_ = 0; + } else { + repr_ = 1ULL << (static_cast(k) - 1); + } + } + + constexpr explicit DispatchKeySet(DispatchKey k) { + // NOLINTNEXTLINE(bugprone-branch-clone) + if (k == DispatchKey::Undefined) { + // Case 1: handle Undefined specifically + repr_ = 0; + } else if (k <= DispatchKey::EndOfFunctionalityKeys) { + // Case 2: handle "functionality-only" keys + // These keys have a functionality bit set, but no backend bits + // These can technically be either: + // - valid runtime keys (e.g. DispatchKey::AutogradOther, + // DispatchKey::FuncTorchBatched, etc) + // - "building block" keys that aren't actual runtime keys (e.g. + // DispatchKey::Dense or Sparse) + uint64_t functionality_val = 1ULL + << (num_backends + static_cast(k) - 1); + repr_ = functionality_val; + } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) { + // Case 3: "runtime" keys that have a functionality bit AND a backend bit. + // First compute which bit to flip for the functionality. + auto functionality_k = toFunctionalityKey(k); + // The - 1 is because Undefined is technically a "functionality" that + // doesn't show up in the bitset. So e.g. Dense is technically the second + // functionality, but the lowest functionality bit. + uint64_t functionality_val = 1ULL + << (num_backends + static_cast(functionality_k) - 1); + + // then compute which bit to flip for the backend + // Case 4a: handle the runtime instances of "per-backend functionality" + // keys For example, given DispatchKey::CPU, we should set: + // - the Dense functionality bit + // - the CPUBit backend bit + // first compute which bit to flip for the backend + auto backend_k = toBackendComponent(k); + uint64_t backend_val = backend_k == BackendComponent::InvalidBit + ? 0 + : 1ULL << (static_cast(backend_k) - 1); + repr_ = functionality_val + backend_val; + } else { + // At this point, we should have covered every case except for alias keys. + // Technically it would be possible to add alias dispatch keys to a + // DispatchKeySet, but the semantics are a little confusing and this + // currently isn't needed anywhere. + repr_ = 0; + } + } + + constexpr uint64_t keys_to_repr(std::initializer_list ks) { + uint64_t repr = 0; + for (auto k : ks) { + repr |= DispatchKeySet(k).repr_; + } + return repr; + } + + constexpr uint64_t backend_bits_to_repr( + std::initializer_list ks) { + uint64_t repr = 0; + for (auto k : ks) { + repr |= DispatchKeySet(k).repr_; + } + return repr; + } + + explicit constexpr DispatchKeySet(std::initializer_list ks) + : repr_(keys_to_repr(ks)) {} + + explicit constexpr DispatchKeySet(std::initializer_list ks) + // Note: for some reason, putting this logic directly in the constructor + // appears to fail to compile on CUDA 10.1. + // See an example internal failure at + // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr + : repr_(backend_bits_to_repr(ks)) {} + + // Test if a DispatchKey is in the set + inline bool has(DispatchKey t) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); + return has_all(DispatchKeySet(t)); + } + constexpr bool has_backend(BackendComponent t) const { + return has_all(DispatchKeySet(t)); + } + + // Test if a DispatchKey is in the set + // Given a DispatchKeySet of functionality keys and (potentially) backend + // keys, tests if all of them are in the current set. + constexpr bool has_all(DispatchKeySet ks) const { + return static_cast((repr_ & ks.repr_) == ks.repr_); + } + + // Given a DispatchKeySet of functionality keys and (potentially) backend + // keys, tests if any of them are in the current set. This could technically + // be pretty easily implemented using has(). It is strictly a perf + // optimization though. There are many places in the code base where we want + // to test for multiple functionality keys together. HOWEVER, runtime + // per-backend functionality keys aren't allowed to be used with this + // function, because you can end up with weird results. e.g. + // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU)) + // would return true. + inline bool has_any(DispatchKeySet ks) const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + // Either there are no backend bits in the input keyset + ((ks.repr_ & full_backend_mask) == 0) || + // or there are no per-backend-functionality bits + // See [Note: Per-Backend Functionality Dispatch Keys] + ((ks & + DispatchKeySet({ + DispatchKey::Dense, + DispatchKey::Quantized, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::AutogradFunctionality, + }) + .repr_) == 0)); + return static_cast((repr_ & ks.repr_) != 0); + } + // Test if DispatchKeySet is a superset of ks. + bool isSupersetOf(DispatchKeySet ks) const { + return (repr_ & ks.repr_) == ks.repr_; + } + // Perform set union + constexpr DispatchKeySet operator|(DispatchKeySet other) const { + return DispatchKeySet(repr_ | other.repr_); + } + // Perform set intersection + constexpr DispatchKeySet operator&(DispatchKeySet other) const { + return DispatchKeySet(repr_ & other.repr_); + } + // Compute the set difference self - other, + // but ONLY for the functionality keys. + // Any backend bits set on self will remain unchanged. + // See Note [Removing keys from DispatchKeySet Only Affects Functionality + // Keys] + constexpr DispatchKeySet operator-(DispatchKeySet other) const { + return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_)); + } + + // Compute self ^ other + constexpr DispatchKeySet operator^(DispatchKeySet other) const { + return DispatchKeySet(repr_ ^ other.repr_); + } + bool operator==(DispatchKeySet other) const { + return repr_ == other.repr_; + } + bool operator!=(DispatchKeySet other) const { + return repr_ != other.repr_; + } + // Add a DispatchKey to the DispatchKey set. Does NOT mutate, + // returns the extended DispatchKeySet! + C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const { + return *this | DispatchKeySet(t); + } + C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const { + return *this | ks; + } + + // Remove a DispatchKey from the DispatchKey set. + // This is generally not an operation you should be doing + // (it's used to implement the printing overload, operator<<) + // + // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys] + // Only functionality bits are allowed to be removed from a keyset. + // For now, we're only allowing removal of "functionality bits" from the + // keyset, which is specifically needed by the fallthrough key calculation + // logic. Why is removing backend bits problematic? Consider this example: + // + // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA, + // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA) + // DispatchKeySet([DispatchKey.CPU, + // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA) + // + // What do we want to happen? + // Technically, we'd like it to be true that after removal, + // the first keyset still has the CUDA dispatch key while the second doesn't. + // Unfortunately there's no way to represent that, because the two keysets are + // represented the same way internally: functionality bits: Autograd, Dense + // backend bits: CPU, CUDA + // + // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd" + // bit from the bitset. + C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const { + return DispatchKeySet( + repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); + } + // You're allowed to remove a backend bit from a DispatchKeySet, + // but you have to be explicit about it (remove_backend() instead of + // remove()). + constexpr DispatchKeySet remove_backend(BackendComponent b) const { + return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_)); + } + // Is the set empty? (AKA undefined tensor) + bool empty() const { + return repr_ == 0; + } + uint64_t raw_repr() { + return repr_; + } + + DispatchKey highestFunctionalityKey() const { + auto functionality_idx = indexOfHighestBit(); + // This means that none of the functionality bits were set. + if (functionality_idx < num_backends) + return DispatchKey::Undefined; + // The first num_backend bits in the keyset don't correspond to real + // dispatch keys. + return static_cast(functionality_idx - num_backends); + } + + // This is similar like toBackendComponent(DispatchKey), but less restrictive. + // toBackendComponent() errors out if the key that it was passed has no + // backend bits, which is useful for error checking. We need a version of that + // here that can also handle "fake" backends like FPGA, because they need to + // map to the AutogradOther key. For those backends, we return + // BackendComponent::InvalidBit. + BackendComponent highestBackendKey() const { + // mask to mask out functionality bits + auto backend_idx = + DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit(); + // all zeros across the backend bits means that no backend bits are set. + if (backend_idx == 0) + return BackendComponent::InvalidBit; + return static_cast(backend_idx); + } + + // returns the DispatchKey of highest priority in the set. + DispatchKey highestPriorityTypeId() const { + auto functionality_k = highestFunctionalityKey(); + if (isPerBackendFunctionalityKey(functionality_k)) { + return toRuntimePerBackendFunctionalityKey( + functionality_k, highestBackendKey()); + } + return functionality_k; + } + + // Returns the index of the most-significant bit in the keyset. + // This is used to as part of the calculation into the operator table to get: + // - the highest "functionality" bit in the keyset. + // - the highest "backend" bit in the keyset. + uint8_t indexOfHighestBit() const { + return 64 - llvm::countLeadingZeros(repr_); + } + +#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) + // [Note: Trimmed Mobile Dispatch Keys] + /** + * The method below maps the dispatch key in the enum DispatchKey to an + * integer index in the dispatchTable_ array in OperatorEntry. The array + * is trimmed for mobile to reduce peak memory usage since it's + * unnecessary to reserve additional space for dispatch keys that will + * never be used on mobile. + */ + int getDispatchTableIndexForDispatchKeySet() const { + auto dk = highestPriorityTypeId(); + switch (dk) { + case DispatchKey::Undefined: + return 0; + case DispatchKey::CPU: + return 1; + case DispatchKey::QuantizedCPU: + return 2; + case DispatchKey::SparseCPU: + return 3; + case DispatchKey::BackendSelect: + return 4; + case DispatchKey::ADInplaceOrView: + return 5; + case DispatchKey::AutogradOther: + return 6; + case DispatchKey::AutogradCPU: + return 7; + default: + return -1; + } + } +#else + // returns the index in the operator table of highest priority key in the the + // keyset Note that we could in theory implement this using + // highestPriorityTypeId(), but this code is very hotpath and we can do it + // faster without it. + int getDispatchTableIndexForDispatchKeySet() const { + auto functionality_idx = + DispatchKeySet(repr_ >> num_backends).indexOfHighestBit(); + auto offset_and_mask = offsetsAndMasks()[functionality_idx]; + // Mask the functionality bits out first, then right-shift by 1. + // right-shifting by 1 because everything is zero-indexed. + // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should + // give us an offset of 1, etc. + auto backend_idx = + DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit(); + return offset_and_mask.offset + backend_idx; + } +#endif + + // returns the "index" of the highest priority backend in the keyset. + // This is pretty similar to getBackendKey(), but: + // - It's hotpath code (part of the runtime bitset calculation) + // - I's returns an integer index, not an enum value + // - Everything is shifted to the right by 1. + // BackendComponent::InvalidBit is technically the lowest enum value, + // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2, + // etc. + uint64_t getBackendIndex() const { + return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit(); + } + + private: + constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {} + uint64_t repr_ = 0; + + public: + // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys + // in the set. The iterator is only invalidated by the destruction of the + // underlying DispatchKeySet as the iterator stores a pointer to the raw + // representation of the DispatchKeySet. Note: When we encounter a per-backend + // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend + // in the keyset, for that functionality. For example, if the next + // functionality key to iterate over is Autograd, and the backend bits in the + // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit], + // then the next two keys we return will be DispatchKey::AutogradCPU, + // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than + // CUDA in DispatchKey.h). + class iterator { + public: + using self_type = iterator; + using iterator_category = std::input_iterator_tag; + using value_type = DispatchKey; + using difference_type = ptrdiff_t; + using reference = value_type&; + using pointer = value_type*; + // final mask value should mask out the entire keyset + static const uint8_t end_iter_mask_val = + num_backends + num_functionality_keys; + // final key value should be the last DispatchKey + static const uint8_t end_iter_key_val = num_functionality_keys; + + // current_dispatchkey_idx_ will iterate through all functionality bits. + // current_backendcomponent_idx_ will iterate through all backend bits. + explicit iterator( + const uint64_t* data_ptr, + uint8_t next_functionality = num_backends, + uint8_t next_backend = 0) + : data_ptr_(data_ptr), + next_functionality_(next_functionality), + next_backend_(next_backend), + // These are in an invalid state at construction time, and set by the + // first increment call + current_dispatchkey_idx_(end_iter_key_val), + current_backendcomponent_idx_(end_iter_key_val) { + // Go to the first key in the set + TORCH_INTERNAL_ASSERT( + next_functionality_ >= num_backends, + "num_backends=", + static_cast(num_backends), + "next_functionality_=", + static_cast(next_functionality_)); + ++(*this); + } + + C10_API self_type& operator++(); + + self_type operator++(int) { + self_type previous_iterator = *this; + ++(*this); + return previous_iterator; + } + + bool operator==(const self_type& rhs) const { + return next_functionality_ == rhs.next_functionality_ && + current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ && + next_backend_ == rhs.next_backend_ && + current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_; + } + bool operator!=(const self_type& rhs) const { + return next_functionality_ != rhs.next_functionality_ || + current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ || + next_backend_ != rhs.next_backend_ || + current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_; + } + DispatchKey operator*() const { + auto functionality_key = + static_cast(current_dispatchkey_idx_); + if (isPerBackendFunctionalityKey(functionality_key)) { + auto next_key = toRuntimePerBackendFunctionalityKey( + functionality_key, + static_cast(current_backendcomponent_idx_)); + // We expect all of the Dense, Sparse, Quantized, and Autograd keys to + // be ordered the same way with respect to their backends + TORCH_INTERNAL_ASSERT( + toBackendComponent(next_key) == + static_cast(current_backendcomponent_idx_), + "Tried to map functionality key ", + toString(functionality_key), + " and backend bit ", + toString( + static_cast(current_backendcomponent_idx_)), + " to a runtime key, but ended up with ", + toString(next_key), + ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.", + " Please double check that enum for inconsistencies."); + return next_key; + } else { + return functionality_key; + } + } + + private: + const uint64_t* data_ptr_; + uint8_t next_functionality_; + uint8_t next_backend_; + uint8_t current_dispatchkey_idx_; + uint8_t current_backendcomponent_idx_; + }; + + public: + // Returns iterator to the first key in the set. If no keys are in the + // set, then will return the end iterator. + iterator begin() const { + return iterator(&repr_); + } + + // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat + // this as the end iterator. + iterator end() const { + return iterator(&repr_, iterator::end_iter_mask_val); + } +}; + +C10_API std::string toString(DispatchKeySet); +C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); + +C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { + return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet(); +} + +// Alias key DispatchKey::Autograd maps to +// (autograd_dispatch_keyset x full_backend_mask) +// NB: keys in this set also get associated with CompositeImplicitAutograd +// +// Note [autograd_dispatch_keyset Does Not Include Backend Bits] +// We don't want to include any backend bits (BackendComponent::CPUBit, etc) +// directly in autograd_dispatch_keyset. +// Why? keysets like autograd_dispatch_keyset are commonly used to remove +// autograd keys from a DispatchKeySet throughout the code base. However, you +// are only allowed to remove functionality bits from a keyset, not backend +// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality +// Keys] for details. To be consistent and avoid confusion, we're explicitly +// setting up autograd_dispatch_keyset to not have any backend bits. +constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ + DispatchKey::AutogradFunctionality, + DispatchKey::AutogradOther, + DispatchKey::AutogradNestedTensor, +}); + +constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ + DispatchKey::AutocastCPU, + DispatchKey::AutocastCUDA, + DispatchKey::AutocastXPU, + DispatchKey::AutocastIPU, + DispatchKey::AutocastHPU, + DispatchKey::AutocastXLA, + DispatchKey::AutocastPrivateUse1, +}); + +// See Note [TLS Initialization] +constexpr DispatchKeySet default_included_set = DispatchKeySet({ + DispatchKey::BackendSelect, + DispatchKey::ADInplaceOrView, +}); + +constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ + DispatchKey::AutocastCPU, + DispatchKey::AutocastCUDA, + DispatchKey::AutocastXPU, + DispatchKey::AutocastIPU, + DispatchKey::AutocastHPU, + DispatchKey::AutocastXLA, + DispatchKey::AutocastPrivateUse1, +}); + +constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = + autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView); + +constexpr DispatchKeySet python_ks = DispatchKeySet({ + DispatchKey::Python, + DispatchKey::PythonTLSSnapshot, +}); + +constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse); + +constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr); + +constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU); + +// backend dispatch keys that map to DispatchKey::AutogradOther +// NB: keys in this set also get associated with CompositeImplicitAutograd +constexpr DispatchKeySet autogradother_backends = + DispatchKeySet( + // HIP and VE aren't in this list: they now have their own backend bits + // which means that they can now have their own Autograd keys. + // Technically, HIP will now redispatch to its own custom AutogradHIP + // slot in the runtime table. + {DispatchKey::FPGA, + DispatchKey::ORT, + DispatchKey::Vulkan, + DispatchKey::Metal, + DispatchKey::CustomRNGKeyId, + DispatchKey::MkldnnCPU, + // Sparse and Quantized backends also live here. + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::Quantized}) + // Including the backend bits because this keyset is used during op + // registration, which requires looping over all runtime autogradother + // backend keys. + | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); + +// The set of dispatch keys that come after autograd +// n.b. this relies on the fact that AutogradOther is currently the lowest +// Autograd key +constexpr DispatchKeySet after_autograd_keyset = + DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther); + +// The set of dispatch keys that come after ADInplaceOrView +constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet( + DispatchKeySet::FULL_AFTER, + c10::DispatchKey::ADInplaceOrView); + +// The set of dispatch keys that come after Functionalize +constexpr DispatchKeySet after_func_keyset = + DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize) + .remove( + // NOTE: we also need to remove ADInplaceOrView from the keyset when + // redispatching after the func kernels. This is because we're not + // calling the same op; we originally called an inplace op, and now + // we aren't. The original key calculation figured out which keys + // were Fallthrough based on the inplace op. That means that it did + // not include the ADInPlaceOrView kernel as a fallthrough key. + // However, we WANT the ADInPlaceOrView kernel to be ignored now + // that we're calling an out-of-place op. Re-invoking + // Dispatcher::call would re-run the Fallthrough key calculation and + // get us that, But at::redispatch is more performant. We can get + // away with it by explicitly removing the key here. + c10::DispatchKey::ADInplaceOrView); + +constexpr DispatchKeySet backend_bitset_mask = + DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1); + +constexpr auto inplace_or_view_ks = + DispatchKeySet(DispatchKey::ADInplaceOrView); +constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); +constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU); +constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); +constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); +constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); +constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy); +constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta); +constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS); +constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU); +constexpr auto autograd_privateuse1_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse1); +constexpr auto autograd_privateuse2_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse2); +constexpr auto autograd_privateuse3_ks = + DispatchKeySet(DispatchKey::AutogradPrivateUse3); +constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther); +constexpr auto autograd_nested = + DispatchKeySet(DispatchKey::AutogradNestedTensor); +// keyset corresponding to functorch keys that have their own dedicated +// TensorImpl subclass. +constexpr auto functorch_transforms_ks = DispatchKeySet( + {DispatchKey::FuncTorchBatched, + DispatchKey::FuncTorchVmapMode, + DispatchKey::Batched, + DispatchKey::VmapMode, + DispatchKey::FuncTorchGradWrapper}); + +constexpr auto functorch_batched_ks = + DispatchKeySet({DispatchKey::FuncTorchBatched}); + +// This keyset has: +// (1) the functionality bits corresponding to backends (dense, sparse, +// quantized) (2) all of the backend bits set +constexpr DispatchKeySet backend_functionality_keys = + DispatchKeySet({ + DispatchKey::Dense, + DispatchKey::Quantized, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + }) | + DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); + +struct OpTableOffsetAndMask { + uint16_t offset; + uint16_t backend_mask; +}; + +static_assert( + num_backends <= 16, + "Right now we expect the number of backends not to exceed 16. In the (unlikely) event" + " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too."); + +// true if t is a backend dispatch key +C10_API bool isBackendDispatchKey(DispatchKey t); + +// Resolve alias dispatch key to DispatchKeySet if applicable +C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); + +// Resolve alias dispatch key to DispatchKeySet if applicable, +// and check if k is a part of that set +C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k); + +// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key +// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd. +C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); + +// Returns a DispatchKeySet of autograd related keys mapped to backend. +// for a given backend key, use the associated autograd key. +// for non-backend keys, use AutogradOther as a default. +// Note: it's convenient and fast to return a default here rather than (say) +// returning an optional, or throwing. But it makes callers +// responsible for either a) enforcing the invariant that only backend keys +// be passed as arguments, or b) interpreting our return value carefully. +inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { + switch (t) { + case BackendComponent::CPUBit: + return inplace_or_view_ks | autograd_cpu_ks; + case BackendComponent::IPUBit: + return inplace_or_view_ks | autograd_ipu_ks; + case BackendComponent::XPUBit: + return inplace_or_view_ks | autograd_xpu_ks; + case BackendComponent::CUDABit: + return inplace_or_view_ks | autograd_cuda_ks; + case BackendComponent::XLABit: + return inplace_or_view_ks | autograd_xla_ks; + case BackendComponent::LazyBit: + return inplace_or_view_ks | autograd_lazy_ks; + case BackendComponent::MetaBit: + return inplace_or_view_ks | autograd_meta_ks; + case BackendComponent::MPSBit: + return inplace_or_view_ks | autograd_mps_ks; + case BackendComponent::HPUBit: + return inplace_or_view_ks | autograd_hpu_ks; + case BackendComponent::PrivateUse1Bit: + return inplace_or_view_ks | autograd_privateuse1_ks; + case BackendComponent::PrivateUse2Bit: + return inplace_or_view_ks | autograd_privateuse2_ks; + case BackendComponent::PrivateUse3Bit: + return inplace_or_view_ks | autograd_privateuse3_ks; + default: + return inplace_or_view_ks | autograd_other_ks; + } +} + +// Returns a DispatchKeySet of autocast related keys mapped to backend. +inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { + constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); + constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU); + constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU); + constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU); + constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA); + constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA); + constexpr auto autocast_privateuse1_ks = + DispatchKeySet(DispatchKey::AutocastPrivateUse1); + switch (t) { + case BackendComponent::CPUBit: + return autocast_cpu_ks; + case BackendComponent::XPUBit: + return autocast_xpu_ks; + case BackendComponent::IPUBit: + return autocast_ipu_ks; + case BackendComponent::HPUBit: + return autocast_hpu_ks; + case BackendComponent::CUDABit: + return autocast_cuda_ks; + case BackendComponent::XLABit: + return autocast_xla_ks; + case BackendComponent::PrivateUse1Bit: + return autocast_privateuse1_ks; + default: + return DispatchKeySet(); + } +} + +// returns the "backend" DispatchKey of highest priority in the set. +// This is basically like highestBackendKey(), except that we have some +// "functionality" bits that correspond to backends (Sparse, Quantized) +inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) { + return (ks & backend_functionality_keys).highestPriorityTypeId(); +} + +// This API exists because we have a use case for checking +// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) +// in OperatorEntry.cpp but we disallow it in has() API. +C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); + +// Historically, every tensor only had a single DispatchKey, and it was always +// something like CPU, and there wasn't any of this business where TLS +// could cause the DispatchKey of a tensor to change. But we still have some +// legacy code that is still using DispatchKey for things like instanceof +// checks; if at all possible, refactor the code to stop using DispatchKey in +// those cases. +static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { + // NB: If you add any extra keys that can be stored in TensorImpl on + // top of existing "backend" keys like CPU/CUDA, you need to add it + // here. At the moment, autograd keys and ADInplaceOrView key need this + // treatment; + return (s - autograd_dispatch_keyset_with_ADInplaceOrView - + autocast_dispatch_keyset - + DispatchKeySet( + {DispatchKey::Functionalize, + DispatchKey::PythonTLSSnapshot, + DispatchKey::Python})) + .highestPriorityTypeId(); +} + +template +using is_not_DispatchKeySet = std::negation>; + +// Given a function type, constructs a function_traits type that drops the first +// parameter type if the first parameter is of type DispatchKeySet. NB: +// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid +// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through +// the Dispatcher] for details). If at any point in the future we need to expose +// this type to JIT, revisit the usage of this type alias. +template +using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t< + typename guts::infer_function_traits_t::return_type, + typename std::conditional_t< + std::is_same_v< + DispatchKeySet, + typename guts::typelist::head_with_default_t< + void, + typename guts::infer_function_traits_t< + FuncType>::parameter_types>>, + guts::typelist::drop_if_nonempty_t< + typename guts::infer_function_traits_t::parameter_types, + 1>, + typename guts::infer_function_traits_t::parameter_types>>; +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/DynamicCast.h b/MLPY/Lib/site-packages/torch/include/c10/core/DynamicCast.h new file mode 100644 index 0000000000000000000000000000000000000000..65a5e4f3b66ff99829ac8089f07dd6eb5c97b699 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/DynamicCast.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10 { + +// Dynamic type casting utils: +// - fetch_and_cast +// - cast_and_store +// +// fetch_and_cast fetch a value with dynamic type specified by a ScalarType +// from a void pointer and cast it to a static type. +// +// cast_and_store casts a static typed value into dynamic type specified +// by a ScalarType, and store it into a void pointer. +// +// NOTE: +// +// Dynamic casting allows us to support type promotion without blowing up +// the combination space: For example, without dynamic cast, in order to +// implement `add_` with type promotion, we would need something like +// +// AT_DISPATCH_ALL_TYPES(output.dtype(), +// AT_DISPATCH_ALL_TYPES(input1.dtype(), +// AT_DISPATCH_ALL_TYPES(input2.dtype(), +// [](arg0_t a, arg1_t b) -> out_t { return a + b; } +// ) +// ) +// ) +// +// If we support N dtypes, the above code would generate the a+b kernel for +// all the N * N * N different supported types, the compilation time and +// binary size would become horrible. +// +// Dynamic casting might sounds like a bad idea in terms of performance. +// Especially if you ever do it in a loop, you are going to do a billion tests. +// But in practice it is not as bad as it might look: +// +// - on CPU, this is a branch that always has the same outcome, therefore +// hopefully the branch predictor could do the job pretty well +// - on GPU, these branches will not diverge, so we could still have the same +// warp executing the same line of code +// - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to +// check an integer does not hurt the performance much because the ALUs would +// wait for load instructions anyway. +// +// For the discussion and benchmark, refer to: +// - https://github.com/pytorch/pytorch/pull/28343 +// - https://github.com/pytorch/pytorch/pull/28344 +// - https://github.com/pytorch/pytorch/pull/28345 +// + +#ifdef C10_HOST_DEVICE +#define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false); +#else +#define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type"); +#endif + +// Fetch a value with dynamic type src_type from ptr, and cast it to static type +// dest_t. +#define FETCH_AND_CAST_CASE(type, scalartype) \ + case ScalarType::scalartype: \ + return c10::convert(c10::load(ptr)); + +template +C10_HOST_DEVICE inline dest_t fetch_and_cast( + const ScalarType src_type, + const void* ptr) { + switch (src_type) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) + FETCH_AND_CAST_CASE(uint16_t, UInt16) + FETCH_AND_CAST_CASE(uint32_t, UInt32) + FETCH_AND_CAST_CASE(uint64_t, UInt64) + default: + ERROR_UNSUPPORTED_CAST + } + return dest_t(0); // just to avoid compiler warning +} + +// Cast a value with static type src_t into dynamic dest_type, and store it to +// ptr. +#define CAST_AND_STORE_CASE(type, scalartype) \ + case ScalarType::scalartype: \ + *(type*)ptr = c10::convert(value); \ + return; +template +C10_HOST_DEVICE inline void cast_and_store( + const ScalarType dest_type, + void* ptr, + src_t value) { + switch (dest_type) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) + CAST_AND_STORE_CASE(uint16_t, UInt16) + CAST_AND_STORE_CASE(uint32_t, UInt32) + CAST_AND_STORE_CASE(uint64_t, UInt64) + default:; + } + ERROR_UNSUPPORTED_CAST +} + +#define DEFINE_UNCASTABLE(T, scalartype_) \ + template <> \ + C10_HOST_DEVICE inline T fetch_and_cast( \ + const ScalarType src_type, const void* ptr) { \ + CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \ + return c10::load(ptr); \ + } \ + template <> \ + C10_HOST_DEVICE inline void cast_and_store( \ + const ScalarType dest_type, void* ptr, T value) { \ + CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \ + *(T*)ptr = value; \ + } + +AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE) + +#undef FETCH_AND_CAST_CASE +#undef CAST_AND_STORE_CASE +#undef DEFINE_UNCASTABLE +#undef ERROR_UNSUPPORTED_CAST + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Event.h b/MLPY/Lib/site-packages/torch/include/c10/core/Event.h new file mode 100644 index 0000000000000000000000000000000000000000..475aca5fdf252730a88cc40de2ab56c938827335 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Event.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/** + * A backend-generic movable, not copyable, not thread-safe event. + * + * The design of this event follows that of CUDA and HIP events. These events + * are recorded and waited on by streams and can be rerecorded to, + * each rerecording essentially creating a new version of the event. + * For example, if (in CPU time), stream X is asked to record E, + * stream Y waits on E, and stream X is asked to record E again, then Y will + * wait for X to finish the first call to record and not the second, because + * it's waiting on the first version of event E, not the second. + * Querying an event only returns the status of its most recent version. + * + * Backend-generic events are implemented by this class and + * impl::InlineEvent. In addition to these events there are also + * some backend-specific events, like ATen's CUDAEvent. Each of these + * classes has its own use. + * + * impl::InlineEvent<...> or a backend-specific event should be + * preferred when the backend is known at compile time and known to + * be compiled. Backend-specific events may have additional functionality. + * + * This Event should be used if a particular backend may not be available, + * or the backend required is not known at compile time. + * + * These generic events are built on top of DeviceGuardImpls, analogous + * to DeviceGuard and InlineDeviceGuard. The name "DeviceGuardImpls," + * is no longer entirely accurate, as these classes implement the + * backend-specific logic for a generic backend interface. + * + * See DeviceGuardImplInterface.h for a list of all supported flags. + */ + +struct Event final { + // Constructors + Event() = delete; + Event( + const DeviceType _device_type, + const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) + : impl_{_device_type, _flag} {} + + // Copy constructor and copy assignment operator (deleted) + Event(const Event&) = delete; + Event& operator=(const Event&) = delete; + + // Move constructor and move assignment operator + Event(Event&&) noexcept = default; + Event& operator=(Event&&) noexcept = default; + + // Destructor + ~Event() = default; + + // Getters + Device device() const noexcept { + return Device(device_type(), device_index()); + } + DeviceType device_type() const noexcept { + return impl_.device_type(); + } + DeviceIndex device_index() const noexcept { + return impl_.device_index(); + } + EventFlag flag() const noexcept { + return impl_.flag(); + } + bool was_marked_for_recording() const noexcept { + return impl_.was_marked_for_recording(); + } + + /** + * Calls record() if and only if record() has never been called for this + * event. Note: because Event is not thread-safe recordOnce() may call + * record() multiple times if called from multiple threads. + */ + void recordOnce(const Stream& stream) { + impl_.recordOnce(stream); + } + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record(const Stream& stream) { + impl_.record(stream); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(const Stream& stream) const { + impl_.block(stream); + } + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool query() const { + return impl_.query(); + } + + private: + impl::InlineEvent impl_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/GeneratorImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/GeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..089dd1dba43611df95e1ed6076e2e1eb9149aa0a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/GeneratorImpl.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +/** + * Note [Generator] + * ~~~~~~~~~~~~~~~~ + * A Pseudo Random Number Generator (PRNG) is an engine that uses an algorithm + * to generate a seemingly random sequence of numbers, that may be later be used + * in creating a random distribution. Such an engine almost always maintains a + * state and requires a seed to start off the creation of random numbers. Often + * times, users have found it beneficial to be able to explicitly create, + * retain, and destroy PRNG states and also be able to have control over the + * seed value. + * + * A Generator in ATen gives users the ability to read, write and modify a PRNG + * engine. For instance, it does so by letting users seed a PRNG engine, fork + * the state of the engine, etc. + * + * By default, there is one generator per device, and a device's generator is + * lazily created. A user can use the torch.Generator() api to create their own + * generator. Currently torch.Generator() can only create a CPUGeneratorImpl. + */ + +/** + * Note [Acquire lock when using random generators] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Generator and its derived classes are NOT thread-safe. Please note that most + * of the places where we have inserted locking for generators are historically + * based, and we haven't actually checked that everything is truly thread safe + * (and it probably isn't). Please use the public mutex_ when using any methods + * from these classes, except for the read-only methods. You can learn about the + * usage by looking into the unittests (aten/src/ATen/cpu_generator_test.cpp) + * and other places where we have used lock_guard. + * + * TODO: Look into changing the threading semantics of Generators in ATen (e.g., + * making them non-thread safe and instead making the generator state + * splittable, to accommodate forks into other threads). + */ + +namespace c10 { + +// The default seed is selected to be a large number +// with good distribution of 0s and 1s in bit representation +constexpr uint64_t default_rng_seed_val = 67280421310721; + +struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { + // Constructors + GeneratorImpl(Device device_in, DispatchKeySet key_set); + + // Delete all copy and move assignment in favor of clone() + // method + GeneratorImpl(const GeneratorImpl& other) = delete; + GeneratorImpl(GeneratorImpl&& other) = delete; + GeneratorImpl& operator=(const GeneratorImpl& other) = delete; + + ~GeneratorImpl() override = default; + c10::intrusive_ptr clone() const; + + // Common methods for all generators + virtual void set_current_seed(uint64_t seed) = 0; + virtual void set_offset(uint64_t offset) = 0; + virtual uint64_t get_offset() const = 0; + virtual uint64_t current_seed() const = 0; + virtual uint64_t seed() = 0; + virtual void set_state(const c10::TensorImpl& new_state) = 0; + virtual c10::intrusive_ptr get_state() const = 0; + Device device() const; + + // See Note [Acquire lock when using random generators] + std::mutex mutex_; + + DispatchKeySet key_set() const { + return key_set_; + } + + inline void set_pyobj(PyObject* pyobj) noexcept { + pyobj_ = pyobj; + } + + inline PyObject* pyobj() const noexcept { + return pyobj_; + } + + protected: + Device device_; + DispatchKeySet key_set_; + PyObject* pyobj_ = nullptr; + + virtual GeneratorImpl* clone_impl() const = 0; +}; + +namespace detail { + +C10_API uint64_t getNonDeterministicRandom(bool is_cuda = false); + +} // namespace detail + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/GradMode.h b/MLPY/Lib/site-packages/torch/include/c10/core/GradMode.h new file mode 100644 index 0000000000000000000000000000000000000000..d49c9fdacd38d433b967e6ee0107f83e83e3feee --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/GradMode.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +namespace c10 { + +struct C10_API GradMode { + static bool is_enabled(); + static void set_enabled(bool enabled); +}; + +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct C10_API AutoGradMode { + AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { + GradMode::set_enabled(enabled); + } + ~AutoGradMode() { + GradMode::set_enabled(prev_mode); + } + bool prev_mode; +}; + +// A RAII, thread local (!) guard that stops future operations from building +// gradients. +struct C10_API NoGradGuard : public AutoGradMode { + NoGradGuard() : AutoGradMode(/*enabled=*/false) {} +}; + +// A RAII, thread local (!) guard that enables or disables forward grad mode +// upon construction, and sets it back to the original value upon destruction. +struct C10_API AutoFwGradMode { + AutoFwGradMode(bool enabled) + : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { + AutogradState::get_tls_state().set_fw_grad_mode(enabled); + } + ~AutoFwGradMode() { + AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); + } + bool prev_mode; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/InferenceMode.h b/MLPY/Lib/site-packages/torch/include/c10/core/InferenceMode.h new file mode 100644 index 0000000000000000000000000000000000000000..ecbbdbded7ecd7fc6adf48a67618f075dddbd310 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/InferenceMode.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +// A RAII, thread local (!) guard that enables or disables inference mode upon +// construction, and sets it back to the original value upon destruction. +struct C10_API InferenceMode { + // Note [Expected TLS state in InferenceMode]: + // InferenceMode: ADInplaceOrView not in + // raw_local_dispatch_key_set.included(), + // Autograd in raw_local_dispatch_key_set.excluded() + // GradMode is disabled. + // NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(), + // Autograd not in raw_local_dispatch_key_set.excluded() + // GradMode is enabled by default unless toggled manually + // through other APIs, e.g. NoGradGuard. + // + // Invariant: + // - ADInplaceOrView is never in the excluded set + // - Autograd is never in the included set + // - Setting InferenceMode will set GradMode accordingly, but not vice versa. + // + // 1. Why do we put ADInplaceOrView in included set outside InferenceMode? + // + // Inplace update to inference tensor outside InferenceMode is not + // allowed. See Note [Inplace update inference tensor] for more details. + // Without going through ADInplaceOrView kernel, we cannot throw error + // for `inference_tensor.add_(1)` case. + // + // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode? + // + // For example: + // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); + // torch::Tensor k = a + 2; + // { + // c10::InferenceMode guard(true); + // k.add_(2); + // } + // `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's + // prepared for future autograd. + // + // 3. Why does setting InferenceMode also set GradMode? + // + // This is required since InferenceMode is a faster and more restrictive + // version of NoGradGuard. All runtime checks using GradMode::is_enabled() + // are applicable to InferenceMode as well, e.g. + // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. + InferenceMode(bool enabled = true) + : prev_mode(AutogradState::get_tls_state()), + prev_keyset(c10::impl::tls_local_dispatch_key_set()) { + // Enabling inference mode means disabling grad modes + // And disabling inference mode means enabling grad modes + AutogradState::set_tls_state(AutogradState( + /* grad_mode */ !enabled, + /* inference_mode */ enabled, + /* fw_grad_mode */ !enabled, + /* multithreading_enabled*/ !enabled)); + DispatchKeySet included = enabled + ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) + : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); + DispatchKeySet excluded = enabled + ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) + : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); + c10::impl::PODLocalDispatchKeySet cur_keyset{}; + cur_keyset.set_included(included); + cur_keyset.set_excluded(excluded); + c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); + } + + ~InferenceMode() { + AutogradState::set_tls_state(prev_mode); + c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); + } + static bool is_enabled(); + + private: + AutogradState prev_mode; + c10::impl::LocalDispatchKeySet prev_keyset; +}; +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Layout.h b/MLPY/Lib/site-packages/torch/include/c10/core/Layout.h new file mode 100644 index 0000000000000000000000000000000000000000..3ec87697d18b902a5e3ec112911a5f4c615de678 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Layout.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { +enum class Layout : int8_t { + Strided, + Sparse, + SparseCsr, + Mkldnn, + SparseCsc, + SparseBsr, + SparseBsc, + Jagged, + NumOptions +}; + +constexpr auto kStrided = Layout::Strided; +constexpr auto kSparse = Layout::Sparse; +constexpr auto kSparseCsr = Layout::SparseCsr; +constexpr auto kMkldnn = Layout::Mkldnn; +constexpr auto kSparseCsc = Layout::SparseCsc; +constexpr auto kSparseBsr = Layout::SparseBsr; +constexpr auto kSparseBsc = Layout::SparseBsc; +constexpr auto kJagged = Layout::Jagged; + +inline Layout layout_from_backend(Backend backend) { + switch (backend) { + case Backend::SparseCPU: + case Backend::SparseCUDA: + case Backend::SparseHIP: + case Backend::SparseVE: + case Backend::SparseXPU: + case Backend::SparsePrivateUse1: + return Layout::Sparse; + case Backend::MkldnnCPU: + return Layout::Mkldnn; + case Backend::SparseCsrCPU: + case Backend::SparseCsrCUDA: + case Backend::SparseCsrHIP: + case Backend::SparseCsrVE: + case Backend::SparseCsrXPU: + TORCH_CHECK( + false, + "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout."); + default: + return Layout::Strided; + } +} + +inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) { + switch (layout) { + case at::kStrided: + return stream << "Strided"; + case at::kSparse: + return stream << "Sparse"; + case at::kSparseCsr: + return stream << "SparseCsr"; + case at::kSparseCsc: + return stream << "SparseCsc"; + case at::kSparseBsr: + return stream << "SparseBsr"; + case at::kSparseBsc: + return stream << "SparseBsc"; + case at::kMkldnn: + return stream << "Mkldnn"; + case at::kJagged: + return stream << "Jagged"; + default: + TORCH_CHECK(false, "Unknown layout"); + } +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/MemoryFormat.h b/MLPY/Lib/site-packages/torch/include/c10/core/MemoryFormat.h new file mode 100644 index 0000000000000000000000000000000000000000..55d25bd24d35b5cc50afce435042bc85d4f7fb76 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/MemoryFormat.h @@ -0,0 +1,290 @@ +#pragma once + +#include +#include + +#include +#include +#include + +// Memory format is not the property of a Tensor. It is the way to tell an +// operator how the result should be organized in memory and nothing more. That +// means memory format should never be used as return value for any tensor state +// interrogation functions (internally and externally). +// +// Possible options are: +// Preserve: +// If any of the input tensors is in channels_last format, operator output +// should be in channels_last format +// +// Contiguous: +// Regardless of input tensors format, the output should be contiguous +// Tensor. +// +// ChannelsLast: +// Regardless of input tensors format, the output should be in channels_last +// format. + +namespace c10 { +enum class MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d, + NumOptions +}; + +// If you are seeing this, it means that this call site was not checked if +// the memory format could be preserved, and it was switched to old default +// behaviour of contiguous +#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format() + +inline MemoryFormat get_contiguous_memory_format() { + return MemoryFormat::Contiguous; +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::MemoryFormat memory_format) { + switch (memory_format) { + case MemoryFormat::Preserve: + return stream << "Preserve"; + case MemoryFormat::Contiguous: + return stream << "Contiguous"; + case MemoryFormat::ChannelsLast: + return stream << "ChannelsLast"; + case MemoryFormat::ChannelsLast3d: + return stream << "ChannelsLast3d"; + default: + TORCH_CHECK(false, "Unknown memory format ", memory_format); + } +} + +// Note: Hardcoded the channel last stride indices here to get better +// performance +template +inline std::vector get_channels_last_strides_2d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 4: + strides[1] = 1; + strides[3] = sizes[1]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 3: + strides[0] = 1; + strides[2] = sizes[0]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + TORCH_INTERNAL_ASSERT( + false, "ChannelsLast2d doesn't support size ", sizes.size()); + } +} + +inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { + return get_channels_last_strides_2d(sizes); +} + +template +std::vector get_channels_last_strides_3d(ArrayRef sizes) { + std::vector strides(sizes.size()); + switch (sizes.size()) { + case 5: + strides[1] = 1; + strides[4] = sizes[1]; + strides[3] = strides[4] * sizes[4]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 4: + strides[0] = 1; + strides[3] = sizes[0]; + strides[2] = strides[3] * sizes[3]; + strides[1] = strides[2] * sizes[2]; + return strides; + default: + TORCH_INTERNAL_ASSERT( + false, "ChannelsLast3d doesn't support size ", sizes.size()); + } +} + +inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { + return get_channels_last_strides_3d(sizes); +} + +// NOTE: +// Below are Helper functions for is_channels_last_strides_xd. +// 1. Please do not combine these helper functions, each helper function handles +// exactly one case of sizes + memory_format, by doing this, the strides indices +// will be a constant array and we can access it using constant index number, +// the compiler will fully unroll the loop on strides indices to gain a better +// performance. +// 2. No error check in helper function, caller ensures the correctness of the +// input +// 3. All helper functions have similar comments, only 1st helper function is +// commented here. +template +inline bool is_channels_last_strides_2d_s4( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + // special case for trivial C dimension. default to NCHW + if (strides[1] == 0) { + return false; + } + // loop strides indices + for (auto& d : {1, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + // Fallback to NCHW as default layout for ambiguous cases + // This is the flaw of implicit memory_format from strides. + // N111 tensor with identical strides for size 1 dimension; + // Two cases could lead us here: + // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + // b. N11W contiguous Tensor sliced on the W-dimension. + // ([N,1,1,1]@[W,W,W,W]) + if (d == 0 && min == strides[1]) { + return false; + } + // This is necessary to: + // 1. distinguish the memory_format of N1H1; + // [H, 1, 1, 1] channels_last stride + // [H, H, 1, 1] contiguous stride + // 2. permutation of 1C1W: + // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +template +inline bool is_channels_last_strides_3d_s5( + const ArrayRef sizes, + const ArrayRef strides) { + T min = 0; + if (strides[1] == 0) { + return false; + } + for (auto& d : {1, 4, 3, 2, 0}) { + if (sizes[d] == 0) { + return false; + } + if (strides[d] < min) { + return false; + } + if (d == 0 && min == strides[1]) { + return false; + } + min = strides[d]; + if (sizes[d] > 1) { + min *= sizes[d]; + } + } + return true; +} + +// Note [Ambiguous is_channels_last_strides_xd] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// The flaw of carrying memory_format implicitly through strides is very hard +// to WAR properly. issue #24090 +// Without the history of permutation, we can't infer the memory_format of a +// tensor from the snapshot of its size & stride +// e.g. +// +// 1. We can NOT specify the memory_format of N111 tensor through strides in a +// meaningful way; +// +// 2. Two path that ended up with identical size/stride +// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W] +// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C] +// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer +// the memory_format of the original tensor. +// +// Due to the limitations, our temporary WAR `is_channels_last_strides` does the +// best effort to infer whether the original memory_format of a tensor is +// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered +// by their importance): +// 1. Ensure that normal shape manipulation does not accidentally change the +// MemoryFormat of an existing tensor. +// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors; +// +// The function does so via checking strides of the tensor, including strides of +// size-1 dimensions. Although conventionally PyTorch implies no restriction on +// trivial stride (stride for size-1 dimension). +// +// Note that this approach is a compromise. We did not solve the problem +// completely. Many cases we will not be able to infer the correct memory +// format. +// The implementation of `is_channels_last_strides` is to serve the objectives: +// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental +// conversion); Best effort to maintain the ChannelsLast flag. +// +// Due to the fact that this is not a bulletproof solution, through testing +// (aten/src/ATen/test/memory_format_test.cpp) +// a. we ensure that the common tasks are supported; +// a. we identify corner cases where the implementation compromises on. +// +// By the time accumulated permutation is enabled to replace implicit +// memory_format through strides, we should be updating our tests and fix the +// issues in our tests. +// +// We use Channels Last 2d as an example above. +// This is a general problem for all the is_channels_last_strides_xd +// implementation. Please check the helper functions +// (is_channels_last_strides_*d_s*) for more details. + +template +inline bool is_channels_last_strides_2d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 4: + return is_channels_last_strides_2d_s4(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 3: + // TODO dim == 3 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +template +inline bool is_channels_last_strides_3d( + const ArrayRef sizes, + const ArrayRef strides) { + switch (sizes.size()) { + case 5: + return is_channels_last_strides_3d_s5(sizes, strides); + // NOLINTNEXTLINE(bugprone-branch-clone) + case 4: + // TODO dim == 4 case will be enabled once it is fully tested + return false; + default: + return false; + } +} + +inline bool is_channels_last_strides_2d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_2d(sizes, strides); +} + +inline bool is_channels_last_strides_3d( + const IntArrayRef sizes, + const IntArrayRef strides) { + return is_channels_last_strides_3d(sizes, strides); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/OptionalRef.h b/MLPY/Lib/site-packages/torch/include/c10/core/OptionalRef.h new file mode 100644 index 0000000000000000000000000000000000000000..84c843ec68164c1dcf01990a0319b25f8cd674ec --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/OptionalRef.h @@ -0,0 +1,31 @@ +#pragma once + +namespace c10 { + +template +class OptionalRef { + public: + OptionalRef() : data_(nullptr) {} + OptionalRef(const T* data) : data_(data) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data_); + } + OptionalRef(const T& data) : data_(&data) {} + + bool has_value() const { + return data_ != nullptr; + } + + const T& get() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data_); + return *data_; + } + + operator bool() const { + return has_value(); + } + + private: + const T* data_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/PyHandleCache.h b/MLPY/Lib/site-packages/torch/include/c10/core/PyHandleCache.h new file mode 100644 index 0000000000000000000000000000000000000000..37245dbed26c4afc2700089a6cf06a34d2a12d8c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/PyHandleCache.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace c10 { + +// A PyHandleCache represents a cached pointer from a C++ object to +// a Python object that represents that object analogously in Python. +// Upon a cache hit, the relevant object can be retrieved after a test +// and then a memory load. Two conditions must hold to be able to use this +// class: +// +// - This must truly be a cache; e.g., the caller must be able to produce +// the object some other way if the cache hit misses. +// +// - This must truly be a handle; e.g., the Python object referenced by +// this class must have static lifetime. This means we don't have to +// maintain strong ownership or deallocate the object when the C++ object +// dies. Static lifetime is a good idea in conjunction with the cache, +// since if you are producing a fresh object on miss you won't be +// maintaining object identity. If you need bidirectional ownership, +// you will want to factor out the pattern in TensorImpl with +// resurrection. +// +// This cache is expected to not improve perf under torchdeploy, as one +// interpreter will fill up the cache, and all the interpreters will be +// unable to use the slot. A potential improvement is to have multiple +// slots (one per interpreter), which will work in deployment scenarios +// where there a stable, fixed number of interpreters. You can also store +// the relevant state in the Python library, rather than in the non-Python +// library (although in many cases, this is not convenient, as there may +// not be a way to conveniently index based on the object.) +class PyHandleCache { + public: + PyHandleCache() : pyinterpreter_(nullptr) {} + + // Attempt to fetch the pointer from the cache, if the PyInterpreter + // matches. If it doesn't exist, or the cache entry is not valid, + // use slow_accessor to get the real pointer value and return that + // (possibly writing it to the cache, if the cache entry is + // available.) + template + PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor) + const { + // Note [Memory ordering on Python interpreter tag] + impl::PyInterpreter* interpreter = + pyinterpreter_.load(std::memory_order_acquire); + if (C10_LIKELY(interpreter == self_interpreter)) { + return data_; + } else if (interpreter == nullptr) { + auto* r = slow_accessor(); + impl::PyInterpreter* expected = nullptr; + // attempt to claim this cache entry with the specified interpreter tag + if (pyinterpreter_.compare_exchange_strong( + expected, self_interpreter, std::memory_order_acq_rel)) { + data_ = r; + } + // This shouldn't be possible, as you should be GIL protected + TORCH_INTERNAL_ASSERT(expected != self_interpreter); + return r; + } else { + return slow_accessor(); + } + } + + private: + mutable std::atomic pyinterpreter_; + mutable PyObject* data_{nullptr}; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/QEngine.h b/MLPY/Lib/site-packages/torch/include/c10/core/QEngine.h new file mode 100644 index 0000000000000000000000000000000000000000..b8a0ac9639303e3ce466db30c53e02739ef68224 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/QEngine.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * QEngine is an enum that is used to select the engine to run quantized ops. + * Keep this enum in sync with get_qengine_id() in + * torch/backends/quantized/__init__.py + */ +enum class QEngine : uint8_t { + NoQEngine = 0, + FBGEMM = 1, + QNNPACK = 2, + ONEDNN = 3, + X86 = 4, +}; + +constexpr auto kNoQEngine = QEngine::NoQEngine; +constexpr auto kFBGEMM = QEngine::FBGEMM; +constexpr auto kQNNPACK = QEngine::QNNPACK; +constexpr auto kONEDNN = QEngine::ONEDNN; +constexpr auto kX86 = QEngine::X86; + +inline std::string toString(QEngine qengine) { + switch (qengine) { + case kNoQEngine: + return "NoQEngine"; + case kFBGEMM: + return "FBGEMM"; + case kQNNPACK: + return "QNNPACK"; + case kONEDNN: + return "ONEDNN"; + case kX86: + return "X86"; + default: + TORCH_CHECK( + false, "Unrecognized Quantized Engine: ", static_cast(qengine)); + } +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/QScheme.h b/MLPY/Lib/site-packages/torch/include/c10/core/QScheme.h new file mode 100644 index 0000000000000000000000000000000000000000..158839257eabe869fdb2c04e7754a938fe2eacd8 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/QScheme.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * QScheme is an enum that specifies the type of quantization. This has a one + * to one correspondence with Quantizer + * Please refer to ATen/quantized/Quantizer.h to see the Quantizers classes. + * Keep this file in sync with torch/nn/_qscheme.py + */ +enum class QScheme : uint8_t { + PER_TENSOR_AFFINE = 0, + PER_CHANNEL_AFFINE = 1, + PER_TENSOR_SYMMETRIC = 2, + PER_CHANNEL_SYMMETRIC = 3, + PER_CHANNEL_AFFINE_FLOAT_QPARAMS = 4, + COMPILE_TIME_NUM_QSCHEMES = 5, +}; + +constexpr auto kPerTensorAffine = QScheme::PER_TENSOR_AFFINE; +constexpr auto kPerChannelAffine = QScheme::PER_CHANNEL_AFFINE; +constexpr auto kPerTensorSymmetric = QScheme::PER_TENSOR_SYMMETRIC; +constexpr auto kPerChannelSymmetric = QScheme::PER_CHANNEL_SYMMETRIC; +constexpr auto kPerChannelAffineFloatQParams = + QScheme::PER_CHANNEL_AFFINE_FLOAT_QPARAMS; +constexpr int COMPILE_TIME_NUM_QSCHEMES = + static_cast(QScheme::COMPILE_TIME_NUM_QSCHEMES); + +inline std::string toString(QScheme qscheme) { + switch (qscheme) { + case kPerTensorAffine: + return "per_tensor_affine"; + case kPerChannelAffine: + return "per_channel_affine"; + case kPerTensorSymmetric: + return "per_tensor_symmetric"; + case kPerChannelSymmetric: + return "per_channel_symmetric"; + case kPerChannelAffineFloatQParams: + return "per_channel_affine_float_qparams"; + default: + TORCH_CHECK(false, "Unrecognized qscheme: ", static_cast(qscheme)); + } +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/RefcountedDeleter.h b/MLPY/Lib/site-packages/torch/include/c10/core/RefcountedDeleter.h new file mode 100644 index 0000000000000000000000000000000000000000..2e7125f5858b8cf6b331f6d32cdad25dfa2af8f5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/RefcountedDeleter.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace c10 { + +// A RefcountedDeleterContext object is used as the `ctx` argument for DataPtr +// to implement a shared DataPtr. Normally, a DataPtr is unique, but we use +// this custom context and the `refcounted_deleter` function below to make the +// DataPtr act like a non-unique DataPtr. This context object holds onto an +// inner context and deleter function which handle the actual deletion of the +// data when the refcount reaches 0. +// +// This shared DataPtr feature is only used when storages are shared between +// multiple Python interpreters in MultiPy. Before storages had PyObject +// preservation, interpreters could just share the same StorageImpl instance. +// But now a StorageImpl can only be associated with one interpreter in order +// to properly manage a zombie PyObject. So we share storages across Python +// interpreters by creating a different StorageImpl instance for each one, but +// they all point to the same data. +struct C10_API RefcountedDeleterContext { + RefcountedDeleterContext(void* other_ctx, c10::DeleterFnPtr other_deleter) + : other_ctx(other_ctx, other_deleter), refcount(1) {} + + std::unique_ptr other_ctx; + std::atomic_int refcount; +}; + +// `refcounted_deleter` is used as the `ctx_deleter` for DataPtr to implement +// a shared DataPtr. +// +// Warning: This should only be called on a pointer to +// a RefcountedDeleterContext that was allocated on the heap with `new`, +// because when the refcount reaches 0, the context is deleted with `delete` +C10_API void refcounted_deleter(void* ctx_); + +// If the storage's DataPtr does not use `refcounted_deleter`, replace it with +// a DataPtr that does, so it can be shared between multiple StorageImpls +C10_API void maybeApplyRefcountedDeleter(const c10::Storage& storage); + +// Create a new StorageImpl that points to the same data. If the original +// StorageImpl's DataPtr does not use `refcounted_deleter`, it will be replaced +// with one that does +C10_API c10::Storage newStorageImplFromRefcountedDataPtr( + const c10::Storage& storage); + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SafePyObject.h b/MLPY/Lib/site-packages/torch/include/c10/core/SafePyObject.h new file mode 100644 index 0000000000000000000000000000000000000000..1f86dbd83b269c08a8396717b5c3f18fa3967646 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SafePyObject.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10 { + +// This is an safe owning holder for a PyObject, akin to pybind11's +// py::object, with two major differences: +// +// - It is in c10/core; i.e., you can use this type in contexts where +// you do not have a libpython dependency +// +// - It is multi-interpreter safe (ala torchdeploy); when you fetch +// the underlying PyObject* you are required to specify what the current +// interpreter context is and we will check that you match it. +// +// It is INVALID to store a reference to a Tensor object in this way; +// you should just use TensorImpl directly in that case! +struct C10_API SafePyObject { + // Steals a reference to data + SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) + : data_(data), pyinterpreter_(pyinterpreter) {} + SafePyObject(SafePyObject&& other) noexcept + : data_(std::exchange(other.data_, nullptr)), + pyinterpreter_(other.pyinterpreter_) {} + + // In principle this could be copyable if we add an incref to PyInterpreter + // but for now it's easier to just disallow it. + SafePyObject(SafePyObject const&) = delete; + SafePyObject& operator=(SafePyObject const&) = delete; + + ~SafePyObject() { + if (data_ != nullptr) { + (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false); + } + } + + c10::impl::PyInterpreter& pyinterpreter() const { + return *pyinterpreter_; + } + PyObject* ptr(const c10::impl::PyInterpreter*) const; + + // stop tracking the current object, and return it + PyObject* release() { + auto rv = data_; + data_ = nullptr; + return rv; + } + + private: + PyObject* data_; + c10::impl::PyInterpreter* pyinterpreter_; +}; + +// Like SafePyObject, but non-owning. Good for references to global PyObjects +// that will be leaked on interpreter exit. You get a copy constructor/assign +// this way. +struct C10_API SafePyHandle { + SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {} + SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) + : data_(data), pyinterpreter_(pyinterpreter) {} + + c10::impl::PyInterpreter& pyinterpreter() const { + return *pyinterpreter_; + } + PyObject* ptr(const c10::impl::PyInterpreter*) const; + void reset() { + data_ = nullptr; + pyinterpreter_ = nullptr; + } + operator bool() { + return data_; + } + + private: + PyObject* data_; + c10::impl::PyInterpreter* pyinterpreter_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Scalar.h b/MLPY/Lib/site-packages/torch/include/c10/core/Scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..dc63c0738cef12123e66b23e7972217062af2324 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Scalar.h @@ -0,0 +1,461 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/** + * Scalar represents a 0-dimensional tensor which contains a single element. + * Unlike a tensor, numeric literals (in C++) are implicitly convertible to + * Scalar (which is why, for example, we provide both add(Tensor) and + * add(Scalar) overloads for many operations). It may also be used in + * circumstances where you statically know a tensor is 0-dim and single size, + * but don't know its type. + */ +class C10_API Scalar { + public: + Scalar() : Scalar(int64_t(0)) {} + + void destroy() { + if (Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag) { + raw::intrusive_ptr::decref(v.p); + v.p = nullptr; + } + } + + ~Scalar() { + destroy(); + } + +#define DEFINE_IMPLICIT_CTOR(type, name) \ + Scalar(type vv) : Scalar(vv, true) {} + + AT_FORALL_SCALAR_TYPES_AND7( + Half, + BFloat16, + Float8_e5m2, + Float8_e4m3fn, + Float8_e5m2fnuz, + Float8_e4m3fnuz, + ComplexHalf, + DEFINE_IMPLICIT_CTOR) + AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) + + // Helper constructors to allow Scalar creation from long and long long types + // As std::is_same_v is false(except Android), one needs to + // provide a constructor from either long or long long in addition to one from + // int64_t +#if defined(__APPLE__) || defined(__MACOSX) + static_assert( + std::is_same_v, + "int64_t is the same as long long on MacOS"); + Scalar(long vv) : Scalar(vv, true) {} +#endif +#if defined(__linux__) && !defined(__ANDROID__) + static_assert( + std::is_same_v, + "int64_t is the same as long on Linux"); + Scalar(long long vv) : Scalar(vv, true) {} +#endif + + Scalar(uint16_t vv) : Scalar(vv, true) {} + Scalar(uint32_t vv) : Scalar(vv, true) {} + Scalar(uint64_t vv) { + if (vv > static_cast(INT64_MAX)) { + tag = Tag::HAS_u; + v.u = vv; + } else { + tag = Tag::HAS_i; + // NB: no need to use convert, we've already tested convertibility + v.i = static_cast(vv); + } + } + +#undef DEFINE_IMPLICIT_CTOR + + // Value* is both implicitly convertible to SymbolicVariable and bool which + // causes ambiguity error. Specialized constructor for bool resolves this + // problem. + template < + typename T, + typename std::enable_if_t, bool>* = nullptr> + Scalar(T vv) : tag(Tag::HAS_b) { + v.i = convert(vv); + } + + template < + typename T, + typename std::enable_if_t, bool>* = + nullptr> + Scalar(T vv) : tag(Tag::HAS_sb) { + v.i = convert(vv); + } + +#define DEFINE_ACCESSOR(type, name) \ + type to##name() const { \ + if (Tag::HAS_d == tag) { \ + return checked_convert(v.d, #type); \ + } else if (Tag::HAS_z == tag) { \ + return checked_convert>(v.z, #type); \ + } \ + if (Tag::HAS_b == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_i == tag) { \ + return checked_convert(v.i, #type); \ + } else if (Tag::HAS_u == tag) { \ + return checked_convert(v.u, #type); \ + } else if (Tag::HAS_si == tag) { \ + return checked_convert( \ + toSymInt().guard_int(__FILE__, __LINE__), #type); \ + } else if (Tag::HAS_sd == tag) { \ + return checked_convert( \ + toSymFloat().guard_float(__FILE__, __LINE__), #type); \ + } else if (Tag::HAS_sb == tag) { \ + return checked_convert( \ + toSymBool().guard_bool(__FILE__, __LINE__), #type); \ + } \ + TORCH_CHECK(false) \ + } + + // TODO: Support ComplexHalf accessor + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) + DEFINE_ACCESSOR(uint16_t, UInt16) + DEFINE_ACCESSOR(uint32_t, UInt32) + DEFINE_ACCESSOR(uint64_t, UInt64) + +#undef DEFINE_ACCESSOR + + SymInt toSymInt() const { + if (Tag::HAS_si == tag) { + return c10::SymInt(intrusive_ptr::reclaim_copy( + static_cast(v.p))); + } else { + return toLong(); + } + } + + SymFloat toSymFloat() const { + if (Tag::HAS_sd == tag) { + return c10::SymFloat(intrusive_ptr::reclaim_copy( + static_cast(v.p))); + } else { + return toDouble(); + } + } + + SymBool toSymBool() const { + if (Tag::HAS_sb == tag) { + return c10::SymBool(intrusive_ptr::reclaim_copy( + static_cast(v.p))); + } else { + return toBool(); + } + } + + // also support scalar.to(); + // Deleted for unsupported types, but specialized below for supported types + template + T to() const = delete; + + // audit uses of data_ptr + const void* data_ptr() const { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return static_cast(&v); + } + + bool isFloatingPoint() const { + return Tag::HAS_d == tag || Tag::HAS_sd == tag; + } + + C10_DEPRECATED_MESSAGE( + "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") + bool isIntegral() const { + return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; + } + bool isIntegral(bool includeBool) const { + return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag || + (includeBool && isBoolean()); + } + + bool isComplex() const { + return Tag::HAS_z == tag; + } + bool isBoolean() const { + return Tag::HAS_b == tag || Tag::HAS_sb == tag; + } + + // you probably don't actually want these; they're mostly for testing + bool isSymInt() const { + return Tag::HAS_si == tag; + } + bool isSymFloat() const { + return Tag::HAS_sd == tag; + } + bool isSymBool() const { + return Tag::HAS_sb == tag; + } + + bool isSymbolic() const { + return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag; + } + + C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept { + if (&other == this) { + return *this; + } + + destroy(); + moveFrom(std::move(other)); + return *this; + } + + C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { + if (&other == this) { + return *this; + } + + *this = Scalar(other); + return *this; + } + + Scalar operator-() const; + Scalar conj() const; + Scalar log() const; + + template < + typename T, + typename std::enable_if_t::value, int> = 0> + bool equal(T num) const { + if (isComplex()) { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + auto val = v.z; + return (val.real() == num) && (val.imag() == T()); + } else if (isFloatingPoint()) { + TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality"); + return v.d == num; + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num; + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num; + } + } else if (tag == Tag::HAS_si) { + TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality"); + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return false; + } else { + TORCH_INTERNAL_ASSERT(false); + } + } + + template < + typename T, + typename std::enable_if_t::value, int> = 0> + bool equal(T num) const { + if (isComplex()) { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return v.z == num; + } else if (isFloatingPoint()) { + TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality"); + return (v.d == num.real()) && (num.imag() == T()); + } else if (tag == Tag::HAS_i) { + if (overflows(v.i, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.i) == num.real() && num.imag() == T(); + } + } else if (tag == Tag::HAS_u) { + if (overflows(v.u, /* strict_unsigned */ true)) { + return false; + } else { + return static_cast(v.u) == num.real() && num.imag() == T(); + } + } else if (tag == Tag::HAS_si) { + TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality"); + } else if (isBoolean()) { + // boolean scalar does not equal to a non boolean value + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return false; + } else { + TORCH_INTERNAL_ASSERT(false); + } + } + + bool equal(bool num) const { + if (isBoolean()) { + TORCH_INTERNAL_ASSERT(!isSymbolic()); + return static_cast(v.i) == num; + } else { + return false; + } + } + + ScalarType type() const { + if (isComplex()) { + return ScalarType::ComplexDouble; + } else if (isFloatingPoint()) { + return ScalarType::Double; + } else if (isIntegral(/*includeBool=*/false)) { + // Represent all integers as long, UNLESS it is unsigned and therefore + // unrepresentable as long + if (Tag::HAS_u == tag) { + return ScalarType::UInt64; + } + return ScalarType::Long; + } else if (isBoolean()) { + return ScalarType::Bool; + } else { + throw std::runtime_error("Unknown scalar type."); + } + } + + Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { + moveFrom(std::move(rhs)); + } + + Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) { + if (isSymbolic()) { + c10::raw::intrusive_ptr::incref(v.p); + } + } + + Scalar(c10::SymInt si) { + if (auto m = si.maybe_as_int()) { + tag = Tag::HAS_i; + v.i = *m; + } else { + tag = Tag::HAS_si; + v.p = std::move(si).release(); + } + } + + Scalar(c10::SymFloat sd) { + if (sd.is_symbolic()) { + tag = Tag::HAS_sd; + v.p = std::move(sd).release(); + } else { + tag = Tag::HAS_d; + v.d = sd.as_float_unchecked(); + } + } + + Scalar(c10::SymBool sb) { + if (auto m = sb.maybe_as_bool()) { + tag = Tag::HAS_b; + v.i = *m; + } else { + tag = Tag::HAS_sb; + v.p = std::move(sb).release(); + } + } + + // We can't set v in the initializer list using the + // syntax v{ .member = ... } because it doesn't work on MSVC + private: + enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b, HAS_sd, HAS_si, HAS_sb }; + + // Note [Meaning of HAS_u] + // ~~~~~~~~~~~~~~~~~~~~~~~ + // HAS_u is a bit special. On its face, it just means that we + // are holding an unsigned integer. However, we generally don't + // distinguish between different bit sizes in Scalar (e.g., we represent + // float as double), instead, it represents a mathematical notion + // of some quantity (integral versus floating point). So actually, + // HAS_u is used solely to represent unsigned integers that could + // not be represented as a signed integer. That means only uint64_t + // potentially can get this tag; smaller types like uint8_t fits into a + // regular int and so for BC reasons we keep as an int. + + // NB: assumes that self has already been cleared + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { + v = rhs.v; + tag = rhs.tag; + if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd || + rhs.tag == Tag::HAS_sb) { + // Move out of scalar + rhs.tag = Tag::HAS_i; + rhs.v.i = 0; + } + } + + Tag tag; + + union v_t { + double d{}; + int64_t i; + // See Note [Meaning of HAS_u] + uint64_t u; + c10::complex z; + c10::intrusive_ptr_target* p; + // NOLINTNEXTLINE(modernize-use-equals-default) + v_t() {} // default constructor + } v; + + template < + typename T, + typename std::enable_if_t< + std::is_integral_v && !std::is_same_v, + bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_i) { + v.i = convert(vv); + } + + template < + typename T, + typename std::enable_if_t< + !std::is_integral_v && !c10::is_complex::value, + bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_d) { + v.d = convert(vv); + } + + template < + typename T, + typename std::enable_if_t::value, bool>* = nullptr> + Scalar(T vv, bool) : tag(Tag::HAS_z) { + v.z = convert(vv); + } +}; + +using OptionalScalarRef = c10::OptionalRef; + +// define the scalar.to() specializations +#define DEFINE_TO(T, name) \ + template <> \ + inline T Scalar::to() const { \ + return to##name(); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) +DEFINE_TO(uint16_t, UInt16) +DEFINE_TO(uint32_t, UInt32) +DEFINE_TO(uint64_t, UInt64) +#undef DEFINE_TO + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/ScalarType.h b/MLPY/Lib/site-packages/torch/include/c10/core/ScalarType.h new file mode 100644 index 0000000000000000000000000000000000000000..08f26ac2b30d429ed7a7bc72877531661c9ac11e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/ScalarType.h @@ -0,0 +1,620 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// For the macros below: +// +// For users: If you want to macro some code for all non-QInt scalar types +// (i.e. types with complete information, you probably want one of the +// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are +// designed to behave similarly to the Dispatch macros with the same name. +// +// For adding a new dtype: In the beginning, we had an idea that there was a +// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to +// iterate over them. But over the years we added weird types which couldn't +// be handled uniformly everywhere and so in the end we ended up with some +// mish-mosh of some helper macros, but mostly use sites making a call about +// what dtypes they can or can't support. So if you want to add a new dtype, +// the preferred resolution is to find a dtype similar to what you want, +// grep for it and edit all the sites you find this way. If you need to add +// a completely new kind of dtype, you're going to have to laboriously audit +// all of the sites everywhere to figure out how it should work. Consulting +// some old PRs where we added new dtypes (check history of this file) can +// help give you an idea where to start. + +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ \ + _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ + _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ + _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ + _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ + _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ + _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ + _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ + +// If you want to support ComplexHalf for real, add ComplexHalf +// into this macro (and change the name). But beware: convert() +// doesn't work for all the conversions you need... +// +// TODO: To add unsigned int types here, we must define accumulate type. +// But uint8 currently accumulates into int64, so we would have to make +// an inconsistent choice for the larger types. Difficult. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(c10::complex, ComplexFloat) \ + _(c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(at::BFloat16, BFloat16) \ + _(at::Float8_e5m2, Float8_e5m2) \ + _(at::Float8_e4m3fn, Float8_e4m3fn) + +// This macro controls many of our C++ APIs, including constructors +// for Scalar as well as the data() and item() accessors on Tensor +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(c10::complex, ComplexHalf) \ + _(c10::complex, ComplexFloat) \ + _(c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(at::BFloat16, BFloat16) \ + _(at::Float8_e5m2, Float8_e5m2) \ + _(at::Float8_e4m3fn, Float8_e4m3fn) \ + _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) + +enum class ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ + Undefined, + NumOptions +}; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); + +namespace impl { + +// These are used to map ScalarTypes to C++ types. + +template +struct ScalarTypeToCPPType; + +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ + \ + /* This is a workaround for the CUDA bug which prevents */ \ + /* ::detail::ScalarTypeToCType::type being used directly due to */ \ + /* ambiguous reference which can't to be resolved. For some reason it */ \ + /* can't pick between at::detail and at::cuda::detail. */ \ + /* For repro example, please see: */ \ + /* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \ + /* TODO: remove once the bug is fixed. */ \ + static type t; \ + }; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) + +#undef SPECIALIZE_ScalarTypeToCPPType + +template +using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType::type; + +} // namespace impl + +template +struct CppTypeToScalarType; + +#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \ + template <> \ + struct CppTypeToScalarType \ + : std:: \ + integral_constant { \ + }; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) + +#undef SPECIALIZE_CppTypeToScalarType + +// NB: despite its generic sounding name, the macros that don't take _AND +// are mostly only used by tensorexpr +#define AT_FORALL_INT_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) + +#define AT_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) + +// These macros are often controlling how many template instantiations we +// create for kernels. It is typically inappropriate to add new dtypes here, +// instead, new types should be added to use sites on a case-by-case basis. +// We generally are not accepting new dtypes due to binary size concerns. + +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE>::t), \ + SCALARTYPE) + +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) + +#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) + +#define AT_FORALL_SCALAR_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE4>::t), \ + SCALARTYPE4) + +#define AT_FORALL_SCALAR_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE4>::t), \ + SCALARTYPE4) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE5>::t), \ + SCALARTYPE5) + +#define AT_FORALL_SCALAR_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE4>::t), \ + SCALARTYPE4) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE5>::t), \ + SCALARTYPE5) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE6>::t), \ + SCALARTYPE6) + +#define AT_FORALL_SCALAR_TYPES_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE4>::t), \ + SCALARTYPE4) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE5>::t), \ + SCALARTYPE5) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE6>::t), \ + SCALARTYPE6) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE7>::t), \ + SCALARTYPE7) + +#define AT_FORALL_QINT_TYPES(_) \ + _(c10::qint8, QInt8) \ + _(c10::quint8, QUInt8) \ + _(c10::qint32, QInt32) \ + _(c10::quint4x2, QUInt4x2) \ + _(c10::quint2x4, QUInt2x4) + +#define AT_FORALL_COMPLEX_TYPES(_) \ + _(c10::complex, ComplexFloat) \ + _(c10::complex, ComplexDouble) + +#define DEFINE_CONSTANT(_, name) \ + constexpr ScalarType k##name = ScalarType::name; + +// NOLINTNEXTLINE(clang-diagnostic-unused-const-variable) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) +#undef DEFINE_CONSTANT + +static inline const char* toString(ScalarType t) { +#define DEFINE_CASE(_, name) \ + case ScalarType::name: \ + return #name; + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + default: + return "UNKNOWN_SCALAR"; + } +#undef DEFINE_CASE +} + +static inline size_t elementSize(ScalarType t) { +#define CASE_ELEMENTSIZE_CASE(ctype, name) \ + case ScalarType::name: \ + return sizeof(ctype); + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + default: + TORCH_CHECK(false, "Unknown ScalarType"); + } +#undef CASE_ELEMENTSIZE_CASE +} + +static inline bool isIntegralType(ScalarType t, bool includeBool) { + bool isIntegral = + (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || + t == ScalarType::Long || t == ScalarType::Short || + t == ScalarType::UInt16 || t == ScalarType::UInt32 || + t == ScalarType::UInt64); + + return isIntegral || (includeBool && t == ScalarType::Bool); +} + +C10_DEPRECATED_MESSAGE( + "isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.") +static inline bool isIntegralType(ScalarType t) { + return isIntegralType(t, /*includeBool=*/false); +} + +static inline bool isFloat8Type(ScalarType t) { + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || + t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz; +} + +static inline bool isReducedFloatingType(ScalarType t) { + return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t); +} + +static inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Double || t == ScalarType::Float || + isReducedFloatingType(t); +} + +static inline bool isComplexType(ScalarType t) { + return ( + t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); +} + +static inline bool isQIntType(ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || + t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || + t == ScalarType::QUInt2x4; +} + +static inline bool isBitsType(ScalarType t) { + return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 || + t == ScalarType::Bits4x2 || t == ScalarType::Bits8 || + t == ScalarType::Bits16; +} + +static inline bool isBarebonesUnsignedType(ScalarType t) { + return t == ScalarType::UInt1 || t == ScalarType::UInt2 || + t == ScalarType::UInt3 || t == ScalarType::UInt4 || + t == ScalarType::UInt5 || t == ScalarType::UInt6 || + t == ScalarType::UInt7 || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64; +} + +static inline ScalarType toQIntType(ScalarType t) { + switch (t) { + case ScalarType::Byte: + return ScalarType::QUInt8; + case ScalarType::Char: + return ScalarType::QInt8; + case ScalarType::Int: + return ScalarType::QInt32; + default: + return t; + } +} + +static inline ScalarType toUnderlying(ScalarType t) { + switch (t) { + case ScalarType::QUInt8: + case ScalarType::QUInt4x2: + [[fallthrough]]; + case ScalarType::QUInt2x4: + return ScalarType::Byte; + case ScalarType::QInt8: + return ScalarType::Char; + case ScalarType::QInt32: + return ScalarType::Int; + default: + return t; + } +} + +static inline bool isSignedType(ScalarType t) { + TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types"); +#define CASE_SIGNED(ctype, name) \ + case ScalarType::name: \ + return std::numeric_limits::is_signed; + + switch (t) { + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + TORCH_CHECK(false, "Bits types are undefined"); + case ScalarType::ComplexHalf: + case ScalarType::ComplexFloat: + case ScalarType::ComplexDouble: + return true; + AT_FORALL_SCALAR_TYPES_AND7( + Half, + Bool, + BFloat16, + Float8_e5m2, + Float8_e4m3fn, + Float8_e5m2fnuz, + Float8_e4m3fnuz, + CASE_SIGNED) + default: + TORCH_CHECK(false, "Unknown ScalarType"); + } +#undef CASE_SIGNED +} + +static inline bool isUnderlying(ScalarType type, ScalarType qtype) { + return type == toUnderlying(qtype); +} + +static inline ScalarType toRealValueType(ScalarType t) { + switch (t) { + case ScalarType::ComplexHalf: + return ScalarType::Half; + case ScalarType::ComplexFloat: + return ScalarType::Float; + case ScalarType::ComplexDouble: + return ScalarType::Double; + default: + return t; + } +} + +static inline ScalarType toComplexType(ScalarType t) { + switch (t) { + case ScalarType::BFloat16: + // BFloat16 has range equivalent to Float, + // so we map it to ComplexFloat. + return ScalarType::ComplexFloat; + case ScalarType::Half: + return ScalarType::ComplexHalf; + case ScalarType::Float: + return ScalarType::ComplexFloat; + case ScalarType::Double: + return ScalarType::ComplexDouble; + case ScalarType::ComplexHalf: + return ScalarType::ComplexHalf; + case ScalarType::ComplexFloat: + return ScalarType::ComplexFloat; + case ScalarType::ComplexDouble: + return ScalarType::ComplexDouble; + default: + TORCH_CHECK(false, "Unknown Complex ScalarType for ", t); + } +} + +// see tensor_attributes.rst for detailed explanation and examples +// of casting rules. +static inline bool canCast(const ScalarType from, const ScalarType to) { + // We disallow complex -> non complex, e.g., float_tensor *= complex is + // disallowed. + if (isComplexType(from) && !isComplexType(to)) { + return false; + } + // We disallow float -> integral, e.g., int_tensor *= float is disallowed. + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + + // Treat bool as a distinct "category," to be consistent with type promotion + // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same + // category as `bool_tensor`, we would not promote. Differing categories + // implies `bool_tensor += 5` is disallowed. + // + // NB: numpy distinguishes "unsigned" as a category to get the desired + // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: + // * We don't want the performance hit of checking the runtime sign of + // Scalars. + // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + +C10_API ScalarType promoteTypes(ScalarType a, ScalarType b); + +inline std::ostream& operator<<( + std::ostream& stream, + at::ScalarType scalar_type) { + return stream << toString(scalar_type); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/ScalarTypeToTypeMeta.h b/MLPY/Lib/site-packages/torch/include/c10/core/ScalarTypeToTypeMeta.h new file mode 100644 index 0000000000000000000000000000000000000000..65f4302325727e7d15f8ccc84e2a71c129d5d387 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/ScalarTypeToTypeMeta.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include + +// these just expose TypeMeta/ScalarType bridge functions in c10 +// TODO move to typeid.h (or codemod away) when TypeMeta et al +// are moved from caffe2 to c10 (see note at top of typeid.h) + +namespace c10 { + +/** + * convert ScalarType enum values to TypeMeta handles + */ +static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { + return caffe2::TypeMeta::fromScalarType(scalar_type); +} + +/** + * convert TypeMeta handles to ScalarType enum values + */ +static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { + return dtype.toScalarType(); +} + +/** + * typeMetaToScalarType(), lifted to optional + */ +static inline optional optTypeMetaToScalarType( + optional type_meta) { + if (!type_meta.has_value()) { + return c10::nullopt; + } + return type_meta->toScalarType(); +} + +/** + * convenience: equality across TypeMeta/ScalarType conversion + */ +static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { + return m.isScalarType(t); +} + +static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { + return t == m; +} + +static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { + return !(t == m); +} + +static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { + return !(t == m); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Storage.h b/MLPY/Lib/site-packages/torch/include/c10/core/Storage.h new file mode 100644 index 0000000000000000000000000000000000000000..09be93941bb9e2a766a0784d96ee2a35dae8d099 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Storage.h @@ -0,0 +1,272 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct Storage; + +C10_API bool isSharedStorageAlias( + const Storage& storage0, + const Storage& storage1); + +struct C10_API Storage { + public: + struct use_byte_size_t {}; + struct unsafe_borrow_t { + explicit unsafe_borrow_t() = default; + }; + + Storage() = default; + Storage(c10::intrusive_ptr ptr) + : storage_impl_(std::move(ptr)) {} + + // Allocates memory buffer using given allocator and creates a storage with it + Storage( + use_byte_size_t /*use_byte_size*/, + const SymInt& size_bytes, + Allocator* allocator = nullptr, + bool resizable = false) + : storage_impl_(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + size_bytes, + allocator, + resizable)) {} + + // Creates storage with pre-allocated memory buffer. Allocator is given for + // potential future reallocations, however it can be nullptr if the storage + // is non-resizable + Storage( + use_byte_size_t /*use_byte_size*/, + size_t size_bytes, + at::DataPtr data_ptr, + at::Allocator* allocator = nullptr, + bool resizable = false) + : storage_impl_(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + size_bytes, + std::move(data_ptr), + allocator, + resizable)) {} + + protected: + explicit Storage(unsafe_borrow_t, const Storage& rhs) + : storage_impl_(c10::intrusive_ptr::reclaim( + rhs.storage_impl_.get())) {} + + friend MaybeOwnedTraits; + + public: + // Legacy constructor for partially initialized (dtype or memory) storages + // that can be temporarily created with Caffe2 APIs. See the note on top of + // TensorImpl.h for details. + static Storage create_legacy(at::Device device) { + auto allocator = GetAllocator(device.type()); + return Storage(c10::make_intrusive( + StorageImpl::use_byte_size_t(), + 0, + allocator->allocate(0), // materialize a non-default Device. + allocator, + true)); + } + + // Mimic create_legacy, but without requiring a newly-created StorageImpl. + void reset_legacy() { + TORCH_CHECK(resizable() && allocator()); + set_nbytes(0); + set_data_ptr_noswap(allocator()->allocate(0)); + } + + // TODO: remove later + void set_nbytes(size_t size_bytes) const { + storage_impl_->set_nbytes(size_bytes); + } + + void set_nbytes(c10::SymInt size_bytes) const { + storage_impl_->set_nbytes(std::move(size_bytes)); + } + + bool resizable() const { + return storage_impl_->resizable(); + } + + size_t nbytes() const { + return storage_impl_->nbytes(); + } + + SymInt sym_nbytes() const { + return storage_impl_->sym_nbytes(); + } + // get() use here is to get const-correctness + + const void* data() const { + return storage_impl_->data(); + } + + void* mutable_data() const { + return storage_impl_->mutable_data(); + } + + at::DataPtr& mutable_data_ptr() const { + return storage_impl_->mutable_data_ptr(); + } + + const at::DataPtr& data_ptr() const { + return storage_impl_->data_ptr(); + } + + // Returns the previous data_ptr + at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const { + return storage_impl_->set_data_ptr(std::move(data_ptr)); + } + + void set_data_ptr_noswap(at::DataPtr&& data_ptr) const { + return storage_impl_->set_data_ptr_noswap(std::move(data_ptr)); + } + + DeviceType device_type() const { + return storage_impl_->device_type(); + } + + at::Allocator* allocator() const { + return storage_impl_->allocator(); + } + + at::Device device() const { + return storage_impl_->device(); + } + + StorageImpl* unsafeReleaseStorageImpl() { + return storage_impl_.release(); + } + + StorageImpl* unsafeGetStorageImpl() const noexcept { + return storage_impl_.get(); + } + + c10::weak_intrusive_ptr getWeakStorageImpl() const { + return c10::weak_intrusive_ptr(storage_impl_); + } + + operator bool() const { + return storage_impl_; + } + + size_t use_count() const { + return storage_impl_.use_count(); + } + + inline bool unique() const { + return storage_impl_.unique(); + } + + bool is_alias_of(const Storage& other) const { + return ( + storage_impl_ == other.storage_impl_ || + isSharedStorageAlias(*this, other)); + } + + void UniqueStorageShareExternalPointer( + void* src, + size_t capacity, + DeleterFnPtr d = nullptr) { + if (!storage_impl_.unique()) { + TORCH_CHECK( + false, + "UniqueStorageShareExternalPointer can only be called when use_count == 1"); + } + storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d); + } + + void UniqueStorageShareExternalPointer( + at::DataPtr&& data_ptr, + size_t capacity) { + if (!storage_impl_.unique()) { + TORCH_CHECK( + false, + "UniqueStorageShareExternalPointer can only be called when use_count == 1"); + } + storage_impl_->UniqueStorageShareExternalPointer( + std::move(data_ptr), capacity); + } + + protected: + c10::intrusive_ptr storage_impl_; +}; + +template <> +struct MaybeOwnedTraits { + using owned_type = c10::Storage; + using borrow_type = c10::Storage; + + static borrow_type createBorrow(const owned_type& from) { + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseStorageImpl(); + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseStorageImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = c10::Storage; + using pointer_type = c10::Storage*; + using const_pointer_type = const c10::Storage*; + + static repr_type nullRepr() { + return c10::Storage(); + } + + template + static repr_type createInPlace(Args&&... args) { + return c10::Storage(std::forward(args)...); + } + + static repr_type moveToRepr(c10::Storage&& x) { + return std::move(x); + } + + static c10::Storage take(c10::Storage& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/StorageImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/StorageImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..714b7d9fbe949285e89cd540e05ca5145549b2db --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/StorageImpl.h @@ -0,0 +1,276 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// A storage represents the underlying backing data buffer for a +// tensor. This concept was inherited from the original Torch7 +// codebase; we'd kind of like to get rid of the concept +// (see https://github.com/pytorch/pytorch/issues/14797) but +// it's hard work and no one has gotten around to doing it. +// +// NB: storage is supposed to uniquely own a data pointer; e.g., +// two non-null data pointers alias if and only if they are from +// the same storage. Technically you can violate this invariant +// (e.g., you can create a non-owning StorageImpl with at::from_blob) +// but a lot of things won't work correctly, including: +// +// - An ordinary deleter on such a storage is wrong, because normal deleters +// assume unique ownership, but if you have two storages at the same data, +// that implies there is some sort of shared ownership. So your deleter would +// have to actually be internally doing some sort of refcount thing +// - Deepcopy in Python side relies on storage equality and not data pointer +// equality; so if there are two separate storages pointing to the same data, +// the data will actually get duplicated in that case (one data ptr before, +// two data ptrs after) +// - Version counts won't work correctly, because we do all VC tracking at the +// level of storages (unless you explicitly disconnect the VC with detach); +// mutation because data pointers are the same are totally untracked +struct C10_API StorageImpl : public c10::intrusive_ptr_target { + public: + struct use_byte_size_t {}; + + StorageImpl( + use_byte_size_t /*use_byte_size*/, + SymInt size_bytes, + at::DataPtr data_ptr, + at::Allocator* allocator, + bool resizable) + : data_ptr_(std::move(data_ptr)), + size_bytes_(std::move(size_bytes)), + size_bytes_is_heap_allocated_(size_bytes_.is_heap_allocated()), + resizable_(resizable), + received_cuda_(false), + allocator_(allocator) { + if (resizable) { + TORCH_INTERNAL_ASSERT( + allocator_, "For resizable storage, allocator must be provided"); + } + } + + StorageImpl( + use_byte_size_t /*use_byte_size*/, + const SymInt& size_bytes, + at::Allocator* allocator, + bool resizable) + : StorageImpl( + use_byte_size_t(), + size_bytes, + size_bytes.is_heap_allocated() + ? allocator->allocate(0) + : allocator->allocate(size_bytes.as_int_unchecked()), + allocator, + resizable) {} + + StorageImpl& operator=(StorageImpl&& other) = delete; + StorageImpl& operator=(const StorageImpl&) = delete; + StorageImpl() = delete; + StorageImpl(StorageImpl&& other) = delete; + StorageImpl(const StorageImpl&) = delete; + ~StorageImpl() override = default; + + void reset() { + data_ptr_.clear(); + size_bytes_ = 0; + size_bytes_is_heap_allocated_ = false; + } + + // Destructor doesn't call release_resources because it's + // unnecessary; don't forget to change that if needed! + void release_resources() override { + data_ptr_.clear(); + } + + size_t nbytes() const { + // OK to do this instead of maybe_as_int as nbytes is guaranteed positive + TORCH_CHECK(!size_bytes_is_heap_allocated_); + return size_bytes_.as_int_unchecked(); + } + + SymInt sym_nbytes() const { + return size_bytes_; + } + + // TODO: remove later + void set_nbytes(size_t size_bytes) { + size_bytes_ = static_cast(size_bytes); + size_bytes_is_heap_allocated_ = false; + } + + void set_nbytes(c10::SymInt size_bytes) { + size_bytes_ = std::move(size_bytes); + } + + bool resizable() const { + return resizable_; + } + + at::DataPtr& mutable_data_ptr() { + maybe_materialize_cow(); + return data_ptr_; + } + + const at::DataPtr& data_ptr() const { + return data_ptr_; + } + + // Returns the previous data_ptr + at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) { + // We need to materialize the old COW DataPtr because it is + // being returned as mutable. + maybe_materialize_cow(); + return set_data_ptr_no_materialize_cow(std::move(data_ptr)); + } + + void set_data_ptr_noswap(at::DataPtr&& data_ptr) { + data_ptr_ = std::move(data_ptr); + } + + const void* data() const { + return data_ptr_.get(); + } + + void* mutable_data() { + maybe_materialize_cow(); + return data_ptr_.mutable_get(); + } + + at::DeviceType device_type() const { + return data_ptr_.device().type(); + } + + at::Allocator* allocator() { + return allocator_; + } + + const at::Allocator* allocator() const { + return allocator_; + } + + // You generally shouldn't use this method, but it is occasionally + // useful if you want to override how a tensor will be reallocated, + // after it was already allocated (and its initial allocator was + // set) + void set_allocator(at::Allocator* allocator) { + allocator_ = allocator; + } + + Device device() const { + return data_ptr_.device(); + } + + void set_resizable(bool resizable) { + if (resizable) { + // We need an allocator to be resizable + AT_ASSERT(allocator_); + } + resizable_ = resizable; + } + + /** + * Can only be called when use_count is 1 + */ + void UniqueStorageShareExternalPointer( + void* src, + size_t size_bytes, + DeleterFnPtr d = nullptr) { + UniqueStorageShareExternalPointer( + at::DataPtr(src, src, d, data_ptr_.device()), size_bytes); + } + + /** + * Can only be called when use_count is 1 + */ + void UniqueStorageShareExternalPointer( + at::DataPtr&& data_ptr, + size_t size_bytes) { + data_ptr_ = std::move(data_ptr); + size_bytes_ = static_cast(size_bytes); + size_bytes_is_heap_allocated_ = false; + allocator_ = nullptr; + resizable_ = false; + } + + // This method can be used only after storage construction and cannot be used + // to modify storage status + void set_received_cuda(bool received_cuda) { + received_cuda_ = received_cuda; + } + + bool received_cuda() { + return received_cuda_; + } + + impl::PyObjectSlot* pyobj_slot() { + return &pyobj_slot_; + } + + const impl::PyObjectSlot* pyobj_slot() const { + return &pyobj_slot_; + } + + protected: + // materialize_cow_storage needs to call set_data_ptr_no_materlize_cow + friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage); + + // Returns the previous data_ptr. If the old data_ptr was COW, + // this avoids materializing it + at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) { + at::DataPtr old_data_ptr(std::move(data_ptr_)); + data_ptr_ = std::move(data_ptr); + return old_data_ptr; + } + + private: + // Triggers a copy if this is a copy-on-write tensor. + void maybe_materialize_cow() { + if (data_ptr_.get_deleter() == impl::cow::cow_deleter) { + impl::cow::materialize_cow_storage(*this); + } + } + + DataPtr data_ptr_; + SymInt size_bytes_; + bool size_bytes_is_heap_allocated_; + bool resizable_; + // Identifies that Storage was received from another process and doesn't have + // local to process cuda memory allocation + bool received_cuda_; + Allocator* allocator_; + impl::PyObjectSlot pyobj_slot_; +}; + +// Declare StorageImpl create function pointer types. +using StorageImplCreateHelper = intrusive_ptr (*)( + StorageImpl::use_byte_size_t, + SymInt size_bytes, + DataPtr data_ptr, + Allocator* allocator, + bool resizable); + +C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr); + +C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t); + +C10_API c10::intrusive_ptr make_storage_impl( + c10::StorageImpl::use_byte_size_t use_byte_size, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator* allocator, + bool resizable, + c10::optional device_opt); + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/Stream.h b/MLPY/Lib/site-packages/torch/include/c10/core/Stream.h new file mode 100644 index 0000000000000000000000000000000000000000..8defb338f4cb31e7b4f769722d66aa3b0ec1e46a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/Stream.h @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// An index representing a specific stream. A StreamId is not independently +/// meaningful without knowing the Device it is associated with; try to +/// use Stream rather than StreamId directly. +/// +/// StreamIds are opaque; they are assigned by some DeviceType-specific +/// numbering system which is not visible to the user. HOWEVER, we +/// guarantee that StreamId 0 is always a valid stream, and corresponds +/// to some sort of "default" stream. +using StreamId = int64_t; + +struct C10_API StreamData3 { + StreamId stream_id; + DeviceIndex device_index; + DeviceType device_type; +}; + +// NB: I decided not to call the above StreamIndex to avoid confusion with +// DeviceIndex. This way, you access device index with index(), and stream id +// with id() + +/** + * A stream is a software mechanism used to synchronize launched kernels + * without requiring explicit synchronizations between kernels. The basic + * model is that every kernel launch is associated with a stream: every + * kernel on the same stream is implicitly synchronized so that if I launch + * kernels A and B on the same stream, A is guaranteed to finish before B + * launches. If I want B to run concurrently with A, I must schedule + * it on a different stream. + * + * The Stream class is a backend agnostic value class representing a stream + * which I may schedule a kernel on. Every stream is associated with a device, + * which is recorded in stream, which is used to avoid confusion about which + * device a stream refers to. + * + * Streams are explicitly thread-safe, in the sense that it is OK to pass + * a Stream from one thread to another, and kernels queued from two different + * threads will still get serialized appropriately. (Of course, the + * time when the kernels get queued is undetermined unless you synchronize + * host side ;) + * + * Stream does NOT have a default constructor. Streams are for expert + * users; if you want to use Streams, we're going to assume you know + * how to deal with C++ template error messages if you try to + * resize() a vector of Streams. + * + * Known instances of streams in backends: + * + * - cudaStream_t (CUDA) + * - hipStream_t (HIP) + * - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration + * does NOT support command queues.) + * + * Because this class is device agnostic, it cannot provide backend-specific + * functionality (e.g., get the cudaStream_t of a CUDA stream.) There are + * wrapper classes which provide this functionality, e.g., CUDAStream. + */ +class C10_API Stream final { + private: + Device device_; + StreamId id_; + + public: + enum Unsafe { UNSAFE }; + enum Default { DEFAULT }; + + /// Unsafely construct a stream from a Device and a StreamId. In + /// general, only specific implementations of streams for a + /// backend should manufacture Stream directly in this way; other users + /// should use the provided APIs to get a stream. In particular, + /// we don't require backends to give any guarantees about non-zero + /// StreamIds; they are welcome to allocate in whatever way they like. + explicit Stream(Unsafe, Device device, StreamId id) + : device_(device), id_(id) {} + + /// Construct the default stream of a Device. The default stream is + /// NOT the same as the current stream; default stream is a fixed stream + /// that never changes, whereas the current stream may be changed by + /// StreamGuard. + explicit Stream(Default, Device device) : device_(device), id_(0) {} + + bool operator==(const Stream& other) const noexcept { + return this->device_ == other.device_ && this->id_ == other.id_; + } + bool operator!=(const Stream& other) const noexcept { + return !(*this == other); + } + + Device device() const noexcept { + return device_; + } + DeviceType device_type() const noexcept { + return device_.type(); + } + DeviceIndex device_index() const noexcept { + return device_.index(); + } + StreamId id() const noexcept { + return id_; + } + + // Enqueues a wait instruction in the stream's work queue. + // This instruction is a no-op unless the event is marked + // for recording. In that case the stream stops processing + // until the event is recorded. + template + void wait(const T& event) const { + event.block(*this); + } + + // Return whether all asynchronous work previously enqueued on this stream + // has completed running on the device. + bool query() const; + + // Wait (by blocking the calling thread) until all asynchronous work enqueued + // on this stream has completed running on the device. + void synchronize() const; + + // The purpose of this function is to more conveniently permit binding + // of Stream to and from Python. Without packing, I have to setup a whole + // class with two fields (device and stream id); with packing I can just + // store a single uint64_t. + // + // The particular way we pack streams into a uint64_t is considered an + // implementation detail and should not be relied upon. + uint64_t hash() const noexcept { + // Concat these together into a 64-bit integer + uint64_t bits = static_cast(device_type()) << 56 | + static_cast(device_index()) << 48 | + // Remove the sign extension part of the 64-bit address because + // the id might be used to hold a pointer. + (static_cast(id()) & ((1ull << 48) - 1)); + return bits; + } + + struct StreamData3 pack3() const { + return {id(), device_index(), device_type()}; + } + + static Stream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + TORCH_CHECK(isValidDeviceType(device_type)); + return Stream(UNSAFE, Device(device_type, device_index), stream_id); + } + + // I decided NOT to provide setters on this class, because really, + // why would you change the device of a stream? Just construct + // it correctly from the beginning dude. +}; + +C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s); + +} // namespace c10 + +namespace std { +template <> +struct hash { + size_t operator()(c10::Stream s) const noexcept { + return std::hash{}(s.hash()); + } +}; +} // namespace std diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/StreamGuard.h b/MLPY/Lib/site-packages/torch/include/c10/core/StreamGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..275de06d62d2e4fcb94a8d1ee57f63f2c9529814 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/StreamGuard.h @@ -0,0 +1,170 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/** + * A StreamGuard is an RAII class that changes the current device + * to the device corresponding to some stream, and changes the + * default stream on that device to be this stream. + * + * Use of StreamGuard is HIGHLY discouraged in operator definitions. In + * a single operator, you probably don't know enough about the global + * state of the world to profitably decide how to set streams. Let + * the caller handle this appropriately, and just use the current stream + * in your operator code. + * + * This StreamGuard does NOT have an uninitialized state; it is guaranteed + * to reset the stream and device on exit. If you are in a situation + * where you *might* want to setup a stream guard, see OptionalStreamGuard. + */ +struct StreamGuard { + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit StreamGuard() = delete; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + explicit StreamGuard(Stream stream) : guard_(stream) {} + + /// Copy is disallowed + StreamGuard(const StreamGuard&) = delete; + StreamGuard& operator=(const StreamGuard&) = delete; + + /// Move is disallowed, as StreamGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + StreamGuard(StreamGuard&& other) = delete; + StreamGuard& operator=(StreamGuard&& other) = delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// on , use MultiStreamGuard instead. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the stream that was set at the time the guard was constructed. + Stream original_stream() const { + return guard_.original_stream(); + } + + /// Returns the most recent stream that was set using this device guard, + /// either from construction, or via set_stream. + Stream current_stream() const { + return guard_.current_stream(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return guard_.current_device(); + } + + /// Returns the device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return guard_.original_device(); + } + + private: + c10::impl::InlineStreamGuard guard_; +}; + +/** + * An OptionalStreamGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * See OptionalDeviceGuard for more guidance on how to use this class. + */ +struct OptionalStreamGuard { + /// Create an uninitialized guard. + explicit OptionalStreamGuard() = default; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + explicit OptionalStreamGuard(Stream stream) : guard_(stream) {} + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit OptionalStreamGuard(optional stream_opt) + : guard_(stream_opt) {} + + /// Copy is disallowed + OptionalStreamGuard(const OptionalStreamGuard&) = delete; + OptionalStreamGuard& operator=(const OptionalStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalStreamGuard(OptionalStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the guard if it was not previously initialized. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the stream that was set at the time the guard was most recently + /// initialized, or nullopt if the guard is uninitialized. + optional original_stream() const { + return guard_.original_stream(); + } + + /// Returns the most recent stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + optional current_stream() const { + return guard_.current_stream(); + } + + /// Restore the original device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalStreamGuard guard_{}; +}; + +/** + * A MultiStreamGuard is an RAII class that sets the current streams of a set of + * devices all at once, and resets them to their original values on destruction. + */ +struct MultiStreamGuard { + /// Set the current streams to the passed streams on each of their respective + /// devices. + explicit MultiStreamGuard(ArrayRef streams) : guard_(streams) {} + + /// Copy is disallowed + MultiStreamGuard(const MultiStreamGuard&) = delete; + MultiStreamGuard& operator=(const MultiStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + MultiStreamGuard(MultiStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete; + + private: + c10::impl::InlineMultiStreamGuard guard_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SymBool.h b/MLPY/Lib/site-packages/torch/include/c10/core/SymBool.h new file mode 100644 index 0000000000000000000000000000000000000000..31073aa373fc70d896570fa574777e1731cf9b31 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SymBool.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +class C10_API SymBool { + public: + /*implicit*/ SymBool(bool b) : data_(b){}; + SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_bool()); + }; + SymBool() : data_(false) {} + + SymNodeImpl* toSymNodeImplUnowned() const { + return ptr_.get(); + } + + SymNodeImpl* release() && { + return std::move(ptr_).release(); + } + + // Only valid if is_heap_allocated() + SymNode toSymNodeImpl() const; + + // Guaranteed to return a SymNode, wrapping using base if necessary + SymNode wrap_node(const SymNode& base) const; + + bool expect_bool() const { + c10::optional c = maybe_as_bool(); + TORCH_CHECK(c.has_value()); + return *c; + } + + SymBool sym_and(const SymBool&) const; + SymBool sym_or(const SymBool&) const; + SymBool sym_not() const; + + SymBool operator&(const SymBool& other) const { + return sym_and(other); + } + SymBool operator|(const SymBool& other) const { + return sym_or(other); + } + SymBool operator~() const { + return sym_not(); + } + + // Insert a guard for the bool to be its concrete value, and then return + // that value. Note that C++ comparison operations default to returning + // bool, so it's not so common to have to call this + bool guard_bool(const char* file, int64_t line) const; + bool expect_true(const char* file, int64_t line) const; + bool guard_size_oblivious(const char* file, int64_t line) const; + + bool has_hint() const; + + bool as_bool_unchecked() const { + return data_; + } + + c10::optional maybe_as_bool() const { + if (!is_heap_allocated()) { + return c10::make_optional(data_); + } + return toSymNodeImplUnowned()->constant_bool(); + } + + bool is_heap_allocated() const { + return ptr_; + } + + private: + // TODO: optimize to union + bool data_; + SymNode ptr_; +}; + +C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s); + +#define TORCH_SYM_CHECK(cond, ...) \ + TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) +#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \ + TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__) + +inline bool guard_size_oblivious(bool b, const char* file, int64_t line) { + return b; +} + +inline bool guard_size_oblivious( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_size_oblivious(file, line); +} + +#define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \ + c10::guard_size_oblivious((cond), __FILE__, __LINE__) + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SymFloat.h b/MLPY/Lib/site-packages/torch/include/c10/core/SymFloat.h new file mode 100644 index 0000000000000000000000000000000000000000..38dfb025f850af31af8415cacfd143e9b63ff0bb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SymFloat.h @@ -0,0 +1,113 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +// NB: this is actually double precision; we're using the Python naming here +class C10_API SymFloat { + public: + /*implicit*/ SymFloat(double d) : data_(d){}; + SymFloat(SymNode ptr) + : data_(std::numeric_limits::quiet_NaN()), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_float()); + }; + SymFloat() : data_(0.0) {} + + SymNodeImpl* toSymNodeImplUnowned() const { + return ptr_.get(); + } + + SymNodeImpl* release() && { + return std::move(ptr_).release(); + } + + // Only valid if is_symbolic() + SymNode toSymNodeImpl() const; + + // Guaranteed to return a SymNode, wrapping using base if necessary + SymNode wrap_node(const SymNode& base) const; + + double expect_float() const { + TORCH_CHECK(!is_symbolic()); + return data_; + } + + SymFloat operator+(const SymFloat&) const; + SymFloat operator-(const SymFloat&) const; + SymFloat operator*(const SymFloat&) const; + SymFloat operator/(const SymFloat&) const; + + SymBool sym_eq(const SymFloat&) const; + SymBool sym_ne(const SymFloat&) const; + SymBool sym_lt(const SymFloat&) const; + SymBool sym_le(const SymFloat&) const; + SymBool sym_gt(const SymFloat&) const; + SymBool sym_ge(const SymFloat&) const; + + bool operator==(const SymFloat& o) const { + return sym_eq(o).guard_bool(__FILE__, __LINE__); + } + bool operator!=(const SymFloat& o) const { + return sym_ne(o).guard_bool(__FILE__, __LINE__); + } + bool operator<(const SymFloat& o) const { + return sym_lt(o).guard_bool(__FILE__, __LINE__); + } + bool operator<=(const SymFloat& o) const { + return sym_le(o).guard_bool(__FILE__, __LINE__); + } + bool operator>(const SymFloat& o) const { + return sym_gt(o).guard_bool(__FILE__, __LINE__); + } + bool operator>=(const SymFloat& o) const { + return sym_ge(o).guard_bool(__FILE__, __LINE__); + } + + SymFloat min(const SymFloat& sci) const; + SymFloat max(const SymFloat& sci) const; + + // Need guidance on where to put this code + SymFloat sqrt() const; + + // Insert a guard for the float to be its concrete value, and then return + // that value. This operation always works, even if the float is symbolic, + // so long as we know what the underlying value is. Don't blindly put this + // everywhere; you can cause overspecialization of PyTorch programs with + // this method. + // + // It should be called as guard_float(__FILE__, __LINE__). The file and line + // number can be used to diagnose overspecialization. + double guard_float(const char* file, int64_t line) const; + + bool has_hint() const; + + // N.B. It's important to keep this definition in the header + // as we expect if checks to be folded for mobile builds + // where `is_symbolic` is always false + C10_ALWAYS_INLINE bool is_symbolic() const { + return ptr_; + } + + double as_float_unchecked() const { + return data_; + } + + private: + // TODO: optimize to union + double data_; + SymNode ptr_; +}; + +C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s); +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SymInt.h b/MLPY/Lib/site-packages/torch/include/c10/core/SymInt.h new file mode 100644 index 0000000000000000000000000000000000000000..8c1bba01c0065ec1d80062cb7fef7adddea36fac --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SymInt.h @@ -0,0 +1,423 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10 { + +class SymFloat; + +// SymInt represents either a regular int64_t, or a symbolic integer +// (represented in a type erased way as SymNode). The intention is for SymInt +// to represent symbolic sizes that arise when doing shape computation in +// operator kernels. This allows for tracing through programs without baking in +// concrete sizes into kernel calls. +// +// SymInt has an API equivalent to int64_t. In particular, it is a value type. +// Internally, SymInt is represented in a clever packed way, so that it only +// occupies one word of space; but morally, it is a union between an int64_t +// and an intrusive pointer to SymNodeImpl. +// +// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where +// is_int() returns true + +class C10_API SymInt { + public: + enum Unchecked { + UNCHECKED, + }; + + /*implicit*/ SymInt(int64_t d) : data_(d) { + if (is_heap_allocated()) { + // Large negative number, heap allocate it + promote_to_negative(); + } + }; + SymInt() : data_(0) {} + SymInt(SymNode n); + + // unchecked c-tor accepting raw `data_` + // One appropriate use for this is when you are constructing a symint + // in a situation where you know it is non-negative (or, if it is negative, + // the negative value is -1; i.e., not user controlled) + SymInt(Unchecked, int64_t d) : data_(d) {} + + // TODO: these implementations are not optimal because they allocate a + // temporary and then use the move constructor/assignment + SymInt(const SymInt& s) : data_(0) { + if (s.is_heap_allocated()) { + *this = SymInt(s.toSymNode()); + } else { + data_ = s.data_; + } + } + SymInt(SymInt&& s) noexcept : data_(s.data_) { + s.data_ = 0; + } + + SymInt& operator=(const SymInt& s) { + if (this != &s) { + if (s.is_heap_allocated()) { + *this = SymInt(s.toSymNode()); + } else { + data_ = s.data_; + } + } + return *this; + } + SymInt& operator=(SymInt&& s) noexcept { + if (this != &s) { + release_(); // release the current SymNode if any + data_ = s.data_; + if (s.is_heap_allocated()) + s.data_ = 0; + }; + return *this; + } + + SymNodeImpl* toSymNodeImplUnowned() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_heap_allocated()); + uint64_t unextended_bits = static_cast(data_) & ~MASK; + uint64_t sign_bit_mask = 1ULL << (62 - 1); + // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c + uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; + return static_cast( + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(static_cast(extended_bits))); + } + + void release_() { + if (is_heap_allocated()) { + SymNode::reclaim(toSymNodeImplUnowned()); // steal + } + } + + SymNodeImpl* release() && { +#ifndef C10_MOBILE + TORCH_INTERNAL_ASSERT(is_heap_allocated()); + auto* r = toSymNodeImplUnowned(); + data_ = 0; // transfer ownership + return r; +#else + TORCH_INTERNAL_ASSERT(false); +#endif + } + + // Only valid if is_heap_allocated() + SymNode toSymNode() const; + + // Guaranteed to return a SymNode, wrapping using base if necessary + SymNode wrap_node(const SymNode& base) const; + + ~SymInt() { + release_(); + } + + // Require the int to be non-symbolic, and if it is symbolic raise an + // error. This is safe to use for C++ code that doesn't work for symbolic + // shapes, and you don't have time to fix it immediately, as if we + // try to trigger the path in C++ you'll appropriately get an error + int64_t expect_int() const { + if (auto r = maybe_as_int()) { + return *r; + } + TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE( + false, "when unpacking SymInt, expected int but got ", *this); + } + + // Test if we have a hint for this int (e.g., guard_int would work). + // Most of the time this is true; it is only false when you have + // an unbacked SymInt. + bool has_hint() const; + + // Insert a guard for the int to be its concrete value, and then return + // that value. This operation always works, even if the int is symbolic, + // so long as we know what the underlying value is (e.g., this won't work + // if you call it on the size of nonzero output). Don't blindly put this + // everywhere; you can cause overspecialization of PyTorch programs with + // this method. + // + // It should be called as guard_int(__FILE__, __LINE__). The file and line + // number can be used to diagnose overspecialization. + int64_t guard_int(const char* file, int64_t line) const; + + // Insert a guard that this SymInt must be size-like, returning true if + // the integer actually is >= 0. Unlike manually performing a >= 0 test, + // if the SymInt in question is an unbacked SymInt (or, potentially in the + // future, if it contains unbacked SymInts), we will also treat the + // unbacked SymInt as statically testing >= 2 (which will prevent us from + // choking on, e.g., contiguity checks.) + bool expect_size(const char* file, int64_t line) const; + + // Distinguish actual symbolic values from constants stored on the heap + bool is_symbolic() const { + return is_heap_allocated() && + !toSymNodeImplUnowned()->constant_int().has_value(); + } + + // N.B. It's important to keep this definition in the header + // as we expect if checks to be folded for mobile builds + // where `is_heap_allocated` is always false and optimize dead code paths + C10_ALWAYS_INLINE bool is_heap_allocated() const { +#ifdef C10_MOBILE + return false; +#else + return !check_range(data_); +#endif + } + + SymInt operator+(const SymInt& sci) const; + SymInt operator-(const SymInt& sci) const; + SymInt operator*(const SymInt& sci) const; + SymInt operator/(const SymInt& sci) const; + SymInt operator%(const SymInt& sci) const; + void operator*=(const SymInt& sci); + void operator+=(const SymInt& sci); + void operator/=(const SymInt& sci); + + SymInt clone() const; + + SymBool sym_eq(const SymInt&) const; + SymBool sym_ne(const SymInt&) const; + SymBool sym_lt(const SymInt&) const; + SymBool sym_le(const SymInt&) const; + SymBool sym_gt(const SymInt&) const; + SymBool sym_ge(const SymInt&) const; + + bool operator==(const SymInt& o) const { + return sym_eq(o).guard_bool(__FILE__, __LINE__); + } + bool operator!=(const SymInt& o) const { + return sym_ne(o).guard_bool(__FILE__, __LINE__); + } + bool operator<(const SymInt& o) const { + return sym_lt(o).guard_bool(__FILE__, __LINE__); + } + bool operator<=(const SymInt& o) const { + return sym_le(o).guard_bool(__FILE__, __LINE__); + } + bool operator>(const SymInt& o) const { + return sym_gt(o).guard_bool(__FILE__, __LINE__); + } + bool operator>=(const SymInt& o) const { + return sym_ge(o).guard_bool(__FILE__, __LINE__); + } + + SymInt min(const SymInt& sci) const; + SymInt max(const SymInt& sci) const; + + // If both are symbolic, this checks if + // they share the same node. + // If both are not symbolic this just checks normal equality. + bool is_same(const SymInt& other) const; + + operator SymFloat() const; + + // Don't use this. Prefer maybe_as_int instead + int64_t as_int_unchecked() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!is_heap_allocated()); + return data_; + } + + c10::optional maybe_as_int() const { + if (!is_heap_allocated()) { + return c10::make_optional(data_); + } + auto* node = toSymNodeImplUnowned(); + if (auto c = node->constant_int()) { + return c; + } + return node->maybe_as_int(); + } + + // Return whether the integer is directly coercible to a SymInt + // without requiring heap allocation. You don't need to use this + // to check if you can pass an integer to SymInt; this is guaranteed + // to work (it just might heap allocate!) + static bool check_range(int64_t i) { + return i > MAX_UNREPRESENTABLE_INT; + } + + // Return the min representable integer as a SymInt without + // heap allocation. For quantities that count bytes (or larger), + // this is still much larger than you need, so you may consider + // using this as a more efficient version of MIN_INT + static constexpr int64_t min_representable_int() { + return MAX_UNREPRESENTABLE_INT + 1; + } + + private: + void promote_to_negative(); + + // Constraints on the internal representation: + // + // - Should represent positive and small negative ints + // - No conversion necessary for operations on ints + // - Must represent valid 64-bit pointers + // - Is symbolic test should be FAST (two arithmetic instructions is too + // much). + // This code being a hotpath is based on Strobelight profiles of + // is_heap_allocated(). FB only: https://fburl.com/strobelight/5l50ncxd + // (you will need to change the time window). + // + // So, the scheme is to reserve large negative numbers (assuming + // two's complement): + // + // - 0b0.... means we are a positive int + // - 0b11... means we are a small negative int + // - 0b10... means we are are a pointer. This means that + // [-2^63, -2^62-1] are not representable as ints. + // We don't actually need all of this space as on x86_64 + // as the top 16bits aren't used for anything + static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62 | 1ULL << 61; + static constexpr uint64_t IS_SYM = 1ULL << 63 | 1ULL << 61; + // We must manually translate the bit pattern test into a greater + // than test because compiler doesn't figure it out: + // https://godbolt.org/z/356aferaW + static constexpr int64_t MAX_UNREPRESENTABLE_INT = + -1LL & static_cast(~(1ULL << 62)); + int64_t data_; +}; + +/// Sum of a list of SymInt; accumulates into the c10::SymInt expression +template < + typename C, + typename std::enable_if_t< + std::is_same_v, + int> = 0> +inline c10::SymInt multiply_integers(const C& container) { + return std::accumulate( + container.begin(), + container.end(), + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + +template < + typename Iter, + typename = std::enable_if_t::value_type, + c10::SymInt>>> +inline c10::SymInt multiply_integers(Iter begin, Iter end) { + return std::accumulate( + begin, + end, + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + +#define DECLARE_SYMINT_OP_INTONLY(scalar_t, RetTy) \ + C10_API RetTy operator%(const SymInt& a, scalar_t b); \ + C10_API RetTy operator%(scalar_t a, const SymInt& b); + +#define DECLARE_SYMINT_OP(scalar_t, RetTy) \ + C10_API RetTy operator+(const SymInt& a, scalar_t b); \ + C10_API RetTy operator-(const SymInt& a, scalar_t b); \ + C10_API RetTy operator*(const SymInt& a, scalar_t b); \ + C10_API RetTy operator/(const SymInt& a, scalar_t b); \ + C10_API RetTy operator+(scalar_t a, const SymInt& b); \ + C10_API RetTy operator-(scalar_t a, const SymInt& b); \ + C10_API RetTy operator*(scalar_t a, const SymInt& b); \ + C10_API RetTy operator/(scalar_t a, const SymInt& b); \ + C10_API bool operator==(const SymInt& a, scalar_t b); \ + C10_API bool operator!=(const SymInt& a, scalar_t b); \ + C10_API bool operator<(const SymInt& a, scalar_t b); \ + C10_API bool operator<=(const SymInt& a, scalar_t b); \ + C10_API bool operator>(const SymInt& a, scalar_t b); \ + C10_API bool operator>=(const SymInt& a, scalar_t b); \ + C10_API bool operator==(scalar_t a, const SymInt& b); \ + C10_API bool operator!=(scalar_t a, const SymInt& b); \ + C10_API bool operator<(scalar_t a, const SymInt& b); \ + C10_API bool operator<=(scalar_t a, const SymInt& b); \ + C10_API bool operator>(scalar_t a, const SymInt& b); \ + C10_API bool operator>=(scalar_t a, const SymInt& b); + +DECLARE_SYMINT_OP_INTONLY(int64_t, SymInt) +DECLARE_SYMINT_OP_INTONLY(int32_t, SymInt) +DECLARE_SYMINT_OP_INTONLY(uint64_t, SymInt) +DECLARE_SYMINT_OP_INTONLY(uint32_t, SymInt) +DECLARE_SYMINT_OP(int64_t, SymInt) +DECLARE_SYMINT_OP(int32_t, SymInt) // make sure constants work +DECLARE_SYMINT_OP(uint64_t, SymInt) +DECLARE_SYMINT_OP(uint32_t, SymInt) +DECLARE_SYMINT_OP(double, SymFloat) +DECLARE_SYMINT_OP(float, SymFloat) // just for completeness + +// On OSX size_t is different than uint64_t so we have to +// define it separately +#if defined(__APPLE__) +DECLARE_SYMINT_OP_INTONLY(size_t, SymInt) +DECLARE_SYMINT_OP(size_t, SymInt) +#endif + +#undef DECLARE_SYMINT_OP + +C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s); +C10_API SymInt operator-(const SymInt& s); + +inline bool sym_eq(int64_t a, int64_t b) { + return a == b; +} + +inline SymBool sym_eq(const SymInt& a, const SymInt& b) { + return a.sym_eq(b); +} + +inline bool sym_ne(int64_t a, int64_t b) { + return a != b; +} + +inline SymBool sym_ne(const SymInt& a, const SymInt& b) { + return a.sym_ne(b); +} + +inline bool sym_lt(int64_t a, int64_t b) { + return a < b; +} + +inline SymBool sym_lt(const SymInt& a, const SymInt& b) { + return a.sym_lt(b); +} + +inline bool sym_le(int64_t a, int64_t b) { + return a <= b; +} + +inline SymBool sym_le(const SymInt& a, const SymInt& b) { + return a.sym_le(b); +} + +inline bool sym_gt(int64_t a, int64_t b) { + return a > b; +} + +inline SymBool sym_gt(const SymInt& a, const SymInt& b) { + return a.sym_gt(b); +} + +inline bool sym_ge(int64_t a, int64_t b) { + return a >= b; +} + +inline SymBool sym_ge(const SymInt& a, const SymInt& b) { + return a.sym_ge(b); +} + +inline bool definitely_true( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.has_hint() && b.guard_bool(file, line); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SymIntArrayRef.h b/MLPY/Lib/site-packages/torch/include/c10/core/SymIntArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..91557143e9d8627721379a50dcff92fa60a488c5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SymIntArrayRef.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { +using SymIntArrayRef = ArrayRef; + +inline at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) { + return IntArrayRef(reinterpret_cast(ar.data()), ar.size()); +} + +// TODO: a SymIntArrayRef containing a heap allocated large negative integer +// can actually technically be converted to an IntArrayRef... but not with +// the non-owning API we have here. We can't reinterpet cast; we have to +// allocate another buffer and write the integers into it. If you need it, +// we can do it. But I don't think you need it. + +inline c10::optional asIntArrayRefSlowOpt( + c10::SymIntArrayRef ar) { + for (const c10::SymInt& sci : ar) { + if (sci.is_heap_allocated()) { + return c10::nullopt; + } + } + + return {asIntArrayRefUnchecked(ar)}; +} + +inline at::IntArrayRef asIntArrayRefSlow( + c10::SymIntArrayRef ar, + const char* file, + int64_t line) { + for (const c10::SymInt& sci : ar) { + TORCH_CHECK( + !sci.is_heap_allocated(), + file, + ":", + line, + ": SymIntArrayRef expected to contain only concrete integers"); + } + return asIntArrayRefUnchecked(ar); +} + +#define C10_AS_INTARRAYREF_SLOW(a) c10::asIntArrayRefSlow(a, __FILE__, __LINE__) + +// Prefer using a more semantic constructor, like +// fromIntArrayRefKnownNonNegative +inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) { + return SymIntArrayRef( + reinterpret_cast(array_ref.data()), array_ref.size()); +} + +inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) { + return fromIntArrayRefUnchecked(array_ref); +} + +inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) { + for (long i : array_ref) { + TORCH_CHECK( + SymInt::check_range(i), + "IntArrayRef contains an int that cannot be represented as a SymInt: ", + i); + } + return SymIntArrayRef( + reinterpret_cast(array_ref.data()), array_ref.size()); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SymNodeImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/SymNodeImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..02fba001736a2320c1ed131d91c2f2ac9ed2450e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SymNodeImpl.h @@ -0,0 +1,218 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +class SymNodeImpl; +using SymNode = c10::intrusive_ptr; + +// When you add a method, you also need to edit +// torch/csrc/jit/python/init.cpp +// torch/csrc/utils/python_symnode.h +// c10/core/ConstantSymNodeImpl.h +class C10_API SymNodeImpl : public c10::intrusive_ptr_target { + public: + ~SymNodeImpl() override = default; + + template + c10::intrusive_ptr dyn_cast() const { + return c10::intrusive_ptr::reclaim_copy(dynamic_cast(this)); + } + + // these could be pure virtual when we implement LTC versions + virtual bool is_int() { + TORCH_CHECK(false, "NYI"); + }; + virtual bool is_bool() { + TORCH_CHECK(false, "NYI"); + }; + virtual bool is_float() { + TORCH_CHECK(false, "NYI"); + }; + virtual bool is_nested_int() const { + return false; + }; + virtual SymNode add(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sub(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode mul(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode truediv(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode pow(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode floordiv(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode mod(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode eq(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode ne(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode gt(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode lt(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode le(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode ge(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode ceil() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode floor() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode neg() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_min(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_max(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_or(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_and(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_not() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) { + TORCH_CHECK(false, "NYI"); + }; + // NB: self is ignored here, only the arguments are used + virtual SymNode is_contiguous( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode is_channels_last_contiguous_2d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode is_channels_last_contiguous_3d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode is_channels_last_strides_2d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode is_channels_last_strides_3d( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode is_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode clone() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_float() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode wrap_int(int64_t num) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode wrap_float(double num) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode wrap_bool(bool num) { + TORCH_CHECK(false, "NYI"); + }; + virtual int64_t guard_int(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + }; + virtual bool guard_bool(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + }; + virtual double guard_float(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + }; + virtual bool guard_size_oblivious(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } + virtual bool expect_true(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + }; + virtual bool expect_size(const char* file, int64_t line) { + // No improvement for unbacked SymInts by default, replace this + // with a better implementation! + return ge(wrap_int(0))->guard_bool(file, line); + }; + virtual int64_t int_() { + TORCH_CHECK(false, "NYI"); + }; + virtual bool bool_() { + TORCH_CHECK(false, "NYI"); + }; + virtual bool has_hint() { + TORCH_CHECK(false, "NYI"); + }; + virtual std::string str() { + TORCH_CHECK(false, "NYI"); + }; + virtual c10::optional nested_int() { + return c10::nullopt; + } + virtual c10::optional nested_int_coeff() { + return c10::nullopt; + } + virtual c10::optional constant_int() { + return c10::nullopt; + } + virtual c10::optional constant_bool() { + return c10::nullopt; + } + virtual c10::optional maybe_as_int() { + return c10::nullopt; + } + virtual bool is_constant() { + return false; + } + virtual bool is_symbolic() { + return true; + } + std::ostream& operator<<(std::ostream& os) { + os << str(); + return os; + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/SymbolicShapeMeta.h b/MLPY/Lib/site-packages/torch/include/c10/core/SymbolicShapeMeta.h new file mode 100644 index 0000000000000000000000000000000000000000..f68f87e9486833aff29924da4afdad18a788b1a9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/SymbolicShapeMeta.h @@ -0,0 +1,214 @@ +#pragma once +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +class C10_API SymbolicShapeMeta { + public: + // Basic metadata from which other quantities are derived + SymDimVector sizes_ = {0}; + SymDimVector strides_ = {1}; + SymInt storage_offset_ = 0; + + bool strides_valid_ = true; // e.g. for sparse where there are no strides + + SymbolicShapeMeta() = default; + SymbolicShapeMeta(const SymbolicShapeMeta& other); + + void refresh_numel() { + // Non-const, don't need to hold mutables_ lock + available_.fetch_and(~numel_avail); + numel_ = 1; + } + + void refresh_contiguous() { + // Non-const, don't need to hold mutables_ lock + available_.fetch_and(numel_avail); + is_contiguous_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_contiguous_ = false; + is_channels_last_ = false; + is_channels_last_3d_ = false; + is_non_overlapping_and_dense_ = false; + } + + int64_t dim() const { + return static_cast(sizes_.size()); + } + + // Accessors for derived quantities, computed lazily on first access + + bool has_numel() const { + return available_.load() & numel_avail; + } + bool has_is_contiguous() const { + return available_.load() & is_contiguous_avail; + } + bool has_is_channels_last_contiguous() const { + return available_.load() & is_channels_last_contiguous_avail; + } + bool has_is_channels_last_3d_contiguous() const { + return available_.load() & is_channels_last_3d_contiguous_avail; + } + bool has_is_channels_last() const { + return available_.load() & is_channels_last_avail; + } + bool has_is_channels_last_3d() const { + return available_.load() & is_channels_last_3d_avail; + } + bool has_is_non_overlapping_and_dense() const { + return available_.load() & is_non_overlapping_and_dense_avail; + } + + // Accessors to cached derived properties + // DO NOT call with mutables_ lock held + const SymInt& numel() const { + if (C10_UNLIKELY(!has_numel())) { + init_numel(); + } + return numel_; + } + + const SymBool& is_contiguous() const { + if (C10_UNLIKELY(!has_is_contiguous())) { + init_is_contiguous(); + } + return is_contiguous_; + } + + const SymBool& is_channels_last_contiguous() const { + if (C10_UNLIKELY(!has_is_channels_last_contiguous())) { + init_is_channels_last_contiguous(); + } + return is_channels_last_contiguous_; + } + + const SymBool& is_channels_last_3d_contiguous() const { + if (C10_UNLIKELY(!has_is_channels_last_3d_contiguous())) { + init_is_channels_last_3d_contiguous(); + } + return is_channels_last_3d_contiguous_; + } + + const SymBool& is_channels_last() const { + if (C10_UNLIKELY(!has_is_channels_last())) { + init_is_channels_last(); + } + return is_channels_last_; + } + + const SymBool& is_channels_last_3d() const { + if (C10_UNLIKELY(!has_is_channels_last_3d())) { + init_is_channels_last_3d(); + } + return is_channels_last_3d_; + } + + const SymBool& is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(!has_is_non_overlapping_and_dense())) { + init_is_non_overlapping_and_dense(); + } + return is_non_overlapping_and_dense_; + } + + // Assumptions so we can short-circuit computation + // NOTE: Don't need to lock mutables_ since these aren't const + void assume_contiguous(SymBool val = true) { + is_contiguous_ = std::move(val); + available_.fetch_or(is_contiguous_avail); + } + void assume_channels_last_contiguous(SymBool val = true) { + is_contiguous_ = std::move(val); + available_.fetch_or(is_channels_last_contiguous_avail); + } + void assume_channels_last_3d_contiguous(SymBool val = true) { + is_channels_last_3d_contiguous_ = std::move(val); + available_.fetch_or(is_channels_last_3d_contiguous_avail); + } + void assume_channels_last(SymBool val = true) { + is_channels_last_ = std::move(val); + available_.fetch_or(is_channels_last_avail); + } + void assume_channels_last_3d(SymBool val = true) { + is_channels_last_3d_ = std::move(val); + available_.fetch_or(is_channels_last_3d_avail); + } + void assume_non_overlapping_and_dense(SymBool val = true) { + is_non_overlapping_and_dense_ = std::move(val); + available_.fetch_or(is_non_overlapping_and_dense_avail); + } + + private: + SymBool compute_contiguous() const; + SymBool compute_channels_last_contiguous_2d() const; + SymBool compute_channels_last_contiguous_3d() const; + SymBool compute_strides_like_channels_last_2d() const; + SymBool compute_strides_like_channels_last_3d() const; + SymBool compute_non_overlapping_and_dense() const; + + // These are little wrappers over the real compute_ functions that + // can make use of other contiguity fields to short circuit. + // They need to be implemented separately for SymBool, as SymBool does + // not short circuit. + // TODO: should the SymBool cases avoid the short circuit? Need to reason + // if its correct, and reason if the simpler expressions are better for + // analysis (maybe not!) + + SymBool compute_channels_last_contiguous_3d_dim5() const; + SymBool compute_channels_last_2d_dim5() const; + SymBool compute_channels_last_3d_dim5() const; + SymBool compute_is_non_overlapping_and_dense_dim4() const; + SymBool compute_is_non_overlapping_and_dense_dim5() const; + SymBool compute_is_non_overlapping_and_dense_anydim() const; + + void init_numel() const; + void init_is_contiguous() const; + void init_is_channels_last_contiguous() const; + void init_is_channels_last_3d_contiguous() const; + void init_is_channels_last() const; + void init_is_channels_last_3d() const; + void init_is_non_overlapping_and_dense() const; + + // NOTE: These only set if !has_foo() + void set_numel(SymInt val) const; + void set_is_contiguous(SymBool val) const; + void set_is_channels_last_contiguous(SymBool val) const; + void set_is_channels_last_3d_contiguous(SymBool val) const; + void set_is_channels_last(SymBool val) const; + void set_is_channels_last_3d(SymBool val) const; + void set_is_non_overlapping_and_dense(SymBool val) const; + + // Lazily initialized variables, with the corresponding available_ flag + // indicating whether the value has been initialized + mutable std::atomic available_{0}; + enum avail { + numel_avail = 1 << 0, + is_contiguous_avail = 1 << 1, + is_channels_last_contiguous_avail = 1 << 2, + is_channels_last_3d_contiguous_avail = 1 << 3, + is_channels_last_avail = 1 << 4, + is_channels_last_3d_avail = 1 << 5, + is_non_overlapping_and_dense_avail = 1 << 6, + }; + + // Mutex to prevent races when initializing the variable from const accessors + mutable std::mutex mutables_; + mutable SymInt numel_ = 1; + mutable SymBool is_contiguous_{true}; + mutable SymBool is_channels_last_contiguous_{false}; + mutable SymBool is_channels_last_3d_contiguous_{false}; + mutable SymBool is_channels_last_{false}; + mutable SymBool is_channels_last_3d_{false}; + mutable SymBool is_non_overlapping_and_dense_{true}; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/TensorImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/TensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..adfaa1adb5419fc692777994db1585ab7fae2923 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/TensorImpl.h @@ -0,0 +1,3249 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// A global boolean variable to control whether we free memory when a Tensor +// is shrunk to a smaller size. As a result, a Tensor is always going to +// keep the memory allocated for its maximum capacity reshaped to so far. +// +// This parameter is respected "upper-case" methods which call Resize() +// (e.g., CopyFrom, ResizeLike); it is NOT respected by Tensor::resize_ +// or ShrinkTo, both of which guarantee to never to free memory. +C10_DECLARE_bool(caffe2_keep_on_shrink); + +// Since we can have high variance in blob memory allocated across different +// inputs in the same run, we will shrink the blob only if the memory gain +// is larger than this flag in bytes. This only applies to functions which +// respect caffe2_keep_on_shrink. +C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory); + +namespace at { +class Tensor; +class TensorBase; +} // namespace at + +namespace c10 { + +/** + * A utility function to convert vector to vector. + */ +inline std::vector ToVectorint64_t(const ArrayRef& src) { + return std::vector(src.begin(), src.end()); +} + +/** + * Return product of all dimensions starting from k + */ +inline int64_t size_from_dim_(int k, IntArrayRef dims) { + int64_t r = 1; + for (const auto i : c10::irange(k, dims.size())) { + r *= dims[i]; + } + return r; +} + +// Product of all dims up to k (not including dims[k]) +inline int64_t size_to_dim_(int k, IntArrayRef dims) { + TORCH_CHECK(k >= 0 && static_cast(k) <= dims.size()); + int64_t r = 1; + for (const auto i : c10::irange(k)) { + r *= dims[i]; + } + return r; +} + +// Product of all dims between k and l (not including dims[k] and dims[l]) +inline int64_t size_between_dim_(int k, int l, IntArrayRef dims) { + TORCH_CHECK((unsigned)l < dims.size() && (unsigned)k < dims.size()); + int64_t r = 1; + if (k < l) { + for (int i = k + 1; i < l; ++i) { + r *= dims[i]; + } + } else { + for (int i = l + 1; i < k; ++i) { + r *= dims[i]; + } + } + return r; +} + +// Wrap around axis_index if it is negative, s.t., -1 is the last dim +inline int canonical_axis_index_(int axis_index, int ndims) { + TORCH_CHECK(axis_index >= -ndims); + TORCH_CHECK(axis_index < ndims); + if (axis_index < 0) { + return axis_index + ndims; + } + return axis_index; +} + +using PlacementDtor = void (*)(void*, size_t); + +/* + * A Context that will call extra placement deleter during + * deconstruction. + * + * Accept a already constructed DataPtr and store it as member + * during destruction, we'll call extra deleter on the underlying + * data pointer before the DataPtr is destructed. + * `data_ptr_` owns the memory. + */ +struct C10_API PlacementDeleteContext { + DataPtr data_ptr_; + PlacementDtor placement_dtor_; + size_t size_; + PlacementDeleteContext( + DataPtr&& data_ptr, + PlacementDtor placement_dtor, + size_t size) + : data_ptr_(std::move(data_ptr)), + placement_dtor_(placement_dtor), + size_(size) {} + static DataPtr makeDataPtr( + DataPtr&& data_ptr, + PlacementDtor placement_dtor, + size_t size, + Device device); + ~PlacementDeleteContext() { + placement_dtor_(data_ptr_.get(), size_); + // original memory will be freed when data_ptr_ is destructed + } +}; + +struct C10_API AutogradMetaInterface { + virtual void set_requires_grad( + bool requires_grad, + at::TensorImpl* self_impl) = 0; + virtual bool requires_grad() const = 0; + virtual at::Tensor& mutable_grad() = 0; + virtual const at::Tensor& grad() const = 0; + virtual const at::Tensor& fw_grad(uint64_t level, const at::TensorBase& self) + const = 0; + virtual void set_fw_grad( + const at::TensorBase& new_grad, + const at::TensorBase& self, + uint64_t level, + bool is_inplace_op) = 0; + virtual ~AutogradMetaInterface(); +}; + +namespace impl { + +// Unfortunately, the definition of AutogradMeta lives in a separate +// compilation unit than TensorImpl (libtorch.so versus libc10.so) +// which means that we cannot construct an AutogradMeta from TensorImpl, +// not even from the cpp file. So we have to indirect it through a factory +// function which will be initialized when we load libtorch.so. + +struct C10_API AutogradMetaFactory { + virtual ~AutogradMetaFactory() = default; + virtual std::unique_ptr make() const = 0; + // This method is the dumbest method. But I don't have access + // to Tensor (not TensorImpl) which is undefined in this header. + virtual const at::Tensor& undefined_tensor() const = 0; +}; + +C10_API void SetAutogradMetaFactory(AutogradMetaFactory* factory); +C10_API AutogradMetaFactory* GetAutogradMetaFactory(); + +struct C10_API AutogradMetaFactoryRegisterer { + explicit AutogradMetaFactoryRegisterer(AutogradMetaFactory* factory) { + SetAutogradMetaFactory(factory); + } +}; + +} // namespace impl + +struct C10_API NamedTensorMetaInterface { + virtual ~NamedTensorMetaInterface() = default; + virtual std::unique_ptr clone() const { + TORCH_INTERNAL_ASSERT( + false, "Not implemented: NamedTensorMetaInterface::clone"); + }; + virtual int64_t slow_dim() const { + TORCH_INTERNAL_ASSERT( + false, "Not implemented: NamedTensorMetaInterface::slow_dim"); + }; +}; + +// For ease of copy pasting +#if 0 +is_contiguous +is_channels_last_contiguous +is_channels_last_3d_contiguous +is_channels_last +is_channels_last_3d +is_non_overlapping_and_dense +#endif + +/** + * This structure is intended to hold additional metadata of the specific device + * backend. + **/ +struct C10_API BackendMeta : intrusive_ptr_target { + ~BackendMeta() override = default; + virtual intrusive_ptr clone( + const intrusive_ptr& ptr) const { + return ptr; + } +}; + +struct C10_API ExtraMeta { + std::unique_ptr symbolic_shape_meta_ = nullptr; + std::unique_ptr named_tensor_meta_ = nullptr; + intrusive_ptr backend_meta_ = nullptr; + c10::optional custom_data_ptr_error_msg_ = c10::nullopt; + c10::optional custom_storage_error_msg_ = c10::nullopt; + + ExtraMeta() = default; + ExtraMeta(const ExtraMeta& other) { + if (other.symbolic_shape_meta_) { + symbolic_shape_meta_ = + std::make_unique(*other.symbolic_shape_meta_); + } + if (other.named_tensor_meta_) { + named_tensor_meta_ = other.named_tensor_meta_->clone(); + } + if (other.backend_meta_) { + backend_meta_ = other.backend_meta_->clone(other.backend_meta_); + } + if (other.custom_data_ptr_error_msg_) { + custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_; + } + if (other.custom_storage_error_msg_) { + custom_storage_error_msg_ = other.custom_storage_error_msg_; + } + } + + ExtraMeta( + std::unique_ptr symbolic_shape_meta, + std::unique_ptr named_tensor_meta, + intrusive_ptr backend_meta, + c10::optional custom_data_ptr_error_msg = c10::nullopt, + c10::optional custom_storage_access_error_msg = c10::nullopt) + : symbolic_shape_meta_(std::move(symbolic_shape_meta)), + named_tensor_meta_(std::move(named_tensor_meta)), + backend_meta_(std::move(backend_meta)), + custom_data_ptr_error_msg_(std::move(custom_data_ptr_error_msg)), + custom_storage_error_msg_(std::move(custom_storage_access_error_msg)) {} + + std::unique_ptr clone() const { + return std::make_unique(*this); + } +}; + +// NOTE [ Version Counter Sharing ] +// +// Every Tensor has a version counter. Version counters are incremented whenever +// the data or size of a tensor changes through in-place Variable operations. +// Version counters are used to detect modifications to saved variables which +// would result in incorrect gradient calculations. Version counters may be +// shared between Variables: +// +// 1. A view shares the version counter of the base Variable, +// 2. `x.detach()` shares the version counter of `x`, +// 3. Unpacked saved variables share the version counter of the source. +// +// Version counters are not shared in these scenarios: +// +// 1. When we replace a `Variable`'s underlying `Tensor` by calling +// `set_data(...)`, +// 2. `x.data` does not share the version counter of `x`. (See discussion at +// https://github.com/pytorch/pytorch/issues/5396) +// +// Question: Why do we put the version counter in TensorImpl instead of +// AutogradMeta? +// +// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta +// when its `requires_grad_` is false, but when we use this tensor in the +// forward pass of a function that requires saving this tensor for backward, we +// need to keep track of this tensor's version to make sure it's always valid in +// the autograd graph. +// +// To achieve this goal, we put the version counter in TensorImpl instead of +// AutogradMeta, and have it always be available. This allows us to have the +// optimization of not carrying AutogradMeta when a tensor doesn't require +// gradient. +// +// A hypothetical alternative way to achieve this goal is to initialize +// AutogradMeta and create the version counter for the non-requires-grad tensor +// only when it's saved for backward. However, since saving a tensor for +// backward happens in the forward pass, and our invariant is that forward pass +// needs to be thread-safe, lazy-initializing AutogradMeta when saving a tensor +// can introduce race conditions when we are running the forward pass in +// multi-thread scenarios, thus making the forward pass not thread-safe anymore, +// which breaks the invariant. +struct C10_API VariableVersion { + private: + struct VersionCounter : intrusive_ptr_target { + VersionCounter(uint32_t version) : version_(version) {} + std::atomic version_; + }; + c10::intrusive_ptr version_counter_; + + public: + // Note [Disabled VariableVersion] + // VariableVersion struct has an intrusive_ptr pointing VersionCounter struct + // with an atomic variable. Thus `VariableVersion(/*version=*/0)` is not as + // cheap as we expected. In some cases constructing a VariableVersion with + // version 0 is not necessary so we add a cheap constructor which + // doesn't allocate the intrusive_ptr. + // Example use cases are: + // - Inference tensors don't track version counter, so they'll just always + // have disabled VariableVersion. + // - In SavedVariable class we override version_counter_ inside its + // constructor + // so that we can use the cheap constructor there. + enum Disabled { DISABLED }; + // It's okay to return true even for inference tensor which + // doesn't have version counter enabled. + // We want to be permissive here since in many cases (e.g. make_variable) + // we can std::move a TensorImpl if there's no other uses which saves us + // an additional TensorImpl allocation. + bool unique() const { + return version_counter_ ? 1 == version_counter_.use_count() : true; + } + // NOTE: As of C++11 and 14, default-constructing a std::atomic variable + // leaves it in a persistently undefined state. See + // https://cplusplus.github.io/LWG/issue2334. + VariableVersion(uint32_t version) + : version_counter_(c10::make_intrusive(version)) {} + VariableVersion(Disabled = DISABLED) {} + + bool enabled() const { + return version_counter_; + } + + // Note [Inplace update inference tensor] + // 1. Inplace update to inference tensor is forbidden in normal mode. + // For example: + // inference_tensor.copy_(normal_tensor_requires_grad) + // This inplace makes inference_tensor have requires_grad=True and + // have a grad_fn. This is bad because views of `inference_tensor` + // created in InferenceMode won't be able to know the grad_fn since + // their ViewMeta were not recorded. To match NoGradMode behavior + // that "inplace update to a view created in NoGradMode raise an error", + // we just ban inplace update to inference tensor since we can't tell + // if an inference tensor is a view created in InferenceMode. + // + // Note that views of normal tensor created in InferenceMode has proper + // ViewMeta so that they're aware of the grad_fn correctly. + // + // 2. Inplace update to inference tensor in inference tensor doesn't bump + // version counter. + // * It either doesn't call bump() by skipping ADInplaceOrView kernel, + // - e.g. inference_tensor.add_(1) + // * or bump() is a no-op for inference tensor. + // - e.g. inference_tensor.add_(normal_tensor) + void bump() { + // TODO: Replace the link to the documentation once it's available. + TORCH_CHECK( + version_counter_ || InferenceMode::is_enabled(), + "Inplace update to inference tensor outside InferenceMode is not allowed." + "You can make a clone to get a normal tensor before doing inplace update." + "See https://github.com/pytorch/rfcs/pull/17 for more details."); + if (version_counter_) { + ++version_counter_->version_; + } + } + + void set_version(int64_t i) { + TORCH_CHECK( + version_counter_, + "Tried to call torch.autograd._unsafe_set_version() on a tensor " + "that does not have a version counter. Was it created in inference mode?"); + TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i); + version_counter_->version_ = i; + } + + // Inference tensor doesn't have version counter so it shouldn't be + // accessed. + uint32_t current_version() const { + TORCH_CHECK( + version_counter_, "Inference tensors do not track version counter."); + return version_counter_->version_; + } +}; + +// Forward declaration of TensorImpl needed for forward declaration of +// C10_TensorImpl_Size_Check_Dummy_Class +struct C10_API TensorImpl; + +/** + * NOTE: Some TensorImpl methods are small and not overridden in the + * PyTorch codebase itself, but may theoretically need to be + * overridden by third-party TensorImpl subclasses. This macro allows + * users that need maximum performance and don't need these extension + * points to disable them with a build-time flag. (In particular, + * XLA's XLATensorImpl currently overrides these methods, so we can't + * enable this flag by default.) + */ +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY +#define TENSORIMPL_MAYBE_VIRTUAL +#else +#define TENSORIMPL_MAYBE_VIRTUAL virtual +#endif + +/** + * The low-level representation of a tensor, which contains a pointer + * to a storage (which contains the actual data) and metadata (e.g., sizes and + * strides) describing this particular view of the data as a tensor. + * + * Some basic characteristics about our in-memory representation of + * tensors: + * + * - It contains a pointer to a storage struct (Storage/StorageImpl) + * which contains the pointer to the actual data and records the + * data type and device of the view. This allows multiple tensors + * to alias the same underlying data, which allows to efficiently + * implement differing *views* on a tensor. + * + * - The tensor struct itself records view-specific metadata about + * the tensor, e.g., sizes, strides and offset into storage. + * Each view of a storage can have a different size or offset. + * + * - This class is intrusively refcounted. It is refcounted so that + * we can support prompt deallocation of large tensors; it is + * intrusively refcounted so that we can still perform reference + * counted operations on raw pointers, which is often more convenient + * when passing tensors across language boundaries. + * + * - For backwards-compatibility reasons, a tensor may be in an + * uninitialized state. A tensor may be uninitialized in the following + * two ways: + * + * - A tensor may be DTYPE UNINITIALIZED. A tensor of this + * form has an uninitialized dtype. This situation most + * frequently arises when a user writes Tensor x(CPU). The dtype + * is subsequently initialized when mutable_data() is + * invoked for the first time. + * + * - A tensor may be STORAGE UNINITIALIZED. A tensor of this form + * has non-zero size, but has a storage with a null data pointer. + * This situation most frequently arises when a user calls + * Resize() or FreeMemory(). This is because Caffe2 historically + * does lazy allocation: allocation of data doesn't occur until + * mutable_data() is invoked. A tensor with zero size is + * always storage initialized, because no allocation is necessary + * in this case. + * + * All combinations of these two uninitialized states are possible. + * Consider the following transcript in idiomatic Caffe2 API: + * + * Tensor x(CPU); // x is storage-initialized, dtype-UNINITIALIZED + * x.Resize(4); // x is storage-UNINITIALIZED, dtype-UNINITIALIZED + * x.mutable_data(); // x is storage-initialized, dtype-initialized + * x.FreeMemory(); // x is storage-UNINITIALIZED, dtype-initialized. + * + * All other fields on tensor are always initialized. In particular, + * size is always valid. (Historically, a tensor declared as Tensor x(CPU) + * also had uninitialized size, encoded as numel == -1, but we have now + * decided to default to zero size, resulting in numel == 0). + * + * Uninitialized storages MUST be uniquely owned, to keep our model + * simple. Thus, we will reject operations which could cause an + * uninitialized storage to become shared (or a shared storage to + * become uninitialized, e.g., from FreeMemory). + * + * In practice, tensors which are storage-UNINITIALIZED and + * dtype-UNINITIALIZED are *extremely* ephemeral: essentially, + * after you do a Resize(), you basically always call mutable_data() + * immediately afterwards. Most functions are not designed to + * work if given a storage-UNINITIALIZED, dtype-UNINITIALIZED tensor. + * + * We intend to eliminate all uninitialized states, so that every + * tensor is fully initialized in all fields. Please do not write new code + * that depends on these uninitialized states. + */ +struct C10_API TensorImpl : public c10::intrusive_ptr_target { + TensorImpl() = delete; + ~TensorImpl() override; + // Note [Enum ImplType] + // This enum is temporary. In the followup refactor we should + // think about how to specialize TensorImpl creation for view + // tensors. Currently we only special case its key_set_ but + // there's also potential to share version_counter_ directly + // without creating first and then override in as_view. + enum ImplType { VIEW }; + + /** + * Construct a 1-dim 0-size tensor backed by the given storage. + */ + TensorImpl( + Storage&& storage, + DispatchKeySet, + const caffe2::TypeMeta data_type); + + // See Note [Enum ImplType] + TensorImpl( + ImplType, + Storage&& storage, + DispatchKeySet, + const caffe2::TypeMeta data_type); + + /** + * Construct a 1-dim 0 size tensor that doesn't have a storage. + */ + TensorImpl( + DispatchKeySet, + const caffe2::TypeMeta data_type, + c10::optional device_opt); + + // Legacy constructors so I don't have to go update call sites. + // TODO: When Variable is added, delete these constructors + TensorImpl( + Storage&& storage, + DispatchKey dispatch_key, + const caffe2::TypeMeta data_type) + : TensorImpl( + std::move(storage), + DispatchKeySet(dispatch_key), + data_type) {} + TensorImpl( + DispatchKey dispatch_key, + const caffe2::TypeMeta data_type, + c10::optional device_opt) + : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} + + private: + // This constructor is private, because the data_type is redundant with + // storage. Still, we pass it in separately because it's easier to write + // the initializer list if we're not worried about storage being moved out + // from under us. + TensorImpl( + Storage&& storage, + DispatchKeySet, + const caffe2::TypeMeta data_type, + c10::optional); + + public: + TensorImpl(const TensorImpl&) = delete; + TensorImpl& operator=(const TensorImpl&) = delete; + TensorImpl(TensorImpl&&) = delete; + TensorImpl& operator=(TensorImpl&&) = delete; + + /** + * Release (decref) storage, and any other external allocations. This + * override is for `intrusive_ptr_target` and is used to implement weak + * tensors. + */ + void release_resources() override; + + public: + /** + * Return the DispatchKeySet corresponding to this Tensor, specifying + * all of the DispatchKeys that this Tensor identifies as. This is the + * information used to dispatch operations on this tensor. + */ + DispatchKeySet key_set() const { + return key_set_; + } + + private: + [[noreturn]] void throw_cannot_call_with_symbolic(const char* meth) const; + + // NOTE: The general recipe for customizable methods is that the fastpath + // function (e.g., sizes()) does an unlikely policy test, and if doesn't + // trigger, it does the fast path implementation with no checks and going + // directly to on-TensorImpl fields. In particular, you never need to + // check ExtraMeta if the policy doesn't trigger, as non-trivial ExtraMeta + // implies the policy will always match. + // + // The default implementations of methods are "safe": they do extra tests + // to make sure the internal state is consistent no matter if you are + // doing symbolic shapes or not. If you don't want the tests, directly + // override the custom method (e.g., custom_sizes()) to do your preferred + // behavior. + + public: + /** + * Return a reference to the sizes of this tensor. This reference remains + * valid as long as the tensor is live and not resized. + */ + IntArrayRef sizes() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sizes_custom(); + } + return sizes_and_strides_.sizes_arrayref(); + } + + SymIntArrayRef sym_sizes() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_sizes_custom(); + } + // Sizes guaranteed to be non-negative, so unchecked cast is OK + return c10::fromIntArrayRefKnownNonNegative( + sizes_and_strides_.sizes_arrayref()); + } + + IntArrayRef sizes_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("sizes"); + } + return sizes_and_strides_.sizes_arrayref(); + } + + SymIntArrayRef sym_sizes_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().sizes_; + } else { + // Sizes guaranteed to be non-negative, so unchecked cast is OK + return c10::fromIntArrayRefKnownNonNegative(sizes_default()); + } + } + + // From https://stackoverflow.com/a/3057522/23845 + // TODO: does C++14 have a stdlib template for this? + template + struct identity { + typedef T type; + }; + + template + ArrayRef generic_sizes() { + return _generic_sizes(identity()); + } + + ArrayRef _generic_sizes(identity) { + return sizes(); + } + ArrayRef _generic_sizes(identity) { + return sym_sizes(); + } + + template + ArrayRef generic_strides() { + return _generic_strides(identity()); + } + + ArrayRef _generic_strides(identity) { + return strides(); + } + ArrayRef _generic_strides(identity) { + return sym_strides(); + } + + template + T generic_storage_offset() { + return _generic_storage_offset(identity()); + } + + int64_t _generic_storage_offset(identity) { + return storage_offset(); + } + c10::SymInt _generic_storage_offset(identity) { + return sym_storage_offset(); + } + + /** + * The number of elements in a tensor. + * + * WARNING: Previously, if you were using the Caffe2 API, you could + * test numel() == -1 to see if a tensor was uninitialized. This + * is no longer true; numel always accurately reports the product + * of sizes of a tensor. + */ + int64_t numel() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return numel_custom(); + } + return numel_; + } + + c10::SymInt sym_numel() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_numel_custom(); + } + return c10::SymInt(SymInt::UNCHECKED, numel_); + } + + int64_t numel_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("numel"); + } + return numel_; + } + + c10::SymInt sym_numel_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().numel(); + } else { + return c10::SymInt(SymInt::UNCHECKED, numel_); + } + } + + /** + * Return the number of dimensions of this tensor. Note that 0-dimension + * represents a Tensor that is a Scalar, e.g., one that has a single element. + */ + int64_t dim() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return dim_custom(); + } + return static_cast(sizes_and_strides_.size()); + } + + int64_t dim_default() const { + if (has_symbolic_sizes_strides_) { + return static_cast(symbolic_shape_meta().sizes_.size()); + } else { + return static_cast(sizes_and_strides_.size()); + } + } + + /** + * Return the offset in number of elements into the storage that this + * tensor points to. Most tensors have storage_offset() == 0, but, + * for example, an index into a tensor will have a non-zero storage_offset(). + * + * WARNING: This is NOT computed in bytes. + */ + int64_t storage_offset() const { + // TODO: maybe this should be toggled by strides + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return storage_offset_custom(); + } + return storage_offset_; + } + + c10::SymInt sym_storage_offset() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_storage_offset_custom(); + } + return c10::SymInt(SymInt::UNCHECKED, storage_offset_); + } + + int64_t storage_offset_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("storage_offset"); + } + return storage_offset_; + } + + c10::SymInt sym_storage_offset_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().storage_offset_; + } else { + return c10::SymInt(SymInt::UNCHECKED, storage_offset_); + } + } + + /** + * Return a reference to the strides of this tensor. This reference remains + * valid as long as the tensor is live and not restrided. + */ + IntArrayRef strides() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return strides_custom(); + } + return sizes_and_strides_.strides_arrayref(); + } + + c10::SymIntArrayRef sym_strides() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_strides_custom(); + } + return c10::fromIntArrayRefKnownNonNegative(strides_default()); + } + + IntArrayRef strides_default() const { + if (C10_UNLIKELY(has_symbolic_sizes_strides_)) { + throw_cannot_call_with_symbolic("strides"); + } + return sizes_and_strides_.strides_arrayref(); + } + + c10::SymIntArrayRef sym_strides_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().strides_; + } else { + return c10::fromIntArrayRefKnownNonNegative(strides_default()); + } + } + + /** + * Whether or not a tensor is laid out in contiguous memory. + * + * Tensors with non-trivial strides are not contiguous. See + * compute_contiguous() for the exact definition of whether or not + * a tensor is contiguous or not. + */ + bool is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return is_contiguous_custom(memory_format); + } + return is_contiguous_default(memory_format); + } + + // These are factored into separate functions in case subclasses + // want to use them + bool is_contiguous_default(at::MemoryFormat memory_format) const { + if (has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return symbolic_shape_meta().is_channels_last_contiguous().guard_bool( + __FILE__, __LINE__); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return symbolic_shape_meta() + .is_channels_last_3d_contiguous() + .guard_bool(__FILE__, __LINE__); + } + return symbolic_shape_meta().is_contiguous().guard_bool( + __FILE__, __LINE__); + } + + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_contiguous_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_contiguous_; + } + return is_contiguous_; + } + + bool is_strides_like_default(at::MemoryFormat memory_format) const { + if (has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return symbolic_shape_meta().is_channels_last().guard_bool( + __FILE__, __LINE__); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return symbolic_shape_meta().is_channels_last_3d().guard_bool( + __FILE__, __LINE__); + } else { + return false; + } + } + + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_; + } else { + return false; + } + } + + bool is_non_overlapping_and_dense_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool( + __FILE__, __LINE__); + } else { + return is_non_overlapping_and_dense_; + } + } + + // NB: these dim accessor functions don't have _default(), as you can use + // sizes_default/strides_default + /** + * Return the size of a tensor at some dimension, wrapping the dimension if + * necessary. + * + * NOTE: if you know wrapping is unnecessary, do sizes()[d] instead; it will + * be faster + */ + int64_t size(int64_t d) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return size_custom(d); + } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sizes_and_strides_.size_at_unchecked(d); + } + + c10::SymInt sym_size(int64_t d) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) { + return sym_size_custom(d); + } + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + const auto sizes = this->sym_sizes(); + return sizes[d]; + } + + /** + * Return the stride of a tensor at some dimension, wrapping the dimension + * if necessary. + * + * NOTE: if you know wrapping is unnecessary, do sizes()[d] instead; it will + * be faster + */ + int64_t stride(int64_t d) const { + d = maybe_wrap_dim(d, dim(), false); + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + // TODO: provide stride_custom, symmetrically with size_custom. + // There is presently no user for it; only NestedTensor is using + // size_custom overrideability + return strides_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + // Intentionally don't call default, which also handles symbolic + return sizes_and_strides_.stride_at_unchecked(d); + } + + enum class SizesStridesPolicy : uint8_t { + // Default behavior, e.g., dense tensor. + // + // Can override: nothing + Default = 0, + // Customizable strides behavior, e.g., sparse tensor, + // mkldnn tensor. + // + // Can override: strides(), is_contiguous() + CustomStrides = 1, + // Customizable sizes behavior, e.g., nested tensor + // + // Can override: strides(), is_contiguous(), sizes(), dim(), numel() + CustomSizes = 2 + }; + + protected: + inline bool matches_policy(SizesStridesPolicy policy) const { + return sizes_strides_policy_ >= static_cast(policy); + } + + inline bool matches_custom(SizesStridesPolicy policy) const { + return custom_sizes_strides_ >= static_cast(policy); + } + + inline bool matches_python_custom(SizesStridesPolicy policy) const { + auto r = python_custom_sizes_strides_ >= static_cast(policy); + if (r) { + TORCH_INTERNAL_ASSERT(is_python_dispatch()) + } + return r; + } + + /** + * Customization points for the functions above. sizes_strides_policy_ + * must be set to enable these. + * + * NB: dim is overrideable separately from sizes because it is possible + * for a tensor to have rank, but not well defined sizes. + */ + // sizes_strides_policy_ >= CustomStrides + virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const; + virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; + virtual bool is_non_overlapping_and_dense_custom() const; + // sizes_strides_policy_ >= CustomSizes + // Currently this method only exists to be overwritten by subclasses such as + // NestedTensorImpl. + virtual int64_t size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + + virtual c10::SymInt sym_size_custom(int64_t d) const { + // TODO: We could add support to Python dispatch here. + // TODO: We could call into aten::size.int instead of + // sym_sizes_custom()[d] and enable use of the dispatcher. + d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false); + return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds) + } + + virtual IntArrayRef sizes_custom() const; + virtual IntArrayRef strides_custom() const; + virtual int64_t numel_custom() const; + virtual int64_t storage_offset_custom() const; + virtual int64_t dim_custom() const; + virtual Device device_custom() const; + virtual Layout layout_custom() const; + + virtual c10::SymIntArrayRef sym_sizes_custom() const; + virtual c10::SymIntArrayRef sym_strides_custom() const; + virtual c10::SymInt sym_numel_custom() const; + virtual c10::SymInt sym_storage_offset_custom() const; + + public: + /** + * True if this tensor has storage. See storage() for details. + */ +#ifdef DEBUG + // Allow subclasses to check that their storage_ is never getting set in debug + // builds. + virtual +#else + TENSORIMPL_MAYBE_VIRTUAL +#endif + bool + has_storage() const + // NOTE: we devirtualize this because it arguably shouldn't be an + // error just to ask subclasses if they have storage. + // This used to throw for most subclasses, but OpaqueTensorImpl + // wanted it to successfully return false, so we went ahead and made + // it a non-error. +#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY + { + return storage_; + } +#else + ; +#endif + + /** + * Return the underlying storage of a Tensor. Multiple tensors may share + * a single storage. A Storage is an impoverished, Tensor-like class + * which supports far less operations than Tensor. + * + * Avoid using this method if possible; try to use only Tensor APIs to perform + * operations. + */ + TENSORIMPL_MAYBE_VIRTUAL const Storage& storage() const { + if (C10_UNLIKELY(storage_access_should_throw_)) { + throw_storage_access_error(); + } + return storage_; + } + + /** + * Return the underlying storage, unsafely assuming this is a basic strided + * tensor. In cases where `storage` access would throw, this returns a + * default-constructed Storage. + */ + inline const Storage& unsafe_storage() const { + return storage_; + } + + bool unique_version() const { + return version_counter_.unique(); + } + + protected: + virtual Layout layout_impl() const { + TORCH_CHECK( + false, "layout_impl is only implemented for TensorImpl subclasses."); + } + + public: + // Whether a tensor is sparse COO or not. + bool is_sparse() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + return key_set_.has_all(c10::sparse_ks); + } + + // Whether a tensor is sparse CSR or not. + bool is_sparse_csr() const { + return layout() == kSparseCsr; + } + + // Whether a tensor is sparse CSR/CSC/BSR/BSC or not. + bool is_sparse_compressed() const { + return key_set_.has_all(c10::sparse_csr_ks); + } + + bool is_quantized() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + constexpr auto quantized_ks = DispatchKeySet(DispatchKey::Quantized); + return key_set_.has_all(quantized_ks); + } + + bool is_meta() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_meta(); + } + return device_opt_.has_value() && device_opt_->type() == kMeta; + } + + bool is_cpu() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_cpu(); + } + // Note: we cannot rely on dispatch keys to determine the device type + // of a tensor, because "wrapper" tensors (like FunctionalTensorWrapper) + // don't include backend dispatch keys. + return device_opt_.has_value() && device_opt_->type() == kCPU; + } + + bool is_cuda() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_cuda(); + } + return device_opt_.has_value() && device_opt_->type() == kCUDA; + } + + bool is_xpu() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_xpu(); + } + return device_opt_.has_value() && device_opt_->type() == kXPU; + } + + bool is_ipu() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_ipu(); + } + return device_opt_.has_value() && device_opt_->type() == kIPU; + } + + bool is_xla() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_xla(); + } + return device_opt_.has_value() && device_opt_->type() == kXLA; + } + + bool is_mtia() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_mtia(); + } + return device_opt_.has_value() && device_opt_->type() == kMTIA; + } + + bool is_hpu() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_hpu(); + } + return device_opt_.has_value() && device_opt_->type() == kHPU; + } + + bool is_lazy() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_lazy(); + } + return device_opt_.has_value() && device_opt_->type() == kLazy; + } + + bool is_hip() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_hip(); + } + return device_opt_.has_value() && device_opt_->type() == kHIP; + } + + bool is_ve() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_ve(); + } + return device_opt_.has_value() && device_opt_->type() == kVE; + } + + bool is_privateuseone() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_privateuseone(); + } + return device_opt_.has_value() && device_opt_->type() == kPrivateUse1; + } + + bool is_mkldnn() const { + return key_set_.has_all(c10::mkldnn_ks); + } + + bool is_vulkan() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_vulkan(); + } + return device_opt_.has_value() && device_opt_->type() == kVulkan; + } + + bool is_metal() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_metal(); + } + return device_opt_.has_value() && device_opt_->type() == kMetal; + } + + bool is_mps() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_mps(); + } + return device_opt_.has_value() && device_opt_->type() == kMPS; + } + + bool is_ort() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_ort(); + } + return device_opt_.has_value() && device_opt_->type() == kORT; + } + + bool is_nested() const { + return key_set_.has(DispatchKey::NestedTensor); + } + + // TODO: remove this once we don't automatically enabled Autograd dispatch + // keys + // in TensorImpl constructor. + // DON'T USE THIS API!! It's only created for testing purpose in + // file aten/src/ATen/core/boxing/impl/test_helpers.h + void remove_autograd_key() { + key_set_ = key_set_ - autograd_dispatch_keyset; + } + + // Inference tensor doesn't have autograd or ADInplaceOrView key. + // Invariant: + // Inference tensor has version_counter_.enabled() == false + bool is_inference() { + bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks); + bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + no_ADInplaceOrView == no_Autograd, + "ADInplaceOrView and Autograd keys must be on/off at the same time."); + return no_ADInplaceOrView && no_Autograd; + } + + DeviceIndex get_device() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom().index(); + } + return device_default().index(); + } + + Device device() const { + if (C10_UNLIKELY(device_policy_)) { + return device_custom(); + } + return device_default(); + } + + protected: + c10::Device device_default() const { + TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device"); + // See NOTE [c10::optional operator usage in CUDA] + return *device_opt_; + } + + public: + Layout layout() const { + if (C10_UNLIKELY(layout_policy_)) { + return layout_custom(); + } + + // NB: This method is not virtual and avoid dispatches for perf. + // strided is also the most common layout type, so we check for + // strided case first. + // This keyset must also be kept in sync with the logic in + // is_sparse() / is_sparse_csr() / is_mkldnn() + constexpr auto sparse_and_sparsecsr_and_mkldnn_ks = + c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks; + if (!key_set_.has_any(sparse_and_sparsecsr_and_mkldnn_ks)) { + return kStrided; + } else if (is_sparse()) { + return kSparse; + } else if (is_sparse_compressed()) { + // Typically, the tensor dispatch keys define the tensor layout + // uniquely. This allows using non-virtual layout method for + // better performance. However, when tensor's layout depends, + // say, on tensor attributes, one must use this execution path + // where the corresponding tensor impl class overwrites virtual + // layout_impl() method. + // + // TODO: implement layout() as native function/method so that + // __torch_dispatch__ users will be able to redefine the + // layout() method. + return layout_impl(); + } else { + TORCH_INTERNAL_ASSERT( + is_mkldnn(), "There is an error in the layout calculation logic."); + return kMkldnn; + } + } + + /** + * True if a tensor was auto-wrapped from a C++ or Python number. + * For example, when you write 't + 2', 2 is auto-wrapped into a Tensor + * with `is_wrapped_number_` set to true. + * + * Wrapped numbers do not participate in the result type computation for + * mixed-type operations if there are any Tensors that are not wrapped + * numbers. This is useful, because we want 't + 2' to work with + * any type of tensor, not just LongTensor (which is what integers + * in Python represent). + * + * Otherwise, they behave like their non-wrapped equivalents. + * See [Result type computation] in TensorIterator.h. + * + * Why did we opt for wrapped numbers, as opposed to just having + * an extra function add(Tensor, Scalar)? This helps greatly reduce + * the amount of code we have to write for add, when actually + * a Tensor-Scalar addition is really just a Tensor-Tensor + * addition when the RHS is 0-dim (except for promotion behavior.) + */ + bool is_wrapped_number() const { + return is_wrapped_number_; + } + + /** + * Set whether or not a tensor was auto-wrapped from a C++ or Python + * number. You probably don't want to call this, unless you are + * writing binding code. + */ + void set_wrapped_number(bool value) { + TORCH_INTERNAL_ASSERT(dim() == 0); + is_wrapped_number_ = value; + } + + /** + * Returns true if Tensor supports as_strided and as_strided_backward. + * This is used in autograd to perform inplace update on view Tensors. + * See Note [View + Inplace update for base tensor] and + * [View + Inplace update for view tensor] for details. + * Note this method only returns true for XLA backend, where it + * simulates strided Tensor to support most view ops, but it cannot + * fully support general `as_strided` case. + * It can be expanded as needed in the future, e.g sparse Tensor. + */ + inline bool support_as_strided() const { + if (is_nested()) { + return false; + } + if (key_set_.has(DispatchKey::Functionalize)) { + return false; + } + return device().supports_as_strided(); + } + + // ~~~~~ Autograd API ~~~~~ + // Some methods below are defined in TensorImpl.cpp because Tensor is an + // incomplete type. + + /** + * Set whether or not a tensor requires gradient. + */ + void set_requires_grad(bool requires_grad); + + /** + * True if a tensor requires gradient. Tensors which require gradient + * have history tracked for any operations performed on them, so that + * we can automatically differentiate back to them. A tensor that + * requires gradient and has no history is a "leaf" tensor, which we + * accumulate gradients into. + */ + bool requires_grad() const; + + /** + * Return a mutable reference to the gradient. This is conventionally + * used as `t.grad() = x` to set a gradient to a completely new tensor. + */ + at::Tensor& mutable_grad(); + + /** + * Return the accumulated gradient of a tensor. This gradient is written + * into when performing backwards, when this tensor is a leaf tensor. + */ + const at::Tensor& grad() const; + + /** + * Whether or not the imaginary part of the tensor should be negated + */ + inline bool is_conj() const { + constexpr auto conjugate_ks = DispatchKeySet(DispatchKey::Conjugate); + return key_set_.has_all(conjugate_ks); + } + + /** + * Set whether or not to take the conjugate of the tensor (flip the imaginary + * bit). + */ + void _set_conj(bool value) { + if (value) { + key_set_ = key_set_.add(DispatchKey::Conjugate); + TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype()))); + } else { + key_set_ = key_set_.remove(DispatchKey::Conjugate); + } + } + + /** + * XXX: do not use, private api! + * Update the backend component related keys to the backend component + * corresponding to this device. + */ + void _change_backend_component_keys(c10::Device device); + + /** + * Whether or not the tensor is a zerotensor + */ + inline bool _is_zerotensor() const { + constexpr auto zerotensor_ks = DispatchKeySet(DispatchKey::ZeroTensor); + return key_set_.has_all(zerotensor_ks); + } + + /** + Set whether or not the tensor is a zero tensor + */ + void _set_zero(bool value) { + if (value) { + TORCH_INTERNAL_ASSERT( + false, + "Please call `torch._efficientzerotensor` if you want to create a tensor with no storage."); + } else { + key_set_ = key_set_.remove(DispatchKey::ZeroTensor); + } + } + + /** + * Whether or not the tensor should be negated + */ + inline bool is_neg() const { + constexpr auto negative_ks = DispatchKeySet(DispatchKey::Negative); + return key_set_.has_all(negative_ks); + } + + /** + * Set whether or not to take the conjugate of the tensor (flip the imaginary + * bit). + */ + void _set_neg(bool value) { + if (value) { + key_set_ = key_set_.add(DispatchKey::Negative); + } else { + key_set_ = key_set_.remove(DispatchKey::Negative); + } + } + + /** + * Return the accumulated gradient of a tensor. This gradient is computed + * using forward mode AD. + * + * This is an internal API that should never be used by end users. + * + * The API is as follows: + * - "level" allows to specify the level of forward AD nesting for which the + * gradient should be returned. Note that since levels are not fully + * supported yet, this argument should be 0. See documentation for + * torch::autograd::enter_dual_level for more details about forward AD + * nesting. + * - "self" should represent the Tensor whose forward grad is accessed. It + * is required when dealing with view. + */ + const at::Tensor& _fw_grad(uint64_t level, const at::TensorBase& self) const; + + /** + * Sets the forward gradient for this Tensor. + * The given Tensor might not be used directly and its content will be copied. + * + * This is an internal API that should never be used by end users. + * + * The API is as follows: + * - "new_grad" is a Tensor containing the new value of the gradient that + * should be set + * - "self" should represent the Tensor whose forward grad is accessed. It + * is required when dealing with view. + * - "level" allows to specify the level of forward AD nesting for which the + * gradient should be set. Note that since levels are not fully supported + * yet, this argument should be 0. See documentation for + * torch::autograd::enter_dual_level for more details about forward AD + * nesting. + * - "is_inplace_op" is a boolean flag that tells if this gradient was + * generated by an inplace operation or an out of place one. This allows + * better error checking. + */ + void _set_fw_grad( + const at::TensorBase& new_grad, + const at::TensorBase& self, + uint64_t level, + bool is_inplace_op); + + /** + * Return a typed data pointer to the actual data which this tensor refers to. + * This checks that the requested type (from the template parameter) matches + * the internal type of the tensor. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if + * the size is 0. + * + * WARNING: If a tensor is not contiguous, you MUST use strides when + * performing index calculations to determine the location of elements in + * the tensor. We recommend using 'TensorAccessor' to handle this computation + * for you; this class is available from 'Tensor'. + */ + template + const T* data_dtype_initialized() const { + return data_dtype_initialized_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * Return a mutable typed data pointer to the actual data which this + * tensor refers to. This checks that the requested type (from the + * template parameter) matches the internal type of the tensor. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if + * the size is 0. + * + * WARNING: If a tensor is not contiguous, you MUST use strides when + * performing index calculations to determine the location of elements in + * the tensor. We recommend using 'TensorAccessor' to handle this computation + * for you; this class is available from 'Tensor'. + */ + template + T* mutable_data_dtype_initialized() { + return data_dtype_initialized_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + // Shared implementation of data_dtype_initialized() and + // mutable_data_dtype_initialized(). + template + T* data_dtype_initialized_impl(const Func& get_data) const { + TORCH_CHECK( + data_type_.Match>(), + "Tensor type mismatch, caller expects elements to be ", + caffe2::TypeMeta::TypeName>(), + ", while tensor contains ", + data_type_.name(), + ". "); + return data_ptr_impl_impl(get_data); + } + + public: + /** + * More efficient helper for Tensor::data_ptr(). Like data(), but + * does not do a type check. Unlike the untemplated data(), does + * check has_storage() and storage_initialized(). + */ + template + inline const T* data_ptr_impl() const { + return data_ptr_impl_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * More efficient helper for Tensor::data_ptr(). Like data(), but + * does not do a type check. Unlike the untemplated data(), does + * check has_storage() and storage_initialized(). + */ + template + inline T* mutable_data_ptr_impl() { + return data_ptr_impl_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + // Shared implementation of mutable_data_ptr_impl() and the future + // mutable_data_ptr_impl(). + template + __ubsan_ignore_pointer_overflow__ T* data_ptr_impl_impl( + const Func& get_data) const { + if (C10_UNLIKELY(!has_storage())) { + throw_data_ptr_access_error(); + } + TORCH_CHECK( + storage_initialized(), + "The tensor has a non-zero number of elements, but its data is not allocated yet. " + "Caffe2 uses a lazy allocation, so you will need to call " + "mutable_data() or raw_mutable_data() to actually allocate memory."); + // Caller does the type check. + // Note: storage_offset_ can be non-null even for zero-elements tensors + // (for example if created as `torch.empty(5)[10:]`) that triggers + // applying non-zero offset to null pointer in UBSan + return get_data() + storage_offset_; + } + + public: + /** + * Return a const void* data pointer to the actual data which this + * tensor refers to. + * + * It is invalid to call data() on a dtype-uninitialized tensor, even if the + * size is 0. + * + * WARNING: The data pointed to by this tensor may not contiguous; do NOT + * assume that itemsize() * numel() is sufficient to compute the bytes that + * can be validly read from this tensor. + */ + inline const void* data() const { + return data_impl( + [this] { return static_cast(storage_.data()); }); + } + + /** + * Return a void* data pointer to the actual data which this tensor refers to. + * + * It is invalid to call mutable_data() on a dtype-uninitialized + * tensor, even if the size is 0. + * + * WARNING: The data pointed to by this tensor may not contiguous; do NOT + * assume that itemsize() * numel() is sufficient to compute the bytes that + * can be validly read from this tensor. + */ + inline void* mutable_data() { + return data_impl( + [this] { return static_cast(storage_.mutable_data()); }); + } + + private: + /// Shared implementation of data() and mutable_data(). + /// + /// get_data must return a byte-addressed pointer, e.g. char*, + /// std::byte const*, etc. + template + Void* data_impl(const Func& get_data) const { + if (C10_UNLIKELY(!has_storage())) { + throw_data_ptr_access_error(); + } + TORCH_CHECK( + dtype_initialized(), + "Cannot access data pointer of Tensor that doesn't have initialized dtype " + "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data() on x)"); + auto* data = get_data(); + static_assert( + sizeof(*data) == 1, "get_data must return a byte-addressed pointer."); + // Computing an offset into an empty tensor would be UB, since an empty + // tensor's storage will be nullptr, and adding a nonzero offset to nullptr + // is UB. So we skip the offset computation in this case. + if (is_empty()) { + return nullptr; + } + return data + data_type_.itemsize() * storage_offset_; + } + + public: + /** + * Returns the TypeMeta of a tensor, which describes what data type + * it is (e.g., int, float, ...) + */ + const caffe2::TypeMeta dtype() const { + return data_type_; + } + + /** + * Return the size of a single element of this tensor in bytes. + */ + size_t itemsize() const { + TORCH_CHECK( + dtype_initialized(), + "Cannot report itemsize of Tensor that doesn't have initialized dtype " + "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data() on x)"); + return data_type_.itemsize(); + } + + void set_backend_meta(intrusive_ptr backend_meta) { + get_extra_meta().backend_meta_ = std::move(backend_meta); + } + + c10::BackendMeta* get_backend_meta() { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->backend_meta_.get(); + } + + intrusive_ptr get_backend_meta_intrusive_ptr() const { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->backend_meta_; + } + + void release_storage_and_set_meta_custom_data_ptr_error_msg_( + c10::optional s) { + storage_ = {}; + set_storage_access_should_throw(); + get_extra_meta().custom_data_ptr_error_msg_ = s; + get_extra_meta().custom_storage_error_msg_ = std::move(s); + } + + protected: + /** + * Returns the human-readable name of the actual type of this object (e.g., + * TensorImpl, BatchedTensorImpl, etc.). Used for error messages. + */ + virtual const char* tensorimpl_type_name() const { + return "TensorImpl"; + } + + private: + [[noreturn]] void throw_storage_access_error() const; + [[noreturn]] void throw_data_ptr_access_error() const; + + ExtraMeta& get_extra_meta() { + if (!extra_meta_) { + extra_meta_ = std::make_unique(); + } + return *extra_meta_; + } + + c10::SymbolicShapeMeta& symbolic_shape_meta() { + TORCH_INTERNAL_ASSERT(extra_meta_ && extra_meta_->symbolic_shape_meta_); + return *extra_meta_->symbolic_shape_meta_; + } + + const c10::SymbolicShapeMeta& symbolic_shape_meta() const { + TORCH_INTERNAL_ASSERT(extra_meta_ && extra_meta_->symbolic_shape_meta_); + return *extra_meta_->symbolic_shape_meta_; + } + + public: + /** + * True if a tensor has no elements (e.g., numel() == 0). + */ + inline bool is_empty() const { + return numel() == 0; + } + + // if we are going to use sym sizes, we should be setting sym strides at the + // same time, otherwise it's very easy to misuse this API + void set_sizes_and_strides( + c10::SymIntArrayRef sizes, + c10::SymIntArrayRef strides, + c10::optional storage_offset = c10::nullopt); + // This is renamed to avoid breaking overload BC + void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes); + void generic_set_sizes_contiguous(c10::IntArrayRef sizes) { + set_sizes_contiguous(sizes); + } + + /** + * Change the size at some dimension. This DOES NOT update strides; + * thus, most changes to size will not preserve contiguity. You probably + * also want to call set_stride() when you call this. + * + * TODO: This should be jettisoned in favor of `set_sizes_and_strides`, + * which is harder to misuse. + */ + virtual void set_size(int64_t dim, int64_t new_size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_size ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !matches_policy(SizesStridesPolicy::CustomSizes), + "set_size() called on tensor with dynamic shapes or customized size behavior") + sizes_and_strides_.size_at(dim) = new_size; + refresh_numel(); + refresh_contiguous(); + } + + /** + * Change the stride at some dimension. + * + * TODO: This should be jettisoned in favor of `set_sizes_and_strides`, + * which is harder to misuse. + */ + virtual void set_stride(int64_t dim, int64_t new_stride) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_stride ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "set_stride() called on tensor with symbolic shape") + sizes_and_strides_.stride_at_unchecked(dim) = new_stride; + refresh_contiguous(); + } + + /** + * Set the offset into the storage of this tensor. + * + * WARNING: This does NOT check if the tensor is in bounds for the new + * location at the storage; the caller is responsible for checking this + * (and resizing if necessary.) + */ + virtual void set_storage_offset(int64_t storage_offset) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_storage_offset ", + err_msg_tensor_metadata_change_not_allowed); + // TODO: this should probably consult policy + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "set_storage_offset() called on tensor with symbolic shape") + storage_offset_ = storage_offset; + } + + /** + * Like set_sizes_and_strides but assumes contiguous strides. + * + * WARNING: This function does not check if the requested + * sizes/strides are in bounds for the storage that is allocated; + * this is the responsibility of the caller + */ + void set_sizes_contiguous(IntArrayRef new_size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_contiguous ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !matches_policy(SizesStridesPolicy::CustomStrides), + "tried to directly modify sizes for customized tensor"); + sizes_and_strides_.set_sizes(new_size); + + refresh_numel(); + empty_tensor_restride( + MemoryFormat::Contiguous); // calls refresh_contiguous() + } + + /** + * Set the sizes and strides of a tensor. + * + * WARNING: This function does not check if the requested + * sizes/strides are in bounds for the storage that is allocated; + * this is the responsibility of the caller + */ + void set_sizes_and_strides( + IntArrayRef new_size, + IntArrayRef new_stride, + c10::optional storage_offset = c10::nullopt) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_and_strides ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "set_sizes_and_strides() called on tensor with symbolic shape") + TORCH_CHECK( + new_size.size() == new_stride.size(), + "dimensionality of sizes (", + new_size.size(), + ") must match dimensionality of strides (", + new_stride.size(), + ")"); + const auto new_dim = new_size.size(); + bool overflowed = false; + sizes_and_strides_.set_sizes(new_size); + + if (new_dim > 0) { + for (size_t dim = new_dim - 1;; dim--) { + if (new_stride[dim] >= 0) { + sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim]; + } else { + // XXX: This behavior is surprising and may need to be removed to + // support negative strides. Some pytorch functions rely on it: + // for example, torch.cat (run TestTorch.test_cat_empty). + if (dim == new_dim - 1) { + sizes_and_strides_.stride_at_unchecked(dim) = 1; + } else { + // Keep stride monotonically increasing to match NumPy. + overflowed |= c10::mul_overflows( + sizes_and_strides_.stride_at_unchecked(dim + 1), + std::max( + sizes_and_strides_.size_at_unchecked(dim + 1), 1), + std::addressof(sizes_and_strides_.stride_at_unchecked(dim))); + } + } + if (dim == 0) + break; + } + TORCH_CHECK(!overflowed, "Stride calculation overflowed"); + } + + refresh_numel(); + refresh_contiguous(); + + if (storage_offset.has_value()) { + storage_offset_ = *storage_offset; + } + } + + /** + * Set whether a tensor allows changes to its metadata (e.g. sizes / strides / + * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor + * ] for details. + */ + void set_allow_tensor_metadata_change(bool value) { + // TODO: at some point, we should kill this field completely. + allow_tensor_metadata_change_ = true; + } + + /** + * True if a tensor allows changes to its metadata (e.g. sizes / strides / + * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor + * ] for details. + */ + bool allow_tensor_metadata_change() const { + return allow_tensor_metadata_change_; + } + + /** + * Set the pointer to autograd metadata. + */ + void set_autograd_meta( + std::unique_ptr autograd_meta); + + /** + * Return the pointer to autograd metadata. May return nullptr if the + * tensor does not track gradients. + */ + c10::AutogradMetaInterface* autograd_meta() const; + + /** + * Set the pointer to named tensor metadata. + */ + void set_named_tensor_meta( + std::unique_ptr named_tensor_meta) { + TORCH_WARN_ONCE( + "Named tensors and all their associated APIs are an experimental feature ", + "and subject to change. Please do not use them for anything important ", + "until they are released as stable."); +#ifdef DEBUG + if (named_tensor_meta) { + TORCH_INTERNAL_ASSERT(named_tensor_meta->slow_dim() == dim()); + } +#endif + if (named_tensor_meta) { + get_extra_meta().named_tensor_meta_ = std::move(named_tensor_meta); + key_set_ = key_set_.add(DispatchKey::Named); + } else { + if (extra_meta_) { + extra_meta_->named_tensor_meta_ = nullptr; + } + key_set_ = key_set_.remove(DispatchKey::Named); + } + } + + void set_python_dispatch(bool k) { + if (k) { + key_set_ = key_set_.add(c10::python_ks); + } else { + key_set_ = key_set_ - c10::python_ks; + } + } + + bool is_python_dispatch() const { + return key_set_.has_all(c10::python_ks); + } + + /** + * Return the pointer to named tensor metadata. + */ + const c10::NamedTensorMetaInterface* named_tensor_meta() const { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->named_tensor_meta_.get(); + } + + c10::NamedTensorMetaInterface* named_tensor_meta() { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->named_tensor_meta_.get(); + } + + bool has_named_tensor_meta() const { + if (!extra_meta_) { + return false; + } + return extra_meta_->named_tensor_meta_ != nullptr; + } + + // NOTE [ TensorImpl Shallow-Copying ] + // + // TensorImpl shallow-copying is used when we want to have two Variables share + // the same tensor metadata (e.g. sizes / strides / storage pointer / + // storage_offset), but each with a different autograd history. Example call + // sites: + // + // 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create + // `var_detached` that shares the same tensor metadata with `var`, but with a + // completely new autograd history. + // 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy tensor + // metadata from `tensor` into `var`, while keeping `var`'s original + // AutogradMeta. + // + // Functions that shallow-copy a TensorImpl (such as + // `shallow_copy_and_detach()` / `shallow_copy_from()` / + // `copy_tensor_metadata()`) copy the tensor metadata fields (e.g. sizes / + // strides / storage pointer / storage_offset) by value. However, the + // following fields are not copied: + // + // 1. the AutogradMeta pointer, because it is unique for each Variable. + // 2. the version counter, because the destination TensorImpl's version + // counter is either set to the passed-in `version_counter` (in + // `shallow_copy_and_detach()` and `copy_tensor_metadata()`), or it is kept + // intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for + // details. + // + // In `shallow_copy_and_detach()` and `copy_tensor_metadata()`, the passed-in + // `allow_tensor_metadata_change` determines whether the TensorImpl + // shallow-copy allows changes to its metadata (e.g. sizes / strides / storage + // / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for + // details. + // + // In `shallow_copy_from()`, we don't check the destination TensorImpl's + // `allow_tensor_metadata_change_`, because `shallow_copy_from()` is used for + // implementing functions such as `var.set_data(tensor)`, which changes + // `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to + // be ignored. + + /** + * One TensorImpl can be copied to another TensorImpl if they have the same + * DispatchKeySet. The only two special cases (for legacy reason) are: + * CPU is compatible with CUDA and SparseCPU is + * compatible with SparseCUDA. + */ + inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + auto is_dense = [](DispatchKeySet ts) { + constexpr auto dense_backends = DispatchKeySet( + {BackendComponent::CPUBit, + BackendComponent::CUDABit, + BackendComponent::MPSBit, + BackendComponent::HIPBit, + BackendComponent::XPUBit, + BackendComponent::HPUBit}); + constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense); + return ts.has_any(dense_k) && ts.has_any(dense_backends); + }; + auto is_sparse = [](DispatchKeySet ts) { + constexpr auto sparse_backends = DispatchKeySet( + {BackendComponent::CPUBit, + BackendComponent::CUDABit, + BackendComponent::HIPBit, + BackendComponent::XPUBit}); + constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); + return ts.has_any(sparse_k) && ts.has_any(sparse_backends); + }; + auto is_sparse_compressed = [](DispatchKeySet ts) { + constexpr auto sparse_compressed_k = + DispatchKeySet(DispatchKey::SparseCsr); + return ts.has_any(sparse_compressed_k); + }; + return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || + (is_sparse(key_set_) && is_sparse(from)) || + (is_sparse_compressed(key_set_) && is_sparse_compressed(from)); + ; + } + + private: + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + public: + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const; + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + virtual void shallow_copy_from(const c10::intrusive_ptr& impl) { + copy_tensor_metadata( + /*src_impl=*/impl.get(), + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + } + + // Inference tensor doesn't have version counter, + // set_version_counter is no-op for them. + void set_version_counter(const c10::VariableVersion& version_counter) { + TORCH_CHECK( + !(is_inference() && version_counter.enabled()), + "Cannot set version_counter for inference tensor"); + version_counter_ = version_counter; + } + + void set_version_counter(c10::VariableVersion&& version_counter) { + TORCH_CHECK( + !(is_inference() && version_counter.enabled()), + "Cannot set version_counter for inference tensor"); + version_counter_ = std::move(version_counter); + } + + const c10::VariableVersion& version_counter() const noexcept { + return version_counter_; + } + + void bump_version() { + version_counter_.bump(); + } + + impl::PyObjectSlot* pyobj_slot() { + return &pyobj_slot_; + } + + const impl::PyObjectSlot* pyobj_slot() const { + return &pyobj_slot_; + } + + private: + // See NOTE [c10::optional operator usage in CUDA] + // We probably don't want to expose this publicly until + // the note is addressed. + c10::optional device_opt() const { + return device_opt_; + } + + public: + /** + * The device type of a Tensor, e.g., DeviceType::CPU or DeviceType::CUDA. + */ + DeviceType device_type() const { + // TODO: A useful internal assert would be to show that device_opt_ is null + // only if you are an undefined tensor + TORCH_CHECK( + device_opt_.has_value(), + "device_type cannot be run on undefined Tensor"); + // See NOTE [c10::optional operator usage in CUDA] + return (*device_opt_).type(); + } + + /** + * @brief Extends the outer-most dimension of this tensor by num elements, + * preserving the existing data. + * + * The underlying data may be reallocated in order to accommodate the new + * elements, in which case this tensors' capacity is grown at a factor of + * growthPct. This ensures that Extend runs on an amortized O(1) time + * complexity. + * + * This op is auto-asynchronous if the underlying device (CUDA) supports it. + */ + void Extend(int64_t num, float growthPct); + + /** + * @brief Reserve space for the underlying tensor. + * + * This must be called after Resize(), since we only specify the first + * dimension This does not copy over the old data to the newly allocated space + */ + void ReserveSpace(int64_t outer_dim); + + /** + * @brief Resizes a tensor. + * + * Resize takes in a vector of ints specifying the dimensions of the tensor. + * You can pass in an empty vector to specify that it is a scalar (i.e. + * containing one single item). + * + * The underlying storage may be deleted after calling Resize: if the new + * shape leads to a different number of items in the tensor, the old memory + * is deleted and new memory will be allocated next time you call + * mutable_data(). However, if the shape is different but the total number of + * items is the same, the underlying storage is kept. + * + * This method respects caffe2_keep_on_shrink. Consult the internal logic + * of this method to see exactly under what circumstances this flag matters. + */ + template + void Resize(Ts... dim_source) { + bool size_changed = SetDims(dim_source...); + if (size_changed) { + HandleResize(); + } + } + + template + void Resize(const std::vector& dim_source) { + Resize(ArrayRef(dim_source)); + } + + /** + * Resizes the tensor without touching underlying storage. + * This requires the total size of the tensor to remains constant. + */ + void Reshape(const std::vector& dims); + + /** + * Release whatever memory the tensor was holding but keep size and type + * information. Subsequent call to mutable_data will trigger new memory + * allocation. + */ + void FreeMemory(); + + /** + * @brief Shares the data with another tensor. + * + * To share data between two tensors, the sizes of the two tensors must be + * equal already. The reason we do not implicitly do a Resize to make the two + * tensors have the same shape is that we want to allow tensors of different + * shapes but the same number of items to still be able to share data. This + * allows one to e.g. have a n-dimensional Tensor and a flattened version + * sharing the same underlying storage. + * + * The source tensor should already have its data allocated. + */ + // To be deprecated + void ShareData(const TensorImpl& src); + + void ShareExternalPointer( + DataPtr&& data_ptr, + const caffe2::TypeMeta data_type, + size_t size_bytes); + + /** + * Returns a mutable raw pointer of the underlying storage. Since we will need + * to know the type of the data for allocation, a TypeMeta object is passed in + * to specify the necessary information. This is conceptually equivalent of + * calling mutable_data() where the TypeMeta parameter meta is derived from + * the type T. This function differs from mutable_data() in the sense that + * the type T can be specified during runtime via the TypeMeta object. + * + * If the existing data does not match the desired type, it will be deleted + * and a new storage will be created. + */ + inline void* raw_mutable_data(const caffe2::TypeMeta& meta) { + // For 0-size tensors it's fine to return any pointer (including nullptr) + if (data_type_ == meta && storage_initialized()) { + return static_cast( + static_cast(storage_.mutable_data()) + + storage_offset_ * meta.itemsize()); + } else { + bool had_special_dtor = data_type_.placementDelete() != nullptr; + storage_offset_ = 0; + data_type_ = meta; + // NB: device is not changed + + // We can reuse the existing buffer if the current data does not have + // a special destructor and the new data doesn't have a special + // constructor. + if (numel_ == 0 || + (meta.placementNew() == nullptr && !had_special_dtor && + (storage_.nbytes() >= (numel_ * data_type_.itemsize())))) { + TORCH_INTERNAL_ASSERT( + storage_offset_ == 0); // because we just reallocated + return storage_.mutable_data(); + } + Allocator* allocator = storage_.allocator(); + // Storage might have nullptr allocator in rare cases, for example, if + // an external memory segment has been wrapped with Tensor and we don't + // know how to reallocate it. However, in order to preserve legacy C2 + // behavior, we allow reallocating the memory using default allocator. + if (allocator == nullptr) { + allocator = GetAllocator(storage_.device_type()); + } + if (meta.placementNew()) { + // For types that need placement new, we will call it, as well as + // making sure that when the data is freed, it calls the right + // destruction procedure. + auto size = numel_; + auto dtor = data_type_.placementDelete(); + auto data_ptr = allocator->allocate(numel_ * data_type_.itemsize()); + storage_.set_data_ptr_noswap(PlacementDeleteContext::makeDataPtr( + std::move(data_ptr), dtor, size, storage_.device())); + data_type_.placementNew()(storage_.mutable_data(), numel_); + } else { + // For fundamental type, new and delete is easier. + storage_.set_data_ptr_noswap( + allocator->allocate(numel_ * data_type_.itemsize())); + } + storage_.set_nbytes(numel_ * data_type_.itemsize()); + TORCH_INTERNAL_ASSERT( + storage_offset_ == 0); // because we just reallocated + device_opt_ = storage_.device(); + return storage_.mutable_data(); + } + } + + /** + * Returns a typed pointer of the underlying storage. + * + * For fundamental types, we reuse possible existing storage if there + * is sufficient capacity. + */ + template + inline T* mutable_data() { + if (storage_initialized() && data_type_.Match()) { + return static_cast(storage_.mutable_data()) + storage_offset_; + } + // Check it here statically - otherwise TypeMeta would throw the runtime + // error in attempt to invoke TypeMeta::ctor() + static_assert( + std::is_default_constructible::value, + "Tensor can't hold non-default-constructable types"); + return static_cast(raw_mutable_data(caffe2::TypeMeta::Make())); + } + + /** + * True if a tensor is storage initialized. A tensor may become + * storage UNINITIALIZED after a Resize() or FreeMemory() + */ + bool storage_initialized() const { + TORCH_CHECK( + has_storage(), + "cannot call storage_initialized on tensor that does not have storage"); + return storage_.data() || numel_ == 0; + } + + /** + * True if a tensor is dtype initialized. A tensor allocated with + * Caffe2-style constructors is dtype uninitialized until the + * first time mutable_data() is called. + */ + bool dtype_initialized() const noexcept { + return data_type_ != caffe2::TypeMeta(); + } + + void set_storage_keep_dtype(at::Storage storage) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_storage ", + err_msg_tensor_metadata_change_not_allowed); + storage_ = std::move(storage); + device_opt_ = storage_.device(); + } + + void set_storage_and_dtype( + at::Storage storage, + const caffe2::TypeMeta data_type) { + set_storage_keep_dtype(std::move(storage)); + data_type_ = data_type; + } + + void empty_tensor_restride_symint(MemoryFormat memory_format); + + /** + * Set the strides of the tensor to match memory_format + * + * WARNING: This function doesn't rearrange data and assumes tensor is a + * memory contiguous + */ + void empty_tensor_restride(MemoryFormat memory_format) { + if (has_symbolic_sizes_strides_) { + empty_tensor_restride_symint(memory_format); + return; + } +#ifdef DEBUG + TORCH_INTERNAL_ASSERT( + compute_numel() == numel_, + "If you are seeing this error, that means empty_tensor_restride was " + "called before setting correct numel"); +#endif + switch (memory_format) { + case MemoryFormat::Contiguous: { + // dim_ is a virtual call, don't repeat it + const auto dim_ = dim(); + sizes_and_strides_.resize(dim_); + if (dim_ > 0) { + bool overflowed = false; + const auto last_idx = dim_ - 1; + sizes_and_strides_.stride_at_unchecked(last_idx) = 1; + for (auto i = last_idx - 1; i >= 0; --i) { + overflowed |= c10::mul_overflows( + sizes_and_strides_.stride_at_unchecked(i + 1), + std::max( + sizes_and_strides_.size_at_unchecked(i + 1), 1), + std::addressof(sizes_and_strides_.stride_at_unchecked(i))); + } + TORCH_CHECK(!overflowed, "Stride calculation overflowed"); + } + break; + } + case MemoryFormat::ChannelsLast: { + TORCH_CHECK( + dim() == 4, "required rank 4 tensor to use channels_last format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes())); + break; + } + case MemoryFormat::ChannelsLast3d: { + TORCH_CHECK( + dim() == 5, + "required rank 5 tensor to use channels_last_3d format"); + set_sizes_and_strides(sizes(), get_channels_last_strides_3d(sizes())); + break; + } + case MemoryFormat::Preserve: + TORCH_CHECK(false, "unsupported memory format ", memory_format); + // Cleaning warning messages, no need to break as TORCH_CHECK(false) + // terminates flow. + // break; + case MemoryFormat::NumOptions: + TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format); + } + // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually + // exclusive see #24090 + refresh_contiguous(); + } + + bool is_strides_like(at::MemoryFormat memory_format) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return is_strides_like_custom(memory_format); + } + return is_strides_like_default(memory_format); + } + + bool is_strides_like_channels_last() const { + return is_strides_like(at::MemoryFormat::ChannelsLast); + } + + bool is_strides_like_channels_last_3d() const { + return is_strides_like(at::MemoryFormat::ChannelsLast3d); + } + + bool is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return is_non_overlapping_and_dense_custom(); + } + return is_non_overlapping_and_dense_default(); + } + + bool has_symbolic_sizes_strides() const { + return has_symbolic_sizes_strides_; + } + + private: + void HandleResize(); + + // The Caffe2 Resize() method supports being called both as Resize({2,2}) as + // well as variadic with Resize(2, 2). These overloads provide all of the + // supported calling configurations, while being overloads (and not templates) + // so that implicit conversions still work. + // + // SetDims on ArrayRef is internally implemented as a template, so we can + // handle both ArrayRefs of different types (there are some uses of + // Resize in Caffe2 which pass in int, not int64_t.) + + template < + typename T, + typename = typename std::enable_if_t>> + bool SetDimsTemplate(ArrayRef src) { + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "SetDims() called on tensor with symbolic shape") + + auto old_numel = numel_; + sizes_and_strides_.resize(src.size()); + int64_t new_numel = 1; + for (const auto i : c10::irange(src.size())) { + new_numel *= src[i]; + sizes_and_strides_.size_at_unchecked(i) = src[i]; + } + numel_ = new_numel; + empty_tensor_restride(MemoryFormat::Contiguous); + return numel_ != old_numel; + } + + bool SetDims(ArrayRef s) { + return SetDimsTemplate(s); + } + + bool SetDims(ArrayRef s) { + return SetDimsTemplate(s); + } + + bool SetDims(ArrayRef s) { + return SetDimsTemplate(s); + } + + bool SetDims() { + return SetDims(IntArrayRef{}); + } + + bool SetDims(const int64_t d0) { + return SetDims(IntArrayRef{d0}); + } + + bool SetDims(const int64_t d0, const int64_t d1) { + return SetDims(IntArrayRef{d0, d1}); + } + + bool SetDims(const int64_t d0, const int64_t d1, const int64_t d2) { + return SetDims(IntArrayRef{d0, d1, d2}); + } + + bool SetDims( + const int64_t d0, + const int64_t d1, + const int64_t d2, + const int64_t d3) { + return SetDims(IntArrayRef{d0, d1, d2, d3}); + } + + /** + * Compute the number of elements based on the sizes of a tensor. + */ + // NB: This is ONLY called when sizes_and_strides_ is used directly; if + // we are virtualizing, then numel calls are virtualized as well, and this + // should never get called + int64_t compute_numel() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_); +#if C10_HAS_BUILTIN_OVERFLOW() && !defined(C10_MOBILE) + // Use overflow checks if supported by the compiler + return safe_compute_numel(); +#else + return c10::multiply_integers(sizes_and_strides_.sizes_arrayref()); +#endif + } + + /** + * Compute the number of elements based on the sizes of a + * tensor. Catches integer overflow that may occur when a tensor + * using a sparse layout has multiple dimensions with large sizes. + */ + int64_t safe_compute_numel() const { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_); + uint64_t n = 1; + bool overflows = + c10::safe_multiplies_u64(sizes_and_strides_.sizes_arrayref(), &n); + constexpr auto numel_max = std::min( + static_cast(std::numeric_limits::max()), + static_cast(std::numeric_limits::max())); + + overflows |= (n > numel_max); + TORCH_CHECK(!overflows, "numel: integer multiplication overflow"); + return static_cast(n); + } + + /** + * Compute whether or not a tensor is contiguous based on the sizes and + * strides of a tensor. + */ + bool compute_contiguous(identity) const; + + bool compute_channels_last_contiguous_2d(identity) const; + + bool compute_channels_last_contiguous_3d(identity) const; + + bool compute_strides_like_channels_last_2d(identity) const; + + bool compute_strides_like_channels_last_3d(identity) const; + + bool compute_non_overlapping_and_dense(identity) const; + + protected: + /** + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. + * + * For tensors with sparse layouts, use safe_refresh_numel() instead + * because it will catch integer overflow that may occur for tensors + * with sparse layouts and large dimensions. + * + * NB: We may uselessly recompute cached numel even in situations where + * it is completely never used (e.g., if CustomSizes for Python). However, + * we still must keep it up to date in case the Python overload + * returns None (in which case we will consult the field here). This also + * implies that sizes/strides will never be complete garbage; in the + * very worst case scenario, it will reflect a 1-dim zero size tensor. + */ + void refresh_numel() { + if (has_symbolic_sizes_strides_) { + symbolic_shape_meta().refresh_numel(); + } else { + numel_ = compute_numel(); + } + } + + /** + * Recompute the cached numel of a tensor. Call this if you modify + * sizes. Use only for tensors with sparse layouts because only + * sparse tensor are likely to have sizes that may lead to integer + * overflow when computing numel. + */ + void safe_refresh_numel() { + if (has_symbolic_sizes_strides_) { + // NB: sym numel is done with symbolic integers, which handle overflow + // checking + symbolic_shape_meta().refresh_numel(); + } else { + numel_ = safe_compute_numel(); + } + } + + private: + // NB: the TypeId argument prevents confusion where you pass a true/false + // literal and pick the wrong overload + + void _set_is_contiguous(identity, bool b) { + is_contiguous_ = b; + } + + void _set_is_channels_last_contiguous(identity, bool b) { + is_channels_last_contiguous_ = b; + } + + void _set_is_channels_last_3d_contiguous(identity, bool b) { + is_channels_last_3d_contiguous_ = b; + } + + void _set_is_channels_last(identity, bool b) { + is_channels_last_ = b; + } + + void _set_is_channels_last_3d(identity, bool b) { + is_channels_last_3d_ = b; + } + + void _set_is_non_overlapping_and_dense(identity, bool b) { + is_non_overlapping_and_dense_ = b; + } + + // These are little wrappers over the real compute_ functions that + // can make use of other contiguity fields to short circuit. + + bool compute_is_non_overlapping_and_dense_dim4(identity type_id) { + return is_contiguous_ || is_channels_last_contiguous_ || + compute_non_overlapping_and_dense(type_id); + } + + bool compute_channels_last_contiguous_3d_dim5(identity type_id) { + return !is_channels_last_contiguous_ && + compute_channels_last_contiguous_3d(type_id); + } + + bool compute_channels_last_2d_dim5(identity type_id) { + return !is_channels_last_3d_contiguous_ && + compute_strides_like_channels_last_2d(type_id); + } + + bool compute_channels_last_3d_dim5(identity type_id) { + return !is_channels_last_ && compute_strides_like_channels_last_3d(type_id); + } + + bool compute_is_non_overlapping_and_dense_dim5(identity type_id) { + return is_contiguous_ || is_channels_last_contiguous_ || + is_channels_last_3d_contiguous_ || + compute_non_overlapping_and_dense(type_id); + } + + bool compute_is_non_overlapping_and_dense_anydim(identity type_id) { + return is_contiguous_ || compute_non_overlapping_and_dense(type_id); + } + + template + void _refresh_contiguous() { + auto type_id = identity(); + // Note: + // Dim 0, 1, 2 will never be a channels last 2d/3d format + // Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this + // point) Dim 4+ is possibly be a channels last 3d format (Dim 5 only at + // this point) + switch (dim()) { + case 4: { + _set_is_contiguous(type_id, compute_contiguous(type_id)); + _set_is_channels_last_contiguous( + type_id, compute_channels_last_contiguous_2d(type_id)); + _set_is_channels_last_3d_contiguous(type_id, false); + _set_is_channels_last( + type_id, compute_strides_like_channels_last_2d(type_id)); + _set_is_channels_last_3d(type_id, false); + _set_is_non_overlapping_and_dense( + type_id, compute_is_non_overlapping_and_dense_dim4(type_id)); + break; + } + case 5: { + _set_is_contiguous(type_id, compute_contiguous(type_id)); + _set_is_channels_last_contiguous( + type_id, compute_channels_last_contiguous_2d(type_id)); + _set_is_channels_last_3d_contiguous( + type_id, compute_channels_last_contiguous_3d_dim5(type_id)); + _set_is_channels_last(type_id, compute_channels_last_2d_dim5(type_id)); + _set_is_channels_last_3d( + type_id, compute_channels_last_3d_dim5(type_id)); + _set_is_non_overlapping_and_dense( + type_id, compute_is_non_overlapping_and_dense_dim5(type_id)); + break; + } + default: + // is_channels_last_ and is_channels_last_3d_ are suggested + // memory_format. Being channels_last_contiguous doesn't necessarily + // mean the tensor is strided like channels_last: for strides on channel + // dimension could suggest desired memory_layout, but it doesn't affect + // memory storage + _set_is_contiguous(type_id, compute_contiguous(type_id)); + _set_is_channels_last_contiguous(type_id, false); + _set_is_channels_last_3d_contiguous(type_id, false); + _set_is_channels_last(type_id, false); + _set_is_channels_last_3d(type_id, false); + _set_is_non_overlapping_and_dense( + type_id, compute_is_non_overlapping_and_dense_anydim(type_id)); + break; + } + } + + protected: + /** + * Recompute the cached contiguity of a tensor. Call this if you modify sizes + * or strides. + */ + void refresh_contiguous() { + if (has_symbolic_sizes_strides_) { + symbolic_shape_meta().refresh_contiguous(); + } else { + _refresh_contiguous(); + } + } + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change); + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change); + + private: + static void copy_tensor_metadata_except_version_counter( + const TensorImpl* src_impl, + TensorImpl* dest_impl, + bool allow_tensor_metadata_change); + + protected: + // Error message to show when the user tries to change tensor metadata on + // Tensor created from .data or .detach(). + // + // See NOTE [ Metadata Change for a Detached Tensor ] for details. + static const char* const err_msg_tensor_metadata_change_not_allowed; + + static void copy_generic_tensor_metadata( + const TensorImpl* src_impl, + TensorImpl* dest_impl); + + public: + void set_storage_access_should_throw() { + storage_access_should_throw_ = true; + } + + public: + void set_custom_sizes_strides(SizesStridesPolicy policy) { + custom_sizes_strides_ = static_cast(policy); + refresh_sizes_strides_policy(); + } + + void set_python_custom_sizes_strides(SizesStridesPolicy policy) { + python_custom_sizes_strides_ = static_cast(policy); + refresh_sizes_strides_policy(); + } + + void set_custom_device(bool custom_device) { + custom_device_ = custom_device; + refresh_device_policy(); + } + + void set_custom_layout(bool custom_layout) { + custom_layout_ = custom_layout; + refresh_layout_policy(); + } + + void set_python_custom_device(bool custom_device) { + python_custom_device_ = custom_device; + refresh_device_policy(); + } + + void set_python_custom_layout(bool custom_layout) { + python_custom_layout_ = custom_layout; + refresh_layout_policy(); + } + + protected: + void refresh_sizes_strides_policy() { + if (has_symbolic_sizes_strides_) { + sizes_strides_policy_ = + static_cast(SizesStridesPolicy::CustomSizes); + } else { + sizes_strides_policy_ = + std::max(custom_sizes_strides_, python_custom_sizes_strides_); + } + } + + void refresh_device_policy() { + device_policy_ = custom_device_ || python_custom_device_; + } + + void refresh_layout_policy() { + layout_policy_ = custom_layout_ || python_custom_layout_; + } + + protected: + Storage storage_; + + private: + // This pointer points to an AutogradMeta struct that stores autograd-specific + // fields (such as grad_ / grad_fn_ / grad_accumulator_). This pointer always + // has unique ownership (meaning only one TensorImpl can own it at a time). + // + // autograd_meta_ can be nullptr, as an optimization. When this occurs, it is + // equivalent to having an autograd_meta_ pointing to a default constructed + // AutogradMeta; intuitively, tensors which don't require grad will have this + // field set to null. + // + // This means accessors on autograd_meta_ have to be careful to test if they + // got a nullptr, and handle default behavior appropriately in that case. + // + // Note that we don't enforce the invariant that if the AutogradMeta is + // default constructed, it is nullptr (to do this, we'd have to continuously + // check if an AutogradMeta became, by mutation, equal to the default + // constructed form. (This might be useful, but it seems rare enough that + // a requires_grad=True variable will turn back into the requires_grad=False + // version.) So there are three representable states: + // + // 1. autograd_meta_ == nullptr + // 2. autograd_meta_ is default constructed (semantically, same as (1)) + // 3. autograd_meta_ has nontrivial information content + // + std::unique_ptr autograd_meta_ = nullptr; + + protected: + std::unique_ptr extra_meta_ = nullptr; + + c10::VariableVersion version_counter_; + + impl::PyObjectSlot pyobj_slot_; + + c10::impl::SizesAndStrides sizes_and_strides_; + + int64_t storage_offset_ = 0; + // If sizes and strides are empty, the numel is 1!! However, most of the + // time, we will immediately set sizes to {0} and reset numel to 0. + // (Can't do that in the default initializers, because there's no way to + // spell "allocate a one-element array" for strides_). + int64_t numel_ = 1; + + // INVARIANT: When storage is non-null, this type meta must + // agree with the type meta in storage + caffe2::TypeMeta data_type_; + + // NOTE [c10::optional operator usage in CUDA] + // Our optional definition doesn't compile in .cu file if `value()` or + // `operator->` are used. Instead, we always use `operator*`. + // See https://github.com/pytorch/pytorch/issues/18496 for more info. + // If this is too burdensome to maintain, we can just + // manually implement this with an additional bool. + + // INVARIANT: When storage is non-null, this Device must + // agree with the type meta in storage. + // + // INVARIANT: device_opt_ is only nullopt for undefined tensors + // (which do not have a device.) + c10::optional device_opt_; + + // default member initializers for bit-fields only available with -std=c++2a + // or -std=gnu++2a + inline void init_bitfields() { + is_contiguous_ = true; + is_channels_last_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_ = false; + is_channels_last_3d_contiguous_ = false; + is_non_overlapping_and_dense_ = true; + is_wrapped_number_ = false; + allow_tensor_metadata_change_ = true; + reserved_ = false; + sizes_strides_policy_ = static_cast(SizesStridesPolicy::Default); + custom_sizes_strides_ = static_cast(SizesStridesPolicy::Default); + python_custom_sizes_strides_ = + static_cast(SizesStridesPolicy::Default); + python_custom_device_ = false; + python_custom_layout_ = false; + custom_device_ = false; + custom_layout_ = false; + device_policy_ = false; + layout_policy_ = false; + storage_access_should_throw_ = false; + has_symbolic_sizes_strides_ = false; + } + + // Tensor is contiguous + bool is_contiguous_ : 1; + + // Tensor is a subclass that does not permit storage access. + bool storage_access_should_throw_ : 1; + + // Tensor is stored in the channels last 2d memory format, when dimensions + // order is (N)CHW and C-strides < W-strides < H-strides (< N-strides) + // (If size of any dimension is equal to 1, this dimension strides value + // is not taken into account). + bool is_channels_last_ : 1; + + // Channels last contiguous tensor is channel last tensor which occupies + // contiguous memory block. + bool is_channels_last_contiguous_ : 1; + + // Tensor is stored in the channels last 3d memory format, when dimensions + // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< + // N-strides) (If size of any dimension is equal to 1, this dimension strides + // value is not taken into account). + bool is_channels_last_3d_ : 1; + + // Channels last 3d contiguous tensor is channel last 3d tensor which occupies + // contiguous memory block. + bool is_channels_last_3d_contiguous_ : 1; + + // Dense tensor is the tensor that store values in a contiguous block of + // memory. Non-overlapping tensor is the tensor in which elements occupy + // individual non-repetitive memory. + bool is_non_overlapping_and_dense_ : 1; + + bool is_wrapped_number_ : 1; + + // NOTE [ Metadata Change for a Detached Tensor ] + // + // Normally, a user is allowed to change the tensor metadata + // (e.g. sizes / strides / storage / storage_offset) of a tensor. + // However, if the tensor is created by `t1_detached = t1.data` in Python + // or `t1_detached = t1.detach()` in Python/C++, those changes to the + // tensor metadata of `t1_detached` will not be propagated back to the + // original tensor `t1`. In order to make such changes explicitly illegal, + // we created the `allow_tensor_metadata_change_` flag, to prevent users + // from changing metadata of the detached tensor and expecting the original + // tensor to also be updated. + // + // NOTE: For a full list of tensor metadata fields, please see + // `copy_tensor_metadata()` in TensorImpl and its subclasses to find + // which fields are copied by value. + bool allow_tensor_metadata_change_ : 1; + + // we decide to keep reserved_ and it will + // live in Tensor after the split + // The logic is that if Extend() or ReserveSpace() were ever called, + // then subsequent Resize()s will not free up Storage. + bool reserved_ : 1; + + // Call _custom() virtual methods for + // strides()/is_contiguous()/sizes()/dim()/numel() + // This is a combination of sizes_strides_custom_dispatch_ + // and has_symbolic_sizes_strides_ + uint8_t sizes_strides_policy_ : 2; + + // Whether or not sizes_and_strides_ contains a symbolic value. + bool has_symbolic_sizes_strides_ : 1; + + // Call _custom() virtual method for + // strides()/is_contiguous()/sizes()/dim()/numel() + uint8_t custom_sizes_strides_ : 2; + + // Combo of custom_ and python_custom_ + bool device_policy_ : 1; + bool layout_policy_ : 1; + + // Call _custom() virtual method for device() + bool custom_device_ : 1; + + // Call _custom() virtual method for layout() + bool custom_layout_ : 1; + + // Call into Python for + // strides()/is_contiguous()/sizes()/dim()/numel() + uint8_t python_custom_sizes_strides_ : 2; + + // Call into Python for device() + bool python_custom_device_ : 1; + + // Call into Python for layout() + bool python_custom_layout_ : 1; + + // The set of DispatchKeys which describe this tensor. NB: this + // does NOT include Autograd (historically, it did, but + // not anymore!) + // + // INVARIANT: extra_meta_->named_tensor_meta_ != nullptr <==> + // key_set_.has(DispatchKey::Named) + DispatchKeySet key_set_; + + private: + // C10_TensorImpl_Size_Check_Dummy_Class needs to be friends with + // TensorImpl so it can inspect the size of private fields + template < + size_t cplusplus, + size_t clang_ver_major, + size_t gcc_ver, + size_t gcc_ver_minor, + size_t nvcc, + size_t cuda_version, + size_t cuda_version_major, + size_t ptr_size> + friend class C10_TensorImpl_Size_Check_Dummy_Class; +}; + +// Note [TensorImpl size constraints] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Changed the size of TensorImpl? If the size went down, good for +// you! Adjust the documentation below and the expected size. +// Did it go up? Read on... +// +// Struct size matters. In some production systems at Facebook, we have +// 400M live tensors during a training run. Do the math: every 64-bit +// word you add to Tensor is an extra 3.2 gigabytes in RAM. +// +// If you are a Facebook employee, you can check if the run in question +// has tipped you over the point using the command here: +// https://fburl.com/q5enpv98 +// +// For reference, we OOMed at 160 bytes (20 words) per TensorImpl. +// This is not counting overhead from strides out-of-line allocation and +// StorageImpl space and this is from before we inlined sizes and strides +// directly into TensorImpl as SmallVectors. +// +// Our memory usage on 32-bit systems is suboptimal, but we're not checking +// for it at the moment (to help avoid rage inducing cycles when the +// 32-bit number is wrong). +// +// Current breakdown: +// +// vtable pointer +// strong refcount TODO: pack these into one word +// weak refcount +// storage pointer +// autograd metadata pointer +// named tensor metadata pointer +// version counter pointer +// PyObjectSlot +// SizesAndStrides size/pointer +// SizesAndStrides sizes (pre-allocated 0) +// SizesAndStrides sizes (pre-allocated 1) +// SizesAndStrides sizes (pre-allocated 2) +// SizesAndStrides sizes (pre-allocated 3) +// SizesAndStrides sizes (pre-allocated 4) +// SizesAndStrides strides (pre-allocated 0) +// SizesAndStrides strides (pre-allocated 1) +// SizesAndStrides strides (pre-allocated 2) +// SizesAndStrides strides (pre-allocated 3) +// SizesAndStrides strides (pre-allocated 4) +// storage offset +// numel +// data type, device, is_contiguous, storage_access_should_throw_, bitfields +// DispatchKeySet +// + +// Various preprocessor macros we use to check that the +// TensorImpl size hasn't changed unexpectedly. We undef +// these later. +#ifndef __NVCC__ +#define C10_NVCC 0 +#else +#define C10_NVCC __NVCC__ +#endif + +#ifndef __CUDA_VER_MAJOR__ +#define C10_CUDA_VERSION_MAJOR 0 +#else +#define C10_CUDA_VERSION_MAJOR __CUDA_VER_MAJOR__ +#endif + +#ifndef CUDA_VERSION +#define C10_CUDA_VERSION 0 +#else +#define C10_CUDA_VERSION CUDA_VERSION +#endif + +#ifndef __clang_major__ +#define C10_CLANG_MAJOR_VERSION 0 +#else +#define C10_CLANG_MAJOR_VERSION __clang_major__ +#endif + +#ifndef __GNUC__ +#define C10_GCC_VERSION 0 +#else +#define C10_GCC_VERSION __GNUC__ +#endif + +#ifndef __GNUC_MINOR__ +#define C10_GCC_VERSION_MINOR 0 +#else +#define C10_GCC_VERSION_MINOR __GNUC_MINOR__ +#endif + +// We use a templatized class to both contain the logic of checking the sizes +// as well as to provide compile-time information that might be useful in +// figuring out why sizes may have changed. +// All the compile time information is given by the template fields that are +// always printed by the compiler when the static_assert fails. +template < + size_t cplusplus = __cplusplus, + size_t clang_ver_major = C10_CLANG_MAJOR_VERSION, + size_t gcc_ver = C10_GCC_VERSION, + size_t gcc_ver_minor = C10_GCC_VERSION_MINOR, + size_t nvcc = C10_NVCC, + size_t cuda_version = C10_CUDA_VERSION, + size_t cuda_version_major = C10_CUDA_VERSION_MAJOR, + size_t ptr_size = sizeof(void*)> +class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { + // Names of (non-bitfield) fields in TensorImpl; used to provide + // compile-time info about fields whose size changes unexpectedly. + enum class FieldNameEnum { + storage_, + autograd_meta_, + extra_meta_, + version_counter_, + pyobj_slot_, + sizes_and_strides_, + storage_offset_, + numel_, + data_type_, + device_opt_, + key_set_, + TOTAL_SIZE + }; + + // Provides compile-time equality check that reveals what numbers + // were used and on which quantity + template + constexpr static bool are_equal() { + static_assert( + Actual == Expected, + "Actual and Expected sizes of a field did not match!"); + return true; + } + + // Provides compile-time <= check that reveals what numbers + // were used and on which quantity + template + constexpr static bool is_le() { + static_assert( + Actual <= Expected, + "Actual and Expected sizes of a field did not match!"); + return true; + } + + public: + // Compile-time check that TensorImpl field sizes are as expected + // + // Observed total sizes and associated versions + // If you find a flag that predicts when unique_ptr has 16 bytes + // on 64-bit systems or when sizes_and_strides_ is 84 vs 88 bytes + // on 32-bit systems you get a cookie! + // Length | LLVM | GCC | C++ | CUDA + // 192 | ? | 11.2 | 201703 | 11040 + // 208 | ? | 11.2 | 201703 | 11040 + // 208 | ? | 11.2 | 201402 | 11040 + // 192 | ? | 11.2 | 201402 | 11040 + // 160 | 12 | 4.2 | 201703 | 0 + // + // To keep things clean, we split on systems here. + +#if UINTPTR_MAX == 0xFFFFFFFF + // This is a 32-bit system + static constexpr bool check_sizes() { + constexpr size_t tsize = 20 * sizeof(int64_t); + + // clang-format off + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + is_le(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + is_le(); + // clang-format on + + return true; + } +#else + // This is a 64-bit system + static constexpr bool check_sizes() { + constexpr size_t tsize = 26 * sizeof(int64_t); + + // clang-format off + are_equal(); + // On some systems involving NVCC the size of unique_ptr is 16 bytes. We haven't + // figured out how to detect those via macro preprocessors yet, so we use <= + // comparisons for the relevant fields. + is_le(); + is_le(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + are_equal(); + is_le(); + // clang-format on + + return true; + } +#endif +}; + +// We use a class to encapsulate size-checking logic with +// templates to capture sizes and flags. We call this within +// a static assert to prove there is no run-time behaviour. +// Since the methods we call return either true or fail their +// own static_asserts, we should never see the error messages +// below. We have to provide it though for c++ <17. +static_assert( + C10_TensorImpl_Size_Check_Dummy_Class<>::check_sizes(), + "You should not see this message."); + +// Clean up after ourselves +#undef C10_NVCC +#undef C10_CUDA_VERSION_MAJOR +#undef C10_CUDA_VERSION +#undef C10_CLANG_MAJOR_VERSION +#undef C10_GCC_VERSION +#undef C10_GCC_VERSION_MINOR + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/TensorOptions.h b/MLPY/Lib/site-packages/torch/include/c10/core/TensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..476af8d63ede96b15f013ae177230d5e492a21ee --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/TensorOptions.h @@ -0,0 +1,787 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10 { + +DispatchKey computeDispatchKey( + c10::optional dtype, + c10::optional layout, + c10::optional device); + +inline ScalarType dtype_or_default(c10::optional dtype) { + return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); }); +} + +inline caffe2::TypeMeta dtype_or_default( + c10::optional dtype) { + return value_or_else(dtype, [] { return get_default_dtype(); }); +} + +inline Layout layout_or_default(c10::optional layout) { + return layout.value_or(kStrided); +} + +inline Device device_or_default(c10::optional device) { + return value_or_else(device, [] { return Device(kCPU); }); +} + +inline bool pinned_memory_or_default(c10::optional pinned_memory) { + return pinned_memory.value_or(false); +} + +/// A class to encapsulate construction axes of an Tensor. TensorOptions was +/// designed to support the Python style API for specifying construction options +/// on factory functions, e.g., +/// +/// torch.zeros(2, 3, dtype=torch.int32) +/// +/// Because C++ doesn't natively support keyword arguments, there must be +/// another way of specifying keyword-like arguments. TensorOptions is a +/// builder class which can be used to construct this "dictionary" of keyword +/// arguments: functions which support TensorOptions conventionally take this +/// argument optionally as their last argument. +/// +/// WARNING: In PyTorch, there are `torch::` variants of factory functions, +/// e.g., torch::zeros for at::zeros. These return Variables (while the +/// stock ATen functions return plain Tensors). If you mix these functions +/// up, you WILL BE SAD. +/// +/// Rather than use the constructor of this class directly, you should prefer to +/// use the constructor functions, and then chain setter methods on top of them. +/// +/// at::device(at::kCUDA).dtype(kInt) +/// at::dtype(at::kInt) +/// +/// Additionally, anywhere a TensorOptions is expected, you can directly +/// pass at::kCUDA / at::kInt, and it will implicitly convert to a +/// TensorOptions. +/// +/// Here are some recommended ways to create a 2x2 tensor of zeros +/// with certain properties. These all *implicitly* make use of +/// TensorOptions, even if they don't mention the class explicitly: +/// +/// at::zeros({2,2}, at::kCUDA); +/// at::zeros({2,2}, at::kLong); +/// at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong())); +/// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1 +/// at::zeros({2,2}, at::requires_grad()); +/// + +/// NOTE [ TensorOptions Constructors ] +/// +/// TensorOptions is like a dictionary with entries from the set: +/// {requires_grad, device, dtype, layout}, where each entry may be +/// unspecified (i.e., is optional). It is used to specify the properties of +/// tensors in many places both in C++ internal and API, e.g., tensor factory +/// methods like `at::empty({10}, options)`, tensor conversions like +/// `tensor.to(...)`, etc. +/// +/// To provide a simple API that is consistent with Python, where one can do +/// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a +/// `torch.layout`, we want TensorOptions to be implicitly convertible from +/// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have +/// three implicit constructors from each of these three types. +/// +/// This is sufficient for `ScalarType` and `Layout` as they are simple Enum +/// classes. However, `Device` is an ordinary class with implicit constructors +/// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be +/// consistent with Python API, where strings are treated as equivalent with a +/// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a +/// `torch.device("cuda:1")` is accepted). To support the syntax +/// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure +/// that `TensorOptions` is implicitly constructible with any arguments that a +/// `Device` can constructed from. So we have, +/// +/// /* implicit */ TensorOptions(T&& device) : TensorOptions() { +/// this->set_device(device); +/// } +/// +/// template ::value>> +/// /* implicit */ TensorOptions(Args&&... args) +/// : TensorOptions(Device(std::forward(args)...)) {} +/// +/// +/// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`. +/// Compiler will complain about ambiguity between the copy constructor and the +/// `Device` constructor because `{kCUDA, 1}` can be converted to both a +/// `TensorOption` and a `Device`. +/// +/// To get around this, we templatize the `Device` constructor. Since overload +/// resolution is done before template resolution, our problem is solved. + +DispatchKey computeDispatchKey( + optional dtype, + optional layout, + optional device); + +struct C10_API TensorOptions { + TensorOptions() + : requires_grad_(false), + pinned_memory_(false), + has_device_(false), + has_dtype_(false), + has_layout_(false), + has_requires_grad_(false), + has_pinned_memory_(false), + has_memory_format_(false) {} + + /// Constructs a `TensorOptions` object with the given layout. + /* implicit */ TensorOptions(Layout layout) : TensorOptions() { + this->set_layout(layout); + } + + /// Constructs a `TensorOptions` object with the given device. + /// See NOTE [ TensorOptions Constructors ] on why this is templatized. + template < + typename T, + typename = std::enable_if_t, Device>>> + /* implicit */ TensorOptions(T&& device) : TensorOptions() { + this->set_device(std::forward(device)); + } + + /// Constructs a `TensorOptions` object from arguments allowed in `Device` + /// constructors. + /// + /// See NOTE [ TensorOptions Constructors ]. + /// + /// NB: Ideally we only allow implicit constructors here. But there is no easy + /// way to detect them. So we have this one that allows explicit + /// constructors too. + template < + typename... Args, + typename = std::enable_if_t>> + /* implicit */ TensorOptions(Args&&... args) + : TensorOptions(Device(std::forward(args)...)) {} + + /// Constructs a `TensorOptions` object with the given dtype. + /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() { + this->set_dtype(dtype); + } + + /// legacy constructor to support ScalarType + /* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() { + this->set_dtype(dtype); + } + + /// Constructs a `TensorOptions` object with the given memory format. + /* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() { + set_memory_format(memory_format); + } + + /// Return a copy of `TensorOptions` with `device` set to the given one, or + /// cleared if `device` is `nullopt`. + C10_NODISCARD TensorOptions + device(c10::optional device) const noexcept { + TensorOptions r = *this; + r.set_device(device); + return r; + } + + /// Return a copy of `TensorOptions` with `device` set to the given one. + /// (This overload ensures that variadic template c10::optional constructor + /// for Device work correctly.) + template + C10_NODISCARD TensorOptions device(Args&&... args) const noexcept { + return device( + c10::optional(std::in_place, std::forward(args)...)); + } + + /// Return a copy of `TensorOptions`, but with device set to CUDA, and the + /// device index set to the given one. + /// + /// TODO: This function encourages bad behavior (assuming CUDA is + /// the only device that matters). Get rid of it / rename it. + C10_NODISCARD TensorOptions + device_index(c10::DeviceIndex device_index) const noexcept { + return device(Device::Type::CUDA, device_index); + } + + /// Return a copy of `TensorOptions` with `dtype` set to the given one. + C10_NODISCARD TensorOptions + dtype(c10::optional dtype) const noexcept { + TensorOptions r = *this; + r.set_dtype(dtype); + return r; + } + + // legacy function to support ScalarType + C10_NODISCARD TensorOptions + dtype(c10::optional dtype) const noexcept { + TensorOptions r = *this; + r.set_dtype(dtype); + return r; + } + + // Since dtype is taken... + template + TensorOptions& dtype() { + dtype_ = caffe2::TypeMeta::Make(); + has_dtype_ = true; + return *this; + } + + /// Sets the layout of the `TensorOptions`. + C10_NODISCARD TensorOptions + layout(c10::optional layout) const noexcept { + TensorOptions r = *this; + r.set_layout(layout); + return r; + } + + /// Sets the `requires_grad` property of the `TensorOptions`. + C10_NODISCARD TensorOptions + requires_grad(c10::optional requires_grad) const noexcept { + TensorOptions r = *this; + r.set_requires_grad(requires_grad); + return r; + } + + /// Sets the `pinned_memory` property on the `TensorOptions`. + C10_NODISCARD TensorOptions + pinned_memory(c10::optional pinned_memory) const noexcept { + TensorOptions r = *this; + r.set_pinned_memory(pinned_memory); + return r; + } + + /// Sets the `memory_format` property on `TensorOptions`. + C10_NODISCARD TensorOptions + memory_format(c10::optional memory_format) const noexcept { + TensorOptions r = *this; + r.set_memory_format(memory_format); + return r; + } + + /// Returns the device of the `TensorOptions`. + Device device() const noexcept { + return device_or_default(device_opt()); + } + + /// Returns whether the device is specified. + bool has_device() const noexcept { + return has_device_; + } + + /// Returns the device of the `TensorOptions`, or `c10::nullopt` if + /// device is not specified. + c10::optional device_opt() const noexcept { + return has_device_ ? c10::make_optional(device_) : c10::nullopt; + } + + /// Returns the device index of the `TensorOptions`. + c10::DeviceIndex device_index() const noexcept { + return device().index(); + } + + /// Returns the dtype of the `TensorOptions`. + caffe2::TypeMeta dtype() const noexcept { + return dtype_or_default(dtype_opt()); + } + + /// Returns whether the dtype is specified. + bool has_dtype() const noexcept { + return has_dtype_; + } + + /// Returns the dtype of the `TensorOptions`, or `c10::nullopt` if + /// device is not specified. + c10::optional dtype_opt() const noexcept { + return has_dtype_ ? c10::make_optional(dtype_) : c10::nullopt; + } + + /// Returns the layout of the `TensorOptions`. + Layout layout() const noexcept { + return layout_or_default(layout_opt()); + } + + /// Returns whether the layout is specified. + bool has_layout() const noexcept { + return has_layout_; + } + + /// Returns the layout of the `TensorOptions`, or `c10::nullopt` if + /// layout is not specified. + c10::optional layout_opt() const noexcept { + return has_layout_ ? c10::make_optional(layout_) : c10::nullopt; + } + + /// Returns the `requires_grad` property of the `TensorOptions`. + bool requires_grad() const noexcept { + return has_requires_grad_ ? requires_grad_ : false; + } + + /// Returns whether the `requires_grad` is specified. + bool has_requires_grad() const noexcept { + return has_requires_grad_; + } + + /// Returns the `requires_grad` property of the `TensorOptions`, or + /// `c10::nullopt` if `requires_grad` is not specified. + c10::optional requires_grad_opt() const noexcept { + return has_requires_grad_ ? c10::make_optional(requires_grad_) + : c10::nullopt; + } + + /// Returns the `pinned_memory` property of the `TensorOptions`. + bool pinned_memory() const noexcept { + return pinned_memory_or_default(pinned_memory_opt()); + } + + /// Returns whether the `pinned_memory` is specified. + bool has_pinned_memory() const noexcept { + return has_pinned_memory_; + } + + /// Returns if the layout is sparse + bool is_sparse() const { + return layout_ == c10::Layout::Sparse; + } + + /// Returns if the layout is sparse CSR, deprecated, use + /// is_sparse_compressed() instead + bool is_sparse_csr() const { + return layout_ == c10::Layout::SparseCsr; + } + + bool is_sparse_compressed() const { + return layout_ == c10::Layout::SparseCsr || + layout_ == c10::Layout::SparseCsc || + layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc; + } + + // For compatibility with legacy tensor.type() comparisons + bool type_equal(const TensorOptions& other) const { + return computeDispatchKey() == other.computeDispatchKey() && + typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype()); + } + + /// Returns the `pinned_memory` property of the `TensorOptions`, or + /// `c10::nullopt` if `pinned_memory` is not specified. + c10::optional pinned_memory_opt() const noexcept { + return has_pinned_memory_ ? c10::make_optional(pinned_memory_) + : c10::nullopt; + } + + /// Returns whether the `memory_layout` is specified + bool has_memory_format() const noexcept { + return has_memory_format_; + } + + // NB: memory_format() getter is PURPOSELY not defined, as the default + // behavior of memory_format varies from function to function. + + /// Returns the `memory_layout` property of `TensorOptions, or + /// `c10::nullopt` if `memory_format` is not specified. + c10::optional memory_format_opt() const noexcept { + return has_memory_format_ ? c10::make_optional(memory_format_) + : c10::nullopt; + } + + // Resolves the ATen backend specified by the current construction axes. + // TODO: Deprecate this + Backend backend() const { + return at::dispatchKeyToBackend(computeDispatchKey()); + } + + /// Return the right-biased merge of two TensorOptions. This has the + /// effect of overwriting settings from self with specified options + /// of options. + /// + /// NB: This merging operation does NOT respect device merges. + /// For example, if you device({kCUDA, 1}).merge_in(kCUDA) + /// you will get kCUDA in the end! Functions like Tensor.new_empty + /// ensure the right device is selected anyway by way of a + /// device guard. + /// + TensorOptions merge_in(TensorOptions options) const noexcept { + TensorOptions merged = *this; + if (options.has_device()) + merged.set_device(options.device_opt()); + if (options.has_dtype()) + merged.set_dtype(options.dtype_opt()); + if (options.has_layout()) + merged.set_layout(options.layout_opt()); + // NB: requires grad is right biased; not a logical AND/OR! + if (options.has_requires_grad()) + merged.set_requires_grad(options.requires_grad_opt()); + if (options.has_pinned_memory()) + merged.set_pinned_memory(options.pinned_memory_opt()); + if (options.has_memory_format()) + merged.set_memory_format(options.memory_format_opt()); + return merged; + } + + // TODO remove after TensorOptions rationalization + TensorOptions merge_memory_format( + c10::optional optional_memory_format) const noexcept { + TensorOptions merged = *this; + if (optional_memory_format.has_value()) { + merged.set_memory_format(*optional_memory_format); + } + return merged; + } + + // INVARIANT: computeDispatchKey returns only the subset of dispatch keys for + // which dispatchKeyToBackend is injective, if it is defined at all (for + // the most part, this just means that this function never returns an + // Autograd key) + DispatchKey computeDispatchKey() const { + return c10::computeDispatchKey( + optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt()); + } + + private: + // These methods are currently private because I'm not sure if it's wise + // to actually publish them. They are methods because I need them in + // the constructor and the functional API implementation. + // + // If you really, really need it, you can make these public, but check if you + // couldn't just do what you need with the functional API. Similarly, these + // methods are not chainable, because if you wanted chaining, you probably + // want to use the functional API instead. (It's probably OK to make + // these chainable, because these functions are all explicitly annotated + // with a ref-qualifier, the trailing &, that makes them illegal to call + // on temporaries.) + + /// Mutably set the device of `TensorOptions`. + void set_device(c10::optional device) & noexcept { + if (device) { + device_ = *device; + has_device_ = true; + } else { + has_device_ = false; + } + } + + /// Mutably set the dtype of `TensorOptions`. + void set_dtype(c10::optional dtype) & noexcept { + if (dtype) { + dtype_ = *dtype; + has_dtype_ = true; + } else { + has_dtype_ = false; + } + } + + // legacy function to support ScalarType + void set_dtype(c10::optional dtype) & noexcept { + if (dtype) { + dtype_ = scalarTypeToTypeMeta(*dtype); + has_dtype_ = true; + } else { + has_dtype_ = false; + } + } + + /// Mutably set the layout of `TensorOptions`. + void set_layout(c10::optional layout) & noexcept { + if (layout) { + layout_ = *layout; + has_layout_ = true; + } else { + has_layout_ = false; + } + } + + /// Mutably set the `requires_grad` property of `TensorOptions`. + void set_requires_grad(c10::optional requires_grad) & noexcept { + if (requires_grad) { + requires_grad_ = *requires_grad; + has_requires_grad_ = true; + } else { + has_requires_grad_ = false; + } + } + + /// Mutably set the `pinned_memory` property of `TensorOptions`. + void set_pinned_memory(c10::optional pinned_memory) & noexcept { + if (pinned_memory) { + pinned_memory_ = *pinned_memory; + has_pinned_memory_ = true; + } else { + has_pinned_memory_ = false; + } + } + + /// Mutably set the `memory_Format` property of `TensorOptions`. + void set_memory_format(c10::optional memory_format) & noexcept { + if (memory_format) { + memory_format_ = *memory_format; + has_memory_format_ = true; + } else { + has_memory_format_ = false; + } + } + + // WARNING: If you edit TensorOptions to add more options, you + // may need to adjust the implementation of Tensor::options. + // The criteria for whether or not Tensor::options must be adjusted + // is whether or not the new option you added should preserved + // by functions such as empty_like(); if it should be preserved, + // you must adjust options(). + // + // TODO: MemoryFormat is not implemented in this way + + // NB: We didn't use c10::optional here, because then we can't pack + // the has_***_ boolean fields. + + Device device_ = at::kCPU; // 16-bit + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 16-bit + Layout layout_ = at::kStrided; // 8-bit + MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit + + // Bitmask required here to get this to fit inside 32 bits (or even 64 bits, + // for that matter) + + bool requires_grad_ : 1; + bool pinned_memory_ : 1; + + bool has_device_ : 1; + bool has_dtype_ : 1; + bool has_layout_ : 1; + bool has_requires_grad_ : 1; + bool has_pinned_memory_ : 1; + bool has_memory_format_ : 1; +}; + +// We should aspire to fit in one machine-size word; but a size greater than two +// words is too much. (We are doing terribly on 32-bit archs, where we require +// three machine size words to store tensor options. Eek!) +static_assert( + sizeof(TensorOptions) <= sizeof(int64_t) * 2, + "TensorOptions must fit in 128-bits"); + +/// Convenience function that returns a `TensorOptions` object with the `dtype` +/// set to the given one. +inline TensorOptions dtype(caffe2::TypeMeta dtype) { + return TensorOptions().dtype(dtype); +} + +// legacy function to support ScalarType +inline TensorOptions dtype(ScalarType dtype) { + return TensorOptions().dtype(scalarTypeToTypeMeta(dtype)); +} + +/// Convenience function that returns a `TensorOptions` object with the `layout` +/// set to the given one. +inline TensorOptions layout(Layout layout) { + return TensorOptions().layout(layout); +} + +/// Convenience function that returns a `TensorOptions` object with the `device` +/// set to the given one. +inline TensorOptions device(Device device) { + return TensorOptions().device(device); +} + +/// Convenience function that returns a `TensorOptions` object with the +/// `device` set to CUDA and the `device_index` set to the given one. +inline TensorOptions device_index(c10::DeviceIndex device_index) { + return TensorOptions().device_index(device_index); +} + +/// Convenience function that returns a `TensorOptions` object with the +/// `requires_grad` set to the given one. +inline TensorOptions requires_grad(bool requires_grad = true) { + return TensorOptions().requires_grad(requires_grad); +} + +/// Convenience function that returns a `TensorOptions` object with the +/// `memory_format` set to the given one. +inline TensorOptions memory_format(MemoryFormat memory_format) { + return TensorOptions().memory_format(memory_format); +} + +C10_API std::ostream& operator<<( + std::ostream& stream, + const TensorOptions& options); + +template +inline TensorOptions dtype() { + return dtype(caffe2::TypeMeta::Make()); +} + +inline std::string toString(const TensorOptions& options) { + std::ostringstream stream; + stream << options; + return stream.str(); +} + +// This is intended to be a centralized location by which we can determine +// what an appropriate DispatchKey for a tensor is. +inline DispatchKey computeDispatchKey( + c10::optional dtype, + c10::optional layout, + c10::optional device) { + const auto layout_ = layout_or_default(layout); + const auto device_ = device_or_default(device); + switch (layout_) { + case Layout::Jagged: + case Layout::Strided: { + const auto dtype_ = dtype_or_default(dtype); + switch (device_.type()) { +#define DO_CASE(device, _) \ + case c10::DeviceType::device: { \ + if (isQIntType(dtype_)) { \ + return DispatchKey::Quantized##device; \ + } \ + return DispatchKey::device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + case c10::DeviceType::FPGA: + return DispatchKey::FPGA; + case c10::DeviceType::ORT: + return DispatchKey::ORT; + case c10::DeviceType::Vulkan: + return DispatchKey::Vulkan; + case c10::DeviceType::Metal: + return DispatchKey::Metal; + case c10::DeviceType::MKLDNN: + case c10::DeviceType::OPENGL: + case c10::DeviceType::OPENCL: + case c10::DeviceType::IDEEP: + TORCH_INTERNAL_ASSERT( + 0, + "This is a grandfathered Caffe2 device type ", + device_.type(), + ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error."); + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for dense layout: ", + device_.type()); + } + } + case Layout::Sparse: + switch (device_.type()) { +#define DO_CASE(device, _) \ + case c10::DeviceType::device: { \ + return DispatchKey::Sparse##device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for sparse layout: ", + device_.type()); + } + case Layout::Mkldnn: + switch (device_.type()) { + case c10::DeviceType::CPU: + return DispatchKey::MkldnnCPU; + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for mkldnn layout: ", + device_.type()); + } + case Layout::SparseCsr: + case Layout::SparseCsc: + case Layout::SparseBsr: + case Layout::SparseBsc: + switch (device_.type()) { +#define DO_CASE(device, _) \ + case c10::DeviceType::device: { \ + return DispatchKey::SparseCsr##device; \ + } + C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) +#undef DO_CASE + default: + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Unsupported device type for ", + layout_, + " layout: ", + device_.type()); + } + default: + TORCH_CHECK(false, "Unsupported layout: ", layout_); + } +} + +inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) { + switch (dispatch_key) { +#define DO_CASE(bc, _) case DispatchKey::Sparse##bc: + C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused) +#undef DO_CASE + return Layout::Sparse; +#define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc: + C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused) +#undef DO_CASE + TORCH_CHECK( + false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout."); + case DispatchKey::MkldnnCPU: + return Layout::Mkldnn; + default: + return Layout::Strided; + } +} + +inline c10::DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { + switch (dispatch_key) { + // stuff that's real +#define DO_CASE(suffix, prefix) \ + case DispatchKey::prefix##suffix: \ + return c10::DeviceType::suffix; +#define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix) + C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES) +#undef DO_CASES +#undef DO_CASE + + case DispatchKey::MkldnnCPU: + return c10::DeviceType::CPU; + case DispatchKey::Vulkan: + return c10::DeviceType::Vulkan; + + case DispatchKey::ORT: + return c10::DeviceType::ORT; + default: + TORCH_CHECK( + false, + "DispatchKey ", + dispatch_key, + " doesn't correspond to a device"); + } +} + +inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) { + return TensorOptions() + .layout(dispatchKeyToLayout(dispatch_key)) + .device(dispatchKeyToDeviceType(dispatch_key)); +} + +namespace detail { +inline bool backend_supports_empty_operator(const TensorOptions& options) { + // Quantized backends don't support at::empty(). + // They have separate operators like at::empty_quantized() that take in + // extra information about how to quantize the tensor. + return !isQIntType(typeMetaToScalarType(options.dtype())); +} + +} // namespace detail + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/UndefinedTensorImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/UndefinedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..ed74fd79b8f5c22ab9112b4e5081d78e5a49e15f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/UndefinedTensorImpl.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +struct C10_API UndefinedTensorImpl final : public TensorImpl { + public: + // Without this, we get: + // error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in + // device code + // (ostensibly because the constexpr tricks MSVC into trying to compile this + // function for device as well). +#ifdef _WIN32 + static inline TensorImpl* singleton() { +#else + static constexpr inline TensorImpl* singleton() { +#endif + return &_singleton; + } +#ifdef DEBUG + bool has_storage() const override; +#endif + void set_storage_offset(int64_t offset) override; + + protected: + bool is_contiguous_custom(MemoryFormat format) const override; + IntArrayRef strides_custom() const override; + SymIntArrayRef sym_strides_custom() const override; + + private: + UndefinedTensorImpl(); + static UndefinedTensorImpl _singleton; + const char* tensorimpl_type_name() const override; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/WrapDimMinimal.h b/MLPY/Lib/site-packages/torch/include/c10/core/WrapDimMinimal.h new file mode 100644 index 0000000000000000000000000000000000000000..cc3b0d3267171a60c5725d8d9185772a2b4601ea --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/WrapDimMinimal.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +namespace detail { +// This template can only be specialized at int64_t and c10::SymInt; +// you'll get linker errors otherwise +template +C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); +} // namespace detail + +template +T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { + // Inline the fast paths + if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { + // For SymInts, we want an explicit control flow to trigger a guard, so we + // may as well branch too. + if (dim < 0) { + return dim + dim_post_expr; + } + return dim; + } + // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) + return c10::detail::maybe_wrap_dim_slow( + std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +inline int64_t maybe_wrap_dim( + int64_t dim, + int64_t dim_post_expr, + bool wrap_scalar = true) { + return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); +} + +inline c10::SymInt maybe_wrap_dim( + c10::SymInt dim, + c10::SymInt dim_post_expr, + bool wrap_scalar = true) { + return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/alignment.h b/MLPY/Lib/site-packages/torch/include/c10/core/alignment.h new file mode 100644 index 0000000000000000000000000000000000000000..32cac40eb982d97808989aa6245ac4081c6eb824 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/alignment.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace c10 { + +#ifdef C10_MOBILE +// Use 16-byte alignment on mobile +// - ARM NEON AArch32 and AArch64 +// - x86[-64] < AVX +constexpr size_t gAlignment = 16; +#else +// Use 64-byte alignment should be enough for computation up to AVX512. +constexpr size_t gAlignment = 64; +#endif + +constexpr size_t gPagesize = 4096; +// since the default thp pagesize is 2MB, enable thp only +// for buffers of size 2MB or larger to avoid memory bloating +constexpr size_t gAlloc_threshold_thp = static_cast(2) * 1024 * 1024; +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/COW.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/COW.h new file mode 100644 index 0000000000000000000000000000000000000000..b3a94d9681de0da82c1a1ddde114eee0376647fc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/COW.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace c10 { +struct StorageImpl; +class DataPtr; +}; // namespace c10 + +namespace c10::impl::cow { + +// Creates a Copy-on-write (COW) clone of the given storage. This will also +// convert the given storage into a COW storage if it is not COW already. +// +// Converting the storage into a COW storage will not be successful if the +// storage's DataPtr has some context (`DataPtr::get_context()`) which is not +// equal to the data pointer (`DataPtr::get()`). In this case, a nullptr is +// returned. +C10_API c10::intrusive_ptr lazy_clone_storage( + StorageImpl& storage); + +// Check if a storage has a simple DataPtr with no abnormal context +C10_API bool has_simple_data_ptr(const c10::StorageImpl& storage); + +// Check if a DataPtr is COW +C10_API bool is_cow_data_ptr(const c10::DataPtr& data_ptr); + +// Eagerly copies a COW storage's data, turning it into a non-COW storage. +C10_API void materialize_cow_storage(StorageImpl& storage); + +} // namespace c10::impl::cow diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/COWDeleter.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/COWDeleter.h new file mode 100644 index 0000000000000000000000000000000000000000..58378c4ec2e3b9826cd7cb4cc11ab610662b3b16 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/COWDeleter.h @@ -0,0 +1,66 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace c10::impl::cow { + +// A COWDeleterContext object is used as the `ctx` argument for DataPtr +// to implement a Copy-on-write (COW) DataPtr. +class C10_API COWDeleterContext { + public: + // Creates an instance, holding the pair of data and original + // deleter. + // + // Note that the deleter will only be called in our destructor if + // the last reference to this goes away without getting + // materialized. + explicit COWDeleterContext(std::unique_ptr data); + + // Increments the current refcount. + void increment_refcount(); + + // See README.md in this directory to understand the locking + // strategy. + + // Represents a reference to the context. + // + // This is returned by decrement_refcount to allow the caller to + // copy the data under the shared lock. + using NotLastReference = std::shared_lock; + + // Represents the last reference to the context. + // + // This will be returned by decrement_refcount when it is the last + // reference remaining and after any pending copies have completed. + using LastReference = std::unique_ptr; + + // Decrements the refcount, returning a handle indicating what to + // do with it. + std::variant decrement_refcount(); + + private: + // The destructor is hidden, this should only ever be used within + // UniqueVoidPtr using cow::delete_context as the deleter. + ~COWDeleterContext(); + + std::shared_mutex mutex_; + std::unique_ptr data_; + std::atomic refcount_ = 1; +}; + +// `cow_deleter` is used as the `ctx_deleter` for DataPtr to implement a COW +// DataPtr. +// +// Warning: This should only be called on a pointer to a COWDeleterContext that +// was allocated on the heap with `new`, because when the refcount reaches 0, +// the context is deleted with `delete`. +C10_API void cow_deleter(void* ctx); + +} // namespace c10::impl::cow diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/DeviceGuardImplInterface.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/DeviceGuardImplInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..19df643064a83fabfa43442590ab570d054ed096 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/DeviceGuardImplInterface.h @@ -0,0 +1,337 @@ +#pragma once + +#include +#include +#include +#include + +// Just for C10_ANONYMOUS_VARIABLE +#include + +#include + +namespace c10 { + +// Forward declaration +class DataPtr; + +/** + * Flags defining the behavior of events. + * + * PYTORCH_DEFAULT and BACKEND_DEFAULT are valid for all backends. The + * BACKEND_DEFAULT is what a particular backend would select if no + * flags were given. PYTORCH_DEFAULT is the PyTorch's framework default + * choice for events on that backend, which may not be the same. For example, + * when PyTorch creates a CUDA event it sets the flag + * CUDA_EVENT_DISABLING_TIMING by default to improve performance. + * + * The mapping of PYTORCH_DEFAULT and BACKEND_DEFAULT is done by each + * backend implementation. Backend-specific flags, like CUDA_EVENT_DEFAULT, + * should map one-to-one with actual event flags for those backends. + */ +enum class EventFlag { + PYTORCH_DEFAULT, + BACKEND_DEFAULT, + // CUDA flags + CUDA_EVENT_DEFAULT, + CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA + // HIP flags + HIP_EVENT_DEFAULT, + HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP + // FOR TESTING ONLY + INVALID +}; + +namespace impl { + +/** + * DeviceGuardImplInterface represents the virtual interface which provides + * functionality to provide an RAII class for device and stream switching, + * via DeviceGuard. Every distinct device type, e.g., CUDA and HIP, is + * expected to implement and register an implementation of this interface. + * All classes which inherit from DeviceGuardImplInterface should be declared + * 'final'. + * + * This class exists because we provide a unified interface for performing + * device guards via DeviceGuard, but we cannot assume that we have actually + * compiled against the, e.g., CUDA library, which actually implements + * this guard functionality. In this case, a dynamic dispatch is required + * to cross the library boundary. + * + * If possible, you should directly use implementations of this interface; + * those uses will be devirtualized. + */ +struct C10_API DeviceGuardImplInterface { + DeviceGuardImplInterface() = default; + DeviceGuardImplInterface(const DeviceGuardImplInterface&) = default; + DeviceGuardImplInterface& operator=(const DeviceGuardImplInterface&) = + default; + DeviceGuardImplInterface(DeviceGuardImplInterface&&) noexcept = default; + DeviceGuardImplInterface& operator=(DeviceGuardImplInterface&&) noexcept = + default; + + /** + * Return the type of device managed by this guard implementation. + */ + virtual DeviceType type() const = 0; + + /** + * Set the current device to Device, and return the previous Device. + */ + virtual Device exchangeDevice(Device) const = 0; + // NB: Implementations of exchangeDevice can be a bit boilerplatey. You might + // consider replacing exchangeDevice with a non-virtual function with a baked + // in implementation; however, note that this will triple the number of + // virtual calls (when you implement exchangeDevice in a final subclass, + // the compiler gets to devirtualize everything; it won't do that if you don't + // define it in the subclass!) A common way to solve this problem is to use + // some sort of CRTP; however, we can template DeviceGuardImplInterface since + // we really *do* need it to be virtual. A little boilerplate seems easiest + // to explain. (Another way around this problem is to provide inline + // functions that provide the default implementations, but this seems a little + // hard to explain. In any case, we're only going to have on order of ten + // implementations of this anyway.) + + /** + * Get the current device. + */ + virtual Device getDevice() const = 0; + + /** + * Set the current device to Device. + */ + virtual void setDevice(Device) const = 0; + + /** + * Set the current device to Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + virtual void uncheckedSetDevice(Device) const noexcept = 0; + + /** + * Get the current stream for a given device. + */ + virtual Stream getStream(Device) const noexcept = 0; + + /** + * Get the default stream for a given device. + */ + virtual Stream getDefaultStream(Device) const { + TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.") + } + + /** + * Get a stream from the global pool for a given device. + */ + virtual Stream getStreamFromGlobalPool(Device, bool isHighPriority = false) + const { + (void)isHighPriority; // Suppress unused variable warning + TORCH_CHECK(false, "Backend doesn't support acquiring a stream from pool.") + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + virtual Stream exchangeStream(Stream) const noexcept = 0; + + /** + * Destroys the given event. + */ + virtual void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) + const noexcept {} + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + virtual void record( + void** /*event*/, + const Stream& /*stream*/, + const DeviceIndex /*device_index*/, + const c10::EventFlag /*flag*/) const { + TORCH_CHECK(false, "Backend doesn't support events."); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + virtual void block(void* /*event*/, const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support events."); + } + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + virtual bool queryEvent(void* /*event*/) const { + TORCH_CHECK(false, "Backend doesn't support events."); + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + virtual DeviceIndex deviceCount() const noexcept = 0; + + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + virtual bool queryStream(const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support querying streams."); + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + virtual void synchronizeStream(const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support synchronizing streams."); + } + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const { + } + + /** + * Intended use of this class is to leak the DeviceGuardImpl at program end. + * So you better not call the destructor, buster! + */ + virtual ~DeviceGuardImplInterface() = default; +}; + +// A no-op device guard impl that doesn't do anything interesting. Useful +// for devices that don't actually have a concept of device index. Prominent +// examples are CPU and Meta. +template +struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface { + NoOpDeviceGuardImpl() = default; + DeviceType type() const override { + return D; + } + Device exchangeDevice(Device) const override { + return Device(D, -1); // no-op + } + Device getDevice() const override { + return Device(D, -1); + } + void setDevice(Device) const override { + // no-op + } + void uncheckedSetDevice(Device) const noexcept override { + // no-op + } + Stream getStream(Device) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(D, -1)); + } + // NB: These do NOT set the current device + Stream exchangeStream(Stream) const noexcept override { + // no-op + return Stream(Stream::DEFAULT, Device(D, -1)); + } + DeviceIndex deviceCount() const noexcept override { + return 1; + } + + // Event-related functions + void record( + void** /*event*/, + const Stream& /*stream*/, + const DeviceIndex /*device_index*/, + const EventFlag /*flag*/) const override { + TORCH_CHECK(false, D, " backend doesn't support events."); + } + void block(void* /*event*/, const Stream& /*stream*/) const override { + TORCH_CHECK(false, D, " backend doesn't support events.") + } + bool queryEvent(void* /*event*/) const override { + TORCH_CHECK(false, D, " backend doesn't support events.") + } + void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) + const noexcept override {} + + // Stream-related functions + bool queryStream(const Stream& /*stream*/) const override { + return true; + } + void synchronizeStream(const Stream& /*stream*/) const override { + // Don't wait for anything. + } +}; + +// The registry is NON-owning. Each stored pointer is std::atomic so +// that under all interleavings of registry calls the structure is +// race-free. This doesn't cost us anything on reads in X86. (An +// unsynchronized implementation probably is OK too, but I didn't want +// to prove that we never read from device_guard_impl_registry at the +// same time some registration is occurring. Shiver.) +// +// I'd like this registry to be valid even at program destruction time +// (in case someone uses a DeviceGuard in a destructor to do some cleanup +// in the CUDA API.) Since there are no direct accesses of the underlying +// owning objects which I can use to enforce initialization order (unlike +// in a Meyer singleton), it implies that you must *leak* objects when +// putting them in the registry. This is done by deleting the destructor +// on DeviceGuardImplInterface. +// NOLINTNEXTLINE(*c-arrays*) +extern C10_API std::atomic + device_guard_impl_registry[static_cast( + DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)]; + +// I can't conveniently use c10/util/Registry.h for the following reason: +// c10/util/Registry.h gives me a slow way of Create'ing a object of some +// interface from the registry, but no way of quickly accessing an already +// created object. I'll be banging on getDeviceGuardImpl every time we do a +// DeviceGuard, so I really don't want to be doing an unordered_map lookup. +// Better if the registration mechanism directly drops its implementation +// into device_guard_impl_registry. + +class C10_API DeviceGuardImplRegistrar { + public: + DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*); +}; + +#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \ + static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \ + g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl()); + +inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) { + // Two adjacent int16_t fields DeviceType and DeviceIndex has field access + // miscompiled on NVCC. To workaround this issue, we apply a mask to the + // DeviceType. First check if the DeviceType is 16-bit. + // FB employees can see + // https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/ + // for more details + static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit"); + auto p = device_guard_impl_registry[static_cast(type) & 0xFF].load(); + + // This seems to be the first place where you make use of a device + // when you pass devices to factory functions. Give a nicer error + // message in this case. + TORCH_CHECK(p, "PyTorch is not linked with support for ", type, " devices"); + return p; +} + +inline bool hasDeviceGuardImpl(DeviceType type) { + return device_guard_impl_registry[static_cast(type)].load(); +} + +} // namespace impl +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/FakeGuardImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/FakeGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..c1f015eb31cc14d0dd80c540b417151f495fa952 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/FakeGuardImpl.h @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include + +namespace c10::impl { + +// FakeGuardImpl is hardcoded to have eight devices. Not for +// any good reason, just to simplify code. +constexpr DeviceIndex kFakeGuardImplMaxDevices = 8; + +/** + * A fake implementation of DeviceGuardImplInterface suitable for testing. + * The current device is modeled as a mutable field in the guard implementation + * class. See DeviceGuard_test.cpp for an example use. + */ +template +struct FakeGuardImpl final : public DeviceGuardImplInterface { + static constexpr DeviceType static_type = T; + // Runtime device type is not used + FakeGuardImpl(DeviceType) {} + FakeGuardImpl() = default; + DeviceType type() const override { + return T; + } + Device exchangeDevice(Device d) const override { + AT_ASSERT(d.type() == type()); + AT_ASSERT(d.index() < kFakeGuardImplMaxDevices); + Device old_device = getDevice(); + if (old_device.index() != d.index()) { + current_device_ = d.index(); + } + return old_device; + } + Device getDevice() const override { + return Device(type(), current_device_); + } + void setDevice(Device d) const override { + AT_ASSERT(d.type() == type()); + AT_ASSERT(d.index() >= 0); + AT_ASSERT(d.index() < kFakeGuardImplMaxDevices); + current_device_ = d.index(); + } + void uncheckedSetDevice(Device d) const noexcept override { + current_device_ = d.index(); + } + Stream getStream(Device d) const noexcept override { + return Stream(Stream::UNSAFE, d, current_streams_[d.index()]); + } + Stream exchangeStream(Stream s) const noexcept override { + auto old_id = current_streams_[s.device_index()]; + current_streams_[s.device_index()] = s.id(); + return Stream(Stream::UNSAFE, s.device(), old_id); + } + DeviceIndex deviceCount() const noexcept override { + return kFakeGuardImplMaxDevices; + } + + // Event-related functions + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override {} + void block(void* event, const Stream& stream) const override {} + bool queryEvent(void* event) const override { + return true; + } + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override {} + + // Convenience methods for testing + static DeviceIndex getDeviceIndex() { + return current_device_; + } + static void setDeviceIndex(DeviceIndex i) { + AT_ASSERT(i >= 0); + AT_ASSERT(i < kFakeGuardImplMaxDevices); + current_device_ = i; + } + static StreamId getCurrentStreamIdFor(DeviceIndex i) { + return current_streams_.at(i); + } + static void resetStreams() { + current_streams_.fill(0); + } + + private: + thread_local static DeviceIndex current_device_; + thread_local static std::array + current_streams_; +}; + +template +thread_local DeviceIndex FakeGuardImpl::current_device_ = 0; + +template +thread_local std::array + FakeGuardImpl::current_streams_ = {0, 0, 0, 0, 0, 0, 0, 0}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/GPUTrace.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/GPUTrace.h new file mode 100644 index 0000000000000000000000000000000000000000..9101b1b29c34a8b3cf61c4f0b759066a804febf5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/GPUTrace.h @@ -0,0 +1,28 @@ +#pragma once + +#include + +namespace c10::impl { + +struct C10_API GPUTrace { + // On the x86 architecture the atomic operations are lock-less. + static std::atomic gpuTraceState; + + // When PyTorch migrates to C++20, this should be changed to an atomic flag. + // Currently, the access to this variable is not synchronized, on the basis + // that it will only be flipped once and by the first interpreter that + // accesses it. + static bool haveState; + + // This function will only register the first interpreter that tries to invoke + // it. For all of the next ones it will be a no-op. + static void set_trace(const PyInterpreter*); + + static const PyInterpreter* get_trace() { + if (!haveState) + return nullptr; + return gpuTraceState.load(std::memory_order_acquire); + } +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/HermeticPyObjectTLS.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/HermeticPyObjectTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..dd22d19adbd0d75a50ca649b360340cfd12dd537 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/HermeticPyObjectTLS.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +namespace c10::impl { + +// This TLS controls whether or not we permanently associate PyObject +// with Tensor the first time it is allocated. When hermetic PyObject +// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor, +// meaning you get a distinct PyObject whenever you execute the code in +// question. +struct C10_API HermeticPyObjectTLS { + static void set_state(bool state); + static bool get_state() { + // Hypothetical fastpath if torchdeploy/multipy isn't used. Per + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf + // this qualifies relaxed access because it is a single-location data + // structure (only the boolean here). + // + // Forgetting about data races for a moment, is there a logical race? + // + // - Boolean only ever transitions from false to true. So the + // critical situation is when one interpreter is already running + // when a second interpreter switches haveState from false to true. + // + // - The first interpreter is indifferent whether or not it sees + // hasState true/false; obviously false works (this is what the + // interpreter was previously using; more directly, the interpreter + // calls into itself as the handler, so being hermetic is not + // required), and true simply means serviced python operator calls will + // be hermetic; in these cases it is expected to be functionally + // equivalent. + // + // - The second interpreter MUST see hasState true (as its requests will + // be forwarded to the first interpreter), but it is assumed that there + // is a synchronization between the interpreter initialization, and + // when we actually perform operations, so it is guaranteed to see + // hasState true. + // + // QED. + // + // This fastpath is currently disabled so that we can more easily test that + // hermetic mode works correctly even on stock build of PyTorch. + if (false && !haveState_.load(std::memory_order_relaxed)) + return false; + return get_tls_state(); + } + // Call this from the multipy/torchdeploy top level + static void init_state(); + + private: + // This only flipped once from false to true during torchdeploy/multipy + // initialization, and never again. + static std::atomic haveState_; + static bool get_tls_state(); +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineDeviceGuard.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineDeviceGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..b5e647e7205eb9c06c17b5d662f9d8fa742cac47 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineDeviceGuard.h @@ -0,0 +1,428 @@ +#pragma once + +// This file provides implementations of InlineDeviceGuard and +// InlineOptionalDeviceGuard. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10::impl { + +/** + * A DeviceGuard is an RAII class that sets a device to some value + * on construction, and resets the device to its original value on + * destruction. + * + * InlineDeviceGuard is a helper class for implementing DeviceGuards. + * It is templated over a DeviceGuardImpl (anything that implements + * DeviceGuardImplInterface). There are two primary ways to instantiate + * InlineDeviceGuard: + * + * - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl. + * This is the best way to use InlineDeviceGuard, as all calls are + * devirtualized, giving you code as efficient as straight line + * calls to cudaGetDevice/cudaSetDevice. + * + * - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl + * retrieved from a DeviceType registry. We have explicitly instantiated + * InlineDeviceGuard this way as c10::DeviceGuard. + * + * If you are in a hurry, you can use InlineDeviceGuard directly: + * + * using CUDAGuard = impl::InlineDeviceGuard; + * + * However, you can provide a better user experience if you explicitly write a + * wrapper class that itself contains the template instantiation: + * + * class CUDAGuard { + * public: + * // ... the API ... + * private: + * impl::InlineDeviceGuard guard_; + * } + * + * The wrapper class provides a good place to write documentation, and helps + * avoid weird template instantiation errors when a user incorrectly uses the + * class. + * + * If you need to test this class, consider instantiating it with FakeGuardImpl. + */ +template +class InlineDeviceGuard { + public: + // Note [Omitted default constructor from RAII] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // In principle, we could add a default constructor to + // DeviceGuard which reads the current device and promises to + // restore to that device on exit. However, most cases where you + // would have written this, you probably meant to actually just + // use OptionalDeviceGuard (since you don't actually need the + // restore to happen if you don't ever actually set the device). + // We remove the constructor here to encourage you to think about + // what you actually want to happen. + explicit InlineDeviceGuard() = delete; + + /// Set the current device to the passed Device. + explicit InlineDeviceGuard(Device device) + : impl_(device.type()), + original_device_( + device.index() == -1 ? impl_.getDevice() + : impl_.exchangeDevice(device)), + current_device_(device.index() == -1 ? original_device_ : device) {} + + /// Set the current device index to the passed DeviceIndex. (The + /// device type is inferred from the template parameter T). + template < + typename U = T, + typename = + typename std::enable_if_t>> + explicit InlineDeviceGuard(DeviceIndex device_index) + : InlineDeviceGuard(Device(U::static_type, device_index)) {} + + /// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit + /// DeviceGuardImplInterface pointer. + template < + typename U = T, + typename = typename std::enable_if_t>> + explicit InlineDeviceGuard( + Device device, + const DeviceGuardImplInterface* impl) + : impl_( + VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))), + original_device_( + device.index() == -1 ? impl_.getDevice() + : impl_.exchangeDevice(device)), + current_device_(device.index() == -1 ? original_device_ : device) {} + + /// Copy is disallowed + InlineDeviceGuard(const InlineDeviceGuard&) = delete; + InlineDeviceGuard& operator=(const InlineDeviceGuard&) = delete; + + /// Move is disallowed, as DeviceGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + InlineDeviceGuard(InlineDeviceGuard&& other) = delete; + InlineDeviceGuard& operator=(InlineDeviceGuard&& other) = delete; + + ~InlineDeviceGuard() { + impl_.uncheckedSetDevice(original_device_); + } + + /// Sets the device to the given one. + template < + typename U = T, + typename std::enable_if_t, int> = 0> + void set_device(at::Device device) { + AT_ASSERT( + (U::static_type == DeviceType::HIP && device.is_cuda()) || + device.type() == U::static_type); + auto index = device.index(); + if (index == -1) + return; + impl_.setDevice(device); + current_device_ = device; + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device. This is effectively equivalent to + /// set_device when a guard supports only a single device type. + template + typename std::enable_if_t> reset_device( + at::Device device) { + set_device(device); + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device (for a possibly different device + /// type). + /// + /// This method is named reset_device to highlight the fact that previous + /// device settings from this guard are NOT preserved, even if the device + /// has a different device type. For example: + /// + /// // CUDA device is 0 + /// DeviceGuard g(Device(kCUDA, 1)); + /// g.reset_device(Device(kHIP, 2)); + /// // CUDA device is 0 (!!) + /// + /// NOTE: this implementation may skip some device setting if it can prove + /// that it is unnecessary. + /// + /// Optional argument is for testing only. + template + typename std::enable_if_t> reset_device( + at::Device device, + const impl::DeviceGuardImplInterface* impl = nullptr) { + auto index = device.index(); + if (index == -1) + return; + if (device.type() == original_device_.type()) { + AT_ASSERT(impl == nullptr || impl->type() == device.type()); + impl_.setDevice(device); + current_device_ = device; + } else { + // Destruct and reconstruct the DeviceGuard in place + impl_.setDevice(original_device_); + impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl); + original_device_ = impl_.exchangeDevice(device); + current_device_ = device; + } + } + + /// Sets the device index to the given one. The device type is inferred + /// from the original device type. + void set_index(DeviceIndex index) { + reset_device(Device(original_device_.type(), index)); + } + + /// Returns the device that was set at the time the most recent + /// reset_device(), or otherwise the device at construction time. + Device original_device() const { + return original_device_; + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return current_device_; + } + + protected: + T impl_; + + private: + Device original_device_; + Device current_device_; +}; + +/** + * A OptionalDeviceGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * + * InlineOptionalDeviceGuard is a helper class for implementing + * OptionalDeviceGuards. See guidance in InlineDeviceGuard on how to + * use this. See OptionalDeviceGuard for user-oriented usage notes. + */ +template +class InlineOptionalDeviceGuard { + public: + // Note [Explicit initialization of optional fields] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Explicit initialization of optional fields + // required to workaround an nvcc bug; see + // https://github.com/pytorch/pytorch/issues/12117 + + /// Creates an uninitialized OptionalDeviceGuard. + explicit InlineOptionalDeviceGuard() + : guard_() // See Note [Explicit initialization of optional fields] + {} + + /// Set the current device to the passed Device, if it is not nullopt. + explicit InlineOptionalDeviceGuard(optional device_opt) + : guard_() { // See Note [Explicit initialization of optional fields] + if (device_opt.has_value()) { + guard_.emplace(device_opt.value()); + } + } + + /// Set the current device to the passed DeviceIndex, if it is not nullopt. + template < + typename U = T, + typename = + typename std::enable_if_t>> + explicit InlineOptionalDeviceGuard(optional device_index_opt) + : guard_() { // See Note [Explicit initialization of optional fields] + if (device_index_opt.has_value()) { + guard_.emplace(device_index_opt.value()); + } + } + + /// All constructors of DeviceGuard are valid for OptionalDeviceGuard + /// and result in initialized OptionalDeviceGuard. + template + explicit InlineOptionalDeviceGuard(Args&&... args) + : guard_(std::in_place, std::forward(args)...) {} + + // TODO: Consider reading Tensor and TensorList constructors here, when + // Tensor moves to c10. (These are only valid on OptionalDeviceGuard, + // because a Tensor may be undefined, in which case we need an uninitialized + // tensor guard.) + + // Note [Move construction for RAII guards is tricky] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // In principle, move construction is useful for terminating + // the lifetime of a `OptionalDeviceGuard` early; for example: + // + // // current device is d0 + // OptionalDeviceGuard g1(d1); + // // current device is d1 + // { + // OptionalDeviceGuard g2(std::move(g1)); + // } + // // current device is d0!! + // + // However, it's difficult to implement the move constructor + // in a way that works in all situations. For example, consider + // the following example: + // + // OptionalDeviceGuard g1(d1); + // { + // OptionalDeviceGuard g2(d2); + // { + // OptionalDeviceGuard g3(std::move(g1)); // !!! + // } + // } + // + // What should the current device be while g3 in scope... and what + // should it be after it goes out of scope? What about g2? + // There don't seem to be satisfactory answers for these questions. + // + // It's in principle possible to raise an error when this occurs + // by doing some extra thread-local bookkeeping. But why bother? + // Just don't provide the constructor. + InlineOptionalDeviceGuard(InlineOptionalDeviceGuard&& other) = delete; + + // Note [Move assignment for RAII guards is tricky] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Move assignment is deleted, because you need to know which guard was + // defined "first", as that guard's original_device_ wins--with the current + // representation, we have no way of telling which is the case. (Move + // construction does not have this problem, as one guard is always + // uninitialized.) + // + // We can make this clear by way of a pair of examples: + // + // Example 1: + // + // // initial device is n0 + // { + // CUDAGuard g1(n1); + // { + // CUDAGuard g2(n2); + // // current device should be n2 + // g1 = std::move(g2); + // // current device should still be n2 + // } + // // current device should still be n2 + // } + // // current device should be n0 + // + // Example 2 (flip the order of the two guards): + // + // // initial device is n0 + // { + // CUDAGuard g2(n2); + // { + // CUDAGuard g1(n1); + // // current device should be n1 + // g1 = std::move(g2); + // // current device should be n2 + // } + // // current device should be n0 (since g2 has been vacated) + // } + // + // In both examples, we need g1 to restore to n0 after move assignment. + // However, in example 1, this is determined by the restore value of g1 + // (prior to the move). In example 2, however, it is determined by the the + // restore value of g2(!!). We don't know which one should win, without having + // a way of telling which guard was allocated first. + // + // We could solve this with an extra thread-local variable. But no one is + // actually using move-assignment. So just get rid of it. + InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = + delete; + + /// Sets the device to the given one. Initializes OptionalDeviceGuard if it + /// is not already initialized. + template < + typename U = T, + typename = + typename std::enable_if_t>> + void set_device(at::Device device) { + if (!guard_.has_value()) { + guard_.emplace(device); + } else { + guard_->set_device(device); + } + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device (for a possibly different device + /// type). Initializes OptionalDeviceGuard if it is not already initialized. + /// + /// See notes on why this is called reset_device on InlineDeviceGuard. + /// + /// Optional argument is for testing only. + template < + typename U = T, + typename = typename std::enable_if_t>> + void reset_device( + at::Device device, + const DeviceGuardImplInterface* impl = nullptr) { + if (!guard_.has_value()) { + guard_.emplace(device, impl); + } else { + guard_->reset_device(device, impl); + } + } + + /// Resets the currently set device to its original device, and then sets the + /// current device to the passed device. Initializes the guard if it is + /// not already initialized. This is effectively equivalent to set_device + /// when a guard supports only a single device type. + template < + typename U = T, + typename = + typename std::enable_if_t>> + void reset_device(at::Device device) { + if (!guard_.has_value()) { + guard_.emplace(device); + } else { + guard_->reset_device(device); + } + } + + /// Sets the device index to the given one. The device type is statically + /// known. + template < + typename U = T, + typename = + typename std::enable_if_t>> + void set_index(DeviceIndex index) { + if (!guard_.has_value()) { + guard_.emplace(index); + } else { + guard_->set_index(index); + } + } + + /// Returns the device that was set immediately prior to initialization of + /// the, guard, or nullopt if the guard is uninitialized. + optional original_device() const { + return guard_.has_value() ? make_optional(guard_->original_device()) + : nullopt; + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + optional current_device() const { + return guard_.has_value() ? make_optional(guard_->current_device()) + : nullopt; + } + + /// Restore the original device, resetting this guard to uninitialized state. + void reset() { + guard_.reset(); + } + + private: + optional> guard_; +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineEvent.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..7ff255440af161e77e1cc8863f52d4c102c45ef1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineEvent.h @@ -0,0 +1,113 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10::impl { + +template +struct InlineEvent final { + InlineEvent() = delete; + InlineEvent( + const DeviceType _device_type, + const EventFlag _flag = EventFlag::PYTORCH_DEFAULT) + : backend_{_device_type}, device_type_{_device_type}, flag_{_flag} {} + + // Copy constructor and copy assignment operator (deleted) + InlineEvent(const InlineEvent&) = delete; + InlineEvent& operator=(const InlineEvent&) = delete; + + // Move constructor and move assignment operator + InlineEvent(InlineEvent&& other) noexcept + : event_(other.event_), + backend_(std::move(other.backend_)), + device_type_(other.device_type_), + device_index_(other.device_index_), + flag_(other.flag_), + was_marked_for_recording_(other.was_marked_for_recording_) { + other.event_ = nullptr; + } + InlineEvent& operator=(InlineEvent&& other) noexcept { + swap(other); + return *this; + } + + void swap(InlineEvent& other) noexcept { + std::swap(event_, other.event_); + std::swap(backend_, other.backend_); + std::swap(device_type_, other.device_type_); + std::swap(device_index_, other.device_index_); + std::swap(flag_, other.flag_); + std::swap(was_marked_for_recording_, other.was_marked_for_recording_); + } + + ~InlineEvent() noexcept { + if (event_) + backend_.destroyEvent(event_, device_index_); + } + + DeviceType device_type() const noexcept { + return device_type_; + } + DeviceIndex device_index() const noexcept { + return device_index_; + } + EventFlag flag() const noexcept { + return flag_; + } + bool was_marked_for_recording() const noexcept { + return was_marked_for_recording_; + } + + void recordOnce(const Stream& stream) { + if (!was_marked_for_recording_) + record(stream); + } + + void record(const Stream& stream) { + TORCH_CHECK( + stream.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match recording stream's device type ", + DeviceTypeName(stream.device_type()), + "."); + + backend_.record(&event_, stream, device_index_, flag_); + was_marked_for_recording_ = true; + device_index_ = stream.device_index(); + } + + void block(const Stream& stream) const { + if (!was_marked_for_recording_) + return; + + TORCH_CHECK( + stream.device_type() == device_type_, + "Event device type ", + DeviceTypeName(device_type_), + " does not match blocking stream's device type ", + DeviceTypeName(stream.device_type()), + "."); + + backend_.block(event_, stream); + } + + bool query() const { + if (!was_marked_for_recording_) + return true; + return backend_.queryEvent(event_); + } + + private: + void* event_ = nullptr; + T backend_; + DeviceType device_type_; + DeviceIndex device_index_ = -1; + EventFlag flag_ = EventFlag::PYTORCH_DEFAULT; + bool was_marked_for_recording_ = false; +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineStreamGuard.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineStreamGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..52d3a648aced64d62eb99ca8ce47c0069729f922 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/InlineStreamGuard.h @@ -0,0 +1,255 @@ +#pragma once + +#include +#include +#include + +namespace c10::impl { + +/** + * A StreamGuard is an RAII class that changes the current device + * to the device corresponding to some stream, and changes the + * default stream on that device to be this stream. + * + * InlineStreamGuard is a helper class for implementing StreamGuards. + * See InlineDeviceGuard for guidance on how to use this class. + */ +template +class InlineStreamGuard : private InlineDeviceGuard { + public: + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit InlineStreamGuard() = delete; + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + explicit InlineStreamGuard(Stream stream) + : InlineDeviceGuard(stream.device()), + original_stream_of_original_device_( + this->impl_.getStream(original_device())), + original_stream_of_current_device_(this->impl_.exchangeStream(stream)), + current_stream_(stream) {} + + /// This constructor exists purely for testing + template < + typename U = T, + typename = typename std::enable_if_t>> + explicit InlineStreamGuard( + Stream stream, + const DeviceGuardImplInterface* impl) + : InlineDeviceGuard( + stream.device(), + impl ? impl : getDeviceGuardImpl(stream.device_type())), + original_stream_of_original_device_( + this->impl_.getStream(original_device())), + original_stream_of_current_device_(this->impl_.exchangeStream(stream)), + current_stream_(stream) {} + + /// Copy is disallowed + InlineStreamGuard(const InlineStreamGuard&) = delete; + InlineStreamGuard& operator=(const InlineStreamGuard&) = delete; + + /// Move is disallowed, as StreamGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + InlineStreamGuard(InlineStreamGuard&& other) = delete; + InlineStreamGuard& operator=(InlineStreamGuard&& other) = delete; + + ~InlineStreamGuard() { + this->impl_.exchangeStream(original_stream_of_current_device_); + } + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// use MultiStreamGuard instead. + void reset_stream(Stream stream) { + // TODO: make a version that takes an impl argument. Unfortunately, + // that will require SFINAE because impl is only valid for the + // VirtualGuardImpl specialization. + if (stream.device() == this->current_device()) { + this->impl_.exchangeStream(stream); + current_stream_ = stream; + } else { + // Destruct and reconstruct the StreamGuard in-place + this->impl_.exchangeStream(original_stream_of_current_device_); + this->reset_device(stream.device()); + original_stream_of_current_device_ = this->impl_.exchangeStream(stream); + current_stream_ = stream; + } + } + + // It's not clear if set_device should also reset the current stream + // if the device is unchanged; therefore, we don't provide it. + // The situation is somewhat clearer with reset_device, but it's still + // a pretty weird thing to do, so haven't added this either. + + /// Returns the stream of the original device prior to this guard. Subtly, + /// the stream returned here is the original stream of the *original* + /// device; i.e., it's the stream that your computation *would* have + /// been put on, if it hadn't been for this meddling stream guard. + /// This is usually what you want. + Stream original_stream() const { + return original_stream_of_original_device_; + } + + /// Returns the most recent stream that was set using this device guard, + /// either from construction, or via set_stream. + Stream current_stream() const { + return current_stream_; + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return InlineDeviceGuard::current_device(); + } + + /// Returns the device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return InlineDeviceGuard::original_device(); + } + + private: + Stream + original_stream_of_original_device_; // what the user probably cares about + Stream original_stream_of_current_device_; // what we need to restore + Stream current_stream_; +}; + +/** + * An OptionalStreamGuard is an RAII class that sets a device to some value on + * initialization, and resets the device to its original value on destruction. + * See InlineOptionalDeviceGuard for more guidance on how to use this class. + */ +template +class InlineOptionalStreamGuard { + public: + /// Creates an uninitialized stream guard. + explicit InlineOptionalStreamGuard() + : guard_() // See Note [Explicit initialization of optional fields] + {} + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit InlineOptionalStreamGuard(optional stream_opt) : guard_() { + if (stream_opt.has_value()) { + guard_.emplace(stream_opt.value()); + } + } + + /// All constructors of StreamGuard are valid for OptionalStreamGuard + template + explicit InlineOptionalStreamGuard(Args&&... args) + : guard_(std::in_place, std::forward(args)...) {} + + // See Note [Move construction for RAII guards is tricky] + InlineOptionalStreamGuard(InlineOptionalStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = + delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the OptionalStreamGuard if it was not previously initialized. + void reset_stream(Stream stream) { + if (guard_.has_value()) { + guard_->reset_stream(stream); + } else { + guard_.emplace(stream); + } + } + + /// Returns the stream that was set at the time the guard was most recently + /// initialized, or nullopt if the guard is uninitialized. + optional original_stream() const { + return guard_.has_value() ? make_optional(guard_->original_stream()) + : nullopt; + } + + /// Returns the most recent stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + optional current_stream() const { + return guard_.has_value() ? make_optional(guard_->current_stream()) + : nullopt; + } + + /// Restore the original device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + optional> guard_; +}; + +template +class InlineMultiStreamGuard { + public: + /// Calls `set_stream` on each of the streams in the list. + /// This may be useful if you need to set different streams + /// for different devices. + explicit InlineMultiStreamGuard(ArrayRef streams) { + if (!streams.empty()) { + impl_.emplace(getDeviceTypeOfStreams(streams)); + original_streams_.reserve(streams.size()); + for (const Stream& s : streams) { + original_streams_.emplace_back(this->impl_->exchangeStream(s)); + } + } + } + + /// Copy is disallowed + InlineMultiStreamGuard(const InlineMultiStreamGuard&) = delete; + InlineMultiStreamGuard& operator=(const InlineMultiStreamGuard&) = delete; + + /// Move is disallowed, as StreamGuard does not have an uninitialized state, + /// which is required for moves on types with nontrivial destructors. + InlineMultiStreamGuard(InlineMultiStreamGuard&& other) = delete; + InlineMultiStreamGuard& operator=(InlineMultiStreamGuard&& other) = delete; + + ~InlineMultiStreamGuard() noexcept { + if (this->impl_.has_value()) { + for (const Stream& s : original_streams_) { + this->impl_->exchangeStream(s); + } + } + } + + protected: + optional impl_; + + private: + /// The original streams that were active on all devices. + std::vector original_streams_; + + static DeviceType getDeviceTypeOfStreams(ArrayRef streams) { + TORCH_INTERNAL_ASSERT(!streams.empty()); + DeviceType type = streams[0].device_type(); + for (const auto idx : c10::irange(1, streams.size())) { + TORCH_CHECK_VALUE( + streams[idx].device_type() == type, + "Streams have a mix of device types: stream 0 is on ", + streams[0].device(), + " while stream ", + idx, + " is on device ", + streams[idx].device()); + } + return type; + } +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/LocalDispatchKeySet.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/LocalDispatchKeySet.h new file mode 100644 index 0000000000000000000000000000000000000000..acf7fce944b9fc7b5be7e94a7822541ab03a407f --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/LocalDispatchKeySet.h @@ -0,0 +1,162 @@ +#pragma once + +#include +#include + +// TLS management for DispatchKeySet (the "local" DispatchKeySet(s)) +// +// This manages two thread-local DispatchKeySets: +// +// - The included type set, which adds a tensor type for consideration +// in dispatch. (For example, you might add Profiling to +// the included type set to turn on profiling on all tensor operations.) +// +// - The excluded type set, which disqualifies a tensor type from dispatch. +// (For example, after redispatching on variable, we disqualify +// Autograd so we don't attempt to handle variable again.) +// (Exclusion wins over inclusion.) +// +// NB: Originally, I implemented the excluded type set as storing the inverted +// set, but TLS is defined to be zero-initialized, so this doesn't actually work +// (if it's inverted, you want the set to be -1 initialized). + +namespace c10::impl { + +// POD version of LocalDispatchKeySet. Declared here just so that +// we can put it in the guards. +// This struct encapsulates special handling for TLS initialization +// in set_included()/included() API so that they reflect the truth. +// If you want to create PODLocalDispatchKeySet with non-zero state, +// use set_included() instead of default constructor. +struct C10_API PODLocalDispatchKeySet { + uint64_t included_; + uint64_t excluded_; + + // See Note [TLS Initialization] + DispatchKeySet included() const { + return DispatchKeySet(DispatchKeySet::RAW, included_) ^ + c10::default_included_set; + } + DispatchKeySet excluded() const { + return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ + c10::default_excluded_set; + } + + void set_included(DispatchKeySet x) { + included_ = (x ^ c10::default_included_set).raw_repr(); + } + void set_excluded(DispatchKeySet x) { + excluded_ = (x ^ c10::default_excluded_set).raw_repr(); + } +}; +static_assert( + std::is_trivial_v, + "PODLocalDispatchKeySet must be a POD type."); + +struct C10_API LocalDispatchKeySet { + /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x) + : included_(x.included()), excluded_(x.excluded()) {} + DispatchKeySet included_; + DispatchKeySet excluded_; +}; + +// thread_local variables cannot be C10_API on Windows. +// Inlining this seems to break AutoDispatchBelowAutograd on Android. +#if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) +C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); +#else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) +extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; + +inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() { + // Don't let people fiddle with the thread_local directly just + // because they include this header. + return raw_local_dispatch_key_set; +} +#endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) + +// Internal, use ThreadLocalStateGuard +C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set); + +// RAII API for manipulating the thread-local dispatch state. + +class C10_API IncludeDispatchKeyGuard { + public: + IncludeDispatchKeyGuard(DispatchKeySet); + IncludeDispatchKeyGuard(DispatchKey k) + : IncludeDispatchKeyGuard(DispatchKeySet(k)) {} + IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete; + IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete; + IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete; + IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete; + ~IncludeDispatchKeyGuard(); + + private: + // A little micro-optimization to save us from tls_get_addr call + // on destruction + PODLocalDispatchKeySet* tls_; + DispatchKeySet include_; +}; + +class C10_API ExcludeDispatchKeyGuard { + public: + ExcludeDispatchKeyGuard(DispatchKeySet); + ExcludeDispatchKeyGuard(DispatchKey k) + : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {} + ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete; + ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete; + ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete; + ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete; + ~ExcludeDispatchKeyGuard(); + + private: + // A little micro-optimization to save us from tls_get_addr call + // on destruction + PODLocalDispatchKeySet* tls_; + DispatchKeySet exclude_; +}; + +struct C10_API ForceDispatchKeyGuard { + public: + ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set) + : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) { + c10::impl::_force_tls_local_dispatch_key_set(key_set); + } + ForceDispatchKeyGuard( + c10::DispatchKeySet include, + c10::DispatchKeySet exclude) + : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) { + auto updated_set = saved_keyset_; + updated_set.included_ = include; + updated_set.excluded_ = exclude; + c10::impl::_force_tls_local_dispatch_key_set(updated_set); + } + ~ForceDispatchKeyGuard() { + c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_); + } + + private: + c10::impl::LocalDispatchKeySet saved_keyset_; +}; + +// Non-RAII API for manipulating the thread-local dispatch state. +// Please prefer the RAII API. The non-RAII API may be useful when +// the included/excluded state of a given DispatchKey must span +// many calls from the Python to the C++, so you cannot conveniently +// use an RAII guard. +// +// Example use case: a Python context manager that includes a certain +// DispatchKey, to ensure ops running under the context manager dispatch +// through that DispatchKey's registered overrides. +// +// The non-RAII API is less efficient than the RAII guards because both the +// getter and setter will do a tls_getaddr lookup (the RAII struct only needs +// one!) + +C10_API bool tls_is_dispatch_key_excluded(DispatchKey x); +C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state); +C10_API bool tls_is_dispatch_key_included(DispatchKey x); +C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state); +C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks); +C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks); + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/PyInterpreter.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/PyInterpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..4f759aac2b7b5a39fe627673b03914be381156fd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/PyInterpreter.h @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declarations + +namespace c10 { +struct IValue; +class OperatorHandle; +struct TensorImpl; +} // namespace c10 + +namespace torch::jit { +using Stack = std::vector; +} + +// Actual implementation + +namespace c10::impl { + +struct C10_API PyInterpreter; + +// Note [Python interpreter tag] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Traditionally, PyTorch is layered such that our Python library +// (libtorch_python) references our pure C++ library (libtorch) as the +// natural order of things. However, sometimes this natural order is +// subverted: C++ objects refer to Python objects (for example, we +// store a PyObject* pointer on TensorImpl so that converting from a +// C++ Tensor to a Python Tensor is just a memory dereference). +// +// These unusual orderings must be treated with care. To start, you need to +// virtualize the destructor so that the PyObject can be decref'ed on +// destruction (because the C++ object itself doesn't know anything about +// Python--remember, layering!). This process itself is fraught, since +// acquiring the GIL could lead to deadlocks if someone is blocking on you +// while holding the GIL. Furthermore, if the C++ objects outlive the +// interpreter (which can happen if you stash them in a static global +// variable defined in libtorch), you may attempt to decref the object when +// the Python interpreter has already been shutdown. +// +// BUT WAIT, IT GETS WORSE. With torchdeploy, there may be multiple Python +// interpreters in a single process. If a C++ object is accessible from +// multiple interpreters, we must take care not to accidentally pass a +// PyObject from one interpreter with another interpreter. +// +// To prevent these mixups, we introduce a PyInterpreter "tag" (object with +// a vtable), which specifies a specific Python interpreter. +// +// - Any given object can be associated with AT MOST one Python interpreter. +// We represent the interpreter tag as a memory address to an instance of +// a virtual class that is allocated once per interpreter (this is so that +// we can request the interpreter to perform operations for us, if +// necessary). +// +// - It can be recorded with a PyObject (PyInterpreterObject) so that +// we know what interpreter the object is associated with, and we can +// raise an error if you try to use the PyObject from the wrong +// interpreter context. +// +// - It contains a vtable that can be used to perform various Python +// operations from ordinary C++ code that ordinarily wouldn't be accessible +// from libtorch. +// +// A simple use case is when a C++ object must be associated with a PyObject. +// However, for TensorImpl, we lazily allocate a PyObject the first time the +// object passes into Python. The invariants for this situation are more +// subtle: +// +// - A given TensorImpl's interpreter tag can only go from uninitialized to +// tagged; once tagged, this is a quiescent state (once tagged to an +// interpreter, ALWAYS tagged to that interpreter) +// +// - A thread may mutate the PyObject field of a TensorImpl if and only if it +// holds the GIL for the interpreter tagged on the TensorImpl. (If the +// TensorImpl is not tagged, it must first atomically claim its tag before it +// can validly write) +// +// WARNING: This class has to be written very carefully, because it may be +// possible for a Tensor to have a reference an interpreter corresponding to +// a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling +// virtual methods very dangerous, because the vtable may be garbage at that +// point (on a good day, you might get "pure virtual method called"). +// +// The idea to solve this problem is we always leak PyInterpreters (so they +// always stay live even after dlclose), and make sure we can disarm their +// virtual methods by indirecting through a separate PyInterpreterVTable +// object. This can be replaced with a no-op vtable from libc10.so, which +// is guaranteed to stick around until the bitter end. +// +// NB: The downside with representing PyInterpreter tags as full objects is that +// it takes an extra word on TensorImpl. If tags were instead just integer +// indices, on 64-bit architectures we could pack the tag and PyObject together +// into a single atomic word. On 32-bit architectures we could simply say that +// only one Python interpreter is supported (erroring if a nontrivial +// interpreter tag is attempted to be set). +// +// The difficulty with this scheme is we need to maintain an out-of-line table +// to get at the PyInterpreters so that we can do virtual method calls on them, +// and registration/deregistration to this table must be done in a thread safe +// manner. This can be easily done if the number of possible PyInterpreters is +// small enough (e.g., 8-bit integer) by simply preallocating an array of +// sufficient size to hold all possible interpreters. Surely 128 threads is +// more than enough for anyone! +// +// I didn't decide to do this technique at the moment, because the extra word +// added by the PyInterpreter tag takes us to 24 words, which means that we +// still fit inside three eight word cache lines. If you need to penny pinch +// another word consider doing this! + +struct C10_API PyInterpreterVTable { + virtual ~PyInterpreterVTable() = default; + + // Report the name of this interpreter + virtual std::string name() const = 0; + + // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call + // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] + virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0; + + // Perform a detach by deferring to the __torch_dispatch__ implementation of + // detach, which will also arrange for the PyObject to get copied in this + // situation + virtual c10::intrusive_ptr detach( + const TensorImpl* self) const = 0; + + // Invoke the Python boxed fallback dispatch to go back into Python + virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) + const = 0; + + virtual void reportErrorCallback(PyObject* callback, DispatchKey key) + const = 0; + + // This is only invoked in the multipy/torchdeploy situation from + // pythonOpRegistrationTrampoline; this lets us get to the Python + // interpreter to actually find the appropriate Python op registration + // entry to call. + virtual void python_op_registration_trampoline( + const c10::OperatorHandle& op, + c10::DispatchKey, + torch::jit::Stack* stack) const = 0; + + virtual void throw_abstract_impl_not_imported_error( + std::string opname, + const char* pymodule, + const char* context) const = 0; + + // Invoke the Python dispatcher to handle this call + virtual void python_dispatcher( + const c10::OperatorHandle& op, + c10::DispatchKeySet, + torch::jit::Stack* stack) const = 0; + + virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat) + const = 0; + virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat) + const = 0; + virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0; + virtual c10::Device device(const TensorImpl* self) const = 0; + virtual int64_t dim(const TensorImpl* self) const = 0; + virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0; + virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0; + virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0; + virtual c10::Layout layout(const TensorImpl* self) const = 0; + virtual int64_t numel(const TensorImpl* self) const = 0; + virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0; + virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0; + virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0; + + virtual void trace_gpu_event_creation(uintptr_t event) const = 0; + virtual void trace_gpu_event_deletion(uintptr_t event) const = 0; + virtual void trace_gpu_event_record(uintptr_t event, uintptr_t stream) + const = 0; + virtual void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) + const = 0; + virtual void trace_gpu_memory_allocation(uintptr_t ptr) const = 0; + virtual void trace_gpu_memory_deallocation(uintptr_t ptr) const = 0; + virtual void trace_gpu_stream_creation(uintptr_t stream) const = 0; + virtual void trace_gpu_device_synchronization() const = 0; + virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0; + virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0; + + virtual void reset_backward_hooks(const TensorImpl* self) const = 0; +}; + +struct C10_API PyInterpreter { + const PyInterpreterVTable* vtable_; + + PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable){}; + + const PyInterpreterVTable& operator*() const noexcept { + return *vtable_; + } + const PyInterpreterVTable* operator->() const noexcept { + return vtable_; + } + + // Disarm this PyInterpreter, making all of its methods noops. + // The vtable pointer is not an atomic at the moment, which means + // a disarm() invocation that is concurrent with active destructors + // is not thread safe and will trigger TSAN. My hope is that this + // situations doesn't ever actually happen; tensor destruction should + // quiesce when a dlclose happens, and any long lived tensors whose + // destructors would be disarmed here only begin the destruction process + // on process shutdown (long after the dlclose has occurred). + void disarm() noexcept; +}; + +// PyInterpreterStatus describes what the state of its interpreter tag +// is, relative to the thread currently holding the GIL. +enum class PyInterpreterStatus { + // We just allocated the Tensor, it hasn't escaped to other threads, + // we know that it definitely hasn't been tagged to be associated + // with an interpreter. + DEFINITELY_UNINITIALIZED, + // We queried the interpreter field and it looked uninitialized. But + // another thread may have raced with us to tag it with some other + // interpreter id. So we will have to do a CEX to make sure we can + // actually nab it. + MAYBE_UNINITIALIZED, + // We queried the interpreter field and it was tagged to belong to us. + // This means we have sole write access (as we hold the GIL for this + // interpreter) + TAGGED_BY_US, + // Someone else tagged this. We can't use this TensorImpl from Python. + TAGGED_BY_OTHER, +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/PyObjectSlot.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/PyObjectSlot.h new file mode 100644 index 0000000000000000000000000000000000000000..b850099490bdb336feb51ad1a91c2f191fd874e7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/PyObjectSlot.h @@ -0,0 +1,190 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace c10::impl { + +struct C10_API PyObjectSlot { + public: + PyObjectSlot(); + + ~PyObjectSlot(); + + void maybe_destroy_pyobj(); + + // Associate the TensorImpl with the specified PyObject, and, if necessary, + // also tag the interpreter. + // + // NB: This lives in a header so that we can inline away the switch on status + // + // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after + // PyObject if necessary! + void init_pyobj( + PyInterpreter* self_interpreter, + PyObject* pyobj, + PyInterpreterStatus status) { + impl::PyInterpreter* expected = nullptr; + switch (status) { + case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED: + // caller guarantees there is no multithreaded access; if there is + // no data race OK to do a relaxed store + pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + break; + case impl::PyInterpreterStatus::TAGGED_BY_US: + // no tagging is necessary, the tag is already correct + break; + case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED: + // attempt to claim this TensorImpl with the specified interpreter + // tag + if (pyobj_interpreter_.compare_exchange_strong( + expected, self_interpreter, std::memory_order_acq_rel)) { + break; + } + // test if, actually, it was already tagged by us! this situation can't + // be caused by a race, but it could be caused by a situation + // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED + // (because they didn't pre-check the tag) when actually it was + // owned by the interpreter + if (expected == self_interpreter) { + break; + } + // fallthrough, we lost the race. We are guaranteed not to lose the + // race with ourself, as calls to init_pyobj with the same interpreter + // ID must be sequentialized by the GIL + [[fallthrough]]; + case impl::PyInterpreterStatus::TAGGED_BY_OTHER: + TORCH_CHECK( + false, + "cannot allocate PyObject for Tensor on interpreter ", + self_interpreter, + " that has already been used by another torch deploy interpreter ", + pyobj_interpreter_.load()); + } + + // we are the ONLY thread that can have gotten to this point. It is not + // possible to conflict with another zero interpreter as access is protected + // by GIL + // NB: owns_pyobj tag is initially false + pyobj_ = pyobj; + } + + // Query the PyObject interpreter. This may return null if there is no + // interpreter. This is racy! + PyInterpreter* pyobj_interpreter(); + + PyObject* _unchecked_untagged_pyobj() const; + + // Test the interpreter tag. If tagged for the current interpreter, return + // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, + // returns a nullopt. If it is definitely invalid, raises an error. + // + // If `ignore_hermetic_tls` is false and this function is called from a + // hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then + // nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic + // context is ignored, allowing you to check the interpreter tag of a + // nonhermetic PyObject from within a hermetic context. This is necessary + // because there are some cases where the deallocator function of a + // nonhermetic PyObject is called from within a hermetic context, so it must + // be properly treated as a nonhermetic PyObject. + // + // NB: this lives in header so that we can avoid actually creating the + // c10::optional + c10::optional check_pyobj( + PyInterpreter* self_interpreter, + bool ignore_hermetic_tls = false) const { + // Note [Memory ordering on Python interpreter tag] + impl::PyInterpreter* interpreter = + pyobj_interpreter_.load(std::memory_order_acquire); + if (interpreter == nullptr) { + // NB: This never returns DEFINITELY_UNINITIALIZED because there is + // always the possibility that another thread races to initialize + // after we query here. The only time when we can conclude a tensor + // is definitely uninitialized is when we have just allocated it and + // it cannot have escaped to other threads yet + return c10::nullopt; + } else if (interpreter == self_interpreter) { + // NB: pyobj_ could still be null! + if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { + return c10::nullopt; + } else { + return c10::make_optional(_unchecked_untagged_pyobj()); + } + } else { + TORCH_CHECK( + false, + "cannot access PyObject for Tensor on interpreter ", + (*self_interpreter)->name(), + " that has already been used by another torch deploy interpreter ", + (*pyobj_interpreter_.load())->name()); + } + } + + // Clear the PyObject field for an interpreter, in situations where we + // statically know the tensor is tagged with our interpreter. + void unchecked_clear_pyobj(PyInterpreter* interpreter); + + PyInterpreter& load_pyobj_interpreter() const; + + // Check if the PyObjectSlot's interpreter is the same as the specified + // interpreter + bool check_interpreter(PyInterpreter* interpreter); + + // Check if the PyObjectSlot is holding a PyObject, owned or non-owned + bool has_pyobj_nonhermetic(); + + bool owns_pyobj(); + + void set_owns_pyobj(bool b); + + private: + // This field contains the interpreter tag for this object. See + // Note [Python interpreter tag] for general context + // + // Note [Memory ordering on Python interpreter tag] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // What memory_order do we need when accessing this atomic? We don't + // need a single total modification order (as provided by + // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only + // transition from -1 to some positive integer and never changes afterwards. + // Because there is only one modification, it trivially already has a total + // modification order (e.g., we don't need fences or locked instructions on + // x86) + // + // In fact, one could make a reasonable argument that relaxed reads are OK, + // due to the presence of external locking (GIL) to ensure that interactions + // with other data structures are still correctly synchronized, so that + // we fall in the "Single-Location Data Structures" case as described in + // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf + // However, on x86, it doesn't matter if I use acquire or relaxed on the load + // as I get the same assembly in both cases. So I just use the more + // conservative acquire (which will impede compiler optimizations but I don't + // care) + std::atomic pyobj_interpreter_; + + // This field contains a reference to a PyObject representing this Tensor. + // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new + // PyObject for it and set this field. This field does not have to be + // protected by an atomic as it is only allowed to be accessed when you hold + // the GIL, or during destruction of the tensor. + // + // When a PyObject dies, you are obligated to clear this field + // (otherwise, you will try to use-after-free the pyobj); this currently + // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp + // + // NB: Ordinarily, this should not be a strong reference, as if the + // PyObject owns the Tensor, this would create a reference cycle. + // However, sometimes this ownership flips. To track who owns + // who, this has a single pointer tag indicating whether or not the + // C++ object owns the PyObject (the common case, zero, means PyObject + // owns the C++ object); see _unchecked_untagged_pyobj for raw access + // or check_pyobj for checked access. See references to PyObject + // resurrection in torch/csrc/autograd/python_variable.cpp + PyObject* pyobj_; +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/PythonDispatcherTLS.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/PythonDispatcherTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..0c6ed14312275653ee8bbf4c7940a4918d1df2b9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/PythonDispatcherTLS.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace c10::impl { + +struct C10_API PythonDispatcherTLS { + static void set_state(PyInterpreter* state); + static PyInterpreter* get_state(); + static void reset_state(); +}; + +struct C10_API DisablePythonDispatcher { + DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) { + PythonDispatcherTLS::set_state({}); + } + ~DisablePythonDispatcher() { + PythonDispatcherTLS::set_state(old_); + } + PyInterpreter* old_; +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/SizesAndStrides.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/SizesAndStrides.h new file mode 100644 index 0000000000000000000000000000000000000000..59f5a9a4d7fccba8d3a046da0dcc82355c580b60 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/SizesAndStrides.h @@ -0,0 +1,315 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 + +namespace c10::impl { + +// Packed container for TensorImpl sizes and strides. +// This design improves on the previous approach of using a pair of +// c10::SmallVector by specializing for the operations we +// actually use and enforcing that the number of sizes is the same as +// the number of strides. The memory layout is as follows: +// +// 1 size_t for the size +// 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer +// to out-of-line array +class C10_API SizesAndStrides { + public: + // TODO: different iterator types for sizes & strides to prevent + // mixing the two accidentally. + using sizes_iterator = int64_t*; + using sizes_const_iterator = const int64_t*; + using strides_iterator = int64_t*; + using strides_const_iterator = const int64_t*; + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + SizesAndStrides() { + size_at_unchecked(0) = 0; + stride_at_unchecked(0) = 1; + } + + ~SizesAndStrides() { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { + if (C10_LIKELY(rhs.isInline())) { + copyDataInline(rhs); + } else { + allocateOutOfLineStorage(size_); + copyDataOutline(rhs); + } + } + + SizesAndStrides& operator=(const SizesAndStrides& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + if (isInline()) { + allocateOutOfLineStorage(rhs.size_); + } else { + resizeOutOfLineStorage(rhs.size_); + } + copyDataOutline(rhs); + } + size_ = rhs.size_; + return *this; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { + if (C10_LIKELY(isInline())) { + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } else { + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + + rhs.size_ = 0; + } + + // Move from rhs. rhs.size() == 0 afterwards. + SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { + if (this == &rhs) { + return *this; + } + if (C10_LIKELY(rhs.isInline())) { + if (C10_UNLIKELY(!isInline())) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + copyDataInline(rhs); + } else { + // They're outline. We're going to steal their vector. + if (!isInline()) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + free(outOfLineStorage_); + } + outOfLineStorage_ = rhs.outOfLineStorage_; + rhs.outOfLineStorage_ = nullptr; + } + size_ = rhs.size_; + rhs.size_ = 0; + + return *this; + } + + size_t size() const noexcept { + return size_; + } + + const int64_t* sizes_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + int64_t* sizes_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[0]; + } else { + return &outOfLineStorage_[0]; + } + } + + sizes_const_iterator sizes_begin() const noexcept { + return sizes_data(); + } + + sizes_iterator sizes_begin() noexcept { + return sizes_data(); + } + + sizes_const_iterator sizes_end() const noexcept { + return sizes_begin() + size(); + } + + sizes_iterator sizes_end() noexcept { + return sizes_begin() + size(); + } + + IntArrayRef sizes_arrayref() const noexcept { + return IntArrayRef{sizes_data(), size()}; + } + + void set_sizes(IntArrayRef newSizes) { + resize(newSizes.size()); + std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); + } + + void set_strides(IntArrayRef strides) { + TORCH_INTERNAL_ASSERT(strides.size() == size()); + std::copy(strides.begin(), strides.end(), strides_begin()); + } + + const int64_t* strides_data() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + int64_t* strides_data() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_begin() const noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_iterator strides_begin() noexcept { + if (C10_LIKELY(isInline())) { + return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; + } else { + return &outOfLineStorage_[size()]; + } + } + + strides_const_iterator strides_end() const noexcept { + return strides_begin() + size(); + } + + strides_iterator strides_end() noexcept { + return strides_begin() + size(); + } + + IntArrayRef strides_arrayref() const noexcept { + return IntArrayRef{strides_data(), size()}; + } + + // Size accessors. + int64_t size_at(size_t idx) const noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t& size_at(size_t idx) noexcept { + assert(idx < size()); + return sizes_data()[idx]; + } + + int64_t size_at_unchecked(size_t idx) const noexcept { + return sizes_data()[idx]; + } + + int64_t& size_at_unchecked(size_t idx) noexcept { + return sizes_data()[idx]; + } + + // Size accessors. + int64_t stride_at(size_t idx) const noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t& stride_at(size_t idx) noexcept { + assert(idx < size()); + return strides_data()[idx]; + } + + int64_t stride_at_unchecked(size_t idx) const noexcept { + return strides_data()[idx]; + } + + int64_t& stride_at_unchecked(size_t idx) noexcept { + return strides_data()[idx]; + } + + void resize(size_t newSize) { + const auto oldSize = size(); + if (newSize == oldSize) { + return; + } + if (C10_LIKELY( + newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { + if (oldSize < newSize) { + const auto bytesToZero = + (newSize - oldSize) * sizeof(inlineStorage_[0]); + memset(&inlineStorage_[oldSize], 0, bytesToZero); + memset( + &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], + 0, + bytesToZero); + } + size_ = newSize; + } else { + resizeSlowPath(newSize, oldSize); + } + } + + void resizeSlowPath(size_t newSize, size_t oldSize); + + private: + bool isInline() const noexcept { + return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + } + + void copyDataInline(const SizesAndStrides& rhs) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); + memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); + } + + static size_t storageBytes(size_t size) noexcept { + return size * 2 * sizeof(int64_t); + } + + void allocateOutOfLineStorage(size_t size) { + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + outOfLineStorage_ = static_cast(malloc(storageBytes(size))); + TORCH_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void resizeOutOfLineStorage(size_t newSize) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); + outOfLineStorage_ = static_cast( + // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) + realloc(outOfLineStorage_, storageBytes(newSize))); + TORCH_CHECK( + outOfLineStorage_, + "Could not allocate memory for Tensor SizesAndStrides!"); + } + + void copyDataOutline(const SizesAndStrides& rhs) noexcept { + memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); + } + + size_t size_{1}; + union { + int64_t* outOfLineStorage_; + // NOLINTNEXTLINE(*c-array*) + int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; + }; +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/TorchDispatchModeTLS.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/TorchDispatchModeTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..12546ff72817e58f8b6c5bdf286977dfc578c7fb --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/TorchDispatchModeTLS.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +namespace c10::impl { + +enum class TorchDispatchModeKey : int8_t { + FAKE, + PROXY, + FUNCTIONAL, + NUM_MODE_KEYS +}; + +struct C10_API TorchDispatchModeTLS { + // This API is NOT invariant safe. + // It must not take in an infra mode that uses TorchDispatchModeKey + // If you're pushing an infra mode onto the stack, we expect + // you to use set_mode + static void push_non_infra_mode_onto_stack( + std::shared_ptr mode); + // Pops the top mode of the stack, + // giving precedence to user modes before attempting to pop + // any infra modes + static const std::shared_ptr pop_stack(); + // Returns the highest-priority infra mode on the stack, + // along with its mode key. + static const std::tuple, TorchDispatchModeKey> + pop_highest_infra_mode(); + + static const std::shared_ptr& get_stack_at(int64_t idx); + static int64_t stack_len(); + + static const c10::optional> get_mode( + TorchDispatchModeKey mode_key); + static const c10::optional> unset_mode( + TorchDispatchModeKey mode_key); + static void set_mode( + const std::shared_ptr& mode, + TorchDispatchModeKey mode_key); + + static const TorchDispatchModeTLS& get_state(); + static void set_state(TorchDispatchModeTLS state); + + static bool any_modes_set(bool skip_infra_modes = false); + + private: + std::vector> stack_; + // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the + // stack + // However, we only allow a single FakeTensorMode onto the stack at a time + // (Pushing additional FakeTensorModes onto the stack is a no-op) + std::array< + c10::optional>, + static_cast(TorchDispatchModeKey::NUM_MODE_KEYS)> + infra_modes_; +}; + +C10_API bool dispatch_mode_enabled(); + +C10_API std::string to_string(TorchDispatchModeKey mode_key); + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/VirtualGuardImpl.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/VirtualGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..62e430e423855787f6deae36735687c358f649fa --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/VirtualGuardImpl.h @@ -0,0 +1,91 @@ +#pragma once + +#include + +namespace c10::impl { + +/** + * An implementation of DeviceGuardImplInterface which delegates + * to virtual dispatch on the DeviceGuardImpl registry. + */ +class VirtualGuardImpl final : public DeviceGuardImplInterface { + public: + VirtualGuardImpl(DeviceType device_type) + : impl_(getDeviceGuardImpl(device_type)) {} + // This constructor exists purely for testing + VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {} + + // Copying and moving is OK! + VirtualGuardImpl(const VirtualGuardImpl&) = default; + VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default; + VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default; + VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default; + + DeviceType type() const override { + return impl_->type(); + } + Device exchangeDevice(Device d) const override { + return impl_->exchangeDevice(d); + } + Device getDevice() const override { + return impl_->getDevice(); + } + void setDevice(Device d) const override { + impl_->setDevice(d); + } + void uncheckedSetDevice(Device d) const noexcept override { + impl_->uncheckedSetDevice(d); + } + Stream getStream(Device d) const noexcept override { + return impl_->getStream(d); + } + Stream getDefaultStream(Device d) const override { + return impl_->getDefaultStream(d); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) + const override { + return impl_->getStreamFromGlobalPool(d, isHighPriority); + } + Stream exchangeStream(Stream s) const noexcept override { + return impl_->exchangeStream(s); + } + DeviceIndex deviceCount() const noexcept override { + return impl_->deviceCount(); + } + + // Event functions + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + impl_->record(event, stream, device_index, flag); + } + void block(void* event, const Stream& stream) const override { + impl_->block(event, stream); + } + bool queryEvent(void* event) const override { + return impl_->queryEvent(event); + } + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + impl_->destroyEvent(event, device_index); + } + + bool queryStream(const Stream& stream) const override { + return impl_->queryStream(stream); + } + void synchronizeStream(const Stream& stream) const override { + impl_->synchronizeStream(stream); + } + + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { + impl_->recordDataPtrOnStream(data_ptr, stream); + } + + private: + const DeviceGuardImplInterface* impl_ = nullptr; +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/impl/alloc_cpu.h b/MLPY/Lib/site-packages/torch/include/c10/core/impl/alloc_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..2d2b27d572fda80daa6c8705714fdf9689b9ce9a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/impl/alloc_cpu.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +#include + +namespace c10 { + +C10_API void* alloc_cpu(size_t nbytes); +C10_API void free_cpu(void* data); + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/core/thread_pool.h b/MLPY/Lib/site-packages/torch/include/c10/core/thread_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..0cc15b325dfb9593205253efba320ec428294ea0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/core/thread_pool.h @@ -0,0 +1,120 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { + +class C10_API TaskThreadPoolBase { + public: + virtual void run(std::function func) = 0; + + virtual size_t size() const = 0; + + /** + * The number of available (i.e. idle) threads in this thread pool. + */ + virtual size_t numAvailable() const = 0; + + /** + * Check if the current thread is from the thread pool. + */ + virtual bool inThreadPool() const = 0; + + virtual ~TaskThreadPoolBase() noexcept = default; + + static size_t defaultNumThreads(); +}; + +class C10_API ThreadPool : public c10::TaskThreadPoolBase { + protected: + struct task_element_t { + bool run_with_id; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::function no_id; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::function with_id; + + explicit task_element_t(std::function f) + : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {} + explicit task_element_t(std::function f) + : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {} + }; + + std::queue tasks_; + std::vector threads_; + mutable std::mutex mutex_; + std::condition_variable condition_; + std::condition_variable completed_; + std::atomic_bool running_; + bool complete_; + std::size_t available_; + std::size_t total_; + int numa_node_id_; + + public: + ThreadPool() = delete; + + explicit ThreadPool( + int pool_size, + int numa_node_id = -1, + const std::function& init_thread = nullptr); + + ~ThreadPool() override; + + size_t size() const override; + + size_t numAvailable() const override; + + bool inThreadPool() const override; + + void run(std::function func) override; + + template + void runTaskWithID(Task task) { + std::unique_lock lock(mutex_); + + // Set task and signal condition variable so that a worker thread will + // wake up and use the task. + tasks_.emplace(static_cast>(task)); + complete_ = false; + condition_.notify_one(); + } + + /// @brief Wait for queue to be empty + void waitWorkComplete(); + + private: + // @brief Entry point for pool threads. + void main_loop(std::size_t index); +}; + +class C10_API TaskThreadPool : public c10::ThreadPool { + public: + explicit TaskThreadPool(int pool_size, int numa_node_id = -1) + : ThreadPool(pool_size, numa_node_id, [numa_node_id]() { + setThreadName("CaffeTaskThread"); + NUMABind(numa_node_id); + }) {} +}; + +C10_DECLARE_SHARED_REGISTRY( + ThreadPoolRegistry, + TaskThreadPoolBase, + int, + int, + bool); + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAAlgorithm.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAAlgorithm.h new file mode 100644 index 0000000000000000000000000000000000000000..f771421d4c8c4e4270ee6ff27ef736bbbcec318e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAAlgorithm.h @@ -0,0 +1,31 @@ +#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS +#include +#include +#include +#include +#endif +namespace c10::cuda { +#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS +template +__forceinline__ __device__ Iter +lower_bound(Iter start, Iter end, Scalar value) { + return thrust::lower_bound(thrust::device, start, end, value); +} +#else +// thrust::lower_bound is broken on device, see +// https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by +// https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28 +template +__device__ Iter lower_bound(Iter start, Iter end, Scalar value) { + while (start < end) { + auto mid = start + ((end - start) >> 1); + if (*mid < value) { + start = mid + 1; + } else { + end = mid; + } + } + return end; +} +#endif // THRUST_DEVICE_LOWER_BOUND_WORKS +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAAllocatorConfig.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAAllocatorConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..ad45cd71d4e7858dee81ca82e247a5e33395e305 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAAllocatorConfig.h @@ -0,0 +1,124 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10::cuda::CUDACachingAllocator { + +// Environment config parser +class C10_CUDA_API CUDAAllocatorConfig { + public: + static size_t max_split_size() { + return instance().m_max_split_size; + } + static double garbage_collection_threshold() { + return instance().m_garbage_collection_threshold; + } + + static bool expandable_segments() { +#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED + if (instance().m_expandable_segments) { + TORCH_WARN_ONCE("expandable_segments not supported on this platform") + } + return false; +#else + return instance().m_expandable_segments; +#endif + } + + static bool release_lock_on_cudamalloc() { + return instance().m_release_lock_on_cudamalloc; + } + + /** Pinned memory allocator settings */ + static bool pinned_use_cuda_host_register() { + return instance().m_pinned_use_cuda_host_register; + } + + static size_t pinned_num_register_threads() { + return instance().m_pinned_num_register_threads; + } + + static size_t pinned_max_register_threads() { + // Based on the benchmark results, we see better allocation performance + // with 8 threads. However on future systems, we may need more threads + // and limiting this to 128 threads. + return 128; + } + + // This is used to round-up allocation size to nearest power of 2 divisions. + // More description below in function roundup_power2_next_division + // As ane example, if we want 4 divisions between 2's power, this can be done + // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 + static size_t roundup_power2_divisions(size_t size); + + static std::vector roundup_power2_divisions() { + return instance().m_roundup_power2_divisions; + } + + static std::string last_allocator_settings() { + std::lock_guard lock( + instance().m_last_allocator_settings_mutex); + return instance().m_last_allocator_settings; + } + + static CUDAAllocatorConfig& instance() { + static CUDAAllocatorConfig* s_instance = ([]() { + auto inst = new CUDAAllocatorConfig(); + const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF"); + inst->parseArgs(env); + return inst; + })(); + return *s_instance; + } + + void parseArgs(const char* env); + + private: + CUDAAllocatorConfig(); + + static void lexArgs(const char* env, std::vector& config); + static void consumeToken( + const std::vector& config, + size_t i, + const char c); + size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseGarbageCollectionThreshold( + const std::vector& config, + size_t i); + size_t parseRoundUpPower2Divisions( + const std::vector& config, + size_t i); + size_t parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync); + size_t parsePinnedUseCudaHostRegister( + const std::vector& config, + size_t i); + size_t parsePinnedNumRegisterThreads( + const std::vector& config, + size_t i); + + std::atomic m_max_split_size; + std::vector m_roundup_power2_divisions; + std::atomic m_garbage_collection_threshold; + std::atomic m_pinned_num_register_threads; + std::atomic m_expandable_segments; + std::atomic m_release_lock_on_cudamalloc; + std::atomic m_pinned_use_cuda_host_register; + std::string m_last_allocator_settings; + std::mutex m_last_allocator_settings_mutex; +}; + +// General caching allocator utilities +C10_CUDA_API void setAllocatorSettings(const std::string& env); + +} // namespace c10::cuda::CUDACachingAllocator diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDACachingAllocator.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDACachingAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..40f6bebe187585135a2d1e5c5863bec8b6585b57 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDACachingAllocator.h @@ -0,0 +1,481 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +// Caching allocator will execute every registered callback if it unable to find +// block inside of already allocated area. +class C10_CUDA_API FreeMemoryCallback { + public: + virtual ~FreeMemoryCallback() = default; + virtual bool Execute() = 0; +}; + +C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); +#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__); +} // namespace c10 + // +// TODO: Turn this into an honest to goodness class. I briefly attempted to do +// this, but it was a bit irritating to figure out how to also correctly +// apply pimpl pattern so I didn't have to leak any internal implementation +// details in the header (CUDACachingAllocator could be made a pimpl, but +// you also need to appropriately define a class which is a subclass +// of Allocator. Not impossible, but required a bit more surgery than +// I wanted to do at the time.) +// +// Why is this using a namespace rather than old-style THCCachingAllocator_ +// prefix? Mostly because it made the HIPify rules easier to write; _ is +// not counted as a word boundary, so you would otherwise have to list each +// of these functions. + +namespace c10::cuda::CUDACachingAllocator { + +extern const size_t kLargeBuffer; + +struct Stat { + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +typedef std::array(StatType::NUM_TYPES)> StatArray; + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from cudaMalloc(). + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via cudaFree) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to CUDA malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of CUDA allocation calls. This includes both cuMemMap + // and cudaMalloc. + int64_t num_device_alloc = 0; + + // COUNT: total number of CUDA free calls. This includes both cuMemUnmap + // and cudaFree. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +typedef std::shared_ptr (*CreateContextFn)(); + +// Struct containing info of an allocation block (i.e. a fractional part of a +// cudaMalloc).. +struct BlockInfo { + int64_t size = 0; + int64_t requested_size = 0; + int32_t gc_counter = 0; + bool allocated = false; + bool active = false; + std::shared_ptr + context_when_allocated; // per-watcher context +}; + +// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc). +struct SegmentInfo { + c10::DeviceIndex device = 0; + int64_t address = 0; + int64_t total_size = 0; + int64_t requested_size = 0; // unrounded, actually requested size + int64_t allocated_size = 0; + int64_t active_size = 0; + cudaStream_t stream = nullptr; + bool is_large = false; + bool is_expandable = false; + MempoolId_t owner_private_pool_id = {0, 0}; + std::vector blocks; + std::shared_ptr context_when_allocated; +}; + +struct AllocatorState { + virtual ~AllocatorState() = default; +}; + +union trace_time_ { + time_t t_; + approx_time_t approx_t_; +}; + +struct TraceEntry { + enum Action { + ALLOC, // API made to the caching allocator for new memory + FREE_REQUESTED, // API call made to the caching allocator to free memory + FREE_COMPLETED, // The allocator might have to delay a free because + // it is still in use on another stream via record_stream + // This event is generated when a free actually completes. + SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS + SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to + // defragment or empty_caches) + SEGMENT_MAP, // a call to cuMemMap (used with expandable_segments) + SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments) + SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace + // events + OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free + // bytes reported by cuda) + }; + TraceEntry( + Action action, + c10::DeviceIndex device, + int64_t addr, + size_t size, + cudaStream_t stream, + approx_time_t time, + std::shared_ptr context = nullptr) + : action_(action), + device_(device), + addr_(addr), + context_(std::move(context)), + stream_(stream), + size_(static_cast(size)) { + time_.approx_t_ = time; + } + Action action_; + c10::DeviceIndex device_; + int64_t addr_; // for OOM, this is the amount of free bytes reported by cuda + std::shared_ptr context_; + cudaStream_t stream_{}; + int64_t size_; + trace_time_ time_{}; +}; + +struct AllocatorConfigInfo { + double garbage_collection_threshold; + size_t max_split_size; + size_t pinned_num_register_threads; + bool expandable_segments; + bool release_lock_on_malloc; + bool pinned_use_host_register; + std::string last_allocator_settings; + std::vector roundup_power2_divisions; +}; + +struct SnapshotInfo { + std::vector segments; + std::vector> device_traces; + AllocatorConfigInfo config_metadata; +}; + +// returns the pointers freed in the pool +// and the pointers allocated. Note: a pointer +// may appear in both freed and allocated +struct CheckpointDelta { + std::vector ptrs_freed; + std::vector dataptrs_allocd; +}; + +enum struct RecordContext { + NEVER = 0, + STATE = 1, // only keep stacks for active allocations + ALLOC = 2, // additionally keep stacks for allocations in the trace history + ALL = 3, // additionally record stacks for when something is freed +}; + +// Size pretty-printer +std::string format_size(uint64_t size); + +using OutOfMemoryObserver = std::function; + +using AllocatorTraceTracker = std::function; + +class CUDAAllocator : public Allocator { + public: + virtual void* raw_alloc(size_t nbytes) = 0; + virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; + virtual void raw_delete(void* ptr) = 0; + virtual void init(int device_count) = 0; + virtual bool initialized() = 0; + virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; + virtual void emptyCache() = 0; + virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; + virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; + virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; + virtual DeviceStats getDeviceStats(c10::DeviceIndex device) = 0; + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + virtual void resetPeakStats(c10::DeviceIndex device) = 0; + virtual SnapshotInfo snapshot() = 0; + virtual void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) = 0; + virtual void endAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id) = 0; + virtual void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) = 0; + // returns true if the allocated blocks are equal to expected live allocations + virtual bool checkPoolLiveAllocations( + c10::DeviceIndex device, + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) { + TORCH_CHECK( + false, + name(), + " does not yet support checkPoolLiveAllocations. " + "If you need it, please file an issue describing your use case."); + } + virtual std::shared_ptr getIpcDevPtr(std::string handle) = 0; + virtual bool isHistoryEnabled() { + TORCH_CHECK( + false, + name(), + " does not yet support recordHistory. " + "If you need it, please file an issue describing your use case."); + } + virtual void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) = 0; + virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0; + + // Attached AllocatorTraceTracker callbacks will be called while the + // per-device allocator lock is held. Any additional locks taken from within + // the callback must be proven to always have the lock order that never + // triggers a deadlock. In particular, Python's GIL may be held when + // calling the allocator so it is unsafe to try to acquire the GIL in this + // callback. + virtual void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) = 0; + + virtual void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) = 0; + + // memory not allocated from cudaMalloc cannot be copied + // across devices using cudaMemcpyAsync if peer to peer access is disabled. + // instead it requires cudaMemcpyAsyncPeer + // with P2P Enabled, all combinations work + // with P2P Disabled: + // cudaMalloc cudaMallocAsync/cuMemMap + // cudaMemcpyAsyncPeer works works + // cudaMemcpyAsync works error + + // This function performs chooses to use the Peer version of + // memcpy if required based on where the allocated put dst/src. + virtual cudaError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + cudaStream_t stream, + bool p2p_enabled) = 0; + virtual std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) = 0; + virtual CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) = 0; + virtual std::string name() = 0; +}; + +// Allocator object, statically initialized +// See BackendInitializer in CUDACachingAllocator.cpp. +// Atomic loads on x86 are just normal loads, +// (atomic stores are different), so reading this value +// is no different than loading a pointer. +C10_CUDA_API extern std::atomic allocator; + +inline CUDAAllocator* get() { + return allocator.load(); +} + +// Called directly by clients. +inline void* raw_alloc(size_t nbytes) { + return get()->raw_alloc(nbytes); +} + +inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) { + return get()->raw_alloc_with_stream(nbytes, stream); +} + +inline void raw_delete(void* ptr) { + return get()->raw_delete(ptr); +} + +inline void init(int device_count) { + return get()->init(device_count); +} + +inline void setMemoryFraction(double fraction, c10::DeviceIndex device) { + return get()->setMemoryFraction(fraction, device); +} + +inline void emptyCache() { + return get()->emptyCache(); +} + +inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) { + return get()->cacheInfo(device, largestBlock); +} + +inline void* getBaseAllocation(void* ptr, size_t* size) { + return get()->getBaseAllocation(ptr, size); +} + +inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) { + return get()->recordStream(dataPtr, stream); +} + +inline DeviceStats getDeviceStats(c10::DeviceIndex device) { + return get()->getDeviceStats(device); +} + +inline void resetAccumulatedStats(c10::DeviceIndex device) { + return get()->resetAccumulatedStats(device); +} + +inline void resetPeakStats(c10::DeviceIndex device) { + return get()->resetPeakStats(device); +} + +inline SnapshotInfo snapshot() { + return get()->snapshot(); +} + +inline std::shared_ptr getCheckpointState( + c10::DeviceIndex device, + MempoolId_t id) { + return get()->getCheckpointState(device, id); +} + +inline CheckpointDelta setCheckpointPoolState( + c10::DeviceIndex device, + std::shared_ptr pps) { + return get()->setCheckpointPoolState(device, std::move(pps)); +} + +// CUDAGraph interactions +inline void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + get()->beginAllocateToPool(device, mempool_id, std::move(filter)); +} + +inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + get()->endAllocateToPool(device, mempool_id); +} + +inline void recordHistory( + bool enabled, + CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + RecordContext when) { + return get()->recordHistory( + enabled, context_recorder, alloc_trace_max_entries, when); +} + +inline bool isHistoryEnabled() { + return get()->isHistoryEnabled(); +} + +inline bool checkPoolLiveAllocations( + c10::DeviceIndex device, + MempoolId_t mempool_id, + const std::unordered_set& expected_live_allocations) { + return get()->checkPoolLiveAllocations( + device, mempool_id, expected_live_allocations); +} + +inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) { + return get()->attachOutOfMemoryObserver(std::move(observer)); +} + +inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { + return get()->attachAllocatorTraceTracker(std::move(tracker)); +} + +inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return get()->releasePool(device, mempool_id); +} +// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE +inline std::shared_ptr getIpcDevPtr(std::string handle) { + return get()->getIpcDevPtr(std::move(handle)); +} + +inline std::string name() { + return get()->name(); +} + +inline cudaError_t memcpyAsync( + void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + cudaStream_t stream, + bool p2p_enabled) { + return get()->memcpyAsync( + dst, dstDevice, src, srcDevice, count, stream, p2p_enabled); +} + +inline void enablePeerAccess( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access) { + return get()->enablePeerAccess(dev, dev_to_access); +} + +} // namespace c10::cuda::CUDACachingAllocator diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDADeviceAssertion.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDADeviceAssertion.h new file mode 100644 index 0000000000000000000000000000000000000000..258abc302ae187cec646b5431ba1c96b331f0e6b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDADeviceAssertion.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include + +namespace c10::cuda { + +#ifdef TORCH_USE_CUDA_DSA +// Copy string from `src` to `dst` +static __device__ void dstrcpy(char* dst, const char* src) { + int i = 0; + // Copy string from source to destination, ensuring that it + // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1` + while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) { + *dst++ = *src++; + } + *dst = '\0'; +} + +static __device__ void dsa_add_new_assertion_failure( + DeviceAssertionsData* assertions_data, + const char* assertion_msg, + const char* filename, + const char* function_name, + const int line_number, + const uint32_t caller, + const dim3 block_id, + const dim3 thread_id) { + // `assertions_data` may be nullptr if device-side assertion checking + // is disabled at run-time. If it is disabled at compile time this + // function will never be called + if (!assertions_data) { + return; + } + + // Atomically increment so other threads can fail at the same time + // Note that incrementing this means that the CPU can observe that + // a failure has happened and can begin to respond before we've + // written information about that failure out to the buffer. + const auto nid = atomicAdd(&(assertions_data->assertion_count), 1); + + if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) { + // At this point we're ran out of assertion buffer space. + // We could print a message about this, but that'd get + // spammy if a lot of threads did it, so we just silently + // ignore any other assertion failures. In most cases the + // failures will all probably be analogous anyway. + return; + } + + // Write information about the assertion failure to memory. + // Note that this occurs only after the `assertion_count` + // increment broadcasts that there's been a problem. + auto& self = assertions_data->assertions[nid]; + dstrcpy(self.assertion_msg, assertion_msg); + dstrcpy(self.filename, filename); + dstrcpy(self.function_name, function_name); + self.line_number = line_number; + self.caller = caller; + self.block_id[0] = block_id.x; + self.block_id[1] = block_id.y; + self.block_id[2] = block_id.z; + self.thread_id[0] = thread_id.x; + self.thread_id[1] = thread_id.y; + self.thread_id[2] = thread_id.z; +} + +// Emulates a kernel assertion. The assertion won't stop the kernel's progress, +// so you should assume everything the kernel produces is garbage if there's an +// assertion failure. +// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are +// arguments of the kernel and therefore accessible. +#define CUDA_KERNEL_ASSERT2(condition) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + /* Has an atomic element so threads can fail at the same time */ \ + c10::cuda::dsa_add_new_assertion_failure( \ + assertions_data, \ + C10_STRINGIZE(condition), \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + assertion_caller_id, \ + blockIdx, \ + threadIdx); \ + /* Now that the kernel has failed we early exit the kernel, but */ \ + /* otherwise keep going and rely on the host to check UVM and */ \ + /* determine we've had a problem */ \ + return; \ + } \ + } while (false) +#else +#define CUDA_KERNEL_ASSERT2(condition) assert(condition) +#endif + +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDADeviceAssertionHost.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDADeviceAssertionHost.h new file mode 100644 index 0000000000000000000000000000000000000000..6d3e99198e7de7dd61c1880489c956231fb53337 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDADeviceAssertionHost.h @@ -0,0 +1,164 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#ifdef USE_CUDA +#define TORCH_USE_CUDA_DSA +#endif + +/// Number of assertion failure messages we can store. If this is too small +/// threads will fail silently. +constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10; +constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512; + +namespace c10::cuda { + +/// Holds information about any device-side assertions that fail. +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionData { + /// Stringification of the assertion + // NOLINTNEXTLINE(*-c-arrays) + char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]{}; + /// File the assertion was in + // NOLINTNEXTLINE(*-c-arrays) + char filename[C10_CUDA_DSA_MAX_STR_LEN]{}; + /// Name of the function the assertion was in + // NOLINTNEXTLINE(*-c-arrays) + char function_name[C10_CUDA_DSA_MAX_STR_LEN]{}; + /// Line number the assertion was at + int line_number{}; + /// Number uniquely identifying the kernel launch that triggered the assertion + uint32_t caller{}; + /// block_id of the thread that failed the assertion + // NOLINTNEXTLINE(*-c-arrays) + int32_t block_id[3]{}; + /// third_id of the thread that failed the assertion + // NOLINTNEXTLINE(*-c-arrays) + int32_t thread_id[3]{}; +}; + +/// Used to hold assertions generated by the device +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionsData { + /// Total number of assertions found; a subset of thse will be recorded + /// in `assertions` + int32_t assertion_count{}; + /// An array of assertions that will be written to in a race-free manner + // NOLINTNEXTLINE(*-c-arrays) + DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]{}; +}; + +/// Use to hold info about kernel launches so that we can run kernels +/// asynchronously and still associate launches with device-side +/// assertion failures +struct CUDAKernelLaunchInfo { + /// Filename of the code where the kernel was launched from + const char* launch_filename; + /// Function from which the kernel was launched + const char* launch_function; + /// Line number of where the code was launched from + uint32_t launch_linenum; + /// Backtrace of where the kernel was launched from, only populated if + /// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True + std::string launch_stacktrace; + /// Kernel that was launched + const char* kernel_name; + /// Device the kernel was launched on + int device; + /// Stream the kernel was launched on + int32_t stream; + /// A number that uniquely identifies the kernel launch + uint64_t generation_number; +}; + +/// Circular buffer used to hold information about kernel launches +/// this is later used to reconstruct how a device-side kernel assertion failure +/// occurred CUDAKernelLaunchRegistry is used as a singleton +class C10_CUDA_API CUDAKernelLaunchRegistry { + private: + /// Assume that this is the max number of kernel launches that might ever be + /// enqueued across all streams on a single device + static constexpr int max_kernel_launches = 1024; + /// How many kernel launch infos we've inserted. Used to ensure that circular + /// queue doesn't provide false information by always increasing, but also to + /// mark where we are inserting into the queue +#ifdef TORCH_USE_CUDA_DSA + uint64_t generation_number = 0; +#endif + /// Shared mutex between writer and accessor to ensure multi-threaded safety. + mutable std::mutex read_write_mutex; + /// Used to ensure prevent race conditions in GPU memory allocation + mutable std::mutex gpu_alloc_mutex; + /// Pointer to managed memory keeping track of device-side assertions. There + /// is one entry for each possible device the process might work with. Unused + /// entries are nullptrs. We could also use an unordered_set here, but this + /// vector design will be faster and the wasted memory is small since we + /// expect the number of GPUs per node will always be small + std::vector< + std::unique_ptr> + uvm_assertions; + /// A single circular buffer holds information about every kernel launch the + /// process makes across all devices. + std::vector kernel_launches; + bool check_env_for_enable_launch_stacktracing() const; + bool check_env_for_dsa_enabled() const; + + public: + CUDAKernelLaunchRegistry(); + /// Register a new kernel launch and obtain a generation number back to be + /// passed to the kernel + uint32_t insert( + const char* launch_filename, + const char* launch_function, + const uint32_t launch_linenum, + const char* kernel_name, + const int32_t stream_id); + /// Get copies of the kernel launch registry and each device's assertion + /// failure buffer so they can be inspected without raising race conditions + std:: + pair, std::vector> + snapshot() const; + /// Get a pointer to the current device's assertion failure buffer. If no such + /// buffer exists then one is created. This means that the first kernel launch + /// made on each device will be slightly slower because memory allocations are + /// required + DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device(); + /// Gets the global singleton of the registry + static CUDAKernelLaunchRegistry& get_singleton_ref(); + /// If not all devices support DSA, we disable it + const bool do_all_devices_support_managed_memory = false; + /// Whether or not to gather stack traces when launching kernels + bool gather_launch_stacktrace = false; + /// Whether or not host-side DSA is enabled or disabled at run-time + /// Note: Device-side code cannot be enabled/disabled at run-time + bool enabled_at_runtime = false; + /// Whether or not a device has indicated a failure + bool has_failed() const; +#ifdef TORCH_USE_CUDA_DSA + const bool enabled_at_compile_time = true; +#else + const bool enabled_at_compile_time = false; +#endif +}; + +std::string c10_retrieve_device_side_assertion_info(); + +} // namespace c10::cuda + +// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH +// requires the same input arguments. We introduce the following macro to +// standardize these. +#define TORCH_DSA_KERNEL_ARGS \ + [[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \ + [[maybe_unused]] uint32_t assertion_caller_id + +// This macro can be used to pass the DSA arguments onward to another +// function +#define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAException.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAException.h new file mode 100644 index 0000000000000000000000000000000000000000..4f49c852aa12f65b0df73fc4ef976595ca440332 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAException.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// Note [CHECK macro] +// ~~~~~~~~~~~~~~~~~~ +// This is a macro so that AT_ERROR can get accurate __LINE__ +// and __FILE__ information. We could split this into a short +// macro and a function implementation if we pass along __LINE__ +// and __FILE__, but no one has found this worth doing. + +// Used to denote errors from CUDA framework. +// This needs to be declared here instead util/Exception.h for proper conversion +// during hipify. +namespace c10 { +class C10_CUDA_API CUDAError : public c10::Error { + using Error::Error; +}; +} // namespace c10 + +#define C10_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + c10::cuda::c10_cuda_check_implementation( \ + static_cast(__err), \ + __FILE__, \ + __func__, /* Line number data type not well-defined between \ + compilers, so we perform an explicit cast */ \ + static_cast(__LINE__), \ + true); \ + } while (0) + +#define C10_CUDA_CHECK_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != cudaSuccess)) { \ + auto error_unused C10_UNUSED = cudaGetLastError(); \ + (void)error_unused; \ + TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ + } \ + } while (0) + +// Indicates that a CUDA error is handled in a non-standard way +#define C10_CUDA_ERROR_HANDLED(EXPR) EXPR + +// Intentionally ignore a CUDA error +#define C10_CUDA_IGNORE_ERROR(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != cudaSuccess)) { \ + cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \ + (void)error_unused; \ + } \ + } while (0) + +// Clear the last CUDA error +#define C10_CUDA_CLEAR_ERROR() \ + do { \ + cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \ + (void)error_unused; \ + } while (0) + +// This should be used directly after every kernel launch to ensure +// the launch happened correctly and provide an early, close-to-source +// diagnostic if it didn't. +#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) + +/// Launches a CUDA kernel appending to it all the information need to handle +/// device-side assertion failures. Checks that the launch was successful. +#define TORCH_DSA_KERNEL_LAUNCH( \ + kernel, blocks, threads, shared_mem, stream, ...) \ + do { \ + auto& launch_registry = \ + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \ + kernel<<>>( \ + __VA_ARGS__, \ + launch_registry.get_uvm_assertions_ptr_for_current_device(), \ + launch_registry.insert( \ + __FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } while (0) + +namespace c10::cuda { + +/// In the event of a CUDA failure, formats a nice error message about that +/// failure and also checks for device-side assertion failures +C10_CUDA_API void c10_cuda_check_implementation( + const int32_t err, + const char* filename, + const char* function_name, + const int line_number, + const bool include_device_assertions); + +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAFunctions.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..98540d9ab7a79a4d86545eb506505c310e584e5e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAFunctions.h @@ -0,0 +1,116 @@ +#pragma once + +// This header provides C++ wrappers around commonly used CUDA API functions. +// The benefit of using C++ here is that we can raise an exception in the +// event of an error, rather than explicitly pass around error codes. This +// leads to more natural APIs. +// +// The naming convention used here matches the naming convention of torch.cuda + +#include +#include +#include +#include +#include +namespace c10::cuda { + +// NB: In the past, we were inconsistent about whether or not this reported +// an error if there were driver problems are not. Based on experience +// interacting with users, it seems that people basically ~never want this +// function to fail; it should just return zero if things are not working. +// Oblige them. +// It still might log a warning for user first time it's invoked +C10_CUDA_API DeviceIndex device_count() noexcept; + +// Version of device_count that throws is no devices are detected +C10_CUDA_API DeviceIndex device_count_ensure_non_zero(); + +C10_CUDA_API DeviceIndex current_device(); + +C10_CUDA_API void set_device(DeviceIndex device); + +C10_CUDA_API void device_synchronize(); + +C10_CUDA_API void warn_or_error_on_sync(); + +// Raw CUDA device management functions +C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count); + +C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device); + +C10_CUDA_API cudaError_t SetDevice(DeviceIndex device); + +C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device); + +C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device); + +C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device); + +C10_CUDA_API void SetTargetDevice(); + +enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; + +// this is a holder for c10 global state (similar to at GlobalContext) +// currently it's used to store cuda synchronization warning state, +// but can be expanded to hold other related global state, e.g. to +// record stream usage +class WarningState { + public: + void set_sync_debug_mode(SyncDebugMode l) { + sync_debug_mode = l; + } + + SyncDebugMode get_sync_debug_mode() { + return sync_debug_mode; + } + + private: + SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED; +}; + +C10_CUDA_API __inline__ WarningState& warning_state() { + static WarningState warning_state_; + return warning_state_; +} +// the subsequent functions are defined in the header because for performance +// reasons we want them to be inline +C10_CUDA_API void __inline__ memcpy_and_sync( + void* dst, + const void* src, + int64_t nbytes, + cudaMemcpyKind kind, + cudaStream_t stream) { + if (C10_UNLIKELY( + warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { + warn_or_error_on_sync(); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + reinterpret_cast(stream)); + } +#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); +#else + C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); +#endif +} + +C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { + if (C10_UNLIKELY( + warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { + warn_or_error_on_sync(); + } + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_stream_synchronization( + reinterpret_cast(stream)); + } + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index); +C10_CUDA_API c10::optional getDeviceIndexWithPrimaryContext(); + +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6792068fb0377c4f4c2edc0ede11e3a3c0efdf2d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAGraphsC10Utils.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +// CUDA Graphs utils used by c10 and aten. +// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. + +namespace c10::cuda { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by CUDAGraph::capture_begin. +// second is set if the instance is created by at::cuda::graph_pool_handle. +using MempoolId_t = std::pair; + +// RAII guard for "cudaStreamCaptureMode", a thread-local value +// that controls the error-checking strictness of a capture. +#if !defined(USE_ROCM) || ROCM_VERSION >= 50300 +struct C10_CUDA_API CUDAStreamCaptureModeGuard { + CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) + : strictness_(desired) { + C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_)); + } + ~CUDAStreamCaptureModeGuard() { + C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); + } + + private: + cudaStreamCaptureMode strictness_; +}; +#endif + +#if !defined(USE_ROCM) || ROCM_VERSION >= 50300 +// Protects against enum cudaStreamCaptureStatus implementation changes. +// Some compilers seem not to like static_assert without the messages. +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, + "unexpected int(cudaStreamCaptureStatusNone) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, + "unexpected int(cudaStreamCaptureStatusActive) value"); +static_assert( + int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, + "unexpected int(cudaStreamCaptureStatusInvalidated) value"); +#endif + +enum class CaptureStatus : int { +#if !defined(USE_ROCM) || ROCM_VERSION >= 50300 + None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), + Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), + Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) +#else + None = 0 +#endif +}; + +inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { + switch (status) { + case CaptureStatus::None: + os << "cudaStreamCaptureStatusNone"; + break; +#if !defined(USE_ROCM) || ROCM_VERSION >= 50300 + case CaptureStatus::Active: + os << "cudaStreamCaptureStatusActive"; + break; + case CaptureStatus::Invalidated: + os << "cudaStreamCaptureStatusInvalidated"; + break; +#endif + default: + TORCH_INTERNAL_ASSERT( + false, "Unknown CUDA graph CaptureStatus", int(status)); + } + return os; +} + +// Use this version where you're sure a CUDA context exists already. +inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { +#if !defined(USE_ROCM) || ROCM_VERSION >= 50300 + cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone}; + C10_CUDA_CHECK( + cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); + return CaptureStatus(is_capturing); +#else + return CaptureStatus::None; +#endif +} + +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAGuard.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..09f1a6f2b6f8be7fbf7966a6b6256ffad1c09da6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAGuard.h @@ -0,0 +1,303 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace c10::cuda { + +// This code is kind of boilerplatey. See Note [Whither the DeviceGuard +// boilerplate] + +/// A variant of DeviceGuard that is specialized for CUDA. It accepts +/// integer indices (interpreting them as CUDA devices) and is a little +/// more efficient than DeviceGuard (it compiles to straight line +/// cudaSetDevice/cudaGetDevice calls); however, it can only be used +/// from code that links against CUDA directly. +struct CUDAGuard { + /// No default constructor; see Note [Omitted default constructor from RAII] + explicit CUDAGuard() = delete; + + /// Set the current CUDA device to the passed device index. + explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} + + /// Sets the current CUDA device to the passed device. Errors if the passed + /// device is not a CUDA device. + explicit CUDAGuard(Device device) : guard_(device) {} + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move is not allowed (there is no uninitialized state) + CUDAGuard(CUDAGuard&& other) = delete; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + /// Sets the CUDA device to the given device. Errors if the given device + /// is not a CUDA device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the CUDA device to the given device. Errors if the given device + /// is not a CUDA device. (This method is provided for uniformity with + /// DeviceGuard). + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the CUDA device to the given device index. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set upon construction of the guard + Device original_device() const { + return guard_.original_device(); + } + + /// Returns the last device that was set via `set_device`, if any, otherwise + /// the device passed during construction. + Device current_device() const { + return guard_.current_device(); + } + + private: + /// The guard for the current device. + c10::impl::InlineDeviceGuard guard_; +}; + +/// A variant of OptionalDeviceGuard that is specialized for CUDA. See +/// CUDAGuard for when you can use this. +struct OptionalCUDAGuard { + /// Create an uninitialized OptionalCUDAGuard. + explicit OptionalCUDAGuard() : guard_() {} + + /// Set the current CUDA device to the passed Device, if it is not nullopt. + explicit OptionalCUDAGuard(optional device_opt) + : guard_(device_opt) {} + + /// Set the current CUDA device to the passed device index, if it is not + /// nullopt + explicit OptionalCUDAGuard(optional device_index_opt) + : guard_(device_index_opt) {} + + // Copy is not allowed + OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; + OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; + + /// Sets the CUDA device to the given device, initializing the guard if it + /// is not already initialized. Errors if the given device is not a CUDA + /// device. + void set_device(Device device) { + guard_.set_device(device); + } + + /// Sets the CUDA device to the given device, initializing the guard if it is + /// not already initialized. Errors if the given device is not a CUDA device. + /// (This method is provided for uniformity with OptionalDeviceGuard). + void reset_device(Device device) { + guard_.reset_device(device); + } + + /// Sets the CUDA device to the given device index, initializing the guard if + /// it is not already initialized. + void set_index(DeviceIndex device_index) { + guard_.set_index(device_index); + } + + /// Returns the device that was set immediately prior to initialization of the + /// guard, or nullopt if the guard is uninitialized. + optional original_device() const { + return guard_.original_device(); + } + + /// Returns the most recent device that was set using this device guard, + /// either from construction, or via set_device, if the guard is initialized, + /// or nullopt if the guard is uninitialized. + optional current_device() const { + return guard_.current_device(); + } + + /// Restore the original CUDA device, resetting this guard to uninitialized + /// state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +/// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard +/// for when you can use this. +struct CUDAStreamGuard { + /// No default constructor, see Note [Omitted default constructor from RAII] + explicit CUDAStreamGuard() = delete; + + /// Set the current CUDA device to the device associated with the passed + /// stream, and set the current CUDA stream on that device to the passed + /// stream. Errors if the Stream is not a CUDA stream. + explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} + + /// Copy is disallowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized + /// state, which is required for moves on types with nontrivial destructors. + CUDAStreamGuard(CUDAStreamGuard&& other) = delete; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; + + /// Resets the currently set stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Errors if the stream passed is not a CUDA stream. + /// + /// NOTE: this implementation may skip some stream/device setting if + /// it can prove that it is unnecessary. + /// + /// WARNING: reset_stream does NOT preserve previously set streams on + /// different devices. If you need to set streams on multiple devices + /// on CUDA, use CUDAMultiStreamGuard instead. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the CUDA stream that was set at the time the guard was + /// constructed. + CUDAStream original_stream() const { + return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); + } + + /// Returns the most recent CUDA stream that was set using this device guard, + /// either from construction, or via set_stream. + CUDAStream current_stream() const { + return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream()); + } + + /// Returns the most recent CUDA device that was set using this device guard, + /// either from construction, or via set_device/reset_device/set_index. + Device current_device() const { + return guard_.current_device(); + } + + /// Returns the CUDA device that was set at the most recent reset_stream(), + /// or otherwise the device at construction time. + Device original_device() const { + return guard_.original_device(); + } + + private: + c10::impl::InlineStreamGuard guard_; +}; + +/// A variant of OptionalStreamGuard that is specialized for CUDA. See +/// CUDAGuard for when you can use this. +struct OptionalCUDAStreamGuard { + /// Create an uninitialized guard. + explicit OptionalCUDAStreamGuard() : guard_() {} + + /// Set the current CUDA device to the device associated with the passed + /// stream, and set the current CUDA stream on that device to the passed + /// stream. Errors if the Stream is not a CUDA stream. + explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} + + /// Set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream, + /// if the passed stream is not nullopt. + explicit OptionalCUDAStreamGuard(optional stream_opt) + : guard_(stream_opt) {} + + /// Copy is disallowed + OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; + OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete; + + /// Resets the currently set CUDA stream to the original stream and + /// the currently set device to the original device. Then, + /// set the current device to the device associated with the passed stream, + /// and set the current stream on that device to the passed stream. + /// Initializes the guard if it was not previously initialized. + void reset_stream(Stream stream) { + guard_.reset_stream(stream); + } + + /// Returns the CUDA stream that was set at the time the guard was most + /// recently initialized, or nullopt if the guard is uninitialized. + optional original_stream() const { + auto r = guard_.original_stream(); + if (r.has_value()) { + return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); + } else { + return nullopt; + } + } + + /// Returns the most recent CUDA stream that was set using this stream guard, + /// either from construction, or via reset_stream, if the guard is + /// initialized, or nullopt if the guard is uninitialized. + optional current_stream() const { + auto r = guard_.current_stream(); + if (r.has_value()) { + return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); + } else { + return nullopt; + } + } + + /// Restore the original CUDA device and stream, resetting this guard to + /// uninitialized state. + void reset() { + guard_.reset(); + } + + private: + c10::impl::InlineOptionalStreamGuard guard_; +}; + +/// A variant of MultiStreamGuard that is specialized for CUDA. +struct CUDAMultiStreamGuard { + explicit CUDAMultiStreamGuard(ArrayRef streams) + : guard_(unwrapStreams(streams)) {} + + /// Copy is disallowed + CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; + CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete; + + // See Note [Move construction for RAII guards is tricky] + CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete; + + // See Note [Move assignment for RAII guards is tricky] + CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; + + private: + c10::impl::InlineMultiStreamGuard guard_; + + static std::vector unwrapStreams(ArrayRef cudaStreams) { + std::vector streams; + streams.reserve(cudaStreams.size()); + for (const CUDAStream& cudaStream : cudaStreams) { + streams.push_back(cudaStream); + } + return streams; + } +}; + +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMacros.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMacros.h new file mode 100644 index 0000000000000000000000000000000000000000..3b69035c4e83fd2eca488a3a726eddd49cab4058 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMacros.h @@ -0,0 +1,51 @@ +#pragma once + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS + +// We have not yet modified the AMD HIP build to generate this file so +// we add an extra option to specifically ignore it. +#ifndef C10_CUDA_NO_CMAKE_CONFIGURE_FILE +#include +#endif // C10_CUDA_NO_CMAKE_CONFIGURE_FILE + +#endif + +// See c10/macros/Export.h for a detailed explanation of what the function +// of these macros are. We need one set of macros for every separate library +// we build. + +#ifdef _WIN32 +#if defined(C10_CUDA_BUILD_SHARED_LIBS) +#define C10_CUDA_EXPORT __declspec(dllexport) +#define C10_CUDA_IMPORT __declspec(dllimport) +#else +#define C10_CUDA_EXPORT +#define C10_CUDA_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_CUDA_EXPORT __attribute__((__visibility__("default"))) +#else // defined(__GNUC__) +#define C10_CUDA_EXPORT +#endif // defined(__GNUC__) +#define C10_CUDA_IMPORT C10_CUDA_EXPORT +#endif // _WIN32 + +// This one is being used by libc10_cuda.so +#ifdef C10_CUDA_BUILD_MAIN_LIB +#define C10_CUDA_API C10_CUDA_EXPORT +#else +#define C10_CUDA_API C10_CUDA_IMPORT +#endif + +/** + * The maximum number of GPUs that we recognizes. Increasing this beyond the + * initial limit of 16 broke Caffe2 testing, hence the ifdef guards. + * This value cannot be more than 128 because our DeviceIndex is a uint8_t. +o */ +#ifdef FBCODE_CAFFE2 +// fbcode depends on this value being 16 +#define C10_COMPILE_TIME_MAX_GPUS 16 +#else +#define C10_COMPILE_TIME_MAX_GPUS 120 +#endif diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMathCompat.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMathCompat.h new file mode 100644 index 0000000000000000000000000000000000000000..4c41bd5f382b35c7955bcda8035b156e4ad26e18 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMathCompat.h @@ -0,0 +1,152 @@ +#pragma once + +/* This file defines math functions compatible across different gpu + * platforms (currently CUDA and HIP). + */ +#if defined(__CUDACC__) || defined(__HIPCC__) + +#include +#include + +#ifdef __HIPCC__ +#define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE +#else /* __HIPCC__ */ +#ifdef __CUDACC_RTC__ +#define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE +#else /* __CUDACC_RTC__ */ +#define __MATH_FUNCTIONS_DECL__ static inline C10_HOST_DEVICE +#endif /* __CUDACC_RTC__ */ +#endif /* __HIPCC__ */ + +namespace c10::cuda::compat { + +__MATH_FUNCTIONS_DECL__ float abs(float x) { + return ::fabsf(x); +} +__MATH_FUNCTIONS_DECL__ double abs(double x) { + return ::fabs(x); +} + +__MATH_FUNCTIONS_DECL__ float exp(float x) { + return ::expf(x); +} +__MATH_FUNCTIONS_DECL__ double exp(double x) { + return ::exp(x); +} + +__MATH_FUNCTIONS_DECL__ float ceil(float x) { + return ::ceilf(x); +} +__MATH_FUNCTIONS_DECL__ double ceil(double x) { + return ::ceil(x); +} + +__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + return ::copysignf(x, y); +#else + // std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64 + // (e.g. Jetson), see PyTorch PR #51834 + // This host function needs to be here for the compiler but is never used + TORCH_INTERNAL_ASSERT( + false, "CUDAMathCompat copysign should not run on the CPU"); +#endif +} +__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + return ::copysign(x, y); +#else + // see above + TORCH_INTERNAL_ASSERT( + false, "CUDAMathCompat copysign should not run on the CPU"); +#endif +} + +__MATH_FUNCTIONS_DECL__ float floor(float x) { + return ::floorf(x); +} +__MATH_FUNCTIONS_DECL__ double floor(double x) { + return ::floor(x); +} + +__MATH_FUNCTIONS_DECL__ float log(float x) { + return ::logf(x); +} +__MATH_FUNCTIONS_DECL__ double log(double x) { + return ::log(x); +} + +__MATH_FUNCTIONS_DECL__ float log1p(float x) { + return ::log1pf(x); +} + +__MATH_FUNCTIONS_DECL__ double log1p(double x) { + return ::log1p(x); +} + +__MATH_FUNCTIONS_DECL__ float max(float x, float y) { + return ::fmaxf(x, y); +} +__MATH_FUNCTIONS_DECL__ double max(double x, double y) { + return ::fmax(x, y); +} + +__MATH_FUNCTIONS_DECL__ float min(float x, float y) { + return ::fminf(x, y); +} +__MATH_FUNCTIONS_DECL__ double min(double x, double y) { + return ::fmin(x, y); +} + +__MATH_FUNCTIONS_DECL__ float pow(float x, float y) { + return ::powf(x, y); +} +__MATH_FUNCTIONS_DECL__ double pow(double x, double y) { + return ::pow(x, y); +} + +__MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) { + return ::sincosf(x, sptr, cptr); +} +__MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) { + return ::sincos(x, sptr, cptr); +} + +__MATH_FUNCTIONS_DECL__ float sqrt(float x) { + return ::sqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double sqrt(double x) { + return ::sqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float rsqrt(float x) { + return ::rsqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double rsqrt(double x) { + return ::rsqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float tan(float x) { + return ::tanf(x); +} +__MATH_FUNCTIONS_DECL__ double tan(double x) { + return ::tan(x); +} + +__MATH_FUNCTIONS_DECL__ float tanh(float x) { + return ::tanhf(x); +} +__MATH_FUNCTIONS_DECL__ double tanh(double x) { + return ::tanh(x); +} + +__MATH_FUNCTIONS_DECL__ float normcdf(float x) { + return ::normcdff(x); +} +__MATH_FUNCTIONS_DECL__ double normcdf(double x) { + return ::normcdf(x); +} + +} // namespace c10::cuda::compat + +#endif diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMiscFunctions.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMiscFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..203363028fd72d01065f161832c7e6c5f6b5217d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAMiscFunctions.h @@ -0,0 +1,12 @@ +#pragma once +// this file is to avoid circular dependency between CUDAFunctions.h and +// CUDAExceptions.h + +#include + +#include + +namespace c10::cuda { +C10_CUDA_API const char* get_cuda_check_suffix() noexcept; +C10_CUDA_API std::mutex* getFreeMutex(); +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAStream.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAStream.h new file mode 100644 index 0000000000000000000000000000000000000000..f7b6609f210d4967c191705c31121ddf7520c30a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/CUDAStream.h @@ -0,0 +1,271 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include + +/* + * Stream pool note. + * + * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams + * are backed by cuStreams, but they use several pools to minimize the costs + * associated with creating, retaining, and destroying cuStreams. + * + * There are three pools per device, and a device's pools are lazily created. + * + * The first pool contains only the default stream. When the default stream + * is requested it's returned. + * + * The second pool is the "low priority" or "default priority" streams. In + * HIP builds there is no distinction between streams in this pool and streams + * in the third pool (below). There are 32 of these streams per device, and + * when a stream is requested one of these streams is returned round-robin. + * That is, the first stream requested is at index 0, the second at index 1... + * to index 31, then index 0 again. + * + * This means that if 33 low priority streams are requested, the first and + * last streams requested are actually the same stream (under the covers) + * and kernels enqueued on them cannot run concurrently. + * + * The third pool is the "high priority" streams. The third pool acts like + * the second pool except the streams are created with a higher priority. + * + * These pools suggest that stream users should prefer many short-lived streams, + * as the cost of acquiring and releasing streams is effectively zero. If + * many longer-lived streams are required in performance critical scenarios + * then the functionality here may need to be extended to allow, for example, + * "reserving" a subset of the pool so that other streams do not accidentally + * overlap the performance critical streams. + * + * Note: although the notion of "current stream for device" is thread local + * (every OS thread has a separate current stream, as one might expect), + * the stream pool is global across all threads; stream 0 is always stream 0 + * no matter which thread you use it on. Multiple threads can synchronize + * on the same stream. Although the CUDA documentation is not very clear + * on the matter, streams are thread safe; e.g., it is safe to enqueue + * a kernel on the same stream from two different threads. + */ + +namespace c10::cuda { + +static constexpr int max_compile_time_stream_priorities = 4; + +// Value object representing a CUDA stream. This is just a wrapper +// around c10::Stream, but it comes with a little extra CUDA-specific +// functionality (conversion to cudaStream_t), and a guarantee that +// the wrapped c10::Stream really is a CUDA stream. +class C10_CUDA_API CUDAStream { + public: + enum Unchecked { UNCHECKED }; + + /// Construct a CUDAStream from a Stream. This construction is checked, + /// and will raise an error if the Stream is not, in fact, a CUDA stream. + explicit CUDAStream(Stream stream) : stream_(stream) { + TORCH_CHECK(stream_.device_type() == DeviceType::CUDA); + } + + /// Construct a CUDAStream from a Stream with no error checking. + /// This constructor uses the "named" constructor idiom, and can + /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream) + explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {} + + bool operator==(const CUDAStream& other) const noexcept { + return unwrap() == other.unwrap(); + } + + bool operator!=(const CUDAStream& other) const noexcept { + return unwrap() != other.unwrap(); + } + + /// Implicit conversion to cudaStream_t. + operator cudaStream_t() const { + return stream(); + } + + /// Implicit conversion to Stream (a.k.a., forget that the stream is a + /// CUDA stream). + operator Stream() const { + return unwrap(); + } + + /// Used to avoid baking in device type explicitly to Python-side API. + DeviceType device_type() const { + return DeviceType::CUDA; + } + + /// Get the CUDA device index that this stream is associated with. + DeviceIndex device_index() const { + return stream_.device_index(); + } + + /// Get the full Device that this stream is associated with. The Device + /// is guaranteed to be a CUDA device. + Device device() const { + return Device(DeviceType::CUDA, device_index()); + } + + /// Return the stream ID corresponding to this particular stream. + StreamId id() const { + return stream_.id(); + } + + bool query() const { + DeviceGuard guard{stream_.device()}; + cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream())); + + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void synchronize() const { + DeviceGuard guard{stream_.device()}; + c10::cuda::stream_synchronize(stream()); + } + + int priority() const { + DeviceGuard guard{stream_.device()}; + int priority = 0; + C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); + return priority; + } + + /// Explicit conversion to cudaStream_t. + cudaStream_t stream() const; + + /// Explicit conversion to Stream. + Stream unwrap() const { + return stream_; + } + + /// Reversibly pack a CUDAStream into a struct representation. + /// Previously the stream's data was packed into a single int64_t, + /// as it was assumed the fields would not require more than + /// 64 bits of storage in total. + /// See https://github.com/pytorch/pytorch/issues/75854 + /// for more information regarding newer platforms that may violate + /// this assumption. + /// + /// The CUDAStream can be unpacked using unpack(). + struct c10::StreamData3 pack3() const { + return stream_.pack3(); + } + + // Unpack a CUDAStream from the 3 fields generated by pack(). + static CUDAStream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + return CUDAStream(Stream::unpack3(stream_id, device_index, device_type)); + } + + static std::tuple priority_range() { + // Note: this returns the range of priority **supported by PyTorch**, not + // the range of priority **supported by CUDA**. The former is a subset of + // the latter. + int least_priority = 0, greatest_priority = 0; + C10_CUDA_CHECK( + cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); +#ifdef USE_ROCM + // See Note [HIP stream priorities] + TORCH_INTERNAL_ASSERT( + least_priority == 1, "Unexpected HIP stream priority range"); + least_priority = 0; +#else + TORCH_INTERNAL_ASSERT( + least_priority == 0, "Unexpected CUDA stream priority range"); +#endif + TORCH_INTERNAL_ASSERT( + greatest_priority <= -1, "Unexpected CUDA stream priority range"); + greatest_priority = std::max( + -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority); + return std::make_tuple(least_priority, greatest_priority); + } + + // Deleted for now; use CUDAEvent::block instead + // void synchronize_with(const CUDAEvent& event) const; + + private: + Stream stream_; +}; + +/** + * Get a new stream from the CUDA stream pool. You can think of this + * as "creating" a new stream, but no such creation actually happens; + * instead, streams are preallocated from the pool and returned in a + * round-robin fashion. + * + * You can request a stream from the high priority pool by setting + * isHighPriority to true, or a stream for a specific device by setting device + * (defaulting to the current CUDA stream.) + */ +C10_API CUDAStream +getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); +// no default priority to disambiguate overloads +C10_API CUDAStream +getStreamFromPool(const int priority, DeviceIndex device = -1); + +/** + * Get a CUDAStream from a externally allocated one. + * + * This is mainly for interoperability with different libraries where we + * want to operate on a non-torch allocated stream for data exchange or similar + * purposes + */ +C10_API CUDAStream +getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index); + +/** + * Get the default CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The default stream is + * where most computation occurs when you aren't explicitly using + * streams. + */ +C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream, for the passed CUDA device, or for the + * current device if no device index is passed. The current CUDA stream + * will usually be the default CUDA stream for the device, but it may + * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard' + * or 'CUDAStreamGuard'. + */ +C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); + +/** + * Set the current stream on the device of the passed in stream to be + * the passed in stream. Yes, you read that right: this function + * has *nothing* to do with the current device: it toggles the current + * stream of the device of the passed stream. + * + * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead + * (which will switch both your current device and current stream in the way you + * expect, and reset it back to its original state afterwards). + */ +C10_API void setCurrentCUDAStream(CUDAStream stream); + +C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); + +} // namespace c10::cuda + +namespace std { +template <> +struct hash { + size_t operator()(c10::cuda::CUDAStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; +} // namespace std diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/driver_api.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/driver_api.h new file mode 100644 index 0000000000000000000000000000000000000000..80b7bcbec62c8821e597eedf2e51ac64cf48f649 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/driver_api.h @@ -0,0 +1,49 @@ +#pragma once +#include +#define NVML_NO_UNVERSIONED_FUNC_DEFS +#include + +#define C10_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + CUresult get_error_str_err C10_UNUSED = \ + c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + AT_ERROR("CUDA driver error: unknown error"); \ + } else { \ + AT_ERROR("CUDA driver error: ", err_str); \ + } \ + } \ + } while (0) + +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ + _(cuGetErrorString) + +#define C10_NVML_DRIVER_API(_) \ + _(nvmlInit_v2) \ + _(nvmlDeviceGetHandleByPciBusId_v2) \ + _(nvmlDeviceGetNvLinkRemoteDeviceType) \ + _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \ + _(nvmlDeviceGetComputeRunningProcesses) + +namespace c10::cuda { + +struct DriverAPI { +#define CREATE_MEMBER(name) decltype(&name) name##_; + C10_LIBCUDA_DRIVER_API(CREATE_MEMBER) + C10_NVML_DRIVER_API(CREATE_MEMBER) +#undef CREATE_MEMBER + static DriverAPI* get(); + static void* get_nvml_handle(); +}; + +} // namespace c10::cuda diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..844c5dd12e340370cd220d1abb52da8bb266280b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/impl/CUDAGuardImpl.h @@ -0,0 +1,212 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10::cuda::impl { + +struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::CUDA; + + CUDAGuardImpl() = default; + explicit CUDAGuardImpl(DeviceType t) { + TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); + } + DeviceType type() const override { + return DeviceType::CUDA; + } + Device exchangeDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_cuda()); + auto old_device_index = c10::cuda::ExchangeDevice(d.index()); + return Device(DeviceType::CUDA, old_device_index); + } + Device getDevice() const override { + DeviceIndex device = 0; + C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + return Device(DeviceType::CUDA, device); + } + c10::optional uncheckedGetDevice() const noexcept { + DeviceIndex device{-1}; + const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK_WARN(err); + if (err != cudaSuccess) { + return c10::nullopt; + } + return Device(DeviceType::CUDA, device); + } + void setDevice(Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_cuda()); + C10_CUDA_CHECK(c10::cuda::SetDevice(d.index())); + } + void uncheckedSetDevice(Device d) const noexcept override { + C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); + } + Stream getStream(Device d) const noexcept override { + return getCurrentCUDAStream(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultCUDAStream(d.index()); + } + Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) + const override { + return getStreamFromPool(isHighPriority, d.index()); + } + // NB: These do NOT set the current device + Stream exchangeStream(Stream s) const noexcept override { + CUDAStream cs(s); + auto old_stream = getCurrentCUDAStream(s.device().index()); + setCurrentCUDAStream(cs); + return old_stream.unwrap(); + } + DeviceIndex deviceCount() const noexcept override { + return device_count(); + } + + // Event-related functions + void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { + // Maps PyTorch's Event::Flag to CUDA flag + auto cuda_flag = cudaEventDefault; + switch (flag) { + case EventFlag::PYTORCH_DEFAULT: + case EventFlag::CUDA_EVENT_DISABLE_TIMING: + cuda_flag = cudaEventDisableTiming; + break; + case EventFlag::BACKEND_DEFAULT: + case EventFlag::CUDA_EVENT_DEFAULT: + cuda_flag = cudaEventDefault; + break; + default: + TORCH_CHECK(false, "CUDA event received unknown flag"); + } + + C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + reinterpret_cast(cuda_event)); + } + } + + void destroyEvent(void* event, const DeviceIndex device_index) + const noexcept override { + if (!event) + return; + auto cuda_event = static_cast(event); + DeviceIndex orig_device{-1}; + C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device)); + C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + reinterpret_cast(cuda_event)); + } + C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); + C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device)); + } + + void record( + void** event, + const Stream& stream, + const DeviceIndex device_index, + const EventFlag flag) const override { + TORCH_CHECK( + device_index == -1 || device_index == stream.device_index(), + "Event device index ", + device_index, + " does not match recording stream's device index ", + stream.device_index(), + "."); + + cudaEvent_t cuda_event = static_cast(*event); + CUDAStream cuda_stream{stream}; + + // Moves to stream's device to record + const auto orig_device = getDevice(); + setDevice(stream.device()); + + // Creates the event (lazily) + if (!cuda_event) + createEvent(&cuda_event, flag); + C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); + // Makes the void* point to the (possibly just allocated) CUDA event + *event = cuda_event; + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + reinterpret_cast(cuda_event), + reinterpret_cast(cuda_stream.stream())); + } + + // Resets device + setDevice(orig_device); + } + + void block(void* event, const Stream& stream) const override { + if (!event) + return; + cudaEvent_t cuda_event = static_cast(event); + CUDAStream cuda_stream{stream}; + const auto orig_device = getDevice(); + setDevice(stream.device()); + C10_CUDA_CHECK(cudaStreamWaitEvent( + cuda_stream, + cuda_event, + /*flags (must be zero)=*/0)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + reinterpret_cast(cuda_event), + reinterpret_cast(cuda_stream.stream())); + } + setDevice(orig_device); + } + + // May be called from any device + bool queryEvent(void* event) const override { + if (!event) + return true; + cudaEvent_t cuda_event = static_cast(event); + const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); + if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + return (err == cudaSuccess); + } + + // Stream-related functions + bool queryStream(const Stream& stream) const override { + CUDAStream cuda_stream{stream}; + return cuda_stream.query(); + } + + void synchronizeStream(const Stream& stream) const override { + CUDAStream cuda_stream{stream}; + cuda_stream.synchronize(); + } + + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) + const override { + CUDAStream cuda_stream{stream}; + CUDACachingAllocator::recordStream(data_ptr, cuda_stream); + } +}; + +} // namespace c10::cuda::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/cuda/impl/CUDATest.h b/MLPY/Lib/site-packages/torch/include/c10/cuda/impl/CUDATest.h new file mode 100644 index 0000000000000000000000000000000000000000..9b288250b5cec90a2540be3b5cae5af2ae965994 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/cuda/impl/CUDATest.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace c10::cuda::impl { + +C10_CUDA_API int c10_cuda_test(); + +} diff --git a/MLPY/Lib/site-packages/torch/include/c10/macros/Export.h b/MLPY/Lib/site-packages/torch/include/c10/macros/Export.h new file mode 100644 index 0000000000000000000000000000000000000000..fce560336543b36dd50afe0ce1732aa8e04494f1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/macros/Export.h @@ -0,0 +1,160 @@ +#ifndef C10_MACROS_EXPORT_H_ +#define C10_MACROS_EXPORT_H_ + +/* Header file to define the common scaffolding for exported symbols. + * + * Export is by itself a quite tricky situation to deal with, and if you are + * hitting this file, make sure you start with the background here: + * - Linux: https://gcc.gnu.org/wiki/Visibility + * - Windows: + * https://docs.microsoft.com/en-us/cpp/cpp/dllexport-dllimport?view=vs-2017 + * + * Do NOT include this file directly. Instead, use c10/macros/Macros.h + */ + +// You do not need to edit this part of file unless you are changing the core +// pytorch export abstractions. +// +// This part defines the C10 core export and import macros. This is controlled +// by whether we are building shared libraries or not, which is determined +// during build time and codified in c10/core/cmake_macros.h. +// When the library is built as a shared lib, EXPORT and IMPORT will contain +// visibility attributes. If it is being built as a static lib, then EXPORT +// and IMPORT basically have no effect. + +// As a rule of thumb, you should almost NEVER mix static and shared builds for +// libraries that depend on c10. AKA, if c10 is built as a static library, we +// recommend everything dependent on c10 to be built statically. If c10 is built +// as a shared library, everything dependent on it should be built as shared. In +// the PyTorch project, all native libraries shall use the macro +// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static +// libraries. + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + +#ifdef _WIN32 +#define C10_HIDDEN +#if defined(C10_BUILD_SHARED_LIBS) +#define C10_EXPORT __declspec(dllexport) +#define C10_IMPORT __declspec(dllimport) +#else +#define C10_EXPORT +#define C10_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_EXPORT __attribute__((__visibility__("default"))) +#define C10_HIDDEN __attribute__((__visibility__("hidden"))) +#else // defined(__GNUC__) +#define C10_EXPORT +#define C10_HIDDEN +#endif // defined(__GNUC__) +#define C10_IMPORT C10_EXPORT +#endif // _WIN32 + +#ifdef NO_EXPORT +#undef C10_EXPORT +#define C10_EXPORT +#endif + +// Definition of an adaptive XX_API macro, that depends on whether you are +// building the library itself or not, routes to XX_EXPORT and XX_IMPORT. +// Basically, you will need to do this for each shared library that you are +// building, and the instruction is as follows: assuming that you are building +// a library called libawesome.so. You should: +// (1) for your cmake target (usually done by "add_library(awesome, ...)"), +// define a macro called AWESOME_BUILD_MAIN_LIB using +// target_compile_options. +// (2) define the AWESOME_API macro similar to the one below. +// And in the source file of your awesome library, use AWESOME_API to +// annotate public symbols. + +// Here, for the C10 library, we will define the macro C10_API for both import +// and export. + +// This one is being used by libc10.so +#ifdef C10_BUILD_MAIN_LIB +#define C10_API C10_EXPORT +#else +#define C10_API C10_IMPORT +#endif + +// This one is being used by libtorch.so +#ifdef CAFFE2_BUILD_MAIN_LIB +#define TORCH_API C10_EXPORT +#else +#define TORCH_API C10_IMPORT +#endif + +// You may be wondering: Whose brilliant idea was it to split torch_cuda into +// two pieces with confusing names? +// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we +// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker +// issues when linking big binaries. +// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: +// (1) Stop supporting so many GPU architectures +// (2) Do something else +// We chose #2 and decided to split the behemoth that was torch_cuda into two +// smaller libraries, one with most of the core kernel functions (torch_cuda_cu) +// and the other that had..well..everything else (torch_cuda_cpp). The idea was +// this: instead of linking our static libraries (like the hefty +// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky +// relocation marker issues, we could link our static libraries to a smaller +// part of torch_cuda (torch_cuda_cpp) and avoid the issues. + +// libtorch_cuda_cu.so +#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB +#define TORCH_CUDA_CU_API C10_EXPORT +#elif defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +// libtorch_cuda_cpp.so +#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#elif defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CPP_API C10_IMPORT +#endif + +// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the +// same api) +#ifdef TORCH_CUDA_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#define TORCH_CUDA_CU_API C10_EXPORT +#elif !defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CPP_API C10_IMPORT +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +#if defined(TORCH_HIP_BUILD_MAIN_LIB) +#define TORCH_HIP_API C10_EXPORT +#else +#define TORCH_HIP_API C10_IMPORT +#endif + +#if defined(TORCH_XPU_BUILD_MAIN_LIB) +#define TORCH_XPU_API C10_EXPORT +#else +#define TORCH_XPU_API C10_IMPORT +#endif + +// Enums only need to be exported on windows for non-CUDA files +#if defined(_WIN32) && defined(__CUDACC__) +#define C10_API_ENUM C10_API +#else +#define C10_API_ENUM +#endif + +#endif // C10_MACROS_MACROS_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/macros/Macros.h b/MLPY/Lib/site-packages/torch/include/c10/macros/Macros.h new file mode 100644 index 0000000000000000000000000000000000000000..fae63b7b91ceec59b56bf6cb05111ef280602193 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/macros/Macros.h @@ -0,0 +1,546 @@ +#ifndef C10_MACROS_MACROS_H_ +#define C10_MACROS_MACROS_H_ +#include + +/* Main entry for c10/macros. + * + * In your code, include c10/macros/Macros.h directly, instead of individual + * files in this folder. + */ + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + +#include + +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_pointer_overflow__ \ + __attribute__((no_sanitize("pointer-overflow"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_pointer_overflow__ +#define __ubsan_ignore_function__ +#endif + +// Detect address sanitizer as some stuff doesn't work with it +#undef C10_ASAN_ENABLED + +// for clang +#if defined(__has_feature) +#if ((__has_feature(address_sanitizer))) +#define C10_ASAN_ENABLED 1 +#endif +#endif + +// for gcc +#if defined(__SANITIZE_ADDRESS__) +#if __SANITIZE_ADDRESS__ +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 1 +#endif +#endif +#endif + +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 0 +#endif + +// Disable the copy and assignment operator for a class. Note that this will +// disable the usage of the class in std containers. +#define C10_DISABLE_COPY_AND_ASSIGN(classname) \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete + +#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 +#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) + +#define C10_MACRO_EXPAND(args) args + +#define C10_STRINGIZE_IMPL(x) #x +#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) + +/** + * C10_ANONYMOUS_VARIABLE(str) introduces an identifier starting with + * str and ending with a number that varies with the line. + */ +#ifdef __COUNTER__ +#define C10_UID __COUNTER__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) +#else +#define C10_UID __LINE__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) +#endif + +#ifdef __has_cpp_attribute +#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +#define C10_HAS_CPP_ATTRIBUTE(x) (0) +#endif + +/// C10_NODISCARD - Warn if a type or return value is discarded. + +// Technically, we should check if __cplusplus > 201402L here, because +// [[nodiscard]] is only defined in C++17. However, some compilers +// we care about don't advertise being C++17 (e.g., clang), but +// support the attribute anyway. In fact, this is not just a good idea, +// it's the law: clang::warn_unused_result doesn't work on nvcc + clang +// and the best workaround for this case is to use [[nodiscard]] +// instead; see https://github.com/pytorch/pytorch/issues/13118 +// +// Note to future editors: if you have noticed that a compiler is +// misbehaving (e.g., it advertises support, but the support doesn't +// actually work, or it is emitting warnings). Some compilers which +// are strict about the matter include MSVC, which will complain: +// +// error C2429: attribute 'nodiscard' requires compiler flag '/std:c++latest' +// +// Exhibits: +// - MSVC 19.14: https://godbolt.org/z/Dzd7gn (requires /std:c++latest) +// - Clang 8.0.0: https://godbolt.org/z/3PYL4Z (always advertises support) +// - gcc 8.3: https://godbolt.org/z/4tLMQS (always advertises support) +#if C10_HAS_CPP_ATTRIBUTE(nodiscard) +#define C10_NODISCARD [[nodiscard]] +// Workaround for llvm.org/PR23435, since clang 3.6 and below emit a spurious +// error when __has_cpp_attribute is given a scoped attribute in C mode. +#elif __cplusplus && C10_HAS_CPP_ATTRIBUTE(clang::warn_unused_result) +// TODO: It's possible this is still triggering +// https://github.com/pytorch/pytorch/issues/13118 on Windows; if it is, better +// fix it. +#define C10_NODISCARD [[clang::warn_unused_result]] +#else +#define C10_NODISCARD +#endif + +// suppress an unused variable. +#if defined(_MSC_VER) && !defined(__clang__) +#define C10_UNUSED __pragma(warning(suppress : 4100 4101)) +#else +#define C10_UNUSED __attribute__((__unused__)) +#endif //_MSC_VER + +#if !defined(__has_attribute) +#define __has_attribute(x) 0 +#endif + +// Direct port of LLVM_ATTRIBUTE_USED. +#if __has_attribute(used) +#define C10_USED __attribute__((__used__)) +#else +#define C10_USED +#endif + +#define C10_RESTRICT __restrict + +// Simply define the namespace, in case a dependent library want to refer to +// the c10 namespace but not any nontrivial files. +namespace c10 {} +namespace c10::cuda {} +namespace c10::hip {} +namespace c10::xpu {} + +// Since C10 is the core library for caffe2 (and aten), we will simply reroute +// all abstractions defined in c10 to be available in caffe2 as well. +// This is only for backwards compatibility. Please use the symbols from the +// c10 namespace where possible. +namespace caffe2 { +using namespace c10; +} +namespace at { +using namespace c10; +} +namespace at::cuda { +using namespace c10::cuda; +} // namespace at::cuda + +// WARNING!!! THIS IS A GIANT HACK!!! +// This line means you cannot simultaneously include c10/hip +// and c10/cuda and then use them from the at::cuda namespace. +// This is true in practice, because HIPIFY works inplace on +// files in ATen/cuda, so it assumes that c10::hip is available +// from at::cuda. This namespace makes that happen. When +// HIPIFY is no longer out-of-place, we can switch the cuda +// here to hip and everyone is happy. +namespace at::cuda { +using namespace c10::hip; +} // namespace at::cuda + +namespace at::xpu { +using namespace c10::xpu; +} // namespace at::xpu + +// C10_LIKELY/C10_UNLIKELY +// +// These macros provide parentheses, so you can use these macros as: +// +// if C10_LIKELY(some_expr) { +// ... +// } +// +// NB: static_cast to boolean is mandatory in C++, because __builtin_expect +// takes a long argument, which means you may trigger the wrong conversion +// without it. +// +#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) +#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#else +#define C10_LIKELY(expr) (expr) +#define C10_UNLIKELY(expr) (expr) +#endif + +/// C10_NOINLINE - Functions whose declaration is annotated with this will not +/// be inlined. +#ifdef __GNUC__ +#define C10_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define C10_NOINLINE __declspec(noinline) +#else +#define C10_NOINLINE +#endif + +#if defined(_MSC_VER) +#define C10_ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define C10_ALWAYS_INLINE inline +#endif + +#if defined(_MSC_VER) +#define C10_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) +#else +#define C10_ATTR_VISIBILITY_HIDDEN +#endif + +#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN + +#include + +#ifdef __HIPCC__ +// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. +// We do this #include here so that C10_HOST_DEVICE and friends will Just Work. +// See https://github.com/ROCm-Developer-Tools/HIP/issues/441 +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define C10_HOST_DEVICE __host__ __device__ +#define C10_DEVICE __device__ +#define C10_HOST __host__ +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. +#if __CUDA_ARCH__ == 750 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; +#else +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; +#endif +// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently +constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. +constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; +// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it +// turns out that although __launch_bounds__ can take constexpr, it +// can't take a constexpr that has anything to do with templates. +// Currently we use launch_bounds that depend on template arguments in +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK +// and C10_MIN_BLOCKS_PER_SM are kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), +// which will also properly respect limits on old architectures. +#define C10_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block)-1) / \ + (threads_per_block)))) +// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ +#define C10_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy and + // versatility across all architectures. +#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) +#else +#define C10_HOST_DEVICE +#define C10_HOST +#define C10_DEVICE +#endif + +#if defined(USE_ROCM) +#define C10_HIP_HOST_DEVICE __host__ __device__ +#else +#define C10_HIP_HOST_DEVICE +#endif + +#if defined(USE_ROCM) +#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h) +#else +#define C10_WARP_SIZE 32 +#endif + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +#define __func__ __FUNCTION__ +#endif + +// CUDA_KERNEL_ASSERT checks the assertion +// even when NDEBUG is defined. This is useful for important assertions in CUDA +// code that would otherwise be suppressed when building Release. +#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) || \ + (defined(USE_ROCM) && ROCM_VERSION < 40100) +// Those platforms do not support assert() +#define CUDA_KERNEL_ASSERT(cond) +#define SYCL_KERNEL_ASSERT(cond) +#elif defined(_MSC_VER) +#if defined(NDEBUG) +extern "C" { +C10_IMPORT +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void _wassert( + const wchar_t* wexpr, + const wchar_t* wfile, + unsigned line); +#else +#if defined(__CUDA_ARCH__) +__host__ __device__ +#endif // __CUDA_ARCH__ + void + _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#else // __APPLE__, _MSC_VER +#if defined(NDEBUG) +extern "C" { +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void __assert_fail( + const char* expr, + const char* file, + unsigned int line, + const char* func); +#else // __SYCL_DEVICE_ONLY__ +#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) +// CUDA supports __assert_fail function which are common for both device +// and host side code. +__host__ __device__ +#endif + + // This forward declaration matching the declaration of __assert_fail + // exactly how it is in glibc in case parts of the program are compiled with + // different NDEBUG settings. Otherwise we might get 'ambiguous declaration' + // error. Note: On ROCm - this declaration serves for host side compilation. + void + __assert_fail( + const char* assertion, + const char* file, + unsigned int line, + const char* function) noexcept __attribute__((__noreturn__)); + +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +// ROCm disable kernel assert by default +#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) +#define CUDA_KERNEL_ASSERT(cond) +#define SYCL_KERNEL_ASSERT(cond) +#else +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM +#endif // __APPLE__ + +#ifdef __APPLE__ +#include +#endif + +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#define C10_MOBILE 1 +#elif ( \ + defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#define C10_MOBILE 1 +#endif // ANDROID / IOS + +#if defined(C10_MOBILE) && C10_MOBILE +#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline +#else +#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE +#endif + +#if defined(__CUDA_ARCH__) +#if defined(_MSC_VER) && defined(__CUDACC__) +#define CONSTEXPR_EXCEPT_WIN_CUDA const +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA __host__ + +// Note [static constexpr char* members for windows NVCC] +// The Windows NVCC compiler doesn't handle static constexpr class members, +// although it's fixed in a later version. +// (see +// https://developercommunity.visualstudio.com/t/intellisense-error-c11-static-constexpr-member-ini/245425) +// +// If we want to ensure that our field is static under all builds, then we need +// to work around it specifically for windows NVCC by making it (a) const, (b) +// defined outside of the class definition We need to define it outside of the +// class definition because of the C++ standard; char* is not an integral type +// (see +// https://stackoverflow.com/questions/24278473/intellisense-a-member-of-type-const-char-const-cannot-have-an-in-class-in) +// +// So instead of this: +// struct Foo { +// static constexpr const char* name = "foo"; +// } +// In Windows NVCC, we end up with this: +// struct Foo { +// static const char* name; +// } +// const char* Foo::name = "foo"; +// +// This gives us a small perf hit for any code that wants to access these field +// members, but right now it isn't used in any perf-critical code paths. +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static const char* field; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) \ + const char* cls::field = val; +#else +#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA __host__ + +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static constexpr const char* field = val; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) +#endif +#else +#if defined(_MSC_VER) && defined(__CUDACC__) +#define CONSTEXPR_EXCEPT_WIN_CUDA const +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA + +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static const char* field; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) \ + const char* cls::field = val; +#else +#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr + +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static constexpr const char* field = val; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) +#endif +#endif + +#ifndef HAS_DEMANGLE +#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) +#define HAS_DEMANGLE 0 +#elif defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE) +#define HAS_DEMANGLE 0 +#else +#define HAS_DEMANGLE 1 +#endif +#endif // HAS_DEMANGLE + +#define _C10_PRAGMA__(string) _Pragma(#string) +#define _C10_PRAGMA_(string) _C10_PRAGMA__(string) + +#ifdef __clang__ +#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") +#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \ + _C10_PRAGMA_(clang diagnostic ignored flag) +#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag) +#else +#define C10_CLANG_DIAGNOSTIC_PUSH() +#define C10_CLANG_DIAGNOSTIC_POP() +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) +#define C10_CLANG_HAS_WARNING(flag) 0 +#endif + +#ifdef __clang__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(clang diagnostic push) \ + _C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \ + _C10_PRAGMA_(clang diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop) + +#elif __GNUC__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(GCC diagnostic push) \ + _C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \ + _C10_PRAGMA_(GCC diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop) + +#else + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) +#define C10_DIAGNOSTIC_POP() + +#endif + +#endif // C10_MACROS_MACROS_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/macros/cmake_macros.h b/MLPY/Lib/site-packages/torch/include/c10/macros/cmake_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..45bba88997cc9f1645a27c82e63fa51ad6bf3ddd --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/macros/cmake_macros.h @@ -0,0 +1,14 @@ +#ifndef C10_MACROS_CMAKE_MACROS_H_ +#define C10_MACROS_CMAKE_MACROS_H_ + +// Automatically generated header file for the C10 library. +// Do not include this file directly. Instead, include c10/macros/Macros.h. + +#define C10_BUILD_SHARED_LIBS +/* #undef C10_USE_GLOG */ +/* #undef C10_USE_GFLAGS */ +/* #undef C10_USE_NUMA */ +/* #undef C10_USE_MSVC_STATIC_RUNTIME */ +/* #undef C10_USE_ROCM_KERNEL_ASSERT */ + +#endif // C10_MACROS_CMAKE_MACROS_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/AbortHandler.h b/MLPY/Lib/site-packages/torch/include/c10/util/AbortHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..327f7a93eda1f5d1e8c65e8f0b1aca97a5886c5c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/AbortHandler.h @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +class AbortHandlerHelper { + public: + static AbortHandlerHelper& getInstance() { +#ifdef _WIN32 + thread_local +#endif // _WIN32 + static AbortHandlerHelper instance; + return instance; + } + + void set(std::terminate_handler handler) { + std::lock_guard lk(mutex); + if (!inited) { + prev = std::set_terminate(handler); + curr = std::get_terminate(); + inited = true; + } + } + + std::terminate_handler getPrev() const { + return prev; + } + + private: + std::terminate_handler prev = nullptr; + std::terminate_handler curr = nullptr; + bool inited = false; + std::mutex mutex; + AbortHandlerHelper() = default; + ~AbortHandlerHelper() { + // Only restore the handler if we are the current one + if (inited && curr == std::get_terminate()) { + std::set_terminate(prev); + } + } + + public: + AbortHandlerHelper(AbortHandlerHelper const&) = delete; + void operator=(AbortHandlerHelper const&) = delete; +}; + +namespace detail { +C10_ALWAYS_INLINE void terminate_handler() { + std::cout << "Unhandled exception caught in c10/util/AbortHandler.h" << '\n'; + auto backtrace = get_backtrace(); + std::cout << backtrace << '\n' << std::flush; + auto prev_handler = AbortHandlerHelper::getInstance().getPrev(); + if (prev_handler) { + prev_handler(); + } else { + std::abort(); + } +} +} // namespace detail + +C10_ALWAYS_INLINE void set_terminate_handler() { + bool use_custom_terminate = false; + // On Windows it is enabled by default based on + // https://github.com/pytorch/pytorch/pull/50320#issuecomment-763147062 +#ifdef _WIN32 + use_custom_terminate = true; +#endif // _WIN32 + auto result = c10::utils::check_env("TORCH_CUSTOM_TERMINATE"); + if (result != std::nullopt) { + use_custom_terminate = result.value(); + } + if (use_custom_terminate) { + AbortHandlerHelper::getInstance().set(detail::terminate_handler); + } +} +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/AlignOf.h b/MLPY/Lib/site-packages/torch/include/c10/util/AlignOf.h new file mode 100644 index 0000000000000000000000000000000000000000..3fd15693fb7369d94e9ada29509f5240d8fb061d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/AlignOf.h @@ -0,0 +1,176 @@ +//===--- AlignOf.h - Portable calculation of type alignment -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the AlignedCharArray and AlignedCharArrayUnion classes. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::AlignOf +// replaced LLVM_ALIGNAS with alignas + +#pragma once + +#include + +namespace c10 { + +/// \struct AlignedCharArray +/// \brief Helper for building an aligned character array type. +/// +/// This template is used to explicitly build up a collection of aligned +/// character array types. We have to build these up using a macro and explicit +/// specialization to cope with MSVC (at least till 2015) where only an +/// integer literal can be used to specify an alignment constraint. Once built +/// up here, we can then begin to indirect between these using normal C++ +/// template parameters. + +// MSVC requires special handling here. +#ifndef _MSC_VER + +template +struct AlignedCharArray { + // NOLINTNEXTLINE(*c-arrays) + alignas(Alignment) char buffer[Size]; +}; + +#else // _MSC_VER + +/// \brief Create a type with an aligned char buffer. +template +struct AlignedCharArray; + +// We provide special variations of this template for the most common +// alignments because __declspec(align(...)) doesn't actually work when it is +// a member of a by-value function argument in MSVC, even if the alignment +// request is something reasonably like 8-byte or 16-byte. Note that we can't +// even include the declspec with the union that forces the alignment because +// MSVC warns on the existence of the declspec despite the union member forcing +// proper alignment. + +template +struct AlignedCharArray<1, Size> { + union { + char aligned; + char buffer[Size]; + }; +}; + +template +struct AlignedCharArray<2, Size> { + union { + short aligned; + char buffer[Size]; + }; +}; + +template +struct AlignedCharArray<4, Size> { + union { + int aligned; + char buffer[Size]; + }; +}; + +template +struct AlignedCharArray<8, Size> { + union { + double aligned; + char buffer[Size]; + }; +}; + +// The rest of these are provided with a __declspec(align(...)) and we simply +// can't pass them by-value as function arguments on MSVC. + +#define AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(x) \ + template \ + struct AlignedCharArray { \ + __declspec(align(x)) char buffer[Size]; \ + }; + +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(16) +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(32) +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(64) +AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT(128) + +#undef AT_ALIGNEDCHARARRAY_TEMPLATE_ALIGNMENT + +#endif // _MSC_VER + +namespace detail { +template < + typename T1, + typename T2 = char, + typename T3 = char, + typename T4 = char, + typename T5 = char, + typename T6 = char, + typename T7 = char, + typename T8 = char, + typename T9 = char, + typename T10 = char> +class AlignerImpl { + T1 t1; + T2 t2; + T3 t3; + T4 t4; + T5 t5; + T6 t6; + T7 t7; + T8 t8; + T9 t9; + T10 t10; + + public: + AlignerImpl() = delete; +}; + +template < + typename T1, + typename T2 = char, + typename T3 = char, + typename T4 = char, + typename T5 = char, + typename T6 = char, + typename T7 = char, + typename T8 = char, + typename T9 = char, + typename T10 = char> +union SizerImpl { + // NOLINTNEXTLINE(*c-arrays) + char arr1[sizeof(T1)], arr2[sizeof(T2)], arr3[sizeof(T3)], arr4[sizeof(T4)], + arr5[sizeof(T5)], arr6[sizeof(T6)], arr7[sizeof(T7)], arr8[sizeof(T8)], + arr9[sizeof(T9)], arr10[sizeof(T10)]; +}; +} // end namespace detail + +/// \brief This union template exposes a suitably aligned and sized character +/// array member which can hold elements of any of up to ten types. +/// +/// These types may be arrays, structs, or any other types. The goal is to +/// expose a char array buffer member which can be used as suitable storage for +/// a placement new of any of these types. Support for more than ten types can +/// be added at the cost of more boilerplate. +template < + typename T1, + typename T2 = char, + typename T3 = char, + typename T4 = char, + typename T5 = char, + typename T6 = char, + typename T7 = char, + typename T8 = char, + typename T9 = char, + typename T10 = char> +struct AlignedCharArrayUnion + : AlignedCharArray< + alignof(detail::AlignerImpl), + sizeof(::c10::detail:: + SizerImpl)> {}; +} // end namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ApproximateClock.h b/MLPY/Lib/site-packages/torch/include/c10/util/ApproximateClock.h new file mode 100644 index 0000000000000000000000000000000000000000..b0fb3efe5b4f3787226c6b4215f477796716f593 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ApproximateClock.h @@ -0,0 +1,115 @@ +// Copyright 2023-present Facebook. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(C10_IOS) && defined(C10_MOBILE) +#include // for gettimeofday() +#endif + +#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__) +#define C10_RDTSC +#if defined(_MSC_VER) +#include +#elif defined(__CUDACC__) || defined(__HIPCC__) +#undef C10_RDTSC +#elif defined(__clang__) +// `__rdtsc` is available by default. +// NB: This has to be first, because Clang will also define `__GNUC__` +#elif defined(__GNUC__) +#include +#else +#undef C10_RDTSC +#endif +#endif + +namespace c10 { + +using time_t = int64_t; +using steady_clock_t = std::conditional_t< + std::chrono::high_resolution_clock::is_steady, + std::chrono::high_resolution_clock, + std::chrono::steady_clock>; + +inline time_t getTimeSinceEpoch() { + auto now = std::chrono::system_clock::now().time_since_epoch(); + return std::chrono::duration_cast(now).count(); +} + +inline time_t getTime(bool allow_monotonic = false) { +#if defined(C10_IOS) && defined(C10_MOBILE) + // clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS + // can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime + // is implemented or not + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000000000 + + static_cast(now.tv_usec) * 1000; +#elif defined(_WIN32) || defined(__MACH__) + return std::chrono::duration_cast( + steady_clock_t::now().time_since_epoch()) + .count(); +#else + // clock_gettime is *much* faster than std::chrono implementation on Linux + struct timespec t {}; + auto mode = CLOCK_REALTIME; + if (allow_monotonic) { + mode = CLOCK_MONOTONIC; + } + clock_gettime(mode, &t); + return static_cast(t.tv_sec) * 1000000000 + + static_cast(t.tv_nsec); +#endif +} + +// We often do not need to capture true wall times. If a fast mechanism such +// as TSC is available we can use that instead and convert back to epoch time +// during post processing. This greatly reduce the clock's contribution to +// profiling. +// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/ +// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io +// TODO: We should use +// `https://github.com/google/benchmark/blob/main/src/cycleclock.h` +inline auto getApproximateTime() { +#if defined(C10_RDTSC) + return static_cast(__rdtsc()); +#else + return getTime(); +#endif +} + +using approx_time_t = decltype(getApproximateTime()); +static_assert( + std::is_same_v || + std::is_same_v, + "Expected either int64_t (`getTime`) or uint64_t (some TSC reads)."); + +// Convert `getCount` results to Nanoseconds since unix epoch. +class C10_API ApproximateClockToUnixTimeConverter final { + public: + ApproximateClockToUnixTimeConverter(); + std::function makeConverter(); + + struct UnixAndApproximateTimePair { + time_t t_; + approx_time_t approx_t_; + }; + static UnixAndApproximateTimePair measurePair(); + + private: + static constexpr size_t replicates = 1001; + using time_pairs = std::array; + time_pairs measurePairs(); + + time_pairs start_times_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Array.h b/MLPY/Lib/site-packages/torch/include/c10/util/Array.h new file mode 100644 index 0000000000000000000000000000000000000000..ecf91b578137d931301f674a9cb10bd2ec86b5ff --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Array.h @@ -0,0 +1,16 @@ +#include +#include + +namespace c10 { + +// This helper function creates a constexpr std::array +// From a compile time list of values, without requiring you to explicitly +// write out the length. +// +// See also https://stackoverflow.com/a/26351760/23845 +template +inline constexpr auto array_of(T&&... t) -> std::array { + return {{std::forward(t)...}}; +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ArrayRef.h b/MLPY/Lib/site-packages/torch/include/c10/util/ArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..c347c5263483be94921086cd0d9c3a3b492aec0e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ArrayRef.h @@ -0,0 +1,380 @@ +//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::ArrayRef. +// removed llvm-specific functionality +// removed some implicit const -> non-const conversions that rely on +// complicated std::enable_if meta-programming +// removed a bunch of slice variants for simplicity... + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +/// ArrayRef - Represent a constant reference to an array (0 or more elements +/// consecutively in memory), i.e. a start pointer and a length. It allows +/// various APIs to take consecutive elements easily and conveniently. +/// +/// This class does not own the underlying data, it is expected to be used in +/// situations where the data resides in some other buffer, whose lifetime +/// extends past that of the ArrayRef. For this reason, it is not in general +/// safe to store an ArrayRef. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +template +class ArrayRef final { + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + void debugCheckNullptrInvariant() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + Data != nullptr || Length == 0, + "created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal"); + } + + public: + /// @name Constructors + /// @{ + + /// Construct an empty ArrayRef. + /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an ArrayRef from a single element. + // TODO Make this explicit + constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an ArrayRef from a pointer and length. + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length) + : Data(data), Length(length) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a range. + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a SmallVector. This is templated in order to + /// avoid instantiating SmallVectorTemplateCommon whenever we + /// copy-construct an ArrayRef. + template + /* implicit */ ArrayRef(const SmallVectorTemplateCommon& Vec) + : Data(Vec.data()), Length(Vec.size()) { + debugCheckNullptrInvariant(); + } + + template < + typename Container, + typename = std::enable_if_t().data())>, + T*>>> + /* implicit */ ArrayRef(const Container& container) + : Data(container.data()), Length(container.size()) { + debugCheckNullptrInvariant(); + } + + /// Construct an ArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ ArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same::value, + "ArrayRef cannot be constructed from a std::vector bitfield."); + } + + /// Construct an ArrayRef from a std::array + template + /* implicit */ constexpr ArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct an ArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-arrays*) + /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} + + /// Construct an ArrayRef from a std::initializer_list. + /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return Data; + } + constexpr iterator end() const { + return Data + Length; + } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return Data; + } + constexpr const_iterator cend() const { + return Data + Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return Length == 0; + } + + constexpr const T* data() const { + return Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return Length; + } + + /// front - Get the first element. + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const { + TORCH_CHECK( + !empty(), "ArrayRef: attempted to access front() of empty list"); + return Data[0]; + } + + /// back - Get the last element. + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const { + TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list"); + return Data[Length - 1]; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(ArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N, size_t M) + const { + TORCH_CHECK( + N + M <= size(), + "ArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + size()); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N) const { + TORCH_CHECK( + N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); + return slice(N, size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return Data[Index]; + } + + /// Vector compatibility + C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const { + TORCH_CHECK( + Index < Length, + "ArrayRef: invalid index Index = ", + Index, + "; Length = ", + Length); + return Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, ArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { + return std::vector(Data, Data + Length); + } + + /// @} +}; + +template +std::ostream& operator<<(std::ostream& out, ArrayRef list) { + int i = 0; + out << "["; + for (const auto& e : list) { + if (i++ > 0) + out << ", "; + out << e; + } + out << "]"; + return out; +} + +/// @name ArrayRef Convenience constructors +/// @{ + +/// Construct an ArrayRef from a single element. +template +ArrayRef makeArrayRef(const T& OneElt) { + return OneElt; +} + +/// Construct an ArrayRef from a pointer and length. +template +ArrayRef makeArrayRef(const T* data, size_t length) { + return ArrayRef(data, length); +} + +/// Construct an ArrayRef from a range. +template +ArrayRef makeArrayRef(const T* begin, const T* end) { + return ArrayRef(begin, end); +} + +/// Construct an ArrayRef from a SmallVector. +template +ArrayRef makeArrayRef(const SmallVectorImpl& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a SmallVector. +template +ArrayRef makeArrayRef(const SmallVector& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::vector. +template +ArrayRef makeArrayRef(const std::vector& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a std::array. +template +ArrayRef makeArrayRef(const std::array& Arr) { + return Arr; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) (const) +template +ArrayRef makeArrayRef(const ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from an ArrayRef (no-op) +template +ArrayRef& makeArrayRef(ArrayRef& Vec) { + return Vec; +} + +/// Construct an ArrayRef from a C array. +template +// NOLINTNEXTLINE(*c-arrays*) +ArrayRef makeArrayRef(const T (&Arr)[N]) { + return ArrayRef(Arr); +} + +// WARNING: Template instantiation will NOT be willing to do an implicit +// conversions to get you to an c10::ArrayRef, which is why we need so +// many overloads. + +template +bool operator==(c10::ArrayRef a1, c10::ArrayRef a2) { + return a1.equals(a2); +} + +template +bool operator!=(c10::ArrayRef a1, c10::ArrayRef a2) { + return !a1.equals(a2); +} + +template +bool operator==(const std::vector& a1, c10::ArrayRef a2) { + return c10::ArrayRef(a1).equals(a2); +} + +template +bool operator!=(const std::vector& a1, c10::ArrayRef a2) { + return !c10::ArrayRef(a1).equals(a2); +} + +template +bool operator==(c10::ArrayRef a1, const std::vector& a2) { + return a1.equals(c10::ArrayRef(a2)); +} + +template +bool operator!=(c10::ArrayRef a1, const std::vector& a2) { + return !a1.equals(c10::ArrayRef(a2)); +} + +using IntArrayRef = ArrayRef; + +// This alias is deprecated because it doesn't make ownership +// semantics obvious. Use IntArrayRef instead! +C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef) + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16-inl.h b/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..2b8061d34a911b682e1fc1a002aeed1c9b5f1ba4 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16-inl.h @@ -0,0 +1,343 @@ +#pragma once + +#include +#include + +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#else +#include // for SYCL 2020 +#endif +#include +#endif + +namespace c10 { + +/// Constructors +inline C10_HOST_DEVICE BFloat16::BFloat16(float value) + : +#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 800 + x(__bfloat16_as_ushort(__float2bfloat16(value))) +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + x(c10::bit_cast(sycl::ext::oneapi::bfloat16(value))) +#else + // RNE by default + x(detail::round_to_nearest_even(value)) +#endif +{ +} + +/// Implicit conversions +inline C10_HOST_DEVICE BFloat16::operator float() const { +#if defined(__CUDACC__) && !defined(USE_ROCM) + return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + return float(*reinterpret_cast(&x)); +#else + return detail::f32_from_bits(x); +#endif +} + +#if defined(__CUDACC__) && !defined(USE_ROCM) +inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +inline C10_HOST_DEVICE BFloat16::BFloat16( + const sycl::ext::oneapi::bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE BFloat16 +operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 +operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 +operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { + a = a / b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::BFloat16 min() { + return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 lowest() { + return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 max() { + return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 epsilon() { + return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 round_error() { + return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 infinity() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 quiet_NaN() { + return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 signaling_NaN() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 denorm_min() { + return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16-math.h b/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16-math.h new file mode 100644 index 0000000000000000000000000000000000000000..63c48046cf5df668db477b60db19903a966a9b59 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16-math.h @@ -0,0 +1,287 @@ +#pragma once + +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace std { + +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log(T a) { + return std::log(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template < + typename T, + typename std::enable_if_t, int> = 0> +C10_HOST_DEVICE inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + using float_t = T; + constexpr uint8_t bits = 16; + union { + float_t f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16.h b/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..7a4df7d934c680e945df4be7e85da7e4eab4d3ae --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/BFloat16.h @@ -0,0 +1,117 @@ +#pragma once + +// Defines the bloat16 type (brain floating-point). This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. + +#include +#include +#include +#include + +#if defined(__CUDACC__) && !defined(USE_ROCM) +#include +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#else +#include // for SYCL 2020 +#endif +#include +#endif + +namespace c10 { + +namespace detail { +inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + +#if defined(USE_ROCM) + float* tempRes; + + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&tmp); + res = *tempRes; +#else + std::memcpy(&res, &tmp, sizeof(tmp)); +#endif + + return res; +} + +inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { + uint32_t res = 0; + +#if defined(USE_ROCM) + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + uint32_t* tempRes = reinterpret_cast(&src); + res = *tempRes; +#else + std::memcpy(&res, &src, sizeof(res)); +#endif + + return res >> 16; +} + +inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { +#if defined(USE_ROCM) + if (src != src) { +#elif defined(_MSC_VER) + if (isnan(src)) { +#else + if (std::isnan(src)) { +#endif + return UINT16_C(0x7FC0); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + union { + uint32_t U32; + float F32; + }; + + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); + } +} +} // namespace detail + +struct alignas(2) BFloat16 { + uint16_t x; + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) + C10_HOST_DEVICE BFloat16() = default; +#else + BFloat16() = default; +#endif + + struct from_bits_t {}; + static constexpr C10_HOST_DEVICE from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) + : x(bits){}; + inline C10_HOST_DEVICE BFloat16(float value); + inline C10_HOST_DEVICE operator float() const; + +#if defined(__CUDACC__) && !defined(USE_ROCM) + inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); + explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; +#endif +}; + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Backtrace.h b/MLPY/Lib/site-packages/torch/include/c10/util/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..f2c21db94abd40a3a3d2c1942d14bfdae289070c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Backtrace.h @@ -0,0 +1,17 @@ +#ifndef C10_UTIL_BACKTRACE_H_ +#define C10_UTIL_BACKTRACE_H_ + +#include +#include +#include + +#include + +namespace c10 { +C10_API std::string get_backtrace( + size_t frames_to_skip = 0, + size_t maximum_number_of_frames = 64, + bool skip_python_frames = true); +} // namespace c10 + +#endif // C10_UTIL_BACKTRACE_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Bitset.h b/MLPY/Lib/site-packages/torch/include/c10/util/Bitset.h new file mode 100644 index 0000000000000000000000000000000000000000..f66282e62e79c1932ff86dcdb75ab929b4e72be9 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Bitset.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#if defined(_MSC_VER) +#include +#endif + +namespace c10::utils { + +/** + * This is a simple bitset class with sizeof(long long int) bits. + * You can set bits, unset bits, query bits by index, + * and query for the first set bit. + * Before using this class, please also take a look at std::bitset, + * which has more functionality and is more generic. It is probably + * a better fit for your use case. The sole reason for c10::utils::bitset + * to exist is that std::bitset misses a find_first_set() method. + */ +struct bitset final { + private: +#if defined(_MSC_VER) + // MSVCs _BitScanForward64 expects int64_t + using bitset_type = int64_t; +#else + // POSIX ffsll expects long long int + using bitset_type = long long int; +#endif + public: + static constexpr size_t NUM_BITS() { + return 8 * sizeof(bitset_type); + } + + constexpr bitset() noexcept = default; + constexpr bitset(const bitset&) noexcept = default; + constexpr bitset(bitset&&) noexcept = default; + // there is an issure for gcc 5.3.0 when define default function as constexpr + // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754. + bitset& operator=(const bitset&) noexcept = default; + bitset& operator=(bitset&&) noexcept = default; + + constexpr void set(size_t index) noexcept { + bitset_ |= (static_cast(1) << index); + } + + constexpr void unset(size_t index) noexcept { + bitset_ &= ~(static_cast(1) << index); + } + + constexpr bool get(size_t index) const noexcept { + return bitset_ & (static_cast(1) << index); + } + + constexpr bool is_entirely_unset() const noexcept { + return 0 == bitset_; + } + + // Call the given functor with the index of each bit that is set + template + void for_each_set_bit(Func&& func) const { + bitset cur = *this; + size_t index = cur.find_first_set(); + while (0 != index) { + // -1 because find_first_set() is not one-indexed. + index -= 1; + func(index); + cur.unset(index); + index = cur.find_first_set(); + } + } + + private: + // Return the index of the first set bit. The returned index is one-indexed + // (i.e. if the very first bit is set, this function returns '1'), and a + // return of '0' means that there was no bit set. + size_t find_first_set() const { +#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64)) + unsigned long result; + bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); + if (!has_bits_set) { + return 0; + } + return result + 1; +#elif defined(_MSC_VER) && defined(_M_IX86) + unsigned long result; + if (static_cast(bitset_) != 0) { + bool has_bits_set = + (0 != _BitScanForward(&result, static_cast(bitset_))); + if (!has_bits_set) { + return 0; + } + return result + 1; + } else { + bool has_bits_set = + (0 != _BitScanForward(&result, static_cast(bitset_ >> 32))); + if (!has_bits_set) { + return 32; + } + return result + 33; + } +#else + return __builtin_ffsll(bitset_); +#endif + } + + friend bool operator==(bitset lhs, bitset rhs) noexcept { + return lhs.bitset_ == rhs.bitset_; + } + + bitset_type bitset_{0}; +}; + +inline bool operator!=(bitset lhs, bitset rhs) noexcept { + return !(lhs == rhs); +} + +} // namespace c10::utils diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/C++17.h b/MLPY/Lib/site-packages/torch/include/c10/util/C++17.h new file mode 100644 index 0000000000000000000000000000000000000000..448621b758ca01ce3ab976e7d0a7ab14bdafbc12 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/C++17.h @@ -0,0 +1,166 @@ +#pragma once +#ifndef C10_UTIL_CPP17_H_ +#define C10_UTIL_CPP17_H_ + +#include +#include +#include +#include +#include + +#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ + __GNUC__ < 9 +#error \ + "You're trying to build PyTorch with a too old version of GCC. We need GCC 9 or later." +#endif + +#if defined(__clang__) && __clang_major__ < 9 +#error \ + "You're trying to build PyTorch with a too old version of Clang. We need Clang 9 or later." +#endif + +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201703L) +#error You need C++17 to compile PyTorch +#endif + +#if defined(_WIN32) && (defined(min) || defined(max)) +#error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows +#endif + +/* + * This header adds some polyfills with C++17 functionality + */ + +namespace c10 { + +// in c++17 std::result_of has been superseded by std::invoke_result. Since +// c++20, std::result_of is removed. +template +#if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703L +using invoke_result = typename std::invoke_result; +#else +using invoke_result = typename std::result_of; +#endif + +template +using invoke_result_t = typename invoke_result::type; + +// std::is_pod is deprecated in C++20, std::is_standard_layout and +// std::is_trivial are introduced in C++11, std::conjunction has been introduced +// in C++17. +template +using is_pod = std::conjunction, std::is_trivial>; + +template +constexpr bool is_pod_v = is_pod::value; + +namespace guts { + +template +std::enable_if_t< + !std::is_array_v && !std::is_array_v && + std::is_base_of_v, + std::unique_ptr> +make_unique_base(Args&&... args) { + return std::unique_ptr(new Child(std::forward(args)...)); +} + +template +using conjunction = std::conjunction; +template +using disjunction = std::disjunction; +template +using bool_constant = std::bool_constant; +template +using negation = std::negation; + +template +using void_t = std::void_t; + +#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__) + +template +C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) { + return std::apply(std::forward(f), std::forward(t)); +} + +#else + +// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but +// modified) +// TODO This is an incomplete implementation of std::apply, not working for +// member functions. +namespace detail { +template +#if defined(_MSC_VER) +// MSVC has a problem with the decltype() return type, but it also doesn't need +// it +C10_HOST_DEVICE constexpr auto apply_impl( + F&& f, + Tuple&& t, + std::index_sequence) +#else +// GCC/Clang need the decltype() return type +C10_HOST_DEVICE constexpr decltype(auto) apply_impl( + F&& f, + Tuple&& t, + std::index_sequence) +#endif +{ + return std::forward(f)(std::get(std::forward(t))...); +} +} // namespace detail + +template +C10_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) { + return detail::apply_impl( + std::forward(f), + std::forward(t), + std::make_index_sequence< + std::tuple_size>::value>{}); +} + +#endif + +template +std::enable_if_t< + std::is_member_pointer_v>, + typename c10::invoke_result_t> +invoke(Functor&& f, Args&&... args) { + return std::mem_fn(std::forward(f))(std::forward(args)...); +} + +template +std::enable_if_t< + !std::is_member_pointer_v>, + typename c10::invoke_result_t> +invoke(Functor&& f, Args&&... args) { + return std::forward(f)(std::forward(args)...); +} + +namespace detail { +struct _identity final { + template + using type_identity = T; + + template + decltype(auto) operator()(T&& arg) { + return std::forward(arg); + } +}; + +template +struct function_takes_identity_argument : std::false_type {}; + +template +struct function_takes_identity_argument< + Func, + std::void_t()(_identity()))>> : std::true_type { +}; +} // namespace detail + +} // namespace guts +} // namespace c10 + +#endif // C10_UTIL_CPP17_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/CallOnce.h b/MLPY/Lib/site-packages/torch/include/c10/util/CallOnce.h new file mode 100644 index 0000000000000000000000000000000000000000..5ddac80f7c694dbad19076a6448481f63242580e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/CallOnce.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace c10 { + +// custom c10 call_once implementation to avoid the deadlock in std::call_once. +// The implementation here is a simplified version from folly and likely much +// much higher memory footprint. +template +inline void call_once(Flag& flag, F&& f, Args&&... args) { + if (C10_LIKELY(flag.test_once())) { + return; + } + flag.call_once_slow(std::forward(f), std::forward(args)...); +} + +class once_flag { + public: +#ifndef _WIN32 + // running into build error on MSVC. Can't seem to get a repro locally so I'm + // just avoiding constexpr + // + // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error: + // defaulted default constructor cannot be constexpr because the + // corresponding implicitly declared default constructor would not be + // constexpr 1 error detected in the compilation of + // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu". + constexpr +#endif + once_flag() noexcept = default; + once_flag(const once_flag&) = delete; + once_flag& operator=(const once_flag&) = delete; + + private: + template + friend void call_once(Flag& flag, F&& f, Args&&... args); + + template + void call_once_slow(F&& f, Args&&... args) { + std::lock_guard guard(mutex_); + if (init_.load(std::memory_order_relaxed)) { + return; + } + c10::guts::invoke(std::forward(f), std::forward(args)...); + init_.store(true, std::memory_order_release); + } + + bool test_once() { + return init_.load(std::memory_order_acquire); + } + + void reset_once() { + init_.store(false, std::memory_order_release); + } + + private: + std::mutex mutex_; + std::atomic init_{false}; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ConstexprCrc.h b/MLPY/Lib/site-packages/torch/include/c10/util/ConstexprCrc.h new file mode 100644 index 0000000000000000000000000000000000000000..5fb725370184efdf20ec36685006cd97bcd041da --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ConstexprCrc.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10::util { + +namespace detail { +// NOLINTNEXTLINE(*c-arrays*) +constexpr uint64_t crc64_table[] = { + 0x0000000000000000, 0x7ad870c830358979, 0xf5b0e190606b12f2, + 0x8f689158505e9b8b, 0xc038e5739841b68f, 0xbae095bba8743ff6, + 0x358804e3f82aa47d, 0x4f50742bc81f2d04, 0xab28ecb46814fe75, + 0xd1f09c7c5821770c, 0x5e980d24087fec87, 0x24407dec384a65fe, + 0x6b1009c7f05548fa, 0x11c8790fc060c183, 0x9ea0e857903e5a08, + 0xe478989fa00bd371, 0x7d08ff3b88be6f81, 0x07d08ff3b88be6f8, + 0x88b81eabe8d57d73, 0xf2606e63d8e0f40a, 0xbd301a4810ffd90e, + 0xc7e86a8020ca5077, 0x4880fbd87094cbfc, 0x32588b1040a14285, + 0xd620138fe0aa91f4, 0xacf86347d09f188d, 0x2390f21f80c18306, + 0x594882d7b0f40a7f, 0x1618f6fc78eb277b, 0x6cc0863448deae02, + 0xe3a8176c18803589, 0x997067a428b5bcf0, 0xfa11fe77117cdf02, + 0x80c98ebf2149567b, 0x0fa11fe77117cdf0, 0x75796f2f41224489, + 0x3a291b04893d698d, 0x40f16bccb908e0f4, 0xcf99fa94e9567b7f, + 0xb5418a5cd963f206, 0x513912c379682177, 0x2be1620b495da80e, + 0xa489f35319033385, 0xde51839b2936bafc, 0x9101f7b0e12997f8, + 0xebd98778d11c1e81, 0x64b116208142850a, 0x1e6966e8b1770c73, + 0x8719014c99c2b083, 0xfdc17184a9f739fa, 0x72a9e0dcf9a9a271, + 0x08719014c99c2b08, 0x4721e43f0183060c, 0x3df994f731b68f75, + 0xb29105af61e814fe, 0xc849756751dd9d87, 0x2c31edf8f1d64ef6, + 0x56e99d30c1e3c78f, 0xd9810c6891bd5c04, 0xa3597ca0a188d57d, + 0xec09088b6997f879, 0x96d1784359a27100, 0x19b9e91b09fcea8b, + 0x636199d339c963f2, 0xdf7adabd7a6e2d6f, 0xa5a2aa754a5ba416, + 0x2aca3b2d1a053f9d, 0x50124be52a30b6e4, 0x1f423fcee22f9be0, + 0x659a4f06d21a1299, 0xeaf2de5e82448912, 0x902aae96b271006b, + 0x74523609127ad31a, 0x0e8a46c1224f5a63, 0x81e2d7997211c1e8, + 0xfb3aa75142244891, 0xb46ad37a8a3b6595, 0xceb2a3b2ba0eecec, + 0x41da32eaea507767, 0x3b024222da65fe1e, 0xa2722586f2d042ee, + 0xd8aa554ec2e5cb97, 0x57c2c41692bb501c, 0x2d1ab4dea28ed965, + 0x624ac0f56a91f461, 0x1892b03d5aa47d18, 0x97fa21650afae693, + 0xed2251ad3acf6fea, 0x095ac9329ac4bc9b, 0x7382b9faaaf135e2, + 0xfcea28a2faafae69, 0x8632586aca9a2710, 0xc9622c4102850a14, + 0xb3ba5c8932b0836d, 0x3cd2cdd162ee18e6, 0x460abd1952db919f, + 0x256b24ca6b12f26d, 0x5fb354025b277b14, 0xd0dbc55a0b79e09f, + 0xaa03b5923b4c69e6, 0xe553c1b9f35344e2, 0x9f8bb171c366cd9b, + 0x10e3202993385610, 0x6a3b50e1a30ddf69, 0x8e43c87e03060c18, + 0xf49bb8b633338561, 0x7bf329ee636d1eea, 0x012b592653589793, + 0x4e7b2d0d9b47ba97, 0x34a35dc5ab7233ee, 0xbbcbcc9dfb2ca865, + 0xc113bc55cb19211c, 0x5863dbf1e3ac9dec, 0x22bbab39d3991495, + 0xadd33a6183c78f1e, 0xd70b4aa9b3f20667, 0x985b3e827bed2b63, + 0xe2834e4a4bd8a21a, 0x6debdf121b863991, 0x1733afda2bb3b0e8, + 0xf34b37458bb86399, 0x8993478dbb8deae0, 0x06fbd6d5ebd3716b, + 0x7c23a61ddbe6f812, 0x3373d23613f9d516, 0x49aba2fe23cc5c6f, + 0xc6c333a67392c7e4, 0xbc1b436e43a74e9d, 0x95ac9329ac4bc9b5, + 0xef74e3e19c7e40cc, 0x601c72b9cc20db47, 0x1ac40271fc15523e, + 0x5594765a340a7f3a, 0x2f4c0692043ff643, 0xa02497ca54616dc8, + 0xdafce7026454e4b1, 0x3e847f9dc45f37c0, 0x445c0f55f46abeb9, + 0xcb349e0da4342532, 0xb1eceec59401ac4b, 0xfebc9aee5c1e814f, + 0x8464ea266c2b0836, 0x0b0c7b7e3c7593bd, 0x71d40bb60c401ac4, + 0xe8a46c1224f5a634, 0x927c1cda14c02f4d, 0x1d148d82449eb4c6, + 0x67ccfd4a74ab3dbf, 0x289c8961bcb410bb, 0x5244f9a98c8199c2, + 0xdd2c68f1dcdf0249, 0xa7f41839ecea8b30, 0x438c80a64ce15841, + 0x3954f06e7cd4d138, 0xb63c61362c8a4ab3, 0xcce411fe1cbfc3ca, + 0x83b465d5d4a0eece, 0xf96c151de49567b7, 0x76048445b4cbfc3c, + 0x0cdcf48d84fe7545, 0x6fbd6d5ebd3716b7, 0x15651d968d029fce, + 0x9a0d8ccedd5c0445, 0xe0d5fc06ed698d3c, 0xaf85882d2576a038, + 0xd55df8e515432941, 0x5a3569bd451db2ca, 0x20ed197575283bb3, + 0xc49581ead523e8c2, 0xbe4df122e51661bb, 0x3125607ab548fa30, + 0x4bfd10b2857d7349, 0x04ad64994d625e4d, 0x7e7514517d57d734, + 0xf11d85092d094cbf, 0x8bc5f5c11d3cc5c6, 0x12b5926535897936, + 0x686de2ad05bcf04f, 0xe70573f555e26bc4, 0x9ddd033d65d7e2bd, + 0xd28d7716adc8cfb9, 0xa85507de9dfd46c0, 0x273d9686cda3dd4b, + 0x5de5e64efd965432, 0xb99d7ed15d9d8743, 0xc3450e196da80e3a, + 0x4c2d9f413df695b1, 0x36f5ef890dc31cc8, 0x79a59ba2c5dc31cc, + 0x037deb6af5e9b8b5, 0x8c157a32a5b7233e, 0xf6cd0afa9582aa47, + 0x4ad64994d625e4da, 0x300e395ce6106da3, 0xbf66a804b64ef628, + 0xc5bed8cc867b7f51, 0x8aeeace74e645255, 0xf036dc2f7e51db2c, + 0x7f5e4d772e0f40a7, 0x05863dbf1e3ac9de, 0xe1fea520be311aaf, + 0x9b26d5e88e0493d6, 0x144e44b0de5a085d, 0x6e963478ee6f8124, + 0x21c640532670ac20, 0x5b1e309b16452559, 0xd476a1c3461bbed2, + 0xaeaed10b762e37ab, 0x37deb6af5e9b8b5b, 0x4d06c6676eae0222, + 0xc26e573f3ef099a9, 0xb8b627f70ec510d0, 0xf7e653dcc6da3dd4, + 0x8d3e2314f6efb4ad, 0x0256b24ca6b12f26, 0x788ec2849684a65f, + 0x9cf65a1b368f752e, 0xe62e2ad306bafc57, 0x6946bb8b56e467dc, + 0x139ecb4366d1eea5, 0x5ccebf68aecec3a1, 0x2616cfa09efb4ad8, + 0xa97e5ef8cea5d153, 0xd3a62e30fe90582a, 0xb0c7b7e3c7593bd8, + 0xca1fc72bf76cb2a1, 0x45775673a732292a, 0x3faf26bb9707a053, + 0x70ff52905f188d57, 0x0a2722586f2d042e, 0x854fb3003f739fa5, + 0xff97c3c80f4616dc, 0x1bef5b57af4dc5ad, 0x61372b9f9f784cd4, + 0xee5fbac7cf26d75f, 0x9487ca0fff135e26, 0xdbd7be24370c7322, + 0xa10fceec0739fa5b, 0x2e675fb4576761d0, 0x54bf2f7c6752e8a9, + 0xcdcf48d84fe75459, 0xb71738107fd2dd20, 0x387fa9482f8c46ab, + 0x42a7d9801fb9cfd2, 0x0df7adabd7a6e2d6, 0x772fdd63e7936baf, + 0xf8474c3bb7cdf024, 0x829f3cf387f8795d, 0x66e7a46c27f3aa2c, + 0x1c3fd4a417c62355, 0x935745fc4798b8de, 0xe98f353477ad31a7, + 0xa6df411fbfb21ca3, 0xdc0731d78f8795da, 0x536fa08fdfd90e51, + 0x29b7d047efec8728, +}; + +inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA uint64_t +crc64impl(uint64_t accumulator, const char* data, size_t size) { + for (size_t i = 0; i < size; ++i) { + accumulator = + crc64_table[(accumulator ^ data[i]) & 0xFF] ^ (accumulator >> 8); + } + return accumulator; +} +} // namespace detail + +struct crc64_t final : IdWrapper { + constexpr crc64_t(uint64_t checksum) : IdWrapper(checksum) {} + constexpr uint64_t checksum() const { + return this->underlyingId(); + } +}; + +// CRC64 with Jones coefficients and an init value of 0. +inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t +crc64(const char* str, size_t size) { + return crc64_t{detail::crc64impl(0, str, size)}; +} + +inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t crc64(c10::string_view str) { + return crc64(str.data(), str.size()); +} +} // namespace c10::util + +// Allow usage of crc64_t in std::unordered_set +C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::crc64_t); diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/DeadlockDetection.h b/MLPY/Lib/site-packages/torch/include/c10/util/DeadlockDetection.h new file mode 100644 index 0000000000000000000000000000000000000000..ee7ce021e69532c3b7f0e067c301391222928ba1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/DeadlockDetection.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +/// This file provides some simple utilities for detecting common deadlocks in +/// PyTorch. For now, we focus exclusively on detecting Python GIL deadlocks, +/// as the GIL is a wide ranging lock that is taken out in many situations. +/// The basic strategy is before performing an operation that may block, you +/// can use TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() to assert that the GIL is +/// not held. This macro is to be used in contexts where no static dependency +/// on Python is available (we will handle indirecting a virtual call for you). +/// +/// If the GIL is held by a torchdeploy interpreter, we always report false. +/// If you are in a context where Python bindings are available, it's better +/// to directly assert on PyGILState_Check (as it avoids a vcall and also +/// works correctly with torchdeploy.) + +#define TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() \ + TORCH_INTERNAL_ASSERT( \ + !c10::impl::check_python_gil(), \ + "Holding GIL before a blocking operation! Please release the GIL before blocking, or see https://github.com/pytorch/pytorch/issues/56297 for how to release the GIL for destructors of objects") + +namespace c10::impl { + +C10_API bool check_python_gil(); + +struct C10_API PythonGILHooks { + virtual ~PythonGILHooks() = default; + // Returns true if we hold the GIL. If not linked against Python we + // always return false. + virtual bool check_python_gil() const = 0; +}; + +C10_API void SetPythonGILHooks(PythonGILHooks* factory); + +// DO NOT call this registerer from a torch deploy instance! You will clobber +// other registrations +struct C10_API PythonGILHooksRegisterer { + explicit PythonGILHooksRegisterer(PythonGILHooks* factory) { + SetPythonGILHooks(factory); + } + ~PythonGILHooksRegisterer() { + SetPythonGILHooks(nullptr); + } +}; + +} // namespace c10::impl diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Deprecated.h b/MLPY/Lib/site-packages/torch/include/c10/util/Deprecated.h new file mode 100644 index 0000000000000000000000000000000000000000..6242b93ea400702ed47c7ddcaa6f25b325359bd3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Deprecated.h @@ -0,0 +1,102 @@ +#pragma once + +/** + * This file provides portable macros for marking declarations + * as deprecated. You should generally use C10_DEPRECATED, + * except when marking 'using' declarations as deprecated, + * in which case you should use C10_DEFINE_DEPRECATED_USING + * (due to portability concerns). + */ + +// Sample usage: +// +// C10_DEPRECATED void bad_func(); +// struct C10_DEPRECATED BadStruct { +// ... +// }; + +// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses +// the "__declspec(deprecated)" implementation and not the C++14 +// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on +// MSVC, but ran into issues with some older MSVC versions. +#if (defined(__cplusplus) && __cplusplus >= 201402L) +#define C10_DEPRECATED [[deprecated]] +#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] +#elif defined(__GNUC__) +#define C10_DEPRECATED __attribute__((deprecated)) +// TODO Is there some way to implement this? +#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) + +#elif defined(_MSC_VER) +#define C10_DEPRECATED __declspec(deprecated) +#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) +#else +#warning "You need to implement C10_DEPRECATED for this compiler" +#define C10_DEPRECATED +#endif + +// Sample usage: +// +// C10_DEFINE_DEPRECATED_USING(BadType, int) +// +// which is the portable version of +// +// using BadType [[deprecated]] = int; + +// technically [[deprecated]] syntax is from c++14 standard, but it works in +// many compilers. +#if defined(__has_cpp_attribute) +#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__) +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; +#endif +#endif + +#if defined(_MSC_VER) +#if defined(__CUDACC__) +// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows; +// you get the error: +// +// error: attribute does not apply to any entity +// +// So we just turn the macro off in this case. +#if defined(C10_DEFINE_DEPRECATED_USING) +#undef C10_DEFINE_DEPRECATED_USING +#endif +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; +#else +// [[deprecated]] does work in windows without nvcc, though msc doesn't support +// `__has_cpp_attribute` when c++14 is supported, otherwise +// __declspec(deprecated) is used as the alternative. +#ifndef C10_DEFINE_DEPRECATED_USING +#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; +#else +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = __declspec(deprecated) TypeThingy; +#endif +#endif +#endif +#endif + +#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__) +// nvcc has a bug where it doesn't understand __attribute__((deprecated)) +// declarations even when the host compiler supports it. We'll only use this gcc +// attribute when not cuda, and when using a GCC compiler that doesn't support +// the c++14 syntax we checked for above (available in __GNUC__ >= 5) +#if !defined(__CUDACC__) +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName __attribute__((deprecated)) = TypeThingy; +#else +// using cuda + gcc < 5, neither deprecated syntax is available so turning off. +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; +#endif +#endif + +#if !defined(C10_DEFINE_DEPRECATED_USING) +#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" +#define C10_DEFINE_DEPRECATED_USING +#endif diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/DimVector.h b/MLPY/Lib/site-packages/torch/include/c10/util/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..0ae8169e86682bf2d084cba8e8eaa9a186faa9c3 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/DimVector.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { + +constexpr size_t kDimVectorStaticSize = C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; + +/// A container for sizes or strides +using DimVector = SmallVector; +using SymDimVector = SmallVector; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Exception.h b/MLPY/Lib/site-packages/torch/include/c10/util/Exception.h new file mode 100644 index 0000000000000000000000000000000000000000..64cb7351fc47c477f9e12ca31d1ca2e15029d7ef --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Exception.h @@ -0,0 +1,711 @@ +#ifndef C10_UTIL_EXCEPTION_H_ +#define C10_UTIL_EXCEPTION_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +#define __func__ __FUNCTION__ +#endif + +namespace c10 { + +/// The primary ATen error class. +/// Provides a complete error message with source location information via +/// `what()`, and a more concise message via `what_without_backtrace()`. +/// Don't throw this directly; use TORCH_CHECK/TORCH_INTERNAL_ASSERT instead. +/// +/// NB: c10::Error is handled specially by the default torch to suppress the +/// backtrace, see torch/csrc/Exceptions.h +class C10_API Error : public std::exception { + // The actual error message. + std::string msg_; + + // Context for the message (in order of decreasing specificity). Context will + // be automatically formatted appropriately, so it is not necessary to add + // extra leading/trailing newlines to strings inside this vector + std::vector context_; + + // The C++ backtrace at the point when this exception was raised. This + // may be empty if there is no valid backtrace. (We don't use optional + // here to reduce the dependencies this file has.) + std::string backtrace_; + + // These two are derived fields from msg_stack_ and backtrace_, but we need + // fields for the strings so that we can return a const char* (as the + // signature of std::exception requires). Currently, the invariant + // is that these fields are ALWAYS populated consistently with respect + // to msg_stack_ and backtrace_. + std::string what_; + std::string what_without_backtrace_; + + // This is a little debugging trick: you can stash a relevant pointer + // in caller, and then when you catch the exception, you can compare + // against pointers you have on hand to get more information about + // where the exception came from. In Caffe2, this is used to figure + // out which operator raised an exception. + const void* caller_; + + public: + // PyTorch-style Error constructor. NB: the implementation of this + // is actually in Logging.cpp + Error(SourceLocation source_location, std::string msg); + + // Caffe2-style error message + Error( + const char* file, + const uint32_t line, + const char* condition, + const std::string& msg, + const std::string& backtrace, + const void* caller = nullptr); + + // Base constructor + Error(std::string msg, std::string backtrace, const void* caller = nullptr); + + // Add some new context to the message stack. The last added context + // will be formatted at the end of the context list upon printing. + // WARNING: This method is O(n) in the size of the stack, so don't go + // wild adding a ridiculous amount of context to error messages. + void add_context(std::string msg); + + const std::string& msg() const { + return msg_; + } + + const std::vector& context() const { + return context_; + } + + const std::string& backtrace() const { + return backtrace_; + } + + /// Returns the complete error message, including the source location. + /// The returned pointer is invalidated if you call add_context() on + /// this object. + const char* what() const noexcept override { + return what_.c_str(); + } + + const void* caller() const noexcept { + return caller_; + } + + /// Returns only the error message string, without source location. + /// The returned pointer is invalidated if you call add_context() on + /// this object. + virtual const char* what_without_backtrace() const noexcept { + return what_without_backtrace_.c_str(); + } + + private: + void refresh_what(); + std::string compute_what(bool include_backtrace) const; +}; + +class C10_API Warning { + public: + class C10_API UserWarning {}; + class C10_API DeprecationWarning {}; + + using warning_variant_t = std::variant; + + Warning( + warning_variant_t type, + const SourceLocation& source_location, + std::string msg, + bool verbatim); + + Warning( + warning_variant_t type, + SourceLocation source_location, + const char* msg, + bool verbatim); + + Warning( + warning_variant_t type, + SourceLocation source_location, + ::c10::detail::CompileTimeEmptyString msg, + bool verbatim); + + // Getters for members + warning_variant_t type() const; + const SourceLocation& source_location() const; + const std::string& msg() const; + bool verbatim() const; + + private: + // The type of warning + warning_variant_t type_; + + // Where the warning happened. + SourceLocation source_location_; + + // The actual warning message. + std::string msg_; + + // See note: [Verbatim Warnings] + bool verbatim_; +}; + +using UserWarning = Warning::UserWarning; +using DeprecationWarning = Warning::DeprecationWarning; + +// Issue a warning with a given message. Dispatched to the current +// warning handler. +void C10_API warn(const Warning& warning); + +class C10_API WarningHandler { + public: + virtual ~WarningHandler() = default; + /// The default warning handler. Prints the message to stderr. + virtual void process(const Warning& warning); +}; + +namespace WarningUtils { + +// Note: [Verbatim Warnings] +// Warnings originating in C++ code can appear out-of-place to Python users: +// a user runs a line in Python, but the warning references a line in C++. +// Some parts of PyTorch, like the JIT, are cognizant of this mismatch +// and take care to map warnings back to the user's program, but most +// of PyTorch simply throws a context-free warning. To allow warning +// handlers to add context where appropriate, warn takes the +// "verbatim" flag. When this is false a warning handler might append +// the C++ warning to a Python warning message that relates the warning +// back to the user's program. Callers who have already accounted for +// context in their warnings should set verbatim to true so their warnings +// appear without modification. + +/// Sets the global warning handler. This is not thread-safe, so it should +/// generally be called once during initialization or while holding the GIL +/// for programs that use python. +/// User is responsible for keeping the WarningHandler alive until +/// it is not needed. +C10_API void set_warning_handler(WarningHandler* handler) noexcept(true); +/// Gets the global warning handler. +C10_API WarningHandler* get_warning_handler() noexcept(true); + +class C10_API WarningHandlerGuard { + WarningHandler* prev_handler_; + + public: + WarningHandlerGuard(WarningHandler* new_handler) + : prev_handler_(c10::WarningUtils::get_warning_handler()) { + c10::WarningUtils::set_warning_handler(new_handler); + } + ~WarningHandlerGuard() { + c10::WarningUtils::set_warning_handler(prev_handler_); + } +}; + +/// The TORCH_WARN_ONCE macro is difficult to test for. Use +/// setWarnAlways(true) to turn it into TORCH_WARN, which can be +/// tested for more easily. +C10_API void set_warnAlways(bool) noexcept(true); +C10_API bool get_warnAlways() noexcept(true); + +// A RAII guard that sets warn_always (not thread-local) on +// construction, and sets it back to the original value upon destruction. +struct C10_API WarnAlways { + public: + explicit WarnAlways(bool setting = true); + ~WarnAlways(); + + private: + bool prev_setting; +}; + +} // namespace WarningUtils + +// Like Error, but we always report the C++ backtrace, instead of only +// reporting when TORCH_SHOW_CPP_STACKTRACES +class C10_API ErrorAlwaysShowCppStacktrace : public Error { + using Error::Error; + const char* what_without_backtrace() const noexcept override { + return what(); + } +}; + +// Used in ATen for out-of-bound indices that can reasonably only be detected +// lazily inside a kernel (See: advanced indexing). These turn into +// IndexError when they cross to Python. +class C10_API IndexError : public Error { + using Error::Error; +}; + +// Used in ATen for invalid values. These turn into +// ValueError when they cross to Python. +class C10_API ValueError : public Error { + using Error::Error; +}; + +// Used in ATen for invalid types. These turn into +// TypeError when they cross to Python. +class C10_API TypeError : public Error { + using Error::Error; +}; + +// Used in ATen for functionality that is not implemented. These turn into +// NotImplementedError when they cross to Python. +class C10_API NotImplementedError : public Error { + using Error::Error; +}; + +// Used in ATen for non finite indices. These turn into +// ExitException when they cross to Python. +class C10_API EnforceFiniteError : public Error { + using Error::Error; +}; + +// Used in Onnxifi backend lowering. These turn into +// ExitException when they cross to Python. +class C10_API OnnxfiBackendSystemError : public Error { + using Error::Error; +}; + +// Used for numerical errors from the linalg module. These +// turn into LinAlgError when they cross into Python. +class C10_API LinAlgError : public Error { + using Error::Error; +}; + +class C10_API OutOfMemoryError : public Error { + using Error::Error; +}; + +// Base error type for all distributed errors. +// These turn into DistError when they cross into Python. +class C10_API DistError : public Error { + using Error::Error; +}; + +// Used for collective communication library errors from the distributed module. +// These turn into DistBackendError when they cross into Python. +class C10_API DistBackendError : public DistError { + using DistError::DistError; +}; + +// Used for errors originating from the store. +// These turn into DistStoreError when they cross into Python. +class C10_API DistStoreError : public DistError { + using DistError::DistError; +}; + +// Used for errors originating from the TCP/IP stack and not from collective +// libraries. These turn into DistNetworkError when they cross into Python. +class C10_API DistNetworkError : public DistError { + using DistError::DistError; +}; + +// A utility function to return an exception std::string by prepending its +// exception type before its what() content +C10_API std::string GetExceptionString(const std::exception& e); + +} // namespace c10 + +// Private helper macro for implementing TORCH_INTERNAL_ASSERT and TORCH_CHECK +// +// Note: In the debug build With MSVC, __LINE__ might be of long type (a.k.a +// int32_t), which is different from the definition of `SourceLocation` that +// requires unsigned int (a.k.a uint32_t) and may cause a compile error with the +// message: error C2397: conversion from 'long' to 'uint32_t' requires a +// narrowing conversion Here the static cast is used to pass the build. if this +// is used inside a lambda the __func__ macro expands to operator(), which isn't +// very useful, but hard to fix in a macro so suppressing the warning. +#define C10_THROW_ERROR(err_type, msg) \ + throw ::c10::err_type( \ + {__func__, __FILE__, static_cast(__LINE__)}, msg) + +#define C10_BUILD_ERROR(err_type, msg) \ + ::c10::err_type({__func__, __FILE__, static_cast(__LINE__)}, msg) + +// Private helper macro for workaround MSVC misexpansion of nested macro +// invocations involving __VA_ARGS__. See +// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly +#define C10_EXPAND_MSVC_WORKAROUND(x) x + +// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases +// where the unlikely expression may be a constant, use this macro to ensure +// return statement analysis keeps working (at the cost of not getting the +// likely/unlikely annotation on nvcc). +// https://github.com/pytorch/pytorch/issues/21418 +// +// Currently, this is only used in the error reporting macros below. If you +// want to use it more generally, move me to Macros.h +// +// TODO: Brian Vaughan observed that we might be able to get this to work on +// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs +// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY +// in nvcc is causing us perf problems, this is not yet implemented, but this +// might be an interesting piece of C++ code for an intrepid bootcamper to +// write. +#if defined(__CUDACC__) +#define C10_UNLIKELY_OR_CONST(e) e +#else +#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) +#endif + +// ---------------------------------------------------------------------------- +// Error reporting macros +// ---------------------------------------------------------------------------- + +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_RETHROW(e, ...) throw +#else +#define TORCH_RETHROW(e, ...) \ + do { \ + e.add_context(::c10::str(__VA_ARGS__)); \ + throw; \ + } while (false) +#endif + +// A utility macro to provide assert()-like functionality; that is, enforcement +// of internal invariants in code. It supports an arbitrary number of extra +// arguments (evaluated only on failure), which will be printed in the assert +// failure message using operator<< (this is useful to print some variables +// which may be useful for debugging.) +// +// Usage: +// TORCH_INTERNAL_ASSERT(should_be_true); +// TORCH_INTERNAL_ASSERT(x == 0, "x = ", x); +// +// Assuming no bugs in PyTorch, the conditions tested by this macro should +// always be true; e.g., it should be possible to disable all of these +// conditions without changing observable user behavior. If you would like to +// do error reporting for user input, please use TORCH_CHECK instead. +// +// NOTE: It is SAFE to use this macro in production code; on failure, this +// simply raises an exception, it does NOT unceremoniously quit the process +// (unlike assert()). +// +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_INTERNAL_ASSERT(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + #cond " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__)); \ + } +#else +// It would be nice if we could build a combined string literal out of +// the TORCH_INTERNAL_ASSERT prefix and a user-provided string literal +// as the first argument, but there doesn't seem to be any good way to +// do that while still supporting having a first argument that isn't a +// string literal. +#define TORCH_INTERNAL_ASSERT(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchInternalAssertFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + #cond \ + " INTERNAL ASSERT FAILED at " C10_STRINGIZE(__FILE__) ":" C10_STRINGIZE( \ + __LINE__) ", please report a bug to PyTorch. ", \ + c10::str(__VA_ARGS__)); \ + } +#endif + +// A utility macro to make it easier to test for error conditions from user +// input. Like TORCH_INTERNAL_ASSERT, it supports an arbitrary number of extra +// arguments (evaluated only on failure), which will be printed in the error +// message using operator<< (e.g., you can pass any object which has +// operator<< defined. Most objects in PyTorch have these definitions!) +// +// Usage: +// TORCH_CHECK(should_be_true); // A default error message will be provided +// // in this case; but we recommend writing an +// // explicit error message, as it is more +// // user friendly. +// TORCH_CHECK(x == 0, "Expected x to be 0, but got ", x); +// +// On failure, this macro will raise an exception. If this exception propagates +// to Python, it will convert into a Python RuntimeError. +// +// NOTE: It is SAFE to use this macro in production code; on failure, this +// simply raises an exception, it does NOT unceremoniously quit the process +// (unlike CHECK() from glog.) +// +#define TORCH_CHECK_WITH(error_t, cond, ...) \ + TORCH_CHECK_WITH_MSG(error_t, cond, "", __VA_ARGS__) + +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_CHECK_MSG(cond, type, ...) \ + (#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__)) +#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + C10_THROW_ERROR(Error, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \ + } +#else + +namespace c10::detail { +template +decltype(auto) torchCheckMsgImpl(const char* /*msg*/, const Args&... args) { + return ::c10::str(args...); +} +inline C10_API const char* torchCheckMsgImpl(const char* msg) { + return msg; +} +// If there is just 1 user-provided C-string argument, use it. +inline C10_API const char* torchCheckMsgImpl( + const char* /*msg*/, + const char* args) { + return args; +} +} // namespace c10::detail + +#define TORCH_CHECK_MSG(cond, type, ...) \ + (::c10::detail::torchCheckMsgImpl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)", \ + ##__VA_ARGS__)) +#define TORCH_CHECK_WITH_MSG(error_t, cond, type, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + C10_THROW_ERROR(error_t, TORCH_CHECK_MSG(cond, type, __VA_ARGS__)); \ + } +#endif + +namespace c10::detail { + +[[noreturn]] C10_API void torchCheckFail( + const char* func, + const char* file, + uint32_t line, + const std::string& msg); +[[noreturn]] C10_API void torchCheckFail( + const char* func, + const char* file, + uint32_t line, + const char* msg); + +// The c10::str() call that creates userMsg can have 1 of 3 return +// types depending on the number and types of arguments passed to +// TORCH_INTERNAL_ASSERT. 0 arguments will get a +// CompileTimeEmptyString, 1 const char * will be passed straight +// through, and anything else will get converted to std::string. +[[noreturn]] C10_API void torchInternalAssertFail( + const char* func, + const char* file, + uint32_t line, + const char* condMsg, + const char* userMsg); +[[noreturn]] inline C10_API void torchInternalAssertFail( + const char* func, + const char* file, + uint32_t line, + const char* condMsg, + ::c10::detail::CompileTimeEmptyString /*userMsg*/) { + torchCheckFail(func, file, line, condMsg); +} +[[noreturn]] C10_API void torchInternalAssertFail( + const char* func, + const char* file, + uint32_t line, + const char* condMsg, + const std::string& userMsg); + +} // namespace c10::detail + +#ifdef STRIP_ERROR_MESSAGES +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } +#else +#define TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchCheckFail( \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ + } +#endif + +// An utility macro that does what `TORCH_CHECK` does if compiled in the host +// code, otherwise does nothing. Supposed to be used in the code shared between +// host and device code as an alternative for `TORCH_CHECK`. +#if defined(__CUDACC__) || defined(__HIPCC__) +#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) +#else +#define TORCH_CHECK_IF_NOT_ON_CUDA(cond, ...) TORCH_CHECK(cond, ##__VA_ARGS__) +#endif + +// Debug only version of TORCH_INTERNAL_ASSERT. This macro only checks in debug +// build, and does nothing in release build. It is appropriate to use +// in situations where you want to add an assert to a hotpath, but it is +// too expensive to run this assert on production builds. +#ifdef NDEBUG +// Optimized version - generates no code. +#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + while (false) \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) +#else +#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) +#endif + +// TODO: We're going to get a lot of similar looking string literals +// this way; check if this actually affects binary size. + +// Like TORCH_CHECK, but raises LinAlgError instead of Error. +#define TORCH_CHECK_LINALG(cond, ...) \ + TORCH_CHECK_WITH_MSG(LinAlgError, cond, "LINALG", __VA_ARGS__) + +// Like TORCH_CHECK, but raises IndexErrors instead of Errors. +#define TORCH_CHECK_INDEX(cond, ...) \ + TORCH_CHECK_WITH_MSG(IndexError, cond, "INDEX", __VA_ARGS__) + +// Like TORCH_CHECK, but raises ValueErrors instead of Errors. +#define TORCH_CHECK_VALUE(cond, ...) \ + TORCH_CHECK_WITH_MSG(ValueError, cond, "VALUE", __VA_ARGS__) + +// Like TORCH_CHECK, but raises TypeErrors instead of Errors. +#define TORCH_CHECK_TYPE(cond, ...) \ + TORCH_CHECK_WITH_MSG(TypeError, cond, "TYPE", __VA_ARGS__) + +// Like TORCH_CHECK, but raises NotImplementedErrors instead of Errors. +#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ + TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__) + +#define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \ + TORCH_CHECK_WITH_MSG( \ + ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__) + +#ifdef STRIP_ERROR_MESSAGES +#define WARNING_MESSAGE_STRING(...) \ + ::c10::detail::CompileTimeEmptyString {} +#else +#define WARNING_MESSAGE_STRING(...) ::c10::str(__VA_ARGS__) +#endif + +// Report a warning to the user. Accepts an arbitrary number of extra +// arguments which are concatenated into the warning message using operator<< +// +#ifdef DISABLE_WARN +#define _TORCH_WARN_WITH(...) ((void)0); +#else +#define _TORCH_WARN_WITH(warning_t, ...) \ + ::c10::warn(::c10::Warning( \ + warning_t(), \ + {__func__, __FILE__, static_cast(__LINE__)}, \ + WARNING_MESSAGE_STRING(__VA_ARGS__), \ + false)); +#endif + +#define TORCH_WARN(...) _TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__); + +#define TORCH_WARN_DEPRECATION(...) \ + _TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__); + +// Report a warning to the user only once. Accepts an arbitrary number of extra +// arguments which are concatenated into the warning message using operator<< +// +#define _TORCH_WARN_ONCE(...) \ + C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \ + [&] { \ + TORCH_WARN(__VA_ARGS__); \ + return true; \ + }() + +#ifdef DISABLE_WARN +#define TORCH_WARN_ONCE(...) ((void)0); +#else +#define TORCH_WARN_ONCE(...) \ + if (::c10::WarningUtils::get_warnAlways()) { \ + TORCH_WARN(__VA_ARGS__); \ + } else { \ + _TORCH_WARN_ONCE(__VA_ARGS__); \ + } +#endif + +// Report an error with a specific argument +// NOTE: using the argument name in TORCH_CHECK's message is preferred +#define TORCH_CHECK_ARG(cond, argN, ...) \ + TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) + +// ---------------------------------------------------------------------------- +// Deprecated macros +// ---------------------------------------------------------------------------- + +namespace c10::detail { + +/* +// Deprecation disabled until we fix sites in our codebase +C10_DEPRECATED_MESSAGE("AT_ERROR(msg) is deprecated, use TORCH_CHECK(false, msg) +instead.") +*/ +inline void deprecated_AT_ERROR() {} + +/* +// Deprecation disabled until we fix sites in our codebase +C10_DEPRECATED_MESSAGE("AT_ASSERT is deprecated, if you mean to indicate an +internal invariant failure, use " \ + "TORCH_INTERNAL_ASSERT instead; if you mean to do user +error checking, use " \ "TORCH_CHECK. See +https://github.com/pytorch/pytorch/issues/20287 for more details.") +*/ +inline void deprecated_AT_ASSERT() {} + +/* +// Deprecation disabled until we fix sites in our codebase +C10_DEPRECATED_MESSAGE("AT_ASSERTM is deprecated, if you mean to indicate an +internal invariant failure, use " \ + "TORCH_INTERNAL_ASSERT instead; if you mean to do user +error checking, use " \ "TORCH_CHECK. See +https://github.com/pytorch/pytorch/issues/20287 for more details.") +*/ +inline void deprecated_AT_ASSERTM() {} + +} // namespace c10::detail + +// Deprecated alias; this alias was deprecated because people kept mistakenly +// using it for user error checking. Use TORCH_INTERNAL_ASSERT or TORCH_CHECK +// instead. See https://github.com/pytorch/pytorch/issues/20287 for more +// details. +#define AT_ASSERT(...) \ + do { \ + ::c10::detail::deprecated_AT_ASSERT(); \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)); \ + } while (false) + +// Deprecated alias, like AT_ASSERT. The new TORCH_INTERNAL_ASSERT macro +// supports both 0-ary and variadic calls, so having a separate +// message-accepting macro is not necessary. +// +// NB: we MUST include cond explicitly here, as MSVC will miscompile the macro +// expansion, shunting all of __VA_ARGS__ to cond. An alternate workaround +// can be seen at +// https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly +#define AT_ASSERTM(cond, ...) \ + do { \ + ::c10::detail::deprecated_AT_ASSERTM(); \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \ + } while (false) + +// Deprecated alias; this alias was deprecated because it represents extra API +// surface that makes it hard for people to understand what macro to use. +// Use TORCH_CHECK(false, ...) or TORCH_INTERNAL_ASSERT(false, ...) to +// unconditionally fail at a line of code. +#define AT_ERROR(...) \ + do { \ + ::c10::detail::deprecated_AT_ERROR(); \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \ + } while (false) + +#endif // C10_UTIL_EXCEPTION_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ExclusivelyOwned.h b/MLPY/Lib/site-packages/torch/include/c10/util/ExclusivelyOwned.h new file mode 100644 index 0000000000000000000000000000000000000000..62a7cca47da91c98494b0006dea29d153746ecbe --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ExclusivelyOwned.h @@ -0,0 +1,140 @@ +#pragma once + +#include + +namespace c10 { + +// See example implementation in TensorBase.h and TensorBody.h. +// Synopsis: +// +// repr_type -- type to use to store an owned T in ExclusivelyOwned. +// +// pointer_type -- pointer-esque type to return from +// ExclusivelyOwned's get() and operator*() methods. +// +// const_pointer_type -- similar to pointer_type, used for the const methods. +// +// static repr_type nullRepr() -- return a null instance of repr_type. +// +// template +// static repr_type createInPlace(Args&&... args) -- used by the in-place +// ExclusivelyOwned constructor. +// +// static repr_type moveToRepr(T&& x) -- move the given x into an +// instance of repr_type. used by the ExclusivelyOwned(T&&) +// constructor. +// +// static void destroyOwned(repr_type x) -- free memory for a +// known-exclusively-owned instance of x. Replaces calling repr_type's +// destructor. Being able to implement this more efficiently than +// repr_type's destructor is the main reason to use ExclusivelyOwned +// for a type. +// +// static T take(repr_type&) -- move out of the given repr_type into an owned T. +// +// static pointer_type getImpl(const repr_type&) -- return a pointer +// to the given repr_type. May take repr_type by value if that is more +// efficient. +template +struct ExclusivelyOwnedTraits; + +/// ExclusivelyOwned is a smart-pointer-like wrapper around an +/// exclusively-owned instance of some type T that normally has +/// mandatory reference counting (currently just Tensor). If you have +/// an isolated piece of code that knows that it has sole ownership of +/// an object of one of these types (i.e., because you created it +/// directly or using a factory function) and that object will not +/// escape from that isolated piece of code, then moving the object +/// into an ExclusivelyOwned will avoid an atomic reference count +/// decrement at destruction time. +/// +/// If you directly create the Tensor in the first +/// place, you can use the in_place constructor of ExclusivelyOwned to +/// additionally avoid doing any stores to initialize the refcount & +/// weakcount. +template +class ExclusivelyOwned { + using EOT = ExclusivelyOwnedTraits; + typename ExclusivelyOwnedTraits::repr_type repr_; + + public: + ExclusivelyOwned() : repr_(EOT::nullRepr()) {} + + explicit ExclusivelyOwned(T&& t) : repr_(EOT::moveToRepr(std::move(t))) {} + + template + explicit ExclusivelyOwned(std::in_place_t, Args&&... args) + : repr_(EOT::createInPlace(std::forward(args)...)) {} + + ExclusivelyOwned(const ExclusivelyOwned&) = delete; + + ExclusivelyOwned(ExclusivelyOwned&& rhs) noexcept + : repr_(std::move(rhs.repr_)) { + rhs.repr_ = EOT::nullRepr(); + } + + ExclusivelyOwned& operator=(const ExclusivelyOwned&) = delete; + + ExclusivelyOwned& operator=(ExclusivelyOwned&& rhs) noexcept { + EOT::destroyOwned(repr_); + repr_ = std::move(rhs.repr_); + rhs.repr_ = EOT::nullRepr(); + return *this; + } + + ExclusivelyOwned& operator=(T&& rhs) noexcept { + EOT::destroyOwned(repr_); + repr_ = EOT::moveToRepr(std::move(rhs)); + return *this; + } + + ~ExclusivelyOwned() { + EOT::destroyOwned(repr_); + // Don't bother to call the destructor of repr_, since we already + // did specialized destruction for the exclusively-owned case in + // destroyOwned! + } + + // We don't provide this because it would require us to be able to + // differentiate an owned-but-empty T from a lack of T. This is + // particularly problematic for Tensor, which wants to use an + // undefined Tensor as its null state. + explicit operator bool() const noexcept = delete; + + operator T() && { + return take(); + } + + // NOTE: the equivalent operation on MaybeOwned is a moving + // operator*. For ExclusivelyOwned, take() and operator*() may well + // have different return types, so they are different functions. + T take() && { + return EOT::take(repr_); + } + + typename EOT::const_pointer_type operator->() const { + return get(); + } + + typename EOT::const_pointer_type get() const { + return EOT::getImpl(repr_); + } + + typename EOT::pointer_type operator->() { + return get(); + } + + typename EOT::pointer_type get() { + return EOT::getImpl(repr_); + } + + std::remove_pointer_t& operator*() const { + return *get(); + } + + std::remove_pointer_t& operator*() { + return *get(); + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h b/MLPY/Lib/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..4d61440a7b4b1b6f8a38e12de0799822c6d4223b --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include + +#include + +namespace c10 { +// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and +// at::TensorBase. +template +struct ExclusivelyOwnedTensorTraits { + using repr_type = TensorType; + using pointer_type = TensorType*; + using const_pointer_type = const TensorType*; + + static repr_type nullRepr() { + return TensorType(); + } + + template + static repr_type createInPlace(Args&&... args) { + return TensorType(std::forward(args)...); + } + + static repr_type moveToRepr(TensorType&& x) { + return std::move(x); + } + + static void destroyOwned(TensorType& x) { + TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); + // May be 0 because UndefinedTensorImpl doesn't get its refcount + // incremented. + const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and refcount ", + toDestroy->refcount_, + ", expected 1 or, if isUndefined, 0!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->weakcount_ == 1 || + (toDestroy->weakcount_ == 0 && + toDestroy == UndefinedTensorImpl::singleton()), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and weakcount ", + toDestroy->weakcount_, + ", expected 1 or, if isUndefined, 0!"); + if (!isUndefined) { +#ifndef NDEBUG + // Needed to pass the debug assertions in ~intrusive_ptr_target. + toDestroy->refcount_ = 0; + toDestroy->weakcount_ = 0; +#endif + delete toDestroy; + } + } + + static TensorType take(TensorType& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/FbcodeMaps.h b/MLPY/Lib/site-packages/torch/include/c10/util/FbcodeMaps.h new file mode 100644 index 0000000000000000000000000000000000000000..3b8abdbcfbd99deb2842bd00bd9dddcd6b2713e1 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/FbcodeMaps.h @@ -0,0 +1,29 @@ +#ifndef C10_UTIL_FBCODEMAPS_H_ +#define C10_UTIL_FBCODEMAPS_H_ + +// Map typedefs so that we can use folly's F14 maps in fbcode without +// taking a folly dependency. + +#ifdef FBCODE_CAFFE2 +#include +#include +#else +#include +#include +#endif + +namespace c10 { +#ifdef FBCODE_CAFFE2 +template +using FastMap = folly::F14FastMap; +template +using FastSet = folly::F14FastSet; +#else +template +using FastMap = std::unordered_map; +template +using FastSet = std::unordered_set; +#endif +} // namespace c10 + +#endif // C10_UTIL_FBCODEMAPS_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Flags.h b/MLPY/Lib/site-packages/torch/include/c10/util/Flags.h new file mode 100644 index 0000000000000000000000000000000000000000..6fef9972125def4a273bfe16a583bc51d4ffd2ce --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Flags.h @@ -0,0 +1,226 @@ +#ifndef C10_UTIL_FLAGS_H_ +#define C10_UTIL_FLAGS_H_ + +/* Commandline flags support for C10. + * + * This is a portable commandline flags tool for c10, so we can optionally + * choose to use gflags or a lightweight custom implementation if gflags is + * not possible on a certain platform. If you have gflags installed, set the + * macro C10_USE_GFLAGS will seamlessly route everything to gflags. + * + * To define a flag foo of type bool default to true, do the following in the + * *global* namespace: + * C10_DEFINE_bool(foo, true, "An example."); + * + * To use it in another .cc file, you can use C10_DECLARE_* as follows: + * C10_DECLARE_bool(foo); + * + * In both cases, you can then access the flag via FLAGS_foo. + * + * It is recommended that you build with gflags. To learn more about the flags + * usage, refer to the gflags page here: + * + * https://gflags.github.io/gflags/ + * + * Note about Python users / devs: gflags is initiated from a C++ function + * ParseCommandLineFlags, and is usually done in native binaries in the main + * function. As Python does not have a modifiable main function, it is usually + * difficult to change the flags after Python starts. Hence, it is recommended + * that one sets the default value of the flags to one that's acceptable in + * general - that will allow Python to run without wrong flags. + */ + +#include +#include + +#include + +namespace c10 { +/** + * Sets the usage message when a commandline tool is called with "--help". + */ +C10_API void SetUsageMessage(const std::string& str); + +/** + * Returns the usage message for the commandline tool set by SetUsageMessage. + */ +C10_API const char* UsageMessage(); + +/** + * Parses the commandline flags. + * + * This command parses all the commandline arguments passed in via pargc + * and argv. Once it is finished, partc and argv will contain the remaining + * commandline args that c10 does not deal with. Note that following + * convention, argv[0] contains the binary name and is not parsed. + */ +C10_API bool ParseCommandLineFlags(int* pargc, char*** pargv); + +/** + * Checks if the commandline flags has already been passed. + */ +C10_API bool CommandLineFlagsHasBeenParsed(); + +} // namespace c10 + +//////////////////////////////////////////////////////////////////////////////// +// Below are gflags and non-gflags specific implementations. +// In general, they define the following macros for one to declare (use +// C10_DECLARE) or define (use C10_DEFINE) flags: +// C10_{DECLARE,DEFINE}_{int,int64,double,bool,string} +//////////////////////////////////////////////////////////////////////////////// + +#ifdef C10_USE_GFLAGS + +//////////////////////////////////////////////////////////////////////////////// +// Begin gflags section: most functions are basically rerouted to gflags. +//////////////////////////////////////////////////////////////////////////////// +#include + +// C10 uses hidden visibility by default. However, in gflags, it only uses +// export on Windows platform (with dllexport) but not on linux/mac (with +// default visibility). As a result, to ensure that we are always exporting +// global variables, we will redefine the GFLAGS_DLL_DEFINE_FLAG macro if we +// are building C10 as a shared library. +// This has to be done after the inclusion of gflags, because some early +// versions of gflags.h (e.g. 2.0 on ubuntu 14.04) directly defines the +// macros, so we need to do definition after gflags is done. +#ifdef GFLAGS_DLL_DEFINE_FLAG +#undef GFLAGS_DLL_DEFINE_FLAG +#endif // GFLAGS_DLL_DEFINE_FLAG +#ifdef GFLAGS_DLL_DECLARE_FLAG +#undef GFLAGS_DLL_DECLARE_FLAG +#endif // GFLAGS_DLL_DECLARE_FLAG +#define GFLAGS_DLL_DEFINE_FLAG C10_EXPORT +#define GFLAGS_DLL_DECLARE_FLAG C10_IMPORT + +// gflags before 2.0 uses namespace google and after 2.1 uses namespace gflags. +// Using GFLAGS_GFLAGS_H_ to capture this change. +#ifndef GFLAGS_GFLAGS_H_ +namespace gflags = google; +#endif // GFLAGS_GFLAGS_H_ + +// Motivation about the gflags wrapper: +// (1) We would need to make sure that the gflags version and the non-gflags +// version of C10 are going to expose the same flags abstraction. One should +// explicitly use FLAGS_flag_name to access the flags. +// (2) For flag names, it is recommended to start with c10_ to distinguish it +// from regular gflags flags. For example, do +// C10_DEFINE_BOOL(c10_my_flag, true, "An example"); +// to allow one to use FLAGS_c10_my_flag. +// (3) Gflags has a design issue that does not properly expose the global flags, +// if one builds the library with -fvisibility=hidden. The current gflags (as of +// Aug 2018) only deals with the Windows case using dllexport, and not the Linux +// counterparts. As a result, we will explicitly use C10_EXPORT to export the +// flags defined in C10. This is done via a global reference, so the flag +// itself is not duplicated - under the hood it is the same global gflags flag. +#define C10_GFLAGS_DEF_WRAPPER(type, real_type, name, default_value, help_str) \ + DEFINE_##type(name, default_value, help_str); + +#define C10_DEFINE_int(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(int32, gflags::int32, name, default_value, help_str) +#define C10_DEFINE_int32(name, default_value, help_str) \ + C10_DEFINE_int(name, default_value, help_str) +#define C10_DEFINE_int64(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(int64, gflags::int64, name, default_value, help_str) +#define C10_DEFINE_double(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(double, double, name, default_value, help_str) +#define C10_DEFINE_bool(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(bool, bool, name, default_value, help_str) +#define C10_DEFINE_string(name, default_value, help_str) \ + C10_GFLAGS_DEF_WRAPPER(string, ::fLS::clstring, name, default_value, help_str) + +// DECLARE_typed_var should be used in header files and in the global namespace. +#define C10_GFLAGS_DECLARE_WRAPPER(type, real_type, name) DECLARE_##type(name); + +#define C10_DECLARE_int(name) \ + C10_GFLAGS_DECLARE_WRAPPER(int32, gflags::int32, name) +#define C10_DECLARE_int32(name) C10_DECLARE_int(name) +#define C10_DECLARE_int64(name) \ + C10_GFLAGS_DECLARE_WRAPPER(int64, gflags::int64, name) +#define C10_DECLARE_double(name) \ + C10_GFLAGS_DECLARE_WRAPPER(double, double, name) +#define C10_DECLARE_bool(name) C10_GFLAGS_DECLARE_WRAPPER(bool, bool, name) +#define C10_DECLARE_string(name) \ + C10_GFLAGS_DECLARE_WRAPPER(string, ::fLS::clstring, name) + +//////////////////////////////////////////////////////////////////////////////// +// End gflags section. +//////////////////////////////////////////////////////////////////////////////// + +#else // C10_USE_GFLAGS + +//////////////////////////////////////////////////////////////////////////////// +// Begin non-gflags section: providing equivalent functionality. +//////////////////////////////////////////////////////////////////////////////// + +namespace c10 { + +class C10_API C10FlagParser { + public: + bool success() { + return success_; + } + + protected: + template + bool Parse(const std::string& content, T* value); + bool success_{false}; +}; + +C10_DECLARE_REGISTRY(C10FlagsRegistry, C10FlagParser, const std::string&); + +} // namespace c10 + +// The macros are defined outside the c10 namespace. In your code, you should +// write the C10_DEFINE_* and C10_DECLARE_* macros outside any namespace +// as well. + +#define C10_DEFINE_typed_var(type, name, default_value, help_str) \ + C10_EXPORT type FLAGS_##name = default_value; \ + namespace c10 { \ + namespace { \ + class C10FlagParser_##name : public C10FlagParser { \ + public: \ + explicit C10FlagParser_##name(const std::string& content) { \ + success_ = C10FlagParser::Parse(content, &FLAGS_##name); \ + } \ + }; \ + } \ + RegistererC10FlagsRegistry g_C10FlagsRegistry_##name( \ + #name, \ + C10FlagsRegistry(), \ + RegistererC10FlagsRegistry::DefaultCreator, \ + "(" #type ", default " #default_value ") " help_str); \ + } + +#define C10_DEFINE_int(name, default_value, help_str) \ + C10_DEFINE_typed_var(int, name, default_value, help_str) +#define C10_DEFINE_int32(name, default_value, help_str) \ + C10_DEFINE_int(name, default_value, help_str) +#define C10_DEFINE_int64(name, default_value, help_str) \ + C10_DEFINE_typed_var(int64_t, name, default_value, help_str) +#define C10_DEFINE_double(name, default_value, help_str) \ + C10_DEFINE_typed_var(double, name, default_value, help_str) +#define C10_DEFINE_bool(name, default_value, help_str) \ + C10_DEFINE_typed_var(bool, name, default_value, help_str) +#define C10_DEFINE_string(name, default_value, help_str) \ + C10_DEFINE_typed_var(std::string, name, default_value, help_str) + +// DECLARE_typed_var should be used in header files and in the global namespace. +#define C10_DECLARE_typed_var(type, name) C10_API extern type FLAGS_##name + +#define C10_DECLARE_int(name) C10_DECLARE_typed_var(int, name) +#define C10_DECLARE_int32(name) C10_DECLARE_int(name) +#define C10_DECLARE_int64(name) C10_DECLARE_typed_var(int64_t, name) +#define C10_DECLARE_double(name) C10_DECLARE_typed_var(double, name) +#define C10_DECLARE_bool(name) C10_DECLARE_typed_var(bool, name) +#define C10_DECLARE_string(name) C10_DECLARE_typed_var(std::string, name) + +//////////////////////////////////////////////////////////////////////////////// +// End non-gflags section. +//////////////////////////////////////////////////////////////////////////////// + +#endif // C10_USE_GFLAGS + +#endif // C10_UTIL_FLAGS_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..902a7c0a577ec3799a474bcad26e0755257c3869 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h @@ -0,0 +1,274 @@ +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) + : x(detail::fp8e4m3fn_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const { + return detail::fp8e4m3fn_to_fp32_value(x); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const { + return (x & 0b01111111) == 0b01111111; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fn +operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn +operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn +operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator/( + const Float8_e4m3fn& a, + const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator+=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator-=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator*=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator/=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fn to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e4m3fn min() { + return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn lowest() { + return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn max() { + return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn epsilon() { + return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn round_error() { + return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn quiet_NaN() { + return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn denorm_min() { + return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fn.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fn.h new file mode 100644 index 0000000000000000000000000000000000000000..86034ccef3f5ebae6d3ab7fb7796326f117160ff --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fn.h @@ -0,0 +1,246 @@ +#pragma once + +/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// bias = 7 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/c10/util/Half.h + +#include +#include +#include +#include + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // operator typeid + +namespace c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E4M3FN number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)input << 24; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(nonsign); +#elif defined(__SYCL_DEVICE_ONLY__) + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#elif defined(_MSC_VER) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#endif + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + /* + * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1, + * the addition overflows it into bit 31, and the subsequent shift turns the + * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number + * is Nan, 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 4 so the exponent (4 bits originally) + * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0x07 + * for fp8e4m3fn number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x78, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { + /* + * Binary representation of 480.0f, which is the first value + * not representable in fp8e4m3fn range: + * 0 1111 111 - fp8e4m3fn + * 0 10000111 11100000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(1087) << 20; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fn normal range + * into denorm representation + * magic number: ((127 - 7) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(141) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = 0x7f; + } else { + if (f_bits < (UINT32_C(121) << 23)) { + // Input number is smaller than 2^(-6), which is the smallest + // fp8e4m3fn normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fn { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fn() = default; + + constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) + : x(bits){}; + inline C10_HOST_DEVICE Float8_e4m3fn(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +C10_API std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value); + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..4c54c3b7d7d04b222a6498f15cb97d076d1bf890 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h @@ -0,0 +1,279 @@ +#pragma once + +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) + : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<4, 3>(x); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { + return x == 0b10000000; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/( + const Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fnuz to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e4m3fnuz min() { + return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz lowest() { + return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz max() { + return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz epsilon() { + return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz round_error() { + return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz infinity() { + // NaN (no infinities) + return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz quiet_NaN() { + return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz denorm_min() { + return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fnuz.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fnuz.h new file mode 100644 index 0000000000000000000000000000000000000000..c329024b81d43f8a6fcc21ca3952514b527d6cac --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e4m3fnuz.h @@ -0,0 +1,136 @@ +#pragma once + +/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as Float8_e4m3fn: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// The key differences versus Float8_e4m3fn are: +/// bias = 8 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include +#include + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { + /* + * Binary representation of 256.0f, which is the first value not representable + * (i.e. the first value which would overflow in to the sign bit, resulting in + * a NaN) in fp8e4m3fnuz range: + * 1 0000 000 - fp8e4m3fnuz + * 0 10000111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range + * into denorm representation + * magic number: ((127 - 8) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s. + return 0x80; + } + + if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { + // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fnuz() = default; + + constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) + : x(bits){}; + inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +C10_API std::ostream& operator<<( + std::ostream& out, + const Float8_e4m3fnuz& value); + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2-inl.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..4d242247823f43ee4b1a0536cdff8bf599a00143 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2-inl.h @@ -0,0 +1,283 @@ +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#define EXP_WIDTH_FP8 5 +#define MAN_WIDTH_FP8 2 +#define EXP_BIAS_FP8 15 + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) + : x(detail::fp8e5m2_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e5m2::operator float() const { + return detail::fp8e5m2_to_fp32_value(x); +} + +/// Special values helpers + +inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const { + return (x & 0b01111111) > 0b01111100; +} + +inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const { + return (x & 0b01111111) == 0b01111100; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2 +operator+(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 +operator-(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 +operator*(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator/( + const Float8_e5m2& a, + const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2& operator+=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator-=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator*=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator/=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2 to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = false; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::Float8_e5m2 min() { + return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 max() { + return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 lowest() { + return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 epsilon() { + return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 round_error() { + return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 infinity() { + return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 denorm_min() { + return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2.h new file mode 100644 index 0000000000000000000000000000000000000000..da2eec186a535dd51bae3670a73f04fed0e93c47 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2.h @@ -0,0 +1,143 @@ +#pragma once + +/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// bias = 15 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/c10/util/Half.h + +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E5M2 number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEEE|MM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 26-30 24-25 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + uint16_t half_representation = input; + half_representation <<= 8; + return fp16_ieee_to_fp32_value(half_representation); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) { + /* + * Binary representation of fp32 infinity + * 0 11111111 00000000000000000000000 + */ + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + + /* + * Binary representation of 65536.0f, which is the first value + * not representable in fp8e5m2 range: + * 0 11111 00 - fp8e5m2 + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2 normal range + * into denorm representation + * magic number: ((127 - 15) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); + } else { + if (f_bits < (UINT32_C(113) << 23)) { + // Input number is smaller than 2^(-14), which is the smallest + // fp8e5m2 normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint32_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2 { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2() = default; + + constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {} + inline C10_HOST_DEVICE Float8_e5m2(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; +}; + +C10_API std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value); + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..300107f14c05b8ebc30ceba2b0cfea856ed2409d --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h @@ -0,0 +1,280 @@ +#pragma once + +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) + : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<5, 2>(x); +} + +/// Special values helpers + +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { + return x == 0b10000000; +} + +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { + return false; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/( + const Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2fnuz to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::Float8_e5m2fnuz min() { + return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz max() { + return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz lowest() { + return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz epsilon() { + return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz round_error() { + return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz infinity() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz denorm_min() { + return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2fnuz.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2fnuz.h new file mode 100644 index 0000000000000000000000000000000000000000..b7d8e25ab059e844d4d1c1c724e7c5fe088cca67 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_e5m2fnuz.h @@ -0,0 +1,135 @@ +#pragma once + +/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as e5m2: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// The key differences that e5m2fnuz brings are: +/// bias = 16 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { + /* + * Binary representation of 65536.0f, which is the first value not + * representable (i.e. the first value which would overflow in to the sign + * bit, resulting in a NaN) in fp8e4m3fnuz range: + * 1 00000 00 - fp8e5m2fnuz + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range + * into denormalized representation. + * magic number: ((127 - 16) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s + return 0x80; + } + + if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { + // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2fnuz() = default; + + constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; +}; + +C10_API std::ostream& operator<<( + std::ostream& out, + const Float8_e5m2fnuz& value); + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Float8_fnuz_cvt.h b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_fnuz_cvt.h new file mode 100644 index 0000000000000000000000000000000000000000..1abf3f1c4122120b29882b0596df529720577546 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Float8_fnuz_cvt.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#include + +namespace c10::detail { + +/* + * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ + * format, in bit representation, to a 32-bit floating-point number. + */ +template +inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { + static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2)); + constexpr uint32_t weo = 8; + constexpr uint32_t wmo = 23; + + if (x == 0) { + return 0; + } + + if (x == 0x80) { + constexpr uint32_t ifNaN = 0x7F800001; + return fp32_from_bits(ifNaN); + } + + uint32_t mantissa = x & ((1 << wm) - 1); + uint32_t exponent = (x & 0x7F) >> wm; + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(mantissa); +#elif defined(_MSC_VER) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(mantissa); +#endif + uint32_t sh = 1 + renorm_shift - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + + const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)); + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + uint32_t sign = x >> 7; + uint32_t retval = (sign << 31) | (exponent << 23) | mantissa; + return fp32_from_bits(retval); +} + +} // namespace c10::detail diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/FunctionRef.h b/MLPY/Lib/site-packages/torch/include/c10/util/FunctionRef.h new file mode 100644 index 0000000000000000000000000000000000000000..aac5e11d99e13bd33431e25e68158a08a8de1223 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/FunctionRef.h @@ -0,0 +1,73 @@ +//===- llvm/ADT/STLExtras.h - Useful STL related functions ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some templates that are useful if you are working with the +// STL at all. +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// + +// c10: modified from llvm::function_ref +// c10: added more SFINAE to enable use in overloaded functions + +#pragma once + +#include +#include +#include + +namespace c10 { + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a function_ref. +template +class function_ref; + +template +class function_ref { + Ret (*callback)(intptr_t callable, Params... params) = nullptr; + intptr_t callable{}; + + template + static Ret callback_fn(intptr_t callable, Params... params) { + return (*reinterpret_cast(callable))( + std::forward(params)...); + } + + public: + function_ref() = default; + function_ref(std::nullptr_t) {} + + template + function_ref( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + Callable&& callable, + std::enable_if_t< + !std::is_same_v, function_ref>>* = + nullptr, + std::enable_if_t, + Ret>>* = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} + + Ret operator()(Params... params) const { + return callback(callable, std::forward(params)...); + } + + operator bool() const { + return callback; + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Half-inl.h b/MLPY/Lib/site-packages/torch/include/c10/util/Half-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..14194d035739f05808e5b7071df456639ea05306 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Half-inl.h @@ -0,0 +1,350 @@ +#pragma once + +#include +#include + +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +#include +#endif + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +/// Constructors +inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {} +inline Half::operator float16_t() const { + return detail::fp16_from_bits(x); +} +#else + +inline C10_HOST_DEVICE Half::Half(float value) + : +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + x(__half_as_short(__float2half(value))) +#elif defined(__SYCL_DEVICE_ONLY__) + x(c10::bit_cast(sycl::half(value))) +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + x(at::vec::float2half_scalar(value)) +#else + x(detail::fp16_ieee_from_fp32_value(value)) +#endif +{ +} + +/// Implicit conversions + +inline C10_HOST_DEVICE Half::operator float() const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __half2float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) + return float(c10::bit_cast(x)); +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + return at::vec::half2float_scalar(x); +#elif defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + return detail::native_fp16_to_fp32_value(x); +#else + return detail::fp16_ieee_to_fp32_value(x); +#endif +} + +#endif /* !defined(__aarch64__) || defined(C10_MOBILE) || defined(__CUDACC__) \ + */ + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_HOST_DEVICE Half::Half(const __half& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE Half::operator __half() const { + return *reinterpret_cast(&x); +} +#endif + +#ifdef SYCL_LANGUAGE_VERSION +inline C10_HOST_DEVICE Half::Half(const sycl::half& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE Half::operator sycl::half() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ + (defined(__clang__) && defined(__CUDA__)) +inline __device__ Half __ldg(const Half* ptr) { + return __ldg(reinterpret_cast(ptr)); +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) + return __hneg(a); +#elif defined(__SYCL_DEVICE_ONLY__) + return -c10::bit_cast(a); +#else + return -static_cast(a); +#endif +} + +inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Half a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Half a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Half a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Half a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Half b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Half b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Half b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Half a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Half a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Half a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Half a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Half b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Half b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Half b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Half operator+(Half a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator+(int a, Half b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int a, Half b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int a, Half b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int a, Half b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Half to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + static constexpr c10::Half min() { + return c10::Half(0x0400, c10::Half::from_bits()); + } + static constexpr c10::Half lowest() { + return c10::Half(0xFBFF, c10::Half::from_bits()); + } + static constexpr c10::Half max() { + return c10::Half(0x7BFF, c10::Half::from_bits()); + } + static constexpr c10::Half epsilon() { + return c10::Half(0x1400, c10::Half::from_bits()); + } + static constexpr c10::Half round_error() { + return c10::Half(0x3800, c10::Half::from_bits()); + } + static constexpr c10::Half infinity() { + return c10::Half(0x7C00, c10::Half::from_bits()); + } + static constexpr c10::Half quiet_NaN() { + return c10::Half(0x7E00, c10::Half::from_bits()); + } + static constexpr c10::Half signaling_NaN() { + return c10::Half(0x7D00, c10::Half::from_bits()); + } + static constexpr c10::Half denorm_min() { + return c10::Half(0x0001, c10::Half::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Half.h b/MLPY/Lib/site-packages/torch/include/c10/util/Half.h new file mode 100644 index 0000000000000000000000000000000000000000..4fee5505ca1dcea55bb6851a66ceaf06e978b159 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Half.h @@ -0,0 +1,538 @@ +#pragma once + +/// Defines the Half type (half-precision floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32, instead of using CUDA half intrinsics. +/// Most uses of this type within ATen are memory bound, including the +/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. +/// If you are writing a compute bound kernel, you can use the CUDA half +/// intrinsics directly on the Half type from device code. + +#include +#include +#include +#include +#include +#include + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +#include +#endif + +namespace c10 { + +namespace detail { + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits + * of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. Therefore, if the biased + * exponent of the half-precision input was 0x1F (max possible value), the + * biased exponent of the single-precision output must be 0xFF (max possible + * value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested by the difference in the exponent bias + * (see above). + * - Then we multiply the single-precision result of exponent adjustment by + * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the + * necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least partially IEEE754-compliant + * implementations. + * + * Note that the above operations do not handle denormal inputs (where biased + * exponent == 0). However, they also do not operate on denormal inputs, and + * do not produce denormal results. + */ + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val = 0; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input and with an exponent which would + * scale the corresponding mantissa bits to 2**(-24). A normalized + * single-precision floating-point number is represented as: FP32 = (1 + + * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased + * exponent is 126, a unit change in the mantissa of the input denormalized + * half-precision number causes a change of the constructed single-precision + * number by 2**(-24), i.e. the same amount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number is zero, the constructed + * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = + * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed + * single-precision number to get the numerical equivalent of the input + * half-precision number. + */ + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the input exponent. The variable + * two_w contains input exponent in bits 27-31, therefore if its smaller than + * 2**27, the input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 16-bit floating-point number in IEEE half-precision format, in bit + * representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline uint16_t fp16_ieee_from_fp32_value(float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val = 0, scale_to_zero_val = 0; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy( + &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | + (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +} + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +constexpr inline float16_t fp16_from_bits(uint16_t h) { + union { + uint16_t as_bits; + float16_t as_value; + } fp16 = {h}; + return fp16.as_value; +} + +constexpr inline uint16_t fp16_to_bits(float16_t f) { + union { + float16_t as_value; + uint16_t as_bits; + } fp16 = {.as_value = f}; + return fp16.as_bits; +} + +// According to https://godbolt.org/z/8s14GvEjo it would translate to single +// fcvt s0, h0 +inline float native_fp16_to_fp32_value(uint16_t h) { + return static_cast(fp16_from_bits(h)); +} + +inline uint16_t native_fp16_from_fp32_value(float f) { + return fp16_to_bits(static_cast(f)); +} +#endif + +} // namespace detail + +struct alignas(2) Half { + unsigned short x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) + C10_HOST_DEVICE Half() = default; +#else + Half() = default; +#endif + + constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {} +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) + inline Half(float16_t value); + inline operator float16_t() const; +#else + inline C10_HOST_DEVICE Half(float value); + inline C10_HOST_DEVICE operator float() const; +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_HOST_DEVICE Half(const __half& value); + inline C10_HOST_DEVICE operator __half() const; +#endif +#ifdef SYCL_LANGUAGE_VERSION + inline C10_HOST_DEVICE Half(const sycl::half& value); + inline C10_HOST_DEVICE operator sycl::half() const; +#endif +}; + +// TODO : move to complex.h +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) + : real_(real), imag_(imag) {} + C10_HOST_DEVICE inline complex(const c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline C10_HOST_DEVICE operator c10::complex() const { + return {real_, imag_}; + } + + constexpr C10_HOST_DEVICE Half real() const { + return real_; + } + constexpr C10_HOST_DEVICE Half imag() const { + return imag_; + } + + C10_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + +// In some versions of MSVC, there will be a compiler error when building. +// C4146: unary minus operator applied to unsigned type, result still unsigned +// C4804: unsafe use of type 'bool' in operation +// It can be addressed by disabling the following warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4146) +#pragma warning(disable : 4804) +#pragma warning(disable : 4018) +#endif + +// The overflow checks may involve float to int conversion which may +// trigger precision loss warning. Re-enable the warning once the code +// is fixed. See T58053069. +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +// bool can be converted to any type. +// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: +// `error: comparison of constant '255' with boolean expression is always false` +// for `f > limit::max()` below +template +std::enable_if_t, bool> overflows( + From /*f*/, + bool strict_unsigned = false) { + return false; +} + +// skip isnan and isinf check for integral types +template +std::enable_if_t && !std::is_same_v, bool> +overflows(From f, bool strict_unsigned = false) { + using limit = std::numeric_limits::type>; + if constexpr (!limit::is_signed && std::numeric_limits::is_signed) { + // allow for negative numbers to wrap using two's complement arithmetic. + // For example, with uint8, this allows for `a - b` to be treated as + // `a + 255 * b`. + if (!strict_unsigned) { + return greater_than_max(f) || + (c10::is_negative(f) && + -static_cast(f) > static_cast(limit::max())); + } + } + return c10::less_than_lowest(f) || greater_than_max(f); +} + +template +std::enable_if_t, bool> overflows( + From f, + bool strict_unsigned = false) { + using limit = std::numeric_limits::type>; + if (limit::has_infinity && std::isinf(static_cast(f))) { + return false; + } + if (!limit::has_quiet_NaN && (f != f)) { + return true; + } + return f < limit::lowest() || f > limit::max(); +} + +C10_CLANG_DIAGNOSTIC_POP() + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template +std::enable_if_t::value, bool> overflows( + From f, + bool strict_unsigned = false) { + // casts from complex to real are considered to overflow if the + // imaginary component is non-zero + if (!is_complex::value && f.imag() != 0) { + return true; + } + // Check for overflow componentwise + // (Technically, the imag overflow check is guaranteed to be false + // when !is_complex, but any optimizer worth its salt will be + // able to figure it out.) + return overflows< + typename scalar_value_type::type, + typename From::value_type>(f.real()) || + overflows< + typename scalar_value_type::type, + typename From::value_type>(f.imag()); +} + +C10_API std::ostream& operator<<(std::ostream& out, const Half& value); + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/IdWrapper.h b/MLPY/Lib/site-packages/torch/include/c10/util/IdWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..086f456fc27ab1866a5c4ad9f9e651eed91ada68 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/IdWrapper.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * This template simplifies generation of simple classes that wrap an id + * in a typesafe way. Namely, you can use it to create a very lightweight + * type that only offers equality comparators and hashing. Example: + * + * struct MyIdType final : IdWrapper { + * constexpr explicit MyIdType(uint32_t id): IdWrapper(id) {} + * }; + * + * Then in the global top level namespace: + * + * C10_DEFINE_HASH_FOR_IDWRAPPER(MyIdType); + * + * That's it - equality operators and hash functions are automatically defined + * for you, given the underlying type supports it. + */ +template +class IdWrapper { + public: + using underlying_type = UnderlyingType; + using concrete_type = ConcreteType; + + protected: + constexpr explicit IdWrapper(underlying_type id) noexcept( + noexcept(underlying_type(std::declval()))) + : id_(id) {} + + constexpr underlying_type underlyingId() const + noexcept(noexcept(underlying_type(std::declval()))) { + return id_; + } + + private: + friend size_t hash_value(const concrete_type& v) { + return std::hash()(v.id_); + } + + // TODO Making operator== noexcept if underlying type is noexcept equality + // comparable doesn't work with GCC 4.8. + // Fix this once we don't need GCC 4.8 anymore. + friend constexpr bool operator==( + const concrete_type& lhs, + const concrete_type& rhs) noexcept { + return lhs.id_ == rhs.id_; + } + + // TODO Making operator!= noexcept if operator== is noexcept doesn't work with + // GCC 4.8. + // Fix this once we don't need GCC 4.8 anymore. + friend constexpr bool operator!=( + const concrete_type& lhs, + const concrete_type& rhs) noexcept { + return !(lhs == rhs); + } + + underlying_type id_; +}; + +} // namespace c10 + +#define C10_DEFINE_HASH_FOR_IDWRAPPER(ClassName) \ + namespace std { \ + template <> \ + struct hash { \ + size_t operator()(ClassName x) const { \ + return hash_value(x); \ + } \ + }; \ + } diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/LeftRight.h b/MLPY/Lib/site-packages/torch/include/c10/util/LeftRight.h new file mode 100644 index 0000000000000000000000000000000000000000..a6c09ac98d964277dfe414d6d8c04cb2ac0e28d6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/LeftRight.h @@ -0,0 +1,223 @@ +#include +#include +#include +#include +#include +#include + +namespace c10 { + +namespace detail { + +struct IncrementRAII final { + public: + explicit IncrementRAII(std::atomic* counter) : _counter(counter) { + _counter->fetch_add(1); + } + + ~IncrementRAII() { + _counter->fetch_sub(1); + } + + private: + std::atomic* _counter; + + C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); +}; + +} // namespace detail + +// LeftRight wait-free readers synchronization primitive +// https://hal.archives-ouvertes.fr/hal-01207881/document +// +// LeftRight is quite easy to use (it can make an arbitrary +// data structure permit wait-free reads), but it has some +// particular performance characteristics you should be aware +// of if you're deciding to use it: +// +// - Reads still incur an atomic write (this is how LeftRight +// keeps track of how long it needs to keep around the old +// data structure) +// +// - Writes get executed twice, to keep both the left and right +// versions up to date. So if your write is expensive or +// nondeterministic, this is also an inappropriate structure +// +// LeftRight is used fairly rarely in PyTorch's codebase. If you +// are still not sure if you need it or not, consult your local +// C++ expert. +// +template +class LeftRight final { + public: + template + explicit LeftRight(const Args&... args) + : _counters{{{0}, {0}}}, + _foregroundCounterIndex(0), + _foregroundDataIndex(0), + _data{{T{args...}, T{args...}}}, + _writeMutex() {} + + // Copying and moving would not be threadsafe. + // Needs more thought and careful design to make that work. + LeftRight(const LeftRight&) = delete; + LeftRight(LeftRight&&) noexcept = delete; + LeftRight& operator=(const LeftRight&) = delete; + LeftRight& operator=(LeftRight&&) noexcept = delete; + + ~LeftRight() { + // wait until any potentially running writers are finished + { std::unique_lock lock(_writeMutex); } + + // wait until any potentially running readers are finished + while (_counters[0].load() != 0 || _counters[1].load() != 0) { + std::this_thread::yield(); + } + } + + template + auto read(F&& readFunc) const { + detail::IncrementRAII _increment_counter( + &_counters[_foregroundCounterIndex.load()]); + + return std::forward(readFunc)(_data[_foregroundDataIndex.load()]); + } + + // Throwing an exception in writeFunc is ok but causes the state to be either + // the old or the new state, depending on if the first or the second call to + // writeFunc threw. + template + auto write(F&& writeFunc) { + std::unique_lock lock(_writeMutex); + + return _write(std::forward(writeFunc)); + } + + private: + template + auto _write(const F& writeFunc) { + /* + * Assume, A is in background and B in foreground. In simplified terms, we + * want to do the following: + * 1. Write to A (old background) + * 2. Switch A/B + * 3. Write to B (new background) + * + * More detailed algorithm (explanations on why this is important are below + * in code): + * 1. Write to A + * 2. Switch A/B data pointers + * 3. Wait until A counter is zero + * 4. Switch A/B counters + * 5. Wait until B counter is zero + * 6. Write to B + */ + + auto localDataIndex = _foregroundDataIndex.load(); + + // 1. Write to A + _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + + // 2. Switch A/B data pointers + localDataIndex = localDataIndex ^ 1; + _foregroundDataIndex = localDataIndex; + + /* + * 3. Wait until A counter is zero + * + * In the previous write run, A was foreground and B was background. + * There was a time after switching _foregroundDataIndex (B to foreground) + * and before switching _foregroundCounterIndex, in which new readers could + * have read B but incremented A's counter. + * + * In this current run, we just switched _foregroundDataIndex (A back to + * foreground), but before writing to the new background B, we have to make + * sure A's counter was zero briefly, so all these old readers are gone. + */ + auto localCounterIndex = _foregroundCounterIndex.load(); + _waitForBackgroundCounterToBeZero(localCounterIndex); + + /* + * 4. Switch A/B counters + * + * Now that we know all readers on B are really gone, we can switch the + * counters and have new readers increment A's counter again, which is the + * correct counter since they're reading A. + */ + localCounterIndex = localCounterIndex ^ 1; + _foregroundCounterIndex = localCounterIndex; + + /* + * 5. Wait until B counter is zero + * + * This waits for all the readers on B that came in while both data and + * counter for B was in foreground, i.e. normal readers that happened + * outside of that brief gap between switching data and counter. + */ + _waitForBackgroundCounterToBeZero(localCounterIndex); + + // 6. Write to B + return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + } + + template + auto _callWriteFuncOnBackgroundInstance( + const F& writeFunc, + uint8_t localDataIndex) { + try { + return writeFunc(_data[localDataIndex ^ 1]); + } catch (...) { + // recover invariant by copying from the foreground instance + _data[localDataIndex ^ 1] = _data[localDataIndex]; + // rethrow + throw; + } + } + + void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { + while (_counters[counterIndex ^ 1].load() != 0) { + std::this_thread::yield(); + } + } + + mutable std::array, 2> _counters; + std::atomic _foregroundCounterIndex; + std::atomic _foregroundDataIndex; + std::array _data; + std::mutex _writeMutex; +}; + +// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a +// read-write lock to protect T (data). +template +class RWSafeLeftRightWrapper final { + public: + template + explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} + + // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight + // is not copyable or moveable. + RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; + RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; + RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; + RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + auto read(F&& readFunc) const { + return data_.withLock( + [&readFunc](T const& data) { return std::forward(readFunc)(data); }); + } + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + auto write(F&& writeFunc) { + return data_.withLock( + [&writeFunc](T& data) { return std::forward(writeFunc)(data); }); + } + + private: + c10::Synchronized data_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Load.h b/MLPY/Lib/site-packages/torch/include/c10/util/Load.h new file mode 100644 index 0000000000000000000000000000000000000000..3aec348dbee66b80e2f89f3cf980dde08af2f773 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Load.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include + +namespace c10 { +namespace detail { + +template +struct LoadImpl { + C10_HOST_DEVICE static T apply(const void* src) { + return *reinterpret_cast(src); + } +}; + +template <> +struct LoadImpl { + C10_HOST_DEVICE static bool apply(const void* src) { + static_assert(sizeof(bool) == sizeof(char)); + // NOTE: [Loading boolean values] + // Protect against invalid boolean values by loading as a byte + // first, then converting to bool (see gh-54789). + return *reinterpret_cast(src); + } +}; + +} // namespace detail + +template +C10_HOST_DEVICE T load(const void* src) { + return c10::detail::LoadImpl::apply(src); +} + +template +C10_HOST_DEVICE scalar_t load(const scalar_t* src) { + return c10::detail::LoadImpl::apply(src); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Logging.h b/MLPY/Lib/site-packages/torch/include/c10/util/Logging.h new file mode 100644 index 0000000000000000000000000000000000000000..4c83b26cb451d53ae004876f2244030ce673fed2 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Logging.h @@ -0,0 +1,340 @@ +#ifndef C10_UTIL_LOGGING_H_ +#define C10_UTIL_LOGGING_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// CAFFE2_LOG_THRESHOLD is a compile time flag that would allow us to turn off +// logging at compile time so no logging message below that level is produced +// at all. The value should be between INT_MIN and CAFFE_FATAL. +#ifndef CAFFE2_LOG_THRESHOLD +// If we have not defined the compile time log threshold, we keep all the +// log cases. +#define CAFFE2_LOG_THRESHOLD INT_MIN +#endif // CAFFE2_LOG_THRESHOLD + +// Below are different implementations for glog and non-glog cases. +#ifdef C10_USE_GLOG +#include +#else // !C10_USE_GLOG +#include +#endif // C10_USE_GLOG + +C10_DECLARE_int(caffe2_log_level); +C10_DECLARE_bool(caffe2_use_fatal_for_enforce); + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS. If it's +// not available - just short-circuit to the always working one one. +// We define the C10_ name to avoid confusing other files +#ifdef LOG_EVERY_MS +#define C10_LOG_EVERY_MS(severity, ms) LOG_EVERY_MS(severity, ms) +#else +#define C10_LOG_EVERY_MS(severity, ms) LOG(severity) +#endif + +// Same for LOG_FIRST_N +#ifdef LOG_FIRST_N +#define C10_LOG_FIRST_N(severity, n) LOG_FIRST_N(severity, n) +#else +#define C10_LOG_FIRST_N(severity, n) LOG(severity) +#endif + +// Same for LOG_EVERY_N +#ifdef LOG_EVERY_N +#define C10_LOG_EVERY_N(severity, n) LOG_EVERY_N(severity, n) +#else +#define C10_LOG_EVERY_N(severity, n) LOG(severity) +#endif + +namespace c10 { + +using std::string; + +// Functions that we use for initialization. +C10_API bool InitCaffeLogging(int* argc, char** argv); +C10_API void UpdateLoggingLevelsFromFlags(); + +[[noreturn]] C10_API void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller = nullptr); + +[[noreturn]] C10_API void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const char* msg, + const void* caller = nullptr); + +[[noreturn]] C10_API inline void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + detail::CompileTimeEmptyString /*msg*/, + const void* caller = nullptr) { + ThrowEnforceNotMet(file, line, condition, "", caller); +} + +[[noreturn]] C10_API void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller = nullptr); + +[[noreturn]] C10_API void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + const char* msg, + const void* caller = nullptr); + +[[noreturn]] C10_API inline void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + detail::CompileTimeEmptyString /*msg*/, + const void* caller = nullptr) { + ThrowEnforceFiniteNotMet(file, line, condition, "", caller); +} + +constexpr bool IsUsingGoogleLogging() { +#ifdef C10_USE_GLOG + return true; +#else + return false; +#endif +} + +/** + * A utility to allow one to show log info to stderr after the program starts. + * + * This is similar to calling GLOG's --logtostderr, or setting caffe2_log_level + * to smaller than INFO. You are recommended to only use this in a few sparse + * cases, such as when you want to write a tutorial or something. Normally, use + * the commandline flags to set the log level. + */ +C10_API void ShowLogInfoToStderr(); + +C10_API void SetStackTraceFetcher(std::function fetcher); + +using EnforceNotMet = ::c10::Error; + +#define CAFFE_ENFORCE(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \ + } \ + } while (false) + +#define CAFFE_ENFORCE_FINITE(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceFiniteNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \ + } \ + } while (false) + +#define CAFFE_ENFORCE_WITH_CALLER(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__), this); \ + } \ + } while (false) + +#define CAFFE_THROW(...) \ + ::c10::ThrowEnforceNotMet(__FILE__, __LINE__, "", ::c10::str(__VA_ARGS__)) + +/** + * Rich logging messages + * + * CAFFE_ENFORCE_THAT can be used with one of the "checker functions" that + * capture input argument values and add it to the exception message. E.g. + * `CAFFE_ENFORCE_THAT(Equals(foo(x), bar(y)), "Optional additional message")` + * would evaluate both foo and bar only once and if the results are not equal - + * include them in the exception message. + * + * Some of the basic checker functions like Equals or Greater are already + * defined below. Other header might define customized checkers by adding + * functions to caffe2::enforce_detail namespace. For example: + * + * namespace caffe2 { namespace enforce_detail { + * inline EnforceFailMessage IsVector(const vector& shape) { + * if (shape.size() == 1) { return EnforceOK(); } + * return c10::str("Shape ", shape, " is not a vector"); + * } + * }} + * + * With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))` + * + * Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided + * too. Please use them instead of TORCH_CHECK_EQ and friends for failures in + * user-provided input. + */ + +namespace enforce_detail { + +template +std::string enforceFailMsgImpl(const T1& x, const T2& y) { + return c10::str(x, " vs ", y); +} + +template +std::string enforceFailMsgImpl(const T1& x, const T2& y, const Args&... args) { + return c10::str(x, " vs ", y, ". ", args...); +} + +template +void enforceThatImpl( + Pred p, + const T1& lhs, + const T2& rhs, + const char* file, + int line, + const char* expr, + const void* caller, + GetFailMsgFunc getFailMsg) { + if (C10_UNLIKELY(!(p(lhs, rhs)))) { + ::c10::ThrowEnforceNotMet(file, line, expr, getFailMsg(lhs, rhs), caller); + } +} + +#define CAFFE_ENFORCE_THAT_IMPL(op, lhs, rhs, expr, ...) \ + ::c10::enforce_detail::enforceThatImpl( \ + op, \ + (lhs), \ + (rhs), \ + __FILE__, \ + __LINE__, \ + expr, \ + nullptr, \ + [&](const auto& arg1, const auto& arg2) { \ + return ::c10::enforce_detail::enforceFailMsgImpl( \ + arg1, arg2, ##__VA_ARGS__); \ + }) + +#define CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER(op, lhs, rhs, expr, ...) \ + ::c10::enforce_detail::enforceThatImpl( \ + op, \ + (lhs), \ + (rhs), \ + __FILE__, \ + __LINE__, \ + expr, \ + this, \ + [&](const auto& arg1, const auto& arg2) { \ + return ::c10::enforce_detail::enforceFailMsgImpl( \ + arg1, arg2, ##__VA_ARGS__); \ + }) + +} // namespace enforce_detail + +#define CAFFE_ENFORCE_THAT(cmp, op, lhs, rhs, ...) \ + CAFFE_ENFORCE_THAT_IMPL(cmp, lhs, rhs, #lhs " " #op " " #rhs, ##__VA_ARGS__) + +#define CAFFE_ENFORCE_BINARY_OP(cmp, op, x, y, ...) \ + CAFFE_ENFORCE_THAT_IMPL(cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_EQ(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::equal_to(), ==, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_NE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::not_equal_to(), !=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::less_equal(), <=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LT(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::less(), <, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::greater_equal(), >=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GT(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::greater(), >, x, y, ##__VA_ARGS__) + +#define CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(cmp, op, x, y, ...) \ + CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER( \ + cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_EQ_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::equal_to(), ==, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_NE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::not_equal_to(), !=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::less_equal(), <=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LT_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(std::less(), <, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::greater_equal(), >=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GT_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::greater(), >, x, y, ##__VA_ARGS__) + +/** + * Very lightweight logging for the first time API usage. It's beneficial for + * tracking of individual functionality usage in larger applications. + * + * In order to ensure light-weightedness of logging, we utilize static variable + * trick - LogAPIUsage will be invoked only once and further invocations will + * just do an atomic check. + * + * Example: + * // Logs caller info with an arbitrary text event, if there is a usage. + * C10_LOG_API_USAGE_ONCE("my_api"); + */ +#define C10_LOG_API_USAGE_ONCE(...) \ + C10_UNUSED static bool C10_ANONYMOUS_VARIABLE(logFlag) = \ + ::c10::detail::LogAPIUsageFakeReturn(__VA_ARGS__); + +// API usage logging capabilities +C10_API void SetAPIUsageLogger(std::function logger); +C10_API void LogAPIUsage(const std::string& context); + +C10_API void SetAPIUsageMetadataLogger( + std::function& metadata_map)> logger); +C10_API void LogAPIUsageMetadata( + const std::string& context, + const std::map& metadata_map); + +// PyTorch ddp usage logging capabilities +// DDPLoggingData holds data that can be logged in applications +// for analysis and debugging. Data structure is defined in +// c10 directory so that it can be easily imported by both c10 +// and torch files. +struct DDPLoggingData { + // logging fields that are string types. + std::map strs_map; + // logging fields that are int64_t types. + std::map ints_map; +}; + +C10_API void SetPyTorchDDPUsageLogger( + std::function logger); +C10_API void LogPyTorchDDPUsage(const DDPLoggingData& ddpData); + +namespace detail { +// Return value is needed to do the static variable initialization trick +C10_API bool LogAPIUsageFakeReturn(const std::string& context); +} // namespace detail + +// Initializes the c10 logger. +C10_API void initLogging(); + +// Sets the rank, which will be included in log messages +C10_API void SetGlobalRank(int64_t rank); + +} // namespace c10 + +#endif // C10_UTIL_LOGGING_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/MathConstants.h b/MLPY/Lib/site-packages/torch/include/c10/util/MathConstants.h new file mode 100644 index 0000000000000000000000000000000000000000..ecdf9f34a945b6acb73614a35c637e80da4a6fba --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/MathConstants.h @@ -0,0 +1,142 @@ +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace c10 { +// TODO: Replace me with inline constexpr variable when C++17 becomes available +namespace detail { +template +C10_HOST_DEVICE inline constexpr T e() { + return static_cast(2.718281828459045235360287471352662); +} + +template +C10_HOST_DEVICE inline constexpr T euler() { + return static_cast(0.577215664901532860606512090082402); +} + +template +C10_HOST_DEVICE inline constexpr T frac_1_pi() { + return static_cast(0.318309886183790671537767526745028); +} + +template +C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() { + return static_cast(0.564189583547756286948079451560772); +} + +template +C10_HOST_DEVICE inline constexpr T frac_sqrt_2() { + return static_cast(0.707106781186547524400844362104849); +} + +template +C10_HOST_DEVICE inline constexpr T frac_sqrt_3() { + return static_cast(0.577350269189625764509148780501957); +} + +template +C10_HOST_DEVICE inline constexpr T golden_ratio() { + return static_cast(1.618033988749894848204586834365638); +} + +template +C10_HOST_DEVICE inline constexpr T ln_10() { + return static_cast(2.302585092994045684017991454684364); +} + +template +C10_HOST_DEVICE inline constexpr T ln_2() { + return static_cast(0.693147180559945309417232121458176); +} + +template +C10_HOST_DEVICE inline constexpr T log_10_e() { + return static_cast(0.434294481903251827651128918916605); +} + +template +C10_HOST_DEVICE inline constexpr T log_2_e() { + return static_cast(1.442695040888963407359924681001892); +} + +template +C10_HOST_DEVICE inline constexpr T pi() { + return static_cast(3.141592653589793238462643383279502); +} + +template +C10_HOST_DEVICE inline constexpr T sqrt_2() { + return static_cast(1.414213562373095048801688724209698); +} + +template +C10_HOST_DEVICE inline constexpr T sqrt_3() { + return static_cast(1.732050807568877293527446341505872); +} + +template <> +C10_HOST_DEVICE inline constexpr BFloat16 pi() { + // According to + // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values + // pi is encoded as 4049 + return BFloat16(0x4049, BFloat16::from_bits()); +} + +template <> +C10_HOST_DEVICE inline constexpr Half pi() { + return Half(0x4248, Half::from_bits()); +} +} // namespace detail + +template +constexpr T e = c10::detail::e(); + +template +constexpr T euler = c10::detail::euler(); + +template +constexpr T frac_1_pi = c10::detail::frac_1_pi(); + +template +constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi(); + +template +constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2(); + +template +constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3(); + +template +constexpr T golden_ratio = c10::detail::golden_ratio(); + +template +constexpr T ln_10 = c10::detail::ln_10(); + +template +constexpr T ln_2 = c10::detail::ln_2(); + +template +constexpr T log_10_e = c10::detail::log_10_e(); + +template +constexpr T log_2_e = c10::detail::log_2_e(); + +template +constexpr T pi = c10::detail::pi(); + +template +constexpr T sqrt_2 = c10::detail::sqrt_2(); + +template +constexpr T sqrt_3 = c10::detail::sqrt_3(); +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/MaybeOwned.h b/MLPY/Lib/site-packages/torch/include/c10/util/MaybeOwned.h new file mode 100644 index 0000000000000000000000000000000000000000..074ae1070adb2a5699561bab078c5973e013ac38 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/MaybeOwned.h @@ -0,0 +1,237 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +/// MaybeOwnedTraits describes how to borrow from T. Here is how we +/// can implement borrowing from an arbitrary type T using a raw +/// pointer to const: +template +struct MaybeOwnedTraitsGenericImpl { + using owned_type = T; + using borrow_type = const T*; + + static borrow_type createBorrow(const owned_type& from) { + return &from; + } + + static void assignBorrow(borrow_type& lhs, borrow_type rhs) { + lhs = rhs; + } + + static void destroyBorrow(borrow_type& /*toDestroy*/) {} + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return *borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static bool debugBorrowIsValid(const borrow_type& borrow) { + return borrow != nullptr; + } +}; + +/// It is possible to eliminate the extra layer of indirection for +/// borrows for some types that we control. For examples, see +/// intrusive_ptr.h and TensorBody.h. + +template +struct MaybeOwnedTraits; + +// Explicitly enable MaybeOwned>, rather than allowing +// MaybeOwned to be used for any type right away. +template +struct MaybeOwnedTraits> + : public MaybeOwnedTraitsGenericImpl> {}; + +/// A smart pointer around either a borrowed or owned T. When +/// constructed with borrowed(), the caller MUST ensure that the +/// borrowed-from argument outlives this MaybeOwned. Compare to +/// Rust's std::borrow::Cow +/// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note +/// that it is probably not suitable for general use because C++ has +/// no borrow checking. Included here to support +/// Tensor::expect_contiguous. +template +class MaybeOwned final { + using borrow_type = typename MaybeOwnedTraits::borrow_type; + using owned_type = typename MaybeOwnedTraits::owned_type; + + bool isBorrowed_; + union { + borrow_type borrow_; + owned_type own_; + }; + + /// Don't use this; use borrowed() instead. + explicit MaybeOwned(const owned_type& t) + : isBorrowed_(true), borrow_(MaybeOwnedTraits::createBorrow(t)) {} + + /// Don't use this; use owned() instead. + explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v) + : isBorrowed_(false), own_(std::move(t)) {} + + /// Don't use this; use owned() instead. + template + explicit MaybeOwned(std::in_place_t, Args&&... args) + : isBorrowed_(false), own_(std::forward(args)...) {} + + public: + explicit MaybeOwned() : isBorrowed_(true), borrow_() {} + + // Copying a borrow yields another borrow of the original, as with a + // T*. Copying an owned T yields another owned T for safety: no + // chains of borrowing by default! (Note you could get that behavior + // with MaybeOwned::borrowed(*rhs) if you wanted it.) + MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + new (&own_) T(rhs.own_); + } + } + + MaybeOwned& operator=(const MaybeOwned& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_UNLIKELY(!isBorrowed_)) { + if (rhs.isBorrowed_) { + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; + } else { + own_ = rhs.own_; + } + } else { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + new (&own_) T(rhs.own_); + isBorrowed_ = false; + } + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_); + return *this; + } + + MaybeOwned(MaybeOwned&& rhs) noexcept( + // NOLINTNEXTLINE(*-noexcept-move-*) + std::is_nothrow_move_constructible_v && + std::is_nothrow_move_assignable_v) + : isBorrowed_(rhs.isBorrowed_) { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + new (&own_) T(std::move(rhs.own_)); + } + } + + MaybeOwned& operator=(MaybeOwned&& rhs) noexcept( + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v && + // NOLINTNEXTLINE(*-noexcept-move-*) + std::is_nothrow_destructible_v && + std::is_nothrow_destructible_v) { + if (this == &rhs) { + return *this; + } + if (C10_UNLIKELY(!isBorrowed_)) { + if (rhs.isBorrowed_) { + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; + } else { + own_ = std::move(rhs.own_); + } + } else { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + new (&own_) T(std::move(rhs.own_)); + isBorrowed_ = false; + } + } + return *this; + } + + static MaybeOwned borrowed(const T& t) { + return MaybeOwned(t); + } + + static MaybeOwned owned(T&& t) noexcept( + std::is_nothrow_move_constructible_v) { + return MaybeOwned(std::move(t)); + } + + template + static MaybeOwned owned(std::in_place_t, Args&&... args) { + return MaybeOwned(std::in_place, std::forward(args)...); + } + + ~MaybeOwned() noexcept( + // NOLINTNEXTLINE(*-noexcept-destructor) + std::is_nothrow_destructible_v && + std::is_nothrow_destructible_v) { + if (C10_UNLIKELY(!isBorrowed_)) { + own_.~T(); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + } + } + + // This is an implementation detail! You should know what you're doing + // if you are testing this. If you just want to guarantee ownership move + // this into a T + bool unsafeIsBorrowed() const { + return isBorrowed_; + } + + const T& operator*() const& { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + } + return C10_LIKELY(isBorrowed_) + ? MaybeOwnedTraits::referenceFromBorrow(borrow_) + : own_; + } + + const T* operator->() const { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + } + return C10_LIKELY(isBorrowed_) + ? MaybeOwnedTraits::pointerFromBorrow(borrow_) + : &own_; + } + + // If borrowed, copy the underlying T. If owned, move from + // it. borrowed/owned state remains the same, and either we + // reference the same borrow as before or we are an owned moved-from + // T. + T operator*() && { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + return MaybeOwnedTraits::referenceFromBorrow(borrow_); + } else { + return std::move(own_); + } + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Metaprogramming.h b/MLPY/Lib/site-packages/torch/include/c10/util/Metaprogramming.h new file mode 100644 index 0000000000000000000000000000000000000000..0e47b356cd6cdcc3c7d4bc5811553fa371b2be5e --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Metaprogramming.h @@ -0,0 +1,224 @@ +#pragma once + +#include +#include + +namespace c10::guts { + +/** + * Access information about result type or arguments from a function type. + * Example: + * using A = function_traits::return_type // A == int + * using A = function_traits::parameter_types::tuple_type + * // A == tuple + */ +template +struct function_traits { + static_assert( + !std::is_same_v, + "In function_traits, Func must be a plain function type."); +}; +template +struct function_traits { + using func_type = Result(Args...); + using return_type = Result; + using parameter_types = typelist::typelist; + static constexpr auto number_of_parameters = sizeof...(Args); +}; + +/** + * infer_function_traits: creates a `function_traits` type for a simple + * function (pointer) or functor (lambda/struct). Currently does not support + * class methods. + */ + +template +struct infer_function_traits { + using type = function_traits< + c10::guts::detail::strip_class_t>; +}; + +template +struct infer_function_traits { + using type = function_traits; +}; + +template +struct infer_function_traits { + using type = function_traits; +}; + +template +using infer_function_traits_t = typename infer_function_traits::type; + +/** + * make_function_traits: creates a `function_traits` type given a Return type + * and a typelist of Argument types + * + * Example: + * bool f(int, int); + * + * infer_function_traits_t == make_function_traits_t> + */ +template +struct make_function_traits { + static_assert( + false_t::value, + "In guts::make_function_traits, the ArgList argument must be typelist<...>."); +}; + +template +struct make_function_traits> { + using type = function_traits; +}; + +template +using make_function_traits_t = + typename make_function_traits::type; + +/** + * make_offset_index_sequence + * Like make_index_sequence, but starting from Start instead of 0. + * + * Example: + * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12> + */ +template +struct make_offset_index_sequence_impl + : make_offset_index_sequence_impl { + static_assert( + static_cast(Start) >= 0, + "make_offset_index_sequence: Start < 0"); + static_assert(static_cast(N) >= 0, "make_offset_index_sequence: N < 0"); +}; + +template +struct make_offset_index_sequence_impl { + typedef std::index_sequence type; +}; + +template +using make_offset_index_sequence = + typename make_offset_index_sequence_impl::type; + +/** + * Use tuple_elements to extract a position-indexed subset of elements + * from the argument tuple into a result tuple. + * + * Example: + * std::tuple t = std::make_tuple(0, "HEY", 2.0); + * std::tuple result = tuple_elements(t, std::index_sequence<0, + * 2>()); + */ +template +constexpr auto tuple_elements(Tuple t, std::index_sequence) { + return std::tuple...>(std::get(t)...); +} + +/** + * Use tuple_take to extract the first or last n elements from the argument + * tuple into a result tuple. + * + * Example: + * std::tuple t = std::make_tuple(0, "HEY", 2.0); + * std::tuple first_two = tuple_take(t); + * std::tuple last_two = tuple_take(t); + */ +template +struct TupleTake {}; + +template +struct TupleTake= 0, void>> { + static auto call(Tuple t) { + constexpr size_t size = std::tuple_size(); + static_assert(N <= size, "tuple_take: N > size"); + return tuple_elements(t, std::make_index_sequence{}); + } +}; + +template + struct TupleTake < Tuple, + N, std::enable_if_t> { + static auto call(Tuple t) { + constexpr size_t size = std::tuple_size(); + static_assert(-N <= size, "tuple_take: -N > size"); + return tuple_elements(t, make_offset_index_sequence{}); + } +}; + +template +auto tuple_take(Tuple t) { + return TupleTake::call(t); +} + +/** + * Use tuple_slice to extract a contiguous subtuple from the argument. + * + * Example: + * std::tuple t = std::make_tuple(0, + * "HEY", 2.0, false); std::tuple middle_two = + * tuple_slice(t); + */ +template +constexpr auto tuple_slice(Tuple t) { + constexpr size_t size = std::tuple_size(); + static_assert(Start + N <= size, "tuple_slice: Start + N > size"); + return tuple_elements(t, make_offset_index_sequence{}); +} + +/** + * Use tuple_map to run a mapping function over a tuple to get a new tuple. + * + * Example 1: + * auto result = tuple_map(std::tuple(3, 4, 5), [] + * (int32_t a) -> int16_t {return a+1;}); + * // result == std::tuple(4, 5, 6) + * + * Example 2: + * struct Mapper { + * std::string operator()(int32_t a) const { + * return std::to_string(a); + * } + * int64_t operator()(const std::string& a) const { + * return atoi(a.c_str()); + * } + * }; + * auto result = tuple_map(std::tuple(3, "4"), + * Mapper()); + * // result == std::tuple("3", 4) + * + * Example 3: + * struct A final { + * int32_t func() { + * return 5; + * } + * }; + * struct B final { + * std::string func() { + * return "5"; + * } + * }; + * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return + * a.func(); }); + * // result == std::tuple(5, "5"); + */ +namespace detail { +template +auto tuple_map( + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + std::tuple&& tuple, + const Mapper& mapper, + std::index_sequence) { + return std::tuple(std::get( + tuple))))...>(mapper(std::forward(std::get(tuple)))...); +} +} // namespace detail + +template +auto tuple_map(std::tuple&& tuple, const Mapper& mapper) { + return detail::tuple_map( + std::move(tuple), mapper, std::index_sequence_for()); +} + +} // namespace c10::guts diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Optional.h b/MLPY/Lib/site-packages/torch/include/c10/util/Optional.h new file mode 100644 index 0000000000000000000000000000000000000000..3229bef1315c505d4b7c6a1c5bcbb13b618cfb63 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Optional.h @@ -0,0 +1,48 @@ +#ifndef C10_UTIL_OPTIONAL_H_ +#define C10_UTIL_OPTIONAL_H_ + +#include +#include + +// Macros.h is not needed, but it does namespace shenanigans that lots +// of downstream code seems to rely on. Feel free to remove it and fix +// up builds. + +namespace c10 { +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::bad_optional_access; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::make_optional; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::nullopt; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::nullopt_t; +// NOLINTNEXTLINE(misc-unused-using-decls) +using std::optional; + +namespace detail_ { +// the call to convert(b) has return type A and converts b to type A iff b +// decltype(b) is implicitly convertible to A +template +constexpr U convert(U v) { + return v; +} +} // namespace detail_ +template +constexpr T value_or_else(const optional& v, F&& func) { + static_assert( + std::is_convertible_v, T>, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() ? *v : detail_::convert(std::forward(func)()); +} + +template +constexpr T value_or_else(optional&& v, F&& func) { + static_assert( + std::is_convertible_v, T>, + "func parameters must be a callable that returns a type convertible to the value stored in the optional"); + return v.has_value() ? constexpr_move(std::move(v).contained_val()) + : detail_::convert(std::forward(func)()); +} +} // namespace c10 +#endif // C10_UTIL_OPTIONAL_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/OptionalArrayRef.h b/MLPY/Lib/site-packages/torch/include/c10/util/OptionalArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..1b4ccf26679b2593f4dfa9d8ac0e679a6f892e87 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/OptionalArrayRef.h @@ -0,0 +1,236 @@ +// This file defines OptionalArrayRef, a class that has almost the same +// exact functionality as c10::optional>, except that its +// converting constructor fixes a dangling pointer issue. +// +// The implicit converting constructor of both c10::optional> and +// std::optional> can cause the underlying ArrayRef to store +// a dangling pointer. OptionalArrayRef prevents this by wrapping +// a c10::optional> and fixing the constructor implementation. +// +// See https://github.com/pytorch/pytorch/issues/63645 for more on this. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +template +class OptionalArrayRef final { + public: + // Constructors + + constexpr OptionalArrayRef() noexcept = default; + + constexpr OptionalArrayRef(nullopt_t) noexcept {} + + OptionalArrayRef(const OptionalArrayRef& other) = default; + + OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef(const optional>& other) noexcept + : wrapped_opt_array_ref(other) {} + + constexpr OptionalArrayRef(optional>&& other) noexcept + : wrapped_opt_array_ref(std::move(other)) {} + + constexpr OptionalArrayRef(const T& value) noexcept + : wrapped_opt_array_ref(value) {} + + template < + typename U = ArrayRef, + std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + std::is_convertible_v> && + !std::is_convertible_v, + bool> = false> + constexpr OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template < + typename U = ArrayRef, + std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + !std::is_convertible_v>, + bool> = false> + constexpr explicit OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template + constexpr explicit OptionalArrayRef( + std::in_place_t ip, + Args&&... args) noexcept + : wrapped_opt_array_ref(ip, std::forward(args)...) {} + + template + constexpr explicit OptionalArrayRef( + std::in_place_t ip, + std::initializer_list il, + Args&&... args) + : wrapped_opt_array_ref(ip, il, std::forward(args)...) {} + + constexpr OptionalArrayRef(const std::initializer_list& Vec) + : wrapped_opt_array_ref(ArrayRef(Vec)) {} + + // Destructor + + ~OptionalArrayRef() = default; + + // Assignment + + constexpr OptionalArrayRef& operator=(nullopt_t) noexcept { + wrapped_opt_array_ref = c10::nullopt; + return *this; + } + + OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; + + OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef& operator=( + const optional>& other) noexcept { + wrapped_opt_array_ref = other; + return *this; + } + + constexpr OptionalArrayRef& operator=( + optional>&& other) noexcept { + wrapped_opt_array_ref = std::move(other); + return *this; + } + + template < + typename U = ArrayRef, + typename = std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + std::is_constructible_v, U&&> && + std::is_assignable_v&, U&&>>> + constexpr OptionalArrayRef& operator=(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&> && + std::is_nothrow_assignable_v&, U&&>) { + wrapped_opt_array_ref = std::forward(value); + return *this; + } + + // Observers + + constexpr ArrayRef* operator->() noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef* operator->() const noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef& operator*() & noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& operator*() const& noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& operator*() && noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& operator*() const&& noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr explicit operator bool() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr bool has_value() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr ArrayRef& value() & { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& value() const& { + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& value() && { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& value() const&& { + return std::move(wrapped_opt_array_ref.value()); + } + + template + constexpr std:: + enable_if_t>, ArrayRef> + value_or(U&& default_value) const& { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + template + constexpr std:: + enable_if_t>, ArrayRef> + value_or(U&& default_value) && { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + // Modifiers + + constexpr void swap(OptionalArrayRef& other) noexcept { + std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); + } + + constexpr void reset() noexcept { + wrapped_opt_array_ref.reset(); + } + + template + constexpr std:: + enable_if_t, Args&&...>, ArrayRef&> + emplace(Args&&... args) noexcept( + std::is_nothrow_constructible_v, Args&&...>) { + return wrapped_opt_array_ref.emplace(std::forward(args)...); + } + + template + constexpr ArrayRef& emplace( + std::initializer_list il, + Args&&... args) noexcept { + return wrapped_opt_array_ref.emplace(il, std::forward(args)...); + } + + private: + optional> wrapped_opt_array_ref; +}; + +using OptionalIntArrayRef = OptionalArrayRef; + +inline bool operator==( + const OptionalIntArrayRef& a1, + const IntArrayRef& other) { + if (!a1.has_value()) { + return false; + } + return a1.value() == other; +} + +inline bool operator==( + const c10::IntArrayRef& a1, + const c10::OptionalIntArrayRef& a2) { + return a2 == a1; +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ParallelGuard.h b/MLPY/Lib/site-packages/torch/include/c10/util/ParallelGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..cabeafcacbdfd6235040b74c00c2e89183b3bf8a --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ParallelGuard.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace c10 { + +// RAII thread local guard that tracks whether code is being executed in +// `at::parallel_for` or `at::parallel_reduce` loop function. +class C10_API ParallelGuard { + public: + static bool is_enabled(); + + ParallelGuard(bool state); + ~ParallelGuard(); + + private: + bool previous_state_; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Registry.h b/MLPY/Lib/site-packages/torch/include/c10/util/Registry.h new file mode 100644 index 0000000000000000000000000000000000000000..20490019b3cfecd722030d19bedd3c900ea156b0 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Registry.h @@ -0,0 +1,326 @@ +#ifndef C10_UTIL_REGISTRY_H_ +#define C10_UTIL_REGISTRY_H_ + +/** + * Simple registry implementation that uses static variables to + * register object creators during program initialization time. + */ + +// NB: This Registry works poorly when you have other namespaces. +// Make all macro invocations from inside the at namespace. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace c10 { + +template +inline std::string KeyStrRepr(const KeyType& /*key*/) { + return "[key type printing not supported]"; +} + +template <> +inline std::string KeyStrRepr(const std::string& key) { + return key; +} + +enum RegistryPriority { + REGISTRY_FALLBACK = 1, + REGISTRY_DEFAULT = 2, + REGISTRY_PREFERRED = 3, +}; + +/** + * @brief A template class that allows one to register classes by keys. + * + * The keys are usually a std::string specifying the name, but can be anything + * that can be used in a std::map. + * + * You should most likely not use the Registry class explicitly, but use the + * helper macros below to declare specific registries as well as registering + * objects. + */ +template +class Registry { + public: + typedef std::function Creator; + + Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {} + + void Register( + const SrcType& key, + Creator creator, + const RegistryPriority priority = REGISTRY_DEFAULT) { + std::lock_guard lock(register_mutex_); + // The if statement below is essentially the same as the following line: + // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key + // << " registered twice."; + // However, TORCH_CHECK_EQ depends on google logging, and since registration + // is carried out at static initialization time, we do not want to have an + // explicit dependency on glog's initialization function. + if (registry_.count(key) != 0) { + auto cur_priority = priority_[key]; + if (priority > cur_priority) { +#ifdef DEBUG + std::string warn_msg = + "Overwriting already registered item for key " + KeyStrRepr(key); + fprintf(stderr, "%s\n", warn_msg.c_str()); +#endif + registry_[key] = creator; + priority_[key] = priority; + } else if (priority == cur_priority) { + std::string err_msg = + "Key already registered with the same priority: " + KeyStrRepr(key); + fprintf(stderr, "%s\n", err_msg.c_str()); + if (terminate_) { + std::exit(1); + } else { + throw std::runtime_error(err_msg); + } + } else if (warning_) { + std::string warn_msg = + "Higher priority item already registered, skipping registration of " + + KeyStrRepr(key); + fprintf(stderr, "%s\n", warn_msg.c_str()); + } + } else { + registry_[key] = creator; + priority_[key] = priority; + } + } + + void Register( + const SrcType& key, + Creator creator, + const std::string& help_msg, + const RegistryPriority priority = REGISTRY_DEFAULT) { + Register(key, creator, priority); + help_message_[key] = help_msg; + } + + inline bool Has(const SrcType& key) { + return (registry_.count(key) != 0); + } + + ObjectPtrType Create(const SrcType& key, Args... args) { + auto it = registry_.find(key); + if (it == registry_.end()) { + // Returns nullptr if the key is not registered. + return nullptr; + } + return it->second(args...); + } + + /** + * Returns the keys currently registered as a std::vector. + */ + std::vector Keys() const { + std::vector keys; + keys.reserve(registry_.size()); + for (const auto& it : registry_) { + keys.push_back(it.first); + } + return keys; + } + + inline const std::unordered_map& HelpMessage() const { + return help_message_; + } + + const char* HelpMessage(const SrcType& key) const { + auto it = help_message_.find(key); + if (it == help_message_.end()) { + return nullptr; + } + return it->second.c_str(); + } + + // Used for testing, if terminate is unset, Registry throws instead of + // calling std::exit + void SetTerminate(bool terminate) { + terminate_ = terminate; + } + + private: + std::unordered_map registry_; + std::unordered_map priority_; + bool terminate_{true}; + const bool warning_; + std::unordered_map help_message_; + std::mutex register_mutex_; + + C10_DISABLE_COPY_AND_ASSIGN(Registry); +}; + +template +class Registerer { + public: + explicit Registerer( + const SrcType& key, + Registry* registry, + typename Registry::Creator creator, + const std::string& help_msg = "") { + registry->Register(key, creator, help_msg); + } + + explicit Registerer( + const SrcType& key, + const RegistryPriority priority, + Registry* registry, + typename Registry::Creator creator, + const std::string& help_msg = "") { + registry->Register(key, creator, help_msg, priority); + } + + template + static ObjectPtrType DefaultCreator(Args... args) { + return ObjectPtrType(new DerivedType(args...)); + } +}; + +/** + * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function + * declaration, as well as creating a convenient typename for its corresponding + * registerer. + */ +// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE +// as import and DEFINE as export, because these registry macros will be used +// in downstream shared libraries as well, and one cannot use *_API - the API +// macro will be defined on a per-shared-library basis. Semantically, when one +// declares a typed registry it is always going to be IMPORT, and when one +// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE), +// the instantiation unit is always going to be exported. +// +// The only unique condition is when in the same file one does DECLARE and +// DEFINE - in Windows compilers, this generates a warning that dllimport and +// dllexport are mixed, but the warning is fine and linker will be properly +// exporting the symbol. Same thing happens in the gflags flag declaration and +// definition caes. +#define C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_API ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName(); \ + typedef ::c10::Registerer, ##__VA_ARGS__> \ + Registerer##RegistryName + +#define TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + TORCH_API ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName(); \ + typedef ::c10::Registerer, ##__VA_ARGS__> \ + Registerer##RegistryName + +#define C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName() { \ + static ::c10::Registry, ##__VA_ARGS__>* \ + registry = new ::c10:: \ + Registry, ##__VA_ARGS__>(); \ + return registry; \ + } + +#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName() { \ + static ::c10::Registry, ##__VA_ARGS__>* \ + registry = \ + new ::c10::Registry, ##__VA_ARGS__>( \ + false); \ + return registry; \ + } + +// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated +// creator with comma in its templated arguments. +#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, RegistryName(), ##__VA_ARGS__); + +#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ + RegistryName, key, priority, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, priority, RegistryName(), ##__VA_ARGS__); + +#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ + ::c10::demangle_type<__VA_ARGS__>()); + +#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ + RegistryName, key, priority, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + priority, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ + ::c10::demangle_type<__VA_ARGS__>()); + +// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use +// std::string as the key type, because that is the most commonly used cases. +#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string +// as the key +// type, because that is the most commonly used cases. +#define C10_REGISTER_CREATOR(RegistryName, key, ...) \ + C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) + +#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \ + C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ + RegistryName, #key, priority, __VA_ARGS__) + +#define C10_REGISTER_CLASS(RegistryName, key, ...) \ + C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) + +#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \ + C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ + RegistryName, #key, priority, __VA_ARGS__) + +} // namespace c10 + +#endif // C10_UTIL_REGISTRY_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ScopeExit.h b/MLPY/Lib/site-packages/torch/include/c10/util/ScopeExit.h new file mode 100644 index 0000000000000000000000000000000000000000..8ed2373eea03bb8615d6742f05996e9aa27158d6 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ScopeExit.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +namespace c10 { + +/** + * Mostly copied from https://llvm.org/doxygen/ScopeExit_8h_source.html + */ +template +class scope_exit { + Callable ExitFunction; + bool Engaged = true; // False once moved-from or release()d. + + public: + template + // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) + explicit scope_exit(Fp&& F) : ExitFunction(std::forward(F)) {} + + scope_exit(scope_exit&& Rhs) noexcept + : ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) { + Rhs.release(); + } + scope_exit(const scope_exit&) = delete; + scope_exit& operator=(scope_exit&&) = delete; + scope_exit& operator=(const scope_exit&) = delete; + + void release() { + Engaged = false; + } + + ~scope_exit() { + if (Engaged) { + ExitFunction(); + } + } +}; + +// Keeps the callable object that is passed in, and execute it at the +// destruction of the returned object (usually at the scope exit where the +// returned object is kept). +// +// Interface is specified by p0052r2. +template +scope_exit> make_scope_exit(Callable&& F) { + return scope_exit>(std::forward(F)); +} + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/SmallBuffer.h b/MLPY/Lib/site-packages/torch/include/c10/util/SmallBuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..1e4317ee03d0913d1a4128c7330631d01b889510 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/SmallBuffer.h @@ -0,0 +1,87 @@ +#pragma once +#include +#include +#include +#include + +/** Helper class for allocating temporary fixed size arrays with SBO. + * + * This is intentionally much simpler than SmallVector, to improve performance + * at the expense of many features: + * - No zero-initialization for numeric types + * - No resizing after construction + * - No copy/move + * - No non-trivial types + */ + +namespace c10 { + +template +class SmallBuffer { + static_assert(std::is_trivial_v, "SmallBuffer is intended for POD types"); + + std::array storage_; + size_t size_{}; + T* data_{}; + + public: + SmallBuffer(size_t size) : size_(size) { + if (size > N) { + data_ = new T[size]; + } else { + data_ = &storage_[0]; + } + } + + SmallBuffer(const SmallBuffer&) = delete; + SmallBuffer& operator=(const SmallBuffer&) = delete; + + // move constructor is needed in function return + SmallBuffer(SmallBuffer&& rhs) noexcept : size_{rhs.size_} { + rhs.size_ = 0; + if (size_ > N) { + data_ = rhs.data_; + rhs.data_ = nullptr; + } else { + storage_ = std::move(rhs.storage_); + data_ = &storage_[0]; + } + } + + SmallBuffer& operator=(SmallBuffer&&) = delete; + + ~SmallBuffer() { + if (size_ > N) { + delete[] data_; + } + } + T& operator[](size_t idx) { + return data()[idx]; + } + const T& operator[](size_t idx) const { + return data()[idx]; + } + T* data() { + return data_; + } + const T* data() const { + return data_; + } + size_t size() const { + return size_; + } + T* begin() { + return data_; + } + const T* begin() const { + return data_; + } + T* end() { + return data_ + size_; + } + const T* end() const { + return data_ + size_; + } +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/SmallVector.h b/MLPY/Lib/site-packages/torch/include/c10/util/SmallVector.h new file mode 100644 index 0000000000000000000000000000000000000000..2a72446179e4493d6033fac59664586a16ad921c --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/SmallVector.h @@ -0,0 +1,1476 @@ +//===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the SmallVector class. +// +//===----------------------------------------------------------------------===// + +// ATen: modified from llvm::SmallVector. +// used std::is_trivially_{copy,move}_constructible +// replaced iterator_range constructor with inline Container&& constructor +// replaced LLVM_NODISCARD, LLVM_LIKELY, and LLVM_UNLIKELY with c10 equivalents +// removed LLVM_GSL_OWNER +// added SmallVector::at +// added operator<< for std::ostream +// added C10_API to export SmallVectorBase + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") +#endif + +namespace c10 { + +/// This is all the stuff common to all SmallVectors. +/// +/// The template parameter specifies the type which should be used to hold the +/// Size and Capacity of the SmallVector, so it can be adjusted. +/// Using 32 bit size is desirable to shrink the size of the SmallVector. +/// Using 64 bit size is desirable for cases like SmallVector, where a +/// 32 bit size would limit the vector to ~4GB. SmallVectors are used for +/// buffering bitcode output - which can exceed 4GB. +template +class C10_API SmallVectorBase { + protected: + void* BeginX; + Size_T Size = 0, Capacity; + + /// The maximum value of the Size_T used. + static constexpr size_t SizeTypeMax() { + return std::numeric_limits::max(); + } + + SmallVectorBase(void* FirstEl, size_t TotalCapacity) + : BeginX(FirstEl), Capacity(TotalCapacity) {} + + /// This is a helper for \a grow() that's out of line to reduce code + /// duplication. This function will report a fatal error if it can't grow at + /// least to \p MinSize. + void* mallocForGrow(size_t MinSize, size_t TSize, size_t& NewCapacity); + + /// This is an implementation of the grow() method which only works + /// on POD-like data types and is out of line to reduce code duplication. + /// This function will report a fatal error if it cannot increase capacity. + void grow_pod(void* FirstEl, size_t MinSize, size_t TSize); + + public: + SmallVectorBase() = delete; + size_t size() const { + return Size; + } + size_t capacity() const { + return Capacity; + } + + C10_NODISCARD bool empty() const { + return !Size; + } + + /// Set the array size to \p N, which the current array must have enough + /// capacity for. + /// + /// This does not construct or destroy any elements in the vector. + /// + /// Clients can use this in conjunction with capacity() to write past the end + /// of the buffer when they know that more elements are available, and only + /// update the size later. This avoids the cost of value initializing elements + /// which will only be overwritten. + void set_size(size_t N) { + assert(N <= capacity()); + Size = N; + } +}; + +template +using SmallVectorSizeType = + std::conditional_t= 8, uint64_t, uint32_t>; + +/// Figure out the offset of the first element. +template +struct SmallVectorAlignmentAndSize { + alignas(SmallVectorBase>) char Base[sizeof( + SmallVectorBase>)]; + alignas(T) char FirstEl[sizeof(T)]; +}; + +/// This is the part of SmallVectorTemplateBase which does not depend on whether +/// the type T is a POD. The extra dummy template argument is used by ArrayRef +/// to avoid unnecessarily requiring T to be complete. +template +class SmallVectorTemplateCommon + : public SmallVectorBase> { + using Base = SmallVectorBase>; + + /// Find the address of the first element. For this pointer math to be valid + /// with small-size of 0 for T with lots of alignment, it's important that + /// SmallVectorStorage is properly-aligned even for small-size of 0. + void* getFirstEl() const { + return const_cast(reinterpret_cast( + reinterpret_cast(this) + + offsetof(SmallVectorAlignmentAndSize, FirstEl))); + } + // Space after 'FirstEl' is clobbered, do not add any instance vars after it. + + protected: + SmallVectorTemplateCommon(size_t Size) : Base(getFirstEl(), Size) {} + + void grow_pod(size_t MinSize, size_t TSize) { + Base::grow_pod(getFirstEl(), MinSize, TSize); + } + + /// Return true if this is a smallvector which has not had dynamic + /// memory allocated for it. + bool isSmall() const { + return this->BeginX == getFirstEl(); + } + + /// Put this vector in a state of being small. + void resetToSmall() { + this->BeginX = getFirstEl(); + this->Size = this->Capacity = 0; // FIXME: Setting Capacity to 0 is suspect. + } + + /// Return true if V is an internal reference to the given range. + bool isReferenceToRange(const void* V, const void* First, const void* Last) + const { + // Use std::less to avoid UB. + std::less<> LessThan; + return !LessThan(V, First) && LessThan(V, Last); + } + + /// Return true if V is an internal reference to this vector. + bool isReferenceToStorage(const void* V) const { + return isReferenceToRange(V, this->begin(), this->end()); + } + + /// Return true if First and Last form a valid (possibly empty) range in this + /// vector's storage. + bool isRangeInStorage(const void* First, const void* Last) const { + // Use std::less to avoid UB. + std::less<> LessThan; + return !LessThan(First, this->begin()) && !LessThan(Last, First) && + !LessThan(this->end(), Last); + } + + /// Return true unless Elt will be invalidated by resizing the vector to + /// NewSize. + bool isSafeToReferenceAfterResize(const void* Elt, size_t NewSize) { + // Past the end. + if (C10_LIKELY(!isReferenceToStorage(Elt))) + return true; + + // Return false if Elt will be destroyed by shrinking. + if (NewSize <= this->size()) + return Elt < this->begin() + NewSize; + + // Return false if we need to grow. + return NewSize <= this->capacity(); + } + + /// Check whether Elt will be invalidated by resizing the vector to NewSize. + void assertSafeToReferenceAfterResize(const void* Elt, size_t NewSize) { + (void)Elt; // Suppress unused variable warning + (void)NewSize; // Suppress unused variable warning + assert( + isSafeToReferenceAfterResize(Elt, NewSize) && + "Attempting to reference an element of the vector in an operation " + "that invalidates it"); + } + + /// Check whether Elt will be invalidated by increasing the size of the + /// vector by N. + void assertSafeToAdd(const void* Elt, size_t N = 1) { + this->assertSafeToReferenceAfterResize(Elt, this->size() + N); + } + + /// Check whether any part of the range will be invalidated by clearing. + void assertSafeToReferenceAfterClear(const T* From, const T* To) { + if (From == To) + return; + this->assertSafeToReferenceAfterResize(From, 0); + this->assertSafeToReferenceAfterResize(To - 1, 0); + } + template < + class ItTy, + std::enable_if_t, T*>, bool> = + false> + void assertSafeToReferenceAfterClear(ItTy, ItTy) {} + + /// Check whether any part of the range will be invalidated by growing. + void assertSafeToAddRange(const T* From, const T* To) { + if (From == To) + return; + this->assertSafeToAdd(From, To - From); + this->assertSafeToAdd(To - 1, To - From); + } + template < + class ItTy, + std::enable_if_t, T*>, bool> = + false> + void assertSafeToAddRange(ItTy, ItTy) {} + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + template + static const T* reserveForParamAndGetAddressImpl( + U* This, + const T& Elt, + size_t N) { + size_t NewSize = This->size() + N; + if (C10_LIKELY(NewSize <= This->capacity())) + return &Elt; + + bool ReferencesStorage = false; + int64_t Index = -1; + if (!U::TakesParamByValue) { + if (C10_UNLIKELY(This->isReferenceToStorage(&Elt))) { + ReferencesStorage = true; + Index = &Elt - This->begin(); + } + } + This->grow(NewSize); + return ReferencesStorage ? This->begin() + Index : &Elt; + } + + public: + using size_type = size_t; + using difference_type = ptrdiff_t; + using value_type = T; + using iterator = T*; + using const_iterator = const T*; + + using const_reverse_iterator = std::reverse_iterator; + using reverse_iterator = std::reverse_iterator; + + using reference = T&; + using const_reference = const T&; + using pointer = T*; + using const_pointer = const T*; + + using Base::capacity; + using Base::empty; + using Base::size; + + // forward iterator creation methods. + iterator begin() { + return (iterator)this->BeginX; + } + const_iterator begin() const { + return (const_iterator)this->BeginX; + } + iterator end() { + return begin() + size(); + } + const_iterator end() const { + return begin() + size(); + } + + // reverse iterator creation methods. + reverse_iterator rbegin() { + return reverse_iterator(end()); + } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + reverse_iterator rend() { + return reverse_iterator(begin()); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + + size_type size_in_bytes() const { + return size() * sizeof(T); + } + size_type max_size() const { + return std::min(this->SizeTypeMax(), size_type(-1) / sizeof(T)); + } + + size_t capacity_in_bytes() const { + return capacity() * sizeof(T); + } + + /// Return a pointer to the vector's buffer, even if empty(). + pointer data() { + return pointer(begin()); + } + /// Return a pointer to the vector's buffer, even if empty(). + const_pointer data() const { + return const_pointer(begin()); + } + + // SmallVector::at is NOT from LLVM. + reference at(size_type idx) { + assert(idx < size()); + return begin()[idx]; + } + const_reference at(size_type idx) const { + assert(idx < size()); + return begin()[idx]; + } + reference operator[](size_type idx) { + assert(idx < size()); + return begin()[idx]; + } + const_reference operator[](size_type idx) const { + assert(idx < size()); + return begin()[idx]; + } + + reference front() { + assert(!empty()); + return begin()[0]; + } + const_reference front() const { + assert(!empty()); + return begin()[0]; + } + + reference back() { + assert(!empty()); + return end()[-1]; + } + const_reference back() const { + assert(!empty()); + return end()[-1]; + } +}; + +/// SmallVectorTemplateBase - This is where we put +/// method implementations that are designed to work with non-trivial T's. +/// +/// We approximate is_trivially_copyable with trivial move/copy construction and +/// trivial destruction. While the standard doesn't specify that you're allowed +/// copy these types with memcpy, there is no way for the type to observe this. +/// This catches the important case of std::pair, which is not +/// trivially assignable. +/// +/// XXX: if build fails here fall back to C10_IS_TRIVIALLY_COPYABLE and make a +/// note +template < + typename T, + bool = (std::is_trivially_copy_constructible_v)&&( + std::is_trivially_move_constructible_v< + T>)&&std::is_trivially_destructible_v> +class SmallVectorTemplateBase : public SmallVectorTemplateCommon { + friend class SmallVectorTemplateCommon; + + protected: + static constexpr bool TakesParamByValue = false; + using ValueParamT = const T&; + + SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon(Size) {} + + static void destroy_range(T* S, T* E) { + while (S != E) { + --E; + E->~T(); + } + } + + /// Move the range [I, E) into the uninitialized memory starting with "Dest", + /// constructing elements as needed. + template + static void uninitialized_move(It1 I, It1 E, It2 Dest) { + std::uninitialized_copy( + std::make_move_iterator(I), std::make_move_iterator(E), Dest); + } + + /// Copy the range [I, E) onto the uninitialized memory starting with "Dest", + /// constructing elements as needed. + template + static void uninitialized_copy(It1 I, It1 E, It2 Dest) { + std::uninitialized_copy(I, E, Dest); + } + + /// Grow the allocated memory (without initializing new elements), doubling + /// the size of the allocated memory. Guarantees space for at least one more + /// element, or MinSize more elements if specified. + void grow(size_t MinSize = 0); + + /// Create a new allocation big enough for \p MinSize and pass back its size + /// in \p NewCapacity. This is the first section of \a grow(). + T* mallocForGrow(size_t MinSize, size_t& NewCapacity) { + return static_cast( + SmallVectorBase>::mallocForGrow( + MinSize, sizeof(T), NewCapacity)); + } + + /// Move existing elements over to the new allocation \p NewElts, the middle + /// section of \a grow(). + void moveElementsForGrow(T* NewElts); + + /// Transfer ownership of the allocation, finishing up \a grow(). + void takeAllocationForGrow(T* NewElts, size_t NewCapacity); + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + const T* reserveForParamAndGetAddress(const T& Elt, size_t N = 1) { + return this->reserveForParamAndGetAddressImpl(this, Elt, N); + } + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + T* reserveForParamAndGetAddress(T& Elt, size_t N = 1) { + return const_cast(this->reserveForParamAndGetAddressImpl(this, Elt, N)); + } + + static T&& forward_value_param(T&& V) { + return std::move(V); + } + static const T& forward_value_param(const T& V) { + return V; + } + + void growAndAssign(size_t NumElts, const T& Elt) { + // Grow manually in case Elt is an internal reference. + size_t NewCapacity = 0; + T* NewElts = mallocForGrow(NumElts, NewCapacity); + std::uninitialized_fill_n(NewElts, NumElts, Elt); + this->destroy_range(this->begin(), this->end()); + takeAllocationForGrow(NewElts, NewCapacity); + this->set_size(NumElts); + } + + template + T& growAndEmplaceBack(ArgTypes&&... Args) { + // Grow manually in case one of Args is an internal reference. + size_t NewCapacity = 0; + T* NewElts = mallocForGrow(0, NewCapacity); + ::new ((void*)(NewElts + this->size())) T(std::forward(Args)...); + moveElementsForGrow(NewElts); + takeAllocationForGrow(NewElts, NewCapacity); + this->set_size(this->size() + 1); + return this->back(); + } + + public: + void push_back(const T& Elt) { + const T* EltPtr = reserveForParamAndGetAddress(Elt); + ::new ((void*)this->end()) T(*EltPtr); + this->set_size(this->size() + 1); + } + + void push_back(T&& Elt) { + T* EltPtr = reserveForParamAndGetAddress(Elt); + ::new ((void*)this->end()) T(::std::move(*EltPtr)); + this->set_size(this->size() + 1); + } + + void pop_back() { + this->set_size(this->size() - 1); + this->end()->~T(); + } +}; + +// Define this out-of-line to dissuade the C++ compiler from inlining it. +template +void SmallVectorTemplateBase::grow(size_t MinSize) { + size_t NewCapacity = 0; + T* NewElts = mallocForGrow(MinSize, NewCapacity); + moveElementsForGrow(NewElts); + takeAllocationForGrow(NewElts, NewCapacity); +} + +// Define this out-of-line to dissuade the C++ compiler from inlining it. +template +void SmallVectorTemplateBase::moveElementsForGrow( + T* NewElts) { + // Move the elements over. + this->uninitialized_move(this->begin(), this->end(), NewElts); + + // Destroy the original elements. + destroy_range(this->begin(), this->end()); +} + +// Define this out-of-line to dissuade the C++ compiler from inlining it. +template +void SmallVectorTemplateBase::takeAllocationForGrow( + T* NewElts, + size_t NewCapacity) { + // If this wasn't grown from the inline copy, deallocate the old space. + if (!this->isSmall()) + free(this->begin()); + + this->BeginX = NewElts; + this->Capacity = NewCapacity; +} + +/// SmallVectorTemplateBase - This is where we put +/// method implementations that are designed to work with trivially copyable +/// T's. This allows using memcpy in place of copy/move construction and +/// skipping destruction. +template +class SmallVectorTemplateBase : public SmallVectorTemplateCommon { + friend class SmallVectorTemplateCommon; + + protected: + /// True if it's cheap enough to take parameters by value. Doing so avoids + /// overhead related to mitigations for reference invalidation. + static constexpr bool TakesParamByValue = sizeof(T) <= 2 * sizeof(void*); + + /// Either const T& or T, depending on whether it's cheap enough to take + /// parameters by value. + using ValueParamT = std::conditional_t; + + SmallVectorTemplateBase(size_t Size) : SmallVectorTemplateCommon(Size) {} + + // No need to do a destroy loop for POD's. + static void destroy_range(T*, T*) {} + + /// Move the range [I, E) onto the uninitialized memory + /// starting with "Dest", constructing elements into it as needed. + template + static void uninitialized_move(It1 I, It1 E, It2 Dest) { + // Just do a copy. + uninitialized_copy(I, E, Dest); + } + + /// Copy the range [I, E) onto the uninitialized memory + /// starting with "Dest", constructing elements into it as needed. + template + static void uninitialized_copy(It1 I, It1 E, It2 Dest) { + // Arbitrary iterator types; just use the basic implementation. + std::uninitialized_copy(I, E, Dest); + } + + /// Copy the range [I, E) onto the uninitialized memory + /// starting with "Dest", constructing elements into it as needed. + template + static void uninitialized_copy( + T1* I, + T1* E, + T2* Dest, + std::enable_if_t, T2>>* = + nullptr) { + // Use memcpy for PODs iterated by pointers (which includes SmallVector + // iterators): std::uninitialized_copy optimizes to memmove, but we can + // use memcpy here. Note that I and E are iterators and thus might be + // invalid for memcpy if they are equal. + if (I != E) + memcpy(reinterpret_cast(Dest), I, (E - I) * sizeof(T)); + } + + /// Double the size of the allocated memory, guaranteeing space for at + /// least one more element or MinSize if specified. + void grow(size_t MinSize = 0) { + this->grow_pod(MinSize, sizeof(T)); + } + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + const T* reserveForParamAndGetAddress(const T& Elt, size_t N = 1) { + return this->reserveForParamAndGetAddressImpl(this, Elt, N); + } + + /// Reserve enough space to add one element, and return the updated element + /// pointer in case it was a reference to the storage. + T* reserveForParamAndGetAddress(T& Elt, size_t N = 1) { + return const_cast(this->reserveForParamAndGetAddressImpl(this, Elt, N)); + } + + /// Copy \p V or return a reference, depending on \a ValueParamT. + static ValueParamT forward_value_param(ValueParamT V) { + return V; + } + + void growAndAssign(size_t NumElts, T Elt) { + // Elt has been copied in case it's an internal reference, side-stepping + // reference invalidation problems without losing the realloc optimization. + this->set_size(0); + this->grow(NumElts); + std::uninitialized_fill_n(this->begin(), NumElts, Elt); + this->set_size(NumElts); + } + + template + T& growAndEmplaceBack(ArgTypes&&... Args) { + // Use push_back with a copy in case Args has an internal reference, + // side-stepping reference invalidation problems without losing the realloc + // optimization. + push_back(T(std::forward(Args)...)); + return this->back(); + } + + public: + void push_back(ValueParamT Elt) { + const T* EltPtr = reserveForParamAndGetAddress(Elt); + memcpy(reinterpret_cast(this->end()), EltPtr, sizeof(T)); + this->set_size(this->size() + 1); + } + + void pop_back() { + this->set_size(this->size() - 1); + } +}; + +/// This class consists of common code factored out of the SmallVector class to +/// reduce code duplication based on the SmallVector 'N' template parameter. +template +class SmallVectorImpl : public SmallVectorTemplateBase { + using SuperClass = SmallVectorTemplateBase; + + public: + using iterator = typename SuperClass::iterator; + using const_iterator = typename SuperClass::const_iterator; + using reference = typename SuperClass::reference; + using size_type = typename SuperClass::size_type; + + protected: + using SmallVectorTemplateBase::TakesParamByValue; + using ValueParamT = typename SuperClass::ValueParamT; + + // Default ctor - Initialize to empty. + explicit SmallVectorImpl(unsigned N) : SmallVectorTemplateBase(N) {} + + public: + SmallVectorImpl(const SmallVectorImpl&) = delete; + + ~SmallVectorImpl() { + // Subclass has already destructed this vector's elements. + // If this wasn't grown from the inline copy, deallocate the old space. + if (!this->isSmall()) + free(this->begin()); + } + + void clear() { + this->destroy_range(this->begin(), this->end()); + this->Size = 0; + } + + private: + template + void resizeImpl(size_type N) { + if (N < this->size()) { + this->pop_back_n(this->size() - N); + } else if (N > this->size()) { + this->reserve(N); + for (auto I = this->end(), E = this->begin() + N; I != E; ++I) + if (ForOverwrite) + new (&*I) T; + else + new (&*I) T(); + this->set_size(N); + } + } + + public: + void resize(size_type N) { + resizeImpl(N); + } + + /// Like resize, but \ref T is POD, the new values won't be initialized. + void resize_for_overwrite(size_type N) { + resizeImpl(N); + } + + void resize(size_type N, ValueParamT NV) { + if (N == this->size()) + return; + + if (N < this->size()) { + this->pop_back_n(this->size() - N); + return; + } + + // N > this->size(). Defer to append. + this->append(N - this->size(), NV); + } + + void reserve(size_type N) { + if (this->capacity() < N) + this->grow(N); + } + + void pop_back_n(size_type NumItems) { + assert(this->size() >= NumItems); + this->destroy_range(this->end() - NumItems, this->end()); + this->set_size(this->size() - NumItems); + } + + C10_NODISCARD T pop_back_val() { + T Result = ::std::move(this->back()); + this->pop_back(); + return Result; + } + + void swap(SmallVectorImpl& RHS) noexcept; + + /// Add the specified range to the end of the SmallVector. + template < + typename in_iter, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + void append(in_iter in_start, in_iter in_end) { + this->assertSafeToAddRange(in_start, in_end); + size_type NumInputs = std::distance(in_start, in_end); + this->reserve(this->size() + NumInputs); + this->uninitialized_copy(in_start, in_end, this->end()); + this->set_size(this->size() + NumInputs); + } + + /// Append \p NumInputs copies of \p Elt to the end. + void append(size_type NumInputs, ValueParamT Elt) { + const T* EltPtr = this->reserveForParamAndGetAddress(Elt, NumInputs); + std::uninitialized_fill_n(this->end(), NumInputs, *EltPtr); + this->set_size(this->size() + NumInputs); + } + + void append(std::initializer_list IL) { + append(IL.begin(), IL.end()); + } + + void append(const SmallVectorImpl& RHS) { + append(RHS.begin(), RHS.end()); + } + + void assign(size_type NumElts, ValueParamT Elt) { + // Note that Elt could be an internal reference. + if (NumElts > this->capacity()) { + this->growAndAssign(NumElts, Elt); + return; + } + + // Assign over existing elements. + std::fill_n(this->begin(), std::min(NumElts, this->size()), Elt); + if (NumElts > this->size()) + std::uninitialized_fill_n(this->end(), NumElts - this->size(), Elt); + else if (NumElts < this->size()) + this->destroy_range(this->begin() + NumElts, this->end()); + this->set_size(NumElts); + } + + // FIXME: Consider assigning over existing elements, rather than clearing & + // re-initializing them - for all assign(...) variants. + + template < + typename in_iter, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + void assign(in_iter in_start, in_iter in_end) { + this->assertSafeToReferenceAfterClear(in_start, in_end); + clear(); + append(in_start, in_end); + } + + void assign(std::initializer_list IL) { + clear(); + append(IL); + } + + void assign(const SmallVectorImpl& RHS) { + assign(RHS.begin(), RHS.end()); + } + + iterator erase(const_iterator CI) { + // Just cast away constness because this is a non-const member function. + iterator I = const_cast(CI); + + assert( + this->isReferenceToStorage(CI) && + "Iterator to erase is out of bounds."); + + iterator N = I; + // Shift all elts down one. + std::move(I + 1, this->end(), I); + // Drop the last elt. + this->pop_back(); + return (N); + } + + iterator erase(const_iterator CS, const_iterator CE) { + // Just cast away constness because this is a non-const member function. + iterator S = const_cast(CS); + iterator E = const_cast(CE); + + assert(this->isRangeInStorage(S, E) && "Range to erase is out of bounds."); + + iterator N = S; + // Shift all elts down. + iterator I = std::move(E, this->end(), S); + // Drop the last elts. + this->destroy_range(I, this->end()); + this->set_size(I - this->begin()); + return (N); + } + + private: + template + iterator insert_one_impl(iterator I, ArgType&& Elt) { + // Callers ensure that ArgType is derived from T. + static_assert( + std::is_same>, T>:: + value, + "ArgType must be derived from T!"); + + if (I == this->end()) { // Important special case for empty vector. + this->push_back(::std::forward(Elt)); + return this->end() - 1; + } + + assert( + this->isReferenceToStorage(I) && + "Insertion iterator is out of bounds."); + + // Grow if necessary. + size_t Index = I - this->begin(); + std::remove_reference_t* EltPtr = + this->reserveForParamAndGetAddress(Elt); + I = this->begin() + Index; + + ::new ((void*)this->end()) T(::std::move(this->back())); + // Push everything else over. + std::move_backward(I, this->end() - 1, this->end()); + this->set_size(this->size() + 1); + + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + static_assert( + !TakesParamByValue || std::is_same::value, + "ArgType must be 'T' when taking by value!"); + if (!TakesParamByValue && this->isReferenceToRange(EltPtr, I, this->end())) + ++EltPtr; + + *I = ::std::forward(*EltPtr); + return I; + } + + public: + iterator insert(iterator I, T&& Elt) { + return insert_one_impl(I, this->forward_value_param(std::move(Elt))); + } + + iterator insert(iterator I, const T& Elt) { + return insert_one_impl(I, this->forward_value_param(Elt)); + } + + iterator insert(iterator I, size_type NumToInsert, ValueParamT Elt) { + // Convert iterator to elt# to avoid invalidating iterator when we reserve() + size_t InsertElt = I - this->begin(); + + if (I == this->end()) { // Important special case for empty vector. + append(NumToInsert, Elt); + return this->begin() + InsertElt; + } + + assert( + this->isReferenceToStorage(I) && + "Insertion iterator is out of bounds."); + + // Ensure there is enough space, and get the (maybe updated) address of + // Elt. + const T* EltPtr = this->reserveForParamAndGetAddress(Elt, NumToInsert); + + // Uninvalidate the iterator. + I = this->begin() + InsertElt; + + // If there are more elements between the insertion point and the end of the + // range than there are being inserted, we can use a simple approach to + // insertion. Since we already reserved space, we know that this won't + // reallocate the vector. + if (size_t(this->end() - I) >= NumToInsert) { + T* OldEnd = this->end(); + append( + std::move_iterator(this->end() - NumToInsert), + std::move_iterator(this->end())); + + // Copy the existing elements that get replaced. + std::move_backward(I, OldEnd - NumToInsert, OldEnd); + + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + + std::fill_n(I, NumToInsert, *EltPtr); + return I; + } + + // Otherwise, we're inserting more elements than exist already, and we're + // not inserting at the end. + + // Move over the elements that we're about to overwrite. + T* OldEnd = this->end(); + this->set_size(this->size() + NumToInsert); + size_t NumOverwritten = OldEnd - I; + this->uninitialized_move(I, OldEnd, this->end() - NumOverwritten); + + // If we just moved the element we're inserting, be sure to update + // the reference (never happens if TakesParamByValue). + if (!TakesParamByValue && I <= EltPtr && EltPtr < this->end()) + EltPtr += NumToInsert; + + // Replace the overwritten part. + std::fill_n(I, NumOverwritten, *EltPtr); + + // Insert the non-overwritten middle part. + std::uninitialized_fill_n(OldEnd, NumToInsert - NumOverwritten, *EltPtr); + return I; + } + + template < + typename ItTy, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + iterator insert(iterator I, ItTy From, ItTy To) { + // Convert iterator to elt# to avoid invalidating iterator when we reserve() + size_t InsertElt = I - this->begin(); + + if (I == this->end()) { // Important special case for empty vector. + append(From, To); + return this->begin() + InsertElt; + } + + assert( + this->isReferenceToStorage(I) && + "Insertion iterator is out of bounds."); + + // Check that the reserve that follows doesn't invalidate the iterators. + this->assertSafeToAddRange(From, To); + + size_t NumToInsert = std::distance(From, To); + + // Ensure there is enough space. + reserve(this->size() + NumToInsert); + + // Uninvalidate the iterator. + I = this->begin() + InsertElt; + + // If there are more elements between the insertion point and the end of the + // range than there are being inserted, we can use a simple approach to + // insertion. Since we already reserved space, we know that this won't + // reallocate the vector. + if (size_t(this->end() - I) >= NumToInsert) { + T* OldEnd = this->end(); + append( + std::move_iterator(this->end() - NumToInsert), + std::move_iterator(this->end())); + + // Copy the existing elements that get replaced. + std::move_backward(I, OldEnd - NumToInsert, OldEnd); + + std::copy(From, To, I); + return I; + } + + // Otherwise, we're inserting more elements than exist already, and we're + // not inserting at the end. + + // Move over the elements that we're about to overwrite. + T* OldEnd = this->end(); + this->set_size(this->size() + NumToInsert); + size_t NumOverwritten = OldEnd - I; + this->uninitialized_move(I, OldEnd, this->end() - NumOverwritten); + + // Replace the overwritten part. + for (T* J = I; NumOverwritten > 0; --NumOverwritten) { + *J = *From; + ++J; + ++From; + } + + // Insert the non-overwritten middle part. + this->uninitialized_copy(From, To, OldEnd); + return I; + } + + void insert(iterator I, std::initializer_list IL) { + insert(I, IL.begin(), IL.end()); + } + + template + reference emplace_back(ArgTypes&&... Args) { + if (C10_UNLIKELY(this->size() >= this->capacity())) + return this->growAndEmplaceBack(std::forward(Args)...); + + ::new ((void*)this->end()) T(std::forward(Args)...); + this->set_size(this->size() + 1); + return this->back(); + } + + SmallVectorImpl& operator=(const SmallVectorImpl& RHS); + + SmallVectorImpl& operator=(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_destructible_v); + + bool operator==(const SmallVectorImpl& RHS) const { + if (this->size() != RHS.size()) + return false; + return std::equal(this->begin(), this->end(), RHS.begin()); + } + bool operator!=(const SmallVectorImpl& RHS) const { + return !(*this == RHS); + } + + bool operator<(const SmallVectorImpl& RHS) const { + return std::lexicographical_compare( + this->begin(), this->end(), RHS.begin(), RHS.end()); + } +}; + +template +void SmallVectorImpl::swap(SmallVectorImpl& RHS) noexcept { + if (this == &RHS) + return; + + // We can only avoid copying elements if neither vector is small. + if (!this->isSmall() && !RHS.isSmall()) { + std::swap(this->BeginX, RHS.BeginX); + std::swap(this->Size, RHS.Size); + std::swap(this->Capacity, RHS.Capacity); + return; + } + this->reserve(RHS.size()); + RHS.reserve(this->size()); + + // Swap the shared elements. + size_t NumShared = this->size(); + if (NumShared > RHS.size()) + NumShared = RHS.size(); + for (size_type i = 0; i != NumShared; ++i) + std::swap((*this)[i], RHS[i]); + + // Copy over the extra elts. + if (this->size() > RHS.size()) { + size_t EltDiff = this->size() - RHS.size(); + this->uninitialized_copy(this->begin() + NumShared, this->end(), RHS.end()); + RHS.set_size(RHS.size() + EltDiff); + this->destroy_range(this->begin() + NumShared, this->end()); + this->set_size(NumShared); + } else if (RHS.size() > this->size()) { + size_t EltDiff = RHS.size() - this->size(); + this->uninitialized_copy(RHS.begin() + NumShared, RHS.end(), this->end()); + this->set_size(this->size() + EltDiff); + this->destroy_range(RHS.begin() + NumShared, RHS.end()); + RHS.set_size(NumShared); + } +} + +template +SmallVectorImpl& SmallVectorImpl::operator=( + const SmallVectorImpl& RHS) { + // Avoid self-assignment. + if (this == &RHS) + return *this; + + // If we already have sufficient space, assign the common elements, then + // destroy any excess. + size_t RHSSize = RHS.size(); + size_t CurSize = this->size(); + if (CurSize >= RHSSize) { + // Assign common elements. + iterator NewEnd; + if (RHSSize) + NewEnd = std::copy(RHS.begin(), RHS.begin() + RHSSize, this->begin()); + else + NewEnd = this->begin(); + + // Destroy excess elements. + this->destroy_range(NewEnd, this->end()); + + // Trim. + this->set_size(RHSSize); + return *this; + } + + // If we have to grow to have enough elements, destroy the current elements. + // This allows us to avoid copying them during the grow. + // FIXME: don't do this if they're efficiently moveable. + if (this->capacity() < RHSSize) { + // Destroy current elements. + this->clear(); + CurSize = 0; + this->grow(RHSSize); + } else if (CurSize) { + // Otherwise, use assignment for the already-constructed elements. + std::copy(RHS.begin(), RHS.begin() + CurSize, this->begin()); + } + + // Copy construct the new elements in place. + this->uninitialized_copy( + RHS.begin() + CurSize, RHS.end(), this->begin() + CurSize); + + // Set end. + this->set_size(RHSSize); + return *this; +} + +template +SmallVectorImpl& SmallVectorImpl:: +operator=(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_constructible_v && + std::is_nothrow_destructible_v) { + // Avoid self-assignment. + if (this == &RHS) + return *this; + + // If the RHS isn't small, clear this vector and then steal its buffer. + if (!RHS.isSmall()) { + this->destroy_range(this->begin(), this->end()); + if (!this->isSmall()) + free(this->begin()); + this->BeginX = RHS.BeginX; + this->Size = RHS.Size; + this->Capacity = RHS.Capacity; + RHS.resetToSmall(); + return *this; + } + + // If we already have sufficient space, assign the common elements, then + // destroy any excess. + size_t RHSSize = RHS.size(); + size_t CurSize = this->size(); + if (CurSize >= RHSSize) { + // Assign common elements. + iterator NewEnd = this->begin(); + if (RHSSize) + NewEnd = std::move(RHS.begin(), RHS.end(), NewEnd); + + // Destroy excess elements and trim the bounds. + this->destroy_range(NewEnd, this->end()); + this->set_size(RHSSize); + + // Clear the RHS. + RHS.clear(); + + return *this; + } + + // If we have to grow to have enough elements, destroy the current elements. + // This allows us to avoid copying them during the grow. + // FIXME: this may not actually make any sense if we can efficiently move + // elements. + if (this->capacity() < RHSSize) { + // Destroy current elements. + this->clear(); + CurSize = 0; + this->grow(RHSSize); + } else if (CurSize) { + // Otherwise, use assignment for the already-constructed elements. + std::move(RHS.begin(), RHS.begin() + CurSize, this->begin()); + } + + // Move-construct the new elements in place. + this->uninitialized_move( + RHS.begin() + CurSize, RHS.end(), this->begin() + CurSize); + + // Set end. + this->set_size(RHSSize); + + RHS.clear(); + return *this; +} + +/// Storage for the SmallVector elements. This is specialized for the N=0 case +/// to avoid allocating unnecessary storage. +template +struct SmallVectorStorage { + alignas(T) char InlineElts[N * sizeof(T)]; +}; + +/// We need the storage to be properly aligned even for small-size of 0 so that +/// the pointer math in \a SmallVectorTemplateCommon::getFirstEl() is +/// well-defined. +template +struct alignas(T) SmallVectorStorage {}; + +/// Forward declaration of SmallVector so that +/// calculateSmallVectorDefaultInlinedElements can reference +/// `sizeof(SmallVector)`. +template +class /* LLVM_GSL_OWNER */ SmallVector; + +/// Helper class for calculating the default number of inline elements for +/// `SmallVector`. +/// +/// This should be migrated to a constexpr function when our minimum +/// compiler support is enough for multi-statement constexpr functions. +template +struct CalculateSmallVectorDefaultInlinedElements { + // Parameter controlling the default number of inlined elements + // for `SmallVector`. + // + // The default number of inlined elements ensures that + // 1. There is at least one inlined element. + // 2. `sizeof(SmallVector) <= kPreferredSmallVectorSizeof` unless + // it contradicts 1. + static constexpr size_t kPreferredSmallVectorSizeof = 64; + + // static_assert that sizeof(T) is not "too big". + // + // Because our policy guarantees at least one inlined element, it is possible + // for an arbitrarily large inlined element to allocate an arbitrarily large + // amount of inline storage. We generally consider it an antipattern for a + // SmallVector to allocate an excessive amount of inline storage, so we want + // to call attention to these cases and make sure that users are making an + // intentional decision if they request a lot of inline storage. + // + // We want this assertion to trigger in pathological cases, but otherwise + // not be too easy to hit. To accomplish that, the cutoff is actually somewhat + // larger than kPreferredSmallVectorSizeof (otherwise, + // `SmallVector>` would be one easy way to trip it, and that + // pattern seems useful in practice). + // + // One wrinkle is that this assertion is in theory non-portable, since + // sizeof(T) is in general platform-dependent. However, we don't expect this + // to be much of an issue, because most LLVM development happens on 64-bit + // hosts, and therefore sizeof(T) is expected to *decrease* when compiled for + // 32-bit hosts, dodging the issue. The reverse situation, where development + // happens on a 32-bit host and then fails due to sizeof(T) *increasing* on a + // 64-bit host, is expected to be very rare. + static_assert( + sizeof(T) <= 256, + "You are trying to use a default number of inlined elements for " + "`SmallVector` but `sizeof(T)` is really big! Please use an " + "explicit number of inlined elements with `SmallVector` to make " + "sure you really want that much inline storage."); + + // Discount the size of the header itself when calculating the maximum inline + // bytes. + static constexpr size_t PreferredInlineBytes = + kPreferredSmallVectorSizeof - sizeof(SmallVector); + static constexpr size_t NumElementsThatFit = PreferredInlineBytes / sizeof(T); + static constexpr size_t value = + NumElementsThatFit == 0 ? 1 : NumElementsThatFit; +}; + +/// This is a 'vector' (really, a variable-sized array), optimized +/// for the case when the array is small. It contains some number of elements +/// in-place, which allows it to avoid heap allocation when the actual number of +/// elements is below that threshold. This allows normal "small" cases to be +/// fast without losing generality for large inputs. +/// +/// \note +/// In the absence of a well-motivated choice for the number of inlined +/// elements \p N, it is recommended to use \c SmallVector (that is, +/// omitting the \p N). This will choose a default number of inlined elements +/// reasonable for allocation on the stack (for example, trying to keep \c +/// sizeof(SmallVector) around 64 bytes). +/// +/// \warning This does not attempt to be exception safe. +/// +/// \see https://llvm.org/docs/ProgrammersManual.html#llvm-adt-smallvector-h +template < + typename T, + unsigned N = CalculateSmallVectorDefaultInlinedElements::value> +class /* LLVM_GSL_OWNER */ SmallVector : public SmallVectorImpl, + SmallVectorStorage { + public: + SmallVector() : SmallVectorImpl(N) {} + + ~SmallVector() { + // Destroy the constructed elements in the vector. + this->destroy_range(this->begin(), this->end()); + } + + explicit SmallVector(size_t Size, const T& Value = T()) + : SmallVectorImpl(N) { + this->assign(Size, Value); + } + + template < + typename ItTy, + typename = std::enable_if_t::iterator_category, + std::input_iterator_tag>>> + SmallVector(ItTy S, ItTy E) : SmallVectorImpl(N) { + this->append(S, E); + } + + // note: The enable_if restricts Container to types that have a .begin() and + // .end() that return valid input iterators. + template < + typename Container, + std::enable_if_t< + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, + std::input_iterator_tag> && + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, + std::input_iterator_tag>, + int> = 0> + explicit SmallVector(Container&& c) : SmallVectorImpl(N) { + this->append(c.begin(), c.end()); + } + + SmallVector(std::initializer_list IL) : SmallVectorImpl(N) { + this->assign(IL); + } + + SmallVector(const SmallVector& RHS) : SmallVectorImpl(N) { + if (!RHS.empty()) + SmallVectorImpl::operator=(RHS); + } + + SmallVector& operator=(const SmallVector& RHS) { + SmallVectorImpl::operator=(RHS); + return *this; + } + + SmallVector(SmallVector&& RHS) noexcept( + std::is_nothrow_move_assignable_v>) + : SmallVectorImpl(N) { + if (!RHS.empty()) + SmallVectorImpl::operator=(::std::move(RHS)); + } + + // note: The enable_if restricts Container to types that have a .begin() and + // .end() that return valid input iterators. + template < + typename Container, + std::enable_if_t< + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, + std::input_iterator_tag> && + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, + std::input_iterator_tag>, + int> = 0> + SmallVector& operator=(const Container& RHS) { + this->assign(RHS.begin(), RHS.end()); + return *this; + } + + SmallVector(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_assignable_v>) + : SmallVectorImpl(N) { + if (!RHS.empty()) + SmallVectorImpl::operator=(::std::move(RHS)); + } + + SmallVector& operator=(SmallVector&& RHS) noexcept( + std::is_nothrow_move_assignable_v>) { + SmallVectorImpl::operator=(::std::move(RHS)); + return *this; + } + + SmallVector& operator=(SmallVectorImpl&& RHS) noexcept( + std::is_nothrow_move_constructible_v>) { + SmallVectorImpl::operator=(::std::move(RHS)); + return *this; + } + + // note: The enable_if restricts Container to types that have a .begin() and + // .end() that return valid input iterators. + template < + typename Container, + std::enable_if_t< + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .begin())>::iterator_category, + std::input_iterator_tag> && + std::is_convertible_v< + typename std::iterator_traits< + decltype(std::declval() + .end())>::iterator_category, + std::input_iterator_tag>, + int> = 0> + SmallVector& operator=(Container&& C) { + this->assign(C.begin(), C.end()); + return *this; + } + + SmallVector& operator=(std::initializer_list IL) { + this->assign(IL); + return *this; + } +}; + +template +inline size_t capacity_in_bytes(const SmallVector& X) { + return X.capacity_in_bytes(); +} + +template +std::ostream& operator<<(std::ostream& out, const SmallVector& list) { + int i = 0; + out << "["; + for (auto e : list) { + if (i++ > 0) + out << ", "; + out << e; + } + out << "]"; + return out; +} + +template +using ValueTypeFromRangeType = std::remove_const_t< + std::remove_reference_t()))>>; + +/// Given a range of type R, iterate the entire range and return a +/// SmallVector with elements of the vector. This is useful, for example, +/// when you want to iterate a range and then sort the results. +template +SmallVector, Size> to_vector(R&& Range) { + return {std::begin(Range), std::end(Range)}; +} +template +SmallVector< + ValueTypeFromRangeType, + CalculateSmallVectorDefaultInlinedElements< + ValueTypeFromRangeType>::value> +to_vector(R&& Range) { + return {std::begin(Range), std::end(Range)}; +} + +} // end namespace c10 + +namespace std { + +/// Implement std::swap in terms of SmallVector swap. +template +inline void swap( + c10::SmallVectorImpl& LHS, + c10::SmallVectorImpl& RHS) noexcept { + LHS.swap(RHS); +} + +/// Implement std::swap in terms of SmallVector swap. +template +inline void swap( + c10::SmallVector& LHS, + c10::SmallVector& RHS) noexcept { + LHS.swap(RHS); +} + +} // end namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/StringUtil.h b/MLPY/Lib/site-packages/torch/include/c10/util/StringUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..35e6a30540e1d99521bb9deb8dd793f7f3a7e104 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/StringUtil.h @@ -0,0 +1,211 @@ +#ifndef C10_UTIL_STRINGUTIL_H_ +#define C10_UTIL_STRINGUTIL_H_ + +#include +#include +#include + +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") +#endif + +namespace c10 { + +namespace detail { + +// Obtains the base name from a full path. +C10_API std::string StripBasename(const std::string& full_path); + +C10_API std::string ExcludeFileExtension(const std::string& full_path); + +struct CompileTimeEmptyString { + operator const std::string&() const { + static const std::string empty_string_literal; + return empty_string_literal; + } + operator const char*() const { + return ""; + } +}; + +template +struct CanonicalizeStrTypes { + using type = const T&; +}; + +template +// NOLINTNEXTLINE(*c-arrays*) +struct CanonicalizeStrTypes { + using type = const char*; +}; + +inline std::ostream& _str(std::ostream& ss) { + return ss; +} + +template +inline std::ostream& _str(std::ostream& ss, const T& t) { + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + ss << t; + return ss; +} + +// Overloads of _str for wide types; forces narrowing. +C10_API std::ostream& _str(std::ostream& ss, const wchar_t* wCStr); +C10_API std::ostream& _str(std::ostream& ss, const wchar_t& wChar); +C10_API std::ostream& _str(std::ostream& ss, const std::wstring& wString); + +template <> +inline std::ostream& _str( + std::ostream& ss, + const CompileTimeEmptyString&) { + return ss; +} + +template +inline std::ostream& _str(std::ostream& ss, const T& t, const Args&... args) { + return _str(_str(ss, t), args...); +} + +template +struct _str_wrapper final { + static std::string call(const Args&... args) { + std::ostringstream ss; + _str(ss, args...); + return ss.str(); + } +}; + +// Specializations for already-a-string types. +template <> +struct _str_wrapper final { + // return by reference to avoid the binary size of a string copy + static const std::string& call(const std::string& str) { + return str; + } +}; + +template <> +struct _str_wrapper final { + static const char* call(const char* str) { + return str; + } +}; + +// For c10::str() with an empty argument list (which is common in our assert +// macros), we don't want to pay the binary size for constructing and +// destructing a stringstream or even constructing a string. +template <> +struct _str_wrapper<> final { + static CompileTimeEmptyString call() { + return CompileTimeEmptyString(); + } +}; + +} // namespace detail + +// Convert a list of string-like arguments into a single string. +template +inline decltype(auto) str(const Args&... args) { + return detail::_str_wrapper< + typename detail::CanonicalizeStrTypes::type...>::call(args...); +} + +template +inline std::string Join(const std::string& delimiter, const Container& v) { + std::stringstream s; + int cnt = static_cast(v.size()) - 1; + for (auto i = v.begin(); i != v.end(); ++i, --cnt) { + s << (*i) << (cnt ? delimiter : ""); + } + return s.str(); +} + +// Replace all occurrences of "from" substring to "to" string. +// Returns number of replacements +size_t C10_API +ReplaceAll(std::string& s, c10::string_view from, c10::string_view to); + +/// Represents a location in source code (for debugging). +struct C10_API SourceLocation { + const char* function; + const char* file; + uint32_t line; +}; + +std::ostream& operator<<(std::ostream& out, const SourceLocation& loc); + +// unix isprint but insensitive to locale +inline static bool isPrint(char s) { + return s > 0x1f && s < 0x7f; +} + +inline void printQuotedString(std::ostream& stmt, const string_view str) { + stmt << "\""; + for (auto s : str) { + switch (s) { + case '\\': + stmt << "\\\\"; + break; + case '\'': + stmt << "\\'"; + break; + case '\"': + stmt << "\\\""; + break; + case '\a': + stmt << "\\a"; + break; + case '\b': + stmt << "\\b"; + break; + case '\f': + stmt << "\\f"; + break; + case '\n': + stmt << "\\n"; + break; + case '\r': + stmt << "\\r"; + break; + case '\t': + stmt << "\\t"; + break; + case '\v': + stmt << "\\v"; + break; + default: + if (isPrint(s)) { + stmt << s; + } else { + // C++ io has stateful formatting settings. Messing with + // them is probably worse than doing this manually. + // NOLINTNEXTLINE(*c-arrays*) + char buf[4] = "000"; + // NOLINTNEXTLINE(*narrowing-conversions) + buf[2] += s % 8; + s /= 8; + // NOLINTNEXTLINE(*narrowing-conversions) + buf[1] += s % 8; + s /= 8; + // NOLINTNEXTLINE(*narrowing-conversions) + buf[0] += s; + stmt << "\\" << buf; + } + break; + } + } + stmt << "\""; +} + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +#endif // C10_UTIL_STRINGUTIL_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Synchronized.h b/MLPY/Lib/site-packages/torch/include/c10/util/Synchronized.h new file mode 100644 index 0000000000000000000000000000000000000000..da39195c1b2441cb10be89d91eef2c189441a5d7 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Synchronized.h @@ -0,0 +1,61 @@ +#pragma once + +#include + +namespace c10 { + +/** + * A very simple Synchronization class for error-free use of data + * in a multi-threaded context. See folly/docs/Synchronized.md for + * the inspiration of this class. + * + * Full URL: + * https://github.com/facebook/folly/blob/main/folly/docs/Synchronized.md + * + * This class implements a small subset of the generic functionality + * implemented by folly:Synchronized. Specifically, only withLock + * is implemented here since it's the smallest possible API that is + * able to cover a large surface area of functionality offered by + * folly::Synchronized. + */ +template +class Synchronized final { + mutable std::mutex mutex_; + T data_; + + public: + Synchronized() = default; + Synchronized(T const& data) : data_(data) {} + Synchronized(T&& data) : data_(std::move(data)) {} + + // Don't permit copy construction, move, assignment, or + // move assignment, since the underlying std::mutex + // isn't necessarily copyable/moveable. + Synchronized(Synchronized const&) = delete; + Synchronized(Synchronized&&) = delete; + Synchronized operator=(Synchronized const&) = delete; + Synchronized operator=(Synchronized&&) = delete; + + /** + * To use, call withLock with a callback that accepts T either + * by copy or by reference. Use the protected variable in the + * provided callback safely. + */ + template + auto withLock(CB&& cb) { + std::lock_guard guard(this->mutex_); + return std::forward(cb)(this->data_); + } + + /** + * To use, call withLock with a callback that accepts T either + * by copy or by const reference. Use the protected variable in + * the provided callback safely. + */ + template + auto withLock(CB&& cb) const { + std::lock_guard guard(this->mutex_); + return std::forward(cb)(this->data_); + } +}; +} // end namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ThreadLocal.h b/MLPY/Lib/site-packages/torch/include/c10/util/ThreadLocal.h new file mode 100644 index 0000000000000000000000000000000000000000..f44f297500d2a4f31fc3598a5c31b012459f1474 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ThreadLocal.h @@ -0,0 +1,153 @@ +#pragma once + +#include + +/** + * Android versions with libgnustl incorrectly handle thread_local C++ + * qualifier with composite types. NDK up to r17 version is affected. + * + * (A fix landed on Jun 4 2018: + * https://android-review.googlesource.com/c/toolchain/gcc/+/683601) + * + * In such cases, use c10::ThreadLocal wrapper + * which is `pthread_*` based with smart pointer semantics. + * + * In addition, convenient macro C10_DEFINE_TLS_static is available. + * To define static TLS variable of type std::string, do the following + * ``` + * C10_DEFINE_TLS_static(std::string, str_tls_); + * /////// + * { + * *str_tls_ = "abc"; + * assert(str_tls_->length(), 3); + * } + * ``` + * + * (see c10/test/util/ThreadLocal_test.cpp for more examples) + */ +#if !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#if defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604 +#define C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE +#endif // defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604 + +#endif // !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) +#include +#include +#include +#include +namespace c10 { + +/** + * @brief Temporary thread_local C++ qualifier replacement for Android + * based on `pthread_*`. + * To be used with composite types that provide default ctor. + */ +template +class ThreadLocal { + public: + ThreadLocal() { + pthread_key_create( + &key_, [](void* buf) { delete static_cast(buf); }); + } + + ~ThreadLocal() { + if (void* current = pthread_getspecific(key_)) { + delete static_cast(current); + } + + pthread_key_delete(key_); + } + + ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal& operator=(const ThreadLocal&) = delete; + + Type& get() { + if (void* current = pthread_getspecific(key_)) { + return *static_cast(current); + } + + std::unique_ptr ptr = std::make_unique(); + if (0 == pthread_setspecific(key_, ptr.get())) { + return *ptr.release(); + } + + int err = errno; + TORCH_INTERNAL_ASSERT(false, "pthread_setspecific() failed, errno = ", err); + } + + Type& operator*() { + return get(); + } + + Type* operator->() { + return &get(); + } + + private: + pthread_key_t key_; +}; + +} // namespace c10 + +#define C10_DEFINE_TLS_static(Type, Name) static ::c10::ThreadLocal Name + +#define C10_DECLARE_TLS_class_static(Class, Type, Name) \ + static ::c10::ThreadLocal Name + +#define C10_DEFINE_TLS_class_static(Class, Type, Name) \ + ::c10::ThreadLocal Class::Name + +#else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +namespace c10 { + +/** + * @brief Default thread_local implementation for non-Android cases. + * To be used with composite types that provide default ctor. + */ +template +class ThreadLocal { + public: + using Accessor = Type* (*)(); + explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {} + + ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal& operator=(const ThreadLocal&) = delete; + + Type& get() { + return *accessor_(); + } + + Type& operator*() { + return get(); + } + + Type* operator->() { + return &get(); + } + + private: + Accessor accessor_; +}; + +} // namespace c10 + +#define C10_DEFINE_TLS_static(Type, Name) \ + static ::c10::ThreadLocal Name([]() { \ + static thread_local Type var; \ + return &var; \ + }) + +#define C10_DECLARE_TLS_class_static(Class, Type, Name) \ + static ::c10::ThreadLocal Name + +#define C10_DEFINE_TLS_class_static(Class, Type, Name) \ + ::c10::ThreadLocal Class::Name([]() { \ + static thread_local Type var; \ + return &var; \ + }) + +#endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/ThreadLocalDebugInfo.h b/MLPY/Lib/site-packages/torch/include/c10/util/ThreadLocalDebugInfo.h new file mode 100644 index 0000000000000000000000000000000000000000..e9540dcd9783c9a265232837220451010cb6e7fc --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/ThreadLocalDebugInfo.h @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include + +namespace c10 { + +enum class C10_API_ENUM DebugInfoKind : uint8_t { + PRODUCER_INFO = 0, + MOBILE_RUNTIME_INFO, + PROFILER_STATE, + INFERENCE_CONTEXT, // for inference usage + PARAM_COMMS_INFO, + + TEST_INFO, // used only in tests + TEST_INFO_2, // used only in tests +}; + +class C10_API DebugInfoBase { + public: + DebugInfoBase() = default; + virtual ~DebugInfoBase() = default; +}; + +// Thread local debug information is propagated across the forward +// (including async fork tasks) and backward passes and is supposed +// to be utilized by the user's code to pass extra information from +// the higher layers (e.g. model id) down to the lower levels +// (e.g. to the operator observers used for debugging, logging, +// profiling, etc) +class C10_API ThreadLocalDebugInfo { + public: + static DebugInfoBase* get(DebugInfoKind kind); + + // Get current ThreadLocalDebugInfo + static std::shared_ptr current(); + + // Internal, use DebugInfoGuard/ThreadLocalStateGuard + static void _forceCurrentDebugInfo( + std::shared_ptr info); + + // Push debug info struct of a given kind + static void _push(DebugInfoKind kind, std::shared_ptr info); + // Pop debug info, throws in case the last pushed + // debug info is not of a given kind + static std::shared_ptr _pop(DebugInfoKind kind); + // Peek debug info, throws in case the last pushed debug info is not of the + // given kind + static std::shared_ptr _peek(DebugInfoKind kind); + + private: + std::shared_ptr info_; + DebugInfoKind kind_; + std::shared_ptr parent_info_; + + friend class DebugInfoGuard; +}; + +// DebugInfoGuard is used to set debug information, +// ThreadLocalDebugInfo is semantically immutable, the values are set +// through the scope-based guard object. +// Nested DebugInfoGuard adds/overrides existing values in the scope, +// restoring the original values after exiting the scope. +// Users can access the values through the ThreadLocalDebugInfo::get() call; +class C10_API DebugInfoGuard { + public: + DebugInfoGuard(DebugInfoKind kind, std::shared_ptr info); + + explicit DebugInfoGuard(std::shared_ptr info); + + ~DebugInfoGuard(); + + DebugInfoGuard(const DebugInfoGuard&) = delete; + DebugInfoGuard(DebugInfoGuard&&) = delete; + + private: + bool active_ = false; + std::shared_ptr prev_info_ = nullptr; +}; + +} // namespace c10 diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/Type.h b/MLPY/Lib/site-packages/torch/include/c10/util/Type.h new file mode 100644 index 0000000000000000000000000000000000000000..f2d73853bc3a6a5a573fc7710b362d75c6496d10 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/Type.h @@ -0,0 +1,30 @@ +#ifndef C10_UTIL_TYPE_H_ +#define C10_UTIL_TYPE_H_ + +#include +#include +#ifdef __GXX_RTTI +#include +#endif // __GXX_RTTI + +#include + +namespace c10 { + +/// Utility to demangle a C++ symbol name. +C10_API std::string demangle(const char* name); + +/// Returns the printable name of the type. +template +inline const char* demangle_type() { +#ifdef __GXX_RTTI + static const auto& name = *(new std::string(demangle(typeid(T).name()))); + return name.c_str(); +#else // __GXX_RTTI + return "(RTTI disabled, cannot show name)"; +#endif // __GXX_RTTI +} + +} // namespace c10 + +#endif // C10_UTIL_TYPE_H_ diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/TypeCast.h b/MLPY/Lib/site-packages/torch/include/c10/util/TypeCast.h new file mode 100644 index 0000000000000000000000000000000000000000..beda372caa02f40da8869e5b9c9839bca4f9d0c5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/TypeCast.h @@ -0,0 +1,169 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +template +struct needs_real { + constexpr static bool value = + (is_complex::value && !is_complex::value); +}; + +template +struct maybe_real { + C10_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_real { + C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + return src.real(); + } +}; + +// Note: deliberately ignores undefined behavior, consistent with NumPy. +// PyTorch's type conversions can cause a variety of undefined behavior, +// including float to integral overflow and signed to unsigned integer overflow. +// Some of this undefined behavior is addressed below. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( + src_t src) { + constexpr bool real = needs_real::value; + auto r = maybe_real::apply(src); + return static_cast(r); + } +}; + +// Partial template instantiation for casting to uint8. +// Note: Converting from negative float values to unsigned integer types is +// undefined behavior in C++, and current CPU and GPU compilers exhibit +// divergent behavior. Casting from negative float values to signed +// integer types and then to unsigned integer types is not undefined, +// however, so this cast improves the consistency of type conversions +// to uint8 across compilers. +// Further note: Type conversions across compilers still have other undefined +// and divergent behavior. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast( + static_cast(maybe_real::apply(src))); + } +}; + +template <> +struct static_cast_with_inter_type, c10::BFloat16> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::BFloat16 src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type, c10::Float8_e5m2> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e5m2 src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e5m2fnuz> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e5m2fnuz src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e4m3fn> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e4m3fn src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e4m3fnuz> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e4m3fnuz src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type, c10::Half> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Half src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::complex> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::complex src) { + return static_cast>( + static_cast>(src)); + } +}; + +template +C10_HOST_DEVICE To convert(From f) { + return static_cast_with_inter_type::apply(f); +} + +// Define separately to avoid being inlined and prevent code-size bloat +C10_API void report_overflow(const char* name); + +template +To checked_convert(From f, const char* name) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same_v && overflows(f)) { + report_overflow(name); + } + return convert(f); +} + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +// Trigger tests for D25440771. TODO: Remove this line any time you want. diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/TypeIndex.h b/MLPY/Lib/site-packages/torch/include/c10/util/TypeIndex.h new file mode 100644 index 0000000000000000000000000000000000000000..a36c7030cb915c4c617df537fd77785074430ec5 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/TypeIndex.h @@ -0,0 +1,196 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10::util { + +// TODO Make it work for more compilers + +// Intel compiler works +#if defined(__INTEL_COMPILER) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR + +// Clang works +#elif defined(__clang__) + +// except for NVCC +#if defined(__CUDACC__) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +// Windows works +#elif defined(_MSC_VER) + +// except for NVCC +#if defined(__CUDACC__) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +// GCC works +#elif defined(__GNUC__) + +// except when gcc < 9 +#if (__GNUC__ < 9) || defined(__CUDACC__) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +// some other compiler we don't know about +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +struct type_index final : IdWrapper { + constexpr explicit type_index(uint64_t checksum) : IdWrapper(checksum) {} + + // Allow usage in std::map / std::set + // TODO Disallow this and rather use std::unordered_map/set everywhere + friend constexpr bool operator<(type_index lhs, type_index rhs) noexcept { + return lhs.underlyingId() < rhs.underlyingId(); + } + + friend std::ostream& operator<<(std::ostream& stream, type_index typeId) { + return stream << typeId.underlyingId(); + } +}; + +namespace detail { + +#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ + __GNUC__ < 5 +// Getting __PRETTY_FUNCTION__ at compile time only works with GCC >= 5 +#error "You're running a too old version of GCC. We need GCC 5 or later." +#endif + +#if defined(__clang__) && __clang_major__ < 4 +// Getting __PRETTY_FUNCTION__ at compile time only works with Clang >= 4 +#error "You're running a too old version of Clang. We need Clang 4 or later." +#endif + +inline constexpr string_view extract( + string_view prefix, + string_view suffix, + string_view str) { +#if !defined(__CUDA_ARCH__) // CUDA doesn't like std::logic_error in device code + return (!str.starts_with(prefix) || !str.ends_with(suffix)) + ? (throw std::logic_error("Invalid pattern"), string_view()) + : str.substr(prefix.size(), str.size() - prefix.size() - suffix.size()); +#else + return str.substr(prefix.size(), str.size() - prefix.size() - suffix.size()); +#endif +} + +template +inline C10_TYPENAME_CONSTEXPR c10::string_view fully_qualified_type_name_impl() { +#if defined(_MSC_VER) && !defined(__clang__) +#if defined(__NVCC__) + return extract( + "c10::basic_string_view c10::util::detail::fully_qualified_type_name_impl<", + ">()", + __FUNCSIG__); +#else + return extract( + "class c10::basic_string_view __cdecl c10::util::detail::fully_qualified_type_name_impl<", + ">(void)", + __FUNCSIG__); +#endif +#elif defined(__clang__) + return extract( + "c10::string_view c10::util::detail::fully_qualified_type_name_impl() [T = ", + "]", + __PRETTY_FUNCTION__); +#elif defined(__GNUC__) + return extract( +#if C10_TYPENAME_SUPPORTS_CONSTEXPR + "constexpr c10::string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ", +#else + "c10::string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ", +#endif + "; c10::string_view = c10::basic_string_view]", + __PRETTY_FUNCTION__); +#endif +} + +#if !defined(__CUDA_ARCH__) +template +inline constexpr uint64_t type_index_impl() { +// Idea: __PRETTY_FUNCTION__ (or __FUNCSIG__ on msvc) contains a qualified name +// of this function, including its template parameter, i.e. including the +// type we want an id for. We use this name and run crc64 on it to get a type +// id. +#if defined(_MSC_VER) && !defined(__clang__) + return crc64(__FUNCSIG__, sizeof(__FUNCSIG__)).checksum(); +#elif defined(__clang__) + return crc64(__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__)).checksum(); +#elif defined(__GNUC__) + return crc64(__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__)).checksum(); +#endif +} +#endif + +} // namespace detail + +template +inline constexpr type_index get_type_index() { +#if !defined(__CUDA_ARCH__) + // To enforce that this is really computed at compile time, we pass the + // type index through std::integral_constant. + return type_index{std::integral_constant< + uint64_t, + detail::type_index_impl>()>::value}; +#else + // There's nothing in theory preventing us from running this on device code + // except for nvcc throwing a compiler error if we enable it. + return (abort(), type_index(0)); +#endif +} + +#if !defined(TORCH_PEDANTIC) +// Use precomputed hashsum for std::string +// Needed to workaround ambiguity in class name resolution +// into __PRETTY_FUNCTION__ when abovementioned class is defined in inlined +// namespace. In multi-ABI C++ library, `std::string` is an alias to +// `std::__cxx11::basic_string` which depending on compiler flags can be +// resolved to `basic_string` either in `std` namespace or in +// `std::__cxx11` one (`__cxx11` is an inline namespace) +template <> +inline constexpr type_index get_type_index() { + // hashsum for std::basic_string + return type_index{4193213214807308375ULL}; +} +#endif + +template +inline C10_TYPENAME_CONSTEXPR string_view +get_fully_qualified_type_name() noexcept { +#if C10_TYPENAME_SUPPORTS_CONSTEXPR + constexpr +#else + static +#endif + string_view name = detail::fully_qualified_type_name_impl(); + return name; +} +} // namespace c10::util + +C10_DEFINE_HASH_FOR_IDWRAPPER(c10::util::type_index); diff --git a/MLPY/Lib/site-packages/torch/include/c10/util/TypeList.h b/MLPY/Lib/site-packages/torch/include/c10/util/TypeList.h new file mode 100644 index 0000000000000000000000000000000000000000..438477a73518f24a5a978fe164a1a8dcdcffd721 --- /dev/null +++ b/MLPY/Lib/site-packages/torch/include/c10/util/TypeList.h @@ -0,0 +1,515 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10::guts { + +template +struct false_t : std::false_type {}; +template